connect: Add AWS PCA provider (#6795)

* Update AWS SDK to use PCA features.

* Add AWS PCA provider

* Add plumbing for config, config validation tests, add test for inheriting existing CA resources created by user

* Unparallel the tests so we don't exhaust PCA limits

* Merge updates

* More aggressive polling; rate limit pass through on sign; Timeout on Sign and CA create

* Add AWS PCA docs

* Fix Vault doc typo too

* Doc typo

* Apply suggestions from code review

Co-Authored-By: R.B. Boyer <rb@hashicorp.com>
Co-Authored-By: kaitlincarter-hc <43049322+kaitlincarter-hc@users.noreply.github.com>

* Doc fixes; tests for erroring if State is modified via API

* More review cleanup

* Uncomment tests!

* Minor suggested clean ups
This commit is contained in:
Paul Banks 2019-11-21 17:40:29 +00:00 committed by GitHub
parent 58db6b5d22
commit cd1b613352
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
166 changed files with 46202 additions and 7853 deletions

View File

@ -590,6 +590,10 @@ func (b *Builder) Build() (rt RuntimeConfig, err error) {
"tls_server_name": "TLSServerName",
"tls_skip_verify": "TLSSkipVerify",
// AWS CA config
"existing_arn": "ExistingARN",
"delete_on_exit": "DeleteOnExit",
// Common CA config
"leaf_cert_ttl": "LeafCertTTL",
"csr_max_per_second": "CSRMaxPerSecond",
@ -1095,6 +1099,7 @@ func (b *Builder) Validate(rt RuntimeConfig) error {
"": true,
structs.ConsulCAProvider: true,
structs.VaultCAProvider: true,
structs.AWSCAProvider: true,
}
if _, ok := validCAProviders[rt.ConnectCAProvider]; !ok {
return fmt.Errorf("%s is not a valid CA provider", rt.ConnectCAProvider)
@ -1108,6 +1113,10 @@ func (b *Builder) Validate(rt RuntimeConfig) error {
if _, err := ca.ParseVaultCAConfig(rt.ConnectCAConfig); err != nil {
return err
}
case structs.AWSCAProvider:
if _, err := ca.ParseAWSCAConfig(rt.ConnectCAConfig); err != nil {
return err
}
}
}

View File

@ -2752,6 +2752,91 @@ func TestConfigFlagsAndEdgecases(t *testing.T) {
}
},
},
{
desc: "Connect AWS CA provider configuration",
args: []string{
`-data-dir=` + dataDir,
},
json: []string{`{
"connect": {
"enabled": true,
"ca_provider": "aws-pca",
"ca_config": {
"existing_arn": "foo",
"delete_on_exit": true
}
}
}`},
hcl: []string{`
connect {
enabled = true
ca_provider = "aws-pca"
ca_config {
existing_arn = "foo"
delete_on_exit = true
}
}
`},
patch: func(rt *RuntimeConfig) {
rt.DataDir = dataDir
rt.ConnectEnabled = true
rt.ConnectCAProvider = "aws-pca"
rt.ConnectCAConfig = map[string]interface{}{
"ExistingARN": "foo",
"DeleteOnExit": true,
}
},
},
{
desc: "Connect AWS CA provider TTL validation",
args: []string{
`-data-dir=` + dataDir,
},
json: []string{`{
"connect": {
"enabled": true,
"ca_provider": "aws-pca",
"ca_config": {
"leaf_cert_ttl": "1h"
}
}
}`},
hcl: []string{`
connect {
enabled = true
ca_provider = "aws-pca"
ca_config {
leaf_cert_ttl = "1h"
}
}
`},
err: "AWS PCA doesn't support certificates that are valid for less than 24 hours",
},
{
desc: "Connect AWS CA provider EC key length validation",
args: []string{
`-data-dir=` + dataDir,
},
json: []string{`{
"connect": {
"enabled": true,
"ca_provider": "aws-pca",
"ca_config": {
"private_key_bits": 521
}
}
}`},
hcl: []string{`
connect {
enabled = true
ca_provider = "aws-pca"
ca_config {
private_key_bits = 521
}
}
`},
err: "AWS PCA only supports P256 EC curve",
},
// ------------------------------------------------------------
// ConfigEntry Handling

View File

@ -63,3 +63,29 @@ func validateSetIntermediate(
return nil
}
func validateSignIntermediate(csr *x509.CertificateRequest, spiffeID *connect.SpiffeIDSigning) error {
// We explicitly _don't_ require that the CSR has a valid SPIFFE signing URI
// SAN because AWS PCA doesn't let us set one :(. We need to relax it here
// otherwise it would be impossible to migrate from built-in provider to AWS
// in multiple DCs without downtime. Nothing in Connect actually checks that
// currently so this is OK for now but it's sad we have to break the SPIFFE
// spec for AWS sake. Hopefully they'll add that ability soon.
uriCount := len(csr.URIs)
if uriCount > 0 {
if uriCount != 1 {
return fmt.Errorf("incoming CSR has unexpected number of URIs: %d", uriCount)
}
certURI, err := connect.ParseCertURI(csr.URIs[0])
if err != nil {
return err
}
// Verify that the trust domain is valid.
if !spiffeID.CanSign(certURI) {
return fmt.Errorf("incoming CSR domain %q is not valid for our domain %q",
certURI.URI().String(), spiffeID.URI().String())
}
}
return nil
}

View File

@ -2,11 +2,19 @@ package ca
import (
"crypto/x509"
"errors"
"log"
)
//go:generate mockery -name Provider -inpkg
// ErrRateLimited is a sentinel error value Providers may return from any method
// to indicate that the operation can't complete due to a temporary rate limit.
// In the case of signing new certificates, Consul clients will respect this and
// intelligently backoff to optimize rotation rollout time while reducing load
// on servers and CA provider.
var ErrRateLimited = errors.New("operation rate limited by CA provider")
// ProviderConfig encapsulates all the data Consul passes to `Configure` on a
// new provider instance. The provider must treat this as read-only and make
// copies of any map or slice if it might modify them internally.
@ -71,7 +79,7 @@ type Provider interface {
// GenerateRoot causes the creation of a new root certificate for this provider.
// This can also be a no-op if a root certificate already exists for the given
// config. If isRoot is false, calling this method is an error.
// config. If IsPrimary is false, calling this method is an error.
GenerateRoot() error
// ActiveRoot returns the currently active root CA for this
@ -80,13 +88,13 @@ type Provider interface {
ActiveRoot() (string, error)
// GenerateIntermediateCSR generates a CSR for an intermediate CA
// certificate, to be signed by the root of another datacenter. If isRoot was
// certificate, to be signed by the root of another datacenter. If IsPrimary was
// set to true with Configure(), calling this is an error.
GenerateIntermediateCSR() (string, error)
// SetIntermediate sets the provider to use the given intermediate certificate
// as well as the root it was signed by. This completes the initialization for
// a provider where isRoot was set to false in Configure().
// a provider where IsPrimary was set to false in Configure().
SetIntermediate(intermediatePEM, rootPEM string) error
// ActiveIntermediate returns the current signing cert used by this provider
@ -105,13 +113,19 @@ type Provider interface {
// Sign signs a leaf certificate used by Connect proxies from a CSR. The PEM
// returned should include only the leaf certificate as all Intermediates
// needed to validate it will be added by Consul based on the active
// intemediate and any cross-signed intermediates managed by Consul.
// intemediate and any cross-signed intermediates managed by Consul. Note that
// providers should return ErrRateLimited if they are unable to complete the
// operation due to upstream rate limiting so that clients can intelligently
// backoff.
Sign(*x509.CertificateRequest) (string, error)
// SignIntermediate will validate the CSR to ensure the trust domain in the
// URI SAN matches the local one and that basic constraints for a CA certificate
// are met. It should return a signed CA certificate with a path length constraint
// of 0 to ensure that the certificate cannot be used to generate further CA certs.
// URI SAN matches the local one and that basic constraints for a CA
// certificate are met. It should return a signed CA certificate with a path
// length constraint of 0 to ensure that the certificate cannot be used to
// generate further CA certs. Note that providers should return ErrRateLimited
// if they are unable to complete the operation due to upstream rate limiting
// so that clients can intelligently backoff.
SignIntermediate(*x509.CertificateRequest) (string, error)
// CrossSignCA must accept a CA certificate from another CA provider and cross
@ -124,7 +138,10 @@ type Provider interface {
// returned as a PEM formatted string.
//
// If the CA provider does not support this operation, it may return an error
// provided `SupportsCrossSigning` also returns false.
// provided `SupportsCrossSigning` also returns false. Note that
// providers should return ErrRateLimited if they are unable to complete the
// operation due to upstream rate limiting so that clients can intelligently
// backoff.
CrossSignCA(*x509.Certificate) (string, error)
// SupportsCrossSigning should indicate whether the CA provider supports

View File

@ -0,0 +1,715 @@
package ca
import (
"bytes"
"crypto/x509"
"encoding/pem"
"fmt"
"log"
"sync/atomic"
"time"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go/service/acmpca"
"github.com/mitchellh/mapstructure"
"github.com/hashicorp/consul/agent/connect"
"github.com/hashicorp/consul/agent/structs"
)
const (
// RootTemplateARN is the AWS-defined template we need to use when issuing a
// root cert.
RootTemplateARN = "arn:aws:acm-pca:::template/RootCACertificate/V1"
// IntermediateTemplateARN is the AWS-defined template we need to use when
// issuing an intermediate cert.
IntermediateTemplateARN = "arn:aws:acm-pca:::template/SubordinateCACertificate_PathLen0/V1"
// LeafTemplateARN is the AWS-defined template we need to use when issuing a
// leaf cert.
LeafTemplateARN = "arn:aws:acm-pca:::template/EndEntityCertificate/V1"
// RootTTL is the validity duration for root certs we create.
AWSRootTTL = 5 * 365 * 24 * time.Hour
// IntermediateTTL is the validity duration for the intermediate certs we
// create.
AWSIntermediateTTL = 1 * 365 * 24 * time.Hour
// SignTimout is the maximum time we will spend waiting (polling) for a leaf
// certificate to be signed.
AWSSignTimeout = 45 * time.Second
// CreateTimeout is the maximum time we will spend waiting (polling)
// for the CA to be created.
AWSCreateTimeout = 2 * time.Minute
// AWSStateCAARNKey is the key in the provider State we store the ARN of the
// CA we created if any.
AWSStateCAARNKey = "CA_ARN"
// day is a more readable shorthand for a duration of 24 hours. Note time
// package doesn't provide time.Day due to ambiguity around DST and leap
// seconds where a day may not actually be 24 hours.
day = 24 * time.Hour
)
// AWSProvider implements Provider for AWS ACM PCA
type AWSProvider struct {
stopped uint32 // atomically accessed, at start to prevent alignment issues
stopCh chan struct{}
config *structs.AWSCAProviderConfig
session *session.Session
client *acmpca.ACMPCA
isPrimary bool
datacenter string
clusterID string
arn string
arnChecked bool
caCreated bool
rootPEM string
intermediatePEM string
logger *log.Logger
}
// SetLogger implements NeedsLogger
func (a *AWSProvider) SetLogger(l *log.Logger) {
a.logger = l
}
// Configure implements Provider
func (a *AWSProvider) Configure(cfg ProviderConfig) error {
config, err := ParseAWSCAConfig(cfg.RawConfig)
if err != nil {
return err
}
// We only support setting IAM credentials through the normal methods ENV,
// SharedCredentialsFile, IAM role. Per
// https://docs.aws.amazon.com/sdk-for-go/v1/developer-guide/configuring-sdk.html#specifying-credentials
// Putting them in CA config amounts to writing them to disk config file in
// another place or sending them via API call and persisting them in state
// store in a new place on disk. One of the existing standard solutions seems
// better in all cases.
awsSession, err := session.NewSessionWithOptions(session.Options{
SharedConfigState: session.SharedConfigEnable,
})
if err != nil {
return err
}
a.config = config
a.session = awsSession
a.isPrimary = cfg.IsPrimary
a.clusterID = cfg.ClusterID
a.datacenter = cfg.Datacenter
a.client = acmpca.New(awsSession)
a.stopCh = make(chan struct{})
// Load the ARN from config or previous state.
if config.ExistingARN != "" {
a.arn = config.ExistingARN
} else if arn := cfg.State[AWSStateCAARNKey]; arn != "" {
a.arn = arn
// We only pass ARN through state if we created the resource. We don't
// "remember" previously existing resources the user configured.
a.caCreated = true
}
return nil
}
// State implements Provider
func (a *AWSProvider) State() (map[string]string, error) {
if a.arn == "" {
return nil, nil
}
// Preserve the CA ARN if there is one
state := make(map[string]string)
state[AWSStateCAARNKey] = a.arn
return state, nil
}
// GenerateRoot implements Provider
func (a *AWSProvider) GenerateRoot() error {
if !a.isPrimary {
return fmt.Errorf("provider is not the root certificate authority")
}
return a.ensureCA()
}
// ensureCA loads the CA resource to check it exists if configured by User or in
// state from previous run. Otherwise it creates a new CA of the correct type
// for this DC.
func (a *AWSProvider) ensureCA() error {
// If we already have an ARN, we assume the CA is created and sanity check
// it's available.
if a.arn != "" {
// Only check this once on startup not on every operation
if a.arnChecked {
return nil
}
// Load from the resource.
input := &acmpca.DescribeCertificateAuthorityInput{
CertificateAuthorityArn: aws.String(a.arn),
}
output, err := a.client.DescribeCertificateAuthority(input)
if err != nil {
return err
}
// Allow it to be active or pending a certificate (leadership might have
// changed during a secondary initialization for example).
if *output.CertificateAuthority.Status != acmpca.CertificateAuthorityStatusActive &&
*output.CertificateAuthority.Status != acmpca.CertificateAuthorityStatusPendingCertificate {
verb := "configured"
if a.caCreated {
verb = "created"
}
// Don't recreate CA that was manually disabled, force full deletion or
// manual recreation. We might later support this or an explicit
// "recreate" config option to allow rotating without a manual creation
// but this is simpler and less surprising default behavior if user
// disabled a CA due to a security concern and we just work around it.
return fmt.Errorf("the %s PCA is not active: status is %s", verb,
*output.CertificateAuthority.Status)
}
// Load the certs
if err := a.loadCACerts(); err != nil {
return err
}
a.arnChecked = true
return nil
}
// Need to create a Private CA resource.
err := a.createPCA()
if err != nil {
return err
}
// If we are in a secondary DC this is all we can do for now - the rest is
// handled by the Initialization routine of calling GenerateIntermediateCSR
// and then SetIntermediate.
if !a.isPrimary {
return nil
}
// CA is created and in PENDING_CERTIFCATE state, generate a self-signed cert
// and install it.
csrPEM, err := a.getCACSR()
if err != nil {
return err
}
// Self-sign it as a root
certPEM, err := a.signCSR(csrPEM, RootTemplateARN, AWSRootTTL)
if err != nil {
return err
}
// Submit the signed cert
input := acmpca.ImportCertificateAuthorityCertificateInput{
CertificateAuthorityArn: aws.String(a.arn),
Certificate: []byte(certPEM),
}
a.logger.Printf("[DEBUG] connect.ca.aws: uploading certificate for %s", a.arn)
_, err = a.client.ImportCertificateAuthorityCertificate(&input)
if err != nil {
return err
}
a.rootPEM = certPEM
return nil
}
func keyTypeToAlgos(keyType string, keyBits int) (string, string, error) {
switch keyType {
case "rsa":
switch keyBits {
case 2048:
return acmpca.KeyAlgorithmRsa2048, acmpca.SigningAlgorithmSha256withrsa, nil
case 4096:
return acmpca.KeyAlgorithmRsa4096, acmpca.SigningAlgorithmSha256withrsa, nil
default:
return "", "", fmt.Errorf("AWS PCA only supports RSA key lengths 2048"+
" and 4096, PrivateKeyBits of %d configured", keyBits)
}
case "ec":
if keyBits != 256 {
return "", "", fmt.Errorf("AWS PCA only supports P256 EC curve, keyBits of %d configured", keyBits)
}
return acmpca.KeyAlgorithmEcPrime256v1, acmpca.SigningAlgorithmSha256withecdsa, nil
default:
return "", "", fmt.Errorf("AWS PCA only supports P256 EC curve, or RSA"+
" 2048/4096. %s, %d configured", keyType, keyBits)
}
}
func (a *AWSProvider) createPCA() error {
pcaType := "ROOT" // For some reason there is no constant for this in the SDK
if !a.isPrimary {
pcaType = acmpca.CertificateAuthorityTypeSubordinate
}
keyAlg, signAlg, err := keyTypeToAlgos(a.config.PrivateKeyType, a.config.PrivateKeyBits)
if err != nil {
return err
}
uid, err := connect.CompactUID()
if err != nil {
return err
}
commonName := connect.CACN("aws", uid, a.clusterID, a.isPrimary)
createInput := acmpca.CreateCertificateAuthorityInput{
CertificateAuthorityType: aws.String(pcaType),
CertificateAuthorityConfiguration: &acmpca.CertificateAuthorityConfiguration{
Subject: &acmpca.ASN1Subject{
CommonName: aws.String(commonName),
},
KeyAlgorithm: aws.String(keyAlg),
SigningAlgorithm: aws.String(signAlg),
},
RevocationConfiguration: &acmpca.RevocationConfiguration{
// TODO support CRL in future when we manage revocation in Connect more
// generally.
CrlConfiguration: &acmpca.CrlConfiguration{
Enabled: aws.Bool(false),
},
},
// uid is unique to each PCA we create so use it as an idempotency string. We
// don't actually retry on failure yet but might as well!
IdempotencyToken: aws.String(uid),
Tags: []*acmpca.Tag{
{Key: aws.String("consul_cluster_id"), Value: aws.String(a.clusterID)},
{Key: aws.String("consul_datacenter"), Value: aws.String(a.datacenter)},
},
}
a.logger.Printf("[DEBUG] creating new PCA %s", commonName)
createOutput, err := a.client.CreateCertificateAuthority(&createInput)
if err != nil {
return err
}
// wait for PCA to be created
newARN := *createOutput.CertificateAuthorityArn
describeInput := acmpca.DescribeCertificateAuthorityInput{
CertificateAuthorityArn: aws.String(newARN),
}
_, err = a.pollLoop("Private CA", AWSCreateTimeout, func() (bool, string, error) {
describeOutput, err := a.client.DescribeCertificateAuthority(&describeInput)
if err != nil {
if err.(awserr.Error).Code() != acmpca.ErrCodeRequestInProgressException {
return true, "", fmt.Errorf("error waiting for PCA to be created: %s", err)
}
}
if *describeOutput.CertificateAuthority.Status == acmpca.CertificateAuthorityStatusPendingCertificate {
a.logger.Printf("[DEBUG] connect.ca.aws: new PCA %s is ready"+
" to accept a certificate", newARN)
a.arn = newARN
// We don't need to reload this ARN since we just created it and know what
// state it's in
a.arnChecked = true
return true, "", nil
}
// Retry
return false, "", nil
})
return err
}
func (a *AWSProvider) getCACSR() (string, error) {
input := &acmpca.GetCertificateAuthorityCsrInput{
CertificateAuthorityArn: aws.String(a.arn),
}
a.logger.Printf("[DEBUG] connect.ca.aws: retrieving CSR for %s", a.arn)
output, err := a.client.GetCertificateAuthorityCsr(input)
if err != nil {
return "", err
}
csrPEM := output.Csr
if csrPEM == nil {
// Probably shouldn't be able to happen but being defensive.
return "", fmt.Errorf("invalid response from AWS PCA: CSR is nil")
}
return *csrPEM, nil
}
func (a *AWSProvider) loadCACerts() error {
input := &acmpca.GetCertificateAuthorityCertificateInput{
CertificateAuthorityArn: aws.String(a.arn),
}
output, err := a.client.GetCertificateAuthorityCertificate(input)
if err != nil {
return err
}
if a.isPrimary {
// Just use the cert as a root
a.rootPEM = *output.Certificate
} else {
a.intermediatePEM = *output.Certificate
// TODO(banks) support user-supplied CA being a Subordinate even in the
// primary DC. For now this assumes there is only one cert in the chain
if output.CertificateChain == nil {
return fmt.Errorf("Subordinate CA %s returned no chain", a.arn)
}
a.rootPEM = *output.CertificateChain
}
return nil
}
func (a *AWSProvider) signCSRRaw(csr *x509.CertificateRequest, templateARN string, ttl time.Duration) (string, error) {
// PEM encode the CSR
var pemBuf bytes.Buffer
if err := pem.Encode(&pemBuf, &pem.Block{Type: "CERTIFICATE REQUEST", Bytes: csr.Raw}); err != nil {
return "", err
}
return a.signCSR(pemBuf.String(), templateARN, ttl)
}
// pollWait returns how long to wait for the next poll of an async operation. We
// optimize for times typically seen in the API. This is called _before_ the
// first poll so we can provide a typical delay since operations are never
// complete immediately so it's a waste to try.
func pollWait(attemptsMade int) time.Duration {
// Hard code times for now
waits := []time.Duration{
// Never seen it complete first time with a lower value
100 * time.Millisecond,
200 * time.Millisecond,
500 * time.Millisecond,
1 * time.Second,
3 * time.Second,
5 * time.Second,
}
maxIdx := len(waits) - 1
if attemptsMade > maxIdx {
attemptsMade = maxIdx
}
return waits[attemptsMade]
}
func (a *AWSProvider) pollLoop(desc string, timeout time.Duration, f func() (bool, string, error)) (string, error) {
attemptsMade := 0
start := time.Now()
wait := pollWait(attemptsMade)
for {
elapsed := time.Since(start)
if elapsed >= timeout {
return "", fmt.Errorf("timeout after %s waiting for %s", elapsed, desc)
}
a.logger.Printf("[DEBUG] connect.ca.aws: %s pending"+
", waiting %s to check readiness", desc, wait)
select {
case <-a.stopCh:
// Provider discarded
a.logger.Print("[WARN] connect.ca.aws: provider instance terminated"+
" while waiting for %s.", desc)
return "", fmt.Errorf("provider terminated")
case <-time.After(wait):
// Continue looping...
}
done, out, err := f()
if err != nil {
return "", err
}
if done {
return out, err
}
attemptsMade++
wait = pollWait(attemptsMade)
}
}
func (a *AWSProvider) signCSR(csrPEM string, templateARN string, ttl time.Duration) (string, error) {
_, signAlg, err := keyTypeToAlgos(a.config.PrivateKeyType, a.config.PrivateKeyBits)
if err != nil {
return "", err
}
issueInput := acmpca.IssueCertificateInput{
CertificateAuthorityArn: aws.String(a.arn),
Csr: []byte(csrPEM),
SigningAlgorithm: aws.String(signAlg),
TemplateArn: aws.String(templateARN),
Validity: &acmpca.Validity{
Value: aws.Int64(int64(ttl / day)),
Type: aws.String(acmpca.ValidityPeriodTypeDays),
},
}
issueOutput, err := a.client.IssueCertificate(&issueInput)
// ErrCodeLimitExceededException is used for both hard and soft limits in AWS
// SDK :(. In this specific context though (issuing a certificate) there is no
// hard limit on number of certs so a limit exceeded here is a rate limit.
if aerr, ok := err.(awserr.Error); ok && err != nil {
if aerr.Code() == acmpca.ErrCodeLimitExceededException {
return "", ErrRateLimited
}
}
if err != nil {
return "", fmt.Errorf("error issuing certificate from PCA: %s", err)
}
// wait for certificate to be created
certInput := acmpca.GetCertificateInput{
CertificateAuthorityArn: aws.String(a.arn),
CertificateArn: issueOutput.CertificateArn,
}
return a.pollLoop(fmt.Sprintf("certificate %s", *issueOutput.CertificateArn),
AWSSignTimeout,
func() (bool, string, error) {
certOutput, err := a.client.GetCertificate(&certInput)
if err != nil {
if err.(awserr.Error).Code() != acmpca.ErrCodeRequestInProgressException {
return true, "", fmt.Errorf("error retrieving certificate from PCA: %s", err)
}
}
if certOutput.Certificate != nil {
return true, *certOutput.Certificate, nil
}
return false, "", nil
})
}
// ActiveRoot implements Provider
func (a *AWSProvider) ActiveRoot() (string, error) {
err := a.ensureCA()
if err != nil {
return "", err
}
if a.rootPEM == "" {
return "", fmt.Errorf("Secondary AWS CA provider not fully Initialized")
}
return a.rootPEM, nil
}
// GenerateIntermediateCSR implements Provider
func (a *AWSProvider) GenerateIntermediateCSR() (string, error) {
if a.isPrimary {
return "", fmt.Errorf("provider is the root certificate authority, " +
"cannot generate an intermediate CSR")
}
err := a.ensureCA()
if err != nil {
return "", err
}
// We should have the CA created now and should be able to generate the CSR.
return a.getCACSR()
}
// SetIntermediate implements Provider
func (a *AWSProvider) SetIntermediate(intermediatePEM string, rootPEM string) error {
err := a.ensureCA()
if err != nil {
return err
}
// Install the certificate
input := acmpca.ImportCertificateAuthorityCertificateInput{
CertificateAuthorityArn: aws.String(a.arn),
Certificate: []byte(intermediatePEM),
CertificateChain: []byte(rootPEM),
}
a.logger.Printf("[DEBUG] uploading certificate for %s", a.arn)
_, err = a.client.ImportCertificateAuthorityCertificate(&input)
if err != nil {
return err
}
// We succsefully initialized, keep track of the root and intermediate certs.
a.rootPEM = rootPEM
a.intermediatePEM = intermediatePEM
return nil
}
// ActiveIntermediate implements Provider
func (a *AWSProvider) ActiveIntermediate() (string, error) {
err := a.ensureCA()
if err != nil {
return "", err
}
if a.rootPEM == "" {
return "", fmt.Errorf("AWS CA provider not fully Initialized")
}
if a.isPrimary {
// In the simple case the primary DC owns a Root CA and signs with it
// directly so just return that for "intermediate" too since that is what we
// will sign leafs with.
//
// TODO(banks) support user-supplied CA being a Subordinate even in the
// primary DC. We'd have to figure that out here and return the actual
// signing cert as well as somehow populate the intermediate chain.
return a.rootPEM, nil
}
if a.intermediatePEM == "" {
return "", fmt.Errorf("secondary AWS CA provider not fully Initialized")
}
return a.intermediatePEM, nil
}
// GenerateIntermediate implements Provider
func (a *AWSProvider) GenerateIntermediate() (string, error) {
// Like the consul provider, for now the Primary DC just gets a root and no
// intermediate to sign with. so just return this. Secondaries use
// intermediates but this method is only called during primary DC (root)
// initialization in case a provider generates separate root and
// intermediates.
//
// TODO(banks) support user-supplied CA being a Subordinate even in the
// primary DC.
return a.ActiveIntermediate()
}
// Sign implements Provider
func (a *AWSProvider) Sign(csr *x509.CertificateRequest) (string, error) {
if a.rootPEM == "" {
return "", fmt.Errorf("AWS CA provider not fully Initialized")
}
a.logger.Printf("[DEBUG] connect.ca.aws: signing csr for %s",
csr.Subject.CommonName)
return a.signCSRRaw(csr, LeafTemplateARN, a.config.LeafCertTTL)
}
// SignIntermediate implements Provider
func (a *AWSProvider) SignIntermediate(csr *x509.CertificateRequest) (string, error) {
err := validateSignIntermediate(csr, &connect.SpiffeIDSigning{ClusterID: a.clusterID, Domain: "consul"})
if err != nil {
return "", err
}
// Sign it!
return a.signCSRRaw(csr, IntermediateTemplateARN, AWSIntermediateTTL)
}
// CrossSignCA implements Provider
func (a *AWSProvider) CrossSignCA(newCA *x509.Certificate) (string, error) {
return "", fmt.Errorf("not implemented in AWS PCA provider")
}
func (a *AWSProvider) disablePCA() error {
if a.arn == "" {
return nil
}
input := acmpca.UpdateCertificateAuthorityInput{
CertificateAuthorityArn: aws.String(a.arn),
Status: aws.String(acmpca.CertificateAuthorityStatusDisabled),
}
a.logger.Printf("[INFO] connect.ca.aws: disabling PCA %s", a.arn)
_, err := a.client.UpdateCertificateAuthority(&input)
return err
}
func (a *AWSProvider) deletePCA() error {
if a.arn == "" {
return nil
}
input := acmpca.DeleteCertificateAuthorityInput{
CertificateAuthorityArn: aws.String(a.arn),
// We only ever use this to clean up after tests so delete as quickly as
// possible (7 days).
PermanentDeletionTimeInDays: aws.Int64(7),
}
a.logger.Printf("[INFO] connect.ca.aws: deleting PCA %s", a.arn)
_, err := a.client.DeleteCertificateAuthority(&input)
return err
}
// Cleanup implements Provider
func (a *AWSProvider) Cleanup() error {
old := atomic.SwapUint32(&a.stopped, 1)
if old == 0 {
close(a.stopCh)
}
if a.config.DeleteOnExit {
if err := a.disablePCA(); err != nil {
// Log the error but continue trying to delete as some errors may still
// allow that and this is best-effort delete anyway.
a.logger.Printf("[ERR] connect.ca.aws: failed to disable PCA %s: %s",
a.arn, err)
}
if err := a.deletePCA(); err != nil {
// Log the error but continue trying to delete as some errors may still
// allow that and this is best-effort delete anyway.
a.logger.Printf("[ERR] connect.ca.aws: failed to delete PCA %s: %s",
a.arn, err)
}
// Don't stall leader shutdown, non of the failures here are fatal.
return nil
}
return nil
}
// SupportsCrossSigning implements Provider
func (a *AWSProvider) SupportsCrossSigning() (bool, error) {
return false, nil
}
// ParseAWSCAConfig parses and validates AWS CA Provider configuration.
func ParseAWSCAConfig(raw map[string]interface{}) (*structs.AWSCAProviderConfig, error) {
config := structs.AWSCAProviderConfig{
CommonCAProviderConfig: defaultCommonConfig(),
}
decodeConf := &mapstructure.DecoderConfig{
DecodeHook: structs.ParseDurationFunc(),
Result: &config,
WeaklyTypedInput: true,
}
decoder, err := mapstructure.NewDecoder(decodeConf)
if err != nil {
return nil, err
}
if err := decoder.Decode(raw); err != nil {
return nil, fmt.Errorf("error decoding config: %s", err)
}
if err := config.CommonCAProviderConfig.Validate(); err != nil {
return nil, err
}
// Extra keytype validation since PCA is more limited than other providers
_, _, err = keyTypeToAlgos(config.PrivateKeyType, config.PrivateKeyBits)
if err != nil {
return nil, err
}
if config.LeafCertTTL < 24*time.Hour {
return nil, fmt.Errorf("AWS PCA doesn't support certificates that are valid"+
" for less than 24 hours, LeafTTL of %s configured", config.LeafCertTTL)
}
return &config, nil
}

View File

@ -0,0 +1,276 @@
package ca
import (
"log"
"os"
"strconv"
"testing"
"github.com/hashicorp/consul/agent/connect"
"github.com/stretchr/testify/require"
)
func skipIfAWSNotConfigured(t *testing.T) bool {
enabled := os.Getenv("ENABLE_AWS_PCA_TESTS")
ok, err := strconv.ParseBool(enabled)
if err != nil || !ok {
t.Skip("Skipping because AWS tests are not enabled")
return true
}
return false
}
func TestAWSBootstrapAndSignPrimary(t *testing.T) {
// Note not parallel since we could easily hit AWS limits of too many CAs if
// all of these tests run at once.
if skipIfAWSNotConfigured(t) {
return
}
for _, tc := range KeyTestCases {
tc := tc
t.Run(tc.Desc, func(t *testing.T) {
require := require.New(t)
cfg := map[string]interface{}{
"PrivateKeyType": tc.KeyType,
"PrivateKeyBits": tc.KeyBits,
}
provider := testAWSProvider(t, testProviderConfigPrimary(t, cfg))
defer provider.Cleanup()
// Generate the root
require.NoError(provider.GenerateRoot())
// Fetch Active Root
rootPEM, err := provider.ActiveRoot()
require.NoError(err)
// Generate Intermediate (not actually needed for this provider for now
// but this simulates the calls in Server.initializeRoot).
interPEM, err := provider.GenerateIntermediate()
require.NoError(err)
// Should be the same for now
require.Equal(rootPEM, interPEM)
// Ensure they use the right key type
rootCert, err := connect.ParseCert(rootPEM)
require.NoError(err)
keyType, keyBits, err := connect.KeyInfoFromCert(rootCert)
require.NoError(err)
require.Equal(tc.KeyType, keyType)
require.Equal(tc.KeyBits, keyBits)
// Sign a leaf with it
testSignAndValidate(t, provider, rootPEM, nil)
})
}
}
func testSignAndValidate(t *testing.T, p Provider, rootPEM string, intermediatePEMs []string) {
csrPEM, _ := connect.TestCSR(t, connect.TestSpiffeIDService(t, "testsvc"))
csr, err := connect.ParseCSR(csrPEM)
require.NoError(t, err)
leafPEM, err := p.Sign(csr)
require.NoError(t, err)
err = connect.ValidateLeaf(rootPEM, leafPEM, intermediatePEMs)
require.NoError(t, err)
}
func TestAWSBootstrapAndSignSecondary(t *testing.T) {
// Note not parallel since we could easily hit AWS limits of too many CAs if
// all of these tests run at once.
if skipIfAWSNotConfigured(t) {
return
}
p1 := testAWSProvider(t, testProviderConfigPrimary(t, nil))
defer p1.Cleanup()
rootPEM, err := p1.ActiveRoot()
require.NoError(t, err)
p2 := testAWSProvider(t, testProviderConfigSecondary(t, nil))
defer p2.Cleanup()
testSignIntermediateCrossDC(t, p1, p2)
// Fetch intermediate from s2 now for later comparison
intPEM, err := p2.ActiveIntermediate()
require.NoError(t, err)
// Capture the state of the providers we've setup
p1State, err := p1.State()
require.NoError(t, err)
p2State, err := p2.State()
require.NoError(t, err)
// TEST LOAD FROM PREVIOUS STATE
{
// Now create new providers fromthe state of the first ones simulating
// leadership change in both DCs
t.Log("Restarting Providers with State")
// Create new provider instances
cfg1 := testProviderConfigPrimary(t, nil)
cfg1.State = p1State
p1 = testAWSProvider(t, cfg1)
newRootPEM, err := p1.ActiveRoot()
require.NoError(t, err)
cfg2 := testProviderConfigPrimary(t, nil)
cfg2.State = p2State
p2 = testAWSProvider(t, cfg2)
// Need call ActiveIntermediate like leader would to trigger loading from PCA
newIntPEM, err := p2.ActiveIntermediate()
require.NoError(t, err)
// Root cert should not have changed
require.Equal(t, rootPEM, newRootPEM)
// Secondary intermediate cert should not have changed
require.NoError(t, err)
require.Equal(t, rootPEM, newRootPEM)
require.Equal(t, intPEM, newIntPEM)
// Should both be able to sign leafs again
testSignAndValidate(t, p1, rootPEM, nil)
testSignAndValidate(t, p2, rootPEM, []string{intPEM})
}
// Since we have CAs created, test the use-case where User supplied CAs are
// used.
{
t.Log("Starting up Providers with ExistingARNs")
// Create new provider instances with config
cfg1 := testProviderConfigPrimary(t, map[string]interface{}{
"ExistingARN": p1State[AWSStateCAARNKey],
})
p1 = testAWSProvider(t, cfg1)
newRootPEM, err := p1.ActiveRoot()
require.NoError(t, err)
cfg2 := testProviderConfigPrimary(t, map[string]interface{}{
"ExistingARN": p2State[AWSStateCAARNKey],
})
cfg1.RawConfig["ExistingARN"] = p2State[AWSStateCAARNKey]
p2 = testAWSProvider(t, cfg2)
// Need call ActiveIntermediate like leader would to trigger loading from PCA
newIntPEM, err := p2.ActiveIntermediate()
require.NoError(t, err)
// Root cert should not have changed
require.Equal(t, rootPEM, newRootPEM)
// Secondary intermediate cert should not have changed
require.NoError(t, err)
require.Equal(t, rootPEM, newRootPEM)
require.Equal(t, intPEM, newIntPEM)
// Should both be able to sign leafs again
testSignAndValidate(t, p1, rootPEM, nil)
testSignAndValidate(t, p2, rootPEM, []string{intPEM})
}
}
func TestAWSBootstrapAndSignSecondaryConsul(t *testing.T) {
// Note not parallel since we could easily hit AWS limits of too many CAs if
// all of these tests run at once.
if skipIfAWSNotConfigured(t) {
return
}
t.Run("pri=consul,sec=aws", func(t *testing.T) {
conf := testConsulCAConfig()
delegate := newMockDelegate(t, conf)
p1 := TestConsulProvider(t, delegate)
cfg := testProviderConfig(conf)
require.NoError(t, p1.Configure(cfg))
require.NoError(t, p1.GenerateRoot())
p2 := testAWSProvider(t, testProviderConfigSecondary(t, nil))
defer p2.Cleanup()
testSignIntermediateCrossDC(t, p1, p2)
})
t.Run("pri=aws,sec=consul", func(t *testing.T) {
p1 := testAWSProvider(t, testProviderConfigPrimary(t, nil))
defer p1.Cleanup()
require.NoError(t, p1.GenerateRoot())
conf := testConsulCAConfig()
delegate := newMockDelegate(t, conf)
p2 := TestConsulProvider(t, delegate)
cfg := testProviderConfig(conf)
cfg.IsPrimary = false
cfg.Datacenter = "dc2"
require.NoError(t, p2.Configure(cfg))
testSignIntermediateCrossDC(t, p1, p2)
})
}
func TestAWSNoCrossSigning(t *testing.T) {
if skipIfAWSNotConfigured(t) {
return
}
p1 := testAWSProvider(t, testProviderConfigPrimary(t, nil))
defer p1.Cleanup()
// Don't bother initializing a PCA as that is slow and unnecessary for this
// test
ok, err := p1.SupportsCrossSigning()
require.NoError(t, err)
require.False(t, ok)
// Attempt to cross sign a CA should fail with sensible error
ca := connect.TestCA(t, nil)
caCert, err := connect.ParseCert(ca.RootCert)
require.NoError(t, err)
_, err = p1.CrossSignCA(caCert)
require.Error(t, err)
require.Contains(t, err.Error(), "not implemented")
}
func testAWSProvider(t *testing.T, cfg ProviderConfig) *AWSProvider {
p := &AWSProvider{}
p.SetLogger(log.New(&testLogger{t}, "", log.LstdFlags))
require.NoError(t, p.Configure(cfg))
return p
}
type testLogger struct {
t *testing.T
}
func (l *testLogger) Write(b []byte) (int, error) {
l.t.Log(string(b))
return len(b), nil
}
func testProviderConfigPrimary(t *testing.T, cfg map[string]interface{}) ProviderConfig {
rawCfg := make(map[string]interface{})
for k, v := range cfg {
rawCfg[k] = v
}
rawCfg["DeleteOnExit"] = true
return ProviderConfig{
ClusterID: connect.TestClusterID,
Datacenter: "dc1",
IsPrimary: true,
RawConfig: rawCfg,
}
}
func testProviderConfigSecondary(t *testing.T, cfg map[string]interface{}) ProviderConfig {
c := testProviderConfigPrimary(t, cfg)
c.IsPrimary = false
c.Datacenter = "dc2"
return c
}

View File

@ -63,7 +63,7 @@ func (c *ConsulProvider) Configure(cfg ProviderConfig) error {
c.spiffeID = connect.SpiffeIDSigningForCluster(&structs.CAConfiguration{ClusterID: c.clusterID})
// Passthrough test state for state handling tests. See testState doc.
c.parseTestState(cfg.RawConfig)
c.parseTestState(cfg.RawConfig, cfg.State)
// Exit early if the state store has an entry for this provider's config.
_, providerState, err := c.Delegate.State().CAProviderState(c.id)
@ -426,20 +426,11 @@ func (c *ConsulProvider) SignIntermediate(csr *x509.CertificateRequest) (string,
return "", err
}
if uriCount := len(csr.URIs); uriCount != 1 {
return "", fmt.Errorf("incoming CSR has unexpected number of URIs: %d", uriCount)
}
certURI, err := connect.ParseCertURI(csr.URIs[0])
err = validateSignIntermediate(csr, c.spiffeID)
if err != nil {
return "", err
}
// Verify that the trust domain is valid.
if !c.spiffeID.CanSign(certURI) {
return "", fmt.Errorf("incoming CSR domain %q is not valid for our domain %q",
certURI.URI().String(), c.spiffeID.URI().String())
}
// Get the signing private key.
signer, err := connect.ParseSigner(providerState.PrivateKey)
if err != nil {
@ -671,7 +662,7 @@ func (c *ConsulProvider) SetLogger(logger *log.Logger) {
c.logger = logger
}
func (c *ConsulProvider) parseTestState(rawConfig map[string]interface{}) {
func (c *ConsulProvider) parseTestState(rawConfig map[string]interface{}, state map[string]string) {
c.testState = nil
if rawTestState, ok := rawConfig["test_state"]; ok {
if ts, ok := rawTestState.(map[string]string); ok {
@ -680,11 +671,12 @@ func (c *ConsulProvider) parseTestState(rawConfig map[string]interface{}) {
}
// Secondary's config takes a trip through the state store before Configure
// is called unlike primaries. This means we end up with map[string]string
// encoded as map[string]interface{}. Just handle that case and ignore the
// rest as this is test-only code and we'd rather not leave a way to error
// CA setup and leave cluster unavailable in prod by accidentally setting a
// bad test_state config.
// is called and RPC calls that msgpack encode also have the same effect. It
// means we end up with map[string]string encoded as map[string]interface{}.
// We just handle that case. There is no struct error handling because this
// is test-only code (undocumented config key) and we'd rather not leave a
// way to error CA setup and leave cluster unavailable in prod by
// accidentally setting a bad test_state config.
if ts, ok := rawTestState.(map[string]interface{}); ok {
c.testState = make(map[string]string)
for k, v := range ts {
@ -694,4 +686,11 @@ func (c *ConsulProvider) parseTestState(rawConfig map[string]interface{}) {
}
}
}
// If config didn't explicitly specify test_state to return, but there is some
// actual state from a previous provider. Just use that since that is expected
// behavior that providers with state would preserve the state they are passed
// in the common case.
if len(state) > 0 && c.testState == nil {
c.testState = state
}
}

View File

@ -333,8 +333,13 @@ func (v *VaultProvider) Sign(csr *x509.CertificateRequest) (string, error) {
// SignIntermediate returns a signed CA certificate with a path length constraint
// of 0 to ensure that the certificate cannot be used to generate further CA certs.
func (v *VaultProvider) SignIntermediate(csr *x509.CertificateRequest) (string, error) {
err := validateSignIntermediate(csr, v.spiffeID)
if err != nil {
return "", err
}
var pemBuf bytes.Buffer
err := pem.Encode(&pemBuf, &pem.Block{Type: "CERTIFICATE REQUEST", Bytes: csr.Raw})
err = pem.Encode(&pemBuf, &pem.Block{Type: "CERTIFICATE REQUEST", Bytes: csr.Raw})
if err != nil {
return "", err
}

View File

@ -24,7 +24,7 @@ var invalidDNSNameChars = regexp.MustCompile(`[^a-z0-9]`)
// name in that DC.
//
// Format is:
// <sanitized_service_name>.svc.<trust-domain-first-8>.consul
// <sanitized_service_name>.svc.<trust_domain_first_8>.consul
//
// service name is sanitized by removing any chars that are not legal in a DNS
// name and lower casing. It is truncated to the first X chars to keep the
@ -42,7 +42,7 @@ func ServiceCN(serviceName, trustDomain string) string {
// more details on rationale.
//
// Format is:
// <sanitized_node_name>.agnt.<trust-domain-first-8>.consul
// <sanitized_node_name>.agnt.<trust_domain_first_8>.consul
//
// node name is sanitized by removing any chars that are not legal in a DNS
// name and lower casing. It is truncated to the first X chars to keep the
@ -82,7 +82,7 @@ func CompactUID() (string, error) {
// details on rationale. A uniqueID is requires because some providers (e.g.
// Vault) cache by subject and so produce incorrect results - for example they
// won't cross-sign an older CA certificate with the same common name since they
// think they already have a valid cert for than CN and just return the current
// think they already have a valid cert for that CN and just return the current
// root.
//
// This can be generated by any means but will be truncated to 8 chars and
@ -90,20 +90,20 @@ func CompactUID() (string, error) {
// specific purpose.
//
// Format is:
// {provider}-{uniqueID-8-b36}.{root|int}.ca.<trust-domain-first-8>.consul
// {provider}-{uniqueID_first8}.{pri|sec}.ca.<trust_domain_first_8>.consul
//
// trust domain is truncated to keep the whole name short
func CACN(provider, uniqueID, trustDomain string, isRoot bool) string {
func CACN(provider, uniqueID, trustDomain string, primaryDC bool) string {
providerSan := invalidDNSNameChars.ReplaceAllString(strings.ToLower(provider), "")
typ := "root"
if !isRoot {
typ = "int"
typ := "pri"
if !primaryDC {
typ = "sec"
}
// 33 = 7 bytes for ".consul", 8 bytes for trust domain, 9 bytes for
// ".root.ca.", 9 bytes for "-{uniqueID-8-b36}"
// 32 = 7 bytes for ".consul", 8 bytes for trust domain, 8 bytes for
// ".pri.ca.", 9 bytes for "-{uniqueID-8-b36}"
uidSAN := invalidDNSNameChars.ReplaceAllString(strings.ToLower(uniqueID), "")
return fmt.Sprintf("%s-%s.%s.ca.%s.consul", typ, truncateTo(uidSAN, 8),
truncateTo(providerSan, 64-33), truncateTo(trustDomain, 8))
truncateTo(providerSan, 64-32), truncateTo(trustDomain, 8))
}
func truncateTo(s string, n int) string {

View File

@ -4,6 +4,7 @@ import (
"fmt"
"net/http"
"github.com/hashicorp/consul/agent/consul"
"github.com/hashicorp/consul/agent/structs"
)
@ -69,5 +70,11 @@ func (s *HTTPServer) ConnectCAConfigurationSet(resp http.ResponseWriter, req *ht
var reply interface{}
err := s.agent.RPC("ConnectCA.ConfigurationSet", &args, &reply)
if err != nil && err.Error() == consul.ErrStateReadOnly.Error() {
return nil, BadRequestError{
Reason: "Provider State is read-only. It must be omitted" +
" or identical to the current value",
}
}
return nil, err
}

View File

@ -63,10 +63,11 @@ func TestConnectCAConfig(t *testing.T) {
t.Parallel()
tests := []struct {
name string
body string
wantErr bool
wantCfg structs.CAConfiguration
name string
initialState string
body string
wantErr bool
wantCfg structs.CAConfiguration
}{
{
name: "basic",
@ -137,13 +138,64 @@ func TestConnectCAConfig(t *testing.T) {
ForceWithoutCrossSigning: true,
},
},
{
name: "setting state fails",
body: `
{
"Provider": "consul",
"State": {
"foo": "bar"
}
}`,
wantErr: true,
},
{
name: "updating config with same state",
initialState: `foo = "bar"`,
body: `
{
"Provider": "consul",
"config": {
"LeafCertTTL": "72h",
"RotationPeriod": "1h"
},
"State": {
"foo": "bar"
}
}`,
wantErr: false,
wantCfg: structs.CAConfiguration{
Provider: "consul",
ClusterID: connect.TestClusterID,
Config: map[string]interface{}{
"LeafCertTTL": "72h",
"RotationPeriod": "1h",
},
State: map[string]string{
"foo": "bar",
},
},
},
}
for _, tc := range tests {
tc := tc
t.Run(tc.name, func(t *testing.T) {
require := require.New(t)
a := NewTestAgent(t, t.Name(), "")
hcl := ""
if tc.initialState != "" {
hcl = `
connect {
enabled = true
ca_provider = "consul"
ca_config {
test_state {
` + tc.initialState + `
}
}
}`
}
a := NewTestAgent(t, t.Name(), hcl)
defer a.Shutdown()
testrpc.WaitForTestAgent(t, a.RPC, "dc1")

View File

@ -34,6 +34,7 @@ var (
ErrConnectNotEnabled = errors.New("Connect must be enabled in order to use this endpoint")
ErrRateLimited = errors.New("Rate limit reached, try again later")
ErrNotPrimaryDatacenter = errors.New("not the primary datacenter")
ErrStateReadOnly = errors.New("CA Provider State is read-only")
)
const (
@ -163,6 +164,13 @@ func (s *ConnectCA) ConfigurationSet(
return err
}
// Don't allow state changes. Either it needs to be empty or the same to allow
// read-modify-write loops that don't touch the State field.
if len(args.Config.State) > 0 &&
!reflect.DeepEqual(args.Config.State, config.State) {
return ErrStateReadOnly
}
// Don't allow users to change the ClusterID.
args.Config.ClusterID = config.ClusterID
if args.Config.Provider == config.Provider && reflect.DeepEqual(args.Config.Config, config.Config) {
@ -519,6 +527,9 @@ func (s *ConnectCA) Sign(
// All seems to be in order, actually sign it.
pem, err := provider.Sign(csr)
if err == ca.ErrRateLimited {
return ErrRateLimited
}
if err != nil {
return err
}

View File

@ -109,6 +109,8 @@ func (s *Server) createCAProvider(conf *structs.CAConfiguration) (ca.Provider, e
p = &ca.ConsulProvider{Delegate: &consulCADelegate{s}}
case structs.VaultCAProvider:
p = &ca.VaultProvider{}
case structs.AWSCAProvider:
p = &ca.AWSProvider{}
default:
return nil, fmt.Errorf("unknown CA provider %q", conf.Provider)
}

View File

@ -216,6 +216,7 @@ func (q *CARequest) RequestDatacenter() string {
const (
ConsulCAProvider = "consul"
VaultCAProvider = "vault"
AWSCAProvider = "aws-pca"
)
// CAConfiguration is the configuration for the current CA plugin.
@ -460,6 +461,13 @@ type VaultCAProviderConfig struct {
TLSSkipVerify bool
}
type AWSCAProviderConfig struct {
CommonCAProviderConfig `mapstructure:",squash"`
ExistingARN string
DeleteOnExit bool
}
// CALeafOp is the operation for a request related to leaf certificates.
type CALeafOp string

View File

@ -17,6 +17,12 @@ type CAConfig struct {
// and maps).
Config map[string]interface{}
// State is read-only data that the provider might have persisted for use
// after restart or leadership transition. For example this might include
// UUIDs of resources it has created. Setting this when writing a
// configuration is an error.
State map[string]string
CreateIndex uint64
ModifyIndex uint64
}

View File

@ -83,6 +83,9 @@ func TestAPI_ConnectCAConfig_get_set(t *testing.T) {
// Change a config value and update
conf.Config["PrivateKey"] = ""
conf.Config["RotationPeriod"] = 120 * 24 * time.Hour
// Pass through some state as if the provider stored it so we can make sure
// we can read it again.
conf.Config["test_state"] = map[string]string{"foo": "bar"}
_, err = connect.CASetConfig(conf, nil)
r.Check(err)
@ -92,5 +95,6 @@ func TestAPI_ConnectCAConfig_get_set(t *testing.T) {
parsed, err = ParseConsulCAConfig(updated.Config)
r.Check(err)
require.Equal(r, expected, parsed)
require.Equal(r, "bar", updated.State["foo"])
})
}

1
go.mod
View File

@ -14,6 +14,7 @@ require (
github.com/armon/circbuf v0.0.0-20150827004946-bbbad097214e
github.com/armon/go-metrics v0.0.0-20190430140413-ec5e00d3c878
github.com/armon/go-radix v1.0.0
github.com/aws/aws-sdk-go v1.25.32
github.com/coredns/coredns v1.1.2
github.com/digitalocean/godo v1.10.0 // indirect
github.com/docker/go-connections v0.3.0

4
go.sum
View File

@ -28,6 +28,8 @@ github.com/armon/go-radix v1.0.0 h1:F4z6KzEeeQIMeLFa97iZU6vupzoecKdU5TX24SNppXI=
github.com/armon/go-radix v1.0.0/go.mod h1:ufUuZ+zHj4x4TnLV4JWEpy2hxWSpsRywHrMgIH9cCH8=
github.com/aws/aws-sdk-go v1.15.24 h1:xLAdTA/ore6xdPAljzZRed7IGqQgC+nY+ERS5vaj4Ro=
github.com/aws/aws-sdk-go v1.15.24/go.mod h1:mFuSZ37Z9YOHbQEwBWztmVzqXrEkub65tZoCYDt7FT0=
github.com/aws/aws-sdk-go v1.25.32 h1:GhqlDvuPXnlW46VoKvfLZkJj5IA6jGLO+/TUPCJSYOY=
github.com/aws/aws-sdk-go v1.25.32/go.mod h1:KmX6BPdI08NWTb3/sm4ZGu5ShLoqVDhKgpiN924inxo=
github.com/beorn7/perks v0.0.0-20180321164747-3a771d992973 h1:xJ4a3vCFaGF/jqvzLMYoU8P317H5OQ+Via4RmuPwCS0=
github.com/beorn7/perks v0.0.0-20180321164747-3a771d992973/go.mod h1:Dwedo/Wpr24TaqPxmxbtue+5NUziq4I4S80YR8gNf3Q=
github.com/bgentry/speakeasy v0.1.0 h1:ByYyxL9InA1OWqxJqqp2A5pYHUrCiAL6K3J+LKSsQkY=
@ -207,6 +209,8 @@ github.com/jarcoal/httpmock v0.0.0-20180424175123-9c70cfe4a1da h1:FjHUJJ7oBW4G/9
github.com/jarcoal/httpmock v0.0.0-20180424175123-9c70cfe4a1da/go.mod h1:ks+b9deReOc7jgqp+e7LuFiCBH6Rm5hL32cLcEAArb4=
github.com/jmespath/go-jmespath v0.0.0-20160202185014-0b12d6b521d8 h1:12VvqtR6Aowv3l/EQUlocDHW2Cp4G9WJVH7uyH8QFJE=
github.com/jmespath/go-jmespath v0.0.0-20160202185014-0b12d6b521d8/go.mod h1:Nht3zPeWKUH0NzdCt2Blrr5ys8VGpn0CEB0cQHVjt7k=
github.com/jmespath/go-jmespath v0.0.0-20180206201540-c2b33e8439af h1:pmfjZENx5imkbgOkpRUYLnmbU7UEFbjtDA2hxJ1ichM=
github.com/jmespath/go-jmespath v0.0.0-20180206201540-c2b33e8439af/go.mod h1:Nht3zPeWKUH0NzdCt2Blrr5ys8VGpn0CEB0cQHVjt7k=
github.com/joyent/triton-go v0.0.0-20180628001255-830d2b111e62 h1:JHCT6xuyPUrbbgAPE/3dqlvUKzRHMNuTBKKUb6OeR/k=
github.com/joyent/triton-go v0.0.0-20180628001255-830d2b111e62/go.mod h1:U+RSyWxWd04xTqnuOQxnai7XGS2PrPY2cfGoDKtMHjA=
github.com/json-iterator/go v1.1.5 h1:gL2yXlmiIo4+t+y32d4WGwOjKGYcGOuyrg46vadswDE=

View File

@ -1,3 +1,3 @@
AWS SDK for Go
Copyright 2015 Amazon.com, Inc. or its affiliates. All Rights Reserved.
Copyright 2015 Amazon.com, Inc. or its affiliates. All Rights Reserved.
Copyright 2014-2015 Stripe, Inc.

View File

@ -138,8 +138,27 @@ type RequestFailure interface {
RequestID() string
}
// NewRequestFailure returns a new request error wrapper for the given Error
// provided.
// NewRequestFailure returns a wrapped error with additional information for
// request status code, and service requestID.
//
// Should be used to wrap all request which involve service requests. Even if
// the request failed without a service response, but had an HTTP status code
// that may be meaningful.
func NewRequestFailure(err Error, statusCode int, reqID string) RequestFailure {
return newRequestError(err, statusCode, reqID)
}
// UnmarshalError provides the interface for the SDK failing to unmarshal data.
type UnmarshalError interface {
awsError
Bytes() []byte
}
// NewUnmarshalError returns an initialized UnmarshalError error wrapper adding
// the bytes that fail to unmarshal to the error.
func NewUnmarshalError(err error, msg string, bytes []byte) UnmarshalError {
return &unmarshalError{
awsError: New("UnmarshalError", msg, err),
bytes: bytes,
}
}

View File

@ -1,6 +1,9 @@
package awserr
import "fmt"
import (
"encoding/hex"
"fmt"
)
// SprintError returns a string of the formatted error code.
//
@ -119,6 +122,7 @@ type requestError struct {
awsError
statusCode int
requestID string
bytes []byte
}
// newRequestError returns a wrapped error with additional information for
@ -170,6 +174,29 @@ func (r requestError) OrigErrs() []error {
return []error{r.OrigErr()}
}
type unmarshalError struct {
awsError
bytes []byte
}
// Error returns the string representation of the error.
// Satisfies the error interface.
func (e unmarshalError) Error() string {
extra := hex.Dump(e.bytes)
return SprintError(e.Code(), e.Message(), extra, e.OrigErr())
}
// String returns the string representation of the error.
// Alias for Error to satisfy the stringer interface.
func (e unmarshalError) String() string {
return e.Error()
}
// Bytes returns the bytes that failed to unmarshal.
func (e unmarshalError) Bytes() []byte {
return e.bytes
}
// An error list that satisfies the golang interface
type errorList []error
@ -181,7 +208,7 @@ func (e errorList) Error() string {
// How do we want to handle the array size being zero
if size := len(e); size > 0 {
for i := 0; i < size; i++ {
msg += fmt.Sprintf("%s", e[i].Error())
msg += e[i].Error()
// We check the next index to see if it is within the slice.
// If it is, then we append a newline. We do this, because unit tests
// could be broken with the additional '\n'

View File

@ -15,7 +15,7 @@ func DeepEqual(a, b interface{}) bool {
rb := reflect.Indirect(reflect.ValueOf(b))
if raValid, rbValid := ra.IsValid(), rb.IsValid(); !raValid && !rbValid {
// If the elements are both nil, and of the same type the are equal
// If the elements are both nil, and of the same type they are equal
// If they are of different types they are not equal
return reflect.TypeOf(a) == reflect.TypeOf(b)
} else if raValid != rbValid {

View File

@ -185,13 +185,12 @@ func ValuesAtPath(i interface{}, path string) ([]interface{}, error) {
// SetValueAtPath sets a value at the case insensitive lexical path inside
// of a structure.
func SetValueAtPath(i interface{}, path string, v interface{}) {
if rvals := rValuesAtPath(i, path, true, false, v == nil); rvals != nil {
for _, rval := range rvals {
if rval.Kind() == reflect.Ptr && rval.IsNil() {
continue
}
setValue(rval, v)
rvals := rValuesAtPath(i, path, true, false, v == nil)
for _, rval := range rvals {
if rval.Kind() == reflect.Ptr && rval.IsNil() {
continue
}
setValue(rval, v)
}
}

View File

@ -23,28 +23,27 @@ func stringValue(v reflect.Value, indent int, buf *bytes.Buffer) {
case reflect.Struct:
buf.WriteString("{\n")
names := []string{}
for i := 0; i < v.Type().NumField(); i++ {
name := v.Type().Field(i).Name
f := v.Field(i)
if name[0:1] == strings.ToLower(name[0:1]) {
ft := v.Type().Field(i)
fv := v.Field(i)
if ft.Name[0:1] == strings.ToLower(ft.Name[0:1]) {
continue // ignore unexported fields
}
if (f.Kind() == reflect.Ptr || f.Kind() == reflect.Slice) && f.IsNil() {
if (fv.Kind() == reflect.Ptr || fv.Kind() == reflect.Slice) && fv.IsNil() {
continue // ignore unset fields
}
names = append(names, name)
}
for i, n := range names {
val := v.FieldByName(n)
buf.WriteString(strings.Repeat(" ", indent+2))
buf.WriteString(n + ": ")
stringValue(val, indent+2, buf)
buf.WriteString(ft.Name + ": ")
if i < len(names)-1 {
buf.WriteString(",\n")
if tag := ft.Tag.Get("sensitive"); tag == "true" {
buf.WriteString("<sensitive>")
} else {
stringValue(fv, indent+2, buf)
}
buf.WriteString(",\n")
}
buf.WriteString("\n" + strings.Repeat(" ", indent) + "}")

View File

@ -12,13 +12,14 @@ import (
type Config struct {
Config *aws.Config
Handlers request.Handlers
PartitionID string
Endpoint string
SigningRegion string
SigningName string
// States that the signing name did not come from a modeled source but
// was derived based on other data. Used by service client constructors
// to determine if the signin name can be overriden based on metadata the
// to determine if the signin name can be overridden based on metadata the
// service has.
SigningNameDerived bool
}
@ -64,7 +65,7 @@ func New(cfg aws.Config, info metadata.ClientInfo, handlers request.Handlers, op
default:
maxRetries := aws.IntValue(cfg.MaxRetries)
if cfg.MaxRetries == nil || maxRetries == aws.UseServiceDefaultRetries {
maxRetries = 3
maxRetries = DefaultRetryerMaxNumRetries
}
svc.Retryer = DefaultRetryer{NumMaxRetries: maxRetries}
}

View File

@ -1,6 +1,7 @@
package client
import (
"math"
"strconv"
"time"
@ -9,82 +10,142 @@ import (
)
// DefaultRetryer implements basic retry logic using exponential backoff for
// most services. If you want to implement custom retry logic, implement the
// request.Retryer interface or create a structure type that composes this
// struct and override the specific methods. For example, to override only
// the MaxRetries method:
// most services. If you want to implement custom retry logic, you can implement the
// request.Retryer interface.
//
// type retryer struct {
// client.DefaultRetryer
// }
//
// // This implementation always has 100 max retries
// func (d retryer) MaxRetries() int { return 100 }
type DefaultRetryer struct {
// Num max Retries is the number of max retries that will be performed.
// By default, this is zero.
NumMaxRetries int
// MinRetryDelay is the minimum retry delay after which retry will be performed.
// If not set, the value is 0ns.
MinRetryDelay time.Duration
// MinThrottleRetryDelay is the minimum retry delay when throttled.
// If not set, the value is 0ns.
MinThrottleDelay time.Duration
// MaxRetryDelay is the maximum retry delay before which retry must be performed.
// If not set, the value is 0ns.
MaxRetryDelay time.Duration
// MaxThrottleDelay is the maximum retry delay when throttled.
// If not set, the value is 0ns.
MaxThrottleDelay time.Duration
}
const (
// DefaultRetryerMaxNumRetries sets maximum number of retries
DefaultRetryerMaxNumRetries = 3
// DefaultRetryerMinRetryDelay sets minimum retry delay
DefaultRetryerMinRetryDelay = 30 * time.Millisecond
// DefaultRetryerMinThrottleDelay sets minimum delay when throttled
DefaultRetryerMinThrottleDelay = 500 * time.Millisecond
// DefaultRetryerMaxRetryDelay sets maximum retry delay
DefaultRetryerMaxRetryDelay = 300 * time.Second
// DefaultRetryerMaxThrottleDelay sets maximum delay when throttled
DefaultRetryerMaxThrottleDelay = 300 * time.Second
)
// MaxRetries returns the number of maximum returns the service will use to make
// an individual API request.
func (d DefaultRetryer) MaxRetries() int {
return d.NumMaxRetries
}
// setRetryerDefaults sets the default values of the retryer if not set
func (d *DefaultRetryer) setRetryerDefaults() {
if d.MinRetryDelay == 0 {
d.MinRetryDelay = DefaultRetryerMinRetryDelay
}
if d.MaxRetryDelay == 0 {
d.MaxRetryDelay = DefaultRetryerMaxRetryDelay
}
if d.MinThrottleDelay == 0 {
d.MinThrottleDelay = DefaultRetryerMinThrottleDelay
}
if d.MaxThrottleDelay == 0 {
d.MaxThrottleDelay = DefaultRetryerMaxThrottleDelay
}
}
// RetryRules returns the delay duration before retrying this request again
func (d DefaultRetryer) RetryRules(r *request.Request) time.Duration {
// Set the upper limit of delay in retrying at ~five minutes
minTime := 30
throttle := d.shouldThrottle(r)
if throttle {
if delay, ok := getRetryDelay(r); ok {
return delay
}
minTime = 500
// if number of max retries is zero, no retries will be performed.
if d.NumMaxRetries == 0 {
return 0
}
// Sets default value for retryer members
d.setRetryerDefaults()
// minDelay is the minimum retryer delay
minDelay := d.MinRetryDelay
var initialDelay time.Duration
isThrottle := r.IsErrorThrottle()
if isThrottle {
if delay, ok := getRetryAfterDelay(r); ok {
initialDelay = delay
}
minDelay = d.MinThrottleDelay
}
retryCount := r.RetryCount
if throttle && retryCount > 8 {
retryCount = 8
} else if retryCount > 13 {
retryCount = 13
// maxDelay the maximum retryer delay
maxDelay := d.MaxRetryDelay
if isThrottle {
maxDelay = d.MaxThrottleDelay
}
delay := (1 << uint(retryCount)) * (sdkrand.SeededRand.Intn(minTime) + minTime)
return time.Duration(delay) * time.Millisecond
var delay time.Duration
// Logic to cap the retry count based on the minDelay provided
actualRetryCount := int(math.Log2(float64(minDelay))) + 1
if actualRetryCount < 63-retryCount {
delay = time.Duration(1<<uint64(retryCount)) * getJitterDelay(minDelay)
if delay > maxDelay {
delay = getJitterDelay(maxDelay / 2)
}
} else {
delay = getJitterDelay(maxDelay / 2)
}
return delay + initialDelay
}
// getJitterDelay returns a jittered delay for retry
func getJitterDelay(duration time.Duration) time.Duration {
return time.Duration(sdkrand.SeededRand.Int63n(int64(duration)) + int64(duration))
}
// ShouldRetry returns true if the request should be retried.
func (d DefaultRetryer) ShouldRetry(r *request.Request) bool {
// ShouldRetry returns false if number of max retries is 0.
if d.NumMaxRetries == 0 {
return false
}
// If one of the other handlers already set the retry state
// we don't want to override it based on the service's state
if r.Retryable != nil {
return *r.Retryable
}
if r.HTTPResponse.StatusCode >= 500 && r.HTTPResponse.StatusCode != 501 {
return true
}
return r.IsErrorRetryable() || d.shouldThrottle(r)
}
// ShouldThrottle returns true if the request should be throttled.
func (d DefaultRetryer) shouldThrottle(r *request.Request) bool {
switch r.HTTPResponse.StatusCode {
case 429:
case 502:
case 503:
case 504:
default:
return r.IsErrorThrottle()
}
return true
return r.IsErrorRetryable() || r.IsErrorThrottle()
}
// This will look in the Retry-After header, RFC 7231, for how long
// it will wait before attempting another request
func getRetryDelay(r *request.Request) (time.Duration, bool) {
func getRetryAfterDelay(r *request.Request) (time.Duration, bool) {
if !canUseRetryAfterHeader(r) {
return 0, false
}

View File

@ -67,10 +67,14 @@ func logRequest(r *request.Request) {
if !bodySeekable {
r.SetReaderBody(aws.ReadSeekCloser(r.HTTPRequest.Body))
}
// Reset the request body because dumpRequest will re-wrap the r.HTTPRequest's
// Body as a NoOpCloser and will not be reset after read by the HTTP
// client reader.
r.ResetBody()
// Reset the request body because dumpRequest will re-wrap the
// r.HTTPRequest's Body as a NoOpCloser and will not be reset after
// read by the HTTP client reader.
if err := r.Error; err != nil {
r.Config.Logger.Log(fmt.Sprintf(logReqErrMsg,
r.ClientInfo.ServiceName, r.Operation.Name, err))
return
}
}
r.Config.Logger.Log(fmt.Sprintf(logReqMsg,
@ -118,6 +122,12 @@ var LogHTTPResponseHandler = request.NamedHandler{
func logResponse(r *request.Request) {
lw := &logWriter{r.Config.Logger, bytes.NewBuffer(nil)}
if r.HTTPResponse == nil {
lw.Logger.Log(fmt.Sprintf(logRespErrMsg,
r.ClientInfo.ServiceName, r.Operation.Name, "request's HTTPResponse is nil"))
return
}
logBody := r.Config.LogLevel.Matches(aws.LogDebugWithHTTPBody)
if logBody {
r.HTTPResponse.Body = &teeReaderCloser{

View File

@ -5,6 +5,7 @@ type ClientInfo struct {
ServiceName string
ServiceID string
APIVersion string
PartitionID string
Endpoint string
SigningName string
SigningRegion string

View File

@ -0,0 +1,28 @@
package client
import (
"time"
"github.com/aws/aws-sdk-go/aws/request"
)
// NoOpRetryer provides a retryer that performs no retries.
// It should be used when we do not want retries to be performed.
type NoOpRetryer struct{}
// MaxRetries returns the number of maximum returns the service will use to make
// an individual API; For NoOpRetryer the MaxRetries will always be zero.
func (d NoOpRetryer) MaxRetries() int {
return 0
}
// ShouldRetry will always return false for NoOpRetryer, as it should never retry.
func (d NoOpRetryer) ShouldRetry(_ *request.Request) bool {
return false
}
// RetryRules returns the delay duration before retrying this request again;
// since NoOpRetryer does not retry, RetryRules always returns 0.
func (d NoOpRetryer) RetryRules(_ *request.Request) time.Duration {
return 0
}

View File

@ -18,9 +18,9 @@ const UseServiceDefaultRetries = -1
type RequestRetryer interface{}
// A Config provides service configuration for service clients. By default,
// all clients will use the defaults.DefaultConfig tructure.
// all clients will use the defaults.DefaultConfig structure.
//
// // Create Session with MaxRetry configuration to be shared by multiple
// // Create Session with MaxRetries configuration to be shared by multiple
// // service clients.
// sess := session.Must(session.NewSession(&aws.Config{
// MaxRetries: aws.Int(3),
@ -45,8 +45,8 @@ type Config struct {
// that overrides the default generated endpoint for a client. Set this
// to `""` to use the default generated endpoint.
//
// @note You must still provide a `Region` value when specifying an
// endpoint for a client.
// Note: You must still provide a `Region` value when specifying an
// endpoint for a client.
Endpoint *string
// The resolver to use for looking up endpoints for AWS service clients
@ -65,8 +65,8 @@ type Config struct {
// noted. A full list of regions is found in the "Regions and Endpoints"
// document.
//
// @see http://docs.aws.amazon.com/general/latest/gr/rande.html
// AWS Regions and Endpoints
// See http://docs.aws.amazon.com/general/latest/gr/rande.html for AWS
// Regions and Endpoints.
Region *string
// Set this to `true` to disable SSL when sending requests. Defaults
@ -120,9 +120,10 @@ type Config struct {
// will use virtual hosted bucket addressing when possible
// (`http://BUCKET.s3.amazonaws.com/KEY`).
//
// @note This configuration option is specific to the Amazon S3 service.
// @see http://docs.aws.amazon.com/AmazonS3/latest/dev/VirtualHosting.html
// Amazon S3: Virtual Hosting of Buckets
// Note: This configuration option is specific to the Amazon S3 service.
//
// See http://docs.aws.amazon.com/AmazonS3/latest/dev/VirtualHosting.html
// for Amazon S3: Virtual Hosting of Buckets
S3ForcePathStyle *bool
// Set this to `true` to disable the SDK adding the `Expect: 100-Continue`
@ -223,12 +224,37 @@ type Config struct {
// Key: aws.String("//foo//bar//moo"),
// })
DisableRestProtocolURICleaning *bool
// EnableEndpointDiscovery will allow for endpoint discovery on operations that
// have the definition in its model. By default, endpoint discovery is off.
//
// Example:
// sess := session.Must(session.NewSession(&aws.Config{
// EnableEndpointDiscovery: aws.Bool(true),
// }))
//
// svc := s3.New(sess)
// out, err := svc.GetObject(&s3.GetObjectInput {
// Bucket: aws.String("bucketname"),
// Key: aws.String("/foo/bar/moo"),
// })
EnableEndpointDiscovery *bool
// DisableEndpointHostPrefix will disable the SDK's behavior of prefixing
// request endpoint hosts with modeled information.
//
// Disabling this feature is useful when you want to use local endpoints
// for testing that do not support the modeled host prefix pattern.
DisableEndpointHostPrefix *bool
// STSRegionalEndpoint will enable regional or legacy endpoint resolving
STSRegionalEndpoint endpoints.STSRegionalEndpoint
}
// NewConfig returns a new Config pointer that can be chained with builder
// methods to set multiple configuration values inline without using pointers.
//
// // Create Session with MaxRetry configuration to be shared by multiple
// // Create Session with MaxRetries configuration to be shared by multiple
// // service clients.
// sess := session.Must(session.NewSession(aws.NewConfig().
// WithMaxRetries(3),
@ -377,6 +403,19 @@ func (c *Config) WithSleepDelay(fn func(time.Duration)) *Config {
return c
}
// WithEndpointDiscovery will set whether or not to use endpoint discovery.
func (c *Config) WithEndpointDiscovery(t bool) *Config {
c.EnableEndpointDiscovery = &t
return c
}
// WithDisableEndpointHostPrefix will set whether or not to use modeled host prefix
// when making requests.
func (c *Config) WithDisableEndpointHostPrefix(t bool) *Config {
c.DisableEndpointHostPrefix = &t
return c
}
// MergeIn merges the passed in configs into the existing config object.
func (c *Config) MergeIn(cfgs ...*Config) {
for _, other := range cfgs {
@ -384,6 +423,13 @@ func (c *Config) MergeIn(cfgs ...*Config) {
}
}
// WithSTSRegionalEndpoint will set whether or not to use regional endpoint flag
// when resolving the endpoint for a service
func (c *Config) WithSTSRegionalEndpoint(sre endpoints.STSRegionalEndpoint) *Config {
c.STSRegionalEndpoint = sre
return c
}
func mergeInConfig(dst *Config, other *Config) {
if other == nil {
return
@ -476,6 +522,18 @@ func mergeInConfig(dst *Config, other *Config) {
if other.EnforceShouldRetryCheck != nil {
dst.EnforceShouldRetryCheck = other.EnforceShouldRetryCheck
}
if other.EnableEndpointDiscovery != nil {
dst.EnableEndpointDiscovery = other.EnableEndpointDiscovery
}
if other.DisableEndpointHostPrefix != nil {
dst.DisableEndpointHostPrefix = other.DisableEndpointHostPrefix
}
if other.STSRegionalEndpoint != endpoints.UnsetSTSEndpoint {
dst.STSRegionalEndpoint = other.STSRegionalEndpoint
}
}
// Copy will return a shallow copy of the Config object. If any additional

View File

@ -1,8 +1,8 @@
// +build !go1.9
package aws
import (
"time"
)
import "time"
// Context is an copy of the Go v1.7 stdlib's context.Context interface.
// It is represented as a SDK interface to enable you to use the "WithContext"
@ -35,37 +35,3 @@ type Context interface {
// functions.
Value(key interface{}) interface{}
}
// BackgroundContext returns a context that will never be canceled, has no
// values, and no deadline. This context is used by the SDK to provide
// backwards compatibility with non-context API operations and functionality.
//
// Go 1.6 and before:
// This context function is equivalent to context.Background in the Go stdlib.
//
// Go 1.7 and later:
// The context returned will be the value returned by context.Background()
//
// See https://golang.org/pkg/context for more information on Contexts.
func BackgroundContext() Context {
return backgroundCtx
}
// SleepWithContext will wait for the timer duration to expire, or the context
// is canceled. Which ever happens first. If the context is canceled the Context's
// error will be returned.
//
// Expects Context to always return a non-nil error if the Done channel is closed.
func SleepWithContext(ctx Context, dur time.Duration) error {
t := time.NewTimer(dur)
defer t.Stop()
select {
case <-t.C:
break
case <-ctx.Done():
return ctx.Err()
}
return nil
}

View File

@ -1,9 +0,0 @@
// +build go1.7
package aws
import "context"
var (
backgroundCtx = context.Background()
)

11
vendor/github.com/aws/aws-sdk-go/aws/context_1_9.go generated vendored Normal file
View File

@ -0,0 +1,11 @@
// +build go1.9
package aws
import "context"
// Context is an alias of the Go stdlib's context.Context interface.
// It can be used within the SDK's API operation "WithContext" methods.
//
// See https://golang.org/pkg/context on how to use contexts.
type Context = context.Context

View File

@ -39,3 +39,18 @@ func (e *emptyCtx) String() string {
var (
backgroundCtx = new(emptyCtx)
)
// BackgroundContext returns a context that will never be canceled, has no
// values, and no deadline. This context is used by the SDK to provide
// backwards compatibility with non-context API operations and functionality.
//
// Go 1.6 and before:
// This context function is equivalent to context.Background in the Go stdlib.
//
// Go 1.7 and later:
// The context returned will be the value returned by context.Background()
//
// See https://golang.org/pkg/context for more information on Contexts.
func BackgroundContext() Context {
return backgroundCtx
}

View File

@ -0,0 +1,20 @@
// +build go1.7
package aws
import "context"
// BackgroundContext returns a context that will never be canceled, has no
// values, and no deadline. This context is used by the SDK to provide
// backwards compatibility with non-context API operations and functionality.
//
// Go 1.6 and before:
// This context function is equivalent to context.Background in the Go stdlib.
//
// Go 1.7 and later:
// The context returned will be the value returned by context.Background()
//
// See https://golang.org/pkg/context for more information on Contexts.
func BackgroundContext() Context {
return context.Background()
}

24
vendor/github.com/aws/aws-sdk-go/aws/context_sleep.go generated vendored Normal file
View File

@ -0,0 +1,24 @@
package aws
import (
"time"
)
// SleepWithContext will wait for the timer duration to expire, or the context
// is canceled. Which ever happens first. If the context is canceled the Context's
// error will be returned.
//
// Expects Context to always return a non-nil error if the Done channel is closed.
func SleepWithContext(ctx Context, dur time.Duration) error {
t := time.NewTimer(dur)
defer t.Stop()
select {
case <-t.C:
break
case <-ctx.Done():
return ctx.Err()
}
return nil
}

View File

@ -179,6 +179,242 @@ func IntValueMap(src map[string]*int) map[string]int {
return dst
}
// Uint returns a pointer to the uint value passed in.
func Uint(v uint) *uint {
return &v
}
// UintValue returns the value of the uint pointer passed in or
// 0 if the pointer is nil.
func UintValue(v *uint) uint {
if v != nil {
return *v
}
return 0
}
// UintSlice converts a slice of uint values uinto a slice of
// uint pointers
func UintSlice(src []uint) []*uint {
dst := make([]*uint, len(src))
for i := 0; i < len(src); i++ {
dst[i] = &(src[i])
}
return dst
}
// UintValueSlice converts a slice of uint pointers uinto a slice of
// uint values
func UintValueSlice(src []*uint) []uint {
dst := make([]uint, len(src))
for i := 0; i < len(src); i++ {
if src[i] != nil {
dst[i] = *(src[i])
}
}
return dst
}
// UintMap converts a string map of uint values uinto a string
// map of uint pointers
func UintMap(src map[string]uint) map[string]*uint {
dst := make(map[string]*uint)
for k, val := range src {
v := val
dst[k] = &v
}
return dst
}
// UintValueMap converts a string map of uint pointers uinto a string
// map of uint values
func UintValueMap(src map[string]*uint) map[string]uint {
dst := make(map[string]uint)
for k, val := range src {
if val != nil {
dst[k] = *val
}
}
return dst
}
// Int8 returns a pointer to the int8 value passed in.
func Int8(v int8) *int8 {
return &v
}
// Int8Value returns the value of the int8 pointer passed in or
// 0 if the pointer is nil.
func Int8Value(v *int8) int8 {
if v != nil {
return *v
}
return 0
}
// Int8Slice converts a slice of int8 values into a slice of
// int8 pointers
func Int8Slice(src []int8) []*int8 {
dst := make([]*int8, len(src))
for i := 0; i < len(src); i++ {
dst[i] = &(src[i])
}
return dst
}
// Int8ValueSlice converts a slice of int8 pointers into a slice of
// int8 values
func Int8ValueSlice(src []*int8) []int8 {
dst := make([]int8, len(src))
for i := 0; i < len(src); i++ {
if src[i] != nil {
dst[i] = *(src[i])
}
}
return dst
}
// Int8Map converts a string map of int8 values into a string
// map of int8 pointers
func Int8Map(src map[string]int8) map[string]*int8 {
dst := make(map[string]*int8)
for k, val := range src {
v := val
dst[k] = &v
}
return dst
}
// Int8ValueMap converts a string map of int8 pointers into a string
// map of int8 values
func Int8ValueMap(src map[string]*int8) map[string]int8 {
dst := make(map[string]int8)
for k, val := range src {
if val != nil {
dst[k] = *val
}
}
return dst
}
// Int16 returns a pointer to the int16 value passed in.
func Int16(v int16) *int16 {
return &v
}
// Int16Value returns the value of the int16 pointer passed in or
// 0 if the pointer is nil.
func Int16Value(v *int16) int16 {
if v != nil {
return *v
}
return 0
}
// Int16Slice converts a slice of int16 values into a slice of
// int16 pointers
func Int16Slice(src []int16) []*int16 {
dst := make([]*int16, len(src))
for i := 0; i < len(src); i++ {
dst[i] = &(src[i])
}
return dst
}
// Int16ValueSlice converts a slice of int16 pointers into a slice of
// int16 values
func Int16ValueSlice(src []*int16) []int16 {
dst := make([]int16, len(src))
for i := 0; i < len(src); i++ {
if src[i] != nil {
dst[i] = *(src[i])
}
}
return dst
}
// Int16Map converts a string map of int16 values into a string
// map of int16 pointers
func Int16Map(src map[string]int16) map[string]*int16 {
dst := make(map[string]*int16)
for k, val := range src {
v := val
dst[k] = &v
}
return dst
}
// Int16ValueMap converts a string map of int16 pointers into a string
// map of int16 values
func Int16ValueMap(src map[string]*int16) map[string]int16 {
dst := make(map[string]int16)
for k, val := range src {
if val != nil {
dst[k] = *val
}
}
return dst
}
// Int32 returns a pointer to the int32 value passed in.
func Int32(v int32) *int32 {
return &v
}
// Int32Value returns the value of the int32 pointer passed in or
// 0 if the pointer is nil.
func Int32Value(v *int32) int32 {
if v != nil {
return *v
}
return 0
}
// Int32Slice converts a slice of int32 values into a slice of
// int32 pointers
func Int32Slice(src []int32) []*int32 {
dst := make([]*int32, len(src))
for i := 0; i < len(src); i++ {
dst[i] = &(src[i])
}
return dst
}
// Int32ValueSlice converts a slice of int32 pointers into a slice of
// int32 values
func Int32ValueSlice(src []*int32) []int32 {
dst := make([]int32, len(src))
for i := 0; i < len(src); i++ {
if src[i] != nil {
dst[i] = *(src[i])
}
}
return dst
}
// Int32Map converts a string map of int32 values into a string
// map of int32 pointers
func Int32Map(src map[string]int32) map[string]*int32 {
dst := make(map[string]*int32)
for k, val := range src {
v := val
dst[k] = &v
}
return dst
}
// Int32ValueMap converts a string map of int32 pointers into a string
// map of int32 values
func Int32ValueMap(src map[string]*int32) map[string]int32 {
dst := make(map[string]int32)
for k, val := range src {
if val != nil {
dst[k] = *val
}
}
return dst
}
// Int64 returns a pointer to the int64 value passed in.
func Int64(v int64) *int64 {
return &v
@ -238,6 +474,301 @@ func Int64ValueMap(src map[string]*int64) map[string]int64 {
return dst
}
// Uint8 returns a pointer to the uint8 value passed in.
func Uint8(v uint8) *uint8 {
return &v
}
// Uint8Value returns the value of the uint8 pointer passed in or
// 0 if the pointer is nil.
func Uint8Value(v *uint8) uint8 {
if v != nil {
return *v
}
return 0
}
// Uint8Slice converts a slice of uint8 values into a slice of
// uint8 pointers
func Uint8Slice(src []uint8) []*uint8 {
dst := make([]*uint8, len(src))
for i := 0; i < len(src); i++ {
dst[i] = &(src[i])
}
return dst
}
// Uint8ValueSlice converts a slice of uint8 pointers into a slice of
// uint8 values
func Uint8ValueSlice(src []*uint8) []uint8 {
dst := make([]uint8, len(src))
for i := 0; i < len(src); i++ {
if src[i] != nil {
dst[i] = *(src[i])
}
}
return dst
}
// Uint8Map converts a string map of uint8 values into a string
// map of uint8 pointers
func Uint8Map(src map[string]uint8) map[string]*uint8 {
dst := make(map[string]*uint8)
for k, val := range src {
v := val
dst[k] = &v
}
return dst
}
// Uint8ValueMap converts a string map of uint8 pointers into a string
// map of uint8 values
func Uint8ValueMap(src map[string]*uint8) map[string]uint8 {
dst := make(map[string]uint8)
for k, val := range src {
if val != nil {
dst[k] = *val
}
}
return dst
}
// Uint16 returns a pointer to the uint16 value passed in.
func Uint16(v uint16) *uint16 {
return &v
}
// Uint16Value returns the value of the uint16 pointer passed in or
// 0 if the pointer is nil.
func Uint16Value(v *uint16) uint16 {
if v != nil {
return *v
}
return 0
}
// Uint16Slice converts a slice of uint16 values into a slice of
// uint16 pointers
func Uint16Slice(src []uint16) []*uint16 {
dst := make([]*uint16, len(src))
for i := 0; i < len(src); i++ {
dst[i] = &(src[i])
}
return dst
}
// Uint16ValueSlice converts a slice of uint16 pointers into a slice of
// uint16 values
func Uint16ValueSlice(src []*uint16) []uint16 {
dst := make([]uint16, len(src))
for i := 0; i < len(src); i++ {
if src[i] != nil {
dst[i] = *(src[i])
}
}
return dst
}
// Uint16Map converts a string map of uint16 values into a string
// map of uint16 pointers
func Uint16Map(src map[string]uint16) map[string]*uint16 {
dst := make(map[string]*uint16)
for k, val := range src {
v := val
dst[k] = &v
}
return dst
}
// Uint16ValueMap converts a string map of uint16 pointers into a string
// map of uint16 values
func Uint16ValueMap(src map[string]*uint16) map[string]uint16 {
dst := make(map[string]uint16)
for k, val := range src {
if val != nil {
dst[k] = *val
}
}
return dst
}
// Uint32 returns a pointer to the uint32 value passed in.
func Uint32(v uint32) *uint32 {
return &v
}
// Uint32Value returns the value of the uint32 pointer passed in or
// 0 if the pointer is nil.
func Uint32Value(v *uint32) uint32 {
if v != nil {
return *v
}
return 0
}
// Uint32Slice converts a slice of uint32 values into a slice of
// uint32 pointers
func Uint32Slice(src []uint32) []*uint32 {
dst := make([]*uint32, len(src))
for i := 0; i < len(src); i++ {
dst[i] = &(src[i])
}
return dst
}
// Uint32ValueSlice converts a slice of uint32 pointers into a slice of
// uint32 values
func Uint32ValueSlice(src []*uint32) []uint32 {
dst := make([]uint32, len(src))
for i := 0; i < len(src); i++ {
if src[i] != nil {
dst[i] = *(src[i])
}
}
return dst
}
// Uint32Map converts a string map of uint32 values into a string
// map of uint32 pointers
func Uint32Map(src map[string]uint32) map[string]*uint32 {
dst := make(map[string]*uint32)
for k, val := range src {
v := val
dst[k] = &v
}
return dst
}
// Uint32ValueMap converts a string map of uint32 pointers into a string
// map of uint32 values
func Uint32ValueMap(src map[string]*uint32) map[string]uint32 {
dst := make(map[string]uint32)
for k, val := range src {
if val != nil {
dst[k] = *val
}
}
return dst
}
// Uint64 returns a pointer to the uint64 value passed in.
func Uint64(v uint64) *uint64 {
return &v
}
// Uint64Value returns the value of the uint64 pointer passed in or
// 0 if the pointer is nil.
func Uint64Value(v *uint64) uint64 {
if v != nil {
return *v
}
return 0
}
// Uint64Slice converts a slice of uint64 values into a slice of
// uint64 pointers
func Uint64Slice(src []uint64) []*uint64 {
dst := make([]*uint64, len(src))
for i := 0; i < len(src); i++ {
dst[i] = &(src[i])
}
return dst
}
// Uint64ValueSlice converts a slice of uint64 pointers into a slice of
// uint64 values
func Uint64ValueSlice(src []*uint64) []uint64 {
dst := make([]uint64, len(src))
for i := 0; i < len(src); i++ {
if src[i] != nil {
dst[i] = *(src[i])
}
}
return dst
}
// Uint64Map converts a string map of uint64 values into a string
// map of uint64 pointers
func Uint64Map(src map[string]uint64) map[string]*uint64 {
dst := make(map[string]*uint64)
for k, val := range src {
v := val
dst[k] = &v
}
return dst
}
// Uint64ValueMap converts a string map of uint64 pointers into a string
// map of uint64 values
func Uint64ValueMap(src map[string]*uint64) map[string]uint64 {
dst := make(map[string]uint64)
for k, val := range src {
if val != nil {
dst[k] = *val
}
}
return dst
}
// Float32 returns a pointer to the float32 value passed in.
func Float32(v float32) *float32 {
return &v
}
// Float32Value returns the value of the float32 pointer passed in or
// 0 if the pointer is nil.
func Float32Value(v *float32) float32 {
if v != nil {
return *v
}
return 0
}
// Float32Slice converts a slice of float32 values into a slice of
// float32 pointers
func Float32Slice(src []float32) []*float32 {
dst := make([]*float32, len(src))
for i := 0; i < len(src); i++ {
dst[i] = &(src[i])
}
return dst
}
// Float32ValueSlice converts a slice of float32 pointers into a slice of
// float32 values
func Float32ValueSlice(src []*float32) []float32 {
dst := make([]float32, len(src))
for i := 0; i < len(src); i++ {
if src[i] != nil {
dst[i] = *(src[i])
}
}
return dst
}
// Float32Map converts a string map of float32 values into a string
// map of float32 pointers
func Float32Map(src map[string]float32) map[string]*float32 {
dst := make(map[string]*float32)
for k, val := range src {
v := val
dst[k] = &v
}
return dst
}
// Float32ValueMap converts a string map of float32 pointers into a string
// map of float32 values
func Float32ValueMap(src map[string]*float32) map[string]float32 {
dst := make(map[string]float32)
for k, val := range src {
if val != nil {
dst[k] = *val
}
}
return dst
}
// Float64 returns a pointer to the float64 value passed in.
func Float64(v float64) *float64 {
return &v

View File

@ -72,9 +72,9 @@ var ValidateReqSigHandler = request.NamedHandler{
signedTime = r.LastSignedAt
}
// 10 minutes to allow for some clock skew/delays in transmission.
// 5 minutes to allow for some clock skew/delays in transmission.
// Would be improved with aws/aws-sdk-go#423
if signedTime.Add(10 * time.Minute).After(time.Now()) {
if signedTime.Add(5 * time.Minute).After(time.Now()) {
return
}
@ -159,9 +159,9 @@ func handleSendError(r *request.Request, err error) {
Body: ioutil.NopCloser(bytes.NewReader([]byte{})),
}
}
// Catch all other request errors.
// Catch all request errors, and let the default retrier determine
// if the error is retryable.
r.Error = awserr.New("RequestError", "send request failed", err)
r.Retryable = aws.Bool(true) // network errors are retryable
// Override the error with a context canceled error, if that was canceled.
ctx := r.Context()
@ -184,37 +184,39 @@ var ValidateResponseHandler = request.NamedHandler{Name: "core.ValidateResponseH
// AfterRetryHandler performs final checks to determine if the request should
// be retried and how long to delay.
var AfterRetryHandler = request.NamedHandler{Name: "core.AfterRetryHandler", Fn: func(r *request.Request) {
// If one of the other handlers already set the retry state
// we don't want to override it based on the service's state
if r.Retryable == nil || aws.BoolValue(r.Config.EnforceShouldRetryCheck) {
r.Retryable = aws.Bool(r.ShouldRetry(r))
}
if r.WillRetry() {
r.RetryDelay = r.RetryRules(r)
if sleepFn := r.Config.SleepDelay; sleepFn != nil {
// Support SleepDelay for backwards compatibility and testing
sleepFn(r.RetryDelay)
} else if err := aws.SleepWithContext(r.Context(), r.RetryDelay); err != nil {
r.Error = awserr.New(request.CanceledErrorCode,
"request context canceled", err)
r.Retryable = aws.Bool(false)
return
var AfterRetryHandler = request.NamedHandler{
Name: "core.AfterRetryHandler",
Fn: func(r *request.Request) {
// If one of the other handlers already set the retry state
// we don't want to override it based on the service's state
if r.Retryable == nil || aws.BoolValue(r.Config.EnforceShouldRetryCheck) {
r.Retryable = aws.Bool(r.ShouldRetry(r))
}
// when the expired token exception occurs the credentials
// need to be expired locally so that the next request to
// get credentials will trigger a credentials refresh.
if r.IsErrorExpired() {
r.Config.Credentials.Expire()
}
if r.WillRetry() {
r.RetryDelay = r.RetryRules(r)
r.RetryCount++
r.Error = nil
}
}}
if sleepFn := r.Config.SleepDelay; sleepFn != nil {
// Support SleepDelay for backwards compatibility and testing
sleepFn(r.RetryDelay)
} else if err := aws.SleepWithContext(r.Context(), r.RetryDelay); err != nil {
r.Error = awserr.New(request.CanceledErrorCode,
"request context canceled", err)
r.Retryable = aws.Bool(false)
return
}
// when the expired token exception occurs the credentials
// need to be expired locally so that the next request to
// get credentials will trigger a credentials refresh.
if r.IsErrorExpired() {
r.Config.Credentials.Expire()
}
r.RetryCount++
r.Error = nil
}
}}
// ValidateEndpointHandler is a request handler to validate a request had the
// appropriate Region and Endpoint set. Will set r.Error if the endpoint or

View File

@ -17,7 +17,7 @@ var SDKVersionUserAgentHandler = request.NamedHandler{
}
const execEnvVar = `AWS_EXECUTION_ENV`
const execEnvUAKey = `exec_env`
const execEnvUAKey = `exec-env`
// AddHostExecEnvUserAgentHander is a request handler appending the SDK's
// execution environment to the user agent.

View File

@ -9,9 +9,7 @@ var (
// providers in the ChainProvider.
//
// This has been deprecated. For verbose error messaging set
// aws.Config.CredentialsChainVerboseErrors to true
//
// @readonly
// aws.Config.CredentialsChainVerboseErrors to true.
ErrNoValidProvidersFoundInChain = awserr.New("NoCredentialProviders",
`no valid providers in chain. Deprecated.
For verbose messaging see aws.Config.CredentialsChainVerboseErrors`,

View File

@ -49,8 +49,11 @@
package credentials
import (
"fmt"
"sync"
"time"
"github.com/aws/aws-sdk-go/aws/awserr"
)
// AnonymousCredentials is an empty Credential object that can be used as
@ -64,8 +67,6 @@ import (
// Credentials: credentials.AnonymousCredentials,
// })))
// // Access public S3 buckets.
//
// @readonly
var AnonymousCredentials = NewStaticCredentials("", "", "")
// A Value is the AWS credentials value for individual credential fields.
@ -83,6 +84,12 @@ type Value struct {
ProviderName string
}
// HasKeys returns if the credentials Value has both AccessKeyID and
// SecretAccessKey value set.
func (v Value) HasKeys() bool {
return len(v.AccessKeyID) != 0 && len(v.SecretAccessKey) != 0
}
// A Provider is the interface for any component which will provide credentials
// Value. A provider is required to manage its own Expired state, and what to
// be expired means.
@ -99,6 +106,14 @@ type Provider interface {
IsExpired() bool
}
// An Expirer is an interface that Providers can implement to expose the expiration
// time, if known. If the Provider cannot accurately provide this info,
// it should not implement this interface.
type Expirer interface {
// The time at which the credentials are no longer valid
ExpiresAt() time.Time
}
// An ErrorProvider is a stub credentials provider that always returns an error
// this is used by the SDK when construction a known provider is not possible
// due to an error.
@ -165,6 +180,11 @@ func (e *Expiry) IsExpired() bool {
return e.expiration.Before(curTime())
}
// ExpiresAt returns the expiration time of the credential
func (e *Expiry) ExpiresAt() time.Time {
return e.expiration
}
// A Credentials provides concurrency safe retrieval of AWS credentials Value.
// Credentials will cache the credentials value until they expire. Once the value
// expires the next Get will attempt to retrieve valid credentials.
@ -257,3 +277,23 @@ func (c *Credentials) IsExpired() bool {
func (c *Credentials) isExpired() bool {
return c.forceRefresh || c.provider.IsExpired()
}
// ExpiresAt provides access to the functionality of the Expirer interface of
// the underlying Provider, if it supports that interface. Otherwise, it returns
// an error.
func (c *Credentials) ExpiresAt() (time.Time, error) {
c.m.RLock()
defer c.m.RUnlock()
expirer, ok := c.provider.(Expirer)
if !ok {
return time.Time{}, awserr.New("ProviderNotExpirer",
fmt.Sprintf("provider %s does not support ExpiresAt()", c.creds.ProviderName),
nil)
}
if c.forceRefresh {
// set expiration time to the distant past
return time.Time{}, nil
}
return expirer.ExpiresAt(), nil
}

View File

@ -11,6 +11,7 @@ import (
"github.com/aws/aws-sdk-go/aws/client"
"github.com/aws/aws-sdk-go/aws/credentials"
"github.com/aws/aws-sdk-go/aws/ec2metadata"
"github.com/aws/aws-sdk-go/aws/request"
"github.com/aws/aws-sdk-go/internal/sdkuri"
)
@ -142,7 +143,8 @@ func requestCredList(client *ec2metadata.EC2Metadata) ([]string, error) {
}
if err := s.Err(); err != nil {
return nil, awserr.New("SerializationError", "failed to read EC2 instance role from metadata service", err)
return nil, awserr.New(request.ErrCodeSerialization,
"failed to read EC2 instance role from metadata service", err)
}
return credsList, nil
@ -164,7 +166,7 @@ func requestCred(client *ec2metadata.EC2Metadata, credsName string) (ec2RoleCred
respCreds := ec2RoleCredRespBody{}
if err := json.NewDecoder(strings.NewReader(resp)).Decode(&respCreds); err != nil {
return ec2RoleCredRespBody{},
awserr.New("SerializationError",
awserr.New(request.ErrCodeSerialization,
fmt.Sprintf("failed to decode %s EC2 instance role credentials", credsName),
err)
}

View File

@ -39,6 +39,7 @@ import (
"github.com/aws/aws-sdk-go/aws/client/metadata"
"github.com/aws/aws-sdk-go/aws/credentials"
"github.com/aws/aws-sdk-go/aws/request"
"github.com/aws/aws-sdk-go/private/protocol/json/jsonutil"
)
// ProviderName is the name of the credentials provider.
@ -65,6 +66,10 @@ type Provider struct {
//
// If ExpiryWindow is 0 or less it will be ignored.
ExpiryWindow time.Duration
// Optional authorization token value if set will be used as the value of
// the Authorization header of the endpoint credential request.
AuthorizationToken string
}
// NewProviderClient returns a credentials Provider for retrieving AWS credentials
@ -93,8 +98,8 @@ func NewProviderClient(cfg aws.Config, handlers request.Handlers, endpoint strin
return p
}
// NewCredentialsClient returns a Credentials wrapper for retrieving credentials
// from an arbitrary endpoint concurrently. The client will request the
// NewCredentialsClient returns a pointer to a new Credentials object
// wrapping the endpoint credentials Provider.
func NewCredentialsClient(cfg aws.Config, handlers request.Handlers, endpoint string, options ...func(*Provider)) *credentials.Credentials {
return credentials.NewCredentials(NewProviderClient(cfg, handlers, endpoint, options...))
}
@ -152,6 +157,9 @@ func (p *Provider) getCredentials() (*getCredentialsOutput, error) {
out := &getCredentialsOutput{}
req := p.Client.NewRequest(op, nil, out)
req.HTTPRequest.Header.Set("Accept", "application/json")
if authToken := p.AuthorizationToken; len(authToken) != 0 {
req.HTTPRequest.Header.Set("Authorization", authToken)
}
return out, req.Send()
}
@ -167,7 +175,7 @@ func unmarshalHandler(r *request.Request) {
out := r.Data.(*getCredentialsOutput)
if err := json.NewDecoder(r.HTTPResponse.Body).Decode(&out); err != nil {
r.Error = awserr.New("SerializationError",
r.Error = awserr.New(request.ErrCodeSerialization,
"failed to decode endpoint credentials",
err,
)
@ -178,11 +186,15 @@ func unmarshalError(r *request.Request) {
defer r.HTTPResponse.Body.Close()
var errOut errorOutput
if err := json.NewDecoder(r.HTTPResponse.Body).Decode(&errOut); err != nil {
r.Error = awserr.New("SerializationError",
"failed to decode endpoint credentials",
err,
err := jsonutil.UnmarshalJSONError(&errOut, r.HTTPResponse.Body)
if err != nil {
r.Error = awserr.NewRequestFailure(
awserr.New(request.ErrCodeSerialization,
"failed to decode error message", err),
r.HTTPResponse.StatusCode,
r.RequestID,
)
return
}
// Response body format is not consistent between metadata endpoints.

View File

@ -12,14 +12,10 @@ const EnvProviderName = "EnvProvider"
var (
// ErrAccessKeyIDNotFound is returned when the AWS Access Key ID can't be
// found in the process's environment.
//
// @readonly
ErrAccessKeyIDNotFound = awserr.New("EnvAccessKeyNotFound", "AWS_ACCESS_KEY_ID or AWS_ACCESS_KEY not found in environment", nil)
// ErrSecretAccessKeyNotFound is returned when the AWS Secret Access Key
// can't be found in the process's environment.
//
// @readonly
ErrSecretAccessKeyNotFound = awserr.New("EnvSecretNotFound", "AWS_SECRET_ACCESS_KEY or AWS_SECRET_KEY not found in environment", nil)
)

View File

@ -0,0 +1,425 @@
/*
Package processcreds is a credential Provider to retrieve `credential_process`
credentials.
WARNING: The following describes a method of sourcing credentials from an external
process. This can potentially be dangerous, so proceed with caution. Other
credential providers should be preferred if at all possible. If using this
option, you should make sure that the config file is as locked down as possible
using security best practices for your operating system.
You can use credentials from a `credential_process` in a variety of ways.
One way is to setup your shared config file, located in the default
location, with the `credential_process` key and the command you want to be
called. You also need to set the AWS_SDK_LOAD_CONFIG environment variable
(e.g., `export AWS_SDK_LOAD_CONFIG=1`) to use the shared config file.
[default]
credential_process = /command/to/call
Creating a new session will use the credential process to retrieve credentials.
NOTE: If there are credentials in the profile you are using, the credential
process will not be used.
// Initialize a session to load credentials.
sess, _ := session.NewSession(&aws.Config{
Region: aws.String("us-east-1")},
)
// Create S3 service client to use the credentials.
svc := s3.New(sess)
Another way to use the `credential_process` method is by using
`credentials.NewCredentials()` and providing a command to be executed to
retrieve credentials:
// Create credentials using the ProcessProvider.
creds := processcreds.NewCredentials("/path/to/command")
// Create service client value configured for credentials.
svc := s3.New(sess, &aws.Config{Credentials: creds})
You can set a non-default timeout for the `credential_process` with another
constructor, `credentials.NewCredentialsTimeout()`, providing the timeout. To
set a one minute timeout:
// Create credentials using the ProcessProvider.
creds := processcreds.NewCredentialsTimeout(
"/path/to/command",
time.Duration(500) * time.Millisecond)
If you need more control, you can set any configurable options in the
credentials using one or more option functions. For example, you can set a two
minute timeout, a credential duration of 60 minutes, and a maximum stdout
buffer size of 2k.
creds := processcreds.NewCredentials(
"/path/to/command",
func(opt *ProcessProvider) {
opt.Timeout = time.Duration(2) * time.Minute
opt.Duration = time.Duration(60) * time.Minute
opt.MaxBufSize = 2048
})
You can also use your own `exec.Cmd`:
// Create an exec.Cmd
myCommand := exec.Command("/path/to/command")
// Create credentials using your exec.Cmd and custom timeout
creds := processcreds.NewCredentialsCommand(
myCommand,
func(opt *processcreds.ProcessProvider) {
opt.Timeout = time.Duration(1) * time.Second
})
*/
package processcreds
import (
"bytes"
"encoding/json"
"fmt"
"io"
"io/ioutil"
"os"
"os/exec"
"runtime"
"strings"
"time"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/aws/credentials"
)
const (
// ProviderName is the name this credentials provider will label any
// returned credentials Value with.
ProviderName = `ProcessProvider`
// ErrCodeProcessProviderParse error parsing process output
ErrCodeProcessProviderParse = "ProcessProviderParseError"
// ErrCodeProcessProviderVersion version error in output
ErrCodeProcessProviderVersion = "ProcessProviderVersionError"
// ErrCodeProcessProviderRequired required attribute missing in output
ErrCodeProcessProviderRequired = "ProcessProviderRequiredError"
// ErrCodeProcessProviderExecution execution of command failed
ErrCodeProcessProviderExecution = "ProcessProviderExecutionError"
// errMsgProcessProviderTimeout process took longer than allowed
errMsgProcessProviderTimeout = "credential process timed out"
// errMsgProcessProviderProcess process error
errMsgProcessProviderProcess = "error in credential_process"
// errMsgProcessProviderParse problem parsing output
errMsgProcessProviderParse = "parse failed of credential_process output"
// errMsgProcessProviderVersion version error in output
errMsgProcessProviderVersion = "wrong version in process output (not 1)"
// errMsgProcessProviderMissKey missing access key id in output
errMsgProcessProviderMissKey = "missing AccessKeyId in process output"
// errMsgProcessProviderMissSecret missing secret acess key in output
errMsgProcessProviderMissSecret = "missing SecretAccessKey in process output"
// errMsgProcessProviderPrepareCmd prepare of command failed
errMsgProcessProviderPrepareCmd = "failed to prepare command"
// errMsgProcessProviderEmptyCmd command must not be empty
errMsgProcessProviderEmptyCmd = "command must not be empty"
// errMsgProcessProviderPipe failed to initialize pipe
errMsgProcessProviderPipe = "failed to initialize pipe"
// DefaultDuration is the default amount of time in minutes that the
// credentials will be valid for.
DefaultDuration = time.Duration(15) * time.Minute
// DefaultBufSize limits buffer size from growing to an enormous
// amount due to a faulty process.
DefaultBufSize = 1024
// DefaultTimeout default limit on time a process can run.
DefaultTimeout = time.Duration(1) * time.Minute
)
// ProcessProvider satisfies the credentials.Provider interface, and is a
// client to retrieve credentials from a process.
type ProcessProvider struct {
staticCreds bool
credentials.Expiry
originalCommand []string
// Expiry duration of the credentials. Defaults to 15 minutes if not set.
Duration time.Duration
// ExpiryWindow will allow the credentials to trigger refreshing prior to
// the credentials actually expiring. This is beneficial so race conditions
// with expiring credentials do not cause request to fail unexpectedly
// due to ExpiredTokenException exceptions.
//
// So a ExpiryWindow of 10s would cause calls to IsExpired() to return true
// 10 seconds before the credentials are actually expired.
//
// If ExpiryWindow is 0 or less it will be ignored.
ExpiryWindow time.Duration
// A string representing an os command that should return a JSON with
// credential information.
command *exec.Cmd
// MaxBufSize limits memory usage from growing to an enormous
// amount due to a faulty process.
MaxBufSize int
// Timeout limits the time a process can run.
Timeout time.Duration
}
// NewCredentials returns a pointer to a new Credentials object wrapping the
// ProcessProvider. The credentials will expire every 15 minutes by default.
func NewCredentials(command string, options ...func(*ProcessProvider)) *credentials.Credentials {
p := &ProcessProvider{
command: exec.Command(command),
Duration: DefaultDuration,
Timeout: DefaultTimeout,
MaxBufSize: DefaultBufSize,
}
for _, option := range options {
option(p)
}
return credentials.NewCredentials(p)
}
// NewCredentialsTimeout returns a pointer to a new Credentials object with
// the specified command and timeout, and default duration and max buffer size.
func NewCredentialsTimeout(command string, timeout time.Duration) *credentials.Credentials {
p := NewCredentials(command, func(opt *ProcessProvider) {
opt.Timeout = timeout
})
return p
}
// NewCredentialsCommand returns a pointer to a new Credentials object with
// the specified command, and default timeout, duration and max buffer size.
func NewCredentialsCommand(command *exec.Cmd, options ...func(*ProcessProvider)) *credentials.Credentials {
p := &ProcessProvider{
command: command,
Duration: DefaultDuration,
Timeout: DefaultTimeout,
MaxBufSize: DefaultBufSize,
}
for _, option := range options {
option(p)
}
return credentials.NewCredentials(p)
}
type credentialProcessResponse struct {
Version int
AccessKeyID string `json:"AccessKeyId"`
SecretAccessKey string
SessionToken string
Expiration *time.Time
}
// Retrieve executes the 'credential_process' and returns the credentials.
func (p *ProcessProvider) Retrieve() (credentials.Value, error) {
out, err := p.executeCredentialProcess()
if err != nil {
return credentials.Value{ProviderName: ProviderName}, err
}
// Serialize and validate response
resp := &credentialProcessResponse{}
if err = json.Unmarshal(out, resp); err != nil {
return credentials.Value{ProviderName: ProviderName}, awserr.New(
ErrCodeProcessProviderParse,
fmt.Sprintf("%s: %s", errMsgProcessProviderParse, string(out)),
err)
}
if resp.Version != 1 {
return credentials.Value{ProviderName: ProviderName}, awserr.New(
ErrCodeProcessProviderVersion,
errMsgProcessProviderVersion,
nil)
}
if len(resp.AccessKeyID) == 0 {
return credentials.Value{ProviderName: ProviderName}, awserr.New(
ErrCodeProcessProviderRequired,
errMsgProcessProviderMissKey,
nil)
}
if len(resp.SecretAccessKey) == 0 {
return credentials.Value{ProviderName: ProviderName}, awserr.New(
ErrCodeProcessProviderRequired,
errMsgProcessProviderMissSecret,
nil)
}
// Handle expiration
p.staticCreds = resp.Expiration == nil
if resp.Expiration != nil {
p.SetExpiration(*resp.Expiration, p.ExpiryWindow)
}
return credentials.Value{
ProviderName: ProviderName,
AccessKeyID: resp.AccessKeyID,
SecretAccessKey: resp.SecretAccessKey,
SessionToken: resp.SessionToken,
}, nil
}
// IsExpired returns true if the credentials retrieved are expired, or not yet
// retrieved.
func (p *ProcessProvider) IsExpired() bool {
if p.staticCreds {
return false
}
return p.Expiry.IsExpired()
}
// prepareCommand prepares the command to be executed.
func (p *ProcessProvider) prepareCommand() error {
var cmdArgs []string
if runtime.GOOS == "windows" {
cmdArgs = []string{"cmd.exe", "/C"}
} else {
cmdArgs = []string{"sh", "-c"}
}
if len(p.originalCommand) == 0 {
p.originalCommand = make([]string, len(p.command.Args))
copy(p.originalCommand, p.command.Args)
// check for empty command because it succeeds
if len(strings.TrimSpace(p.originalCommand[0])) < 1 {
return awserr.New(
ErrCodeProcessProviderExecution,
fmt.Sprintf(
"%s: %s",
errMsgProcessProviderPrepareCmd,
errMsgProcessProviderEmptyCmd),
nil)
}
}
cmdArgs = append(cmdArgs, p.originalCommand...)
p.command = exec.Command(cmdArgs[0], cmdArgs[1:]...)
p.command.Env = os.Environ()
return nil
}
// executeCredentialProcess starts the credential process on the OS and
// returns the results or an error.
func (p *ProcessProvider) executeCredentialProcess() ([]byte, error) {
if err := p.prepareCommand(); err != nil {
return nil, err
}
// Setup the pipes
outReadPipe, outWritePipe, err := os.Pipe()
if err != nil {
return nil, awserr.New(
ErrCodeProcessProviderExecution,
errMsgProcessProviderPipe,
err)
}
p.command.Stderr = os.Stderr // display stderr on console for MFA
p.command.Stdout = outWritePipe // get creds json on process's stdout
p.command.Stdin = os.Stdin // enable stdin for MFA
output := bytes.NewBuffer(make([]byte, 0, p.MaxBufSize))
stdoutCh := make(chan error, 1)
go readInput(
io.LimitReader(outReadPipe, int64(p.MaxBufSize)),
output,
stdoutCh)
execCh := make(chan error, 1)
go executeCommand(*p.command, execCh)
finished := false
var errors []error
for !finished {
select {
case readError := <-stdoutCh:
errors = appendError(errors, readError)
finished = true
case execError := <-execCh:
err := outWritePipe.Close()
errors = appendError(errors, err)
errors = appendError(errors, execError)
if errors != nil {
return output.Bytes(), awserr.NewBatchError(
ErrCodeProcessProviderExecution,
errMsgProcessProviderProcess,
errors)
}
case <-time.After(p.Timeout):
finished = true
return output.Bytes(), awserr.NewBatchError(
ErrCodeProcessProviderExecution,
errMsgProcessProviderTimeout,
errors) // errors can be nil
}
}
out := output.Bytes()
if runtime.GOOS == "windows" {
// windows adds slashes to quotes
out = []byte(strings.Replace(string(out), `\"`, `"`, -1))
}
return out, nil
}
// appendError conveniently checks for nil before appending slice
func appendError(errors []error, err error) []error {
if err != nil {
return append(errors, err)
}
return errors
}
func executeCommand(cmd exec.Cmd, exec chan error) {
// Start the command
err := cmd.Start()
if err == nil {
err = cmd.Wait()
}
exec <- err
}
func readInput(r io.Reader, w io.Writer, read chan error) {
tee := io.TeeReader(r, w)
_, err := ioutil.ReadAll(tee)
if err == io.EOF {
err = nil
}
read <- err // will only arrive here when write end of pipe is closed
}

View File

@ -4,9 +4,8 @@ import (
"fmt"
"os"
"github.com/go-ini/ini"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/internal/ini"
"github.com/aws/aws-sdk-go/internal/shareddefaults"
)
@ -77,36 +76,37 @@ func (p *SharedCredentialsProvider) IsExpired() bool {
// The credentials retrieved from the profile will be returned or error. Error will be
// returned if it fails to read from the file, or the data is invalid.
func loadProfile(filename, profile string) (Value, error) {
config, err := ini.Load(filename)
config, err := ini.OpenFile(filename)
if err != nil {
return Value{ProviderName: SharedCredsProviderName}, awserr.New("SharedCredsLoad", "failed to load shared credentials file", err)
}
iniProfile, err := config.GetSection(profile)
if err != nil {
return Value{ProviderName: SharedCredsProviderName}, awserr.New("SharedCredsLoad", "failed to get profile", err)
iniProfile, ok := config.GetSection(profile)
if !ok {
return Value{ProviderName: SharedCredsProviderName}, awserr.New("SharedCredsLoad", "failed to get profile", nil)
}
id, err := iniProfile.GetKey("aws_access_key_id")
if err != nil {
id := iniProfile.String("aws_access_key_id")
if len(id) == 0 {
return Value{ProviderName: SharedCredsProviderName}, awserr.New("SharedCredsAccessKey",
fmt.Sprintf("shared credentials %s in %s did not contain aws_access_key_id", profile, filename),
err)
nil)
}
secret, err := iniProfile.GetKey("aws_secret_access_key")
if err != nil {
secret := iniProfile.String("aws_secret_access_key")
if len(secret) == 0 {
return Value{ProviderName: SharedCredsProviderName}, awserr.New("SharedCredsSecret",
fmt.Sprintf("shared credentials %s in %s did not contain aws_secret_access_key", profile, filename),
nil)
}
// Default to empty string if not found
token := iniProfile.Key("aws_session_token")
token := iniProfile.String("aws_session_token")
return Value{
AccessKeyID: id.String(),
SecretAccessKey: secret.String(),
SessionToken: token.String(),
AccessKeyID: id,
SecretAccessKey: secret,
SessionToken: token,
ProviderName: SharedCredsProviderName,
}, nil
}

View File

@ -9,8 +9,6 @@ const StaticProviderName = "StaticProvider"
var (
// ErrStaticCredentialsEmpty is emitted when static credentials are empty.
//
// @readonly
ErrStaticCredentialsEmpty = awserr.New("EmptyStaticCreds", "static credentials are empty", nil)
)

View File

@ -80,16 +80,18 @@ package stscreds
import (
"fmt"
"os"
"time"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/aws/client"
"github.com/aws/aws-sdk-go/aws/credentials"
"github.com/aws/aws-sdk-go/internal/sdkrand"
"github.com/aws/aws-sdk-go/service/sts"
)
// StdinTokenProvider will prompt on stdout and read from stdin for a string value.
// StdinTokenProvider will prompt on stderr and read from stdin for a string value.
// An error is returned if reading from stdin fails.
//
// Use this function go read MFA tokens from stdin. The function makes no attempt
@ -102,7 +104,7 @@ import (
// Will wait forever until something is provided on the stdin.
func StdinTokenProvider() (string, error) {
var v string
fmt.Printf("Assume Role MFA token code: ")
fmt.Fprintf(os.Stderr, "Assume Role MFA token code: ")
_, err := fmt.Scanln(&v)
return v, err
@ -193,6 +195,18 @@ type AssumeRoleProvider struct {
//
// If ExpiryWindow is 0 or less it will be ignored.
ExpiryWindow time.Duration
// MaxJitterFrac reduces the effective Duration of each credential requested
// by a random percentage between 0 and MaxJitterFraction. MaxJitterFrac must
// have a value between 0 and 1. Any other value may lead to expected behavior.
// With a MaxJitterFrac value of 0, default) will no jitter will be used.
//
// For example, with a Duration of 30m and a MaxJitterFrac of 0.1, the
// AssumeRole call will be made with an arbitrary Duration between 27m and
// 30m.
//
// MaxJitterFrac should not be negative.
MaxJitterFrac float64
}
// NewCredentials returns a pointer to a new Credentials object wrapping the
@ -244,7 +258,6 @@ func NewCredentialsWithClient(svc AssumeRoler, roleARN string, options ...func(*
// Retrieve generates a new set of temporary credentials using STS.
func (p *AssumeRoleProvider) Retrieve() (credentials.Value, error) {
// Apply defaults where parameters are not set.
if p.RoleSessionName == "" {
// Try to work out a role name that will hopefully end up unique.
@ -254,8 +267,9 @@ func (p *AssumeRoleProvider) Retrieve() (credentials.Value, error) {
// Expire as often as AWS permits.
p.Duration = DefaultDuration
}
jitter := time.Duration(sdkrand.SeededRand.Float64() * p.MaxJitterFrac * float64(p.Duration))
input := &sts.AssumeRoleInput{
DurationSeconds: aws.Int64(int64(p.Duration / time.Second)),
DurationSeconds: aws.Int64(int64((p.Duration - jitter) / time.Second)),
RoleArn: aws.String(p.RoleARN),
RoleSessionName: aws.String(p.RoleSessionName),
ExternalId: p.ExternalID,

View File

@ -0,0 +1,100 @@
package stscreds
import (
"fmt"
"io/ioutil"
"strconv"
"time"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/aws/client"
"github.com/aws/aws-sdk-go/aws/credentials"
"github.com/aws/aws-sdk-go/service/sts"
"github.com/aws/aws-sdk-go/service/sts/stsiface"
)
const (
// ErrCodeWebIdentity will be used as an error code when constructing
// a new error to be returned during session creation or retrieval.
ErrCodeWebIdentity = "WebIdentityErr"
// WebIdentityProviderName is the web identity provider name
WebIdentityProviderName = "WebIdentityCredentials"
)
// now is used to return a time.Time object representing
// the current time. This can be used to easily test and
// compare test values.
var now = time.Now
// WebIdentityRoleProvider is used to retrieve credentials using
// an OIDC token.
type WebIdentityRoleProvider struct {
credentials.Expiry
client stsiface.STSAPI
ExpiryWindow time.Duration
tokenFilePath string
roleARN string
roleSessionName string
}
// NewWebIdentityCredentials will return a new set of credentials with a given
// configuration, role arn, and token file path.
func NewWebIdentityCredentials(c client.ConfigProvider, roleARN, roleSessionName, path string) *credentials.Credentials {
svc := sts.New(c)
p := NewWebIdentityRoleProvider(svc, roleARN, roleSessionName, path)
return credentials.NewCredentials(p)
}
// NewWebIdentityRoleProvider will return a new WebIdentityRoleProvider with the
// provided stsiface.STSAPI
func NewWebIdentityRoleProvider(svc stsiface.STSAPI, roleARN, roleSessionName, path string) *WebIdentityRoleProvider {
return &WebIdentityRoleProvider{
client: svc,
tokenFilePath: path,
roleARN: roleARN,
roleSessionName: roleSessionName,
}
}
// Retrieve will attempt to assume a role from a token which is located at
// 'WebIdentityTokenFilePath' specified destination and if that is empty an
// error will be returned.
func (p *WebIdentityRoleProvider) Retrieve() (credentials.Value, error) {
b, err := ioutil.ReadFile(p.tokenFilePath)
if err != nil {
errMsg := fmt.Sprintf("unable to read file at %s", p.tokenFilePath)
return credentials.Value{}, awserr.New(ErrCodeWebIdentity, errMsg, err)
}
sessionName := p.roleSessionName
if len(sessionName) == 0 {
// session name is used to uniquely identify a session. This simply
// uses unix time in nanoseconds to uniquely identify sessions.
sessionName = strconv.FormatInt(now().UnixNano(), 10)
}
req, resp := p.client.AssumeRoleWithWebIdentityRequest(&sts.AssumeRoleWithWebIdentityInput{
RoleArn: &p.roleARN,
RoleSessionName: &sessionName,
WebIdentityToken: aws.String(string(b)),
})
// InvalidIdentityToken error is a temporary error that can occur
// when assuming an Role with a JWT web identity token.
req.RetryErrorCodes = append(req.RetryErrorCodes, sts.ErrCodeInvalidIdentityTokenException)
if err := req.Send(); err != nil {
return credentials.Value{}, awserr.New(ErrCodeWebIdentity, "failed to retrieve credentials", err)
}
p.SetExpiration(aws.TimeValue(resp.Credentials.Expiration), p.ExpiryWindow)
value := credentials.Value{
AccessKeyID: aws.StringValue(resp.Credentials.AccessKeyId),
SecretAccessKey: aws.StringValue(resp.Credentials.SecretAccessKey),
SessionToken: aws.StringValue(resp.Credentials.SessionToken),
ProviderName: WebIdentityProviderName,
}
return value, nil
}

View File

@ -1,30 +1,61 @@
// Package csm provides Client Side Monitoring (CSM) which enables sending metrics
// via UDP connection. Using the Start function will enable the reporting of
// metrics on a given port. If Start is called, with different parameters, again,
// a panic will occur.
// Package csm provides the Client Side Monitoring (CSM) client which enables
// sending metrics via UDP connection to the CSM agent. This package provides
// control options, and configuration for the CSM client. The client can be
// controlled manually, or automatically via the SDK's Session configuration.
//
// Pause can be called to pause any metrics publishing on a given port. Sessions
// that have had their handlers modified via InjectHandlers may still be used.
// However, the handlers will act as a no-op meaning no metrics will be published.
// Enabling CSM client via SDK's Session configuration
//
// The CSM client can be enabled automatically via SDK's Session configuration.
// The SDK's session configuration enables the CSM client if the AWS_CSM_PORT
// environment variable is set to a non-empty value.
//
// The configuration options for the CSM client via the SDK's session
// configuration are:
//
// * AWS_CSM_PORT=<port number>
// The port number the CSM agent will receive metrics on.
//
// * AWS_CSM_HOST=<hostname or ip>
// The hostname, or IP address the CSM agent will receive metrics on.
// Without port number.
//
// Manually enabling the CSM client
//
// The CSM client can be started, paused, and resumed manually. The Start
// function will enable the CSM client to publish metrics to the CSM agent. It
// is safe to call Start concurrently, but if Start is called additional times
// with different ClientID or address it will panic.
//
// Example:
// r, err := csm.Start("clientID", ":31000")
// if err != nil {
// panic(fmt.Errorf("failed starting CSM: %v", err))
// }
//
// When controlling the CSM client manually, you must also inject its request
// handlers into the SDK's Session configuration for the SDK's API clients to
// publish metrics.
//
// sess, err := session.NewSession(&aws.Config{})
// if err != nil {
// panic(fmt.Errorf("failed loading session: %v", err))
// }
//
// // Add CSM client's metric publishing request handlers to the SDK's
// // Session Configuration.
// r.InjectHandlers(&sess.Handlers)
//
// client := s3.New(sess)
// resp, err := client.GetObject(&s3.GetObjectInput{
// Bucket: aws.String("bucket"),
// Key: aws.String("key"),
// })
// Controlling CSM client
//
// Once the CSM client has been enabled the Get function will return a Reporter
// value that you can use to pause and resume the metrics published to the CSM
// agent. If Get function is called before the reporter is enabled with the
// Start function or via SDK's Session configuration nil will be returned.
//
// The Pause method can be called to stop the CSM client publishing metrics to
// the CSM agent. The Continue method will resume metric publishing.
//
// // Get the CSM client Reporter.
// r := csm.Get()
//
// // Will pause monitoring
// r.Pause()
@ -35,12 +66,4 @@
//
// // Resume monitoring
// r.Continue()
//
// Start returns a Reporter that is used to enable or disable monitoring. If
// access to the Reporter is required later, calling Get will return the Reporter
// singleton.
//
// Example:
// r := csm.Get()
// r.Continue()
package csm

View File

@ -2,6 +2,7 @@ package csm
import (
"fmt"
"strings"
"sync"
)
@ -9,19 +10,40 @@ var (
lock sync.Mutex
)
// Client side metric handler names
const (
APICallMetricHandlerName = "awscsm.SendAPICallMetric"
APICallAttemptMetricHandlerName = "awscsm.SendAPICallAttemptMetric"
// DefaultPort is used when no port is specified.
DefaultPort = "31000"
// DefaultHost is the host that will be used when none is specified.
DefaultHost = "127.0.0.1"
)
// Start will start the a long running go routine to capture
// AddressWithDefaults returns a CSM address built from the host and port
// values. If the host or port is not set, default values will be used
// instead. If host is "localhost" it will be replaced with "127.0.0.1".
func AddressWithDefaults(host, port string) string {
if len(host) == 0 || strings.EqualFold(host, "localhost") {
host = DefaultHost
}
if len(port) == 0 {
port = DefaultPort
}
// Only IP6 host can contain a colon
if strings.Contains(host, ":") {
return "[" + host + "]:" + port
}
return host + ":" + port
}
// Start will start a long running go routine to capture
// client side metrics. Calling start multiple time will only
// start the metric listener once and will panic if a different
// client ID or port is passed in.
//
// Example:
// r, err := csm.Start("clientID", "127.0.0.1:8094")
// r, err := csm.Start("clientID", "127.0.0.1:31000")
// if err != nil {
// panic(fmt.Errorf("expected no error, but received %v", err))
// }

View File

@ -3,6 +3,8 @@ package csm
import (
"strconv"
"time"
"github.com/aws/aws-sdk-go/aws"
)
type metricTime time.Time
@ -39,6 +41,12 @@ type metric struct {
SDKException *string `json:"SdkException,omitempty"`
SDKExceptionMessage *string `json:"SdkExceptionMessage,omitempty"`
FinalHTTPStatusCode *int `json:"FinalHttpStatusCode,omitempty"`
FinalAWSException *string `json:"FinalAwsException,omitempty"`
FinalAWSExceptionMessage *string `json:"FinalAwsExceptionMessage,omitempty"`
FinalSDKException *string `json:"FinalSdkException,omitempty"`
FinalSDKExceptionMessage *string `json:"FinalSdkExceptionMessage,omitempty"`
DestinationIP *string `json:"DestinationIp,omitempty"`
ConnectionReused *int `json:"ConnectionReused,omitempty"`
@ -48,4 +56,54 @@ type metric struct {
DNSLatency *int `json:"DnsLatency,omitempty"`
TCPLatency *int `json:"TcpLatency,omitempty"`
SSLLatency *int `json:"SslLatency,omitempty"`
MaxRetriesExceeded *int `json:"MaxRetriesExceeded,omitempty"`
}
func (m *metric) TruncateFields() {
m.ClientID = truncateString(m.ClientID, 255)
m.UserAgent = truncateString(m.UserAgent, 256)
m.AWSException = truncateString(m.AWSException, 128)
m.AWSExceptionMessage = truncateString(m.AWSExceptionMessage, 512)
m.SDKException = truncateString(m.SDKException, 128)
m.SDKExceptionMessage = truncateString(m.SDKExceptionMessage, 512)
m.FinalAWSException = truncateString(m.FinalAWSException, 128)
m.FinalAWSExceptionMessage = truncateString(m.FinalAWSExceptionMessage, 512)
m.FinalSDKException = truncateString(m.FinalSDKException, 128)
m.FinalSDKExceptionMessage = truncateString(m.FinalSDKExceptionMessage, 512)
}
func truncateString(v *string, l int) *string {
if v != nil && len(*v) > l {
nv := (*v)[:l]
return &nv
}
return v
}
func (m *metric) SetException(e metricException) {
switch te := e.(type) {
case awsException:
m.AWSException = aws.String(te.exception)
m.AWSExceptionMessage = aws.String(te.message)
case sdkException:
m.SDKException = aws.String(te.exception)
m.SDKExceptionMessage = aws.String(te.message)
}
}
func (m *metric) SetFinalException(e metricException) {
switch te := e.(type) {
case awsException:
m.FinalAWSException = aws.String(te.exception)
m.FinalAWSExceptionMessage = aws.String(te.message)
case sdkException:
m.FinalSDKException = aws.String(te.exception)
m.FinalSDKExceptionMessage = aws.String(te.message)
}
}

View File

@ -16,25 +16,26 @@ var (
type metricChan struct {
ch chan metric
paused int64
paused *int64
}
func newMetricChan(size int) metricChan {
return metricChan{
ch: make(chan metric, size),
ch: make(chan metric, size),
paused: new(int64),
}
}
func (ch *metricChan) Pause() {
atomic.StoreInt64(&ch.paused, pausedEnum)
atomic.StoreInt64(ch.paused, pausedEnum)
}
func (ch *metricChan) Continue() {
atomic.StoreInt64(&ch.paused, runningEnum)
atomic.StoreInt64(ch.paused, runningEnum)
}
func (ch *metricChan) IsPaused() bool {
v := atomic.LoadInt64(&ch.paused)
v := atomic.LoadInt64(ch.paused)
return v == pausedEnum
}

View File

@ -0,0 +1,26 @@
package csm
type metricException interface {
Exception() string
Message() string
}
type requestException struct {
exception string
message string
}
func (e requestException) Exception() string {
return e.exception
}
func (e requestException) Message() string {
return e.message
}
type awsException struct {
requestException
}
type sdkException struct {
requestException
}

View File

@ -10,11 +10,6 @@ import (
"github.com/aws/aws-sdk-go/aws/request"
)
const (
// DefaultPort is used when no port is specified
DefaultPort = "31000"
)
// Reporter will gather metrics of API requests made and
// send those metrics to the CSM endpoint.
type Reporter struct {
@ -71,7 +66,6 @@ func (rep *Reporter) sendAPICallAttemptMetric(r *request.Request) {
XAmzRequestID: aws.String(r.RequestID),
AttemptCount: aws.Int(r.RetryCount + 1),
AttemptLatency: aws.Int(int(now.Sub(r.AttemptTime).Nanoseconds() / int64(time.Millisecond))),
AccessKey: aws.String(creds.AccessKeyID),
}
@ -82,26 +76,29 @@ func (rep *Reporter) sendAPICallAttemptMetric(r *request.Request) {
if r.Error != nil {
if awserr, ok := r.Error.(awserr.Error); ok {
setError(&m, awserr)
m.SetException(getMetricException(awserr))
}
}
m.TruncateFields()
rep.metricsCh.Push(m)
}
func setError(m *metric, err awserr.Error) {
func getMetricException(err awserr.Error) metricException {
msg := err.Error()
code := err.Code()
switch code {
case "RequestError",
"SerializationError",
request.ErrCodeSerialization,
request.CanceledErrorCode:
m.SDKException = &code
m.SDKExceptionMessage = &msg
return sdkException{
requestException{exception: code, message: msg},
}
default:
m.AWSException = &code
m.AWSExceptionMessage = &msg
return awsException{
requestException{exception: code, message: msg},
}
}
}
@ -112,16 +109,31 @@ func (rep *Reporter) sendAPICallMetric(r *request.Request) {
now := time.Now()
m := metric{
ClientID: aws.String(rep.clientID),
API: aws.String(r.Operation.Name),
Service: aws.String(r.ClientInfo.ServiceID),
Timestamp: (*metricTime)(&now),
Type: aws.String("ApiCall"),
AttemptCount: aws.Int(r.RetryCount + 1),
Latency: aws.Int(int(time.Now().Sub(r.Time) / time.Millisecond)),
XAmzRequestID: aws.String(r.RequestID),
ClientID: aws.String(rep.clientID),
API: aws.String(r.Operation.Name),
Service: aws.String(r.ClientInfo.ServiceID),
Timestamp: (*metricTime)(&now),
UserAgent: aws.String(r.HTTPRequest.Header.Get("User-Agent")),
Type: aws.String("ApiCall"),
AttemptCount: aws.Int(r.RetryCount + 1),
Region: r.Config.Region,
Latency: aws.Int(int(time.Since(r.Time) / time.Millisecond)),
XAmzRequestID: aws.String(r.RequestID),
MaxRetriesExceeded: aws.Int(boolIntValue(r.RetryCount >= r.MaxRetries())),
}
if r.HTTPResponse != nil {
m.FinalHTTPStatusCode = aws.Int(r.HTTPResponse.StatusCode)
}
if r.Error != nil {
if awserr, ok := r.Error.(awserr.Error); ok {
m.SetFinalException(getMetricException(awserr))
}
}
m.TruncateFields()
// TODO: Probably want to figure something out for logging dropped
// metrics
rep.metricsCh.Push(m)
@ -172,8 +184,9 @@ func (rep *Reporter) start() {
}
}
// Pause will pause the metric channel preventing any new metrics from
// being added.
// Pause will pause the metric channel preventing any new metrics from being
// added. It is safe to call concurrently with other calls to Pause, but if
// called concurently with Continue can lead to unexpected state.
func (rep *Reporter) Pause() {
lock.Lock()
defer lock.Unlock()
@ -185,8 +198,9 @@ func (rep *Reporter) Pause() {
rep.close()
}
// Continue will reopen the metric channel and allow for monitoring
// to be resumed.
// Continue will reopen the metric channel and allow for monitoring to be
// resumed. It is safe to call concurrently with other calls to Continue, but
// if called concurently with Pause can lead to unexpected state.
func (rep *Reporter) Continue() {
lock.Lock()
defer lock.Unlock()
@ -201,10 +215,18 @@ func (rep *Reporter) Continue() {
rep.metricsCh.Continue()
}
// Client side metric handler names
const (
APICallMetricHandlerName = "awscsm.SendAPICallMetric"
APICallAttemptMetricHandlerName = "awscsm.SendAPICallAttemptMetric"
)
// InjectHandlers will will enable client side metrics and inject the proper
// handlers to handle how metrics are sent.
//
// Example:
// InjectHandlers is NOT safe to call concurrently. Calling InjectHandlers
// multiple times may lead to unexpected behavior, (e.g. duplicate metrics).
//
// // Start must be called in order to inject the correct handlers
// r, err := csm.Start("clientID", "127.0.0.1:8094")
// if err != nil {
@ -221,11 +243,22 @@ func (rep *Reporter) InjectHandlers(handlers *request.Handlers) {
return
}
apiCallHandler := request.NamedHandler{Name: APICallMetricHandlerName, Fn: rep.sendAPICallMetric}
apiCallAttemptHandler := request.NamedHandler{Name: APICallAttemptMetricHandlerName, Fn: rep.sendAPICallAttemptMetric}
handlers.Complete.PushFrontNamed(request.NamedHandler{
Name: APICallMetricHandlerName,
Fn: rep.sendAPICallMetric,
})
handlers.Complete.PushFrontNamed(apiCallHandler)
handlers.Complete.PushFrontNamed(apiCallAttemptHandler)
handlers.AfterRetry.PushFrontNamed(apiCallAttemptHandler)
handlers.CompleteAttempt.PushFrontNamed(request.NamedHandler{
Name: APICallAttemptMetricHandlerName,
Fn: rep.sendAPICallAttemptMetric,
})
}
// boolIntValue return 1 for true and 0 for false.
func boolIntValue(b bool) int {
if b {
return 1
}
return 0
}

View File

@ -24,6 +24,7 @@ import (
"github.com/aws/aws-sdk-go/aws/ec2metadata"
"github.com/aws/aws-sdk-go/aws/endpoints"
"github.com/aws/aws-sdk-go/aws/request"
"github.com/aws/aws-sdk-go/internal/shareddefaults"
)
// A Defaults provides a collection of default values for SDK clients.
@ -112,8 +113,8 @@ func CredProviders(cfg *aws.Config, handlers request.Handlers) []credentials.Pro
}
const (
httpProviderEnvVar = "AWS_CONTAINER_CREDENTIALS_FULL_URI"
ecsCredsProviderEnvVar = "AWS_CONTAINER_CREDENTIALS_RELATIVE_URI"
httpProviderAuthorizationEnvVar = "AWS_CONTAINER_AUTHORIZATION_TOKEN"
httpProviderEnvVar = "AWS_CONTAINER_CREDENTIALS_FULL_URI"
)
// RemoteCredProvider returns a credentials provider for the default remote
@ -123,8 +124,8 @@ func RemoteCredProvider(cfg aws.Config, handlers request.Handlers) credentials.P
return localHTTPCredProvider(cfg, handlers, u)
}
if uri := os.Getenv(ecsCredsProviderEnvVar); len(uri) > 0 {
u := fmt.Sprintf("http://169.254.170.2%s", uri)
if uri := os.Getenv(shareddefaults.ECSCredsProviderEnvVar); len(uri) > 0 {
u := fmt.Sprintf("%s%s", shareddefaults.ECSContainerCredentialsURI, uri)
return httpCredProvider(cfg, handlers, u)
}
@ -187,6 +188,7 @@ func httpCredProvider(cfg aws.Config, handlers request.Handlers, u string) crede
return endpointcreds.NewProviderClient(cfg, handlers, u,
func(p *endpointcreds.Provider) {
p.ExpiryWindow = 5 * time.Minute
p.AuthorizationToken = os.Getenv(httpProviderAuthorizationEnvVar)
},
)
}

View File

@ -24,8 +24,9 @@ func (c *EC2Metadata) GetMetadata(p string) (string, error) {
output := &metadataOutput{}
req := c.NewRequest(op, nil, output)
err := req.Send()
return output.Content, req.Send()
return output.Content, err
}
// GetUserData returns the userdata that was configured for the service. If
@ -45,8 +46,9 @@ func (c *EC2Metadata) GetUserData() (string, error) {
r.Error = awserr.New("NotFoundError", "user-data not found", r.Error)
}
})
err := req.Send()
return output.Content, req.Send()
return output.Content, err
}
// GetDynamicData uses the path provided to request information from the EC2
@ -61,8 +63,9 @@ func (c *EC2Metadata) GetDynamicData(p string) (string, error) {
output := &metadataOutput{}
req := c.NewRequest(op, nil, output)
err := req.Send()
return output.Content, req.Send()
return output.Content, err
}
// GetInstanceIdentityDocument retrieves an identity document describing an
@ -79,7 +82,7 @@ func (c *EC2Metadata) GetInstanceIdentityDocument() (EC2InstanceIdentityDocument
doc := EC2InstanceIdentityDocument{}
if err := json.NewDecoder(strings.NewReader(resp)).Decode(&doc); err != nil {
return EC2InstanceIdentityDocument{},
awserr.New("SerializationError",
awserr.New(request.ErrCodeSerialization,
"failed to decode EC2 instance identity document", err)
}
@ -98,7 +101,7 @@ func (c *EC2Metadata) IAMInfo() (EC2IAMInfo, error) {
info := EC2IAMInfo{}
if err := json.NewDecoder(strings.NewReader(resp)).Decode(&info); err != nil {
return EC2IAMInfo{},
awserr.New("SerializationError",
awserr.New(request.ErrCodeSerialization,
"failed to decode EC2 IAM info", err)
}
@ -118,6 +121,10 @@ func (c *EC2Metadata) Region() (string, error) {
return "", err
}
if len(resp) == 0 {
return "", awserr.New("EC2MetadataError", "invalid Region response", nil)
}
// returns region without the suffix. Eg: us-west-2a becomes us-west-2
return resp[:len(resp)-1], nil
}
@ -145,18 +152,19 @@ type EC2IAMInfo struct {
// An EC2InstanceIdentityDocument provides the shape for unmarshaling
// an instance identity document
type EC2InstanceIdentityDocument struct {
DevpayProductCodes []string `json:"devpayProductCodes"`
AvailabilityZone string `json:"availabilityZone"`
PrivateIP string `json:"privateIp"`
Version string `json:"version"`
Region string `json:"region"`
InstanceID string `json:"instanceId"`
BillingProducts []string `json:"billingProducts"`
InstanceType string `json:"instanceType"`
AccountID string `json:"accountId"`
PendingTime time.Time `json:"pendingTime"`
ImageID string `json:"imageId"`
KernelID string `json:"kernelId"`
RamdiskID string `json:"ramdiskId"`
Architecture string `json:"architecture"`
DevpayProductCodes []string `json:"devpayProductCodes"`
MarketplaceProductCodes []string `json:"marketplaceProductCodes"`
AvailabilityZone string `json:"availabilityZone"`
PrivateIP string `json:"privateIp"`
Version string `json:"version"`
Region string `json:"region"`
InstanceID string `json:"instanceId"`
BillingProducts []string `json:"billingProducts"`
InstanceType string `json:"instanceType"`
AccountID string `json:"accountId"`
PendingTime time.Time `json:"pendingTime"`
ImageID string `json:"imageId"`
KernelID string `json:"kernelId"`
RamdiskID string `json:"ramdiskId"`
Architecture string `json:"architecture"`
}

View File

@ -4,7 +4,7 @@
// This package's client can be disabled completely by setting the environment
// variable "AWS_EC2_METADATA_DISABLED=true". This environment variable set to
// true instructs the SDK to disable the EC2 Metadata client. The client cannot
// be used while the environemnt variable is set to true, (case insensitive).
// be used while the environment variable is set to true, (case insensitive).
package ec2metadata
import (
@ -72,6 +72,7 @@ func NewClient(cfg aws.Config, handlers request.Handlers, endpoint, signingRegio
cfg,
metadata.ClientInfo{
ServiceName: ServiceName,
ServiceID: ServiceName,
Endpoint: endpoint,
APIVersion: "latest",
},
@ -91,6 +92,9 @@ func NewClient(cfg aws.Config, handlers request.Handlers, endpoint, signingRegio
svc.Handlers.Send.SwapNamed(request.NamedHandler{
Name: corehandlers.SendHandler.Name,
Fn: func(r *request.Request) {
r.HTTPResponse = &http.Response{
Header: http.Header{},
}
r.Error = awserr.New(
request.CanceledErrorCode,
"EC2 IMDS access disabled via "+disableServiceEnvVar+" env var",
@ -119,7 +123,7 @@ func unmarshalHandler(r *request.Request) {
defer r.HTTPResponse.Body.Close()
b := &bytes.Buffer{}
if _, err := io.Copy(b, r.HTTPResponse.Body); err != nil {
r.Error = awserr.New("SerializationError", "unable to unmarshal EC2 metadata respose", err)
r.Error = awserr.New(request.ErrCodeSerialization, "unable to unmarshal EC2 metadata response", err)
return
}
@ -132,7 +136,7 @@ func unmarshalError(r *request.Request) {
defer r.HTTPResponse.Body.Close()
b := &bytes.Buffer{}
if _, err := io.Copy(b, r.HTTPResponse.Body); err != nil {
r.Error = awserr.New("SerializationError", "unable to unmarshal EC2 metadata error respose", err)
r.Error = awserr.New(request.ErrCodeSerialization, "unable to unmarshal EC2 metadata error response", err)
return
}

View File

@ -85,6 +85,7 @@ func decodeV3Endpoints(modelDef modelDefinition, opts DecodeModelOptions) (Resol
custAddS3DualStack(p)
custRmIotDataService(p)
custFixAppAutoscalingChina(p)
custFixAppAutoscalingUsGov(p)
}
return ps, nil
@ -95,7 +96,12 @@ func custAddS3DualStack(p *partition) {
return
}
s, ok := p.Services["s3"]
custAddDualstack(p, "s3")
custAddDualstack(p, "s3-control")
}
func custAddDualstack(p *partition, svcName string) {
s, ok := p.Services[svcName]
if !ok {
return
}
@ -103,7 +109,7 @@ func custAddS3DualStack(p *partition) {
s.Defaults.HasDualStack = boxedTrue
s.Defaults.DualStackHostname = "{service}.dualstack.{region}.{dnsSuffix}"
p.Services["s3"] = s
p.Services[svcName] = s
}
func custAddEC2Metadata(p *partition) {
@ -144,6 +150,33 @@ func custFixAppAutoscalingChina(p *partition) {
p.Services[serviceName] = s
}
func custFixAppAutoscalingUsGov(p *partition) {
if p.ID != "aws-us-gov" {
return
}
const serviceName = "application-autoscaling"
s, ok := p.Services[serviceName]
if !ok {
return
}
if a := s.Defaults.CredentialScope.Service; a != "" {
fmt.Printf("custFixAppAutoscalingUsGov: ignoring customization, expected empty credential scope service, got %s\n", a)
return
}
if a := s.Defaults.Hostname; a != "" {
fmt.Printf("custFixAppAutoscalingUsGov: ignoring customization, expected empty hostname, got %s\n", a)
return
}
s.Defaults.CredentialScope.Service = "application-autoscaling"
s.Defaults.Hostname = "autoscaling.{region}.amazonaws.com"
p.Services[serviceName] = s
}
type decodeModelError struct {
awsError
}

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,141 @@
package endpoints
// Service identifiers
//
// Deprecated: Use client package's EndpointsID value instead of these
// ServiceIDs. These IDs are not maintained, and are out of date.
const (
A4bServiceID = "a4b" // A4b.
AcmServiceID = "acm" // Acm.
AcmPcaServiceID = "acm-pca" // AcmPca.
ApiMediatailorServiceID = "api.mediatailor" // ApiMediatailor.
ApiPricingServiceID = "api.pricing" // ApiPricing.
ApiSagemakerServiceID = "api.sagemaker" // ApiSagemaker.
ApigatewayServiceID = "apigateway" // Apigateway.
ApplicationAutoscalingServiceID = "application-autoscaling" // ApplicationAutoscaling.
Appstream2ServiceID = "appstream2" // Appstream2.
AppsyncServiceID = "appsync" // Appsync.
AthenaServiceID = "athena" // Athena.
AutoscalingServiceID = "autoscaling" // Autoscaling.
AutoscalingPlansServiceID = "autoscaling-plans" // AutoscalingPlans.
BatchServiceID = "batch" // Batch.
BudgetsServiceID = "budgets" // Budgets.
CeServiceID = "ce" // Ce.
ChimeServiceID = "chime" // Chime.
Cloud9ServiceID = "cloud9" // Cloud9.
ClouddirectoryServiceID = "clouddirectory" // Clouddirectory.
CloudformationServiceID = "cloudformation" // Cloudformation.
CloudfrontServiceID = "cloudfront" // Cloudfront.
CloudhsmServiceID = "cloudhsm" // Cloudhsm.
Cloudhsmv2ServiceID = "cloudhsmv2" // Cloudhsmv2.
CloudsearchServiceID = "cloudsearch" // Cloudsearch.
CloudtrailServiceID = "cloudtrail" // Cloudtrail.
CodebuildServiceID = "codebuild" // Codebuild.
CodecommitServiceID = "codecommit" // Codecommit.
CodedeployServiceID = "codedeploy" // Codedeploy.
CodepipelineServiceID = "codepipeline" // Codepipeline.
CodestarServiceID = "codestar" // Codestar.
CognitoIdentityServiceID = "cognito-identity" // CognitoIdentity.
CognitoIdpServiceID = "cognito-idp" // CognitoIdp.
CognitoSyncServiceID = "cognito-sync" // CognitoSync.
ComprehendServiceID = "comprehend" // Comprehend.
ConfigServiceID = "config" // Config.
CurServiceID = "cur" // Cur.
DatapipelineServiceID = "datapipeline" // Datapipeline.
DaxServiceID = "dax" // Dax.
DevicefarmServiceID = "devicefarm" // Devicefarm.
DirectconnectServiceID = "directconnect" // Directconnect.
DiscoveryServiceID = "discovery" // Discovery.
DmsServiceID = "dms" // Dms.
DsServiceID = "ds" // Ds.
DynamodbServiceID = "dynamodb" // Dynamodb.
Ec2ServiceID = "ec2" // Ec2.
Ec2metadataServiceID = "ec2metadata" // Ec2metadata.
EcrServiceID = "ecr" // Ecr.
EcsServiceID = "ecs" // Ecs.
ElasticacheServiceID = "elasticache" // Elasticache.
ElasticbeanstalkServiceID = "elasticbeanstalk" // Elasticbeanstalk.
ElasticfilesystemServiceID = "elasticfilesystem" // Elasticfilesystem.
ElasticloadbalancingServiceID = "elasticloadbalancing" // Elasticloadbalancing.
ElasticmapreduceServiceID = "elasticmapreduce" // Elasticmapreduce.
ElastictranscoderServiceID = "elastictranscoder" // Elastictranscoder.
EmailServiceID = "email" // Email.
EntitlementMarketplaceServiceID = "entitlement.marketplace" // EntitlementMarketplace.
EsServiceID = "es" // Es.
EventsServiceID = "events" // Events.
FirehoseServiceID = "firehose" // Firehose.
FmsServiceID = "fms" // Fms.
GameliftServiceID = "gamelift" // Gamelift.
GlacierServiceID = "glacier" // Glacier.
GlueServiceID = "glue" // Glue.
GreengrassServiceID = "greengrass" // Greengrass.
GuarddutyServiceID = "guardduty" // Guardduty.
HealthServiceID = "health" // Health.
IamServiceID = "iam" // Iam.
ImportexportServiceID = "importexport" // Importexport.
InspectorServiceID = "inspector" // Inspector.
IotServiceID = "iot" // Iot.
IotanalyticsServiceID = "iotanalytics" // Iotanalytics.
KinesisServiceID = "kinesis" // Kinesis.
KinesisanalyticsServiceID = "kinesisanalytics" // Kinesisanalytics.
KinesisvideoServiceID = "kinesisvideo" // Kinesisvideo.
KmsServiceID = "kms" // Kms.
LambdaServiceID = "lambda" // Lambda.
LightsailServiceID = "lightsail" // Lightsail.
LogsServiceID = "logs" // Logs.
MachinelearningServiceID = "machinelearning" // Machinelearning.
MarketplacecommerceanalyticsServiceID = "marketplacecommerceanalytics" // Marketplacecommerceanalytics.
MediaconvertServiceID = "mediaconvert" // Mediaconvert.
MedialiveServiceID = "medialive" // Medialive.
MediapackageServiceID = "mediapackage" // Mediapackage.
MediastoreServiceID = "mediastore" // Mediastore.
MeteringMarketplaceServiceID = "metering.marketplace" // MeteringMarketplace.
MghServiceID = "mgh" // Mgh.
MobileanalyticsServiceID = "mobileanalytics" // Mobileanalytics.
ModelsLexServiceID = "models.lex" // ModelsLex.
MonitoringServiceID = "monitoring" // Monitoring.
MturkRequesterServiceID = "mturk-requester" // MturkRequester.
NeptuneServiceID = "neptune" // Neptune.
OpsworksServiceID = "opsworks" // Opsworks.
OpsworksCmServiceID = "opsworks-cm" // OpsworksCm.
OrganizationsServiceID = "organizations" // Organizations.
PinpointServiceID = "pinpoint" // Pinpoint.
PollyServiceID = "polly" // Polly.
RdsServiceID = "rds" // Rds.
RedshiftServiceID = "redshift" // Redshift.
RekognitionServiceID = "rekognition" // Rekognition.
ResourceGroupsServiceID = "resource-groups" // ResourceGroups.
Route53ServiceID = "route53" // Route53.
Route53domainsServiceID = "route53domains" // Route53domains.
RuntimeLexServiceID = "runtime.lex" // RuntimeLex.
RuntimeSagemakerServiceID = "runtime.sagemaker" // RuntimeSagemaker.
S3ServiceID = "s3" // S3.
S3ControlServiceID = "s3-control" // S3Control.
SagemakerServiceID = "api.sagemaker" // Sagemaker.
SdbServiceID = "sdb" // Sdb.
SecretsmanagerServiceID = "secretsmanager" // Secretsmanager.
ServerlessrepoServiceID = "serverlessrepo" // Serverlessrepo.
ServicecatalogServiceID = "servicecatalog" // Servicecatalog.
ServicediscoveryServiceID = "servicediscovery" // Servicediscovery.
ShieldServiceID = "shield" // Shield.
SmsServiceID = "sms" // Sms.
SnowballServiceID = "snowball" // Snowball.
SnsServiceID = "sns" // Sns.
SqsServiceID = "sqs" // Sqs.
SsmServiceID = "ssm" // Ssm.
StatesServiceID = "states" // States.
StoragegatewayServiceID = "storagegateway" // Storagegateway.
StreamsDynamodbServiceID = "streams.dynamodb" // StreamsDynamodb.
StsServiceID = "sts" // Sts.
SupportServiceID = "support" // Support.
SwfServiceID = "swf" // Swf.
TaggingServiceID = "tagging" // Tagging.
TransferServiceID = "transfer" // Transfer.
TranslateServiceID = "translate" // Translate.
WafServiceID = "waf" // Waf.
WafRegionalServiceID = "waf-regional" // WafRegional.
WorkdocsServiceID = "workdocs" // Workdocs.
WorkmailServiceID = "workmail" // Workmail.
WorkspacesServiceID = "workspaces" // Workspaces.
XrayServiceID = "xray" // Xray.
)

View File

@ -3,6 +3,7 @@ package endpoints
import (
"fmt"
"regexp"
"strings"
"github.com/aws/aws-sdk-go/aws/awserr"
)
@ -35,7 +36,7 @@ type Options struct {
//
// If resolving an endpoint on the partition list the provided region will
// be used to determine which partition's domain name pattern to the service
// endpoint ID with. If both the service and region are unkonwn and resolving
// endpoint ID with. If both the service and region are unknown and resolving
// the endpoint on partition list an UnknownEndpointError error will be returned.
//
// If resolving and endpoint on a partition specific resolver that partition's
@ -46,6 +47,43 @@ type Options struct {
//
// This option is ignored if StrictMatching is enabled.
ResolveUnknownService bool
// STS Regional Endpoint flag helps with resolving the STS endpoint
STSRegionalEndpoint STSRegionalEndpoint
}
// STSRegionalEndpoint is an enum type alias for int
// It is used internally by the core sdk as STS Regional Endpoint flag value
type STSRegionalEndpoint int
const (
// UnsetSTSEndpoint represents that STS Regional Endpoint flag is not specified.
UnsetSTSEndpoint STSRegionalEndpoint = iota
// LegacySTSEndpoint represents when STS Regional Endpoint flag is specified
// to use legacy endpoints.
LegacySTSEndpoint
// RegionalSTSEndpoint represents when STS Regional Endpoint flag is specified
// to use regional endpoints.
RegionalSTSEndpoint
)
// GetSTSRegionalEndpoint function returns the STSRegionalEndpointFlag based
// on the input string provided in env config or shared config by the user.
//
// `legacy`, `regional` are the only case-insensitive valid strings for
// resolving the STS regional Endpoint flag.
func GetSTSRegionalEndpoint(s string) (STSRegionalEndpoint, error) {
switch {
case strings.EqualFold(s, "legacy"):
return LegacySTSEndpoint, nil
case strings.EqualFold(s, "regional"):
return RegionalSTSEndpoint, nil
default:
return UnsetSTSEndpoint, fmt.Errorf("unable to resolve the value of STSRegionalEndpoint for %v", s)
}
}
// Set combines all of the option functions together.
@ -79,6 +117,12 @@ func ResolveUnknownServiceOption(o *Options) {
o.ResolveUnknownService = true
}
// STSRegionalEndpointOption enables the STS endpoint resolver behavior to resolve
// STS endpoint to their regional endpoint, instead of the global endpoint.
func STSRegionalEndpointOption(o *Options) {
o.STSRegionalEndpoint = RegionalSTSEndpoint
}
// A Resolver provides the interface for functionality to resolve endpoints.
// The build in Partition and DefaultResolver return value satisfy this interface.
type Resolver interface {
@ -170,10 +214,13 @@ func PartitionForRegion(ps []Partition, regionID string) (Partition, bool) {
// A Partition provides the ability to enumerate the partition's regions
// and services.
type Partition struct {
id string
p *partition
id, dnsSuffix string
p *partition
}
// DNSSuffix returns the base domain name of the partition.
func (p Partition) DNSSuffix() string { return p.dnsSuffix }
// ID returns the identifier of the partition.
func (p Partition) ID() string { return p.id }
@ -191,7 +238,7 @@ func (p Partition) ID() string { return p.id }
// require the provided service and region to be known by the partition.
// If the endpoint cannot be strictly resolved an error will be returned. This
// mode is useful to ensure the endpoint resolved is valid. Without
// StrictMatching enabled the endpoint returned my look valid but may not work.
// StrictMatching enabled the endpoint returned may look valid but may not work.
// StrictMatching requires the SDK to be updated if you want to take advantage
// of new regions and services expansions.
//
@ -347,6 +394,9 @@ type ResolvedEndpoint struct {
// The endpoint URL
URL string
// The endpoint partition
PartitionID string
// The region that should be used for signing requests.
SigningRegion string

View File

@ -0,0 +1,19 @@
package endpoints
var stsLegacyGlobalRegions = map[string]struct{}{
"ap-northeast-1": {},
"ap-south-1": {},
"ap-southeast-1": {},
"ap-southeast-2": {},
"ca-central-1": {},
"eu-central-1": {},
"eu-north-1": {},
"eu-west-1": {},
"eu-west-2": {},
"eu-west-3": {},
"sa-east-1": {},
"us-east-1": {},
"us-east-2": {},
"us-west-1": {},
"us-west-2": {},
}

View File

@ -54,8 +54,9 @@ type partition struct {
func (p partition) Partition() Partition {
return Partition{
id: p.ID,
p: &p,
dnsSuffix: p.DNSSuffix,
id: p.ID,
p: &p,
}
}
@ -74,24 +75,55 @@ func (p partition) canResolveEndpoint(service, region string, strictMatch bool)
return p.RegionRegex.MatchString(region)
}
func allowLegacyEmptyRegion(service string) bool {
legacy := map[string]struct{}{
"budgets": {},
"ce": {},
"chime": {},
"cloudfront": {},
"ec2metadata": {},
"iam": {},
"importexport": {},
"organizations": {},
"route53": {},
"sts": {},
"support": {},
"waf": {},
}
_, allowed := legacy[service]
return allowed
}
func (p partition) EndpointFor(service, region string, opts ...func(*Options)) (resolved ResolvedEndpoint, err error) {
var opt Options
opt.Set(opts...)
s, hasService := p.Services[service]
if !(hasService || opt.ResolveUnknownService) {
if len(service) == 0 || !(hasService || opt.ResolveUnknownService) {
// Only return error if the resolver will not fallback to creating
// endpoint based on service endpoint ID passed in.
return resolved, NewUnknownServiceError(p.ID, service, serviceList(p.Services))
}
if len(region) == 0 && allowLegacyEmptyRegion(service) && len(s.PartitionEndpoint) != 0 {
region = s.PartitionEndpoint
}
if service == "sts" && opt.STSRegionalEndpoint != RegionalSTSEndpoint {
if _, ok := stsLegacyGlobalRegions[region]; ok {
region = "aws-global"
}
}
e, hasEndpoint := s.endpointForRegion(region)
if !hasEndpoint && opt.StrictMatching {
if len(region) == 0 || (!hasEndpoint && opt.StrictMatching) {
return resolved, NewUnknownEndpointError(p.ID, service, region, endpointList(s.Endpoints))
}
defs := []endpoint{p.Defaults, s.Defaults}
return e.resolve(service, region, p.DNSSuffix, defs, opt), nil
return e.resolve(service, p.ID, region, p.DNSSuffix, defs, opt), nil
}
func serviceList(ss services) []string {
@ -200,7 +232,7 @@ func getByPriority(s []string, p []string, def string) string {
return s[0]
}
func (e endpoint) resolve(service, region, dnsSuffix string, defs []endpoint, opts Options) ResolvedEndpoint {
func (e endpoint) resolve(service, partitionID, region, dnsSuffix string, defs []endpoint, opts Options) ResolvedEndpoint {
var merged endpoint
for _, def := range defs {
merged.mergeIn(def)
@ -236,6 +268,7 @@ func (e endpoint) resolve(service, region, dnsSuffix string, defs []endpoint, op
return ResolvedEndpoint{
URL: u,
PartitionID: partitionID,
SigningRegion: signingRegion,
SigningName: signingName,
SigningNameDerived: signingNameDerived,

View File

@ -16,6 +16,10 @@ import (
type CodeGenOptions struct {
// Options for how the model will be decoded.
DecodeModelOptions DecodeModelOptions
// Disables code generation of the service endpoint prefix IDs defined in
// the model.
DisableGenerateServiceIDs bool
}
// Set combines all of the option functions together
@ -39,8 +43,16 @@ func CodeGenModel(modelFile io.Reader, outFile io.Writer, optFns ...func(*CodeGe
return err
}
v := struct {
Resolver
CodeGenOptions
}{
Resolver: resolver,
CodeGenOptions: opts,
}
tmpl := template.Must(template.New("tmpl").Funcs(funcMap).Parse(v3Tmpl))
if err := tmpl.ExecuteTemplate(outFile, "defaults", resolver); err != nil {
if err := tmpl.ExecuteTemplate(outFile, "defaults", v); err != nil {
return fmt.Errorf("failed to execute template, %v", err)
}
@ -166,15 +178,17 @@ import (
"regexp"
)
{{ template "partition consts" . }}
{{ template "partition consts" $.Resolver }}
{{ range $_, $partition := . }}
{{ range $_, $partition := $.Resolver }}
{{ template "partition region consts" $partition }}
{{ end }}
{{ template "service consts" . }}
{{ if not $.DisableGenerateServiceIDs -}}
{{ template "service consts" $.Resolver }}
{{- end }}
{{ template "endpoint resolvers" . }}
{{ template "endpoint resolvers" $.Resolver }}
{{- end }}
{{ define "partition consts" }}

View File

@ -5,13 +5,9 @@ import "github.com/aws/aws-sdk-go/aws/awserr"
var (
// ErrMissingRegion is an error that is returned if region configuration is
// not found.
//
// @readonly
ErrMissingRegion = awserr.New("MissingRegion", "could not find region configuration", nil)
// ErrMissingEndpoint is an error that is returned if an endpoint cannot be
// resolved for a service.
//
// @readonly
ErrMissingEndpoint = awserr.New("MissingEndpoint", "'Endpoint' configuration is required for this service", nil)
)

View File

@ -1,18 +1,17 @@
// +build !appengine,!plan9
package request
import (
"net"
"os"
"syscall"
"strings"
)
func isErrConnectionReset(err error) bool {
if opErr, ok := err.(*net.OpError); ok {
if sysErr, ok := opErr.Err.(*os.SyscallError); ok {
return sysErr.Err == syscall.ECONNRESET
}
if strings.Contains(err.Error(), "read: connection reset") {
return false
}
if strings.Contains(err.Error(), "connection reset") ||
strings.Contains(err.Error(), "broken pipe") {
return true
}
return false

View File

@ -1,11 +0,0 @@
// +build appengine plan9
package request
import (
"strings"
)
func isErrConnectionReset(err error) bool {
return strings.Contains(err.Error(), "connection reset")
}

View File

@ -19,10 +19,11 @@ type Handlers struct {
UnmarshalError HandlerList
Retry HandlerList
AfterRetry HandlerList
CompleteAttempt HandlerList
Complete HandlerList
}
// Copy returns of this handler's lists.
// Copy returns a copy of this handler's lists.
func (h *Handlers) Copy() Handlers {
return Handlers{
Validate: h.Validate.copy(),
@ -36,11 +37,12 @@ func (h *Handlers) Copy() Handlers {
UnmarshalMeta: h.UnmarshalMeta.copy(),
Retry: h.Retry.copy(),
AfterRetry: h.AfterRetry.copy(),
CompleteAttempt: h.CompleteAttempt.copy(),
Complete: h.Complete.copy(),
}
}
// Clear removes callback functions for all handlers
// Clear removes callback functions for all handlers.
func (h *Handlers) Clear() {
h.Validate.Clear()
h.Build.Clear()
@ -53,9 +55,55 @@ func (h *Handlers) Clear() {
h.ValidateResponse.Clear()
h.Retry.Clear()
h.AfterRetry.Clear()
h.CompleteAttempt.Clear()
h.Complete.Clear()
}
// IsEmpty returns if there are no handlers in any of the handlerlists.
func (h *Handlers) IsEmpty() bool {
if h.Validate.Len() != 0 {
return false
}
if h.Build.Len() != 0 {
return false
}
if h.Send.Len() != 0 {
return false
}
if h.Sign.Len() != 0 {
return false
}
if h.Unmarshal.Len() != 0 {
return false
}
if h.UnmarshalStream.Len() != 0 {
return false
}
if h.UnmarshalMeta.Len() != 0 {
return false
}
if h.UnmarshalError.Len() != 0 {
return false
}
if h.ValidateResponse.Len() != 0 {
return false
}
if h.Retry.Len() != 0 {
return false
}
if h.AfterRetry.Len() != 0 {
return false
}
if h.CompleteAttempt.Len() != 0 {
return false
}
if h.Complete.Len() != 0 {
return false
}
return true
}
// A HandlerListRunItem represents an entry in the HandlerList which
// is being run.
type HandlerListRunItem struct {

View File

@ -15,12 +15,15 @@ type offsetReader struct {
closed bool
}
func newOffsetReader(buf io.ReadSeeker, offset int64) *offsetReader {
func newOffsetReader(buf io.ReadSeeker, offset int64) (*offsetReader, error) {
reader := &offsetReader{}
buf.Seek(offset, sdkio.SeekStart)
_, err := buf.Seek(offset, sdkio.SeekStart)
if err != nil {
return nil, err
}
reader.buf = buf
return reader
return reader, nil
}
// Close will close the instance of the offset reader's access to
@ -54,7 +57,9 @@ func (o *offsetReader) Seek(offset int64, whence int) (int64, error) {
// CloseAndCopy will return a new offsetReader with a copy of the old buffer
// and close the old buffer.
func (o *offsetReader) CloseAndCopy(offset int64) *offsetReader {
o.Close()
func (o *offsetReader) CloseAndCopy(offset int64) (*offsetReader, error) {
if err := o.Close(); err != nil {
return nil, err
}
return newOffsetReader(o.buf, offset)
}

View File

@ -4,7 +4,6 @@ import (
"bytes"
"fmt"
"io"
"net"
"net/http"
"net/url"
"reflect"
@ -65,6 +64,15 @@ type Request struct {
LastSignedAt time.Time
DisableFollowRedirects bool
// Additional API error codes that should be retried. IsErrorRetryable
// will consider these codes in addition to its built in cases.
RetryErrorCodes []string
// Additional API error codes that should be retried with throttle backoff
// delay. IsErrorThrottle will consider these codes in addition to its
// built in cases.
ThrottleErrorCodes []string
// A value greater than 0 instructs the request to be signed as Presigned URL
// You should not set this field directly. Instead use Request's
// Presign or PresignRequest methods.
@ -91,8 +99,12 @@ type Operation struct {
BeforePresignFn func(r *Request) error
}
// New returns a new Request pointer for the service API
// operation and parameters.
// New returns a new Request pointer for the service API operation and
// parameters.
//
// A Retryer should be provided to direct how the request is retried. If
// Retryer is nil, a default no retry value will be used. You can use
// NoOpRetryer in the Client package to disable retry behavior directly.
//
// Params is any value of input parameters to be the request payload.
// Data is pointer value to an object which the request's response
@ -100,6 +112,10 @@ type Operation struct {
func New(cfg aws.Config, clientInfo metadata.ClientInfo, handlers Handlers,
retryer Retryer, operation *Operation, params interface{}, data interface{}) *Request {
if retryer == nil {
retryer = noOpRetryer{}
}
method := operation.HTTPMethod
if method == "" {
method = "POST"
@ -122,7 +138,6 @@ func New(cfg aws.Config, clientInfo metadata.ClientInfo, handlers Handlers,
Handlers: handlers.Copy(),
Retryer: retryer,
AttemptTime: time.Now(),
Time: time.Now(),
ExpireTime: 0,
Operation: operation,
@ -233,6 +248,10 @@ func (r *Request) WillRetry() bool {
return r.Error != nil && aws.BoolValue(r.Retryable) && r.RetryCount < r.MaxRetries()
}
func fmtAttemptCount(retryCount, maxRetries int) string {
return fmt.Sprintf("attempt %v/%v", retryCount, maxRetries)
}
// ParamsFilled returns if the request's parameters have been populated
// and the parameters are valid. False is returned if no parameters are
// provided or invalid.
@ -261,12 +280,25 @@ func (r *Request) SetStringBody(s string) {
// SetReaderBody will set the request's body reader.
func (r *Request) SetReaderBody(reader io.ReadSeeker) {
r.Body = reader
r.BodyStart, _ = reader.Seek(0, sdkio.SeekCurrent) // Get the Bodies current offset.
if aws.IsReaderSeekable(reader) {
var err error
// Get the Bodies current offset so retries will start from the same
// initial position.
r.BodyStart, err = reader.Seek(0, sdkio.SeekCurrent)
if err != nil {
r.Error = awserr.New(ErrCodeSerialization,
"failed to determine start of request body", err)
return
}
}
r.ResetBody()
}
// Presign returns the request's signed URL. Error will be returned
// if the signing fails.
// if the signing fails. The expire parameter is only used for presigned Amazon
// S3 API requests. All other AWS services will use a fixed expiration
// time of 15 minutes.
//
// It is invalid to create a presigned URL with a expire duration 0 or less. An
// error is returned if expire duration is 0 or less.
@ -283,7 +315,9 @@ func (r *Request) Presign(expire time.Duration) (string, error) {
}
// PresignRequest behaves just like presign, with the addition of returning a
// set of headers that were signed.
// set of headers that were signed. The expire parameter is only used for
// presigned Amazon S3 API requests. All other AWS services will use a fixed
// expiration time of 15 minutes.
//
// It is invalid to create a presigned URL with a expire duration 0 or less. An
// error is returned if expire duration is 0 or less.
@ -328,16 +362,15 @@ func getPresignedURL(r *Request, expire time.Duration) (string, http.Header, err
return r.HTTPRequest.URL.String(), r.SignedHeaderVals, nil
}
func debugLogReqError(r *Request, stage string, retrying bool, err error) {
const (
notRetrying = "not retrying"
)
func debugLogReqError(r *Request, stage, retryStr string, err error) {
if !r.Config.LogLevel.Matches(aws.LogDebugWithRequestErrors) {
return
}
retryStr := "not retrying"
if retrying {
retryStr = "will retry"
}
r.Config.Logger.Log(fmt.Sprintf("DEBUG: %s %s/%s failed, %s, error %v",
stage, r.ClientInfo.ServiceName, r.Operation.Name, retryStr, err))
}
@ -356,12 +389,12 @@ func (r *Request) Build() error {
if !r.built {
r.Handlers.Validate.Run(r)
if r.Error != nil {
debugLogReqError(r, "Validate Request", false, r.Error)
debugLogReqError(r, "Validate Request", notRetrying, r.Error)
return r.Error
}
r.Handlers.Build.Run(r)
if r.Error != nil {
debugLogReqError(r, "Build Request", false, r.Error)
debugLogReqError(r, "Build Request", notRetrying, r.Error)
return r.Error
}
r.built = true
@ -377,7 +410,7 @@ func (r *Request) Build() error {
func (r *Request) Sign() error {
r.Build()
if r.Error != nil {
debugLogReqError(r, "Build Request", false, r.Error)
debugLogReqError(r, "Build Request", notRetrying, r.Error)
return r.Error
}
@ -385,12 +418,16 @@ func (r *Request) Sign() error {
return r.Error
}
func (r *Request) getNextRequestBody() (io.ReadCloser, error) {
func (r *Request) getNextRequestBody() (body io.ReadCloser, err error) {
if r.safeBody != nil {
r.safeBody.Close()
}
r.safeBody = newOffsetReader(r.Body, r.BodyStart)
r.safeBody, err = newOffsetReader(r.Body, r.BodyStart)
if err != nil {
return nil, awserr.New(ErrCodeSerialization,
"failed to get next request body reader", err)
}
// Go 1.8 tightened and clarified the rules code needs to use when building
// requests with the http package. Go 1.8 removed the automatic detection
@ -407,10 +444,10 @@ func (r *Request) getNextRequestBody() (io.ReadCloser, error) {
// Related golang/go#18257
l, err := aws.SeekerLen(r.Body)
if err != nil {
return nil, awserr.New(ErrCodeSerialization, "failed to compute request body size", err)
return nil, awserr.New(ErrCodeSerialization,
"failed to compute request body size", err)
}
var body io.ReadCloser
if l == 0 {
body = NoBody
} else if l > 0 {
@ -462,80 +499,90 @@ func (r *Request) Send() error {
r.Handlers.Complete.Run(r)
}()
if err := r.Error; err != nil {
return err
}
for {
r.Error = nil
r.AttemptTime = time.Now()
if aws.BoolValue(r.Retryable) {
if r.Config.LogLevel.Matches(aws.LogDebugWithRequestRetries) {
r.Config.Logger.Log(fmt.Sprintf("DEBUG: Retrying Request %s/%s, attempt %d",
r.ClientInfo.ServiceName, r.Operation.Name, r.RetryCount))
}
// The previous http.Request will have a reference to the r.Body
// and the HTTP Client's Transport may still be reading from
// the request's body even though the Client's Do returned.
r.HTTPRequest = copyHTTPRequest(r.HTTPRequest, nil)
r.ResetBody()
// Closing response body to ensure that no response body is leaked
// between retry attempts.
if r.HTTPResponse != nil && r.HTTPResponse.Body != nil {
r.HTTPResponse.Body.Close()
}
if err := r.Sign(); err != nil {
debugLogReqError(r, "Sign Request", notRetrying, err)
return err
}
r.Sign()
if r.Error != nil {
if err := r.sendRequest(); err == nil {
return nil
}
r.Handlers.Retry.Run(r)
r.Handlers.AfterRetry.Run(r)
if r.Error != nil || !aws.BoolValue(r.Retryable) {
return r.Error
}
r.Retryable = nil
r.Handlers.Send.Run(r)
if r.Error != nil {
if !shouldRetryCancel(r) {
return r.Error
}
err := r.Error
r.Handlers.Retry.Run(r)
r.Handlers.AfterRetry.Run(r)
if r.Error != nil {
debugLogReqError(r, "Send Request", false, err)
return r.Error
}
debugLogReqError(r, "Send Request", true, err)
continue
if err := r.prepareRetry(); err != nil {
r.Error = err
return err
}
r.Handlers.UnmarshalMeta.Run(r)
r.Handlers.ValidateResponse.Run(r)
if r.Error != nil {
r.Handlers.UnmarshalError.Run(r)
err := r.Error
}
}
r.Handlers.Retry.Run(r)
r.Handlers.AfterRetry.Run(r)
if r.Error != nil {
debugLogReqError(r, "Validate Response", false, err)
return r.Error
}
debugLogReqError(r, "Validate Response", true, err)
continue
}
func (r *Request) prepareRetry() error {
if r.Config.LogLevel.Matches(aws.LogDebugWithRequestRetries) {
r.Config.Logger.Log(fmt.Sprintf("DEBUG: Retrying Request %s/%s, attempt %d",
r.ClientInfo.ServiceName, r.Operation.Name, r.RetryCount))
}
r.Handlers.Unmarshal.Run(r)
if r.Error != nil {
err := r.Error
r.Handlers.Retry.Run(r)
r.Handlers.AfterRetry.Run(r)
if r.Error != nil {
debugLogReqError(r, "Unmarshal Response", false, err)
return r.Error
}
debugLogReqError(r, "Unmarshal Response", true, err)
continue
}
// The previous http.Request will have a reference to the r.Body
// and the HTTP Client's Transport may still be reading from
// the request's body even though the Client's Do returned.
r.HTTPRequest = copyHTTPRequest(r.HTTPRequest, nil)
r.ResetBody()
if err := r.Error; err != nil {
return awserr.New(ErrCodeSerialization,
"failed to prepare body for retry", err)
break
}
// Closing response body to ensure that no response body is leaked
// between retry attempts.
if r.HTTPResponse != nil && r.HTTPResponse.Body != nil {
r.HTTPResponse.Body.Close()
}
return nil
}
func (r *Request) sendRequest() (sendErr error) {
defer r.Handlers.CompleteAttempt.Run(r)
r.Retryable = nil
r.Handlers.Send.Run(r)
if r.Error != nil {
debugLogReqError(r, "Send Request",
fmtAttemptCount(r.RetryCount, r.MaxRetries()),
r.Error)
return r.Error
}
r.Handlers.UnmarshalMeta.Run(r)
r.Handlers.ValidateResponse.Run(r)
if r.Error != nil {
r.Handlers.UnmarshalError.Run(r)
debugLogReqError(r, "Validate Response",
fmtAttemptCount(r.RetryCount, r.MaxRetries()),
r.Error)
return r.Error
}
r.Handlers.Unmarshal.Run(r)
if r.Error != nil {
debugLogReqError(r, "Unmarshal Response",
fmtAttemptCount(r.RetryCount, r.MaxRetries()),
r.Error)
return r.Error
}
return nil
@ -561,32 +608,6 @@ func AddToUserAgent(r *Request, s string) {
r.HTTPRequest.Header.Set("User-Agent", s)
}
func shouldRetryCancel(r *Request) bool {
awsErr, ok := r.Error.(awserr.Error)
timeoutErr := false
errStr := r.Error.Error()
if ok {
if awsErr.Code() == CanceledErrorCode {
return false
}
err := awsErr.OrigErr()
netErr, netOK := err.(net.Error)
timeoutErr = netOK && netErr.Temporary()
if urlErr, ok := err.(*url.Error); !timeoutErr && ok {
errStr = urlErr.Err.Error()
}
}
// There can be two types of canceled errors here.
// The first being a net.Error and the other being an error.
// If the request was timed out, we want to continue the retry
// process. Otherwise, return the canceled error.
return timeoutErr ||
(errStr != "net/http: request canceled" &&
errStr != "net/http: request canceled while waiting for connection")
}
// SanitizeHostForHeader removes default port from host and updates request.Host
func SanitizeHostForHeader(r *http.Request) {
host := getHost(r)

View File

@ -4,6 +4,8 @@ package request
import (
"net/http"
"github.com/aws/aws-sdk-go/aws/awserr"
)
// NoBody is a http.NoBody reader instructing Go HTTP client to not include
@ -24,7 +26,8 @@ var NoBody = http.NoBody
func (r *Request) ResetBody() {
body, err := r.getNextRequestBody()
if err != nil {
r.Error = err
r.Error = awserr.New(ErrCodeSerialization,
"failed to reset request body", err)
return
}

View File

@ -146,7 +146,7 @@ func (r *Request) nextPageTokens() []interface{} {
return nil
}
case bool:
if v == false {
if !v {
return nil
}
}

View File

@ -1,26 +1,75 @@
package request
import (
"net"
"net/url"
"strings"
"time"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/awserr"
)
// Retryer is an interface to control retry logic for a given service.
// The default implementation used by most services is the client.DefaultRetryer
// structure, which contains basic retry logic using exponential backoff.
// Retryer provides the interface drive the SDK's request retry behavior. The
// Retryer implementation is responsible for implementing exponential backoff,
// and determine if a request API error should be retried.
//
// client.DefaultRetryer is the SDK's default implementation of the Retryer. It
// uses the which uses the Request.IsErrorRetryable and Request.IsErrorThrottle
// methods to determine if the request is retried.
type Retryer interface {
// RetryRules return the retry delay that should be used by the SDK before
// making another request attempt for the failed request.
RetryRules(*Request) time.Duration
// ShouldRetry returns if the failed request is retryable.
//
// Implementations may consider request attempt count when determining if a
// request is retryable, but the SDK will use MaxRetries to limit the
// number of attempts a request are made.
ShouldRetry(*Request) bool
// MaxRetries is the number of times a request may be retried before
// failing.
MaxRetries() int
}
// WithRetryer sets a config Retryer value to the given Config returning it
// for chaining.
// WithRetryer sets a Retryer value to the given Config returning the Config
// value for chaining. The value must not be nil.
func WithRetryer(cfg *aws.Config, retryer Retryer) *aws.Config {
if retryer == nil {
if cfg.Logger != nil {
cfg.Logger.Log("ERROR: Request.WithRetryer called with nil retryer. Replacing with retry disabled Retryer.")
}
retryer = noOpRetryer{}
}
cfg.Retryer = retryer
return cfg
}
// noOpRetryer is a internal no op retryer used when a request is created
// without a retryer.
//
// Provides a retryer that performs no retries.
// It should be used when we do not want retries to be performed.
type noOpRetryer struct{}
// MaxRetries returns the number of maximum returns the service will use to make
// an individual API; For NoOpRetryer the MaxRetries will always be zero.
func (d noOpRetryer) MaxRetries() int {
return 0
}
// ShouldRetry will always return false for NoOpRetryer, as it should never retry.
func (d noOpRetryer) ShouldRetry(_ *Request) bool {
return false
}
// RetryRules returns the delay duration before retrying this request again;
// since NoOpRetryer does not retry, RetryRules always returns 0.
func (d noOpRetryer) RetryRules(_ *Request) time.Duration {
return 0
}
// retryableCodes is a collection of service response codes which are retry-able
@ -38,8 +87,10 @@ var throttleCodes = map[string]struct{}{
"ThrottlingException": {},
"RequestLimitExceeded": {},
"RequestThrottled": {},
"RequestThrottledException": {},
"TooManyRequestsException": {}, // Lambda functions
"PriorRequestNotComplete": {}, // Route53
"TransactionInProgressException": {},
}
// credsExpiredCodes is a collection of error codes which signify the credentials
@ -74,10 +125,6 @@ var validParentCodes = map[string]struct{}{
ErrCodeRead: {},
}
type temporaryError interface {
Temporary() bool
}
func isNestedErrorRetryable(parentErr awserr.Error) bool {
if parentErr == nil {
return false
@ -96,7 +143,7 @@ func isNestedErrorRetryable(parentErr awserr.Error) bool {
return isCodeRetryable(aerr.Code())
}
if t, ok := err.(temporaryError); ok {
if t, ok := err.(temporary); ok {
return t.Temporary() || isErrConnectionReset(err)
}
@ -106,32 +153,90 @@ func isNestedErrorRetryable(parentErr awserr.Error) bool {
// IsErrorRetryable returns whether the error is retryable, based on its Code.
// Returns false if error is nil.
func IsErrorRetryable(err error) bool {
if err != nil {
if aerr, ok := err.(awserr.Error); ok {
return isCodeRetryable(aerr.Code()) || isNestedErrorRetryable(aerr)
}
if err == nil {
return false
}
return shouldRetryError(err)
}
type temporary interface {
Temporary() bool
}
func shouldRetryError(origErr error) bool {
switch err := origErr.(type) {
case awserr.Error:
if err.Code() == CanceledErrorCode {
return false
}
if isNestedErrorRetryable(err) {
return true
}
origErr := err.OrigErr()
var shouldRetry bool
if origErr != nil {
shouldRetry := shouldRetryError(origErr)
if err.Code() == "RequestError" && !shouldRetry {
return false
}
}
if isCodeRetryable(err.Code()) {
return true
}
return shouldRetry
case *url.Error:
if strings.Contains(err.Error(), "connection refused") {
// Refused connections should be retried as the service may not yet
// be running on the port. Go TCP dial considers refused
// connections as not temporary.
return true
}
// *url.Error only implements Temporary after golang 1.6 but since
// url.Error only wraps the error:
return shouldRetryError(err.Err)
case temporary:
if netErr, ok := err.(*net.OpError); ok && netErr.Op == "dial" {
return true
}
// If the error is temporary, we want to allow continuation of the
// retry process
return err.Temporary() || isErrConnectionReset(origErr)
case nil:
// `awserr.Error.OrigErr()` can be nil, meaning there was an error but
// because we don't know the cause, it is marked as retryable. See
// TestRequest4xxUnretryable for an example.
return true
default:
switch err.Error() {
case "net/http: request canceled",
"net/http: request canceled while waiting for connection":
// known 1.5 error case when an http request is cancelled
return false
}
// here we don't know the error; so we allow a retry.
return true
}
return false
}
// IsErrorThrottle returns whether the error is to be throttled based on its code.
// Returns false if error is nil.
func IsErrorThrottle(err error) bool {
if err != nil {
if aerr, ok := err.(awserr.Error); ok {
return isCodeThrottle(aerr.Code())
}
if aerr, ok := err.(awserr.Error); ok && aerr != nil {
return isCodeThrottle(aerr.Code())
}
return false
}
// IsErrorExpiredCreds returns whether the error code is a credential expiry error.
// Returns false if error is nil.
// IsErrorExpiredCreds returns whether the error code is a credential expiry
// error. Returns false if error is nil.
func IsErrorExpiredCreds(err error) bool {
if err != nil {
if aerr, ok := err.(awserr.Error); ok {
return isCodeExpiredCreds(aerr.Code())
}
if aerr, ok := err.(awserr.Error); ok && aerr != nil {
return isCodeExpiredCreds(aerr.Code())
}
return false
}
@ -141,17 +246,58 @@ func IsErrorExpiredCreds(err error) bool {
//
// Alias for the utility function IsErrorRetryable
func (r *Request) IsErrorRetryable() bool {
if isErrCode(r.Error, r.RetryErrorCodes) {
return true
}
// HTTP response status code 501 should not be retried.
// 501 represents Not Implemented which means the request method is not
// supported by the server and cannot be handled.
if r.HTTPResponse != nil {
// HTTP response status code 500 represents internal server error and
// should be retried without any throttle.
if r.HTTPResponse.StatusCode == 500 {
return true
}
}
return IsErrorRetryable(r.Error)
}
// IsErrorThrottle returns whether the error is to be throttled based on its code.
// Returns false if the request has no Error set
// IsErrorThrottle returns whether the error is to be throttled based on its
// code. Returns false if the request has no Error set.
//
// Alias for the utility function IsErrorThrottle
func (r *Request) IsErrorThrottle() bool {
if isErrCode(r.Error, r.ThrottleErrorCodes) {
return true
}
if r.HTTPResponse != nil {
switch r.HTTPResponse.StatusCode {
case
429, // error caused due to too many requests
502, // Bad Gateway error should be throttled
503, // caused when service is unavailable
504: // error occurred due to gateway timeout
return true
}
}
return IsErrorThrottle(r.Error)
}
func isErrCode(err error, codes []string) bool {
if aerr, ok := err.(awserr.Error); ok && aerr != nil {
for _, code := range codes {
if code == aerr.Code() {
return true
}
}
}
return false
}
// IsErrorExpired returns whether the error code is a credential expiry error.
// Returns false if the request has no Error set.
//

View File

@ -17,6 +17,12 @@ const (
ParamMinValueErrCode = "ParamMinValueError"
// ParamMinLenErrCode is the error code for fields without enough elements.
ParamMinLenErrCode = "ParamMinLenError"
// ParamMaxLenErrCode is the error code for value being too long.
ParamMaxLenErrCode = "ParamMaxLenError"
// ParamFormatErrCode is the error code for a field with invalid
// format or characters.
ParamFormatErrCode = "ParamFormatInvalidError"
)
// Validator provides a way for types to perform validation logic on their
@ -232,3 +238,49 @@ func NewErrParamMinLen(field string, min int) *ErrParamMinLen {
func (e *ErrParamMinLen) MinLen() int {
return e.min
}
// An ErrParamMaxLen represents a maximum length parameter error.
type ErrParamMaxLen struct {
errInvalidParam
max int
}
// NewErrParamMaxLen creates a new maximum length parameter error.
func NewErrParamMaxLen(field string, max int, value string) *ErrParamMaxLen {
return &ErrParamMaxLen{
errInvalidParam: errInvalidParam{
code: ParamMaxLenErrCode,
field: field,
msg: fmt.Sprintf("maximum size of %v, %v", max, value),
},
max: max,
}
}
// MaxLen returns the field's required minimum length.
func (e *ErrParamMaxLen) MaxLen() int {
return e.max
}
// An ErrParamFormat represents a invalid format parameter error.
type ErrParamFormat struct {
errInvalidParam
format string
}
// NewErrParamFormat creates a new invalid format parameter error.
func NewErrParamFormat(field string, format, value string) *ErrParamFormat {
return &ErrParamFormat{
errInvalidParam: errInvalidParam{
code: ParamFormatErrCode,
field: field,
msg: fmt.Sprintf("format %v, %v", format, value),
},
format: format,
}
}
// Format returns the field's required format.
func (e *ErrParamFormat) Format() string {
return e.format
}

View File

@ -0,0 +1,26 @@
// +build go1.7
package session
import (
"net"
"net/http"
"time"
)
// Transport that should be used when a custom CA bundle is specified with the
// SDK.
func getCABundleTransport() *http.Transport {
return &http.Transport{
Proxy: http.ProxyFromEnvironment,
DialContext: (&net.Dialer{
Timeout: 30 * time.Second,
KeepAlive: 30 * time.Second,
DualStack: true,
}).DialContext,
MaxIdleConns: 100,
IdleConnTimeout: 90 * time.Second,
TLSHandshakeTimeout: 10 * time.Second,
ExpectContinueTimeout: 1 * time.Second,
}
}

View File

@ -0,0 +1,22 @@
// +build !go1.6,go1.5
package session
import (
"net"
"net/http"
"time"
)
// Transport that should be used when a custom CA bundle is specified with the
// SDK.
func getCABundleTransport() *http.Transport {
return &http.Transport{
Proxy: http.ProxyFromEnvironment,
Dial: (&net.Dialer{
Timeout: 30 * time.Second,
KeepAlive: 30 * time.Second,
}).Dial,
TLSHandshakeTimeout: 10 * time.Second,
}
}

View File

@ -0,0 +1,23 @@
// +build !go1.7,go1.6
package session
import (
"net"
"net/http"
"time"
)
// Transport that should be used when a custom CA bundle is specified with the
// SDK.
func getCABundleTransport() *http.Transport {
return &http.Transport{
Proxy: http.ProxyFromEnvironment,
Dial: (&net.Dialer{
Timeout: 30 * time.Second,
KeepAlive: 30 * time.Second,
}).Dial,
TLSHandshakeTimeout: 10 * time.Second,
ExpectContinueTimeout: 1 * time.Second,
}
}

View File

@ -0,0 +1,259 @@
package session
import (
"fmt"
"os"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/aws/credentials"
"github.com/aws/aws-sdk-go/aws/credentials/processcreds"
"github.com/aws/aws-sdk-go/aws/credentials/stscreds"
"github.com/aws/aws-sdk-go/aws/defaults"
"github.com/aws/aws-sdk-go/aws/request"
"github.com/aws/aws-sdk-go/internal/shareddefaults"
)
func resolveCredentials(cfg *aws.Config,
envCfg envConfig, sharedCfg sharedConfig,
handlers request.Handlers,
sessOpts Options,
) (*credentials.Credentials, error) {
switch {
case len(sessOpts.Profile) != 0:
// User explicitly provided an Profile in the session's configuration
// so load that profile from shared config first.
// Github(aws/aws-sdk-go#2727)
return resolveCredsFromProfile(cfg, envCfg, sharedCfg, handlers, sessOpts)
case envCfg.Creds.HasKeys():
// Environment credentials
return credentials.NewStaticCredentialsFromCreds(envCfg.Creds), nil
case len(envCfg.WebIdentityTokenFilePath) != 0:
// Web identity token from environment, RoleARN required to also be
// set.
return assumeWebIdentity(cfg, handlers,
envCfg.WebIdentityTokenFilePath,
envCfg.RoleARN,
envCfg.RoleSessionName,
)
default:
// Fallback to the "default" credential resolution chain.
return resolveCredsFromProfile(cfg, envCfg, sharedCfg, handlers, sessOpts)
}
}
// WebIdentityEmptyRoleARNErr will occur if 'AWS_WEB_IDENTITY_TOKEN_FILE' was set but
// 'AWS_IAM_ROLE_ARN' was not set.
var WebIdentityEmptyRoleARNErr = awserr.New(stscreds.ErrCodeWebIdentity, "role ARN is not set", nil)
// WebIdentityEmptyTokenFilePathErr will occur if 'AWS_IAM_ROLE_ARN' was set but
// 'AWS_WEB_IDENTITY_TOKEN_FILE' was not set.
var WebIdentityEmptyTokenFilePathErr = awserr.New(stscreds.ErrCodeWebIdentity, "token file path is not set", nil)
func assumeWebIdentity(cfg *aws.Config, handlers request.Handlers,
filepath string,
roleARN, sessionName string,
) (*credentials.Credentials, error) {
if len(filepath) == 0 {
return nil, WebIdentityEmptyTokenFilePathErr
}
if len(roleARN) == 0 {
return nil, WebIdentityEmptyRoleARNErr
}
creds := stscreds.NewWebIdentityCredentials(
&Session{
Config: cfg,
Handlers: handlers.Copy(),
},
roleARN,
sessionName,
filepath,
)
return creds, nil
}
func resolveCredsFromProfile(cfg *aws.Config,
envCfg envConfig, sharedCfg sharedConfig,
handlers request.Handlers,
sessOpts Options,
) (creds *credentials.Credentials, err error) {
switch {
case sharedCfg.SourceProfile != nil:
// Assume IAM role with credentials source from a different profile.
creds, err = resolveCredsFromProfile(cfg, envCfg,
*sharedCfg.SourceProfile, handlers, sessOpts,
)
case sharedCfg.Creds.HasKeys():
// Static Credentials from Shared Config/Credentials file.
creds = credentials.NewStaticCredentialsFromCreds(
sharedCfg.Creds,
)
case len(sharedCfg.CredentialProcess) != 0:
// Get credentials from CredentialProcess
creds = processcreds.NewCredentials(sharedCfg.CredentialProcess)
case len(sharedCfg.CredentialSource) != 0:
creds, err = resolveCredsFromSource(cfg, envCfg,
sharedCfg, handlers, sessOpts,
)
case len(sharedCfg.WebIdentityTokenFile) != 0:
// Credentials from Assume Web Identity token require an IAM Role, and
// that roll will be assumed. May be wrapped with another assume role
// via SourceProfile.
return assumeWebIdentity(cfg, handlers,
sharedCfg.WebIdentityTokenFile,
sharedCfg.RoleARN,
sharedCfg.RoleSessionName,
)
default:
// Fallback to default credentials provider, include mock errors for
// the credential chain so user can identify why credentials failed to
// be retrieved.
creds = credentials.NewCredentials(&credentials.ChainProvider{
VerboseErrors: aws.BoolValue(cfg.CredentialsChainVerboseErrors),
Providers: []credentials.Provider{
&credProviderError{
Err: awserr.New("EnvAccessKeyNotFound",
"failed to find credentials in the environment.", nil),
},
&credProviderError{
Err: awserr.New("SharedCredsLoad",
fmt.Sprintf("failed to load profile, %s.", envCfg.Profile), nil),
},
defaults.RemoteCredProvider(*cfg, handlers),
},
})
}
if err != nil {
return nil, err
}
if len(sharedCfg.RoleARN) > 0 {
cfgCp := *cfg
cfgCp.Credentials = creds
return credsFromAssumeRole(cfgCp, handlers, sharedCfg, sessOpts)
}
return creds, nil
}
// valid credential source values
const (
credSourceEc2Metadata = "Ec2InstanceMetadata"
credSourceEnvironment = "Environment"
credSourceECSContainer = "EcsContainer"
)
func resolveCredsFromSource(cfg *aws.Config,
envCfg envConfig, sharedCfg sharedConfig,
handlers request.Handlers,
sessOpts Options,
) (creds *credentials.Credentials, err error) {
switch sharedCfg.CredentialSource {
case credSourceEc2Metadata:
p := defaults.RemoteCredProvider(*cfg, handlers)
creds = credentials.NewCredentials(p)
case credSourceEnvironment:
creds = credentials.NewStaticCredentialsFromCreds(envCfg.Creds)
case credSourceECSContainer:
if len(os.Getenv(shareddefaults.ECSCredsProviderEnvVar)) == 0 {
return nil, ErrSharedConfigECSContainerEnvVarEmpty
}
p := defaults.RemoteCredProvider(*cfg, handlers)
creds = credentials.NewCredentials(p)
default:
return nil, ErrSharedConfigInvalidCredSource
}
return creds, nil
}
func credsFromAssumeRole(cfg aws.Config,
handlers request.Handlers,
sharedCfg sharedConfig,
sessOpts Options,
) (*credentials.Credentials, error) {
if len(sharedCfg.MFASerial) != 0 && sessOpts.AssumeRoleTokenProvider == nil {
// AssumeRole Token provider is required if doing Assume Role
// with MFA.
return nil, AssumeRoleTokenProviderNotSetError{}
}
return stscreds.NewCredentials(
&Session{
Config: &cfg,
Handlers: handlers.Copy(),
},
sharedCfg.RoleARN,
func(opt *stscreds.AssumeRoleProvider) {
opt.RoleSessionName = sharedCfg.RoleSessionName
opt.Duration = sessOpts.AssumeRoleDuration
// Assume role with external ID
if len(sharedCfg.ExternalID) > 0 {
opt.ExternalID = aws.String(sharedCfg.ExternalID)
}
// Assume role with MFA
if len(sharedCfg.MFASerial) > 0 {
opt.SerialNumber = aws.String(sharedCfg.MFASerial)
opt.TokenProvider = sessOpts.AssumeRoleTokenProvider
}
},
), nil
}
// AssumeRoleTokenProviderNotSetError is an error returned when creating a
// session when the MFAToken option is not set when shared config is configured
// load assume a role with an MFA token.
type AssumeRoleTokenProviderNotSetError struct{}
// Code is the short id of the error.
func (e AssumeRoleTokenProviderNotSetError) Code() string {
return "AssumeRoleTokenProviderNotSetError"
}
// Message is the description of the error
func (e AssumeRoleTokenProviderNotSetError) Message() string {
return fmt.Sprintf("assume role with MFA enabled, but AssumeRoleTokenProvider session option not set.")
}
// OrigErr is the underlying error that caused the failure.
func (e AssumeRoleTokenProviderNotSetError) OrigErr() error {
return nil
}
// Error satisfies the error interface.
func (e AssumeRoleTokenProviderNotSetError) Error() string {
return awserr.SprintError(e.Code(), e.Message(), "", nil)
}
type credProviderError struct {
Err error
}
func (c credProviderError) Retrieve() (credentials.Value, error) {
return credentials.Value{}, c.Err
}
func (c credProviderError) IsExpired() bool {
return true
}

View File

@ -1,97 +1,93 @@
/*
Package session provides configuration for the SDK's service clients.
Sessions can be shared across all service clients that share the same base
configuration. The Session is built from the SDK's default configuration and
request handlers.
Sessions should be cached when possible, because creating a new Session will
load all configuration values from the environment, and config files each time
the Session is created. Sharing the Session value across all of your service
clients will ensure the configuration is loaded the fewest number of times possible.
Concurrency
Package session provides configuration for the SDK's service clients. Sessions
can be shared across service clients that share the same base configuration.
Sessions are safe to use concurrently as long as the Session is not being
modified. The SDK will not modify the Session once the Session has been created.
Creating service clients concurrently from a shared Session is safe.
modified. Sessions should be cached when possible, because creating a new
Session will load all configuration values from the environment, and config
files each time the Session is created. Sharing the Session value across all of
your service clients will ensure the configuration is loaded the fewest number
of times possible.
Sessions from Shared Config
Sessions can be created using the method above that will only load the
additional config if the AWS_SDK_LOAD_CONFIG environment variable is set.
Alternatively you can explicitly create a Session with shared config enabled.
To do this you can use NewSessionWithOptions to configure how the Session will
be created. Using the NewSessionWithOptions with SharedConfigState set to
SharedConfigEnable will create the session as if the AWS_SDK_LOAD_CONFIG
environment variable was set.
Creating Sessions
When creating Sessions optional aws.Config values can be passed in that will
override the default, or loaded config values the Session is being created
with. This allows you to provide additional, or case based, configuration
as needed.
Sessions options from Shared Config
By default NewSession will only load credentials from the shared credentials
file (~/.aws/credentials). If the AWS_SDK_LOAD_CONFIG environment variable is
set to a truthy value the Session will be created from the configuration
values from the shared config (~/.aws/config) and shared credentials
(~/.aws/credentials) files. See the section Sessions from Shared Config for
more information.
(~/.aws/credentials) files. Using the NewSessionWithOptions with
SharedConfigState set to SharedConfigEnable will create the session as if the
AWS_SDK_LOAD_CONFIG environment variable was set.
Create a Session with the default config and request handlers. With credentials
region, and profile loaded from the environment and shared config automatically.
Requires the AWS_PROFILE to be set, or "default" is used.
Credential and config loading order
The Session will attempt to load configuration and credentials from the
environment, configuration files, and other credential sources. The order
configuration is loaded in is:
* Environment Variables
* Shared Credentials file
* Shared Configuration file (if SharedConfig is enabled)
* EC2 Instance Metadata (credentials only)
The Environment variables for credentials will have precedence over shared
config even if SharedConfig is enabled. To override this behavior, and use
shared config credentials instead specify the session.Options.Profile, (e.g.
when using credential_source=Environment to assume a role).
sess, err := session.NewSessionWithOptions(session.Options{
Profile: "myProfile",
})
Creating Sessions
Creating a Session without additional options will load credentials region, and
profile loaded from the environment and shared config automatically. See,
"Environment Variables" section for information on environment variables used
by Session.
// Create Session
sess := session.Must(session.NewSession())
sess, err := session.NewSession()
When creating Sessions optional aws.Config values can be passed in that will
override the default, or loaded, config values the Session is being created
with. This allows you to provide additional, or case based, configuration
as needed.
// Create a Session with a custom region
sess := session.Must(session.NewSession(&aws.Config{
Region: aws.String("us-east-1"),
}))
sess, err := session.NewSession(&aws.Config{
Region: aws.String("us-west-2"),
})
// Create a S3 client instance from a session
sess := session.Must(session.NewSession())
svc := s3.New(sess)
Create Session With Option Overrides
In addition to NewSession, Sessions can be created using NewSessionWithOptions.
This func allows you to control and override how the Session will be created
through code instead of being driven by environment variables only.
Use NewSessionWithOptions when you want to provide the config profile, or
override the shared config state (AWS_SDK_LOAD_CONFIG).
Use NewSessionWithOptions to provide additional configuration driving how the
Session's configuration will be loaded. Such as, specifying shared config
profile, or override the shared config state, (AWS_SDK_LOAD_CONFIG).
// Equivalent to session.NewSession()
sess := session.Must(session.NewSessionWithOptions(session.Options{
sess, err := session.NewSessionWithOptions(session.Options{
// Options
}))
})
// Specify profile to load for the session's config
sess := session.Must(session.NewSessionWithOptions(session.Options{
Profile: "profile_name",
}))
sess, err := session.NewSessionWithOptions(session.Options{
// Specify profile to load for the session's config
Profile: "profile_name",
// Specify profile for config and region for requests
sess := session.Must(session.NewSessionWithOptions(session.Options{
Config: aws.Config{Region: aws.String("us-east-1")},
Profile: "profile_name",
}))
// Provide SDK Config options, such as Region.
Config: aws.Config{
Region: aws.String("us-west-2"),
},
// Force enable Shared Config support
sess := session.Must(session.NewSessionWithOptions(session.Options{
// Force enable Shared Config support
SharedConfigState: session.SharedConfigEnable,
}))
})
Adding Handlers
You can add handlers to a session for processing HTTP requests. All service
clients that use the session inherit the handlers. For example, the following
handler logs every request and its payload made by a service client:
You can add handlers to a session to decorate API operation, (e.g. adding HTTP
headers). All clients that use the Session receive a copy of the Session's
handlers. For example, the following request handler added to the Session logs
every requests made.
// Create a session, and add additional handlers for all service
// clients created with the Session to inherit. Adds logging handler.
@ -99,22 +95,15 @@ handler logs every request and its payload made by a service client:
sess.Handlers.Send.PushFront(func(r *request.Request) {
// Log every request made and its payload
logger.Println("Request: %s/%s, Payload: %s",
logger.Printf("Request: %s/%s, Params: %s",
r.ClientInfo.ServiceName, r.Operation, r.Params)
})
Deprecated "New" function
The New session function has been deprecated because it does not provide good
way to return errors that occur when loading the configuration files and values.
Because of this, NewSession was created so errors can be retrieved when
creating a session fails.
Shared Config Fields
By default the SDK will only load the shared credentials file's (~/.aws/credentials)
credentials values, and all other config is provided by the environment variables,
SDK defaults, and user provided aws.Config values.
By default the SDK will only load the shared credentials file's
(~/.aws/credentials) credentials values, and all other config is provided by
the environment variables, SDK defaults, and user provided aws.Config values.
If the AWS_SDK_LOAD_CONFIG environment variable is set, or SharedConfigEnable
option is used to create the Session the full shared config values will be
@ -125,24 +114,31 @@ files have the same format.
If both config files are present the configuration from both files will be
read. The Session will be created from configuration values from the shared
credentials file (~/.aws/credentials) over those in the shared config file (~/.aws/config).
credentials file (~/.aws/credentials) over those in the shared config file
(~/.aws/config).
Credentials are the values the SDK should use for authenticating requests with
AWS Services. They are from a configuration file will need to include both
aws_access_key_id and aws_secret_access_key must be provided together in the
same file to be considered valid. The values will be ignored if not a complete
group. aws_session_token is an optional field that can be provided if both of
the other two fields are also provided.
Credentials are the values the SDK uses to authenticating requests with AWS
Services. When specified in a file, both aws_access_key_id and
aws_secret_access_key must be provided together in the same file to be
considered valid. They will be ignored if both are not present.
aws_session_token is an optional field that can be provided in addition to the
other two fields.
aws_access_key_id = AKID
aws_secret_access_key = SECRET
aws_session_token = TOKEN
Assume Role values allow you to configure the SDK to assume an IAM role using
a set of credentials provided in a config file via the source_profile field.
Both "role_arn" and "source_profile" are required. The SDK supports assuming
a role with MFA token if the session option AssumeRoleTokenProvider
is set.
; region only supported if SharedConfigEnabled.
region = us-east-1
Assume Role configuration
The role_arn field allows you to configure the SDK to assume an IAM role using
a set of credentials from another source. Such as when paired with static
credentials, "profile_source", "credential_process", or "credential_source"
fields. If "role_arn" is provided, a source of credentials must also be
specified, such as "source_profile", "credential_source", or
"credential_process".
role_arn = arn:aws:iam::<account_number>:role/<role_name>
source_profile = profile_with_creds
@ -150,40 +146,16 @@ is set.
mfa_serial = <serial or mfa arn>
role_session_name = session_name
Region is the region the SDK should use for looking up AWS service endpoints
and signing requests.
region = us-east-1
Assume Role with MFA token
To create a session with support for assuming an IAM role with MFA set the
session option AssumeRoleTokenProvider to a function that will prompt for the
MFA token code when the SDK assumes the role and refreshes the role's credentials.
This allows you to configure the SDK via the shared config to assumea role
with MFA tokens.
In order for the SDK to assume a role with MFA the SharedConfigState
session option must be set to SharedConfigEnable, or AWS_SDK_LOAD_CONFIG
environment variable set.
The shared configuration instructs the SDK to assume an IAM role with MFA
when the mfa_serial configuration field is set in the shared config
(~/.aws/config) or shared credentials (~/.aws/credentials) file.
If mfa_serial is set in the configuration, the SDK will assume the role, and
the AssumeRoleTokenProvider session option is not set an an error will
be returned when creating the session.
The SDK supports assuming a role with MFA token. If "mfa_serial" is set, you
must also set the Session Option.AssumeRoleTokenProvider. The Session will fail
to load if the AssumeRoleTokenProvider is not specified.
sess := session.Must(session.NewSessionWithOptions(session.Options{
AssumeRoleTokenProvider: stscreds.StdinTokenProvider,
}))
// Create service client value configured for credentials
// from assumed role.
svc := s3.New(sess)
To setup assume role outside of a session see the stscrds.AssumeRoleProvider
To setup Assume Role outside of a session see the stscreds.AssumeRoleProvider
documentation.
Environment Variables

View File

@ -1,11 +1,14 @@
package session
import (
"fmt"
"os"
"strconv"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/credentials"
"github.com/aws/aws-sdk-go/aws/defaults"
"github.com/aws/aws-sdk-go/aws/endpoints"
)
// EnvProviderName provides a name of the provider when config is loaded from environment.
@ -79,7 +82,7 @@ type envConfig struct {
// AWS_CONFIG_FILE=$HOME/my_shared_config
SharedConfigFile string
// Sets the path to a custom Credentials Authroity (CA) Bundle PEM file
// Sets the path to a custom Credentials Authority (CA) Bundle PEM file
// that the SDK will use instead of the system's root CA bundle.
// Only use this if you want to configure the SDK to use a custom set
// of CAs.
@ -98,15 +101,47 @@ type envConfig struct {
CustomCABundle string
csmEnabled string
CSMEnabled bool
CSMEnabled *bool
CSMPort string
CSMHost string
CSMClientID string
// Enables endpoint discovery via environment variables.
//
// AWS_ENABLE_ENDPOINT_DISCOVERY=true
EnableEndpointDiscovery *bool
enableEndpointDiscovery string
// Specifies the WebIdentity token the SDK should use to assume a role
// with.
//
// AWS_WEB_IDENTITY_TOKEN_FILE=file_path
WebIdentityTokenFilePath string
// Specifies the IAM role arn to use when assuming an role.
//
// AWS_ROLE_ARN=role_arn
RoleARN string
// Specifies the IAM role session name to use when assuming a role.
//
// AWS_ROLE_SESSION_NAME=session_name
RoleSessionName string
// Specifies the Regional Endpoint flag for the sdk to resolve the endpoint for a service
//
// AWS_STS_REGIONAL_ENDPOINTS =sts_regional_endpoint
// This can take value as `regional` or `legacy`
STSRegionalEndpoint endpoints.STSRegionalEndpoint
}
var (
csmEnabledEnvKey = []string{
"AWS_CSM_ENABLED",
}
csmHostEnvKey = []string{
"AWS_CSM_HOST",
}
csmPortEnvKey = []string{
"AWS_CSM_PORT",
}
@ -125,6 +160,10 @@ var (
"AWS_SESSION_TOKEN",
}
enableEndpointDiscoveryEnvKey = []string{
"AWS_ENABLE_ENDPOINT_DISCOVERY",
}
regionEnvKeys = []string{
"AWS_REGION",
"AWS_DEFAULT_REGION", // Only read if AWS_SDK_LOAD_CONFIG is also set
@ -139,6 +178,18 @@ var (
sharedConfigFileEnvKey = []string{
"AWS_CONFIG_FILE",
}
webIdentityTokenFilePathEnvKey = []string{
"AWS_WEB_IDENTITY_TOKEN_FILE",
}
roleARNEnvKey = []string{
"AWS_ROLE_ARN",
}
roleSessionNameEnvKey = []string{
"AWS_ROLE_SESSION_NAME",
}
stsRegionalEndpointKey = []string{
"AWS_STS_REGIONAL_ENDPOINTS",
}
)
// loadEnvConfig retrieves the SDK's environment configuration.
@ -147,7 +198,7 @@ var (
// If the environment variable `AWS_SDK_LOAD_CONFIG` is set to a truthy value
// the shared SDK config will be loaded in addition to the SDK's specific
// configuration values.
func loadEnvConfig() envConfig {
func loadEnvConfig() (envConfig, error) {
enableSharedConfig, _ := strconv.ParseBool(os.Getenv("AWS_SDK_LOAD_CONFIG"))
return envConfigLoad(enableSharedConfig)
}
@ -158,30 +209,42 @@ func loadEnvConfig() envConfig {
// Loads the shared configuration in addition to the SDK's specific configuration.
// This will load the same values as `loadEnvConfig` if the `AWS_SDK_LOAD_CONFIG`
// environment variable is set.
func loadSharedEnvConfig() envConfig {
func loadSharedEnvConfig() (envConfig, error) {
return envConfigLoad(true)
}
func envConfigLoad(enableSharedConfig bool) envConfig {
func envConfigLoad(enableSharedConfig bool) (envConfig, error) {
cfg := envConfig{}
cfg.EnableSharedConfig = enableSharedConfig
setFromEnvVal(&cfg.Creds.AccessKeyID, credAccessEnvKey)
setFromEnvVal(&cfg.Creds.SecretAccessKey, credSecretEnvKey)
setFromEnvVal(&cfg.Creds.SessionToken, credSessionEnvKey)
// Static environment credentials
var creds credentials.Value
setFromEnvVal(&creds.AccessKeyID, credAccessEnvKey)
setFromEnvVal(&creds.SecretAccessKey, credSecretEnvKey)
setFromEnvVal(&creds.SessionToken, credSessionEnvKey)
if creds.HasKeys() {
// Require logical grouping of credentials
creds.ProviderName = EnvProviderName
cfg.Creds = creds
}
// Role Metadata
setFromEnvVal(&cfg.RoleARN, roleARNEnvKey)
setFromEnvVal(&cfg.RoleSessionName, roleSessionNameEnvKey)
// Web identity environment variables
setFromEnvVal(&cfg.WebIdentityTokenFilePath, webIdentityTokenFilePathEnvKey)
// CSM environment variables
setFromEnvVal(&cfg.csmEnabled, csmEnabledEnvKey)
setFromEnvVal(&cfg.CSMHost, csmHostEnvKey)
setFromEnvVal(&cfg.CSMPort, csmPortEnvKey)
setFromEnvVal(&cfg.CSMClientID, csmClientIDEnvKey)
cfg.CSMEnabled = len(cfg.csmEnabled) > 0
// Require logical grouping of credentials
if len(cfg.Creds.AccessKeyID) == 0 || len(cfg.Creds.SecretAccessKey) == 0 {
cfg.Creds = credentials.Value{}
} else {
cfg.Creds.ProviderName = EnvProviderName
if len(cfg.csmEnabled) != 0 {
v, _ := strconv.ParseBool(cfg.csmEnabled)
cfg.CSMEnabled = &v
}
regionKeys := regionEnvKeys
@ -194,6 +257,12 @@ func envConfigLoad(enableSharedConfig bool) envConfig {
setFromEnvVal(&cfg.Region, regionKeys)
setFromEnvVal(&cfg.Profile, profileKeys)
// endpoint discovery is in reference to it being enabled.
setFromEnvVal(&cfg.enableEndpointDiscovery, enableEndpointDiscoveryEnvKey)
if len(cfg.enableEndpointDiscovery) > 0 {
cfg.EnableEndpointDiscovery = aws.Bool(cfg.enableEndpointDiscovery != "false")
}
setFromEnvVal(&cfg.SharedCredentialsFile, sharedCredsFileEnvKey)
setFromEnvVal(&cfg.SharedConfigFile, sharedConfigFileEnvKey)
@ -206,12 +275,23 @@ func envConfigLoad(enableSharedConfig bool) envConfig {
cfg.CustomCABundle = os.Getenv("AWS_CA_BUNDLE")
return cfg
// STS Regional Endpoint variable
for _, k := range stsRegionalEndpointKey {
if v := os.Getenv(k); len(v) != 0 {
STSRegionalEndpoint, err := endpoints.GetSTSRegionalEndpoint(v)
if err != nil {
return cfg, fmt.Errorf("failed to load, %v from env config, %v", k, err)
}
cfg.STSRegionalEndpoint = STSRegionalEndpoint
}
}
return cfg, nil
}
func setFromEnvVal(dst *string, keys []string) {
for _, k := range keys {
if v := os.Getenv(k); len(v) > 0 {
if v := os.Getenv(k); len(v) != 0 {
*dst = v
break
}

View File

@ -8,19 +8,36 @@ import (
"io/ioutil"
"net/http"
"os"
"time"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/aws/client"
"github.com/aws/aws-sdk-go/aws/corehandlers"
"github.com/aws/aws-sdk-go/aws/credentials"
"github.com/aws/aws-sdk-go/aws/credentials/stscreds"
"github.com/aws/aws-sdk-go/aws/csm"
"github.com/aws/aws-sdk-go/aws/defaults"
"github.com/aws/aws-sdk-go/aws/endpoints"
"github.com/aws/aws-sdk-go/aws/request"
)
const (
// ErrCodeSharedConfig represents an error that occurs in the shared
// configuration logic
ErrCodeSharedConfig = "SharedConfigErr"
)
// ErrSharedConfigSourceCollision will be returned if a section contains both
// source_profile and credential_source
var ErrSharedConfigSourceCollision = awserr.New(ErrCodeSharedConfig, "only source profile or credential source can be specified, not both", nil)
// ErrSharedConfigECSContainerEnvVarEmpty will be returned if the environment
// variables are empty and Environment was set as the credential source
var ErrSharedConfigECSContainerEnvVarEmpty = awserr.New(ErrCodeSharedConfig, "EcsContainer was specified as the credential_source, but 'AWS_CONTAINER_CREDENTIALS_RELATIVE_URI' was not set", nil)
// ErrSharedConfigInvalidCredSource will be returned if an invalid credential source was provided
var ErrSharedConfigInvalidCredSource = awserr.New(ErrCodeSharedConfig, "credential source values must be EcsContainer, Ec2InstanceMetadata, or Environment", nil)
// A Session provides a central location to create service clients from and
// store configurations and request handlers for those services.
//
@ -56,7 +73,7 @@ type Session struct {
// func is called instead of waiting to receive an error until a request is made.
func New(cfgs ...*aws.Config) *Session {
// load initial config from environment
envCfg := loadEnvConfig()
envCfg, envErr := loadEnvConfig()
if envCfg.EnableSharedConfig {
var cfg aws.Config
@ -76,19 +93,28 @@ func New(cfgs ...*aws.Config) *Session {
// Session creation failed, need to report the error and prevent
// any requests from succeeding.
s = &Session{Config: defaults.Config()}
s.Config.MergeIn(cfgs...)
s.Config.Logger.Log("ERROR:", msg, "Error:", err)
s.Handlers.Validate.PushBack(func(r *request.Request) {
r.Error = err
})
s.logDeprecatedNewSessionError(msg, err, cfgs)
}
return s
}
s := deprecatedNewSession(cfgs...)
if envCfg.CSMEnabled {
enableCSM(&s.Handlers, envCfg.CSMClientID, envCfg.CSMPort, s.Config.Logger)
if envErr != nil {
msg := "failed to load env config"
s.logDeprecatedNewSessionError(msg, envErr, cfgs)
}
if csmCfg, err := loadCSMConfig(envCfg, []string{}); err != nil {
if l := s.Config.Logger; l != nil {
l.Log(fmt.Sprintf("ERROR: failed to load CSM configuration, %v", err))
}
} else if csmCfg.Enabled {
err := enableCSM(&s.Handlers, csmCfg, s.Config.Logger)
if err != nil {
msg := "failed to enable CSM"
s.logDeprecatedNewSessionError(msg, err, cfgs)
}
}
return s
@ -107,7 +133,7 @@ func New(cfgs ...*aws.Config) *Session {
// to be built with retrieving credentials with AssumeRole set in the config.
//
// See the NewSessionWithOptions func for information on how to override or
// control through code how the Session will be created. Such as specifying the
// control through code how the Session will be created, such as specifying the
// config profile, and controlling if shared config is enabled or not.
func NewSession(cfgs ...*aws.Config) (*Session, error) {
opts := Options{}
@ -191,6 +217,12 @@ type Options struct {
// the config enables assume role wit MFA via the mfa_serial field.
AssumeRoleTokenProvider func() (string, error)
// When the SDK's shared config is configured to assume a role this option
// may be provided to set the expiry duration of the STS credentials.
// Defaults to 15 minutes if not set as documented in the
// stscreds.AssumeRoleProvider.
AssumeRoleDuration time.Duration
// Reader for a custom Credentials Authority (CA) bundle in PEM format that
// the SDK will use instead of the default system's root CA bundle. Use this
// only if you want to replace the CA bundle the SDK uses for TLS requests.
@ -205,6 +237,12 @@ type Options struct {
// to also enable this feature. CustomCABundle session option field has priority
// over the AWS_CA_BUNDLE environment variable, and will be used if both are set.
CustomCABundle io.Reader
// The handlers that the session and all API clients will be created with.
// This must be a complete set of handlers. Use the defaults.Handlers()
// function to initialize this value before changing the handlers to be
// used by the SDK.
Handlers request.Handlers
}
// NewSessionWithOptions returns a new Session created from SDK defaults, config files,
@ -238,13 +276,20 @@ type Options struct {
// }))
func NewSessionWithOptions(opts Options) (*Session, error) {
var envCfg envConfig
var err error
if opts.SharedConfigState == SharedConfigEnable {
envCfg = loadSharedEnvConfig()
envCfg, err = loadSharedEnvConfig()
if err != nil {
return nil, fmt.Errorf("failed to load shared config, %v", err)
}
} else {
envCfg = loadEnvConfig()
envCfg, err = loadEnvConfig()
if err != nil {
return nil, fmt.Errorf("failed to load environment config, %v", err)
}
}
if len(opts.Profile) > 0 {
if len(opts.Profile) != 0 {
envCfg.Profile = opts.Profile
}
@ -310,27 +355,33 @@ func deprecatedNewSession(cfgs ...*aws.Config) *Session {
return s
}
func enableCSM(handlers *request.Handlers, clientID string, port string, logger aws.Logger) {
logger.Log("Enabling CSM")
if len(port) == 0 {
port = csm.DefaultPort
func enableCSM(handlers *request.Handlers, cfg csmConfig, logger aws.Logger) error {
if logger != nil {
logger.Log("Enabling CSM")
}
r, err := csm.Start(clientID, "127.0.0.1:"+port)
r, err := csm.Start(cfg.ClientID, csm.AddressWithDefaults(cfg.Host, cfg.Port))
if err != nil {
return
return err
}
r.InjectHandlers(handlers)
return nil
}
func newSession(opts Options, envCfg envConfig, cfgs ...*aws.Config) (*Session, error) {
cfg := defaults.Config()
handlers := defaults.Handlers()
handlers := opts.Handlers
if handlers.IsEmpty() {
handlers = defaults.Handlers()
}
// Get a merged version of the user provided config to determine if
// credentials were.
userCfg := &aws.Config{}
userCfg.MergeIn(cfgs...)
cfg.MergeIn(userCfg)
// Ordered config files will be loaded in with later files overwriting
// previous config file values.
@ -347,9 +398,17 @@ func newSession(opts Options, envCfg envConfig, cfgs ...*aws.Config) (*Session,
}
// Load additional config from file(s)
sharedCfg, err := loadSharedConfig(envCfg.Profile, cfgFiles)
sharedCfg, err := loadSharedConfig(envCfg.Profile, cfgFiles, envCfg.EnableSharedConfig)
if err != nil {
return nil, err
if len(envCfg.Profile) == 0 && !envCfg.EnableSharedConfig && (envCfg.Creds.HasKeys() || userCfg.Credentials != nil) {
// Special case where the user has not explicitly specified an AWS_PROFILE,
// or session.Options.profile, shared config is not enabled, and the
// environment has credentials, allow the shared config file to fail to
// load since the user has already provided credentials, and nothing else
// is required to be read file. Github(aws/aws-sdk-go#2455)
} else if _, ok := err.(SharedConfigProfileNotExistsError); !ok {
return nil, err
}
}
if err := mergeConfigSrcs(cfg, userCfg, envCfg, sharedCfg, handlers, opts); err != nil {
@ -362,8 +421,16 @@ func newSession(opts Options, envCfg envConfig, cfgs ...*aws.Config) (*Session,
}
initHandlers(s)
if envCfg.CSMEnabled {
enableCSM(&s.Handlers, envCfg.CSMClientID, envCfg.CSMPort, s.Config.Logger)
if csmCfg, err := loadCSMConfig(envCfg, cfgFiles); err != nil {
if l := s.Config.Logger; l != nil {
l.Log(fmt.Sprintf("ERROR: failed to load CSM configuration, %v", err))
}
} else if csmCfg.Enabled {
err = enableCSM(&s.Handlers, csmCfg, s.Config.Logger)
if err != nil {
return nil, err
}
}
// Setup HTTP client with custom cert bundle if enabled
@ -376,6 +443,46 @@ func newSession(opts Options, envCfg envConfig, cfgs ...*aws.Config) (*Session,
return s, nil
}
type csmConfig struct {
Enabled bool
Host string
Port string
ClientID string
}
var csmProfileName = "aws_csm"
func loadCSMConfig(envCfg envConfig, cfgFiles []string) (csmConfig, error) {
if envCfg.CSMEnabled != nil {
if *envCfg.CSMEnabled {
return csmConfig{
Enabled: true,
ClientID: envCfg.CSMClientID,
Host: envCfg.CSMHost,
Port: envCfg.CSMPort,
}, nil
}
return csmConfig{}, nil
}
sharedCfg, err := loadSharedConfig(csmProfileName, cfgFiles, false)
if err != nil {
if _, ok := err.(SharedConfigProfileNotExistsError); !ok {
return csmConfig{}, err
}
}
if sharedCfg.CSMEnabled != nil && *sharedCfg.CSMEnabled == true {
return csmConfig{
Enabled: true,
ClientID: sharedCfg.CSMClientID,
Host: sharedCfg.CSMHost,
Port: sharedCfg.CSMPort,
}, nil
}
return csmConfig{}, nil
}
func loadCustomCABundle(s *Session, bundle io.Reader) error {
var t *http.Transport
switch v := s.Config.HTTPClient.Transport.(type) {
@ -388,7 +495,10 @@ func loadCustomCABundle(s *Session, bundle io.Reader) error {
}
}
if t == nil {
t = &http.Transport{}
// Nil transport implies `http.DefaultTransport` should be used. Since
// the SDK cannot modify, nor copy the `DefaultTransport` specifying
// the values the next closest behavior.
t = getCABundleTransport()
}
p, err := loadCertPool(bundle)
@ -421,9 +531,11 @@ func loadCertPool(r io.Reader) (*x509.CertPool, error) {
return p, nil
}
func mergeConfigSrcs(cfg, userCfg *aws.Config, envCfg envConfig, sharedCfg sharedConfig, handlers request.Handlers, sessOpts Options) error {
// Merge in user provided configuration
cfg.MergeIn(userCfg)
func mergeConfigSrcs(cfg, userCfg *aws.Config,
envCfg envConfig, sharedCfg sharedConfig,
handlers request.Handlers,
sessOpts Options,
) error {
// Region if not already set by user
if len(aws.StringValue(cfg.Region)) == 0 {
@ -434,103 +546,46 @@ func mergeConfigSrcs(cfg, userCfg *aws.Config, envCfg envConfig, sharedCfg share
}
}
// Configure credentials if not already set
if cfg.Credentials == credentials.AnonymousCredentials && userCfg.Credentials == nil {
if len(envCfg.Creds.AccessKeyID) > 0 {
cfg.Credentials = credentials.NewStaticCredentialsFromCreds(
envCfg.Creds,
)
} else if envCfg.EnableSharedConfig && len(sharedCfg.AssumeRole.RoleARN) > 0 && sharedCfg.AssumeRoleSource != nil {
cfgCp := *cfg
cfgCp.Credentials = credentials.NewStaticCredentialsFromCreds(
sharedCfg.AssumeRoleSource.Creds,
)
if len(sharedCfg.AssumeRole.MFASerial) > 0 && sessOpts.AssumeRoleTokenProvider == nil {
// AssumeRole Token provider is required if doing Assume Role
// with MFA.
return AssumeRoleTokenProviderNotSetError{}
}
cfg.Credentials = stscreds.NewCredentials(
&Session{
Config: &cfgCp,
Handlers: handlers.Copy(),
},
sharedCfg.AssumeRole.RoleARN,
func(opt *stscreds.AssumeRoleProvider) {
opt.RoleSessionName = sharedCfg.AssumeRole.RoleSessionName
// Assume role with external ID
if len(sharedCfg.AssumeRole.ExternalID) > 0 {
opt.ExternalID = aws.String(sharedCfg.AssumeRole.ExternalID)
}
// Assume role with MFA
if len(sharedCfg.AssumeRole.MFASerial) > 0 {
opt.SerialNumber = aws.String(sharedCfg.AssumeRole.MFASerial)
opt.TokenProvider = sessOpts.AssumeRoleTokenProvider
}
},
)
} else if len(sharedCfg.Creds.AccessKeyID) > 0 {
cfg.Credentials = credentials.NewStaticCredentialsFromCreds(
sharedCfg.Creds,
)
} else {
// Fallback to default credentials provider, include mock errors
// for the credential chain so user can identify why credentials
// failed to be retrieved.
cfg.Credentials = credentials.NewCredentials(&credentials.ChainProvider{
VerboseErrors: aws.BoolValue(cfg.CredentialsChainVerboseErrors),
Providers: []credentials.Provider{
&credProviderError{Err: awserr.New("EnvAccessKeyNotFound", "failed to find credentials in the environment.", nil)},
&credProviderError{Err: awserr.New("SharedCredsLoad", fmt.Sprintf("failed to load profile, %s.", envCfg.Profile), nil)},
defaults.RemoteCredProvider(*cfg, handlers),
},
})
if cfg.EnableEndpointDiscovery == nil {
if envCfg.EnableEndpointDiscovery != nil {
cfg.WithEndpointDiscovery(*envCfg.EnableEndpointDiscovery)
} else if envCfg.EnableSharedConfig && sharedCfg.EnableEndpointDiscovery != nil {
cfg.WithEndpointDiscovery(*sharedCfg.EnableEndpointDiscovery)
}
}
// Regional Endpoint flag for STS endpoint resolving
mergeSTSRegionalEndpointConfig(cfg, envCfg, sharedCfg)
// Configure credentials if not already set by the user when creating the
// Session.
if cfg.Credentials == credentials.AnonymousCredentials && userCfg.Credentials == nil {
creds, err := resolveCredentials(cfg, envCfg, sharedCfg, handlers, sessOpts)
if err != nil {
return err
}
cfg.Credentials = creds
}
return nil
}
// AssumeRoleTokenProviderNotSetError is an error returned when creating a session when the
// MFAToken option is not set when shared config is configured load assume a
// role with an MFA token.
type AssumeRoleTokenProviderNotSetError struct{}
// mergeSTSRegionalEndpointConfig function merges the STSRegionalEndpoint into cfg from
// envConfig and SharedConfig with envConfig being given precedence over SharedConfig
func mergeSTSRegionalEndpointConfig(cfg *aws.Config, envCfg envConfig, sharedCfg sharedConfig) error {
// Code is the short id of the error.
func (e AssumeRoleTokenProviderNotSetError) Code() string {
return "AssumeRoleTokenProviderNotSetError"
}
cfg.STSRegionalEndpoint = envCfg.STSRegionalEndpoint
// Message is the description of the error
func (e AssumeRoleTokenProviderNotSetError) Message() string {
return fmt.Sprintf("assume role with MFA enabled, but AssumeRoleTokenProvider session option not set.")
}
if cfg.STSRegionalEndpoint == endpoints.UnsetSTSEndpoint {
cfg.STSRegionalEndpoint = sharedCfg.STSRegionalEndpoint
}
// OrigErr is the underlying error that caused the failure.
func (e AssumeRoleTokenProviderNotSetError) OrigErr() error {
if cfg.STSRegionalEndpoint == endpoints.UnsetSTSEndpoint {
cfg.STSRegionalEndpoint = endpoints.LegacySTSEndpoint
}
return nil
}
// Error satisfies the error interface.
func (e AssumeRoleTokenProviderNotSetError) Error() string {
return awserr.SprintError(e.Code(), e.Message(), "", nil)
}
type credProviderError struct {
Err error
}
var emptyCreds = credentials.Value{}
func (c credProviderError) Retrieve() (credentials.Value, error) {
return credentials.Value{}, c.Err
}
func (c credProviderError) IsExpired() bool {
return true
}
func initHandlers(s *Session) {
// Add the Validate parameter handler if it is not disabled.
s.Handlers.Validate.Remove(corehandlers.ValidateParametersHandler)
@ -539,7 +594,7 @@ func initHandlers(s *Session) {
}
}
// Copy creates and returns a copy of the current Session, coping the config
// Copy creates and returns a copy of the current Session, copying the config
// and handlers. If any additional configs are provided they will be merged
// on top of the Session's copied config.
//
@ -559,37 +614,15 @@ func (s *Session) Copy(cfgs ...*aws.Config) *Session {
// ClientConfig satisfies the client.ConfigProvider interface and is used to
// configure the service client instances. Passing the Session to the service
// client's constructor (New) will use this method to configure the client.
func (s *Session) ClientConfig(serviceName string, cfgs ...*aws.Config) client.Config {
// Backwards compatibility, the error will be eaten if user calls ClientConfig
// directly. All SDK services will use ClientconfigWithError.
cfg, _ := s.clientConfigWithErr(serviceName, cfgs...)
return cfg
}
func (s *Session) clientConfigWithErr(serviceName string, cfgs ...*aws.Config) (client.Config, error) {
func (s *Session) ClientConfig(service string, cfgs ...*aws.Config) client.Config {
s = s.Copy(cfgs...)
var resolved endpoints.ResolvedEndpoint
var err error
region := aws.StringValue(s.Config.Region)
if endpoint := aws.StringValue(s.Config.Endpoint); len(endpoint) != 0 {
resolved.URL = endpoints.AddScheme(endpoint, aws.BoolValue(s.Config.DisableSSL))
resolved.SigningRegion = region
} else {
resolved, err = s.Config.EndpointResolver.EndpointFor(
serviceName, region,
func(opt *endpoints.Options) {
opt.DisableSSL = aws.BoolValue(s.Config.DisableSSL)
opt.UseDualStack = aws.BoolValue(s.Config.UseDualStack)
// Support the condition where the service is modeled but its
// endpoint metadata is not available.
opt.ResolveUnknownService = true
},
)
resolved, err := s.resolveEndpoint(service, region, s.Config)
if err != nil && s.Config.Logger != nil {
s.Config.Logger.Log(fmt.Sprintf(
"ERROR: unable to resolve endpoint for service %q, region %q, err: %v",
service, region, err))
}
return client.Config{
@ -599,7 +632,37 @@ func (s *Session) clientConfigWithErr(serviceName string, cfgs ...*aws.Config) (
SigningRegion: resolved.SigningRegion,
SigningNameDerived: resolved.SigningNameDerived,
SigningName: resolved.SigningName,
}, err
}
}
func (s *Session) resolveEndpoint(service, region string, cfg *aws.Config) (endpoints.ResolvedEndpoint, error) {
if ep := aws.StringValue(cfg.Endpoint); len(ep) != 0 {
return endpoints.ResolvedEndpoint{
URL: endpoints.AddScheme(ep, aws.BoolValue(cfg.DisableSSL)),
SigningRegion: region,
}, nil
}
resolved, err := cfg.EndpointResolver.EndpointFor(service, region,
func(opt *endpoints.Options) {
opt.DisableSSL = aws.BoolValue(cfg.DisableSSL)
opt.UseDualStack = aws.BoolValue(cfg.UseDualStack)
// Support for STSRegionalEndpoint where the STSRegionalEndpoint is
// provided in envConfig or sharedConfig with envConfig getting
// precedence.
opt.STSRegionalEndpoint = cfg.STSRegionalEndpoint
// Support the condition where the service is modeled but its
// endpoint metadata is not available.
opt.ResolveUnknownService = true
},
)
if err != nil {
return endpoints.ResolvedEndpoint{}, err
}
return resolved, nil
}
// ClientConfigNoResolveEndpoint is the same as ClientConfig with the exception
@ -609,12 +672,9 @@ func (s *Session) ClientConfigNoResolveEndpoint(cfgs ...*aws.Config) client.Conf
s = s.Copy(cfgs...)
var resolved endpoints.ResolvedEndpoint
region := aws.StringValue(s.Config.Region)
if ep := aws.StringValue(s.Config.Endpoint); len(ep) > 0 {
resolved.URL = endpoints.AddScheme(ep, aws.BoolValue(s.Config.DisableSSL))
resolved.SigningRegion = region
resolved.SigningRegion = aws.StringValue(s.Config.Region)
}
return client.Config{
@ -626,3 +686,14 @@ func (s *Session) ClientConfigNoResolveEndpoint(cfgs ...*aws.Config) client.Conf
SigningName: resolved.SigningName,
}
}
// logDeprecatedNewSessionError function enables error handling for session
func (s *Session) logDeprecatedNewSessionError(msg string, err error, cfgs []*aws.Config) {
// Session creation failed, need to report the error and prevent
// any requests from succeeding.
s.Config.MergeIn(cfgs...)
s.Config.Logger.Log("ERROR:", msg, "Error:", err)
s.Handlers.Validate.PushBack(func(r *request.Request) {
r.Error = err
})
}

View File

@ -2,11 +2,11 @@ package session
import (
"fmt"
"io/ioutil"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/aws/credentials"
"github.com/go-ini/ini"
"github.com/aws/aws-sdk-go/aws/endpoints"
"github.com/aws/aws-sdk-go/internal/ini"
)
const (
@ -16,68 +16,106 @@ const (
sessionTokenKey = `aws_session_token` // optional
// Assume Role Credentials group
roleArnKey = `role_arn` // group required
sourceProfileKey = `source_profile` // group required
externalIDKey = `external_id` // optional
mfaSerialKey = `mfa_serial` // optional
roleSessionNameKey = `role_session_name` // optional
roleArnKey = `role_arn` // group required
sourceProfileKey = `source_profile` // group required (or credential_source)
credentialSourceKey = `credential_source` // group required (or source_profile)
externalIDKey = `external_id` // optional
mfaSerialKey = `mfa_serial` // optional
roleSessionNameKey = `role_session_name` // optional
// CSM options
csmEnabledKey = `csm_enabled`
csmHostKey = `csm_host`
csmPortKey = `csm_port`
csmClientIDKey = `csm_client_id`
// Additional Config fields
regionKey = `region`
// endpoint discovery group
enableEndpointDiscoveryKey = `endpoint_discovery_enabled` // optional
// External Credential Process
credentialProcessKey = `credential_process` // optional
// Web Identity Token File
webIdentityTokenFileKey = `web_identity_token_file` // optional
// Additional config fields for regional or legacy endpoints
stsRegionalEndpointSharedKey = `sts_regional_endpoints`
// DefaultSharedConfigProfile is the default profile to be used when
// loading configuration from the config files if another profile name
// is not provided.
DefaultSharedConfigProfile = `default`
)
type assumeRoleConfig struct {
RoleARN string
SourceProfile string
ExternalID string
MFASerial string
RoleSessionName string
}
// sharedConfig represents the configuration fields of the SDK config files.
type sharedConfig struct {
// Credentials values from the config file. Both aws_access_key_id
// and aws_secret_access_key must be provided together in the same file
// to be considered valid. The values will be ignored if not a complete group.
// aws_session_token is an optional field that can be provided if both of the
// other two fields are also provided.
// Credentials values from the config file. Both aws_access_key_id and
// aws_secret_access_key must be provided together in the same file to be
// considered valid. The values will be ignored if not a complete group.
// aws_session_token is an optional field that can be provided if both of
// the other two fields are also provided.
//
// aws_access_key_id
// aws_secret_access_key
// aws_session_token
Creds credentials.Value
AssumeRole assumeRoleConfig
AssumeRoleSource *sharedConfig
CredentialSource string
CredentialProcess string
WebIdentityTokenFile string
// Region is the region the SDK should use for looking up AWS service endpoints
// and signing requests.
RoleARN string
RoleSessionName string
ExternalID string
MFASerial string
SourceProfileName string
SourceProfile *sharedConfig
// Region is the region the SDK should use for looking up AWS service
// endpoints and signing requests.
//
// region
Region string
// EnableEndpointDiscovery can be enabled in the shared config by setting
// endpoint_discovery_enabled to true
//
// endpoint_discovery_enabled = true
EnableEndpointDiscovery *bool
// CSM Options
CSMEnabled *bool
CSMHost string
CSMPort string
CSMClientID string
// Specifies the Regional Endpoint flag for the sdk to resolve the endpoint for a service
//
// sts_regional_endpoints = sts_regional_endpoint
// This can take value as `LegacySTSEndpoint` or `RegionalSTSEndpoint`
STSRegionalEndpoint endpoints.STSRegionalEndpoint
}
type sharedConfigFile struct {
Filename string
IniData *ini.File
IniData ini.Sections
}
// loadSharedConfig retrieves the configuration from the list of files
// using the profile provided. The order the files are listed will determine
// loadSharedConfig retrieves the configuration from the list of files using
// the profile provided. The order the files are listed will determine
// precedence. Values in subsequent files will overwrite values defined in
// earlier files.
//
// For example, given two files A and B. Both define credentials. If the order
// of the files are A then B, B's credential values will be used instead of A's.
// of the files are A then B, B's credential values will be used instead of
// A's.
//
// See sharedConfig.setFromFile for information how the config files
// will be loaded.
func loadSharedConfig(profile string, filenames []string) (sharedConfig, error) {
func loadSharedConfig(profile string, filenames []string, exOpts bool) (sharedConfig, error) {
if len(profile) == 0 {
profile = DefaultSharedConfigProfile
}
@ -88,16 +126,11 @@ func loadSharedConfig(profile string, filenames []string) (sharedConfig, error)
}
cfg := sharedConfig{}
if err = cfg.setFromIniFiles(profile, files); err != nil {
profiles := map[string]struct{}{}
if err = cfg.setFromIniFiles(profiles, profile, files, exOpts); err != nil {
return sharedConfig{}, err
}
if len(cfg.AssumeRole.SourceProfile) > 0 {
if err := cfg.setAssumeRoleSource(profile, files); err != nil {
return sharedConfig{}, err
}
}
return cfg, nil
}
@ -105,114 +138,258 @@ func loadSharedConfigIniFiles(filenames []string) ([]sharedConfigFile, error) {
files := make([]sharedConfigFile, 0, len(filenames))
for _, filename := range filenames {
b, err := ioutil.ReadFile(filename)
if err != nil {
sections, err := ini.OpenFile(filename)
if aerr, ok := err.(awserr.Error); ok && aerr.Code() == ini.ErrCodeUnableToReadFile {
// Skip files which can't be opened and read for whatever reason
continue
}
f, err := ini.Load(b)
if err != nil {
} else if err != nil {
return nil, SharedConfigLoadError{Filename: filename, Err: err}
}
files = append(files, sharedConfigFile{
Filename: filename, IniData: f,
Filename: filename, IniData: sections,
})
}
return files, nil
}
func (cfg *sharedConfig) setAssumeRoleSource(origProfile string, files []sharedConfigFile) error {
var assumeRoleSrc sharedConfig
// Multiple level assume role chains are not support
if cfg.AssumeRole.SourceProfile == origProfile {
assumeRoleSrc = *cfg
assumeRoleSrc.AssumeRole = assumeRoleConfig{}
} else {
err := assumeRoleSrc.setFromIniFiles(cfg.AssumeRole.SourceProfile, files)
if err != nil {
return err
}
}
if len(assumeRoleSrc.Creds.AccessKeyID) == 0 {
return SharedConfigAssumeRoleError{RoleARN: cfg.AssumeRole.RoleARN}
}
cfg.AssumeRoleSource = &assumeRoleSrc
return nil
}
func (cfg *sharedConfig) setFromIniFiles(profile string, files []sharedConfigFile) error {
func (cfg *sharedConfig) setFromIniFiles(profiles map[string]struct{}, profile string, files []sharedConfigFile, exOpts bool) error {
// Trim files from the list that don't exist.
var skippedFiles int
var profileNotFoundErr error
for _, f := range files {
if err := cfg.setFromIniFile(profile, f); err != nil {
if err := cfg.setFromIniFile(profile, f, exOpts); err != nil {
if _, ok := err.(SharedConfigProfileNotExistsError); ok {
// Ignore proviles missings
// Ignore profiles not defined in individual files.
profileNotFoundErr = err
skippedFiles++
continue
}
return err
}
}
if skippedFiles == len(files) {
// If all files were skipped because the profile is not found, return
// the original profile not found error.
return profileNotFoundErr
}
if _, ok := profiles[profile]; ok {
// if this is the second instance of the profile the Assume Role
// options must be cleared because they are only valid for the
// first reference of a profile. The self linked instance of the
// profile only have credential provider options.
cfg.clearAssumeRoleOptions()
} else {
// First time a profile has been seen, It must either be a assume role
// or credentials. Assert if the credential type requires a role ARN,
// the ARN is also set.
if err := cfg.validateCredentialsRequireARN(profile); err != nil {
return err
}
}
profiles[profile] = struct{}{}
if err := cfg.validateCredentialType(); err != nil {
return err
}
// Link source profiles for assume roles
if len(cfg.SourceProfileName) != 0 {
// Linked profile via source_profile ignore credential provider
// options, the source profile must provide the credentials.
cfg.clearCredentialOptions()
srcCfg := &sharedConfig{}
err := srcCfg.setFromIniFiles(profiles, cfg.SourceProfileName, files, exOpts)
if err != nil {
// SourceProfile that doesn't exist is an error in configuration.
if _, ok := err.(SharedConfigProfileNotExistsError); ok {
err = SharedConfigAssumeRoleError{
RoleARN: cfg.RoleARN,
SourceProfile: cfg.SourceProfileName,
}
}
return err
}
if !srcCfg.hasCredentials() {
return SharedConfigAssumeRoleError{
RoleARN: cfg.RoleARN,
SourceProfile: cfg.SourceProfileName,
}
}
cfg.SourceProfile = srcCfg
}
return nil
}
// setFromFile loads the configuration from the file using
// the profile provided. A sharedConfig pointer type value is used so that
// multiple config file loadings can be chained.
// setFromFile loads the configuration from the file using the profile
// provided. A sharedConfig pointer type value is used so that multiple config
// file loadings can be chained.
//
// Only loads complete logically grouped values, and will not set fields in cfg
// for incomplete grouped values in the config. Such as credentials. For example
// if a config file only includes aws_access_key_id but no aws_secret_access_key
// the aws_access_key_id will be ignored.
func (cfg *sharedConfig) setFromIniFile(profile string, file sharedConfigFile) error {
section, err := file.IniData.GetSection(profile)
if err != nil {
// for incomplete grouped values in the config. Such as credentials. For
// example if a config file only includes aws_access_key_id but no
// aws_secret_access_key the aws_access_key_id will be ignored.
func (cfg *sharedConfig) setFromIniFile(profile string, file sharedConfigFile, exOpts bool) error {
section, ok := file.IniData.GetSection(profile)
if !ok {
// Fallback to to alternate profile name: profile <name>
section, err = file.IniData.GetSection(fmt.Sprintf("profile %s", profile))
if err != nil {
return SharedConfigProfileNotExistsError{Profile: profile, Err: err}
section, ok = file.IniData.GetSection(fmt.Sprintf("profile %s", profile))
if !ok {
return SharedConfigProfileNotExistsError{Profile: profile, Err: nil}
}
}
if exOpts {
// Assume Role Parameters
updateString(&cfg.RoleARN, section, roleArnKey)
updateString(&cfg.ExternalID, section, externalIDKey)
updateString(&cfg.MFASerial, section, mfaSerialKey)
updateString(&cfg.RoleSessionName, section, roleSessionNameKey)
updateString(&cfg.SourceProfileName, section, sourceProfileKey)
updateString(&cfg.CredentialSource, section, credentialSourceKey)
updateString(&cfg.Region, section, regionKey)
if v := section.String(stsRegionalEndpointSharedKey); len(v) != 0 {
sre, err := endpoints.GetSTSRegionalEndpoint(v)
if err != nil {
return fmt.Errorf("failed to load %s from shared config, %s, %v",
stsRegionalEndpointKey, file.Filename, err)
}
cfg.STSRegionalEndpoint = sre
}
}
updateString(&cfg.CredentialProcess, section, credentialProcessKey)
updateString(&cfg.WebIdentityTokenFile, section, webIdentityTokenFileKey)
// Shared Credentials
akid := section.Key(accessKeyIDKey).String()
secret := section.Key(secretAccessKey).String()
if len(akid) > 0 && len(secret) > 0 {
cfg.Creds = credentials.Value{
AccessKeyID: akid,
SecretAccessKey: secret,
SessionToken: section.Key(sessionTokenKey).String(),
ProviderName: fmt.Sprintf("SharedConfigCredentials: %s", file.Filename),
}
creds := credentials.Value{
AccessKeyID: section.String(accessKeyIDKey),
SecretAccessKey: section.String(secretAccessKey),
SessionToken: section.String(sessionTokenKey),
ProviderName: fmt.Sprintf("SharedConfigCredentials: %s", file.Filename),
}
if creds.HasKeys() {
cfg.Creds = creds
}
// Assume Role
roleArn := section.Key(roleArnKey).String()
srcProfile := section.Key(sourceProfileKey).String()
if len(roleArn) > 0 && len(srcProfile) > 0 {
cfg.AssumeRole = assumeRoleConfig{
RoleARN: roleArn,
SourceProfile: srcProfile,
ExternalID: section.Key(externalIDKey).String(),
MFASerial: section.Key(mfaSerialKey).String(),
RoleSessionName: section.Key(roleSessionNameKey).String(),
}
// Endpoint discovery
updateBoolPtr(&cfg.EnableEndpointDiscovery, section, enableEndpointDiscoveryKey)
// CSM options
updateBoolPtr(&cfg.CSMEnabled, section, csmEnabledKey)
updateString(&cfg.CSMHost, section, csmHostKey)
updateString(&cfg.CSMPort, section, csmPortKey)
updateString(&cfg.CSMClientID, section, csmClientIDKey)
return nil
}
func (cfg *sharedConfig) validateCredentialsRequireARN(profile string) error {
var credSource string
switch {
case len(cfg.SourceProfileName) != 0:
credSource = sourceProfileKey
case len(cfg.CredentialSource) != 0:
credSource = credentialSourceKey
case len(cfg.WebIdentityTokenFile) != 0:
credSource = webIdentityTokenFileKey
}
// Region
if v := section.Key(regionKey).String(); len(v) > 0 {
cfg.Region = v
if len(credSource) != 0 && len(cfg.RoleARN) == 0 {
return CredentialRequiresARNError{
Type: credSource,
Profile: profile,
}
}
return nil
}
func (cfg *sharedConfig) validateCredentialType() error {
// Only one or no credential type can be defined.
if !oneOrNone(
len(cfg.SourceProfileName) != 0,
len(cfg.CredentialSource) != 0,
len(cfg.CredentialProcess) != 0,
len(cfg.WebIdentityTokenFile) != 0,
) {
return ErrSharedConfigSourceCollision
}
return nil
}
func (cfg *sharedConfig) hasCredentials() bool {
switch {
case len(cfg.SourceProfileName) != 0:
case len(cfg.CredentialSource) != 0:
case len(cfg.CredentialProcess) != 0:
case len(cfg.WebIdentityTokenFile) != 0:
case cfg.Creds.HasKeys():
default:
return false
}
return true
}
func (cfg *sharedConfig) clearCredentialOptions() {
cfg.CredentialSource = ""
cfg.CredentialProcess = ""
cfg.WebIdentityTokenFile = ""
cfg.Creds = credentials.Value{}
}
func (cfg *sharedConfig) clearAssumeRoleOptions() {
cfg.RoleARN = ""
cfg.ExternalID = ""
cfg.MFASerial = ""
cfg.RoleSessionName = ""
cfg.SourceProfileName = ""
}
func oneOrNone(bs ...bool) bool {
var count int
for _, b := range bs {
if b {
count++
if count > 1 {
return false
}
}
}
return true
}
// updateString will only update the dst with the value in the section key, key
// is present in the section.
func updateString(dst *string, section ini.Section, key string) {
if !section.Has(key) {
return
}
*dst = section.String(key)
}
// updateBoolPtr will only update the dst with the value in the section key,
// key is present in the section.
func updateBoolPtr(dst **bool, section ini.Section, key string) {
if !section.Has(key) {
return
}
*dst = new(bool)
**dst = section.Bool(key)
}
// SharedConfigLoadError is an error for the shared config file failed to load.
type SharedConfigLoadError struct {
Filename string
@ -270,7 +447,8 @@ func (e SharedConfigProfileNotExistsError) Error() string {
// profile contains assume role information, but that information is invalid
// or not complete.
type SharedConfigAssumeRoleError struct {
RoleARN string
RoleARN string
SourceProfile string
}
// Code is the short id of the error.
@ -280,8 +458,10 @@ func (e SharedConfigAssumeRoleError) Code() string {
// Message is the description of the error
func (e SharedConfigAssumeRoleError) Message() string {
return fmt.Sprintf("failed to load assume role for %s, source profile has no shared credentials",
e.RoleARN)
return fmt.Sprintf(
"failed to load assume role for %s, source profile %s has no shared credentials",
e.RoleARN, e.SourceProfile,
)
}
// OrigErr is the underlying error that caused the failure.
@ -293,3 +473,36 @@ func (e SharedConfigAssumeRoleError) OrigErr() error {
func (e SharedConfigAssumeRoleError) Error() string {
return awserr.SprintError(e.Code(), e.Message(), "", nil)
}
// CredentialRequiresARNError provides the error for shared config credentials
// that are incorrectly configured in the shared config or credentials file.
type CredentialRequiresARNError struct {
// type of credentials that were configured.
Type string
// Profile name the credentials were in.
Profile string
}
// Code is the short id of the error.
func (e CredentialRequiresARNError) Code() string {
return "CredentialRequiresARNError"
}
// Message is the description of the error
func (e CredentialRequiresARNError) Message() string {
return fmt.Sprintf(
"credential type %s requires role_arn, profile %s",
e.Type, e.Profile,
)
}
// OrigErr is the underlying error that caused the failure.
func (e CredentialRequiresARNError) OrigErr() error {
return nil
}
// Error satisfies the error interface.
func (e CredentialRequiresARNError) Error() string {
return awserr.SprintError(e.Code(), e.Message(), "", nil)
}

View File

@ -134,6 +134,7 @@ var requiredSignedHeaders = rules{
"X-Amz-Server-Side-Encryption-Customer-Key": struct{}{},
"X-Amz-Server-Side-Encryption-Customer-Key-Md5": struct{}{},
"X-Amz-Storage-Class": struct{}{},
"X-Amz-Tagging": struct{}{},
"X-Amz-Website-Redirect-Location": struct{}{},
"X-Amz-Content-Sha256": struct{}{},
},
@ -181,7 +182,7 @@ type Signer struct {
// http://docs.aws.amazon.com/general/latest/gr/sigv4-create-canonical-request.html
DisableURIPathEscaping bool
// Disales the automatical setting of the HTTP request's Body field with the
// Disables the automatical setting of the HTTP request's Body field with the
// io.ReadSeeker passed in to the signer. This is useful if you're using a
// custom wrapper around the body for the io.ReadSeeker and want to preserve
// the Body value on the Request.Body.
@ -421,7 +422,7 @@ var SignRequestHandler = request.NamedHandler{
// If the credentials of the request's config are set to
// credentials.AnonymousCredentials the request will not be signed.
func SignSDKRequest(req *request.Request) {
signSDKRequestWithCurrTime(req, time.Now)
SignSDKRequestWithCurrentTime(req, time.Now)
}
// BuildNamedHandler will build a generic handler for signing.
@ -429,12 +430,15 @@ func BuildNamedHandler(name string, opts ...func(*Signer)) request.NamedHandler
return request.NamedHandler{
Name: name,
Fn: func(req *request.Request) {
signSDKRequestWithCurrTime(req, time.Now, opts...)
SignSDKRequestWithCurrentTime(req, time.Now, opts...)
},
}
}
func signSDKRequestWithCurrTime(req *request.Request, curTimeFn func() time.Time, opts ...func(*Signer)) {
// SignSDKRequestWithCurrentTime will sign the SDK's request using the time
// function passed in. Behaves the same as SignSDKRequest with the exception
// the request is signed with the value returned by the current time function.
func SignSDKRequestWithCurrentTime(req *request.Request, curTimeFn func() time.Time, opts ...func(*Signer)) {
// If the request does not need to be signed ignore the signing of the
// request if the AnonymousCredentials object is used.
if req.Config.Credentials == credentials.AnonymousCredentials {
@ -470,13 +474,9 @@ func signSDKRequestWithCurrTime(req *request.Request, curTimeFn func() time.Time
opt(v4)
}
signingTime := req.Time
if !req.LastSignedAt.IsZero() {
signingTime = req.LastSignedAt
}
curTime := curTimeFn()
signedHeaders, err := v4.signWithBody(req.HTTPRequest, req.GetBody(),
name, region, req.ExpireTime, req.ExpireTime > 0, signingTime,
name, region, req.ExpireTime, req.ExpireTime > 0, curTime,
)
if err != nil {
req.Error = err
@ -485,7 +485,7 @@ func signSDKRequestWithCurrTime(req *request.Request, curTimeFn func() time.Time
}
req.SignedHeaderVals = signedHeaders
req.LastSignedAt = curTimeFn()
req.LastSignedAt = curTime
}
const logSignInfoMsg = `DEBUG: Request Signature:
@ -687,7 +687,11 @@ func (ctx *signingCtx) buildBodyDigest() error {
if !aws.IsReaderSeekable(ctx.Body) {
return fmt.Errorf("cannot use unseekable request body %T, for signed request with body", ctx.Body)
}
hash = hex.EncodeToString(makeSha256Reader(ctx.Body))
hashBytes, err := makeSha256Reader(ctx.Body)
if err != nil {
return err
}
hash = hex.EncodeToString(hashBytes)
}
if includeSHA256Header {
@ -734,19 +738,33 @@ func makeSha256(data []byte) []byte {
return hash.Sum(nil)
}
func makeSha256Reader(reader io.ReadSeeker) []byte {
func makeSha256Reader(reader io.ReadSeeker) (hashBytes []byte, err error) {
hash := sha256.New()
start, _ := reader.Seek(0, sdkio.SeekCurrent)
defer reader.Seek(start, sdkio.SeekStart)
start, err := reader.Seek(0, sdkio.SeekCurrent)
if err != nil {
return nil, err
}
defer func() {
// ensure error is return if unable to seek back to start of payload.
_, err = reader.Seek(start, sdkio.SeekStart)
}()
io.Copy(hash, reader)
return hash.Sum(nil)
// Use CopyN to avoid allocating the 32KB buffer in io.Copy for bodies
// smaller than 32KB. Fall back to io.Copy if we fail to determine the size.
size, err := aws.SeekerLen(reader)
if err != nil {
io.Copy(hash, reader)
} else {
io.CopyN(hash, reader, size)
}
return hash.Sum(nil), nil
}
const doubleSpace = " "
// stripExcessSpaces will rewrite the passed in slice's string values to not
// contain muliple side-by-side spaces.
// contain multiple side-by-side spaces.
func stripExcessSpaces(vals []string) {
var j, k, l, m, spaces int
for i, str := range vals {

View File

@ -7,13 +7,18 @@ import (
"github.com/aws/aws-sdk-go/internal/sdkio"
)
// ReadSeekCloser wraps a io.Reader returning a ReaderSeekerCloser. Should
// only be used with an io.Reader that is also an io.Seeker. Doing so may
// cause request signature errors, or request body's not sent for GET, HEAD
// and DELETE HTTP methods.
// ReadSeekCloser wraps a io.Reader returning a ReaderSeekerCloser. Allows the
// SDK to accept an io.Reader that is not also an io.Seeker for unsigned
// streaming payload API operations.
//
// Deprecated: Should only be used with io.ReadSeeker. If using for
// S3 PutObject to stream content use s3manager.Uploader instead.
// A ReadSeekCloser wrapping an nonseekable io.Reader used in an API
// operation's input will prevent that operation being retried in the case of
// network errors, and cause operation requests to fail if the operation
// requires payload signing.
//
// Note: If using With S3 PutObject to stream an object upload The SDK's S3
// Upload manager (s3manager.Uploader) provides support for streaming with the
// ability to retry network errors.
func ReadSeekCloser(r io.Reader) ReaderSeekerCloser {
return ReaderSeekerCloser{r}
}
@ -43,7 +48,8 @@ func IsReaderSeekable(r io.Reader) bool {
// Read reads from the reader up to size of p. The number of bytes read, and
// error if it occurred will be returned.
//
// If the reader is not an io.Reader zero bytes read, and nil error will be returned.
// If the reader is not an io.Reader zero bytes read, and nil error will be
// returned.
//
// Performs the same functionality as io.Reader Read
func (r ReaderSeekerCloser) Read(p []byte) (int, error) {

View File

@ -5,4 +5,4 @@ package aws
const SDKName = "aws-sdk-go"
// SDKVersion is the version of this SDK
const SDKVersion = "1.15.24"
const SDKVersion = "1.25.32"

120
vendor/github.com/aws/aws-sdk-go/internal/ini/ast.go generated vendored Normal file
View File

@ -0,0 +1,120 @@
package ini
// ASTKind represents different states in the parse table
// and the type of AST that is being constructed
type ASTKind int
// ASTKind* is used in the parse table to transition between
// the different states
const (
ASTKindNone = ASTKind(iota)
ASTKindStart
ASTKindExpr
ASTKindEqualExpr
ASTKindStatement
ASTKindSkipStatement
ASTKindExprStatement
ASTKindSectionStatement
ASTKindNestedSectionStatement
ASTKindCompletedNestedSectionStatement
ASTKindCommentStatement
ASTKindCompletedSectionStatement
)
func (k ASTKind) String() string {
switch k {
case ASTKindNone:
return "none"
case ASTKindStart:
return "start"
case ASTKindExpr:
return "expr"
case ASTKindStatement:
return "stmt"
case ASTKindSectionStatement:
return "section_stmt"
case ASTKindExprStatement:
return "expr_stmt"
case ASTKindCommentStatement:
return "comment"
case ASTKindNestedSectionStatement:
return "nested_section_stmt"
case ASTKindCompletedSectionStatement:
return "completed_stmt"
case ASTKindSkipStatement:
return "skip"
default:
return ""
}
}
// AST interface allows us to determine what kind of node we
// are on and casting may not need to be necessary.
//
// The root is always the first node in Children
type AST struct {
Kind ASTKind
Root Token
RootToken bool
Children []AST
}
func newAST(kind ASTKind, root AST, children ...AST) AST {
return AST{
Kind: kind,
Children: append([]AST{root}, children...),
}
}
func newASTWithRootToken(kind ASTKind, root Token, children ...AST) AST {
return AST{
Kind: kind,
Root: root,
RootToken: true,
Children: children,
}
}
// AppendChild will append to the list of children an AST has.
func (a *AST) AppendChild(child AST) {
a.Children = append(a.Children, child)
}
// GetRoot will return the root AST which can be the first entry
// in the children list or a token.
func (a *AST) GetRoot() AST {
if a.RootToken {
return *a
}
if len(a.Children) == 0 {
return AST{}
}
return a.Children[0]
}
// GetChildren will return the current AST's list of children
func (a *AST) GetChildren() []AST {
if len(a.Children) == 0 {
return []AST{}
}
if a.RootToken {
return a.Children
}
return a.Children[1:]
}
// SetChildren will set and override all children of the AST.
func (a *AST) SetChildren(children []AST) {
if a.RootToken {
a.Children = children
} else {
a.Children = append(a.Children[:1], children...)
}
}
// Start is used to indicate the starting state of the parse table.
var Start = newAST(ASTKindStart, AST{})

View File

@ -0,0 +1,11 @@
package ini
var commaRunes = []rune(",")
func isComma(b rune) bool {
return b == ','
}
func newCommaToken() Token {
return newToken(TokenComma, commaRunes, NoneType)
}

View File

@ -0,0 +1,35 @@
package ini
// isComment will return whether or not the next byte(s) is a
// comment.
func isComment(b []rune) bool {
if len(b) == 0 {
return false
}
switch b[0] {
case ';':
return true
case '#':
return true
}
return false
}
// newCommentToken will create a comment token and
// return how many bytes were read.
func newCommentToken(b []rune) (Token, int, error) {
i := 0
for ; i < len(b); i++ {
if b[i] == '\n' {
break
}
if len(b)-i > 2 && b[i] == '\r' && b[i+1] == '\n' {
break
}
}
return newToken(TokenComment, b[:i], NoneType), i, nil
}

29
vendor/github.com/aws/aws-sdk-go/internal/ini/doc.go generated vendored Normal file
View File

@ -0,0 +1,29 @@
// Package ini is an LL(1) parser for configuration files.
//
// Example:
// sections, err := ini.OpenFile("/path/to/file")
// if err != nil {
// panic(err)
// }
//
// profile := "foo"
// section, ok := sections.GetSection(profile)
// if !ok {
// fmt.Printf("section %q could not be found", profile)
// }
//
// Below is the BNF that describes this parser
// Grammar:
// stmt -> value stmt'
// stmt' -> epsilon | op stmt
// value -> number | string | boolean | quoted_string
//
// section -> [ section'
// section' -> value section_close
// section_close -> ]
//
// SkipState will skip (NL WS)+
//
// comment -> # comment' | ; comment'
// comment' -> epsilon | value
package ini

View File

@ -0,0 +1,4 @@
package ini
// emptyToken is used to satisfy the Token interface
var emptyToken = newToken(TokenNone, []rune{}, NoneType)

View File

@ -0,0 +1,24 @@
package ini
// newExpression will return an expression AST.
// Expr represents an expression
//
// grammar:
// expr -> string | number
func newExpression(tok Token) AST {
return newASTWithRootToken(ASTKindExpr, tok)
}
func newEqualExpr(left AST, tok Token) AST {
return newASTWithRootToken(ASTKindEqualExpr, tok, left)
}
// EqualExprKey will return a LHS value in the equal expr
func EqualExprKey(ast AST) string {
children := ast.GetChildren()
if len(children) == 0 || ast.Kind != ASTKindEqualExpr {
return ""
}
return string(children[0].Root.Raw())
}

17
vendor/github.com/aws/aws-sdk-go/internal/ini/fuzz.go generated vendored Normal file
View File

@ -0,0 +1,17 @@
// +build gofuzz
package ini
import (
"bytes"
)
func Fuzz(data []byte) int {
b := bytes.NewReader(data)
if _, err := Parse(b); err != nil {
return 0
}
return 1
}

51
vendor/github.com/aws/aws-sdk-go/internal/ini/ini.go generated vendored Normal file
View File

@ -0,0 +1,51 @@
package ini
import (
"io"
"os"
"github.com/aws/aws-sdk-go/aws/awserr"
)
// OpenFile takes a path to a given file, and will open and parse
// that file.
func OpenFile(path string) (Sections, error) {
f, err := os.Open(path)
if err != nil {
return Sections{}, awserr.New(ErrCodeUnableToReadFile, "unable to open file", err)
}
defer f.Close()
return Parse(f)
}
// Parse will parse the given file using the shared config
// visitor.
func Parse(f io.Reader) (Sections, error) {
tree, err := ParseAST(f)
if err != nil {
return Sections{}, err
}
v := NewDefaultVisitor()
if err = Walk(tree, v); err != nil {
return Sections{}, err
}
return v.Sections, nil
}
// ParseBytes will parse the given bytes and return the parsed sections.
func ParseBytes(b []byte) (Sections, error) {
tree, err := ParseASTBytes(b)
if err != nil {
return Sections{}, err
}
v := NewDefaultVisitor()
if err = Walk(tree, v); err != nil {
return Sections{}, err
}
return v.Sections, nil
}

View File

@ -0,0 +1,165 @@
package ini
import (
"bytes"
"io"
"io/ioutil"
"github.com/aws/aws-sdk-go/aws/awserr"
)
const (
// ErrCodeUnableToReadFile is used when a file is failed to be
// opened or read from.
ErrCodeUnableToReadFile = "FailedRead"
)
// TokenType represents the various different tokens types
type TokenType int
func (t TokenType) String() string {
switch t {
case TokenNone:
return "none"
case TokenLit:
return "literal"
case TokenSep:
return "sep"
case TokenOp:
return "op"
case TokenWS:
return "ws"
case TokenNL:
return "newline"
case TokenComment:
return "comment"
case TokenComma:
return "comma"
default:
return ""
}
}
// TokenType enums
const (
TokenNone = TokenType(iota)
TokenLit
TokenSep
TokenComma
TokenOp
TokenWS
TokenNL
TokenComment
)
type iniLexer struct{}
// Tokenize will return a list of tokens during lexical analysis of the
// io.Reader.
func (l *iniLexer) Tokenize(r io.Reader) ([]Token, error) {
b, err := ioutil.ReadAll(r)
if err != nil {
return nil, awserr.New(ErrCodeUnableToReadFile, "unable to read file", err)
}
return l.tokenize(b)
}
func (l *iniLexer) tokenize(b []byte) ([]Token, error) {
runes := bytes.Runes(b)
var err error
n := 0
tokenAmount := countTokens(runes)
tokens := make([]Token, tokenAmount)
count := 0
for len(runes) > 0 && count < tokenAmount {
switch {
case isWhitespace(runes[0]):
tokens[count], n, err = newWSToken(runes)
case isComma(runes[0]):
tokens[count], n = newCommaToken(), 1
case isComment(runes):
tokens[count], n, err = newCommentToken(runes)
case isNewline(runes):
tokens[count], n, err = newNewlineToken(runes)
case isSep(runes):
tokens[count], n, err = newSepToken(runes)
case isOp(runes):
tokens[count], n, err = newOpToken(runes)
default:
tokens[count], n, err = newLitToken(runes)
}
if err != nil {
return nil, err
}
count++
runes = runes[n:]
}
return tokens[:count], nil
}
func countTokens(runes []rune) int {
count, n := 0, 0
var err error
for len(runes) > 0 {
switch {
case isWhitespace(runes[0]):
_, n, err = newWSToken(runes)
case isComma(runes[0]):
_, n = newCommaToken(), 1
case isComment(runes):
_, n, err = newCommentToken(runes)
case isNewline(runes):
_, n, err = newNewlineToken(runes)
case isSep(runes):
_, n, err = newSepToken(runes)
case isOp(runes):
_, n, err = newOpToken(runes)
default:
_, n, err = newLitToken(runes)
}
if err != nil {
return 0
}
count++
runes = runes[n:]
}
return count + 1
}
// Token indicates a metadata about a given value.
type Token struct {
t TokenType
ValueType ValueType
base int
raw []rune
}
var emptyValue = Value{}
func newToken(t TokenType, raw []rune, v ValueType) Token {
return Token{
t: t,
raw: raw,
ValueType: v,
}
}
// Raw return the raw runes that were consumed
func (tok Token) Raw() []rune {
return tok.raw
}
// Type returns the token type
func (tok Token) Type() TokenType {
return tok.t
}

View File

@ -0,0 +1,356 @@
package ini
import (
"fmt"
"io"
)
// State enums for the parse table
const (
InvalidState = iota
// stmt -> value stmt'
StatementState
// stmt' -> MarkComplete | op stmt
StatementPrimeState
// value -> number | string | boolean | quoted_string
ValueState
// section -> [ section'
OpenScopeState
// section' -> value section_close
SectionState
// section_close -> ]
CloseScopeState
// SkipState will skip (NL WS)+
SkipState
// SkipTokenState will skip any token and push the previous
// state onto the stack.
SkipTokenState
// comment -> # comment' | ; comment'
// comment' -> MarkComplete | value
CommentState
// MarkComplete state will complete statements and move that
// to the completed AST list
MarkCompleteState
// TerminalState signifies that the tokens have been fully parsed
TerminalState
)
// parseTable is a state machine to dictate the grammar above.
var parseTable = map[ASTKind]map[TokenType]int{
ASTKindStart: map[TokenType]int{
TokenLit: StatementState,
TokenSep: OpenScopeState,
TokenWS: SkipTokenState,
TokenNL: SkipTokenState,
TokenComment: CommentState,
TokenNone: TerminalState,
},
ASTKindCommentStatement: map[TokenType]int{
TokenLit: StatementState,
TokenSep: OpenScopeState,
TokenWS: SkipTokenState,
TokenNL: SkipTokenState,
TokenComment: CommentState,
TokenNone: MarkCompleteState,
},
ASTKindExpr: map[TokenType]int{
TokenOp: StatementPrimeState,
TokenLit: ValueState,
TokenSep: OpenScopeState,
TokenWS: ValueState,
TokenNL: SkipState,
TokenComment: CommentState,
TokenNone: MarkCompleteState,
},
ASTKindEqualExpr: map[TokenType]int{
TokenLit: ValueState,
TokenWS: SkipTokenState,
TokenNL: SkipState,
},
ASTKindStatement: map[TokenType]int{
TokenLit: SectionState,
TokenSep: CloseScopeState,
TokenWS: SkipTokenState,
TokenNL: SkipTokenState,
TokenComment: CommentState,
TokenNone: MarkCompleteState,
},
ASTKindExprStatement: map[TokenType]int{
TokenLit: ValueState,
TokenSep: OpenScopeState,
TokenOp: ValueState,
TokenWS: ValueState,
TokenNL: MarkCompleteState,
TokenComment: CommentState,
TokenNone: TerminalState,
TokenComma: SkipState,
},
ASTKindSectionStatement: map[TokenType]int{
TokenLit: SectionState,
TokenOp: SectionState,
TokenSep: CloseScopeState,
TokenWS: SectionState,
TokenNL: SkipTokenState,
},
ASTKindCompletedSectionStatement: map[TokenType]int{
TokenWS: SkipTokenState,
TokenNL: SkipTokenState,
TokenLit: StatementState,
TokenSep: OpenScopeState,
TokenComment: CommentState,
TokenNone: MarkCompleteState,
},
ASTKindSkipStatement: map[TokenType]int{
TokenLit: StatementState,
TokenSep: OpenScopeState,
TokenWS: SkipTokenState,
TokenNL: SkipTokenState,
TokenComment: CommentState,
TokenNone: TerminalState,
},
}
// ParseAST will parse input from an io.Reader using
// an LL(1) parser.
func ParseAST(r io.Reader) ([]AST, error) {
lexer := iniLexer{}
tokens, err := lexer.Tokenize(r)
if err != nil {
return []AST{}, err
}
return parse(tokens)
}
// ParseASTBytes will parse input from a byte slice using
// an LL(1) parser.
func ParseASTBytes(b []byte) ([]AST, error) {
lexer := iniLexer{}
tokens, err := lexer.tokenize(b)
if err != nil {
return []AST{}, err
}
return parse(tokens)
}
func parse(tokens []Token) ([]AST, error) {
start := Start
stack := newParseStack(3, len(tokens))
stack.Push(start)
s := newSkipper()
loop:
for stack.Len() > 0 {
k := stack.Pop()
var tok Token
if len(tokens) == 0 {
// this occurs when all the tokens have been processed
// but reduction of what's left on the stack needs to
// occur.
tok = emptyToken
} else {
tok = tokens[0]
}
step := parseTable[k.Kind][tok.Type()]
if s.ShouldSkip(tok) {
// being in a skip state with no tokens will break out of
// the parse loop since there is nothing left to process.
if len(tokens) == 0 {
break loop
}
// if should skip is true, we skip the tokens until should skip is set to false.
step = SkipTokenState
}
switch step {
case TerminalState:
// Finished parsing. Push what should be the last
// statement to the stack. If there is anything left
// on the stack, an error in parsing has occurred.
if k.Kind != ASTKindStart {
stack.MarkComplete(k)
}
break loop
case SkipTokenState:
// When skipping a token, the previous state was popped off the stack.
// To maintain the correct state, the previous state will be pushed
// onto the stack.
stack.Push(k)
case StatementState:
if k.Kind != ASTKindStart {
stack.MarkComplete(k)
}
expr := newExpression(tok)
stack.Push(expr)
case StatementPrimeState:
if tok.Type() != TokenOp {
stack.MarkComplete(k)
continue
}
if k.Kind != ASTKindExpr {
return nil, NewParseError(
fmt.Sprintf("invalid expression: expected Expr type, but found %T type", k),
)
}
k = trimSpaces(k)
expr := newEqualExpr(k, tok)
stack.Push(expr)
case ValueState:
// ValueState requires the previous state to either be an equal expression
// or an expression statement.
//
// This grammar occurs when the RHS is a number, word, or quoted string.
// equal_expr -> lit op equal_expr'
// equal_expr' -> number | string | quoted_string
// quoted_string -> " quoted_string'
// quoted_string' -> string quoted_string_end
// quoted_string_end -> "
//
// otherwise
// expr_stmt -> equal_expr (expr_stmt')*
// expr_stmt' -> ws S | op S | MarkComplete
// S -> equal_expr' expr_stmt'
switch k.Kind {
case ASTKindEqualExpr:
// assigning a value to some key
k.AppendChild(newExpression(tok))
stack.Push(newExprStatement(k))
case ASTKindExpr:
k.Root.raw = append(k.Root.raw, tok.Raw()...)
stack.Push(k)
case ASTKindExprStatement:
root := k.GetRoot()
children := root.GetChildren()
if len(children) == 0 {
return nil, NewParseError(
fmt.Sprintf("invalid expression: AST contains no children %s", k.Kind),
)
}
rhs := children[len(children)-1]
if rhs.Root.ValueType != QuotedStringType {
rhs.Root.ValueType = StringType
rhs.Root.raw = append(rhs.Root.raw, tok.Raw()...)
}
children[len(children)-1] = rhs
k.SetChildren(children)
stack.Push(k)
}
case OpenScopeState:
if !runeCompare(tok.Raw(), openBrace) {
return nil, NewParseError("expected '['")
}
// If OpenScopeState is not at the start, we must mark the previous ast as complete
//
// for example: if previous ast was a skip statement;
// we should mark it as complete before we create a new statement
if k.Kind != ASTKindStart {
stack.MarkComplete(k)
}
stmt := newStatement()
stack.Push(stmt)
case CloseScopeState:
if !runeCompare(tok.Raw(), closeBrace) {
return nil, NewParseError("expected ']'")
}
k = trimSpaces(k)
stack.Push(newCompletedSectionStatement(k))
case SectionState:
var stmt AST
switch k.Kind {
case ASTKindStatement:
// If there are multiple literals inside of a scope declaration,
// then the current token's raw value will be appended to the Name.
//
// This handles cases like [ profile default ]
//
// k will represent a SectionStatement with the children representing
// the label of the section
stmt = newSectionStatement(tok)
case ASTKindSectionStatement:
k.Root.raw = append(k.Root.raw, tok.Raw()...)
stmt = k
default:
return nil, NewParseError(
fmt.Sprintf("invalid statement: expected statement: %v", k.Kind),
)
}
stack.Push(stmt)
case MarkCompleteState:
if k.Kind != ASTKindStart {
stack.MarkComplete(k)
}
if stack.Len() == 0 {
stack.Push(start)
}
case SkipState:
stack.Push(newSkipStatement(k))
s.Skip()
case CommentState:
if k.Kind == ASTKindStart {
stack.Push(k)
} else {
stack.MarkComplete(k)
}
stmt := newCommentStatement(tok)
stack.Push(stmt)
default:
return nil, NewParseError(
fmt.Sprintf("invalid state with ASTKind %v and TokenType %v",
k, tok.Type()))
}
if len(tokens) > 0 {
tokens = tokens[1:]
}
}
// this occurs when a statement has not been completed
if stack.top > 1 {
return nil, NewParseError(fmt.Sprintf("incomplete ini expression"))
}
// returns a sublist which excludes the start symbol
return stack.List(), nil
}
// trimSpaces will trim spaces on the left and right hand side of
// the literal.
func trimSpaces(k AST) AST {
// trim left hand side of spaces
for i := 0; i < len(k.Root.raw); i++ {
if !isWhitespace(k.Root.raw[i]) {
break
}
k.Root.raw = k.Root.raw[1:]
i--
}
// trim right hand side of spaces
for i := len(k.Root.raw) - 1; i >= 0; i-- {
if !isWhitespace(k.Root.raw[i]) {
break
}
k.Root.raw = k.Root.raw[:len(k.Root.raw)-1]
}
return k
}

View File

@ -0,0 +1,324 @@
package ini
import (
"fmt"
"strconv"
"strings"
)
var (
runesTrue = []rune("true")
runesFalse = []rune("false")
)
var literalValues = [][]rune{
runesTrue,
runesFalse,
}
func isBoolValue(b []rune) bool {
for _, lv := range literalValues {
if isLitValue(lv, b) {
return true
}
}
return false
}
func isLitValue(want, have []rune) bool {
if len(have) < len(want) {
return false
}
for i := 0; i < len(want); i++ {
if want[i] != have[i] {
return false
}
}
return true
}
// isNumberValue will return whether not the leading characters in
// a byte slice is a number. A number is delimited by whitespace or
// the newline token.
//
// A number is defined to be in a binary, octal, decimal (int | float), hex format,
// or in scientific notation.
func isNumberValue(b []rune) bool {
negativeIndex := 0
helper := numberHelper{}
needDigit := false
for i := 0; i < len(b); i++ {
negativeIndex++
switch b[i] {
case '-':
if helper.IsNegative() || negativeIndex != 1 {
return false
}
helper.Determine(b[i])
needDigit = true
continue
case 'e', 'E':
if err := helper.Determine(b[i]); err != nil {
return false
}
negativeIndex = 0
needDigit = true
continue
case 'b':
if helper.numberFormat == hex {
break
}
fallthrough
case 'o', 'x':
needDigit = true
if i == 0 {
return false
}
fallthrough
case '.':
if err := helper.Determine(b[i]); err != nil {
return false
}
needDigit = true
continue
}
if i > 0 && (isNewline(b[i:]) || isWhitespace(b[i])) {
return !needDigit
}
if !helper.CorrectByte(b[i]) {
return false
}
needDigit = false
}
return !needDigit
}
func isValid(b []rune) (bool, int, error) {
if len(b) == 0 {
// TODO: should probably return an error
return false, 0, nil
}
return isValidRune(b[0]), 1, nil
}
func isValidRune(r rune) bool {
return r != ':' && r != '=' && r != '[' && r != ']' && r != ' ' && r != '\n'
}
// ValueType is an enum that will signify what type
// the Value is
type ValueType int
func (v ValueType) String() string {
switch v {
case NoneType:
return "NONE"
case DecimalType:
return "FLOAT"
case IntegerType:
return "INT"
case StringType:
return "STRING"
case BoolType:
return "BOOL"
}
return ""
}
// ValueType enums
const (
NoneType = ValueType(iota)
DecimalType
IntegerType
StringType
QuotedStringType
BoolType
)
// Value is a union container
type Value struct {
Type ValueType
raw []rune
integer int64
decimal float64
boolean bool
str string
}
func newValue(t ValueType, base int, raw []rune) (Value, error) {
v := Value{
Type: t,
raw: raw,
}
var err error
switch t {
case DecimalType:
v.decimal, err = strconv.ParseFloat(string(raw), 64)
case IntegerType:
if base != 10 {
raw = raw[2:]
}
v.integer, err = strconv.ParseInt(string(raw), base, 64)
case StringType:
v.str = string(raw)
case QuotedStringType:
v.str = string(raw[1 : len(raw)-1])
case BoolType:
v.boolean = runeCompare(v.raw, runesTrue)
}
// issue 2253
//
// if the value trying to be parsed is too large, then we will use
// the 'StringType' and raw value instead.
if nerr, ok := err.(*strconv.NumError); ok && nerr.Err == strconv.ErrRange {
v.Type = StringType
v.str = string(raw)
err = nil
}
return v, err
}
// Append will append values and change the type to a string
// type.
func (v *Value) Append(tok Token) {
r := tok.Raw()
if v.Type != QuotedStringType {
v.Type = StringType
r = tok.raw[1 : len(tok.raw)-1]
}
if tok.Type() != TokenLit {
v.raw = append(v.raw, tok.Raw()...)
} else {
v.raw = append(v.raw, r...)
}
}
func (v Value) String() string {
switch v.Type {
case DecimalType:
return fmt.Sprintf("decimal: %f", v.decimal)
case IntegerType:
return fmt.Sprintf("integer: %d", v.integer)
case StringType:
return fmt.Sprintf("string: %s", string(v.raw))
case QuotedStringType:
return fmt.Sprintf("quoted string: %s", string(v.raw))
case BoolType:
return fmt.Sprintf("bool: %t", v.boolean)
default:
return "union not set"
}
}
func newLitToken(b []rune) (Token, int, error) {
n := 0
var err error
token := Token{}
if b[0] == '"' {
n, err = getStringValue(b)
if err != nil {
return token, n, err
}
token = newToken(TokenLit, b[:n], QuotedStringType)
} else if isNumberValue(b) {
var base int
base, n, err = getNumericalValue(b)
if err != nil {
return token, 0, err
}
value := b[:n]
vType := IntegerType
if contains(value, '.') || hasExponent(value) {
vType = DecimalType
}
token = newToken(TokenLit, value, vType)
token.base = base
} else if isBoolValue(b) {
n, err = getBoolValue(b)
token = newToken(TokenLit, b[:n], BoolType)
} else {
n, err = getValue(b)
token = newToken(TokenLit, b[:n], StringType)
}
return token, n, err
}
// IntValue returns an integer value
func (v Value) IntValue() int64 {
return v.integer
}
// FloatValue returns a float value
func (v Value) FloatValue() float64 {
return v.decimal
}
// BoolValue returns a bool value
func (v Value) BoolValue() bool {
return v.boolean
}
func isTrimmable(r rune) bool {
switch r {
case '\n', ' ':
return true
}
return false
}
// StringValue returns the string value
func (v Value) StringValue() string {
switch v.Type {
case StringType:
return strings.TrimFunc(string(v.raw), isTrimmable)
case QuotedStringType:
// preserve all characters in the quotes
return string(removeEscapedCharacters(v.raw[1 : len(v.raw)-1]))
default:
return strings.TrimFunc(string(v.raw), isTrimmable)
}
}
func contains(runes []rune, c rune) bool {
for i := 0; i < len(runes); i++ {
if runes[i] == c {
return true
}
}
return false
}
func runeCompare(v1 []rune, v2 []rune) bool {
if len(v1) != len(v2) {
return false
}
for i := 0; i < len(v1); i++ {
if v1[i] != v2[i] {
return false
}
}
return true
}

View File

@ -0,0 +1,30 @@
package ini
func isNewline(b []rune) bool {
if len(b) == 0 {
return false
}
if b[0] == '\n' {
return true
}
if len(b) < 2 {
return false
}
return b[0] == '\r' && b[1] == '\n'
}
func newNewlineToken(b []rune) (Token, int, error) {
i := 1
if b[0] == '\r' && isNewline(b[1:]) {
i++
}
if !isNewline([]rune(b[:i])) {
return emptyToken, 0, NewParseError("invalid new line token")
}
return newToken(TokenNL, b[:i], NoneType), i, nil
}

View File

@ -0,0 +1,152 @@
package ini
import (
"bytes"
"fmt"
"strconv"
)
const (
none = numberFormat(iota)
binary
octal
decimal
hex
exponent
)
type numberFormat int
// numberHelper is used to dictate what format a number is in
// and what to do for negative values. Since -1e-4 is a valid
// number, we cannot just simply check for duplicate negatives.
type numberHelper struct {
numberFormat numberFormat
negative bool
negativeExponent bool
}
func (b numberHelper) Exists() bool {
return b.numberFormat != none
}
func (b numberHelper) IsNegative() bool {
return b.negative || b.negativeExponent
}
func (b *numberHelper) Determine(c rune) error {
if b.Exists() {
return NewParseError(fmt.Sprintf("multiple number formats: 0%v", string(c)))
}
switch c {
case 'b':
b.numberFormat = binary
case 'o':
b.numberFormat = octal
case 'x':
b.numberFormat = hex
case 'e', 'E':
b.numberFormat = exponent
case '-':
if b.numberFormat != exponent {
b.negative = true
} else {
b.negativeExponent = true
}
case '.':
b.numberFormat = decimal
default:
return NewParseError(fmt.Sprintf("invalid number character: %v", string(c)))
}
return nil
}
func (b numberHelper) CorrectByte(c rune) bool {
switch {
case b.numberFormat == binary:
if !isBinaryByte(c) {
return false
}
case b.numberFormat == octal:
if !isOctalByte(c) {
return false
}
case b.numberFormat == hex:
if !isHexByte(c) {
return false
}
case b.numberFormat == decimal:
if !isDigit(c) {
return false
}
case b.numberFormat == exponent:
if !isDigit(c) {
return false
}
case b.negativeExponent:
if !isDigit(c) {
return false
}
case b.negative:
if !isDigit(c) {
return false
}
default:
if !isDigit(c) {
return false
}
}
return true
}
func (b numberHelper) Base() int {
switch b.numberFormat {
case binary:
return 2
case octal:
return 8
case hex:
return 16
default:
return 10
}
}
func (b numberHelper) String() string {
buf := bytes.Buffer{}
i := 0
switch b.numberFormat {
case binary:
i++
buf.WriteString(strconv.Itoa(i) + ": binary format\n")
case octal:
i++
buf.WriteString(strconv.Itoa(i) + ": octal format\n")
case hex:
i++
buf.WriteString(strconv.Itoa(i) + ": hex format\n")
case exponent:
i++
buf.WriteString(strconv.Itoa(i) + ": exponent format\n")
default:
i++
buf.WriteString(strconv.Itoa(i) + ": integer format\n")
}
if b.negative {
i++
buf.WriteString(strconv.Itoa(i) + ": negative format\n")
}
if b.negativeExponent {
i++
buf.WriteString(strconv.Itoa(i) + ": negative exponent format\n")
}
return buf.String()
}

View File

@ -0,0 +1,39 @@
package ini
import (
"fmt"
)
var (
equalOp = []rune("=")
equalColonOp = []rune(":")
)
func isOp(b []rune) bool {
if len(b) == 0 {
return false
}
switch b[0] {
case '=':
return true
case ':':
return true
default:
return false
}
}
func newOpToken(b []rune) (Token, int, error) {
tok := Token{}
switch b[0] {
case '=':
tok = newToken(TokenOp, equalOp, NoneType)
case ':':
tok = newToken(TokenOp, equalColonOp, NoneType)
default:
return tok, 0, NewParseError(fmt.Sprintf("unexpected op type, %v", b[0]))
}
return tok, 1, nil
}

Some files were not shown because too many files have changed in this diff Show More