diff --git a/agent/acl_endpoint_test.go b/agent/acl_endpoint_test.go index 9c149c60ef..d1904f5c3f 100644 --- a/agent/acl_endpoint_test.go +++ b/agent/acl_endpoint_test.go @@ -1658,7 +1658,7 @@ func TestACLEndpoint_LoginLogout_jwt(t *testing.T) { testrpc.WaitForLeader(t, a.RPC, "dc1") // spin up a fake oidc server - oidcServer := startSSOTestServer(t) + oidcServer := oidcauthtest.Start(t, oidcauthtest.WithPort(freeport.Port(t))) pubKey, privKey := oidcServer.SigningKeys() type mConfig = map[string]interface{} @@ -2330,14 +2330,6 @@ func upsertTestCustomizedBindingRule(rpc rpcFn, masterToken string, datacenter s return &out, nil } -func startSSOTestServer(t *testing.T) *oidcauthtest.Server { - ports := freeport.MustTake(1) - return oidcauthtest.Start(t, oidcauthtest.WithPort( - ports[0], - func() { freeport.Return(ports) }, - )) -} - func TestHTTPHandlers_ACLReplicationStatus(t *testing.T) { if testing.Short() { t.Skip("too slow for testing.Short") diff --git a/agent/consul/acl_endpoint_test.go b/agent/consul/acl_endpoint_test.go index 54105fe19c..fa0325fec7 100644 --- a/agent/consul/acl_endpoint_test.go +++ b/agent/consul/acl_endpoint_test.go @@ -4868,7 +4868,7 @@ func TestACLEndpoint_Login_jwt(t *testing.T) { acl := ACL{srv: srv} // spin up a fake oidc server - oidcServer := startSSOTestServer(t) + oidcServer := oidcauthtest.Start(t, oidcauthtest.WithPort(freeport.Port(t))) pubKey, privKey := oidcServer.SigningKeys() type mConfig = map[string]interface{} @@ -5003,14 +5003,6 @@ func TestACLEndpoint_Login_jwt(t *testing.T) { } } -func startSSOTestServer(t *testing.T) *oidcauthtest.Server { - ports := freeport.MustTake(1) - return oidcauthtest.Start(t, oidcauthtest.WithPort( - ports[0], - func() { freeport.Return(ports) }, - )) -} - func TestACLEndpoint_Logout(t *testing.T) { if testing.Short() { t.Skip("too slow for testing.Short") diff --git a/agent/consul/authmethod/ssoauth/sso_test.go b/agent/consul/authmethod/ssoauth/sso_test.go index 1623f0904c..7ef3394806 100644 --- a/agent/consul/authmethod/ssoauth/sso_test.go +++ b/agent/consul/authmethod/ssoauth/sso_test.go @@ -32,7 +32,7 @@ func TestJWT_NewValidator(t *testing.T) { return method } - oidcServer := startTestServer(t) + oidcServer := oidcauthtest.Start(t, oidcauthtest.WithPort(freeport.Port(t))) // Note that we won't test ALL of the available config variations here. // The go-sso library has exhaustive tests. @@ -110,7 +110,7 @@ func TestJWT_ValidateLogin(t *testing.T) { return v } - oidcServer := startTestServer(t) + oidcServer := oidcauthtest.Start(t, oidcauthtest.WithPort(freeport.Port(t))) pubKey, privKey := oidcServer.SigningKeys() cases := map[string]struct { @@ -260,11 +260,3 @@ func kv(a ...string) map[string]string { } return m } - -func startTestServer(t *testing.T) *oidcauthtest.Server { - ports := freeport.MustTake(1) - return oidcauthtest.Start(t, oidcauthtest.WithPort( - ports[0], - func() { freeport.Return(ports) }, - )) -} diff --git a/command/login/login_test.go b/command/login/login_test.go index 8c9309b25c..84a0a19170 100644 --- a/command/login/login_test.go +++ b/command/login/login_test.go @@ -352,7 +352,7 @@ func TestLoginCommand_jwt(t *testing.T) { bearerTokenFile := filepath.Join(testDir, "bearer.token") // spin up a fake oidc server - oidcServer := startSSOTestServer(t) + oidcServer := oidcauthtest.Start(t, oidcauthtest.WithPort(freeport.Port(t))) pubKey, privKey := oidcServer.SigningKeys() type mConfig = map[string]interface{} @@ -470,11 +470,3 @@ func TestLoginCommand_jwt(t *testing.T) { }) } } - -func startSSOTestServer(t *testing.T) *oidcauthtest.Server { - ports := freeport.MustTake(1) - return oidcauthtest.Start(t, oidcauthtest.WithPort( - ports[0], - func() { freeport.Return(ports) }, - )) -} diff --git a/internal/go-sso/oidcauth/oidcauthtest/testing.go b/internal/go-sso/oidcauth/oidcauthtest/testing.go index cdf27a19c8..e5cb56ebf6 100644 --- a/internal/go-sso/oidcauth/oidcauthtest/testing.go +++ b/internal/go-sso/oidcauth/oidcauthtest/testing.go @@ -25,7 +25,6 @@ import ( "time" "github.com/hashicorp/consul/internal/go-sso/oidcauth/internal/strutil" - "github.com/mitchellh/go-testing-interface" "github.com/stretchr/testify/require" "gopkg.in/square/go-jose.v2" "gopkg.in/square/go-jose.v2/jwt" @@ -55,24 +54,27 @@ type Server struct { } type startOption struct { - port int - returnFunc func() + port int } // WithPort is a option for Start that lets the caller control the port // allocation. The returnFunc parameter is used when the provider is stopped to // return the port in whatever bookkeeping system the caller wants to use. -func WithPort(port int, returnFunc func()) startOption { - return startOption{ - port: port, - returnFunc: returnFunc, - } +func WithPort(port int) startOption { + return startOption{port: port} +} + +type TestingT interface { + require.TestingT + Helper() + Cleanup(func()) } // Start creates a disposable Server. If the port provided is // zero it will bind to a random free port, otherwise the provided port is // used. -func Start(t testing.T, options ...startOption) *Server { +func Start(t TestingT, options ...startOption) *Server { + t.Helper() s := &Server{ allowedRedirectURIs: []string{ "https://example.com", @@ -89,23 +91,16 @@ func Start(t testing.T, options ...startOption) *Server { require.NoError(t, err) s.jwks = jwks - var ( - port int - returnFunc func() - ) + var port int for _, option := range options { if option.port > 0 { port = option.port - returnFunc = option.returnFunc } } s.httpServer = httptestNewUnstartedServerWithPort(s, port) s.httpServer.Config.ErrorLog = log.New(ioutil.Discard, "", 0) s.httpServer.StartTLS() - if returnFunc != nil { - t.Cleanup(returnFunc) - } t.Cleanup(s.httpServer.Close) cert := s.httpServer.Certificate()