rpc: bind rpc test server to port 0

This commit is contained in:
Frank Schroeder 2017-06-25 21:36:03 +02:00 committed by Frank Schröder
parent e9e2c599db
commit 53eab7e970
5 changed files with 75 additions and 21 deletions

View File

@ -85,6 +85,11 @@ type Config struct {
// as a voting member of the Raft cluster.
NonVoter bool
// NotifyListen is called after the RPC listener has been configured.
// RPCAdvertise will be set to the listener address if it hasn't been
// configured at this point.
NotifyListen func()
// RPCAddr is the RPC address used by Consul. This should be reachable
// by the WAN and LAN
RPCAddr *net.TCPAddr
@ -92,7 +97,8 @@ type Config struct {
// RPCAdvertise is the address that is advertised to other nodes for
// the RPC endpoint. This can differ from the RPC address, if for example
// the RPCAddr is unspecified "0.0.0.0:8300", but this address must be
// reachable
// reachable. If RPCAdvertise is nil then it will be set to the Listener
// address after the listening socket is configured.
RPCAdvertise *net.TCPAddr
// RPCSrcAddr is the source address for outgoing RPC connections.

View File

@ -49,7 +49,7 @@ const (
func (s *Server) listen() {
for {
// Accept a connection
conn, err := s.rpcListener.Accept()
conn, err := s.Listener.Accept()
if err != nil {
if s.shutdown {
return

View File

@ -150,9 +150,9 @@ type Server struct {
// Enterprise user-defined areas.
router *servers.Router
// rpcListener is used to listen for incoming connections
rpcListener net.Listener
rpcServer *rpc.Server
// Listener is used to listen for incoming connections
Listener net.Listener
rpcServer *rpc.Server
// rpcTLS is the TLS config for incoming TLS requests
rpcTLS *tls.Config
@ -392,7 +392,7 @@ func NewServerLogger(config *Config, logger *log.Logger) (*Server, error) {
// setupSerf is used to setup and initialize a Serf
func (s *Server) setupSerf(conf *serf.Config, ch chan serf.Event, path string, wan bool) (*serf.Serf, error) {
addr := s.rpcListener.Addr().(*net.TCPAddr)
addr := s.Listener.Addr().(*net.TCPAddr)
conf.Init()
if wan {
conf.NodeName = fmt.Sprintf("%s.%s", s.config.NodeName, s.config.Datacenter)
@ -645,7 +645,14 @@ func (s *Server) setupRPC(tlsWrap tlsutil.DCWrapper) error {
if err != nil {
return err
}
s.rpcListener = ln
s.Listener = ln
if s.config.NotifyListen != nil {
s.config.NotifyListen()
}
// todo(fs): we should probably guard this
if s.config.RPCAdvertise == nil {
s.config.RPCAdvertise = ln.Addr().(*net.TCPAddr)
}
// Verify that we have a usable advertise address
if s.config.RPCAdvertise.IP.IsUnspecified() {
@ -714,8 +721,8 @@ func (s *Server) Shutdown() error {
}
}
if s.rpcListener != nil {
s.rpcListener.Close()
if s.Listener != nil {
s.Listener.Close()
}
// Close the connection pool

View File

@ -35,25 +35,30 @@ func testServerConfig(t *testing.T, NodeName string) (string, *Config) {
config.Bootstrap = true
config.Datacenter = "dc1"
config.DataDir = dir
config.RPCAddr = &net.TCPAddr{
IP: []byte{127, 0, 0, 1},
Port: getPort(),
}
config.RPCAdvertise = config.RPCAddr
// bind the rpc server to a random port. config.RPCAdvertise will be
// set to the listen address unless it was set in the configuration.
// In that case get the address from srv.Listener.Addr().
config.RPCAddr = &net.TCPAddr{IP: []byte{127, 0, 0, 1}}
nodeID, err := uuid.GenerateUUID()
if err != nil {
t.Fatal(err)
}
config.NodeID = types.NodeID(nodeID)
// set the memberlist bind port to 0 to bind to a random port.
// memberlist will update the value of BindPort after bind
// to the actual value.
config.SerfLANConfig.MemberlistConfig.BindAddr = "127.0.0.1"
config.SerfLANConfig.MemberlistConfig.BindPort = getPort()
config.SerfLANConfig.MemberlistConfig.BindPort = 0
config.SerfLANConfig.MemberlistConfig.SuspicionMult = 2
config.SerfLANConfig.MemberlistConfig.ProbeTimeout = 50 * time.Millisecond
config.SerfLANConfig.MemberlistConfig.ProbeInterval = 100 * time.Millisecond
config.SerfLANConfig.MemberlistConfig.GossipInterval = 100 * time.Millisecond
config.SerfWANConfig.MemberlistConfig.BindAddr = "127.0.0.1"
config.SerfWANConfig.MemberlistConfig.BindPort = getPort()
config.SerfWANConfig.MemberlistConfig.BindPort = 0
config.SerfWANConfig.MemberlistConfig.SuspicionMult = 2
config.SerfWANConfig.MemberlistConfig.ProbeTimeout = 50 * time.Millisecond
config.SerfWANConfig.MemberlistConfig.ProbeInterval = 100 * time.Millisecond
@ -107,14 +112,50 @@ func testServerDCExpect(t *testing.T, dc string, expect int) (string, *Server) {
func testServerWithConfig(t *testing.T, cb func(c *Config)) (string, *Server) {
name := fmt.Sprintf("Node %d", getPort())
dir, config := testServerConfig(t, name)
cb(config)
server, err := NewServer(config)
if cb != nil {
cb(config)
}
server, err := newServer(config)
if err != nil {
t.Fatalf("err: %v", err)
}
return dir, server
}
func newServer(c *Config) (*Server, error) {
// chain server up notification
oldNotify := c.NotifyListen
up := make(chan struct{})
c.NotifyListen = func() {
close(up)
if oldNotify != nil {
oldNotify()
}
}
// start server
srv, err := NewServer(c)
if err != nil {
return nil, err
}
// wait until after listen
<-up
// get the real address
//
// the server already sets the RPCAdvertise address
// if it wasn't configured since it needs it for
// some initialization
//
// todo(fs): setting RPCAddr should probably be guarded
// todo(fs): but for now it is a shortcut to avoid fixing
// todo(fs): tests which depend on that value. They should
// todo(fs): just get the listener address instead.
c.RPCAddr = srv.Listener.Addr().(*net.TCPAddr)
return srv, nil
}
func TestServer_StartStop(t *testing.T) {
// Start up a server and then stop it.
dir1, s1 := testServer(t)
@ -381,7 +422,7 @@ func TestServer_JoinLAN_TLS(t *testing.T) {
conf1.VerifyIncoming = true
conf1.VerifyOutgoing = true
configureTLS(conf1)
s1, err := NewServer(conf1)
s1, err := newServer(conf1)
if err != nil {
t.Fatalf("err: %v", err)
}
@ -393,7 +434,7 @@ func TestServer_JoinLAN_TLS(t *testing.T) {
conf2.VerifyIncoming = true
conf2.VerifyOutgoing = true
configureTLS(conf2)
s2, err := NewServer(conf2)
s2, err := newServer(conf2)
if err != nil {
t.Fatalf("err: %v", err)
}

View File

@ -13,7 +13,7 @@ import (
)
func rpcClient(t *testing.T, s *Server) rpc.ClientCodec {
addr := s.config.RPCAddr
addr := s.config.RPCAdvertise
conn, err := net.DialTimeout("tcp", addr.String(), time.Second)
if err != nil {
t.Fatalf("err: %v", err)