82 lines
2.6 KiB
Go
Raw Normal View History

// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: MPL-2.0
package middleware
import (
"fmt"
"net"
"google.golang.org/grpc/credentials"
)
var _ net.Listener = (*LabelledListener)(nil)
var _ net.Conn = (*LabelledConn)(nil)
type Protocol int
var (
ProtocolPlaintext Protocol = 0
ProtocolTLS Protocol = 1
)
// LabelledListener wraps a listener and attaches pre-specified
// fields to each spawned connection.
type LabelledListener struct {
net.Listener
Protocol Protocol
}
func (l LabelledListener) Accept() (net.Conn, error) {
conn, err := l.Listener.Accept()
if conn != nil {
conn = LabelledConn{conn, l.Protocol}
}
return conn, err
}
// LabelledConn wraps a connection and provides extra metadata fields.
type LabelledConn struct {
net.Conn
protocol Protocol
}
var _ credentials.TransportCredentials = (*optionalTransportCredentials)(nil)
// optionalTransportCredentials provides a way to selectively perform a TLS handshake
// based on metadata extracted from the underlying connection object.
type optionalTransportCredentials struct {
credentials.TransportCredentials
logger Logger
}
func NewOptionalTransportCredentials(creds credentials.TransportCredentials, logger Logger) credentials.TransportCredentials {
return &optionalTransportCredentials{creds, logger}
}
// ServerHandshake will attempt to detect the underlying connection protocol (TLS or Plaintext)
// based on metadata attached to the underlying connection. If TLS is detected, then a handshake
// will be performed, and the corresponding AuthInfo will be attached to the gRPC context.
// For plaintext connections, this is effectively a no-op, since there is no TLS info to attach.
// If the underlying connection is not a LabelledConn with a valid protocol, then this method will
// panic and prevent the gRPC connection from successfully progressing further.
func (tc *optionalTransportCredentials) ServerHandshake(conn net.Conn) (net.Conn, credentials.AuthInfo, error) {
// This should always be a LabelledConn, so no check is necessary.
nc := conn.(LabelledConn)
switch nc.protocol {
case ProtocolPlaintext:
// This originated from a plaintext listener, so do not use TLS auth.
return nc, nil, nil
case ProtocolTLS:
// This originated from a TLS listener, so it should have a full handshake performed.
c, ai, err := tc.TransportCredentials.ServerHandshake(conn)
if err == nil && ai == nil {
// This should not be possible, but ensure that it's non-nil for safety.
return nil, nil, fmt.Errorf("missing auth info after handshake")
}
return c, ai, err
default:
return nil, nil, fmt.Errorf("invalid protocol for grpc connection")
}
}