consul/internal/go-sso/oidcauth/oidc_test.go

515 lines
12 KiB
Go
Raw Normal View History

// Copyright (c) HashiCorp, Inc.
[COMPLIANCE] License changes (#18443) * Adding explicit MPL license for sub-package This directory and its subdirectories (packages) contain files licensed with the MPLv2 `LICENSE` file in this directory and are intentionally licensed separately from the BSL `LICENSE` file at the root of this repository. * Adding explicit MPL license for sub-package This directory and its subdirectories (packages) contain files licensed with the MPLv2 `LICENSE` file in this directory and are intentionally licensed separately from the BSL `LICENSE` file at the root of this repository. * Updating the license from MPL to Business Source License Going forward, this project will be licensed under the Business Source License v1.1. Please see our blog post for more details at <Blog URL>, FAQ at www.hashicorp.com/licensing-faq, and details of the license at www.hashicorp.com/bsl. * add missing license headers * Update copyright file headers to BUSL-1.1 * Update copyright file headers to BUSL-1.1 * Update copyright file headers to BUSL-1.1 * Update copyright file headers to BUSL-1.1 * Update copyright file headers to BUSL-1.1 * Update copyright file headers to BUSL-1.1 * Update copyright file headers to BUSL-1.1 * Update copyright file headers to BUSL-1.1 * Update copyright file headers to BUSL-1.1 * Update copyright file headers to BUSL-1.1 * Update copyright file headers to BUSL-1.1 * Update copyright file headers to BUSL-1.1 * Update copyright file headers to BUSL-1.1 * Update copyright file headers to BUSL-1.1 * Update copyright file headers to BUSL-1.1 --------- Co-authored-by: hashicorp-copywrite[bot] <110428419+hashicorp-copywrite[bot]@users.noreply.github.com>
2023-08-11 13:12:13 +00:00
// SPDX-License-Identifier: BUSL-1.1
2020-05-12 01:59:29 +00:00
package oidcauth
import (
"context"
"errors"
"net/url"
"strings"
"testing"
"time"
2024-03-22 14:54:58 +00:00
"github.com/go-jose/go-jose/v3/jwt"
2020-05-12 01:59:29 +00:00
"github.com/hashicorp/consul/internal/go-sso/oidcauth/oidcauthtest"
"github.com/hashicorp/go-hclog"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func setupForOIDC(t *testing.T) (*Authenticator, *oidcauthtest.Server) {
t.Helper()
srv := oidcauthtest.Start(t)
srv.SetClientCreds("abc", "def")
config := &Config{
Type: TypeOIDC,
OIDCDiscoveryURL: srv.Addr(),
OIDCDiscoveryCACert: srv.CACert(),
OIDCClientID: "abc",
OIDCClientSecret: "def",
OIDCACRValues: []string{"acr1", "acr2"},
2020-05-12 01:59:29 +00:00
JWTSupportedAlgs: []string{"ES256"},
BoundAudiences: []string{"abc"},
AllowedRedirectURIs: []string{"https://example.com"},
ClaimMappings: map[string]string{
"COLOR": "color",
"/nested/Size": "size",
"Age": "age",
"Admin": "is_admin",
"/nested/division": "division",
"/nested/remote": "is_remote",
"flavor": "flavor", // userinfo
},
ListClaimMappings: map[string]string{
"/nested/Groups": "groups",
},
}
2020-05-12 01:59:29 +00:00
require.NoError(t, config.Validate())
oa, err := New(config, hclog.NewNullLogger())
require.NoError(t, err)
t.Cleanup(oa.Stop)
return oa, srv
}
func TestOIDC_AuthURL(t *testing.T) {
t.Run("normal case", func(t *testing.T) {
t.Parallel()
oa, _ := setupForOIDC(t)
authURL, err := oa.GetAuthCodeURL(
context.Background(),
"https://example.com",
map[string]string{"foo": "bar"},
)
require.NoError(t, err)
require.True(t, strings.HasPrefix(authURL, oa.config.OIDCDiscoveryURL+"/auth?"))
expected := map[string]string{
"client_id": "abc",
"redirect_uri": "https://example.com",
"response_type": "code",
"scope": "openid",
// optional values
"acr_values": "acr1 acr2",
2020-05-12 01:59:29 +00:00
}
au, err := url.Parse(authURL)
require.NoError(t, err)
for k, v := range expected {
assert.Equal(t, v, au.Query().Get(k), "key %q is incorrect", k)
}
assert.Regexp(t, `^[a-z0-9]{40}$`, au.Query().Get("nonce"))
assert.Regexp(t, `^[a-z0-9]{40}$`, au.Query().Get("state"))
})
t.Run("invalid RedirectURI", func(t *testing.T) {
t.Parallel()
oa, _ := setupForOIDC(t)
_, err := oa.GetAuthCodeURL(
context.Background(),
"http://bitc0in-4-less.cx",
map[string]string{"foo": "bar"},
)
requireErrorContains(t, err, "unauthorized redirect_uri: http://bitc0in-4-less.cx")
})
t.Run("missing RedirectURI", func(t *testing.T) {
t.Parallel()
oa, _ := setupForOIDC(t)
_, err := oa.GetAuthCodeURL(
context.Background(),
"",
map[string]string{"foo": "bar"},
)
requireErrorContains(t, err, "missing redirect_uri")
})
}
func TestOIDC_JWT_Functions_Fail(t *testing.T) {
oa, srv := setupForOIDC(t)
cl := jwt.Claims{
Subject: "r3qXcK2bix9eFECzsU3Sbmh0K16fatW6@clients",
Issuer: srv.Addr(),
NotBefore: jwt.NewNumericDate(time.Now().Add(-5 * time.Second)),
Audience: jwt.Audience{"https://go-sso.test"},
}
privateCl := struct {
User string `json:"https://go-sso/user"`
Groups []string `json:"https://go-sso/groups"`
}{
"jeff",
[]string{"foo", "bar"},
}
jwtData, err := oidcauthtest.SignJWT("", cl, privateCl)
require.NoError(t, err)
_, err = oa.ClaimsFromJWT(context.Background(), jwtData)
requireErrorContains(t, err, `ClaimsFromJWT is incompatible with type "oidc"`)
}
func TestOIDC_ClaimsFromAuthCode(t *testing.T) {
requireProviderError := func(t *testing.T, err error) {
var provErr *ProviderLoginFailedError
if !errors.As(err, &provErr) {
t.Fatalf("error was not a *ProviderLoginFailedError")
}
}
requireTokenVerificationError := func(t *testing.T, err error) {
var tokErr *TokenVerificationFailedError
if !errors.As(err, &tokErr) {
t.Fatalf("error was not a *TokenVerificationFailedError")
}
}
t.Run("successful login", func(t *testing.T) {
oa, srv := setupForOIDC(t)
origPayload := map[string]string{"foo": "bar"}
authURL, err := oa.GetAuthCodeURL(
context.Background(),
"https://example.com",
origPayload,
)
require.NoError(t, err)
state := getQueryParam(t, authURL, "state")
nonce := getQueryParam(t, authURL, "nonce")
// set provider claims that will be returned by the mock server
srv.SetCustomClaims(sampleClaims(nonce))
// set mock provider's expected code
srv.SetExpectedAuthCode("abc")
claims, payload, err := oa.ClaimsFromAuthCode(
context.Background(),
state, "abc",
)
require.NoError(t, err)
require.Equal(t, origPayload, payload)
expectedClaims := &Claims{
Values: map[string]string{
"color": "green",
"size": "medium",
"age": "85",
"is_admin": "true",
"division": "3",
"is_remote": "true",
"flavor": "umami", // from userinfo
},
Lists: map[string][]string{
"groups": {"a", "b"},
2020-05-12 01:59:29 +00:00
},
}
require.Equal(t, expectedClaims, claims)
})
t.Run("failed login unusable claims", func(t *testing.T) {
oa, srv := setupForOIDC(t)
origPayload := map[string]string{"foo": "bar"}
authURL, err := oa.GetAuthCodeURL(
context.Background(),
"https://example.com",
origPayload,
)
require.NoError(t, err)
state := getQueryParam(t, authURL, "state")
nonce := getQueryParam(t, authURL, "nonce")
// set provider claims that will be returned by the mock server
customClaims := sampleClaims(nonce)
customClaims["COLOR"] = []interface{}{"yellow"}
srv.SetCustomClaims(customClaims)
// set mock provider's expected code
srv.SetExpectedAuthCode("abc")
_, _, err = oa.ClaimsFromAuthCode(
context.Background(),
state, "abc",
)
requireErrorContains(t, err, "error converting claim 'COLOR' to string from unknown type []interface {}")
requireTokenVerificationError(t, err)
})
t.Run("successful login - no userinfo", func(t *testing.T) {
oa, srv := setupForOIDC(t)
srv.DisableUserInfo()
origPayload := map[string]string{"foo": "bar"}
authURL, err := oa.GetAuthCodeURL(
context.Background(),
"https://example.com",
origPayload,
)
require.NoError(t, err)
state := getQueryParam(t, authURL, "state")
nonce := getQueryParam(t, authURL, "nonce")
// set provider claims that will be returned by the mock server
srv.SetCustomClaims(sampleClaims(nonce))
// set mock provider's expected code
srv.SetExpectedAuthCode("abc")
claims, payload, err := oa.ClaimsFromAuthCode(
context.Background(),
state, "abc",
)
require.NoError(t, err)
require.Equal(t, origPayload, payload)
expectedClaims := &Claims{
Values: map[string]string{
"color": "green",
"size": "medium",
"age": "85",
"is_admin": "true",
"division": "3",
"is_remote": "true",
// "flavor": "umami", // from userinfo
},
Lists: map[string][]string{
"groups": {"a", "b"},
2020-05-12 01:59:29 +00:00
},
}
require.Equal(t, expectedClaims, claims)
})
t.Run("failed login - bad nonce", func(t *testing.T) {
t.Parallel()
oa, srv := setupForOIDC(t)
origPayload := map[string]string{"foo": "bar"}
authURL, err := oa.GetAuthCodeURL(
context.Background(),
"https://example.com",
origPayload,
)
require.NoError(t, err)
state := getQueryParam(t, authURL, "state")
srv.SetCustomClaims(sampleClaims("bad nonce"))
// set mock provider's expected code
srv.SetExpectedAuthCode("abc")
_, _, err = oa.ClaimsFromAuthCode(
context.Background(),
state, "abc",
)
requireErrorContains(t, err, "Invalid ID token nonce")
requireTokenVerificationError(t, err)
})
t.Run("missing state", func(t *testing.T) {
oa, _ := setupForOIDC(t)
origPayload := map[string]string{"foo": "bar"}
_, err := oa.GetAuthCodeURL(
context.Background(),
"https://example.com",
origPayload,
)
require.NoError(t, err)
_, _, err = oa.ClaimsFromAuthCode(
context.Background(),
"", "abc",
)
requireErrorContains(t, err, "Expired or missing OAuth state")
requireProviderError(t, err)
})
t.Run("unknown state", func(t *testing.T) {
oa, _ := setupForOIDC(t)
origPayload := map[string]string{"foo": "bar"}
_, err := oa.GetAuthCodeURL(
context.Background(),
"https://example.com",
origPayload,
)
require.NoError(t, err)
_, _, err = oa.ClaimsFromAuthCode(
context.Background(),
"not_a_state", "abc",
)
requireErrorContains(t, err, "Expired or missing OAuth state")
requireProviderError(t, err)
})
t.Run("valid state, missing code", func(t *testing.T) {
oa, _ := setupForOIDC(t)
origPayload := map[string]string{"foo": "bar"}
authURL, err := oa.GetAuthCodeURL(
context.Background(),
"https://example.com",
origPayload,
)
require.NoError(t, err)
state := getQueryParam(t, authURL, "state")
_, _, err = oa.ClaimsFromAuthCode(
context.Background(),
state, "",
)
requireErrorContains(t, err, "OAuth code parameter not provided")
requireProviderError(t, err)
})
t.Run("failed code exchange", func(t *testing.T) {
oa, srv := setupForOIDC(t)
origPayload := map[string]string{"foo": "bar"}
authURL, err := oa.GetAuthCodeURL(
context.Background(),
"https://example.com",
origPayload,
)
require.NoError(t, err)
state := getQueryParam(t, authURL, "state")
// set mock provider's expected code
srv.SetExpectedAuthCode("abc")
_, _, err = oa.ClaimsFromAuthCode(
context.Background(),
state, "wrong_code",
)
[Cloud][CC-6925] Updates to pushing server state (#19682) * Upgrade hcp-sdk-go to latest version v0.73 Changes: - go get github.com/hashicorp/hcp-sdk-go - go mod tidy * From upgrade: regenerate protobufs for upgrade from 1.30 to 1.31 Ran: `make proto` Slack: https://hashicorp.slack.com/archives/C0253EQ5B40/p1701105418579429 * From upgrade: fix mock interface implementation After upgrading, there is the following compile error: cannot use &mockHCPCfg{} (value of type *mockHCPCfg) as "github.com/hashicorp/hcp-sdk-go/config".HCPConfig value in return statement: *mockHCPCfg does not implement "github.com/hashicorp/hcp-sdk-go/config".HCPConfig (missing method Logout) Solution: update the mock to have the missing Logout method * From upgrade: Lint: remove usage of deprecated req.ServerState.TLS Due to upgrade, linting is erroring due to usage of a newly deprecated field 22:47:56 [consul]: make lint --> Running golangci-lint (.) agent/hcp/testing.go:157:24: SA1019: req.ServerState.TLS is deprecated: use server_tls.internal_rpc instead. (staticcheck) time.Until(time.Time(req.ServerState.TLS.CertExpiry)).Hours()/24, ^ * From upgrade: adjust oidc error message From the upgrade, this test started failing: === FAIL: internal/go-sso/oidcauth TestOIDC_ClaimsFromAuthCode/failed_code_exchange (re-run 2) (0.01s) oidc_test.go:393: unexpected error: Provider login failed: Error exchanging oidc code: oauth2: "invalid_grant" "unexpected auth code" Prior to the upgrade, the error returned was: ``` Provider login failed: Error exchanging oidc code: oauth2: cannot fetch token: 401 Unauthorized\nResponse: {\"error\":\"invalid_grant\",\"error_description\":\"unexpected auth code\"}\n ``` Now the error returned is as below and does not contain "cannot fetch token" ``` Provider login failed: Error exchanging oidc code: oauth2: "invalid_grant" "unexpected auth code" ``` * Update AgentPushServerState structs with new fields HCP-side changes for the new fields are in: https://github.com/hashicorp/cloud-global-network-manager-service/pull/1195/files * Minor refactor for hcpServerStatus to abstract tlsInfo into struct This will make it easier to set the same tls-info information to both - status.TLS (deprecated field) - status.ServerTLSMetadata (new field to use instead) * Update hcpServerStatus to parse out information for new fields Changes: - Improve error message and handling (encountered some issues and was confused) - Set new field TLSInfo.CertIssuer - Collect certificate authority metadata and set on TLSInfo.CertificateAuthorities - Set TLSInfo on both server.TLS and server.ServerTLSMetadata.InternalRPC * Update serverStatusToHCP to convert new fields to GNM rpc * Add changelog * Feedback: connect.ParseCert, caCerts * Feedback: refactor and unit test server status * Feedback: test to use expected struct * Feedback: certificate with intermediate * Feedback: catch no leaf, remove expectedErr * Feedback: update todos with jira ticket * Feedback: mock tlsConfigurator
2023-12-04 15:25:18 +00:00
requireErrorContains(t, err, "Error exchanging oidc code")
2020-05-12 01:59:29 +00:00
requireProviderError(t, err)
})
t.Run("no id_token returned", func(t *testing.T) {
oa, srv := setupForOIDC(t)
origPayload := map[string]string{"foo": "bar"}
authURL, err := oa.GetAuthCodeURL(
context.Background(),
"https://example.com",
origPayload,
)
require.NoError(t, err)
state := getQueryParam(t, authURL, "state")
nonce := getQueryParam(t, authURL, "nonce")
// set provider claims that will be returned by the mock server
srv.SetCustomClaims(sampleClaims(nonce))
// set mock provider's expected code
srv.SetExpectedAuthCode("abc")
srv.OmitIDTokens()
_, _, err = oa.ClaimsFromAuthCode(
context.Background(),
state, "abc",
)
requireErrorContains(t, err, "No id_token found in response")
requireTokenVerificationError(t, err)
})
t.Run("no response from provider", func(t *testing.T) {
oa, srv := setupForOIDC(t)
origPayload := map[string]string{"foo": "bar"}
authURL, err := oa.GetAuthCodeURL(
context.Background(),
"https://example.com",
origPayload,
)
require.NoError(t, err)
state := getQueryParam(t, authURL, "state")
// close the server prematurely
srv.Stop()
srv.SetExpectedAuthCode("abc")
_, _, err = oa.ClaimsFromAuthCode(
context.Background(),
state, "abc",
)
requireErrorContains(t, err, "connection refused")
requireProviderError(t, err)
})
t.Run("invalid bound audience", func(t *testing.T) {
oa, srv := setupForOIDC(t)
srv.SetClientCreds("not_gonna_match", "def")
origPayload := map[string]string{"foo": "bar"}
authURL, err := oa.GetAuthCodeURL(
context.Background(),
"https://example.com",
origPayload,
)
require.NoError(t, err)
state := getQueryParam(t, authURL, "state")
nonce := getQueryParam(t, authURL, "nonce")
// set provider claims that will be returned by the mock server
srv.SetCustomClaims(sampleClaims(nonce))
// set mock provider's expected code
srv.SetExpectedAuthCode("abc")
_, _, err = oa.ClaimsFromAuthCode(
context.Background(),
state, "abc",
)
requireErrorContains(t, err, `error validating signature: oidc: expected audience "abc" got ["not_gonna_match"]`)
requireTokenVerificationError(t, err)
})
}
func sampleClaims(nonce string) map[string]interface{} {
return map[string]interface{}{
"nonce": nonce,
"email": "bob@example.com",
"COLOR": "green",
"sk": "42",
"Age": 85,
"Admin": true,
"nested": map[string]interface{}{
"Size": "medium",
"division": 3,
"remote": true,
"Groups": []string{"a", "b"},
"secret_code": "bar",
},
"password": "foo",
}
}
func getQueryParam(t *testing.T, inputURL, param string) string {
t.Helper()
m, err := url.ParseQuery(inputURL)
if err != nil {
t.Fatal(err)
}
v, ok := m[param]
if !ok {
t.Fatalf("query param %q not found", param)
}
return v[0]
}