mirror of https://github.com/status-im/consul.git
Ensure certificates retrieved through the cache get persisted with auto-config (#8409)
This commit is contained in:
parent
4f98af0724
commit
c9b66157a1
|
@ -545,7 +545,8 @@ func New(options ...AgentOption) (*Agent, error) {
|
||||||
WithNodeName(a.config.NodeName).
|
WithNodeName(a.config.NodeName).
|
||||||
WithFallback(a.autoConfigFallbackTLS).
|
WithFallback(a.autoConfigFallbackTLS).
|
||||||
WithLogger(a.logger.Named(logging.AutoConfig)).
|
WithLogger(a.logger.Named(logging.AutoConfig)).
|
||||||
WithTokens(a.tokens)
|
WithTokens(a.tokens).
|
||||||
|
WithPersistence(a.autoConfigPersist)
|
||||||
acCertMon, err := certmon.New(cmConf)
|
acCertMon, err := certmon.New(cmConf)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
@ -889,9 +890,19 @@ func (a *Agent) autoEncryptInitialCertificate(ctx context.Context) (*structs.Sig
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Agent) autoConfigFallbackTLS(ctx context.Context) (*structs.SignedResponse, error) {
|
func (a *Agent) autoConfigFallbackTLS(ctx context.Context) (*structs.SignedResponse, error) {
|
||||||
|
if a.autoConf == nil {
|
||||||
|
return nil, fmt.Errorf("AutoConfig manager has not been created yet")
|
||||||
|
}
|
||||||
return a.autoConf.FallbackTLS(ctx)
|
return a.autoConf.FallbackTLS(ctx)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (a *Agent) autoConfigPersist(resp *structs.SignedResponse) error {
|
||||||
|
if a.autoConf == nil {
|
||||||
|
return fmt.Errorf("AutoConfig manager has not been created yet")
|
||||||
|
}
|
||||||
|
return a.autoConf.RecordUpdatedCerts(resp)
|
||||||
|
}
|
||||||
|
|
||||||
func (a *Agent) listenAndServeGRPC() error {
|
func (a *Agent) listenAndServeGRPC() error {
|
||||||
if len(a.config.GRPCAddrs) < 1 {
|
if len(a.config.GRPCAddrs) < 1 {
|
||||||
return nil
|
return nil
|
||||||
|
|
|
@ -21,6 +21,7 @@ import (
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/golang/protobuf/jsonpb"
|
||||||
"github.com/google/tcpproxy"
|
"github.com/google/tcpproxy"
|
||||||
"github.com/hashicorp/consul/agent/cache"
|
"github.com/hashicorp/consul/agent/cache"
|
||||||
cachetype "github.com/hashicorp/consul/agent/cache-types"
|
cachetype "github.com/hashicorp/consul/agent/cache-types"
|
||||||
|
@ -31,6 +32,7 @@ import (
|
||||||
"github.com/hashicorp/consul/api"
|
"github.com/hashicorp/consul/api"
|
||||||
"github.com/hashicorp/consul/internal/go-sso/oidcauth/oidcauthtest"
|
"github.com/hashicorp/consul/internal/go-sso/oidcauth/oidcauthtest"
|
||||||
"github.com/hashicorp/consul/ipaddr"
|
"github.com/hashicorp/consul/ipaddr"
|
||||||
|
"github.com/hashicorp/consul/proto/pbautoconf"
|
||||||
"github.com/hashicorp/consul/sdk/freeport"
|
"github.com/hashicorp/consul/sdk/freeport"
|
||||||
"github.com/hashicorp/consul/sdk/testutil"
|
"github.com/hashicorp/consul/sdk/testutil"
|
||||||
"github.com/hashicorp/consul/sdk/testutil/retry"
|
"github.com/hashicorp/consul/sdk/testutil/retry"
|
||||||
|
@ -4728,21 +4730,28 @@ func TestAutoConfig_Integration(t *testing.T) {
|
||||||
})
|
})
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
client := StartTestAgent(t, TestAgent{Name: "test-client", HCL: `
|
client := StartTestAgent(t, TestAgent{Name: "test-client",
|
||||||
bootstrap = false
|
Overrides: `
|
||||||
server = false
|
connect {
|
||||||
ca_file = "` + caFile + `"
|
test_ca_leaf_root_change_spread = "1ns"
|
||||||
verify_outgoing = true
|
}
|
||||||
verify_server_hostname = true
|
`,
|
||||||
node_name = "test-client"
|
HCL: `
|
||||||
ports {
|
bootstrap = false
|
||||||
server = ` + strconv.Itoa(srv.Config.RPCBindAddr.Port) + `
|
server = false
|
||||||
}
|
ca_file = "` + caFile + `"
|
||||||
auto_config {
|
verify_outgoing = true
|
||||||
enabled = true
|
verify_server_hostname = true
|
||||||
intro_token = "` + token + `"
|
node_name = "test-client"
|
||||||
server_addresses = ["` + srv.Config.RPCBindAddr.String() + `"]
|
ports {
|
||||||
}`})
|
server = ` + strconv.Itoa(srv.Config.RPCBindAddr.Port) + `
|
||||||
|
}
|
||||||
|
auto_config {
|
||||||
|
enabled = true
|
||||||
|
intro_token = "` + token + `"
|
||||||
|
server_addresses = ["` + srv.Config.RPCBindAddr.String() + `"]
|
||||||
|
}`,
|
||||||
|
})
|
||||||
|
|
||||||
defer client.Shutdown()
|
defer client.Shutdown()
|
||||||
|
|
||||||
|
@ -4782,6 +4791,21 @@ func TestAutoConfig_Integration(t *testing.T) {
|
||||||
// ensure that a new cert gets generated and pushed into the TLS configurator
|
// ensure that a new cert gets generated and pushed into the TLS configurator
|
||||||
retry.Run(t, func(r *retry.R) {
|
retry.Run(t, func(r *retry.R) {
|
||||||
require.NotEqual(r, cert1, client.Agent.tlsConfigurator.Cert())
|
require.NotEqual(r, cert1, client.Agent.tlsConfigurator.Cert())
|
||||||
|
|
||||||
|
// check that the on disk certs match expectations
|
||||||
|
data, err := ioutil.ReadFile(filepath.Join(client.DataDir, "auto-config.json"))
|
||||||
|
require.NoError(r, err)
|
||||||
|
rdr := strings.NewReader(string(data))
|
||||||
|
|
||||||
|
var resp pbautoconf.AutoConfigResponse
|
||||||
|
pbUnmarshaler := &jsonpb.Unmarshaler{
|
||||||
|
AllowUnknownFields: false,
|
||||||
|
}
|
||||||
|
require.NoError(r, pbUnmarshaler.Unmarshal(rdr, &resp), "data: %s", data)
|
||||||
|
|
||||||
|
actual, err := tls.X509KeyPair([]byte(resp.Certificate.CertPEM), []byte(resp.Certificate.PrivateKeyPEM))
|
||||||
|
require.NoError(r, err)
|
||||||
|
require.Equal(r, client.Agent.tlsConfigurator.Cert(), &actual)
|
||||||
})
|
})
|
||||||
|
|
||||||
// spot check that we now have an ACL token
|
// spot check that we now have an ACL token
|
||||||
|
|
|
@ -55,15 +55,16 @@ var (
|
||||||
// then we will need to add some locking here. I am deferring that for now
|
// then we will need to add some locking here. I am deferring that for now
|
||||||
// to help ease the review of this already large PR.
|
// to help ease the review of this already large PR.
|
||||||
type AutoConfig struct {
|
type AutoConfig struct {
|
||||||
builderOpts config.BuilderOpts
|
builderOpts config.BuilderOpts
|
||||||
logger hclog.Logger
|
logger hclog.Logger
|
||||||
directRPC DirectRPC
|
directRPC DirectRPC
|
||||||
waiter *lib.RetryWaiter
|
waiter *lib.RetryWaiter
|
||||||
overrides []config.Source
|
overrides []config.Source
|
||||||
certMonitor CertMonitor
|
certMonitor CertMonitor
|
||||||
config *config.RuntimeConfig
|
config *config.RuntimeConfig
|
||||||
autoConfigData string
|
autoConfigResponse *pbautoconf.AutoConfigResponse
|
||||||
cancel context.CancelFunc
|
autoConfigData string
|
||||||
|
cancel context.CancelFunc
|
||||||
}
|
}
|
||||||
|
|
||||||
// New creates a new AutoConfig object for providing automatic
|
// New creates a new AutoConfig object for providing automatic
|
||||||
|
@ -493,6 +494,8 @@ func (ac *AutoConfig) generateCSR() (csr string, key string, err error) {
|
||||||
// config data to be used during a call to ReadConfig, updating the
|
// config data to be used during a call to ReadConfig, updating the
|
||||||
// tls Configurator and prepopulating the cache.
|
// tls Configurator and prepopulating the cache.
|
||||||
func (ac *AutoConfig) update(resp *pbautoconf.AutoConfigResponse) error {
|
func (ac *AutoConfig) update(resp *pbautoconf.AutoConfigResponse) error {
|
||||||
|
ac.autoConfigResponse = resp
|
||||||
|
|
||||||
if err := ac.updateConfigFromResponse(resp); err != nil {
|
if err := ac.updateConfigFromResponse(resp); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -591,3 +594,18 @@ func (ac *AutoConfig) FallbackTLS(ctx context.Context) (*structs.SignedResponse,
|
||||||
|
|
||||||
return extractSignedResponse(resp)
|
return extractSignedResponse(resp)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (ac *AutoConfig) RecordUpdatedCerts(resp *structs.SignedResponse) error {
|
||||||
|
var err error
|
||||||
|
ac.autoConfigResponse.ExtraCACertificates = resp.ManualCARoots
|
||||||
|
ac.autoConfigResponse.CARoots, err = translateCARootsToProtobuf(&resp.ConnectCARoots)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
ac.autoConfigResponse.Certificate, err = translateIssuedCertToProtobuf(&resp.IssuedCert)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return ac.recordResponse(ac.autoConfigResponse)
|
||||||
|
}
|
||||||
|
|
|
@ -226,3 +226,34 @@ func mapstructureTranslateToStructs(in interface{}, out interface{}) error {
|
||||||
|
|
||||||
return decoder.Decode(in)
|
return decoder.Decode(in)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func translateCARootsToProtobuf(in *structs.IndexedCARoots) (*pbconnect.CARoots, error) {
|
||||||
|
var out pbconnect.CARoots
|
||||||
|
if err := mapstructureTranslateToProtobuf(in, &out); err != nil {
|
||||||
|
return nil, fmt.Errorf("Failed to re-encode CA Roots: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return &out, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func translateIssuedCertToProtobuf(in *structs.IssuedCert) (*pbconnect.IssuedCert, error) {
|
||||||
|
var out pbconnect.IssuedCert
|
||||||
|
if err := mapstructureTranslateToProtobuf(in, &out); err != nil {
|
||||||
|
return nil, fmt.Errorf("Failed to re-encode CA Roots: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return &out, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func mapstructureTranslateToProtobuf(in interface{}, out interface{}) error {
|
||||||
|
decoder, err := mapstructure.NewDecoder(&mapstructure.DecoderConfig{
|
||||||
|
DecodeHook: proto.HookTimeToPBTimestamp,
|
||||||
|
Result: out,
|
||||||
|
})
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return decoder.Decode(in)
|
||||||
|
}
|
||||||
|
|
|
@ -40,6 +40,7 @@ type CertMonitor struct {
|
||||||
tokens *token.Store
|
tokens *token.Store
|
||||||
leafReq cachetype.ConnectCALeafRequest
|
leafReq cachetype.ConnectCALeafRequest
|
||||||
rootsReq structs.DCSpecificRequest
|
rootsReq structs.DCSpecificRequest
|
||||||
|
persist PersistFunc
|
||||||
fallback FallbackFunc
|
fallback FallbackFunc
|
||||||
fallbackLeeway time.Duration
|
fallbackLeeway time.Duration
|
||||||
fallbackRetry time.Duration
|
fallbackRetry time.Duration
|
||||||
|
@ -66,6 +67,11 @@ type CertMonitor struct {
|
||||||
// events from the token store when the Agent
|
// events from the token store when the Agent
|
||||||
// token is updated.
|
// token is updated.
|
||||||
tokenUpdates token.Notifier
|
tokenUpdates token.Notifier
|
||||||
|
|
||||||
|
// this is used to keep a local copy of the certs
|
||||||
|
// keys and ca certs. It will be used to persist
|
||||||
|
// all of the local state at once.
|
||||||
|
certs structs.SignedResponse
|
||||||
}
|
}
|
||||||
|
|
||||||
// New creates a new CertMonitor for automatically rotating
|
// New creates a new CertMonitor for automatically rotating
|
||||||
|
@ -115,6 +121,7 @@ func New(config *Config) (*CertMonitor, error) {
|
||||||
cache: config.Cache,
|
cache: config.Cache,
|
||||||
tokens: config.Tokens,
|
tokens: config.Tokens,
|
||||||
tlsConfigurator: config.TLSConfigurator,
|
tlsConfigurator: config.TLSConfigurator,
|
||||||
|
persist: config.Persist,
|
||||||
fallback: config.Fallback,
|
fallback: config.Fallback,
|
||||||
fallbackLeeway: config.FallbackLeeway,
|
fallbackLeeway: config.FallbackLeeway,
|
||||||
fallbackRetry: config.FallbackRetry,
|
fallbackRetry: config.FallbackRetry,
|
||||||
|
@ -135,6 +142,8 @@ func (m *CertMonitor) Update(certs *structs.SignedResponse) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
m.certs = *certs
|
||||||
|
|
||||||
if err := m.populateCache(certs); err != nil {
|
if err := m.populateCache(certs); err != nil {
|
||||||
return fmt.Errorf("error populating cache with certificates: %w", err)
|
return fmt.Errorf("error populating cache with certificates: %w", err)
|
||||||
}
|
}
|
||||||
|
@ -306,6 +315,8 @@ func (m *CertMonitor) handleCacheEvent(u cache.UpdateEvent) error {
|
||||||
return fmt.Errorf("invalid type for roots watch response: %T", u.Result)
|
return fmt.Errorf("invalid type for roots watch response: %T", u.Result)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
m.certs.ConnectCARoots = *roots
|
||||||
|
|
||||||
var pems []string
|
var pems []string
|
||||||
for _, root := range roots.Roots {
|
for _, root := range roots.Roots {
|
||||||
pems = append(pems, root.RootCert)
|
pems = append(pems, root.RootCert)
|
||||||
|
@ -314,6 +325,13 @@ func (m *CertMonitor) handleCacheEvent(u cache.UpdateEvent) error {
|
||||||
if err := m.tlsConfigurator.UpdateAutoTLSCA(pems); err != nil {
|
if err := m.tlsConfigurator.UpdateAutoTLSCA(pems); err != nil {
|
||||||
return fmt.Errorf("failed to update Connect CA certificates: %w", err)
|
return fmt.Errorf("failed to update Connect CA certificates: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if m.persist != nil {
|
||||||
|
copy := m.certs
|
||||||
|
if err := m.persist(©); err != nil {
|
||||||
|
return fmt.Errorf("failed to persist certificate package: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
case leafWatchID:
|
case leafWatchID:
|
||||||
m.logger.Debug("leaf certificate watch fired - updating TLS certificate")
|
m.logger.Debug("leaf certificate watch fired - updating TLS certificate")
|
||||||
if u.Err != nil {
|
if u.Err != nil {
|
||||||
|
@ -324,9 +342,19 @@ func (m *CertMonitor) handleCacheEvent(u cache.UpdateEvent) error {
|
||||||
if !ok {
|
if !ok {
|
||||||
return fmt.Errorf("invalid type for agent leaf cert watch response: %T", u.Result)
|
return fmt.Errorf("invalid type for agent leaf cert watch response: %T", u.Result)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
m.certs.IssuedCert = *leaf
|
||||||
|
|
||||||
if err := m.tlsConfigurator.UpdateAutoTLSCert(leaf.CertPEM, leaf.PrivateKeyPEM); err != nil {
|
if err := m.tlsConfigurator.UpdateAutoTLSCert(leaf.CertPEM, leaf.PrivateKeyPEM); err != nil {
|
||||||
return fmt.Errorf("failed to update the agent leaf cert: %w", err)
|
return fmt.Errorf("failed to update the agent leaf cert: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if m.persist != nil {
|
||||||
|
copy := m.certs
|
||||||
|
if err := m.persist(©); err != nil {
|
||||||
|
return fmt.Errorf("failed to persist certificate package: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
|
@ -380,6 +408,11 @@ func (m *CertMonitor) handleFallback(ctx context.Context) error {
|
||||||
return fmt.Errorf("error when getting new agent certificate: %w", err)
|
return fmt.Errorf("error when getting new agent certificate: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if m.persist != nil {
|
||||||
|
if err := m.persist(reply); err != nil {
|
||||||
|
return fmt.Errorf("failed to persist certificate package: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
return m.Update(reply)
|
return m.Update(reply)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -33,6 +33,14 @@ func (m *mockFallback) fallback(ctx context.Context) (*structs.SignedResponse, e
|
||||||
return resp, ret.Error(1)
|
return resp, ret.Error(1)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type mockPersist struct {
|
||||||
|
mock.Mock
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockPersist) persist(resp *structs.SignedResponse) error {
|
||||||
|
return m.Called(resp).Error(0)
|
||||||
|
}
|
||||||
|
|
||||||
type mockWatcher struct {
|
type mockWatcher struct {
|
||||||
ch chan<- cache.UpdateEvent
|
ch chan<- cache.UpdateEvent
|
||||||
done <-chan struct{}
|
done <-chan struct{}
|
||||||
|
@ -159,6 +167,7 @@ type testCertMonitor struct {
|
||||||
tls *tlsutil.Configurator
|
tls *tlsutil.Configurator
|
||||||
tokens *token.Store
|
tokens *token.Store
|
||||||
fallback *mockFallback
|
fallback *mockFallback
|
||||||
|
persist *mockPersist
|
||||||
|
|
||||||
extraCACerts []string
|
extraCACerts []string
|
||||||
initialCert *structs.IssuedCert
|
initialCert *structs.IssuedCert
|
||||||
|
@ -210,8 +219,10 @@ func newTestCertMonitor(t *testing.T) testCertMonitor {
|
||||||
dnsSANs := []string{"test.dev"}
|
dnsSANs := []string{"test.dev"}
|
||||||
ipSANs := []net.IP{net.IPv4(198, 18, 0, 1)}
|
ipSANs := []net.IP{net.IPv4(198, 18, 0, 1)}
|
||||||
|
|
||||||
// this chan should be unbuffered so we can detect when the fallback func has been called.
|
|
||||||
fallback := &mockFallback{}
|
fallback := &mockFallback{}
|
||||||
|
fallback.Test(t)
|
||||||
|
persist := &mockPersist{}
|
||||||
|
persist.Test(t)
|
||||||
|
|
||||||
mcache := newMockCache(t)
|
mcache := newMockCache(t)
|
||||||
rootRes := cache.FetchResult{Value: &indexedRoots, Index: 1}
|
rootRes := cache.FetchResult{Value: &indexedRoots, Index: 1}
|
||||||
|
@ -246,7 +257,8 @@ func newTestCertMonitor(t *testing.T) testCertMonitor {
|
||||||
WithDatacenter("foo").
|
WithDatacenter("foo").
|
||||||
WithNodeName("node").
|
WithNodeName("node").
|
||||||
WithFallbackLeeway(time.Nanosecond).
|
WithFallbackLeeway(time.Nanosecond).
|
||||||
WithFallbackRetry(time.Millisecond)
|
WithFallbackRetry(time.Millisecond).
|
||||||
|
WithPersistence(persist.persist)
|
||||||
|
|
||||||
monitor, err := New(cfg)
|
monitor, err := New(cfg)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
@ -259,6 +271,7 @@ func newTestCertMonitor(t *testing.T) testCertMonitor {
|
||||||
tls: tlsConfigurator,
|
tls: tlsConfigurator,
|
||||||
tokens: tokens,
|
tokens: tokens,
|
||||||
mcache: mcache,
|
mcache: mcache,
|
||||||
|
persist: persist,
|
||||||
fallback: fallback,
|
fallback: fallback,
|
||||||
extraCACerts: []string{manualCA.RootCert},
|
extraCACerts: []string{manualCA.RootCert},
|
||||||
initialCert: issued,
|
initialCert: issued,
|
||||||
|
@ -298,6 +311,7 @@ func (cm *testCertMonitor) initialCACerts() []string {
|
||||||
func (cm *testCertMonitor) assertExpectations(t *testing.T) {
|
func (cm *testCertMonitor) assertExpectations(t *testing.T) {
|
||||||
cm.mcache.AssertExpectations(t)
|
cm.mcache.AssertExpectations(t)
|
||||||
cm.fallback.AssertExpectations(t)
|
cm.fallback.AssertExpectations(t)
|
||||||
|
cm.persist.AssertExpectations(t)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestCertMonitor_InitialCerts(t *testing.T) {
|
func TestCertMonitor_InitialCerts(t *testing.T) {
|
||||||
|
@ -473,6 +487,13 @@ func TestCertMonitor_RootsUpdate(t *testing.T) {
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
cm.persist.On("persist", &structs.SignedResponse{
|
||||||
|
IssuedCert: *cm.initialCert,
|
||||||
|
ManualCARoots: cm.extraCACerts,
|
||||||
|
ConnectCARoots: secondRoots,
|
||||||
|
VerifyServerHostname: cm.verifyServerHostname,
|
||||||
|
}).Return(nil).Once()
|
||||||
|
|
||||||
// assert value of the CA certs prior to updating
|
// assert value of the CA certs prior to updating
|
||||||
require.ElementsMatch(t, cm.initialCACerts(), cm.tls.CAPems())
|
require.ElementsMatch(t, cm.initialCACerts(), cm.tls.CAPems())
|
||||||
|
|
||||||
|
@ -500,6 +521,13 @@ func TestCertMonitor_CertUpdate(t *testing.T) {
|
||||||
|
|
||||||
secondCert := newLeaf(t, cm.initialRoots.Roots[0], 100, 10*time.Minute)
|
secondCert := newLeaf(t, cm.initialRoots.Roots[0], 100, 10*time.Minute)
|
||||||
|
|
||||||
|
cm.persist.On("persist", &structs.SignedResponse{
|
||||||
|
IssuedCert: *secondCert,
|
||||||
|
ManualCARoots: cm.extraCACerts,
|
||||||
|
ConnectCARoots: *cm.initialRoots,
|
||||||
|
VerifyServerHostname: cm.verifyServerHostname,
|
||||||
|
}).Return(nil).Once()
|
||||||
|
|
||||||
// assert value of cert prior to updating the leaf
|
// assert value of cert prior to updating the leaf
|
||||||
require.Equal(t, cm.initialTLSCertificate(t), cm.tls.Cert())
|
require.Equal(t, cm.initialTLSCertificate(t), cm.tls.Cert())
|
||||||
|
|
||||||
|
@ -549,13 +577,23 @@ func TestCertMonitor_Fallback(t *testing.T) {
|
||||||
// inject a fallback routine error to check that we rerun it quickly
|
// inject a fallback routine error to check that we rerun it quickly
|
||||||
cm.fallback.On("fallback").Return(nil, fmt.Errorf("induced error")).Once()
|
cm.fallback.On("fallback").Return(nil, fmt.Errorf("induced error")).Once()
|
||||||
|
|
||||||
// expect the fallback routine to be executed and setup the return
|
fallbackResp := &structs.SignedResponse{
|
||||||
cm.fallback.On("fallback").Return(&structs.SignedResponse{
|
|
||||||
ConnectCARoots: secondRoots,
|
ConnectCARoots: secondRoots,
|
||||||
IssuedCert: *thirdCert,
|
IssuedCert: *thirdCert,
|
||||||
ManualCARoots: cm.extraCACerts,
|
ManualCARoots: cm.extraCACerts,
|
||||||
VerifyServerHostname: true,
|
VerifyServerHostname: true,
|
||||||
}, nil).Once()
|
}
|
||||||
|
// expect the fallback routine to be executed and setup the return
|
||||||
|
cm.fallback.On("fallback").Return(fallbackResp, nil).Once()
|
||||||
|
|
||||||
|
cm.persist.On("persist", &structs.SignedResponse{
|
||||||
|
IssuedCert: *secondCert,
|
||||||
|
ConnectCARoots: *cm.initialRoots,
|
||||||
|
ManualCARoots: cm.extraCACerts,
|
||||||
|
VerifyServerHostname: cm.verifyServerHostname,
|
||||||
|
}).Return(nil).Once()
|
||||||
|
|
||||||
|
cm.persist.On("persist", fallbackResp).Return(nil).Once()
|
||||||
|
|
||||||
// Add another roots cache prepopulation expectation which should happen
|
// Add another roots cache prepopulation expectation which should happen
|
||||||
// in response to executing the fallback mechanism
|
// in response to executing the fallback mechanism
|
||||||
|
|
|
@ -16,6 +16,9 @@ import (
|
||||||
// method of updating the certificate is required.
|
// method of updating the certificate is required.
|
||||||
type FallbackFunc func(context.Context) (*structs.SignedResponse, error)
|
type FallbackFunc func(context.Context) (*structs.SignedResponse, error)
|
||||||
|
|
||||||
|
// PersistFunc is used to persist the data from a signed response
|
||||||
|
type PersistFunc func(*structs.SignedResponse) error
|
||||||
|
|
||||||
type Config struct {
|
type Config struct {
|
||||||
// Logger is the logger to be used while running. If not set
|
// Logger is the logger to be used while running. If not set
|
||||||
// then no logging will be performed.
|
// then no logging will be performed.
|
||||||
|
@ -34,6 +37,9 @@ type Config struct {
|
||||||
// This field is required.
|
// This field is required.
|
||||||
Tokens *token.Store
|
Tokens *token.Store
|
||||||
|
|
||||||
|
// Persist is a function to run when there are new certs or keys
|
||||||
|
Persist PersistFunc
|
||||||
|
|
||||||
// Fallback is a function to run when the normal cache updating of the
|
// Fallback is a function to run when the normal cache updating of the
|
||||||
// agent's certificates has failed to work for one reason or another.
|
// agent's certificates has failed to work for one reason or another.
|
||||||
// This field is required.
|
// This field is required.
|
||||||
|
@ -135,3 +141,10 @@ func (cfg *Config) WithFallbackRetry(after time.Duration) *Config {
|
||||||
cfg.FallbackRetry = after
|
cfg.FallbackRetry = after
|
||||||
return cfg
|
return cfg
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// WithPersistence will configure the CertMonitor to use this callback for persisting
|
||||||
|
// a new TLS configuration.
|
||||||
|
func (cfg *Config) WithPersistence(persist PersistFunc) *Config {
|
||||||
|
cfg.Persist = persist
|
||||||
|
return cfg
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in New Issue