Add accessor and helpers to SDK for fetching self-name and client service ID

This commit is contained in:
Paul Banks 2018-06-11 17:53:14 +01:00 committed by Jack Pearkes
parent 7649d630c6
commit c08b6f6fec
3 changed files with 58 additions and 11 deletions

View File

@ -59,9 +59,9 @@ type Service struct {
// //
// Caller must provide client which is already configured to speak to the local // Caller must provide client which is already configured to speak to the local
// Consul agent, and with an ACL token that has `service:write` privileges for // Consul agent, and with an ACL token that has `service:write` privileges for
// the serviceID specified. // the service specified.
func NewService(serviceID string, client *api.Client) (*Service, error) { func NewService(serviceName string, client *api.Client) (*Service, error) {
return NewServiceWithLogger(serviceID, client, return NewServiceWithLogger(serviceName, client,
log.New(os.Stderr, "", log.LstdFlags)) log.New(os.Stderr, "", log.LstdFlags))
} }
@ -125,6 +125,13 @@ func NewDevServiceWithTLSConfig(serviceName string, logger *log.Logger,
return s, nil return s, nil
} }
// Name returns the name of the service this object represents. Note it is the
// service _name_ as used during discovery, not the ID used to uniquely identify
// an instance of the service with an agent.
func (s *Service) Name() string {
return s.service
}
// ServerTLSConfig returns a *tls.Config that allows any TCP listener to accept // ServerTLSConfig returns a *tls.Config that allows any TCP listener to accept
// and authorize incoming Connect clients. It will return a single static config // and authorize incoming Connect clients. It will return a single static config
// with hooks to dynamically load certificates, and perform Connect // with hooks to dynamically load certificates, and perform Connect

View File

@ -13,6 +13,8 @@ import (
"testing" "testing"
"time" "time"
"github.com/stretchr/testify/assert"
"github.com/hashicorp/consul/agent" "github.com/hashicorp/consul/agent"
"github.com/hashicorp/consul/agent/connect" "github.com/hashicorp/consul/agent/connect"
"github.com/hashicorp/consul/api" "github.com/hashicorp/consul/api"
@ -23,6 +25,12 @@ import (
// Assert io.Closer implementation // Assert io.Closer implementation
var _ io.Closer = new(Service) var _ io.Closer = new(Service)
func TestService_Name(t *testing.T) {
ca := connect.TestCA(t, nil)
s := TestService(t, "web", ca)
assert.Equal(t, "web", s.Name())
}
func TestService_Dial(t *testing.T) { func TestService_Dial(t *testing.T) {
ca := connect.TestCA(t, nil) ca := connect.TestCA(t, nil)

View File

@ -7,6 +7,8 @@ import (
"fmt" "fmt"
"io/ioutil" "io/ioutil"
"log" "log"
"net"
"net/url"
"sync" "sync"
"github.com/hashicorp/consul/agent/connect" "github.com/hashicorp/consul/agent/connect"
@ -81,14 +83,29 @@ func devTLSConfigFromFiles(caFile, certFile,
return cfg, nil return cfg, nil
} }
// verifyServerCertMatchesURI is used on tls connections dialled to a connect // CertURIFromConn is a helper to extract the service identifier URI from a
// server to ensure that the certificate it presented has the correct identity. // net.Conn. If the net.Conn is not a *tls.Conn then an error is always
func verifyServerCertMatchesURI(certs []*x509.Certificate, // returned. If the *tls.Conn didn't present a valid connect certificate, or is
expected connect.CertURI) error { // not yet past the handshake, an error is returned.
expectedStr := expected.URI().String() func CertURIFromConn(conn net.Conn) (connect.CertURI, error) {
tc, ok := conn.(*tls.Conn)
if !ok {
return nil, fmt.Errorf("invalid non-TLS connect client")
}
gotURI, err := extractCertURI(tc.ConnectionState().PeerCertificates)
if err != nil {
return nil, err
}
return connect.ParseCertURI(gotURI)
}
// extractCertURI returns the first URI SAN from the leaf certificate presented
// in the slice. The slice is expected to be the passed from
// tls.Conn.ConnectionState().PeerCertificates and requires that the leaf has at
// least one URI and the first URI is the correct one to use.
func extractCertURI(certs []*x509.Certificate) (*url.URL, error) {
if len(certs) < 1 { if len(certs) < 1 {
return errors.New("peer certificate mismatch") return nil, errors.New("no peer certificate presented")
} }
// Only check the first cert assuming this is the only leaf. It's not clear if // Only check the first cert assuming this is the only leaf. It's not clear if
@ -98,16 +115,31 @@ func verifyServerCertMatchesURI(certs []*x509.Certificate,
// Our certs will only ever have a single URI for now so only check that // Our certs will only ever have a single URI for now so only check that
if len(cert.URIs) < 1 { if len(cert.URIs) < 1 {
return nil, errors.New("peer certificate invalid")
}
return cert.URIs[0], nil
}
// verifyServerCertMatchesURI is used on tls connections dialled to a connect
// server to ensure that the certificate it presented has the correct identity.
func verifyServerCertMatchesURI(certs []*x509.Certificate,
expected connect.CertURI) error {
expectedStr := expected.URI().String()
gotURI, err := extractCertURI(certs)
if err != nil {
return errors.New("peer certificate mismatch") return errors.New("peer certificate mismatch")
} }
// We may want to do better than string matching later in some special // We may want to do better than string matching later in some special
// cases and/or encapsulate the "match" logic inside the CertURI // cases and/or encapsulate the "match" logic inside the CertURI
// implementation but for now this is all we need. // implementation but for now this is all we need.
if cert.URIs[0].String() == expectedStr { if gotURI.String() == expectedStr {
return nil return nil
} }
return fmt.Errorf("peer certificate mismatch got %s, want %s", return fmt.Errorf("peer certificate mismatch got %s, want %s",
cert.URIs[0].String(), expectedStr) gotURI.String(), expectedStr)
} }
// newServerSideVerifier returns a verifierFunc that wraps the provided // newServerSideVerifier returns a verifierFunc that wraps the provided