mirror of https://github.com/status-im/consul.git
consul: use tlsutil.Wrapper instead of tls.Config directly
This commit is contained in:
parent
6b2390833d
commit
92e5548b23
|
@ -5,6 +5,7 @@ import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"log"
|
"log"
|
||||||
"math/rand"
|
"math/rand"
|
||||||
|
"net"
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"strconv"
|
"strconv"
|
||||||
|
@ -13,6 +14,7 @@ import (
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/hashicorp/consul/consul/structs"
|
"github.com/hashicorp/consul/consul/structs"
|
||||||
|
"github.com/hashicorp/consul/tlsutil"
|
||||||
"github.com/hashicorp/serf/serf"
|
"github.com/hashicorp/serf/serf"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -98,13 +100,18 @@ func NewClient(config *Config) (*Client, error) {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Define a TLS wrapper
|
||||||
|
tlsWrap := func(c net.Conn) (net.Conn, error) {
|
||||||
|
return tlsutil.WrapTLSClient(c, tlsConfig)
|
||||||
|
}
|
||||||
|
|
||||||
// Create a logger
|
// Create a logger
|
||||||
logger := log.New(config.LogOutput, "", log.LstdFlags)
|
logger := log.New(config.LogOutput, "", log.LstdFlags)
|
||||||
|
|
||||||
// Create server
|
// Create server
|
||||||
c := &Client{
|
c := &Client{
|
||||||
config: config,
|
config: config,
|
||||||
connPool: NewPool(config.LogOutput, clientRPCCache, clientMaxStreams, tlsConfig),
|
connPool: NewPool(config.LogOutput, clientRPCCache, clientMaxStreams, tlsWrap),
|
||||||
eventCh: make(chan serf.Event, 256),
|
eventCh: make(chan serf.Event, 256),
|
||||||
logger: logger,
|
logger: logger,
|
||||||
shutdownCh: make(chan struct{}),
|
shutdownCh: make(chan struct{}),
|
||||||
|
|
|
@ -2,7 +2,6 @@ package consul
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"container/list"
|
"container/list"
|
||||||
"crypto/tls"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"net"
|
"net"
|
||||||
|
@ -135,8 +134,8 @@ type ConnPool struct {
|
||||||
// Pool maps an address to a open connection
|
// Pool maps an address to a open connection
|
||||||
pool map[string]*Conn
|
pool map[string]*Conn
|
||||||
|
|
||||||
// TLS settings
|
// TLS wrapper
|
||||||
tlsConfig *tls.Config
|
tlsWrap tlsutil.Wrapper
|
||||||
|
|
||||||
// Used to indicate the pool is shutdown
|
// Used to indicate the pool is shutdown
|
||||||
shutdown bool
|
shutdown bool
|
||||||
|
@ -148,13 +147,13 @@ type ConnPool struct {
|
||||||
// Set maxTime to 0 to disable reaping. maxStreams is used to control
|
// Set maxTime to 0 to disable reaping. maxStreams is used to control
|
||||||
// the number of idle streams allowed.
|
// the number of idle streams allowed.
|
||||||
// If TLS settings are provided outgoing connections use TLS.
|
// If TLS settings are provided outgoing connections use TLS.
|
||||||
func NewPool(logOutput io.Writer, maxTime time.Duration, maxStreams int, tlsConfig *tls.Config) *ConnPool {
|
func NewPool(logOutput io.Writer, maxTime time.Duration, maxStreams int, tlsWrap tlsutil.Wrapper) *ConnPool {
|
||||||
pool := &ConnPool{
|
pool := &ConnPool{
|
||||||
logOutput: logOutput,
|
logOutput: logOutput,
|
||||||
maxTime: maxTime,
|
maxTime: maxTime,
|
||||||
maxStreams: maxStreams,
|
maxStreams: maxStreams,
|
||||||
pool: make(map[string]*Conn),
|
pool: make(map[string]*Conn),
|
||||||
tlsConfig: tlsConfig,
|
tlsWrap: tlsWrap,
|
||||||
shutdownCh: make(chan struct{}),
|
shutdownCh: make(chan struct{}),
|
||||||
}
|
}
|
||||||
if maxTime > 0 {
|
if maxTime > 0 {
|
||||||
|
@ -220,7 +219,7 @@ func (p *ConnPool) getNewConn(addr net.Addr, version int) (*Conn, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check if TLS is enabled
|
// Check if TLS is enabled
|
||||||
if p.tlsConfig != nil {
|
if p.tlsWrap != nil {
|
||||||
// Switch the connection into TLS mode
|
// Switch the connection into TLS mode
|
||||||
if _, err := conn.Write([]byte{byte(rpcTLS)}); err != nil {
|
if _, err := conn.Write([]byte{byte(rpcTLS)}); err != nil {
|
||||||
conn.Close()
|
conn.Close()
|
||||||
|
@ -228,7 +227,7 @@ func (p *ConnPool) getNewConn(addr net.Addr, version int) (*Conn, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Wrap the connection in a TLS client
|
// Wrap the connection in a TLS client
|
||||||
tlsConn, err := tlsutil.WrapTLSClient(conn, p.tlsConfig)
|
tlsConn, err := p.tlsWrap(conn)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
conn.Close()
|
conn.Close()
|
||||||
return nil, err
|
return nil, err
|
||||||
|
|
|
@ -1,12 +1,12 @@
|
||||||
package consul
|
package consul
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"crypto/tls"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"github.com/hashicorp/consul/tlsutil"
|
|
||||||
"net"
|
"net"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/hashicorp/consul/tlsutil"
|
||||||
)
|
)
|
||||||
|
|
||||||
// RaftLayer implements the raft.StreamLayer interface,
|
// RaftLayer implements the raft.StreamLayer interface,
|
||||||
|
@ -18,8 +18,8 @@ type RaftLayer struct {
|
||||||
// connCh is used to accept connections
|
// connCh is used to accept connections
|
||||||
connCh chan net.Conn
|
connCh chan net.Conn
|
||||||
|
|
||||||
// TLS configuration
|
// TLS wrapper
|
||||||
tlsConfig *tls.Config
|
tlsWrap tlsutil.Wrapper
|
||||||
|
|
||||||
// Tracks if we are closed
|
// Tracks if we are closed
|
||||||
closed bool
|
closed bool
|
||||||
|
@ -30,11 +30,11 @@ type RaftLayer struct {
|
||||||
// NewRaftLayer is used to initialize a new RaftLayer which can
|
// NewRaftLayer is used to initialize a new RaftLayer which can
|
||||||
// be used as a StreamLayer for Raft. If a tlsConfig is provided,
|
// be used as a StreamLayer for Raft. If a tlsConfig is provided,
|
||||||
// then the connection will use TLS.
|
// then the connection will use TLS.
|
||||||
func NewRaftLayer(addr net.Addr, tlsConfig *tls.Config) *RaftLayer {
|
func NewRaftLayer(addr net.Addr, tlsWrap tlsutil.Wrapper) *RaftLayer {
|
||||||
layer := &RaftLayer{
|
layer := &RaftLayer{
|
||||||
addr: addr,
|
addr: addr,
|
||||||
connCh: make(chan net.Conn),
|
connCh: make(chan net.Conn),
|
||||||
tlsConfig: tlsConfig,
|
tlsWrap: tlsWrap,
|
||||||
closeCh: make(chan struct{}),
|
closeCh: make(chan struct{}),
|
||||||
}
|
}
|
||||||
return layer
|
return layer
|
||||||
|
@ -87,7 +87,7 @@ func (l *RaftLayer) Dial(address string, timeout time.Duration) (net.Conn, error
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check for tls mode
|
// Check for tls mode
|
||||||
if l.tlsConfig != nil {
|
if l.tlsWrap != nil {
|
||||||
// Switch the connection into TLS mode
|
// Switch the connection into TLS mode
|
||||||
if _, err := conn.Write([]byte{byte(rpcTLS)}); err != nil {
|
if _, err := conn.Write([]byte{byte(rpcTLS)}); err != nil {
|
||||||
conn.Close()
|
conn.Close()
|
||||||
|
@ -95,7 +95,7 @@ func (l *RaftLayer) Dial(address string, timeout time.Duration) (net.Conn, error
|
||||||
}
|
}
|
||||||
|
|
||||||
// Wrap the connection in a TLS client
|
// Wrap the connection in a TLS client
|
||||||
conn, err = tlsutil.WrapTLSClient(conn, l.tlsConfig)
|
conn, err = l.tlsWrap(conn)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
|
@ -15,6 +15,7 @@ import (
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/hashicorp/consul/acl"
|
"github.com/hashicorp/consul/acl"
|
||||||
|
"github.com/hashicorp/consul/tlsutil"
|
||||||
"github.com/hashicorp/golang-lru"
|
"github.com/hashicorp/golang-lru"
|
||||||
"github.com/hashicorp/raft"
|
"github.com/hashicorp/raft"
|
||||||
"github.com/hashicorp/raft-boltdb"
|
"github.com/hashicorp/raft-boltdb"
|
||||||
|
@ -189,6 +190,11 @@ func NewServer(config *Config) (*Server, error) {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Define a TLS wrapper
|
||||||
|
tlsWrap := func(c net.Conn) (net.Conn, error) {
|
||||||
|
return tlsutil.WrapTLSClient(c, tlsConfig)
|
||||||
|
}
|
||||||
|
|
||||||
// Get the incoming tls config
|
// Get the incoming tls config
|
||||||
incomingTLS, err := tlsConf.IncomingTLSConfig()
|
incomingTLS, err := tlsConf.IncomingTLSConfig()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -207,7 +213,7 @@ func NewServer(config *Config) (*Server, error) {
|
||||||
// Create server
|
// Create server
|
||||||
s := &Server{
|
s := &Server{
|
||||||
config: config,
|
config: config,
|
||||||
connPool: NewPool(config.LogOutput, serverRPCCache, serverMaxStreams, tlsConfig),
|
connPool: NewPool(config.LogOutput, serverRPCCache, serverMaxStreams, tlsWrap),
|
||||||
eventChLAN: make(chan serf.Event, 256),
|
eventChLAN: make(chan serf.Event, 256),
|
||||||
eventChWAN: make(chan serf.Event, 256),
|
eventChWAN: make(chan serf.Event, 256),
|
||||||
localConsuls: make(map[string]*serverParts),
|
localConsuls: make(map[string]*serverParts),
|
||||||
|
@ -242,7 +248,7 @@ func NewServer(config *Config) (*Server, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Initialize the RPC layer
|
// Initialize the RPC layer
|
||||||
if err := s.setupRPC(tlsConfig); err != nil {
|
if err := s.setupRPC(tlsWrap); err != nil {
|
||||||
s.Shutdown()
|
s.Shutdown()
|
||||||
return nil, fmt.Errorf("Failed to start RPC layer: %v", err)
|
return nil, fmt.Errorf("Failed to start RPC layer: %v", err)
|
||||||
}
|
}
|
||||||
|
@ -410,7 +416,7 @@ func (s *Server) setupRaft() error {
|
||||||
}
|
}
|
||||||
|
|
||||||
// setupRPC is used to setup the RPC listener
|
// setupRPC is used to setup the RPC listener
|
||||||
func (s *Server) setupRPC(tlsConfig *tls.Config) error {
|
func (s *Server) setupRPC(tlsWrap tlsutil.Wrapper) error {
|
||||||
// Create endpoints
|
// Create endpoints
|
||||||
s.endpoints.Status = &Status{s}
|
s.endpoints.Status = &Status{s}
|
||||||
s.endpoints.Catalog = &Catalog{s}
|
s.endpoints.Catalog = &Catalog{s}
|
||||||
|
@ -453,7 +459,7 @@ func (s *Server) setupRPC(tlsConfig *tls.Config) error {
|
||||||
return fmt.Errorf("RPC advertise address is not advertisable: %v", addr)
|
return fmt.Errorf("RPC advertise address is not advertisable: %v", addr)
|
||||||
}
|
}
|
||||||
|
|
||||||
s.raftLayer = NewRaftLayer(advertise, tlsConfig)
|
s.raftLayer = NewRaftLayer(advertise, tlsWrap)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue