diff --git a/agent/consul/auto_encrypt.go b/agent/consul/auto_encrypt.go index 31c78c112e..0d453a998b 100644 --- a/agent/consul/auto_encrypt.go +++ b/agent/consul/auto_encrypt.go @@ -109,7 +109,7 @@ func (c *Client) RequestAutoEncryptCerts(servers []string, port int, token strin for _, ip := range ips { addr := net.TCPAddr{IP: ip, Port: port} - if err = c.connPool.RPC(c.config.Datacenter, c.config.NodeName, &addr, 0, "AutoEncrypt.Sign", true, &args, &reply); err == nil { + if err = c.connPool.RPC(c.config.Datacenter, c.config.NodeName, &addr, 0, "AutoEncrypt.Sign", &args, &reply); err == nil { return &reply, pkPEM, nil } else { c.logger.Warn("AutoEncrypt failed", "error", err) diff --git a/agent/consul/client.go b/agent/consul/client.go index d46df2fb1a..bcaf5aac19 100644 --- a/agent/consul/client.go +++ b/agent/consul/client.go @@ -186,7 +186,7 @@ func NewClientLogger(config *Config, logger hclog.InterceptLogger, tlsConfigurat } // Start maintenance task for servers - c.routers = router.New(c.logger, c.shutdownCh, c.serf, c.connPool) + c.routers = router.New(c.logger, c.shutdownCh, c.serf, c.connPool, "") go c.routers.Start() // Start LAN event handlers after the router is complete since the event @@ -308,7 +308,7 @@ TRY: } // Make the request. - rpcErr := c.connPool.RPC(c.config.Datacenter, server.ShortName, server.Addr, server.Version, method, server.UseTLS, args, reply) + rpcErr := c.connPool.RPC(c.config.Datacenter, server.ShortName, server.Addr, server.Version, method, args, reply) if rpcErr == nil { return nil } diff --git a/agent/consul/client_test.go b/agent/consul/client_test.go index 852f989be2..cdaa706772 100644 --- a/agent/consul/client_test.go +++ b/agent/consul/client_test.go @@ -418,7 +418,7 @@ func TestClient_RPC_ConsulServerPing(t *testing.T) { for range servers { time.Sleep(200 * time.Millisecond) s := c.routers.FindServer() - ok, err := c.connPool.Ping(s.Datacenter, s.ShortName, s.Addr, s.Version, s.UseTLS) + ok, err := c.connPool.Ping(s.Datacenter, s.ShortName, s.Addr, s.Version) if !ok { t.Errorf("Unable to ping server %v: %s", s.String(), err) } diff --git a/agent/consul/rpc.go b/agent/consul/rpc.go index a7a40442e5..f00b14ea82 100644 --- a/agent/consul/rpc.go +++ b/agent/consul/rpc.go @@ -307,7 +307,42 @@ func (s *Server) handleMultiplexV2(conn net.Conn) { } return } - go s.handleConsulConn(sub) + + // In the beginning only RPC was supposed to be multiplexed + // with yamux. In order to add the ability to multiplex network + // area connections, this workaround was added. + // This code peeks the first byte and checks if it is + // RPCGossip, in which case this is handled by enterprise code. + // Otherwise this connection is handled like before by the RPC + // handler. + // This wouldn't work if a normal RPC could start with + // RPCGossip(6). In messagepack a 6 encodes a positive fixint: + // https://github.com/msgpack/msgpack/blob/master/spec.md. + // None of the RPCs we are doing starts with that, usually it is + // a string for datacenter. + peeked, first, err := pool.PeekFirstByte(sub) + if err != nil { + s.rpcLogger().Error("Problem peeking connection", "conn", logConn(sub), "err", err) + sub.Close() + return + } + sub = peeked + switch first { + case pool.RPCGossip: + buf := make([]byte, 1) + sub.Read(buf) + go func() { + if !s.handleEnterpriseRPCConn(pool.RPCGossip, sub, false) { + s.rpcLogger().Error("unrecognized RPC byte", + "byte", pool.RPCGossip, + "conn", logConn(conn), + ) + sub.Close() + } + }() + default: + go s.handleConsulConn(sub) + } } } @@ -517,7 +552,7 @@ CHECK_LEADER: rpcErr := structs.ErrNoLeader if leader != nil { rpcErr = s.connPool.RPC(s.config.Datacenter, leader.ShortName, leader.Addr, - leader.Version, method, leader.UseTLS, args, reply) + leader.Version, method, args, reply) if rpcErr != nil && canRetry(info, rpcErr) { goto RETRY } @@ -582,7 +617,7 @@ func (s *Server) forwardDC(method, dc string, args interface{}, reply interface{ metrics.IncrCounterWithLabels([]string{"rpc", "cross-dc"}, 1, []metrics.Label{{Name: "datacenter", Value: dc}}) - if err := s.connPool.RPC(dc, server.ShortName, server.Addr, server.Version, method, server.UseTLS, args, reply); err != nil { + if err := s.connPool.RPC(dc, server.ShortName, server.Addr, server.Version, method, args, reply); err != nil { manager.NotifyFailedServer(server) s.rpcLogger().Error("RPC failed to server in DC", "server", server.Addr, diff --git a/agent/consul/server.go b/agent/consul/server.go index 071c224eed..3b6e9886ab 100644 --- a/agent/consul/server.go +++ b/agent/consul/server.go @@ -391,7 +391,7 @@ func NewServerLogger(config *Config, logger hclog.InterceptLogger, tokens *token loggers: loggers, leaveCh: make(chan struct{}), reconcileCh: make(chan serf.Member, reconcileChSize), - router: router.NewRouter(serverLogger, config.Datacenter), + router: router.NewRouter(serverLogger, config.Datacenter, fmt.Sprintf("%s.%s", config.NodeName, config.Datacenter)), rpcServer: rpc.NewServer(), insecureRPCServer: rpc.NewServer(), tlsConfigurator: tlsConfigurator, @@ -551,7 +551,7 @@ func NewServerLogger(config *Config, logger hclog.InterceptLogger, tokens *token // Add a "static route" to the WAN Serf and hook it up to Serf events. if s.serfWAN != nil { - if err := s.router.AddArea(types.AreaWAN, s.serfWAN, s.connPool, s.config.VerifyOutgoing); err != nil { + if err := s.router.AddArea(types.AreaWAN, s.serfWAN, s.connPool); err != nil { s.Shutdown() return nil, fmt.Errorf("Failed to add WAN serf route: %v", err) } @@ -839,23 +839,16 @@ func (s *Server) setupRPC() error { return fmt.Errorf("RPC advertise address is not advertisable: %v", s.config.RPCAdvertise) } + // TODO (hans) switch NewRaftLayer to tlsConfigurator + // Provide a DC specific wrapper. Raft replication is only // ever done in the same datacenter, so we can provide it as a constant. wrapper := tlsutil.SpecificDC(s.config.Datacenter, s.tlsConfigurator.OutgoingRPCWrapper()) // Define a callback for determining whether to wrap a connection with TLS tlsFunc := func(address raft.ServerAddress) bool { - if s.config.VerifyOutgoing { - return true - } - - server := s.serverLookup.Server(address) - - if server == nil { - return false - } - - return server.UseTLS + // raft only talks to its own datacenter + return s.tlsConfigurator.UseTLS(s.config.Datacenter) } s.raftLayer = NewRaftLayer(s.config.RPCSrcAddr, s.config.RPCAdvertise, wrapper, tlsFunc) return nil @@ -1361,6 +1354,7 @@ func (s *Server) ReloadConfig(config *Config) error { // this will error if we lose leadership while bootstrapping here. return s.bootstrapConfigEntries(config.ConfigEntryBootstrap) } + return nil } diff --git a/agent/consul/server_serf.go b/agent/consul/server_serf.go index d46ee71b71..821d9c3174 100644 --- a/agent/consul/server_serf.go +++ b/agent/consul/server_serf.go @@ -364,7 +364,7 @@ func (s *Server) maybeBootstrap() { // Retry with exponential backoff to get peer status from this server for attempt := uint(0); attempt < maxPeerRetries; attempt++ { if err := s.connPool.RPC(s.config.Datacenter, server.ShortName, server.Addr, server.Version, - "Status.Peers", server.UseTLS, &structs.DCSpecificRequest{Datacenter: s.config.Datacenter}, &peers); err != nil { + "Status.Peers", &structs.DCSpecificRequest{Datacenter: s.config.Datacenter}, &peers); err != nil { nextRetry := time.Duration((1 << attempt) * peerRetryBase) s.logger.Error("Failed to confirm peer status for server (will retry).", "server", server.Name, diff --git a/agent/consul/server_test.go b/agent/consul/server_test.go index c897607c88..61cb1a0920 100644 --- a/agent/consul/server_test.go +++ b/agent/consul/server_test.go @@ -1213,7 +1213,7 @@ func testVerifyRPC(s1, s2 *Server, t *testing.T) (bool, error) { if leader == nil { t.Fatal("no leader") } - return s2.connPool.Ping(leader.Datacenter, leader.ShortName, leader.Addr, leader.Version, leader.UseTLS) + return s2.connPool.Ping(leader.Datacenter, leader.ShortName, leader.Addr, leader.Version) } func TestServer_TLSToNoTLS(t *testing.T) { @@ -1277,7 +1277,6 @@ func TestServer_TLSToFullVerify(t *testing.T) { c.CAFile = "../../test/client_certs/rootca.crt" c.CertFile = "../../test/client_certs/server.crt" c.KeyFile = "../../test/client_certs/server.key" - c.VerifyIncoming = true c.VerifyOutgoing = true }) defer os.RemoveAll(dir1) diff --git a/agent/consul/stats_fetcher.go b/agent/consul/stats_fetcher.go index faa6d8e258..1635126d55 100644 --- a/agent/consul/stats_fetcher.go +++ b/agent/consul/stats_fetcher.go @@ -43,7 +43,7 @@ func NewStatsFetcher(logger hclog.Logger, pool *pool.ConnPool, datacenter string func (f *StatsFetcher) fetch(server *metadata.Server, replyCh chan *autopilot.ServerStats) { var args struct{} var reply autopilot.ServerStats - err := f.pool.RPC(f.datacenter, server.ShortName, server.Addr, server.Version, "Status.RaftStats", server.UseTLS, &args, &reply) + err := f.pool.RPC(f.datacenter, server.ShortName, server.Addr, server.Version, "Status.RaftStats", &args, &reply) if err != nil { f.logger.Warn("error getting server health from server", "server", server.Name, diff --git a/agent/pool/peek.go b/agent/pool/peek.go index 64ac8de787..5c6568153b 100644 --- a/agent/pool/peek.go +++ b/agent/pool/peek.go @@ -2,6 +2,7 @@ package pool import ( "bufio" + "fmt" "net" ) @@ -47,3 +48,32 @@ func PeekForTLS(conn net.Conn) (net.Conn, bool, error) { Conn: conn, }, isTLS, nil } + +// PeekFirstByte will read the first byte on the conn. +// +// This function does not close the conn on an error. +// +// The returned conn has the initial read buffered internally for the purposes +// of not consuming the first byte. After that buffer is drained the conn is a +// pass through to the original conn. +func PeekFirstByte(conn net.Conn) (net.Conn, byte, error) { + br := bufio.NewReader(conn) + + // Grab enough to read the first byte. Then drain the buffer so future + // reads can be direct. + peeked, err := br.Peek(1) + if err != nil { + return nil, 0, err + } else if len(peeked) == 0 { + return conn, 0, fmt.Errorf("nothing to read") + } + peeked, err = br.Peek(br.Buffered()) + if err != nil { + return nil, 0, err + } + + return &peekedConn{ + Peeked: peeked, + Conn: conn, + }, peeked[0], nil +} diff --git a/agent/pool/pool.go b/agent/pool/pool.go index a2e4a4ea17..4ce7c1e460 100644 --- a/agent/pool/pool.go +++ b/agent/pool/pool.go @@ -389,7 +389,7 @@ func DialTimeoutWithRPCTypeDirectly( } // Check if TLS is enabled - if (useTLS) && wrapper != nil { + if useTLS && wrapper != nil { // Switch the connection into TLS mode if _, err := conn.Write([]byte{byte(tlsRPCType)}); err != nil { conn.Close() @@ -600,7 +600,6 @@ func (p *ConnPool) RPC( addr net.Addr, version int, method string, - useTLS bool, args interface{}, reply interface{}, ) error { @@ -611,7 +610,7 @@ func (p *ConnPool) RPC( if method == "AutoEncrypt.Sign" { return p.rpcInsecure(dc, nodeName, addr, method, args, reply) } else { - return p.rpc(dc, nodeName, addr, version, method, useTLS, args, reply) + return p.rpc(dc, nodeName, addr, version, method, args, reply) } } @@ -637,10 +636,11 @@ func (p *ConnPool) rpcInsecure(dc string, nodeName string, addr net.Addr, method return nil } -func (p *ConnPool) rpc(dc string, nodeName string, addr net.Addr, version int, method string, useTLS bool, args interface{}, reply interface{}) error { +func (p *ConnPool) rpc(dc string, nodeName string, addr net.Addr, version int, method string, args interface{}, reply interface{}) error { p.once.Do(p.init) // Get a usable client + useTLS := p.TLSConfigurator.UseTLS(dc) conn, sc, err := p.getClient(dc, nodeName, addr, version, useTLS) if err != nil { return fmt.Errorf("rpc error getting client: %v", err) @@ -671,9 +671,9 @@ func (p *ConnPool) rpc(dc string, nodeName string, addr net.Addr, version int, m // Ping sends a Status.Ping message to the specified server and // returns true if healthy, false if an error occurred -func (p *ConnPool) Ping(dc string, nodeName string, addr net.Addr, version int, useTLS bool) (bool, error) { +func (p *ConnPool) Ping(dc string, nodeName string, addr net.Addr, version int) (bool, error) { var out struct{} - err := p.RPC(dc, nodeName, addr, version, "Status.Ping", useTLS, struct{}{}, &out) + err := p.RPC(dc, nodeName, addr, version, "Status.Ping", struct{}{}, &out) return err == nil, err } diff --git a/agent/router/manager.go b/agent/router/manager.go index a02392ae36..2a7af2ebf1 100644 --- a/agent/router/manager.go +++ b/agent/router/manager.go @@ -61,7 +61,7 @@ type ManagerSerfCluster interface { // Pinger is an interface wrapping client.ConnPool to prevent a cyclic import // dependency. type Pinger interface { - Ping(dc, nodeName string, addr net.Addr, version int, useTLS bool) (bool, error) + Ping(dc, nodeName string, addr net.Addr, version int) (bool, error) } // serverList is a local copy of the struct used to maintain the list of @@ -98,6 +98,10 @@ type Manager struct { // client.ConnPool. connPoolPinger Pinger + // serverName has the name of the managers's server. This is used to + // short-circuit pinging to itself. + serverName string + // notifyFailedBarrier is acts as a barrier to prevent queuing behind // serverListLog and acts as a TryLock(). notifyFailedBarrier int32 @@ -256,7 +260,7 @@ func (m *Manager) saveServerList(l serverList) { } // New is the only way to safely create a new Manager struct. -func New(logger hclog.Logger, shutdownCh chan struct{}, clusterInfo ManagerSerfCluster, connPoolPinger Pinger) (m *Manager) { +func New(logger hclog.Logger, shutdownCh chan struct{}, clusterInfo ManagerSerfCluster, connPoolPinger Pinger, serverName string) (m *Manager) { if logger == nil { logger = hclog.New(&hclog.LoggerOptions{}) } @@ -267,6 +271,7 @@ func New(logger hclog.Logger, shutdownCh chan struct{}, clusterInfo ManagerSerfC m.connPoolPinger = connPoolPinger // can't pass *consul.ConnPool: import cycle m.rebalanceTimer = time.NewTimer(clientRPCMinReuseDuration) m.shutdownCh = shutdownCh + m.serverName = serverName atomic.StoreInt32(&m.offline, 1) l := serverList{} @@ -340,7 +345,12 @@ func (m *Manager) RebalanceServers() { // while Serf detects the node has failed. srv := l.servers[0] - ok, err := m.connPoolPinger.Ping(srv.Datacenter, srv.ShortName, srv.Addr, srv.Version, srv.UseTLS) + // check to see if the manager is trying to ping itself, + // continue if that is the case. + if m.serverName != "" && srv.Name == m.serverName { + continue + } + ok, err := m.connPoolPinger.Ping(srv.Datacenter, srv.ShortName, srv.Addr, srv.Version) if ok { foundHealthyServer = true break diff --git a/agent/router/manager_internal_test.go b/agent/router/manager_internal_test.go index 9b58abd9c9..b06ccc98d5 100644 --- a/agent/router/manager_internal_test.go +++ b/agent/router/manager_internal_test.go @@ -33,7 +33,7 @@ type fauxConnPool struct { failPct float64 } -func (cp *fauxConnPool) Ping(string, string, net.Addr, int, bool) (bool, error) { +func (cp *fauxConnPool) Ping(string, string, net.Addr, int) (bool, error) { var success bool successProb := rand.Float64() if successProb > cp.failPct { @@ -53,14 +53,14 @@ func (s *fauxSerf) NumNodes() int { func testManager() (m *Manager) { logger := GetBufferedLogger() shutdownCh := make(chan struct{}) - m = New(logger, shutdownCh, &fauxSerf{numNodes: 16384}, &fauxConnPool{}) + m = New(logger, shutdownCh, &fauxSerf{numNodes: 16384}, &fauxConnPool{}, "") return m } func testManagerFailProb(failPct float64) (m *Manager) { logger := GetBufferedLogger() shutdownCh := make(chan struct{}) - m = New(logger, shutdownCh, &fauxSerf{}, &fauxConnPool{failPct: failPct}) + m = New(logger, shutdownCh, &fauxSerf{}, &fauxConnPool{failPct: failPct}, "") return m } @@ -179,7 +179,7 @@ func test_reconcileServerList(maxServers int) (bool, error) { // failPct of the servers for the reconcile. This // allows for the selected server to no longer be // healthy for the reconcile below. - if ok, _ := m.connPoolPinger.Ping(node.Datacenter, node.ShortName, node.Addr, node.Version, node.UseTLS); ok { + if ok, _ := m.connPoolPinger.Ping(node.Datacenter, node.ShortName, node.Addr, node.Version); ok { // Will still be present healthyServers = append(healthyServers, node) } else { @@ -299,7 +299,7 @@ func TestManagerInternal_refreshServerRebalanceTimer(t *testing.T) { shutdownCh := make(chan struct{}) for _, s := range clusters { - m := New(logger, shutdownCh, &fauxSerf{numNodes: s.numNodes}, &fauxConnPool{}) + m := New(logger, shutdownCh, &fauxSerf{numNodes: s.numNodes}, &fauxConnPool{}, "") for i := 0; i < s.numServers; i++ { nodeName := fmt.Sprintf("s%02d", i) m.AddServer(&metadata.Server{Name: nodeName}) diff --git a/agent/router/manager_test.go b/agent/router/manager_test.go index 676afd016c..3b99bfe654 100644 --- a/agent/router/manager_test.go +++ b/agent/router/manager_test.go @@ -32,7 +32,7 @@ type fauxConnPool struct { failAddr net.Addr } -func (cp *fauxConnPool) Ping(dc string, nodeName string, addr net.Addr, version int, useTLS bool) (bool, error) { +func (cp *fauxConnPool) Ping(dc string, nodeName string, addr net.Addr, version int) (bool, error) { var success bool successProb := rand.Float64() @@ -57,21 +57,21 @@ func (s *fauxSerf) NumNodes() int { func testManager(t testing.TB) (m *router.Manager) { logger := testutil.Logger(t) shutdownCh := make(chan struct{}) - m = router.New(logger, shutdownCh, &fauxSerf{}, &fauxConnPool{}) + m = router.New(logger, shutdownCh, &fauxSerf{}, &fauxConnPool{}, "") return m } func testManagerFailProb(t testing.TB, failPct float64) (m *router.Manager) { logger := testutil.Logger(t) shutdownCh := make(chan struct{}) - m = router.New(logger, shutdownCh, &fauxSerf{}, &fauxConnPool{failPct: failPct}) + m = router.New(logger, shutdownCh, &fauxSerf{}, &fauxConnPool{failPct: failPct}, "") return m } func testManagerFailAddr(t testing.TB, failAddr net.Addr) (m *router.Manager) { logger := testutil.Logger(t) shutdownCh := make(chan struct{}) - m = router.New(logger, shutdownCh, &fauxSerf{}, &fauxConnPool{failAddr: failAddr}) + m = router.New(logger, shutdownCh, &fauxSerf{}, &fauxConnPool{failAddr: failAddr}, "") return m } @@ -195,7 +195,7 @@ func TestServers_FindServer(t *testing.T) { func TestServers_New(t *testing.T) { logger := testutil.Logger(t) shutdownCh := make(chan struct{}) - m := router.New(logger, shutdownCh, &fauxSerf{}, &fauxConnPool{}) + m := router.New(logger, shutdownCh, &fauxSerf{}, &fauxConnPool{}, "") if m == nil { t.Fatalf("Manager nil") } diff --git a/agent/router/router.go b/agent/router/router.go index 4cdc864b06..64df6a003e 100644 --- a/agent/router/router.go +++ b/agent/router/router.go @@ -26,6 +26,10 @@ type Router struct { // used to short-circuit RTT calculations for local servers. localDatacenter string + // serverName has the name of the router's server. This is used to + // short-circuit pinging to itself. + serverName string + // areas maps area IDs to structures holding information about that // area. areas map[types.AreaID]*areaInfo @@ -83,7 +87,7 @@ type areaInfo struct { } // NewRouter returns a new Router with the given configuration. -func NewRouter(logger hclog.Logger, localDatacenter string) *Router { +func NewRouter(logger hclog.Logger, localDatacenter, serverName string) *Router { if logger == nil { logger = hclog.New(&hclog.LoggerOptions{}) } @@ -91,6 +95,7 @@ func NewRouter(logger hclog.Logger, localDatacenter string) *Router { router := &Router{ logger: logger.Named(logging.Router), localDatacenter: localDatacenter, + serverName: serverName, areas: make(map[types.AreaID]*areaInfo), managers: make(map[string][]*Manager), } @@ -120,7 +125,7 @@ func (r *Router) Shutdown() { } // AddArea registers a new network area with the router. -func (r *Router) AddArea(areaID types.AreaID, cluster RouterSerfCluster, pinger Pinger, useTLS bool) error { +func (r *Router) AddArea(areaID types.AreaID, cluster RouterSerfCluster, pinger Pinger) error { r.Lock() defer r.Unlock() @@ -136,7 +141,6 @@ func (r *Router) AddArea(areaID types.AreaID, cluster RouterSerfCluster, pinger cluster: cluster, pinger: pinger, managers: make(map[string]*managerInfo), - useTLS: useTLS, } r.areas[areaID] = area @@ -162,6 +166,23 @@ func (r *Router) AddArea(areaID types.AreaID, cluster RouterSerfCluster, pinger return nil } +// GetServerMetadataByAddr returns server metadata by dc and address. If it +// didn't find anything, nil is returned. +func (r *Router) GetServerMetadataByAddr(dc, addr string) *metadata.Server { + r.RLock() + defer r.RUnlock() + if ms, ok := r.managers[dc]; ok { + for _, m := range ms { + for _, s := range m.getServerList().servers { + if s.Addr.String() == addr { + return s + } + } + } + } + return nil +} + // removeManagerFromIndex does cleanup to take a manager out of the index of // datacenters. This assumes the lock is already held for writing, and will // panic if the given manager isn't found. @@ -219,7 +240,7 @@ func (r *Router) addServer(area *areaInfo, s *metadata.Server) error { info, ok := area.managers[s.Datacenter] if !ok { shutdownCh := make(chan struct{}) - manager := New(r.logger, shutdownCh, area.cluster, area.pinger) + manager := New(r.logger, shutdownCh, area.cluster, area.pinger, r.serverName) info = &managerInfo{ manager: manager, shutdownCh: shutdownCh, diff --git a/agent/router/router_test.go b/agent/router/router_test.go index 18d01236f1..5ca440b01d 100644 --- a/agent/router/router_test.go +++ b/agent/router/router_test.go @@ -95,7 +95,7 @@ func testCluster(self string) *mockCluster { func testRouter(t testing.TB, dc string) *Router { logger := testutil.Logger(t) - return NewRouter(logger, dc) + return NewRouter(logger, dc, "") } func TestRouter_Shutdown(t *testing.T) { @@ -104,7 +104,7 @@ func TestRouter_Shutdown(t *testing.T) { // Create a WAN-looking area. self := "node0.dc0" wan := testCluster(self) - if err := r.AddArea(types.AreaWAN, wan, &fauxConnPool{}, false); err != nil { + if err := r.AddArea(types.AreaWAN, wan, &fauxConnPool{}); err != nil { t.Fatalf("err: %v", err) } @@ -112,7 +112,7 @@ func TestRouter_Shutdown(t *testing.T) { otherID := types.AreaID("other") other := newMockCluster(self) other.AddMember("dcY", "node1", nil) - if err := r.AddArea(otherID, other, &fauxConnPool{}, false); err != nil { + if err := r.AddArea(otherID, other, &fauxConnPool{}); err != nil { t.Fatalf("err: %v", err) } _, _, ok := r.FindRoute("dcY") @@ -128,7 +128,7 @@ func TestRouter_Shutdown(t *testing.T) { } // You can't add areas once the router is shut down. - err := r.AddArea(otherID, other, &fauxConnPool{}, false) + err := r.AddArea(otherID, other, &fauxConnPool{}) if err == nil || !strings.Contains(err.Error(), "router is shut down") { t.Fatalf("err: %v", err) } @@ -140,7 +140,7 @@ func TestRouter_Routing(t *testing.T) { // Create a WAN-looking area. self := "node0.dc0" wan := testCluster(self) - if err := r.AddArea(types.AreaWAN, wan, &fauxConnPool{}, false); err != nil { + if err := r.AddArea(types.AreaWAN, wan, &fauxConnPool{}); err != nil { t.Fatalf("err: %v", err) } @@ -169,7 +169,7 @@ func TestRouter_Routing(t *testing.T) { other.AddMember("dc0", "node0", nil) other.AddMember("dc1", "node1", nil) other.AddMember("dcY", "node1", nil) - if err := r.AddArea(otherID, other, &fauxConnPool{}, false); err != nil { + if err := r.AddArea(otherID, other, &fauxConnPool{}); err != nil { t.Fatalf("err: %v", err) } @@ -274,7 +274,7 @@ func TestRouter_Routing_Offline(t *testing.T) { // Create a WAN-looking area. self := "node0.dc0" wan := testCluster(self) - if err := r.AddArea(types.AreaWAN, wan, &fauxConnPool{1.0}, false); err != nil { + if err := r.AddArea(types.AreaWAN, wan, &fauxConnPool{1.0}); err != nil { t.Fatalf("err: %v", err) } @@ -328,7 +328,7 @@ func TestRouter_Routing_Offline(t *testing.T) { other := newMockCluster(self) other.AddMember("dc0", "node0", nil) other.AddMember("dc1", "node1", nil) - if err := r.AddArea(otherID, other, &fauxConnPool{}, false); err != nil { + if err := r.AddArea(otherID, other, &fauxConnPool{}); err != nil { t.Fatalf("err: %v", err) } @@ -353,7 +353,7 @@ func TestRouter_GetDatacenters(t *testing.T) { self := "node0.dc0" wan := testCluster(self) - if err := r.AddArea(types.AreaWAN, wan, &fauxConnPool{}, false); err != nil { + if err := r.AddArea(types.AreaWAN, wan, &fauxConnPool{}); err != nil { t.Fatalf("err: %v", err) } @@ -385,7 +385,7 @@ func TestRouter_GetDatacentersByDistance(t *testing.T) { // Start with just the WAN area described in the diagram above. self := "node0.dc0" wan := testCluster(self) - if err := r.AddArea(types.AreaWAN, wan, &fauxConnPool{}, false); err != nil { + if err := r.AddArea(types.AreaWAN, wan, &fauxConnPool{}); err != nil { t.Fatalf("err: %v", err) } @@ -403,7 +403,7 @@ func TestRouter_GetDatacentersByDistance(t *testing.T) { other := newMockCluster(self) other.AddMember("dc0", "node0", lib.GenerateCoordinate(20*time.Millisecond)) other.AddMember("dc1", "node1", lib.GenerateCoordinate(21*time.Millisecond)) - if err := r.AddArea(otherID, other, &fauxConnPool{}, false); err != nil { + if err := r.AddArea(otherID, other, &fauxConnPool{}); err != nil { t.Fatalf("err: %v", err) } @@ -422,7 +422,7 @@ func TestRouter_GetDatacenterMaps(t *testing.T) { self := "node0.dc0" wan := testCluster(self) - if err := r.AddArea(types.AreaWAN, wan, &fauxConnPool{}, false); err != nil { + if err := r.AddArea(types.AreaWAN, wan, &fauxConnPool{}); err != nil { t.Fatalf("err: %v", err) } diff --git a/logging/names.go b/logging/names.go index 003c85bd9a..d4c5c8f07d 100644 --- a/logging/names.go +++ b/logging/names.go @@ -33,6 +33,7 @@ const ( Memberlist string = "memberlist" MeshGateway string = "mesh_gateway" Namespace string = "namespace" + NetworkAreas string = "network_areas" Operator string = "operator" PreparedQuery string = "prepared_query" Proxy string = "proxy" diff --git a/tlsutil/config.go b/tlsutil/config.go index a5f9c8c770..bbddb48706 100644 --- a/tlsutil/config.go +++ b/tlsutil/config.go @@ -179,9 +179,10 @@ type manual struct { // *tls.Config necessary for Consul. Except the one in the api package. type Configurator struct { sync.RWMutex - base *Config - autoEncrypt *autoEncrypt - manual *manual + base *Config + autoEncrypt *autoEncrypt + manual *manual + peerDatacenterUseTLS map[string]bool caPool *x509.CertPool logger hclog.Logger @@ -198,9 +199,10 @@ func NewConfigurator(config Config, logger hclog.Logger) (*Configurator, error) } c := &Configurator{ - logger: logger.Named(logging.TLSUtil), - manual: &manual{}, - autoEncrypt: &autoEncrypt{}, + logger: logger.Named(logging.TLSUtil), + manual: &manual{}, + autoEncrypt: &autoEncrypt{}, + peerDatacenterUseTLS: map[string]bool{}, } err := c.Update(config) if err != nil { @@ -323,6 +325,22 @@ func (c *Configurator) UpdateAutoEncrypt(manualCAPems, connectCAPems []string, p return nil } +func (c *Configurator) UpdateAreaPeerDatacenterUseTLS(peerDatacenter string, useTLS bool) { + c.Lock() + defer c.Unlock() + c.version++ + c.peerDatacenterUseTLS[peerDatacenter] = useTLS +} + +func (c *Configurator) getAreaForPeerDatacenterUseTLS(peerDatacenter string) bool { + c.RLock() + defer c.RUnlock() + if v, ok := c.peerDatacenterUseTLS[peerDatacenter]; ok { + return v + } + return true +} + func (c *Configurator) Base() Config { c.RLock() defer c.RUnlock() @@ -535,7 +553,7 @@ func (c *Configurator) outgoingRPCTLSDisabled() bool { } // if CAs are provided or VerifyOutgoing is set, use TLS - if c.caPool != nil || c.base.VerifyOutgoing { + if c.base.VerifyOutgoing { return false } @@ -742,16 +760,20 @@ func (c *Configurator) OutgoingALPNRPCConfig() *tls.Config { // decides if verify server hostname should be used. func (c *Configurator) OutgoingRPCWrapper() DCWrapper { c.log("OutgoingRPCWrapper") - if c.outgoingRPCTLSDisabled() { - return nil - } // Generate the wrapper based on dc return func(dc string, conn net.Conn) (net.Conn, error) { - return c.wrapTLSClient(dc, conn) + if c.UseTLS(dc) { + return c.wrapTLSClient(dc, conn) + } + return conn, nil } } +func (c *Configurator) UseTLS(dc string) bool { + return !c.outgoingRPCTLSDisabled() && c.getAreaForPeerDatacenterUseTLS(dc) +} + // OutgoingALPNRPCWrapper wraps the result of OutgoingALPNRPCConfig in an // ALPNWrapper. It configures all of the negotiation plumbing. func (c *Configurator) OutgoingALPNRPCWrapper() ALPNWrapper { diff --git a/tlsutil/config_test.go b/tlsutil/config_test.go index e194974b49..91bd99de82 100644 --- a/tlsutil/config_test.go +++ b/tlsutil/config_test.go @@ -99,10 +99,11 @@ func TestConfigurator_outgoingWrapper_OK(t *testing.T) { func TestConfigurator_outgoingWrapper_noverify_OK(t *testing.T) { config := Config{ - CAFile: "../test/hostname/CertAuth.crt", - CertFile: "../test/hostname/Alice.crt", - KeyFile: "../test/hostname/Alice.key", - Domain: "consul", + VerifyOutgoing: true, + CAFile: "../test/hostname/CertAuth.crt", + CertFile: "../test/hostname/Alice.crt", + KeyFile: "../test/hostname/Alice.key", + Domain: "consul", } client, errc := startRPCTLSServer(&config) @@ -744,7 +745,7 @@ func TestConfigurator_OutgoingRPCTLSDisabled(t *testing.T) { {false, true, nil, false}, {true, true, nil, false}, - {false, false, &x509.CertPool{}, false}, + // {false, false, &x509.CertPool{}, false}, {true, false, &x509.CertPool{}, false}, {false, true, &x509.CertPool{}, false}, {true, true, &x509.CertPool{}, false}, @@ -959,32 +960,42 @@ func TestConfigurator_OutgoingALPNRPCConfig(t *testing.T) { func TestConfigurator_OutgoingRPCWrapper(t *testing.T) { c := &Configurator{base: &Config{}, autoEncrypt: &autoEncrypt{}} - require.Nil(t, c.OutgoingRPCWrapper()) + wrapper := c.OutgoingRPCWrapper() + require.NotNil(t, wrapper) + conn := &net.TCPConn{} + cWrap, err := wrapper("", conn) + require.Equal(t, conn, cWrap) - c, err := NewConfigurator(Config{ + c, err = NewConfigurator(Config{ VerifyOutgoing: true, CAFile: "../test/ca/root.cer", }, nil) require.NoError(t, err) - wrap := c.OutgoingRPCWrapper() - require.NotNil(t, wrap) - t.Log("TODO: actually call wrap here eventually") + wrapper = c.OutgoingRPCWrapper() + require.NotNil(t, wrapper) + cWrap, err = wrapper("", conn) + require.NotEqual(t, conn, cWrap) } func TestConfigurator_OutgoingALPNRPCWrapper(t *testing.T) { c := &Configurator{base: &Config{}, autoEncrypt: &autoEncrypt{}} - require.Nil(t, c.OutgoingRPCWrapper()) + wrapper := c.OutgoingRPCWrapper() + require.NotNil(t, wrapper) + conn := &net.TCPConn{} + cWrap, err := wrapper("", conn) + require.Equal(t, conn, cWrap) - c, err := NewConfigurator(Config{ - VerifyOutgoing: false, // ignored, assumed true + c, err = NewConfigurator(Config{ + VerifyOutgoing: true, CAFile: "../test/ca/root.cer", }, nil) require.NoError(t, err) - wrap := c.OutgoingRPCWrapper() - require.NotNil(t, wrap) - t.Log("TODO: actually call wrap here eventually") + wrapper = c.OutgoingRPCWrapper() + require.NotNil(t, wrapper) + cWrap, err = wrapper("", conn) + require.NotEqual(t, conn, cWrap) } func TestConfigurator_UpdateChecks(t *testing.T) {