mirror of https://github.com/status-im/consul.git
Clean up Vault renew tests and shutdown
This commit is contained in:
parent
844e9ffe16
commit
9496780ab4
|
@ -2,13 +2,13 @@ package ca
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
|
"context"
|
||||||
"crypto/x509"
|
"crypto/x509"
|
||||||
"encoding/pem"
|
"encoding/pem"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io/ioutil"
|
"io/ioutil"
|
||||||
"net/http"
|
"net/http"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
|
||||||
|
|
||||||
"github.com/hashicorp/consul/agent/connect"
|
"github.com/hashicorp/consul/agent/connect"
|
||||||
"github.com/hashicorp/consul/agent/structs"
|
"github.com/hashicorp/consul/agent/structs"
|
||||||
|
@ -27,9 +27,7 @@ type VaultProvider struct {
|
||||||
config *structs.VaultCAProviderConfig
|
config *structs.VaultCAProviderConfig
|
||||||
client *vaultapi.Client
|
client *vaultapi.Client
|
||||||
|
|
||||||
shutdown bool
|
shutdown func()
|
||||||
shutdownCh chan struct{}
|
|
||||||
shutdownLock sync.RWMutex
|
|
||||||
|
|
||||||
isPrimary bool
|
isPrimary bool
|
||||||
clusterID string
|
clusterID string
|
||||||
|
@ -38,6 +36,10 @@ type VaultProvider struct {
|
||||||
logger hclog.Logger
|
logger hclog.Logger
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func NewVaultProvider() *VaultProvider {
|
||||||
|
return &VaultProvider{shutdown: func() {}}
|
||||||
|
}
|
||||||
|
|
||||||
func vaultTLSConfig(config *structs.VaultCAProviderConfig) *vaultapi.TLSConfig {
|
func vaultTLSConfig(config *structs.VaultCAProviderConfig) *vaultapi.TLSConfig {
|
||||||
return &vaultapi.TLSConfig{
|
return &vaultapi.TLSConfig{
|
||||||
CACert: config.CAFile,
|
CACert: config.CAFile,
|
||||||
|
@ -74,7 +76,6 @@ func (v *VaultProvider) Configure(cfg ProviderConfig) error {
|
||||||
v.isPrimary = cfg.IsPrimary
|
v.isPrimary = cfg.IsPrimary
|
||||||
v.clusterID = cfg.ClusterID
|
v.clusterID = cfg.ClusterID
|
||||||
v.spiffeID = connect.SpiffeIDSigningForCluster(&structs.CAConfiguration{ClusterID: v.clusterID})
|
v.spiffeID = connect.SpiffeIDSigningForCluster(&structs.CAConfiguration{ClusterID: v.clusterID})
|
||||||
v.shutdownCh = make(chan struct{}, 0)
|
|
||||||
|
|
||||||
// Look up the token to see if we can auto-renew its lease.
|
// Look up the token to see if we can auto-renew its lease.
|
||||||
secret, err := client.Auth().Token().Lookup(config.Token)
|
secret, err := client.Auth().Token().Lookup(config.Token)
|
||||||
|
@ -99,25 +100,28 @@ func (v *VaultProvider) Configure(cfg ProviderConfig) error {
|
||||||
LeaseDuration: secret.LeaseDuration,
|
LeaseDuration: secret.LeaseDuration,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
Increment: int(token.TTL),
|
Increment: token.TTL,
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("Error beginning Vault provider token renewal: %v", err)
|
return fmt.Errorf("Error beginning Vault provider token renewal: %v", err)
|
||||||
}
|
}
|
||||||
go v.renewToken(renewer)
|
|
||||||
|
ctx, cancel := context.WithCancel(context.TODO())
|
||||||
|
v.shutdown = cancel
|
||||||
|
go v.renewToken(ctx, renewer)
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// renewToken uses a vaultapi.Renewer to repeatedly renew our token's lease.
|
// renewToken uses a vaultapi.Renewer to repeatedly renew our token's lease.
|
||||||
func (v *VaultProvider) renewToken(renewer *vaultapi.Renewer) {
|
func (v *VaultProvider) renewToken(ctx context.Context, renewer *vaultapi.Renewer) {
|
||||||
go renewer.Renew()
|
go renewer.Renew()
|
||||||
|
defer renewer.Stop()
|
||||||
|
|
||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
case <-v.shutdownCh:
|
case <-ctx.Done():
|
||||||
renewer.Stop()
|
|
||||||
return
|
return
|
||||||
|
|
||||||
case err := <-renewer.DoneCh():
|
case err := <-renewer.DoneCh():
|
||||||
|
@ -125,6 +129,9 @@ func (v *VaultProvider) renewToken(renewer *vaultapi.Renewer) {
|
||||||
v.logger.Error(fmt.Sprintf("Error renewing token for Vault provider: %v", err))
|
v.logger.Error(fmt.Sprintf("Error renewing token for Vault provider: %v", err))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Renewer routine has finished, so start it again.
|
||||||
|
go renewer.Renew()
|
||||||
|
|
||||||
case <-renewer.RenewCh():
|
case <-renewer.RenewCh():
|
||||||
v.logger.Error("Successfully renewed token for Vault provider")
|
v.logger.Error("Successfully renewed token for Vault provider")
|
||||||
}
|
}
|
||||||
|
@ -508,13 +515,7 @@ func (v *VaultProvider) Cleanup() error {
|
||||||
|
|
||||||
// Stop shuts down the token renew goroutine.
|
// Stop shuts down the token renew goroutine.
|
||||||
func (v *VaultProvider) Stop() {
|
func (v *VaultProvider) Stop() {
|
||||||
v.shutdownLock.Lock()
|
v.shutdown()
|
||||||
defer v.shutdownLock.Unlock()
|
|
||||||
|
|
||||||
if !v.shutdown && v.shutdownCh != nil {
|
|
||||||
close(v.shutdownCh)
|
|
||||||
v.shutdown = true
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func ParseVaultCAConfig(raw map[string]interface{}) (*structs.VaultCAProviderConfig, error) {
|
func ParseVaultCAConfig(raw map[string]interface{}) (*structs.VaultCAProviderConfig, error) {
|
||||||
|
|
|
@ -55,14 +55,10 @@ func TestVaultCAProvider_SecondaryActiveIntermediate(t *testing.T) {
|
||||||
|
|
||||||
func TestVaultCAProvider_RenewToken(t *testing.T) {
|
func TestVaultCAProvider_RenewToken(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
require := require.New(t)
|
|
||||||
skipIfVaultNotPresent(t)
|
skipIfVaultNotPresent(t)
|
||||||
|
|
||||||
testVault, err := runTestVault()
|
testVault, err := runTestVault(t)
|
||||||
if err != nil {
|
require.NoError(t, err)
|
||||||
t.Fatalf("err: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
testVault.WaitUntilReady(t)
|
testVault.WaitUntilReady(t)
|
||||||
|
|
||||||
// Create a token with a short TTL to be renewed by the provider.
|
// Create a token with a short TTL to be renewed by the provider.
|
||||||
|
@ -71,26 +67,26 @@ func TestVaultCAProvider_RenewToken(t *testing.T) {
|
||||||
TTL: ttl.String(),
|
TTL: ttl.String(),
|
||||||
}
|
}
|
||||||
secret, err := testVault.client.Auth().Token().Create(tcr)
|
secret, err := testVault.client.Auth().Token().Create(tcr)
|
||||||
require.NoError(err)
|
require.NoError(t, err)
|
||||||
providerToken := secret.Auth.ClientToken
|
providerToken := secret.Auth.ClientToken
|
||||||
|
|
||||||
_, err = createVaultProvider(true, testVault.addr, providerToken, nil)
|
_, err = createVaultProvider(t, true, testVault.addr, providerToken, nil)
|
||||||
require.NoError(err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
// Check the last renewal time.
|
// Check the last renewal time.
|
||||||
secret, err = testVault.client.Auth().Token().Lookup(providerToken)
|
secret, err = testVault.client.Auth().Token().Lookup(providerToken)
|
||||||
require.NoError(err)
|
require.NoError(t, err)
|
||||||
firstRenewal, err := secret.Data["last_renewal_time"].(json.Number).Int64()
|
firstRenewal, err := secret.Data["last_renewal_time"].(json.Number).Int64()
|
||||||
require.NoError(err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
time.Sleep(ttl * 2)
|
|
||||||
|
|
||||||
// Wait past the TTL and make sure the token has been renewed.
|
// Wait past the TTL and make sure the token has been renewed.
|
||||||
|
retry.Run(t, func(r *retry.R) {
|
||||||
secret, err = testVault.client.Auth().Token().Lookup(providerToken)
|
secret, err = testVault.client.Auth().Token().Lookup(providerToken)
|
||||||
require.NoError(err)
|
require.NoError(r, err)
|
||||||
lastRenewal, err := secret.Data["last_renewal_time"].(json.Number).Int64()
|
lastRenewal, err := secret.Data["last_renewal_time"].(json.Number).Int64()
|
||||||
require.NoError(err)
|
require.NoError(r, err)
|
||||||
require.Greater(lastRenewal, firstRenewal)
|
require.Greater(r, lastRenewal, firstRenewal)
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestVaultCAProvider_Bootstrap(t *testing.T) {
|
func TestVaultCAProvider_Bootstrap(t *testing.T) {
|
||||||
|
@ -391,14 +387,14 @@ func testVaultProvider(t *testing.T) (*VaultProvider, *testVaultServer) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func testVaultProviderWithConfig(t *testing.T, isPrimary bool, rawConf map[string]interface{}) (*VaultProvider, *testVaultServer) {
|
func testVaultProviderWithConfig(t *testing.T, isPrimary bool, rawConf map[string]interface{}) (*VaultProvider, *testVaultServer) {
|
||||||
testVault, err := runTestVault()
|
testVault, err := runTestVault(t)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("err: %v", err)
|
t.Fatalf("err: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
testVault.WaitUntilReady(t)
|
testVault.WaitUntilReady(t)
|
||||||
|
|
||||||
provider, err := createVaultProvider(isPrimary, testVault.addr, testVault.rootToken, rawConf)
|
provider, err := createVaultProvider(t, isPrimary, testVault.addr, testVault.rootToken, rawConf)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
testVault.Stop()
|
testVault.Stop()
|
||||||
t.Fatalf("err: %v", err)
|
t.Fatalf("err: %v", err)
|
||||||
|
@ -406,7 +402,7 @@ func testVaultProviderWithConfig(t *testing.T, isPrimary bool, rawConf map[strin
|
||||||
return provider, testVault
|
return provider, testVault
|
||||||
}
|
}
|
||||||
|
|
||||||
func createVaultProvider(isPrimary bool, addr, token string, rawConf map[string]interface{}) (*VaultProvider, error) {
|
func createVaultProvider(t *testing.T, isPrimary bool, addr, token string, rawConf map[string]interface{}) (*VaultProvider, error) {
|
||||||
conf := map[string]interface{}{
|
conf := map[string]interface{}{
|
||||||
"Address": addr,
|
"Address": addr,
|
||||||
"Token": token,
|
"Token": token,
|
||||||
|
@ -419,7 +415,7 @@ func createVaultProvider(isPrimary bool, addr, token string, rawConf map[string]
|
||||||
conf[k] = v
|
conf[k] = v
|
||||||
}
|
}
|
||||||
|
|
||||||
provider := &VaultProvider{}
|
provider := NewVaultProvider()
|
||||||
|
|
||||||
cfg := ProviderConfig{
|
cfg := ProviderConfig{
|
||||||
ClusterID: connect.TestClusterID,
|
ClusterID: connect.TestClusterID,
|
||||||
|
@ -438,16 +434,11 @@ func createVaultProvider(isPrimary bool, addr, token string, rawConf map[string]
|
||||||
cfg.Datacenter = "dc2"
|
cfg.Datacenter = "dc2"
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := provider.Configure(cfg); err != nil {
|
require.NoError(t, provider.Configure(cfg))
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
if isPrimary {
|
if isPrimary {
|
||||||
if err := provider.GenerateRoot(); err != nil {
|
require.NoError(t, provider.GenerateRoot())
|
||||||
return nil, err
|
_, err := provider.GenerateIntermediate()
|
||||||
}
|
require.NoError(t, err)
|
||||||
if _, err := provider.GenerateIntermediate(); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return provider, nil
|
return provider, nil
|
||||||
|
@ -469,7 +460,7 @@ func skipIfVaultNotPresent(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func runTestVault() (*testVaultServer, error) {
|
func runTestVault(t *testing.T) (*testVaultServer, error) {
|
||||||
vaultBinaryName := os.Getenv("VAULT_BINARY_NAME")
|
vaultBinaryName := os.Getenv("VAULT_BINARY_NAME")
|
||||||
if vaultBinaryName == "" {
|
if vaultBinaryName == "" {
|
||||||
vaultBinaryName = "vault"
|
vaultBinaryName = "vault"
|
||||||
|
@ -520,13 +511,17 @@ func runTestVault() (*testVaultServer, error) {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
return &testVaultServer{
|
testVault := &testVaultServer{
|
||||||
rootToken: token,
|
rootToken: token,
|
||||||
addr: "http://" + clientAddr,
|
addr: "http://" + clientAddr,
|
||||||
cmd: cmd,
|
cmd: cmd,
|
||||||
client: client,
|
client: client,
|
||||||
returnPortsFn: returnPortsFn,
|
returnPortsFn: returnPortsFn,
|
||||||
}, nil
|
}
|
||||||
|
t.Cleanup(func() {
|
||||||
|
testVault.Stop()
|
||||||
|
})
|
||||||
|
return testVault, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
type testVaultServer struct {
|
type testVaultServer struct {
|
||||||
|
|
|
@ -158,7 +158,7 @@ func (s *ConnectCA) ConfigurationSet(
|
||||||
defer func() {
|
defer func() {
|
||||||
if cleanupNewProvider {
|
if cleanupNewProvider {
|
||||||
if err := newProvider.Cleanup(); err != nil {
|
if err := newProvider.Cleanup(); err != nil {
|
||||||
s.logger.Warn("failed to clean up temporary new CA provider", "provider", newProvider)
|
s.logger.Warn("failed to clean up CA provider while handling startup failure", "provider", newProvider, "error", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
|
@ -116,7 +116,7 @@ func (s *Server) createCAProvider(conf *structs.CAConfiguration) (ca.Provider, e
|
||||||
case structs.ConsulCAProvider:
|
case structs.ConsulCAProvider:
|
||||||
p = &ca.ConsulProvider{Delegate: &consulCADelegate{s}}
|
p = &ca.ConsulProvider{Delegate: &consulCADelegate{s}}
|
||||||
case structs.VaultCAProvider:
|
case structs.VaultCAProvider:
|
||||||
p = &ca.VaultProvider{}
|
p = ca.NewVaultProvider()
|
||||||
case structs.AWSCAProvider:
|
case structs.AWSCAProvider:
|
||||||
p = &ca.AWSProvider{}
|
p = &ca.AWSProvider{}
|
||||||
default:
|
default:
|
||||||
|
|
Loading…
Reference in New Issue