mirror of https://github.com/status-im/consul.git
connect: support AWS PCA as a CA provider (#6189)
Port AWS PCA provider from consul-ent
This commit is contained in:
parent
2552f4a11a
commit
3497b7c00d
|
@ -363,7 +363,7 @@ test-envoy-integ: $(ENVOY_INTEG_DEPS)
|
||||||
@$(SHELL) $(CURDIR)/test/integration/connect/envoy/run-tests.sh
|
@$(SHELL) $(CURDIR)/test/integration/connect/envoy/run-tests.sh
|
||||||
|
|
||||||
proto:
|
proto:
|
||||||
protoc agent/connect/ca/plugin/*.proto --gofast_out=plugins=grpc:../../..
|
protoc agent/connect/ca/plugin/*.proto --gofast_out=plugins=grpc,Mgoogle/protobuf/duration.proto=github.com/gogo/protobuf/types:../../..
|
||||||
|
|
||||||
.PHONY: all ci bin dev dist cov test test-ci test-internal test-install-deps cover format vet ui static-assets tools
|
.PHONY: all ci bin dev dist cov test test-ci test-internal test-install-deps cover format vet ui static-assets tools
|
||||||
.PHONY: docker-images go-build-image ui-build-image static-assets-docker consul-docker ui-docker
|
.PHONY: docker-images go-build-image ui-build-image static-assets-docker consul-docker ui-docker
|
||||||
|
|
|
@ -527,7 +527,8 @@ func (c *ConnectCALeaf) generateNewLeaf(req *ConnectCALeafRequest,
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create a CSR.
|
// Create a CSR.
|
||||||
csr, err := connect.CreateCSR(id, pk)
|
serviceName := fmt.Sprintf("%s.service.%s.%s", req.Service, req.Datacenter, "consul")
|
||||||
|
csr, err := connect.CreateCSR(serviceName, id, pk)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return result, err
|
return result, err
|
||||||
}
|
}
|
||||||
|
|
|
@ -584,6 +584,17 @@ func (b *Builder) Build() (rt RuntimeConfig, err error) {
|
||||||
"tls_server_name": "TLSServerName",
|
"tls_server_name": "TLSServerName",
|
||||||
"tls_skip_verify": "TLSSkipVerify",
|
"tls_skip_verify": "TLSSkipVerify",
|
||||||
|
|
||||||
|
// AWS ACM PCA config
|
||||||
|
"access_key_id": "AccessKeyID",
|
||||||
|
"secret_access_key": "SecretAccessKey",
|
||||||
|
"region": "Region",
|
||||||
|
"sleep_time": "SleepTime",
|
||||||
|
"root_arn": "RootARN",
|
||||||
|
"intermediate_arn": "IntermediateTemplateARN",
|
||||||
|
"key_algorithm": "KeyAlgorithm",
|
||||||
|
"signing_algorithm": "SigningAlgorithm",
|
||||||
|
"delete_on_exit": "DeleteOnExit",
|
||||||
|
|
||||||
// Common CA config
|
// Common CA config
|
||||||
"leaf_cert_ttl": "LeafCertTTL",
|
"leaf_cert_ttl": "LeafCertTTL",
|
||||||
"csr_max_per_second": "CSRMaxPerSecond",
|
"csr_max_per_second": "CSRMaxPerSecond",
|
||||||
|
@ -1099,6 +1110,7 @@ func (b *Builder) Validate(rt RuntimeConfig) error {
|
||||||
"": true,
|
"": true,
|
||||||
structs.ConsulCAProvider: true,
|
structs.ConsulCAProvider: true,
|
||||||
structs.VaultCAProvider: true,
|
structs.VaultCAProvider: true,
|
||||||
|
structs.AWSCAProvider: true,
|
||||||
}
|
}
|
||||||
if _, ok := validCAProviders[rt.ConnectCAProvider]; !ok {
|
if _, ok := validCAProviders[rt.ConnectCAProvider]; !ok {
|
||||||
return fmt.Errorf("%s is not a valid CA provider", rt.ConnectCAProvider)
|
return fmt.Errorf("%s is not a valid CA provider", rt.ConnectCAProvider)
|
||||||
|
@ -1112,6 +1124,10 @@ func (b *Builder) Validate(rt RuntimeConfig) error {
|
||||||
if _, err := ca.ParseVaultCAConfig(rt.ConnectCAConfig); err != nil {
|
if _, err := ca.ParseVaultCAConfig(rt.ConnectCAConfig); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
case structs.AWSCAProvider:
|
||||||
|
if _, err := ca.ParseAWSCAConfig(rt.ConnectCAConfig); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,438 @@
|
||||||
|
package ca
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/x509"
|
||||||
|
"fmt"
|
||||||
|
"log"
|
||||||
|
"regexp"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/aws/aws-sdk-go/aws"
|
||||||
|
"github.com/aws/aws-sdk-go/aws/awserr"
|
||||||
|
"github.com/aws/aws-sdk-go/service/acmpca"
|
||||||
|
|
||||||
|
"github.com/hashicorp/consul/agent/connect"
|
||||||
|
"github.com/hashicorp/consul/agent/structs"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
RootTemplateARN = "arn:aws:acm-pca:::template/RootCACertificate/V1"
|
||||||
|
IntermediateTemplateARN = "arn:aws:acm-pca:::template/SubordinateCACertificate_PathLen0/V1"
|
||||||
|
LeafTemplateARN = "arn:aws:acm-pca:::template/EndEntityCertificate/V1"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
RootValidity = 5 * 365 * 24 * time.Hour
|
||||||
|
IntermediateValidity = 1 * 365 * 24 * time.Hour
|
||||||
|
)
|
||||||
|
|
||||||
|
type AmazonPCA struct {
|
||||||
|
arn string
|
||||||
|
pcaType string
|
||||||
|
keyAlgorithm string
|
||||||
|
signingAlgorithm string
|
||||||
|
sleepTime time.Duration
|
||||||
|
certPEM string
|
||||||
|
chainPEM string
|
||||||
|
client *acmpca.ACMPCA
|
||||||
|
logger *log.Logger
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewAmazonPCA(client *acmpca.ACMPCA, config *structs.AWSCAProviderConfig, arn string, pcaType string,
|
||||||
|
keyAlgorithm string, signingAlgorithm string, logger *log.Logger) *AmazonPCA {
|
||||||
|
sleepTime, err := time.ParseDuration(config.SleepTime)
|
||||||
|
if err != nil {
|
||||||
|
sleepTime = 5 * time.Second
|
||||||
|
}
|
||||||
|
return &AmazonPCA{arn: arn, pcaType: pcaType, client: client, logger: logger,
|
||||||
|
keyAlgorithm: keyAlgorithm, signingAlgorithm: signingAlgorithm, sleepTime: sleepTime}
|
||||||
|
}
|
||||||
|
|
||||||
|
func LoadAmazonPCA(client *acmpca.ACMPCA, config *structs.AWSCAProviderConfig, arn string, pcaType string,
|
||||||
|
clusterId string, keyAlgorithm string, signingAlgorithm string, logger *log.Logger) (*AmazonPCA, error) {
|
||||||
|
pca := NewAmazonPCA(client, config, arn, pcaType, keyAlgorithm, signingAlgorithm, logger)
|
||||||
|
output, err := pca.describe()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
warned := false
|
||||||
|
if *output.CertificateAuthority.Status != acmpca.CertificateAuthorityStatusActive {
|
||||||
|
return nil, fmt.Errorf("the specified PCA is not active: status is %s",
|
||||||
|
*output.CertificateAuthority.Status)
|
||||||
|
}
|
||||||
|
if *output.CertificateAuthority.CertificateAuthorityConfiguration.Subject.CommonName !=
|
||||||
|
GetCommonName(pcaType, clusterId) {
|
||||||
|
logger.Printf("[WARN] name of specified PCA '%s' does not match expected '%s'",
|
||||||
|
*output.CertificateAuthority.CertificateAuthorityConfiguration.Subject.CommonName,
|
||||||
|
GetCommonName(pcaType, clusterId))
|
||||||
|
warned = true
|
||||||
|
}
|
||||||
|
if *output.CertificateAuthority.CertificateAuthorityConfiguration.KeyAlgorithm != keyAlgorithm {
|
||||||
|
logger.Printf("[WARN] specified PCA is using an unexpected key algorithm: expected=%s actual=%s",
|
||||||
|
keyAlgorithm, *output.CertificateAuthority.CertificateAuthorityConfiguration.KeyAlgorithm)
|
||||||
|
warned = true
|
||||||
|
}
|
||||||
|
if *output.CertificateAuthority.CertificateAuthorityConfiguration.SigningAlgorithm != signingAlgorithm {
|
||||||
|
logger.Printf("[WARN] specified PCA is using an unexpected signing algorithm: expected=%s actual=%s",
|
||||||
|
signingAlgorithm, *output.CertificateAuthority.CertificateAuthorityConfiguration.SigningAlgorithm)
|
||||||
|
warned = true
|
||||||
|
}
|
||||||
|
|
||||||
|
if warned {
|
||||||
|
logger.Print("[WARN] existing PCA failed some preflight checks, trying to continue anyway")
|
||||||
|
} else {
|
||||||
|
logger.Print("[WARN] existing PCA passed all preflight checks")
|
||||||
|
}
|
||||||
|
return pca, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func FindAmazonPCA(client *acmpca.ACMPCA, config *structs.AWSCAProviderConfig, pcaType string, clusterId string,
|
||||||
|
keyAlgorithm string, signingAlgorithm string, logger *log.Logger) (*AmazonPCA, error) {
|
||||||
|
var name string = GetCommonName(pcaType, clusterId)
|
||||||
|
var nextToken *string
|
||||||
|
for {
|
||||||
|
input := acmpca.ListCertificateAuthoritiesInput{
|
||||||
|
MaxResults: aws.Int64(100),
|
||||||
|
NextToken: nextToken,
|
||||||
|
}
|
||||||
|
logger.Print("[DEBUG] listing existing certificate authorities")
|
||||||
|
output, err := client.ListCertificateAuthorities(&input)
|
||||||
|
if err != nil {
|
||||||
|
logger.Printf("[ERR] error searching certificate authorities: %s", err.Error())
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, ca := range output.CertificateAuthorities {
|
||||||
|
if *ca.CertificateAuthorityConfiguration.Subject.CommonName == name &&
|
||||||
|
*ca.CertificateAuthorityConfiguration.KeyAlgorithm == keyAlgorithm &&
|
||||||
|
*ca.CertificateAuthorityConfiguration.SigningAlgorithm == signingAlgorithm &&
|
||||||
|
*ca.Status == acmpca.CertificateAuthorityStatusActive {
|
||||||
|
logger.Printf("[INFO] found an existing active CA %s", *ca.Arn)
|
||||||
|
return NewAmazonPCA(client, config, *ca.Arn, pcaType, *ca.CertificateAuthorityConfiguration.KeyAlgorithm,
|
||||||
|
*ca.CertificateAuthorityConfiguration.SigningAlgorithm, logger), nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
nextToken = output.NextToken
|
||||||
|
if nextToken == nil {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.Print("[WARN] no existing active CA found")
|
||||||
|
return nil, nil // not found
|
||||||
|
}
|
||||||
|
|
||||||
|
func CreateAmazonPCA(client *acmpca.ACMPCA, config *structs.AWSCAProviderConfig, pcaType string, clusterId string,
|
||||||
|
keyAlgorithm string, signingAlgorithm string, logger *log.Logger) (*AmazonPCA, error) {
|
||||||
|
commonName := GetCommonName(pcaType, clusterId)
|
||||||
|
createInput := acmpca.CreateCertificateAuthorityInput{
|
||||||
|
CertificateAuthorityType: aws.String(pcaType),
|
||||||
|
CertificateAuthorityConfiguration: &acmpca.CertificateAuthorityConfiguration{
|
||||||
|
Subject: &acmpca.ASN1Subject{
|
||||||
|
CommonName: aws.String(commonName),
|
||||||
|
},
|
||||||
|
KeyAlgorithm: aws.String(keyAlgorithm),
|
||||||
|
SigningAlgorithm: aws.String(signingAlgorithm),
|
||||||
|
},
|
||||||
|
RevocationConfiguration: &acmpca.RevocationConfiguration{
|
||||||
|
CrlConfiguration: &acmpca.CrlConfiguration{
|
||||||
|
Enabled: aws.Bool(false),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
Tags: []*acmpca.Tag{
|
||||||
|
{Key: aws.String("ClusterID"), Value: aws.String(clusterId)},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.Printf("[DEBUG] creating new PCA %s", commonName)
|
||||||
|
createOutput, err := client.CreateCertificateAuthority(&createInput)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// wait for PCA to be created
|
||||||
|
newARN := *createOutput.CertificateAuthorityArn
|
||||||
|
for {
|
||||||
|
logger.Printf("[DEBUG] checking to see if PCA %s is ready", newARN)
|
||||||
|
|
||||||
|
describeInput := acmpca.DescribeCertificateAuthorityInput{
|
||||||
|
CertificateAuthorityArn: aws.String(newARN),
|
||||||
|
}
|
||||||
|
|
||||||
|
describeOutput, err := client.DescribeCertificateAuthority(&describeInput)
|
||||||
|
if err != nil {
|
||||||
|
logger.Printf("[ERR] error describing PCA: %s", err.Error())
|
||||||
|
if err.(awserr.Error).Code() != acmpca.ErrCodeRequestInProgressException {
|
||||||
|
return nil, fmt.Errorf("error waiting for PCA to be created: %s", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if *describeOutput.CertificateAuthority.Status == acmpca.CertificateAuthorityStatusPendingCertificate {
|
||||||
|
logger.Printf("[DEBUG] new PCA %s is ready to accept a certificate", newARN)
|
||||||
|
return NewAmazonPCA(client, config, newARN, pcaType, keyAlgorithm, signingAlgorithm, logger), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.Print("[DEBUG] sleeping until certificate is ready")
|
||||||
|
time.Sleep(5 * time.Second) // TODO: get from provider config
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (pca *AmazonPCA) describe() (*acmpca.DescribeCertificateAuthorityOutput, error) {
|
||||||
|
input := &acmpca.DescribeCertificateAuthorityInput{
|
||||||
|
CertificateAuthorityArn: aws.String(pca.arn),
|
||||||
|
}
|
||||||
|
|
||||||
|
return pca.client.DescribeCertificateAuthority(input)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (pca *AmazonPCA) GetCSR() (string, error) {
|
||||||
|
input := &acmpca.GetCertificateAuthorityCsrInput{
|
||||||
|
CertificateAuthorityArn: aws.String(pca.arn),
|
||||||
|
}
|
||||||
|
pca.logger.Printf("[DEBUG] retrieving CSR for %s", pca.arn)
|
||||||
|
output, err := pca.client.GetCertificateAuthorityCsr(input)
|
||||||
|
if err != nil {
|
||||||
|
pca.logger.Printf("[ERR] error retrieving CSR: %s", err.Error())
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
return *output.Csr, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (pca *AmazonPCA) SetCert(certPEM string, chainPEM string) error {
|
||||||
|
chainBytes := []byte(chainPEM)
|
||||||
|
if chainPEM == "" {
|
||||||
|
chainBytes = nil
|
||||||
|
}
|
||||||
|
input := acmpca.ImportCertificateAuthorityCertificateInput{
|
||||||
|
CertificateAuthorityArn: aws.String(pca.arn),
|
||||||
|
Certificate: []byte(certPEM),
|
||||||
|
CertificateChain: chainBytes,
|
||||||
|
}
|
||||||
|
|
||||||
|
pca.logger.Printf("[DEBUG] uploading certificate for %s", pca.arn)
|
||||||
|
_, err := pca.client.ImportCertificateAuthorityCertificate(&input)
|
||||||
|
if err != nil {
|
||||||
|
pca.logger.Printf("[ERR] error importing certificates: %s", err.Error())
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
pca.certPEM = certPEM
|
||||||
|
pca.chainPEM = chainPEM
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (pca *AmazonPCA) getCerts() error {
|
||||||
|
if pca.certPEM == "" || pca.chainPEM == "" {
|
||||||
|
input := &acmpca.GetCertificateAuthorityCertificateInput{
|
||||||
|
CertificateAuthorityArn: aws.String(pca.arn),
|
||||||
|
}
|
||||||
|
|
||||||
|
output, err := pca.client.GetCertificateAuthorityCertificate(input)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
pca.certPEM = *output.Certificate
|
||||||
|
pca.chainPEM = *output.CertificateChain
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (pca *AmazonPCA) Certificate() string {
|
||||||
|
if pca.certPEM == "" {
|
||||||
|
_ = pca.getCerts()
|
||||||
|
}
|
||||||
|
return pca.certPEM
|
||||||
|
}
|
||||||
|
|
||||||
|
func (pca *AmazonPCA) CertificateChain() string {
|
||||||
|
if pca.chainPEM == "" {
|
||||||
|
_ = pca.getCerts()
|
||||||
|
}
|
||||||
|
return pca.certPEM
|
||||||
|
}
|
||||||
|
|
||||||
|
func (pca *AmazonPCA) Generate(signingPCA *AmazonPCA) error {
|
||||||
|
csrPEM, err := pca.GetCSR()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
templateARN := GetTemplateARN(pca.pcaType)
|
||||||
|
chainPEM := ""
|
||||||
|
validity := RootValidity
|
||||||
|
if pca.pcaType == acmpca.CertificateAuthorityTypeSubordinate {
|
||||||
|
chainPEM = signingPCA.Certificate()
|
||||||
|
validity = IntermediateValidity
|
||||||
|
}
|
||||||
|
|
||||||
|
newCertPEM, err := signingPCA.Sign(csrPEM, templateARN, validity)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
err = pca.SetCert(newCertPEM, chainPEM)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (pca *AmazonPCA) Sign(csrPEM string, templateARN string, validity time.Duration) (string, error) {
|
||||||
|
issueInput := acmpca.IssueCertificateInput{
|
||||||
|
CertificateAuthorityArn: aws.String(pca.arn),
|
||||||
|
Csr: []byte(csrPEM),
|
||||||
|
SigningAlgorithm: aws.String(pca.signingAlgorithm),
|
||||||
|
TemplateArn: aws.String(templateARN),
|
||||||
|
Validity: &acmpca.Validity{
|
||||||
|
Value: aws.Int64(int64(validity.Seconds() / 86400.0)),
|
||||||
|
Type: aws.String(acmpca.ValidityPeriodTypeDays),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
csr, err := connect.ParseCSR(csrPEM)
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("unable to parse CSR: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
pca.logger.Printf("[DEBUG] issuing certificate for %s with %s", csr.Subject.String(), pca.arn)
|
||||||
|
issueOutput, err := pca.client.IssueCertificate(&issueInput)
|
||||||
|
if err != nil {
|
||||||
|
pca.logger.Printf("[ERR] error issuing certificate: %s", err.Error())
|
||||||
|
return "", fmt.Errorf("error issuing certificate from PCA: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// wait for certificate to be created
|
||||||
|
for {
|
||||||
|
pca.logger.Printf("[DEBUG] checking to see if certificate %s is ready", *issueOutput.CertificateArn)
|
||||||
|
|
||||||
|
certInput := acmpca.GetCertificateInput{
|
||||||
|
CertificateAuthorityArn: aws.String(pca.arn),
|
||||||
|
CertificateArn: issueOutput.CertificateArn,
|
||||||
|
}
|
||||||
|
certOutput, err := pca.client.GetCertificate(&certInput)
|
||||||
|
if err != nil {
|
||||||
|
if err.(awserr.Error).Code() != acmpca.ErrCodeRequestInProgressException {
|
||||||
|
pca.logger.Printf("[ERR] error retrieving new certificate from %s: %s",
|
||||||
|
*issueOutput.CertificateArn, err.Error())
|
||||||
|
return "", fmt.Errorf("error retrieving certificate from PCA: %s", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if certOutput.Certificate != nil {
|
||||||
|
pca.logger.Printf("[DEBUG] certificate is ready, ARN is %s", *issueOutput.CertificateArn)
|
||||||
|
|
||||||
|
newCert, err := connect.ParseCert(*certOutput.Certificate)
|
||||||
|
if err == nil {
|
||||||
|
pca.logger.Printf("[DEBUG] certificate created: commonName=%s subjectKey=%s authorityKey=%s",
|
||||||
|
newCert.Subject.CommonName,
|
||||||
|
connect.HexString(newCert.SubjectKeyId),
|
||||||
|
connect.HexString(newCert.AuthorityKeyId))
|
||||||
|
}
|
||||||
|
|
||||||
|
return *certOutput.Certificate, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
pca.logger.Printf("[DEBUG] sleeping for %s until certificate is ready", pca.sleepTime)
|
||||||
|
time.Sleep(pca.sleepTime)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (pca *AmazonPCA) SignLeaf(csrPEM string, validity time.Duration) (string, error) {
|
||||||
|
certPEM, err := pca.Sign(csrPEM, LeafTemplateARN, validity)
|
||||||
|
return certPEM, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (pca *AmazonPCA) SignIntermediate(csrPEM string) (string, error) {
|
||||||
|
certPEM, err := pca.Sign(csrPEM, IntermediateTemplateARN, IntermediateValidity)
|
||||||
|
return certPEM, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (pca *AmazonPCA) SignRoot(csrPEM string) (string, error) {
|
||||||
|
certPEM, err := pca.Sign(csrPEM, RootTemplateARN, RootValidity)
|
||||||
|
return certPEM, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (pca *AmazonPCA) Disable() error {
|
||||||
|
input := acmpca.UpdateCertificateAuthorityInput{
|
||||||
|
CertificateAuthorityArn: aws.String(pca.arn),
|
||||||
|
Status: aws.String(acmpca.CertificateAuthorityStatusDisabled),
|
||||||
|
}
|
||||||
|
pca.logger.Printf("[INFO] disabling PCA %s", pca.arn)
|
||||||
|
_, err := pca.client.UpdateCertificateAuthority(&input)
|
||||||
|
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (pca *AmazonPCA) Enable() error {
|
||||||
|
input := acmpca.UpdateCertificateAuthorityInput{
|
||||||
|
CertificateAuthorityArn: aws.String(pca.arn),
|
||||||
|
Status: aws.String(acmpca.CertificateAuthorityStatusActive),
|
||||||
|
}
|
||||||
|
pca.logger.Printf("[INFO] enabling PCA %s", pca.arn)
|
||||||
|
_, err := pca.client.UpdateCertificateAuthority(&input)
|
||||||
|
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (pca *AmazonPCA) Delete() error {
|
||||||
|
input := acmpca.DeleteCertificateAuthorityInput{
|
||||||
|
CertificateAuthorityArn: aws.String(pca.arn),
|
||||||
|
}
|
||||||
|
pca.logger.Printf("[INFO] deleting PCA %s", pca.arn)
|
||||||
|
_, err := pca.client.DeleteCertificateAuthority(&input)
|
||||||
|
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (pca *AmazonPCA) Undelete() error {
|
||||||
|
input := acmpca.RestoreCertificateAuthorityInput{
|
||||||
|
CertificateAuthorityArn: aws.String(pca.arn),
|
||||||
|
}
|
||||||
|
pca.logger.Printf("[INFO] undeleting PCA %s", pca.arn)
|
||||||
|
_, err := pca.client.RestoreCertificateAuthority(&input)
|
||||||
|
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// utility functions
|
||||||
|
|
||||||
|
func GetCommonName(pcaType string, clusterId string) string {
|
||||||
|
return fmt.Sprintf("Consul %s %s", pcaType, clusterId)
|
||||||
|
}
|
||||||
|
|
||||||
|
func GetTemplateARN(pcaType string) string {
|
||||||
|
if pcaType == acmpca.CertificateAuthorityTypeRoot {
|
||||||
|
return RootTemplateARN
|
||||||
|
} else {
|
||||||
|
return IntermediateTemplateARN
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func IsValidARN(arn string) bool {
|
||||||
|
const PcaArnRegex = "^arn:([\\w-]+):([\\w-]+):(\\w{2}-\\w+-\\d+):(\\d+):(?:([\\w-]+)[/:])?([[:xdigit:]]{8}-[[:xdigit:]]{4}-[[:xdigit:]]{4}-[[:xdigit:]]{4}-[[:xdigit:]]{12})$"
|
||||||
|
|
||||||
|
matched, _ := regexp.MatchString(PcaArnRegex, arn)
|
||||||
|
return matched
|
||||||
|
}
|
||||||
|
|
||||||
|
func ToSignatureAlgorithm(algo string) x509.SignatureAlgorithm {
|
||||||
|
switch algo {
|
||||||
|
case acmpca.SigningAlgorithmSha256withrsa:
|
||||||
|
return x509.SHA256WithRSA
|
||||||
|
case acmpca.SigningAlgorithmSha384withrsa:
|
||||||
|
return x509.SHA384WithRSA
|
||||||
|
case acmpca.SigningAlgorithmSha512withrsa:
|
||||||
|
return x509.SHA512WithRSA
|
||||||
|
case acmpca.SigningAlgorithmSha256withecdsa:
|
||||||
|
return x509.ECDSAWithSHA256
|
||||||
|
case acmpca.SigningAlgorithmSha384withecdsa:
|
||||||
|
return x509.ECDSAWithSHA384
|
||||||
|
case acmpca.SigningAlgorithmSha512withecdsa:
|
||||||
|
return x509.ECDSAWithSHA512
|
||||||
|
default:
|
||||||
|
return x509.SHA256WithRSA
|
||||||
|
}
|
||||||
|
}
|
|
@ -3,6 +3,7 @@
|
||||||
package ca
|
package ca
|
||||||
|
|
||||||
import mock "github.com/stretchr/testify/mock"
|
import mock "github.com/stretchr/testify/mock"
|
||||||
|
import time "time"
|
||||||
import x509 "crypto/x509"
|
import x509 "crypto/x509"
|
||||||
|
|
||||||
// MockProvider is an autogenerated mock type for the Provider type
|
// MockProvider is an autogenerated mock type for the Provider type
|
||||||
|
@ -157,6 +158,20 @@ func (_m *MockProvider) GenerateRoot() error {
|
||||||
return r0
|
return r0
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// MinLifetime provides a mock function with given fields:
|
||||||
|
func (_m *MockProvider) MinLifetime() time.Duration {
|
||||||
|
ret := _m.Called()
|
||||||
|
|
||||||
|
var r0 time.Duration
|
||||||
|
if rf, ok := ret.Get(0).(func() time.Duration); ok {
|
||||||
|
r0 = rf()
|
||||||
|
} else {
|
||||||
|
r0 = ret.Get(0).(time.Duration)
|
||||||
|
}
|
||||||
|
|
||||||
|
return r0
|
||||||
|
}
|
||||||
|
|
||||||
// SetIntermediate provides a mock function with given fields: intermediatePEM, rootPEM
|
// SetIntermediate provides a mock function with given fields: intermediatePEM, rootPEM
|
||||||
func (_m *MockProvider) SetIntermediate(intermediatePEM string, rootPEM string) error {
|
func (_m *MockProvider) SetIntermediate(intermediatePEM string, rootPEM string) error {
|
||||||
ret := _m.Called(intermediatePEM, rootPEM)
|
ret := _m.Called(intermediatePEM, rootPEM)
|
||||||
|
@ -212,3 +227,17 @@ func (_m *MockProvider) SignIntermediate(_a0 *x509.CertificateRequest) (string,
|
||||||
|
|
||||||
return r0, r1
|
return r0, r1
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SupportsCrossSigning provides a mock function with given fields:
|
||||||
|
func (_m *MockProvider) SupportsCrossSigning() bool {
|
||||||
|
ret := _m.Called()
|
||||||
|
|
||||||
|
var r0 bool
|
||||||
|
if rf, ok := ret.Get(0).(func() bool); ok {
|
||||||
|
r0 = rf()
|
||||||
|
} else {
|
||||||
|
r0 = ret.Get(0).(bool)
|
||||||
|
}
|
||||||
|
|
||||||
|
return r0
|
||||||
|
}
|
||||||
|
|
File diff suppressed because it is too large
Load Diff
|
@ -14,6 +14,8 @@ option go_package = "github.com/hashicorp/consul/agent/connect/ca/plugin";
|
||||||
|
|
||||||
package plugin;
|
package plugin;
|
||||||
|
|
||||||
|
import "google/protobuf/duration.proto";
|
||||||
|
|
||||||
service CA {
|
service CA {
|
||||||
rpc Configure(ConfigureRequest) returns (Empty);
|
rpc Configure(ConfigureRequest) returns (Empty);
|
||||||
rpc GenerateRoot(Empty) returns (Empty);
|
rpc GenerateRoot(Empty) returns (Empty);
|
||||||
|
@ -26,6 +28,8 @@ service CA {
|
||||||
rpc SignIntermediate(SignIntermediateRequest) returns (SignIntermediateResponse);
|
rpc SignIntermediate(SignIntermediateRequest) returns (SignIntermediateResponse);
|
||||||
rpc CrossSignCA(CrossSignCARequest) returns (CrossSignCAResponse);
|
rpc CrossSignCA(CrossSignCARequest) returns (CrossSignCAResponse);
|
||||||
rpc Cleanup(Empty) returns (Empty);
|
rpc Cleanup(Empty) returns (Empty);
|
||||||
|
rpc SupportsCrossSigning(Empty) returns (SupportsCrossSigningResponse);
|
||||||
|
rpc MinLifetime(Empty) returns (MinLifetimeResponse);
|
||||||
}
|
}
|
||||||
|
|
||||||
message ConfigureRequest {
|
message ConfigureRequest {
|
||||||
|
@ -79,6 +83,14 @@ message CrossSignCAResponse {
|
||||||
string crt_pem = 1;
|
string crt_pem = 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
message SupportsCrossSigningResponse {
|
||||||
|
bool supports_cross_signing = 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
message MinLifetimeResponse {
|
||||||
|
google.protobuf.Duration min_lifetime = 1;
|
||||||
|
}
|
||||||
|
|
||||||
// Protobufs doesn't allow no req/resp so in the cases where there are
|
// Protobufs doesn't allow no req/resp so in the cases where there are
|
||||||
// no arguments we use the Empty message.
|
// no arguments we use the Empty message.
|
||||||
message Empty {}
|
message Empty {}
|
||||||
|
|
|
@ -4,9 +4,12 @@ import (
|
||||||
"context"
|
"context"
|
||||||
"crypto/x509"
|
"crypto/x509"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/gogo/protobuf/types"
|
||||||
|
"google.golang.org/grpc"
|
||||||
|
|
||||||
"github.com/hashicorp/consul/agent/connect/ca"
|
"github.com/hashicorp/consul/agent/connect/ca"
|
||||||
"google.golang.org/grpc"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// providerPluginGRPCServer implements the CAServer interface for gRPC.
|
// providerPluginGRPCServer implements the CAServer interface for gRPC.
|
||||||
|
@ -81,6 +84,15 @@ func (p *providerPluginGRPCServer) CrossSignCA(_ context.Context, req *CrossSign
|
||||||
return &CrossSignCAResponse{CrtPem: crtPEM}, err
|
return &CrossSignCAResponse{CrtPem: crtPEM}, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (p *providerPluginGRPCServer) SupportsCrossSigning(context.Context, *Empty) (*SupportsCrossSigningResponse, error) {
|
||||||
|
s := p.impl.SupportsCrossSigning()
|
||||||
|
return &SupportsCrossSigningResponse{SupportsCrossSigning: s}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *providerPluginGRPCServer) MinLifetime(context.Context, *Empty) (*MinLifetimeResponse, error) {
|
||||||
|
return &MinLifetimeResponse{MinLifetime: types.DurationProto(p.impl.MinLifetime())}, nil
|
||||||
|
}
|
||||||
|
|
||||||
func (p *providerPluginGRPCServer) Cleanup(context.Context, *Empty) (*Empty, error) {
|
func (p *providerPluginGRPCServer) Cleanup(context.Context, *Empty) (*Empty, error) {
|
||||||
return &Empty{}, p.impl.Cleanup()
|
return &Empty{}, p.impl.Cleanup()
|
||||||
}
|
}
|
||||||
|
@ -192,6 +204,17 @@ func (p *providerPluginGRPCClient) CrossSignCA(crt *x509.Certificate) (string, e
|
||||||
return resp.CrtPem, nil
|
return resp.CrtPem, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (p *providerPluginGRPCClient) SupportsCrossSigning() bool {
|
||||||
|
resp, _ := p.client.SupportsCrossSigning(p.doneCtx, &Empty{})
|
||||||
|
return resp.SupportsCrossSigning
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *providerPluginGRPCClient) MinLifetime() time.Duration {
|
||||||
|
resp, _ := p.client.MinLifetime(p.doneCtx, &Empty{})
|
||||||
|
min, _ := types.DurationFromProto(resp.MinLifetime)
|
||||||
|
return min
|
||||||
|
}
|
||||||
|
|
||||||
func (p *providerPluginGRPCClient) Cleanup() error {
|
func (p *providerPluginGRPCClient) Cleanup() error {
|
||||||
_, err := p.client.Cleanup(p.doneCtx, &Empty{})
|
_, err := p.client.Cleanup(p.doneCtx, &Empty{})
|
||||||
return p.err(err)
|
return p.err(err)
|
||||||
|
|
|
@ -3,6 +3,9 @@ package plugin
|
||||||
import (
|
import (
|
||||||
"crypto/x509"
|
"crypto/x509"
|
||||||
"net/rpc"
|
"net/rpc"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/gogo/protobuf/types"
|
||||||
|
|
||||||
"github.com/hashicorp/consul/agent/connect/ca"
|
"github.com/hashicorp/consul/agent/connect/ca"
|
||||||
)
|
)
|
||||||
|
@ -80,6 +83,17 @@ func (p *providerPluginRPCServer) CrossSignCA(args *CrossSignCARequest, resp *Cr
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (p *providerPluginRPCServer) SupportsCrossSigning(_ struct{}, resp *SupportsCrossSigningResponse) error {
|
||||||
|
s := p.impl.SupportsCrossSigning()
|
||||||
|
resp.SupportsCrossSigning = s
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *providerPluginRPCServer) MinLifetime(_ struct{}, resp *MinLifetimeResponse) error {
|
||||||
|
resp.MinLifetime = types.DurationProto(p.impl.MinLifetime())
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func (p *providerPluginRPCServer) Cleanup(struct{}, *struct{}) error {
|
func (p *providerPluginRPCServer) Cleanup(struct{}, *struct{}) error {
|
||||||
return p.impl.Cleanup()
|
return p.impl.Cleanup()
|
||||||
}
|
}
|
||||||
|
@ -163,6 +177,19 @@ func (p *providerPluginRPCClient) CrossSignCA(crt *x509.Certificate) (string, er
|
||||||
return resp.CrtPem, err
|
return resp.CrtPem, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (p *providerPluginRPCClient) SupportsCrossSigning() bool {
|
||||||
|
var resp SupportsCrossSigningResponse
|
||||||
|
_ = p.client.Call("Plugin.SupportsCrossSigning", struct{}{}, &resp)
|
||||||
|
return resp.SupportsCrossSigning
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *providerPluginRPCClient) MinLifetime() time.Duration {
|
||||||
|
var resp MinLifetimeResponse
|
||||||
|
_ = p.client.Call("Plugin.MinLifetime", struct{}{}, &resp)
|
||||||
|
min, _ := types.DurationFromProto(resp.MinLifetime)
|
||||||
|
return min
|
||||||
|
}
|
||||||
|
|
||||||
func (p *providerPluginRPCClient) Cleanup() error {
|
func (p *providerPluginRPCClient) Cleanup() error {
|
||||||
return p.client.Call("Plugin.Cleanup", struct{}{}, &struct{}{})
|
return p.client.Call("Plugin.Cleanup", struct{}{}, &struct{}{})
|
||||||
}
|
}
|
||||||
|
|
|
@ -2,10 +2,21 @@ package ca
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"crypto/x509"
|
"crypto/x509"
|
||||||
|
"log"
|
||||||
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
//go:generate mockery -name Provider -inpkg
|
//go:generate mockery -name Provider -inpkg
|
||||||
|
|
||||||
|
// NeedsLogger is an interface that allows Consul to pass a configured
|
||||||
|
// logger into any component that needs one.
|
||||||
|
type NeedsLogger interface {
|
||||||
|
// SetLogger tells the provider to use the specified logger.
|
||||||
|
// This is called immediately after instantiating
|
||||||
|
// the provider, so that the provider can log startup messages etc.
|
||||||
|
SetLogger(l *log.Logger)
|
||||||
|
}
|
||||||
|
|
||||||
// Provider is the interface for Consul to interact with
|
// Provider is the interface for Consul to interact with
|
||||||
// an external CA that provides leaf certificate signing for
|
// an external CA that provides leaf certificate signing for
|
||||||
// given SpiffeIDServices.
|
// given SpiffeIDServices.
|
||||||
|
@ -69,6 +80,14 @@ type Provider interface {
|
||||||
// returned as a PEM formatted string.
|
// returned as a PEM formatted string.
|
||||||
CrossSignCA(*x509.Certificate) (string, error)
|
CrossSignCA(*x509.Certificate) (string, error)
|
||||||
|
|
||||||
|
// SupportsCrossSigning indicates whether the provider supports cross-signing
|
||||||
|
// other CA certs via CrossSignCA().
|
||||||
|
SupportsCrossSigning() bool
|
||||||
|
|
||||||
|
// MinLifetime returns the minimum TTL allowed by the provider for certificates
|
||||||
|
// it issues.
|
||||||
|
MinLifetime() time.Duration
|
||||||
|
|
||||||
// Cleanup performs any necessary cleanup that should happen when the provider
|
// Cleanup performs any necessary cleanup that should happen when the provider
|
||||||
// is shut down permanently, such as removing a temporary PKI backend in Vault
|
// is shut down permanently, such as removing a temporary PKI backend in Vault
|
||||||
// created for an intermediate CA.
|
// created for an intermediate CA.
|
||||||
|
|
|
@ -0,0 +1,327 @@
|
||||||
|
package ca
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"crypto/x509"
|
||||||
|
"encoding/pem"
|
||||||
|
"fmt"
|
||||||
|
"log"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/aws/aws-sdk-go/aws"
|
||||||
|
"github.com/aws/aws-sdk-go/aws/credentials"
|
||||||
|
"github.com/aws/aws-sdk-go/aws/session"
|
||||||
|
"github.com/aws/aws-sdk-go/service/acmpca"
|
||||||
|
"github.com/mitchellh/mapstructure"
|
||||||
|
|
||||||
|
"github.com/hashicorp/consul/agent/connect"
|
||||||
|
"github.com/hashicorp/consul/agent/structs"
|
||||||
|
)
|
||||||
|
|
||||||
|
type AWSProvider struct {
|
||||||
|
config *structs.AWSCAProviderConfig
|
||||||
|
session *session.Session
|
||||||
|
client *acmpca.ACMPCA
|
||||||
|
isRoot bool
|
||||||
|
clusterId string
|
||||||
|
logger *log.Logger
|
||||||
|
rootPCA *AmazonPCA
|
||||||
|
subPCA *AmazonPCA
|
||||||
|
sleepTime time.Duration
|
||||||
|
}
|
||||||
|
|
||||||
|
func (v *AWSProvider) SetLogger(l *log.Logger) {
|
||||||
|
v.logger = l
|
||||||
|
}
|
||||||
|
|
||||||
|
func (v *AWSProvider) Configure(clusterId string, isRoot bool, rawConfig map[string]interface{}) error {
|
||||||
|
config, err := ParseAWSCAConfig(rawConfig)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
sleepTime, err := time.ParseDuration(config.SleepTime)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("invalid sleep time specified: %s", err)
|
||||||
|
}
|
||||||
|
if sleepTime.Seconds() < 1 {
|
||||||
|
return fmt.Errorf("invalid sleep time specified: must be at least 1s")
|
||||||
|
}
|
||||||
|
|
||||||
|
creds := credentials.NewStaticCredentials(config.AccessKeyId, config.SecretAccessKey, "")
|
||||||
|
awsSession, err := session.NewSession(&aws.Config{
|
||||||
|
Region: aws.String(config.Region),
|
||||||
|
Credentials: creds,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
v.config = config
|
||||||
|
v.session = awsSession
|
||||||
|
v.isRoot = isRoot
|
||||||
|
v.clusterId = clusterId
|
||||||
|
v.client = acmpca.New(awsSession)
|
||||||
|
v.sleepTime = sleepTime
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (v *AWSProvider) loadFindOrCreate(arn string, pcaType string) (*AmazonPCA, error) {
|
||||||
|
if arn != "" {
|
||||||
|
return LoadAmazonPCA(v.client, v.config, arn, pcaType, v.clusterId,
|
||||||
|
v.config.KeyAlgorithm, v.config.SigningAlgorithm, v.logger)
|
||||||
|
} else {
|
||||||
|
pca, err := FindAmazonPCA(v.client, v.config, pcaType, v.clusterId,
|
||||||
|
v.config.KeyAlgorithm, v.config.SigningAlgorithm, v.logger)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if pca == nil {
|
||||||
|
pca, err = CreateAmazonPCA(v.client, v.config, pcaType, v.clusterId,
|
||||||
|
v.config.KeyAlgorithm, v.config.SigningAlgorithm, v.logger)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return pca, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (v *AWSProvider) GenerateRoot() error {
|
||||||
|
if !v.isRoot {
|
||||||
|
return fmt.Errorf("provider is not the root certificate authority")
|
||||||
|
}
|
||||||
|
|
||||||
|
if v.rootPCA != nil {
|
||||||
|
return nil // root PCA has already been created
|
||||||
|
}
|
||||||
|
|
||||||
|
rootPCA, err := v.loadFindOrCreate(v.config.RootARN, acmpca.CertificateAuthorityTypeRoot)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
v.rootPCA = rootPCA
|
||||||
|
return v.rootPCA.Generate(v.rootPCA)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (v *AWSProvider) ensureIntermediate() error {
|
||||||
|
if v.subPCA != nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
subPCA, err := v.loadFindOrCreate(v.config.IntermediateARN, acmpca.CertificateAuthorityTypeSubordinate)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
v.subPCA = subPCA
|
||||||
|
return v.subPCA.Generate(v.rootPCA)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (v *AWSProvider) ActiveRoot() (string, error) {
|
||||||
|
return v.rootPCA.Certificate(), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (v *AWSProvider) GenerateIntermediateCSR() (string, error) {
|
||||||
|
if v.isRoot {
|
||||||
|
return "", fmt.Errorf("provider is the root certificate authority, " +
|
||||||
|
"cannot generate an intermediate CSR")
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := v.ensureIntermediate(); err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
v.logger.Print("[INFO] requesting CSR for new intermediate CA cert")
|
||||||
|
return v.subPCA.GetCSR()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (v *AWSProvider) SetIntermediate(intermediatePEM string, rootPEM string) error {
|
||||||
|
if err := v.ensureIntermediate(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return v.subPCA.SetCert(intermediatePEM, rootPEM)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (v *AWSProvider) ActiveIntermediate() (string, error) {
|
||||||
|
if err := v.ensureIntermediate(); err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
return v.subPCA.Certificate(), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (v *AWSProvider) GenerateIntermediate() (string, error) {
|
||||||
|
if err := v.ensureIntermediate(); err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
if err := v.subPCA.Generate(v.rootPCA); err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
return v.subPCA.Certificate(), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (v *AWSProvider) Sign(csr *x509.CertificateRequest) (string, error) {
|
||||||
|
|
||||||
|
if err := v.ensureIntermediate(); err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
var pemBuf bytes.Buffer
|
||||||
|
if err := pem.Encode(&pemBuf, &pem.Block{Type: "CERTIFICATE REQUEST", Bytes: csr.Raw}); err != nil {
|
||||||
|
return "", fmt.Errorf("error encoding CSR into PEM format: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
leafPEM, err := v.subPCA.SignLeaf(pemBuf.String(), v.config.LeafCertTTL)
|
||||||
|
|
||||||
|
return leafPEM, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (v *AWSProvider) SignIntermediate(csr *x509.CertificateRequest) (string, error) {
|
||||||
|
|
||||||
|
spiffeID := connect.SpiffeIDSigning{ClusterID: v.clusterId, Domain: "consul"}
|
||||||
|
|
||||||
|
if len(csr.URIs) < 1 {
|
||||||
|
return "", fmt.Errorf("intermediate does not contain a trust domain SAN")
|
||||||
|
}
|
||||||
|
|
||||||
|
if csr.URIs[0].String() != spiffeID.URI().String() {
|
||||||
|
return "", fmt.Errorf("attempt to sign intermediate from a different trust domain: "+
|
||||||
|
"mine='%s' theirs='%s'", spiffeID.URI().String(), csr.URIs[0].String())
|
||||||
|
}
|
||||||
|
|
||||||
|
var buf bytes.Buffer
|
||||||
|
if err := pem.Encode(&buf, &pem.Block{Type: "CERTIFICATE REQUEST", Bytes: csr.Raw}); err != nil {
|
||||||
|
return "", fmt.Errorf("error encoding private key: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return v.rootPCA.SignIntermediate(buf.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
// I'm not sure this can actually be implemented. PCA cannot cross-sign a cert, it can only
|
||||||
|
// sign a CSR, and we cannot generate a CSR from another provider's certificate without its
|
||||||
|
// private key.
|
||||||
|
func (v *AWSProvider) CrossSignCA(newCA *x509.Certificate) (string, error) {
|
||||||
|
return "", fmt.Errorf("not implemented in AWS PCA provider")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (v *AWSProvider) Cleanup() error {
|
||||||
|
if v.config.DeleteOnExit {
|
||||||
|
if v.subPCA != nil {
|
||||||
|
if err := v.subPCA.Disable(); err != nil {
|
||||||
|
v.logger.Printf("[WARN] error disabling subordinate PCA: %s", err.Error())
|
||||||
|
}
|
||||||
|
if err := v.subPCA.Delete(); err != nil {
|
||||||
|
v.logger.Printf("[WARN] error deleting subordinate PCA: %s", err.Error())
|
||||||
|
}
|
||||||
|
v.subPCA = nil
|
||||||
|
}
|
||||||
|
if v.rootPCA != nil {
|
||||||
|
if err := v.rootPCA.Disable(); err != nil {
|
||||||
|
v.logger.Printf("[WARN] error disabling root PCA: %s", err.Error())
|
||||||
|
}
|
||||||
|
if err := v.rootPCA.Delete(); err != nil {
|
||||||
|
v.logger.Printf("[WARN] error deleting root PCA: %s", err.Error())
|
||||||
|
}
|
||||||
|
v.rootPCA = nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (v *AWSProvider) SupportsCrossSigning() bool {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func (v *AWSProvider) MinLifetime() time.Duration {
|
||||||
|
return 24 * time.Hour
|
||||||
|
}
|
||||||
|
|
||||||
|
func ParseAWSCAConfig(raw map[string]interface{}) (*structs.AWSCAProviderConfig, error) {
|
||||||
|
config := structs.AWSCAProviderConfig{
|
||||||
|
CommonCAProviderConfig: defaultCommonConfig(),
|
||||||
|
SleepTime: "5s",
|
||||||
|
}
|
||||||
|
|
||||||
|
decodeConf := &mapstructure.DecoderConfig{
|
||||||
|
DecodeHook: structs.ParseDurationFunc(),
|
||||||
|
Result: &config,
|
||||||
|
WeaklyTypedInput: true,
|
||||||
|
}
|
||||||
|
|
||||||
|
decoder, err := mapstructure.NewDecoder(decodeConf)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := decoder.Decode(raw); err != nil {
|
||||||
|
return nil, fmt.Errorf("error decoding config: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if config.AccessKeyId == "" {
|
||||||
|
return nil, fmt.Errorf("must provide the AWS access key ID")
|
||||||
|
}
|
||||||
|
|
||||||
|
if config.SecretAccessKey == "" {
|
||||||
|
return nil, fmt.Errorf("must provide the AWS secret access key")
|
||||||
|
}
|
||||||
|
|
||||||
|
if config.Region == "" {
|
||||||
|
return nil, fmt.Errorf("must provide the AWS region")
|
||||||
|
}
|
||||||
|
|
||||||
|
if config.RootARN != "" {
|
||||||
|
if !IsValidARN(config.RootARN) {
|
||||||
|
return nil, fmt.Errorf("root PCA ARN is not in correct format")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if config.IntermediateARN != "" {
|
||||||
|
if !IsValidARN(config.IntermediateARN) {
|
||||||
|
return nil, fmt.Errorf("intermediate PCA ARN is not in correct format")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if config.KeyAlgorithm == "" {
|
||||||
|
config.KeyAlgorithm = acmpca.KeyAlgorithmEcPrime256v1
|
||||||
|
} else {
|
||||||
|
config.KeyAlgorithm, err = ValidateEnum(config.KeyAlgorithm,
|
||||||
|
acmpca.KeyAlgorithmRsa2048, acmpca.KeyAlgorithmRsa4096,
|
||||||
|
acmpca.KeyAlgorithmEcPrime256v1, acmpca.KeyAlgorithmEcSecp384r1)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("invalid key algorithm specified: %s", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if config.SigningAlgorithm == "" {
|
||||||
|
config.SigningAlgorithm = acmpca.SigningAlgorithmSha256withecdsa
|
||||||
|
} else {
|
||||||
|
config.SigningAlgorithm, err = ValidateEnum(config.SigningAlgorithm,
|
||||||
|
acmpca.SigningAlgorithmSha256withrsa, acmpca.SigningAlgorithmSha384withrsa,
|
||||||
|
acmpca.SigningAlgorithmSha512withrsa, acmpca.SigningAlgorithmSha256withecdsa,
|
||||||
|
acmpca.SigningAlgorithmSha384withecdsa, acmpca.SigningAlgorithmSha512withecdsa)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("invalid signing algorithm specified: %s", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := config.CommonCAProviderConfig.Validate(); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return &config, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func ValidateEnum(value string, choices ...string) (string, error) {
|
||||||
|
for _, choice := range choices {
|
||||||
|
if strings.ToLower(value) == strings.ToLower(choice) {
|
||||||
|
return choice, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return "", fmt.Errorf("must be one of %s or %s",
|
||||||
|
strings.Join(choices[:len(choices)-1], ","),
|
||||||
|
choices[len(choices)-1])
|
||||||
|
}
|
|
@ -0,0 +1,392 @@
|
||||||
|
package ca
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/x509"
|
||||||
|
"log"
|
||||||
|
"os"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/aws/aws-sdk-go/aws"
|
||||||
|
"github.com/aws/aws-sdk-go/aws/credentials"
|
||||||
|
"github.com/aws/aws-sdk-go/aws/session"
|
||||||
|
"github.com/aws/aws-sdk-go/service/acmpca"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
|
"github.com/hashicorp/consul/agent/connect"
|
||||||
|
"github.com/hashicorp/consul/agent/structs"
|
||||||
|
)
|
||||||
|
|
||||||
|
var awsAccessKeyId string
|
||||||
|
var awsSecretAccessKey string
|
||||||
|
var awsRegion string
|
||||||
|
var awsClient *acmpca.ACMPCA
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
awsAccessKeyId = os.Getenv("AWS_ACCESS_KEY_ID")
|
||||||
|
awsSecretAccessKey = os.Getenv("AWS_SECRET_ACCESS_KEY")
|
||||||
|
awsRegion = os.Getenv("AWS_REGION")
|
||||||
|
awsSession, _ := session.NewSession(&aws.Config{
|
||||||
|
Region: aws.String(awsRegion),
|
||||||
|
Credentials: credentials.NewStaticCredentials(awsAccessKeyId, awsSecretAccessKey, ""),
|
||||||
|
})
|
||||||
|
awsClient = acmpca.New(awsSession)
|
||||||
|
}
|
||||||
|
|
||||||
|
func makeConfig() *structs.CAConfiguration {
|
||||||
|
return &structs.CAConfiguration{
|
||||||
|
ClusterID: "asdf",
|
||||||
|
Provider: "aws",
|
||||||
|
Config: map[string]interface{}{
|
||||||
|
"LeafCertTTL": "72h",
|
||||||
|
"Region": awsRegion,
|
||||||
|
"AccessKeyId": awsAccessKeyId,
|
||||||
|
"SecretAccessKey": awsSecretAccessKey,
|
||||||
|
"KeyAlgorithm": acmpca.KeyAlgorithmEcPrime256v1,
|
||||||
|
"SigningAlgorithm": acmpca.SigningAlgorithmSha256withecdsa,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func makeProvider(r *require.Assertions, config *structs.CAConfiguration) *AWSProvider {
|
||||||
|
logger := log.New(os.Stderr, "aws_pca", log.LstdFlags)
|
||||||
|
provider := &AWSProvider{logger: logger}
|
||||||
|
r.NoError(provider.Configure(config.ClusterID, true, config.Config))
|
||||||
|
return provider
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAWSProvider_Configure(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
if awsAccessKeyId == "" || awsSecretAccessKey == "" || awsRegion == "" {
|
||||||
|
t.Skip("skipping test due to missing AWS credentials")
|
||||||
|
}
|
||||||
|
|
||||||
|
r := require.New(t)
|
||||||
|
conf := makeConfig()
|
||||||
|
provider := makeProvider(r, conf)
|
||||||
|
r.Equal(conf.Config["AccessKeyId"], provider.config.AccessKeyId)
|
||||||
|
r.Equal(conf.Config["SecretAccessKey"], provider.config.SecretAccessKey)
|
||||||
|
r.Equal(conf.Config["Region"], provider.config.Region)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAWSProvider_ConfigureBadKeyAlgorithm(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
if awsAccessKeyId == "" || awsSecretAccessKey == "" || awsRegion == "" {
|
||||||
|
t.Skip("skipping test due to missing AWS credentials")
|
||||||
|
}
|
||||||
|
|
||||||
|
r := require.New(t)
|
||||||
|
conf := makeConfig()
|
||||||
|
conf.Config["KeyAlgorithm"] = "foo"
|
||||||
|
provider := &AWSProvider{}
|
||||||
|
r.Error(provider.Configure(conf.ClusterID, true, conf.Config))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAWSProvider_ConfigureBadSigningAlgorithm(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
if awsAccessKeyId == "" || awsSecretAccessKey == "" || awsRegion == "" {
|
||||||
|
t.Skip("skipping test due to missing AWS credentials")
|
||||||
|
}
|
||||||
|
|
||||||
|
r := require.New(t)
|
||||||
|
conf := makeConfig()
|
||||||
|
conf.Config["SigningAlgorithm"] = "foo"
|
||||||
|
provider := &AWSProvider{}
|
||||||
|
r.Error(provider.Configure(conf.ClusterID, true, conf.Config))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAWSProvider_ConfigureBadSleepTime(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
if awsAccessKeyId == "" || awsSecretAccessKey == "" || awsRegion == "" {
|
||||||
|
t.Skip("skipping test due to missing AWS credentials")
|
||||||
|
}
|
||||||
|
|
||||||
|
r := require.New(t)
|
||||||
|
conf := makeConfig()
|
||||||
|
|
||||||
|
conf.Config["SleepTime"] = "-5s"
|
||||||
|
provider := &AWSProvider{}
|
||||||
|
r.Error(provider.Configure(conf.ClusterID, true, conf.Config))
|
||||||
|
|
||||||
|
conf.Config["SleepTime"] = "5foo"
|
||||||
|
provider = &AWSProvider{}
|
||||||
|
r.Error(provider.Configure(conf.ClusterID, true, conf.Config))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAWSProvider_ConfigureBadLeafTTL(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
if awsAccessKeyId == "" || awsSecretAccessKey == "" || awsRegion == "" {
|
||||||
|
t.Skip("skipping test due to missing AWS credentials")
|
||||||
|
}
|
||||||
|
|
||||||
|
r := require.New(t)
|
||||||
|
conf := makeConfig()
|
||||||
|
conf.Config["LeafCertTTL"] = "-72h"
|
||||||
|
provider := &AWSProvider{}
|
||||||
|
r.Error(provider.Configure(conf.ClusterID, true, conf.Config))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAWSProvider_GenerateRoot(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
if awsAccessKeyId == "" || awsSecretAccessKey == "" || awsRegion == "" {
|
||||||
|
t.Skip("skipping test due to missing AWS credentials")
|
||||||
|
}
|
||||||
|
|
||||||
|
r := require.New(t)
|
||||||
|
conf := makeConfig()
|
||||||
|
provider := makeProvider(r, conf)
|
||||||
|
|
||||||
|
r.NoError(provider.GenerateRoot())
|
||||||
|
r.NotEmpty(provider.rootPCA.arn)
|
||||||
|
|
||||||
|
output, err := awsClient.DescribeCertificateAuthority(&acmpca.DescribeCertificateAuthorityInput{
|
||||||
|
CertificateAuthorityArn: aws.String(provider.rootPCA.arn),
|
||||||
|
})
|
||||||
|
r.NoError(err)
|
||||||
|
|
||||||
|
ca := output.CertificateAuthority
|
||||||
|
caConf := ca.CertificateAuthorityConfiguration
|
||||||
|
r.Equal(acmpca.CertificateAuthorityStatusActive, *ca.Status)
|
||||||
|
r.Equal(acmpca.KeyAlgorithmEcPrime256v1, *caConf.KeyAlgorithm)
|
||||||
|
r.Equal(acmpca.SigningAlgorithmSha256withecdsa, *caConf.SigningAlgorithm)
|
||||||
|
r.Contains(*caConf.Subject.CommonName, conf.ClusterID)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAWSProvider_GenerateRootNotRoot(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
if awsAccessKeyId == "" || awsSecretAccessKey == "" || awsRegion == "" {
|
||||||
|
t.Skip("skipping test due to missing AWS credentials")
|
||||||
|
}
|
||||||
|
|
||||||
|
r := require.New(t)
|
||||||
|
conf := makeConfig()
|
||||||
|
provider := makeProvider(r, conf)
|
||||||
|
provider.isRoot = false
|
||||||
|
r.Error(provider.GenerateRoot())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAWSProvider_Sign(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
if awsAccessKeyId == "" || awsSecretAccessKey == "" || awsRegion == "" {
|
||||||
|
t.Skip("skipping test due to missing AWS credentials")
|
||||||
|
}
|
||||||
|
|
||||||
|
r := require.New(t)
|
||||||
|
conf := makeConfig()
|
||||||
|
provider := makeProvider(r, conf)
|
||||||
|
|
||||||
|
r.NoError(provider.GenerateRoot())
|
||||||
|
|
||||||
|
cn := "foo"
|
||||||
|
pk, _, err := connect.GeneratePrivateKeyWithConfig("rsa", 2048)
|
||||||
|
r.NoError(err)
|
||||||
|
|
||||||
|
serviceID := &connect.SpiffeIDService{
|
||||||
|
Host: "11111111-2222-3333-4444-555555555555.consul",
|
||||||
|
Datacenter: "dc1",
|
||||||
|
Namespace: "default",
|
||||||
|
Service: "foo",
|
||||||
|
}
|
||||||
|
|
||||||
|
csrText, err := connect.CreateCSR(cn, serviceID, pk)
|
||||||
|
r.NoError(err)
|
||||||
|
|
||||||
|
csr, err := connect.ParseCSR(csrText)
|
||||||
|
r.NoError(err)
|
||||||
|
|
||||||
|
leafText, err := provider.Sign(csr)
|
||||||
|
r.NoError(err)
|
||||||
|
|
||||||
|
leaf, err := connect.ParseCert(leafText)
|
||||||
|
r.NoError(err)
|
||||||
|
|
||||||
|
r.Equal(csr.Subject.CommonName, leaf.Subject.CommonName)
|
||||||
|
r.Equal(serviceID.URI().String(), leaf.URIs[0].String())
|
||||||
|
r.True(leaf.NotBefore.Before(time.Now()))
|
||||||
|
r.True(leaf.NotAfter.After(time.Now()))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAWSProvider_GenerateIntermediateCSR(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
if awsAccessKeyId == "" || awsSecretAccessKey == "" || awsRegion == "" {
|
||||||
|
t.Skip("skipping test due to missing AWS credentials")
|
||||||
|
}
|
||||||
|
|
||||||
|
r := require.New(t)
|
||||||
|
conf := makeConfig()
|
||||||
|
provider := makeProvider(r, conf)
|
||||||
|
|
||||||
|
r.NoError(provider.GenerateRoot())
|
||||||
|
|
||||||
|
_, err := provider.GenerateIntermediateCSR()
|
||||||
|
r.Error(err)
|
||||||
|
|
||||||
|
provider.isRoot = false
|
||||||
|
csrText, err := provider.GenerateIntermediateCSR()
|
||||||
|
|
||||||
|
csr, err := connect.ParseCSR(csrText)
|
||||||
|
r.NoError(err)
|
||||||
|
|
||||||
|
r.Contains(csr.Subject.CommonName, conf.ClusterID)
|
||||||
|
r.Equal(x509.ECDSA, csr.PublicKeyAlgorithm)
|
||||||
|
r.Equal(x509.ECDSAWithSHA256, csr.SignatureAlgorithm)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAWSProvider_ActiveRoot(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
if awsAccessKeyId == "" || awsSecretAccessKey == "" || awsRegion == "" {
|
||||||
|
t.Skip("skipping test due to missing AWS credentials")
|
||||||
|
}
|
||||||
|
|
||||||
|
r := require.New(t)
|
||||||
|
conf := makeConfig()
|
||||||
|
provider := makeProvider(r, conf)
|
||||||
|
|
||||||
|
r.NoError(provider.GenerateRoot())
|
||||||
|
|
||||||
|
rootText, err := provider.ActiveRoot()
|
||||||
|
r.NoError(err)
|
||||||
|
|
||||||
|
root, err := connect.ParseCert(rootText)
|
||||||
|
r.NoError(err)
|
||||||
|
|
||||||
|
r.Equal(x509.ECDSA, root.PublicKeyAlgorithm)
|
||||||
|
r.Equal(x509.ECDSAWithSHA256, root.SignatureAlgorithm)
|
||||||
|
r.True(root.NotBefore.Before(time.Now()))
|
||||||
|
r.True(root.NotAfter.After(time.Now()))
|
||||||
|
r.True(root.IsCA)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAWSProvider_GenerateIntermediate(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
if awsAccessKeyId == "" || awsSecretAccessKey == "" || awsRegion == "" {
|
||||||
|
t.Skip("skipping test due to missing AWS credentials")
|
||||||
|
}
|
||||||
|
|
||||||
|
r := require.New(t)
|
||||||
|
conf := makeConfig()
|
||||||
|
provider := makeProvider(r, conf)
|
||||||
|
|
||||||
|
r.NoError(provider.GenerateRoot())
|
||||||
|
|
||||||
|
interText, err := provider.GenerateIntermediate()
|
||||||
|
r.NoError(err)
|
||||||
|
|
||||||
|
inter, err := connect.ParseCert(interText)
|
||||||
|
r.NoError(err)
|
||||||
|
|
||||||
|
r.Contains(inter.Subject.CommonName, conf.ClusterID)
|
||||||
|
r.Equal(x509.ECDSA, inter.PublicKeyAlgorithm)
|
||||||
|
r.Equal(x509.ECDSAWithSHA256, inter.SignatureAlgorithm)
|
||||||
|
r.True(inter.NotBefore.Before(time.Now()))
|
||||||
|
r.True(inter.NotAfter.After(time.Now()))
|
||||||
|
r.True(inter.IsCA)
|
||||||
|
r.True(inter.MaxPathLenZero)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAWSProvider_ActiveIntermediate(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
if awsAccessKeyId == "" || awsSecretAccessKey == "" || awsRegion == "" {
|
||||||
|
t.Skip("skipping test due to missing AWS credentials")
|
||||||
|
}
|
||||||
|
|
||||||
|
r := require.New(t)
|
||||||
|
conf := makeConfig()
|
||||||
|
provider := makeProvider(r, conf)
|
||||||
|
|
||||||
|
r.NoError(provider.GenerateRoot())
|
||||||
|
|
||||||
|
interText, err := provider.ActiveIntermediate()
|
||||||
|
r.NoError(err)
|
||||||
|
|
||||||
|
inter, err := connect.ParseCert(interText)
|
||||||
|
r.NoError(err)
|
||||||
|
|
||||||
|
r.Contains(inter.Subject.CommonName, conf.ClusterID)
|
||||||
|
r.Equal(x509.ECDSA, inter.PublicKeyAlgorithm)
|
||||||
|
r.Equal(x509.ECDSAWithSHA256, inter.SignatureAlgorithm)
|
||||||
|
r.True(inter.NotBefore.Before(time.Now()))
|
||||||
|
r.True(inter.NotAfter.After(time.Now()))
|
||||||
|
r.True(inter.IsCA)
|
||||||
|
r.True(inter.MaxPathLenZero)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAWSProvider_SignIntermediate(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
if awsAccessKeyId == "" || awsSecretAccessKey == "" || awsRegion == "" {
|
||||||
|
t.Skip("skipping test due to missing AWS credentials")
|
||||||
|
}
|
||||||
|
|
||||||
|
r := require.New(t)
|
||||||
|
conf := makeConfig()
|
||||||
|
provider := makeProvider(r, conf)
|
||||||
|
|
||||||
|
r.NoError(provider.GenerateRoot())
|
||||||
|
|
||||||
|
conf2 := testConsulCAConfig()
|
||||||
|
delegate2 := newMockDelegate(t, conf2)
|
||||||
|
provider2 := &ConsulProvider{Delegate: delegate2}
|
||||||
|
r.NoError(provider2.Configure(conf2.ClusterID, false, conf2.Config))
|
||||||
|
|
||||||
|
testSignIntermediateCrossDC(t, provider, provider2)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAWSProvider_Cleanup(t *testing.T) {
|
||||||
|
// THIS TEST CANNOT BE RUN IN PARALLEL.
|
||||||
|
// It disables and deletes the PCA, which will cause other tests to fail if they are running
|
||||||
|
// at the same time.
|
||||||
|
|
||||||
|
if awsAccessKeyId == "" || awsSecretAccessKey == "" || awsRegion == "" {
|
||||||
|
t.Skip("skipping test due to missing AWS credentials")
|
||||||
|
}
|
||||||
|
|
||||||
|
r := require.New(t)
|
||||||
|
conf := makeConfig()
|
||||||
|
conf.Config["DeleteOnExit"] = true
|
||||||
|
provider := makeProvider(r, conf)
|
||||||
|
|
||||||
|
r.NoError(provider.GenerateRoot())
|
||||||
|
_, err := provider.GenerateIntermediate()
|
||||||
|
r.NoError(err)
|
||||||
|
|
||||||
|
rootPCA := provider.rootPCA
|
||||||
|
subPCA := provider.subPCA
|
||||||
|
r.NoError(provider.Cleanup())
|
||||||
|
|
||||||
|
output, err := awsClient.DescribeCertificateAuthority(&acmpca.DescribeCertificateAuthorityInput{
|
||||||
|
CertificateAuthorityArn: aws.String(rootPCA.arn),
|
||||||
|
})
|
||||||
|
r.NoError(err)
|
||||||
|
|
||||||
|
ca := output.CertificateAuthority
|
||||||
|
r.Equal(acmpca.CertificateAuthorityStatusDeleted, *ca.Status)
|
||||||
|
r.Equal((*AmazonPCA)(nil), provider.rootPCA)
|
||||||
|
|
||||||
|
r.NoError(rootPCA.Undelete())
|
||||||
|
r.NoError(rootPCA.Enable())
|
||||||
|
|
||||||
|
output, err = awsClient.DescribeCertificateAuthority(&acmpca.DescribeCertificateAuthorityInput{
|
||||||
|
CertificateAuthorityArn: aws.String(subPCA.arn),
|
||||||
|
})
|
||||||
|
r.NoError(err)
|
||||||
|
|
||||||
|
ca = output.CertificateAuthority
|
||||||
|
r.Equal(acmpca.CertificateAuthorityStatusDeleted, *ca.Status)
|
||||||
|
r.Equal((*AmazonPCA)(nil), provider.subPCA)
|
||||||
|
|
||||||
|
r.NoError(subPCA.Undelete())
|
||||||
|
r.NoError(subPCA.Enable())
|
||||||
|
}
|
|
@ -658,3 +658,11 @@ func (c *ConsulProvider) generateCA(privateKey string, sn uint64) (string, error
|
||||||
|
|
||||||
return buf.String(), nil
|
return buf.String(), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c *ConsulProvider) SupportsCrossSigning() bool {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ConsulProvider) MinLifetime() time.Duration {
|
||||||
|
return 1 * time.Hour
|
||||||
|
}
|
||||||
|
|
|
@ -8,6 +8,7 @@ import (
|
||||||
"io/ioutil"
|
"io/ioutil"
|
||||||
"net/http"
|
"net/http"
|
||||||
"strings"
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/hashicorp/consul/agent/connect"
|
"github.com/hashicorp/consul/agent/connect"
|
||||||
"github.com/hashicorp/consul/agent/structs"
|
"github.com/hashicorp/consul/agent/structs"
|
||||||
|
@ -376,6 +377,14 @@ func (v *VaultProvider) Cleanup() error {
|
||||||
return v.client.Sys().Unmount(v.config.IntermediatePKIPath)
|
return v.client.Sys().Unmount(v.config.IntermediatePKIPath)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (v *VaultProvider) SupportsCrossSigning() bool {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
func (v *VaultProvider) MinLifetime() time.Duration {
|
||||||
|
return 1 * time.Hour
|
||||||
|
}
|
||||||
|
|
||||||
func ParseVaultCAConfig(raw map[string]interface{}) (*structs.VaultCAProviderConfig, error) {
|
func ParseVaultCAConfig(raw map[string]interface{}) (*structs.VaultCAProviderConfig, error) {
|
||||||
config := structs.VaultCAProviderConfig{
|
config := structs.VaultCAProviderConfig{
|
||||||
CommonCAProviderConfig: defaultCommonConfig(),
|
CommonCAProviderConfig: defaultCommonConfig(),
|
||||||
|
|
|
@ -3,6 +3,7 @@ package connect
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"crypto"
|
"crypto"
|
||||||
|
"crypto/ecdsa"
|
||||||
"crypto/rand"
|
"crypto/rand"
|
||||||
"crypto/x509"
|
"crypto/x509"
|
||||||
"crypto/x509/pkix"
|
"crypto/x509/pkix"
|
||||||
|
@ -13,10 +14,21 @@ import (
|
||||||
|
|
||||||
// CreateCSR returns a CSR to sign the given service along with the PEM-encoded
|
// CreateCSR returns a CSR to sign the given service along with the PEM-encoded
|
||||||
// private key for this certificate.
|
// private key for this certificate.
|
||||||
func CreateCSR(uri CertURI, privateKey crypto.Signer, extensions ...pkix.Extension) (string, error) {
|
func CreateCSR(commonName string, uri CertURI, privateKey crypto.Signer,
|
||||||
|
extensions ...pkix.Extension) (string, error) {
|
||||||
|
|
||||||
|
signAlgo := x509.SHA256WithRSA
|
||||||
|
_, ok := privateKey.(*ecdsa.PrivateKey)
|
||||||
|
if ok {
|
||||||
|
signAlgo = x509.ECDSAWithSHA256
|
||||||
|
}
|
||||||
|
|
||||||
template := &x509.CertificateRequest{
|
template := &x509.CertificateRequest{
|
||||||
|
Subject: pkix.Name{
|
||||||
|
CommonName: commonName,
|
||||||
|
},
|
||||||
URIs: []*url.URL{uri.URI()},
|
URIs: []*url.URL{uri.URI()},
|
||||||
SignatureAlgorithm: x509.ECDSAWithSHA256,
|
SignatureAlgorithm: signAlgo,
|
||||||
ExtraExtensions: extensions,
|
ExtraExtensions: extensions,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -43,7 +55,7 @@ func CreateCACSR(uri CertURI, privateKey crypto.Signer) (string, error) {
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
|
|
||||||
return CreateCSR(uri, privateKey, ext)
|
return CreateCSR("Consul CA", uri, privateKey, ext)
|
||||||
}
|
}
|
||||||
|
|
||||||
// CreateCAExtension creates a pkix.Extension for the x509 Basic Constraints
|
// CreateCAExtension creates a pkix.Extension for the x509 Basic Constraints
|
||||||
|
|
|
@ -59,7 +59,7 @@ func (c *Client) RequestAutoEncryptCerts(servers []string, defaultPort int, toke
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create a CSR.
|
// Create a CSR.
|
||||||
csr, err := connect.CreateCSR(id, pk)
|
csr, err := connect.CreateCSR("Consul RPC", id, pk)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return errFn(err)
|
return errFn(err)
|
||||||
}
|
}
|
||||||
|
|
|
@ -75,7 +75,7 @@ func TestAutoEncryptSign(t *testing.T) {
|
||||||
require.NoError(t, err, info)
|
require.NoError(t, err, info)
|
||||||
|
|
||||||
// Create a CSR.
|
// Create a CSR.
|
||||||
csr, err := connect.CreateCSR(id, pk)
|
csr, err := connect.CreateCSR(info, id, pk)
|
||||||
require.NoError(t, err, info)
|
require.NoError(t, err, info)
|
||||||
require.NotEmpty(t, csr, info)
|
require.NotEmpty(t, csr, info)
|
||||||
args := &structs.CASignRequest{
|
args := &structs.CASignRequest{
|
||||||
|
|
|
@ -168,6 +168,16 @@ func (s *ConnectCA) ConfigurationSet(
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Have the old provider cross-sign the new intermediate
|
||||||
|
oldProvider, _ := s.srv.getCAProvider()
|
||||||
|
if oldProvider == nil {
|
||||||
|
return fmt.Errorf("internal error: CA provider is nil")
|
||||||
|
}
|
||||||
|
|
||||||
|
if !oldProvider.SupportsCrossSigning() {
|
||||||
|
return fmt.Errorf("error: current CA does not support cross-signing")
|
||||||
|
}
|
||||||
|
|
||||||
// Create a new instance of the provider described by the config
|
// Create a new instance of the provider described by the config
|
||||||
// and get the current active root CA. This acts as a good validation
|
// and get the current active root CA. This acts as a good validation
|
||||||
// of the config and makes sure the provider is functioning correctly
|
// of the config and makes sure the provider is functioning correctly
|
||||||
|
@ -237,11 +247,6 @@ func (s *ConnectCA) ConfigurationSet(
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Have the old provider cross-sign the new intermediate
|
|
||||||
oldProvider, _ := s.srv.getCAProvider()
|
|
||||||
if oldProvider == nil {
|
|
||||||
return fmt.Errorf("internal error: CA provider is nil")
|
|
||||||
}
|
|
||||||
xcCert, err := oldProvider.CrossSignCA(newRoot)
|
xcCert, err := oldProvider.CrossSignCA(newRoot)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
|
|
|
@ -112,6 +112,53 @@ type CARoot struct {
|
||||||
// CARoots is a list of CARoot structures.
|
// CARoots is a list of CARoot structures.
|
||||||
type CARoots []*CARoot
|
type CARoots []*CARoot
|
||||||
|
|
||||||
|
// CAIntermediate represents an intermediate CA certificate that is trusted and used
|
||||||
|
// to sign leaf certificates.
|
||||||
|
type CAIntermediate struct {
|
||||||
|
// ID is a globally unique ID (UUID) representing this CA root.
|
||||||
|
ID string
|
||||||
|
|
||||||
|
// Name is a human-friendly name for this CA root. This value is
|
||||||
|
// opaque to Consul and is not used for anything internally.
|
||||||
|
Name string
|
||||||
|
|
||||||
|
// SerialNumber is the x509 serial number of the certificate.
|
||||||
|
SerialNumber uint64
|
||||||
|
|
||||||
|
// SigningKeyID is the ID of the public key that corresponds to the private
|
||||||
|
// key used to sign the certificate. Is is the HexString format of the raw
|
||||||
|
// AuthorityKeyID bytes.
|
||||||
|
SigningKeyID string
|
||||||
|
|
||||||
|
// ExternalTrustDomain is the trust domain this root was generated under. It
|
||||||
|
// is usually empty implying "the current cluster trust-domain". It is set
|
||||||
|
// only in the case that a cluster changes trust domain and then all old roots
|
||||||
|
// that are still trusted have the old trust domain set here.
|
||||||
|
//
|
||||||
|
// We currently DON'T validate these trust domains explicitly anywhere, see
|
||||||
|
// IndexedRoots.TrustDomain doc. We retain this information for debugging and
|
||||||
|
// future flexibility.
|
||||||
|
ExternalTrustDomain string
|
||||||
|
|
||||||
|
// Time validity bounds.
|
||||||
|
NotBefore time.Time
|
||||||
|
NotAfter time.Time
|
||||||
|
|
||||||
|
// IntermediateCert is the PEM-encoded public certificate.
|
||||||
|
IntermediateCert string
|
||||||
|
|
||||||
|
// Active is true if this is the current active CA. This must only
|
||||||
|
// be true for exactly one CA. For any method that modifies roots in the
|
||||||
|
// state store, tests should be written to verify that multiple roots
|
||||||
|
// cannot be active.
|
||||||
|
Active bool
|
||||||
|
|
||||||
|
// RotatedOutAt is the time at which this CA was removed from the state.
|
||||||
|
// This will only be set on roots that have been rotated out from being the
|
||||||
|
// active root.
|
||||||
|
RotatedOutAt time.Time `json:"-"`
|
||||||
|
}
|
||||||
|
|
||||||
// CASignRequest is the request for signing a service certificate.
|
// CASignRequest is the request for signing a service certificate.
|
||||||
type CASignRequest struct {
|
type CASignRequest struct {
|
||||||
// Datacenter is the target for this request.
|
// Datacenter is the target for this request.
|
||||||
|
@ -207,6 +254,7 @@ func (q *CARequest) RequestDatacenter() string {
|
||||||
const (
|
const (
|
||||||
ConsulCAProvider = "consul"
|
ConsulCAProvider = "consul"
|
||||||
VaultCAProvider = "vault"
|
VaultCAProvider = "vault"
|
||||||
|
AWSCAProvider = "aws"
|
||||||
)
|
)
|
||||||
|
|
||||||
// CAConfiguration is the configuration for the current CA plugin.
|
// CAConfiguration is the configuration for the current CA plugin.
|
||||||
|
@ -349,6 +397,20 @@ type VaultCAProviderConfig struct {
|
||||||
TLSSkipVerify bool
|
TLSSkipVerify bool
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type AWSCAProviderConfig struct {
|
||||||
|
CommonCAProviderConfig `mapstructure:",squash"`
|
||||||
|
|
||||||
|
AccessKeyId string
|
||||||
|
SecretAccessKey string
|
||||||
|
Region string
|
||||||
|
SleepTime string
|
||||||
|
RootARN string
|
||||||
|
IntermediateARN string
|
||||||
|
KeyAlgorithm string
|
||||||
|
SigningAlgorithm string
|
||||||
|
DeleteOnExit bool
|
||||||
|
}
|
||||||
|
|
||||||
// CALeafOp is the operation for a request related to leaf certificates.
|
// CALeafOp is the operation for a request related to leaf certificates.
|
||||||
type CALeafOp string
|
type CALeafOp string
|
||||||
|
|
||||||
|
|
3
go.mod
3
go.mod
|
@ -20,6 +20,7 @@ require (
|
||||||
github.com/armon/go-metrics v0.0.0-20190430140413-ec5e00d3c878
|
github.com/armon/go-metrics v0.0.0-20190430140413-ec5e00d3c878
|
||||||
github.com/armon/go-radix v0.0.0-20180808171621-7fddfc383310
|
github.com/armon/go-radix v0.0.0-20180808171621-7fddfc383310
|
||||||
github.com/asaskevich/govalidator v0.0.0-20180319081651-7d2e70ef918f // indirect
|
github.com/asaskevich/govalidator v0.0.0-20180319081651-7d2e70ef918f // indirect
|
||||||
|
github.com/aws/aws-sdk-go v1.21.1
|
||||||
github.com/bitly/go-hostpool v0.0.0-20171023180738-a3a6125de932 // indirect
|
github.com/bitly/go-hostpool v0.0.0-20171023180738-a3a6125de932 // indirect
|
||||||
github.com/bmizerany/assert v0.0.0-20160611221934-b7ed37b82869 // indirect
|
github.com/bmizerany/assert v0.0.0-20160611221934-b7ed37b82869 // indirect
|
||||||
github.com/cenkalti/backoff v2.1.1+incompatible // indirect
|
github.com/cenkalti/backoff v2.1.1+incompatible // indirect
|
||||||
|
@ -56,7 +57,7 @@ require (
|
||||||
github.com/hashicorp/go-msgpack v0.5.5
|
github.com/hashicorp/go-msgpack v0.5.5
|
||||||
github.com/hashicorp/go-multierror v1.0.0
|
github.com/hashicorp/go-multierror v1.0.0
|
||||||
github.com/hashicorp/go-plugin v0.0.0-20180331002553-e8d22c780116
|
github.com/hashicorp/go-plugin v0.0.0-20180331002553-e8d22c780116
|
||||||
github.com/hashicorp/go-raftchunking v0.6.1
|
github.com/hashicorp/go-raftchunking v0.6.2
|
||||||
github.com/hashicorp/go-sockaddr v1.0.0
|
github.com/hashicorp/go-sockaddr v1.0.0
|
||||||
github.com/hashicorp/go-syslog v1.0.0
|
github.com/hashicorp/go-syslog v1.0.0
|
||||||
github.com/hashicorp/go-uuid v1.0.1
|
github.com/hashicorp/go-uuid v1.0.1
|
||||||
|
|
6
go.sum
6
go.sum
|
@ -38,6 +38,8 @@ github.com/asaskevich/govalidator v0.0.0-20180319081651-7d2e70ef918f h1:/8NcnxL6
|
||||||
github.com/asaskevich/govalidator v0.0.0-20180319081651-7d2e70ef918f/go.mod h1:lB+ZfQJz7igIIfQNfa7Ml4HSf2uFQQRzpGGRXenZAgY=
|
github.com/asaskevich/govalidator v0.0.0-20180319081651-7d2e70ef918f/go.mod h1:lB+ZfQJz7igIIfQNfa7Ml4HSf2uFQQRzpGGRXenZAgY=
|
||||||
github.com/aws/aws-sdk-go v1.15.24 h1:xLAdTA/ore6xdPAljzZRed7IGqQgC+nY+ERS5vaj4Ro=
|
github.com/aws/aws-sdk-go v1.15.24 h1:xLAdTA/ore6xdPAljzZRed7IGqQgC+nY+ERS5vaj4Ro=
|
||||||
github.com/aws/aws-sdk-go v1.15.24/go.mod h1:mFuSZ37Z9YOHbQEwBWztmVzqXrEkub65tZoCYDt7FT0=
|
github.com/aws/aws-sdk-go v1.15.24/go.mod h1:mFuSZ37Z9YOHbQEwBWztmVzqXrEkub65tZoCYDt7FT0=
|
||||||
|
github.com/aws/aws-sdk-go v1.21.1 h1:IOFDnCEDybcw4V8nbKqyyjBu+vpu7hFYSfZqNuogi7I=
|
||||||
|
github.com/aws/aws-sdk-go v1.21.1/go.mod h1:KmX6BPdI08NWTb3/sm4ZGu5ShLoqVDhKgpiN924inxo=
|
||||||
github.com/beorn7/perks v0.0.0-20180321164747-3a771d992973 h1:xJ4a3vCFaGF/jqvzLMYoU8P317H5OQ+Via4RmuPwCS0=
|
github.com/beorn7/perks v0.0.0-20180321164747-3a771d992973 h1:xJ4a3vCFaGF/jqvzLMYoU8P317H5OQ+Via4RmuPwCS0=
|
||||||
github.com/beorn7/perks v0.0.0-20180321164747-3a771d992973/go.mod h1:Dwedo/Wpr24TaqPxmxbtue+5NUziq4I4S80YR8gNf3Q=
|
github.com/beorn7/perks v0.0.0-20180321164747-3a771d992973/go.mod h1:Dwedo/Wpr24TaqPxmxbtue+5NUziq4I4S80YR8gNf3Q=
|
||||||
github.com/bgentry/speakeasy v0.1.0 h1:ByYyxL9InA1OWqxJqqp2A5pYHUrCiAL6K3J+LKSsQkY=
|
github.com/bgentry/speakeasy v0.1.0 h1:ByYyxL9InA1OWqxJqqp2A5pYHUrCiAL6K3J+LKSsQkY=
|
||||||
|
@ -171,6 +173,8 @@ github.com/hashicorp/go-plugin v0.0.0-20180331002553-e8d22c780116 h1:Y4V/yReWjQo
|
||||||
github.com/hashicorp/go-plugin v0.0.0-20180331002553-e8d22c780116/go.mod h1:JSqWYsict+jzcj0+xElxyrBQRPNoiWQuddnxArJ7XHQ=
|
github.com/hashicorp/go-plugin v0.0.0-20180331002553-e8d22c780116/go.mod h1:JSqWYsict+jzcj0+xElxyrBQRPNoiWQuddnxArJ7XHQ=
|
||||||
github.com/hashicorp/go-raftchunking v0.6.1 h1:moEnaG3gcwsWNyIBJoD5PCByE+Ewkqxh6N05CT+MbwA=
|
github.com/hashicorp/go-raftchunking v0.6.1 h1:moEnaG3gcwsWNyIBJoD5PCByE+Ewkqxh6N05CT+MbwA=
|
||||||
github.com/hashicorp/go-raftchunking v0.6.1/go.mod h1:cGlg3JtDy7qy6c/3Bu660Mic1JF+7lWqIwCFSb08fX0=
|
github.com/hashicorp/go-raftchunking v0.6.1/go.mod h1:cGlg3JtDy7qy6c/3Bu660Mic1JF+7lWqIwCFSb08fX0=
|
||||||
|
github.com/hashicorp/go-raftchunking v0.6.2 h1:imj6CVkwXj6VzgXZQvzS+fSrkbFCzlJ2t00F3PacnuU=
|
||||||
|
github.com/hashicorp/go-raftchunking v0.6.2/go.mod h1:cGlg3JtDy7qy6c/3Bu660Mic1JF+7lWqIwCFSb08fX0=
|
||||||
github.com/hashicorp/go-retryablehttp v0.5.3 h1:QlWt0KvWT0lq8MFppF9tsJGF+ynG7ztc2KIPhzRGk7s=
|
github.com/hashicorp/go-retryablehttp v0.5.3 h1:QlWt0KvWT0lq8MFppF9tsJGF+ynG7ztc2KIPhzRGk7s=
|
||||||
github.com/hashicorp/go-retryablehttp v0.5.3/go.mod h1:9B5zBasrRhHXnJnui7y6sL7es7NDiJgTc6Er0maI1Xs=
|
github.com/hashicorp/go-retryablehttp v0.5.3/go.mod h1:9B5zBasrRhHXnJnui7y6sL7es7NDiJgTc6Er0maI1Xs=
|
||||||
github.com/hashicorp/go-rootcerts v1.0.0 h1:Rqb66Oo1X/eSV1x66xbDccZjhJigjg0+e82kpwzSwCI=
|
github.com/hashicorp/go-rootcerts v1.0.0 h1:Rqb66Oo1X/eSV1x66xbDccZjhJigjg0+e82kpwzSwCI=
|
||||||
|
@ -229,6 +233,8 @@ github.com/jefferai/jsonx v0.0.0-20160721235117-9cc31c3135ee h1:AQ/QmCk6x8ECPpf2
|
||||||
github.com/jefferai/jsonx v0.0.0-20160721235117-9cc31c3135ee/go.mod h1:N0t2vlmpe8nyZB5ouIbJQPDSR+mH6oe7xHB9VZHSUzM=
|
github.com/jefferai/jsonx v0.0.0-20160721235117-9cc31c3135ee/go.mod h1:N0t2vlmpe8nyZB5ouIbJQPDSR+mH6oe7xHB9VZHSUzM=
|
||||||
github.com/jmespath/go-jmespath v0.0.0-20160202185014-0b12d6b521d8 h1:12VvqtR6Aowv3l/EQUlocDHW2Cp4G9WJVH7uyH8QFJE=
|
github.com/jmespath/go-jmespath v0.0.0-20160202185014-0b12d6b521d8 h1:12VvqtR6Aowv3l/EQUlocDHW2Cp4G9WJVH7uyH8QFJE=
|
||||||
github.com/jmespath/go-jmespath v0.0.0-20160202185014-0b12d6b521d8/go.mod h1:Nht3zPeWKUH0NzdCt2Blrr5ys8VGpn0CEB0cQHVjt7k=
|
github.com/jmespath/go-jmespath v0.0.0-20160202185014-0b12d6b521d8/go.mod h1:Nht3zPeWKUH0NzdCt2Blrr5ys8VGpn0CEB0cQHVjt7k=
|
||||||
|
github.com/jmespath/go-jmespath v0.0.0-20180206201540-c2b33e8439af h1:pmfjZENx5imkbgOkpRUYLnmbU7UEFbjtDA2hxJ1ichM=
|
||||||
|
github.com/jmespath/go-jmespath v0.0.0-20180206201540-c2b33e8439af/go.mod h1:Nht3zPeWKUH0NzdCt2Blrr5ys8VGpn0CEB0cQHVjt7k=
|
||||||
github.com/joyent/triton-go v0.0.0-20180628001255-830d2b111e62 h1:JHCT6xuyPUrbbgAPE/3dqlvUKzRHMNuTBKKUb6OeR/k=
|
github.com/joyent/triton-go v0.0.0-20180628001255-830d2b111e62 h1:JHCT6xuyPUrbbgAPE/3dqlvUKzRHMNuTBKKUb6OeR/k=
|
||||||
github.com/joyent/triton-go v0.0.0-20180628001255-830d2b111e62/go.mod h1:U+RSyWxWd04xTqnuOQxnai7XGS2PrPY2cfGoDKtMHjA=
|
github.com/joyent/triton-go v0.0.0-20180628001255-830d2b111e62/go.mod h1:U+RSyWxWd04xTqnuOQxnai7XGS2PrPY2cfGoDKtMHjA=
|
||||||
github.com/json-iterator/go v1.1.5 h1:gL2yXlmiIo4+t+y32d4WGwOjKGYcGOuyrg46vadswDE=
|
github.com/json-iterator/go v1.1.5 h1:gL2yXlmiIo4+t+y32d4WGwOjKGYcGOuyrg46vadswDE=
|
||||||
|
|
|
@ -138,8 +138,27 @@ type RequestFailure interface {
|
||||||
RequestID() string
|
RequestID() string
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewRequestFailure returns a new request error wrapper for the given Error
|
// NewRequestFailure returns a wrapped error with additional information for
|
||||||
// provided.
|
// request status code, and service requestID.
|
||||||
|
//
|
||||||
|
// Should be used to wrap all request which involve service requests. Even if
|
||||||
|
// the request failed without a service response, but had an HTTP status code
|
||||||
|
// that may be meaningful.
|
||||||
func NewRequestFailure(err Error, statusCode int, reqID string) RequestFailure {
|
func NewRequestFailure(err Error, statusCode int, reqID string) RequestFailure {
|
||||||
return newRequestError(err, statusCode, reqID)
|
return newRequestError(err, statusCode, reqID)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// UnmarshalError provides the interface for the SDK failing to unmarshal data.
|
||||||
|
type UnmarshalError interface {
|
||||||
|
awsError
|
||||||
|
Bytes() []byte
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewUnmarshalError returns an initialized UnmarshalError error wrapper adding
|
||||||
|
// the bytes that fail to unmarshal to the error.
|
||||||
|
func NewUnmarshalError(err error, msg string, bytes []byte) UnmarshalError {
|
||||||
|
return &unmarshalError{
|
||||||
|
awsError: New("UnmarshalError", msg, err),
|
||||||
|
bytes: bytes,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
@ -1,6 +1,9 @@
|
||||||
package awserr
|
package awserr
|
||||||
|
|
||||||
import "fmt"
|
import (
|
||||||
|
"encoding/hex"
|
||||||
|
"fmt"
|
||||||
|
)
|
||||||
|
|
||||||
// SprintError returns a string of the formatted error code.
|
// SprintError returns a string of the formatted error code.
|
||||||
//
|
//
|
||||||
|
@ -119,6 +122,7 @@ type requestError struct {
|
||||||
awsError
|
awsError
|
||||||
statusCode int
|
statusCode int
|
||||||
requestID string
|
requestID string
|
||||||
|
bytes []byte
|
||||||
}
|
}
|
||||||
|
|
||||||
// newRequestError returns a wrapped error with additional information for
|
// newRequestError returns a wrapped error with additional information for
|
||||||
|
@ -170,6 +174,29 @@ func (r requestError) OrigErrs() []error {
|
||||||
return []error{r.OrigErr()}
|
return []error{r.OrigErr()}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type unmarshalError struct {
|
||||||
|
awsError
|
||||||
|
bytes []byte
|
||||||
|
}
|
||||||
|
|
||||||
|
// Error returns the string representation of the error.
|
||||||
|
// Satisfies the error interface.
|
||||||
|
func (e unmarshalError) Error() string {
|
||||||
|
extra := hex.Dump(e.bytes)
|
||||||
|
return SprintError(e.Code(), e.Message(), extra, e.OrigErr())
|
||||||
|
}
|
||||||
|
|
||||||
|
// String returns the string representation of the error.
|
||||||
|
// Alias for Error to satisfy the stringer interface.
|
||||||
|
func (e unmarshalError) String() string {
|
||||||
|
return e.Error()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Bytes returns the bytes that failed to unmarshal.
|
||||||
|
func (e unmarshalError) Bytes() []byte {
|
||||||
|
return e.bytes
|
||||||
|
}
|
||||||
|
|
||||||
// An error list that satisfies the golang interface
|
// An error list that satisfies the golang interface
|
||||||
type errorList []error
|
type errorList []error
|
||||||
|
|
||||||
|
|
|
@ -15,7 +15,7 @@ func DeepEqual(a, b interface{}) bool {
|
||||||
rb := reflect.Indirect(reflect.ValueOf(b))
|
rb := reflect.Indirect(reflect.ValueOf(b))
|
||||||
|
|
||||||
if raValid, rbValid := ra.IsValid(), rb.IsValid(); !raValid && !rbValid {
|
if raValid, rbValid := ra.IsValid(), rb.IsValid(); !raValid && !rbValid {
|
||||||
// If the elements are both nil, and of the same type the are equal
|
// If the elements are both nil, and of the same type they are equal
|
||||||
// If they are of different types they are not equal
|
// If they are of different types they are not equal
|
||||||
return reflect.TypeOf(a) == reflect.TypeOf(b)
|
return reflect.TypeOf(a) == reflect.TypeOf(b)
|
||||||
} else if raValid != rbValid {
|
} else if raValid != rbValid {
|
||||||
|
|
|
@ -23,28 +23,27 @@ func stringValue(v reflect.Value, indent int, buf *bytes.Buffer) {
|
||||||
case reflect.Struct:
|
case reflect.Struct:
|
||||||
buf.WriteString("{\n")
|
buf.WriteString("{\n")
|
||||||
|
|
||||||
names := []string{}
|
|
||||||
for i := 0; i < v.Type().NumField(); i++ {
|
for i := 0; i < v.Type().NumField(); i++ {
|
||||||
name := v.Type().Field(i).Name
|
ft := v.Type().Field(i)
|
||||||
f := v.Field(i)
|
fv := v.Field(i)
|
||||||
if name[0:1] == strings.ToLower(name[0:1]) {
|
|
||||||
|
if ft.Name[0:1] == strings.ToLower(ft.Name[0:1]) {
|
||||||
continue // ignore unexported fields
|
continue // ignore unexported fields
|
||||||
}
|
}
|
||||||
if (f.Kind() == reflect.Ptr || f.Kind() == reflect.Slice) && f.IsNil() {
|
if (fv.Kind() == reflect.Ptr || fv.Kind() == reflect.Slice) && fv.IsNil() {
|
||||||
continue // ignore unset fields
|
continue // ignore unset fields
|
||||||
}
|
}
|
||||||
names = append(names, name)
|
|
||||||
}
|
|
||||||
|
|
||||||
for i, n := range names {
|
|
||||||
val := v.FieldByName(n)
|
|
||||||
buf.WriteString(strings.Repeat(" ", indent+2))
|
buf.WriteString(strings.Repeat(" ", indent+2))
|
||||||
buf.WriteString(n + ": ")
|
buf.WriteString(ft.Name + ": ")
|
||||||
stringValue(val, indent+2, buf)
|
|
||||||
|
|
||||||
if i < len(names)-1 {
|
if tag := ft.Tag.Get("sensitive"); tag == "true" {
|
||||||
buf.WriteString(",\n")
|
buf.WriteString("<sensitive>")
|
||||||
|
} else {
|
||||||
|
stringValue(fv, indent+2, buf)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
buf.WriteString(",\n")
|
||||||
}
|
}
|
||||||
|
|
||||||
buf.WriteString("\n" + strings.Repeat(" ", indent) + "}")
|
buf.WriteString("\n" + strings.Repeat(" ", indent) + "}")
|
||||||
|
|
|
@ -18,7 +18,7 @@ type Config struct {
|
||||||
|
|
||||||
// States that the signing name did not come from a modeled source but
|
// States that the signing name did not come from a modeled source but
|
||||||
// was derived based on other data. Used by service client constructors
|
// was derived based on other data. Used by service client constructors
|
||||||
// to determine if the signin name can be overriden based on metadata the
|
// to determine if the signin name can be overridden based on metadata the
|
||||||
// service has.
|
// service has.
|
||||||
SigningNameDerived bool
|
SigningNameDerived bool
|
||||||
}
|
}
|
||||||
|
|
|
@ -118,6 +118,12 @@ var LogHTTPResponseHandler = request.NamedHandler{
|
||||||
func logResponse(r *request.Request) {
|
func logResponse(r *request.Request) {
|
||||||
lw := &logWriter{r.Config.Logger, bytes.NewBuffer(nil)}
|
lw := &logWriter{r.Config.Logger, bytes.NewBuffer(nil)}
|
||||||
|
|
||||||
|
if r.HTTPResponse == nil {
|
||||||
|
lw.Logger.Log(fmt.Sprintf(logRespErrMsg,
|
||||||
|
r.ClientInfo.ServiceName, r.Operation.Name, "request's HTTPResponse is nil"))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
logBody := r.Config.LogLevel.Matches(aws.LogDebugWithHTTPBody)
|
logBody := r.Config.LogLevel.Matches(aws.LogDebugWithHTTPBody)
|
||||||
if logBody {
|
if logBody {
|
||||||
r.HTTPResponse.Body = &teeReaderCloser{
|
r.HTTPResponse.Body = &teeReaderCloser{
|
||||||
|
|
|
@ -18,7 +18,7 @@ const UseServiceDefaultRetries = -1
|
||||||
type RequestRetryer interface{}
|
type RequestRetryer interface{}
|
||||||
|
|
||||||
// A Config provides service configuration for service clients. By default,
|
// A Config provides service configuration for service clients. By default,
|
||||||
// all clients will use the defaults.DefaultConfig tructure.
|
// all clients will use the defaults.DefaultConfig structure.
|
||||||
//
|
//
|
||||||
// // Create Session with MaxRetry configuration to be shared by multiple
|
// // Create Session with MaxRetry configuration to be shared by multiple
|
||||||
// // service clients.
|
// // service clients.
|
||||||
|
@ -45,7 +45,7 @@ type Config struct {
|
||||||
// that overrides the default generated endpoint for a client. Set this
|
// that overrides the default generated endpoint for a client. Set this
|
||||||
// to `""` to use the default generated endpoint.
|
// to `""` to use the default generated endpoint.
|
||||||
//
|
//
|
||||||
// @note You must still provide a `Region` value when specifying an
|
// Note: You must still provide a `Region` value when specifying an
|
||||||
// endpoint for a client.
|
// endpoint for a client.
|
||||||
Endpoint *string
|
Endpoint *string
|
||||||
|
|
||||||
|
@ -65,8 +65,8 @@ type Config struct {
|
||||||
// noted. A full list of regions is found in the "Regions and Endpoints"
|
// noted. A full list of regions is found in the "Regions and Endpoints"
|
||||||
// document.
|
// document.
|
||||||
//
|
//
|
||||||
// @see http://docs.aws.amazon.com/general/latest/gr/rande.html
|
// See http://docs.aws.amazon.com/general/latest/gr/rande.html for AWS
|
||||||
// AWS Regions and Endpoints
|
// Regions and Endpoints.
|
||||||
Region *string
|
Region *string
|
||||||
|
|
||||||
// Set this to `true` to disable SSL when sending requests. Defaults
|
// Set this to `true` to disable SSL when sending requests. Defaults
|
||||||
|
@ -120,9 +120,10 @@ type Config struct {
|
||||||
// will use virtual hosted bucket addressing when possible
|
// will use virtual hosted bucket addressing when possible
|
||||||
// (`http://BUCKET.s3.amazonaws.com/KEY`).
|
// (`http://BUCKET.s3.amazonaws.com/KEY`).
|
||||||
//
|
//
|
||||||
// @note This configuration option is specific to the Amazon S3 service.
|
// Note: This configuration option is specific to the Amazon S3 service.
|
||||||
// @see http://docs.aws.amazon.com/AmazonS3/latest/dev/VirtualHosting.html
|
//
|
||||||
// Amazon S3: Virtual Hosting of Buckets
|
// See http://docs.aws.amazon.com/AmazonS3/latest/dev/VirtualHosting.html
|
||||||
|
// for Amazon S3: Virtual Hosting of Buckets
|
||||||
S3ForcePathStyle *bool
|
S3ForcePathStyle *bool
|
||||||
|
|
||||||
// Set this to `true` to disable the SDK adding the `Expect: 100-Continue`
|
// Set this to `true` to disable the SDK adding the `Expect: 100-Continue`
|
||||||
|
@ -223,6 +224,28 @@ type Config struct {
|
||||||
// Key: aws.String("//foo//bar//moo"),
|
// Key: aws.String("//foo//bar//moo"),
|
||||||
// })
|
// })
|
||||||
DisableRestProtocolURICleaning *bool
|
DisableRestProtocolURICleaning *bool
|
||||||
|
|
||||||
|
// EnableEndpointDiscovery will allow for endpoint discovery on operations that
|
||||||
|
// have the definition in its model. By default, endpoint discovery is off.
|
||||||
|
//
|
||||||
|
// Example:
|
||||||
|
// sess := session.Must(session.NewSession(&aws.Config{
|
||||||
|
// EnableEndpointDiscovery: aws.Bool(true),
|
||||||
|
// }))
|
||||||
|
//
|
||||||
|
// svc := s3.New(sess)
|
||||||
|
// out, err := svc.GetObject(&s3.GetObjectInput {
|
||||||
|
// Bucket: aws.String("bucketname"),
|
||||||
|
// Key: aws.String("/foo/bar/moo"),
|
||||||
|
// })
|
||||||
|
EnableEndpointDiscovery *bool
|
||||||
|
|
||||||
|
// DisableEndpointHostPrefix will disable the SDK's behavior of prefixing
|
||||||
|
// request endpoint hosts with modeled information.
|
||||||
|
//
|
||||||
|
// Disabling this feature is useful when you want to use local endpoints
|
||||||
|
// for testing that do not support the modeled host prefix pattern.
|
||||||
|
DisableEndpointHostPrefix *bool
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewConfig returns a new Config pointer that can be chained with builder
|
// NewConfig returns a new Config pointer that can be chained with builder
|
||||||
|
@ -377,6 +400,19 @@ func (c *Config) WithSleepDelay(fn func(time.Duration)) *Config {
|
||||||
return c
|
return c
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// WithEndpointDiscovery will set whether or not to use endpoint discovery.
|
||||||
|
func (c *Config) WithEndpointDiscovery(t bool) *Config {
|
||||||
|
c.EnableEndpointDiscovery = &t
|
||||||
|
return c
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithDisableEndpointHostPrefix will set whether or not to use modeled host prefix
|
||||||
|
// when making requests.
|
||||||
|
func (c *Config) WithDisableEndpointHostPrefix(t bool) *Config {
|
||||||
|
c.DisableEndpointHostPrefix = &t
|
||||||
|
return c
|
||||||
|
}
|
||||||
|
|
||||||
// MergeIn merges the passed in configs into the existing config object.
|
// MergeIn merges the passed in configs into the existing config object.
|
||||||
func (c *Config) MergeIn(cfgs ...*Config) {
|
func (c *Config) MergeIn(cfgs ...*Config) {
|
||||||
for _, other := range cfgs {
|
for _, other := range cfgs {
|
||||||
|
@ -476,6 +512,14 @@ func mergeInConfig(dst *Config, other *Config) {
|
||||||
if other.EnforceShouldRetryCheck != nil {
|
if other.EnforceShouldRetryCheck != nil {
|
||||||
dst.EnforceShouldRetryCheck = other.EnforceShouldRetryCheck
|
dst.EnforceShouldRetryCheck = other.EnforceShouldRetryCheck
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if other.EnableEndpointDiscovery != nil {
|
||||||
|
dst.EnableEndpointDiscovery = other.EnableEndpointDiscovery
|
||||||
|
}
|
||||||
|
|
||||||
|
if other.DisableEndpointHostPrefix != nil {
|
||||||
|
dst.DisableEndpointHostPrefix = other.DisableEndpointHostPrefix
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Copy will return a shallow copy of the Config object. If any additional
|
// Copy will return a shallow copy of the Config object. If any additional
|
||||||
|
|
|
@ -1,8 +1,8 @@
|
||||||
|
// +build !go1.9
|
||||||
|
|
||||||
package aws
|
package aws
|
||||||
|
|
||||||
import (
|
import "time"
|
||||||
"time"
|
|
||||||
)
|
|
||||||
|
|
||||||
// Context is an copy of the Go v1.7 stdlib's context.Context interface.
|
// Context is an copy of the Go v1.7 stdlib's context.Context interface.
|
||||||
// It is represented as a SDK interface to enable you to use the "WithContext"
|
// It is represented as a SDK interface to enable you to use the "WithContext"
|
||||||
|
@ -35,37 +35,3 @@ type Context interface {
|
||||||
// functions.
|
// functions.
|
||||||
Value(key interface{}) interface{}
|
Value(key interface{}) interface{}
|
||||||
}
|
}
|
||||||
|
|
||||||
// BackgroundContext returns a context that will never be canceled, has no
|
|
||||||
// values, and no deadline. This context is used by the SDK to provide
|
|
||||||
// backwards compatibility with non-context API operations and functionality.
|
|
||||||
//
|
|
||||||
// Go 1.6 and before:
|
|
||||||
// This context function is equivalent to context.Background in the Go stdlib.
|
|
||||||
//
|
|
||||||
// Go 1.7 and later:
|
|
||||||
// The context returned will be the value returned by context.Background()
|
|
||||||
//
|
|
||||||
// See https://golang.org/pkg/context for more information on Contexts.
|
|
||||||
func BackgroundContext() Context {
|
|
||||||
return backgroundCtx
|
|
||||||
}
|
|
||||||
|
|
||||||
// SleepWithContext will wait for the timer duration to expire, or the context
|
|
||||||
// is canceled. Which ever happens first. If the context is canceled the Context's
|
|
||||||
// error will be returned.
|
|
||||||
//
|
|
||||||
// Expects Context to always return a non-nil error if the Done channel is closed.
|
|
||||||
func SleepWithContext(ctx Context, dur time.Duration) error {
|
|
||||||
t := time.NewTimer(dur)
|
|
||||||
defer t.Stop()
|
|
||||||
|
|
||||||
select {
|
|
||||||
case <-t.C:
|
|
||||||
break
|
|
||||||
case <-ctx.Done():
|
|
||||||
return ctx.Err()
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
|
@ -1,9 +0,0 @@
|
||||||
// +build go1.7
|
|
||||||
|
|
||||||
package aws
|
|
||||||
|
|
||||||
import "context"
|
|
||||||
|
|
||||||
var (
|
|
||||||
backgroundCtx = context.Background()
|
|
||||||
)
|
|
|
@ -0,0 +1,11 @@
|
||||||
|
// +build go1.9
|
||||||
|
|
||||||
|
package aws
|
||||||
|
|
||||||
|
import "context"
|
||||||
|
|
||||||
|
// Context is an alias of the Go stdlib's context.Context interface.
|
||||||
|
// It can be used within the SDK's API operation "WithContext" methods.
|
||||||
|
//
|
||||||
|
// See https://golang.org/pkg/context on how to use contexts.
|
||||||
|
type Context = context.Context
|
|
@ -39,3 +39,18 @@ func (e *emptyCtx) String() string {
|
||||||
var (
|
var (
|
||||||
backgroundCtx = new(emptyCtx)
|
backgroundCtx = new(emptyCtx)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// BackgroundContext returns a context that will never be canceled, has no
|
||||||
|
// values, and no deadline. This context is used by the SDK to provide
|
||||||
|
// backwards compatibility with non-context API operations and functionality.
|
||||||
|
//
|
||||||
|
// Go 1.6 and before:
|
||||||
|
// This context function is equivalent to context.Background in the Go stdlib.
|
||||||
|
//
|
||||||
|
// Go 1.7 and later:
|
||||||
|
// The context returned will be the value returned by context.Background()
|
||||||
|
//
|
||||||
|
// See https://golang.org/pkg/context for more information on Contexts.
|
||||||
|
func BackgroundContext() Context {
|
||||||
|
return backgroundCtx
|
||||||
|
}
|
|
@ -0,0 +1,20 @@
|
||||||
|
// +build go1.7
|
||||||
|
|
||||||
|
package aws
|
||||||
|
|
||||||
|
import "context"
|
||||||
|
|
||||||
|
// BackgroundContext returns a context that will never be canceled, has no
|
||||||
|
// values, and no deadline. This context is used by the SDK to provide
|
||||||
|
// backwards compatibility with non-context API operations and functionality.
|
||||||
|
//
|
||||||
|
// Go 1.6 and before:
|
||||||
|
// This context function is equivalent to context.Background in the Go stdlib.
|
||||||
|
//
|
||||||
|
// Go 1.7 and later:
|
||||||
|
// The context returned will be the value returned by context.Background()
|
||||||
|
//
|
||||||
|
// See https://golang.org/pkg/context for more information on Contexts.
|
||||||
|
func BackgroundContext() Context {
|
||||||
|
return context.Background()
|
||||||
|
}
|
|
@ -0,0 +1,24 @@
|
||||||
|
package aws
|
||||||
|
|
||||||
|
import (
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// SleepWithContext will wait for the timer duration to expire, or the context
|
||||||
|
// is canceled. Which ever happens first. If the context is canceled the Context's
|
||||||
|
// error will be returned.
|
||||||
|
//
|
||||||
|
// Expects Context to always return a non-nil error if the Done channel is closed.
|
||||||
|
func SleepWithContext(ctx Context, dur time.Duration) error {
|
||||||
|
t := time.NewTimer(dur)
|
||||||
|
defer t.Stop()
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-t.C:
|
||||||
|
break
|
||||||
|
case <-ctx.Done():
|
||||||
|
return ctx.Err()
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
|
@ -72,9 +72,9 @@ var ValidateReqSigHandler = request.NamedHandler{
|
||||||
signedTime = r.LastSignedAt
|
signedTime = r.LastSignedAt
|
||||||
}
|
}
|
||||||
|
|
||||||
// 10 minutes to allow for some clock skew/delays in transmission.
|
// 5 minutes to allow for some clock skew/delays in transmission.
|
||||||
// Would be improved with aws/aws-sdk-go#423
|
// Would be improved with aws/aws-sdk-go#423
|
||||||
if signedTime.Add(10 * time.Minute).After(time.Now()) {
|
if signedTime.Add(5 * time.Minute).After(time.Now()) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -17,7 +17,7 @@ var SDKVersionUserAgentHandler = request.NamedHandler{
|
||||||
}
|
}
|
||||||
|
|
||||||
const execEnvVar = `AWS_EXECUTION_ENV`
|
const execEnvVar = `AWS_EXECUTION_ENV`
|
||||||
const execEnvUAKey = `exec_env`
|
const execEnvUAKey = `exec-env`
|
||||||
|
|
||||||
// AddHostExecEnvUserAgentHander is a request handler appending the SDK's
|
// AddHostExecEnvUserAgentHander is a request handler appending the SDK's
|
||||||
// execution environment to the user agent.
|
// execution environment to the user agent.
|
||||||
|
|
|
@ -9,9 +9,7 @@ var (
|
||||||
// providers in the ChainProvider.
|
// providers in the ChainProvider.
|
||||||
//
|
//
|
||||||
// This has been deprecated. For verbose error messaging set
|
// This has been deprecated. For verbose error messaging set
|
||||||
// aws.Config.CredentialsChainVerboseErrors to true
|
// aws.Config.CredentialsChainVerboseErrors to true.
|
||||||
//
|
|
||||||
// @readonly
|
|
||||||
ErrNoValidProvidersFoundInChain = awserr.New("NoCredentialProviders",
|
ErrNoValidProvidersFoundInChain = awserr.New("NoCredentialProviders",
|
||||||
`no valid providers in chain. Deprecated.
|
`no valid providers in chain. Deprecated.
|
||||||
For verbose messaging see aws.Config.CredentialsChainVerboseErrors`,
|
For verbose messaging see aws.Config.CredentialsChainVerboseErrors`,
|
||||||
|
|
|
@ -49,8 +49,11 @@
|
||||||
package credentials
|
package credentials
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"fmt"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/aws/aws-sdk-go/aws/awserr"
|
||||||
)
|
)
|
||||||
|
|
||||||
// AnonymousCredentials is an empty Credential object that can be used as
|
// AnonymousCredentials is an empty Credential object that can be used as
|
||||||
|
@ -64,8 +67,6 @@ import (
|
||||||
// Credentials: credentials.AnonymousCredentials,
|
// Credentials: credentials.AnonymousCredentials,
|
||||||
// })))
|
// })))
|
||||||
// // Access public S3 buckets.
|
// // Access public S3 buckets.
|
||||||
//
|
|
||||||
// @readonly
|
|
||||||
var AnonymousCredentials = NewStaticCredentials("", "", "")
|
var AnonymousCredentials = NewStaticCredentials("", "", "")
|
||||||
|
|
||||||
// A Value is the AWS credentials value for individual credential fields.
|
// A Value is the AWS credentials value for individual credential fields.
|
||||||
|
@ -83,6 +84,12 @@ type Value struct {
|
||||||
ProviderName string
|
ProviderName string
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// HasKeys returns if the credentials Value has both AccessKeyID and
|
||||||
|
// SecretAccessKey value set.
|
||||||
|
func (v Value) HasKeys() bool {
|
||||||
|
return len(v.AccessKeyID) != 0 && len(v.SecretAccessKey) != 0
|
||||||
|
}
|
||||||
|
|
||||||
// A Provider is the interface for any component which will provide credentials
|
// A Provider is the interface for any component which will provide credentials
|
||||||
// Value. A provider is required to manage its own Expired state, and what to
|
// Value. A provider is required to manage its own Expired state, and what to
|
||||||
// be expired means.
|
// be expired means.
|
||||||
|
@ -99,6 +106,14 @@ type Provider interface {
|
||||||
IsExpired() bool
|
IsExpired() bool
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// An Expirer is an interface that Providers can implement to expose the expiration
|
||||||
|
// time, if known. If the Provider cannot accurately provide this info,
|
||||||
|
// it should not implement this interface.
|
||||||
|
type Expirer interface {
|
||||||
|
// The time at which the credentials are no longer valid
|
||||||
|
ExpiresAt() time.Time
|
||||||
|
}
|
||||||
|
|
||||||
// An ErrorProvider is a stub credentials provider that always returns an error
|
// An ErrorProvider is a stub credentials provider that always returns an error
|
||||||
// this is used by the SDK when construction a known provider is not possible
|
// this is used by the SDK when construction a known provider is not possible
|
||||||
// due to an error.
|
// due to an error.
|
||||||
|
@ -165,6 +180,11 @@ func (e *Expiry) IsExpired() bool {
|
||||||
return e.expiration.Before(curTime())
|
return e.expiration.Before(curTime())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ExpiresAt returns the expiration time of the credential
|
||||||
|
func (e *Expiry) ExpiresAt() time.Time {
|
||||||
|
return e.expiration
|
||||||
|
}
|
||||||
|
|
||||||
// A Credentials provides concurrency safe retrieval of AWS credentials Value.
|
// A Credentials provides concurrency safe retrieval of AWS credentials Value.
|
||||||
// Credentials will cache the credentials value until they expire. Once the value
|
// Credentials will cache the credentials value until they expire. Once the value
|
||||||
// expires the next Get will attempt to retrieve valid credentials.
|
// expires the next Get will attempt to retrieve valid credentials.
|
||||||
|
@ -257,3 +277,23 @@ func (c *Credentials) IsExpired() bool {
|
||||||
func (c *Credentials) isExpired() bool {
|
func (c *Credentials) isExpired() bool {
|
||||||
return c.forceRefresh || c.provider.IsExpired()
|
return c.forceRefresh || c.provider.IsExpired()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ExpiresAt provides access to the functionality of the Expirer interface of
|
||||||
|
// the underlying Provider, if it supports that interface. Otherwise, it returns
|
||||||
|
// an error.
|
||||||
|
func (c *Credentials) ExpiresAt() (time.Time, error) {
|
||||||
|
c.m.RLock()
|
||||||
|
defer c.m.RUnlock()
|
||||||
|
|
||||||
|
expirer, ok := c.provider.(Expirer)
|
||||||
|
if !ok {
|
||||||
|
return time.Time{}, awserr.New("ProviderNotExpirer",
|
||||||
|
fmt.Sprintf("provider %s does not support ExpiresAt()", c.creds.ProviderName),
|
||||||
|
nil)
|
||||||
|
}
|
||||||
|
if c.forceRefresh {
|
||||||
|
// set expiration time to the distant past
|
||||||
|
return time.Time{}, nil
|
||||||
|
}
|
||||||
|
return expirer.ExpiresAt(), nil
|
||||||
|
}
|
||||||
|
|
6
vendor/github.com/aws/aws-sdk-go/aws/credentials/ec2rolecreds/ec2_role_provider.go
generated
vendored
6
vendor/github.com/aws/aws-sdk-go/aws/credentials/ec2rolecreds/ec2_role_provider.go
generated
vendored
|
@ -11,6 +11,7 @@ import (
|
||||||
"github.com/aws/aws-sdk-go/aws/client"
|
"github.com/aws/aws-sdk-go/aws/client"
|
||||||
"github.com/aws/aws-sdk-go/aws/credentials"
|
"github.com/aws/aws-sdk-go/aws/credentials"
|
||||||
"github.com/aws/aws-sdk-go/aws/ec2metadata"
|
"github.com/aws/aws-sdk-go/aws/ec2metadata"
|
||||||
|
"github.com/aws/aws-sdk-go/aws/request"
|
||||||
"github.com/aws/aws-sdk-go/internal/sdkuri"
|
"github.com/aws/aws-sdk-go/internal/sdkuri"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -142,7 +143,8 @@ func requestCredList(client *ec2metadata.EC2Metadata) ([]string, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := s.Err(); err != nil {
|
if err := s.Err(); err != nil {
|
||||||
return nil, awserr.New("SerializationError", "failed to read EC2 instance role from metadata service", err)
|
return nil, awserr.New(request.ErrCodeSerialization,
|
||||||
|
"failed to read EC2 instance role from metadata service", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return credsList, nil
|
return credsList, nil
|
||||||
|
@ -164,7 +166,7 @@ func requestCred(client *ec2metadata.EC2Metadata, credsName string) (ec2RoleCred
|
||||||
respCreds := ec2RoleCredRespBody{}
|
respCreds := ec2RoleCredRespBody{}
|
||||||
if err := json.NewDecoder(strings.NewReader(resp)).Decode(&respCreds); err != nil {
|
if err := json.NewDecoder(strings.NewReader(resp)).Decode(&respCreds); err != nil {
|
||||||
return ec2RoleCredRespBody{},
|
return ec2RoleCredRespBody{},
|
||||||
awserr.New("SerializationError",
|
awserr.New(request.ErrCodeSerialization,
|
||||||
fmt.Sprintf("failed to decode %s EC2 instance role credentials", credsName),
|
fmt.Sprintf("failed to decode %s EC2 instance role credentials", credsName),
|
||||||
err)
|
err)
|
||||||
}
|
}
|
||||||
|
|
|
@ -39,6 +39,7 @@ import (
|
||||||
"github.com/aws/aws-sdk-go/aws/client/metadata"
|
"github.com/aws/aws-sdk-go/aws/client/metadata"
|
||||||
"github.com/aws/aws-sdk-go/aws/credentials"
|
"github.com/aws/aws-sdk-go/aws/credentials"
|
||||||
"github.com/aws/aws-sdk-go/aws/request"
|
"github.com/aws/aws-sdk-go/aws/request"
|
||||||
|
"github.com/aws/aws-sdk-go/private/protocol/json/jsonutil"
|
||||||
)
|
)
|
||||||
|
|
||||||
// ProviderName is the name of the credentials provider.
|
// ProviderName is the name of the credentials provider.
|
||||||
|
@ -65,6 +66,10 @@ type Provider struct {
|
||||||
//
|
//
|
||||||
// If ExpiryWindow is 0 or less it will be ignored.
|
// If ExpiryWindow is 0 or less it will be ignored.
|
||||||
ExpiryWindow time.Duration
|
ExpiryWindow time.Duration
|
||||||
|
|
||||||
|
// Optional authorization token value if set will be used as the value of
|
||||||
|
// the Authorization header of the endpoint credential request.
|
||||||
|
AuthorizationToken string
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewProviderClient returns a credentials Provider for retrieving AWS credentials
|
// NewProviderClient returns a credentials Provider for retrieving AWS credentials
|
||||||
|
@ -152,6 +157,9 @@ func (p *Provider) getCredentials() (*getCredentialsOutput, error) {
|
||||||
out := &getCredentialsOutput{}
|
out := &getCredentialsOutput{}
|
||||||
req := p.Client.NewRequest(op, nil, out)
|
req := p.Client.NewRequest(op, nil, out)
|
||||||
req.HTTPRequest.Header.Set("Accept", "application/json")
|
req.HTTPRequest.Header.Set("Accept", "application/json")
|
||||||
|
if authToken := p.AuthorizationToken; len(authToken) != 0 {
|
||||||
|
req.HTTPRequest.Header.Set("Authorization", authToken)
|
||||||
|
}
|
||||||
|
|
||||||
return out, req.Send()
|
return out, req.Send()
|
||||||
}
|
}
|
||||||
|
@ -167,7 +175,7 @@ func unmarshalHandler(r *request.Request) {
|
||||||
|
|
||||||
out := r.Data.(*getCredentialsOutput)
|
out := r.Data.(*getCredentialsOutput)
|
||||||
if err := json.NewDecoder(r.HTTPResponse.Body).Decode(&out); err != nil {
|
if err := json.NewDecoder(r.HTTPResponse.Body).Decode(&out); err != nil {
|
||||||
r.Error = awserr.New("SerializationError",
|
r.Error = awserr.New(request.ErrCodeSerialization,
|
||||||
"failed to decode endpoint credentials",
|
"failed to decode endpoint credentials",
|
||||||
err,
|
err,
|
||||||
)
|
)
|
||||||
|
@ -178,11 +186,15 @@ func unmarshalError(r *request.Request) {
|
||||||
defer r.HTTPResponse.Body.Close()
|
defer r.HTTPResponse.Body.Close()
|
||||||
|
|
||||||
var errOut errorOutput
|
var errOut errorOutput
|
||||||
if err := json.NewDecoder(r.HTTPResponse.Body).Decode(&errOut); err != nil {
|
err := jsonutil.UnmarshalJSONError(&errOut, r.HTTPResponse.Body)
|
||||||
r.Error = awserr.New("SerializationError",
|
if err != nil {
|
||||||
"failed to decode endpoint credentials",
|
r.Error = awserr.NewRequestFailure(
|
||||||
err,
|
awserr.New(request.ErrCodeSerialization,
|
||||||
|
"failed to decode error message", err),
|
||||||
|
r.HTTPResponse.StatusCode,
|
||||||
|
r.RequestID,
|
||||||
)
|
)
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// Response body format is not consistent between metadata endpoints.
|
// Response body format is not consistent between metadata endpoints.
|
||||||
|
|
|
@ -12,14 +12,10 @@ const EnvProviderName = "EnvProvider"
|
||||||
var (
|
var (
|
||||||
// ErrAccessKeyIDNotFound is returned when the AWS Access Key ID can't be
|
// ErrAccessKeyIDNotFound is returned when the AWS Access Key ID can't be
|
||||||
// found in the process's environment.
|
// found in the process's environment.
|
||||||
//
|
|
||||||
// @readonly
|
|
||||||
ErrAccessKeyIDNotFound = awserr.New("EnvAccessKeyNotFound", "AWS_ACCESS_KEY_ID or AWS_ACCESS_KEY not found in environment", nil)
|
ErrAccessKeyIDNotFound = awserr.New("EnvAccessKeyNotFound", "AWS_ACCESS_KEY_ID or AWS_ACCESS_KEY not found in environment", nil)
|
||||||
|
|
||||||
// ErrSecretAccessKeyNotFound is returned when the AWS Secret Access Key
|
// ErrSecretAccessKeyNotFound is returned when the AWS Secret Access Key
|
||||||
// can't be found in the process's environment.
|
// can't be found in the process's environment.
|
||||||
//
|
|
||||||
// @readonly
|
|
||||||
ErrSecretAccessKeyNotFound = awserr.New("EnvSecretNotFound", "AWS_SECRET_ACCESS_KEY or AWS_SECRET_KEY not found in environment", nil)
|
ErrSecretAccessKeyNotFound = awserr.New("EnvSecretNotFound", "AWS_SECRET_ACCESS_KEY or AWS_SECRET_KEY not found in environment", nil)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
425
vendor/github.com/aws/aws-sdk-go/aws/credentials/processcreds/provider.go
generated
vendored
Normal file
425
vendor/github.com/aws/aws-sdk-go/aws/credentials/processcreds/provider.go
generated
vendored
Normal file
|
@ -0,0 +1,425 @@
|
||||||
|
/*
|
||||||
|
Package processcreds is a credential Provider to retrieve `credential_process`
|
||||||
|
credentials.
|
||||||
|
|
||||||
|
WARNING: The following describes a method of sourcing credentials from an external
|
||||||
|
process. This can potentially be dangerous, so proceed with caution. Other
|
||||||
|
credential providers should be preferred if at all possible. If using this
|
||||||
|
option, you should make sure that the config file is as locked down as possible
|
||||||
|
using security best practices for your operating system.
|
||||||
|
|
||||||
|
You can use credentials from a `credential_process` in a variety of ways.
|
||||||
|
|
||||||
|
One way is to setup your shared config file, located in the default
|
||||||
|
location, with the `credential_process` key and the command you want to be
|
||||||
|
called. You also need to set the AWS_SDK_LOAD_CONFIG environment variable
|
||||||
|
(e.g., `export AWS_SDK_LOAD_CONFIG=1`) to use the shared config file.
|
||||||
|
|
||||||
|
[default]
|
||||||
|
credential_process = /command/to/call
|
||||||
|
|
||||||
|
Creating a new session will use the credential process to retrieve credentials.
|
||||||
|
NOTE: If there are credentials in the profile you are using, the credential
|
||||||
|
process will not be used.
|
||||||
|
|
||||||
|
// Initialize a session to load credentials.
|
||||||
|
sess, _ := session.NewSession(&aws.Config{
|
||||||
|
Region: aws.String("us-east-1")},
|
||||||
|
)
|
||||||
|
|
||||||
|
// Create S3 service client to use the credentials.
|
||||||
|
svc := s3.New(sess)
|
||||||
|
|
||||||
|
Another way to use the `credential_process` method is by using
|
||||||
|
`credentials.NewCredentials()` and providing a command to be executed to
|
||||||
|
retrieve credentials:
|
||||||
|
|
||||||
|
// Create credentials using the ProcessProvider.
|
||||||
|
creds := processcreds.NewCredentials("/path/to/command")
|
||||||
|
|
||||||
|
// Create service client value configured for credentials.
|
||||||
|
svc := s3.New(sess, &aws.Config{Credentials: creds})
|
||||||
|
|
||||||
|
You can set a non-default timeout for the `credential_process` with another
|
||||||
|
constructor, `credentials.NewCredentialsTimeout()`, providing the timeout. To
|
||||||
|
set a one minute timeout:
|
||||||
|
|
||||||
|
// Create credentials using the ProcessProvider.
|
||||||
|
creds := processcreds.NewCredentialsTimeout(
|
||||||
|
"/path/to/command",
|
||||||
|
time.Duration(500) * time.Millisecond)
|
||||||
|
|
||||||
|
If you need more control, you can set any configurable options in the
|
||||||
|
credentials using one or more option functions. For example, you can set a two
|
||||||
|
minute timeout, a credential duration of 60 minutes, and a maximum stdout
|
||||||
|
buffer size of 2k.
|
||||||
|
|
||||||
|
creds := processcreds.NewCredentials(
|
||||||
|
"/path/to/command",
|
||||||
|
func(opt *ProcessProvider) {
|
||||||
|
opt.Timeout = time.Duration(2) * time.Minute
|
||||||
|
opt.Duration = time.Duration(60) * time.Minute
|
||||||
|
opt.MaxBufSize = 2048
|
||||||
|
})
|
||||||
|
|
||||||
|
You can also use your own `exec.Cmd`:
|
||||||
|
|
||||||
|
// Create an exec.Cmd
|
||||||
|
myCommand := exec.Command("/path/to/command")
|
||||||
|
|
||||||
|
// Create credentials using your exec.Cmd and custom timeout
|
||||||
|
creds := processcreds.NewCredentialsCommand(
|
||||||
|
myCommand,
|
||||||
|
func(opt *processcreds.ProcessProvider) {
|
||||||
|
opt.Timeout = time.Duration(1) * time.Second
|
||||||
|
})
|
||||||
|
*/
|
||||||
|
package processcreds
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"io/ioutil"
|
||||||
|
"os"
|
||||||
|
"os/exec"
|
||||||
|
"runtime"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/aws/aws-sdk-go/aws/awserr"
|
||||||
|
"github.com/aws/aws-sdk-go/aws/credentials"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
// ProviderName is the name this credentials provider will label any
|
||||||
|
// returned credentials Value with.
|
||||||
|
ProviderName = `ProcessProvider`
|
||||||
|
|
||||||
|
// ErrCodeProcessProviderParse error parsing process output
|
||||||
|
ErrCodeProcessProviderParse = "ProcessProviderParseError"
|
||||||
|
|
||||||
|
// ErrCodeProcessProviderVersion version error in output
|
||||||
|
ErrCodeProcessProviderVersion = "ProcessProviderVersionError"
|
||||||
|
|
||||||
|
// ErrCodeProcessProviderRequired required attribute missing in output
|
||||||
|
ErrCodeProcessProviderRequired = "ProcessProviderRequiredError"
|
||||||
|
|
||||||
|
// ErrCodeProcessProviderExecution execution of command failed
|
||||||
|
ErrCodeProcessProviderExecution = "ProcessProviderExecutionError"
|
||||||
|
|
||||||
|
// errMsgProcessProviderTimeout process took longer than allowed
|
||||||
|
errMsgProcessProviderTimeout = "credential process timed out"
|
||||||
|
|
||||||
|
// errMsgProcessProviderProcess process error
|
||||||
|
errMsgProcessProviderProcess = "error in credential_process"
|
||||||
|
|
||||||
|
// errMsgProcessProviderParse problem parsing output
|
||||||
|
errMsgProcessProviderParse = "parse failed of credential_process output"
|
||||||
|
|
||||||
|
// errMsgProcessProviderVersion version error in output
|
||||||
|
errMsgProcessProviderVersion = "wrong version in process output (not 1)"
|
||||||
|
|
||||||
|
// errMsgProcessProviderMissKey missing access key id in output
|
||||||
|
errMsgProcessProviderMissKey = "missing AccessKeyId in process output"
|
||||||
|
|
||||||
|
// errMsgProcessProviderMissSecret missing secret acess key in output
|
||||||
|
errMsgProcessProviderMissSecret = "missing SecretAccessKey in process output"
|
||||||
|
|
||||||
|
// errMsgProcessProviderPrepareCmd prepare of command failed
|
||||||
|
errMsgProcessProviderPrepareCmd = "failed to prepare command"
|
||||||
|
|
||||||
|
// errMsgProcessProviderEmptyCmd command must not be empty
|
||||||
|
errMsgProcessProviderEmptyCmd = "command must not be empty"
|
||||||
|
|
||||||
|
// errMsgProcessProviderPipe failed to initialize pipe
|
||||||
|
errMsgProcessProviderPipe = "failed to initialize pipe"
|
||||||
|
|
||||||
|
// DefaultDuration is the default amount of time in minutes that the
|
||||||
|
// credentials will be valid for.
|
||||||
|
DefaultDuration = time.Duration(15) * time.Minute
|
||||||
|
|
||||||
|
// DefaultBufSize limits buffer size from growing to an enormous
|
||||||
|
// amount due to a faulty process.
|
||||||
|
DefaultBufSize = 1024
|
||||||
|
|
||||||
|
// DefaultTimeout default limit on time a process can run.
|
||||||
|
DefaultTimeout = time.Duration(1) * time.Minute
|
||||||
|
)
|
||||||
|
|
||||||
|
// ProcessProvider satisfies the credentials.Provider interface, and is a
|
||||||
|
// client to retrieve credentials from a process.
|
||||||
|
type ProcessProvider struct {
|
||||||
|
staticCreds bool
|
||||||
|
credentials.Expiry
|
||||||
|
originalCommand []string
|
||||||
|
|
||||||
|
// Expiry duration of the credentials. Defaults to 15 minutes if not set.
|
||||||
|
Duration time.Duration
|
||||||
|
|
||||||
|
// ExpiryWindow will allow the credentials to trigger refreshing prior to
|
||||||
|
// the credentials actually expiring. This is beneficial so race conditions
|
||||||
|
// with expiring credentials do not cause request to fail unexpectedly
|
||||||
|
// due to ExpiredTokenException exceptions.
|
||||||
|
//
|
||||||
|
// So a ExpiryWindow of 10s would cause calls to IsExpired() to return true
|
||||||
|
// 10 seconds before the credentials are actually expired.
|
||||||
|
//
|
||||||
|
// If ExpiryWindow is 0 or less it will be ignored.
|
||||||
|
ExpiryWindow time.Duration
|
||||||
|
|
||||||
|
// A string representing an os command that should return a JSON with
|
||||||
|
// credential information.
|
||||||
|
command *exec.Cmd
|
||||||
|
|
||||||
|
// MaxBufSize limits memory usage from growing to an enormous
|
||||||
|
// amount due to a faulty process.
|
||||||
|
MaxBufSize int
|
||||||
|
|
||||||
|
// Timeout limits the time a process can run.
|
||||||
|
Timeout time.Duration
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewCredentials returns a pointer to a new Credentials object wrapping the
|
||||||
|
// ProcessProvider. The credentials will expire every 15 minutes by default.
|
||||||
|
func NewCredentials(command string, options ...func(*ProcessProvider)) *credentials.Credentials {
|
||||||
|
p := &ProcessProvider{
|
||||||
|
command: exec.Command(command),
|
||||||
|
Duration: DefaultDuration,
|
||||||
|
Timeout: DefaultTimeout,
|
||||||
|
MaxBufSize: DefaultBufSize,
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, option := range options {
|
||||||
|
option(p)
|
||||||
|
}
|
||||||
|
|
||||||
|
return credentials.NewCredentials(p)
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewCredentialsTimeout returns a pointer to a new Credentials object with
|
||||||
|
// the specified command and timeout, and default duration and max buffer size.
|
||||||
|
func NewCredentialsTimeout(command string, timeout time.Duration) *credentials.Credentials {
|
||||||
|
p := NewCredentials(command, func(opt *ProcessProvider) {
|
||||||
|
opt.Timeout = timeout
|
||||||
|
})
|
||||||
|
|
||||||
|
return p
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewCredentialsCommand returns a pointer to a new Credentials object with
|
||||||
|
// the specified command, and default timeout, duration and max buffer size.
|
||||||
|
func NewCredentialsCommand(command *exec.Cmd, options ...func(*ProcessProvider)) *credentials.Credentials {
|
||||||
|
p := &ProcessProvider{
|
||||||
|
command: command,
|
||||||
|
Duration: DefaultDuration,
|
||||||
|
Timeout: DefaultTimeout,
|
||||||
|
MaxBufSize: DefaultBufSize,
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, option := range options {
|
||||||
|
option(p)
|
||||||
|
}
|
||||||
|
|
||||||
|
return credentials.NewCredentials(p)
|
||||||
|
}
|
||||||
|
|
||||||
|
type credentialProcessResponse struct {
|
||||||
|
Version int
|
||||||
|
AccessKeyID string `json:"AccessKeyId"`
|
||||||
|
SecretAccessKey string
|
||||||
|
SessionToken string
|
||||||
|
Expiration *time.Time
|
||||||
|
}
|
||||||
|
|
||||||
|
// Retrieve executes the 'credential_process' and returns the credentials.
|
||||||
|
func (p *ProcessProvider) Retrieve() (credentials.Value, error) {
|
||||||
|
out, err := p.executeCredentialProcess()
|
||||||
|
if err != nil {
|
||||||
|
return credentials.Value{ProviderName: ProviderName}, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Serialize and validate response
|
||||||
|
resp := &credentialProcessResponse{}
|
||||||
|
if err = json.Unmarshal(out, resp); err != nil {
|
||||||
|
return credentials.Value{ProviderName: ProviderName}, awserr.New(
|
||||||
|
ErrCodeProcessProviderParse,
|
||||||
|
fmt.Sprintf("%s: %s", errMsgProcessProviderParse, string(out)),
|
||||||
|
err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if resp.Version != 1 {
|
||||||
|
return credentials.Value{ProviderName: ProviderName}, awserr.New(
|
||||||
|
ErrCodeProcessProviderVersion,
|
||||||
|
errMsgProcessProviderVersion,
|
||||||
|
nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(resp.AccessKeyID) == 0 {
|
||||||
|
return credentials.Value{ProviderName: ProviderName}, awserr.New(
|
||||||
|
ErrCodeProcessProviderRequired,
|
||||||
|
errMsgProcessProviderMissKey,
|
||||||
|
nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(resp.SecretAccessKey) == 0 {
|
||||||
|
return credentials.Value{ProviderName: ProviderName}, awserr.New(
|
||||||
|
ErrCodeProcessProviderRequired,
|
||||||
|
errMsgProcessProviderMissSecret,
|
||||||
|
nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handle expiration
|
||||||
|
p.staticCreds = resp.Expiration == nil
|
||||||
|
if resp.Expiration != nil {
|
||||||
|
p.SetExpiration(*resp.Expiration, p.ExpiryWindow)
|
||||||
|
}
|
||||||
|
|
||||||
|
return credentials.Value{
|
||||||
|
ProviderName: ProviderName,
|
||||||
|
AccessKeyID: resp.AccessKeyID,
|
||||||
|
SecretAccessKey: resp.SecretAccessKey,
|
||||||
|
SessionToken: resp.SessionToken,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsExpired returns true if the credentials retrieved are expired, or not yet
|
||||||
|
// retrieved.
|
||||||
|
func (p *ProcessProvider) IsExpired() bool {
|
||||||
|
if p.staticCreds {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return p.Expiry.IsExpired()
|
||||||
|
}
|
||||||
|
|
||||||
|
// prepareCommand prepares the command to be executed.
|
||||||
|
func (p *ProcessProvider) prepareCommand() error {
|
||||||
|
|
||||||
|
var cmdArgs []string
|
||||||
|
if runtime.GOOS == "windows" {
|
||||||
|
cmdArgs = []string{"cmd.exe", "/C"}
|
||||||
|
} else {
|
||||||
|
cmdArgs = []string{"sh", "-c"}
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(p.originalCommand) == 0 {
|
||||||
|
p.originalCommand = make([]string, len(p.command.Args))
|
||||||
|
copy(p.originalCommand, p.command.Args)
|
||||||
|
|
||||||
|
// check for empty command because it succeeds
|
||||||
|
if len(strings.TrimSpace(p.originalCommand[0])) < 1 {
|
||||||
|
return awserr.New(
|
||||||
|
ErrCodeProcessProviderExecution,
|
||||||
|
fmt.Sprintf(
|
||||||
|
"%s: %s",
|
||||||
|
errMsgProcessProviderPrepareCmd,
|
||||||
|
errMsgProcessProviderEmptyCmd),
|
||||||
|
nil)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
cmdArgs = append(cmdArgs, p.originalCommand...)
|
||||||
|
p.command = exec.Command(cmdArgs[0], cmdArgs[1:]...)
|
||||||
|
p.command.Env = os.Environ()
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// executeCredentialProcess starts the credential process on the OS and
|
||||||
|
// returns the results or an error.
|
||||||
|
func (p *ProcessProvider) executeCredentialProcess() ([]byte, error) {
|
||||||
|
|
||||||
|
if err := p.prepareCommand(); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Setup the pipes
|
||||||
|
outReadPipe, outWritePipe, err := os.Pipe()
|
||||||
|
if err != nil {
|
||||||
|
return nil, awserr.New(
|
||||||
|
ErrCodeProcessProviderExecution,
|
||||||
|
errMsgProcessProviderPipe,
|
||||||
|
err)
|
||||||
|
}
|
||||||
|
|
||||||
|
p.command.Stderr = os.Stderr // display stderr on console for MFA
|
||||||
|
p.command.Stdout = outWritePipe // get creds json on process's stdout
|
||||||
|
p.command.Stdin = os.Stdin // enable stdin for MFA
|
||||||
|
|
||||||
|
output := bytes.NewBuffer(make([]byte, 0, p.MaxBufSize))
|
||||||
|
|
||||||
|
stdoutCh := make(chan error, 1)
|
||||||
|
go readInput(
|
||||||
|
io.LimitReader(outReadPipe, int64(p.MaxBufSize)),
|
||||||
|
output,
|
||||||
|
stdoutCh)
|
||||||
|
|
||||||
|
execCh := make(chan error, 1)
|
||||||
|
go executeCommand(*p.command, execCh)
|
||||||
|
|
||||||
|
finished := false
|
||||||
|
var errors []error
|
||||||
|
for !finished {
|
||||||
|
select {
|
||||||
|
case readError := <-stdoutCh:
|
||||||
|
errors = appendError(errors, readError)
|
||||||
|
finished = true
|
||||||
|
case execError := <-execCh:
|
||||||
|
err := outWritePipe.Close()
|
||||||
|
errors = appendError(errors, err)
|
||||||
|
errors = appendError(errors, execError)
|
||||||
|
if errors != nil {
|
||||||
|
return output.Bytes(), awserr.NewBatchError(
|
||||||
|
ErrCodeProcessProviderExecution,
|
||||||
|
errMsgProcessProviderProcess,
|
||||||
|
errors)
|
||||||
|
}
|
||||||
|
case <-time.After(p.Timeout):
|
||||||
|
finished = true
|
||||||
|
return output.Bytes(), awserr.NewBatchError(
|
||||||
|
ErrCodeProcessProviderExecution,
|
||||||
|
errMsgProcessProviderTimeout,
|
||||||
|
errors) // errors can be nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
out := output.Bytes()
|
||||||
|
|
||||||
|
if runtime.GOOS == "windows" {
|
||||||
|
// windows adds slashes to quotes
|
||||||
|
out = []byte(strings.Replace(string(out), `\"`, `"`, -1))
|
||||||
|
}
|
||||||
|
|
||||||
|
return out, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// appendError conveniently checks for nil before appending slice
|
||||||
|
func appendError(errors []error, err error) []error {
|
||||||
|
if err != nil {
|
||||||
|
return append(errors, err)
|
||||||
|
}
|
||||||
|
return errors
|
||||||
|
}
|
||||||
|
|
||||||
|
func executeCommand(cmd exec.Cmd, exec chan error) {
|
||||||
|
// Start the command
|
||||||
|
err := cmd.Start()
|
||||||
|
if err == nil {
|
||||||
|
err = cmd.Wait()
|
||||||
|
}
|
||||||
|
|
||||||
|
exec <- err
|
||||||
|
}
|
||||||
|
|
||||||
|
func readInput(r io.Reader, w io.Writer, read chan error) {
|
||||||
|
tee := io.TeeReader(r, w)
|
||||||
|
|
||||||
|
_, err := ioutil.ReadAll(tee)
|
||||||
|
|
||||||
|
if err == io.EOF {
|
||||||
|
err = nil
|
||||||
|
}
|
||||||
|
|
||||||
|
read <- err // will only arrive here when write end of pipe is closed
|
||||||
|
}
|
30
vendor/github.com/aws/aws-sdk-go/aws/credentials/shared_credentials_provider.go
generated
vendored
30
vendor/github.com/aws/aws-sdk-go/aws/credentials/shared_credentials_provider.go
generated
vendored
|
@ -4,9 +4,8 @@ import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"os"
|
"os"
|
||||||
|
|
||||||
"github.com/go-ini/ini"
|
|
||||||
|
|
||||||
"github.com/aws/aws-sdk-go/aws/awserr"
|
"github.com/aws/aws-sdk-go/aws/awserr"
|
||||||
|
"github.com/aws/aws-sdk-go/internal/ini"
|
||||||
"github.com/aws/aws-sdk-go/internal/shareddefaults"
|
"github.com/aws/aws-sdk-go/internal/shareddefaults"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -77,36 +76,37 @@ func (p *SharedCredentialsProvider) IsExpired() bool {
|
||||||
// The credentials retrieved from the profile will be returned or error. Error will be
|
// The credentials retrieved from the profile will be returned or error. Error will be
|
||||||
// returned if it fails to read from the file, or the data is invalid.
|
// returned if it fails to read from the file, or the data is invalid.
|
||||||
func loadProfile(filename, profile string) (Value, error) {
|
func loadProfile(filename, profile string) (Value, error) {
|
||||||
config, err := ini.Load(filename)
|
config, err := ini.OpenFile(filename)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return Value{ProviderName: SharedCredsProviderName}, awserr.New("SharedCredsLoad", "failed to load shared credentials file", err)
|
return Value{ProviderName: SharedCredsProviderName}, awserr.New("SharedCredsLoad", "failed to load shared credentials file", err)
|
||||||
}
|
}
|
||||||
iniProfile, err := config.GetSection(profile)
|
|
||||||
if err != nil {
|
iniProfile, ok := config.GetSection(profile)
|
||||||
return Value{ProviderName: SharedCredsProviderName}, awserr.New("SharedCredsLoad", "failed to get profile", err)
|
if !ok {
|
||||||
|
return Value{ProviderName: SharedCredsProviderName}, awserr.New("SharedCredsLoad", "failed to get profile", nil)
|
||||||
}
|
}
|
||||||
|
|
||||||
id, err := iniProfile.GetKey("aws_access_key_id")
|
id := iniProfile.String("aws_access_key_id")
|
||||||
if err != nil {
|
if len(id) == 0 {
|
||||||
return Value{ProviderName: SharedCredsProviderName}, awserr.New("SharedCredsAccessKey",
|
return Value{ProviderName: SharedCredsProviderName}, awserr.New("SharedCredsAccessKey",
|
||||||
fmt.Sprintf("shared credentials %s in %s did not contain aws_access_key_id", profile, filename),
|
fmt.Sprintf("shared credentials %s in %s did not contain aws_access_key_id", profile, filename),
|
||||||
err)
|
nil)
|
||||||
}
|
}
|
||||||
|
|
||||||
secret, err := iniProfile.GetKey("aws_secret_access_key")
|
secret := iniProfile.String("aws_secret_access_key")
|
||||||
if err != nil {
|
if len(secret) == 0 {
|
||||||
return Value{ProviderName: SharedCredsProviderName}, awserr.New("SharedCredsSecret",
|
return Value{ProviderName: SharedCredsProviderName}, awserr.New("SharedCredsSecret",
|
||||||
fmt.Sprintf("shared credentials %s in %s did not contain aws_secret_access_key", profile, filename),
|
fmt.Sprintf("shared credentials %s in %s did not contain aws_secret_access_key", profile, filename),
|
||||||
nil)
|
nil)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Default to empty string if not found
|
// Default to empty string if not found
|
||||||
token := iniProfile.Key("aws_session_token")
|
token := iniProfile.String("aws_session_token")
|
||||||
|
|
||||||
return Value{
|
return Value{
|
||||||
AccessKeyID: id.String(),
|
AccessKeyID: id,
|
||||||
SecretAccessKey: secret.String(),
|
SecretAccessKey: secret,
|
||||||
SessionToken: token.String(),
|
SessionToken: token,
|
||||||
ProviderName: SharedCredsProviderName,
|
ProviderName: SharedCredsProviderName,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
|
@ -9,8 +9,6 @@ const StaticProviderName = "StaticProvider"
|
||||||
|
|
||||||
var (
|
var (
|
||||||
// ErrStaticCredentialsEmpty is emitted when static credentials are empty.
|
// ErrStaticCredentialsEmpty is emitted when static credentials are empty.
|
||||||
//
|
|
||||||
// @readonly
|
|
||||||
ErrStaticCredentialsEmpty = awserr.New("EmptyStaticCreds", "static credentials are empty", nil)
|
ErrStaticCredentialsEmpty = awserr.New("EmptyStaticCreds", "static credentials are empty", nil)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
22
vendor/github.com/aws/aws-sdk-go/aws/credentials/stscreds/assume_role_provider.go
generated
vendored
22
vendor/github.com/aws/aws-sdk-go/aws/credentials/stscreds/assume_role_provider.go
generated
vendored
|
@ -80,16 +80,18 @@ package stscreds
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"os"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/aws/aws-sdk-go/aws"
|
"github.com/aws/aws-sdk-go/aws"
|
||||||
"github.com/aws/aws-sdk-go/aws/awserr"
|
"github.com/aws/aws-sdk-go/aws/awserr"
|
||||||
"github.com/aws/aws-sdk-go/aws/client"
|
"github.com/aws/aws-sdk-go/aws/client"
|
||||||
"github.com/aws/aws-sdk-go/aws/credentials"
|
"github.com/aws/aws-sdk-go/aws/credentials"
|
||||||
|
"github.com/aws/aws-sdk-go/internal/sdkrand"
|
||||||
"github.com/aws/aws-sdk-go/service/sts"
|
"github.com/aws/aws-sdk-go/service/sts"
|
||||||
)
|
)
|
||||||
|
|
||||||
// StdinTokenProvider will prompt on stdout and read from stdin for a string value.
|
// StdinTokenProvider will prompt on stderr and read from stdin for a string value.
|
||||||
// An error is returned if reading from stdin fails.
|
// An error is returned if reading from stdin fails.
|
||||||
//
|
//
|
||||||
// Use this function go read MFA tokens from stdin. The function makes no attempt
|
// Use this function go read MFA tokens from stdin. The function makes no attempt
|
||||||
|
@ -102,7 +104,7 @@ import (
|
||||||
// Will wait forever until something is provided on the stdin.
|
// Will wait forever until something is provided on the stdin.
|
||||||
func StdinTokenProvider() (string, error) {
|
func StdinTokenProvider() (string, error) {
|
||||||
var v string
|
var v string
|
||||||
fmt.Printf("Assume Role MFA token code: ")
|
fmt.Fprintf(os.Stderr, "Assume Role MFA token code: ")
|
||||||
_, err := fmt.Scanln(&v)
|
_, err := fmt.Scanln(&v)
|
||||||
|
|
||||||
return v, err
|
return v, err
|
||||||
|
@ -193,6 +195,18 @@ type AssumeRoleProvider struct {
|
||||||
//
|
//
|
||||||
// If ExpiryWindow is 0 or less it will be ignored.
|
// If ExpiryWindow is 0 or less it will be ignored.
|
||||||
ExpiryWindow time.Duration
|
ExpiryWindow time.Duration
|
||||||
|
|
||||||
|
// MaxJitterFrac reduces the effective Duration of each credential requested
|
||||||
|
// by a random percentage between 0 and MaxJitterFraction. MaxJitterFrac must
|
||||||
|
// have a value between 0 and 1. Any other value may lead to expected behavior.
|
||||||
|
// With a MaxJitterFrac value of 0, default) will no jitter will be used.
|
||||||
|
//
|
||||||
|
// For example, with a Duration of 30m and a MaxJitterFrac of 0.1, the
|
||||||
|
// AssumeRole call will be made with an arbitrary Duration between 27m and
|
||||||
|
// 30m.
|
||||||
|
//
|
||||||
|
// MaxJitterFrac should not be negative.
|
||||||
|
MaxJitterFrac float64
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewCredentials returns a pointer to a new Credentials object wrapping the
|
// NewCredentials returns a pointer to a new Credentials object wrapping the
|
||||||
|
@ -244,7 +258,6 @@ func NewCredentialsWithClient(svc AssumeRoler, roleARN string, options ...func(*
|
||||||
|
|
||||||
// Retrieve generates a new set of temporary credentials using STS.
|
// Retrieve generates a new set of temporary credentials using STS.
|
||||||
func (p *AssumeRoleProvider) Retrieve() (credentials.Value, error) {
|
func (p *AssumeRoleProvider) Retrieve() (credentials.Value, error) {
|
||||||
|
|
||||||
// Apply defaults where parameters are not set.
|
// Apply defaults where parameters are not set.
|
||||||
if p.RoleSessionName == "" {
|
if p.RoleSessionName == "" {
|
||||||
// Try to work out a role name that will hopefully end up unique.
|
// Try to work out a role name that will hopefully end up unique.
|
||||||
|
@ -254,8 +267,9 @@ func (p *AssumeRoleProvider) Retrieve() (credentials.Value, error) {
|
||||||
// Expire as often as AWS permits.
|
// Expire as often as AWS permits.
|
||||||
p.Duration = DefaultDuration
|
p.Duration = DefaultDuration
|
||||||
}
|
}
|
||||||
|
jitter := time.Duration(sdkrand.SeededRand.Float64() * p.MaxJitterFrac * float64(p.Duration))
|
||||||
input := &sts.AssumeRoleInput{
|
input := &sts.AssumeRoleInput{
|
||||||
DurationSeconds: aws.Int64(int64(p.Duration / time.Second)),
|
DurationSeconds: aws.Int64(int64((p.Duration - jitter) / time.Second)),
|
||||||
RoleArn: aws.String(p.RoleARN),
|
RoleArn: aws.String(p.RoleARN),
|
||||||
RoleSessionName: aws.String(p.RoleSessionName),
|
RoleSessionName: aws.String(p.RoleSessionName),
|
||||||
ExternalId: p.ExternalID,
|
ExternalId: p.ExternalID,
|
||||||
|
|
99
vendor/github.com/aws/aws-sdk-go/aws/credentials/stscreds/web_identity_provider.go
generated
vendored
Normal file
99
vendor/github.com/aws/aws-sdk-go/aws/credentials/stscreds/web_identity_provider.go
generated
vendored
Normal file
|
@ -0,0 +1,99 @@
|
||||||
|
package stscreds
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"io/ioutil"
|
||||||
|
"strconv"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/aws/aws-sdk-go/aws"
|
||||||
|
"github.com/aws/aws-sdk-go/aws/awserr"
|
||||||
|
"github.com/aws/aws-sdk-go/aws/client"
|
||||||
|
"github.com/aws/aws-sdk-go/aws/credentials"
|
||||||
|
"github.com/aws/aws-sdk-go/service/sts"
|
||||||
|
"github.com/aws/aws-sdk-go/service/sts/stsiface"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
// ErrCodeWebIdentity will be used as an error code when constructing
|
||||||
|
// a new error to be returned during session creation or retrieval.
|
||||||
|
ErrCodeWebIdentity = "WebIdentityErr"
|
||||||
|
|
||||||
|
// WebIdentityProviderName is the web identity provider name
|
||||||
|
WebIdentityProviderName = "WebIdentityCredentials"
|
||||||
|
)
|
||||||
|
|
||||||
|
// now is used to return a time.Time object representing
|
||||||
|
// the current time. This can be used to easily test and
|
||||||
|
// compare test values.
|
||||||
|
var now = func() time.Time {
|
||||||
|
return time.Now()
|
||||||
|
}
|
||||||
|
|
||||||
|
// WebIdentityRoleProvider is used to retrieve credentials using
|
||||||
|
// an OIDC token.
|
||||||
|
type WebIdentityRoleProvider struct {
|
||||||
|
credentials.Expiry
|
||||||
|
|
||||||
|
client stsiface.STSAPI
|
||||||
|
ExpiryWindow time.Duration
|
||||||
|
|
||||||
|
tokenFilePath string
|
||||||
|
roleARN string
|
||||||
|
roleSessionName string
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewWebIdentityCredentials will return a new set of credentials with a given
|
||||||
|
// configuration, role arn, and token file path.
|
||||||
|
func NewWebIdentityCredentials(c client.ConfigProvider, roleARN, roleSessionName, path string) *credentials.Credentials {
|
||||||
|
svc := sts.New(c)
|
||||||
|
p := NewWebIdentityRoleProvider(svc, roleARN, roleSessionName, path)
|
||||||
|
return credentials.NewCredentials(p)
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewWebIdentityRoleProvider will return a new WebIdentityRoleProvider with the
|
||||||
|
// provided stsiface.STSAPI
|
||||||
|
func NewWebIdentityRoleProvider(svc stsiface.STSAPI, roleARN, roleSessionName, path string) *WebIdentityRoleProvider {
|
||||||
|
return &WebIdentityRoleProvider{
|
||||||
|
client: svc,
|
||||||
|
tokenFilePath: path,
|
||||||
|
roleARN: roleARN,
|
||||||
|
roleSessionName: roleSessionName,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Retrieve will attempt to assume a role from a token which is located at
|
||||||
|
// 'WebIdentityTokenFilePath' specified destination and if that is empty an
|
||||||
|
// error will be returned.
|
||||||
|
func (p *WebIdentityRoleProvider) Retrieve() (credentials.Value, error) {
|
||||||
|
b, err := ioutil.ReadFile(p.tokenFilePath)
|
||||||
|
if err != nil {
|
||||||
|
errMsg := fmt.Sprintf("unable to read file at %s", p.tokenFilePath)
|
||||||
|
return credentials.Value{}, awserr.New(ErrCodeWebIdentity, errMsg, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
sessionName := p.roleSessionName
|
||||||
|
if len(sessionName) == 0 {
|
||||||
|
// session name is used to uniquely identify a session. This simply
|
||||||
|
// uses unix time in nanoseconds to uniquely identify sessions.
|
||||||
|
sessionName = strconv.FormatInt(now().UnixNano(), 10)
|
||||||
|
}
|
||||||
|
resp, err := p.client.AssumeRoleWithWebIdentity(&sts.AssumeRoleWithWebIdentityInput{
|
||||||
|
RoleArn: &p.roleARN,
|
||||||
|
RoleSessionName: &sessionName,
|
||||||
|
WebIdentityToken: aws.String(string(b)),
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return credentials.Value{}, awserr.New(ErrCodeWebIdentity, "failed to retrieve credentials", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
p.SetExpiration(aws.TimeValue(resp.Credentials.Expiration), p.ExpiryWindow)
|
||||||
|
|
||||||
|
value := credentials.Value{
|
||||||
|
AccessKeyID: aws.StringValue(resp.Credentials.AccessKeyId),
|
||||||
|
SecretAccessKey: aws.StringValue(resp.Credentials.SecretAccessKey),
|
||||||
|
SessionToken: aws.StringValue(resp.Credentials.SessionToken),
|
||||||
|
ProviderName: WebIdentityProviderName,
|
||||||
|
}
|
||||||
|
return value, nil
|
||||||
|
}
|
|
@ -1,30 +1,61 @@
|
||||||
// Package csm provides Client Side Monitoring (CSM) which enables sending metrics
|
// Package csm provides the Client Side Monitoring (CSM) client which enables
|
||||||
// via UDP connection. Using the Start function will enable the reporting of
|
// sending metrics via UDP connection to the CSM agent. This package provides
|
||||||
// metrics on a given port. If Start is called, with different parameters, again,
|
// control options, and configuration for the CSM client. The client can be
|
||||||
// a panic will occur.
|
// controlled manually, or automatically via the SDK's Session configuration.
|
||||||
//
|
//
|
||||||
// Pause can be called to pause any metrics publishing on a given port. Sessions
|
// Enabling CSM client via SDK's Session configuration
|
||||||
// that have had their handlers modified via InjectHandlers may still be used.
|
//
|
||||||
// However, the handlers will act as a no-op meaning no metrics will be published.
|
// The CSM client can be enabled automatically via SDK's Session configuration.
|
||||||
|
// The SDK's session configuration enables the CSM client if the AWS_CSM_PORT
|
||||||
|
// environment variable is set to a non-empty value.
|
||||||
|
//
|
||||||
|
// The configuration options for the CSM client via the SDK's session
|
||||||
|
// configuration are:
|
||||||
|
//
|
||||||
|
// * AWS_CSM_PORT=<port number>
|
||||||
|
// The port number the CSM agent will receive metrics on.
|
||||||
|
//
|
||||||
|
// * AWS_CSM_HOST=<hostname or ip>
|
||||||
|
// The hostname, or IP address the CSM agent will receive metrics on.
|
||||||
|
// Without port number.
|
||||||
|
//
|
||||||
|
// Manually enabling the CSM client
|
||||||
|
//
|
||||||
|
// The CSM client can be started, paused, and resumed manually. The Start
|
||||||
|
// function will enable the CSM client to publish metrics to the CSM agent. It
|
||||||
|
// is safe to call Start concurrently, but if Start is called additional times
|
||||||
|
// with different ClientID or address it will panic.
|
||||||
//
|
//
|
||||||
// Example:
|
|
||||||
// r, err := csm.Start("clientID", ":31000")
|
// r, err := csm.Start("clientID", ":31000")
|
||||||
// if err != nil {
|
// if err != nil {
|
||||||
// panic(fmt.Errorf("failed starting CSM: %v", err))
|
// panic(fmt.Errorf("failed starting CSM: %v", err))
|
||||||
// }
|
// }
|
||||||
//
|
//
|
||||||
|
// When controlling the CSM client manually, you must also inject its request
|
||||||
|
// handlers into the SDK's Session configuration for the SDK's API clients to
|
||||||
|
// publish metrics.
|
||||||
|
//
|
||||||
// sess, err := session.NewSession(&aws.Config{})
|
// sess, err := session.NewSession(&aws.Config{})
|
||||||
// if err != nil {
|
// if err != nil {
|
||||||
// panic(fmt.Errorf("failed loading session: %v", err))
|
// panic(fmt.Errorf("failed loading session: %v", err))
|
||||||
// }
|
// }
|
||||||
//
|
//
|
||||||
|
// // Add CSM client's metric publishing request handlers to the SDK's
|
||||||
|
// // Session Configuration.
|
||||||
// r.InjectHandlers(&sess.Handlers)
|
// r.InjectHandlers(&sess.Handlers)
|
||||||
//
|
//
|
||||||
// client := s3.New(sess)
|
// Controlling CSM client
|
||||||
// resp, err := client.GetObject(&s3.GetObjectInput{
|
//
|
||||||
// Bucket: aws.String("bucket"),
|
// Once the CSM client has been enabled the Get function will return a Reporter
|
||||||
// Key: aws.String("key"),
|
// value that you can use to pause and resume the metrics published to the CSM
|
||||||
// })
|
// agent. If Get function is called before the reporter is enabled with the
|
||||||
|
// Start function or via SDK's Session configuration nil will be returned.
|
||||||
|
//
|
||||||
|
// The Pause method can be called to stop the CSM client publishing metrics to
|
||||||
|
// the CSM agent. The Continue method will resume metric publishing.
|
||||||
|
//
|
||||||
|
// // Get the CSM client Reporter.
|
||||||
|
// r := csm.Get()
|
||||||
//
|
//
|
||||||
// // Will pause monitoring
|
// // Will pause monitoring
|
||||||
// r.Pause()
|
// r.Pause()
|
||||||
|
@ -35,12 +66,4 @@
|
||||||
//
|
//
|
||||||
// // Resume monitoring
|
// // Resume monitoring
|
||||||
// r.Continue()
|
// r.Continue()
|
||||||
//
|
|
||||||
// Start returns a Reporter that is used to enable or disable monitoring. If
|
|
||||||
// access to the Reporter is required later, calling Get will return the Reporter
|
|
||||||
// singleton.
|
|
||||||
//
|
|
||||||
// Example:
|
|
||||||
// r := csm.Get()
|
|
||||||
// r.Continue()
|
|
||||||
package csm
|
package csm
|
||||||
|
|
|
@ -2,6 +2,7 @@ package csm
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -9,19 +10,40 @@ var (
|
||||||
lock sync.Mutex
|
lock sync.Mutex
|
||||||
)
|
)
|
||||||
|
|
||||||
// Client side metric handler names
|
|
||||||
const (
|
const (
|
||||||
APICallMetricHandlerName = "awscsm.SendAPICallMetric"
|
// DefaultPort is used when no port is specified.
|
||||||
APICallAttemptMetricHandlerName = "awscsm.SendAPICallAttemptMetric"
|
DefaultPort = "31000"
|
||||||
|
|
||||||
|
// DefaultHost is the host that will be used when none is specified.
|
||||||
|
DefaultHost = "127.0.0.1"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Start will start the a long running go routine to capture
|
// AddressWithDefaults returns a CSM address built from the host and port
|
||||||
|
// values. If the host or port is not set, default values will be used
|
||||||
|
// instead. If host is "localhost" it will be replaced with "127.0.0.1".
|
||||||
|
func AddressWithDefaults(host, port string) string {
|
||||||
|
if len(host) == 0 || strings.EqualFold(host, "localhost") {
|
||||||
|
host = DefaultHost
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(port) == 0 {
|
||||||
|
port = DefaultPort
|
||||||
|
}
|
||||||
|
|
||||||
|
// Only IP6 host can contain a colon
|
||||||
|
if strings.Contains(host, ":") {
|
||||||
|
return "[" + host + "]:" + port
|
||||||
|
}
|
||||||
|
|
||||||
|
return host + ":" + port
|
||||||
|
}
|
||||||
|
|
||||||
|
// Start will start a long running go routine to capture
|
||||||
// client side metrics. Calling start multiple time will only
|
// client side metrics. Calling start multiple time will only
|
||||||
// start the metric listener once and will panic if a different
|
// start the metric listener once and will panic if a different
|
||||||
// client ID or port is passed in.
|
// client ID or port is passed in.
|
||||||
//
|
//
|
||||||
// Example:
|
// r, err := csm.Start("clientID", "127.0.0.1:31000")
|
||||||
// r, err := csm.Start("clientID", "127.0.0.1:8094")
|
|
||||||
// if err != nil {
|
// if err != nil {
|
||||||
// panic(fmt.Errorf("expected no error, but received %v", err))
|
// panic(fmt.Errorf("expected no error, but received %v", err))
|
||||||
// }
|
// }
|
||||||
|
|
|
@ -3,6 +3,8 @@ package csm
|
||||||
import (
|
import (
|
||||||
"strconv"
|
"strconv"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/aws/aws-sdk-go/aws"
|
||||||
)
|
)
|
||||||
|
|
||||||
type metricTime time.Time
|
type metricTime time.Time
|
||||||
|
@ -39,6 +41,12 @@ type metric struct {
|
||||||
SDKException *string `json:"SdkException,omitempty"`
|
SDKException *string `json:"SdkException,omitempty"`
|
||||||
SDKExceptionMessage *string `json:"SdkExceptionMessage,omitempty"`
|
SDKExceptionMessage *string `json:"SdkExceptionMessage,omitempty"`
|
||||||
|
|
||||||
|
FinalHTTPStatusCode *int `json:"FinalHttpStatusCode,omitempty"`
|
||||||
|
FinalAWSException *string `json:"FinalAwsException,omitempty"`
|
||||||
|
FinalAWSExceptionMessage *string `json:"FinalAwsExceptionMessage,omitempty"`
|
||||||
|
FinalSDKException *string `json:"FinalSdkException,omitempty"`
|
||||||
|
FinalSDKExceptionMessage *string `json:"FinalSdkExceptionMessage,omitempty"`
|
||||||
|
|
||||||
DestinationIP *string `json:"DestinationIp,omitempty"`
|
DestinationIP *string `json:"DestinationIp,omitempty"`
|
||||||
ConnectionReused *int `json:"ConnectionReused,omitempty"`
|
ConnectionReused *int `json:"ConnectionReused,omitempty"`
|
||||||
|
|
||||||
|
@ -48,4 +56,54 @@ type metric struct {
|
||||||
DNSLatency *int `json:"DnsLatency,omitempty"`
|
DNSLatency *int `json:"DnsLatency,omitempty"`
|
||||||
TCPLatency *int `json:"TcpLatency,omitempty"`
|
TCPLatency *int `json:"TcpLatency,omitempty"`
|
||||||
SSLLatency *int `json:"SslLatency,omitempty"`
|
SSLLatency *int `json:"SslLatency,omitempty"`
|
||||||
|
|
||||||
|
MaxRetriesExceeded *int `json:"MaxRetriesExceeded,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *metric) TruncateFields() {
|
||||||
|
m.ClientID = truncateString(m.ClientID, 255)
|
||||||
|
m.UserAgent = truncateString(m.UserAgent, 256)
|
||||||
|
|
||||||
|
m.AWSException = truncateString(m.AWSException, 128)
|
||||||
|
m.AWSExceptionMessage = truncateString(m.AWSExceptionMessage, 512)
|
||||||
|
|
||||||
|
m.SDKException = truncateString(m.SDKException, 128)
|
||||||
|
m.SDKExceptionMessage = truncateString(m.SDKExceptionMessage, 512)
|
||||||
|
|
||||||
|
m.FinalAWSException = truncateString(m.FinalAWSException, 128)
|
||||||
|
m.FinalAWSExceptionMessage = truncateString(m.FinalAWSExceptionMessage, 512)
|
||||||
|
|
||||||
|
m.FinalSDKException = truncateString(m.FinalSDKException, 128)
|
||||||
|
m.FinalSDKExceptionMessage = truncateString(m.FinalSDKExceptionMessage, 512)
|
||||||
|
}
|
||||||
|
|
||||||
|
func truncateString(v *string, l int) *string {
|
||||||
|
if v != nil && len(*v) > l {
|
||||||
|
nv := (*v)[:l]
|
||||||
|
return &nv
|
||||||
|
}
|
||||||
|
|
||||||
|
return v
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *metric) SetException(e metricException) {
|
||||||
|
switch te := e.(type) {
|
||||||
|
case awsException:
|
||||||
|
m.AWSException = aws.String(te.exception)
|
||||||
|
m.AWSExceptionMessage = aws.String(te.message)
|
||||||
|
case sdkException:
|
||||||
|
m.SDKException = aws.String(te.exception)
|
||||||
|
m.SDKExceptionMessage = aws.String(te.message)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *metric) SetFinalException(e metricException) {
|
||||||
|
switch te := e.(type) {
|
||||||
|
case awsException:
|
||||||
|
m.FinalAWSException = aws.String(te.exception)
|
||||||
|
m.FinalAWSExceptionMessage = aws.String(te.message)
|
||||||
|
case sdkException:
|
||||||
|
m.FinalSDKException = aws.String(te.exception)
|
||||||
|
m.FinalSDKExceptionMessage = aws.String(te.message)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -0,0 +1,26 @@
|
||||||
|
package csm
|
||||||
|
|
||||||
|
type metricException interface {
|
||||||
|
Exception() string
|
||||||
|
Message() string
|
||||||
|
}
|
||||||
|
|
||||||
|
type requestException struct {
|
||||||
|
exception string
|
||||||
|
message string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e requestException) Exception() string {
|
||||||
|
return e.exception
|
||||||
|
}
|
||||||
|
func (e requestException) Message() string {
|
||||||
|
return e.message
|
||||||
|
}
|
||||||
|
|
||||||
|
type awsException struct {
|
||||||
|
requestException
|
||||||
|
}
|
||||||
|
|
||||||
|
type sdkException struct {
|
||||||
|
requestException
|
||||||
|
}
|
|
@ -10,11 +10,6 @@ import (
|
||||||
"github.com/aws/aws-sdk-go/aws/request"
|
"github.com/aws/aws-sdk-go/aws/request"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
|
||||||
// DefaultPort is used when no port is specified
|
|
||||||
DefaultPort = "31000"
|
|
||||||
)
|
|
||||||
|
|
||||||
// Reporter will gather metrics of API requests made and
|
// Reporter will gather metrics of API requests made and
|
||||||
// send those metrics to the CSM endpoint.
|
// send those metrics to the CSM endpoint.
|
||||||
type Reporter struct {
|
type Reporter struct {
|
||||||
|
@ -82,26 +77,29 @@ func (rep *Reporter) sendAPICallAttemptMetric(r *request.Request) {
|
||||||
|
|
||||||
if r.Error != nil {
|
if r.Error != nil {
|
||||||
if awserr, ok := r.Error.(awserr.Error); ok {
|
if awserr, ok := r.Error.(awserr.Error); ok {
|
||||||
setError(&m, awserr)
|
m.SetException(getMetricException(awserr))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
m.TruncateFields()
|
||||||
rep.metricsCh.Push(m)
|
rep.metricsCh.Push(m)
|
||||||
}
|
}
|
||||||
|
|
||||||
func setError(m *metric, err awserr.Error) {
|
func getMetricException(err awserr.Error) metricException {
|
||||||
msg := err.Error()
|
msg := err.Error()
|
||||||
code := err.Code()
|
code := err.Code()
|
||||||
|
|
||||||
switch code {
|
switch code {
|
||||||
case "RequestError",
|
case "RequestError",
|
||||||
"SerializationError",
|
request.ErrCodeSerialization,
|
||||||
request.CanceledErrorCode:
|
request.CanceledErrorCode:
|
||||||
m.SDKException = &code
|
return sdkException{
|
||||||
m.SDKExceptionMessage = &msg
|
requestException{exception: code, message: msg},
|
||||||
|
}
|
||||||
default:
|
default:
|
||||||
m.AWSException = &code
|
return awsException{
|
||||||
m.AWSExceptionMessage = &msg
|
requestException{exception: code, message: msg},
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -116,12 +114,27 @@ func (rep *Reporter) sendAPICallMetric(r *request.Request) {
|
||||||
API: aws.String(r.Operation.Name),
|
API: aws.String(r.Operation.Name),
|
||||||
Service: aws.String(r.ClientInfo.ServiceID),
|
Service: aws.String(r.ClientInfo.ServiceID),
|
||||||
Timestamp: (*metricTime)(&now),
|
Timestamp: (*metricTime)(&now),
|
||||||
|
UserAgent: aws.String(r.HTTPRequest.Header.Get("User-Agent")),
|
||||||
Type: aws.String("ApiCall"),
|
Type: aws.String("ApiCall"),
|
||||||
AttemptCount: aws.Int(r.RetryCount + 1),
|
AttemptCount: aws.Int(r.RetryCount + 1),
|
||||||
|
Region: r.Config.Region,
|
||||||
Latency: aws.Int(int(time.Now().Sub(r.Time) / time.Millisecond)),
|
Latency: aws.Int(int(time.Now().Sub(r.Time) / time.Millisecond)),
|
||||||
XAmzRequestID: aws.String(r.RequestID),
|
XAmzRequestID: aws.String(r.RequestID),
|
||||||
|
MaxRetriesExceeded: aws.Int(boolIntValue(r.RetryCount >= r.MaxRetries())),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if r.HTTPResponse != nil {
|
||||||
|
m.FinalHTTPStatusCode = aws.Int(r.HTTPResponse.StatusCode)
|
||||||
|
}
|
||||||
|
|
||||||
|
if r.Error != nil {
|
||||||
|
if awserr, ok := r.Error.(awserr.Error); ok {
|
||||||
|
m.SetFinalException(getMetricException(awserr))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
m.TruncateFields()
|
||||||
|
|
||||||
// TODO: Probably want to figure something out for logging dropped
|
// TODO: Probably want to figure something out for logging dropped
|
||||||
// metrics
|
// metrics
|
||||||
rep.metricsCh.Push(m)
|
rep.metricsCh.Push(m)
|
||||||
|
@ -172,8 +185,9 @@ func (rep *Reporter) start() {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Pause will pause the metric channel preventing any new metrics from
|
// Pause will pause the metric channel preventing any new metrics from being
|
||||||
// being added.
|
// added. It is safe to call concurrently with other calls to Pause, but if
|
||||||
|
// called concurently with Continue can lead to unexpected state.
|
||||||
func (rep *Reporter) Pause() {
|
func (rep *Reporter) Pause() {
|
||||||
lock.Lock()
|
lock.Lock()
|
||||||
defer lock.Unlock()
|
defer lock.Unlock()
|
||||||
|
@ -185,8 +199,9 @@ func (rep *Reporter) Pause() {
|
||||||
rep.close()
|
rep.close()
|
||||||
}
|
}
|
||||||
|
|
||||||
// Continue will reopen the metric channel and allow for monitoring
|
// Continue will reopen the metric channel and allow for monitoring to be
|
||||||
// to be resumed.
|
// resumed. It is safe to call concurrently with other calls to Continue, but
|
||||||
|
// if called concurently with Pause can lead to unexpected state.
|
||||||
func (rep *Reporter) Continue() {
|
func (rep *Reporter) Continue() {
|
||||||
lock.Lock()
|
lock.Lock()
|
||||||
defer lock.Unlock()
|
defer lock.Unlock()
|
||||||
|
@ -201,10 +216,18 @@ func (rep *Reporter) Continue() {
|
||||||
rep.metricsCh.Continue()
|
rep.metricsCh.Continue()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Client side metric handler names
|
||||||
|
const (
|
||||||
|
APICallMetricHandlerName = "awscsm.SendAPICallMetric"
|
||||||
|
APICallAttemptMetricHandlerName = "awscsm.SendAPICallAttemptMetric"
|
||||||
|
)
|
||||||
|
|
||||||
// InjectHandlers will will enable client side metrics and inject the proper
|
// InjectHandlers will will enable client side metrics and inject the proper
|
||||||
// handlers to handle how metrics are sent.
|
// handlers to handle how metrics are sent.
|
||||||
//
|
//
|
||||||
// Example:
|
// InjectHandlers is NOT safe to call concurrently. Calling InjectHandlers
|
||||||
|
// multiple times may lead to unexpected behavior, (e.g. duplicate metrics).
|
||||||
|
//
|
||||||
// // Start must be called in order to inject the correct handlers
|
// // Start must be called in order to inject the correct handlers
|
||||||
// r, err := csm.Start("clientID", "127.0.0.1:8094")
|
// r, err := csm.Start("clientID", "127.0.0.1:8094")
|
||||||
// if err != nil {
|
// if err != nil {
|
||||||
|
@ -221,11 +244,22 @@ func (rep *Reporter) InjectHandlers(handlers *request.Handlers) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
apiCallHandler := request.NamedHandler{Name: APICallMetricHandlerName, Fn: rep.sendAPICallMetric}
|
handlers.Complete.PushFrontNamed(request.NamedHandler{
|
||||||
apiCallAttemptHandler := request.NamedHandler{Name: APICallAttemptMetricHandlerName, Fn: rep.sendAPICallAttemptMetric}
|
Name: APICallMetricHandlerName,
|
||||||
|
Fn: rep.sendAPICallMetric,
|
||||||
|
})
|
||||||
|
|
||||||
handlers.Complete.PushFrontNamed(apiCallHandler)
|
handlers.CompleteAttempt.PushFrontNamed(request.NamedHandler{
|
||||||
handlers.Complete.PushFrontNamed(apiCallAttemptHandler)
|
Name: APICallAttemptMetricHandlerName,
|
||||||
|
Fn: rep.sendAPICallAttemptMetric,
|
||||||
handlers.AfterRetry.PushFrontNamed(apiCallAttemptHandler)
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// boolIntValue return 1 for true and 0 for false.
|
||||||
|
func boolIntValue(b bool) int {
|
||||||
|
if b {
|
||||||
|
return 1
|
||||||
|
}
|
||||||
|
|
||||||
|
return 0
|
||||||
}
|
}
|
||||||
|
|
|
@ -24,6 +24,7 @@ import (
|
||||||
"github.com/aws/aws-sdk-go/aws/ec2metadata"
|
"github.com/aws/aws-sdk-go/aws/ec2metadata"
|
||||||
"github.com/aws/aws-sdk-go/aws/endpoints"
|
"github.com/aws/aws-sdk-go/aws/endpoints"
|
||||||
"github.com/aws/aws-sdk-go/aws/request"
|
"github.com/aws/aws-sdk-go/aws/request"
|
||||||
|
"github.com/aws/aws-sdk-go/internal/shareddefaults"
|
||||||
)
|
)
|
||||||
|
|
||||||
// A Defaults provides a collection of default values for SDK clients.
|
// A Defaults provides a collection of default values for SDK clients.
|
||||||
|
@ -112,8 +113,8 @@ func CredProviders(cfg *aws.Config, handlers request.Handlers) []credentials.Pro
|
||||||
}
|
}
|
||||||
|
|
||||||
const (
|
const (
|
||||||
|
httpProviderAuthorizationEnvVar = "AWS_CONTAINER_AUTHORIZATION_TOKEN"
|
||||||
httpProviderEnvVar = "AWS_CONTAINER_CREDENTIALS_FULL_URI"
|
httpProviderEnvVar = "AWS_CONTAINER_CREDENTIALS_FULL_URI"
|
||||||
ecsCredsProviderEnvVar = "AWS_CONTAINER_CREDENTIALS_RELATIVE_URI"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// RemoteCredProvider returns a credentials provider for the default remote
|
// RemoteCredProvider returns a credentials provider for the default remote
|
||||||
|
@ -123,8 +124,8 @@ func RemoteCredProvider(cfg aws.Config, handlers request.Handlers) credentials.P
|
||||||
return localHTTPCredProvider(cfg, handlers, u)
|
return localHTTPCredProvider(cfg, handlers, u)
|
||||||
}
|
}
|
||||||
|
|
||||||
if uri := os.Getenv(ecsCredsProviderEnvVar); len(uri) > 0 {
|
if uri := os.Getenv(shareddefaults.ECSCredsProviderEnvVar); len(uri) > 0 {
|
||||||
u := fmt.Sprintf("http://169.254.170.2%s", uri)
|
u := fmt.Sprintf("%s%s", shareddefaults.ECSContainerCredentialsURI, uri)
|
||||||
return httpCredProvider(cfg, handlers, u)
|
return httpCredProvider(cfg, handlers, u)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -187,6 +188,7 @@ func httpCredProvider(cfg aws.Config, handlers request.Handlers, u string) crede
|
||||||
return endpointcreds.NewProviderClient(cfg, handlers, u,
|
return endpointcreds.NewProviderClient(cfg, handlers, u,
|
||||||
func(p *endpointcreds.Provider) {
|
func(p *endpointcreds.Provider) {
|
||||||
p.ExpiryWindow = 5 * time.Minute
|
p.ExpiryWindow = 5 * time.Minute
|
||||||
|
p.AuthorizationToken = os.Getenv(httpProviderAuthorizationEnvVar)
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
|
@ -24,8 +24,9 @@ func (c *EC2Metadata) GetMetadata(p string) (string, error) {
|
||||||
|
|
||||||
output := &metadataOutput{}
|
output := &metadataOutput{}
|
||||||
req := c.NewRequest(op, nil, output)
|
req := c.NewRequest(op, nil, output)
|
||||||
|
err := req.Send()
|
||||||
|
|
||||||
return output.Content, req.Send()
|
return output.Content, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetUserData returns the userdata that was configured for the service. If
|
// GetUserData returns the userdata that was configured for the service. If
|
||||||
|
@ -45,8 +46,9 @@ func (c *EC2Metadata) GetUserData() (string, error) {
|
||||||
r.Error = awserr.New("NotFoundError", "user-data not found", r.Error)
|
r.Error = awserr.New("NotFoundError", "user-data not found", r.Error)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
err := req.Send()
|
||||||
|
|
||||||
return output.Content, req.Send()
|
return output.Content, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetDynamicData uses the path provided to request information from the EC2
|
// GetDynamicData uses the path provided to request information from the EC2
|
||||||
|
@ -61,8 +63,9 @@ func (c *EC2Metadata) GetDynamicData(p string) (string, error) {
|
||||||
|
|
||||||
output := &metadataOutput{}
|
output := &metadataOutput{}
|
||||||
req := c.NewRequest(op, nil, output)
|
req := c.NewRequest(op, nil, output)
|
||||||
|
err := req.Send()
|
||||||
|
|
||||||
return output.Content, req.Send()
|
return output.Content, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetInstanceIdentityDocument retrieves an identity document describing an
|
// GetInstanceIdentityDocument retrieves an identity document describing an
|
||||||
|
@ -79,7 +82,7 @@ func (c *EC2Metadata) GetInstanceIdentityDocument() (EC2InstanceIdentityDocument
|
||||||
doc := EC2InstanceIdentityDocument{}
|
doc := EC2InstanceIdentityDocument{}
|
||||||
if err := json.NewDecoder(strings.NewReader(resp)).Decode(&doc); err != nil {
|
if err := json.NewDecoder(strings.NewReader(resp)).Decode(&doc); err != nil {
|
||||||
return EC2InstanceIdentityDocument{},
|
return EC2InstanceIdentityDocument{},
|
||||||
awserr.New("SerializationError",
|
awserr.New(request.ErrCodeSerialization,
|
||||||
"failed to decode EC2 instance identity document", err)
|
"failed to decode EC2 instance identity document", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -98,7 +101,7 @@ func (c *EC2Metadata) IAMInfo() (EC2IAMInfo, error) {
|
||||||
info := EC2IAMInfo{}
|
info := EC2IAMInfo{}
|
||||||
if err := json.NewDecoder(strings.NewReader(resp)).Decode(&info); err != nil {
|
if err := json.NewDecoder(strings.NewReader(resp)).Decode(&info); err != nil {
|
||||||
return EC2IAMInfo{},
|
return EC2IAMInfo{},
|
||||||
awserr.New("SerializationError",
|
awserr.New(request.ErrCodeSerialization,
|
||||||
"failed to decode EC2 IAM info", err)
|
"failed to decode EC2 IAM info", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -118,6 +121,10 @@ func (c *EC2Metadata) Region() (string, error) {
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if len(resp) == 0 {
|
||||||
|
return "", awserr.New("EC2MetadataError", "invalid Region response", nil)
|
||||||
|
}
|
||||||
|
|
||||||
// returns region without the suffix. Eg: us-west-2a becomes us-west-2
|
// returns region without the suffix. Eg: us-west-2a becomes us-west-2
|
||||||
return resp[:len(resp)-1], nil
|
return resp[:len(resp)-1], nil
|
||||||
}
|
}
|
||||||
|
|
|
@ -4,7 +4,7 @@
|
||||||
// This package's client can be disabled completely by setting the environment
|
// This package's client can be disabled completely by setting the environment
|
||||||
// variable "AWS_EC2_METADATA_DISABLED=true". This environment variable set to
|
// variable "AWS_EC2_METADATA_DISABLED=true". This environment variable set to
|
||||||
// true instructs the SDK to disable the EC2 Metadata client. The client cannot
|
// true instructs the SDK to disable the EC2 Metadata client. The client cannot
|
||||||
// be used while the environemnt variable is set to true, (case insensitive).
|
// be used while the environment variable is set to true, (case insensitive).
|
||||||
package ec2metadata
|
package ec2metadata
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
@ -72,6 +72,7 @@ func NewClient(cfg aws.Config, handlers request.Handlers, endpoint, signingRegio
|
||||||
cfg,
|
cfg,
|
||||||
metadata.ClientInfo{
|
metadata.ClientInfo{
|
||||||
ServiceName: ServiceName,
|
ServiceName: ServiceName,
|
||||||
|
ServiceID: ServiceName,
|
||||||
Endpoint: endpoint,
|
Endpoint: endpoint,
|
||||||
APIVersion: "latest",
|
APIVersion: "latest",
|
||||||
},
|
},
|
||||||
|
@ -91,6 +92,9 @@ func NewClient(cfg aws.Config, handlers request.Handlers, endpoint, signingRegio
|
||||||
svc.Handlers.Send.SwapNamed(request.NamedHandler{
|
svc.Handlers.Send.SwapNamed(request.NamedHandler{
|
||||||
Name: corehandlers.SendHandler.Name,
|
Name: corehandlers.SendHandler.Name,
|
||||||
Fn: func(r *request.Request) {
|
Fn: func(r *request.Request) {
|
||||||
|
r.HTTPResponse = &http.Response{
|
||||||
|
Header: http.Header{},
|
||||||
|
}
|
||||||
r.Error = awserr.New(
|
r.Error = awserr.New(
|
||||||
request.CanceledErrorCode,
|
request.CanceledErrorCode,
|
||||||
"EC2 IMDS access disabled via "+disableServiceEnvVar+" env var",
|
"EC2 IMDS access disabled via "+disableServiceEnvVar+" env var",
|
||||||
|
@ -119,7 +123,7 @@ func unmarshalHandler(r *request.Request) {
|
||||||
defer r.HTTPResponse.Body.Close()
|
defer r.HTTPResponse.Body.Close()
|
||||||
b := &bytes.Buffer{}
|
b := &bytes.Buffer{}
|
||||||
if _, err := io.Copy(b, r.HTTPResponse.Body); err != nil {
|
if _, err := io.Copy(b, r.HTTPResponse.Body); err != nil {
|
||||||
r.Error = awserr.New("SerializationError", "unable to unmarshal EC2 metadata respose", err)
|
r.Error = awserr.New(request.ErrCodeSerialization, "unable to unmarshal EC2 metadata respose", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -132,7 +136,7 @@ func unmarshalError(r *request.Request) {
|
||||||
defer r.HTTPResponse.Body.Close()
|
defer r.HTTPResponse.Body.Close()
|
||||||
b := &bytes.Buffer{}
|
b := &bytes.Buffer{}
|
||||||
if _, err := io.Copy(b, r.HTTPResponse.Body); err != nil {
|
if _, err := io.Copy(b, r.HTTPResponse.Body); err != nil {
|
||||||
r.Error = awserr.New("SerializationError", "unable to unmarshal EC2 metadata error respose", err)
|
r.Error = awserr.New(request.ErrCodeSerialization, "unable to unmarshal EC2 metadata error respose", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -85,6 +85,7 @@ func decodeV3Endpoints(modelDef modelDefinition, opts DecodeModelOptions) (Resol
|
||||||
custAddS3DualStack(p)
|
custAddS3DualStack(p)
|
||||||
custRmIotDataService(p)
|
custRmIotDataService(p)
|
||||||
custFixAppAutoscalingChina(p)
|
custFixAppAutoscalingChina(p)
|
||||||
|
custFixAppAutoscalingUsGov(p)
|
||||||
}
|
}
|
||||||
|
|
||||||
return ps, nil
|
return ps, nil
|
||||||
|
@ -95,7 +96,12 @@ func custAddS3DualStack(p *partition) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
s, ok := p.Services["s3"]
|
custAddDualstack(p, "s3")
|
||||||
|
custAddDualstack(p, "s3-control")
|
||||||
|
}
|
||||||
|
|
||||||
|
func custAddDualstack(p *partition, svcName string) {
|
||||||
|
s, ok := p.Services[svcName]
|
||||||
if !ok {
|
if !ok {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -103,7 +109,7 @@ func custAddS3DualStack(p *partition) {
|
||||||
s.Defaults.HasDualStack = boxedTrue
|
s.Defaults.HasDualStack = boxedTrue
|
||||||
s.Defaults.DualStackHostname = "{service}.dualstack.{region}.{dnsSuffix}"
|
s.Defaults.DualStackHostname = "{service}.dualstack.{region}.{dnsSuffix}"
|
||||||
|
|
||||||
p.Services["s3"] = s
|
p.Services[svcName] = s
|
||||||
}
|
}
|
||||||
|
|
||||||
func custAddEC2Metadata(p *partition) {
|
func custAddEC2Metadata(p *partition) {
|
||||||
|
@ -144,6 +150,33 @@ func custFixAppAutoscalingChina(p *partition) {
|
||||||
p.Services[serviceName] = s
|
p.Services[serviceName] = s
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func custFixAppAutoscalingUsGov(p *partition) {
|
||||||
|
if p.ID != "aws-us-gov" {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
const serviceName = "application-autoscaling"
|
||||||
|
s, ok := p.Services[serviceName]
|
||||||
|
if !ok {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if a := s.Defaults.CredentialScope.Service; a != "" {
|
||||||
|
fmt.Printf("custFixAppAutoscalingUsGov: ignoring customization, expected empty credential scope service, got %s\n", a)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if a := s.Defaults.Hostname; a != "" {
|
||||||
|
fmt.Printf("custFixAppAutoscalingUsGov: ignoring customization, expected empty hostname, got %s\n", a)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
s.Defaults.CredentialScope.Service = "application-autoscaling"
|
||||||
|
s.Defaults.Hostname = "autoscaling.{region}.amazonaws.com"
|
||||||
|
|
||||||
|
p.Services[serviceName] = s
|
||||||
|
}
|
||||||
|
|
||||||
type decodeModelError struct {
|
type decodeModelError struct {
|
||||||
awsError
|
awsError
|
||||||
}
|
}
|
||||||
|
|
File diff suppressed because it is too large
Load Diff
141
vendor/github.com/aws/aws-sdk-go/aws/endpoints/dep_service_ids.go
generated
vendored
Normal file
141
vendor/github.com/aws/aws-sdk-go/aws/endpoints/dep_service_ids.go
generated
vendored
Normal file
|
@ -0,0 +1,141 @@
|
||||||
|
package endpoints
|
||||||
|
|
||||||
|
// Service identifiers
|
||||||
|
//
|
||||||
|
// Deprecated: Use client package's EndpointsID value instead of these
|
||||||
|
// ServiceIDs. These IDs are not maintained, and are out of date.
|
||||||
|
const (
|
||||||
|
A4bServiceID = "a4b" // A4b.
|
||||||
|
AcmServiceID = "acm" // Acm.
|
||||||
|
AcmPcaServiceID = "acm-pca" // AcmPca.
|
||||||
|
ApiMediatailorServiceID = "api.mediatailor" // ApiMediatailor.
|
||||||
|
ApiPricingServiceID = "api.pricing" // ApiPricing.
|
||||||
|
ApiSagemakerServiceID = "api.sagemaker" // ApiSagemaker.
|
||||||
|
ApigatewayServiceID = "apigateway" // Apigateway.
|
||||||
|
ApplicationAutoscalingServiceID = "application-autoscaling" // ApplicationAutoscaling.
|
||||||
|
Appstream2ServiceID = "appstream2" // Appstream2.
|
||||||
|
AppsyncServiceID = "appsync" // Appsync.
|
||||||
|
AthenaServiceID = "athena" // Athena.
|
||||||
|
AutoscalingServiceID = "autoscaling" // Autoscaling.
|
||||||
|
AutoscalingPlansServiceID = "autoscaling-plans" // AutoscalingPlans.
|
||||||
|
BatchServiceID = "batch" // Batch.
|
||||||
|
BudgetsServiceID = "budgets" // Budgets.
|
||||||
|
CeServiceID = "ce" // Ce.
|
||||||
|
ChimeServiceID = "chime" // Chime.
|
||||||
|
Cloud9ServiceID = "cloud9" // Cloud9.
|
||||||
|
ClouddirectoryServiceID = "clouddirectory" // Clouddirectory.
|
||||||
|
CloudformationServiceID = "cloudformation" // Cloudformation.
|
||||||
|
CloudfrontServiceID = "cloudfront" // Cloudfront.
|
||||||
|
CloudhsmServiceID = "cloudhsm" // Cloudhsm.
|
||||||
|
Cloudhsmv2ServiceID = "cloudhsmv2" // Cloudhsmv2.
|
||||||
|
CloudsearchServiceID = "cloudsearch" // Cloudsearch.
|
||||||
|
CloudtrailServiceID = "cloudtrail" // Cloudtrail.
|
||||||
|
CodebuildServiceID = "codebuild" // Codebuild.
|
||||||
|
CodecommitServiceID = "codecommit" // Codecommit.
|
||||||
|
CodedeployServiceID = "codedeploy" // Codedeploy.
|
||||||
|
CodepipelineServiceID = "codepipeline" // Codepipeline.
|
||||||
|
CodestarServiceID = "codestar" // Codestar.
|
||||||
|
CognitoIdentityServiceID = "cognito-identity" // CognitoIdentity.
|
||||||
|
CognitoIdpServiceID = "cognito-idp" // CognitoIdp.
|
||||||
|
CognitoSyncServiceID = "cognito-sync" // CognitoSync.
|
||||||
|
ComprehendServiceID = "comprehend" // Comprehend.
|
||||||
|
ConfigServiceID = "config" // Config.
|
||||||
|
CurServiceID = "cur" // Cur.
|
||||||
|
DatapipelineServiceID = "datapipeline" // Datapipeline.
|
||||||
|
DaxServiceID = "dax" // Dax.
|
||||||
|
DevicefarmServiceID = "devicefarm" // Devicefarm.
|
||||||
|
DirectconnectServiceID = "directconnect" // Directconnect.
|
||||||
|
DiscoveryServiceID = "discovery" // Discovery.
|
||||||
|
DmsServiceID = "dms" // Dms.
|
||||||
|
DsServiceID = "ds" // Ds.
|
||||||
|
DynamodbServiceID = "dynamodb" // Dynamodb.
|
||||||
|
Ec2ServiceID = "ec2" // Ec2.
|
||||||
|
Ec2metadataServiceID = "ec2metadata" // Ec2metadata.
|
||||||
|
EcrServiceID = "ecr" // Ecr.
|
||||||
|
EcsServiceID = "ecs" // Ecs.
|
||||||
|
ElasticacheServiceID = "elasticache" // Elasticache.
|
||||||
|
ElasticbeanstalkServiceID = "elasticbeanstalk" // Elasticbeanstalk.
|
||||||
|
ElasticfilesystemServiceID = "elasticfilesystem" // Elasticfilesystem.
|
||||||
|
ElasticloadbalancingServiceID = "elasticloadbalancing" // Elasticloadbalancing.
|
||||||
|
ElasticmapreduceServiceID = "elasticmapreduce" // Elasticmapreduce.
|
||||||
|
ElastictranscoderServiceID = "elastictranscoder" // Elastictranscoder.
|
||||||
|
EmailServiceID = "email" // Email.
|
||||||
|
EntitlementMarketplaceServiceID = "entitlement.marketplace" // EntitlementMarketplace.
|
||||||
|
EsServiceID = "es" // Es.
|
||||||
|
EventsServiceID = "events" // Events.
|
||||||
|
FirehoseServiceID = "firehose" // Firehose.
|
||||||
|
FmsServiceID = "fms" // Fms.
|
||||||
|
GameliftServiceID = "gamelift" // Gamelift.
|
||||||
|
GlacierServiceID = "glacier" // Glacier.
|
||||||
|
GlueServiceID = "glue" // Glue.
|
||||||
|
GreengrassServiceID = "greengrass" // Greengrass.
|
||||||
|
GuarddutyServiceID = "guardduty" // Guardduty.
|
||||||
|
HealthServiceID = "health" // Health.
|
||||||
|
IamServiceID = "iam" // Iam.
|
||||||
|
ImportexportServiceID = "importexport" // Importexport.
|
||||||
|
InspectorServiceID = "inspector" // Inspector.
|
||||||
|
IotServiceID = "iot" // Iot.
|
||||||
|
IotanalyticsServiceID = "iotanalytics" // Iotanalytics.
|
||||||
|
KinesisServiceID = "kinesis" // Kinesis.
|
||||||
|
KinesisanalyticsServiceID = "kinesisanalytics" // Kinesisanalytics.
|
||||||
|
KinesisvideoServiceID = "kinesisvideo" // Kinesisvideo.
|
||||||
|
KmsServiceID = "kms" // Kms.
|
||||||
|
LambdaServiceID = "lambda" // Lambda.
|
||||||
|
LightsailServiceID = "lightsail" // Lightsail.
|
||||||
|
LogsServiceID = "logs" // Logs.
|
||||||
|
MachinelearningServiceID = "machinelearning" // Machinelearning.
|
||||||
|
MarketplacecommerceanalyticsServiceID = "marketplacecommerceanalytics" // Marketplacecommerceanalytics.
|
||||||
|
MediaconvertServiceID = "mediaconvert" // Mediaconvert.
|
||||||
|
MedialiveServiceID = "medialive" // Medialive.
|
||||||
|
MediapackageServiceID = "mediapackage" // Mediapackage.
|
||||||
|
MediastoreServiceID = "mediastore" // Mediastore.
|
||||||
|
MeteringMarketplaceServiceID = "metering.marketplace" // MeteringMarketplace.
|
||||||
|
MghServiceID = "mgh" // Mgh.
|
||||||
|
MobileanalyticsServiceID = "mobileanalytics" // Mobileanalytics.
|
||||||
|
ModelsLexServiceID = "models.lex" // ModelsLex.
|
||||||
|
MonitoringServiceID = "monitoring" // Monitoring.
|
||||||
|
MturkRequesterServiceID = "mturk-requester" // MturkRequester.
|
||||||
|
NeptuneServiceID = "neptune" // Neptune.
|
||||||
|
OpsworksServiceID = "opsworks" // Opsworks.
|
||||||
|
OpsworksCmServiceID = "opsworks-cm" // OpsworksCm.
|
||||||
|
OrganizationsServiceID = "organizations" // Organizations.
|
||||||
|
PinpointServiceID = "pinpoint" // Pinpoint.
|
||||||
|
PollyServiceID = "polly" // Polly.
|
||||||
|
RdsServiceID = "rds" // Rds.
|
||||||
|
RedshiftServiceID = "redshift" // Redshift.
|
||||||
|
RekognitionServiceID = "rekognition" // Rekognition.
|
||||||
|
ResourceGroupsServiceID = "resource-groups" // ResourceGroups.
|
||||||
|
Route53ServiceID = "route53" // Route53.
|
||||||
|
Route53domainsServiceID = "route53domains" // Route53domains.
|
||||||
|
RuntimeLexServiceID = "runtime.lex" // RuntimeLex.
|
||||||
|
RuntimeSagemakerServiceID = "runtime.sagemaker" // RuntimeSagemaker.
|
||||||
|
S3ServiceID = "s3" // S3.
|
||||||
|
S3ControlServiceID = "s3-control" // S3Control.
|
||||||
|
SagemakerServiceID = "api.sagemaker" // Sagemaker.
|
||||||
|
SdbServiceID = "sdb" // Sdb.
|
||||||
|
SecretsmanagerServiceID = "secretsmanager" // Secretsmanager.
|
||||||
|
ServerlessrepoServiceID = "serverlessrepo" // Serverlessrepo.
|
||||||
|
ServicecatalogServiceID = "servicecatalog" // Servicecatalog.
|
||||||
|
ServicediscoveryServiceID = "servicediscovery" // Servicediscovery.
|
||||||
|
ShieldServiceID = "shield" // Shield.
|
||||||
|
SmsServiceID = "sms" // Sms.
|
||||||
|
SnowballServiceID = "snowball" // Snowball.
|
||||||
|
SnsServiceID = "sns" // Sns.
|
||||||
|
SqsServiceID = "sqs" // Sqs.
|
||||||
|
SsmServiceID = "ssm" // Ssm.
|
||||||
|
StatesServiceID = "states" // States.
|
||||||
|
StoragegatewayServiceID = "storagegateway" // Storagegateway.
|
||||||
|
StreamsDynamodbServiceID = "streams.dynamodb" // StreamsDynamodb.
|
||||||
|
StsServiceID = "sts" // Sts.
|
||||||
|
SupportServiceID = "support" // Support.
|
||||||
|
SwfServiceID = "swf" // Swf.
|
||||||
|
TaggingServiceID = "tagging" // Tagging.
|
||||||
|
TransferServiceID = "transfer" // Transfer.
|
||||||
|
TranslateServiceID = "translate" // Translate.
|
||||||
|
WafServiceID = "waf" // Waf.
|
||||||
|
WafRegionalServiceID = "waf-regional" // WafRegional.
|
||||||
|
WorkdocsServiceID = "workdocs" // Workdocs.
|
||||||
|
WorkmailServiceID = "workmail" // Workmail.
|
||||||
|
WorkspacesServiceID = "workspaces" // Workspaces.
|
||||||
|
XrayServiceID = "xray" // Xray.
|
||||||
|
)
|
|
@ -35,7 +35,7 @@ type Options struct {
|
||||||
//
|
//
|
||||||
// If resolving an endpoint on the partition list the provided region will
|
// If resolving an endpoint on the partition list the provided region will
|
||||||
// be used to determine which partition's domain name pattern to the service
|
// be used to determine which partition's domain name pattern to the service
|
||||||
// endpoint ID with. If both the service and region are unkonwn and resolving
|
// endpoint ID with. If both the service and region are unknown and resolving
|
||||||
// the endpoint on partition list an UnknownEndpointError error will be returned.
|
// the endpoint on partition list an UnknownEndpointError error will be returned.
|
||||||
//
|
//
|
||||||
// If resolving and endpoint on a partition specific resolver that partition's
|
// If resolving and endpoint on a partition specific resolver that partition's
|
||||||
|
|
|
@ -16,6 +16,10 @@ import (
|
||||||
type CodeGenOptions struct {
|
type CodeGenOptions struct {
|
||||||
// Options for how the model will be decoded.
|
// Options for how the model will be decoded.
|
||||||
DecodeModelOptions DecodeModelOptions
|
DecodeModelOptions DecodeModelOptions
|
||||||
|
|
||||||
|
// Disables code generation of the service endpoint prefix IDs defined in
|
||||||
|
// the model.
|
||||||
|
DisableGenerateServiceIDs bool
|
||||||
}
|
}
|
||||||
|
|
||||||
// Set combines all of the option functions together
|
// Set combines all of the option functions together
|
||||||
|
@ -39,8 +43,16 @@ func CodeGenModel(modelFile io.Reader, outFile io.Writer, optFns ...func(*CodeGe
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
v := struct {
|
||||||
|
Resolver
|
||||||
|
CodeGenOptions
|
||||||
|
}{
|
||||||
|
Resolver: resolver,
|
||||||
|
CodeGenOptions: opts,
|
||||||
|
}
|
||||||
|
|
||||||
tmpl := template.Must(template.New("tmpl").Funcs(funcMap).Parse(v3Tmpl))
|
tmpl := template.Must(template.New("tmpl").Funcs(funcMap).Parse(v3Tmpl))
|
||||||
if err := tmpl.ExecuteTemplate(outFile, "defaults", resolver); err != nil {
|
if err := tmpl.ExecuteTemplate(outFile, "defaults", v); err != nil {
|
||||||
return fmt.Errorf("failed to execute template, %v", err)
|
return fmt.Errorf("failed to execute template, %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -166,15 +178,17 @@ import (
|
||||||
"regexp"
|
"regexp"
|
||||||
)
|
)
|
||||||
|
|
||||||
{{ template "partition consts" . }}
|
{{ template "partition consts" $.Resolver }}
|
||||||
|
|
||||||
{{ range $_, $partition := . }}
|
{{ range $_, $partition := $.Resolver }}
|
||||||
{{ template "partition region consts" $partition }}
|
{{ template "partition region consts" $partition }}
|
||||||
{{ end }}
|
{{ end }}
|
||||||
|
|
||||||
{{ template "service consts" . }}
|
{{ if not $.DisableGenerateServiceIDs -}}
|
||||||
|
{{ template "service consts" $.Resolver }}
|
||||||
|
{{- end }}
|
||||||
|
|
||||||
{{ template "endpoint resolvers" . }}
|
{{ template "endpoint resolvers" $.Resolver }}
|
||||||
{{- end }}
|
{{- end }}
|
||||||
|
|
||||||
{{ define "partition consts" }}
|
{{ define "partition consts" }}
|
||||||
|
|
|
@ -5,13 +5,9 @@ import "github.com/aws/aws-sdk-go/aws/awserr"
|
||||||
var (
|
var (
|
||||||
// ErrMissingRegion is an error that is returned if region configuration is
|
// ErrMissingRegion is an error that is returned if region configuration is
|
||||||
// not found.
|
// not found.
|
||||||
//
|
|
||||||
// @readonly
|
|
||||||
ErrMissingRegion = awserr.New("MissingRegion", "could not find region configuration", nil)
|
ErrMissingRegion = awserr.New("MissingRegion", "could not find region configuration", nil)
|
||||||
|
|
||||||
// ErrMissingEndpoint is an error that is returned if an endpoint cannot be
|
// ErrMissingEndpoint is an error that is returned if an endpoint cannot be
|
||||||
// resolved for a service.
|
// resolved for a service.
|
||||||
//
|
|
||||||
// @readonly
|
|
||||||
ErrMissingEndpoint = awserr.New("MissingEndpoint", "'Endpoint' configuration is required for this service", nil)
|
ErrMissingEndpoint = awserr.New("MissingEndpoint", "'Endpoint' configuration is required for this service", nil)
|
||||||
)
|
)
|
||||||
|
|
|
@ -1,18 +1,17 @@
|
||||||
// +build !appengine,!plan9
|
|
||||||
|
|
||||||
package request
|
package request
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"net"
|
"strings"
|
||||||
"os"
|
|
||||||
"syscall"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func isErrConnectionReset(err error) bool {
|
func isErrConnectionReset(err error) bool {
|
||||||
if opErr, ok := err.(*net.OpError); ok {
|
if strings.Contains(err.Error(), "read: connection reset") {
|
||||||
if sysErr, ok := opErr.Err.(*os.SyscallError); ok {
|
return false
|
||||||
return sysErr.Err == syscall.ECONNRESET
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if strings.Contains(err.Error(), "connection reset") ||
|
||||||
|
strings.Contains(err.Error(), "broken pipe") {
|
||||||
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
return false
|
return false
|
||||||
|
|
|
@ -1,11 +0,0 @@
|
||||||
// +build appengine plan9
|
|
||||||
|
|
||||||
package request
|
|
||||||
|
|
||||||
import (
|
|
||||||
"strings"
|
|
||||||
)
|
|
||||||
|
|
||||||
func isErrConnectionReset(err error) bool {
|
|
||||||
return strings.Contains(err.Error(), "connection reset")
|
|
||||||
}
|
|
|
@ -19,6 +19,7 @@ type Handlers struct {
|
||||||
UnmarshalError HandlerList
|
UnmarshalError HandlerList
|
||||||
Retry HandlerList
|
Retry HandlerList
|
||||||
AfterRetry HandlerList
|
AfterRetry HandlerList
|
||||||
|
CompleteAttempt HandlerList
|
||||||
Complete HandlerList
|
Complete HandlerList
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -36,6 +37,7 @@ func (h *Handlers) Copy() Handlers {
|
||||||
UnmarshalMeta: h.UnmarshalMeta.copy(),
|
UnmarshalMeta: h.UnmarshalMeta.copy(),
|
||||||
Retry: h.Retry.copy(),
|
Retry: h.Retry.copy(),
|
||||||
AfterRetry: h.AfterRetry.copy(),
|
AfterRetry: h.AfterRetry.copy(),
|
||||||
|
CompleteAttempt: h.CompleteAttempt.copy(),
|
||||||
Complete: h.Complete.copy(),
|
Complete: h.Complete.copy(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -53,9 +55,55 @@ func (h *Handlers) Clear() {
|
||||||
h.ValidateResponse.Clear()
|
h.ValidateResponse.Clear()
|
||||||
h.Retry.Clear()
|
h.Retry.Clear()
|
||||||
h.AfterRetry.Clear()
|
h.AfterRetry.Clear()
|
||||||
|
h.CompleteAttempt.Clear()
|
||||||
h.Complete.Clear()
|
h.Complete.Clear()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// IsEmpty returns if there are no handlers in any of the handlerlists.
|
||||||
|
func (h *Handlers) IsEmpty() bool {
|
||||||
|
if h.Validate.Len() != 0 {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if h.Build.Len() != 0 {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if h.Send.Len() != 0 {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if h.Sign.Len() != 0 {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if h.Unmarshal.Len() != 0 {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if h.UnmarshalStream.Len() != 0 {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if h.UnmarshalMeta.Len() != 0 {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if h.UnmarshalError.Len() != 0 {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if h.ValidateResponse.Len() != 0 {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if h.Retry.Len() != 0 {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if h.AfterRetry.Len() != 0 {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if h.CompleteAttempt.Len() != 0 {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if h.Complete.Len() != 0 {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
// A HandlerListRunItem represents an entry in the HandlerList which
|
// A HandlerListRunItem represents an entry in the HandlerList which
|
||||||
// is being run.
|
// is being run.
|
||||||
type HandlerListRunItem struct {
|
type HandlerListRunItem struct {
|
||||||
|
|
|
@ -122,7 +122,6 @@ func New(cfg aws.Config, clientInfo metadata.ClientInfo, handlers Handlers,
|
||||||
Handlers: handlers.Copy(),
|
Handlers: handlers.Copy(),
|
||||||
|
|
||||||
Retryer: retryer,
|
Retryer: retryer,
|
||||||
AttemptTime: time.Now(),
|
|
||||||
Time: time.Now(),
|
Time: time.Now(),
|
||||||
ExpireTime: 0,
|
ExpireTime: 0,
|
||||||
Operation: operation,
|
Operation: operation,
|
||||||
|
@ -233,6 +232,10 @@ func (r *Request) WillRetry() bool {
|
||||||
return r.Error != nil && aws.BoolValue(r.Retryable) && r.RetryCount < r.MaxRetries()
|
return r.Error != nil && aws.BoolValue(r.Retryable) && r.RetryCount < r.MaxRetries()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func fmtAttemptCount(retryCount, maxRetries int) string {
|
||||||
|
return fmt.Sprintf("attempt %v/%v", retryCount, maxRetries)
|
||||||
|
}
|
||||||
|
|
||||||
// ParamsFilled returns if the request's parameters have been populated
|
// ParamsFilled returns if the request's parameters have been populated
|
||||||
// and the parameters are valid. False is returned if no parameters are
|
// and the parameters are valid. False is returned if no parameters are
|
||||||
// provided or invalid.
|
// provided or invalid.
|
||||||
|
@ -266,7 +269,9 @@ func (r *Request) SetReaderBody(reader io.ReadSeeker) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Presign returns the request's signed URL. Error will be returned
|
// Presign returns the request's signed URL. Error will be returned
|
||||||
// if the signing fails.
|
// if the signing fails. The expire parameter is only used for presigned Amazon
|
||||||
|
// S3 API requests. All other AWS services will use a fixed expiration
|
||||||
|
// time of 15 minutes.
|
||||||
//
|
//
|
||||||
// It is invalid to create a presigned URL with a expire duration 0 or less. An
|
// It is invalid to create a presigned URL with a expire duration 0 or less. An
|
||||||
// error is returned if expire duration is 0 or less.
|
// error is returned if expire duration is 0 or less.
|
||||||
|
@ -283,7 +288,9 @@ func (r *Request) Presign(expire time.Duration) (string, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// PresignRequest behaves just like presign, with the addition of returning a
|
// PresignRequest behaves just like presign, with the addition of returning a
|
||||||
// set of headers that were signed.
|
// set of headers that were signed. The expire parameter is only used for
|
||||||
|
// presigned Amazon S3 API requests. All other AWS services will use a fixed
|
||||||
|
// expiration time of 15 minutes.
|
||||||
//
|
//
|
||||||
// It is invalid to create a presigned URL with a expire duration 0 or less. An
|
// It is invalid to create a presigned URL with a expire duration 0 or less. An
|
||||||
// error is returned if expire duration is 0 or less.
|
// error is returned if expire duration is 0 or less.
|
||||||
|
@ -328,16 +335,17 @@ func getPresignedURL(r *Request, expire time.Duration) (string, http.Header, err
|
||||||
return r.HTTPRequest.URL.String(), r.SignedHeaderVals, nil
|
return r.HTTPRequest.URL.String(), r.SignedHeaderVals, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func debugLogReqError(r *Request, stage string, retrying bool, err error) {
|
const (
|
||||||
|
willRetry = "will retry"
|
||||||
|
notRetrying = "not retrying"
|
||||||
|
retryCount = "retry %v/%v"
|
||||||
|
)
|
||||||
|
|
||||||
|
func debugLogReqError(r *Request, stage, retryStr string, err error) {
|
||||||
if !r.Config.LogLevel.Matches(aws.LogDebugWithRequestErrors) {
|
if !r.Config.LogLevel.Matches(aws.LogDebugWithRequestErrors) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
retryStr := "not retrying"
|
|
||||||
if retrying {
|
|
||||||
retryStr = "will retry"
|
|
||||||
}
|
|
||||||
|
|
||||||
r.Config.Logger.Log(fmt.Sprintf("DEBUG: %s %s/%s failed, %s, error %v",
|
r.Config.Logger.Log(fmt.Sprintf("DEBUG: %s %s/%s failed, %s, error %v",
|
||||||
stage, r.ClientInfo.ServiceName, r.Operation.Name, retryStr, err))
|
stage, r.ClientInfo.ServiceName, r.Operation.Name, retryStr, err))
|
||||||
}
|
}
|
||||||
|
@ -356,12 +364,12 @@ func (r *Request) Build() error {
|
||||||
if !r.built {
|
if !r.built {
|
||||||
r.Handlers.Validate.Run(r)
|
r.Handlers.Validate.Run(r)
|
||||||
if r.Error != nil {
|
if r.Error != nil {
|
||||||
debugLogReqError(r, "Validate Request", false, r.Error)
|
debugLogReqError(r, "Validate Request", notRetrying, r.Error)
|
||||||
return r.Error
|
return r.Error
|
||||||
}
|
}
|
||||||
r.Handlers.Build.Run(r)
|
r.Handlers.Build.Run(r)
|
||||||
if r.Error != nil {
|
if r.Error != nil {
|
||||||
debugLogReqError(r, "Build Request", false, r.Error)
|
debugLogReqError(r, "Build Request", notRetrying, r.Error)
|
||||||
return r.Error
|
return r.Error
|
||||||
}
|
}
|
||||||
r.built = true
|
r.built = true
|
||||||
|
@ -377,7 +385,7 @@ func (r *Request) Build() error {
|
||||||
func (r *Request) Sign() error {
|
func (r *Request) Sign() error {
|
||||||
r.Build()
|
r.Build()
|
||||||
if r.Error != nil {
|
if r.Error != nil {
|
||||||
debugLogReqError(r, "Build Request", false, r.Error)
|
debugLogReqError(r, "Build Request", notRetrying, r.Error)
|
||||||
return r.Error
|
return r.Error
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -462,9 +470,38 @@ func (r *Request) Send() error {
|
||||||
r.Handlers.Complete.Run(r)
|
r.Handlers.Complete.Run(r)
|
||||||
}()
|
}()
|
||||||
|
|
||||||
|
if err := r.Error; err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
for {
|
for {
|
||||||
|
r.Error = nil
|
||||||
r.AttemptTime = time.Now()
|
r.AttemptTime = time.Now()
|
||||||
if aws.BoolValue(r.Retryable) {
|
|
||||||
|
if err := r.Sign(); err != nil {
|
||||||
|
debugLogReqError(r, "Sign Request", notRetrying, err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := r.sendRequest(); err == nil {
|
||||||
|
return nil
|
||||||
|
} else if !shouldRetryError(r.Error) {
|
||||||
|
return err
|
||||||
|
} else {
|
||||||
|
r.Handlers.Retry.Run(r)
|
||||||
|
r.Handlers.AfterRetry.Run(r)
|
||||||
|
|
||||||
|
if r.Error != nil || !aws.BoolValue(r.Retryable) {
|
||||||
|
return r.Error
|
||||||
|
}
|
||||||
|
|
||||||
|
r.prepareRetry()
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *Request) prepareRetry() {
|
||||||
if r.Config.LogLevel.Matches(aws.LogDebugWithRequestRetries) {
|
if r.Config.LogLevel.Matches(aws.LogDebugWithRequestRetries) {
|
||||||
r.Config.Logger.Log(fmt.Sprintf("DEBUG: Retrying Request %s/%s, attempt %d",
|
r.Config.Logger.Log(fmt.Sprintf("DEBUG: Retrying Request %s/%s, attempt %d",
|
||||||
r.ClientInfo.ServiceName, r.Operation.Name, r.RetryCount))
|
r.ClientInfo.ServiceName, r.Operation.Name, r.RetryCount))
|
||||||
|
@ -481,62 +518,37 @@ func (r *Request) Send() error {
|
||||||
if r.HTTPResponse != nil && r.HTTPResponse.Body != nil {
|
if r.HTTPResponse != nil && r.HTTPResponse.Body != nil {
|
||||||
r.HTTPResponse.Body.Close()
|
r.HTTPResponse.Body.Close()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
r.Sign()
|
func (r *Request) sendRequest() (sendErr error) {
|
||||||
if r.Error != nil {
|
defer r.Handlers.CompleteAttempt.Run(r)
|
||||||
return r.Error
|
|
||||||
}
|
|
||||||
|
|
||||||
r.Retryable = nil
|
r.Retryable = nil
|
||||||
|
|
||||||
r.Handlers.Send.Run(r)
|
r.Handlers.Send.Run(r)
|
||||||
if r.Error != nil {
|
if r.Error != nil {
|
||||||
if !shouldRetryCancel(r) {
|
debugLogReqError(r, "Send Request",
|
||||||
|
fmtAttemptCount(r.RetryCount, r.MaxRetries()),
|
||||||
|
r.Error)
|
||||||
return r.Error
|
return r.Error
|
||||||
}
|
}
|
||||||
|
|
||||||
err := r.Error
|
|
||||||
r.Handlers.Retry.Run(r)
|
|
||||||
r.Handlers.AfterRetry.Run(r)
|
|
||||||
if r.Error != nil {
|
|
||||||
debugLogReqError(r, "Send Request", false, err)
|
|
||||||
return r.Error
|
|
||||||
}
|
|
||||||
debugLogReqError(r, "Send Request", true, err)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
r.Handlers.UnmarshalMeta.Run(r)
|
r.Handlers.UnmarshalMeta.Run(r)
|
||||||
r.Handlers.ValidateResponse.Run(r)
|
r.Handlers.ValidateResponse.Run(r)
|
||||||
if r.Error != nil {
|
if r.Error != nil {
|
||||||
r.Handlers.UnmarshalError.Run(r)
|
r.Handlers.UnmarshalError.Run(r)
|
||||||
err := r.Error
|
debugLogReqError(r, "Validate Response",
|
||||||
|
fmtAttemptCount(r.RetryCount, r.MaxRetries()),
|
||||||
r.Handlers.Retry.Run(r)
|
r.Error)
|
||||||
r.Handlers.AfterRetry.Run(r)
|
|
||||||
if r.Error != nil {
|
|
||||||
debugLogReqError(r, "Validate Response", false, err)
|
|
||||||
return r.Error
|
return r.Error
|
||||||
}
|
}
|
||||||
debugLogReqError(r, "Validate Response", true, err)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
r.Handlers.Unmarshal.Run(r)
|
r.Handlers.Unmarshal.Run(r)
|
||||||
if r.Error != nil {
|
if r.Error != nil {
|
||||||
err := r.Error
|
debugLogReqError(r, "Unmarshal Response",
|
||||||
r.Handlers.Retry.Run(r)
|
fmtAttemptCount(r.RetryCount, r.MaxRetries()),
|
||||||
r.Handlers.AfterRetry.Run(r)
|
r.Error)
|
||||||
if r.Error != nil {
|
|
||||||
debugLogReqError(r, "Unmarshal Response", false, err)
|
|
||||||
return r.Error
|
return r.Error
|
||||||
}
|
}
|
||||||
debugLogReqError(r, "Unmarshal Response", true, err)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
break
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
@ -561,30 +573,49 @@ func AddToUserAgent(r *Request, s string) {
|
||||||
r.HTTPRequest.Header.Set("User-Agent", s)
|
r.HTTPRequest.Header.Set("User-Agent", s)
|
||||||
}
|
}
|
||||||
|
|
||||||
func shouldRetryCancel(r *Request) bool {
|
type temporary interface {
|
||||||
awsErr, ok := r.Error.(awserr.Error)
|
Temporary() bool
|
||||||
timeoutErr := false
|
}
|
||||||
errStr := r.Error.Error()
|
|
||||||
if ok {
|
func shouldRetryError(origErr error) bool {
|
||||||
if awsErr.Code() == CanceledErrorCode {
|
switch err := origErr.(type) {
|
||||||
|
case awserr.Error:
|
||||||
|
if err.Code() == CanceledErrorCode {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
err := awsErr.OrigErr()
|
return shouldRetryError(err.OrigErr())
|
||||||
netErr, netOK := err.(net.Error)
|
case *url.Error:
|
||||||
timeoutErr = netOK && netErr.Temporary()
|
if strings.Contains(err.Error(), "connection refused") {
|
||||||
if urlErr, ok := err.(*url.Error); !timeoutErr && ok {
|
// Refused connections should be retried as the service may not yet
|
||||||
errStr = urlErr.Err.Error()
|
// be running on the port. Go TCP dial considers refused
|
||||||
|
// connections as not temporary.
|
||||||
|
return true
|
||||||
}
|
}
|
||||||
|
// *url.Error only implements Temporary after golang 1.6 but since
|
||||||
|
// url.Error only wraps the error:
|
||||||
|
return shouldRetryError(err.Err)
|
||||||
|
case temporary:
|
||||||
|
if netErr, ok := err.(*net.OpError); ok && netErr.Op == "dial" {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
// If the error is temporary, we want to allow continuation of the
|
||||||
|
// retry process
|
||||||
|
return err.Temporary() || isErrConnectionReset(origErr)
|
||||||
|
case nil:
|
||||||
|
// `awserr.Error.OrigErr()` can be nil, meaning there was an error but
|
||||||
|
// because we don't know the cause, it is marked as retryable. See
|
||||||
|
// TestRequest4xxUnretryable for an example.
|
||||||
|
return true
|
||||||
|
default:
|
||||||
|
switch err.Error() {
|
||||||
|
case "net/http: request canceled",
|
||||||
|
"net/http: request canceled while waiting for connection":
|
||||||
|
// known 1.5 error case when an http request is cancelled
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
// here we don't know the error; so we allow a retry.
|
||||||
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
// There can be two types of canceled errors here.
|
|
||||||
// The first being a net.Error and the other being an error.
|
|
||||||
// If the request was timed out, we want to continue the retry
|
|
||||||
// process. Otherwise, return the canceled error.
|
|
||||||
return timeoutErr ||
|
|
||||||
(errStr != "net/http: request canceled" &&
|
|
||||||
errStr != "net/http: request canceled while waiting for connection")
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// SanitizeHostForHeader removes default port from host and updates request.Host
|
// SanitizeHostForHeader removes default port from host and updates request.Host
|
||||||
|
|
|
@ -38,8 +38,10 @@ var throttleCodes = map[string]struct{}{
|
||||||
"ThrottlingException": {},
|
"ThrottlingException": {},
|
||||||
"RequestLimitExceeded": {},
|
"RequestLimitExceeded": {},
|
||||||
"RequestThrottled": {},
|
"RequestThrottled": {},
|
||||||
|
"RequestThrottledException": {},
|
||||||
"TooManyRequestsException": {}, // Lambda functions
|
"TooManyRequestsException": {}, // Lambda functions
|
||||||
"PriorRequestNotComplete": {}, // Route53
|
"PriorRequestNotComplete": {}, // Route53
|
||||||
|
"TransactionInProgressException": {},
|
||||||
}
|
}
|
||||||
|
|
||||||
// credsExpiredCodes is a collection of error codes which signify the credentials
|
// credsExpiredCodes is a collection of error codes which signify the credentials
|
||||||
|
|
|
@ -17,6 +17,12 @@ const (
|
||||||
ParamMinValueErrCode = "ParamMinValueError"
|
ParamMinValueErrCode = "ParamMinValueError"
|
||||||
// ParamMinLenErrCode is the error code for fields without enough elements.
|
// ParamMinLenErrCode is the error code for fields without enough elements.
|
||||||
ParamMinLenErrCode = "ParamMinLenError"
|
ParamMinLenErrCode = "ParamMinLenError"
|
||||||
|
// ParamMaxLenErrCode is the error code for value being too long.
|
||||||
|
ParamMaxLenErrCode = "ParamMaxLenError"
|
||||||
|
|
||||||
|
// ParamFormatErrCode is the error code for a field with invalid
|
||||||
|
// format or characters.
|
||||||
|
ParamFormatErrCode = "ParamFormatInvalidError"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Validator provides a way for types to perform validation logic on their
|
// Validator provides a way for types to perform validation logic on their
|
||||||
|
@ -232,3 +238,49 @@ func NewErrParamMinLen(field string, min int) *ErrParamMinLen {
|
||||||
func (e *ErrParamMinLen) MinLen() int {
|
func (e *ErrParamMinLen) MinLen() int {
|
||||||
return e.min
|
return e.min
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// An ErrParamMaxLen represents a maximum length parameter error.
|
||||||
|
type ErrParamMaxLen struct {
|
||||||
|
errInvalidParam
|
||||||
|
max int
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewErrParamMaxLen creates a new maximum length parameter error.
|
||||||
|
func NewErrParamMaxLen(field string, max int, value string) *ErrParamMaxLen {
|
||||||
|
return &ErrParamMaxLen{
|
||||||
|
errInvalidParam: errInvalidParam{
|
||||||
|
code: ParamMaxLenErrCode,
|
||||||
|
field: field,
|
||||||
|
msg: fmt.Sprintf("maximum size of %v, %v", max, value),
|
||||||
|
},
|
||||||
|
max: max,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// MaxLen returns the field's required minimum length.
|
||||||
|
func (e *ErrParamMaxLen) MaxLen() int {
|
||||||
|
return e.max
|
||||||
|
}
|
||||||
|
|
||||||
|
// An ErrParamFormat represents a invalid format parameter error.
|
||||||
|
type ErrParamFormat struct {
|
||||||
|
errInvalidParam
|
||||||
|
format string
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewErrParamFormat creates a new invalid format parameter error.
|
||||||
|
func NewErrParamFormat(field string, format, value string) *ErrParamFormat {
|
||||||
|
return &ErrParamFormat{
|
||||||
|
errInvalidParam: errInvalidParam{
|
||||||
|
code: ParamFormatErrCode,
|
||||||
|
field: field,
|
||||||
|
msg: fmt.Sprintf("format %v, %v", format, value),
|
||||||
|
},
|
||||||
|
format: format,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Format returns the field's required format.
|
||||||
|
func (e *ErrParamFormat) Format() string {
|
||||||
|
return e.format
|
||||||
|
}
|
||||||
|
|
26
vendor/github.com/aws/aws-sdk-go/aws/session/cabundle_transport.go
generated
vendored
Normal file
26
vendor/github.com/aws/aws-sdk-go/aws/session/cabundle_transport.go
generated
vendored
Normal file
|
@ -0,0 +1,26 @@
|
||||||
|
// +build go1.7
|
||||||
|
|
||||||
|
package session
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net"
|
||||||
|
"net/http"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Transport that should be used when a custom CA bundle is specified with the
|
||||||
|
// SDK.
|
||||||
|
func getCABundleTransport() *http.Transport {
|
||||||
|
return &http.Transport{
|
||||||
|
Proxy: http.ProxyFromEnvironment,
|
||||||
|
DialContext: (&net.Dialer{
|
||||||
|
Timeout: 30 * time.Second,
|
||||||
|
KeepAlive: 30 * time.Second,
|
||||||
|
DualStack: true,
|
||||||
|
}).DialContext,
|
||||||
|
MaxIdleConns: 100,
|
||||||
|
IdleConnTimeout: 90 * time.Second,
|
||||||
|
TLSHandshakeTimeout: 10 * time.Second,
|
||||||
|
ExpectContinueTimeout: 1 * time.Second,
|
||||||
|
}
|
||||||
|
}
|
22
vendor/github.com/aws/aws-sdk-go/aws/session/cabundle_transport_1_5.go
generated
vendored
Normal file
22
vendor/github.com/aws/aws-sdk-go/aws/session/cabundle_transport_1_5.go
generated
vendored
Normal file
|
@ -0,0 +1,22 @@
|
||||||
|
// +build !go1.6,go1.5
|
||||||
|
|
||||||
|
package session
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net"
|
||||||
|
"net/http"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Transport that should be used when a custom CA bundle is specified with the
|
||||||
|
// SDK.
|
||||||
|
func getCABundleTransport() *http.Transport {
|
||||||
|
return &http.Transport{
|
||||||
|
Proxy: http.ProxyFromEnvironment,
|
||||||
|
Dial: (&net.Dialer{
|
||||||
|
Timeout: 30 * time.Second,
|
||||||
|
KeepAlive: 30 * time.Second,
|
||||||
|
}).Dial,
|
||||||
|
TLSHandshakeTimeout: 10 * time.Second,
|
||||||
|
}
|
||||||
|
}
|
23
vendor/github.com/aws/aws-sdk-go/aws/session/cabundle_transport_1_6.go
generated
vendored
Normal file
23
vendor/github.com/aws/aws-sdk-go/aws/session/cabundle_transport_1_6.go
generated
vendored
Normal file
|
@ -0,0 +1,23 @@
|
||||||
|
// +build !go1.7,go1.6
|
||||||
|
|
||||||
|
package session
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net"
|
||||||
|
"net/http"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Transport that should be used when a custom CA bundle is specified with the
|
||||||
|
// SDK.
|
||||||
|
func getCABundleTransport() *http.Transport {
|
||||||
|
return &http.Transport{
|
||||||
|
Proxy: http.ProxyFromEnvironment,
|
||||||
|
Dial: (&net.Dialer{
|
||||||
|
Timeout: 30 * time.Second,
|
||||||
|
KeepAlive: 30 * time.Second,
|
||||||
|
}).Dial,
|
||||||
|
TLSHandshakeTimeout: 10 * time.Second,
|
||||||
|
ExpectContinueTimeout: 1 * time.Second,
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,260 @@
|
||||||
|
package session
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
|
||||||
|
"github.com/aws/aws-sdk-go/aws"
|
||||||
|
"github.com/aws/aws-sdk-go/aws/awserr"
|
||||||
|
"github.com/aws/aws-sdk-go/aws/credentials"
|
||||||
|
"github.com/aws/aws-sdk-go/aws/credentials/processcreds"
|
||||||
|
"github.com/aws/aws-sdk-go/aws/credentials/stscreds"
|
||||||
|
"github.com/aws/aws-sdk-go/aws/defaults"
|
||||||
|
"github.com/aws/aws-sdk-go/aws/request"
|
||||||
|
"github.com/aws/aws-sdk-go/internal/shareddefaults"
|
||||||
|
)
|
||||||
|
|
||||||
|
func resolveCredentials(cfg *aws.Config,
|
||||||
|
envCfg envConfig, sharedCfg sharedConfig,
|
||||||
|
handlers request.Handlers,
|
||||||
|
sessOpts Options,
|
||||||
|
) (*credentials.Credentials, error) {
|
||||||
|
|
||||||
|
switch {
|
||||||
|
case len(envCfg.Profile) != 0:
|
||||||
|
// User explicitly provided an Profile, so load from shared config
|
||||||
|
// first.
|
||||||
|
return resolveCredsFromProfile(cfg, envCfg, sharedCfg, handlers, sessOpts)
|
||||||
|
|
||||||
|
case envCfg.Creds.HasKeys():
|
||||||
|
// Environment credentials
|
||||||
|
return credentials.NewStaticCredentialsFromCreds(envCfg.Creds), nil
|
||||||
|
|
||||||
|
case len(envCfg.WebIdentityTokenFilePath) != 0:
|
||||||
|
// Web identity token from environment, RoleARN required to also be
|
||||||
|
// set.
|
||||||
|
return assumeWebIdentity(cfg, handlers,
|
||||||
|
envCfg.WebIdentityTokenFilePath,
|
||||||
|
envCfg.RoleARN,
|
||||||
|
envCfg.RoleSessionName,
|
||||||
|
)
|
||||||
|
|
||||||
|
default:
|
||||||
|
// Fallback to the "default" credential resolution chain.
|
||||||
|
return resolveCredsFromProfile(cfg, envCfg, sharedCfg, handlers, sessOpts)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// WebIdentityEmptyRoleARNErr will occur if 'AWS_WEB_IDENTITY_TOKEN_FILE' was set but
|
||||||
|
// 'AWS_IAM_ROLE_ARN' was not set.
|
||||||
|
var WebIdentityEmptyRoleARNErr = awserr.New(stscreds.ErrCodeWebIdentity, "role ARN is not set", nil)
|
||||||
|
|
||||||
|
// WebIdentityEmptyTokenFilePathErr will occur if 'AWS_IAM_ROLE_ARN' was set but
|
||||||
|
// 'AWS_WEB_IDENTITY_TOKEN_FILE' was not set.
|
||||||
|
var WebIdentityEmptyTokenFilePathErr = awserr.New(stscreds.ErrCodeWebIdentity, "token file path is not set", nil)
|
||||||
|
|
||||||
|
func assumeWebIdentity(cfg *aws.Config, handlers request.Handlers,
|
||||||
|
filepath string,
|
||||||
|
roleARN, sessionName string,
|
||||||
|
) (*credentials.Credentials, error) {
|
||||||
|
|
||||||
|
if len(filepath) == 0 {
|
||||||
|
return nil, WebIdentityEmptyTokenFilePathErr
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(roleARN) == 0 {
|
||||||
|
return nil, WebIdentityEmptyRoleARNErr
|
||||||
|
}
|
||||||
|
|
||||||
|
creds := stscreds.NewWebIdentityCredentials(
|
||||||
|
&Session{
|
||||||
|
Config: cfg,
|
||||||
|
Handlers: handlers.Copy(),
|
||||||
|
},
|
||||||
|
roleARN,
|
||||||
|
sessionName,
|
||||||
|
filepath,
|
||||||
|
)
|
||||||
|
|
||||||
|
return creds, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func resolveCredsFromProfile(cfg *aws.Config,
|
||||||
|
envCfg envConfig, sharedCfg sharedConfig,
|
||||||
|
handlers request.Handlers,
|
||||||
|
sessOpts Options,
|
||||||
|
) (creds *credentials.Credentials, err error) {
|
||||||
|
|
||||||
|
switch {
|
||||||
|
case sharedCfg.SourceProfile != nil:
|
||||||
|
// Assume IAM role with credentials source from a different profile.
|
||||||
|
creds, err = resolveCredsFromProfile(cfg, envCfg,
|
||||||
|
*sharedCfg.SourceProfile, handlers, sessOpts,
|
||||||
|
)
|
||||||
|
|
||||||
|
case sharedCfg.Creds.HasKeys():
|
||||||
|
// Static Credentials from Shared Config/Credentials file.
|
||||||
|
creds = credentials.NewStaticCredentialsFromCreds(
|
||||||
|
sharedCfg.Creds,
|
||||||
|
)
|
||||||
|
|
||||||
|
case len(sharedCfg.CredentialProcess) != 0:
|
||||||
|
// Get credentials from CredentialProcess
|
||||||
|
creds = processcreds.NewCredentials(sharedCfg.CredentialProcess)
|
||||||
|
|
||||||
|
case len(sharedCfg.CredentialSource) != 0:
|
||||||
|
creds, err = resolveCredsFromSource(cfg, envCfg,
|
||||||
|
sharedCfg, handlers, sessOpts,
|
||||||
|
)
|
||||||
|
|
||||||
|
case len(sharedCfg.WebIdentityTokenFile) != 0:
|
||||||
|
// Credentials from Assume Web Identity token require an IAM Role, and
|
||||||
|
// that roll will be assumed. May be wrapped with another assume role
|
||||||
|
// via SourceProfile.
|
||||||
|
return assumeWebIdentity(cfg, handlers,
|
||||||
|
sharedCfg.WebIdentityTokenFile,
|
||||||
|
sharedCfg.RoleARN,
|
||||||
|
sharedCfg.RoleSessionName,
|
||||||
|
)
|
||||||
|
|
||||||
|
default:
|
||||||
|
// Fallback to default credentials provider, include mock errors for
|
||||||
|
// the credential chain so user can identify why credentials failed to
|
||||||
|
// be retrieved.
|
||||||
|
creds = credentials.NewCredentials(&credentials.ChainProvider{
|
||||||
|
VerboseErrors: aws.BoolValue(cfg.CredentialsChainVerboseErrors),
|
||||||
|
Providers: []credentials.Provider{
|
||||||
|
&credProviderError{
|
||||||
|
Err: awserr.New("EnvAccessKeyNotFound",
|
||||||
|
"failed to find credentials in the environment.", nil),
|
||||||
|
},
|
||||||
|
&credProviderError{
|
||||||
|
Err: awserr.New("SharedCredsLoad",
|
||||||
|
fmt.Sprintf("failed to load profile, %s.", envCfg.Profile), nil),
|
||||||
|
},
|
||||||
|
defaults.RemoteCredProvider(*cfg, handlers),
|
||||||
|
},
|
||||||
|
})
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(sharedCfg.RoleARN) > 0 {
|
||||||
|
cfgCp := *cfg
|
||||||
|
cfgCp.Credentials = creds
|
||||||
|
return credsFromAssumeRole(cfgCp, handlers, sharedCfg, sessOpts)
|
||||||
|
}
|
||||||
|
|
||||||
|
return creds, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// valid credential source values
|
||||||
|
const (
|
||||||
|
credSourceEc2Metadata = "Ec2InstanceMetadata"
|
||||||
|
credSourceEnvironment = "Environment"
|
||||||
|
credSourceECSContainer = "EcsContainer"
|
||||||
|
)
|
||||||
|
|
||||||
|
func resolveCredsFromSource(cfg *aws.Config,
|
||||||
|
envCfg envConfig, sharedCfg sharedConfig,
|
||||||
|
handlers request.Handlers,
|
||||||
|
sessOpts Options,
|
||||||
|
) (creds *credentials.Credentials, err error) {
|
||||||
|
|
||||||
|
switch sharedCfg.CredentialSource {
|
||||||
|
case credSourceEc2Metadata:
|
||||||
|
p := defaults.RemoteCredProvider(*cfg, handlers)
|
||||||
|
creds = credentials.NewCredentials(p)
|
||||||
|
|
||||||
|
case credSourceEnvironment:
|
||||||
|
creds = credentials.NewStaticCredentialsFromCreds(envCfg.Creds)
|
||||||
|
|
||||||
|
case credSourceECSContainer:
|
||||||
|
if len(os.Getenv(shareddefaults.ECSCredsProviderEnvVar)) == 0 {
|
||||||
|
return nil, ErrSharedConfigECSContainerEnvVarEmpty
|
||||||
|
}
|
||||||
|
|
||||||
|
p := defaults.RemoteCredProvider(*cfg, handlers)
|
||||||
|
creds = credentials.NewCredentials(p)
|
||||||
|
|
||||||
|
default:
|
||||||
|
return nil, ErrSharedConfigInvalidCredSource
|
||||||
|
}
|
||||||
|
|
||||||
|
return creds, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func credsFromAssumeRole(cfg aws.Config,
|
||||||
|
handlers request.Handlers,
|
||||||
|
sharedCfg sharedConfig,
|
||||||
|
sessOpts Options,
|
||||||
|
) (*credentials.Credentials, error) {
|
||||||
|
|
||||||
|
if len(sharedCfg.MFASerial) != 0 && sessOpts.AssumeRoleTokenProvider == nil {
|
||||||
|
// AssumeRole Token provider is required if doing Assume Role
|
||||||
|
// with MFA.
|
||||||
|
return nil, AssumeRoleTokenProviderNotSetError{}
|
||||||
|
}
|
||||||
|
|
||||||
|
return stscreds.NewCredentials(
|
||||||
|
&Session{
|
||||||
|
Config: &cfg,
|
||||||
|
Handlers: handlers.Copy(),
|
||||||
|
},
|
||||||
|
sharedCfg.RoleARN,
|
||||||
|
func(opt *stscreds.AssumeRoleProvider) {
|
||||||
|
opt.RoleSessionName = sharedCfg.RoleSessionName
|
||||||
|
opt.Duration = sessOpts.AssumeRoleDuration
|
||||||
|
|
||||||
|
// Assume role with external ID
|
||||||
|
if len(sharedCfg.ExternalID) > 0 {
|
||||||
|
opt.ExternalID = aws.String(sharedCfg.ExternalID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Assume role with MFA
|
||||||
|
if len(sharedCfg.MFASerial) > 0 {
|
||||||
|
opt.SerialNumber = aws.String(sharedCfg.MFASerial)
|
||||||
|
opt.TokenProvider = sessOpts.AssumeRoleTokenProvider
|
||||||
|
}
|
||||||
|
},
|
||||||
|
), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// AssumeRoleTokenProviderNotSetError is an error returned when creating a
|
||||||
|
// session when the MFAToken option is not set when shared config is configured
|
||||||
|
// load assume a role with an MFA token.
|
||||||
|
type AssumeRoleTokenProviderNotSetError struct{}
|
||||||
|
|
||||||
|
// Code is the short id of the error.
|
||||||
|
func (e AssumeRoleTokenProviderNotSetError) Code() string {
|
||||||
|
return "AssumeRoleTokenProviderNotSetError"
|
||||||
|
}
|
||||||
|
|
||||||
|
// Message is the description of the error
|
||||||
|
func (e AssumeRoleTokenProviderNotSetError) Message() string {
|
||||||
|
return fmt.Sprintf("assume role with MFA enabled, but AssumeRoleTokenProvider session option not set.")
|
||||||
|
}
|
||||||
|
|
||||||
|
// OrigErr is the underlying error that caused the failure.
|
||||||
|
func (e AssumeRoleTokenProviderNotSetError) OrigErr() error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Error satisfies the error interface.
|
||||||
|
func (e AssumeRoleTokenProviderNotSetError) Error() string {
|
||||||
|
return awserr.SprintError(e.Code(), e.Message(), "", nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
type credProviderError struct {
|
||||||
|
Err error
|
||||||
|
}
|
||||||
|
|
||||||
|
var emptyCreds = credentials.Value{}
|
||||||
|
|
||||||
|
func (c credProviderError) Retrieve() (credentials.Value, error) {
|
||||||
|
return credentials.Value{}, c.Err
|
||||||
|
}
|
||||||
|
func (c credProviderError) IsExpired() bool {
|
||||||
|
return true
|
||||||
|
}
|
|
@ -99,7 +99,7 @@ handler logs every request and its payload made by a service client:
|
||||||
|
|
||||||
sess.Handlers.Send.PushFront(func(r *request.Request) {
|
sess.Handlers.Send.PushFront(func(r *request.Request) {
|
||||||
// Log every request made and its payload
|
// Log every request made and its payload
|
||||||
logger.Println("Request: %s/%s, Payload: %s",
|
logger.Printf("Request: %s/%s, Payload: %s",
|
||||||
r.ClientInfo.ServiceName, r.Operation, r.Params)
|
r.ClientInfo.ServiceName, r.Operation, r.Params)
|
||||||
})
|
})
|
||||||
|
|
||||||
|
@ -183,7 +183,7 @@ be returned when creating the session.
|
||||||
// from assumed role.
|
// from assumed role.
|
||||||
svc := s3.New(sess)
|
svc := s3.New(sess)
|
||||||
|
|
||||||
To setup assume role outside of a session see the stscrds.AssumeRoleProvider
|
To setup assume role outside of a session see the stscreds.AssumeRoleProvider
|
||||||
documentation.
|
documentation.
|
||||||
|
|
||||||
Environment Variables
|
Environment Variables
|
||||||
|
|
|
@ -4,6 +4,7 @@ import (
|
||||||
"os"
|
"os"
|
||||||
"strconv"
|
"strconv"
|
||||||
|
|
||||||
|
"github.com/aws/aws-sdk-go/aws"
|
||||||
"github.com/aws/aws-sdk-go/aws/credentials"
|
"github.com/aws/aws-sdk-go/aws/credentials"
|
||||||
"github.com/aws/aws-sdk-go/aws/defaults"
|
"github.com/aws/aws-sdk-go/aws/defaults"
|
||||||
)
|
)
|
||||||
|
@ -79,7 +80,7 @@ type envConfig struct {
|
||||||
// AWS_CONFIG_FILE=$HOME/my_shared_config
|
// AWS_CONFIG_FILE=$HOME/my_shared_config
|
||||||
SharedConfigFile string
|
SharedConfigFile string
|
||||||
|
|
||||||
// Sets the path to a custom Credentials Authroity (CA) Bundle PEM file
|
// Sets the path to a custom Credentials Authority (CA) Bundle PEM file
|
||||||
// that the SDK will use instead of the system's root CA bundle.
|
// that the SDK will use instead of the system's root CA bundle.
|
||||||
// Only use this if you want to configure the SDK to use a custom set
|
// Only use this if you want to configure the SDK to use a custom set
|
||||||
// of CAs.
|
// of CAs.
|
||||||
|
@ -101,12 +102,38 @@ type envConfig struct {
|
||||||
CSMEnabled bool
|
CSMEnabled bool
|
||||||
CSMPort string
|
CSMPort string
|
||||||
CSMClientID string
|
CSMClientID string
|
||||||
|
CSMHost string
|
||||||
|
|
||||||
|
// Enables endpoint discovery via environment variables.
|
||||||
|
//
|
||||||
|
// AWS_ENABLE_ENDPOINT_DISCOVERY=true
|
||||||
|
EnableEndpointDiscovery *bool
|
||||||
|
enableEndpointDiscovery string
|
||||||
|
|
||||||
|
// Specifies the WebIdentity token the SDK should use to assume a role
|
||||||
|
// with.
|
||||||
|
//
|
||||||
|
// AWS_WEB_IDENTITY_TOKEN_FILE=file_path
|
||||||
|
WebIdentityTokenFilePath string
|
||||||
|
|
||||||
|
// Specifies the IAM role arn to use when assuming an role.
|
||||||
|
//
|
||||||
|
// AWS_ROLE_ARN=role_arn
|
||||||
|
RoleARN string
|
||||||
|
|
||||||
|
// Specifies the IAM role session name to use when assuming a role.
|
||||||
|
//
|
||||||
|
// AWS_ROLE_SESSION_NAME=session_name
|
||||||
|
RoleSessionName string
|
||||||
}
|
}
|
||||||
|
|
||||||
var (
|
var (
|
||||||
csmEnabledEnvKey = []string{
|
csmEnabledEnvKey = []string{
|
||||||
"AWS_CSM_ENABLED",
|
"AWS_CSM_ENABLED",
|
||||||
}
|
}
|
||||||
|
csmHostEnvKey = []string{
|
||||||
|
"AWS_CSM_HOST",
|
||||||
|
}
|
||||||
csmPortEnvKey = []string{
|
csmPortEnvKey = []string{
|
||||||
"AWS_CSM_PORT",
|
"AWS_CSM_PORT",
|
||||||
}
|
}
|
||||||
|
@ -125,6 +152,10 @@ var (
|
||||||
"AWS_SESSION_TOKEN",
|
"AWS_SESSION_TOKEN",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
enableEndpointDiscoveryEnvKey = []string{
|
||||||
|
"AWS_ENABLE_ENDPOINT_DISCOVERY",
|
||||||
|
}
|
||||||
|
|
||||||
regionEnvKeys = []string{
|
regionEnvKeys = []string{
|
||||||
"AWS_REGION",
|
"AWS_REGION",
|
||||||
"AWS_DEFAULT_REGION", // Only read if AWS_SDK_LOAD_CONFIG is also set
|
"AWS_DEFAULT_REGION", // Only read if AWS_SDK_LOAD_CONFIG is also set
|
||||||
|
@ -139,6 +170,15 @@ var (
|
||||||
sharedConfigFileEnvKey = []string{
|
sharedConfigFileEnvKey = []string{
|
||||||
"AWS_CONFIG_FILE",
|
"AWS_CONFIG_FILE",
|
||||||
}
|
}
|
||||||
|
webIdentityTokenFilePathEnvKey = []string{
|
||||||
|
"AWS_WEB_IDENTITY_TOKEN_FILE",
|
||||||
|
}
|
||||||
|
roleARNEnvKey = []string{
|
||||||
|
"AWS_ROLE_ARN",
|
||||||
|
}
|
||||||
|
roleSessionNameEnvKey = []string{
|
||||||
|
"AWS_ROLE_SESSION_NAME",
|
||||||
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
// loadEnvConfig retrieves the SDK's environment configuration.
|
// loadEnvConfig retrieves the SDK's environment configuration.
|
||||||
|
@ -167,23 +207,31 @@ func envConfigLoad(enableSharedConfig bool) envConfig {
|
||||||
|
|
||||||
cfg.EnableSharedConfig = enableSharedConfig
|
cfg.EnableSharedConfig = enableSharedConfig
|
||||||
|
|
||||||
setFromEnvVal(&cfg.Creds.AccessKeyID, credAccessEnvKey)
|
// Static environment credentials
|
||||||
setFromEnvVal(&cfg.Creds.SecretAccessKey, credSecretEnvKey)
|
var creds credentials.Value
|
||||||
setFromEnvVal(&cfg.Creds.SessionToken, credSessionEnvKey)
|
setFromEnvVal(&creds.AccessKeyID, credAccessEnvKey)
|
||||||
|
setFromEnvVal(&creds.SecretAccessKey, credSecretEnvKey)
|
||||||
|
setFromEnvVal(&creds.SessionToken, credSessionEnvKey)
|
||||||
|
if creds.HasKeys() {
|
||||||
|
// Require logical grouping of credentials
|
||||||
|
creds.ProviderName = EnvProviderName
|
||||||
|
cfg.Creds = creds
|
||||||
|
}
|
||||||
|
|
||||||
|
// Role Metadata
|
||||||
|
setFromEnvVal(&cfg.RoleARN, roleARNEnvKey)
|
||||||
|
setFromEnvVal(&cfg.RoleSessionName, roleSessionNameEnvKey)
|
||||||
|
|
||||||
|
// Web identity environment variables
|
||||||
|
setFromEnvVal(&cfg.WebIdentityTokenFilePath, webIdentityTokenFilePathEnvKey)
|
||||||
|
|
||||||
// CSM environment variables
|
// CSM environment variables
|
||||||
setFromEnvVal(&cfg.csmEnabled, csmEnabledEnvKey)
|
setFromEnvVal(&cfg.csmEnabled, csmEnabledEnvKey)
|
||||||
|
setFromEnvVal(&cfg.CSMHost, csmHostEnvKey)
|
||||||
setFromEnvVal(&cfg.CSMPort, csmPortEnvKey)
|
setFromEnvVal(&cfg.CSMPort, csmPortEnvKey)
|
||||||
setFromEnvVal(&cfg.CSMClientID, csmClientIDEnvKey)
|
setFromEnvVal(&cfg.CSMClientID, csmClientIDEnvKey)
|
||||||
cfg.CSMEnabled = len(cfg.csmEnabled) > 0
|
cfg.CSMEnabled = len(cfg.csmEnabled) > 0
|
||||||
|
|
||||||
// Require logical grouping of credentials
|
|
||||||
if len(cfg.Creds.AccessKeyID) == 0 || len(cfg.Creds.SecretAccessKey) == 0 {
|
|
||||||
cfg.Creds = credentials.Value{}
|
|
||||||
} else {
|
|
||||||
cfg.Creds.ProviderName = EnvProviderName
|
|
||||||
}
|
|
||||||
|
|
||||||
regionKeys := regionEnvKeys
|
regionKeys := regionEnvKeys
|
||||||
profileKeys := profileEnvKeys
|
profileKeys := profileEnvKeys
|
||||||
if !cfg.EnableSharedConfig {
|
if !cfg.EnableSharedConfig {
|
||||||
|
@ -194,6 +242,12 @@ func envConfigLoad(enableSharedConfig bool) envConfig {
|
||||||
setFromEnvVal(&cfg.Region, regionKeys)
|
setFromEnvVal(&cfg.Region, regionKeys)
|
||||||
setFromEnvVal(&cfg.Profile, profileKeys)
|
setFromEnvVal(&cfg.Profile, profileKeys)
|
||||||
|
|
||||||
|
// endpoint discovery is in reference to it being enabled.
|
||||||
|
setFromEnvVal(&cfg.enableEndpointDiscovery, enableEndpointDiscoveryEnvKey)
|
||||||
|
if len(cfg.enableEndpointDiscovery) > 0 {
|
||||||
|
cfg.EnableEndpointDiscovery = aws.Bool(cfg.enableEndpointDiscovery != "false")
|
||||||
|
}
|
||||||
|
|
||||||
setFromEnvVal(&cfg.SharedCredentialsFile, sharedCredsFileEnvKey)
|
setFromEnvVal(&cfg.SharedCredentialsFile, sharedCredsFileEnvKey)
|
||||||
setFromEnvVal(&cfg.SharedConfigFile, sharedConfigFileEnvKey)
|
setFromEnvVal(&cfg.SharedConfigFile, sharedConfigFileEnvKey)
|
||||||
|
|
||||||
|
|
|
@ -8,19 +8,36 @@ import (
|
||||||
"io/ioutil"
|
"io/ioutil"
|
||||||
"net/http"
|
"net/http"
|
||||||
"os"
|
"os"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/aws/aws-sdk-go/aws"
|
"github.com/aws/aws-sdk-go/aws"
|
||||||
"github.com/aws/aws-sdk-go/aws/awserr"
|
"github.com/aws/aws-sdk-go/aws/awserr"
|
||||||
"github.com/aws/aws-sdk-go/aws/client"
|
"github.com/aws/aws-sdk-go/aws/client"
|
||||||
"github.com/aws/aws-sdk-go/aws/corehandlers"
|
"github.com/aws/aws-sdk-go/aws/corehandlers"
|
||||||
"github.com/aws/aws-sdk-go/aws/credentials"
|
"github.com/aws/aws-sdk-go/aws/credentials"
|
||||||
"github.com/aws/aws-sdk-go/aws/credentials/stscreds"
|
|
||||||
"github.com/aws/aws-sdk-go/aws/csm"
|
"github.com/aws/aws-sdk-go/aws/csm"
|
||||||
"github.com/aws/aws-sdk-go/aws/defaults"
|
"github.com/aws/aws-sdk-go/aws/defaults"
|
||||||
"github.com/aws/aws-sdk-go/aws/endpoints"
|
"github.com/aws/aws-sdk-go/aws/endpoints"
|
||||||
"github.com/aws/aws-sdk-go/aws/request"
|
"github.com/aws/aws-sdk-go/aws/request"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
// ErrCodeSharedConfig represents an error that occurs in the shared
|
||||||
|
// configuration logic
|
||||||
|
ErrCodeSharedConfig = "SharedConfigErr"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ErrSharedConfigSourceCollision will be returned if a section contains both
|
||||||
|
// source_profile and credential_source
|
||||||
|
var ErrSharedConfigSourceCollision = awserr.New(ErrCodeSharedConfig, "only source profile or credential source can be specified, not both", nil)
|
||||||
|
|
||||||
|
// ErrSharedConfigECSContainerEnvVarEmpty will be returned if the environment
|
||||||
|
// variables are empty and Environment was set as the credential source
|
||||||
|
var ErrSharedConfigECSContainerEnvVarEmpty = awserr.New(ErrCodeSharedConfig, "EcsContainer was specified as the credential_source, but 'AWS_CONTAINER_CREDENTIALS_RELATIVE_URI' was not set", nil)
|
||||||
|
|
||||||
|
// ErrSharedConfigInvalidCredSource will be returned if an invalid credential source was provided
|
||||||
|
var ErrSharedConfigInvalidCredSource = awserr.New(ErrCodeSharedConfig, "credential source values must be EcsContainer, Ec2InstanceMetadata, or Environment", nil)
|
||||||
|
|
||||||
// A Session provides a central location to create service clients from and
|
// A Session provides a central location to create service clients from and
|
||||||
// store configurations and request handlers for those services.
|
// store configurations and request handlers for those services.
|
||||||
//
|
//
|
||||||
|
@ -88,7 +105,15 @@ func New(cfgs ...*aws.Config) *Session {
|
||||||
|
|
||||||
s := deprecatedNewSession(cfgs...)
|
s := deprecatedNewSession(cfgs...)
|
||||||
if envCfg.CSMEnabled {
|
if envCfg.CSMEnabled {
|
||||||
enableCSM(&s.Handlers, envCfg.CSMClientID, envCfg.CSMPort, s.Config.Logger)
|
err := enableCSM(&s.Handlers, envCfg.CSMClientID,
|
||||||
|
envCfg.CSMHost, envCfg.CSMPort, s.Config.Logger)
|
||||||
|
if err != nil {
|
||||||
|
err = fmt.Errorf("failed to enable CSM, %v", err)
|
||||||
|
s.Config.Logger.Log("ERROR:", err.Error())
|
||||||
|
s.Handlers.Validate.PushBack(func(r *request.Request) {
|
||||||
|
r.Error = err
|
||||||
|
})
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return s
|
return s
|
||||||
|
@ -191,6 +216,12 @@ type Options struct {
|
||||||
// the config enables assume role wit MFA via the mfa_serial field.
|
// the config enables assume role wit MFA via the mfa_serial field.
|
||||||
AssumeRoleTokenProvider func() (string, error)
|
AssumeRoleTokenProvider func() (string, error)
|
||||||
|
|
||||||
|
// When the SDK's shared config is configured to assume a role this option
|
||||||
|
// may be provided to set the expiry duration of the STS credentials.
|
||||||
|
// Defaults to 15 minutes if not set as documented in the
|
||||||
|
// stscreds.AssumeRoleProvider.
|
||||||
|
AssumeRoleDuration time.Duration
|
||||||
|
|
||||||
// Reader for a custom Credentials Authority (CA) bundle in PEM format that
|
// Reader for a custom Credentials Authority (CA) bundle in PEM format that
|
||||||
// the SDK will use instead of the default system's root CA bundle. Use this
|
// the SDK will use instead of the default system's root CA bundle. Use this
|
||||||
// only if you want to replace the CA bundle the SDK uses for TLS requests.
|
// only if you want to replace the CA bundle the SDK uses for TLS requests.
|
||||||
|
@ -205,6 +236,12 @@ type Options struct {
|
||||||
// to also enable this feature. CustomCABundle session option field has priority
|
// to also enable this feature. CustomCABundle session option field has priority
|
||||||
// over the AWS_CA_BUNDLE environment variable, and will be used if both are set.
|
// over the AWS_CA_BUNDLE environment variable, and will be used if both are set.
|
||||||
CustomCABundle io.Reader
|
CustomCABundle io.Reader
|
||||||
|
|
||||||
|
// The handlers that the session and all API clients will be created with.
|
||||||
|
// This must be a complete set of handlers. Use the defaults.Handlers()
|
||||||
|
// function to initialize this value before changing the handlers to be
|
||||||
|
// used by the SDK.
|
||||||
|
Handlers request.Handlers
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewSessionWithOptions returns a new Session created from SDK defaults, config files,
|
// NewSessionWithOptions returns a new Session created from SDK defaults, config files,
|
||||||
|
@ -310,27 +347,36 @@ func deprecatedNewSession(cfgs ...*aws.Config) *Session {
|
||||||
return s
|
return s
|
||||||
}
|
}
|
||||||
|
|
||||||
func enableCSM(handlers *request.Handlers, clientID string, port string, logger aws.Logger) {
|
func enableCSM(handlers *request.Handlers,
|
||||||
|
clientID, host, port string,
|
||||||
|
logger aws.Logger,
|
||||||
|
) error {
|
||||||
|
if logger != nil {
|
||||||
logger.Log("Enabling CSM")
|
logger.Log("Enabling CSM")
|
||||||
if len(port) == 0 {
|
|
||||||
port = csm.DefaultPort
|
|
||||||
}
|
}
|
||||||
|
|
||||||
r, err := csm.Start(clientID, "127.0.0.1:"+port)
|
r, err := csm.Start(clientID, csm.AddressWithDefaults(host, port))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return err
|
||||||
}
|
}
|
||||||
r.InjectHandlers(handlers)
|
r.InjectHandlers(handlers)
|
||||||
|
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func newSession(opts Options, envCfg envConfig, cfgs ...*aws.Config) (*Session, error) {
|
func newSession(opts Options, envCfg envConfig, cfgs ...*aws.Config) (*Session, error) {
|
||||||
cfg := defaults.Config()
|
cfg := defaults.Config()
|
||||||
handlers := defaults.Handlers()
|
|
||||||
|
handlers := opts.Handlers
|
||||||
|
if handlers.IsEmpty() {
|
||||||
|
handlers = defaults.Handlers()
|
||||||
|
}
|
||||||
|
|
||||||
// Get a merged version of the user provided config to determine if
|
// Get a merged version of the user provided config to determine if
|
||||||
// credentials were.
|
// credentials were.
|
||||||
userCfg := &aws.Config{}
|
userCfg := &aws.Config{}
|
||||||
userCfg.MergeIn(cfgs...)
|
userCfg.MergeIn(cfgs...)
|
||||||
|
cfg.MergeIn(userCfg)
|
||||||
|
|
||||||
// Ordered config files will be loaded in with later files overwriting
|
// Ordered config files will be loaded in with later files overwriting
|
||||||
// previous config file values.
|
// previous config file values.
|
||||||
|
@ -347,10 +393,12 @@ func newSession(opts Options, envCfg envConfig, cfgs ...*aws.Config) (*Session,
|
||||||
}
|
}
|
||||||
|
|
||||||
// Load additional config from file(s)
|
// Load additional config from file(s)
|
||||||
sharedCfg, err := loadSharedConfig(envCfg.Profile, cfgFiles)
|
sharedCfg, err := loadSharedConfig(envCfg.Profile, cfgFiles, envCfg.EnableSharedConfig)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
if _, ok := err.(SharedConfigProfileNotExistsError); !ok {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if err := mergeConfigSrcs(cfg, userCfg, envCfg, sharedCfg, handlers, opts); err != nil {
|
if err := mergeConfigSrcs(cfg, userCfg, envCfg, sharedCfg, handlers, opts); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
@ -363,7 +411,11 @@ func newSession(opts Options, envCfg envConfig, cfgs ...*aws.Config) (*Session,
|
||||||
|
|
||||||
initHandlers(s)
|
initHandlers(s)
|
||||||
if envCfg.CSMEnabled {
|
if envCfg.CSMEnabled {
|
||||||
enableCSM(&s.Handlers, envCfg.CSMClientID, envCfg.CSMPort, s.Config.Logger)
|
err := enableCSM(&s.Handlers, envCfg.CSMClientID,
|
||||||
|
envCfg.CSMHost, envCfg.CSMPort, s.Config.Logger)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Setup HTTP client with custom cert bundle if enabled
|
// Setup HTTP client with custom cert bundle if enabled
|
||||||
|
@ -388,7 +440,10 @@ func loadCustomCABundle(s *Session, bundle io.Reader) error {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if t == nil {
|
if t == nil {
|
||||||
t = &http.Transport{}
|
// Nil transport implies `http.DefaultTransport` should be used. Since
|
||||||
|
// the SDK cannot modify, nor copy the `DefaultTransport` specifying
|
||||||
|
// the values the next closest behavior.
|
||||||
|
t = getCABundleTransport()
|
||||||
}
|
}
|
||||||
|
|
||||||
p, err := loadCertPool(bundle)
|
p, err := loadCertPool(bundle)
|
||||||
|
@ -421,9 +476,11 @@ func loadCertPool(r io.Reader) (*x509.CertPool, error) {
|
||||||
return p, nil
|
return p, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func mergeConfigSrcs(cfg, userCfg *aws.Config, envCfg envConfig, sharedCfg sharedConfig, handlers request.Handlers, sessOpts Options) error {
|
func mergeConfigSrcs(cfg, userCfg *aws.Config,
|
||||||
// Merge in user provided configuration
|
envCfg envConfig, sharedCfg sharedConfig,
|
||||||
cfg.MergeIn(userCfg)
|
handlers request.Handlers,
|
||||||
|
sessOpts Options,
|
||||||
|
) error {
|
||||||
|
|
||||||
// Region if not already set by user
|
// Region if not already set by user
|
||||||
if len(aws.StringValue(cfg.Region)) == 0 {
|
if len(aws.StringValue(cfg.Region)) == 0 {
|
||||||
|
@ -434,103 +491,27 @@ func mergeConfigSrcs(cfg, userCfg *aws.Config, envCfg envConfig, sharedCfg share
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Configure credentials if not already set
|
if cfg.EnableEndpointDiscovery == nil {
|
||||||
|
if envCfg.EnableEndpointDiscovery != nil {
|
||||||
|
cfg.WithEndpointDiscovery(*envCfg.EnableEndpointDiscovery)
|
||||||
|
} else if envCfg.EnableSharedConfig && sharedCfg.EnableEndpointDiscovery != nil {
|
||||||
|
cfg.WithEndpointDiscovery(*sharedCfg.EnableEndpointDiscovery)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Configure credentials if not already set by the user when creating the
|
||||||
|
// Session.
|
||||||
if cfg.Credentials == credentials.AnonymousCredentials && userCfg.Credentials == nil {
|
if cfg.Credentials == credentials.AnonymousCredentials && userCfg.Credentials == nil {
|
||||||
if len(envCfg.Creds.AccessKeyID) > 0 {
|
creds, err := resolveCredentials(cfg, envCfg, sharedCfg, handlers, sessOpts)
|
||||||
cfg.Credentials = credentials.NewStaticCredentialsFromCreds(
|
if err != nil {
|
||||||
envCfg.Creds,
|
return err
|
||||||
)
|
|
||||||
} else if envCfg.EnableSharedConfig && len(sharedCfg.AssumeRole.RoleARN) > 0 && sharedCfg.AssumeRoleSource != nil {
|
|
||||||
cfgCp := *cfg
|
|
||||||
cfgCp.Credentials = credentials.NewStaticCredentialsFromCreds(
|
|
||||||
sharedCfg.AssumeRoleSource.Creds,
|
|
||||||
)
|
|
||||||
if len(sharedCfg.AssumeRole.MFASerial) > 0 && sessOpts.AssumeRoleTokenProvider == nil {
|
|
||||||
// AssumeRole Token provider is required if doing Assume Role
|
|
||||||
// with MFA.
|
|
||||||
return AssumeRoleTokenProviderNotSetError{}
|
|
||||||
}
|
|
||||||
cfg.Credentials = stscreds.NewCredentials(
|
|
||||||
&Session{
|
|
||||||
Config: &cfgCp,
|
|
||||||
Handlers: handlers.Copy(),
|
|
||||||
},
|
|
||||||
sharedCfg.AssumeRole.RoleARN,
|
|
||||||
func(opt *stscreds.AssumeRoleProvider) {
|
|
||||||
opt.RoleSessionName = sharedCfg.AssumeRole.RoleSessionName
|
|
||||||
|
|
||||||
// Assume role with external ID
|
|
||||||
if len(sharedCfg.AssumeRole.ExternalID) > 0 {
|
|
||||||
opt.ExternalID = aws.String(sharedCfg.AssumeRole.ExternalID)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Assume role with MFA
|
|
||||||
if len(sharedCfg.AssumeRole.MFASerial) > 0 {
|
|
||||||
opt.SerialNumber = aws.String(sharedCfg.AssumeRole.MFASerial)
|
|
||||||
opt.TokenProvider = sessOpts.AssumeRoleTokenProvider
|
|
||||||
}
|
|
||||||
},
|
|
||||||
)
|
|
||||||
} else if len(sharedCfg.Creds.AccessKeyID) > 0 {
|
|
||||||
cfg.Credentials = credentials.NewStaticCredentialsFromCreds(
|
|
||||||
sharedCfg.Creds,
|
|
||||||
)
|
|
||||||
} else {
|
|
||||||
// Fallback to default credentials provider, include mock errors
|
|
||||||
// for the credential chain so user can identify why credentials
|
|
||||||
// failed to be retrieved.
|
|
||||||
cfg.Credentials = credentials.NewCredentials(&credentials.ChainProvider{
|
|
||||||
VerboseErrors: aws.BoolValue(cfg.CredentialsChainVerboseErrors),
|
|
||||||
Providers: []credentials.Provider{
|
|
||||||
&credProviderError{Err: awserr.New("EnvAccessKeyNotFound", "failed to find credentials in the environment.", nil)},
|
|
||||||
&credProviderError{Err: awserr.New("SharedCredsLoad", fmt.Sprintf("failed to load profile, %s.", envCfg.Profile), nil)},
|
|
||||||
defaults.RemoteCredProvider(*cfg, handlers),
|
|
||||||
},
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
|
cfg.Credentials = creds
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// AssumeRoleTokenProviderNotSetError is an error returned when creating a session when the
|
|
||||||
// MFAToken option is not set when shared config is configured load assume a
|
|
||||||
// role with an MFA token.
|
|
||||||
type AssumeRoleTokenProviderNotSetError struct{}
|
|
||||||
|
|
||||||
// Code is the short id of the error.
|
|
||||||
func (e AssumeRoleTokenProviderNotSetError) Code() string {
|
|
||||||
return "AssumeRoleTokenProviderNotSetError"
|
|
||||||
}
|
|
||||||
|
|
||||||
// Message is the description of the error
|
|
||||||
func (e AssumeRoleTokenProviderNotSetError) Message() string {
|
|
||||||
return fmt.Sprintf("assume role with MFA enabled, but AssumeRoleTokenProvider session option not set.")
|
|
||||||
}
|
|
||||||
|
|
||||||
// OrigErr is the underlying error that caused the failure.
|
|
||||||
func (e AssumeRoleTokenProviderNotSetError) OrigErr() error {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Error satisfies the error interface.
|
|
||||||
func (e AssumeRoleTokenProviderNotSetError) Error() string {
|
|
||||||
return awserr.SprintError(e.Code(), e.Message(), "", nil)
|
|
||||||
}
|
|
||||||
|
|
||||||
type credProviderError struct {
|
|
||||||
Err error
|
|
||||||
}
|
|
||||||
|
|
||||||
var emptyCreds = credentials.Value{}
|
|
||||||
|
|
||||||
func (c credProviderError) Retrieve() (credentials.Value, error) {
|
|
||||||
return credentials.Value{}, c.Err
|
|
||||||
}
|
|
||||||
func (c credProviderError) IsExpired() bool {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
func initHandlers(s *Session) {
|
func initHandlers(s *Session) {
|
||||||
// Add the Validate parameter handler if it is not disabled.
|
// Add the Validate parameter handler if it is not disabled.
|
||||||
s.Handlers.Validate.Remove(corehandlers.ValidateParametersHandler)
|
s.Handlers.Validate.Remove(corehandlers.ValidateParametersHandler)
|
||||||
|
|
|
@ -2,11 +2,10 @@ package session
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"io/ioutil"
|
|
||||||
|
|
||||||
"github.com/aws/aws-sdk-go/aws/awserr"
|
"github.com/aws/aws-sdk-go/aws/awserr"
|
||||||
"github.com/aws/aws-sdk-go/aws/credentials"
|
"github.com/aws/aws-sdk-go/aws/credentials"
|
||||||
"github.com/go-ini/ini"
|
"github.com/aws/aws-sdk-go/internal/ini"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
|
@ -17,7 +16,8 @@ const (
|
||||||
|
|
||||||
// Assume Role Credentials group
|
// Assume Role Credentials group
|
||||||
roleArnKey = `role_arn` // group required
|
roleArnKey = `role_arn` // group required
|
||||||
sourceProfileKey = `source_profile` // group required
|
sourceProfileKey = `source_profile` // group required (or credential_source)
|
||||||
|
credentialSourceKey = `credential_source` // group required (or source_profile)
|
||||||
externalIDKey = `external_id` // optional
|
externalIDKey = `external_id` // optional
|
||||||
mfaSerialKey = `mfa_serial` // optional
|
mfaSerialKey = `mfa_serial` // optional
|
||||||
roleSessionNameKey = `role_session_name` // optional
|
roleSessionNameKey = `role_session_name` // optional
|
||||||
|
@ -25,59 +25,76 @@ const (
|
||||||
// Additional Config fields
|
// Additional Config fields
|
||||||
regionKey = `region`
|
regionKey = `region`
|
||||||
|
|
||||||
|
// endpoint discovery group
|
||||||
|
enableEndpointDiscoveryKey = `endpoint_discovery_enabled` // optional
|
||||||
|
|
||||||
|
// External Credential Process
|
||||||
|
credentialProcessKey = `credential_process` // optional
|
||||||
|
|
||||||
|
// Web Identity Token File
|
||||||
|
webIdentityTokenFileKey = `web_identity_token_file` // optional
|
||||||
|
|
||||||
// DefaultSharedConfigProfile is the default profile to be used when
|
// DefaultSharedConfigProfile is the default profile to be used when
|
||||||
// loading configuration from the config files if another profile name
|
// loading configuration from the config files if another profile name
|
||||||
// is not provided.
|
// is not provided.
|
||||||
DefaultSharedConfigProfile = `default`
|
DefaultSharedConfigProfile = `default`
|
||||||
)
|
)
|
||||||
|
|
||||||
type assumeRoleConfig struct {
|
|
||||||
RoleARN string
|
|
||||||
SourceProfile string
|
|
||||||
ExternalID string
|
|
||||||
MFASerial string
|
|
||||||
RoleSessionName string
|
|
||||||
}
|
|
||||||
|
|
||||||
// sharedConfig represents the configuration fields of the SDK config files.
|
// sharedConfig represents the configuration fields of the SDK config files.
|
||||||
type sharedConfig struct {
|
type sharedConfig struct {
|
||||||
// Credentials values from the config file. Both aws_access_key_id
|
// Credentials values from the config file. Both aws_access_key_id and
|
||||||
// and aws_secret_access_key must be provided together in the same file
|
// aws_secret_access_key must be provided together in the same file to be
|
||||||
// to be considered valid. The values will be ignored if not a complete group.
|
// considered valid. The values will be ignored if not a complete group.
|
||||||
// aws_session_token is an optional field that can be provided if both of the
|
// aws_session_token is an optional field that can be provided if both of
|
||||||
// other two fields are also provided.
|
// the other two fields are also provided.
|
||||||
//
|
//
|
||||||
// aws_access_key_id
|
// aws_access_key_id
|
||||||
// aws_secret_access_key
|
// aws_secret_access_key
|
||||||
// aws_session_token
|
// aws_session_token
|
||||||
Creds credentials.Value
|
Creds credentials.Value
|
||||||
|
|
||||||
AssumeRole assumeRoleConfig
|
CredentialSource string
|
||||||
AssumeRoleSource *sharedConfig
|
CredentialProcess string
|
||||||
|
WebIdentityTokenFile string
|
||||||
|
|
||||||
// Region is the region the SDK should use for looking up AWS service endpoints
|
RoleARN string
|
||||||
// and signing requests.
|
RoleSessionName string
|
||||||
|
ExternalID string
|
||||||
|
MFASerial string
|
||||||
|
|
||||||
|
SourceProfileName string
|
||||||
|
SourceProfile *sharedConfig
|
||||||
|
|
||||||
|
// Region is the region the SDK should use for looking up AWS service
|
||||||
|
// endpoints and signing requests.
|
||||||
//
|
//
|
||||||
// region
|
// region
|
||||||
Region string
|
Region string
|
||||||
|
|
||||||
|
// EnableEndpointDiscovery can be enabled in the shared config by setting
|
||||||
|
// endpoint_discovery_enabled to true
|
||||||
|
//
|
||||||
|
// endpoint_discovery_enabled = true
|
||||||
|
EnableEndpointDiscovery *bool
|
||||||
}
|
}
|
||||||
|
|
||||||
type sharedConfigFile struct {
|
type sharedConfigFile struct {
|
||||||
Filename string
|
Filename string
|
||||||
IniData *ini.File
|
IniData ini.Sections
|
||||||
}
|
}
|
||||||
|
|
||||||
// loadSharedConfig retrieves the configuration from the list of files
|
// loadSharedConfig retrieves the configuration from the list of files using
|
||||||
// using the profile provided. The order the files are listed will determine
|
// the profile provided. The order the files are listed will determine
|
||||||
// precedence. Values in subsequent files will overwrite values defined in
|
// precedence. Values in subsequent files will overwrite values defined in
|
||||||
// earlier files.
|
// earlier files.
|
||||||
//
|
//
|
||||||
// For example, given two files A and B. Both define credentials. If the order
|
// For example, given two files A and B. Both define credentials. If the order
|
||||||
// of the files are A then B, B's credential values will be used instead of A's.
|
// of the files are A then B, B's credential values will be used instead of
|
||||||
|
// A's.
|
||||||
//
|
//
|
||||||
// See sharedConfig.setFromFile for information how the config files
|
// See sharedConfig.setFromFile for information how the config files
|
||||||
// will be loaded.
|
// will be loaded.
|
||||||
func loadSharedConfig(profile string, filenames []string) (sharedConfig, error) {
|
func loadSharedConfig(profile string, filenames []string, exOpts bool) (sharedConfig, error) {
|
||||||
if len(profile) == 0 {
|
if len(profile) == 0 {
|
||||||
profile = DefaultSharedConfigProfile
|
profile = DefaultSharedConfigProfile
|
||||||
}
|
}
|
||||||
|
@ -88,16 +105,11 @@ func loadSharedConfig(profile string, filenames []string) (sharedConfig, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
cfg := sharedConfig{}
|
cfg := sharedConfig{}
|
||||||
if err = cfg.setFromIniFiles(profile, files); err != nil {
|
profiles := map[string]struct{}{}
|
||||||
|
if err = cfg.setFromIniFiles(profiles, profile, files, exOpts); err != nil {
|
||||||
return sharedConfig{}, err
|
return sharedConfig{}, err
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(cfg.AssumeRole.SourceProfile) > 0 {
|
|
||||||
if err := cfg.setAssumeRoleSource(profile, files); err != nil {
|
|
||||||
return sharedConfig{}, err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return cfg, nil
|
return cfg, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -105,114 +117,237 @@ func loadSharedConfigIniFiles(filenames []string) ([]sharedConfigFile, error) {
|
||||||
files := make([]sharedConfigFile, 0, len(filenames))
|
files := make([]sharedConfigFile, 0, len(filenames))
|
||||||
|
|
||||||
for _, filename := range filenames {
|
for _, filename := range filenames {
|
||||||
b, err := ioutil.ReadFile(filename)
|
sections, err := ini.OpenFile(filename)
|
||||||
if err != nil {
|
if aerr, ok := err.(awserr.Error); ok && aerr.Code() == ini.ErrCodeUnableToReadFile {
|
||||||
// Skip files which can't be opened and read for whatever reason
|
// Skip files which can't be opened and read for whatever reason
|
||||||
continue
|
continue
|
||||||
}
|
} else if err != nil {
|
||||||
|
|
||||||
f, err := ini.Load(b)
|
|
||||||
if err != nil {
|
|
||||||
return nil, SharedConfigLoadError{Filename: filename, Err: err}
|
return nil, SharedConfigLoadError{Filename: filename, Err: err}
|
||||||
}
|
}
|
||||||
|
|
||||||
files = append(files, sharedConfigFile{
|
files = append(files, sharedConfigFile{
|
||||||
Filename: filename, IniData: f,
|
Filename: filename, IniData: sections,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
return files, nil
|
return files, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (cfg *sharedConfig) setAssumeRoleSource(origProfile string, files []sharedConfigFile) error {
|
func (cfg *sharedConfig) setFromIniFiles(profiles map[string]struct{}, profile string, files []sharedConfigFile, exOpts bool) error {
|
||||||
var assumeRoleSrc sharedConfig
|
|
||||||
|
|
||||||
// Multiple level assume role chains are not support
|
|
||||||
if cfg.AssumeRole.SourceProfile == origProfile {
|
|
||||||
assumeRoleSrc = *cfg
|
|
||||||
assumeRoleSrc.AssumeRole = assumeRoleConfig{}
|
|
||||||
} else {
|
|
||||||
err := assumeRoleSrc.setFromIniFiles(cfg.AssumeRole.SourceProfile, files)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(assumeRoleSrc.Creds.AccessKeyID) == 0 {
|
|
||||||
return SharedConfigAssumeRoleError{RoleARN: cfg.AssumeRole.RoleARN}
|
|
||||||
}
|
|
||||||
|
|
||||||
cfg.AssumeRoleSource = &assumeRoleSrc
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (cfg *sharedConfig) setFromIniFiles(profile string, files []sharedConfigFile) error {
|
|
||||||
// Trim files from the list that don't exist.
|
// Trim files from the list that don't exist.
|
||||||
|
var skippedFiles int
|
||||||
|
var profileNotFoundErr error
|
||||||
for _, f := range files {
|
for _, f := range files {
|
||||||
if err := cfg.setFromIniFile(profile, f); err != nil {
|
if err := cfg.setFromIniFile(profile, f, exOpts); err != nil {
|
||||||
if _, ok := err.(SharedConfigProfileNotExistsError); ok {
|
if _, ok := err.(SharedConfigProfileNotExistsError); ok {
|
||||||
// Ignore proviles missings
|
// Ignore profiles not defined in individual files.
|
||||||
|
profileNotFoundErr = err
|
||||||
|
skippedFiles++
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
if skippedFiles == len(files) {
|
||||||
|
// If all files were skipped because the profile is not found, return
|
||||||
|
// the original profile not found error.
|
||||||
|
return profileNotFoundErr
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, ok := profiles[profile]; ok {
|
||||||
|
// if this is the second instance of the profile the Assume Role
|
||||||
|
// options must be cleared because they are only valid for the
|
||||||
|
// first reference of a profile. The self linked instance of the
|
||||||
|
// profile only have credential provider options.
|
||||||
|
cfg.clearAssumeRoleOptions()
|
||||||
|
} else {
|
||||||
|
// First time a profile has been seen, It must either be a assume role
|
||||||
|
// or credentials. Assert if the credential type requires a role ARN,
|
||||||
|
// the ARN is also set.
|
||||||
|
if err := cfg.validateCredentialsRequireARN(profile); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
profiles[profile] = struct{}{}
|
||||||
|
|
||||||
|
if err := cfg.validateCredentialType(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Link source profiles for assume roles
|
||||||
|
if len(cfg.SourceProfileName) != 0 {
|
||||||
|
// Linked profile via source_profile ignore credential provider
|
||||||
|
// options, the source profile must provide the credentials.
|
||||||
|
cfg.clearCredentialOptions()
|
||||||
|
|
||||||
|
srcCfg := &sharedConfig{}
|
||||||
|
err := srcCfg.setFromIniFiles(profiles, cfg.SourceProfileName, files, exOpts)
|
||||||
|
if err != nil {
|
||||||
|
// SourceProfile that doesn't exist is an error in configuration.
|
||||||
|
if _, ok := err.(SharedConfigProfileNotExistsError); ok {
|
||||||
|
err = SharedConfigAssumeRoleError{
|
||||||
|
RoleARN: cfg.RoleARN,
|
||||||
|
SourceProfile: cfg.SourceProfileName,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if !srcCfg.hasCredentials() {
|
||||||
|
return SharedConfigAssumeRoleError{
|
||||||
|
RoleARN: cfg.RoleARN,
|
||||||
|
SourceProfile: cfg.SourceProfileName,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
cfg.SourceProfile = srcCfg
|
||||||
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// setFromFile loads the configuration from the file using
|
// setFromFile loads the configuration from the file using the profile
|
||||||
// the profile provided. A sharedConfig pointer type value is used so that
|
// provided. A sharedConfig pointer type value is used so that multiple config
|
||||||
// multiple config file loadings can be chained.
|
// file loadings can be chained.
|
||||||
//
|
//
|
||||||
// Only loads complete logically grouped values, and will not set fields in cfg
|
// Only loads complete logically grouped values, and will not set fields in cfg
|
||||||
// for incomplete grouped values in the config. Such as credentials. For example
|
// for incomplete grouped values in the config. Such as credentials. For
|
||||||
// if a config file only includes aws_access_key_id but no aws_secret_access_key
|
// example if a config file only includes aws_access_key_id but no
|
||||||
// the aws_access_key_id will be ignored.
|
// aws_secret_access_key the aws_access_key_id will be ignored.
|
||||||
func (cfg *sharedConfig) setFromIniFile(profile string, file sharedConfigFile) error {
|
func (cfg *sharedConfig) setFromIniFile(profile string, file sharedConfigFile, exOpts bool) error {
|
||||||
section, err := file.IniData.GetSection(profile)
|
section, ok := file.IniData.GetSection(profile)
|
||||||
if err != nil {
|
if !ok {
|
||||||
// Fallback to to alternate profile name: profile <name>
|
// Fallback to to alternate profile name: profile <name>
|
||||||
section, err = file.IniData.GetSection(fmt.Sprintf("profile %s", profile))
|
section, ok = file.IniData.GetSection(fmt.Sprintf("profile %s", profile))
|
||||||
if err != nil {
|
if !ok {
|
||||||
return SharedConfigProfileNotExistsError{Profile: profile, Err: err}
|
return SharedConfigProfileNotExistsError{Profile: profile, Err: nil}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if exOpts {
|
||||||
|
// Assume Role Parameters
|
||||||
|
updateString(&cfg.RoleARN, section, roleArnKey)
|
||||||
|
updateString(&cfg.ExternalID, section, externalIDKey)
|
||||||
|
updateString(&cfg.MFASerial, section, mfaSerialKey)
|
||||||
|
updateString(&cfg.RoleSessionName, section, roleSessionNameKey)
|
||||||
|
updateString(&cfg.SourceProfileName, section, sourceProfileKey)
|
||||||
|
updateString(&cfg.CredentialSource, section, credentialSourceKey)
|
||||||
|
|
||||||
|
updateString(&cfg.Region, section, regionKey)
|
||||||
|
}
|
||||||
|
|
||||||
|
updateString(&cfg.CredentialProcess, section, credentialProcessKey)
|
||||||
|
updateString(&cfg.WebIdentityTokenFile, section, webIdentityTokenFileKey)
|
||||||
|
|
||||||
// Shared Credentials
|
// Shared Credentials
|
||||||
akid := section.Key(accessKeyIDKey).String()
|
creds := credentials.Value{
|
||||||
secret := section.Key(secretAccessKey).String()
|
AccessKeyID: section.String(accessKeyIDKey),
|
||||||
if len(akid) > 0 && len(secret) > 0 {
|
SecretAccessKey: section.String(secretAccessKey),
|
||||||
cfg.Creds = credentials.Value{
|
SessionToken: section.String(sessionTokenKey),
|
||||||
AccessKeyID: akid,
|
|
||||||
SecretAccessKey: secret,
|
|
||||||
SessionToken: section.Key(sessionTokenKey).String(),
|
|
||||||
ProviderName: fmt.Sprintf("SharedConfigCredentials: %s", file.Filename),
|
ProviderName: fmt.Sprintf("SharedConfigCredentials: %s", file.Filename),
|
||||||
}
|
}
|
||||||
|
if creds.HasKeys() {
|
||||||
|
cfg.Creds = creds
|
||||||
}
|
}
|
||||||
|
|
||||||
// Assume Role
|
// Endpoint discovery
|
||||||
roleArn := section.Key(roleArnKey).String()
|
if section.Has(enableEndpointDiscoveryKey) {
|
||||||
srcProfile := section.Key(sourceProfileKey).String()
|
v := section.Bool(enableEndpointDiscoveryKey)
|
||||||
if len(roleArn) > 0 && len(srcProfile) > 0 {
|
cfg.EnableEndpointDiscovery = &v
|
||||||
cfg.AssumeRole = assumeRoleConfig{
|
|
||||||
RoleARN: roleArn,
|
|
||||||
SourceProfile: srcProfile,
|
|
||||||
ExternalID: section.Key(externalIDKey).String(),
|
|
||||||
MFASerial: section.Key(mfaSerialKey).String(),
|
|
||||||
RoleSessionName: section.Key(roleSessionNameKey).String(),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Region
|
|
||||||
if v := section.Key(regionKey).String(); len(v) > 0 {
|
|
||||||
cfg.Region = v
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (cfg *sharedConfig) validateCredentialsRequireARN(profile string) error {
|
||||||
|
var credSource string
|
||||||
|
|
||||||
|
switch {
|
||||||
|
case len(cfg.SourceProfileName) != 0:
|
||||||
|
credSource = sourceProfileKey
|
||||||
|
case len(cfg.CredentialSource) != 0:
|
||||||
|
credSource = credentialSourceKey
|
||||||
|
case len(cfg.WebIdentityTokenFile) != 0:
|
||||||
|
credSource = webIdentityTokenFileKey
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(credSource) != 0 && len(cfg.RoleARN) == 0 {
|
||||||
|
return CredentialRequiresARNError{
|
||||||
|
Type: credSource,
|
||||||
|
Profile: profile,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (cfg *sharedConfig) validateCredentialType() error {
|
||||||
|
// Only one or no credential type can be defined.
|
||||||
|
if !oneOrNone(
|
||||||
|
len(cfg.SourceProfileName) != 0,
|
||||||
|
len(cfg.CredentialSource) != 0,
|
||||||
|
len(cfg.CredentialProcess) != 0,
|
||||||
|
len(cfg.WebIdentityTokenFile) != 0,
|
||||||
|
) {
|
||||||
|
return ErrSharedConfigSourceCollision
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (cfg *sharedConfig) hasCredentials() bool {
|
||||||
|
switch {
|
||||||
|
case len(cfg.SourceProfileName) != 0:
|
||||||
|
case len(cfg.CredentialSource) != 0:
|
||||||
|
case len(cfg.CredentialProcess) != 0:
|
||||||
|
case len(cfg.WebIdentityTokenFile) != 0:
|
||||||
|
case cfg.Creds.HasKeys():
|
||||||
|
default:
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
func (cfg *sharedConfig) clearCredentialOptions() {
|
||||||
|
cfg.CredentialSource = ""
|
||||||
|
cfg.CredentialProcess = ""
|
||||||
|
cfg.WebIdentityTokenFile = ""
|
||||||
|
cfg.Creds = credentials.Value{}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (cfg *sharedConfig) clearAssumeRoleOptions() {
|
||||||
|
cfg.RoleARN = ""
|
||||||
|
cfg.ExternalID = ""
|
||||||
|
cfg.MFASerial = ""
|
||||||
|
cfg.RoleSessionName = ""
|
||||||
|
cfg.SourceProfileName = ""
|
||||||
|
}
|
||||||
|
|
||||||
|
func oneOrNone(bs ...bool) bool {
|
||||||
|
var count int
|
||||||
|
|
||||||
|
for _, b := range bs {
|
||||||
|
if b {
|
||||||
|
count++
|
||||||
|
if count > 1 {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
// updateString will only update the dst with the value in the section key, key
|
||||||
|
// is present in the section.
|
||||||
|
func updateString(dst *string, section ini.Section, key string) {
|
||||||
|
if !section.Has(key) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
*dst = section.String(key)
|
||||||
|
}
|
||||||
|
|
||||||
// SharedConfigLoadError is an error for the shared config file failed to load.
|
// SharedConfigLoadError is an error for the shared config file failed to load.
|
||||||
type SharedConfigLoadError struct {
|
type SharedConfigLoadError struct {
|
||||||
Filename string
|
Filename string
|
||||||
|
@ -271,6 +406,7 @@ func (e SharedConfigProfileNotExistsError) Error() string {
|
||||||
// or not complete.
|
// or not complete.
|
||||||
type SharedConfigAssumeRoleError struct {
|
type SharedConfigAssumeRoleError struct {
|
||||||
RoleARN string
|
RoleARN string
|
||||||
|
SourceProfile string
|
||||||
}
|
}
|
||||||
|
|
||||||
// Code is the short id of the error.
|
// Code is the short id of the error.
|
||||||
|
@ -280,8 +416,10 @@ func (e SharedConfigAssumeRoleError) Code() string {
|
||||||
|
|
||||||
// Message is the description of the error
|
// Message is the description of the error
|
||||||
func (e SharedConfigAssumeRoleError) Message() string {
|
func (e SharedConfigAssumeRoleError) Message() string {
|
||||||
return fmt.Sprintf("failed to load assume role for %s, source profile has no shared credentials",
|
return fmt.Sprintf(
|
||||||
e.RoleARN)
|
"failed to load assume role for %s, source profile %s has no shared credentials",
|
||||||
|
e.RoleARN, e.SourceProfile,
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
// OrigErr is the underlying error that caused the failure.
|
// OrigErr is the underlying error that caused the failure.
|
||||||
|
@ -293,3 +431,36 @@ func (e SharedConfigAssumeRoleError) OrigErr() error {
|
||||||
func (e SharedConfigAssumeRoleError) Error() string {
|
func (e SharedConfigAssumeRoleError) Error() string {
|
||||||
return awserr.SprintError(e.Code(), e.Message(), "", nil)
|
return awserr.SprintError(e.Code(), e.Message(), "", nil)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// CredentialRequiresARNError provides the error for shared config credentials
|
||||||
|
// that are incorrectly configured in the shared config or credentials file.
|
||||||
|
type CredentialRequiresARNError struct {
|
||||||
|
// type of credentials that were configured.
|
||||||
|
Type string
|
||||||
|
|
||||||
|
// Profile name the credentials were in.
|
||||||
|
Profile string
|
||||||
|
}
|
||||||
|
|
||||||
|
// Code is the short id of the error.
|
||||||
|
func (e CredentialRequiresARNError) Code() string {
|
||||||
|
return "CredentialRequiresARNError"
|
||||||
|
}
|
||||||
|
|
||||||
|
// Message is the description of the error
|
||||||
|
func (e CredentialRequiresARNError) Message() string {
|
||||||
|
return fmt.Sprintf(
|
||||||
|
"credential type %s requires role_arn, profile %s",
|
||||||
|
e.Type, e.Profile,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
// OrigErr is the underlying error that caused the failure.
|
||||||
|
func (e CredentialRequiresARNError) OrigErr() error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Error satisfies the error interface.
|
||||||
|
func (e CredentialRequiresARNError) Error() string {
|
||||||
|
return awserr.SprintError(e.Code(), e.Message(), "", nil)
|
||||||
|
}
|
||||||
|
|
|
@ -134,6 +134,7 @@ var requiredSignedHeaders = rules{
|
||||||
"X-Amz-Server-Side-Encryption-Customer-Key": struct{}{},
|
"X-Amz-Server-Side-Encryption-Customer-Key": struct{}{},
|
||||||
"X-Amz-Server-Side-Encryption-Customer-Key-Md5": struct{}{},
|
"X-Amz-Server-Side-Encryption-Customer-Key-Md5": struct{}{},
|
||||||
"X-Amz-Storage-Class": struct{}{},
|
"X-Amz-Storage-Class": struct{}{},
|
||||||
|
"X-Amz-Tagging": struct{}{},
|
||||||
"X-Amz-Website-Redirect-Location": struct{}{},
|
"X-Amz-Website-Redirect-Location": struct{}{},
|
||||||
"X-Amz-Content-Sha256": struct{}{},
|
"X-Amz-Content-Sha256": struct{}{},
|
||||||
},
|
},
|
||||||
|
@ -181,7 +182,7 @@ type Signer struct {
|
||||||
// http://docs.aws.amazon.com/general/latest/gr/sigv4-create-canonical-request.html
|
// http://docs.aws.amazon.com/general/latest/gr/sigv4-create-canonical-request.html
|
||||||
DisableURIPathEscaping bool
|
DisableURIPathEscaping bool
|
||||||
|
|
||||||
// Disales the automatical setting of the HTTP request's Body field with the
|
// Disables the automatical setting of the HTTP request's Body field with the
|
||||||
// io.ReadSeeker passed in to the signer. This is useful if you're using a
|
// io.ReadSeeker passed in to the signer. This is useful if you're using a
|
||||||
// custom wrapper around the body for the io.ReadSeeker and want to preserve
|
// custom wrapper around the body for the io.ReadSeeker and want to preserve
|
||||||
// the Body value on the Request.Body.
|
// the Body value on the Request.Body.
|
||||||
|
@ -421,7 +422,7 @@ var SignRequestHandler = request.NamedHandler{
|
||||||
// If the credentials of the request's config are set to
|
// If the credentials of the request's config are set to
|
||||||
// credentials.AnonymousCredentials the request will not be signed.
|
// credentials.AnonymousCredentials the request will not be signed.
|
||||||
func SignSDKRequest(req *request.Request) {
|
func SignSDKRequest(req *request.Request) {
|
||||||
signSDKRequestWithCurrTime(req, time.Now)
|
SignSDKRequestWithCurrentTime(req, time.Now)
|
||||||
}
|
}
|
||||||
|
|
||||||
// BuildNamedHandler will build a generic handler for signing.
|
// BuildNamedHandler will build a generic handler for signing.
|
||||||
|
@ -429,12 +430,15 @@ func BuildNamedHandler(name string, opts ...func(*Signer)) request.NamedHandler
|
||||||
return request.NamedHandler{
|
return request.NamedHandler{
|
||||||
Name: name,
|
Name: name,
|
||||||
Fn: func(req *request.Request) {
|
Fn: func(req *request.Request) {
|
||||||
signSDKRequestWithCurrTime(req, time.Now, opts...)
|
SignSDKRequestWithCurrentTime(req, time.Now, opts...)
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func signSDKRequestWithCurrTime(req *request.Request, curTimeFn func() time.Time, opts ...func(*Signer)) {
|
// SignSDKRequestWithCurrentTime will sign the SDK's request using the time
|
||||||
|
// function passed in. Behaves the same as SignSDKRequest with the exception
|
||||||
|
// the request is signed with the value returned by the current time function.
|
||||||
|
func SignSDKRequestWithCurrentTime(req *request.Request, curTimeFn func() time.Time, opts ...func(*Signer)) {
|
||||||
// If the request does not need to be signed ignore the signing of the
|
// If the request does not need to be signed ignore the signing of the
|
||||||
// request if the AnonymousCredentials object is used.
|
// request if the AnonymousCredentials object is used.
|
||||||
if req.Config.Credentials == credentials.AnonymousCredentials {
|
if req.Config.Credentials == credentials.AnonymousCredentials {
|
||||||
|
@ -470,13 +474,9 @@ func signSDKRequestWithCurrTime(req *request.Request, curTimeFn func() time.Time
|
||||||
opt(v4)
|
opt(v4)
|
||||||
}
|
}
|
||||||
|
|
||||||
signingTime := req.Time
|
curTime := curTimeFn()
|
||||||
if !req.LastSignedAt.IsZero() {
|
|
||||||
signingTime = req.LastSignedAt
|
|
||||||
}
|
|
||||||
|
|
||||||
signedHeaders, err := v4.signWithBody(req.HTTPRequest, req.GetBody(),
|
signedHeaders, err := v4.signWithBody(req.HTTPRequest, req.GetBody(),
|
||||||
name, region, req.ExpireTime, req.ExpireTime > 0, signingTime,
|
name, region, req.ExpireTime, req.ExpireTime > 0, curTime,
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
req.Error = err
|
req.Error = err
|
||||||
|
@ -485,7 +485,7 @@ func signSDKRequestWithCurrTime(req *request.Request, curTimeFn func() time.Time
|
||||||
}
|
}
|
||||||
|
|
||||||
req.SignedHeaderVals = signedHeaders
|
req.SignedHeaderVals = signedHeaders
|
||||||
req.LastSignedAt = curTimeFn()
|
req.LastSignedAt = curTime
|
||||||
}
|
}
|
||||||
|
|
||||||
const logSignInfoMsg = `DEBUG: Request Signature:
|
const logSignInfoMsg = `DEBUG: Request Signature:
|
||||||
|
@ -739,14 +739,22 @@ func makeSha256Reader(reader io.ReadSeeker) []byte {
|
||||||
start, _ := reader.Seek(0, sdkio.SeekCurrent)
|
start, _ := reader.Seek(0, sdkio.SeekCurrent)
|
||||||
defer reader.Seek(start, sdkio.SeekStart)
|
defer reader.Seek(start, sdkio.SeekStart)
|
||||||
|
|
||||||
|
// Use CopyN to avoid allocating the 32KB buffer in io.Copy for bodies
|
||||||
|
// smaller than 32KB. Fall back to io.Copy if we fail to determine the size.
|
||||||
|
size, err := aws.SeekerLen(reader)
|
||||||
|
if err != nil {
|
||||||
io.Copy(hash, reader)
|
io.Copy(hash, reader)
|
||||||
|
} else {
|
||||||
|
io.CopyN(hash, reader, size)
|
||||||
|
}
|
||||||
|
|
||||||
return hash.Sum(nil)
|
return hash.Sum(nil)
|
||||||
}
|
}
|
||||||
|
|
||||||
const doubleSpace = " "
|
const doubleSpace = " "
|
||||||
|
|
||||||
// stripExcessSpaces will rewrite the passed in slice's string values to not
|
// stripExcessSpaces will rewrite the passed in slice's string values to not
|
||||||
// contain muliple side-by-side spaces.
|
// contain multiple side-by-side spaces.
|
||||||
func stripExcessSpaces(vals []string) {
|
func stripExcessSpaces(vals []string) {
|
||||||
var j, k, l, m, spaces int
|
var j, k, l, m, spaces int
|
||||||
for i, str := range vals {
|
for i, str := range vals {
|
||||||
|
|
|
@ -7,13 +7,18 @@ import (
|
||||||
"github.com/aws/aws-sdk-go/internal/sdkio"
|
"github.com/aws/aws-sdk-go/internal/sdkio"
|
||||||
)
|
)
|
||||||
|
|
||||||
// ReadSeekCloser wraps a io.Reader returning a ReaderSeekerCloser. Should
|
// ReadSeekCloser wraps a io.Reader returning a ReaderSeekerCloser. Allows the
|
||||||
// only be used with an io.Reader that is also an io.Seeker. Doing so may
|
// SDK to accept an io.Reader that is not also an io.Seeker for unsigned
|
||||||
// cause request signature errors, or request body's not sent for GET, HEAD
|
// streaming payload API operations.
|
||||||
// and DELETE HTTP methods.
|
|
||||||
//
|
//
|
||||||
// Deprecated: Should only be used with io.ReadSeeker. If using for
|
// A ReadSeekCloser wrapping an nonseekable io.Reader used in an API
|
||||||
// S3 PutObject to stream content use s3manager.Uploader instead.
|
// operation's input will prevent that operation being retried in the case of
|
||||||
|
// network errors, and cause operation requests to fail if the operation
|
||||||
|
// requires payload signing.
|
||||||
|
//
|
||||||
|
// Note: If using With S3 PutObject to stream an object upload The SDK's S3
|
||||||
|
// Upload manager (s3manager.Uploader) provides support for streaming with the
|
||||||
|
// ability to retry network errors.
|
||||||
func ReadSeekCloser(r io.Reader) ReaderSeekerCloser {
|
func ReadSeekCloser(r io.Reader) ReaderSeekerCloser {
|
||||||
return ReaderSeekerCloser{r}
|
return ReaderSeekerCloser{r}
|
||||||
}
|
}
|
||||||
|
@ -43,7 +48,8 @@ func IsReaderSeekable(r io.Reader) bool {
|
||||||
// Read reads from the reader up to size of p. The number of bytes read, and
|
// Read reads from the reader up to size of p. The number of bytes read, and
|
||||||
// error if it occurred will be returned.
|
// error if it occurred will be returned.
|
||||||
//
|
//
|
||||||
// If the reader is not an io.Reader zero bytes read, and nil error will be returned.
|
// If the reader is not an io.Reader zero bytes read, and nil error will be
|
||||||
|
// returned.
|
||||||
//
|
//
|
||||||
// Performs the same functionality as io.Reader Read
|
// Performs the same functionality as io.Reader Read
|
||||||
func (r ReaderSeekerCloser) Read(p []byte) (int, error) {
|
func (r ReaderSeekerCloser) Read(p []byte) (int, error) {
|
||||||
|
|
|
@ -5,4 +5,4 @@ package aws
|
||||||
const SDKName = "aws-sdk-go"
|
const SDKName = "aws-sdk-go"
|
||||||
|
|
||||||
// SDKVersion is the version of this SDK
|
// SDKVersion is the version of this SDK
|
||||||
const SDKVersion = "1.15.24"
|
const SDKVersion = "1.21.1"
|
||||||
|
|
|
@ -0,0 +1,120 @@
|
||||||
|
package ini
|
||||||
|
|
||||||
|
// ASTKind represents different states in the parse table
|
||||||
|
// and the type of AST that is being constructed
|
||||||
|
type ASTKind int
|
||||||
|
|
||||||
|
// ASTKind* is used in the parse table to transition between
|
||||||
|
// the different states
|
||||||
|
const (
|
||||||
|
ASTKindNone = ASTKind(iota)
|
||||||
|
ASTKindStart
|
||||||
|
ASTKindExpr
|
||||||
|
ASTKindEqualExpr
|
||||||
|
ASTKindStatement
|
||||||
|
ASTKindSkipStatement
|
||||||
|
ASTKindExprStatement
|
||||||
|
ASTKindSectionStatement
|
||||||
|
ASTKindNestedSectionStatement
|
||||||
|
ASTKindCompletedNestedSectionStatement
|
||||||
|
ASTKindCommentStatement
|
||||||
|
ASTKindCompletedSectionStatement
|
||||||
|
)
|
||||||
|
|
||||||
|
func (k ASTKind) String() string {
|
||||||
|
switch k {
|
||||||
|
case ASTKindNone:
|
||||||
|
return "none"
|
||||||
|
case ASTKindStart:
|
||||||
|
return "start"
|
||||||
|
case ASTKindExpr:
|
||||||
|
return "expr"
|
||||||
|
case ASTKindStatement:
|
||||||
|
return "stmt"
|
||||||
|
case ASTKindSectionStatement:
|
||||||
|
return "section_stmt"
|
||||||
|
case ASTKindExprStatement:
|
||||||
|
return "expr_stmt"
|
||||||
|
case ASTKindCommentStatement:
|
||||||
|
return "comment"
|
||||||
|
case ASTKindNestedSectionStatement:
|
||||||
|
return "nested_section_stmt"
|
||||||
|
case ASTKindCompletedSectionStatement:
|
||||||
|
return "completed_stmt"
|
||||||
|
case ASTKindSkipStatement:
|
||||||
|
return "skip"
|
||||||
|
default:
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// AST interface allows us to determine what kind of node we
|
||||||
|
// are on and casting may not need to be necessary.
|
||||||
|
//
|
||||||
|
// The root is always the first node in Children
|
||||||
|
type AST struct {
|
||||||
|
Kind ASTKind
|
||||||
|
Root Token
|
||||||
|
RootToken bool
|
||||||
|
Children []AST
|
||||||
|
}
|
||||||
|
|
||||||
|
func newAST(kind ASTKind, root AST, children ...AST) AST {
|
||||||
|
return AST{
|
||||||
|
Kind: kind,
|
||||||
|
Children: append([]AST{root}, children...),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func newASTWithRootToken(kind ASTKind, root Token, children ...AST) AST {
|
||||||
|
return AST{
|
||||||
|
Kind: kind,
|
||||||
|
Root: root,
|
||||||
|
RootToken: true,
|
||||||
|
Children: children,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// AppendChild will append to the list of children an AST has.
|
||||||
|
func (a *AST) AppendChild(child AST) {
|
||||||
|
a.Children = append(a.Children, child)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetRoot will return the root AST which can be the first entry
|
||||||
|
// in the children list or a token.
|
||||||
|
func (a *AST) GetRoot() AST {
|
||||||
|
if a.RootToken {
|
||||||
|
return *a
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(a.Children) == 0 {
|
||||||
|
return AST{}
|
||||||
|
}
|
||||||
|
|
||||||
|
return a.Children[0]
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetChildren will return the current AST's list of children
|
||||||
|
func (a *AST) GetChildren() []AST {
|
||||||
|
if len(a.Children) == 0 {
|
||||||
|
return []AST{}
|
||||||
|
}
|
||||||
|
|
||||||
|
if a.RootToken {
|
||||||
|
return a.Children
|
||||||
|
}
|
||||||
|
|
||||||
|
return a.Children[1:]
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetChildren will set and override all children of the AST.
|
||||||
|
func (a *AST) SetChildren(children []AST) {
|
||||||
|
if a.RootToken {
|
||||||
|
a.Children = children
|
||||||
|
} else {
|
||||||
|
a.Children = append(a.Children[:1], children...)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Start is used to indicate the starting state of the parse table.
|
||||||
|
var Start = newAST(ASTKindStart, AST{})
|
|
@ -0,0 +1,11 @@
|
||||||
|
package ini
|
||||||
|
|
||||||
|
var commaRunes = []rune(",")
|
||||||
|
|
||||||
|
func isComma(b rune) bool {
|
||||||
|
return b == ','
|
||||||
|
}
|
||||||
|
|
||||||
|
func newCommaToken() Token {
|
||||||
|
return newToken(TokenComma, commaRunes, NoneType)
|
||||||
|
}
|
|
@ -0,0 +1,35 @@
|
||||||
|
package ini
|
||||||
|
|
||||||
|
// isComment will return whether or not the next byte(s) is a
|
||||||
|
// comment.
|
||||||
|
func isComment(b []rune) bool {
|
||||||
|
if len(b) == 0 {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
switch b[0] {
|
||||||
|
case ';':
|
||||||
|
return true
|
||||||
|
case '#':
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// newCommentToken will create a comment token and
|
||||||
|
// return how many bytes were read.
|
||||||
|
func newCommentToken(b []rune) (Token, int, error) {
|
||||||
|
i := 0
|
||||||
|
for ; i < len(b); i++ {
|
||||||
|
if b[i] == '\n' {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(b)-i > 2 && b[i] == '\r' && b[i+1] == '\n' {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return newToken(TokenComment, b[:i], NoneType), i, nil
|
||||||
|
}
|
|
@ -0,0 +1,29 @@
|
||||||
|
// Package ini is an LL(1) parser for configuration files.
|
||||||
|
//
|
||||||
|
// Example:
|
||||||
|
// sections, err := ini.OpenFile("/path/to/file")
|
||||||
|
// if err != nil {
|
||||||
|
// panic(err)
|
||||||
|
// }
|
||||||
|
//
|
||||||
|
// profile := "foo"
|
||||||
|
// section, ok := sections.GetSection(profile)
|
||||||
|
// if !ok {
|
||||||
|
// fmt.Printf("section %q could not be found", profile)
|
||||||
|
// }
|
||||||
|
//
|
||||||
|
// Below is the BNF that describes this parser
|
||||||
|
// Grammar:
|
||||||
|
// stmt -> value stmt'
|
||||||
|
// stmt' -> epsilon | op stmt
|
||||||
|
// value -> number | string | boolean | quoted_string
|
||||||
|
//
|
||||||
|
// section -> [ section'
|
||||||
|
// section' -> value section_close
|
||||||
|
// section_close -> ]
|
||||||
|
//
|
||||||
|
// SkipState will skip (NL WS)+
|
||||||
|
//
|
||||||
|
// comment -> # comment' | ; comment'
|
||||||
|
// comment' -> epsilon | value
|
||||||
|
package ini
|
|
@ -0,0 +1,4 @@
|
||||||
|
package ini
|
||||||
|
|
||||||
|
// emptyToken is used to satisfy the Token interface
|
||||||
|
var emptyToken = newToken(TokenNone, []rune{}, NoneType)
|
|
@ -0,0 +1,24 @@
|
||||||
|
package ini
|
||||||
|
|
||||||
|
// newExpression will return an expression AST.
|
||||||
|
// Expr represents an expression
|
||||||
|
//
|
||||||
|
// grammar:
|
||||||
|
// expr -> string | number
|
||||||
|
func newExpression(tok Token) AST {
|
||||||
|
return newASTWithRootToken(ASTKindExpr, tok)
|
||||||
|
}
|
||||||
|
|
||||||
|
func newEqualExpr(left AST, tok Token) AST {
|
||||||
|
return newASTWithRootToken(ASTKindEqualExpr, tok, left)
|
||||||
|
}
|
||||||
|
|
||||||
|
// EqualExprKey will return a LHS value in the equal expr
|
||||||
|
func EqualExprKey(ast AST) string {
|
||||||
|
children := ast.GetChildren()
|
||||||
|
if len(children) == 0 || ast.Kind != ASTKindEqualExpr {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
return string(children[0].Root.Raw())
|
||||||
|
}
|
|
@ -0,0 +1,17 @@
|
||||||
|
// +build gofuzz
|
||||||
|
|
||||||
|
package ini
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
)
|
||||||
|
|
||||||
|
func Fuzz(data []byte) int {
|
||||||
|
b := bytes.NewReader(data)
|
||||||
|
|
||||||
|
if _, err := Parse(b); err != nil {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
|
||||||
|
return 1
|
||||||
|
}
|
|
@ -0,0 +1,51 @@
|
||||||
|
package ini
|
||||||
|
|
||||||
|
import (
|
||||||
|
"io"
|
||||||
|
"os"
|
||||||
|
|
||||||
|
"github.com/aws/aws-sdk-go/aws/awserr"
|
||||||
|
)
|
||||||
|
|
||||||
|
// OpenFile takes a path to a given file, and will open and parse
|
||||||
|
// that file.
|
||||||
|
func OpenFile(path string) (Sections, error) {
|
||||||
|
f, err := os.Open(path)
|
||||||
|
if err != nil {
|
||||||
|
return Sections{}, awserr.New(ErrCodeUnableToReadFile, "unable to open file", err)
|
||||||
|
}
|
||||||
|
defer f.Close()
|
||||||
|
|
||||||
|
return Parse(f)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parse will parse the given file using the shared config
|
||||||
|
// visitor.
|
||||||
|
func Parse(f io.Reader) (Sections, error) {
|
||||||
|
tree, err := ParseAST(f)
|
||||||
|
if err != nil {
|
||||||
|
return Sections{}, err
|
||||||
|
}
|
||||||
|
|
||||||
|
v := NewDefaultVisitor()
|
||||||
|
if err = Walk(tree, v); err != nil {
|
||||||
|
return Sections{}, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return v.Sections, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ParseBytes will parse the given bytes and return the parsed sections.
|
||||||
|
func ParseBytes(b []byte) (Sections, error) {
|
||||||
|
tree, err := ParseASTBytes(b)
|
||||||
|
if err != nil {
|
||||||
|
return Sections{}, err
|
||||||
|
}
|
||||||
|
|
||||||
|
v := NewDefaultVisitor()
|
||||||
|
if err = Walk(tree, v); err != nil {
|
||||||
|
return Sections{}, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return v.Sections, nil
|
||||||
|
}
|
|
@ -0,0 +1,165 @@
|
||||||
|
package ini
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"io"
|
||||||
|
"io/ioutil"
|
||||||
|
|
||||||
|
"github.com/aws/aws-sdk-go/aws/awserr"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
// ErrCodeUnableToReadFile is used when a file is failed to be
|
||||||
|
// opened or read from.
|
||||||
|
ErrCodeUnableToReadFile = "FailedRead"
|
||||||
|
)
|
||||||
|
|
||||||
|
// TokenType represents the various different tokens types
|
||||||
|
type TokenType int
|
||||||
|
|
||||||
|
func (t TokenType) String() string {
|
||||||
|
switch t {
|
||||||
|
case TokenNone:
|
||||||
|
return "none"
|
||||||
|
case TokenLit:
|
||||||
|
return "literal"
|
||||||
|
case TokenSep:
|
||||||
|
return "sep"
|
||||||
|
case TokenOp:
|
||||||
|
return "op"
|
||||||
|
case TokenWS:
|
||||||
|
return "ws"
|
||||||
|
case TokenNL:
|
||||||
|
return "newline"
|
||||||
|
case TokenComment:
|
||||||
|
return "comment"
|
||||||
|
case TokenComma:
|
||||||
|
return "comma"
|
||||||
|
default:
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TokenType enums
|
||||||
|
const (
|
||||||
|
TokenNone = TokenType(iota)
|
||||||
|
TokenLit
|
||||||
|
TokenSep
|
||||||
|
TokenComma
|
||||||
|
TokenOp
|
||||||
|
TokenWS
|
||||||
|
TokenNL
|
||||||
|
TokenComment
|
||||||
|
)
|
||||||
|
|
||||||
|
type iniLexer struct{}
|
||||||
|
|
||||||
|
// Tokenize will return a list of tokens during lexical analysis of the
|
||||||
|
// io.Reader.
|
||||||
|
func (l *iniLexer) Tokenize(r io.Reader) ([]Token, error) {
|
||||||
|
b, err := ioutil.ReadAll(r)
|
||||||
|
if err != nil {
|
||||||
|
return nil, awserr.New(ErrCodeUnableToReadFile, "unable to read file", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return l.tokenize(b)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *iniLexer) tokenize(b []byte) ([]Token, error) {
|
||||||
|
runes := bytes.Runes(b)
|
||||||
|
var err error
|
||||||
|
n := 0
|
||||||
|
tokenAmount := countTokens(runes)
|
||||||
|
tokens := make([]Token, tokenAmount)
|
||||||
|
count := 0
|
||||||
|
|
||||||
|
for len(runes) > 0 && count < tokenAmount {
|
||||||
|
switch {
|
||||||
|
case isWhitespace(runes[0]):
|
||||||
|
tokens[count], n, err = newWSToken(runes)
|
||||||
|
case isComma(runes[0]):
|
||||||
|
tokens[count], n = newCommaToken(), 1
|
||||||
|
case isComment(runes):
|
||||||
|
tokens[count], n, err = newCommentToken(runes)
|
||||||
|
case isNewline(runes):
|
||||||
|
tokens[count], n, err = newNewlineToken(runes)
|
||||||
|
case isSep(runes):
|
||||||
|
tokens[count], n, err = newSepToken(runes)
|
||||||
|
case isOp(runes):
|
||||||
|
tokens[count], n, err = newOpToken(runes)
|
||||||
|
default:
|
||||||
|
tokens[count], n, err = newLitToken(runes)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
count++
|
||||||
|
|
||||||
|
runes = runes[n:]
|
||||||
|
}
|
||||||
|
|
||||||
|
return tokens[:count], nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func countTokens(runes []rune) int {
|
||||||
|
count, n := 0, 0
|
||||||
|
var err error
|
||||||
|
|
||||||
|
for len(runes) > 0 {
|
||||||
|
switch {
|
||||||
|
case isWhitespace(runes[0]):
|
||||||
|
_, n, err = newWSToken(runes)
|
||||||
|
case isComma(runes[0]):
|
||||||
|
_, n = newCommaToken(), 1
|
||||||
|
case isComment(runes):
|
||||||
|
_, n, err = newCommentToken(runes)
|
||||||
|
case isNewline(runes):
|
||||||
|
_, n, err = newNewlineToken(runes)
|
||||||
|
case isSep(runes):
|
||||||
|
_, n, err = newSepToken(runes)
|
||||||
|
case isOp(runes):
|
||||||
|
_, n, err = newOpToken(runes)
|
||||||
|
default:
|
||||||
|
_, n, err = newLitToken(runes)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
|
||||||
|
count++
|
||||||
|
runes = runes[n:]
|
||||||
|
}
|
||||||
|
|
||||||
|
return count + 1
|
||||||
|
}
|
||||||
|
|
||||||
|
// Token indicates a metadata about a given value.
|
||||||
|
type Token struct {
|
||||||
|
t TokenType
|
||||||
|
ValueType ValueType
|
||||||
|
base int
|
||||||
|
raw []rune
|
||||||
|
}
|
||||||
|
|
||||||
|
var emptyValue = Value{}
|
||||||
|
|
||||||
|
func newToken(t TokenType, raw []rune, v ValueType) Token {
|
||||||
|
return Token{
|
||||||
|
t: t,
|
||||||
|
raw: raw,
|
||||||
|
ValueType: v,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Raw return the raw runes that were consumed
|
||||||
|
func (tok Token) Raw() []rune {
|
||||||
|
return tok.raw
|
||||||
|
}
|
||||||
|
|
||||||
|
// Type returns the token type
|
||||||
|
func (tok Token) Type() TokenType {
|
||||||
|
return tok.t
|
||||||
|
}
|
|
@ -0,0 +1,349 @@
|
||||||
|
package ini
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
)
|
||||||
|
|
||||||
|
// State enums for the parse table
|
||||||
|
const (
|
||||||
|
InvalidState = iota
|
||||||
|
// stmt -> value stmt'
|
||||||
|
StatementState
|
||||||
|
// stmt' -> MarkComplete | op stmt
|
||||||
|
StatementPrimeState
|
||||||
|
// value -> number | string | boolean | quoted_string
|
||||||
|
ValueState
|
||||||
|
// section -> [ section'
|
||||||
|
OpenScopeState
|
||||||
|
// section' -> value section_close
|
||||||
|
SectionState
|
||||||
|
// section_close -> ]
|
||||||
|
CloseScopeState
|
||||||
|
// SkipState will skip (NL WS)+
|
||||||
|
SkipState
|
||||||
|
// SkipTokenState will skip any token and push the previous
|
||||||
|
// state onto the stack.
|
||||||
|
SkipTokenState
|
||||||
|
// comment -> # comment' | ; comment'
|
||||||
|
// comment' -> MarkComplete | value
|
||||||
|
CommentState
|
||||||
|
// MarkComplete state will complete statements and move that
|
||||||
|
// to the completed AST list
|
||||||
|
MarkCompleteState
|
||||||
|
// TerminalState signifies that the tokens have been fully parsed
|
||||||
|
TerminalState
|
||||||
|
)
|
||||||
|
|
||||||
|
// parseTable is a state machine to dictate the grammar above.
|
||||||
|
var parseTable = map[ASTKind]map[TokenType]int{
|
||||||
|
ASTKindStart: map[TokenType]int{
|
||||||
|
TokenLit: StatementState,
|
||||||
|
TokenSep: OpenScopeState,
|
||||||
|
TokenWS: SkipTokenState,
|
||||||
|
TokenNL: SkipTokenState,
|
||||||
|
TokenComment: CommentState,
|
||||||
|
TokenNone: TerminalState,
|
||||||
|
},
|
||||||
|
ASTKindCommentStatement: map[TokenType]int{
|
||||||
|
TokenLit: StatementState,
|
||||||
|
TokenSep: OpenScopeState,
|
||||||
|
TokenWS: SkipTokenState,
|
||||||
|
TokenNL: SkipTokenState,
|
||||||
|
TokenComment: CommentState,
|
||||||
|
TokenNone: MarkCompleteState,
|
||||||
|
},
|
||||||
|
ASTKindExpr: map[TokenType]int{
|
||||||
|
TokenOp: StatementPrimeState,
|
||||||
|
TokenLit: ValueState,
|
||||||
|
TokenSep: OpenScopeState,
|
||||||
|
TokenWS: ValueState,
|
||||||
|
TokenNL: SkipState,
|
||||||
|
TokenComment: CommentState,
|
||||||
|
TokenNone: MarkCompleteState,
|
||||||
|
},
|
||||||
|
ASTKindEqualExpr: map[TokenType]int{
|
||||||
|
TokenLit: ValueState,
|
||||||
|
TokenWS: SkipTokenState,
|
||||||
|
TokenNL: SkipState,
|
||||||
|
},
|
||||||
|
ASTKindStatement: map[TokenType]int{
|
||||||
|
TokenLit: SectionState,
|
||||||
|
TokenSep: CloseScopeState,
|
||||||
|
TokenWS: SkipTokenState,
|
||||||
|
TokenNL: SkipTokenState,
|
||||||
|
TokenComment: CommentState,
|
||||||
|
TokenNone: MarkCompleteState,
|
||||||
|
},
|
||||||
|
ASTKindExprStatement: map[TokenType]int{
|
||||||
|
TokenLit: ValueState,
|
||||||
|
TokenSep: OpenScopeState,
|
||||||
|
TokenOp: ValueState,
|
||||||
|
TokenWS: ValueState,
|
||||||
|
TokenNL: MarkCompleteState,
|
||||||
|
TokenComment: CommentState,
|
||||||
|
TokenNone: TerminalState,
|
||||||
|
TokenComma: SkipState,
|
||||||
|
},
|
||||||
|
ASTKindSectionStatement: map[TokenType]int{
|
||||||
|
TokenLit: SectionState,
|
||||||
|
TokenOp: SectionState,
|
||||||
|
TokenSep: CloseScopeState,
|
||||||
|
TokenWS: SectionState,
|
||||||
|
TokenNL: SkipTokenState,
|
||||||
|
},
|
||||||
|
ASTKindCompletedSectionStatement: map[TokenType]int{
|
||||||
|
TokenWS: SkipTokenState,
|
||||||
|
TokenNL: SkipTokenState,
|
||||||
|
TokenLit: StatementState,
|
||||||
|
TokenSep: OpenScopeState,
|
||||||
|
TokenComment: CommentState,
|
||||||
|
TokenNone: MarkCompleteState,
|
||||||
|
},
|
||||||
|
ASTKindSkipStatement: map[TokenType]int{
|
||||||
|
TokenLit: StatementState,
|
||||||
|
TokenSep: OpenScopeState,
|
||||||
|
TokenWS: SkipTokenState,
|
||||||
|
TokenNL: SkipTokenState,
|
||||||
|
TokenComment: CommentState,
|
||||||
|
TokenNone: TerminalState,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
// ParseAST will parse input from an io.Reader using
|
||||||
|
// an LL(1) parser.
|
||||||
|
func ParseAST(r io.Reader) ([]AST, error) {
|
||||||
|
lexer := iniLexer{}
|
||||||
|
tokens, err := lexer.Tokenize(r)
|
||||||
|
if err != nil {
|
||||||
|
return []AST{}, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return parse(tokens)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ParseASTBytes will parse input from a byte slice using
|
||||||
|
// an LL(1) parser.
|
||||||
|
func ParseASTBytes(b []byte) ([]AST, error) {
|
||||||
|
lexer := iniLexer{}
|
||||||
|
tokens, err := lexer.tokenize(b)
|
||||||
|
if err != nil {
|
||||||
|
return []AST{}, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return parse(tokens)
|
||||||
|
}
|
||||||
|
|
||||||
|
func parse(tokens []Token) ([]AST, error) {
|
||||||
|
start := Start
|
||||||
|
stack := newParseStack(3, len(tokens))
|
||||||
|
|
||||||
|
stack.Push(start)
|
||||||
|
s := newSkipper()
|
||||||
|
|
||||||
|
loop:
|
||||||
|
for stack.Len() > 0 {
|
||||||
|
k := stack.Pop()
|
||||||
|
|
||||||
|
var tok Token
|
||||||
|
if len(tokens) == 0 {
|
||||||
|
// this occurs when all the tokens have been processed
|
||||||
|
// but reduction of what's left on the stack needs to
|
||||||
|
// occur.
|
||||||
|
tok = emptyToken
|
||||||
|
} else {
|
||||||
|
tok = tokens[0]
|
||||||
|
}
|
||||||
|
|
||||||
|
step := parseTable[k.Kind][tok.Type()]
|
||||||
|
if s.ShouldSkip(tok) {
|
||||||
|
// being in a skip state with no tokens will break out of
|
||||||
|
// the parse loop since there is nothing left to process.
|
||||||
|
if len(tokens) == 0 {
|
||||||
|
break loop
|
||||||
|
}
|
||||||
|
|
||||||
|
step = SkipTokenState
|
||||||
|
}
|
||||||
|
|
||||||
|
switch step {
|
||||||
|
case TerminalState:
|
||||||
|
// Finished parsing. Push what should be the last
|
||||||
|
// statement to the stack. If there is anything left
|
||||||
|
// on the stack, an error in parsing has occurred.
|
||||||
|
if k.Kind != ASTKindStart {
|
||||||
|
stack.MarkComplete(k)
|
||||||
|
}
|
||||||
|
break loop
|
||||||
|
case SkipTokenState:
|
||||||
|
// When skipping a token, the previous state was popped off the stack.
|
||||||
|
// To maintain the correct state, the previous state will be pushed
|
||||||
|
// onto the stack.
|
||||||
|
stack.Push(k)
|
||||||
|
case StatementState:
|
||||||
|
if k.Kind != ASTKindStart {
|
||||||
|
stack.MarkComplete(k)
|
||||||
|
}
|
||||||
|
expr := newExpression(tok)
|
||||||
|
stack.Push(expr)
|
||||||
|
case StatementPrimeState:
|
||||||
|
if tok.Type() != TokenOp {
|
||||||
|
stack.MarkComplete(k)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if k.Kind != ASTKindExpr {
|
||||||
|
return nil, NewParseError(
|
||||||
|
fmt.Sprintf("invalid expression: expected Expr type, but found %T type", k),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
k = trimSpaces(k)
|
||||||
|
expr := newEqualExpr(k, tok)
|
||||||
|
stack.Push(expr)
|
||||||
|
case ValueState:
|
||||||
|
// ValueState requires the previous state to either be an equal expression
|
||||||
|
// or an expression statement.
|
||||||
|
//
|
||||||
|
// This grammar occurs when the RHS is a number, word, or quoted string.
|
||||||
|
// equal_expr -> lit op equal_expr'
|
||||||
|
// equal_expr' -> number | string | quoted_string
|
||||||
|
// quoted_string -> " quoted_string'
|
||||||
|
// quoted_string' -> string quoted_string_end
|
||||||
|
// quoted_string_end -> "
|
||||||
|
//
|
||||||
|
// otherwise
|
||||||
|
// expr_stmt -> equal_expr (expr_stmt')*
|
||||||
|
// expr_stmt' -> ws S | op S | MarkComplete
|
||||||
|
// S -> equal_expr' expr_stmt'
|
||||||
|
switch k.Kind {
|
||||||
|
case ASTKindEqualExpr:
|
||||||
|
// assiging a value to some key
|
||||||
|
k.AppendChild(newExpression(tok))
|
||||||
|
stack.Push(newExprStatement(k))
|
||||||
|
case ASTKindExpr:
|
||||||
|
k.Root.raw = append(k.Root.raw, tok.Raw()...)
|
||||||
|
stack.Push(k)
|
||||||
|
case ASTKindExprStatement:
|
||||||
|
root := k.GetRoot()
|
||||||
|
children := root.GetChildren()
|
||||||
|
if len(children) == 0 {
|
||||||
|
return nil, NewParseError(
|
||||||
|
fmt.Sprintf("invalid expression: AST contains no children %s", k.Kind),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
rhs := children[len(children)-1]
|
||||||
|
|
||||||
|
if rhs.Root.ValueType != QuotedStringType {
|
||||||
|
rhs.Root.ValueType = StringType
|
||||||
|
rhs.Root.raw = append(rhs.Root.raw, tok.Raw()...)
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
children[len(children)-1] = rhs
|
||||||
|
k.SetChildren(children)
|
||||||
|
|
||||||
|
stack.Push(k)
|
||||||
|
}
|
||||||
|
case OpenScopeState:
|
||||||
|
if !runeCompare(tok.Raw(), openBrace) {
|
||||||
|
return nil, NewParseError("expected '['")
|
||||||
|
}
|
||||||
|
|
||||||
|
stmt := newStatement()
|
||||||
|
stack.Push(stmt)
|
||||||
|
case CloseScopeState:
|
||||||
|
if !runeCompare(tok.Raw(), closeBrace) {
|
||||||
|
return nil, NewParseError("expected ']'")
|
||||||
|
}
|
||||||
|
|
||||||
|
k = trimSpaces(k)
|
||||||
|
stack.Push(newCompletedSectionStatement(k))
|
||||||
|
case SectionState:
|
||||||
|
var stmt AST
|
||||||
|
|
||||||
|
switch k.Kind {
|
||||||
|
case ASTKindStatement:
|
||||||
|
// If there are multiple literals inside of a scope declaration,
|
||||||
|
// then the current token's raw value will be appended to the Name.
|
||||||
|
//
|
||||||
|
// This handles cases like [ profile default ]
|
||||||
|
//
|
||||||
|
// k will represent a SectionStatement with the children representing
|
||||||
|
// the label of the section
|
||||||
|
stmt = newSectionStatement(tok)
|
||||||
|
case ASTKindSectionStatement:
|
||||||
|
k.Root.raw = append(k.Root.raw, tok.Raw()...)
|
||||||
|
stmt = k
|
||||||
|
default:
|
||||||
|
return nil, NewParseError(
|
||||||
|
fmt.Sprintf("invalid statement: expected statement: %v", k.Kind),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
stack.Push(stmt)
|
||||||
|
case MarkCompleteState:
|
||||||
|
if k.Kind != ASTKindStart {
|
||||||
|
stack.MarkComplete(k)
|
||||||
|
}
|
||||||
|
|
||||||
|
if stack.Len() == 0 {
|
||||||
|
stack.Push(start)
|
||||||
|
}
|
||||||
|
case SkipState:
|
||||||
|
stack.Push(newSkipStatement(k))
|
||||||
|
s.Skip()
|
||||||
|
case CommentState:
|
||||||
|
if k.Kind == ASTKindStart {
|
||||||
|
stack.Push(k)
|
||||||
|
} else {
|
||||||
|
stack.MarkComplete(k)
|
||||||
|
}
|
||||||
|
|
||||||
|
stmt := newCommentStatement(tok)
|
||||||
|
stack.Push(stmt)
|
||||||
|
default:
|
||||||
|
return nil, NewParseError(
|
||||||
|
fmt.Sprintf("invalid state with ASTKind %v and TokenType %v",
|
||||||
|
k, tok.Type()))
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(tokens) > 0 {
|
||||||
|
tokens = tokens[1:]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// this occurs when a statement has not been completed
|
||||||
|
if stack.top > 1 {
|
||||||
|
return nil, NewParseError(fmt.Sprintf("incomplete ini expression"))
|
||||||
|
}
|
||||||
|
|
||||||
|
// returns a sublist which excludes the start symbol
|
||||||
|
return stack.List(), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// trimSpaces will trim spaces on the left and right hand side of
|
||||||
|
// the literal.
|
||||||
|
func trimSpaces(k AST) AST {
|
||||||
|
// trim left hand side of spaces
|
||||||
|
for i := 0; i < len(k.Root.raw); i++ {
|
||||||
|
if !isWhitespace(k.Root.raw[i]) {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
k.Root.raw = k.Root.raw[1:]
|
||||||
|
i--
|
||||||
|
}
|
||||||
|
|
||||||
|
// trim right hand side of spaces
|
||||||
|
for i := len(k.Root.raw) - 1; i >= 0; i-- {
|
||||||
|
if !isWhitespace(k.Root.raw[i]) {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
k.Root.raw = k.Root.raw[:len(k.Root.raw)-1]
|
||||||
|
}
|
||||||
|
|
||||||
|
return k
|
||||||
|
}
|
|
@ -0,0 +1,324 @@
|
||||||
|
package ini
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
runesTrue = []rune("true")
|
||||||
|
runesFalse = []rune("false")
|
||||||
|
)
|
||||||
|
|
||||||
|
var literalValues = [][]rune{
|
||||||
|
runesTrue,
|
||||||
|
runesFalse,
|
||||||
|
}
|
||||||
|
|
||||||
|
func isBoolValue(b []rune) bool {
|
||||||
|
for _, lv := range literalValues {
|
||||||
|
if isLitValue(lv, b) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func isLitValue(want, have []rune) bool {
|
||||||
|
if len(have) < len(want) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
for i := 0; i < len(want); i++ {
|
||||||
|
if want[i] != have[i] {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
// isNumberValue will return whether not the leading characters in
|
||||||
|
// a byte slice is a number. A number is delimited by whitespace or
|
||||||
|
// the newline token.
|
||||||
|
//
|
||||||
|
// A number is defined to be in a binary, octal, decimal (int | float), hex format,
|
||||||
|
// or in scientific notation.
|
||||||
|
func isNumberValue(b []rune) bool {
|
||||||
|
negativeIndex := 0
|
||||||
|
helper := numberHelper{}
|
||||||
|
needDigit := false
|
||||||
|
|
||||||
|
for i := 0; i < len(b); i++ {
|
||||||
|
negativeIndex++
|
||||||
|
|
||||||
|
switch b[i] {
|
||||||
|
case '-':
|
||||||
|
if helper.IsNegative() || negativeIndex != 1 {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
helper.Determine(b[i])
|
||||||
|
needDigit = true
|
||||||
|
continue
|
||||||
|
case 'e', 'E':
|
||||||
|
if err := helper.Determine(b[i]); err != nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
negativeIndex = 0
|
||||||
|
needDigit = true
|
||||||
|
continue
|
||||||
|
case 'b':
|
||||||
|
if helper.numberFormat == hex {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
fallthrough
|
||||||
|
case 'o', 'x':
|
||||||
|
needDigit = true
|
||||||
|
if i == 0 {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
fallthrough
|
||||||
|
case '.':
|
||||||
|
if err := helper.Determine(b[i]); err != nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
needDigit = true
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if i > 0 && (isNewline(b[i:]) || isWhitespace(b[i])) {
|
||||||
|
return !needDigit
|
||||||
|
}
|
||||||
|
|
||||||
|
if !helper.CorrectByte(b[i]) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
needDigit = false
|
||||||
|
}
|
||||||
|
|
||||||
|
return !needDigit
|
||||||
|
}
|
||||||
|
|
||||||
|
func isValid(b []rune) (bool, int, error) {
|
||||||
|
if len(b) == 0 {
|
||||||
|
// TODO: should probably return an error
|
||||||
|
return false, 0, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return isValidRune(b[0]), 1, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func isValidRune(r rune) bool {
|
||||||
|
return r != ':' && r != '=' && r != '[' && r != ']' && r != ' ' && r != '\n'
|
||||||
|
}
|
||||||
|
|
||||||
|
// ValueType is an enum that will signify what type
|
||||||
|
// the Value is
|
||||||
|
type ValueType int
|
||||||
|
|
||||||
|
func (v ValueType) String() string {
|
||||||
|
switch v {
|
||||||
|
case NoneType:
|
||||||
|
return "NONE"
|
||||||
|
case DecimalType:
|
||||||
|
return "FLOAT"
|
||||||
|
case IntegerType:
|
||||||
|
return "INT"
|
||||||
|
case StringType:
|
||||||
|
return "STRING"
|
||||||
|
case BoolType:
|
||||||
|
return "BOOL"
|
||||||
|
}
|
||||||
|
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
// ValueType enums
|
||||||
|
const (
|
||||||
|
NoneType = ValueType(iota)
|
||||||
|
DecimalType
|
||||||
|
IntegerType
|
||||||
|
StringType
|
||||||
|
QuotedStringType
|
||||||
|
BoolType
|
||||||
|
)
|
||||||
|
|
||||||
|
// Value is a union container
|
||||||
|
type Value struct {
|
||||||
|
Type ValueType
|
||||||
|
raw []rune
|
||||||
|
|
||||||
|
integer int64
|
||||||
|
decimal float64
|
||||||
|
boolean bool
|
||||||
|
str string
|
||||||
|
}
|
||||||
|
|
||||||
|
func newValue(t ValueType, base int, raw []rune) (Value, error) {
|
||||||
|
v := Value{
|
||||||
|
Type: t,
|
||||||
|
raw: raw,
|
||||||
|
}
|
||||||
|
var err error
|
||||||
|
|
||||||
|
switch t {
|
||||||
|
case DecimalType:
|
||||||
|
v.decimal, err = strconv.ParseFloat(string(raw), 64)
|
||||||
|
case IntegerType:
|
||||||
|
if base != 10 {
|
||||||
|
raw = raw[2:]
|
||||||
|
}
|
||||||
|
|
||||||
|
v.integer, err = strconv.ParseInt(string(raw), base, 64)
|
||||||
|
case StringType:
|
||||||
|
v.str = string(raw)
|
||||||
|
case QuotedStringType:
|
||||||
|
v.str = string(raw[1 : len(raw)-1])
|
||||||
|
case BoolType:
|
||||||
|
v.boolean = runeCompare(v.raw, runesTrue)
|
||||||
|
}
|
||||||
|
|
||||||
|
// issue 2253
|
||||||
|
//
|
||||||
|
// if the value trying to be parsed is too large, then we will use
|
||||||
|
// the 'StringType' and raw value instead.
|
||||||
|
if nerr, ok := err.(*strconv.NumError); ok && nerr.Err == strconv.ErrRange {
|
||||||
|
v.Type = StringType
|
||||||
|
v.str = string(raw)
|
||||||
|
err = nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return v, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Append will append values and change the type to a string
|
||||||
|
// type.
|
||||||
|
func (v *Value) Append(tok Token) {
|
||||||
|
r := tok.Raw()
|
||||||
|
if v.Type != QuotedStringType {
|
||||||
|
v.Type = StringType
|
||||||
|
r = tok.raw[1 : len(tok.raw)-1]
|
||||||
|
}
|
||||||
|
if tok.Type() != TokenLit {
|
||||||
|
v.raw = append(v.raw, tok.Raw()...)
|
||||||
|
} else {
|
||||||
|
v.raw = append(v.raw, r...)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (v Value) String() string {
|
||||||
|
switch v.Type {
|
||||||
|
case DecimalType:
|
||||||
|
return fmt.Sprintf("decimal: %f", v.decimal)
|
||||||
|
case IntegerType:
|
||||||
|
return fmt.Sprintf("integer: %d", v.integer)
|
||||||
|
case StringType:
|
||||||
|
return fmt.Sprintf("string: %s", string(v.raw))
|
||||||
|
case QuotedStringType:
|
||||||
|
return fmt.Sprintf("quoted string: %s", string(v.raw))
|
||||||
|
case BoolType:
|
||||||
|
return fmt.Sprintf("bool: %t", v.boolean)
|
||||||
|
default:
|
||||||
|
return "union not set"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func newLitToken(b []rune) (Token, int, error) {
|
||||||
|
n := 0
|
||||||
|
var err error
|
||||||
|
|
||||||
|
token := Token{}
|
||||||
|
if b[0] == '"' {
|
||||||
|
n, err = getStringValue(b)
|
||||||
|
if err != nil {
|
||||||
|
return token, n, err
|
||||||
|
}
|
||||||
|
|
||||||
|
token = newToken(TokenLit, b[:n], QuotedStringType)
|
||||||
|
} else if isNumberValue(b) {
|
||||||
|
var base int
|
||||||
|
base, n, err = getNumericalValue(b)
|
||||||
|
if err != nil {
|
||||||
|
return token, 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
value := b[:n]
|
||||||
|
vType := IntegerType
|
||||||
|
if contains(value, '.') || hasExponent(value) {
|
||||||
|
vType = DecimalType
|
||||||
|
}
|
||||||
|
token = newToken(TokenLit, value, vType)
|
||||||
|
token.base = base
|
||||||
|
} else if isBoolValue(b) {
|
||||||
|
n, err = getBoolValue(b)
|
||||||
|
|
||||||
|
token = newToken(TokenLit, b[:n], BoolType)
|
||||||
|
} else {
|
||||||
|
n, err = getValue(b)
|
||||||
|
token = newToken(TokenLit, b[:n], StringType)
|
||||||
|
}
|
||||||
|
|
||||||
|
return token, n, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// IntValue returns an integer value
|
||||||
|
func (v Value) IntValue() int64 {
|
||||||
|
return v.integer
|
||||||
|
}
|
||||||
|
|
||||||
|
// FloatValue returns a float value
|
||||||
|
func (v Value) FloatValue() float64 {
|
||||||
|
return v.decimal
|
||||||
|
}
|
||||||
|
|
||||||
|
// BoolValue returns a bool value
|
||||||
|
func (v Value) BoolValue() bool {
|
||||||
|
return v.boolean
|
||||||
|
}
|
||||||
|
|
||||||
|
func isTrimmable(r rune) bool {
|
||||||
|
switch r {
|
||||||
|
case '\n', ' ':
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// StringValue returns the string value
|
||||||
|
func (v Value) StringValue() string {
|
||||||
|
switch v.Type {
|
||||||
|
case StringType:
|
||||||
|
return strings.TrimFunc(string(v.raw), isTrimmable)
|
||||||
|
case QuotedStringType:
|
||||||
|
// preserve all characters in the quotes
|
||||||
|
return string(removeEscapedCharacters(v.raw[1 : len(v.raw)-1]))
|
||||||
|
default:
|
||||||
|
return strings.TrimFunc(string(v.raw), isTrimmable)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func contains(runes []rune, c rune) bool {
|
||||||
|
for i := 0; i < len(runes); i++ {
|
||||||
|
if runes[i] == c {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func runeCompare(v1 []rune, v2 []rune) bool {
|
||||||
|
if len(v1) != len(v2) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
for i := 0; i < len(v1); i++ {
|
||||||
|
if v1[i] != v2[i] {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return true
|
||||||
|
}
|
|
@ -0,0 +1,30 @@
|
||||||
|
package ini
|
||||||
|
|
||||||
|
func isNewline(b []rune) bool {
|
||||||
|
if len(b) == 0 {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
if b[0] == '\n' {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(b) < 2 {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
return b[0] == '\r' && b[1] == '\n'
|
||||||
|
}
|
||||||
|
|
||||||
|
func newNewlineToken(b []rune) (Token, int, error) {
|
||||||
|
i := 1
|
||||||
|
if b[0] == '\r' && isNewline(b[1:]) {
|
||||||
|
i++
|
||||||
|
}
|
||||||
|
|
||||||
|
if !isNewline([]rune(b[:i])) {
|
||||||
|
return emptyToken, 0, NewParseError("invalid new line token")
|
||||||
|
}
|
||||||
|
|
||||||
|
return newToken(TokenNL, b[:i], NoneType), i, nil
|
||||||
|
}
|
|
@ -0,0 +1,152 @@
|
||||||
|
package ini
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"fmt"
|
||||||
|
"strconv"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
none = numberFormat(iota)
|
||||||
|
binary
|
||||||
|
octal
|
||||||
|
decimal
|
||||||
|
hex
|
||||||
|
exponent
|
||||||
|
)
|
||||||
|
|
||||||
|
type numberFormat int
|
||||||
|
|
||||||
|
// numberHelper is used to dictate what format a number is in
|
||||||
|
// and what to do for negative values. Since -1e-4 is a valid
|
||||||
|
// number, we cannot just simply check for duplicate negatives.
|
||||||
|
type numberHelper struct {
|
||||||
|
numberFormat numberFormat
|
||||||
|
|
||||||
|
negative bool
|
||||||
|
negativeExponent bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b numberHelper) Exists() bool {
|
||||||
|
return b.numberFormat != none
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b numberHelper) IsNegative() bool {
|
||||||
|
return b.negative || b.negativeExponent
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *numberHelper) Determine(c rune) error {
|
||||||
|
if b.Exists() {
|
||||||
|
return NewParseError(fmt.Sprintf("multiple number formats: 0%v", string(c)))
|
||||||
|
}
|
||||||
|
|
||||||
|
switch c {
|
||||||
|
case 'b':
|
||||||
|
b.numberFormat = binary
|
||||||
|
case 'o':
|
||||||
|
b.numberFormat = octal
|
||||||
|
case 'x':
|
||||||
|
b.numberFormat = hex
|
||||||
|
case 'e', 'E':
|
||||||
|
b.numberFormat = exponent
|
||||||
|
case '-':
|
||||||
|
if b.numberFormat != exponent {
|
||||||
|
b.negative = true
|
||||||
|
} else {
|
||||||
|
b.negativeExponent = true
|
||||||
|
}
|
||||||
|
case '.':
|
||||||
|
b.numberFormat = decimal
|
||||||
|
default:
|
||||||
|
return NewParseError(fmt.Sprintf("invalid number character: %v", string(c)))
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b numberHelper) CorrectByte(c rune) bool {
|
||||||
|
switch {
|
||||||
|
case b.numberFormat == binary:
|
||||||
|
if !isBinaryByte(c) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
case b.numberFormat == octal:
|
||||||
|
if !isOctalByte(c) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
case b.numberFormat == hex:
|
||||||
|
if !isHexByte(c) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
case b.numberFormat == decimal:
|
||||||
|
if !isDigit(c) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
case b.numberFormat == exponent:
|
||||||
|
if !isDigit(c) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
case b.negativeExponent:
|
||||||
|
if !isDigit(c) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
case b.negative:
|
||||||
|
if !isDigit(c) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
if !isDigit(c) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b numberHelper) Base() int {
|
||||||
|
switch b.numberFormat {
|
||||||
|
case binary:
|
||||||
|
return 2
|
||||||
|
case octal:
|
||||||
|
return 8
|
||||||
|
case hex:
|
||||||
|
return 16
|
||||||
|
default:
|
||||||
|
return 10
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b numberHelper) String() string {
|
||||||
|
buf := bytes.Buffer{}
|
||||||
|
i := 0
|
||||||
|
|
||||||
|
switch b.numberFormat {
|
||||||
|
case binary:
|
||||||
|
i++
|
||||||
|
buf.WriteString(strconv.Itoa(i) + ": binary format\n")
|
||||||
|
case octal:
|
||||||
|
i++
|
||||||
|
buf.WriteString(strconv.Itoa(i) + ": octal format\n")
|
||||||
|
case hex:
|
||||||
|
i++
|
||||||
|
buf.WriteString(strconv.Itoa(i) + ": hex format\n")
|
||||||
|
case exponent:
|
||||||
|
i++
|
||||||
|
buf.WriteString(strconv.Itoa(i) + ": exponent format\n")
|
||||||
|
default:
|
||||||
|
i++
|
||||||
|
buf.WriteString(strconv.Itoa(i) + ": integer format\n")
|
||||||
|
}
|
||||||
|
|
||||||
|
if b.negative {
|
||||||
|
i++
|
||||||
|
buf.WriteString(strconv.Itoa(i) + ": negative format\n")
|
||||||
|
}
|
||||||
|
|
||||||
|
if b.negativeExponent {
|
||||||
|
i++
|
||||||
|
buf.WriteString(strconv.Itoa(i) + ": negative exponent format\n")
|
||||||
|
}
|
||||||
|
|
||||||
|
return buf.String()
|
||||||
|
}
|
|
@ -0,0 +1,39 @@
|
||||||
|
package ini
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
equalOp = []rune("=")
|
||||||
|
equalColonOp = []rune(":")
|
||||||
|
)
|
||||||
|
|
||||||
|
func isOp(b []rune) bool {
|
||||||
|
if len(b) == 0 {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
switch b[0] {
|
||||||
|
case '=':
|
||||||
|
return true
|
||||||
|
case ':':
|
||||||
|
return true
|
||||||
|
default:
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func newOpToken(b []rune) (Token, int, error) {
|
||||||
|
tok := Token{}
|
||||||
|
|
||||||
|
switch b[0] {
|
||||||
|
case '=':
|
||||||
|
tok = newToken(TokenOp, equalOp, NoneType)
|
||||||
|
case ':':
|
||||||
|
tok = newToken(TokenOp, equalColonOp, NoneType)
|
||||||
|
default:
|
||||||
|
return tok, 0, NewParseError(fmt.Sprintf("unexpected op type, %v", b[0]))
|
||||||
|
}
|
||||||
|
return tok, 1, nil
|
||||||
|
}
|
|
@ -0,0 +1,43 @@
|
||||||
|
package ini
|
||||||
|
|
||||||
|
import "fmt"
|
||||||
|
|
||||||
|
const (
|
||||||
|
// ErrCodeParseError is returned when a parsing error
|
||||||
|
// has occurred.
|
||||||
|
ErrCodeParseError = "INIParseError"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ParseError is an error which is returned during any part of
|
||||||
|
// the parsing process.
|
||||||
|
type ParseError struct {
|
||||||
|
msg string
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewParseError will return a new ParseError where message
|
||||||
|
// is the description of the error.
|
||||||
|
func NewParseError(message string) *ParseError {
|
||||||
|
return &ParseError{
|
||||||
|
msg: message,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Code will return the ErrCodeParseError
|
||||||
|
func (err *ParseError) Code() string {
|
||||||
|
return ErrCodeParseError
|
||||||
|
}
|
||||||
|
|
||||||
|
// Message returns the error's message
|
||||||
|
func (err *ParseError) Message() string {
|
||||||
|
return err.msg
|
||||||
|
}
|
||||||
|
|
||||||
|
// OrigError return nothing since there will never be any
|
||||||
|
// original error.
|
||||||
|
func (err *ParseError) OrigError() error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (err *ParseError) Error() string {
|
||||||
|
return fmt.Sprintf("%s: %s", err.Code(), err.Message())
|
||||||
|
}
|
|
@ -0,0 +1,60 @@
|
||||||
|
package ini
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"fmt"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ParseStack is a stack that contains a container, the stack portion,
|
||||||
|
// and the list which is the list of ASTs that have been successfully
|
||||||
|
// parsed.
|
||||||
|
type ParseStack struct {
|
||||||
|
top int
|
||||||
|
container []AST
|
||||||
|
list []AST
|
||||||
|
index int
|
||||||
|
}
|
||||||
|
|
||||||
|
func newParseStack(sizeContainer, sizeList int) ParseStack {
|
||||||
|
return ParseStack{
|
||||||
|
container: make([]AST, sizeContainer),
|
||||||
|
list: make([]AST, sizeList),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Pop will return and truncate the last container element.
|
||||||
|
func (s *ParseStack) Pop() AST {
|
||||||
|
s.top--
|
||||||
|
return s.container[s.top]
|
||||||
|
}
|
||||||
|
|
||||||
|
// Push will add the new AST to the container
|
||||||
|
func (s *ParseStack) Push(ast AST) {
|
||||||
|
s.container[s.top] = ast
|
||||||
|
s.top++
|
||||||
|
}
|
||||||
|
|
||||||
|
// MarkComplete will append the AST to the list of completed statements
|
||||||
|
func (s *ParseStack) MarkComplete(ast AST) {
|
||||||
|
s.list[s.index] = ast
|
||||||
|
s.index++
|
||||||
|
}
|
||||||
|
|
||||||
|
// List will return the completed statements
|
||||||
|
func (s ParseStack) List() []AST {
|
||||||
|
return s.list[:s.index]
|
||||||
|
}
|
||||||
|
|
||||||
|
// Len will return the length of the container
|
||||||
|
func (s *ParseStack) Len() int {
|
||||||
|
return s.top
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s ParseStack) String() string {
|
||||||
|
buf := bytes.Buffer{}
|
||||||
|
for i, node := range s.list {
|
||||||
|
buf.WriteString(fmt.Sprintf("%d: %v\n", i+1, node))
|
||||||
|
}
|
||||||
|
|
||||||
|
return buf.String()
|
||||||
|
}
|
|
@ -0,0 +1,41 @@
|
||||||
|
package ini
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
emptyRunes = []rune{}
|
||||||
|
)
|
||||||
|
|
||||||
|
func isSep(b []rune) bool {
|
||||||
|
if len(b) == 0 {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
switch b[0] {
|
||||||
|
case '[', ']':
|
||||||
|
return true
|
||||||
|
default:
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
var (
|
||||||
|
openBrace = []rune("[")
|
||||||
|
closeBrace = []rune("]")
|
||||||
|
)
|
||||||
|
|
||||||
|
func newSepToken(b []rune) (Token, int, error) {
|
||||||
|
tok := Token{}
|
||||||
|
|
||||||
|
switch b[0] {
|
||||||
|
case '[':
|
||||||
|
tok = newToken(TokenSep, openBrace, NoneType)
|
||||||
|
case ']':
|
||||||
|
tok = newToken(TokenSep, closeBrace, NoneType)
|
||||||
|
default:
|
||||||
|
return tok, 0, NewParseError(fmt.Sprintf("unexpected sep type, %v", b[0]))
|
||||||
|
}
|
||||||
|
return tok, 1, nil
|
||||||
|
}
|
|
@ -0,0 +1,45 @@
|
||||||
|
package ini
|
||||||
|
|
||||||
|
// skipper is used to skip certain blocks of an ini file.
|
||||||
|
// Currently skipper is used to skip nested blocks of ini
|
||||||
|
// files. See example below
|
||||||
|
//
|
||||||
|
// [ foo ]
|
||||||
|
// nested = ; this section will be skipped
|
||||||
|
// a=b
|
||||||
|
// c=d
|
||||||
|
// bar=baz ; this will be included
|
||||||
|
type skipper struct {
|
||||||
|
shouldSkip bool
|
||||||
|
TokenSet bool
|
||||||
|
prevTok Token
|
||||||
|
}
|
||||||
|
|
||||||
|
func newSkipper() skipper {
|
||||||
|
return skipper{
|
||||||
|
prevTok: emptyToken,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *skipper) ShouldSkip(tok Token) bool {
|
||||||
|
if s.shouldSkip &&
|
||||||
|
s.prevTok.Type() == TokenNL &&
|
||||||
|
tok.Type() != TokenWS {
|
||||||
|
|
||||||
|
s.Continue()
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
s.prevTok = tok
|
||||||
|
|
||||||
|
return s.shouldSkip
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *skipper) Skip() {
|
||||||
|
s.shouldSkip = true
|
||||||
|
s.prevTok = emptyToken
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *skipper) Continue() {
|
||||||
|
s.shouldSkip = false
|
||||||
|
s.prevTok = emptyToken
|
||||||
|
}
|
|
@ -0,0 +1,35 @@
|
||||||
|
package ini
|
||||||
|
|
||||||
|
// Statement is an empty AST mostly used for transitioning states.
|
||||||
|
func newStatement() AST {
|
||||||
|
return newAST(ASTKindStatement, AST{})
|
||||||
|
}
|
||||||
|
|
||||||
|
// SectionStatement represents a section AST
|
||||||
|
func newSectionStatement(tok Token) AST {
|
||||||
|
return newASTWithRootToken(ASTKindSectionStatement, tok)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ExprStatement represents a completed expression AST
|
||||||
|
func newExprStatement(ast AST) AST {
|
||||||
|
return newAST(ASTKindExprStatement, ast)
|
||||||
|
}
|
||||||
|
|
||||||
|
// CommentStatement represents a comment in the ini definition.
|
||||||
|
//
|
||||||
|
// grammar:
|
||||||
|
// comment -> #comment' | ;comment'
|
||||||
|
// comment' -> epsilon | value
|
||||||
|
func newCommentStatement(tok Token) AST {
|
||||||
|
return newAST(ASTKindCommentStatement, newExpression(tok))
|
||||||
|
}
|
||||||
|
|
||||||
|
// CompletedSectionStatement represents a completed section
|
||||||
|
func newCompletedSectionStatement(ast AST) AST {
|
||||||
|
return newAST(ASTKindCompletedSectionStatement, ast)
|
||||||
|
}
|
||||||
|
|
||||||
|
// SkipStatement is used to skip whole statements
|
||||||
|
func newSkipStatement(ast AST) AST {
|
||||||
|
return newAST(ASTKindSkipStatement, ast)
|
||||||
|
}
|
|
@ -0,0 +1,284 @@
|
||||||
|
package ini
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
)
|
||||||
|
|
||||||
|
// getStringValue will return a quoted string and the amount
|
||||||
|
// of bytes read
|
||||||
|
//
|
||||||
|
// an error will be returned if the string is not properly formatted
|
||||||
|
func getStringValue(b []rune) (int, error) {
|
||||||
|
if b[0] != '"' {
|
||||||
|
return 0, NewParseError("strings must start with '\"'")
|
||||||
|
}
|
||||||
|
|
||||||
|
endQuote := false
|
||||||
|
i := 1
|
||||||
|
|
||||||
|
for ; i < len(b) && !endQuote; i++ {
|
||||||
|
if escaped := isEscaped(b[:i], b[i]); b[i] == '"' && !escaped {
|
||||||
|
endQuote = true
|
||||||
|
break
|
||||||
|
} else if escaped {
|
||||||
|
/*c, err := getEscapedByte(b[i])
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
b[i-1] = c
|
||||||
|
b = append(b[:i], b[i+1:]...)
|
||||||
|
i--*/
|
||||||
|
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if !endQuote {
|
||||||
|
return 0, NewParseError("missing '\"' in string value")
|
||||||
|
}
|
||||||
|
|
||||||
|
return i + 1, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// getBoolValue will return a boolean and the amount
|
||||||
|
// of bytes read
|
||||||
|
//
|
||||||
|
// an error will be returned if the boolean is not of a correct
|
||||||
|
// value
|
||||||
|
func getBoolValue(b []rune) (int, error) {
|
||||||
|
if len(b) < 4 {
|
||||||
|
return 0, NewParseError("invalid boolean value")
|
||||||
|
}
|
||||||
|
|
||||||
|
n := 0
|
||||||
|
for _, lv := range literalValues {
|
||||||
|
if len(lv) > len(b) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if isLitValue(lv, b) {
|
||||||
|
n = len(lv)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if n == 0 {
|
||||||
|
return 0, NewParseError("invalid boolean value")
|
||||||
|
}
|
||||||
|
|
||||||
|
return n, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// getNumericalValue will return a numerical string, the amount
|
||||||
|
// of bytes read, and the base of the number
|
||||||
|
//
|
||||||
|
// an error will be returned if the number is not of a correct
|
||||||
|
// value
|
||||||
|
func getNumericalValue(b []rune) (int, int, error) {
|
||||||
|
if !isDigit(b[0]) {
|
||||||
|
return 0, 0, NewParseError("invalid digit value")
|
||||||
|
}
|
||||||
|
|
||||||
|
i := 0
|
||||||
|
helper := numberHelper{}
|
||||||
|
|
||||||
|
loop:
|
||||||
|
for negativeIndex := 0; i < len(b); i++ {
|
||||||
|
negativeIndex++
|
||||||
|
|
||||||
|
if !isDigit(b[i]) {
|
||||||
|
switch b[i] {
|
||||||
|
case '-':
|
||||||
|
if helper.IsNegative() || negativeIndex != 1 {
|
||||||
|
return 0, 0, NewParseError("parse error '-'")
|
||||||
|
}
|
||||||
|
|
||||||
|
n := getNegativeNumber(b[i:])
|
||||||
|
i += (n - 1)
|
||||||
|
helper.Determine(b[i])
|
||||||
|
continue
|
||||||
|
case '.':
|
||||||
|
if err := helper.Determine(b[i]); err != nil {
|
||||||
|
return 0, 0, err
|
||||||
|
}
|
||||||
|
case 'e', 'E':
|
||||||
|
if err := helper.Determine(b[i]); err != nil {
|
||||||
|
return 0, 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
negativeIndex = 0
|
||||||
|
case 'b':
|
||||||
|
if helper.numberFormat == hex {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
fallthrough
|
||||||
|
case 'o', 'x':
|
||||||
|
if i == 0 && b[i] != '0' {
|
||||||
|
return 0, 0, NewParseError("incorrect base format, expected leading '0'")
|
||||||
|
}
|
||||||
|
|
||||||
|
if i != 1 {
|
||||||
|
return 0, 0, NewParseError(fmt.Sprintf("incorrect base format found %s at %d index", string(b[i]), i))
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := helper.Determine(b[i]); err != nil {
|
||||||
|
return 0, 0, err
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
if isWhitespace(b[i]) {
|
||||||
|
break loop
|
||||||
|
}
|
||||||
|
|
||||||
|
if isNewline(b[i:]) {
|
||||||
|
break loop
|
||||||
|
}
|
||||||
|
|
||||||
|
if !(helper.numberFormat == hex && isHexByte(b[i])) {
|
||||||
|
if i+2 < len(b) && !isNewline(b[i:i+2]) {
|
||||||
|
return 0, 0, NewParseError("invalid numerical character")
|
||||||
|
} else if !isNewline([]rune{b[i]}) {
|
||||||
|
return 0, 0, NewParseError("invalid numerical character")
|
||||||
|
}
|
||||||
|
|
||||||
|
break loop
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return helper.Base(), i, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// isDigit will return whether or not something is an integer
|
||||||
|
func isDigit(b rune) bool {
|
||||||
|
return b >= '0' && b <= '9'
|
||||||
|
}
|
||||||
|
|
||||||
|
func hasExponent(v []rune) bool {
|
||||||
|
return contains(v, 'e') || contains(v, 'E')
|
||||||
|
}
|
||||||
|
|
||||||
|
func isBinaryByte(b rune) bool {
|
||||||
|
switch b {
|
||||||
|
case '0', '1':
|
||||||
|
return true
|
||||||
|
default:
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func isOctalByte(b rune) bool {
|
||||||
|
switch b {
|
||||||
|
case '0', '1', '2', '3', '4', '5', '6', '7':
|
||||||
|
return true
|
||||||
|
default:
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func isHexByte(b rune) bool {
|
||||||
|
if isDigit(b) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
return (b >= 'A' && b <= 'F') ||
|
||||||
|
(b >= 'a' && b <= 'f')
|
||||||
|
}
|
||||||
|
|
||||||
|
func getValue(b []rune) (int, error) {
|
||||||
|
i := 0
|
||||||
|
|
||||||
|
for i < len(b) {
|
||||||
|
if isNewline(b[i:]) {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
if isOp(b[i:]) {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
valid, n, err := isValid(b[i:])
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if !valid {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
i += n
|
||||||
|
}
|
||||||
|
|
||||||
|
return i, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// getNegativeNumber will return a negative number from a
|
||||||
|
// byte slice. This will iterate through all characters until
|
||||||
|
// a non-digit has been found.
|
||||||
|
func getNegativeNumber(b []rune) int {
|
||||||
|
if b[0] != '-' {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
|
||||||
|
i := 1
|
||||||
|
for ; i < len(b); i++ {
|
||||||
|
if !isDigit(b[i]) {
|
||||||
|
return i
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return i
|
||||||
|
}
|
||||||
|
|
||||||
|
// isEscaped will return whether or not the character is an escaped
|
||||||
|
// character.
|
||||||
|
func isEscaped(value []rune, b rune) bool {
|
||||||
|
if len(value) == 0 {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
switch b {
|
||||||
|
case '\'': // single quote
|
||||||
|
case '"': // quote
|
||||||
|
case 'n': // newline
|
||||||
|
case 't': // tab
|
||||||
|
case '\\': // backslash
|
||||||
|
default:
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
return value[len(value)-1] == '\\'
|
||||||
|
}
|
||||||
|
|
||||||
|
func getEscapedByte(b rune) (rune, error) {
|
||||||
|
switch b {
|
||||||
|
case '\'': // single quote
|
||||||
|
return '\'', nil
|
||||||
|
case '"': // quote
|
||||||
|
return '"', nil
|
||||||
|
case 'n': // newline
|
||||||
|
return '\n', nil
|
||||||
|
case 't': // table
|
||||||
|
return '\t', nil
|
||||||
|
case '\\': // backslash
|
||||||
|
return '\\', nil
|
||||||
|
default:
|
||||||
|
return b, NewParseError(fmt.Sprintf("invalid escaped character %c", b))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func removeEscapedCharacters(b []rune) []rune {
|
||||||
|
for i := 0; i < len(b); i++ {
|
||||||
|
if isEscaped(b[:i], b[i]) {
|
||||||
|
c, err := getEscapedByte(b[i])
|
||||||
|
if err != nil {
|
||||||
|
return b
|
||||||
|
}
|
||||||
|
|
||||||
|
b[i-1] = c
|
||||||
|
b = append(b[:i], b[i+1:]...)
|
||||||
|
i--
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return b
|
||||||
|
}
|
|
@ -0,0 +1,166 @@
|
||||||
|
package ini
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"sort"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Visitor is an interface used by walkers that will
|
||||||
|
// traverse an array of ASTs.
|
||||||
|
type Visitor interface {
|
||||||
|
VisitExpr(AST) error
|
||||||
|
VisitStatement(AST) error
|
||||||
|
}
|
||||||
|
|
||||||
|
// DefaultVisitor is used to visit statements and expressions
|
||||||
|
// and ensure that they are both of the correct format.
|
||||||
|
// In addition, upon visiting this will build sections and populate
|
||||||
|
// the Sections field which can be used to retrieve profile
|
||||||
|
// configuration.
|
||||||
|
type DefaultVisitor struct {
|
||||||
|
scope string
|
||||||
|
Sections Sections
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewDefaultVisitor return a DefaultVisitor
|
||||||
|
func NewDefaultVisitor() *DefaultVisitor {
|
||||||
|
return &DefaultVisitor{
|
||||||
|
Sections: Sections{
|
||||||
|
container: map[string]Section{},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// VisitExpr visits expressions...
|
||||||
|
func (v *DefaultVisitor) VisitExpr(expr AST) error {
|
||||||
|
t := v.Sections.container[v.scope]
|
||||||
|
if t.values == nil {
|
||||||
|
t.values = values{}
|
||||||
|
}
|
||||||
|
|
||||||
|
switch expr.Kind {
|
||||||
|
case ASTKindExprStatement:
|
||||||
|
opExpr := expr.GetRoot()
|
||||||
|
switch opExpr.Kind {
|
||||||
|
case ASTKindEqualExpr:
|
||||||
|
children := opExpr.GetChildren()
|
||||||
|
if len(children) <= 1 {
|
||||||
|
return NewParseError("unexpected token type")
|
||||||
|
}
|
||||||
|
|
||||||
|
rhs := children[1]
|
||||||
|
|
||||||
|
if rhs.Root.Type() != TokenLit {
|
||||||
|
return NewParseError("unexpected token type")
|
||||||
|
}
|
||||||
|
|
||||||
|
key := EqualExprKey(opExpr)
|
||||||
|
v, err := newValue(rhs.Root.ValueType, rhs.Root.base, rhs.Root.Raw())
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
t.values[key] = v
|
||||||
|
default:
|
||||||
|
return NewParseError(fmt.Sprintf("unsupported expression %v", expr))
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
return NewParseError(fmt.Sprintf("unsupported expression %v", expr))
|
||||||
|
}
|
||||||
|
|
||||||
|
v.Sections.container[v.scope] = t
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// VisitStatement visits statements...
|
||||||
|
func (v *DefaultVisitor) VisitStatement(stmt AST) error {
|
||||||
|
switch stmt.Kind {
|
||||||
|
case ASTKindCompletedSectionStatement:
|
||||||
|
child := stmt.GetRoot()
|
||||||
|
if child.Kind != ASTKindSectionStatement {
|
||||||
|
return NewParseError(fmt.Sprintf("unsupported child statement: %T", child))
|
||||||
|
}
|
||||||
|
|
||||||
|
name := string(child.Root.Raw())
|
||||||
|
v.Sections.container[name] = Section{}
|
||||||
|
v.scope = name
|
||||||
|
default:
|
||||||
|
return NewParseError(fmt.Sprintf("unsupported statement: %s", stmt.Kind))
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Sections is a map of Section structures that represent
|
||||||
|
// a configuration.
|
||||||
|
type Sections struct {
|
||||||
|
container map[string]Section
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetSection will return section p. If section p does not exist,
|
||||||
|
// false will be returned in the second parameter.
|
||||||
|
func (t Sections) GetSection(p string) (Section, bool) {
|
||||||
|
v, ok := t.container[p]
|
||||||
|
return v, ok
|
||||||
|
}
|
||||||
|
|
||||||
|
// values represents a map of union values.
|
||||||
|
type values map[string]Value
|
||||||
|
|
||||||
|
// List will return a list of all sections that were successfully
|
||||||
|
// parsed.
|
||||||
|
func (t Sections) List() []string {
|
||||||
|
keys := make([]string, len(t.container))
|
||||||
|
i := 0
|
||||||
|
for k := range t.container {
|
||||||
|
keys[i] = k
|
||||||
|
i++
|
||||||
|
}
|
||||||
|
|
||||||
|
sort.Strings(keys)
|
||||||
|
return keys
|
||||||
|
}
|
||||||
|
|
||||||
|
// Section contains a name and values. This represent
|
||||||
|
// a sectioned entry in a configuration file.
|
||||||
|
type Section struct {
|
||||||
|
Name string
|
||||||
|
values values
|
||||||
|
}
|
||||||
|
|
||||||
|
// Has will return whether or not an entry exists in a given section
|
||||||
|
func (t Section) Has(k string) bool {
|
||||||
|
_, ok := t.values[k]
|
||||||
|
return ok
|
||||||
|
}
|
||||||
|
|
||||||
|
// ValueType will returned what type the union is set to. If
|
||||||
|
// k was not found, the NoneType will be returned.
|
||||||
|
func (t Section) ValueType(k string) (ValueType, bool) {
|
||||||
|
v, ok := t.values[k]
|
||||||
|
return v.Type, ok
|
||||||
|
}
|
||||||
|
|
||||||
|
// Bool returns a bool value at k
|
||||||
|
func (t Section) Bool(k string) bool {
|
||||||
|
return t.values[k].BoolValue()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Int returns an integer value at k
|
||||||
|
func (t Section) Int(k string) int64 {
|
||||||
|
return t.values[k].IntValue()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Float64 returns a float value at k
|
||||||
|
func (t Section) Float64(k string) float64 {
|
||||||
|
return t.values[k].FloatValue()
|
||||||
|
}
|
||||||
|
|
||||||
|
// String returns the string value at k
|
||||||
|
func (t Section) String(k string) string {
|
||||||
|
_, ok := t.values[k]
|
||||||
|
if !ok {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
return t.values[k].StringValue()
|
||||||
|
}
|
|
@ -0,0 +1,25 @@
|
||||||
|
package ini
|
||||||
|
|
||||||
|
// Walk will traverse the AST using the v, the Visitor.
|
||||||
|
func Walk(tree []AST, v Visitor) error {
|
||||||
|
for _, node := range tree {
|
||||||
|
switch node.Kind {
|
||||||
|
case ASTKindExpr,
|
||||||
|
ASTKindExprStatement:
|
||||||
|
|
||||||
|
if err := v.VisitExpr(node); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
case ASTKindStatement,
|
||||||
|
ASTKindCompletedSectionStatement,
|
||||||
|
ASTKindNestedSectionStatement,
|
||||||
|
ASTKindCompletedNestedSectionStatement:
|
||||||
|
|
||||||
|
if err := v.VisitStatement(node); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
|
@ -0,0 +1,24 @@
|
||||||
|
package ini
|
||||||
|
|
||||||
|
import (
|
||||||
|
"unicode"
|
||||||
|
)
|
||||||
|
|
||||||
|
// isWhitespace will return whether or not the character is
|
||||||
|
// a whitespace character.
|
||||||
|
//
|
||||||
|
// Whitespace is defined as a space or tab.
|
||||||
|
func isWhitespace(c rune) bool {
|
||||||
|
return unicode.IsSpace(c) && c != '\n' && c != '\r'
|
||||||
|
}
|
||||||
|
|
||||||
|
func newWSToken(b []rune) (Token, int, error) {
|
||||||
|
i := 0
|
||||||
|
for ; i < len(b); i++ {
|
||||||
|
if !isWhitespace(b[i]) {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return newToken(TokenWS, b[:i], NoneType), i, nil
|
||||||
|
}
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue