mirror of https://github.com/status-im/consul.git
Rework connect/proxy and command/connect/proxy. End to end demo working again
This commit is contained in:
parent
aa19be4651
commit
10db79c8ae
|
@ -157,7 +157,7 @@ func TestLeaf(t testing.T, service string, root *structs.CARoot) (string, string
|
||||||
t.Fatalf("error generating serial number: %s", err)
|
t.Fatalf("error generating serial number: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Genereate fresh private key
|
// Generate fresh private key
|
||||||
pkSigner, pkPEM := testPrivateKey(t)
|
pkSigner, pkPEM := testPrivateKey(t)
|
||||||
|
|
||||||
// Cert template for generation
|
// Cert template for generation
|
||||||
|
|
|
@ -1,17 +1,15 @@
|
||||||
package proxy
|
package proxy
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
|
||||||
"flag"
|
"flag"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"log"
|
"log"
|
||||||
"net/http"
|
"net/http"
|
||||||
// Expose pprof if configured
|
_ "net/http/pprof" // Expose pprof if configured
|
||||||
_ "net/http/pprof"
|
|
||||||
|
|
||||||
"github.com/hashicorp/consul/command/flags"
|
"github.com/hashicorp/consul/command/flags"
|
||||||
proxyImpl "github.com/hashicorp/consul/proxy"
|
proxyImpl "github.com/hashicorp/consul/connect/proxy"
|
||||||
|
|
||||||
"github.com/hashicorp/consul/logger"
|
"github.com/hashicorp/consul/logger"
|
||||||
"github.com/hashicorp/logutils"
|
"github.com/hashicorp/logutils"
|
||||||
|
@ -46,13 +44,14 @@ type cmd struct {
|
||||||
func (c *cmd) init() {
|
func (c *cmd) init() {
|
||||||
c.flags = flag.NewFlagSet("", flag.ContinueOnError)
|
c.flags = flag.NewFlagSet("", flag.ContinueOnError)
|
||||||
|
|
||||||
c.flags.StringVar(&c.cfgFile, "insecure-dev-config", "",
|
c.flags.StringVar(&c.cfgFile, "dev-config", "",
|
||||||
"If set, proxy config is read on startup from this file (in HCL or JSON"+
|
"If set, proxy config is read on startup from this file (in HCL or JSON"+
|
||||||
"format). If a config file is given, the proxy will use that instead of "+
|
"format). If a config file is given, the proxy will use that instead of "+
|
||||||
"querying the local agent for it's configuration. It will not reload it "+
|
"querying the local agent for it's configuration. It will not reload it "+
|
||||||
"except on startup. In this mode the proxy WILL NOT authorize incoming "+
|
"except on startup. In this mode the proxy WILL NOT authorize incoming "+
|
||||||
"connections with the local agent which is totally insecure. This is "+
|
"connections with the local agent which is totally insecure. This is "+
|
||||||
"ONLY for development and testing.")
|
"ONLY for internal development and testing and will probably be removed "+
|
||||||
|
"once proxy implementation is more complete..")
|
||||||
|
|
||||||
c.flags.StringVar(&c.proxyID, "proxy-id", "",
|
c.flags.StringVar(&c.proxyID, "proxy-id", "",
|
||||||
"The proxy's ID on the local agent.")
|
"The proxy's ID on the local agent.")
|
||||||
|
@ -121,31 +120,23 @@ func (c *cmd) Run(args []string) int {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
ctx, cancel := context.WithCancel(context.Background())
|
// Hook the shutdownCh up to close the proxy
|
||||||
go func() {
|
go func() {
|
||||||
err := p.Run(ctx)
|
<-c.shutdownCh
|
||||||
if err != nil {
|
p.Close()
|
||||||
c.UI.Error(fmt.Sprintf("Failed running proxy: %s", err))
|
|
||||||
}
|
|
||||||
// If we exited early due to a fatal error, need to unblock the main
|
|
||||||
// routine. But we can't close shutdownCh since it might already be closed
|
|
||||||
// by a signal and there is no way to tell. We also can't send on it to
|
|
||||||
// unblock main routine since it's typed as receive only. So the best thing
|
|
||||||
// we can do is cancel the context and have the main routine select on both.
|
|
||||||
cancel()
|
|
||||||
}()
|
}()
|
||||||
|
|
||||||
c.UI.Output("Consul Connect proxy running!")
|
c.UI.Output("Consul Connect proxy starting")
|
||||||
|
|
||||||
c.UI.Output("Log data will now stream in as it occurs:\n")
|
c.UI.Output("Log data will now stream in as it occurs:\n")
|
||||||
logGate.Flush()
|
logGate.Flush()
|
||||||
|
|
||||||
// Wait for shutdown or context cancel (see Run() goroutine above)
|
// Run the proxy
|
||||||
select {
|
err = p.Serve()
|
||||||
case <-c.shutdownCh:
|
if err != nil {
|
||||||
cancel()
|
c.UI.Error(fmt.Sprintf("Failed running proxy: %s", err))
|
||||||
case <-ctx.Done():
|
|
||||||
}
|
}
|
||||||
|
|
||||||
c.UI.Output("Consul Connect proxy shutdown")
|
c.UI.Output("Consul Connect proxy shutdown")
|
||||||
return 0
|
return 0
|
||||||
}
|
}
|
||||||
|
|
|
@ -1 +0,0 @@
|
||||||
package proxy
|
|
|
@ -27,6 +27,7 @@
|
||||||
// NOTE: THIS IS A QUIRK OF OPENSSL; in Connect we distribute the roots alone
|
// NOTE: THIS IS A QUIRK OF OPENSSL; in Connect we distribute the roots alone
|
||||||
// and stable intermediates like the XC cert to the _leaf_.
|
// and stable intermediates like the XC cert to the _leaf_.
|
||||||
package main // import "github.com/hashicorp/consul/connect/certgen"
|
package main // import "github.com/hashicorp/consul/connect/certgen"
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"flag"
|
"flag"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
@ -42,7 +43,6 @@ import (
|
||||||
func main() {
|
func main() {
|
||||||
var numCAs = 2
|
var numCAs = 2
|
||||||
var services = []string{"web", "db", "cache"}
|
var services = []string{"web", "db", "cache"}
|
||||||
//var slugRe = regexp.MustCompile("[^a-zA-Z0-9]+")
|
|
||||||
var outDir string
|
var outDir string
|
||||||
|
|
||||||
flag.StringVar(&outDir, "out-dir", "",
|
flag.StringVar(&outDir, "out-dir", "",
|
||||||
|
|
|
@ -0,0 +1,223 @@
|
||||||
|
package proxy
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"io/ioutil"
|
||||||
|
"log"
|
||||||
|
|
||||||
|
"github.com/hashicorp/consul/api"
|
||||||
|
"github.com/hashicorp/consul/connect"
|
||||||
|
"github.com/hashicorp/hcl"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Config is the publicly configurable state for an entire proxy instance. It's
|
||||||
|
// mostly used as the format for the local-file config mode which is mostly for
|
||||||
|
// dev/testing. In normal use, different parts of this config are pulled from
|
||||||
|
// different locations (e.g. command line, agent config endpoint, agent
|
||||||
|
// certificate endpoints).
|
||||||
|
type Config struct {
|
||||||
|
// ProxyID is the identifier for this proxy as registered in Consul. It's only
|
||||||
|
// guaranteed to be unique per agent.
|
||||||
|
ProxyID string `json:"proxy_id" hcl:"proxy_id"`
|
||||||
|
|
||||||
|
// Token is the authentication token provided for queries to the local agent.
|
||||||
|
Token string `json:"token" hcl:"token"`
|
||||||
|
|
||||||
|
// ProxiedServiceID is the identifier of the service this proxy is representing.
|
||||||
|
ProxiedServiceID string `json:"proxied_service_id" hcl:"proxied_service_id"`
|
||||||
|
|
||||||
|
// ProxiedServiceNamespace is the namespace of the service this proxy is
|
||||||
|
// representing.
|
||||||
|
ProxiedServiceNamespace string `json:"proxied_service_namespace" hcl:"proxied_service_namespace"`
|
||||||
|
|
||||||
|
// PublicListener configures the mTLS listener.
|
||||||
|
PublicListener PublicListenerConfig `json:"public_listener" hcl:"public_listener"`
|
||||||
|
|
||||||
|
// Upstreams configures outgoing proxies for remote connect services.
|
||||||
|
Upstreams []UpstreamConfig `json:"upstreams" hcl:"upstreams"`
|
||||||
|
|
||||||
|
// DevCAFile allows passing the file path to PEM encoded root certificate
|
||||||
|
// bundle to be used in development instead of the ones supplied by Connect.
|
||||||
|
DevCAFile string `json:"dev_ca_file" hcl:"dev_ca_file"`
|
||||||
|
|
||||||
|
// DevServiceCertFile allows passing the file path to PEM encoded service
|
||||||
|
// certificate (client and server) to be used in development instead of the
|
||||||
|
// ones supplied by Connect.
|
||||||
|
DevServiceCertFile string `json:"dev_service_cert_file" hcl:"dev_service_cert_file"`
|
||||||
|
|
||||||
|
// DevServiceKeyFile allows passing the file path to PEM encoded service
|
||||||
|
// private key to be used in development instead of the ones supplied by
|
||||||
|
// Connect.
|
||||||
|
DevServiceKeyFile string `json:"dev_service_key_file" hcl:"dev_service_key_file"`
|
||||||
|
|
||||||
|
// service is a connect.Service instance representing the proxied service. It
|
||||||
|
// is created internally by the code responsible for setting up config as it
|
||||||
|
// may depend on other external dependencies
|
||||||
|
service *connect.Service
|
||||||
|
}
|
||||||
|
|
||||||
|
// PublicListenerConfig contains the parameters needed for the incoming mTLS
|
||||||
|
// listener.
|
||||||
|
type PublicListenerConfig struct {
|
||||||
|
// BindAddress is the host:port the public mTLS listener will bind to.
|
||||||
|
BindAddress string `json:"bind_address" hcl:"bind_address"`
|
||||||
|
|
||||||
|
// LocalServiceAddress is the host:port for the proxied application. This
|
||||||
|
// should be on loopback or otherwise protected as it's plain TCP.
|
||||||
|
LocalServiceAddress string `json:"local_service_address" hcl:"local_service_address"`
|
||||||
|
|
||||||
|
// LocalConnectTimeout is the timeout for establishing connections with the
|
||||||
|
// local backend. Defaults to 1000 (1s).
|
||||||
|
LocalConnectTimeoutMs int `json:"local_connect_timeout_ms" hcl:"local_connect_timeout_ms"`
|
||||||
|
|
||||||
|
// HandshakeTimeout is the timeout for incoming mTLS clients to complete a
|
||||||
|
// handshake. Setting this low avoids DOS by malicious clients holding
|
||||||
|
// resources open. Defaults to 10000 (10s).
|
||||||
|
HandshakeTimeoutMs int `json:"handshake_timeout_ms" hcl:"handshake_timeout_ms"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// applyDefaults sets zero-valued params to a sane default.
|
||||||
|
func (plc *PublicListenerConfig) applyDefaults() {
|
||||||
|
if plc.LocalConnectTimeoutMs == 0 {
|
||||||
|
plc.LocalConnectTimeoutMs = 1000
|
||||||
|
}
|
||||||
|
if plc.HandshakeTimeoutMs == 0 {
|
||||||
|
plc.HandshakeTimeoutMs = 10000
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpstreamConfig configures an upstream (outgoing) listener.
|
||||||
|
type UpstreamConfig struct {
|
||||||
|
// LocalAddress is the host:port to listen on for local app connections.
|
||||||
|
LocalBindAddress string `json:"local_bind_address" hcl:"local_bind_address,attr"`
|
||||||
|
|
||||||
|
// DestinationName is the service name of the destination.
|
||||||
|
DestinationName string `json:"destination_name" hcl:"destination_name,attr"`
|
||||||
|
|
||||||
|
// DestinationNamespace is the namespace of the destination.
|
||||||
|
DestinationNamespace string `json:"destination_namespace" hcl:"destination_namespace,attr"`
|
||||||
|
|
||||||
|
// DestinationType determines which service discovery method is used to find a
|
||||||
|
// candidate instance to connect to.
|
||||||
|
DestinationType string `json:"destination_type" hcl:"destination_type,attr"`
|
||||||
|
|
||||||
|
// DestinationDatacenter is the datacenter the destination is in. If empty,
|
||||||
|
// defaults to discovery within the same datacenter.
|
||||||
|
DestinationDatacenter string `json:"destination_datacenter" hcl:"destination_datacenter,attr"`
|
||||||
|
|
||||||
|
// ConnectTimeout is the timeout for establishing connections with the remote
|
||||||
|
// service instance. Defaults to 10,000 (10s).
|
||||||
|
ConnectTimeoutMs int `json:"connect_timeout_ms" hcl:"connect_timeout_ms,attr"`
|
||||||
|
|
||||||
|
// resolver is used to plug in the service discover mechanism. It can be used
|
||||||
|
// in tests to bypass discovery. In real usage it is used to inject the
|
||||||
|
// api.Client dependency from the remainder of the config struct parsed from
|
||||||
|
// the user JSON using the UpstreamResolverFromClient helper.
|
||||||
|
resolver connect.Resolver
|
||||||
|
}
|
||||||
|
|
||||||
|
// applyDefaults sets zero-valued params to a sane default.
|
||||||
|
func (uc *UpstreamConfig) applyDefaults() {
|
||||||
|
if uc.ConnectTimeoutMs == 0 {
|
||||||
|
uc.ConnectTimeoutMs = 10000
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// String returns a string that uniquely identifies the Upstream. Used for
|
||||||
|
// identifying the upstream in log output and map keys.
|
||||||
|
func (uc *UpstreamConfig) String() string {
|
||||||
|
return fmt.Sprintf("%s->%s:%s/%s", uc.LocalBindAddress, uc.DestinationType,
|
||||||
|
uc.DestinationNamespace, uc.DestinationName)
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpstreamResolverFromClient returns a ConsulResolver that can resolve the
|
||||||
|
// given UpstreamConfig using the provided api.Client dependency.
|
||||||
|
func UpstreamResolverFromClient(client *api.Client,
|
||||||
|
cfg UpstreamConfig) connect.Resolver {
|
||||||
|
|
||||||
|
// For now default to service as it has the most natural meaning and the error
|
||||||
|
// that the service doesn't exist is probably reasonable if misconfigured. We
|
||||||
|
// should probably handle actual configs that have invalid types at a higher
|
||||||
|
// level anyway (like when parsing).
|
||||||
|
typ := connect.ConsulResolverTypeService
|
||||||
|
if cfg.DestinationType == "prepared_query" {
|
||||||
|
typ = connect.ConsulResolverTypePreparedQuery
|
||||||
|
}
|
||||||
|
return &connect.ConsulResolver{
|
||||||
|
Client: client,
|
||||||
|
Namespace: cfg.DestinationNamespace,
|
||||||
|
Name: cfg.DestinationName,
|
||||||
|
Type: typ,
|
||||||
|
Datacenter: cfg.DestinationDatacenter,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ConfigWatcher is a simple interface to allow dynamic configurations from
|
||||||
|
// plugggable sources.
|
||||||
|
type ConfigWatcher interface {
|
||||||
|
// Watch returns a channel that will deliver new Configs if something external
|
||||||
|
// provokes it.
|
||||||
|
Watch() <-chan *Config
|
||||||
|
}
|
||||||
|
|
||||||
|
// StaticConfigWatcher is a simple ConfigWatcher that delivers a static Config
|
||||||
|
// once and then never changes it.
|
||||||
|
type StaticConfigWatcher struct {
|
||||||
|
ch chan *Config
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewStaticConfigWatcher returns a ConfigWatcher for a config that never
|
||||||
|
// changes. It assumes only one "watcher" will ever call Watch. The config is
|
||||||
|
// delivered on the first call but will never be delivered again to allow
|
||||||
|
// callers to call repeatedly (e.g. select in a loop).
|
||||||
|
func NewStaticConfigWatcher(cfg *Config) *StaticConfigWatcher {
|
||||||
|
sc := &StaticConfigWatcher{
|
||||||
|
// Buffer it so we can queue up the config for first delivery.
|
||||||
|
ch: make(chan *Config, 1),
|
||||||
|
}
|
||||||
|
sc.ch <- cfg
|
||||||
|
return sc
|
||||||
|
}
|
||||||
|
|
||||||
|
// Watch implements ConfigWatcher on a static configuration for compatibility.
|
||||||
|
// It returns itself on the channel once and then leaves it open.
|
||||||
|
func (sc *StaticConfigWatcher) Watch() <-chan *Config {
|
||||||
|
return sc.ch
|
||||||
|
}
|
||||||
|
|
||||||
|
// ParseConfigFile parses proxy configuration from a file for local dev.
|
||||||
|
func ParseConfigFile(filename string) (*Config, error) {
|
||||||
|
bs, err := ioutil.ReadFile(filename)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
var cfg Config
|
||||||
|
|
||||||
|
err = hcl.Unmarshal(bs, &cfg)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
cfg.PublicListener.applyDefaults()
|
||||||
|
for idx := range cfg.Upstreams {
|
||||||
|
cfg.Upstreams[idx].applyDefaults()
|
||||||
|
}
|
||||||
|
|
||||||
|
return &cfg, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// AgentConfigWatcher watches the local Consul agent for proxy config changes.
|
||||||
|
type AgentConfigWatcher struct {
|
||||||
|
client *api.Client
|
||||||
|
proxyID string
|
||||||
|
logger *log.Logger
|
||||||
|
}
|
||||||
|
|
||||||
|
// Watch implements ConfigWatcher.
|
||||||
|
func (w *AgentConfigWatcher) Watch() <-chan *Config {
|
||||||
|
watch := make(chan *Config)
|
||||||
|
// TODO implement me, note we need to discover the Service instance to use and
|
||||||
|
// set it on the Config we return.
|
||||||
|
return watch
|
||||||
|
}
|
|
@ -0,0 +1,108 @@
|
||||||
|
package proxy
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/hashicorp/consul/connect"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestParseConfigFile(t *testing.T) {
|
||||||
|
cfg, err := ParseConfigFile("testdata/config-kitchensink.hcl")
|
||||||
|
require.Nil(t, err)
|
||||||
|
|
||||||
|
expect := &Config{
|
||||||
|
ProxyID: "foo",
|
||||||
|
Token: "11111111-2222-3333-4444-555555555555",
|
||||||
|
ProxiedServiceID: "web",
|
||||||
|
ProxiedServiceNamespace: "default",
|
||||||
|
PublicListener: PublicListenerConfig{
|
||||||
|
BindAddress: ":9999",
|
||||||
|
LocalServiceAddress: "127.0.0.1:5000",
|
||||||
|
LocalConnectTimeoutMs: 1000,
|
||||||
|
HandshakeTimeoutMs: 10000, // From defaults
|
||||||
|
},
|
||||||
|
Upstreams: []UpstreamConfig{
|
||||||
|
{
|
||||||
|
LocalBindAddress: "127.0.0.1:6000",
|
||||||
|
DestinationName: "db",
|
||||||
|
DestinationNamespace: "default",
|
||||||
|
DestinationType: "service",
|
||||||
|
ConnectTimeoutMs: 10000,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
LocalBindAddress: "127.0.0.1:6001",
|
||||||
|
DestinationName: "geo-cache",
|
||||||
|
DestinationNamespace: "default",
|
||||||
|
DestinationType: "prepared_query",
|
||||||
|
ConnectTimeoutMs: 10000,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
DevCAFile: "connect/testdata/ca1-ca-consul-internal.cert.pem",
|
||||||
|
DevServiceCertFile: "connect/testdata/ca1-svc-web.cert.pem",
|
||||||
|
DevServiceKeyFile: "connect/testdata/ca1-svc-web.key.pem",
|
||||||
|
}
|
||||||
|
|
||||||
|
require.Equal(t, expect, cfg)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUpstreamResolverFromClient(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
cfg UpstreamConfig
|
||||||
|
want *connect.ConsulResolver
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "service",
|
||||||
|
cfg: UpstreamConfig{
|
||||||
|
DestinationNamespace: "foo",
|
||||||
|
DestinationName: "web",
|
||||||
|
DestinationDatacenter: "ny1",
|
||||||
|
DestinationType: "service",
|
||||||
|
},
|
||||||
|
want: &connect.ConsulResolver{
|
||||||
|
Namespace: "foo",
|
||||||
|
Name: "web",
|
||||||
|
Datacenter: "ny1",
|
||||||
|
Type: connect.ConsulResolverTypeService,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "prepared_query",
|
||||||
|
cfg: UpstreamConfig{
|
||||||
|
DestinationNamespace: "foo",
|
||||||
|
DestinationName: "web",
|
||||||
|
DestinationDatacenter: "ny1",
|
||||||
|
DestinationType: "prepared_query",
|
||||||
|
},
|
||||||
|
want: &connect.ConsulResolver{
|
||||||
|
Namespace: "foo",
|
||||||
|
Name: "web",
|
||||||
|
Datacenter: "ny1",
|
||||||
|
Type: connect.ConsulResolverTypePreparedQuery,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "unknown behaves like service",
|
||||||
|
cfg: UpstreamConfig{
|
||||||
|
DestinationNamespace: "foo",
|
||||||
|
DestinationName: "web",
|
||||||
|
DestinationDatacenter: "ny1",
|
||||||
|
DestinationType: "junk",
|
||||||
|
},
|
||||||
|
want: &connect.ConsulResolver{
|
||||||
|
Namespace: "foo",
|
||||||
|
Name: "web",
|
||||||
|
Datacenter: "ny1",
|
||||||
|
Type: connect.ConsulResolverTypeService,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
// Client doesn't really matter as long as it's passed through.
|
||||||
|
got := UpstreamResolverFromClient(nil, tt.cfg)
|
||||||
|
require.Equal(t, tt.want, got)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,61 @@
|
||||||
|
package proxy
|
||||||
|
|
||||||
|
import (
|
||||||
|
"io"
|
||||||
|
"net"
|
||||||
|
"sync/atomic"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Conn represents a single proxied TCP connection.
|
||||||
|
type Conn struct {
|
||||||
|
src, dst net.Conn
|
||||||
|
stopping int32
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewConn returns a conn joining the two given net.Conn
|
||||||
|
func NewConn(src, dst net.Conn) *Conn {
|
||||||
|
return &Conn{
|
||||||
|
src: src,
|
||||||
|
dst: dst,
|
||||||
|
stopping: 0,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close closes both connections.
|
||||||
|
func (c *Conn) Close() error {
|
||||||
|
// Note that net.Conn.Close can be called multiple times and atomic store is
|
||||||
|
// idempotent so no need to ensure we only do this once.
|
||||||
|
//
|
||||||
|
// Also note that we don't wait for CopyBytes to return here since we are
|
||||||
|
// closing the conns which is the only externally visible sideeffect of that
|
||||||
|
// goroutine running and there should be no way for it to hang or leak once
|
||||||
|
// the conns are closed so we can save the extra coordination.
|
||||||
|
atomic.StoreInt32(&c.stopping, 1)
|
||||||
|
c.src.Close()
|
||||||
|
c.dst.Close()
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// CopyBytes will continuously copy bytes in both directions between src and dst
|
||||||
|
// until either connection is closed.
|
||||||
|
func (c *Conn) CopyBytes() error {
|
||||||
|
defer c.Close()
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
// Need this since Copy is only guaranteed to stop when it's source reader
|
||||||
|
// (second arg) hits EOF or error but either conn might close first possibly
|
||||||
|
// causing this goroutine to exit but not the outer one. See
|
||||||
|
// TestConnSrcClosing which will fail if you comment the defer below.
|
||||||
|
defer c.Close()
|
||||||
|
io.Copy(c.dst, c.src)
|
||||||
|
}()
|
||||||
|
|
||||||
|
_, err := io.Copy(c.src, c.dst)
|
||||||
|
// Note that we don't wait for the other goroutine to finish because it either
|
||||||
|
// already has due to it's src conn closing, or it will once our defer fires
|
||||||
|
// and closes the source conn. No need for the extra coordination.
|
||||||
|
if atomic.LoadInt32(&c.stopping) == 1 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
}
|
|
@ -0,0 +1,185 @@
|
||||||
|
package proxy
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bufio"
|
||||||
|
"io"
|
||||||
|
"net"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Assert io.Closer implementation
|
||||||
|
var _ io.Closer = new(Conn)
|
||||||
|
|
||||||
|
// testConnPairSetup creates a TCP connection by listening on a random port, and
|
||||||
|
// returns both ends. Ready to have data sent down them. It also returns a
|
||||||
|
// closer function that will close both conns and the listener.
|
||||||
|
func testConnPairSetup(t *testing.T) (net.Conn, net.Conn, func()) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
l, err := net.Listen("tcp", "localhost:0")
|
||||||
|
require.Nil(t, err)
|
||||||
|
|
||||||
|
ch := make(chan net.Conn, 1)
|
||||||
|
go func() {
|
||||||
|
src, err := l.Accept()
|
||||||
|
require.Nil(t, err)
|
||||||
|
ch <- src
|
||||||
|
}()
|
||||||
|
|
||||||
|
dst, err := net.Dial("tcp", l.Addr().String())
|
||||||
|
require.Nil(t, err)
|
||||||
|
|
||||||
|
src := <-ch
|
||||||
|
|
||||||
|
stopper := func() {
|
||||||
|
l.Close()
|
||||||
|
src.Close()
|
||||||
|
dst.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
return src, dst, stopper
|
||||||
|
}
|
||||||
|
|
||||||
|
// testConnPipelineSetup creates a pipeline consiting of two TCP connection
|
||||||
|
// pairs and a Conn that copies bytes between them. Data flow looks like this:
|
||||||
|
//
|
||||||
|
// src1 <---> dst1 <== Conn.CopyBytes ==> src2 <---> dst2
|
||||||
|
//
|
||||||
|
// The returned values are the src1 and dst2 which should be able to send and
|
||||||
|
// receive to each other via the Conn, the Conn itself (not running), and a
|
||||||
|
// stopper func to close everything.
|
||||||
|
func testConnPipelineSetup(t *testing.T) (net.Conn, net.Conn, *Conn, func()) {
|
||||||
|
src1, dst1, stop1 := testConnPairSetup(t)
|
||||||
|
src2, dst2, stop2 := testConnPairSetup(t)
|
||||||
|
c := NewConn(dst1, src2)
|
||||||
|
return src1, dst2, c, func() {
|
||||||
|
c.Close()
|
||||||
|
stop1()
|
||||||
|
stop2()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConn(t *testing.T) {
|
||||||
|
src, dst, c, stop := testConnPipelineSetup(t)
|
||||||
|
defer stop()
|
||||||
|
|
||||||
|
retCh := make(chan error, 1)
|
||||||
|
go func() {
|
||||||
|
retCh <- c.CopyBytes()
|
||||||
|
}()
|
||||||
|
|
||||||
|
// Now write/read into the other ends of the pipes (src1, dst2)
|
||||||
|
srcR := bufio.NewReader(src)
|
||||||
|
dstR := bufio.NewReader(dst)
|
||||||
|
|
||||||
|
_, err := src.Write([]byte("ping 1\n"))
|
||||||
|
require.Nil(t, err)
|
||||||
|
_, err = dst.Write([]byte("ping 2\n"))
|
||||||
|
require.Nil(t, err)
|
||||||
|
|
||||||
|
got, err := dstR.ReadString('\n')
|
||||||
|
require.Nil(t, err)
|
||||||
|
require.Equal(t, "ping 1\n", got)
|
||||||
|
|
||||||
|
got, err = srcR.ReadString('\n')
|
||||||
|
require.Nil(t, err)
|
||||||
|
require.Equal(t, "ping 2\n", got)
|
||||||
|
|
||||||
|
_, err = src.Write([]byte("pong 1\n"))
|
||||||
|
require.Nil(t, err)
|
||||||
|
_, err = dst.Write([]byte("pong 2\n"))
|
||||||
|
require.Nil(t, err)
|
||||||
|
|
||||||
|
got, err = dstR.ReadString('\n')
|
||||||
|
require.Nil(t, err)
|
||||||
|
require.Equal(t, "pong 1\n", got)
|
||||||
|
|
||||||
|
got, err = srcR.ReadString('\n')
|
||||||
|
require.Nil(t, err)
|
||||||
|
require.Equal(t, "pong 2\n", got)
|
||||||
|
|
||||||
|
c.Close()
|
||||||
|
|
||||||
|
ret := <-retCh
|
||||||
|
require.Nil(t, ret, "Close() should not cause error return")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConnSrcClosing(t *testing.T) {
|
||||||
|
src, dst, c, stop := testConnPipelineSetup(t)
|
||||||
|
defer stop()
|
||||||
|
|
||||||
|
retCh := make(chan error, 1)
|
||||||
|
go func() {
|
||||||
|
retCh <- c.CopyBytes()
|
||||||
|
}()
|
||||||
|
|
||||||
|
// Wait until we can actually get some bytes through both ways so we know that
|
||||||
|
// the copy goroutines are running.
|
||||||
|
srcR := bufio.NewReader(src)
|
||||||
|
dstR := bufio.NewReader(dst)
|
||||||
|
|
||||||
|
_, err := src.Write([]byte("ping 1\n"))
|
||||||
|
require.Nil(t, err)
|
||||||
|
_, err = dst.Write([]byte("ping 2\n"))
|
||||||
|
require.Nil(t, err)
|
||||||
|
|
||||||
|
got, err := dstR.ReadString('\n')
|
||||||
|
require.Nil(t, err)
|
||||||
|
require.Equal(t, "ping 1\n", got)
|
||||||
|
got, err = srcR.ReadString('\n')
|
||||||
|
require.Nil(t, err)
|
||||||
|
require.Equal(t, "ping 2\n", got)
|
||||||
|
|
||||||
|
// If we close the src conn, we expect CopyBytes to return and dst to be
|
||||||
|
// closed too. No good way to assert that the conn is closed really other than
|
||||||
|
// assume the retCh receive will hang unless CopyBytes exits and that
|
||||||
|
// CopyBytes defers Closing both.
|
||||||
|
testTimer := time.AfterFunc(3*time.Second, func() {
|
||||||
|
panic("test timeout")
|
||||||
|
})
|
||||||
|
src.Close()
|
||||||
|
<-retCh
|
||||||
|
testTimer.Stop()
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConnDstClosing(t *testing.T) {
|
||||||
|
src, dst, c, stop := testConnPipelineSetup(t)
|
||||||
|
defer stop()
|
||||||
|
|
||||||
|
retCh := make(chan error, 1)
|
||||||
|
go func() {
|
||||||
|
retCh <- c.CopyBytes()
|
||||||
|
}()
|
||||||
|
|
||||||
|
// Wait until we can actually get some bytes through both ways so we know that
|
||||||
|
// the copy goroutines are running.
|
||||||
|
srcR := bufio.NewReader(src)
|
||||||
|
dstR := bufio.NewReader(dst)
|
||||||
|
|
||||||
|
_, err := src.Write([]byte("ping 1\n"))
|
||||||
|
require.Nil(t, err)
|
||||||
|
_, err = dst.Write([]byte("ping 2\n"))
|
||||||
|
require.Nil(t, err)
|
||||||
|
|
||||||
|
got, err := dstR.ReadString('\n')
|
||||||
|
require.Nil(t, err)
|
||||||
|
require.Equal(t, "ping 1\n", got)
|
||||||
|
got, err = srcR.ReadString('\n')
|
||||||
|
require.Nil(t, err)
|
||||||
|
require.Equal(t, "ping 2\n", got)
|
||||||
|
|
||||||
|
// If we close the dst conn, we expect CopyBytes to return and src to be
|
||||||
|
// closed too. No good way to assert that the conn is closed really other than
|
||||||
|
// assume the retCh receive will hang unless CopyBytes exits and that
|
||||||
|
// CopyBytes defers Closing both. i.e. if this test doesn't time out it's
|
||||||
|
// good!
|
||||||
|
testTimer := time.AfterFunc(3*time.Second, func() {
|
||||||
|
panic("test timeout")
|
||||||
|
})
|
||||||
|
src.Close()
|
||||||
|
<-retCh
|
||||||
|
testTimer.Stop()
|
||||||
|
}
|
|
@ -0,0 +1,116 @@
|
||||||
|
package proxy
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"crypto/tls"
|
||||||
|
"errors"
|
||||||
|
"log"
|
||||||
|
"net"
|
||||||
|
"sync/atomic"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/hashicorp/consul/connect"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Listener is the implementation of a specific proxy listener. It has pluggable
|
||||||
|
// Listen and Dial methods to suit public mTLS vs upstream semantics. It handles
|
||||||
|
// the lifecycle of the listener and all connections opened through it
|
||||||
|
type Listener struct {
|
||||||
|
// Service is the connect service instance to use.
|
||||||
|
Service *connect.Service
|
||||||
|
|
||||||
|
listenFunc func() (net.Listener, error)
|
||||||
|
dialFunc func() (net.Conn, error)
|
||||||
|
|
||||||
|
stopFlag int32
|
||||||
|
stopChan chan struct{}
|
||||||
|
|
||||||
|
logger *log.Logger
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewPublicListener returns a Listener setup to listen for public mTLS
|
||||||
|
// connections and proxy them to the configured local application over TCP.
|
||||||
|
func NewPublicListener(svc *connect.Service, cfg PublicListenerConfig,
|
||||||
|
logger *log.Logger) *Listener {
|
||||||
|
return &Listener{
|
||||||
|
Service: svc,
|
||||||
|
listenFunc: func() (net.Listener, error) {
|
||||||
|
return tls.Listen("tcp", cfg.BindAddress, svc.ServerTLSConfig())
|
||||||
|
},
|
||||||
|
dialFunc: func() (net.Conn, error) {
|
||||||
|
return net.DialTimeout("tcp", cfg.LocalServiceAddress,
|
||||||
|
time.Duration(cfg.LocalConnectTimeoutMs)*time.Millisecond)
|
||||||
|
},
|
||||||
|
stopChan: make(chan struct{}),
|
||||||
|
logger: logger,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewUpstreamListener returns a Listener setup to listen locally for TCP
|
||||||
|
// connections that are proxied to a discovered Connect service instance.
|
||||||
|
func NewUpstreamListener(svc *connect.Service, cfg UpstreamConfig,
|
||||||
|
logger *log.Logger) *Listener {
|
||||||
|
return &Listener{
|
||||||
|
Service: svc,
|
||||||
|
listenFunc: func() (net.Listener, error) {
|
||||||
|
return net.Listen("tcp", cfg.LocalBindAddress)
|
||||||
|
},
|
||||||
|
dialFunc: func() (net.Conn, error) {
|
||||||
|
if cfg.resolver == nil {
|
||||||
|
return nil, errors.New("no resolver provided")
|
||||||
|
}
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(),
|
||||||
|
time.Duration(cfg.ConnectTimeoutMs)*time.Millisecond)
|
||||||
|
defer cancel()
|
||||||
|
return svc.Dial(ctx, cfg.resolver)
|
||||||
|
},
|
||||||
|
stopChan: make(chan struct{}),
|
||||||
|
logger: logger,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Serve runs the listener until it is stopped.
|
||||||
|
func (l *Listener) Serve() error {
|
||||||
|
listen, err := l.listenFunc()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
for {
|
||||||
|
conn, err := listen.Accept()
|
||||||
|
if err != nil {
|
||||||
|
if atomic.LoadInt32(&l.stopFlag) == 1 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
go l.handleConn(conn)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// handleConn is the internal connection handler goroutine.
|
||||||
|
func (l *Listener) handleConn(src net.Conn) {
|
||||||
|
defer src.Close()
|
||||||
|
|
||||||
|
dst, err := l.dialFunc()
|
||||||
|
if err != nil {
|
||||||
|
l.logger.Printf("[ERR] failed to dial: %s", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// Note no need to defer dst.Close() since conn handles that for us.
|
||||||
|
conn := NewConn(src, dst)
|
||||||
|
defer conn.Close()
|
||||||
|
|
||||||
|
err = conn.CopyBytes()
|
||||||
|
if err != nil {
|
||||||
|
l.logger.Printf("[ERR] connection failed: %s", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close terminates the listener and all active connections.
|
||||||
|
func (l *Listener) Close() error {
|
||||||
|
return nil
|
||||||
|
}
|
|
@ -0,0 +1,91 @@
|
||||||
|
package proxy
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"log"
|
||||||
|
"net"
|
||||||
|
"os"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
agConnect "github.com/hashicorp/consul/agent/connect"
|
||||||
|
"github.com/hashicorp/consul/connect"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestPublicListener(t *testing.T) {
|
||||||
|
ca := agConnect.TestCA(t, nil)
|
||||||
|
addrs := TestLocalBindAddrs(t, 2)
|
||||||
|
|
||||||
|
cfg := PublicListenerConfig{
|
||||||
|
BindAddress: addrs[0],
|
||||||
|
LocalServiceAddress: addrs[1],
|
||||||
|
HandshakeTimeoutMs: 100,
|
||||||
|
LocalConnectTimeoutMs: 100,
|
||||||
|
}
|
||||||
|
|
||||||
|
testApp, err := NewTestTCPServer(t, cfg.LocalServiceAddress)
|
||||||
|
require.Nil(t, err)
|
||||||
|
defer testApp.Close()
|
||||||
|
|
||||||
|
svc := connect.TestService(t, "db", ca)
|
||||||
|
|
||||||
|
l := NewPublicListener(svc, cfg, log.New(os.Stderr, "", log.LstdFlags))
|
||||||
|
|
||||||
|
// Run proxy
|
||||||
|
go func() {
|
||||||
|
err := l.Serve()
|
||||||
|
require.Nil(t, err)
|
||||||
|
}()
|
||||||
|
defer l.Close()
|
||||||
|
|
||||||
|
// Proxy and backend are running, play the part of a TLS client using same
|
||||||
|
// cert for now.
|
||||||
|
conn, err := svc.Dial(context.Background(), &connect.StaticResolver{
|
||||||
|
Addr: addrs[0],
|
||||||
|
CertURI: agConnect.TestSpiffeIDService(t, "db"),
|
||||||
|
})
|
||||||
|
require.Nilf(t, err, "unexpected err: %s", err)
|
||||||
|
TestEchoConn(t, conn, "")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUpstreamListener(t *testing.T) {
|
||||||
|
ca := agConnect.TestCA(t, nil)
|
||||||
|
addrs := TestLocalBindAddrs(t, 1)
|
||||||
|
|
||||||
|
// Run a test server that we can dial.
|
||||||
|
testSvr := connect.NewTestServer(t, "db", ca)
|
||||||
|
go func() {
|
||||||
|
err := testSvr.Serve()
|
||||||
|
require.Nil(t, err)
|
||||||
|
}()
|
||||||
|
defer testSvr.Close()
|
||||||
|
|
||||||
|
cfg := UpstreamConfig{
|
||||||
|
DestinationType: "service",
|
||||||
|
DestinationNamespace: "default",
|
||||||
|
DestinationName: "db",
|
||||||
|
ConnectTimeoutMs: 100,
|
||||||
|
LocalBindAddress: addrs[0],
|
||||||
|
resolver: &connect.StaticResolver{
|
||||||
|
Addr: testSvr.Addr,
|
||||||
|
CertURI: agConnect.TestSpiffeIDService(t, "db"),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
svc := connect.TestService(t, "web", ca)
|
||||||
|
|
||||||
|
l := NewUpstreamListener(svc, cfg, log.New(os.Stderr, "", log.LstdFlags))
|
||||||
|
|
||||||
|
// Run proxy
|
||||||
|
go func() {
|
||||||
|
err := l.Serve()
|
||||||
|
require.Nil(t, err)
|
||||||
|
}()
|
||||||
|
defer l.Close()
|
||||||
|
|
||||||
|
// Proxy and fake remote service are running, play the part of the app
|
||||||
|
// connecting to a remote connect service over TCP.
|
||||||
|
conn, err := net.Dial("tcp", cfg.LocalBindAddress)
|
||||||
|
require.Nilf(t, err, "unexpected err: %s", err)
|
||||||
|
TestEchoConn(t, conn, "")
|
||||||
|
}
|
|
@ -0,0 +1,134 @@
|
||||||
|
package proxy
|
||||||
|
|
||||||
|
import (
|
||||||
|
"log"
|
||||||
|
|
||||||
|
"github.com/hashicorp/consul/api"
|
||||||
|
"github.com/hashicorp/consul/connect"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Proxy implements the built-in connect proxy.
|
||||||
|
type Proxy struct {
|
||||||
|
proxyID string
|
||||||
|
client *api.Client
|
||||||
|
cfgWatcher ConfigWatcher
|
||||||
|
stopChan chan struct{}
|
||||||
|
logger *log.Logger
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewFromConfigFile returns a Proxy instance configured just from a local file.
|
||||||
|
// This is intended mostly for development and bypasses the normal mechanisms
|
||||||
|
// for fetching config and certificates from the local agent.
|
||||||
|
func NewFromConfigFile(client *api.Client, filename string,
|
||||||
|
logger *log.Logger) (*Proxy, error) {
|
||||||
|
cfg, err := ParseConfigFile(filename)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
service, err := connect.NewDevServiceFromCertFiles(cfg.ProxiedServiceID,
|
||||||
|
client, logger, cfg.DevCAFile, cfg.DevServiceCertFile,
|
||||||
|
cfg.DevServiceKeyFile)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
cfg.service = service
|
||||||
|
|
||||||
|
p := &Proxy{
|
||||||
|
proxyID: cfg.ProxyID,
|
||||||
|
client: client,
|
||||||
|
cfgWatcher: NewStaticConfigWatcher(cfg),
|
||||||
|
stopChan: make(chan struct{}),
|
||||||
|
logger: logger,
|
||||||
|
}
|
||||||
|
return p, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// New returns a Proxy with the given id, consuming the provided (configured)
|
||||||
|
// agent. It is ready to Run().
|
||||||
|
func New(client *api.Client, proxyID string, logger *log.Logger) (*Proxy, error) {
|
||||||
|
p := &Proxy{
|
||||||
|
proxyID: proxyID,
|
||||||
|
client: client,
|
||||||
|
cfgWatcher: &AgentConfigWatcher{
|
||||||
|
client: client,
|
||||||
|
proxyID: proxyID,
|
||||||
|
logger: logger,
|
||||||
|
},
|
||||||
|
stopChan: make(chan struct{}),
|
||||||
|
logger: logger,
|
||||||
|
}
|
||||||
|
return p, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Serve the proxy instance until a fatal error occurs or proxy is closed.
|
||||||
|
func (p *Proxy) Serve() error {
|
||||||
|
|
||||||
|
var cfg *Config
|
||||||
|
|
||||||
|
// Watch for config changes (initial setup happens on first "change")
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case newCfg := <-p.cfgWatcher.Watch():
|
||||||
|
p.logger.Printf("[DEBUG] got new config")
|
||||||
|
if newCfg.service == nil {
|
||||||
|
p.logger.Printf("[ERR] new config has nil service")
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if cfg == nil {
|
||||||
|
// Initial setup
|
||||||
|
|
||||||
|
newCfg.PublicListener.applyDefaults()
|
||||||
|
l := NewPublicListener(newCfg.service, newCfg.PublicListener, p.logger)
|
||||||
|
err := p.startListener("public listener", l)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO(banks) update/remove upstreams properly based on a diff with current. Can
|
||||||
|
// store a map of uc.String() to Listener here and then use it to only
|
||||||
|
// start one of each and stop/modify if changes occur.
|
||||||
|
for _, uc := range newCfg.Upstreams {
|
||||||
|
uc.applyDefaults()
|
||||||
|
uc.resolver = UpstreamResolverFromClient(p.client, uc)
|
||||||
|
|
||||||
|
l := NewUpstreamListener(newCfg.service, uc, p.logger)
|
||||||
|
err := p.startListener(uc.String(), l)
|
||||||
|
if err != nil {
|
||||||
|
p.logger.Printf("[ERR] failed to start upstream %s: %s", uc.String(),
|
||||||
|
err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
cfg = newCfg
|
||||||
|
|
||||||
|
case <-p.stopChan:
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// startPublicListener is run from the internal state machine loop
|
||||||
|
func (p *Proxy) startListener(name string, l *Listener) error {
|
||||||
|
go func() {
|
||||||
|
err := l.Serve()
|
||||||
|
if err != nil {
|
||||||
|
p.logger.Printf("[ERR] %s stopped with error: %s", name, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
p.logger.Printf("[INFO] %s stopped", name)
|
||||||
|
}()
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
<-p.stopChan
|
||||||
|
l.Close()
|
||||||
|
}()
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close stops the proxy and terminates all active connections. It must be
|
||||||
|
// called only once.
|
||||||
|
func (p *Proxy) Close() {
|
||||||
|
close(p.stopChan)
|
||||||
|
}
|
|
@ -0,0 +1,32 @@
|
||||||
|
# Example proxy config with everything specified
|
||||||
|
|
||||||
|
proxy_id = "foo"
|
||||||
|
token = "11111111-2222-3333-4444-555555555555"
|
||||||
|
|
||||||
|
proxied_service_id = "web"
|
||||||
|
proxied_service_namespace = "default"
|
||||||
|
|
||||||
|
# Assumes running consul in dev mode from the repo root...
|
||||||
|
dev_ca_file = "connect/testdata/ca1-ca-consul-internal.cert.pem"
|
||||||
|
dev_service_cert_file = "connect/testdata/ca1-svc-web.cert.pem"
|
||||||
|
dev_service_key_file = "connect/testdata/ca1-svc-web.key.pem"
|
||||||
|
|
||||||
|
public_listener {
|
||||||
|
bind_address = ":9999"
|
||||||
|
local_service_address = "127.0.0.1:5000"
|
||||||
|
}
|
||||||
|
|
||||||
|
upstreams = [
|
||||||
|
{
|
||||||
|
local_bind_address = "127.0.0.1:6000"
|
||||||
|
destination_name = "db"
|
||||||
|
destination_namespace = "default"
|
||||||
|
destination_type = "service"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
local_bind_address = "127.0.0.1:6001"
|
||||||
|
destination_name = "geo-cache"
|
||||||
|
destination_namespace = "default"
|
||||||
|
destination_type = "prepared_query"
|
||||||
|
}
|
||||||
|
]
|
|
@ -0,0 +1,105 @@
|
||||||
|
package proxy
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"log"
|
||||||
|
"net"
|
||||||
|
"sync/atomic"
|
||||||
|
|
||||||
|
"github.com/hashicorp/consul/lib/freeport"
|
||||||
|
"github.com/mitchellh/go-testing-interface"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
// TestLocalBindAddrs returns n localhost address:port strings with free ports
|
||||||
|
// for binding test listeners to.
|
||||||
|
func TestLocalBindAddrs(t testing.T, n int) []string {
|
||||||
|
ports := freeport.GetT(t, n)
|
||||||
|
addrs := make([]string, n)
|
||||||
|
for i, p := range ports {
|
||||||
|
addrs[i] = fmt.Sprintf("localhost:%d", p)
|
||||||
|
}
|
||||||
|
return addrs
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestTCPServer is a simple TCP echo server for use during tests.
|
||||||
|
type TestTCPServer struct {
|
||||||
|
l net.Listener
|
||||||
|
stopped int32
|
||||||
|
accepted, closed, active int32
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewTestTCPServer opens as a listening socket on the given address and returns
|
||||||
|
// a TestTCPServer serving requests to it. The server is already started and can
|
||||||
|
// be stopped by calling Close().
|
||||||
|
func NewTestTCPServer(t testing.T, addr string) (*TestTCPServer, error) {
|
||||||
|
l, err := net.Listen("tcp", addr)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
log.Printf("test tcp server listening on %s", addr)
|
||||||
|
s := &TestTCPServer{
|
||||||
|
l: l,
|
||||||
|
}
|
||||||
|
go s.accept()
|
||||||
|
return s, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close stops the server
|
||||||
|
func (s *TestTCPServer) Close() {
|
||||||
|
atomic.StoreInt32(&s.stopped, 1)
|
||||||
|
if s.l != nil {
|
||||||
|
s.l.Close()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *TestTCPServer) accept() error {
|
||||||
|
for {
|
||||||
|
conn, err := s.l.Accept()
|
||||||
|
if err != nil {
|
||||||
|
if atomic.LoadInt32(&s.stopped) == 1 {
|
||||||
|
log.Printf("test tcp echo server %s stopped", s.l.Addr())
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
log.Printf("test tcp echo server %s failed: %s", s.l.Addr(), err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
atomic.AddInt32(&s.accepted, 1)
|
||||||
|
atomic.AddInt32(&s.active, 1)
|
||||||
|
|
||||||
|
go func(c net.Conn) {
|
||||||
|
io.Copy(c, c)
|
||||||
|
atomic.AddInt32(&s.closed, 1)
|
||||||
|
atomic.AddInt32(&s.active, -1)
|
||||||
|
}(conn)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestEchoConn attempts to write some bytes to conn and expects to read them
|
||||||
|
// back within a short timeout (10ms). If prefix is not empty we expect it to be
|
||||||
|
// poresent at the start of all echoed responses (for example to distinguish
|
||||||
|
// between multiple echo server instances).
|
||||||
|
func TestEchoConn(t testing.T, conn net.Conn, prefix string) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
// Write some bytes and read them back
|
||||||
|
n, err := conn.Write([]byte("Hello World"))
|
||||||
|
require.Equal(t, 11, n)
|
||||||
|
require.Nil(t, err)
|
||||||
|
|
||||||
|
expectLen := 11 + len(prefix)
|
||||||
|
|
||||||
|
buf := make([]byte, expectLen)
|
||||||
|
// read until our buffer is full - it might be separate packets if prefix is
|
||||||
|
// in use.
|
||||||
|
got := 0
|
||||||
|
for got < expectLen {
|
||||||
|
n, err = conn.Read(buf[got:])
|
||||||
|
require.Nilf(t, err, "err: %s", err)
|
||||||
|
got += n
|
||||||
|
}
|
||||||
|
require.Equal(t, expectLen, got)
|
||||||
|
require.Equal(t, prefix+"Hello World", string(buf[:]))
|
||||||
|
}
|
|
@ -10,7 +10,9 @@ import (
|
||||||
testing "github.com/mitchellh/go-testing-interface"
|
testing "github.com/mitchellh/go-testing-interface"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Resolver is the interface implemented by a service discovery mechanism.
|
// Resolver is the interface implemented by a service discovery mechanism to get
|
||||||
|
// the address and identity of an instance to connect to via Connect as a
|
||||||
|
// client.
|
||||||
type Resolver interface {
|
type Resolver interface {
|
||||||
// Resolve returns a single service instance to connect to. Implementations
|
// Resolve returns a single service instance to connect to. Implementations
|
||||||
// may attempt to ensure the instance returned is currently available. It is
|
// may attempt to ensure the instance returned is currently available. It is
|
||||||
|
@ -19,7 +21,10 @@ type Resolver interface {
|
||||||
// increases reliability. The context passed can be used to impose timeouts
|
// increases reliability. The context passed can be used to impose timeouts
|
||||||
// which may or may not be respected by implementations that make network
|
// which may or may not be respected by implementations that make network
|
||||||
// calls to resolve the service. The addr returned is a string in any valid
|
// calls to resolve the service. The addr returned is a string in any valid
|
||||||
// form for passing directly to `net.Dial("tcp", addr)`.
|
// form for passing directly to `net.Dial("tcp", addr)`. The certURI
|
||||||
|
// represents the identity of the service instance. It will be matched against
|
||||||
|
// the TLS certificate URI SAN presented by the server and the connection
|
||||||
|
// rejected if they don't match.
|
||||||
Resolve(ctx context.Context) (addr string, certURI connect.CertURI, err error)
|
Resolve(ctx context.Context) (addr string, certURI connect.CertURI, err error)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -33,7 +38,8 @@ type StaticResolver struct {
|
||||||
Addr string
|
Addr string
|
||||||
|
|
||||||
// CertURL is the _identity_ we expect the server to present in it's TLS
|
// CertURL is the _identity_ we expect the server to present in it's TLS
|
||||||
// certificate. It must be an exact match or the connection will be rejected.
|
// certificate. It must be an exact URI string match or the connection will be
|
||||||
|
// rejected.
|
||||||
CertURI connect.CertURI
|
CertURI connect.CertURI
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -56,13 +62,14 @@ type ConsulResolver struct {
|
||||||
// panic.
|
// panic.
|
||||||
Client *api.Client
|
Client *api.Client
|
||||||
|
|
||||||
// Namespace of the query target
|
// Namespace of the query target.
|
||||||
Namespace string
|
Namespace string
|
||||||
|
|
||||||
// Name of the query target
|
// Name of the query target.
|
||||||
Name string
|
Name string
|
||||||
|
|
||||||
// Type of the query target,
|
// Type of the query target. Should be one of the defined ConsulResolverType*
|
||||||
|
// constants. Currently defaults to ConsulResolverTypeService.
|
||||||
Type int
|
Type int
|
||||||
|
|
||||||
// Datacenter to resolve in, empty indicates agent's local DC.
|
// Datacenter to resolve in, empty indicates agent's local DC.
|
||||||
|
|
|
@ -41,7 +41,6 @@ func TestStaticResolver_Resolve(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestConsulResolver_Resolve(t *testing.T) {
|
func TestConsulResolver_Resolve(t *testing.T) {
|
||||||
|
|
||||||
// Setup a local test agent to query
|
// Setup a local test agent to query
|
||||||
agent := agent.NewTestAgent("test-consul", "")
|
agent := agent.NewTestAgent("test-consul", "")
|
||||||
defer agent.Shutdown()
|
defer agent.Shutdown()
|
||||||
|
|
|
@ -3,6 +3,7 @@ package connect
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"crypto/tls"
|
"crypto/tls"
|
||||||
|
"errors"
|
||||||
"log"
|
"log"
|
||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
@ -10,6 +11,7 @@ import (
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/hashicorp/consul/api"
|
"github.com/hashicorp/consul/api"
|
||||||
|
"golang.org/x/net/http2"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Service represents a Consul service that accepts and/or connects via Connect.
|
// Service represents a Consul service that accepts and/or connects via Connect.
|
||||||
|
@ -41,10 +43,17 @@ type Service struct {
|
||||||
client *api.Client
|
client *api.Client
|
||||||
|
|
||||||
// serverTLSCfg is the (reloadable) TLS config we use for serving.
|
// serverTLSCfg is the (reloadable) TLS config we use for serving.
|
||||||
serverTLSCfg *ReloadableTLSConfig
|
serverTLSCfg *reloadableTLSConfig
|
||||||
|
|
||||||
// clientTLSCfg is the (reloadable) TLS config we use for dialling.
|
// clientTLSCfg is the (reloadable) TLS config we use for dialling.
|
||||||
clientTLSCfg *ReloadableTLSConfig
|
clientTLSCfg *reloadableTLSConfig
|
||||||
|
|
||||||
|
// httpResolverFromAddr is a function that returns a Resolver from a string
|
||||||
|
// address for HTTP clients. It's privately pluggable to make testing easier
|
||||||
|
// but will default to a simple method to parse the host as a Consul DNS host.
|
||||||
|
//
|
||||||
|
// TODO(banks): write the proper implementation
|
||||||
|
httpResolverFromAddr func(addr string) (Resolver, error)
|
||||||
|
|
||||||
logger *log.Logger
|
logger *log.Logger
|
||||||
}
|
}
|
||||||
|
@ -65,8 +74,8 @@ func NewServiceWithLogger(serviceID string, client *api.Client,
|
||||||
client: client,
|
client: client,
|
||||||
logger: logger,
|
logger: logger,
|
||||||
}
|
}
|
||||||
s.serverTLSCfg = NewReloadableTLSConfig(defaultTLSConfig(serverVerifyCerts))
|
s.serverTLSCfg = newReloadableTLSConfig(defaultTLSConfig(serverVerifyCerts))
|
||||||
s.clientTLSCfg = NewReloadableTLSConfig(defaultTLSConfig(clientVerifyCerts))
|
s.clientTLSCfg = newReloadableTLSConfig(defaultTLSConfig(clientVerifyCerts))
|
||||||
|
|
||||||
// TODO(banks) run the background certificate sync
|
// TODO(banks) run the background certificate sync
|
||||||
return s, nil
|
return s, nil
|
||||||
|
@ -86,12 +95,12 @@ func NewDevServiceFromCertFiles(serviceID string, client *api.Client,
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Note that NewReloadableTLSConfig makes a copy so we can re-use the same
|
// Note that newReloadableTLSConfig makes a copy so we can re-use the same
|
||||||
// base for both client and server with swapped verifiers.
|
// base for both client and server with swapped verifiers.
|
||||||
tlsCfg.VerifyPeerCertificate = serverVerifyCerts
|
tlsCfg.VerifyPeerCertificate = serverVerifyCerts
|
||||||
s.serverTLSCfg = NewReloadableTLSConfig(tlsCfg)
|
s.serverTLSCfg = newReloadableTLSConfig(tlsCfg)
|
||||||
tlsCfg.VerifyPeerCertificate = clientVerifyCerts
|
tlsCfg.VerifyPeerCertificate = clientVerifyCerts
|
||||||
s.clientTLSCfg = NewReloadableTLSConfig(tlsCfg)
|
s.clientTLSCfg = newReloadableTLSConfig(tlsCfg)
|
||||||
return s, nil
|
return s, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -121,6 +130,8 @@ func (s *Service) Dial(ctx context.Context, resolver Resolver) (net.Conn, error)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
s.logger.Printf("[DEBUG] resolved service instance: %s (%s)", addr,
|
||||||
|
certURI.URI())
|
||||||
var dialer net.Dialer
|
var dialer net.Dialer
|
||||||
tcpConn, err := dialer.DialContext(ctx, "tcp", addr)
|
tcpConn, err := dialer.DialContext(ctx, "tcp", addr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -133,8 +144,8 @@ func (s *Service) Dial(ctx context.Context, resolver Resolver) (net.Conn, error)
|
||||||
if ok {
|
if ok {
|
||||||
tlsConn.SetDeadline(deadline)
|
tlsConn.SetDeadline(deadline)
|
||||||
}
|
}
|
||||||
err = tlsConn.Handshake()
|
// Perform handshake
|
||||||
if err != nil {
|
if err = tlsConn.Handshake(); err != nil {
|
||||||
tlsConn.Close()
|
tlsConn.Close()
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -149,20 +160,27 @@ func (s *Service) Dial(ctx context.Context, resolver Resolver) (net.Conn, error)
|
||||||
tlsConn.Close()
|
tlsConn.Close()
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
s.logger.Printf("[DEBUG] successfully connected to %s (%s)", addr,
|
||||||
|
certURI.URI())
|
||||||
return tlsConn, nil
|
return tlsConn, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// HTTPDialContext is compatible with http.Transport.DialContext. It expects the
|
// HTTPDialTLS is compatible with http.Transport.DialTLS. It expects the addr
|
||||||
// addr hostname to be specified using Consul DNS query syntax, e.g.
|
// hostname to be specified using Consul DNS query syntax, e.g.
|
||||||
// "web.service.consul". It converts that into the equivalent ConsulResolver and
|
// "web.service.consul". It converts that into the equivalent ConsulResolver and
|
||||||
// then call s.Dial with the resolver. This is low level, clients should
|
// then call s.Dial with the resolver. This is low level, clients should
|
||||||
// typically use HTTPClient directly.
|
// typically use HTTPClient directly.
|
||||||
func (s *Service) HTTPDialContext(ctx context.Context, network,
|
func (s *Service) HTTPDialTLS(network,
|
||||||
addr string) (net.Conn, error) {
|
addr string) (net.Conn, error) {
|
||||||
var r ConsulResolver
|
if s.httpResolverFromAddr == nil {
|
||||||
// TODO(banks): parse addr into ConsulResolver
|
return nil, errors.New("no http resolver configured")
|
||||||
return s.Dial(ctx, &r)
|
}
|
||||||
|
r, err := s.httpResolverFromAddr(addr)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
// TODO(banks): figure out how to do timeouts better.
|
||||||
|
return s.Dial(context.Background(), r)
|
||||||
}
|
}
|
||||||
|
|
||||||
// HTTPClient returns an *http.Client configured to dial remote Consul Connect
|
// HTTPClient returns an *http.Client configured to dial remote Consul Connect
|
||||||
|
@ -172,14 +190,27 @@ func (s *Service) HTTPDialContext(ctx context.Context, network,
|
||||||
// API rather than just relying on Consul DNS. Hostnames that are not valid
|
// API rather than just relying on Consul DNS. Hostnames that are not valid
|
||||||
// Consul DNS queries will fail.
|
// Consul DNS queries will fail.
|
||||||
func (s *Service) HTTPClient() *http.Client {
|
func (s *Service) HTTPClient() *http.Client {
|
||||||
|
t := &http.Transport{
|
||||||
|
// Sadly we can't use DialContext hook since that is expected to return a
|
||||||
|
// plain TCP connection an http.Client tries to start a TLS handshake over
|
||||||
|
// it. We need to control the handshake to be able to do our validation.
|
||||||
|
// So we have to use the older DialTLS which means no context/timeout
|
||||||
|
// support.
|
||||||
|
//
|
||||||
|
// TODO(banks): figure out how users can configure a timeout when using
|
||||||
|
// this and/or compatibility with http.Request.WithContext.
|
||||||
|
DialTLS: s.HTTPDialTLS,
|
||||||
|
}
|
||||||
|
// Need to manually re-enable http2 support since we set custom DialTLS.
|
||||||
|
// See https://golang.org/src/net/http/transport.go?s=8692:9036#L228
|
||||||
|
http2.ConfigureTransport(t)
|
||||||
return &http.Client{
|
return &http.Client{
|
||||||
Transport: &http.Transport{
|
Transport: t,
|
||||||
DialContext: s.HTTPDialContext,
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Close stops the service and frees resources.
|
// Close stops the service and frees resources.
|
||||||
func (s *Service) Close() {
|
func (s *Service) Close() error {
|
||||||
// TODO(banks): stop background activity if started
|
// TODO(banks): stop background activity if started
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
|
@ -2,14 +2,22 @@ package connect
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"crypto/tls"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"io/ioutil"
|
||||||
|
"net/http"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/hashicorp/consul/agent/connect"
|
"github.com/hashicorp/consul/agent/connect"
|
||||||
|
"github.com/hashicorp/consul/testutil/retry"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// Assert io.Closer implementation
|
||||||
|
var _ io.Closer = new(Service)
|
||||||
|
|
||||||
func TestService_Dial(t *testing.T) {
|
func TestService_Dial(t *testing.T) {
|
||||||
ca := connect.TestCA(t, nil)
|
ca := connect.TestCA(t, nil)
|
||||||
|
|
||||||
|
@ -53,30 +61,26 @@ func TestService_Dial(t *testing.T) {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
require := require.New(t)
|
require := require.New(t)
|
||||||
|
|
||||||
s, err := NewService("web", nil)
|
s := TestService(t, "web", ca)
|
||||||
require.Nil(err)
|
|
||||||
|
|
||||||
// Force TLSConfig
|
|
||||||
s.clientTLSCfg = NewReloadableTLSConfig(TestTLSConfig(t, "web", ca))
|
|
||||||
|
|
||||||
ctx, cancel := context.WithTimeout(context.Background(),
|
ctx, cancel := context.WithTimeout(context.Background(),
|
||||||
100*time.Millisecond)
|
100*time.Millisecond)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
testSvc := NewTestService(t, tt.presentService, ca)
|
testSvr := NewTestServer(t, tt.presentService, ca)
|
||||||
testSvc.TimeoutHandshake = !tt.handshake
|
testSvr.TimeoutHandshake = !tt.handshake
|
||||||
|
|
||||||
if tt.accept {
|
if tt.accept {
|
||||||
go func() {
|
go func() {
|
||||||
err := testSvc.Serve()
|
err := testSvr.Serve()
|
||||||
require.Nil(err)
|
require.Nil(err)
|
||||||
}()
|
}()
|
||||||
defer testSvc.Close()
|
defer testSvr.Close()
|
||||||
}
|
}
|
||||||
|
|
||||||
// Always expect to be connecting to a "DB"
|
// Always expect to be connecting to a "DB"
|
||||||
resolver := &StaticResolver{
|
resolver := &StaticResolver{
|
||||||
Addr: testSvc.Addr,
|
Addr: testSvr.Addr,
|
||||||
CertURI: connect.TestSpiffeIDService(t, "db"),
|
CertURI: connect.TestSpiffeIDService(t, "db"),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -92,6 +96,7 @@ func TestService_Dial(t *testing.T) {
|
||||||
|
|
||||||
if tt.wantErr == "" {
|
if tt.wantErr == "" {
|
||||||
require.Nil(err)
|
require.Nil(err)
|
||||||
|
require.IsType(&tls.Conn{}, conn)
|
||||||
} else {
|
} else {
|
||||||
require.NotNil(err)
|
require.NotNil(err)
|
||||||
require.Contains(err.Error(), tt.wantErr)
|
require.Contains(err.Error(), tt.wantErr)
|
||||||
|
@ -103,3 +108,62 @@ func TestService_Dial(t *testing.T) {
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestService_ServerTLSConfig(t *testing.T) {
|
||||||
|
// TODO(banks): it's mostly meaningless to test this now since we directly set
|
||||||
|
// the tlsCfg in our TestService helper which is all we'd be asserting on here
|
||||||
|
// not the actual implementation. Once agent tls fetching is built, it becomes
|
||||||
|
// more meaningful to actually verify it's returning the correct config.
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestService_HTTPClient(t *testing.T) {
|
||||||
|
require := require.New(t)
|
||||||
|
ca := connect.TestCA(t, nil)
|
||||||
|
|
||||||
|
s := TestService(t, "web", ca)
|
||||||
|
|
||||||
|
// Run a test HTTP server
|
||||||
|
testSvr := NewTestServer(t, "backend", ca)
|
||||||
|
defer testSvr.Close()
|
||||||
|
go func() {
|
||||||
|
err := testSvr.ServeHTTPS(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.Write([]byte("Hello, I am Backend"))
|
||||||
|
}))
|
||||||
|
require.Nil(t, err)
|
||||||
|
}()
|
||||||
|
|
||||||
|
// TODO(banks): this will talk http2 on both client and server. I hit some
|
||||||
|
// compatibility issues when testing though need to make sure that the http
|
||||||
|
// server with our TLSConfig can actually support HTTP/1.1 as well. Could make
|
||||||
|
// this a table test with all 4 permutations of client/server http version
|
||||||
|
// support.
|
||||||
|
|
||||||
|
// Still get connection refused some times so retry on those
|
||||||
|
retry.Run(t, func(r *retry.R) {
|
||||||
|
// Hook the service resolver to avoid needing full agent setup.
|
||||||
|
s.httpResolverFromAddr = func(addr string) (Resolver, error) {
|
||||||
|
// Require in this goroutine seems to block causing a timeout on the Get.
|
||||||
|
//require.Equal("https://backend.service.consul:443", addr)
|
||||||
|
return &StaticResolver{
|
||||||
|
Addr: testSvr.Addr,
|
||||||
|
CertURI: connect.TestSpiffeIDService(t, "backend"),
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
client := s.HTTPClient()
|
||||||
|
client.Timeout = 1 * time.Second
|
||||||
|
|
||||||
|
resp, err := client.Get("https://backend.service.consul/foo")
|
||||||
|
r.Check(err)
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
bodyBytes, err := ioutil.ReadAll(resp.Body)
|
||||||
|
r.Check(err)
|
||||||
|
|
||||||
|
got := string(bodyBytes)
|
||||||
|
want := "Hello, I am Backend"
|
||||||
|
if got != want {
|
||||||
|
r.Fatalf("got %s, want %s", got, want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
|
@ -5,26 +5,33 @@ import (
|
||||||
"crypto/x509"
|
"crypto/x509"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
|
"log"
|
||||||
"net"
|
"net"
|
||||||
|
"net/http"
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
|
|
||||||
"github.com/hashicorp/consul/agent/connect"
|
"github.com/hashicorp/consul/agent/connect"
|
||||||
"github.com/hashicorp/consul/agent/structs"
|
"github.com/hashicorp/consul/agent/structs"
|
||||||
"github.com/hashicorp/consul/lib/freeport"
|
"github.com/hashicorp/consul/lib/freeport"
|
||||||
testing "github.com/mitchellh/go-testing-interface"
|
testing "github.com/mitchellh/go-testing-interface"
|
||||||
"github.com/stretchr/testify/require"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// testVerifier creates a helper verifyFunc that can be set in a tls.Config and
|
// TestService returns a Service instance based on a static TLS Config.
|
||||||
// records calls made, passing back the certificates presented via the returned
|
func TestService(t testing.T, service string, ca *structs.CARoot) *Service {
|
||||||
// channel. The channel is buffered so up to 128 verification calls can be made
|
t.Helper()
|
||||||
// without reading the chan before verification blocks.
|
|
||||||
func testVerifier(t testing.T, returnErr error) (verifyFunc, chan [][]byte) {
|
// Don't need to talk to client since we are setting TLSConfig locally
|
||||||
ch := make(chan [][]byte, 128)
|
svc, err := NewService(service, nil)
|
||||||
return func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error {
|
if err != nil {
|
||||||
ch <- rawCerts
|
t.Fatal(err)
|
||||||
return returnErr
|
}
|
||||||
}, ch
|
|
||||||
|
svc.serverTLSCfg = newReloadableTLSConfig(
|
||||||
|
TestTLSConfigWithVerifier(t, service, ca, serverVerifyCerts))
|
||||||
|
svc.clientTLSCfg = newReloadableTLSConfig(
|
||||||
|
TestTLSConfigWithVerifier(t, service, ca, clientVerifyCerts))
|
||||||
|
|
||||||
|
return svc
|
||||||
}
|
}
|
||||||
|
|
||||||
// TestTLSConfig returns a *tls.Config suitable for use during tests.
|
// TestTLSConfig returns a *tls.Config suitable for use during tests.
|
||||||
|
@ -32,7 +39,16 @@ func TestTLSConfig(t testing.T, service string, ca *structs.CARoot) *tls.Config
|
||||||
t.Helper()
|
t.Helper()
|
||||||
|
|
||||||
// Insecure default (nil verifier)
|
// Insecure default (nil verifier)
|
||||||
cfg := defaultTLSConfig(nil)
|
return TestTLSConfigWithVerifier(t, service, ca, nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestTLSConfigWithVerifier returns a *tls.Config suitable for use during
|
||||||
|
// tests, it will use the given verifyFunc to verify tls certificates.
|
||||||
|
func TestTLSConfigWithVerifier(t testing.T, service string, ca *structs.CARoot,
|
||||||
|
verifier verifyFunc) *tls.Config {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
cfg := defaultTLSConfig(verifier)
|
||||||
cfg.Certificates = []tls.Certificate{TestSvcKeyPair(t, service, ca)}
|
cfg.Certificates = []tls.Certificate{TestSvcKeyPair(t, service, ca)}
|
||||||
cfg.RootCAs = TestCAPool(t, ca)
|
cfg.RootCAs = TestCAPool(t, ca)
|
||||||
cfg.ClientCAs = TestCAPool(t, ca)
|
cfg.ClientCAs = TestCAPool(t, ca)
|
||||||
|
@ -55,7 +71,9 @@ func TestSvcKeyPair(t testing.T, service string, ca *structs.CARoot) tls.Certifi
|
||||||
t.Helper()
|
t.Helper()
|
||||||
certPEM, keyPEM := connect.TestLeaf(t, service, ca)
|
certPEM, keyPEM := connect.TestLeaf(t, service, ca)
|
||||||
cert, err := tls.X509KeyPair([]byte(certPEM), []byte(keyPEM))
|
cert, err := tls.X509KeyPair([]byte(certPEM), []byte(keyPEM))
|
||||||
require.Nil(t, err)
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
return cert
|
return cert
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -65,13 +83,15 @@ func TestPeerCertificates(t testing.T, service string, ca *structs.CARoot) []*x5
|
||||||
t.Helper()
|
t.Helper()
|
||||||
certPEM, _ := connect.TestLeaf(t, service, ca)
|
certPEM, _ := connect.TestLeaf(t, service, ca)
|
||||||
cert, err := connect.ParseCert(certPEM)
|
cert, err := connect.ParseCert(certPEM)
|
||||||
require.Nil(t, err)
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
return []*x509.Certificate{cert}
|
return []*x509.Certificate{cert}
|
||||||
}
|
}
|
||||||
|
|
||||||
// TestService runs a service listener that can be used to test clients. It's
|
// TestServer runs a service listener that can be used to test clients. It's
|
||||||
// behaviour can be controlled by the struct members.
|
// behaviour can be controlled by the struct members.
|
||||||
type TestService struct {
|
type TestServer struct {
|
||||||
// The service name to serve.
|
// The service name to serve.
|
||||||
Service string
|
Service string
|
||||||
// The (test) CA to use for generating certs.
|
// The (test) CA to use for generating certs.
|
||||||
|
@ -91,11 +111,11 @@ type TestService struct {
|
||||||
stopChan chan struct{}
|
stopChan chan struct{}
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewTestService returns a TestService. It should be closed when test is
|
// NewTestServer returns a TestServer. It should be closed when test is
|
||||||
// complete.
|
// complete.
|
||||||
func NewTestService(t testing.T, service string, ca *structs.CARoot) *TestService {
|
func NewTestServer(t testing.T, service string, ca *structs.CARoot) *TestServer {
|
||||||
ports := freeport.GetT(t, 1)
|
ports := freeport.GetT(t, 1)
|
||||||
return &TestService{
|
return &TestServer{
|
||||||
Service: service,
|
Service: service,
|
||||||
CA: ca,
|
CA: ca,
|
||||||
stopChan: make(chan struct{}),
|
stopChan: make(chan struct{}),
|
||||||
|
@ -104,14 +124,16 @@ func NewTestService(t testing.T, service string, ca *structs.CARoot) *TestServic
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Serve runs a TestService and blocks until it is closed or errors.
|
// Serve runs a tcp echo server and blocks until it is closed or errors. If
|
||||||
func (s *TestService) Serve() error {
|
// TimeoutHandshake is set it won't start TLS handshake on new connections.
|
||||||
|
func (s *TestServer) Serve() error {
|
||||||
// Just accept TCP conn but so we can control timing of accept/handshake
|
// Just accept TCP conn but so we can control timing of accept/handshake
|
||||||
l, err := net.Listen("tcp", s.Addr)
|
l, err := net.Listen("tcp", s.Addr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
s.l = l
|
s.l = l
|
||||||
|
log.Printf("test connect service listening on %s", s.Addr)
|
||||||
|
|
||||||
for {
|
for {
|
||||||
conn, err := s.l.Accept()
|
conn, err := s.l.Accept()
|
||||||
|
@ -122,12 +144,14 @@ func (s *TestService) Serve() error {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Ignore the conn if we are not actively ha
|
// Ignore the conn if we are not actively handshaking
|
||||||
if !s.TimeoutHandshake {
|
if !s.TimeoutHandshake {
|
||||||
// Upgrade conn to TLS
|
// Upgrade conn to TLS
|
||||||
conn = tls.Server(conn, s.TLSCfg)
|
conn = tls.Server(conn, s.TLSCfg)
|
||||||
|
|
||||||
// Run an echo service
|
// Run an echo service
|
||||||
|
log.Printf("test connect service accepted conn from %s, "+
|
||||||
|
" running echo service", conn.RemoteAddr())
|
||||||
go io.Copy(conn, conn)
|
go io.Copy(conn, conn)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -141,8 +165,20 @@ func (s *TestService) Serve() error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Close stops a TestService
|
// ServeHTTPS runs an HTTPS server with the given config. It invokes the passed
|
||||||
func (s *TestService) Close() {
|
// Handler for all requests.
|
||||||
|
func (s *TestServer) ServeHTTPS(h http.Handler) error {
|
||||||
|
srv := http.Server{
|
||||||
|
Addr: s.Addr,
|
||||||
|
TLSConfig: s.TLSCfg,
|
||||||
|
Handler: h,
|
||||||
|
}
|
||||||
|
log.Printf("starting test connect HTTPS server on %s", s.Addr)
|
||||||
|
return srv.ListenAndServeTLS("", "")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close stops a TestServer
|
||||||
|
func (s *TestServer) Close() error {
|
||||||
old := atomic.SwapInt32(&s.stopFlag, 1)
|
old := atomic.SwapInt32(&s.stopFlag, 1)
|
||||||
if old == 0 {
|
if old == 0 {
|
||||||
if s.l != nil {
|
if s.l != nil {
|
||||||
|
@ -150,4 +186,5 @@ func (s *TestService) Close() {
|
||||||
}
|
}
|
||||||
close(s.stopChan)
|
close(s.stopChan)
|
||||||
}
|
}
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
|
@ -42,27 +42,27 @@ func defaultTLSConfig(verify verifyFunc) *tls.Config {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// ReloadableTLSConfig exposes a tls.Config that can have it's certificates
|
// reloadableTLSConfig exposes a tls.Config that can have it's certificates
|
||||||
// reloaded. On a server, this uses GetConfigForClient to pass the current
|
// reloaded. On a server, this uses GetConfigForClient to pass the current
|
||||||
// tls.Config or client certificate for each acceptted connection. On a client,
|
// tls.Config or client certificate for each acceptted connection. On a client,
|
||||||
// this uses GetClientCertificate to provide the current client certificate.
|
// this uses GetClientCertificate to provide the current client certificate.
|
||||||
type ReloadableTLSConfig struct {
|
type reloadableTLSConfig struct {
|
||||||
mu sync.Mutex
|
mu sync.Mutex
|
||||||
|
|
||||||
// cfg is the current config to use for new connections
|
// cfg is the current config to use for new connections
|
||||||
cfg *tls.Config
|
cfg *tls.Config
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewReloadableTLSConfig returns a reloadable config currently set to base.
|
// newReloadableTLSConfig returns a reloadable config currently set to base.
|
||||||
func NewReloadableTLSConfig(base *tls.Config) *ReloadableTLSConfig {
|
func newReloadableTLSConfig(base *tls.Config) *reloadableTLSConfig {
|
||||||
c := &ReloadableTLSConfig{}
|
c := &reloadableTLSConfig{}
|
||||||
c.SetTLSConfig(base)
|
c.SetTLSConfig(base)
|
||||||
return c
|
return c
|
||||||
}
|
}
|
||||||
|
|
||||||
// TLSConfig returns a *tls.Config that will dynamically load certs. It's
|
// TLSConfig returns a *tls.Config that will dynamically load certs. It's
|
||||||
// suitable for use in either a client or server.
|
// suitable for use in either a client or server.
|
||||||
func (c *ReloadableTLSConfig) TLSConfig() *tls.Config {
|
func (c *reloadableTLSConfig) TLSConfig() *tls.Config {
|
||||||
c.mu.Lock()
|
c.mu.Lock()
|
||||||
cfgCopy := c.cfg
|
cfgCopy := c.cfg
|
||||||
c.mu.Unlock()
|
c.mu.Unlock()
|
||||||
|
@ -71,7 +71,7 @@ func (c *ReloadableTLSConfig) TLSConfig() *tls.Config {
|
||||||
|
|
||||||
// SetTLSConfig sets the config used for future connections. It is safe to call
|
// SetTLSConfig sets the config used for future connections. It is safe to call
|
||||||
// from any goroutine.
|
// from any goroutine.
|
||||||
func (c *ReloadableTLSConfig) SetTLSConfig(cfg *tls.Config) error {
|
func (c *reloadableTLSConfig) SetTLSConfig(cfg *tls.Config) error {
|
||||||
copy := cfg.Clone()
|
copy := cfg.Clone()
|
||||||
copy.GetClientCertificate = func(*tls.CertificateRequestInfo) (*tls.Certificate, error) {
|
copy.GetClientCertificate = func(*tls.CertificateRequestInfo) (*tls.Certificate, error) {
|
||||||
current := c.TLSConfig()
|
current := c.TLSConfig()
|
||||||
|
|
|
@ -10,10 +10,9 @@ import (
|
||||||
|
|
||||||
func TestReloadableTLSConfig(t *testing.T) {
|
func TestReloadableTLSConfig(t *testing.T) {
|
||||||
require := require.New(t)
|
require := require.New(t)
|
||||||
verify, _ := testVerifier(t, nil)
|
base := defaultTLSConfig(nil)
|
||||||
base := defaultTLSConfig(verify)
|
|
||||||
|
|
||||||
c := NewReloadableTLSConfig(base)
|
c := newReloadableTLSConfig(base)
|
||||||
|
|
||||||
// The dynamic config should be the one we loaded (with some different hooks)
|
// The dynamic config should be the one we loaded (with some different hooks)
|
||||||
got := c.TLSConfig()
|
got := c.TLSConfig()
|
||||||
|
|
Loading…
Reference in New Issue