diff --git a/agent/connect/testing_ca.go b/agent/connect/testing_ca.go index 3fbcf2e021..e12372589f 100644 --- a/agent/connect/testing_ca.go +++ b/agent/connect/testing_ca.go @@ -157,7 +157,7 @@ func TestLeaf(t testing.T, service string, root *structs.CARoot) (string, string t.Fatalf("error generating serial number: %s", err) } - // Genereate fresh private key + // Generate fresh private key pkSigner, pkPEM := testPrivateKey(t) // Cert template for generation diff --git a/command/connect/proxy/proxy.go b/command/connect/proxy/proxy.go index 237f4b7e2f..362e704598 100644 --- a/command/connect/proxy/proxy.go +++ b/command/connect/proxy/proxy.go @@ -1,17 +1,15 @@ package proxy import ( - "context" "flag" "fmt" "io" "log" "net/http" - // Expose pprof if configured - _ "net/http/pprof" + _ "net/http/pprof" // Expose pprof if configured "github.com/hashicorp/consul/command/flags" - proxyImpl "github.com/hashicorp/consul/proxy" + proxyImpl "github.com/hashicorp/consul/connect/proxy" "github.com/hashicorp/consul/logger" "github.com/hashicorp/logutils" @@ -46,13 +44,14 @@ type cmd struct { func (c *cmd) init() { c.flags = flag.NewFlagSet("", flag.ContinueOnError) - c.flags.StringVar(&c.cfgFile, "insecure-dev-config", "", + c.flags.StringVar(&c.cfgFile, "dev-config", "", "If set, proxy config is read on startup from this file (in HCL or JSON"+ "format). If a config file is given, the proxy will use that instead of "+ "querying the local agent for it's configuration. It will not reload it "+ "except on startup. In this mode the proxy WILL NOT authorize incoming "+ "connections with the local agent which is totally insecure. This is "+ - "ONLY for development and testing.") + "ONLY for internal development and testing and will probably be removed "+ + "once proxy implementation is more complete..") c.flags.StringVar(&c.proxyID, "proxy-id", "", "The proxy's ID on the local agent.") @@ -121,31 +120,23 @@ func (c *cmd) Run(args []string) int { } } - ctx, cancel := context.WithCancel(context.Background()) + // Hook the shutdownCh up to close the proxy go func() { - err := p.Run(ctx) - if err != nil { - c.UI.Error(fmt.Sprintf("Failed running proxy: %s", err)) - } - // If we exited early due to a fatal error, need to unblock the main - // routine. But we can't close shutdownCh since it might already be closed - // by a signal and there is no way to tell. We also can't send on it to - // unblock main routine since it's typed as receive only. So the best thing - // we can do is cancel the context and have the main routine select on both. - cancel() + <-c.shutdownCh + p.Close() }() - c.UI.Output("Consul Connect proxy running!") + c.UI.Output("Consul Connect proxy starting") c.UI.Output("Log data will now stream in as it occurs:\n") logGate.Flush() - // Wait for shutdown or context cancel (see Run() goroutine above) - select { - case <-c.shutdownCh: - cancel() - case <-ctx.Done(): + // Run the proxy + err = p.Serve() + if err != nil { + c.UI.Error(fmt.Sprintf("Failed running proxy: %s", err)) } + c.UI.Output("Consul Connect proxy shutdown") return 0 } diff --git a/command/connect/proxy/proxy_test.go b/command/connect/proxy/proxy_test.go deleted file mode 100644 index 943b369ffe..0000000000 --- a/command/connect/proxy/proxy_test.go +++ /dev/null @@ -1 +0,0 @@ -package proxy diff --git a/connect/certgen/certgen.go b/connect/certgen/certgen.go index 6fecf6ae1f..89c4245761 100644 --- a/connect/certgen/certgen.go +++ b/connect/certgen/certgen.go @@ -27,6 +27,7 @@ // NOTE: THIS IS A QUIRK OF OPENSSL; in Connect we distribute the roots alone // and stable intermediates like the XC cert to the _leaf_. package main // import "github.com/hashicorp/consul/connect/certgen" + import ( "flag" "fmt" @@ -42,7 +43,6 @@ import ( func main() { var numCAs = 2 var services = []string{"web", "db", "cache"} - //var slugRe = regexp.MustCompile("[^a-zA-Z0-9]+") var outDir string flag.StringVar(&outDir, "out-dir", "", diff --git a/connect/proxy/config.go b/connect/proxy/config.go new file mode 100644 index 0000000000..a8f83d22cf --- /dev/null +++ b/connect/proxy/config.go @@ -0,0 +1,223 @@ +package proxy + +import ( + "fmt" + "io/ioutil" + "log" + + "github.com/hashicorp/consul/api" + "github.com/hashicorp/consul/connect" + "github.com/hashicorp/hcl" +) + +// Config is the publicly configurable state for an entire proxy instance. It's +// mostly used as the format for the local-file config mode which is mostly for +// dev/testing. In normal use, different parts of this config are pulled from +// different locations (e.g. command line, agent config endpoint, agent +// certificate endpoints). +type Config struct { + // ProxyID is the identifier for this proxy as registered in Consul. It's only + // guaranteed to be unique per agent. + ProxyID string `json:"proxy_id" hcl:"proxy_id"` + + // Token is the authentication token provided for queries to the local agent. + Token string `json:"token" hcl:"token"` + + // ProxiedServiceID is the identifier of the service this proxy is representing. + ProxiedServiceID string `json:"proxied_service_id" hcl:"proxied_service_id"` + + // ProxiedServiceNamespace is the namespace of the service this proxy is + // representing. + ProxiedServiceNamespace string `json:"proxied_service_namespace" hcl:"proxied_service_namespace"` + + // PublicListener configures the mTLS listener. + PublicListener PublicListenerConfig `json:"public_listener" hcl:"public_listener"` + + // Upstreams configures outgoing proxies for remote connect services. + Upstreams []UpstreamConfig `json:"upstreams" hcl:"upstreams"` + + // DevCAFile allows passing the file path to PEM encoded root certificate + // bundle to be used in development instead of the ones supplied by Connect. + DevCAFile string `json:"dev_ca_file" hcl:"dev_ca_file"` + + // DevServiceCertFile allows passing the file path to PEM encoded service + // certificate (client and server) to be used in development instead of the + // ones supplied by Connect. + DevServiceCertFile string `json:"dev_service_cert_file" hcl:"dev_service_cert_file"` + + // DevServiceKeyFile allows passing the file path to PEM encoded service + // private key to be used in development instead of the ones supplied by + // Connect. + DevServiceKeyFile string `json:"dev_service_key_file" hcl:"dev_service_key_file"` + + // service is a connect.Service instance representing the proxied service. It + // is created internally by the code responsible for setting up config as it + // may depend on other external dependencies + service *connect.Service +} + +// PublicListenerConfig contains the parameters needed for the incoming mTLS +// listener. +type PublicListenerConfig struct { + // BindAddress is the host:port the public mTLS listener will bind to. + BindAddress string `json:"bind_address" hcl:"bind_address"` + + // LocalServiceAddress is the host:port for the proxied application. This + // should be on loopback or otherwise protected as it's plain TCP. + LocalServiceAddress string `json:"local_service_address" hcl:"local_service_address"` + + // LocalConnectTimeout is the timeout for establishing connections with the + // local backend. Defaults to 1000 (1s). + LocalConnectTimeoutMs int `json:"local_connect_timeout_ms" hcl:"local_connect_timeout_ms"` + + // HandshakeTimeout is the timeout for incoming mTLS clients to complete a + // handshake. Setting this low avoids DOS by malicious clients holding + // resources open. Defaults to 10000 (10s). + HandshakeTimeoutMs int `json:"handshake_timeout_ms" hcl:"handshake_timeout_ms"` +} + +// applyDefaults sets zero-valued params to a sane default. +func (plc *PublicListenerConfig) applyDefaults() { + if plc.LocalConnectTimeoutMs == 0 { + plc.LocalConnectTimeoutMs = 1000 + } + if plc.HandshakeTimeoutMs == 0 { + plc.HandshakeTimeoutMs = 10000 + } +} + +// UpstreamConfig configures an upstream (outgoing) listener. +type UpstreamConfig struct { + // LocalAddress is the host:port to listen on for local app connections. + LocalBindAddress string `json:"local_bind_address" hcl:"local_bind_address,attr"` + + // DestinationName is the service name of the destination. + DestinationName string `json:"destination_name" hcl:"destination_name,attr"` + + // DestinationNamespace is the namespace of the destination. + DestinationNamespace string `json:"destination_namespace" hcl:"destination_namespace,attr"` + + // DestinationType determines which service discovery method is used to find a + // candidate instance to connect to. + DestinationType string `json:"destination_type" hcl:"destination_type,attr"` + + // DestinationDatacenter is the datacenter the destination is in. If empty, + // defaults to discovery within the same datacenter. + DestinationDatacenter string `json:"destination_datacenter" hcl:"destination_datacenter,attr"` + + // ConnectTimeout is the timeout for establishing connections with the remote + // service instance. Defaults to 10,000 (10s). + ConnectTimeoutMs int `json:"connect_timeout_ms" hcl:"connect_timeout_ms,attr"` + + // resolver is used to plug in the service discover mechanism. It can be used + // in tests to bypass discovery. In real usage it is used to inject the + // api.Client dependency from the remainder of the config struct parsed from + // the user JSON using the UpstreamResolverFromClient helper. + resolver connect.Resolver +} + +// applyDefaults sets zero-valued params to a sane default. +func (uc *UpstreamConfig) applyDefaults() { + if uc.ConnectTimeoutMs == 0 { + uc.ConnectTimeoutMs = 10000 + } +} + +// String returns a string that uniquely identifies the Upstream. Used for +// identifying the upstream in log output and map keys. +func (uc *UpstreamConfig) String() string { + return fmt.Sprintf("%s->%s:%s/%s", uc.LocalBindAddress, uc.DestinationType, + uc.DestinationNamespace, uc.DestinationName) +} + +// UpstreamResolverFromClient returns a ConsulResolver that can resolve the +// given UpstreamConfig using the provided api.Client dependency. +func UpstreamResolverFromClient(client *api.Client, + cfg UpstreamConfig) connect.Resolver { + + // For now default to service as it has the most natural meaning and the error + // that the service doesn't exist is probably reasonable if misconfigured. We + // should probably handle actual configs that have invalid types at a higher + // level anyway (like when parsing). + typ := connect.ConsulResolverTypeService + if cfg.DestinationType == "prepared_query" { + typ = connect.ConsulResolverTypePreparedQuery + } + return &connect.ConsulResolver{ + Client: client, + Namespace: cfg.DestinationNamespace, + Name: cfg.DestinationName, + Type: typ, + Datacenter: cfg.DestinationDatacenter, + } +} + +// ConfigWatcher is a simple interface to allow dynamic configurations from +// plugggable sources. +type ConfigWatcher interface { + // Watch returns a channel that will deliver new Configs if something external + // provokes it. + Watch() <-chan *Config +} + +// StaticConfigWatcher is a simple ConfigWatcher that delivers a static Config +// once and then never changes it. +type StaticConfigWatcher struct { + ch chan *Config +} + +// NewStaticConfigWatcher returns a ConfigWatcher for a config that never +// changes. It assumes only one "watcher" will ever call Watch. The config is +// delivered on the first call but will never be delivered again to allow +// callers to call repeatedly (e.g. select in a loop). +func NewStaticConfigWatcher(cfg *Config) *StaticConfigWatcher { + sc := &StaticConfigWatcher{ + // Buffer it so we can queue up the config for first delivery. + ch: make(chan *Config, 1), + } + sc.ch <- cfg + return sc +} + +// Watch implements ConfigWatcher on a static configuration for compatibility. +// It returns itself on the channel once and then leaves it open. +func (sc *StaticConfigWatcher) Watch() <-chan *Config { + return sc.ch +} + +// ParseConfigFile parses proxy configuration from a file for local dev. +func ParseConfigFile(filename string) (*Config, error) { + bs, err := ioutil.ReadFile(filename) + if err != nil { + return nil, err + } + + var cfg Config + + err = hcl.Unmarshal(bs, &cfg) + if err != nil { + return nil, err + } + + cfg.PublicListener.applyDefaults() + for idx := range cfg.Upstreams { + cfg.Upstreams[idx].applyDefaults() + } + + return &cfg, nil +} + +// AgentConfigWatcher watches the local Consul agent for proxy config changes. +type AgentConfigWatcher struct { + client *api.Client + proxyID string + logger *log.Logger +} + +// Watch implements ConfigWatcher. +func (w *AgentConfigWatcher) Watch() <-chan *Config { + watch := make(chan *Config) + // TODO implement me, note we need to discover the Service instance to use and + // set it on the Config we return. + return watch +} diff --git a/connect/proxy/config_test.go b/connect/proxy/config_test.go new file mode 100644 index 0000000000..96782b12e6 --- /dev/null +++ b/connect/proxy/config_test.go @@ -0,0 +1,108 @@ +package proxy + +import ( + "testing" + + "github.com/hashicorp/consul/connect" + "github.com/stretchr/testify/require" +) + +func TestParseConfigFile(t *testing.T) { + cfg, err := ParseConfigFile("testdata/config-kitchensink.hcl") + require.Nil(t, err) + + expect := &Config{ + ProxyID: "foo", + Token: "11111111-2222-3333-4444-555555555555", + ProxiedServiceID: "web", + ProxiedServiceNamespace: "default", + PublicListener: PublicListenerConfig{ + BindAddress: ":9999", + LocalServiceAddress: "127.0.0.1:5000", + LocalConnectTimeoutMs: 1000, + HandshakeTimeoutMs: 10000, // From defaults + }, + Upstreams: []UpstreamConfig{ + { + LocalBindAddress: "127.0.0.1:6000", + DestinationName: "db", + DestinationNamespace: "default", + DestinationType: "service", + ConnectTimeoutMs: 10000, + }, + { + LocalBindAddress: "127.0.0.1:6001", + DestinationName: "geo-cache", + DestinationNamespace: "default", + DestinationType: "prepared_query", + ConnectTimeoutMs: 10000, + }, + }, + DevCAFile: "connect/testdata/ca1-ca-consul-internal.cert.pem", + DevServiceCertFile: "connect/testdata/ca1-svc-web.cert.pem", + DevServiceKeyFile: "connect/testdata/ca1-svc-web.key.pem", + } + + require.Equal(t, expect, cfg) +} + +func TestUpstreamResolverFromClient(t *testing.T) { + tests := []struct { + name string + cfg UpstreamConfig + want *connect.ConsulResolver + }{ + { + name: "service", + cfg: UpstreamConfig{ + DestinationNamespace: "foo", + DestinationName: "web", + DestinationDatacenter: "ny1", + DestinationType: "service", + }, + want: &connect.ConsulResolver{ + Namespace: "foo", + Name: "web", + Datacenter: "ny1", + Type: connect.ConsulResolverTypeService, + }, + }, + { + name: "prepared_query", + cfg: UpstreamConfig{ + DestinationNamespace: "foo", + DestinationName: "web", + DestinationDatacenter: "ny1", + DestinationType: "prepared_query", + }, + want: &connect.ConsulResolver{ + Namespace: "foo", + Name: "web", + Datacenter: "ny1", + Type: connect.ConsulResolverTypePreparedQuery, + }, + }, + { + name: "unknown behaves like service", + cfg: UpstreamConfig{ + DestinationNamespace: "foo", + DestinationName: "web", + DestinationDatacenter: "ny1", + DestinationType: "junk", + }, + want: &connect.ConsulResolver{ + Namespace: "foo", + Name: "web", + Datacenter: "ny1", + Type: connect.ConsulResolverTypeService, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Client doesn't really matter as long as it's passed through. + got := UpstreamResolverFromClient(nil, tt.cfg) + require.Equal(t, tt.want, got) + }) + } +} diff --git a/connect/proxy/conn.go b/connect/proxy/conn.go new file mode 100644 index 0000000000..70019e55cb --- /dev/null +++ b/connect/proxy/conn.go @@ -0,0 +1,61 @@ +package proxy + +import ( + "io" + "net" + "sync/atomic" +) + +// Conn represents a single proxied TCP connection. +type Conn struct { + src, dst net.Conn + stopping int32 +} + +// NewConn returns a conn joining the two given net.Conn +func NewConn(src, dst net.Conn) *Conn { + return &Conn{ + src: src, + dst: dst, + stopping: 0, + } +} + +// Close closes both connections. +func (c *Conn) Close() error { + // Note that net.Conn.Close can be called multiple times and atomic store is + // idempotent so no need to ensure we only do this once. + // + // Also note that we don't wait for CopyBytes to return here since we are + // closing the conns which is the only externally visible sideeffect of that + // goroutine running and there should be no way for it to hang or leak once + // the conns are closed so we can save the extra coordination. + atomic.StoreInt32(&c.stopping, 1) + c.src.Close() + c.dst.Close() + return nil +} + +// CopyBytes will continuously copy bytes in both directions between src and dst +// until either connection is closed. +func (c *Conn) CopyBytes() error { + defer c.Close() + + go func() { + // Need this since Copy is only guaranteed to stop when it's source reader + // (second arg) hits EOF or error but either conn might close first possibly + // causing this goroutine to exit but not the outer one. See + // TestConnSrcClosing which will fail if you comment the defer below. + defer c.Close() + io.Copy(c.dst, c.src) + }() + + _, err := io.Copy(c.src, c.dst) + // Note that we don't wait for the other goroutine to finish because it either + // already has due to it's src conn closing, or it will once our defer fires + // and closes the source conn. No need for the extra coordination. + if atomic.LoadInt32(&c.stopping) == 1 { + return nil + } + return err +} diff --git a/connect/proxy/conn_test.go b/connect/proxy/conn_test.go new file mode 100644 index 0000000000..a37720ea0a --- /dev/null +++ b/connect/proxy/conn_test.go @@ -0,0 +1,185 @@ +package proxy + +import ( + "bufio" + "io" + "net" + "testing" + "time" + + "github.com/stretchr/testify/require" +) + +// Assert io.Closer implementation +var _ io.Closer = new(Conn) + +// testConnPairSetup creates a TCP connection by listening on a random port, and +// returns both ends. Ready to have data sent down them. It also returns a +// closer function that will close both conns and the listener. +func testConnPairSetup(t *testing.T) (net.Conn, net.Conn, func()) { + t.Helper() + + l, err := net.Listen("tcp", "localhost:0") + require.Nil(t, err) + + ch := make(chan net.Conn, 1) + go func() { + src, err := l.Accept() + require.Nil(t, err) + ch <- src + }() + + dst, err := net.Dial("tcp", l.Addr().String()) + require.Nil(t, err) + + src := <-ch + + stopper := func() { + l.Close() + src.Close() + dst.Close() + } + + return src, dst, stopper +} + +// testConnPipelineSetup creates a pipeline consiting of two TCP connection +// pairs and a Conn that copies bytes between them. Data flow looks like this: +// +// src1 <---> dst1 <== Conn.CopyBytes ==> src2 <---> dst2 +// +// The returned values are the src1 and dst2 which should be able to send and +// receive to each other via the Conn, the Conn itself (not running), and a +// stopper func to close everything. +func testConnPipelineSetup(t *testing.T) (net.Conn, net.Conn, *Conn, func()) { + src1, dst1, stop1 := testConnPairSetup(t) + src2, dst2, stop2 := testConnPairSetup(t) + c := NewConn(dst1, src2) + return src1, dst2, c, func() { + c.Close() + stop1() + stop2() + } +} + +func TestConn(t *testing.T) { + src, dst, c, stop := testConnPipelineSetup(t) + defer stop() + + retCh := make(chan error, 1) + go func() { + retCh <- c.CopyBytes() + }() + + // Now write/read into the other ends of the pipes (src1, dst2) + srcR := bufio.NewReader(src) + dstR := bufio.NewReader(dst) + + _, err := src.Write([]byte("ping 1\n")) + require.Nil(t, err) + _, err = dst.Write([]byte("ping 2\n")) + require.Nil(t, err) + + got, err := dstR.ReadString('\n') + require.Nil(t, err) + require.Equal(t, "ping 1\n", got) + + got, err = srcR.ReadString('\n') + require.Nil(t, err) + require.Equal(t, "ping 2\n", got) + + _, err = src.Write([]byte("pong 1\n")) + require.Nil(t, err) + _, err = dst.Write([]byte("pong 2\n")) + require.Nil(t, err) + + got, err = dstR.ReadString('\n') + require.Nil(t, err) + require.Equal(t, "pong 1\n", got) + + got, err = srcR.ReadString('\n') + require.Nil(t, err) + require.Equal(t, "pong 2\n", got) + + c.Close() + + ret := <-retCh + require.Nil(t, ret, "Close() should not cause error return") +} + +func TestConnSrcClosing(t *testing.T) { + src, dst, c, stop := testConnPipelineSetup(t) + defer stop() + + retCh := make(chan error, 1) + go func() { + retCh <- c.CopyBytes() + }() + + // Wait until we can actually get some bytes through both ways so we know that + // the copy goroutines are running. + srcR := bufio.NewReader(src) + dstR := bufio.NewReader(dst) + + _, err := src.Write([]byte("ping 1\n")) + require.Nil(t, err) + _, err = dst.Write([]byte("ping 2\n")) + require.Nil(t, err) + + got, err := dstR.ReadString('\n') + require.Nil(t, err) + require.Equal(t, "ping 1\n", got) + got, err = srcR.ReadString('\n') + require.Nil(t, err) + require.Equal(t, "ping 2\n", got) + + // If we close the src conn, we expect CopyBytes to return and dst to be + // closed too. No good way to assert that the conn is closed really other than + // assume the retCh receive will hang unless CopyBytes exits and that + // CopyBytes defers Closing both. + testTimer := time.AfterFunc(3*time.Second, func() { + panic("test timeout") + }) + src.Close() + <-retCh + testTimer.Stop() +} + +func TestConnDstClosing(t *testing.T) { + src, dst, c, stop := testConnPipelineSetup(t) + defer stop() + + retCh := make(chan error, 1) + go func() { + retCh <- c.CopyBytes() + }() + + // Wait until we can actually get some bytes through both ways so we know that + // the copy goroutines are running. + srcR := bufio.NewReader(src) + dstR := bufio.NewReader(dst) + + _, err := src.Write([]byte("ping 1\n")) + require.Nil(t, err) + _, err = dst.Write([]byte("ping 2\n")) + require.Nil(t, err) + + got, err := dstR.ReadString('\n') + require.Nil(t, err) + require.Equal(t, "ping 1\n", got) + got, err = srcR.ReadString('\n') + require.Nil(t, err) + require.Equal(t, "ping 2\n", got) + + // If we close the dst conn, we expect CopyBytes to return and src to be + // closed too. No good way to assert that the conn is closed really other than + // assume the retCh receive will hang unless CopyBytes exits and that + // CopyBytes defers Closing both. i.e. if this test doesn't time out it's + // good! + testTimer := time.AfterFunc(3*time.Second, func() { + panic("test timeout") + }) + src.Close() + <-retCh + testTimer.Stop() +} diff --git a/connect/proxy/listener.go b/connect/proxy/listener.go new file mode 100644 index 0000000000..c003cb19c3 --- /dev/null +++ b/connect/proxy/listener.go @@ -0,0 +1,116 @@ +package proxy + +import ( + "context" + "crypto/tls" + "errors" + "log" + "net" + "sync/atomic" + "time" + + "github.com/hashicorp/consul/connect" +) + +// Listener is the implementation of a specific proxy listener. It has pluggable +// Listen and Dial methods to suit public mTLS vs upstream semantics. It handles +// the lifecycle of the listener and all connections opened through it +type Listener struct { + // Service is the connect service instance to use. + Service *connect.Service + + listenFunc func() (net.Listener, error) + dialFunc func() (net.Conn, error) + + stopFlag int32 + stopChan chan struct{} + + logger *log.Logger +} + +// NewPublicListener returns a Listener setup to listen for public mTLS +// connections and proxy them to the configured local application over TCP. +func NewPublicListener(svc *connect.Service, cfg PublicListenerConfig, + logger *log.Logger) *Listener { + return &Listener{ + Service: svc, + listenFunc: func() (net.Listener, error) { + return tls.Listen("tcp", cfg.BindAddress, svc.ServerTLSConfig()) + }, + dialFunc: func() (net.Conn, error) { + return net.DialTimeout("tcp", cfg.LocalServiceAddress, + time.Duration(cfg.LocalConnectTimeoutMs)*time.Millisecond) + }, + stopChan: make(chan struct{}), + logger: logger, + } +} + +// NewUpstreamListener returns a Listener setup to listen locally for TCP +// connections that are proxied to a discovered Connect service instance. +func NewUpstreamListener(svc *connect.Service, cfg UpstreamConfig, + logger *log.Logger) *Listener { + return &Listener{ + Service: svc, + listenFunc: func() (net.Listener, error) { + return net.Listen("tcp", cfg.LocalBindAddress) + }, + dialFunc: func() (net.Conn, error) { + if cfg.resolver == nil { + return nil, errors.New("no resolver provided") + } + ctx, cancel := context.WithTimeout(context.Background(), + time.Duration(cfg.ConnectTimeoutMs)*time.Millisecond) + defer cancel() + return svc.Dial(ctx, cfg.resolver) + }, + stopChan: make(chan struct{}), + logger: logger, + } +} + +// Serve runs the listener until it is stopped. +func (l *Listener) Serve() error { + listen, err := l.listenFunc() + if err != nil { + return err + } + + for { + conn, err := listen.Accept() + if err != nil { + if atomic.LoadInt32(&l.stopFlag) == 1 { + return nil + } + return err + } + + go l.handleConn(conn) + } + return nil +} + +// handleConn is the internal connection handler goroutine. +func (l *Listener) handleConn(src net.Conn) { + defer src.Close() + + dst, err := l.dialFunc() + if err != nil { + l.logger.Printf("[ERR] failed to dial: %s", err) + return + } + // Note no need to defer dst.Close() since conn handles that for us. + conn := NewConn(src, dst) + defer conn.Close() + + err = conn.CopyBytes() + if err != nil { + l.logger.Printf("[ERR] connection failed: %s", err) + return + } +} + +// Close terminates the listener and all active connections. +func (l *Listener) Close() error { + return nil +} diff --git a/connect/proxy/listener_test.go b/connect/proxy/listener_test.go new file mode 100644 index 0000000000..ce41c81e59 --- /dev/null +++ b/connect/proxy/listener_test.go @@ -0,0 +1,91 @@ +package proxy + +import ( + "context" + "log" + "net" + "os" + "testing" + + agConnect "github.com/hashicorp/consul/agent/connect" + "github.com/hashicorp/consul/connect" + "github.com/stretchr/testify/require" +) + +func TestPublicListener(t *testing.T) { + ca := agConnect.TestCA(t, nil) + addrs := TestLocalBindAddrs(t, 2) + + cfg := PublicListenerConfig{ + BindAddress: addrs[0], + LocalServiceAddress: addrs[1], + HandshakeTimeoutMs: 100, + LocalConnectTimeoutMs: 100, + } + + testApp, err := NewTestTCPServer(t, cfg.LocalServiceAddress) + require.Nil(t, err) + defer testApp.Close() + + svc := connect.TestService(t, "db", ca) + + l := NewPublicListener(svc, cfg, log.New(os.Stderr, "", log.LstdFlags)) + + // Run proxy + go func() { + err := l.Serve() + require.Nil(t, err) + }() + defer l.Close() + + // Proxy and backend are running, play the part of a TLS client using same + // cert for now. + conn, err := svc.Dial(context.Background(), &connect.StaticResolver{ + Addr: addrs[0], + CertURI: agConnect.TestSpiffeIDService(t, "db"), + }) + require.Nilf(t, err, "unexpected err: %s", err) + TestEchoConn(t, conn, "") +} + +func TestUpstreamListener(t *testing.T) { + ca := agConnect.TestCA(t, nil) + addrs := TestLocalBindAddrs(t, 1) + + // Run a test server that we can dial. + testSvr := connect.NewTestServer(t, "db", ca) + go func() { + err := testSvr.Serve() + require.Nil(t, err) + }() + defer testSvr.Close() + + cfg := UpstreamConfig{ + DestinationType: "service", + DestinationNamespace: "default", + DestinationName: "db", + ConnectTimeoutMs: 100, + LocalBindAddress: addrs[0], + resolver: &connect.StaticResolver{ + Addr: testSvr.Addr, + CertURI: agConnect.TestSpiffeIDService(t, "db"), + }, + } + + svc := connect.TestService(t, "web", ca) + + l := NewUpstreamListener(svc, cfg, log.New(os.Stderr, "", log.LstdFlags)) + + // Run proxy + go func() { + err := l.Serve() + require.Nil(t, err) + }() + defer l.Close() + + // Proxy and fake remote service are running, play the part of the app + // connecting to a remote connect service over TCP. + conn, err := net.Dial("tcp", cfg.LocalBindAddress) + require.Nilf(t, err, "unexpected err: %s", err) + TestEchoConn(t, conn, "") +} diff --git a/connect/proxy/proxy.go b/connect/proxy/proxy.go new file mode 100644 index 0000000000..bda6f3afbd --- /dev/null +++ b/connect/proxy/proxy.go @@ -0,0 +1,134 @@ +package proxy + +import ( + "log" + + "github.com/hashicorp/consul/api" + "github.com/hashicorp/consul/connect" +) + +// Proxy implements the built-in connect proxy. +type Proxy struct { + proxyID string + client *api.Client + cfgWatcher ConfigWatcher + stopChan chan struct{} + logger *log.Logger +} + +// NewFromConfigFile returns a Proxy instance configured just from a local file. +// This is intended mostly for development and bypasses the normal mechanisms +// for fetching config and certificates from the local agent. +func NewFromConfigFile(client *api.Client, filename string, + logger *log.Logger) (*Proxy, error) { + cfg, err := ParseConfigFile(filename) + if err != nil { + return nil, err + } + + service, err := connect.NewDevServiceFromCertFiles(cfg.ProxiedServiceID, + client, logger, cfg.DevCAFile, cfg.DevServiceCertFile, + cfg.DevServiceKeyFile) + if err != nil { + return nil, err + } + cfg.service = service + + p := &Proxy{ + proxyID: cfg.ProxyID, + client: client, + cfgWatcher: NewStaticConfigWatcher(cfg), + stopChan: make(chan struct{}), + logger: logger, + } + return p, nil +} + +// New returns a Proxy with the given id, consuming the provided (configured) +// agent. It is ready to Run(). +func New(client *api.Client, proxyID string, logger *log.Logger) (*Proxy, error) { + p := &Proxy{ + proxyID: proxyID, + client: client, + cfgWatcher: &AgentConfigWatcher{ + client: client, + proxyID: proxyID, + logger: logger, + }, + stopChan: make(chan struct{}), + logger: logger, + } + return p, nil +} + +// Serve the proxy instance until a fatal error occurs or proxy is closed. +func (p *Proxy) Serve() error { + + var cfg *Config + + // Watch for config changes (initial setup happens on first "change") + for { + select { + case newCfg := <-p.cfgWatcher.Watch(): + p.logger.Printf("[DEBUG] got new config") + if newCfg.service == nil { + p.logger.Printf("[ERR] new config has nil service") + continue + } + if cfg == nil { + // Initial setup + + newCfg.PublicListener.applyDefaults() + l := NewPublicListener(newCfg.service, newCfg.PublicListener, p.logger) + err := p.startListener("public listener", l) + if err != nil { + return err + } + } + + // TODO(banks) update/remove upstreams properly based on a diff with current. Can + // store a map of uc.String() to Listener here and then use it to only + // start one of each and stop/modify if changes occur. + for _, uc := range newCfg.Upstreams { + uc.applyDefaults() + uc.resolver = UpstreamResolverFromClient(p.client, uc) + + l := NewUpstreamListener(newCfg.service, uc, p.logger) + err := p.startListener(uc.String(), l) + if err != nil { + p.logger.Printf("[ERR] failed to start upstream %s: %s", uc.String(), + err) + } + } + cfg = newCfg + + case <-p.stopChan: + return nil + } + } +} + +// startPublicListener is run from the internal state machine loop +func (p *Proxy) startListener(name string, l *Listener) error { + go func() { + err := l.Serve() + if err != nil { + p.logger.Printf("[ERR] %s stopped with error: %s", name, err) + return + } + p.logger.Printf("[INFO] %s stopped", name) + }() + + go func() { + <-p.stopChan + l.Close() + }() + + return nil +} + +// Close stops the proxy and terminates all active connections. It must be +// called only once. +func (p *Proxy) Close() { + close(p.stopChan) +} diff --git a/connect/proxy/testdata/config-kitchensink.hcl b/connect/proxy/testdata/config-kitchensink.hcl new file mode 100644 index 0000000000..2bda997917 --- /dev/null +++ b/connect/proxy/testdata/config-kitchensink.hcl @@ -0,0 +1,32 @@ +# Example proxy config with everything specified + +proxy_id = "foo" +token = "11111111-2222-3333-4444-555555555555" + +proxied_service_id = "web" +proxied_service_namespace = "default" + +# Assumes running consul in dev mode from the repo root... +dev_ca_file = "connect/testdata/ca1-ca-consul-internal.cert.pem" +dev_service_cert_file = "connect/testdata/ca1-svc-web.cert.pem" +dev_service_key_file = "connect/testdata/ca1-svc-web.key.pem" + +public_listener { + bind_address = ":9999" + local_service_address = "127.0.0.1:5000" +} + +upstreams = [ + { + local_bind_address = "127.0.0.1:6000" + destination_name = "db" + destination_namespace = "default" + destination_type = "service" + }, + { + local_bind_address = "127.0.0.1:6001" + destination_name = "geo-cache" + destination_namespace = "default" + destination_type = "prepared_query" + } +] diff --git a/connect/proxy/testing.go b/connect/proxy/testing.go new file mode 100644 index 0000000000..9ed8c41c4e --- /dev/null +++ b/connect/proxy/testing.go @@ -0,0 +1,105 @@ +package proxy + +import ( + "fmt" + "io" + "log" + "net" + "sync/atomic" + + "github.com/hashicorp/consul/lib/freeport" + "github.com/mitchellh/go-testing-interface" + "github.com/stretchr/testify/require" +) + +// TestLocalBindAddrs returns n localhost address:port strings with free ports +// for binding test listeners to. +func TestLocalBindAddrs(t testing.T, n int) []string { + ports := freeport.GetT(t, n) + addrs := make([]string, n) + for i, p := range ports { + addrs[i] = fmt.Sprintf("localhost:%d", p) + } + return addrs +} + +// TestTCPServer is a simple TCP echo server for use during tests. +type TestTCPServer struct { + l net.Listener + stopped int32 + accepted, closed, active int32 +} + +// NewTestTCPServer opens as a listening socket on the given address and returns +// a TestTCPServer serving requests to it. The server is already started and can +// be stopped by calling Close(). +func NewTestTCPServer(t testing.T, addr string) (*TestTCPServer, error) { + l, err := net.Listen("tcp", addr) + if err != nil { + return nil, err + } + log.Printf("test tcp server listening on %s", addr) + s := &TestTCPServer{ + l: l, + } + go s.accept() + return s, nil +} + +// Close stops the server +func (s *TestTCPServer) Close() { + atomic.StoreInt32(&s.stopped, 1) + if s.l != nil { + s.l.Close() + } +} + +func (s *TestTCPServer) accept() error { + for { + conn, err := s.l.Accept() + if err != nil { + if atomic.LoadInt32(&s.stopped) == 1 { + log.Printf("test tcp echo server %s stopped", s.l.Addr()) + return nil + } + log.Printf("test tcp echo server %s failed: %s", s.l.Addr(), err) + return err + } + + atomic.AddInt32(&s.accepted, 1) + atomic.AddInt32(&s.active, 1) + + go func(c net.Conn) { + io.Copy(c, c) + atomic.AddInt32(&s.closed, 1) + atomic.AddInt32(&s.active, -1) + }(conn) + } +} + +// TestEchoConn attempts to write some bytes to conn and expects to read them +// back within a short timeout (10ms). If prefix is not empty we expect it to be +// poresent at the start of all echoed responses (for example to distinguish +// between multiple echo server instances). +func TestEchoConn(t testing.T, conn net.Conn, prefix string) { + t.Helper() + + // Write some bytes and read them back + n, err := conn.Write([]byte("Hello World")) + require.Equal(t, 11, n) + require.Nil(t, err) + + expectLen := 11 + len(prefix) + + buf := make([]byte, expectLen) + // read until our buffer is full - it might be separate packets if prefix is + // in use. + got := 0 + for got < expectLen { + n, err = conn.Read(buf[got:]) + require.Nilf(t, err, "err: %s", err) + got += n + } + require.Equal(t, expectLen, got) + require.Equal(t, prefix+"Hello World", string(buf[:])) +} diff --git a/connect/resolver.go b/connect/resolver.go index 41dc70e82e..9873fcdf1f 100644 --- a/connect/resolver.go +++ b/connect/resolver.go @@ -10,7 +10,9 @@ import ( testing "github.com/mitchellh/go-testing-interface" ) -// Resolver is the interface implemented by a service discovery mechanism. +// Resolver is the interface implemented by a service discovery mechanism to get +// the address and identity of an instance to connect to via Connect as a +// client. type Resolver interface { // Resolve returns a single service instance to connect to. Implementations // may attempt to ensure the instance returned is currently available. It is @@ -19,7 +21,10 @@ type Resolver interface { // increases reliability. The context passed can be used to impose timeouts // which may or may not be respected by implementations that make network // calls to resolve the service. The addr returned is a string in any valid - // form for passing directly to `net.Dial("tcp", addr)`. + // form for passing directly to `net.Dial("tcp", addr)`. The certURI + // represents the identity of the service instance. It will be matched against + // the TLS certificate URI SAN presented by the server and the connection + // rejected if they don't match. Resolve(ctx context.Context) (addr string, certURI connect.CertURI, err error) } @@ -33,7 +38,8 @@ type StaticResolver struct { Addr string // CertURL is the _identity_ we expect the server to present in it's TLS - // certificate. It must be an exact match or the connection will be rejected. + // certificate. It must be an exact URI string match or the connection will be + // rejected. CertURI connect.CertURI } @@ -56,13 +62,14 @@ type ConsulResolver struct { // panic. Client *api.Client - // Namespace of the query target + // Namespace of the query target. Namespace string - // Name of the query target + // Name of the query target. Name string - // Type of the query target, + // Type of the query target. Should be one of the defined ConsulResolverType* + // constants. Currently defaults to ConsulResolverTypeService. Type int // Datacenter to resolve in, empty indicates agent's local DC. diff --git a/connect/resolver_test.go b/connect/resolver_test.go index 29a40e3d32..3ab439addb 100644 --- a/connect/resolver_test.go +++ b/connect/resolver_test.go @@ -41,7 +41,6 @@ func TestStaticResolver_Resolve(t *testing.T) { } func TestConsulResolver_Resolve(t *testing.T) { - // Setup a local test agent to query agent := agent.NewTestAgent("test-consul", "") defer agent.Shutdown() diff --git a/connect/service.go b/connect/service.go index db83ce5aad..6bbda08079 100644 --- a/connect/service.go +++ b/connect/service.go @@ -3,6 +3,7 @@ package connect import ( "context" "crypto/tls" + "errors" "log" "net" "net/http" @@ -10,6 +11,7 @@ import ( "time" "github.com/hashicorp/consul/api" + "golang.org/x/net/http2" ) // Service represents a Consul service that accepts and/or connects via Connect. @@ -41,10 +43,17 @@ type Service struct { client *api.Client // serverTLSCfg is the (reloadable) TLS config we use for serving. - serverTLSCfg *ReloadableTLSConfig + serverTLSCfg *reloadableTLSConfig // clientTLSCfg is the (reloadable) TLS config we use for dialling. - clientTLSCfg *ReloadableTLSConfig + clientTLSCfg *reloadableTLSConfig + + // httpResolverFromAddr is a function that returns a Resolver from a string + // address for HTTP clients. It's privately pluggable to make testing easier + // but will default to a simple method to parse the host as a Consul DNS host. + // + // TODO(banks): write the proper implementation + httpResolverFromAddr func(addr string) (Resolver, error) logger *log.Logger } @@ -65,8 +74,8 @@ func NewServiceWithLogger(serviceID string, client *api.Client, client: client, logger: logger, } - s.serverTLSCfg = NewReloadableTLSConfig(defaultTLSConfig(serverVerifyCerts)) - s.clientTLSCfg = NewReloadableTLSConfig(defaultTLSConfig(clientVerifyCerts)) + s.serverTLSCfg = newReloadableTLSConfig(defaultTLSConfig(serverVerifyCerts)) + s.clientTLSCfg = newReloadableTLSConfig(defaultTLSConfig(clientVerifyCerts)) // TODO(banks) run the background certificate sync return s, nil @@ -86,12 +95,12 @@ func NewDevServiceFromCertFiles(serviceID string, client *api.Client, return nil, err } - // Note that NewReloadableTLSConfig makes a copy so we can re-use the same + // Note that newReloadableTLSConfig makes a copy so we can re-use the same // base for both client and server with swapped verifiers. tlsCfg.VerifyPeerCertificate = serverVerifyCerts - s.serverTLSCfg = NewReloadableTLSConfig(tlsCfg) + s.serverTLSCfg = newReloadableTLSConfig(tlsCfg) tlsCfg.VerifyPeerCertificate = clientVerifyCerts - s.clientTLSCfg = NewReloadableTLSConfig(tlsCfg) + s.clientTLSCfg = newReloadableTLSConfig(tlsCfg) return s, nil } @@ -121,6 +130,8 @@ func (s *Service) Dial(ctx context.Context, resolver Resolver) (net.Conn, error) if err != nil { return nil, err } + s.logger.Printf("[DEBUG] resolved service instance: %s (%s)", addr, + certURI.URI()) var dialer net.Dialer tcpConn, err := dialer.DialContext(ctx, "tcp", addr) if err != nil { @@ -133,8 +144,8 @@ func (s *Service) Dial(ctx context.Context, resolver Resolver) (net.Conn, error) if ok { tlsConn.SetDeadline(deadline) } - err = tlsConn.Handshake() - if err != nil { + // Perform handshake + if err = tlsConn.Handshake(); err != nil { tlsConn.Close() return nil, err } @@ -149,20 +160,27 @@ func (s *Service) Dial(ctx context.Context, resolver Resolver) (net.Conn, error) tlsConn.Close() return nil, err } - + s.logger.Printf("[DEBUG] successfully connected to %s (%s)", addr, + certURI.URI()) return tlsConn, nil } -// HTTPDialContext is compatible with http.Transport.DialContext. It expects the -// addr hostname to be specified using Consul DNS query syntax, e.g. +// HTTPDialTLS is compatible with http.Transport.DialTLS. It expects the addr +// hostname to be specified using Consul DNS query syntax, e.g. // "web.service.consul". It converts that into the equivalent ConsulResolver and // then call s.Dial with the resolver. This is low level, clients should // typically use HTTPClient directly. -func (s *Service) HTTPDialContext(ctx context.Context, network, +func (s *Service) HTTPDialTLS(network, addr string) (net.Conn, error) { - var r ConsulResolver - // TODO(banks): parse addr into ConsulResolver - return s.Dial(ctx, &r) + if s.httpResolverFromAddr == nil { + return nil, errors.New("no http resolver configured") + } + r, err := s.httpResolverFromAddr(addr) + if err != nil { + return nil, err + } + // TODO(banks): figure out how to do timeouts better. + return s.Dial(context.Background(), r) } // HTTPClient returns an *http.Client configured to dial remote Consul Connect @@ -172,14 +190,27 @@ func (s *Service) HTTPDialContext(ctx context.Context, network, // API rather than just relying on Consul DNS. Hostnames that are not valid // Consul DNS queries will fail. func (s *Service) HTTPClient() *http.Client { + t := &http.Transport{ + // Sadly we can't use DialContext hook since that is expected to return a + // plain TCP connection an http.Client tries to start a TLS handshake over + // it. We need to control the handshake to be able to do our validation. + // So we have to use the older DialTLS which means no context/timeout + // support. + // + // TODO(banks): figure out how users can configure a timeout when using + // this and/or compatibility with http.Request.WithContext. + DialTLS: s.HTTPDialTLS, + } + // Need to manually re-enable http2 support since we set custom DialTLS. + // See https://golang.org/src/net/http/transport.go?s=8692:9036#L228 + http2.ConfigureTransport(t) return &http.Client{ - Transport: &http.Transport{ - DialContext: s.HTTPDialContext, - }, + Transport: t, } } // Close stops the service and frees resources. -func (s *Service) Close() { +func (s *Service) Close() error { // TODO(banks): stop background activity if started + return nil } diff --git a/connect/service_test.go b/connect/service_test.go index a2adfe7f1f..7bc4c97f21 100644 --- a/connect/service_test.go +++ b/connect/service_test.go @@ -2,14 +2,22 @@ package connect import ( "context" + "crypto/tls" "fmt" + "io" + "io/ioutil" + "net/http" "testing" "time" "github.com/hashicorp/consul/agent/connect" + "github.com/hashicorp/consul/testutil/retry" "github.com/stretchr/testify/require" ) +// Assert io.Closer implementation +var _ io.Closer = new(Service) + func TestService_Dial(t *testing.T) { ca := connect.TestCA(t, nil) @@ -53,30 +61,26 @@ func TestService_Dial(t *testing.T) { t.Run(tt.name, func(t *testing.T) { require := require.New(t) - s, err := NewService("web", nil) - require.Nil(err) - - // Force TLSConfig - s.clientTLSCfg = NewReloadableTLSConfig(TestTLSConfig(t, "web", ca)) + s := TestService(t, "web", ca) ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) defer cancel() - testSvc := NewTestService(t, tt.presentService, ca) - testSvc.TimeoutHandshake = !tt.handshake + testSvr := NewTestServer(t, tt.presentService, ca) + testSvr.TimeoutHandshake = !tt.handshake if tt.accept { go func() { - err := testSvc.Serve() + err := testSvr.Serve() require.Nil(err) }() - defer testSvc.Close() + defer testSvr.Close() } // Always expect to be connecting to a "DB" resolver := &StaticResolver{ - Addr: testSvc.Addr, + Addr: testSvr.Addr, CertURI: connect.TestSpiffeIDService(t, "db"), } @@ -92,6 +96,7 @@ func TestService_Dial(t *testing.T) { if tt.wantErr == "" { require.Nil(err) + require.IsType(&tls.Conn{}, conn) } else { require.NotNil(err) require.Contains(err.Error(), tt.wantErr) @@ -103,3 +108,62 @@ func TestService_Dial(t *testing.T) { }) } } + +func TestService_ServerTLSConfig(t *testing.T) { + // TODO(banks): it's mostly meaningless to test this now since we directly set + // the tlsCfg in our TestService helper which is all we'd be asserting on here + // not the actual implementation. Once agent tls fetching is built, it becomes + // more meaningful to actually verify it's returning the correct config. +} + +func TestService_HTTPClient(t *testing.T) { + require := require.New(t) + ca := connect.TestCA(t, nil) + + s := TestService(t, "web", ca) + + // Run a test HTTP server + testSvr := NewTestServer(t, "backend", ca) + defer testSvr.Close() + go func() { + err := testSvr.ServeHTTPS(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Write([]byte("Hello, I am Backend")) + })) + require.Nil(t, err) + }() + + // TODO(banks): this will talk http2 on both client and server. I hit some + // compatibility issues when testing though need to make sure that the http + // server with our TLSConfig can actually support HTTP/1.1 as well. Could make + // this a table test with all 4 permutations of client/server http version + // support. + + // Still get connection refused some times so retry on those + retry.Run(t, func(r *retry.R) { + // Hook the service resolver to avoid needing full agent setup. + s.httpResolverFromAddr = func(addr string) (Resolver, error) { + // Require in this goroutine seems to block causing a timeout on the Get. + //require.Equal("https://backend.service.consul:443", addr) + return &StaticResolver{ + Addr: testSvr.Addr, + CertURI: connect.TestSpiffeIDService(t, "backend"), + }, nil + } + + client := s.HTTPClient() + client.Timeout = 1 * time.Second + + resp, err := client.Get("https://backend.service.consul/foo") + r.Check(err) + defer resp.Body.Close() + + bodyBytes, err := ioutil.ReadAll(resp.Body) + r.Check(err) + + got := string(bodyBytes) + want := "Hello, I am Backend" + if got != want { + r.Fatalf("got %s, want %s", got, want) + } + }) +} diff --git a/connect/testing.go b/connect/testing.go index f6fa438cfe..235ff60018 100644 --- a/connect/testing.go +++ b/connect/testing.go @@ -5,26 +5,33 @@ import ( "crypto/x509" "fmt" "io" + "log" "net" + "net/http" "sync/atomic" "github.com/hashicorp/consul/agent/connect" "github.com/hashicorp/consul/agent/structs" "github.com/hashicorp/consul/lib/freeport" testing "github.com/mitchellh/go-testing-interface" - "github.com/stretchr/testify/require" ) -// testVerifier creates a helper verifyFunc that can be set in a tls.Config and -// records calls made, passing back the certificates presented via the returned -// channel. The channel is buffered so up to 128 verification calls can be made -// without reading the chan before verification blocks. -func testVerifier(t testing.T, returnErr error) (verifyFunc, chan [][]byte) { - ch := make(chan [][]byte, 128) - return func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error { - ch <- rawCerts - return returnErr - }, ch +// TestService returns a Service instance based on a static TLS Config. +func TestService(t testing.T, service string, ca *structs.CARoot) *Service { + t.Helper() + + // Don't need to talk to client since we are setting TLSConfig locally + svc, err := NewService(service, nil) + if err != nil { + t.Fatal(err) + } + + svc.serverTLSCfg = newReloadableTLSConfig( + TestTLSConfigWithVerifier(t, service, ca, serverVerifyCerts)) + svc.clientTLSCfg = newReloadableTLSConfig( + TestTLSConfigWithVerifier(t, service, ca, clientVerifyCerts)) + + return svc } // TestTLSConfig returns a *tls.Config suitable for use during tests. @@ -32,7 +39,16 @@ func TestTLSConfig(t testing.T, service string, ca *structs.CARoot) *tls.Config t.Helper() // Insecure default (nil verifier) - cfg := defaultTLSConfig(nil) + return TestTLSConfigWithVerifier(t, service, ca, nil) +} + +// TestTLSConfigWithVerifier returns a *tls.Config suitable for use during +// tests, it will use the given verifyFunc to verify tls certificates. +func TestTLSConfigWithVerifier(t testing.T, service string, ca *structs.CARoot, + verifier verifyFunc) *tls.Config { + t.Helper() + + cfg := defaultTLSConfig(verifier) cfg.Certificates = []tls.Certificate{TestSvcKeyPair(t, service, ca)} cfg.RootCAs = TestCAPool(t, ca) cfg.ClientCAs = TestCAPool(t, ca) @@ -55,7 +71,9 @@ func TestSvcKeyPair(t testing.T, service string, ca *structs.CARoot) tls.Certifi t.Helper() certPEM, keyPEM := connect.TestLeaf(t, service, ca) cert, err := tls.X509KeyPair([]byte(certPEM), []byte(keyPEM)) - require.Nil(t, err) + if err != nil { + t.Fatal(err) + } return cert } @@ -65,13 +83,15 @@ func TestPeerCertificates(t testing.T, service string, ca *structs.CARoot) []*x5 t.Helper() certPEM, _ := connect.TestLeaf(t, service, ca) cert, err := connect.ParseCert(certPEM) - require.Nil(t, err) + if err != nil { + t.Fatal(err) + } return []*x509.Certificate{cert} } -// TestService runs a service listener that can be used to test clients. It's +// TestServer runs a service listener that can be used to test clients. It's // behaviour can be controlled by the struct members. -type TestService struct { +type TestServer struct { // The service name to serve. Service string // The (test) CA to use for generating certs. @@ -91,11 +111,11 @@ type TestService struct { stopChan chan struct{} } -// NewTestService returns a TestService. It should be closed when test is +// NewTestServer returns a TestServer. It should be closed when test is // complete. -func NewTestService(t testing.T, service string, ca *structs.CARoot) *TestService { +func NewTestServer(t testing.T, service string, ca *structs.CARoot) *TestServer { ports := freeport.GetT(t, 1) - return &TestService{ + return &TestServer{ Service: service, CA: ca, stopChan: make(chan struct{}), @@ -104,14 +124,16 @@ func NewTestService(t testing.T, service string, ca *structs.CARoot) *TestServic } } -// Serve runs a TestService and blocks until it is closed or errors. -func (s *TestService) Serve() error { +// Serve runs a tcp echo server and blocks until it is closed or errors. If +// TimeoutHandshake is set it won't start TLS handshake on new connections. +func (s *TestServer) Serve() error { // Just accept TCP conn but so we can control timing of accept/handshake l, err := net.Listen("tcp", s.Addr) if err != nil { return err } s.l = l + log.Printf("test connect service listening on %s", s.Addr) for { conn, err := s.l.Accept() @@ -122,12 +144,14 @@ func (s *TestService) Serve() error { return err } - // Ignore the conn if we are not actively ha + // Ignore the conn if we are not actively handshaking if !s.TimeoutHandshake { // Upgrade conn to TLS conn = tls.Server(conn, s.TLSCfg) // Run an echo service + log.Printf("test connect service accepted conn from %s, "+ + " running echo service", conn.RemoteAddr()) go io.Copy(conn, conn) } @@ -141,8 +165,20 @@ func (s *TestService) Serve() error { return nil } -// Close stops a TestService -func (s *TestService) Close() { +// ServeHTTPS runs an HTTPS server with the given config. It invokes the passed +// Handler for all requests. +func (s *TestServer) ServeHTTPS(h http.Handler) error { + srv := http.Server{ + Addr: s.Addr, + TLSConfig: s.TLSCfg, + Handler: h, + } + log.Printf("starting test connect HTTPS server on %s", s.Addr) + return srv.ListenAndServeTLS("", "") +} + +// Close stops a TestServer +func (s *TestServer) Close() error { old := atomic.SwapInt32(&s.stopFlag, 1) if old == 0 { if s.l != nil { @@ -150,4 +186,5 @@ func (s *TestService) Close() { } close(s.stopChan) } + return nil } diff --git a/connect/tls.go b/connect/tls.go index 8d3bc3a94c..89d5ccb542 100644 --- a/connect/tls.go +++ b/connect/tls.go @@ -42,27 +42,27 @@ func defaultTLSConfig(verify verifyFunc) *tls.Config { } } -// ReloadableTLSConfig exposes a tls.Config that can have it's certificates +// reloadableTLSConfig exposes a tls.Config that can have it's certificates // reloaded. On a server, this uses GetConfigForClient to pass the current // tls.Config or client certificate for each acceptted connection. On a client, // this uses GetClientCertificate to provide the current client certificate. -type ReloadableTLSConfig struct { +type reloadableTLSConfig struct { mu sync.Mutex // cfg is the current config to use for new connections cfg *tls.Config } -// NewReloadableTLSConfig returns a reloadable config currently set to base. -func NewReloadableTLSConfig(base *tls.Config) *ReloadableTLSConfig { - c := &ReloadableTLSConfig{} +// newReloadableTLSConfig returns a reloadable config currently set to base. +func newReloadableTLSConfig(base *tls.Config) *reloadableTLSConfig { + c := &reloadableTLSConfig{} c.SetTLSConfig(base) return c } // TLSConfig returns a *tls.Config that will dynamically load certs. It's // suitable for use in either a client or server. -func (c *ReloadableTLSConfig) TLSConfig() *tls.Config { +func (c *reloadableTLSConfig) TLSConfig() *tls.Config { c.mu.Lock() cfgCopy := c.cfg c.mu.Unlock() @@ -71,7 +71,7 @@ func (c *ReloadableTLSConfig) TLSConfig() *tls.Config { // SetTLSConfig sets the config used for future connections. It is safe to call // from any goroutine. -func (c *ReloadableTLSConfig) SetTLSConfig(cfg *tls.Config) error { +func (c *reloadableTLSConfig) SetTLSConfig(cfg *tls.Config) error { copy := cfg.Clone() copy.GetClientCertificate = func(*tls.CertificateRequestInfo) (*tls.Certificate, error) { current := c.TLSConfig() diff --git a/connect/tls_test.go b/connect/tls_test.go index 3605f22dbb..64c473c1e4 100644 --- a/connect/tls_test.go +++ b/connect/tls_test.go @@ -10,10 +10,9 @@ import ( func TestReloadableTLSConfig(t *testing.T) { require := require.New(t) - verify, _ := testVerifier(t, nil) - base := defaultTLSConfig(verify) + base := defaultTLSConfig(nil) - c := NewReloadableTLSConfig(base) + c := newReloadableTLSConfig(base) // The dynamic config should be the one we loaded (with some different hooks) got := c.TLSConfig()