Close active listeners on error

If startListeners successfully created listeners for some of its input addresses but eventually failed, the function would return an error and existing listeners would not be cleaned up.
This commit is contained in:
Chris S. Kim 2022-08-09 12:22:39 -04:00
parent 6311c651de
commit e3046120b3
3 changed files with 86 additions and 3 deletions

3
.changelog/14081.txt Normal file
View File

@ -0,0 +1,3 @@
```release-note:bug
agent: Fixes an issue where an agent that fails to start due to bad addresses won't clean up any existing listeners
```

View File

@ -863,8 +863,18 @@ func (a *Agent) listenAndServeDNS() error {
return merr.ErrorOrNil() return merr.ErrorOrNil()
} }
// startListeners will return a net.Listener for every address unless an
// error is encountered, in which case it will close all previously opened
// listeners and return the error.
func (a *Agent) startListeners(addrs []net.Addr) ([]net.Listener, error) { func (a *Agent) startListeners(addrs []net.Addr) ([]net.Listener, error) {
var ln []net.Listener var lns []net.Listener
closeAll := func() {
for _, l := range lns {
l.Close()
}
}
for _, addr := range addrs { for _, addr := range addrs {
var l net.Listener var l net.Listener
var err error var err error
@ -873,22 +883,25 @@ func (a *Agent) startListeners(addrs []net.Addr) ([]net.Listener, error) {
case *net.UnixAddr: case *net.UnixAddr:
l, err = a.listenSocket(x.Name) l, err = a.listenSocket(x.Name)
if err != nil { if err != nil {
closeAll()
return nil, err return nil, err
} }
case *net.TCPAddr: case *net.TCPAddr:
l, err = net.Listen("tcp", x.String()) l, err = net.Listen("tcp", x.String())
if err != nil { if err != nil {
closeAll()
return nil, err return nil, err
} }
l = &tcpKeepAliveListener{l.(*net.TCPListener)} l = &tcpKeepAliveListener{l.(*net.TCPListener)}
default: default:
closeAll()
return nil, fmt.Errorf("unsupported address type %T", addr) return nil, fmt.Errorf("unsupported address type %T", addr)
} }
ln = append(ln, l) lns = append(lns, l)
} }
return ln, nil return lns, nil
} }
// listenHTTP binds listeners to the provided addresses and also returns // listenHTTP binds listeners to the provided addresses and also returns

View File

@ -5857,6 +5857,73 @@ func Test_coalesceTimerTwoPeriods(t *testing.T) {
} }
func TestAgent_startListeners(t *testing.T) {
if testing.Short() {
t.Skip("too slow for testing.Short")
}
t.Parallel()
ports := freeport.GetN(t, 3)
bd := BaseDeps{
Deps: consul.Deps{
Logger: hclog.NewInterceptLogger(nil),
Tokens: new(token.Store),
GRPCConnPool: &fakeGRPCConnPool{},
},
RuntimeConfig: &config.RuntimeConfig{
HTTPAddrs: []net.Addr{},
},
Cache: cache.New(cache.Options{}),
}
bd, err := initEnterpriseBaseDeps(bd, nil)
require.NoError(t, err)
agent, err := New(bd)
require.NoError(t, err)
// use up an address
used := net.TCPAddr{IP: net.ParseIP("127.0.0.1"), Port: ports[2]}
l, err := net.Listen("tcp", used.String())
require.NoError(t, err)
t.Cleanup(func() { l.Close() })
var lns []net.Listener
t.Cleanup(func() {
for _, ln := range lns {
ln.Close()
}
})
// first two addresses open listeners but third address should fail
lns, err = agent.startListeners([]net.Addr{
&net.TCPAddr{IP: net.ParseIP("127.0.0.1"), Port: ports[0]},
&net.TCPAddr{IP: net.ParseIP("127.0.0.1"), Port: ports[1]},
&net.TCPAddr{IP: net.ParseIP("127.0.0.1"), Port: ports[2]},
})
require.Contains(t, err.Error(), "address already in use")
// first two ports should be freed up
retry.Run(t, func(r *retry.R) {
lns, err = agent.startListeners([]net.Addr{
&net.TCPAddr{IP: net.ParseIP("127.0.0.1"), Port: ports[0]},
&net.TCPAddr{IP: net.ParseIP("127.0.0.1"), Port: ports[1]},
})
require.NoError(r, err)
require.Len(r, lns, 2)
})
// first two ports should be in use
retry.Run(t, func(r *retry.R) {
_, err = agent.startListeners([]net.Addr{
&net.TCPAddr{IP: net.ParseIP("127.0.0.1"), Port: ports[0]},
&net.TCPAddr{IP: net.ParseIP("127.0.0.1"), Port: ports[1]},
})
require.Contains(r, err.Error(), "address already in use")
})
}
func getExpectedCaPoolByFile(t *testing.T) *x509.CertPool { func getExpectedCaPoolByFile(t *testing.T) *x509.CertPool {
pool := x509.NewCertPool() pool := x509.NewCertPool()
data, err := ioutil.ReadFile("../test/ca/root.cer") data, err := ioutil.ReadFile("../test/ca/root.cer")