mirror of https://github.com/status-im/consul.git
consul: use tlsutil.Wrapper instead of tls.Config directly
This commit is contained in:
parent
6b2390833d
commit
92e5548b23
|
@ -5,6 +5,7 @@ import (
|
|||
"fmt"
|
||||
"log"
|
||||
"math/rand"
|
||||
"net"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
|
@ -13,6 +14,7 @@ import (
|
|||
"time"
|
||||
|
||||
"github.com/hashicorp/consul/consul/structs"
|
||||
"github.com/hashicorp/consul/tlsutil"
|
||||
"github.com/hashicorp/serf/serf"
|
||||
)
|
||||
|
||||
|
@ -98,13 +100,18 @@ func NewClient(config *Config) (*Client, error) {
|
|||
return nil, err
|
||||
}
|
||||
|
||||
// Define a TLS wrapper
|
||||
tlsWrap := func(c net.Conn) (net.Conn, error) {
|
||||
return tlsutil.WrapTLSClient(c, tlsConfig)
|
||||
}
|
||||
|
||||
// Create a logger
|
||||
logger := log.New(config.LogOutput, "", log.LstdFlags)
|
||||
|
||||
// Create server
|
||||
c := &Client{
|
||||
config: config,
|
||||
connPool: NewPool(config.LogOutput, clientRPCCache, clientMaxStreams, tlsConfig),
|
||||
connPool: NewPool(config.LogOutput, clientRPCCache, clientMaxStreams, tlsWrap),
|
||||
eventCh: make(chan serf.Event, 256),
|
||||
logger: logger,
|
||||
shutdownCh: make(chan struct{}),
|
||||
|
|
|
@ -2,7 +2,6 @@ package consul
|
|||
|
||||
import (
|
||||
"container/list"
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
|
@ -135,8 +134,8 @@ type ConnPool struct {
|
|||
// Pool maps an address to a open connection
|
||||
pool map[string]*Conn
|
||||
|
||||
// TLS settings
|
||||
tlsConfig *tls.Config
|
||||
// TLS wrapper
|
||||
tlsWrap tlsutil.Wrapper
|
||||
|
||||
// Used to indicate the pool is shutdown
|
||||
shutdown bool
|
||||
|
@ -148,13 +147,13 @@ type ConnPool struct {
|
|||
// Set maxTime to 0 to disable reaping. maxStreams is used to control
|
||||
// the number of idle streams allowed.
|
||||
// If TLS settings are provided outgoing connections use TLS.
|
||||
func NewPool(logOutput io.Writer, maxTime time.Duration, maxStreams int, tlsConfig *tls.Config) *ConnPool {
|
||||
func NewPool(logOutput io.Writer, maxTime time.Duration, maxStreams int, tlsWrap tlsutil.Wrapper) *ConnPool {
|
||||
pool := &ConnPool{
|
||||
logOutput: logOutput,
|
||||
maxTime: maxTime,
|
||||
maxStreams: maxStreams,
|
||||
pool: make(map[string]*Conn),
|
||||
tlsConfig: tlsConfig,
|
||||
tlsWrap: tlsWrap,
|
||||
shutdownCh: make(chan struct{}),
|
||||
}
|
||||
if maxTime > 0 {
|
||||
|
@ -220,7 +219,7 @@ func (p *ConnPool) getNewConn(addr net.Addr, version int) (*Conn, error) {
|
|||
}
|
||||
|
||||
// Check if TLS is enabled
|
||||
if p.tlsConfig != nil {
|
||||
if p.tlsWrap != nil {
|
||||
// Switch the connection into TLS mode
|
||||
if _, err := conn.Write([]byte{byte(rpcTLS)}); err != nil {
|
||||
conn.Close()
|
||||
|
@ -228,7 +227,7 @@ func (p *ConnPool) getNewConn(addr net.Addr, version int) (*Conn, error) {
|
|||
}
|
||||
|
||||
// Wrap the connection in a TLS client
|
||||
tlsConn, err := tlsutil.WrapTLSClient(conn, p.tlsConfig)
|
||||
tlsConn, err := p.tlsWrap(conn)
|
||||
if err != nil {
|
||||
conn.Close()
|
||||
return nil, err
|
||||
|
|
|
@ -1,12 +1,12 @@
|
|||
package consul
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"github.com/hashicorp/consul/tlsutil"
|
||||
"net"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/hashicorp/consul/tlsutil"
|
||||
)
|
||||
|
||||
// RaftLayer implements the raft.StreamLayer interface,
|
||||
|
@ -18,8 +18,8 @@ type RaftLayer struct {
|
|||
// connCh is used to accept connections
|
||||
connCh chan net.Conn
|
||||
|
||||
// TLS configuration
|
||||
tlsConfig *tls.Config
|
||||
// TLS wrapper
|
||||
tlsWrap tlsutil.Wrapper
|
||||
|
||||
// Tracks if we are closed
|
||||
closed bool
|
||||
|
@ -30,11 +30,11 @@ type RaftLayer struct {
|
|||
// NewRaftLayer is used to initialize a new RaftLayer which can
|
||||
// be used as a StreamLayer for Raft. If a tlsConfig is provided,
|
||||
// then the connection will use TLS.
|
||||
func NewRaftLayer(addr net.Addr, tlsConfig *tls.Config) *RaftLayer {
|
||||
func NewRaftLayer(addr net.Addr, tlsWrap tlsutil.Wrapper) *RaftLayer {
|
||||
layer := &RaftLayer{
|
||||
addr: addr,
|
||||
connCh: make(chan net.Conn),
|
||||
tlsConfig: tlsConfig,
|
||||
tlsWrap: tlsWrap,
|
||||
closeCh: make(chan struct{}),
|
||||
}
|
||||
return layer
|
||||
|
@ -87,7 +87,7 @@ func (l *RaftLayer) Dial(address string, timeout time.Duration) (net.Conn, error
|
|||
}
|
||||
|
||||
// Check for tls mode
|
||||
if l.tlsConfig != nil {
|
||||
if l.tlsWrap != nil {
|
||||
// Switch the connection into TLS mode
|
||||
if _, err := conn.Write([]byte{byte(rpcTLS)}); err != nil {
|
||||
conn.Close()
|
||||
|
@ -95,7 +95,7 @@ func (l *RaftLayer) Dial(address string, timeout time.Duration) (net.Conn, error
|
|||
}
|
||||
|
||||
// Wrap the connection in a TLS client
|
||||
conn, err = tlsutil.WrapTLSClient(conn, l.tlsConfig)
|
||||
conn, err = l.tlsWrap(conn)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
|
@ -15,6 +15,7 @@ import (
|
|||
"time"
|
||||
|
||||
"github.com/hashicorp/consul/acl"
|
||||
"github.com/hashicorp/consul/tlsutil"
|
||||
"github.com/hashicorp/golang-lru"
|
||||
"github.com/hashicorp/raft"
|
||||
"github.com/hashicorp/raft-boltdb"
|
||||
|
@ -189,6 +190,11 @@ func NewServer(config *Config) (*Server, error) {
|
|||
return nil, err
|
||||
}
|
||||
|
||||
// Define a TLS wrapper
|
||||
tlsWrap := func(c net.Conn) (net.Conn, error) {
|
||||
return tlsutil.WrapTLSClient(c, tlsConfig)
|
||||
}
|
||||
|
||||
// Get the incoming tls config
|
||||
incomingTLS, err := tlsConf.IncomingTLSConfig()
|
||||
if err != nil {
|
||||
|
@ -207,7 +213,7 @@ func NewServer(config *Config) (*Server, error) {
|
|||
// Create server
|
||||
s := &Server{
|
||||
config: config,
|
||||
connPool: NewPool(config.LogOutput, serverRPCCache, serverMaxStreams, tlsConfig),
|
||||
connPool: NewPool(config.LogOutput, serverRPCCache, serverMaxStreams, tlsWrap),
|
||||
eventChLAN: make(chan serf.Event, 256),
|
||||
eventChWAN: make(chan serf.Event, 256),
|
||||
localConsuls: make(map[string]*serverParts),
|
||||
|
@ -242,7 +248,7 @@ func NewServer(config *Config) (*Server, error) {
|
|||
}
|
||||
|
||||
// Initialize the RPC layer
|
||||
if err := s.setupRPC(tlsConfig); err != nil {
|
||||
if err := s.setupRPC(tlsWrap); err != nil {
|
||||
s.Shutdown()
|
||||
return nil, fmt.Errorf("Failed to start RPC layer: %v", err)
|
||||
}
|
||||
|
@ -410,7 +416,7 @@ func (s *Server) setupRaft() error {
|
|||
}
|
||||
|
||||
// setupRPC is used to setup the RPC listener
|
||||
func (s *Server) setupRPC(tlsConfig *tls.Config) error {
|
||||
func (s *Server) setupRPC(tlsWrap tlsutil.Wrapper) error {
|
||||
// Create endpoints
|
||||
s.endpoints.Status = &Status{s}
|
||||
s.endpoints.Catalog = &Catalog{s}
|
||||
|
@ -453,7 +459,7 @@ func (s *Server) setupRPC(tlsConfig *tls.Config) error {
|
|||
return fmt.Errorf("RPC advertise address is not advertisable: %v", addr)
|
||||
}
|
||||
|
||||
s.raftLayer = NewRaftLayer(advertise, tlsConfig)
|
||||
s.raftLayer = NewRaftLayer(advertise, tlsWrap)
|
||||
return nil
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue