consul: use tlsutil.Wrapper instead of tls.Config directly

This commit is contained in:
Armon Dadgar 2015-05-08 15:57:37 -07:00
parent 6b2390833d
commit 92e5548b23
4 changed files with 35 additions and 23 deletions

View File

@ -5,6 +5,7 @@ import (
"fmt"
"log"
"math/rand"
"net"
"os"
"path/filepath"
"strconv"
@ -13,6 +14,7 @@ import (
"time"
"github.com/hashicorp/consul/consul/structs"
"github.com/hashicorp/consul/tlsutil"
"github.com/hashicorp/serf/serf"
)
@ -98,13 +100,18 @@ func NewClient(config *Config) (*Client, error) {
return nil, err
}
// Define a TLS wrapper
tlsWrap := func(c net.Conn) (net.Conn, error) {
return tlsutil.WrapTLSClient(c, tlsConfig)
}
// Create a logger
logger := log.New(config.LogOutput, "", log.LstdFlags)
// Create server
c := &Client{
config: config,
connPool: NewPool(config.LogOutput, clientRPCCache, clientMaxStreams, tlsConfig),
connPool: NewPool(config.LogOutput, clientRPCCache, clientMaxStreams, tlsWrap),
eventCh: make(chan serf.Event, 256),
logger: logger,
shutdownCh: make(chan struct{}),

View File

@ -2,7 +2,6 @@ package consul
import (
"container/list"
"crypto/tls"
"fmt"
"io"
"net"
@ -135,8 +134,8 @@ type ConnPool struct {
// Pool maps an address to a open connection
pool map[string]*Conn
// TLS settings
tlsConfig *tls.Config
// TLS wrapper
tlsWrap tlsutil.Wrapper
// Used to indicate the pool is shutdown
shutdown bool
@ -148,13 +147,13 @@ type ConnPool struct {
// Set maxTime to 0 to disable reaping. maxStreams is used to control
// the number of idle streams allowed.
// If TLS settings are provided outgoing connections use TLS.
func NewPool(logOutput io.Writer, maxTime time.Duration, maxStreams int, tlsConfig *tls.Config) *ConnPool {
func NewPool(logOutput io.Writer, maxTime time.Duration, maxStreams int, tlsWrap tlsutil.Wrapper) *ConnPool {
pool := &ConnPool{
logOutput: logOutput,
maxTime: maxTime,
maxStreams: maxStreams,
pool: make(map[string]*Conn),
tlsConfig: tlsConfig,
tlsWrap: tlsWrap,
shutdownCh: make(chan struct{}),
}
if maxTime > 0 {
@ -220,7 +219,7 @@ func (p *ConnPool) getNewConn(addr net.Addr, version int) (*Conn, error) {
}
// Check if TLS is enabled
if p.tlsConfig != nil {
if p.tlsWrap != nil {
// Switch the connection into TLS mode
if _, err := conn.Write([]byte{byte(rpcTLS)}); err != nil {
conn.Close()
@ -228,7 +227,7 @@ func (p *ConnPool) getNewConn(addr net.Addr, version int) (*Conn, error) {
}
// Wrap the connection in a TLS client
tlsConn, err := tlsutil.WrapTLSClient(conn, p.tlsConfig)
tlsConn, err := p.tlsWrap(conn)
if err != nil {
conn.Close()
return nil, err

View File

@ -1,12 +1,12 @@
package consul
import (
"crypto/tls"
"fmt"
"github.com/hashicorp/consul/tlsutil"
"net"
"sync"
"time"
"github.com/hashicorp/consul/tlsutil"
)
// RaftLayer implements the raft.StreamLayer interface,
@ -18,8 +18,8 @@ type RaftLayer struct {
// connCh is used to accept connections
connCh chan net.Conn
// TLS configuration
tlsConfig *tls.Config
// TLS wrapper
tlsWrap tlsutil.Wrapper
// Tracks if we are closed
closed bool
@ -30,12 +30,12 @@ type RaftLayer struct {
// NewRaftLayer is used to initialize a new RaftLayer which can
// be used as a StreamLayer for Raft. If a tlsConfig is provided,
// then the connection will use TLS.
func NewRaftLayer(addr net.Addr, tlsConfig *tls.Config) *RaftLayer {
func NewRaftLayer(addr net.Addr, tlsWrap tlsutil.Wrapper) *RaftLayer {
layer := &RaftLayer{
addr: addr,
connCh: make(chan net.Conn),
tlsConfig: tlsConfig,
closeCh: make(chan struct{}),
addr: addr,
connCh: make(chan net.Conn),
tlsWrap: tlsWrap,
closeCh: make(chan struct{}),
}
return layer
}
@ -87,7 +87,7 @@ func (l *RaftLayer) Dial(address string, timeout time.Duration) (net.Conn, error
}
// Check for tls mode
if l.tlsConfig != nil {
if l.tlsWrap != nil {
// Switch the connection into TLS mode
if _, err := conn.Write([]byte{byte(rpcTLS)}); err != nil {
conn.Close()
@ -95,7 +95,7 @@ func (l *RaftLayer) Dial(address string, timeout time.Duration) (net.Conn, error
}
// Wrap the connection in a TLS client
conn, err = tlsutil.WrapTLSClient(conn, l.tlsConfig)
conn, err = l.tlsWrap(conn)
if err != nil {
return nil, err
}

View File

@ -15,6 +15,7 @@ import (
"time"
"github.com/hashicorp/consul/acl"
"github.com/hashicorp/consul/tlsutil"
"github.com/hashicorp/golang-lru"
"github.com/hashicorp/raft"
"github.com/hashicorp/raft-boltdb"
@ -189,6 +190,11 @@ func NewServer(config *Config) (*Server, error) {
return nil, err
}
// Define a TLS wrapper
tlsWrap := func(c net.Conn) (net.Conn, error) {
return tlsutil.WrapTLSClient(c, tlsConfig)
}
// Get the incoming tls config
incomingTLS, err := tlsConf.IncomingTLSConfig()
if err != nil {
@ -207,7 +213,7 @@ func NewServer(config *Config) (*Server, error) {
// Create server
s := &Server{
config: config,
connPool: NewPool(config.LogOutput, serverRPCCache, serverMaxStreams, tlsConfig),
connPool: NewPool(config.LogOutput, serverRPCCache, serverMaxStreams, tlsWrap),
eventChLAN: make(chan serf.Event, 256),
eventChWAN: make(chan serf.Event, 256),
localConsuls: make(map[string]*serverParts),
@ -242,7 +248,7 @@ func NewServer(config *Config) (*Server, error) {
}
// Initialize the RPC layer
if err := s.setupRPC(tlsConfig); err != nil {
if err := s.setupRPC(tlsWrap); err != nil {
s.Shutdown()
return nil, fmt.Errorf("Failed to start RPC layer: %v", err)
}
@ -410,7 +416,7 @@ func (s *Server) setupRaft() error {
}
// setupRPC is used to setup the RPC listener
func (s *Server) setupRPC(tlsConfig *tls.Config) error {
func (s *Server) setupRPC(tlsWrap tlsutil.Wrapper) error {
// Create endpoints
s.endpoints.Status = &Status{s}
s.endpoints.Catalog = &Catalog{s}
@ -453,7 +459,7 @@ func (s *Server) setupRPC(tlsConfig *tls.Config) error {
return fmt.Errorf("RPC advertise address is not advertisable: %v", addr)
}
s.raftLayer = NewRaftLayer(advertise, tlsConfig)
s.raftLayer = NewRaftLayer(advertise, tlsWrap)
return nil
}