diff --git a/agent/connect/ca/provider_test.go b/agent/connect/ca/provider_test.go new file mode 100644 index 0000000000..618086b64c --- /dev/null +++ b/agent/connect/ca/provider_test.go @@ -0,0 +1,197 @@ +package ca + +import ( + "bytes" + "testing" + "time" + + "github.com/hashicorp/consul/agent/structs" + "github.com/hashicorp/go-msgpack/codec" + "github.com/stretchr/testify/require" +) + +func TestStructs_CAConfiguration_MsgpackEncodeDecode(t *testing.T) { + type testcase struct { + in *structs.CAConfiguration + expectConfig interface{} // provider specific + parseFunc func(*testing.T, map[string]interface{}) interface{} + } + + commonBaseMap := map[string]interface{}{ + "LeafCertTTL": "30h", + "SkipValidate": true, + "CSRMaxPerSecond": 5.25, + "CSRMaxConcurrent": int64(55), + "PrivateKeyType": "rsa", + "PrivateKeyBits": int64(4096), + } + expectCommonBase := &structs.CommonCAProviderConfig{ + LeafCertTTL: 30 * time.Hour, + SkipValidate: true, + CSRMaxPerSecond: 5.25, + CSRMaxConcurrent: 55, + PrivateKeyType: "rsa", + PrivateKeyBits: 4096, + } + + cases := map[string]testcase{ + structs.ConsulCAProvider: { + in: &structs.CAConfiguration{ + ClusterID: "abc", + Provider: structs.ConsulCAProvider, + State: map[string]string{ + "foo": "bar", + }, + ForceWithoutCrossSigning: true, + RaftIndex: structs.RaftIndex{ + CreateIndex: 5, + ModifyIndex: 99, + }, + Config: map[string]interface{}{ + "PrivateKey": "key", + "RootCert": "cert", + "RotationPeriod": "5m", + "IntermediateCertTTL": "30m", + "DisableCrossSigning": true, + }, + }, + expectConfig: &structs.ConsulCAProviderConfig{ + CommonCAProviderConfig: *expectCommonBase, + PrivateKey: "key", + RootCert: "cert", + RotationPeriod: 5 * time.Minute, + IntermediateCertTTL: 30 * time.Minute, + DisableCrossSigning: true, + }, + parseFunc: func(t *testing.T, raw map[string]interface{}) interface{} { + config, err := ParseConsulCAConfig(raw) + require.NoError(t, err) + return config + }, + }, + structs.VaultCAProvider: { + in: &structs.CAConfiguration{ + ClusterID: "abc", + Provider: structs.VaultCAProvider, + State: map[string]string{ + "foo": "bar", + }, + ForceWithoutCrossSigning: true, + RaftIndex: structs.RaftIndex{ + CreateIndex: 5, + ModifyIndex: 99, + }, + Config: map[string]interface{}{ + "Address": "addr", + "Token": "token", + "RootPKIPath": "root-pki/", + "IntermediatePKIPath": "im-pki/", + "CAFile": "ca-file", + "CAPath": "ca-path", + "CertFile": "cert-file", + "KeyFile": "key-file", + "TLSServerName": "server-name", + "TLSSkipVerify": true, + }, + }, + expectConfig: &structs.VaultCAProviderConfig{ + CommonCAProviderConfig: *expectCommonBase, + Address: "addr", + Token: "token", + RootPKIPath: "root-pki/", + IntermediatePKIPath: "im-pki/", + CAFile: "ca-file", + CAPath: "ca-path", + CertFile: "cert-file", + KeyFile: "key-file", + TLSServerName: "server-name", + TLSSkipVerify: true, + }, + parseFunc: func(t *testing.T, raw map[string]interface{}) interface{} { + config, err := ParseVaultCAConfig(raw) + require.NoError(t, err) + return config + }, + }, + structs.AWSCAProvider: { + in: &structs.CAConfiguration{ + ClusterID: "abc", + Provider: structs.AWSCAProvider, + State: map[string]string{ + "foo": "bar", + }, + ForceWithoutCrossSigning: true, + RaftIndex: structs.RaftIndex{ + CreateIndex: 5, + ModifyIndex: 99, + }, + Config: map[string]interface{}{ + "ExistingARN": "arn://foo", + "DeleteOnExit": true, + }, + }, + expectConfig: &structs.AWSCAProviderConfig{ + CommonCAProviderConfig: *expectCommonBase, + ExistingARN: "arn://foo", + DeleteOnExit: true, + }, + parseFunc: func(t *testing.T, raw map[string]interface{}) interface{} { + config, err := ParseAWSCAConfig(raw) + require.NoError(t, err) + return config + }, + }, + } + // underlay common ca config stuff + for _, tc := range cases { + for k, v := range commonBaseMap { + if _, ok := tc.in.Config[k]; !ok { + tc.in.Config[k] = v + } + } + } + + var ( + // This is the common configuration pre-1.7.0 + handle1 = structs.TestingOldPre1dot7MsgpackHandle + // This is the common configuration post-1.7.0 + handle2 = structs.MsgpackHandle + ) + + decoderCase := func(t *testing.T, tc testcase, encHandle, decHandle *codec.MsgpackHandle) { + t.Helper() + + var buf bytes.Buffer + enc := codec.NewEncoder(&buf, encHandle) + require.NoError(t, enc.Encode(tc.in)) + + out := &structs.CAConfiguration{} + dec := codec.NewDecoder(&buf, decHandle) + require.NoError(t, dec.Decode(out)) + + config := tc.parseFunc(t, out.Config) + + out.Config = tc.in.Config // no longer care about how this field decoded + require.Equal(t, tc.in, out) + require.Equal(t, tc.expectConfig, config) + // TODO: verify json? + } + + for name, tc := range cases { + tc := tc + t.Run(name, func(t *testing.T) { + t.Run("old encoder and old decoder", func(t *testing.T) { + decoderCase(t, tc, handle1, handle1) + }) + t.Run("old encoder and new decoder", func(t *testing.T) { + decoderCase(t, tc, handle1, handle2) + }) + t.Run("new encoder and old decoder", func(t *testing.T) { + decoderCase(t, tc, handle2, handle1) + }) + t.Run("new encoder and new decoder", func(t *testing.T) { + decoderCase(t, tc, handle2, handle2) + }) + }) + } +} diff --git a/agent/consul/authmethod/kubeauth/k8s_test.go b/agent/consul/authmethod/kubeauth/k8s_test.go index 20771e2759..544bd26ee6 100644 --- a/agent/consul/authmethod/kubeauth/k8s_test.go +++ b/agent/consul/authmethod/kubeauth/k8s_test.go @@ -1,13 +1,79 @@ package kubeauth import ( + "bytes" "testing" "github.com/hashicorp/consul/agent/connect" + "github.com/hashicorp/consul/agent/consul/authmethod" "github.com/hashicorp/consul/agent/structs" + "github.com/hashicorp/go-msgpack/codec" "github.com/stretchr/testify/require" ) +func TestStructs_ACLAuthMethod_Kubernetes_MsgpackEncodeDecode(t *testing.T) { + in := &structs.ACLAuthMethod{ + Name: "k8s", + Type: "kubernetes", + Description: "k00b", + Config: map[string]interface{}{ + "Host": "https://kube.api.internal:8443", + "CACert": "", + "ServiceAccountJWT": "my.fake.jwt", + }, + RaftIndex: structs.RaftIndex{ + CreateIndex: 5, + ModifyIndex: 99, + }, + } + + expectConfig := &Config{ + Host: "https://kube.api.internal:8443", + CACert: "", + ServiceAccountJWT: "my.fake.jwt", + } + + var ( + // This is the common configuration pre-1.7.0 + handle1 = structs.TestingOldPre1dot7MsgpackHandle + // This is the common configuration post-1.7.0 + handle2 = structs.MsgpackHandle + ) + + decoderCase := func(t *testing.T, encHandle, decHandle *codec.MsgpackHandle) { + t.Helper() + + var buf bytes.Buffer + enc := codec.NewEncoder(&buf, encHandle) + require.NoError(t, enc.Encode(in)) + + out := &structs.ACLAuthMethod{} + dec := codec.NewDecoder(&buf, decHandle) + require.NoError(t, dec.Decode(out)) + + var config Config + require.NoError(t, authmethod.ParseConfig(in.Config, &config)) + + out.Config = in.Config // no longer care about how this field decoded + require.Equal(t, in, out) + require.Equal(t, expectConfig, &config) + // TODO: verify json? + } + + t.Run("old encoder and old decoder", func(t *testing.T) { + decoderCase(t, handle1, handle1) + }) + t.Run("old encoder and new decoder", func(t *testing.T) { + decoderCase(t, handle1, handle2) + }) + t.Run("new encoder and old decoder", func(t *testing.T) { + decoderCase(t, handle2, handle1) + }) + t.Run("new encoder and new decoder", func(t *testing.T) { + decoderCase(t, handle2, handle2) + }) +} + func TestValidateLogin(t *testing.T) { testSrv := StartTestAPIServer(t) defer testSrv.Stop() diff --git a/agent/consul/fsm/fsm.go b/agent/consul/fsm/fsm.go index 3d54319d42..7b2b7b7b01 100644 --- a/agent/consul/fsm/fsm.go +++ b/agent/consul/fsm/fsm.go @@ -15,11 +15,6 @@ import ( "github.com/hashicorp/raft" ) -// msgpackHandle is a shared handle for encoding/decoding msgpack payloads -var msgpackHandle = &codec.MsgpackHandle{ - RawToString: true, -} - // command is a command method on the FSM. type command func(buf []byte, index uint64) interface{} @@ -166,7 +161,7 @@ func (c *FSM) Restore(old io.ReadCloser) error { defer restore.Abort() // Create a decoder - dec := codec.NewDecoder(old, msgpackHandle) + dec := codec.NewDecoder(old, structs.MsgpackHandle) // Read in the header var header snapshotHeader diff --git a/agent/consul/fsm/snapshot.go b/agent/consul/fsm/snapshot.go index 11a2bd2f4a..4f3c36ab13 100644 --- a/agent/consul/fsm/snapshot.go +++ b/agent/consul/fsm/snapshot.go @@ -65,7 +65,7 @@ func (s *snapshot) Persist(sink raft.SnapshotSink) error { header := snapshotHeader{ LastIndex: s.state.LastIndex(), } - encoder := codec.NewEncoder(sink, msgpackHandle) + encoder := codec.NewEncoder(sink, structs.MsgpackHandle) if err := encoder.Encode(&header); err != nil { sink.Cancel() return err diff --git a/agent/consul/rpc.go b/agent/consul/rpc.go index 4cf210153a..cd2fc26522 100644 --- a/agent/consul/rpc.go +++ b/agent/consul/rpc.go @@ -196,7 +196,7 @@ func (s *Server) handleMultiplexV2(conn net.Conn) { // handleConsulConn is used to service a single Consul RPC connection func (s *Server) handleConsulConn(conn net.Conn) { defer conn.Close() - rpcCodec := msgpackrpc.NewServerCodec(conn) + rpcCodec := msgpackrpc.NewCodecFromHandle(true, true, conn, structs.MsgpackHandle) for { select { case <-s.shutdownCh: @@ -221,7 +221,7 @@ func (s *Server) handleConsulConn(conn net.Conn) { // handleInsecureConsulConn is used to service a single Consul INSECURERPC connection func (s *Server) handleInsecureConn(conn net.Conn) { defer conn.Close() - rpcCodec := msgpackrpc.NewServerCodec(conn) + rpcCodec := msgpackrpc.NewCodecFromHandle(true, true, conn, structs.MsgpackHandle) for { select { case <-s.shutdownCh: diff --git a/agent/consul/snapshot_endpoint.go b/agent/consul/snapshot_endpoint.go index 27833a72ca..2f2ddf06c5 100644 --- a/agent/consul/snapshot_endpoint.go +++ b/agent/consul/snapshot_endpoint.go @@ -150,7 +150,7 @@ func (s *Server) dispatchSnapshotRequest(args *structs.SnapshotRequest, in io.Re // a snapshot request. func (s *Server) handleSnapshotRequest(conn net.Conn) error { var args structs.SnapshotRequest - dec := codec.NewDecoder(conn, &codec.MsgpackHandle{}) + dec := codec.NewDecoder(conn, structs.MsgpackHandle) if err := dec.Decode(&args); err != nil { return fmt.Errorf("failed to decode request: %v", err) } @@ -168,7 +168,7 @@ func (s *Server) handleSnapshotRequest(conn net.Conn) error { }() RESPOND: - enc := codec.NewEncoder(conn, &codec.MsgpackHandle{}) + enc := codec.NewEncoder(conn, structs.MsgpackHandle) if err := enc.Encode(&reply); err != nil { return fmt.Errorf("failed to encode response: %v", err) } @@ -213,7 +213,7 @@ func SnapshotRPC(connPool *pool.ConnPool, dc string, addr net.Addr, useTLS bool, } // Push the header encoded as msgpack, then stream the input. - enc := codec.NewEncoder(conn, &codec.MsgpackHandle{}) + enc := codec.NewEncoder(conn, structs.MsgpackHandle) if err := enc.Encode(&args); err != nil { return nil, fmt.Errorf("failed to encode request: %v", err) } @@ -235,7 +235,7 @@ func SnapshotRPC(connPool *pool.ConnPool, dc string, addr net.Addr, useTLS bool, // Pull the header decoded as msgpack. The caller can continue to read // the conn to stream the remaining data. - dec := codec.NewDecoder(conn, &codec.MsgpackHandle{}) + dec := codec.NewDecoder(conn, structs.MsgpackHandle) if err := dec.Decode(reply); err != nil { return nil, fmt.Errorf("failed to decode response: %v", err) } diff --git a/agent/consul/status_endpoint_test.go b/agent/consul/status_endpoint_test.go index 8343a12193..15c6594929 100644 --- a/agent/consul/status_endpoint_test.go +++ b/agent/consul/status_endpoint_test.go @@ -24,7 +24,7 @@ func rpcClient(t *testing.T, s *Server) rpc.ClientCodec { // Write the Consul RPC byte to set the mode conn.Write([]byte{byte(pool.RPCConsul)}) - return msgpackrpc.NewClientCodec(conn) + return msgpackrpc.NewCodecFromHandle(true, true, conn, structs.MsgpackHandle) } func insecureRPCClient(s *Server, c tlsutil.Config) (rpc.ClientCodec, error) { @@ -41,7 +41,7 @@ func insecureRPCClient(s *Server, c tlsutil.Config) (rpc.ClientCodec, error) { if err != nil { return nil, err } - return msgpackrpc.NewClientCodec(conn), nil + return msgpackrpc.NewCodecFromHandle(true, true, conn, structs.MsgpackHandle), nil } func TestStatusLeader(t *testing.T) { diff --git a/agent/pool/pool.go b/agent/pool/pool.go index fd693514d1..2ee344173e 100644 --- a/agent/pool/pool.go +++ b/agent/pool/pool.go @@ -11,9 +11,10 @@ import ( "sync/atomic" "time" + "github.com/hashicorp/consul/agent/structs" "github.com/hashicorp/consul/lib" "github.com/hashicorp/consul/tlsutil" - "github.com/hashicorp/net-rpc-msgpackrpc" + msgpackrpc "github.com/hashicorp/net-rpc-msgpackrpc" "github.com/hashicorp/yamux" ) @@ -76,7 +77,7 @@ func (c *Conn) getClient() (*StreamClient, error) { } // Create the RPC client - codec := msgpackrpc.NewClientCodec(stream) + codec := msgpackrpc.NewCodecFromHandle(true, true, stream, structs.MsgpackHandle) // Return a new stream client sc := &StreamClient{ @@ -443,7 +444,7 @@ func (p *ConnPool) rpcInsecure(dc string, addr net.Addr, method string, args int if err != nil { return fmt.Errorf("rpcinsecure error establishing connection: %v", err) } - codec = msgpackrpc.NewClientCodec(conn) + codec = msgpackrpc.NewCodecFromHandle(true, true, conn, structs.MsgpackHandle) // Make the RPC call err = msgpackrpc.CallWithCodec(codec, method, args, reply) diff --git a/agent/structs/acl.go b/agent/structs/acl.go index 6ce8ac2f31..f837dfba26 100644 --- a/agent/structs/acl.go +++ b/agent/structs/acl.go @@ -12,7 +12,6 @@ import ( "github.com/hashicorp/consul/acl" "github.com/hashicorp/consul/lib" - "github.com/hashicorp/go-msgpack/codec" "golang.org/x/crypto/blake2b" ) @@ -1048,49 +1047,6 @@ type ACLAuthMethod struct { RaftIndex `hash:"ignore"` } -// MarshalBinary writes ACLAuthMethod as msgpack encoded. It's only here -// because we need custom decoding of the raw interface{} values and this -// completes the interface. -func (m *ACLAuthMethod) MarshalBinary() (data []byte, err error) { - // bs will grow if needed but allocate enough to avoid reallocation in common - // case. - bs := make([]byte, 256) - enc := codec.NewEncoderBytes(&bs, msgpackHandle) - - type Alias ACLAuthMethod - - if err := enc.Encode((*Alias)(m)); err != nil { - return nil, err - } - - return bs, nil -} - -// UnmarshalBinary decodes msgpack encoded ACLAuthMethod. It used -// default msgpack encoding but fixes up the uint8 strings and other problems we -// have with encoding map[string]interface{}. -func (m *ACLAuthMethod) UnmarshalBinary(data []byte) error { - dec := codec.NewDecoderBytes(data, msgpackHandle) - - type Alias ACLAuthMethod - var a Alias - - if err := dec.Decode(&a); err != nil { - return err - } - - *m = ACLAuthMethod(a) - - var err error - - // Fix strings and maps in the returned maps - m.Config, err = lib.MapWalk(m.Config) - if err != nil { - return err - } - return nil -} - type ACLReplicationType string const ( diff --git a/agent/structs/config_entry.go b/agent/structs/config_entry.go index 742070983d..4e9b7adfa9 100644 --- a/agent/structs/config_entry.go +++ b/agent/structs/config_entry.go @@ -216,7 +216,7 @@ func (e *ProxyConfigEntry) MarshalBinary() (data []byte, err error) { // bs will grow if needed but allocate enough to avoid reallocation in common // case. bs := make([]byte, 128) - enc := codec.NewEncoderBytes(&bs, msgpackHandle) + enc := codec.NewEncoderBytes(&bs, MsgpackHandle) err = enc.Encode(a) if err != nil { return nil, err @@ -235,7 +235,7 @@ func (e *ProxyConfigEntry) UnmarshalBinary(data []byte) error { type alias ProxyConfigEntry var a alias - dec := codec.NewDecoderBytes(data, msgpackHandle) + dec := codec.NewDecoderBytes(data, MsgpackHandle) if err := dec.Decode(&a); err != nil { return err } @@ -406,7 +406,7 @@ func (c *ConfigEntryRequest) MarshalBinary() (data []byte, err error) { // bs will grow if needed but allocate enough to avoid reallocation in common // case. bs := make([]byte, 128) - enc := codec.NewEncoderBytes(&bs, msgpackHandle) + enc := codec.NewEncoderBytes(&bs, MsgpackHandle) // Encode kind first err = enc.Encode(c.Entry.GetKind()) if err != nil { @@ -428,7 +428,7 @@ func (c *ConfigEntryRequest) MarshalBinary() (data []byte, err error) { func (c *ConfigEntryRequest) UnmarshalBinary(data []byte) error { // First decode the kind prefix var kind string - dec := codec.NewDecoderBytes(data, msgpackHandle) + dec := codec.NewDecoderBytes(data, MsgpackHandle) if err := dec.Decode(&kind); err != nil { return err } @@ -611,7 +611,7 @@ func (r *ServiceConfigResponse) MarshalBinary() (data []byte, err error) { // bs will grow if needed but allocate enough to avoid reallocation in common // case. bs := make([]byte, 128) - enc := codec.NewEncoderBytes(&bs, msgpackHandle) + enc := codec.NewEncoderBytes(&bs, MsgpackHandle) type Alias ServiceConfigResponse @@ -626,7 +626,7 @@ func (r *ServiceConfigResponse) MarshalBinary() (data []byte, err error) { // default msgpack encoding but fixes up the uint8 strings and other problems we // have with encoding map[string]interface{}. func (r *ServiceConfigResponse) UnmarshalBinary(data []byte) error { - dec := codec.NewDecoderBytes(data, msgpackHandle) + dec := codec.NewDecoderBytes(data, MsgpackHandle) type Alias ServiceConfigResponse var a Alias @@ -670,7 +670,7 @@ func (c *ConfigEntryResponse) MarshalBinary() (data []byte, err error) { // bs will grow if needed but allocate enough to avoid reallocation in common // case. bs := make([]byte, 128) - enc := codec.NewEncoderBytes(&bs, msgpackHandle) + enc := codec.NewEncoderBytes(&bs, MsgpackHandle) if c.Entry != nil { if err := enc.Encode(c.Entry.GetKind()); err != nil { @@ -693,7 +693,7 @@ func (c *ConfigEntryResponse) MarshalBinary() (data []byte, err error) { } func (c *ConfigEntryResponse) UnmarshalBinary(data []byte) error { - dec := codec.NewDecoderBytes(data, msgpackHandle) + dec := codec.NewDecoderBytes(data, MsgpackHandle) var kind string if err := dec.Decode(&kind); err != nil { diff --git a/agent/structs/config_entry_test.go b/agent/structs/config_entry_test.go index fc2a3368d7..f654fa4480 100644 --- a/agent/structs/config_entry_test.go +++ b/agent/structs/config_entry_test.go @@ -592,12 +592,12 @@ func TestServiceConfigResponse_MsgPack(t *testing.T) { // Encode as msgPack using a regular handle i.e. NOT one with RawAsString // since our RPC codec doesn't use that. - enc := codec.NewEncoder(&buf, msgpackHandle) + enc := codec.NewEncoder(&buf, MsgpackHandle) require.NoError(t, enc.Encode(&a)) var b ServiceConfigResponse - dec := codec.NewDecoder(&buf, msgpackHandle) + dec := codec.NewDecoder(&buf, MsgpackHandle) require.NoError(t, dec.Decode(&b)) require.Equal(t, a, b) diff --git a/agent/structs/connect_ca.go b/agent/structs/connect_ca.go index 2e3ddb1330..448f753ddc 100644 --- a/agent/structs/connect_ca.go +++ b/agent/structs/connect_ca.go @@ -6,7 +6,6 @@ import ( "time" "github.com/hashicorp/consul/lib" - "github.com/hashicorp/go-msgpack/codec" "github.com/mitchellh/mapstructure" ) @@ -260,49 +259,6 @@ type CAConfiguration struct { RaftIndex } -// MarshalBinary writes CAConfiguration as msgpack encoded. It's only here -// because we need custom decoding of the raw interface{} values and this -// completes the interface. -func (c *CAConfiguration) MarshalBinary() (data []byte, err error) { - // bs will grow if needed but allocate enough to avoid reallocation in common - // case. - bs := make([]byte, 128) - enc := codec.NewEncoderBytes(&bs, msgpackHandle) - - type Alias CAConfiguration - - if err := enc.Encode((*Alias)(c)); err != nil { - return nil, err - } - - return bs, nil -} - -// UnmarshalBinary decodes msgpack encoded CAConfiguration. It used -// default msgpack encoding but fixes up the uint8 strings and other problems we -// have with encoding map[string]interface{}. -func (c *CAConfiguration) UnmarshalBinary(data []byte) error { - dec := codec.NewDecoderBytes(data, msgpackHandle) - - type Alias CAConfiguration - var a Alias - - if err := dec.Decode(&a); err != nil { - return err - } - - *c = CAConfiguration(a) - - var err error - - // Fix strings and maps in the returned maps - c.Config, err = lib.MapWalk(c.Config) - if err != nil { - return err - } - return nil -} - func (c *CAConfiguration) UnmarshalJSON(data []byte) (err error) { type Alias CAConfiguration diff --git a/agent/structs/structs.go b/agent/structs/structs.go index da720cb4c4..e7981d6009 100644 --- a/agent/structs/structs.go +++ b/agent/structs/structs.go @@ -1725,7 +1725,7 @@ func (c *IndexedConfigEntries) MarshalBinary() (data []byte, err error) { // bs will grow if needed but allocate enough to avoid reallocation in common // case. bs := make([]byte, 128) - enc := codec.NewEncoderBytes(&bs, msgpackHandle) + enc := codec.NewEncoderBytes(&bs, MsgpackHandle) // Encode length. err = enc.Encode(len(c.Entries)) @@ -1755,7 +1755,7 @@ func (c *IndexedConfigEntries) MarshalBinary() (data []byte, err error) { func (c *IndexedConfigEntries) UnmarshalBinary(data []byte) error { // First decode the number of entries. var numEntries int - dec := codec.NewDecoderBytes(data, msgpackHandle) + dec := codec.NewDecoderBytes(data, MsgpackHandle) if err := dec.Decode(&numEntries); err != nil { return err } @@ -1799,7 +1799,7 @@ func (c *IndexedGenericConfigEntries) MarshalBinary() (data []byte, err error) { // bs will grow if needed but allocate enough to avoid reallocation in common // case. bs := make([]byte, 128) - enc := codec.NewEncoderBytes(&bs, msgpackHandle) + enc := codec.NewEncoderBytes(&bs, MsgpackHandle) if err := enc.Encode(len(c.Entries)); err != nil { return nil, err @@ -1824,7 +1824,7 @@ func (c *IndexedGenericConfigEntries) MarshalBinary() (data []byte, err error) { func (c *IndexedGenericConfigEntries) UnmarshalBinary(data []byte) error { // First decode the number of entries. var numEntries int - dec := codec.NewDecoderBytes(data, msgpackHandle) + dec := codec.NewDecoderBytes(data, MsgpackHandle) if err := dec.Decode(&numEntries); err != nil { return err } @@ -2135,19 +2135,26 @@ func (r *TombstoneRequest) RequestDatacenter() string { return r.Datacenter } -// msgpackHandle is a shared handle for encoding/decoding of structs -var msgpackHandle = &codec.MsgpackHandle{} +// MsgpackHandle is a shared handle for encoding/decoding msgpack payloads +var MsgpackHandle = &codec.MsgpackHandle{ + RawToString: true, + BasicHandle: codec.BasicHandle{ + DecodeOptions: codec.DecodeOptions{ + MapType: reflect.TypeOf(map[string]interface{}{}), + }, + }, +} // Decode is used to decode a MsgPack encoded object func Decode(buf []byte, out interface{}) error { - return codec.NewDecoder(bytes.NewReader(buf), msgpackHandle).Decode(out) + return codec.NewDecoder(bytes.NewReader(buf), MsgpackHandle).Decode(out) } // Encode is used to encode a MsgPack object with type prefix func Encode(t MessageType, msg interface{}) ([]byte, error) { var buf bytes.Buffer buf.WriteByte(uint8(t)) - err := codec.NewEncoder(&buf, msgpackHandle).Encode(msg) + err := codec.NewEncoder(&buf, MsgpackHandle).Encode(msg) return buf.Bytes(), err } diff --git a/agent/structs/structs_test.go b/agent/structs/structs_test.go index de2aee5464..41bd35e665 100644 --- a/agent/structs/structs_test.go +++ b/agent/structs/structs_test.go @@ -6,6 +6,7 @@ import ( "reflect" "strings" "testing" + "time" "github.com/hashicorp/consul/agent/cache" "github.com/hashicorp/consul/api" @@ -1815,3 +1816,150 @@ func TestServiceNode_JSON_Marshal(t *testing.T) { require.NoError(t, json.Unmarshal(buf, &out)) require.Equal(t, *sn, out) } + +// frankensteinStruct is an amalgamation of all of the different kinds of +// fields you could have on struct defined in the agent/structs package that we +// send through msgpack +type frankensteinStruct struct { + Child *monsterStruct + ChildSlice []*monsterStruct + ChildMap map[string]*monsterStruct +} +type monsterStruct struct { + Bool bool + Int int + Uint8 uint8 + Uint64 uint64 + Float32 float32 + Float64 float64 + String string + + Hash []byte + Uint32Slice []uint32 + Float64Slice []float64 + StringSlice []string + + MapInt map[string]int + MapString map[string]string + MapStringSlice map[string][]string + + // We explicitly DO NOT try to test the following types that involve + // interface{} as the TestMsgpackEncodeDecode test WILL fail. + // + // These are tested elsewhere for the very specific scenario in question, + // which usually takes a secondary trip through mapstructure during decode + // which papers over some of the additional conversions necessary to finish + // decoding. + // MapIface map[string]interface{} + // MapMapIface map[string]map[string]interface{} + + Dur time.Duration + DurPtr *time.Duration + Time time.Time + TimePtr *time.Time + + RaftIndex +} + +func makeFrank() *frankensteinStruct { + return &frankensteinStruct{ + Child: makeMonster(), + ChildSlice: []*monsterStruct{ + makeMonster(), + makeMonster(), + }, + ChildMap: map[string]*monsterStruct{ + "one": makeMonster(), // only put one key in here so the map order is fixed + }, + } +} + +func makeMonster() *monsterStruct { + var d time.Duration = 9 * time.Hour + var t time.Time = time.Date(2008, 1, 2, 3, 4, 5, 0, time.UTC) + + return &monsterStruct{ + Bool: true, + Int: -8, + Uint8: 5, + Uint64: 9, + Float32: 5.25, + Float64: 99.5, + String: "strval", + + Hash: []byte("hello"), + Uint32Slice: []uint32{1, 2, 3, 4}, + Float64Slice: []float64{9.2, 6.25}, + StringSlice: []string{"foo", "bar"}, + + // // MapIface will hold an amalgam of what AuthMethods and + // // CAConfigurations use in 'Config' + // MapIface: map[string]interface{}{ + // "Name": "inner", + // "Dur": "5s", + // "Bool": true, + // "Float": 15.25, + // "Int": int64(94), + // "Nested": map[string]string{ // this doesn't survive + // "foo": "bar", + // }, + // }, + // // MapMapIface map[string]map[string]interface{} + + MapInt: map[string]int{ + "int": 5, + }, + MapString: map[string]string{ + "aaa": "bbb", + }, + MapStringSlice: map[string][]string{ + "aaa": []string{"bbb"}, + }, + + Dur: 5 * time.Second, + DurPtr: &d, + Time: t.Add(-5 * time.Hour), + TimePtr: &t, + + RaftIndex: RaftIndex{ + CreateIndex: 1, + ModifyIndex: 3, + }, + } +} + +func TestStructs_MsgpackEncodeDecode_Monolith(t *testing.T) { + t.Run("monster", func(t *testing.T) { + in := makeMonster() + TestMsgpackEncodeDecode(t, in, false) + }) + t.Run("frankenstein", func(t *testing.T) { + in := makeFrank() + TestMsgpackEncodeDecode(t, in, false) + }) +} + +func TestSnapshotRequestResponse_MsgpackEncodeDecode(t *testing.T) { + t.Run("request", func(t *testing.T) { + in := &SnapshotRequest{ + Datacenter: "foo", + Token: "blah", + AllowStale: true, + Op: SnapshotRestore, + } + TestMsgpackEncodeDecode(t, in, true) + }) + t.Run("response", func(t *testing.T) { + in := &SnapshotResponse{ + Error: "blah", + QueryMeta: QueryMeta{ + Index: 3, + LastContact: 5 * time.Second, + KnownLeader: true, + ConsistencyLevel: "default", + }, + } + TestMsgpackEncodeDecode(t, in, true) + }) + +} diff --git a/agent/structs/testing.go b/agent/structs/testing.go new file mode 100644 index 0000000000..5d52d87a0a --- /dev/null +++ b/agent/structs/testing.go @@ -0,0 +1,76 @@ +package structs + +import ( + "bytes" + "reflect" + "testing" + + "github.com/hashicorp/go-msgpack/codec" + "github.com/stretchr/testify/require" +) + +// TestingOldPre1dot7MsgpackHandle is the common configuration pre-1.7.0 +var TestingOldPre1dot7MsgpackHandle = &codec.MsgpackHandle{} + +// TestMsgpackEncodeDecode is a test helper to easily write a test to verify +// msgpack encoding and decoding using two handles is identical. +func TestMsgpackEncodeDecode(t *testing.T, in interface{}, requireEncoderEquality bool) { + t.Helper() + var ( + // This is the common configuration pre-1.7.0 + handle1 = TestingOldPre1dot7MsgpackHandle + // This is the common configuration post-1.7.0 + handle2 = MsgpackHandle + ) + + // Verify the 3 interface{} args are all pointers to the same kind of type. + inType := reflect.TypeOf(in) + require.Equal(t, reflect.Ptr, inType.Kind()) + + // Encode using both handles. + var b1 []byte + { + var buf bytes.Buffer + enc := codec.NewEncoder(&buf, handle1) + require.NoError(t, enc.Encode(in)) + b1 = buf.Bytes() + } + var b2 []byte + { + var buf bytes.Buffer + enc := codec.NewEncoder(&buf, handle2) + require.NoError(t, enc.Encode(in)) + b2 = buf.Bytes() + } + + if requireEncoderEquality { + // The resulting bytes should be identical. + require.Equal(t, b1, b2) + } + + // Decode both outputs using both handles. + t.Run("old encoder and old decoder", func(t *testing.T) { + out1 := reflect.New(inType.Elem()).Interface() + dec := codec.NewDecoderBytes(b1, handle1) + require.NoError(t, dec.Decode(out1)) + require.Equal(t, in, out1) + }) + t.Run("old encoder and new decoder", func(t *testing.T) { + out1 := reflect.New(inType.Elem()).Interface() + dec := codec.NewDecoderBytes(b1, handle2) + require.NoError(t, dec.Decode(out1)) + require.Equal(t, in, out1) + }) + t.Run("new encoder and old decoder", func(t *testing.T) { + out2 := reflect.New(inType.Elem()).Interface() + dec := codec.NewDecoderBytes(b2, handle1) + require.NoError(t, dec.Decode(out2)) + require.Equal(t, in, out2) + }) + t.Run("new encoder and new decoder", func(t *testing.T) { + out2 := reflect.New(inType.Elem()).Interface() + dec := codec.NewDecoderBytes(b2, handle2) + require.NoError(t, dec.Decode(out2)) + require.Equal(t, in, out2) + }) +} diff --git a/agent/user_event.go b/agent/user_event.go index c0c87bd328..bf1f5cb16d 100644 --- a/agent/user_event.go +++ b/agent/user_event.go @@ -1,10 +1,12 @@ package agent import ( + "bytes" "fmt" "regexp" "github.com/hashicorp/consul/agent/structs" + "github.com/hashicorp/go-msgpack/codec" "github.com/hashicorp/go-uuid" ) @@ -84,7 +86,7 @@ func (a *Agent) UserEvent(dc, token string, params *UserEvent) error { return fmt.Errorf("UUID generation failed: %v", err) } params.Version = userEventMaxVersion - payload, err := encodeMsgPack(¶ms) + payload, err := encodeMsgPackUserEvent(¶ms) if err != nil { return fmt.Errorf("UserEvent encoding failed: %v", err) } @@ -112,7 +114,7 @@ func (a *Agent) handleEvents() { case e := <-a.eventCh: // Decode the event msg := new(UserEvent) - if err := decodeMsgPack(e.Payload, msg); err != nil { + if err := decodeMsgPackUserEvent(e.Payload, msg); err != nil { a.logger.Error("Failed to decode event", "error", err) continue } @@ -280,3 +282,22 @@ func (a *Agent) LastUserEvent() *UserEvent { idx := (((a.eventIndex - 1) % n) + n) % n return a.eventBuf[idx] } + +// msgpackHandleUserEvent is a shared handle for encoding/decoding of +// messages for user events +var msgpackHandleUserEvent = &codec.MsgpackHandle{ + RawToString: true, + WriteExt: true, +} + +// decodeMsgPackUserEvent is used to decode a MsgPack encoded object +func decodeMsgPackUserEvent(buf []byte, out interface{}) error { + return codec.NewDecoder(bytes.NewReader(buf), msgpackHandleUserEvent).Decode(out) +} + +// encodeMsgPackUserEvent is used to encode an object with msgpack +func encodeMsgPackUserEvent(msg interface{}) ([]byte, error) { + var buf bytes.Buffer + err := codec.NewEncoder(&buf, msgpackHandleUserEvent).Encode(msg) + return buf.Bytes(), err +} diff --git a/agent/util.go b/agent/util.go index 835b756cb7..2649563619 100644 --- a/agent/util.go +++ b/agent/util.go @@ -1,7 +1,6 @@ package agent import ( - "bytes" "crypto/md5" "fmt" "os" @@ -13,28 +12,8 @@ import ( "time" "github.com/hashicorp/consul/types" - "github.com/hashicorp/go-msgpack/codec" ) -// msgpackHandle is a shared handle for encoding/decoding of -// messages -var msgpackHandle = &codec.MsgpackHandle{ - RawToString: true, - WriteExt: true, -} - -// decodeMsgPack is used to decode a MsgPack encoded object -func decodeMsgPack(buf []byte, out interface{}) error { - return codec.NewDecoder(bytes.NewReader(buf), msgpackHandle).Decode(out) -} - -// encodeMsgPack is used to encode an object with msgpack -func encodeMsgPack(msg interface{}) ([]byte, error) { - var buf bytes.Buffer - err := codec.NewEncoder(&buf, msgpackHandle).Encode(msg) - return buf.Bytes(), err -} - // stringHash returns a simple md5sum for a string. func stringHash(s string) string { return fmt.Sprintf("%x", md5.Sum([]byte(s))) diff --git a/snapshot/snapshot_test.go b/snapshot/snapshot_test.go index 99593b93bb..29e8ec413f 100644 --- a/snapshot/snapshot_test.go +++ b/snapshot/snapshot_test.go @@ -12,6 +12,7 @@ import ( "testing" "time" + "github.com/hashicorp/consul/agent/structs" "github.com/hashicorp/consul/sdk/testutil" "github.com/hashicorp/go-msgpack/codec" "github.com/hashicorp/raft" @@ -51,8 +52,7 @@ func (m *MockFSM) Restore(in io.ReadCloser) error { m.Lock() defer m.Unlock() defer in.Close() - hd := codec.MsgpackHandle{} - dec := codec.NewDecoder(in, &hd) + dec := codec.NewDecoder(in, structs.MsgpackHandle) m.logs = nil return dec.Decode(&m.logs) @@ -60,8 +60,7 @@ func (m *MockFSM) Restore(in io.ReadCloser) error { // See raft.SnapshotSink. func (m *MockSnapshot) Persist(sink raft.SnapshotSink) error { - hd := codec.MsgpackHandle{} - enc := codec.NewEncoder(sink, &hd) + enc := codec.NewEncoder(sink, structs.MsgpackHandle) if err := enc.Encode(m.logs[:m.maxIndex]); err != nil { sink.Cancel() return err