consul: Simplify TLS handling in RPC server

This commit is contained in:
Armon Dadgar 2014-04-07 12:45:33 -07:00
parent 592953309e
commit 39a55953af
1 changed files with 12 additions and 15 deletions

View File

@ -45,14 +45,14 @@ func (s *Server) listen() {
s.rpcClients[conn] = struct{}{} s.rpcClients[conn] = struct{}{}
s.rpcClientLock.Unlock() s.rpcClientLock.Unlock()
go s.handleConn(conn) go s.handleConn(conn, false)
metrics.IncrCounter([]string{"consul", "rpc", "accept_conn"}, 1) metrics.IncrCounter([]string{"consul", "rpc", "accept_conn"}, 1)
} }
} }
// handleConn is used to determine if this is a Raft or // handleConn is used to determine if this is a Raft or
// Consul type RPC connection and invoke the correct handler // Consul type RPC connection and invoke the correct handler
func (s *Server) handleConn(conn net.Conn) { func (s *Server) handleConn(conn net.Conn, isTLS bool) {
// Read a single byte // Read a single byte
buf := make([]byte, 1) buf := make([]byte, 1)
if _, err := conn.Read(buf); err != nil { if _, err := conn.Read(buf); err != nil {
@ -61,20 +61,8 @@ func (s *Server) handleConn(conn net.Conn) {
return return
} }
// Check if entering TLS mode
isTLS := false
if RPCType(buf[0]) == rpcTLS {
if s.rpcTLS == nil {
s.logger.Printf("[WARN] consul.rpc: TLS connection attempted, server not configured for TLS")
conn.Close()
return
}
conn = tls.Server(conn, s.rpcTLS)
isTLS = true
}
// Enforce TLS if VerifyIncoming is set // Enforce TLS if VerifyIncoming is set
if s.config.VerifyIncoming && !isTLS { if s.config.VerifyIncoming && !isTLS && RPCType(buf[0]) != rpcTLS {
s.logger.Printf("[WARN] consul.rpc: Non-TLS connection attempted with VerifyIncoming set") s.logger.Printf("[WARN] consul.rpc: Non-TLS connection attempted with VerifyIncoming set")
conn.Close() conn.Close()
return return
@ -92,6 +80,15 @@ func (s *Server) handleConn(conn net.Conn) {
case rpcMultiplex: case rpcMultiplex:
s.handleMultiplex(conn) s.handleMultiplex(conn)
case rpcTLS:
if s.rpcTLS == nil {
s.logger.Printf("[WARN] consul.rpc: TLS connection attempted, server not configured for TLS")
conn.Close()
return
}
conn = tls.Server(conn, s.rpcTLS)
s.handleConn(conn, true)
default: default:
s.logger.Printf("[ERR] consul.rpc: unrecognized RPC byte: %v", buf[0]) s.logger.Printf("[ERR] consul.rpc: unrecognized RPC byte: %v", buf[0])
conn.Close() conn.Close()