agent/consul: make router required

This commit is contained in:
Daniel Nephin 2020-09-11 12:43:29 -04:00
parent d5edce269e
commit 3aa9bd4c23
5 changed files with 37 additions and 22 deletions

View File

@ -68,8 +68,7 @@ type Client struct {
// from an agent. // from an agent.
rpcLimiter atomic.Value rpcLimiter atomic.Value
// eventCh is used to receive events from the // eventCh is used to receive events from the serf cluster in the datacenter
// serf cluster in the datacenter
eventCh chan serf.Event eventCh chan serf.Event
// Logger uses the provided LogOutput // Logger uses the provided LogOutput
@ -108,6 +107,9 @@ func NewClient(config *Config, options ...ConsulOption) (*Client, error) {
if flat.logger == nil { if flat.logger == nil {
return nil, fmt.Errorf("logger is required") return nil, fmt.Errorf("logger is required")
} }
if flat.router == nil {
return nil, fmt.Errorf("router is required")
}
if connPool == nil { if connPool == nil {
connPool = &pool.ConnPool{ connPool = &pool.ConnPool{
@ -156,23 +158,17 @@ func NewClient(config *Config, options ...ConsulOption) (*Client, error) {
} }
// Initialize the LAN Serf // Initialize the LAN Serf
c.serf, err = c.setupSerf(config.SerfLANConfig, c.serf, err = c.setupSerf(config.SerfLANConfig, c.eventCh, serfLANSnapshot)
c.eventCh, serfLANSnapshot)
if err != nil { if err != nil {
c.Shutdown() c.Shutdown()
return nil, fmt.Errorf("Failed to start lan serf: %v", err) return nil, fmt.Errorf("Failed to start lan serf: %v", err)
} }
rpcRouter := flat.router if err := flat.router.AddArea(types.AreaLAN, c.serf, c.connPool); err != nil {
if rpcRouter == nil {
rpcRouter = router.NewRouter(logger, config.Datacenter, fmt.Sprintf("%s.%s", config.NodeName, config.Datacenter))
}
if err := rpcRouter.AddArea(types.AreaLAN, c.serf, c.connPool); err != nil {
c.Shutdown() c.Shutdown()
return nil, fmt.Errorf("Failed to add LAN area to the RPC router: %w", err) return nil, fmt.Errorf("Failed to add LAN area to the RPC router: %w", err)
} }
c.router = rpcRouter c.router = flat.router
// Start LAN event handlers after the router is complete since the event // Start LAN event handlers after the router is complete since the event
// handlers depend on the router and the router depends on Serf. // handlers depend on the router and the router depends on Serf.

View File

@ -2,12 +2,14 @@ package consul
import ( import (
"bytes" "bytes"
"fmt"
"net" "net"
"os" "os"
"sync" "sync"
"testing" "testing"
"time" "time"
"github.com/hashicorp/consul/agent/router"
"github.com/hashicorp/consul/agent/structs" "github.com/hashicorp/consul/agent/structs"
"github.com/hashicorp/consul/sdk/freeport" "github.com/hashicorp/consul/sdk/freeport"
"github.com/hashicorp/consul/sdk/testutil" "github.com/hashicorp/consul/sdk/testutil"
@ -75,7 +77,11 @@ func testClientWithConfigWithErr(t *testing.T, cb func(c *Config)) (string, *Cli
t.Fatalf("err: %v", err) t.Fatalf("err: %v", err)
} }
client, err := NewClient(config, WithLogger(logger), WithTLSConfigurator(tlsConf)) r := router.NewRouter(logger, config.Datacenter, fmt.Sprintf("%s.%s", config.NodeName, config.Datacenter))
client, err := NewClient(config,
WithLogger(logger),
WithTLSConfigurator(tlsConf),
WithRouter(r))
return dir, client, err return dir, client, err
} }
@ -473,7 +479,11 @@ func newClient(t *testing.T, config *Config) *Client {
Level: hclog.Debug, Level: hclog.Debug,
Output: testutil.NewLogBuffer(t), Output: testutil.NewLogBuffer(t),
}) })
client, err := NewClient(config, WithLogger(logger), WithTLSConfigurator(c)) r := router.NewRouter(logger, config.Datacenter, fmt.Sprintf("%s.%s", config.NodeName, config.Datacenter))
client, err := NewClient(config,
WithLogger(logger),
WithTLSConfigurator(c),
WithRouter(r))
require.NoError(t, err, "failed to create client") require.NoError(t, err, "failed to create client")
t.Cleanup(func() { t.Cleanup(func() {
client.Shutdown() client.Shutdown()

View File

@ -9,6 +9,7 @@ import (
"testing" "testing"
"time" "time"
"github.com/hashicorp/consul/agent/router"
"github.com/hashicorp/consul/agent/structs" "github.com/hashicorp/consul/agent/structs"
"github.com/hashicorp/consul/agent/token" "github.com/hashicorp/consul/agent/token"
"github.com/hashicorp/consul/api" "github.com/hashicorp/consul/api"
@ -1305,10 +1306,13 @@ func TestLeader_ConfigEntryBootstrap_Fail(t *testing.T) {
}) })
tlsConf, err := tlsutil.NewConfigurator(config.ToTLSUtilConfig(), logger) tlsConf, err := tlsutil.NewConfigurator(config.ToTLSUtilConfig(), logger)
require.NoError(t, err) require.NoError(t, err)
rpcRouter := router.NewRouter(logger, config.Datacenter, fmt.Sprintf("%s.%s", config.NodeName, config.Datacenter))
srv, err := NewServer(config, srv, err := NewServer(config,
WithLogger(logger), WithLogger(logger),
WithTokenStore(new(token.Store)), WithTokenStore(new(token.Store)),
WithTLSConfigurator(tlsConf)) WithTLSConfigurator(tlsConf),
WithRouter(rpcRouter))
require.NoError(t, err) require.NoError(t, err)
defer srv.Shutdown() defer srv.Shutdown()

View File

@ -331,7 +331,6 @@ func NewServer(config *Config, options ...ConsulOption) (*Server, error) {
tokens := flat.tokens tokens := flat.tokens
tlsConfigurator := flat.tlsConfigurator tlsConfigurator := flat.tlsConfigurator
connPool := flat.connPool connPool := flat.connPool
rpcRouter := flat.router
if err := config.CheckProtocolVersion(); err != nil { if err := config.CheckProtocolVersion(); err != nil {
return nil, err return nil, err
@ -345,6 +344,9 @@ func NewServer(config *Config, options ...ConsulOption) (*Server, error) {
if logger == nil { if logger == nil {
return nil, fmt.Errorf("logger is required") return nil, fmt.Errorf("logger is required")
} }
if flat.router == nil {
return nil, fmt.Errorf("router is required")
}
// Check if TLS is enabled // Check if TLS is enabled
if config.CAFile != "" || config.CAPath != "" { if config.CAFile != "" || config.CAPath != "" {
@ -388,10 +390,6 @@ func NewServer(config *Config, options ...ConsulOption) (*Server, error) {
serverLogger := logger.NamedIntercept(logging.ConsulServer) serverLogger := logger.NamedIntercept(logging.ConsulServer)
loggers := newLoggerStore(serverLogger) loggers := newLoggerStore(serverLogger)
if rpcRouter == nil {
rpcRouter = router.NewRouter(serverLogger, config.Datacenter, fmt.Sprintf("%s.%s", config.NodeName, config.Datacenter))
}
// Create server. // Create server.
s := &Server{ s := &Server{
config: config, config: config,
@ -403,7 +401,7 @@ func NewServer(config *Config, options ...ConsulOption) (*Server, error) {
loggers: loggers, loggers: loggers,
leaveCh: make(chan struct{}), leaveCh: make(chan struct{}),
reconcileCh: make(chan serf.Member, reconcileChSize), reconcileCh: make(chan serf.Member, reconcileChSize),
router: rpcRouter, router: flat.router,
rpcServer: rpc.NewServer(), rpcServer: rpc.NewServer(),
insecureRPCServer: rpc.NewServer(), insecureRPCServer: rpc.NewServer(),
tlsConfigurator: tlsConfigurator, tlsConfigurator: tlsConfigurator,

View File

@ -17,6 +17,7 @@ import (
"github.com/google/tcpproxy" "github.com/google/tcpproxy"
"github.com/hashicorp/consul/agent/connect/ca" "github.com/hashicorp/consul/agent/connect/ca"
"github.com/hashicorp/consul/agent/router"
"github.com/hashicorp/consul/ipaddr" "github.com/hashicorp/consul/ipaddr"
"github.com/hashicorp/memberlist" "github.com/hashicorp/memberlist"
@ -301,10 +302,13 @@ func newServer(t *testing.T, c *Config) (*Server, error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
rpcRouter := router.NewRouter(logger, c.Datacenter, fmt.Sprintf("%s.%s", c.NodeName, c.Datacenter))
srv, err := NewServer(c, srv, err := NewServer(c,
WithLogger(logger), WithLogger(logger),
WithTokenStore(new(token.Store)), WithTokenStore(new(token.Store)),
WithTLSConfigurator(tlsConf)) WithTLSConfigurator(tlsConf),
WithRouter(rpcRouter))
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -1491,10 +1495,13 @@ func TestServer_CALogging(t *testing.T) {
c, err := tlsutil.NewConfigurator(conf1.ToTLSUtilConfig(), logger) c, err := tlsutil.NewConfigurator(conf1.ToTLSUtilConfig(), logger)
require.NoError(t, err) require.NoError(t, err)
rpcRouter := router.NewRouter(logger, "dc1", fmt.Sprintf("%s.%s", "nodename", "dc1"))
s1, err := NewServer(conf1, s1, err := NewServer(conf1,
WithLogger(logger), WithLogger(logger),
WithTokenStore(new(token.Store)), WithTokenStore(new(token.Store)),
WithTLSConfigurator(c)) WithTLSConfigurator(c),
WithRouter(rpcRouter))
if err != nil { if err != nil {
t.Fatalf("err: %v", err) t.Fatalf("err: %v", err)
} }