consul: use tlsutil.Wrapper instead of tls.Config directly

This commit is contained in:
Armon Dadgar 2015-05-08 15:57:37 -07:00
parent 6b2390833d
commit 92e5548b23
4 changed files with 35 additions and 23 deletions

View File

@ -5,6 +5,7 @@ import (
"fmt" "fmt"
"log" "log"
"math/rand" "math/rand"
"net"
"os" "os"
"path/filepath" "path/filepath"
"strconv" "strconv"
@ -13,6 +14,7 @@ import (
"time" "time"
"github.com/hashicorp/consul/consul/structs" "github.com/hashicorp/consul/consul/structs"
"github.com/hashicorp/consul/tlsutil"
"github.com/hashicorp/serf/serf" "github.com/hashicorp/serf/serf"
) )
@ -98,13 +100,18 @@ func NewClient(config *Config) (*Client, error) {
return nil, err return nil, err
} }
// Define a TLS wrapper
tlsWrap := func(c net.Conn) (net.Conn, error) {
return tlsutil.WrapTLSClient(c, tlsConfig)
}
// Create a logger // Create a logger
logger := log.New(config.LogOutput, "", log.LstdFlags) logger := log.New(config.LogOutput, "", log.LstdFlags)
// Create server // Create server
c := &Client{ c := &Client{
config: config, config: config,
connPool: NewPool(config.LogOutput, clientRPCCache, clientMaxStreams, tlsConfig), connPool: NewPool(config.LogOutput, clientRPCCache, clientMaxStreams, tlsWrap),
eventCh: make(chan serf.Event, 256), eventCh: make(chan serf.Event, 256),
logger: logger, logger: logger,
shutdownCh: make(chan struct{}), shutdownCh: make(chan struct{}),

View File

@ -2,7 +2,6 @@ package consul
import ( import (
"container/list" "container/list"
"crypto/tls"
"fmt" "fmt"
"io" "io"
"net" "net"
@ -135,8 +134,8 @@ type ConnPool struct {
// Pool maps an address to a open connection // Pool maps an address to a open connection
pool map[string]*Conn pool map[string]*Conn
// TLS settings // TLS wrapper
tlsConfig *tls.Config tlsWrap tlsutil.Wrapper
// Used to indicate the pool is shutdown // Used to indicate the pool is shutdown
shutdown bool shutdown bool
@ -148,13 +147,13 @@ type ConnPool struct {
// Set maxTime to 0 to disable reaping. maxStreams is used to control // Set maxTime to 0 to disable reaping. maxStreams is used to control
// the number of idle streams allowed. // the number of idle streams allowed.
// If TLS settings are provided outgoing connections use TLS. // If TLS settings are provided outgoing connections use TLS.
func NewPool(logOutput io.Writer, maxTime time.Duration, maxStreams int, tlsConfig *tls.Config) *ConnPool { func NewPool(logOutput io.Writer, maxTime time.Duration, maxStreams int, tlsWrap tlsutil.Wrapper) *ConnPool {
pool := &ConnPool{ pool := &ConnPool{
logOutput: logOutput, logOutput: logOutput,
maxTime: maxTime, maxTime: maxTime,
maxStreams: maxStreams, maxStreams: maxStreams,
pool: make(map[string]*Conn), pool: make(map[string]*Conn),
tlsConfig: tlsConfig, tlsWrap: tlsWrap,
shutdownCh: make(chan struct{}), shutdownCh: make(chan struct{}),
} }
if maxTime > 0 { if maxTime > 0 {
@ -220,7 +219,7 @@ func (p *ConnPool) getNewConn(addr net.Addr, version int) (*Conn, error) {
} }
// Check if TLS is enabled // Check if TLS is enabled
if p.tlsConfig != nil { if p.tlsWrap != nil {
// Switch the connection into TLS mode // Switch the connection into TLS mode
if _, err := conn.Write([]byte{byte(rpcTLS)}); err != nil { if _, err := conn.Write([]byte{byte(rpcTLS)}); err != nil {
conn.Close() conn.Close()
@ -228,7 +227,7 @@ func (p *ConnPool) getNewConn(addr net.Addr, version int) (*Conn, error) {
} }
// Wrap the connection in a TLS client // Wrap the connection in a TLS client
tlsConn, err := tlsutil.WrapTLSClient(conn, p.tlsConfig) tlsConn, err := p.tlsWrap(conn)
if err != nil { if err != nil {
conn.Close() conn.Close()
return nil, err return nil, err

View File

@ -1,12 +1,12 @@
package consul package consul
import ( import (
"crypto/tls"
"fmt" "fmt"
"github.com/hashicorp/consul/tlsutil"
"net" "net"
"sync" "sync"
"time" "time"
"github.com/hashicorp/consul/tlsutil"
) )
// RaftLayer implements the raft.StreamLayer interface, // RaftLayer implements the raft.StreamLayer interface,
@ -18,8 +18,8 @@ type RaftLayer struct {
// connCh is used to accept connections // connCh is used to accept connections
connCh chan net.Conn connCh chan net.Conn
// TLS configuration // TLS wrapper
tlsConfig *tls.Config tlsWrap tlsutil.Wrapper
// Tracks if we are closed // Tracks if we are closed
closed bool closed bool
@ -30,11 +30,11 @@ type RaftLayer struct {
// NewRaftLayer is used to initialize a new RaftLayer which can // NewRaftLayer is used to initialize a new RaftLayer which can
// be used as a StreamLayer for Raft. If a tlsConfig is provided, // be used as a StreamLayer for Raft. If a tlsConfig is provided,
// then the connection will use TLS. // then the connection will use TLS.
func NewRaftLayer(addr net.Addr, tlsConfig *tls.Config) *RaftLayer { func NewRaftLayer(addr net.Addr, tlsWrap tlsutil.Wrapper) *RaftLayer {
layer := &RaftLayer{ layer := &RaftLayer{
addr: addr, addr: addr,
connCh: make(chan net.Conn), connCh: make(chan net.Conn),
tlsConfig: tlsConfig, tlsWrap: tlsWrap,
closeCh: make(chan struct{}), closeCh: make(chan struct{}),
} }
return layer return layer
@ -87,7 +87,7 @@ func (l *RaftLayer) Dial(address string, timeout time.Duration) (net.Conn, error
} }
// Check for tls mode // Check for tls mode
if l.tlsConfig != nil { if l.tlsWrap != nil {
// Switch the connection into TLS mode // Switch the connection into TLS mode
if _, err := conn.Write([]byte{byte(rpcTLS)}); err != nil { if _, err := conn.Write([]byte{byte(rpcTLS)}); err != nil {
conn.Close() conn.Close()
@ -95,7 +95,7 @@ func (l *RaftLayer) Dial(address string, timeout time.Duration) (net.Conn, error
} }
// Wrap the connection in a TLS client // Wrap the connection in a TLS client
conn, err = tlsutil.WrapTLSClient(conn, l.tlsConfig) conn, err = l.tlsWrap(conn)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@ -15,6 +15,7 @@ import (
"time" "time"
"github.com/hashicorp/consul/acl" "github.com/hashicorp/consul/acl"
"github.com/hashicorp/consul/tlsutil"
"github.com/hashicorp/golang-lru" "github.com/hashicorp/golang-lru"
"github.com/hashicorp/raft" "github.com/hashicorp/raft"
"github.com/hashicorp/raft-boltdb" "github.com/hashicorp/raft-boltdb"
@ -189,6 +190,11 @@ func NewServer(config *Config) (*Server, error) {
return nil, err return nil, err
} }
// Define a TLS wrapper
tlsWrap := func(c net.Conn) (net.Conn, error) {
return tlsutil.WrapTLSClient(c, tlsConfig)
}
// Get the incoming tls config // Get the incoming tls config
incomingTLS, err := tlsConf.IncomingTLSConfig() incomingTLS, err := tlsConf.IncomingTLSConfig()
if err != nil { if err != nil {
@ -207,7 +213,7 @@ func NewServer(config *Config) (*Server, error) {
// Create server // Create server
s := &Server{ s := &Server{
config: config, config: config,
connPool: NewPool(config.LogOutput, serverRPCCache, serverMaxStreams, tlsConfig), connPool: NewPool(config.LogOutput, serverRPCCache, serverMaxStreams, tlsWrap),
eventChLAN: make(chan serf.Event, 256), eventChLAN: make(chan serf.Event, 256),
eventChWAN: make(chan serf.Event, 256), eventChWAN: make(chan serf.Event, 256),
localConsuls: make(map[string]*serverParts), localConsuls: make(map[string]*serverParts),
@ -242,7 +248,7 @@ func NewServer(config *Config) (*Server, error) {
} }
// Initialize the RPC layer // Initialize the RPC layer
if err := s.setupRPC(tlsConfig); err != nil { if err := s.setupRPC(tlsWrap); err != nil {
s.Shutdown() s.Shutdown()
return nil, fmt.Errorf("Failed to start RPC layer: %v", err) return nil, fmt.Errorf("Failed to start RPC layer: %v", err)
} }
@ -410,7 +416,7 @@ func (s *Server) setupRaft() error {
} }
// setupRPC is used to setup the RPC listener // setupRPC is used to setup the RPC listener
func (s *Server) setupRPC(tlsConfig *tls.Config) error { func (s *Server) setupRPC(tlsWrap tlsutil.Wrapper) error {
// Create endpoints // Create endpoints
s.endpoints.Status = &Status{s} s.endpoints.Status = &Status{s}
s.endpoints.Catalog = &Catalog{s} s.endpoints.Catalog = &Catalog{s}
@ -453,7 +459,7 @@ func (s *Server) setupRPC(tlsConfig *tls.Config) error {
return fmt.Errorf("RPC advertise address is not advertisable: %v", addr) return fmt.Errorf("RPC advertise address is not advertisable: %v", addr)
} }
s.raftLayer = NewRaftLayer(advertise, tlsConfig) s.raftLayer = NewRaftLayer(advertise, tlsWrap)
return nil return nil
} }