consul: Connection pool supports TLS mode

This commit is contained in:
Armon Dadgar 2014-04-04 16:27:56 -07:00
parent c593632d3c
commit 7beac2a1bc
1 changed files with 28 additions and 8 deletions

View File

@ -1,6 +1,7 @@
package consul package consul
import ( import (
"crypto/tls"
"fmt" "fmt"
"github.com/inconshreveable/muxado" "github.com/inconshreveable/muxado"
"github.com/ugorji/go/codec" "github.com/ugorji/go/codec"
@ -37,6 +38,9 @@ type ConnPool struct {
// Pool maps an address to a open connection // Pool maps an address to a open connection
pool map[string]*Conn pool map[string]*Conn
// TLS settings
tlsConfig *tls.Config
// Used to indicate the pool is shutdown // Used to indicate the pool is shutdown
shutdown bool shutdown bool
shutdownCh chan struct{} shutdownCh chan struct{}
@ -44,11 +48,13 @@ type ConnPool struct {
// NewPool is used to make a new connection pool // NewPool is used to make a new connection pool
// Maintain at most one connection per host, for up to maxTime. // Maintain at most one connection per host, for up to maxTime.
// Set maxTime to 0 to disable reaping. // Set maxTime to 0 to disable reaping. If TLS settings are provided
func NewPool(maxTime time.Duration) *ConnPool { // outgoing connections use TLS.
func NewPool(maxTime time.Duration, tlsConfig *tls.Config) *ConnPool {
pool := &ConnPool{ pool := &ConnPool{
maxTime: maxTime, maxTime: maxTime,
pool: make(map[string]*Conn), pool: make(map[string]*Conn),
tlsConfig: tlsConfig,
shutdownCh: make(chan struct{}), shutdownCh: make(chan struct{}),
} }
if maxTime > 0 { if maxTime > 0 {
@ -104,20 +110,34 @@ func (p *ConnPool) getPooled(addr net.Addr) *Conn {
// getNewConn is used to return a new connection // getNewConn is used to return a new connection
func (p *ConnPool) getNewConn(addr net.Addr) (*Conn, error) { func (p *ConnPool) getNewConn(addr net.Addr) (*Conn, error) {
// Try to dial the conn // Try to dial the conn
rawConn, err := net.DialTimeout("tcp", addr.String(), 10*time.Second) conn, err := net.DialTimeout("tcp", addr.String(), 10*time.Second)
if err != nil { if err != nil {
return nil, err return nil, err
} }
// Cast to TCPConn // Cast to TCPConn
conn := rawConn.(*net.TCPConn) if tcp, ok := conn.(*net.TCPConn); ok {
tcp.SetKeepAlive(true)
tcp.SetNoDelay(true)
}
// Enable keep alives // Check if TLS is enabled
conn.SetKeepAlive(true) if p.tlsConfig != nil {
conn.SetNoDelay(true) // Switch the connection into TLS mode
if _, err := conn.Write([]byte{byte(rpcTLS)}); err != nil {
conn.Close()
return nil, err
}
// Wrap the connection in a TLS client
conn = tls.Client(conn, p.tlsConfig)
}
// Write the Consul multiplex byte to set the mode // Write the Consul multiplex byte to set the mode
conn.Write([]byte{byte(rpcMultiplex)}) if _, err := conn.Write([]byte{byte(rpcMultiplex)}); err != nil {
conn.Close()
return nil, err
}
// Create a multiplexed session // Create a multiplexed session
session := muxado.Client(conn) session := muxado.Client(conn)