mirror of
https://github.com/status-im/consul.git
synced 2025-01-10 22:06:20 +00:00
consul: Ensure Raft also uses TLS connections
This commit is contained in:
parent
1ab9a4ad53
commit
7884439b7c
@ -1,6 +1,7 @@
|
||||
package consul
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"net"
|
||||
"sync"
|
||||
@ -16,6 +17,9 @@ type RaftLayer struct {
|
||||
// connCh is used to accept connections
|
||||
connCh chan net.Conn
|
||||
|
||||
// TLS configuration
|
||||
tlsConfig *tls.Config
|
||||
|
||||
// Tracks if we are closed
|
||||
closed bool
|
||||
closeCh chan struct{}
|
||||
@ -23,12 +27,14 @@ type RaftLayer struct {
|
||||
}
|
||||
|
||||
// NewRaftLayer is used to initialize a new RaftLayer which can
|
||||
// be used as a StreamLayer for Raft
|
||||
func NewRaftLayer(addr net.Addr) *RaftLayer {
|
||||
// be used as a StreamLayer for Raft. If a tlsConfig is provided,
|
||||
// then the connection will use TLS.
|
||||
func NewRaftLayer(addr net.Addr, tlsConfig *tls.Config) *RaftLayer {
|
||||
layer := &RaftLayer{
|
||||
addr: addr,
|
||||
connCh: make(chan net.Conn),
|
||||
closeCh: make(chan struct{}),
|
||||
addr: addr,
|
||||
connCh: make(chan net.Conn),
|
||||
tlsConfig: tlsConfig,
|
||||
closeCh: make(chan struct{}),
|
||||
}
|
||||
return layer
|
||||
}
|
||||
@ -79,6 +85,18 @@ func (l *RaftLayer) Dial(address string, timeout time.Duration) (net.Conn, error
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Check for tls mode
|
||||
if l.tlsConfig != nil {
|
||||
// Switch the connection into TLS mode
|
||||
if _, err := conn.Write([]byte{byte(rpcTLS)}); err != nil {
|
||||
conn.Close()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Wrap the connection in a TLS client
|
||||
conn = tls.Client(conn, l.tlsConfig)
|
||||
}
|
||||
|
||||
// Write the Raft byte to set the mode
|
||||
_, err = conn.Write([]byte{byte(rpcRaft)})
|
||||
if err != nil {
|
||||
|
@ -160,7 +160,7 @@ func NewServer(config *Config) (*Server, error) {
|
||||
}
|
||||
|
||||
// Initialize the RPC layer
|
||||
if err := s.setupRPC(); err != nil {
|
||||
if err := s.setupRPC(tlsConfig); err != nil {
|
||||
s.Shutdown()
|
||||
return nil, fmt.Errorf("Failed to start RPC layer: %v", err)
|
||||
}
|
||||
@ -290,7 +290,7 @@ func (s *Server) setupRaft() error {
|
||||
}
|
||||
|
||||
// setupRPC is used to setup the RPC listener
|
||||
func (s *Server) setupRPC() error {
|
||||
func (s *Server) setupRPC(tlsConfig *tls.Config) error {
|
||||
// Create endpoints
|
||||
s.endpoints.Status = &Status{s}
|
||||
s.endpoints.Raft = &Raft{s}
|
||||
@ -329,7 +329,7 @@ func (s *Server) setupRPC() error {
|
||||
return fmt.Errorf("RPC advertise address is not advertisable: %v", addr)
|
||||
}
|
||||
|
||||
s.raftLayer = NewRaftLayer(advertise)
|
||||
s.raftLayer = NewRaftLayer(advertise, tlsConfig)
|
||||
go s.listen()
|
||||
return nil
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user