lornasong edf4610ed9
[Cloud][CC-6925] Updates to pushing server state (#19682)
* Upgrade hcp-sdk-go to latest version v0.73

Changes:
- go get github.com/hashicorp/hcp-sdk-go
- go mod tidy

* From upgrade: regenerate protobufs for upgrade from 1.30 to 1.31

Ran: `make proto`

Slack: https://hashicorp.slack.com/archives/C0253EQ5B40/p1701105418579429

* From upgrade: fix mock interface implementation

After upgrading, there is the following compile error:

cannot use &mockHCPCfg{} (value of type *mockHCPCfg) as "github.com/hashicorp/hcp-sdk-go/config".HCPConfig value in return statement: *mockHCPCfg does not implement "github.com/hashicorp/hcp-sdk-go/config".HCPConfig (missing method Logout)

Solution: update the mock to have the missing Logout method

* From upgrade: Lint: remove usage of deprecated req.ServerState.TLS

Due to upgrade, linting is erroring due to usage of a newly deprecated field

22:47:56 [consul]: make lint
--> Running golangci-lint (.)
agent/hcp/testing.go:157:24: SA1019: req.ServerState.TLS is deprecated: use server_tls.internal_rpc instead. (staticcheck)
                time.Until(time.Time(req.ServerState.TLS.CertExpiry)).Hours()/24,
                                     ^

* From upgrade: adjust oidc error message

From the upgrade, this test started failing:

=== FAIL: internal/go-sso/oidcauth TestOIDC_ClaimsFromAuthCode/failed_code_exchange (re-run 2) (0.01s)
    oidc_test.go:393: unexpected error: Provider login failed: Error exchanging oidc code: oauth2: "invalid_grant" "unexpected auth code"

Prior to the upgrade, the error returned was:
```
Provider login failed: Error exchanging oidc code: oauth2: cannot fetch token: 401 Unauthorized\nResponse: {\"error\":\"invalid_grant\",\"error_description\":\"unexpected auth code\"}\n
```

Now the error returned is as below and does not contain "cannot fetch token"
```
Provider login failed: Error exchanging oidc code: oauth2: "invalid_grant" "unexpected auth code"

```

* Update AgentPushServerState structs with new fields

HCP-side changes for the new fields are in:
https://github.com/hashicorp/cloud-global-network-manager-service/pull/1195/files

* Minor refactor for hcpServerStatus to abstract tlsInfo into struct

This will make it easier to set the same tls-info information to both
 - status.TLS (deprecated field)
 - status.ServerTLSMetadata (new field to use instead)

* Update hcpServerStatus to parse out information for new fields

Changes:
 - Improve error message and handling (encountered some issues and was confused)
 - Set new field TLSInfo.CertIssuer
 - Collect certificate authority metadata and set on TLSInfo.CertificateAuthorities
 - Set TLSInfo on both server.TLS and server.ServerTLSMetadata.InternalRPC

* Update serverStatusToHCP to convert new fields to GNM rpc

* Add changelog

* Feedback: connect.ParseCert, caCerts

* Feedback: refactor and unit test server status

* Feedback: test to use expected struct

* Feedback: certificate with intermediate

* Feedback: catch no leaf, remove expectedErr

* Feedback: update todos with jira ticket

* Feedback: mock tlsConfigurator
2023-12-04 10:25:18 -05:00

515 lines
12 KiB
Go

// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: BUSL-1.1
package oidcauth
import (
"context"
"errors"
"net/url"
"strings"
"testing"
"time"
"github.com/hashicorp/consul/internal/go-sso/oidcauth/oidcauthtest"
"github.com/hashicorp/go-hclog"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"gopkg.in/square/go-jose.v2/jwt"
)
func setupForOIDC(t *testing.T) (*Authenticator, *oidcauthtest.Server) {
t.Helper()
srv := oidcauthtest.Start(t)
srv.SetClientCreds("abc", "def")
config := &Config{
Type: TypeOIDC,
OIDCDiscoveryURL: srv.Addr(),
OIDCDiscoveryCACert: srv.CACert(),
OIDCClientID: "abc",
OIDCClientSecret: "def",
OIDCACRValues: []string{"acr1", "acr2"},
JWTSupportedAlgs: []string{"ES256"},
BoundAudiences: []string{"abc"},
AllowedRedirectURIs: []string{"https://example.com"},
ClaimMappings: map[string]string{
"COLOR": "color",
"/nested/Size": "size",
"Age": "age",
"Admin": "is_admin",
"/nested/division": "division",
"/nested/remote": "is_remote",
"flavor": "flavor", // userinfo
},
ListClaimMappings: map[string]string{
"/nested/Groups": "groups",
},
}
require.NoError(t, config.Validate())
oa, err := New(config, hclog.NewNullLogger())
require.NoError(t, err)
t.Cleanup(oa.Stop)
return oa, srv
}
func TestOIDC_AuthURL(t *testing.T) {
t.Run("normal case", func(t *testing.T) {
t.Parallel()
oa, _ := setupForOIDC(t)
authURL, err := oa.GetAuthCodeURL(
context.Background(),
"https://example.com",
map[string]string{"foo": "bar"},
)
require.NoError(t, err)
require.True(t, strings.HasPrefix(authURL, oa.config.OIDCDiscoveryURL+"/auth?"))
expected := map[string]string{
"client_id": "abc",
"redirect_uri": "https://example.com",
"response_type": "code",
"scope": "openid",
// optional values
"acr_values": "acr1 acr2",
}
au, err := url.Parse(authURL)
require.NoError(t, err)
for k, v := range expected {
assert.Equal(t, v, au.Query().Get(k), "key %q is incorrect", k)
}
assert.Regexp(t, `^[a-z0-9]{40}$`, au.Query().Get("nonce"))
assert.Regexp(t, `^[a-z0-9]{40}$`, au.Query().Get("state"))
})
t.Run("invalid RedirectURI", func(t *testing.T) {
t.Parallel()
oa, _ := setupForOIDC(t)
_, err := oa.GetAuthCodeURL(
context.Background(),
"http://bitc0in-4-less.cx",
map[string]string{"foo": "bar"},
)
requireErrorContains(t, err, "unauthorized redirect_uri: http://bitc0in-4-less.cx")
})
t.Run("missing RedirectURI", func(t *testing.T) {
t.Parallel()
oa, _ := setupForOIDC(t)
_, err := oa.GetAuthCodeURL(
context.Background(),
"",
map[string]string{"foo": "bar"},
)
requireErrorContains(t, err, "missing redirect_uri")
})
}
func TestOIDC_JWT_Functions_Fail(t *testing.T) {
oa, srv := setupForOIDC(t)
cl := jwt.Claims{
Subject: "r3qXcK2bix9eFECzsU3Sbmh0K16fatW6@clients",
Issuer: srv.Addr(),
NotBefore: jwt.NewNumericDate(time.Now().Add(-5 * time.Second)),
Audience: jwt.Audience{"https://go-sso.test"},
}
privateCl := struct {
User string `json:"https://go-sso/user"`
Groups []string `json:"https://go-sso/groups"`
}{
"jeff",
[]string{"foo", "bar"},
}
jwtData, err := oidcauthtest.SignJWT("", cl, privateCl)
require.NoError(t, err)
_, err = oa.ClaimsFromJWT(context.Background(), jwtData)
requireErrorContains(t, err, `ClaimsFromJWT is incompatible with type "oidc"`)
}
func TestOIDC_ClaimsFromAuthCode(t *testing.T) {
requireProviderError := func(t *testing.T, err error) {
var provErr *ProviderLoginFailedError
if !errors.As(err, &provErr) {
t.Fatalf("error was not a *ProviderLoginFailedError")
}
}
requireTokenVerificationError := func(t *testing.T, err error) {
var tokErr *TokenVerificationFailedError
if !errors.As(err, &tokErr) {
t.Fatalf("error was not a *TokenVerificationFailedError")
}
}
t.Run("successful login", func(t *testing.T) {
oa, srv := setupForOIDC(t)
origPayload := map[string]string{"foo": "bar"}
authURL, err := oa.GetAuthCodeURL(
context.Background(),
"https://example.com",
origPayload,
)
require.NoError(t, err)
state := getQueryParam(t, authURL, "state")
nonce := getQueryParam(t, authURL, "nonce")
// set provider claims that will be returned by the mock server
srv.SetCustomClaims(sampleClaims(nonce))
// set mock provider's expected code
srv.SetExpectedAuthCode("abc")
claims, payload, err := oa.ClaimsFromAuthCode(
context.Background(),
state, "abc",
)
require.NoError(t, err)
require.Equal(t, origPayload, payload)
expectedClaims := &Claims{
Values: map[string]string{
"color": "green",
"size": "medium",
"age": "85",
"is_admin": "true",
"division": "3",
"is_remote": "true",
"flavor": "umami", // from userinfo
},
Lists: map[string][]string{
"groups": {"a", "b"},
},
}
require.Equal(t, expectedClaims, claims)
})
t.Run("failed login unusable claims", func(t *testing.T) {
oa, srv := setupForOIDC(t)
origPayload := map[string]string{"foo": "bar"}
authURL, err := oa.GetAuthCodeURL(
context.Background(),
"https://example.com",
origPayload,
)
require.NoError(t, err)
state := getQueryParam(t, authURL, "state")
nonce := getQueryParam(t, authURL, "nonce")
// set provider claims that will be returned by the mock server
customClaims := sampleClaims(nonce)
customClaims["COLOR"] = []interface{}{"yellow"}
srv.SetCustomClaims(customClaims)
// set mock provider's expected code
srv.SetExpectedAuthCode("abc")
_, _, err = oa.ClaimsFromAuthCode(
context.Background(),
state, "abc",
)
requireErrorContains(t, err, "error converting claim 'COLOR' to string from unknown type []interface {}")
requireTokenVerificationError(t, err)
})
t.Run("successful login - no userinfo", func(t *testing.T) {
oa, srv := setupForOIDC(t)
srv.DisableUserInfo()
origPayload := map[string]string{"foo": "bar"}
authURL, err := oa.GetAuthCodeURL(
context.Background(),
"https://example.com",
origPayload,
)
require.NoError(t, err)
state := getQueryParam(t, authURL, "state")
nonce := getQueryParam(t, authURL, "nonce")
// set provider claims that will be returned by the mock server
srv.SetCustomClaims(sampleClaims(nonce))
// set mock provider's expected code
srv.SetExpectedAuthCode("abc")
claims, payload, err := oa.ClaimsFromAuthCode(
context.Background(),
state, "abc",
)
require.NoError(t, err)
require.Equal(t, origPayload, payload)
expectedClaims := &Claims{
Values: map[string]string{
"color": "green",
"size": "medium",
"age": "85",
"is_admin": "true",
"division": "3",
"is_remote": "true",
// "flavor": "umami", // from userinfo
},
Lists: map[string][]string{
"groups": {"a", "b"},
},
}
require.Equal(t, expectedClaims, claims)
})
t.Run("failed login - bad nonce", func(t *testing.T) {
t.Parallel()
oa, srv := setupForOIDC(t)
origPayload := map[string]string{"foo": "bar"}
authURL, err := oa.GetAuthCodeURL(
context.Background(),
"https://example.com",
origPayload,
)
require.NoError(t, err)
state := getQueryParam(t, authURL, "state")
srv.SetCustomClaims(sampleClaims("bad nonce"))
// set mock provider's expected code
srv.SetExpectedAuthCode("abc")
_, _, err = oa.ClaimsFromAuthCode(
context.Background(),
state, "abc",
)
requireErrorContains(t, err, "Invalid ID token nonce")
requireTokenVerificationError(t, err)
})
t.Run("missing state", func(t *testing.T) {
oa, _ := setupForOIDC(t)
origPayload := map[string]string{"foo": "bar"}
_, err := oa.GetAuthCodeURL(
context.Background(),
"https://example.com",
origPayload,
)
require.NoError(t, err)
_, _, err = oa.ClaimsFromAuthCode(
context.Background(),
"", "abc",
)
requireErrorContains(t, err, "Expired or missing OAuth state")
requireProviderError(t, err)
})
t.Run("unknown state", func(t *testing.T) {
oa, _ := setupForOIDC(t)
origPayload := map[string]string{"foo": "bar"}
_, err := oa.GetAuthCodeURL(
context.Background(),
"https://example.com",
origPayload,
)
require.NoError(t, err)
_, _, err = oa.ClaimsFromAuthCode(
context.Background(),
"not_a_state", "abc",
)
requireErrorContains(t, err, "Expired or missing OAuth state")
requireProviderError(t, err)
})
t.Run("valid state, missing code", func(t *testing.T) {
oa, _ := setupForOIDC(t)
origPayload := map[string]string{"foo": "bar"}
authURL, err := oa.GetAuthCodeURL(
context.Background(),
"https://example.com",
origPayload,
)
require.NoError(t, err)
state := getQueryParam(t, authURL, "state")
_, _, err = oa.ClaimsFromAuthCode(
context.Background(),
state, "",
)
requireErrorContains(t, err, "OAuth code parameter not provided")
requireProviderError(t, err)
})
t.Run("failed code exchange", func(t *testing.T) {
oa, srv := setupForOIDC(t)
origPayload := map[string]string{"foo": "bar"}
authURL, err := oa.GetAuthCodeURL(
context.Background(),
"https://example.com",
origPayload,
)
require.NoError(t, err)
state := getQueryParam(t, authURL, "state")
// set mock provider's expected code
srv.SetExpectedAuthCode("abc")
_, _, err = oa.ClaimsFromAuthCode(
context.Background(),
state, "wrong_code",
)
requireErrorContains(t, err, "Error exchanging oidc code")
requireProviderError(t, err)
})
t.Run("no id_token returned", func(t *testing.T) {
oa, srv := setupForOIDC(t)
origPayload := map[string]string{"foo": "bar"}
authURL, err := oa.GetAuthCodeURL(
context.Background(),
"https://example.com",
origPayload,
)
require.NoError(t, err)
state := getQueryParam(t, authURL, "state")
nonce := getQueryParam(t, authURL, "nonce")
// set provider claims that will be returned by the mock server
srv.SetCustomClaims(sampleClaims(nonce))
// set mock provider's expected code
srv.SetExpectedAuthCode("abc")
srv.OmitIDTokens()
_, _, err = oa.ClaimsFromAuthCode(
context.Background(),
state, "abc",
)
requireErrorContains(t, err, "No id_token found in response")
requireTokenVerificationError(t, err)
})
t.Run("no response from provider", func(t *testing.T) {
oa, srv := setupForOIDC(t)
origPayload := map[string]string{"foo": "bar"}
authURL, err := oa.GetAuthCodeURL(
context.Background(),
"https://example.com",
origPayload,
)
require.NoError(t, err)
state := getQueryParam(t, authURL, "state")
// close the server prematurely
srv.Stop()
srv.SetExpectedAuthCode("abc")
_, _, err = oa.ClaimsFromAuthCode(
context.Background(),
state, "abc",
)
requireErrorContains(t, err, "connection refused")
requireProviderError(t, err)
})
t.Run("invalid bound audience", func(t *testing.T) {
oa, srv := setupForOIDC(t)
srv.SetClientCreds("not_gonna_match", "def")
origPayload := map[string]string{"foo": "bar"}
authURL, err := oa.GetAuthCodeURL(
context.Background(),
"https://example.com",
origPayload,
)
require.NoError(t, err)
state := getQueryParam(t, authURL, "state")
nonce := getQueryParam(t, authURL, "nonce")
// set provider claims that will be returned by the mock server
srv.SetCustomClaims(sampleClaims(nonce))
// set mock provider's expected code
srv.SetExpectedAuthCode("abc")
_, _, err = oa.ClaimsFromAuthCode(
context.Background(),
state, "abc",
)
requireErrorContains(t, err, `error validating signature: oidc: expected audience "abc" got ["not_gonna_match"]`)
requireTokenVerificationError(t, err)
})
}
func sampleClaims(nonce string) map[string]interface{} {
return map[string]interface{}{
"nonce": nonce,
"email": "bob@example.com",
"COLOR": "green",
"sk": "42",
"Age": 85,
"Admin": true,
"nested": map[string]interface{}{
"Size": "medium",
"division": 3,
"remote": true,
"Groups": []string{"a", "b"},
"secret_code": "bar",
},
"password": "foo",
}
}
func getQueryParam(t *testing.T, inputURL, param string) string {
t.Helper()
m, err := url.ParseQuery(inputURL)
if err != nil {
t.Fatal(err)
}
v, ok := m[param]
if !ok {
t.Fatalf("query param %q not found", param)
}
return v[0]
}