sdk: add freelist tracking and ephemeral port range skipping to freeport

This should cut down on test flakiness.

Problems handled:

- If you had enough parallel test cases running, the former circular
approach to handling the port block could hand out the same port to
multiple cases before they each had a chance to bind them, leading to
one of the two tests to fail.

- The freeport library would allocate out of the ephemeral port range.
This has been corrected for Linux (which should cover CI).

- The library now waits until a formerly-in-use port is verified to be
free before putting it back into circulation.
This commit is contained in:
R.B. Boyer 2019-08-27 16:16:41 -05:00 committed by R.B. Boyer
parent 90d945590a
commit f9496dc627
18 changed files with 699 additions and 99 deletions

View File

@ -1229,8 +1229,11 @@ func TestAgent_RestoreServiceWithAliasCheck(t *testing.T) {
testCtx, testCancel := context.WithCancel(context.Background()) testCtx, testCancel := context.WithCancel(context.Background())
defer testCancel() defer testCancel()
testHTTPServer := launchHTTPCheckServer(t, testCtx) testHTTPServer, returnPort := launchHTTPCheckServer(t, testCtx)
defer testHTTPServer.Close() defer func() {
testHTTPServer.Close()
returnPort()
}()
registerServicesAndChecks := func(t *testing.T, a *TestAgent) { registerServicesAndChecks := func(t *testing.T, a *TestAgent) {
// add one persistent service with a simple check // add one persistent service with a simple check
@ -1338,8 +1341,8 @@ node_name = "` + a.Config.NodeName + `"
} }
} }
func launchHTTPCheckServer(t *testing.T, ctx context.Context) *httptest.Server { func launchHTTPCheckServer(t *testing.T, ctx context.Context) (srv *httptest.Server, returnPortsFn func()) {
ports := freeport.GetT(t, 1) ports := freeport.MustTake(1)
port := ports[0] port := ports[0]
addr := net.JoinHostPort("127.0.0.1", strconv.Itoa(port)) addr := net.JoinHostPort("127.0.0.1", strconv.Itoa(port))
@ -1353,12 +1356,12 @@ func launchHTTPCheckServer(t *testing.T, ctx context.Context) *httptest.Server {
_, _ = w.Write([]byte("OK\n")) _, _ = w.Write([]byte("OK\n"))
}) })
srv := &httptest.Server{ srv = &httptest.Server{
Listener: listener, Listener: listener,
Config: &http.Server{Handler: handler}, Config: &http.Server{Handler: handler},
} }
srv.Start() srv.Start()
return srv return srv, func() { freeport.Return(ports) }
} }
func TestAgent_AddCheck_Alias(t *testing.T) { func TestAgent_AddCheck_Alias(t *testing.T) {

View File

@ -13,7 +13,7 @@ import (
"github.com/hashicorp/consul/sdk/testutil" "github.com/hashicorp/consul/sdk/testutil"
"github.com/hashicorp/consul/sdk/testutil/retry" "github.com/hashicorp/consul/sdk/testutil/retry"
"github.com/hashicorp/consul/testrpc" "github.com/hashicorp/consul/testrpc"
"github.com/hashicorp/net-rpc-msgpackrpc" msgpackrpc "github.com/hashicorp/net-rpc-msgpackrpc"
"github.com/hashicorp/serf/serf" "github.com/hashicorp/serf/serf"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"golang.org/x/time/rate" "golang.org/x/time/rate"
@ -22,15 +22,27 @@ import (
func testClientConfig(t *testing.T) (string, *Config) { func testClientConfig(t *testing.T) (string, *Config) {
dir := testutil.TempDir(t, "consul") dir := testutil.TempDir(t, "consul")
config := DefaultConfig() config := DefaultConfig()
ports := freeport.MustTake(2)
returnPortsFn := func() {
// The method of plumbing this into the client shutdown hook doesn't
// cover all exit points, so we insulate this against multiple
// invocations and then it's safe to call it a bunch of times.
freeport.Return(ports)
config.NotifyShutdown = nil // self-erasing
}
config.NotifyShutdown = returnPortsFn
config.Datacenter = "dc1" config.Datacenter = "dc1"
config.DataDir = dir config.DataDir = dir
config.NodeName = uniqueNodeName(t.Name()) config.NodeName = uniqueNodeName(t.Name())
config.RPCAddr = &net.TCPAddr{ config.RPCAddr = &net.TCPAddr{
IP: []byte{127, 0, 0, 1}, IP: []byte{127, 0, 0, 1},
Port: freeport.Get(1)[0], Port: ports[0],
} }
config.SerfLANConfig.MemberlistConfig.BindAddr = "127.0.0.1" config.SerfLANConfig.MemberlistConfig.BindAddr = "127.0.0.1"
config.SerfLANConfig.MemberlistConfig.BindPort = freeport.Get(1)[0] config.SerfLANConfig.MemberlistConfig.BindPort = ports[1]
config.SerfLANConfig.MemberlistConfig.ProbeTimeout = 200 * time.Millisecond config.SerfLANConfig.MemberlistConfig.ProbeTimeout = 200 * time.Millisecond
config.SerfLANConfig.MemberlistConfig.ProbeInterval = time.Second config.SerfLANConfig.MemberlistConfig.ProbeInterval = time.Second
config.SerfLANConfig.MemberlistConfig.GossipInterval = 100 * time.Millisecond config.SerfLANConfig.MemberlistConfig.GossipInterval = 100 * time.Millisecond
@ -59,6 +71,7 @@ func testClientWithConfig(t *testing.T, cb func(c *Config)) (string, *Client) {
} }
client, err := NewClient(config) client, err := NewClient(config)
if err != nil { if err != nil {
config.NotifyShutdown()
t.Fatalf("err: %v", err) t.Fatalf("err: %v", err)
} }
return dir, client return dir, client
@ -416,6 +429,7 @@ func TestClient_RPC_TLS(t *testing.T) {
defer s1.Shutdown() defer s1.Shutdown()
dir2, conf2 := testClientConfig(t) dir2, conf2 := testClientConfig(t)
defer conf2.NotifyShutdown()
conf2.VerifyOutgoing = true conf2.VerifyOutgoing = true
configureTLS(conf2) configureTLS(conf2)
c1, err := NewClient(conf2) c1, err := NewClient(conf2)
@ -460,6 +474,7 @@ func TestClient_RPC_RateLimit(t *testing.T) {
testrpc.WaitForLeader(t, s1.RPC, "dc1") testrpc.WaitForLeader(t, s1.RPC, "dc1")
dir2, conf2 := testClientConfig(t) dir2, conf2 := testClientConfig(t)
defer conf2.NotifyShutdown()
conf2.RPCRate = 2 conf2.RPCRate = 2
conf2.RPCMaxBurst = 2 conf2.RPCMaxBurst = 2
c1, err := NewClient(conf2) c1, err := NewClient(conf2)
@ -527,6 +542,7 @@ func TestClient_SnapshotRPC_RateLimit(t *testing.T) {
testrpc.WaitForLeader(t, s1.RPC, "dc1") testrpc.WaitForLeader(t, s1.RPC, "dc1")
dir2, conf1 := testClientConfig(t) dir2, conf1 := testClientConfig(t)
defer conf1.NotifyShutdown()
conf1.RPCRate = 2 conf1.RPCRate = 2
conf1.RPCMaxBurst = 2 conf1.RPCMaxBurst = 2
c1, err := NewClient(conf1) c1, err := NewClient(conf1)
@ -569,6 +585,7 @@ func TestClient_SnapshotRPC_TLS(t *testing.T) {
defer s1.Shutdown() defer s1.Shutdown()
dir2, conf2 := testClientConfig(t) dir2, conf2 := testClientConfig(t)
defer conf2.NotifyShutdown()
conf2.VerifyOutgoing = true conf2.VerifyOutgoing = true
configureTLS(conf2) configureTLS(conf2)
c1, err := NewClient(conf2) c1, err := NewClient(conf2)

View File

@ -110,6 +110,9 @@ type Config struct {
// configured at this point. // configured at this point.
NotifyListen func() NotifyListen func()
// NotifyShutdown is called after Server is completely Shutdown.
NotifyShutdown func()
// RPCAddr is the RPC address used by Consul. This should be reachable // RPCAddr is the RPC address used by Consul. This should be reachable
// by the WAN and LAN // by the WAN and LAN
RPCAddr *net.TCPAddr RPCAddr *net.TCPAddr

View File

@ -11,7 +11,7 @@ import (
"github.com/hashicorp/consul/agent/structs" "github.com/hashicorp/consul/agent/structs"
"github.com/hashicorp/consul/sdk/freeport" "github.com/hashicorp/consul/sdk/freeport"
"github.com/hashicorp/consul/testrpc" "github.com/hashicorp/consul/testrpc"
"github.com/hashicorp/net-rpc-msgpackrpc" msgpackrpc "github.com/hashicorp/net-rpc-msgpackrpc"
"github.com/hashicorp/raft" "github.com/hashicorp/raft"
"github.com/pascaldekloe/goe/verify" "github.com/pascaldekloe/goe/verify"
) )
@ -145,10 +145,13 @@ func TestOperator_RaftRemovePeerByAddress(t *testing.T) {
testrpc.WaitForLeader(t, s1.RPC, "dc1") testrpc.WaitForLeader(t, s1.RPC, "dc1")
ports := freeport.MustTake(1)
defer freeport.Return(ports)
// Try to remove a peer that's not there. // Try to remove a peer that's not there.
arg := structs.RaftRemovePeerRequest{ arg := structs.RaftRemovePeerRequest{
Datacenter: "dc1", Datacenter: "dc1",
Address: raft.ServerAddress(fmt.Sprintf("127.0.0.1:%d", freeport.Get(1)[0])), Address: raft.ServerAddress(fmt.Sprintf("127.0.0.1:%d", ports[0])),
} }
var reply struct{} var reply struct{}
err := msgpackrpc.CallWithCodec(codec, "Operator.RaftRemovePeerByAddress", &arg, &reply) err := msgpackrpc.CallWithCodec(codec, "Operator.RaftRemovePeerByAddress", &arg, &reply)
@ -277,7 +280,10 @@ func TestOperator_RaftRemovePeerByID(t *testing.T) {
// Add it manually to Raft. // Add it manually to Raft.
{ {
future := s1.raft.AddVoter(arg.ID, raft.ServerAddress(fmt.Sprintf("127.0.0.1:%d", freeport.Get(1)[0])), 0, 0) ports := freeport.MustTake(1)
defer freeport.Return(ports)
future := s1.raft.AddVoter(arg.ID, raft.ServerAddress(fmt.Sprintf("127.0.0.1:%d", ports[0])), 0, 0)
if err := future.Error(); err != nil { if err := future.Error(); err != nil {
t.Fatalf("err: %v", err) t.Fatalf("err: %v", err)
} }

View File

@ -843,6 +843,10 @@ func (s *Server) Shutdown() error {
// Close the connection pool // Close the connection pool
s.connPool.Shutdown() s.connPool.Shutdown()
if s.config.NotifyShutdown != nil {
s.config.NotifyShutdown()
}
return nil return nil
} }

View File

@ -42,7 +42,17 @@ func testServerConfig(t *testing.T) (string, *Config) {
dir := testutil.TempDir(t, "consul") dir := testutil.TempDir(t, "consul")
config := DefaultConfig() config := DefaultConfig()
ports := freeport.Get(3) ports := freeport.MustTake(3)
returnPortsFn := func() {
// The method of plumbing this into the server shutdown hook doesn't
// cover all exit points, so we insulate this against multiple
// invocations and then it's safe to call it a bunch of times.
freeport.Return(ports)
config.NotifyShutdown = nil // self-erasing
}
config.NotifyShutdown = returnPortsFn
config.NodeName = uniqueNodeName(t.Name()) config.NodeName = uniqueNodeName(t.Name())
config.Bootstrap = true config.Bootstrap = true
config.Datacenter = "dc1" config.Datacenter = "dc1"
@ -56,6 +66,7 @@ func testServerConfig(t *testing.T) (string, *Config) {
nodeID, err := uuid.GenerateUUID() nodeID, err := uuid.GenerateUUID()
if err != nil { if err != nil {
returnPortsFn()
t.Fatal(err) t.Fatal(err)
} }
config.NodeID = types.NodeID(nodeID) config.NodeID = types.NodeID(nodeID)
@ -112,6 +123,8 @@ func testServerConfig(t *testing.T) (string, *Config) {
}, },
} }
config.NotifyShutdown = returnPortsFn
return dir, config return dir, config
} }
@ -168,6 +181,7 @@ func testServerWithConfig(t *testing.T, cb func(*Config)) (string, *Server) {
srv, err = newServer(config) srv, err = newServer(config)
if err != nil { if err != nil {
config.NotifyShutdown()
os.RemoveAll(dir) os.RemoveAll(dir)
r.Fatalf("err: %v", err) r.Fatalf("err: %v", err)
} }

View File

@ -738,6 +738,7 @@ func TestParseWait(t *testing.T) {
t.Fatalf("Bad: %v", b) t.Fatalf("Bad: %v", b)
} }
} }
func TestPProfHandlers_EnableDebug(t *testing.T) { func TestPProfHandlers_EnableDebug(t *testing.T) {
t.Parallel() t.Parallel()
require := require.New(t) require := require.New(t)
@ -751,6 +752,7 @@ func TestPProfHandlers_EnableDebug(t *testing.T) {
require.Equal(http.StatusOK, resp.Code) require.Equal(http.StatusOK, resp.Code)
} }
func TestPProfHandlers_DisableDebugNoACLs(t *testing.T) { func TestPProfHandlers_DisableDebugNoACLs(t *testing.T) {
t.Parallel() t.Parallel()
require := require.New(t) require := require.New(t)

View File

@ -52,6 +52,10 @@ type TestAgent struct {
// when Shutdown() is called. // when Shutdown() is called.
Config *config.RuntimeConfig Config *config.RuntimeConfig
// returnPortsFn will put the ports claimed for the test back into the
// general freeport pool
returnPortsFn func()
// LogOutput is the sink for the logs. If nil, logs are written // LogOutput is the sink for the logs. If nil, logs are written
// to os.Stderr. // to os.Stderr.
LogOutput io.Writer LogOutput io.Writer
@ -150,12 +154,21 @@ func (a *TestAgent) Start() (err error) {
hclDataDir = `data_dir = "` + d + `"` hclDataDir = `data_dir = "` + d + `"`
} }
portsConfig, returnPortsFn := randomPortsSource(a.UseTLS)
a.returnPortsFn = returnPortsFn
a.Config = TestConfig( a.Config = TestConfig(
randomPortsSource(a.UseTLS), portsConfig,
config.Source{Name: a.Name, Format: "hcl", Data: a.HCL}, config.Source{Name: a.Name, Format: "hcl", Data: a.HCL},
config.Source{Name: a.Name + ".data_dir", Format: "hcl", Data: hclDataDir}, config.Source{Name: a.Name + ".data_dir", Format: "hcl", Data: hclDataDir},
) )
defer func() {
if err != nil && a.returnPortsFn != nil {
a.returnPortsFn()
a.returnPortsFn = nil
}
}()
// write the keyring // write the keyring
if a.Key != "" { if a.Key != "" {
writeKey := func(key, filename string) error { writeKey := func(key, filename string) error {
@ -286,6 +299,14 @@ func (a *TestAgent) Shutdown() error {
return nil return nil
} }
// Return ports last of all
defer func() {
if a.returnPortsFn != nil {
a.returnPortsFn()
a.returnPortsFn = nil
}
}()
// shutdown agent before endpoints // shutdown agent before endpoints
defer a.Agent.ShutdownEndpoints() defer a.Agent.ShutdownEndpoints()
if err := a.Agent.ShutdownAgent(); err != nil { if err := a.Agent.ShutdownAgent(); err != nil {
@ -350,27 +371,32 @@ func (a *TestAgent) consulConfig() *consul.Config {
// chance of port conflicts for concurrently executed test binaries. // chance of port conflicts for concurrently executed test binaries.
// Instead of relying on one set of ports to be sufficient we retry // Instead of relying on one set of ports to be sufficient we retry
// starting the agent with different ports on port conflict. // starting the agent with different ports on port conflict.
func randomPortsSource(tls bool) config.Source { func randomPortsSource(tls bool) (src config.Source, returnPortsFn func()) {
ports := freeport.Get(6) ports := freeport.MustTake(6)
var http, https int
if tls { if tls {
ports[1] = -1 http = -1
https = ports[2]
} else { } else {
ports[2] = -1 http = ports[1]
https = -1
} }
return config.Source{ return config.Source{
Name: "ports", Name: "ports",
Format: "hcl", Format: "hcl",
Data: ` Data: `
ports = { ports = {
dns = ` + strconv.Itoa(ports[0]) + ` dns = ` + strconv.Itoa(ports[0]) + `
http = ` + strconv.Itoa(ports[1]) + ` http = ` + strconv.Itoa(http) + `
https = ` + strconv.Itoa(ports[2]) + ` https = ` + strconv.Itoa(https) + `
serf_lan = ` + strconv.Itoa(ports[3]) + ` serf_lan = ` + strconv.Itoa(ports[3]) + `
serf_wan = ` + strconv.Itoa(ports[4]) + ` serf_wan = ` + strconv.Itoa(ports[4]) + `
server = ` + strconv.Itoa(ports[5]) + ` server = ` + strconv.Itoa(ports[5]) + `
} }
`, `,
} }, func() { freeport.Return(ports) }
} }
func NodeID() string { func NodeID() string {

View File

@ -4,13 +4,14 @@ import (
"bytes" "bytes"
"context" "context"
"fmt" "fmt"
"github.com/hashicorp/consul/connect"
"log" "log"
"net" "net"
"os" "os"
"testing" "testing"
"time" "time"
"github.com/hashicorp/consul/connect"
metrics "github.com/armon/go-metrics" metrics "github.com/armon/go-metrics"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
@ -110,7 +111,8 @@ func TestPublicListener(t *testing.T) {
// Can't enable t.Parallel since we rely on the global metrics instance. // Can't enable t.Parallel since we rely on the global metrics instance.
ca := agConnect.TestCA(t, nil) ca := agConnect.TestCA(t, nil)
ports := freeport.GetT(t, 1) ports := freeport.MustTake(1)
defer freeport.Return(ports)
testApp := NewTestTCPServer(t) testApp := NewTestTCPServer(t)
defer testApp.Close() defer testApp.Close()
@ -162,7 +164,8 @@ func TestUpstreamListener(t *testing.T) {
// Can't enable t.Parallel since we rely on the global metrics instance. // Can't enable t.Parallel since we rely on the global metrics instance.
ca := agConnect.TestCA(t, nil) ca := agConnect.TestCA(t, nil)
ports := freeport.GetT(t, 1) ports := freeport.MustTake(1)
defer freeport.Return(ports)
// Run a test server that we can dial. // Run a test server that we can dial.
testSvr := connect.NewTestServer(t, "db", ca) testSvr := connect.NewTestServer(t, "db", ca)

View File

@ -22,7 +22,9 @@ func TestProxy_public(t *testing.T) {
t.Parallel() t.Parallel()
require := require.New(t) require := require.New(t)
ports := freeport.GetT(t, 1)
ports := freeport.MustTake(1)
defer freeport.Return(ports)
a := agent.NewTestAgent(t, t.Name(), "") a := agent.NewTestAgent(t, t.Name(), "")
defer a.Shutdown() defer a.Shutdown()

View File

@ -24,14 +24,15 @@ type TestTCPServer struct {
l net.Listener l net.Listener
stopped int32 stopped int32
accepted, closed, active int32 accepted, closed, active int32
returnPortsFn func()
} }
// NewTestTCPServer opens as a listening socket on the given address and returns // NewTestTCPServer opens as a listening socket on the given address and returns
// a TestTCPServer serving requests to it. The server is already started and can // a TestTCPServer serving requests to it. The server is already started and can
// be stopped by calling Close(). // be stopped by calling Close().
func NewTestTCPServer(t testing.T) *TestTCPServer { func NewTestTCPServer(t testing.T) *TestTCPServer {
port := freeport.GetT(t, 1) ports := freeport.MustTake(1)
addr := TestLocalAddr(port[0]) addr := TestLocalAddr(ports[0])
l, err := net.Listen("tcp", addr) l, err := net.Listen("tcp", addr)
require.NoError(t, err) require.NoError(t, err)
@ -39,6 +40,7 @@ func NewTestTCPServer(t testing.T) *TestTCPServer {
log.Printf("test tcp server listening on %s", addr) log.Printf("test tcp server listening on %s", addr)
s := &TestTCPServer{ s := &TestTCPServer{
l: l, l: l,
returnPortsFn: func() { freeport.Return(ports) },
} }
go s.accept() go s.accept()
@ -51,6 +53,10 @@ func (s *TestTCPServer) Close() {
if s.l != nil { if s.l != nil {
s.l.Close() s.l.Close()
} }
if s.returnPortsFn != nil {
s.returnPortsFn()
s.returnPortsFn = nil
}
} }
// Addr returns the address that this server is listening on. // Addr returns the address that this server is listening on.

View File

@ -95,6 +95,7 @@ type TestServer struct {
Listening chan struct{} Listening chan struct{}
l net.Listener l net.Listener
returnPortsFn func()
stopFlag int32 stopFlag int32
stopChan chan struct{} stopChan chan struct{}
} }
@ -102,7 +103,7 @@ type TestServer struct {
// NewTestServer returns a TestServer. It should be closed when test is // NewTestServer returns a TestServer. It should be closed when test is
// complete. // complete.
func NewTestServer(t testing.T, service string, ca *structs.CARoot) *TestServer { func NewTestServer(t testing.T, service string, ca *structs.CARoot) *TestServer {
ports := freeport.GetT(t, 1) ports := freeport.MustTake(1)
return &TestServer{ return &TestServer{
Service: service, Service: service,
CA: ca, CA: ca,
@ -110,6 +111,7 @@ func NewTestServer(t testing.T, service string, ca *structs.CARoot) *TestServer
TLSCfg: TestTLSConfig(t, service, ca), TLSCfg: TestTLSConfig(t, service, ca),
Addr: fmt.Sprintf("127.0.0.1:%d", ports[0]), Addr: fmt.Sprintf("127.0.0.1:%d", ports[0]),
Listening: make(chan struct{}), Listening: make(chan struct{}),
returnPortsFn: func() { freeport.Return(ports) },
} }
} }
@ -186,6 +188,10 @@ func (s *TestServer) Close() error {
if s.l != nil { if s.l != nil {
s.l.Close() s.l.Close()
} }
if s.returnPortsFn != nil {
s.returnPortsFn()
s.returnPortsFn = nil
}
close(s.stopChan) close(s.stopChan)
} }
return nil return nil

View File

@ -0,0 +1,7 @@
//+build !linux
package freeport
func getEphemeralPortRange() (int, int, error) {
return 0, 0, nil
}

View File

@ -0,0 +1,36 @@
//+build linux
package freeport
import (
"fmt"
"os/exec"
"regexp"
"strconv"
)
const ephemeralPortRangeSysctlKey = "net.ipv4.ip_local_port_range"
var ephemeralPortRangePatt = regexp.MustCompile(`^\s*(\d+)\s+(\d+)\s*$`)
func getEphemeralPortRange() (int, int, error) {
cmd := exec.Command("sysctl", "-n", ephemeralPortRangeSysctlKey)
out, err := cmd.Output()
if err != nil {
return 0, 0, err
}
val := string(out)
m := ephemeralPortRangePatt.FindStringSubmatch(val)
if m != nil {
min, err1 := strconv.Atoi(m[1])
max, err2 := strconv.Atoi(m[2])
if err1 == nil && err2 == nil {
return min, max, nil
}
}
return 0, 0, fmt.Errorf("unexpected sysctl value %q for key %q", val, ephemeralPortRangeSysctlKey)
}

View File

@ -0,0 +1,18 @@
//+build linux
package freeport
import (
"testing"
)
func TestGetEphemeralPortRange(t *testing.T) {
min, max, err := getEphemeralPortRange()
if err != nil {
t.Fatalf("err: %v", err)
}
if min <= 0 || max <= 0 || min > max {
t.Fatalf("unexpected values: min=%d, max=%d", min, max)
}
t.Logf("min=%d, max=%d", min, max)
}

View File

@ -3,9 +3,12 @@
package freeport package freeport
import ( import (
"container/list"
"fmt" "fmt"
"math/rand" "math/rand"
"net" "net"
"os"
"runtime"
"sync" "sync"
"time" "time"
@ -14,12 +17,10 @@ import (
const ( const (
// blockSize is the size of the allocated port block. ports are given out // blockSize is the size of the allocated port block. ports are given out
// consecutively from that block with roll-over for the lifetime of the // consecutively from that block and after that point in a LRU fashion.
// application/test run.
blockSize = 1500 blockSize = 1500
// maxBlocks is the number of available port blocks. // maxBlocks is the number of available port blocks before exclusions.
// lowPort + maxBlocks * blockSize must be less than 65535.
maxBlocks = 30 maxBlocks = 30
// lowPort is the lowest port number that should be used. // lowPort is the lowest port number that should be used.
@ -31,31 +32,158 @@ const (
) )
var ( var (
// effectiveMaxBlocks is the number of available port blocks.
// lowPort + effectiveMaxBlocks * blockSize must be less than 65535.
effectiveMaxBlocks int
// firstPort is the first port of the allocated block. // firstPort is the first port of the allocated block.
firstPort int firstPort int
// lockLn is the system-wide mutex for the port block. // lockLn is the system-wide mutex for the port block.
lockLn net.Listener lockLn net.Listener
// mu guards nextPort // mu guards:
// - pendingPorts
// - freePorts
// - total
mu sync.Mutex mu sync.Mutex
// once is used to do the initialization on the first call to retrieve free // once is used to do the initialization on the first call to retrieve free
// ports // ports
once sync.Once once sync.Once
// port is the last allocated port. // condNotEmpty is a condition variable to wait for freePorts to be not
port int // empty. Linked to 'mu'
condNotEmpty *sync.Cond
// freePorts is a FIFO of all currently free ports. Take from the front,
// and return to the back.
freePorts *list.List
// pendingPorts is a FIFO of recently freed ports that have not yet passed
// the not-in-use check.
pendingPorts *list.List
// total is the total number of available ports in the block for use.
total int
) )
// initialize is used to initialize freeport. // initialize is used to initialize freeport.
func initialize() { func initialize() {
if lowPort+maxBlocks*blockSize > 65535 { var err error
effectiveMaxBlocks, err = adjustMaxBlocks()
if err != nil {
panic("freeport: ephemeral port range detection failed: " + err.Error())
}
if effectiveMaxBlocks < 0 {
panic("freeport: no blocks of ports available outside of ephemeral range")
}
if lowPort+effectiveMaxBlocks*blockSize > 65535 {
panic("freeport: block size too big or too many blocks requested") panic("freeport: block size too big or too many blocks requested")
} }
rand.Seed(time.Now().UnixNano()) rand.Seed(time.Now().UnixNano())
firstPort, lockLn = alloc() firstPort, lockLn = alloc()
condNotEmpty = sync.NewCond(&mu)
freePorts = list.New()
pendingPorts = list.New()
// fill with all available free ports
for port := firstPort + 1; port < firstPort+blockSize; port++ {
if used := isPortInUse(port); !used {
freePorts.PushBack(port)
}
}
total = freePorts.Len()
go checkFreedPorts()
}
// reset will reverse the setup from initialize() and then redo it (for tests)
func reset() {
mu.Lock()
defer mu.Unlock()
logf("INFO", "resetting the freeport package state")
effectiveMaxBlocks = 0
firstPort = 0
if lockLn != nil {
lockLn.Close()
lockLn = nil
}
once = sync.Once{}
freePorts = nil
pendingPorts = nil
total = 0
}
func checkFreedPorts() {
ticker := time.NewTicker(250 * time.Millisecond)
for {
<-ticker.C
checkFreedPortsOnce()
}
}
func checkFreedPortsOnce() {
mu.Lock()
defer mu.Unlock()
pending := pendingPorts.Len()
remove := make([]*list.Element, 0, pending)
for elem := pendingPorts.Front(); elem != nil; elem = elem.Next() {
port := elem.Value.(int)
if used := isPortInUse(port); !used {
freePorts.PushBack(port)
remove = append(remove, elem)
}
}
retained := pending - len(remove)
if retained > 0 {
logf("WARN", "%d out of %d pending ports are still in use; something probably didn't wait around for the port to be closed!", retained, pending)
}
if len(remove) == 0 {
return
}
for _, elem := range remove {
pendingPorts.Remove(elem)
}
condNotEmpty.Broadcast()
}
// adjustMaxBlocks avoids having the allocation ranges overlap the ephemeral
// port range.
func adjustMaxBlocks() (int, error) {
ephemeralPortMin, ephemeralPortMax, err := getEphemeralPortRange()
if err != nil {
return 0, err
}
if ephemeralPortMin <= 0 || ephemeralPortMax <= 0 {
logf("INFO", "ephemeral port range detection not configured for GOOS=%q", runtime.GOOS)
return maxBlocks, nil
}
logf("INFO", "detected ephemeral port range of [%d, %d]", ephemeralPortMin, ephemeralPortMax)
for block := 0; block < maxBlocks; block++ {
min := lowPort + block*blockSize
max := min + blockSize
overlap := intervalOverlap(min, max-1, ephemeralPortMin, ephemeralPortMax)
if overlap {
logf("INFO", "reducing max blocks from %d to %d to avoid the ephemeral port range", maxBlocks, block)
return block, nil
}
}
return maxBlocks, nil
} }
// alloc reserves a port block for exclusive use for the lifetime of the // alloc reserves a port block for exclusive use for the lifetime of the
@ -64,76 +192,154 @@ func initialize() {
// be automatically released when the application terminates. // be automatically released when the application terminates.
func alloc() (int, net.Listener) { func alloc() (int, net.Listener) {
for i := 0; i < attempts; i++ { for i := 0; i < attempts; i++ {
block := int(rand.Int31n(int32(maxBlocks))) block := int(rand.Int31n(int32(effectiveMaxBlocks)))
firstPort := lowPort + block*blockSize firstPort := lowPort + block*blockSize
ln, err := net.ListenTCP("tcp", tcpAddr("127.0.0.1", firstPort)) ln, err := net.ListenTCP("tcp", tcpAddr("127.0.0.1", firstPort))
if err != nil { if err != nil {
continue continue
} }
// log.Printf("[DEBUG] freeport: allocated port block %d (%d-%d)", block, firstPort, firstPort+blockSize-1) // logf("DEBUG", "allocated port block %d (%d-%d)", block, firstPort, firstPort+blockSize-1)
return firstPort, ln return firstPort, ln
} }
panic("freeport: cannot allocate port block") panic("freeport: cannot allocate port block")
} }
// MustTake is the same as Take except it panics on error.
func MustTake(n int) (ports []int) {
ports, err := Take(n)
if err != nil {
panic(err)
}
return ports
}
// Take returns a list of free ports from the allocated port block. It is safe
// to call this method concurrently. Ports have been tested to be available on
// 127.0.0.1 TCP but there is no guarantee that they will remain free in the
// future.
func Take(n int) (ports []int, err error) {
if n <= 0 {
return nil, fmt.Errorf("freeport: cannot take %d ports", n)
}
mu.Lock()
defer mu.Unlock()
// Reserve a port block
once.Do(initialize)
if n > total {
return nil, fmt.Errorf("freeport: block size too small")
}
for len(ports) < n {
for freePorts.Len() == 0 {
if total == 0 {
return nil, fmt.Errorf("freeport: impossible to satisfy request; there are no actual free ports in the block anymore")
}
condNotEmpty.Wait()
}
elem := freePorts.Front()
freePorts.Remove(elem)
port := elem.Value.(int)
if used := isPortInUse(port); used {
// Something outside of the test suite has stolen this port, possibly
// due to assignment to an ephemeral port, remove it completely.
logf("WARN", "leaked port %d due to theft; removing from circulation", port)
total--
continue
}
ports = append(ports, port)
}
// logf("DEBUG", "free ports: %v", ports)
return ports, nil
}
// peekFree returns the next port that will be returned by Take to aid in testing.
func peekFree() int {
mu.Lock()
defer mu.Unlock()
return freePorts.Front().Value.(int)
}
// peekAllFree returns all free ports that could be returned by Take to aid in testing.
func peekAllFree() []int {
mu.Lock()
defer mu.Unlock()
var out []int
for elem := freePorts.Front(); elem != nil; elem = elem.Next() {
port := elem.Value.(int)
out = append(out, port)
}
return out
}
// stats returns diagnostic data to aid in testing
func stats() (numTotal, numPending, numFree int) {
mu.Lock()
defer mu.Unlock()
return total, pendingPorts.Len(), freePorts.Len()
}
// Return returns a block of ports back to the general pool. These ports should
// have been returned from a call to Take().
func Return(ports []int) {
if len(ports) == 0 {
return // convenience short circuit for test ergonomics
}
mu.Lock()
defer mu.Unlock()
for _, port := range ports {
if port > firstPort && port < firstPort+blockSize {
pendingPorts.PushBack(port)
}
}
}
func isPortInUse(port int) bool {
ln, err := net.ListenTCP("tcp", tcpAddr("127.0.0.1", port))
if err != nil {
return true
}
ln.Close()
return false
}
func tcpAddr(ip string, port int) *net.TCPAddr { func tcpAddr(ip string, port int) *net.TCPAddr {
return &net.TCPAddr{IP: net.ParseIP(ip), Port: port} return &net.TCPAddr{IP: net.ParseIP(ip), Port: port}
} }
// Get wraps the Free function and panics on any failure retrieving ports. // intervalOverlap returns true if the doubly-inclusive integer intervals
func Get(n int) (ports []int) { // represented by [min1, max1] and [min2, max2] overlap.
ports, err := Free(n) func intervalOverlap(min1, max1, min2, max2 int) bool {
if err != nil { if min1 > max1 {
panic(err) logf("WARN", "interval1 is not ordered [%d, %d]", min1, max1)
return false
} }
if min2 > max2 {
return ports logf("WARN", "interval2 is not ordered [%d, %d]", min2, max2)
return false
}
return min1 <= max2 && min2 <= max1
} }
// GetT is suitable for use when retrieving unused ports in tests. If there is func logf(severity string, format string, a ...interface{}) {
// an error retrieving free ports, the test will be failed. fmt.Fprintf(os.Stderr, "["+severity+"] freeport: "+format+"\n", a...)
func GetT(t testing.T, n int) (ports []int) {
ports, err := Free(n)
if err != nil {
t.Fatalf("Failed retrieving free port: %v", err)
}
return ports
} }
// Free returns a list of free ports from the allocated port block. It is safe // Deprecated: Please use Take/Return calls instead.
// to call this method concurrently. Ports have been tested to be available on func Get(n int) (ports []int) { return MustTake(n) }
// 127.0.0.1 TCP but there is no guarantee that they will remain free in the
// future.
func Free(n int) (ports []int, err error) {
mu.Lock()
defer mu.Unlock()
if n > blockSize-1 { // Deprecated: Please use Take/Return calls instead.
return nil, fmt.Errorf("freeport: block size too small") func GetT(t testing.T, n int) (ports []int) { return MustTake(n) }
}
// Reserve a port block // Deprecated: Please use Take/Return calls instead.
once.Do(initialize) func Free(n int) (ports []int, err error) { return MustTake(n), nil }
for len(ports) < n {
port++
// roll-over the port
if port < firstPort+1 || port >= firstPort+blockSize {
port = firstPort + 1
}
// if the port is in use then skip it
ln, err := net.ListenTCP("tcp", tcpAddr("127.0.0.1", port))
if err != nil {
// log.Println("[DEBUG] freeport: port already in use: ", port)
continue
}
ln.Close()
ports = append(ports, port)
}
// log.Println("[DEBUG] freeport: free ports:", ports)
return ports, nil
}

View File

@ -0,0 +1,231 @@
package freeport
import (
"fmt"
"io"
"net"
"testing"
"github.com/hashicorp/consul/sdk/testutil/retry"
)
func TestTakeReturn(t *testing.T) {
// NOTE: for global var reasons this cannot execute in parallel
// t.Parallel()
// Since this test is destructive (i.e. it leaks all ports) it means that
// any other test cases in this package will not function after it runs. To
// help out we reset the global state after we run this test.
defer reset()
// OK: do a simple take/return cycle to trigger the package initialization
func() {
ports, err := Take(1)
if err != nil {
t.Fatalf("err: %v", err)
}
defer Return(ports)
if len(ports) != 1 {
t.Fatalf("expected %d but got %d ports", 1, len(ports))
}
}()
waitForStatsReset := func() (numTotal int) {
t.Helper()
numTotal, numPending, numFree := stats()
if numTotal != numFree+numPending {
t.Fatalf("expected total (%d) and free+pending (%d) ports to match", numTotal, numFree+numPending)
}
retry.Run(t, func(r *retry.R) {
numTotal, numPending, numFree = stats()
if numPending != 0 {
r.Fatalf("pending is still non zero: %d", numPending)
}
if numTotal != numFree {
r.Fatalf("total (%d) does not equal free (%d)", numTotal, numFree)
}
})
return numTotal
}
// Reset
numTotal := waitForStatsReset()
// --------------------
// OK: take the max
func() {
ports, err := Take(numTotal)
if err != nil {
t.Fatalf("err: %v", err)
}
defer Return(ports)
if len(ports) != numTotal {
t.Fatalf("expected %d but got %d ports", numTotal, len(ports))
}
}()
// Reset
numTotal = waitForStatsReset()
expectError := func(expected string, got error) {
t.Helper()
if got == nil {
t.Fatalf("expected error but was nil")
}
if got.Error() != expected {
t.Fatalf("expected error %q but got %q", expected, got.Error())
}
}
// --------------------
// ERROR: take too many ports
func() {
ports, err := Take(numTotal + 1)
defer Return(ports)
expectError("freeport: block size too small", err)
}()
// --------------------
// ERROR: invalid ports request (negative)
func() {
_, err := Take(-1)
expectError("freeport: cannot take -1 ports", err)
}()
// --------------------
// ERROR: invalid ports request (zero)
func() {
_, err := Take(0)
expectError("freeport: cannot take 0 ports", err)
}()
// --------------------
// OK: Steal a port under the covers and let freeport detect the theft and compensate
leakedPort := peekFree()
func() {
leakyListener, err := net.ListenTCP("tcp", tcpAddr("127.0.0.1", leakedPort))
if err != nil {
t.Fatalf("err: %v", err)
}
defer leakyListener.Close()
func() {
ports, err := Take(3)
if err != nil {
t.Fatalf("err: %v", err)
}
defer Return(ports)
if len(ports) != 3 {
t.Fatalf("expected %d but got %d ports", 3, len(ports))
}
for _, port := range ports {
if port == leakedPort {
t.Fatalf("did not expect for Take to return the leaked port")
}
}
}()
newNumTotal := waitForStatsReset()
if newNumTotal != numTotal-1 {
t.Fatalf("expected total to drop to %d but got %d", numTotal-1, newNumTotal)
}
numTotal = newNumTotal // update outer variable for later tests
}()
// --------------------
// OK: sequence it so that one Take must wait on another Take to Return.
func() {
mostPorts, err := Take(numTotal - 5)
if err != nil {
t.Fatalf("err: %v", err)
}
type reply struct {
ports []int
err error
}
ch := make(chan reply, 1)
go func() {
ports, err := Take(10)
ch <- reply{ports: ports, err: err}
}()
Return(mostPorts)
r := <-ch
if r.err != nil {
t.Fatalf("err: %v", r.err)
}
defer Return(r.ports)
if len(r.ports) != 10 {
t.Fatalf("expected %d ports but got %d", 10, len(r.ports))
}
}()
// Reset
numTotal = waitForStatsReset()
// --------------------
// ERROR: Now we end on the crazy "Ocean's 11" level port theft where we
// orchestrate a situation where all ports are stolen and we don't find out
// until Take.
func() {
// 1. Grab all of the ports.
allPorts := peekAllFree()
// 2. Leak all of the ports
leaked := make([]io.Closer, 0, len(allPorts))
defer func() {
for _, c := range leaked {
c.Close()
}
}()
for _, port := range allPorts {
ln, err := net.ListenTCP("tcp", tcpAddr("127.0.0.1", port))
if err != nil {
t.Fatalf("err: %v", err)
}
leaked = append(leaked, ln)
}
// 3. Request 1 port which will detect the leaked ports and fail.
_, err := Take(1)
expectError("freeport: impossible to satisfy request; there are no actual free ports in the block anymore", err)
// 4. Wait for the block to zero out.
newNumTotal := waitForStatsReset()
if newNumTotal != 0 {
t.Fatalf("expected total to drop to %d but got %d", 0, newNumTotal)
}
}()
}
func TestIntervalOverlap(t *testing.T) {
cases := []struct {
min1, max1, min2, max2 int
overlap bool
}{
{0, 0, 0, 0, true},
{1, 1, 1, 1, true},
{1, 3, 1, 3, true}, // same
{1, 3, 4, 6, false}, // serial
{1, 4, 3, 6, true}, // inner overlap
{1, 6, 3, 4, true}, // nest
}
for _, tc := range cases {
t.Run(fmt.Sprintf("%d:%d vs %d:%d", tc.min1, tc.max1, tc.min2, tc.max2), func(t *testing.T) {
if tc.overlap != intervalOverlap(tc.min1, tc.max1, tc.min2, tc.max2) { // 1 vs 2
t.Fatalf("expected %v but got %v", tc.overlap, !tc.overlap)
}
if tc.overlap != intervalOverlap(tc.min2, tc.max2, tc.min1, tc.max1) { // 2 vs 1
t.Fatalf("expected %v but got %v", tc.overlap, !tc.overlap)
}
})
}
}

View File

@ -104,6 +104,7 @@ type TestServerConfig struct {
ReadyTimeout time.Duration `json:"-"` ReadyTimeout time.Duration `json:"-"`
Stdout, Stderr io.Writer `json:"-"` Stdout, Stderr io.Writer `json:"-"`
Args []string `json:"-"` Args []string `json:"-"`
ReturnPorts func() `json:"-"`
} }
type TestACLs struct { type TestACLs struct {
@ -138,7 +139,8 @@ func defaultServerConfig() *TestServerConfig {
panic(err) panic(err)
} }
ports := freeport.Get(6) ports := freeport.MustTake(6)
return &TestServerConfig{ return &TestServerConfig{
NodeName: "node-" + nodeID, NodeName: "node-" + nodeID,
NodeID: nodeID, NodeID: nodeID,
@ -167,6 +169,9 @@ func defaultServerConfig() *TestServerConfig {
"cluster_id": "11111111-2222-3333-4444-555555555555", "cluster_id": "11111111-2222-3333-4444-555555555555",
}, },
}, },
ReturnPorts: func() {
freeport.Return(ports)
},
} }
} }
@ -244,6 +249,7 @@ func newTestServerConfigT(t *testing.T, cb ServerConfigCallback) (*TestServer, e
} }
cfg := defaultServerConfig() cfg := defaultServerConfig()
cfg.DataDir = filepath.Join(tmpdir, "data") cfg.DataDir = filepath.Join(tmpdir, "data")
if cb != nil { if cb != nil {
cb(cfg) cb(cfg)
@ -251,6 +257,7 @@ func newTestServerConfigT(t *testing.T, cb ServerConfigCallback) (*TestServer, e
b, err := json.Marshal(cfg) b, err := json.Marshal(cfg)
if err != nil { if err != nil {
cfg.ReturnPorts()
os.RemoveAll(tmpdir) os.RemoveAll(tmpdir)
return nil, errors.Wrap(err, "failed marshaling json") return nil, errors.Wrap(err, "failed marshaling json")
} }
@ -258,6 +265,7 @@ func newTestServerConfigT(t *testing.T, cb ServerConfigCallback) (*TestServer, e
log.Printf("CONFIG JSON: %s", string(b)) log.Printf("CONFIG JSON: %s", string(b))
configFile := filepath.Join(tmpdir, "config.json") configFile := filepath.Join(tmpdir, "config.json")
if err := ioutil.WriteFile(configFile, b, 0644); err != nil { if err := ioutil.WriteFile(configFile, b, 0644); err != nil {
cfg.ReturnPorts()
os.RemoveAll(tmpdir) os.RemoveAll(tmpdir)
return nil, errors.Wrap(err, "failed writing config content") return nil, errors.Wrap(err, "failed writing config content")
} }
@ -278,6 +286,7 @@ func newTestServerConfigT(t *testing.T, cb ServerConfigCallback) (*TestServer, e
cmd.Stdout = stdout cmd.Stdout = stdout
cmd.Stderr = stderr cmd.Stderr = stderr
if err := cmd.Start(); err != nil { if err := cmd.Start(); err != nil {
cfg.ReturnPorts()
os.RemoveAll(tmpdir) os.RemoveAll(tmpdir)
return nil, errors.Wrap(err, "failed starting command") return nil, errors.Wrap(err, "failed starting command")
} }
@ -319,6 +328,7 @@ func newTestServerConfigT(t *testing.T, cb ServerConfigCallback) (*TestServer, e
// Stop stops the test Consul server, and removes the Consul data // Stop stops the test Consul server, and removes the Consul data
// directory once we are done. // directory once we are done.
func (s *TestServer) Stop() error { func (s *TestServer) Stop() error {
defer s.Config.ReturnPorts()
defer os.RemoveAll(s.tmpdir) defer os.RemoveAll(s.tmpdir)
// There was no process // There was no process