From e69bc727ecf11bd5d993364a4522c2b38d40286c Mon Sep 17 00:00:00 2001 From: freddygv Date: Mon, 10 Oct 2022 13:45:30 -0600 Subject: [PATCH] Update peering establishment to maybe use gateways When peering through mesh gateways we expect outbound dials to peer servers to flow through the local mesh gateway addresses. Now when establishing a peering we get a list of dial addresses as a ring buffer that includes local mesh gateway addresses if the local DC is configured to peer through mesh gateways. The ring buffer includes the mesh gateway addresses first, but also includes the remote server addresses as a fallback. This fallback is present because it's possible that direct egress from the servers may be allowed. If not allowed then the leader will cycle back to a mesh gateway address through the ring. When attempting to dial the remote servers we retry up to a fixed timeout. If using mesh gateways we also have an initial wait in order to allow for the mesh gateways to configure themselves. Note that if we encounter a permission denied error we do not retry since that error indicates that the secret in the peering token is invalid. --- agent/consul/leader_peering_test.go | 2 +- agent/consul/peering_backend.go | 118 ++++++++++++-- agent/consul/peering_backend_test.go | 220 ++++++++++++++++++++++++++- agent/rpc/peering/service.go | 151 +++++++++++++----- agent/rpc/peering/service_test.go | 183 ++++++++++++++++++++++ 5 files changed, 615 insertions(+), 59 deletions(-) diff --git a/agent/consul/leader_peering_test.go b/agent/consul/leader_peering_test.go index 5f59787764..373d80f25d 100644 --- a/agent/consul/leader_peering_test.go +++ b/agent/consul/leader_peering_test.go @@ -567,7 +567,7 @@ func testLeader_PeeringSync_failsForTLSError(t *testing.T, tokenMutateFn func(to } // Since the Establish RPC dials the remote cluster, it will yield the TLS error. - ctx, cancel = context.WithTimeout(context.Background(), 3*time.Second) + ctx, cancel = context.WithTimeout(context.Background(), 10*time.Second) t.Cleanup(cancel) _, err = s2Client.Establish(ctx, &establishReq) require.Contains(t, err.Error(), expectErr) diff --git a/agent/consul/peering_backend.go b/agent/consul/peering_backend.go index 26c2f19436..064015cb67 100644 --- a/agent/consul/peering_backend.go +++ b/agent/consul/peering_backend.go @@ -1,12 +1,16 @@ package consul import ( + "container/ring" "encoding/base64" "encoding/json" "fmt" "strconv" "sync" + "github.com/hashicorp/go-hclog" + "github.com/hashicorp/go-memdb" + "github.com/hashicorp/consul/acl" "github.com/hashicorp/consul/acl/resolver" "github.com/hashicorp/consul/agent/connect" @@ -85,29 +89,117 @@ func (b *PeeringBackend) GetTLSMaterials(generatingToken bool) (string, []string return serverName, caPems, nil } -// GetServerAddresses looks up server or mesh gateway addresses from the state store. -func (b *PeeringBackend) GetServerAddresses() ([]string, error) { - _, rawEntry, err := b.srv.fsm.State().ConfigEntry(nil, structs.MeshConfig, structs.MeshConfigMesh, acl.DefaultEnterpriseMeta()) - if err != nil { - return nil, fmt.Errorf("failed to read mesh config entry: %w", err) - } +// GetLocalServerAddresses looks up server or mesh gateway addresses from the state store for a peer to dial. +func (b *PeeringBackend) GetLocalServerAddresses() ([]string, error) { + store := b.srv.fsm.State() - meshConfig, ok := rawEntry.(*structs.MeshConfigEntry) - if ok && meshConfig.Peering != nil && meshConfig.Peering.PeerThroughMeshGateways { - return meshGatewayAdresses(b.srv.fsm.State()) + useGateways, err := b.PeerThroughMeshGateways(nil) + if err != nil { + // For inbound traffic we prefer to fail fast if we can't determine whether we should peer through MGW. + // This prevents unexpectedly sharing local server addresses when a user only intended to peer through gateways. + return nil, fmt.Errorf("failed to determine if peering should happen through mesh gateways: %w", err) } - return serverAddresses(b.srv.fsm.State()) + if useGateways { + return meshGatewayAdresses(store, nil, true) + } + return serverAddresses(store) } -func meshGatewayAdresses(state *state.Store) ([]string, error) { - _, nodes, err := state.ServiceDump(nil, structs.ServiceKindMeshGateway, true, acl.DefaultEnterpriseMeta(), structs.DefaultPeerKeyword) +// GetDialAddresses returns: the addresses to cycle through when dialing a peer's servers, +// a boolean indicating whether mesh gateways are present, and an optional error. +// The resulting ring buffer is front-loaded with the local mesh gateway addresses if they are present. +func (b *PeeringBackend) GetDialAddresses(logger hclog.Logger, ws memdb.WatchSet, peerID string) (*ring.Ring, bool, error) { + newRing, err := b.fetchPeerServerAddresses(ws, peerID) + if err != nil { + return nil, false, fmt.Errorf("failed to refresh peer server addresses, will continue to use initial addresses: %w", err) + } + + gatewayRing, err := b.maybeFetchGatewayAddresses(ws) + if err != nil { + // If we couldn't fetch the mesh gateway addresses we fall back to dialing the remote server addresses. + logger.Warn("failed to refresh local gateway addresses, will attempt to dial peer directly: %w", "error", err) + return newRing, false, nil + } + if gatewayRing != nil { + // The ordering is important here. We always want to start with the mesh gateway + // addresses and fallback to the remote addresses, so we append the server addresses + // in newRing to gatewayRing. + newRing = gatewayRing.Link(newRing) + } + return newRing, gatewayRing != nil, nil +} + +// fetchPeerServerAddresses will return a ring buffer with the latest peer server addresses. +// If the peering is no longer active or does not have addresses, then we return an error. +func (b *PeeringBackend) fetchPeerServerAddresses(ws memdb.WatchSet, peerID string) (*ring.Ring, error) { + _, peering, err := b.Store().PeeringReadByID(ws, peerID) + if err != nil { + return nil, fmt.Errorf("failed to fetch peer %q: %w", peerID, err) + } + if !peering.IsActive() { + return nil, fmt.Errorf("there is no active peering for %q", peerID) + } + + // IMPORTANT: The address ring buffer must always be length > 0 + if len(peering.PeerServerAddresses) == 0 { + return nil, fmt.Errorf("peer %q has no addresses to dial", peerID) + } + return bufferFromAddresses(peering.PeerServerAddresses), nil +} + +// maybeFetchGatewayAddresses will return a ring buffer with the latest gateway addresses if the +// local datacenter is configured to peer through mesh gateways and there are local gateways registered. +// If neither of these are true then we return a nil buffer. +func (b *PeeringBackend) maybeFetchGatewayAddresses(ws memdb.WatchSet) (*ring.Ring, error) { + useGateways, err := b.PeerThroughMeshGateways(ws) + if err != nil { + return nil, fmt.Errorf("failed to determine if peering should happen through mesh gateways: %w", err) + } + if useGateways { + addresses, err := meshGatewayAdresses(b.srv.fsm.State(), ws, false) + + // IMPORTANT: The address ring buffer must always be length > 0 + if err != nil || len(addresses) == 0 { + return nil, fmt.Errorf("error fetching local mesh gateway addresses: %w", err) + } + return bufferFromAddresses(addresses), nil + } + return nil, nil +} + +// PeerThroughMeshGateways determines if the config entry to enable peering control plane +// traffic through a mesh gateway is set to enable. +func (b *PeeringBackend) PeerThroughMeshGateways(ws memdb.WatchSet) (bool, error) { + _, rawEntry, err := b.srv.fsm.State().ConfigEntry(ws, structs.MeshConfig, structs.MeshConfigMesh, acl.DefaultEnterpriseMeta()) + if err != nil { + return false, fmt.Errorf("failed to read mesh config entry: %w", err) + } + mesh, ok := rawEntry.(*structs.MeshConfigEntry) + if rawEntry != nil && !ok { + return false, fmt.Errorf("got unexpected type for mesh config entry: %T", rawEntry) + } + return mesh.PeerThroughMeshGateways(), nil + +} + +func bufferFromAddresses(addresses []string) *ring.Ring { + ring := ring.New(len(addresses)) + for _, addr := range addresses { + ring.Value = addr + ring = ring.Next() + } + return ring +} + +func meshGatewayAdresses(state *state.Store, ws memdb.WatchSet, wan bool) ([]string, error) { + _, nodes, err := state.ServiceDump(ws, structs.ServiceKindMeshGateway, true, acl.DefaultEnterpriseMeta(), structs.DefaultPeerKeyword) if err != nil { return nil, fmt.Errorf("failed to dump gateway addresses: %w", err) } var addrs []string for _, node := range nodes { - _, addr, port := node.BestAddress(true) + _, addr, port := node.BestAddress(wan) addrs = append(addrs, ipaddr.FormatAddressPort(addr, port)) } if len(addrs) == 0 { diff --git a/agent/consul/peering_backend_test.go b/agent/consul/peering_backend_test.go index 0d834c09a9..63a42e07a9 100644 --- a/agent/consul/peering_backend_test.go +++ b/agent/consul/peering_backend_test.go @@ -10,6 +10,7 @@ import ( gogrpc "google.golang.org/grpc" "github.com/hashicorp/consul/agent/connect" + "github.com/hashicorp/consul/agent/consul/state" "github.com/hashicorp/consul/agent/pool" "github.com/hashicorp/consul/agent/structs" "github.com/hashicorp/consul/proto/pbpeering" @@ -76,7 +77,7 @@ func TestPeeringBackend_ForwardToLeader(t *testing.T) { }) } -func TestPeeringBackend_GetServerAddresses(t *testing.T) { +func TestPeeringBackend_GetLocalServerAddresses(t *testing.T) { if testing.Short() { t.Skip("too slow for testing.Short") } @@ -91,7 +92,7 @@ func TestPeeringBackend_GetServerAddresses(t *testing.T) { backend := NewPeeringBackend(srv) testutil.RunStep(t, "peer to servers", func(t *testing.T) { - addrs, err := backend.GetServerAddresses() + addrs, err := backend.GetLocalServerAddresses() require.NoError(t, err) expect := fmt.Sprintf("127.0.0.1:%d", srv.config.GRPCTLSPort) @@ -107,7 +108,7 @@ func TestPeeringBackend_GetServerAddresses(t *testing.T) { } require.NoError(t, srv.fsm.State().EnsureConfigEntry(1, &mesh)) - addrs, err := backend.GetServerAddresses() + addrs, err := backend.GetLocalServerAddresses() require.NoError(t, err) // Still expect server address because PeerThroughMeshGateways was not enabled. @@ -121,7 +122,7 @@ func TestPeeringBackend_GetServerAddresses(t *testing.T) { } require.NoError(t, srv.fsm.State().EnsureConfigEntry(1, &mesh)) - addrs, err := backend.GetServerAddresses() + addrs, err := backend.GetLocalServerAddresses() require.Nil(t, addrs) testutil.RequireErrorContains(t, err, "servers are configured to PeerThroughMeshGateways, but no mesh gateway instances are registered") @@ -147,12 +148,221 @@ func TestPeeringBackend_GetServerAddresses(t *testing.T) { } require.NoError(t, srv.fsm.State().EnsureRegistration(2, ®)) - addrs, err := backend.GetServerAddresses() + addrs, err := backend.GetLocalServerAddresses() require.NoError(t, err) require.Equal(t, []string{"154.238.12.252:8443"}, addrs) }) } +func TestPeeringBackend_GetDialAddresses(t *testing.T) { + if testing.Short() { + t.Skip("too slow for testing.Short") + } + + _, cfg := testServerConfig(t) + cfg.GRPCTLSPort = freeport.GetOne(t) + + srv, err := newServer(t, cfg) + require.NoError(t, err) + testrpc.WaitForLeader(t, srv.RPC, "dc1") + + backend := NewPeeringBackend(srv) + + dialerPeerID := testUUID() + acceptorPeerID := testUUID() + + type expectation struct { + addrs []string + haveGateways bool + err string + } + + type testCase struct { + name string + setup func(store *state.Store) + peerID string + expect expectation + } + + run := func(t *testing.T, tc testCase) { + if tc.setup != nil { + tc.setup(srv.fsm.State()) + } + + ring, haveGateways, err := backend.GetDialAddresses(testutil.Logger(t), nil, tc.peerID) + if tc.expect.err != "" { + testutil.RequireErrorContains(t, err, tc.expect.err) + return + } + require.Equal(t, tc.expect.haveGateways, haveGateways) + require.NotNil(t, ring) + + var addrs []string + ring.Do(func(value any) { + addr, ok := value.(string) + + require.True(t, ok) + addrs = append(addrs, addr) + }) + require.Equal(t, tc.expect.addrs, addrs) + } + + // NOTE: The following tests are set up to run serially with RunStep to save on the setup/teardown cost for a test server. + tt := []testCase{ + { + name: "unknown peering", + setup: func(store *state.Store) { + // Test peering is not written during setup + }, + peerID: acceptorPeerID, + expect: expectation{ + err: fmt.Sprintf(`there is no active peering for %q`, acceptorPeerID), + }, + }, + { + name: "no server addresses", + setup: func(store *state.Store) { + require.NoError(t, store.PeeringWrite(1, &pbpeering.PeeringWriteRequest{ + Peering: &pbpeering.Peering{ + Name: "acceptor", + ID: acceptorPeerID, + // Acceptor peers do not have PeerServerAddresses populated locally. + }, + })) + }, + peerID: acceptorPeerID, + expect: expectation{ + err: fmt.Sprintf(`peer %q has no addresses to dial`, acceptorPeerID), + }, + }, + { + name: "only server addrs are returned when mesh config does not exist", + setup: func(store *state.Store) { + require.NoError(t, store.PeeringWrite(2, &pbpeering.PeeringWriteRequest{ + Peering: &pbpeering.Peering{ + Name: "dialer", + ID: dialerPeerID, + PeerServerAddresses: []string{"1.2.3.4:8502", "2.3.4.5:8503"}, + }, + })) + + // Mesh config entry does not exist + }, + peerID: dialerPeerID, + expect: expectation{ + haveGateways: false, + addrs: []string{"1.2.3.4:8502", "2.3.4.5:8503"}, + }, + }, + { + name: "only server addrs are returned when not peering through gateways", + setup: func(store *state.Store) { + require.NoError(t, srv.fsm.State().EnsureConfigEntry(3, &structs.MeshConfigEntry{ + Peering: &structs.PeeringMeshConfig{ + PeerThroughMeshGateways: false, // Peering through gateways is not enabled + }, + })) + }, + peerID: dialerPeerID, + expect: expectation{ + haveGateways: false, + addrs: []string{"1.2.3.4:8502", "2.3.4.5:8503"}, + }, + }, + { + name: "only server addrs are returned when peering through gateways without gateways registered", + setup: func(store *state.Store) { + require.NoError(t, srv.fsm.State().EnsureConfigEntry(4, &structs.MeshConfigEntry{ + Peering: &structs.PeeringMeshConfig{ + PeerThroughMeshGateways: true, + }, + })) + + // No gateways are registered + }, + peerID: dialerPeerID, + expect: expectation{ + haveGateways: false, + + // Fall back to remote server addresses + addrs: []string{"1.2.3.4:8502", "2.3.4.5:8503"}, + }, + }, + { + name: "gateway addresses are included after gateways are registered", + setup: func(store *state.Store) { + require.NoError(t, srv.fsm.State().EnsureRegistration(5, &structs.RegisterRequest{ + ID: types.NodeID(testUUID()), + Node: "gateway-node-1", + Address: "5.6.7.8", + Service: &structs.NodeService{ + Kind: structs.ServiceKindMeshGateway, + ID: "mesh-gateway-1", + Service: "mesh-gateway", + Port: 8443, + TaggedAddresses: map[string]structs.ServiceAddress{ + structs.TaggedAddressWAN: { + Address: "my-lb-addr.not-aws.com", + Port: 443, + }, + }, + }, + })) + require.NoError(t, srv.fsm.State().EnsureRegistration(6, &structs.RegisterRequest{ + ID: types.NodeID(testUUID()), + Node: "gateway-node-2", + Address: "6.7.8.9", + Service: &structs.NodeService{ + Kind: structs.ServiceKindMeshGateway, + ID: "mesh-gateway-2", + Service: "mesh-gateway", + Port: 8443, + TaggedAddresses: map[string]structs.ServiceAddress{ + structs.TaggedAddressWAN: { + Address: "my-other-lb-addr.not-aws.com", + Port: 443, + }, + }, + }, + })) + }, + peerID: dialerPeerID, + expect: expectation{ + haveGateways: true, + + // Gateways come first, and we use their LAN addresses since this is for outbound communication. + addrs: []string{"6.7.8.9:8443", "5.6.7.8:8443", "1.2.3.4:8502", "2.3.4.5:8503"}, + }, + }, + { + name: "addresses are not returned if the peering is deleted", + setup: func(store *state.Store) { + require.NoError(t, store.PeeringWrite(5, &pbpeering.PeeringWriteRequest{ + Peering: &pbpeering.Peering{ + Name: "dialer", + ID: dialerPeerID, + PeerServerAddresses: []string{"1.2.3.4:8502", "2.3.4.5:8503"}, + + // Mark as deleted + State: pbpeering.PeeringState_DELETING, + DeletedAt: structs.TimeToProto(time.Now()), + }, + })) + }, + peerID: dialerPeerID, + expect: expectation{ + err: fmt.Sprintf(`there is no active peering for %q`, dialerPeerID), + }, + }, + } + + for _, tc := range tt { + testutil.RunStep(t, tc.name, func(t *testing.T) { + run(t, tc) + }) + } +} + func newServerDialer(serverAddr string) func(context.Context, string) (net.Conn, error) { return func(ctx context.Context, addr string) (net.Conn, error) { d := net.Dialer{} diff --git a/agent/rpc/peering/service.go b/agent/rpc/peering/service.go index c5abb3d9af..c6eb78dea9 100644 --- a/agent/rpc/peering/service.go +++ b/agent/rpc/peering/service.go @@ -1,6 +1,7 @@ package peering import ( + "container/ring" "context" "errors" "fmt" @@ -8,6 +9,7 @@ import ( "time" "github.com/armon/go-metrics" + "github.com/hashicorp/consul/lib/retry" "github.com/hashicorp/go-hclog" "github.com/hashicorp/go-memdb" "github.com/hashicorp/go-multierror" @@ -36,6 +38,15 @@ var ( errPeeringTokenEmptyPeerID = errors.New("peering token peer ID value is empty") ) +const ( + // meshGatewayWait is the initial wait on calls to exchange a secret with a peer when dialing through a gateway. + // This wait provides some time for the first gateway address to configure a route to the peer servers. + // Why 350ms? That is roughly the p50 latency we observed in a scale test for proxy config propagation: + // https://www.hashicorp.com/cgsb + meshGatewayWait = 350 * time.Millisecond + establishmentTimeout = 5 * time.Second +) + // errPeeringInvalidServerAddress is returned when an establish request contains // an invalid server address. type errPeeringInvalidServerAddress struct { @@ -118,9 +129,9 @@ type Backend interface { // It returns the server name to validate, and the CA certificate to validate with. GetTLSMaterials(generatingToken bool) (string, []string, error) - // GetServerAddresses returns the addresses used for establishing a peering connection. + // GetLocalServerAddresses returns the addresses used for establishing a peering connection. // These may be server addresses or mesh gateway addresses if peering through mesh gateways. - GetServerAddresses() ([]string, error) + GetLocalServerAddresses() ([]string, error) // EncodeToken packages a peering token into a slice of bytes. EncodeToken(tok *structs.PeeringToken) ([]byte, error) @@ -128,6 +139,12 @@ type Backend interface { // DecodeToken unpackages a peering token from a slice of bytes. DecodeToken([]byte) (*structs.PeeringToken, error) + // GetDialAddresses returns: the addresses to cycle through when dialing a peer's servers, + // a boolean indicating whether mesh gateways are present, and an optional error. + // The resulting ring buffer is front-loaded with the local mesh gateway addresses if the local + // datacenter is configured to dial through mesh gateways. + GetDialAddresses(logger hclog.Logger, ws memdb.WatchSet, peerID string) (*ring.Ring, bool, error) + EnterpriseCheckPartitions(partition string) error EnterpriseCheckNamespaces(namespace string) error @@ -298,7 +315,7 @@ func (s *Server) GenerateToken( if len(req.ServerExternalAddresses) > 0 { serverAddrs = req.ServerExternalAddresses } else { - serverAddrs, err = s.Backend.GetServerAddresses() + serverAddrs, err = s.Backend.GetLocalServerAddresses() if err != nil { return nil, err } @@ -419,7 +436,13 @@ func (s *Server) Establish( PeerServerName: tok.ServerName, PeerID: tok.PeerID, Meta: req.Meta, - State: pbpeering.PeeringState_ESTABLISHING, + + // State is intentionally not set until after the secret exchange succeeds. + // This is to prevent a scenario where an active peering is re-established, + // the secret exchange fails, and the peering state gets stuck in "Establishing" + // while the original connection is still active. + // State: pbpeering.PeeringState_ESTABLISHING, + // PartitionOrEmpty is used to avoid writing "default" in OSS. Partition: entMeta.PartitionOrEmpty(), Remote: &pbpeering.RemoteInfo{ @@ -428,39 +451,30 @@ func (s *Server) Establish( }, } - tlsOption, err := peering.TLSDialOption() - if err != nil { - return nil, fmt.Errorf("failed to build TLS dial option from peering: %w", err) + // Write the peering ahead of the ExchangeSecret handshake to give + // mesh gateways in the default partition an opportunity + // to update their config with an outbound route to this peer server. + // + // If the request to exchange a secret fails then the peering will continue to exist. + // We do not undo this write because this call to establish may actually be a re-establish call + // for an active peering. + writeReq := &pbpeering.PeeringWriteRequest{ + Peering: peering, + } + if err := s.Backend.PeeringWrite(writeReq); err != nil { + return nil, fmt.Errorf("failed to write peering: %w", err) } - exchangeReq := pbpeerstream.ExchangeSecretRequest{ - PeerID: peering.PeerID, - EstablishmentSecret: tok.EstablishmentSecret, - } - var exchangeResp *pbpeerstream.ExchangeSecretResponse - - // Loop through the known server addresses once, attempting to fetch the long-lived stream secret. - var dialErrors error - for _, addr := range serverAddrs { - exchangeResp, err = exchangeSecret(ctx, addr, tlsOption, &exchangeReq) - if err != nil { - dialErrors = multierror.Append(dialErrors, fmt.Errorf("failed to exchange peering secret with %q: %w", addr, err)) - } - if exchangeResp != nil { - break - } - } + exchangeResp, dialErrors := s.exchangeSecret(ctx, peering, tok.EstablishmentSecret) if exchangeResp == nil { return nil, dialErrors } + peering.State = pbpeering.PeeringState_ESTABLISHING // As soon as a peering is written with a non-empty list of ServerAddresses // and an active stream secret, a leader routine will see the peering and // attempt to establish a peering stream with the remote peer. - // - // This peer now has a record of both the LocalPeerID(ID) and - // RemotePeerID(PeerID) but at this point the other peer does not. - writeReq := &pbpeering.PeeringWriteRequest{ + writeReq = &pbpeering.PeeringWriteRequest{ Peering: peering, SecretsRequest: &pbpeering.SecretsWriteRequest{ PeerID: peering.ID, @@ -474,7 +488,6 @@ func (s *Server) Establish( if err := s.Backend.PeeringWrite(writeReq); err != nil { return nil, fmt.Errorf("failed to write peering: %w", err) } - // TODO(peering): low prio: consider adding response details return resp, nil } @@ -493,20 +506,78 @@ func (s *Server) validatePeeringLocality(token *structs.PeeringToken) error { return nil } -func exchangeSecret(ctx context.Context, addr string, tlsOption grpc.DialOption, req *pbpeerstream.ExchangeSecretRequest) (*pbpeerstream.ExchangeSecretResponse, error) { - dialCtx, cancel := context.WithTimeout(ctx, 10*time.Second) +// exchangeSecret will continuously attempt to exchange the given establishment secret with the peer, up to a timeout. +// This function will attempt to dial through mesh gateways if the local DC is configured to peer through gateways, +// but will fall back to server addresses if not. +func (s *Server) exchangeSecret(ctx context.Context, peering *pbpeering.Peering, establishmentSecret string) (*pbpeerstream.ExchangeSecretResponse, error) { + req := pbpeerstream.ExchangeSecretRequest{ + PeerID: peering.PeerID, + EstablishmentSecret: establishmentSecret, + } + + tlsOption, err := peering.TLSDialOption() + if err != nil { + return nil, fmt.Errorf("failed to build TLS dial option from peering: %w", err) + } + + ringBuf, usingGateways, err := s.Backend.GetDialAddresses(s.Logger, nil, peering.ID) + if err != nil { + return nil, fmt.Errorf("failed to get addresses to dial peer: %w", err) + } + + var ( + resp *pbpeerstream.ExchangeSecretResponse + dialErrors error + ) + + retryWait := 150 * time.Millisecond + jitter := retry.NewJitter(25) + + if usingGateways { + // If we are dialing through local gateways we sleep before issuing the first request. + // This gives the local gateways some time to configure a route to the peer servers. + time.Sleep(meshGatewayWait) + } + + retryCtx, cancel := context.WithTimeout(ctx, establishmentTimeout) defer cancel() - conn, err := grpc.DialContext(dialCtx, addr, - tlsOption, - ) - if err != nil { - return nil, fmt.Errorf("failed to dial peer: %w", err) - } - defer conn.Close() + for retryCtx.Err() == nil { + addr := ringBuf.Value.(string) - client := pbpeerstream.NewPeerStreamServiceClient(conn) - return client.ExchangeSecret(ctx, req) + dialCtx, cancel := context.WithTimeout(ctx, 2*time.Second) + defer cancel() + + conn, err := grpc.DialContext(dialCtx, addr, + tlsOption, + ) + if err != nil { + return nil, fmt.Errorf("failed to dial peer: %w", err) + } + defer conn.Close() + + client := pbpeerstream.NewPeerStreamServiceClient(conn) + resp, err = client.ExchangeSecret(ctx, &req) + + // If we got a permission denied error that means out establishment secret is invalid, so we do not retry. + grpcErr, ok := grpcstatus.FromError(err) + if ok && grpcErr.Code() == codes.PermissionDenied { + return nil, fmt.Errorf("a new peering token must be generated: %w", grpcErr.Err()) + } + if err != nil { + dialErrors = multierror.Append(dialErrors, fmt.Errorf("failed to exchange peering secret through address %q: %w", addr, err)) + } + if resp != nil { + // Got a valid response. We're done. + break + } + + time.Sleep(jitter(retryWait)) + + // Cycle to the next possible address. + ringBuf = ringBuf.Next() + } + return resp, dialErrors } // OPTIMIZE: Handle blocking queries diff --git a/agent/rpc/peering/service_test.go b/agent/rpc/peering/service_test.go index 300e463bee..f2db059e75 100644 --- a/agent/rpc/peering/service_test.go +++ b/agent/rpc/peering/service_test.go @@ -12,6 +12,7 @@ import ( "testing" "time" + "github.com/google/tcpproxy" "github.com/hashicorp/go-hclog" "github.com/hashicorp/go-uuid" "github.com/stretchr/testify/require" @@ -476,6 +477,188 @@ func TestPeeringService_Establish(t *testing.T) { }) } +func TestPeeringService_Establish_ThroughMeshGateway(t *testing.T) { + // This test is timing-sensitive, must not be run in parallel. + // t.Parallel() + + acceptor := newTestServer(t, func(conf *consul.Config) { + conf.NodeName = "acceptor" + }) + acceptorClient := pbpeering.NewPeeringServiceClient(acceptor.ClientConn(t)) + + dialer := newTestServer(t, func(conf *consul.Config) { + conf.NodeName = "dialer" + conf.Datacenter = "dc2" + conf.PrimaryDatacenter = "dc2" + }) + dialerClient := pbpeering.NewPeeringServiceClient(dialer.ClientConn(t)) + + var peeringToken string + + testutil.RunStep(t, "retry until timeout on dial errors", func(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + t.Cleanup(cancel) + + testToken := structs.PeeringToken{ + ServerAddresses: []string{fmt.Sprintf("127.0.0.1:%d", freeport.GetOne(t))}, + PeerID: testUUID(t), + } + testTokenJSON, _ := json.Marshal(&testToken) + testTokenB64 := base64.StdEncoding.EncodeToString(testTokenJSON) + + start := time.Now() + _, err := dialerClient.Establish(ctx, &pbpeering.EstablishRequest{ + PeerName: "my-peer-acceptor", + PeeringToken: testTokenB64, + }) + require.Error(t, err) + testutil.RequireErrorContains(t, err, "connection refused") + + require.Greater(t, time.Since(start), 5*time.Second) + }) + + testutil.RunStep(t, "peering can be established from token", func(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + t.Cleanup(cancel) + + // Generate a peering token for dialer + tokenResp, err := acceptorClient.GenerateToken(ctx, &pbpeering.GenerateTokenRequest{PeerName: "my-peer-dialer"}) + require.NoError(t, err) + + // Capture peering token for re-use later + peeringToken = tokenResp.PeeringToken + + // The context timeout is short, it checks that we do not wait the 350ms that we do when peering through mesh gateways + ctx, cancel = context.WithTimeout(context.Background(), 300*time.Millisecond) + t.Cleanup(cancel) + + _, err = dialerClient.Establish(ctx, &pbpeering.EstablishRequest{ + PeerName: "my-peer-acceptor", + PeeringToken: tokenResp.PeeringToken, + }) + require.NoError(t, err) + }) + + testutil.RunStep(t, "fail fast on permission denied", func(t *testing.T) { + // This test case re-uses the previous token since the establishment secret will have been invalidated. + // The context timeout is short, it checks that we do not retry. + ctx, cancel := context.WithTimeout(context.Background(), 300*time.Millisecond) + t.Cleanup(cancel) + + _, err := dialerClient.Establish(ctx, &pbpeering.EstablishRequest{ + PeerName: "my-peer-acceptor", + PeeringToken: peeringToken, + }) + testutil.RequireErrorContains(t, err, "a new peering token must be generated") + }) + + gatewayPort := freeport.GetOne(t) + + testutil.RunStep(t, "fail past bad mesh gateway", func(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) + t.Cleanup(cancel) + + // Generate a new peering token for the dialer. + tokenResp, err := acceptorClient.GenerateToken(ctx, &pbpeering.GenerateTokenRequest{PeerName: "my-peer-dialer"}) + require.NoError(t, err) + + store := dialer.Server.FSM().State() + require.NoError(t, store.EnsureConfigEntry(1, &structs.MeshConfigEntry{ + Peering: &structs.PeeringMeshConfig{ + PeerThroughMeshGateways: true, + }, + })) + + // Register a gateway that isn't actually listening. + require.NoError(t, store.EnsureRegistration(2, &structs.RegisterRequest{ + ID: types.NodeID(testUUID(t)), + Node: "gateway-node-1", + Address: "127.0.0.1", + Service: &structs.NodeService{ + Kind: structs.ServiceKindMeshGateway, + ID: "mesh-gateway-1", + Service: "mesh-gateway", + Port: gatewayPort, + }, + })) + + ctx, cancel = context.WithTimeout(context.Background(), 3*time.Second) + t.Cleanup(cancel) + + // Call to establish should succeed when we fall back to remote server address. + _, err = dialerClient.Establish(ctx, &pbpeering.EstablishRequest{ + PeerName: "my-peer-acceptor", + PeeringToken: tokenResp.PeeringToken, + }) + require.NoError(t, err) + }) + + testutil.RunStep(t, "route through gateway", func(t *testing.T) { + // Spin up a proxy listening at the gateway port registered above. + gatewayAddr := fmt.Sprintf("127.0.0.1:%d", gatewayPort) + + // Configure a TCP proxy with an SNI route corresponding to the acceptor cluster. + var proxy tcpproxy.Proxy + target := &connWrapper{ + proxy: tcpproxy.DialProxy{ + Addr: acceptor.PublicGRPCAddr, + }, + } + proxy.AddSNIRoute(gatewayAddr, "server.dc1.peering.11111111-2222-3333-4444-555555555555.consul", target) + proxy.AddStopACMESearch(gatewayAddr) + + require.NoError(t, proxy.Start()) + t.Cleanup(func() { + proxy.Close() + proxy.Wait() + }) + + ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) + t.Cleanup(cancel) + + // Generate a new peering token for the dialer. + tokenResp, err := acceptorClient.GenerateToken(ctx, &pbpeering.GenerateTokenRequest{PeerName: "my-peer-dialer"}) + require.NoError(t, err) + + store := dialer.Server.FSM().State() + require.NoError(t, store.EnsureConfigEntry(1, &structs.MeshConfigEntry{ + Peering: &structs.PeeringMeshConfig{ + PeerThroughMeshGateways: true, + }, + })) + + ctx, cancel = context.WithTimeout(context.Background(), 1*time.Second) + t.Cleanup(cancel) + + start := time.Now() + + // Call to establish should succeed through the proxy. + _, err = dialerClient.Establish(ctx, &pbpeering.EstablishRequest{ + PeerName: "my-peer-acceptor", + PeeringToken: tokenResp.PeeringToken, + }) + require.NoError(t, err) + + // Dialing through a gateway is preceded by a mandatory 350ms sleep. + require.Greater(t, time.Since(start), 350*time.Millisecond) + + // target.called is true when the tcproxy's conn handler was invoked. + // This lets us know that the "Establish" success flowed through the proxy masquerading as a gateway. + require.True(t, target.called) + }) +} + +// connWrapper is a wrapper around tcpproxy.DialProxy to enable tracking whether the proxy handled a connection. +type connWrapper struct { + proxy tcpproxy.DialProxy + called bool +} + +func (w *connWrapper) HandleConn(src net.Conn) { + w.called = true + w.proxy.HandleConn(src) +} + func TestPeeringService_Establish_ACLEnforcement(t *testing.T) { validToken := peering.TestPeeringToken("83474a06-cca4-4ff4-99a4-4152929c8160") validTokenJSON, _ := json.Marshal(&validToken)