diff --git a/agent/consul/internal_endpoint.go b/agent/consul/internal_endpoint.go index 720810f0d8..c0cefd65a2 100644 --- a/agent/consul/internal_endpoint.go +++ b/agent/consul/internal_endpoint.go @@ -250,6 +250,11 @@ func (m *Internal) KeyringOperation( args *structs.KeyringRequest, reply *structs.KeyringResponses) error { + // Error aggressively to be clear about LocalOnly behavior + if args.LocalOnly && args.Operation != structs.KeyringList { + return fmt.Errorf("argument error: LocalOnly can only be used for List operations") + } + // Check ACLs identity, rule, err := m.srv.ResolveTokenToIdentityAndAuthorizer(args.Token) if err != nil { @@ -277,44 +282,63 @@ func (m *Internal) KeyringOperation( } } - // Validate use of local-only - if args.LocalOnly && args.Operation != structs.KeyringList { - // Error aggressively to be clear about LocalOnly behavior - return fmt.Errorf("argument error: LocalOnly can only be used for List operations") - } + if args.LocalOnly || args.Forwarded || m.srv.serfWAN == nil { + // Handle operations that are localOnly, already forwarded or + // there is no serfWAN. If any of this is the case this + // operation shouldn't go out to other dcs or WAN pool. + reply.Responses = append(reply.Responses, m.executeKeyringOpLAN(args)...) + } else { + // Handle not already forwarded, non-local operations. - // args.LocalOnly should always be false for non-GET requests - if !args.LocalOnly { - // Only perform WAN keyring querying and RPC forwarding once - if !args.Forwarded && m.srv.serfWAN != nil { - args.Forwarded = true - m.executeKeyringOp(args, reply, true) - return m.srv.globalRPC("Internal.KeyringOperation", args, reply) + // Marking this as forwarded because this is what we are about + // to do. Prevents the same message from being fowarded by + // other servers. + args.Forwarded = true + reply.Responses = append(reply.Responses, m.executeKeyringOpWAN(args)) + reply.Responses = append(reply.Responses, m.executeKeyringOpLAN(args)...) + + dcs := m.srv.router.GetRemoteDatacenters(m.srv.config.Datacenter) + responses, err := m.srv.keyringRPCs("Internal.KeyringOperation", args, dcs) + if err != nil { + return err } + reply.Add(responses) } - - // Query the LAN keyring of this node's DC - m.executeKeyringOp(args, reply, false) return nil } -// executeKeyringOp executes the keyring-related operation in the request -// on either the WAN or LAN pools. -func (m *Internal) executeKeyringOp( - args *structs.KeyringRequest, - reply *structs.KeyringResponses, - wan bool) { - - if wan { - mgr := m.srv.KeyManagerWAN() - m.executeKeyringOpMgr(mgr, args, reply, wan, "") - } else { - segments := m.srv.LANSegments() - for name, segment := range segments { - mgr := segment.KeyManager() - m.executeKeyringOpMgr(mgr, args, reply, wan, name) - } +func (m *Internal) executeKeyringOpLAN(args *structs.KeyringRequest) []*structs.KeyringResponse { + responses := []*structs.KeyringResponse{} + segments := m.srv.LANSegments() + for name, segment := range segments { + mgr := segment.KeyManager() + serfResp, err := m.executeKeyringOpMgr(mgr, args) + resp := translateKeyResponseToKeyringResponse(serfResp, m.srv.config.Datacenter, err) + resp.Segment = name + responses = append(responses, &resp) } + return responses +} + +func (m *Internal) executeKeyringOpWAN(args *structs.KeyringRequest) *structs.KeyringResponse { + mgr := m.srv.KeyManagerWAN() + serfResp, err := m.executeKeyringOpMgr(mgr, args) + resp := translateKeyResponseToKeyringResponse(serfResp, m.srv.config.Datacenter, err) + resp.WAN = true + return &resp +} + +func translateKeyResponseToKeyringResponse(keyresponse *serf.KeyResponse, datacenter string, err error) structs.KeyringResponse { + resp := structs.KeyringResponse{ + Datacenter: datacenter, + Messages: keyresponse.Messages, + Keys: keyresponse.Keys, + NumNodes: keyresponse.NumNodes, + } + if err != nil { + resp.Error = err.Error() + } + return resp } // executeKeyringOpMgr executes the appropriate keyring-related function based on @@ -323,9 +347,7 @@ func (m *Internal) executeKeyringOp( func (m *Internal) executeKeyringOpMgr( mgr *serf.KeyManager, args *structs.KeyringRequest, - reply *structs.KeyringResponses, - wan bool, - segment string) { +) (*serf.KeyResponse, error) { var serfResp *serf.KeyResponse var err error @@ -341,20 +363,7 @@ func (m *Internal) executeKeyringOpMgr( serfResp, err = mgr.RemoveKeyWithOptions(args.Key, opts) } - errStr := "" - if err != nil { - errStr = err.Error() - } - - reply.Responses = append(reply.Responses, &structs.KeyringResponse{ - WAN: wan, - Datacenter: m.srv.config.Datacenter, - Segment: segment, - Messages: serfResp.Messages, - Keys: serfResp.Keys, - NumNodes: serfResp.NumNodes, - Error: errStr, - }) + return serfResp, err } // aclAccessorID is used to convert an ACLToken's secretID to its accessorID for non- diff --git a/agent/consul/rpc.go b/agent/consul/rpc.go index 4ab54a8b9e..0a520dcee0 100644 --- a/agent/consul/rpc.go +++ b/agent/consul/rpc.go @@ -635,22 +635,17 @@ func (s *Server) forwardDC(method, dc string, args interface{}, reply interface{ return nil } -// globalRPC is used to forward an RPC request to one server in each datacenter. -// This will only error for RPC-related errors. Otherwise, application-level -// errors can be sent in the response objects. -func (s *Server) globalRPC(method string, args interface{}, - reply structs.CompoundResponse) error { +// keyringRPCs is used to forward an RPC request to a server in each dc. This +// will only error for RPC-related errors. Otherwise, application-level errors +// can be sent in the response objects. +func (s *Server) keyringRPCs(method string, args interface{}, dcs []string) (*structs.KeyringResponses, error) { - // Make a new request into each datacenter - dcs := s.router.GetDatacenters() - - replies, total := 0, len(dcs) - errorCh := make(chan error, total) - respCh := make(chan interface{}, total) + errorCh := make(chan error, len(dcs)) + respCh := make(chan *structs.KeyringResponses, len(dcs)) for _, dc := range dcs { go func(dc string) { - rr := reply.New() + rr := &structs.KeyringResponses{} if err := s.forwardDC(method, dc, args, &rr); err != nil { errorCh <- err return @@ -659,16 +654,16 @@ func (s *Server) globalRPC(method string, args interface{}, }(dc) } - for replies < total { + responses := &structs.KeyringResponses{} + for i := 0; i < len(dcs); i++ { select { case err := <-errorCh: - return err + return nil, err case rr := <-respCh: - reply.Add(rr) - replies++ + responses.Add(rr) } } - return nil + return responses, nil } type raftEncoder func(structs.MessageType, interface{}) ([]byte, error) diff --git a/agent/consul/server_test.go b/agent/consul/server_test.go index 18d61618f0..a6516b87dd 100644 --- a/agent/consul/server_test.go +++ b/agent/consul/server_test.go @@ -1286,7 +1286,7 @@ func (r *fakeGlobalResp) New() interface{} { return struct{}{} } -func TestServer_globalRPCErrors(t *testing.T) { +func TestServer_keyringRPCs(t *testing.T) { t.Parallel() dir1, s1 := testServerDC(t, "dc1") defer os.RemoveAll(dir1) @@ -1298,7 +1298,7 @@ func TestServer_globalRPCErrors(t *testing.T) { }) // Check that an error from a remote DC is returned - err := s1.globalRPC("Bad.Method", nil, &fakeGlobalResp{}) + _, err := s1.keyringRPCs("Bad.Method", nil, []string{s1.config.Datacenter}) if err == nil { t.Fatalf("should have errored") } diff --git a/agent/router/router.go b/agent/router/router.go index 64df6a003e..f8fa52ff39 100644 --- a/agent/router/router.go +++ b/agent/router/router.go @@ -406,6 +406,24 @@ func (r *Router) GetDatacenters() []string { return dcs } +// GetRemoteDatacenters returns a list of remote datacenters known to the router, sorted by +// name. +func (r *Router) GetRemoteDatacenters(local string) []string { + r.RLock() + defer r.RUnlock() + + dcs := make([]string, 0, len(r.managers)) + for dc := range r.managers { + if dc == local { + continue + } + dcs = append(dcs, dc) + } + + sort.Strings(dcs) + return dcs +} + // HasDatacenter checks whether dc is defined in WAN func (r *Router) HasDatacenter(dc string) bool { r.RLock()