consul: Ensure Raft also uses TLS connections

This commit is contained in:
Armon Dadgar 2014-04-04 16:36:47 -07:00
parent 1ab9a4ad53
commit 7884439b7c
2 changed files with 26 additions and 8 deletions

View File

@ -1,6 +1,7 @@
package consul
import (
"crypto/tls"
"fmt"
"net"
"sync"
@ -16,6 +17,9 @@ type RaftLayer struct {
// connCh is used to accept connections
connCh chan net.Conn
// TLS configuration
tlsConfig *tls.Config
// Tracks if we are closed
closed bool
closeCh chan struct{}
@ -23,12 +27,14 @@ type RaftLayer struct {
}
// NewRaftLayer is used to initialize a new RaftLayer which can
// be used as a StreamLayer for Raft
func NewRaftLayer(addr net.Addr) *RaftLayer {
// be used as a StreamLayer for Raft. If a tlsConfig is provided,
// then the connection will use TLS.
func NewRaftLayer(addr net.Addr, tlsConfig *tls.Config) *RaftLayer {
layer := &RaftLayer{
addr: addr,
connCh: make(chan net.Conn),
closeCh: make(chan struct{}),
addr: addr,
connCh: make(chan net.Conn),
tlsConfig: tlsConfig,
closeCh: make(chan struct{}),
}
return layer
}
@ -79,6 +85,18 @@ func (l *RaftLayer) Dial(address string, timeout time.Duration) (net.Conn, error
return nil, err
}
// Check for tls mode
if l.tlsConfig != nil {
// Switch the connection into TLS mode
if _, err := conn.Write([]byte{byte(rpcTLS)}); err != nil {
conn.Close()
return nil, err
}
// Wrap the connection in a TLS client
conn = tls.Client(conn, l.tlsConfig)
}
// Write the Raft byte to set the mode
_, err = conn.Write([]byte{byte(rpcRaft)})
if err != nil {

View File

@ -160,7 +160,7 @@ func NewServer(config *Config) (*Server, error) {
}
// Initialize the RPC layer
if err := s.setupRPC(); err != nil {
if err := s.setupRPC(tlsConfig); err != nil {
s.Shutdown()
return nil, fmt.Errorf("Failed to start RPC layer: %v", err)
}
@ -290,7 +290,7 @@ func (s *Server) setupRaft() error {
}
// setupRPC is used to setup the RPC listener
func (s *Server) setupRPC() error {
func (s *Server) setupRPC(tlsConfig *tls.Config) error {
// Create endpoints
s.endpoints.Status = &Status{s}
s.endpoints.Raft = &Raft{s}
@ -329,7 +329,7 @@ func (s *Server) setupRPC() error {
return fmt.Errorf("RPC advertise address is not advertisable: %v", addr)
}
s.raftLayer = NewRaftLayer(advertise)
s.raftLayer = NewRaftLayer(advertise, tlsConfig)
go s.listen()
return nil
}