// Copyright (c) HashiCorp, Inc. // SPDX-License-Identifier: BUSL-1.1 package oidcauth import ( "context" "crypto/tls" "crypto/x509" "encoding/json" "errors" "fmt" "net/http" "strconv" "strings" "github.com/hashicorp/go-cleanhttp" "github.com/hashicorp/go-hclog" "github.com/mitchellh/pointerstructure" "golang.org/x/oauth2" "github.com/hashicorp/consul/internal/go-sso/oidcauth/internal/strutil" ) func contextWithHttpClient(ctx context.Context, client *http.Client) context.Context { return context.WithValue(ctx, oauth2.HTTPClient, client) } func createHTTPClient(caCert string) (*http.Client, error) { tr := cleanhttp.DefaultPooledTransport() if caCert != "" { certPool := x509.NewCertPool() if ok := certPool.AppendCertsFromPEM([]byte(caCert)); !ok { return nil, errors.New("could not parse CA PEM value successfully") } tr.TLSClientConfig = &tls.Config{ RootCAs: certPool, } } return &http.Client{ Transport: tr, }, nil } // extractClaims extracts all configured claims from the received claims. func (a *Authenticator) extractClaims(allClaims map[string]interface{}) (*Claims, error) { metadata, err := extractStringMetadata(a.logger, allClaims, a.config.ClaimMappings) if err != nil { return nil, err } listMetadata, err := extractListMetadata(a.logger, allClaims, a.config.ListClaimMappings) if err != nil { return nil, err } return &Claims{ Values: metadata, Lists: listMetadata, }, nil } // extractStringMetadata builds a metadata map of string values from a set of // claims and claims mappings. The referenced claims must be strings and the // claims mappings must be of the structure: // // { // "/some/claim/pointer": "metadata_key1", // "another_claim": "metadata_key2", // ... // } func extractStringMetadata(logger hclog.Logger, allClaims map[string]interface{}, claimMappings map[string]string) (map[string]string, error) { metadata := make(map[string]string) for source, target := range claimMappings { rawValue := getClaim(logger, allClaims, source) if rawValue == nil { continue } strValue, ok := stringifyMetadataValue(rawValue) if !ok { return nil, fmt.Errorf("error converting claim '%s' to string from unknown type %T", source, rawValue) } metadata[target] = strValue } return metadata, nil } // extractListMetadata builds a metadata map of string list values from a set // of claims and claims mappings. The referenced claims must be strings and // the claims mappings must be of the structure: // // { // "/some/claim/pointer": "metadata_key1", // "another_claim": "metadata_key2", // ... // } func extractListMetadata(logger hclog.Logger, allClaims map[string]interface{}, listClaimMappings map[string]string) (map[string][]string, error) { out := make(map[string][]string) for source, target := range listClaimMappings { if rawValue := getClaim(logger, allClaims, source); rawValue != nil { rawList, ok := normalizeList(rawValue) if !ok { return nil, fmt.Errorf("%q list claim could not be converted to string list", source) } list := make([]string, 0, len(rawList)) for _, raw := range rawList { value, ok := stringifyMetadataValue(raw) if !ok { return nil, fmt.Errorf("value %v in %q list claim could not be parsed as string", raw, source) } if value == "" { continue } list = append(list, value) } out[target] = list } } return out, nil } // getClaim returns a claim value from allClaims given a provided claim string. // If this string is a valid JSONPointer, it will be interpreted as such to // locate the claim. Otherwise, the claim string will be used directly. // // There is no fixup done to the returned data type here. That happens a layer // up in the caller. func getClaim(logger hclog.Logger, allClaims map[string]interface{}, claim string) interface{} { if !strings.HasPrefix(claim, "/") { return allClaims[claim] } val, err := pointerstructure.Get(allClaims, claim) if err != nil { if logger != nil { logger.Warn("unable to locate claim", "claim", claim, "error", err) } return nil } return val } // normalizeList takes an item or a slice and returns a slice. This is useful // when providers are expected to return a list (typically of strings) but // reduce it to a non-slice type when the list count is 1. // // There is no fixup done to elements of the returned slice here. That happens // a layer up in the caller. func normalizeList(raw interface{}) ([]interface{}, bool) { switch v := raw.(type) { case []interface{}: return v, true case string, // note: this list should be the same as stringifyMetadataValue bool, json.Number, float64, float32, int8, int16, int32, int64, int, uint8, uint16, uint32, uint64, uint: return []interface{}{v}, true default: return nil, false } } // stringifyMetadataValue will try to convert the provided raw value into a // faithful string representation of that value per these rules: // // - strings => unchanged // - bool => "true" / "false" // - json.Number => String() // - float32/64 => truncated to int64 and then formatted as an ascii string // - intXX/uintXX => casted to int64 and then formatted as an ascii string // // If successful the string value and true are returned. otherwise an empty // string and false are returned. func stringifyMetadataValue(rawValue interface{}) (string, bool) { switch v := rawValue.(type) { case string: return v, true case bool: return strconv.FormatBool(v), true case json.Number: return v.String(), true case float64: // The claims unmarshalled by go-oidc don't use UseNumber, so // they'll come in as float64 instead of an integer or json.Number. return strconv.FormatInt(int64(v), 10), true // The numerical type cases following here are only here for the sake // of numerical type completion. Everything is truncated to an integer // before being stringified. case float32: return strconv.FormatInt(int64(v), 10), true case int8: return strconv.FormatInt(int64(v), 10), true case int16: return strconv.FormatInt(int64(v), 10), true case int32: return strconv.FormatInt(int64(v), 10), true case int64: return strconv.FormatInt(v, 10), true case int: return strconv.FormatInt(int64(v), 10), true case uint8: return strconv.FormatInt(int64(v), 10), true case uint16: return strconv.FormatInt(int64(v), 10), true case uint32: return strconv.FormatInt(int64(v), 10), true case uint64: return strconv.FormatInt(int64(v), 10), true case uint: return strconv.FormatInt(int64(v), 10), true default: return "", false } } // validateAudience checks whether any of the audiences in audClaim match those // in boundAudiences. If strict is true and there are no bound audiences, then // the presence of any audience in the received claim is considered an error. func validateAudience(boundAudiences, audClaim []string, strict bool) error { if strict && len(boundAudiences) == 0 && len(audClaim) > 0 { return errors.New("audience claim found in JWT but no audiences are bound") } if len(boundAudiences) > 0 { for _, v := range boundAudiences { if strutil.StrListContains(audClaim, v) { return nil } } return errors.New("aud claim does not match any bound audience") } return nil }