diff --git a/agent/consul/leader_peering_test.go b/agent/consul/leader_peering_test.go index 5f59787764..373d80f25d 100644 --- a/agent/consul/leader_peering_test.go +++ b/agent/consul/leader_peering_test.go @@ -567,7 +567,7 @@ func testLeader_PeeringSync_failsForTLSError(t *testing.T, tokenMutateFn func(to } // Since the Establish RPC dials the remote cluster, it will yield the TLS error. - ctx, cancel = context.WithTimeout(context.Background(), 3*time.Second) + ctx, cancel = context.WithTimeout(context.Background(), 10*time.Second) t.Cleanup(cancel) _, err = s2Client.Establish(ctx, &establishReq) require.Contains(t, err.Error(), expectErr) diff --git a/agent/consul/peering_backend.go b/agent/consul/peering_backend.go index 26c2f19436..064015cb67 100644 --- a/agent/consul/peering_backend.go +++ b/agent/consul/peering_backend.go @@ -1,12 +1,16 @@ package consul import ( + "container/ring" "encoding/base64" "encoding/json" "fmt" "strconv" "sync" + "github.com/hashicorp/go-hclog" + "github.com/hashicorp/go-memdb" + "github.com/hashicorp/consul/acl" "github.com/hashicorp/consul/acl/resolver" "github.com/hashicorp/consul/agent/connect" @@ -85,29 +89,117 @@ func (b *PeeringBackend) GetTLSMaterials(generatingToken bool) (string, []string return serverName, caPems, nil } -// GetServerAddresses looks up server or mesh gateway addresses from the state store. -func (b *PeeringBackend) GetServerAddresses() ([]string, error) { - _, rawEntry, err := b.srv.fsm.State().ConfigEntry(nil, structs.MeshConfig, structs.MeshConfigMesh, acl.DefaultEnterpriseMeta()) - if err != nil { - return nil, fmt.Errorf("failed to read mesh config entry: %w", err) - } +// GetLocalServerAddresses looks up server or mesh gateway addresses from the state store for a peer to dial. +func (b *PeeringBackend) GetLocalServerAddresses() ([]string, error) { + store := b.srv.fsm.State() - meshConfig, ok := rawEntry.(*structs.MeshConfigEntry) - if ok && meshConfig.Peering != nil && meshConfig.Peering.PeerThroughMeshGateways { - return meshGatewayAdresses(b.srv.fsm.State()) + useGateways, err := b.PeerThroughMeshGateways(nil) + if err != nil { + // For inbound traffic we prefer to fail fast if we can't determine whether we should peer through MGW. + // This prevents unexpectedly sharing local server addresses when a user only intended to peer through gateways. + return nil, fmt.Errorf("failed to determine if peering should happen through mesh gateways: %w", err) } - return serverAddresses(b.srv.fsm.State()) + if useGateways { + return meshGatewayAdresses(store, nil, true) + } + return serverAddresses(store) } -func meshGatewayAdresses(state *state.Store) ([]string, error) { - _, nodes, err := state.ServiceDump(nil, structs.ServiceKindMeshGateway, true, acl.DefaultEnterpriseMeta(), structs.DefaultPeerKeyword) +// GetDialAddresses returns: the addresses to cycle through when dialing a peer's servers, +// a boolean indicating whether mesh gateways are present, and an optional error. +// The resulting ring buffer is front-loaded with the local mesh gateway addresses if they are present. +func (b *PeeringBackend) GetDialAddresses(logger hclog.Logger, ws memdb.WatchSet, peerID string) (*ring.Ring, bool, error) { + newRing, err := b.fetchPeerServerAddresses(ws, peerID) + if err != nil { + return nil, false, fmt.Errorf("failed to refresh peer server addresses, will continue to use initial addresses: %w", err) + } + + gatewayRing, err := b.maybeFetchGatewayAddresses(ws) + if err != nil { + // If we couldn't fetch the mesh gateway addresses we fall back to dialing the remote server addresses. + logger.Warn("failed to refresh local gateway addresses, will attempt to dial peer directly: %w", "error", err) + return newRing, false, nil + } + if gatewayRing != nil { + // The ordering is important here. We always want to start with the mesh gateway + // addresses and fallback to the remote addresses, so we append the server addresses + // in newRing to gatewayRing. + newRing = gatewayRing.Link(newRing) + } + return newRing, gatewayRing != nil, nil +} + +// fetchPeerServerAddresses will return a ring buffer with the latest peer server addresses. +// If the peering is no longer active or does not have addresses, then we return an error. +func (b *PeeringBackend) fetchPeerServerAddresses(ws memdb.WatchSet, peerID string) (*ring.Ring, error) { + _, peering, err := b.Store().PeeringReadByID(ws, peerID) + if err != nil { + return nil, fmt.Errorf("failed to fetch peer %q: %w", peerID, err) + } + if !peering.IsActive() { + return nil, fmt.Errorf("there is no active peering for %q", peerID) + } + + // IMPORTANT: The address ring buffer must always be length > 0 + if len(peering.PeerServerAddresses) == 0 { + return nil, fmt.Errorf("peer %q has no addresses to dial", peerID) + } + return bufferFromAddresses(peering.PeerServerAddresses), nil +} + +// maybeFetchGatewayAddresses will return a ring buffer with the latest gateway addresses if the +// local datacenter is configured to peer through mesh gateways and there are local gateways registered. +// If neither of these are true then we return a nil buffer. +func (b *PeeringBackend) maybeFetchGatewayAddresses(ws memdb.WatchSet) (*ring.Ring, error) { + useGateways, err := b.PeerThroughMeshGateways(ws) + if err != nil { + return nil, fmt.Errorf("failed to determine if peering should happen through mesh gateways: %w", err) + } + if useGateways { + addresses, err := meshGatewayAdresses(b.srv.fsm.State(), ws, false) + + // IMPORTANT: The address ring buffer must always be length > 0 + if err != nil || len(addresses) == 0 { + return nil, fmt.Errorf("error fetching local mesh gateway addresses: %w", err) + } + return bufferFromAddresses(addresses), nil + } + return nil, nil +} + +// PeerThroughMeshGateways determines if the config entry to enable peering control plane +// traffic through a mesh gateway is set to enable. +func (b *PeeringBackend) PeerThroughMeshGateways(ws memdb.WatchSet) (bool, error) { + _, rawEntry, err := b.srv.fsm.State().ConfigEntry(ws, structs.MeshConfig, structs.MeshConfigMesh, acl.DefaultEnterpriseMeta()) + if err != nil { + return false, fmt.Errorf("failed to read mesh config entry: %w", err) + } + mesh, ok := rawEntry.(*structs.MeshConfigEntry) + if rawEntry != nil && !ok { + return false, fmt.Errorf("got unexpected type for mesh config entry: %T", rawEntry) + } + return mesh.PeerThroughMeshGateways(), nil + +} + +func bufferFromAddresses(addresses []string) *ring.Ring { + ring := ring.New(len(addresses)) + for _, addr := range addresses { + ring.Value = addr + ring = ring.Next() + } + return ring +} + +func meshGatewayAdresses(state *state.Store, ws memdb.WatchSet, wan bool) ([]string, error) { + _, nodes, err := state.ServiceDump(ws, structs.ServiceKindMeshGateway, true, acl.DefaultEnterpriseMeta(), structs.DefaultPeerKeyword) if err != nil { return nil, fmt.Errorf("failed to dump gateway addresses: %w", err) } var addrs []string for _, node := range nodes { - _, addr, port := node.BestAddress(true) + _, addr, port := node.BestAddress(wan) addrs = append(addrs, ipaddr.FormatAddressPort(addr, port)) } if len(addrs) == 0 { diff --git a/agent/consul/peering_backend_test.go b/agent/consul/peering_backend_test.go index 0d834c09a9..63a42e07a9 100644 --- a/agent/consul/peering_backend_test.go +++ b/agent/consul/peering_backend_test.go @@ -10,6 +10,7 @@ import ( gogrpc "google.golang.org/grpc" "github.com/hashicorp/consul/agent/connect" + "github.com/hashicorp/consul/agent/consul/state" "github.com/hashicorp/consul/agent/pool" "github.com/hashicorp/consul/agent/structs" "github.com/hashicorp/consul/proto/pbpeering" @@ -76,7 +77,7 @@ func TestPeeringBackend_ForwardToLeader(t *testing.T) { }) } -func TestPeeringBackend_GetServerAddresses(t *testing.T) { +func TestPeeringBackend_GetLocalServerAddresses(t *testing.T) { if testing.Short() { t.Skip("too slow for testing.Short") } @@ -91,7 +92,7 @@ func TestPeeringBackend_GetServerAddresses(t *testing.T) { backend := NewPeeringBackend(srv) testutil.RunStep(t, "peer to servers", func(t *testing.T) { - addrs, err := backend.GetServerAddresses() + addrs, err := backend.GetLocalServerAddresses() require.NoError(t, err) expect := fmt.Sprintf("127.0.0.1:%d", srv.config.GRPCTLSPort) @@ -107,7 +108,7 @@ func TestPeeringBackend_GetServerAddresses(t *testing.T) { } require.NoError(t, srv.fsm.State().EnsureConfigEntry(1, &mesh)) - addrs, err := backend.GetServerAddresses() + addrs, err := backend.GetLocalServerAddresses() require.NoError(t, err) // Still expect server address because PeerThroughMeshGateways was not enabled. @@ -121,7 +122,7 @@ func TestPeeringBackend_GetServerAddresses(t *testing.T) { } require.NoError(t, srv.fsm.State().EnsureConfigEntry(1, &mesh)) - addrs, err := backend.GetServerAddresses() + addrs, err := backend.GetLocalServerAddresses() require.Nil(t, addrs) testutil.RequireErrorContains(t, err, "servers are configured to PeerThroughMeshGateways, but no mesh gateway instances are registered") @@ -147,12 +148,221 @@ func TestPeeringBackend_GetServerAddresses(t *testing.T) { } require.NoError(t, srv.fsm.State().EnsureRegistration(2, ®)) - addrs, err := backend.GetServerAddresses() + addrs, err := backend.GetLocalServerAddresses() require.NoError(t, err) require.Equal(t, []string{"154.238.12.252:8443"}, addrs) }) } +func TestPeeringBackend_GetDialAddresses(t *testing.T) { + if testing.Short() { + t.Skip("too slow for testing.Short") + } + + _, cfg := testServerConfig(t) + cfg.GRPCTLSPort = freeport.GetOne(t) + + srv, err := newServer(t, cfg) + require.NoError(t, err) + testrpc.WaitForLeader(t, srv.RPC, "dc1") + + backend := NewPeeringBackend(srv) + + dialerPeerID := testUUID() + acceptorPeerID := testUUID() + + type expectation struct { + addrs []string + haveGateways bool + err string + } + + type testCase struct { + name string + setup func(store *state.Store) + peerID string + expect expectation + } + + run := func(t *testing.T, tc testCase) { + if tc.setup != nil { + tc.setup(srv.fsm.State()) + } + + ring, haveGateways, err := backend.GetDialAddresses(testutil.Logger(t), nil, tc.peerID) + if tc.expect.err != "" { + testutil.RequireErrorContains(t, err, tc.expect.err) + return + } + require.Equal(t, tc.expect.haveGateways, haveGateways) + require.NotNil(t, ring) + + var addrs []string + ring.Do(func(value any) { + addr, ok := value.(string) + + require.True(t, ok) + addrs = append(addrs, addr) + }) + require.Equal(t, tc.expect.addrs, addrs) + } + + // NOTE: The following tests are set up to run serially with RunStep to save on the setup/teardown cost for a test server. + tt := []testCase{ + { + name: "unknown peering", + setup: func(store *state.Store) { + // Test peering is not written during setup + }, + peerID: acceptorPeerID, + expect: expectation{ + err: fmt.Sprintf(`there is no active peering for %q`, acceptorPeerID), + }, + }, + { + name: "no server addresses", + setup: func(store *state.Store) { + require.NoError(t, store.PeeringWrite(1, &pbpeering.PeeringWriteRequest{ + Peering: &pbpeering.Peering{ + Name: "acceptor", + ID: acceptorPeerID, + // Acceptor peers do not have PeerServerAddresses populated locally. + }, + })) + }, + peerID: acceptorPeerID, + expect: expectation{ + err: fmt.Sprintf(`peer %q has no addresses to dial`, acceptorPeerID), + }, + }, + { + name: "only server addrs are returned when mesh config does not exist", + setup: func(store *state.Store) { + require.NoError(t, store.PeeringWrite(2, &pbpeering.PeeringWriteRequest{ + Peering: &pbpeering.Peering{ + Name: "dialer", + ID: dialerPeerID, + PeerServerAddresses: []string{"1.2.3.4:8502", "2.3.4.5:8503"}, + }, + })) + + // Mesh config entry does not exist + }, + peerID: dialerPeerID, + expect: expectation{ + haveGateways: false, + addrs: []string{"1.2.3.4:8502", "2.3.4.5:8503"}, + }, + }, + { + name: "only server addrs are returned when not peering through gateways", + setup: func(store *state.Store) { + require.NoError(t, srv.fsm.State().EnsureConfigEntry(3, &structs.MeshConfigEntry{ + Peering: &structs.PeeringMeshConfig{ + PeerThroughMeshGateways: false, // Peering through gateways is not enabled + }, + })) + }, + peerID: dialerPeerID, + expect: expectation{ + haveGateways: false, + addrs: []string{"1.2.3.4:8502", "2.3.4.5:8503"}, + }, + }, + { + name: "only server addrs are returned when peering through gateways without gateways registered", + setup: func(store *state.Store) { + require.NoError(t, srv.fsm.State().EnsureConfigEntry(4, &structs.MeshConfigEntry{ + Peering: &structs.PeeringMeshConfig{ + PeerThroughMeshGateways: true, + }, + })) + + // No gateways are registered + }, + peerID: dialerPeerID, + expect: expectation{ + haveGateways: false, + + // Fall back to remote server addresses + addrs: []string{"1.2.3.4:8502", "2.3.4.5:8503"}, + }, + }, + { + name: "gateway addresses are included after gateways are registered", + setup: func(store *state.Store) { + require.NoError(t, srv.fsm.State().EnsureRegistration(5, &structs.RegisterRequest{ + ID: types.NodeID(testUUID()), + Node: "gateway-node-1", + Address: "5.6.7.8", + Service: &structs.NodeService{ + Kind: structs.ServiceKindMeshGateway, + ID: "mesh-gateway-1", + Service: "mesh-gateway", + Port: 8443, + TaggedAddresses: map[string]structs.ServiceAddress{ + structs.TaggedAddressWAN: { + Address: "my-lb-addr.not-aws.com", + Port: 443, + }, + }, + }, + })) + require.NoError(t, srv.fsm.State().EnsureRegistration(6, &structs.RegisterRequest{ + ID: types.NodeID(testUUID()), + Node: "gateway-node-2", + Address: "6.7.8.9", + Service: &structs.NodeService{ + Kind: structs.ServiceKindMeshGateway, + ID: "mesh-gateway-2", + Service: "mesh-gateway", + Port: 8443, + TaggedAddresses: map[string]structs.ServiceAddress{ + structs.TaggedAddressWAN: { + Address: "my-other-lb-addr.not-aws.com", + Port: 443, + }, + }, + }, + })) + }, + peerID: dialerPeerID, + expect: expectation{ + haveGateways: true, + + // Gateways come first, and we use their LAN addresses since this is for outbound communication. + addrs: []string{"6.7.8.9:8443", "5.6.7.8:8443", "1.2.3.4:8502", "2.3.4.5:8503"}, + }, + }, + { + name: "addresses are not returned if the peering is deleted", + setup: func(store *state.Store) { + require.NoError(t, store.PeeringWrite(5, &pbpeering.PeeringWriteRequest{ + Peering: &pbpeering.Peering{ + Name: "dialer", + ID: dialerPeerID, + PeerServerAddresses: []string{"1.2.3.4:8502", "2.3.4.5:8503"}, + + // Mark as deleted + State: pbpeering.PeeringState_DELETING, + DeletedAt: structs.TimeToProto(time.Now()), + }, + })) + }, + peerID: dialerPeerID, + expect: expectation{ + err: fmt.Sprintf(`there is no active peering for %q`, dialerPeerID), + }, + }, + } + + for _, tc := range tt { + testutil.RunStep(t, tc.name, func(t *testing.T) { + run(t, tc) + }) + } +} + func newServerDialer(serverAddr string) func(context.Context, string) (net.Conn, error) { return func(ctx context.Context, addr string) (net.Conn, error) { d := net.Dialer{} diff --git a/agent/rpc/peering/service.go b/agent/rpc/peering/service.go index c5abb3d9af..c6eb78dea9 100644 --- a/agent/rpc/peering/service.go +++ b/agent/rpc/peering/service.go @@ -1,6 +1,7 @@ package peering import ( + "container/ring" "context" "errors" "fmt" @@ -8,6 +9,7 @@ import ( "time" "github.com/armon/go-metrics" + "github.com/hashicorp/consul/lib/retry" "github.com/hashicorp/go-hclog" "github.com/hashicorp/go-memdb" "github.com/hashicorp/go-multierror" @@ -36,6 +38,15 @@ var ( errPeeringTokenEmptyPeerID = errors.New("peering token peer ID value is empty") ) +const ( + // meshGatewayWait is the initial wait on calls to exchange a secret with a peer when dialing through a gateway. + // This wait provides some time for the first gateway address to configure a route to the peer servers. + // Why 350ms? That is roughly the p50 latency we observed in a scale test for proxy config propagation: + // https://www.hashicorp.com/cgsb + meshGatewayWait = 350 * time.Millisecond + establishmentTimeout = 5 * time.Second +) + // errPeeringInvalidServerAddress is returned when an establish request contains // an invalid server address. type errPeeringInvalidServerAddress struct { @@ -118,9 +129,9 @@ type Backend interface { // It returns the server name to validate, and the CA certificate to validate with. GetTLSMaterials(generatingToken bool) (string, []string, error) - // GetServerAddresses returns the addresses used for establishing a peering connection. + // GetLocalServerAddresses returns the addresses used for establishing a peering connection. // These may be server addresses or mesh gateway addresses if peering through mesh gateways. - GetServerAddresses() ([]string, error) + GetLocalServerAddresses() ([]string, error) // EncodeToken packages a peering token into a slice of bytes. EncodeToken(tok *structs.PeeringToken) ([]byte, error) @@ -128,6 +139,12 @@ type Backend interface { // DecodeToken unpackages a peering token from a slice of bytes. DecodeToken([]byte) (*structs.PeeringToken, error) + // GetDialAddresses returns: the addresses to cycle through when dialing a peer's servers, + // a boolean indicating whether mesh gateways are present, and an optional error. + // The resulting ring buffer is front-loaded with the local mesh gateway addresses if the local + // datacenter is configured to dial through mesh gateways. + GetDialAddresses(logger hclog.Logger, ws memdb.WatchSet, peerID string) (*ring.Ring, bool, error) + EnterpriseCheckPartitions(partition string) error EnterpriseCheckNamespaces(namespace string) error @@ -298,7 +315,7 @@ func (s *Server) GenerateToken( if len(req.ServerExternalAddresses) > 0 { serverAddrs = req.ServerExternalAddresses } else { - serverAddrs, err = s.Backend.GetServerAddresses() + serverAddrs, err = s.Backend.GetLocalServerAddresses() if err != nil { return nil, err } @@ -419,7 +436,13 @@ func (s *Server) Establish( PeerServerName: tok.ServerName, PeerID: tok.PeerID, Meta: req.Meta, - State: pbpeering.PeeringState_ESTABLISHING, + + // State is intentionally not set until after the secret exchange succeeds. + // This is to prevent a scenario where an active peering is re-established, + // the secret exchange fails, and the peering state gets stuck in "Establishing" + // while the original connection is still active. + // State: pbpeering.PeeringState_ESTABLISHING, + // PartitionOrEmpty is used to avoid writing "default" in OSS. Partition: entMeta.PartitionOrEmpty(), Remote: &pbpeering.RemoteInfo{ @@ -428,39 +451,30 @@ func (s *Server) Establish( }, } - tlsOption, err := peering.TLSDialOption() - if err != nil { - return nil, fmt.Errorf("failed to build TLS dial option from peering: %w", err) + // Write the peering ahead of the ExchangeSecret handshake to give + // mesh gateways in the default partition an opportunity + // to update their config with an outbound route to this peer server. + // + // If the request to exchange a secret fails then the peering will continue to exist. + // We do not undo this write because this call to establish may actually be a re-establish call + // for an active peering. + writeReq := &pbpeering.PeeringWriteRequest{ + Peering: peering, + } + if err := s.Backend.PeeringWrite(writeReq); err != nil { + return nil, fmt.Errorf("failed to write peering: %w", err) } - exchangeReq := pbpeerstream.ExchangeSecretRequest{ - PeerID: peering.PeerID, - EstablishmentSecret: tok.EstablishmentSecret, - } - var exchangeResp *pbpeerstream.ExchangeSecretResponse - - // Loop through the known server addresses once, attempting to fetch the long-lived stream secret. - var dialErrors error - for _, addr := range serverAddrs { - exchangeResp, err = exchangeSecret(ctx, addr, tlsOption, &exchangeReq) - if err != nil { - dialErrors = multierror.Append(dialErrors, fmt.Errorf("failed to exchange peering secret with %q: %w", addr, err)) - } - if exchangeResp != nil { - break - } - } + exchangeResp, dialErrors := s.exchangeSecret(ctx, peering, tok.EstablishmentSecret) if exchangeResp == nil { return nil, dialErrors } + peering.State = pbpeering.PeeringState_ESTABLISHING // As soon as a peering is written with a non-empty list of ServerAddresses // and an active stream secret, a leader routine will see the peering and // attempt to establish a peering stream with the remote peer. - // - // This peer now has a record of both the LocalPeerID(ID) and - // RemotePeerID(PeerID) but at this point the other peer does not. - writeReq := &pbpeering.PeeringWriteRequest{ + writeReq = &pbpeering.PeeringWriteRequest{ Peering: peering, SecretsRequest: &pbpeering.SecretsWriteRequest{ PeerID: peering.ID, @@ -474,7 +488,6 @@ func (s *Server) Establish( if err := s.Backend.PeeringWrite(writeReq); err != nil { return nil, fmt.Errorf("failed to write peering: %w", err) } - // TODO(peering): low prio: consider adding response details return resp, nil } @@ -493,20 +506,78 @@ func (s *Server) validatePeeringLocality(token *structs.PeeringToken) error { return nil } -func exchangeSecret(ctx context.Context, addr string, tlsOption grpc.DialOption, req *pbpeerstream.ExchangeSecretRequest) (*pbpeerstream.ExchangeSecretResponse, error) { - dialCtx, cancel := context.WithTimeout(ctx, 10*time.Second) +// exchangeSecret will continuously attempt to exchange the given establishment secret with the peer, up to a timeout. +// This function will attempt to dial through mesh gateways if the local DC is configured to peer through gateways, +// but will fall back to server addresses if not. +func (s *Server) exchangeSecret(ctx context.Context, peering *pbpeering.Peering, establishmentSecret string) (*pbpeerstream.ExchangeSecretResponse, error) { + req := pbpeerstream.ExchangeSecretRequest{ + PeerID: peering.PeerID, + EstablishmentSecret: establishmentSecret, + } + + tlsOption, err := peering.TLSDialOption() + if err != nil { + return nil, fmt.Errorf("failed to build TLS dial option from peering: %w", err) + } + + ringBuf, usingGateways, err := s.Backend.GetDialAddresses(s.Logger, nil, peering.ID) + if err != nil { + return nil, fmt.Errorf("failed to get addresses to dial peer: %w", err) + } + + var ( + resp *pbpeerstream.ExchangeSecretResponse + dialErrors error + ) + + retryWait := 150 * time.Millisecond + jitter := retry.NewJitter(25) + + if usingGateways { + // If we are dialing through local gateways we sleep before issuing the first request. + // This gives the local gateways some time to configure a route to the peer servers. + time.Sleep(meshGatewayWait) + } + + retryCtx, cancel := context.WithTimeout(ctx, establishmentTimeout) defer cancel() - conn, err := grpc.DialContext(dialCtx, addr, - tlsOption, - ) - if err != nil { - return nil, fmt.Errorf("failed to dial peer: %w", err) - } - defer conn.Close() + for retryCtx.Err() == nil { + addr := ringBuf.Value.(string) - client := pbpeerstream.NewPeerStreamServiceClient(conn) - return client.ExchangeSecret(ctx, req) + dialCtx, cancel := context.WithTimeout(ctx, 2*time.Second) + defer cancel() + + conn, err := grpc.DialContext(dialCtx, addr, + tlsOption, + ) + if err != nil { + return nil, fmt.Errorf("failed to dial peer: %w", err) + } + defer conn.Close() + + client := pbpeerstream.NewPeerStreamServiceClient(conn) + resp, err = client.ExchangeSecret(ctx, &req) + + // If we got a permission denied error that means out establishment secret is invalid, so we do not retry. + grpcErr, ok := grpcstatus.FromError(err) + if ok && grpcErr.Code() == codes.PermissionDenied { + return nil, fmt.Errorf("a new peering token must be generated: %w", grpcErr.Err()) + } + if err != nil { + dialErrors = multierror.Append(dialErrors, fmt.Errorf("failed to exchange peering secret through address %q: %w", addr, err)) + } + if resp != nil { + // Got a valid response. We're done. + break + } + + time.Sleep(jitter(retryWait)) + + // Cycle to the next possible address. + ringBuf = ringBuf.Next() + } + return resp, dialErrors } // OPTIMIZE: Handle blocking queries diff --git a/agent/rpc/peering/service_test.go b/agent/rpc/peering/service_test.go index 300e463bee..f2db059e75 100644 --- a/agent/rpc/peering/service_test.go +++ b/agent/rpc/peering/service_test.go @@ -12,6 +12,7 @@ import ( "testing" "time" + "github.com/google/tcpproxy" "github.com/hashicorp/go-hclog" "github.com/hashicorp/go-uuid" "github.com/stretchr/testify/require" @@ -476,6 +477,188 @@ func TestPeeringService_Establish(t *testing.T) { }) } +func TestPeeringService_Establish_ThroughMeshGateway(t *testing.T) { + // This test is timing-sensitive, must not be run in parallel. + // t.Parallel() + + acceptor := newTestServer(t, func(conf *consul.Config) { + conf.NodeName = "acceptor" + }) + acceptorClient := pbpeering.NewPeeringServiceClient(acceptor.ClientConn(t)) + + dialer := newTestServer(t, func(conf *consul.Config) { + conf.NodeName = "dialer" + conf.Datacenter = "dc2" + conf.PrimaryDatacenter = "dc2" + }) + dialerClient := pbpeering.NewPeeringServiceClient(dialer.ClientConn(t)) + + var peeringToken string + + testutil.RunStep(t, "retry until timeout on dial errors", func(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + t.Cleanup(cancel) + + testToken := structs.PeeringToken{ + ServerAddresses: []string{fmt.Sprintf("127.0.0.1:%d", freeport.GetOne(t))}, + PeerID: testUUID(t), + } + testTokenJSON, _ := json.Marshal(&testToken) + testTokenB64 := base64.StdEncoding.EncodeToString(testTokenJSON) + + start := time.Now() + _, err := dialerClient.Establish(ctx, &pbpeering.EstablishRequest{ + PeerName: "my-peer-acceptor", + PeeringToken: testTokenB64, + }) + require.Error(t, err) + testutil.RequireErrorContains(t, err, "connection refused") + + require.Greater(t, time.Since(start), 5*time.Second) + }) + + testutil.RunStep(t, "peering can be established from token", func(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + t.Cleanup(cancel) + + // Generate a peering token for dialer + tokenResp, err := acceptorClient.GenerateToken(ctx, &pbpeering.GenerateTokenRequest{PeerName: "my-peer-dialer"}) + require.NoError(t, err) + + // Capture peering token for re-use later + peeringToken = tokenResp.PeeringToken + + // The context timeout is short, it checks that we do not wait the 350ms that we do when peering through mesh gateways + ctx, cancel = context.WithTimeout(context.Background(), 300*time.Millisecond) + t.Cleanup(cancel) + + _, err = dialerClient.Establish(ctx, &pbpeering.EstablishRequest{ + PeerName: "my-peer-acceptor", + PeeringToken: tokenResp.PeeringToken, + }) + require.NoError(t, err) + }) + + testutil.RunStep(t, "fail fast on permission denied", func(t *testing.T) { + // This test case re-uses the previous token since the establishment secret will have been invalidated. + // The context timeout is short, it checks that we do not retry. + ctx, cancel := context.WithTimeout(context.Background(), 300*time.Millisecond) + t.Cleanup(cancel) + + _, err := dialerClient.Establish(ctx, &pbpeering.EstablishRequest{ + PeerName: "my-peer-acceptor", + PeeringToken: peeringToken, + }) + testutil.RequireErrorContains(t, err, "a new peering token must be generated") + }) + + gatewayPort := freeport.GetOne(t) + + testutil.RunStep(t, "fail past bad mesh gateway", func(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) + t.Cleanup(cancel) + + // Generate a new peering token for the dialer. + tokenResp, err := acceptorClient.GenerateToken(ctx, &pbpeering.GenerateTokenRequest{PeerName: "my-peer-dialer"}) + require.NoError(t, err) + + store := dialer.Server.FSM().State() + require.NoError(t, store.EnsureConfigEntry(1, &structs.MeshConfigEntry{ + Peering: &structs.PeeringMeshConfig{ + PeerThroughMeshGateways: true, + }, + })) + + // Register a gateway that isn't actually listening. + require.NoError(t, store.EnsureRegistration(2, &structs.RegisterRequest{ + ID: types.NodeID(testUUID(t)), + Node: "gateway-node-1", + Address: "127.0.0.1", + Service: &structs.NodeService{ + Kind: structs.ServiceKindMeshGateway, + ID: "mesh-gateway-1", + Service: "mesh-gateway", + Port: gatewayPort, + }, + })) + + ctx, cancel = context.WithTimeout(context.Background(), 3*time.Second) + t.Cleanup(cancel) + + // Call to establish should succeed when we fall back to remote server address. + _, err = dialerClient.Establish(ctx, &pbpeering.EstablishRequest{ + PeerName: "my-peer-acceptor", + PeeringToken: tokenResp.PeeringToken, + }) + require.NoError(t, err) + }) + + testutil.RunStep(t, "route through gateway", func(t *testing.T) { + // Spin up a proxy listening at the gateway port registered above. + gatewayAddr := fmt.Sprintf("127.0.0.1:%d", gatewayPort) + + // Configure a TCP proxy with an SNI route corresponding to the acceptor cluster. + var proxy tcpproxy.Proxy + target := &connWrapper{ + proxy: tcpproxy.DialProxy{ + Addr: acceptor.PublicGRPCAddr, + }, + } + proxy.AddSNIRoute(gatewayAddr, "server.dc1.peering.11111111-2222-3333-4444-555555555555.consul", target) + proxy.AddStopACMESearch(gatewayAddr) + + require.NoError(t, proxy.Start()) + t.Cleanup(func() { + proxy.Close() + proxy.Wait() + }) + + ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) + t.Cleanup(cancel) + + // Generate a new peering token for the dialer. + tokenResp, err := acceptorClient.GenerateToken(ctx, &pbpeering.GenerateTokenRequest{PeerName: "my-peer-dialer"}) + require.NoError(t, err) + + store := dialer.Server.FSM().State() + require.NoError(t, store.EnsureConfigEntry(1, &structs.MeshConfigEntry{ + Peering: &structs.PeeringMeshConfig{ + PeerThroughMeshGateways: true, + }, + })) + + ctx, cancel = context.WithTimeout(context.Background(), 1*time.Second) + t.Cleanup(cancel) + + start := time.Now() + + // Call to establish should succeed through the proxy. + _, err = dialerClient.Establish(ctx, &pbpeering.EstablishRequest{ + PeerName: "my-peer-acceptor", + PeeringToken: tokenResp.PeeringToken, + }) + require.NoError(t, err) + + // Dialing through a gateway is preceded by a mandatory 350ms sleep. + require.Greater(t, time.Since(start), 350*time.Millisecond) + + // target.called is true when the tcproxy's conn handler was invoked. + // This lets us know that the "Establish" success flowed through the proxy masquerading as a gateway. + require.True(t, target.called) + }) +} + +// connWrapper is a wrapper around tcpproxy.DialProxy to enable tracking whether the proxy handled a connection. +type connWrapper struct { + proxy tcpproxy.DialProxy + called bool +} + +func (w *connWrapper) HandleConn(src net.Conn) { + w.called = true + w.proxy.HandleConn(src) +} + func TestPeeringService_Establish_ACLEnforcement(t *testing.T) { validToken := peering.TestPeeringToken("83474a06-cca4-4ff4-99a4-4152929c8160") validTokenJSON, _ := json.Marshal(&validToken)