Rework connect/proxy and command/connect/proxy. End to end demo working again

This commit is contained in:
Paul Banks 2018-04-03 19:10:59 +01:00 committed by Mitchell Hashimoto
parent aa19be4651
commit 10db79c8ae
No known key found for this signature in database
GPG Key ID: 744E147AA52F5B0A
20 changed files with 1279 additions and 97 deletions

View File

@ -157,7 +157,7 @@ func TestLeaf(t testing.T, service string, root *structs.CARoot) (string, string
t.Fatalf("error generating serial number: %s", err) t.Fatalf("error generating serial number: %s", err)
} }
// Genereate fresh private key // Generate fresh private key
pkSigner, pkPEM := testPrivateKey(t) pkSigner, pkPEM := testPrivateKey(t)
// Cert template for generation // Cert template for generation

View File

@ -1,17 +1,15 @@
package proxy package proxy
import ( import (
"context"
"flag" "flag"
"fmt" "fmt"
"io" "io"
"log" "log"
"net/http" "net/http"
// Expose pprof if configured _ "net/http/pprof" // Expose pprof if configured
_ "net/http/pprof"
"github.com/hashicorp/consul/command/flags" "github.com/hashicorp/consul/command/flags"
proxyImpl "github.com/hashicorp/consul/proxy" proxyImpl "github.com/hashicorp/consul/connect/proxy"
"github.com/hashicorp/consul/logger" "github.com/hashicorp/consul/logger"
"github.com/hashicorp/logutils" "github.com/hashicorp/logutils"
@ -46,13 +44,14 @@ type cmd struct {
func (c *cmd) init() { func (c *cmd) init() {
c.flags = flag.NewFlagSet("", flag.ContinueOnError) c.flags = flag.NewFlagSet("", flag.ContinueOnError)
c.flags.StringVar(&c.cfgFile, "insecure-dev-config", "", c.flags.StringVar(&c.cfgFile, "dev-config", "",
"If set, proxy config is read on startup from this file (in HCL or JSON"+ "If set, proxy config is read on startup from this file (in HCL or JSON"+
"format). If a config file is given, the proxy will use that instead of "+ "format). If a config file is given, the proxy will use that instead of "+
"querying the local agent for it's configuration. It will not reload it "+ "querying the local agent for it's configuration. It will not reload it "+
"except on startup. In this mode the proxy WILL NOT authorize incoming "+ "except on startup. In this mode the proxy WILL NOT authorize incoming "+
"connections with the local agent which is totally insecure. This is "+ "connections with the local agent which is totally insecure. This is "+
"ONLY for development and testing.") "ONLY for internal development and testing and will probably be removed "+
"once proxy implementation is more complete..")
c.flags.StringVar(&c.proxyID, "proxy-id", "", c.flags.StringVar(&c.proxyID, "proxy-id", "",
"The proxy's ID on the local agent.") "The proxy's ID on the local agent.")
@ -121,31 +120,23 @@ func (c *cmd) Run(args []string) int {
} }
} }
ctx, cancel := context.WithCancel(context.Background()) // Hook the shutdownCh up to close the proxy
go func() { go func() {
err := p.Run(ctx) <-c.shutdownCh
if err != nil { p.Close()
c.UI.Error(fmt.Sprintf("Failed running proxy: %s", err))
}
// If we exited early due to a fatal error, need to unblock the main
// routine. But we can't close shutdownCh since it might already be closed
// by a signal and there is no way to tell. We also can't send on it to
// unblock main routine since it's typed as receive only. So the best thing
// we can do is cancel the context and have the main routine select on both.
cancel()
}() }()
c.UI.Output("Consul Connect proxy running!") c.UI.Output("Consul Connect proxy starting")
c.UI.Output("Log data will now stream in as it occurs:\n") c.UI.Output("Log data will now stream in as it occurs:\n")
logGate.Flush() logGate.Flush()
// Wait for shutdown or context cancel (see Run() goroutine above) // Run the proxy
select { err = p.Serve()
case <-c.shutdownCh: if err != nil {
cancel() c.UI.Error(fmt.Sprintf("Failed running proxy: %s", err))
case <-ctx.Done():
} }
c.UI.Output("Consul Connect proxy shutdown") c.UI.Output("Consul Connect proxy shutdown")
return 0 return 0
} }

View File

@ -1 +0,0 @@
package proxy

View File

@ -27,6 +27,7 @@
// NOTE: THIS IS A QUIRK OF OPENSSL; in Connect we distribute the roots alone // NOTE: THIS IS A QUIRK OF OPENSSL; in Connect we distribute the roots alone
// and stable intermediates like the XC cert to the _leaf_. // and stable intermediates like the XC cert to the _leaf_.
package main // import "github.com/hashicorp/consul/connect/certgen" package main // import "github.com/hashicorp/consul/connect/certgen"
import ( import (
"flag" "flag"
"fmt" "fmt"
@ -42,7 +43,6 @@ import (
func main() { func main() {
var numCAs = 2 var numCAs = 2
var services = []string{"web", "db", "cache"} var services = []string{"web", "db", "cache"}
//var slugRe = regexp.MustCompile("[^a-zA-Z0-9]+")
var outDir string var outDir string
flag.StringVar(&outDir, "out-dir", "", flag.StringVar(&outDir, "out-dir", "",

223
connect/proxy/config.go Normal file
View File

@ -0,0 +1,223 @@
package proxy
import (
"fmt"
"io/ioutil"
"log"
"github.com/hashicorp/consul/api"
"github.com/hashicorp/consul/connect"
"github.com/hashicorp/hcl"
)
// Config is the publicly configurable state for an entire proxy instance. It's
// mostly used as the format for the local-file config mode which is mostly for
// dev/testing. In normal use, different parts of this config are pulled from
// different locations (e.g. command line, agent config endpoint, agent
// certificate endpoints).
type Config struct {
// ProxyID is the identifier for this proxy as registered in Consul. It's only
// guaranteed to be unique per agent.
ProxyID string `json:"proxy_id" hcl:"proxy_id"`
// Token is the authentication token provided for queries to the local agent.
Token string `json:"token" hcl:"token"`
// ProxiedServiceID is the identifier of the service this proxy is representing.
ProxiedServiceID string `json:"proxied_service_id" hcl:"proxied_service_id"`
// ProxiedServiceNamespace is the namespace of the service this proxy is
// representing.
ProxiedServiceNamespace string `json:"proxied_service_namespace" hcl:"proxied_service_namespace"`
// PublicListener configures the mTLS listener.
PublicListener PublicListenerConfig `json:"public_listener" hcl:"public_listener"`
// Upstreams configures outgoing proxies for remote connect services.
Upstreams []UpstreamConfig `json:"upstreams" hcl:"upstreams"`
// DevCAFile allows passing the file path to PEM encoded root certificate
// bundle to be used in development instead of the ones supplied by Connect.
DevCAFile string `json:"dev_ca_file" hcl:"dev_ca_file"`
// DevServiceCertFile allows passing the file path to PEM encoded service
// certificate (client and server) to be used in development instead of the
// ones supplied by Connect.
DevServiceCertFile string `json:"dev_service_cert_file" hcl:"dev_service_cert_file"`
// DevServiceKeyFile allows passing the file path to PEM encoded service
// private key to be used in development instead of the ones supplied by
// Connect.
DevServiceKeyFile string `json:"dev_service_key_file" hcl:"dev_service_key_file"`
// service is a connect.Service instance representing the proxied service. It
// is created internally by the code responsible for setting up config as it
// may depend on other external dependencies
service *connect.Service
}
// PublicListenerConfig contains the parameters needed for the incoming mTLS
// listener.
type PublicListenerConfig struct {
// BindAddress is the host:port the public mTLS listener will bind to.
BindAddress string `json:"bind_address" hcl:"bind_address"`
// LocalServiceAddress is the host:port for the proxied application. This
// should be on loopback or otherwise protected as it's plain TCP.
LocalServiceAddress string `json:"local_service_address" hcl:"local_service_address"`
// LocalConnectTimeout is the timeout for establishing connections with the
// local backend. Defaults to 1000 (1s).
LocalConnectTimeoutMs int `json:"local_connect_timeout_ms" hcl:"local_connect_timeout_ms"`
// HandshakeTimeout is the timeout for incoming mTLS clients to complete a
// handshake. Setting this low avoids DOS by malicious clients holding
// resources open. Defaults to 10000 (10s).
HandshakeTimeoutMs int `json:"handshake_timeout_ms" hcl:"handshake_timeout_ms"`
}
// applyDefaults sets zero-valued params to a sane default.
func (plc *PublicListenerConfig) applyDefaults() {
if plc.LocalConnectTimeoutMs == 0 {
plc.LocalConnectTimeoutMs = 1000
}
if plc.HandshakeTimeoutMs == 0 {
plc.HandshakeTimeoutMs = 10000
}
}
// UpstreamConfig configures an upstream (outgoing) listener.
type UpstreamConfig struct {
// LocalAddress is the host:port to listen on for local app connections.
LocalBindAddress string `json:"local_bind_address" hcl:"local_bind_address,attr"`
// DestinationName is the service name of the destination.
DestinationName string `json:"destination_name" hcl:"destination_name,attr"`
// DestinationNamespace is the namespace of the destination.
DestinationNamespace string `json:"destination_namespace" hcl:"destination_namespace,attr"`
// DestinationType determines which service discovery method is used to find a
// candidate instance to connect to.
DestinationType string `json:"destination_type" hcl:"destination_type,attr"`
// DestinationDatacenter is the datacenter the destination is in. If empty,
// defaults to discovery within the same datacenter.
DestinationDatacenter string `json:"destination_datacenter" hcl:"destination_datacenter,attr"`
// ConnectTimeout is the timeout for establishing connections with the remote
// service instance. Defaults to 10,000 (10s).
ConnectTimeoutMs int `json:"connect_timeout_ms" hcl:"connect_timeout_ms,attr"`
// resolver is used to plug in the service discover mechanism. It can be used
// in tests to bypass discovery. In real usage it is used to inject the
// api.Client dependency from the remainder of the config struct parsed from
// the user JSON using the UpstreamResolverFromClient helper.
resolver connect.Resolver
}
// applyDefaults sets zero-valued params to a sane default.
func (uc *UpstreamConfig) applyDefaults() {
if uc.ConnectTimeoutMs == 0 {
uc.ConnectTimeoutMs = 10000
}
}
// String returns a string that uniquely identifies the Upstream. Used for
// identifying the upstream in log output and map keys.
func (uc *UpstreamConfig) String() string {
return fmt.Sprintf("%s->%s:%s/%s", uc.LocalBindAddress, uc.DestinationType,
uc.DestinationNamespace, uc.DestinationName)
}
// UpstreamResolverFromClient returns a ConsulResolver that can resolve the
// given UpstreamConfig using the provided api.Client dependency.
func UpstreamResolverFromClient(client *api.Client,
cfg UpstreamConfig) connect.Resolver {
// For now default to service as it has the most natural meaning and the error
// that the service doesn't exist is probably reasonable if misconfigured. We
// should probably handle actual configs that have invalid types at a higher
// level anyway (like when parsing).
typ := connect.ConsulResolverTypeService
if cfg.DestinationType == "prepared_query" {
typ = connect.ConsulResolverTypePreparedQuery
}
return &connect.ConsulResolver{
Client: client,
Namespace: cfg.DestinationNamespace,
Name: cfg.DestinationName,
Type: typ,
Datacenter: cfg.DestinationDatacenter,
}
}
// ConfigWatcher is a simple interface to allow dynamic configurations from
// plugggable sources.
type ConfigWatcher interface {
// Watch returns a channel that will deliver new Configs if something external
// provokes it.
Watch() <-chan *Config
}
// StaticConfigWatcher is a simple ConfigWatcher that delivers a static Config
// once and then never changes it.
type StaticConfigWatcher struct {
ch chan *Config
}
// NewStaticConfigWatcher returns a ConfigWatcher for a config that never
// changes. It assumes only one "watcher" will ever call Watch. The config is
// delivered on the first call but will never be delivered again to allow
// callers to call repeatedly (e.g. select in a loop).
func NewStaticConfigWatcher(cfg *Config) *StaticConfigWatcher {
sc := &StaticConfigWatcher{
// Buffer it so we can queue up the config for first delivery.
ch: make(chan *Config, 1),
}
sc.ch <- cfg
return sc
}
// Watch implements ConfigWatcher on a static configuration for compatibility.
// It returns itself on the channel once and then leaves it open.
func (sc *StaticConfigWatcher) Watch() <-chan *Config {
return sc.ch
}
// ParseConfigFile parses proxy configuration from a file for local dev.
func ParseConfigFile(filename string) (*Config, error) {
bs, err := ioutil.ReadFile(filename)
if err != nil {
return nil, err
}
var cfg Config
err = hcl.Unmarshal(bs, &cfg)
if err != nil {
return nil, err
}
cfg.PublicListener.applyDefaults()
for idx := range cfg.Upstreams {
cfg.Upstreams[idx].applyDefaults()
}
return &cfg, nil
}
// AgentConfigWatcher watches the local Consul agent for proxy config changes.
type AgentConfigWatcher struct {
client *api.Client
proxyID string
logger *log.Logger
}
// Watch implements ConfigWatcher.
func (w *AgentConfigWatcher) Watch() <-chan *Config {
watch := make(chan *Config)
// TODO implement me, note we need to discover the Service instance to use and
// set it on the Config we return.
return watch
}

View File

@ -0,0 +1,108 @@
package proxy
import (
"testing"
"github.com/hashicorp/consul/connect"
"github.com/stretchr/testify/require"
)
func TestParseConfigFile(t *testing.T) {
cfg, err := ParseConfigFile("testdata/config-kitchensink.hcl")
require.Nil(t, err)
expect := &Config{
ProxyID: "foo",
Token: "11111111-2222-3333-4444-555555555555",
ProxiedServiceID: "web",
ProxiedServiceNamespace: "default",
PublicListener: PublicListenerConfig{
BindAddress: ":9999",
LocalServiceAddress: "127.0.0.1:5000",
LocalConnectTimeoutMs: 1000,
HandshakeTimeoutMs: 10000, // From defaults
},
Upstreams: []UpstreamConfig{
{
LocalBindAddress: "127.0.0.1:6000",
DestinationName: "db",
DestinationNamespace: "default",
DestinationType: "service",
ConnectTimeoutMs: 10000,
},
{
LocalBindAddress: "127.0.0.1:6001",
DestinationName: "geo-cache",
DestinationNamespace: "default",
DestinationType: "prepared_query",
ConnectTimeoutMs: 10000,
},
},
DevCAFile: "connect/testdata/ca1-ca-consul-internal.cert.pem",
DevServiceCertFile: "connect/testdata/ca1-svc-web.cert.pem",
DevServiceKeyFile: "connect/testdata/ca1-svc-web.key.pem",
}
require.Equal(t, expect, cfg)
}
func TestUpstreamResolverFromClient(t *testing.T) {
tests := []struct {
name string
cfg UpstreamConfig
want *connect.ConsulResolver
}{
{
name: "service",
cfg: UpstreamConfig{
DestinationNamespace: "foo",
DestinationName: "web",
DestinationDatacenter: "ny1",
DestinationType: "service",
},
want: &connect.ConsulResolver{
Namespace: "foo",
Name: "web",
Datacenter: "ny1",
Type: connect.ConsulResolverTypeService,
},
},
{
name: "prepared_query",
cfg: UpstreamConfig{
DestinationNamespace: "foo",
DestinationName: "web",
DestinationDatacenter: "ny1",
DestinationType: "prepared_query",
},
want: &connect.ConsulResolver{
Namespace: "foo",
Name: "web",
Datacenter: "ny1",
Type: connect.ConsulResolverTypePreparedQuery,
},
},
{
name: "unknown behaves like service",
cfg: UpstreamConfig{
DestinationNamespace: "foo",
DestinationName: "web",
DestinationDatacenter: "ny1",
DestinationType: "junk",
},
want: &connect.ConsulResolver{
Namespace: "foo",
Name: "web",
Datacenter: "ny1",
Type: connect.ConsulResolverTypeService,
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Client doesn't really matter as long as it's passed through.
got := UpstreamResolverFromClient(nil, tt.cfg)
require.Equal(t, tt.want, got)
})
}
}

61
connect/proxy/conn.go Normal file
View File

@ -0,0 +1,61 @@
package proxy
import (
"io"
"net"
"sync/atomic"
)
// Conn represents a single proxied TCP connection.
type Conn struct {
src, dst net.Conn
stopping int32
}
// NewConn returns a conn joining the two given net.Conn
func NewConn(src, dst net.Conn) *Conn {
return &Conn{
src: src,
dst: dst,
stopping: 0,
}
}
// Close closes both connections.
func (c *Conn) Close() error {
// Note that net.Conn.Close can be called multiple times and atomic store is
// idempotent so no need to ensure we only do this once.
//
// Also note that we don't wait for CopyBytes to return here since we are
// closing the conns which is the only externally visible sideeffect of that
// goroutine running and there should be no way for it to hang or leak once
// the conns are closed so we can save the extra coordination.
atomic.StoreInt32(&c.stopping, 1)
c.src.Close()
c.dst.Close()
return nil
}
// CopyBytes will continuously copy bytes in both directions between src and dst
// until either connection is closed.
func (c *Conn) CopyBytes() error {
defer c.Close()
go func() {
// Need this since Copy is only guaranteed to stop when it's source reader
// (second arg) hits EOF or error but either conn might close first possibly
// causing this goroutine to exit but not the outer one. See
// TestConnSrcClosing which will fail if you comment the defer below.
defer c.Close()
io.Copy(c.dst, c.src)
}()
_, err := io.Copy(c.src, c.dst)
// Note that we don't wait for the other goroutine to finish because it either
// already has due to it's src conn closing, or it will once our defer fires
// and closes the source conn. No need for the extra coordination.
if atomic.LoadInt32(&c.stopping) == 1 {
return nil
}
return err
}

185
connect/proxy/conn_test.go Normal file
View File

@ -0,0 +1,185 @@
package proxy
import (
"bufio"
"io"
"net"
"testing"
"time"
"github.com/stretchr/testify/require"
)
// Assert io.Closer implementation
var _ io.Closer = new(Conn)
// testConnPairSetup creates a TCP connection by listening on a random port, and
// returns both ends. Ready to have data sent down them. It also returns a
// closer function that will close both conns and the listener.
func testConnPairSetup(t *testing.T) (net.Conn, net.Conn, func()) {
t.Helper()
l, err := net.Listen("tcp", "localhost:0")
require.Nil(t, err)
ch := make(chan net.Conn, 1)
go func() {
src, err := l.Accept()
require.Nil(t, err)
ch <- src
}()
dst, err := net.Dial("tcp", l.Addr().String())
require.Nil(t, err)
src := <-ch
stopper := func() {
l.Close()
src.Close()
dst.Close()
}
return src, dst, stopper
}
// testConnPipelineSetup creates a pipeline consiting of two TCP connection
// pairs and a Conn that copies bytes between them. Data flow looks like this:
//
// src1 <---> dst1 <== Conn.CopyBytes ==> src2 <---> dst2
//
// The returned values are the src1 and dst2 which should be able to send and
// receive to each other via the Conn, the Conn itself (not running), and a
// stopper func to close everything.
func testConnPipelineSetup(t *testing.T) (net.Conn, net.Conn, *Conn, func()) {
src1, dst1, stop1 := testConnPairSetup(t)
src2, dst2, stop2 := testConnPairSetup(t)
c := NewConn(dst1, src2)
return src1, dst2, c, func() {
c.Close()
stop1()
stop2()
}
}
func TestConn(t *testing.T) {
src, dst, c, stop := testConnPipelineSetup(t)
defer stop()
retCh := make(chan error, 1)
go func() {
retCh <- c.CopyBytes()
}()
// Now write/read into the other ends of the pipes (src1, dst2)
srcR := bufio.NewReader(src)
dstR := bufio.NewReader(dst)
_, err := src.Write([]byte("ping 1\n"))
require.Nil(t, err)
_, err = dst.Write([]byte("ping 2\n"))
require.Nil(t, err)
got, err := dstR.ReadString('\n')
require.Nil(t, err)
require.Equal(t, "ping 1\n", got)
got, err = srcR.ReadString('\n')
require.Nil(t, err)
require.Equal(t, "ping 2\n", got)
_, err = src.Write([]byte("pong 1\n"))
require.Nil(t, err)
_, err = dst.Write([]byte("pong 2\n"))
require.Nil(t, err)
got, err = dstR.ReadString('\n')
require.Nil(t, err)
require.Equal(t, "pong 1\n", got)
got, err = srcR.ReadString('\n')
require.Nil(t, err)
require.Equal(t, "pong 2\n", got)
c.Close()
ret := <-retCh
require.Nil(t, ret, "Close() should not cause error return")
}
func TestConnSrcClosing(t *testing.T) {
src, dst, c, stop := testConnPipelineSetup(t)
defer stop()
retCh := make(chan error, 1)
go func() {
retCh <- c.CopyBytes()
}()
// Wait until we can actually get some bytes through both ways so we know that
// the copy goroutines are running.
srcR := bufio.NewReader(src)
dstR := bufio.NewReader(dst)
_, err := src.Write([]byte("ping 1\n"))
require.Nil(t, err)
_, err = dst.Write([]byte("ping 2\n"))
require.Nil(t, err)
got, err := dstR.ReadString('\n')
require.Nil(t, err)
require.Equal(t, "ping 1\n", got)
got, err = srcR.ReadString('\n')
require.Nil(t, err)
require.Equal(t, "ping 2\n", got)
// If we close the src conn, we expect CopyBytes to return and dst to be
// closed too. No good way to assert that the conn is closed really other than
// assume the retCh receive will hang unless CopyBytes exits and that
// CopyBytes defers Closing both.
testTimer := time.AfterFunc(3*time.Second, func() {
panic("test timeout")
})
src.Close()
<-retCh
testTimer.Stop()
}
func TestConnDstClosing(t *testing.T) {
src, dst, c, stop := testConnPipelineSetup(t)
defer stop()
retCh := make(chan error, 1)
go func() {
retCh <- c.CopyBytes()
}()
// Wait until we can actually get some bytes through both ways so we know that
// the copy goroutines are running.
srcR := bufio.NewReader(src)
dstR := bufio.NewReader(dst)
_, err := src.Write([]byte("ping 1\n"))
require.Nil(t, err)
_, err = dst.Write([]byte("ping 2\n"))
require.Nil(t, err)
got, err := dstR.ReadString('\n')
require.Nil(t, err)
require.Equal(t, "ping 1\n", got)
got, err = srcR.ReadString('\n')
require.Nil(t, err)
require.Equal(t, "ping 2\n", got)
// If we close the dst conn, we expect CopyBytes to return and src to be
// closed too. No good way to assert that the conn is closed really other than
// assume the retCh receive will hang unless CopyBytes exits and that
// CopyBytes defers Closing both. i.e. if this test doesn't time out it's
// good!
testTimer := time.AfterFunc(3*time.Second, func() {
panic("test timeout")
})
src.Close()
<-retCh
testTimer.Stop()
}

116
connect/proxy/listener.go Normal file
View File

@ -0,0 +1,116 @@
package proxy
import (
"context"
"crypto/tls"
"errors"
"log"
"net"
"sync/atomic"
"time"
"github.com/hashicorp/consul/connect"
)
// Listener is the implementation of a specific proxy listener. It has pluggable
// Listen and Dial methods to suit public mTLS vs upstream semantics. It handles
// the lifecycle of the listener and all connections opened through it
type Listener struct {
// Service is the connect service instance to use.
Service *connect.Service
listenFunc func() (net.Listener, error)
dialFunc func() (net.Conn, error)
stopFlag int32
stopChan chan struct{}
logger *log.Logger
}
// NewPublicListener returns a Listener setup to listen for public mTLS
// connections and proxy them to the configured local application over TCP.
func NewPublicListener(svc *connect.Service, cfg PublicListenerConfig,
logger *log.Logger) *Listener {
return &Listener{
Service: svc,
listenFunc: func() (net.Listener, error) {
return tls.Listen("tcp", cfg.BindAddress, svc.ServerTLSConfig())
},
dialFunc: func() (net.Conn, error) {
return net.DialTimeout("tcp", cfg.LocalServiceAddress,
time.Duration(cfg.LocalConnectTimeoutMs)*time.Millisecond)
},
stopChan: make(chan struct{}),
logger: logger,
}
}
// NewUpstreamListener returns a Listener setup to listen locally for TCP
// connections that are proxied to a discovered Connect service instance.
func NewUpstreamListener(svc *connect.Service, cfg UpstreamConfig,
logger *log.Logger) *Listener {
return &Listener{
Service: svc,
listenFunc: func() (net.Listener, error) {
return net.Listen("tcp", cfg.LocalBindAddress)
},
dialFunc: func() (net.Conn, error) {
if cfg.resolver == nil {
return nil, errors.New("no resolver provided")
}
ctx, cancel := context.WithTimeout(context.Background(),
time.Duration(cfg.ConnectTimeoutMs)*time.Millisecond)
defer cancel()
return svc.Dial(ctx, cfg.resolver)
},
stopChan: make(chan struct{}),
logger: logger,
}
}
// Serve runs the listener until it is stopped.
func (l *Listener) Serve() error {
listen, err := l.listenFunc()
if err != nil {
return err
}
for {
conn, err := listen.Accept()
if err != nil {
if atomic.LoadInt32(&l.stopFlag) == 1 {
return nil
}
return err
}
go l.handleConn(conn)
}
return nil
}
// handleConn is the internal connection handler goroutine.
func (l *Listener) handleConn(src net.Conn) {
defer src.Close()
dst, err := l.dialFunc()
if err != nil {
l.logger.Printf("[ERR] failed to dial: %s", err)
return
}
// Note no need to defer dst.Close() since conn handles that for us.
conn := NewConn(src, dst)
defer conn.Close()
err = conn.CopyBytes()
if err != nil {
l.logger.Printf("[ERR] connection failed: %s", err)
return
}
}
// Close terminates the listener and all active connections.
func (l *Listener) Close() error {
return nil
}

View File

@ -0,0 +1,91 @@
package proxy
import (
"context"
"log"
"net"
"os"
"testing"
agConnect "github.com/hashicorp/consul/agent/connect"
"github.com/hashicorp/consul/connect"
"github.com/stretchr/testify/require"
)
func TestPublicListener(t *testing.T) {
ca := agConnect.TestCA(t, nil)
addrs := TestLocalBindAddrs(t, 2)
cfg := PublicListenerConfig{
BindAddress: addrs[0],
LocalServiceAddress: addrs[1],
HandshakeTimeoutMs: 100,
LocalConnectTimeoutMs: 100,
}
testApp, err := NewTestTCPServer(t, cfg.LocalServiceAddress)
require.Nil(t, err)
defer testApp.Close()
svc := connect.TestService(t, "db", ca)
l := NewPublicListener(svc, cfg, log.New(os.Stderr, "", log.LstdFlags))
// Run proxy
go func() {
err := l.Serve()
require.Nil(t, err)
}()
defer l.Close()
// Proxy and backend are running, play the part of a TLS client using same
// cert for now.
conn, err := svc.Dial(context.Background(), &connect.StaticResolver{
Addr: addrs[0],
CertURI: agConnect.TestSpiffeIDService(t, "db"),
})
require.Nilf(t, err, "unexpected err: %s", err)
TestEchoConn(t, conn, "")
}
func TestUpstreamListener(t *testing.T) {
ca := agConnect.TestCA(t, nil)
addrs := TestLocalBindAddrs(t, 1)
// Run a test server that we can dial.
testSvr := connect.NewTestServer(t, "db", ca)
go func() {
err := testSvr.Serve()
require.Nil(t, err)
}()
defer testSvr.Close()
cfg := UpstreamConfig{
DestinationType: "service",
DestinationNamespace: "default",
DestinationName: "db",
ConnectTimeoutMs: 100,
LocalBindAddress: addrs[0],
resolver: &connect.StaticResolver{
Addr: testSvr.Addr,
CertURI: agConnect.TestSpiffeIDService(t, "db"),
},
}
svc := connect.TestService(t, "web", ca)
l := NewUpstreamListener(svc, cfg, log.New(os.Stderr, "", log.LstdFlags))
// Run proxy
go func() {
err := l.Serve()
require.Nil(t, err)
}()
defer l.Close()
// Proxy and fake remote service are running, play the part of the app
// connecting to a remote connect service over TCP.
conn, err := net.Dial("tcp", cfg.LocalBindAddress)
require.Nilf(t, err, "unexpected err: %s", err)
TestEchoConn(t, conn, "")
}

134
connect/proxy/proxy.go Normal file
View File

@ -0,0 +1,134 @@
package proxy
import (
"log"
"github.com/hashicorp/consul/api"
"github.com/hashicorp/consul/connect"
)
// Proxy implements the built-in connect proxy.
type Proxy struct {
proxyID string
client *api.Client
cfgWatcher ConfigWatcher
stopChan chan struct{}
logger *log.Logger
}
// NewFromConfigFile returns a Proxy instance configured just from a local file.
// This is intended mostly for development and bypasses the normal mechanisms
// for fetching config and certificates from the local agent.
func NewFromConfigFile(client *api.Client, filename string,
logger *log.Logger) (*Proxy, error) {
cfg, err := ParseConfigFile(filename)
if err != nil {
return nil, err
}
service, err := connect.NewDevServiceFromCertFiles(cfg.ProxiedServiceID,
client, logger, cfg.DevCAFile, cfg.DevServiceCertFile,
cfg.DevServiceKeyFile)
if err != nil {
return nil, err
}
cfg.service = service
p := &Proxy{
proxyID: cfg.ProxyID,
client: client,
cfgWatcher: NewStaticConfigWatcher(cfg),
stopChan: make(chan struct{}),
logger: logger,
}
return p, nil
}
// New returns a Proxy with the given id, consuming the provided (configured)
// agent. It is ready to Run().
func New(client *api.Client, proxyID string, logger *log.Logger) (*Proxy, error) {
p := &Proxy{
proxyID: proxyID,
client: client,
cfgWatcher: &AgentConfigWatcher{
client: client,
proxyID: proxyID,
logger: logger,
},
stopChan: make(chan struct{}),
logger: logger,
}
return p, nil
}
// Serve the proxy instance until a fatal error occurs or proxy is closed.
func (p *Proxy) Serve() error {
var cfg *Config
// Watch for config changes (initial setup happens on first "change")
for {
select {
case newCfg := <-p.cfgWatcher.Watch():
p.logger.Printf("[DEBUG] got new config")
if newCfg.service == nil {
p.logger.Printf("[ERR] new config has nil service")
continue
}
if cfg == nil {
// Initial setup
newCfg.PublicListener.applyDefaults()
l := NewPublicListener(newCfg.service, newCfg.PublicListener, p.logger)
err := p.startListener("public listener", l)
if err != nil {
return err
}
}
// TODO(banks) update/remove upstreams properly based on a diff with current. Can
// store a map of uc.String() to Listener here and then use it to only
// start one of each and stop/modify if changes occur.
for _, uc := range newCfg.Upstreams {
uc.applyDefaults()
uc.resolver = UpstreamResolverFromClient(p.client, uc)
l := NewUpstreamListener(newCfg.service, uc, p.logger)
err := p.startListener(uc.String(), l)
if err != nil {
p.logger.Printf("[ERR] failed to start upstream %s: %s", uc.String(),
err)
}
}
cfg = newCfg
case <-p.stopChan:
return nil
}
}
}
// startPublicListener is run from the internal state machine loop
func (p *Proxy) startListener(name string, l *Listener) error {
go func() {
err := l.Serve()
if err != nil {
p.logger.Printf("[ERR] %s stopped with error: %s", name, err)
return
}
p.logger.Printf("[INFO] %s stopped", name)
}()
go func() {
<-p.stopChan
l.Close()
}()
return nil
}
// Close stops the proxy and terminates all active connections. It must be
// called only once.
func (p *Proxy) Close() {
close(p.stopChan)
}

View File

@ -0,0 +1,32 @@
# Example proxy config with everything specified
proxy_id = "foo"
token = "11111111-2222-3333-4444-555555555555"
proxied_service_id = "web"
proxied_service_namespace = "default"
# Assumes running consul in dev mode from the repo root...
dev_ca_file = "connect/testdata/ca1-ca-consul-internal.cert.pem"
dev_service_cert_file = "connect/testdata/ca1-svc-web.cert.pem"
dev_service_key_file = "connect/testdata/ca1-svc-web.key.pem"
public_listener {
bind_address = ":9999"
local_service_address = "127.0.0.1:5000"
}
upstreams = [
{
local_bind_address = "127.0.0.1:6000"
destination_name = "db"
destination_namespace = "default"
destination_type = "service"
},
{
local_bind_address = "127.0.0.1:6001"
destination_name = "geo-cache"
destination_namespace = "default"
destination_type = "prepared_query"
}
]

105
connect/proxy/testing.go Normal file
View File

@ -0,0 +1,105 @@
package proxy
import (
"fmt"
"io"
"log"
"net"
"sync/atomic"
"github.com/hashicorp/consul/lib/freeport"
"github.com/mitchellh/go-testing-interface"
"github.com/stretchr/testify/require"
)
// TestLocalBindAddrs returns n localhost address:port strings with free ports
// for binding test listeners to.
func TestLocalBindAddrs(t testing.T, n int) []string {
ports := freeport.GetT(t, n)
addrs := make([]string, n)
for i, p := range ports {
addrs[i] = fmt.Sprintf("localhost:%d", p)
}
return addrs
}
// TestTCPServer is a simple TCP echo server for use during tests.
type TestTCPServer struct {
l net.Listener
stopped int32
accepted, closed, active int32
}
// NewTestTCPServer opens as a listening socket on the given address and returns
// a TestTCPServer serving requests to it. The server is already started and can
// be stopped by calling Close().
func NewTestTCPServer(t testing.T, addr string) (*TestTCPServer, error) {
l, err := net.Listen("tcp", addr)
if err != nil {
return nil, err
}
log.Printf("test tcp server listening on %s", addr)
s := &TestTCPServer{
l: l,
}
go s.accept()
return s, nil
}
// Close stops the server
func (s *TestTCPServer) Close() {
atomic.StoreInt32(&s.stopped, 1)
if s.l != nil {
s.l.Close()
}
}
func (s *TestTCPServer) accept() error {
for {
conn, err := s.l.Accept()
if err != nil {
if atomic.LoadInt32(&s.stopped) == 1 {
log.Printf("test tcp echo server %s stopped", s.l.Addr())
return nil
}
log.Printf("test tcp echo server %s failed: %s", s.l.Addr(), err)
return err
}
atomic.AddInt32(&s.accepted, 1)
atomic.AddInt32(&s.active, 1)
go func(c net.Conn) {
io.Copy(c, c)
atomic.AddInt32(&s.closed, 1)
atomic.AddInt32(&s.active, -1)
}(conn)
}
}
// TestEchoConn attempts to write some bytes to conn and expects to read them
// back within a short timeout (10ms). If prefix is not empty we expect it to be
// poresent at the start of all echoed responses (for example to distinguish
// between multiple echo server instances).
func TestEchoConn(t testing.T, conn net.Conn, prefix string) {
t.Helper()
// Write some bytes and read them back
n, err := conn.Write([]byte("Hello World"))
require.Equal(t, 11, n)
require.Nil(t, err)
expectLen := 11 + len(prefix)
buf := make([]byte, expectLen)
// read until our buffer is full - it might be separate packets if prefix is
// in use.
got := 0
for got < expectLen {
n, err = conn.Read(buf[got:])
require.Nilf(t, err, "err: %s", err)
got += n
}
require.Equal(t, expectLen, got)
require.Equal(t, prefix+"Hello World", string(buf[:]))
}

View File

@ -10,7 +10,9 @@ import (
testing "github.com/mitchellh/go-testing-interface" testing "github.com/mitchellh/go-testing-interface"
) )
// Resolver is the interface implemented by a service discovery mechanism. // Resolver is the interface implemented by a service discovery mechanism to get
// the address and identity of an instance to connect to via Connect as a
// client.
type Resolver interface { type Resolver interface {
// Resolve returns a single service instance to connect to. Implementations // Resolve returns a single service instance to connect to. Implementations
// may attempt to ensure the instance returned is currently available. It is // may attempt to ensure the instance returned is currently available. It is
@ -19,7 +21,10 @@ type Resolver interface {
// increases reliability. The context passed can be used to impose timeouts // increases reliability. The context passed can be used to impose timeouts
// which may or may not be respected by implementations that make network // which may or may not be respected by implementations that make network
// calls to resolve the service. The addr returned is a string in any valid // calls to resolve the service. The addr returned is a string in any valid
// form for passing directly to `net.Dial("tcp", addr)`. // form for passing directly to `net.Dial("tcp", addr)`. The certURI
// represents the identity of the service instance. It will be matched against
// the TLS certificate URI SAN presented by the server and the connection
// rejected if they don't match.
Resolve(ctx context.Context) (addr string, certURI connect.CertURI, err error) Resolve(ctx context.Context) (addr string, certURI connect.CertURI, err error)
} }
@ -33,7 +38,8 @@ type StaticResolver struct {
Addr string Addr string
// CertURL is the _identity_ we expect the server to present in it's TLS // CertURL is the _identity_ we expect the server to present in it's TLS
// certificate. It must be an exact match or the connection will be rejected. // certificate. It must be an exact URI string match or the connection will be
// rejected.
CertURI connect.CertURI CertURI connect.CertURI
} }
@ -56,13 +62,14 @@ type ConsulResolver struct {
// panic. // panic.
Client *api.Client Client *api.Client
// Namespace of the query target // Namespace of the query target.
Namespace string Namespace string
// Name of the query target // Name of the query target.
Name string Name string
// Type of the query target, // Type of the query target. Should be one of the defined ConsulResolverType*
// constants. Currently defaults to ConsulResolverTypeService.
Type int Type int
// Datacenter to resolve in, empty indicates agent's local DC. // Datacenter to resolve in, empty indicates agent's local DC.

View File

@ -41,7 +41,6 @@ func TestStaticResolver_Resolve(t *testing.T) {
} }
func TestConsulResolver_Resolve(t *testing.T) { func TestConsulResolver_Resolve(t *testing.T) {
// Setup a local test agent to query // Setup a local test agent to query
agent := agent.NewTestAgent("test-consul", "") agent := agent.NewTestAgent("test-consul", "")
defer agent.Shutdown() defer agent.Shutdown()

View File

@ -3,6 +3,7 @@ package connect
import ( import (
"context" "context"
"crypto/tls" "crypto/tls"
"errors"
"log" "log"
"net" "net"
"net/http" "net/http"
@ -10,6 +11,7 @@ import (
"time" "time"
"github.com/hashicorp/consul/api" "github.com/hashicorp/consul/api"
"golang.org/x/net/http2"
) )
// Service represents a Consul service that accepts and/or connects via Connect. // Service represents a Consul service that accepts and/or connects via Connect.
@ -41,10 +43,17 @@ type Service struct {
client *api.Client client *api.Client
// serverTLSCfg is the (reloadable) TLS config we use for serving. // serverTLSCfg is the (reloadable) TLS config we use for serving.
serverTLSCfg *ReloadableTLSConfig serverTLSCfg *reloadableTLSConfig
// clientTLSCfg is the (reloadable) TLS config we use for dialling. // clientTLSCfg is the (reloadable) TLS config we use for dialling.
clientTLSCfg *ReloadableTLSConfig clientTLSCfg *reloadableTLSConfig
// httpResolverFromAddr is a function that returns a Resolver from a string
// address for HTTP clients. It's privately pluggable to make testing easier
// but will default to a simple method to parse the host as a Consul DNS host.
//
// TODO(banks): write the proper implementation
httpResolverFromAddr func(addr string) (Resolver, error)
logger *log.Logger logger *log.Logger
} }
@ -65,8 +74,8 @@ func NewServiceWithLogger(serviceID string, client *api.Client,
client: client, client: client,
logger: logger, logger: logger,
} }
s.serverTLSCfg = NewReloadableTLSConfig(defaultTLSConfig(serverVerifyCerts)) s.serverTLSCfg = newReloadableTLSConfig(defaultTLSConfig(serverVerifyCerts))
s.clientTLSCfg = NewReloadableTLSConfig(defaultTLSConfig(clientVerifyCerts)) s.clientTLSCfg = newReloadableTLSConfig(defaultTLSConfig(clientVerifyCerts))
// TODO(banks) run the background certificate sync // TODO(banks) run the background certificate sync
return s, nil return s, nil
@ -86,12 +95,12 @@ func NewDevServiceFromCertFiles(serviceID string, client *api.Client,
return nil, err return nil, err
} }
// Note that NewReloadableTLSConfig makes a copy so we can re-use the same // Note that newReloadableTLSConfig makes a copy so we can re-use the same
// base for both client and server with swapped verifiers. // base for both client and server with swapped verifiers.
tlsCfg.VerifyPeerCertificate = serverVerifyCerts tlsCfg.VerifyPeerCertificate = serverVerifyCerts
s.serverTLSCfg = NewReloadableTLSConfig(tlsCfg) s.serverTLSCfg = newReloadableTLSConfig(tlsCfg)
tlsCfg.VerifyPeerCertificate = clientVerifyCerts tlsCfg.VerifyPeerCertificate = clientVerifyCerts
s.clientTLSCfg = NewReloadableTLSConfig(tlsCfg) s.clientTLSCfg = newReloadableTLSConfig(tlsCfg)
return s, nil return s, nil
} }
@ -121,6 +130,8 @@ func (s *Service) Dial(ctx context.Context, resolver Resolver) (net.Conn, error)
if err != nil { if err != nil {
return nil, err return nil, err
} }
s.logger.Printf("[DEBUG] resolved service instance: %s (%s)", addr,
certURI.URI())
var dialer net.Dialer var dialer net.Dialer
tcpConn, err := dialer.DialContext(ctx, "tcp", addr) tcpConn, err := dialer.DialContext(ctx, "tcp", addr)
if err != nil { if err != nil {
@ -133,8 +144,8 @@ func (s *Service) Dial(ctx context.Context, resolver Resolver) (net.Conn, error)
if ok { if ok {
tlsConn.SetDeadline(deadline) tlsConn.SetDeadline(deadline)
} }
err = tlsConn.Handshake() // Perform handshake
if err != nil { if err = tlsConn.Handshake(); err != nil {
tlsConn.Close() tlsConn.Close()
return nil, err return nil, err
} }
@ -149,20 +160,27 @@ func (s *Service) Dial(ctx context.Context, resolver Resolver) (net.Conn, error)
tlsConn.Close() tlsConn.Close()
return nil, err return nil, err
} }
s.logger.Printf("[DEBUG] successfully connected to %s (%s)", addr,
certURI.URI())
return tlsConn, nil return tlsConn, nil
} }
// HTTPDialContext is compatible with http.Transport.DialContext. It expects the // HTTPDialTLS is compatible with http.Transport.DialTLS. It expects the addr
// addr hostname to be specified using Consul DNS query syntax, e.g. // hostname to be specified using Consul DNS query syntax, e.g.
// "web.service.consul". It converts that into the equivalent ConsulResolver and // "web.service.consul". It converts that into the equivalent ConsulResolver and
// then call s.Dial with the resolver. This is low level, clients should // then call s.Dial with the resolver. This is low level, clients should
// typically use HTTPClient directly. // typically use HTTPClient directly.
func (s *Service) HTTPDialContext(ctx context.Context, network, func (s *Service) HTTPDialTLS(network,
addr string) (net.Conn, error) { addr string) (net.Conn, error) {
var r ConsulResolver if s.httpResolverFromAddr == nil {
// TODO(banks): parse addr into ConsulResolver return nil, errors.New("no http resolver configured")
return s.Dial(ctx, &r) }
r, err := s.httpResolverFromAddr(addr)
if err != nil {
return nil, err
}
// TODO(banks): figure out how to do timeouts better.
return s.Dial(context.Background(), r)
} }
// HTTPClient returns an *http.Client configured to dial remote Consul Connect // HTTPClient returns an *http.Client configured to dial remote Consul Connect
@ -172,14 +190,27 @@ func (s *Service) HTTPDialContext(ctx context.Context, network,
// API rather than just relying on Consul DNS. Hostnames that are not valid // API rather than just relying on Consul DNS. Hostnames that are not valid
// Consul DNS queries will fail. // Consul DNS queries will fail.
func (s *Service) HTTPClient() *http.Client { func (s *Service) HTTPClient() *http.Client {
t := &http.Transport{
// Sadly we can't use DialContext hook since that is expected to return a
// plain TCP connection an http.Client tries to start a TLS handshake over
// it. We need to control the handshake to be able to do our validation.
// So we have to use the older DialTLS which means no context/timeout
// support.
//
// TODO(banks): figure out how users can configure a timeout when using
// this and/or compatibility with http.Request.WithContext.
DialTLS: s.HTTPDialTLS,
}
// Need to manually re-enable http2 support since we set custom DialTLS.
// See https://golang.org/src/net/http/transport.go?s=8692:9036#L228
http2.ConfigureTransport(t)
return &http.Client{ return &http.Client{
Transport: &http.Transport{ Transport: t,
DialContext: s.HTTPDialContext,
},
} }
} }
// Close stops the service and frees resources. // Close stops the service and frees resources.
func (s *Service) Close() { func (s *Service) Close() error {
// TODO(banks): stop background activity if started // TODO(banks): stop background activity if started
return nil
} }

View File

@ -2,14 +2,22 @@ package connect
import ( import (
"context" "context"
"crypto/tls"
"fmt" "fmt"
"io"
"io/ioutil"
"net/http"
"testing" "testing"
"time" "time"
"github.com/hashicorp/consul/agent/connect" "github.com/hashicorp/consul/agent/connect"
"github.com/hashicorp/consul/testutil/retry"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
// Assert io.Closer implementation
var _ io.Closer = new(Service)
func TestService_Dial(t *testing.T) { func TestService_Dial(t *testing.T) {
ca := connect.TestCA(t, nil) ca := connect.TestCA(t, nil)
@ -53,30 +61,26 @@ func TestService_Dial(t *testing.T) {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
require := require.New(t) require := require.New(t)
s, err := NewService("web", nil) s := TestService(t, "web", ca)
require.Nil(err)
// Force TLSConfig
s.clientTLSCfg = NewReloadableTLSConfig(TestTLSConfig(t, "web", ca))
ctx, cancel := context.WithTimeout(context.Background(), ctx, cancel := context.WithTimeout(context.Background(),
100*time.Millisecond) 100*time.Millisecond)
defer cancel() defer cancel()
testSvc := NewTestService(t, tt.presentService, ca) testSvr := NewTestServer(t, tt.presentService, ca)
testSvc.TimeoutHandshake = !tt.handshake testSvr.TimeoutHandshake = !tt.handshake
if tt.accept { if tt.accept {
go func() { go func() {
err := testSvc.Serve() err := testSvr.Serve()
require.Nil(err) require.Nil(err)
}() }()
defer testSvc.Close() defer testSvr.Close()
} }
// Always expect to be connecting to a "DB" // Always expect to be connecting to a "DB"
resolver := &StaticResolver{ resolver := &StaticResolver{
Addr: testSvc.Addr, Addr: testSvr.Addr,
CertURI: connect.TestSpiffeIDService(t, "db"), CertURI: connect.TestSpiffeIDService(t, "db"),
} }
@ -92,6 +96,7 @@ func TestService_Dial(t *testing.T) {
if tt.wantErr == "" { if tt.wantErr == "" {
require.Nil(err) require.Nil(err)
require.IsType(&tls.Conn{}, conn)
} else { } else {
require.NotNil(err) require.NotNil(err)
require.Contains(err.Error(), tt.wantErr) require.Contains(err.Error(), tt.wantErr)
@ -103,3 +108,62 @@ func TestService_Dial(t *testing.T) {
}) })
} }
} }
func TestService_ServerTLSConfig(t *testing.T) {
// TODO(banks): it's mostly meaningless to test this now since we directly set
// the tlsCfg in our TestService helper which is all we'd be asserting on here
// not the actual implementation. Once agent tls fetching is built, it becomes
// more meaningful to actually verify it's returning the correct config.
}
func TestService_HTTPClient(t *testing.T) {
require := require.New(t)
ca := connect.TestCA(t, nil)
s := TestService(t, "web", ca)
// Run a test HTTP server
testSvr := NewTestServer(t, "backend", ca)
defer testSvr.Close()
go func() {
err := testSvr.ServeHTTPS(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("Hello, I am Backend"))
}))
require.Nil(t, err)
}()
// TODO(banks): this will talk http2 on both client and server. I hit some
// compatibility issues when testing though need to make sure that the http
// server with our TLSConfig can actually support HTTP/1.1 as well. Could make
// this a table test with all 4 permutations of client/server http version
// support.
// Still get connection refused some times so retry on those
retry.Run(t, func(r *retry.R) {
// Hook the service resolver to avoid needing full agent setup.
s.httpResolverFromAddr = func(addr string) (Resolver, error) {
// Require in this goroutine seems to block causing a timeout on the Get.
//require.Equal("https://backend.service.consul:443", addr)
return &StaticResolver{
Addr: testSvr.Addr,
CertURI: connect.TestSpiffeIDService(t, "backend"),
}, nil
}
client := s.HTTPClient()
client.Timeout = 1 * time.Second
resp, err := client.Get("https://backend.service.consul/foo")
r.Check(err)
defer resp.Body.Close()
bodyBytes, err := ioutil.ReadAll(resp.Body)
r.Check(err)
got := string(bodyBytes)
want := "Hello, I am Backend"
if got != want {
r.Fatalf("got %s, want %s", got, want)
}
})
}

View File

@ -5,26 +5,33 @@ import (
"crypto/x509" "crypto/x509"
"fmt" "fmt"
"io" "io"
"log"
"net" "net"
"net/http"
"sync/atomic" "sync/atomic"
"github.com/hashicorp/consul/agent/connect" "github.com/hashicorp/consul/agent/connect"
"github.com/hashicorp/consul/agent/structs" "github.com/hashicorp/consul/agent/structs"
"github.com/hashicorp/consul/lib/freeport" "github.com/hashicorp/consul/lib/freeport"
testing "github.com/mitchellh/go-testing-interface" testing "github.com/mitchellh/go-testing-interface"
"github.com/stretchr/testify/require"
) )
// testVerifier creates a helper verifyFunc that can be set in a tls.Config and // TestService returns a Service instance based on a static TLS Config.
// records calls made, passing back the certificates presented via the returned func TestService(t testing.T, service string, ca *structs.CARoot) *Service {
// channel. The channel is buffered so up to 128 verification calls can be made t.Helper()
// without reading the chan before verification blocks.
func testVerifier(t testing.T, returnErr error) (verifyFunc, chan [][]byte) { // Don't need to talk to client since we are setting TLSConfig locally
ch := make(chan [][]byte, 128) svc, err := NewService(service, nil)
return func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error { if err != nil {
ch <- rawCerts t.Fatal(err)
return returnErr }
}, ch
svc.serverTLSCfg = newReloadableTLSConfig(
TestTLSConfigWithVerifier(t, service, ca, serverVerifyCerts))
svc.clientTLSCfg = newReloadableTLSConfig(
TestTLSConfigWithVerifier(t, service, ca, clientVerifyCerts))
return svc
} }
// TestTLSConfig returns a *tls.Config suitable for use during tests. // TestTLSConfig returns a *tls.Config suitable for use during tests.
@ -32,7 +39,16 @@ func TestTLSConfig(t testing.T, service string, ca *structs.CARoot) *tls.Config
t.Helper() t.Helper()
// Insecure default (nil verifier) // Insecure default (nil verifier)
cfg := defaultTLSConfig(nil) return TestTLSConfigWithVerifier(t, service, ca, nil)
}
// TestTLSConfigWithVerifier returns a *tls.Config suitable for use during
// tests, it will use the given verifyFunc to verify tls certificates.
func TestTLSConfigWithVerifier(t testing.T, service string, ca *structs.CARoot,
verifier verifyFunc) *tls.Config {
t.Helper()
cfg := defaultTLSConfig(verifier)
cfg.Certificates = []tls.Certificate{TestSvcKeyPair(t, service, ca)} cfg.Certificates = []tls.Certificate{TestSvcKeyPair(t, service, ca)}
cfg.RootCAs = TestCAPool(t, ca) cfg.RootCAs = TestCAPool(t, ca)
cfg.ClientCAs = TestCAPool(t, ca) cfg.ClientCAs = TestCAPool(t, ca)
@ -55,7 +71,9 @@ func TestSvcKeyPair(t testing.T, service string, ca *structs.CARoot) tls.Certifi
t.Helper() t.Helper()
certPEM, keyPEM := connect.TestLeaf(t, service, ca) certPEM, keyPEM := connect.TestLeaf(t, service, ca)
cert, err := tls.X509KeyPair([]byte(certPEM), []byte(keyPEM)) cert, err := tls.X509KeyPair([]byte(certPEM), []byte(keyPEM))
require.Nil(t, err) if err != nil {
t.Fatal(err)
}
return cert return cert
} }
@ -65,13 +83,15 @@ func TestPeerCertificates(t testing.T, service string, ca *structs.CARoot) []*x5
t.Helper() t.Helper()
certPEM, _ := connect.TestLeaf(t, service, ca) certPEM, _ := connect.TestLeaf(t, service, ca)
cert, err := connect.ParseCert(certPEM) cert, err := connect.ParseCert(certPEM)
require.Nil(t, err) if err != nil {
t.Fatal(err)
}
return []*x509.Certificate{cert} return []*x509.Certificate{cert}
} }
// TestService runs a service listener that can be used to test clients. It's // TestServer runs a service listener that can be used to test clients. It's
// behaviour can be controlled by the struct members. // behaviour can be controlled by the struct members.
type TestService struct { type TestServer struct {
// The service name to serve. // The service name to serve.
Service string Service string
// The (test) CA to use for generating certs. // The (test) CA to use for generating certs.
@ -91,11 +111,11 @@ type TestService struct {
stopChan chan struct{} stopChan chan struct{}
} }
// NewTestService returns a TestService. It should be closed when test is // NewTestServer returns a TestServer. It should be closed when test is
// complete. // complete.
func NewTestService(t testing.T, service string, ca *structs.CARoot) *TestService { func NewTestServer(t testing.T, service string, ca *structs.CARoot) *TestServer {
ports := freeport.GetT(t, 1) ports := freeport.GetT(t, 1)
return &TestService{ return &TestServer{
Service: service, Service: service,
CA: ca, CA: ca,
stopChan: make(chan struct{}), stopChan: make(chan struct{}),
@ -104,14 +124,16 @@ func NewTestService(t testing.T, service string, ca *structs.CARoot) *TestServic
} }
} }
// Serve runs a TestService and blocks until it is closed or errors. // Serve runs a tcp echo server and blocks until it is closed or errors. If
func (s *TestService) Serve() error { // TimeoutHandshake is set it won't start TLS handshake on new connections.
func (s *TestServer) Serve() error {
// Just accept TCP conn but so we can control timing of accept/handshake // Just accept TCP conn but so we can control timing of accept/handshake
l, err := net.Listen("tcp", s.Addr) l, err := net.Listen("tcp", s.Addr)
if err != nil { if err != nil {
return err return err
} }
s.l = l s.l = l
log.Printf("test connect service listening on %s", s.Addr)
for { for {
conn, err := s.l.Accept() conn, err := s.l.Accept()
@ -122,12 +144,14 @@ func (s *TestService) Serve() error {
return err return err
} }
// Ignore the conn if we are not actively ha // Ignore the conn if we are not actively handshaking
if !s.TimeoutHandshake { if !s.TimeoutHandshake {
// Upgrade conn to TLS // Upgrade conn to TLS
conn = tls.Server(conn, s.TLSCfg) conn = tls.Server(conn, s.TLSCfg)
// Run an echo service // Run an echo service
log.Printf("test connect service accepted conn from %s, "+
" running echo service", conn.RemoteAddr())
go io.Copy(conn, conn) go io.Copy(conn, conn)
} }
@ -141,8 +165,20 @@ func (s *TestService) Serve() error {
return nil return nil
} }
// Close stops a TestService // ServeHTTPS runs an HTTPS server with the given config. It invokes the passed
func (s *TestService) Close() { // Handler for all requests.
func (s *TestServer) ServeHTTPS(h http.Handler) error {
srv := http.Server{
Addr: s.Addr,
TLSConfig: s.TLSCfg,
Handler: h,
}
log.Printf("starting test connect HTTPS server on %s", s.Addr)
return srv.ListenAndServeTLS("", "")
}
// Close stops a TestServer
func (s *TestServer) Close() error {
old := atomic.SwapInt32(&s.stopFlag, 1) old := atomic.SwapInt32(&s.stopFlag, 1)
if old == 0 { if old == 0 {
if s.l != nil { if s.l != nil {
@ -150,4 +186,5 @@ func (s *TestService) Close() {
} }
close(s.stopChan) close(s.stopChan)
} }
return nil
} }

View File

@ -42,27 +42,27 @@ func defaultTLSConfig(verify verifyFunc) *tls.Config {
} }
} }
// ReloadableTLSConfig exposes a tls.Config that can have it's certificates // reloadableTLSConfig exposes a tls.Config that can have it's certificates
// reloaded. On a server, this uses GetConfigForClient to pass the current // reloaded. On a server, this uses GetConfigForClient to pass the current
// tls.Config or client certificate for each acceptted connection. On a client, // tls.Config or client certificate for each acceptted connection. On a client,
// this uses GetClientCertificate to provide the current client certificate. // this uses GetClientCertificate to provide the current client certificate.
type ReloadableTLSConfig struct { type reloadableTLSConfig struct {
mu sync.Mutex mu sync.Mutex
// cfg is the current config to use for new connections // cfg is the current config to use for new connections
cfg *tls.Config cfg *tls.Config
} }
// NewReloadableTLSConfig returns a reloadable config currently set to base. // newReloadableTLSConfig returns a reloadable config currently set to base.
func NewReloadableTLSConfig(base *tls.Config) *ReloadableTLSConfig { func newReloadableTLSConfig(base *tls.Config) *reloadableTLSConfig {
c := &ReloadableTLSConfig{} c := &reloadableTLSConfig{}
c.SetTLSConfig(base) c.SetTLSConfig(base)
return c return c
} }
// TLSConfig returns a *tls.Config that will dynamically load certs. It's // TLSConfig returns a *tls.Config that will dynamically load certs. It's
// suitable for use in either a client or server. // suitable for use in either a client or server.
func (c *ReloadableTLSConfig) TLSConfig() *tls.Config { func (c *reloadableTLSConfig) TLSConfig() *tls.Config {
c.mu.Lock() c.mu.Lock()
cfgCopy := c.cfg cfgCopy := c.cfg
c.mu.Unlock() c.mu.Unlock()
@ -71,7 +71,7 @@ func (c *ReloadableTLSConfig) TLSConfig() *tls.Config {
// SetTLSConfig sets the config used for future connections. It is safe to call // SetTLSConfig sets the config used for future connections. It is safe to call
// from any goroutine. // from any goroutine.
func (c *ReloadableTLSConfig) SetTLSConfig(cfg *tls.Config) error { func (c *reloadableTLSConfig) SetTLSConfig(cfg *tls.Config) error {
copy := cfg.Clone() copy := cfg.Clone()
copy.GetClientCertificate = func(*tls.CertificateRequestInfo) (*tls.Certificate, error) { copy.GetClientCertificate = func(*tls.CertificateRequestInfo) (*tls.Certificate, error) {
current := c.TLSConfig() current := c.TLSConfig()

View File

@ -10,10 +10,9 @@ import (
func TestReloadableTLSConfig(t *testing.T) { func TestReloadableTLSConfig(t *testing.T) {
require := require.New(t) require := require.New(t)
verify, _ := testVerifier(t, nil) base := defaultTLSConfig(nil)
base := defaultTLSConfig(verify)
c := NewReloadableTLSConfig(base) c := newReloadableTLSConfig(base)
// The dynamic config should be the one we loaded (with some different hooks) // The dynamic config should be the one we loaded (with some different hooks)
got := c.TLSConfig() got := c.TLSConfig()