consul/internal/iamauth/util.go

159 lines
5.1 KiB
Go

package iamauth
import (
"encoding/base64"
"encoding/json"
"fmt"
"io/ioutil"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/credentials"
"github.com/aws/aws-sdk-go/aws/endpoints"
"github.com/aws/aws-sdk-go/aws/request"
"github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go/service/iam"
"github.com/aws/aws-sdk-go/service/sts"
"github.com/hashicorp/consul/internal/iamauth/responses"
"github.com/hashicorp/go-hclog"
)
type LoginInput struct {
Creds *credentials.Credentials
IncludeIAMEntity bool
STSEndpoint string
STSRegion string
Logger hclog.Logger
ServerIDHeaderValue string
// Customizable header names
ServerIDHeaderName string
GetEntityMethodHeader string
GetEntityURLHeader string
GetEntityHeadersHeader string
GetEntityBodyHeader string
}
// GenerateLoginData populates the necessary data to send for the bearer token.
// https://github.com/hashicorp/go-secure-stdlib/blob/main/awsutil/generate_credentials.go#L232-L301
func GenerateLoginData(in *LoginInput) (map[string]interface{}, error) {
cfg := aws.Config{
Credentials: in.Creds,
Region: aws.String(in.STSRegion),
}
if in.STSEndpoint != "" {
cfg.Endpoint = aws.String(in.STSEndpoint)
} else {
cfg.EndpointResolver = endpoints.ResolverFunc(stsSigningResolver)
}
stsSession, err := session.NewSessionWithOptions(session.Options{Config: cfg})
if err != nil {
return nil, err
}
svc := sts.New(stsSession)
stsRequest, _ := svc.GetCallerIdentityRequest(nil)
// Include the iam:GetRole or iam:GetUser request in headers.
if in.IncludeIAMEntity {
entityRequest, err := formatSignedEntityRequest(svc, in)
if err != nil {
return nil, err
}
headersJson, err := json.Marshal(entityRequest.HTTPRequest.Header)
if err != nil {
return nil, err
}
requestBody, err := ioutil.ReadAll(entityRequest.HTTPRequest.Body)
if err != nil {
return nil, err
}
stsRequest.HTTPRequest.Header.Add(in.GetEntityMethodHeader, entityRequest.HTTPRequest.Method)
stsRequest.HTTPRequest.Header.Add(in.GetEntityURLHeader, entityRequest.HTTPRequest.URL.String())
stsRequest.HTTPRequest.Header.Add(in.GetEntityHeadersHeader, string(headersJson))
stsRequest.HTTPRequest.Header.Add(in.GetEntityBodyHeader, string(requestBody))
}
// Inject the required auth header value, if supplied, and then sign the request including that header
if in.ServerIDHeaderValue != "" {
stsRequest.HTTPRequest.Header.Add(in.ServerIDHeaderName, in.ServerIDHeaderValue)
}
stsRequest.Sign()
// Now extract out the relevant parts of the request
headersJson, err := json.Marshal(stsRequest.HTTPRequest.Header)
if err != nil {
return nil, err
}
requestBody, err := ioutil.ReadAll(stsRequest.HTTPRequest.Body)
if err != nil {
return nil, err
}
return map[string]interface{}{
"iam_http_request_method": stsRequest.HTTPRequest.Method,
"iam_request_url": base64.StdEncoding.EncodeToString([]byte(stsRequest.HTTPRequest.URL.String())),
"iam_request_headers": base64.StdEncoding.EncodeToString(headersJson),
"iam_request_body": base64.StdEncoding.EncodeToString(requestBody),
}, nil
}
// STS is a really weird service that used to only have global endpoints but now has regional endpoints as well.
// For backwards compatibility, even if you request a region other than us-east-1, it'll still sign for us-east-1.
// See, e.g., https://docs.aws.amazon.com/IAM/latest/UserGuide/id_credentials_temp_enable-regions.html#id_credentials_temp_enable-regions_writing_code
// So we have to shim in this EndpointResolver to force it to sign for the right region
func stsSigningResolver(service, region string, optFns ...func(*endpoints.Options)) (endpoints.ResolvedEndpoint, error) {
defaultEndpoint, err := endpoints.DefaultResolver().EndpointFor(service, region, optFns...)
if err != nil {
return defaultEndpoint, err
}
defaultEndpoint.SigningRegion = region
return defaultEndpoint, nil
}
func formatSignedEntityRequest(svc *sts.STS, in *LoginInput) (*request.Request, error) {
// We need to retrieve the IAM user or role for the iam:GetRole or iam:GetUser request.
// GetCallerIdentity returns this and requires no permissions.
resp, err := svc.GetCallerIdentity(nil)
if err != nil {
return nil, err
}
arn, err := responses.ParseArn(*resp.Arn)
if err != nil {
return nil, err
}
iamSession, err := session.NewSessionWithOptions(session.Options{
Config: aws.Config{
Credentials: svc.Config.Credentials,
},
})
if err != nil {
return nil, err
}
iamSvc := iam.New(iamSession)
var req *request.Request
switch arn.Type {
case "role", "assumed-role":
req, _ = iamSvc.GetRoleRequest(&iam.GetRoleInput{RoleName: &arn.FriendlyName})
case "user":
req, _ = iamSvc.GetUserRequest(&iam.GetUserInput{UserName: &arn.FriendlyName})
default:
return nil, fmt.Errorf("entity %s is not an IAM role or IAM user", arn.Type)
}
// Inject the required auth header value, if supplied, and then sign the request including that header
if in.ServerIDHeaderValue != "" {
req.HTTPRequest.Header.Add(in.ServerIDHeaderName, in.ServerIDHeaderValue)
}
req.Sign()
return req, nil
}