From c9b66157a128158c7b3e3800db2ca54d2b7d9c3f Mon Sep 17 00:00:00 2001 From: Matt Keeler Date: Thu, 30 Jul 2020 11:37:18 -0400 Subject: [PATCH] Ensure certificates retrieved through the cache get persisted with auto-config (#8409) --- agent/agent.go | 13 +++++- agent/agent_test.go | 54 ++++++++++++++++++------- agent/auto-config/auto_config.go | 36 ++++++++++++----- agent/auto-config/config_translate.go | 31 ++++++++++++++ agent/cert-monitor/cert_monitor.go | 33 +++++++++++++++ agent/cert-monitor/cert_monitor_test.go | 48 +++++++++++++++++++--- agent/cert-monitor/config.go | 13 ++++++ 7 files changed, 198 insertions(+), 30 deletions(-) diff --git a/agent/agent.go b/agent/agent.go index 96546c16be..77051e640b 100644 --- a/agent/agent.go +++ b/agent/agent.go @@ -545,7 +545,8 @@ func New(options ...AgentOption) (*Agent, error) { WithNodeName(a.config.NodeName). WithFallback(a.autoConfigFallbackTLS). WithLogger(a.logger.Named(logging.AutoConfig)). - WithTokens(a.tokens) + WithTokens(a.tokens). + WithPersistence(a.autoConfigPersist) acCertMon, err := certmon.New(cmConf) if err != nil { return nil, err @@ -889,9 +890,19 @@ func (a *Agent) autoEncryptInitialCertificate(ctx context.Context) (*structs.Sig } func (a *Agent) autoConfigFallbackTLS(ctx context.Context) (*structs.SignedResponse, error) { + if a.autoConf == nil { + return nil, fmt.Errorf("AutoConfig manager has not been created yet") + } return a.autoConf.FallbackTLS(ctx) } +func (a *Agent) autoConfigPersist(resp *structs.SignedResponse) error { + if a.autoConf == nil { + return fmt.Errorf("AutoConfig manager has not been created yet") + } + return a.autoConf.RecordUpdatedCerts(resp) +} + func (a *Agent) listenAndServeGRPC() error { if len(a.config.GRPCAddrs) < 1 { return nil diff --git a/agent/agent_test.go b/agent/agent_test.go index 0c7e6d23d0..d2eeb8da5a 100644 --- a/agent/agent_test.go +++ b/agent/agent_test.go @@ -21,6 +21,7 @@ import ( "testing" "time" + "github.com/golang/protobuf/jsonpb" "github.com/google/tcpproxy" "github.com/hashicorp/consul/agent/cache" cachetype "github.com/hashicorp/consul/agent/cache-types" @@ -31,6 +32,7 @@ import ( "github.com/hashicorp/consul/api" "github.com/hashicorp/consul/internal/go-sso/oidcauth/oidcauthtest" "github.com/hashicorp/consul/ipaddr" + "github.com/hashicorp/consul/proto/pbautoconf" "github.com/hashicorp/consul/sdk/freeport" "github.com/hashicorp/consul/sdk/testutil" "github.com/hashicorp/consul/sdk/testutil/retry" @@ -4728,21 +4730,28 @@ func TestAutoConfig_Integration(t *testing.T) { }) require.NoError(t, err) - client := StartTestAgent(t, TestAgent{Name: "test-client", HCL: ` - bootstrap = false - server = false - ca_file = "` + caFile + `" - verify_outgoing = true - verify_server_hostname = true - node_name = "test-client" - ports { - server = ` + strconv.Itoa(srv.Config.RPCBindAddr.Port) + ` - } - auto_config { - enabled = true - intro_token = "` + token + `" - server_addresses = ["` + srv.Config.RPCBindAddr.String() + `"] - }`}) + client := StartTestAgent(t, TestAgent{Name: "test-client", + Overrides: ` + connect { + test_ca_leaf_root_change_spread = "1ns" + } + `, + HCL: ` + bootstrap = false + server = false + ca_file = "` + caFile + `" + verify_outgoing = true + verify_server_hostname = true + node_name = "test-client" + ports { + server = ` + strconv.Itoa(srv.Config.RPCBindAddr.Port) + ` + } + auto_config { + enabled = true + intro_token = "` + token + `" + server_addresses = ["` + srv.Config.RPCBindAddr.String() + `"] + }`, + }) defer client.Shutdown() @@ -4782,6 +4791,21 @@ func TestAutoConfig_Integration(t *testing.T) { // ensure that a new cert gets generated and pushed into the TLS configurator retry.Run(t, func(r *retry.R) { require.NotEqual(r, cert1, client.Agent.tlsConfigurator.Cert()) + + // check that the on disk certs match expectations + data, err := ioutil.ReadFile(filepath.Join(client.DataDir, "auto-config.json")) + require.NoError(r, err) + rdr := strings.NewReader(string(data)) + + var resp pbautoconf.AutoConfigResponse + pbUnmarshaler := &jsonpb.Unmarshaler{ + AllowUnknownFields: false, + } + require.NoError(r, pbUnmarshaler.Unmarshal(rdr, &resp), "data: %s", data) + + actual, err := tls.X509KeyPair([]byte(resp.Certificate.CertPEM), []byte(resp.Certificate.PrivateKeyPEM)) + require.NoError(r, err) + require.Equal(r, client.Agent.tlsConfigurator.Cert(), &actual) }) // spot check that we now have an ACL token diff --git a/agent/auto-config/auto_config.go b/agent/auto-config/auto_config.go index 61587c39a8..bf2c348e93 100644 --- a/agent/auto-config/auto_config.go +++ b/agent/auto-config/auto_config.go @@ -55,15 +55,16 @@ var ( // then we will need to add some locking here. I am deferring that for now // to help ease the review of this already large PR. type AutoConfig struct { - builderOpts config.BuilderOpts - logger hclog.Logger - directRPC DirectRPC - waiter *lib.RetryWaiter - overrides []config.Source - certMonitor CertMonitor - config *config.RuntimeConfig - autoConfigData string - cancel context.CancelFunc + builderOpts config.BuilderOpts + logger hclog.Logger + directRPC DirectRPC + waiter *lib.RetryWaiter + overrides []config.Source + certMonitor CertMonitor + config *config.RuntimeConfig + autoConfigResponse *pbautoconf.AutoConfigResponse + autoConfigData string + cancel context.CancelFunc } // New creates a new AutoConfig object for providing automatic @@ -493,6 +494,8 @@ func (ac *AutoConfig) generateCSR() (csr string, key string, err error) { // config data to be used during a call to ReadConfig, updating the // tls Configurator and prepopulating the cache. func (ac *AutoConfig) update(resp *pbautoconf.AutoConfigResponse) error { + ac.autoConfigResponse = resp + if err := ac.updateConfigFromResponse(resp); err != nil { return err } @@ -591,3 +594,18 @@ func (ac *AutoConfig) FallbackTLS(ctx context.Context) (*structs.SignedResponse, return extractSignedResponse(resp) } + +func (ac *AutoConfig) RecordUpdatedCerts(resp *structs.SignedResponse) error { + var err error + ac.autoConfigResponse.ExtraCACertificates = resp.ManualCARoots + ac.autoConfigResponse.CARoots, err = translateCARootsToProtobuf(&resp.ConnectCARoots) + if err != nil { + return err + } + ac.autoConfigResponse.Certificate, err = translateIssuedCertToProtobuf(&resp.IssuedCert) + if err != nil { + return err + } + + return ac.recordResponse(ac.autoConfigResponse) +} diff --git a/agent/auto-config/config_translate.go b/agent/auto-config/config_translate.go index 313b37a47f..ff02a63262 100644 --- a/agent/auto-config/config_translate.go +++ b/agent/auto-config/config_translate.go @@ -226,3 +226,34 @@ func mapstructureTranslateToStructs(in interface{}, out interface{}) error { return decoder.Decode(in) } + +func translateCARootsToProtobuf(in *structs.IndexedCARoots) (*pbconnect.CARoots, error) { + var out pbconnect.CARoots + if err := mapstructureTranslateToProtobuf(in, &out); err != nil { + return nil, fmt.Errorf("Failed to re-encode CA Roots: %w", err) + } + + return &out, nil +} + +func translateIssuedCertToProtobuf(in *structs.IssuedCert) (*pbconnect.IssuedCert, error) { + var out pbconnect.IssuedCert + if err := mapstructureTranslateToProtobuf(in, &out); err != nil { + return nil, fmt.Errorf("Failed to re-encode CA Roots: %w", err) + } + + return &out, nil +} + +func mapstructureTranslateToProtobuf(in interface{}, out interface{}) error { + decoder, err := mapstructure.NewDecoder(&mapstructure.DecoderConfig{ + DecodeHook: proto.HookTimeToPBTimestamp, + Result: out, + }) + + if err != nil { + return err + } + + return decoder.Decode(in) +} diff --git a/agent/cert-monitor/cert_monitor.go b/agent/cert-monitor/cert_monitor.go index c23f08a256..0ad50e8a11 100644 --- a/agent/cert-monitor/cert_monitor.go +++ b/agent/cert-monitor/cert_monitor.go @@ -40,6 +40,7 @@ type CertMonitor struct { tokens *token.Store leafReq cachetype.ConnectCALeafRequest rootsReq structs.DCSpecificRequest + persist PersistFunc fallback FallbackFunc fallbackLeeway time.Duration fallbackRetry time.Duration @@ -66,6 +67,11 @@ type CertMonitor struct { // events from the token store when the Agent // token is updated. tokenUpdates token.Notifier + + // this is used to keep a local copy of the certs + // keys and ca certs. It will be used to persist + // all of the local state at once. + certs structs.SignedResponse } // New creates a new CertMonitor for automatically rotating @@ -115,6 +121,7 @@ func New(config *Config) (*CertMonitor, error) { cache: config.Cache, tokens: config.Tokens, tlsConfigurator: config.TLSConfigurator, + persist: config.Persist, fallback: config.Fallback, fallbackLeeway: config.FallbackLeeway, fallbackRetry: config.FallbackRetry, @@ -135,6 +142,8 @@ func (m *CertMonitor) Update(certs *structs.SignedResponse) error { return nil } + m.certs = *certs + if err := m.populateCache(certs); err != nil { return fmt.Errorf("error populating cache with certificates: %w", err) } @@ -306,6 +315,8 @@ func (m *CertMonitor) handleCacheEvent(u cache.UpdateEvent) error { return fmt.Errorf("invalid type for roots watch response: %T", u.Result) } + m.certs.ConnectCARoots = *roots + var pems []string for _, root := range roots.Roots { pems = append(pems, root.RootCert) @@ -314,6 +325,13 @@ func (m *CertMonitor) handleCacheEvent(u cache.UpdateEvent) error { if err := m.tlsConfigurator.UpdateAutoTLSCA(pems); err != nil { return fmt.Errorf("failed to update Connect CA certificates: %w", err) } + + if m.persist != nil { + copy := m.certs + if err := m.persist(©); err != nil { + return fmt.Errorf("failed to persist certificate package: %w", err) + } + } case leafWatchID: m.logger.Debug("leaf certificate watch fired - updating TLS certificate") if u.Err != nil { @@ -324,9 +342,19 @@ func (m *CertMonitor) handleCacheEvent(u cache.UpdateEvent) error { if !ok { return fmt.Errorf("invalid type for agent leaf cert watch response: %T", u.Result) } + + m.certs.IssuedCert = *leaf + if err := m.tlsConfigurator.UpdateAutoTLSCert(leaf.CertPEM, leaf.PrivateKeyPEM); err != nil { return fmt.Errorf("failed to update the agent leaf cert: %w", err) } + + if m.persist != nil { + copy := m.certs + if err := m.persist(©); err != nil { + return fmt.Errorf("failed to persist certificate package: %w", err) + } + } } return nil @@ -380,6 +408,11 @@ func (m *CertMonitor) handleFallback(ctx context.Context) error { return fmt.Errorf("error when getting new agent certificate: %w", err) } + if m.persist != nil { + if err := m.persist(reply); err != nil { + return fmt.Errorf("failed to persist certificate package: %w", err) + } + } return m.Update(reply) } diff --git a/agent/cert-monitor/cert_monitor_test.go b/agent/cert-monitor/cert_monitor_test.go index 500de549ab..2b6ea76d86 100644 --- a/agent/cert-monitor/cert_monitor_test.go +++ b/agent/cert-monitor/cert_monitor_test.go @@ -33,6 +33,14 @@ func (m *mockFallback) fallback(ctx context.Context) (*structs.SignedResponse, e return resp, ret.Error(1) } +type mockPersist struct { + mock.Mock +} + +func (m *mockPersist) persist(resp *structs.SignedResponse) error { + return m.Called(resp).Error(0) +} + type mockWatcher struct { ch chan<- cache.UpdateEvent done <-chan struct{} @@ -159,6 +167,7 @@ type testCertMonitor struct { tls *tlsutil.Configurator tokens *token.Store fallback *mockFallback + persist *mockPersist extraCACerts []string initialCert *structs.IssuedCert @@ -210,8 +219,10 @@ func newTestCertMonitor(t *testing.T) testCertMonitor { dnsSANs := []string{"test.dev"} ipSANs := []net.IP{net.IPv4(198, 18, 0, 1)} - // this chan should be unbuffered so we can detect when the fallback func has been called. fallback := &mockFallback{} + fallback.Test(t) + persist := &mockPersist{} + persist.Test(t) mcache := newMockCache(t) rootRes := cache.FetchResult{Value: &indexedRoots, Index: 1} @@ -246,7 +257,8 @@ func newTestCertMonitor(t *testing.T) testCertMonitor { WithDatacenter("foo"). WithNodeName("node"). WithFallbackLeeway(time.Nanosecond). - WithFallbackRetry(time.Millisecond) + WithFallbackRetry(time.Millisecond). + WithPersistence(persist.persist) monitor, err := New(cfg) require.NoError(t, err) @@ -259,6 +271,7 @@ func newTestCertMonitor(t *testing.T) testCertMonitor { tls: tlsConfigurator, tokens: tokens, mcache: mcache, + persist: persist, fallback: fallback, extraCACerts: []string{manualCA.RootCert}, initialCert: issued, @@ -298,6 +311,7 @@ func (cm *testCertMonitor) initialCACerts() []string { func (cm *testCertMonitor) assertExpectations(t *testing.T) { cm.mcache.AssertExpectations(t) cm.fallback.AssertExpectations(t) + cm.persist.AssertExpectations(t) } func TestCertMonitor_InitialCerts(t *testing.T) { @@ -473,6 +487,13 @@ func TestCertMonitor_RootsUpdate(t *testing.T) { }, } + cm.persist.On("persist", &structs.SignedResponse{ + IssuedCert: *cm.initialCert, + ManualCARoots: cm.extraCACerts, + ConnectCARoots: secondRoots, + VerifyServerHostname: cm.verifyServerHostname, + }).Return(nil).Once() + // assert value of the CA certs prior to updating require.ElementsMatch(t, cm.initialCACerts(), cm.tls.CAPems()) @@ -500,6 +521,13 @@ func TestCertMonitor_CertUpdate(t *testing.T) { secondCert := newLeaf(t, cm.initialRoots.Roots[0], 100, 10*time.Minute) + cm.persist.On("persist", &structs.SignedResponse{ + IssuedCert: *secondCert, + ManualCARoots: cm.extraCACerts, + ConnectCARoots: *cm.initialRoots, + VerifyServerHostname: cm.verifyServerHostname, + }).Return(nil).Once() + // assert value of cert prior to updating the leaf require.Equal(t, cm.initialTLSCertificate(t), cm.tls.Cert()) @@ -549,13 +577,23 @@ func TestCertMonitor_Fallback(t *testing.T) { // inject a fallback routine error to check that we rerun it quickly cm.fallback.On("fallback").Return(nil, fmt.Errorf("induced error")).Once() - // expect the fallback routine to be executed and setup the return - cm.fallback.On("fallback").Return(&structs.SignedResponse{ + fallbackResp := &structs.SignedResponse{ ConnectCARoots: secondRoots, IssuedCert: *thirdCert, ManualCARoots: cm.extraCACerts, VerifyServerHostname: true, - }, nil).Once() + } + // expect the fallback routine to be executed and setup the return + cm.fallback.On("fallback").Return(fallbackResp, nil).Once() + + cm.persist.On("persist", &structs.SignedResponse{ + IssuedCert: *secondCert, + ConnectCARoots: *cm.initialRoots, + ManualCARoots: cm.extraCACerts, + VerifyServerHostname: cm.verifyServerHostname, + }).Return(nil).Once() + + cm.persist.On("persist", fallbackResp).Return(nil).Once() // Add another roots cache prepopulation expectation which should happen // in response to executing the fallback mechanism diff --git a/agent/cert-monitor/config.go b/agent/cert-monitor/config.go index a1da2841e6..2e4bcc57ca 100644 --- a/agent/cert-monitor/config.go +++ b/agent/cert-monitor/config.go @@ -16,6 +16,9 @@ import ( // method of updating the certificate is required. type FallbackFunc func(context.Context) (*structs.SignedResponse, error) +// PersistFunc is used to persist the data from a signed response +type PersistFunc func(*structs.SignedResponse) error + type Config struct { // Logger is the logger to be used while running. If not set // then no logging will be performed. @@ -34,6 +37,9 @@ type Config struct { // This field is required. Tokens *token.Store + // Persist is a function to run when there are new certs or keys + Persist PersistFunc + // Fallback is a function to run when the normal cache updating of the // agent's certificates has failed to work for one reason or another. // This field is required. @@ -135,3 +141,10 @@ func (cfg *Config) WithFallbackRetry(after time.Duration) *Config { cfg.FallbackRetry = after return cfg } + +// WithPersistence will configure the CertMonitor to use this callback for persisting +// a new TLS configuration. +func (cfg *Config) WithPersistence(persist PersistFunc) *Config { + cfg.Persist = persist + return cfg +}