2
0
mirror of https://github.com/status-im/consul.git synced 2025-01-13 07:14:37 +00:00

connect: support AWS PCA as a CA provider ()

Port AWS PCA provider from consul-ent
This commit is contained in:
Todd Radel 2019-07-30 22:57:51 -04:00 committed by GitHub
parent 2552f4a11a
commit 3497b7c00d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
145 changed files with 40608 additions and 7424 deletions
GNUmakefile
agent
go.modgo.sum
vendor/github.com/aws/aws-sdk-go

@ -363,7 +363,7 @@ test-envoy-integ: $(ENVOY_INTEG_DEPS)
@$(SHELL) $(CURDIR)/test/integration/connect/envoy/run-tests.sh @$(SHELL) $(CURDIR)/test/integration/connect/envoy/run-tests.sh
proto: proto:
protoc agent/connect/ca/plugin/*.proto --gofast_out=plugins=grpc:../../.. protoc agent/connect/ca/plugin/*.proto --gofast_out=plugins=grpc,Mgoogle/protobuf/duration.proto=github.com/gogo/protobuf/types:../../..
.PHONY: all ci bin dev dist cov test test-ci test-internal test-install-deps cover format vet ui static-assets tools .PHONY: all ci bin dev dist cov test test-ci test-internal test-install-deps cover format vet ui static-assets tools
.PHONY: docker-images go-build-image ui-build-image static-assets-docker consul-docker ui-docker .PHONY: docker-images go-build-image ui-build-image static-assets-docker consul-docker ui-docker

@ -527,7 +527,8 @@ func (c *ConnectCALeaf) generateNewLeaf(req *ConnectCALeafRequest,
} }
// Create a CSR. // Create a CSR.
csr, err := connect.CreateCSR(id, pk) serviceName := fmt.Sprintf("%s.service.%s.%s", req.Service, req.Datacenter, "consul")
csr, err := connect.CreateCSR(serviceName, id, pk)
if err != nil { if err != nil {
return result, err return result, err
} }

@ -584,6 +584,17 @@ func (b *Builder) Build() (rt RuntimeConfig, err error) {
"tls_server_name": "TLSServerName", "tls_server_name": "TLSServerName",
"tls_skip_verify": "TLSSkipVerify", "tls_skip_verify": "TLSSkipVerify",
// AWS ACM PCA config
"access_key_id": "AccessKeyID",
"secret_access_key": "SecretAccessKey",
"region": "Region",
"sleep_time": "SleepTime",
"root_arn": "RootARN",
"intermediate_arn": "IntermediateTemplateARN",
"key_algorithm": "KeyAlgorithm",
"signing_algorithm": "SigningAlgorithm",
"delete_on_exit": "DeleteOnExit",
// Common CA config // Common CA config
"leaf_cert_ttl": "LeafCertTTL", "leaf_cert_ttl": "LeafCertTTL",
"csr_max_per_second": "CSRMaxPerSecond", "csr_max_per_second": "CSRMaxPerSecond",
@ -1099,6 +1110,7 @@ func (b *Builder) Validate(rt RuntimeConfig) error {
"": true, "": true,
structs.ConsulCAProvider: true, structs.ConsulCAProvider: true,
structs.VaultCAProvider: true, structs.VaultCAProvider: true,
structs.AWSCAProvider: true,
} }
if _, ok := validCAProviders[rt.ConnectCAProvider]; !ok { if _, ok := validCAProviders[rt.ConnectCAProvider]; !ok {
return fmt.Errorf("%s is not a valid CA provider", rt.ConnectCAProvider) return fmt.Errorf("%s is not a valid CA provider", rt.ConnectCAProvider)
@ -1112,6 +1124,10 @@ func (b *Builder) Validate(rt RuntimeConfig) error {
if _, err := ca.ParseVaultCAConfig(rt.ConnectCAConfig); err != nil { if _, err := ca.ParseVaultCAConfig(rt.ConnectCAConfig); err != nil {
return err return err
} }
case structs.AWSCAProvider:
if _, err := ca.ParseAWSCAConfig(rt.ConnectCAConfig); err != nil {
return err
}
} }
} }

438
agent/connect/ca/aws_pca.go Normal file

@ -0,0 +1,438 @@
package ca
import (
"crypto/x509"
"fmt"
"log"
"regexp"
"time"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/service/acmpca"
"github.com/hashicorp/consul/agent/connect"
"github.com/hashicorp/consul/agent/structs"
)
const (
RootTemplateARN = "arn:aws:acm-pca:::template/RootCACertificate/V1"
IntermediateTemplateARN = "arn:aws:acm-pca:::template/SubordinateCACertificate_PathLen0/V1"
LeafTemplateARN = "arn:aws:acm-pca:::template/EndEntityCertificate/V1"
)
const (
RootValidity = 5 * 365 * 24 * time.Hour
IntermediateValidity = 1 * 365 * 24 * time.Hour
)
type AmazonPCA struct {
arn string
pcaType string
keyAlgorithm string
signingAlgorithm string
sleepTime time.Duration
certPEM string
chainPEM string
client *acmpca.ACMPCA
logger *log.Logger
}
func NewAmazonPCA(client *acmpca.ACMPCA, config *structs.AWSCAProviderConfig, arn string, pcaType string,
keyAlgorithm string, signingAlgorithm string, logger *log.Logger) *AmazonPCA {
sleepTime, err := time.ParseDuration(config.SleepTime)
if err != nil {
sleepTime = 5 * time.Second
}
return &AmazonPCA{arn: arn, pcaType: pcaType, client: client, logger: logger,
keyAlgorithm: keyAlgorithm, signingAlgorithm: signingAlgorithm, sleepTime: sleepTime}
}
func LoadAmazonPCA(client *acmpca.ACMPCA, config *structs.AWSCAProviderConfig, arn string, pcaType string,
clusterId string, keyAlgorithm string, signingAlgorithm string, logger *log.Logger) (*AmazonPCA, error) {
pca := NewAmazonPCA(client, config, arn, pcaType, keyAlgorithm, signingAlgorithm, logger)
output, err := pca.describe()
if err != nil {
return nil, err
}
warned := false
if *output.CertificateAuthority.Status != acmpca.CertificateAuthorityStatusActive {
return nil, fmt.Errorf("the specified PCA is not active: status is %s",
*output.CertificateAuthority.Status)
}
if *output.CertificateAuthority.CertificateAuthorityConfiguration.Subject.CommonName !=
GetCommonName(pcaType, clusterId) {
logger.Printf("[WARN] name of specified PCA '%s' does not match expected '%s'",
*output.CertificateAuthority.CertificateAuthorityConfiguration.Subject.CommonName,
GetCommonName(pcaType, clusterId))
warned = true
}
if *output.CertificateAuthority.CertificateAuthorityConfiguration.KeyAlgorithm != keyAlgorithm {
logger.Printf("[WARN] specified PCA is using an unexpected key algorithm: expected=%s actual=%s",
keyAlgorithm, *output.CertificateAuthority.CertificateAuthorityConfiguration.KeyAlgorithm)
warned = true
}
if *output.CertificateAuthority.CertificateAuthorityConfiguration.SigningAlgorithm != signingAlgorithm {
logger.Printf("[WARN] specified PCA is using an unexpected signing algorithm: expected=%s actual=%s",
signingAlgorithm, *output.CertificateAuthority.CertificateAuthorityConfiguration.SigningAlgorithm)
warned = true
}
if warned {
logger.Print("[WARN] existing PCA failed some preflight checks, trying to continue anyway")
} else {
logger.Print("[WARN] existing PCA passed all preflight checks")
}
return pca, nil
}
func FindAmazonPCA(client *acmpca.ACMPCA, config *structs.AWSCAProviderConfig, pcaType string, clusterId string,
keyAlgorithm string, signingAlgorithm string, logger *log.Logger) (*AmazonPCA, error) {
var name string = GetCommonName(pcaType, clusterId)
var nextToken *string
for {
input := acmpca.ListCertificateAuthoritiesInput{
MaxResults: aws.Int64(100),
NextToken: nextToken,
}
logger.Print("[DEBUG] listing existing certificate authorities")
output, err := client.ListCertificateAuthorities(&input)
if err != nil {
logger.Printf("[ERR] error searching certificate authorities: %s", err.Error())
return nil, err
}
for _, ca := range output.CertificateAuthorities {
if *ca.CertificateAuthorityConfiguration.Subject.CommonName == name &&
*ca.CertificateAuthorityConfiguration.KeyAlgorithm == keyAlgorithm &&
*ca.CertificateAuthorityConfiguration.SigningAlgorithm == signingAlgorithm &&
*ca.Status == acmpca.CertificateAuthorityStatusActive {
logger.Printf("[INFO] found an existing active CA %s", *ca.Arn)
return NewAmazonPCA(client, config, *ca.Arn, pcaType, *ca.CertificateAuthorityConfiguration.KeyAlgorithm,
*ca.CertificateAuthorityConfiguration.SigningAlgorithm, logger), nil
}
}
nextToken = output.NextToken
if nextToken == nil {
break
}
}
logger.Print("[WARN] no existing active CA found")
return nil, nil // not found
}
func CreateAmazonPCA(client *acmpca.ACMPCA, config *structs.AWSCAProviderConfig, pcaType string, clusterId string,
keyAlgorithm string, signingAlgorithm string, logger *log.Logger) (*AmazonPCA, error) {
commonName := GetCommonName(pcaType, clusterId)
createInput := acmpca.CreateCertificateAuthorityInput{
CertificateAuthorityType: aws.String(pcaType),
CertificateAuthorityConfiguration: &acmpca.CertificateAuthorityConfiguration{
Subject: &acmpca.ASN1Subject{
CommonName: aws.String(commonName),
},
KeyAlgorithm: aws.String(keyAlgorithm),
SigningAlgorithm: aws.String(signingAlgorithm),
},
RevocationConfiguration: &acmpca.RevocationConfiguration{
CrlConfiguration: &acmpca.CrlConfiguration{
Enabled: aws.Bool(false),
},
},
Tags: []*acmpca.Tag{
{Key: aws.String("ClusterID"), Value: aws.String(clusterId)},
},
}
logger.Printf("[DEBUG] creating new PCA %s", commonName)
createOutput, err := client.CreateCertificateAuthority(&createInput)
if err != nil {
return nil, err
}
// wait for PCA to be created
newARN := *createOutput.CertificateAuthorityArn
for {
logger.Printf("[DEBUG] checking to see if PCA %s is ready", newARN)
describeInput := acmpca.DescribeCertificateAuthorityInput{
CertificateAuthorityArn: aws.String(newARN),
}
describeOutput, err := client.DescribeCertificateAuthority(&describeInput)
if err != nil {
logger.Printf("[ERR] error describing PCA: %s", err.Error())
if err.(awserr.Error).Code() != acmpca.ErrCodeRequestInProgressException {
return nil, fmt.Errorf("error waiting for PCA to be created: %s", err)
}
}
if *describeOutput.CertificateAuthority.Status == acmpca.CertificateAuthorityStatusPendingCertificate {
logger.Printf("[DEBUG] new PCA %s is ready to accept a certificate", newARN)
return NewAmazonPCA(client, config, newARN, pcaType, keyAlgorithm, signingAlgorithm, logger), nil
}
logger.Print("[DEBUG] sleeping until certificate is ready")
time.Sleep(5 * time.Second) // TODO: get from provider config
}
}
func (pca *AmazonPCA) describe() (*acmpca.DescribeCertificateAuthorityOutput, error) {
input := &acmpca.DescribeCertificateAuthorityInput{
CertificateAuthorityArn: aws.String(pca.arn),
}
return pca.client.DescribeCertificateAuthority(input)
}
func (pca *AmazonPCA) GetCSR() (string, error) {
input := &acmpca.GetCertificateAuthorityCsrInput{
CertificateAuthorityArn: aws.String(pca.arn),
}
pca.logger.Printf("[DEBUG] retrieving CSR for %s", pca.arn)
output, err := pca.client.GetCertificateAuthorityCsr(input)
if err != nil {
pca.logger.Printf("[ERR] error retrieving CSR: %s", err.Error())
return "", err
}
return *output.Csr, nil
}
func (pca *AmazonPCA) SetCert(certPEM string, chainPEM string) error {
chainBytes := []byte(chainPEM)
if chainPEM == "" {
chainBytes = nil
}
input := acmpca.ImportCertificateAuthorityCertificateInput{
CertificateAuthorityArn: aws.String(pca.arn),
Certificate: []byte(certPEM),
CertificateChain: chainBytes,
}
pca.logger.Printf("[DEBUG] uploading certificate for %s", pca.arn)
_, err := pca.client.ImportCertificateAuthorityCertificate(&input)
if err != nil {
pca.logger.Printf("[ERR] error importing certificates: %s", err.Error())
return err
}
pca.certPEM = certPEM
pca.chainPEM = chainPEM
return nil
}
func (pca *AmazonPCA) getCerts() error {
if pca.certPEM == "" || pca.chainPEM == "" {
input := &acmpca.GetCertificateAuthorityCertificateInput{
CertificateAuthorityArn: aws.String(pca.arn),
}
output, err := pca.client.GetCertificateAuthorityCertificate(input)
if err != nil {
return err
}
pca.certPEM = *output.Certificate
pca.chainPEM = *output.CertificateChain
}
return nil
}
func (pca *AmazonPCA) Certificate() string {
if pca.certPEM == "" {
_ = pca.getCerts()
}
return pca.certPEM
}
func (pca *AmazonPCA) CertificateChain() string {
if pca.chainPEM == "" {
_ = pca.getCerts()
}
return pca.certPEM
}
func (pca *AmazonPCA) Generate(signingPCA *AmazonPCA) error {
csrPEM, err := pca.GetCSR()
if err != nil {
return err
}
templateARN := GetTemplateARN(pca.pcaType)
chainPEM := ""
validity := RootValidity
if pca.pcaType == acmpca.CertificateAuthorityTypeSubordinate {
chainPEM = signingPCA.Certificate()
validity = IntermediateValidity
}
newCertPEM, err := signingPCA.Sign(csrPEM, templateARN, validity)
if err != nil {
return err
}
err = pca.SetCert(newCertPEM, chainPEM)
return err
}
func (pca *AmazonPCA) Sign(csrPEM string, templateARN string, validity time.Duration) (string, error) {
issueInput := acmpca.IssueCertificateInput{
CertificateAuthorityArn: aws.String(pca.arn),
Csr: []byte(csrPEM),
SigningAlgorithm: aws.String(pca.signingAlgorithm),
TemplateArn: aws.String(templateARN),
Validity: &acmpca.Validity{
Value: aws.Int64(int64(validity.Seconds() / 86400.0)),
Type: aws.String(acmpca.ValidityPeriodTypeDays),
},
}
csr, err := connect.ParseCSR(csrPEM)
if err != nil {
return "", fmt.Errorf("unable to parse CSR: %s", err)
}
pca.logger.Printf("[DEBUG] issuing certificate for %s with %s", csr.Subject.String(), pca.arn)
issueOutput, err := pca.client.IssueCertificate(&issueInput)
if err != nil {
pca.logger.Printf("[ERR] error issuing certificate: %s", err.Error())
return "", fmt.Errorf("error issuing certificate from PCA: %s", err)
}
// wait for certificate to be created
for {
pca.logger.Printf("[DEBUG] checking to see if certificate %s is ready", *issueOutput.CertificateArn)
certInput := acmpca.GetCertificateInput{
CertificateAuthorityArn: aws.String(pca.arn),
CertificateArn: issueOutput.CertificateArn,
}
certOutput, err := pca.client.GetCertificate(&certInput)
if err != nil {
if err.(awserr.Error).Code() != acmpca.ErrCodeRequestInProgressException {
pca.logger.Printf("[ERR] error retrieving new certificate from %s: %s",
*issueOutput.CertificateArn, err.Error())
return "", fmt.Errorf("error retrieving certificate from PCA: %s", err)
}
}
if certOutput.Certificate != nil {
pca.logger.Printf("[DEBUG] certificate is ready, ARN is %s", *issueOutput.CertificateArn)
newCert, err := connect.ParseCert(*certOutput.Certificate)
if err == nil {
pca.logger.Printf("[DEBUG] certificate created: commonName=%s subjectKey=%s authorityKey=%s",
newCert.Subject.CommonName,
connect.HexString(newCert.SubjectKeyId),
connect.HexString(newCert.AuthorityKeyId))
}
return *certOutput.Certificate, nil
}
pca.logger.Printf("[DEBUG] sleeping for %s until certificate is ready", pca.sleepTime)
time.Sleep(pca.sleepTime)
}
}
func (pca *AmazonPCA) SignLeaf(csrPEM string, validity time.Duration) (string, error) {
certPEM, err := pca.Sign(csrPEM, LeafTemplateARN, validity)
return certPEM, err
}
func (pca *AmazonPCA) SignIntermediate(csrPEM string) (string, error) {
certPEM, err := pca.Sign(csrPEM, IntermediateTemplateARN, IntermediateValidity)
return certPEM, err
}
func (pca *AmazonPCA) SignRoot(csrPEM string) (string, error) {
certPEM, err := pca.Sign(csrPEM, RootTemplateARN, RootValidity)
return certPEM, err
}
func (pca *AmazonPCA) Disable() error {
input := acmpca.UpdateCertificateAuthorityInput{
CertificateAuthorityArn: aws.String(pca.arn),
Status: aws.String(acmpca.CertificateAuthorityStatusDisabled),
}
pca.logger.Printf("[INFO] disabling PCA %s", pca.arn)
_, err := pca.client.UpdateCertificateAuthority(&input)
return err
}
func (pca *AmazonPCA) Enable() error {
input := acmpca.UpdateCertificateAuthorityInput{
CertificateAuthorityArn: aws.String(pca.arn),
Status: aws.String(acmpca.CertificateAuthorityStatusActive),
}
pca.logger.Printf("[INFO] enabling PCA %s", pca.arn)
_, err := pca.client.UpdateCertificateAuthority(&input)
return err
}
func (pca *AmazonPCA) Delete() error {
input := acmpca.DeleteCertificateAuthorityInput{
CertificateAuthorityArn: aws.String(pca.arn),
}
pca.logger.Printf("[INFO] deleting PCA %s", pca.arn)
_, err := pca.client.DeleteCertificateAuthority(&input)
return err
}
func (pca *AmazonPCA) Undelete() error {
input := acmpca.RestoreCertificateAuthorityInput{
CertificateAuthorityArn: aws.String(pca.arn),
}
pca.logger.Printf("[INFO] undeleting PCA %s", pca.arn)
_, err := pca.client.RestoreCertificateAuthority(&input)
return err
}
// utility functions
func GetCommonName(pcaType string, clusterId string) string {
return fmt.Sprintf("Consul %s %s", pcaType, clusterId)
}
func GetTemplateARN(pcaType string) string {
if pcaType == acmpca.CertificateAuthorityTypeRoot {
return RootTemplateARN
} else {
return IntermediateTemplateARN
}
}
func IsValidARN(arn string) bool {
const PcaArnRegex = "^arn:([\\w-]+):([\\w-]+):(\\w{2}-\\w+-\\d+):(\\d+):(?:([\\w-]+)[/:])?([[:xdigit:]]{8}-[[:xdigit:]]{4}-[[:xdigit:]]{4}-[[:xdigit:]]{4}-[[:xdigit:]]{12})$"
matched, _ := regexp.MatchString(PcaArnRegex, arn)
return matched
}
func ToSignatureAlgorithm(algo string) x509.SignatureAlgorithm {
switch algo {
case acmpca.SigningAlgorithmSha256withrsa:
return x509.SHA256WithRSA
case acmpca.SigningAlgorithmSha384withrsa:
return x509.SHA384WithRSA
case acmpca.SigningAlgorithmSha512withrsa:
return x509.SHA512WithRSA
case acmpca.SigningAlgorithmSha256withecdsa:
return x509.ECDSAWithSHA256
case acmpca.SigningAlgorithmSha384withecdsa:
return x509.ECDSAWithSHA384
case acmpca.SigningAlgorithmSha512withecdsa:
return x509.ECDSAWithSHA512
default:
return x509.SHA256WithRSA
}
}

@ -3,6 +3,7 @@
package ca package ca
import mock "github.com/stretchr/testify/mock" import mock "github.com/stretchr/testify/mock"
import time "time"
import x509 "crypto/x509" import x509 "crypto/x509"
// MockProvider is an autogenerated mock type for the Provider type // MockProvider is an autogenerated mock type for the Provider type
@ -157,6 +158,20 @@ func (_m *MockProvider) GenerateRoot() error {
return r0 return r0
} }
// MinLifetime provides a mock function with given fields:
func (_m *MockProvider) MinLifetime() time.Duration {
ret := _m.Called()
var r0 time.Duration
if rf, ok := ret.Get(0).(func() time.Duration); ok {
r0 = rf()
} else {
r0 = ret.Get(0).(time.Duration)
}
return r0
}
// SetIntermediate provides a mock function with given fields: intermediatePEM, rootPEM // SetIntermediate provides a mock function with given fields: intermediatePEM, rootPEM
func (_m *MockProvider) SetIntermediate(intermediatePEM string, rootPEM string) error { func (_m *MockProvider) SetIntermediate(intermediatePEM string, rootPEM string) error {
ret := _m.Called(intermediatePEM, rootPEM) ret := _m.Called(intermediatePEM, rootPEM)
@ -212,3 +227,17 @@ func (_m *MockProvider) SignIntermediate(_a0 *x509.CertificateRequest) (string,
return r0, r1 return r0, r1
} }
// SupportsCrossSigning provides a mock function with given fields:
func (_m *MockProvider) SupportsCrossSigning() bool {
ret := _m.Called()
var r0 bool
if rf, ok := ret.Get(0).(func() bool); ok {
r0 = rf()
} else {
r0 = ret.Get(0).(bool)
}
return r0
}

File diff suppressed because it is too large Load Diff

@ -14,6 +14,8 @@ option go_package = "github.com/hashicorp/consul/agent/connect/ca/plugin";
package plugin; package plugin;
import "google/protobuf/duration.proto";
service CA { service CA {
rpc Configure(ConfigureRequest) returns (Empty); rpc Configure(ConfigureRequest) returns (Empty);
rpc GenerateRoot(Empty) returns (Empty); rpc GenerateRoot(Empty) returns (Empty);
@ -26,6 +28,8 @@ service CA {
rpc SignIntermediate(SignIntermediateRequest) returns (SignIntermediateResponse); rpc SignIntermediate(SignIntermediateRequest) returns (SignIntermediateResponse);
rpc CrossSignCA(CrossSignCARequest) returns (CrossSignCAResponse); rpc CrossSignCA(CrossSignCARequest) returns (CrossSignCAResponse);
rpc Cleanup(Empty) returns (Empty); rpc Cleanup(Empty) returns (Empty);
rpc SupportsCrossSigning(Empty) returns (SupportsCrossSigningResponse);
rpc MinLifetime(Empty) returns (MinLifetimeResponse);
} }
message ConfigureRequest { message ConfigureRequest {
@ -79,6 +83,14 @@ message CrossSignCAResponse {
string crt_pem = 1; string crt_pem = 1;
} }
message SupportsCrossSigningResponse {
bool supports_cross_signing = 1;
}
message MinLifetimeResponse {
google.protobuf.Duration min_lifetime = 1;
}
// Protobufs doesn't allow no req/resp so in the cases where there are // Protobufs doesn't allow no req/resp so in the cases where there are
// no arguments we use the Empty message. // no arguments we use the Empty message.
message Empty {} message Empty {}

@ -4,9 +4,12 @@ import (
"context" "context"
"crypto/x509" "crypto/x509"
"encoding/json" "encoding/json"
"time"
"github.com/gogo/protobuf/types"
"google.golang.org/grpc"
"github.com/hashicorp/consul/agent/connect/ca" "github.com/hashicorp/consul/agent/connect/ca"
"google.golang.org/grpc"
) )
// providerPluginGRPCServer implements the CAServer interface for gRPC. // providerPluginGRPCServer implements the CAServer interface for gRPC.
@ -81,6 +84,15 @@ func (p *providerPluginGRPCServer) CrossSignCA(_ context.Context, req *CrossSign
return &CrossSignCAResponse{CrtPem: crtPEM}, err return &CrossSignCAResponse{CrtPem: crtPEM}, err
} }
func (p *providerPluginGRPCServer) SupportsCrossSigning(context.Context, *Empty) (*SupportsCrossSigningResponse, error) {
s := p.impl.SupportsCrossSigning()
return &SupportsCrossSigningResponse{SupportsCrossSigning: s}, nil
}
func (p *providerPluginGRPCServer) MinLifetime(context.Context, *Empty) (*MinLifetimeResponse, error) {
return &MinLifetimeResponse{MinLifetime: types.DurationProto(p.impl.MinLifetime())}, nil
}
func (p *providerPluginGRPCServer) Cleanup(context.Context, *Empty) (*Empty, error) { func (p *providerPluginGRPCServer) Cleanup(context.Context, *Empty) (*Empty, error) {
return &Empty{}, p.impl.Cleanup() return &Empty{}, p.impl.Cleanup()
} }
@ -192,6 +204,17 @@ func (p *providerPluginGRPCClient) CrossSignCA(crt *x509.Certificate) (string, e
return resp.CrtPem, nil return resp.CrtPem, nil
} }
func (p *providerPluginGRPCClient) SupportsCrossSigning() bool {
resp, _ := p.client.SupportsCrossSigning(p.doneCtx, &Empty{})
return resp.SupportsCrossSigning
}
func (p *providerPluginGRPCClient) MinLifetime() time.Duration {
resp, _ := p.client.MinLifetime(p.doneCtx, &Empty{})
min, _ := types.DurationFromProto(resp.MinLifetime)
return min
}
func (p *providerPluginGRPCClient) Cleanup() error { func (p *providerPluginGRPCClient) Cleanup() error {
_, err := p.client.Cleanup(p.doneCtx, &Empty{}) _, err := p.client.Cleanup(p.doneCtx, &Empty{})
return p.err(err) return p.err(err)

@ -3,6 +3,9 @@ package plugin
import ( import (
"crypto/x509" "crypto/x509"
"net/rpc" "net/rpc"
"time"
"github.com/gogo/protobuf/types"
"github.com/hashicorp/consul/agent/connect/ca" "github.com/hashicorp/consul/agent/connect/ca"
) )
@ -80,6 +83,17 @@ func (p *providerPluginRPCServer) CrossSignCA(args *CrossSignCARequest, resp *Cr
return err return err
} }
func (p *providerPluginRPCServer) SupportsCrossSigning(_ struct{}, resp *SupportsCrossSigningResponse) error {
s := p.impl.SupportsCrossSigning()
resp.SupportsCrossSigning = s
return nil
}
func (p *providerPluginRPCServer) MinLifetime(_ struct{}, resp *MinLifetimeResponse) error {
resp.MinLifetime = types.DurationProto(p.impl.MinLifetime())
return nil
}
func (p *providerPluginRPCServer) Cleanup(struct{}, *struct{}) error { func (p *providerPluginRPCServer) Cleanup(struct{}, *struct{}) error {
return p.impl.Cleanup() return p.impl.Cleanup()
} }
@ -163,6 +177,19 @@ func (p *providerPluginRPCClient) CrossSignCA(crt *x509.Certificate) (string, er
return resp.CrtPem, err return resp.CrtPem, err
} }
func (p *providerPluginRPCClient) SupportsCrossSigning() bool {
var resp SupportsCrossSigningResponse
_ = p.client.Call("Plugin.SupportsCrossSigning", struct{}{}, &resp)
return resp.SupportsCrossSigning
}
func (p *providerPluginRPCClient) MinLifetime() time.Duration {
var resp MinLifetimeResponse
_ = p.client.Call("Plugin.MinLifetime", struct{}{}, &resp)
min, _ := types.DurationFromProto(resp.MinLifetime)
return min
}
func (p *providerPluginRPCClient) Cleanup() error { func (p *providerPluginRPCClient) Cleanup() error {
return p.client.Call("Plugin.Cleanup", struct{}{}, &struct{}{}) return p.client.Call("Plugin.Cleanup", struct{}{}, &struct{}{})
} }

@ -2,10 +2,21 @@ package ca
import ( import (
"crypto/x509" "crypto/x509"
"log"
"time"
) )
//go:generate mockery -name Provider -inpkg //go:generate mockery -name Provider -inpkg
// NeedsLogger is an interface that allows Consul to pass a configured
// logger into any component that needs one.
type NeedsLogger interface {
// SetLogger tells the provider to use the specified logger.
// This is called immediately after instantiating
// the provider, so that the provider can log startup messages etc.
SetLogger(l *log.Logger)
}
// Provider is the interface for Consul to interact with // Provider is the interface for Consul to interact with
// an external CA that provides leaf certificate signing for // an external CA that provides leaf certificate signing for
// given SpiffeIDServices. // given SpiffeIDServices.
@ -69,6 +80,14 @@ type Provider interface {
// returned as a PEM formatted string. // returned as a PEM formatted string.
CrossSignCA(*x509.Certificate) (string, error) CrossSignCA(*x509.Certificate) (string, error)
// SupportsCrossSigning indicates whether the provider supports cross-signing
// other CA certs via CrossSignCA().
SupportsCrossSigning() bool
// MinLifetime returns the minimum TTL allowed by the provider for certificates
// it issues.
MinLifetime() time.Duration
// Cleanup performs any necessary cleanup that should happen when the provider // Cleanup performs any necessary cleanup that should happen when the provider
// is shut down permanently, such as removing a temporary PKI backend in Vault // is shut down permanently, such as removing a temporary PKI backend in Vault
// created for an intermediate CA. // created for an intermediate CA.

@ -0,0 +1,327 @@
package ca
import (
"bytes"
"crypto/x509"
"encoding/pem"
"fmt"
"log"
"strings"
"time"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/credentials"
"github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go/service/acmpca"
"github.com/mitchellh/mapstructure"
"github.com/hashicorp/consul/agent/connect"
"github.com/hashicorp/consul/agent/structs"
)
type AWSProvider struct {
config *structs.AWSCAProviderConfig
session *session.Session
client *acmpca.ACMPCA
isRoot bool
clusterId string
logger *log.Logger
rootPCA *AmazonPCA
subPCA *AmazonPCA
sleepTime time.Duration
}
func (v *AWSProvider) SetLogger(l *log.Logger) {
v.logger = l
}
func (v *AWSProvider) Configure(clusterId string, isRoot bool, rawConfig map[string]interface{}) error {
config, err := ParseAWSCAConfig(rawConfig)
if err != nil {
return err
}
sleepTime, err := time.ParseDuration(config.SleepTime)
if err != nil {
return fmt.Errorf("invalid sleep time specified: %s", err)
}
if sleepTime.Seconds() < 1 {
return fmt.Errorf("invalid sleep time specified: must be at least 1s")
}
creds := credentials.NewStaticCredentials(config.AccessKeyId, config.SecretAccessKey, "")
awsSession, err := session.NewSession(&aws.Config{
Region: aws.String(config.Region),
Credentials: creds,
})
if err != nil {
return err
}
v.config = config
v.session = awsSession
v.isRoot = isRoot
v.clusterId = clusterId
v.client = acmpca.New(awsSession)
v.sleepTime = sleepTime
return nil
}
func (v *AWSProvider) loadFindOrCreate(arn string, pcaType string) (*AmazonPCA, error) {
if arn != "" {
return LoadAmazonPCA(v.client, v.config, arn, pcaType, v.clusterId,
v.config.KeyAlgorithm, v.config.SigningAlgorithm, v.logger)
} else {
pca, err := FindAmazonPCA(v.client, v.config, pcaType, v.clusterId,
v.config.KeyAlgorithm, v.config.SigningAlgorithm, v.logger)
if err != nil {
return nil, err
}
if pca == nil {
pca, err = CreateAmazonPCA(v.client, v.config, pcaType, v.clusterId,
v.config.KeyAlgorithm, v.config.SigningAlgorithm, v.logger)
if err != nil {
return nil, err
}
}
return pca, nil
}
}
func (v *AWSProvider) GenerateRoot() error {
if !v.isRoot {
return fmt.Errorf("provider is not the root certificate authority")
}
if v.rootPCA != nil {
return nil // root PCA has already been created
}
rootPCA, err := v.loadFindOrCreate(v.config.RootARN, acmpca.CertificateAuthorityTypeRoot)
if err != nil {
return err
}
v.rootPCA = rootPCA
return v.rootPCA.Generate(v.rootPCA)
}
func (v *AWSProvider) ensureIntermediate() error {
if v.subPCA != nil {
return nil
}
subPCA, err := v.loadFindOrCreate(v.config.IntermediateARN, acmpca.CertificateAuthorityTypeSubordinate)
if err != nil {
return err
}
v.subPCA = subPCA
return v.subPCA.Generate(v.rootPCA)
}
func (v *AWSProvider) ActiveRoot() (string, error) {
return v.rootPCA.Certificate(), nil
}
func (v *AWSProvider) GenerateIntermediateCSR() (string, error) {
if v.isRoot {
return "", fmt.Errorf("provider is the root certificate authority, " +
"cannot generate an intermediate CSR")
}
if err := v.ensureIntermediate(); err != nil {
return "", err
}
v.logger.Print("[INFO] requesting CSR for new intermediate CA cert")
return v.subPCA.GetCSR()
}
func (v *AWSProvider) SetIntermediate(intermediatePEM string, rootPEM string) error {
if err := v.ensureIntermediate(); err != nil {
return err
}
return v.subPCA.SetCert(intermediatePEM, rootPEM)
}
func (v *AWSProvider) ActiveIntermediate() (string, error) {
if err := v.ensureIntermediate(); err != nil {
return "", err
}
return v.subPCA.Certificate(), nil
}
func (v *AWSProvider) GenerateIntermediate() (string, error) {
if err := v.ensureIntermediate(); err != nil {
return "", err
}
if err := v.subPCA.Generate(v.rootPCA); err != nil {
return "", err
}
return v.subPCA.Certificate(), nil
}
func (v *AWSProvider) Sign(csr *x509.CertificateRequest) (string, error) {
if err := v.ensureIntermediate(); err != nil {
return "", err
}
var pemBuf bytes.Buffer
if err := pem.Encode(&pemBuf, &pem.Block{Type: "CERTIFICATE REQUEST", Bytes: csr.Raw}); err != nil {
return "", fmt.Errorf("error encoding CSR into PEM format: %s", err)
}
leafPEM, err := v.subPCA.SignLeaf(pemBuf.String(), v.config.LeafCertTTL)
return leafPEM, err
}
func (v *AWSProvider) SignIntermediate(csr *x509.CertificateRequest) (string, error) {
spiffeID := connect.SpiffeIDSigning{ClusterID: v.clusterId, Domain: "consul"}
if len(csr.URIs) < 1 {
return "", fmt.Errorf("intermediate does not contain a trust domain SAN")
}
if csr.URIs[0].String() != spiffeID.URI().String() {
return "", fmt.Errorf("attempt to sign intermediate from a different trust domain: "+
"mine='%s' theirs='%s'", spiffeID.URI().String(), csr.URIs[0].String())
}
var buf bytes.Buffer
if err := pem.Encode(&buf, &pem.Block{Type: "CERTIFICATE REQUEST", Bytes: csr.Raw}); err != nil {
return "", fmt.Errorf("error encoding private key: %s", err)
}
return v.rootPCA.SignIntermediate(buf.String())
}
// I'm not sure this can actually be implemented. PCA cannot cross-sign a cert, it can only
// sign a CSR, and we cannot generate a CSR from another provider's certificate without its
// private key.
func (v *AWSProvider) CrossSignCA(newCA *x509.Certificate) (string, error) {
return "", fmt.Errorf("not implemented in AWS PCA provider")
}
func (v *AWSProvider) Cleanup() error {
if v.config.DeleteOnExit {
if v.subPCA != nil {
if err := v.subPCA.Disable(); err != nil {
v.logger.Printf("[WARN] error disabling subordinate PCA: %s", err.Error())
}
if err := v.subPCA.Delete(); err != nil {
v.logger.Printf("[WARN] error deleting subordinate PCA: %s", err.Error())
}
v.subPCA = nil
}
if v.rootPCA != nil {
if err := v.rootPCA.Disable(); err != nil {
v.logger.Printf("[WARN] error disabling root PCA: %s", err.Error())
}
if err := v.rootPCA.Delete(); err != nil {
v.logger.Printf("[WARN] error deleting root PCA: %s", err.Error())
}
v.rootPCA = nil
}
}
return nil
}
func (v *AWSProvider) SupportsCrossSigning() bool {
return false
}
func (v *AWSProvider) MinLifetime() time.Duration {
return 24 * time.Hour
}
func ParseAWSCAConfig(raw map[string]interface{}) (*structs.AWSCAProviderConfig, error) {
config := structs.AWSCAProviderConfig{
CommonCAProviderConfig: defaultCommonConfig(),
SleepTime: "5s",
}
decodeConf := &mapstructure.DecoderConfig{
DecodeHook: structs.ParseDurationFunc(),
Result: &config,
WeaklyTypedInput: true,
}
decoder, err := mapstructure.NewDecoder(decodeConf)
if err != nil {
return nil, err
}
if err := decoder.Decode(raw); err != nil {
return nil, fmt.Errorf("error decoding config: %s", err)
}
if config.AccessKeyId == "" {
return nil, fmt.Errorf("must provide the AWS access key ID")
}
if config.SecretAccessKey == "" {
return nil, fmt.Errorf("must provide the AWS secret access key")
}
if config.Region == "" {
return nil, fmt.Errorf("must provide the AWS region")
}
if config.RootARN != "" {
if !IsValidARN(config.RootARN) {
return nil, fmt.Errorf("root PCA ARN is not in correct format")
}
}
if config.IntermediateARN != "" {
if !IsValidARN(config.IntermediateARN) {
return nil, fmt.Errorf("intermediate PCA ARN is not in correct format")
}
}
if config.KeyAlgorithm == "" {
config.KeyAlgorithm = acmpca.KeyAlgorithmEcPrime256v1
} else {
config.KeyAlgorithm, err = ValidateEnum(config.KeyAlgorithm,
acmpca.KeyAlgorithmRsa2048, acmpca.KeyAlgorithmRsa4096,
acmpca.KeyAlgorithmEcPrime256v1, acmpca.KeyAlgorithmEcSecp384r1)
if err != nil {
return nil, fmt.Errorf("invalid key algorithm specified: %s", err)
}
}
if config.SigningAlgorithm == "" {
config.SigningAlgorithm = acmpca.SigningAlgorithmSha256withecdsa
} else {
config.SigningAlgorithm, err = ValidateEnum(config.SigningAlgorithm,
acmpca.SigningAlgorithmSha256withrsa, acmpca.SigningAlgorithmSha384withrsa,
acmpca.SigningAlgorithmSha512withrsa, acmpca.SigningAlgorithmSha256withecdsa,
acmpca.SigningAlgorithmSha384withecdsa, acmpca.SigningAlgorithmSha512withecdsa)
if err != nil {
return nil, fmt.Errorf("invalid signing algorithm specified: %s", err)
}
}
if err := config.CommonCAProviderConfig.Validate(); err != nil {
return nil, err
}
return &config, nil
}
func ValidateEnum(value string, choices ...string) (string, error) {
for _, choice := range choices {
if strings.ToLower(value) == strings.ToLower(choice) {
return choice, nil
}
}
return "", fmt.Errorf("must be one of %s or %s",
strings.Join(choices[:len(choices)-1], ","),
choices[len(choices)-1])
}

@ -0,0 +1,392 @@
package ca
import (
"crypto/x509"
"log"
"os"
"testing"
"time"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/credentials"
"github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go/service/acmpca"
"github.com/stretchr/testify/require"
"github.com/hashicorp/consul/agent/connect"
"github.com/hashicorp/consul/agent/structs"
)
var awsAccessKeyId string
var awsSecretAccessKey string
var awsRegion string
var awsClient *acmpca.ACMPCA
func init() {
awsAccessKeyId = os.Getenv("AWS_ACCESS_KEY_ID")
awsSecretAccessKey = os.Getenv("AWS_SECRET_ACCESS_KEY")
awsRegion = os.Getenv("AWS_REGION")
awsSession, _ := session.NewSession(&aws.Config{
Region: aws.String(awsRegion),
Credentials: credentials.NewStaticCredentials(awsAccessKeyId, awsSecretAccessKey, ""),
})
awsClient = acmpca.New(awsSession)
}
func makeConfig() *structs.CAConfiguration {
return &structs.CAConfiguration{
ClusterID: "asdf",
Provider: "aws",
Config: map[string]interface{}{
"LeafCertTTL": "72h",
"Region": awsRegion,
"AccessKeyId": awsAccessKeyId,
"SecretAccessKey": awsSecretAccessKey,
"KeyAlgorithm": acmpca.KeyAlgorithmEcPrime256v1,
"SigningAlgorithm": acmpca.SigningAlgorithmSha256withecdsa,
},
}
}
func makeProvider(r *require.Assertions, config *structs.CAConfiguration) *AWSProvider {
logger := log.New(os.Stderr, "aws_pca", log.LstdFlags)
provider := &AWSProvider{logger: logger}
r.NoError(provider.Configure(config.ClusterID, true, config.Config))
return provider
}
func TestAWSProvider_Configure(t *testing.T) {
t.Parallel()
if awsAccessKeyId == "" || awsSecretAccessKey == "" || awsRegion == "" {
t.Skip("skipping test due to missing AWS credentials")
}
r := require.New(t)
conf := makeConfig()
provider := makeProvider(r, conf)
r.Equal(conf.Config["AccessKeyId"], provider.config.AccessKeyId)
r.Equal(conf.Config["SecretAccessKey"], provider.config.SecretAccessKey)
r.Equal(conf.Config["Region"], provider.config.Region)
}
func TestAWSProvider_ConfigureBadKeyAlgorithm(t *testing.T) {
t.Parallel()
if awsAccessKeyId == "" || awsSecretAccessKey == "" || awsRegion == "" {
t.Skip("skipping test due to missing AWS credentials")
}
r := require.New(t)
conf := makeConfig()
conf.Config["KeyAlgorithm"] = "foo"
provider := &AWSProvider{}
r.Error(provider.Configure(conf.ClusterID, true, conf.Config))
}
func TestAWSProvider_ConfigureBadSigningAlgorithm(t *testing.T) {
t.Parallel()
if awsAccessKeyId == "" || awsSecretAccessKey == "" || awsRegion == "" {
t.Skip("skipping test due to missing AWS credentials")
}
r := require.New(t)
conf := makeConfig()
conf.Config["SigningAlgorithm"] = "foo"
provider := &AWSProvider{}
r.Error(provider.Configure(conf.ClusterID, true, conf.Config))
}
func TestAWSProvider_ConfigureBadSleepTime(t *testing.T) {
t.Parallel()
if awsAccessKeyId == "" || awsSecretAccessKey == "" || awsRegion == "" {
t.Skip("skipping test due to missing AWS credentials")
}
r := require.New(t)
conf := makeConfig()
conf.Config["SleepTime"] = "-5s"
provider := &AWSProvider{}
r.Error(provider.Configure(conf.ClusterID, true, conf.Config))
conf.Config["SleepTime"] = "5foo"
provider = &AWSProvider{}
r.Error(provider.Configure(conf.ClusterID, true, conf.Config))
}
func TestAWSProvider_ConfigureBadLeafTTL(t *testing.T) {
t.Parallel()
if awsAccessKeyId == "" || awsSecretAccessKey == "" || awsRegion == "" {
t.Skip("skipping test due to missing AWS credentials")
}
r := require.New(t)
conf := makeConfig()
conf.Config["LeafCertTTL"] = "-72h"
provider := &AWSProvider{}
r.Error(provider.Configure(conf.ClusterID, true, conf.Config))
}
func TestAWSProvider_GenerateRoot(t *testing.T) {
t.Parallel()
if awsAccessKeyId == "" || awsSecretAccessKey == "" || awsRegion == "" {
t.Skip("skipping test due to missing AWS credentials")
}
r := require.New(t)
conf := makeConfig()
provider := makeProvider(r, conf)
r.NoError(provider.GenerateRoot())
r.NotEmpty(provider.rootPCA.arn)
output, err := awsClient.DescribeCertificateAuthority(&acmpca.DescribeCertificateAuthorityInput{
CertificateAuthorityArn: aws.String(provider.rootPCA.arn),
})
r.NoError(err)
ca := output.CertificateAuthority
caConf := ca.CertificateAuthorityConfiguration
r.Equal(acmpca.CertificateAuthorityStatusActive, *ca.Status)
r.Equal(acmpca.KeyAlgorithmEcPrime256v1, *caConf.KeyAlgorithm)
r.Equal(acmpca.SigningAlgorithmSha256withecdsa, *caConf.SigningAlgorithm)
r.Contains(*caConf.Subject.CommonName, conf.ClusterID)
}
func TestAWSProvider_GenerateRootNotRoot(t *testing.T) {
t.Parallel()
if awsAccessKeyId == "" || awsSecretAccessKey == "" || awsRegion == "" {
t.Skip("skipping test due to missing AWS credentials")
}
r := require.New(t)
conf := makeConfig()
provider := makeProvider(r, conf)
provider.isRoot = false
r.Error(provider.GenerateRoot())
}
func TestAWSProvider_Sign(t *testing.T) {
t.Parallel()
if awsAccessKeyId == "" || awsSecretAccessKey == "" || awsRegion == "" {
t.Skip("skipping test due to missing AWS credentials")
}
r := require.New(t)
conf := makeConfig()
provider := makeProvider(r, conf)
r.NoError(provider.GenerateRoot())
cn := "foo"
pk, _, err := connect.GeneratePrivateKeyWithConfig("rsa", 2048)
r.NoError(err)
serviceID := &connect.SpiffeIDService{
Host: "11111111-2222-3333-4444-555555555555.consul",
Datacenter: "dc1",
Namespace: "default",
Service: "foo",
}
csrText, err := connect.CreateCSR(cn, serviceID, pk)
r.NoError(err)
csr, err := connect.ParseCSR(csrText)
r.NoError(err)
leafText, err := provider.Sign(csr)
r.NoError(err)
leaf, err := connect.ParseCert(leafText)
r.NoError(err)
r.Equal(csr.Subject.CommonName, leaf.Subject.CommonName)
r.Equal(serviceID.URI().String(), leaf.URIs[0].String())
r.True(leaf.NotBefore.Before(time.Now()))
r.True(leaf.NotAfter.After(time.Now()))
}
func TestAWSProvider_GenerateIntermediateCSR(t *testing.T) {
t.Parallel()
if awsAccessKeyId == "" || awsSecretAccessKey == "" || awsRegion == "" {
t.Skip("skipping test due to missing AWS credentials")
}
r := require.New(t)
conf := makeConfig()
provider := makeProvider(r, conf)
r.NoError(provider.GenerateRoot())
_, err := provider.GenerateIntermediateCSR()
r.Error(err)
provider.isRoot = false
csrText, err := provider.GenerateIntermediateCSR()
csr, err := connect.ParseCSR(csrText)
r.NoError(err)
r.Contains(csr.Subject.CommonName, conf.ClusterID)
r.Equal(x509.ECDSA, csr.PublicKeyAlgorithm)
r.Equal(x509.ECDSAWithSHA256, csr.SignatureAlgorithm)
}
func TestAWSProvider_ActiveRoot(t *testing.T) {
t.Parallel()
if awsAccessKeyId == "" || awsSecretAccessKey == "" || awsRegion == "" {
t.Skip("skipping test due to missing AWS credentials")
}
r := require.New(t)
conf := makeConfig()
provider := makeProvider(r, conf)
r.NoError(provider.GenerateRoot())
rootText, err := provider.ActiveRoot()
r.NoError(err)
root, err := connect.ParseCert(rootText)
r.NoError(err)
r.Equal(x509.ECDSA, root.PublicKeyAlgorithm)
r.Equal(x509.ECDSAWithSHA256, root.SignatureAlgorithm)
r.True(root.NotBefore.Before(time.Now()))
r.True(root.NotAfter.After(time.Now()))
r.True(root.IsCA)
}
func TestAWSProvider_GenerateIntermediate(t *testing.T) {
t.Parallel()
if awsAccessKeyId == "" || awsSecretAccessKey == "" || awsRegion == "" {
t.Skip("skipping test due to missing AWS credentials")
}
r := require.New(t)
conf := makeConfig()
provider := makeProvider(r, conf)
r.NoError(provider.GenerateRoot())
interText, err := provider.GenerateIntermediate()
r.NoError(err)
inter, err := connect.ParseCert(interText)
r.NoError(err)
r.Contains(inter.Subject.CommonName, conf.ClusterID)
r.Equal(x509.ECDSA, inter.PublicKeyAlgorithm)
r.Equal(x509.ECDSAWithSHA256, inter.SignatureAlgorithm)
r.True(inter.NotBefore.Before(time.Now()))
r.True(inter.NotAfter.After(time.Now()))
r.True(inter.IsCA)
r.True(inter.MaxPathLenZero)
}
func TestAWSProvider_ActiveIntermediate(t *testing.T) {
t.Parallel()
if awsAccessKeyId == "" || awsSecretAccessKey == "" || awsRegion == "" {
t.Skip("skipping test due to missing AWS credentials")
}
r := require.New(t)
conf := makeConfig()
provider := makeProvider(r, conf)
r.NoError(provider.GenerateRoot())
interText, err := provider.ActiveIntermediate()
r.NoError(err)
inter, err := connect.ParseCert(interText)
r.NoError(err)
r.Contains(inter.Subject.CommonName, conf.ClusterID)
r.Equal(x509.ECDSA, inter.PublicKeyAlgorithm)
r.Equal(x509.ECDSAWithSHA256, inter.SignatureAlgorithm)
r.True(inter.NotBefore.Before(time.Now()))
r.True(inter.NotAfter.After(time.Now()))
r.True(inter.IsCA)
r.True(inter.MaxPathLenZero)
}
func TestAWSProvider_SignIntermediate(t *testing.T) {
t.Parallel()
if awsAccessKeyId == "" || awsSecretAccessKey == "" || awsRegion == "" {
t.Skip("skipping test due to missing AWS credentials")
}
r := require.New(t)
conf := makeConfig()
provider := makeProvider(r, conf)
r.NoError(provider.GenerateRoot())
conf2 := testConsulCAConfig()
delegate2 := newMockDelegate(t, conf2)
provider2 := &ConsulProvider{Delegate: delegate2}
r.NoError(provider2.Configure(conf2.ClusterID, false, conf2.Config))
testSignIntermediateCrossDC(t, provider, provider2)
}
func TestAWSProvider_Cleanup(t *testing.T) {
// THIS TEST CANNOT BE RUN IN PARALLEL.
// It disables and deletes the PCA, which will cause other tests to fail if they are running
// at the same time.
if awsAccessKeyId == "" || awsSecretAccessKey == "" || awsRegion == "" {
t.Skip("skipping test due to missing AWS credentials")
}
r := require.New(t)
conf := makeConfig()
conf.Config["DeleteOnExit"] = true
provider := makeProvider(r, conf)
r.NoError(provider.GenerateRoot())
_, err := provider.GenerateIntermediate()
r.NoError(err)
rootPCA := provider.rootPCA
subPCA := provider.subPCA
r.NoError(provider.Cleanup())
output, err := awsClient.DescribeCertificateAuthority(&acmpca.DescribeCertificateAuthorityInput{
CertificateAuthorityArn: aws.String(rootPCA.arn),
})
r.NoError(err)
ca := output.CertificateAuthority
r.Equal(acmpca.CertificateAuthorityStatusDeleted, *ca.Status)
r.Equal((*AmazonPCA)(nil), provider.rootPCA)
r.NoError(rootPCA.Undelete())
r.NoError(rootPCA.Enable())
output, err = awsClient.DescribeCertificateAuthority(&acmpca.DescribeCertificateAuthorityInput{
CertificateAuthorityArn: aws.String(subPCA.arn),
})
r.NoError(err)
ca = output.CertificateAuthority
r.Equal(acmpca.CertificateAuthorityStatusDeleted, *ca.Status)
r.Equal((*AmazonPCA)(nil), provider.subPCA)
r.NoError(subPCA.Undelete())
r.NoError(subPCA.Enable())
}

@ -658,3 +658,11 @@ func (c *ConsulProvider) generateCA(privateKey string, sn uint64) (string, error
return buf.String(), nil return buf.String(), nil
} }
func (c *ConsulProvider) SupportsCrossSigning() bool {
return true
}
func (c *ConsulProvider) MinLifetime() time.Duration {
return 1 * time.Hour
}

@ -8,6 +8,7 @@ import (
"io/ioutil" "io/ioutil"
"net/http" "net/http"
"strings" "strings"
"time"
"github.com/hashicorp/consul/agent/connect" "github.com/hashicorp/consul/agent/connect"
"github.com/hashicorp/consul/agent/structs" "github.com/hashicorp/consul/agent/structs"
@ -376,6 +377,14 @@ func (v *VaultProvider) Cleanup() error {
return v.client.Sys().Unmount(v.config.IntermediatePKIPath) return v.client.Sys().Unmount(v.config.IntermediatePKIPath)
} }
func (v *VaultProvider) SupportsCrossSigning() bool {
return true
}
func (v *VaultProvider) MinLifetime() time.Duration {
return 1 * time.Hour
}
func ParseVaultCAConfig(raw map[string]interface{}) (*structs.VaultCAProviderConfig, error) { func ParseVaultCAConfig(raw map[string]interface{}) (*structs.VaultCAProviderConfig, error) {
config := structs.VaultCAProviderConfig{ config := structs.VaultCAProviderConfig{
CommonCAProviderConfig: defaultCommonConfig(), CommonCAProviderConfig: defaultCommonConfig(),

@ -3,6 +3,7 @@ package connect
import ( import (
"bytes" "bytes"
"crypto" "crypto"
"crypto/ecdsa"
"crypto/rand" "crypto/rand"
"crypto/x509" "crypto/x509"
"crypto/x509/pkix" "crypto/x509/pkix"
@ -13,10 +14,21 @@ import (
// CreateCSR returns a CSR to sign the given service along with the PEM-encoded // CreateCSR returns a CSR to sign the given service along with the PEM-encoded
// private key for this certificate. // private key for this certificate.
func CreateCSR(uri CertURI, privateKey crypto.Signer, extensions ...pkix.Extension) (string, error) { func CreateCSR(commonName string, uri CertURI, privateKey crypto.Signer,
extensions ...pkix.Extension) (string, error) {
signAlgo := x509.SHA256WithRSA
_, ok := privateKey.(*ecdsa.PrivateKey)
if ok {
signAlgo = x509.ECDSAWithSHA256
}
template := &x509.CertificateRequest{ template := &x509.CertificateRequest{
Subject: pkix.Name{
CommonName: commonName,
},
URIs: []*url.URL{uri.URI()}, URIs: []*url.URL{uri.URI()},
SignatureAlgorithm: x509.ECDSAWithSHA256, SignatureAlgorithm: signAlgo,
ExtraExtensions: extensions, ExtraExtensions: extensions,
} }
@ -43,7 +55,7 @@ func CreateCACSR(uri CertURI, privateKey crypto.Signer) (string, error) {
return "", err return "", err
} }
return CreateCSR(uri, privateKey, ext) return CreateCSR("Consul CA", uri, privateKey, ext)
} }
// CreateCAExtension creates a pkix.Extension for the x509 Basic Constraints // CreateCAExtension creates a pkix.Extension for the x509 Basic Constraints

@ -59,7 +59,7 @@ func (c *Client) RequestAutoEncryptCerts(servers []string, defaultPort int, toke
} }
// Create a CSR. // Create a CSR.
csr, err := connect.CreateCSR(id, pk) csr, err := connect.CreateCSR("Consul RPC", id, pk)
if err != nil { if err != nil {
return errFn(err) return errFn(err)
} }

@ -75,7 +75,7 @@ func TestAutoEncryptSign(t *testing.T) {
require.NoError(t, err, info) require.NoError(t, err, info)
// Create a CSR. // Create a CSR.
csr, err := connect.CreateCSR(id, pk) csr, err := connect.CreateCSR(info, id, pk)
require.NoError(t, err, info) require.NoError(t, err, info)
require.NotEmpty(t, csr, info) require.NotEmpty(t, csr, info)
args := &structs.CASignRequest{ args := &structs.CASignRequest{

@ -168,6 +168,16 @@ func (s *ConnectCA) ConfigurationSet(
return nil return nil
} }
// Have the old provider cross-sign the new intermediate
oldProvider, _ := s.srv.getCAProvider()
if oldProvider == nil {
return fmt.Errorf("internal error: CA provider is nil")
}
if !oldProvider.SupportsCrossSigning() {
return fmt.Errorf("error: current CA does not support cross-signing")
}
// Create a new instance of the provider described by the config // Create a new instance of the provider described by the config
// and get the current active root CA. This acts as a good validation // and get the current active root CA. This acts as a good validation
// of the config and makes sure the provider is functioning correctly // of the config and makes sure the provider is functioning correctly
@ -237,11 +247,6 @@ func (s *ConnectCA) ConfigurationSet(
return err return err
} }
// Have the old provider cross-sign the new intermediate
oldProvider, _ := s.srv.getCAProvider()
if oldProvider == nil {
return fmt.Errorf("internal error: CA provider is nil")
}
xcCert, err := oldProvider.CrossSignCA(newRoot) xcCert, err := oldProvider.CrossSignCA(newRoot)
if err != nil { if err != nil {
return err return err

@ -112,6 +112,53 @@ type CARoot struct {
// CARoots is a list of CARoot structures. // CARoots is a list of CARoot structures.
type CARoots []*CARoot type CARoots []*CARoot
// CAIntermediate represents an intermediate CA certificate that is trusted and used
// to sign leaf certificates.
type CAIntermediate struct {
// ID is a globally unique ID (UUID) representing this CA root.
ID string
// Name is a human-friendly name for this CA root. This value is
// opaque to Consul and is not used for anything internally.
Name string
// SerialNumber is the x509 serial number of the certificate.
SerialNumber uint64
// SigningKeyID is the ID of the public key that corresponds to the private
// key used to sign the certificate. Is is the HexString format of the raw
// AuthorityKeyID bytes.
SigningKeyID string
// ExternalTrustDomain is the trust domain this root was generated under. It
// is usually empty implying "the current cluster trust-domain". It is set
// only in the case that a cluster changes trust domain and then all old roots
// that are still trusted have the old trust domain set here.
//
// We currently DON'T validate these trust domains explicitly anywhere, see
// IndexedRoots.TrustDomain doc. We retain this information for debugging and
// future flexibility.
ExternalTrustDomain string
// Time validity bounds.
NotBefore time.Time
NotAfter time.Time
// IntermediateCert is the PEM-encoded public certificate.
IntermediateCert string
// Active is true if this is the current active CA. This must only
// be true for exactly one CA. For any method that modifies roots in the
// state store, tests should be written to verify that multiple roots
// cannot be active.
Active bool
// RotatedOutAt is the time at which this CA was removed from the state.
// This will only be set on roots that have been rotated out from being the
// active root.
RotatedOutAt time.Time `json:"-"`
}
// CASignRequest is the request for signing a service certificate. // CASignRequest is the request for signing a service certificate.
type CASignRequest struct { type CASignRequest struct {
// Datacenter is the target for this request. // Datacenter is the target for this request.
@ -207,6 +254,7 @@ func (q *CARequest) RequestDatacenter() string {
const ( const (
ConsulCAProvider = "consul" ConsulCAProvider = "consul"
VaultCAProvider = "vault" VaultCAProvider = "vault"
AWSCAProvider = "aws"
) )
// CAConfiguration is the configuration for the current CA plugin. // CAConfiguration is the configuration for the current CA plugin.
@ -349,6 +397,20 @@ type VaultCAProviderConfig struct {
TLSSkipVerify bool TLSSkipVerify bool
} }
type AWSCAProviderConfig struct {
CommonCAProviderConfig `mapstructure:",squash"`
AccessKeyId string
SecretAccessKey string
Region string
SleepTime string
RootARN string
IntermediateARN string
KeyAlgorithm string
SigningAlgorithm string
DeleteOnExit bool
}
// CALeafOp is the operation for a request related to leaf certificates. // CALeafOp is the operation for a request related to leaf certificates.
type CALeafOp string type CALeafOp string

3
go.mod

@ -20,6 +20,7 @@ require (
github.com/armon/go-metrics v0.0.0-20190430140413-ec5e00d3c878 github.com/armon/go-metrics v0.0.0-20190430140413-ec5e00d3c878
github.com/armon/go-radix v0.0.0-20180808171621-7fddfc383310 github.com/armon/go-radix v0.0.0-20180808171621-7fddfc383310
github.com/asaskevich/govalidator v0.0.0-20180319081651-7d2e70ef918f // indirect github.com/asaskevich/govalidator v0.0.0-20180319081651-7d2e70ef918f // indirect
github.com/aws/aws-sdk-go v1.21.1
github.com/bitly/go-hostpool v0.0.0-20171023180738-a3a6125de932 // indirect github.com/bitly/go-hostpool v0.0.0-20171023180738-a3a6125de932 // indirect
github.com/bmizerany/assert v0.0.0-20160611221934-b7ed37b82869 // indirect github.com/bmizerany/assert v0.0.0-20160611221934-b7ed37b82869 // indirect
github.com/cenkalti/backoff v2.1.1+incompatible // indirect github.com/cenkalti/backoff v2.1.1+incompatible // indirect
@ -56,7 +57,7 @@ require (
github.com/hashicorp/go-msgpack v0.5.5 github.com/hashicorp/go-msgpack v0.5.5
github.com/hashicorp/go-multierror v1.0.0 github.com/hashicorp/go-multierror v1.0.0
github.com/hashicorp/go-plugin v0.0.0-20180331002553-e8d22c780116 github.com/hashicorp/go-plugin v0.0.0-20180331002553-e8d22c780116
github.com/hashicorp/go-raftchunking v0.6.1 github.com/hashicorp/go-raftchunking v0.6.2
github.com/hashicorp/go-sockaddr v1.0.0 github.com/hashicorp/go-sockaddr v1.0.0
github.com/hashicorp/go-syslog v1.0.0 github.com/hashicorp/go-syslog v1.0.0
github.com/hashicorp/go-uuid v1.0.1 github.com/hashicorp/go-uuid v1.0.1

6
go.sum

@ -38,6 +38,8 @@ github.com/asaskevich/govalidator v0.0.0-20180319081651-7d2e70ef918f h1:/8NcnxL6
github.com/asaskevich/govalidator v0.0.0-20180319081651-7d2e70ef918f/go.mod h1:lB+ZfQJz7igIIfQNfa7Ml4HSf2uFQQRzpGGRXenZAgY= github.com/asaskevich/govalidator v0.0.0-20180319081651-7d2e70ef918f/go.mod h1:lB+ZfQJz7igIIfQNfa7Ml4HSf2uFQQRzpGGRXenZAgY=
github.com/aws/aws-sdk-go v1.15.24 h1:xLAdTA/ore6xdPAljzZRed7IGqQgC+nY+ERS5vaj4Ro= github.com/aws/aws-sdk-go v1.15.24 h1:xLAdTA/ore6xdPAljzZRed7IGqQgC+nY+ERS5vaj4Ro=
github.com/aws/aws-sdk-go v1.15.24/go.mod h1:mFuSZ37Z9YOHbQEwBWztmVzqXrEkub65tZoCYDt7FT0= github.com/aws/aws-sdk-go v1.15.24/go.mod h1:mFuSZ37Z9YOHbQEwBWztmVzqXrEkub65tZoCYDt7FT0=
github.com/aws/aws-sdk-go v1.21.1 h1:IOFDnCEDybcw4V8nbKqyyjBu+vpu7hFYSfZqNuogi7I=
github.com/aws/aws-sdk-go v1.21.1/go.mod h1:KmX6BPdI08NWTb3/sm4ZGu5ShLoqVDhKgpiN924inxo=
github.com/beorn7/perks v0.0.0-20180321164747-3a771d992973 h1:xJ4a3vCFaGF/jqvzLMYoU8P317H5OQ+Via4RmuPwCS0= github.com/beorn7/perks v0.0.0-20180321164747-3a771d992973 h1:xJ4a3vCFaGF/jqvzLMYoU8P317H5OQ+Via4RmuPwCS0=
github.com/beorn7/perks v0.0.0-20180321164747-3a771d992973/go.mod h1:Dwedo/Wpr24TaqPxmxbtue+5NUziq4I4S80YR8gNf3Q= github.com/beorn7/perks v0.0.0-20180321164747-3a771d992973/go.mod h1:Dwedo/Wpr24TaqPxmxbtue+5NUziq4I4S80YR8gNf3Q=
github.com/bgentry/speakeasy v0.1.0 h1:ByYyxL9InA1OWqxJqqp2A5pYHUrCiAL6K3J+LKSsQkY= github.com/bgentry/speakeasy v0.1.0 h1:ByYyxL9InA1OWqxJqqp2A5pYHUrCiAL6K3J+LKSsQkY=
@ -171,6 +173,8 @@ github.com/hashicorp/go-plugin v0.0.0-20180331002553-e8d22c780116 h1:Y4V/yReWjQo
github.com/hashicorp/go-plugin v0.0.0-20180331002553-e8d22c780116/go.mod h1:JSqWYsict+jzcj0+xElxyrBQRPNoiWQuddnxArJ7XHQ= github.com/hashicorp/go-plugin v0.0.0-20180331002553-e8d22c780116/go.mod h1:JSqWYsict+jzcj0+xElxyrBQRPNoiWQuddnxArJ7XHQ=
github.com/hashicorp/go-raftchunking v0.6.1 h1:moEnaG3gcwsWNyIBJoD5PCByE+Ewkqxh6N05CT+MbwA= github.com/hashicorp/go-raftchunking v0.6.1 h1:moEnaG3gcwsWNyIBJoD5PCByE+Ewkqxh6N05CT+MbwA=
github.com/hashicorp/go-raftchunking v0.6.1/go.mod h1:cGlg3JtDy7qy6c/3Bu660Mic1JF+7lWqIwCFSb08fX0= github.com/hashicorp/go-raftchunking v0.6.1/go.mod h1:cGlg3JtDy7qy6c/3Bu660Mic1JF+7lWqIwCFSb08fX0=
github.com/hashicorp/go-raftchunking v0.6.2 h1:imj6CVkwXj6VzgXZQvzS+fSrkbFCzlJ2t00F3PacnuU=
github.com/hashicorp/go-raftchunking v0.6.2/go.mod h1:cGlg3JtDy7qy6c/3Bu660Mic1JF+7lWqIwCFSb08fX0=
github.com/hashicorp/go-retryablehttp v0.5.3 h1:QlWt0KvWT0lq8MFppF9tsJGF+ynG7ztc2KIPhzRGk7s= github.com/hashicorp/go-retryablehttp v0.5.3 h1:QlWt0KvWT0lq8MFppF9tsJGF+ynG7ztc2KIPhzRGk7s=
github.com/hashicorp/go-retryablehttp v0.5.3/go.mod h1:9B5zBasrRhHXnJnui7y6sL7es7NDiJgTc6Er0maI1Xs= github.com/hashicorp/go-retryablehttp v0.5.3/go.mod h1:9B5zBasrRhHXnJnui7y6sL7es7NDiJgTc6Er0maI1Xs=
github.com/hashicorp/go-rootcerts v1.0.0 h1:Rqb66Oo1X/eSV1x66xbDccZjhJigjg0+e82kpwzSwCI= github.com/hashicorp/go-rootcerts v1.0.0 h1:Rqb66Oo1X/eSV1x66xbDccZjhJigjg0+e82kpwzSwCI=
@ -229,6 +233,8 @@ github.com/jefferai/jsonx v0.0.0-20160721235117-9cc31c3135ee h1:AQ/QmCk6x8ECPpf2
github.com/jefferai/jsonx v0.0.0-20160721235117-9cc31c3135ee/go.mod h1:N0t2vlmpe8nyZB5ouIbJQPDSR+mH6oe7xHB9VZHSUzM= github.com/jefferai/jsonx v0.0.0-20160721235117-9cc31c3135ee/go.mod h1:N0t2vlmpe8nyZB5ouIbJQPDSR+mH6oe7xHB9VZHSUzM=
github.com/jmespath/go-jmespath v0.0.0-20160202185014-0b12d6b521d8 h1:12VvqtR6Aowv3l/EQUlocDHW2Cp4G9WJVH7uyH8QFJE= github.com/jmespath/go-jmespath v0.0.0-20160202185014-0b12d6b521d8 h1:12VvqtR6Aowv3l/EQUlocDHW2Cp4G9WJVH7uyH8QFJE=
github.com/jmespath/go-jmespath v0.0.0-20160202185014-0b12d6b521d8/go.mod h1:Nht3zPeWKUH0NzdCt2Blrr5ys8VGpn0CEB0cQHVjt7k= github.com/jmespath/go-jmespath v0.0.0-20160202185014-0b12d6b521d8/go.mod h1:Nht3zPeWKUH0NzdCt2Blrr5ys8VGpn0CEB0cQHVjt7k=
github.com/jmespath/go-jmespath v0.0.0-20180206201540-c2b33e8439af h1:pmfjZENx5imkbgOkpRUYLnmbU7UEFbjtDA2hxJ1ichM=
github.com/jmespath/go-jmespath v0.0.0-20180206201540-c2b33e8439af/go.mod h1:Nht3zPeWKUH0NzdCt2Blrr5ys8VGpn0CEB0cQHVjt7k=
github.com/joyent/triton-go v0.0.0-20180628001255-830d2b111e62 h1:JHCT6xuyPUrbbgAPE/3dqlvUKzRHMNuTBKKUb6OeR/k= github.com/joyent/triton-go v0.0.0-20180628001255-830d2b111e62 h1:JHCT6xuyPUrbbgAPE/3dqlvUKzRHMNuTBKKUb6OeR/k=
github.com/joyent/triton-go v0.0.0-20180628001255-830d2b111e62/go.mod h1:U+RSyWxWd04xTqnuOQxnai7XGS2PrPY2cfGoDKtMHjA= github.com/joyent/triton-go v0.0.0-20180628001255-830d2b111e62/go.mod h1:U+RSyWxWd04xTqnuOQxnai7XGS2PrPY2cfGoDKtMHjA=
github.com/json-iterator/go v1.1.5 h1:gL2yXlmiIo4+t+y32d4WGwOjKGYcGOuyrg46vadswDE= github.com/json-iterator/go v1.1.5 h1:gL2yXlmiIo4+t+y32d4WGwOjKGYcGOuyrg46vadswDE=

@ -138,8 +138,27 @@ type RequestFailure interface {
RequestID() string RequestID() string
} }
// NewRequestFailure returns a new request error wrapper for the given Error // NewRequestFailure returns a wrapped error with additional information for
// provided. // request status code, and service requestID.
//
// Should be used to wrap all request which involve service requests. Even if
// the request failed without a service response, but had an HTTP status code
// that may be meaningful.
func NewRequestFailure(err Error, statusCode int, reqID string) RequestFailure { func NewRequestFailure(err Error, statusCode int, reqID string) RequestFailure {
return newRequestError(err, statusCode, reqID) return newRequestError(err, statusCode, reqID)
} }
// UnmarshalError provides the interface for the SDK failing to unmarshal data.
type UnmarshalError interface {
awsError
Bytes() []byte
}
// NewUnmarshalError returns an initialized UnmarshalError error wrapper adding
// the bytes that fail to unmarshal to the error.
func NewUnmarshalError(err error, msg string, bytes []byte) UnmarshalError {
return &unmarshalError{
awsError: New("UnmarshalError", msg, err),
bytes: bytes,
}
}

@ -1,6 +1,9 @@
package awserr package awserr
import "fmt" import (
"encoding/hex"
"fmt"
)
// SprintError returns a string of the formatted error code. // SprintError returns a string of the formatted error code.
// //
@ -119,6 +122,7 @@ type requestError struct {
awsError awsError
statusCode int statusCode int
requestID string requestID string
bytes []byte
} }
// newRequestError returns a wrapped error with additional information for // newRequestError returns a wrapped error with additional information for
@ -170,6 +174,29 @@ func (r requestError) OrigErrs() []error {
return []error{r.OrigErr()} return []error{r.OrigErr()}
} }
type unmarshalError struct {
awsError
bytes []byte
}
// Error returns the string representation of the error.
// Satisfies the error interface.
func (e unmarshalError) Error() string {
extra := hex.Dump(e.bytes)
return SprintError(e.Code(), e.Message(), extra, e.OrigErr())
}
// String returns the string representation of the error.
// Alias for Error to satisfy the stringer interface.
func (e unmarshalError) String() string {
return e.Error()
}
// Bytes returns the bytes that failed to unmarshal.
func (e unmarshalError) Bytes() []byte {
return e.bytes
}
// An error list that satisfies the golang interface // An error list that satisfies the golang interface
type errorList []error type errorList []error

@ -15,7 +15,7 @@ func DeepEqual(a, b interface{}) bool {
rb := reflect.Indirect(reflect.ValueOf(b)) rb := reflect.Indirect(reflect.ValueOf(b))
if raValid, rbValid := ra.IsValid(), rb.IsValid(); !raValid && !rbValid { if raValid, rbValid := ra.IsValid(), rb.IsValid(); !raValid && !rbValid {
// If the elements are both nil, and of the same type the are equal // If the elements are both nil, and of the same type they are equal
// If they are of different types they are not equal // If they are of different types they are not equal
return reflect.TypeOf(a) == reflect.TypeOf(b) return reflect.TypeOf(a) == reflect.TypeOf(b)
} else if raValid != rbValid { } else if raValid != rbValid {

@ -23,28 +23,27 @@ func stringValue(v reflect.Value, indent int, buf *bytes.Buffer) {
case reflect.Struct: case reflect.Struct:
buf.WriteString("{\n") buf.WriteString("{\n")
names := []string{}
for i := 0; i < v.Type().NumField(); i++ { for i := 0; i < v.Type().NumField(); i++ {
name := v.Type().Field(i).Name ft := v.Type().Field(i)
f := v.Field(i) fv := v.Field(i)
if name[0:1] == strings.ToLower(name[0:1]) {
if ft.Name[0:1] == strings.ToLower(ft.Name[0:1]) {
continue // ignore unexported fields continue // ignore unexported fields
} }
if (f.Kind() == reflect.Ptr || f.Kind() == reflect.Slice) && f.IsNil() { if (fv.Kind() == reflect.Ptr || fv.Kind() == reflect.Slice) && fv.IsNil() {
continue // ignore unset fields continue // ignore unset fields
} }
names = append(names, name)
}
for i, n := range names {
val := v.FieldByName(n)
buf.WriteString(strings.Repeat(" ", indent+2)) buf.WriteString(strings.Repeat(" ", indent+2))
buf.WriteString(n + ": ") buf.WriteString(ft.Name + ": ")
stringValue(val, indent+2, buf)
if i < len(names)-1 { if tag := ft.Tag.Get("sensitive"); tag == "true" {
buf.WriteString(",\n") buf.WriteString("<sensitive>")
} else {
stringValue(fv, indent+2, buf)
} }
buf.WriteString(",\n")
} }
buf.WriteString("\n" + strings.Repeat(" ", indent) + "}") buf.WriteString("\n" + strings.Repeat(" ", indent) + "}")

@ -18,7 +18,7 @@ type Config struct {
// States that the signing name did not come from a modeled source but // States that the signing name did not come from a modeled source but
// was derived based on other data. Used by service client constructors // was derived based on other data. Used by service client constructors
// to determine if the signin name can be overriden based on metadata the // to determine if the signin name can be overridden based on metadata the
// service has. // service has.
SigningNameDerived bool SigningNameDerived bool
} }

@ -118,6 +118,12 @@ var LogHTTPResponseHandler = request.NamedHandler{
func logResponse(r *request.Request) { func logResponse(r *request.Request) {
lw := &logWriter{r.Config.Logger, bytes.NewBuffer(nil)} lw := &logWriter{r.Config.Logger, bytes.NewBuffer(nil)}
if r.HTTPResponse == nil {
lw.Logger.Log(fmt.Sprintf(logRespErrMsg,
r.ClientInfo.ServiceName, r.Operation.Name, "request's HTTPResponse is nil"))
return
}
logBody := r.Config.LogLevel.Matches(aws.LogDebugWithHTTPBody) logBody := r.Config.LogLevel.Matches(aws.LogDebugWithHTTPBody)
if logBody { if logBody {
r.HTTPResponse.Body = &teeReaderCloser{ r.HTTPResponse.Body = &teeReaderCloser{

@ -18,7 +18,7 @@ const UseServiceDefaultRetries = -1
type RequestRetryer interface{} type RequestRetryer interface{}
// A Config provides service configuration for service clients. By default, // A Config provides service configuration for service clients. By default,
// all clients will use the defaults.DefaultConfig tructure. // all clients will use the defaults.DefaultConfig structure.
// //
// // Create Session with MaxRetry configuration to be shared by multiple // // Create Session with MaxRetry configuration to be shared by multiple
// // service clients. // // service clients.
@ -45,7 +45,7 @@ type Config struct {
// that overrides the default generated endpoint for a client. Set this // that overrides the default generated endpoint for a client. Set this
// to `""` to use the default generated endpoint. // to `""` to use the default generated endpoint.
// //
// @note You must still provide a `Region` value when specifying an // Note: You must still provide a `Region` value when specifying an
// endpoint for a client. // endpoint for a client.
Endpoint *string Endpoint *string
@ -65,8 +65,8 @@ type Config struct {
// noted. A full list of regions is found in the "Regions and Endpoints" // noted. A full list of regions is found in the "Regions and Endpoints"
// document. // document.
// //
// @see http://docs.aws.amazon.com/general/latest/gr/rande.html // See http://docs.aws.amazon.com/general/latest/gr/rande.html for AWS
// AWS Regions and Endpoints // Regions and Endpoints.
Region *string Region *string
// Set this to `true` to disable SSL when sending requests. Defaults // Set this to `true` to disable SSL when sending requests. Defaults
@ -120,9 +120,10 @@ type Config struct {
// will use virtual hosted bucket addressing when possible // will use virtual hosted bucket addressing when possible
// (`http://BUCKET.s3.amazonaws.com/KEY`). // (`http://BUCKET.s3.amazonaws.com/KEY`).
// //
// @note This configuration option is specific to the Amazon S3 service. // Note: This configuration option is specific to the Amazon S3 service.
// @see http://docs.aws.amazon.com/AmazonS3/latest/dev/VirtualHosting.html //
// Amazon S3: Virtual Hosting of Buckets // See http://docs.aws.amazon.com/AmazonS3/latest/dev/VirtualHosting.html
// for Amazon S3: Virtual Hosting of Buckets
S3ForcePathStyle *bool S3ForcePathStyle *bool
// Set this to `true` to disable the SDK adding the `Expect: 100-Continue` // Set this to `true` to disable the SDK adding the `Expect: 100-Continue`
@ -223,6 +224,28 @@ type Config struct {
// Key: aws.String("//foo//bar//moo"), // Key: aws.String("//foo//bar//moo"),
// }) // })
DisableRestProtocolURICleaning *bool DisableRestProtocolURICleaning *bool
// EnableEndpointDiscovery will allow for endpoint discovery on operations that
// have the definition in its model. By default, endpoint discovery is off.
//
// Example:
// sess := session.Must(session.NewSession(&aws.Config{
// EnableEndpointDiscovery: aws.Bool(true),
// }))
//
// svc := s3.New(sess)
// out, err := svc.GetObject(&s3.GetObjectInput {
// Bucket: aws.String("bucketname"),
// Key: aws.String("/foo/bar/moo"),
// })
EnableEndpointDiscovery *bool
// DisableEndpointHostPrefix will disable the SDK's behavior of prefixing
// request endpoint hosts with modeled information.
//
// Disabling this feature is useful when you want to use local endpoints
// for testing that do not support the modeled host prefix pattern.
DisableEndpointHostPrefix *bool
} }
// NewConfig returns a new Config pointer that can be chained with builder // NewConfig returns a new Config pointer that can be chained with builder
@ -377,6 +400,19 @@ func (c *Config) WithSleepDelay(fn func(time.Duration)) *Config {
return c return c
} }
// WithEndpointDiscovery will set whether or not to use endpoint discovery.
func (c *Config) WithEndpointDiscovery(t bool) *Config {
c.EnableEndpointDiscovery = &t
return c
}
// WithDisableEndpointHostPrefix will set whether or not to use modeled host prefix
// when making requests.
func (c *Config) WithDisableEndpointHostPrefix(t bool) *Config {
c.DisableEndpointHostPrefix = &t
return c
}
// MergeIn merges the passed in configs into the existing config object. // MergeIn merges the passed in configs into the existing config object.
func (c *Config) MergeIn(cfgs ...*Config) { func (c *Config) MergeIn(cfgs ...*Config) {
for _, other := range cfgs { for _, other := range cfgs {
@ -476,6 +512,14 @@ func mergeInConfig(dst *Config, other *Config) {
if other.EnforceShouldRetryCheck != nil { if other.EnforceShouldRetryCheck != nil {
dst.EnforceShouldRetryCheck = other.EnforceShouldRetryCheck dst.EnforceShouldRetryCheck = other.EnforceShouldRetryCheck
} }
if other.EnableEndpointDiscovery != nil {
dst.EnableEndpointDiscovery = other.EnableEndpointDiscovery
}
if other.DisableEndpointHostPrefix != nil {
dst.DisableEndpointHostPrefix = other.DisableEndpointHostPrefix
}
} }
// Copy will return a shallow copy of the Config object. If any additional // Copy will return a shallow copy of the Config object. If any additional

@ -1,8 +1,8 @@
// +build !go1.9
package aws package aws
import ( import "time"
"time"
)
// Context is an copy of the Go v1.7 stdlib's context.Context interface. // Context is an copy of the Go v1.7 stdlib's context.Context interface.
// It is represented as a SDK interface to enable you to use the "WithContext" // It is represented as a SDK interface to enable you to use the "WithContext"
@ -35,37 +35,3 @@ type Context interface {
// functions. // functions.
Value(key interface{}) interface{} Value(key interface{}) interface{}
} }
// BackgroundContext returns a context that will never be canceled, has no
// values, and no deadline. This context is used by the SDK to provide
// backwards compatibility with non-context API operations and functionality.
//
// Go 1.6 and before:
// This context function is equivalent to context.Background in the Go stdlib.
//
// Go 1.7 and later:
// The context returned will be the value returned by context.Background()
//
// See https://golang.org/pkg/context for more information on Contexts.
func BackgroundContext() Context {
return backgroundCtx
}
// SleepWithContext will wait for the timer duration to expire, or the context
// is canceled. Which ever happens first. If the context is canceled the Context's
// error will be returned.
//
// Expects Context to always return a non-nil error if the Done channel is closed.
func SleepWithContext(ctx Context, dur time.Duration) error {
t := time.NewTimer(dur)
defer t.Stop()
select {
case <-t.C:
break
case <-ctx.Done():
return ctx.Err()
}
return nil
}

@ -1,9 +0,0 @@
// +build go1.7
package aws
import "context"
var (
backgroundCtx = context.Background()
)

11
vendor/github.com/aws/aws-sdk-go/aws/context_1_9.go generated vendored Normal file

@ -0,0 +1,11 @@
// +build go1.9
package aws
import "context"
// Context is an alias of the Go stdlib's context.Context interface.
// It can be used within the SDK's API operation "WithContext" methods.
//
// See https://golang.org/pkg/context on how to use contexts.
type Context = context.Context

@ -39,3 +39,18 @@ func (e *emptyCtx) String() string {
var ( var (
backgroundCtx = new(emptyCtx) backgroundCtx = new(emptyCtx)
) )
// BackgroundContext returns a context that will never be canceled, has no
// values, and no deadline. This context is used by the SDK to provide
// backwards compatibility with non-context API operations and functionality.
//
// Go 1.6 and before:
// This context function is equivalent to context.Background in the Go stdlib.
//
// Go 1.7 and later:
// The context returned will be the value returned by context.Background()
//
// See https://golang.org/pkg/context for more information on Contexts.
func BackgroundContext() Context {
return backgroundCtx
}

@ -0,0 +1,20 @@
// +build go1.7
package aws
import "context"
// BackgroundContext returns a context that will never be canceled, has no
// values, and no deadline. This context is used by the SDK to provide
// backwards compatibility with non-context API operations and functionality.
//
// Go 1.6 and before:
// This context function is equivalent to context.Background in the Go stdlib.
//
// Go 1.7 and later:
// The context returned will be the value returned by context.Background()
//
// See https://golang.org/pkg/context for more information on Contexts.
func BackgroundContext() Context {
return context.Background()
}

24
vendor/github.com/aws/aws-sdk-go/aws/context_sleep.go generated vendored Normal file

@ -0,0 +1,24 @@
package aws
import (
"time"
)
// SleepWithContext will wait for the timer duration to expire, or the context
// is canceled. Which ever happens first. If the context is canceled the Context's
// error will be returned.
//
// Expects Context to always return a non-nil error if the Done channel is closed.
func SleepWithContext(ctx Context, dur time.Duration) error {
t := time.NewTimer(dur)
defer t.Stop()
select {
case <-t.C:
break
case <-ctx.Done():
return ctx.Err()
}
return nil
}

@ -72,9 +72,9 @@ var ValidateReqSigHandler = request.NamedHandler{
signedTime = r.LastSignedAt signedTime = r.LastSignedAt
} }
// 10 minutes to allow for some clock skew/delays in transmission. // 5 minutes to allow for some clock skew/delays in transmission.
// Would be improved with aws/aws-sdk-go#423 // Would be improved with aws/aws-sdk-go#423
if signedTime.Add(10 * time.Minute).After(time.Now()) { if signedTime.Add(5 * time.Minute).After(time.Now()) {
return return
} }

@ -17,7 +17,7 @@ var SDKVersionUserAgentHandler = request.NamedHandler{
} }
const execEnvVar = `AWS_EXECUTION_ENV` const execEnvVar = `AWS_EXECUTION_ENV`
const execEnvUAKey = `exec_env` const execEnvUAKey = `exec-env`
// AddHostExecEnvUserAgentHander is a request handler appending the SDK's // AddHostExecEnvUserAgentHander is a request handler appending the SDK's
// execution environment to the user agent. // execution environment to the user agent.

@ -9,9 +9,7 @@ var (
// providers in the ChainProvider. // providers in the ChainProvider.
// //
// This has been deprecated. For verbose error messaging set // This has been deprecated. For verbose error messaging set
// aws.Config.CredentialsChainVerboseErrors to true // aws.Config.CredentialsChainVerboseErrors to true.
//
// @readonly
ErrNoValidProvidersFoundInChain = awserr.New("NoCredentialProviders", ErrNoValidProvidersFoundInChain = awserr.New("NoCredentialProviders",
`no valid providers in chain. Deprecated. `no valid providers in chain. Deprecated.
For verbose messaging see aws.Config.CredentialsChainVerboseErrors`, For verbose messaging see aws.Config.CredentialsChainVerboseErrors`,

@ -49,8 +49,11 @@
package credentials package credentials
import ( import (
"fmt"
"sync" "sync"
"time" "time"
"github.com/aws/aws-sdk-go/aws/awserr"
) )
// AnonymousCredentials is an empty Credential object that can be used as // AnonymousCredentials is an empty Credential object that can be used as
@ -64,8 +67,6 @@ import (
// Credentials: credentials.AnonymousCredentials, // Credentials: credentials.AnonymousCredentials,
// }))) // })))
// // Access public S3 buckets. // // Access public S3 buckets.
//
// @readonly
var AnonymousCredentials = NewStaticCredentials("", "", "") var AnonymousCredentials = NewStaticCredentials("", "", "")
// A Value is the AWS credentials value for individual credential fields. // A Value is the AWS credentials value for individual credential fields.
@ -83,6 +84,12 @@ type Value struct {
ProviderName string ProviderName string
} }
// HasKeys returns if the credentials Value has both AccessKeyID and
// SecretAccessKey value set.
func (v Value) HasKeys() bool {
return len(v.AccessKeyID) != 0 && len(v.SecretAccessKey) != 0
}
// A Provider is the interface for any component which will provide credentials // A Provider is the interface for any component which will provide credentials
// Value. A provider is required to manage its own Expired state, and what to // Value. A provider is required to manage its own Expired state, and what to
// be expired means. // be expired means.
@ -99,6 +106,14 @@ type Provider interface {
IsExpired() bool IsExpired() bool
} }
// An Expirer is an interface that Providers can implement to expose the expiration
// time, if known. If the Provider cannot accurately provide this info,
// it should not implement this interface.
type Expirer interface {
// The time at which the credentials are no longer valid
ExpiresAt() time.Time
}
// An ErrorProvider is a stub credentials provider that always returns an error // An ErrorProvider is a stub credentials provider that always returns an error
// this is used by the SDK when construction a known provider is not possible // this is used by the SDK when construction a known provider is not possible
// due to an error. // due to an error.
@ -165,6 +180,11 @@ func (e *Expiry) IsExpired() bool {
return e.expiration.Before(curTime()) return e.expiration.Before(curTime())
} }
// ExpiresAt returns the expiration time of the credential
func (e *Expiry) ExpiresAt() time.Time {
return e.expiration
}
// A Credentials provides concurrency safe retrieval of AWS credentials Value. // A Credentials provides concurrency safe retrieval of AWS credentials Value.
// Credentials will cache the credentials value until they expire. Once the value // Credentials will cache the credentials value until they expire. Once the value
// expires the next Get will attempt to retrieve valid credentials. // expires the next Get will attempt to retrieve valid credentials.
@ -257,3 +277,23 @@ func (c *Credentials) IsExpired() bool {
func (c *Credentials) isExpired() bool { func (c *Credentials) isExpired() bool {
return c.forceRefresh || c.provider.IsExpired() return c.forceRefresh || c.provider.IsExpired()
} }
// ExpiresAt provides access to the functionality of the Expirer interface of
// the underlying Provider, if it supports that interface. Otherwise, it returns
// an error.
func (c *Credentials) ExpiresAt() (time.Time, error) {
c.m.RLock()
defer c.m.RUnlock()
expirer, ok := c.provider.(Expirer)
if !ok {
return time.Time{}, awserr.New("ProviderNotExpirer",
fmt.Sprintf("provider %s does not support ExpiresAt()", c.creds.ProviderName),
nil)
}
if c.forceRefresh {
// set expiration time to the distant past
return time.Time{}, nil
}
return expirer.ExpiresAt(), nil
}

@ -11,6 +11,7 @@ import (
"github.com/aws/aws-sdk-go/aws/client" "github.com/aws/aws-sdk-go/aws/client"
"github.com/aws/aws-sdk-go/aws/credentials" "github.com/aws/aws-sdk-go/aws/credentials"
"github.com/aws/aws-sdk-go/aws/ec2metadata" "github.com/aws/aws-sdk-go/aws/ec2metadata"
"github.com/aws/aws-sdk-go/aws/request"
"github.com/aws/aws-sdk-go/internal/sdkuri" "github.com/aws/aws-sdk-go/internal/sdkuri"
) )
@ -142,7 +143,8 @@ func requestCredList(client *ec2metadata.EC2Metadata) ([]string, error) {
} }
if err := s.Err(); err != nil { if err := s.Err(); err != nil {
return nil, awserr.New("SerializationError", "failed to read EC2 instance role from metadata service", err) return nil, awserr.New(request.ErrCodeSerialization,
"failed to read EC2 instance role from metadata service", err)
} }
return credsList, nil return credsList, nil
@ -164,7 +166,7 @@ func requestCred(client *ec2metadata.EC2Metadata, credsName string) (ec2RoleCred
respCreds := ec2RoleCredRespBody{} respCreds := ec2RoleCredRespBody{}
if err := json.NewDecoder(strings.NewReader(resp)).Decode(&respCreds); err != nil { if err := json.NewDecoder(strings.NewReader(resp)).Decode(&respCreds); err != nil {
return ec2RoleCredRespBody{}, return ec2RoleCredRespBody{},
awserr.New("SerializationError", awserr.New(request.ErrCodeSerialization,
fmt.Sprintf("failed to decode %s EC2 instance role credentials", credsName), fmt.Sprintf("failed to decode %s EC2 instance role credentials", credsName),
err) err)
} }

@ -39,6 +39,7 @@ import (
"github.com/aws/aws-sdk-go/aws/client/metadata" "github.com/aws/aws-sdk-go/aws/client/metadata"
"github.com/aws/aws-sdk-go/aws/credentials" "github.com/aws/aws-sdk-go/aws/credentials"
"github.com/aws/aws-sdk-go/aws/request" "github.com/aws/aws-sdk-go/aws/request"
"github.com/aws/aws-sdk-go/private/protocol/json/jsonutil"
) )
// ProviderName is the name of the credentials provider. // ProviderName is the name of the credentials provider.
@ -65,6 +66,10 @@ type Provider struct {
// //
// If ExpiryWindow is 0 or less it will be ignored. // If ExpiryWindow is 0 or less it will be ignored.
ExpiryWindow time.Duration ExpiryWindow time.Duration
// Optional authorization token value if set will be used as the value of
// the Authorization header of the endpoint credential request.
AuthorizationToken string
} }
// NewProviderClient returns a credentials Provider for retrieving AWS credentials // NewProviderClient returns a credentials Provider for retrieving AWS credentials
@ -152,6 +157,9 @@ func (p *Provider) getCredentials() (*getCredentialsOutput, error) {
out := &getCredentialsOutput{} out := &getCredentialsOutput{}
req := p.Client.NewRequest(op, nil, out) req := p.Client.NewRequest(op, nil, out)
req.HTTPRequest.Header.Set("Accept", "application/json") req.HTTPRequest.Header.Set("Accept", "application/json")
if authToken := p.AuthorizationToken; len(authToken) != 0 {
req.HTTPRequest.Header.Set("Authorization", authToken)
}
return out, req.Send() return out, req.Send()
} }
@ -167,7 +175,7 @@ func unmarshalHandler(r *request.Request) {
out := r.Data.(*getCredentialsOutput) out := r.Data.(*getCredentialsOutput)
if err := json.NewDecoder(r.HTTPResponse.Body).Decode(&out); err != nil { if err := json.NewDecoder(r.HTTPResponse.Body).Decode(&out); err != nil {
r.Error = awserr.New("SerializationError", r.Error = awserr.New(request.ErrCodeSerialization,
"failed to decode endpoint credentials", "failed to decode endpoint credentials",
err, err,
) )
@ -178,11 +186,15 @@ func unmarshalError(r *request.Request) {
defer r.HTTPResponse.Body.Close() defer r.HTTPResponse.Body.Close()
var errOut errorOutput var errOut errorOutput
if err := json.NewDecoder(r.HTTPResponse.Body).Decode(&errOut); err != nil { err := jsonutil.UnmarshalJSONError(&errOut, r.HTTPResponse.Body)
r.Error = awserr.New("SerializationError", if err != nil {
"failed to decode endpoint credentials", r.Error = awserr.NewRequestFailure(
err, awserr.New(request.ErrCodeSerialization,
"failed to decode error message", err),
r.HTTPResponse.StatusCode,
r.RequestID,
) )
return
} }
// Response body format is not consistent between metadata endpoints. // Response body format is not consistent between metadata endpoints.

@ -12,14 +12,10 @@ const EnvProviderName = "EnvProvider"
var ( var (
// ErrAccessKeyIDNotFound is returned when the AWS Access Key ID can't be // ErrAccessKeyIDNotFound is returned when the AWS Access Key ID can't be
// found in the process's environment. // found in the process's environment.
//
// @readonly
ErrAccessKeyIDNotFound = awserr.New("EnvAccessKeyNotFound", "AWS_ACCESS_KEY_ID or AWS_ACCESS_KEY not found in environment", nil) ErrAccessKeyIDNotFound = awserr.New("EnvAccessKeyNotFound", "AWS_ACCESS_KEY_ID or AWS_ACCESS_KEY not found in environment", nil)
// ErrSecretAccessKeyNotFound is returned when the AWS Secret Access Key // ErrSecretAccessKeyNotFound is returned when the AWS Secret Access Key
// can't be found in the process's environment. // can't be found in the process's environment.
//
// @readonly
ErrSecretAccessKeyNotFound = awserr.New("EnvSecretNotFound", "AWS_SECRET_ACCESS_KEY or AWS_SECRET_KEY not found in environment", nil) ErrSecretAccessKeyNotFound = awserr.New("EnvSecretNotFound", "AWS_SECRET_ACCESS_KEY or AWS_SECRET_KEY not found in environment", nil)
) )

@ -0,0 +1,425 @@
/*
Package processcreds is a credential Provider to retrieve `credential_process`
credentials.
WARNING: The following describes a method of sourcing credentials from an external
process. This can potentially be dangerous, so proceed with caution. Other
credential providers should be preferred if at all possible. If using this
option, you should make sure that the config file is as locked down as possible
using security best practices for your operating system.
You can use credentials from a `credential_process` in a variety of ways.
One way is to setup your shared config file, located in the default
location, with the `credential_process` key and the command you want to be
called. You also need to set the AWS_SDK_LOAD_CONFIG environment variable
(e.g., `export AWS_SDK_LOAD_CONFIG=1`) to use the shared config file.
[default]
credential_process = /command/to/call
Creating a new session will use the credential process to retrieve credentials.
NOTE: If there are credentials in the profile you are using, the credential
process will not be used.
// Initialize a session to load credentials.
sess, _ := session.NewSession(&aws.Config{
Region: aws.String("us-east-1")},
)
// Create S3 service client to use the credentials.
svc := s3.New(sess)
Another way to use the `credential_process` method is by using
`credentials.NewCredentials()` and providing a command to be executed to
retrieve credentials:
// Create credentials using the ProcessProvider.
creds := processcreds.NewCredentials("/path/to/command")
// Create service client value configured for credentials.
svc := s3.New(sess, &aws.Config{Credentials: creds})
You can set a non-default timeout for the `credential_process` with another
constructor, `credentials.NewCredentialsTimeout()`, providing the timeout. To
set a one minute timeout:
// Create credentials using the ProcessProvider.
creds := processcreds.NewCredentialsTimeout(
"/path/to/command",
time.Duration(500) * time.Millisecond)
If you need more control, you can set any configurable options in the
credentials using one or more option functions. For example, you can set a two
minute timeout, a credential duration of 60 minutes, and a maximum stdout
buffer size of 2k.
creds := processcreds.NewCredentials(
"/path/to/command",
func(opt *ProcessProvider) {
opt.Timeout = time.Duration(2) * time.Minute
opt.Duration = time.Duration(60) * time.Minute
opt.MaxBufSize = 2048
})
You can also use your own `exec.Cmd`:
// Create an exec.Cmd
myCommand := exec.Command("/path/to/command")
// Create credentials using your exec.Cmd and custom timeout
creds := processcreds.NewCredentialsCommand(
myCommand,
func(opt *processcreds.ProcessProvider) {
opt.Timeout = time.Duration(1) * time.Second
})
*/
package processcreds
import (
"bytes"
"encoding/json"
"fmt"
"io"
"io/ioutil"
"os"
"os/exec"
"runtime"
"strings"
"time"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/aws/credentials"
)
const (
// ProviderName is the name this credentials provider will label any
// returned credentials Value with.
ProviderName = `ProcessProvider`
// ErrCodeProcessProviderParse error parsing process output
ErrCodeProcessProviderParse = "ProcessProviderParseError"
// ErrCodeProcessProviderVersion version error in output
ErrCodeProcessProviderVersion = "ProcessProviderVersionError"
// ErrCodeProcessProviderRequired required attribute missing in output
ErrCodeProcessProviderRequired = "ProcessProviderRequiredError"
// ErrCodeProcessProviderExecution execution of command failed
ErrCodeProcessProviderExecution = "ProcessProviderExecutionError"
// errMsgProcessProviderTimeout process took longer than allowed
errMsgProcessProviderTimeout = "credential process timed out"
// errMsgProcessProviderProcess process error
errMsgProcessProviderProcess = "error in credential_process"
// errMsgProcessProviderParse problem parsing output
errMsgProcessProviderParse = "parse failed of credential_process output"
// errMsgProcessProviderVersion version error in output
errMsgProcessProviderVersion = "wrong version in process output (not 1)"
// errMsgProcessProviderMissKey missing access key id in output
errMsgProcessProviderMissKey = "missing AccessKeyId in process output"
// errMsgProcessProviderMissSecret missing secret acess key in output
errMsgProcessProviderMissSecret = "missing SecretAccessKey in process output"
// errMsgProcessProviderPrepareCmd prepare of command failed
errMsgProcessProviderPrepareCmd = "failed to prepare command"
// errMsgProcessProviderEmptyCmd command must not be empty
errMsgProcessProviderEmptyCmd = "command must not be empty"
// errMsgProcessProviderPipe failed to initialize pipe
errMsgProcessProviderPipe = "failed to initialize pipe"
// DefaultDuration is the default amount of time in minutes that the
// credentials will be valid for.
DefaultDuration = time.Duration(15) * time.Minute
// DefaultBufSize limits buffer size from growing to an enormous
// amount due to a faulty process.
DefaultBufSize = 1024
// DefaultTimeout default limit on time a process can run.
DefaultTimeout = time.Duration(1) * time.Minute
)
// ProcessProvider satisfies the credentials.Provider interface, and is a
// client to retrieve credentials from a process.
type ProcessProvider struct {
staticCreds bool
credentials.Expiry
originalCommand []string
// Expiry duration of the credentials. Defaults to 15 minutes if not set.
Duration time.Duration
// ExpiryWindow will allow the credentials to trigger refreshing prior to
// the credentials actually expiring. This is beneficial so race conditions
// with expiring credentials do not cause request to fail unexpectedly
// due to ExpiredTokenException exceptions.
//
// So a ExpiryWindow of 10s would cause calls to IsExpired() to return true
// 10 seconds before the credentials are actually expired.
//
// If ExpiryWindow is 0 or less it will be ignored.
ExpiryWindow time.Duration
// A string representing an os command that should return a JSON with
// credential information.
command *exec.Cmd
// MaxBufSize limits memory usage from growing to an enormous
// amount due to a faulty process.
MaxBufSize int
// Timeout limits the time a process can run.
Timeout time.Duration
}
// NewCredentials returns a pointer to a new Credentials object wrapping the
// ProcessProvider. The credentials will expire every 15 minutes by default.
func NewCredentials(command string, options ...func(*ProcessProvider)) *credentials.Credentials {
p := &ProcessProvider{
command: exec.Command(command),
Duration: DefaultDuration,
Timeout: DefaultTimeout,
MaxBufSize: DefaultBufSize,
}
for _, option := range options {
option(p)
}
return credentials.NewCredentials(p)
}
// NewCredentialsTimeout returns a pointer to a new Credentials object with
// the specified command and timeout, and default duration and max buffer size.
func NewCredentialsTimeout(command string, timeout time.Duration) *credentials.Credentials {
p := NewCredentials(command, func(opt *ProcessProvider) {
opt.Timeout = timeout
})
return p
}
// NewCredentialsCommand returns a pointer to a new Credentials object with
// the specified command, and default timeout, duration and max buffer size.
func NewCredentialsCommand(command *exec.Cmd, options ...func(*ProcessProvider)) *credentials.Credentials {
p := &ProcessProvider{
command: command,
Duration: DefaultDuration,
Timeout: DefaultTimeout,
MaxBufSize: DefaultBufSize,
}
for _, option := range options {
option(p)
}
return credentials.NewCredentials(p)
}
type credentialProcessResponse struct {
Version int
AccessKeyID string `json:"AccessKeyId"`
SecretAccessKey string
SessionToken string
Expiration *time.Time
}
// Retrieve executes the 'credential_process' and returns the credentials.
func (p *ProcessProvider) Retrieve() (credentials.Value, error) {
out, err := p.executeCredentialProcess()
if err != nil {
return credentials.Value{ProviderName: ProviderName}, err
}
// Serialize and validate response
resp := &credentialProcessResponse{}
if err = json.Unmarshal(out, resp); err != nil {
return credentials.Value{ProviderName: ProviderName}, awserr.New(
ErrCodeProcessProviderParse,
fmt.Sprintf("%s: %s", errMsgProcessProviderParse, string(out)),
err)
}
if resp.Version != 1 {
return credentials.Value{ProviderName: ProviderName}, awserr.New(
ErrCodeProcessProviderVersion,
errMsgProcessProviderVersion,
nil)
}
if len(resp.AccessKeyID) == 0 {
return credentials.Value{ProviderName: ProviderName}, awserr.New(
ErrCodeProcessProviderRequired,
errMsgProcessProviderMissKey,
nil)
}
if len(resp.SecretAccessKey) == 0 {
return credentials.Value{ProviderName: ProviderName}, awserr.New(
ErrCodeProcessProviderRequired,
errMsgProcessProviderMissSecret,
nil)
}
// Handle expiration
p.staticCreds = resp.Expiration == nil
if resp.Expiration != nil {
p.SetExpiration(*resp.Expiration, p.ExpiryWindow)
}
return credentials.Value{
ProviderName: ProviderName,
AccessKeyID: resp.AccessKeyID,
SecretAccessKey: resp.SecretAccessKey,
SessionToken: resp.SessionToken,
}, nil
}
// IsExpired returns true if the credentials retrieved are expired, or not yet
// retrieved.
func (p *ProcessProvider) IsExpired() bool {
if p.staticCreds {
return false
}
return p.Expiry.IsExpired()
}
// prepareCommand prepares the command to be executed.
func (p *ProcessProvider) prepareCommand() error {
var cmdArgs []string
if runtime.GOOS == "windows" {
cmdArgs = []string{"cmd.exe", "/C"}
} else {
cmdArgs = []string{"sh", "-c"}
}
if len(p.originalCommand) == 0 {
p.originalCommand = make([]string, len(p.command.Args))
copy(p.originalCommand, p.command.Args)
// check for empty command because it succeeds
if len(strings.TrimSpace(p.originalCommand[0])) < 1 {
return awserr.New(
ErrCodeProcessProviderExecution,
fmt.Sprintf(
"%s: %s",
errMsgProcessProviderPrepareCmd,
errMsgProcessProviderEmptyCmd),
nil)
}
}
cmdArgs = append(cmdArgs, p.originalCommand...)
p.command = exec.Command(cmdArgs[0], cmdArgs[1:]...)
p.command.Env = os.Environ()
return nil
}
// executeCredentialProcess starts the credential process on the OS and
// returns the results or an error.
func (p *ProcessProvider) executeCredentialProcess() ([]byte, error) {
if err := p.prepareCommand(); err != nil {
return nil, err
}
// Setup the pipes
outReadPipe, outWritePipe, err := os.Pipe()
if err != nil {
return nil, awserr.New(
ErrCodeProcessProviderExecution,
errMsgProcessProviderPipe,
err)
}
p.command.Stderr = os.Stderr // display stderr on console for MFA
p.command.Stdout = outWritePipe // get creds json on process's stdout
p.command.Stdin = os.Stdin // enable stdin for MFA
output := bytes.NewBuffer(make([]byte, 0, p.MaxBufSize))
stdoutCh := make(chan error, 1)
go readInput(
io.LimitReader(outReadPipe, int64(p.MaxBufSize)),
output,
stdoutCh)
execCh := make(chan error, 1)
go executeCommand(*p.command, execCh)
finished := false
var errors []error
for !finished {
select {
case readError := <-stdoutCh:
errors = appendError(errors, readError)
finished = true
case execError := <-execCh:
err := outWritePipe.Close()
errors = appendError(errors, err)
errors = appendError(errors, execError)
if errors != nil {
return output.Bytes(), awserr.NewBatchError(
ErrCodeProcessProviderExecution,
errMsgProcessProviderProcess,
errors)
}
case <-time.After(p.Timeout):
finished = true
return output.Bytes(), awserr.NewBatchError(
ErrCodeProcessProviderExecution,
errMsgProcessProviderTimeout,
errors) // errors can be nil
}
}
out := output.Bytes()
if runtime.GOOS == "windows" {
// windows adds slashes to quotes
out = []byte(strings.Replace(string(out), `\"`, `"`, -1))
}
return out, nil
}
// appendError conveniently checks for nil before appending slice
func appendError(errors []error, err error) []error {
if err != nil {
return append(errors, err)
}
return errors
}
func executeCommand(cmd exec.Cmd, exec chan error) {
// Start the command
err := cmd.Start()
if err == nil {
err = cmd.Wait()
}
exec <- err
}
func readInput(r io.Reader, w io.Writer, read chan error) {
tee := io.TeeReader(r, w)
_, err := ioutil.ReadAll(tee)
if err == io.EOF {
err = nil
}
read <- err // will only arrive here when write end of pipe is closed
}

@ -4,9 +4,8 @@ import (
"fmt" "fmt"
"os" "os"
"github.com/go-ini/ini"
"github.com/aws/aws-sdk-go/aws/awserr" "github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/internal/ini"
"github.com/aws/aws-sdk-go/internal/shareddefaults" "github.com/aws/aws-sdk-go/internal/shareddefaults"
) )
@ -77,36 +76,37 @@ func (p *SharedCredentialsProvider) IsExpired() bool {
// The credentials retrieved from the profile will be returned or error. Error will be // The credentials retrieved from the profile will be returned or error. Error will be
// returned if it fails to read from the file, or the data is invalid. // returned if it fails to read from the file, or the data is invalid.
func loadProfile(filename, profile string) (Value, error) { func loadProfile(filename, profile string) (Value, error) {
config, err := ini.Load(filename) config, err := ini.OpenFile(filename)
if err != nil { if err != nil {
return Value{ProviderName: SharedCredsProviderName}, awserr.New("SharedCredsLoad", "failed to load shared credentials file", err) return Value{ProviderName: SharedCredsProviderName}, awserr.New("SharedCredsLoad", "failed to load shared credentials file", err)
} }
iniProfile, err := config.GetSection(profile)
if err != nil { iniProfile, ok := config.GetSection(profile)
return Value{ProviderName: SharedCredsProviderName}, awserr.New("SharedCredsLoad", "failed to get profile", err) if !ok {
return Value{ProviderName: SharedCredsProviderName}, awserr.New("SharedCredsLoad", "failed to get profile", nil)
} }
id, err := iniProfile.GetKey("aws_access_key_id") id := iniProfile.String("aws_access_key_id")
if err != nil { if len(id) == 0 {
return Value{ProviderName: SharedCredsProviderName}, awserr.New("SharedCredsAccessKey", return Value{ProviderName: SharedCredsProviderName}, awserr.New("SharedCredsAccessKey",
fmt.Sprintf("shared credentials %s in %s did not contain aws_access_key_id", profile, filename), fmt.Sprintf("shared credentials %s in %s did not contain aws_access_key_id", profile, filename),
err) nil)
} }
secret, err := iniProfile.GetKey("aws_secret_access_key") secret := iniProfile.String("aws_secret_access_key")
if err != nil { if len(secret) == 0 {
return Value{ProviderName: SharedCredsProviderName}, awserr.New("SharedCredsSecret", return Value{ProviderName: SharedCredsProviderName}, awserr.New("SharedCredsSecret",
fmt.Sprintf("shared credentials %s in %s did not contain aws_secret_access_key", profile, filename), fmt.Sprintf("shared credentials %s in %s did not contain aws_secret_access_key", profile, filename),
nil) nil)
} }
// Default to empty string if not found // Default to empty string if not found
token := iniProfile.Key("aws_session_token") token := iniProfile.String("aws_session_token")
return Value{ return Value{
AccessKeyID: id.String(), AccessKeyID: id,
SecretAccessKey: secret.String(), SecretAccessKey: secret,
SessionToken: token.String(), SessionToken: token,
ProviderName: SharedCredsProviderName, ProviderName: SharedCredsProviderName,
}, nil }, nil
} }

@ -9,8 +9,6 @@ const StaticProviderName = "StaticProvider"
var ( var (
// ErrStaticCredentialsEmpty is emitted when static credentials are empty. // ErrStaticCredentialsEmpty is emitted when static credentials are empty.
//
// @readonly
ErrStaticCredentialsEmpty = awserr.New("EmptyStaticCreds", "static credentials are empty", nil) ErrStaticCredentialsEmpty = awserr.New("EmptyStaticCreds", "static credentials are empty", nil)
) )

@ -80,16 +80,18 @@ package stscreds
import ( import (
"fmt" "fmt"
"os"
"time" "time"
"github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/awserr" "github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/aws/client" "github.com/aws/aws-sdk-go/aws/client"
"github.com/aws/aws-sdk-go/aws/credentials" "github.com/aws/aws-sdk-go/aws/credentials"
"github.com/aws/aws-sdk-go/internal/sdkrand"
"github.com/aws/aws-sdk-go/service/sts" "github.com/aws/aws-sdk-go/service/sts"
) )
// StdinTokenProvider will prompt on stdout and read from stdin for a string value. // StdinTokenProvider will prompt on stderr and read from stdin for a string value.
// An error is returned if reading from stdin fails. // An error is returned if reading from stdin fails.
// //
// Use this function go read MFA tokens from stdin. The function makes no attempt // Use this function go read MFA tokens from stdin. The function makes no attempt
@ -102,7 +104,7 @@ import (
// Will wait forever until something is provided on the stdin. // Will wait forever until something is provided on the stdin.
func StdinTokenProvider() (string, error) { func StdinTokenProvider() (string, error) {
var v string var v string
fmt.Printf("Assume Role MFA token code: ") fmt.Fprintf(os.Stderr, "Assume Role MFA token code: ")
_, err := fmt.Scanln(&v) _, err := fmt.Scanln(&v)
return v, err return v, err
@ -193,6 +195,18 @@ type AssumeRoleProvider struct {
// //
// If ExpiryWindow is 0 or less it will be ignored. // If ExpiryWindow is 0 or less it will be ignored.
ExpiryWindow time.Duration ExpiryWindow time.Duration
// MaxJitterFrac reduces the effective Duration of each credential requested
// by a random percentage between 0 and MaxJitterFraction. MaxJitterFrac must
// have a value between 0 and 1. Any other value may lead to expected behavior.
// With a MaxJitterFrac value of 0, default) will no jitter will be used.
//
// For example, with a Duration of 30m and a MaxJitterFrac of 0.1, the
// AssumeRole call will be made with an arbitrary Duration between 27m and
// 30m.
//
// MaxJitterFrac should not be negative.
MaxJitterFrac float64
} }
// NewCredentials returns a pointer to a new Credentials object wrapping the // NewCredentials returns a pointer to a new Credentials object wrapping the
@ -244,7 +258,6 @@ func NewCredentialsWithClient(svc AssumeRoler, roleARN string, options ...func(*
// Retrieve generates a new set of temporary credentials using STS. // Retrieve generates a new set of temporary credentials using STS.
func (p *AssumeRoleProvider) Retrieve() (credentials.Value, error) { func (p *AssumeRoleProvider) Retrieve() (credentials.Value, error) {
// Apply defaults where parameters are not set. // Apply defaults where parameters are not set.
if p.RoleSessionName == "" { if p.RoleSessionName == "" {
// Try to work out a role name that will hopefully end up unique. // Try to work out a role name that will hopefully end up unique.
@ -254,8 +267,9 @@ func (p *AssumeRoleProvider) Retrieve() (credentials.Value, error) {
// Expire as often as AWS permits. // Expire as often as AWS permits.
p.Duration = DefaultDuration p.Duration = DefaultDuration
} }
jitter := time.Duration(sdkrand.SeededRand.Float64() * p.MaxJitterFrac * float64(p.Duration))
input := &sts.AssumeRoleInput{ input := &sts.AssumeRoleInput{
DurationSeconds: aws.Int64(int64(p.Duration / time.Second)), DurationSeconds: aws.Int64(int64((p.Duration - jitter) / time.Second)),
RoleArn: aws.String(p.RoleARN), RoleArn: aws.String(p.RoleARN),
RoleSessionName: aws.String(p.RoleSessionName), RoleSessionName: aws.String(p.RoleSessionName),
ExternalId: p.ExternalID, ExternalId: p.ExternalID,

@ -0,0 +1,99 @@
package stscreds
import (
"fmt"
"io/ioutil"
"strconv"
"time"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/aws/client"
"github.com/aws/aws-sdk-go/aws/credentials"
"github.com/aws/aws-sdk-go/service/sts"
"github.com/aws/aws-sdk-go/service/sts/stsiface"
)
const (
// ErrCodeWebIdentity will be used as an error code when constructing
// a new error to be returned during session creation or retrieval.
ErrCodeWebIdentity = "WebIdentityErr"
// WebIdentityProviderName is the web identity provider name
WebIdentityProviderName = "WebIdentityCredentials"
)
// now is used to return a time.Time object representing
// the current time. This can be used to easily test and
// compare test values.
var now = func() time.Time {
return time.Now()
}
// WebIdentityRoleProvider is used to retrieve credentials using
// an OIDC token.
type WebIdentityRoleProvider struct {
credentials.Expiry
client stsiface.STSAPI
ExpiryWindow time.Duration
tokenFilePath string
roleARN string
roleSessionName string
}
// NewWebIdentityCredentials will return a new set of credentials with a given
// configuration, role arn, and token file path.
func NewWebIdentityCredentials(c client.ConfigProvider, roleARN, roleSessionName, path string) *credentials.Credentials {
svc := sts.New(c)
p := NewWebIdentityRoleProvider(svc, roleARN, roleSessionName, path)
return credentials.NewCredentials(p)
}
// NewWebIdentityRoleProvider will return a new WebIdentityRoleProvider with the
// provided stsiface.STSAPI
func NewWebIdentityRoleProvider(svc stsiface.STSAPI, roleARN, roleSessionName, path string) *WebIdentityRoleProvider {
return &WebIdentityRoleProvider{
client: svc,
tokenFilePath: path,
roleARN: roleARN,
roleSessionName: roleSessionName,
}
}
// Retrieve will attempt to assume a role from a token which is located at
// 'WebIdentityTokenFilePath' specified destination and if that is empty an
// error will be returned.
func (p *WebIdentityRoleProvider) Retrieve() (credentials.Value, error) {
b, err := ioutil.ReadFile(p.tokenFilePath)
if err != nil {
errMsg := fmt.Sprintf("unable to read file at %s", p.tokenFilePath)
return credentials.Value{}, awserr.New(ErrCodeWebIdentity, errMsg, err)
}
sessionName := p.roleSessionName
if len(sessionName) == 0 {
// session name is used to uniquely identify a session. This simply
// uses unix time in nanoseconds to uniquely identify sessions.
sessionName = strconv.FormatInt(now().UnixNano(), 10)
}
resp, err := p.client.AssumeRoleWithWebIdentity(&sts.AssumeRoleWithWebIdentityInput{
RoleArn: &p.roleARN,
RoleSessionName: &sessionName,
WebIdentityToken: aws.String(string(b)),
})
if err != nil {
return credentials.Value{}, awserr.New(ErrCodeWebIdentity, "failed to retrieve credentials", err)
}
p.SetExpiration(aws.TimeValue(resp.Credentials.Expiration), p.ExpiryWindow)
value := credentials.Value{
AccessKeyID: aws.StringValue(resp.Credentials.AccessKeyId),
SecretAccessKey: aws.StringValue(resp.Credentials.SecretAccessKey),
SessionToken: aws.StringValue(resp.Credentials.SessionToken),
ProviderName: WebIdentityProviderName,
}
return value, nil
}

@ -1,30 +1,61 @@
// Package csm provides Client Side Monitoring (CSM) which enables sending metrics // Package csm provides the Client Side Monitoring (CSM) client which enables
// via UDP connection. Using the Start function will enable the reporting of // sending metrics via UDP connection to the CSM agent. This package provides
// metrics on a given port. If Start is called, with different parameters, again, // control options, and configuration for the CSM client. The client can be
// a panic will occur. // controlled manually, or automatically via the SDK's Session configuration.
// //
// Pause can be called to pause any metrics publishing on a given port. Sessions // Enabling CSM client via SDK's Session configuration
// that have had their handlers modified via InjectHandlers may still be used. //
// However, the handlers will act as a no-op meaning no metrics will be published. // The CSM client can be enabled automatically via SDK's Session configuration.
// The SDK's session configuration enables the CSM client if the AWS_CSM_PORT
// environment variable is set to a non-empty value.
//
// The configuration options for the CSM client via the SDK's session
// configuration are:
//
// * AWS_CSM_PORT=<port number>
// The port number the CSM agent will receive metrics on.
//
// * AWS_CSM_HOST=<hostname or ip>
// The hostname, or IP address the CSM agent will receive metrics on.
// Without port number.
//
// Manually enabling the CSM client
//
// The CSM client can be started, paused, and resumed manually. The Start
// function will enable the CSM client to publish metrics to the CSM agent. It
// is safe to call Start concurrently, but if Start is called additional times
// with different ClientID or address it will panic.
// //
// Example:
// r, err := csm.Start("clientID", ":31000") // r, err := csm.Start("clientID", ":31000")
// if err != nil { // if err != nil {
// panic(fmt.Errorf("failed starting CSM: %v", err)) // panic(fmt.Errorf("failed starting CSM: %v", err))
// } // }
// //
// When controlling the CSM client manually, you must also inject its request
// handlers into the SDK's Session configuration for the SDK's API clients to
// publish metrics.
//
// sess, err := session.NewSession(&aws.Config{}) // sess, err := session.NewSession(&aws.Config{})
// if err != nil { // if err != nil {
// panic(fmt.Errorf("failed loading session: %v", err)) // panic(fmt.Errorf("failed loading session: %v", err))
// } // }
// //
// // Add CSM client's metric publishing request handlers to the SDK's
// // Session Configuration.
// r.InjectHandlers(&sess.Handlers) // r.InjectHandlers(&sess.Handlers)
// //
// client := s3.New(sess) // Controlling CSM client
// resp, err := client.GetObject(&s3.GetObjectInput{ //
// Bucket: aws.String("bucket"), // Once the CSM client has been enabled the Get function will return a Reporter
// Key: aws.String("key"), // value that you can use to pause and resume the metrics published to the CSM
// }) // agent. If Get function is called before the reporter is enabled with the
// Start function or via SDK's Session configuration nil will be returned.
//
// The Pause method can be called to stop the CSM client publishing metrics to
// the CSM agent. The Continue method will resume metric publishing.
//
// // Get the CSM client Reporter.
// r := csm.Get()
// //
// // Will pause monitoring // // Will pause monitoring
// r.Pause() // r.Pause()
@ -35,12 +66,4 @@
// //
// // Resume monitoring // // Resume monitoring
// r.Continue() // r.Continue()
//
// Start returns a Reporter that is used to enable or disable monitoring. If
// access to the Reporter is required later, calling Get will return the Reporter
// singleton.
//
// Example:
// r := csm.Get()
// r.Continue()
package csm package csm

@ -2,6 +2,7 @@ package csm
import ( import (
"fmt" "fmt"
"strings"
"sync" "sync"
) )
@ -9,19 +10,40 @@ var (
lock sync.Mutex lock sync.Mutex
) )
// Client side metric handler names
const ( const (
APICallMetricHandlerName = "awscsm.SendAPICallMetric" // DefaultPort is used when no port is specified.
APICallAttemptMetricHandlerName = "awscsm.SendAPICallAttemptMetric" DefaultPort = "31000"
// DefaultHost is the host that will be used when none is specified.
DefaultHost = "127.0.0.1"
) )
// Start will start the a long running go routine to capture // AddressWithDefaults returns a CSM address built from the host and port
// values. If the host or port is not set, default values will be used
// instead. If host is "localhost" it will be replaced with "127.0.0.1".
func AddressWithDefaults(host, port string) string {
if len(host) == 0 || strings.EqualFold(host, "localhost") {
host = DefaultHost
}
if len(port) == 0 {
port = DefaultPort
}
// Only IP6 host can contain a colon
if strings.Contains(host, ":") {
return "[" + host + "]:" + port
}
return host + ":" + port
}
// Start will start a long running go routine to capture
// client side metrics. Calling start multiple time will only // client side metrics. Calling start multiple time will only
// start the metric listener once and will panic if a different // start the metric listener once and will panic if a different
// client ID or port is passed in. // client ID or port is passed in.
// //
// Example: // r, err := csm.Start("clientID", "127.0.0.1:31000")
// r, err := csm.Start("clientID", "127.0.0.1:8094")
// if err != nil { // if err != nil {
// panic(fmt.Errorf("expected no error, but received %v", err)) // panic(fmt.Errorf("expected no error, but received %v", err))
// } // }

@ -3,6 +3,8 @@ package csm
import ( import (
"strconv" "strconv"
"time" "time"
"github.com/aws/aws-sdk-go/aws"
) )
type metricTime time.Time type metricTime time.Time
@ -39,6 +41,12 @@ type metric struct {
SDKException *string `json:"SdkException,omitempty"` SDKException *string `json:"SdkException,omitempty"`
SDKExceptionMessage *string `json:"SdkExceptionMessage,omitempty"` SDKExceptionMessage *string `json:"SdkExceptionMessage,omitempty"`
FinalHTTPStatusCode *int `json:"FinalHttpStatusCode,omitempty"`
FinalAWSException *string `json:"FinalAwsException,omitempty"`
FinalAWSExceptionMessage *string `json:"FinalAwsExceptionMessage,omitempty"`
FinalSDKException *string `json:"FinalSdkException,omitempty"`
FinalSDKExceptionMessage *string `json:"FinalSdkExceptionMessage,omitempty"`
DestinationIP *string `json:"DestinationIp,omitempty"` DestinationIP *string `json:"DestinationIp,omitempty"`
ConnectionReused *int `json:"ConnectionReused,omitempty"` ConnectionReused *int `json:"ConnectionReused,omitempty"`
@ -48,4 +56,54 @@ type metric struct {
DNSLatency *int `json:"DnsLatency,omitempty"` DNSLatency *int `json:"DnsLatency,omitempty"`
TCPLatency *int `json:"TcpLatency,omitempty"` TCPLatency *int `json:"TcpLatency,omitempty"`
SSLLatency *int `json:"SslLatency,omitempty"` SSLLatency *int `json:"SslLatency,omitempty"`
MaxRetriesExceeded *int `json:"MaxRetriesExceeded,omitempty"`
}
func (m *metric) TruncateFields() {
m.ClientID = truncateString(m.ClientID, 255)
m.UserAgent = truncateString(m.UserAgent, 256)
m.AWSException = truncateString(m.AWSException, 128)
m.AWSExceptionMessage = truncateString(m.AWSExceptionMessage, 512)
m.SDKException = truncateString(m.SDKException, 128)
m.SDKExceptionMessage = truncateString(m.SDKExceptionMessage, 512)
m.FinalAWSException = truncateString(m.FinalAWSException, 128)
m.FinalAWSExceptionMessage = truncateString(m.FinalAWSExceptionMessage, 512)
m.FinalSDKException = truncateString(m.FinalSDKException, 128)
m.FinalSDKExceptionMessage = truncateString(m.FinalSDKExceptionMessage, 512)
}
func truncateString(v *string, l int) *string {
if v != nil && len(*v) > l {
nv := (*v)[:l]
return &nv
}
return v
}
func (m *metric) SetException(e metricException) {
switch te := e.(type) {
case awsException:
m.AWSException = aws.String(te.exception)
m.AWSExceptionMessage = aws.String(te.message)
case sdkException:
m.SDKException = aws.String(te.exception)
m.SDKExceptionMessage = aws.String(te.message)
}
}
func (m *metric) SetFinalException(e metricException) {
switch te := e.(type) {
case awsException:
m.FinalAWSException = aws.String(te.exception)
m.FinalAWSExceptionMessage = aws.String(te.message)
case sdkException:
m.FinalSDKException = aws.String(te.exception)
m.FinalSDKExceptionMessage = aws.String(te.message)
}
} }

@ -0,0 +1,26 @@
package csm
type metricException interface {
Exception() string
Message() string
}
type requestException struct {
exception string
message string
}
func (e requestException) Exception() string {
return e.exception
}
func (e requestException) Message() string {
return e.message
}
type awsException struct {
requestException
}
type sdkException struct {
requestException
}

@ -10,11 +10,6 @@ import (
"github.com/aws/aws-sdk-go/aws/request" "github.com/aws/aws-sdk-go/aws/request"
) )
const (
// DefaultPort is used when no port is specified
DefaultPort = "31000"
)
// Reporter will gather metrics of API requests made and // Reporter will gather metrics of API requests made and
// send those metrics to the CSM endpoint. // send those metrics to the CSM endpoint.
type Reporter struct { type Reporter struct {
@ -82,26 +77,29 @@ func (rep *Reporter) sendAPICallAttemptMetric(r *request.Request) {
if r.Error != nil { if r.Error != nil {
if awserr, ok := r.Error.(awserr.Error); ok { if awserr, ok := r.Error.(awserr.Error); ok {
setError(&m, awserr) m.SetException(getMetricException(awserr))
} }
} }
m.TruncateFields()
rep.metricsCh.Push(m) rep.metricsCh.Push(m)
} }
func setError(m *metric, err awserr.Error) { func getMetricException(err awserr.Error) metricException {
msg := err.Error() msg := err.Error()
code := err.Code() code := err.Code()
switch code { switch code {
case "RequestError", case "RequestError",
"SerializationError", request.ErrCodeSerialization,
request.CanceledErrorCode: request.CanceledErrorCode:
m.SDKException = &code return sdkException{
m.SDKExceptionMessage = &msg requestException{exception: code, message: msg},
}
default: default:
m.AWSException = &code return awsException{
m.AWSExceptionMessage = &msg requestException{exception: code, message: msg},
}
} }
} }
@ -116,12 +114,27 @@ func (rep *Reporter) sendAPICallMetric(r *request.Request) {
API: aws.String(r.Operation.Name), API: aws.String(r.Operation.Name),
Service: aws.String(r.ClientInfo.ServiceID), Service: aws.String(r.ClientInfo.ServiceID),
Timestamp: (*metricTime)(&now), Timestamp: (*metricTime)(&now),
UserAgent: aws.String(r.HTTPRequest.Header.Get("User-Agent")),
Type: aws.String("ApiCall"), Type: aws.String("ApiCall"),
AttemptCount: aws.Int(r.RetryCount + 1), AttemptCount: aws.Int(r.RetryCount + 1),
Region: r.Config.Region,
Latency: aws.Int(int(time.Now().Sub(r.Time) / time.Millisecond)), Latency: aws.Int(int(time.Now().Sub(r.Time) / time.Millisecond)),
XAmzRequestID: aws.String(r.RequestID), XAmzRequestID: aws.String(r.RequestID),
MaxRetriesExceeded: aws.Int(boolIntValue(r.RetryCount >= r.MaxRetries())),
} }
if r.HTTPResponse != nil {
m.FinalHTTPStatusCode = aws.Int(r.HTTPResponse.StatusCode)
}
if r.Error != nil {
if awserr, ok := r.Error.(awserr.Error); ok {
m.SetFinalException(getMetricException(awserr))
}
}
m.TruncateFields()
// TODO: Probably want to figure something out for logging dropped // TODO: Probably want to figure something out for logging dropped
// metrics // metrics
rep.metricsCh.Push(m) rep.metricsCh.Push(m)
@ -172,8 +185,9 @@ func (rep *Reporter) start() {
} }
} }
// Pause will pause the metric channel preventing any new metrics from // Pause will pause the metric channel preventing any new metrics from being
// being added. // added. It is safe to call concurrently with other calls to Pause, but if
// called concurently with Continue can lead to unexpected state.
func (rep *Reporter) Pause() { func (rep *Reporter) Pause() {
lock.Lock() lock.Lock()
defer lock.Unlock() defer lock.Unlock()
@ -185,8 +199,9 @@ func (rep *Reporter) Pause() {
rep.close() rep.close()
} }
// Continue will reopen the metric channel and allow for monitoring // Continue will reopen the metric channel and allow for monitoring to be
// to be resumed. // resumed. It is safe to call concurrently with other calls to Continue, but
// if called concurently with Pause can lead to unexpected state.
func (rep *Reporter) Continue() { func (rep *Reporter) Continue() {
lock.Lock() lock.Lock()
defer lock.Unlock() defer lock.Unlock()
@ -201,10 +216,18 @@ func (rep *Reporter) Continue() {
rep.metricsCh.Continue() rep.metricsCh.Continue()
} }
// Client side metric handler names
const (
APICallMetricHandlerName = "awscsm.SendAPICallMetric"
APICallAttemptMetricHandlerName = "awscsm.SendAPICallAttemptMetric"
)
// InjectHandlers will will enable client side metrics and inject the proper // InjectHandlers will will enable client side metrics and inject the proper
// handlers to handle how metrics are sent. // handlers to handle how metrics are sent.
// //
// Example: // InjectHandlers is NOT safe to call concurrently. Calling InjectHandlers
// multiple times may lead to unexpected behavior, (e.g. duplicate metrics).
//
// // Start must be called in order to inject the correct handlers // // Start must be called in order to inject the correct handlers
// r, err := csm.Start("clientID", "127.0.0.1:8094") // r, err := csm.Start("clientID", "127.0.0.1:8094")
// if err != nil { // if err != nil {
@ -221,11 +244,22 @@ func (rep *Reporter) InjectHandlers(handlers *request.Handlers) {
return return
} }
apiCallHandler := request.NamedHandler{Name: APICallMetricHandlerName, Fn: rep.sendAPICallMetric} handlers.Complete.PushFrontNamed(request.NamedHandler{
apiCallAttemptHandler := request.NamedHandler{Name: APICallAttemptMetricHandlerName, Fn: rep.sendAPICallAttemptMetric} Name: APICallMetricHandlerName,
Fn: rep.sendAPICallMetric,
})
handlers.Complete.PushFrontNamed(apiCallHandler) handlers.CompleteAttempt.PushFrontNamed(request.NamedHandler{
handlers.Complete.PushFrontNamed(apiCallAttemptHandler) Name: APICallAttemptMetricHandlerName,
Fn: rep.sendAPICallAttemptMetric,
handlers.AfterRetry.PushFrontNamed(apiCallAttemptHandler) })
}
// boolIntValue return 1 for true and 0 for false.
func boolIntValue(b bool) int {
if b {
return 1
}
return 0
} }

@ -24,6 +24,7 @@ import (
"github.com/aws/aws-sdk-go/aws/ec2metadata" "github.com/aws/aws-sdk-go/aws/ec2metadata"
"github.com/aws/aws-sdk-go/aws/endpoints" "github.com/aws/aws-sdk-go/aws/endpoints"
"github.com/aws/aws-sdk-go/aws/request" "github.com/aws/aws-sdk-go/aws/request"
"github.com/aws/aws-sdk-go/internal/shareddefaults"
) )
// A Defaults provides a collection of default values for SDK clients. // A Defaults provides a collection of default values for SDK clients.
@ -112,8 +113,8 @@ func CredProviders(cfg *aws.Config, handlers request.Handlers) []credentials.Pro
} }
const ( const (
httpProviderAuthorizationEnvVar = "AWS_CONTAINER_AUTHORIZATION_TOKEN"
httpProviderEnvVar = "AWS_CONTAINER_CREDENTIALS_FULL_URI" httpProviderEnvVar = "AWS_CONTAINER_CREDENTIALS_FULL_URI"
ecsCredsProviderEnvVar = "AWS_CONTAINER_CREDENTIALS_RELATIVE_URI"
) )
// RemoteCredProvider returns a credentials provider for the default remote // RemoteCredProvider returns a credentials provider for the default remote
@ -123,8 +124,8 @@ func RemoteCredProvider(cfg aws.Config, handlers request.Handlers) credentials.P
return localHTTPCredProvider(cfg, handlers, u) return localHTTPCredProvider(cfg, handlers, u)
} }
if uri := os.Getenv(ecsCredsProviderEnvVar); len(uri) > 0 { if uri := os.Getenv(shareddefaults.ECSCredsProviderEnvVar); len(uri) > 0 {
u := fmt.Sprintf("http://169.254.170.2%s", uri) u := fmt.Sprintf("%s%s", shareddefaults.ECSContainerCredentialsURI, uri)
return httpCredProvider(cfg, handlers, u) return httpCredProvider(cfg, handlers, u)
} }
@ -187,6 +188,7 @@ func httpCredProvider(cfg aws.Config, handlers request.Handlers, u string) crede
return endpointcreds.NewProviderClient(cfg, handlers, u, return endpointcreds.NewProviderClient(cfg, handlers, u,
func(p *endpointcreds.Provider) { func(p *endpointcreds.Provider) {
p.ExpiryWindow = 5 * time.Minute p.ExpiryWindow = 5 * time.Minute
p.AuthorizationToken = os.Getenv(httpProviderAuthorizationEnvVar)
}, },
) )
} }

@ -24,8 +24,9 @@ func (c *EC2Metadata) GetMetadata(p string) (string, error) {
output := &metadataOutput{} output := &metadataOutput{}
req := c.NewRequest(op, nil, output) req := c.NewRequest(op, nil, output)
err := req.Send()
return output.Content, req.Send() return output.Content, err
} }
// GetUserData returns the userdata that was configured for the service. If // GetUserData returns the userdata that was configured for the service. If
@ -45,8 +46,9 @@ func (c *EC2Metadata) GetUserData() (string, error) {
r.Error = awserr.New("NotFoundError", "user-data not found", r.Error) r.Error = awserr.New("NotFoundError", "user-data not found", r.Error)
} }
}) })
err := req.Send()
return output.Content, req.Send() return output.Content, err
} }
// GetDynamicData uses the path provided to request information from the EC2 // GetDynamicData uses the path provided to request information from the EC2
@ -61,8 +63,9 @@ func (c *EC2Metadata) GetDynamicData(p string) (string, error) {
output := &metadataOutput{} output := &metadataOutput{}
req := c.NewRequest(op, nil, output) req := c.NewRequest(op, nil, output)
err := req.Send()
return output.Content, req.Send() return output.Content, err
} }
// GetInstanceIdentityDocument retrieves an identity document describing an // GetInstanceIdentityDocument retrieves an identity document describing an
@ -79,7 +82,7 @@ func (c *EC2Metadata) GetInstanceIdentityDocument() (EC2InstanceIdentityDocument
doc := EC2InstanceIdentityDocument{} doc := EC2InstanceIdentityDocument{}
if err := json.NewDecoder(strings.NewReader(resp)).Decode(&doc); err != nil { if err := json.NewDecoder(strings.NewReader(resp)).Decode(&doc); err != nil {
return EC2InstanceIdentityDocument{}, return EC2InstanceIdentityDocument{},
awserr.New("SerializationError", awserr.New(request.ErrCodeSerialization,
"failed to decode EC2 instance identity document", err) "failed to decode EC2 instance identity document", err)
} }
@ -98,7 +101,7 @@ func (c *EC2Metadata) IAMInfo() (EC2IAMInfo, error) {
info := EC2IAMInfo{} info := EC2IAMInfo{}
if err := json.NewDecoder(strings.NewReader(resp)).Decode(&info); err != nil { if err := json.NewDecoder(strings.NewReader(resp)).Decode(&info); err != nil {
return EC2IAMInfo{}, return EC2IAMInfo{},
awserr.New("SerializationError", awserr.New(request.ErrCodeSerialization,
"failed to decode EC2 IAM info", err) "failed to decode EC2 IAM info", err)
} }
@ -118,6 +121,10 @@ func (c *EC2Metadata) Region() (string, error) {
return "", err return "", err
} }
if len(resp) == 0 {
return "", awserr.New("EC2MetadataError", "invalid Region response", nil)
}
// returns region without the suffix. Eg: us-west-2a becomes us-west-2 // returns region without the suffix. Eg: us-west-2a becomes us-west-2
return resp[:len(resp)-1], nil return resp[:len(resp)-1], nil
} }

@ -4,7 +4,7 @@
// This package's client can be disabled completely by setting the environment // This package's client can be disabled completely by setting the environment
// variable "AWS_EC2_METADATA_DISABLED=true". This environment variable set to // variable "AWS_EC2_METADATA_DISABLED=true". This environment variable set to
// true instructs the SDK to disable the EC2 Metadata client. The client cannot // true instructs the SDK to disable the EC2 Metadata client. The client cannot
// be used while the environemnt variable is set to true, (case insensitive). // be used while the environment variable is set to true, (case insensitive).
package ec2metadata package ec2metadata
import ( import (
@ -72,6 +72,7 @@ func NewClient(cfg aws.Config, handlers request.Handlers, endpoint, signingRegio
cfg, cfg,
metadata.ClientInfo{ metadata.ClientInfo{
ServiceName: ServiceName, ServiceName: ServiceName,
ServiceID: ServiceName,
Endpoint: endpoint, Endpoint: endpoint,
APIVersion: "latest", APIVersion: "latest",
}, },
@ -91,6 +92,9 @@ func NewClient(cfg aws.Config, handlers request.Handlers, endpoint, signingRegio
svc.Handlers.Send.SwapNamed(request.NamedHandler{ svc.Handlers.Send.SwapNamed(request.NamedHandler{
Name: corehandlers.SendHandler.Name, Name: corehandlers.SendHandler.Name,
Fn: func(r *request.Request) { Fn: func(r *request.Request) {
r.HTTPResponse = &http.Response{
Header: http.Header{},
}
r.Error = awserr.New( r.Error = awserr.New(
request.CanceledErrorCode, request.CanceledErrorCode,
"EC2 IMDS access disabled via "+disableServiceEnvVar+" env var", "EC2 IMDS access disabled via "+disableServiceEnvVar+" env var",
@ -119,7 +123,7 @@ func unmarshalHandler(r *request.Request) {
defer r.HTTPResponse.Body.Close() defer r.HTTPResponse.Body.Close()
b := &bytes.Buffer{} b := &bytes.Buffer{}
if _, err := io.Copy(b, r.HTTPResponse.Body); err != nil { if _, err := io.Copy(b, r.HTTPResponse.Body); err != nil {
r.Error = awserr.New("SerializationError", "unable to unmarshal EC2 metadata respose", err) r.Error = awserr.New(request.ErrCodeSerialization, "unable to unmarshal EC2 metadata respose", err)
return return
} }
@ -132,7 +136,7 @@ func unmarshalError(r *request.Request) {
defer r.HTTPResponse.Body.Close() defer r.HTTPResponse.Body.Close()
b := &bytes.Buffer{} b := &bytes.Buffer{}
if _, err := io.Copy(b, r.HTTPResponse.Body); err != nil { if _, err := io.Copy(b, r.HTTPResponse.Body); err != nil {
r.Error = awserr.New("SerializationError", "unable to unmarshal EC2 metadata error respose", err) r.Error = awserr.New(request.ErrCodeSerialization, "unable to unmarshal EC2 metadata error respose", err)
return return
} }

@ -85,6 +85,7 @@ func decodeV3Endpoints(modelDef modelDefinition, opts DecodeModelOptions) (Resol
custAddS3DualStack(p) custAddS3DualStack(p)
custRmIotDataService(p) custRmIotDataService(p)
custFixAppAutoscalingChina(p) custFixAppAutoscalingChina(p)
custFixAppAutoscalingUsGov(p)
} }
return ps, nil return ps, nil
@ -95,7 +96,12 @@ func custAddS3DualStack(p *partition) {
return return
} }
s, ok := p.Services["s3"] custAddDualstack(p, "s3")
custAddDualstack(p, "s3-control")
}
func custAddDualstack(p *partition, svcName string) {
s, ok := p.Services[svcName]
if !ok { if !ok {
return return
} }
@ -103,7 +109,7 @@ func custAddS3DualStack(p *partition) {
s.Defaults.HasDualStack = boxedTrue s.Defaults.HasDualStack = boxedTrue
s.Defaults.DualStackHostname = "{service}.dualstack.{region}.{dnsSuffix}" s.Defaults.DualStackHostname = "{service}.dualstack.{region}.{dnsSuffix}"
p.Services["s3"] = s p.Services[svcName] = s
} }
func custAddEC2Metadata(p *partition) { func custAddEC2Metadata(p *partition) {
@ -144,6 +150,33 @@ func custFixAppAutoscalingChina(p *partition) {
p.Services[serviceName] = s p.Services[serviceName] = s
} }
func custFixAppAutoscalingUsGov(p *partition) {
if p.ID != "aws-us-gov" {
return
}
const serviceName = "application-autoscaling"
s, ok := p.Services[serviceName]
if !ok {
return
}
if a := s.Defaults.CredentialScope.Service; a != "" {
fmt.Printf("custFixAppAutoscalingUsGov: ignoring customization, expected empty credential scope service, got %s\n", a)
return
}
if a := s.Defaults.Hostname; a != "" {
fmt.Printf("custFixAppAutoscalingUsGov: ignoring customization, expected empty hostname, got %s\n", a)
return
}
s.Defaults.CredentialScope.Service = "application-autoscaling"
s.Defaults.Hostname = "autoscaling.{region}.amazonaws.com"
p.Services[serviceName] = s
}
type decodeModelError struct { type decodeModelError struct {
awsError awsError
} }

File diff suppressed because it is too large Load Diff

@ -0,0 +1,141 @@
package endpoints
// Service identifiers
//
// Deprecated: Use client package's EndpointsID value instead of these
// ServiceIDs. These IDs are not maintained, and are out of date.
const (
A4bServiceID = "a4b" // A4b.
AcmServiceID = "acm" // Acm.
AcmPcaServiceID = "acm-pca" // AcmPca.
ApiMediatailorServiceID = "api.mediatailor" // ApiMediatailor.
ApiPricingServiceID = "api.pricing" // ApiPricing.
ApiSagemakerServiceID = "api.sagemaker" // ApiSagemaker.
ApigatewayServiceID = "apigateway" // Apigateway.
ApplicationAutoscalingServiceID = "application-autoscaling" // ApplicationAutoscaling.
Appstream2ServiceID = "appstream2" // Appstream2.
AppsyncServiceID = "appsync" // Appsync.
AthenaServiceID = "athena" // Athena.
AutoscalingServiceID = "autoscaling" // Autoscaling.
AutoscalingPlansServiceID = "autoscaling-plans" // AutoscalingPlans.
BatchServiceID = "batch" // Batch.
BudgetsServiceID = "budgets" // Budgets.
CeServiceID = "ce" // Ce.
ChimeServiceID = "chime" // Chime.
Cloud9ServiceID = "cloud9" // Cloud9.
ClouddirectoryServiceID = "clouddirectory" // Clouddirectory.
CloudformationServiceID = "cloudformation" // Cloudformation.
CloudfrontServiceID = "cloudfront" // Cloudfront.
CloudhsmServiceID = "cloudhsm" // Cloudhsm.
Cloudhsmv2ServiceID = "cloudhsmv2" // Cloudhsmv2.
CloudsearchServiceID = "cloudsearch" // Cloudsearch.
CloudtrailServiceID = "cloudtrail" // Cloudtrail.
CodebuildServiceID = "codebuild" // Codebuild.
CodecommitServiceID = "codecommit" // Codecommit.
CodedeployServiceID = "codedeploy" // Codedeploy.
CodepipelineServiceID = "codepipeline" // Codepipeline.
CodestarServiceID = "codestar" // Codestar.
CognitoIdentityServiceID = "cognito-identity" // CognitoIdentity.
CognitoIdpServiceID = "cognito-idp" // CognitoIdp.
CognitoSyncServiceID = "cognito-sync" // CognitoSync.
ComprehendServiceID = "comprehend" // Comprehend.
ConfigServiceID = "config" // Config.
CurServiceID = "cur" // Cur.
DatapipelineServiceID = "datapipeline" // Datapipeline.
DaxServiceID = "dax" // Dax.
DevicefarmServiceID = "devicefarm" // Devicefarm.
DirectconnectServiceID = "directconnect" // Directconnect.
DiscoveryServiceID = "discovery" // Discovery.
DmsServiceID = "dms" // Dms.
DsServiceID = "ds" // Ds.
DynamodbServiceID = "dynamodb" // Dynamodb.
Ec2ServiceID = "ec2" // Ec2.
Ec2metadataServiceID = "ec2metadata" // Ec2metadata.
EcrServiceID = "ecr" // Ecr.
EcsServiceID = "ecs" // Ecs.
ElasticacheServiceID = "elasticache" // Elasticache.
ElasticbeanstalkServiceID = "elasticbeanstalk" // Elasticbeanstalk.
ElasticfilesystemServiceID = "elasticfilesystem" // Elasticfilesystem.
ElasticloadbalancingServiceID = "elasticloadbalancing" // Elasticloadbalancing.
ElasticmapreduceServiceID = "elasticmapreduce" // Elasticmapreduce.
ElastictranscoderServiceID = "elastictranscoder" // Elastictranscoder.
EmailServiceID = "email" // Email.
EntitlementMarketplaceServiceID = "entitlement.marketplace" // EntitlementMarketplace.
EsServiceID = "es" // Es.
EventsServiceID = "events" // Events.
FirehoseServiceID = "firehose" // Firehose.
FmsServiceID = "fms" // Fms.
GameliftServiceID = "gamelift" // Gamelift.
GlacierServiceID = "glacier" // Glacier.
GlueServiceID = "glue" // Glue.
GreengrassServiceID = "greengrass" // Greengrass.
GuarddutyServiceID = "guardduty" // Guardduty.
HealthServiceID = "health" // Health.
IamServiceID = "iam" // Iam.
ImportexportServiceID = "importexport" // Importexport.
InspectorServiceID = "inspector" // Inspector.
IotServiceID = "iot" // Iot.
IotanalyticsServiceID = "iotanalytics" // Iotanalytics.
KinesisServiceID = "kinesis" // Kinesis.
KinesisanalyticsServiceID = "kinesisanalytics" // Kinesisanalytics.
KinesisvideoServiceID = "kinesisvideo" // Kinesisvideo.
KmsServiceID = "kms" // Kms.
LambdaServiceID = "lambda" // Lambda.
LightsailServiceID = "lightsail" // Lightsail.
LogsServiceID = "logs" // Logs.
MachinelearningServiceID = "machinelearning" // Machinelearning.
MarketplacecommerceanalyticsServiceID = "marketplacecommerceanalytics" // Marketplacecommerceanalytics.
MediaconvertServiceID = "mediaconvert" // Mediaconvert.
MedialiveServiceID = "medialive" // Medialive.
MediapackageServiceID = "mediapackage" // Mediapackage.
MediastoreServiceID = "mediastore" // Mediastore.
MeteringMarketplaceServiceID = "metering.marketplace" // MeteringMarketplace.
MghServiceID = "mgh" // Mgh.
MobileanalyticsServiceID = "mobileanalytics" // Mobileanalytics.
ModelsLexServiceID = "models.lex" // ModelsLex.
MonitoringServiceID = "monitoring" // Monitoring.
MturkRequesterServiceID = "mturk-requester" // MturkRequester.
NeptuneServiceID = "neptune" // Neptune.
OpsworksServiceID = "opsworks" // Opsworks.
OpsworksCmServiceID = "opsworks-cm" // OpsworksCm.
OrganizationsServiceID = "organizations" // Organizations.
PinpointServiceID = "pinpoint" // Pinpoint.
PollyServiceID = "polly" // Polly.
RdsServiceID = "rds" // Rds.
RedshiftServiceID = "redshift" // Redshift.
RekognitionServiceID = "rekognition" // Rekognition.
ResourceGroupsServiceID = "resource-groups" // ResourceGroups.
Route53ServiceID = "route53" // Route53.
Route53domainsServiceID = "route53domains" // Route53domains.
RuntimeLexServiceID = "runtime.lex" // RuntimeLex.
RuntimeSagemakerServiceID = "runtime.sagemaker" // RuntimeSagemaker.
S3ServiceID = "s3" // S3.
S3ControlServiceID = "s3-control" // S3Control.
SagemakerServiceID = "api.sagemaker" // Sagemaker.
SdbServiceID = "sdb" // Sdb.
SecretsmanagerServiceID = "secretsmanager" // Secretsmanager.
ServerlessrepoServiceID = "serverlessrepo" // Serverlessrepo.
ServicecatalogServiceID = "servicecatalog" // Servicecatalog.
ServicediscoveryServiceID = "servicediscovery" // Servicediscovery.
ShieldServiceID = "shield" // Shield.
SmsServiceID = "sms" // Sms.
SnowballServiceID = "snowball" // Snowball.
SnsServiceID = "sns" // Sns.
SqsServiceID = "sqs" // Sqs.
SsmServiceID = "ssm" // Ssm.
StatesServiceID = "states" // States.
StoragegatewayServiceID = "storagegateway" // Storagegateway.
StreamsDynamodbServiceID = "streams.dynamodb" // StreamsDynamodb.
StsServiceID = "sts" // Sts.
SupportServiceID = "support" // Support.
SwfServiceID = "swf" // Swf.
TaggingServiceID = "tagging" // Tagging.
TransferServiceID = "transfer" // Transfer.
TranslateServiceID = "translate" // Translate.
WafServiceID = "waf" // Waf.
WafRegionalServiceID = "waf-regional" // WafRegional.
WorkdocsServiceID = "workdocs" // Workdocs.
WorkmailServiceID = "workmail" // Workmail.
WorkspacesServiceID = "workspaces" // Workspaces.
XrayServiceID = "xray" // Xray.
)

@ -35,7 +35,7 @@ type Options struct {
// //
// If resolving an endpoint on the partition list the provided region will // If resolving an endpoint on the partition list the provided region will
// be used to determine which partition's domain name pattern to the service // be used to determine which partition's domain name pattern to the service
// endpoint ID with. If both the service and region are unkonwn and resolving // endpoint ID with. If both the service and region are unknown and resolving
// the endpoint on partition list an UnknownEndpointError error will be returned. // the endpoint on partition list an UnknownEndpointError error will be returned.
// //
// If resolving and endpoint on a partition specific resolver that partition's // If resolving and endpoint on a partition specific resolver that partition's

@ -16,6 +16,10 @@ import (
type CodeGenOptions struct { type CodeGenOptions struct {
// Options for how the model will be decoded. // Options for how the model will be decoded.
DecodeModelOptions DecodeModelOptions DecodeModelOptions DecodeModelOptions
// Disables code generation of the service endpoint prefix IDs defined in
// the model.
DisableGenerateServiceIDs bool
} }
// Set combines all of the option functions together // Set combines all of the option functions together
@ -39,8 +43,16 @@ func CodeGenModel(modelFile io.Reader, outFile io.Writer, optFns ...func(*CodeGe
return err return err
} }
v := struct {
Resolver
CodeGenOptions
}{
Resolver: resolver,
CodeGenOptions: opts,
}
tmpl := template.Must(template.New("tmpl").Funcs(funcMap).Parse(v3Tmpl)) tmpl := template.Must(template.New("tmpl").Funcs(funcMap).Parse(v3Tmpl))
if err := tmpl.ExecuteTemplate(outFile, "defaults", resolver); err != nil { if err := tmpl.ExecuteTemplate(outFile, "defaults", v); err != nil {
return fmt.Errorf("failed to execute template, %v", err) return fmt.Errorf("failed to execute template, %v", err)
} }
@ -166,15 +178,17 @@ import (
"regexp" "regexp"
) )
{{ template "partition consts" . }} {{ template "partition consts" $.Resolver }}
{{ range $_, $partition := . }} {{ range $_, $partition := $.Resolver }}
{{ template "partition region consts" $partition }} {{ template "partition region consts" $partition }}
{{ end }} {{ end }}
{{ template "service consts" . }} {{ if not $.DisableGenerateServiceIDs -}}
{{ template "service consts" $.Resolver }}
{{- end }}
{{ template "endpoint resolvers" . }} {{ template "endpoint resolvers" $.Resolver }}
{{- end }} {{- end }}
{{ define "partition consts" }} {{ define "partition consts" }}

@ -5,13 +5,9 @@ import "github.com/aws/aws-sdk-go/aws/awserr"
var ( var (
// ErrMissingRegion is an error that is returned if region configuration is // ErrMissingRegion is an error that is returned if region configuration is
// not found. // not found.
//
// @readonly
ErrMissingRegion = awserr.New("MissingRegion", "could not find region configuration", nil) ErrMissingRegion = awserr.New("MissingRegion", "could not find region configuration", nil)
// ErrMissingEndpoint is an error that is returned if an endpoint cannot be // ErrMissingEndpoint is an error that is returned if an endpoint cannot be
// resolved for a service. // resolved for a service.
//
// @readonly
ErrMissingEndpoint = awserr.New("MissingEndpoint", "'Endpoint' configuration is required for this service", nil) ErrMissingEndpoint = awserr.New("MissingEndpoint", "'Endpoint' configuration is required for this service", nil)
) )

@ -1,18 +1,17 @@
// +build !appengine,!plan9
package request package request
import ( import (
"net" "strings"
"os"
"syscall"
) )
func isErrConnectionReset(err error) bool { func isErrConnectionReset(err error) bool {
if opErr, ok := err.(*net.OpError); ok { if strings.Contains(err.Error(), "read: connection reset") {
if sysErr, ok := opErr.Err.(*os.SyscallError); ok { return false
return sysErr.Err == syscall.ECONNRESET
} }
if strings.Contains(err.Error(), "connection reset") ||
strings.Contains(err.Error(), "broken pipe") {
return true
} }
return false return false

@ -1,11 +0,0 @@
// +build appengine plan9
package request
import (
"strings"
)
func isErrConnectionReset(err error) bool {
return strings.Contains(err.Error(), "connection reset")
}

@ -19,6 +19,7 @@ type Handlers struct {
UnmarshalError HandlerList UnmarshalError HandlerList
Retry HandlerList Retry HandlerList
AfterRetry HandlerList AfterRetry HandlerList
CompleteAttempt HandlerList
Complete HandlerList Complete HandlerList
} }
@ -36,6 +37,7 @@ func (h *Handlers) Copy() Handlers {
UnmarshalMeta: h.UnmarshalMeta.copy(), UnmarshalMeta: h.UnmarshalMeta.copy(),
Retry: h.Retry.copy(), Retry: h.Retry.copy(),
AfterRetry: h.AfterRetry.copy(), AfterRetry: h.AfterRetry.copy(),
CompleteAttempt: h.CompleteAttempt.copy(),
Complete: h.Complete.copy(), Complete: h.Complete.copy(),
} }
} }
@ -53,9 +55,55 @@ func (h *Handlers) Clear() {
h.ValidateResponse.Clear() h.ValidateResponse.Clear()
h.Retry.Clear() h.Retry.Clear()
h.AfterRetry.Clear() h.AfterRetry.Clear()
h.CompleteAttempt.Clear()
h.Complete.Clear() h.Complete.Clear()
} }
// IsEmpty returns if there are no handlers in any of the handlerlists.
func (h *Handlers) IsEmpty() bool {
if h.Validate.Len() != 0 {
return false
}
if h.Build.Len() != 0 {
return false
}
if h.Send.Len() != 0 {
return false
}
if h.Sign.Len() != 0 {
return false
}
if h.Unmarshal.Len() != 0 {
return false
}
if h.UnmarshalStream.Len() != 0 {
return false
}
if h.UnmarshalMeta.Len() != 0 {
return false
}
if h.UnmarshalError.Len() != 0 {
return false
}
if h.ValidateResponse.Len() != 0 {
return false
}
if h.Retry.Len() != 0 {
return false
}
if h.AfterRetry.Len() != 0 {
return false
}
if h.CompleteAttempt.Len() != 0 {
return false
}
if h.Complete.Len() != 0 {
return false
}
return true
}
// A HandlerListRunItem represents an entry in the HandlerList which // A HandlerListRunItem represents an entry in the HandlerList which
// is being run. // is being run.
type HandlerListRunItem struct { type HandlerListRunItem struct {

@ -122,7 +122,6 @@ func New(cfg aws.Config, clientInfo metadata.ClientInfo, handlers Handlers,
Handlers: handlers.Copy(), Handlers: handlers.Copy(),
Retryer: retryer, Retryer: retryer,
AttemptTime: time.Now(),
Time: time.Now(), Time: time.Now(),
ExpireTime: 0, ExpireTime: 0,
Operation: operation, Operation: operation,
@ -233,6 +232,10 @@ func (r *Request) WillRetry() bool {
return r.Error != nil && aws.BoolValue(r.Retryable) && r.RetryCount < r.MaxRetries() return r.Error != nil && aws.BoolValue(r.Retryable) && r.RetryCount < r.MaxRetries()
} }
func fmtAttemptCount(retryCount, maxRetries int) string {
return fmt.Sprintf("attempt %v/%v", retryCount, maxRetries)
}
// ParamsFilled returns if the request's parameters have been populated // ParamsFilled returns if the request's parameters have been populated
// and the parameters are valid. False is returned if no parameters are // and the parameters are valid. False is returned if no parameters are
// provided or invalid. // provided or invalid.
@ -266,7 +269,9 @@ func (r *Request) SetReaderBody(reader io.ReadSeeker) {
} }
// Presign returns the request's signed URL. Error will be returned // Presign returns the request's signed URL. Error will be returned
// if the signing fails. // if the signing fails. The expire parameter is only used for presigned Amazon
// S3 API requests. All other AWS services will use a fixed expiration
// time of 15 minutes.
// //
// It is invalid to create a presigned URL with a expire duration 0 or less. An // It is invalid to create a presigned URL with a expire duration 0 or less. An
// error is returned if expire duration is 0 or less. // error is returned if expire duration is 0 or less.
@ -283,7 +288,9 @@ func (r *Request) Presign(expire time.Duration) (string, error) {
} }
// PresignRequest behaves just like presign, with the addition of returning a // PresignRequest behaves just like presign, with the addition of returning a
// set of headers that were signed. // set of headers that were signed. The expire parameter is only used for
// presigned Amazon S3 API requests. All other AWS services will use a fixed
// expiration time of 15 minutes.
// //
// It is invalid to create a presigned URL with a expire duration 0 or less. An // It is invalid to create a presigned URL with a expire duration 0 or less. An
// error is returned if expire duration is 0 or less. // error is returned if expire duration is 0 or less.
@ -328,16 +335,17 @@ func getPresignedURL(r *Request, expire time.Duration) (string, http.Header, err
return r.HTTPRequest.URL.String(), r.SignedHeaderVals, nil return r.HTTPRequest.URL.String(), r.SignedHeaderVals, nil
} }
func debugLogReqError(r *Request, stage string, retrying bool, err error) { const (
willRetry = "will retry"
notRetrying = "not retrying"
retryCount = "retry %v/%v"
)
func debugLogReqError(r *Request, stage, retryStr string, err error) {
if !r.Config.LogLevel.Matches(aws.LogDebugWithRequestErrors) { if !r.Config.LogLevel.Matches(aws.LogDebugWithRequestErrors) {
return return
} }
retryStr := "not retrying"
if retrying {
retryStr = "will retry"
}
r.Config.Logger.Log(fmt.Sprintf("DEBUG: %s %s/%s failed, %s, error %v", r.Config.Logger.Log(fmt.Sprintf("DEBUG: %s %s/%s failed, %s, error %v",
stage, r.ClientInfo.ServiceName, r.Operation.Name, retryStr, err)) stage, r.ClientInfo.ServiceName, r.Operation.Name, retryStr, err))
} }
@ -356,12 +364,12 @@ func (r *Request) Build() error {
if !r.built { if !r.built {
r.Handlers.Validate.Run(r) r.Handlers.Validate.Run(r)
if r.Error != nil { if r.Error != nil {
debugLogReqError(r, "Validate Request", false, r.Error) debugLogReqError(r, "Validate Request", notRetrying, r.Error)
return r.Error return r.Error
} }
r.Handlers.Build.Run(r) r.Handlers.Build.Run(r)
if r.Error != nil { if r.Error != nil {
debugLogReqError(r, "Build Request", false, r.Error) debugLogReqError(r, "Build Request", notRetrying, r.Error)
return r.Error return r.Error
} }
r.built = true r.built = true
@ -377,7 +385,7 @@ func (r *Request) Build() error {
func (r *Request) Sign() error { func (r *Request) Sign() error {
r.Build() r.Build()
if r.Error != nil { if r.Error != nil {
debugLogReqError(r, "Build Request", false, r.Error) debugLogReqError(r, "Build Request", notRetrying, r.Error)
return r.Error return r.Error
} }
@ -462,9 +470,38 @@ func (r *Request) Send() error {
r.Handlers.Complete.Run(r) r.Handlers.Complete.Run(r)
}() }()
if err := r.Error; err != nil {
return err
}
for { for {
r.Error = nil
r.AttemptTime = time.Now() r.AttemptTime = time.Now()
if aws.BoolValue(r.Retryable) {
if err := r.Sign(); err != nil {
debugLogReqError(r, "Sign Request", notRetrying, err)
return err
}
if err := r.sendRequest(); err == nil {
return nil
} else if !shouldRetryError(r.Error) {
return err
} else {
r.Handlers.Retry.Run(r)
r.Handlers.AfterRetry.Run(r)
if r.Error != nil || !aws.BoolValue(r.Retryable) {
return r.Error
}
r.prepareRetry()
continue
}
}
}
func (r *Request) prepareRetry() {
if r.Config.LogLevel.Matches(aws.LogDebugWithRequestRetries) { if r.Config.LogLevel.Matches(aws.LogDebugWithRequestRetries) {
r.Config.Logger.Log(fmt.Sprintf("DEBUG: Retrying Request %s/%s, attempt %d", r.Config.Logger.Log(fmt.Sprintf("DEBUG: Retrying Request %s/%s, attempt %d",
r.ClientInfo.ServiceName, r.Operation.Name, r.RetryCount)) r.ClientInfo.ServiceName, r.Operation.Name, r.RetryCount))
@ -483,60 +520,35 @@ func (r *Request) Send() error {
} }
} }
r.Sign() func (r *Request) sendRequest() (sendErr error) {
if r.Error != nil { defer r.Handlers.CompleteAttempt.Run(r)
return r.Error
}
r.Retryable = nil r.Retryable = nil
r.Handlers.Send.Run(r) r.Handlers.Send.Run(r)
if r.Error != nil { if r.Error != nil {
if !shouldRetryCancel(r) { debugLogReqError(r, "Send Request",
fmtAttemptCount(r.RetryCount, r.MaxRetries()),
r.Error)
return r.Error return r.Error
} }
err := r.Error
r.Handlers.Retry.Run(r)
r.Handlers.AfterRetry.Run(r)
if r.Error != nil {
debugLogReqError(r, "Send Request", false, err)
return r.Error
}
debugLogReqError(r, "Send Request", true, err)
continue
}
r.Handlers.UnmarshalMeta.Run(r) r.Handlers.UnmarshalMeta.Run(r)
r.Handlers.ValidateResponse.Run(r) r.Handlers.ValidateResponse.Run(r)
if r.Error != nil { if r.Error != nil {
r.Handlers.UnmarshalError.Run(r) r.Handlers.UnmarshalError.Run(r)
err := r.Error debugLogReqError(r, "Validate Response",
fmtAttemptCount(r.RetryCount, r.MaxRetries()),
r.Handlers.Retry.Run(r) r.Error)
r.Handlers.AfterRetry.Run(r)
if r.Error != nil {
debugLogReqError(r, "Validate Response", false, err)
return r.Error return r.Error
} }
debugLogReqError(r, "Validate Response", true, err)
continue
}
r.Handlers.Unmarshal.Run(r) r.Handlers.Unmarshal.Run(r)
if r.Error != nil { if r.Error != nil {
err := r.Error debugLogReqError(r, "Unmarshal Response",
r.Handlers.Retry.Run(r) fmtAttemptCount(r.RetryCount, r.MaxRetries()),
r.Handlers.AfterRetry.Run(r) r.Error)
if r.Error != nil {
debugLogReqError(r, "Unmarshal Response", false, err)
return r.Error return r.Error
} }
debugLogReqError(r, "Unmarshal Response", true, err)
continue
}
break
}
return nil return nil
} }
@ -561,30 +573,49 @@ func AddToUserAgent(r *Request, s string) {
r.HTTPRequest.Header.Set("User-Agent", s) r.HTTPRequest.Header.Set("User-Agent", s)
} }
func shouldRetryCancel(r *Request) bool { type temporary interface {
awsErr, ok := r.Error.(awserr.Error) Temporary() bool
timeoutErr := false }
errStr := r.Error.Error()
if ok { func shouldRetryError(origErr error) bool {
if awsErr.Code() == CanceledErrorCode { switch err := origErr.(type) {
case awserr.Error:
if err.Code() == CanceledErrorCode {
return false return false
} }
err := awsErr.OrigErr() return shouldRetryError(err.OrigErr())
netErr, netOK := err.(net.Error) case *url.Error:
timeoutErr = netOK && netErr.Temporary() if strings.Contains(err.Error(), "connection refused") {
if urlErr, ok := err.(*url.Error); !timeoutErr && ok { // Refused connections should be retried as the service may not yet
errStr = urlErr.Err.Error() // be running on the port. Go TCP dial considers refused
// connections as not temporary.
return true
} }
// *url.Error only implements Temporary after golang 1.6 but since
// url.Error only wraps the error:
return shouldRetryError(err.Err)
case temporary:
if netErr, ok := err.(*net.OpError); ok && netErr.Op == "dial" {
return true
}
// If the error is temporary, we want to allow continuation of the
// retry process
return err.Temporary() || isErrConnectionReset(origErr)
case nil:
// `awserr.Error.OrigErr()` can be nil, meaning there was an error but
// because we don't know the cause, it is marked as retryable. See
// TestRequest4xxUnretryable for an example.
return true
default:
switch err.Error() {
case "net/http: request canceled",
"net/http: request canceled while waiting for connection":
// known 1.5 error case when an http request is cancelled
return false
}
// here we don't know the error; so we allow a retry.
return true
} }
// There can be two types of canceled errors here.
// The first being a net.Error and the other being an error.
// If the request was timed out, we want to continue the retry
// process. Otherwise, return the canceled error.
return timeoutErr ||
(errStr != "net/http: request canceled" &&
errStr != "net/http: request canceled while waiting for connection")
} }
// SanitizeHostForHeader removes default port from host and updates request.Host // SanitizeHostForHeader removes default port from host and updates request.Host

@ -38,8 +38,10 @@ var throttleCodes = map[string]struct{}{
"ThrottlingException": {}, "ThrottlingException": {},
"RequestLimitExceeded": {}, "RequestLimitExceeded": {},
"RequestThrottled": {}, "RequestThrottled": {},
"RequestThrottledException": {},
"TooManyRequestsException": {}, // Lambda functions "TooManyRequestsException": {}, // Lambda functions
"PriorRequestNotComplete": {}, // Route53 "PriorRequestNotComplete": {}, // Route53
"TransactionInProgressException": {},
} }
// credsExpiredCodes is a collection of error codes which signify the credentials // credsExpiredCodes is a collection of error codes which signify the credentials

@ -17,6 +17,12 @@ const (
ParamMinValueErrCode = "ParamMinValueError" ParamMinValueErrCode = "ParamMinValueError"
// ParamMinLenErrCode is the error code for fields without enough elements. // ParamMinLenErrCode is the error code for fields without enough elements.
ParamMinLenErrCode = "ParamMinLenError" ParamMinLenErrCode = "ParamMinLenError"
// ParamMaxLenErrCode is the error code for value being too long.
ParamMaxLenErrCode = "ParamMaxLenError"
// ParamFormatErrCode is the error code for a field with invalid
// format or characters.
ParamFormatErrCode = "ParamFormatInvalidError"
) )
// Validator provides a way for types to perform validation logic on their // Validator provides a way for types to perform validation logic on their
@ -232,3 +238,49 @@ func NewErrParamMinLen(field string, min int) *ErrParamMinLen {
func (e *ErrParamMinLen) MinLen() int { func (e *ErrParamMinLen) MinLen() int {
return e.min return e.min
} }
// An ErrParamMaxLen represents a maximum length parameter error.
type ErrParamMaxLen struct {
errInvalidParam
max int
}
// NewErrParamMaxLen creates a new maximum length parameter error.
func NewErrParamMaxLen(field string, max int, value string) *ErrParamMaxLen {
return &ErrParamMaxLen{
errInvalidParam: errInvalidParam{
code: ParamMaxLenErrCode,
field: field,
msg: fmt.Sprintf("maximum size of %v, %v", max, value),
},
max: max,
}
}
// MaxLen returns the field's required minimum length.
func (e *ErrParamMaxLen) MaxLen() int {
return e.max
}
// An ErrParamFormat represents a invalid format parameter error.
type ErrParamFormat struct {
errInvalidParam
format string
}
// NewErrParamFormat creates a new invalid format parameter error.
func NewErrParamFormat(field string, format, value string) *ErrParamFormat {
return &ErrParamFormat{
errInvalidParam: errInvalidParam{
code: ParamFormatErrCode,
field: field,
msg: fmt.Sprintf("format %v, %v", format, value),
},
format: format,
}
}
// Format returns the field's required format.
func (e *ErrParamFormat) Format() string {
return e.format
}

@ -0,0 +1,26 @@
// +build go1.7
package session
import (
"net"
"net/http"
"time"
)
// Transport that should be used when a custom CA bundle is specified with the
// SDK.
func getCABundleTransport() *http.Transport {
return &http.Transport{
Proxy: http.ProxyFromEnvironment,
DialContext: (&net.Dialer{
Timeout: 30 * time.Second,
KeepAlive: 30 * time.Second,
DualStack: true,
}).DialContext,
MaxIdleConns: 100,
IdleConnTimeout: 90 * time.Second,
TLSHandshakeTimeout: 10 * time.Second,
ExpectContinueTimeout: 1 * time.Second,
}
}

@ -0,0 +1,22 @@
// +build !go1.6,go1.5
package session
import (
"net"
"net/http"
"time"
)
// Transport that should be used when a custom CA bundle is specified with the
// SDK.
func getCABundleTransport() *http.Transport {
return &http.Transport{
Proxy: http.ProxyFromEnvironment,
Dial: (&net.Dialer{
Timeout: 30 * time.Second,
KeepAlive: 30 * time.Second,
}).Dial,
TLSHandshakeTimeout: 10 * time.Second,
}
}

@ -0,0 +1,23 @@
// +build !go1.7,go1.6
package session
import (
"net"
"net/http"
"time"
)
// Transport that should be used when a custom CA bundle is specified with the
// SDK.
func getCABundleTransport() *http.Transport {
return &http.Transport{
Proxy: http.ProxyFromEnvironment,
Dial: (&net.Dialer{
Timeout: 30 * time.Second,
KeepAlive: 30 * time.Second,
}).Dial,
TLSHandshakeTimeout: 10 * time.Second,
ExpectContinueTimeout: 1 * time.Second,
}
}

@ -0,0 +1,260 @@
package session
import (
"fmt"
"os"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/aws/credentials"
"github.com/aws/aws-sdk-go/aws/credentials/processcreds"
"github.com/aws/aws-sdk-go/aws/credentials/stscreds"
"github.com/aws/aws-sdk-go/aws/defaults"
"github.com/aws/aws-sdk-go/aws/request"
"github.com/aws/aws-sdk-go/internal/shareddefaults"
)
func resolveCredentials(cfg *aws.Config,
envCfg envConfig, sharedCfg sharedConfig,
handlers request.Handlers,
sessOpts Options,
) (*credentials.Credentials, error) {
switch {
case len(envCfg.Profile) != 0:
// User explicitly provided an Profile, so load from shared config
// first.
return resolveCredsFromProfile(cfg, envCfg, sharedCfg, handlers, sessOpts)
case envCfg.Creds.HasKeys():
// Environment credentials
return credentials.NewStaticCredentialsFromCreds(envCfg.Creds), nil
case len(envCfg.WebIdentityTokenFilePath) != 0:
// Web identity token from environment, RoleARN required to also be
// set.
return assumeWebIdentity(cfg, handlers,
envCfg.WebIdentityTokenFilePath,
envCfg.RoleARN,
envCfg.RoleSessionName,
)
default:
// Fallback to the "default" credential resolution chain.
return resolveCredsFromProfile(cfg, envCfg, sharedCfg, handlers, sessOpts)
}
}
// WebIdentityEmptyRoleARNErr will occur if 'AWS_WEB_IDENTITY_TOKEN_FILE' was set but
// 'AWS_IAM_ROLE_ARN' was not set.
var WebIdentityEmptyRoleARNErr = awserr.New(stscreds.ErrCodeWebIdentity, "role ARN is not set", nil)
// WebIdentityEmptyTokenFilePathErr will occur if 'AWS_IAM_ROLE_ARN' was set but
// 'AWS_WEB_IDENTITY_TOKEN_FILE' was not set.
var WebIdentityEmptyTokenFilePathErr = awserr.New(stscreds.ErrCodeWebIdentity, "token file path is not set", nil)
func assumeWebIdentity(cfg *aws.Config, handlers request.Handlers,
filepath string,
roleARN, sessionName string,
) (*credentials.Credentials, error) {
if len(filepath) == 0 {
return nil, WebIdentityEmptyTokenFilePathErr
}
if len(roleARN) == 0 {
return nil, WebIdentityEmptyRoleARNErr
}
creds := stscreds.NewWebIdentityCredentials(
&Session{
Config: cfg,
Handlers: handlers.Copy(),
},
roleARN,
sessionName,
filepath,
)
return creds, nil
}
func resolveCredsFromProfile(cfg *aws.Config,
envCfg envConfig, sharedCfg sharedConfig,
handlers request.Handlers,
sessOpts Options,
) (creds *credentials.Credentials, err error) {
switch {
case sharedCfg.SourceProfile != nil:
// Assume IAM role with credentials source from a different profile.
creds, err = resolveCredsFromProfile(cfg, envCfg,
*sharedCfg.SourceProfile, handlers, sessOpts,
)
case sharedCfg.Creds.HasKeys():
// Static Credentials from Shared Config/Credentials file.
creds = credentials.NewStaticCredentialsFromCreds(
sharedCfg.Creds,
)
case len(sharedCfg.CredentialProcess) != 0:
// Get credentials from CredentialProcess
creds = processcreds.NewCredentials(sharedCfg.CredentialProcess)
case len(sharedCfg.CredentialSource) != 0:
creds, err = resolveCredsFromSource(cfg, envCfg,
sharedCfg, handlers, sessOpts,
)
case len(sharedCfg.WebIdentityTokenFile) != 0:
// Credentials from Assume Web Identity token require an IAM Role, and
// that roll will be assumed. May be wrapped with another assume role
// via SourceProfile.
return assumeWebIdentity(cfg, handlers,
sharedCfg.WebIdentityTokenFile,
sharedCfg.RoleARN,
sharedCfg.RoleSessionName,
)
default:
// Fallback to default credentials provider, include mock errors for
// the credential chain so user can identify why credentials failed to
// be retrieved.
creds = credentials.NewCredentials(&credentials.ChainProvider{
VerboseErrors: aws.BoolValue(cfg.CredentialsChainVerboseErrors),
Providers: []credentials.Provider{
&credProviderError{
Err: awserr.New("EnvAccessKeyNotFound",
"failed to find credentials in the environment.", nil),
},
&credProviderError{
Err: awserr.New("SharedCredsLoad",
fmt.Sprintf("failed to load profile, %s.", envCfg.Profile), nil),
},
defaults.RemoteCredProvider(*cfg, handlers),
},
})
}
if err != nil {
return nil, err
}
if len(sharedCfg.RoleARN) > 0 {
cfgCp := *cfg
cfgCp.Credentials = creds
return credsFromAssumeRole(cfgCp, handlers, sharedCfg, sessOpts)
}
return creds, nil
}
// valid credential source values
const (
credSourceEc2Metadata = "Ec2InstanceMetadata"
credSourceEnvironment = "Environment"
credSourceECSContainer = "EcsContainer"
)
func resolveCredsFromSource(cfg *aws.Config,
envCfg envConfig, sharedCfg sharedConfig,
handlers request.Handlers,
sessOpts Options,
) (creds *credentials.Credentials, err error) {
switch sharedCfg.CredentialSource {
case credSourceEc2Metadata:
p := defaults.RemoteCredProvider(*cfg, handlers)
creds = credentials.NewCredentials(p)
case credSourceEnvironment:
creds = credentials.NewStaticCredentialsFromCreds(envCfg.Creds)
case credSourceECSContainer:
if len(os.Getenv(shareddefaults.ECSCredsProviderEnvVar)) == 0 {
return nil, ErrSharedConfigECSContainerEnvVarEmpty
}
p := defaults.RemoteCredProvider(*cfg, handlers)
creds = credentials.NewCredentials(p)
default:
return nil, ErrSharedConfigInvalidCredSource
}
return creds, nil
}
func credsFromAssumeRole(cfg aws.Config,
handlers request.Handlers,
sharedCfg sharedConfig,
sessOpts Options,
) (*credentials.Credentials, error) {
if len(sharedCfg.MFASerial) != 0 && sessOpts.AssumeRoleTokenProvider == nil {
// AssumeRole Token provider is required if doing Assume Role
// with MFA.
return nil, AssumeRoleTokenProviderNotSetError{}
}
return stscreds.NewCredentials(
&Session{
Config: &cfg,
Handlers: handlers.Copy(),
},
sharedCfg.RoleARN,
func(opt *stscreds.AssumeRoleProvider) {
opt.RoleSessionName = sharedCfg.RoleSessionName
opt.Duration = sessOpts.AssumeRoleDuration
// Assume role with external ID
if len(sharedCfg.ExternalID) > 0 {
opt.ExternalID = aws.String(sharedCfg.ExternalID)
}
// Assume role with MFA
if len(sharedCfg.MFASerial) > 0 {
opt.SerialNumber = aws.String(sharedCfg.MFASerial)
opt.TokenProvider = sessOpts.AssumeRoleTokenProvider
}
},
), nil
}
// AssumeRoleTokenProviderNotSetError is an error returned when creating a
// session when the MFAToken option is not set when shared config is configured
// load assume a role with an MFA token.
type AssumeRoleTokenProviderNotSetError struct{}
// Code is the short id of the error.
func (e AssumeRoleTokenProviderNotSetError) Code() string {
return "AssumeRoleTokenProviderNotSetError"
}
// Message is the description of the error
func (e AssumeRoleTokenProviderNotSetError) Message() string {
return fmt.Sprintf("assume role with MFA enabled, but AssumeRoleTokenProvider session option not set.")
}
// OrigErr is the underlying error that caused the failure.
func (e AssumeRoleTokenProviderNotSetError) OrigErr() error {
return nil
}
// Error satisfies the error interface.
func (e AssumeRoleTokenProviderNotSetError) Error() string {
return awserr.SprintError(e.Code(), e.Message(), "", nil)
}
type credProviderError struct {
Err error
}
var emptyCreds = credentials.Value{}
func (c credProviderError) Retrieve() (credentials.Value, error) {
return credentials.Value{}, c.Err
}
func (c credProviderError) IsExpired() bool {
return true
}

@ -99,7 +99,7 @@ handler logs every request and its payload made by a service client:
sess.Handlers.Send.PushFront(func(r *request.Request) { sess.Handlers.Send.PushFront(func(r *request.Request) {
// Log every request made and its payload // Log every request made and its payload
logger.Println("Request: %s/%s, Payload: %s", logger.Printf("Request: %s/%s, Payload: %s",
r.ClientInfo.ServiceName, r.Operation, r.Params) r.ClientInfo.ServiceName, r.Operation, r.Params)
}) })
@ -183,7 +183,7 @@ be returned when creating the session.
// from assumed role. // from assumed role.
svc := s3.New(sess) svc := s3.New(sess)
To setup assume role outside of a session see the stscrds.AssumeRoleProvider To setup assume role outside of a session see the stscreds.AssumeRoleProvider
documentation. documentation.
Environment Variables Environment Variables

@ -4,6 +4,7 @@ import (
"os" "os"
"strconv" "strconv"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/credentials" "github.com/aws/aws-sdk-go/aws/credentials"
"github.com/aws/aws-sdk-go/aws/defaults" "github.com/aws/aws-sdk-go/aws/defaults"
) )
@ -79,7 +80,7 @@ type envConfig struct {
// AWS_CONFIG_FILE=$HOME/my_shared_config // AWS_CONFIG_FILE=$HOME/my_shared_config
SharedConfigFile string SharedConfigFile string
// Sets the path to a custom Credentials Authroity (CA) Bundle PEM file // Sets the path to a custom Credentials Authority (CA) Bundle PEM file
// that the SDK will use instead of the system's root CA bundle. // that the SDK will use instead of the system's root CA bundle.
// Only use this if you want to configure the SDK to use a custom set // Only use this if you want to configure the SDK to use a custom set
// of CAs. // of CAs.
@ -101,12 +102,38 @@ type envConfig struct {
CSMEnabled bool CSMEnabled bool
CSMPort string CSMPort string
CSMClientID string CSMClientID string
CSMHost string
// Enables endpoint discovery via environment variables.
//
// AWS_ENABLE_ENDPOINT_DISCOVERY=true
EnableEndpointDiscovery *bool
enableEndpointDiscovery string
// Specifies the WebIdentity token the SDK should use to assume a role
// with.
//
// AWS_WEB_IDENTITY_TOKEN_FILE=file_path
WebIdentityTokenFilePath string
// Specifies the IAM role arn to use when assuming an role.
//
// AWS_ROLE_ARN=role_arn
RoleARN string
// Specifies the IAM role session name to use when assuming a role.
//
// AWS_ROLE_SESSION_NAME=session_name
RoleSessionName string
} }
var ( var (
csmEnabledEnvKey = []string{ csmEnabledEnvKey = []string{
"AWS_CSM_ENABLED", "AWS_CSM_ENABLED",
} }
csmHostEnvKey = []string{
"AWS_CSM_HOST",
}
csmPortEnvKey = []string{ csmPortEnvKey = []string{
"AWS_CSM_PORT", "AWS_CSM_PORT",
} }
@ -125,6 +152,10 @@ var (
"AWS_SESSION_TOKEN", "AWS_SESSION_TOKEN",
} }
enableEndpointDiscoveryEnvKey = []string{
"AWS_ENABLE_ENDPOINT_DISCOVERY",
}
regionEnvKeys = []string{ regionEnvKeys = []string{
"AWS_REGION", "AWS_REGION",
"AWS_DEFAULT_REGION", // Only read if AWS_SDK_LOAD_CONFIG is also set "AWS_DEFAULT_REGION", // Only read if AWS_SDK_LOAD_CONFIG is also set
@ -139,6 +170,15 @@ var (
sharedConfigFileEnvKey = []string{ sharedConfigFileEnvKey = []string{
"AWS_CONFIG_FILE", "AWS_CONFIG_FILE",
} }
webIdentityTokenFilePathEnvKey = []string{
"AWS_WEB_IDENTITY_TOKEN_FILE",
}
roleARNEnvKey = []string{
"AWS_ROLE_ARN",
}
roleSessionNameEnvKey = []string{
"AWS_ROLE_SESSION_NAME",
}
) )
// loadEnvConfig retrieves the SDK's environment configuration. // loadEnvConfig retrieves the SDK's environment configuration.
@ -167,23 +207,31 @@ func envConfigLoad(enableSharedConfig bool) envConfig {
cfg.EnableSharedConfig = enableSharedConfig cfg.EnableSharedConfig = enableSharedConfig
setFromEnvVal(&cfg.Creds.AccessKeyID, credAccessEnvKey) // Static environment credentials
setFromEnvVal(&cfg.Creds.SecretAccessKey, credSecretEnvKey) var creds credentials.Value
setFromEnvVal(&cfg.Creds.SessionToken, credSessionEnvKey) setFromEnvVal(&creds.AccessKeyID, credAccessEnvKey)
setFromEnvVal(&creds.SecretAccessKey, credSecretEnvKey)
setFromEnvVal(&creds.SessionToken, credSessionEnvKey)
if creds.HasKeys() {
// Require logical grouping of credentials
creds.ProviderName = EnvProviderName
cfg.Creds = creds
}
// Role Metadata
setFromEnvVal(&cfg.RoleARN, roleARNEnvKey)
setFromEnvVal(&cfg.RoleSessionName, roleSessionNameEnvKey)
// Web identity environment variables
setFromEnvVal(&cfg.WebIdentityTokenFilePath, webIdentityTokenFilePathEnvKey)
// CSM environment variables // CSM environment variables
setFromEnvVal(&cfg.csmEnabled, csmEnabledEnvKey) setFromEnvVal(&cfg.csmEnabled, csmEnabledEnvKey)
setFromEnvVal(&cfg.CSMHost, csmHostEnvKey)
setFromEnvVal(&cfg.CSMPort, csmPortEnvKey) setFromEnvVal(&cfg.CSMPort, csmPortEnvKey)
setFromEnvVal(&cfg.CSMClientID, csmClientIDEnvKey) setFromEnvVal(&cfg.CSMClientID, csmClientIDEnvKey)
cfg.CSMEnabled = len(cfg.csmEnabled) > 0 cfg.CSMEnabled = len(cfg.csmEnabled) > 0
// Require logical grouping of credentials
if len(cfg.Creds.AccessKeyID) == 0 || len(cfg.Creds.SecretAccessKey) == 0 {
cfg.Creds = credentials.Value{}
} else {
cfg.Creds.ProviderName = EnvProviderName
}
regionKeys := regionEnvKeys regionKeys := regionEnvKeys
profileKeys := profileEnvKeys profileKeys := profileEnvKeys
if !cfg.EnableSharedConfig { if !cfg.EnableSharedConfig {
@ -194,6 +242,12 @@ func envConfigLoad(enableSharedConfig bool) envConfig {
setFromEnvVal(&cfg.Region, regionKeys) setFromEnvVal(&cfg.Region, regionKeys)
setFromEnvVal(&cfg.Profile, profileKeys) setFromEnvVal(&cfg.Profile, profileKeys)
// endpoint discovery is in reference to it being enabled.
setFromEnvVal(&cfg.enableEndpointDiscovery, enableEndpointDiscoveryEnvKey)
if len(cfg.enableEndpointDiscovery) > 0 {
cfg.EnableEndpointDiscovery = aws.Bool(cfg.enableEndpointDiscovery != "false")
}
setFromEnvVal(&cfg.SharedCredentialsFile, sharedCredsFileEnvKey) setFromEnvVal(&cfg.SharedCredentialsFile, sharedCredsFileEnvKey)
setFromEnvVal(&cfg.SharedConfigFile, sharedConfigFileEnvKey) setFromEnvVal(&cfg.SharedConfigFile, sharedConfigFileEnvKey)

@ -8,19 +8,36 @@ import (
"io/ioutil" "io/ioutil"
"net/http" "net/http"
"os" "os"
"time"
"github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/awserr" "github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/aws/client" "github.com/aws/aws-sdk-go/aws/client"
"github.com/aws/aws-sdk-go/aws/corehandlers" "github.com/aws/aws-sdk-go/aws/corehandlers"
"github.com/aws/aws-sdk-go/aws/credentials" "github.com/aws/aws-sdk-go/aws/credentials"
"github.com/aws/aws-sdk-go/aws/credentials/stscreds"
"github.com/aws/aws-sdk-go/aws/csm" "github.com/aws/aws-sdk-go/aws/csm"
"github.com/aws/aws-sdk-go/aws/defaults" "github.com/aws/aws-sdk-go/aws/defaults"
"github.com/aws/aws-sdk-go/aws/endpoints" "github.com/aws/aws-sdk-go/aws/endpoints"
"github.com/aws/aws-sdk-go/aws/request" "github.com/aws/aws-sdk-go/aws/request"
) )
const (
// ErrCodeSharedConfig represents an error that occurs in the shared
// configuration logic
ErrCodeSharedConfig = "SharedConfigErr"
)
// ErrSharedConfigSourceCollision will be returned if a section contains both
// source_profile and credential_source
var ErrSharedConfigSourceCollision = awserr.New(ErrCodeSharedConfig, "only source profile or credential source can be specified, not both", nil)
// ErrSharedConfigECSContainerEnvVarEmpty will be returned if the environment
// variables are empty and Environment was set as the credential source
var ErrSharedConfigECSContainerEnvVarEmpty = awserr.New(ErrCodeSharedConfig, "EcsContainer was specified as the credential_source, but 'AWS_CONTAINER_CREDENTIALS_RELATIVE_URI' was not set", nil)
// ErrSharedConfigInvalidCredSource will be returned if an invalid credential source was provided
var ErrSharedConfigInvalidCredSource = awserr.New(ErrCodeSharedConfig, "credential source values must be EcsContainer, Ec2InstanceMetadata, or Environment", nil)
// A Session provides a central location to create service clients from and // A Session provides a central location to create service clients from and
// store configurations and request handlers for those services. // store configurations and request handlers for those services.
// //
@ -88,7 +105,15 @@ func New(cfgs ...*aws.Config) *Session {
s := deprecatedNewSession(cfgs...) s := deprecatedNewSession(cfgs...)
if envCfg.CSMEnabled { if envCfg.CSMEnabled {
enableCSM(&s.Handlers, envCfg.CSMClientID, envCfg.CSMPort, s.Config.Logger) err := enableCSM(&s.Handlers, envCfg.CSMClientID,
envCfg.CSMHost, envCfg.CSMPort, s.Config.Logger)
if err != nil {
err = fmt.Errorf("failed to enable CSM, %v", err)
s.Config.Logger.Log("ERROR:", err.Error())
s.Handlers.Validate.PushBack(func(r *request.Request) {
r.Error = err
})
}
} }
return s return s
@ -191,6 +216,12 @@ type Options struct {
// the config enables assume role wit MFA via the mfa_serial field. // the config enables assume role wit MFA via the mfa_serial field.
AssumeRoleTokenProvider func() (string, error) AssumeRoleTokenProvider func() (string, error)
// When the SDK's shared config is configured to assume a role this option
// may be provided to set the expiry duration of the STS credentials.
// Defaults to 15 minutes if not set as documented in the
// stscreds.AssumeRoleProvider.
AssumeRoleDuration time.Duration
// Reader for a custom Credentials Authority (CA) bundle in PEM format that // Reader for a custom Credentials Authority (CA) bundle in PEM format that
// the SDK will use instead of the default system's root CA bundle. Use this // the SDK will use instead of the default system's root CA bundle. Use this
// only if you want to replace the CA bundle the SDK uses for TLS requests. // only if you want to replace the CA bundle the SDK uses for TLS requests.
@ -205,6 +236,12 @@ type Options struct {
// to also enable this feature. CustomCABundle session option field has priority // to also enable this feature. CustomCABundle session option field has priority
// over the AWS_CA_BUNDLE environment variable, and will be used if both are set. // over the AWS_CA_BUNDLE environment variable, and will be used if both are set.
CustomCABundle io.Reader CustomCABundle io.Reader
// The handlers that the session and all API clients will be created with.
// This must be a complete set of handlers. Use the defaults.Handlers()
// function to initialize this value before changing the handlers to be
// used by the SDK.
Handlers request.Handlers
} }
// NewSessionWithOptions returns a new Session created from SDK defaults, config files, // NewSessionWithOptions returns a new Session created from SDK defaults, config files,
@ -310,27 +347,36 @@ func deprecatedNewSession(cfgs ...*aws.Config) *Session {
return s return s
} }
func enableCSM(handlers *request.Handlers, clientID string, port string, logger aws.Logger) { func enableCSM(handlers *request.Handlers,
clientID, host, port string,
logger aws.Logger,
) error {
if logger != nil {
logger.Log("Enabling CSM") logger.Log("Enabling CSM")
if len(port) == 0 {
port = csm.DefaultPort
} }
r, err := csm.Start(clientID, "127.0.0.1:"+port) r, err := csm.Start(clientID, csm.AddressWithDefaults(host, port))
if err != nil { if err != nil {
return return err
} }
r.InjectHandlers(handlers) r.InjectHandlers(handlers)
return nil
} }
func newSession(opts Options, envCfg envConfig, cfgs ...*aws.Config) (*Session, error) { func newSession(opts Options, envCfg envConfig, cfgs ...*aws.Config) (*Session, error) {
cfg := defaults.Config() cfg := defaults.Config()
handlers := defaults.Handlers()
handlers := opts.Handlers
if handlers.IsEmpty() {
handlers = defaults.Handlers()
}
// Get a merged version of the user provided config to determine if // Get a merged version of the user provided config to determine if
// credentials were. // credentials were.
userCfg := &aws.Config{} userCfg := &aws.Config{}
userCfg.MergeIn(cfgs...) userCfg.MergeIn(cfgs...)
cfg.MergeIn(userCfg)
// Ordered config files will be loaded in with later files overwriting // Ordered config files will be loaded in with later files overwriting
// previous config file values. // previous config file values.
@ -347,10 +393,12 @@ func newSession(opts Options, envCfg envConfig, cfgs ...*aws.Config) (*Session,
} }
// Load additional config from file(s) // Load additional config from file(s)
sharedCfg, err := loadSharedConfig(envCfg.Profile, cfgFiles) sharedCfg, err := loadSharedConfig(envCfg.Profile, cfgFiles, envCfg.EnableSharedConfig)
if err != nil { if err != nil {
if _, ok := err.(SharedConfigProfileNotExistsError); !ok {
return nil, err return nil, err
} }
}
if err := mergeConfigSrcs(cfg, userCfg, envCfg, sharedCfg, handlers, opts); err != nil { if err := mergeConfigSrcs(cfg, userCfg, envCfg, sharedCfg, handlers, opts); err != nil {
return nil, err return nil, err
@ -363,7 +411,11 @@ func newSession(opts Options, envCfg envConfig, cfgs ...*aws.Config) (*Session,
initHandlers(s) initHandlers(s)
if envCfg.CSMEnabled { if envCfg.CSMEnabled {
enableCSM(&s.Handlers, envCfg.CSMClientID, envCfg.CSMPort, s.Config.Logger) err := enableCSM(&s.Handlers, envCfg.CSMClientID,
envCfg.CSMHost, envCfg.CSMPort, s.Config.Logger)
if err != nil {
return nil, err
}
} }
// Setup HTTP client with custom cert bundle if enabled // Setup HTTP client with custom cert bundle if enabled
@ -388,7 +440,10 @@ func loadCustomCABundle(s *Session, bundle io.Reader) error {
} }
} }
if t == nil { if t == nil {
t = &http.Transport{} // Nil transport implies `http.DefaultTransport` should be used. Since
// the SDK cannot modify, nor copy the `DefaultTransport` specifying
// the values the next closest behavior.
t = getCABundleTransport()
} }
p, err := loadCertPool(bundle) p, err := loadCertPool(bundle)
@ -421,9 +476,11 @@ func loadCertPool(r io.Reader) (*x509.CertPool, error) {
return p, nil return p, nil
} }
func mergeConfigSrcs(cfg, userCfg *aws.Config, envCfg envConfig, sharedCfg sharedConfig, handlers request.Handlers, sessOpts Options) error { func mergeConfigSrcs(cfg, userCfg *aws.Config,
// Merge in user provided configuration envCfg envConfig, sharedCfg sharedConfig,
cfg.MergeIn(userCfg) handlers request.Handlers,
sessOpts Options,
) error {
// Region if not already set by user // Region if not already set by user
if len(aws.StringValue(cfg.Region)) == 0 { if len(aws.StringValue(cfg.Region)) == 0 {
@ -434,103 +491,27 @@ func mergeConfigSrcs(cfg, userCfg *aws.Config, envCfg envConfig, sharedCfg share
} }
} }
// Configure credentials if not already set if cfg.EnableEndpointDiscovery == nil {
if envCfg.EnableEndpointDiscovery != nil {
cfg.WithEndpointDiscovery(*envCfg.EnableEndpointDiscovery)
} else if envCfg.EnableSharedConfig && sharedCfg.EnableEndpointDiscovery != nil {
cfg.WithEndpointDiscovery(*sharedCfg.EnableEndpointDiscovery)
}
}
// Configure credentials if not already set by the user when creating the
// Session.
if cfg.Credentials == credentials.AnonymousCredentials && userCfg.Credentials == nil { if cfg.Credentials == credentials.AnonymousCredentials && userCfg.Credentials == nil {
if len(envCfg.Creds.AccessKeyID) > 0 { creds, err := resolveCredentials(cfg, envCfg, sharedCfg, handlers, sessOpts)
cfg.Credentials = credentials.NewStaticCredentialsFromCreds( if err != nil {
envCfg.Creds, return err
)
} else if envCfg.EnableSharedConfig && len(sharedCfg.AssumeRole.RoleARN) > 0 && sharedCfg.AssumeRoleSource != nil {
cfgCp := *cfg
cfgCp.Credentials = credentials.NewStaticCredentialsFromCreds(
sharedCfg.AssumeRoleSource.Creds,
)
if len(sharedCfg.AssumeRole.MFASerial) > 0 && sessOpts.AssumeRoleTokenProvider == nil {
// AssumeRole Token provider is required if doing Assume Role
// with MFA.
return AssumeRoleTokenProviderNotSetError{}
}
cfg.Credentials = stscreds.NewCredentials(
&Session{
Config: &cfgCp,
Handlers: handlers.Copy(),
},
sharedCfg.AssumeRole.RoleARN,
func(opt *stscreds.AssumeRoleProvider) {
opt.RoleSessionName = sharedCfg.AssumeRole.RoleSessionName
// Assume role with external ID
if len(sharedCfg.AssumeRole.ExternalID) > 0 {
opt.ExternalID = aws.String(sharedCfg.AssumeRole.ExternalID)
}
// Assume role with MFA
if len(sharedCfg.AssumeRole.MFASerial) > 0 {
opt.SerialNumber = aws.String(sharedCfg.AssumeRole.MFASerial)
opt.TokenProvider = sessOpts.AssumeRoleTokenProvider
}
},
)
} else if len(sharedCfg.Creds.AccessKeyID) > 0 {
cfg.Credentials = credentials.NewStaticCredentialsFromCreds(
sharedCfg.Creds,
)
} else {
// Fallback to default credentials provider, include mock errors
// for the credential chain so user can identify why credentials
// failed to be retrieved.
cfg.Credentials = credentials.NewCredentials(&credentials.ChainProvider{
VerboseErrors: aws.BoolValue(cfg.CredentialsChainVerboseErrors),
Providers: []credentials.Provider{
&credProviderError{Err: awserr.New("EnvAccessKeyNotFound", "failed to find credentials in the environment.", nil)},
&credProviderError{Err: awserr.New("SharedCredsLoad", fmt.Sprintf("failed to load profile, %s.", envCfg.Profile), nil)},
defaults.RemoteCredProvider(*cfg, handlers),
},
})
} }
cfg.Credentials = creds
} }
return nil return nil
} }
// AssumeRoleTokenProviderNotSetError is an error returned when creating a session when the
// MFAToken option is not set when shared config is configured load assume a
// role with an MFA token.
type AssumeRoleTokenProviderNotSetError struct{}
// Code is the short id of the error.
func (e AssumeRoleTokenProviderNotSetError) Code() string {
return "AssumeRoleTokenProviderNotSetError"
}
// Message is the description of the error
func (e AssumeRoleTokenProviderNotSetError) Message() string {
return fmt.Sprintf("assume role with MFA enabled, but AssumeRoleTokenProvider session option not set.")
}
// OrigErr is the underlying error that caused the failure.
func (e AssumeRoleTokenProviderNotSetError) OrigErr() error {
return nil
}
// Error satisfies the error interface.
func (e AssumeRoleTokenProviderNotSetError) Error() string {
return awserr.SprintError(e.Code(), e.Message(), "", nil)
}
type credProviderError struct {
Err error
}
var emptyCreds = credentials.Value{}
func (c credProviderError) Retrieve() (credentials.Value, error) {
return credentials.Value{}, c.Err
}
func (c credProviderError) IsExpired() bool {
return true
}
func initHandlers(s *Session) { func initHandlers(s *Session) {
// Add the Validate parameter handler if it is not disabled. // Add the Validate parameter handler if it is not disabled.
s.Handlers.Validate.Remove(corehandlers.ValidateParametersHandler) s.Handlers.Validate.Remove(corehandlers.ValidateParametersHandler)

@ -2,11 +2,10 @@ package session
import ( import (
"fmt" "fmt"
"io/ioutil"
"github.com/aws/aws-sdk-go/aws/awserr" "github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/aws/credentials" "github.com/aws/aws-sdk-go/aws/credentials"
"github.com/go-ini/ini" "github.com/aws/aws-sdk-go/internal/ini"
) )
const ( const (
@ -17,7 +16,8 @@ const (
// Assume Role Credentials group // Assume Role Credentials group
roleArnKey = `role_arn` // group required roleArnKey = `role_arn` // group required
sourceProfileKey = `source_profile` // group required sourceProfileKey = `source_profile` // group required (or credential_source)
credentialSourceKey = `credential_source` // group required (or source_profile)
externalIDKey = `external_id` // optional externalIDKey = `external_id` // optional
mfaSerialKey = `mfa_serial` // optional mfaSerialKey = `mfa_serial` // optional
roleSessionNameKey = `role_session_name` // optional roleSessionNameKey = `role_session_name` // optional
@ -25,59 +25,76 @@ const (
// Additional Config fields // Additional Config fields
regionKey = `region` regionKey = `region`
// endpoint discovery group
enableEndpointDiscoveryKey = `endpoint_discovery_enabled` // optional
// External Credential Process
credentialProcessKey = `credential_process` // optional
// Web Identity Token File
webIdentityTokenFileKey = `web_identity_token_file` // optional
// DefaultSharedConfigProfile is the default profile to be used when // DefaultSharedConfigProfile is the default profile to be used when
// loading configuration from the config files if another profile name // loading configuration from the config files if another profile name
// is not provided. // is not provided.
DefaultSharedConfigProfile = `default` DefaultSharedConfigProfile = `default`
) )
type assumeRoleConfig struct {
RoleARN string
SourceProfile string
ExternalID string
MFASerial string
RoleSessionName string
}
// sharedConfig represents the configuration fields of the SDK config files. // sharedConfig represents the configuration fields of the SDK config files.
type sharedConfig struct { type sharedConfig struct {
// Credentials values from the config file. Both aws_access_key_id // Credentials values from the config file. Both aws_access_key_id and
// and aws_secret_access_key must be provided together in the same file // aws_secret_access_key must be provided together in the same file to be
// to be considered valid. The values will be ignored if not a complete group. // considered valid. The values will be ignored if not a complete group.
// aws_session_token is an optional field that can be provided if both of the // aws_session_token is an optional field that can be provided if both of
// other two fields are also provided. // the other two fields are also provided.
// //
// aws_access_key_id // aws_access_key_id
// aws_secret_access_key // aws_secret_access_key
// aws_session_token // aws_session_token
Creds credentials.Value Creds credentials.Value
AssumeRole assumeRoleConfig CredentialSource string
AssumeRoleSource *sharedConfig CredentialProcess string
WebIdentityTokenFile string
// Region is the region the SDK should use for looking up AWS service endpoints RoleARN string
// and signing requests. RoleSessionName string
ExternalID string
MFASerial string
SourceProfileName string
SourceProfile *sharedConfig
// Region is the region the SDK should use for looking up AWS service
// endpoints and signing requests.
// //
// region // region
Region string Region string
// EnableEndpointDiscovery can be enabled in the shared config by setting
// endpoint_discovery_enabled to true
//
// endpoint_discovery_enabled = true
EnableEndpointDiscovery *bool
} }
type sharedConfigFile struct { type sharedConfigFile struct {
Filename string Filename string
IniData *ini.File IniData ini.Sections
} }
// loadSharedConfig retrieves the configuration from the list of files // loadSharedConfig retrieves the configuration from the list of files using
// using the profile provided. The order the files are listed will determine // the profile provided. The order the files are listed will determine
// precedence. Values in subsequent files will overwrite values defined in // precedence. Values in subsequent files will overwrite values defined in
// earlier files. // earlier files.
// //
// For example, given two files A and B. Both define credentials. If the order // For example, given two files A and B. Both define credentials. If the order
// of the files are A then B, B's credential values will be used instead of A's. // of the files are A then B, B's credential values will be used instead of
// A's.
// //
// See sharedConfig.setFromFile for information how the config files // See sharedConfig.setFromFile for information how the config files
// will be loaded. // will be loaded.
func loadSharedConfig(profile string, filenames []string) (sharedConfig, error) { func loadSharedConfig(profile string, filenames []string, exOpts bool) (sharedConfig, error) {
if len(profile) == 0 { if len(profile) == 0 {
profile = DefaultSharedConfigProfile profile = DefaultSharedConfigProfile
} }
@ -88,16 +105,11 @@ func loadSharedConfig(profile string, filenames []string) (sharedConfig, error)
} }
cfg := sharedConfig{} cfg := sharedConfig{}
if err = cfg.setFromIniFiles(profile, files); err != nil { profiles := map[string]struct{}{}
if err = cfg.setFromIniFiles(profiles, profile, files, exOpts); err != nil {
return sharedConfig{}, err return sharedConfig{}, err
} }
if len(cfg.AssumeRole.SourceProfile) > 0 {
if err := cfg.setAssumeRoleSource(profile, files); err != nil {
return sharedConfig{}, err
}
}
return cfg, nil return cfg, nil
} }
@ -105,114 +117,237 @@ func loadSharedConfigIniFiles(filenames []string) ([]sharedConfigFile, error) {
files := make([]sharedConfigFile, 0, len(filenames)) files := make([]sharedConfigFile, 0, len(filenames))
for _, filename := range filenames { for _, filename := range filenames {
b, err := ioutil.ReadFile(filename) sections, err := ini.OpenFile(filename)
if err != nil { if aerr, ok := err.(awserr.Error); ok && aerr.Code() == ini.ErrCodeUnableToReadFile {
// Skip files which can't be opened and read for whatever reason // Skip files which can't be opened and read for whatever reason
continue continue
} } else if err != nil {
f, err := ini.Load(b)
if err != nil {
return nil, SharedConfigLoadError{Filename: filename, Err: err} return nil, SharedConfigLoadError{Filename: filename, Err: err}
} }
files = append(files, sharedConfigFile{ files = append(files, sharedConfigFile{
Filename: filename, IniData: f, Filename: filename, IniData: sections,
}) })
} }
return files, nil return files, nil
} }
func (cfg *sharedConfig) setAssumeRoleSource(origProfile string, files []sharedConfigFile) error { func (cfg *sharedConfig) setFromIniFiles(profiles map[string]struct{}, profile string, files []sharedConfigFile, exOpts bool) error {
var assumeRoleSrc sharedConfig
// Multiple level assume role chains are not support
if cfg.AssumeRole.SourceProfile == origProfile {
assumeRoleSrc = *cfg
assumeRoleSrc.AssumeRole = assumeRoleConfig{}
} else {
err := assumeRoleSrc.setFromIniFiles(cfg.AssumeRole.SourceProfile, files)
if err != nil {
return err
}
}
if len(assumeRoleSrc.Creds.AccessKeyID) == 0 {
return SharedConfigAssumeRoleError{RoleARN: cfg.AssumeRole.RoleARN}
}
cfg.AssumeRoleSource = &assumeRoleSrc
return nil
}
func (cfg *sharedConfig) setFromIniFiles(profile string, files []sharedConfigFile) error {
// Trim files from the list that don't exist. // Trim files from the list that don't exist.
var skippedFiles int
var profileNotFoundErr error
for _, f := range files { for _, f := range files {
if err := cfg.setFromIniFile(profile, f); err != nil { if err := cfg.setFromIniFile(profile, f, exOpts); err != nil {
if _, ok := err.(SharedConfigProfileNotExistsError); ok { if _, ok := err.(SharedConfigProfileNotExistsError); ok {
// Ignore proviles missings // Ignore profiles not defined in individual files.
profileNotFoundErr = err
skippedFiles++
continue continue
} }
return err return err
} }
} }
if skippedFiles == len(files) {
// If all files were skipped because the profile is not found, return
// the original profile not found error.
return profileNotFoundErr
}
if _, ok := profiles[profile]; ok {
// if this is the second instance of the profile the Assume Role
// options must be cleared because they are only valid for the
// first reference of a profile. The self linked instance of the
// profile only have credential provider options.
cfg.clearAssumeRoleOptions()
} else {
// First time a profile has been seen, It must either be a assume role
// or credentials. Assert if the credential type requires a role ARN,
// the ARN is also set.
if err := cfg.validateCredentialsRequireARN(profile); err != nil {
return err
}
}
profiles[profile] = struct{}{}
if err := cfg.validateCredentialType(); err != nil {
return err
}
// Link source profiles for assume roles
if len(cfg.SourceProfileName) != 0 {
// Linked profile via source_profile ignore credential provider
// options, the source profile must provide the credentials.
cfg.clearCredentialOptions()
srcCfg := &sharedConfig{}
err := srcCfg.setFromIniFiles(profiles, cfg.SourceProfileName, files, exOpts)
if err != nil {
// SourceProfile that doesn't exist is an error in configuration.
if _, ok := err.(SharedConfigProfileNotExistsError); ok {
err = SharedConfigAssumeRoleError{
RoleARN: cfg.RoleARN,
SourceProfile: cfg.SourceProfileName,
}
}
return err
}
if !srcCfg.hasCredentials() {
return SharedConfigAssumeRoleError{
RoleARN: cfg.RoleARN,
SourceProfile: cfg.SourceProfileName,
}
}
cfg.SourceProfile = srcCfg
}
return nil return nil
} }
// setFromFile loads the configuration from the file using // setFromFile loads the configuration from the file using the profile
// the profile provided. A sharedConfig pointer type value is used so that // provided. A sharedConfig pointer type value is used so that multiple config
// multiple config file loadings can be chained. // file loadings can be chained.
// //
// Only loads complete logically grouped values, and will not set fields in cfg // Only loads complete logically grouped values, and will not set fields in cfg
// for incomplete grouped values in the config. Such as credentials. For example // for incomplete grouped values in the config. Such as credentials. For
// if a config file only includes aws_access_key_id but no aws_secret_access_key // example if a config file only includes aws_access_key_id but no
// the aws_access_key_id will be ignored. // aws_secret_access_key the aws_access_key_id will be ignored.
func (cfg *sharedConfig) setFromIniFile(profile string, file sharedConfigFile) error { func (cfg *sharedConfig) setFromIniFile(profile string, file sharedConfigFile, exOpts bool) error {
section, err := file.IniData.GetSection(profile) section, ok := file.IniData.GetSection(profile)
if err != nil { if !ok {
// Fallback to to alternate profile name: profile <name> // Fallback to to alternate profile name: profile <name>
section, err = file.IniData.GetSection(fmt.Sprintf("profile %s", profile)) section, ok = file.IniData.GetSection(fmt.Sprintf("profile %s", profile))
if err != nil { if !ok {
return SharedConfigProfileNotExistsError{Profile: profile, Err: err} return SharedConfigProfileNotExistsError{Profile: profile, Err: nil}
} }
} }
if exOpts {
// Assume Role Parameters
updateString(&cfg.RoleARN, section, roleArnKey)
updateString(&cfg.ExternalID, section, externalIDKey)
updateString(&cfg.MFASerial, section, mfaSerialKey)
updateString(&cfg.RoleSessionName, section, roleSessionNameKey)
updateString(&cfg.SourceProfileName, section, sourceProfileKey)
updateString(&cfg.CredentialSource, section, credentialSourceKey)
updateString(&cfg.Region, section, regionKey)
}
updateString(&cfg.CredentialProcess, section, credentialProcessKey)
updateString(&cfg.WebIdentityTokenFile, section, webIdentityTokenFileKey)
// Shared Credentials // Shared Credentials
akid := section.Key(accessKeyIDKey).String() creds := credentials.Value{
secret := section.Key(secretAccessKey).String() AccessKeyID: section.String(accessKeyIDKey),
if len(akid) > 0 && len(secret) > 0 { SecretAccessKey: section.String(secretAccessKey),
cfg.Creds = credentials.Value{ SessionToken: section.String(sessionTokenKey),
AccessKeyID: akid,
SecretAccessKey: secret,
SessionToken: section.Key(sessionTokenKey).String(),
ProviderName: fmt.Sprintf("SharedConfigCredentials: %s", file.Filename), ProviderName: fmt.Sprintf("SharedConfigCredentials: %s", file.Filename),
} }
if creds.HasKeys() {
cfg.Creds = creds
} }
// Assume Role // Endpoint discovery
roleArn := section.Key(roleArnKey).String() if section.Has(enableEndpointDiscoveryKey) {
srcProfile := section.Key(sourceProfileKey).String() v := section.Bool(enableEndpointDiscoveryKey)
if len(roleArn) > 0 && len(srcProfile) > 0 { cfg.EnableEndpointDiscovery = &v
cfg.AssumeRole = assumeRoleConfig{
RoleARN: roleArn,
SourceProfile: srcProfile,
ExternalID: section.Key(externalIDKey).String(),
MFASerial: section.Key(mfaSerialKey).String(),
RoleSessionName: section.Key(roleSessionNameKey).String(),
}
}
// Region
if v := section.Key(regionKey).String(); len(v) > 0 {
cfg.Region = v
} }
return nil return nil
} }
func (cfg *sharedConfig) validateCredentialsRequireARN(profile string) error {
var credSource string
switch {
case len(cfg.SourceProfileName) != 0:
credSource = sourceProfileKey
case len(cfg.CredentialSource) != 0:
credSource = credentialSourceKey
case len(cfg.WebIdentityTokenFile) != 0:
credSource = webIdentityTokenFileKey
}
if len(credSource) != 0 && len(cfg.RoleARN) == 0 {
return CredentialRequiresARNError{
Type: credSource,
Profile: profile,
}
}
return nil
}
func (cfg *sharedConfig) validateCredentialType() error {
// Only one or no credential type can be defined.
if !oneOrNone(
len(cfg.SourceProfileName) != 0,
len(cfg.CredentialSource) != 0,
len(cfg.CredentialProcess) != 0,
len(cfg.WebIdentityTokenFile) != 0,
) {
return ErrSharedConfigSourceCollision
}
return nil
}
func (cfg *sharedConfig) hasCredentials() bool {
switch {
case len(cfg.SourceProfileName) != 0:
case len(cfg.CredentialSource) != 0:
case len(cfg.CredentialProcess) != 0:
case len(cfg.WebIdentityTokenFile) != 0:
case cfg.Creds.HasKeys():
default:
return false
}
return true
}
func (cfg *sharedConfig) clearCredentialOptions() {
cfg.CredentialSource = ""
cfg.CredentialProcess = ""
cfg.WebIdentityTokenFile = ""
cfg.Creds = credentials.Value{}
}
func (cfg *sharedConfig) clearAssumeRoleOptions() {
cfg.RoleARN = ""
cfg.ExternalID = ""
cfg.MFASerial = ""
cfg.RoleSessionName = ""
cfg.SourceProfileName = ""
}
func oneOrNone(bs ...bool) bool {
var count int
for _, b := range bs {
if b {
count++
if count > 1 {
return false
}
}
}
return true
}
// updateString will only update the dst with the value in the section key, key
// is present in the section.
func updateString(dst *string, section ini.Section, key string) {
if !section.Has(key) {
return
}
*dst = section.String(key)
}
// SharedConfigLoadError is an error for the shared config file failed to load. // SharedConfigLoadError is an error for the shared config file failed to load.
type SharedConfigLoadError struct { type SharedConfigLoadError struct {
Filename string Filename string
@ -271,6 +406,7 @@ func (e SharedConfigProfileNotExistsError) Error() string {
// or not complete. // or not complete.
type SharedConfigAssumeRoleError struct { type SharedConfigAssumeRoleError struct {
RoleARN string RoleARN string
SourceProfile string
} }
// Code is the short id of the error. // Code is the short id of the error.
@ -280,8 +416,10 @@ func (e SharedConfigAssumeRoleError) Code() string {
// Message is the description of the error // Message is the description of the error
func (e SharedConfigAssumeRoleError) Message() string { func (e SharedConfigAssumeRoleError) Message() string {
return fmt.Sprintf("failed to load assume role for %s, source profile has no shared credentials", return fmt.Sprintf(
e.RoleARN) "failed to load assume role for %s, source profile %s has no shared credentials",
e.RoleARN, e.SourceProfile,
)
} }
// OrigErr is the underlying error that caused the failure. // OrigErr is the underlying error that caused the failure.
@ -293,3 +431,36 @@ func (e SharedConfigAssumeRoleError) OrigErr() error {
func (e SharedConfigAssumeRoleError) Error() string { func (e SharedConfigAssumeRoleError) Error() string {
return awserr.SprintError(e.Code(), e.Message(), "", nil) return awserr.SprintError(e.Code(), e.Message(), "", nil)
} }
// CredentialRequiresARNError provides the error for shared config credentials
// that are incorrectly configured in the shared config or credentials file.
type CredentialRequiresARNError struct {
// type of credentials that were configured.
Type string
// Profile name the credentials were in.
Profile string
}
// Code is the short id of the error.
func (e CredentialRequiresARNError) Code() string {
return "CredentialRequiresARNError"
}
// Message is the description of the error
func (e CredentialRequiresARNError) Message() string {
return fmt.Sprintf(
"credential type %s requires role_arn, profile %s",
e.Type, e.Profile,
)
}
// OrigErr is the underlying error that caused the failure.
func (e CredentialRequiresARNError) OrigErr() error {
return nil
}
// Error satisfies the error interface.
func (e CredentialRequiresARNError) Error() string {
return awserr.SprintError(e.Code(), e.Message(), "", nil)
}

@ -134,6 +134,7 @@ var requiredSignedHeaders = rules{
"X-Amz-Server-Side-Encryption-Customer-Key": struct{}{}, "X-Amz-Server-Side-Encryption-Customer-Key": struct{}{},
"X-Amz-Server-Side-Encryption-Customer-Key-Md5": struct{}{}, "X-Amz-Server-Side-Encryption-Customer-Key-Md5": struct{}{},
"X-Amz-Storage-Class": struct{}{}, "X-Amz-Storage-Class": struct{}{},
"X-Amz-Tagging": struct{}{},
"X-Amz-Website-Redirect-Location": struct{}{}, "X-Amz-Website-Redirect-Location": struct{}{},
"X-Amz-Content-Sha256": struct{}{}, "X-Amz-Content-Sha256": struct{}{},
}, },
@ -181,7 +182,7 @@ type Signer struct {
// http://docs.aws.amazon.com/general/latest/gr/sigv4-create-canonical-request.html // http://docs.aws.amazon.com/general/latest/gr/sigv4-create-canonical-request.html
DisableURIPathEscaping bool DisableURIPathEscaping bool
// Disales the automatical setting of the HTTP request's Body field with the // Disables the automatical setting of the HTTP request's Body field with the
// io.ReadSeeker passed in to the signer. This is useful if you're using a // io.ReadSeeker passed in to the signer. This is useful if you're using a
// custom wrapper around the body for the io.ReadSeeker and want to preserve // custom wrapper around the body for the io.ReadSeeker and want to preserve
// the Body value on the Request.Body. // the Body value on the Request.Body.
@ -421,7 +422,7 @@ var SignRequestHandler = request.NamedHandler{
// If the credentials of the request's config are set to // If the credentials of the request's config are set to
// credentials.AnonymousCredentials the request will not be signed. // credentials.AnonymousCredentials the request will not be signed.
func SignSDKRequest(req *request.Request) { func SignSDKRequest(req *request.Request) {
signSDKRequestWithCurrTime(req, time.Now) SignSDKRequestWithCurrentTime(req, time.Now)
} }
// BuildNamedHandler will build a generic handler for signing. // BuildNamedHandler will build a generic handler for signing.
@ -429,12 +430,15 @@ func BuildNamedHandler(name string, opts ...func(*Signer)) request.NamedHandler
return request.NamedHandler{ return request.NamedHandler{
Name: name, Name: name,
Fn: func(req *request.Request) { Fn: func(req *request.Request) {
signSDKRequestWithCurrTime(req, time.Now, opts...) SignSDKRequestWithCurrentTime(req, time.Now, opts...)
}, },
} }
} }
func signSDKRequestWithCurrTime(req *request.Request, curTimeFn func() time.Time, opts ...func(*Signer)) { // SignSDKRequestWithCurrentTime will sign the SDK's request using the time
// function passed in. Behaves the same as SignSDKRequest with the exception
// the request is signed with the value returned by the current time function.
func SignSDKRequestWithCurrentTime(req *request.Request, curTimeFn func() time.Time, opts ...func(*Signer)) {
// If the request does not need to be signed ignore the signing of the // If the request does not need to be signed ignore the signing of the
// request if the AnonymousCredentials object is used. // request if the AnonymousCredentials object is used.
if req.Config.Credentials == credentials.AnonymousCredentials { if req.Config.Credentials == credentials.AnonymousCredentials {
@ -470,13 +474,9 @@ func signSDKRequestWithCurrTime(req *request.Request, curTimeFn func() time.Time
opt(v4) opt(v4)
} }
signingTime := req.Time curTime := curTimeFn()
if !req.LastSignedAt.IsZero() {
signingTime = req.LastSignedAt
}
signedHeaders, err := v4.signWithBody(req.HTTPRequest, req.GetBody(), signedHeaders, err := v4.signWithBody(req.HTTPRequest, req.GetBody(),
name, region, req.ExpireTime, req.ExpireTime > 0, signingTime, name, region, req.ExpireTime, req.ExpireTime > 0, curTime,
) )
if err != nil { if err != nil {
req.Error = err req.Error = err
@ -485,7 +485,7 @@ func signSDKRequestWithCurrTime(req *request.Request, curTimeFn func() time.Time
} }
req.SignedHeaderVals = signedHeaders req.SignedHeaderVals = signedHeaders
req.LastSignedAt = curTimeFn() req.LastSignedAt = curTime
} }
const logSignInfoMsg = `DEBUG: Request Signature: const logSignInfoMsg = `DEBUG: Request Signature:
@ -739,14 +739,22 @@ func makeSha256Reader(reader io.ReadSeeker) []byte {
start, _ := reader.Seek(0, sdkio.SeekCurrent) start, _ := reader.Seek(0, sdkio.SeekCurrent)
defer reader.Seek(start, sdkio.SeekStart) defer reader.Seek(start, sdkio.SeekStart)
// Use CopyN to avoid allocating the 32KB buffer in io.Copy for bodies
// smaller than 32KB. Fall back to io.Copy if we fail to determine the size.
size, err := aws.SeekerLen(reader)
if err != nil {
io.Copy(hash, reader) io.Copy(hash, reader)
} else {
io.CopyN(hash, reader, size)
}
return hash.Sum(nil) return hash.Sum(nil)
} }
const doubleSpace = " " const doubleSpace = " "
// stripExcessSpaces will rewrite the passed in slice's string values to not // stripExcessSpaces will rewrite the passed in slice's string values to not
// contain muliple side-by-side spaces. // contain multiple side-by-side spaces.
func stripExcessSpaces(vals []string) { func stripExcessSpaces(vals []string) {
var j, k, l, m, spaces int var j, k, l, m, spaces int
for i, str := range vals { for i, str := range vals {

@ -7,13 +7,18 @@ import (
"github.com/aws/aws-sdk-go/internal/sdkio" "github.com/aws/aws-sdk-go/internal/sdkio"
) )
// ReadSeekCloser wraps a io.Reader returning a ReaderSeekerCloser. Should // ReadSeekCloser wraps a io.Reader returning a ReaderSeekerCloser. Allows the
// only be used with an io.Reader that is also an io.Seeker. Doing so may // SDK to accept an io.Reader that is not also an io.Seeker for unsigned
// cause request signature errors, or request body's not sent for GET, HEAD // streaming payload API operations.
// and DELETE HTTP methods.
// //
// Deprecated: Should only be used with io.ReadSeeker. If using for // A ReadSeekCloser wrapping an nonseekable io.Reader used in an API
// S3 PutObject to stream content use s3manager.Uploader instead. // operation's input will prevent that operation being retried in the case of
// network errors, and cause operation requests to fail if the operation
// requires payload signing.
//
// Note: If using With S3 PutObject to stream an object upload The SDK's S3
// Upload manager (s3manager.Uploader) provides support for streaming with the
// ability to retry network errors.
func ReadSeekCloser(r io.Reader) ReaderSeekerCloser { func ReadSeekCloser(r io.Reader) ReaderSeekerCloser {
return ReaderSeekerCloser{r} return ReaderSeekerCloser{r}
} }
@ -43,7 +48,8 @@ func IsReaderSeekable(r io.Reader) bool {
// Read reads from the reader up to size of p. The number of bytes read, and // Read reads from the reader up to size of p. The number of bytes read, and
// error if it occurred will be returned. // error if it occurred will be returned.
// //
// If the reader is not an io.Reader zero bytes read, and nil error will be returned. // If the reader is not an io.Reader zero bytes read, and nil error will be
// returned.
// //
// Performs the same functionality as io.Reader Read // Performs the same functionality as io.Reader Read
func (r ReaderSeekerCloser) Read(p []byte) (int, error) { func (r ReaderSeekerCloser) Read(p []byte) (int, error) {

@ -5,4 +5,4 @@ package aws
const SDKName = "aws-sdk-go" const SDKName = "aws-sdk-go"
// SDKVersion is the version of this SDK // SDKVersion is the version of this SDK
const SDKVersion = "1.15.24" const SDKVersion = "1.21.1"

120
vendor/github.com/aws/aws-sdk-go/internal/ini/ast.go generated vendored Normal file

@ -0,0 +1,120 @@
package ini
// ASTKind represents different states in the parse table
// and the type of AST that is being constructed
type ASTKind int
// ASTKind* is used in the parse table to transition between
// the different states
const (
ASTKindNone = ASTKind(iota)
ASTKindStart
ASTKindExpr
ASTKindEqualExpr
ASTKindStatement
ASTKindSkipStatement
ASTKindExprStatement
ASTKindSectionStatement
ASTKindNestedSectionStatement
ASTKindCompletedNestedSectionStatement
ASTKindCommentStatement
ASTKindCompletedSectionStatement
)
func (k ASTKind) String() string {
switch k {
case ASTKindNone:
return "none"
case ASTKindStart:
return "start"
case ASTKindExpr:
return "expr"
case ASTKindStatement:
return "stmt"
case ASTKindSectionStatement:
return "section_stmt"
case ASTKindExprStatement:
return "expr_stmt"
case ASTKindCommentStatement:
return "comment"
case ASTKindNestedSectionStatement:
return "nested_section_stmt"
case ASTKindCompletedSectionStatement:
return "completed_stmt"
case ASTKindSkipStatement:
return "skip"
default:
return ""
}
}
// AST interface allows us to determine what kind of node we
// are on and casting may not need to be necessary.
//
// The root is always the first node in Children
type AST struct {
Kind ASTKind
Root Token
RootToken bool
Children []AST
}
func newAST(kind ASTKind, root AST, children ...AST) AST {
return AST{
Kind: kind,
Children: append([]AST{root}, children...),
}
}
func newASTWithRootToken(kind ASTKind, root Token, children ...AST) AST {
return AST{
Kind: kind,
Root: root,
RootToken: true,
Children: children,
}
}
// AppendChild will append to the list of children an AST has.
func (a *AST) AppendChild(child AST) {
a.Children = append(a.Children, child)
}
// GetRoot will return the root AST which can be the first entry
// in the children list or a token.
func (a *AST) GetRoot() AST {
if a.RootToken {
return *a
}
if len(a.Children) == 0 {
return AST{}
}
return a.Children[0]
}
// GetChildren will return the current AST's list of children
func (a *AST) GetChildren() []AST {
if len(a.Children) == 0 {
return []AST{}
}
if a.RootToken {
return a.Children
}
return a.Children[1:]
}
// SetChildren will set and override all children of the AST.
func (a *AST) SetChildren(children []AST) {
if a.RootToken {
a.Children = children
} else {
a.Children = append(a.Children[:1], children...)
}
}
// Start is used to indicate the starting state of the parse table.
var Start = newAST(ASTKindStart, AST{})

@ -0,0 +1,11 @@
package ini
var commaRunes = []rune(",")
func isComma(b rune) bool {
return b == ','
}
func newCommaToken() Token {
return newToken(TokenComma, commaRunes, NoneType)
}

@ -0,0 +1,35 @@
package ini
// isComment will return whether or not the next byte(s) is a
// comment.
func isComment(b []rune) bool {
if len(b) == 0 {
return false
}
switch b[0] {
case ';':
return true
case '#':
return true
}
return false
}
// newCommentToken will create a comment token and
// return how many bytes were read.
func newCommentToken(b []rune) (Token, int, error) {
i := 0
for ; i < len(b); i++ {
if b[i] == '\n' {
break
}
if len(b)-i > 2 && b[i] == '\r' && b[i+1] == '\n' {
break
}
}
return newToken(TokenComment, b[:i], NoneType), i, nil
}

29
vendor/github.com/aws/aws-sdk-go/internal/ini/doc.go generated vendored Normal file

@ -0,0 +1,29 @@
// Package ini is an LL(1) parser for configuration files.
//
// Example:
// sections, err := ini.OpenFile("/path/to/file")
// if err != nil {
// panic(err)
// }
//
// profile := "foo"
// section, ok := sections.GetSection(profile)
// if !ok {
// fmt.Printf("section %q could not be found", profile)
// }
//
// Below is the BNF that describes this parser
// Grammar:
// stmt -> value stmt'
// stmt' -> epsilon | op stmt
// value -> number | string | boolean | quoted_string
//
// section -> [ section'
// section' -> value section_close
// section_close -> ]
//
// SkipState will skip (NL WS)+
//
// comment -> # comment' | ; comment'
// comment' -> epsilon | value
package ini

@ -0,0 +1,4 @@
package ini
// emptyToken is used to satisfy the Token interface
var emptyToken = newToken(TokenNone, []rune{}, NoneType)

@ -0,0 +1,24 @@
package ini
// newExpression will return an expression AST.
// Expr represents an expression
//
// grammar:
// expr -> string | number
func newExpression(tok Token) AST {
return newASTWithRootToken(ASTKindExpr, tok)
}
func newEqualExpr(left AST, tok Token) AST {
return newASTWithRootToken(ASTKindEqualExpr, tok, left)
}
// EqualExprKey will return a LHS value in the equal expr
func EqualExprKey(ast AST) string {
children := ast.GetChildren()
if len(children) == 0 || ast.Kind != ASTKindEqualExpr {
return ""
}
return string(children[0].Root.Raw())
}

17
vendor/github.com/aws/aws-sdk-go/internal/ini/fuzz.go generated vendored Normal file

@ -0,0 +1,17 @@
// +build gofuzz
package ini
import (
"bytes"
)
func Fuzz(data []byte) int {
b := bytes.NewReader(data)
if _, err := Parse(b); err != nil {
return 0
}
return 1
}

51
vendor/github.com/aws/aws-sdk-go/internal/ini/ini.go generated vendored Normal file

@ -0,0 +1,51 @@
package ini
import (
"io"
"os"
"github.com/aws/aws-sdk-go/aws/awserr"
)
// OpenFile takes a path to a given file, and will open and parse
// that file.
func OpenFile(path string) (Sections, error) {
f, err := os.Open(path)
if err != nil {
return Sections{}, awserr.New(ErrCodeUnableToReadFile, "unable to open file", err)
}
defer f.Close()
return Parse(f)
}
// Parse will parse the given file using the shared config
// visitor.
func Parse(f io.Reader) (Sections, error) {
tree, err := ParseAST(f)
if err != nil {
return Sections{}, err
}
v := NewDefaultVisitor()
if err = Walk(tree, v); err != nil {
return Sections{}, err
}
return v.Sections, nil
}
// ParseBytes will parse the given bytes and return the parsed sections.
func ParseBytes(b []byte) (Sections, error) {
tree, err := ParseASTBytes(b)
if err != nil {
return Sections{}, err
}
v := NewDefaultVisitor()
if err = Walk(tree, v); err != nil {
return Sections{}, err
}
return v.Sections, nil
}

@ -0,0 +1,165 @@
package ini
import (
"bytes"
"io"
"io/ioutil"
"github.com/aws/aws-sdk-go/aws/awserr"
)
const (
// ErrCodeUnableToReadFile is used when a file is failed to be
// opened or read from.
ErrCodeUnableToReadFile = "FailedRead"
)
// TokenType represents the various different tokens types
type TokenType int
func (t TokenType) String() string {
switch t {
case TokenNone:
return "none"
case TokenLit:
return "literal"
case TokenSep:
return "sep"
case TokenOp:
return "op"
case TokenWS:
return "ws"
case TokenNL:
return "newline"
case TokenComment:
return "comment"
case TokenComma:
return "comma"
default:
return ""
}
}
// TokenType enums
const (
TokenNone = TokenType(iota)
TokenLit
TokenSep
TokenComma
TokenOp
TokenWS
TokenNL
TokenComment
)
type iniLexer struct{}
// Tokenize will return a list of tokens during lexical analysis of the
// io.Reader.
func (l *iniLexer) Tokenize(r io.Reader) ([]Token, error) {
b, err := ioutil.ReadAll(r)
if err != nil {
return nil, awserr.New(ErrCodeUnableToReadFile, "unable to read file", err)
}
return l.tokenize(b)
}
func (l *iniLexer) tokenize(b []byte) ([]Token, error) {
runes := bytes.Runes(b)
var err error
n := 0
tokenAmount := countTokens(runes)
tokens := make([]Token, tokenAmount)
count := 0
for len(runes) > 0 && count < tokenAmount {
switch {
case isWhitespace(runes[0]):
tokens[count], n, err = newWSToken(runes)
case isComma(runes[0]):
tokens[count], n = newCommaToken(), 1
case isComment(runes):
tokens[count], n, err = newCommentToken(runes)
case isNewline(runes):
tokens[count], n, err = newNewlineToken(runes)
case isSep(runes):
tokens[count], n, err = newSepToken(runes)
case isOp(runes):
tokens[count], n, err = newOpToken(runes)
default:
tokens[count], n, err = newLitToken(runes)
}
if err != nil {
return nil, err
}
count++
runes = runes[n:]
}
return tokens[:count], nil
}
func countTokens(runes []rune) int {
count, n := 0, 0
var err error
for len(runes) > 0 {
switch {
case isWhitespace(runes[0]):
_, n, err = newWSToken(runes)
case isComma(runes[0]):
_, n = newCommaToken(), 1
case isComment(runes):
_, n, err = newCommentToken(runes)
case isNewline(runes):
_, n, err = newNewlineToken(runes)
case isSep(runes):
_, n, err = newSepToken(runes)
case isOp(runes):
_, n, err = newOpToken(runes)
default:
_, n, err = newLitToken(runes)
}
if err != nil {
return 0
}
count++
runes = runes[n:]
}
return count + 1
}
// Token indicates a metadata about a given value.
type Token struct {
t TokenType
ValueType ValueType
base int
raw []rune
}
var emptyValue = Value{}
func newToken(t TokenType, raw []rune, v ValueType) Token {
return Token{
t: t,
raw: raw,
ValueType: v,
}
}
// Raw return the raw runes that were consumed
func (tok Token) Raw() []rune {
return tok.raw
}
// Type returns the token type
func (tok Token) Type() TokenType {
return tok.t
}

@ -0,0 +1,349 @@
package ini
import (
"fmt"
"io"
)
// State enums for the parse table
const (
InvalidState = iota
// stmt -> value stmt'
StatementState
// stmt' -> MarkComplete | op stmt
StatementPrimeState
// value -> number | string | boolean | quoted_string
ValueState
// section -> [ section'
OpenScopeState
// section' -> value section_close
SectionState
// section_close -> ]
CloseScopeState
// SkipState will skip (NL WS)+
SkipState
// SkipTokenState will skip any token and push the previous
// state onto the stack.
SkipTokenState
// comment -> # comment' | ; comment'
// comment' -> MarkComplete | value
CommentState
// MarkComplete state will complete statements and move that
// to the completed AST list
MarkCompleteState
// TerminalState signifies that the tokens have been fully parsed
TerminalState
)
// parseTable is a state machine to dictate the grammar above.
var parseTable = map[ASTKind]map[TokenType]int{
ASTKindStart: map[TokenType]int{
TokenLit: StatementState,
TokenSep: OpenScopeState,
TokenWS: SkipTokenState,
TokenNL: SkipTokenState,
TokenComment: CommentState,
TokenNone: TerminalState,
},
ASTKindCommentStatement: map[TokenType]int{
TokenLit: StatementState,
TokenSep: OpenScopeState,
TokenWS: SkipTokenState,
TokenNL: SkipTokenState,
TokenComment: CommentState,
TokenNone: MarkCompleteState,
},
ASTKindExpr: map[TokenType]int{
TokenOp: StatementPrimeState,
TokenLit: ValueState,
TokenSep: OpenScopeState,
TokenWS: ValueState,
TokenNL: SkipState,
TokenComment: CommentState,
TokenNone: MarkCompleteState,
},
ASTKindEqualExpr: map[TokenType]int{
TokenLit: ValueState,
TokenWS: SkipTokenState,
TokenNL: SkipState,
},
ASTKindStatement: map[TokenType]int{
TokenLit: SectionState,
TokenSep: CloseScopeState,
TokenWS: SkipTokenState,
TokenNL: SkipTokenState,
TokenComment: CommentState,
TokenNone: MarkCompleteState,
},
ASTKindExprStatement: map[TokenType]int{
TokenLit: ValueState,
TokenSep: OpenScopeState,
TokenOp: ValueState,
TokenWS: ValueState,
TokenNL: MarkCompleteState,
TokenComment: CommentState,
TokenNone: TerminalState,
TokenComma: SkipState,
},
ASTKindSectionStatement: map[TokenType]int{
TokenLit: SectionState,
TokenOp: SectionState,
TokenSep: CloseScopeState,
TokenWS: SectionState,
TokenNL: SkipTokenState,
},
ASTKindCompletedSectionStatement: map[TokenType]int{
TokenWS: SkipTokenState,
TokenNL: SkipTokenState,
TokenLit: StatementState,
TokenSep: OpenScopeState,
TokenComment: CommentState,
TokenNone: MarkCompleteState,
},
ASTKindSkipStatement: map[TokenType]int{
TokenLit: StatementState,
TokenSep: OpenScopeState,
TokenWS: SkipTokenState,
TokenNL: SkipTokenState,
TokenComment: CommentState,
TokenNone: TerminalState,
},
}
// ParseAST will parse input from an io.Reader using
// an LL(1) parser.
func ParseAST(r io.Reader) ([]AST, error) {
lexer := iniLexer{}
tokens, err := lexer.Tokenize(r)
if err != nil {
return []AST{}, err
}
return parse(tokens)
}
// ParseASTBytes will parse input from a byte slice using
// an LL(1) parser.
func ParseASTBytes(b []byte) ([]AST, error) {
lexer := iniLexer{}
tokens, err := lexer.tokenize(b)
if err != nil {
return []AST{}, err
}
return parse(tokens)
}
func parse(tokens []Token) ([]AST, error) {
start := Start
stack := newParseStack(3, len(tokens))
stack.Push(start)
s := newSkipper()
loop:
for stack.Len() > 0 {
k := stack.Pop()
var tok Token
if len(tokens) == 0 {
// this occurs when all the tokens have been processed
// but reduction of what's left on the stack needs to
// occur.
tok = emptyToken
} else {
tok = tokens[0]
}
step := parseTable[k.Kind][tok.Type()]
if s.ShouldSkip(tok) {
// being in a skip state with no tokens will break out of
// the parse loop since there is nothing left to process.
if len(tokens) == 0 {
break loop
}
step = SkipTokenState
}
switch step {
case TerminalState:
// Finished parsing. Push what should be the last
// statement to the stack. If there is anything left
// on the stack, an error in parsing has occurred.
if k.Kind != ASTKindStart {
stack.MarkComplete(k)
}
break loop
case SkipTokenState:
// When skipping a token, the previous state was popped off the stack.
// To maintain the correct state, the previous state will be pushed
// onto the stack.
stack.Push(k)
case StatementState:
if k.Kind != ASTKindStart {
stack.MarkComplete(k)
}
expr := newExpression(tok)
stack.Push(expr)
case StatementPrimeState:
if tok.Type() != TokenOp {
stack.MarkComplete(k)
continue
}
if k.Kind != ASTKindExpr {
return nil, NewParseError(
fmt.Sprintf("invalid expression: expected Expr type, but found %T type", k),
)
}
k = trimSpaces(k)
expr := newEqualExpr(k, tok)
stack.Push(expr)
case ValueState:
// ValueState requires the previous state to either be an equal expression
// or an expression statement.
//
// This grammar occurs when the RHS is a number, word, or quoted string.
// equal_expr -> lit op equal_expr'
// equal_expr' -> number | string | quoted_string
// quoted_string -> " quoted_string'
// quoted_string' -> string quoted_string_end
// quoted_string_end -> "
//
// otherwise
// expr_stmt -> equal_expr (expr_stmt')*
// expr_stmt' -> ws S | op S | MarkComplete
// S -> equal_expr' expr_stmt'
switch k.Kind {
case ASTKindEqualExpr:
// assiging a value to some key
k.AppendChild(newExpression(tok))
stack.Push(newExprStatement(k))
case ASTKindExpr:
k.Root.raw = append(k.Root.raw, tok.Raw()...)
stack.Push(k)
case ASTKindExprStatement:
root := k.GetRoot()
children := root.GetChildren()
if len(children) == 0 {
return nil, NewParseError(
fmt.Sprintf("invalid expression: AST contains no children %s", k.Kind),
)
}
rhs := children[len(children)-1]
if rhs.Root.ValueType != QuotedStringType {
rhs.Root.ValueType = StringType
rhs.Root.raw = append(rhs.Root.raw, tok.Raw()...)
}
children[len(children)-1] = rhs
k.SetChildren(children)
stack.Push(k)
}
case OpenScopeState:
if !runeCompare(tok.Raw(), openBrace) {
return nil, NewParseError("expected '['")
}
stmt := newStatement()
stack.Push(stmt)
case CloseScopeState:
if !runeCompare(tok.Raw(), closeBrace) {
return nil, NewParseError("expected ']'")
}
k = trimSpaces(k)
stack.Push(newCompletedSectionStatement(k))
case SectionState:
var stmt AST
switch k.Kind {
case ASTKindStatement:
// If there are multiple literals inside of a scope declaration,
// then the current token's raw value will be appended to the Name.
//
// This handles cases like [ profile default ]
//
// k will represent a SectionStatement with the children representing
// the label of the section
stmt = newSectionStatement(tok)
case ASTKindSectionStatement:
k.Root.raw = append(k.Root.raw, tok.Raw()...)
stmt = k
default:
return nil, NewParseError(
fmt.Sprintf("invalid statement: expected statement: %v", k.Kind),
)
}
stack.Push(stmt)
case MarkCompleteState:
if k.Kind != ASTKindStart {
stack.MarkComplete(k)
}
if stack.Len() == 0 {
stack.Push(start)
}
case SkipState:
stack.Push(newSkipStatement(k))
s.Skip()
case CommentState:
if k.Kind == ASTKindStart {
stack.Push(k)
} else {
stack.MarkComplete(k)
}
stmt := newCommentStatement(tok)
stack.Push(stmt)
default:
return nil, NewParseError(
fmt.Sprintf("invalid state with ASTKind %v and TokenType %v",
k, tok.Type()))
}
if len(tokens) > 0 {
tokens = tokens[1:]
}
}
// this occurs when a statement has not been completed
if stack.top > 1 {
return nil, NewParseError(fmt.Sprintf("incomplete ini expression"))
}
// returns a sublist which excludes the start symbol
return stack.List(), nil
}
// trimSpaces will trim spaces on the left and right hand side of
// the literal.
func trimSpaces(k AST) AST {
// trim left hand side of spaces
for i := 0; i < len(k.Root.raw); i++ {
if !isWhitespace(k.Root.raw[i]) {
break
}
k.Root.raw = k.Root.raw[1:]
i--
}
// trim right hand side of spaces
for i := len(k.Root.raw) - 1; i >= 0; i-- {
if !isWhitespace(k.Root.raw[i]) {
break
}
k.Root.raw = k.Root.raw[:len(k.Root.raw)-1]
}
return k
}

@ -0,0 +1,324 @@
package ini
import (
"fmt"
"strconv"
"strings"
)
var (
runesTrue = []rune("true")
runesFalse = []rune("false")
)
var literalValues = [][]rune{
runesTrue,
runesFalse,
}
func isBoolValue(b []rune) bool {
for _, lv := range literalValues {
if isLitValue(lv, b) {
return true
}
}
return false
}
func isLitValue(want, have []rune) bool {
if len(have) < len(want) {
return false
}
for i := 0; i < len(want); i++ {
if want[i] != have[i] {
return false
}
}
return true
}
// isNumberValue will return whether not the leading characters in
// a byte slice is a number. A number is delimited by whitespace or
// the newline token.
//
// A number is defined to be in a binary, octal, decimal (int | float), hex format,
// or in scientific notation.
func isNumberValue(b []rune) bool {
negativeIndex := 0
helper := numberHelper{}
needDigit := false
for i := 0; i < len(b); i++ {
negativeIndex++
switch b[i] {
case '-':
if helper.IsNegative() || negativeIndex != 1 {
return false
}
helper.Determine(b[i])
needDigit = true
continue
case 'e', 'E':
if err := helper.Determine(b[i]); err != nil {
return false
}
negativeIndex = 0
needDigit = true
continue
case 'b':
if helper.numberFormat == hex {
break
}
fallthrough
case 'o', 'x':
needDigit = true
if i == 0 {
return false
}
fallthrough
case '.':
if err := helper.Determine(b[i]); err != nil {
return false
}
needDigit = true
continue
}
if i > 0 && (isNewline(b[i:]) || isWhitespace(b[i])) {
return !needDigit
}
if !helper.CorrectByte(b[i]) {
return false
}
needDigit = false
}
return !needDigit
}
func isValid(b []rune) (bool, int, error) {
if len(b) == 0 {
// TODO: should probably return an error
return false, 0, nil
}
return isValidRune(b[0]), 1, nil
}
func isValidRune(r rune) bool {
return r != ':' && r != '=' && r != '[' && r != ']' && r != ' ' && r != '\n'
}
// ValueType is an enum that will signify what type
// the Value is
type ValueType int
func (v ValueType) String() string {
switch v {
case NoneType:
return "NONE"
case DecimalType:
return "FLOAT"
case IntegerType:
return "INT"
case StringType:
return "STRING"
case BoolType:
return "BOOL"
}
return ""
}
// ValueType enums
const (
NoneType = ValueType(iota)
DecimalType
IntegerType
StringType
QuotedStringType
BoolType
)
// Value is a union container
type Value struct {
Type ValueType
raw []rune
integer int64
decimal float64
boolean bool
str string
}
func newValue(t ValueType, base int, raw []rune) (Value, error) {
v := Value{
Type: t,
raw: raw,
}
var err error
switch t {
case DecimalType:
v.decimal, err = strconv.ParseFloat(string(raw), 64)
case IntegerType:
if base != 10 {
raw = raw[2:]
}
v.integer, err = strconv.ParseInt(string(raw), base, 64)
case StringType:
v.str = string(raw)
case QuotedStringType:
v.str = string(raw[1 : len(raw)-1])
case BoolType:
v.boolean = runeCompare(v.raw, runesTrue)
}
// issue 2253
//
// if the value trying to be parsed is too large, then we will use
// the 'StringType' and raw value instead.
if nerr, ok := err.(*strconv.NumError); ok && nerr.Err == strconv.ErrRange {
v.Type = StringType
v.str = string(raw)
err = nil
}
return v, err
}
// Append will append values and change the type to a string
// type.
func (v *Value) Append(tok Token) {
r := tok.Raw()
if v.Type != QuotedStringType {
v.Type = StringType
r = tok.raw[1 : len(tok.raw)-1]
}
if tok.Type() != TokenLit {
v.raw = append(v.raw, tok.Raw()...)
} else {
v.raw = append(v.raw, r...)
}
}
func (v Value) String() string {
switch v.Type {
case DecimalType:
return fmt.Sprintf("decimal: %f", v.decimal)
case IntegerType:
return fmt.Sprintf("integer: %d", v.integer)
case StringType:
return fmt.Sprintf("string: %s", string(v.raw))
case QuotedStringType:
return fmt.Sprintf("quoted string: %s", string(v.raw))
case BoolType:
return fmt.Sprintf("bool: %t", v.boolean)
default:
return "union not set"
}
}
func newLitToken(b []rune) (Token, int, error) {
n := 0
var err error
token := Token{}
if b[0] == '"' {
n, err = getStringValue(b)
if err != nil {
return token, n, err
}
token = newToken(TokenLit, b[:n], QuotedStringType)
} else if isNumberValue(b) {
var base int
base, n, err = getNumericalValue(b)
if err != nil {
return token, 0, err
}
value := b[:n]
vType := IntegerType
if contains(value, '.') || hasExponent(value) {
vType = DecimalType
}
token = newToken(TokenLit, value, vType)
token.base = base
} else if isBoolValue(b) {
n, err = getBoolValue(b)
token = newToken(TokenLit, b[:n], BoolType)
} else {
n, err = getValue(b)
token = newToken(TokenLit, b[:n], StringType)
}
return token, n, err
}
// IntValue returns an integer value
func (v Value) IntValue() int64 {
return v.integer
}
// FloatValue returns a float value
func (v Value) FloatValue() float64 {
return v.decimal
}
// BoolValue returns a bool value
func (v Value) BoolValue() bool {
return v.boolean
}
func isTrimmable(r rune) bool {
switch r {
case '\n', ' ':
return true
}
return false
}
// StringValue returns the string value
func (v Value) StringValue() string {
switch v.Type {
case StringType:
return strings.TrimFunc(string(v.raw), isTrimmable)
case QuotedStringType:
// preserve all characters in the quotes
return string(removeEscapedCharacters(v.raw[1 : len(v.raw)-1]))
default:
return strings.TrimFunc(string(v.raw), isTrimmable)
}
}
func contains(runes []rune, c rune) bool {
for i := 0; i < len(runes); i++ {
if runes[i] == c {
return true
}
}
return false
}
func runeCompare(v1 []rune, v2 []rune) bool {
if len(v1) != len(v2) {
return false
}
for i := 0; i < len(v1); i++ {
if v1[i] != v2[i] {
return false
}
}
return true
}

@ -0,0 +1,30 @@
package ini
func isNewline(b []rune) bool {
if len(b) == 0 {
return false
}
if b[0] == '\n' {
return true
}
if len(b) < 2 {
return false
}
return b[0] == '\r' && b[1] == '\n'
}
func newNewlineToken(b []rune) (Token, int, error) {
i := 1
if b[0] == '\r' && isNewline(b[1:]) {
i++
}
if !isNewline([]rune(b[:i])) {
return emptyToken, 0, NewParseError("invalid new line token")
}
return newToken(TokenNL, b[:i], NoneType), i, nil
}

@ -0,0 +1,152 @@
package ini
import (
"bytes"
"fmt"
"strconv"
)
const (
none = numberFormat(iota)
binary
octal
decimal
hex
exponent
)
type numberFormat int
// numberHelper is used to dictate what format a number is in
// and what to do for negative values. Since -1e-4 is a valid
// number, we cannot just simply check for duplicate negatives.
type numberHelper struct {
numberFormat numberFormat
negative bool
negativeExponent bool
}
func (b numberHelper) Exists() bool {
return b.numberFormat != none
}
func (b numberHelper) IsNegative() bool {
return b.negative || b.negativeExponent
}
func (b *numberHelper) Determine(c rune) error {
if b.Exists() {
return NewParseError(fmt.Sprintf("multiple number formats: 0%v", string(c)))
}
switch c {
case 'b':
b.numberFormat = binary
case 'o':
b.numberFormat = octal
case 'x':
b.numberFormat = hex
case 'e', 'E':
b.numberFormat = exponent
case '-':
if b.numberFormat != exponent {
b.negative = true
} else {
b.negativeExponent = true
}
case '.':
b.numberFormat = decimal
default:
return NewParseError(fmt.Sprintf("invalid number character: %v", string(c)))
}
return nil
}
func (b numberHelper) CorrectByte(c rune) bool {
switch {
case b.numberFormat == binary:
if !isBinaryByte(c) {
return false
}
case b.numberFormat == octal:
if !isOctalByte(c) {
return false
}
case b.numberFormat == hex:
if !isHexByte(c) {
return false
}
case b.numberFormat == decimal:
if !isDigit(c) {
return false
}
case b.numberFormat == exponent:
if !isDigit(c) {
return false
}
case b.negativeExponent:
if !isDigit(c) {
return false
}
case b.negative:
if !isDigit(c) {
return false
}
default:
if !isDigit(c) {
return false
}
}
return true
}
func (b numberHelper) Base() int {
switch b.numberFormat {
case binary:
return 2
case octal:
return 8
case hex:
return 16
default:
return 10
}
}
func (b numberHelper) String() string {
buf := bytes.Buffer{}
i := 0
switch b.numberFormat {
case binary:
i++
buf.WriteString(strconv.Itoa(i) + ": binary format\n")
case octal:
i++
buf.WriteString(strconv.Itoa(i) + ": octal format\n")
case hex:
i++
buf.WriteString(strconv.Itoa(i) + ": hex format\n")
case exponent:
i++
buf.WriteString(strconv.Itoa(i) + ": exponent format\n")
default:
i++
buf.WriteString(strconv.Itoa(i) + ": integer format\n")
}
if b.negative {
i++
buf.WriteString(strconv.Itoa(i) + ": negative format\n")
}
if b.negativeExponent {
i++
buf.WriteString(strconv.Itoa(i) + ": negative exponent format\n")
}
return buf.String()
}

@ -0,0 +1,39 @@
package ini
import (
"fmt"
)
var (
equalOp = []rune("=")
equalColonOp = []rune(":")
)
func isOp(b []rune) bool {
if len(b) == 0 {
return false
}
switch b[0] {
case '=':
return true
case ':':
return true
default:
return false
}
}
func newOpToken(b []rune) (Token, int, error) {
tok := Token{}
switch b[0] {
case '=':
tok = newToken(TokenOp, equalOp, NoneType)
case ':':
tok = newToken(TokenOp, equalColonOp, NoneType)
default:
return tok, 0, NewParseError(fmt.Sprintf("unexpected op type, %v", b[0]))
}
return tok, 1, nil
}

@ -0,0 +1,43 @@
package ini
import "fmt"
const (
// ErrCodeParseError is returned when a parsing error
// has occurred.
ErrCodeParseError = "INIParseError"
)
// ParseError is an error which is returned during any part of
// the parsing process.
type ParseError struct {
msg string
}
// NewParseError will return a new ParseError where message
// is the description of the error.
func NewParseError(message string) *ParseError {
return &ParseError{
msg: message,
}
}
// Code will return the ErrCodeParseError
func (err *ParseError) Code() string {
return ErrCodeParseError
}
// Message returns the error's message
func (err *ParseError) Message() string {
return err.msg
}
// OrigError return nothing since there will never be any
// original error.
func (err *ParseError) OrigError() error {
return nil
}
func (err *ParseError) Error() string {
return fmt.Sprintf("%s: %s", err.Code(), err.Message())
}

@ -0,0 +1,60 @@
package ini
import (
"bytes"
"fmt"
)
// ParseStack is a stack that contains a container, the stack portion,
// and the list which is the list of ASTs that have been successfully
// parsed.
type ParseStack struct {
top int
container []AST
list []AST
index int
}
func newParseStack(sizeContainer, sizeList int) ParseStack {
return ParseStack{
container: make([]AST, sizeContainer),
list: make([]AST, sizeList),
}
}
// Pop will return and truncate the last container element.
func (s *ParseStack) Pop() AST {
s.top--
return s.container[s.top]
}
// Push will add the new AST to the container
func (s *ParseStack) Push(ast AST) {
s.container[s.top] = ast
s.top++
}
// MarkComplete will append the AST to the list of completed statements
func (s *ParseStack) MarkComplete(ast AST) {
s.list[s.index] = ast
s.index++
}
// List will return the completed statements
func (s ParseStack) List() []AST {
return s.list[:s.index]
}
// Len will return the length of the container
func (s *ParseStack) Len() int {
return s.top
}
func (s ParseStack) String() string {
buf := bytes.Buffer{}
for i, node := range s.list {
buf.WriteString(fmt.Sprintf("%d: %v\n", i+1, node))
}
return buf.String()
}

@ -0,0 +1,41 @@
package ini
import (
"fmt"
)
var (
emptyRunes = []rune{}
)
func isSep(b []rune) bool {
if len(b) == 0 {
return false
}
switch b[0] {
case '[', ']':
return true
default:
return false
}
}
var (
openBrace = []rune("[")
closeBrace = []rune("]")
)
func newSepToken(b []rune) (Token, int, error) {
tok := Token{}
switch b[0] {
case '[':
tok = newToken(TokenSep, openBrace, NoneType)
case ']':
tok = newToken(TokenSep, closeBrace, NoneType)
default:
return tok, 0, NewParseError(fmt.Sprintf("unexpected sep type, %v", b[0]))
}
return tok, 1, nil
}

@ -0,0 +1,45 @@
package ini
// skipper is used to skip certain blocks of an ini file.
// Currently skipper is used to skip nested blocks of ini
// files. See example below
//
// [ foo ]
// nested = ; this section will be skipped
// a=b
// c=d
// bar=baz ; this will be included
type skipper struct {
shouldSkip bool
TokenSet bool
prevTok Token
}
func newSkipper() skipper {
return skipper{
prevTok: emptyToken,
}
}
func (s *skipper) ShouldSkip(tok Token) bool {
if s.shouldSkip &&
s.prevTok.Type() == TokenNL &&
tok.Type() != TokenWS {
s.Continue()
return false
}
s.prevTok = tok
return s.shouldSkip
}
func (s *skipper) Skip() {
s.shouldSkip = true
s.prevTok = emptyToken
}
func (s *skipper) Continue() {
s.shouldSkip = false
s.prevTok = emptyToken
}

@ -0,0 +1,35 @@
package ini
// Statement is an empty AST mostly used for transitioning states.
func newStatement() AST {
return newAST(ASTKindStatement, AST{})
}
// SectionStatement represents a section AST
func newSectionStatement(tok Token) AST {
return newASTWithRootToken(ASTKindSectionStatement, tok)
}
// ExprStatement represents a completed expression AST
func newExprStatement(ast AST) AST {
return newAST(ASTKindExprStatement, ast)
}
// CommentStatement represents a comment in the ini definition.
//
// grammar:
// comment -> #comment' | ;comment'
// comment' -> epsilon | value
func newCommentStatement(tok Token) AST {
return newAST(ASTKindCommentStatement, newExpression(tok))
}
// CompletedSectionStatement represents a completed section
func newCompletedSectionStatement(ast AST) AST {
return newAST(ASTKindCompletedSectionStatement, ast)
}
// SkipStatement is used to skip whole statements
func newSkipStatement(ast AST) AST {
return newAST(ASTKindSkipStatement, ast)
}

@ -0,0 +1,284 @@
package ini
import (
"fmt"
)
// getStringValue will return a quoted string and the amount
// of bytes read
//
// an error will be returned if the string is not properly formatted
func getStringValue(b []rune) (int, error) {
if b[0] != '"' {
return 0, NewParseError("strings must start with '\"'")
}
endQuote := false
i := 1
for ; i < len(b) && !endQuote; i++ {
if escaped := isEscaped(b[:i], b[i]); b[i] == '"' && !escaped {
endQuote = true
break
} else if escaped {
/*c, err := getEscapedByte(b[i])
if err != nil {
return 0, err
}
b[i-1] = c
b = append(b[:i], b[i+1:]...)
i--*/
continue
}
}
if !endQuote {
return 0, NewParseError("missing '\"' in string value")
}
return i + 1, nil
}
// getBoolValue will return a boolean and the amount
// of bytes read
//
// an error will be returned if the boolean is not of a correct
// value
func getBoolValue(b []rune) (int, error) {
if len(b) < 4 {
return 0, NewParseError("invalid boolean value")
}
n := 0
for _, lv := range literalValues {
if len(lv) > len(b) {
continue
}
if isLitValue(lv, b) {
n = len(lv)
}
}
if n == 0 {
return 0, NewParseError("invalid boolean value")
}
return n, nil
}
// getNumericalValue will return a numerical string, the amount
// of bytes read, and the base of the number
//
// an error will be returned if the number is not of a correct
// value
func getNumericalValue(b []rune) (int, int, error) {
if !isDigit(b[0]) {
return 0, 0, NewParseError("invalid digit value")
}
i := 0
helper := numberHelper{}
loop:
for negativeIndex := 0; i < len(b); i++ {
negativeIndex++
if !isDigit(b[i]) {
switch b[i] {
case '-':
if helper.IsNegative() || negativeIndex != 1 {
return 0, 0, NewParseError("parse error '-'")
}
n := getNegativeNumber(b[i:])
i += (n - 1)
helper.Determine(b[i])
continue
case '.':
if err := helper.Determine(b[i]); err != nil {
return 0, 0, err
}
case 'e', 'E':
if err := helper.Determine(b[i]); err != nil {
return 0, 0, err
}
negativeIndex = 0
case 'b':
if helper.numberFormat == hex {
break
}
fallthrough
case 'o', 'x':
if i == 0 && b[i] != '0' {
return 0, 0, NewParseError("incorrect base format, expected leading '0'")
}
if i != 1 {
return 0, 0, NewParseError(fmt.Sprintf("incorrect base format found %s at %d index", string(b[i]), i))
}
if err := helper.Determine(b[i]); err != nil {
return 0, 0, err
}
default:
if isWhitespace(b[i]) {
break loop
}
if isNewline(b[i:]) {
break loop
}
if !(helper.numberFormat == hex && isHexByte(b[i])) {
if i+2 < len(b) && !isNewline(b[i:i+2]) {
return 0, 0, NewParseError("invalid numerical character")
} else if !isNewline([]rune{b[i]}) {
return 0, 0, NewParseError("invalid numerical character")
}
break loop
}
}
}
}
return helper.Base(), i, nil
}
// isDigit will return whether or not something is an integer
func isDigit(b rune) bool {
return b >= '0' && b <= '9'
}
func hasExponent(v []rune) bool {
return contains(v, 'e') || contains(v, 'E')
}
func isBinaryByte(b rune) bool {
switch b {
case '0', '1':
return true
default:
return false
}
}
func isOctalByte(b rune) bool {
switch b {
case '0', '1', '2', '3', '4', '5', '6', '7':
return true
default:
return false
}
}
func isHexByte(b rune) bool {
if isDigit(b) {
return true
}
return (b >= 'A' && b <= 'F') ||
(b >= 'a' && b <= 'f')
}
func getValue(b []rune) (int, error) {
i := 0
for i < len(b) {
if isNewline(b[i:]) {
break
}
if isOp(b[i:]) {
break
}
valid, n, err := isValid(b[i:])
if err != nil {
return 0, err
}
if !valid {
break
}
i += n
}
return i, nil
}
// getNegativeNumber will return a negative number from a
// byte slice. This will iterate through all characters until
// a non-digit has been found.
func getNegativeNumber(b []rune) int {
if b[0] != '-' {
return 0
}
i := 1
for ; i < len(b); i++ {
if !isDigit(b[i]) {
return i
}
}
return i
}
// isEscaped will return whether or not the character is an escaped
// character.
func isEscaped(value []rune, b rune) bool {
if len(value) == 0 {
return false
}
switch b {
case '\'': // single quote
case '"': // quote
case 'n': // newline
case 't': // tab
case '\\': // backslash
default:
return false
}
return value[len(value)-1] == '\\'
}
func getEscapedByte(b rune) (rune, error) {
switch b {
case '\'': // single quote
return '\'', nil
case '"': // quote
return '"', nil
case 'n': // newline
return '\n', nil
case 't': // table
return '\t', nil
case '\\': // backslash
return '\\', nil
default:
return b, NewParseError(fmt.Sprintf("invalid escaped character %c", b))
}
}
func removeEscapedCharacters(b []rune) []rune {
for i := 0; i < len(b); i++ {
if isEscaped(b[:i], b[i]) {
c, err := getEscapedByte(b[i])
if err != nil {
return b
}
b[i-1] = c
b = append(b[:i], b[i+1:]...)
i--
}
}
return b
}

@ -0,0 +1,166 @@
package ini
import (
"fmt"
"sort"
)
// Visitor is an interface used by walkers that will
// traverse an array of ASTs.
type Visitor interface {
VisitExpr(AST) error
VisitStatement(AST) error
}
// DefaultVisitor is used to visit statements and expressions
// and ensure that they are both of the correct format.
// In addition, upon visiting this will build sections and populate
// the Sections field which can be used to retrieve profile
// configuration.
type DefaultVisitor struct {
scope string
Sections Sections
}
// NewDefaultVisitor return a DefaultVisitor
func NewDefaultVisitor() *DefaultVisitor {
return &DefaultVisitor{
Sections: Sections{
container: map[string]Section{},
},
}
}
// VisitExpr visits expressions...
func (v *DefaultVisitor) VisitExpr(expr AST) error {
t := v.Sections.container[v.scope]
if t.values == nil {
t.values = values{}
}
switch expr.Kind {
case ASTKindExprStatement:
opExpr := expr.GetRoot()
switch opExpr.Kind {
case ASTKindEqualExpr:
children := opExpr.GetChildren()
if len(children) <= 1 {
return NewParseError("unexpected token type")
}
rhs := children[1]
if rhs.Root.Type() != TokenLit {
return NewParseError("unexpected token type")
}
key := EqualExprKey(opExpr)
v, err := newValue(rhs.Root.ValueType, rhs.Root.base, rhs.Root.Raw())
if err != nil {
return err
}
t.values[key] = v
default:
return NewParseError(fmt.Sprintf("unsupported expression %v", expr))
}
default:
return NewParseError(fmt.Sprintf("unsupported expression %v", expr))
}
v.Sections.container[v.scope] = t
return nil
}
// VisitStatement visits statements...
func (v *DefaultVisitor) VisitStatement(stmt AST) error {
switch stmt.Kind {
case ASTKindCompletedSectionStatement:
child := stmt.GetRoot()
if child.Kind != ASTKindSectionStatement {
return NewParseError(fmt.Sprintf("unsupported child statement: %T", child))
}
name := string(child.Root.Raw())
v.Sections.container[name] = Section{}
v.scope = name
default:
return NewParseError(fmt.Sprintf("unsupported statement: %s", stmt.Kind))
}
return nil
}
// Sections is a map of Section structures that represent
// a configuration.
type Sections struct {
container map[string]Section
}
// GetSection will return section p. If section p does not exist,
// false will be returned in the second parameter.
func (t Sections) GetSection(p string) (Section, bool) {
v, ok := t.container[p]
return v, ok
}
// values represents a map of union values.
type values map[string]Value
// List will return a list of all sections that were successfully
// parsed.
func (t Sections) List() []string {
keys := make([]string, len(t.container))
i := 0
for k := range t.container {
keys[i] = k
i++
}
sort.Strings(keys)
return keys
}
// Section contains a name and values. This represent
// a sectioned entry in a configuration file.
type Section struct {
Name string
values values
}
// Has will return whether or not an entry exists in a given section
func (t Section) Has(k string) bool {
_, ok := t.values[k]
return ok
}
// ValueType will returned what type the union is set to. If
// k was not found, the NoneType will be returned.
func (t Section) ValueType(k string) (ValueType, bool) {
v, ok := t.values[k]
return v.Type, ok
}
// Bool returns a bool value at k
func (t Section) Bool(k string) bool {
return t.values[k].BoolValue()
}
// Int returns an integer value at k
func (t Section) Int(k string) int64 {
return t.values[k].IntValue()
}
// Float64 returns a float value at k
func (t Section) Float64(k string) float64 {
return t.values[k].FloatValue()
}
// String returns the string value at k
func (t Section) String(k string) string {
_, ok := t.values[k]
if !ok {
return ""
}
return t.values[k].StringValue()
}

@ -0,0 +1,25 @@
package ini
// Walk will traverse the AST using the v, the Visitor.
func Walk(tree []AST, v Visitor) error {
for _, node := range tree {
switch node.Kind {
case ASTKindExpr,
ASTKindExprStatement:
if err := v.VisitExpr(node); err != nil {
return err
}
case ASTKindStatement,
ASTKindCompletedSectionStatement,
ASTKindNestedSectionStatement,
ASTKindCompletedNestedSectionStatement:
if err := v.VisitStatement(node); err != nil {
return err
}
}
}
return nil
}

@ -0,0 +1,24 @@
package ini
import (
"unicode"
)
// isWhitespace will return whether or not the character is
// a whitespace character.
//
// Whitespace is defined as a space or tab.
func isWhitespace(c rune) bool {
return unicode.IsSpace(c) && c != '\n' && c != '\r'
}
func newWSToken(b []rune) (Token, int, error) {
i := 0
for ; i < len(b); i++ {
if !isWhitespace(b[i]) {
break
}
}
return newToken(TokenWS, b[:i], NoneType), i, nil
}

Some files were not shown because too many files have changed in this diff Show More