mirror of
https://github.com/status-im/consul.git
synced 2025-01-12 14:55:02 +00:00
Add IAM Auth Method (#12583)
This adds an aws-iam auth method type which supports authenticating to Consul using AWS IAM identities. Co-authored-by: R.B. Boyer <4903+rboyer@users.noreply.github.com>
This commit is contained in:
parent
6bf67b7ef4
commit
706c844423
3
.changelog/12583.txt
Normal file
3
.changelog/12583.txt
Normal file
@ -0,0 +1,3 @@
|
||||
```release-note:feature
|
||||
acl: Added an AWS IAM auth method that allows authenticating to Consul using AWS IAM identities
|
||||
```
|
@ -8,6 +8,7 @@ import (
|
||||
"github.com/hashicorp/go-bexpr"
|
||||
|
||||
// register these as a builtin auth method
|
||||
_ "github.com/hashicorp/consul/agent/consul/authmethod/awsauth"
|
||||
_ "github.com/hashicorp/consul/agent/consul/authmethod/kubeauth"
|
||||
_ "github.com/hashicorp/consul/agent/consul/authmethod/ssoauth"
|
||||
)
|
||||
|
193
agent/consul/authmethod/awsauth/aws.go
Normal file
193
agent/consul/authmethod/awsauth/aws.go
Normal file
@ -0,0 +1,193 @@
|
||||
package awsauth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"github.com/hashicorp/consul/agent/consul/authmethod"
|
||||
"github.com/hashicorp/consul/agent/structs"
|
||||
"github.com/hashicorp/consul/internal/iamauth"
|
||||
"github.com/hashicorp/go-hclog"
|
||||
)
|
||||
|
||||
const (
|
||||
authMethodType string = "aws-iam"
|
||||
|
||||
IAMServerIDHeaderName string = "X-Consul-IAM-ServerID"
|
||||
GetEntityMethodHeader string = "X-Consul-IAM-GetEntity-Method"
|
||||
GetEntityURLHeader string = "X-Consul-IAM-GetEntity-URL"
|
||||
GetEntityHeadersHeader string = "X-Consul-IAM-GetEntity-Headers"
|
||||
GetEntityBodyHeader string = "X-Consul-IAM-GetEntity-Body"
|
||||
)
|
||||
|
||||
func init() {
|
||||
// register this as an available auth method type
|
||||
authmethod.Register(authMethodType, func(logger hclog.Logger, method *structs.ACLAuthMethod) (authmethod.Validator, error) {
|
||||
v, err := NewValidator(logger, method)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return v, nil
|
||||
})
|
||||
}
|
||||
|
||||
type Config struct {
|
||||
// BoundIAMPrincipalARNs are the trusted AWS IAM principal ARNs that are permitted
|
||||
// to login to the auth method. These can be the exact ARNs or wildcards. Wildcards
|
||||
// are only supported if EnableIAMEntityDetails is true.
|
||||
BoundIAMPrincipalARNs []string `json:",omitempty"`
|
||||
|
||||
// EnableIAMEntityDetails will fetch the IAM User or IAM Role details to include
|
||||
// in binding rules. Required if wildcard principal ARNs are used.
|
||||
EnableIAMEntityDetails bool `json:",omitempty"`
|
||||
|
||||
// IAMEntityTags are the specific IAM User or IAM Role tags to include as selectable
|
||||
// fields in the binding rule attributes. Requires EnableIAMEntityDetails = true.
|
||||
IAMEntityTags []string `json:",omitempty"`
|
||||
|
||||
// ServerIDHeaderValue adds a X-Consul-IAM-ServerID header to each AWS API request.
|
||||
// This helps protect against replay attacks.
|
||||
ServerIDHeaderValue string `json:",omitempty"`
|
||||
|
||||
// MaxRetries is the maximum number of retries on AWS API requests for recoverable errors.
|
||||
MaxRetries int `json:",omitempty"`
|
||||
// IAMEndpoint is the AWS IAM endpoint where iam:GetRole or iam:GetUser requests will be sent.
|
||||
// Note that the Host header in a signed request cannot be changed.
|
||||
IAMEndpoint string `json:",omitempty"`
|
||||
// STSEndpoint is the AWS STS endpoint where sts:GetCallerIdentity requests will be sent.
|
||||
// Note that the Host header in a signed request cannot be changed.
|
||||
STSEndpoint string `json:",omitempty"`
|
||||
// STSRegion is the region for the AWS STS service. This should only be set if STSEndpoint
|
||||
// is set, and must match the region of the STSEndpoint.
|
||||
STSRegion string `json:",omitempty"`
|
||||
|
||||
// AllowedSTSHeaderValues is a list of additional allowed headers on the sts:GetCallerIdentity
|
||||
// request in the bearer token. A default list of necessary headers is allowed in any case.
|
||||
AllowedSTSHeaderValues []string `json:",omitempty"`
|
||||
}
|
||||
|
||||
func (c *Config) convertForLibrary() *iamauth.Config {
|
||||
return &iamauth.Config{
|
||||
BoundIAMPrincipalARNs: c.BoundIAMPrincipalARNs,
|
||||
EnableIAMEntityDetails: c.EnableIAMEntityDetails,
|
||||
IAMEntityTags: c.IAMEntityTags,
|
||||
ServerIDHeaderValue: c.ServerIDHeaderValue,
|
||||
MaxRetries: c.MaxRetries,
|
||||
IAMEndpoint: c.IAMEndpoint,
|
||||
STSEndpoint: c.STSEndpoint,
|
||||
STSRegion: c.STSRegion,
|
||||
AllowedSTSHeaderValues: c.AllowedSTSHeaderValues,
|
||||
|
||||
ServerIDHeaderName: IAMServerIDHeaderName,
|
||||
GetEntityMethodHeader: GetEntityMethodHeader,
|
||||
GetEntityURLHeader: GetEntityURLHeader,
|
||||
GetEntityHeadersHeader: GetEntityHeadersHeader,
|
||||
GetEntityBodyHeader: GetEntityBodyHeader,
|
||||
}
|
||||
}
|
||||
|
||||
type Validator struct {
|
||||
name string
|
||||
config *iamauth.Config
|
||||
logger hclog.Logger
|
||||
|
||||
auth *iamauth.Authenticator
|
||||
}
|
||||
|
||||
func NewValidator(logger hclog.Logger, method *structs.ACLAuthMethod) (*Validator, error) {
|
||||
if method.Type != authMethodType {
|
||||
return nil, fmt.Errorf("%q is not an AWS IAM auth method", method.Name)
|
||||
}
|
||||
|
||||
var config Config
|
||||
if err := authmethod.ParseConfig(method.Config, &config); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
iamConfig := config.convertForLibrary()
|
||||
|
||||
auth, err := iamauth.NewAuthenticator(iamConfig, logger)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &Validator{
|
||||
name: method.Name,
|
||||
config: iamConfig,
|
||||
logger: logger,
|
||||
auth: auth,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Name implements authmethod.Validator.
|
||||
func (v *Validator) Name() string { return v.name }
|
||||
|
||||
// Stop implements authmethod.Validator.
|
||||
func (v *Validator) Stop() {}
|
||||
|
||||
// ValidateLogin implements authmethod.Validator.
|
||||
func (v *Validator) ValidateLogin(ctx context.Context, loginToken string) (*authmethod.Identity, error) {
|
||||
details, err := v.auth.ValidateLogin(ctx, loginToken)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
vars := map[string]string{
|
||||
"entity_name": details.EntityName,
|
||||
"entity_id": details.EntityId,
|
||||
"account_id": details.AccountId,
|
||||
}
|
||||
fields := &awsSelectableFields{
|
||||
EntityName: details.EntityName,
|
||||
EntityId: details.EntityId,
|
||||
AccountId: details.AccountId,
|
||||
}
|
||||
|
||||
if v.config.EnableIAMEntityDetails {
|
||||
vars["entity_path"] = details.EntityPath
|
||||
fields.EntityPath = details.EntityPath
|
||||
fields.EntityTags = map[string]string{}
|
||||
for _, tag := range v.config.IAMEntityTags {
|
||||
vars["entity_tags."+tag] = details.EntityTags[tag]
|
||||
fields.EntityTags[tag] = details.EntityTags[tag]
|
||||
}
|
||||
}
|
||||
|
||||
result := &authmethod.Identity{
|
||||
SelectableFields: fields,
|
||||
ProjectedVars: vars,
|
||||
EnterpriseMeta: nil,
|
||||
}
|
||||
return result, nil
|
||||
|
||||
}
|
||||
|
||||
func (v *Validator) NewIdentity() *authmethod.Identity {
|
||||
fields := &awsSelectableFields{
|
||||
EntityTags: map[string]string{},
|
||||
}
|
||||
vars := map[string]string{
|
||||
"entity_name": "",
|
||||
"entity_id": "",
|
||||
"account_id": "",
|
||||
}
|
||||
if v.config.EnableIAMEntityDetails {
|
||||
vars["entity_path"] = ""
|
||||
for _, tag := range v.config.IAMEntityTags {
|
||||
vars["entity_tags."+tag] = ""
|
||||
fields.EntityTags[tag] = ""
|
||||
}
|
||||
}
|
||||
return &authmethod.Identity{
|
||||
SelectableFields: fields,
|
||||
ProjectedVars: vars,
|
||||
}
|
||||
}
|
||||
|
||||
type awsSelectableFields struct {
|
||||
EntityName string `bexpr:"entity_name"`
|
||||
EntityId string `bexpr:"entity_id"`
|
||||
AccountId string `bexpr:"account_id"`
|
||||
|
||||
EntityPath string `bexpr:"entity_path"`
|
||||
EntityTags map[string]string `bexpr:"entity_tags"`
|
||||
}
|
342
agent/consul/authmethod/awsauth/aws_test.go
Normal file
342
agent/consul/authmethod/awsauth/aws_test.go
Normal file
@ -0,0 +1,342 @@
|
||||
package awsauth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/aws/aws-sdk-go/aws/credentials"
|
||||
"github.com/hashicorp/consul/agent/consul/authmethod"
|
||||
"github.com/hashicorp/consul/agent/structs"
|
||||
"github.com/hashicorp/consul/internal/iamauth"
|
||||
"github.com/hashicorp/consul/internal/iamauth/iamauthtest"
|
||||
"github.com/hashicorp/go-hclog"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestNewValidator(t *testing.T) {
|
||||
f := iamauthtest.MakeFixture()
|
||||
expConfig := &iamauth.Config{
|
||||
BoundIAMPrincipalARNs: []string{f.AssumedRoleARN},
|
||||
EnableIAMEntityDetails: true,
|
||||
IAMEntityTags: []string{"tag-1"},
|
||||
ServerIDHeaderValue: "x-some-header",
|
||||
MaxRetries: 3,
|
||||
IAMEndpoint: "iam-endpoint",
|
||||
STSEndpoint: "sts-endpoint",
|
||||
STSRegion: "sts-region",
|
||||
AllowedSTSHeaderValues: []string{"header-value"},
|
||||
ServerIDHeaderName: "X-Consul-IAM-ServerID",
|
||||
GetEntityMethodHeader: "X-Consul-IAM-GetEntity-Method",
|
||||
GetEntityURLHeader: "X-Consul-IAM-GetEntity-URL",
|
||||
GetEntityHeadersHeader: "X-Consul-IAM-GetEntity-Headers",
|
||||
GetEntityBodyHeader: "X-Consul-IAM-GetEntity-Body",
|
||||
}
|
||||
|
||||
type AM = *structs.ACLAuthMethod
|
||||
// Create the auth method, with an optional modification function.
|
||||
makeMethod := func(modifyFn func(AM)) AM {
|
||||
config := map[string]interface{}{
|
||||
"BoundIAMPrincipalARNs": []string{f.AssumedRoleARN},
|
||||
"EnableIAMEntityDetails": true,
|
||||
"IAMEntityTags": []string{"tag-1"},
|
||||
"ServerIDHeaderValue": "x-some-header",
|
||||
"MaxRetries": 3,
|
||||
"IAMEndpoint": "iam-endpoint",
|
||||
"STSEndpoint": "sts-endpoint",
|
||||
"STSRegion": "sts-region",
|
||||
"AllowedSTSHeaderValues": []string{"header-value"},
|
||||
}
|
||||
|
||||
m := &structs.ACLAuthMethod{
|
||||
Name: "test-iam",
|
||||
Type: "aws-iam",
|
||||
Description: "aws iam auth",
|
||||
Config: config,
|
||||
}
|
||||
if modifyFn != nil {
|
||||
modifyFn(m)
|
||||
}
|
||||
return m
|
||||
}
|
||||
|
||||
cases := map[string]struct {
|
||||
ok bool
|
||||
modifyFn func(AM)
|
||||
}{
|
||||
"success": {true, nil},
|
||||
"wrong type": {false, func(m AM) { m.Type = "not-iam" }},
|
||||
"extra config": {false, func(m AM) { m.Config["extraField"] = "123" }},
|
||||
"wrong config value type": {false, func(m AM) { m.Config["MaxRetries"] = []string{"1"} }},
|
||||
"missing bound principals": {false, func(m AM) { delete(m.Config, "BoundIAMPrincipalARNs") }},
|
||||
}
|
||||
for name, c := range cases {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
v, err := NewValidator(nil, makeMethod(c.modifyFn))
|
||||
if c.ok {
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, v)
|
||||
require.Equal(t, "test-iam", v.name)
|
||||
require.NotNil(t, v.auth)
|
||||
require.Equal(t, expConfig, v.config)
|
||||
} else {
|
||||
require.Error(t, err)
|
||||
require.Nil(t, v)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateLogin(t *testing.T) {
|
||||
f := iamauthtest.MakeFixture()
|
||||
|
||||
cases := map[string]struct {
|
||||
server *iamauthtest.Server
|
||||
token string
|
||||
config map[string]interface{}
|
||||
expVars map[string]string
|
||||
expFields []string
|
||||
expError string
|
||||
}{
|
||||
"success - role login": {
|
||||
server: f.ServerForRole,
|
||||
config: map[string]interface{}{
|
||||
"BoundIAMPrincipalARNs": []string{f.CanonicalRoleARN},
|
||||
},
|
||||
expVars: map[string]string{
|
||||
"entity_id": f.EntityID,
|
||||
"entity_name": f.RoleName,
|
||||
"account_id": f.AccountID,
|
||||
},
|
||||
expFields: []string{
|
||||
fmt.Sprintf(`entity_id == %q`, f.EntityID),
|
||||
fmt.Sprintf(`entity_name == %q`, f.RoleName),
|
||||
fmt.Sprintf(`account_id == %q`, f.AccountID),
|
||||
},
|
||||
},
|
||||
"success - user login": {
|
||||
server: f.ServerForUser,
|
||||
config: map[string]interface{}{
|
||||
"BoundIAMPrincipalARNs": []string{f.UserARN},
|
||||
},
|
||||
expVars: map[string]string{
|
||||
"entity_id": f.EntityID,
|
||||
"entity_name": f.UserName,
|
||||
"account_id": f.AccountID,
|
||||
},
|
||||
expFields: []string{
|
||||
fmt.Sprintf(`entity_id == %q`, f.EntityID),
|
||||
fmt.Sprintf(`entity_name == %q`, f.UserName),
|
||||
fmt.Sprintf(`account_id == %q`, f.AccountID),
|
||||
},
|
||||
},
|
||||
"success - role login with entity details": {
|
||||
server: f.ServerForUser,
|
||||
config: map[string]interface{}{
|
||||
"BoundIAMPrincipalARNs": []string{f.UserARN},
|
||||
"EnableIAMEntityDetails": true,
|
||||
},
|
||||
expVars: map[string]string{
|
||||
"entity_id": f.EntityID,
|
||||
"entity_name": f.UserName,
|
||||
"account_id": f.AccountID,
|
||||
"entity_path": f.UserPath,
|
||||
},
|
||||
expFields: []string{
|
||||
fmt.Sprintf(`entity_id == %q`, f.EntityID),
|
||||
fmt.Sprintf(`entity_name == %q`, f.UserName),
|
||||
fmt.Sprintf(`account_id == %q`, f.AccountID),
|
||||
fmt.Sprintf(`entity_path == %q`, f.UserPath),
|
||||
},
|
||||
},
|
||||
"success - user login with entity details": {
|
||||
server: f.ServerForUser,
|
||||
config: map[string]interface{}{
|
||||
"BoundIAMPrincipalARNs": []string{f.UserARN},
|
||||
"EnableIAMEntityDetails": true,
|
||||
},
|
||||
expVars: map[string]string{
|
||||
"entity_id": f.EntityID,
|
||||
"entity_name": f.UserName,
|
||||
"account_id": f.AccountID,
|
||||
"entity_path": f.UserPath,
|
||||
},
|
||||
expFields: []string{
|
||||
fmt.Sprintf(`entity_id == %q`, f.EntityID),
|
||||
fmt.Sprintf(`entity_name == %q`, f.UserName),
|
||||
fmt.Sprintf(`account_id == %q`, f.AccountID),
|
||||
fmt.Sprintf(`entity_path == %q`, f.UserPath),
|
||||
},
|
||||
},
|
||||
"invalid token": {
|
||||
server: f.ServerForUser,
|
||||
config: map[string]interface{}{
|
||||
"BoundIAMPrincipalARNs": []string{f.UserARN},
|
||||
},
|
||||
token: `invalid`,
|
||||
expError: "invalid token",
|
||||
},
|
||||
"empty json token": {
|
||||
server: f.ServerForUser,
|
||||
config: map[string]interface{}{
|
||||
"BoundIAMPrincipalARNs": []string{f.UserARN},
|
||||
},
|
||||
token: `{}`,
|
||||
expError: "invalid token",
|
||||
},
|
||||
"empty json fields in token": {
|
||||
server: f.ServerForUser,
|
||||
config: map[string]interface{}{
|
||||
"BoundIAMPrincipalARNs": []string{f.UserARN},
|
||||
},
|
||||
token: `{"iam_http_request_method": "",
|
||||
"iam_request_body": "",
|
||||
"iam_request_headers": "",
|
||||
"iam_request_url": ""
|
||||
}`,
|
||||
expError: "invalid token",
|
||||
},
|
||||
}
|
||||
for name, c := range cases {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
v, _, token := setup(t, c.config, c.server)
|
||||
if c.token != "" {
|
||||
token = c.token
|
||||
}
|
||||
id, err := v.ValidateLogin(context.Background(), token)
|
||||
if c.expError != "" {
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), c.expError)
|
||||
require.Nil(t, id)
|
||||
} else {
|
||||
require.NoError(t, err)
|
||||
authmethod.RequireIdentityMatch(t, id, c.expVars, c.expFields...)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func setup(t *testing.T, config map[string]interface{}, server *iamauthtest.Server) (*Validator, *httptest.Server, string) {
|
||||
t.Helper()
|
||||
|
||||
fakeAws := iamauthtest.NewTestServer(t, server)
|
||||
|
||||
config["STSEndpoint"] = fakeAws.URL + "/sts"
|
||||
config["STSRegion"] = "fake-region"
|
||||
config["IAMEndpoint"] = fakeAws.URL + "/iam"
|
||||
|
||||
method := &structs.ACLAuthMethod{
|
||||
Name: "test-method",
|
||||
Type: "aws-iam",
|
||||
Config: config,
|
||||
}
|
||||
nullLogger := hclog.NewNullLogger()
|
||||
v, err := NewValidator(nullLogger, method)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Generate the login token
|
||||
tokenData, err := iamauth.GenerateLoginData(&iamauth.LoginInput{
|
||||
Creds: credentials.NewStaticCredentials("fake", "fake", ""),
|
||||
IncludeIAMEntity: v.config.EnableIAMEntityDetails,
|
||||
STSEndpoint: v.config.STSEndpoint,
|
||||
STSRegion: v.config.STSRegion,
|
||||
Logger: nullLogger,
|
||||
ServerIDHeaderValue: v.config.ServerIDHeaderValue,
|
||||
ServerIDHeaderName: v.config.ServerIDHeaderName,
|
||||
GetEntityMethodHeader: v.config.GetEntityMethodHeader,
|
||||
GetEntityURLHeader: v.config.GetEntityURLHeader,
|
||||
GetEntityHeadersHeader: v.config.GetEntityHeadersHeader,
|
||||
GetEntityBodyHeader: v.config.GetEntityBodyHeader,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
token, err := json.Marshal(tokenData)
|
||||
require.NoError(t, err)
|
||||
return v, fakeAws, string(token)
|
||||
}
|
||||
|
||||
func TestNewIdentity(t *testing.T) {
|
||||
principals := []string{"arn:aws:sts::1234567890:assumed-role/my-role/some-session"}
|
||||
cases := map[string]struct {
|
||||
config map[string]interface{}
|
||||
expVars map[string]string
|
||||
expFilters []string
|
||||
}{
|
||||
"entity details disabled": {
|
||||
config: map[string]interface{}{
|
||||
"BoundIAMPrincipalARNs": principals,
|
||||
},
|
||||
expVars: map[string]string{
|
||||
"entity_name": "",
|
||||
"entity_id": "",
|
||||
"account_id": "",
|
||||
},
|
||||
expFilters: []string{
|
||||
`entity_name == ""`,
|
||||
`entity_id == ""`,
|
||||
`account_id == ""`,
|
||||
},
|
||||
},
|
||||
"entity details enabled": {
|
||||
config: map[string]interface{}{
|
||||
"BoundIAMPrincipalARNs": principals,
|
||||
"EnableIAMEntityDetails": true,
|
||||
},
|
||||
expVars: map[string]string{
|
||||
"entity_name": "",
|
||||
"entity_id": "",
|
||||
"account_id": "",
|
||||
"entity_path": "",
|
||||
},
|
||||
expFilters: []string{
|
||||
`entity_name == ""`,
|
||||
`entity_id == ""`,
|
||||
`account_id == ""`,
|
||||
`entity_path == ""`,
|
||||
},
|
||||
},
|
||||
"entity tags": {
|
||||
config: map[string]interface{}{
|
||||
"BoundIAMPrincipalARNs": principals,
|
||||
"EnableIAMEntityDetails": true,
|
||||
"IAMEntityTags": []string{
|
||||
"test_tag",
|
||||
"test_tag_2",
|
||||
},
|
||||
},
|
||||
expVars: map[string]string{
|
||||
"entity_name": "",
|
||||
"entity_id": "",
|
||||
"account_id": "",
|
||||
"entity_path": "",
|
||||
"entity_tags.test_tag": "",
|
||||
"entity_tags.test_tag_2": "",
|
||||
},
|
||||
expFilters: []string{
|
||||
`entity_name == ""`,
|
||||
`entity_id == ""`,
|
||||
`account_id == ""`,
|
||||
`entity_path == ""`,
|
||||
`entity_tags.test_tag == ""`,
|
||||
`entity_tags.test_tag_2 == ""`,
|
||||
},
|
||||
},
|
||||
}
|
||||
for name, c := range cases {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
method := &structs.ACLAuthMethod{
|
||||
Name: "test-method",
|
||||
Type: "aws-iam",
|
||||
Config: c.config,
|
||||
}
|
||||
nullLogger := hclog.NewNullLogger()
|
||||
v, err := NewValidator(nullLogger, method)
|
||||
require.NoError(t, err)
|
||||
|
||||
id := v.NewIdentity()
|
||||
authmethod.RequireIdentityMatch(t, id, c.expVars, c.expFilters...)
|
||||
})
|
||||
}
|
||||
}
|
148
command/login/aws.go
Normal file
148
command/login/aws.go
Normal file
@ -0,0 +1,148 @@
|
||||
package login
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"flag"
|
||||
"fmt"
|
||||
|
||||
"github.com/aws/aws-sdk-go/aws"
|
||||
"github.com/aws/aws-sdk-go/aws/credentials"
|
||||
"github.com/aws/aws-sdk-go/aws/session"
|
||||
|
||||
"github.com/hashicorp/consul/agent/consul/authmethod/awsauth"
|
||||
"github.com/hashicorp/consul/internal/iamauth"
|
||||
"github.com/hashicorp/go-hclog"
|
||||
)
|
||||
|
||||
type AWSLogin struct {
|
||||
autoBearerToken bool
|
||||
includeEntity bool
|
||||
stsEndpoint string
|
||||
region string
|
||||
serverIDHeaderValue string
|
||||
accessKeyId string
|
||||
secretAccessKey string
|
||||
sessionToken string
|
||||
}
|
||||
|
||||
func (a *AWSLogin) flags() *flag.FlagSet {
|
||||
fs := flag.NewFlagSet("", flag.ContinueOnError)
|
||||
fs.BoolVar(&a.autoBearerToken, "aws-auto-bearer-token", false,
|
||||
"Construct a bearer token and login to the AWS IAM auth method. This requires AWS credentials. "+
|
||||
"AWS credentials are automatically discovered from standard sources supported by the Go SDK for "+
|
||||
"AWS. Alternatively, explicit credentials can be passed using the -aws-acesss-key-id and "+
|
||||
"-aws-secret-access-key flags. [aws-iam only]")
|
||||
|
||||
fs.BoolVar(&a.includeEntity, "aws-include-entity", false,
|
||||
"Include a signed request to get the IAM role or IAM user in the bearer token. [aws-iam only]")
|
||||
|
||||
fs.StringVar(&a.stsEndpoint, "aws-sts-endpoint", "",
|
||||
"URL for AWS STS API calls. [aws-iam only]")
|
||||
|
||||
fs.StringVar(&a.region, "aws-region", "",
|
||||
"Region for AWS API calls. If set, should match the region of -aws-sts-endpoint. "+
|
||||
"If not provided, the region will be discovered from standard sources, such as "+
|
||||
"the AWS_REGION environment variable. [aws-iam only]")
|
||||
|
||||
fs.StringVar(&a.serverIDHeaderValue, "aws-server-id-header-value", "",
|
||||
"If set, an X-Consul-IAM-ServerID header is included in signed AWS API request(s) that form "+
|
||||
"the bearer token. This value must match the server-side configured value for the auth method "+
|
||||
"in order to login. This is optional and helps protect against replay attacks. [aws-iam only]")
|
||||
|
||||
fs.StringVar(&a.accessKeyId, "aws-access-key-id", "",
|
||||
"AWS access key id to use. Requires -aws-secret-access-key if specified. [aws-iam only]")
|
||||
|
||||
fs.StringVar(&a.secretAccessKey, "aws-secret-access-key", "",
|
||||
"AWS secret access key to use. Requires -aws-access-key-id if specified. [aws-iam only]")
|
||||
|
||||
fs.StringVar(&a.sessionToken, "aws-session-token", "",
|
||||
"AWS session token to use. Requires -aws-access-key-id and -aws-secret-access-key if "+
|
||||
"specified. [aws-iam only]")
|
||||
return fs
|
||||
}
|
||||
|
||||
// checkFlags validates flags for the aws-iam auth method.
|
||||
func (a *AWSLogin) checkFlags() error {
|
||||
if !a.autoBearerToken {
|
||||
if a.includeEntity || a.stsEndpoint != "" || a.region != "" || a.serverIDHeaderValue != "" ||
|
||||
a.accessKeyId != "" || a.secretAccessKey != "" || a.sessionToken != "" {
|
||||
return fmt.Errorf("Missing '-aws-auto-bearer-token' flag")
|
||||
}
|
||||
}
|
||||
if a.accessKeyId != "" && a.secretAccessKey == "" {
|
||||
return fmt.Errorf("Missing '-aws-secret-access-key' flag")
|
||||
}
|
||||
if a.secretAccessKey != "" && a.accessKeyId == "" {
|
||||
return fmt.Errorf("Missing '-aws-access-key-id' flag")
|
||||
}
|
||||
if a.sessionToken != "" && (a.accessKeyId == "" || a.secretAccessKey == "") {
|
||||
return fmt.Errorf("Missing '-aws-access-key-id' and '-aws-secret-access-key' flags")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// createAWSBearerToken generates a bearer token string for the AWS IAM auth method.
|
||||
// It will discover AWS credentials which are used to sign AWS API requests.
|
||||
// Alternatively, static credentials can be passed as flags.
|
||||
//
|
||||
// The bearer token contains a signed sts:GetCallerIdentity request.
|
||||
// If aws-include-entity is specified, a signed iam:GetRole or iam:GetUser request is
|
||||
// also included. The AWS credentials are used to retrieve the current user's role
|
||||
// or user name for the iam:GetRole or iam:GetUser request.
|
||||
func (a *AWSLogin) createAWSBearerToken() (string, error) {
|
||||
cfg := aws.Config{
|
||||
Endpoint: aws.String(a.stsEndpoint),
|
||||
Region: aws.String(a.region),
|
||||
// More detailed error message to help debug credential discovery.
|
||||
CredentialsChainVerboseErrors: aws.Bool(true),
|
||||
}
|
||||
|
||||
if a.accessKeyId != "" {
|
||||
// Use creds from flags.
|
||||
cfg.Credentials = credentials.NewStaticCredentials(
|
||||
a.accessKeyId, a.secretAccessKey, a.sessionToken,
|
||||
)
|
||||
}
|
||||
|
||||
// Session loads creds from standard sources (env vars, file, EC2 metadata, ...)
|
||||
sess, err := session.NewSessionWithOptions(session.Options{
|
||||
Config: cfg,
|
||||
// Allow loading from config files by default:
|
||||
// ~/.aws/config or AWS_CONFIG_FILE
|
||||
// ~/.aws/credentials or AWS_SHARED_CREDENTIALS_FILE
|
||||
SharedConfigState: session.SharedConfigEnable,
|
||||
})
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if sess.Config.Region == nil || *sess.Config.Region == "" {
|
||||
return "", fmt.Errorf("AWS region not found")
|
||||
}
|
||||
if sess.Config.Credentials == nil {
|
||||
return "", fmt.Errorf("AWS credentials not found")
|
||||
}
|
||||
creds := sess.Config.Credentials
|
||||
|
||||
loginData, err := iamauth.GenerateLoginData(&iamauth.LoginInput{
|
||||
Creds: creds,
|
||||
IncludeIAMEntity: a.includeEntity,
|
||||
STSEndpoint: a.stsEndpoint,
|
||||
STSRegion: a.region,
|
||||
Logger: hclog.New(nil),
|
||||
ServerIDHeaderValue: a.serverIDHeaderValue,
|
||||
ServerIDHeaderName: awsauth.IAMServerIDHeaderName,
|
||||
GetEntityMethodHeader: awsauth.GetEntityMethodHeader,
|
||||
GetEntityURLHeader: awsauth.GetEntityURLHeader,
|
||||
GetEntityHeadersHeader: awsauth.GetEntityHeadersHeader,
|
||||
GetEntityBodyHeader: awsauth.GetEntityBodyHeader,
|
||||
})
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
loginDataJson, err := json.Marshal(loginData)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return string(loginDataJson), err
|
||||
}
|
@ -36,6 +36,8 @@ type cmd struct {
|
||||
tokenSinkFile string
|
||||
meta map[string]string
|
||||
|
||||
aws AWSLogin
|
||||
|
||||
enterpriseCmd
|
||||
}
|
||||
|
||||
@ -57,10 +59,10 @@ func (c *cmd) init() {
|
||||
c.flags.Var((*flags.FlagMapValue)(&c.meta), "meta",
|
||||
"Metadata to set on the token, formatted as key=value. This flag "+
|
||||
"may be specified multiple times to set multiple meta fields.")
|
||||
|
||||
c.initEnterpriseFlags()
|
||||
|
||||
c.http = &flags.HTTPFlags{}
|
||||
flags.Merge(c.flags, c.aws.flags())
|
||||
flags.Merge(c.flags, c.http.ClientFlags())
|
||||
flags.Merge(c.flags, c.http.ServerFlags())
|
||||
flags.Merge(c.flags, c.http.MultiTenancyFlags())
|
||||
@ -89,21 +91,38 @@ func (c *cmd) Run(args []string) int {
|
||||
}
|
||||
|
||||
func (c *cmd) bearerTokenLogin() int {
|
||||
if c.bearerTokenFile == "" {
|
||||
c.UI.Error(fmt.Sprintf("Missing required '-bearer-token-file' flag"))
|
||||
return 1
|
||||
}
|
||||
|
||||
data, err := ioutil.ReadFile(c.bearerTokenFile)
|
||||
if err != nil {
|
||||
if err := c.aws.checkFlags(); err != nil {
|
||||
c.UI.Error(err.Error())
|
||||
return 1
|
||||
}
|
||||
c.bearerToken = strings.TrimSpace(string(data))
|
||||
|
||||
if c.bearerToken == "" {
|
||||
c.UI.Error(fmt.Sprintf("No bearer token found in %s", c.bearerTokenFile))
|
||||
if c.aws.autoBearerToken {
|
||||
if c.bearerTokenFile != "" {
|
||||
c.UI.Error("Cannot use '-bearer-token-file' flag with '-aws-auto-bearer-token'")
|
||||
return 1
|
||||
}
|
||||
|
||||
if token, err := c.aws.createAWSBearerToken(); err != nil {
|
||||
c.UI.Error(fmt.Sprintf("Error with aws-iam auth method: %s", err))
|
||||
return 1
|
||||
} else {
|
||||
c.bearerToken = token
|
||||
}
|
||||
} else if c.bearerTokenFile == "" {
|
||||
c.UI.Error("Missing required '-bearer-token-file' flag")
|
||||
return 1
|
||||
} else {
|
||||
data, err := ioutil.ReadFile(c.bearerTokenFile)
|
||||
if err != nil {
|
||||
c.UI.Error(err.Error())
|
||||
return 1
|
||||
}
|
||||
c.bearerToken = strings.TrimSpace(string(data))
|
||||
|
||||
if c.bearerToken == "" {
|
||||
c.UI.Error(fmt.Sprintf("No bearer token found in %s", c.bearerTokenFile))
|
||||
return 1
|
||||
}
|
||||
}
|
||||
|
||||
// Ensure that we don't try to use a token when performing a login
|
||||
|
@ -1,6 +1,7 @@
|
||||
package login
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"os"
|
||||
"path/filepath"
|
||||
@ -18,6 +19,7 @@ import (
|
||||
"github.com/hashicorp/consul/api"
|
||||
"github.com/hashicorp/consul/command/acl"
|
||||
"github.com/hashicorp/consul/internal/go-sso/oidcauth/oidcauthtest"
|
||||
"github.com/hashicorp/consul/internal/iamauth/iamauthtest"
|
||||
"github.com/hashicorp/consul/sdk/testutil"
|
||||
"github.com/hashicorp/consul/testrpc"
|
||||
)
|
||||
@ -39,18 +41,7 @@ func TestLoginCommand(t *testing.T) {
|
||||
|
||||
testDir := testutil.TempDir(t, "acl")
|
||||
|
||||
a := agent.NewTestAgent(t, `
|
||||
primary_datacenter = "dc1"
|
||||
acl {
|
||||
enabled = true
|
||||
tokens {
|
||||
initial_management = "root"
|
||||
}
|
||||
}`)
|
||||
|
||||
defer a.Shutdown()
|
||||
testrpc.WaitForLeader(t, a.RPC, "dc1")
|
||||
|
||||
a := newTestAgent(t)
|
||||
client := a.Client()
|
||||
|
||||
t.Run("method is required", func(t *testing.T) {
|
||||
@ -102,6 +93,81 @@ func TestLoginCommand(t *testing.T) {
|
||||
require.Contains(t, ui.ErrorWriter.String(), "Missing required '-bearer-token-file' flag")
|
||||
})
|
||||
|
||||
t.Run("bearer-token-file disallowed with aws-auto-bearer-token", func(t *testing.T) {
|
||||
defer os.Remove(tokenSinkFile)
|
||||
|
||||
ui := cli.NewMockUi()
|
||||
cmd := New(ui)
|
||||
|
||||
args := []string{
|
||||
"-http-addr=" + a.HTTPAddr(),
|
||||
"-token=root",
|
||||
"-method=test",
|
||||
"-token-sink-file", tokenSinkFile,
|
||||
"-bearer-token-file", "none.txt",
|
||||
"-aws-auto-bearer-token",
|
||||
}
|
||||
|
||||
code := cmd.Run(args)
|
||||
require.Equal(t, code, 1, "err: %s", ui.ErrorWriter.String())
|
||||
require.Contains(t, ui.ErrorWriter.String(), "Cannot use '-bearer-token-file' flag with '-aws-auto-bearer-token'")
|
||||
})
|
||||
|
||||
t.Run("aws flags require aws-auto-bearer-token", func(t *testing.T) {
|
||||
defer os.Remove(tokenSinkFile)
|
||||
|
||||
baseArgs := []string{
|
||||
"-http-addr=" + a.HTTPAddr(),
|
||||
"-token=root",
|
||||
"-method=test",
|
||||
"-token-sink-file", tokenSinkFile,
|
||||
}
|
||||
|
||||
for _, extraArgs := range [][]string{
|
||||
{"-aws-include-entity"},
|
||||
{"-aws-sts-endpoint", "some-endpoint"},
|
||||
{"-aws-region", "some-region"},
|
||||
{"-aws-server-id-header-value", "some-value"},
|
||||
{"-aws-access-key-id", "some-key"},
|
||||
{"-aws-secret-access-key", "some-secret"},
|
||||
{"-aws-session-token", "some-token"},
|
||||
} {
|
||||
ui := cli.NewMockUi()
|
||||
code := New(ui).Run(append(baseArgs, extraArgs...))
|
||||
require.Equal(t, code, 1, "err: %s", ui.ErrorWriter.String())
|
||||
require.Contains(t, ui.ErrorWriter.String(), "Missing '-aws-auto-bearer-token' flag")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("aws-access-key-id and aws-secret-access-key require each other", func(t *testing.T) {
|
||||
defer os.Remove(tokenSinkFile)
|
||||
|
||||
baseArgs := []string{
|
||||
"-http-addr=" + a.HTTPAddr(),
|
||||
"-token=root",
|
||||
"-method=test",
|
||||
"-token-sink-file", tokenSinkFile,
|
||||
"-aws-auto-bearer-token",
|
||||
}
|
||||
|
||||
ui := cli.NewMockUi()
|
||||
code := New(ui).Run(append(baseArgs, "-aws-access-key-id", "some-key"))
|
||||
require.Equal(t, code, 1, "err: %s", ui.ErrorWriter.String())
|
||||
require.Contains(t, ui.ErrorWriter.String(), "Missing '-aws-secret-access-key' flag")
|
||||
|
||||
ui = cli.NewMockUi()
|
||||
code = New(ui).Run(append(baseArgs, "-aws-secret-access-key", "some-key"))
|
||||
require.Equal(t, code, 1, "err: %s", ui.ErrorWriter.String())
|
||||
require.Contains(t, ui.ErrorWriter.String(), "Missing '-aws-access-key-id' flag")
|
||||
|
||||
ui = cli.NewMockUi()
|
||||
code = New(ui).Run(append(baseArgs, "-aws-session-token", "some-token"))
|
||||
require.Equal(t, code, 1, "err: %s", ui.ErrorWriter.String())
|
||||
require.Contains(t, ui.ErrorWriter.String(),
|
||||
"Missing '-aws-access-key-id' and '-aws-secret-access-key' flags")
|
||||
|
||||
})
|
||||
|
||||
bearerTokenFile := filepath.Join(testDir, "bearer.token")
|
||||
|
||||
t.Run("bearer-token-file is empty", func(t *testing.T) {
|
||||
@ -236,18 +302,7 @@ func TestLoginCommand_k8s(t *testing.T) {
|
||||
|
||||
testDir := testutil.TempDir(t, "acl")
|
||||
|
||||
a := agent.NewTestAgent(t, `
|
||||
primary_datacenter = "dc1"
|
||||
acl {
|
||||
enabled = true
|
||||
tokens {
|
||||
initial_management = "root"
|
||||
}
|
||||
}`)
|
||||
|
||||
defer a.Shutdown()
|
||||
testrpc.WaitForLeader(t, a.RPC, "dc1")
|
||||
|
||||
a := newTestAgent(t)
|
||||
client := a.Client()
|
||||
|
||||
tokenSinkFile := filepath.Join(testDir, "test.token")
|
||||
@ -334,18 +389,7 @@ func TestLoginCommand_jwt(t *testing.T) {
|
||||
|
||||
testDir := testutil.TempDir(t, "acl")
|
||||
|
||||
a := agent.NewTestAgent(t, `
|
||||
primary_datacenter = "dc1"
|
||||
acl {
|
||||
enabled = true
|
||||
tokens {
|
||||
initial_management = "root"
|
||||
}
|
||||
}`)
|
||||
|
||||
defer a.Shutdown()
|
||||
testrpc.WaitForLeader(t, a.RPC, "dc1")
|
||||
|
||||
a := newTestAgent(t)
|
||||
client := a.Client()
|
||||
|
||||
tokenSinkFile := filepath.Join(testDir, "test.token")
|
||||
@ -470,3 +514,178 @@ func TestLoginCommand_jwt(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoginCommand_aws_iam(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("too slow for testing.Short")
|
||||
}
|
||||
|
||||
t.Parallel()
|
||||
|
||||
// Formats an HIL template for a BindName, and the expected value for entity tags.
|
||||
// Input: string{"a", "b"}, []string{"1", "2"}
|
||||
// Return: "${entity_tags.a}-${entity_tags.b}", "1-2"
|
||||
entityTagsBind := func(keys, values []string) (string, string) {
|
||||
parts := []string{}
|
||||
for _, k := range keys {
|
||||
parts = append(parts, fmt.Sprintf("${entity_tags.%s}", k))
|
||||
}
|
||||
return strings.Join(parts, "-"), strings.Join(values, "-")
|
||||
}
|
||||
|
||||
f := iamauthtest.MakeFixture()
|
||||
roleTagsBindName, roleTagsBindValue := entityTagsBind(f.RoleTagKeys(), f.RoleTagValues())
|
||||
userTagsBindName, userTagsBindValue := entityTagsBind(f.UserTagKeys(), f.UserTagValues())
|
||||
|
||||
cases := map[string]struct {
|
||||
awsServer *iamauthtest.Server
|
||||
cmdArgs []string
|
||||
config map[string]interface{}
|
||||
bindingRule *api.ACLBindingRule
|
||||
expServiceIdentity *api.ACLServiceIdentity
|
||||
}{
|
||||
"success - login with role": {
|
||||
awsServer: f.ServerForRole,
|
||||
cmdArgs: []string{"-aws-auto-bearer-token"},
|
||||
config: map[string]interface{}{
|
||||
// Test that an assumed-role arn is translated to the canonical role arn.
|
||||
"BoundIAMPrincipalARNs": []string{f.CanonicalRoleARN},
|
||||
},
|
||||
bindingRule: &api.ACLBindingRule{
|
||||
BindType: api.BindingRuleBindTypeService,
|
||||
BindName: "${entity_name}-${entity_id}-${account_id}",
|
||||
Selector: fmt.Sprintf(`entity_name==%q and entity_id==%q and account_id==%q`,
|
||||
f.RoleName, f.EntityID, f.AccountID),
|
||||
},
|
||||
expServiceIdentity: &api.ACLServiceIdentity{
|
||||
ServiceName: fmt.Sprintf("%s-%s-%s", f.RoleName, strings.ToLower(f.EntityID), f.AccountID),
|
||||
},
|
||||
},
|
||||
"success - login with role and entity details enabled": {
|
||||
awsServer: f.ServerForRole,
|
||||
cmdArgs: []string{"-aws-auto-bearer-token", "-aws-include-entity"},
|
||||
config: map[string]interface{}{
|
||||
// Test that we can login with full user path.
|
||||
"BoundIAMPrincipalARNs": []string{f.RoleARN},
|
||||
"EnableIAMEntityDetails": true,
|
||||
},
|
||||
bindingRule: &api.ACLBindingRule{
|
||||
BindType: api.BindingRuleBindTypeService,
|
||||
// TODO: Path cannot be used as service name if it contains a '/'
|
||||
BindName: "${entity_name}",
|
||||
Selector: fmt.Sprintf(`entity_name==%q and entity_path==%q`, f.RoleName, f.RolePath),
|
||||
},
|
||||
expServiceIdentity: &api.ACLServiceIdentity{ServiceName: f.RoleName},
|
||||
},
|
||||
"success - login with role and role tags": {
|
||||
awsServer: f.ServerForRole,
|
||||
cmdArgs: []string{"-aws-auto-bearer-token", "-aws-include-entity"},
|
||||
config: map[string]interface{}{
|
||||
// Test that we can login with a wildcard.
|
||||
"BoundIAMPrincipalARNs": []string{f.RoleARNWildcard},
|
||||
"EnableIAMEntityDetails": true,
|
||||
"IAMEntityTags": f.RoleTagKeys(),
|
||||
},
|
||||
bindingRule: &api.ACLBindingRule{
|
||||
BindType: api.BindingRuleBindTypeService,
|
||||
BindName: roleTagsBindName,
|
||||
Selector: fmt.Sprintf(`entity_name==%q and entity_path==%q`, f.RoleName, f.RolePath),
|
||||
},
|
||||
expServiceIdentity: &api.ACLServiceIdentity{ServiceName: roleTagsBindValue},
|
||||
},
|
||||
"success - login with user and user tags": {
|
||||
awsServer: f.ServerForUser,
|
||||
cmdArgs: []string{"-aws-auto-bearer-token", "-aws-include-entity"},
|
||||
config: map[string]interface{}{
|
||||
// Test that we can login with a wildcard.
|
||||
"BoundIAMPrincipalARNs": []string{f.UserARNWildcard},
|
||||
"EnableIAMEntityDetails": true,
|
||||
"IAMEntityTags": f.UserTagKeys(),
|
||||
},
|
||||
bindingRule: &api.ACLBindingRule{
|
||||
BindType: api.BindingRuleBindTypeService,
|
||||
BindName: "${entity_name}-" + userTagsBindName,
|
||||
Selector: fmt.Sprintf(`entity_name==%q and entity_path==%q`, f.UserName, f.UserPath),
|
||||
},
|
||||
expServiceIdentity: &api.ACLServiceIdentity{
|
||||
ServiceName: fmt.Sprintf("%s-%s", f.UserName, userTagsBindValue),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for name, c := range cases {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
a := newTestAgent(t)
|
||||
client := a.Client()
|
||||
|
||||
fakeAws := iamauthtest.NewTestServer(t, c.awsServer)
|
||||
|
||||
c.config["STSEndpoint"] = fakeAws.URL + "/sts"
|
||||
c.config["IAMEndpoint"] = fakeAws.URL + "/iam"
|
||||
|
||||
_, _, err := client.ACL().AuthMethodCreate(
|
||||
&api.ACLAuthMethod{
|
||||
Name: "iam-test",
|
||||
Type: "aws-iam",
|
||||
Config: c.config,
|
||||
},
|
||||
&api.WriteOptions{Token: "root"},
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
c.bindingRule.AuthMethod = "iam-test"
|
||||
_, _, err = client.ACL().BindingRuleCreate(
|
||||
c.bindingRule,
|
||||
&api.WriteOptions{Token: "root"},
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
testDir := testutil.TempDir(t, "acl")
|
||||
tokenSinkFile := filepath.Join(testDir, "test.token")
|
||||
t.Cleanup(func() { _ = os.Remove(tokenSinkFile) })
|
||||
|
||||
ui := cli.NewMockUi()
|
||||
cmd := New(ui)
|
||||
args := []string{
|
||||
"-http-addr=" + a.HTTPAddr(),
|
||||
"-token=root",
|
||||
"-method=iam-test",
|
||||
"-token-sink-file", tokenSinkFile,
|
||||
"-aws-sts-endpoint", fakeAws.URL + "/sts",
|
||||
"-aws-region", "fake-region",
|
||||
"-aws-access-key-id", "fake-key-id",
|
||||
"-aws-secret-access-key", "fake-secret-key",
|
||||
}
|
||||
args = append(args, c.cmdArgs...)
|
||||
code := cmd.Run(args)
|
||||
require.Equal(t, 0, code, ui.ErrorWriter.String())
|
||||
|
||||
raw, err := ioutil.ReadFile(tokenSinkFile)
|
||||
require.NoError(t, err)
|
||||
|
||||
token := strings.TrimSpace(string(raw))
|
||||
require.Len(t, token, 36, "must be a valid uid: %s", token)
|
||||
|
||||
// Validate correct BindName was interpolated.
|
||||
tokenRead, _, err := client.ACL().TokenReadSelf(&api.QueryOptions{Token: token})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, tokenRead.ServiceIdentities, 1)
|
||||
require.Equal(t, c.expServiceIdentity, tokenRead.ServiceIdentities[0])
|
||||
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func newTestAgent(t *testing.T) *agent.TestAgent {
|
||||
a := agent.NewTestAgent(t, `
|
||||
primary_datacenter = "dc1"
|
||||
acl {
|
||||
enabled = true
|
||||
tokens {
|
||||
initial_management = "root"
|
||||
}
|
||||
}`)
|
||||
t.Cleanup(func() { _ = a.Shutdown() })
|
||||
testrpc.WaitForLeader(t, a.RPC, "dc1")
|
||||
return a
|
||||
}
|
||||
|
2
go.mod
2
go.mod
@ -44,7 +44,7 @@ require (
|
||||
github.com/hashicorp/go-memdb v1.3.2
|
||||
github.com/hashicorp/go-multierror v1.1.1
|
||||
github.com/hashicorp/go-raftchunking v0.6.2
|
||||
github.com/hashicorp/go-retryablehttp v0.6.7 // indirect
|
||||
github.com/hashicorp/go-retryablehttp v0.6.7
|
||||
github.com/hashicorp/go-sockaddr v1.0.2
|
||||
github.com/hashicorp/go-syslog v1.0.0
|
||||
github.com/hashicorp/go-uuid v1.0.2
|
||||
|
2
internal/iamauth/README.md
Normal file
2
internal/iamauth/README.md
Normal file
@ -0,0 +1,2 @@
|
||||
This is an internal package to house the AWS IAM auth method utilities for potential
|
||||
future extraction from Consul.
|
311
internal/iamauth/auth.go
Normal file
311
internal/iamauth/auth.go
Normal file
@ -0,0 +1,311 @@
|
||||
package iamauth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/xml"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"regexp"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/hashicorp/consul/internal/iamauth/responses"
|
||||
"github.com/hashicorp/consul/lib"
|
||||
"github.com/hashicorp/consul/lib/stringslice"
|
||||
"github.com/hashicorp/go-cleanhttp"
|
||||
"github.com/hashicorp/go-hclog"
|
||||
"github.com/hashicorp/go-retryablehttp"
|
||||
)
|
||||
|
||||
const (
|
||||
// Retry configuration
|
||||
retryWaitMin = 500 * time.Millisecond
|
||||
retryWaitMax = 30 * time.Second
|
||||
)
|
||||
|
||||
type Authenticator struct {
|
||||
config *Config
|
||||
logger hclog.Logger
|
||||
}
|
||||
|
||||
type IdentityDetails struct {
|
||||
EntityName string
|
||||
EntityId string
|
||||
AccountId string
|
||||
|
||||
EntityPath string
|
||||
EntityTags map[string]string
|
||||
}
|
||||
|
||||
func NewAuthenticator(config *Config, logger hclog.Logger) (*Authenticator, error) {
|
||||
if err := config.Validate(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &Authenticator{
|
||||
config: config,
|
||||
logger: logger,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// ValidateLogin determines if the identity in the loginToken is permitted to login.
|
||||
// If so, it returns details about the identity. Otherwise, an error is returned.
|
||||
func (a *Authenticator) ValidateLogin(ctx context.Context, loginToken string) (*IdentityDetails, error) {
|
||||
token, err := NewBearerToken(loginToken, a.config)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
req, err := token.GetCallerIdentityRequest()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if a.config.ServerIDHeaderValue != "" {
|
||||
err := validateHeaderValue(req.Header, a.config.ServerIDHeaderName, a.config.ServerIDHeaderValue)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
callerIdentity, err := a.submitCallerIdentityRequest(ctx, req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
a.logger.Debug("iamauth login attempt", "arn", callerIdentity.Arn)
|
||||
|
||||
entity, err := responses.ParseArn(callerIdentity.Arn)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
identityDetails := &IdentityDetails{
|
||||
EntityName: entity.FriendlyName,
|
||||
// This could either be a "userID:SessionID" (in the case of an assumed role) or just a "userID"
|
||||
// (in the case of an IAM user).
|
||||
EntityId: strings.Split(callerIdentity.UserId, ":")[0],
|
||||
AccountId: callerIdentity.Account,
|
||||
}
|
||||
clientArn := entity.CanonicalArn()
|
||||
|
||||
// Fetch the IAM Role or IAM User, if configured.
|
||||
// This requires the token to contain a signed iam:GetRole or iam:GetUser request.
|
||||
if a.config.EnableIAMEntityDetails {
|
||||
iamReq, err := token.GetEntityRequest()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if a.config.ServerIDHeaderValue != "" {
|
||||
err := validateHeaderValue(iamReq.Header, a.config.ServerIDHeaderName, a.config.ServerIDHeaderValue)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
iamEntityDetails, err := a.submitGetIAMEntityRequest(ctx, iamReq, token.entityRequestType)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Only the CallerIdentity response is a guarantee of the client's identity.
|
||||
// The role/user details must have a unique id match to the CallerIdentity before use.
|
||||
if iamEntityDetails.EntityId() != identityDetails.EntityId {
|
||||
return nil, fmt.Errorf("unique id mismatch in login token")
|
||||
}
|
||||
|
||||
// Use the full ARN with path from the Role/User details
|
||||
clientArn = iamEntityDetails.EntityArn()
|
||||
identityDetails.EntityPath = iamEntityDetails.EntityPath()
|
||||
identityDetails.EntityTags = iamEntityDetails.EntityTags()
|
||||
}
|
||||
|
||||
if err := a.validateIdentity(clientArn); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return identityDetails, nil
|
||||
}
|
||||
|
||||
// https://github.com/hashicorp/vault/blob/ba533d006f2244103648785ebfe8a9a9763d2b6e/builtin/credential/aws/path_login.go#L1321-L1361
|
||||
func (a *Authenticator) validateIdentity(clientArn string) error {
|
||||
if stringslice.Contains(a.config.BoundIAMPrincipalARNs, clientArn) {
|
||||
// Matches one of BoundIAMPrincipalARNs, so it is trusted
|
||||
return nil
|
||||
}
|
||||
if a.config.EnableIAMEntityDetails {
|
||||
for _, principalArn := range a.config.BoundIAMPrincipalARNs {
|
||||
if strings.HasSuffix(principalArn, "*") && lib.GlobbedStringsMatch(principalArn, clientArn) {
|
||||
// Wildcard match, so it is trusted
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
return fmt.Errorf("IAM principal %s is not trusted", clientArn)
|
||||
}
|
||||
|
||||
func (a *Authenticator) submitCallerIdentityRequest(ctx context.Context, req *http.Request) (*responses.GetCallerIdentityResult, error) {
|
||||
responseBody, err := a.submitRequest(ctx, req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
callerIdentityResponse, err := parseGetCallerIdentityResponse(responseBody)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error parsing STS response")
|
||||
}
|
||||
|
||||
if n := len(callerIdentityResponse.GetCallerIdentityResult); n != 1 {
|
||||
return nil, fmt.Errorf("received %d identities in STS response but expected 1", n)
|
||||
}
|
||||
return &callerIdentityResponse.GetCallerIdentityResult[0], nil
|
||||
}
|
||||
|
||||
func (a *Authenticator) submitGetIAMEntityRequest(ctx context.Context, req *http.Request, reqType string) (responses.IAMEntity, error) {
|
||||
responseBody, err := a.submitRequest(ctx, req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
iamResponse, err := parseGetIAMEntityResponse(responseBody, reqType)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error parsing IAM response: %s", err)
|
||||
}
|
||||
return iamResponse, nil
|
||||
|
||||
}
|
||||
|
||||
// https://github.com/hashicorp/vault/blob/b17e3256dde937a6248c9a2fa56206aac93d07de/builtin/credential/aws/path_login.go#L1636
|
||||
func (a *Authenticator) submitRequest(ctx context.Context, req *http.Request) (string, error) {
|
||||
retryableReq, err := retryablehttp.FromRequest(req)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
retryableReq = retryableReq.WithContext(ctx)
|
||||
client := cleanhttp.DefaultClient()
|
||||
client.CheckRedirect = func(req *http.Request, via []*http.Request) error {
|
||||
return http.ErrUseLastResponse
|
||||
}
|
||||
retryingClient := &retryablehttp.Client{
|
||||
HTTPClient: client,
|
||||
RetryWaitMin: retryWaitMin,
|
||||
RetryWaitMax: retryWaitMax,
|
||||
RetryMax: a.config.MaxRetries,
|
||||
CheckRetry: retryablehttp.DefaultRetryPolicy,
|
||||
Backoff: retryablehttp.DefaultBackoff,
|
||||
}
|
||||
|
||||
response, err := retryingClient.Do(retryableReq)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("error making request: %w", err)
|
||||
}
|
||||
if response != nil {
|
||||
defer response.Body.Close()
|
||||
}
|
||||
// Validate that the response type is XML
|
||||
if ct := response.Header.Get("Content-Type"); ct != "text/xml" {
|
||||
return "", fmt.Errorf("response body is invalid")
|
||||
}
|
||||
|
||||
// we check for status code afterwards to also print out response body
|
||||
responseBody, err := ioutil.ReadAll(response.Body)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if response.StatusCode != 200 {
|
||||
return "", fmt.Errorf("received error code %d: %s", response.StatusCode, string(responseBody))
|
||||
}
|
||||
return string(responseBody), nil
|
||||
|
||||
}
|
||||
|
||||
// https://github.com/hashicorp/vault/blob/ba533d006f2244103648785ebfe8a9a9763d2b6e/builtin/credential/aws/path_login.go#L1625-L1634
|
||||
func parseGetCallerIdentityResponse(response string) (responses.GetCallerIdentityResponse, error) {
|
||||
result := responses.GetCallerIdentityResponse{}
|
||||
response = strings.TrimSpace(response)
|
||||
if !strings.HasPrefix(response, "<GetCallerIdentityResponse") && !strings.HasPrefix(response, "<?xml") {
|
||||
return result, fmt.Errorf("body of GetCallerIdentity is invalid")
|
||||
}
|
||||
decoder := xml.NewDecoder(strings.NewReader(response))
|
||||
err := decoder.Decode(&result)
|
||||
return result, err
|
||||
}
|
||||
|
||||
func parseGetIAMEntityResponse(response string, reqType string) (responses.IAMEntity, error) {
|
||||
if !strings.HasPrefix(response, "<GetRoleResponse") &&
|
||||
!strings.HasPrefix(response, "<GetUserResponse") &&
|
||||
!strings.HasPrefix(response, "<?xml") {
|
||||
return nil, fmt.Errorf("body of GetRole or GetUser is invalid")
|
||||
}
|
||||
|
||||
decoder := xml.NewDecoder(strings.NewReader(response))
|
||||
|
||||
switch reqType {
|
||||
case "GetRole":
|
||||
result := &responses.GetRoleResponse{}
|
||||
err := decoder.Decode(&result)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if n := len(result.GetRoleResult); n != 1 {
|
||||
return nil, fmt.Errorf("received %d identities in GetRole response but expected 1", n)
|
||||
}
|
||||
return &result.GetRoleResult[0].Role, nil
|
||||
case "GetUser":
|
||||
result := &responses.GetUserResponse{}
|
||||
err := decoder.Decode(&result)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if n := len(result.GetUserResult); n != 1 {
|
||||
return nil, fmt.Errorf("received %d identities in GetUser response but expected 1", n)
|
||||
}
|
||||
return &result.GetUserResult[0].User, nil
|
||||
}
|
||||
return nil, fmt.Errorf("invalid %s request: %s", reqType, response)
|
||||
}
|
||||
|
||||
// https://github.com/hashicorp/vault/blob/b17e3256dde937a6248c9a2fa56206aac93d07de/builtin/credential/aws/path_login.go#L1532
|
||||
func validateHeaderValue(headers http.Header, headerName string, requiredHeaderValue string) error {
|
||||
providedValue := ""
|
||||
for k, v := range headers {
|
||||
if strings.EqualFold(headerName, k) {
|
||||
providedValue = strings.Join(v, ",")
|
||||
break
|
||||
}
|
||||
}
|
||||
if providedValue == "" {
|
||||
return fmt.Errorf("missing header %q", headerName)
|
||||
}
|
||||
|
||||
// NOT doing a constant time compare here since the value is NOT intended to be secret
|
||||
if providedValue != requiredHeaderValue {
|
||||
return fmt.Errorf("expected %q but got %q", requiredHeaderValue, providedValue)
|
||||
}
|
||||
|
||||
if authzHeaders, ok := headers["Authorization"]; ok {
|
||||
// authzHeader looks like AWS4-HMAC-SHA256 Credential=AKI..., SignedHeaders=host;x-amz-date;x-vault-awsiam-id, Signature=...
|
||||
// We need to extract out the SignedHeaders
|
||||
re := regexp.MustCompile(".*SignedHeaders=([^,]+)")
|
||||
authzHeader := strings.Join(authzHeaders, ",")
|
||||
matches := re.FindSubmatch([]byte(authzHeader))
|
||||
if len(matches) < 1 {
|
||||
return fmt.Errorf("server id header wasn't signed")
|
||||
}
|
||||
if len(matches) > 2 {
|
||||
return fmt.Errorf("found multiple SignedHeaders components")
|
||||
}
|
||||
signedHeaders := string(matches[1])
|
||||
return ensureHeaderIsSigned(signedHeaders, headerName)
|
||||
}
|
||||
// NOTE: If we support GET requests, then we need to parse the X-Amz-SignedHeaders
|
||||
// argument out of the query string and search in there for the header value
|
||||
return fmt.Errorf("missing Authorization header")
|
||||
}
|
||||
|
||||
func ensureHeaderIsSigned(signedHeaders, headerToSign string) error {
|
||||
// Not doing a constant time compare here, the values aren't secret
|
||||
for _, header := range strings.Split(signedHeaders, ";") {
|
||||
if header == strings.ToLower(headerToSign) {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
return fmt.Errorf("header wasn't signed")
|
||||
}
|
123
internal/iamauth/auth_test.go
Normal file
123
internal/iamauth/auth_test.go
Normal file
@ -0,0 +1,123 @@
|
||||
package iamauth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"testing"
|
||||
|
||||
"github.com/aws/aws-sdk-go/aws/credentials"
|
||||
"github.com/hashicorp/consul/internal/iamauth/iamauthtest"
|
||||
"github.com/hashicorp/consul/internal/iamauth/responsestest"
|
||||
"github.com/hashicorp/go-hclog"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestValidateLogin(t *testing.T) {
|
||||
f := iamauthtest.MakeFixture()
|
||||
|
||||
var (
|
||||
serverForRoleMismatchedIds = &iamauthtest.Server{
|
||||
GetCallerIdentityResponse: f.ServerForRole.GetCallerIdentityResponse,
|
||||
GetRoleResponse: responsestest.MakeGetRoleResponse(f.RoleARN, "AAAAsomenonmatchingid"),
|
||||
}
|
||||
serverForUserMismatchedIds = &iamauthtest.Server{
|
||||
GetCallerIdentityResponse: f.ServerForUser.GetCallerIdentityResponse,
|
||||
GetUserResponse: responsestest.MakeGetUserResponse(f.UserARN, "AAAAsomenonmatchingid"),
|
||||
}
|
||||
)
|
||||
|
||||
cases := map[string]struct {
|
||||
config *Config
|
||||
server *iamauthtest.Server
|
||||
expIdent *IdentityDetails
|
||||
expError string
|
||||
}{
|
||||
"no bound principals": {
|
||||
expError: "not trusted",
|
||||
server: f.ServerForRole,
|
||||
config: &Config{},
|
||||
},
|
||||
"no matching principal": {
|
||||
expError: "not trusted",
|
||||
server: f.ServerForUser,
|
||||
config: &Config{
|
||||
BoundIAMPrincipalARNs: []string{
|
||||
"arn:aws:iam::1234567890:user/some-other-role",
|
||||
"arn:aws:iam::1234567890:user/some-other-user",
|
||||
},
|
||||
},
|
||||
},
|
||||
"mismatched server id header": {
|
||||
expError: `expected "some-non-matching-value" but got "server.id.example.com"`,
|
||||
server: f.ServerForRole,
|
||||
config: &Config{
|
||||
BoundIAMPrincipalARNs: []string{f.CanonicalRoleARN},
|
||||
ServerIDHeaderValue: "some-non-matching-value",
|
||||
ServerIDHeaderName: "X-Test-ServerID",
|
||||
},
|
||||
},
|
||||
"role unique id mismatch": {
|
||||
expError: "unique id mismatch in login token",
|
||||
// The RoleId in the GetRole response must match the UserId in the GetCallerIdentity response
|
||||
// during login. If not, the RoleId cannot be used.
|
||||
server: serverForRoleMismatchedIds,
|
||||
config: &Config{
|
||||
BoundIAMPrincipalARNs: []string{f.RoleARN},
|
||||
EnableIAMEntityDetails: true,
|
||||
},
|
||||
},
|
||||
"user unique id mismatch": {
|
||||
expError: "unique id mismatch in login token",
|
||||
server: serverForUserMismatchedIds,
|
||||
config: &Config{
|
||||
BoundIAMPrincipalARNs: []string{f.UserARN},
|
||||
EnableIAMEntityDetails: true,
|
||||
},
|
||||
},
|
||||
}
|
||||
logger := hclog.New(nil)
|
||||
for name, c := range cases {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
fakeAws := iamauthtest.NewTestServer(t, c.server)
|
||||
|
||||
c.config.STSEndpoint = fakeAws.URL + "/sts"
|
||||
c.config.IAMEndpoint = fakeAws.URL + "/iam"
|
||||
setTestHeaderNames(c.config)
|
||||
|
||||
// This bypasses NewAuthenticator, which bypasses config.Validate().
|
||||
auth := &Authenticator{config: c.config, logger: logger}
|
||||
|
||||
loginInput := &LoginInput{
|
||||
Creds: credentials.NewStaticCredentials("fake", "fake", ""),
|
||||
IncludeIAMEntity: c.config.EnableIAMEntityDetails,
|
||||
STSEndpoint: c.config.STSEndpoint,
|
||||
STSRegion: "fake-region",
|
||||
Logger: logger,
|
||||
ServerIDHeaderValue: "server.id.example.com",
|
||||
}
|
||||
setLoginInputHeaderNames(loginInput)
|
||||
loginData, err := GenerateLoginData(loginInput)
|
||||
require.NoError(t, err)
|
||||
loginBytes, err := json.Marshal(loginData)
|
||||
require.NoError(t, err)
|
||||
|
||||
ident, err := auth.ValidateLogin(context.Background(), string(loginBytes))
|
||||
if c.expError != "" {
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), c.expError)
|
||||
require.Nil(t, ident)
|
||||
} else {
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, c.expIdent, ident)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func setLoginInputHeaderNames(in *LoginInput) {
|
||||
in.ServerIDHeaderName = "X-Test-ServerID"
|
||||
in.GetEntityMethodHeader = "X-Test-Method"
|
||||
in.GetEntityURLHeader = "X-Test-URL"
|
||||
in.GetEntityHeadersHeader = "X-Test-Headers"
|
||||
in.GetEntityBodyHeader = "X-Test-Body"
|
||||
}
|
69
internal/iamauth/config.go
Normal file
69
internal/iamauth/config.go
Normal file
@ -0,0 +1,69 @@
|
||||
package iamauth
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
awsArn "github.com/aws/aws-sdk-go/aws/arn"
|
||||
)
|
||||
|
||||
type Config struct {
|
||||
BoundIAMPrincipalARNs []string
|
||||
EnableIAMEntityDetails bool
|
||||
IAMEntityTags []string
|
||||
ServerIDHeaderValue string
|
||||
MaxRetries int
|
||||
IAMEndpoint string
|
||||
STSEndpoint string
|
||||
STSRegion string
|
||||
AllowedSTSHeaderValues []string
|
||||
|
||||
// Customizable header names
|
||||
ServerIDHeaderName string
|
||||
GetEntityMethodHeader string
|
||||
GetEntityURLHeader string
|
||||
GetEntityHeadersHeader string
|
||||
GetEntityBodyHeader string
|
||||
}
|
||||
|
||||
func (c *Config) Validate() error {
|
||||
if len(c.BoundIAMPrincipalARNs) == 0 {
|
||||
return fmt.Errorf("BoundIAMPrincipalARNs is required and must have at least 1 entry")
|
||||
}
|
||||
|
||||
for _, arn := range c.BoundIAMPrincipalARNs {
|
||||
if n := strings.Count(arn, "*"); n > 0 {
|
||||
if !c.EnableIAMEntityDetails {
|
||||
return fmt.Errorf("Must set EnableIAMEntityDetails=true to use wildcards in BoundIAMPrincipalARNs")
|
||||
}
|
||||
if n != 1 || !strings.HasSuffix(arn, "*") {
|
||||
return fmt.Errorf("Only one wildcard is allowed at the end of the bound IAM principal ARN")
|
||||
}
|
||||
}
|
||||
|
||||
if parsed, err := awsArn.Parse(arn); err != nil {
|
||||
return fmt.Errorf("Invalid principal ARN: %q", arn)
|
||||
} else if parsed.Service != "iam" && parsed.Service != "sts" {
|
||||
return fmt.Errorf("Invalid principal ARN: %q", arn)
|
||||
}
|
||||
}
|
||||
|
||||
if len(c.IAMEntityTags) > 0 && !c.EnableIAMEntityDetails {
|
||||
return fmt.Errorf("Must set EnableIAMEntityDetails=true to use IAMUserTags")
|
||||
}
|
||||
|
||||
// If server id header checking is enabled, we need the header name.
|
||||
if c.ServerIDHeaderValue != "" && c.ServerIDHeaderName == "" {
|
||||
return fmt.Errorf("Must set ServerIDHeaderName to use a server ID value")
|
||||
}
|
||||
|
||||
if c.EnableIAMEntityDetails && (c.GetEntityBodyHeader == "" ||
|
||||
c.GetEntityHeadersHeader == "" ||
|
||||
c.GetEntityMethodHeader == "" ||
|
||||
c.GetEntityURLHeader == "") {
|
||||
return fmt.Errorf("Must set all of GetEntityMethodHeader, GetEntityURLHeader, " +
|
||||
"GetEntityHeadersHeader, and GetEntityBodyHeader when EnableIAMEntityDetails=true")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
150
internal/iamauth/config_test.go
Normal file
150
internal/iamauth/config_test.go
Normal file
@ -0,0 +1,150 @@
|
||||
package iamauth
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestConfigValidate(t *testing.T) {
|
||||
principalArn := "arn:aws:iam::000000000000:role/my-role"
|
||||
|
||||
cases := map[string]struct {
|
||||
expError string
|
||||
configs []Config
|
||||
|
||||
includeHeaderNames bool
|
||||
}{
|
||||
"bound iam principals are required": {
|
||||
expError: "BoundIAMPrincipalARNs is required and must have at least 1 entry",
|
||||
configs: []Config{
|
||||
{BoundIAMPrincipalARNs: nil},
|
||||
{BoundIAMPrincipalARNs: []string{}},
|
||||
},
|
||||
},
|
||||
"entity tags require entity details": {
|
||||
expError: "Must set EnableIAMEntityDetails=true to use IAMUserTags",
|
||||
configs: []Config{
|
||||
{
|
||||
BoundIAMPrincipalARNs: []string{principalArn},
|
||||
EnableIAMEntityDetails: false,
|
||||
IAMEntityTags: []string{"some-tag"},
|
||||
},
|
||||
},
|
||||
},
|
||||
"entity details require all entity header names": {
|
||||
expError: "Must set all of GetEntityMethodHeader, GetEntityURLHeader, " +
|
||||
"GetEntityHeadersHeader, and GetEntityBodyHeader when EnableIAMEntityDetails=true",
|
||||
configs: []Config{
|
||||
{
|
||||
BoundIAMPrincipalARNs: []string{principalArn},
|
||||
EnableIAMEntityDetails: true,
|
||||
},
|
||||
{
|
||||
BoundIAMPrincipalARNs: []string{principalArn},
|
||||
EnableIAMEntityDetails: true,
|
||||
GetEntityBodyHeader: "X-Test-Header",
|
||||
},
|
||||
{
|
||||
BoundIAMPrincipalARNs: []string{principalArn},
|
||||
EnableIAMEntityDetails: true,
|
||||
GetEntityHeadersHeader: "X-Test-Header",
|
||||
},
|
||||
{
|
||||
BoundIAMPrincipalARNs: []string{principalArn},
|
||||
EnableIAMEntityDetails: true,
|
||||
GetEntityURLHeader: "X-Test-Header",
|
||||
},
|
||||
{
|
||||
BoundIAMPrincipalARNs: []string{principalArn},
|
||||
EnableIAMEntityDetails: true,
|
||||
GetEntityMethodHeader: "X-Test-Header",
|
||||
},
|
||||
},
|
||||
},
|
||||
"wildcard principals require entity details": {
|
||||
expError: "Must set EnableIAMEntityDetails=true to use wildcards in BoundIAMPrincipalARNs",
|
||||
configs: []Config{
|
||||
{BoundIAMPrincipalARNs: []string{"arn:aws:iam::000000000000:role/*"}},
|
||||
{BoundIAMPrincipalARNs: []string{"arn:aws:iam::000000000000:role/path/*"}},
|
||||
},
|
||||
},
|
||||
"only one wildcard suffix is allowed": {
|
||||
expError: "Only one wildcard is allowed at the end of the bound IAM principal ARN",
|
||||
configs: []Config{
|
||||
{
|
||||
BoundIAMPrincipalARNs: []string{"arn:aws:iam::000000000000:role/**"},
|
||||
EnableIAMEntityDetails: true,
|
||||
},
|
||||
{
|
||||
BoundIAMPrincipalARNs: []string{"arn:aws:iam::000000000000:role/*/*"},
|
||||
EnableIAMEntityDetails: true,
|
||||
},
|
||||
{
|
||||
BoundIAMPrincipalARNs: []string{"arn:aws:iam::000000000000:role/*/path"},
|
||||
EnableIAMEntityDetails: true,
|
||||
},
|
||||
{
|
||||
BoundIAMPrincipalARNs: []string{"arn:aws:iam::000000000000:role/*/path/*"},
|
||||
EnableIAMEntityDetails: true,
|
||||
},
|
||||
},
|
||||
},
|
||||
"invalid principal arns are disallowed": {
|
||||
expError: fmt.Sprintf("Invalid principal ARN"),
|
||||
configs: []Config{
|
||||
{BoundIAMPrincipalARNs: []string{""}},
|
||||
{BoundIAMPrincipalARNs: []string{" "}},
|
||||
{BoundIAMPrincipalARNs: []string{"*"}, EnableIAMEntityDetails: true},
|
||||
{BoundIAMPrincipalARNs: []string{"arn:aws:iam:role/my-role"}},
|
||||
},
|
||||
},
|
||||
"valid principal arns are allowed": {
|
||||
includeHeaderNames: true,
|
||||
configs: []Config{
|
||||
{BoundIAMPrincipalARNs: []string{"arn:aws:sts::000000000000:assumed-role/my-role/some-session-name"}},
|
||||
{BoundIAMPrincipalARNs: []string{"arn:aws:iam::000000000000:user/my-user"}},
|
||||
{BoundIAMPrincipalARNs: []string{"arn:aws:iam::000000000000:role/my-role"}},
|
||||
{BoundIAMPrincipalARNs: []string{"arn:aws:iam::000000000000:*"}, EnableIAMEntityDetails: true},
|
||||
{BoundIAMPrincipalARNs: []string{"arn:aws:iam::000000000000:role/*"}, EnableIAMEntityDetails: true},
|
||||
{BoundIAMPrincipalARNs: []string{"arn:aws:iam::000000000000:role/path/*"}, EnableIAMEntityDetails: true},
|
||||
{BoundIAMPrincipalARNs: []string{"arn:aws:iam::000000000000:user/*"}, EnableIAMEntityDetails: true},
|
||||
{BoundIAMPrincipalARNs: []string{"arn:aws:iam::000000000000:user/path/*"}, EnableIAMEntityDetails: true},
|
||||
},
|
||||
},
|
||||
"server id header value requires service id header name": {
|
||||
expError: "Must set ServerIDHeaderName to use a server ID value",
|
||||
configs: []Config{
|
||||
{
|
||||
BoundIAMPrincipalARNs: []string{principalArn},
|
||||
ServerIDHeaderValue: "consul.test.example.com",
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for name, c := range cases {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
for _, conf := range c.configs {
|
||||
if c.includeHeaderNames {
|
||||
setTestHeaderNames(&conf)
|
||||
}
|
||||
err := conf.Validate()
|
||||
if c.expError != "" {
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), c.expError)
|
||||
} else {
|
||||
require.NoError(t, err)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func setTestHeaderNames(conf *Config) {
|
||||
conf.GetEntityMethodHeader = "X-Test-Method"
|
||||
conf.GetEntityURLHeader = "X-Test-URL"
|
||||
conf.GetEntityHeadersHeader = "X-Test-Headers"
|
||||
conf.GetEntityBodyHeader = "X-Test-Body"
|
||||
}
|
187
internal/iamauth/iamauthtest/testing.go
Normal file
187
internal/iamauth/iamauthtest/testing.go
Normal file
@ -0,0 +1,187 @@
|
||||
package iamauthtest
|
||||
|
||||
import (
|
||||
"encoding/xml"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"sort"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/hashicorp/consul/internal/iamauth/responses"
|
||||
"github.com/hashicorp/consul/internal/iamauth/responsestest"
|
||||
)
|
||||
|
||||
// NewTestServer returns a fake AWS API server for local tests:
|
||||
// It supports the following paths:
|
||||
// /sts returns STS API responses
|
||||
// /iam returns IAM API responses
|
||||
func NewTestServer(t *testing.T, s *Server) *httptest.Server {
|
||||
server := httptest.NewUnstartedServer(s)
|
||||
t.Cleanup(server.Close)
|
||||
server.Start()
|
||||
return server
|
||||
}
|
||||
|
||||
// Server contains configuration for the fake AWS API server.
|
||||
type Server struct {
|
||||
GetCallerIdentityResponse responses.GetCallerIdentityResponse
|
||||
GetRoleResponse responses.GetRoleResponse
|
||||
GetUserResponse responses.GetUserResponse
|
||||
}
|
||||
|
||||
func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != "POST" {
|
||||
writeError(w, http.StatusBadRequest, r)
|
||||
return
|
||||
}
|
||||
|
||||
switch {
|
||||
case strings.HasPrefix(r.URL.Path, "/sts"):
|
||||
writeXML(w, s.GetCallerIdentityResponse)
|
||||
case strings.HasPrefix(r.URL.Path, "/iam"):
|
||||
if bodyBytes, err := io.ReadAll(r.Body); err == nil {
|
||||
body := string(bodyBytes)
|
||||
switch {
|
||||
case strings.Contains(body, "Action=GetRole"):
|
||||
writeXML(w, s.GetRoleResponse)
|
||||
return
|
||||
case strings.Contains(body, "Action=GetUser"):
|
||||
writeXML(w, s.GetUserResponse)
|
||||
return
|
||||
}
|
||||
}
|
||||
writeError(w, http.StatusBadRequest, r)
|
||||
default:
|
||||
writeError(w, http.StatusNotFound, r)
|
||||
}
|
||||
}
|
||||
|
||||
func writeXML(w http.ResponseWriter, val interface{}) {
|
||||
str, err := xml.MarshalIndent(val, "", " ")
|
||||
if err != nil {
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
fmt.Fprint(w, err.Error())
|
||||
return
|
||||
}
|
||||
w.Header().Add("Content-Type", "text/xml")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
fmt.Fprint(w, string(str))
|
||||
}
|
||||
|
||||
func writeError(w http.ResponseWriter, code int, r *http.Request) {
|
||||
w.WriteHeader(code)
|
||||
msg := fmt.Sprintf("%s %s", r.Method, r.URL)
|
||||
fmt.Fprintf(w, `<ErrorResponse xmlns="https://fakeaws/">
|
||||
<Error>
|
||||
<Message>Fake AWS Server Error: %s</Message>
|
||||
</Error>
|
||||
</ErrorResponse>`, msg)
|
||||
}
|
||||
|
||||
type Fixture struct {
|
||||
AssumedRoleARN string
|
||||
CanonicalRoleARN string
|
||||
RoleARN string
|
||||
RoleARNWildcard string
|
||||
RoleName string
|
||||
RolePath string
|
||||
RoleTags map[string]string
|
||||
|
||||
EntityID string
|
||||
EntityIDWithSession string
|
||||
AccountID string
|
||||
|
||||
UserARN string
|
||||
UserARNWildcard string
|
||||
UserName string
|
||||
UserPath string
|
||||
UserTags map[string]string
|
||||
|
||||
ServerForRole *Server
|
||||
ServerForUser *Server
|
||||
}
|
||||
|
||||
func MakeFixture() Fixture {
|
||||
f := Fixture{
|
||||
AssumedRoleARN: "arn:aws:sts::1234567890:assumed-role/my-role/some-session",
|
||||
CanonicalRoleARN: "arn:aws:iam::1234567890:role/my-role",
|
||||
RoleARN: "arn:aws:iam::1234567890:role/some/path/my-role",
|
||||
RoleARNWildcard: "arn:aws:iam::1234567890:role/some/path/*",
|
||||
RoleName: "my-role",
|
||||
RolePath: "some/path",
|
||||
RoleTags: map[string]string{
|
||||
"service-name": "my-service",
|
||||
"env": "my-env",
|
||||
},
|
||||
|
||||
EntityID: "AAAsomeuniqueid",
|
||||
EntityIDWithSession: "AAAsomeuniqueid:some-session",
|
||||
AccountID: "1234567890",
|
||||
|
||||
UserARN: "arn:aws:iam::1234567890:user/my-user",
|
||||
UserARNWildcard: "arn:aws:iam::1234567890:user/*",
|
||||
UserName: "my-user",
|
||||
UserPath: "",
|
||||
UserTags: map[string]string{"user-group": "my-group"},
|
||||
}
|
||||
|
||||
f.ServerForRole = &Server{
|
||||
GetCallerIdentityResponse: responsestest.MakeGetCallerIdentityResponse(
|
||||
f.AssumedRoleARN, f.EntityIDWithSession, f.AccountID,
|
||||
),
|
||||
GetRoleResponse: responsestest.MakeGetRoleResponse(
|
||||
f.RoleARN, f.EntityID, toTags(f.RoleTags)...,
|
||||
),
|
||||
}
|
||||
|
||||
f.ServerForUser = &Server{
|
||||
GetCallerIdentityResponse: responsestest.MakeGetCallerIdentityResponse(
|
||||
f.UserARN, f.EntityID, f.AccountID,
|
||||
),
|
||||
GetUserResponse: responsestest.MakeGetUserResponse(
|
||||
f.UserARN, f.EntityID, toTags(f.UserTags)...,
|
||||
),
|
||||
}
|
||||
|
||||
return f
|
||||
}
|
||||
|
||||
func (f *Fixture) RoleTagKeys() []string { return keys(f.RoleTags) }
|
||||
func (f *Fixture) UserTagKeys() []string { return keys(f.UserTags) }
|
||||
func (f *Fixture) RoleTagValues() []string { return values(f.RoleTags) }
|
||||
func (f *Fixture) UserTagValues() []string { return values(f.UserTags) }
|
||||
|
||||
// toTags converts the map to a slice of responses.Tag
|
||||
func toTags(tags map[string]string) []responses.Tag {
|
||||
result := []responses.Tag{}
|
||||
for k, v := range tags {
|
||||
result = append(result, responses.Tag{
|
||||
Key: k,
|
||||
Value: v,
|
||||
})
|
||||
}
|
||||
return result
|
||||
|
||||
}
|
||||
|
||||
// keys returns the keys in sorted order
|
||||
func keys(tags map[string]string) []string {
|
||||
result := []string{}
|
||||
for k := range tags {
|
||||
result = append(result, k)
|
||||
}
|
||||
sort.Strings(result)
|
||||
return result
|
||||
}
|
||||
|
||||
// values returns values in tags, ordered by sorted keys
|
||||
func values(tags map[string]string) []string {
|
||||
result := []string{}
|
||||
for _, k := range keys(tags) { // ensures sorted by key
|
||||
result = append(result, tags[k])
|
||||
}
|
||||
return result
|
||||
}
|
94
internal/iamauth/responses/arn.go
Normal file
94
internal/iamauth/responses/arn.go
Normal file
@ -0,0 +1,94 @@
|
||||
package responses
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// https://github.com/hashicorp/vault/blob/ba533d006f2244103648785ebfe8a9a9763d2b6e/builtin/credential/aws/path_login.go#L1722-L1744
|
||||
type ParsedArn struct {
|
||||
Partition string
|
||||
AccountNumber string
|
||||
Type string
|
||||
Path string
|
||||
FriendlyName string
|
||||
SessionInfo string
|
||||
}
|
||||
|
||||
// https://github.com/hashicorp/vault/blob/ba533d006f2244103648785ebfe8a9a9763d2b6e/builtin/credential/aws/path_login.go#L1482-L1530
|
||||
// However, instance profiles are not support in Consul.
|
||||
func ParseArn(iamArn string) (*ParsedArn, error) {
|
||||
// iamArn should look like one of the following:
|
||||
// 1. arn:aws:iam::<account_id>:<entity_type>/<UserName>
|
||||
// 2. arn:aws:sts::<account_id>:assumed-role/<RoleName>/<RoleSessionName>
|
||||
// if we get something like 2, then we want to transform that back to what
|
||||
// most people would expect, which is arn:aws:iam::<account_id>:role/<RoleName>
|
||||
var entity ParsedArn
|
||||
fullParts := strings.Split(iamArn, ":")
|
||||
if len(fullParts) != 6 {
|
||||
return nil, fmt.Errorf("unrecognized arn: contains %d colon-separated parts, expected 6", len(fullParts))
|
||||
}
|
||||
if fullParts[0] != "arn" {
|
||||
return nil, fmt.Errorf("unrecognized arn: does not begin with \"arn:\"")
|
||||
}
|
||||
// normally aws, but could be aws-cn or aws-us-gov
|
||||
entity.Partition = fullParts[1]
|
||||
if entity.Partition == "" {
|
||||
return nil, fmt.Errorf("unrecognized arn: %q is missing the partition", iamArn)
|
||||
}
|
||||
if fullParts[2] != "iam" && fullParts[2] != "sts" {
|
||||
return nil, fmt.Errorf("unrecognized service: %v, not one of iam or sts", fullParts[2])
|
||||
}
|
||||
// fullParts[3] is the region, which doesn't matter for AWS IAM entities
|
||||
entity.AccountNumber = fullParts[4]
|
||||
if entity.AccountNumber == "" {
|
||||
return nil, fmt.Errorf("unrecognized arn: %q is missing the account number", iamArn)
|
||||
}
|
||||
// fullParts[5] would now be something like user/<UserName> or assumed-role/<RoleName>/<RoleSessionName>
|
||||
parts := strings.Split(fullParts[5], "/")
|
||||
if len(parts) < 2 {
|
||||
return nil, fmt.Errorf("unrecognized arn: %q contains fewer than 2 slash-separated parts", fullParts[5])
|
||||
}
|
||||
entity.Type = parts[0]
|
||||
entity.Path = strings.Join(parts[1:len(parts)-1], "/")
|
||||
entity.FriendlyName = parts[len(parts)-1]
|
||||
// now, entity.FriendlyName should either be <UserName> or <RoleName>
|
||||
switch entity.Type {
|
||||
case "assumed-role":
|
||||
// Check for three parts for assumed role ARNs
|
||||
if len(parts) < 3 {
|
||||
return nil, fmt.Errorf("unrecognized arn: %q contains fewer than 3 slash-separated parts", fullParts[5])
|
||||
}
|
||||
// Assumed roles don't have paths and have a slightly different format
|
||||
// parts[2] is <RoleSessionName>
|
||||
entity.Path = ""
|
||||
entity.FriendlyName = parts[1]
|
||||
entity.SessionInfo = parts[2]
|
||||
case "user":
|
||||
case "role":
|
||||
// case "instance-profile":
|
||||
default:
|
||||
return nil, fmt.Errorf("unrecognized principal type: %q", entity.Type)
|
||||
}
|
||||
|
||||
if entity.FriendlyName == "" {
|
||||
return nil, fmt.Errorf("unrecognized arn: %q is missing the resource name", iamArn)
|
||||
}
|
||||
|
||||
return &entity, nil
|
||||
}
|
||||
|
||||
// CanonicalArn returns the canonical ARN for referring to an IAM entity
|
||||
func (p *ParsedArn) CanonicalArn() string {
|
||||
entityType := p.Type
|
||||
// canonicalize "assumed-role" into "role"
|
||||
if entityType == "assumed-role" {
|
||||
entityType = "role"
|
||||
}
|
||||
// Annoyingly, the assumed-role entity type doesn't have the Path of the role which was assumed
|
||||
// So, we "canonicalize" it by just completely dropping the path. The other option would be to
|
||||
// make an AWS API call to look up the role by FriendlyName, which introduces more complexity to
|
||||
// code and test, and it also breaks backwards compatibility in an area where we would really want
|
||||
// it
|
||||
return fmt.Sprintf("arn:%s:iam::%s:%s/%s", p.Partition, p.AccountNumber, entityType, p.FriendlyName)
|
||||
}
|
92
internal/iamauth/responses/responses.go
Normal file
92
internal/iamauth/responses/responses.go
Normal file
@ -0,0 +1,92 @@
|
||||
package responses
|
||||
|
||||
import "encoding/xml"
|
||||
|
||||
type GetCallerIdentityResponse struct {
|
||||
XMLName xml.Name `xml:"GetCallerIdentityResponse"`
|
||||
GetCallerIdentityResult []GetCallerIdentityResult `xml:"GetCallerIdentityResult"`
|
||||
ResponseMetadata []ResponseMetadata `xml:"ResponseMetadata"`
|
||||
}
|
||||
|
||||
type GetCallerIdentityResult struct {
|
||||
Arn string `xml:"Arn"`
|
||||
UserId string `xml:"UserId"`
|
||||
Account string `xml:"Account"`
|
||||
}
|
||||
|
||||
type ResponseMetadata struct {
|
||||
RequestId string `xml:"RequestId"`
|
||||
}
|
||||
|
||||
// IAMEntity is an interface for getting details from an IAM Role or User.
|
||||
type IAMEntity interface {
|
||||
EntityPath() string
|
||||
EntityArn() string
|
||||
EntityName() string
|
||||
EntityId() string
|
||||
EntityTags() map[string]string
|
||||
}
|
||||
|
||||
var _ IAMEntity = (*Role)(nil)
|
||||
var _ IAMEntity = (*User)(nil)
|
||||
|
||||
type GetRoleResponse struct {
|
||||
XMLName xml.Name `xml:"GetRoleResponse"`
|
||||
GetRoleResult []GetRoleResult `xml:"GetRoleResult"`
|
||||
ResponseMetadata []ResponseMetadata `xml:"ResponseMetadata"`
|
||||
}
|
||||
|
||||
type GetRoleResult struct {
|
||||
Role Role `xml:"Role"`
|
||||
}
|
||||
|
||||
type Role struct {
|
||||
Arn string `xml:"Arn"`
|
||||
Path string `xml:"Path"`
|
||||
RoleId string `xml:"RoleId"`
|
||||
RoleName string `xml:"RoleName"`
|
||||
Tags []Tag `xml:"Tags"`
|
||||
}
|
||||
|
||||
func (r *Role) EntityPath() string { return r.Path }
|
||||
func (r *Role) EntityArn() string { return r.Arn }
|
||||
func (r *Role) EntityName() string { return r.RoleName }
|
||||
func (r *Role) EntityId() string { return r.RoleId }
|
||||
func (r *Role) EntityTags() map[string]string { return tagsToMap(r.Tags) }
|
||||
|
||||
type GetUserResponse struct {
|
||||
XMLName xml.Name `xml:"GetUserResponse"`
|
||||
GetUserResult []GetUserResult `xml:"GetUserResult"`
|
||||
ResponseMetadata []ResponseMetadata `xml:"ResponseMetadata"`
|
||||
}
|
||||
|
||||
type GetUserResult struct {
|
||||
User User `xml:"User"`
|
||||
}
|
||||
|
||||
type User struct {
|
||||
Arn string `xml:"Arn"`
|
||||
Path string `xml:"Path"`
|
||||
UserId string `xml:"UserId"`
|
||||
UserName string `xml:"UserName"`
|
||||
Tags []Tag `xml:"Tags"`
|
||||
}
|
||||
|
||||
func (u *User) EntityPath() string { return u.Path }
|
||||
func (u *User) EntityArn() string { return u.Arn }
|
||||
func (u *User) EntityName() string { return u.UserName }
|
||||
func (u *User) EntityId() string { return u.UserId }
|
||||
func (u *User) EntityTags() map[string]string { return tagsToMap(u.Tags) }
|
||||
|
||||
type Tag struct {
|
||||
Key string `xml:"Key"`
|
||||
Value string `xml:"Value"`
|
||||
}
|
||||
|
||||
func tagsToMap(tags []Tag) map[string]string {
|
||||
result := map[string]string{}
|
||||
for _, tag := range tags {
|
||||
result[tag.Key] = tag.Value
|
||||
}
|
||||
return result
|
||||
}
|
157
internal/iamauth/responses/responses_test.go
Normal file
157
internal/iamauth/responses/responses_test.go
Normal file
@ -0,0 +1,157 @@
|
||||
package responses
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestParseArn(t *testing.T) {
|
||||
cases := map[string]struct {
|
||||
arn string
|
||||
expArn *ParsedArn
|
||||
}{
|
||||
"assumed-role": {
|
||||
arn: "arn:aws:sts::000000000000:assumed-role/my-role/session-name",
|
||||
expArn: &ParsedArn{
|
||||
Partition: "aws",
|
||||
AccountNumber: "000000000000",
|
||||
Type: "assumed-role",
|
||||
Path: "",
|
||||
FriendlyName: "my-role",
|
||||
SessionInfo: "session-name",
|
||||
},
|
||||
},
|
||||
"role": {
|
||||
arn: "arn:aws:iam::000000000000:role/my-role",
|
||||
expArn: &ParsedArn{
|
||||
Partition: "aws",
|
||||
AccountNumber: "000000000000",
|
||||
Type: "role",
|
||||
Path: "",
|
||||
FriendlyName: "my-role",
|
||||
SessionInfo: "",
|
||||
},
|
||||
},
|
||||
"user": {
|
||||
arn: "arn:aws:iam::000000000000:user/my-user",
|
||||
expArn: &ParsedArn{
|
||||
Partition: "aws",
|
||||
AccountNumber: "000000000000",
|
||||
Type: "user",
|
||||
Path: "",
|
||||
FriendlyName: "my-user",
|
||||
SessionInfo: "",
|
||||
},
|
||||
},
|
||||
"role with path": {
|
||||
arn: "arn:aws:iam::000000000000:role/path/my-role",
|
||||
expArn: &ParsedArn{
|
||||
Partition: "aws",
|
||||
AccountNumber: "000000000000",
|
||||
Type: "role",
|
||||
Path: "path",
|
||||
FriendlyName: "my-role",
|
||||
SessionInfo: "",
|
||||
},
|
||||
},
|
||||
"role with path 2": {
|
||||
arn: "arn:aws:iam::000000000000:role/path/to/my-role",
|
||||
expArn: &ParsedArn{
|
||||
Partition: "aws",
|
||||
AccountNumber: "000000000000",
|
||||
Type: "role",
|
||||
Path: "path/to",
|
||||
FriendlyName: "my-role",
|
||||
SessionInfo: "",
|
||||
},
|
||||
},
|
||||
"role with path 3": {
|
||||
arn: "arn:aws:iam::000000000000:role/some/path/to/my-role",
|
||||
expArn: &ParsedArn{
|
||||
Partition: "aws",
|
||||
AccountNumber: "000000000000",
|
||||
Type: "role",
|
||||
Path: "some/path/to",
|
||||
FriendlyName: "my-role",
|
||||
SessionInfo: "",
|
||||
},
|
||||
},
|
||||
"user with path": {
|
||||
arn: "arn:aws:iam::000000000000:user/path/my-user",
|
||||
expArn: &ParsedArn{
|
||||
Partition: "aws",
|
||||
AccountNumber: "000000000000",
|
||||
Type: "user",
|
||||
Path: "path",
|
||||
FriendlyName: "my-user",
|
||||
SessionInfo: "",
|
||||
},
|
||||
},
|
||||
|
||||
// Invalid cases
|
||||
"empty string": {arn: ""},
|
||||
"wildcard": {arn: "*"},
|
||||
"missing prefix": {arn: ":aws:sts::000000000000:assumed-role/my-role/session-name"},
|
||||
"missing partition": {arn: "arn::sts::000000000000:assumed-role/my-role/session-name"},
|
||||
"missing service": {arn: "arn:aws:::000000000000:assumed-role/my-role/session-name"},
|
||||
"missing separator": {arn: "arn:aws:sts:000000000000:assumed-role/my-role/session-name"},
|
||||
"missing account id": {arn: "arn:aws:sts:::assumed-role/my-role/session-name"},
|
||||
"missing resource": {arn: "arn:aws:sts::000000000000:"},
|
||||
"assumed-role missing parts": {arn: "arn:aws:sts::000000000000:assumed-role/my-role"},
|
||||
"role missing parts": {arn: "arn:aws:sts::000000000000:role"},
|
||||
"role missing parts 2": {arn: "arn:aws:sts::000000000000:role/"},
|
||||
"user missing parts": {arn: "arn:aws:sts::000000000000:user"},
|
||||
"user missing parts 2": {arn: "arn:aws:sts::000000000000:user/"},
|
||||
"unsupported service": {arn: "arn:aws:ecs:us-east-1:000000000000:task/my-task/00000000000000000000000000000000"},
|
||||
}
|
||||
|
||||
for name, c := range cases {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
parsed, err := ParseArn(c.arn)
|
||||
if c.expArn != nil {
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, c.expArn, parsed)
|
||||
} else {
|
||||
require.Error(t, err)
|
||||
require.Nil(t, parsed)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCanonicalArn(t *testing.T) {
|
||||
cases := map[string]struct {
|
||||
arn string
|
||||
expArn string
|
||||
}{
|
||||
"assumed-role arn": {
|
||||
arn: "arn:aws:sts::000000000000:assumed-role/my-role/session-name",
|
||||
expArn: "arn:aws:iam::000000000000:role/my-role",
|
||||
},
|
||||
"role arn": {
|
||||
arn: "arn:aws:iam::000000000000:role/my-role",
|
||||
expArn: "arn:aws:iam::000000000000:role/my-role",
|
||||
},
|
||||
"role arn with path": {
|
||||
arn: "arn:aws:iam::000000000000:role/path/to/my-role",
|
||||
expArn: "arn:aws:iam::000000000000:role/my-role",
|
||||
},
|
||||
"user arn": {
|
||||
arn: "arn:aws:iam::000000000000:user/my-user",
|
||||
expArn: "arn:aws:iam::000000000000:user/my-user",
|
||||
},
|
||||
"user arn with path": {
|
||||
arn: "arn:aws:iam::000000000000:user/path/to/my-user",
|
||||
expArn: "arn:aws:iam::000000000000:user/my-user",
|
||||
},
|
||||
}
|
||||
|
||||
for name, c := range cases {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
parsed, err := ParseArn(c.arn)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, c.expArn, parsed.CanonicalArn())
|
||||
})
|
||||
}
|
||||
}
|
81
internal/iamauth/responsestest/testing.go
Normal file
81
internal/iamauth/responsestest/testing.go
Normal file
@ -0,0 +1,81 @@
|
||||
package responsestest
|
||||
|
||||
import (
|
||||
"strings"
|
||||
|
||||
"github.com/hashicorp/consul/internal/iamauth/responses"
|
||||
)
|
||||
|
||||
func MakeGetCallerIdentityResponse(arn, userId, accountId string) responses.GetCallerIdentityResponse {
|
||||
// Sanity check the UserId for unit tests.
|
||||
parsed := parseArn(arn)
|
||||
switch parsed.Type {
|
||||
case "assumed-role":
|
||||
if !strings.Contains(userId, ":") {
|
||||
panic("UserId for assumed-role in GetCallerIdentity response must be '<uniqueId>:<session>'")
|
||||
}
|
||||
default:
|
||||
if strings.Contains(userId, ":") {
|
||||
panic("UserId in GetCallerIdentity must not contain ':'")
|
||||
}
|
||||
}
|
||||
|
||||
return responses.GetCallerIdentityResponse{
|
||||
GetCallerIdentityResult: []responses.GetCallerIdentityResult{
|
||||
{
|
||||
Arn: arn,
|
||||
UserId: userId,
|
||||
Account: accountId,
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func MakeGetRoleResponse(arn, id string, tags ...responses.Tag) responses.GetRoleResponse {
|
||||
if strings.Contains(id, ":") {
|
||||
panic("RoleId in GetRole response must not contain ':'")
|
||||
}
|
||||
parsed := parseArn(arn)
|
||||
return responses.GetRoleResponse{
|
||||
GetRoleResult: []responses.GetRoleResult{
|
||||
{
|
||||
Role: responses.Role{
|
||||
Arn: arn,
|
||||
Path: parsed.Path,
|
||||
RoleId: id,
|
||||
RoleName: parsed.FriendlyName,
|
||||
Tags: tags,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func MakeGetUserResponse(arn, id string, tags ...responses.Tag) responses.GetUserResponse {
|
||||
if strings.Contains(id, ":") {
|
||||
panic("UserId in GetUser resposne must not contain ':'")
|
||||
}
|
||||
parsed := parseArn(arn)
|
||||
return responses.GetUserResponse{
|
||||
GetUserResult: []responses.GetUserResult{
|
||||
{
|
||||
User: responses.User{
|
||||
Arn: arn,
|
||||
Path: parsed.Path,
|
||||
UserId: id,
|
||||
UserName: parsed.FriendlyName,
|
||||
Tags: tags,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func parseArn(arn string) *responses.ParsedArn {
|
||||
parsed, err := responses.ParseArn(arn)
|
||||
if err != nil {
|
||||
// For testing, just fail immediately.
|
||||
panic(err)
|
||||
}
|
||||
return parsed
|
||||
}
|
343
internal/iamauth/token.go
Normal file
343
internal/iamauth/token.go
Normal file
@ -0,0 +1,343 @@
|
||||
package iamauth
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/textproto"
|
||||
"net/url"
|
||||
"strings"
|
||||
|
||||
"github.com/hashicorp/consul/lib/stringslice"
|
||||
)
|
||||
|
||||
const (
|
||||
amzHeaderPrefix = "X-Amz-"
|
||||
defaultIAMEndpoint = "https://iam.amazonaws.com"
|
||||
defaultSTSEndpoint = "https://sts.amazonaws.com"
|
||||
)
|
||||
|
||||
var defaultAllowedSTSRequestHeaders = []string{
|
||||
"X-Amz-Algorithm",
|
||||
"X-Amz-Content-Sha256",
|
||||
"X-Amz-Credential",
|
||||
"X-Amz-Date",
|
||||
"X-Amz-Security-Token",
|
||||
"X-Amz-Signature",
|
||||
"X-Amz-SignedHeaders",
|
||||
}
|
||||
|
||||
// BearerToken is a login "token" for an IAM auth method. It is a signed
|
||||
// sts:GetCallerIdentity request in JSON format. Optionally, it can include a
|
||||
// signed embedded iam:GetRole or iam:GetUser request in the headers.
|
||||
type BearerToken struct {
|
||||
config *Config
|
||||
|
||||
getCallerIdentityMethod string
|
||||
getCallerIdentityURL string
|
||||
getCallerIdentityHeader http.Header
|
||||
getCallerIdentityBody string
|
||||
|
||||
getIAMEntityMethod string
|
||||
getIAMEntityURL string
|
||||
getIAMEntityHeader http.Header
|
||||
getIAMEntityBody string
|
||||
|
||||
entityRequestType string
|
||||
parsedCallerIdentityURL *url.URL
|
||||
parsedIAMEntityURL *url.URL
|
||||
}
|
||||
|
||||
var _ json.Unmarshaler = (*BearerToken)(nil)
|
||||
|
||||
func NewBearerToken(loginToken string, config *Config) (*BearerToken, error) {
|
||||
token := &BearerToken{config: config}
|
||||
if err := json.Unmarshal([]byte(loginToken), &token); err != nil {
|
||||
return nil, fmt.Errorf("invalid token: %s", err)
|
||||
}
|
||||
|
||||
if err := token.validate(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if config.EnableIAMEntityDetails {
|
||||
method, err := token.getHeader(token.config.GetEntityMethodHeader)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
rawUrl, err := token.getHeader(token.config.GetEntityURLHeader)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
headerJson, err := token.getHeader(token.config.GetEntityHeadersHeader)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var header http.Header
|
||||
if err := json.Unmarshal([]byte(headerJson), &header); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
body, err := token.getHeader(token.config.GetEntityBodyHeader)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
parsedUrl, err := parseUrl(rawUrl)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
token.getIAMEntityMethod = method
|
||||
token.getIAMEntityBody = body
|
||||
token.getIAMEntityURL = rawUrl
|
||||
token.getIAMEntityHeader = header
|
||||
token.parsedIAMEntityURL = parsedUrl
|
||||
|
||||
reqType, err := token.validateIAMEntityBody()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
token.entityRequestType = reqType
|
||||
}
|
||||
return token, nil
|
||||
}
|
||||
|
||||
// https://github.com/hashicorp/vault/blob/b17e3256dde937a6248c9a2fa56206aac93d07de/builtin/credential/aws/path_login.go#L1178
|
||||
func (t *BearerToken) validate() error {
|
||||
if t.getCallerIdentityMethod != "POST" {
|
||||
return fmt.Errorf("iam_http_request_method must be POST")
|
||||
}
|
||||
if err := t.validateGetCallerIdentityBody(); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := t.validateAllowedSTSHeaderValues(); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// https://github.com/hashicorp/vault/blob/b17e3256dde937a6248c9a2fa56206aac93d07de/builtin/credential/aws/path_login.go#L1439
|
||||
func (t *BearerToken) validateGetCallerIdentityBody() error {
|
||||
allowedValues := url.Values{
|
||||
"Action": []string{"GetCallerIdentity"},
|
||||
// Will assume for now that future versions don't change
|
||||
// the semantics
|
||||
"Version": nil, // any value is allowed
|
||||
}
|
||||
if _, err := parseRequestBody(t.getCallerIdentityBody, allowedValues); err != nil {
|
||||
return fmt.Errorf("iam_request_body error: %s", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t *BearerToken) validateIAMEntityBody() (string, error) {
|
||||
allowedValues := url.Values{
|
||||
"Action": []string{"GetRole", "GetUser"},
|
||||
"RoleName": nil, // any value is allowed
|
||||
"UserName": nil,
|
||||
"Version": nil,
|
||||
}
|
||||
body, err := parseRequestBody(t.getIAMEntityBody, allowedValues)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("iam_request_headers[%s] error: %s", t.config.GetEntityBodyHeader, err)
|
||||
}
|
||||
|
||||
// Disallow GetRole+UserName and GetUser+RoleName.
|
||||
action := body["Action"][0]
|
||||
_, hasRoleName := body["RoleName"]
|
||||
_, hasUserName := body["UserName"]
|
||||
if action == "GetUser" && hasUserName && !hasRoleName {
|
||||
return action, nil
|
||||
} else if action == "GetRole" && hasRoleName && !hasUserName {
|
||||
return action, nil
|
||||
}
|
||||
return "", fmt.Errorf("iam_request_headers[%q] error: invalid request body %q", t.config.GetEntityBodyHeader, t.getIAMEntityBody)
|
||||
}
|
||||
|
||||
// parseRequestBody parses the AWS STS or IAM request body, such as 'Action=GetRole&RoleName=my-role'.
|
||||
// It returns the parsed values, or an error if there are unexpected fields based on allowedValues.
|
||||
//
|
||||
// A key-value pair in the body is allowed if:
|
||||
// - It is a single value (i.e. no bodies like 'Action=1&Action=2')
|
||||
// - allowedValues[key] is an empty slice or nil (any value is allowed for the key)
|
||||
// - allowedValues[key] is non-empty and contains the exact value
|
||||
// This always requires an 'Action' field is present and non-empty.
|
||||
func parseRequestBody(body string, allowedValues url.Values) (url.Values, error) {
|
||||
qs, err := url.ParseQuery(body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Action field is always required.
|
||||
if _, ok := qs["Action"]; !ok || len(qs["Action"]) == 0 || qs["Action"][0] == "" {
|
||||
return nil, fmt.Errorf(`missing field "Action"`)
|
||||
}
|
||||
|
||||
// Ensure the body does not have extra fields and each
|
||||
// field in the body matches the allowed values.
|
||||
for k, v := range qs {
|
||||
exp, ok := allowedValues[k]
|
||||
if k != "Action" && !ok {
|
||||
return nil, fmt.Errorf("unexpected field %q", k)
|
||||
}
|
||||
|
||||
if len(exp) == 0 {
|
||||
// empty indicates any value is okay
|
||||
continue
|
||||
} else if len(v) != 1 || !stringslice.Contains(exp, v[0]) {
|
||||
return nil, fmt.Errorf("unexpected value %s=%v", k, v)
|
||||
}
|
||||
}
|
||||
|
||||
return qs, nil
|
||||
}
|
||||
|
||||
// https://github.com/hashicorp/vault/blob/861454e0ed1390d67ddaf1a53c1798e5e291728c/builtin/credential/aws/path_config_client.go#L349
|
||||
func (t *BearerToken) validateAllowedSTSHeaderValues() error {
|
||||
for k := range t.getCallerIdentityHeader {
|
||||
h := textproto.CanonicalMIMEHeaderKey(k)
|
||||
if strings.HasPrefix(h, amzHeaderPrefix) &&
|
||||
!stringslice.Contains(defaultAllowedSTSRequestHeaders, h) &&
|
||||
!stringslice.Contains(t.config.AllowedSTSHeaderValues, h) {
|
||||
return fmt.Errorf("invalid request header: %s", h)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// UnmarshalJSON unmarshals the bearer token details which contains an HTTP
|
||||
// request (a signed sts:GetCallerIdentity request).
|
||||
func (t *BearerToken) UnmarshalJSON(data []byte) error {
|
||||
var rawData struct {
|
||||
Method string `json:"iam_http_request_method"`
|
||||
UrlBase64 string `json:"iam_request_url"`
|
||||
HeadersBase64 string `json:"iam_request_headers"`
|
||||
BodyBase64 string `json:"iam_request_body"`
|
||||
}
|
||||
|
||||
if err := json.Unmarshal(data, &rawData); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
rawUrl, err := base64.StdEncoding.DecodeString(rawData.UrlBase64)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
headersJson, err := base64.StdEncoding.DecodeString(rawData.HeadersBase64)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var headers http.Header
|
||||
// This is a JSON-string in JSON
|
||||
if err := json.Unmarshal(headersJson, &headers); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
body, err := base64.StdEncoding.DecodeString(rawData.BodyBase64)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
t.getCallerIdentityMethod = rawData.Method
|
||||
t.getCallerIdentityBody = string(body)
|
||||
t.getCallerIdentityHeader = headers
|
||||
t.getCallerIdentityURL = string(rawUrl)
|
||||
|
||||
parsedUrl, err := parseUrl(t.getCallerIdentityURL)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
t.parsedCallerIdentityURL = parsedUrl
|
||||
return nil
|
||||
}
|
||||
|
||||
func parseUrl(s string) (*url.URL, error) {
|
||||
u, err := url.Parse(s)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// url.Parse doesn't error on empty string
|
||||
if u == nil || u.Scheme == "" || u.Host == "" || u.Path == "" {
|
||||
return nil, fmt.Errorf("url is invalid: %q", s)
|
||||
}
|
||||
return u, nil
|
||||
}
|
||||
|
||||
// GetCallerIdentityRequest returns the sts:GetCallerIdentity request decoded
|
||||
// from the bearer token.
|
||||
func (t *BearerToken) GetCallerIdentityRequest() (*http.Request, error) {
|
||||
// NOTE: We need to ensure we're calling STS, instead of acting as an unintended network proxy
|
||||
// The protection against this is that this method will only call the endpoint specified in the
|
||||
// client config (defaulting to sts.amazonaws.com), so it would require an admin to override
|
||||
// the endpoint to talk to alternate web addresses
|
||||
endpoint := defaultSTSEndpoint
|
||||
if t.config.STSEndpoint != "" {
|
||||
endpoint = t.config.STSEndpoint
|
||||
}
|
||||
|
||||
return buildHttpRequest(
|
||||
t.getCallerIdentityMethod,
|
||||
endpoint,
|
||||
t.parsedCallerIdentityURL,
|
||||
t.getCallerIdentityBody,
|
||||
t.getCallerIdentityHeader,
|
||||
)
|
||||
}
|
||||
|
||||
// GetEntityRequest returns the iam:GetUser or iam:GetRole request from the request details,
|
||||
// if present, embedded in the headers of the sts:GetCallerIdentity request.
|
||||
func (t *BearerToken) GetEntityRequest() (*http.Request, error) {
|
||||
endpoint := defaultIAMEndpoint
|
||||
if t.config.IAMEndpoint != "" {
|
||||
endpoint = t.config.IAMEndpoint
|
||||
}
|
||||
|
||||
return buildHttpRequest(
|
||||
t.getIAMEntityMethod,
|
||||
endpoint,
|
||||
t.parsedIAMEntityURL,
|
||||
t.getIAMEntityBody,
|
||||
t.getIAMEntityHeader,
|
||||
)
|
||||
}
|
||||
|
||||
// getHeader returns the header from s.GetCallerIdentityHeader, or an error if
|
||||
// the header is not found or is not a single value.
|
||||
func (t *BearerToken) getHeader(name string) (string, error) {
|
||||
values := t.getCallerIdentityHeader.Values(name)
|
||||
if len(values) == 0 {
|
||||
return "", fmt.Errorf("missing header %q", name)
|
||||
}
|
||||
if len(values) != 1 {
|
||||
return "", fmt.Errorf("invalid value for header %q (expected 1 item)", name)
|
||||
}
|
||||
return values[0], nil
|
||||
}
|
||||
|
||||
// buildHttpRequest returns an HTTP request from the given details.
|
||||
// This supports sending to a custom endpoint, but always preserves the
|
||||
// Host header and URI path, which are signed and cannot be modified.
|
||||
// There's a deeper explanation of this in the Vault source code.
|
||||
// https://github.com/hashicorp/vault/blob/b17e3256dde937a6248c9a2fa56206aac93d07de/builtin/credential/aws/path_login.go#L1569
|
||||
func buildHttpRequest(method, endpoint string, parsedUrl *url.URL, body string, headers http.Header) (*http.Request, error) {
|
||||
targetUrl := fmt.Sprintf("%s%s", endpoint, parsedUrl.RequestURI())
|
||||
request, err := http.NewRequest(method, targetUrl, strings.NewReader(body))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
request.Host = parsedUrl.Host
|
||||
for k, vals := range headers {
|
||||
for _, val := range vals {
|
||||
request.Header.Add(k, val)
|
||||
}
|
||||
}
|
||||
return request, nil
|
||||
}
|
364
internal/iamauth/token_test.go
Normal file
364
internal/iamauth/token_test.go
Normal file
@ -0,0 +1,364 @@
|
||||
package iamauth
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/url"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestNewBearerToken(t *testing.T) {
|
||||
cases := map[string]struct {
|
||||
tokenStr string
|
||||
config Config
|
||||
expToken BearerToken
|
||||
expError string
|
||||
}{
|
||||
"valid token": {
|
||||
tokenStr: validBearerTokenJson,
|
||||
expToken: validBearerTokenParsed,
|
||||
},
|
||||
"valid token with role": {
|
||||
tokenStr: validBearerTokenWithRoleJson,
|
||||
config: Config{
|
||||
EnableIAMEntityDetails: true,
|
||||
GetEntityMethodHeader: "X-Consul-IAM-GetEntity-Method",
|
||||
GetEntityURLHeader: "X-Consul-IAM-GetEntity-URL",
|
||||
GetEntityHeadersHeader: "X-Consul-IAM-GetEntity-Headers",
|
||||
GetEntityBodyHeader: "X-Consul-IAM-GetEntity-Body",
|
||||
},
|
||||
expToken: validBearerTokenWithRoleParsed,
|
||||
},
|
||||
|
||||
"empty json": {
|
||||
tokenStr: `{}`,
|
||||
expError: "unexpected end of JSON input",
|
||||
},
|
||||
"missing iam_request_method field": {
|
||||
tokenStr: tokenJsonMissingMethodField,
|
||||
expError: "iam_http_request_method must be POST",
|
||||
},
|
||||
"missing iam_request_url field": {
|
||||
tokenStr: tokenJsonMissingUrlField,
|
||||
expError: "url is invalid",
|
||||
},
|
||||
"missing iam_request_headers field": {
|
||||
tokenStr: tokenJsonMissingHeadersField,
|
||||
expError: "unexpected end of JSON input",
|
||||
},
|
||||
"missing iam_request_body field": {
|
||||
tokenStr: tokenJsonMissingBodyField,
|
||||
expError: "iam_request_body error",
|
||||
},
|
||||
"invalid json": {
|
||||
tokenStr: `{`,
|
||||
expError: "unexpected end of JSON input",
|
||||
},
|
||||
}
|
||||
for name, c := range cases {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
token, err := NewBearerToken(c.tokenStr, &c.config)
|
||||
t.Logf("token = %+v", token)
|
||||
if c.expError != "" {
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), c.expError)
|
||||
require.Nil(t, token)
|
||||
} else {
|
||||
require.NoError(t, err)
|
||||
c.expToken.config = &c.config
|
||||
require.Equal(t, &c.expToken, token)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseRequestBody(t *testing.T) {
|
||||
cases := map[string]struct {
|
||||
body string
|
||||
allowedValues url.Values
|
||||
expValues url.Values
|
||||
expError string
|
||||
}{
|
||||
"one allowed field": {
|
||||
body: "Action=GetCallerIdentity&Version=1234",
|
||||
allowedValues: url.Values{"Version": []string{"1234"}},
|
||||
expValues: url.Values{
|
||||
"Action": []string{"GetCallerIdentity"},
|
||||
"Version": []string{"1234"},
|
||||
},
|
||||
},
|
||||
"many allowed fields": {
|
||||
body: "Action=GetRole&RoleName=my-role&Version=1234",
|
||||
allowedValues: url.Values{
|
||||
"Action": []string{"GetUser", "GetRole"},
|
||||
"UserName": nil,
|
||||
"RoleName": nil,
|
||||
"Version": nil,
|
||||
},
|
||||
expValues: url.Values{
|
||||
"Action": []string{"GetRole"},
|
||||
"RoleName": []string{"my-role"},
|
||||
"Version": []string{"1234"},
|
||||
},
|
||||
},
|
||||
"action only": {
|
||||
body: "Action=GetRole",
|
||||
allowedValues: nil,
|
||||
expValues: url.Values{"Action": []string{"GetRole"}},
|
||||
},
|
||||
|
||||
"empty body": {
|
||||
expValues: url.Values{},
|
||||
expError: `missing field "Action"`,
|
||||
},
|
||||
"disallowed field": {
|
||||
body: "Action=GetRole&Version=1234&Extra=Abc",
|
||||
allowedValues: url.Values{"Action": nil, "Version": nil},
|
||||
expError: `unexpected field "Extra"`,
|
||||
},
|
||||
"mismatched action": {
|
||||
body: "Action=GetRole",
|
||||
allowedValues: url.Values{"Action": []string{"GetUser"}},
|
||||
expError: `unexpected value Action=[GetRole]`,
|
||||
},
|
||||
"mismatched field": {
|
||||
body: "Action=GetRole&Extra=1234",
|
||||
allowedValues: url.Values{"Action": nil, "Extra": []string{"abc"}},
|
||||
expError: `unexpected value Extra=[1234]`,
|
||||
},
|
||||
"multi-valued field": {
|
||||
body: "Action=GetRole&Action=GetUser",
|
||||
allowedValues: url.Values{"Action": []string{"GetRole", "GetUser"}},
|
||||
// only one value is allowed.
|
||||
expError: `unexpected value Action=[GetRole GetUser]`,
|
||||
},
|
||||
"empty action": {
|
||||
body: "Action=",
|
||||
allowedValues: nil,
|
||||
expError: `missing field "Action"`,
|
||||
},
|
||||
"missing action": {
|
||||
body: "Version=1234",
|
||||
allowedValues: url.Values{"Action": []string{"GetRole"}},
|
||||
expError: `missing field "Action"`,
|
||||
},
|
||||
}
|
||||
for name, c := range cases {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
values, err := parseRequestBody(c.body, c.allowedValues)
|
||||
if c.expError != "" {
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), c.expError)
|
||||
require.Nil(t, values)
|
||||
} else {
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, c.expValues, values)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateGetCallerIdentityBody(t *testing.T) {
|
||||
cases := map[string]struct {
|
||||
body string
|
||||
expError string
|
||||
}{
|
||||
"valid": {"Action=GetCallerIdentity&Version=1234", ""},
|
||||
"valid 2": {"Action=GetCallerIdentity", ""},
|
||||
"empty action": {
|
||||
"Action=",
|
||||
`iam_request_body error: missing field "Action"`,
|
||||
},
|
||||
"invalid action": {
|
||||
"Action=GetRole",
|
||||
`iam_request_body error: unexpected value Action=[GetRole]`,
|
||||
},
|
||||
"missing action": {
|
||||
"Version=1234",
|
||||
`iam_request_body error: missing field "Action"`,
|
||||
},
|
||||
"empty": {
|
||||
"",
|
||||
`iam_request_body error: missing field "Action"`,
|
||||
},
|
||||
}
|
||||
for name, c := range cases {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
token := &BearerToken{getCallerIdentityBody: c.body}
|
||||
err := token.validateGetCallerIdentityBody()
|
||||
if c.expError != "" {
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), c.expError)
|
||||
} else {
|
||||
require.NoError(t, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateIAMEntityBody(t *testing.T) {
|
||||
cases := map[string]struct {
|
||||
body string
|
||||
expReqType string
|
||||
expError string
|
||||
}{
|
||||
"valid role": {
|
||||
body: "Action=GetRole&RoleName=my-role&Version=1234",
|
||||
expReqType: "GetRole",
|
||||
},
|
||||
"valid role without version": {
|
||||
body: "Action=GetRole&RoleName=my-role",
|
||||
expReqType: "GetRole",
|
||||
},
|
||||
"valid user": {
|
||||
body: "Action=GetUser&UserName=my-role&Version=1234",
|
||||
expReqType: "GetUser",
|
||||
},
|
||||
"valid user without version": {
|
||||
body: "Action=GetUser&UserName=my-role",
|
||||
expReqType: "GetUser",
|
||||
},
|
||||
|
||||
"invalid action": {
|
||||
body: "Action=GetCallerIdentity",
|
||||
expError: `unexpected value Action=[GetCallerIdentity]`,
|
||||
},
|
||||
"role missing action": {
|
||||
body: "RoleName=my-role&Version=1234",
|
||||
expError: `missing field "Action"`,
|
||||
},
|
||||
"user missing action": {
|
||||
body: "UserName=my-role&Version=1234",
|
||||
expError: `missing field "Action"`,
|
||||
},
|
||||
"empty": {
|
||||
body: "",
|
||||
expError: `missing field "Action"`,
|
||||
},
|
||||
"empty action": {
|
||||
body: "Action=",
|
||||
expError: `missing field "Action"`,
|
||||
},
|
||||
"role with user name": {
|
||||
body: "Action=GetRole&UserName=my-role&Version=1234",
|
||||
expError: `invalid request body`,
|
||||
},
|
||||
"user with role name": {
|
||||
body: "Action=GetUser&RoleName=my-role&Version=1234",
|
||||
expError: `invalid request body`,
|
||||
},
|
||||
}
|
||||
for name, c := range cases {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
token := &BearerToken{
|
||||
config: &Config{},
|
||||
getIAMEntityBody: c.body,
|
||||
}
|
||||
reqType, err := token.validateIAMEntityBody()
|
||||
if c.expError != "" {
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), c.expError)
|
||||
require.Equal(t, "", reqType)
|
||||
} else {
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, c.expReqType, reqType)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
var (
|
||||
validBearerTokenJson = `{
|
||||
"iam_http_request_method":"POST",
|
||||
"iam_request_body":"QWN0aW9uPUdldENhbGxlcklkZW50aXR5JlZlcnNpb249MjAxMS0wNi0xNQ==",
|
||||
"iam_request_headers":"eyJBdXRob3JpemF0aW9uIjpbIkFXUzQtSE1BQy1TSEEyNTYgQ3JlZGVudGlhbD1mYWtlLzIwMjIwMzIyL3VzLWVhc3QtMS9zdHMvYXdzNF9yZXF1ZXN0LCBTaWduZWRIZWFkZXJzPWNvbnRlbnQtbGVuZ3RoO2NvbnRlbnQtdHlwZTtob3N0O3gtYW16LWRhdGU7eC1hbXotc2VjdXJpdHktdG9rZW4sIFNpZ25hdHVyZT1lZmMzMjBiOTcyZDA3YjM4YjY1ZWIyNDI1NjgwNWUwMzE0OWRhNTg2ZDgwNGY4YzYzNjRjZTk4ZGViZTA4MGIxIl0sIkNvbnRlbnQtTGVuZ3RoIjpbIjQzIl0sIkNvbnRlbnQtVHlwZSI6WyJhcHBsaWNhdGlvbi94LXd3dy1mb3JtLXVybGVuY29kZWQ7IGNoYXJzZXQ9dXRmLTgiXSwiVXNlci1BZ2VudCI6WyJhd3Mtc2RrLWdvLzEuNDIuMzQgKGdvMS4xNy41OyBkYXJ3aW47IGFtZDY0KSJdLCJYLUFtei1EYXRlIjpbIjIwMjIwMzIyVDIxMTEwM1oiXSwiWC1BbXotU2VjdXJpdHktVG9rZW4iOlsiZmFrZSJdfQ==",
|
||||
"iam_request_url":"aHR0cHM6Ly9zdHMuYW1hem9uYXdzLmNvbS8="
|
||||
}`
|
||||
|
||||
validBearerTokenParsed = BearerToken{
|
||||
getCallerIdentityMethod: "POST",
|
||||
getCallerIdentityURL: "https://sts.amazonaws.com/",
|
||||
getCallerIdentityHeader: http.Header{
|
||||
"Authorization": []string{"AWS4-HMAC-SHA256 Credential=fake/20220322/us-east-1/sts/aws4_request, SignedHeaders=content-length;content-type;host;x-amz-date;x-amz-security-token, Signature=efc320b972d07b38b65eb24256805e03149da586d804f8c6364ce98debe080b1"},
|
||||
"Content-Length": []string{"43"},
|
||||
"Content-Type": []string{"application/x-www-form-urlencoded; charset=utf-8"},
|
||||
"User-Agent": []string{"aws-sdk-go/1.42.34 (go1.17.5; darwin; amd64)"},
|
||||
"X-Amz-Date": []string{"20220322T211103Z"},
|
||||
"X-Amz-Security-Token": []string{"fake"},
|
||||
},
|
||||
getCallerIdentityBody: "Action=GetCallerIdentity&Version=2011-06-15",
|
||||
parsedCallerIdentityURL: &url.URL{
|
||||
Scheme: "https",
|
||||
Host: "sts.amazonaws.com",
|
||||
Path: "/",
|
||||
},
|
||||
}
|
||||
|
||||
validBearerTokenWithRoleJson = `{"iam_http_request_method":"POST","iam_request_body":"QWN0aW9uPUdldENhbGxlcklkZW50aXR5JlZlcnNpb249MjAxMS0wNi0xNQ==","iam_request_headers":"eyJBdXRob3JpemF0aW9uIjpbIkFXUzQtSE1BQy1TSEEyNTYgQ3JlZGVudGlhbD1mYWtlLWtleS1pZC8yMDIyMDMyMi9mYWtlLXJlZ2lvbi9zdHMvYXdzNF9yZXF1ZXN0LCBTaWduZWRIZWFkZXJzPWNvbnRlbnQtbGVuZ3RoO2NvbnRlbnQtdHlwZTtob3N0O3gtYW16LWRhdGU7eC1jb25zdWwtaWFtLWdldGVudGl0eS1ib2R5O3gtY29uc3VsLWlhbS1nZXRlbnRpdHktaGVhZGVyczt4LWNvbnN1bC1pYW0tZ2V0ZW50aXR5LW1ldGhvZDt4LWNvbnN1bC1pYW0tZ2V0ZW50aXR5LXVybCwgU2lnbmF0dXJlPTU2MWFjMzFiNWFkMDFjMTI0YzU0YzE2OGY3NmVhNmJmZDY0NWI4ZWM1MzQ1ZjgzNTc3MjljOWFhMGI0NzEzMzciXSwiQ29udGVudC1MZW5ndGgiOlsiNDMiXSwiQ29udGVudC1UeXBlIjpbImFwcGxpY2F0aW9uL3gtd3d3LWZvcm0tdXJsZW5jb2RlZDsgY2hhcnNldD11dGYtOCJdLCJVc2VyLUFnZW50IjpbImF3cy1zZGstZ28vMS40Mi4zNCAoZ28xLjE3LjU7IGRhcndpbjsgYW1kNjQpIl0sIlgtQW16LURhdGUiOlsiMjAyMjAzMjJUMjI1NzQyWiJdLCJYLUNvbnN1bC1JYW0tR2V0ZW50aXR5LUJvZHkiOlsiQWN0aW9uPUdldFJvbGVcdTAwMjZSb2xlTmFtZT1teS1yb2xlXHUwMDI2VmVyc2lvbj0yMDEwLTA1LTA4Il0sIlgtQ29uc3VsLUlhbS1HZXRlbnRpdHktSGVhZGVycyI6WyJ7XCJBdXRob3JpemF0aW9uXCI6W1wiQVdTNC1ITUFDLVNIQTI1NiBDcmVkZW50aWFsPWZha2Uta2V5LWlkLzIwMjIwMzIyL3VzLWVhc3QtMS9pYW0vYXdzNF9yZXF1ZXN0LCBTaWduZWRIZWFkZXJzPWNvbnRlbnQtbGVuZ3RoO2NvbnRlbnQtdHlwZTtob3N0O3gtYW16LWRhdGUsIFNpZ25hdHVyZT1hYTJhMTlkMGEzMDVkNzRiYmQwMDk3NzZiY2E4ODBlNTNjZmE5OTFlNDgzZTQwMzk0NzE4MWE0MWNjNDgyOTQwXCJdLFwiQ29udGVudC1MZW5ndGhcIjpbXCI1MFwiXSxcIkNvbnRlbnQtVHlwZVwiOltcImFwcGxpY2F0aW9uL3gtd3d3LWZvcm0tdXJsZW5jb2RlZDsgY2hhcnNldD11dGYtOFwiXSxcIlVzZXItQWdlbnRcIjpbXCJhd3Mtc2RrLWdvLzEuNDIuMzQgKGdvMS4xNy41OyBkYXJ3aW47IGFtZDY0KVwiXSxcIlgtQW16LURhdGVcIjpbXCIyMDIyMDMyMlQyMjU3NDJaXCJdfSJdLCJYLUNvbnN1bC1JYW0tR2V0ZW50aXR5LU1ldGhvZCI6WyJQT1NUIl0sIlgtQ29uc3VsLUlhbS1HZXRlbnRpdHktVXJsIjpbImh0dHBzOi8vaWFtLmFtYXpvbmF3cy5jb20vIl19","iam_request_url":"aHR0cDovLzEyNy4wLjAuMTo2MzY5Ni9zdHMv"}`
|
||||
|
||||
validBearerTokenWithRoleParsed = BearerToken{
|
||||
getCallerIdentityMethod: "POST",
|
||||
getCallerIdentityURL: "http://127.0.0.1:63696/sts/",
|
||||
getCallerIdentityHeader: http.Header{
|
||||
"Authorization": []string{"AWS4-HMAC-SHA256 Credential=fake-key-id/20220322/fake-region/sts/aws4_request, SignedHeaders=content-length;content-type;host;x-amz-date;x-consul-iam-getentity-body;x-consul-iam-getentity-headers;x-consul-iam-getentity-method;x-consul-iam-getentity-url, Signature=561ac31b5ad01c124c54c168f76ea6bfd645b8ec5345f8357729c9aa0b471337"},
|
||||
"Content-Length": []string{"43"},
|
||||
"Content-Type": []string{"application/x-www-form-urlencoded; charset=utf-8"},
|
||||
"User-Agent": []string{"aws-sdk-go/1.42.34 (go1.17.5; darwin; amd64)"},
|
||||
"X-Amz-Date": []string{"20220322T225742Z"},
|
||||
"X-Consul-Iam-Getentity-Body": []string{"Action=GetRole&RoleName=my-role&Version=2010-05-08"},
|
||||
"X-Consul-Iam-Getentity-Headers": []string{`{"Authorization":["AWS4-HMAC-SHA256 Credential=fake-key-id/20220322/us-east-1/iam/aws4_request, SignedHeaders=content-length;content-type;host;x-amz-date, Signature=aa2a19d0a305d74bbd009776bca880e53cfa991e483e403947181a41cc482940"],"Content-Length":["50"],"Content-Type":["application/x-www-form-urlencoded; charset=utf-8"],"User-Agent":["aws-sdk-go/1.42.34 (go1.17.5; darwin; amd64)"],"X-Amz-Date":["20220322T225742Z"]}`},
|
||||
"X-Consul-Iam-Getentity-Method": []string{"POST"},
|
||||
"X-Consul-Iam-Getentity-Url": []string{"https://iam.amazonaws.com/"},
|
||||
},
|
||||
getCallerIdentityBody: "Action=GetCallerIdentity&Version=2011-06-15",
|
||||
|
||||
// Fields parsed from headers above
|
||||
getIAMEntityMethod: "POST",
|
||||
getIAMEntityURL: "https://iam.amazonaws.com/",
|
||||
getIAMEntityHeader: http.Header{
|
||||
"Authorization": []string{"AWS4-HMAC-SHA256 Credential=fake-key-id/20220322/us-east-1/iam/aws4_request, SignedHeaders=content-length;content-type;host;x-amz-date, Signature=aa2a19d0a305d74bbd009776bca880e53cfa991e483e403947181a41cc482940"},
|
||||
"Content-Length": []string{"50"},
|
||||
"Content-Type": []string{"application/x-www-form-urlencoded; charset=utf-8"},
|
||||
"User-Agent": []string{"aws-sdk-go/1.42.34 (go1.17.5; darwin; amd64)"},
|
||||
"X-Amz-Date": []string{"20220322T225742Z"},
|
||||
},
|
||||
getIAMEntityBody: "Action=GetRole&RoleName=my-role&Version=2010-05-08",
|
||||
entityRequestType: "GetRole",
|
||||
|
||||
parsedCallerIdentityURL: &url.URL{
|
||||
Scheme: "http",
|
||||
Host: "127.0.0.1:63696",
|
||||
Path: "/sts/",
|
||||
},
|
||||
parsedIAMEntityURL: &url.URL{
|
||||
Scheme: "https",
|
||||
Host: "iam.amazonaws.com",
|
||||
Path: "/",
|
||||
},
|
||||
}
|
||||
|
||||
tokenJsonMissingMethodField = `{
|
||||
"iam_request_body":"QWN0aW9uPUdldENhbGxlcklkZW50aXR5JlZlcnNpb249MjAxMS0wNi0xNQ==",
|
||||
"iam_request_headers":"eyJBdXRob3JpemF0aW9uIjpbIkFXUzQtSE1BQy1TSEEyNTYgQ3JlZGVudGlhbD1mYWtlLzIwMjIwMzIyL3VzLWVhc3QtMS9zdHMvYXdzNF9yZXF1ZXN0LCBTaWduZWRIZWFkZXJzPWNvbnRlbnQtbGVuZ3RoO2NvbnRlbnQtdHlwZTtob3N0O3gtYW16LWRhdGU7eC1hbXotc2VjdXJpdHktdG9rZW4sIFNpZ25hdHVyZT1lZmMzMjBiOTcyZDA3YjM4YjY1ZWIyNDI1NjgwNWUwMzE0OWRhNTg2ZDgwNGY4YzYzNjRjZTk4ZGViZTA4MGIxIl0sIkNvbnRlbnQtTGVuZ3RoIjpbIjQzIl0sIkNvbnRlbnQtVHlwZSI6WyJhcHBsaWNhdGlvbi94LXd3dy1mb3JtLXVybGVuY29kZWQ7IGNoYXJzZXQ9dXRmLTgiXSwiVXNlci1BZ2VudCI6WyJhd3Mtc2RrLWdvLzEuNDIuMzQgKGdvMS4xNy41OyBkYXJ3aW47IGFtZDY0KSJdLCJYLUFtei1EYXRlIjpbIjIwMjIwMzIyVDIxMTEwM1oiXSwiWC1BbXotU2VjdXJpdHktVG9rZW4iOlsiZmFrZSJdfQ==",
|
||||
"iam_request_url":"aHR0cHM6Ly9zdHMuYW1hem9uYXdzLmNvbS8="
|
||||
}`
|
||||
|
||||
tokenJsonMissingBodyField = `{
|
||||
"iam_http_request_method":"POST",
|
||||
"iam_request_headers":"eyJBdXRob3JpemF0aW9uIjpbIkFXUzQtSE1BQy1TSEEyNTYgQ3JlZGVudGlhbD1mYWtlLzIwMjIwMzIyL3VzLWVhc3QtMS9zdHMvYXdzNF9yZXF1ZXN0LCBTaWduZWRIZWFkZXJzPWNvbnRlbnQtbGVuZ3RoO2NvbnRlbnQtdHlwZTtob3N0O3gtYW16LWRhdGU7eC1hbXotc2VjdXJpdHktdG9rZW4sIFNpZ25hdHVyZT1lZmMzMjBiOTcyZDA3YjM4YjY1ZWIyNDI1NjgwNWUwMzE0OWRhNTg2ZDgwNGY4YzYzNjRjZTk4ZGViZTA4MGIxIl0sIkNvbnRlbnQtTGVuZ3RoIjpbIjQzIl0sIkNvbnRlbnQtVHlwZSI6WyJhcHBsaWNhdGlvbi94LXd3dy1mb3JtLXVybGVuY29kZWQ7IGNoYXJzZXQ9dXRmLTgiXSwiVXNlci1BZ2VudCI6WyJhd3Mtc2RrLWdvLzEuNDIuMzQgKGdvMS4xNy41OyBkYXJ3aW47IGFtZDY0KSJdLCJYLUFtei1EYXRlIjpbIjIwMjIwMzIyVDIxMTEwM1oiXSwiWC1BbXotU2VjdXJpdHktVG9rZW4iOlsiZmFrZSJdfQ==",
|
||||
"iam_request_url":"aHR0cHM6Ly9zdHMuYW1hem9uYXdzLmNvbS8="
|
||||
}`
|
||||
|
||||
tokenJsonMissingHeadersField = `{
|
||||
"iam_http_request_method":"POST",
|
||||
"iam_request_body":"QWN0aW9uPUdldENhbGxlcklkZW50aXR5JlZlcnNpb249MjAxMS0wNi0xNQ==",
|
||||
"iam_request_url":"aHR0cHM6Ly9zdHMuYW1hem9uYXdzLmNvbS8="
|
||||
}`
|
||||
|
||||
tokenJsonMissingUrlField = `{
|
||||
"iam_http_request_method":"POST",
|
||||
"iam_request_body":"QWN0aW9uPUdldENhbGxlcklkZW50aXR5JlZlcnNpb249MjAxMS0wNi0xNQ==",
|
||||
"iam_request_headers":"eyJBdXRob3JpemF0aW9uIjpbIkFXUzQtSE1BQy1TSEEyNTYgQ3JlZGVudGlhbD1mYWtlLzIwMjIwMzIyL3VzLWVhc3QtMS9zdHMvYXdzNF9yZXF1ZXN0LCBTaWduZWRIZWFkZXJzPWNvbnRlbnQtbGVuZ3RoO2NvbnRlbnQtdHlwZTtob3N0O3gtYW16LWRhdGU7eC1hbXotc2VjdXJpdHktdG9rZW4sIFNpZ25hdHVyZT1lZmMzMjBiOTcyZDA3YjM4YjY1ZWIyNDI1NjgwNWUwMzE0OWRhNTg2ZDgwNGY4YzYzNjRjZTk4ZGViZTA4MGIxIl0sIkNvbnRlbnQtTGVuZ3RoIjpbIjQzIl0sIkNvbnRlbnQtVHlwZSI6WyJhcHBsaWNhdGlvbi94LXd3dy1mb3JtLXVybGVuY29kZWQ7IGNoYXJzZXQ9dXRmLTgiXSwiVXNlci1BZ2VudCI6WyJhd3Mtc2RrLWdvLzEuNDIuMzQgKGdvMS4xNy41OyBkYXJ3aW47IGFtZDY0KSJdLCJYLUFtei1EYXRlIjpbIjIwMjIwMzIyVDIxMTEwM1oiXSwiWC1BbXotU2VjdXJpdHktVG9rZW4iOlsiZmFrZSJdfQ=="
|
||||
}`
|
||||
)
|
158
internal/iamauth/util.go
Normal file
158
internal/iamauth/util.go
Normal file
@ -0,0 +1,158 @@
|
||||
package iamauth
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
|
||||
"github.com/aws/aws-sdk-go/aws"
|
||||
"github.com/aws/aws-sdk-go/aws/credentials"
|
||||
"github.com/aws/aws-sdk-go/aws/endpoints"
|
||||
"github.com/aws/aws-sdk-go/aws/request"
|
||||
"github.com/aws/aws-sdk-go/aws/session"
|
||||
"github.com/aws/aws-sdk-go/service/iam"
|
||||
"github.com/aws/aws-sdk-go/service/sts"
|
||||
"github.com/hashicorp/consul/internal/iamauth/responses"
|
||||
"github.com/hashicorp/go-hclog"
|
||||
)
|
||||
|
||||
type LoginInput struct {
|
||||
Creds *credentials.Credentials
|
||||
IncludeIAMEntity bool
|
||||
STSEndpoint string
|
||||
STSRegion string
|
||||
|
||||
Logger hclog.Logger
|
||||
|
||||
ServerIDHeaderValue string
|
||||
// Customizable header names
|
||||
ServerIDHeaderName string
|
||||
GetEntityMethodHeader string
|
||||
GetEntityURLHeader string
|
||||
GetEntityHeadersHeader string
|
||||
GetEntityBodyHeader string
|
||||
}
|
||||
|
||||
// GenerateLoginData populates the necessary data to send for the bearer token.
|
||||
// https://github.com/hashicorp/go-secure-stdlib/blob/main/awsutil/generate_credentials.go#L232-L301
|
||||
func GenerateLoginData(in *LoginInput) (map[string]interface{}, error) {
|
||||
cfg := aws.Config{
|
||||
Credentials: in.Creds,
|
||||
Region: aws.String(in.STSRegion),
|
||||
}
|
||||
if in.STSEndpoint != "" {
|
||||
cfg.Endpoint = aws.String(in.STSEndpoint)
|
||||
} else {
|
||||
cfg.EndpointResolver = endpoints.ResolverFunc(stsSigningResolver)
|
||||
}
|
||||
|
||||
stsSession, err := session.NewSessionWithOptions(session.Options{Config: cfg})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
svc := sts.New(stsSession)
|
||||
stsRequest, _ := svc.GetCallerIdentityRequest(nil)
|
||||
|
||||
// Include the iam:GetRole or iam:GetUser request in headers.
|
||||
if in.IncludeIAMEntity {
|
||||
entityRequest, err := formatSignedEntityRequest(svc, in)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
headersJson, err := json.Marshal(entityRequest.HTTPRequest.Header)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
requestBody, err := ioutil.ReadAll(entityRequest.HTTPRequest.Body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
stsRequest.HTTPRequest.Header.Add(in.GetEntityMethodHeader, entityRequest.HTTPRequest.Method)
|
||||
stsRequest.HTTPRequest.Header.Add(in.GetEntityURLHeader, entityRequest.HTTPRequest.URL.String())
|
||||
stsRequest.HTTPRequest.Header.Add(in.GetEntityHeadersHeader, string(headersJson))
|
||||
stsRequest.HTTPRequest.Header.Add(in.GetEntityBodyHeader, string(requestBody))
|
||||
}
|
||||
|
||||
// Inject the required auth header value, if supplied, and then sign the request including that header
|
||||
if in.ServerIDHeaderValue != "" {
|
||||
stsRequest.HTTPRequest.Header.Add(in.ServerIDHeaderName, in.ServerIDHeaderValue)
|
||||
}
|
||||
|
||||
stsRequest.Sign()
|
||||
|
||||
// Now extract out the relevant parts of the request
|
||||
headersJson, err := json.Marshal(stsRequest.HTTPRequest.Header)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
requestBody, err := ioutil.ReadAll(stsRequest.HTTPRequest.Body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return map[string]interface{}{
|
||||
"iam_http_request_method": stsRequest.HTTPRequest.Method,
|
||||
"iam_request_url": base64.StdEncoding.EncodeToString([]byte(stsRequest.HTTPRequest.URL.String())),
|
||||
"iam_request_headers": base64.StdEncoding.EncodeToString(headersJson),
|
||||
"iam_request_body": base64.StdEncoding.EncodeToString(requestBody),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// STS is a really weird service that used to only have global endpoints but now has regional endpoints as well.
|
||||
// For backwards compatibility, even if you request a region other than us-east-1, it'll still sign for us-east-1.
|
||||
// See, e.g., https://docs.aws.amazon.com/IAM/latest/UserGuide/id_credentials_temp_enable-regions.html#id_credentials_temp_enable-regions_writing_code
|
||||
// So we have to shim in this EndpointResolver to force it to sign for the right region
|
||||
func stsSigningResolver(service, region string, optFns ...func(*endpoints.Options)) (endpoints.ResolvedEndpoint, error) {
|
||||
defaultEndpoint, err := endpoints.DefaultResolver().EndpointFor(service, region, optFns...)
|
||||
if err != nil {
|
||||
return defaultEndpoint, err
|
||||
}
|
||||
defaultEndpoint.SigningRegion = region
|
||||
return defaultEndpoint, nil
|
||||
}
|
||||
|
||||
func formatSignedEntityRequest(svc *sts.STS, in *LoginInput) (*request.Request, error) {
|
||||
// We need to retrieve the IAM user or role for the iam:GetRole or iam:GetUser request.
|
||||
// GetCallerIdentity returns this and requires no permissions.
|
||||
resp, err := svc.GetCallerIdentity(nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
arn, err := responses.ParseArn(*resp.Arn)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
iamSession, err := session.NewSessionWithOptions(session.Options{
|
||||
Config: aws.Config{
|
||||
Credentials: svc.Config.Credentials,
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
iamSvc := iam.New(iamSession)
|
||||
|
||||
var req *request.Request
|
||||
switch arn.Type {
|
||||
case "role", "assumed-role":
|
||||
req, _ = iamSvc.GetRoleRequest(&iam.GetRoleInput{RoleName: &arn.FriendlyName})
|
||||
case "user":
|
||||
req, _ = iamSvc.GetUserRequest(&iam.GetUserInput{UserName: &arn.FriendlyName})
|
||||
default:
|
||||
return nil, fmt.Errorf("entity %s is not an IAM role or IAM user", arn.Type)
|
||||
}
|
||||
|
||||
// Inject the required auth header value, if supplied, and then sign the request including that header
|
||||
if in.ServerIDHeaderValue != "" {
|
||||
req.HTTPRequest.Header.Add(in.ServerIDHeaderName, in.ServerIDHeaderValue)
|
||||
}
|
||||
|
||||
req.Sign()
|
||||
return req, nil
|
||||
}
|
24
lib/glob.go
Normal file
24
lib/glob.go
Normal file
@ -0,0 +1,24 @@
|
||||
package lib
|
||||
|
||||
import "strings"
|
||||
|
||||
// GlobbedStringsMatch compares item to val with support for a leading and/or
|
||||
// trailing wildcard '*' in item.
|
||||
func GlobbedStringsMatch(item, val string) bool {
|
||||
if len(item) < 2 {
|
||||
return val == item
|
||||
}
|
||||
|
||||
hasPrefix := strings.HasPrefix(item, "*")
|
||||
hasSuffix := strings.HasSuffix(item, "*")
|
||||
|
||||
if hasPrefix && hasSuffix {
|
||||
return strings.Contains(val, item[1:len(item)-1])
|
||||
} else if hasPrefix {
|
||||
return strings.HasSuffix(val, item[1:])
|
||||
} else if hasSuffix {
|
||||
return strings.HasPrefix(val, item[:len(item)-1])
|
||||
}
|
||||
|
||||
return val == item
|
||||
}
|
37
lib/glob_test.go
Normal file
37
lib/glob_test.go
Normal file
@ -0,0 +1,37 @@
|
||||
package lib
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestGlobbedStringsMatch(t *testing.T) {
|
||||
tests := []struct {
|
||||
item string
|
||||
val string
|
||||
expect bool
|
||||
}{
|
||||
{"", "", true},
|
||||
{"*", "*", true},
|
||||
{"**", "**", true},
|
||||
{"*t", "t", true},
|
||||
{"*t", "test", true},
|
||||
{"t*", "test", true},
|
||||
{"*test", "test", true},
|
||||
{"*test", "a test", true},
|
||||
{"test", "a test", false},
|
||||
{"*test", "tests", false},
|
||||
{"test*", "test", true},
|
||||
{"test*", "testsss", true},
|
||||
{"test**", "testsss", false},
|
||||
{"test**", "test*", true},
|
||||
{"**test", "*test", true},
|
||||
{"TEST", "test", false},
|
||||
{"test", "test", true},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
actual := GlobbedStringsMatch(tt.item, tt.val)
|
||||
|
||||
if actual != tt.expect {
|
||||
t.Fatalf("Bad testcase %#v, expected %t, got %t", tt, tt.expect, actual)
|
||||
}
|
||||
}
|
||||
}
|
Loading…
x
Reference in New Issue
Block a user