mirror of https://github.com/status-im/consul.git
connect: Add AWS PCA provider (#6795)
* Update AWS SDK to use PCA features. * Add AWS PCA provider * Add plumbing for config, config validation tests, add test for inheriting existing CA resources created by user * Unparallel the tests so we don't exhaust PCA limits * Merge updates * More aggressive polling; rate limit pass through on sign; Timeout on Sign and CA create * Add AWS PCA docs * Fix Vault doc typo too * Doc typo * Apply suggestions from code review Co-Authored-By: R.B. Boyer <rb@hashicorp.com> Co-Authored-By: kaitlincarter-hc <43049322+kaitlincarter-hc@users.noreply.github.com> * Doc fixes; tests for erroring if State is modified via API * More review cleanup * Uncomment tests! * Minor suggested clean ups
This commit is contained in:
parent
58db6b5d22
commit
cd1b613352
|
@ -590,6 +590,10 @@ func (b *Builder) Build() (rt RuntimeConfig, err error) {
|
|||
"tls_server_name": "TLSServerName",
|
||||
"tls_skip_verify": "TLSSkipVerify",
|
||||
|
||||
// AWS CA config
|
||||
"existing_arn": "ExistingARN",
|
||||
"delete_on_exit": "DeleteOnExit",
|
||||
|
||||
// Common CA config
|
||||
"leaf_cert_ttl": "LeafCertTTL",
|
||||
"csr_max_per_second": "CSRMaxPerSecond",
|
||||
|
@ -1095,6 +1099,7 @@ func (b *Builder) Validate(rt RuntimeConfig) error {
|
|||
"": true,
|
||||
structs.ConsulCAProvider: true,
|
||||
structs.VaultCAProvider: true,
|
||||
structs.AWSCAProvider: true,
|
||||
}
|
||||
if _, ok := validCAProviders[rt.ConnectCAProvider]; !ok {
|
||||
return fmt.Errorf("%s is not a valid CA provider", rt.ConnectCAProvider)
|
||||
|
@ -1108,6 +1113,10 @@ func (b *Builder) Validate(rt RuntimeConfig) error {
|
|||
if _, err := ca.ParseVaultCAConfig(rt.ConnectCAConfig); err != nil {
|
||||
return err
|
||||
}
|
||||
case structs.AWSCAProvider:
|
||||
if _, err := ca.ParseAWSCAConfig(rt.ConnectCAConfig); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -2752,6 +2752,91 @@ func TestConfigFlagsAndEdgecases(t *testing.T) {
|
|||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "Connect AWS CA provider configuration",
|
||||
args: []string{
|
||||
`-data-dir=` + dataDir,
|
||||
},
|
||||
json: []string{`{
|
||||
"connect": {
|
||||
"enabled": true,
|
||||
"ca_provider": "aws-pca",
|
||||
"ca_config": {
|
||||
"existing_arn": "foo",
|
||||
"delete_on_exit": true
|
||||
}
|
||||
}
|
||||
}`},
|
||||
hcl: []string{`
|
||||
connect {
|
||||
enabled = true
|
||||
ca_provider = "aws-pca"
|
||||
ca_config {
|
||||
existing_arn = "foo"
|
||||
delete_on_exit = true
|
||||
}
|
||||
}
|
||||
`},
|
||||
patch: func(rt *RuntimeConfig) {
|
||||
rt.DataDir = dataDir
|
||||
rt.ConnectEnabled = true
|
||||
rt.ConnectCAProvider = "aws-pca"
|
||||
rt.ConnectCAConfig = map[string]interface{}{
|
||||
"ExistingARN": "foo",
|
||||
"DeleteOnExit": true,
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "Connect AWS CA provider TTL validation",
|
||||
args: []string{
|
||||
`-data-dir=` + dataDir,
|
||||
},
|
||||
json: []string{`{
|
||||
"connect": {
|
||||
"enabled": true,
|
||||
"ca_provider": "aws-pca",
|
||||
"ca_config": {
|
||||
"leaf_cert_ttl": "1h"
|
||||
}
|
||||
}
|
||||
}`},
|
||||
hcl: []string{`
|
||||
connect {
|
||||
enabled = true
|
||||
ca_provider = "aws-pca"
|
||||
ca_config {
|
||||
leaf_cert_ttl = "1h"
|
||||
}
|
||||
}
|
||||
`},
|
||||
err: "AWS PCA doesn't support certificates that are valid for less than 24 hours",
|
||||
},
|
||||
{
|
||||
desc: "Connect AWS CA provider EC key length validation",
|
||||
args: []string{
|
||||
`-data-dir=` + dataDir,
|
||||
},
|
||||
json: []string{`{
|
||||
"connect": {
|
||||
"enabled": true,
|
||||
"ca_provider": "aws-pca",
|
||||
"ca_config": {
|
||||
"private_key_bits": 521
|
||||
}
|
||||
}
|
||||
}`},
|
||||
hcl: []string{`
|
||||
connect {
|
||||
enabled = true
|
||||
ca_provider = "aws-pca"
|
||||
ca_config {
|
||||
private_key_bits = 521
|
||||
}
|
||||
}
|
||||
`},
|
||||
err: "AWS PCA only supports P256 EC curve",
|
||||
},
|
||||
|
||||
// ------------------------------------------------------------
|
||||
// ConfigEntry Handling
|
||||
|
|
|
@ -63,3 +63,29 @@ func validateSetIntermediate(
|
|||
|
||||
return nil
|
||||
}
|
||||
|
||||
func validateSignIntermediate(csr *x509.CertificateRequest, spiffeID *connect.SpiffeIDSigning) error {
|
||||
// We explicitly _don't_ require that the CSR has a valid SPIFFE signing URI
|
||||
// SAN because AWS PCA doesn't let us set one :(. We need to relax it here
|
||||
// otherwise it would be impossible to migrate from built-in provider to AWS
|
||||
// in multiple DCs without downtime. Nothing in Connect actually checks that
|
||||
// currently so this is OK for now but it's sad we have to break the SPIFFE
|
||||
// spec for AWS sake. Hopefully they'll add that ability soon.
|
||||
uriCount := len(csr.URIs)
|
||||
if uriCount > 0 {
|
||||
if uriCount != 1 {
|
||||
return fmt.Errorf("incoming CSR has unexpected number of URIs: %d", uriCount)
|
||||
}
|
||||
certURI, err := connect.ParseCertURI(csr.URIs[0])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Verify that the trust domain is valid.
|
||||
if !spiffeID.CanSign(certURI) {
|
||||
return fmt.Errorf("incoming CSR domain %q is not valid for our domain %q",
|
||||
certURI.URI().String(), spiffeID.URI().String())
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
|
|
@ -2,11 +2,19 @@ package ca
|
|||
|
||||
import (
|
||||
"crypto/x509"
|
||||
"errors"
|
||||
"log"
|
||||
)
|
||||
|
||||
//go:generate mockery -name Provider -inpkg
|
||||
|
||||
// ErrRateLimited is a sentinel error value Providers may return from any method
|
||||
// to indicate that the operation can't complete due to a temporary rate limit.
|
||||
// In the case of signing new certificates, Consul clients will respect this and
|
||||
// intelligently backoff to optimize rotation rollout time while reducing load
|
||||
// on servers and CA provider.
|
||||
var ErrRateLimited = errors.New("operation rate limited by CA provider")
|
||||
|
||||
// ProviderConfig encapsulates all the data Consul passes to `Configure` on a
|
||||
// new provider instance. The provider must treat this as read-only and make
|
||||
// copies of any map or slice if it might modify them internally.
|
||||
|
@ -71,7 +79,7 @@ type Provider interface {
|
|||
|
||||
// GenerateRoot causes the creation of a new root certificate for this provider.
|
||||
// This can also be a no-op if a root certificate already exists for the given
|
||||
// config. If isRoot is false, calling this method is an error.
|
||||
// config. If IsPrimary is false, calling this method is an error.
|
||||
GenerateRoot() error
|
||||
|
||||
// ActiveRoot returns the currently active root CA for this
|
||||
|
@ -80,13 +88,13 @@ type Provider interface {
|
|||
ActiveRoot() (string, error)
|
||||
|
||||
// GenerateIntermediateCSR generates a CSR for an intermediate CA
|
||||
// certificate, to be signed by the root of another datacenter. If isRoot was
|
||||
// certificate, to be signed by the root of another datacenter. If IsPrimary was
|
||||
// set to true with Configure(), calling this is an error.
|
||||
GenerateIntermediateCSR() (string, error)
|
||||
|
||||
// SetIntermediate sets the provider to use the given intermediate certificate
|
||||
// as well as the root it was signed by. This completes the initialization for
|
||||
// a provider where isRoot was set to false in Configure().
|
||||
// a provider where IsPrimary was set to false in Configure().
|
||||
SetIntermediate(intermediatePEM, rootPEM string) error
|
||||
|
||||
// ActiveIntermediate returns the current signing cert used by this provider
|
||||
|
@ -105,13 +113,19 @@ type Provider interface {
|
|||
// Sign signs a leaf certificate used by Connect proxies from a CSR. The PEM
|
||||
// returned should include only the leaf certificate as all Intermediates
|
||||
// needed to validate it will be added by Consul based on the active
|
||||
// intemediate and any cross-signed intermediates managed by Consul.
|
||||
// intemediate and any cross-signed intermediates managed by Consul. Note that
|
||||
// providers should return ErrRateLimited if they are unable to complete the
|
||||
// operation due to upstream rate limiting so that clients can intelligently
|
||||
// backoff.
|
||||
Sign(*x509.CertificateRequest) (string, error)
|
||||
|
||||
// SignIntermediate will validate the CSR to ensure the trust domain in the
|
||||
// URI SAN matches the local one and that basic constraints for a CA certificate
|
||||
// are met. It should return a signed CA certificate with a path length constraint
|
||||
// of 0 to ensure that the certificate cannot be used to generate further CA certs.
|
||||
// URI SAN matches the local one and that basic constraints for a CA
|
||||
// certificate are met. It should return a signed CA certificate with a path
|
||||
// length constraint of 0 to ensure that the certificate cannot be used to
|
||||
// generate further CA certs. Note that providers should return ErrRateLimited
|
||||
// if they are unable to complete the operation due to upstream rate limiting
|
||||
// so that clients can intelligently backoff.
|
||||
SignIntermediate(*x509.CertificateRequest) (string, error)
|
||||
|
||||
// CrossSignCA must accept a CA certificate from another CA provider and cross
|
||||
|
@ -124,7 +138,10 @@ type Provider interface {
|
|||
// returned as a PEM formatted string.
|
||||
//
|
||||
// If the CA provider does not support this operation, it may return an error
|
||||
// provided `SupportsCrossSigning` also returns false.
|
||||
// provided `SupportsCrossSigning` also returns false. Note that
|
||||
// providers should return ErrRateLimited if they are unable to complete the
|
||||
// operation due to upstream rate limiting so that clients can intelligently
|
||||
// backoff.
|
||||
CrossSignCA(*x509.Certificate) (string, error)
|
||||
|
||||
// SupportsCrossSigning should indicate whether the CA provider supports
|
||||
|
|
|
@ -0,0 +1,715 @@
|
|||
package ca
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/x509"
|
||||
"encoding/pem"
|
||||
"fmt"
|
||||
"log"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/aws/aws-sdk-go/aws"
|
||||
"github.com/aws/aws-sdk-go/aws/awserr"
|
||||
"github.com/aws/aws-sdk-go/aws/session"
|
||||
"github.com/aws/aws-sdk-go/service/acmpca"
|
||||
"github.com/mitchellh/mapstructure"
|
||||
|
||||
"github.com/hashicorp/consul/agent/connect"
|
||||
"github.com/hashicorp/consul/agent/structs"
|
||||
)
|
||||
|
||||
const (
|
||||
// RootTemplateARN is the AWS-defined template we need to use when issuing a
|
||||
// root cert.
|
||||
RootTemplateARN = "arn:aws:acm-pca:::template/RootCACertificate/V1"
|
||||
|
||||
// IntermediateTemplateARN is the AWS-defined template we need to use when
|
||||
// issuing an intermediate cert.
|
||||
IntermediateTemplateARN = "arn:aws:acm-pca:::template/SubordinateCACertificate_PathLen0/V1"
|
||||
|
||||
// LeafTemplateARN is the AWS-defined template we need to use when issuing a
|
||||
// leaf cert.
|
||||
LeafTemplateARN = "arn:aws:acm-pca:::template/EndEntityCertificate/V1"
|
||||
|
||||
// RootTTL is the validity duration for root certs we create.
|
||||
AWSRootTTL = 5 * 365 * 24 * time.Hour
|
||||
|
||||
// IntermediateTTL is the validity duration for the intermediate certs we
|
||||
// create.
|
||||
AWSIntermediateTTL = 1 * 365 * 24 * time.Hour
|
||||
|
||||
// SignTimout is the maximum time we will spend waiting (polling) for a leaf
|
||||
// certificate to be signed.
|
||||
AWSSignTimeout = 45 * time.Second
|
||||
|
||||
// CreateTimeout is the maximum time we will spend waiting (polling)
|
||||
// for the CA to be created.
|
||||
AWSCreateTimeout = 2 * time.Minute
|
||||
|
||||
// AWSStateCAARNKey is the key in the provider State we store the ARN of the
|
||||
// CA we created if any.
|
||||
AWSStateCAARNKey = "CA_ARN"
|
||||
|
||||
// day is a more readable shorthand for a duration of 24 hours. Note time
|
||||
// package doesn't provide time.Day due to ambiguity around DST and leap
|
||||
// seconds where a day may not actually be 24 hours.
|
||||
day = 24 * time.Hour
|
||||
)
|
||||
|
||||
// AWSProvider implements Provider for AWS ACM PCA
|
||||
type AWSProvider struct {
|
||||
stopped uint32 // atomically accessed, at start to prevent alignment issues
|
||||
stopCh chan struct{}
|
||||
|
||||
config *structs.AWSCAProviderConfig
|
||||
session *session.Session
|
||||
client *acmpca.ACMPCA
|
||||
isPrimary bool
|
||||
datacenter string
|
||||
clusterID string
|
||||
arn string
|
||||
arnChecked bool
|
||||
caCreated bool
|
||||
rootPEM string
|
||||
intermediatePEM string
|
||||
logger *log.Logger
|
||||
}
|
||||
|
||||
// SetLogger implements NeedsLogger
|
||||
func (a *AWSProvider) SetLogger(l *log.Logger) {
|
||||
a.logger = l
|
||||
}
|
||||
|
||||
// Configure implements Provider
|
||||
func (a *AWSProvider) Configure(cfg ProviderConfig) error {
|
||||
config, err := ParseAWSCAConfig(cfg.RawConfig)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// We only support setting IAM credentials through the normal methods ENV,
|
||||
// SharedCredentialsFile, IAM role. Per
|
||||
// https://docs.aws.amazon.com/sdk-for-go/v1/developer-guide/configuring-sdk.html#specifying-credentials
|
||||
// Putting them in CA config amounts to writing them to disk config file in
|
||||
// another place or sending them via API call and persisting them in state
|
||||
// store in a new place on disk. One of the existing standard solutions seems
|
||||
// better in all cases.
|
||||
awsSession, err := session.NewSessionWithOptions(session.Options{
|
||||
SharedConfigState: session.SharedConfigEnable,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
a.config = config
|
||||
a.session = awsSession
|
||||
a.isPrimary = cfg.IsPrimary
|
||||
a.clusterID = cfg.ClusterID
|
||||
a.datacenter = cfg.Datacenter
|
||||
a.client = acmpca.New(awsSession)
|
||||
a.stopCh = make(chan struct{})
|
||||
|
||||
// Load the ARN from config or previous state.
|
||||
if config.ExistingARN != "" {
|
||||
a.arn = config.ExistingARN
|
||||
} else if arn := cfg.State[AWSStateCAARNKey]; arn != "" {
|
||||
a.arn = arn
|
||||
// We only pass ARN through state if we created the resource. We don't
|
||||
// "remember" previously existing resources the user configured.
|
||||
a.caCreated = true
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// State implements Provider
|
||||
func (a *AWSProvider) State() (map[string]string, error) {
|
||||
if a.arn == "" {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// Preserve the CA ARN if there is one
|
||||
state := make(map[string]string)
|
||||
state[AWSStateCAARNKey] = a.arn
|
||||
return state, nil
|
||||
}
|
||||
|
||||
// GenerateRoot implements Provider
|
||||
func (a *AWSProvider) GenerateRoot() error {
|
||||
if !a.isPrimary {
|
||||
return fmt.Errorf("provider is not the root certificate authority")
|
||||
}
|
||||
|
||||
return a.ensureCA()
|
||||
}
|
||||
|
||||
// ensureCA loads the CA resource to check it exists if configured by User or in
|
||||
// state from previous run. Otherwise it creates a new CA of the correct type
|
||||
// for this DC.
|
||||
func (a *AWSProvider) ensureCA() error {
|
||||
// If we already have an ARN, we assume the CA is created and sanity check
|
||||
// it's available.
|
||||
if a.arn != "" {
|
||||
// Only check this once on startup not on every operation
|
||||
if a.arnChecked {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Load from the resource.
|
||||
input := &acmpca.DescribeCertificateAuthorityInput{
|
||||
CertificateAuthorityArn: aws.String(a.arn),
|
||||
}
|
||||
output, err := a.client.DescribeCertificateAuthority(input)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// Allow it to be active or pending a certificate (leadership might have
|
||||
// changed during a secondary initialization for example).
|
||||
if *output.CertificateAuthority.Status != acmpca.CertificateAuthorityStatusActive &&
|
||||
*output.CertificateAuthority.Status != acmpca.CertificateAuthorityStatusPendingCertificate {
|
||||
verb := "configured"
|
||||
if a.caCreated {
|
||||
verb = "created"
|
||||
}
|
||||
// Don't recreate CA that was manually disabled, force full deletion or
|
||||
// manual recreation. We might later support this or an explicit
|
||||
// "recreate" config option to allow rotating without a manual creation
|
||||
// but this is simpler and less surprising default behavior if user
|
||||
// disabled a CA due to a security concern and we just work around it.
|
||||
return fmt.Errorf("the %s PCA is not active: status is %s", verb,
|
||||
*output.CertificateAuthority.Status)
|
||||
}
|
||||
|
||||
// Load the certs
|
||||
if err := a.loadCACerts(); err != nil {
|
||||
return err
|
||||
}
|
||||
a.arnChecked = true
|
||||
return nil
|
||||
}
|
||||
|
||||
// Need to create a Private CA resource.
|
||||
err := a.createPCA()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// If we are in a secondary DC this is all we can do for now - the rest is
|
||||
// handled by the Initialization routine of calling GenerateIntermediateCSR
|
||||
// and then SetIntermediate.
|
||||
if !a.isPrimary {
|
||||
return nil
|
||||
}
|
||||
|
||||
// CA is created and in PENDING_CERTIFCATE state, generate a self-signed cert
|
||||
// and install it.
|
||||
csrPEM, err := a.getCACSR()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Self-sign it as a root
|
||||
certPEM, err := a.signCSR(csrPEM, RootTemplateARN, AWSRootTTL)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Submit the signed cert
|
||||
input := acmpca.ImportCertificateAuthorityCertificateInput{
|
||||
CertificateAuthorityArn: aws.String(a.arn),
|
||||
Certificate: []byte(certPEM),
|
||||
}
|
||||
|
||||
a.logger.Printf("[DEBUG] connect.ca.aws: uploading certificate for %s", a.arn)
|
||||
_, err = a.client.ImportCertificateAuthorityCertificate(&input)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
a.rootPEM = certPEM
|
||||
return nil
|
||||
}
|
||||
|
||||
func keyTypeToAlgos(keyType string, keyBits int) (string, string, error) {
|
||||
switch keyType {
|
||||
case "rsa":
|
||||
switch keyBits {
|
||||
case 2048:
|
||||
return acmpca.KeyAlgorithmRsa2048, acmpca.SigningAlgorithmSha256withrsa, nil
|
||||
case 4096:
|
||||
return acmpca.KeyAlgorithmRsa4096, acmpca.SigningAlgorithmSha256withrsa, nil
|
||||
default:
|
||||
return "", "", fmt.Errorf("AWS PCA only supports RSA key lengths 2048"+
|
||||
" and 4096, PrivateKeyBits of %d configured", keyBits)
|
||||
}
|
||||
case "ec":
|
||||
if keyBits != 256 {
|
||||
return "", "", fmt.Errorf("AWS PCA only supports P256 EC curve, keyBits of %d configured", keyBits)
|
||||
}
|
||||
return acmpca.KeyAlgorithmEcPrime256v1, acmpca.SigningAlgorithmSha256withecdsa, nil
|
||||
default:
|
||||
return "", "", fmt.Errorf("AWS PCA only supports P256 EC curve, or RSA"+
|
||||
" 2048/4096. %s, %d configured", keyType, keyBits)
|
||||
}
|
||||
}
|
||||
|
||||
func (a *AWSProvider) createPCA() error {
|
||||
pcaType := "ROOT" // For some reason there is no constant for this in the SDK
|
||||
if !a.isPrimary {
|
||||
pcaType = acmpca.CertificateAuthorityTypeSubordinate
|
||||
}
|
||||
|
||||
keyAlg, signAlg, err := keyTypeToAlgos(a.config.PrivateKeyType, a.config.PrivateKeyBits)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
uid, err := connect.CompactUID()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
commonName := connect.CACN("aws", uid, a.clusterID, a.isPrimary)
|
||||
|
||||
createInput := acmpca.CreateCertificateAuthorityInput{
|
||||
CertificateAuthorityType: aws.String(pcaType),
|
||||
CertificateAuthorityConfiguration: &acmpca.CertificateAuthorityConfiguration{
|
||||
Subject: &acmpca.ASN1Subject{
|
||||
CommonName: aws.String(commonName),
|
||||
},
|
||||
KeyAlgorithm: aws.String(keyAlg),
|
||||
SigningAlgorithm: aws.String(signAlg),
|
||||
},
|
||||
RevocationConfiguration: &acmpca.RevocationConfiguration{
|
||||
// TODO support CRL in future when we manage revocation in Connect more
|
||||
// generally.
|
||||
CrlConfiguration: &acmpca.CrlConfiguration{
|
||||
Enabled: aws.Bool(false),
|
||||
},
|
||||
},
|
||||
// uid is unique to each PCA we create so use it as an idempotency string. We
|
||||
// don't actually retry on failure yet but might as well!
|
||||
IdempotencyToken: aws.String(uid),
|
||||
Tags: []*acmpca.Tag{
|
||||
{Key: aws.String("consul_cluster_id"), Value: aws.String(a.clusterID)},
|
||||
{Key: aws.String("consul_datacenter"), Value: aws.String(a.datacenter)},
|
||||
},
|
||||
}
|
||||
|
||||
a.logger.Printf("[DEBUG] creating new PCA %s", commonName)
|
||||
createOutput, err := a.client.CreateCertificateAuthority(&createInput)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// wait for PCA to be created
|
||||
newARN := *createOutput.CertificateAuthorityArn
|
||||
describeInput := acmpca.DescribeCertificateAuthorityInput{
|
||||
CertificateAuthorityArn: aws.String(newARN),
|
||||
}
|
||||
_, err = a.pollLoop("Private CA", AWSCreateTimeout, func() (bool, string, error) {
|
||||
describeOutput, err := a.client.DescribeCertificateAuthority(&describeInput)
|
||||
if err != nil {
|
||||
if err.(awserr.Error).Code() != acmpca.ErrCodeRequestInProgressException {
|
||||
return true, "", fmt.Errorf("error waiting for PCA to be created: %s", err)
|
||||
}
|
||||
}
|
||||
if *describeOutput.CertificateAuthority.Status == acmpca.CertificateAuthorityStatusPendingCertificate {
|
||||
a.logger.Printf("[DEBUG] connect.ca.aws: new PCA %s is ready"+
|
||||
" to accept a certificate", newARN)
|
||||
a.arn = newARN
|
||||
// We don't need to reload this ARN since we just created it and know what
|
||||
// state it's in
|
||||
a.arnChecked = true
|
||||
return true, "", nil
|
||||
}
|
||||
// Retry
|
||||
return false, "", nil
|
||||
})
|
||||
return err
|
||||
}
|
||||
|
||||
func (a *AWSProvider) getCACSR() (string, error) {
|
||||
input := &acmpca.GetCertificateAuthorityCsrInput{
|
||||
CertificateAuthorityArn: aws.String(a.arn),
|
||||
}
|
||||
a.logger.Printf("[DEBUG] connect.ca.aws: retrieving CSR for %s", a.arn)
|
||||
output, err := a.client.GetCertificateAuthorityCsr(input)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
csrPEM := output.Csr
|
||||
if csrPEM == nil {
|
||||
// Probably shouldn't be able to happen but being defensive.
|
||||
return "", fmt.Errorf("invalid response from AWS PCA: CSR is nil")
|
||||
}
|
||||
return *csrPEM, nil
|
||||
}
|
||||
|
||||
func (a *AWSProvider) loadCACerts() error {
|
||||
input := &acmpca.GetCertificateAuthorityCertificateInput{
|
||||
CertificateAuthorityArn: aws.String(a.arn),
|
||||
}
|
||||
output, err := a.client.GetCertificateAuthorityCertificate(input)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if a.isPrimary {
|
||||
// Just use the cert as a root
|
||||
a.rootPEM = *output.Certificate
|
||||
} else {
|
||||
a.intermediatePEM = *output.Certificate
|
||||
// TODO(banks) support user-supplied CA being a Subordinate even in the
|
||||
// primary DC. For now this assumes there is only one cert in the chain
|
||||
if output.CertificateChain == nil {
|
||||
return fmt.Errorf("Subordinate CA %s returned no chain", a.arn)
|
||||
}
|
||||
a.rootPEM = *output.CertificateChain
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *AWSProvider) signCSRRaw(csr *x509.CertificateRequest, templateARN string, ttl time.Duration) (string, error) {
|
||||
// PEM encode the CSR
|
||||
var pemBuf bytes.Buffer
|
||||
if err := pem.Encode(&pemBuf, &pem.Block{Type: "CERTIFICATE REQUEST", Bytes: csr.Raw}); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return a.signCSR(pemBuf.String(), templateARN, ttl)
|
||||
}
|
||||
|
||||
// pollWait returns how long to wait for the next poll of an async operation. We
|
||||
// optimize for times typically seen in the API. This is called _before_ the
|
||||
// first poll so we can provide a typical delay since operations are never
|
||||
// complete immediately so it's a waste to try.
|
||||
func pollWait(attemptsMade int) time.Duration {
|
||||
// Hard code times for now
|
||||
waits := []time.Duration{
|
||||
// Never seen it complete first time with a lower value
|
||||
100 * time.Millisecond,
|
||||
200 * time.Millisecond,
|
||||
500 * time.Millisecond,
|
||||
1 * time.Second,
|
||||
3 * time.Second,
|
||||
5 * time.Second,
|
||||
}
|
||||
maxIdx := len(waits) - 1
|
||||
if attemptsMade > maxIdx {
|
||||
attemptsMade = maxIdx
|
||||
}
|
||||
return waits[attemptsMade]
|
||||
}
|
||||
|
||||
func (a *AWSProvider) pollLoop(desc string, timeout time.Duration, f func() (bool, string, error)) (string, error) {
|
||||
attemptsMade := 0
|
||||
start := time.Now()
|
||||
wait := pollWait(attemptsMade)
|
||||
for {
|
||||
elapsed := time.Since(start)
|
||||
if elapsed >= timeout {
|
||||
return "", fmt.Errorf("timeout after %s waiting for %s", elapsed, desc)
|
||||
}
|
||||
|
||||
a.logger.Printf("[DEBUG] connect.ca.aws: %s pending"+
|
||||
", waiting %s to check readiness", desc, wait)
|
||||
select {
|
||||
case <-a.stopCh:
|
||||
// Provider discarded
|
||||
a.logger.Print("[WARN] connect.ca.aws: provider instance terminated"+
|
||||
" while waiting for %s.", desc)
|
||||
return "", fmt.Errorf("provider terminated")
|
||||
case <-time.After(wait):
|
||||
// Continue looping...
|
||||
}
|
||||
|
||||
done, out, err := f()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if done {
|
||||
return out, err
|
||||
}
|
||||
|
||||
attemptsMade++
|
||||
wait = pollWait(attemptsMade)
|
||||
}
|
||||
}
|
||||
|
||||
func (a *AWSProvider) signCSR(csrPEM string, templateARN string, ttl time.Duration) (string, error) {
|
||||
_, signAlg, err := keyTypeToAlgos(a.config.PrivateKeyType, a.config.PrivateKeyBits)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
issueInput := acmpca.IssueCertificateInput{
|
||||
CertificateAuthorityArn: aws.String(a.arn),
|
||||
Csr: []byte(csrPEM),
|
||||
SigningAlgorithm: aws.String(signAlg),
|
||||
TemplateArn: aws.String(templateARN),
|
||||
Validity: &acmpca.Validity{
|
||||
Value: aws.Int64(int64(ttl / day)),
|
||||
Type: aws.String(acmpca.ValidityPeriodTypeDays),
|
||||
},
|
||||
}
|
||||
|
||||
issueOutput, err := a.client.IssueCertificate(&issueInput)
|
||||
// ErrCodeLimitExceededException is used for both hard and soft limits in AWS
|
||||
// SDK :(. In this specific context though (issuing a certificate) there is no
|
||||
// hard limit on number of certs so a limit exceeded here is a rate limit.
|
||||
if aerr, ok := err.(awserr.Error); ok && err != nil {
|
||||
if aerr.Code() == acmpca.ErrCodeLimitExceededException {
|
||||
return "", ErrRateLimited
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("error issuing certificate from PCA: %s", err)
|
||||
}
|
||||
|
||||
// wait for certificate to be created
|
||||
certInput := acmpca.GetCertificateInput{
|
||||
CertificateAuthorityArn: aws.String(a.arn),
|
||||
CertificateArn: issueOutput.CertificateArn,
|
||||
}
|
||||
return a.pollLoop(fmt.Sprintf("certificate %s", *issueOutput.CertificateArn),
|
||||
AWSSignTimeout,
|
||||
func() (bool, string, error) {
|
||||
certOutput, err := a.client.GetCertificate(&certInput)
|
||||
if err != nil {
|
||||
if err.(awserr.Error).Code() != acmpca.ErrCodeRequestInProgressException {
|
||||
return true, "", fmt.Errorf("error retrieving certificate from PCA: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
if certOutput.Certificate != nil {
|
||||
return true, *certOutput.Certificate, nil
|
||||
}
|
||||
|
||||
return false, "", nil
|
||||
})
|
||||
}
|
||||
|
||||
// ActiveRoot implements Provider
|
||||
func (a *AWSProvider) ActiveRoot() (string, error) {
|
||||
err := a.ensureCA()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
if a.rootPEM == "" {
|
||||
return "", fmt.Errorf("Secondary AWS CA provider not fully Initialized")
|
||||
}
|
||||
return a.rootPEM, nil
|
||||
}
|
||||
|
||||
// GenerateIntermediateCSR implements Provider
|
||||
func (a *AWSProvider) GenerateIntermediateCSR() (string, error) {
|
||||
if a.isPrimary {
|
||||
return "", fmt.Errorf("provider is the root certificate authority, " +
|
||||
"cannot generate an intermediate CSR")
|
||||
}
|
||||
|
||||
err := a.ensureCA()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
// We should have the CA created now and should be able to generate the CSR.
|
||||
return a.getCACSR()
|
||||
}
|
||||
|
||||
// SetIntermediate implements Provider
|
||||
func (a *AWSProvider) SetIntermediate(intermediatePEM string, rootPEM string) error {
|
||||
err := a.ensureCA()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Install the certificate
|
||||
input := acmpca.ImportCertificateAuthorityCertificateInput{
|
||||
CertificateAuthorityArn: aws.String(a.arn),
|
||||
Certificate: []byte(intermediatePEM),
|
||||
CertificateChain: []byte(rootPEM),
|
||||
}
|
||||
a.logger.Printf("[DEBUG] uploading certificate for %s", a.arn)
|
||||
_, err = a.client.ImportCertificateAuthorityCertificate(&input)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// We succsefully initialized, keep track of the root and intermediate certs.
|
||||
a.rootPEM = rootPEM
|
||||
a.intermediatePEM = intermediatePEM
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ActiveIntermediate implements Provider
|
||||
func (a *AWSProvider) ActiveIntermediate() (string, error) {
|
||||
err := a.ensureCA()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
if a.rootPEM == "" {
|
||||
return "", fmt.Errorf("AWS CA provider not fully Initialized")
|
||||
}
|
||||
|
||||
if a.isPrimary {
|
||||
// In the simple case the primary DC owns a Root CA and signs with it
|
||||
// directly so just return that for "intermediate" too since that is what we
|
||||
// will sign leafs with.
|
||||
//
|
||||
// TODO(banks) support user-supplied CA being a Subordinate even in the
|
||||
// primary DC. We'd have to figure that out here and return the actual
|
||||
// signing cert as well as somehow populate the intermediate chain.
|
||||
return a.rootPEM, nil
|
||||
}
|
||||
|
||||
if a.intermediatePEM == "" {
|
||||
return "", fmt.Errorf("secondary AWS CA provider not fully Initialized")
|
||||
}
|
||||
|
||||
return a.intermediatePEM, nil
|
||||
}
|
||||
|
||||
// GenerateIntermediate implements Provider
|
||||
func (a *AWSProvider) GenerateIntermediate() (string, error) {
|
||||
// Like the consul provider, for now the Primary DC just gets a root and no
|
||||
// intermediate to sign with. so just return this. Secondaries use
|
||||
// intermediates but this method is only called during primary DC (root)
|
||||
// initialization in case a provider generates separate root and
|
||||
// intermediates.
|
||||
//
|
||||
// TODO(banks) support user-supplied CA being a Subordinate even in the
|
||||
// primary DC.
|
||||
return a.ActiveIntermediate()
|
||||
}
|
||||
|
||||
// Sign implements Provider
|
||||
func (a *AWSProvider) Sign(csr *x509.CertificateRequest) (string, error) {
|
||||
if a.rootPEM == "" {
|
||||
return "", fmt.Errorf("AWS CA provider not fully Initialized")
|
||||
}
|
||||
|
||||
a.logger.Printf("[DEBUG] connect.ca.aws: signing csr for %s",
|
||||
csr.Subject.CommonName)
|
||||
|
||||
return a.signCSRRaw(csr, LeafTemplateARN, a.config.LeafCertTTL)
|
||||
}
|
||||
|
||||
// SignIntermediate implements Provider
|
||||
func (a *AWSProvider) SignIntermediate(csr *x509.CertificateRequest) (string, error) {
|
||||
err := validateSignIntermediate(csr, &connect.SpiffeIDSigning{ClusterID: a.clusterID, Domain: "consul"})
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
// Sign it!
|
||||
return a.signCSRRaw(csr, IntermediateTemplateARN, AWSIntermediateTTL)
|
||||
}
|
||||
|
||||
// CrossSignCA implements Provider
|
||||
func (a *AWSProvider) CrossSignCA(newCA *x509.Certificate) (string, error) {
|
||||
return "", fmt.Errorf("not implemented in AWS PCA provider")
|
||||
}
|
||||
|
||||
func (a *AWSProvider) disablePCA() error {
|
||||
if a.arn == "" {
|
||||
return nil
|
||||
}
|
||||
input := acmpca.UpdateCertificateAuthorityInput{
|
||||
CertificateAuthorityArn: aws.String(a.arn),
|
||||
Status: aws.String(acmpca.CertificateAuthorityStatusDisabled),
|
||||
}
|
||||
a.logger.Printf("[INFO] connect.ca.aws: disabling PCA %s", a.arn)
|
||||
_, err := a.client.UpdateCertificateAuthority(&input)
|
||||
return err
|
||||
}
|
||||
|
||||
func (a *AWSProvider) deletePCA() error {
|
||||
if a.arn == "" {
|
||||
return nil
|
||||
}
|
||||
input := acmpca.DeleteCertificateAuthorityInput{
|
||||
CertificateAuthorityArn: aws.String(a.arn),
|
||||
// We only ever use this to clean up after tests so delete as quickly as
|
||||
// possible (7 days).
|
||||
PermanentDeletionTimeInDays: aws.Int64(7),
|
||||
}
|
||||
a.logger.Printf("[INFO] connect.ca.aws: deleting PCA %s", a.arn)
|
||||
_, err := a.client.DeleteCertificateAuthority(&input)
|
||||
return err
|
||||
}
|
||||
|
||||
// Cleanup implements Provider
|
||||
func (a *AWSProvider) Cleanup() error {
|
||||
old := atomic.SwapUint32(&a.stopped, 1)
|
||||
if old == 0 {
|
||||
close(a.stopCh)
|
||||
}
|
||||
|
||||
if a.config.DeleteOnExit {
|
||||
if err := a.disablePCA(); err != nil {
|
||||
// Log the error but continue trying to delete as some errors may still
|
||||
// allow that and this is best-effort delete anyway.
|
||||
a.logger.Printf("[ERR] connect.ca.aws: failed to disable PCA %s: %s",
|
||||
a.arn, err)
|
||||
}
|
||||
if err := a.deletePCA(); err != nil {
|
||||
// Log the error but continue trying to delete as some errors may still
|
||||
// allow that and this is best-effort delete anyway.
|
||||
a.logger.Printf("[ERR] connect.ca.aws: failed to delete PCA %s: %s",
|
||||
a.arn, err)
|
||||
}
|
||||
// Don't stall leader shutdown, non of the failures here are fatal.
|
||||
return nil
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// SupportsCrossSigning implements Provider
|
||||
func (a *AWSProvider) SupportsCrossSigning() (bool, error) {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
// ParseAWSCAConfig parses and validates AWS CA Provider configuration.
|
||||
func ParseAWSCAConfig(raw map[string]interface{}) (*structs.AWSCAProviderConfig, error) {
|
||||
config := structs.AWSCAProviderConfig{
|
||||
CommonCAProviderConfig: defaultCommonConfig(),
|
||||
}
|
||||
|
||||
decodeConf := &mapstructure.DecoderConfig{
|
||||
DecodeHook: structs.ParseDurationFunc(),
|
||||
Result: &config,
|
||||
WeaklyTypedInput: true,
|
||||
}
|
||||
|
||||
decoder, err := mapstructure.NewDecoder(decodeConf)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := decoder.Decode(raw); err != nil {
|
||||
return nil, fmt.Errorf("error decoding config: %s", err)
|
||||
}
|
||||
|
||||
if err := config.CommonCAProviderConfig.Validate(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Extra keytype validation since PCA is more limited than other providers
|
||||
_, _, err = keyTypeToAlgos(config.PrivateKeyType, config.PrivateKeyBits)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if config.LeafCertTTL < 24*time.Hour {
|
||||
return nil, fmt.Errorf("AWS PCA doesn't support certificates that are valid"+
|
||||
" for less than 24 hours, LeafTTL of %s configured", config.LeafCertTTL)
|
||||
}
|
||||
|
||||
return &config, nil
|
||||
}
|
|
@ -0,0 +1,276 @@
|
|||
package ca
|
||||
|
||||
import (
|
||||
"log"
|
||||
"os"
|
||||
"strconv"
|
||||
"testing"
|
||||
|
||||
"github.com/hashicorp/consul/agent/connect"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func skipIfAWSNotConfigured(t *testing.T) bool {
|
||||
enabled := os.Getenv("ENABLE_AWS_PCA_TESTS")
|
||||
ok, err := strconv.ParseBool(enabled)
|
||||
if err != nil || !ok {
|
||||
t.Skip("Skipping because AWS tests are not enabled")
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func TestAWSBootstrapAndSignPrimary(t *testing.T) {
|
||||
// Note not parallel since we could easily hit AWS limits of too many CAs if
|
||||
// all of these tests run at once.
|
||||
if skipIfAWSNotConfigured(t) {
|
||||
return
|
||||
}
|
||||
|
||||
for _, tc := range KeyTestCases {
|
||||
tc := tc
|
||||
t.Run(tc.Desc, func(t *testing.T) {
|
||||
require := require.New(t)
|
||||
cfg := map[string]interface{}{
|
||||
"PrivateKeyType": tc.KeyType,
|
||||
"PrivateKeyBits": tc.KeyBits,
|
||||
}
|
||||
provider := testAWSProvider(t, testProviderConfigPrimary(t, cfg))
|
||||
defer provider.Cleanup()
|
||||
|
||||
// Generate the root
|
||||
require.NoError(provider.GenerateRoot())
|
||||
|
||||
// Fetch Active Root
|
||||
rootPEM, err := provider.ActiveRoot()
|
||||
require.NoError(err)
|
||||
|
||||
// Generate Intermediate (not actually needed for this provider for now
|
||||
// but this simulates the calls in Server.initializeRoot).
|
||||
interPEM, err := provider.GenerateIntermediate()
|
||||
require.NoError(err)
|
||||
|
||||
// Should be the same for now
|
||||
require.Equal(rootPEM, interPEM)
|
||||
|
||||
// Ensure they use the right key type
|
||||
rootCert, err := connect.ParseCert(rootPEM)
|
||||
require.NoError(err)
|
||||
|
||||
keyType, keyBits, err := connect.KeyInfoFromCert(rootCert)
|
||||
require.NoError(err)
|
||||
require.Equal(tc.KeyType, keyType)
|
||||
require.Equal(tc.KeyBits, keyBits)
|
||||
|
||||
// Sign a leaf with it
|
||||
testSignAndValidate(t, provider, rootPEM, nil)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func testSignAndValidate(t *testing.T, p Provider, rootPEM string, intermediatePEMs []string) {
|
||||
csrPEM, _ := connect.TestCSR(t, connect.TestSpiffeIDService(t, "testsvc"))
|
||||
csr, err := connect.ParseCSR(csrPEM)
|
||||
require.NoError(t, err)
|
||||
|
||||
leafPEM, err := p.Sign(csr)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = connect.ValidateLeaf(rootPEM, leafPEM, intermediatePEMs)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestAWSBootstrapAndSignSecondary(t *testing.T) {
|
||||
// Note not parallel since we could easily hit AWS limits of too many CAs if
|
||||
// all of these tests run at once.
|
||||
if skipIfAWSNotConfigured(t) {
|
||||
return
|
||||
}
|
||||
|
||||
p1 := testAWSProvider(t, testProviderConfigPrimary(t, nil))
|
||||
defer p1.Cleanup()
|
||||
rootPEM, err := p1.ActiveRoot()
|
||||
require.NoError(t, err)
|
||||
|
||||
p2 := testAWSProvider(t, testProviderConfigSecondary(t, nil))
|
||||
defer p2.Cleanup()
|
||||
|
||||
testSignIntermediateCrossDC(t, p1, p2)
|
||||
|
||||
// Fetch intermediate from s2 now for later comparison
|
||||
intPEM, err := p2.ActiveIntermediate()
|
||||
require.NoError(t, err)
|
||||
|
||||
// Capture the state of the providers we've setup
|
||||
p1State, err := p1.State()
|
||||
require.NoError(t, err)
|
||||
p2State, err := p2.State()
|
||||
require.NoError(t, err)
|
||||
|
||||
// TEST LOAD FROM PREVIOUS STATE
|
||||
{
|
||||
// Now create new providers fromthe state of the first ones simulating
|
||||
// leadership change in both DCs
|
||||
t.Log("Restarting Providers with State")
|
||||
|
||||
// Create new provider instances
|
||||
cfg1 := testProviderConfigPrimary(t, nil)
|
||||
cfg1.State = p1State
|
||||
p1 = testAWSProvider(t, cfg1)
|
||||
newRootPEM, err := p1.ActiveRoot()
|
||||
require.NoError(t, err)
|
||||
|
||||
cfg2 := testProviderConfigPrimary(t, nil)
|
||||
cfg2.State = p2State
|
||||
p2 = testAWSProvider(t, cfg2)
|
||||
// Need call ActiveIntermediate like leader would to trigger loading from PCA
|
||||
newIntPEM, err := p2.ActiveIntermediate()
|
||||
require.NoError(t, err)
|
||||
|
||||
// Root cert should not have changed
|
||||
require.Equal(t, rootPEM, newRootPEM)
|
||||
|
||||
// Secondary intermediate cert should not have changed
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, rootPEM, newRootPEM)
|
||||
require.Equal(t, intPEM, newIntPEM)
|
||||
|
||||
// Should both be able to sign leafs again
|
||||
testSignAndValidate(t, p1, rootPEM, nil)
|
||||
testSignAndValidate(t, p2, rootPEM, []string{intPEM})
|
||||
}
|
||||
|
||||
// Since we have CAs created, test the use-case where User supplied CAs are
|
||||
// used.
|
||||
{
|
||||
t.Log("Starting up Providers with ExistingARNs")
|
||||
|
||||
// Create new provider instances with config
|
||||
cfg1 := testProviderConfigPrimary(t, map[string]interface{}{
|
||||
"ExistingARN": p1State[AWSStateCAARNKey],
|
||||
})
|
||||
p1 = testAWSProvider(t, cfg1)
|
||||
newRootPEM, err := p1.ActiveRoot()
|
||||
require.NoError(t, err)
|
||||
|
||||
cfg2 := testProviderConfigPrimary(t, map[string]interface{}{
|
||||
"ExistingARN": p2State[AWSStateCAARNKey],
|
||||
})
|
||||
cfg1.RawConfig["ExistingARN"] = p2State[AWSStateCAARNKey]
|
||||
p2 = testAWSProvider(t, cfg2)
|
||||
// Need call ActiveIntermediate like leader would to trigger loading from PCA
|
||||
newIntPEM, err := p2.ActiveIntermediate()
|
||||
require.NoError(t, err)
|
||||
|
||||
// Root cert should not have changed
|
||||
require.Equal(t, rootPEM, newRootPEM)
|
||||
|
||||
// Secondary intermediate cert should not have changed
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, rootPEM, newRootPEM)
|
||||
require.Equal(t, intPEM, newIntPEM)
|
||||
|
||||
// Should both be able to sign leafs again
|
||||
testSignAndValidate(t, p1, rootPEM, nil)
|
||||
testSignAndValidate(t, p2, rootPEM, []string{intPEM})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAWSBootstrapAndSignSecondaryConsul(t *testing.T) {
|
||||
// Note not parallel since we could easily hit AWS limits of too many CAs if
|
||||
// all of these tests run at once.
|
||||
if skipIfAWSNotConfigured(t) {
|
||||
return
|
||||
}
|
||||
|
||||
t.Run("pri=consul,sec=aws", func(t *testing.T) {
|
||||
conf := testConsulCAConfig()
|
||||
delegate := newMockDelegate(t, conf)
|
||||
p1 := TestConsulProvider(t, delegate)
|
||||
cfg := testProviderConfig(conf)
|
||||
require.NoError(t, p1.Configure(cfg))
|
||||
require.NoError(t, p1.GenerateRoot())
|
||||
|
||||
p2 := testAWSProvider(t, testProviderConfigSecondary(t, nil))
|
||||
defer p2.Cleanup()
|
||||
|
||||
testSignIntermediateCrossDC(t, p1, p2)
|
||||
})
|
||||
|
||||
t.Run("pri=aws,sec=consul", func(t *testing.T) {
|
||||
p1 := testAWSProvider(t, testProviderConfigPrimary(t, nil))
|
||||
defer p1.Cleanup()
|
||||
require.NoError(t, p1.GenerateRoot())
|
||||
|
||||
conf := testConsulCAConfig()
|
||||
delegate := newMockDelegate(t, conf)
|
||||
p2 := TestConsulProvider(t, delegate)
|
||||
cfg := testProviderConfig(conf)
|
||||
cfg.IsPrimary = false
|
||||
cfg.Datacenter = "dc2"
|
||||
require.NoError(t, p2.Configure(cfg))
|
||||
|
||||
testSignIntermediateCrossDC(t, p1, p2)
|
||||
})
|
||||
}
|
||||
|
||||
func TestAWSNoCrossSigning(t *testing.T) {
|
||||
if skipIfAWSNotConfigured(t) {
|
||||
return
|
||||
}
|
||||
|
||||
p1 := testAWSProvider(t, testProviderConfigPrimary(t, nil))
|
||||
defer p1.Cleanup()
|
||||
// Don't bother initializing a PCA as that is slow and unnecessary for this
|
||||
// test
|
||||
|
||||
ok, err := p1.SupportsCrossSigning()
|
||||
require.NoError(t, err)
|
||||
require.False(t, ok)
|
||||
|
||||
// Attempt to cross sign a CA should fail with sensible error
|
||||
ca := connect.TestCA(t, nil)
|
||||
|
||||
caCert, err := connect.ParseCert(ca.RootCert)
|
||||
require.NoError(t, err)
|
||||
_, err = p1.CrossSignCA(caCert)
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "not implemented")
|
||||
}
|
||||
|
||||
func testAWSProvider(t *testing.T, cfg ProviderConfig) *AWSProvider {
|
||||
p := &AWSProvider{}
|
||||
p.SetLogger(log.New(&testLogger{t}, "", log.LstdFlags))
|
||||
require.NoError(t, p.Configure(cfg))
|
||||
return p
|
||||
}
|
||||
|
||||
type testLogger struct {
|
||||
t *testing.T
|
||||
}
|
||||
|
||||
func (l *testLogger) Write(b []byte) (int, error) {
|
||||
l.t.Log(string(b))
|
||||
return len(b), nil
|
||||
}
|
||||
|
||||
func testProviderConfigPrimary(t *testing.T, cfg map[string]interface{}) ProviderConfig {
|
||||
rawCfg := make(map[string]interface{})
|
||||
for k, v := range cfg {
|
||||
rawCfg[k] = v
|
||||
}
|
||||
rawCfg["DeleteOnExit"] = true
|
||||
return ProviderConfig{
|
||||
ClusterID: connect.TestClusterID,
|
||||
Datacenter: "dc1",
|
||||
IsPrimary: true,
|
||||
RawConfig: rawCfg,
|
||||
}
|
||||
}
|
||||
|
||||
func testProviderConfigSecondary(t *testing.T, cfg map[string]interface{}) ProviderConfig {
|
||||
c := testProviderConfigPrimary(t, cfg)
|
||||
c.IsPrimary = false
|
||||
c.Datacenter = "dc2"
|
||||
return c
|
||||
}
|
|
@ -63,7 +63,7 @@ func (c *ConsulProvider) Configure(cfg ProviderConfig) error {
|
|||
c.spiffeID = connect.SpiffeIDSigningForCluster(&structs.CAConfiguration{ClusterID: c.clusterID})
|
||||
|
||||
// Passthrough test state for state handling tests. See testState doc.
|
||||
c.parseTestState(cfg.RawConfig)
|
||||
c.parseTestState(cfg.RawConfig, cfg.State)
|
||||
|
||||
// Exit early if the state store has an entry for this provider's config.
|
||||
_, providerState, err := c.Delegate.State().CAProviderState(c.id)
|
||||
|
@ -426,20 +426,11 @@ func (c *ConsulProvider) SignIntermediate(csr *x509.CertificateRequest) (string,
|
|||
return "", err
|
||||
}
|
||||
|
||||
if uriCount := len(csr.URIs); uriCount != 1 {
|
||||
return "", fmt.Errorf("incoming CSR has unexpected number of URIs: %d", uriCount)
|
||||
}
|
||||
certURI, err := connect.ParseCertURI(csr.URIs[0])
|
||||
err = validateSignIntermediate(csr, c.spiffeID)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
// Verify that the trust domain is valid.
|
||||
if !c.spiffeID.CanSign(certURI) {
|
||||
return "", fmt.Errorf("incoming CSR domain %q is not valid for our domain %q",
|
||||
certURI.URI().String(), c.spiffeID.URI().String())
|
||||
}
|
||||
|
||||
// Get the signing private key.
|
||||
signer, err := connect.ParseSigner(providerState.PrivateKey)
|
||||
if err != nil {
|
||||
|
@ -671,7 +662,7 @@ func (c *ConsulProvider) SetLogger(logger *log.Logger) {
|
|||
c.logger = logger
|
||||
}
|
||||
|
||||
func (c *ConsulProvider) parseTestState(rawConfig map[string]interface{}) {
|
||||
func (c *ConsulProvider) parseTestState(rawConfig map[string]interface{}, state map[string]string) {
|
||||
c.testState = nil
|
||||
if rawTestState, ok := rawConfig["test_state"]; ok {
|
||||
if ts, ok := rawTestState.(map[string]string); ok {
|
||||
|
@ -680,11 +671,12 @@ func (c *ConsulProvider) parseTestState(rawConfig map[string]interface{}) {
|
|||
}
|
||||
|
||||
// Secondary's config takes a trip through the state store before Configure
|
||||
// is called unlike primaries. This means we end up with map[string]string
|
||||
// encoded as map[string]interface{}. Just handle that case and ignore the
|
||||
// rest as this is test-only code and we'd rather not leave a way to error
|
||||
// CA setup and leave cluster unavailable in prod by accidentally setting a
|
||||
// bad test_state config.
|
||||
// is called and RPC calls that msgpack encode also have the same effect. It
|
||||
// means we end up with map[string]string encoded as map[string]interface{}.
|
||||
// We just handle that case. There is no struct error handling because this
|
||||
// is test-only code (undocumented config key) and we'd rather not leave a
|
||||
// way to error CA setup and leave cluster unavailable in prod by
|
||||
// accidentally setting a bad test_state config.
|
||||
if ts, ok := rawTestState.(map[string]interface{}); ok {
|
||||
c.testState = make(map[string]string)
|
||||
for k, v := range ts {
|
||||
|
@ -694,4 +686,11 @@ func (c *ConsulProvider) parseTestState(rawConfig map[string]interface{}) {
|
|||
}
|
||||
}
|
||||
}
|
||||
// If config didn't explicitly specify test_state to return, but there is some
|
||||
// actual state from a previous provider. Just use that since that is expected
|
||||
// behavior that providers with state would preserve the state they are passed
|
||||
// in the common case.
|
||||
if len(state) > 0 && c.testState == nil {
|
||||
c.testState = state
|
||||
}
|
||||
}
|
||||
|
|
|
@ -333,8 +333,13 @@ func (v *VaultProvider) Sign(csr *x509.CertificateRequest) (string, error) {
|
|||
// SignIntermediate returns a signed CA certificate with a path length constraint
|
||||
// of 0 to ensure that the certificate cannot be used to generate further CA certs.
|
||||
func (v *VaultProvider) SignIntermediate(csr *x509.CertificateRequest) (string, error) {
|
||||
err := validateSignIntermediate(csr, v.spiffeID)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
var pemBuf bytes.Buffer
|
||||
err := pem.Encode(&pemBuf, &pem.Block{Type: "CERTIFICATE REQUEST", Bytes: csr.Raw})
|
||||
err = pem.Encode(&pemBuf, &pem.Block{Type: "CERTIFICATE REQUEST", Bytes: csr.Raw})
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
|
|
@ -24,7 +24,7 @@ var invalidDNSNameChars = regexp.MustCompile(`[^a-z0-9]`)
|
|||
// name in that DC.
|
||||
//
|
||||
// Format is:
|
||||
// <sanitized_service_name>.svc.<trust-domain-first-8>.consul
|
||||
// <sanitized_service_name>.svc.<trust_domain_first_8>.consul
|
||||
//
|
||||
// service name is sanitized by removing any chars that are not legal in a DNS
|
||||
// name and lower casing. It is truncated to the first X chars to keep the
|
||||
|
@ -42,7 +42,7 @@ func ServiceCN(serviceName, trustDomain string) string {
|
|||
// more details on rationale.
|
||||
//
|
||||
// Format is:
|
||||
// <sanitized_node_name>.agnt.<trust-domain-first-8>.consul
|
||||
// <sanitized_node_name>.agnt.<trust_domain_first_8>.consul
|
||||
//
|
||||
// node name is sanitized by removing any chars that are not legal in a DNS
|
||||
// name and lower casing. It is truncated to the first X chars to keep the
|
||||
|
@ -82,7 +82,7 @@ func CompactUID() (string, error) {
|
|||
// details on rationale. A uniqueID is requires because some providers (e.g.
|
||||
// Vault) cache by subject and so produce incorrect results - for example they
|
||||
// won't cross-sign an older CA certificate with the same common name since they
|
||||
// think they already have a valid cert for than CN and just return the current
|
||||
// think they already have a valid cert for that CN and just return the current
|
||||
// root.
|
||||
//
|
||||
// This can be generated by any means but will be truncated to 8 chars and
|
||||
|
@ -90,20 +90,20 @@ func CompactUID() (string, error) {
|
|||
// specific purpose.
|
||||
//
|
||||
// Format is:
|
||||
// {provider}-{uniqueID-8-b36}.{root|int}.ca.<trust-domain-first-8>.consul
|
||||
// {provider}-{uniqueID_first8}.{pri|sec}.ca.<trust_domain_first_8>.consul
|
||||
//
|
||||
// trust domain is truncated to keep the whole name short
|
||||
func CACN(provider, uniqueID, trustDomain string, isRoot bool) string {
|
||||
func CACN(provider, uniqueID, trustDomain string, primaryDC bool) string {
|
||||
providerSan := invalidDNSNameChars.ReplaceAllString(strings.ToLower(provider), "")
|
||||
typ := "root"
|
||||
if !isRoot {
|
||||
typ = "int"
|
||||
typ := "pri"
|
||||
if !primaryDC {
|
||||
typ = "sec"
|
||||
}
|
||||
// 33 = 7 bytes for ".consul", 8 bytes for trust domain, 9 bytes for
|
||||
// ".root.ca.", 9 bytes for "-{uniqueID-8-b36}"
|
||||
// 32 = 7 bytes for ".consul", 8 bytes for trust domain, 8 bytes for
|
||||
// ".pri.ca.", 9 bytes for "-{uniqueID-8-b36}"
|
||||
uidSAN := invalidDNSNameChars.ReplaceAllString(strings.ToLower(uniqueID), "")
|
||||
return fmt.Sprintf("%s-%s.%s.ca.%s.consul", typ, truncateTo(uidSAN, 8),
|
||||
truncateTo(providerSan, 64-33), truncateTo(trustDomain, 8))
|
||||
truncateTo(providerSan, 64-32), truncateTo(trustDomain, 8))
|
||||
}
|
||||
|
||||
func truncateTo(s string, n int) string {
|
||||
|
|
|
@ -4,6 +4,7 @@ import (
|
|||
"fmt"
|
||||
"net/http"
|
||||
|
||||
"github.com/hashicorp/consul/agent/consul"
|
||||
"github.com/hashicorp/consul/agent/structs"
|
||||
)
|
||||
|
||||
|
@ -69,5 +70,11 @@ func (s *HTTPServer) ConnectCAConfigurationSet(resp http.ResponseWriter, req *ht
|
|||
|
||||
var reply interface{}
|
||||
err := s.agent.RPC("ConnectCA.ConfigurationSet", &args, &reply)
|
||||
if err != nil && err.Error() == consul.ErrStateReadOnly.Error() {
|
||||
return nil, BadRequestError{
|
||||
Reason: "Provider State is read-only. It must be omitted" +
|
||||
" or identical to the current value",
|
||||
}
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
|
|
@ -63,10 +63,11 @@ func TestConnectCAConfig(t *testing.T) {
|
|||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
body string
|
||||
wantErr bool
|
||||
wantCfg structs.CAConfiguration
|
||||
name string
|
||||
initialState string
|
||||
body string
|
||||
wantErr bool
|
||||
wantCfg structs.CAConfiguration
|
||||
}{
|
||||
{
|
||||
name: "basic",
|
||||
|
@ -137,13 +138,64 @@ func TestConnectCAConfig(t *testing.T) {
|
|||
ForceWithoutCrossSigning: true,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "setting state fails",
|
||||
body: `
|
||||
{
|
||||
"Provider": "consul",
|
||||
"State": {
|
||||
"foo": "bar"
|
||||
}
|
||||
}`,
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "updating config with same state",
|
||||
initialState: `foo = "bar"`,
|
||||
body: `
|
||||
{
|
||||
"Provider": "consul",
|
||||
"config": {
|
||||
"LeafCertTTL": "72h",
|
||||
"RotationPeriod": "1h"
|
||||
},
|
||||
"State": {
|
||||
"foo": "bar"
|
||||
}
|
||||
}`,
|
||||
wantErr: false,
|
||||
wantCfg: structs.CAConfiguration{
|
||||
Provider: "consul",
|
||||
ClusterID: connect.TestClusterID,
|
||||
Config: map[string]interface{}{
|
||||
"LeafCertTTL": "72h",
|
||||
"RotationPeriod": "1h",
|
||||
},
|
||||
State: map[string]string{
|
||||
"foo": "bar",
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
tc := tc
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
require := require.New(t)
|
||||
a := NewTestAgent(t, t.Name(), "")
|
||||
hcl := ""
|
||||
if tc.initialState != "" {
|
||||
hcl = `
|
||||
connect {
|
||||
enabled = true
|
||||
ca_provider = "consul"
|
||||
ca_config {
|
||||
test_state {
|
||||
` + tc.initialState + `
|
||||
}
|
||||
}
|
||||
}`
|
||||
}
|
||||
a := NewTestAgent(t, t.Name(), hcl)
|
||||
defer a.Shutdown()
|
||||
testrpc.WaitForTestAgent(t, a.RPC, "dc1")
|
||||
|
||||
|
|
|
@ -34,6 +34,7 @@ var (
|
|||
ErrConnectNotEnabled = errors.New("Connect must be enabled in order to use this endpoint")
|
||||
ErrRateLimited = errors.New("Rate limit reached, try again later")
|
||||
ErrNotPrimaryDatacenter = errors.New("not the primary datacenter")
|
||||
ErrStateReadOnly = errors.New("CA Provider State is read-only")
|
||||
)
|
||||
|
||||
const (
|
||||
|
@ -163,6 +164,13 @@ func (s *ConnectCA) ConfigurationSet(
|
|||
return err
|
||||
}
|
||||
|
||||
// Don't allow state changes. Either it needs to be empty or the same to allow
|
||||
// read-modify-write loops that don't touch the State field.
|
||||
if len(args.Config.State) > 0 &&
|
||||
!reflect.DeepEqual(args.Config.State, config.State) {
|
||||
return ErrStateReadOnly
|
||||
}
|
||||
|
||||
// Don't allow users to change the ClusterID.
|
||||
args.Config.ClusterID = config.ClusterID
|
||||
if args.Config.Provider == config.Provider && reflect.DeepEqual(args.Config.Config, config.Config) {
|
||||
|
@ -519,6 +527,9 @@ func (s *ConnectCA) Sign(
|
|||
|
||||
// All seems to be in order, actually sign it.
|
||||
pem, err := provider.Sign(csr)
|
||||
if err == ca.ErrRateLimited {
|
||||
return ErrRateLimited
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
|
|
@ -109,6 +109,8 @@ func (s *Server) createCAProvider(conf *structs.CAConfiguration) (ca.Provider, e
|
|||
p = &ca.ConsulProvider{Delegate: &consulCADelegate{s}}
|
||||
case structs.VaultCAProvider:
|
||||
p = &ca.VaultProvider{}
|
||||
case structs.AWSCAProvider:
|
||||
p = &ca.AWSProvider{}
|
||||
default:
|
||||
return nil, fmt.Errorf("unknown CA provider %q", conf.Provider)
|
||||
}
|
||||
|
|
|
@ -216,6 +216,7 @@ func (q *CARequest) RequestDatacenter() string {
|
|||
const (
|
||||
ConsulCAProvider = "consul"
|
||||
VaultCAProvider = "vault"
|
||||
AWSCAProvider = "aws-pca"
|
||||
)
|
||||
|
||||
// CAConfiguration is the configuration for the current CA plugin.
|
||||
|
@ -460,6 +461,13 @@ type VaultCAProviderConfig struct {
|
|||
TLSSkipVerify bool
|
||||
}
|
||||
|
||||
type AWSCAProviderConfig struct {
|
||||
CommonCAProviderConfig `mapstructure:",squash"`
|
||||
|
||||
ExistingARN string
|
||||
DeleteOnExit bool
|
||||
}
|
||||
|
||||
// CALeafOp is the operation for a request related to leaf certificates.
|
||||
type CALeafOp string
|
||||
|
||||
|
|
|
@ -17,6 +17,12 @@ type CAConfig struct {
|
|||
// and maps).
|
||||
Config map[string]interface{}
|
||||
|
||||
// State is read-only data that the provider might have persisted for use
|
||||
// after restart or leadership transition. For example this might include
|
||||
// UUIDs of resources it has created. Setting this when writing a
|
||||
// configuration is an error.
|
||||
State map[string]string
|
||||
|
||||
CreateIndex uint64
|
||||
ModifyIndex uint64
|
||||
}
|
||||
|
|
|
@ -83,6 +83,9 @@ func TestAPI_ConnectCAConfig_get_set(t *testing.T) {
|
|||
// Change a config value and update
|
||||
conf.Config["PrivateKey"] = ""
|
||||
conf.Config["RotationPeriod"] = 120 * 24 * time.Hour
|
||||
// Pass through some state as if the provider stored it so we can make sure
|
||||
// we can read it again.
|
||||
conf.Config["test_state"] = map[string]string{"foo": "bar"}
|
||||
_, err = connect.CASetConfig(conf, nil)
|
||||
r.Check(err)
|
||||
|
||||
|
@ -92,5 +95,6 @@ func TestAPI_ConnectCAConfig_get_set(t *testing.T) {
|
|||
parsed, err = ParseConsulCAConfig(updated.Config)
|
||||
r.Check(err)
|
||||
require.Equal(r, expected, parsed)
|
||||
require.Equal(r, "bar", updated.State["foo"])
|
||||
})
|
||||
}
|
||||
|
|
1
go.mod
1
go.mod
|
@ -14,6 +14,7 @@ require (
|
|||
github.com/armon/circbuf v0.0.0-20150827004946-bbbad097214e
|
||||
github.com/armon/go-metrics v0.0.0-20190430140413-ec5e00d3c878
|
||||
github.com/armon/go-radix v1.0.0
|
||||
github.com/aws/aws-sdk-go v1.25.32
|
||||
github.com/coredns/coredns v1.1.2
|
||||
github.com/digitalocean/godo v1.10.0 // indirect
|
||||
github.com/docker/go-connections v0.3.0
|
||||
|
|
4
go.sum
4
go.sum
|
@ -28,6 +28,8 @@ github.com/armon/go-radix v1.0.0 h1:F4z6KzEeeQIMeLFa97iZU6vupzoecKdU5TX24SNppXI=
|
|||
github.com/armon/go-radix v1.0.0/go.mod h1:ufUuZ+zHj4x4TnLV4JWEpy2hxWSpsRywHrMgIH9cCH8=
|
||||
github.com/aws/aws-sdk-go v1.15.24 h1:xLAdTA/ore6xdPAljzZRed7IGqQgC+nY+ERS5vaj4Ro=
|
||||
github.com/aws/aws-sdk-go v1.15.24/go.mod h1:mFuSZ37Z9YOHbQEwBWztmVzqXrEkub65tZoCYDt7FT0=
|
||||
github.com/aws/aws-sdk-go v1.25.32 h1:GhqlDvuPXnlW46VoKvfLZkJj5IA6jGLO+/TUPCJSYOY=
|
||||
github.com/aws/aws-sdk-go v1.25.32/go.mod h1:KmX6BPdI08NWTb3/sm4ZGu5ShLoqVDhKgpiN924inxo=
|
||||
github.com/beorn7/perks v0.0.0-20180321164747-3a771d992973 h1:xJ4a3vCFaGF/jqvzLMYoU8P317H5OQ+Via4RmuPwCS0=
|
||||
github.com/beorn7/perks v0.0.0-20180321164747-3a771d992973/go.mod h1:Dwedo/Wpr24TaqPxmxbtue+5NUziq4I4S80YR8gNf3Q=
|
||||
github.com/bgentry/speakeasy v0.1.0 h1:ByYyxL9InA1OWqxJqqp2A5pYHUrCiAL6K3J+LKSsQkY=
|
||||
|
@ -207,6 +209,8 @@ github.com/jarcoal/httpmock v0.0.0-20180424175123-9c70cfe4a1da h1:FjHUJJ7oBW4G/9
|
|||
github.com/jarcoal/httpmock v0.0.0-20180424175123-9c70cfe4a1da/go.mod h1:ks+b9deReOc7jgqp+e7LuFiCBH6Rm5hL32cLcEAArb4=
|
||||
github.com/jmespath/go-jmespath v0.0.0-20160202185014-0b12d6b521d8 h1:12VvqtR6Aowv3l/EQUlocDHW2Cp4G9WJVH7uyH8QFJE=
|
||||
github.com/jmespath/go-jmespath v0.0.0-20160202185014-0b12d6b521d8/go.mod h1:Nht3zPeWKUH0NzdCt2Blrr5ys8VGpn0CEB0cQHVjt7k=
|
||||
github.com/jmespath/go-jmespath v0.0.0-20180206201540-c2b33e8439af h1:pmfjZENx5imkbgOkpRUYLnmbU7UEFbjtDA2hxJ1ichM=
|
||||
github.com/jmespath/go-jmespath v0.0.0-20180206201540-c2b33e8439af/go.mod h1:Nht3zPeWKUH0NzdCt2Blrr5ys8VGpn0CEB0cQHVjt7k=
|
||||
github.com/joyent/triton-go v0.0.0-20180628001255-830d2b111e62 h1:JHCT6xuyPUrbbgAPE/3dqlvUKzRHMNuTBKKUb6OeR/k=
|
||||
github.com/joyent/triton-go v0.0.0-20180628001255-830d2b111e62/go.mod h1:U+RSyWxWd04xTqnuOQxnai7XGS2PrPY2cfGoDKtMHjA=
|
||||
github.com/json-iterator/go v1.1.5 h1:gL2yXlmiIo4+t+y32d4WGwOjKGYcGOuyrg46vadswDE=
|
||||
|
|
|
@ -1,3 +1,3 @@
|
|||
AWS SDK for Go
|
||||
Copyright 2015 Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
||||
Copyright 2015 Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
||||
Copyright 2014-2015 Stripe, Inc.
|
||||
|
|
|
@ -138,8 +138,27 @@ type RequestFailure interface {
|
|||
RequestID() string
|
||||
}
|
||||
|
||||
// NewRequestFailure returns a new request error wrapper for the given Error
|
||||
// provided.
|
||||
// NewRequestFailure returns a wrapped error with additional information for
|
||||
// request status code, and service requestID.
|
||||
//
|
||||
// Should be used to wrap all request which involve service requests. Even if
|
||||
// the request failed without a service response, but had an HTTP status code
|
||||
// that may be meaningful.
|
||||
func NewRequestFailure(err Error, statusCode int, reqID string) RequestFailure {
|
||||
return newRequestError(err, statusCode, reqID)
|
||||
}
|
||||
|
||||
// UnmarshalError provides the interface for the SDK failing to unmarshal data.
|
||||
type UnmarshalError interface {
|
||||
awsError
|
||||
Bytes() []byte
|
||||
}
|
||||
|
||||
// NewUnmarshalError returns an initialized UnmarshalError error wrapper adding
|
||||
// the bytes that fail to unmarshal to the error.
|
||||
func NewUnmarshalError(err error, msg string, bytes []byte) UnmarshalError {
|
||||
return &unmarshalError{
|
||||
awsError: New("UnmarshalError", msg, err),
|
||||
bytes: bytes,
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,6 +1,9 @@
|
|||
package awserr
|
||||
|
||||
import "fmt"
|
||||
import (
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
// SprintError returns a string of the formatted error code.
|
||||
//
|
||||
|
@ -119,6 +122,7 @@ type requestError struct {
|
|||
awsError
|
||||
statusCode int
|
||||
requestID string
|
||||
bytes []byte
|
||||
}
|
||||
|
||||
// newRequestError returns a wrapped error with additional information for
|
||||
|
@ -170,6 +174,29 @@ func (r requestError) OrigErrs() []error {
|
|||
return []error{r.OrigErr()}
|
||||
}
|
||||
|
||||
type unmarshalError struct {
|
||||
awsError
|
||||
bytes []byte
|
||||
}
|
||||
|
||||
// Error returns the string representation of the error.
|
||||
// Satisfies the error interface.
|
||||
func (e unmarshalError) Error() string {
|
||||
extra := hex.Dump(e.bytes)
|
||||
return SprintError(e.Code(), e.Message(), extra, e.OrigErr())
|
||||
}
|
||||
|
||||
// String returns the string representation of the error.
|
||||
// Alias for Error to satisfy the stringer interface.
|
||||
func (e unmarshalError) String() string {
|
||||
return e.Error()
|
||||
}
|
||||
|
||||
// Bytes returns the bytes that failed to unmarshal.
|
||||
func (e unmarshalError) Bytes() []byte {
|
||||
return e.bytes
|
||||
}
|
||||
|
||||
// An error list that satisfies the golang interface
|
||||
type errorList []error
|
||||
|
||||
|
@ -181,7 +208,7 @@ func (e errorList) Error() string {
|
|||
// How do we want to handle the array size being zero
|
||||
if size := len(e); size > 0 {
|
||||
for i := 0; i < size; i++ {
|
||||
msg += fmt.Sprintf("%s", e[i].Error())
|
||||
msg += e[i].Error()
|
||||
// We check the next index to see if it is within the slice.
|
||||
// If it is, then we append a newline. We do this, because unit tests
|
||||
// could be broken with the additional '\n'
|
||||
|
|
|
@ -15,7 +15,7 @@ func DeepEqual(a, b interface{}) bool {
|
|||
rb := reflect.Indirect(reflect.ValueOf(b))
|
||||
|
||||
if raValid, rbValid := ra.IsValid(), rb.IsValid(); !raValid && !rbValid {
|
||||
// If the elements are both nil, and of the same type the are equal
|
||||
// If the elements are both nil, and of the same type they are equal
|
||||
// If they are of different types they are not equal
|
||||
return reflect.TypeOf(a) == reflect.TypeOf(b)
|
||||
} else if raValid != rbValid {
|
||||
|
|
|
@ -185,13 +185,12 @@ func ValuesAtPath(i interface{}, path string) ([]interface{}, error) {
|
|||
// SetValueAtPath sets a value at the case insensitive lexical path inside
|
||||
// of a structure.
|
||||
func SetValueAtPath(i interface{}, path string, v interface{}) {
|
||||
if rvals := rValuesAtPath(i, path, true, false, v == nil); rvals != nil {
|
||||
for _, rval := range rvals {
|
||||
if rval.Kind() == reflect.Ptr && rval.IsNil() {
|
||||
continue
|
||||
}
|
||||
setValue(rval, v)
|
||||
rvals := rValuesAtPath(i, path, true, false, v == nil)
|
||||
for _, rval := range rvals {
|
||||
if rval.Kind() == reflect.Ptr && rval.IsNil() {
|
||||
continue
|
||||
}
|
||||
setValue(rval, v)
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -23,28 +23,27 @@ func stringValue(v reflect.Value, indent int, buf *bytes.Buffer) {
|
|||
case reflect.Struct:
|
||||
buf.WriteString("{\n")
|
||||
|
||||
names := []string{}
|
||||
for i := 0; i < v.Type().NumField(); i++ {
|
||||
name := v.Type().Field(i).Name
|
||||
f := v.Field(i)
|
||||
if name[0:1] == strings.ToLower(name[0:1]) {
|
||||
ft := v.Type().Field(i)
|
||||
fv := v.Field(i)
|
||||
|
||||
if ft.Name[0:1] == strings.ToLower(ft.Name[0:1]) {
|
||||
continue // ignore unexported fields
|
||||
}
|
||||
if (f.Kind() == reflect.Ptr || f.Kind() == reflect.Slice) && f.IsNil() {
|
||||
if (fv.Kind() == reflect.Ptr || fv.Kind() == reflect.Slice) && fv.IsNil() {
|
||||
continue // ignore unset fields
|
||||
}
|
||||
names = append(names, name)
|
||||
}
|
||||
|
||||
for i, n := range names {
|
||||
val := v.FieldByName(n)
|
||||
buf.WriteString(strings.Repeat(" ", indent+2))
|
||||
buf.WriteString(n + ": ")
|
||||
stringValue(val, indent+2, buf)
|
||||
buf.WriteString(ft.Name + ": ")
|
||||
|
||||
if i < len(names)-1 {
|
||||
buf.WriteString(",\n")
|
||||
if tag := ft.Tag.Get("sensitive"); tag == "true" {
|
||||
buf.WriteString("<sensitive>")
|
||||
} else {
|
||||
stringValue(fv, indent+2, buf)
|
||||
}
|
||||
|
||||
buf.WriteString(",\n")
|
||||
}
|
||||
|
||||
buf.WriteString("\n" + strings.Repeat(" ", indent) + "}")
|
||||
|
|
|
@ -12,13 +12,14 @@ import (
|
|||
type Config struct {
|
||||
Config *aws.Config
|
||||
Handlers request.Handlers
|
||||
PartitionID string
|
||||
Endpoint string
|
||||
SigningRegion string
|
||||
SigningName string
|
||||
|
||||
// States that the signing name did not come from a modeled source but
|
||||
// was derived based on other data. Used by service client constructors
|
||||
// to determine if the signin name can be overriden based on metadata the
|
||||
// to determine if the signin name can be overridden based on metadata the
|
||||
// service has.
|
||||
SigningNameDerived bool
|
||||
}
|
||||
|
@ -64,7 +65,7 @@ func New(cfg aws.Config, info metadata.ClientInfo, handlers request.Handlers, op
|
|||
default:
|
||||
maxRetries := aws.IntValue(cfg.MaxRetries)
|
||||
if cfg.MaxRetries == nil || maxRetries == aws.UseServiceDefaultRetries {
|
||||
maxRetries = 3
|
||||
maxRetries = DefaultRetryerMaxNumRetries
|
||||
}
|
||||
svc.Retryer = DefaultRetryer{NumMaxRetries: maxRetries}
|
||||
}
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
package client
|
||||
|
||||
import (
|
||||
"math"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
|
@ -9,82 +10,142 @@ import (
|
|||
)
|
||||
|
||||
// DefaultRetryer implements basic retry logic using exponential backoff for
|
||||
// most services. If you want to implement custom retry logic, implement the
|
||||
// request.Retryer interface or create a structure type that composes this
|
||||
// struct and override the specific methods. For example, to override only
|
||||
// the MaxRetries method:
|
||||
// most services. If you want to implement custom retry logic, you can implement the
|
||||
// request.Retryer interface.
|
||||
//
|
||||
// type retryer struct {
|
||||
// client.DefaultRetryer
|
||||
// }
|
||||
//
|
||||
// // This implementation always has 100 max retries
|
||||
// func (d retryer) MaxRetries() int { return 100 }
|
||||
type DefaultRetryer struct {
|
||||
// Num max Retries is the number of max retries that will be performed.
|
||||
// By default, this is zero.
|
||||
NumMaxRetries int
|
||||
|
||||
// MinRetryDelay is the minimum retry delay after which retry will be performed.
|
||||
// If not set, the value is 0ns.
|
||||
MinRetryDelay time.Duration
|
||||
|
||||
// MinThrottleRetryDelay is the minimum retry delay when throttled.
|
||||
// If not set, the value is 0ns.
|
||||
MinThrottleDelay time.Duration
|
||||
|
||||
// MaxRetryDelay is the maximum retry delay before which retry must be performed.
|
||||
// If not set, the value is 0ns.
|
||||
MaxRetryDelay time.Duration
|
||||
|
||||
// MaxThrottleDelay is the maximum retry delay when throttled.
|
||||
// If not set, the value is 0ns.
|
||||
MaxThrottleDelay time.Duration
|
||||
}
|
||||
|
||||
const (
|
||||
// DefaultRetryerMaxNumRetries sets maximum number of retries
|
||||
DefaultRetryerMaxNumRetries = 3
|
||||
|
||||
// DefaultRetryerMinRetryDelay sets minimum retry delay
|
||||
DefaultRetryerMinRetryDelay = 30 * time.Millisecond
|
||||
|
||||
// DefaultRetryerMinThrottleDelay sets minimum delay when throttled
|
||||
DefaultRetryerMinThrottleDelay = 500 * time.Millisecond
|
||||
|
||||
// DefaultRetryerMaxRetryDelay sets maximum retry delay
|
||||
DefaultRetryerMaxRetryDelay = 300 * time.Second
|
||||
|
||||
// DefaultRetryerMaxThrottleDelay sets maximum delay when throttled
|
||||
DefaultRetryerMaxThrottleDelay = 300 * time.Second
|
||||
)
|
||||
|
||||
// MaxRetries returns the number of maximum returns the service will use to make
|
||||
// an individual API request.
|
||||
func (d DefaultRetryer) MaxRetries() int {
|
||||
return d.NumMaxRetries
|
||||
}
|
||||
|
||||
// setRetryerDefaults sets the default values of the retryer if not set
|
||||
func (d *DefaultRetryer) setRetryerDefaults() {
|
||||
if d.MinRetryDelay == 0 {
|
||||
d.MinRetryDelay = DefaultRetryerMinRetryDelay
|
||||
}
|
||||
if d.MaxRetryDelay == 0 {
|
||||
d.MaxRetryDelay = DefaultRetryerMaxRetryDelay
|
||||
}
|
||||
if d.MinThrottleDelay == 0 {
|
||||
d.MinThrottleDelay = DefaultRetryerMinThrottleDelay
|
||||
}
|
||||
if d.MaxThrottleDelay == 0 {
|
||||
d.MaxThrottleDelay = DefaultRetryerMaxThrottleDelay
|
||||
}
|
||||
}
|
||||
|
||||
// RetryRules returns the delay duration before retrying this request again
|
||||
func (d DefaultRetryer) RetryRules(r *request.Request) time.Duration {
|
||||
// Set the upper limit of delay in retrying at ~five minutes
|
||||
minTime := 30
|
||||
throttle := d.shouldThrottle(r)
|
||||
if throttle {
|
||||
if delay, ok := getRetryDelay(r); ok {
|
||||
return delay
|
||||
}
|
||||
|
||||
minTime = 500
|
||||
// if number of max retries is zero, no retries will be performed.
|
||||
if d.NumMaxRetries == 0 {
|
||||
return 0
|
||||
}
|
||||
|
||||
// Sets default value for retryer members
|
||||
d.setRetryerDefaults()
|
||||
|
||||
// minDelay is the minimum retryer delay
|
||||
minDelay := d.MinRetryDelay
|
||||
|
||||
var initialDelay time.Duration
|
||||
|
||||
isThrottle := r.IsErrorThrottle()
|
||||
if isThrottle {
|
||||
if delay, ok := getRetryAfterDelay(r); ok {
|
||||
initialDelay = delay
|
||||
}
|
||||
minDelay = d.MinThrottleDelay
|
||||
}
|
||||
|
||||
retryCount := r.RetryCount
|
||||
if throttle && retryCount > 8 {
|
||||
retryCount = 8
|
||||
} else if retryCount > 13 {
|
||||
retryCount = 13
|
||||
|
||||
// maxDelay the maximum retryer delay
|
||||
maxDelay := d.MaxRetryDelay
|
||||
|
||||
if isThrottle {
|
||||
maxDelay = d.MaxThrottleDelay
|
||||
}
|
||||
|
||||
delay := (1 << uint(retryCount)) * (sdkrand.SeededRand.Intn(minTime) + minTime)
|
||||
return time.Duration(delay) * time.Millisecond
|
||||
var delay time.Duration
|
||||
|
||||
// Logic to cap the retry count based on the minDelay provided
|
||||
actualRetryCount := int(math.Log2(float64(minDelay))) + 1
|
||||
if actualRetryCount < 63-retryCount {
|
||||
delay = time.Duration(1<<uint64(retryCount)) * getJitterDelay(minDelay)
|
||||
if delay > maxDelay {
|
||||
delay = getJitterDelay(maxDelay / 2)
|
||||
}
|
||||
} else {
|
||||
delay = getJitterDelay(maxDelay / 2)
|
||||
}
|
||||
return delay + initialDelay
|
||||
}
|
||||
|
||||
// getJitterDelay returns a jittered delay for retry
|
||||
func getJitterDelay(duration time.Duration) time.Duration {
|
||||
return time.Duration(sdkrand.SeededRand.Int63n(int64(duration)) + int64(duration))
|
||||
}
|
||||
|
||||
// ShouldRetry returns true if the request should be retried.
|
||||
func (d DefaultRetryer) ShouldRetry(r *request.Request) bool {
|
||||
|
||||
// ShouldRetry returns false if number of max retries is 0.
|
||||
if d.NumMaxRetries == 0 {
|
||||
return false
|
||||
}
|
||||
|
||||
// If one of the other handlers already set the retry state
|
||||
// we don't want to override it based on the service's state
|
||||
if r.Retryable != nil {
|
||||
return *r.Retryable
|
||||
}
|
||||
|
||||
if r.HTTPResponse.StatusCode >= 500 && r.HTTPResponse.StatusCode != 501 {
|
||||
return true
|
||||
}
|
||||
return r.IsErrorRetryable() || d.shouldThrottle(r)
|
||||
}
|
||||
|
||||
// ShouldThrottle returns true if the request should be throttled.
|
||||
func (d DefaultRetryer) shouldThrottle(r *request.Request) bool {
|
||||
switch r.HTTPResponse.StatusCode {
|
||||
case 429:
|
||||
case 502:
|
||||
case 503:
|
||||
case 504:
|
||||
default:
|
||||
return r.IsErrorThrottle()
|
||||
}
|
||||
|
||||
return true
|
||||
return r.IsErrorRetryable() || r.IsErrorThrottle()
|
||||
}
|
||||
|
||||
// This will look in the Retry-After header, RFC 7231, for how long
|
||||
// it will wait before attempting another request
|
||||
func getRetryDelay(r *request.Request) (time.Duration, bool) {
|
||||
func getRetryAfterDelay(r *request.Request) (time.Duration, bool) {
|
||||
if !canUseRetryAfterHeader(r) {
|
||||
return 0, false
|
||||
}
|
||||
|
|
|
@ -67,10 +67,14 @@ func logRequest(r *request.Request) {
|
|||
if !bodySeekable {
|
||||
r.SetReaderBody(aws.ReadSeekCloser(r.HTTPRequest.Body))
|
||||
}
|
||||
// Reset the request body because dumpRequest will re-wrap the r.HTTPRequest's
|
||||
// Body as a NoOpCloser and will not be reset after read by the HTTP
|
||||
// client reader.
|
||||
r.ResetBody()
|
||||
// Reset the request body because dumpRequest will re-wrap the
|
||||
// r.HTTPRequest's Body as a NoOpCloser and will not be reset after
|
||||
// read by the HTTP client reader.
|
||||
if err := r.Error; err != nil {
|
||||
r.Config.Logger.Log(fmt.Sprintf(logReqErrMsg,
|
||||
r.ClientInfo.ServiceName, r.Operation.Name, err))
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
r.Config.Logger.Log(fmt.Sprintf(logReqMsg,
|
||||
|
@ -118,6 +122,12 @@ var LogHTTPResponseHandler = request.NamedHandler{
|
|||
func logResponse(r *request.Request) {
|
||||
lw := &logWriter{r.Config.Logger, bytes.NewBuffer(nil)}
|
||||
|
||||
if r.HTTPResponse == nil {
|
||||
lw.Logger.Log(fmt.Sprintf(logRespErrMsg,
|
||||
r.ClientInfo.ServiceName, r.Operation.Name, "request's HTTPResponse is nil"))
|
||||
return
|
||||
}
|
||||
|
||||
logBody := r.Config.LogLevel.Matches(aws.LogDebugWithHTTPBody)
|
||||
if logBody {
|
||||
r.HTTPResponse.Body = &teeReaderCloser{
|
||||
|
|
|
@ -5,6 +5,7 @@ type ClientInfo struct {
|
|||
ServiceName string
|
||||
ServiceID string
|
||||
APIVersion string
|
||||
PartitionID string
|
||||
Endpoint string
|
||||
SigningName string
|
||||
SigningRegion string
|
||||
|
|
|
@ -0,0 +1,28 @@
|
|||
package client
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/aws/aws-sdk-go/aws/request"
|
||||
)
|
||||
|
||||
// NoOpRetryer provides a retryer that performs no retries.
|
||||
// It should be used when we do not want retries to be performed.
|
||||
type NoOpRetryer struct{}
|
||||
|
||||
// MaxRetries returns the number of maximum returns the service will use to make
|
||||
// an individual API; For NoOpRetryer the MaxRetries will always be zero.
|
||||
func (d NoOpRetryer) MaxRetries() int {
|
||||
return 0
|
||||
}
|
||||
|
||||
// ShouldRetry will always return false for NoOpRetryer, as it should never retry.
|
||||
func (d NoOpRetryer) ShouldRetry(_ *request.Request) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
// RetryRules returns the delay duration before retrying this request again;
|
||||
// since NoOpRetryer does not retry, RetryRules always returns 0.
|
||||
func (d NoOpRetryer) RetryRules(_ *request.Request) time.Duration {
|
||||
return 0
|
||||
}
|
|
@ -18,9 +18,9 @@ const UseServiceDefaultRetries = -1
|
|||
type RequestRetryer interface{}
|
||||
|
||||
// A Config provides service configuration for service clients. By default,
|
||||
// all clients will use the defaults.DefaultConfig tructure.
|
||||
// all clients will use the defaults.DefaultConfig structure.
|
||||
//
|
||||
// // Create Session with MaxRetry configuration to be shared by multiple
|
||||
// // Create Session with MaxRetries configuration to be shared by multiple
|
||||
// // service clients.
|
||||
// sess := session.Must(session.NewSession(&aws.Config{
|
||||
// MaxRetries: aws.Int(3),
|
||||
|
@ -45,8 +45,8 @@ type Config struct {
|
|||
// that overrides the default generated endpoint for a client. Set this
|
||||
// to `""` to use the default generated endpoint.
|
||||
//
|
||||
// @note You must still provide a `Region` value when specifying an
|
||||
// endpoint for a client.
|
||||
// Note: You must still provide a `Region` value when specifying an
|
||||
// endpoint for a client.
|
||||
Endpoint *string
|
||||
|
||||
// The resolver to use for looking up endpoints for AWS service clients
|
||||
|
@ -65,8 +65,8 @@ type Config struct {
|
|||
// noted. A full list of regions is found in the "Regions and Endpoints"
|
||||
// document.
|
||||
//
|
||||
// @see http://docs.aws.amazon.com/general/latest/gr/rande.html
|
||||
// AWS Regions and Endpoints
|
||||
// See http://docs.aws.amazon.com/general/latest/gr/rande.html for AWS
|
||||
// Regions and Endpoints.
|
||||
Region *string
|
||||
|
||||
// Set this to `true` to disable SSL when sending requests. Defaults
|
||||
|
@ -120,9 +120,10 @@ type Config struct {
|
|||
// will use virtual hosted bucket addressing when possible
|
||||
// (`http://BUCKET.s3.amazonaws.com/KEY`).
|
||||
//
|
||||
// @note This configuration option is specific to the Amazon S3 service.
|
||||
// @see http://docs.aws.amazon.com/AmazonS3/latest/dev/VirtualHosting.html
|
||||
// Amazon S3: Virtual Hosting of Buckets
|
||||
// Note: This configuration option is specific to the Amazon S3 service.
|
||||
//
|
||||
// See http://docs.aws.amazon.com/AmazonS3/latest/dev/VirtualHosting.html
|
||||
// for Amazon S3: Virtual Hosting of Buckets
|
||||
S3ForcePathStyle *bool
|
||||
|
||||
// Set this to `true` to disable the SDK adding the `Expect: 100-Continue`
|
||||
|
@ -223,12 +224,37 @@ type Config struct {
|
|||
// Key: aws.String("//foo//bar//moo"),
|
||||
// })
|
||||
DisableRestProtocolURICleaning *bool
|
||||
|
||||
// EnableEndpointDiscovery will allow for endpoint discovery on operations that
|
||||
// have the definition in its model. By default, endpoint discovery is off.
|
||||
//
|
||||
// Example:
|
||||
// sess := session.Must(session.NewSession(&aws.Config{
|
||||
// EnableEndpointDiscovery: aws.Bool(true),
|
||||
// }))
|
||||
//
|
||||
// svc := s3.New(sess)
|
||||
// out, err := svc.GetObject(&s3.GetObjectInput {
|
||||
// Bucket: aws.String("bucketname"),
|
||||
// Key: aws.String("/foo/bar/moo"),
|
||||
// })
|
||||
EnableEndpointDiscovery *bool
|
||||
|
||||
// DisableEndpointHostPrefix will disable the SDK's behavior of prefixing
|
||||
// request endpoint hosts with modeled information.
|
||||
//
|
||||
// Disabling this feature is useful when you want to use local endpoints
|
||||
// for testing that do not support the modeled host prefix pattern.
|
||||
DisableEndpointHostPrefix *bool
|
||||
|
||||
// STSRegionalEndpoint will enable regional or legacy endpoint resolving
|
||||
STSRegionalEndpoint endpoints.STSRegionalEndpoint
|
||||
}
|
||||
|
||||
// NewConfig returns a new Config pointer that can be chained with builder
|
||||
// methods to set multiple configuration values inline without using pointers.
|
||||
//
|
||||
// // Create Session with MaxRetry configuration to be shared by multiple
|
||||
// // Create Session with MaxRetries configuration to be shared by multiple
|
||||
// // service clients.
|
||||
// sess := session.Must(session.NewSession(aws.NewConfig().
|
||||
// WithMaxRetries(3),
|
||||
|
@ -377,6 +403,19 @@ func (c *Config) WithSleepDelay(fn func(time.Duration)) *Config {
|
|||
return c
|
||||
}
|
||||
|
||||
// WithEndpointDiscovery will set whether or not to use endpoint discovery.
|
||||
func (c *Config) WithEndpointDiscovery(t bool) *Config {
|
||||
c.EnableEndpointDiscovery = &t
|
||||
return c
|
||||
}
|
||||
|
||||
// WithDisableEndpointHostPrefix will set whether or not to use modeled host prefix
|
||||
// when making requests.
|
||||
func (c *Config) WithDisableEndpointHostPrefix(t bool) *Config {
|
||||
c.DisableEndpointHostPrefix = &t
|
||||
return c
|
||||
}
|
||||
|
||||
// MergeIn merges the passed in configs into the existing config object.
|
||||
func (c *Config) MergeIn(cfgs ...*Config) {
|
||||
for _, other := range cfgs {
|
||||
|
@ -384,6 +423,13 @@ func (c *Config) MergeIn(cfgs ...*Config) {
|
|||
}
|
||||
}
|
||||
|
||||
// WithSTSRegionalEndpoint will set whether or not to use regional endpoint flag
|
||||
// when resolving the endpoint for a service
|
||||
func (c *Config) WithSTSRegionalEndpoint(sre endpoints.STSRegionalEndpoint) *Config {
|
||||
c.STSRegionalEndpoint = sre
|
||||
return c
|
||||
}
|
||||
|
||||
func mergeInConfig(dst *Config, other *Config) {
|
||||
if other == nil {
|
||||
return
|
||||
|
@ -476,6 +522,18 @@ func mergeInConfig(dst *Config, other *Config) {
|
|||
if other.EnforceShouldRetryCheck != nil {
|
||||
dst.EnforceShouldRetryCheck = other.EnforceShouldRetryCheck
|
||||
}
|
||||
|
||||
if other.EnableEndpointDiscovery != nil {
|
||||
dst.EnableEndpointDiscovery = other.EnableEndpointDiscovery
|
||||
}
|
||||
|
||||
if other.DisableEndpointHostPrefix != nil {
|
||||
dst.DisableEndpointHostPrefix = other.DisableEndpointHostPrefix
|
||||
}
|
||||
|
||||
if other.STSRegionalEndpoint != endpoints.UnsetSTSEndpoint {
|
||||
dst.STSRegionalEndpoint = other.STSRegionalEndpoint
|
||||
}
|
||||
}
|
||||
|
||||
// Copy will return a shallow copy of the Config object. If any additional
|
||||
|
|
|
@ -1,8 +1,8 @@
|
|||
// +build !go1.9
|
||||
|
||||
package aws
|
||||
|
||||
import (
|
||||
"time"
|
||||
)
|
||||
import "time"
|
||||
|
||||
// Context is an copy of the Go v1.7 stdlib's context.Context interface.
|
||||
// It is represented as a SDK interface to enable you to use the "WithContext"
|
||||
|
@ -35,37 +35,3 @@ type Context interface {
|
|||
// functions.
|
||||
Value(key interface{}) interface{}
|
||||
}
|
||||
|
||||
// BackgroundContext returns a context that will never be canceled, has no
|
||||
// values, and no deadline. This context is used by the SDK to provide
|
||||
// backwards compatibility with non-context API operations and functionality.
|
||||
//
|
||||
// Go 1.6 and before:
|
||||
// This context function is equivalent to context.Background in the Go stdlib.
|
||||
//
|
||||
// Go 1.7 and later:
|
||||
// The context returned will be the value returned by context.Background()
|
||||
//
|
||||
// See https://golang.org/pkg/context for more information on Contexts.
|
||||
func BackgroundContext() Context {
|
||||
return backgroundCtx
|
||||
}
|
||||
|
||||
// SleepWithContext will wait for the timer duration to expire, or the context
|
||||
// is canceled. Which ever happens first. If the context is canceled the Context's
|
||||
// error will be returned.
|
||||
//
|
||||
// Expects Context to always return a non-nil error if the Done channel is closed.
|
||||
func SleepWithContext(ctx Context, dur time.Duration) error {
|
||||
t := time.NewTimer(dur)
|
||||
defer t.Stop()
|
||||
|
||||
select {
|
||||
case <-t.C:
|
||||
break
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
|
@ -1,9 +0,0 @@
|
|||
// +build go1.7
|
||||
|
||||
package aws
|
||||
|
||||
import "context"
|
||||
|
||||
var (
|
||||
backgroundCtx = context.Background()
|
||||
)
|
|
@ -0,0 +1,11 @@
|
|||
// +build go1.9
|
||||
|
||||
package aws
|
||||
|
||||
import "context"
|
||||
|
||||
// Context is an alias of the Go stdlib's context.Context interface.
|
||||
// It can be used within the SDK's API operation "WithContext" methods.
|
||||
//
|
||||
// See https://golang.org/pkg/context on how to use contexts.
|
||||
type Context = context.Context
|
|
@ -39,3 +39,18 @@ func (e *emptyCtx) String() string {
|
|||
var (
|
||||
backgroundCtx = new(emptyCtx)
|
||||
)
|
||||
|
||||
// BackgroundContext returns a context that will never be canceled, has no
|
||||
// values, and no deadline. This context is used by the SDK to provide
|
||||
// backwards compatibility with non-context API operations and functionality.
|
||||
//
|
||||
// Go 1.6 and before:
|
||||
// This context function is equivalent to context.Background in the Go stdlib.
|
||||
//
|
||||
// Go 1.7 and later:
|
||||
// The context returned will be the value returned by context.Background()
|
||||
//
|
||||
// See https://golang.org/pkg/context for more information on Contexts.
|
||||
func BackgroundContext() Context {
|
||||
return backgroundCtx
|
||||
}
|
|
@ -0,0 +1,20 @@
|
|||
// +build go1.7
|
||||
|
||||
package aws
|
||||
|
||||
import "context"
|
||||
|
||||
// BackgroundContext returns a context that will never be canceled, has no
|
||||
// values, and no deadline. This context is used by the SDK to provide
|
||||
// backwards compatibility with non-context API operations and functionality.
|
||||
//
|
||||
// Go 1.6 and before:
|
||||
// This context function is equivalent to context.Background in the Go stdlib.
|
||||
//
|
||||
// Go 1.7 and later:
|
||||
// The context returned will be the value returned by context.Background()
|
||||
//
|
||||
// See https://golang.org/pkg/context for more information on Contexts.
|
||||
func BackgroundContext() Context {
|
||||
return context.Background()
|
||||
}
|
|
@ -0,0 +1,24 @@
|
|||
package aws
|
||||
|
||||
import (
|
||||
"time"
|
||||
)
|
||||
|
||||
// SleepWithContext will wait for the timer duration to expire, or the context
|
||||
// is canceled. Which ever happens first. If the context is canceled the Context's
|
||||
// error will be returned.
|
||||
//
|
||||
// Expects Context to always return a non-nil error if the Done channel is closed.
|
||||
func SleepWithContext(ctx Context, dur time.Duration) error {
|
||||
t := time.NewTimer(dur)
|
||||
defer t.Stop()
|
||||
|
||||
select {
|
||||
case <-t.C:
|
||||
break
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
|
@ -179,6 +179,242 @@ func IntValueMap(src map[string]*int) map[string]int {
|
|||
return dst
|
||||
}
|
||||
|
||||
// Uint returns a pointer to the uint value passed in.
|
||||
func Uint(v uint) *uint {
|
||||
return &v
|
||||
}
|
||||
|
||||
// UintValue returns the value of the uint pointer passed in or
|
||||
// 0 if the pointer is nil.
|
||||
func UintValue(v *uint) uint {
|
||||
if v != nil {
|
||||
return *v
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
// UintSlice converts a slice of uint values uinto a slice of
|
||||
// uint pointers
|
||||
func UintSlice(src []uint) []*uint {
|
||||
dst := make([]*uint, len(src))
|
||||
for i := 0; i < len(src); i++ {
|
||||
dst[i] = &(src[i])
|
||||
}
|
||||
return dst
|
||||
}
|
||||
|
||||
// UintValueSlice converts a slice of uint pointers uinto a slice of
|
||||
// uint values
|
||||
func UintValueSlice(src []*uint) []uint {
|
||||
dst := make([]uint, len(src))
|
||||
for i := 0; i < len(src); i++ {
|
||||
if src[i] != nil {
|
||||
dst[i] = *(src[i])
|
||||
}
|
||||
}
|
||||
return dst
|
||||
}
|
||||
|
||||
// UintMap converts a string map of uint values uinto a string
|
||||
// map of uint pointers
|
||||
func UintMap(src map[string]uint) map[string]*uint {
|
||||
dst := make(map[string]*uint)
|
||||
for k, val := range src {
|
||||
v := val
|
||||
dst[k] = &v
|
||||
}
|
||||
return dst
|
||||
}
|
||||
|
||||
// UintValueMap converts a string map of uint pointers uinto a string
|
||||
// map of uint values
|
||||
func UintValueMap(src map[string]*uint) map[string]uint {
|
||||
dst := make(map[string]uint)
|
||||
for k, val := range src {
|
||||
if val != nil {
|
||||
dst[k] = *val
|
||||
}
|
||||
}
|
||||
return dst
|
||||
}
|
||||
|
||||
// Int8 returns a pointer to the int8 value passed in.
|
||||
func Int8(v int8) *int8 {
|
||||
return &v
|
||||
}
|
||||
|
||||
// Int8Value returns the value of the int8 pointer passed in or
|
||||
// 0 if the pointer is nil.
|
||||
func Int8Value(v *int8) int8 {
|
||||
if v != nil {
|
||||
return *v
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
// Int8Slice converts a slice of int8 values into a slice of
|
||||
// int8 pointers
|
||||
func Int8Slice(src []int8) []*int8 {
|
||||
dst := make([]*int8, len(src))
|
||||
for i := 0; i < len(src); i++ {
|
||||
dst[i] = &(src[i])
|
||||
}
|
||||
return dst
|
||||
}
|
||||
|
||||
// Int8ValueSlice converts a slice of int8 pointers into a slice of
|
||||
// int8 values
|
||||
func Int8ValueSlice(src []*int8) []int8 {
|
||||
dst := make([]int8, len(src))
|
||||
for i := 0; i < len(src); i++ {
|
||||
if src[i] != nil {
|
||||
dst[i] = *(src[i])
|
||||
}
|
||||
}
|
||||
return dst
|
||||
}
|
||||
|
||||
// Int8Map converts a string map of int8 values into a string
|
||||
// map of int8 pointers
|
||||
func Int8Map(src map[string]int8) map[string]*int8 {
|
||||
dst := make(map[string]*int8)
|
||||
for k, val := range src {
|
||||
v := val
|
||||
dst[k] = &v
|
||||
}
|
||||
return dst
|
||||
}
|
||||
|
||||
// Int8ValueMap converts a string map of int8 pointers into a string
|
||||
// map of int8 values
|
||||
func Int8ValueMap(src map[string]*int8) map[string]int8 {
|
||||
dst := make(map[string]int8)
|
||||
for k, val := range src {
|
||||
if val != nil {
|
||||
dst[k] = *val
|
||||
}
|
||||
}
|
||||
return dst
|
||||
}
|
||||
|
||||
// Int16 returns a pointer to the int16 value passed in.
|
||||
func Int16(v int16) *int16 {
|
||||
return &v
|
||||
}
|
||||
|
||||
// Int16Value returns the value of the int16 pointer passed in or
|
||||
// 0 if the pointer is nil.
|
||||
func Int16Value(v *int16) int16 {
|
||||
if v != nil {
|
||||
return *v
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
// Int16Slice converts a slice of int16 values into a slice of
|
||||
// int16 pointers
|
||||
func Int16Slice(src []int16) []*int16 {
|
||||
dst := make([]*int16, len(src))
|
||||
for i := 0; i < len(src); i++ {
|
||||
dst[i] = &(src[i])
|
||||
}
|
||||
return dst
|
||||
}
|
||||
|
||||
// Int16ValueSlice converts a slice of int16 pointers into a slice of
|
||||
// int16 values
|
||||
func Int16ValueSlice(src []*int16) []int16 {
|
||||
dst := make([]int16, len(src))
|
||||
for i := 0; i < len(src); i++ {
|
||||
if src[i] != nil {
|
||||
dst[i] = *(src[i])
|
||||
}
|
||||
}
|
||||
return dst
|
||||
}
|
||||
|
||||
// Int16Map converts a string map of int16 values into a string
|
||||
// map of int16 pointers
|
||||
func Int16Map(src map[string]int16) map[string]*int16 {
|
||||
dst := make(map[string]*int16)
|
||||
for k, val := range src {
|
||||
v := val
|
||||
dst[k] = &v
|
||||
}
|
||||
return dst
|
||||
}
|
||||
|
||||
// Int16ValueMap converts a string map of int16 pointers into a string
|
||||
// map of int16 values
|
||||
func Int16ValueMap(src map[string]*int16) map[string]int16 {
|
||||
dst := make(map[string]int16)
|
||||
for k, val := range src {
|
||||
if val != nil {
|
||||
dst[k] = *val
|
||||
}
|
||||
}
|
||||
return dst
|
||||
}
|
||||
|
||||
// Int32 returns a pointer to the int32 value passed in.
|
||||
func Int32(v int32) *int32 {
|
||||
return &v
|
||||
}
|
||||
|
||||
// Int32Value returns the value of the int32 pointer passed in or
|
||||
// 0 if the pointer is nil.
|
||||
func Int32Value(v *int32) int32 {
|
||||
if v != nil {
|
||||
return *v
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
// Int32Slice converts a slice of int32 values into a slice of
|
||||
// int32 pointers
|
||||
func Int32Slice(src []int32) []*int32 {
|
||||
dst := make([]*int32, len(src))
|
||||
for i := 0; i < len(src); i++ {
|
||||
dst[i] = &(src[i])
|
||||
}
|
||||
return dst
|
||||
}
|
||||
|
||||
// Int32ValueSlice converts a slice of int32 pointers into a slice of
|
||||
// int32 values
|
||||
func Int32ValueSlice(src []*int32) []int32 {
|
||||
dst := make([]int32, len(src))
|
||||
for i := 0; i < len(src); i++ {
|
||||
if src[i] != nil {
|
||||
dst[i] = *(src[i])
|
||||
}
|
||||
}
|
||||
return dst
|
||||
}
|
||||
|
||||
// Int32Map converts a string map of int32 values into a string
|
||||
// map of int32 pointers
|
||||
func Int32Map(src map[string]int32) map[string]*int32 {
|
||||
dst := make(map[string]*int32)
|
||||
for k, val := range src {
|
||||
v := val
|
||||
dst[k] = &v
|
||||
}
|
||||
return dst
|
||||
}
|
||||
|
||||
// Int32ValueMap converts a string map of int32 pointers into a string
|
||||
// map of int32 values
|
||||
func Int32ValueMap(src map[string]*int32) map[string]int32 {
|
||||
dst := make(map[string]int32)
|
||||
for k, val := range src {
|
||||
if val != nil {
|
||||
dst[k] = *val
|
||||
}
|
||||
}
|
||||
return dst
|
||||
}
|
||||
|
||||
// Int64 returns a pointer to the int64 value passed in.
|
||||
func Int64(v int64) *int64 {
|
||||
return &v
|
||||
|
@ -238,6 +474,301 @@ func Int64ValueMap(src map[string]*int64) map[string]int64 {
|
|||
return dst
|
||||
}
|
||||
|
||||
// Uint8 returns a pointer to the uint8 value passed in.
|
||||
func Uint8(v uint8) *uint8 {
|
||||
return &v
|
||||
}
|
||||
|
||||
// Uint8Value returns the value of the uint8 pointer passed in or
|
||||
// 0 if the pointer is nil.
|
||||
func Uint8Value(v *uint8) uint8 {
|
||||
if v != nil {
|
||||
return *v
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
// Uint8Slice converts a slice of uint8 values into a slice of
|
||||
// uint8 pointers
|
||||
func Uint8Slice(src []uint8) []*uint8 {
|
||||
dst := make([]*uint8, len(src))
|
||||
for i := 0; i < len(src); i++ {
|
||||
dst[i] = &(src[i])
|
||||
}
|
||||
return dst
|
||||
}
|
||||
|
||||
// Uint8ValueSlice converts a slice of uint8 pointers into a slice of
|
||||
// uint8 values
|
||||
func Uint8ValueSlice(src []*uint8) []uint8 {
|
||||
dst := make([]uint8, len(src))
|
||||
for i := 0; i < len(src); i++ {
|
||||
if src[i] != nil {
|
||||
dst[i] = *(src[i])
|
||||
}
|
||||
}
|
||||
return dst
|
||||
}
|
||||
|
||||
// Uint8Map converts a string map of uint8 values into a string
|
||||
// map of uint8 pointers
|
||||
func Uint8Map(src map[string]uint8) map[string]*uint8 {
|
||||
dst := make(map[string]*uint8)
|
||||
for k, val := range src {
|
||||
v := val
|
||||
dst[k] = &v
|
||||
}
|
||||
return dst
|
||||
}
|
||||
|
||||
// Uint8ValueMap converts a string map of uint8 pointers into a string
|
||||
// map of uint8 values
|
||||
func Uint8ValueMap(src map[string]*uint8) map[string]uint8 {
|
||||
dst := make(map[string]uint8)
|
||||
for k, val := range src {
|
||||
if val != nil {
|
||||
dst[k] = *val
|
||||
}
|
||||
}
|
||||
return dst
|
||||
}
|
||||
|
||||
// Uint16 returns a pointer to the uint16 value passed in.
|
||||
func Uint16(v uint16) *uint16 {
|
||||
return &v
|
||||
}
|
||||
|
||||
// Uint16Value returns the value of the uint16 pointer passed in or
|
||||
// 0 if the pointer is nil.
|
||||
func Uint16Value(v *uint16) uint16 {
|
||||
if v != nil {
|
||||
return *v
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
// Uint16Slice converts a slice of uint16 values into a slice of
|
||||
// uint16 pointers
|
||||
func Uint16Slice(src []uint16) []*uint16 {
|
||||
dst := make([]*uint16, len(src))
|
||||
for i := 0; i < len(src); i++ {
|
||||
dst[i] = &(src[i])
|
||||
}
|
||||
return dst
|
||||
}
|
||||
|
||||
// Uint16ValueSlice converts a slice of uint16 pointers into a slice of
|
||||
// uint16 values
|
||||
func Uint16ValueSlice(src []*uint16) []uint16 {
|
||||
dst := make([]uint16, len(src))
|
||||
for i := 0; i < len(src); i++ {
|
||||
if src[i] != nil {
|
||||
dst[i] = *(src[i])
|
||||
}
|
||||
}
|
||||
return dst
|
||||
}
|
||||
|
||||
// Uint16Map converts a string map of uint16 values into a string
|
||||
// map of uint16 pointers
|
||||
func Uint16Map(src map[string]uint16) map[string]*uint16 {
|
||||
dst := make(map[string]*uint16)
|
||||
for k, val := range src {
|
||||
v := val
|
||||
dst[k] = &v
|
||||
}
|
||||
return dst
|
||||
}
|
||||
|
||||
// Uint16ValueMap converts a string map of uint16 pointers into a string
|
||||
// map of uint16 values
|
||||
func Uint16ValueMap(src map[string]*uint16) map[string]uint16 {
|
||||
dst := make(map[string]uint16)
|
||||
for k, val := range src {
|
||||
if val != nil {
|
||||
dst[k] = *val
|
||||
}
|
||||
}
|
||||
return dst
|
||||
}
|
||||
|
||||
// Uint32 returns a pointer to the uint32 value passed in.
|
||||
func Uint32(v uint32) *uint32 {
|
||||
return &v
|
||||
}
|
||||
|
||||
// Uint32Value returns the value of the uint32 pointer passed in or
|
||||
// 0 if the pointer is nil.
|
||||
func Uint32Value(v *uint32) uint32 {
|
||||
if v != nil {
|
||||
return *v
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
// Uint32Slice converts a slice of uint32 values into a slice of
|
||||
// uint32 pointers
|
||||
func Uint32Slice(src []uint32) []*uint32 {
|
||||
dst := make([]*uint32, len(src))
|
||||
for i := 0; i < len(src); i++ {
|
||||
dst[i] = &(src[i])
|
||||
}
|
||||
return dst
|
||||
}
|
||||
|
||||
// Uint32ValueSlice converts a slice of uint32 pointers into a slice of
|
||||
// uint32 values
|
||||
func Uint32ValueSlice(src []*uint32) []uint32 {
|
||||
dst := make([]uint32, len(src))
|
||||
for i := 0; i < len(src); i++ {
|
||||
if src[i] != nil {
|
||||
dst[i] = *(src[i])
|
||||
}
|
||||
}
|
||||
return dst
|
||||
}
|
||||
|
||||
// Uint32Map converts a string map of uint32 values into a string
|
||||
// map of uint32 pointers
|
||||
func Uint32Map(src map[string]uint32) map[string]*uint32 {
|
||||
dst := make(map[string]*uint32)
|
||||
for k, val := range src {
|
||||
v := val
|
||||
dst[k] = &v
|
||||
}
|
||||
return dst
|
||||
}
|
||||
|
||||
// Uint32ValueMap converts a string map of uint32 pointers into a string
|
||||
// map of uint32 values
|
||||
func Uint32ValueMap(src map[string]*uint32) map[string]uint32 {
|
||||
dst := make(map[string]uint32)
|
||||
for k, val := range src {
|
||||
if val != nil {
|
||||
dst[k] = *val
|
||||
}
|
||||
}
|
||||
return dst
|
||||
}
|
||||
|
||||
// Uint64 returns a pointer to the uint64 value passed in.
|
||||
func Uint64(v uint64) *uint64 {
|
||||
return &v
|
||||
}
|
||||
|
||||
// Uint64Value returns the value of the uint64 pointer passed in or
|
||||
// 0 if the pointer is nil.
|
||||
func Uint64Value(v *uint64) uint64 {
|
||||
if v != nil {
|
||||
return *v
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
// Uint64Slice converts a slice of uint64 values into a slice of
|
||||
// uint64 pointers
|
||||
func Uint64Slice(src []uint64) []*uint64 {
|
||||
dst := make([]*uint64, len(src))
|
||||
for i := 0; i < len(src); i++ {
|
||||
dst[i] = &(src[i])
|
||||
}
|
||||
return dst
|
||||
}
|
||||
|
||||
// Uint64ValueSlice converts a slice of uint64 pointers into a slice of
|
||||
// uint64 values
|
||||
func Uint64ValueSlice(src []*uint64) []uint64 {
|
||||
dst := make([]uint64, len(src))
|
||||
for i := 0; i < len(src); i++ {
|
||||
if src[i] != nil {
|
||||
dst[i] = *(src[i])
|
||||
}
|
||||
}
|
||||
return dst
|
||||
}
|
||||
|
||||
// Uint64Map converts a string map of uint64 values into a string
|
||||
// map of uint64 pointers
|
||||
func Uint64Map(src map[string]uint64) map[string]*uint64 {
|
||||
dst := make(map[string]*uint64)
|
||||
for k, val := range src {
|
||||
v := val
|
||||
dst[k] = &v
|
||||
}
|
||||
return dst
|
||||
}
|
||||
|
||||
// Uint64ValueMap converts a string map of uint64 pointers into a string
|
||||
// map of uint64 values
|
||||
func Uint64ValueMap(src map[string]*uint64) map[string]uint64 {
|
||||
dst := make(map[string]uint64)
|
||||
for k, val := range src {
|
||||
if val != nil {
|
||||
dst[k] = *val
|
||||
}
|
||||
}
|
||||
return dst
|
||||
}
|
||||
|
||||
// Float32 returns a pointer to the float32 value passed in.
|
||||
func Float32(v float32) *float32 {
|
||||
return &v
|
||||
}
|
||||
|
||||
// Float32Value returns the value of the float32 pointer passed in or
|
||||
// 0 if the pointer is nil.
|
||||
func Float32Value(v *float32) float32 {
|
||||
if v != nil {
|
||||
return *v
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
// Float32Slice converts a slice of float32 values into a slice of
|
||||
// float32 pointers
|
||||
func Float32Slice(src []float32) []*float32 {
|
||||
dst := make([]*float32, len(src))
|
||||
for i := 0; i < len(src); i++ {
|
||||
dst[i] = &(src[i])
|
||||
}
|
||||
return dst
|
||||
}
|
||||
|
||||
// Float32ValueSlice converts a slice of float32 pointers into a slice of
|
||||
// float32 values
|
||||
func Float32ValueSlice(src []*float32) []float32 {
|
||||
dst := make([]float32, len(src))
|
||||
for i := 0; i < len(src); i++ {
|
||||
if src[i] != nil {
|
||||
dst[i] = *(src[i])
|
||||
}
|
||||
}
|
||||
return dst
|
||||
}
|
||||
|
||||
// Float32Map converts a string map of float32 values into a string
|
||||
// map of float32 pointers
|
||||
func Float32Map(src map[string]float32) map[string]*float32 {
|
||||
dst := make(map[string]*float32)
|
||||
for k, val := range src {
|
||||
v := val
|
||||
dst[k] = &v
|
||||
}
|
||||
return dst
|
||||
}
|
||||
|
||||
// Float32ValueMap converts a string map of float32 pointers into a string
|
||||
// map of float32 values
|
||||
func Float32ValueMap(src map[string]*float32) map[string]float32 {
|
||||
dst := make(map[string]float32)
|
||||
for k, val := range src {
|
||||
if val != nil {
|
||||
dst[k] = *val
|
||||
}
|
||||
}
|
||||
return dst
|
||||
}
|
||||
|
||||
// Float64 returns a pointer to the float64 value passed in.
|
||||
func Float64(v float64) *float64 {
|
||||
return &v
|
||||
|
|
|
@ -72,9 +72,9 @@ var ValidateReqSigHandler = request.NamedHandler{
|
|||
signedTime = r.LastSignedAt
|
||||
}
|
||||
|
||||
// 10 minutes to allow for some clock skew/delays in transmission.
|
||||
// 5 minutes to allow for some clock skew/delays in transmission.
|
||||
// Would be improved with aws/aws-sdk-go#423
|
||||
if signedTime.Add(10 * time.Minute).After(time.Now()) {
|
||||
if signedTime.Add(5 * time.Minute).After(time.Now()) {
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -159,9 +159,9 @@ func handleSendError(r *request.Request, err error) {
|
|||
Body: ioutil.NopCloser(bytes.NewReader([]byte{})),
|
||||
}
|
||||
}
|
||||
// Catch all other request errors.
|
||||
// Catch all request errors, and let the default retrier determine
|
||||
// if the error is retryable.
|
||||
r.Error = awserr.New("RequestError", "send request failed", err)
|
||||
r.Retryable = aws.Bool(true) // network errors are retryable
|
||||
|
||||
// Override the error with a context canceled error, if that was canceled.
|
||||
ctx := r.Context()
|
||||
|
@ -184,37 +184,39 @@ var ValidateResponseHandler = request.NamedHandler{Name: "core.ValidateResponseH
|
|||
|
||||
// AfterRetryHandler performs final checks to determine if the request should
|
||||
// be retried and how long to delay.
|
||||
var AfterRetryHandler = request.NamedHandler{Name: "core.AfterRetryHandler", Fn: func(r *request.Request) {
|
||||
// If one of the other handlers already set the retry state
|
||||
// we don't want to override it based on the service's state
|
||||
if r.Retryable == nil || aws.BoolValue(r.Config.EnforceShouldRetryCheck) {
|
||||
r.Retryable = aws.Bool(r.ShouldRetry(r))
|
||||
}
|
||||
|
||||
if r.WillRetry() {
|
||||
r.RetryDelay = r.RetryRules(r)
|
||||
|
||||
if sleepFn := r.Config.SleepDelay; sleepFn != nil {
|
||||
// Support SleepDelay for backwards compatibility and testing
|
||||
sleepFn(r.RetryDelay)
|
||||
} else if err := aws.SleepWithContext(r.Context(), r.RetryDelay); err != nil {
|
||||
r.Error = awserr.New(request.CanceledErrorCode,
|
||||
"request context canceled", err)
|
||||
r.Retryable = aws.Bool(false)
|
||||
return
|
||||
var AfterRetryHandler = request.NamedHandler{
|
||||
Name: "core.AfterRetryHandler",
|
||||
Fn: func(r *request.Request) {
|
||||
// If one of the other handlers already set the retry state
|
||||
// we don't want to override it based on the service's state
|
||||
if r.Retryable == nil || aws.BoolValue(r.Config.EnforceShouldRetryCheck) {
|
||||
r.Retryable = aws.Bool(r.ShouldRetry(r))
|
||||
}
|
||||
|
||||
// when the expired token exception occurs the credentials
|
||||
// need to be expired locally so that the next request to
|
||||
// get credentials will trigger a credentials refresh.
|
||||
if r.IsErrorExpired() {
|
||||
r.Config.Credentials.Expire()
|
||||
}
|
||||
if r.WillRetry() {
|
||||
r.RetryDelay = r.RetryRules(r)
|
||||
|
||||
r.RetryCount++
|
||||
r.Error = nil
|
||||
}
|
||||
}}
|
||||
if sleepFn := r.Config.SleepDelay; sleepFn != nil {
|
||||
// Support SleepDelay for backwards compatibility and testing
|
||||
sleepFn(r.RetryDelay)
|
||||
} else if err := aws.SleepWithContext(r.Context(), r.RetryDelay); err != nil {
|
||||
r.Error = awserr.New(request.CanceledErrorCode,
|
||||
"request context canceled", err)
|
||||
r.Retryable = aws.Bool(false)
|
||||
return
|
||||
}
|
||||
|
||||
// when the expired token exception occurs the credentials
|
||||
// need to be expired locally so that the next request to
|
||||
// get credentials will trigger a credentials refresh.
|
||||
if r.IsErrorExpired() {
|
||||
r.Config.Credentials.Expire()
|
||||
}
|
||||
|
||||
r.RetryCount++
|
||||
r.Error = nil
|
||||
}
|
||||
}}
|
||||
|
||||
// ValidateEndpointHandler is a request handler to validate a request had the
|
||||
// appropriate Region and Endpoint set. Will set r.Error if the endpoint or
|
||||
|
|
|
@ -17,7 +17,7 @@ var SDKVersionUserAgentHandler = request.NamedHandler{
|
|||
}
|
||||
|
||||
const execEnvVar = `AWS_EXECUTION_ENV`
|
||||
const execEnvUAKey = `exec_env`
|
||||
const execEnvUAKey = `exec-env`
|
||||
|
||||
// AddHostExecEnvUserAgentHander is a request handler appending the SDK's
|
||||
// execution environment to the user agent.
|
||||
|
|
|
@ -9,9 +9,7 @@ var (
|
|||
// providers in the ChainProvider.
|
||||
//
|
||||
// This has been deprecated. For verbose error messaging set
|
||||
// aws.Config.CredentialsChainVerboseErrors to true
|
||||
//
|
||||
// @readonly
|
||||
// aws.Config.CredentialsChainVerboseErrors to true.
|
||||
ErrNoValidProvidersFoundInChain = awserr.New("NoCredentialProviders",
|
||||
`no valid providers in chain. Deprecated.
|
||||
For verbose messaging see aws.Config.CredentialsChainVerboseErrors`,
|
||||
|
|
|
@ -49,8 +49,11 @@
|
|||
package credentials
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/aws/aws-sdk-go/aws/awserr"
|
||||
)
|
||||
|
||||
// AnonymousCredentials is an empty Credential object that can be used as
|
||||
|
@ -64,8 +67,6 @@ import (
|
|||
// Credentials: credentials.AnonymousCredentials,
|
||||
// })))
|
||||
// // Access public S3 buckets.
|
||||
//
|
||||
// @readonly
|
||||
var AnonymousCredentials = NewStaticCredentials("", "", "")
|
||||
|
||||
// A Value is the AWS credentials value for individual credential fields.
|
||||
|
@ -83,6 +84,12 @@ type Value struct {
|
|||
ProviderName string
|
||||
}
|
||||
|
||||
// HasKeys returns if the credentials Value has both AccessKeyID and
|
||||
// SecretAccessKey value set.
|
||||
func (v Value) HasKeys() bool {
|
||||
return len(v.AccessKeyID) != 0 && len(v.SecretAccessKey) != 0
|
||||
}
|
||||
|
||||
// A Provider is the interface for any component which will provide credentials
|
||||
// Value. A provider is required to manage its own Expired state, and what to
|
||||
// be expired means.
|
||||
|
@ -99,6 +106,14 @@ type Provider interface {
|
|||
IsExpired() bool
|
||||
}
|
||||
|
||||
// An Expirer is an interface that Providers can implement to expose the expiration
|
||||
// time, if known. If the Provider cannot accurately provide this info,
|
||||
// it should not implement this interface.
|
||||
type Expirer interface {
|
||||
// The time at which the credentials are no longer valid
|
||||
ExpiresAt() time.Time
|
||||
}
|
||||
|
||||
// An ErrorProvider is a stub credentials provider that always returns an error
|
||||
// this is used by the SDK when construction a known provider is not possible
|
||||
// due to an error.
|
||||
|
@ -165,6 +180,11 @@ func (e *Expiry) IsExpired() bool {
|
|||
return e.expiration.Before(curTime())
|
||||
}
|
||||
|
||||
// ExpiresAt returns the expiration time of the credential
|
||||
func (e *Expiry) ExpiresAt() time.Time {
|
||||
return e.expiration
|
||||
}
|
||||
|
||||
// A Credentials provides concurrency safe retrieval of AWS credentials Value.
|
||||
// Credentials will cache the credentials value until they expire. Once the value
|
||||
// expires the next Get will attempt to retrieve valid credentials.
|
||||
|
@ -257,3 +277,23 @@ func (c *Credentials) IsExpired() bool {
|
|||
func (c *Credentials) isExpired() bool {
|
||||
return c.forceRefresh || c.provider.IsExpired()
|
||||
}
|
||||
|
||||
// ExpiresAt provides access to the functionality of the Expirer interface of
|
||||
// the underlying Provider, if it supports that interface. Otherwise, it returns
|
||||
// an error.
|
||||
func (c *Credentials) ExpiresAt() (time.Time, error) {
|
||||
c.m.RLock()
|
||||
defer c.m.RUnlock()
|
||||
|
||||
expirer, ok := c.provider.(Expirer)
|
||||
if !ok {
|
||||
return time.Time{}, awserr.New("ProviderNotExpirer",
|
||||
fmt.Sprintf("provider %s does not support ExpiresAt()", c.creds.ProviderName),
|
||||
nil)
|
||||
}
|
||||
if c.forceRefresh {
|
||||
// set expiration time to the distant past
|
||||
return time.Time{}, nil
|
||||
}
|
||||
return expirer.ExpiresAt(), nil
|
||||
}
|
||||
|
|
6
vendor/github.com/aws/aws-sdk-go/aws/credentials/ec2rolecreds/ec2_role_provider.go
generated
vendored
6
vendor/github.com/aws/aws-sdk-go/aws/credentials/ec2rolecreds/ec2_role_provider.go
generated
vendored
|
@ -11,6 +11,7 @@ import (
|
|||
"github.com/aws/aws-sdk-go/aws/client"
|
||||
"github.com/aws/aws-sdk-go/aws/credentials"
|
||||
"github.com/aws/aws-sdk-go/aws/ec2metadata"
|
||||
"github.com/aws/aws-sdk-go/aws/request"
|
||||
"github.com/aws/aws-sdk-go/internal/sdkuri"
|
||||
)
|
||||
|
||||
|
@ -142,7 +143,8 @@ func requestCredList(client *ec2metadata.EC2Metadata) ([]string, error) {
|
|||
}
|
||||
|
||||
if err := s.Err(); err != nil {
|
||||
return nil, awserr.New("SerializationError", "failed to read EC2 instance role from metadata service", err)
|
||||
return nil, awserr.New(request.ErrCodeSerialization,
|
||||
"failed to read EC2 instance role from metadata service", err)
|
||||
}
|
||||
|
||||
return credsList, nil
|
||||
|
@ -164,7 +166,7 @@ func requestCred(client *ec2metadata.EC2Metadata, credsName string) (ec2RoleCred
|
|||
respCreds := ec2RoleCredRespBody{}
|
||||
if err := json.NewDecoder(strings.NewReader(resp)).Decode(&respCreds); err != nil {
|
||||
return ec2RoleCredRespBody{},
|
||||
awserr.New("SerializationError",
|
||||
awserr.New(request.ErrCodeSerialization,
|
||||
fmt.Sprintf("failed to decode %s EC2 instance role credentials", credsName),
|
||||
err)
|
||||
}
|
||||
|
|
|
@ -39,6 +39,7 @@ import (
|
|||
"github.com/aws/aws-sdk-go/aws/client/metadata"
|
||||
"github.com/aws/aws-sdk-go/aws/credentials"
|
||||
"github.com/aws/aws-sdk-go/aws/request"
|
||||
"github.com/aws/aws-sdk-go/private/protocol/json/jsonutil"
|
||||
)
|
||||
|
||||
// ProviderName is the name of the credentials provider.
|
||||
|
@ -65,6 +66,10 @@ type Provider struct {
|
|||
//
|
||||
// If ExpiryWindow is 0 or less it will be ignored.
|
||||
ExpiryWindow time.Duration
|
||||
|
||||
// Optional authorization token value if set will be used as the value of
|
||||
// the Authorization header of the endpoint credential request.
|
||||
AuthorizationToken string
|
||||
}
|
||||
|
||||
// NewProviderClient returns a credentials Provider for retrieving AWS credentials
|
||||
|
@ -93,8 +98,8 @@ func NewProviderClient(cfg aws.Config, handlers request.Handlers, endpoint strin
|
|||
return p
|
||||
}
|
||||
|
||||
// NewCredentialsClient returns a Credentials wrapper for retrieving credentials
|
||||
// from an arbitrary endpoint concurrently. The client will request the
|
||||
// NewCredentialsClient returns a pointer to a new Credentials object
|
||||
// wrapping the endpoint credentials Provider.
|
||||
func NewCredentialsClient(cfg aws.Config, handlers request.Handlers, endpoint string, options ...func(*Provider)) *credentials.Credentials {
|
||||
return credentials.NewCredentials(NewProviderClient(cfg, handlers, endpoint, options...))
|
||||
}
|
||||
|
@ -152,6 +157,9 @@ func (p *Provider) getCredentials() (*getCredentialsOutput, error) {
|
|||
out := &getCredentialsOutput{}
|
||||
req := p.Client.NewRequest(op, nil, out)
|
||||
req.HTTPRequest.Header.Set("Accept", "application/json")
|
||||
if authToken := p.AuthorizationToken; len(authToken) != 0 {
|
||||
req.HTTPRequest.Header.Set("Authorization", authToken)
|
||||
}
|
||||
|
||||
return out, req.Send()
|
||||
}
|
||||
|
@ -167,7 +175,7 @@ func unmarshalHandler(r *request.Request) {
|
|||
|
||||
out := r.Data.(*getCredentialsOutput)
|
||||
if err := json.NewDecoder(r.HTTPResponse.Body).Decode(&out); err != nil {
|
||||
r.Error = awserr.New("SerializationError",
|
||||
r.Error = awserr.New(request.ErrCodeSerialization,
|
||||
"failed to decode endpoint credentials",
|
||||
err,
|
||||
)
|
||||
|
@ -178,11 +186,15 @@ func unmarshalError(r *request.Request) {
|
|||
defer r.HTTPResponse.Body.Close()
|
||||
|
||||
var errOut errorOutput
|
||||
if err := json.NewDecoder(r.HTTPResponse.Body).Decode(&errOut); err != nil {
|
||||
r.Error = awserr.New("SerializationError",
|
||||
"failed to decode endpoint credentials",
|
||||
err,
|
||||
err := jsonutil.UnmarshalJSONError(&errOut, r.HTTPResponse.Body)
|
||||
if err != nil {
|
||||
r.Error = awserr.NewRequestFailure(
|
||||
awserr.New(request.ErrCodeSerialization,
|
||||
"failed to decode error message", err),
|
||||
r.HTTPResponse.StatusCode,
|
||||
r.RequestID,
|
||||
)
|
||||
return
|
||||
}
|
||||
|
||||
// Response body format is not consistent between metadata endpoints.
|
||||
|
|
|
@ -12,14 +12,10 @@ const EnvProviderName = "EnvProvider"
|
|||
var (
|
||||
// ErrAccessKeyIDNotFound is returned when the AWS Access Key ID can't be
|
||||
// found in the process's environment.
|
||||
//
|
||||
// @readonly
|
||||
ErrAccessKeyIDNotFound = awserr.New("EnvAccessKeyNotFound", "AWS_ACCESS_KEY_ID or AWS_ACCESS_KEY not found in environment", nil)
|
||||
|
||||
// ErrSecretAccessKeyNotFound is returned when the AWS Secret Access Key
|
||||
// can't be found in the process's environment.
|
||||
//
|
||||
// @readonly
|
||||
ErrSecretAccessKeyNotFound = awserr.New("EnvSecretNotFound", "AWS_SECRET_ACCESS_KEY or AWS_SECRET_KEY not found in environment", nil)
|
||||
)
|
||||
|
||||
|
|
425
vendor/github.com/aws/aws-sdk-go/aws/credentials/processcreds/provider.go
generated
vendored
Normal file
425
vendor/github.com/aws/aws-sdk-go/aws/credentials/processcreds/provider.go
generated
vendored
Normal file
|
@ -0,0 +1,425 @@
|
|||
/*
|
||||
Package processcreds is a credential Provider to retrieve `credential_process`
|
||||
credentials.
|
||||
|
||||
WARNING: The following describes a method of sourcing credentials from an external
|
||||
process. This can potentially be dangerous, so proceed with caution. Other
|
||||
credential providers should be preferred if at all possible. If using this
|
||||
option, you should make sure that the config file is as locked down as possible
|
||||
using security best practices for your operating system.
|
||||
|
||||
You can use credentials from a `credential_process` in a variety of ways.
|
||||
|
||||
One way is to setup your shared config file, located in the default
|
||||
location, with the `credential_process` key and the command you want to be
|
||||
called. You also need to set the AWS_SDK_LOAD_CONFIG environment variable
|
||||
(e.g., `export AWS_SDK_LOAD_CONFIG=1`) to use the shared config file.
|
||||
|
||||
[default]
|
||||
credential_process = /command/to/call
|
||||
|
||||
Creating a new session will use the credential process to retrieve credentials.
|
||||
NOTE: If there are credentials in the profile you are using, the credential
|
||||
process will not be used.
|
||||
|
||||
// Initialize a session to load credentials.
|
||||
sess, _ := session.NewSession(&aws.Config{
|
||||
Region: aws.String("us-east-1")},
|
||||
)
|
||||
|
||||
// Create S3 service client to use the credentials.
|
||||
svc := s3.New(sess)
|
||||
|
||||
Another way to use the `credential_process` method is by using
|
||||
`credentials.NewCredentials()` and providing a command to be executed to
|
||||
retrieve credentials:
|
||||
|
||||
// Create credentials using the ProcessProvider.
|
||||
creds := processcreds.NewCredentials("/path/to/command")
|
||||
|
||||
// Create service client value configured for credentials.
|
||||
svc := s3.New(sess, &aws.Config{Credentials: creds})
|
||||
|
||||
You can set a non-default timeout for the `credential_process` with another
|
||||
constructor, `credentials.NewCredentialsTimeout()`, providing the timeout. To
|
||||
set a one minute timeout:
|
||||
|
||||
// Create credentials using the ProcessProvider.
|
||||
creds := processcreds.NewCredentialsTimeout(
|
||||
"/path/to/command",
|
||||
time.Duration(500) * time.Millisecond)
|
||||
|
||||
If you need more control, you can set any configurable options in the
|
||||
credentials using one or more option functions. For example, you can set a two
|
||||
minute timeout, a credential duration of 60 minutes, and a maximum stdout
|
||||
buffer size of 2k.
|
||||
|
||||
creds := processcreds.NewCredentials(
|
||||
"/path/to/command",
|
||||
func(opt *ProcessProvider) {
|
||||
opt.Timeout = time.Duration(2) * time.Minute
|
||||
opt.Duration = time.Duration(60) * time.Minute
|
||||
opt.MaxBufSize = 2048
|
||||
})
|
||||
|
||||
You can also use your own `exec.Cmd`:
|
||||
|
||||
// Create an exec.Cmd
|
||||
myCommand := exec.Command("/path/to/command")
|
||||
|
||||
// Create credentials using your exec.Cmd and custom timeout
|
||||
creds := processcreds.NewCredentialsCommand(
|
||||
myCommand,
|
||||
func(opt *processcreds.ProcessProvider) {
|
||||
opt.Timeout = time.Duration(1) * time.Second
|
||||
})
|
||||
*/
|
||||
package processcreds
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"os"
|
||||
"os/exec"
|
||||
"runtime"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/aws/aws-sdk-go/aws/awserr"
|
||||
"github.com/aws/aws-sdk-go/aws/credentials"
|
||||
)
|
||||
|
||||
const (
|
||||
// ProviderName is the name this credentials provider will label any
|
||||
// returned credentials Value with.
|
||||
ProviderName = `ProcessProvider`
|
||||
|
||||
// ErrCodeProcessProviderParse error parsing process output
|
||||
ErrCodeProcessProviderParse = "ProcessProviderParseError"
|
||||
|
||||
// ErrCodeProcessProviderVersion version error in output
|
||||
ErrCodeProcessProviderVersion = "ProcessProviderVersionError"
|
||||
|
||||
// ErrCodeProcessProviderRequired required attribute missing in output
|
||||
ErrCodeProcessProviderRequired = "ProcessProviderRequiredError"
|
||||
|
||||
// ErrCodeProcessProviderExecution execution of command failed
|
||||
ErrCodeProcessProviderExecution = "ProcessProviderExecutionError"
|
||||
|
||||
// errMsgProcessProviderTimeout process took longer than allowed
|
||||
errMsgProcessProviderTimeout = "credential process timed out"
|
||||
|
||||
// errMsgProcessProviderProcess process error
|
||||
errMsgProcessProviderProcess = "error in credential_process"
|
||||
|
||||
// errMsgProcessProviderParse problem parsing output
|
||||
errMsgProcessProviderParse = "parse failed of credential_process output"
|
||||
|
||||
// errMsgProcessProviderVersion version error in output
|
||||
errMsgProcessProviderVersion = "wrong version in process output (not 1)"
|
||||
|
||||
// errMsgProcessProviderMissKey missing access key id in output
|
||||
errMsgProcessProviderMissKey = "missing AccessKeyId in process output"
|
||||
|
||||
// errMsgProcessProviderMissSecret missing secret acess key in output
|
||||
errMsgProcessProviderMissSecret = "missing SecretAccessKey in process output"
|
||||
|
||||
// errMsgProcessProviderPrepareCmd prepare of command failed
|
||||
errMsgProcessProviderPrepareCmd = "failed to prepare command"
|
||||
|
||||
// errMsgProcessProviderEmptyCmd command must not be empty
|
||||
errMsgProcessProviderEmptyCmd = "command must not be empty"
|
||||
|
||||
// errMsgProcessProviderPipe failed to initialize pipe
|
||||
errMsgProcessProviderPipe = "failed to initialize pipe"
|
||||
|
||||
// DefaultDuration is the default amount of time in minutes that the
|
||||
// credentials will be valid for.
|
||||
DefaultDuration = time.Duration(15) * time.Minute
|
||||
|
||||
// DefaultBufSize limits buffer size from growing to an enormous
|
||||
// amount due to a faulty process.
|
||||
DefaultBufSize = 1024
|
||||
|
||||
// DefaultTimeout default limit on time a process can run.
|
||||
DefaultTimeout = time.Duration(1) * time.Minute
|
||||
)
|
||||
|
||||
// ProcessProvider satisfies the credentials.Provider interface, and is a
|
||||
// client to retrieve credentials from a process.
|
||||
type ProcessProvider struct {
|
||||
staticCreds bool
|
||||
credentials.Expiry
|
||||
originalCommand []string
|
||||
|
||||
// Expiry duration of the credentials. Defaults to 15 minutes if not set.
|
||||
Duration time.Duration
|
||||
|
||||
// ExpiryWindow will allow the credentials to trigger refreshing prior to
|
||||
// the credentials actually expiring. This is beneficial so race conditions
|
||||
// with expiring credentials do not cause request to fail unexpectedly
|
||||
// due to ExpiredTokenException exceptions.
|
||||
//
|
||||
// So a ExpiryWindow of 10s would cause calls to IsExpired() to return true
|
||||
// 10 seconds before the credentials are actually expired.
|
||||
//
|
||||
// If ExpiryWindow is 0 or less it will be ignored.
|
||||
ExpiryWindow time.Duration
|
||||
|
||||
// A string representing an os command that should return a JSON with
|
||||
// credential information.
|
||||
command *exec.Cmd
|
||||
|
||||
// MaxBufSize limits memory usage from growing to an enormous
|
||||
// amount due to a faulty process.
|
||||
MaxBufSize int
|
||||
|
||||
// Timeout limits the time a process can run.
|
||||
Timeout time.Duration
|
||||
}
|
||||
|
||||
// NewCredentials returns a pointer to a new Credentials object wrapping the
|
||||
// ProcessProvider. The credentials will expire every 15 minutes by default.
|
||||
func NewCredentials(command string, options ...func(*ProcessProvider)) *credentials.Credentials {
|
||||
p := &ProcessProvider{
|
||||
command: exec.Command(command),
|
||||
Duration: DefaultDuration,
|
||||
Timeout: DefaultTimeout,
|
||||
MaxBufSize: DefaultBufSize,
|
||||
}
|
||||
|
||||
for _, option := range options {
|
||||
option(p)
|
||||
}
|
||||
|
||||
return credentials.NewCredentials(p)
|
||||
}
|
||||
|
||||
// NewCredentialsTimeout returns a pointer to a new Credentials object with
|
||||
// the specified command and timeout, and default duration and max buffer size.
|
||||
func NewCredentialsTimeout(command string, timeout time.Duration) *credentials.Credentials {
|
||||
p := NewCredentials(command, func(opt *ProcessProvider) {
|
||||
opt.Timeout = timeout
|
||||
})
|
||||
|
||||
return p
|
||||
}
|
||||
|
||||
// NewCredentialsCommand returns a pointer to a new Credentials object with
|
||||
// the specified command, and default timeout, duration and max buffer size.
|
||||
func NewCredentialsCommand(command *exec.Cmd, options ...func(*ProcessProvider)) *credentials.Credentials {
|
||||
p := &ProcessProvider{
|
||||
command: command,
|
||||
Duration: DefaultDuration,
|
||||
Timeout: DefaultTimeout,
|
||||
MaxBufSize: DefaultBufSize,
|
||||
}
|
||||
|
||||
for _, option := range options {
|
||||
option(p)
|
||||
}
|
||||
|
||||
return credentials.NewCredentials(p)
|
||||
}
|
||||
|
||||
type credentialProcessResponse struct {
|
||||
Version int
|
||||
AccessKeyID string `json:"AccessKeyId"`
|
||||
SecretAccessKey string
|
||||
SessionToken string
|
||||
Expiration *time.Time
|
||||
}
|
||||
|
||||
// Retrieve executes the 'credential_process' and returns the credentials.
|
||||
func (p *ProcessProvider) Retrieve() (credentials.Value, error) {
|
||||
out, err := p.executeCredentialProcess()
|
||||
if err != nil {
|
||||
return credentials.Value{ProviderName: ProviderName}, err
|
||||
}
|
||||
|
||||
// Serialize and validate response
|
||||
resp := &credentialProcessResponse{}
|
||||
if err = json.Unmarshal(out, resp); err != nil {
|
||||
return credentials.Value{ProviderName: ProviderName}, awserr.New(
|
||||
ErrCodeProcessProviderParse,
|
||||
fmt.Sprintf("%s: %s", errMsgProcessProviderParse, string(out)),
|
||||
err)
|
||||
}
|
||||
|
||||
if resp.Version != 1 {
|
||||
return credentials.Value{ProviderName: ProviderName}, awserr.New(
|
||||
ErrCodeProcessProviderVersion,
|
||||
errMsgProcessProviderVersion,
|
||||
nil)
|
||||
}
|
||||
|
||||
if len(resp.AccessKeyID) == 0 {
|
||||
return credentials.Value{ProviderName: ProviderName}, awserr.New(
|
||||
ErrCodeProcessProviderRequired,
|
||||
errMsgProcessProviderMissKey,
|
||||
nil)
|
||||
}
|
||||
|
||||
if len(resp.SecretAccessKey) == 0 {
|
||||
return credentials.Value{ProviderName: ProviderName}, awserr.New(
|
||||
ErrCodeProcessProviderRequired,
|
||||
errMsgProcessProviderMissSecret,
|
||||
nil)
|
||||
}
|
||||
|
||||
// Handle expiration
|
||||
p.staticCreds = resp.Expiration == nil
|
||||
if resp.Expiration != nil {
|
||||
p.SetExpiration(*resp.Expiration, p.ExpiryWindow)
|
||||
}
|
||||
|
||||
return credentials.Value{
|
||||
ProviderName: ProviderName,
|
||||
AccessKeyID: resp.AccessKeyID,
|
||||
SecretAccessKey: resp.SecretAccessKey,
|
||||
SessionToken: resp.SessionToken,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// IsExpired returns true if the credentials retrieved are expired, or not yet
|
||||
// retrieved.
|
||||
func (p *ProcessProvider) IsExpired() bool {
|
||||
if p.staticCreds {
|
||||
return false
|
||||
}
|
||||
return p.Expiry.IsExpired()
|
||||
}
|
||||
|
||||
// prepareCommand prepares the command to be executed.
|
||||
func (p *ProcessProvider) prepareCommand() error {
|
||||
|
||||
var cmdArgs []string
|
||||
if runtime.GOOS == "windows" {
|
||||
cmdArgs = []string{"cmd.exe", "/C"}
|
||||
} else {
|
||||
cmdArgs = []string{"sh", "-c"}
|
||||
}
|
||||
|
||||
if len(p.originalCommand) == 0 {
|
||||
p.originalCommand = make([]string, len(p.command.Args))
|
||||
copy(p.originalCommand, p.command.Args)
|
||||
|
||||
// check for empty command because it succeeds
|
||||
if len(strings.TrimSpace(p.originalCommand[0])) < 1 {
|
||||
return awserr.New(
|
||||
ErrCodeProcessProviderExecution,
|
||||
fmt.Sprintf(
|
||||
"%s: %s",
|
||||
errMsgProcessProviderPrepareCmd,
|
||||
errMsgProcessProviderEmptyCmd),
|
||||
nil)
|
||||
}
|
||||
}
|
||||
|
||||
cmdArgs = append(cmdArgs, p.originalCommand...)
|
||||
p.command = exec.Command(cmdArgs[0], cmdArgs[1:]...)
|
||||
p.command.Env = os.Environ()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// executeCredentialProcess starts the credential process on the OS and
|
||||
// returns the results or an error.
|
||||
func (p *ProcessProvider) executeCredentialProcess() ([]byte, error) {
|
||||
|
||||
if err := p.prepareCommand(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Setup the pipes
|
||||
outReadPipe, outWritePipe, err := os.Pipe()
|
||||
if err != nil {
|
||||
return nil, awserr.New(
|
||||
ErrCodeProcessProviderExecution,
|
||||
errMsgProcessProviderPipe,
|
||||
err)
|
||||
}
|
||||
|
||||
p.command.Stderr = os.Stderr // display stderr on console for MFA
|
||||
p.command.Stdout = outWritePipe // get creds json on process's stdout
|
||||
p.command.Stdin = os.Stdin // enable stdin for MFA
|
||||
|
||||
output := bytes.NewBuffer(make([]byte, 0, p.MaxBufSize))
|
||||
|
||||
stdoutCh := make(chan error, 1)
|
||||
go readInput(
|
||||
io.LimitReader(outReadPipe, int64(p.MaxBufSize)),
|
||||
output,
|
||||
stdoutCh)
|
||||
|
||||
execCh := make(chan error, 1)
|
||||
go executeCommand(*p.command, execCh)
|
||||
|
||||
finished := false
|
||||
var errors []error
|
||||
for !finished {
|
||||
select {
|
||||
case readError := <-stdoutCh:
|
||||
errors = appendError(errors, readError)
|
||||
finished = true
|
||||
case execError := <-execCh:
|
||||
err := outWritePipe.Close()
|
||||
errors = appendError(errors, err)
|
||||
errors = appendError(errors, execError)
|
||||
if errors != nil {
|
||||
return output.Bytes(), awserr.NewBatchError(
|
||||
ErrCodeProcessProviderExecution,
|
||||
errMsgProcessProviderProcess,
|
||||
errors)
|
||||
}
|
||||
case <-time.After(p.Timeout):
|
||||
finished = true
|
||||
return output.Bytes(), awserr.NewBatchError(
|
||||
ErrCodeProcessProviderExecution,
|
||||
errMsgProcessProviderTimeout,
|
||||
errors) // errors can be nil
|
||||
}
|
||||
}
|
||||
|
||||
out := output.Bytes()
|
||||
|
||||
if runtime.GOOS == "windows" {
|
||||
// windows adds slashes to quotes
|
||||
out = []byte(strings.Replace(string(out), `\"`, `"`, -1))
|
||||
}
|
||||
|
||||
return out, nil
|
||||
}
|
||||
|
||||
// appendError conveniently checks for nil before appending slice
|
||||
func appendError(errors []error, err error) []error {
|
||||
if err != nil {
|
||||
return append(errors, err)
|
||||
}
|
||||
return errors
|
||||
}
|
||||
|
||||
func executeCommand(cmd exec.Cmd, exec chan error) {
|
||||
// Start the command
|
||||
err := cmd.Start()
|
||||
if err == nil {
|
||||
err = cmd.Wait()
|
||||
}
|
||||
|
||||
exec <- err
|
||||
}
|
||||
|
||||
func readInput(r io.Reader, w io.Writer, read chan error) {
|
||||
tee := io.TeeReader(r, w)
|
||||
|
||||
_, err := ioutil.ReadAll(tee)
|
||||
|
||||
if err == io.EOF {
|
||||
err = nil
|
||||
}
|
||||
|
||||
read <- err // will only arrive here when write end of pipe is closed
|
||||
}
|
30
vendor/github.com/aws/aws-sdk-go/aws/credentials/shared_credentials_provider.go
generated
vendored
30
vendor/github.com/aws/aws-sdk-go/aws/credentials/shared_credentials_provider.go
generated
vendored
|
@ -4,9 +4,8 @@ import (
|
|||
"fmt"
|
||||
"os"
|
||||
|
||||
"github.com/go-ini/ini"
|
||||
|
||||
"github.com/aws/aws-sdk-go/aws/awserr"
|
||||
"github.com/aws/aws-sdk-go/internal/ini"
|
||||
"github.com/aws/aws-sdk-go/internal/shareddefaults"
|
||||
)
|
||||
|
||||
|
@ -77,36 +76,37 @@ func (p *SharedCredentialsProvider) IsExpired() bool {
|
|||
// The credentials retrieved from the profile will be returned or error. Error will be
|
||||
// returned if it fails to read from the file, or the data is invalid.
|
||||
func loadProfile(filename, profile string) (Value, error) {
|
||||
config, err := ini.Load(filename)
|
||||
config, err := ini.OpenFile(filename)
|
||||
if err != nil {
|
||||
return Value{ProviderName: SharedCredsProviderName}, awserr.New("SharedCredsLoad", "failed to load shared credentials file", err)
|
||||
}
|
||||
iniProfile, err := config.GetSection(profile)
|
||||
if err != nil {
|
||||
return Value{ProviderName: SharedCredsProviderName}, awserr.New("SharedCredsLoad", "failed to get profile", err)
|
||||
|
||||
iniProfile, ok := config.GetSection(profile)
|
||||
if !ok {
|
||||
return Value{ProviderName: SharedCredsProviderName}, awserr.New("SharedCredsLoad", "failed to get profile", nil)
|
||||
}
|
||||
|
||||
id, err := iniProfile.GetKey("aws_access_key_id")
|
||||
if err != nil {
|
||||
id := iniProfile.String("aws_access_key_id")
|
||||
if len(id) == 0 {
|
||||
return Value{ProviderName: SharedCredsProviderName}, awserr.New("SharedCredsAccessKey",
|
||||
fmt.Sprintf("shared credentials %s in %s did not contain aws_access_key_id", profile, filename),
|
||||
err)
|
||||
nil)
|
||||
}
|
||||
|
||||
secret, err := iniProfile.GetKey("aws_secret_access_key")
|
||||
if err != nil {
|
||||
secret := iniProfile.String("aws_secret_access_key")
|
||||
if len(secret) == 0 {
|
||||
return Value{ProviderName: SharedCredsProviderName}, awserr.New("SharedCredsSecret",
|
||||
fmt.Sprintf("shared credentials %s in %s did not contain aws_secret_access_key", profile, filename),
|
||||
nil)
|
||||
}
|
||||
|
||||
// Default to empty string if not found
|
||||
token := iniProfile.Key("aws_session_token")
|
||||
token := iniProfile.String("aws_session_token")
|
||||
|
||||
return Value{
|
||||
AccessKeyID: id.String(),
|
||||
SecretAccessKey: secret.String(),
|
||||
SessionToken: token.String(),
|
||||
AccessKeyID: id,
|
||||
SecretAccessKey: secret,
|
||||
SessionToken: token,
|
||||
ProviderName: SharedCredsProviderName,
|
||||
}, nil
|
||||
}
|
||||
|
|
|
@ -9,8 +9,6 @@ const StaticProviderName = "StaticProvider"
|
|||
|
||||
var (
|
||||
// ErrStaticCredentialsEmpty is emitted when static credentials are empty.
|
||||
//
|
||||
// @readonly
|
||||
ErrStaticCredentialsEmpty = awserr.New("EmptyStaticCreds", "static credentials are empty", nil)
|
||||
)
|
||||
|
||||
|
|
22
vendor/github.com/aws/aws-sdk-go/aws/credentials/stscreds/assume_role_provider.go
generated
vendored
22
vendor/github.com/aws/aws-sdk-go/aws/credentials/stscreds/assume_role_provider.go
generated
vendored
|
@ -80,16 +80,18 @@ package stscreds
|
|||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"time"
|
||||
|
||||
"github.com/aws/aws-sdk-go/aws"
|
||||
"github.com/aws/aws-sdk-go/aws/awserr"
|
||||
"github.com/aws/aws-sdk-go/aws/client"
|
||||
"github.com/aws/aws-sdk-go/aws/credentials"
|
||||
"github.com/aws/aws-sdk-go/internal/sdkrand"
|
||||
"github.com/aws/aws-sdk-go/service/sts"
|
||||
)
|
||||
|
||||
// StdinTokenProvider will prompt on stdout and read from stdin for a string value.
|
||||
// StdinTokenProvider will prompt on stderr and read from stdin for a string value.
|
||||
// An error is returned if reading from stdin fails.
|
||||
//
|
||||
// Use this function go read MFA tokens from stdin. The function makes no attempt
|
||||
|
@ -102,7 +104,7 @@ import (
|
|||
// Will wait forever until something is provided on the stdin.
|
||||
func StdinTokenProvider() (string, error) {
|
||||
var v string
|
||||
fmt.Printf("Assume Role MFA token code: ")
|
||||
fmt.Fprintf(os.Stderr, "Assume Role MFA token code: ")
|
||||
_, err := fmt.Scanln(&v)
|
||||
|
||||
return v, err
|
||||
|
@ -193,6 +195,18 @@ type AssumeRoleProvider struct {
|
|||
//
|
||||
// If ExpiryWindow is 0 or less it will be ignored.
|
||||
ExpiryWindow time.Duration
|
||||
|
||||
// MaxJitterFrac reduces the effective Duration of each credential requested
|
||||
// by a random percentage between 0 and MaxJitterFraction. MaxJitterFrac must
|
||||
// have a value between 0 and 1. Any other value may lead to expected behavior.
|
||||
// With a MaxJitterFrac value of 0, default) will no jitter will be used.
|
||||
//
|
||||
// For example, with a Duration of 30m and a MaxJitterFrac of 0.1, the
|
||||
// AssumeRole call will be made with an arbitrary Duration between 27m and
|
||||
// 30m.
|
||||
//
|
||||
// MaxJitterFrac should not be negative.
|
||||
MaxJitterFrac float64
|
||||
}
|
||||
|
||||
// NewCredentials returns a pointer to a new Credentials object wrapping the
|
||||
|
@ -244,7 +258,6 @@ func NewCredentialsWithClient(svc AssumeRoler, roleARN string, options ...func(*
|
|||
|
||||
// Retrieve generates a new set of temporary credentials using STS.
|
||||
func (p *AssumeRoleProvider) Retrieve() (credentials.Value, error) {
|
||||
|
||||
// Apply defaults where parameters are not set.
|
||||
if p.RoleSessionName == "" {
|
||||
// Try to work out a role name that will hopefully end up unique.
|
||||
|
@ -254,8 +267,9 @@ func (p *AssumeRoleProvider) Retrieve() (credentials.Value, error) {
|
|||
// Expire as often as AWS permits.
|
||||
p.Duration = DefaultDuration
|
||||
}
|
||||
jitter := time.Duration(sdkrand.SeededRand.Float64() * p.MaxJitterFrac * float64(p.Duration))
|
||||
input := &sts.AssumeRoleInput{
|
||||
DurationSeconds: aws.Int64(int64(p.Duration / time.Second)),
|
||||
DurationSeconds: aws.Int64(int64((p.Duration - jitter) / time.Second)),
|
||||
RoleArn: aws.String(p.RoleARN),
|
||||
RoleSessionName: aws.String(p.RoleSessionName),
|
||||
ExternalId: p.ExternalID,
|
||||
|
|
100
vendor/github.com/aws/aws-sdk-go/aws/credentials/stscreds/web_identity_provider.go
generated
vendored
Normal file
100
vendor/github.com/aws/aws-sdk-go/aws/credentials/stscreds/web_identity_provider.go
generated
vendored
Normal file
|
@ -0,0 +1,100 @@
|
|||
package stscreds
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/aws/aws-sdk-go/aws"
|
||||
"github.com/aws/aws-sdk-go/aws/awserr"
|
||||
"github.com/aws/aws-sdk-go/aws/client"
|
||||
"github.com/aws/aws-sdk-go/aws/credentials"
|
||||
"github.com/aws/aws-sdk-go/service/sts"
|
||||
"github.com/aws/aws-sdk-go/service/sts/stsiface"
|
||||
)
|
||||
|
||||
const (
|
||||
// ErrCodeWebIdentity will be used as an error code when constructing
|
||||
// a new error to be returned during session creation or retrieval.
|
||||
ErrCodeWebIdentity = "WebIdentityErr"
|
||||
|
||||
// WebIdentityProviderName is the web identity provider name
|
||||
WebIdentityProviderName = "WebIdentityCredentials"
|
||||
)
|
||||
|
||||
// now is used to return a time.Time object representing
|
||||
// the current time. This can be used to easily test and
|
||||
// compare test values.
|
||||
var now = time.Now
|
||||
|
||||
// WebIdentityRoleProvider is used to retrieve credentials using
|
||||
// an OIDC token.
|
||||
type WebIdentityRoleProvider struct {
|
||||
credentials.Expiry
|
||||
|
||||
client stsiface.STSAPI
|
||||
ExpiryWindow time.Duration
|
||||
|
||||
tokenFilePath string
|
||||
roleARN string
|
||||
roleSessionName string
|
||||
}
|
||||
|
||||
// NewWebIdentityCredentials will return a new set of credentials with a given
|
||||
// configuration, role arn, and token file path.
|
||||
func NewWebIdentityCredentials(c client.ConfigProvider, roleARN, roleSessionName, path string) *credentials.Credentials {
|
||||
svc := sts.New(c)
|
||||
p := NewWebIdentityRoleProvider(svc, roleARN, roleSessionName, path)
|
||||
return credentials.NewCredentials(p)
|
||||
}
|
||||
|
||||
// NewWebIdentityRoleProvider will return a new WebIdentityRoleProvider with the
|
||||
// provided stsiface.STSAPI
|
||||
func NewWebIdentityRoleProvider(svc stsiface.STSAPI, roleARN, roleSessionName, path string) *WebIdentityRoleProvider {
|
||||
return &WebIdentityRoleProvider{
|
||||
client: svc,
|
||||
tokenFilePath: path,
|
||||
roleARN: roleARN,
|
||||
roleSessionName: roleSessionName,
|
||||
}
|
||||
}
|
||||
|
||||
// Retrieve will attempt to assume a role from a token which is located at
|
||||
// 'WebIdentityTokenFilePath' specified destination and if that is empty an
|
||||
// error will be returned.
|
||||
func (p *WebIdentityRoleProvider) Retrieve() (credentials.Value, error) {
|
||||
b, err := ioutil.ReadFile(p.tokenFilePath)
|
||||
if err != nil {
|
||||
errMsg := fmt.Sprintf("unable to read file at %s", p.tokenFilePath)
|
||||
return credentials.Value{}, awserr.New(ErrCodeWebIdentity, errMsg, err)
|
||||
}
|
||||
|
||||
sessionName := p.roleSessionName
|
||||
if len(sessionName) == 0 {
|
||||
// session name is used to uniquely identify a session. This simply
|
||||
// uses unix time in nanoseconds to uniquely identify sessions.
|
||||
sessionName = strconv.FormatInt(now().UnixNano(), 10)
|
||||
}
|
||||
req, resp := p.client.AssumeRoleWithWebIdentityRequest(&sts.AssumeRoleWithWebIdentityInput{
|
||||
RoleArn: &p.roleARN,
|
||||
RoleSessionName: &sessionName,
|
||||
WebIdentityToken: aws.String(string(b)),
|
||||
})
|
||||
// InvalidIdentityToken error is a temporary error that can occur
|
||||
// when assuming an Role with a JWT web identity token.
|
||||
req.RetryErrorCodes = append(req.RetryErrorCodes, sts.ErrCodeInvalidIdentityTokenException)
|
||||
if err := req.Send(); err != nil {
|
||||
return credentials.Value{}, awserr.New(ErrCodeWebIdentity, "failed to retrieve credentials", err)
|
||||
}
|
||||
|
||||
p.SetExpiration(aws.TimeValue(resp.Credentials.Expiration), p.ExpiryWindow)
|
||||
|
||||
value := credentials.Value{
|
||||
AccessKeyID: aws.StringValue(resp.Credentials.AccessKeyId),
|
||||
SecretAccessKey: aws.StringValue(resp.Credentials.SecretAccessKey),
|
||||
SessionToken: aws.StringValue(resp.Credentials.SessionToken),
|
||||
ProviderName: WebIdentityProviderName,
|
||||
}
|
||||
return value, nil
|
||||
}
|
|
@ -1,30 +1,61 @@
|
|||
// Package csm provides Client Side Monitoring (CSM) which enables sending metrics
|
||||
// via UDP connection. Using the Start function will enable the reporting of
|
||||
// metrics on a given port. If Start is called, with different parameters, again,
|
||||
// a panic will occur.
|
||||
// Package csm provides the Client Side Monitoring (CSM) client which enables
|
||||
// sending metrics via UDP connection to the CSM agent. This package provides
|
||||
// control options, and configuration for the CSM client. The client can be
|
||||
// controlled manually, or automatically via the SDK's Session configuration.
|
||||
//
|
||||
// Pause can be called to pause any metrics publishing on a given port. Sessions
|
||||
// that have had their handlers modified via InjectHandlers may still be used.
|
||||
// However, the handlers will act as a no-op meaning no metrics will be published.
|
||||
// Enabling CSM client via SDK's Session configuration
|
||||
//
|
||||
// The CSM client can be enabled automatically via SDK's Session configuration.
|
||||
// The SDK's session configuration enables the CSM client if the AWS_CSM_PORT
|
||||
// environment variable is set to a non-empty value.
|
||||
//
|
||||
// The configuration options for the CSM client via the SDK's session
|
||||
// configuration are:
|
||||
//
|
||||
// * AWS_CSM_PORT=<port number>
|
||||
// The port number the CSM agent will receive metrics on.
|
||||
//
|
||||
// * AWS_CSM_HOST=<hostname or ip>
|
||||
// The hostname, or IP address the CSM agent will receive metrics on.
|
||||
// Without port number.
|
||||
//
|
||||
// Manually enabling the CSM client
|
||||
//
|
||||
// The CSM client can be started, paused, and resumed manually. The Start
|
||||
// function will enable the CSM client to publish metrics to the CSM agent. It
|
||||
// is safe to call Start concurrently, but if Start is called additional times
|
||||
// with different ClientID or address it will panic.
|
||||
//
|
||||
// Example:
|
||||
// r, err := csm.Start("clientID", ":31000")
|
||||
// if err != nil {
|
||||
// panic(fmt.Errorf("failed starting CSM: %v", err))
|
||||
// }
|
||||
//
|
||||
// When controlling the CSM client manually, you must also inject its request
|
||||
// handlers into the SDK's Session configuration for the SDK's API clients to
|
||||
// publish metrics.
|
||||
//
|
||||
// sess, err := session.NewSession(&aws.Config{})
|
||||
// if err != nil {
|
||||
// panic(fmt.Errorf("failed loading session: %v", err))
|
||||
// }
|
||||
//
|
||||
// // Add CSM client's metric publishing request handlers to the SDK's
|
||||
// // Session Configuration.
|
||||
// r.InjectHandlers(&sess.Handlers)
|
||||
//
|
||||
// client := s3.New(sess)
|
||||
// resp, err := client.GetObject(&s3.GetObjectInput{
|
||||
// Bucket: aws.String("bucket"),
|
||||
// Key: aws.String("key"),
|
||||
// })
|
||||
// Controlling CSM client
|
||||
//
|
||||
// Once the CSM client has been enabled the Get function will return a Reporter
|
||||
// value that you can use to pause and resume the metrics published to the CSM
|
||||
// agent. If Get function is called before the reporter is enabled with the
|
||||
// Start function or via SDK's Session configuration nil will be returned.
|
||||
//
|
||||
// The Pause method can be called to stop the CSM client publishing metrics to
|
||||
// the CSM agent. The Continue method will resume metric publishing.
|
||||
//
|
||||
// // Get the CSM client Reporter.
|
||||
// r := csm.Get()
|
||||
//
|
||||
// // Will pause monitoring
|
||||
// r.Pause()
|
||||
|
@ -35,12 +66,4 @@
|
|||
//
|
||||
// // Resume monitoring
|
||||
// r.Continue()
|
||||
//
|
||||
// Start returns a Reporter that is used to enable or disable monitoring. If
|
||||
// access to the Reporter is required later, calling Get will return the Reporter
|
||||
// singleton.
|
||||
//
|
||||
// Example:
|
||||
// r := csm.Get()
|
||||
// r.Continue()
|
||||
package csm
|
||||
|
|
|
@ -2,6 +2,7 @@ package csm
|
|||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
"sync"
|
||||
)
|
||||
|
||||
|
@ -9,19 +10,40 @@ var (
|
|||
lock sync.Mutex
|
||||
)
|
||||
|
||||
// Client side metric handler names
|
||||
const (
|
||||
APICallMetricHandlerName = "awscsm.SendAPICallMetric"
|
||||
APICallAttemptMetricHandlerName = "awscsm.SendAPICallAttemptMetric"
|
||||
// DefaultPort is used when no port is specified.
|
||||
DefaultPort = "31000"
|
||||
|
||||
// DefaultHost is the host that will be used when none is specified.
|
||||
DefaultHost = "127.0.0.1"
|
||||
)
|
||||
|
||||
// Start will start the a long running go routine to capture
|
||||
// AddressWithDefaults returns a CSM address built from the host and port
|
||||
// values. If the host or port is not set, default values will be used
|
||||
// instead. If host is "localhost" it will be replaced with "127.0.0.1".
|
||||
func AddressWithDefaults(host, port string) string {
|
||||
if len(host) == 0 || strings.EqualFold(host, "localhost") {
|
||||
host = DefaultHost
|
||||
}
|
||||
|
||||
if len(port) == 0 {
|
||||
port = DefaultPort
|
||||
}
|
||||
|
||||
// Only IP6 host can contain a colon
|
||||
if strings.Contains(host, ":") {
|
||||
return "[" + host + "]:" + port
|
||||
}
|
||||
|
||||
return host + ":" + port
|
||||
}
|
||||
|
||||
// Start will start a long running go routine to capture
|
||||
// client side metrics. Calling start multiple time will only
|
||||
// start the metric listener once and will panic if a different
|
||||
// client ID or port is passed in.
|
||||
//
|
||||
// Example:
|
||||
// r, err := csm.Start("clientID", "127.0.0.1:8094")
|
||||
// r, err := csm.Start("clientID", "127.0.0.1:31000")
|
||||
// if err != nil {
|
||||
// panic(fmt.Errorf("expected no error, but received %v", err))
|
||||
// }
|
||||
|
|
|
@ -3,6 +3,8 @@ package csm
|
|||
import (
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/aws/aws-sdk-go/aws"
|
||||
)
|
||||
|
||||
type metricTime time.Time
|
||||
|
@ -39,6 +41,12 @@ type metric struct {
|
|||
SDKException *string `json:"SdkException,omitempty"`
|
||||
SDKExceptionMessage *string `json:"SdkExceptionMessage,omitempty"`
|
||||
|
||||
FinalHTTPStatusCode *int `json:"FinalHttpStatusCode,omitempty"`
|
||||
FinalAWSException *string `json:"FinalAwsException,omitempty"`
|
||||
FinalAWSExceptionMessage *string `json:"FinalAwsExceptionMessage,omitempty"`
|
||||
FinalSDKException *string `json:"FinalSdkException,omitempty"`
|
||||
FinalSDKExceptionMessage *string `json:"FinalSdkExceptionMessage,omitempty"`
|
||||
|
||||
DestinationIP *string `json:"DestinationIp,omitempty"`
|
||||
ConnectionReused *int `json:"ConnectionReused,omitempty"`
|
||||
|
||||
|
@ -48,4 +56,54 @@ type metric struct {
|
|||
DNSLatency *int `json:"DnsLatency,omitempty"`
|
||||
TCPLatency *int `json:"TcpLatency,omitempty"`
|
||||
SSLLatency *int `json:"SslLatency,omitempty"`
|
||||
|
||||
MaxRetriesExceeded *int `json:"MaxRetriesExceeded,omitempty"`
|
||||
}
|
||||
|
||||
func (m *metric) TruncateFields() {
|
||||
m.ClientID = truncateString(m.ClientID, 255)
|
||||
m.UserAgent = truncateString(m.UserAgent, 256)
|
||||
|
||||
m.AWSException = truncateString(m.AWSException, 128)
|
||||
m.AWSExceptionMessage = truncateString(m.AWSExceptionMessage, 512)
|
||||
|
||||
m.SDKException = truncateString(m.SDKException, 128)
|
||||
m.SDKExceptionMessage = truncateString(m.SDKExceptionMessage, 512)
|
||||
|
||||
m.FinalAWSException = truncateString(m.FinalAWSException, 128)
|
||||
m.FinalAWSExceptionMessage = truncateString(m.FinalAWSExceptionMessage, 512)
|
||||
|
||||
m.FinalSDKException = truncateString(m.FinalSDKException, 128)
|
||||
m.FinalSDKExceptionMessage = truncateString(m.FinalSDKExceptionMessage, 512)
|
||||
}
|
||||
|
||||
func truncateString(v *string, l int) *string {
|
||||
if v != nil && len(*v) > l {
|
||||
nv := (*v)[:l]
|
||||
return &nv
|
||||
}
|
||||
|
||||
return v
|
||||
}
|
||||
|
||||
func (m *metric) SetException(e metricException) {
|
||||
switch te := e.(type) {
|
||||
case awsException:
|
||||
m.AWSException = aws.String(te.exception)
|
||||
m.AWSExceptionMessage = aws.String(te.message)
|
||||
case sdkException:
|
||||
m.SDKException = aws.String(te.exception)
|
||||
m.SDKExceptionMessage = aws.String(te.message)
|
||||
}
|
||||
}
|
||||
|
||||
func (m *metric) SetFinalException(e metricException) {
|
||||
switch te := e.(type) {
|
||||
case awsException:
|
||||
m.FinalAWSException = aws.String(te.exception)
|
||||
m.FinalAWSExceptionMessage = aws.String(te.message)
|
||||
case sdkException:
|
||||
m.FinalSDKException = aws.String(te.exception)
|
||||
m.FinalSDKExceptionMessage = aws.String(te.message)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -16,25 +16,26 @@ var (
|
|||
|
||||
type metricChan struct {
|
||||
ch chan metric
|
||||
paused int64
|
||||
paused *int64
|
||||
}
|
||||
|
||||
func newMetricChan(size int) metricChan {
|
||||
return metricChan{
|
||||
ch: make(chan metric, size),
|
||||
ch: make(chan metric, size),
|
||||
paused: new(int64),
|
||||
}
|
||||
}
|
||||
|
||||
func (ch *metricChan) Pause() {
|
||||
atomic.StoreInt64(&ch.paused, pausedEnum)
|
||||
atomic.StoreInt64(ch.paused, pausedEnum)
|
||||
}
|
||||
|
||||
func (ch *metricChan) Continue() {
|
||||
atomic.StoreInt64(&ch.paused, runningEnum)
|
||||
atomic.StoreInt64(ch.paused, runningEnum)
|
||||
}
|
||||
|
||||
func (ch *metricChan) IsPaused() bool {
|
||||
v := atomic.LoadInt64(&ch.paused)
|
||||
v := atomic.LoadInt64(ch.paused)
|
||||
return v == pausedEnum
|
||||
}
|
||||
|
||||
|
|
|
@ -0,0 +1,26 @@
|
|||
package csm
|
||||
|
||||
type metricException interface {
|
||||
Exception() string
|
||||
Message() string
|
||||
}
|
||||
|
||||
type requestException struct {
|
||||
exception string
|
||||
message string
|
||||
}
|
||||
|
||||
func (e requestException) Exception() string {
|
||||
return e.exception
|
||||
}
|
||||
func (e requestException) Message() string {
|
||||
return e.message
|
||||
}
|
||||
|
||||
type awsException struct {
|
||||
requestException
|
||||
}
|
||||
|
||||
type sdkException struct {
|
||||
requestException
|
||||
}
|
|
@ -10,11 +10,6 @@ import (
|
|||
"github.com/aws/aws-sdk-go/aws/request"
|
||||
)
|
||||
|
||||
const (
|
||||
// DefaultPort is used when no port is specified
|
||||
DefaultPort = "31000"
|
||||
)
|
||||
|
||||
// Reporter will gather metrics of API requests made and
|
||||
// send those metrics to the CSM endpoint.
|
||||
type Reporter struct {
|
||||
|
@ -71,7 +66,6 @@ func (rep *Reporter) sendAPICallAttemptMetric(r *request.Request) {
|
|||
|
||||
XAmzRequestID: aws.String(r.RequestID),
|
||||
|
||||
AttemptCount: aws.Int(r.RetryCount + 1),
|
||||
AttemptLatency: aws.Int(int(now.Sub(r.AttemptTime).Nanoseconds() / int64(time.Millisecond))),
|
||||
AccessKey: aws.String(creds.AccessKeyID),
|
||||
}
|
||||
|
@ -82,26 +76,29 @@ func (rep *Reporter) sendAPICallAttemptMetric(r *request.Request) {
|
|||
|
||||
if r.Error != nil {
|
||||
if awserr, ok := r.Error.(awserr.Error); ok {
|
||||
setError(&m, awserr)
|
||||
m.SetException(getMetricException(awserr))
|
||||
}
|
||||
}
|
||||
|
||||
m.TruncateFields()
|
||||
rep.metricsCh.Push(m)
|
||||
}
|
||||
|
||||
func setError(m *metric, err awserr.Error) {
|
||||
func getMetricException(err awserr.Error) metricException {
|
||||
msg := err.Error()
|
||||
code := err.Code()
|
||||
|
||||
switch code {
|
||||
case "RequestError",
|
||||
"SerializationError",
|
||||
request.ErrCodeSerialization,
|
||||
request.CanceledErrorCode:
|
||||
m.SDKException = &code
|
||||
m.SDKExceptionMessage = &msg
|
||||
return sdkException{
|
||||
requestException{exception: code, message: msg},
|
||||
}
|
||||
default:
|
||||
m.AWSException = &code
|
||||
m.AWSExceptionMessage = &msg
|
||||
return awsException{
|
||||
requestException{exception: code, message: msg},
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -112,16 +109,31 @@ func (rep *Reporter) sendAPICallMetric(r *request.Request) {
|
|||
|
||||
now := time.Now()
|
||||
m := metric{
|
||||
ClientID: aws.String(rep.clientID),
|
||||
API: aws.String(r.Operation.Name),
|
||||
Service: aws.String(r.ClientInfo.ServiceID),
|
||||
Timestamp: (*metricTime)(&now),
|
||||
Type: aws.String("ApiCall"),
|
||||
AttemptCount: aws.Int(r.RetryCount + 1),
|
||||
Latency: aws.Int(int(time.Now().Sub(r.Time) / time.Millisecond)),
|
||||
XAmzRequestID: aws.String(r.RequestID),
|
||||
ClientID: aws.String(rep.clientID),
|
||||
API: aws.String(r.Operation.Name),
|
||||
Service: aws.String(r.ClientInfo.ServiceID),
|
||||
Timestamp: (*metricTime)(&now),
|
||||
UserAgent: aws.String(r.HTTPRequest.Header.Get("User-Agent")),
|
||||
Type: aws.String("ApiCall"),
|
||||
AttemptCount: aws.Int(r.RetryCount + 1),
|
||||
Region: r.Config.Region,
|
||||
Latency: aws.Int(int(time.Since(r.Time) / time.Millisecond)),
|
||||
XAmzRequestID: aws.String(r.RequestID),
|
||||
MaxRetriesExceeded: aws.Int(boolIntValue(r.RetryCount >= r.MaxRetries())),
|
||||
}
|
||||
|
||||
if r.HTTPResponse != nil {
|
||||
m.FinalHTTPStatusCode = aws.Int(r.HTTPResponse.StatusCode)
|
||||
}
|
||||
|
||||
if r.Error != nil {
|
||||
if awserr, ok := r.Error.(awserr.Error); ok {
|
||||
m.SetFinalException(getMetricException(awserr))
|
||||
}
|
||||
}
|
||||
|
||||
m.TruncateFields()
|
||||
|
||||
// TODO: Probably want to figure something out for logging dropped
|
||||
// metrics
|
||||
rep.metricsCh.Push(m)
|
||||
|
@ -172,8 +184,9 @@ func (rep *Reporter) start() {
|
|||
}
|
||||
}
|
||||
|
||||
// Pause will pause the metric channel preventing any new metrics from
|
||||
// being added.
|
||||
// Pause will pause the metric channel preventing any new metrics from being
|
||||
// added. It is safe to call concurrently with other calls to Pause, but if
|
||||
// called concurently with Continue can lead to unexpected state.
|
||||
func (rep *Reporter) Pause() {
|
||||
lock.Lock()
|
||||
defer lock.Unlock()
|
||||
|
@ -185,8 +198,9 @@ func (rep *Reporter) Pause() {
|
|||
rep.close()
|
||||
}
|
||||
|
||||
// Continue will reopen the metric channel and allow for monitoring
|
||||
// to be resumed.
|
||||
// Continue will reopen the metric channel and allow for monitoring to be
|
||||
// resumed. It is safe to call concurrently with other calls to Continue, but
|
||||
// if called concurently with Pause can lead to unexpected state.
|
||||
func (rep *Reporter) Continue() {
|
||||
lock.Lock()
|
||||
defer lock.Unlock()
|
||||
|
@ -201,10 +215,18 @@ func (rep *Reporter) Continue() {
|
|||
rep.metricsCh.Continue()
|
||||
}
|
||||
|
||||
// Client side metric handler names
|
||||
const (
|
||||
APICallMetricHandlerName = "awscsm.SendAPICallMetric"
|
||||
APICallAttemptMetricHandlerName = "awscsm.SendAPICallAttemptMetric"
|
||||
)
|
||||
|
||||
// InjectHandlers will will enable client side metrics and inject the proper
|
||||
// handlers to handle how metrics are sent.
|
||||
//
|
||||
// Example:
|
||||
// InjectHandlers is NOT safe to call concurrently. Calling InjectHandlers
|
||||
// multiple times may lead to unexpected behavior, (e.g. duplicate metrics).
|
||||
//
|
||||
// // Start must be called in order to inject the correct handlers
|
||||
// r, err := csm.Start("clientID", "127.0.0.1:8094")
|
||||
// if err != nil {
|
||||
|
@ -221,11 +243,22 @@ func (rep *Reporter) InjectHandlers(handlers *request.Handlers) {
|
|||
return
|
||||
}
|
||||
|
||||
apiCallHandler := request.NamedHandler{Name: APICallMetricHandlerName, Fn: rep.sendAPICallMetric}
|
||||
apiCallAttemptHandler := request.NamedHandler{Name: APICallAttemptMetricHandlerName, Fn: rep.sendAPICallAttemptMetric}
|
||||
handlers.Complete.PushFrontNamed(request.NamedHandler{
|
||||
Name: APICallMetricHandlerName,
|
||||
Fn: rep.sendAPICallMetric,
|
||||
})
|
||||
|
||||
handlers.Complete.PushFrontNamed(apiCallHandler)
|
||||
handlers.Complete.PushFrontNamed(apiCallAttemptHandler)
|
||||
|
||||
handlers.AfterRetry.PushFrontNamed(apiCallAttemptHandler)
|
||||
handlers.CompleteAttempt.PushFrontNamed(request.NamedHandler{
|
||||
Name: APICallAttemptMetricHandlerName,
|
||||
Fn: rep.sendAPICallAttemptMetric,
|
||||
})
|
||||
}
|
||||
|
||||
// boolIntValue return 1 for true and 0 for false.
|
||||
func boolIntValue(b bool) int {
|
||||
if b {
|
||||
return 1
|
||||
}
|
||||
|
||||
return 0
|
||||
}
|
||||
|
|
|
@ -24,6 +24,7 @@ import (
|
|||
"github.com/aws/aws-sdk-go/aws/ec2metadata"
|
||||
"github.com/aws/aws-sdk-go/aws/endpoints"
|
||||
"github.com/aws/aws-sdk-go/aws/request"
|
||||
"github.com/aws/aws-sdk-go/internal/shareddefaults"
|
||||
)
|
||||
|
||||
// A Defaults provides a collection of default values for SDK clients.
|
||||
|
@ -112,8 +113,8 @@ func CredProviders(cfg *aws.Config, handlers request.Handlers) []credentials.Pro
|
|||
}
|
||||
|
||||
const (
|
||||
httpProviderEnvVar = "AWS_CONTAINER_CREDENTIALS_FULL_URI"
|
||||
ecsCredsProviderEnvVar = "AWS_CONTAINER_CREDENTIALS_RELATIVE_URI"
|
||||
httpProviderAuthorizationEnvVar = "AWS_CONTAINER_AUTHORIZATION_TOKEN"
|
||||
httpProviderEnvVar = "AWS_CONTAINER_CREDENTIALS_FULL_URI"
|
||||
)
|
||||
|
||||
// RemoteCredProvider returns a credentials provider for the default remote
|
||||
|
@ -123,8 +124,8 @@ func RemoteCredProvider(cfg aws.Config, handlers request.Handlers) credentials.P
|
|||
return localHTTPCredProvider(cfg, handlers, u)
|
||||
}
|
||||
|
||||
if uri := os.Getenv(ecsCredsProviderEnvVar); len(uri) > 0 {
|
||||
u := fmt.Sprintf("http://169.254.170.2%s", uri)
|
||||
if uri := os.Getenv(shareddefaults.ECSCredsProviderEnvVar); len(uri) > 0 {
|
||||
u := fmt.Sprintf("%s%s", shareddefaults.ECSContainerCredentialsURI, uri)
|
||||
return httpCredProvider(cfg, handlers, u)
|
||||
}
|
||||
|
||||
|
@ -187,6 +188,7 @@ func httpCredProvider(cfg aws.Config, handlers request.Handlers, u string) crede
|
|||
return endpointcreds.NewProviderClient(cfg, handlers, u,
|
||||
func(p *endpointcreds.Provider) {
|
||||
p.ExpiryWindow = 5 * time.Minute
|
||||
p.AuthorizationToken = os.Getenv(httpProviderAuthorizationEnvVar)
|
||||
},
|
||||
)
|
||||
}
|
||||
|
|
|
@ -24,8 +24,9 @@ func (c *EC2Metadata) GetMetadata(p string) (string, error) {
|
|||
|
||||
output := &metadataOutput{}
|
||||
req := c.NewRequest(op, nil, output)
|
||||
err := req.Send()
|
||||
|
||||
return output.Content, req.Send()
|
||||
return output.Content, err
|
||||
}
|
||||
|
||||
// GetUserData returns the userdata that was configured for the service. If
|
||||
|
@ -45,8 +46,9 @@ func (c *EC2Metadata) GetUserData() (string, error) {
|
|||
r.Error = awserr.New("NotFoundError", "user-data not found", r.Error)
|
||||
}
|
||||
})
|
||||
err := req.Send()
|
||||
|
||||
return output.Content, req.Send()
|
||||
return output.Content, err
|
||||
}
|
||||
|
||||
// GetDynamicData uses the path provided to request information from the EC2
|
||||
|
@ -61,8 +63,9 @@ func (c *EC2Metadata) GetDynamicData(p string) (string, error) {
|
|||
|
||||
output := &metadataOutput{}
|
||||
req := c.NewRequest(op, nil, output)
|
||||
err := req.Send()
|
||||
|
||||
return output.Content, req.Send()
|
||||
return output.Content, err
|
||||
}
|
||||
|
||||
// GetInstanceIdentityDocument retrieves an identity document describing an
|
||||
|
@ -79,7 +82,7 @@ func (c *EC2Metadata) GetInstanceIdentityDocument() (EC2InstanceIdentityDocument
|
|||
doc := EC2InstanceIdentityDocument{}
|
||||
if err := json.NewDecoder(strings.NewReader(resp)).Decode(&doc); err != nil {
|
||||
return EC2InstanceIdentityDocument{},
|
||||
awserr.New("SerializationError",
|
||||
awserr.New(request.ErrCodeSerialization,
|
||||
"failed to decode EC2 instance identity document", err)
|
||||
}
|
||||
|
||||
|
@ -98,7 +101,7 @@ func (c *EC2Metadata) IAMInfo() (EC2IAMInfo, error) {
|
|||
info := EC2IAMInfo{}
|
||||
if err := json.NewDecoder(strings.NewReader(resp)).Decode(&info); err != nil {
|
||||
return EC2IAMInfo{},
|
||||
awserr.New("SerializationError",
|
||||
awserr.New(request.ErrCodeSerialization,
|
||||
"failed to decode EC2 IAM info", err)
|
||||
}
|
||||
|
||||
|
@ -118,6 +121,10 @@ func (c *EC2Metadata) Region() (string, error) {
|
|||
return "", err
|
||||
}
|
||||
|
||||
if len(resp) == 0 {
|
||||
return "", awserr.New("EC2MetadataError", "invalid Region response", nil)
|
||||
}
|
||||
|
||||
// returns region without the suffix. Eg: us-west-2a becomes us-west-2
|
||||
return resp[:len(resp)-1], nil
|
||||
}
|
||||
|
@ -145,18 +152,19 @@ type EC2IAMInfo struct {
|
|||
// An EC2InstanceIdentityDocument provides the shape for unmarshaling
|
||||
// an instance identity document
|
||||
type EC2InstanceIdentityDocument struct {
|
||||
DevpayProductCodes []string `json:"devpayProductCodes"`
|
||||
AvailabilityZone string `json:"availabilityZone"`
|
||||
PrivateIP string `json:"privateIp"`
|
||||
Version string `json:"version"`
|
||||
Region string `json:"region"`
|
||||
InstanceID string `json:"instanceId"`
|
||||
BillingProducts []string `json:"billingProducts"`
|
||||
InstanceType string `json:"instanceType"`
|
||||
AccountID string `json:"accountId"`
|
||||
PendingTime time.Time `json:"pendingTime"`
|
||||
ImageID string `json:"imageId"`
|
||||
KernelID string `json:"kernelId"`
|
||||
RamdiskID string `json:"ramdiskId"`
|
||||
Architecture string `json:"architecture"`
|
||||
DevpayProductCodes []string `json:"devpayProductCodes"`
|
||||
MarketplaceProductCodes []string `json:"marketplaceProductCodes"`
|
||||
AvailabilityZone string `json:"availabilityZone"`
|
||||
PrivateIP string `json:"privateIp"`
|
||||
Version string `json:"version"`
|
||||
Region string `json:"region"`
|
||||
InstanceID string `json:"instanceId"`
|
||||
BillingProducts []string `json:"billingProducts"`
|
||||
InstanceType string `json:"instanceType"`
|
||||
AccountID string `json:"accountId"`
|
||||
PendingTime time.Time `json:"pendingTime"`
|
||||
ImageID string `json:"imageId"`
|
||||
KernelID string `json:"kernelId"`
|
||||
RamdiskID string `json:"ramdiskId"`
|
||||
Architecture string `json:"architecture"`
|
||||
}
|
||||
|
|
|
@ -4,7 +4,7 @@
|
|||
// This package's client can be disabled completely by setting the environment
|
||||
// variable "AWS_EC2_METADATA_DISABLED=true". This environment variable set to
|
||||
// true instructs the SDK to disable the EC2 Metadata client. The client cannot
|
||||
// be used while the environemnt variable is set to true, (case insensitive).
|
||||
// be used while the environment variable is set to true, (case insensitive).
|
||||
package ec2metadata
|
||||
|
||||
import (
|
||||
|
@ -72,6 +72,7 @@ func NewClient(cfg aws.Config, handlers request.Handlers, endpoint, signingRegio
|
|||
cfg,
|
||||
metadata.ClientInfo{
|
||||
ServiceName: ServiceName,
|
||||
ServiceID: ServiceName,
|
||||
Endpoint: endpoint,
|
||||
APIVersion: "latest",
|
||||
},
|
||||
|
@ -91,6 +92,9 @@ func NewClient(cfg aws.Config, handlers request.Handlers, endpoint, signingRegio
|
|||
svc.Handlers.Send.SwapNamed(request.NamedHandler{
|
||||
Name: corehandlers.SendHandler.Name,
|
||||
Fn: func(r *request.Request) {
|
||||
r.HTTPResponse = &http.Response{
|
||||
Header: http.Header{},
|
||||
}
|
||||
r.Error = awserr.New(
|
||||
request.CanceledErrorCode,
|
||||
"EC2 IMDS access disabled via "+disableServiceEnvVar+" env var",
|
||||
|
@ -119,7 +123,7 @@ func unmarshalHandler(r *request.Request) {
|
|||
defer r.HTTPResponse.Body.Close()
|
||||
b := &bytes.Buffer{}
|
||||
if _, err := io.Copy(b, r.HTTPResponse.Body); err != nil {
|
||||
r.Error = awserr.New("SerializationError", "unable to unmarshal EC2 metadata respose", err)
|
||||
r.Error = awserr.New(request.ErrCodeSerialization, "unable to unmarshal EC2 metadata response", err)
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -132,7 +136,7 @@ func unmarshalError(r *request.Request) {
|
|||
defer r.HTTPResponse.Body.Close()
|
||||
b := &bytes.Buffer{}
|
||||
if _, err := io.Copy(b, r.HTTPResponse.Body); err != nil {
|
||||
r.Error = awserr.New("SerializationError", "unable to unmarshal EC2 metadata error respose", err)
|
||||
r.Error = awserr.New(request.ErrCodeSerialization, "unable to unmarshal EC2 metadata error response", err)
|
||||
return
|
||||
}
|
||||
|
||||
|
|
|
@ -85,6 +85,7 @@ func decodeV3Endpoints(modelDef modelDefinition, opts DecodeModelOptions) (Resol
|
|||
custAddS3DualStack(p)
|
||||
custRmIotDataService(p)
|
||||
custFixAppAutoscalingChina(p)
|
||||
custFixAppAutoscalingUsGov(p)
|
||||
}
|
||||
|
||||
return ps, nil
|
||||
|
@ -95,7 +96,12 @@ func custAddS3DualStack(p *partition) {
|
|||
return
|
||||
}
|
||||
|
||||
s, ok := p.Services["s3"]
|
||||
custAddDualstack(p, "s3")
|
||||
custAddDualstack(p, "s3-control")
|
||||
}
|
||||
|
||||
func custAddDualstack(p *partition, svcName string) {
|
||||
s, ok := p.Services[svcName]
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
@ -103,7 +109,7 @@ func custAddS3DualStack(p *partition) {
|
|||
s.Defaults.HasDualStack = boxedTrue
|
||||
s.Defaults.DualStackHostname = "{service}.dualstack.{region}.{dnsSuffix}"
|
||||
|
||||
p.Services["s3"] = s
|
||||
p.Services[svcName] = s
|
||||
}
|
||||
|
||||
func custAddEC2Metadata(p *partition) {
|
||||
|
@ -144,6 +150,33 @@ func custFixAppAutoscalingChina(p *partition) {
|
|||
p.Services[serviceName] = s
|
||||
}
|
||||
|
||||
func custFixAppAutoscalingUsGov(p *partition) {
|
||||
if p.ID != "aws-us-gov" {
|
||||
return
|
||||
}
|
||||
|
||||
const serviceName = "application-autoscaling"
|
||||
s, ok := p.Services[serviceName]
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
if a := s.Defaults.CredentialScope.Service; a != "" {
|
||||
fmt.Printf("custFixAppAutoscalingUsGov: ignoring customization, expected empty credential scope service, got %s\n", a)
|
||||
return
|
||||
}
|
||||
|
||||
if a := s.Defaults.Hostname; a != "" {
|
||||
fmt.Printf("custFixAppAutoscalingUsGov: ignoring customization, expected empty hostname, got %s\n", a)
|
||||
return
|
||||
}
|
||||
|
||||
s.Defaults.CredentialScope.Service = "application-autoscaling"
|
||||
s.Defaults.Hostname = "autoscaling.{region}.amazonaws.com"
|
||||
|
||||
p.Services[serviceName] = s
|
||||
}
|
||||
|
||||
type decodeModelError struct {
|
||||
awsError
|
||||
}
|
||||
|
|
File diff suppressed because it is too large
Load Diff
141
vendor/github.com/aws/aws-sdk-go/aws/endpoints/dep_service_ids.go
generated
vendored
Normal file
141
vendor/github.com/aws/aws-sdk-go/aws/endpoints/dep_service_ids.go
generated
vendored
Normal file
|
@ -0,0 +1,141 @@
|
|||
package endpoints
|
||||
|
||||
// Service identifiers
|
||||
//
|
||||
// Deprecated: Use client package's EndpointsID value instead of these
|
||||
// ServiceIDs. These IDs are not maintained, and are out of date.
|
||||
const (
|
||||
A4bServiceID = "a4b" // A4b.
|
||||
AcmServiceID = "acm" // Acm.
|
||||
AcmPcaServiceID = "acm-pca" // AcmPca.
|
||||
ApiMediatailorServiceID = "api.mediatailor" // ApiMediatailor.
|
||||
ApiPricingServiceID = "api.pricing" // ApiPricing.
|
||||
ApiSagemakerServiceID = "api.sagemaker" // ApiSagemaker.
|
||||
ApigatewayServiceID = "apigateway" // Apigateway.
|
||||
ApplicationAutoscalingServiceID = "application-autoscaling" // ApplicationAutoscaling.
|
||||
Appstream2ServiceID = "appstream2" // Appstream2.
|
||||
AppsyncServiceID = "appsync" // Appsync.
|
||||
AthenaServiceID = "athena" // Athena.
|
||||
AutoscalingServiceID = "autoscaling" // Autoscaling.
|
||||
AutoscalingPlansServiceID = "autoscaling-plans" // AutoscalingPlans.
|
||||
BatchServiceID = "batch" // Batch.
|
||||
BudgetsServiceID = "budgets" // Budgets.
|
||||
CeServiceID = "ce" // Ce.
|
||||
ChimeServiceID = "chime" // Chime.
|
||||
Cloud9ServiceID = "cloud9" // Cloud9.
|
||||
ClouddirectoryServiceID = "clouddirectory" // Clouddirectory.
|
||||
CloudformationServiceID = "cloudformation" // Cloudformation.
|
||||
CloudfrontServiceID = "cloudfront" // Cloudfront.
|
||||
CloudhsmServiceID = "cloudhsm" // Cloudhsm.
|
||||
Cloudhsmv2ServiceID = "cloudhsmv2" // Cloudhsmv2.
|
||||
CloudsearchServiceID = "cloudsearch" // Cloudsearch.
|
||||
CloudtrailServiceID = "cloudtrail" // Cloudtrail.
|
||||
CodebuildServiceID = "codebuild" // Codebuild.
|
||||
CodecommitServiceID = "codecommit" // Codecommit.
|
||||
CodedeployServiceID = "codedeploy" // Codedeploy.
|
||||
CodepipelineServiceID = "codepipeline" // Codepipeline.
|
||||
CodestarServiceID = "codestar" // Codestar.
|
||||
CognitoIdentityServiceID = "cognito-identity" // CognitoIdentity.
|
||||
CognitoIdpServiceID = "cognito-idp" // CognitoIdp.
|
||||
CognitoSyncServiceID = "cognito-sync" // CognitoSync.
|
||||
ComprehendServiceID = "comprehend" // Comprehend.
|
||||
ConfigServiceID = "config" // Config.
|
||||
CurServiceID = "cur" // Cur.
|
||||
DatapipelineServiceID = "datapipeline" // Datapipeline.
|
||||
DaxServiceID = "dax" // Dax.
|
||||
DevicefarmServiceID = "devicefarm" // Devicefarm.
|
||||
DirectconnectServiceID = "directconnect" // Directconnect.
|
||||
DiscoveryServiceID = "discovery" // Discovery.
|
||||
DmsServiceID = "dms" // Dms.
|
||||
DsServiceID = "ds" // Ds.
|
||||
DynamodbServiceID = "dynamodb" // Dynamodb.
|
||||
Ec2ServiceID = "ec2" // Ec2.
|
||||
Ec2metadataServiceID = "ec2metadata" // Ec2metadata.
|
||||
EcrServiceID = "ecr" // Ecr.
|
||||
EcsServiceID = "ecs" // Ecs.
|
||||
ElasticacheServiceID = "elasticache" // Elasticache.
|
||||
ElasticbeanstalkServiceID = "elasticbeanstalk" // Elasticbeanstalk.
|
||||
ElasticfilesystemServiceID = "elasticfilesystem" // Elasticfilesystem.
|
||||
ElasticloadbalancingServiceID = "elasticloadbalancing" // Elasticloadbalancing.
|
||||
ElasticmapreduceServiceID = "elasticmapreduce" // Elasticmapreduce.
|
||||
ElastictranscoderServiceID = "elastictranscoder" // Elastictranscoder.
|
||||
EmailServiceID = "email" // Email.
|
||||
EntitlementMarketplaceServiceID = "entitlement.marketplace" // EntitlementMarketplace.
|
||||
EsServiceID = "es" // Es.
|
||||
EventsServiceID = "events" // Events.
|
||||
FirehoseServiceID = "firehose" // Firehose.
|
||||
FmsServiceID = "fms" // Fms.
|
||||
GameliftServiceID = "gamelift" // Gamelift.
|
||||
GlacierServiceID = "glacier" // Glacier.
|
||||
GlueServiceID = "glue" // Glue.
|
||||
GreengrassServiceID = "greengrass" // Greengrass.
|
||||
GuarddutyServiceID = "guardduty" // Guardduty.
|
||||
HealthServiceID = "health" // Health.
|
||||
IamServiceID = "iam" // Iam.
|
||||
ImportexportServiceID = "importexport" // Importexport.
|
||||
InspectorServiceID = "inspector" // Inspector.
|
||||
IotServiceID = "iot" // Iot.
|
||||
IotanalyticsServiceID = "iotanalytics" // Iotanalytics.
|
||||
KinesisServiceID = "kinesis" // Kinesis.
|
||||
KinesisanalyticsServiceID = "kinesisanalytics" // Kinesisanalytics.
|
||||
KinesisvideoServiceID = "kinesisvideo" // Kinesisvideo.
|
||||
KmsServiceID = "kms" // Kms.
|
||||
LambdaServiceID = "lambda" // Lambda.
|
||||
LightsailServiceID = "lightsail" // Lightsail.
|
||||
LogsServiceID = "logs" // Logs.
|
||||
MachinelearningServiceID = "machinelearning" // Machinelearning.
|
||||
MarketplacecommerceanalyticsServiceID = "marketplacecommerceanalytics" // Marketplacecommerceanalytics.
|
||||
MediaconvertServiceID = "mediaconvert" // Mediaconvert.
|
||||
MedialiveServiceID = "medialive" // Medialive.
|
||||
MediapackageServiceID = "mediapackage" // Mediapackage.
|
||||
MediastoreServiceID = "mediastore" // Mediastore.
|
||||
MeteringMarketplaceServiceID = "metering.marketplace" // MeteringMarketplace.
|
||||
MghServiceID = "mgh" // Mgh.
|
||||
MobileanalyticsServiceID = "mobileanalytics" // Mobileanalytics.
|
||||
ModelsLexServiceID = "models.lex" // ModelsLex.
|
||||
MonitoringServiceID = "monitoring" // Monitoring.
|
||||
MturkRequesterServiceID = "mturk-requester" // MturkRequester.
|
||||
NeptuneServiceID = "neptune" // Neptune.
|
||||
OpsworksServiceID = "opsworks" // Opsworks.
|
||||
OpsworksCmServiceID = "opsworks-cm" // OpsworksCm.
|
||||
OrganizationsServiceID = "organizations" // Organizations.
|
||||
PinpointServiceID = "pinpoint" // Pinpoint.
|
||||
PollyServiceID = "polly" // Polly.
|
||||
RdsServiceID = "rds" // Rds.
|
||||
RedshiftServiceID = "redshift" // Redshift.
|
||||
RekognitionServiceID = "rekognition" // Rekognition.
|
||||
ResourceGroupsServiceID = "resource-groups" // ResourceGroups.
|
||||
Route53ServiceID = "route53" // Route53.
|
||||
Route53domainsServiceID = "route53domains" // Route53domains.
|
||||
RuntimeLexServiceID = "runtime.lex" // RuntimeLex.
|
||||
RuntimeSagemakerServiceID = "runtime.sagemaker" // RuntimeSagemaker.
|
||||
S3ServiceID = "s3" // S3.
|
||||
S3ControlServiceID = "s3-control" // S3Control.
|
||||
SagemakerServiceID = "api.sagemaker" // Sagemaker.
|
||||
SdbServiceID = "sdb" // Sdb.
|
||||
SecretsmanagerServiceID = "secretsmanager" // Secretsmanager.
|
||||
ServerlessrepoServiceID = "serverlessrepo" // Serverlessrepo.
|
||||
ServicecatalogServiceID = "servicecatalog" // Servicecatalog.
|
||||
ServicediscoveryServiceID = "servicediscovery" // Servicediscovery.
|
||||
ShieldServiceID = "shield" // Shield.
|
||||
SmsServiceID = "sms" // Sms.
|
||||
SnowballServiceID = "snowball" // Snowball.
|
||||
SnsServiceID = "sns" // Sns.
|
||||
SqsServiceID = "sqs" // Sqs.
|
||||
SsmServiceID = "ssm" // Ssm.
|
||||
StatesServiceID = "states" // States.
|
||||
StoragegatewayServiceID = "storagegateway" // Storagegateway.
|
||||
StreamsDynamodbServiceID = "streams.dynamodb" // StreamsDynamodb.
|
||||
StsServiceID = "sts" // Sts.
|
||||
SupportServiceID = "support" // Support.
|
||||
SwfServiceID = "swf" // Swf.
|
||||
TaggingServiceID = "tagging" // Tagging.
|
||||
TransferServiceID = "transfer" // Transfer.
|
||||
TranslateServiceID = "translate" // Translate.
|
||||
WafServiceID = "waf" // Waf.
|
||||
WafRegionalServiceID = "waf-regional" // WafRegional.
|
||||
WorkdocsServiceID = "workdocs" // Workdocs.
|
||||
WorkmailServiceID = "workmail" // Workmail.
|
||||
WorkspacesServiceID = "workspaces" // Workspaces.
|
||||
XrayServiceID = "xray" // Xray.
|
||||
)
|
|
@ -3,6 +3,7 @@ package endpoints
|
|||
import (
|
||||
"fmt"
|
||||
"regexp"
|
||||
"strings"
|
||||
|
||||
"github.com/aws/aws-sdk-go/aws/awserr"
|
||||
)
|
||||
|
@ -35,7 +36,7 @@ type Options struct {
|
|||
//
|
||||
// If resolving an endpoint on the partition list the provided region will
|
||||
// be used to determine which partition's domain name pattern to the service
|
||||
// endpoint ID with. If both the service and region are unkonwn and resolving
|
||||
// endpoint ID with. If both the service and region are unknown and resolving
|
||||
// the endpoint on partition list an UnknownEndpointError error will be returned.
|
||||
//
|
||||
// If resolving and endpoint on a partition specific resolver that partition's
|
||||
|
@ -46,6 +47,43 @@ type Options struct {
|
|||
//
|
||||
// This option is ignored if StrictMatching is enabled.
|
||||
ResolveUnknownService bool
|
||||
|
||||
// STS Regional Endpoint flag helps with resolving the STS endpoint
|
||||
STSRegionalEndpoint STSRegionalEndpoint
|
||||
}
|
||||
|
||||
// STSRegionalEndpoint is an enum type alias for int
|
||||
// It is used internally by the core sdk as STS Regional Endpoint flag value
|
||||
type STSRegionalEndpoint int
|
||||
|
||||
const (
|
||||
|
||||
// UnsetSTSEndpoint represents that STS Regional Endpoint flag is not specified.
|
||||
UnsetSTSEndpoint STSRegionalEndpoint = iota
|
||||
|
||||
// LegacySTSEndpoint represents when STS Regional Endpoint flag is specified
|
||||
// to use legacy endpoints.
|
||||
LegacySTSEndpoint
|
||||
|
||||
// RegionalSTSEndpoint represents when STS Regional Endpoint flag is specified
|
||||
// to use regional endpoints.
|
||||
RegionalSTSEndpoint
|
||||
)
|
||||
|
||||
// GetSTSRegionalEndpoint function returns the STSRegionalEndpointFlag based
|
||||
// on the input string provided in env config or shared config by the user.
|
||||
//
|
||||
// `legacy`, `regional` are the only case-insensitive valid strings for
|
||||
// resolving the STS regional Endpoint flag.
|
||||
func GetSTSRegionalEndpoint(s string) (STSRegionalEndpoint, error) {
|
||||
switch {
|
||||
case strings.EqualFold(s, "legacy"):
|
||||
return LegacySTSEndpoint, nil
|
||||
case strings.EqualFold(s, "regional"):
|
||||
return RegionalSTSEndpoint, nil
|
||||
default:
|
||||
return UnsetSTSEndpoint, fmt.Errorf("unable to resolve the value of STSRegionalEndpoint for %v", s)
|
||||
}
|
||||
}
|
||||
|
||||
// Set combines all of the option functions together.
|
||||
|
@ -79,6 +117,12 @@ func ResolveUnknownServiceOption(o *Options) {
|
|||
o.ResolveUnknownService = true
|
||||
}
|
||||
|
||||
// STSRegionalEndpointOption enables the STS endpoint resolver behavior to resolve
|
||||
// STS endpoint to their regional endpoint, instead of the global endpoint.
|
||||
func STSRegionalEndpointOption(o *Options) {
|
||||
o.STSRegionalEndpoint = RegionalSTSEndpoint
|
||||
}
|
||||
|
||||
// A Resolver provides the interface for functionality to resolve endpoints.
|
||||
// The build in Partition and DefaultResolver return value satisfy this interface.
|
||||
type Resolver interface {
|
||||
|
@ -170,10 +214,13 @@ func PartitionForRegion(ps []Partition, regionID string) (Partition, bool) {
|
|||
// A Partition provides the ability to enumerate the partition's regions
|
||||
// and services.
|
||||
type Partition struct {
|
||||
id string
|
||||
p *partition
|
||||
id, dnsSuffix string
|
||||
p *partition
|
||||
}
|
||||
|
||||
// DNSSuffix returns the base domain name of the partition.
|
||||
func (p Partition) DNSSuffix() string { return p.dnsSuffix }
|
||||
|
||||
// ID returns the identifier of the partition.
|
||||
func (p Partition) ID() string { return p.id }
|
||||
|
||||
|
@ -191,7 +238,7 @@ func (p Partition) ID() string { return p.id }
|
|||
// require the provided service and region to be known by the partition.
|
||||
// If the endpoint cannot be strictly resolved an error will be returned. This
|
||||
// mode is useful to ensure the endpoint resolved is valid. Without
|
||||
// StrictMatching enabled the endpoint returned my look valid but may not work.
|
||||
// StrictMatching enabled the endpoint returned may look valid but may not work.
|
||||
// StrictMatching requires the SDK to be updated if you want to take advantage
|
||||
// of new regions and services expansions.
|
||||
//
|
||||
|
@ -347,6 +394,9 @@ type ResolvedEndpoint struct {
|
|||
// The endpoint URL
|
||||
URL string
|
||||
|
||||
// The endpoint partition
|
||||
PartitionID string
|
||||
|
||||
// The region that should be used for signing requests.
|
||||
SigningRegion string
|
||||
|
||||
|
|
19
vendor/github.com/aws/aws-sdk-go/aws/endpoints/sts_legacy_regions.go
generated
vendored
Normal file
19
vendor/github.com/aws/aws-sdk-go/aws/endpoints/sts_legacy_regions.go
generated
vendored
Normal file
|
@ -0,0 +1,19 @@
|
|||
package endpoints
|
||||
|
||||
var stsLegacyGlobalRegions = map[string]struct{}{
|
||||
"ap-northeast-1": {},
|
||||
"ap-south-1": {},
|
||||
"ap-southeast-1": {},
|
||||
"ap-southeast-2": {},
|
||||
"ca-central-1": {},
|
||||
"eu-central-1": {},
|
||||
"eu-north-1": {},
|
||||
"eu-west-1": {},
|
||||
"eu-west-2": {},
|
||||
"eu-west-3": {},
|
||||
"sa-east-1": {},
|
||||
"us-east-1": {},
|
||||
"us-east-2": {},
|
||||
"us-west-1": {},
|
||||
"us-west-2": {},
|
||||
}
|
|
@ -54,8 +54,9 @@ type partition struct {
|
|||
|
||||
func (p partition) Partition() Partition {
|
||||
return Partition{
|
||||
id: p.ID,
|
||||
p: &p,
|
||||
dnsSuffix: p.DNSSuffix,
|
||||
id: p.ID,
|
||||
p: &p,
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -74,24 +75,55 @@ func (p partition) canResolveEndpoint(service, region string, strictMatch bool)
|
|||
return p.RegionRegex.MatchString(region)
|
||||
}
|
||||
|
||||
func allowLegacyEmptyRegion(service string) bool {
|
||||
legacy := map[string]struct{}{
|
||||
"budgets": {},
|
||||
"ce": {},
|
||||
"chime": {},
|
||||
"cloudfront": {},
|
||||
"ec2metadata": {},
|
||||
"iam": {},
|
||||
"importexport": {},
|
||||
"organizations": {},
|
||||
"route53": {},
|
||||
"sts": {},
|
||||
"support": {},
|
||||
"waf": {},
|
||||
}
|
||||
|
||||
_, allowed := legacy[service]
|
||||
return allowed
|
||||
}
|
||||
|
||||
func (p partition) EndpointFor(service, region string, opts ...func(*Options)) (resolved ResolvedEndpoint, err error) {
|
||||
var opt Options
|
||||
opt.Set(opts...)
|
||||
|
||||
s, hasService := p.Services[service]
|
||||
if !(hasService || opt.ResolveUnknownService) {
|
||||
if len(service) == 0 || !(hasService || opt.ResolveUnknownService) {
|
||||
// Only return error if the resolver will not fallback to creating
|
||||
// endpoint based on service endpoint ID passed in.
|
||||
return resolved, NewUnknownServiceError(p.ID, service, serviceList(p.Services))
|
||||
}
|
||||
|
||||
if len(region) == 0 && allowLegacyEmptyRegion(service) && len(s.PartitionEndpoint) != 0 {
|
||||
region = s.PartitionEndpoint
|
||||
}
|
||||
|
||||
if service == "sts" && opt.STSRegionalEndpoint != RegionalSTSEndpoint {
|
||||
if _, ok := stsLegacyGlobalRegions[region]; ok {
|
||||
region = "aws-global"
|
||||
}
|
||||
}
|
||||
|
||||
e, hasEndpoint := s.endpointForRegion(region)
|
||||
if !hasEndpoint && opt.StrictMatching {
|
||||
if len(region) == 0 || (!hasEndpoint && opt.StrictMatching) {
|
||||
return resolved, NewUnknownEndpointError(p.ID, service, region, endpointList(s.Endpoints))
|
||||
}
|
||||
|
||||
defs := []endpoint{p.Defaults, s.Defaults}
|
||||
return e.resolve(service, region, p.DNSSuffix, defs, opt), nil
|
||||
|
||||
return e.resolve(service, p.ID, region, p.DNSSuffix, defs, opt), nil
|
||||
}
|
||||
|
||||
func serviceList(ss services) []string {
|
||||
|
@ -200,7 +232,7 @@ func getByPriority(s []string, p []string, def string) string {
|
|||
return s[0]
|
||||
}
|
||||
|
||||
func (e endpoint) resolve(service, region, dnsSuffix string, defs []endpoint, opts Options) ResolvedEndpoint {
|
||||
func (e endpoint) resolve(service, partitionID, region, dnsSuffix string, defs []endpoint, opts Options) ResolvedEndpoint {
|
||||
var merged endpoint
|
||||
for _, def := range defs {
|
||||
merged.mergeIn(def)
|
||||
|
@ -236,6 +268,7 @@ func (e endpoint) resolve(service, region, dnsSuffix string, defs []endpoint, op
|
|||
|
||||
return ResolvedEndpoint{
|
||||
URL: u,
|
||||
PartitionID: partitionID,
|
||||
SigningRegion: signingRegion,
|
||||
SigningName: signingName,
|
||||
SigningNameDerived: signingNameDerived,
|
||||
|
|
|
@ -16,6 +16,10 @@ import (
|
|||
type CodeGenOptions struct {
|
||||
// Options for how the model will be decoded.
|
||||
DecodeModelOptions DecodeModelOptions
|
||||
|
||||
// Disables code generation of the service endpoint prefix IDs defined in
|
||||
// the model.
|
||||
DisableGenerateServiceIDs bool
|
||||
}
|
||||
|
||||
// Set combines all of the option functions together
|
||||
|
@ -39,8 +43,16 @@ func CodeGenModel(modelFile io.Reader, outFile io.Writer, optFns ...func(*CodeGe
|
|||
return err
|
||||
}
|
||||
|
||||
v := struct {
|
||||
Resolver
|
||||
CodeGenOptions
|
||||
}{
|
||||
Resolver: resolver,
|
||||
CodeGenOptions: opts,
|
||||
}
|
||||
|
||||
tmpl := template.Must(template.New("tmpl").Funcs(funcMap).Parse(v3Tmpl))
|
||||
if err := tmpl.ExecuteTemplate(outFile, "defaults", resolver); err != nil {
|
||||
if err := tmpl.ExecuteTemplate(outFile, "defaults", v); err != nil {
|
||||
return fmt.Errorf("failed to execute template, %v", err)
|
||||
}
|
||||
|
||||
|
@ -166,15 +178,17 @@ import (
|
|||
"regexp"
|
||||
)
|
||||
|
||||
{{ template "partition consts" . }}
|
||||
{{ template "partition consts" $.Resolver }}
|
||||
|
||||
{{ range $_, $partition := . }}
|
||||
{{ range $_, $partition := $.Resolver }}
|
||||
{{ template "partition region consts" $partition }}
|
||||
{{ end }}
|
||||
|
||||
{{ template "service consts" . }}
|
||||
{{ if not $.DisableGenerateServiceIDs -}}
|
||||
{{ template "service consts" $.Resolver }}
|
||||
{{- end }}
|
||||
|
||||
{{ template "endpoint resolvers" . }}
|
||||
{{ template "endpoint resolvers" $.Resolver }}
|
||||
{{- end }}
|
||||
|
||||
{{ define "partition consts" }}
|
||||
|
|
|
@ -5,13 +5,9 @@ import "github.com/aws/aws-sdk-go/aws/awserr"
|
|||
var (
|
||||
// ErrMissingRegion is an error that is returned if region configuration is
|
||||
// not found.
|
||||
//
|
||||
// @readonly
|
||||
ErrMissingRegion = awserr.New("MissingRegion", "could not find region configuration", nil)
|
||||
|
||||
// ErrMissingEndpoint is an error that is returned if an endpoint cannot be
|
||||
// resolved for a service.
|
||||
//
|
||||
// @readonly
|
||||
ErrMissingEndpoint = awserr.New("MissingEndpoint", "'Endpoint' configuration is required for this service", nil)
|
||||
)
|
||||
|
|
|
@ -1,18 +1,17 @@
|
|||
// +build !appengine,!plan9
|
||||
|
||||
package request
|
||||
|
||||
import (
|
||||
"net"
|
||||
"os"
|
||||
"syscall"
|
||||
"strings"
|
||||
)
|
||||
|
||||
func isErrConnectionReset(err error) bool {
|
||||
if opErr, ok := err.(*net.OpError); ok {
|
||||
if sysErr, ok := opErr.Err.(*os.SyscallError); ok {
|
||||
return sysErr.Err == syscall.ECONNRESET
|
||||
}
|
||||
if strings.Contains(err.Error(), "read: connection reset") {
|
||||
return false
|
||||
}
|
||||
|
||||
if strings.Contains(err.Error(), "connection reset") ||
|
||||
strings.Contains(err.Error(), "broken pipe") {
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
|
|
|
@ -1,11 +0,0 @@
|
|||
// +build appengine plan9
|
||||
|
||||
package request
|
||||
|
||||
import (
|
||||
"strings"
|
||||
)
|
||||
|
||||
func isErrConnectionReset(err error) bool {
|
||||
return strings.Contains(err.Error(), "connection reset")
|
||||
}
|
|
@ -19,10 +19,11 @@ type Handlers struct {
|
|||
UnmarshalError HandlerList
|
||||
Retry HandlerList
|
||||
AfterRetry HandlerList
|
||||
CompleteAttempt HandlerList
|
||||
Complete HandlerList
|
||||
}
|
||||
|
||||
// Copy returns of this handler's lists.
|
||||
// Copy returns a copy of this handler's lists.
|
||||
func (h *Handlers) Copy() Handlers {
|
||||
return Handlers{
|
||||
Validate: h.Validate.copy(),
|
||||
|
@ -36,11 +37,12 @@ func (h *Handlers) Copy() Handlers {
|
|||
UnmarshalMeta: h.UnmarshalMeta.copy(),
|
||||
Retry: h.Retry.copy(),
|
||||
AfterRetry: h.AfterRetry.copy(),
|
||||
CompleteAttempt: h.CompleteAttempt.copy(),
|
||||
Complete: h.Complete.copy(),
|
||||
}
|
||||
}
|
||||
|
||||
// Clear removes callback functions for all handlers
|
||||
// Clear removes callback functions for all handlers.
|
||||
func (h *Handlers) Clear() {
|
||||
h.Validate.Clear()
|
||||
h.Build.Clear()
|
||||
|
@ -53,9 +55,55 @@ func (h *Handlers) Clear() {
|
|||
h.ValidateResponse.Clear()
|
||||
h.Retry.Clear()
|
||||
h.AfterRetry.Clear()
|
||||
h.CompleteAttempt.Clear()
|
||||
h.Complete.Clear()
|
||||
}
|
||||
|
||||
// IsEmpty returns if there are no handlers in any of the handlerlists.
|
||||
func (h *Handlers) IsEmpty() bool {
|
||||
if h.Validate.Len() != 0 {
|
||||
return false
|
||||
}
|
||||
if h.Build.Len() != 0 {
|
||||
return false
|
||||
}
|
||||
if h.Send.Len() != 0 {
|
||||
return false
|
||||
}
|
||||
if h.Sign.Len() != 0 {
|
||||
return false
|
||||
}
|
||||
if h.Unmarshal.Len() != 0 {
|
||||
return false
|
||||
}
|
||||
if h.UnmarshalStream.Len() != 0 {
|
||||
return false
|
||||
}
|
||||
if h.UnmarshalMeta.Len() != 0 {
|
||||
return false
|
||||
}
|
||||
if h.UnmarshalError.Len() != 0 {
|
||||
return false
|
||||
}
|
||||
if h.ValidateResponse.Len() != 0 {
|
||||
return false
|
||||
}
|
||||
if h.Retry.Len() != 0 {
|
||||
return false
|
||||
}
|
||||
if h.AfterRetry.Len() != 0 {
|
||||
return false
|
||||
}
|
||||
if h.CompleteAttempt.Len() != 0 {
|
||||
return false
|
||||
}
|
||||
if h.Complete.Len() != 0 {
|
||||
return false
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// A HandlerListRunItem represents an entry in the HandlerList which
|
||||
// is being run.
|
||||
type HandlerListRunItem struct {
|
||||
|
|
|
@ -15,12 +15,15 @@ type offsetReader struct {
|
|||
closed bool
|
||||
}
|
||||
|
||||
func newOffsetReader(buf io.ReadSeeker, offset int64) *offsetReader {
|
||||
func newOffsetReader(buf io.ReadSeeker, offset int64) (*offsetReader, error) {
|
||||
reader := &offsetReader{}
|
||||
buf.Seek(offset, sdkio.SeekStart)
|
||||
_, err := buf.Seek(offset, sdkio.SeekStart)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
reader.buf = buf
|
||||
return reader
|
||||
return reader, nil
|
||||
}
|
||||
|
||||
// Close will close the instance of the offset reader's access to
|
||||
|
@ -54,7 +57,9 @@ func (o *offsetReader) Seek(offset int64, whence int) (int64, error) {
|
|||
|
||||
// CloseAndCopy will return a new offsetReader with a copy of the old buffer
|
||||
// and close the old buffer.
|
||||
func (o *offsetReader) CloseAndCopy(offset int64) *offsetReader {
|
||||
o.Close()
|
||||
func (o *offsetReader) CloseAndCopy(offset int64) (*offsetReader, error) {
|
||||
if err := o.Close(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return newOffsetReader(o.buf, offset)
|
||||
}
|
||||
|
|
|
@ -4,7 +4,6 @@ import (
|
|||
"bytes"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"reflect"
|
||||
|
@ -65,6 +64,15 @@ type Request struct {
|
|||
LastSignedAt time.Time
|
||||
DisableFollowRedirects bool
|
||||
|
||||
// Additional API error codes that should be retried. IsErrorRetryable
|
||||
// will consider these codes in addition to its built in cases.
|
||||
RetryErrorCodes []string
|
||||
|
||||
// Additional API error codes that should be retried with throttle backoff
|
||||
// delay. IsErrorThrottle will consider these codes in addition to its
|
||||
// built in cases.
|
||||
ThrottleErrorCodes []string
|
||||
|
||||
// A value greater than 0 instructs the request to be signed as Presigned URL
|
||||
// You should not set this field directly. Instead use Request's
|
||||
// Presign or PresignRequest methods.
|
||||
|
@ -91,8 +99,12 @@ type Operation struct {
|
|||
BeforePresignFn func(r *Request) error
|
||||
}
|
||||
|
||||
// New returns a new Request pointer for the service API
|
||||
// operation and parameters.
|
||||
// New returns a new Request pointer for the service API operation and
|
||||
// parameters.
|
||||
//
|
||||
// A Retryer should be provided to direct how the request is retried. If
|
||||
// Retryer is nil, a default no retry value will be used. You can use
|
||||
// NoOpRetryer in the Client package to disable retry behavior directly.
|
||||
//
|
||||
// Params is any value of input parameters to be the request payload.
|
||||
// Data is pointer value to an object which the request's response
|
||||
|
@ -100,6 +112,10 @@ type Operation struct {
|
|||
func New(cfg aws.Config, clientInfo metadata.ClientInfo, handlers Handlers,
|
||||
retryer Retryer, operation *Operation, params interface{}, data interface{}) *Request {
|
||||
|
||||
if retryer == nil {
|
||||
retryer = noOpRetryer{}
|
||||
}
|
||||
|
||||
method := operation.HTTPMethod
|
||||
if method == "" {
|
||||
method = "POST"
|
||||
|
@ -122,7 +138,6 @@ func New(cfg aws.Config, clientInfo metadata.ClientInfo, handlers Handlers,
|
|||
Handlers: handlers.Copy(),
|
||||
|
||||
Retryer: retryer,
|
||||
AttemptTime: time.Now(),
|
||||
Time: time.Now(),
|
||||
ExpireTime: 0,
|
||||
Operation: operation,
|
||||
|
@ -233,6 +248,10 @@ func (r *Request) WillRetry() bool {
|
|||
return r.Error != nil && aws.BoolValue(r.Retryable) && r.RetryCount < r.MaxRetries()
|
||||
}
|
||||
|
||||
func fmtAttemptCount(retryCount, maxRetries int) string {
|
||||
return fmt.Sprintf("attempt %v/%v", retryCount, maxRetries)
|
||||
}
|
||||
|
||||
// ParamsFilled returns if the request's parameters have been populated
|
||||
// and the parameters are valid. False is returned if no parameters are
|
||||
// provided or invalid.
|
||||
|
@ -261,12 +280,25 @@ func (r *Request) SetStringBody(s string) {
|
|||
// SetReaderBody will set the request's body reader.
|
||||
func (r *Request) SetReaderBody(reader io.ReadSeeker) {
|
||||
r.Body = reader
|
||||
r.BodyStart, _ = reader.Seek(0, sdkio.SeekCurrent) // Get the Bodies current offset.
|
||||
|
||||
if aws.IsReaderSeekable(reader) {
|
||||
var err error
|
||||
// Get the Bodies current offset so retries will start from the same
|
||||
// initial position.
|
||||
r.BodyStart, err = reader.Seek(0, sdkio.SeekCurrent)
|
||||
if err != nil {
|
||||
r.Error = awserr.New(ErrCodeSerialization,
|
||||
"failed to determine start of request body", err)
|
||||
return
|
||||
}
|
||||
}
|
||||
r.ResetBody()
|
||||
}
|
||||
|
||||
// Presign returns the request's signed URL. Error will be returned
|
||||
// if the signing fails.
|
||||
// if the signing fails. The expire parameter is only used for presigned Amazon
|
||||
// S3 API requests. All other AWS services will use a fixed expiration
|
||||
// time of 15 minutes.
|
||||
//
|
||||
// It is invalid to create a presigned URL with a expire duration 0 or less. An
|
||||
// error is returned if expire duration is 0 or less.
|
||||
|
@ -283,7 +315,9 @@ func (r *Request) Presign(expire time.Duration) (string, error) {
|
|||
}
|
||||
|
||||
// PresignRequest behaves just like presign, with the addition of returning a
|
||||
// set of headers that were signed.
|
||||
// set of headers that were signed. The expire parameter is only used for
|
||||
// presigned Amazon S3 API requests. All other AWS services will use a fixed
|
||||
// expiration time of 15 minutes.
|
||||
//
|
||||
// It is invalid to create a presigned URL with a expire duration 0 or less. An
|
||||
// error is returned if expire duration is 0 or less.
|
||||
|
@ -328,16 +362,15 @@ func getPresignedURL(r *Request, expire time.Duration) (string, http.Header, err
|
|||
return r.HTTPRequest.URL.String(), r.SignedHeaderVals, nil
|
||||
}
|
||||
|
||||
func debugLogReqError(r *Request, stage string, retrying bool, err error) {
|
||||
const (
|
||||
notRetrying = "not retrying"
|
||||
)
|
||||
|
||||
func debugLogReqError(r *Request, stage, retryStr string, err error) {
|
||||
if !r.Config.LogLevel.Matches(aws.LogDebugWithRequestErrors) {
|
||||
return
|
||||
}
|
||||
|
||||
retryStr := "not retrying"
|
||||
if retrying {
|
||||
retryStr = "will retry"
|
||||
}
|
||||
|
||||
r.Config.Logger.Log(fmt.Sprintf("DEBUG: %s %s/%s failed, %s, error %v",
|
||||
stage, r.ClientInfo.ServiceName, r.Operation.Name, retryStr, err))
|
||||
}
|
||||
|
@ -356,12 +389,12 @@ func (r *Request) Build() error {
|
|||
if !r.built {
|
||||
r.Handlers.Validate.Run(r)
|
||||
if r.Error != nil {
|
||||
debugLogReqError(r, "Validate Request", false, r.Error)
|
||||
debugLogReqError(r, "Validate Request", notRetrying, r.Error)
|
||||
return r.Error
|
||||
}
|
||||
r.Handlers.Build.Run(r)
|
||||
if r.Error != nil {
|
||||
debugLogReqError(r, "Build Request", false, r.Error)
|
||||
debugLogReqError(r, "Build Request", notRetrying, r.Error)
|
||||
return r.Error
|
||||
}
|
||||
r.built = true
|
||||
|
@ -377,7 +410,7 @@ func (r *Request) Build() error {
|
|||
func (r *Request) Sign() error {
|
||||
r.Build()
|
||||
if r.Error != nil {
|
||||
debugLogReqError(r, "Build Request", false, r.Error)
|
||||
debugLogReqError(r, "Build Request", notRetrying, r.Error)
|
||||
return r.Error
|
||||
}
|
||||
|
||||
|
@ -385,12 +418,16 @@ func (r *Request) Sign() error {
|
|||
return r.Error
|
||||
}
|
||||
|
||||
func (r *Request) getNextRequestBody() (io.ReadCloser, error) {
|
||||
func (r *Request) getNextRequestBody() (body io.ReadCloser, err error) {
|
||||
if r.safeBody != nil {
|
||||
r.safeBody.Close()
|
||||
}
|
||||
|
||||
r.safeBody = newOffsetReader(r.Body, r.BodyStart)
|
||||
r.safeBody, err = newOffsetReader(r.Body, r.BodyStart)
|
||||
if err != nil {
|
||||
return nil, awserr.New(ErrCodeSerialization,
|
||||
"failed to get next request body reader", err)
|
||||
}
|
||||
|
||||
// Go 1.8 tightened and clarified the rules code needs to use when building
|
||||
// requests with the http package. Go 1.8 removed the automatic detection
|
||||
|
@ -407,10 +444,10 @@ func (r *Request) getNextRequestBody() (io.ReadCloser, error) {
|
|||
// Related golang/go#18257
|
||||
l, err := aws.SeekerLen(r.Body)
|
||||
if err != nil {
|
||||
return nil, awserr.New(ErrCodeSerialization, "failed to compute request body size", err)
|
||||
return nil, awserr.New(ErrCodeSerialization,
|
||||
"failed to compute request body size", err)
|
||||
}
|
||||
|
||||
var body io.ReadCloser
|
||||
if l == 0 {
|
||||
body = NoBody
|
||||
} else if l > 0 {
|
||||
|
@ -462,80 +499,90 @@ func (r *Request) Send() error {
|
|||
r.Handlers.Complete.Run(r)
|
||||
}()
|
||||
|
||||
if err := r.Error; err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for {
|
||||
r.Error = nil
|
||||
r.AttemptTime = time.Now()
|
||||
if aws.BoolValue(r.Retryable) {
|
||||
if r.Config.LogLevel.Matches(aws.LogDebugWithRequestRetries) {
|
||||
r.Config.Logger.Log(fmt.Sprintf("DEBUG: Retrying Request %s/%s, attempt %d",
|
||||
r.ClientInfo.ServiceName, r.Operation.Name, r.RetryCount))
|
||||
}
|
||||
|
||||
// The previous http.Request will have a reference to the r.Body
|
||||
// and the HTTP Client's Transport may still be reading from
|
||||
// the request's body even though the Client's Do returned.
|
||||
r.HTTPRequest = copyHTTPRequest(r.HTTPRequest, nil)
|
||||
r.ResetBody()
|
||||
|
||||
// Closing response body to ensure that no response body is leaked
|
||||
// between retry attempts.
|
||||
if r.HTTPResponse != nil && r.HTTPResponse.Body != nil {
|
||||
r.HTTPResponse.Body.Close()
|
||||
}
|
||||
if err := r.Sign(); err != nil {
|
||||
debugLogReqError(r, "Sign Request", notRetrying, err)
|
||||
return err
|
||||
}
|
||||
|
||||
r.Sign()
|
||||
if r.Error != nil {
|
||||
if err := r.sendRequest(); err == nil {
|
||||
return nil
|
||||
}
|
||||
r.Handlers.Retry.Run(r)
|
||||
r.Handlers.AfterRetry.Run(r)
|
||||
|
||||
if r.Error != nil || !aws.BoolValue(r.Retryable) {
|
||||
return r.Error
|
||||
}
|
||||
|
||||
r.Retryable = nil
|
||||
|
||||
r.Handlers.Send.Run(r)
|
||||
if r.Error != nil {
|
||||
if !shouldRetryCancel(r) {
|
||||
return r.Error
|
||||
}
|
||||
|
||||
err := r.Error
|
||||
r.Handlers.Retry.Run(r)
|
||||
r.Handlers.AfterRetry.Run(r)
|
||||
if r.Error != nil {
|
||||
debugLogReqError(r, "Send Request", false, err)
|
||||
return r.Error
|
||||
}
|
||||
debugLogReqError(r, "Send Request", true, err)
|
||||
continue
|
||||
if err := r.prepareRetry(); err != nil {
|
||||
r.Error = err
|
||||
return err
|
||||
}
|
||||
r.Handlers.UnmarshalMeta.Run(r)
|
||||
r.Handlers.ValidateResponse.Run(r)
|
||||
if r.Error != nil {
|
||||
r.Handlers.UnmarshalError.Run(r)
|
||||
err := r.Error
|
||||
}
|
||||
}
|
||||
|
||||
r.Handlers.Retry.Run(r)
|
||||
r.Handlers.AfterRetry.Run(r)
|
||||
if r.Error != nil {
|
||||
debugLogReqError(r, "Validate Response", false, err)
|
||||
return r.Error
|
||||
}
|
||||
debugLogReqError(r, "Validate Response", true, err)
|
||||
continue
|
||||
}
|
||||
func (r *Request) prepareRetry() error {
|
||||
if r.Config.LogLevel.Matches(aws.LogDebugWithRequestRetries) {
|
||||
r.Config.Logger.Log(fmt.Sprintf("DEBUG: Retrying Request %s/%s, attempt %d",
|
||||
r.ClientInfo.ServiceName, r.Operation.Name, r.RetryCount))
|
||||
}
|
||||
|
||||
r.Handlers.Unmarshal.Run(r)
|
||||
if r.Error != nil {
|
||||
err := r.Error
|
||||
r.Handlers.Retry.Run(r)
|
||||
r.Handlers.AfterRetry.Run(r)
|
||||
if r.Error != nil {
|
||||
debugLogReqError(r, "Unmarshal Response", false, err)
|
||||
return r.Error
|
||||
}
|
||||
debugLogReqError(r, "Unmarshal Response", true, err)
|
||||
continue
|
||||
}
|
||||
// The previous http.Request will have a reference to the r.Body
|
||||
// and the HTTP Client's Transport may still be reading from
|
||||
// the request's body even though the Client's Do returned.
|
||||
r.HTTPRequest = copyHTTPRequest(r.HTTPRequest, nil)
|
||||
r.ResetBody()
|
||||
if err := r.Error; err != nil {
|
||||
return awserr.New(ErrCodeSerialization,
|
||||
"failed to prepare body for retry", err)
|
||||
|
||||
break
|
||||
}
|
||||
|
||||
// Closing response body to ensure that no response body is leaked
|
||||
// between retry attempts.
|
||||
if r.HTTPResponse != nil && r.HTTPResponse.Body != nil {
|
||||
r.HTTPResponse.Body.Close()
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *Request) sendRequest() (sendErr error) {
|
||||
defer r.Handlers.CompleteAttempt.Run(r)
|
||||
|
||||
r.Retryable = nil
|
||||
r.Handlers.Send.Run(r)
|
||||
if r.Error != nil {
|
||||
debugLogReqError(r, "Send Request",
|
||||
fmtAttemptCount(r.RetryCount, r.MaxRetries()),
|
||||
r.Error)
|
||||
return r.Error
|
||||
}
|
||||
|
||||
r.Handlers.UnmarshalMeta.Run(r)
|
||||
r.Handlers.ValidateResponse.Run(r)
|
||||
if r.Error != nil {
|
||||
r.Handlers.UnmarshalError.Run(r)
|
||||
debugLogReqError(r, "Validate Response",
|
||||
fmtAttemptCount(r.RetryCount, r.MaxRetries()),
|
||||
r.Error)
|
||||
return r.Error
|
||||
}
|
||||
|
||||
r.Handlers.Unmarshal.Run(r)
|
||||
if r.Error != nil {
|
||||
debugLogReqError(r, "Unmarshal Response",
|
||||
fmtAttemptCount(r.RetryCount, r.MaxRetries()),
|
||||
r.Error)
|
||||
return r.Error
|
||||
}
|
||||
|
||||
return nil
|
||||
|
@ -561,32 +608,6 @@ func AddToUserAgent(r *Request, s string) {
|
|||
r.HTTPRequest.Header.Set("User-Agent", s)
|
||||
}
|
||||
|
||||
func shouldRetryCancel(r *Request) bool {
|
||||
awsErr, ok := r.Error.(awserr.Error)
|
||||
timeoutErr := false
|
||||
errStr := r.Error.Error()
|
||||
if ok {
|
||||
if awsErr.Code() == CanceledErrorCode {
|
||||
return false
|
||||
}
|
||||
err := awsErr.OrigErr()
|
||||
netErr, netOK := err.(net.Error)
|
||||
timeoutErr = netOK && netErr.Temporary()
|
||||
if urlErr, ok := err.(*url.Error); !timeoutErr && ok {
|
||||
errStr = urlErr.Err.Error()
|
||||
}
|
||||
}
|
||||
|
||||
// There can be two types of canceled errors here.
|
||||
// The first being a net.Error and the other being an error.
|
||||
// If the request was timed out, we want to continue the retry
|
||||
// process. Otherwise, return the canceled error.
|
||||
return timeoutErr ||
|
||||
(errStr != "net/http: request canceled" &&
|
||||
errStr != "net/http: request canceled while waiting for connection")
|
||||
|
||||
}
|
||||
|
||||
// SanitizeHostForHeader removes default port from host and updates request.Host
|
||||
func SanitizeHostForHeader(r *http.Request) {
|
||||
host := getHost(r)
|
||||
|
|
|
@ -4,6 +4,8 @@ package request
|
|||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"github.com/aws/aws-sdk-go/aws/awserr"
|
||||
)
|
||||
|
||||
// NoBody is a http.NoBody reader instructing Go HTTP client to not include
|
||||
|
@ -24,7 +26,8 @@ var NoBody = http.NoBody
|
|||
func (r *Request) ResetBody() {
|
||||
body, err := r.getNextRequestBody()
|
||||
if err != nil {
|
||||
r.Error = err
|
||||
r.Error = awserr.New(ErrCodeSerialization,
|
||||
"failed to reset request body", err)
|
||||
return
|
||||
}
|
||||
|
||||
|
|
|
@ -146,7 +146,7 @@ func (r *Request) nextPageTokens() []interface{} {
|
|||
return nil
|
||||
}
|
||||
case bool:
|
||||
if v == false {
|
||||
if !v {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,26 +1,75 @@
|
|||
package request
|
||||
|
||||
import (
|
||||
"net"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/aws/aws-sdk-go/aws"
|
||||
"github.com/aws/aws-sdk-go/aws/awserr"
|
||||
)
|
||||
|
||||
// Retryer is an interface to control retry logic for a given service.
|
||||
// The default implementation used by most services is the client.DefaultRetryer
|
||||
// structure, which contains basic retry logic using exponential backoff.
|
||||
// Retryer provides the interface drive the SDK's request retry behavior. The
|
||||
// Retryer implementation is responsible for implementing exponential backoff,
|
||||
// and determine if a request API error should be retried.
|
||||
//
|
||||
// client.DefaultRetryer is the SDK's default implementation of the Retryer. It
|
||||
// uses the which uses the Request.IsErrorRetryable and Request.IsErrorThrottle
|
||||
// methods to determine if the request is retried.
|
||||
type Retryer interface {
|
||||
// RetryRules return the retry delay that should be used by the SDK before
|
||||
// making another request attempt for the failed request.
|
||||
RetryRules(*Request) time.Duration
|
||||
|
||||
// ShouldRetry returns if the failed request is retryable.
|
||||
//
|
||||
// Implementations may consider request attempt count when determining if a
|
||||
// request is retryable, but the SDK will use MaxRetries to limit the
|
||||
// number of attempts a request are made.
|
||||
ShouldRetry(*Request) bool
|
||||
|
||||
// MaxRetries is the number of times a request may be retried before
|
||||
// failing.
|
||||
MaxRetries() int
|
||||
}
|
||||
|
||||
// WithRetryer sets a config Retryer value to the given Config returning it
|
||||
// for chaining.
|
||||
// WithRetryer sets a Retryer value to the given Config returning the Config
|
||||
// value for chaining. The value must not be nil.
|
||||
func WithRetryer(cfg *aws.Config, retryer Retryer) *aws.Config {
|
||||
if retryer == nil {
|
||||
if cfg.Logger != nil {
|
||||
cfg.Logger.Log("ERROR: Request.WithRetryer called with nil retryer. Replacing with retry disabled Retryer.")
|
||||
}
|
||||
retryer = noOpRetryer{}
|
||||
}
|
||||
cfg.Retryer = retryer
|
||||
return cfg
|
||||
|
||||
}
|
||||
|
||||
// noOpRetryer is a internal no op retryer used when a request is created
|
||||
// without a retryer.
|
||||
//
|
||||
// Provides a retryer that performs no retries.
|
||||
// It should be used when we do not want retries to be performed.
|
||||
type noOpRetryer struct{}
|
||||
|
||||
// MaxRetries returns the number of maximum returns the service will use to make
|
||||
// an individual API; For NoOpRetryer the MaxRetries will always be zero.
|
||||
func (d noOpRetryer) MaxRetries() int {
|
||||
return 0
|
||||
}
|
||||
|
||||
// ShouldRetry will always return false for NoOpRetryer, as it should never retry.
|
||||
func (d noOpRetryer) ShouldRetry(_ *Request) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
// RetryRules returns the delay duration before retrying this request again;
|
||||
// since NoOpRetryer does not retry, RetryRules always returns 0.
|
||||
func (d noOpRetryer) RetryRules(_ *Request) time.Duration {
|
||||
return 0
|
||||
}
|
||||
|
||||
// retryableCodes is a collection of service response codes which are retry-able
|
||||
|
@ -38,8 +87,10 @@ var throttleCodes = map[string]struct{}{
|
|||
"ThrottlingException": {},
|
||||
"RequestLimitExceeded": {},
|
||||
"RequestThrottled": {},
|
||||
"RequestThrottledException": {},
|
||||
"TooManyRequestsException": {}, // Lambda functions
|
||||
"PriorRequestNotComplete": {}, // Route53
|
||||
"TransactionInProgressException": {},
|
||||
}
|
||||
|
||||
// credsExpiredCodes is a collection of error codes which signify the credentials
|
||||
|
@ -74,10 +125,6 @@ var validParentCodes = map[string]struct{}{
|
|||
ErrCodeRead: {},
|
||||
}
|
||||
|
||||
type temporaryError interface {
|
||||
Temporary() bool
|
||||
}
|
||||
|
||||
func isNestedErrorRetryable(parentErr awserr.Error) bool {
|
||||
if parentErr == nil {
|
||||
return false
|
||||
|
@ -96,7 +143,7 @@ func isNestedErrorRetryable(parentErr awserr.Error) bool {
|
|||
return isCodeRetryable(aerr.Code())
|
||||
}
|
||||
|
||||
if t, ok := err.(temporaryError); ok {
|
||||
if t, ok := err.(temporary); ok {
|
||||
return t.Temporary() || isErrConnectionReset(err)
|
||||
}
|
||||
|
||||
|
@ -106,32 +153,90 @@ func isNestedErrorRetryable(parentErr awserr.Error) bool {
|
|||
// IsErrorRetryable returns whether the error is retryable, based on its Code.
|
||||
// Returns false if error is nil.
|
||||
func IsErrorRetryable(err error) bool {
|
||||
if err != nil {
|
||||
if aerr, ok := err.(awserr.Error); ok {
|
||||
return isCodeRetryable(aerr.Code()) || isNestedErrorRetryable(aerr)
|
||||
}
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
return shouldRetryError(err)
|
||||
}
|
||||
|
||||
type temporary interface {
|
||||
Temporary() bool
|
||||
}
|
||||
|
||||
func shouldRetryError(origErr error) bool {
|
||||
switch err := origErr.(type) {
|
||||
case awserr.Error:
|
||||
if err.Code() == CanceledErrorCode {
|
||||
return false
|
||||
}
|
||||
if isNestedErrorRetryable(err) {
|
||||
return true
|
||||
}
|
||||
|
||||
origErr := err.OrigErr()
|
||||
var shouldRetry bool
|
||||
if origErr != nil {
|
||||
shouldRetry := shouldRetryError(origErr)
|
||||
if err.Code() == "RequestError" && !shouldRetry {
|
||||
return false
|
||||
}
|
||||
}
|
||||
if isCodeRetryable(err.Code()) {
|
||||
return true
|
||||
}
|
||||
return shouldRetry
|
||||
|
||||
case *url.Error:
|
||||
if strings.Contains(err.Error(), "connection refused") {
|
||||
// Refused connections should be retried as the service may not yet
|
||||
// be running on the port. Go TCP dial considers refused
|
||||
// connections as not temporary.
|
||||
return true
|
||||
}
|
||||
// *url.Error only implements Temporary after golang 1.6 but since
|
||||
// url.Error only wraps the error:
|
||||
return shouldRetryError(err.Err)
|
||||
|
||||
case temporary:
|
||||
if netErr, ok := err.(*net.OpError); ok && netErr.Op == "dial" {
|
||||
return true
|
||||
}
|
||||
// If the error is temporary, we want to allow continuation of the
|
||||
// retry process
|
||||
return err.Temporary() || isErrConnectionReset(origErr)
|
||||
|
||||
case nil:
|
||||
// `awserr.Error.OrigErr()` can be nil, meaning there was an error but
|
||||
// because we don't know the cause, it is marked as retryable. See
|
||||
// TestRequest4xxUnretryable for an example.
|
||||
return true
|
||||
|
||||
default:
|
||||
switch err.Error() {
|
||||
case "net/http: request canceled",
|
||||
"net/http: request canceled while waiting for connection":
|
||||
// known 1.5 error case when an http request is cancelled
|
||||
return false
|
||||
}
|
||||
// here we don't know the error; so we allow a retry.
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// IsErrorThrottle returns whether the error is to be throttled based on its code.
|
||||
// Returns false if error is nil.
|
||||
func IsErrorThrottle(err error) bool {
|
||||
if err != nil {
|
||||
if aerr, ok := err.(awserr.Error); ok {
|
||||
return isCodeThrottle(aerr.Code())
|
||||
}
|
||||
if aerr, ok := err.(awserr.Error); ok && aerr != nil {
|
||||
return isCodeThrottle(aerr.Code())
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// IsErrorExpiredCreds returns whether the error code is a credential expiry error.
|
||||
// Returns false if error is nil.
|
||||
// IsErrorExpiredCreds returns whether the error code is a credential expiry
|
||||
// error. Returns false if error is nil.
|
||||
func IsErrorExpiredCreds(err error) bool {
|
||||
if err != nil {
|
||||
if aerr, ok := err.(awserr.Error); ok {
|
||||
return isCodeExpiredCreds(aerr.Code())
|
||||
}
|
||||
if aerr, ok := err.(awserr.Error); ok && aerr != nil {
|
||||
return isCodeExpiredCreds(aerr.Code())
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
@ -141,17 +246,58 @@ func IsErrorExpiredCreds(err error) bool {
|
|||
//
|
||||
// Alias for the utility function IsErrorRetryable
|
||||
func (r *Request) IsErrorRetryable() bool {
|
||||
if isErrCode(r.Error, r.RetryErrorCodes) {
|
||||
return true
|
||||
}
|
||||
|
||||
// HTTP response status code 501 should not be retried.
|
||||
// 501 represents Not Implemented which means the request method is not
|
||||
// supported by the server and cannot be handled.
|
||||
if r.HTTPResponse != nil {
|
||||
// HTTP response status code 500 represents internal server error and
|
||||
// should be retried without any throttle.
|
||||
if r.HTTPResponse.StatusCode == 500 {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return IsErrorRetryable(r.Error)
|
||||
}
|
||||
|
||||
// IsErrorThrottle returns whether the error is to be throttled based on its code.
|
||||
// Returns false if the request has no Error set
|
||||
// IsErrorThrottle returns whether the error is to be throttled based on its
|
||||
// code. Returns false if the request has no Error set.
|
||||
//
|
||||
// Alias for the utility function IsErrorThrottle
|
||||
func (r *Request) IsErrorThrottle() bool {
|
||||
if isErrCode(r.Error, r.ThrottleErrorCodes) {
|
||||
return true
|
||||
}
|
||||
|
||||
if r.HTTPResponse != nil {
|
||||
switch r.HTTPResponse.StatusCode {
|
||||
case
|
||||
429, // error caused due to too many requests
|
||||
502, // Bad Gateway error should be throttled
|
||||
503, // caused when service is unavailable
|
||||
504: // error occurred due to gateway timeout
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return IsErrorThrottle(r.Error)
|
||||
}
|
||||
|
||||
func isErrCode(err error, codes []string) bool {
|
||||
if aerr, ok := err.(awserr.Error); ok && aerr != nil {
|
||||
for _, code := range codes {
|
||||
if code == aerr.Code() {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// IsErrorExpired returns whether the error code is a credential expiry error.
|
||||
// Returns false if the request has no Error set.
|
||||
//
|
||||
|
|
|
@ -17,6 +17,12 @@ const (
|
|||
ParamMinValueErrCode = "ParamMinValueError"
|
||||
// ParamMinLenErrCode is the error code for fields without enough elements.
|
||||
ParamMinLenErrCode = "ParamMinLenError"
|
||||
// ParamMaxLenErrCode is the error code for value being too long.
|
||||
ParamMaxLenErrCode = "ParamMaxLenError"
|
||||
|
||||
// ParamFormatErrCode is the error code for a field with invalid
|
||||
// format or characters.
|
||||
ParamFormatErrCode = "ParamFormatInvalidError"
|
||||
)
|
||||
|
||||
// Validator provides a way for types to perform validation logic on their
|
||||
|
@ -232,3 +238,49 @@ func NewErrParamMinLen(field string, min int) *ErrParamMinLen {
|
|||
func (e *ErrParamMinLen) MinLen() int {
|
||||
return e.min
|
||||
}
|
||||
|
||||
// An ErrParamMaxLen represents a maximum length parameter error.
|
||||
type ErrParamMaxLen struct {
|
||||
errInvalidParam
|
||||
max int
|
||||
}
|
||||
|
||||
// NewErrParamMaxLen creates a new maximum length parameter error.
|
||||
func NewErrParamMaxLen(field string, max int, value string) *ErrParamMaxLen {
|
||||
return &ErrParamMaxLen{
|
||||
errInvalidParam: errInvalidParam{
|
||||
code: ParamMaxLenErrCode,
|
||||
field: field,
|
||||
msg: fmt.Sprintf("maximum size of %v, %v", max, value),
|
||||
},
|
||||
max: max,
|
||||
}
|
||||
}
|
||||
|
||||
// MaxLen returns the field's required minimum length.
|
||||
func (e *ErrParamMaxLen) MaxLen() int {
|
||||
return e.max
|
||||
}
|
||||
|
||||
// An ErrParamFormat represents a invalid format parameter error.
|
||||
type ErrParamFormat struct {
|
||||
errInvalidParam
|
||||
format string
|
||||
}
|
||||
|
||||
// NewErrParamFormat creates a new invalid format parameter error.
|
||||
func NewErrParamFormat(field string, format, value string) *ErrParamFormat {
|
||||
return &ErrParamFormat{
|
||||
errInvalidParam: errInvalidParam{
|
||||
code: ParamFormatErrCode,
|
||||
field: field,
|
||||
msg: fmt.Sprintf("format %v, %v", format, value),
|
||||
},
|
||||
format: format,
|
||||
}
|
||||
}
|
||||
|
||||
// Format returns the field's required format.
|
||||
func (e *ErrParamFormat) Format() string {
|
||||
return e.format
|
||||
}
|
||||
|
|
26
vendor/github.com/aws/aws-sdk-go/aws/session/cabundle_transport.go
generated
vendored
Normal file
26
vendor/github.com/aws/aws-sdk-go/aws/session/cabundle_transport.go
generated
vendored
Normal file
|
@ -0,0 +1,26 @@
|
|||
// +build go1.7
|
||||
|
||||
package session
|
||||
|
||||
import (
|
||||
"net"
|
||||
"net/http"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Transport that should be used when a custom CA bundle is specified with the
|
||||
// SDK.
|
||||
func getCABundleTransport() *http.Transport {
|
||||
return &http.Transport{
|
||||
Proxy: http.ProxyFromEnvironment,
|
||||
DialContext: (&net.Dialer{
|
||||
Timeout: 30 * time.Second,
|
||||
KeepAlive: 30 * time.Second,
|
||||
DualStack: true,
|
||||
}).DialContext,
|
||||
MaxIdleConns: 100,
|
||||
IdleConnTimeout: 90 * time.Second,
|
||||
TLSHandshakeTimeout: 10 * time.Second,
|
||||
ExpectContinueTimeout: 1 * time.Second,
|
||||
}
|
||||
}
|
22
vendor/github.com/aws/aws-sdk-go/aws/session/cabundle_transport_1_5.go
generated
vendored
Normal file
22
vendor/github.com/aws/aws-sdk-go/aws/session/cabundle_transport_1_5.go
generated
vendored
Normal file
|
@ -0,0 +1,22 @@
|
|||
// +build !go1.6,go1.5
|
||||
|
||||
package session
|
||||
|
||||
import (
|
||||
"net"
|
||||
"net/http"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Transport that should be used when a custom CA bundle is specified with the
|
||||
// SDK.
|
||||
func getCABundleTransport() *http.Transport {
|
||||
return &http.Transport{
|
||||
Proxy: http.ProxyFromEnvironment,
|
||||
Dial: (&net.Dialer{
|
||||
Timeout: 30 * time.Second,
|
||||
KeepAlive: 30 * time.Second,
|
||||
}).Dial,
|
||||
TLSHandshakeTimeout: 10 * time.Second,
|
||||
}
|
||||
}
|
23
vendor/github.com/aws/aws-sdk-go/aws/session/cabundle_transport_1_6.go
generated
vendored
Normal file
23
vendor/github.com/aws/aws-sdk-go/aws/session/cabundle_transport_1_6.go
generated
vendored
Normal file
|
@ -0,0 +1,23 @@
|
|||
// +build !go1.7,go1.6
|
||||
|
||||
package session
|
||||
|
||||
import (
|
||||
"net"
|
||||
"net/http"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Transport that should be used when a custom CA bundle is specified with the
|
||||
// SDK.
|
||||
func getCABundleTransport() *http.Transport {
|
||||
return &http.Transport{
|
||||
Proxy: http.ProxyFromEnvironment,
|
||||
Dial: (&net.Dialer{
|
||||
Timeout: 30 * time.Second,
|
||||
KeepAlive: 30 * time.Second,
|
||||
}).Dial,
|
||||
TLSHandshakeTimeout: 10 * time.Second,
|
||||
ExpectContinueTimeout: 1 * time.Second,
|
||||
}
|
||||
}
|
|
@ -0,0 +1,259 @@
|
|||
package session
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
|
||||
"github.com/aws/aws-sdk-go/aws"
|
||||
"github.com/aws/aws-sdk-go/aws/awserr"
|
||||
"github.com/aws/aws-sdk-go/aws/credentials"
|
||||
"github.com/aws/aws-sdk-go/aws/credentials/processcreds"
|
||||
"github.com/aws/aws-sdk-go/aws/credentials/stscreds"
|
||||
"github.com/aws/aws-sdk-go/aws/defaults"
|
||||
"github.com/aws/aws-sdk-go/aws/request"
|
||||
"github.com/aws/aws-sdk-go/internal/shareddefaults"
|
||||
)
|
||||
|
||||
func resolveCredentials(cfg *aws.Config,
|
||||
envCfg envConfig, sharedCfg sharedConfig,
|
||||
handlers request.Handlers,
|
||||
sessOpts Options,
|
||||
) (*credentials.Credentials, error) {
|
||||
|
||||
switch {
|
||||
case len(sessOpts.Profile) != 0:
|
||||
// User explicitly provided an Profile in the session's configuration
|
||||
// so load that profile from shared config first.
|
||||
// Github(aws/aws-sdk-go#2727)
|
||||
return resolveCredsFromProfile(cfg, envCfg, sharedCfg, handlers, sessOpts)
|
||||
|
||||
case envCfg.Creds.HasKeys():
|
||||
// Environment credentials
|
||||
return credentials.NewStaticCredentialsFromCreds(envCfg.Creds), nil
|
||||
|
||||
case len(envCfg.WebIdentityTokenFilePath) != 0:
|
||||
// Web identity token from environment, RoleARN required to also be
|
||||
// set.
|
||||
return assumeWebIdentity(cfg, handlers,
|
||||
envCfg.WebIdentityTokenFilePath,
|
||||
envCfg.RoleARN,
|
||||
envCfg.RoleSessionName,
|
||||
)
|
||||
|
||||
default:
|
||||
// Fallback to the "default" credential resolution chain.
|
||||
return resolveCredsFromProfile(cfg, envCfg, sharedCfg, handlers, sessOpts)
|
||||
}
|
||||
}
|
||||
|
||||
// WebIdentityEmptyRoleARNErr will occur if 'AWS_WEB_IDENTITY_TOKEN_FILE' was set but
|
||||
// 'AWS_IAM_ROLE_ARN' was not set.
|
||||
var WebIdentityEmptyRoleARNErr = awserr.New(stscreds.ErrCodeWebIdentity, "role ARN is not set", nil)
|
||||
|
||||
// WebIdentityEmptyTokenFilePathErr will occur if 'AWS_IAM_ROLE_ARN' was set but
|
||||
// 'AWS_WEB_IDENTITY_TOKEN_FILE' was not set.
|
||||
var WebIdentityEmptyTokenFilePathErr = awserr.New(stscreds.ErrCodeWebIdentity, "token file path is not set", nil)
|
||||
|
||||
func assumeWebIdentity(cfg *aws.Config, handlers request.Handlers,
|
||||
filepath string,
|
||||
roleARN, sessionName string,
|
||||
) (*credentials.Credentials, error) {
|
||||
|
||||
if len(filepath) == 0 {
|
||||
return nil, WebIdentityEmptyTokenFilePathErr
|
||||
}
|
||||
|
||||
if len(roleARN) == 0 {
|
||||
return nil, WebIdentityEmptyRoleARNErr
|
||||
}
|
||||
|
||||
creds := stscreds.NewWebIdentityCredentials(
|
||||
&Session{
|
||||
Config: cfg,
|
||||
Handlers: handlers.Copy(),
|
||||
},
|
||||
roleARN,
|
||||
sessionName,
|
||||
filepath,
|
||||
)
|
||||
|
||||
return creds, nil
|
||||
}
|
||||
|
||||
func resolveCredsFromProfile(cfg *aws.Config,
|
||||
envCfg envConfig, sharedCfg sharedConfig,
|
||||
handlers request.Handlers,
|
||||
sessOpts Options,
|
||||
) (creds *credentials.Credentials, err error) {
|
||||
|
||||
switch {
|
||||
case sharedCfg.SourceProfile != nil:
|
||||
// Assume IAM role with credentials source from a different profile.
|
||||
creds, err = resolveCredsFromProfile(cfg, envCfg,
|
||||
*sharedCfg.SourceProfile, handlers, sessOpts,
|
||||
)
|
||||
|
||||
case sharedCfg.Creds.HasKeys():
|
||||
// Static Credentials from Shared Config/Credentials file.
|
||||
creds = credentials.NewStaticCredentialsFromCreds(
|
||||
sharedCfg.Creds,
|
||||
)
|
||||
|
||||
case len(sharedCfg.CredentialProcess) != 0:
|
||||
// Get credentials from CredentialProcess
|
||||
creds = processcreds.NewCredentials(sharedCfg.CredentialProcess)
|
||||
|
||||
case len(sharedCfg.CredentialSource) != 0:
|
||||
creds, err = resolveCredsFromSource(cfg, envCfg,
|
||||
sharedCfg, handlers, sessOpts,
|
||||
)
|
||||
|
||||
case len(sharedCfg.WebIdentityTokenFile) != 0:
|
||||
// Credentials from Assume Web Identity token require an IAM Role, and
|
||||
// that roll will be assumed. May be wrapped with another assume role
|
||||
// via SourceProfile.
|
||||
return assumeWebIdentity(cfg, handlers,
|
||||
sharedCfg.WebIdentityTokenFile,
|
||||
sharedCfg.RoleARN,
|
||||
sharedCfg.RoleSessionName,
|
||||
)
|
||||
|
||||
default:
|
||||
// Fallback to default credentials provider, include mock errors for
|
||||
// the credential chain so user can identify why credentials failed to
|
||||
// be retrieved.
|
||||
creds = credentials.NewCredentials(&credentials.ChainProvider{
|
||||
VerboseErrors: aws.BoolValue(cfg.CredentialsChainVerboseErrors),
|
||||
Providers: []credentials.Provider{
|
||||
&credProviderError{
|
||||
Err: awserr.New("EnvAccessKeyNotFound",
|
||||
"failed to find credentials in the environment.", nil),
|
||||
},
|
||||
&credProviderError{
|
||||
Err: awserr.New("SharedCredsLoad",
|
||||
fmt.Sprintf("failed to load profile, %s.", envCfg.Profile), nil),
|
||||
},
|
||||
defaults.RemoteCredProvider(*cfg, handlers),
|
||||
},
|
||||
})
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if len(sharedCfg.RoleARN) > 0 {
|
||||
cfgCp := *cfg
|
||||
cfgCp.Credentials = creds
|
||||
return credsFromAssumeRole(cfgCp, handlers, sharedCfg, sessOpts)
|
||||
}
|
||||
|
||||
return creds, nil
|
||||
}
|
||||
|
||||
// valid credential source values
|
||||
const (
|
||||
credSourceEc2Metadata = "Ec2InstanceMetadata"
|
||||
credSourceEnvironment = "Environment"
|
||||
credSourceECSContainer = "EcsContainer"
|
||||
)
|
||||
|
||||
func resolveCredsFromSource(cfg *aws.Config,
|
||||
envCfg envConfig, sharedCfg sharedConfig,
|
||||
handlers request.Handlers,
|
||||
sessOpts Options,
|
||||
) (creds *credentials.Credentials, err error) {
|
||||
|
||||
switch sharedCfg.CredentialSource {
|
||||
case credSourceEc2Metadata:
|
||||
p := defaults.RemoteCredProvider(*cfg, handlers)
|
||||
creds = credentials.NewCredentials(p)
|
||||
|
||||
case credSourceEnvironment:
|
||||
creds = credentials.NewStaticCredentialsFromCreds(envCfg.Creds)
|
||||
|
||||
case credSourceECSContainer:
|
||||
if len(os.Getenv(shareddefaults.ECSCredsProviderEnvVar)) == 0 {
|
||||
return nil, ErrSharedConfigECSContainerEnvVarEmpty
|
||||
}
|
||||
|
||||
p := defaults.RemoteCredProvider(*cfg, handlers)
|
||||
creds = credentials.NewCredentials(p)
|
||||
|
||||
default:
|
||||
return nil, ErrSharedConfigInvalidCredSource
|
||||
}
|
||||
|
||||
return creds, nil
|
||||
}
|
||||
|
||||
func credsFromAssumeRole(cfg aws.Config,
|
||||
handlers request.Handlers,
|
||||
sharedCfg sharedConfig,
|
||||
sessOpts Options,
|
||||
) (*credentials.Credentials, error) {
|
||||
|
||||
if len(sharedCfg.MFASerial) != 0 && sessOpts.AssumeRoleTokenProvider == nil {
|
||||
// AssumeRole Token provider is required if doing Assume Role
|
||||
// with MFA.
|
||||
return nil, AssumeRoleTokenProviderNotSetError{}
|
||||
}
|
||||
|
||||
return stscreds.NewCredentials(
|
||||
&Session{
|
||||
Config: &cfg,
|
||||
Handlers: handlers.Copy(),
|
||||
},
|
||||
sharedCfg.RoleARN,
|
||||
func(opt *stscreds.AssumeRoleProvider) {
|
||||
opt.RoleSessionName = sharedCfg.RoleSessionName
|
||||
opt.Duration = sessOpts.AssumeRoleDuration
|
||||
|
||||
// Assume role with external ID
|
||||
if len(sharedCfg.ExternalID) > 0 {
|
||||
opt.ExternalID = aws.String(sharedCfg.ExternalID)
|
||||
}
|
||||
|
||||
// Assume role with MFA
|
||||
if len(sharedCfg.MFASerial) > 0 {
|
||||
opt.SerialNumber = aws.String(sharedCfg.MFASerial)
|
||||
opt.TokenProvider = sessOpts.AssumeRoleTokenProvider
|
||||
}
|
||||
},
|
||||
), nil
|
||||
}
|
||||
|
||||
// AssumeRoleTokenProviderNotSetError is an error returned when creating a
|
||||
// session when the MFAToken option is not set when shared config is configured
|
||||
// load assume a role with an MFA token.
|
||||
type AssumeRoleTokenProviderNotSetError struct{}
|
||||
|
||||
// Code is the short id of the error.
|
||||
func (e AssumeRoleTokenProviderNotSetError) Code() string {
|
||||
return "AssumeRoleTokenProviderNotSetError"
|
||||
}
|
||||
|
||||
// Message is the description of the error
|
||||
func (e AssumeRoleTokenProviderNotSetError) Message() string {
|
||||
return fmt.Sprintf("assume role with MFA enabled, but AssumeRoleTokenProvider session option not set.")
|
||||
}
|
||||
|
||||
// OrigErr is the underlying error that caused the failure.
|
||||
func (e AssumeRoleTokenProviderNotSetError) OrigErr() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Error satisfies the error interface.
|
||||
func (e AssumeRoleTokenProviderNotSetError) Error() string {
|
||||
return awserr.SprintError(e.Code(), e.Message(), "", nil)
|
||||
}
|
||||
|
||||
type credProviderError struct {
|
||||
Err error
|
||||
}
|
||||
|
||||
func (c credProviderError) Retrieve() (credentials.Value, error) {
|
||||
return credentials.Value{}, c.Err
|
||||
}
|
||||
func (c credProviderError) IsExpired() bool {
|
||||
return true
|
||||
}
|
|
@ -1,97 +1,93 @@
|
|||
/*
|
||||
Package session provides configuration for the SDK's service clients.
|
||||
|
||||
Sessions can be shared across all service clients that share the same base
|
||||
configuration. The Session is built from the SDK's default configuration and
|
||||
request handlers.
|
||||
|
||||
Sessions should be cached when possible, because creating a new Session will
|
||||
load all configuration values from the environment, and config files each time
|
||||
the Session is created. Sharing the Session value across all of your service
|
||||
clients will ensure the configuration is loaded the fewest number of times possible.
|
||||
|
||||
Concurrency
|
||||
Package session provides configuration for the SDK's service clients. Sessions
|
||||
can be shared across service clients that share the same base configuration.
|
||||
|
||||
Sessions are safe to use concurrently as long as the Session is not being
|
||||
modified. The SDK will not modify the Session once the Session has been created.
|
||||
Creating service clients concurrently from a shared Session is safe.
|
||||
modified. Sessions should be cached when possible, because creating a new
|
||||
Session will load all configuration values from the environment, and config
|
||||
files each time the Session is created. Sharing the Session value across all of
|
||||
your service clients will ensure the configuration is loaded the fewest number
|
||||
of times possible.
|
||||
|
||||
Sessions from Shared Config
|
||||
|
||||
Sessions can be created using the method above that will only load the
|
||||
additional config if the AWS_SDK_LOAD_CONFIG environment variable is set.
|
||||
Alternatively you can explicitly create a Session with shared config enabled.
|
||||
To do this you can use NewSessionWithOptions to configure how the Session will
|
||||
be created. Using the NewSessionWithOptions with SharedConfigState set to
|
||||
SharedConfigEnable will create the session as if the AWS_SDK_LOAD_CONFIG
|
||||
environment variable was set.
|
||||
|
||||
Creating Sessions
|
||||
|
||||
When creating Sessions optional aws.Config values can be passed in that will
|
||||
override the default, or loaded config values the Session is being created
|
||||
with. This allows you to provide additional, or case based, configuration
|
||||
as needed.
|
||||
Sessions options from Shared Config
|
||||
|
||||
By default NewSession will only load credentials from the shared credentials
|
||||
file (~/.aws/credentials). If the AWS_SDK_LOAD_CONFIG environment variable is
|
||||
set to a truthy value the Session will be created from the configuration
|
||||
values from the shared config (~/.aws/config) and shared credentials
|
||||
(~/.aws/credentials) files. See the section Sessions from Shared Config for
|
||||
more information.
|
||||
(~/.aws/credentials) files. Using the NewSessionWithOptions with
|
||||
SharedConfigState set to SharedConfigEnable will create the session as if the
|
||||
AWS_SDK_LOAD_CONFIG environment variable was set.
|
||||
|
||||
Create a Session with the default config and request handlers. With credentials
|
||||
region, and profile loaded from the environment and shared config automatically.
|
||||
Requires the AWS_PROFILE to be set, or "default" is used.
|
||||
Credential and config loading order
|
||||
|
||||
The Session will attempt to load configuration and credentials from the
|
||||
environment, configuration files, and other credential sources. The order
|
||||
configuration is loaded in is:
|
||||
|
||||
* Environment Variables
|
||||
* Shared Credentials file
|
||||
* Shared Configuration file (if SharedConfig is enabled)
|
||||
* EC2 Instance Metadata (credentials only)
|
||||
|
||||
The Environment variables for credentials will have precedence over shared
|
||||
config even if SharedConfig is enabled. To override this behavior, and use
|
||||
shared config credentials instead specify the session.Options.Profile, (e.g.
|
||||
when using credential_source=Environment to assume a role).
|
||||
|
||||
sess, err := session.NewSessionWithOptions(session.Options{
|
||||
Profile: "myProfile",
|
||||
})
|
||||
|
||||
Creating Sessions
|
||||
|
||||
Creating a Session without additional options will load credentials region, and
|
||||
profile loaded from the environment and shared config automatically. See,
|
||||
"Environment Variables" section for information on environment variables used
|
||||
by Session.
|
||||
|
||||
// Create Session
|
||||
sess := session.Must(session.NewSession())
|
||||
sess, err := session.NewSession()
|
||||
|
||||
|
||||
When creating Sessions optional aws.Config values can be passed in that will
|
||||
override the default, or loaded, config values the Session is being created
|
||||
with. This allows you to provide additional, or case based, configuration
|
||||
as needed.
|
||||
|
||||
// Create a Session with a custom region
|
||||
sess := session.Must(session.NewSession(&aws.Config{
|
||||
Region: aws.String("us-east-1"),
|
||||
}))
|
||||
sess, err := session.NewSession(&aws.Config{
|
||||
Region: aws.String("us-west-2"),
|
||||
})
|
||||
|
||||
// Create a S3 client instance from a session
|
||||
sess := session.Must(session.NewSession())
|
||||
|
||||
svc := s3.New(sess)
|
||||
|
||||
Create Session With Option Overrides
|
||||
|
||||
In addition to NewSession, Sessions can be created using NewSessionWithOptions.
|
||||
This func allows you to control and override how the Session will be created
|
||||
through code instead of being driven by environment variables only.
|
||||
|
||||
Use NewSessionWithOptions when you want to provide the config profile, or
|
||||
override the shared config state (AWS_SDK_LOAD_CONFIG).
|
||||
Use NewSessionWithOptions to provide additional configuration driving how the
|
||||
Session's configuration will be loaded. Such as, specifying shared config
|
||||
profile, or override the shared config state, (AWS_SDK_LOAD_CONFIG).
|
||||
|
||||
// Equivalent to session.NewSession()
|
||||
sess := session.Must(session.NewSessionWithOptions(session.Options{
|
||||
sess, err := session.NewSessionWithOptions(session.Options{
|
||||
// Options
|
||||
}))
|
||||
})
|
||||
|
||||
// Specify profile to load for the session's config
|
||||
sess := session.Must(session.NewSessionWithOptions(session.Options{
|
||||
Profile: "profile_name",
|
||||
}))
|
||||
sess, err := session.NewSessionWithOptions(session.Options{
|
||||
// Specify profile to load for the session's config
|
||||
Profile: "profile_name",
|
||||
|
||||
// Specify profile for config and region for requests
|
||||
sess := session.Must(session.NewSessionWithOptions(session.Options{
|
||||
Config: aws.Config{Region: aws.String("us-east-1")},
|
||||
Profile: "profile_name",
|
||||
}))
|
||||
// Provide SDK Config options, such as Region.
|
||||
Config: aws.Config{
|
||||
Region: aws.String("us-west-2"),
|
||||
},
|
||||
|
||||
// Force enable Shared Config support
|
||||
sess := session.Must(session.NewSessionWithOptions(session.Options{
|
||||
// Force enable Shared Config support
|
||||
SharedConfigState: session.SharedConfigEnable,
|
||||
}))
|
||||
})
|
||||
|
||||
Adding Handlers
|
||||
|
||||
You can add handlers to a session for processing HTTP requests. All service
|
||||
clients that use the session inherit the handlers. For example, the following
|
||||
handler logs every request and its payload made by a service client:
|
||||
You can add handlers to a session to decorate API operation, (e.g. adding HTTP
|
||||
headers). All clients that use the Session receive a copy of the Session's
|
||||
handlers. For example, the following request handler added to the Session logs
|
||||
every requests made.
|
||||
|
||||
// Create a session, and add additional handlers for all service
|
||||
// clients created with the Session to inherit. Adds logging handler.
|
||||
|
@ -99,22 +95,15 @@ handler logs every request and its payload made by a service client:
|
|||
|
||||
sess.Handlers.Send.PushFront(func(r *request.Request) {
|
||||
// Log every request made and its payload
|
||||
logger.Println("Request: %s/%s, Payload: %s",
|
||||
logger.Printf("Request: %s/%s, Params: %s",
|
||||
r.ClientInfo.ServiceName, r.Operation, r.Params)
|
||||
})
|
||||
|
||||
Deprecated "New" function
|
||||
|
||||
The New session function has been deprecated because it does not provide good
|
||||
way to return errors that occur when loading the configuration files and values.
|
||||
Because of this, NewSession was created so errors can be retrieved when
|
||||
creating a session fails.
|
||||
|
||||
Shared Config Fields
|
||||
|
||||
By default the SDK will only load the shared credentials file's (~/.aws/credentials)
|
||||
credentials values, and all other config is provided by the environment variables,
|
||||
SDK defaults, and user provided aws.Config values.
|
||||
By default the SDK will only load the shared credentials file's
|
||||
(~/.aws/credentials) credentials values, and all other config is provided by
|
||||
the environment variables, SDK defaults, and user provided aws.Config values.
|
||||
|
||||
If the AWS_SDK_LOAD_CONFIG environment variable is set, or SharedConfigEnable
|
||||
option is used to create the Session the full shared config values will be
|
||||
|
@ -125,24 +114,31 @@ files have the same format.
|
|||
|
||||
If both config files are present the configuration from both files will be
|
||||
read. The Session will be created from configuration values from the shared
|
||||
credentials file (~/.aws/credentials) over those in the shared config file (~/.aws/config).
|
||||
credentials file (~/.aws/credentials) over those in the shared config file
|
||||
(~/.aws/config).
|
||||
|
||||
Credentials are the values the SDK should use for authenticating requests with
|
||||
AWS Services. They are from a configuration file will need to include both
|
||||
aws_access_key_id and aws_secret_access_key must be provided together in the
|
||||
same file to be considered valid. The values will be ignored if not a complete
|
||||
group. aws_session_token is an optional field that can be provided if both of
|
||||
the other two fields are also provided.
|
||||
Credentials are the values the SDK uses to authenticating requests with AWS
|
||||
Services. When specified in a file, both aws_access_key_id and
|
||||
aws_secret_access_key must be provided together in the same file to be
|
||||
considered valid. They will be ignored if both are not present.
|
||||
aws_session_token is an optional field that can be provided in addition to the
|
||||
other two fields.
|
||||
|
||||
aws_access_key_id = AKID
|
||||
aws_secret_access_key = SECRET
|
||||
aws_session_token = TOKEN
|
||||
|
||||
Assume Role values allow you to configure the SDK to assume an IAM role using
|
||||
a set of credentials provided in a config file via the source_profile field.
|
||||
Both "role_arn" and "source_profile" are required. The SDK supports assuming
|
||||
a role with MFA token if the session option AssumeRoleTokenProvider
|
||||
is set.
|
||||
; region only supported if SharedConfigEnabled.
|
||||
region = us-east-1
|
||||
|
||||
Assume Role configuration
|
||||
|
||||
The role_arn field allows you to configure the SDK to assume an IAM role using
|
||||
a set of credentials from another source. Such as when paired with static
|
||||
credentials, "profile_source", "credential_process", or "credential_source"
|
||||
fields. If "role_arn" is provided, a source of credentials must also be
|
||||
specified, such as "source_profile", "credential_source", or
|
||||
"credential_process".
|
||||
|
||||
role_arn = arn:aws:iam::<account_number>:role/<role_name>
|
||||
source_profile = profile_with_creds
|
||||
|
@ -150,40 +146,16 @@ is set.
|
|||
mfa_serial = <serial or mfa arn>
|
||||
role_session_name = session_name
|
||||
|
||||
Region is the region the SDK should use for looking up AWS service endpoints
|
||||
and signing requests.
|
||||
|
||||
region = us-east-1
|
||||
|
||||
Assume Role with MFA token
|
||||
|
||||
To create a session with support for assuming an IAM role with MFA set the
|
||||
session option AssumeRoleTokenProvider to a function that will prompt for the
|
||||
MFA token code when the SDK assumes the role and refreshes the role's credentials.
|
||||
This allows you to configure the SDK via the shared config to assumea role
|
||||
with MFA tokens.
|
||||
|
||||
In order for the SDK to assume a role with MFA the SharedConfigState
|
||||
session option must be set to SharedConfigEnable, or AWS_SDK_LOAD_CONFIG
|
||||
environment variable set.
|
||||
|
||||
The shared configuration instructs the SDK to assume an IAM role with MFA
|
||||
when the mfa_serial configuration field is set in the shared config
|
||||
(~/.aws/config) or shared credentials (~/.aws/credentials) file.
|
||||
|
||||
If mfa_serial is set in the configuration, the SDK will assume the role, and
|
||||
the AssumeRoleTokenProvider session option is not set an an error will
|
||||
be returned when creating the session.
|
||||
The SDK supports assuming a role with MFA token. If "mfa_serial" is set, you
|
||||
must also set the Session Option.AssumeRoleTokenProvider. The Session will fail
|
||||
to load if the AssumeRoleTokenProvider is not specified.
|
||||
|
||||
sess := session.Must(session.NewSessionWithOptions(session.Options{
|
||||
AssumeRoleTokenProvider: stscreds.StdinTokenProvider,
|
||||
}))
|
||||
|
||||
// Create service client value configured for credentials
|
||||
// from assumed role.
|
||||
svc := s3.New(sess)
|
||||
|
||||
To setup assume role outside of a session see the stscrds.AssumeRoleProvider
|
||||
To setup Assume Role outside of a session see the stscreds.AssumeRoleProvider
|
||||
documentation.
|
||||
|
||||
Environment Variables
|
||||
|
|
|
@ -1,11 +1,14 @@
|
|||
package session
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"strconv"
|
||||
|
||||
"github.com/aws/aws-sdk-go/aws"
|
||||
"github.com/aws/aws-sdk-go/aws/credentials"
|
||||
"github.com/aws/aws-sdk-go/aws/defaults"
|
||||
"github.com/aws/aws-sdk-go/aws/endpoints"
|
||||
)
|
||||
|
||||
// EnvProviderName provides a name of the provider when config is loaded from environment.
|
||||
|
@ -79,7 +82,7 @@ type envConfig struct {
|
|||
// AWS_CONFIG_FILE=$HOME/my_shared_config
|
||||
SharedConfigFile string
|
||||
|
||||
// Sets the path to a custom Credentials Authroity (CA) Bundle PEM file
|
||||
// Sets the path to a custom Credentials Authority (CA) Bundle PEM file
|
||||
// that the SDK will use instead of the system's root CA bundle.
|
||||
// Only use this if you want to configure the SDK to use a custom set
|
||||
// of CAs.
|
||||
|
@ -98,15 +101,47 @@ type envConfig struct {
|
|||
CustomCABundle string
|
||||
|
||||
csmEnabled string
|
||||
CSMEnabled bool
|
||||
CSMEnabled *bool
|
||||
CSMPort string
|
||||
CSMHost string
|
||||
CSMClientID string
|
||||
|
||||
// Enables endpoint discovery via environment variables.
|
||||
//
|
||||
// AWS_ENABLE_ENDPOINT_DISCOVERY=true
|
||||
EnableEndpointDiscovery *bool
|
||||
enableEndpointDiscovery string
|
||||
|
||||
// Specifies the WebIdentity token the SDK should use to assume a role
|
||||
// with.
|
||||
//
|
||||
// AWS_WEB_IDENTITY_TOKEN_FILE=file_path
|
||||
WebIdentityTokenFilePath string
|
||||
|
||||
// Specifies the IAM role arn to use when assuming an role.
|
||||
//
|
||||
// AWS_ROLE_ARN=role_arn
|
||||
RoleARN string
|
||||
|
||||
// Specifies the IAM role session name to use when assuming a role.
|
||||
//
|
||||
// AWS_ROLE_SESSION_NAME=session_name
|
||||
RoleSessionName string
|
||||
|
||||
// Specifies the Regional Endpoint flag for the sdk to resolve the endpoint for a service
|
||||
//
|
||||
// AWS_STS_REGIONAL_ENDPOINTS =sts_regional_endpoint
|
||||
// This can take value as `regional` or `legacy`
|
||||
STSRegionalEndpoint endpoints.STSRegionalEndpoint
|
||||
}
|
||||
|
||||
var (
|
||||
csmEnabledEnvKey = []string{
|
||||
"AWS_CSM_ENABLED",
|
||||
}
|
||||
csmHostEnvKey = []string{
|
||||
"AWS_CSM_HOST",
|
||||
}
|
||||
csmPortEnvKey = []string{
|
||||
"AWS_CSM_PORT",
|
||||
}
|
||||
|
@ -125,6 +160,10 @@ var (
|
|||
"AWS_SESSION_TOKEN",
|
||||
}
|
||||
|
||||
enableEndpointDiscoveryEnvKey = []string{
|
||||
"AWS_ENABLE_ENDPOINT_DISCOVERY",
|
||||
}
|
||||
|
||||
regionEnvKeys = []string{
|
||||
"AWS_REGION",
|
||||
"AWS_DEFAULT_REGION", // Only read if AWS_SDK_LOAD_CONFIG is also set
|
||||
|
@ -139,6 +178,18 @@ var (
|
|||
sharedConfigFileEnvKey = []string{
|
||||
"AWS_CONFIG_FILE",
|
||||
}
|
||||
webIdentityTokenFilePathEnvKey = []string{
|
||||
"AWS_WEB_IDENTITY_TOKEN_FILE",
|
||||
}
|
||||
roleARNEnvKey = []string{
|
||||
"AWS_ROLE_ARN",
|
||||
}
|
||||
roleSessionNameEnvKey = []string{
|
||||
"AWS_ROLE_SESSION_NAME",
|
||||
}
|
||||
stsRegionalEndpointKey = []string{
|
||||
"AWS_STS_REGIONAL_ENDPOINTS",
|
||||
}
|
||||
)
|
||||
|
||||
// loadEnvConfig retrieves the SDK's environment configuration.
|
||||
|
@ -147,7 +198,7 @@ var (
|
|||
// If the environment variable `AWS_SDK_LOAD_CONFIG` is set to a truthy value
|
||||
// the shared SDK config will be loaded in addition to the SDK's specific
|
||||
// configuration values.
|
||||
func loadEnvConfig() envConfig {
|
||||
func loadEnvConfig() (envConfig, error) {
|
||||
enableSharedConfig, _ := strconv.ParseBool(os.Getenv("AWS_SDK_LOAD_CONFIG"))
|
||||
return envConfigLoad(enableSharedConfig)
|
||||
}
|
||||
|
@ -158,30 +209,42 @@ func loadEnvConfig() envConfig {
|
|||
// Loads the shared configuration in addition to the SDK's specific configuration.
|
||||
// This will load the same values as `loadEnvConfig` if the `AWS_SDK_LOAD_CONFIG`
|
||||
// environment variable is set.
|
||||
func loadSharedEnvConfig() envConfig {
|
||||
func loadSharedEnvConfig() (envConfig, error) {
|
||||
return envConfigLoad(true)
|
||||
}
|
||||
|
||||
func envConfigLoad(enableSharedConfig bool) envConfig {
|
||||
func envConfigLoad(enableSharedConfig bool) (envConfig, error) {
|
||||
cfg := envConfig{}
|
||||
|
||||
cfg.EnableSharedConfig = enableSharedConfig
|
||||
|
||||
setFromEnvVal(&cfg.Creds.AccessKeyID, credAccessEnvKey)
|
||||
setFromEnvVal(&cfg.Creds.SecretAccessKey, credSecretEnvKey)
|
||||
setFromEnvVal(&cfg.Creds.SessionToken, credSessionEnvKey)
|
||||
// Static environment credentials
|
||||
var creds credentials.Value
|
||||
setFromEnvVal(&creds.AccessKeyID, credAccessEnvKey)
|
||||
setFromEnvVal(&creds.SecretAccessKey, credSecretEnvKey)
|
||||
setFromEnvVal(&creds.SessionToken, credSessionEnvKey)
|
||||
if creds.HasKeys() {
|
||||
// Require logical grouping of credentials
|
||||
creds.ProviderName = EnvProviderName
|
||||
cfg.Creds = creds
|
||||
}
|
||||
|
||||
// Role Metadata
|
||||
setFromEnvVal(&cfg.RoleARN, roleARNEnvKey)
|
||||
setFromEnvVal(&cfg.RoleSessionName, roleSessionNameEnvKey)
|
||||
|
||||
// Web identity environment variables
|
||||
setFromEnvVal(&cfg.WebIdentityTokenFilePath, webIdentityTokenFilePathEnvKey)
|
||||
|
||||
// CSM environment variables
|
||||
setFromEnvVal(&cfg.csmEnabled, csmEnabledEnvKey)
|
||||
setFromEnvVal(&cfg.CSMHost, csmHostEnvKey)
|
||||
setFromEnvVal(&cfg.CSMPort, csmPortEnvKey)
|
||||
setFromEnvVal(&cfg.CSMClientID, csmClientIDEnvKey)
|
||||
cfg.CSMEnabled = len(cfg.csmEnabled) > 0
|
||||
|
||||
// Require logical grouping of credentials
|
||||
if len(cfg.Creds.AccessKeyID) == 0 || len(cfg.Creds.SecretAccessKey) == 0 {
|
||||
cfg.Creds = credentials.Value{}
|
||||
} else {
|
||||
cfg.Creds.ProviderName = EnvProviderName
|
||||
if len(cfg.csmEnabled) != 0 {
|
||||
v, _ := strconv.ParseBool(cfg.csmEnabled)
|
||||
cfg.CSMEnabled = &v
|
||||
}
|
||||
|
||||
regionKeys := regionEnvKeys
|
||||
|
@ -194,6 +257,12 @@ func envConfigLoad(enableSharedConfig bool) envConfig {
|
|||
setFromEnvVal(&cfg.Region, regionKeys)
|
||||
setFromEnvVal(&cfg.Profile, profileKeys)
|
||||
|
||||
// endpoint discovery is in reference to it being enabled.
|
||||
setFromEnvVal(&cfg.enableEndpointDiscovery, enableEndpointDiscoveryEnvKey)
|
||||
if len(cfg.enableEndpointDiscovery) > 0 {
|
||||
cfg.EnableEndpointDiscovery = aws.Bool(cfg.enableEndpointDiscovery != "false")
|
||||
}
|
||||
|
||||
setFromEnvVal(&cfg.SharedCredentialsFile, sharedCredsFileEnvKey)
|
||||
setFromEnvVal(&cfg.SharedConfigFile, sharedConfigFileEnvKey)
|
||||
|
||||
|
@ -206,12 +275,23 @@ func envConfigLoad(enableSharedConfig bool) envConfig {
|
|||
|
||||
cfg.CustomCABundle = os.Getenv("AWS_CA_BUNDLE")
|
||||
|
||||
return cfg
|
||||
// STS Regional Endpoint variable
|
||||
for _, k := range stsRegionalEndpointKey {
|
||||
if v := os.Getenv(k); len(v) != 0 {
|
||||
STSRegionalEndpoint, err := endpoints.GetSTSRegionalEndpoint(v)
|
||||
if err != nil {
|
||||
return cfg, fmt.Errorf("failed to load, %v from env config, %v", k, err)
|
||||
}
|
||||
cfg.STSRegionalEndpoint = STSRegionalEndpoint
|
||||
}
|
||||
}
|
||||
|
||||
return cfg, nil
|
||||
}
|
||||
|
||||
func setFromEnvVal(dst *string, keys []string) {
|
||||
for _, k := range keys {
|
||||
if v := os.Getenv(k); len(v) > 0 {
|
||||
if v := os.Getenv(k); len(v) != 0 {
|
||||
*dst = v
|
||||
break
|
||||
}
|
||||
|
|
|
@ -8,19 +8,36 @@ import (
|
|||
"io/ioutil"
|
||||
"net/http"
|
||||
"os"
|
||||
"time"
|
||||
|
||||
"github.com/aws/aws-sdk-go/aws"
|
||||
"github.com/aws/aws-sdk-go/aws/awserr"
|
||||
"github.com/aws/aws-sdk-go/aws/client"
|
||||
"github.com/aws/aws-sdk-go/aws/corehandlers"
|
||||
"github.com/aws/aws-sdk-go/aws/credentials"
|
||||
"github.com/aws/aws-sdk-go/aws/credentials/stscreds"
|
||||
"github.com/aws/aws-sdk-go/aws/csm"
|
||||
"github.com/aws/aws-sdk-go/aws/defaults"
|
||||
"github.com/aws/aws-sdk-go/aws/endpoints"
|
||||
"github.com/aws/aws-sdk-go/aws/request"
|
||||
)
|
||||
|
||||
const (
|
||||
// ErrCodeSharedConfig represents an error that occurs in the shared
|
||||
// configuration logic
|
||||
ErrCodeSharedConfig = "SharedConfigErr"
|
||||
)
|
||||
|
||||
// ErrSharedConfigSourceCollision will be returned if a section contains both
|
||||
// source_profile and credential_source
|
||||
var ErrSharedConfigSourceCollision = awserr.New(ErrCodeSharedConfig, "only source profile or credential source can be specified, not both", nil)
|
||||
|
||||
// ErrSharedConfigECSContainerEnvVarEmpty will be returned if the environment
|
||||
// variables are empty and Environment was set as the credential source
|
||||
var ErrSharedConfigECSContainerEnvVarEmpty = awserr.New(ErrCodeSharedConfig, "EcsContainer was specified as the credential_source, but 'AWS_CONTAINER_CREDENTIALS_RELATIVE_URI' was not set", nil)
|
||||
|
||||
// ErrSharedConfigInvalidCredSource will be returned if an invalid credential source was provided
|
||||
var ErrSharedConfigInvalidCredSource = awserr.New(ErrCodeSharedConfig, "credential source values must be EcsContainer, Ec2InstanceMetadata, or Environment", nil)
|
||||
|
||||
// A Session provides a central location to create service clients from and
|
||||
// store configurations and request handlers for those services.
|
||||
//
|
||||
|
@ -56,7 +73,7 @@ type Session struct {
|
|||
// func is called instead of waiting to receive an error until a request is made.
|
||||
func New(cfgs ...*aws.Config) *Session {
|
||||
// load initial config from environment
|
||||
envCfg := loadEnvConfig()
|
||||
envCfg, envErr := loadEnvConfig()
|
||||
|
||||
if envCfg.EnableSharedConfig {
|
||||
var cfg aws.Config
|
||||
|
@ -76,19 +93,28 @@ func New(cfgs ...*aws.Config) *Session {
|
|||
// Session creation failed, need to report the error and prevent
|
||||
// any requests from succeeding.
|
||||
s = &Session{Config: defaults.Config()}
|
||||
s.Config.MergeIn(cfgs...)
|
||||
s.Config.Logger.Log("ERROR:", msg, "Error:", err)
|
||||
s.Handlers.Validate.PushBack(func(r *request.Request) {
|
||||
r.Error = err
|
||||
})
|
||||
s.logDeprecatedNewSessionError(msg, err, cfgs)
|
||||
}
|
||||
|
||||
return s
|
||||
}
|
||||
|
||||
s := deprecatedNewSession(cfgs...)
|
||||
if envCfg.CSMEnabled {
|
||||
enableCSM(&s.Handlers, envCfg.CSMClientID, envCfg.CSMPort, s.Config.Logger)
|
||||
if envErr != nil {
|
||||
msg := "failed to load env config"
|
||||
s.logDeprecatedNewSessionError(msg, envErr, cfgs)
|
||||
}
|
||||
|
||||
if csmCfg, err := loadCSMConfig(envCfg, []string{}); err != nil {
|
||||
if l := s.Config.Logger; l != nil {
|
||||
l.Log(fmt.Sprintf("ERROR: failed to load CSM configuration, %v", err))
|
||||
}
|
||||
} else if csmCfg.Enabled {
|
||||
err := enableCSM(&s.Handlers, csmCfg, s.Config.Logger)
|
||||
if err != nil {
|
||||
msg := "failed to enable CSM"
|
||||
s.logDeprecatedNewSessionError(msg, err, cfgs)
|
||||
}
|
||||
}
|
||||
|
||||
return s
|
||||
|
@ -107,7 +133,7 @@ func New(cfgs ...*aws.Config) *Session {
|
|||
// to be built with retrieving credentials with AssumeRole set in the config.
|
||||
//
|
||||
// See the NewSessionWithOptions func for information on how to override or
|
||||
// control through code how the Session will be created. Such as specifying the
|
||||
// control through code how the Session will be created, such as specifying the
|
||||
// config profile, and controlling if shared config is enabled or not.
|
||||
func NewSession(cfgs ...*aws.Config) (*Session, error) {
|
||||
opts := Options{}
|
||||
|
@ -191,6 +217,12 @@ type Options struct {
|
|||
// the config enables assume role wit MFA via the mfa_serial field.
|
||||
AssumeRoleTokenProvider func() (string, error)
|
||||
|
||||
// When the SDK's shared config is configured to assume a role this option
|
||||
// may be provided to set the expiry duration of the STS credentials.
|
||||
// Defaults to 15 minutes if not set as documented in the
|
||||
// stscreds.AssumeRoleProvider.
|
||||
AssumeRoleDuration time.Duration
|
||||
|
||||
// Reader for a custom Credentials Authority (CA) bundle in PEM format that
|
||||
// the SDK will use instead of the default system's root CA bundle. Use this
|
||||
// only if you want to replace the CA bundle the SDK uses for TLS requests.
|
||||
|
@ -205,6 +237,12 @@ type Options struct {
|
|||
// to also enable this feature. CustomCABundle session option field has priority
|
||||
// over the AWS_CA_BUNDLE environment variable, and will be used if both are set.
|
||||
CustomCABundle io.Reader
|
||||
|
||||
// The handlers that the session and all API clients will be created with.
|
||||
// This must be a complete set of handlers. Use the defaults.Handlers()
|
||||
// function to initialize this value before changing the handlers to be
|
||||
// used by the SDK.
|
||||
Handlers request.Handlers
|
||||
}
|
||||
|
||||
// NewSessionWithOptions returns a new Session created from SDK defaults, config files,
|
||||
|
@ -238,13 +276,20 @@ type Options struct {
|
|||
// }))
|
||||
func NewSessionWithOptions(opts Options) (*Session, error) {
|
||||
var envCfg envConfig
|
||||
var err error
|
||||
if opts.SharedConfigState == SharedConfigEnable {
|
||||
envCfg = loadSharedEnvConfig()
|
||||
envCfg, err = loadSharedEnvConfig()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to load shared config, %v", err)
|
||||
}
|
||||
} else {
|
||||
envCfg = loadEnvConfig()
|
||||
envCfg, err = loadEnvConfig()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to load environment config, %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
if len(opts.Profile) > 0 {
|
||||
if len(opts.Profile) != 0 {
|
||||
envCfg.Profile = opts.Profile
|
||||
}
|
||||
|
||||
|
@ -310,27 +355,33 @@ func deprecatedNewSession(cfgs ...*aws.Config) *Session {
|
|||
return s
|
||||
}
|
||||
|
||||
func enableCSM(handlers *request.Handlers, clientID string, port string, logger aws.Logger) {
|
||||
logger.Log("Enabling CSM")
|
||||
if len(port) == 0 {
|
||||
port = csm.DefaultPort
|
||||
func enableCSM(handlers *request.Handlers, cfg csmConfig, logger aws.Logger) error {
|
||||
if logger != nil {
|
||||
logger.Log("Enabling CSM")
|
||||
}
|
||||
|
||||
r, err := csm.Start(clientID, "127.0.0.1:"+port)
|
||||
r, err := csm.Start(cfg.ClientID, csm.AddressWithDefaults(cfg.Host, cfg.Port))
|
||||
if err != nil {
|
||||
return
|
||||
return err
|
||||
}
|
||||
r.InjectHandlers(handlers)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func newSession(opts Options, envCfg envConfig, cfgs ...*aws.Config) (*Session, error) {
|
||||
cfg := defaults.Config()
|
||||
handlers := defaults.Handlers()
|
||||
|
||||
handlers := opts.Handlers
|
||||
if handlers.IsEmpty() {
|
||||
handlers = defaults.Handlers()
|
||||
}
|
||||
|
||||
// Get a merged version of the user provided config to determine if
|
||||
// credentials were.
|
||||
userCfg := &aws.Config{}
|
||||
userCfg.MergeIn(cfgs...)
|
||||
cfg.MergeIn(userCfg)
|
||||
|
||||
// Ordered config files will be loaded in with later files overwriting
|
||||
// previous config file values.
|
||||
|
@ -347,9 +398,17 @@ func newSession(opts Options, envCfg envConfig, cfgs ...*aws.Config) (*Session,
|
|||
}
|
||||
|
||||
// Load additional config from file(s)
|
||||
sharedCfg, err := loadSharedConfig(envCfg.Profile, cfgFiles)
|
||||
sharedCfg, err := loadSharedConfig(envCfg.Profile, cfgFiles, envCfg.EnableSharedConfig)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
if len(envCfg.Profile) == 0 && !envCfg.EnableSharedConfig && (envCfg.Creds.HasKeys() || userCfg.Credentials != nil) {
|
||||
// Special case where the user has not explicitly specified an AWS_PROFILE,
|
||||
// or session.Options.profile, shared config is not enabled, and the
|
||||
// environment has credentials, allow the shared config file to fail to
|
||||
// load since the user has already provided credentials, and nothing else
|
||||
// is required to be read file. Github(aws/aws-sdk-go#2455)
|
||||
} else if _, ok := err.(SharedConfigProfileNotExistsError); !ok {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
if err := mergeConfigSrcs(cfg, userCfg, envCfg, sharedCfg, handlers, opts); err != nil {
|
||||
|
@ -362,8 +421,16 @@ func newSession(opts Options, envCfg envConfig, cfgs ...*aws.Config) (*Session,
|
|||
}
|
||||
|
||||
initHandlers(s)
|
||||
if envCfg.CSMEnabled {
|
||||
enableCSM(&s.Handlers, envCfg.CSMClientID, envCfg.CSMPort, s.Config.Logger)
|
||||
|
||||
if csmCfg, err := loadCSMConfig(envCfg, cfgFiles); err != nil {
|
||||
if l := s.Config.Logger; l != nil {
|
||||
l.Log(fmt.Sprintf("ERROR: failed to load CSM configuration, %v", err))
|
||||
}
|
||||
} else if csmCfg.Enabled {
|
||||
err = enableCSM(&s.Handlers, csmCfg, s.Config.Logger)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
// Setup HTTP client with custom cert bundle if enabled
|
||||
|
@ -376,6 +443,46 @@ func newSession(opts Options, envCfg envConfig, cfgs ...*aws.Config) (*Session,
|
|||
return s, nil
|
||||
}
|
||||
|
||||
type csmConfig struct {
|
||||
Enabled bool
|
||||
Host string
|
||||
Port string
|
||||
ClientID string
|
||||
}
|
||||
|
||||
var csmProfileName = "aws_csm"
|
||||
|
||||
func loadCSMConfig(envCfg envConfig, cfgFiles []string) (csmConfig, error) {
|
||||
if envCfg.CSMEnabled != nil {
|
||||
if *envCfg.CSMEnabled {
|
||||
return csmConfig{
|
||||
Enabled: true,
|
||||
ClientID: envCfg.CSMClientID,
|
||||
Host: envCfg.CSMHost,
|
||||
Port: envCfg.CSMPort,
|
||||
}, nil
|
||||
}
|
||||
return csmConfig{}, nil
|
||||
}
|
||||
|
||||
sharedCfg, err := loadSharedConfig(csmProfileName, cfgFiles, false)
|
||||
if err != nil {
|
||||
if _, ok := err.(SharedConfigProfileNotExistsError); !ok {
|
||||
return csmConfig{}, err
|
||||
}
|
||||
}
|
||||
if sharedCfg.CSMEnabled != nil && *sharedCfg.CSMEnabled == true {
|
||||
return csmConfig{
|
||||
Enabled: true,
|
||||
ClientID: sharedCfg.CSMClientID,
|
||||
Host: sharedCfg.CSMHost,
|
||||
Port: sharedCfg.CSMPort,
|
||||
}, nil
|
||||
}
|
||||
|
||||
return csmConfig{}, nil
|
||||
}
|
||||
|
||||
func loadCustomCABundle(s *Session, bundle io.Reader) error {
|
||||
var t *http.Transport
|
||||
switch v := s.Config.HTTPClient.Transport.(type) {
|
||||
|
@ -388,7 +495,10 @@ func loadCustomCABundle(s *Session, bundle io.Reader) error {
|
|||
}
|
||||
}
|
||||
if t == nil {
|
||||
t = &http.Transport{}
|
||||
// Nil transport implies `http.DefaultTransport` should be used. Since
|
||||
// the SDK cannot modify, nor copy the `DefaultTransport` specifying
|
||||
// the values the next closest behavior.
|
||||
t = getCABundleTransport()
|
||||
}
|
||||
|
||||
p, err := loadCertPool(bundle)
|
||||
|
@ -421,9 +531,11 @@ func loadCertPool(r io.Reader) (*x509.CertPool, error) {
|
|||
return p, nil
|
||||
}
|
||||
|
||||
func mergeConfigSrcs(cfg, userCfg *aws.Config, envCfg envConfig, sharedCfg sharedConfig, handlers request.Handlers, sessOpts Options) error {
|
||||
// Merge in user provided configuration
|
||||
cfg.MergeIn(userCfg)
|
||||
func mergeConfigSrcs(cfg, userCfg *aws.Config,
|
||||
envCfg envConfig, sharedCfg sharedConfig,
|
||||
handlers request.Handlers,
|
||||
sessOpts Options,
|
||||
) error {
|
||||
|
||||
// Region if not already set by user
|
||||
if len(aws.StringValue(cfg.Region)) == 0 {
|
||||
|
@ -434,103 +546,46 @@ func mergeConfigSrcs(cfg, userCfg *aws.Config, envCfg envConfig, sharedCfg share
|
|||
}
|
||||
}
|
||||
|
||||
// Configure credentials if not already set
|
||||
if cfg.Credentials == credentials.AnonymousCredentials && userCfg.Credentials == nil {
|
||||
if len(envCfg.Creds.AccessKeyID) > 0 {
|
||||
cfg.Credentials = credentials.NewStaticCredentialsFromCreds(
|
||||
envCfg.Creds,
|
||||
)
|
||||
} else if envCfg.EnableSharedConfig && len(sharedCfg.AssumeRole.RoleARN) > 0 && sharedCfg.AssumeRoleSource != nil {
|
||||
cfgCp := *cfg
|
||||
cfgCp.Credentials = credentials.NewStaticCredentialsFromCreds(
|
||||
sharedCfg.AssumeRoleSource.Creds,
|
||||
)
|
||||
if len(sharedCfg.AssumeRole.MFASerial) > 0 && sessOpts.AssumeRoleTokenProvider == nil {
|
||||
// AssumeRole Token provider is required if doing Assume Role
|
||||
// with MFA.
|
||||
return AssumeRoleTokenProviderNotSetError{}
|
||||
}
|
||||
cfg.Credentials = stscreds.NewCredentials(
|
||||
&Session{
|
||||
Config: &cfgCp,
|
||||
Handlers: handlers.Copy(),
|
||||
},
|
||||
sharedCfg.AssumeRole.RoleARN,
|
||||
func(opt *stscreds.AssumeRoleProvider) {
|
||||
opt.RoleSessionName = sharedCfg.AssumeRole.RoleSessionName
|
||||
|
||||
// Assume role with external ID
|
||||
if len(sharedCfg.AssumeRole.ExternalID) > 0 {
|
||||
opt.ExternalID = aws.String(sharedCfg.AssumeRole.ExternalID)
|
||||
}
|
||||
|
||||
// Assume role with MFA
|
||||
if len(sharedCfg.AssumeRole.MFASerial) > 0 {
|
||||
opt.SerialNumber = aws.String(sharedCfg.AssumeRole.MFASerial)
|
||||
opt.TokenProvider = sessOpts.AssumeRoleTokenProvider
|
||||
}
|
||||
},
|
||||
)
|
||||
} else if len(sharedCfg.Creds.AccessKeyID) > 0 {
|
||||
cfg.Credentials = credentials.NewStaticCredentialsFromCreds(
|
||||
sharedCfg.Creds,
|
||||
)
|
||||
} else {
|
||||
// Fallback to default credentials provider, include mock errors
|
||||
// for the credential chain so user can identify why credentials
|
||||
// failed to be retrieved.
|
||||
cfg.Credentials = credentials.NewCredentials(&credentials.ChainProvider{
|
||||
VerboseErrors: aws.BoolValue(cfg.CredentialsChainVerboseErrors),
|
||||
Providers: []credentials.Provider{
|
||||
&credProviderError{Err: awserr.New("EnvAccessKeyNotFound", "failed to find credentials in the environment.", nil)},
|
||||
&credProviderError{Err: awserr.New("SharedCredsLoad", fmt.Sprintf("failed to load profile, %s.", envCfg.Profile), nil)},
|
||||
defaults.RemoteCredProvider(*cfg, handlers),
|
||||
},
|
||||
})
|
||||
if cfg.EnableEndpointDiscovery == nil {
|
||||
if envCfg.EnableEndpointDiscovery != nil {
|
||||
cfg.WithEndpointDiscovery(*envCfg.EnableEndpointDiscovery)
|
||||
} else if envCfg.EnableSharedConfig && sharedCfg.EnableEndpointDiscovery != nil {
|
||||
cfg.WithEndpointDiscovery(*sharedCfg.EnableEndpointDiscovery)
|
||||
}
|
||||
}
|
||||
|
||||
// Regional Endpoint flag for STS endpoint resolving
|
||||
mergeSTSRegionalEndpointConfig(cfg, envCfg, sharedCfg)
|
||||
|
||||
// Configure credentials if not already set by the user when creating the
|
||||
// Session.
|
||||
if cfg.Credentials == credentials.AnonymousCredentials && userCfg.Credentials == nil {
|
||||
creds, err := resolveCredentials(cfg, envCfg, sharedCfg, handlers, sessOpts)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
cfg.Credentials = creds
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// AssumeRoleTokenProviderNotSetError is an error returned when creating a session when the
|
||||
// MFAToken option is not set when shared config is configured load assume a
|
||||
// role with an MFA token.
|
||||
type AssumeRoleTokenProviderNotSetError struct{}
|
||||
// mergeSTSRegionalEndpointConfig function merges the STSRegionalEndpoint into cfg from
|
||||
// envConfig and SharedConfig with envConfig being given precedence over SharedConfig
|
||||
func mergeSTSRegionalEndpointConfig(cfg *aws.Config, envCfg envConfig, sharedCfg sharedConfig) error {
|
||||
|
||||
// Code is the short id of the error.
|
||||
func (e AssumeRoleTokenProviderNotSetError) Code() string {
|
||||
return "AssumeRoleTokenProviderNotSetError"
|
||||
}
|
||||
cfg.STSRegionalEndpoint = envCfg.STSRegionalEndpoint
|
||||
|
||||
// Message is the description of the error
|
||||
func (e AssumeRoleTokenProviderNotSetError) Message() string {
|
||||
return fmt.Sprintf("assume role with MFA enabled, but AssumeRoleTokenProvider session option not set.")
|
||||
}
|
||||
if cfg.STSRegionalEndpoint == endpoints.UnsetSTSEndpoint {
|
||||
cfg.STSRegionalEndpoint = sharedCfg.STSRegionalEndpoint
|
||||
}
|
||||
|
||||
// OrigErr is the underlying error that caused the failure.
|
||||
func (e AssumeRoleTokenProviderNotSetError) OrigErr() error {
|
||||
if cfg.STSRegionalEndpoint == endpoints.UnsetSTSEndpoint {
|
||||
cfg.STSRegionalEndpoint = endpoints.LegacySTSEndpoint
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Error satisfies the error interface.
|
||||
func (e AssumeRoleTokenProviderNotSetError) Error() string {
|
||||
return awserr.SprintError(e.Code(), e.Message(), "", nil)
|
||||
}
|
||||
|
||||
type credProviderError struct {
|
||||
Err error
|
||||
}
|
||||
|
||||
var emptyCreds = credentials.Value{}
|
||||
|
||||
func (c credProviderError) Retrieve() (credentials.Value, error) {
|
||||
return credentials.Value{}, c.Err
|
||||
}
|
||||
func (c credProviderError) IsExpired() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func initHandlers(s *Session) {
|
||||
// Add the Validate parameter handler if it is not disabled.
|
||||
s.Handlers.Validate.Remove(corehandlers.ValidateParametersHandler)
|
||||
|
@ -539,7 +594,7 @@ func initHandlers(s *Session) {
|
|||
}
|
||||
}
|
||||
|
||||
// Copy creates and returns a copy of the current Session, coping the config
|
||||
// Copy creates and returns a copy of the current Session, copying the config
|
||||
// and handlers. If any additional configs are provided they will be merged
|
||||
// on top of the Session's copied config.
|
||||
//
|
||||
|
@ -559,37 +614,15 @@ func (s *Session) Copy(cfgs ...*aws.Config) *Session {
|
|||
// ClientConfig satisfies the client.ConfigProvider interface and is used to
|
||||
// configure the service client instances. Passing the Session to the service
|
||||
// client's constructor (New) will use this method to configure the client.
|
||||
func (s *Session) ClientConfig(serviceName string, cfgs ...*aws.Config) client.Config {
|
||||
// Backwards compatibility, the error will be eaten if user calls ClientConfig
|
||||
// directly. All SDK services will use ClientconfigWithError.
|
||||
cfg, _ := s.clientConfigWithErr(serviceName, cfgs...)
|
||||
|
||||
return cfg
|
||||
}
|
||||
|
||||
func (s *Session) clientConfigWithErr(serviceName string, cfgs ...*aws.Config) (client.Config, error) {
|
||||
func (s *Session) ClientConfig(service string, cfgs ...*aws.Config) client.Config {
|
||||
s = s.Copy(cfgs...)
|
||||
|
||||
var resolved endpoints.ResolvedEndpoint
|
||||
var err error
|
||||
|
||||
region := aws.StringValue(s.Config.Region)
|
||||
|
||||
if endpoint := aws.StringValue(s.Config.Endpoint); len(endpoint) != 0 {
|
||||
resolved.URL = endpoints.AddScheme(endpoint, aws.BoolValue(s.Config.DisableSSL))
|
||||
resolved.SigningRegion = region
|
||||
} else {
|
||||
resolved, err = s.Config.EndpointResolver.EndpointFor(
|
||||
serviceName, region,
|
||||
func(opt *endpoints.Options) {
|
||||
opt.DisableSSL = aws.BoolValue(s.Config.DisableSSL)
|
||||
opt.UseDualStack = aws.BoolValue(s.Config.UseDualStack)
|
||||
|
||||
// Support the condition where the service is modeled but its
|
||||
// endpoint metadata is not available.
|
||||
opt.ResolveUnknownService = true
|
||||
},
|
||||
)
|
||||
resolved, err := s.resolveEndpoint(service, region, s.Config)
|
||||
if err != nil && s.Config.Logger != nil {
|
||||
s.Config.Logger.Log(fmt.Sprintf(
|
||||
"ERROR: unable to resolve endpoint for service %q, region %q, err: %v",
|
||||
service, region, err))
|
||||
}
|
||||
|
||||
return client.Config{
|
||||
|
@ -599,7 +632,37 @@ func (s *Session) clientConfigWithErr(serviceName string, cfgs ...*aws.Config) (
|
|||
SigningRegion: resolved.SigningRegion,
|
||||
SigningNameDerived: resolved.SigningNameDerived,
|
||||
SigningName: resolved.SigningName,
|
||||
}, err
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Session) resolveEndpoint(service, region string, cfg *aws.Config) (endpoints.ResolvedEndpoint, error) {
|
||||
|
||||
if ep := aws.StringValue(cfg.Endpoint); len(ep) != 0 {
|
||||
return endpoints.ResolvedEndpoint{
|
||||
URL: endpoints.AddScheme(ep, aws.BoolValue(cfg.DisableSSL)),
|
||||
SigningRegion: region,
|
||||
}, nil
|
||||
}
|
||||
|
||||
resolved, err := cfg.EndpointResolver.EndpointFor(service, region,
|
||||
func(opt *endpoints.Options) {
|
||||
opt.DisableSSL = aws.BoolValue(cfg.DisableSSL)
|
||||
opt.UseDualStack = aws.BoolValue(cfg.UseDualStack)
|
||||
// Support for STSRegionalEndpoint where the STSRegionalEndpoint is
|
||||
// provided in envConfig or sharedConfig with envConfig getting
|
||||
// precedence.
|
||||
opt.STSRegionalEndpoint = cfg.STSRegionalEndpoint
|
||||
|
||||
// Support the condition where the service is modeled but its
|
||||
// endpoint metadata is not available.
|
||||
opt.ResolveUnknownService = true
|
||||
},
|
||||
)
|
||||
if err != nil {
|
||||
return endpoints.ResolvedEndpoint{}, err
|
||||
}
|
||||
|
||||
return resolved, nil
|
||||
}
|
||||
|
||||
// ClientConfigNoResolveEndpoint is the same as ClientConfig with the exception
|
||||
|
@ -609,12 +672,9 @@ func (s *Session) ClientConfigNoResolveEndpoint(cfgs ...*aws.Config) client.Conf
|
|||
s = s.Copy(cfgs...)
|
||||
|
||||
var resolved endpoints.ResolvedEndpoint
|
||||
|
||||
region := aws.StringValue(s.Config.Region)
|
||||
|
||||
if ep := aws.StringValue(s.Config.Endpoint); len(ep) > 0 {
|
||||
resolved.URL = endpoints.AddScheme(ep, aws.BoolValue(s.Config.DisableSSL))
|
||||
resolved.SigningRegion = region
|
||||
resolved.SigningRegion = aws.StringValue(s.Config.Region)
|
||||
}
|
||||
|
||||
return client.Config{
|
||||
|
@ -626,3 +686,14 @@ func (s *Session) ClientConfigNoResolveEndpoint(cfgs ...*aws.Config) client.Conf
|
|||
SigningName: resolved.SigningName,
|
||||
}
|
||||
}
|
||||
|
||||
// logDeprecatedNewSessionError function enables error handling for session
|
||||
func (s *Session) logDeprecatedNewSessionError(msg string, err error, cfgs []*aws.Config) {
|
||||
// Session creation failed, need to report the error and prevent
|
||||
// any requests from succeeding.
|
||||
s.Config.MergeIn(cfgs...)
|
||||
s.Config.Logger.Log("ERROR:", msg, "Error:", err)
|
||||
s.Handlers.Validate.PushBack(func(r *request.Request) {
|
||||
r.Error = err
|
||||
})
|
||||
}
|
||||
|
|
|
@ -2,11 +2,11 @@ package session
|
|||
|
||||
import (
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
|
||||
"github.com/aws/aws-sdk-go/aws/awserr"
|
||||
"github.com/aws/aws-sdk-go/aws/credentials"
|
||||
"github.com/go-ini/ini"
|
||||
"github.com/aws/aws-sdk-go/aws/endpoints"
|
||||
"github.com/aws/aws-sdk-go/internal/ini"
|
||||
)
|
||||
|
||||
const (
|
||||
|
@ -16,68 +16,106 @@ const (
|
|||
sessionTokenKey = `aws_session_token` // optional
|
||||
|
||||
// Assume Role Credentials group
|
||||
roleArnKey = `role_arn` // group required
|
||||
sourceProfileKey = `source_profile` // group required
|
||||
externalIDKey = `external_id` // optional
|
||||
mfaSerialKey = `mfa_serial` // optional
|
||||
roleSessionNameKey = `role_session_name` // optional
|
||||
roleArnKey = `role_arn` // group required
|
||||
sourceProfileKey = `source_profile` // group required (or credential_source)
|
||||
credentialSourceKey = `credential_source` // group required (or source_profile)
|
||||
externalIDKey = `external_id` // optional
|
||||
mfaSerialKey = `mfa_serial` // optional
|
||||
roleSessionNameKey = `role_session_name` // optional
|
||||
|
||||
// CSM options
|
||||
csmEnabledKey = `csm_enabled`
|
||||
csmHostKey = `csm_host`
|
||||
csmPortKey = `csm_port`
|
||||
csmClientIDKey = `csm_client_id`
|
||||
|
||||
// Additional Config fields
|
||||
regionKey = `region`
|
||||
|
||||
// endpoint discovery group
|
||||
enableEndpointDiscoveryKey = `endpoint_discovery_enabled` // optional
|
||||
|
||||
// External Credential Process
|
||||
credentialProcessKey = `credential_process` // optional
|
||||
|
||||
// Web Identity Token File
|
||||
webIdentityTokenFileKey = `web_identity_token_file` // optional
|
||||
|
||||
// Additional config fields for regional or legacy endpoints
|
||||
stsRegionalEndpointSharedKey = `sts_regional_endpoints`
|
||||
|
||||
// DefaultSharedConfigProfile is the default profile to be used when
|
||||
// loading configuration from the config files if another profile name
|
||||
// is not provided.
|
||||
DefaultSharedConfigProfile = `default`
|
||||
)
|
||||
|
||||
type assumeRoleConfig struct {
|
||||
RoleARN string
|
||||
SourceProfile string
|
||||
ExternalID string
|
||||
MFASerial string
|
||||
RoleSessionName string
|
||||
}
|
||||
|
||||
// sharedConfig represents the configuration fields of the SDK config files.
|
||||
type sharedConfig struct {
|
||||
// Credentials values from the config file. Both aws_access_key_id
|
||||
// and aws_secret_access_key must be provided together in the same file
|
||||
// to be considered valid. The values will be ignored if not a complete group.
|
||||
// aws_session_token is an optional field that can be provided if both of the
|
||||
// other two fields are also provided.
|
||||
// Credentials values from the config file. Both aws_access_key_id and
|
||||
// aws_secret_access_key must be provided together in the same file to be
|
||||
// considered valid. The values will be ignored if not a complete group.
|
||||
// aws_session_token is an optional field that can be provided if both of
|
||||
// the other two fields are also provided.
|
||||
//
|
||||
// aws_access_key_id
|
||||
// aws_secret_access_key
|
||||
// aws_session_token
|
||||
Creds credentials.Value
|
||||
|
||||
AssumeRole assumeRoleConfig
|
||||
AssumeRoleSource *sharedConfig
|
||||
CredentialSource string
|
||||
CredentialProcess string
|
||||
WebIdentityTokenFile string
|
||||
|
||||
// Region is the region the SDK should use for looking up AWS service endpoints
|
||||
// and signing requests.
|
||||
RoleARN string
|
||||
RoleSessionName string
|
||||
ExternalID string
|
||||
MFASerial string
|
||||
|
||||
SourceProfileName string
|
||||
SourceProfile *sharedConfig
|
||||
|
||||
// Region is the region the SDK should use for looking up AWS service
|
||||
// endpoints and signing requests.
|
||||
//
|
||||
// region
|
||||
Region string
|
||||
|
||||
// EnableEndpointDiscovery can be enabled in the shared config by setting
|
||||
// endpoint_discovery_enabled to true
|
||||
//
|
||||
// endpoint_discovery_enabled = true
|
||||
EnableEndpointDiscovery *bool
|
||||
// CSM Options
|
||||
CSMEnabled *bool
|
||||
CSMHost string
|
||||
CSMPort string
|
||||
CSMClientID string
|
||||
|
||||
// Specifies the Regional Endpoint flag for the sdk to resolve the endpoint for a service
|
||||
//
|
||||
// sts_regional_endpoints = sts_regional_endpoint
|
||||
// This can take value as `LegacySTSEndpoint` or `RegionalSTSEndpoint`
|
||||
STSRegionalEndpoint endpoints.STSRegionalEndpoint
|
||||
}
|
||||
|
||||
type sharedConfigFile struct {
|
||||
Filename string
|
||||
IniData *ini.File
|
||||
IniData ini.Sections
|
||||
}
|
||||
|
||||
// loadSharedConfig retrieves the configuration from the list of files
|
||||
// using the profile provided. The order the files are listed will determine
|
||||
// loadSharedConfig retrieves the configuration from the list of files using
|
||||
// the profile provided. The order the files are listed will determine
|
||||
// precedence. Values in subsequent files will overwrite values defined in
|
||||
// earlier files.
|
||||
//
|
||||
// For example, given two files A and B. Both define credentials. If the order
|
||||
// of the files are A then B, B's credential values will be used instead of A's.
|
||||
// of the files are A then B, B's credential values will be used instead of
|
||||
// A's.
|
||||
//
|
||||
// See sharedConfig.setFromFile for information how the config files
|
||||
// will be loaded.
|
||||
func loadSharedConfig(profile string, filenames []string) (sharedConfig, error) {
|
||||
func loadSharedConfig(profile string, filenames []string, exOpts bool) (sharedConfig, error) {
|
||||
if len(profile) == 0 {
|
||||
profile = DefaultSharedConfigProfile
|
||||
}
|
||||
|
@ -88,16 +126,11 @@ func loadSharedConfig(profile string, filenames []string) (sharedConfig, error)
|
|||
}
|
||||
|
||||
cfg := sharedConfig{}
|
||||
if err = cfg.setFromIniFiles(profile, files); err != nil {
|
||||
profiles := map[string]struct{}{}
|
||||
if err = cfg.setFromIniFiles(profiles, profile, files, exOpts); err != nil {
|
||||
return sharedConfig{}, err
|
||||
}
|
||||
|
||||
if len(cfg.AssumeRole.SourceProfile) > 0 {
|
||||
if err := cfg.setAssumeRoleSource(profile, files); err != nil {
|
||||
return sharedConfig{}, err
|
||||
}
|
||||
}
|
||||
|
||||
return cfg, nil
|
||||
}
|
||||
|
||||
|
@ -105,114 +138,258 @@ func loadSharedConfigIniFiles(filenames []string) ([]sharedConfigFile, error) {
|
|||
files := make([]sharedConfigFile, 0, len(filenames))
|
||||
|
||||
for _, filename := range filenames {
|
||||
b, err := ioutil.ReadFile(filename)
|
||||
if err != nil {
|
||||
sections, err := ini.OpenFile(filename)
|
||||
if aerr, ok := err.(awserr.Error); ok && aerr.Code() == ini.ErrCodeUnableToReadFile {
|
||||
// Skip files which can't be opened and read for whatever reason
|
||||
continue
|
||||
}
|
||||
|
||||
f, err := ini.Load(b)
|
||||
if err != nil {
|
||||
} else if err != nil {
|
||||
return nil, SharedConfigLoadError{Filename: filename, Err: err}
|
||||
}
|
||||
|
||||
files = append(files, sharedConfigFile{
|
||||
Filename: filename, IniData: f,
|
||||
Filename: filename, IniData: sections,
|
||||
})
|
||||
}
|
||||
|
||||
return files, nil
|
||||
}
|
||||
|
||||
func (cfg *sharedConfig) setAssumeRoleSource(origProfile string, files []sharedConfigFile) error {
|
||||
var assumeRoleSrc sharedConfig
|
||||
|
||||
// Multiple level assume role chains are not support
|
||||
if cfg.AssumeRole.SourceProfile == origProfile {
|
||||
assumeRoleSrc = *cfg
|
||||
assumeRoleSrc.AssumeRole = assumeRoleConfig{}
|
||||
} else {
|
||||
err := assumeRoleSrc.setFromIniFiles(cfg.AssumeRole.SourceProfile, files)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
if len(assumeRoleSrc.Creds.AccessKeyID) == 0 {
|
||||
return SharedConfigAssumeRoleError{RoleARN: cfg.AssumeRole.RoleARN}
|
||||
}
|
||||
|
||||
cfg.AssumeRoleSource = &assumeRoleSrc
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (cfg *sharedConfig) setFromIniFiles(profile string, files []sharedConfigFile) error {
|
||||
func (cfg *sharedConfig) setFromIniFiles(profiles map[string]struct{}, profile string, files []sharedConfigFile, exOpts bool) error {
|
||||
// Trim files from the list that don't exist.
|
||||
var skippedFiles int
|
||||
var profileNotFoundErr error
|
||||
for _, f := range files {
|
||||
if err := cfg.setFromIniFile(profile, f); err != nil {
|
||||
if err := cfg.setFromIniFile(profile, f, exOpts); err != nil {
|
||||
if _, ok := err.(SharedConfigProfileNotExistsError); ok {
|
||||
// Ignore proviles missings
|
||||
// Ignore profiles not defined in individual files.
|
||||
profileNotFoundErr = err
|
||||
skippedFiles++
|
||||
continue
|
||||
}
|
||||
return err
|
||||
}
|
||||
}
|
||||
if skippedFiles == len(files) {
|
||||
// If all files were skipped because the profile is not found, return
|
||||
// the original profile not found error.
|
||||
return profileNotFoundErr
|
||||
}
|
||||
|
||||
if _, ok := profiles[profile]; ok {
|
||||
// if this is the second instance of the profile the Assume Role
|
||||
// options must be cleared because they are only valid for the
|
||||
// first reference of a profile. The self linked instance of the
|
||||
// profile only have credential provider options.
|
||||
cfg.clearAssumeRoleOptions()
|
||||
} else {
|
||||
// First time a profile has been seen, It must either be a assume role
|
||||
// or credentials. Assert if the credential type requires a role ARN,
|
||||
// the ARN is also set.
|
||||
if err := cfg.validateCredentialsRequireARN(profile); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
profiles[profile] = struct{}{}
|
||||
|
||||
if err := cfg.validateCredentialType(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Link source profiles for assume roles
|
||||
if len(cfg.SourceProfileName) != 0 {
|
||||
// Linked profile via source_profile ignore credential provider
|
||||
// options, the source profile must provide the credentials.
|
||||
cfg.clearCredentialOptions()
|
||||
|
||||
srcCfg := &sharedConfig{}
|
||||
err := srcCfg.setFromIniFiles(profiles, cfg.SourceProfileName, files, exOpts)
|
||||
if err != nil {
|
||||
// SourceProfile that doesn't exist is an error in configuration.
|
||||
if _, ok := err.(SharedConfigProfileNotExistsError); ok {
|
||||
err = SharedConfigAssumeRoleError{
|
||||
RoleARN: cfg.RoleARN,
|
||||
SourceProfile: cfg.SourceProfileName,
|
||||
}
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
if !srcCfg.hasCredentials() {
|
||||
return SharedConfigAssumeRoleError{
|
||||
RoleARN: cfg.RoleARN,
|
||||
SourceProfile: cfg.SourceProfileName,
|
||||
}
|
||||
}
|
||||
|
||||
cfg.SourceProfile = srcCfg
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// setFromFile loads the configuration from the file using
|
||||
// the profile provided. A sharedConfig pointer type value is used so that
|
||||
// multiple config file loadings can be chained.
|
||||
// setFromFile loads the configuration from the file using the profile
|
||||
// provided. A sharedConfig pointer type value is used so that multiple config
|
||||
// file loadings can be chained.
|
||||
//
|
||||
// Only loads complete logically grouped values, and will not set fields in cfg
|
||||
// for incomplete grouped values in the config. Such as credentials. For example
|
||||
// if a config file only includes aws_access_key_id but no aws_secret_access_key
|
||||
// the aws_access_key_id will be ignored.
|
||||
func (cfg *sharedConfig) setFromIniFile(profile string, file sharedConfigFile) error {
|
||||
section, err := file.IniData.GetSection(profile)
|
||||
if err != nil {
|
||||
// for incomplete grouped values in the config. Such as credentials. For
|
||||
// example if a config file only includes aws_access_key_id but no
|
||||
// aws_secret_access_key the aws_access_key_id will be ignored.
|
||||
func (cfg *sharedConfig) setFromIniFile(profile string, file sharedConfigFile, exOpts bool) error {
|
||||
section, ok := file.IniData.GetSection(profile)
|
||||
if !ok {
|
||||
// Fallback to to alternate profile name: profile <name>
|
||||
section, err = file.IniData.GetSection(fmt.Sprintf("profile %s", profile))
|
||||
if err != nil {
|
||||
return SharedConfigProfileNotExistsError{Profile: profile, Err: err}
|
||||
section, ok = file.IniData.GetSection(fmt.Sprintf("profile %s", profile))
|
||||
if !ok {
|
||||
return SharedConfigProfileNotExistsError{Profile: profile, Err: nil}
|
||||
}
|
||||
}
|
||||
|
||||
if exOpts {
|
||||
// Assume Role Parameters
|
||||
updateString(&cfg.RoleARN, section, roleArnKey)
|
||||
updateString(&cfg.ExternalID, section, externalIDKey)
|
||||
updateString(&cfg.MFASerial, section, mfaSerialKey)
|
||||
updateString(&cfg.RoleSessionName, section, roleSessionNameKey)
|
||||
updateString(&cfg.SourceProfileName, section, sourceProfileKey)
|
||||
updateString(&cfg.CredentialSource, section, credentialSourceKey)
|
||||
updateString(&cfg.Region, section, regionKey)
|
||||
|
||||
if v := section.String(stsRegionalEndpointSharedKey); len(v) != 0 {
|
||||
sre, err := endpoints.GetSTSRegionalEndpoint(v)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to load %s from shared config, %s, %v",
|
||||
stsRegionalEndpointKey, file.Filename, err)
|
||||
}
|
||||
cfg.STSRegionalEndpoint = sre
|
||||
}
|
||||
}
|
||||
|
||||
updateString(&cfg.CredentialProcess, section, credentialProcessKey)
|
||||
updateString(&cfg.WebIdentityTokenFile, section, webIdentityTokenFileKey)
|
||||
|
||||
// Shared Credentials
|
||||
akid := section.Key(accessKeyIDKey).String()
|
||||
secret := section.Key(secretAccessKey).String()
|
||||
if len(akid) > 0 && len(secret) > 0 {
|
||||
cfg.Creds = credentials.Value{
|
||||
AccessKeyID: akid,
|
||||
SecretAccessKey: secret,
|
||||
SessionToken: section.Key(sessionTokenKey).String(),
|
||||
ProviderName: fmt.Sprintf("SharedConfigCredentials: %s", file.Filename),
|
||||
}
|
||||
creds := credentials.Value{
|
||||
AccessKeyID: section.String(accessKeyIDKey),
|
||||
SecretAccessKey: section.String(secretAccessKey),
|
||||
SessionToken: section.String(sessionTokenKey),
|
||||
ProviderName: fmt.Sprintf("SharedConfigCredentials: %s", file.Filename),
|
||||
}
|
||||
if creds.HasKeys() {
|
||||
cfg.Creds = creds
|
||||
}
|
||||
|
||||
// Assume Role
|
||||
roleArn := section.Key(roleArnKey).String()
|
||||
srcProfile := section.Key(sourceProfileKey).String()
|
||||
if len(roleArn) > 0 && len(srcProfile) > 0 {
|
||||
cfg.AssumeRole = assumeRoleConfig{
|
||||
RoleARN: roleArn,
|
||||
SourceProfile: srcProfile,
|
||||
ExternalID: section.Key(externalIDKey).String(),
|
||||
MFASerial: section.Key(mfaSerialKey).String(),
|
||||
RoleSessionName: section.Key(roleSessionNameKey).String(),
|
||||
}
|
||||
// Endpoint discovery
|
||||
updateBoolPtr(&cfg.EnableEndpointDiscovery, section, enableEndpointDiscoveryKey)
|
||||
|
||||
// CSM options
|
||||
updateBoolPtr(&cfg.CSMEnabled, section, csmEnabledKey)
|
||||
updateString(&cfg.CSMHost, section, csmHostKey)
|
||||
updateString(&cfg.CSMPort, section, csmPortKey)
|
||||
updateString(&cfg.CSMClientID, section, csmClientIDKey)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (cfg *sharedConfig) validateCredentialsRequireARN(profile string) error {
|
||||
var credSource string
|
||||
|
||||
switch {
|
||||
case len(cfg.SourceProfileName) != 0:
|
||||
credSource = sourceProfileKey
|
||||
case len(cfg.CredentialSource) != 0:
|
||||
credSource = credentialSourceKey
|
||||
case len(cfg.WebIdentityTokenFile) != 0:
|
||||
credSource = webIdentityTokenFileKey
|
||||
}
|
||||
|
||||
// Region
|
||||
if v := section.Key(regionKey).String(); len(v) > 0 {
|
||||
cfg.Region = v
|
||||
if len(credSource) != 0 && len(cfg.RoleARN) == 0 {
|
||||
return CredentialRequiresARNError{
|
||||
Type: credSource,
|
||||
Profile: profile,
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (cfg *sharedConfig) validateCredentialType() error {
|
||||
// Only one or no credential type can be defined.
|
||||
if !oneOrNone(
|
||||
len(cfg.SourceProfileName) != 0,
|
||||
len(cfg.CredentialSource) != 0,
|
||||
len(cfg.CredentialProcess) != 0,
|
||||
len(cfg.WebIdentityTokenFile) != 0,
|
||||
) {
|
||||
return ErrSharedConfigSourceCollision
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (cfg *sharedConfig) hasCredentials() bool {
|
||||
switch {
|
||||
case len(cfg.SourceProfileName) != 0:
|
||||
case len(cfg.CredentialSource) != 0:
|
||||
case len(cfg.CredentialProcess) != 0:
|
||||
case len(cfg.WebIdentityTokenFile) != 0:
|
||||
case cfg.Creds.HasKeys():
|
||||
default:
|
||||
return false
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
func (cfg *sharedConfig) clearCredentialOptions() {
|
||||
cfg.CredentialSource = ""
|
||||
cfg.CredentialProcess = ""
|
||||
cfg.WebIdentityTokenFile = ""
|
||||
cfg.Creds = credentials.Value{}
|
||||
}
|
||||
|
||||
func (cfg *sharedConfig) clearAssumeRoleOptions() {
|
||||
cfg.RoleARN = ""
|
||||
cfg.ExternalID = ""
|
||||
cfg.MFASerial = ""
|
||||
cfg.RoleSessionName = ""
|
||||
cfg.SourceProfileName = ""
|
||||
}
|
||||
|
||||
func oneOrNone(bs ...bool) bool {
|
||||
var count int
|
||||
|
||||
for _, b := range bs {
|
||||
if b {
|
||||
count++
|
||||
if count > 1 {
|
||||
return false
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// updateString will only update the dst with the value in the section key, key
|
||||
// is present in the section.
|
||||
func updateString(dst *string, section ini.Section, key string) {
|
||||
if !section.Has(key) {
|
||||
return
|
||||
}
|
||||
*dst = section.String(key)
|
||||
}
|
||||
|
||||
// updateBoolPtr will only update the dst with the value in the section key,
|
||||
// key is present in the section.
|
||||
func updateBoolPtr(dst **bool, section ini.Section, key string) {
|
||||
if !section.Has(key) {
|
||||
return
|
||||
}
|
||||
*dst = new(bool)
|
||||
**dst = section.Bool(key)
|
||||
}
|
||||
|
||||
// SharedConfigLoadError is an error for the shared config file failed to load.
|
||||
type SharedConfigLoadError struct {
|
||||
Filename string
|
||||
|
@ -270,7 +447,8 @@ func (e SharedConfigProfileNotExistsError) Error() string {
|
|||
// profile contains assume role information, but that information is invalid
|
||||
// or not complete.
|
||||
type SharedConfigAssumeRoleError struct {
|
||||
RoleARN string
|
||||
RoleARN string
|
||||
SourceProfile string
|
||||
}
|
||||
|
||||
// Code is the short id of the error.
|
||||
|
@ -280,8 +458,10 @@ func (e SharedConfigAssumeRoleError) Code() string {
|
|||
|
||||
// Message is the description of the error
|
||||
func (e SharedConfigAssumeRoleError) Message() string {
|
||||
return fmt.Sprintf("failed to load assume role for %s, source profile has no shared credentials",
|
||||
e.RoleARN)
|
||||
return fmt.Sprintf(
|
||||
"failed to load assume role for %s, source profile %s has no shared credentials",
|
||||
e.RoleARN, e.SourceProfile,
|
||||
)
|
||||
}
|
||||
|
||||
// OrigErr is the underlying error that caused the failure.
|
||||
|
@ -293,3 +473,36 @@ func (e SharedConfigAssumeRoleError) OrigErr() error {
|
|||
func (e SharedConfigAssumeRoleError) Error() string {
|
||||
return awserr.SprintError(e.Code(), e.Message(), "", nil)
|
||||
}
|
||||
|
||||
// CredentialRequiresARNError provides the error for shared config credentials
|
||||
// that are incorrectly configured in the shared config or credentials file.
|
||||
type CredentialRequiresARNError struct {
|
||||
// type of credentials that were configured.
|
||||
Type string
|
||||
|
||||
// Profile name the credentials were in.
|
||||
Profile string
|
||||
}
|
||||
|
||||
// Code is the short id of the error.
|
||||
func (e CredentialRequiresARNError) Code() string {
|
||||
return "CredentialRequiresARNError"
|
||||
}
|
||||
|
||||
// Message is the description of the error
|
||||
func (e CredentialRequiresARNError) Message() string {
|
||||
return fmt.Sprintf(
|
||||
"credential type %s requires role_arn, profile %s",
|
||||
e.Type, e.Profile,
|
||||
)
|
||||
}
|
||||
|
||||
// OrigErr is the underlying error that caused the failure.
|
||||
func (e CredentialRequiresARNError) OrigErr() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Error satisfies the error interface.
|
||||
func (e CredentialRequiresARNError) Error() string {
|
||||
return awserr.SprintError(e.Code(), e.Message(), "", nil)
|
||||
}
|
||||
|
|
|
@ -134,6 +134,7 @@ var requiredSignedHeaders = rules{
|
|||
"X-Amz-Server-Side-Encryption-Customer-Key": struct{}{},
|
||||
"X-Amz-Server-Side-Encryption-Customer-Key-Md5": struct{}{},
|
||||
"X-Amz-Storage-Class": struct{}{},
|
||||
"X-Amz-Tagging": struct{}{},
|
||||
"X-Amz-Website-Redirect-Location": struct{}{},
|
||||
"X-Amz-Content-Sha256": struct{}{},
|
||||
},
|
||||
|
@ -181,7 +182,7 @@ type Signer struct {
|
|||
// http://docs.aws.amazon.com/general/latest/gr/sigv4-create-canonical-request.html
|
||||
DisableURIPathEscaping bool
|
||||
|
||||
// Disales the automatical setting of the HTTP request's Body field with the
|
||||
// Disables the automatical setting of the HTTP request's Body field with the
|
||||
// io.ReadSeeker passed in to the signer. This is useful if you're using a
|
||||
// custom wrapper around the body for the io.ReadSeeker and want to preserve
|
||||
// the Body value on the Request.Body.
|
||||
|
@ -421,7 +422,7 @@ var SignRequestHandler = request.NamedHandler{
|
|||
// If the credentials of the request's config are set to
|
||||
// credentials.AnonymousCredentials the request will not be signed.
|
||||
func SignSDKRequest(req *request.Request) {
|
||||
signSDKRequestWithCurrTime(req, time.Now)
|
||||
SignSDKRequestWithCurrentTime(req, time.Now)
|
||||
}
|
||||
|
||||
// BuildNamedHandler will build a generic handler for signing.
|
||||
|
@ -429,12 +430,15 @@ func BuildNamedHandler(name string, opts ...func(*Signer)) request.NamedHandler
|
|||
return request.NamedHandler{
|
||||
Name: name,
|
||||
Fn: func(req *request.Request) {
|
||||
signSDKRequestWithCurrTime(req, time.Now, opts...)
|
||||
SignSDKRequestWithCurrentTime(req, time.Now, opts...)
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func signSDKRequestWithCurrTime(req *request.Request, curTimeFn func() time.Time, opts ...func(*Signer)) {
|
||||
// SignSDKRequestWithCurrentTime will sign the SDK's request using the time
|
||||
// function passed in. Behaves the same as SignSDKRequest with the exception
|
||||
// the request is signed with the value returned by the current time function.
|
||||
func SignSDKRequestWithCurrentTime(req *request.Request, curTimeFn func() time.Time, opts ...func(*Signer)) {
|
||||
// If the request does not need to be signed ignore the signing of the
|
||||
// request if the AnonymousCredentials object is used.
|
||||
if req.Config.Credentials == credentials.AnonymousCredentials {
|
||||
|
@ -470,13 +474,9 @@ func signSDKRequestWithCurrTime(req *request.Request, curTimeFn func() time.Time
|
|||
opt(v4)
|
||||
}
|
||||
|
||||
signingTime := req.Time
|
||||
if !req.LastSignedAt.IsZero() {
|
||||
signingTime = req.LastSignedAt
|
||||
}
|
||||
|
||||
curTime := curTimeFn()
|
||||
signedHeaders, err := v4.signWithBody(req.HTTPRequest, req.GetBody(),
|
||||
name, region, req.ExpireTime, req.ExpireTime > 0, signingTime,
|
||||
name, region, req.ExpireTime, req.ExpireTime > 0, curTime,
|
||||
)
|
||||
if err != nil {
|
||||
req.Error = err
|
||||
|
@ -485,7 +485,7 @@ func signSDKRequestWithCurrTime(req *request.Request, curTimeFn func() time.Time
|
|||
}
|
||||
|
||||
req.SignedHeaderVals = signedHeaders
|
||||
req.LastSignedAt = curTimeFn()
|
||||
req.LastSignedAt = curTime
|
||||
}
|
||||
|
||||
const logSignInfoMsg = `DEBUG: Request Signature:
|
||||
|
@ -687,7 +687,11 @@ func (ctx *signingCtx) buildBodyDigest() error {
|
|||
if !aws.IsReaderSeekable(ctx.Body) {
|
||||
return fmt.Errorf("cannot use unseekable request body %T, for signed request with body", ctx.Body)
|
||||
}
|
||||
hash = hex.EncodeToString(makeSha256Reader(ctx.Body))
|
||||
hashBytes, err := makeSha256Reader(ctx.Body)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
hash = hex.EncodeToString(hashBytes)
|
||||
}
|
||||
|
||||
if includeSHA256Header {
|
||||
|
@ -734,19 +738,33 @@ func makeSha256(data []byte) []byte {
|
|||
return hash.Sum(nil)
|
||||
}
|
||||
|
||||
func makeSha256Reader(reader io.ReadSeeker) []byte {
|
||||
func makeSha256Reader(reader io.ReadSeeker) (hashBytes []byte, err error) {
|
||||
hash := sha256.New()
|
||||
start, _ := reader.Seek(0, sdkio.SeekCurrent)
|
||||
defer reader.Seek(start, sdkio.SeekStart)
|
||||
start, err := reader.Seek(0, sdkio.SeekCurrent)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer func() {
|
||||
// ensure error is return if unable to seek back to start of payload.
|
||||
_, err = reader.Seek(start, sdkio.SeekStart)
|
||||
}()
|
||||
|
||||
io.Copy(hash, reader)
|
||||
return hash.Sum(nil)
|
||||
// Use CopyN to avoid allocating the 32KB buffer in io.Copy for bodies
|
||||
// smaller than 32KB. Fall back to io.Copy if we fail to determine the size.
|
||||
size, err := aws.SeekerLen(reader)
|
||||
if err != nil {
|
||||
io.Copy(hash, reader)
|
||||
} else {
|
||||
io.CopyN(hash, reader, size)
|
||||
}
|
||||
|
||||
return hash.Sum(nil), nil
|
||||
}
|
||||
|
||||
const doubleSpace = " "
|
||||
|
||||
// stripExcessSpaces will rewrite the passed in slice's string values to not
|
||||
// contain muliple side-by-side spaces.
|
||||
// contain multiple side-by-side spaces.
|
||||
func stripExcessSpaces(vals []string) {
|
||||
var j, k, l, m, spaces int
|
||||
for i, str := range vals {
|
||||
|
|
|
@ -7,13 +7,18 @@ import (
|
|||
"github.com/aws/aws-sdk-go/internal/sdkio"
|
||||
)
|
||||
|
||||
// ReadSeekCloser wraps a io.Reader returning a ReaderSeekerCloser. Should
|
||||
// only be used with an io.Reader that is also an io.Seeker. Doing so may
|
||||
// cause request signature errors, or request body's not sent for GET, HEAD
|
||||
// and DELETE HTTP methods.
|
||||
// ReadSeekCloser wraps a io.Reader returning a ReaderSeekerCloser. Allows the
|
||||
// SDK to accept an io.Reader that is not also an io.Seeker for unsigned
|
||||
// streaming payload API operations.
|
||||
//
|
||||
// Deprecated: Should only be used with io.ReadSeeker. If using for
|
||||
// S3 PutObject to stream content use s3manager.Uploader instead.
|
||||
// A ReadSeekCloser wrapping an nonseekable io.Reader used in an API
|
||||
// operation's input will prevent that operation being retried in the case of
|
||||
// network errors, and cause operation requests to fail if the operation
|
||||
// requires payload signing.
|
||||
//
|
||||
// Note: If using With S3 PutObject to stream an object upload The SDK's S3
|
||||
// Upload manager (s3manager.Uploader) provides support for streaming with the
|
||||
// ability to retry network errors.
|
||||
func ReadSeekCloser(r io.Reader) ReaderSeekerCloser {
|
||||
return ReaderSeekerCloser{r}
|
||||
}
|
||||
|
@ -43,7 +48,8 @@ func IsReaderSeekable(r io.Reader) bool {
|
|||
// Read reads from the reader up to size of p. The number of bytes read, and
|
||||
// error if it occurred will be returned.
|
||||
//
|
||||
// If the reader is not an io.Reader zero bytes read, and nil error will be returned.
|
||||
// If the reader is not an io.Reader zero bytes read, and nil error will be
|
||||
// returned.
|
||||
//
|
||||
// Performs the same functionality as io.Reader Read
|
||||
func (r ReaderSeekerCloser) Read(p []byte) (int, error) {
|
||||
|
|
|
@ -5,4 +5,4 @@ package aws
|
|||
const SDKName = "aws-sdk-go"
|
||||
|
||||
// SDKVersion is the version of this SDK
|
||||
const SDKVersion = "1.15.24"
|
||||
const SDKVersion = "1.25.32"
|
||||
|
|
|
@ -0,0 +1,120 @@
|
|||
package ini
|
||||
|
||||
// ASTKind represents different states in the parse table
|
||||
// and the type of AST that is being constructed
|
||||
type ASTKind int
|
||||
|
||||
// ASTKind* is used in the parse table to transition between
|
||||
// the different states
|
||||
const (
|
||||
ASTKindNone = ASTKind(iota)
|
||||
ASTKindStart
|
||||
ASTKindExpr
|
||||
ASTKindEqualExpr
|
||||
ASTKindStatement
|
||||
ASTKindSkipStatement
|
||||
ASTKindExprStatement
|
||||
ASTKindSectionStatement
|
||||
ASTKindNestedSectionStatement
|
||||
ASTKindCompletedNestedSectionStatement
|
||||
ASTKindCommentStatement
|
||||
ASTKindCompletedSectionStatement
|
||||
)
|
||||
|
||||
func (k ASTKind) String() string {
|
||||
switch k {
|
||||
case ASTKindNone:
|
||||
return "none"
|
||||
case ASTKindStart:
|
||||
return "start"
|
||||
case ASTKindExpr:
|
||||
return "expr"
|
||||
case ASTKindStatement:
|
||||
return "stmt"
|
||||
case ASTKindSectionStatement:
|
||||
return "section_stmt"
|
||||
case ASTKindExprStatement:
|
||||
return "expr_stmt"
|
||||
case ASTKindCommentStatement:
|
||||
return "comment"
|
||||
case ASTKindNestedSectionStatement:
|
||||
return "nested_section_stmt"
|
||||
case ASTKindCompletedSectionStatement:
|
||||
return "completed_stmt"
|
||||
case ASTKindSkipStatement:
|
||||
return "skip"
|
||||
default:
|
||||
return ""
|
||||
}
|
||||
}
|
||||
|
||||
// AST interface allows us to determine what kind of node we
|
||||
// are on and casting may not need to be necessary.
|
||||
//
|
||||
// The root is always the first node in Children
|
||||
type AST struct {
|
||||
Kind ASTKind
|
||||
Root Token
|
||||
RootToken bool
|
||||
Children []AST
|
||||
}
|
||||
|
||||
func newAST(kind ASTKind, root AST, children ...AST) AST {
|
||||
return AST{
|
||||
Kind: kind,
|
||||
Children: append([]AST{root}, children...),
|
||||
}
|
||||
}
|
||||
|
||||
func newASTWithRootToken(kind ASTKind, root Token, children ...AST) AST {
|
||||
return AST{
|
||||
Kind: kind,
|
||||
Root: root,
|
||||
RootToken: true,
|
||||
Children: children,
|
||||
}
|
||||
}
|
||||
|
||||
// AppendChild will append to the list of children an AST has.
|
||||
func (a *AST) AppendChild(child AST) {
|
||||
a.Children = append(a.Children, child)
|
||||
}
|
||||
|
||||
// GetRoot will return the root AST which can be the first entry
|
||||
// in the children list or a token.
|
||||
func (a *AST) GetRoot() AST {
|
||||
if a.RootToken {
|
||||
return *a
|
||||
}
|
||||
|
||||
if len(a.Children) == 0 {
|
||||
return AST{}
|
||||
}
|
||||
|
||||
return a.Children[0]
|
||||
}
|
||||
|
||||
// GetChildren will return the current AST's list of children
|
||||
func (a *AST) GetChildren() []AST {
|
||||
if len(a.Children) == 0 {
|
||||
return []AST{}
|
||||
}
|
||||
|
||||
if a.RootToken {
|
||||
return a.Children
|
||||
}
|
||||
|
||||
return a.Children[1:]
|
||||
}
|
||||
|
||||
// SetChildren will set and override all children of the AST.
|
||||
func (a *AST) SetChildren(children []AST) {
|
||||
if a.RootToken {
|
||||
a.Children = children
|
||||
} else {
|
||||
a.Children = append(a.Children[:1], children...)
|
||||
}
|
||||
}
|
||||
|
||||
// Start is used to indicate the starting state of the parse table.
|
||||
var Start = newAST(ASTKindStart, AST{})
|
|
@ -0,0 +1,11 @@
|
|||
package ini
|
||||
|
||||
var commaRunes = []rune(",")
|
||||
|
||||
func isComma(b rune) bool {
|
||||
return b == ','
|
||||
}
|
||||
|
||||
func newCommaToken() Token {
|
||||
return newToken(TokenComma, commaRunes, NoneType)
|
||||
}
|
|
@ -0,0 +1,35 @@
|
|||
package ini
|
||||
|
||||
// isComment will return whether or not the next byte(s) is a
|
||||
// comment.
|
||||
func isComment(b []rune) bool {
|
||||
if len(b) == 0 {
|
||||
return false
|
||||
}
|
||||
|
||||
switch b[0] {
|
||||
case ';':
|
||||
return true
|
||||
case '#':
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// newCommentToken will create a comment token and
|
||||
// return how many bytes were read.
|
||||
func newCommentToken(b []rune) (Token, int, error) {
|
||||
i := 0
|
||||
for ; i < len(b); i++ {
|
||||
if b[i] == '\n' {
|
||||
break
|
||||
}
|
||||
|
||||
if len(b)-i > 2 && b[i] == '\r' && b[i+1] == '\n' {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
return newToken(TokenComment, b[:i], NoneType), i, nil
|
||||
}
|
|
@ -0,0 +1,29 @@
|
|||
// Package ini is an LL(1) parser for configuration files.
|
||||
//
|
||||
// Example:
|
||||
// sections, err := ini.OpenFile("/path/to/file")
|
||||
// if err != nil {
|
||||
// panic(err)
|
||||
// }
|
||||
//
|
||||
// profile := "foo"
|
||||
// section, ok := sections.GetSection(profile)
|
||||
// if !ok {
|
||||
// fmt.Printf("section %q could not be found", profile)
|
||||
// }
|
||||
//
|
||||
// Below is the BNF that describes this parser
|
||||
// Grammar:
|
||||
// stmt -> value stmt'
|
||||
// stmt' -> epsilon | op stmt
|
||||
// value -> number | string | boolean | quoted_string
|
||||
//
|
||||
// section -> [ section'
|
||||
// section' -> value section_close
|
||||
// section_close -> ]
|
||||
//
|
||||
// SkipState will skip (NL WS)+
|
||||
//
|
||||
// comment -> # comment' | ; comment'
|
||||
// comment' -> epsilon | value
|
||||
package ini
|
|
@ -0,0 +1,4 @@
|
|||
package ini
|
||||
|
||||
// emptyToken is used to satisfy the Token interface
|
||||
var emptyToken = newToken(TokenNone, []rune{}, NoneType)
|
|
@ -0,0 +1,24 @@
|
|||
package ini
|
||||
|
||||
// newExpression will return an expression AST.
|
||||
// Expr represents an expression
|
||||
//
|
||||
// grammar:
|
||||
// expr -> string | number
|
||||
func newExpression(tok Token) AST {
|
||||
return newASTWithRootToken(ASTKindExpr, tok)
|
||||
}
|
||||
|
||||
func newEqualExpr(left AST, tok Token) AST {
|
||||
return newASTWithRootToken(ASTKindEqualExpr, tok, left)
|
||||
}
|
||||
|
||||
// EqualExprKey will return a LHS value in the equal expr
|
||||
func EqualExprKey(ast AST) string {
|
||||
children := ast.GetChildren()
|
||||
if len(children) == 0 || ast.Kind != ASTKindEqualExpr {
|
||||
return ""
|
||||
}
|
||||
|
||||
return string(children[0].Root.Raw())
|
||||
}
|
|
@ -0,0 +1,17 @@
|
|||
// +build gofuzz
|
||||
|
||||
package ini
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
)
|
||||
|
||||
func Fuzz(data []byte) int {
|
||||
b := bytes.NewReader(data)
|
||||
|
||||
if _, err := Parse(b); err != nil {
|
||||
return 0
|
||||
}
|
||||
|
||||
return 1
|
||||
}
|
|
@ -0,0 +1,51 @@
|
|||
package ini
|
||||
|
||||
import (
|
||||
"io"
|
||||
"os"
|
||||
|
||||
"github.com/aws/aws-sdk-go/aws/awserr"
|
||||
)
|
||||
|
||||
// OpenFile takes a path to a given file, and will open and parse
|
||||
// that file.
|
||||
func OpenFile(path string) (Sections, error) {
|
||||
f, err := os.Open(path)
|
||||
if err != nil {
|
||||
return Sections{}, awserr.New(ErrCodeUnableToReadFile, "unable to open file", err)
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
return Parse(f)
|
||||
}
|
||||
|
||||
// Parse will parse the given file using the shared config
|
||||
// visitor.
|
||||
func Parse(f io.Reader) (Sections, error) {
|
||||
tree, err := ParseAST(f)
|
||||
if err != nil {
|
||||
return Sections{}, err
|
||||
}
|
||||
|
||||
v := NewDefaultVisitor()
|
||||
if err = Walk(tree, v); err != nil {
|
||||
return Sections{}, err
|
||||
}
|
||||
|
||||
return v.Sections, nil
|
||||
}
|
||||
|
||||
// ParseBytes will parse the given bytes and return the parsed sections.
|
||||
func ParseBytes(b []byte) (Sections, error) {
|
||||
tree, err := ParseASTBytes(b)
|
||||
if err != nil {
|
||||
return Sections{}, err
|
||||
}
|
||||
|
||||
v := NewDefaultVisitor()
|
||||
if err = Walk(tree, v); err != nil {
|
||||
return Sections{}, err
|
||||
}
|
||||
|
||||
return v.Sections, nil
|
||||
}
|
|
@ -0,0 +1,165 @@
|
|||
package ini
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
|
||||
"github.com/aws/aws-sdk-go/aws/awserr"
|
||||
)
|
||||
|
||||
const (
|
||||
// ErrCodeUnableToReadFile is used when a file is failed to be
|
||||
// opened or read from.
|
||||
ErrCodeUnableToReadFile = "FailedRead"
|
||||
)
|
||||
|
||||
// TokenType represents the various different tokens types
|
||||
type TokenType int
|
||||
|
||||
func (t TokenType) String() string {
|
||||
switch t {
|
||||
case TokenNone:
|
||||
return "none"
|
||||
case TokenLit:
|
||||
return "literal"
|
||||
case TokenSep:
|
||||
return "sep"
|
||||
case TokenOp:
|
||||
return "op"
|
||||
case TokenWS:
|
||||
return "ws"
|
||||
case TokenNL:
|
||||
return "newline"
|
||||
case TokenComment:
|
||||
return "comment"
|
||||
case TokenComma:
|
||||
return "comma"
|
||||
default:
|
||||
return ""
|
||||
}
|
||||
}
|
||||
|
||||
// TokenType enums
|
||||
const (
|
||||
TokenNone = TokenType(iota)
|
||||
TokenLit
|
||||
TokenSep
|
||||
TokenComma
|
||||
TokenOp
|
||||
TokenWS
|
||||
TokenNL
|
||||
TokenComment
|
||||
)
|
||||
|
||||
type iniLexer struct{}
|
||||
|
||||
// Tokenize will return a list of tokens during lexical analysis of the
|
||||
// io.Reader.
|
||||
func (l *iniLexer) Tokenize(r io.Reader) ([]Token, error) {
|
||||
b, err := ioutil.ReadAll(r)
|
||||
if err != nil {
|
||||
return nil, awserr.New(ErrCodeUnableToReadFile, "unable to read file", err)
|
||||
}
|
||||
|
||||
return l.tokenize(b)
|
||||
}
|
||||
|
||||
func (l *iniLexer) tokenize(b []byte) ([]Token, error) {
|
||||
runes := bytes.Runes(b)
|
||||
var err error
|
||||
n := 0
|
||||
tokenAmount := countTokens(runes)
|
||||
tokens := make([]Token, tokenAmount)
|
||||
count := 0
|
||||
|
||||
for len(runes) > 0 && count < tokenAmount {
|
||||
switch {
|
||||
case isWhitespace(runes[0]):
|
||||
tokens[count], n, err = newWSToken(runes)
|
||||
case isComma(runes[0]):
|
||||
tokens[count], n = newCommaToken(), 1
|
||||
case isComment(runes):
|
||||
tokens[count], n, err = newCommentToken(runes)
|
||||
case isNewline(runes):
|
||||
tokens[count], n, err = newNewlineToken(runes)
|
||||
case isSep(runes):
|
||||
tokens[count], n, err = newSepToken(runes)
|
||||
case isOp(runes):
|
||||
tokens[count], n, err = newOpToken(runes)
|
||||
default:
|
||||
tokens[count], n, err = newLitToken(runes)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
count++
|
||||
|
||||
runes = runes[n:]
|
||||
}
|
||||
|
||||
return tokens[:count], nil
|
||||
}
|
||||
|
||||
func countTokens(runes []rune) int {
|
||||
count, n := 0, 0
|
||||
var err error
|
||||
|
||||
for len(runes) > 0 {
|
||||
switch {
|
||||
case isWhitespace(runes[0]):
|
||||
_, n, err = newWSToken(runes)
|
||||
case isComma(runes[0]):
|
||||
_, n = newCommaToken(), 1
|
||||
case isComment(runes):
|
||||
_, n, err = newCommentToken(runes)
|
||||
case isNewline(runes):
|
||||
_, n, err = newNewlineToken(runes)
|
||||
case isSep(runes):
|
||||
_, n, err = newSepToken(runes)
|
||||
case isOp(runes):
|
||||
_, n, err = newOpToken(runes)
|
||||
default:
|
||||
_, n, err = newLitToken(runes)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return 0
|
||||
}
|
||||
|
||||
count++
|
||||
runes = runes[n:]
|
||||
}
|
||||
|
||||
return count + 1
|
||||
}
|
||||
|
||||
// Token indicates a metadata about a given value.
|
||||
type Token struct {
|
||||
t TokenType
|
||||
ValueType ValueType
|
||||
base int
|
||||
raw []rune
|
||||
}
|
||||
|
||||
var emptyValue = Value{}
|
||||
|
||||
func newToken(t TokenType, raw []rune, v ValueType) Token {
|
||||
return Token{
|
||||
t: t,
|
||||
raw: raw,
|
||||
ValueType: v,
|
||||
}
|
||||
}
|
||||
|
||||
// Raw return the raw runes that were consumed
|
||||
func (tok Token) Raw() []rune {
|
||||
return tok.raw
|
||||
}
|
||||
|
||||
// Type returns the token type
|
||||
func (tok Token) Type() TokenType {
|
||||
return tok.t
|
||||
}
|
|
@ -0,0 +1,356 @@
|
|||
package ini
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
)
|
||||
|
||||
// State enums for the parse table
|
||||
const (
|
||||
InvalidState = iota
|
||||
// stmt -> value stmt'
|
||||
StatementState
|
||||
// stmt' -> MarkComplete | op stmt
|
||||
StatementPrimeState
|
||||
// value -> number | string | boolean | quoted_string
|
||||
ValueState
|
||||
// section -> [ section'
|
||||
OpenScopeState
|
||||
// section' -> value section_close
|
||||
SectionState
|
||||
// section_close -> ]
|
||||
CloseScopeState
|
||||
// SkipState will skip (NL WS)+
|
||||
SkipState
|
||||
// SkipTokenState will skip any token and push the previous
|
||||
// state onto the stack.
|
||||
SkipTokenState
|
||||
// comment -> # comment' | ; comment'
|
||||
// comment' -> MarkComplete | value
|
||||
CommentState
|
||||
// MarkComplete state will complete statements and move that
|
||||
// to the completed AST list
|
||||
MarkCompleteState
|
||||
// TerminalState signifies that the tokens have been fully parsed
|
||||
TerminalState
|
||||
)
|
||||
|
||||
// parseTable is a state machine to dictate the grammar above.
|
||||
var parseTable = map[ASTKind]map[TokenType]int{
|
||||
ASTKindStart: map[TokenType]int{
|
||||
TokenLit: StatementState,
|
||||
TokenSep: OpenScopeState,
|
||||
TokenWS: SkipTokenState,
|
||||
TokenNL: SkipTokenState,
|
||||
TokenComment: CommentState,
|
||||
TokenNone: TerminalState,
|
||||
},
|
||||
ASTKindCommentStatement: map[TokenType]int{
|
||||
TokenLit: StatementState,
|
||||
TokenSep: OpenScopeState,
|
||||
TokenWS: SkipTokenState,
|
||||
TokenNL: SkipTokenState,
|
||||
TokenComment: CommentState,
|
||||
TokenNone: MarkCompleteState,
|
||||
},
|
||||
ASTKindExpr: map[TokenType]int{
|
||||
TokenOp: StatementPrimeState,
|
||||
TokenLit: ValueState,
|
||||
TokenSep: OpenScopeState,
|
||||
TokenWS: ValueState,
|
||||
TokenNL: SkipState,
|
||||
TokenComment: CommentState,
|
||||
TokenNone: MarkCompleteState,
|
||||
},
|
||||
ASTKindEqualExpr: map[TokenType]int{
|
||||
TokenLit: ValueState,
|
||||
TokenWS: SkipTokenState,
|
||||
TokenNL: SkipState,
|
||||
},
|
||||
ASTKindStatement: map[TokenType]int{
|
||||
TokenLit: SectionState,
|
||||
TokenSep: CloseScopeState,
|
||||
TokenWS: SkipTokenState,
|
||||
TokenNL: SkipTokenState,
|
||||
TokenComment: CommentState,
|
||||
TokenNone: MarkCompleteState,
|
||||
},
|
||||
ASTKindExprStatement: map[TokenType]int{
|
||||
TokenLit: ValueState,
|
||||
TokenSep: OpenScopeState,
|
||||
TokenOp: ValueState,
|
||||
TokenWS: ValueState,
|
||||
TokenNL: MarkCompleteState,
|
||||
TokenComment: CommentState,
|
||||
TokenNone: TerminalState,
|
||||
TokenComma: SkipState,
|
||||
},
|
||||
ASTKindSectionStatement: map[TokenType]int{
|
||||
TokenLit: SectionState,
|
||||
TokenOp: SectionState,
|
||||
TokenSep: CloseScopeState,
|
||||
TokenWS: SectionState,
|
||||
TokenNL: SkipTokenState,
|
||||
},
|
||||
ASTKindCompletedSectionStatement: map[TokenType]int{
|
||||
TokenWS: SkipTokenState,
|
||||
TokenNL: SkipTokenState,
|
||||
TokenLit: StatementState,
|
||||
TokenSep: OpenScopeState,
|
||||
TokenComment: CommentState,
|
||||
TokenNone: MarkCompleteState,
|
||||
},
|
||||
ASTKindSkipStatement: map[TokenType]int{
|
||||
TokenLit: StatementState,
|
||||
TokenSep: OpenScopeState,
|
||||
TokenWS: SkipTokenState,
|
||||
TokenNL: SkipTokenState,
|
||||
TokenComment: CommentState,
|
||||
TokenNone: TerminalState,
|
||||
},
|
||||
}
|
||||
|
||||
// ParseAST will parse input from an io.Reader using
|
||||
// an LL(1) parser.
|
||||
func ParseAST(r io.Reader) ([]AST, error) {
|
||||
lexer := iniLexer{}
|
||||
tokens, err := lexer.Tokenize(r)
|
||||
if err != nil {
|
||||
return []AST{}, err
|
||||
}
|
||||
|
||||
return parse(tokens)
|
||||
}
|
||||
|
||||
// ParseASTBytes will parse input from a byte slice using
|
||||
// an LL(1) parser.
|
||||
func ParseASTBytes(b []byte) ([]AST, error) {
|
||||
lexer := iniLexer{}
|
||||
tokens, err := lexer.tokenize(b)
|
||||
if err != nil {
|
||||
return []AST{}, err
|
||||
}
|
||||
|
||||
return parse(tokens)
|
||||
}
|
||||
|
||||
func parse(tokens []Token) ([]AST, error) {
|
||||
start := Start
|
||||
stack := newParseStack(3, len(tokens))
|
||||
|
||||
stack.Push(start)
|
||||
s := newSkipper()
|
||||
|
||||
loop:
|
||||
for stack.Len() > 0 {
|
||||
k := stack.Pop()
|
||||
|
||||
var tok Token
|
||||
if len(tokens) == 0 {
|
||||
// this occurs when all the tokens have been processed
|
||||
// but reduction of what's left on the stack needs to
|
||||
// occur.
|
||||
tok = emptyToken
|
||||
} else {
|
||||
tok = tokens[0]
|
||||
}
|
||||
|
||||
step := parseTable[k.Kind][tok.Type()]
|
||||
if s.ShouldSkip(tok) {
|
||||
// being in a skip state with no tokens will break out of
|
||||
// the parse loop since there is nothing left to process.
|
||||
if len(tokens) == 0 {
|
||||
break loop
|
||||
}
|
||||
// if should skip is true, we skip the tokens until should skip is set to false.
|
||||
step = SkipTokenState
|
||||
}
|
||||
|
||||
switch step {
|
||||
case TerminalState:
|
||||
// Finished parsing. Push what should be the last
|
||||
// statement to the stack. If there is anything left
|
||||
// on the stack, an error in parsing has occurred.
|
||||
if k.Kind != ASTKindStart {
|
||||
stack.MarkComplete(k)
|
||||
}
|
||||
break loop
|
||||
case SkipTokenState:
|
||||
// When skipping a token, the previous state was popped off the stack.
|
||||
// To maintain the correct state, the previous state will be pushed
|
||||
// onto the stack.
|
||||
stack.Push(k)
|
||||
case StatementState:
|
||||
if k.Kind != ASTKindStart {
|
||||
stack.MarkComplete(k)
|
||||
}
|
||||
expr := newExpression(tok)
|
||||
stack.Push(expr)
|
||||
case StatementPrimeState:
|
||||
if tok.Type() != TokenOp {
|
||||
stack.MarkComplete(k)
|
||||
continue
|
||||
}
|
||||
|
||||
if k.Kind != ASTKindExpr {
|
||||
return nil, NewParseError(
|
||||
fmt.Sprintf("invalid expression: expected Expr type, but found %T type", k),
|
||||
)
|
||||
}
|
||||
|
||||
k = trimSpaces(k)
|
||||
expr := newEqualExpr(k, tok)
|
||||
stack.Push(expr)
|
||||
case ValueState:
|
||||
// ValueState requires the previous state to either be an equal expression
|
||||
// or an expression statement.
|
||||
//
|
||||
// This grammar occurs when the RHS is a number, word, or quoted string.
|
||||
// equal_expr -> lit op equal_expr'
|
||||
// equal_expr' -> number | string | quoted_string
|
||||
// quoted_string -> " quoted_string'
|
||||
// quoted_string' -> string quoted_string_end
|
||||
// quoted_string_end -> "
|
||||
//
|
||||
// otherwise
|
||||
// expr_stmt -> equal_expr (expr_stmt')*
|
||||
// expr_stmt' -> ws S | op S | MarkComplete
|
||||
// S -> equal_expr' expr_stmt'
|
||||
switch k.Kind {
|
||||
case ASTKindEqualExpr:
|
||||
// assigning a value to some key
|
||||
k.AppendChild(newExpression(tok))
|
||||
stack.Push(newExprStatement(k))
|
||||
case ASTKindExpr:
|
||||
k.Root.raw = append(k.Root.raw, tok.Raw()...)
|
||||
stack.Push(k)
|
||||
case ASTKindExprStatement:
|
||||
root := k.GetRoot()
|
||||
children := root.GetChildren()
|
||||
if len(children) == 0 {
|
||||
return nil, NewParseError(
|
||||
fmt.Sprintf("invalid expression: AST contains no children %s", k.Kind),
|
||||
)
|
||||
}
|
||||
|
||||
rhs := children[len(children)-1]
|
||||
|
||||
if rhs.Root.ValueType != QuotedStringType {
|
||||
rhs.Root.ValueType = StringType
|
||||
rhs.Root.raw = append(rhs.Root.raw, tok.Raw()...)
|
||||
|
||||
}
|
||||
|
||||
children[len(children)-1] = rhs
|
||||
k.SetChildren(children)
|
||||
|
||||
stack.Push(k)
|
||||
}
|
||||
case OpenScopeState:
|
||||
if !runeCompare(tok.Raw(), openBrace) {
|
||||
return nil, NewParseError("expected '['")
|
||||
}
|
||||
// If OpenScopeState is not at the start, we must mark the previous ast as complete
|
||||
//
|
||||
// for example: if previous ast was a skip statement;
|
||||
// we should mark it as complete before we create a new statement
|
||||
if k.Kind != ASTKindStart {
|
||||
stack.MarkComplete(k)
|
||||
}
|
||||
|
||||
stmt := newStatement()
|
||||
stack.Push(stmt)
|
||||
case CloseScopeState:
|
||||
if !runeCompare(tok.Raw(), closeBrace) {
|
||||
return nil, NewParseError("expected ']'")
|
||||
}
|
||||
|
||||
k = trimSpaces(k)
|
||||
stack.Push(newCompletedSectionStatement(k))
|
||||
case SectionState:
|
||||
var stmt AST
|
||||
|
||||
switch k.Kind {
|
||||
case ASTKindStatement:
|
||||
// If there are multiple literals inside of a scope declaration,
|
||||
// then the current token's raw value will be appended to the Name.
|
||||
//
|
||||
// This handles cases like [ profile default ]
|
||||
//
|
||||
// k will represent a SectionStatement with the children representing
|
||||
// the label of the section
|
||||
stmt = newSectionStatement(tok)
|
||||
case ASTKindSectionStatement:
|
||||
k.Root.raw = append(k.Root.raw, tok.Raw()...)
|
||||
stmt = k
|
||||
default:
|
||||
return nil, NewParseError(
|
||||
fmt.Sprintf("invalid statement: expected statement: %v", k.Kind),
|
||||
)
|
||||
}
|
||||
|
||||
stack.Push(stmt)
|
||||
case MarkCompleteState:
|
||||
if k.Kind != ASTKindStart {
|
||||
stack.MarkComplete(k)
|
||||
}
|
||||
|
||||
if stack.Len() == 0 {
|
||||
stack.Push(start)
|
||||
}
|
||||
case SkipState:
|
||||
stack.Push(newSkipStatement(k))
|
||||
s.Skip()
|
||||
case CommentState:
|
||||
if k.Kind == ASTKindStart {
|
||||
stack.Push(k)
|
||||
} else {
|
||||
stack.MarkComplete(k)
|
||||
}
|
||||
|
||||
stmt := newCommentStatement(tok)
|
||||
stack.Push(stmt)
|
||||
default:
|
||||
return nil, NewParseError(
|
||||
fmt.Sprintf("invalid state with ASTKind %v and TokenType %v",
|
||||
k, tok.Type()))
|
||||
}
|
||||
|
||||
if len(tokens) > 0 {
|
||||
tokens = tokens[1:]
|
||||
}
|
||||
}
|
||||
|
||||
// this occurs when a statement has not been completed
|
||||
if stack.top > 1 {
|
||||
return nil, NewParseError(fmt.Sprintf("incomplete ini expression"))
|
||||
}
|
||||
|
||||
// returns a sublist which excludes the start symbol
|
||||
return stack.List(), nil
|
||||
}
|
||||
|
||||
// trimSpaces will trim spaces on the left and right hand side of
|
||||
// the literal.
|
||||
func trimSpaces(k AST) AST {
|
||||
// trim left hand side of spaces
|
||||
for i := 0; i < len(k.Root.raw); i++ {
|
||||
if !isWhitespace(k.Root.raw[i]) {
|
||||
break
|
||||
}
|
||||
|
||||
k.Root.raw = k.Root.raw[1:]
|
||||
i--
|
||||
}
|
||||
|
||||
// trim right hand side of spaces
|
||||
for i := len(k.Root.raw) - 1; i >= 0; i-- {
|
||||
if !isWhitespace(k.Root.raw[i]) {
|
||||
break
|
||||
}
|
||||
|
||||
k.Root.raw = k.Root.raw[:len(k.Root.raw)-1]
|
||||
}
|
||||
|
||||
return k
|
||||
}
|
|
@ -0,0 +1,324 @@
|
|||
package ini
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
||||
var (
|
||||
runesTrue = []rune("true")
|
||||
runesFalse = []rune("false")
|
||||
)
|
||||
|
||||
var literalValues = [][]rune{
|
||||
runesTrue,
|
||||
runesFalse,
|
||||
}
|
||||
|
||||
func isBoolValue(b []rune) bool {
|
||||
for _, lv := range literalValues {
|
||||
if isLitValue(lv, b) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func isLitValue(want, have []rune) bool {
|
||||
if len(have) < len(want) {
|
||||
return false
|
||||
}
|
||||
|
||||
for i := 0; i < len(want); i++ {
|
||||
if want[i] != have[i] {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// isNumberValue will return whether not the leading characters in
|
||||
// a byte slice is a number. A number is delimited by whitespace or
|
||||
// the newline token.
|
||||
//
|
||||
// A number is defined to be in a binary, octal, decimal (int | float), hex format,
|
||||
// or in scientific notation.
|
||||
func isNumberValue(b []rune) bool {
|
||||
negativeIndex := 0
|
||||
helper := numberHelper{}
|
||||
needDigit := false
|
||||
|
||||
for i := 0; i < len(b); i++ {
|
||||
negativeIndex++
|
||||
|
||||
switch b[i] {
|
||||
case '-':
|
||||
if helper.IsNegative() || negativeIndex != 1 {
|
||||
return false
|
||||
}
|
||||
helper.Determine(b[i])
|
||||
needDigit = true
|
||||
continue
|
||||
case 'e', 'E':
|
||||
if err := helper.Determine(b[i]); err != nil {
|
||||
return false
|
||||
}
|
||||
negativeIndex = 0
|
||||
needDigit = true
|
||||
continue
|
||||
case 'b':
|
||||
if helper.numberFormat == hex {
|
||||
break
|
||||
}
|
||||
fallthrough
|
||||
case 'o', 'x':
|
||||
needDigit = true
|
||||
if i == 0 {
|
||||
return false
|
||||
}
|
||||
|
||||
fallthrough
|
||||
case '.':
|
||||
if err := helper.Determine(b[i]); err != nil {
|
||||
return false
|
||||
}
|
||||
needDigit = true
|
||||
continue
|
||||
}
|
||||
|
||||
if i > 0 && (isNewline(b[i:]) || isWhitespace(b[i])) {
|
||||
return !needDigit
|
||||
}
|
||||
|
||||
if !helper.CorrectByte(b[i]) {
|
||||
return false
|
||||
}
|
||||
needDigit = false
|
||||
}
|
||||
|
||||
return !needDigit
|
||||
}
|
||||
|
||||
func isValid(b []rune) (bool, int, error) {
|
||||
if len(b) == 0 {
|
||||
// TODO: should probably return an error
|
||||
return false, 0, nil
|
||||
}
|
||||
|
||||
return isValidRune(b[0]), 1, nil
|
||||
}
|
||||
|
||||
func isValidRune(r rune) bool {
|
||||
return r != ':' && r != '=' && r != '[' && r != ']' && r != ' ' && r != '\n'
|
||||
}
|
||||
|
||||
// ValueType is an enum that will signify what type
|
||||
// the Value is
|
||||
type ValueType int
|
||||
|
||||
func (v ValueType) String() string {
|
||||
switch v {
|
||||
case NoneType:
|
||||
return "NONE"
|
||||
case DecimalType:
|
||||
return "FLOAT"
|
||||
case IntegerType:
|
||||
return "INT"
|
||||
case StringType:
|
||||
return "STRING"
|
||||
case BoolType:
|
||||
return "BOOL"
|
||||
}
|
||||
|
||||
return ""
|
||||
}
|
||||
|
||||
// ValueType enums
|
||||
const (
|
||||
NoneType = ValueType(iota)
|
||||
DecimalType
|
||||
IntegerType
|
||||
StringType
|
||||
QuotedStringType
|
||||
BoolType
|
||||
)
|
||||
|
||||
// Value is a union container
|
||||
type Value struct {
|
||||
Type ValueType
|
||||
raw []rune
|
||||
|
||||
integer int64
|
||||
decimal float64
|
||||
boolean bool
|
||||
str string
|
||||
}
|
||||
|
||||
func newValue(t ValueType, base int, raw []rune) (Value, error) {
|
||||
v := Value{
|
||||
Type: t,
|
||||
raw: raw,
|
||||
}
|
||||
var err error
|
||||
|
||||
switch t {
|
||||
case DecimalType:
|
||||
v.decimal, err = strconv.ParseFloat(string(raw), 64)
|
||||
case IntegerType:
|
||||
if base != 10 {
|
||||
raw = raw[2:]
|
||||
}
|
||||
|
||||
v.integer, err = strconv.ParseInt(string(raw), base, 64)
|
||||
case StringType:
|
||||
v.str = string(raw)
|
||||
case QuotedStringType:
|
||||
v.str = string(raw[1 : len(raw)-1])
|
||||
case BoolType:
|
||||
v.boolean = runeCompare(v.raw, runesTrue)
|
||||
}
|
||||
|
||||
// issue 2253
|
||||
//
|
||||
// if the value trying to be parsed is too large, then we will use
|
||||
// the 'StringType' and raw value instead.
|
||||
if nerr, ok := err.(*strconv.NumError); ok && nerr.Err == strconv.ErrRange {
|
||||
v.Type = StringType
|
||||
v.str = string(raw)
|
||||
err = nil
|
||||
}
|
||||
|
||||
return v, err
|
||||
}
|
||||
|
||||
// Append will append values and change the type to a string
|
||||
// type.
|
||||
func (v *Value) Append(tok Token) {
|
||||
r := tok.Raw()
|
||||
if v.Type != QuotedStringType {
|
||||
v.Type = StringType
|
||||
r = tok.raw[1 : len(tok.raw)-1]
|
||||
}
|
||||
if tok.Type() != TokenLit {
|
||||
v.raw = append(v.raw, tok.Raw()...)
|
||||
} else {
|
||||
v.raw = append(v.raw, r...)
|
||||
}
|
||||
}
|
||||
|
||||
func (v Value) String() string {
|
||||
switch v.Type {
|
||||
case DecimalType:
|
||||
return fmt.Sprintf("decimal: %f", v.decimal)
|
||||
case IntegerType:
|
||||
return fmt.Sprintf("integer: %d", v.integer)
|
||||
case StringType:
|
||||
return fmt.Sprintf("string: %s", string(v.raw))
|
||||
case QuotedStringType:
|
||||
return fmt.Sprintf("quoted string: %s", string(v.raw))
|
||||
case BoolType:
|
||||
return fmt.Sprintf("bool: %t", v.boolean)
|
||||
default:
|
||||
return "union not set"
|
||||
}
|
||||
}
|
||||
|
||||
func newLitToken(b []rune) (Token, int, error) {
|
||||
n := 0
|
||||
var err error
|
||||
|
||||
token := Token{}
|
||||
if b[0] == '"' {
|
||||
n, err = getStringValue(b)
|
||||
if err != nil {
|
||||
return token, n, err
|
||||
}
|
||||
|
||||
token = newToken(TokenLit, b[:n], QuotedStringType)
|
||||
} else if isNumberValue(b) {
|
||||
var base int
|
||||
base, n, err = getNumericalValue(b)
|
||||
if err != nil {
|
||||
return token, 0, err
|
||||
}
|
||||
|
||||
value := b[:n]
|
||||
vType := IntegerType
|
||||
if contains(value, '.') || hasExponent(value) {
|
||||
vType = DecimalType
|
||||
}
|
||||
token = newToken(TokenLit, value, vType)
|
||||
token.base = base
|
||||
} else if isBoolValue(b) {
|
||||
n, err = getBoolValue(b)
|
||||
|
||||
token = newToken(TokenLit, b[:n], BoolType)
|
||||
} else {
|
||||
n, err = getValue(b)
|
||||
token = newToken(TokenLit, b[:n], StringType)
|
||||
}
|
||||
|
||||
return token, n, err
|
||||
}
|
||||
|
||||
// IntValue returns an integer value
|
||||
func (v Value) IntValue() int64 {
|
||||
return v.integer
|
||||
}
|
||||
|
||||
// FloatValue returns a float value
|
||||
func (v Value) FloatValue() float64 {
|
||||
return v.decimal
|
||||
}
|
||||
|
||||
// BoolValue returns a bool value
|
||||
func (v Value) BoolValue() bool {
|
||||
return v.boolean
|
||||
}
|
||||
|
||||
func isTrimmable(r rune) bool {
|
||||
switch r {
|
||||
case '\n', ' ':
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// StringValue returns the string value
|
||||
func (v Value) StringValue() string {
|
||||
switch v.Type {
|
||||
case StringType:
|
||||
return strings.TrimFunc(string(v.raw), isTrimmable)
|
||||
case QuotedStringType:
|
||||
// preserve all characters in the quotes
|
||||
return string(removeEscapedCharacters(v.raw[1 : len(v.raw)-1]))
|
||||
default:
|
||||
return strings.TrimFunc(string(v.raw), isTrimmable)
|
||||
}
|
||||
}
|
||||
|
||||
func contains(runes []rune, c rune) bool {
|
||||
for i := 0; i < len(runes); i++ {
|
||||
if runes[i] == c {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func runeCompare(v1 []rune, v2 []rune) bool {
|
||||
if len(v1) != len(v2) {
|
||||
return false
|
||||
}
|
||||
|
||||
for i := 0; i < len(v1); i++ {
|
||||
if v1[i] != v2[i] {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
|
@ -0,0 +1,30 @@
|
|||
package ini
|
||||
|
||||
func isNewline(b []rune) bool {
|
||||
if len(b) == 0 {
|
||||
return false
|
||||
}
|
||||
|
||||
if b[0] == '\n' {
|
||||
return true
|
||||
}
|
||||
|
||||
if len(b) < 2 {
|
||||
return false
|
||||
}
|
||||
|
||||
return b[0] == '\r' && b[1] == '\n'
|
||||
}
|
||||
|
||||
func newNewlineToken(b []rune) (Token, int, error) {
|
||||
i := 1
|
||||
if b[0] == '\r' && isNewline(b[1:]) {
|
||||
i++
|
||||
}
|
||||
|
||||
if !isNewline([]rune(b[:i])) {
|
||||
return emptyToken, 0, NewParseError("invalid new line token")
|
||||
}
|
||||
|
||||
return newToken(TokenNL, b[:i], NoneType), i, nil
|
||||
}
|
|
@ -0,0 +1,152 @@
|
|||
package ini
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"strconv"
|
||||
)
|
||||
|
||||
const (
|
||||
none = numberFormat(iota)
|
||||
binary
|
||||
octal
|
||||
decimal
|
||||
hex
|
||||
exponent
|
||||
)
|
||||
|
||||
type numberFormat int
|
||||
|
||||
// numberHelper is used to dictate what format a number is in
|
||||
// and what to do for negative values. Since -1e-4 is a valid
|
||||
// number, we cannot just simply check for duplicate negatives.
|
||||
type numberHelper struct {
|
||||
numberFormat numberFormat
|
||||
|
||||
negative bool
|
||||
negativeExponent bool
|
||||
}
|
||||
|
||||
func (b numberHelper) Exists() bool {
|
||||
return b.numberFormat != none
|
||||
}
|
||||
|
||||
func (b numberHelper) IsNegative() bool {
|
||||
return b.negative || b.negativeExponent
|
||||
}
|
||||
|
||||
func (b *numberHelper) Determine(c rune) error {
|
||||
if b.Exists() {
|
||||
return NewParseError(fmt.Sprintf("multiple number formats: 0%v", string(c)))
|
||||
}
|
||||
|
||||
switch c {
|
||||
case 'b':
|
||||
b.numberFormat = binary
|
||||
case 'o':
|
||||
b.numberFormat = octal
|
||||
case 'x':
|
||||
b.numberFormat = hex
|
||||
case 'e', 'E':
|
||||
b.numberFormat = exponent
|
||||
case '-':
|
||||
if b.numberFormat != exponent {
|
||||
b.negative = true
|
||||
} else {
|
||||
b.negativeExponent = true
|
||||
}
|
||||
case '.':
|
||||
b.numberFormat = decimal
|
||||
default:
|
||||
return NewParseError(fmt.Sprintf("invalid number character: %v", string(c)))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (b numberHelper) CorrectByte(c rune) bool {
|
||||
switch {
|
||||
case b.numberFormat == binary:
|
||||
if !isBinaryByte(c) {
|
||||
return false
|
||||
}
|
||||
case b.numberFormat == octal:
|
||||
if !isOctalByte(c) {
|
||||
return false
|
||||
}
|
||||
case b.numberFormat == hex:
|
||||
if !isHexByte(c) {
|
||||
return false
|
||||
}
|
||||
case b.numberFormat == decimal:
|
||||
if !isDigit(c) {
|
||||
return false
|
||||
}
|
||||
case b.numberFormat == exponent:
|
||||
if !isDigit(c) {
|
||||
return false
|
||||
}
|
||||
case b.negativeExponent:
|
||||
if !isDigit(c) {
|
||||
return false
|
||||
}
|
||||
case b.negative:
|
||||
if !isDigit(c) {
|
||||
return false
|
||||
}
|
||||
default:
|
||||
if !isDigit(c) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
func (b numberHelper) Base() int {
|
||||
switch b.numberFormat {
|
||||
case binary:
|
||||
return 2
|
||||
case octal:
|
||||
return 8
|
||||
case hex:
|
||||
return 16
|
||||
default:
|
||||
return 10
|
||||
}
|
||||
}
|
||||
|
||||
func (b numberHelper) String() string {
|
||||
buf := bytes.Buffer{}
|
||||
i := 0
|
||||
|
||||
switch b.numberFormat {
|
||||
case binary:
|
||||
i++
|
||||
buf.WriteString(strconv.Itoa(i) + ": binary format\n")
|
||||
case octal:
|
||||
i++
|
||||
buf.WriteString(strconv.Itoa(i) + ": octal format\n")
|
||||
case hex:
|
||||
i++
|
||||
buf.WriteString(strconv.Itoa(i) + ": hex format\n")
|
||||
case exponent:
|
||||
i++
|
||||
buf.WriteString(strconv.Itoa(i) + ": exponent format\n")
|
||||
default:
|
||||
i++
|
||||
buf.WriteString(strconv.Itoa(i) + ": integer format\n")
|
||||
}
|
||||
|
||||
if b.negative {
|
||||
i++
|
||||
buf.WriteString(strconv.Itoa(i) + ": negative format\n")
|
||||
}
|
||||
|
||||
if b.negativeExponent {
|
||||
i++
|
||||
buf.WriteString(strconv.Itoa(i) + ": negative exponent format\n")
|
||||
}
|
||||
|
||||
return buf.String()
|
||||
}
|
|
@ -0,0 +1,39 @@
|
|||
package ini
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
)
|
||||
|
||||
var (
|
||||
equalOp = []rune("=")
|
||||
equalColonOp = []rune(":")
|
||||
)
|
||||
|
||||
func isOp(b []rune) bool {
|
||||
if len(b) == 0 {
|
||||
return false
|
||||
}
|
||||
|
||||
switch b[0] {
|
||||
case '=':
|
||||
return true
|
||||
case ':':
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func newOpToken(b []rune) (Token, int, error) {
|
||||
tok := Token{}
|
||||
|
||||
switch b[0] {
|
||||
case '=':
|
||||
tok = newToken(TokenOp, equalOp, NoneType)
|
||||
case ':':
|
||||
tok = newToken(TokenOp, equalColonOp, NoneType)
|
||||
default:
|
||||
return tok, 0, NewParseError(fmt.Sprintf("unexpected op type, %v", b[0]))
|
||||
}
|
||||
return tok, 1, nil
|
||||
}
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue