diff --git a/agent/grpc-middleware/auth_interceptor.go b/agent/grpc-middleware/auth_interceptor.go index 048b0b272b..bc9082c75e 100644 --- a/agent/grpc-middleware/auth_interceptor.go +++ b/agent/grpc-middleware/auth_interceptor.go @@ -66,6 +66,10 @@ func (a *AuthInterceptor) InterceptStream( // present a mutual TLS certificate, and is allowed to bypass the `tls.grpc.verify_incoming` // check as a special case. See the `tlsutil.Configurator` for this bypass. func restrictPeeringEndpoints(authInfo credentials.AuthInfo, peeringSNI string, endpoint string) error { + // No peering connection has been configured + if peeringSNI == "" { + return nil + } // This indicates a plaintext connection. if authInfo == nil { return nil @@ -75,6 +79,7 @@ func restrictPeeringEndpoints(authInfo credentials.AuthInfo, peeringSNI string, if !ok { return status.Error(codes.Unauthenticated, "invalid transport credentials") } + if tlsAuth.State.ServerName == peeringSNI { // Prevent any calls, except those in the PeerStreamService if !strings.HasPrefix(endpoint, AllowedPeerEndpointPrefix) { diff --git a/agent/grpc-middleware/auth_interceptor_test.go b/agent/grpc-middleware/auth_interceptor_test.go index f0e8c9e46a..512d5ed550 100644 --- a/agent/grpc-middleware/auth_interceptor_test.go +++ b/agent/grpc-middleware/auth_interceptor_test.go @@ -30,9 +30,16 @@ func TestGRPCMiddleware_restrictPeeringEndpoints(t *testing.T) { endpoint: "/hashicorp.consul.internal.peerstream.PeerStreamService/SomeEndpoint", }, { - name: "deny_invalid_credentials", - authInfo: invalidAuthInfo{}, - expectErr: "invalid transport credentials", + name: "peering_not_enabled", + authInfo: nil, + peeringSNI: "", + endpoint: "/hashicorp.consul.internal.peerstream.PeerStreamService/SomeEndpoint", + }, + { + name: "deny_invalid_credentials", + authInfo: invalidAuthInfo{}, + peeringSNI: "expected-server-sni", + expectErr: "invalid transport credentials", }, { name: "peering_sni_with_invalid_endpoint", @@ -72,6 +79,7 @@ func TestGRPCMiddleware_restrictPeeringEndpoints(t *testing.T) { if tc.expectErr == "" { require.NoError(t, err) } else { + require.NotNil(t, err) require.Contains(t, err.Error(), tc.expectErr) } })