From bafa5c7156c5e4137c9e4955917c0f0bd3f63784 Mon Sep 17 00:00:00 2001 From: Semir Patel Date: Wed, 14 Dec 2022 09:24:22 -0600 Subject: [PATCH] Pass remote addr of incoming HTTP requests through to RPC(..) calls (#15700) --- .gitignore | 1 + agent/acl_endpoint.go | 54 ++-- agent/acl_endpoint_test.go | 13 +- agent/acl_test.go | 3 +- agent/agent.go | 9 +- agent/agent_endpoint_test.go | 14 +- agent/agent_test.go | 10 +- agent/cache-types/catalog_datacenters.go | 3 +- agent/cache-types/catalog_list_services.go | 3 +- agent/cache-types/catalog_service_list.go | 3 +- agent/cache-types/catalog_services.go | 3 +- agent/cache-types/config_entry.go | 5 +- agent/cache-types/connect_ca_leaf.go | 2 +- agent/cache-types/connect_ca_leaf_test.go | 3 +- agent/cache-types/connect_ca_root.go | 3 +- agent/cache-types/discovery_chain.go | 3 +- agent/cache-types/exported_peered_services.go | 3 +- .../federation_state_list_gateways.go | 3 +- agent/cache-types/gateway_services.go | 3 +- agent/cache-types/health_services.go | 3 +- agent/cache-types/intention_match.go | 3 +- agent/cache-types/intention_upstreams.go | 3 +- .../intention_upstreams_destination.go | 3 +- agent/cache-types/mock_RPC.go | 5 +- agent/cache-types/node_services.go | 3 +- agent/cache-types/peered_upstreams.go | 3 +- agent/cache-types/prepared_query.go | 3 +- agent/cache-types/resolved_service_config.go | 3 +- agent/cache-types/rpc.go | 4 +- agent/cache-types/service_dump.go | 3 +- agent/cache-types/service_gateways.go | 3 +- agent/catalog_endpoint.go | 18 +- agent/catalog_endpoint_test.go | 85 +++--- agent/checks/alias.go | 5 +- agent/checks/alias_test.go | 3 +- agent/checks/check.go | 2 +- agent/config_endpoint.go | 8 +- agent/config_endpoint_test.go | 31 +- agent/connect/testing_ca.go | 5 +- agent/connect_ca_endpoint.go | 6 +- agent/consul/acl.go | 9 +- agent/consul/acl_oss_test.go | 3 +- agent/consul/acl_replication.go | 12 +- agent/consul/acl_replication_test.go | 47 +-- agent/consul/acl_test.go | 5 +- agent/consul/client.go | 3 +- agent/consul/client_test.go | 35 +-- agent/consul/config_replication.go | 2 +- agent/consul/config_replication_test.go | 13 +- agent/consul/context.go | 20 ++ agent/consul/context_test.go | 27 ++ agent/consul/federation_state_replication.go | 2 +- .../federation_state_replication_test.go | 7 +- agent/consul/leader_connect_ca_test.go | 2 +- agent/consul/leader_connect_test.go | 14 +- .../consul/leader_federation_state_ae_test.go | 5 +- agent/consul/leader_intentions_test.go | 7 +- agent/consul/leader_test.go | 19 +- agent/consul/rpc_test.go | 2 +- agent/consul/server.go | 23 +- agent/consul/server_test.go | 18 +- agent/consul/session_ttl_test.go | 3 +- agent/consul/subscribe_backend_test.go | 6 +- agent/consul/txn_endpoint_test.go | 3 +- agent/coordinate_endpoint.go | 8 +- agent/coordinate_endpoint_test.go | 17 +- agent/delegate_mock_test.go | 3 +- agent/discovery_chain_endpoint.go | 2 +- agent/discovery_chain_endpoint_test.go | 3 +- agent/dns.go | 10 +- agent/dns_oss_test.go | 9 +- agent/dns_test.go | 267 +++++++++--------- agent/federation_state_endpoint.go | 6 +- agent/health_endpoint.go | 6 +- agent/health_endpoint_test.go | 77 ++--- agent/http.go | 16 ++ agent/http_test.go | 41 +++ agent/intentions_endpoint.go | 20 +- agent/intentions_endpoint_test.go | 29 +- agent/keyring.go | 3 +- agent/kvs_endpoint.go | 8 +- agent/local/state.go | 19 +- agent/local/state_test.go | 103 +++---- agent/metrics_test.go | 8 +- agent/operator_endpoint.go | 17 +- agent/operator_endpoint_test.go | 7 +- agent/prepared_query_endpoint.go | 14 +- agent/prepared_query_endpoint_test.go | 3 +- agent/remote_exec.go | 5 +- agent/remote_exec_test.go | 9 +- agent/rpcclient/health/health.go | 6 +- agent/rpcclient/health/health_test.go | 2 +- agent/service_manager_test.go | 3 +- agent/session_endpoint.go | 12 +- agent/session_endpoint_test.go | 9 +- agent/status_endpoint.go | 4 +- agent/testagent.go | 2 +- agent/txn_endpoint.go | 4 +- agent/ui_endpoint.go | 16 +- agent/ui_endpoint_oss_test.go | 5 +- agent/ui_endpoint_test.go | 54 ++-- agent/user_event.go | 3 +- agent/user_event_test.go | 7 +- command/connect/ca/set/connect_ca_set_test.go | 3 +- .../set/operator_autopilot_set_test.go | 3 +- command/rtt/rtt_test.go | 7 +- go.mod | 2 +- go.sum | 4 +- test/integration/consul-container/go.mod | 2 +- test/integration/consul-container/go.sum | 4 +- testrpc/wait.go | 17 +- 111 files changed, 845 insertions(+), 664 deletions(-) create mode 100644 agent/consul/context.go create mode 100644 agent/consul/context_test.go diff --git a/.gitignore b/.gitignore index faa0961470..ade8cd97d1 100644 --- a/.gitignore +++ b/.gitignore @@ -15,6 +15,7 @@ exit-code Thumbs.db .idea .vscode +__debug_bin # MacOS .DS_Store diff --git a/agent/acl_endpoint.go b/agent/acl_endpoint.go index 23e43ba5c4..d3fa62b12e 100644 --- a/agent/acl_endpoint.go +++ b/agent/acl_endpoint.go @@ -38,7 +38,7 @@ func (s *HTTPHandlers) ACLBootstrap(resp http.ResponseWriter, req *http.Request) Datacenter: s.agent.config.Datacenter, } var out structs.ACLToken - err := s.agent.RPC("ACL.BootstrapTokens", &args, &out) + err := s.agent.RPC(req.Context(), "ACL.BootstrapTokens", &args, &out) if err != nil { if strings.Contains(err.Error(), structs.ACLBootstrapNotAllowedErr.Error()) { return nil, acl.PermissionDeniedError{Cause: err.Error()} @@ -64,7 +64,7 @@ func (s *HTTPHandlers) ACLReplicationStatus(resp http.ResponseWriter, req *http. // Make the request. var out structs.ACLReplicationStatus - if err := s.agent.RPC("ACL.ReplicationStatus", &args, &out); err != nil { + if err := s.agent.RPC(req.Context(), "ACL.ReplicationStatus", &args, &out); err != nil { return nil, err } return out, nil @@ -89,7 +89,7 @@ func (s *HTTPHandlers) ACLPolicyList(resp http.ResponseWriter, req *http.Request var out structs.ACLPolicyListResponse defer setMeta(resp, &out.QueryMeta) - if err := s.agent.RPC("ACL.PolicyList", &args, &out); err != nil { + if err := s.agent.RPC(req.Context(), "ACL.PolicyList", &args, &out); err != nil { return nil, err } @@ -150,7 +150,7 @@ func (s *HTTPHandlers) ACLPolicyRead(resp http.ResponseWriter, req *http.Request var out structs.ACLPolicyResponse defer setMeta(resp, &out.QueryMeta) - if err := s.agent.RPC("ACL.PolicyRead", &args, &out); err != nil { + if err := s.agent.RPC(req.Context(), "ACL.PolicyRead", &args, &out); err != nil { return nil, err } @@ -219,7 +219,7 @@ func (s *HTTPHandlers) aclPolicyWriteInternal(_resp http.ResponseWriter, req *ht } var out structs.ACLPolicy - if err := s.agent.RPC("ACL.PolicySet", args, &out); err != nil { + if err := s.agent.RPC(req.Context(), "ACL.PolicySet", args, &out); err != nil { return nil, err } @@ -237,7 +237,7 @@ func (s *HTTPHandlers) ACLPolicyDelete(resp http.ResponseWriter, req *http.Reque } var ignored string - if err := s.agent.RPC("ACL.PolicyDelete", args, &ignored); err != nil { + if err := s.agent.RPC(req.Context(), "ACL.PolicyDelete", args, &ignored); err != nil { return nil, err } @@ -274,7 +274,7 @@ func (s *HTTPHandlers) ACLTokenList(resp http.ResponseWriter, req *http.Request) var out structs.ACLTokenListResponse defer setMeta(resp, &out.QueryMeta) - if err := s.agent.RPC("ACL.TokenList", &args, &out); err != nil { + if err := s.agent.RPC(req.Context(), "ACL.TokenList", &args, &out); err != nil { return nil, err } @@ -336,7 +336,7 @@ func (s *HTTPHandlers) ACLTokenSelf(resp http.ResponseWriter, req *http.Request) var out structs.ACLTokenResponse defer setMeta(resp, &out.QueryMeta) - if err := s.agent.RPC("ACL.TokenRead", &args, &out); err != nil { + if err := s.agent.RPC(req.Context(), "ACL.TokenRead", &args, &out); err != nil { return nil, err } @@ -379,7 +379,7 @@ func (s *HTTPHandlers) ACLTokenGet(resp http.ResponseWriter, req *http.Request, var out structs.ACLTokenResponse defer setMeta(resp, &out.QueryMeta) - if err := s.agent.RPC("ACL.TokenRead", &args, &out); err != nil { + if err := s.agent.RPC(req.Context(), "ACL.TokenRead", &args, &out); err != nil { return nil, err } @@ -425,7 +425,7 @@ func (s *HTTPHandlers) aclTokenSetInternal(req *http.Request, tokenID string, cr } var out structs.ACLToken - if err := s.agent.RPC("ACL.TokenSet", args, &out); err != nil { + if err := s.agent.RPC(req.Context(), "ACL.TokenSet", args, &out); err != nil { return nil, err } @@ -443,7 +443,7 @@ func (s *HTTPHandlers) ACLTokenDelete(resp http.ResponseWriter, req *http.Reques } var ignored string - if err := s.agent.RPC("ACL.TokenDelete", args, &ignored); err != nil { + if err := s.agent.RPC(req.Context(), "ACL.TokenDelete", args, &ignored); err != nil { return nil, err } return true, nil @@ -471,7 +471,7 @@ func (s *HTTPHandlers) ACLTokenClone(resp http.ResponseWriter, req *http.Request args.ACLToken.AccessorID = tokenID var out structs.ACLToken - if err := s.agent.RPC("ACL.TokenClone", args, &out); err != nil { + if err := s.agent.RPC(req.Context(), "ACL.TokenClone", args, &out); err != nil { return nil, err } @@ -499,7 +499,7 @@ func (s *HTTPHandlers) ACLRoleList(resp http.ResponseWriter, req *http.Request) var out structs.ACLRoleListResponse defer setMeta(resp, &out.QueryMeta) - if err := s.agent.RPC("ACL.RoleList", &args, &out); err != nil { + if err := s.agent.RPC(req.Context(), "ACL.RoleList", &args, &out); err != nil { return nil, err } @@ -576,7 +576,7 @@ func (s *HTTPHandlers) ACLRoleRead(resp http.ResponseWriter, req *http.Request, var out structs.ACLRoleResponse defer setMeta(resp, &out.QueryMeta) - if err := s.agent.RPC("ACL.RoleRead", &args, &out); err != nil { + if err := s.agent.RPC(req.Context(), "ACL.RoleRead", &args, &out); err != nil { return nil, err } @@ -616,7 +616,7 @@ func (s *HTTPHandlers) ACLRoleWrite(resp http.ResponseWriter, req *http.Request, } var out structs.ACLRole - if err := s.agent.RPC("ACL.RoleSet", args, &out); err != nil { + if err := s.agent.RPC(req.Context(), "ACL.RoleSet", args, &out); err != nil { return nil, err } @@ -634,7 +634,7 @@ func (s *HTTPHandlers) ACLRoleDelete(resp http.ResponseWriter, req *http.Request } var ignored string - if err := s.agent.RPC("ACL.RoleDelete", args, &ignored); err != nil { + if err := s.agent.RPC(req.Context(), "ACL.RoleDelete", args, &ignored); err != nil { return nil, err } @@ -663,7 +663,7 @@ func (s *HTTPHandlers) ACLBindingRuleList(resp http.ResponseWriter, req *http.Re var out structs.ACLBindingRuleListResponse defer setMeta(resp, &out.QueryMeta) - if err := s.agent.RPC("ACL.BindingRuleList", &args, &out); err != nil { + if err := s.agent.RPC(req.Context(), "ACL.BindingRuleList", &args, &out); err != nil { return nil, err } @@ -723,7 +723,7 @@ func (s *HTTPHandlers) ACLBindingRuleRead(resp http.ResponseWriter, req *http.Re var out structs.ACLBindingRuleResponse defer setMeta(resp, &out.QueryMeta) - if err := s.agent.RPC("ACL.BindingRuleRead", &args, &out); err != nil { + if err := s.agent.RPC(req.Context(), "ACL.BindingRuleRead", &args, &out); err != nil { return nil, err } @@ -762,7 +762,7 @@ func (s *HTTPHandlers) ACLBindingRuleWrite(resp http.ResponseWriter, req *http.R } var out structs.ACLBindingRule - if err := s.agent.RPC("ACL.BindingRuleSet", args, &out); err != nil { + if err := s.agent.RPC(req.Context(), "ACL.BindingRuleSet", args, &out); err != nil { return nil, err } @@ -780,7 +780,7 @@ func (s *HTTPHandlers) ACLBindingRuleDelete(resp http.ResponseWriter, req *http. } var ignored bool - if err := s.agent.RPC("ACL.BindingRuleDelete", args, &ignored); err != nil { + if err := s.agent.RPC(req.Context(), "ACL.BindingRuleDelete", args, &ignored); err != nil { return nil, err } @@ -806,7 +806,7 @@ func (s *HTTPHandlers) ACLAuthMethodList(resp http.ResponseWriter, req *http.Req var out structs.ACLAuthMethodListResponse defer setMeta(resp, &out.QueryMeta) - if err := s.agent.RPC("ACL.AuthMethodList", &args, &out); err != nil { + if err := s.agent.RPC(req.Context(), "ACL.AuthMethodList", &args, &out); err != nil { return nil, err } @@ -865,7 +865,7 @@ func (s *HTTPHandlers) ACLAuthMethodRead(resp http.ResponseWriter, req *http.Req var out structs.ACLAuthMethodResponse defer setMeta(resp, &out.QueryMeta) - if err := s.agent.RPC("ACL.AuthMethodRead", &args, &out); err != nil { + if err := s.agent.RPC(req.Context(), "ACL.AuthMethodRead", &args, &out); err != nil { return nil, err } @@ -907,7 +907,7 @@ func (s *HTTPHandlers) ACLAuthMethodWrite(resp http.ResponseWriter, req *http.Re } var out structs.ACLAuthMethod - if err := s.agent.RPC("ACL.AuthMethodSet", args, &out); err != nil { + if err := s.agent.RPC(req.Context(), "ACL.AuthMethodSet", args, &out); err != nil { return nil, err } @@ -926,7 +926,7 @@ func (s *HTTPHandlers) ACLAuthMethodDelete(resp http.ResponseWriter, req *http.R } var ignored bool - if err := s.agent.RPC("ACL.AuthMethodDelete", args, &ignored); err != nil { + if err := s.agent.RPC(req.Context(), "ACL.AuthMethodDelete", args, &ignored); err != nil { return nil, err } @@ -952,7 +952,7 @@ func (s *HTTPHandlers) ACLLogin(resp http.ResponseWriter, req *http.Request) (in } var out structs.ACLToken - if err := s.agent.RPC("ACL.Login", args, &out); err != nil { + if err := s.agent.RPC(req.Context(), "ACL.Login", args, &out); err != nil { return nil, err } @@ -975,7 +975,7 @@ func (s *HTTPHandlers) ACLLogout(resp http.ResponseWriter, req *http.Request) (i } var ignored bool - if err := s.agent.RPC("ACL.Logout", &args, &ignored); err != nil { + if err := s.agent.RPC(req.Context(), "ACL.Logout", &args, &ignored); err != nil { return nil, err } @@ -1051,7 +1051,7 @@ func (s *HTTPHandlers) ACLAuthorize(resp http.ResponseWriter, req *http.Request) if request.Datacenter != "" && request.Datacenter != s.agent.config.Datacenter { // when we are targeting a datacenter other than our own then we must issue an RPC // to perform the resolution as it may involve a local token - if err := s.agent.RPC("ACL.Authorize", &request, &responses); err != nil { + if err := s.agent.RPC(req.Context(), "ACL.Authorize", &request, &responses); err != nil { return nil, err } } else { diff --git a/agent/acl_endpoint_test.go b/agent/acl_endpoint_test.go index 5cffef6ee3..da1b5b685a 100644 --- a/agent/acl_endpoint_test.go +++ b/agent/acl_endpoint_test.go @@ -2,6 +2,7 @@ package agent import ( "bytes" + "context" "encoding/json" "fmt" "io" @@ -1905,7 +1906,7 @@ func TestACL_Authorize(t *testing.T) { WriteRequest: structs.WriteRequest{Token: TestDefaultInitialManagementToken}, } var policy structs.ACLPolicy - require.NoError(t, a1.RPC("ACL.PolicySet", &policyReq, &policy)) + require.NoError(t, a1.RPC(context.Background(), "ACL.PolicySet", &policyReq, &policy)) tokenReq := structs.ACLTokenSetRequest{ ACLToken: structs.ACLToken{ @@ -1920,7 +1921,7 @@ func TestACL_Authorize(t *testing.T) { } var token structs.ACLToken - require.NoError(t, a1.RPC("ACL.TokenSet", &tokenReq, &token)) + require.NoError(t, a1.RPC(context.Background(), "ACL.TokenSet", &tokenReq, &token)) // secondary also needs to setup a replication token to pull tokens and policies secondaryParams := DefaultTestACLConfigParams() @@ -1953,7 +1954,7 @@ func TestACL_Authorize(t *testing.T) { } var localToken structs.ACLToken - require.NoError(t, a2.RPC("ACL.TokenSet", &localTokenReq, &localToken)) + require.NoError(t, a2.RPC(context.Background(), "ACL.TokenSet", &localTokenReq, &localToken)) t.Run("initial-management-token", func(t *testing.T) { request := []structs.ACLAuthorizationRequest{ @@ -2367,7 +2368,7 @@ func TestACL_Authorize(t *testing.T) { }) } -type rpcFn func(string, interface{}, interface{}) error +type rpcFn func(context.Context, string, interface{}, interface{}) error func upsertTestCustomizedAuthMethod( rpc rpcFn, initialManagementToken string, datacenter string, @@ -2393,7 +2394,7 @@ func upsertTestCustomizedAuthMethod( var out structs.ACLAuthMethod - err = rpc("ACL.AuthMethodSet", &req, &out) + err = rpc(context.Background(), "ACL.AuthMethodSet", &req, &out) if err != nil { return nil, err } @@ -2414,7 +2415,7 @@ func upsertTestCustomizedBindingRule(rpc rpcFn, initialManagementToken string, d var out structs.ACLBindingRule - err := rpc("ACL.BindingRuleSet", &req, &out) + err := rpc(context.Background(), "ACL.BindingRuleSet", &req, &out) if err != nil { return nil, err } diff --git a/agent/acl_test.go b/agent/acl_test.go index 48679122c0..310f8e0fd0 100644 --- a/agent/acl_test.go +++ b/agent/acl_test.go @@ -1,6 +1,7 @@ package agent import ( + "context" "fmt" "io" "testing" @@ -144,7 +145,7 @@ func (a *TestACLAgent) JoinLAN(addrs []string, entMeta *acl.EnterpriseMeta) (n i func (a *TestACLAgent) RemoveFailedNode(node string, prune bool, entMeta *acl.EnterpriseMeta) error { return fmt.Errorf("Unimplemented") } -func (a *TestACLAgent) RPC(method string, args interface{}, reply interface{}) error { +func (a *TestACLAgent) RPC(ctx context.Context, method string, args interface{}, reply interface{}) error { return fmt.Errorf("Unimplemented") } func (a *TestACLAgent) SnapshotRPC(args *structs.SnapshotRequest, in io.Reader, out io.Writer, replyFn structs.SnapshotReplyFn) error { diff --git a/agent/agent.go b/agent/agent.go index 3998ba3abb..33dc577c27 100644 --- a/agent/agent.go +++ b/agent/agent.go @@ -189,7 +189,8 @@ type delegate interface { // default partition and namespace from the token. ResolveTokenAndDefaultMeta(token string, entMeta *acl.EnterpriseMeta, authzContext *acl.AuthorizerContext) (resolver.Result, error) - RPC(method string, args interface{}, reply interface{}) error + RPC(ctx context.Context, method string, args interface{}, reply interface{}) error + SnapshotRPC(args *structs.SnapshotRequest, in io.Reader, out io.Writer, replyFn structs.SnapshotReplyFn) error Shutdown() error Stats() map[string]map[string]string @@ -1552,7 +1553,7 @@ func (a *Agent) registerEndpoint(name string, handler interface{}) error { // RPC is used to make an RPC call to the Consul servers // This allows the agent to implement the Consul.Interface -func (a *Agent) RPC(method string, args interface{}, reply interface{}) error { +func (a *Agent) RPC(ctx context.Context, method string, args interface{}, reply interface{}) error { a.endpointsLock.RLock() // fast path: only translate if there are overrides if len(a.endpoints) > 0 { @@ -1562,7 +1563,7 @@ func (a *Agent) RPC(method string, args interface{}, reply interface{}) error { } } a.endpointsLock.RUnlock() - return a.delegate.RPC(method, args, reply) + return a.delegate.RPC(ctx, method, args, reply) } // Leave is used to prepare the agent for a graceful shutdown @@ -1950,7 +1951,7 @@ OUTER: var reply struct{} // todo(kit) port all of these logger calls to hclog w/ loglevel configuration // todo(kit) handle acl.ErrNotFound cases here in the future - if err := a.RPC("Coordinate.Update", &req, &reply); err != nil { + if err := a.RPC(context.Background(), "Coordinate.Update", &req, &reply); err != nil { if acl.IsErrPermissionDenied(err) { accessorID := a.aclAccessorID(agentToken) a.logger.Warn("Coordinate update blocked by ACLs", "accessorID", accessorID) diff --git a/agent/agent_endpoint_test.go b/agent/agent_endpoint_test.go index 252b2f8766..c19004115c 100644 --- a/agent/agent_endpoint_test.go +++ b/agent/agent_endpoint_test.go @@ -7122,7 +7122,7 @@ func TestAgentConnectCALeafCert_Vault_doesNotChurnLeafCertsAtIdle(t *testing.T) { args := &structs.DCSpecificRequest{Datacenter: "dc1"} var reply structs.IndexedCARoots - require.NoError(t, a.RPC("ConnectCA.Roots", args, &reply)) + require.NoError(t, a.RPC(context.Background(), "ConnectCA.Roots", args, &reply)) for _, r := range reply.Roots { if r.ID == reply.ActiveRootID { ca1 = r @@ -7550,7 +7550,7 @@ func TestAgentConnectAuthorize_allow(t *testing.T) { req.Intention.DestinationName = target req.Intention.Action = structs.IntentionActionAllow - require.Nil(t, a.RPC("Intention.Apply", &req, &ixnId)) + require.Nil(t, a.RPC(context.Background(), "Intention.Apply", &req, &ixnId)) } args := &structs.ConnectAuthorizeRequest{ @@ -7600,7 +7600,7 @@ func TestAgentConnectAuthorize_allow(t *testing.T) { req.Intention.DestinationName = target req.Intention.Action = structs.IntentionActionDeny - require.Nil(t, a.RPC("Intention.Apply", &req, &ixnId)) + require.Nil(t, a.RPC(context.Background(), "Intention.Apply", &req, &ixnId)) } // Short sleep lets the cache background refresh happen @@ -7653,7 +7653,7 @@ func TestAgentConnectAuthorize_deny(t *testing.T) { req.Intention.Action = structs.IntentionActionDeny var reply string - assert.Nil(t, a.RPC("Intention.Apply", &req, &reply)) + assert.Nil(t, a.RPC(context.Background(), "Intention.Apply", &req, &reply)) } args := &structs.ConnectAuthorizeRequest{ @@ -7706,7 +7706,7 @@ func TestAgentConnectAuthorize_allowTrustDomain(t *testing.T) { req.Intention.Action = structs.IntentionActionAllow var reply string - require.NoError(t, a.RPC("Intention.Apply", &req, &reply)) + require.NoError(t, a.RPC(context.Background(), "Intention.Apply", &req, &reply)) } { @@ -7755,7 +7755,7 @@ func TestAgentConnectAuthorize_denyWildcard(t *testing.T) { req.Intention.Action = structs.IntentionActionDeny var reply string - require.NoError(t, a.RPC("Intention.Apply", &req, &reply)) + require.NoError(t, a.RPC(context.Background(), "Intention.Apply", &req, &reply)) } { // Allow web to DB @@ -7771,7 +7771,7 @@ func TestAgentConnectAuthorize_denyWildcard(t *testing.T) { req.Intention.Action = structs.IntentionActionAllow var reply string - assert.Nil(t, a.RPC("Intention.Apply", &req, &reply)) + assert.Nil(t, a.RPC(context.Background(), "Intention.Apply", &req, &reply)) } // Web should be allowed diff --git a/agent/agent_test.go b/agent/agent_test.go index 9d9f710a6f..db7d457402 100644 --- a/agent/agent_test.go +++ b/agent/agent_test.go @@ -208,7 +208,7 @@ func TestAgent_RPCPing(t *testing.T) { testrpc.WaitForTestAgent(t, a.RPC, "dc1") var out struct{} - if err := a.RPC("Status.Ping", struct{}{}, &out); err != nil { + if err := a.RPC(context.Background(), "Status.Ping", struct{}{}, &out); err != nil { t.Fatalf("err: %v", err) } } @@ -837,7 +837,7 @@ func TestAgent_CheckAliasRPC(t *testing.T) { args.Node = "node1" args.AllowStale = true var out structs.IndexedNodeServices - err := a.RPC("Catalog.NodeServices", &args, &out) + err := a.RPC(context.Background(), "Catalog.NodeServices", &args, &out) assert.NoError(r, err) foundService := false lookup := structs.NewServiceID("svcid1", structs.WildcardEnterpriseMetaInDefaultPartition()) @@ -1451,7 +1451,7 @@ func verifyIndexChurn(t *testing.T, tags []string) { // check is added to an agent. 500ms so that we don't see flakiness ever. time.Sleep(500 * time.Millisecond) - if err := a.RPC("Health.ServiceNodes", args, &before); err != nil { + if err := a.RPC(context.Background(), "Health.ServiceNodes", args, &before); err != nil { t.Fatalf("err: %v", err) } for _, name := range before.Nodes[0].Checks { @@ -1474,7 +1474,7 @@ func verifyIndexChurn(t *testing.T, tags []string) { // has changed for the RPC, which means that idempotent ops // are not working as intended. var after structs.IndexedCheckServiceNodes - if err := a.RPC("Health.ServiceNodes", args, &after); err != nil { + if err := a.RPC(context.Background(), "Health.ServiceNodes", args, &after); err != nil { t.Fatalf("err: %v", err) } require.Equal(t, before, after) @@ -5281,7 +5281,7 @@ func TestAutoConfig_Integration(t *testing.T) { }, } var reply interface{} - require.NoError(t, srv.RPC("ConnectCA.ConfigurationSet", &req, &reply)) + require.NoError(t, srv.RPC(context.Background(), "ConnectCA.ConfigurationSet", &req, &reply)) // ensure that a new cert gets generated and pushed into the TLS configurator retry.Run(t, func(r *retry.R) { diff --git a/agent/cache-types/catalog_datacenters.go b/agent/cache-types/catalog_datacenters.go index 46c7e97e40..b84d2a933a 100644 --- a/agent/cache-types/catalog_datacenters.go +++ b/agent/cache-types/catalog_datacenters.go @@ -1,6 +1,7 @@ package cachetype import ( + "context" "fmt" "github.com/hashicorp/consul/agent/cache" @@ -38,7 +39,7 @@ func (c *CatalogDatacenters) Fetch(opts cache.FetchOptions, req cache.Request) ( // Fetch var reply []string - if err := c.RPC.RPC("Catalog.ListDatacenters", reqReal, &reply); err != nil { + if err := c.RPC.RPC(context.Background(), "Catalog.ListDatacenters", reqReal, &reply); err != nil { return result, err } diff --git a/agent/cache-types/catalog_list_services.go b/agent/cache-types/catalog_list_services.go index 1c602cdb29..0324d4a6e1 100644 --- a/agent/cache-types/catalog_list_services.go +++ b/agent/cache-types/catalog_list_services.go @@ -1,6 +1,7 @@ package cachetype import ( + "context" "fmt" "github.com/hashicorp/consul/agent/cache" @@ -45,7 +46,7 @@ func (c *CatalogListServices) Fetch(opts cache.FetchOptions, req cache.Request) } var reply structs.IndexedServices - if err := c.RPC.RPC("Catalog.ListServices", reqReal, &reply); err != nil { + if err := c.RPC.RPC(context.Background(), "Catalog.ListServices", reqReal, &reply); err != nil { return result, err } diff --git a/agent/cache-types/catalog_service_list.go b/agent/cache-types/catalog_service_list.go index aacdf3e2e2..7e417cd2c2 100644 --- a/agent/cache-types/catalog_service_list.go +++ b/agent/cache-types/catalog_service_list.go @@ -1,6 +1,7 @@ package cachetype import ( + "context" "fmt" "github.com/hashicorp/consul/agent/cache" @@ -42,7 +43,7 @@ func (c *CatalogServiceList) Fetch(opts cache.FetchOptions, req cache.Request) ( // Fetch var reply structs.IndexedServiceList - if err := c.RPC.RPC("Catalog.ServiceList", reqReal, &reply); err != nil { + if err := c.RPC.RPC(context.Background(), "Catalog.ServiceList", reqReal, &reply); err != nil { return result, err } diff --git a/agent/cache-types/catalog_services.go b/agent/cache-types/catalog_services.go index 43559de423..0f5f7f8aa0 100644 --- a/agent/cache-types/catalog_services.go +++ b/agent/cache-types/catalog_services.go @@ -1,6 +1,7 @@ package cachetype import ( + "context" "fmt" "github.com/hashicorp/consul/agent/cache" @@ -43,7 +44,7 @@ func (c *CatalogServices) Fetch(opts cache.FetchOptions, req cache.Request) (cac // Fetch var reply structs.IndexedServiceNodes - if err := c.RPC.RPC("Catalog.ServiceNodes", reqReal, &reply); err != nil { + if err := c.RPC.RPC(context.Background(), "Catalog.ServiceNodes", reqReal, &reply); err != nil { return result, err } diff --git a/agent/cache-types/config_entry.go b/agent/cache-types/config_entry.go index d48572d7fb..3c434c24f6 100644 --- a/agent/cache-types/config_entry.go +++ b/agent/cache-types/config_entry.go @@ -1,6 +1,7 @@ package cachetype import ( + "context" "fmt" "github.com/hashicorp/consul/agent/cache" @@ -45,7 +46,7 @@ func (c *ConfigEntryList) Fetch(opts cache.FetchOptions, req cache.Request) (cac // Fetch var reply structs.IndexedConfigEntries - if err := c.RPC.RPC("ConfigEntry.List", reqReal, &reply); err != nil { + if err := c.RPC.RPC(context.Background(), "ConfigEntry.List", reqReal, &reply); err != nil { return result, err } @@ -86,7 +87,7 @@ func (c *ConfigEntry) Fetch(opts cache.FetchOptions, req cache.Request) (cache.F // Fetch var reply structs.ConfigEntryResponse - if err := c.RPC.RPC("ConfigEntry.Get", reqReal, &reply); err != nil { + if err := c.RPC.RPC(context.Background(), "ConfigEntry.Get", reqReal, &reply); err != nil { return result, err } diff --git a/agent/cache-types/connect_ca_leaf.go b/agent/cache-types/connect_ca_leaf.go index f12ce1ece6..9bee39af7d 100644 --- a/agent/cache-types/connect_ca_leaf.go +++ b/agent/cache-types/connect_ca_leaf.go @@ -618,7 +618,7 @@ func (c *ConnectCALeaf) generateNewLeaf(req *ConnectCALeafRequest, Datacenter: req.Datacenter, CSR: csr, } - if err := c.RPC.RPC("ConnectCA.Sign", &args, &reply); err != nil { + if err := c.RPC.RPC(context.Background(), "ConnectCA.Sign", &args, &reply); err != nil { if err.Error() == consul.ErrRateLimited.Error() { if result.Value == nil { // This was a first fetch - we have no good value in cache. In this case diff --git a/agent/cache-types/connect_ca_leaf_test.go b/agent/cache-types/connect_ca_leaf_test.go index 04fe805cb0..1f07240657 100644 --- a/agent/cache-types/connect_ca_leaf_test.go +++ b/agent/cache-types/connect_ca_leaf_test.go @@ -1,6 +1,7 @@ package cachetype import ( + "context" "crypto/x509" "encoding/pem" "fmt" @@ -1093,7 +1094,7 @@ type testGatedRootsRPC struct { ValueCh chan structs.IndexedCARoots } -func (r *testGatedRootsRPC) RPC(method string, args interface{}, reply interface{}) error { +func (r *testGatedRootsRPC) RPC(ctx context.Context, method string, args interface{}, reply interface{}) error { if method != "ConnectCA.Roots" { return fmt.Errorf("invalid RPC method: %s", method) } diff --git a/agent/cache-types/connect_ca_root.go b/agent/cache-types/connect_ca_root.go index e4f3816f88..deca68e31c 100644 --- a/agent/cache-types/connect_ca_root.go +++ b/agent/cache-types/connect_ca_root.go @@ -1,6 +1,7 @@ package cachetype import ( + "context" "fmt" "github.com/hashicorp/consul/agent/cache" @@ -38,7 +39,7 @@ func (c *ConnectCARoot) Fetch(opts cache.FetchOptions, req cache.Request) (cache // Fetch var reply structs.IndexedCARoots - if err := c.RPC.RPC("ConnectCA.Roots", reqReal, &reply); err != nil { + if err := c.RPC.RPC(context.Background(), "ConnectCA.Roots", reqReal, &reply); err != nil { return result, err } diff --git a/agent/cache-types/discovery_chain.go b/agent/cache-types/discovery_chain.go index 5dd6eaed2e..2d8cd1bea6 100644 --- a/agent/cache-types/discovery_chain.go +++ b/agent/cache-types/discovery_chain.go @@ -1,6 +1,7 @@ package cachetype import ( + "context" "fmt" "github.com/hashicorp/consul/agent/cache" @@ -43,7 +44,7 @@ func (c *CompiledDiscoveryChain) Fetch(opts cache.FetchOptions, req cache.Reques // Fetch var reply structs.DiscoveryChainResponse - if err := c.RPC.RPC("DiscoveryChain.Get", reqReal, &reply); err != nil { + if err := c.RPC.RPC(context.Background(), "DiscoveryChain.Get", reqReal, &reply); err != nil { return result, err } diff --git a/agent/cache-types/exported_peered_services.go b/agent/cache-types/exported_peered_services.go index 02bc46a4c2..21ff779f56 100644 --- a/agent/cache-types/exported_peered_services.go +++ b/agent/cache-types/exported_peered_services.go @@ -1,6 +1,7 @@ package cachetype import ( + "context" "fmt" "github.com/hashicorp/consul/agent/cache" @@ -41,7 +42,7 @@ func (c *ExportedPeeredServices) Fetch(opts cache.FetchOptions, req cache.Reques // Fetch var reply structs.IndexedExportedServiceList - if err := c.RPC.RPC("Internal.ExportedPeeredServices", reqReal, &reply); err != nil { + if err := c.RPC.RPC(context.Background(), "Internal.ExportedPeeredServices", reqReal, &reply); err != nil { return result, err } diff --git a/agent/cache-types/federation_state_list_gateways.go b/agent/cache-types/federation_state_list_gateways.go index 83df66978e..c28ad3700d 100644 --- a/agent/cache-types/federation_state_list_gateways.go +++ b/agent/cache-types/federation_state_list_gateways.go @@ -1,6 +1,7 @@ package cachetype import ( + "context" "fmt" "github.com/hashicorp/consul/agent/cache" @@ -42,7 +43,7 @@ func (c *FederationStateListMeshGateways) Fetch(opts cache.FetchOptions, req cac // Fetch var reply structs.DatacenterIndexedCheckServiceNodes - if err := c.RPC.RPC("FederationState.ListMeshGateways", reqReal, &reply); err != nil { + if err := c.RPC.RPC(context.Background(), "FederationState.ListMeshGateways", reqReal, &reply); err != nil { return result, err } diff --git a/agent/cache-types/gateway_services.go b/agent/cache-types/gateway_services.go index 6e3ca380d4..02ae60b080 100644 --- a/agent/cache-types/gateway_services.go +++ b/agent/cache-types/gateway_services.go @@ -1,6 +1,7 @@ package cachetype import ( + "context" "fmt" "github.com/hashicorp/consul/agent/cache" @@ -42,7 +43,7 @@ func (g *GatewayServices) Fetch(opts cache.FetchOptions, req cache.Request) (cac // Fetch var reply structs.IndexedGatewayServices - if err := g.RPC.RPC("Catalog.GatewayServices", reqReal, &reply); err != nil { + if err := g.RPC.RPC(context.Background(), "Catalog.GatewayServices", reqReal, &reply); err != nil { return result, err } diff --git a/agent/cache-types/health_services.go b/agent/cache-types/health_services.go index d73b952090..63e52470a1 100644 --- a/agent/cache-types/health_services.go +++ b/agent/cache-types/health_services.go @@ -1,6 +1,7 @@ package cachetype import ( + "context" "fmt" "github.com/hashicorp/consul/agent/cache" @@ -43,7 +44,7 @@ func (c *HealthServices) Fetch(opts cache.FetchOptions, req cache.Request) (cach // Fetch var reply structs.IndexedCheckServiceNodes - if err := c.RPC.RPC("Health.ServiceNodes", reqReal, &reply); err != nil { + if err := c.RPC.RPC(context.Background(), "Health.ServiceNodes", reqReal, &reply); err != nil { return result, err } diff --git a/agent/cache-types/intention_match.go b/agent/cache-types/intention_match.go index 688adfe438..3b4b519c72 100644 --- a/agent/cache-types/intention_match.go +++ b/agent/cache-types/intention_match.go @@ -1,6 +1,7 @@ package cachetype import ( + "context" "fmt" "github.com/hashicorp/consul/agent/cache" @@ -36,7 +37,7 @@ func (c *IntentionMatch) Fetch(opts cache.FetchOptions, req cache.Request) (cach // Fetch var reply structs.IndexedIntentionMatches - if err := c.RPC.RPC("Intention.Match", reqReal, &reply); err != nil { + if err := c.RPC.RPC(context.Background(), "Intention.Match", reqReal, &reply); err != nil { return result, err } diff --git a/agent/cache-types/intention_upstreams.go b/agent/cache-types/intention_upstreams.go index 489bf4cd96..80d657ebb8 100644 --- a/agent/cache-types/intention_upstreams.go +++ b/agent/cache-types/intention_upstreams.go @@ -1,6 +1,7 @@ package cachetype import ( + "context" "fmt" "github.com/hashicorp/consul/agent/cache" @@ -42,7 +43,7 @@ func (i *IntentionUpstreams) Fetch(opts cache.FetchOptions, req cache.Request) ( // Fetch var reply structs.IndexedServiceList - if err := i.RPC.RPC("Internal.IntentionUpstreams", reqReal, &reply); err != nil { + if err := i.RPC.RPC(context.Background(), "Internal.IntentionUpstreams", reqReal, &reply); err != nil { return result, err } diff --git a/agent/cache-types/intention_upstreams_destination.go b/agent/cache-types/intention_upstreams_destination.go index ae1012c354..9513003364 100644 --- a/agent/cache-types/intention_upstreams_destination.go +++ b/agent/cache-types/intention_upstreams_destination.go @@ -1,6 +1,7 @@ package cachetype import ( + "context" "fmt" "github.com/hashicorp/consul/agent/cache" @@ -42,7 +43,7 @@ func (i *IntentionUpstreamsDestination) Fetch(opts cache.FetchOptions, req cache // Fetch var reply structs.IndexedServiceList - if err := i.RPC.RPC("Internal.IntentionUpstreamsDestination", reqReal, &reply); err != nil { + if err := i.RPC.RPC(context.Background(), "Internal.IntentionUpstreamsDestination", reqReal, &reply); err != nil { return result, err } diff --git a/agent/cache-types/mock_RPC.go b/agent/cache-types/mock_RPC.go index 059623f04a..a4289afa30 100644 --- a/agent/cache-types/mock_RPC.go +++ b/agent/cache-types/mock_RPC.go @@ -3,6 +3,7 @@ package cachetype import ( + "context" testing "testing" mock "github.com/stretchr/testify/mock" @@ -13,8 +14,8 @@ type MockRPC struct { mock.Mock } -// RPC provides a mock function with given fields: method, args, reply -func (_m *MockRPC) RPC(method string, args interface{}, reply interface{}) error { +// RPC provides a mock function with given fields: ctx, method, args, reply +func (_m *MockRPC) RPC(ctx context.Context, method string, args interface{}, reply interface{}) error { ret := _m.Called(method, args, reply) var r0 error diff --git a/agent/cache-types/node_services.go b/agent/cache-types/node_services.go index 9ead9ba436..9856b50180 100644 --- a/agent/cache-types/node_services.go +++ b/agent/cache-types/node_services.go @@ -1,6 +1,7 @@ package cachetype import ( + "context" "fmt" "github.com/hashicorp/consul/agent/cache" @@ -43,7 +44,7 @@ func (c *NodeServices) Fetch(opts cache.FetchOptions, req cache.Request) (cache. // Fetch var reply structs.IndexedNodeServices - if err := c.RPC.RPC("Catalog.NodeServices", reqReal, &reply); err != nil { + if err := c.RPC.RPC(context.Background(), "Catalog.NodeServices", reqReal, &reply); err != nil { return result, err } diff --git a/agent/cache-types/peered_upstreams.go b/agent/cache-types/peered_upstreams.go index 8e8f9001a6..6aa8a3e34a 100644 --- a/agent/cache-types/peered_upstreams.go +++ b/agent/cache-types/peered_upstreams.go @@ -1,6 +1,7 @@ package cachetype import ( + "context" "fmt" "github.com/hashicorp/consul/agent/cache" @@ -41,7 +42,7 @@ func (i *PeeredUpstreams) Fetch(opts cache.FetchOptions, req cache.Request) (cac // Fetch var reply structs.IndexedPeeredServiceList - if err := i.RPC.RPC("Internal.PeeredUpstreams", reqReal, &reply); err != nil { + if err := i.RPC.RPC(context.Background(), "Internal.PeeredUpstreams", reqReal, &reply); err != nil { return result, err } diff --git a/agent/cache-types/prepared_query.go b/agent/cache-types/prepared_query.go index 3592bc2106..5a5230c7c2 100644 --- a/agent/cache-types/prepared_query.go +++ b/agent/cache-types/prepared_query.go @@ -1,6 +1,7 @@ package cachetype import ( + "context" "fmt" "github.com/hashicorp/consul/agent/cache" @@ -39,7 +40,7 @@ func (c *PreparedQuery) Fetch(_ cache.FetchOptions, req cache.Request) (cache.Fe // Fetch var reply structs.PreparedQueryExecuteResponse - if err := c.RPC.RPC("PreparedQuery.Execute", reqReal, &reply); err != nil { + if err := c.RPC.RPC(context.Background(), "PreparedQuery.Execute", reqReal, &reply); err != nil { return result, err } diff --git a/agent/cache-types/resolved_service_config.go b/agent/cache-types/resolved_service_config.go index 7c17e06186..3065ab4eb3 100644 --- a/agent/cache-types/resolved_service_config.go +++ b/agent/cache-types/resolved_service_config.go @@ -1,6 +1,7 @@ package cachetype import ( + "context" "fmt" "github.com/hashicorp/consul/agent/cache" @@ -43,7 +44,7 @@ func (c *ResolvedServiceConfig) Fetch(opts cache.FetchOptions, req cache.Request // Fetch var reply structs.ServiceConfigResponse - if err := c.RPC.RPC("ConfigEntry.ResolveServiceConfig", reqReal, &reply); err != nil { + if err := c.RPC.RPC(context.Background(), "ConfigEntry.ResolveServiceConfig", reqReal, &reply); err != nil { return result, err } diff --git a/agent/cache-types/rpc.go b/agent/cache-types/rpc.go index 445ca20463..750e5058bc 100644 --- a/agent/cache-types/rpc.go +++ b/agent/cache-types/rpc.go @@ -1,10 +1,12 @@ package cachetype +import "context" + // RPC is an interface that an RPC client must implement. This is a helper // interface that is implemented by the agent delegate so that Type // implementations can request RPC access. // //go:generate mockery --name RPC --inpackage type RPC interface { - RPC(method string, args interface{}, reply interface{}) error + RPC(ctx context.Context, method string, args interface{}, reply interface{}) error } diff --git a/agent/cache-types/service_dump.go b/agent/cache-types/service_dump.go index 88ea1ed68b..fccd4da8ae 100644 --- a/agent/cache-types/service_dump.go +++ b/agent/cache-types/service_dump.go @@ -1,6 +1,7 @@ package cachetype import ( + "context" "fmt" "github.com/hashicorp/consul/agent/cache" @@ -42,7 +43,7 @@ func (c *InternalServiceDump) Fetch(opts cache.FetchOptions, req cache.Request) // Fetch var reply structs.IndexedNodesWithGateways - if err := c.RPC.RPC("Internal.ServiceDump", reqReal, &reply); err != nil { + if err := c.RPC.RPC(context.Background(), "Internal.ServiceDump", reqReal, &reply); err != nil { return result, err } diff --git a/agent/cache-types/service_gateways.go b/agent/cache-types/service_gateways.go index 1c7a8e8557..54ba269497 100644 --- a/agent/cache-types/service_gateways.go +++ b/agent/cache-types/service_gateways.go @@ -1,6 +1,7 @@ package cachetype import ( + "context" "fmt" "github.com/hashicorp/consul/agent/cache" @@ -42,7 +43,7 @@ func (g *ServiceGateways) Fetch(opts cache.FetchOptions, req cache.Request) (cac // Fetch var reply structs.IndexedCheckServiceNodes - if err := g.RPC.RPC("Internal.ServiceGateways", reqReal, &reply); err != nil { + if err := g.RPC.RPC(context.Background(), "Internal.ServiceGateways", reqReal, &reply); err != nil { return result, err } diff --git a/agent/catalog_endpoint.go b/agent/catalog_endpoint.go index 9623da6af4..305b317db8 100644 --- a/agent/catalog_endpoint.go +++ b/agent/catalog_endpoint.go @@ -148,7 +148,7 @@ func (s *HTTPHandlers) CatalogRegister(resp http.ResponseWriter, req *http.Reque // Forward to the servers var out struct{} - if err := s.agent.RPC("Catalog.Register", &args, &out); err != nil { + if err := s.agent.RPC(req.Context(), "Catalog.Register", &args, &out); err != nil { metrics.IncrCounterWithLabels([]string{"client", "rpc", "error", "catalog_register"}, 1, s.nodeMetricsLabels()) return nil, err @@ -178,7 +178,7 @@ func (s *HTTPHandlers) CatalogDeregister(resp http.ResponseWriter, req *http.Req // Forward to the servers var out struct{} - if err := s.agent.RPC("Catalog.Deregister", &args, &out); err != nil { + if err := s.agent.RPC(req.Context(), "Catalog.Deregister", &args, &out); err != nil { metrics.IncrCounterWithLabels([]string{"client", "rpc", "error", "catalog_deregister"}, 1, s.nodeMetricsLabels()) return nil, err @@ -212,7 +212,7 @@ func (s *HTTPHandlers) CatalogDatacenters(resp http.ResponseWriter, req *http.Re defer setCacheMeta(resp, &m) out = *reply } else { - if err := s.agent.RPC("Catalog.ListDatacenters", &args, &out); err != nil { + if err := s.agent.RPC(req.Context(), "Catalog.ListDatacenters", &args, &out); err != nil { metrics.IncrCounterWithLabels([]string{"client", "rpc", "error", "catalog_datacenters"}, 1, s.nodeMetricsLabels()) return nil, err @@ -244,7 +244,7 @@ func (s *HTTPHandlers) CatalogNodes(resp http.ResponseWriter, req *http.Request) var out structs.IndexedNodes defer setMeta(resp, &out.QueryMeta) RETRY_ONCE: - if err := s.agent.RPC("Catalog.ListNodes", &args, &out); err != nil { + if err := s.agent.RPC(req.Context(), "Catalog.ListNodes", &args, &out); err != nil { return nil, err } if args.QueryOptions.AllowStale && args.MaxStaleDuration > 0 && args.MaxStaleDuration < out.LastContact { @@ -297,7 +297,7 @@ func (s *HTTPHandlers) CatalogServices(resp http.ResponseWriter, req *http.Reque out = *reply } else { RETRY_ONCE: - if err := s.agent.RPC("Catalog.ListServices", &args, &out); err != nil { + if err := s.agent.RPC(req.Context(), "Catalog.ListServices", &args, &out); err != nil { metrics.IncrCounterWithLabels([]string{"client", "rpc", "error", "catalog_services"}, 1, s.nodeMetricsLabels()) return nil, err @@ -387,7 +387,7 @@ func (s *HTTPHandlers) catalogServiceNodes(resp http.ResponseWriter, req *http.R out = *reply } else { RETRY_ONCE: - if err := s.agent.RPC("Catalog.ServiceNodes", &args, &out); err != nil { + if err := s.agent.RPC(req.Context(), "Catalog.ServiceNodes", &args, &out); err != nil { metrics.IncrCounterWithLabels([]string{"client", "rpc", "error", "catalog_service_nodes"}, 1, s.nodeMetricsLabels()) return nil, err @@ -442,7 +442,7 @@ func (s *HTTPHandlers) CatalogNodeServices(resp http.ResponseWriter, req *http.R var out structs.IndexedNodeServices defer setMeta(resp, &out.QueryMeta) RETRY_ONCE: - if err := s.agent.RPC("Catalog.NodeServices", &args, &out); err != nil { + if err := s.agent.RPC(req.Context(), "Catalog.NodeServices", &args, &out); err != nil { metrics.IncrCounterWithLabels([]string{"client", "rpc", "error", "catalog_node_services"}, 1, s.nodeMetricsLabels()) return nil, err @@ -507,7 +507,7 @@ func (s *HTTPHandlers) CatalogNodeServiceList(resp http.ResponseWriter, req *htt var out structs.IndexedNodeServiceList defer setMeta(resp, &out.QueryMeta) RETRY_ONCE: - if err := s.agent.RPC("Catalog.NodeServiceList", &args, &out); err != nil { + if err := s.agent.RPC(req.Context(), "Catalog.NodeServiceList", &args, &out); err != nil { metrics.IncrCounterWithLabels([]string{"client", "rpc", "error", "catalog_node_service_list"}, 1, s.nodeMetricsLabels()) return nil, err @@ -554,7 +554,7 @@ func (s *HTTPHandlers) CatalogGatewayServices(resp http.ResponseWriter, req *htt var out structs.IndexedGatewayServices defer setMeta(resp, &out.QueryMeta) RETRY_ONCE: - if err := s.agent.RPC("Catalog.GatewayServices", &args, &out); err != nil { + if err := s.agent.RPC(req.Context(), "Catalog.GatewayServices", &args, &out); err != nil { metrics.IncrCounterWithLabels([]string{"client", "rpc", "error", "catalog_gateway_services"}, 1, s.nodeMetricsLabels()) return nil, err diff --git a/agent/catalog_endpoint_test.go b/agent/catalog_endpoint_test.go index f3487c2126..b0a5922e0f 100644 --- a/agent/catalog_endpoint_test.go +++ b/agent/catalog_endpoint_test.go @@ -1,6 +1,7 @@ package agent import ( + "context" "fmt" "net/http" "net/http/httptest" @@ -167,7 +168,7 @@ func TestCatalogNodes(t *testing.T) { } var out struct{} - if err := a.RPC("Catalog.Register", args, &out); err != nil { + if err := a.RPC(context.Background(), "Catalog.Register", args, &out); err != nil { t.Fatalf("err: %v", err) } @@ -208,7 +209,7 @@ func TestCatalogNodes_MetaFilter(t *testing.T) { } var out struct{} - if err := a.RPC("Catalog.Register", args, &out); err != nil { + if err := a.RPC(context.Background(), "Catalog.Register", args, &out); err != nil { t.Fatalf("err: %v", err) } @@ -253,7 +254,7 @@ func TestCatalogNodes_Filter(t *testing.T) { } var out struct{} - require.NoError(t, a.RPC("Catalog.Register", args, &out)) + require.NoError(t, a.RPC(context.Background(), "Catalog.Register", args, &out)) req, _ := http.NewRequest("GET", "/v1/catalog/nodes?filter="+url.QueryEscape("Meta.somekey == somevalue"), nil) resp := httptest.NewRecorder() @@ -322,7 +323,7 @@ func TestCatalogNodes_WanTranslation(t *testing.T) { } var out struct{} - if err := a2.RPC("Catalog.Register", args, &out); err != nil { + if err := a2.RPC(context.Background(), "Catalog.Register", args, &out); err != nil { t.Fatalf("err: %v", err) } } @@ -389,7 +390,7 @@ func TestCatalogNodes_Blocking(t *testing.T) { Datacenter: "dc1", } var out structs.IndexedNodes - if err := a.RPC("Catalog.ListNodes", *args, &out); err != nil { + if err := a.RPC(context.Background(), "Catalog.ListNodes", *args, &out); err != nil { t.Fatalf("err: %v", err) } @@ -404,7 +405,7 @@ func TestCatalogNodes_Blocking(t *testing.T) { Address: "127.0.0.1", } var out struct{} - if err := a.RPC("Catalog.Register", args, &out); err != nil { + if err := a.RPC(context.Background(), "Catalog.Register", args, &out); err != nil { t.Errorf("err: %v", err) } }() @@ -469,14 +470,14 @@ func TestCatalogNodes_DistanceSort(t *testing.T) { Address: "127.0.0.1", } var out struct{} - require.NoError(t, a.RPC("Catalog.Register", args, &out)) + require.NoError(t, a.RPC(context.Background(), "Catalog.Register", args, &out)) args = &structs.RegisterRequest{ Datacenter: "dc1", Node: "bar", Address: "127.0.0.2", } - require.NoError(t, a.RPC("Catalog.Register", args, &out)) + require.NoError(t, a.RPC(context.Background(), "Catalog.Register", args, &out)) // Nobody has coordinates set so this will still return them in the // order they are indexed. @@ -498,7 +499,7 @@ func TestCatalogNodes_DistanceSort(t *testing.T) { Node: "foo", Coord: coordinate.NewCoordinate(coordinate.DefaultConfig()), } - require.NoError(t, a.RPC("Coordinate.Update", &arg, &out)) + require.NoError(t, a.RPC(context.Background(), "Coordinate.Update", &arg, &out)) time.Sleep(300 * time.Millisecond) // Query again and now foo should have moved to the front of the line. @@ -536,7 +537,7 @@ func TestCatalogServices(t *testing.T) { } var out struct{} - if err := a.RPC("Catalog.Register", args, &out); err != nil { + if err := a.RPC(context.Background(), "Catalog.Register", args, &out); err != nil { t.Fatalf("err: %v", err) } @@ -578,7 +579,7 @@ func TestCatalogServices_NodeMetaFilter(t *testing.T) { } var out struct{} - if err := a.RPC("Catalog.Register", args, &out); err != nil { + if err := a.RPC(context.Background(), "Catalog.Register", args, &out); err != nil { t.Fatalf("err: %v", err) } @@ -632,7 +633,7 @@ func TestCatalogRegister_checkRegistration(t *testing.T) { } var out struct{} - if err := a.RPC("Catalog.Register", args, &out); err != nil { + if err := a.RPC(context.Background(), "Catalog.Register", args, &out); err != nil { t.Fatalf("err: %v", err) } @@ -689,7 +690,7 @@ func TestCatalogRegister_checkRegistration_UDP(t *testing.T) { } var out struct{} - if err := a.RPC("Catalog.Register", args, &out); err != nil { + if err := a.RPC(context.Background(), "Catalog.Register", args, &out); err != nil { t.Fatalf("err: %v", err) } @@ -752,7 +753,7 @@ func TestCatalogServiceNodes(t *testing.T) { } var out struct{} - if err := a.RPC("Catalog.Register", args, &out); err != nil { + if err := a.RPC(context.Background(), "Catalog.Register", args, &out); err != nil { t.Fatalf("err: %v", err) } @@ -804,7 +805,7 @@ func TestCatalogServiceNodes(t *testing.T) { args2 := args args2.Node = "bar" args2.Address = "127.0.0.2" - require.NoError(t, a.RPC("Catalog.Register", args, &out)) + require.NoError(t, a.RPC(context.Background(), "Catalog.Register", args, &out)) retry.Run(t, func(r *retry.R) { // List it again @@ -868,7 +869,7 @@ func TestCatalogServiceNodes_NodeMetaFilter(t *testing.T) { } var out struct{} - if err := a.RPC("Catalog.Register", args, &out); err != nil { + if err := a.RPC(context.Background(), "Catalog.Register", args, &out); err != nil { t.Fatalf("err: %v", err) } @@ -925,7 +926,7 @@ func TestCatalogServiceNodes_Filter(t *testing.T) { } var out struct{} - require.NoError(t, a.RPC("Catalog.Register", args, &out)) + require.NoError(t, a.RPC(context.Background(), "Catalog.Register", args, &out)) // Register a second service for the node args = &structs.RegisterRequest{ @@ -942,7 +943,7 @@ func TestCatalogServiceNodes_Filter(t *testing.T) { SkipNodeUpdate: true, } - require.NoError(t, a.RPC("Catalog.Register", args, &out)) + require.NoError(t, a.RPC(context.Background(), "Catalog.Register", args, &out)) req, _ := http.NewRequest("GET", queryPath, nil) resp := httptest.NewRecorder() @@ -1005,7 +1006,7 @@ func TestCatalogServiceNodes_WanTranslation(t *testing.T) { } var out struct{} - require.NoError(t, a2.RPC("Catalog.Register", args, &out)) + require.NoError(t, a2.RPC(context.Background(), "Catalog.Register", args, &out)) } // Query for the node in DC2 from DC1. @@ -1061,7 +1062,7 @@ func TestCatalogServiceNodes_DistanceSort(t *testing.T) { }, } var out struct{} - if err := a.RPC("Catalog.Register", args, &out); err != nil { + if err := a.RPC(context.Background(), "Catalog.Register", args, &out); err != nil { t.Fatalf("err: %v", err) } @@ -1075,7 +1076,7 @@ func TestCatalogServiceNodes_DistanceSort(t *testing.T) { Tags: []string{"a"}, }, } - if err := a.RPC("Catalog.Register", args, &out); err != nil { + if err := a.RPC(context.Background(), "Catalog.Register", args, &out); err != nil { t.Fatalf("err: %v", err) } @@ -1106,7 +1107,7 @@ func TestCatalogServiceNodes_DistanceSort(t *testing.T) { Node: "foo", Coord: coordinate.NewCoordinate(coordinate.DefaultConfig()), } - if err := a.RPC("Coordinate.Update", &arg, &out); err != nil { + if err := a.RPC(context.Background(), "Coordinate.Update", &arg, &out); err != nil { t.Fatalf("err: %v", err) } @@ -1149,7 +1150,7 @@ func TestCatalogServiceNodes_ConnectProxy(t *testing.T) { // Register args := structs.TestRegisterRequestProxy(t) var out struct{} - assert.Nil(t, a.RPC("Catalog.Register", args, &out)) + assert.Nil(t, a.RPC(context.Background(), "Catalog.Register", args, &out)) req, _ := http.NewRequest("GET", fmt.Sprintf( "/v1/catalog/service/%s", args.Service.Service), nil) @@ -1177,7 +1178,7 @@ func registerService(t *testing.T, a *TestAgent) (registerServiceReq *structs.Re } var out struct{} - require.NoError(t, a.RPC("Catalog.Register", registerServiceReq, &out)) + require.NoError(t, a.RPC(context.Background(), "Catalog.Register", registerServiceReq, &out)) return } @@ -1201,7 +1202,7 @@ func registerProxyDefaults(t *testing.T, a *TestAgent) (proxyGlobalEntry structs Entry: &proxyGlobalEntry, } var proxyDefaultsConfigEntryResp bool - require.NoError(t, a.RPC("ConfigEntry.Apply", &proxyDefaultsConfigEntryReq, &proxyDefaultsConfigEntryResp)) + require.NoError(t, a.RPC(context.Background(), "ConfigEntry.Apply", &proxyDefaultsConfigEntryReq, &proxyDefaultsConfigEntryResp)) return } @@ -1230,7 +1231,7 @@ func registerServiceDefaults(t *testing.T, a *TestAgent, serviceName string) (se Entry: &serviceDefaultsConfigEntry, } var serviceDefaultsConfigEntryResp bool - require.NoError(t, a.RPC("ConfigEntry.Apply", &serviceDefaultsConfigEntryReq, &serviceDefaultsConfigEntryResp)) + require.NoError(t, a.RPC(context.Background(), "ConfigEntry.Apply", &serviceDefaultsConfigEntryReq, &serviceDefaultsConfigEntryResp)) return } @@ -1352,7 +1353,7 @@ func TestCatalogServiceNodes_MergeCentralConfigBlocking(t *testing.T) { MergeCentralConfig: true, } var rpcResp structs.IndexedServiceNodes - require.NoError(t, a.RPC("Catalog.ServiceNodes", &rpcReq, &rpcResp)) + require.NoError(t, a.RPC(context.Background(), "Catalog.ServiceNodes", &rpcReq, &rpcResp)) require.Len(t, rpcResp.ServiceNodes, 1) serviceNode := rpcResp.ServiceNodes[0] @@ -1424,7 +1425,7 @@ func TestCatalogConnectServiceNodes_good(t *testing.T) { args := structs.TestRegisterRequestProxy(t) args.Service.Address = "127.0.0.55" var out struct{} - assert.Nil(t, a.RPC("Catalog.Register", args, &out)) + assert.Nil(t, a.RPC(context.Background(), "Catalog.Register", args, &out)) req, _ := http.NewRequest("GET", fmt.Sprintf( "/v1/catalog/connect/%s", args.Service.Proxy.DestinationServiceName), nil) @@ -1455,7 +1456,7 @@ func TestCatalogConnectServiceNodes_Filter(t *testing.T) { args := structs.TestRegisterRequestProxy(t) args.Service.Address = "127.0.0.55" var out struct{} - require.NoError(t, a.RPC("Catalog.Register", args, &out)) + require.NoError(t, a.RPC(context.Background(), "Catalog.Register", args, &out)) args = structs.TestRegisterRequestProxy(t) args.Service.Address = "127.0.0.55" @@ -1464,7 +1465,7 @@ func TestCatalogConnectServiceNodes_Filter(t *testing.T) { } args.Service.ID = "web-proxy2" args.SkipNodeUpdate = true - require.NoError(t, a.RPC("Catalog.Register", args, &out)) + require.NoError(t, a.RPC(context.Background(), "Catalog.Register", args, &out)) req, _ := http.NewRequest("GET", fmt.Sprintf( "/v1/catalog/connect/%s?filter=%s", @@ -1504,13 +1505,13 @@ func TestCatalogNodeServices(t *testing.T) { } var out struct{} - if err := a.RPC("Catalog.Register", args, &out); err != nil { + if err := a.RPC(context.Background(), "Catalog.Register", args, &out); err != nil { t.Fatalf("err: %v", err) } // Register a connect proxy args.Service = structs.TestNodeServiceProxy(t) - require.NoError(t, a.RPC("Catalog.Register", args, &out)) + require.NoError(t, a.RPC(context.Background(), "Catalog.Register", args, &out)) req, _ := http.NewRequest("GET", "/v1/catalog/node/foo?dc=dc1", nil) resp := httptest.NewRecorder() @@ -1551,13 +1552,13 @@ func TestCatalogNodeServiceList(t *testing.T) { } var out struct{} - if err := a.RPC("Catalog.Register", args, &out); err != nil { + if err := a.RPC(context.Background(), "Catalog.Register", args, &out); err != nil { t.Fatalf("err: %v", err) } // Register a connect proxy args.Service = structs.TestNodeServiceProxy(t) - require.NoError(t, a.RPC("Catalog.Register", args, &out)) + require.NoError(t, a.RPC(context.Background(), "Catalog.Register", args, &out)) req, _ := http.NewRequest("GET", "/v1/catalog/node-services/foo?dc=dc1", nil) resp := httptest.NewRecorder() @@ -1637,7 +1638,7 @@ func TestCatalogNodeServiceList_MergeCentralConfigBlocking(t *testing.T) { MergeCentralConfig: true, } var rpcResp structs.IndexedNodeServiceList - require.NoError(t, a.RPC("Catalog.NodeServiceList", &rpcReq, &rpcResp)) + require.NoError(t, a.RPC(context.Background(), "Catalog.NodeServiceList", &rpcReq, &rpcResp)) require.Len(t, rpcResp.NodeServices.Services, 1) nodeService := rpcResp.NodeServices.Services[0] require.Equal(t, registerServiceReq.Service.Service, nodeService.Service) @@ -1710,11 +1711,11 @@ func TestCatalogNodeServices_Filter(t *testing.T) { } var out struct{} - require.NoError(t, a.RPC("Catalog.Register", args, &out)) + require.NoError(t, a.RPC(context.Background(), "Catalog.Register", args, &out)) // Register a connect proxy args.Service = structs.TestNodeServiceProxy(t) - require.NoError(t, a.RPC("Catalog.Register", args, &out)) + require.NoError(t, a.RPC(context.Background(), "Catalog.Register", args, &out)) req, _ := http.NewRequest("GET", "/v1/catalog/node/foo?dc=dc1&filter="+url.QueryEscape("Kind == `connect-proxy`"), nil) resp := httptest.NewRecorder() @@ -1745,7 +1746,7 @@ func TestCatalogNodeServices_ConnectProxy(t *testing.T) { // Register args := structs.TestRegisterRequestProxy(t) var out struct{} - assert.Nil(t, a.RPC("Catalog.Register", args, &out)) + assert.Nil(t, a.RPC(context.Background(), "Catalog.Register", args, &out)) req, _ := http.NewRequest("GET", fmt.Sprintf( "/v1/catalog/node/%s", args.Node), nil) @@ -1813,7 +1814,7 @@ func TestCatalogNodeServices_WanTranslation(t *testing.T) { } var out struct{} - require.NoError(t, a2.RPC("Catalog.Register", args, &out)) + require.NoError(t, a2.RPC(context.Background(), "Catalog.Register", args, &out)) } // Query for the node in DC2 from DC1. @@ -1872,7 +1873,7 @@ func TestCatalog_GatewayServices_Terminating(t *testing.T) { ServiceID: args.Service.Service, } var out struct{} - assert.NoError(t, a.RPC("Catalog.Register", &args, &out)) + assert.NoError(t, a.RPC(context.Background(), "Catalog.Register", &args, &out)) // Associate the gateway and api/redis services entryArgs := &structs.ConfigEntryRequest{ @@ -1900,7 +1901,7 @@ func TestCatalog_GatewayServices_Terminating(t *testing.T) { }, } var entryResp bool - assert.NoError(t, a.RPC("ConfigEntry.Apply", &entryArgs, &entryResp)) + assert.NoError(t, a.RPC(context.Background(), "ConfigEntry.Apply", &entryArgs, &entryResp)) retry.Run(t, func(r *retry.R) { req, _ := http.NewRequest("GET", "/v1/catalog/gateway-services/terminating", nil) @@ -1985,7 +1986,7 @@ func TestCatalog_GatewayServices_Ingress(t *testing.T) { } var entryResp bool - require.NoError(t, a.RPC("ConfigEntry.Apply", &entryArgs, &entryResp)) + require.NoError(t, a.RPC(context.Background(), "ConfigEntry.Apply", &entryArgs, &entryResp)) retry.Run(t, func(r *retry.R) { req, _ := http.NewRequest("GET", "/v1/catalog/gateway-services/ingress", nil) diff --git a/agent/checks/alias.go b/agent/checks/alias.go index 9553745af1..a301daa4fa 100644 --- a/agent/checks/alias.go +++ b/agent/checks/alias.go @@ -1,6 +1,7 @@ package checks import ( + "context" "fmt" "strings" "sync" @@ -150,7 +151,7 @@ func (c *CheckAlias) checkServiceExistsOnRemoteServer(serviceID *structs.Service RETRY_CALL: var out structs.IndexedNodeServices attempts++ - if err := c.RPC.RPC("Catalog.NodeServices", &args, &out); err != nil { + if err := c.RPC.RPC(context.Background(), "Catalog.NodeServices", &args, &out); err != nil { if attempts <= 3 { time.Sleep(time.Duration(attempts) * time.Second) goto RETRY_CALL @@ -207,7 +208,7 @@ func (c *CheckAlias) runQuery(stopCh chan struct{}) { // index is global to the cluster. var out structs.IndexedHealthChecks - if err := c.RPC.RPC("Health.NodeChecks", &args, &out); err != nil { + if err := c.RPC.RPC(context.Background(), "Health.NodeChecks", &args, &out); err != nil { attempt++ if attempt > 1 { c.Notify.UpdateCheck(c.CheckID, api.HealthCritical, diff --git a/agent/checks/alias_test.go b/agent/checks/alias_test.go index 673e833044..4f77e7495e 100644 --- a/agent/checks/alias_test.go +++ b/agent/checks/alias_test.go @@ -1,6 +1,7 @@ package checks import ( + "context" "fmt" "reflect" "sync/atomic" @@ -550,7 +551,7 @@ func (m *mockRPC) AddReply(method string, reply interface{}) { } -func (m *mockRPC) RPC(method string, args interface{}, reply interface{}) error { +func (m *mockRPC) RPC(ctx context.Context, method string, args interface{}, reply interface{}) error { atomic.AddUint32(&m.Calls, 1) m.Args.Store(args) diff --git a/agent/checks/check.go b/agent/checks/check.go index b1bdad66a1..2eb200dd85 100644 --- a/agent/checks/check.go +++ b/agent/checks/check.go @@ -48,7 +48,7 @@ const ( // interface that is implemented by the agent delegate for checks that need // to make RPC calls. type RPC interface { - RPC(method string, args interface{}, reply interface{}) error + RPC(ctx context.Context, method string, args interface{}, reply interface{}) error } // CheckNotifier interface is used by the CheckMonitor diff --git a/agent/config_endpoint.go b/agent/config_endpoint.go index 7e67d851ae..fbcac83ce3 100644 --- a/agent/config_endpoint.go +++ b/agent/config_endpoint.go @@ -47,7 +47,7 @@ func (s *HTTPHandlers) configGet(resp http.ResponseWriter, req *http.Request) (i } var reply structs.ConfigEntryResponse - if err := s.agent.RPC("ConfigEntry.Get", &args, &reply); err != nil { + if err := s.agent.RPC(req.Context(), "ConfigEntry.Get", &args, &reply); err != nil { return nil, err } setMeta(resp, &reply.QueryMeta) @@ -65,7 +65,7 @@ func (s *HTTPHandlers) configGet(resp http.ResponseWriter, req *http.Request) (i args.Kind = pathArgs[0] var reply structs.IndexedConfigEntries - if err := s.agent.RPC("ConfigEntry.List", &args, &reply); err != nil { + if err := s.agent.RPC(req.Context(), "ConfigEntry.List", &args, &reply); err != nil { return nil, err } setMeta(resp, &reply.QueryMeta) @@ -111,7 +111,7 @@ func (s *HTTPHandlers) configDelete(resp http.ResponseWriter, req *http.Request) } var reply structs.ConfigEntryDeleteResponse - if err := s.agent.RPC("ConfigEntry.Delete", &args, &reply); err != nil { + if err := s.agent.RPC(req.Context(), "ConfigEntry.Delete", &args, &reply); err != nil { return nil, err } @@ -160,7 +160,7 @@ func (s *HTTPHandlers) ConfigApply(resp http.ResponseWriter, req *http.Request) } var reply bool - if err := s.agent.RPC("ConfigEntry.Apply", &args, &reply); err != nil { + if err := s.agent.RPC(req.Context(), "ConfigEntry.Apply", &args, &reply); err != nil { return nil, err } diff --git a/agent/config_endpoint_test.go b/agent/config_endpoint_test.go index b8ed9d5507..f3c1c17201 100644 --- a/agent/config_endpoint_test.go +++ b/agent/config_endpoint_test.go @@ -2,6 +2,7 @@ package agent import ( "bytes" + "context" "fmt" "net/http" "net/http/httptest" @@ -65,7 +66,7 @@ func TestConfig_Get(t *testing.T) { } for _, req := range reqs { out := false - require.NoError(t, a.RPC("ConfigEntry.Apply", &req, &out)) + require.NoError(t, a.RPC(context.Background(), "ConfigEntry.Apply", &req, &out)) } t.Run("get a single service entry", func(t *testing.T) { @@ -171,7 +172,7 @@ func TestConfig_Delete(t *testing.T) { } for _, req := range reqs { out := false - require.NoError(t, a.RPC("ConfigEntry.Apply", &req, &out)) + require.NoError(t, a.RPC(context.Background(), "ConfigEntry.Apply", &req, &out)) } // Delete an entry. @@ -188,7 +189,7 @@ func TestConfig_Delete(t *testing.T) { Datacenter: "dc1", } var out structs.IndexedConfigEntries - require.NoError(t, a.RPC("ConfigEntry.List", &args, &out)) + require.NoError(t, a.RPC(context.Background(), "ConfigEntry.List", &args, &out)) require.Equal(t, structs.ServiceDefaults, out.Kind) require.Len(t, out.Entries, 1) entry := out.Entries[0].(*structs.ServiceConfigEntry) @@ -212,7 +213,7 @@ func TestConfig_Delete_CAS(t *testing.T) { Name: "foo", } var created bool - require.NoError(t, a.RPC("ConfigEntry.Apply", &structs.ConfigEntryRequest{ + require.NoError(t, a.RPC(context.Background(), "ConfigEntry.Apply", &structs.ConfigEntryRequest{ Datacenter: "dc1", Entry: entry, }, &created)) @@ -220,7 +221,7 @@ func TestConfig_Delete_CAS(t *testing.T) { // Read it back to get its ModifyIndex. var out structs.ConfigEntryResponse - require.NoError(t, a.RPC("ConfigEntry.Get", &structs.ConfigEntryQuery{ + require.NoError(t, a.RPC(context.Background(), "ConfigEntry.Get", &structs.ConfigEntryQuery{ Datacenter: "dc1", Kind: entry.Kind, Name: entry.Name, @@ -244,7 +245,7 @@ func TestConfig_Delete_CAS(t *testing.T) { // Verify it was not deleted. var out structs.ConfigEntryResponse - require.NoError(t, a.RPC("ConfigEntry.Get", &structs.ConfigEntryQuery{ + require.NoError(t, a.RPC(context.Background(), "ConfigEntry.Get", &structs.ConfigEntryQuery{ Datacenter: "dc1", Kind: entry.Kind, Name: entry.Name, @@ -267,7 +268,7 @@ func TestConfig_Delete_CAS(t *testing.T) { // Verify it was deleted. var out structs.ConfigEntryResponse - require.NoError(t, a.RPC("ConfigEntry.Get", &structs.ConfigEntryQuery{ + require.NoError(t, a.RPC(context.Background(), "ConfigEntry.Get", &structs.ConfigEntryQuery{ Datacenter: "dc1", Kind: entry.Kind, Name: entry.Name, @@ -311,7 +312,7 @@ func TestConfig_Apply(t *testing.T) { Datacenter: "dc1", } var out structs.ConfigEntryResponse - require.NoError(t, a.RPC("ConfigEntry.Get", &args, &out)) + require.NoError(t, a.RPC(context.Background(), "ConfigEntry.Get", &args, &out)) require.NotNil(t, out.Entry) entry := out.Entry.(*structs.ServiceConfigEntry) require.Equal(t, entry.Name, "foo") @@ -360,7 +361,7 @@ func TestConfig_Apply_TerminatingGateway(t *testing.T) { Datacenter: "dc1", } var out structs.IndexedConfigEntries - require.NoError(t, a.RPC("ConfigEntry.List", &args, &out)) + require.NoError(t, a.RPC(context.Background(), "ConfigEntry.List", &args, &out)) require.NotNil(t, out) require.Len(t, out.Entries, 1) @@ -421,7 +422,7 @@ func TestConfig_Apply_IngressGateway(t *testing.T) { Datacenter: "dc1", } var out structs.IndexedConfigEntries - require.NoError(t, a.RPC("ConfigEntry.List", &args, &out)) + require.NoError(t, a.RPC(context.Background(), "ConfigEntry.List", &args, &out)) require.NotNil(t, out) require.Len(t, out.Entries, 1) @@ -486,7 +487,7 @@ func TestConfig_Apply_ProxyDefaultsMeshGateway(t *testing.T) { Datacenter: "dc1", } var out structs.ConfigEntryResponse - require.NoError(t, a.RPC("ConfigEntry.Get", &args, &out)) + require.NoError(t, a.RPC(context.Background(), "ConfigEntry.Get", &args, &out)) require.NotNil(t, out.Entry) entry := out.Entry.(*structs.ProxyConfigEntry) require.Equal(t, structs.MeshGatewayModeLocal, entry.MeshGateway.Mode) @@ -528,7 +529,7 @@ func TestConfig_Apply_CAS(t *testing.T) { } out := &structs.ConfigEntryResponse{} - require.NoError(t, a.RPC("ConfigEntry.Get", &args, out)) + require.NoError(t, a.RPC(context.Background(), "ConfigEntry.Get", &args, out)) require.NotNil(t, out.Entry) entry := out.Entry.(*structs.ServiceConfigEntry) @@ -572,7 +573,7 @@ func TestConfig_Apply_CAS(t *testing.T) { } out = &structs.ConfigEntryResponse{} - require.NoError(t, a.RPC("ConfigEntry.Get", &args, out)) + require.NoError(t, a.RPC(context.Background(), "ConfigEntry.Get", &args, out)) require.NotNil(t, out.Entry) newEntry := out.Entry.(*structs.ServiceConfigEntry) require.NotEqual(t, entry.GetRaftIndex(), newEntry.GetRaftIndex()) @@ -644,7 +645,7 @@ func TestConfig_Apply_Decoding(t *testing.T) { Datacenter: "dc1", } var out structs.ConfigEntryResponse - require.NoError(t, a.RPC("ConfigEntry.Get", &args, &out)) + require.NoError(t, a.RPC(context.Background(), "ConfigEntry.Get", &args, &out)) require.NotNil(t, out.Entry) entry := out.Entry.(*structs.ServiceConfigEntry) require.Equal(t, entry.Name, "foo") @@ -695,7 +696,7 @@ func TestConfig_Apply_ProxyDefaultsExpose(t *testing.T) { Datacenter: "dc1", } var out structs.ConfigEntryResponse - require.NoError(t, a.RPC("ConfigEntry.Get", &args, &out)) + require.NoError(t, a.RPC(context.Background(), "ConfigEntry.Get", &args, &out)) require.NotNil(t, out.Entry) entry := out.Entry.(*structs.ProxyConfigEntry) diff --git a/agent/connect/testing_ca.go b/agent/connect/testing_ca.go index 06d965ce6e..aa44859b37 100644 --- a/agent/connect/testing_ca.go +++ b/agent/connect/testing_ca.go @@ -2,6 +2,7 @@ package connect import ( "bytes" + "context" "crypto" "crypto/rand" "crypto/x509" @@ -414,7 +415,7 @@ func testUUID(t testing.T) string { // helper interface that is implemented by the agent delegate so that test // helpers can make RPCs without introducing an import cycle on `agent`. type TestAgentRPC interface { - RPC(method string, args interface{}, reply interface{}) error + RPC(ctx context.Context, method string, args interface{}, reply interface{}) error } func testCAConfigSet(t testing.T, a TestAgentRPC, @@ -438,7 +439,7 @@ func testCAConfigSet(t testing.T, a TestAgentRPC, } var reply interface{} - err := a.RPC("ConnectCA.ConfigurationSet", args, &reply) + err := a.RPC(context.Background(), "ConnectCA.ConfigurationSet", args, &reply) if err != nil { t.Fatalf("failed to set test CA config: %s", err) } diff --git a/agent/connect_ca_endpoint.go b/agent/connect_ca_endpoint.go index 2e78bc7d89..bb2ec637c4 100644 --- a/agent/connect_ca_endpoint.go +++ b/agent/connect_ca_endpoint.go @@ -27,7 +27,7 @@ func (s *HTTPHandlers) ConnectCARoots(resp http.ResponseWriter, req *http.Reques var reply structs.IndexedCARoots defer setMeta(resp, &reply.QueryMeta) - if err := s.agent.RPC("ConnectCA.Roots", &args, &reply); err != nil { + if err := s.agent.RPC(req.Context(), "ConnectCA.Roots", &args, &reply); err != nil { return nil, err } @@ -74,7 +74,7 @@ func (s *HTTPHandlers) ConnectCAConfigurationGet(resp http.ResponseWriter, req * } var reply structs.CAConfiguration - err := s.agent.RPC("ConnectCA.ConfigurationGet", &args, &reply) + err := s.agent.RPC(req.Context(), "ConnectCA.ConfigurationGet", &args, &reply) if err != nil { return nil, err } @@ -94,7 +94,7 @@ func (s *HTTPHandlers) ConnectCAConfigurationSet(req *http.Request) (interface{} } var reply interface{} - err := s.agent.RPC("ConnectCA.ConfigurationSet", &args, &reply) + err := s.agent.RPC(req.Context(), "ConnectCA.ConfigurationSet", &args, &reply) if err != nil && err.Error() == consul.ErrStateReadOnly.Error() { return nil, HTTPError{ StatusCode: http.StatusBadRequest, diff --git a/agent/consul/acl.go b/agent/consul/acl.go index e0d244b5d6..3dbb19a618 100644 --- a/agent/consul/acl.go +++ b/agent/consul/acl.go @@ -1,6 +1,7 @@ package consul import ( + "context" "fmt" "sort" "sync" @@ -134,7 +135,7 @@ type ACLResolverBackend interface { ResolveRoleFromID(roleID string) (bool, *structs.ACLRole, error) IsServerManagementToken(token string) bool // TODO: separate methods for each RPC call (there are 4) - RPC(method string, args interface{}, reply interface{}) error + RPC(ctx context.Context, method string, args interface{}, reply interface{}) error EnterpriseACLResolverDelegate } @@ -354,7 +355,7 @@ func (r *ACLResolver) fetchAndCacheIdentityFromToken(token string, cached *struc } var resp structs.ACLTokenResponse - err := r.backend.RPC("ACL.TokenRead", &req, &resp) + err := r.backend.RPC(context.Background(), "ACL.TokenRead", &req, &resp) if err == nil { if resp.Token == nil { r.cache.RemoveIdentityWithSecretToken(token) @@ -441,7 +442,7 @@ func (r *ACLResolver) fetchAndCachePoliciesForIdentity(identity structs.ACLIdent } var resp structs.ACLPolicyBatchResponse - err := r.backend.RPC("ACL.PolicyResolve", &req, &resp) + err := r.backend.RPC(context.Background(), "ACL.PolicyResolve", &req, &resp) if err == nil { out := make(map[string]*structs.ACLPolicy) for _, policy := range resp.Policies { @@ -496,7 +497,7 @@ func (r *ACLResolver) fetchAndCacheRolesForIdentity(identity structs.ACLIdentity } var resp structs.ACLRoleBatchResponse - err := r.backend.RPC("ACL.RoleResolve", &req, &resp) + err := r.backend.RPC(context.Background(), "ACL.RoleResolve", &req, &resp) if err == nil { out := make(map[string]*structs.ACLRole) for _, role := range resp.Roles { diff --git a/agent/consul/acl_oss_test.go b/agent/consul/acl_oss_test.go index 917696105a..df2165e1c5 100644 --- a/agent/consul/acl_oss_test.go +++ b/agent/consul/acl_oss_test.go @@ -4,6 +4,7 @@ package consul import ( + "context" "fmt" "github.com/hashicorp/consul/acl" @@ -26,7 +27,7 @@ func testRoleForIDEnterprise(string) (bool, *structs.ACLRole, error) { type EnterpriseACLResolverTestDelegate struct{} // RPC stub for the EnterpriseACLResolverTestDelegate -func (d *EnterpriseACLResolverTestDelegate) RPC(string, interface{}, interface{}) (bool, error) { +func (d *EnterpriseACLResolverTestDelegate) RPC(context.Context, string, interface{}, interface{}) (bool, error) { return false, nil } diff --git a/agent/consul/acl_replication.go b/agent/consul/acl_replication.go index f95cabbcf7..5107afe451 100644 --- a/agent/consul/acl_replication.go +++ b/agent/consul/acl_replication.go @@ -96,7 +96,7 @@ func (s *Server) fetchACLRolesBatch(roleIDs []string) (*structs.ACLRoleBatchResp } var response structs.ACLRoleBatchResponse - if err := s.RPC("ACL.RoleBatchRead", &req, &response); err != nil { + if err := s.RPC(context.Background(), "ACL.RoleBatchRead", &req, &response); err != nil { return nil, err } @@ -117,7 +117,7 @@ func (s *Server) fetchACLRoles(lastRemoteIndex uint64) (*structs.ACLRoleListResp } var response structs.ACLRoleListResponse - if err := s.RPC("ACL.RoleList", &req, &response); err != nil { + if err := s.RPC(context.Background(), "ACL.RoleList", &req, &response); err != nil { return nil, err } return &response, nil @@ -134,7 +134,7 @@ func (s *Server) fetchACLPoliciesBatch(policyIDs []string) (*structs.ACLPolicyBa } var response structs.ACLPolicyBatchResponse - if err := s.RPC("ACL.PolicyBatchRead", &req, &response); err != nil { + if err := s.RPC(context.Background(), "ACL.PolicyBatchRead", &req, &response); err != nil { return nil, err } @@ -155,7 +155,7 @@ func (s *Server) fetchACLPolicies(lastRemoteIndex uint64) (*structs.ACLPolicyLis } var response structs.ACLPolicyListResponse - if err := s.RPC("ACL.PolicyList", &req, &response); err != nil { + if err := s.RPC(context.Background(), "ACL.PolicyList", &req, &response); err != nil { return nil, err } return &response, nil @@ -324,7 +324,7 @@ func (s *Server) fetchACLTokensBatch(tokenIDs []string) (*structs.ACLTokenBatchR } var response structs.ACLTokenBatchResponse - if err := s.RPC("ACL.TokenBatchRead", &req, &response); err != nil { + if err := s.RPC(context.Background(), "ACL.TokenBatchRead", &req, &response); err != nil { return nil, err } @@ -347,7 +347,7 @@ func (s *Server) fetchACLTokens(lastRemoteIndex uint64) (*structs.ACLTokenListRe } var response structs.ACLTokenListResponse - if err := s.RPC("ACL.TokenList", &req, &response); err != nil { + if err := s.RPC(context.Background(), "ACL.TokenList", &req, &response); err != nil { return nil, err } return &response, nil diff --git a/agent/consul/acl_replication_test.go b/agent/consul/acl_replication_test.go index 102a7b4b26..25bd3929ff 100644 --- a/agent/consul/acl_replication_test.go +++ b/agent/consul/acl_replication_test.go @@ -1,6 +1,7 @@ package consul import ( + "context" "fmt" "os" "strconv" @@ -347,7 +348,7 @@ func TestACLReplication_Tokens(t *testing.T) { WriteRequest: structs.WriteRequest{Token: "root"}, } var token structs.ACLToken - require.NoError(t, s1.RPC("ACL.TokenSet", &arg, &token)) + require.NoError(t, s1.RPC(context.Background(), "ACL.TokenSet", &arg, &token)) tokens = append(tokens, &token) } @@ -368,7 +369,7 @@ func TestACLReplication_Tokens(t *testing.T) { // Create one token via this process. methodToken := structs.ACLToken{} - require.NoError(t, s1.RPC("ACL.Login", &structs.ACLLoginRequest{ + require.NoError(t, s1.RPC(context.Background(), "ACL.Login", &structs.ACLLoginRequest{ Auth: &structs.ACLLoginParams{ AuthMethod: method1.Name, BearerToken: "fake-token", @@ -433,7 +434,7 @@ func TestACLReplication_Tokens(t *testing.T) { WriteRequest: structs.WriteRequest{Token: "root"}, } var token structs.ACLToken - require.NoError(t, s2.RPC("ACL.TokenSet", &arg, &token)) + require.NoError(t, s2.RPC(context.Background(), "ACL.TokenSet", &arg, &token)) } // add some local tokens to the primary DC @@ -453,7 +454,7 @@ func TestACLReplication_Tokens(t *testing.T) { WriteRequest: structs.WriteRequest{Token: "root"}, } var token structs.ACLToken - require.NoError(t, s1.RPC("ACL.TokenSet", &arg, &token)) + require.NoError(t, s1.RPC(context.Background(), "ACL.TokenSet", &arg, &token)) } // Update those other tokens @@ -474,7 +475,7 @@ func TestACLReplication_Tokens(t *testing.T) { WriteRequest: structs.WriteRequest{Token: "root"}, } var token structs.ACLToken - require.NoError(t, s1.RPC("ACL.TokenSet", &arg, &token)) + require.NoError(t, s1.RPC(context.Background(), "ACL.TokenSet", &arg, &token)) } // Wait for the replica to converge. @@ -496,7 +497,7 @@ func TestACLReplication_Tokens(t *testing.T) { } var dontCare string - require.NoError(t, s1.RPC("ACL.TokenDelete", &arg, &dontCare)) + require.NoError(t, s1.RPC(context.Background(), "ACL.TokenDelete", &arg, &dontCare)) } // Wait for the replica to converge. @@ -555,7 +556,7 @@ func TestACLReplication_Policies(t *testing.T) { WriteRequest: structs.WriteRequest{Token: "root"}, } var policy structs.ACLPolicy - require.NoError(t, s1.RPC("ACL.PolicySet", &arg, &policy)) + require.NoError(t, s1.RPC(context.Background(), "ACL.PolicySet", &arg, &policy)) policies = append(policies, &policy) } @@ -599,7 +600,7 @@ func TestACLReplication_Policies(t *testing.T) { WriteRequest: structs.WriteRequest{Token: "root"}, } var policy structs.ACLPolicy - require.NoError(t, s1.RPC("ACL.PolicySet", &arg, &policy)) + require.NoError(t, s1.RPC(context.Background(), "ACL.PolicySet", &arg, &policy)) } // Wait for the replica to converge. @@ -616,7 +617,7 @@ func TestACLReplication_Policies(t *testing.T) { } var dontCare string - require.NoError(t, s1.RPC("ACL.PolicyDelete", &arg, &dontCare)) + require.NoError(t, s1.RPC(context.Background(), "ACL.PolicyDelete", &arg, &dontCare)) } // Wait for the replica to converge. @@ -653,7 +654,7 @@ func TestACLReplication_TokensRedacted(t *testing.T) { WriteRequest: structs.WriteRequest{Token: "root"}, } var policy structs.ACLPolicy - require.NoError(t, s1.RPC("ACL.PolicySet", &policyArg, &policy)) + require.NoError(t, s1.RPC(context.Background(), "ACL.PolicySet", &policyArg, &policy)) // Create the dc2 replication token tokenArg := structs.ACLTokenSetRequest{ @@ -671,7 +672,7 @@ func TestACLReplication_TokensRedacted(t *testing.T) { } var token structs.ACLToken - require.NoError(t, s1.RPC("ACL.TokenSet", &tokenArg, &token)) + require.NoError(t, s1.RPC(context.Background(), "ACL.TokenSet", &tokenArg, &token)) dir2, s2 := testServerWithConfig(t, func(c *Config) { c.Datacenter = "dc2" @@ -701,7 +702,7 @@ func TestACLReplication_TokensRedacted(t *testing.T) { TokenIDType: structs.ACLTokenSecret, QueryOptions: structs.QueryOptions{Token: "root"}, } - err := s2.RPC("ACL.TokenRead", &req, &tokenResp) + err := s2.RPC(context.Background(), "ACL.TokenRead", &req, &tokenResp) require.NoError(r, err) require.NotNil(r, tokenResp.Token) require.Equal(r, "root", tokenResp.Token.SecretID) @@ -710,7 +711,7 @@ func TestACLReplication_TokensRedacted(t *testing.T) { statusReq := structs.DCSpecificRequest{ Datacenter: "dc2", } - require.NoError(r, s2.RPC("ACL.ReplicationStatus", &statusReq, &status)) + require.NoError(r, s2.RPC(context.Background(), "ACL.ReplicationStatus", &statusReq, &status)) // ensures that tokens are not being synced require.True(r, status.ReplicatedTokenIndex > 0, "ReplicatedTokenIndex not greater than 0") @@ -727,7 +728,7 @@ func TestACLReplication_TokensRedacted(t *testing.T) { }, WriteRequest: structs.WriteRequest{Token: "root"}, } - require.NoError(t, s1.RPC("ACL.PolicySet", &policyArg, &policy)) + require.NoError(t, s1.RPC(context.Background(), "ACL.PolicySet", &policyArg, &policy)) // Create the another token so that replication will attempt to read it. tokenArg = structs.ACLTokenSetRequest{ @@ -747,7 +748,7 @@ func TestACLReplication_TokensRedacted(t *testing.T) { // record the time right before we are touching the token minErrorTime := time.Now() - require.NoError(t, s1.RPC("ACL.TokenSet", &tokenArg, &token2)) + require.NoError(t, s1.RPC(context.Background(), "ACL.TokenSet", &tokenArg, &token2)) retry.Run(t, func(r *retry.R) { var tokenResp structs.ACLTokenResponse @@ -757,7 +758,7 @@ func TestACLReplication_TokensRedacted(t *testing.T) { TokenIDType: structs.ACLTokenSecret, QueryOptions: structs.QueryOptions{Token: aclfilter.RedactedToken}, } - err := s2.RPC("ACL.TokenRead", &req, &tokenResp) + err := s2.RPC(context.Background(), "ACL.TokenRead", &req, &tokenResp) // its not an error for the secret to not be found. require.NoError(r, err) require.Nil(r, tokenResp.Token) @@ -766,7 +767,7 @@ func TestACLReplication_TokensRedacted(t *testing.T) { statusReq := structs.DCSpecificRequest{ Datacenter: "dc2", } - require.NoError(r, s2.RPC("ACL.ReplicationStatus", &statusReq, &status)) + require.NoError(r, s2.RPC(context.Background(), "ACL.ReplicationStatus", &statusReq, &status)) // ensures that tokens are not being synced require.True(r, status.ReplicatedTokenIndex < token2.CreateIndex, "ReplicatedTokenIndex is not less than the token2s create index") // ensures that token replication is erroring @@ -914,7 +915,7 @@ func TestACLReplication_AllTypes(t *testing.T) { WriteRequest: structs.WriteRequest{Token: "root"}, } var dontCare string - if err := s1.RPC("ACL.TokenDelete", &arg, &dontCare); err != nil { + if err := s1.RPC(context.Background(), "ACL.TokenDelete", &arg, &dontCare); err != nil { t.Fatalf("err: %v", err) } } @@ -927,7 +928,7 @@ func TestACLReplication_AllTypes(t *testing.T) { WriteRequest: structs.WriteRequest{Token: "root"}, } var dontCare string - if err := s1.RPC("ACL.RoleDelete", &arg, &dontCare); err != nil { + if err := s1.RPC(context.Background(), "ACL.RoleDelete", &arg, &dontCare); err != nil { t.Fatalf("err: %v", err) } } @@ -940,7 +941,7 @@ func TestACLReplication_AllTypes(t *testing.T) { WriteRequest: structs.WriteRequest{Token: "root"}, } var dontCare string - if err := s1.RPC("ACL.PolicyDelete", &arg, &dontCare); err != nil { + if err := s1.RPC(context.Background(), "ACL.PolicyDelete", &arg, &dontCare); err != nil { t.Fatalf("err: %v", err) } } @@ -966,7 +967,7 @@ func createACLTestData(t *testing.T, srv *Server, namePrefix string, numObjects, WriteRequest: structs.WriteRequest{Token: "root"}, } var out structs.ACLPolicy - if err := srv.RPC("ACL.PolicySet", &arg, &out); err != nil { + if err := srv.RPC(context.Background(), "ACL.PolicySet", &arg, &out); err != nil { t.Fatalf("err: %v", err) } policyIDs = append(policyIDs, out.ID) @@ -987,7 +988,7 @@ func createACLTestData(t *testing.T, srv *Server, namePrefix string, numObjects, WriteRequest: structs.WriteRequest{Token: "root"}, } var out structs.ACLRole - if err := srv.RPC("ACL.RoleSet", &arg, &out); err != nil { + if err := srv.RPC(context.Background(), "ACL.RoleSet", &arg, &out); err != nil { t.Fatalf("err: %v", err) } roleIDs = append(roleIDs, out.ID) @@ -1011,7 +1012,7 @@ func createACLTestData(t *testing.T, srv *Server, namePrefix string, numObjects, WriteRequest: structs.WriteRequest{Token: "root"}, } var out structs.ACLToken - if err := srv.RPC("ACL.TokenSet", &arg, &out); err != nil { + if err := srv.RPC(context.Background(), "ACL.TokenSet", &arg, &out); err != nil { t.Fatalf("err: %v", err) } tokenIDs = append(tokenIDs, out.AccessorID) diff --git a/agent/consul/acl_test.go b/agent/consul/acl_test.go index 7c5288d1ec..c6601289d5 100644 --- a/agent/consul/acl_test.go +++ b/agent/consul/acl_test.go @@ -1,6 +1,7 @@ package consul import ( + "context" "fmt" "strings" "sync/atomic" @@ -671,7 +672,7 @@ func (d *ACLResolverTestDelegate) ResolveRoleFromID(roleID string) (bool, *struc return testRoleForID(roleID) } -func (d *ACLResolverTestDelegate) RPC(method string, args interface{}, reply interface{}) error { +func (d *ACLResolverTestDelegate) RPC(ctx context.Context, method string, args interface{}, reply interface{}) error { switch method { case "ACL.TokenRead": atomic.AddInt32(&d.remoteTokenResolutions, 1) @@ -692,7 +693,7 @@ func (d *ACLResolverTestDelegate) RPC(method string, args interface{}, reply int } panic("Bad Test Implementation: should provide a roleResolveFn to the ACLResolverTestDelegate") } - if handled, err := d.EnterpriseACLResolverTestDelegate.RPC(method, args, reply); handled { + if handled, err := d.EnterpriseACLResolverTestDelegate.RPC(context.Background(), method, args, reply); handled { return err } panic("Bad Test Implementation: Was the ACLResolver updated to use new RPC methods") diff --git a/agent/consul/client.go b/agent/consul/client.go index a2e5bb57e1..94f806a849 100644 --- a/agent/consul/client.go +++ b/agent/consul/client.go @@ -1,6 +1,7 @@ package consul import ( + "context" "fmt" "io" "strconv" @@ -262,7 +263,7 @@ func (c *Client) KeyManagerLAN() *serf.KeyManager { } // RPC is used to forward an RPC call to a consul server, or fail if no servers -func (c *Client) RPC(method string, args interface{}, reply interface{}) error { +func (c *Client) RPC(ctx context.Context, method string, args interface{}, reply interface{}) error { // This is subtle but we start measuring the time on the client side // right at the time of the first request, vs. on the first retry as // is done on the server side inside forward(). This is because the diff --git a/agent/consul/client_test.go b/agent/consul/client_test.go index cef20c291b..da1f462c30 100644 --- a/agent/consul/client_test.go +++ b/agent/consul/client_test.go @@ -2,6 +2,7 @@ package consul import ( "bytes" + "context" "fmt" "net" "os" @@ -242,7 +243,7 @@ func TestClient_RPC(t *testing.T) { // Try an RPC var out struct{} - if err := c1.RPC("Status.Ping", struct{}{}, &out); err != structs.ErrNoServers { + if err := c1.RPC(context.Background(), "Status.Ping", struct{}{}, &out); err != structs.ErrNoServers { t.Fatalf("err: %v", err) } @@ -260,7 +261,7 @@ func TestClient_RPC(t *testing.T) { // RPC should succeed retry.Run(t, func(r *retry.R) { - if err := c1.RPC("Status.Ping", struct{}{}, &out); err != nil { + if err := c1.RPC(context.Background(), "Status.Ping", struct{}{}, &out); err != nil { r.Fatal("ping failed", err) } }) @@ -311,7 +312,7 @@ func TestClient_RPC_Retry(t *testing.T) { joinLAN(t, c1, s1) retry.Run(t, func(r *retry.R) { var out struct{} - if err := c1.RPC("Status.Ping", struct{}{}, &out); err != nil { + if err := c1.RPC(context.Background(), "Status.Ping", struct{}{}, &out); err != nil { r.Fatalf("err: %v", err) } }) @@ -322,13 +323,13 @@ func TestClient_RPC_Retry(t *testing.T) { } var out struct{} - if err := c1.RPC("Fail.Always", struct{}{}, &out); !structs.IsErrNoLeader(err) { + if err := c1.RPC(context.Background(), "Fail.Always", struct{}{}, &out); !structs.IsErrNoLeader(err) { t.Fatalf("err: %v", err) } if got, want := failer.totalCalls, 2; got < want { t.Fatalf("got %d want >= %d", got, want) } - if err := c1.RPC("Fail.Once", struct{}{}, &out); err != nil { + if err := c1.RPC(context.Background(), "Fail.Once", struct{}{}, &out); err != nil { t.Fatalf("err: %v", err) } if got, want := failer.onceCalls, 2; got < want { @@ -372,7 +373,7 @@ func TestClient_RPC_Pool(t *testing.T) { defer wg.Done() var out struct{} retry.Run(t, func(r *retry.R) { - if err := c1.RPC("Status.Ping", struct{}{}, &out); err != nil { + if err := c1.RPC(context.Background(), "Status.Ping", struct{}{}, &out); err != nil { r.Fatal("ping failed", err) } }) @@ -467,7 +468,7 @@ func TestClient_RPC_TLS(t *testing.T) { // Try an RPC var out struct{} - if err := c1.RPC("Status.Ping", struct{}{}, &out); err != structs.ErrNoServers { + if err := c1.RPC(context.Background(), "Status.Ping", struct{}{}, &out); err != structs.ErrNoServers { t.Fatalf("err: %v", err) } @@ -482,7 +483,7 @@ func TestClient_RPC_TLS(t *testing.T) { if got, want := len(c1.LANMembersInAgentPartition()), 2; got != want { r.Fatalf("got %d client LAN members want %d", got, want) } - if err := c1.RPC("Status.Ping", struct{}{}, &out); err != nil { + if err := c1.RPC(context.Background(), "Status.Ping", struct{}{}, &out); err != nil { r.Fatal("ping failed", err) } }) @@ -579,7 +580,7 @@ func TestClient_RPC_RateLimit(t *testing.T) { joinLAN(t, c1, s1) retry.Run(t, func(r *retry.R) { var out struct{} - if err := c1.RPC("Status.Ping", struct{}{}, &out); err != structs.ErrRPCRateExceeded { + if err := c1.RPC(context.Background(), "Status.Ping", struct{}{}, &out); err != structs.ErrRPCRateExceeded { r.Fatalf("err: %v", err) } }) @@ -890,7 +891,7 @@ func TestClient_RPC_Timeout(t *testing.T) { retry.Run(t, func(r *retry.R) { var out struct{} - if err := c1.RPC("Status.Ping", struct{}{}, &out); err != nil { + if err := c1.RPC(context.Background(), "Status.Ping", struct{}{}, &out); err != nil { r.Fatalf("err: %v", err) } }) @@ -902,22 +903,22 @@ func TestClient_RPC_Timeout(t *testing.T) { // Requests with QueryOptions have a default timeout of // RPCClientTimeout (10ms) so we expect the RPC call to timeout. var out struct{} - err := c1.RPC("Long.Wait", &structs.NodeSpecificRequest{}, &out) + err := c1.RPC(context.Background(), "Long.Wait", &structs.NodeSpecificRequest{}, &out) require.Error(t, err) require.Contains(t, err.Error(), "rpc error making call: i/o deadline reached") }) t.Run("non-blocking query succeeds", func(t *testing.T) { var out struct{} - require.NoError(t, c1.RPC("Short.Wait", &structs.NodeSpecificRequest{}, &out)) + require.NoError(t, c1.RPC(context.Background(), "Short.Wait", &structs.NodeSpecificRequest{}, &out)) }) t.Run("check that deadline does not persist across calls", func(t *testing.T) { var out struct{} - err := c1.RPC("Long.Wait", &structs.NodeSpecificRequest{}, &out) + err := c1.RPC(context.Background(), "Long.Wait", &structs.NodeSpecificRequest{}, &out) require.Error(t, err) require.Contains(t, err.Error(), "rpc error making call: i/o deadline reached") - require.NoError(t, c1.RPC("Long.Wait", &structs.NodeSpecificRequest{ + require.NoError(t, c1.RPC(context.Background(), "Long.Wait", &structs.NodeSpecificRequest{ QueryOptions: structs.QueryOptions{ MinQueryIndex: 1, }, @@ -926,7 +927,7 @@ func TestClient_RPC_Timeout(t *testing.T) { t.Run("blocking query succeeds", func(t *testing.T) { var out struct{} - require.NoError(t, c1.RPC("Long.Wait", &structs.NodeSpecificRequest{ + require.NoError(t, c1.RPC(context.Background(), "Long.Wait", &structs.NodeSpecificRequest{ QueryOptions: structs.QueryOptions{ MinQueryIndex: 1, }, @@ -939,7 +940,7 @@ func TestClient_RPC_Timeout(t *testing.T) { // jitter (100ms / 16 = 6.25ms) as well as RPCHoldTimeout (50ms). // Client waits 156.25ms while the server waits 106.25ms (artifically // adds maximum jitter) so the server will always return first. - require.NoError(t, c1.RPC("Long.Wait", &structs.NodeSpecificRequest{ + require.NoError(t, c1.RPC(context.Background(), "Long.Wait", &structs.NodeSpecificRequest{ QueryOptions: structs.QueryOptions{ MinQueryIndex: 1, MaxQueryTime: 100 * time.Millisecond, @@ -957,7 +958,7 @@ func TestClient_RPC_Timeout(t *testing.T) { // jitter (20ms / 16 = 1.25ms) as well as RPCHoldTimeout (50ms). // Client waits 71.25ms while the server waits 106.25ms (artifically // adds maximum jitter) so the client will error first. - err := c1.RPC("Long.Wait", &structs.NodeSpecificRequest{ + err := c1.RPC(context.Background(), "Long.Wait", &structs.NodeSpecificRequest{ QueryOptions: structs.QueryOptions{ MinQueryIndex: 1, MaxQueryTime: 20 * time.Millisecond, diff --git a/agent/consul/config_replication.go b/agent/consul/config_replication.go index 8b1a2273a8..902f92a6af 100644 --- a/agent/consul/config_replication.go +++ b/agent/consul/config_replication.go @@ -142,7 +142,7 @@ func (s *Server) fetchConfigEntries(lastRemoteIndex uint64) (*structs.IndexedGen } var response structs.IndexedGenericConfigEntries - if err := s.RPC("ConfigEntry.ListAll", &req, &response); err != nil { + if err := s.RPC(context.Background(), "ConfigEntry.ListAll", &req, &response); err != nil { return nil, err } diff --git a/agent/consul/config_replication_test.go b/agent/consul/config_replication_test.go index 24cf0d4e41..24d51de884 100644 --- a/agent/consul/config_replication_test.go +++ b/agent/consul/config_replication_test.go @@ -1,6 +1,7 @@ package consul import ( + "context" "fmt" "os" "testing" @@ -137,7 +138,7 @@ func TestReplication_ConfigEntries(t *testing.T) { } out := false - require.NoError(t, s1.RPC("ConfigEntry.Apply", &arg, &out)) + require.NoError(t, s1.RPC(context.Background(), "ConfigEntry.Apply", &arg, &out)) entries = append(entries, arg.Entry) } @@ -155,7 +156,7 @@ func TestReplication_ConfigEntries(t *testing.T) { } out := false - require.NoError(t, s1.RPC("ConfigEntry.Apply", &arg, &out)) + require.NoError(t, s1.RPC(context.Background(), "ConfigEntry.Apply", &arg, &out)) entries = append(entries, arg.Entry) checkSame := func(t *retry.R) error { @@ -208,7 +209,7 @@ func TestReplication_ConfigEntries(t *testing.T) { } out := false - require.NoError(t, s1.RPC("ConfigEntry.Apply", &arg, &out)) + require.NoError(t, s1.RPC(context.Background(), "ConfigEntry.Apply", &arg, &out)) } arg = structs.ConfigEntryRequest{ @@ -224,7 +225,7 @@ func TestReplication_ConfigEntries(t *testing.T) { }, } - require.NoError(t, s1.RPC("ConfigEntry.Apply", &arg, &out)) + require.NoError(t, s1.RPC(context.Background(), "ConfigEntry.Apply", &arg, &out)) // Wait for the replica to converge. retry.Run(t, func(r *retry.R) { @@ -239,7 +240,7 @@ func TestReplication_ConfigEntries(t *testing.T) { } var out structs.ConfigEntryDeleteResponse - require.NoError(t, s1.RPC("ConfigEntry.Delete", &arg, &out)) + require.NoError(t, s1.RPC(context.Background(), "ConfigEntry.Delete", &arg, &out)) } // Wait for the replica to converge. @@ -299,7 +300,7 @@ func TestReplication_ConfigEntries_GraphValidationErrorDuringReplication(t *test } out := false - require.NoError(t, s1.RPC("ConfigEntry.Apply", &arg, &out)) + require.NoError(t, s1.RPC(context.Background(), "ConfigEntry.Apply", &arg, &out)) } // Try to join which should kick off replication. diff --git a/agent/consul/context.go b/agent/consul/context.go new file mode 100644 index 0000000000..ecf782911d --- /dev/null +++ b/agent/consul/context.go @@ -0,0 +1,20 @@ +package consul + +import ( + "context" + "net" +) + +type contextKeyRemoteAddr struct{} + +func ContextWithRemoteAddr(ctx context.Context, addr net.Addr) context.Context { + return context.WithValue(ctx, contextKeyRemoteAddr{}, addr) +} + +func RemoteAddrFromContext(ctx context.Context) (net.Addr, bool) { + v := ctx.Value(contextKeyRemoteAddr{}) + if v == nil { + return nil, false + } + return v.(net.Addr), true +} diff --git a/agent/consul/context_test.go b/agent/consul/context_test.go new file mode 100644 index 0000000000..b6e136fea7 --- /dev/null +++ b/agent/consul/context_test.go @@ -0,0 +1,27 @@ +package consul + +import ( + "context" + "net" + "net/netip" + "testing" +) + +func TestRemoteAddrFromContext_Found(t *testing.T) { + in := net.TCPAddrFromAddrPort(netip.MustParseAddrPort("1.2.3.4:8080")) + ctx := ContextWithRemoteAddr(context.Background(), in) + out, ok := RemoteAddrFromContext(ctx) + if !ok { + t.Fatalf("cannot get remote addr from context") + } + if in != out { + t.Fatalf("expected %s but got %s instead", in, out) + } +} + +func TestRemoteAddrFromContext_NotFound(t *testing.T) { + out, ok := RemoteAddrFromContext(context.Background()) + if ok || out != nil { + t.Fatalf("expected remote addr %s to not be in context", out) + } +} diff --git a/agent/consul/federation_state_replication.go b/agent/consul/federation_state_replication.go index bfb433085a..69fb896940 100644 --- a/agent/consul/federation_state_replication.go +++ b/agent/consul/federation_state_replication.go @@ -55,7 +55,7 @@ func (r *FederationStateReplicator) fetchRemote(lastRemoteIndex uint64) (int, in } var response structs.IndexedFederationStates - if err := r.srv.RPC("FederationState.List", &req, &response); err != nil { + if err := r.srv.RPC(context.Background(), "FederationState.List", &req, &response); err != nil { return 0, nil, 0, err } diff --git a/agent/consul/federation_state_replication_test.go b/agent/consul/federation_state_replication_test.go index a7292e8913..b908aec96d 100644 --- a/agent/consul/federation_state_replication_test.go +++ b/agent/consul/federation_state_replication_test.go @@ -1,6 +1,7 @@ package consul import ( + "context" "fmt" "os" "testing" @@ -70,7 +71,7 @@ func TestReplication_FederationStates(t *testing.T) { } out := false - require.NoError(t, s1.RPC("FederationState.Apply", &arg, &out)) + require.NoError(t, s1.RPC(context.Background(), "FederationState.Apply", &arg, &out)) fedStateDCs = append(fedStateDCs, dc) } @@ -126,7 +127,7 @@ func TestReplication_FederationStates(t *testing.T) { } out := false - require.NoError(t, s1.RPC("FederationState.Apply", &arg, &out)) + require.NoError(t, s1.RPC(context.Background(), "FederationState.Apply", &arg, &out)) } // Wait for the replica to converge. @@ -144,7 +145,7 @@ func TestReplication_FederationStates(t *testing.T) { } out := false - require.NoError(t, s1.RPC("FederationState.Apply", &arg, &out)) + require.NoError(t, s1.RPC(context.Background(), "FederationState.Apply", &arg, &out)) } // Wait for the replica to converge. diff --git a/agent/consul/leader_connect_ca_test.go b/agent/consul/leader_connect_ca_test.go index ec23cc5433..8ffee0b67c 100644 --- a/agent/consul/leader_connect_ca_test.go +++ b/agent/consul/leader_connect_ca_test.go @@ -571,7 +571,7 @@ func TestCAManager_Initialize_Logging(t *testing.T) { // Wait til CA root is setup retry.Run(t, func(r *retry.R) { var out structs.IndexedCARoots - r.Check(s1.RPC("ConnectCA.Roots", structs.DCSpecificRequest{ + r.Check(s1.RPC(context.Background(), "ConnectCA.Roots", structs.DCSpecificRequest{ Datacenter: conf1.Datacenter, }, &out)) }) diff --git a/agent/consul/leader_connect_test.go b/agent/consul/leader_connect_test.go index bfa1fefc25..9c1218f9c2 100644 --- a/agent/consul/leader_connect_test.go +++ b/agent/consul/leader_connect_test.go @@ -614,7 +614,7 @@ func TestConnectCA_ConfigurationSet_RootRotation_Secondary(t *testing.T) { } var reply interface{} - require.NoError(t, s1.RPC("ConnectCA.ConfigurationSet", args, &reply)) + require.NoError(t, s1.RPC(context.Background(), "ConnectCA.ConfigurationSet", args, &reply)) } var updatedRoot *structs.CARoot @@ -1010,7 +1010,7 @@ func TestCAManager_Initialize_TransitionFromPrimaryToSecondary(t *testing.T) { testrpc.WaitForLeader(t, s2.RPC, "dc2") args := structs.DCSpecificRequest{Datacenter: "dc2"} var dc2PrimaryRoots structs.IndexedCARoots - require.NoError(t, s2.RPC("ConnectCA.Roots", &args, &dc2PrimaryRoots)) + require.NoError(t, s2.RPC(context.Background(), "ConnectCA.Roots", &args, &dc2PrimaryRoots)) require.Len(t, dc2PrimaryRoots.Roots, 1) // Shutdown s2 and restart it with the dc1 as the primary @@ -1033,12 +1033,12 @@ func TestCAManager_Initialize_TransitionFromPrimaryToSecondary(t *testing.T) { retry.Run(t, func(r *retry.R) { args = structs.DCSpecificRequest{Datacenter: "dc1"} var dc1Roots structs.IndexedCARoots - require.NoError(r, s1.RPC("ConnectCA.Roots", &args, &dc1Roots)) + require.NoError(r, s1.RPC(context.Background(), "ConnectCA.Roots", &args, &dc1Roots)) require.Len(r, dc1Roots.Roots, 1) args = structs.DCSpecificRequest{Datacenter: "dc2"} var dc2SecondaryRoots structs.IndexedCARoots - require.NoError(r, s3.RPC("ConnectCA.Roots", &args, &dc2SecondaryRoots)) + require.NoError(r, s3.RPC(context.Background(), "ConnectCA.Roots", &args, &dc2SecondaryRoots)) // dc2's TrustDomain should have changed to the primary's require.Equal(r, dc2SecondaryRoots.TrustDomain, dc1Roots.TrustDomain) @@ -1191,7 +1191,7 @@ func getTestRoots(s *Server, datacenter string) (*structs.IndexedCARoots, *struc Datacenter: datacenter, } var rootList structs.IndexedCARoots - if err := s.RPC("ConnectCA.Roots", rootReq, &rootList); err != nil { + if err := s.RPC(context.Background(), "ConnectCA.Roots", rootReq, &rootList); err != nil { return nil, nil, err } @@ -1586,7 +1586,7 @@ func TestCAManager_Initialize_Vault_BadCAConfigDoesNotPreventLeaderEstablishment var reply interface{} retry.Run(t, func(r *retry.R) { - require.NoError(r, s1.RPC("ConnectCA.ConfigurationSet", args, &reply)) + require.NoError(r, s1.RPC(context.Background(), "ConnectCA.ConfigurationSet", args, &reply)) }) } @@ -1628,7 +1628,7 @@ func TestCAManager_Initialize_BadCAConfigDoesNotPreventLeaderEstablishment(t *te var reply interface{} retry.Run(t, func(r *retry.R) { - require.NoError(r, s1.RPC("ConnectCA.ConfigurationSet", args, &reply)) + require.NoError(r, s1.RPC(context.Background(), "ConnectCA.ConfigurationSet", args, &reply)) }) } diff --git a/agent/consul/leader_federation_state_ae_test.go b/agent/consul/leader_federation_state_ae_test.go index d7f6d108f1..597a927530 100644 --- a/agent/consul/leader_federation_state_ae_test.go +++ b/agent/consul/leader_federation_state_ae_test.go @@ -1,6 +1,7 @@ package consul import ( + "context" "os" "testing" "time" @@ -170,7 +171,7 @@ func TestLeader_FederationStateAntiEntropy_BlockingQuery(t *testing.T) { } out := false - require.NoError(t, s1.RPC("FederationState.Apply", &arg, &out)) + require.NoError(t, s1.RPC(context.Background(), "FederationState.Apply", &arg, &out)) } makeGateways := func(t *testing.T, csn structs.CheckServiceNode) { @@ -185,7 +186,7 @@ func TestLeader_FederationStateAntiEntropy_BlockingQuery(t *testing.T) { Checks: csn.Checks, } var out struct{} - require.NoError(t, s2.RPC("Catalog.Register", &arg, &out)) + require.NoError(t, s2.RPC(context.Background(), "Catalog.Register", &arg, &out)) } type result struct { diff --git a/agent/consul/leader_intentions_test.go b/agent/consul/leader_intentions_test.go index b431567fc6..e0dcb8b3d1 100644 --- a/agent/consul/leader_intentions_test.go +++ b/agent/consul/leader_intentions_test.go @@ -1,6 +1,7 @@ package consul import ( + "context" "os" "strings" "testing" @@ -140,7 +141,7 @@ func TestLeader_ReplicateIntentions(t *testing.T) { IntentionID: ixn.Intention.ID, } var resp structs.IndexedIntentions - require.NoError(r, s2.RPC("Intention.Get", req, &resp), "ID=%q", ixn.Intention.ID) + require.NoError(r, s2.RPC(context.Background(), "Intention.Get", req, &resp), "ID=%q", ixn.Intention.ID) require.Len(r, resp.Intentions, 1) actual := resp.Intentions[0] @@ -171,7 +172,7 @@ func TestLeader_ReplicateIntentions(t *testing.T) { IntentionID: ixn.Intention.ID, } - require.NoError(r, s2.RPC("Intention.Get", req, &resp), "ID=%q", ixn.Intention.ID) + require.NoError(r, s2.RPC(context.Background(), "Intention.Get", req, &resp), "ID=%q", ixn.Intention.ID) require.Len(r, resp.Intentions, 1) require.Equal(r, "*", resp.Intentions[0].SourceName) }) @@ -205,7 +206,7 @@ func TestLeader_ReplicateIntentions(t *testing.T) { IntentionID: ixn.Intention.ID, } var resp structs.IndexedIntentions - err := s2.RPC("Intention.Get", req, &resp) + err := s2.RPC(context.Background(), "Intention.Get", req, &resp) require.Error(r, err) if !strings.Contains(err.Error(), ErrIntentionNotFound.Error()) { r.Fatalf("expected intention not found, got: %v", err) diff --git a/agent/consul/leader_test.go b/agent/consul/leader_test.go index 35bc924b7a..33094f59d5 100644 --- a/agent/consul/leader_test.go +++ b/agent/consul/leader_test.go @@ -2,6 +2,7 @@ package consul import ( "bufio" + "context" "encoding/json" "fmt" "io" @@ -587,7 +588,7 @@ func TestLeader_Reconcile_ReapMember(t *testing.T) { }, } var out struct{} - if err := s1.RPC("Catalog.Register", &dead, &out); err != nil { + if err := s1.RPC(context.Background(), "Catalog.Register", &dead, &out); err != nil { t.Fatalf("err: %v", err) } @@ -701,7 +702,7 @@ func TestLeader_Reconcile_Races(t *testing.T) { }, } var out struct{} - if err := s1.RPC("Catalog.Register", &req, &out); err != nil { + if err := s1.RPC(context.Background(), "Catalog.Register", &req, &out); err != nil { t.Fatalf("err: %v", err) } @@ -1592,7 +1593,7 @@ func TestDatacenterSupportsFederationStates(t *testing.T) { } var out struct{} - require.NoError(t, srv.RPC("Catalog.Register", &arg, &out)) + require.NoError(t, srv.RPC(context.Background(), "Catalog.Register", &arg, &out)) } t.Run("one node primary with old version", func(t *testing.T) { @@ -1644,7 +1645,7 @@ func TestDatacenterSupportsFederationStates(t *testing.T) { } var out structs.FederationStateResponse - require.NoError(r, s1.RPC("FederationState.Get", &arg, &out)) + require.NoError(r, s1.RPC(context.Background(), "FederationState.Get", &arg, &out)) require.NotNil(r, out.State) require.Len(r, out.State.MeshGateways, 1) }) @@ -1749,7 +1750,7 @@ func TestDatacenterSupportsFederationStates(t *testing.T) { } var out structs.IndexedFederationStates - require.NoError(r, s1.RPC("FederationState.List", &arg, &out)) + require.NoError(r, s1.RPC(context.Background(), "FederationState.List", &arg, &out)) require.Len(r, out.States, 1) require.Len(r, out.States[0].MeshGateways, 1) }) @@ -1805,7 +1806,7 @@ func TestDatacenterSupportsFederationStates(t *testing.T) { } var out structs.IndexedFederationStates - require.NoError(r, s1.RPC("FederationState.List", &arg, &out)) + require.NoError(r, s1.RPC(context.Background(), "FederationState.List", &arg, &out)) require.Len(r, out.States, 2) require.Len(r, out.States[0].MeshGateways, 1) require.Len(r, out.States[1].MeshGateways, 1) @@ -1818,7 +1819,7 @@ func TestDatacenterSupportsFederationStates(t *testing.T) { } var out structs.IndexedFederationStates - require.NoError(r, s1.RPC("FederationState.List", &arg, &out)) + require.NoError(r, s1.RPC(context.Background(), "FederationState.List", &arg, &out)) require.Len(r, out.States, 2) require.Len(r, out.States[0].MeshGateways, 1) require.Len(r, out.States[1].MeshGateways, 1) @@ -1905,7 +1906,7 @@ func TestDatacenterSupportsIntentionsAsConfigEntries(t *testing.T) { } var id string - return srv.RPC("Intention.Apply", &arg, &id) + return srv.RPC(context.Background(), "Intention.Apply", &arg, &id) } getConfigEntry := func(srv *Server, dc, kind, name string) (structs.ConfigEntry, error) { @@ -1915,7 +1916,7 @@ func TestDatacenterSupportsIntentionsAsConfigEntries(t *testing.T) { Name: name, } var reply structs.ConfigEntryResponse - if err := srv.RPC("ConfigEntry.Get", &arg, &reply); err != nil { + if err := srv.RPC(context.Background(), "ConfigEntry.Get", &arg, &reply); err != nil { return nil, err } return reply.Entry, nil diff --git a/agent/consul/rpc_test.go b/agent/consul/rpc_test.go index 01ea34961b..2dce38ed0e 100644 --- a/agent/consul/rpc_test.go +++ b/agent/consul/rpc_test.go @@ -1158,7 +1158,7 @@ func TestRPC_LocalTokenStrippedOnForward_GRPC(t *testing.T) { WriteRequest: structs.WriteRequest{Token: "root"}, } var out struct{} - require.NoError(t, s1.RPC("Catalog.Register", &req, &out)) + require.NoError(t, s1.RPC(context.Background(), "Catalog.Register", &req, &out)) }) var conn *grpc.ClientConn diff --git a/agent/consul/server.go b/agent/consul/server.go index c37f415e50..cf16f5d012 100644 --- a/agent/consul/server.go +++ b/agent/consul/server.go @@ -1494,10 +1494,11 @@ func (s *Server) AgentEnterpriseMeta() *acl.EnterpriseMeta { // inmemCodec is used to do an RPC call without going over a network type inmemCodec struct { - method string - args interface{} - reply interface{} - err error + method string + args interface{} + reply interface{} + err error + sourceAddr net.Addr } func (i *inmemCodec) ReadRequestHeader(req *rpc.Request) error { @@ -1523,16 +1524,22 @@ func (i *inmemCodec) WriteResponse(resp *rpc.Response, reply interface{}) error return nil } +func (i *inmemCodec) SourceAddr() net.Addr { + return i.sourceAddr +} + func (i *inmemCodec) Close() error { return nil } // RPC is used to make a local RPC call -func (s *Server) RPC(method string, args interface{}, reply interface{}) error { +func (s *Server) RPC(ctx context.Context, method string, args interface{}, reply interface{}) error { + remoteAddr, _ := RemoteAddrFromContext(ctx) codec := &inmemCodec{ - method: method, - args: args, - reply: reply, + method: method, + args: args, + reply: reply, + sourceAddr: remoteAddr, } // Enforce the RPC limit. diff --git a/agent/consul/server_test.go b/agent/consul/server_test.go index 4cb7b5d97e..759a682f28 100644 --- a/agent/consul/server_test.go +++ b/agent/consul/server_test.go @@ -862,7 +862,7 @@ func TestServer_JoinWAN_viaMeshGateway(t *testing.T) { } var out struct{} - require.NoError(t, s1.RPC("Catalog.Register", &arg, &out)) + require.NoError(t, s1.RPC(context.Background(), "Catalog.Register", &arg, &out)) } // Wait for it to make it into the gateway locator. @@ -917,7 +917,7 @@ func TestServer_JoinWAN_viaMeshGateway(t *testing.T) { } var out struct{} - require.NoError(t, s2.RPC("Catalog.Register", &arg, &out)) + require.NoError(t, s2.RPC(context.Background(), "Catalog.Register", &arg, &out)) } { arg := structs.RegisterRequest{ @@ -934,7 +934,7 @@ func TestServer_JoinWAN_viaMeshGateway(t *testing.T) { } var out struct{} - require.NoError(t, s3.RPC("Catalog.Register", &arg, &out)) + require.NoError(t, s3.RPC(context.Background(), "Catalog.Register", &arg, &out)) } // Wait for it to make it into the gateway locator in dc2 and then for @@ -988,7 +988,7 @@ func TestServer_JoinWAN_viaMeshGateway(t *testing.T) { Datacenter: dstDC, } var out structs.IndexedNodes - require.NoError(t, srv.RPC("Catalog.ListNodes", &arg, &out)) + require.NoError(t, srv.RPC(context.Background(), "Catalog.ListNodes", &arg, &out)) require.Len(t, out.Nodes, 1) node := out.Nodes[0] require.Equal(t, dstDC, node.Datacenter) @@ -1195,7 +1195,7 @@ func TestServer_RPC(t *testing.T) { defer s1.Shutdown() var out struct{} - if err := s1.RPC("Status.Ping", struct{}{}, &out); err != nil { + if err := s1.RPC(context.Background(), "Status.Ping", struct{}{}, &out); err != nil { t.Fatalf("err: %v", err) } } @@ -1248,7 +1248,7 @@ func TestServer_RPC_MetricsIntercept_Off(t *testing.T) { t.Cleanup(func() { s1.Shutdown() }) var out struct{} - if err := s1.RPC("Status.Ping", struct{}{}, &out); err != nil { + if err := s1.RPC(context.Background(), "Status.Ping", struct{}{}, &out); err != nil { t.Fatalf("err: %v", err) } @@ -1289,7 +1289,7 @@ func TestServer_RPC_MetricsIntercept_Off(t *testing.T) { } var out struct{} - if err := s2.RPC("Status.Ping", struct{}{}, &out); err != nil { + if err := s2.RPC(context.Background(), "Status.Ping", struct{}{}, &out); err != nil { t.Fatalf("err: %v", err) } @@ -1398,7 +1398,7 @@ func TestServer_RPC_MetricsIntercept(t *testing.T) { // asserts t.Run("test happy path for metrics interceptor", func(t *testing.T) { var out struct{} - if err := s.RPC("Status.Ping", struct{}{}, &out); err != nil { + if err := s.RPC(context.Background(), "Status.Ping", struct{}{}, &out); err != nil { t.Fatalf("err: %v", err) } @@ -2018,7 +2018,7 @@ func TestServer_RPC_RateLimit(t *testing.T) { retry.Run(t, func(r *retry.R) { var out struct{} - if err := s1.RPC("Status.Ping", struct{}{}, &out); err != structs.ErrRPCRateExceeded { + if err := s1.RPC(context.Background(), "Status.Ping", struct{}{}, &out); err != structs.ErrRPCRateExceeded { r.Fatalf("err: %v", err) } }) diff --git a/agent/consul/session_ttl_test.go b/agent/consul/session_ttl_test.go index c380962e41..2ebff6c092 100644 --- a/agent/consul/session_ttl_test.go +++ b/agent/consul/session_ttl_test.go @@ -1,6 +1,7 @@ package consul import ( + "context" "fmt" "os" "strings" @@ -321,7 +322,7 @@ func TestServer_SessionTTL_Failover(t *testing.T) { Address: "127.0.0.1", } var out struct{} - if err := s1.RPC("Catalog.Register", &node, &out); err != nil { + if err := s1.RPC(context.Background(), "Catalog.Register", &node, &out); err != nil { t.Fatalf("err: %v", err) } diff --git a/agent/consul/subscribe_backend_test.go b/agent/consul/subscribe_backend_test.go index a4d1134e18..26bd3f90b1 100644 --- a/agent/consul/subscribe_backend_test.go +++ b/agent/consul/subscribe_backend_test.go @@ -58,7 +58,7 @@ func TestSubscribeBackend_IntegrationWithServer_TLSEnabled(t *testing.T) { }, } var out struct{} - require.NoError(t, server.RPC("Catalog.Register", &req, &out)) + require.NoError(t, server.RPC(context.Background(), "Catalog.Register", &req, &out)) } // Start a Subscribe call to our streaming endpoint from the client. @@ -301,7 +301,7 @@ func TestSubscribeBackend_IntegrationWithServer_DeliversAllMessages(t *testing.T }, } var out struct{} - require.NoError(t, server.RPC("Catalog.Register", &req, &out)) + require.NoError(t, server.RPC(context.Background(), "Catalog.Register", &req, &out)) } // Start background writer @@ -326,7 +326,7 @@ func TestSubscribeBackend_IntegrationWithServer_DeliversAllMessages(t *testing.T return } var out struct{} - require.NoError(t, server.RPC("Catalog.Register", &req, &out)) + require.NoError(t, server.RPC(context.Background(), "Catalog.Register", &req, &out)) req.Service.Port++ if req.Service.Port > 100 { return diff --git a/agent/consul/txn_endpoint_test.go b/agent/consul/txn_endpoint_test.go index 62aca0f7aa..1cd33f18d1 100644 --- a/agent/consul/txn_endpoint_test.go +++ b/agent/consul/txn_endpoint_test.go @@ -2,6 +2,7 @@ package consul import ( "bytes" + "context" "os" "strings" "testing" @@ -551,7 +552,7 @@ func TestTxn_Apply_ACLDeny(t *testing.T) { }, } var out structs.TxnResponse - if err := s1.RPC("Txn.Apply", &arg, &out); err != nil { + if err := s1.RPC(context.Background(), "Txn.Apply", &arg, &out); err != nil { t.Fatalf("err: %v", err) } diff --git a/agent/coordinate_endpoint.go b/agent/coordinate_endpoint.go index bb4f328ede..a238f52ae6 100644 --- a/agent/coordinate_endpoint.go +++ b/agent/coordinate_endpoint.go @@ -47,7 +47,7 @@ func (s *HTTPHandlers) CoordinateDatacenters(resp http.ResponseWriter, req *http } var out []structs.DatacenterMap - if err := s.agent.RPC("Coordinate.ListDatacenters", struct{}{}, &out); err != nil { + if err := s.agent.RPC(req.Context(), "Coordinate.ListDatacenters", struct{}{}, &out); err != nil { for i := range out { sort.Sort(&sorter{out[i].Coordinates}) } @@ -85,7 +85,7 @@ func (s *HTTPHandlers) CoordinateNodes(resp http.ResponseWriter, req *http.Reque var out structs.IndexedCoordinates defer setMeta(resp, &out.QueryMeta) - if err := s.agent.RPC("Coordinate.ListNodes", &args, &out); err != nil { + if err := s.agent.RPC(req.Context(), "Coordinate.ListNodes", &args, &out); err != nil { sort.Sort(&sorter{out.Coordinates}) return nil, err } @@ -111,7 +111,7 @@ func (s *HTTPHandlers) CoordinateNode(resp http.ResponseWriter, req *http.Reques var out structs.IndexedCoordinates defer setMeta(resp, &out.QueryMeta) - if err := s.agent.RPC("Coordinate.Node", &args, &out); err != nil { + if err := s.agent.RPC(req.Context(), "Coordinate.Node", &args, &out); err != nil { return nil, err } @@ -164,7 +164,7 @@ func (s *HTTPHandlers) CoordinateUpdate(resp http.ResponseWriter, req *http.Requ } var reply struct{} - if err := s.agent.RPC("Coordinate.Update", &args, &reply); err != nil { + if err := s.agent.RPC(req.Context(), "Coordinate.Update", &args, &reply); err != nil { return nil, err } diff --git a/agent/coordinate_endpoint_test.go b/agent/coordinate_endpoint_test.go index 331451641f..38fc97cec4 100644 --- a/agent/coordinate_endpoint_test.go +++ b/agent/coordinate_endpoint_test.go @@ -1,6 +1,7 @@ package agent import ( + "context" "fmt" "net/http" "net/http/httptest" @@ -130,7 +131,7 @@ func TestCoordinate_Nodes(t *testing.T) { Address: "127.0.0.1", } var reply struct{} - if err := a.RPC("Catalog.Register", &req, &reply); err != nil { + if err := a.RPC(context.Background(), "Catalog.Register", &req, &reply); err != nil { t.Fatalf("err: %s", err) } } @@ -143,7 +144,7 @@ func TestCoordinate_Nodes(t *testing.T) { Coord: coordinate.NewCoordinate(coordinate.DefaultConfig()), } var out struct{} - if err := a.RPC("Coordinate.Update", &arg1, &out); err != nil { + if err := a.RPC(context.Background(), "Coordinate.Update", &arg1, &out); err != nil { t.Fatalf("err: %v", err) } @@ -152,7 +153,7 @@ func TestCoordinate_Nodes(t *testing.T) { Node: "bar", Coord: coordinate.NewCoordinate(coordinate.DefaultConfig()), } - if err := a.RPC("Coordinate.Update", &arg2, &out); err != nil { + if err := a.RPC(context.Background(), "Coordinate.Update", &arg2, &out); err != nil { t.Fatalf("err: %v", err) } time.Sleep(300 * time.Millisecond) @@ -213,7 +214,7 @@ func TestCoordinate_Node(t *testing.T) { Address: "127.0.0.1", } var reply struct{} - if err := a.RPC("Catalog.Register", &req, &reply); err != nil { + if err := a.RPC(context.Background(), "Catalog.Register", &req, &reply); err != nil { t.Fatalf("err: %s", err) } } @@ -226,7 +227,7 @@ func TestCoordinate_Node(t *testing.T) { Coord: coordinate.NewCoordinate(coordinate.DefaultConfig()), } var out struct{} - if err := a.RPC("Coordinate.Update", &arg1, &out); err != nil { + if err := a.RPC(context.Background(), "Coordinate.Update", &arg1, &out); err != nil { t.Fatalf("err: %v", err) } @@ -235,7 +236,7 @@ func TestCoordinate_Node(t *testing.T) { Node: "bar", Coord: coordinate.NewCoordinate(coordinate.DefaultConfig()), } - if err := a.RPC("Coordinate.Update", &arg2, &out); err != nil { + if err := a.RPC(context.Background(), "Coordinate.Update", &arg2, &out); err != nil { t.Fatalf("err: %v", err) } time.Sleep(300 * time.Millisecond) @@ -276,7 +277,7 @@ func TestCoordinate_Update(t *testing.T) { Address: "127.0.0.1", } var reply struct{} - if err := a.RPC("Catalog.Register", ®, &reply); err != nil { + if err := a.RPC(context.Background(), "Catalog.Register", ®, &reply); err != nil { t.Fatalf("err: %s", err) } @@ -304,7 +305,7 @@ func TestCoordinate_Update(t *testing.T) { // Query back and check the coordinates are present. args := structs.NodeSpecificRequest{Node: "foo", Datacenter: "dc1"} var coords structs.IndexedCoordinates - if err := a.RPC("Coordinate.Node", &args, &coords); err != nil { + if err := a.RPC(context.Background(), "Coordinate.Node", &args, &coords); err != nil { t.Fatalf("err: %s", err) } diff --git a/agent/delegate_mock_test.go b/agent/delegate_mock_test.go index 23b93b829a..67ba0abd8d 100644 --- a/agent/delegate_mock_test.go +++ b/agent/delegate_mock_test.go @@ -1,6 +1,7 @@ package agent import ( + "context" "io" "github.com/hashicorp/serf/serf" @@ -53,7 +54,7 @@ func (m *delegateMock) ResolveTokenAndDefaultMeta(token string, entMeta *acl.Ent return ret.Get(0).(resolver.Result), ret.Error(1) } -func (m *delegateMock) RPC(method string, args interface{}, reply interface{}) error { +func (m *delegateMock) RPC(ctx context.Context, method string, args interface{}, reply interface{}) error { return m.Called(method, args, reply).Error(0) } diff --git a/agent/discovery_chain_endpoint.go b/agent/discovery_chain_endpoint.go index 4f762faf04..1b39a90966 100644 --- a/agent/discovery_chain_endpoint.go +++ b/agent/discovery_chain_endpoint.go @@ -74,7 +74,7 @@ func (s *HTTPHandlers) DiscoveryChainRead(resp http.ResponseWriter, req *http.Re out = *reply } else { RETRY_ONCE: - if err := s.agent.RPC("DiscoveryChain.Get", &args, &out); err != nil { + if err := s.agent.RPC(req.Context(), "DiscoveryChain.Get", &args, &out); err != nil { return nil, err } if args.QueryOptions.AllowStale && args.MaxStaleDuration > 0 && args.MaxStaleDuration < out.LastContact { diff --git a/agent/discovery_chain_endpoint_test.go b/agent/discovery_chain_endpoint_test.go index 42c0825916..f83fa49293 100644 --- a/agent/discovery_chain_endpoint_test.go +++ b/agent/discovery_chain_endpoint_test.go @@ -1,6 +1,7 @@ package agent import ( + "context" "net/http" "net/http/httptest" "reflect" @@ -216,7 +217,7 @@ func TestDiscoveryChainRead(t *testing.T) { { // Now create one config entry. out := false - require.NoError(t, a.RPC("ConfigEntry.Apply", &structs.ConfigEntryRequest{ + require.NoError(t, a.RPC(context.Background(), "ConfigEntry.Apply", &structs.ConfigEntryRequest{ Datacenter: "dc1", Entry: &structs.ServiceResolverConfigEntry{ Kind: structs.ServiceResolver, diff --git a/agent/dns.go b/agent/dns.go index b35f80c630..3dce6410d7 100644 --- a/agent/dns.go +++ b/agent/dns.go @@ -421,7 +421,7 @@ func (d *DNSServer) handlePtr(resp dns.ResponseWriter, req *dns.Msg) { // TODO: Replace ListNodes with an internal RPC that can do the filter // server side to avoid transferring the entire node list. - if err := d.agent.RPC("Catalog.ListNodes", &args, &out); err == nil { + if err := d.agent.RPC(context.Background(), "Catalog.ListNodes", &args, &out); err == nil { for _, n := range out.Nodes { lookup := serviceLookup{ // Peering PTR lookups are currently not supported, so we don't @@ -457,7 +457,7 @@ func (d *DNSServer) handlePtr(resp dns.ResponseWriter, req *dns.Msg) { } var sout structs.IndexedServiceNodes - if err := d.agent.RPC("Catalog.ServiceNodes", &sargs, &sout); err == nil { + if err := d.agent.RPC(context.Background(), "Catalog.ServiceNodes", &sargs, &sout); err == nil { for _, n := range sout.ServiceNodes { if n.ServiceAddress == serviceAddress { ptr := &dns.PTR{ @@ -872,7 +872,7 @@ func (d *DNSServer) dispatch(remoteAddr net.Addr, req, resp *dns.Msg, maxRecursi } var out string - if err := d.agent.RPC("Catalog.VirtualIPForService", &args, &out); err != nil { + if err := d.agent.RPC(context.Background(), "Catalog.VirtualIPForService", &args, &out); err != nil { return err } if out != "" { @@ -1135,7 +1135,7 @@ RPC: } out = *reply } else { - if err := d.agent.RPC("Catalog.NodeServices", &args, &out); err != nil { + if err := d.agent.RPC(context.Background(), "Catalog.NodeServices", &args, &out); err != nil { return nil, err } } @@ -1599,7 +1599,7 @@ RPC: out = *reply } else { - if err := d.agent.RPC("PreparedQuery.Execute", &args, &out); err != nil { + if err := d.agent.RPC(context.Background(), "PreparedQuery.Execute", &args, &out); err != nil { return nil, err } } diff --git a/agent/dns_oss_test.go b/agent/dns_oss_test.go index a394aa8718..c3cf264c68 100644 --- a/agent/dns_oss_test.go +++ b/agent/dns_oss_test.go @@ -4,6 +4,7 @@ package agent import ( + "context" "testing" "github.com/hashicorp/consul/acl" @@ -69,7 +70,7 @@ func TestDNS_OSS_PeeredServices(t *testing.T) { } t.Run("srv-with-addr-reply", func(t *testing.T) { - require.NoError(t, a.RPC("Catalog.Register", makeReq(), &struct{}{})) + require.NoError(t, a.RPC(context.Background(), "Catalog.Register", makeReq(), &struct{}{})) q := dnsQuery(t, "web-proxy.service.peer1.peer.consul.", dns.TypeSRV) require.Len(t, q.Answer, 1) require.Len(t, q.Extra, 1) @@ -89,7 +90,7 @@ func TestDNS_OSS_PeeredServices(t *testing.T) { req := makeReq() // Clear service address to trigger node response req.Service.Address = "" - require.NoError(t, a.RPC("Catalog.Register", req, &struct{}{})) + require.NoError(t, a.RPC(context.Background(), "Catalog.Register", req, &struct{}{})) q := dnsQuery(t, "web-proxy.service.peer1.peer.consul.", dns.TypeSRV) require.Len(t, q.Answer, 1) require.Len(t, q.Extra, 1) @@ -110,7 +111,7 @@ func TestDNS_OSS_PeeredServices(t *testing.T) { // Set non-ip address to trigger external response req.Address = "localhost" req.Service.Address = "" - require.NoError(t, a.RPC("Catalog.Register", req, &struct{}{})) + require.NoError(t, a.RPC(context.Background(), "Catalog.Register", req, &struct{}{})) q := dnsQuery(t, "web-proxy.service.peer1.peer.consul.", dns.TypeSRV) require.Len(t, q.Answer, 1) require.Len(t, q.Extra, 0) @@ -118,7 +119,7 @@ func TestDNS_OSS_PeeredServices(t *testing.T) { }) t.Run("a-reply", func(t *testing.T) { - require.NoError(t, a.RPC("Catalog.Register", makeReq(), &struct{}{})) + require.NoError(t, a.RPC(context.Background(), "Catalog.Register", makeReq(), &struct{}{})) q := dnsQuery(t, "web-proxy.service.peer1.peer.consul.", dns.TypeA) require.Len(t, q.Answer, 1) require.Len(t, q.Extra, 0) diff --git a/agent/dns_test.go b/agent/dns_test.go index 7b1512fbf6..501564c66c 100644 --- a/agent/dns_test.go +++ b/agent/dns_test.go @@ -1,6 +1,7 @@ package agent import ( + "context" "errors" "fmt" "math" @@ -177,7 +178,7 @@ func TestDNS_Over_TCP(t *testing.T) { } var out struct{} - if err := a.RPC("Catalog.Register", args, &out); err != nil { + if err := a.RPC(context.Background(), "Catalog.Register", args, &out); err != nil { t.Fatalf("err: %v", err) } @@ -239,7 +240,7 @@ func TestDNS_NodeLookup(t *testing.T) { } var out struct{} - if err := a.RPC("Catalog.Register", args, &out); err != nil { + if err := a.RPC(context.Background(), "Catalog.Register", args, &out); err != nil { t.Fatalf("err: %v", err) } @@ -332,7 +333,7 @@ func TestDNS_CaseInsensitiveNodeLookup(t *testing.T) { } var out struct{} - if err := a.RPC("Catalog.Register", args, &out); err != nil { + if err := a.RPC(context.Background(), "Catalog.Register", args, &out); err != nil { t.Fatalf("err: %v", err) } @@ -368,7 +369,7 @@ func TestDNS_NodeLookup_PeriodName(t *testing.T) { } var out struct{} - if err := a.RPC("Catalog.Register", args, &out); err != nil { + if err := a.RPC(context.Background(), "Catalog.Register", args, &out); err != nil { t.Fatalf("err: %v", err) } @@ -412,7 +413,7 @@ func TestDNS_NodeLookup_AAAA(t *testing.T) { } var out struct{} - if err := a.RPC("Catalog.Register", args, &out); err != nil { + if err := a.RPC(context.Background(), "Catalog.Register", args, &out); err != nil { t.Fatalf("err: %v", err) } @@ -541,7 +542,7 @@ func TestDNS_NodeLookup_CNAME(t *testing.T) { } var out struct{} - if err := a.RPC("Catalog.Register", args, &out); err != nil { + if err := a.RPC(context.Background(), "Catalog.Register", args, &out); err != nil { t.Fatalf("err: %v", err) } @@ -596,7 +597,7 @@ func TestDNS_NodeLookup_TXT(t *testing.T) { } var out struct{} - if err := a.RPC("Catalog.Register", args, &out); err != nil { + if err := a.RPC(context.Background(), "Catalog.Register", args, &out); err != nil { t.Fatalf("err: %v", err) } @@ -646,7 +647,7 @@ func TestDNS_NodeLookup_TXT_DontSuppress(t *testing.T) { } var out struct{} - if err := a.RPC("Catalog.Register", args, &out); err != nil { + if err := a.RPC(context.Background(), "Catalog.Register", args, &out); err != nil { t.Fatalf("err: %v", err) } @@ -695,7 +696,7 @@ func TestDNS_NodeLookup_ANY(t *testing.T) { } var out struct{} - if err := a.RPC("Catalog.Register", args, &out); err != nil { + if err := a.RPC(context.Background(), "Catalog.Register", args, &out); err != nil { t.Fatalf("err: %v", err) } @@ -740,7 +741,7 @@ func TestDNS_NodeLookup_ANY_DontSuppressTXT(t *testing.T) { } var out struct{} - if err := a.RPC("Catalog.Register", args, &out); err != nil { + if err := a.RPC(context.Background(), "Catalog.Register", args, &out); err != nil { t.Fatalf("err: %v", err) } @@ -785,7 +786,7 @@ func TestDNS_NodeLookup_A_SuppressTXT(t *testing.T) { } var out struct{} - require.NoError(t, a.RPC("Catalog.Register", args, &out)) + require.NoError(t, a.RPC(context.Background(), "Catalog.Register", args, &out)) m := new(dns.Msg) m.SetQuestion("bar.node.consul.", dns.TypeA) @@ -824,7 +825,7 @@ func TestDNS_EDNS0(t *testing.T) { } var out struct{} - if err := a.RPC("Catalog.Register", args, &out); err != nil { + if err := a.RPC(context.Background(), "Catalog.Register", args, &out); err != nil { t.Fatalf("err: %v", err) } @@ -874,7 +875,7 @@ func TestDNS_EDNS0_ECS(t *testing.T) { } var out struct{} - require.NoError(t, a.RPC("Catalog.Register", args, &out)) + require.NoError(t, a.RPC(context.Background(), "Catalog.Register", args, &out)) } // Register an equivalent prepared query. @@ -890,7 +891,7 @@ func TestDNS_EDNS0_ECS(t *testing.T) { }, }, } - require.NoError(t, a.RPC("PreparedQuery.Apply", args, &id)) + require.NoError(t, a.RPC(context.Background(), "PreparedQuery.Apply", args, &id)) } cases := []struct { @@ -963,7 +964,7 @@ func TestDNS_ReverseLookup(t *testing.T) { } var out struct{} - if err := a.RPC("Catalog.Register", args, &out); err != nil { + if err := a.RPC(context.Background(), "Catalog.Register", args, &out); err != nil { t.Fatalf("err: %v", err) } @@ -1009,7 +1010,7 @@ func TestDNS_ReverseLookup_CustomDomain(t *testing.T) { } var out struct{} - if err := a.RPC("Catalog.Register", args, &out); err != nil { + if err := a.RPC(context.Background(), "Catalog.Register", args, &out); err != nil { t.Fatalf("err: %v", err) } @@ -1053,7 +1054,7 @@ func TestDNS_ReverseLookup_IPV6(t *testing.T) { } var out struct{} - if err := a.RPC("Catalog.Register", args, &out); err != nil { + if err := a.RPC(context.Background(), "Catalog.Register", args, &out); err != nil { t.Fatalf("err: %v", err) } @@ -1104,7 +1105,7 @@ func TestDNS_ServiceReverseLookup(t *testing.T) { } var out struct{} - if err := a.RPC("Catalog.Register", args, &out); err != nil { + if err := a.RPC(context.Background(), "Catalog.Register", args, &out); err != nil { t.Fatalf("err: %v", err) } } @@ -1156,7 +1157,7 @@ func TestDNS_ServiceReverseLookup_IPV6(t *testing.T) { } var out struct{} - if err := a.RPC("Catalog.Register", args, &out); err != nil { + if err := a.RPC(context.Background(), "Catalog.Register", args, &out); err != nil { t.Fatalf("err: %v", err) } } @@ -1210,7 +1211,7 @@ func TestDNS_ServiceReverseLookup_CustomDomain(t *testing.T) { } var out struct{} - if err := a.RPC("Catalog.Register", args, &out); err != nil { + if err := a.RPC(context.Background(), "Catalog.Register", args, &out); err != nil { t.Fatalf("err: %v", err) } } @@ -1299,7 +1300,7 @@ func TestDNS_ServiceReverseLookupNodeAddress(t *testing.T) { } var out struct{} - if err := a.RPC("Catalog.Register", args, &out); err != nil { + if err := a.RPC(context.Background(), "Catalog.Register", args, &out); err != nil { t.Fatalf("err: %v", err) } } @@ -1350,7 +1351,7 @@ func TestDNS_ServiceLookupNoMultiCNAME(t *testing.T) { } var out struct{} - require.NoError(t, a.RPC("Catalog.Register", args, &out)) + require.NoError(t, a.RPC(context.Background(), "Catalog.Register", args, &out)) } // Register a second node node with the same service. @@ -1367,7 +1368,7 @@ func TestDNS_ServiceLookupNoMultiCNAME(t *testing.T) { } var out struct{} - if err := a.RPC("Catalog.Register", args, &out); err != nil { + if err := a.RPC(context.Background(), "Catalog.Register", args, &out); err != nil { t.Fatalf("err: %v", err) } } @@ -1409,7 +1410,7 @@ func TestDNS_ServiceLookupPreferNoCNAME(t *testing.T) { } var out struct{} - require.NoError(t, a.RPC("Catalog.Register", args, &out)) + require.NoError(t, a.RPC(context.Background(), "Catalog.Register", args, &out)) } // Register a second node node with the same service. @@ -1426,7 +1427,7 @@ func TestDNS_ServiceLookupPreferNoCNAME(t *testing.T) { } var out struct{} - if err := a.RPC("Catalog.Register", args, &out); err != nil { + if err := a.RPC(context.Background(), "Catalog.Register", args, &out); err != nil { t.Fatalf("err: %v", err) } } @@ -1471,7 +1472,7 @@ func TestDNS_ServiceLookupMultiAddrNoCNAME(t *testing.T) { } var out struct{} - require.NoError(t, a.RPC("Catalog.Register", args, &out)) + require.NoError(t, a.RPC(context.Background(), "Catalog.Register", args, &out)) } // Register a second node node with the same service. @@ -1488,7 +1489,7 @@ func TestDNS_ServiceLookupMultiAddrNoCNAME(t *testing.T) { } var out struct{} - if err := a.RPC("Catalog.Register", args, &out); err != nil { + if err := a.RPC(context.Background(), "Catalog.Register", args, &out); err != nil { t.Fatalf("err: %v", err) } } @@ -1507,7 +1508,7 @@ func TestDNS_ServiceLookupMultiAddrNoCNAME(t *testing.T) { } var out struct{} - if err := a.RPC("Catalog.Register", args, &out); err != nil { + if err := a.RPC(context.Background(), "Catalog.Register", args, &out); err != nil { t.Fatalf("err: %v", err) } } @@ -1549,7 +1550,7 @@ func TestDNS_ServiceLookup(t *testing.T) { } var out struct{} - if err := a.RPC("Catalog.Register", args, &out); err != nil { + if err := a.RPC(context.Background(), "Catalog.Register", args, &out); err != nil { t.Fatalf("err: %v", err) } } @@ -1567,7 +1568,7 @@ func TestDNS_ServiceLookup(t *testing.T) { }, }, } - if err := a.RPC("PreparedQuery.Apply", args, &id); err != nil { + if err := a.RPC(context.Background(), "PreparedQuery.Apply", args, &id); err != nil { t.Fatalf("err: %v", err) } } @@ -1678,7 +1679,7 @@ func TestDNS_ServiceLookupWithInternalServiceAddress(t *testing.T) { } var out struct{} - if err := a.RPC("Catalog.Register", args, &out); err != nil { + if err := a.RPC(context.Background(), "Catalog.Register", args, &out); err != nil { t.Fatalf("err: %v", err) } @@ -1730,7 +1731,7 @@ func TestDNS_ConnectServiceLookup(t *testing.T) { args.Service.Address = "" args.Service.Port = 12345 var out struct{} - require.Nil(t, a.RPC("Catalog.Register", args, &out)) + require.Nil(t, a.RPC(context.Background(), "Catalog.Register", args, &out)) } // Look up the service @@ -1791,7 +1792,7 @@ func TestDNS_VirtualIPLookup(t *testing.T) { run := func(t *testing.T, tc testCase) { var out struct{} - require.Nil(t, a.RPC("Catalog.Register", tc.reg, &out)) + require.Nil(t, a.RPC(context.Background(), "Catalog.Register", tc.reg, &out)) m := new(dns.Msg) m.SetQuestion(tc.question, dns.TypeA) @@ -1869,7 +1870,7 @@ func TestDNS_IngressServiceLookup(t *testing.T) { { args := structs.TestRegisterIngressGateway(t) var out struct{} - require.Nil(t, a.RPC("Catalog.Register", args, &out)) + require.Nil(t, a.RPC(context.Background(), "Catalog.Register", args, &out)) } // Register db service @@ -1886,7 +1887,7 @@ func TestDNS_IngressServiceLookup(t *testing.T) { } var out struct{} - require.Nil(t, a.RPC("Catalog.Register", args, &out)) + require.Nil(t, a.RPC(context.Background(), "Catalog.Register", args, &out)) } // Register proxy-defaults with 'http' protocol @@ -1904,7 +1905,7 @@ func TestDNS_IngressServiceLookup(t *testing.T) { WriteRequest: structs.WriteRequest{Token: "root"}, } var out bool - require.Nil(t, a.RPC("ConfigEntry.Apply", req, &out)) + require.Nil(t, a.RPC(context.Background(), "ConfigEntry.Apply", req, &out)) require.True(t, out) } @@ -1931,7 +1932,7 @@ func TestDNS_IngressServiceLookup(t *testing.T) { Entry: args, } var out bool - require.Nil(t, a.RPC("ConfigEntry.Apply", req, &out)) + require.Nil(t, a.RPC(context.Background(), "ConfigEntry.Apply", req, &out)) require.True(t, out) } @@ -1984,7 +1985,7 @@ func TestDNS_ExternalServiceLookup(t *testing.T) { } var out struct{} - if err := a.RPC("Catalog.Register", args, &out); err != nil { + if err := a.RPC(context.Background(), "Catalog.Register", args, &out); err != nil { t.Fatalf("err: %v", err) } } @@ -2051,7 +2052,7 @@ func TestDNS_InifiniteRecursion(t *testing.T) { } var out struct{} - if err := a.RPC("Catalog.Register", args, &out); err != nil { + if err := a.RPC(context.Background(), "Catalog.Register", args, &out); err != nil { t.Fatalf("err: %v", err) } } @@ -2108,7 +2109,7 @@ func TestDNS_ExternalServiceToConsulCNAMELookup(t *testing.T) { } var out struct{} - if err := a.RPC("Catalog.Register", args, &out); err != nil { + if err := a.RPC(context.Background(), "Catalog.Register", args, &out); err != nil { t.Fatalf("err: %v", err) } } @@ -2126,7 +2127,7 @@ func TestDNS_ExternalServiceToConsulCNAMELookup(t *testing.T) { } var out struct{} - if err := a.RPC("Catalog.Register", args, &out); err != nil { + if err := a.RPC(context.Background(), "Catalog.Register", args, &out); err != nil { t.Fatalf("err: %v", err) } } @@ -2395,7 +2396,7 @@ func TestDNS_ExternalServiceToConsulCNAMENestedLookup(t *testing.T) { } var out struct{} - if err := a.RPC("Catalog.Register", args, &out); err != nil { + if err := a.RPC(context.Background(), "Catalog.Register", args, &out); err != nil { t.Fatalf("err: %v", err) } } @@ -2413,7 +2414,7 @@ func TestDNS_ExternalServiceToConsulCNAMENestedLookup(t *testing.T) { } var out struct{} - if err := a.RPC("Catalog.Register", args, &out); err != nil { + if err := a.RPC(context.Background(), "Catalog.Register", args, &out); err != nil { t.Fatalf("err: %v", err) } } @@ -2431,7 +2432,7 @@ func TestDNS_ExternalServiceToConsulCNAMENestedLookup(t *testing.T) { } var out struct{} - if err := a.RPC("Catalog.Register", args, &out); err != nil { + if err := a.RPC(context.Background(), "Catalog.Register", args, &out); err != nil { t.Fatalf("err: %v", err) } } @@ -2526,7 +2527,7 @@ func TestDNS_ServiceLookup_ServiceAddress_A(t *testing.T) { } var out struct{} - if err := a.RPC("Catalog.Register", args, &out); err != nil { + if err := a.RPC(context.Background(), "Catalog.Register", args, &out); err != nil { t.Fatalf("err: %v", err) } } @@ -2544,7 +2545,7 @@ func TestDNS_ServiceLookup_ServiceAddress_A(t *testing.T) { }, }, } - if err := a.RPC("PreparedQuery.Apply", args, &id); err != nil { + if err := a.RPC(context.Background(), "PreparedQuery.Apply", args, &id); err != nil { t.Fatalf("err: %v", err) } } @@ -2625,7 +2626,7 @@ func TestDNS_AltDomain_ServiceLookup_ServiceAddress_A(t *testing.T) { } var out struct{} - if err := a.RPC("Catalog.Register", args, &out); err != nil { + if err := a.RPC(context.Background(), "Catalog.Register", args, &out); err != nil { t.Fatalf("err: %v", err) } } @@ -2643,7 +2644,7 @@ func TestDNS_AltDomain_ServiceLookup_ServiceAddress_A(t *testing.T) { }, }, } - if err := a.RPC("PreparedQuery.Apply", args, &id); err != nil { + if err := a.RPC(context.Background(), "PreparedQuery.Apply", args, &id); err != nil { t.Fatalf("err: %v", err) } } @@ -2737,7 +2738,7 @@ func TestDNS_ServiceLookup_ServiceAddress_SRV(t *testing.T) { } var out struct{} - if err := a.RPC("Catalog.Register", args, &out); err != nil { + if err := a.RPC(context.Background(), "Catalog.Register", args, &out); err != nil { t.Fatalf("err: %v", err) } } @@ -2758,7 +2759,7 @@ func TestDNS_ServiceLookup_ServiceAddress_SRV(t *testing.T) { }, }, } - if err := a.RPC("PreparedQuery.Apply", args, &id); err != nil { + if err := a.RPC(context.Background(), "PreparedQuery.Apply", args, &id); err != nil { t.Fatalf("err: %v", err) } } @@ -2844,7 +2845,7 @@ func TestDNS_ServiceLookup_ServiceAddressIPV6(t *testing.T) { } var out struct{} - if err := a.RPC("Catalog.Register", args, &out); err != nil { + if err := a.RPC(context.Background(), "Catalog.Register", args, &out); err != nil { t.Fatalf("err: %v", err) } } @@ -2862,7 +2863,7 @@ func TestDNS_ServiceLookup_ServiceAddressIPV6(t *testing.T) { }, }, } - if err := a.RPC("PreparedQuery.Apply", args, &id); err != nil { + if err := a.RPC(context.Background(), "PreparedQuery.Apply", args, &id); err != nil { t.Fatalf("err: %v", err) } } @@ -2943,7 +2944,7 @@ func TestDNS_AltDomain_ServiceLookup_ServiceAddressIPV6(t *testing.T) { } var out struct{} - if err := a.RPC("Catalog.Register", args, &out); err != nil { + if err := a.RPC(context.Background(), "Catalog.Register", args, &out); err != nil { t.Fatalf("err: %v", err) } } @@ -2961,7 +2962,7 @@ func TestDNS_AltDomain_ServiceLookup_ServiceAddressIPV6(t *testing.T) { }, }, } - if err := a.RPC("PreparedQuery.Apply", args, &id); err != nil { + if err := a.RPC(context.Background(), "PreparedQuery.Apply", args, &id); err != nil { t.Fatalf("err: %v", err) } } @@ -3062,7 +3063,7 @@ func TestDNS_ServiceLookup_WanTranslation(t *testing.T) { }, }, } - require.NoError(t, a2.RPC("PreparedQuery.Apply", args, &id)) + require.NoError(t, a2.RPC(context.Background(), "PreparedQuery.Apply", args, &id)) } type testCase struct { @@ -3184,7 +3185,7 @@ func TestDNS_ServiceLookup_WanTranslation(t *testing.T) { } var out struct{} - require.NoError(t, a2.RPC("Catalog.Register", args, &out)) + require.NoError(t, a2.RPC(context.Background(), "Catalog.Register", args, &out)) }) // Look up the SRV record via service and prepared query. @@ -3255,7 +3256,7 @@ func TestDNS_Lookup_TaggedIPAddresses(t *testing.T) { }, }, } - require.NoError(t, a.RPC("PreparedQuery.Apply", args, &id)) + require.NoError(t, a.RPC(context.Background(), "PreparedQuery.Apply", args, &id)) } type testCase struct { @@ -3341,7 +3342,7 @@ func TestDNS_Lookup_TaggedIPAddresses(t *testing.T) { } var out struct{} - require.NoError(t, a.RPC("Catalog.Register", args, &out)) + require.NoError(t, a.RPC(context.Background(), "Catalog.Register", args, &out)) // Look up the SRV record via service and prepared query. questions := []string{ @@ -3467,7 +3468,7 @@ func TestDNS_CaseInsensitiveServiceLookup(t *testing.T) { } var out struct{} - if err := a.RPC("Catalog.Register", args, &out); err != nil { + if err := a.RPC(context.Background(), "Catalog.Register", args, &out); err != nil { t.Fatalf("err: %v", err) } } @@ -3485,7 +3486,7 @@ func TestDNS_CaseInsensitiveServiceLookup(t *testing.T) { }, }, } - if err := a.RPC("PreparedQuery.Apply", args, &id); err != nil { + if err := a.RPC(context.Background(), "PreparedQuery.Apply", args, &id); err != nil { t.Fatalf("err: %v", err) } } @@ -3547,7 +3548,7 @@ func TestDNS_ServiceLookup_TagPeriod(t *testing.T) { } var out struct{} - if err := a.RPC("Catalog.Register", args, &out); err != nil { + if err := a.RPC(context.Background(), "Catalog.Register", args, &out); err != nil { t.Fatalf("err: %v", err) } @@ -3636,7 +3637,7 @@ func TestDNS_PreparedQueryNearIPEDNS(t *testing.T) { } var out struct{} - err := a.RPC("Catalog.Register", args, &out) + err := a.RPC(context.Background(), "Catalog.Register", args, &out) require.NoError(t, err) // Send coordinate updates @@ -3645,7 +3646,7 @@ func TestDNS_PreparedQueryNearIPEDNS(t *testing.T) { Node: cfg.name, Coord: cfg.coord, } - err = a.RPC("Coordinate.Update", &coordArgs, &out) + err = a.RPC(context.Background(), "Coordinate.Update", &coordArgs, &out) require.NoError(t, err) added += 1 @@ -3662,7 +3663,7 @@ func TestDNS_PreparedQueryNearIPEDNS(t *testing.T) { } var out struct{} - err := a.RPC("Catalog.Register", args, &out) + err := a.RPC(context.Background(), "Catalog.Register", args, &out) require.NoError(t, err) // Send coordinate updates for a few nodes. @@ -3671,7 +3672,7 @@ func TestDNS_PreparedQueryNearIPEDNS(t *testing.T) { Node: "bar", Coord: ipCoord, } - err = a.RPC("Coordinate.Update", &coordArgs, &out) + err = a.RPC(context.Background(), "Coordinate.Update", &coordArgs, &out) require.NoError(t, err) } @@ -3690,7 +3691,7 @@ func TestDNS_PreparedQueryNearIPEDNS(t *testing.T) { } var id string - err := a.RPC("PreparedQuery.Apply", args, &id) + err := a.RPC(context.Background(), "PreparedQuery.Apply", args, &id) require.NoError(t, err) } retry.Run(t, func(r *retry.R) { @@ -3767,7 +3768,7 @@ func TestDNS_PreparedQueryNearIP(t *testing.T) { } var out struct{} - err := a.RPC("Catalog.Register", args, &out) + err := a.RPC(context.Background(), "Catalog.Register", args, &out) require.NoError(t, err) // Send coordinate updates @@ -3776,7 +3777,7 @@ func TestDNS_PreparedQueryNearIP(t *testing.T) { Node: cfg.name, Coord: cfg.coord, } - err = a.RPC("Coordinate.Update", &coordArgs, &out) + err = a.RPC(context.Background(), "Coordinate.Update", &coordArgs, &out) require.NoError(t, err) added += 1 @@ -3793,7 +3794,7 @@ func TestDNS_PreparedQueryNearIP(t *testing.T) { } var out struct{} - err := a.RPC("Catalog.Register", args, &out) + err := a.RPC(context.Background(), "Catalog.Register", args, &out) require.NoError(t, err) // Send coordinate updates for a few nodes. @@ -3802,7 +3803,7 @@ func TestDNS_PreparedQueryNearIP(t *testing.T) { Node: "bar", Coord: ipCoord, } - err = a.RPC("Coordinate.Update", &coordArgs, &out) + err = a.RPC(context.Background(), "Coordinate.Update", &coordArgs, &out) require.NoError(t, err) } @@ -3821,7 +3822,7 @@ func TestDNS_PreparedQueryNearIP(t *testing.T) { } var id string - err := a.RPC("PreparedQuery.Apply", args, &id) + err := a.RPC(context.Background(), "PreparedQuery.Apply", args, &id) require.NoError(t, err) } @@ -3874,7 +3875,7 @@ func TestDNS_ServiceLookup_PreparedQueryNamePeriod(t *testing.T) { } var out struct{} - if err := a.RPC("Catalog.Register", args, &out); err != nil { + if err := a.RPC(context.Background(), "Catalog.Register", args, &out); err != nil { t.Fatalf("err: %v", err) } } @@ -3893,7 +3894,7 @@ func TestDNS_ServiceLookup_PreparedQueryNamePeriod(t *testing.T) { } var id string - if err := a.RPC("PreparedQuery.Apply", args, &id); err != nil { + if err := a.RPC(context.Background(), "PreparedQuery.Apply", args, &id); err != nil { t.Fatalf("err: %v", err) } } @@ -3958,7 +3959,7 @@ func TestDNS_ServiceLookup_Dedup(t *testing.T) { } var out struct{} - if err := a.RPC("Catalog.Register", args, &out); err != nil { + if err := a.RPC(context.Background(), "Catalog.Register", args, &out); err != nil { t.Fatalf("err: %v", err) } @@ -3973,7 +3974,7 @@ func TestDNS_ServiceLookup_Dedup(t *testing.T) { Port: 12345, }, } - if err := a.RPC("Catalog.Register", args, &out); err != nil { + if err := a.RPC(context.Background(), "Catalog.Register", args, &out); err != nil { t.Fatalf("err: %v", err) } @@ -3988,7 +3989,7 @@ func TestDNS_ServiceLookup_Dedup(t *testing.T) { Port: 12346, }, } - if err := a.RPC("Catalog.Register", args, &out); err != nil { + if err := a.RPC(context.Background(), "Catalog.Register", args, &out); err != nil { t.Fatalf("err: %v", err) } } @@ -4006,7 +4007,7 @@ func TestDNS_ServiceLookup_Dedup(t *testing.T) { }, }, } - if err := a.RPC("PreparedQuery.Apply", args, &id); err != nil { + if err := a.RPC(context.Background(), "PreparedQuery.Apply", args, &id); err != nil { t.Fatalf("err: %v", err) } } @@ -4065,7 +4066,7 @@ func TestDNS_ServiceLookup_Dedup_SRV(t *testing.T) { } var out struct{} - if err := a.RPC("Catalog.Register", args, &out); err != nil { + if err := a.RPC(context.Background(), "Catalog.Register", args, &out); err != nil { t.Fatalf("err: %v", err) } @@ -4080,7 +4081,7 @@ func TestDNS_ServiceLookup_Dedup_SRV(t *testing.T) { Port: 12345, }, } - if err := a.RPC("Catalog.Register", args, &out); err != nil { + if err := a.RPC(context.Background(), "Catalog.Register", args, &out); err != nil { t.Fatalf("err: %v", err) } @@ -4095,7 +4096,7 @@ func TestDNS_ServiceLookup_Dedup_SRV(t *testing.T) { Port: 12346, }, } - if err := a.RPC("Catalog.Register", args, &out); err != nil { + if err := a.RPC(context.Background(), "Catalog.Register", args, &out); err != nil { t.Fatalf("err: %v", err) } } @@ -4113,7 +4114,7 @@ func TestDNS_ServiceLookup_Dedup_SRV(t *testing.T) { }, }, } - if err := a.RPC("PreparedQuery.Apply", args, &id); err != nil { + if err := a.RPC(context.Background(), "PreparedQuery.Apply", args, &id); err != nil { t.Fatalf("err: %v", err) } } @@ -4334,7 +4335,7 @@ func TestDNS_ServiceLookup_FilterCritical(t *testing.T) { } var out struct{} - if err := a.RPC("Catalog.Register", args, &out); err != nil { + if err := a.RPC(context.Background(), "Catalog.Register", args, &out); err != nil { t.Fatalf("err: %v", err) } @@ -4353,7 +4354,7 @@ func TestDNS_ServiceLookup_FilterCritical(t *testing.T) { Status: api.HealthCritical, }, } - if err := a.RPC("Catalog.Register", args2, &out); err != nil { + if err := a.RPC(context.Background(), "Catalog.Register", args2, &out); err != nil { t.Fatalf("err: %v", err) } @@ -4373,7 +4374,7 @@ func TestDNS_ServiceLookup_FilterCritical(t *testing.T) { Status: api.HealthCritical, }, } - if err := a.RPC("Catalog.Register", args3, &out); err != nil { + if err := a.RPC(context.Background(), "Catalog.Register", args3, &out); err != nil { t.Fatalf("err: %v", err) } @@ -4387,7 +4388,7 @@ func TestDNS_ServiceLookup_FilterCritical(t *testing.T) { Port: 12345, }, } - if err := a.RPC("Catalog.Register", args4, &out); err != nil { + if err := a.RPC(context.Background(), "Catalog.Register", args4, &out); err != nil { t.Fatalf("err: %v", err) } @@ -4407,7 +4408,7 @@ func TestDNS_ServiceLookup_FilterCritical(t *testing.T) { Status: api.HealthWarning, }, } - if err := a.RPC("Catalog.Register", args5, &out); err != nil { + if err := a.RPC(context.Background(), "Catalog.Register", args5, &out); err != nil { t.Fatalf("err: %v", err) } } @@ -4425,7 +4426,7 @@ func TestDNS_ServiceLookup_FilterCritical(t *testing.T) { }, }, } - if err := a.RPC("PreparedQuery.Apply", args, &id); err != nil { + if err := a.RPC(context.Background(), "PreparedQuery.Apply", args, &id); err != nil { t.Fatalf("err: %v", err) } } @@ -4494,7 +4495,7 @@ func TestDNS_ServiceLookup_OnlyFailing(t *testing.T) { } var out struct{} - if err := a.RPC("Catalog.Register", args, &out); err != nil { + if err := a.RPC(context.Background(), "Catalog.Register", args, &out); err != nil { t.Fatalf("err: %v", err) } @@ -4513,7 +4514,7 @@ func TestDNS_ServiceLookup_OnlyFailing(t *testing.T) { Status: api.HealthCritical, }, } - if err := a.RPC("Catalog.Register", args2, &out); err != nil { + if err := a.RPC(context.Background(), "Catalog.Register", args2, &out); err != nil { t.Fatalf("err: %v", err) } @@ -4533,7 +4534,7 @@ func TestDNS_ServiceLookup_OnlyFailing(t *testing.T) { Status: api.HealthCritical, }, } - if err := a.RPC("Catalog.Register", args3, &out); err != nil { + if err := a.RPC(context.Background(), "Catalog.Register", args3, &out); err != nil { t.Fatalf("err: %v", err) } } @@ -4551,7 +4552,7 @@ func TestDNS_ServiceLookup_OnlyFailing(t *testing.T) { }, }, } - if err := a.RPC("PreparedQuery.Apply", args, &id); err != nil { + if err := a.RPC(context.Background(), "PreparedQuery.Apply", args, &id); err != nil { t.Fatalf("err: %v", err) } } @@ -4616,7 +4617,7 @@ func TestDNS_ServiceLookup_OnlyPassing(t *testing.T) { } var out struct{} - if err := a.RPC("Catalog.Register", args, &out); err != nil { + if err := a.RPC(context.Background(), "Catalog.Register", args, &out); err != nil { t.Fatalf("err: %v", err) } @@ -4637,7 +4638,7 @@ func TestDNS_ServiceLookup_OnlyPassing(t *testing.T) { }, } - if err := a.RPC("Catalog.Register", args2, &out); err != nil { + if err := a.RPC(context.Background(), "Catalog.Register", args2, &out); err != nil { t.Fatalf("err: %v", err) } @@ -4658,7 +4659,7 @@ func TestDNS_ServiceLookup_OnlyPassing(t *testing.T) { }, } - if err := a.RPC("Catalog.Register", args3, &out); err != nil { + if err := a.RPC(context.Background(), "Catalog.Register", args3, &out); err != nil { t.Fatalf("err: %v", err) } } @@ -4677,7 +4678,7 @@ func TestDNS_ServiceLookup_OnlyPassing(t *testing.T) { }, }, } - if err := a.RPC("PreparedQuery.Apply", args, &id); err != nil { + if err := a.RPC(context.Background(), "PreparedQuery.Apply", args, &id); err != nil { t.Fatalf("err: %v", err) } } @@ -4752,7 +4753,7 @@ func TestDNS_ServiceLookup_Randomize(t *testing.T) { } var out struct{} - if err := a.RPC("Catalog.Register", args, &out); err != nil { + if err := a.RPC(context.Background(), "Catalog.Register", args, &out); err != nil { t.Fatalf("err: %v", err) } } @@ -4770,7 +4771,7 @@ func TestDNS_ServiceLookup_Randomize(t *testing.T) { }, }, } - if err := a.RPC("PreparedQuery.Apply", args, &id); err != nil { + if err := a.RPC(context.Background(), "PreparedQuery.Apply", args, &id); err != nil { t.Fatalf("err: %v", err) } } @@ -4894,7 +4895,7 @@ func TestDNS_TCP_and_UDP_Truncate(t *testing.T) { } var out struct{} - if err := a.RPC("Catalog.Register", args, &out); err != nil { + if err := a.RPC(context.Background(), "Catalog.Register", args, &out); err != nil { t.Fatalf("err: %v", err) } } @@ -4912,7 +4913,7 @@ func TestDNS_TCP_and_UDP_Truncate(t *testing.T) { }, }, } - if err := a.RPC("PreparedQuery.Apply", args, &id); err != nil { + if err := a.RPC(context.Background(), "PreparedQuery.Apply", args, &id); err != nil { t.Fatalf("err: %v", err) } } @@ -5001,7 +5002,7 @@ func TestDNS_ServiceLookup_Truncate(t *testing.T) { } var out struct{} - if err := a.RPC("Catalog.Register", args, &out); err != nil { + if err := a.RPC(context.Background(), "Catalog.Register", args, &out); err != nil { t.Fatalf("err: %v", err) } } @@ -5019,7 +5020,7 @@ func TestDNS_ServiceLookup_Truncate(t *testing.T) { }, }, } - if err := a.RPC("PreparedQuery.Apply", args, &id); err != nil { + if err := a.RPC(context.Background(), "PreparedQuery.Apply", args, &id); err != nil { t.Fatalf("err: %v", err) } } @@ -5077,7 +5078,7 @@ func TestDNS_ServiceLookup_LargeResponses(t *testing.T) { } var out struct{} - if err := a.RPC("Catalog.Register", args, &out); err != nil { + if err := a.RPC(context.Background(), "Catalog.Register", args, &out); err != nil { t.Fatalf("err: %v", err) } } @@ -5096,7 +5097,7 @@ func TestDNS_ServiceLookup_LargeResponses(t *testing.T) { }, } var id string - if err := a.RPC("PreparedQuery.Apply", args, &id); err != nil { + if err := a.RPC(context.Background(), "PreparedQuery.Apply", args, &id); err != nil { t.Fatalf("err: %v", err) } } @@ -5185,7 +5186,7 @@ func testDNSServiceLookupResponseLimits(t *testing.T, answerLimit int, qType uin } var out struct{} - if err := a.RPC("Catalog.Register", args, &out); err != nil { + if err := a.RPC(context.Background(), "Catalog.Register", args, &out); err != nil { return false, fmt.Errorf("err: %v", err) } } @@ -5202,7 +5203,7 @@ func testDNSServiceLookupResponseLimits(t *testing.T, answerLimit int, qType uin }, } - if err := a.RPC("PreparedQuery.Apply", args, &id); err != nil { + if err := a.RPC(context.Background(), "PreparedQuery.Apply", args, &id); err != nil { return false, fmt.Errorf("err: %v", err) } } @@ -5281,7 +5282,7 @@ func checkDNSService( } var out struct{} - require.NoError(t, a.RPC("Catalog.Register", args, &out)) + require.NoError(t, a.RPC(context.Background(), "Catalog.Register", args, &out)) } var id string { @@ -5296,7 +5297,7 @@ func checkDNSService( }, } - require.NoError(t, a.RPC("PreparedQuery.Apply", args, &id)) + require.NoError(t, a.RPC(context.Background(), "PreparedQuery.Apply", args, &id)) } // Look up the service directly and via prepared query. @@ -5520,7 +5521,7 @@ func TestDNS_ServiceLookup_CNAME(t *testing.T) { } var out struct{} - if err := a.RPC("Catalog.Register", args, &out); err != nil { + if err := a.RPC(context.Background(), "Catalog.Register", args, &out); err != nil { t.Fatalf("err: %v", err) } } @@ -5538,7 +5539,7 @@ func TestDNS_ServiceLookup_CNAME(t *testing.T) { }, }, } - if err := a.RPC("PreparedQuery.Apply", args, &id); err != nil { + if err := a.RPC(context.Background(), "PreparedQuery.Apply", args, &id); err != nil { t.Fatalf("err: %v", err) } } @@ -5622,7 +5623,7 @@ func TestDNS_ServiceLookup_ServiceAddress_CNAME(t *testing.T) { } var out struct{} - if err := a.RPC("Catalog.Register", args, &out); err != nil { + if err := a.RPC(context.Background(), "Catalog.Register", args, &out); err != nil { t.Fatalf("err: %v", err) } } @@ -5640,7 +5641,7 @@ func TestDNS_ServiceLookup_ServiceAddress_CNAME(t *testing.T) { }, }, } - if err := a.RPC("PreparedQuery.Apply", args, &id); err != nil { + if err := a.RPC(context.Background(), "PreparedQuery.Apply", args, &id); err != nil { t.Fatalf("err: %v", err) } } @@ -5723,7 +5724,7 @@ func TestDNS_NodeLookup_TTL(t *testing.T) { } var out struct{} - if err := a.RPC("Catalog.Register", args, &out); err != nil { + if err := a.RPC(context.Background(), "Catalog.Register", args, &out); err != nil { t.Fatalf("err: %v", err) } @@ -5757,7 +5758,7 @@ func TestDNS_NodeLookup_TTL(t *testing.T) { Node: "bar", Address: "::4242:4242", } - if err := a.RPC("Catalog.Register", args, &out); err != nil { + if err := a.RPC(context.Background(), "Catalog.Register", args, &out); err != nil { t.Fatalf("err: %v", err) } @@ -5791,7 +5792,7 @@ func TestDNS_NodeLookup_TTL(t *testing.T) { Node: "google", Address: "www.google.com", } - if err := a.RPC("Catalog.Register", args, &out); err != nil { + if err := a.RPC(context.Background(), "Catalog.Register", args, &out); err != nil { t.Fatalf("err: %v", err) } @@ -5855,7 +5856,7 @@ func TestDNS_ServiceLookup_TTL(t *testing.T) { } var out struct{} - if err := a.RPC("Catalog.Register", args, &out); err != nil { + if err := a.RPC(context.Background(), "Catalog.Register", args, &out); err != nil { t.Fatalf("err: %v", err) } } @@ -5938,7 +5939,7 @@ func TestDNS_PreparedQuery_TTL(t *testing.T) { } var out struct{} - if err := a.RPC("Catalog.Register", args, &out); err != nil { + if err := a.RPC(context.Background(), "Catalog.Register", args, &out); err != nil { t.Fatalf("err: %v", err) } // Register prepared query without TTL and with TTL @@ -5955,7 +5956,7 @@ func TestDNS_PreparedQuery_TTL(t *testing.T) { } var id string - if err := a.RPC("PreparedQuery.Apply", args, &id); err != nil { + if err := a.RPC(context.Background(), "PreparedQuery.Apply", args, &id); err != nil { t.Fatalf("err: %v", err) } queryTTL := fmt.Sprintf("%s-ttl", service) @@ -5973,7 +5974,7 @@ func TestDNS_PreparedQuery_TTL(t *testing.T) { }, } - if err := a.RPC("PreparedQuery.Apply", args, &id); err != nil { + if err := a.RPC(context.Background(), "PreparedQuery.Apply", args, &id); err != nil { t.Fatalf("err: %v", err) } } @@ -6077,7 +6078,7 @@ func TestDNS_PreparedQuery_Failover(t *testing.T) { } var out struct{} - if err := a2.RPC("Catalog.Register", args, &out); err != nil { + if err := a2.RPC(context.Background(), "Catalog.Register", args, &out); err != nil { r.Fatalf("err: %v", err) } }) @@ -6098,7 +6099,7 @@ func TestDNS_PreparedQuery_Failover(t *testing.T) { }, } var id string - if err := a1.RPC("PreparedQuery.Apply", args, &id); err != nil { + if err := a1.RPC(context.Background(), "PreparedQuery.Apply", args, &id); err != nil { t.Fatalf("err: %v", err) } } @@ -6165,7 +6166,7 @@ func TestDNS_ServiceLookup_SRV_RFC(t *testing.T) { } var out struct{} - if err := a.RPC("Catalog.Register", args, &out); err != nil { + if err := a.RPC(context.Background(), "Catalog.Register", args, &out); err != nil { t.Fatalf("err: %v", err) } @@ -6244,7 +6245,7 @@ func TestDNS_ServiceLookup_SRV_RFC_TCP_Default(t *testing.T) { } var out struct{} - if err := a.RPC("Catalog.Register", args, &out); err != nil { + if err := a.RPC(context.Background(), "Catalog.Register", args, &out); err != nil { t.Fatalf("err: %v", err) } @@ -6344,7 +6345,7 @@ func TestDNS_ServiceLookup_FilterACL(t *testing.T) { WriteRequest: structs.WriteRequest{Token: "root"}, } var out struct{} - if err := a.RPC("Catalog.Register", args, &out); err != nil { + if err := a.RPC(context.Background(), "Catalog.Register", args, &out); err != nil { t.Fatalf("err: %v", err) } @@ -6387,7 +6388,7 @@ func TestDNS_ServiceLookup_MetaTXT(t *testing.T) { } var out struct{} - if err := a.RPC("Catalog.Register", args, &out); err != nil { + if err := a.RPC(context.Background(), "Catalog.Register", args, &out); err != nil { t.Fatalf("err: %v", err) } @@ -6438,7 +6439,7 @@ func TestDNS_ServiceLookup_SuppressTXT(t *testing.T) { } var out struct{} - if err := a.RPC("Catalog.Register", args, &out); err != nil { + if err := a.RPC(context.Background(), "Catalog.Register", args, &out); err != nil { t.Fatalf("err: %v", err) } @@ -6754,7 +6755,7 @@ func TestDNS_NonExistingLookupEmptyAorAAAA(t *testing.T) { } var out struct{} - if err := a.RPC("Catalog.Register", args, &out); err != nil { + if err := a.RPC(context.Background(), "Catalog.Register", args, &out); err != nil { t.Fatalf("err: %v", err) } @@ -6768,7 +6769,7 @@ func TestDNS_NonExistingLookupEmptyAorAAAA(t *testing.T) { }, } - if err := a.RPC("Catalog.Register", args, &out); err != nil { + if err := a.RPC(context.Background(), "Catalog.Register", args, &out); err != nil { t.Fatalf("err: %v", err) } } @@ -6787,7 +6788,7 @@ func TestDNS_NonExistingLookupEmptyAorAAAA(t *testing.T) { } var id string - if err := a.RPC("PreparedQuery.Apply", args, &id); err != nil { + if err := a.RPC(context.Background(), "PreparedQuery.Apply", args, &id); err != nil { t.Fatalf("err: %v", err) } @@ -6802,7 +6803,7 @@ func TestDNS_NonExistingLookupEmptyAorAAAA(t *testing.T) { }, } - if err := a.RPC("PreparedQuery.Apply", args, &id); err != nil { + if err := a.RPC(context.Background(), "PreparedQuery.Apply", args, &id); err != nil { t.Fatalf("err: %v", err) } } @@ -6898,7 +6899,7 @@ func TestDNS_AltDomains_Service(t *testing.T) { } var out struct{} - if err := a.RPC("Catalog.Register", args, &out); err != nil { + if err := a.RPC(context.Background(), "Catalog.Register", args, &out); err != nil { t.Fatalf("err: %v", err) } } @@ -7894,7 +7895,7 @@ func TestDNS_Compression_Query(t *testing.T) { } var out struct{} - if err := a.RPC("Catalog.Register", args, &out); err != nil { + if err := a.RPC(context.Background(), "Catalog.Register", args, &out); err != nil { t.Fatalf("err: %v", err) } } @@ -7912,7 +7913,7 @@ func TestDNS_Compression_Query(t *testing.T) { }, }, } - if err := a.RPC("PreparedQuery.Apply", args, &id); err != nil { + if err := a.RPC(context.Background(), "PreparedQuery.Apply", args, &id); err != nil { t.Fatalf("err: %v", err) } } @@ -7978,7 +7979,7 @@ func TestDNS_Compression_ReverseLookup(t *testing.T) { Address: "127.0.0.2", } var out struct{} - if err := a.RPC("Catalog.Register", args, &out); err != nil { + if err := a.RPC(context.Background(), "Catalog.Register", args, &out); err != nil { t.Fatalf("err: %v", err) } diff --git a/agent/federation_state_endpoint.go b/agent/federation_state_endpoint.go index 4b7757ef82..bac8807ec4 100644 --- a/agent/federation_state_endpoint.go +++ b/agent/federation_state_endpoint.go @@ -23,7 +23,7 @@ func (s *HTTPHandlers) FederationStateGet(resp http.ResponseWriter, req *http.Re var out structs.FederationStateResponse defer setMeta(resp, &out.QueryMeta) - if err := s.agent.RPC("FederationState.Get", &args, &out); err != nil { + if err := s.agent.RPC(req.Context(), "FederationState.Get", &args, &out); err != nil { return nil, err } @@ -48,7 +48,7 @@ func (s *HTTPHandlers) FederationStateList(resp http.ResponseWriter, req *http.R var out structs.IndexedFederationStates defer setMeta(resp, &out.QueryMeta) - if err := s.agent.RPC("FederationState.List", &args, &out); err != nil { + if err := s.agent.RPC(req.Context(), "FederationState.List", &args, &out); err != nil { return nil, err } @@ -73,7 +73,7 @@ func (s *HTTPHandlers) FederationStateListMeshGateways(resp http.ResponseWriter, var out structs.DatacenterIndexedCheckServiceNodes defer setMeta(resp, &out.QueryMeta) - if err := s.agent.RPC("FederationState.ListMeshGateways", &args, &out); err != nil { + if err := s.agent.RPC(req.Context(), "FederationState.ListMeshGateways", &args, &out); err != nil { return nil, err } diff --git a/agent/health_endpoint.go b/agent/health_endpoint.go index 6ea64528b0..8060d93bd5 100644 --- a/agent/health_endpoint.go +++ b/agent/health_endpoint.go @@ -38,7 +38,7 @@ func (s *HTTPHandlers) HealthChecksInState(resp http.ResponseWriter, req *http.R var out structs.IndexedHealthChecks defer setMeta(resp, &out.QueryMeta) RETRY_ONCE: - if err := s.agent.RPC("Health.ChecksInState", &args, &out); err != nil { + if err := s.agent.RPC(req.Context(), "Health.ChecksInState", &args, &out); err != nil { return nil, err } if args.QueryOptions.AllowStale && args.MaxStaleDuration > 0 && args.MaxStaleDuration < out.LastContact { @@ -82,7 +82,7 @@ func (s *HTTPHandlers) HealthNodeChecks(resp http.ResponseWriter, req *http.Requ var out structs.IndexedHealthChecks defer setMeta(resp, &out.QueryMeta) RETRY_ONCE: - if err := s.agent.RPC("Health.NodeChecks", &args, &out); err != nil { + if err := s.agent.RPC(req.Context(), "Health.NodeChecks", &args, &out); err != nil { return nil, err } if args.QueryOptions.AllowStale && args.MaxStaleDuration > 0 && args.MaxStaleDuration < out.LastContact { @@ -128,7 +128,7 @@ func (s *HTTPHandlers) HealthServiceChecks(resp http.ResponseWriter, req *http.R var out structs.IndexedHealthChecks defer setMeta(resp, &out.QueryMeta) RETRY_ONCE: - if err := s.agent.RPC("Health.ServiceChecks", &args, &out); err != nil { + if err := s.agent.RPC(req.Context(), "Health.ServiceChecks", &args, &out); err != nil { return nil, err } if args.QueryOptions.AllowStale && args.MaxStaleDuration > 0 && args.MaxStaleDuration < out.LastContact { diff --git a/agent/health_endpoint_test.go b/agent/health_endpoint_test.go index b822bdde82..8217fca69b 100644 --- a/agent/health_endpoint_test.go +++ b/agent/health_endpoint_test.go @@ -1,6 +1,7 @@ package agent import ( + "context" "fmt" "net/http" "net/http/httptest" @@ -98,7 +99,7 @@ func TestHealthChecksInState_NodeMetaFilter(t *testing.T) { }, } var out struct{} - if err := a.RPC("Catalog.Register", args, &out); err != nil { + if err := a.RPC(context.Background(), "Catalog.Register", args, &out); err != nil { t.Fatalf("err: %v", err) } @@ -142,7 +143,7 @@ func TestHealthChecksInState_Filter(t *testing.T) { }, } var out struct{} - require.NoError(t, a.RPC("Catalog.Register", args, &out)) + require.NoError(t, a.RPC(context.Background(), "Catalog.Register", args, &out)) args = &structs.RegisterRequest{ Datacenter: "dc1", @@ -156,7 +157,7 @@ func TestHealthChecksInState_Filter(t *testing.T) { }, SkipNodeUpdate: true, } - require.NoError(t, a.RPC("Catalog.Register", args, &out)) + require.NoError(t, a.RPC(context.Background(), "Catalog.Register", args, &out)) req, _ := http.NewRequest("GET", "/v1/health/state/critical?filter="+url.QueryEscape("Name == `node check 2`"), nil) retry.Run(t, func(r *retry.R) { @@ -192,12 +193,12 @@ func TestHealthChecksInState_DistanceSort(t *testing.T) { } var out struct{} - if err := a.RPC("Catalog.Register", args, &out); err != nil { + if err := a.RPC(context.Background(), "Catalog.Register", args, &out); err != nil { t.Fatalf("err: %v", err) } args.Node, args.Check.Node = "foo", "foo" - if err := a.RPC("Catalog.Register", args, &out); err != nil { + if err := a.RPC(context.Background(), "Catalog.Register", args, &out); err != nil { t.Fatalf("err: %v", err) } @@ -225,7 +226,7 @@ func TestHealthChecksInState_DistanceSort(t *testing.T) { Node: "foo", Coord: coordinate.NewCoordinate(coordinate.DefaultConfig()), } - if err := a.RPC("Coordinate.Update", &arg, &out); err != nil { + if err := a.RPC(context.Background(), "Coordinate.Update", &arg, &out); err != nil { t.Fatalf("err: %v", err) } // Retry until foo moves to the front of the line. @@ -310,7 +311,7 @@ func TestHealthNodeChecks_Filtering(t *testing.T) { } var out struct{} - require.NoError(t, a.RPC("Catalog.Register", args, &out)) + require.NoError(t, a.RPC(context.Background(), "Catalog.Register", args, &out)) // Create a second check args = &structs.RegisterRequest{ @@ -323,7 +324,7 @@ func TestHealthNodeChecks_Filtering(t *testing.T) { }, SkipNodeUpdate: true, } - require.NoError(t, a.RPC("Catalog.Register", args, &out)) + require.NoError(t, a.RPC(context.Background(), "Catalog.Register", args, &out)) req, _ := http.NewRequest("GET", "/v1/health/node/test-health-node?filter="+url.QueryEscape("Name == check2"), nil) resp := httptest.NewRecorder() @@ -374,7 +375,7 @@ func TestHealthServiceChecks(t *testing.T) { } var out struct{} - if err = a.RPC("Catalog.Register", args, &out); err != nil { + if err = a.RPC(context.Background(), "Catalog.Register", args, &out); err != nil { t.Fatalf("err: %v", err) } @@ -434,7 +435,7 @@ func TestHealthServiceChecks_NodeMetaFilter(t *testing.T) { } var out struct{} - if err = a.RPC("Catalog.Register", args, &out); err != nil { + if err = a.RPC(context.Background(), "Catalog.Register", args, &out); err != nil { t.Fatalf("err: %v", err) } @@ -487,7 +488,7 @@ func TestHealthServiceChecks_Filtering(t *testing.T) { } var out struct{} - require.NoError(t, a.RPC("Catalog.Register", args, &out)) + require.NoError(t, a.RPC(context.Background(), "Catalog.Register", args, &out)) // Create a new node, service and check args = &structs.RegisterRequest{ @@ -505,7 +506,7 @@ func TestHealthServiceChecks_Filtering(t *testing.T) { ServiceID: "consul", }, } - require.NoError(t, a.RPC("Catalog.Register", args, &out)) + require.NoError(t, a.RPC(context.Background(), "Catalog.Register", args, &out)) req, _ = http.NewRequest("GET", "/v1/health/checks/consul?dc=dc1&filter="+url.QueryEscape("Node == `test-health-node`"), nil) resp = httptest.NewRecorder() @@ -545,12 +546,12 @@ func TestHealthServiceChecks_DistanceSort(t *testing.T) { } var out struct{} - if err := a.RPC("Catalog.Register", args, &out); err != nil { + if err := a.RPC(context.Background(), "Catalog.Register", args, &out); err != nil { t.Fatalf("err: %v", err) } args.Node, args.Check.Node = "foo", "foo" - if err := a.RPC("Catalog.Register", args, &out); err != nil { + if err := a.RPC(context.Background(), "Catalog.Register", args, &out); err != nil { t.Fatalf("err: %v", err) } @@ -578,7 +579,7 @@ func TestHealthServiceChecks_DistanceSort(t *testing.T) { Node: "foo", Coord: coordinate.NewCoordinate(coordinate.DefaultConfig()), } - if err := a.RPC("Coordinate.Update", &arg, &out); err != nil { + if err := a.RPC(context.Background(), "Coordinate.Update", &arg, &out); err != nil { t.Fatalf("err: %v", err) } // Retry until foo has moved to the front of the line. @@ -663,7 +664,7 @@ func TestHealthServiceNodes(t *testing.T) { } var out struct{} - require.NoError(t, a.RPC("Catalog.Register", args, &out)) + require.NoError(t, a.RPC(context.Background(), "Catalog.Register", args, &out)) originalRegister[peerName] = args } @@ -730,7 +731,7 @@ func TestHealthServiceNodes(t *testing.T) { args2.Node = "baz" args2.Address = "127.0.0.2" var out struct{} - require.NoError(t, a.RPC("Catalog.Register", &args2, &out)) + require.NoError(t, a.RPC(context.Background(), "Catalog.Register", &args2, &out)) } for _, peerName := range testingPeerNames { @@ -840,7 +841,7 @@ use_streaming_backend = true } var out struct{} - require.NoError(t, a.RPC("Catalog.Register", args, &out)) + require.NoError(t, a.RPC(context.Background(), "Catalog.Register", args, &out)) } // Initial request should return two instances @@ -894,7 +895,7 @@ use_streaming_backend = true } var out struct{} - errCh <- a.RPC("Catalog.Register", args, &out) + errCh <- a.RPC(context.Background(), "Catalog.Register", args, &out) }() { @@ -1010,7 +1011,7 @@ use_streaming_backend = true } var out struct{} - require.NoError(t, a.RPC("Catalog.Register", args, &out)) + require.NoError(t, a.RPC(context.Background(), "Catalog.Register", args, &out)) } for _, tc := range cases { @@ -1169,7 +1170,7 @@ func TestHealthServiceNodes_NodeMetaFilter(t *testing.T) { } var ignored struct{} - require.NoError(t, a.RPC("Catalog.Register", args, &ignored)) + require.NoError(t, a.RPC(context.Background(), "Catalog.Register", args, &ignored)) }) testutil.RunStep(t, "register item 2", func(t *testing.T) { @@ -1184,7 +1185,7 @@ func TestHealthServiceNodes_NodeMetaFilter(t *testing.T) { }, } var ignored struct{} - require.NoError(t, a.RPC("Catalog.Register", args, &ignored)) + require.NoError(t, a.RPC(context.Background(), "Catalog.Register", args, &ignored)) }) testutil.RunStep(t, "do blocking read", func(t *testing.T) { @@ -1243,7 +1244,7 @@ func TestHealthServiceNodes_Filter(t *testing.T) { } var out struct{} - require.NoError(t, a.RPC("Catalog.Register", args, &out)) + require.NoError(t, a.RPC(context.Background(), "Catalog.Register", args, &out)) // Create a new node, service and check args = &structs.RegisterRequest{ @@ -1261,7 +1262,7 @@ func TestHealthServiceNodes_Filter(t *testing.T) { ServiceID: "consul", }, } - require.NoError(t, a.RPC("Catalog.Register", args, &out)) + require.NoError(t, a.RPC(context.Background(), "Catalog.Register", args, &out)) req, _ = http.NewRequest("GET", "/v1/health/service/consul?dc=dc1&filter="+url.QueryEscape("Node.Node == `test-health-node`"), nil) resp = httptest.NewRecorder() @@ -1302,12 +1303,12 @@ func TestHealthServiceNodes_DistanceSort(t *testing.T) { } testrpc.WaitForLeader(t, a.RPC, dc) var out struct{} - if err := a.RPC("Catalog.Register", args, &out); err != nil { + if err := a.RPC(context.Background(), "Catalog.Register", args, &out); err != nil { t.Fatalf("err: %v", err) } args.Node, args.Check.Node = "foo", "foo" - if err := a.RPC("Catalog.Register", args, &out); err != nil { + if err := a.RPC(context.Background(), "Catalog.Register", args, &out); err != nil { t.Fatalf("err: %v", err) } @@ -1335,7 +1336,7 @@ func TestHealthServiceNodes_DistanceSort(t *testing.T) { Node: "foo", Coord: coordinate.NewCoordinate(coordinate.DefaultConfig()), } - if err := a.RPC("Coordinate.Update", &arg, &out); err != nil { + if err := a.RPC(context.Background(), "Coordinate.Update", &arg, &out); err != nil { t.Fatalf("err: %v", err) } // Retry until foo has moved to the front of the line. @@ -1384,7 +1385,7 @@ func TestHealthServiceNodes_PassingFilter(t *testing.T) { retry.Run(t, func(r *retry.R) { var out struct{} - if err := a.RPC("Catalog.Register", args, &out); err != nil { + if err := a.RPC(context.Background(), "Catalog.Register", args, &out); err != nil { r.Fatalf("err: %v", err) } }) @@ -1488,7 +1489,7 @@ func TestHealthServiceNodes_CheckType(t *testing.T) { } var out struct{} - require.NoError(t, a.RPC("Catalog.Register", args, &out)) + require.NoError(t, a.RPC(context.Background(), "Catalog.Register", args, &out)) req, _ = http.NewRequest("GET", "/v1/health/service/consul?dc=dc1", nil) resp = httptest.NewRecorder() @@ -1560,7 +1561,7 @@ func TestHealthServiceNodes_WanTranslation(t *testing.T) { } var out struct{} - require.NoError(t, a2.RPC("Catalog.Register", args, &out)) + require.NoError(t, a2.RPC(context.Background(), "Catalog.Register", args, &out)) } // Query for a service in DC2 from DC1. @@ -1612,7 +1613,7 @@ func TestHealthConnectServiceNodes(t *testing.T) { // Register args := structs.TestRegisterRequestProxy(t) var out struct{} - assert.Nil(t, a.RPC("Catalog.Register", args, &out)) + assert.Nil(t, a.RPC(context.Background(), "Catalog.Register", args, &out)) // Request req, _ := http.NewRequest("GET", fmt.Sprintf( @@ -1651,10 +1652,10 @@ func testHealthIngressServiceNodes(t *testing.T, agentHCL string) { gatewayArgs := structs.TestRegisterIngressGateway(t) gatewayArgs.Service.Address = "127.0.0.27" var out struct{} - require.NoError(t, a.RPC("Catalog.Register", gatewayArgs, &out)) + require.NoError(t, a.RPC(context.Background(), "Catalog.Register", gatewayArgs, &out)) args := structs.TestRegisterRequest(t) - require.NoError(t, a.RPC("Catalog.Register", args, &out)) + require.NoError(t, a.RPC(context.Background(), "Catalog.Register", args, &out)) // Associate service to gateway cfgArgs := &structs.IngressGatewayConfigEntry{ @@ -1677,7 +1678,7 @@ func testHealthIngressServiceNodes(t *testing.T, agentHCL string) { Entry: cfgArgs, } var outB bool - require.Nil(t, a.RPC("ConfigEntry.Apply", req, &outB)) + require.Nil(t, a.RPC(context.Background(), "ConfigEntry.Apply", req, &outB)) require.True(t, outB) checkResults := func(t *testing.T, obj interface{}) { @@ -1759,7 +1760,7 @@ func TestHealthConnectServiceNodes_Filter(t *testing.T) { args := structs.TestRegisterRequestProxy(t) args.Service.Address = "127.0.0.55" var out struct{} - require.NoError(t, a.RPC("Catalog.Register", args, &out)) + require.NoError(t, a.RPC(context.Background(), "Catalog.Register", args, &out)) args = structs.TestRegisterRequestProxy(t) args.Service.Address = "127.0.0.55" @@ -1768,7 +1769,7 @@ func TestHealthConnectServiceNodes_Filter(t *testing.T) { } args.Service.ID = "web-proxy2" args.SkipNodeUpdate = true - require.NoError(t, a.RPC("Catalog.Register", args, &out)) + require.NoError(t, a.RPC(context.Background(), "Catalog.Register", args, &out)) req, _ := http.NewRequest("GET", fmt.Sprintf( "/v1/health/connect/%s?filter=%s", @@ -1805,7 +1806,7 @@ func TestHealthConnectServiceNodes_PassingFilter(t *testing.T) { Status: api.HealthCritical, } var out struct{} - assert.Nil(t, a.RPC("Catalog.Register", args, &out)) + assert.Nil(t, a.RPC(context.Background(), "Catalog.Register", args, &out)) t.Run("bc_no_query_value", func(t *testing.T) { req, _ := http.NewRequest("GET", fmt.Sprintf( @@ -1986,7 +1987,7 @@ func TestHealthServiceNodes_MergeCentralConfigBlocking(t *testing.T) { MergeCentralConfig: true, } var rpcResp structs.IndexedCheckServiceNodes - require.NoError(t, a.RPC("Health.ServiceNodes", &rpcReq, &rpcResp)) + require.NoError(t, a.RPC(context.Background(), "Health.ServiceNodes", &rpcReq, &rpcResp)) require.Len(t, rpcResp.Nodes, 1) nodeService := rpcResp.Nodes[0].Service diff --git a/agent/http.go b/agent/http.go index a3529a67a7..401e94ecd5 100644 --- a/agent/http.go +++ b/agent/http.go @@ -7,6 +7,7 @@ import ( "net" "net/http" "net/http/pprof" + "net/netip" "net/url" "reflect" "regexp" @@ -288,7 +289,9 @@ func (s *HTTPHandlers) handler(enableDebug bool) http.Handler { if s.agent.config.DisableHTTPUnprintableCharFilter { h = mux } + h = s.enterpriseHandler(h) + h = withRemoteAddrHandler(h) s.h = &wrappedMux{ mux: mux, handler: h, @@ -296,6 +299,19 @@ func (s *HTTPHandlers) handler(enableDebug bool) http.Handler { return s.h } +// Injects remote addr into the request's context +func withRemoteAddrHandler(next http.Handler) http.Handler { + return http.HandlerFunc(func(resp http.ResponseWriter, req *http.Request) { + addrPort, err := netip.ParseAddrPort(req.RemoteAddr) + if err == nil { + remoteAddr := net.TCPAddrFromAddrPort(addrPort) + ctx := consul.ContextWithRemoteAddr(req.Context(), remoteAddr) + req = req.WithContext(ctx) + } + next.ServeHTTP(resp, req) + }) +} + // nodeName returns the node name of the agent func (s *HTTPHandlers) nodeName() string { return s.agent.config.NodeName diff --git a/agent/http_test.go b/agent/http_test.go index 4464fe1834..39963be041 100644 --- a/agent/http_test.go +++ b/agent/http_test.go @@ -10,6 +10,7 @@ import ( "net" "net/http" "net/http/httptest" + "net/netip" "os" "path/filepath" "runtime" @@ -26,6 +27,7 @@ import ( "golang.org/x/net/http2" "github.com/hashicorp/consul/agent/config" + "github.com/hashicorp/consul/agent/consul" "github.com/hashicorp/consul/agent/structs" tokenStore "github.com/hashicorp/consul/agent/token" "github.com/hashicorp/consul/api" @@ -1685,3 +1687,42 @@ func TestRPC_HTTPSMaxConnsPerClient(t *testing.T) { }) } } + +func TestWithRemoteAddrHandler_ValidAddr(t *testing.T) { + expected := net.TCPAddrFromAddrPort(netip.MustParseAddrPort("1.2.3.4:8080")) + nextHandlerCalled := false + + assertionHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + nextHandlerCalled = true + remoteAddr, ok := consul.RemoteAddrFromContext(r.Context()) + if !ok || remoteAddr.String() != expected.String() { + t.Errorf("remote addr not present but expected %v", expected) + } + }) + + remoteAddrHandler := withRemoteAddrHandler(assertionHandler) + req := httptest.NewRequest("GET", "http://ignoreme", nil) + req.RemoteAddr = expected.String() + remoteAddrHandler.ServeHTTP(httptest.NewRecorder(), req) + + assert.True(t, nextHandlerCalled, "expected next handler to be called") +} + +func TestWithRemoteAddrHandler_InvalidAddr(t *testing.T) { + nextHandlerCalled := false + + assertionHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + nextHandlerCalled = true + remoteAddr, ok := consul.RemoteAddrFromContext(r.Context()) + if ok || remoteAddr != nil { + t.Errorf("remote addr %v present but not expected", remoteAddr) + } + }) + + remoteAddrHandler := withRemoteAddrHandler(assertionHandler) + req := httptest.NewRequest("GET", "http://ignoreme", nil) + req.RemoteAddr = "i.am.not.a.valid.ipaddr:port" + remoteAddrHandler.ServeHTTP(httptest.NewRecorder(), req) + + assert.True(t, nextHandlerCalled, "expected next handler to be called") +} diff --git a/agent/intentions_endpoint.go b/agent/intentions_endpoint.go index f43dc3ecf4..7144f026f1 100644 --- a/agent/intentions_endpoint.go +++ b/agent/intentions_endpoint.go @@ -40,7 +40,7 @@ func (s *HTTPHandlers) IntentionList(resp http.ResponseWriter, req *http.Request var reply structs.IndexedIntentions defer setMeta(resp, &reply.QueryMeta) - if err := s.agent.RPC("Intention.List", &args, &reply); err != nil { + if err := s.agent.RPC(req.Context(), "Intention.List", &args, &reply); err != nil { return nil, err } @@ -84,7 +84,7 @@ func (s *HTTPHandlers) IntentionCreate(resp http.ResponseWriter, req *http.Reque } var reply string - if err := s.agent.RPC("Intention.Apply", &args, &reply); err != nil { + if err := s.agent.RPC(req.Context(), "Intention.Apply", &args, &reply); err != nil { return nil, err } @@ -176,7 +176,7 @@ func (s *HTTPHandlers) IntentionMatch(resp http.ResponseWriter, req *http.Reques out = *reply } else { RETRY_ONCE: - if err := s.agent.RPC("Intention.Match", args, &out); err != nil { + if err := s.agent.RPC(req.Context(), "Intention.Match", args, &out); err != nil { return nil, err } if args.QueryOptions.AllowStale && args.MaxStaleDuration > 0 && args.MaxStaleDuration < out.LastContact { @@ -254,7 +254,7 @@ func (s *HTTPHandlers) IntentionCheck(resp http.ResponseWriter, req *http.Reques args.Check.DestinationName = parsed.name var reply structs.IntentionQueryCheckResponse - if err := s.agent.RPC("Intention.Check", args, &reply); err != nil { + if err := s.agent.RPC(req.Context(), "Intention.Check", args, &reply); err != nil { return nil, err } @@ -324,7 +324,7 @@ func (s *HTTPHandlers) IntentionGetExact(resp http.ResponseWriter, req *http.Req } var reply structs.IndexedIntentions - if err := s.agent.RPC("Intention.Get", &args, &reply); err != nil { + if err := s.agent.RPC(req.Context(), "Intention.Get", &args, &reply); err != nil { // We have to check the string since the RPC sheds the error type if strings.Contains(err.Error(), consul.ErrIntentionNotFound.Error()) { return nil, HTTPError{StatusCode: http.StatusNotFound, Reason: err.Error()} @@ -386,7 +386,7 @@ func (s *HTTPHandlers) IntentionPutExact(resp http.ResponseWriter, req *http.Req args.Intention.FillPartitionAndNamespace(&entMeta, false) var ignored string - if err := s.agent.RPC("Intention.Apply", &args, &ignored); err != nil { + if err := s.agent.RPC(req.Context(), "Intention.Apply", &args, &ignored); err != nil { return nil, err } @@ -421,7 +421,7 @@ func (s *HTTPHandlers) IntentionDeleteExact(resp http.ResponseWriter, req *http. s.parseToken(req, &args.Token) var ignored string - if err := s.agent.RPC("Intention.Apply", &args, &ignored); err != nil { + if err := s.agent.RPC(req.Context(), "Intention.Apply", &args, &ignored); err != nil { return nil, err } @@ -542,7 +542,7 @@ func (s *HTTPHandlers) IntentionSpecificGet(id string, resp http.ResponseWriter, } var reply structs.IndexedIntentions - if err := s.agent.RPC("Intention.Get", &args, &reply); err != nil { + if err := s.agent.RPC(req.Context(), "Intention.Get", &args, &reply); err != nil { // We have to check the string since the RPC sheds the error type if err.Error() == consul.ErrIntentionNotFound.Error() { return nil, HTTPError{StatusCode: http.StatusNotFound, Reason: err.Error()} @@ -604,7 +604,7 @@ func (s *HTTPHandlers) IntentionSpecificUpdate(id string, resp http.ResponseWrit args.Intention.ID = id var reply string - if err := s.agent.RPC("Intention.Apply", &args, &reply); err != nil { + if err := s.agent.RPC(req.Context(), "Intention.Apply", &args, &reply); err != nil { return nil, err } @@ -624,7 +624,7 @@ func (s *HTTPHandlers) IntentionSpecificDelete(id string, resp http.ResponseWrit s.parseToken(req, &args.Token) var reply string - if err := s.agent.RPC("Intention.Apply", &args, &reply); err != nil { + if err := s.agent.RPC(req.Context(), "Intention.Apply", &args, &reply); err != nil { return nil, err } diff --git a/agent/intentions_endpoint_test.go b/agent/intentions_endpoint_test.go index ff17297fc5..6bd7a5fec2 100644 --- a/agent/intentions_endpoint_test.go +++ b/agent/intentions_endpoint_test.go @@ -1,6 +1,7 @@ package agent import ( + "context" "fmt" "net/http" "net/http/httptest" @@ -58,7 +59,7 @@ func TestIntentionList(t *testing.T) { } var reply string - require.NoError(t, a.RPC("Intention.Apply", &req, &reply)) + require.NoError(t, a.RPC(context.Background(), "Intention.Apply", &req, &reply)) ids = append(ids, reply) } @@ -162,7 +163,7 @@ func TestIntentionMatch(t *testing.T) { // Create var reply string - require.NoError(t, a.RPC("Intention.Apply", &ixn, &reply)) + require.NoError(t, a.RPC(context.Background(), "Intention.Apply", &ixn, &reply)) } } @@ -302,7 +303,7 @@ func TestIntentionCheck(t *testing.T) { // Create var reply string - require.NoError(t, a.RPC("Intention.Apply", &ixn, &reply)) + require.NoError(t, a.RPC(context.Background(), "Intention.Apply", &ixn, &reply)) } } @@ -510,7 +511,7 @@ func TestIntentionPutExact(t *testing.T) { } var resp structs.IndexedIntentions - require.NoError(t, a.RPC("Intention.Get", req, &resp)) + require.NoError(t, a.RPC(context.Background(), "Intention.Get", req, &resp)) require.Len(t, resp.Intentions, 1) actual := resp.Intentions[0] require.Equal(t, "foo", actual.SourceName) @@ -557,7 +558,7 @@ func TestIntentionCreate(t *testing.T) { IntentionID: value.ID, } var resp structs.IndexedIntentions - require.NoError(t, a.RPC("Intention.Get", req, &resp)) + require.NoError(t, a.RPC(context.Background(), "Intention.Get", req, &resp)) require.Len(t, resp.Intentions, 1) actual := resp.Intentions[0] require.Equal(t, "foo", actual.SourceName) @@ -607,7 +608,7 @@ func TestIntentionSpecificGet(t *testing.T) { Op: structs.IntentionOpCreate, Intention: ixn, } - require.NoError(t, a.RPC("Intention.Apply", &req, &reply)) + require.NoError(t, a.RPC(context.Background(), "Intention.Apply", &req, &reply)) } t.Run("invalid id", func(t *testing.T) { @@ -662,7 +663,7 @@ func TestIntentionSpecificUpdate(t *testing.T) { Op: structs.IntentionOpCreate, Intention: ixn, } - require.NoError(t, a.RPC("Intention.Apply", &req, &reply)) + require.NoError(t, a.RPC(context.Background(), "Intention.Apply", &req, &reply)) } // Update the intention @@ -683,7 +684,7 @@ func TestIntentionSpecificUpdate(t *testing.T) { IntentionID: reply, } var resp structs.IndexedIntentions - require.NoError(t, a.RPC("Intention.Get", req, &resp)) + require.NoError(t, a.RPC(context.Background(), "Intention.Get", req, &resp)) require.Len(t, resp.Intentions, 1) actual := resp.Intentions[0] require.Equal(t, "bar", actual.SourceName) @@ -745,7 +746,7 @@ func TestIntentionDeleteExact(t *testing.T) { Intention: ixn, } var ignored string - require.NoError(t, a.RPC("Intention.Apply", &req, &ignored)) + require.NoError(t, a.RPC(context.Background(), "Intention.Apply", &req, &ignored)) } // Sanity check that the intention exists @@ -755,7 +756,7 @@ func TestIntentionDeleteExact(t *testing.T) { Exact: exact, } var resp structs.IndexedIntentions - require.NoError(t, a.RPC("Intention.Get", req, &resp)) + require.NoError(t, a.RPC(context.Background(), "Intention.Get", req, &resp)) require.Len(t, resp.Intentions, 1) actual := resp.Intentions[0] require.Equal(t, "foo", actual.SourceName) @@ -799,7 +800,7 @@ func TestIntentionDeleteExact(t *testing.T) { Exact: exact, } var resp structs.IndexedIntentions - err := a.RPC("Intention.Get", req, &resp) + err := a.RPC(context.Background(), "Intention.Get", req, &resp) testutil.RequireErrorContains(t, err, "not found") } }) @@ -840,7 +841,7 @@ func TestIntentionSpecificDelete(t *testing.T) { Op: structs.IntentionOpCreate, Intention: ixn, } - require.NoError(t, a.RPC("Intention.Apply", &req, &reply)) + require.NoError(t, a.RPC(context.Background(), "Intention.Apply", &req, &reply)) } // Sanity check that the intention exists @@ -850,7 +851,7 @@ func TestIntentionSpecificDelete(t *testing.T) { IntentionID: reply, } var resp structs.IndexedIntentions - require.NoError(t, a.RPC("Intention.Get", req, &resp)) + require.NoError(t, a.RPC(context.Background(), "Intention.Get", req, &resp)) require.Len(t, resp.Intentions, 1) actual := resp.Intentions[0] require.Equal(t, "foo", actual.SourceName) @@ -870,7 +871,7 @@ func TestIntentionSpecificDelete(t *testing.T) { IntentionID: reply, } var resp structs.IndexedIntentions - err := a.RPC("Intention.Get", req, &resp) + err := a.RPC(context.Background(), "Intention.Get", req, &resp) testutil.RequireErrorContains(t, err, "not found") } } diff --git a/agent/keyring.go b/agent/keyring.go index a430eee92c..ad4496ead1 100644 --- a/agent/keyring.go +++ b/agent/keyring.go @@ -2,6 +2,7 @@ package agent import ( "bytes" + "context" "encoding/base64" "encoding/json" "fmt" @@ -232,7 +233,7 @@ func decodeStringKey(key string) ([]byte, error) { func (a *Agent) keyringProcess(args *structs.KeyringRequest) (*structs.KeyringResponses, error) { var reply structs.KeyringResponses - if err := a.RPC("Internal.KeyringOperation", args, &reply); err != nil { + if err := a.RPC(context.Background(), "Internal.KeyringOperation", args, &reply); err != nil { return &reply, err } diff --git a/agent/kvs_endpoint.go b/agent/kvs_endpoint.go index 1aed8178ad..ce65604b55 100644 --- a/agent/kvs_endpoint.go +++ b/agent/kvs_endpoint.go @@ -69,7 +69,7 @@ func (s *HTTPHandlers) KVSGet(resp http.ResponseWriter, req *http.Request, args // Make the RPC var out structs.IndexedDirEntries - if err := s.agent.RPC(method, args, &out); err != nil { + if err := s.agent.RPC(req.Context(), method, args, &out); err != nil { return nil, err } setMeta(resp, &out.QueryMeta) @@ -129,7 +129,7 @@ func (s *HTTPHandlers) KVSGetKeys(resp http.ResponseWriter, req *http.Request, a // Make the RPC var out structs.IndexedKeyList - if err := s.agent.RPC("KVS.ListKeys", &listArgs, &out); err != nil { + if err := s.agent.RPC(req.Context(), "KVS.ListKeys", &listArgs, &out); err != nil { return nil, err } setMeta(resp, &out.QueryMeta) @@ -221,7 +221,7 @@ func (s *HTTPHandlers) KVSPut(resp http.ResponseWriter, req *http.Request, args // Make the RPC var out bool - if err := s.agent.RPC("KVS.Apply", &applyReq, &out); err != nil { + if err := s.agent.RPC(req.Context(), "KVS.Apply", &applyReq, &out); err != nil { return nil, err } @@ -270,7 +270,7 @@ func (s *HTTPHandlers) KVSDelete(resp http.ResponseWriter, req *http.Request, ar // Make the RPC var out bool - if err := s.agent.RPC("KVS.Apply", &applyReq, &out); err != nil { + if err := s.agent.RPC(req.Context(), "KVS.Apply", &applyReq, &out); err != nil { return nil, err } diff --git a/agent/local/state.go b/agent/local/state.go index 68a29b3a2e..d2a634ebf0 100644 --- a/agent/local/state.go +++ b/agent/local/state.go @@ -1,6 +1,7 @@ package local import ( + "context" "fmt" "reflect" "strconv" @@ -149,7 +150,7 @@ func (c *CheckState) CriticalFor() time.Duration { } type rpc interface { - RPC(method string, args interface{}, reply interface{}) error + RPC(ctx context.Context, method string, args interface{}, reply interface{}) error ResolveTokenAndDefaultMeta(token string, entMeta *acl.EnterpriseMeta, authzContext *acl.AuthorizerContext) (resolver.Result, error) } @@ -1007,7 +1008,7 @@ func (l *State) updateSyncState() error { remoteServices := make(map[structs.ServiceID]*structs.NodeService) var svcNode *structs.Node - if err := l.Delegate.RPC("Catalog.NodeServiceList", &req, &out1); err == nil { + if err := l.Delegate.RPC(context.Background(), "Catalog.NodeServiceList", &req, &out1); err == nil { for _, svc := range out1.NodeServices.Services { remoteServices[svc.CompoundServiceID()] = svc } @@ -1016,7 +1017,7 @@ func (l *State) updateSyncState() error { } else if errMsg := err.Error(); strings.Contains(errMsg, "rpc: can't find method") { // fallback to the old RPC var out1 structs.IndexedNodeServices - if err := l.Delegate.RPC("Catalog.NodeServices", &req, &out1); err != nil { + if err := l.Delegate.RPC(context.Background(), "Catalog.NodeServices", &req, &out1); err != nil { return err } @@ -1032,7 +1033,7 @@ func (l *State) updateSyncState() error { } var out2 structs.IndexedHealthChecks - if err := l.Delegate.RPC("Health.NodeChecks", &req, &out2); err != nil { + if err := l.Delegate.RPC(context.Background(), "Health.NodeChecks", &req, &out2); err != nil { return err } @@ -1279,7 +1280,7 @@ func (l *State) deleteService(key structs.ServiceID) error { WriteRequest: structs.WriteRequest{Token: st}, } var out struct{} - err := l.Delegate.RPC("Catalog.Deregister", &req, &out) + err := l.Delegate.RPC(context.Background(), "Catalog.Deregister", &req, &out) switch { case err == nil || strings.Contains(err.Error(), "Unknown service"): delete(l.services, key) @@ -1328,7 +1329,7 @@ func (l *State) deleteCheck(key structs.CheckID) error { WriteRequest: structs.WriteRequest{Token: ct}, } var out struct{} - err := l.Delegate.RPC("Catalog.Deregister", &req, &out) + err := l.Delegate.RPC(context.Background(), "Catalog.Deregister", &req, &out) switch { case err == nil || strings.Contains(err.Error(), "Unknown check"): l.pruneCheck(key) @@ -1406,7 +1407,7 @@ func (l *State) syncService(key structs.ServiceID) error { } var out struct{} - err := l.Delegate.RPC("Catalog.Register", &req, &out) + err := l.Delegate.RPC(context.Background(), "Catalog.Register", &req, &out) switch { case err == nil: l.services[key].InSync = true @@ -1468,7 +1469,7 @@ func (l *State) syncCheck(key structs.CheckID) error { } var out struct{} - err := l.Delegate.RPC("Catalog.Register", &req, &out) + err := l.Delegate.RPC(context.Background(), "Catalog.Register", &req, &out) switch { case err == nil: l.checks[key].InSync = true @@ -1509,7 +1510,7 @@ func (l *State) syncNodeInfo() error { WriteRequest: structs.WriteRequest{Token: at}, } var out struct{} - err := l.Delegate.RPC("Catalog.Register", &req, &out) + err := l.Delegate.RPC(context.Background(), "Catalog.Register", &req, &out) switch { case err == nil: l.nodeInfoInSync = true diff --git a/agent/local/state_test.go b/agent/local/state_test.go index 7aa539ea0b..448cfde044 100644 --- a/agent/local/state_test.go +++ b/agent/local/state_test.go @@ -1,6 +1,7 @@ package local_test import ( + "context" "errors" "fmt" "os" @@ -67,7 +68,7 @@ func TestAgentAntiEntropy_Services(t *testing.T) { a.State.AddServiceWithChecks(srv1, nil, "") assert.True(t, a.State.ServiceExists(structs.ServiceID{ID: srv1.ID})) args.Service = srv1 - if err := a.RPC("Catalog.Register", args, &out); err != nil { + if err := a.RPC(context.Background(), "Catalog.Register", args, &out); err != nil { t.Fatalf("err: %v", err) } @@ -89,7 +90,7 @@ func TestAgentAntiEntropy_Services(t *testing.T) { *srv2_mod = *srv2 srv2_mod.Port = 9000 args.Service = srv2_mod - if err := a.RPC("Catalog.Register", args, &out); err != nil { + if err := a.RPC(context.Background(), "Catalog.Register", args, &out); err != nil { t.Fatalf("err: %v", err) } @@ -120,7 +121,7 @@ func TestAgentAntiEntropy_Services(t *testing.T) { EnterpriseMeta: *structs.DefaultEnterpriseMetaInDefaultPartition(), } args.Service = srv4 - if err := a.RPC("Catalog.Register", args, &out); err != nil { + if err := a.RPC(context.Background(), "Catalog.Register", args, &out); err != nil { t.Fatalf("err: %v", err) } @@ -143,7 +144,7 @@ func TestAgentAntiEntropy_Services(t *testing.T) { *srv5_mod = *srv5 srv5_mod.Address = "127.0.0.1" args.Service = srv5_mod - if err := a.RPC("Catalog.Register", args, &out); err != nil { + if err := a.RPC(context.Background(), "Catalog.Register", args, &out); err != nil { t.Fatalf("err: %v", err) } @@ -174,7 +175,7 @@ func TestAgentAntiEntropy_Services(t *testing.T) { Node: a.Config.NodeName, } - if err := a.RPC("Catalog.NodeServices", &req, &services); err != nil { + if err := a.RPC(context.Background(), "Catalog.NodeServices", &req, &services); err != nil { t.Fatalf("err: %v", err) } @@ -224,7 +225,7 @@ func TestAgentAntiEntropy_Services(t *testing.T) { t.Fatalf("err: %v", err) } - if err := a.RPC("Catalog.NodeServices", &req, &services); err != nil { + if err := a.RPC(context.Background(), "Catalog.NodeServices", &req, &services); err != nil { t.Fatalf("err: %v", err) } @@ -291,7 +292,7 @@ func TestAgentAntiEntropy_Services_ConnectProxy(t *testing.T) { EnterpriseMeta: *structs.DefaultEnterpriseMetaInDefaultPartition(), } a.State.AddServiceWithChecks(srv1, nil, "") - require.NoError(t, a.RPC("Catalog.Register", &structs.RegisterRequest{ + require.NoError(t, a.RPC(context.Background(), "Catalog.Register", &structs.RegisterRequest{ Datacenter: "dc1", Node: a.Config.NodeName, Address: "127.0.0.1", @@ -315,7 +316,7 @@ func TestAgentAntiEntropy_Services_ConnectProxy(t *testing.T) { srv2_mod := clone(srv2) srv2_mod.Port = 9000 - require.NoError(t, a.RPC("Catalog.Register", &structs.RegisterRequest{ + require.NoError(t, a.RPC(context.Background(), "Catalog.Register", &structs.RegisterRequest{ Datacenter: "dc1", Node: a.Config.NodeName, Address: "127.0.0.1", @@ -350,7 +351,7 @@ func TestAgentAntiEntropy_Services_ConnectProxy(t *testing.T) { }, EnterpriseMeta: *structs.DefaultEnterpriseMetaInDefaultPartition(), } - require.NoError(t, a.RPC("Catalog.Register", &structs.RegisterRequest{ + require.NoError(t, a.RPC(context.Background(), "Catalog.Register", &structs.RegisterRequest{ Datacenter: "dc1", Node: a.Config.NodeName, Address: "127.0.0.1", @@ -382,7 +383,7 @@ func TestAgentAntiEntropy_Services_ConnectProxy(t *testing.T) { Datacenter: "dc1", Node: a.Config.NodeName, } - require.NoError(t, a.RPC("Catalog.NodeServices", &req, &services)) + require.NoError(t, a.RPC(context.Background(), "Catalog.NodeServices", &req, &services)) // We should have 5 services (consul included) require.Len(t, services.NodeServices.Services, 5) @@ -454,7 +455,7 @@ func TestAgentAntiEntropy_Services_ConnectProxy(t *testing.T) { // Remove one of the services a.State.RemoveService(structs.NewServiceID("cache-proxy", nil)) require.NoError(t, a.State.SyncFull()) - require.NoError(t, a.RPC("Catalog.NodeServices", &req, &services)) + require.NoError(t, a.RPC(context.Background(), "Catalog.NodeServices", &req, &services)) // We should have 4 services (consul included) require.Len(t, services.NodeServices.Services, 4) @@ -632,7 +633,7 @@ func TestAgentAntiEntropy_EnableTagOverride(t *testing.T) { }, EnterpriseMeta: *structs.DefaultEnterpriseMetaInDefaultPartition(), } - if err := a.RPC("Catalog.Register", args, &out); err != nil { + if err := a.RPC(context.Background(), "Catalog.Register", args, &out); err != nil { t.Fatalf("err: %v", err) } @@ -648,7 +649,7 @@ func TestAgentAntiEntropy_EnableTagOverride(t *testing.T) { }, EnterpriseMeta: *structs.DefaultEnterpriseMetaInDefaultPartition(), } - if err := a.RPC("Catalog.Register", args, &out); err != nil { + if err := a.RPC(context.Background(), "Catalog.Register", args, &out); err != nil { t.Fatalf("err: %v", err) } @@ -664,7 +665,7 @@ func TestAgentAntiEntropy_EnableTagOverride(t *testing.T) { var services structs.IndexedNodeServices retry.Run(t, func(r *retry.R) { - if err := a.RPC("Catalog.NodeServices", &req, &services); err != nil { + if err := a.RPC(context.Background(), "Catalog.NodeServices", &req, &services); err != nil { r.Fatalf("err: %v", err) } @@ -743,7 +744,7 @@ func TestAgentAntiEntropy_Services_WithChecks(t *testing.T) { Node: a.Config.NodeName, } var services structs.IndexedNodeServices - if err := a.RPC("Catalog.NodeServices", &svcReq, &services); err != nil { + if err := a.RPC(context.Background(), "Catalog.NodeServices", &svcReq, &services); err != nil { t.Fatalf("err: %v", err) } if len(services.NodeServices.Services) != 2 { @@ -756,7 +757,7 @@ func TestAgentAntiEntropy_Services_WithChecks(t *testing.T) { ServiceName: "mysql", } var checks structs.IndexedHealthChecks - if err := a.RPC("Health.ServiceChecks", &chkReq, &checks); err != nil { + if err := a.RPC(context.Background(), "Health.ServiceChecks", &chkReq, &checks); err != nil { t.Fatalf("err: %v", err) } if len(checks.HealthChecks) != 1 { @@ -802,7 +803,7 @@ func TestAgentAntiEntropy_Services_WithChecks(t *testing.T) { Node: a.Config.NodeName, } var services structs.IndexedNodeServices - if err := a.RPC("Catalog.NodeServices", &svcReq, &services); err != nil { + if err := a.RPC(context.Background(), "Catalog.NodeServices", &svcReq, &services); err != nil { t.Fatalf("err: %v", err) } if len(services.NodeServices.Services) != 3 { @@ -815,7 +816,7 @@ func TestAgentAntiEntropy_Services_WithChecks(t *testing.T) { ServiceName: "redis", } var checks structs.IndexedHealthChecks - if err := a.RPC("Health.ServiceChecks", &chkReq, &checks); err != nil { + if err := a.RPC(context.Background(), "Health.ServiceChecks", &chkReq, &checks); err != nil { t.Fatalf("err: %v", err) } if len(checks.HealthChecks) != 2 { @@ -903,7 +904,7 @@ func TestAgentAntiEntropy_Services_ACLDeny(t *testing.T) { }, } var services structs.IndexedNodeServices - if err := a.RPC("Catalog.NodeServices", &req, &services); err != nil { + if err := a.RPC(context.Background(), "Catalog.NodeServices", &req, &services); err != nil { t.Fatalf("err: %v", err) } @@ -948,7 +949,7 @@ func TestAgentAntiEntropy_Services_ACLDeny(t *testing.T) { }, } var services structs.IndexedNodeServices - if err := a.RPC("Catalog.NodeServices", &req, &services); err != nil { + if err := a.RPC(context.Background(), "Catalog.NodeServices", &req, &services); err != nil { t.Fatalf("err: %v", err) } @@ -984,7 +985,7 @@ func TestAgentAntiEntropy_Services_ACLDeny(t *testing.T) { } type RPC interface { - RPC(method string, args interface{}, reply interface{}) error + RPC(ctx context.Context, method string, args interface{}, reply interface{}) error } func createToken(t *testing.T, rpc RPC, policyRules string) string { @@ -998,7 +999,7 @@ func createToken(t *testing.T, rpc RPC, policyRules string) string { }, WriteRequest: structs.WriteRequest{Token: "root"}, } - err := rpc.RPC("ACL.PolicySet", &reqPolicy, &structs.ACLPolicy{}) + err := rpc.RPC(context.Background(), "ACL.PolicySet", &reqPolicy, &structs.ACLPolicy{}) require.NoError(t, err) token, err := uuid.GenerateUUID() @@ -1012,7 +1013,7 @@ func createToken(t *testing.T, rpc RPC, policyRules string) string { }, WriteRequest: structs.WriteRequest{Token: "root"}, } - err = rpc.RPC("ACL.TokenSet", &reqToken, &structs.ACLToken{}) + err = rpc.RPC(context.Background(), "ACL.TokenSet", &reqToken, &structs.ACLToken{}) require.NoError(t, err) return token } @@ -1045,7 +1046,7 @@ func TestAgentAntiEntropy_Checks(t *testing.T) { } a.State.AddCheck(chk1, "") args.Check = chk1 - if err := a.RPC("Catalog.Register", args, &out); err != nil { + if err := a.RPC(context.Background(), "Catalog.Register", args, &out); err != nil { t.Fatalf("err: %v", err) } @@ -1063,7 +1064,7 @@ func TestAgentAntiEntropy_Checks(t *testing.T) { *chk2_mod = *chk2 chk2_mod.Status = api.HealthCritical args.Check = chk2_mod - if err := a.RPC("Catalog.Register", args, &out); err != nil { + if err := a.RPC(context.Background(), "Catalog.Register", args, &out); err != nil { t.Fatalf("err: %v", err) } @@ -1086,7 +1087,7 @@ func TestAgentAntiEntropy_Checks(t *testing.T) { EnterpriseMeta: *structs.DefaultEnterpriseMetaInDefaultPartition(), } args.Check = chk4 - if err := a.RPC("Catalog.Register", args, &out); err != nil { + if err := a.RPC(context.Background(), "Catalog.Register", args, &out); err != nil { t.Fatalf("err: %v", err) } @@ -1116,7 +1117,7 @@ func TestAgentAntiEntropy_Checks(t *testing.T) { retry.Run(t, func(r *retry.R) { // Verify that we are in sync - if err := a.RPC("Health.NodeChecks", &req, &checks); err != nil { + if err := a.RPC(context.Background(), "Health.NodeChecks", &req, &checks); err != nil { r.Fatalf("err: %v", err) } @@ -1158,7 +1159,7 @@ func TestAgentAntiEntropy_Checks(t *testing.T) { Node: a.Config.NodeName, } var services structs.IndexedNodeServices - if err := a.RPC("Catalog.NodeServices", &req, &services); err != nil { + if err := a.RPC(context.Background(), "Catalog.NodeServices", &req, &services); err != nil { r.Fatalf("err: %v", err) } @@ -1181,7 +1182,7 @@ func TestAgentAntiEntropy_Checks(t *testing.T) { } // Verify that we are in sync - if err := a.RPC("Health.NodeChecks", &req, &checks); err != nil { + if err := a.RPC(context.Background(), "Health.NodeChecks", &req, &checks); err != nil { r.Fatalf("err: %v", err) } @@ -1242,7 +1243,7 @@ func TestAgentAntiEntropy_RemovingServiceAndCheck(t *testing.T) { Port: 8080, } args.Service = srv - if err := a.RPC("Catalog.Register", args, &out); err != nil { + if err := a.RPC(context.Background(), "Catalog.Register", args, &out); err != nil { t.Fatalf("err: %v", err) } @@ -1257,7 +1258,7 @@ func TestAgentAntiEntropy_RemovingServiceAndCheck(t *testing.T) { } args.Check = chk - if err := a.RPC("Catalog.Register", args, &out); err != nil { + if err := a.RPC(context.Background(), "Catalog.Register", args, &out); err != nil { t.Fatalf("err: %v", err) } @@ -1271,7 +1272,7 @@ func TestAgentAntiEntropy_RemovingServiceAndCheck(t *testing.T) { Node: a.Config.NodeName, } - if err := a.RPC("Catalog.NodeServices", &req, &services); err != nil { + if err := a.RPC(context.Background(), "Catalog.NodeServices", &req, &services); err != nil { t.Fatalf("err: %v", err) } @@ -1282,7 +1283,7 @@ func TestAgentAntiEntropy_RemovingServiceAndCheck(t *testing.T) { var checks structs.IndexedHealthChecks // Verify that we are in sync - if err := a.RPC("Health.NodeChecks", &req, &checks); err != nil { + if err := a.RPC(context.Background(), "Health.NodeChecks", &req, &checks); err != nil { t.Fatalf("err: %v", err) } @@ -1360,7 +1361,7 @@ func TestAgentAntiEntropy_Checks_ACLDeny(t *testing.T) { }, } var services structs.IndexedNodeServices - if err := a.RPC("Catalog.NodeServices", &req, &services); err != nil { + if err := a.RPC(context.Background(), "Catalog.NodeServices", &req, &services); err != nil { t.Fatalf("err: %v", err) } @@ -1428,7 +1429,7 @@ func TestAgentAntiEntropy_Checks_ACLDeny(t *testing.T) { }, } var checks structs.IndexedHealthChecks - if err := a.RPC("Health.NodeChecks", &req, &checks); err != nil { + if err := a.RPC(context.Background(), "Health.NodeChecks", &req, &checks); err != nil { t.Fatalf("err: %v", err) } @@ -1472,7 +1473,7 @@ func TestAgentAntiEntropy_Checks_ACLDeny(t *testing.T) { }, } var checks structs.IndexedHealthChecks - if err := a.RPC("Health.NodeChecks", &req, &checks); err != nil { + if err := a.RPC(context.Background(), "Health.NodeChecks", &req, &checks); err != nil { t.Fatalf("err: %v", err) } @@ -1598,7 +1599,7 @@ func TestAgentAntiEntropy_Check_DeferSync(t *testing.T) { } var checks structs.IndexedHealthChecks retry.Run(t, func(r *retry.R) { - if err := a.RPC("Health.NodeChecks", &req, &checks); err != nil { + if err := a.RPC(context.Background(), "Health.NodeChecks", &req, &checks); err != nil { r.Fatalf("err: %v", err) } if got, want := len(checks.HealthChecks), 2; got != want { @@ -1652,7 +1653,7 @@ func TestAgentAntiEntropy_Check_DeferSync(t *testing.T) { // synced. timer = &retry.Timer{Timeout: 6 * time.Second, Wait: 100 * time.Millisecond} retry.RunWith(timer, t, func(r *retry.R) { - if err := a.RPC("Health.NodeChecks", &req, &checks); err != nil { + if err := a.RPC(context.Background(), "Health.NodeChecks", &req, &checks); err != nil { r.Fatalf("err: %v", err) } @@ -1679,12 +1680,12 @@ func TestAgentAntiEntropy_Check_DeferSync(t *testing.T) { WriteRequest: structs.WriteRequest{}, } var out struct{} - if err := a.RPC("Catalog.Register", ®, &out); err != nil { + if err := a.RPC(context.Background(), "Catalog.Register", ®, &out); err != nil { t.Fatalf("err: %s", err) } // Verify that the output is out of sync. - if err := a.RPC("Health.NodeChecks", &req, &checks); err != nil { + if err := a.RPC(context.Background(), "Health.NodeChecks", &req, &checks); err != nil { t.Fatalf("err: %v", err) } for _, chk := range checks.HealthChecks { @@ -1701,7 +1702,7 @@ func TestAgentAntiEntropy_Check_DeferSync(t *testing.T) { } // Verify that the output was synced back to the agent's value. - if err := a.RPC("Health.NodeChecks", &req, &checks); err != nil { + if err := a.RPC(context.Background(), "Health.NodeChecks", &req, &checks); err != nil { t.Fatalf("err: %v", err) } for _, chk := range checks.HealthChecks { @@ -1714,12 +1715,12 @@ func TestAgentAntiEntropy_Check_DeferSync(t *testing.T) { } // Reset the catalog again. - if err := a.RPC("Catalog.Register", ®, &out); err != nil { + if err := a.RPC(context.Background(), "Catalog.Register", ®, &out); err != nil { t.Fatalf("err: %s", err) } // Verify that the output is out of sync. - if err := a.RPC("Health.NodeChecks", &req, &checks); err != nil { + if err := a.RPC(context.Background(), "Health.NodeChecks", &req, &checks); err != nil { t.Fatalf("err: %v", err) } for _, chk := range checks.HealthChecks { @@ -1740,7 +1741,7 @@ func TestAgentAntiEntropy_Check_DeferSync(t *testing.T) { // Verify that the output is still out of sync since there's a deferred // update pending. - if err := a.RPC("Health.NodeChecks", &req, &checks); err != nil { + if err := a.RPC(context.Background(), "Health.NodeChecks", &req, &checks); err != nil { t.Fatalf("err: %v", err) } for _, chk := range checks.HealthChecks { @@ -1753,7 +1754,7 @@ func TestAgentAntiEntropy_Check_DeferSync(t *testing.T) { } // Wait for the deferred update. retry.Run(t, func(r *retry.R) { - if err := a.RPC("Health.NodeChecks", &req, &checks); err != nil { + if err := a.RPC(context.Background(), "Health.NodeChecks", &req, &checks); err != nil { r.Fatal(err) } @@ -1798,7 +1799,7 @@ func TestAgentAntiEntropy_NodeInfo(t *testing.T) { Address: "127.0.0.1", } var out struct{} - if err := a.RPC("Catalog.Register", args, &out); err != nil { + if err := a.RPC(context.Background(), "Catalog.Register", args, &out); err != nil { t.Fatalf("err: %v", err) } @@ -1811,7 +1812,7 @@ func TestAgentAntiEntropy_NodeInfo(t *testing.T) { Node: a.Config.NodeName, } var services structs.IndexedNodeServices - if err := a.RPC("Catalog.NodeServices", &req, &services); err != nil { + if err := a.RPC(context.Background(), "Catalog.NodeServices", &req, &services); err != nil { t.Fatalf("err: %v", err) } @@ -1824,7 +1825,7 @@ func TestAgentAntiEntropy_NodeInfo(t *testing.T) { assert.Equal(t, unNilMap(a.Config.NodeMeta), meta) // Blow away the catalog version of the node info - if err := a.RPC("Catalog.Register", args, &out); err != nil { + if err := a.RPC(context.Background(), "Catalog.Register", args, &out); err != nil { t.Fatalf("err: %v", err) } @@ -1833,7 +1834,7 @@ func TestAgentAntiEntropy_NodeInfo(t *testing.T) { } // Wait for the sync - this should have been a sync of just the node info - if err := a.RPC("Catalog.NodeServices", &req, &services); err != nil { + if err := a.RPC(context.Background(), "Catalog.NodeServices", &req, &services); err != nil { t.Fatalf("err: %v", err) } @@ -2143,7 +2144,7 @@ func TestAgent_sendCoordinate(t *testing.T) { } var reply structs.IndexedCoordinates retry.Run(t, func(r *retry.R) { - if err := a.RPC("Coordinate.ListNodes", &req, &reply); err != nil { + if err := a.RPC(context.Background(), "Coordinate.ListNodes", &req, &reply); err != nil { r.Fatalf("err: %s", err) } if len(reply.Coordinates) != 1 { @@ -2416,7 +2417,7 @@ type callRPC struct { reply interface{} } -func (f *fakeRPC) RPC(method string, args interface{}, reply interface{}) error { +func (f *fakeRPC) RPC(ctx context.Context, method string, args interface{}, reply interface{}) error { f.calls = append(f.calls, callRPC{method: method, args: args, reply: reply}) return nil } diff --git a/agent/metrics_test.go b/agent/metrics_test.go index a66beae3bd..1f649dd07a 100644 --- a/agent/metrics_test.go +++ b/agent/metrics_test.go @@ -188,7 +188,7 @@ func TestAgent_OneTwelveRPCMetrics(t *testing.T) { defer a.Shutdown() var out struct{} - err := a.RPC("Status.Ping", struct{}{}, &out) + err := a.RPC(context.Background(), "Status.Ping", struct{}{}, &out) require.NoError(t, err) respRec := httptest.NewRecorder() @@ -213,11 +213,11 @@ func TestAgent_OneTwelveRPCMetrics(t *testing.T) { defer a.Shutdown() var out struct{} - err := a.RPC("Status.Ping", struct{}{}, &out) + err := a.RPC(context.Background(), "Status.Ping", struct{}{}, &out) require.NoError(t, err) - err = a.RPC("Status.Ping", struct{}{}, &out) + err = a.RPC(context.Background(), "Status.Ping", struct{}{}, &out) require.NoError(t, err) - err = a.RPC("Status.Ping", struct{}{}, &out) + err = a.RPC(context.Background(), "Status.Ping", struct{}{}, &out) require.NoError(t, err) respRec := httptest.NewRecorder() diff --git a/agent/operator_endpoint.go b/agent/operator_endpoint.go index 10af5e31d0..9baf0e8b87 100644 --- a/agent/operator_endpoint.go +++ b/agent/operator_endpoint.go @@ -2,12 +2,13 @@ package agent import ( "fmt" - external "github.com/hashicorp/consul/agent/grpc-external" - "github.com/hashicorp/consul/proto/pboperator" "net/http" "strconv" "time" + external "github.com/hashicorp/consul/agent/grpc-external" + "github.com/hashicorp/consul/proto/pboperator" + multierror "github.com/hashicorp/go-multierror" "github.com/hashicorp/raft" autopilot "github.com/hashicorp/raft-autopilot" @@ -26,7 +27,7 @@ func (s *HTTPHandlers) OperatorRaftConfiguration(resp http.ResponseWriter, req * } var reply structs.RaftConfigurationResponse - if err := s.agent.RPC("Operator.RaftGetConfiguration", &args, &reply); err != nil { + if err := s.agent.RPC(req.Context(), "Operator.RaftGetConfiguration", &args, &reply); err != nil { return nil, err } @@ -102,7 +103,7 @@ func (s *HTTPHandlers) OperatorRaftPeer(resp http.ResponseWriter, req *http.Requ if hasAddress { method = "Operator.RaftRemovePeerByAddress" } - if err := s.agent.RPC(method, &args, &reply); err != nil { + if err := s.agent.RPC(req.Context(), method, &args, &reply); err != nil { return nil, err } @@ -242,7 +243,7 @@ func (s *HTTPHandlers) OperatorAutopilotConfiguration(resp http.ResponseWriter, } var reply structs.AutopilotConfig - if err := s.agent.RPC("Operator.AutopilotGetConfiguration", &args, &reply); err != nil { + if err := s.agent.RPC(req.Context(), "Operator.AutopilotGetConfiguration", &args, &reply); err != nil { return nil, err } @@ -294,7 +295,7 @@ func (s *HTTPHandlers) OperatorAutopilotConfiguration(resp http.ResponseWriter, } var reply bool - if err := s.agent.RPC("Operator.AutopilotSetConfiguration", &args, &reply); err != nil { + if err := s.agent.RPC(req.Context(), "Operator.AutopilotSetConfiguration", &args, &reply); err != nil { return nil, err } @@ -317,7 +318,7 @@ func (s *HTTPHandlers) OperatorServerHealth(resp http.ResponseWriter, req *http. } var reply structs.AutopilotHealthReply - if err := s.agent.RPC("Operator.ServerHealth", &args, &reply); err != nil { + if err := s.agent.RPC(req.Context(), "Operator.ServerHealth", &args, &reply); err != nil { return nil, err } @@ -357,7 +358,7 @@ func (s *HTTPHandlers) OperatorAutopilotState(resp http.ResponseWriter, req *htt } var reply autopilot.State - if err := s.agent.RPC("Operator.AutopilotState", &args, &reply); err != nil { + if err := s.agent.RPC(req.Context(), "Operator.AutopilotState", &args, &reply); err != nil { return nil, err } diff --git a/agent/operator_endpoint_test.go b/agent/operator_endpoint_test.go index 37f893b441..25e3caf75c 100644 --- a/agent/operator_endpoint_test.go +++ b/agent/operator_endpoint_test.go @@ -2,6 +2,7 @@ package agent import ( "bytes" + "context" "fmt" "net/http" "net/http/httptest" @@ -464,7 +465,7 @@ func TestOperator_AutopilotSetConfiguration(t *testing.T) { } var reply structs.AutopilotConfig - if err := a.RPC("Operator.AutopilotGetConfiguration", &args, &reply); err != nil { + if err := a.RPC(context.Background(), "Operator.AutopilotGetConfiguration", &args, &reply); err != nil { t.Fatalf("err: %v", err) } @@ -499,7 +500,7 @@ func TestOperator_AutopilotCASConfiguration(t *testing.T) { } var reply structs.AutopilotConfig - if err := a.RPC("Operator.AutopilotGetConfiguration", &args, &reply); err != nil { + if err := a.RPC(context.Background(), "Operator.AutopilotGetConfiguration", &args, &reply); err != nil { t.Fatalf("err: %v", err) } @@ -538,7 +539,7 @@ func TestOperator_AutopilotCASConfiguration(t *testing.T) { } // Verify the update - if err := a.RPC("Operator.AutopilotGetConfiguration", &args, &reply); err != nil { + if err := a.RPC(context.Background(), "Operator.AutopilotGetConfiguration", &args, &reply); err != nil { t.Fatalf("err: %v", err) } if !reply.CleanupDeadServers { diff --git a/agent/prepared_query_endpoint.go b/agent/prepared_query_endpoint.go index 6f5e7a9c2e..2aafb7d964 100644 --- a/agent/prepared_query_endpoint.go +++ b/agent/prepared_query_endpoint.go @@ -27,7 +27,7 @@ func (s *HTTPHandlers) preparedQueryCreate(resp http.ResponseWriter, req *http.R } var reply string - if err := s.agent.RPC("PreparedQuery.Apply", &args, &reply); err != nil { + if err := s.agent.RPC(req.Context(), "PreparedQuery.Apply", &args, &reply); err != nil { return nil, err } return preparedQueryCreateResponse{reply}, nil @@ -43,7 +43,7 @@ func (s *HTTPHandlers) preparedQueryList(resp http.ResponseWriter, req *http.Req var reply structs.IndexedPreparedQueries defer setMeta(resp, &reply.QueryMeta) RETRY_ONCE: - if err := s.agent.RPC("PreparedQuery.List", &args, &reply); err != nil { + if err := s.agent.RPC(req.Context(), "PreparedQuery.List", &args, &reply); err != nil { return nil, err } if args.QueryOptions.AllowStale && args.MaxStaleDuration > 0 && args.MaxStaleDuration < reply.LastContact { @@ -139,7 +139,7 @@ func (s *HTTPHandlers) preparedQueryExecute(id string, resp http.ResponseWriter, reply = *r } else { RETRY_ONCE: - if err := s.agent.RPC("PreparedQuery.Execute", &args, &reply); err != nil { + if err := s.agent.RPC(req.Context(), "PreparedQuery.Execute", &args, &reply); err != nil { // We have to check the string since the RPC sheds // the specific error type. if structs.IsErrQueryNotFound(err) { @@ -192,7 +192,7 @@ func (s *HTTPHandlers) preparedQueryExplain(id string, resp http.ResponseWriter, var reply structs.PreparedQueryExplainResponse defer setMeta(resp, &reply.QueryMeta) RETRY_ONCE: - if err := s.agent.RPC("PreparedQuery.Explain", &args, &reply); err != nil { + if err := s.agent.RPC(req.Context(), "PreparedQuery.Explain", &args, &reply); err != nil { // We have to check the string since the RPC sheds // the specific error type. if structs.IsErrQueryNotFound(err) { @@ -221,7 +221,7 @@ func (s *HTTPHandlers) preparedQueryGet(id string, resp http.ResponseWriter, req var reply structs.IndexedPreparedQueries defer setMeta(resp, &reply.QueryMeta) RETRY_ONCE: - if err := s.agent.RPC("PreparedQuery.Get", &args, &reply); err != nil { + if err := s.agent.RPC(req.Context(), "PreparedQuery.Get", &args, &reply); err != nil { // We have to check the string since the RPC sheds // the specific error type. if structs.IsErrQueryNotFound(err) { @@ -259,7 +259,7 @@ func (s *HTTPHandlers) preparedQueryUpdate(id string, resp http.ResponseWriter, args.Query.ID = id var reply string - if err := s.agent.RPC("PreparedQuery.Apply", &args, &reply); err != nil { + if err := s.agent.RPC(req.Context(), "PreparedQuery.Apply", &args, &reply); err != nil { return nil, err } return nil, nil @@ -277,7 +277,7 @@ func (s *HTTPHandlers) preparedQueryDelete(id string, resp http.ResponseWriter, s.parseToken(req, &args.Token) var reply string - if err := s.agent.RPC("PreparedQuery.Apply", &args, &reply); err != nil { + if err := s.agent.RPC(req.Context(), "PreparedQuery.Apply", &args, &reply); err != nil { return nil, err } return nil, nil diff --git a/agent/prepared_query_endpoint_test.go b/agent/prepared_query_endpoint_test.go index 9cf805b88c..689e9510ec 100644 --- a/agent/prepared_query_endpoint_test.go +++ b/agent/prepared_query_endpoint_test.go @@ -2,6 +2,7 @@ package agent import ( "bytes" + "context" "encoding/json" "fmt" "net/http" @@ -1047,7 +1048,7 @@ func TestPreparedQuery_Integration(t *testing.T) { }, } var out struct{} - if err := a.RPC("Catalog.Register", args, &out); err != nil { + if err := a.RPC(context.Background(), "Catalog.Register", args, &out); err != nil { t.Fatalf("err: %v", err) } } diff --git a/agent/remote_exec.go b/agent/remote_exec.go index 3668ef8406..8519ca4857 100644 --- a/agent/remote_exec.go +++ b/agent/remote_exec.go @@ -1,6 +1,7 @@ package agent import ( + "context" "encoding/json" "fmt" "os" @@ -252,7 +253,7 @@ func (a *Agent) remoteExecGetSpec(event *remoteExecEvent, spec *remoteExecSpec) get.Token = a.tokens.AgentToken() var out structs.IndexedDirEntries QUERY: - if err := a.RPC("KVS.Get", &get, &out); err != nil { + if err := a.RPC(context.Background(), "KVS.Get", &get, &out); err != nil { a.logger.Error("failed to get remote exec job", "error", err) return false } @@ -318,7 +319,7 @@ func (a *Agent) remoteExecWriteKey(event *remoteExecEvent, suffix string, val [] } write.Token = a.tokens.AgentToken() var success bool - if err := a.RPC("KVS.Apply", &write, &success); err != nil { + if err := a.RPC(context.Background(), "KVS.Apply", &write, &success); err != nil { return err } if !success { diff --git a/agent/remote_exec_test.go b/agent/remote_exec_test.go index dc6489fa54..b08f9e14e6 100644 --- a/agent/remote_exec_test.go +++ b/agent/remote_exec_test.go @@ -2,6 +2,7 @@ package agent import ( "bytes" + "context" "encoding/json" "fmt" "reflect" @@ -438,7 +439,7 @@ func makeRexecSession(t *testing.T, a *Agent, token string) string { }, } var out string - if err := a.RPC("Session.Apply", &args, &out); err != nil { + if err := a.RPC(context.Background(), "Session.Apply", &args, &out); err != nil { t.Fatalf("err: %v", err) } return out @@ -456,7 +457,7 @@ func destroySession(t *testing.T, a *Agent, session string, token string) { }, } var out string - if err := a.RPC("Session.Apply", &args, &out); err != nil { + if err := a.RPC(context.Background(), "Session.Apply", &args, &out); err != nil { t.Fatalf("err: %v", err) } } @@ -474,7 +475,7 @@ func setKV(a *Agent, key string, val []byte, token string) error { }, } var success bool - if err := a.RPC("KVS.Apply", &write, &success); err != nil { + if err := a.RPC(context.Background(), "KVS.Apply", &write, &success); err != nil { return err } return nil @@ -489,7 +490,7 @@ func getKV(a *Agent, key string, token string) (*structs.DirEntry, error) { }, } var out structs.IndexedDirEntries - if err := a.RPC("KVS.Get", &req, &out); err != nil { + if err := a.RPC(context.Background(), "KVS.Get", &req, &out); err != nil { return nil, err } if len(out.Entries) > 0 { diff --git a/agent/rpcclient/health/health.go b/agent/rpcclient/health/health.go index 9b24a164dd..1f1781e1a6 100644 --- a/agent/rpcclient/health/health.go +++ b/agent/rpcclient/health/health.go @@ -23,7 +23,7 @@ type Client struct { } type NetRPC interface { - RPC(method string, args interface{}, reply interface{}) error + RPC(ctx context.Context, method string, args interface{}, reply interface{}) error } type CacheGetter interface { @@ -71,7 +71,7 @@ func (c *Client) ServiceNodes( // TODO: DNSServer emitted a metric here, do we still need it? if req.QueryOptions.AllowStale && req.QueryOptions.MaxStaleDuration > 0 && out.QueryMeta.LastContact > req.MaxStaleDuration { req.AllowStale = false - err := c.NetRPC.RPC("Health.ServiceNodes", &req, &out) + err := c.NetRPC.RPC(context.Background(), "Health.ServiceNodes", &req, &out) return out, cache.ResultMeta{}, err } @@ -84,7 +84,7 @@ func (c *Client) getServiceNodes( ) (structs.IndexedCheckServiceNodes, cache.ResultMeta, error) { var out structs.IndexedCheckServiceNodes if !req.QueryOptions.UseCache { - err := c.NetRPC.RPC("Health.ServiceNodes", &req, &out) + err := c.NetRPC.RPC(context.Background(), "Health.ServiceNodes", &req, &out) return out, cache.ResultMeta{}, err } diff --git a/agent/rpcclient/health/health_test.go b/agent/rpcclient/health/health_test.go index 00bc224b7a..a9cc7a087d 100644 --- a/agent/rpcclient/health/health_test.go +++ b/agent/rpcclient/health/health_test.go @@ -171,7 +171,7 @@ type fakeNetRPC struct { calls []string } -func (f *fakeNetRPC) RPC(method string, _ interface{}, _ interface{}) error { +func (f *fakeNetRPC) RPC(ctx context.Context, method string, _ interface{}, _ interface{}) error { f.calls = append(f.calls, method) return nil } diff --git a/agent/service_manager_test.go b/agent/service_manager_test.go index 6b7757a76e..c346268303 100644 --- a/agent/service_manager_test.go +++ b/agent/service_manager_test.go @@ -1,6 +1,7 @@ package agent import ( + "context" "encoding/json" "fmt" "os" @@ -777,7 +778,7 @@ func testApplyConfigEntries(t *testing.T, a *TestAgent, entries ...structs.Confi Entry: entry, } var out bool - require.NoError(t, a.RPC("ConfigEntry.Apply", args, &out)) + require.NoError(t, a.RPC(context.Background(), "ConfigEntry.Apply", args, &out)) } } diff --git a/agent/session_endpoint.go b/agent/session_endpoint.go index 8e76d577ab..2f10b149b2 100644 --- a/agent/session_endpoint.go +++ b/agent/session_endpoint.go @@ -53,7 +53,7 @@ func (s *HTTPHandlers) SessionCreate(resp http.ResponseWriter, req *http.Request // Create the session, get the ID var out string - if err := s.agent.RPC("Session.Apply", &args, &out); err != nil { + if err := s.agent.RPC(req.Context(), "Session.Apply", &args, &out); err != nil { return nil, err } @@ -80,7 +80,7 @@ func (s *HTTPHandlers) SessionDestroy(resp http.ResponseWriter, req *http.Reques } var out string - if err := s.agent.RPC("Session.Apply", &args, &out); err != nil { + if err := s.agent.RPC(req.Context(), "Session.Apply", &args, &out); err != nil { return nil, err } return true, nil @@ -104,7 +104,7 @@ func (s *HTTPHandlers) SessionRenew(resp http.ResponseWriter, req *http.Request) } var out structs.IndexedSessions - if err := s.agent.RPC("Session.Renew", &args, &out); err != nil { + if err := s.agent.RPC(req.Context(), "Session.Renew", &args, &out); err != nil { return nil, err } else if out.Sessions == nil { return nil, HTTPError{StatusCode: http.StatusNotFound, Reason: fmt.Sprintf("Session id '%s' not found", args.SessionID)} @@ -132,7 +132,7 @@ func (s *HTTPHandlers) SessionGet(resp http.ResponseWriter, req *http.Request) ( var out structs.IndexedSessions defer setMeta(resp, &out.QueryMeta) - if err := s.agent.RPC("Session.Get", &args, &out); err != nil { + if err := s.agent.RPC(req.Context(), "Session.Get", &args, &out); err != nil { return nil, err } @@ -155,7 +155,7 @@ func (s *HTTPHandlers) SessionList(resp http.ResponseWriter, req *http.Request) var out structs.IndexedSessions defer setMeta(resp, &out.QueryMeta) - if err := s.agent.RPC("Session.List", &args, &out); err != nil { + if err := s.agent.RPC(req.Context(), "Session.List", &args, &out); err != nil { return nil, err } @@ -184,7 +184,7 @@ func (s *HTTPHandlers) SessionsForNode(resp http.ResponseWriter, req *http.Reque var out structs.IndexedSessions defer setMeta(resp, &out.QueryMeta) - if err := s.agent.RPC("Session.NodeSessions", &args, &out); err != nil { + if err := s.agent.RPC(req.Context(), "Session.NodeSessions", &args, &out); err != nil { return nil, err } diff --git a/agent/session_endpoint_test.go b/agent/session_endpoint_test.go index eb50e99fde..a1b011f25f 100644 --- a/agent/session_endpoint_test.go +++ b/agent/session_endpoint_test.go @@ -2,6 +2,7 @@ package agent import ( "bytes" + "context" "encoding/json" "net/http" "net/http/httptest" @@ -25,7 +26,7 @@ func verifySession(t *testing.T, r *retry.R, a *TestAgent, want structs.Session) SessionID: want.ID, } var out structs.IndexedSessions - if err := a.RPC("Session.Get", args, &out); err != nil { + if err := a.RPC(context.Background(), "Session.Get", args, &out); err != nil { r.Fatalf("err: %v", err) } if len(out.Sessions) != 1 { @@ -88,7 +89,7 @@ func TestSessionCreate(t *testing.T) { retry.Run(t, func(r *retry.R) { var out struct{} - if err := a.RPC("Catalog.Register", args, &out); err != nil { + if err := a.RPC(context.Background(), "Catalog.Register", args, &out); err != nil { r.Fatalf("err: %v", err) } @@ -150,7 +151,7 @@ func TestSessionCreate_NodeChecks(t *testing.T) { retry.Run(t, func(r *retry.R) { var out struct{} - if err := a.RPC("Catalog.Register", args, &out); err != nil { + if err := a.RPC(context.Background(), "Catalog.Register", args, &out); err != nil { r.Fatalf("err: %v", err) } @@ -213,7 +214,7 @@ func TestSessionCreate_Delete(t *testing.T) { } retry.Run(t, func(r *retry.R) { var out struct{} - if err := a.RPC("Catalog.Register", args, &out); err != nil { + if err := a.RPC(context.Background(), "Catalog.Register", args, &out); err != nil { r.Fatalf("err: %v", err) } diff --git a/agent/status_endpoint.go b/agent/status_endpoint.go index 5cc329ac9e..0fbd26837b 100644 --- a/agent/status_endpoint.go +++ b/agent/status_endpoint.go @@ -13,7 +13,7 @@ func (s *HTTPHandlers) StatusLeader(resp http.ResponseWriter, req *http.Request) } var out string - if err := s.agent.RPC("Status.Leader", &args, &out); err != nil { + if err := s.agent.RPC(req.Context(), "Status.Leader", &args, &out); err != nil { return nil, err } return out, nil @@ -26,7 +26,7 @@ func (s *HTTPHandlers) StatusPeers(resp http.ResponseWriter, req *http.Request) } var out []string - if err := s.agent.RPC("Status.Peers", &args, &out); err != nil { + if err := s.agent.RPC(req.Context(), "Status.Peers", &args, &out); err != nil { return nil, err } return out, nil diff --git a/agent/testagent.go b/agent/testagent.go index 9642fca668..db08be40a4 100644 --- a/agent/testagent.go +++ b/agent/testagent.go @@ -296,7 +296,7 @@ func (a *TestAgent) waitForUp() error { MaxQueryTime: 25 * time.Millisecond, }, } - if err := a.RPC("Catalog.ListNodes", args, &out); err != nil { + if err := a.RPC(context.Background(), "Catalog.ListNodes", args, &out); err != nil { retErr = fmt.Errorf("Catalog.ListNodes failed: %v", err) continue // fail, try again } diff --git a/agent/txn_endpoint.go b/agent/txn_endpoint.go index 55c687503a..1eeb4b6d49 100644 --- a/agent/txn_endpoint.go +++ b/agent/txn_endpoint.go @@ -356,7 +356,7 @@ func (s *HTTPHandlers) Txn(resp http.ResponseWriter, req *http.Request) (interfa } var reply structs.TxnReadResponse - if err := s.agent.RPC("Txn.Read", &args, &reply); err != nil { + if err := s.agent.RPC(req.Context(), "Txn.Read", &args, &reply); err != nil { return nil, err } @@ -372,7 +372,7 @@ func (s *HTTPHandlers) Txn(resp http.ResponseWriter, req *http.Request) (interfa s.parseToken(req, &args.Token) var reply structs.TxnResponse - if err := s.agent.RPC("Txn.Apply", &args, &reply); err != nil { + if err := s.agent.RPC(req.Context(), "Txn.Apply", &args, &reply); err != nil { return nil, err } ret, conflict = reply, len(reply.Errors) > 0 diff --git a/agent/ui_endpoint.go b/agent/ui_endpoint.go index dde74c5853..aaeb7003e6 100644 --- a/agent/ui_endpoint.go +++ b/agent/ui_endpoint.go @@ -98,7 +98,7 @@ func (s *HTTPHandlers) UINodes(resp http.ResponseWriter, req *http.Request) (int var out structs.IndexedNodeDump defer setMeta(resp, &out.QueryMeta) RPC: - if err := s.agent.RPC("Internal.NodeDump", &args, &out); err != nil { + if err := s.agent.RPC(req.Context(), "Internal.NodeDump", &args, &out); err != nil { // Retry the request allowing stale data if no leader if strings.Contains(err.Error(), structs.ErrNoLeader.Error()) && !args.AllowStale { args.AllowStale = true @@ -160,7 +160,7 @@ func (s *HTTPHandlers) UINodeInfo(resp http.ResponseWriter, req *http.Request) ( var out structs.IndexedNodeDump defer setMeta(resp, &out.QueryMeta) RPC: - if err := s.agent.RPC("Internal.NodeInfo", &args, &out); err != nil { + if err := s.agent.RPC(req.Context(), "Internal.NodeInfo", &args, &out); err != nil { // Retry the request allowing stale data if no leader if strings.Contains(err.Error(), structs.ErrNoLeader.Error()) && !args.AllowStale { args.AllowStale = true @@ -196,7 +196,7 @@ func (s *HTTPHandlers) UICatalogOverview(resp http.ResponseWriter, req *http.Req // Make the RPC request var out structs.CatalogSummary - if err := s.agent.RPC("Internal.CatalogOverview", &args, &out); err != nil { + if err := s.agent.RPC(req.Context(), "Internal.CatalogOverview", &args, &out); err != nil { return nil, err } @@ -224,7 +224,7 @@ func (s *HTTPHandlers) UIServices(resp http.ResponseWriter, req *http.Request) ( var out structs.IndexedNodesWithGateways defer setMeta(resp, &out.QueryMeta) RPC: - if err := s.agent.RPC("Internal.ServiceDump", &args, &out); err != nil { + if err := s.agent.RPC(req.Context(), "Internal.ServiceDump", &args, &out); err != nil { // Retry the request allowing stale data if no leader if strings.Contains(err.Error(), structs.ErrNoLeader.Error()) && !args.AllowStale { args.AllowStale = true @@ -293,7 +293,7 @@ func (s *HTTPHandlers) UIGatewayServicesNodes(resp http.ResponseWriter, req *htt var out structs.IndexedServiceDump defer setMeta(resp, &out.QueryMeta) RPC: - if err := s.agent.RPC("Internal.GatewayServiceDump", &args, &out); err != nil { + if err := s.agent.RPC(req.Context(), "Internal.GatewayServiceDump", &args, &out); err != nil { // Retry the request allowing stale data if no leader if strings.Contains(err.Error(), structs.ErrNoLeader.Error()) && !args.AllowStale { args.AllowStale = true @@ -346,7 +346,7 @@ func (s *HTTPHandlers) UIServiceTopology(resp http.ResponseWriter, req *http.Req var out structs.IndexedServiceTopology defer setMeta(resp, &out.QueryMeta) RPC: - if err := s.agent.RPC("Internal.ServiceTopology", &args, &out); err != nil { + if err := s.agent.RPC(req.Context(), "Internal.ServiceTopology", &args, &out); err != nil { // Retry the request allowing stale data if no leader if strings.Contains(err.Error(), structs.ErrNoLeader.Error()) && !args.AllowStale { args.AllowStale = true @@ -631,7 +631,7 @@ func (s *HTTPHandlers) UIGatewayIntentions(resp http.ResponseWriter, req *http.R var reply structs.IndexedIntentions defer setMeta(resp, &reply.QueryMeta) - if err := s.agent.RPC("Internal.GatewayIntentions", args, &reply); err != nil { + if err := s.agent.RPC(req.Context(), "Internal.GatewayIntentions", args, &reply); err != nil { return nil, err } @@ -801,7 +801,7 @@ func (s *HTTPHandlers) UIExportedServices(resp http.ResponseWriter, req *http.Re var out structs.IndexedServiceList defer setMeta(resp, &out.QueryMeta) RPC: - if err := s.agent.RPC("Internal.ExportedServicesForPeer", &args, &out); err != nil { + if err := s.agent.RPC(req.Context(), "Internal.ExportedServicesForPeer", &args, &out); err != nil { // Retry the request allowing stale data if no leader if strings.Contains(err.Error(), structs.ErrNoLeader.Error()) && !args.AllowStale { args.AllowStale = true diff --git a/agent/ui_endpoint_oss_test.go b/agent/ui_endpoint_oss_test.go index 2022c32c6b..14fbacd066 100644 --- a/agent/ui_endpoint_oss_test.go +++ b/agent/ui_endpoint_oss_test.go @@ -4,6 +4,7 @@ package agent import ( + "context" "fmt" "net/http" "net/http/httptest" @@ -76,7 +77,7 @@ func TestUIEndpoint_MetricsProxy_ACLDeny(t *testing.T) { WriteRequest: structs.WriteRequest{Token: "root"}, } var policy structs.ACLPolicy - require.NoError(t, a.RPC("ACL.PolicySet", &req, &policy)) + require.NoError(t, a.RPC(context.Background(), "ACL.PolicySet", &req, &policy)) } makeToken := func(t *testing.T, policyNames []string) string { @@ -91,7 +92,7 @@ func TestUIEndpoint_MetricsProxy_ACLDeny(t *testing.T) { require.Len(t, req.ACLToken.Policies, len(policyNames)) var token structs.ACLToken - require.NoError(t, a.RPC("ACL.TokenSet", &req, &token)) + require.NoError(t, a.RPC(context.Background(), "ACL.TokenSet", &req, &token)) return token.SecretID } diff --git a/agent/ui_endpoint_test.go b/agent/ui_endpoint_test.go index 1bd9ff6c19..2d7ff7feb5 100644 --- a/agent/ui_endpoint_test.go +++ b/agent/ui_endpoint_test.go @@ -106,7 +106,7 @@ func TestUINodes(t *testing.T) { for _, reg := range args { var out struct{} - err := a.RPC("Catalog.Register", reg, &out) + err := a.RPC(context.Background(), "Catalog.Register", reg, &out) require.NoError(t, err) } @@ -181,7 +181,7 @@ func TestUINodes_Filter(t *testing.T) { } var out struct{} - require.NoError(t, a.RPC("Catalog.Register", args, &out)) + require.NoError(t, a.RPC(context.Background(), "Catalog.Register", args, &out)) args = &structs.RegisterRequest{ Datacenter: "dc1", @@ -191,7 +191,7 @@ func TestUINodes_Filter(t *testing.T) { "os": "macos", }, } - require.NoError(t, a.RPC("Catalog.Register", args, &out)) + require.NoError(t, a.RPC(context.Background(), "Catalog.Register", args, &out)) req, _ := http.NewRequest("GET", "/v1/internal/ui/nodes/dc1?filter="+url.QueryEscape("Meta.os == linux"), nil) resp := httptest.NewRecorder() @@ -237,7 +237,7 @@ func TestUINodeInfo(t *testing.T) { } var out struct{} - if err := a.RPC("Catalog.Register", args, &out); err != nil { + if err := a.RPC(context.Background(), "Catalog.Register", args, &out); err != nil { t.Fatalf("err: %v", err) } @@ -392,7 +392,7 @@ func TestUIServices(t *testing.T) { for _, args := range requests { var out struct{} - require.NoError(t, a.RPC("Catalog.Register", args, &out)) + require.NoError(t, a.RPC(context.Background(), "Catalog.Register", args, &out)) } // establish "peer1" @@ -427,7 +427,7 @@ func TestUIServices(t *testing.T) { }, } var regOutput struct{} - require.NoError(t, a.RPC("Catalog.Register", &arg, ®Output)) + require.NoError(t, a.RPC(context.Background(), "Catalog.Register", &arg, ®Output)) args := &structs.TerminatingGatewayConfigEntry{ Name: "terminating-gateway", @@ -448,7 +448,7 @@ func TestUIServices(t *testing.T) { Entry: args, } var configOutput bool - require.NoError(t, a.RPC("ConfigEntry.Apply", &req, &configOutput)) + require.NoError(t, a.RPC(context.Background(), "ConfigEntry.Apply", &req, &configOutput)) require.True(t, configOutput) // Web should not show up as ConnectedWithGateway since this one does not have any instances @@ -467,7 +467,7 @@ func TestUIServices(t *testing.T) { Datacenter: "dc1", Entry: args, } - require.NoError(t, a.RPC("ConfigEntry.Apply", &req, &configOutput)) + require.NoError(t, a.RPC(context.Background(), "ConfigEntry.Apply", &req, &configOutput)) require.True(t, configOutput) } @@ -755,7 +755,7 @@ func TestUIExportedServices(t *testing.T) { for _, args := range requests { var out struct{} - require.NoError(t, a.RPC("Catalog.Register", args, &out)) + require.NoError(t, a.RPC(context.Background(), "Catalog.Register", args, &out)) } // establish "peer1" @@ -791,7 +791,7 @@ func TestUIExportedServices(t *testing.T) { Entry: args, } var configOutput bool - require.NoError(t, a.RPC("ConfigEntry.Apply", &req, &configOutput)) + require.NoError(t, a.RPC(context.Background(), "ConfigEntry.Apply", &req, &configOutput)) require.True(t, configOutput) } @@ -874,7 +874,7 @@ func TestUIGatewayServiceNodes_Terminating(t *testing.T) { }, } var regOutput struct{} - require.NoError(t, a.RPC("Catalog.Register", &arg, ®Output)) + require.NoError(t, a.RPC(context.Background(), "Catalog.Register", &arg, ®Output)) arg = structs.RegisterRequest{ Datacenter: "dc1", @@ -891,7 +891,7 @@ func TestUIGatewayServiceNodes_Terminating(t *testing.T) { ServiceID: "db", }, } - require.NoError(t, a.RPC("Catalog.Register", &arg, ®Output)) + require.NoError(t, a.RPC(context.Background(), "Catalog.Register", &arg, ®Output)) arg = structs.RegisterRequest{ Datacenter: "dc1", @@ -908,7 +908,7 @@ func TestUIGatewayServiceNodes_Terminating(t *testing.T) { ServiceID: "db2", }, } - require.NoError(t, a.RPC("Catalog.Register", &arg, ®Output)) + require.NoError(t, a.RPC(context.Background(), "Catalog.Register", &arg, ®Output)) } { @@ -944,7 +944,7 @@ func TestUIGatewayServiceNodes_Terminating(t *testing.T) { Entry: args, } var configOutput bool - require.NoError(t, a.RPC("ConfigEntry.Apply", &req, &configOutput)) + require.NoError(t, a.RPC(context.Background(), "ConfigEntry.Apply", &req, &configOutput)) require.True(t, configOutput) } @@ -1012,7 +1012,7 @@ func TestUIGatewayServiceNodes_Ingress(t *testing.T) { }, } var regOutput struct{} - require.NoError(t, a.RPC("Catalog.Register", &arg, ®Output)) + require.NoError(t, a.RPC(context.Background(), "Catalog.Register", &arg, ®Output)) arg = structs.RegisterRequest{ Datacenter: "dc1", @@ -1029,7 +1029,7 @@ func TestUIGatewayServiceNodes_Ingress(t *testing.T) { ServiceID: "db", }, } - require.NoError(t, a.RPC("Catalog.Register", &arg, ®Output)) + require.NoError(t, a.RPC(context.Background(), "Catalog.Register", &arg, ®Output)) arg = structs.RegisterRequest{ Datacenter: "dc1", @@ -1046,7 +1046,7 @@ func TestUIGatewayServiceNodes_Ingress(t *testing.T) { ServiceID: "db2", }, } - require.NoError(t, a.RPC("Catalog.Register", &arg, ®Output)) + require.NoError(t, a.RPC(context.Background(), "Catalog.Register", &arg, ®Output)) // Set web protocol to http svcDefaultsReq := structs.ConfigEntryRequest{ @@ -1057,7 +1057,7 @@ func TestUIGatewayServiceNodes_Ingress(t *testing.T) { }, } var configOutput bool - require.NoError(t, a.RPC("ConfigEntry.Apply", &svcDefaultsReq, &configOutput)) + require.NoError(t, a.RPC(context.Background(), "ConfigEntry.Apply", &svcDefaultsReq, &configOutput)) require.True(t, configOutput) // Register ingress-gateway config entry, linking it to db and redis (does not exist) @@ -1101,7 +1101,7 @@ func TestUIGatewayServiceNodes_Ingress(t *testing.T) { Datacenter: "dc1", Entry: args, } - require.NoError(t, a.RPC("ConfigEntry.Apply", &req, &configOutput)) + require.NoError(t, a.RPC(context.Background(), "ConfigEntry.Apply", &req, &configOutput)) require.True(t, configOutput) } @@ -1190,7 +1190,7 @@ func TestUIGatewayIntentions(t *testing.T) { }, } var regOutput struct{} - require.NoError(t, a.RPC("Catalog.Register", &arg, ®Output)) + require.NoError(t, a.RPC(context.Background(), "Catalog.Register", &arg, ®Output)) args := &structs.TerminatingGatewayConfigEntry{ Name: "terminating-gateway", @@ -1214,7 +1214,7 @@ func TestUIGatewayIntentions(t *testing.T) { Entry: args, } var configOutput bool - require.NoError(t, a.RPC("ConfigEntry.Apply", &req, &configOutput)) + require.NoError(t, a.RPC(context.Background(), "ConfigEntry.Apply", &req, &configOutput)) require.True(t, configOutput) } @@ -1230,7 +1230,7 @@ func TestUIGatewayIntentions(t *testing.T) { req.Intention.DestinationName = v var reply string - require.NoError(t, a.RPC("Intention.Apply", &req, &reply)) + require.NoError(t, a.RPC(context.Background(), "Intention.Apply", &req, &reply)) req = structs.IntentionRequest{ Datacenter: "dc1", @@ -1239,7 +1239,7 @@ func TestUIGatewayIntentions(t *testing.T) { } req.Intention.SourceName = v req.Intention.DestinationName = "api" - require.NoError(t, a.RPC("Intention.Apply", &req, &reply)) + require.NoError(t, a.RPC(context.Background(), "Intention.Apply", &req, &reply)) } } @@ -1698,7 +1698,7 @@ func TestUIServiceTopology(t *testing.T) { } for _, args := range registrations { var out struct{} - require.NoError(t, a.RPC("Catalog.Register", args, &out)) + require.NoError(t, a.RPC(context.Background(), "Catalog.Register", args, &out)) } } @@ -1836,7 +1836,7 @@ func TestUIServiceTopology(t *testing.T) { } for _, req := range entries { out := false - require.NoError(t, a.RPC("ConfigEntry.Apply", &req, &out)) + require.NoError(t, a.RPC(context.Background(), "ConfigEntry.Apply", &req, &out)) } } @@ -2294,7 +2294,7 @@ func TestUIServiceTopology_RoutingConfigs(t *testing.T) { } for _, args := range registrations { var out struct{} - require.NoError(t, a.RPC("Catalog.Register", args, &out)) + require.NoError(t, a.RPC(context.Background(), "Catalog.Register", args, &out)) } } { @@ -2341,7 +2341,7 @@ func TestUIServiceTopology_RoutingConfigs(t *testing.T) { } for _, req := range entries { out := false - require.NoError(t, a.RPC("ConfigEntry.Apply", &req, &out)) + require.NoError(t, a.RPC(context.Background(), "ConfigEntry.Apply", &req, &out)) } } diff --git a/agent/user_event.go b/agent/user_event.go index bcbd9c5a41..ab7aa83ca3 100644 --- a/agent/user_event.go +++ b/agent/user_event.go @@ -2,6 +2,7 @@ package agent import ( "bytes" + "context" "fmt" "regexp" @@ -105,7 +106,7 @@ func (a *Agent) UserEvent(dc, token string, params *UserEvent) error { // gossip will take over anyways args.AllowStale = true var out structs.EventFireResponse - return a.RPC("Internal.EventFire", &args, &out) + return a.RPC(context.Background(), "Internal.EventFire", &args, &out) } // handleEvents is used to process incoming user events diff --git a/agent/user_event_test.go b/agent/user_event_test.go index 8cae94dde2..419e0ca12a 100644 --- a/agent/user_event_test.go +++ b/agent/user_event_test.go @@ -1,6 +1,7 @@ package agent import ( + "context" "strings" "testing" @@ -231,7 +232,7 @@ func TestUserEventToken(t *testing.T) { } type RPC interface { - RPC(method string, args interface{}, reply interface{}) error + RPC(ctx context.Context, method string, args interface{}, reply interface{}) error } func createToken(t *testing.T, rpc RPC, policyRules string) string { @@ -245,7 +246,7 @@ func createToken(t *testing.T, rpc RPC, policyRules string) string { }, WriteRequest: structs.WriteRequest{Token: "root"}, } - err := rpc.RPC("ACL.PolicySet", &reqPolicy, &structs.ACLPolicy{}) + err := rpc.RPC(context.Background(), "ACL.PolicySet", &reqPolicy, &structs.ACLPolicy{}) require.NoError(t, err) token, err := uuid.GenerateUUID() @@ -259,7 +260,7 @@ func createToken(t *testing.T, rpc RPC, policyRules string) string { }, WriteRequest: structs.WriteRequest{Token: "root"}, } - err = rpc.RPC("ACL.TokenSet", &reqToken, &structs.ACLToken{}) + err = rpc.RPC(context.Background(), "ACL.TokenSet", &reqToken, &structs.ACLToken{}) require.NoError(t, err) return token } diff --git a/command/connect/ca/set/connect_ca_set_test.go b/command/connect/ca/set/connect_ca_set_test.go index a1c1576b7a..76bb90e698 100644 --- a/command/connect/ca/set/connect_ca_set_test.go +++ b/command/connect/ca/set/connect_ca_set_test.go @@ -1,6 +1,7 @@ package set import ( + "context" "strings" "testing" "time" @@ -48,7 +49,7 @@ func TestConnectCASetConfigCommand(t *testing.T) { Datacenter: "dc1", } var reply structs.CAConfiguration - require.NoError(t, a.RPC("ConnectCA.ConfigurationGet", &req, &reply)) + require.NoError(t, a.RPC(context.Background(), "ConnectCA.ConfigurationGet", &req, &reply)) require.Equal(t, "consul", reply.Provider) parsed, err := ca.ParseConsulCAConfig(reply.Config) diff --git a/command/operator/autopilot/set/operator_autopilot_set_test.go b/command/operator/autopilot/set/operator_autopilot_set_test.go index bbd263fcd1..21963216fe 100644 --- a/command/operator/autopilot/set/operator_autopilot_set_test.go +++ b/command/operator/autopilot/set/operator_autopilot_set_test.go @@ -1,6 +1,7 @@ package set import ( + "context" "strings" "testing" "time" @@ -53,7 +54,7 @@ func TestOperatorAutopilotSetConfigCommand(t *testing.T) { Datacenter: "dc1", } var reply structs.AutopilotConfig - if err := a.RPC("Operator.AutopilotGetConfiguration", &req, &reply); err != nil { + if err := a.RPC(context.Background(), "Operator.AutopilotGetConfiguration", &req, &reply); err != nil { t.Fatalf("err: %v", err) } diff --git a/command/rtt/rtt_test.go b/command/rtt/rtt_test.go index cfe236f343..a5bf58a24b 100644 --- a/command/rtt/rtt_test.go +++ b/command/rtt/rtt_test.go @@ -1,6 +1,7 @@ package rtt import ( + "context" "fmt" "strings" "testing" @@ -73,7 +74,7 @@ func TestRTTCommand_LAN(t *testing.T) { Coord: c1, } var reply struct{} - if err := a.RPC("Coordinate.Update", &req, &reply); err != nil { + if err := a.RPC(context.Background(), "Coordinate.Update", &req, &reply); err != nil { t.Fatalf("err: %s", err) } } @@ -84,7 +85,7 @@ func TestRTTCommand_LAN(t *testing.T) { Address: "127.0.0.2", } var reply struct{} - if err := a.RPC("Catalog.Register", &req, &reply); err != nil { + if err := a.RPC(context.Background(), "Catalog.Register", &req, &reply); err != nil { t.Fatalf("err: %s", err) } } @@ -95,7 +96,7 @@ func TestRTTCommand_LAN(t *testing.T) { Node: "dogs", Coord: c2, } - if err := a.RPC("Coordinate.Update", &req, &reply); err != nil { + if err := a.RPC(context.Background(), "Coordinate.Update", &req, &reply); err != nil { t.Fatalf("err: %s", err) } } diff --git a/go.mod b/go.mod index cc1f74be85..0d7d0ebe79 100644 --- a/go.mod +++ b/go.mod @@ -33,7 +33,7 @@ require ( github.com/google/tcpproxy v0.0.0-20180808230851-dfa16c61dad2 github.com/grpc-ecosystem/go-grpc-middleware v1.0.1-0.20190118093823-f849b5445de4 github.com/hashicorp/consul-awsauth v0.0.0-20220713182709-05ac1c5c2706 - github.com/hashicorp/consul-net-rpc v0.0.0-20220307172752-3602954411b4 + github.com/hashicorp/consul-net-rpc v0.0.0-20221205195236-156cfab66a69 github.com/hashicorp/consul/api v1.18.0 github.com/hashicorp/consul/proto-public v0.2.1 github.com/hashicorp/consul/sdk v0.13.0 diff --git a/go.sum b/go.sum index 3918f3d42a..5132c306e0 100644 --- a/go.sum +++ b/go.sum @@ -455,8 +455,8 @@ github.com/grpc-ecosystem/grpc-opentracing v0.0.0-20180507213350-8e809c8a8645/go github.com/h2non/parth v0.0.0-20190131123155-b4df798d6542/go.mod h1:Ow0tF8D4Kplbc8s8sSb3V2oUCygFHVp8gC3Dn6U4MNI= github.com/hashicorp/consul-awsauth v0.0.0-20220713182709-05ac1c5c2706 h1:1ZEjnveDe20yFa6lSkfdQZm5BR/b271n0MsB5R2L3us= github.com/hashicorp/consul-awsauth v0.0.0-20220713182709-05ac1c5c2706/go.mod h1:1Cs8FlmD1BfSQXJGcFLSV5FuIx1AbJP+EJGdxosoS2g= -github.com/hashicorp/consul-net-rpc v0.0.0-20220307172752-3602954411b4 h1:Com/5n/omNSBusX11zdyIYtidiqewLIanchbm//McZA= -github.com/hashicorp/consul-net-rpc v0.0.0-20220307172752-3602954411b4/go.mod h1:vWEAHAeAqfOwB3pSgHMQpIu8VH1jL+Ltg54Tw0wt/NI= +github.com/hashicorp/consul-net-rpc v0.0.0-20221205195236-156cfab66a69 h1:wzWurXrxfSyG1PHskIZlfuXlTSCj1Tsyatp9DtaasuY= +github.com/hashicorp/consul-net-rpc v0.0.0-20221205195236-156cfab66a69/go.mod h1:svUZZDvotY8zTODknUePc6mZ9pX8nN0ViGwWcUSOBEA= github.com/hashicorp/errwrap v1.0.0 h1:hLrqtEDnRye3+sgx6z4qVLNuviH3MR5aQ0ykNJa/UYA= github.com/hashicorp/errwrap v1.0.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4= github.com/hashicorp/go-bexpr v0.1.2 h1:ijMXI4qERbzxbCnkxmfUtwMyjrrk3y+Vt0MxojNCbBs= diff --git a/test/integration/consul-container/go.mod b/test/integration/consul-container/go.mod index 4dd37cea1c..ec478dd155 100644 --- a/test/integration/consul-container/go.mod +++ b/test/integration/consul-container/go.mod @@ -67,7 +67,7 @@ require ( github.com/gorilla/mux v1.7.3 // indirect github.com/grpc-ecosystem/go-grpc-middleware v1.0.1-0.20190118093823-f849b5445de4 // indirect github.com/hashicorp/consul-awsauth v0.0.0-20220713182709-05ac1c5c2706 // indirect - github.com/hashicorp/consul-net-rpc v0.0.0-20220307172752-3602954411b4 // indirect + github.com/hashicorp/consul-net-rpc v0.0.0-20221205195236-156cfab66a69 // indirect github.com/hashicorp/consul/proto-public v0.2.1 // indirect github.com/hashicorp/errwrap v1.1.0 // indirect github.com/hashicorp/go-bexpr v0.1.2 // indirect diff --git a/test/integration/consul-container/go.sum b/test/integration/consul-container/go.sum index 0348b0bea3..86a8443c87 100644 --- a/test/integration/consul-container/go.sum +++ b/test/integration/consul-container/go.sum @@ -488,8 +488,8 @@ github.com/grpc-ecosystem/grpc-gateway v1.9.0/go.mod h1:vNeuVxBJEsws4ogUvrchl83t github.com/grpc-ecosystem/grpc-gateway v1.9.5/go.mod h1:vNeuVxBJEsws4ogUvrchl83t/GYV9WGTSLVdBhOQFDY= github.com/hashicorp/consul-awsauth v0.0.0-20220713182709-05ac1c5c2706 h1:1ZEjnveDe20yFa6lSkfdQZm5BR/b271n0MsB5R2L3us= github.com/hashicorp/consul-awsauth v0.0.0-20220713182709-05ac1c5c2706/go.mod h1:1Cs8FlmD1BfSQXJGcFLSV5FuIx1AbJP+EJGdxosoS2g= -github.com/hashicorp/consul-net-rpc v0.0.0-20220307172752-3602954411b4 h1:Com/5n/omNSBusX11zdyIYtidiqewLIanchbm//McZA= -github.com/hashicorp/consul-net-rpc v0.0.0-20220307172752-3602954411b4/go.mod h1:vWEAHAeAqfOwB3pSgHMQpIu8VH1jL+Ltg54Tw0wt/NI= +github.com/hashicorp/consul-net-rpc v0.0.0-20221205195236-156cfab66a69 h1:wzWurXrxfSyG1PHskIZlfuXlTSCj1Tsyatp9DtaasuY= +github.com/hashicorp/consul-net-rpc v0.0.0-20221205195236-156cfab66a69/go.mod h1:svUZZDvotY8zTODknUePc6mZ9pX8nN0ViGwWcUSOBEA= github.com/hashicorp/consul/proto-public v0.2.1 h1:9dZGW68IEuajEkaAAdXCUovVuKyccBOS0jub4Gee5II= github.com/hashicorp/consul/proto-public v0.2.1/go.mod h1:iWNlBDJIZQJC3bBiCThoqg9i7uk/4RQZYkqH1wiQrss= github.com/hashicorp/errwrap v0.0.0-20141028054710-7554cd9344ce/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4= diff --git a/testrpc/wait.go b/testrpc/wait.go index 39e3d65922..31208c0c48 100644 --- a/testrpc/wait.go +++ b/testrpc/wait.go @@ -1,6 +1,7 @@ package testrpc import ( + "context" "testing" "github.com/stretchr/testify/require" @@ -9,7 +10,7 @@ import ( "github.com/hashicorp/consul/sdk/testutil/retry" ) -type rpcFn func(string, interface{}, interface{}) error +type rpcFn func(context.Context, string, interface{}, interface{}) error // WaitForLeader ensures we have a leader and a node registration. It // does not wait for the Consul (node) service to be ready. Use `WaitForTestAgent` @@ -31,7 +32,7 @@ func WaitForLeader(t *testing.T, rpc rpcFn, dc string, options ...waitOption) { Datacenter: dc, QueryOptions: structs.QueryOptions{Token: flat.Token}, } - if err := rpc("Catalog.ListNodes", args, &out); err != nil { + if err := rpc(context.Background(), "Catalog.ListNodes", args, &out); err != nil { r.Fatalf("Catalog.ListNodes failed: %v", err) } if !out.QueryMeta.KnownLeader { @@ -58,7 +59,7 @@ func WaitUntilNoLeader(t *testing.T, rpc rpcFn, dc string, options ...waitOption Datacenter: dc, QueryOptions: structs.QueryOptions{Token: flat.Token}, } - if err := rpc("Catalog.ListNodes", args, &out); err == nil { + if err := rpc(context.Background(), "Catalog.ListNodes", args, &out); err == nil { r.Fatalf("It still has a leader: %#v", out) } if out.QueryMeta.KnownLeader { @@ -108,7 +109,7 @@ func WaitForTestAgent(t *testing.T, rpc rpcFn, dc string, options ...waitOption) Datacenter: dc, QueryOptions: structs.QueryOptions{Token: flat.Token}, } - if err := rpc("Catalog.ListNodes", dcReq, &nodes); err != nil { + if err := rpc(context.Background(), "Catalog.ListNodes", dcReq, &nodes); err != nil { r.Fatalf("Catalog.ListNodes failed: %v", err) } if len(nodes.Nodes) == 0 { @@ -127,7 +128,7 @@ func WaitForTestAgent(t *testing.T, rpc rpcFn, dc string, options ...waitOption) Node: nodes.Nodes[0].Node, QueryOptions: structs.QueryOptions{Token: flat.Token}, } - if err := rpc("Health.NodeChecks", nodeReq, &checks); err != nil { + if err := rpc(context.Background(), "Health.NodeChecks", nodeReq, &checks); err != nil { r.Fatalf("Health.NodeChecks failed: %v", err) } @@ -156,7 +157,7 @@ func WaitForActiveCARoot(t *testing.T, rpc rpcFn, dc string, expect *structs.CAR Datacenter: dc, } var reply structs.IndexedCARoots - if err := rpc("ConnectCA.Roots", args, &reply); err != nil { + if err := rpc(context.Background(), "ConnectCA.Roots", args, &reply); err != nil { r.Fatalf("err: %v", err) } @@ -185,7 +186,7 @@ func WaitForServiceIntentions(t *testing.T, rpc rpcFn, dc string) { }, } var ignored structs.ConfigEntryDeleteResponse - if err := rpc("ConfigEntry.Delete", args, &ignored); err != nil { + if err := rpc(context.Background(), "ConfigEntry.Delete", args, &ignored); err != nil { r.Fatalf("err: %v", err) } }) @@ -198,7 +199,7 @@ func WaitForACLReplication(t *testing.T, rpc rpcFn, dc string, expectedReplicati } var reply structs.ACLReplicationStatus - require.NoError(r, rpc("ACL.ReplicationStatus", &args, &reply)) + require.NoError(r, rpc(context.Background(), "ACL.ReplicationStatus", &args, &reply)) require.Equal(r, expectedReplicationType, reply.ReplicationType) require.True(r, reply.Running, "Server not running new replicator yet")