diff --git a/consul/client.go b/consul/client.go index c37ace6d73..3459274a32 100644 --- a/consul/client.go +++ b/consul/client.go @@ -5,6 +5,7 @@ import ( "fmt" "log" "math/rand" + "net" "os" "path/filepath" "strconv" @@ -13,6 +14,7 @@ import ( "time" "github.com/hashicorp/consul/consul/structs" + "github.com/hashicorp/consul/tlsutil" "github.com/hashicorp/serf/serf" ) @@ -98,13 +100,18 @@ func NewClient(config *Config) (*Client, error) { return nil, err } + // Define a TLS wrapper + tlsWrap := func(c net.Conn) (net.Conn, error) { + return tlsutil.WrapTLSClient(c, tlsConfig) + } + // Create a logger logger := log.New(config.LogOutput, "", log.LstdFlags) // Create server c := &Client{ config: config, - connPool: NewPool(config.LogOutput, clientRPCCache, clientMaxStreams, tlsConfig), + connPool: NewPool(config.LogOutput, clientRPCCache, clientMaxStreams, tlsWrap), eventCh: make(chan serf.Event, 256), logger: logger, shutdownCh: make(chan struct{}), diff --git a/consul/pool.go b/consul/pool.go index 89d0654f8e..53f546a532 100644 --- a/consul/pool.go +++ b/consul/pool.go @@ -2,7 +2,6 @@ package consul import ( "container/list" - "crypto/tls" "fmt" "io" "net" @@ -135,8 +134,8 @@ type ConnPool struct { // Pool maps an address to a open connection pool map[string]*Conn - // TLS settings - tlsConfig *tls.Config + // TLS wrapper + tlsWrap tlsutil.Wrapper // Used to indicate the pool is shutdown shutdown bool @@ -148,13 +147,13 @@ type ConnPool struct { // Set maxTime to 0 to disable reaping. maxStreams is used to control // the number of idle streams allowed. // If TLS settings are provided outgoing connections use TLS. -func NewPool(logOutput io.Writer, maxTime time.Duration, maxStreams int, tlsConfig *tls.Config) *ConnPool { +func NewPool(logOutput io.Writer, maxTime time.Duration, maxStreams int, tlsWrap tlsutil.Wrapper) *ConnPool { pool := &ConnPool{ logOutput: logOutput, maxTime: maxTime, maxStreams: maxStreams, pool: make(map[string]*Conn), - tlsConfig: tlsConfig, + tlsWrap: tlsWrap, shutdownCh: make(chan struct{}), } if maxTime > 0 { @@ -220,7 +219,7 @@ func (p *ConnPool) getNewConn(addr net.Addr, version int) (*Conn, error) { } // Check if TLS is enabled - if p.tlsConfig != nil { + if p.tlsWrap != nil { // Switch the connection into TLS mode if _, err := conn.Write([]byte{byte(rpcTLS)}); err != nil { conn.Close() @@ -228,7 +227,7 @@ func (p *ConnPool) getNewConn(addr net.Addr, version int) (*Conn, error) { } // Wrap the connection in a TLS client - tlsConn, err := tlsutil.WrapTLSClient(conn, p.tlsConfig) + tlsConn, err := p.tlsWrap(conn) if err != nil { conn.Close() return nil, err diff --git a/consul/raft_rpc.go b/consul/raft_rpc.go index e0ee4c68e6..545895e195 100644 --- a/consul/raft_rpc.go +++ b/consul/raft_rpc.go @@ -1,12 +1,12 @@ package consul import ( - "crypto/tls" "fmt" - "github.com/hashicorp/consul/tlsutil" "net" "sync" "time" + + "github.com/hashicorp/consul/tlsutil" ) // RaftLayer implements the raft.StreamLayer interface, @@ -18,8 +18,8 @@ type RaftLayer struct { // connCh is used to accept connections connCh chan net.Conn - // TLS configuration - tlsConfig *tls.Config + // TLS wrapper + tlsWrap tlsutil.Wrapper // Tracks if we are closed closed bool @@ -30,12 +30,12 @@ type RaftLayer struct { // NewRaftLayer is used to initialize a new RaftLayer which can // be used as a StreamLayer for Raft. If a tlsConfig is provided, // then the connection will use TLS. -func NewRaftLayer(addr net.Addr, tlsConfig *tls.Config) *RaftLayer { +func NewRaftLayer(addr net.Addr, tlsWrap tlsutil.Wrapper) *RaftLayer { layer := &RaftLayer{ - addr: addr, - connCh: make(chan net.Conn), - tlsConfig: tlsConfig, - closeCh: make(chan struct{}), + addr: addr, + connCh: make(chan net.Conn), + tlsWrap: tlsWrap, + closeCh: make(chan struct{}), } return layer } @@ -87,7 +87,7 @@ func (l *RaftLayer) Dial(address string, timeout time.Duration) (net.Conn, error } // Check for tls mode - if l.tlsConfig != nil { + if l.tlsWrap != nil { // Switch the connection into TLS mode if _, err := conn.Write([]byte{byte(rpcTLS)}); err != nil { conn.Close() @@ -95,7 +95,7 @@ func (l *RaftLayer) Dial(address string, timeout time.Duration) (net.Conn, error } // Wrap the connection in a TLS client - conn, err = tlsutil.WrapTLSClient(conn, l.tlsConfig) + conn, err = l.tlsWrap(conn) if err != nil { return nil, err } diff --git a/consul/server.go b/consul/server.go index 0fe981113a..47edbf9182 100644 --- a/consul/server.go +++ b/consul/server.go @@ -15,6 +15,7 @@ import ( "time" "github.com/hashicorp/consul/acl" + "github.com/hashicorp/consul/tlsutil" "github.com/hashicorp/golang-lru" "github.com/hashicorp/raft" "github.com/hashicorp/raft-boltdb" @@ -189,6 +190,11 @@ func NewServer(config *Config) (*Server, error) { return nil, err } + // Define a TLS wrapper + tlsWrap := func(c net.Conn) (net.Conn, error) { + return tlsutil.WrapTLSClient(c, tlsConfig) + } + // Get the incoming tls config incomingTLS, err := tlsConf.IncomingTLSConfig() if err != nil { @@ -207,7 +213,7 @@ func NewServer(config *Config) (*Server, error) { // Create server s := &Server{ config: config, - connPool: NewPool(config.LogOutput, serverRPCCache, serverMaxStreams, tlsConfig), + connPool: NewPool(config.LogOutput, serverRPCCache, serverMaxStreams, tlsWrap), eventChLAN: make(chan serf.Event, 256), eventChWAN: make(chan serf.Event, 256), localConsuls: make(map[string]*serverParts), @@ -242,7 +248,7 @@ func NewServer(config *Config) (*Server, error) { } // Initialize the RPC layer - if err := s.setupRPC(tlsConfig); err != nil { + if err := s.setupRPC(tlsWrap); err != nil { s.Shutdown() return nil, fmt.Errorf("Failed to start RPC layer: %v", err) } @@ -410,7 +416,7 @@ func (s *Server) setupRaft() error { } // setupRPC is used to setup the RPC listener -func (s *Server) setupRPC(tlsConfig *tls.Config) error { +func (s *Server) setupRPC(tlsWrap tlsutil.Wrapper) error { // Create endpoints s.endpoints.Status = &Status{s} s.endpoints.Catalog = &Catalog{s} @@ -453,7 +459,7 @@ func (s *Server) setupRPC(tlsConfig *tls.Config) error { return fmt.Errorf("RPC advertise address is not advertisable: %v", addr) } - s.raftLayer = NewRaftLayer(advertise, tlsConfig) + s.raftLayer = NewRaftLayer(advertise, tlsWrap) return nil }