diff --git a/agent/agent.go b/agent/agent.go index f0ddfa5a6b..a27e5cea8b 100644 --- a/agent/agent.go +++ b/agent/agent.go @@ -32,6 +32,7 @@ import ( autoconf "github.com/hashicorp/consul/agent/auto-config" "github.com/hashicorp/consul/agent/cache" cachetype "github.com/hashicorp/consul/agent/cache-types" + certmon "github.com/hashicorp/consul/agent/cert-monitor" "github.com/hashicorp/consul/agent/checks" "github.com/hashicorp/consul/agent/config" "github.com/hashicorp/consul/agent/consul" @@ -165,6 +166,8 @@ type notifier interface { type Agent struct { autoConf *autoconf.AutoConfig + certMonitor *certmon.CertMonitor + // config is the agent configuration. config *config.RuntimeConfig @@ -716,17 +719,39 @@ func (a *Agent) Start(ctx context.Context) error { a.registerCache() if a.config.AutoEncryptTLS && !a.config.ServerMode { - reply, err := a.setupClientAutoEncrypt(ctx) + reply, err := a.autoEncryptInitialCertificate(ctx) if err != nil { return fmt.Errorf("AutoEncrypt failed: %s", err) } - rootsReq, leafReq, err := a.setupClientAutoEncryptCache(reply) + + cmConfig := new(certmon.Config). + WithCache(a.cache). + WithLogger(a.logger.Named(logging.AutoEncrypt)). + WithTLSConfigurator(a.tlsConfigurator). + WithTokens(a.tokens). + WithFallback(a.autoEncryptInitialCertificate). + WithDNSSANs(a.config.AutoEncryptDNSSAN). + WithIPSANs(a.config.AutoEncryptIPSAN). + WithDatacenter(a.config.Datacenter). + WithNodeName(a.config.NodeName) + + monitor, err := certmon.New(cmConfig) if err != nil { - return fmt.Errorf("AutoEncrypt failed: %s", err) + return fmt.Errorf("AutoEncrypt failed to setup certificate monitor: %w", err) } - if err = a.setupClientAutoEncryptWatching(rootsReq, leafReq); err != nil { - return fmt.Errorf("AutoEncrypt failed: %s", err) + if err := monitor.Update(reply); err != nil { + return fmt.Errorf("AutoEncrypt failed to setup certificate monitor: %w", err) } + a.certMonitor = monitor + + // we don't need to worry about ever calling Stop as we have tied the go routines + // to the agents lifetime by using the StopCh. Also the agent itself doesn't have + // a need of ensuring that the go routine was stopped before performing any action + // so we can ignore the chan in the return. + if _, err := a.certMonitor.Start(&lib.StopChannelContext{StopCh: a.shutdownCh}); err != nil { + return fmt.Errorf("AutoEncrypt failed to start certificate monitor: %w", err) + } + a.logger.Info("automatically upgraded to TLS") } @@ -829,7 +854,7 @@ func (a *Agent) Start(ctx context.Context) error { return nil } -func (a *Agent) setupClientAutoEncrypt(ctx context.Context) (*structs.SignedResponse, error) { +func (a *Agent) autoEncryptInitialCertificate(ctx context.Context) (*structs.SignedResponse, error) { client := a.delegate.(*consul.Client) addrs := a.config.StartJoinAddrsLAN @@ -839,165 +864,7 @@ func (a *Agent) setupClientAutoEncrypt(ctx context.Context) (*structs.SignedResp } addrs = append(addrs, retryJoinAddrs(disco, retryJoinSerfVariant, "LAN", a.config.RetryJoinLAN, a.logger)...) - reply, priv, err := client.RequestAutoEncryptCerts(ctx, addrs, a.config.ServerPort, a.tokens.AgentToken(), a.config.AutoEncryptDNSSAN, a.config.AutoEncryptIPSAN) - if err != nil { - return nil, err - } - - connectCAPems := []string{} - for _, ca := range reply.ConnectCARoots.Roots { - connectCAPems = append(connectCAPems, ca.RootCert) - } - if err := a.tlsConfigurator.UpdateAutoEncrypt(reply.ManualCARoots, connectCAPems, reply.IssuedCert.CertPEM, priv, reply.VerifyServerHostname); err != nil { - return nil, err - } - return reply, nil - -} - -func (a *Agent) setupClientAutoEncryptCache(reply *structs.SignedResponse) (*structs.DCSpecificRequest, *cachetype.ConnectCALeafRequest, error) { - rootsReq := &structs.DCSpecificRequest{ - Datacenter: a.config.Datacenter, - QueryOptions: structs.QueryOptions{Token: a.tokens.AgentToken()}, - } - - // prepolutate roots cache - rootRes := cache.FetchResult{Value: &reply.ConnectCARoots, Index: reply.ConnectCARoots.QueryMeta.Index} - if err := a.cache.Prepopulate(cachetype.ConnectCARootName, rootRes, a.config.Datacenter, a.tokens.AgentToken(), rootsReq.CacheInfo().Key); err != nil { - return nil, nil, err - } - - leafReq := &cachetype.ConnectCALeafRequest{ - Datacenter: a.config.Datacenter, - Token: a.tokens.AgentToken(), - Agent: a.config.NodeName, - DNSSAN: a.config.AutoEncryptDNSSAN, - IPSAN: a.config.AutoEncryptIPSAN, - } - - // prepolutate leaf cache - certRes := cache.FetchResult{ - Value: &reply.IssuedCert, - Index: reply.ConnectCARoots.QueryMeta.Index, - } - - for _, ca := range reply.ConnectCARoots.Roots { - if ca.ID == reply.ConnectCARoots.ActiveRootID { - certRes.State = cachetype.ConnectCALeafSuccess(ca.SigningKeyID) - break - } - } - if err := a.cache.Prepopulate(cachetype.ConnectCALeafName, certRes, a.config.Datacenter, a.tokens.AgentToken(), leafReq.Key()); err != nil { - return nil, nil, err - } - return rootsReq, leafReq, nil -} - -func (a *Agent) setupClientAutoEncryptWatching(rootsReq *structs.DCSpecificRequest, leafReq *cachetype.ConnectCALeafRequest) error { - // setup watches - ch := make(chan cache.UpdateEvent, 10) - ctx, cancel := context.WithCancel(context.Background()) - - // Watch for root changes - err := a.cache.Notify(ctx, cachetype.ConnectCARootName, rootsReq, rootsWatchID, ch) - if err != nil { - cancel() - return err - } - - // Watch the leaf cert - err = a.cache.Notify(ctx, cachetype.ConnectCALeafName, leafReq, leafWatchID, ch) - if err != nil { - cancel() - return err - } - - // Setup actions in case the watches are firing. - go func() { - for { - select { - case <-a.shutdownCh: - cancel() - return - case <-ctx.Done(): - return - case u := <-ch: - switch u.CorrelationID { - case rootsWatchID: - roots, ok := u.Result.(*structs.IndexedCARoots) - if !ok { - err := fmt.Errorf("invalid type for roots response: %T", u.Result) - a.logger.Error("watch error for correlation id", - "correlation_id", u.CorrelationID, - "error", err, - ) - continue - } - pems := []string{} - for _, root := range roots.Roots { - pems = append(pems, root.RootCert) - } - a.tlsConfigurator.UpdateAutoEncryptCA(pems) - case leafWatchID: - leaf, ok := u.Result.(*structs.IssuedCert) - if !ok { - err := fmt.Errorf("invalid type for leaf response: %T", u.Result) - a.logger.Error("watch error for correlation id", - "correlation_id", u.CorrelationID, - "error", err, - ) - continue - } - a.tlsConfigurator.UpdateAutoEncryptCert(leaf.CertPEM, leaf.PrivateKeyPEM) - } - } - } - }() - - // Setup safety net in case the auto_encrypt cert doesn't get renewed - // in time. The agent would be stuck in that case because the watches - // never use the AutoEncrypt.Sign endpoint. - go func() { - // Check 10sec after cert expires. The agent cache - // should be handling the expiration and renew before - // it. - // If there is no cert, AutoEncryptCertNotAfter returns - // a value in the past which immediately triggers the - // renew, but this case shouldn't happen because at - // this point, auto_encrypt was just being setup - // successfully. - interval := a.tlsConfigurator.AutoEncryptCertNotAfter().Sub(time.Now().Add(10 * time.Second)) - autoLogger := a.logger.Named(logging.AutoEncrypt) - for { - a.logger.Debug("setting up client certificate expiration check on interval", "interval", interval) - select { - case <-a.shutdownCh: - return - case <-time.After(interval): - // check auto encrypt client cert expiration - if a.tlsConfigurator.AutoEncryptCertExpired() { - autoLogger.Debug("client certificate expired.") - // Background because the context is mainly useful when the agent is first starting up. - reply, err := a.setupClientAutoEncrypt(context.Background()) - if err != nil { - autoLogger.Error("client certificate expired, failed to renew", "error", err) - // in case of an error, try again in one minute - interval = time.Minute - continue - } - _, _, err = a.setupClientAutoEncryptCache(reply) - if err != nil { - autoLogger.Error("client certificate expired, failed to populate cache", "error", err) - // in case of an error, try again in one minute - interval = time.Minute - continue - } - } - } - } - }() - - return nil + return client.RequestAutoEncryptCerts(ctx, addrs, a.config.ServerPort, a.tokens.AgentToken(), a.config.AutoEncryptDNSSAN, a.config.AutoEncryptIPSAN) } func (a *Agent) listenAndServeGRPC() error { diff --git a/agent/agent_endpoint_test.go b/agent/agent_endpoint_test.go index c9aef60e3f..a2b96e9693 100644 --- a/agent/agent_endpoint_test.go +++ b/agent/agent_endpoint_test.go @@ -5247,6 +5247,8 @@ func TestAgentConnectCALeafCert_good(t *testing.T) { assert.Equal(fmt.Sprintf("%d", issued.ModifyIndex), resp.Header().Get("X-Consul-Index")) + index := resp.Header().Get("X-Consul-Index") + // Test caching { // Fetch it again @@ -5259,39 +5261,25 @@ func TestAgentConnectCALeafCert_good(t *testing.T) { require.Equal("HIT", resp.Header().Get("X-Cache")) } - // Test that caching is updated in the background + // Issue a blocking query to ensure that the cert gets updated appropriately { // Set a new CA ca := connect.TestCAConfigSet(t, a, nil) - retry.Run(t, func(r *retry.R) { - resp := httptest.NewRecorder() - // Try and sign again (note no index/wait arg since cache should update in - // background even if we aren't actively blocking) - obj, err := a.srv.AgentConnectCALeafCert(resp, req) - r.Check(err) + resp := httptest.NewRecorder() + req, _ := http.NewRequest("GET", "/v1/agent/connect/ca/leaf/test?index="+index, nil) + obj, err := a.srv.AgentConnectCALeafCert(resp, req) + require.NoError(err) + issued2 := obj.(*structs.IssuedCert) + require.NotEqual(issued.CertPEM, issued2.CertPEM) + require.NotEqual(issued.PrivateKeyPEM, issued2.PrivateKeyPEM) - issued2 := obj.(*structs.IssuedCert) - if issued.CertPEM == issued2.CertPEM { - r.Fatalf("leaf has not updated") - } + // Verify that the cert is signed by the new CA + requireLeafValidUnderCA(t, issued2, ca) - // Got a new leaf. Sanity check it's a whole new key as well as different - // cert. - if issued.PrivateKeyPEM == issued2.PrivateKeyPEM { - r.Fatalf("new leaf has same private key as before") - } - - // Verify that the cert is signed by the new CA - requireLeafValidUnderCA(t, issued2, ca) - - // Should be a cache hit! The data should've updated in the cache - // in the background so this should've been fetched directly from - // the cache. - if resp.Header().Get("X-Cache") != "HIT" { - r.Fatalf("should be a cache hit") - } - }) + // Should not be a cache hit! The data was updated in response to the blocking + // query being made. + require.Equal("MISS", resp.Header().Get("X-Cache")) } } diff --git a/agent/cache-types/connect_ca_leaf.go b/agent/cache-types/connect_ca_leaf.go index c177804893..325423bec6 100644 --- a/agent/cache-types/connect_ca_leaf.go +++ b/agent/cache-types/connect_ca_leaf.go @@ -50,7 +50,7 @@ const caChangeJitterWindow = 30 * time.Second // ConnectCALeaf supports fetching and generating Connect leaf // certificates. type ConnectCALeaf struct { - RegisterOptionsBlockingRefresh + RegisterOptionsBlockingNoRefresh caIndex uint64 // Current index for CA roots // rootWatchMu protects access to the rootWatchSubscribers map and diff --git a/agent/cache-types/options.go b/agent/cache-types/options.go index 6864258e8e..5eb6cdd9b6 100644 --- a/agent/cache-types/options.go +++ b/agent/cache-types/options.go @@ -18,7 +18,17 @@ func (r RegisterOptionsBlockingRefresh) RegisterOptions() cache.RegisterOptions Refresh: true, SupportsBlocking: true, RefreshTimer: 0 * time.Second, - RefreshTimeout: 10 * time.Minute, + QueryTimeout: 10 * time.Minute, + } +} + +type RegisterOptionsBlockingNoRefresh struct{} + +func (r RegisterOptionsBlockingNoRefresh) RegisterOptions() cache.RegisterOptions { + return cache.RegisterOptions{ + Refresh: false, + SupportsBlocking: true, + QueryTimeout: 10 * time.Minute, } } diff --git a/agent/cache/cache.go b/agent/cache/cache.go index 2b1f3c1202..a216d4ab8a 100644 --- a/agent/cache/cache.go +++ b/agent/cache/cache.go @@ -172,7 +172,7 @@ type RegisterOptions struct { // If this is zero, then data is refreshed immediately when a fetch // is returned. // - // Using different values for RefreshTimer and RefreshTimeout, various + // Using different values for RefreshTimer and QueryTimeout, various // "refresh" mechanisms can be implemented: // // * With a high timer duration and a low timeout, a timer-based @@ -184,10 +184,10 @@ type RegisterOptions struct { // RefreshTimer time.Duration - // RefreshTimeout is the default value for the maximum query time for a fetch + // QueryTimeout is the default value for the maximum query time for a fetch // operation. It is set as FetchOptions.Timeout so that cache.Type // implementations can use it as the MaxQueryTime. - RefreshTimeout time.Duration + QueryTimeout time.Duration } // RegisterType registers a cacheable type. @@ -475,7 +475,7 @@ func (c *Cache) fetch(key string, r getOptions, allowNew bool, attempt uint, ign // keepalives are every 30 seconds so the RPC should fail if the packets are // being blackholed for more than 30 seconds. var connectedTimer *time.Timer - if tEntry.Opts.Refresh && entry.Index > 0 && tEntry.Opts.RefreshTimeout > 31*time.Second { + if tEntry.Opts.Refresh && entry.Index > 0 && tEntry.Opts.QueryTimeout > 31*time.Second { connectedTimer = time.AfterFunc(31*time.Second, func() { c.entriesLock.Lock() defer c.entriesLock.Unlock() @@ -491,7 +491,11 @@ func (c *Cache) fetch(key string, r getOptions, allowNew bool, attempt uint, ign fOpts := FetchOptions{} if tEntry.Opts.SupportsBlocking { fOpts.MinIndex = entry.Index - fOpts.Timeout = tEntry.Opts.RefreshTimeout + fOpts.Timeout = tEntry.Opts.QueryTimeout + + if fOpts.Timeout == 0 { + fOpts.Timeout = 10 * time.Minute + } } if entry.Valid { fOpts.LastResult = &FetchResult{ diff --git a/agent/cache/cache_test.go b/agent/cache/cache_test.go index 5036a12ca5..72422fb258 100644 --- a/agent/cache/cache_test.go +++ b/agent/cache/cache_test.go @@ -464,9 +464,9 @@ func TestCacheGet_periodicRefresh(t *testing.T) { typ := &MockType{} typ.On("RegisterOptions").Return(RegisterOptions{ - Refresh: true, - RefreshTimer: 100 * time.Millisecond, - RefreshTimeout: 5 * time.Minute, + Refresh: true, + RefreshTimer: 100 * time.Millisecond, + QueryTimeout: 5 * time.Minute, }) defer typ.AssertExpectations(t) c := TestCache(t) @@ -504,9 +504,9 @@ func TestCacheGet_periodicRefreshMultiple(t *testing.T) { typ := &MockType{} typ.On("RegisterOptions").Return(RegisterOptions{ - Refresh: true, - RefreshTimer: 0 * time.Millisecond, - RefreshTimeout: 5 * time.Minute, + Refresh: true, + RefreshTimer: 0 * time.Millisecond, + QueryTimeout: 5 * time.Minute, }) defer typ.AssertExpectations(t) c := TestCache(t) @@ -553,9 +553,9 @@ func TestCacheGet_periodicRefreshErrorBackoff(t *testing.T) { typ := &MockType{} typ.On("RegisterOptions").Return(RegisterOptions{ - Refresh: true, - RefreshTimer: 0, - RefreshTimeout: 5 * time.Minute, + Refresh: true, + RefreshTimer: 0, + QueryTimeout: 5 * time.Minute, }) defer typ.AssertExpectations(t) c := TestCache(t) @@ -595,9 +595,9 @@ func TestCacheGet_periodicRefreshBadRPCZeroIndexErrorBackoff(t *testing.T) { typ := &MockType{} typ.On("RegisterOptions").Return(RegisterOptions{ - Refresh: true, - RefreshTimer: 0, - RefreshTimeout: 5 * time.Minute, + Refresh: true, + RefreshTimer: 0, + QueryTimeout: 5 * time.Minute, }) defer typ.AssertExpectations(t) c := TestCache(t) @@ -642,7 +642,7 @@ func TestCacheGet_noIndexSetsOne(t *testing.T) { SupportsBlocking: true, Refresh: true, RefreshTimer: 0, - RefreshTimeout: 5 * time.Minute, + QueryTimeout: 5 * time.Minute, }) defer typ.AssertExpectations(t) c := TestCache(t) @@ -702,7 +702,7 @@ func TestCacheGet_fetchTimeout(t *testing.T) { typ := &MockType{} timeout := 10 * time.Minute typ.On("RegisterOptions").Return(RegisterOptions{ - RefreshTimeout: timeout, + QueryTimeout: timeout, SupportsBlocking: true, }) defer typ.AssertExpectations(t) @@ -954,9 +954,9 @@ func TestCacheGet_refreshAge(t *testing.T) { typ := &MockType{} typ.On("RegisterOptions").Return(RegisterOptions{ - Refresh: true, - RefreshTimer: 0, - RefreshTimeout: 5 * time.Minute, + Refresh: true, + RefreshTimer: 0, + QueryTimeout: 5 * time.Minute, }) defer typ.AssertExpectations(t) c := TestCache(t) diff --git a/agent/cert-monitor/cert_monitor.go b/agent/cert-monitor/cert_monitor.go new file mode 100644 index 0000000000..44fd3b7f27 --- /dev/null +++ b/agent/cert-monitor/cert_monitor.go @@ -0,0 +1,471 @@ +package certmon + +import ( + "context" + "fmt" + "io/ioutil" + "sync" + "time" + + "github.com/hashicorp/consul/agent/cache" + cachetype "github.com/hashicorp/consul/agent/cache-types" + "github.com/hashicorp/consul/agent/connect" + "github.com/hashicorp/consul/agent/structs" + "github.com/hashicorp/consul/agent/token" + "github.com/hashicorp/consul/tlsutil" + "github.com/hashicorp/go-hclog" +) + +const ( + // ID of the roots watch + rootsWatchID = "roots" + + // ID of the leaf watch + leafWatchID = "leaf" +) + +// Cache is an interface to represent the methods of the +// agent/cache.Cache struct that we care about +type Cache interface { + Notify(ctx context.Context, t string, r cache.Request, correlationID string, ch chan<- cache.UpdateEvent) error + Prepopulate(t string, result cache.FetchResult, dc string, token string, key string) error +} + +// CertMonitor will setup the proper watches to ensure that +// the Agent's Connect TLS certificate remains up to date +type CertMonitor struct { + logger hclog.Logger + cache Cache + tlsConfigurator *tlsutil.Configurator + tokens *token.Store + leafReq cachetype.ConnectCALeafRequest + rootsReq structs.DCSpecificRequest + fallback FallbackFunc + fallbackLeeway time.Duration + fallbackRetry time.Duration + + l sync.Mutex + running bool + // cancel is used to cancel the entire CertMonitor + // go routine. This is the main field protected + // by the mutex as it being non-nil indicates that + // the go routine has been started and is stoppable. + // note that it doesn't indcate that the go routine + // is currently running. + cancel context.CancelFunc + + // cancelWatches is used to cancel the existing + // cache watches. This is mainly only necessary + // when the Agent token changes + cancelWatches context.CancelFunc + + // cacheUpdates is the chan used to have the cache + // send us back events + cacheUpdates chan cache.UpdateEvent + // tokenUpdates is the struct used to receive + // events from the token store when the Agent + // token is updated. + tokenUpdates token.Notifier +} + +// New creates a new CertMonitor for automatically rotating +// an Agent's Connect Certificate +func New(config *Config) (*CertMonitor, error) { + logger := config.Logger + if logger == nil { + logger = hclog.New(&hclog.LoggerOptions{ + Level: 0, + Output: ioutil.Discard, + }) + } + + if config.FallbackLeeway == 0 { + config.FallbackLeeway = 10 * time.Second + } + if config.FallbackRetry == 0 { + config.FallbackRetry = time.Minute + } + + if config.Cache == nil { + return nil, fmt.Errorf("CertMonitor creation requires a Cache") + } + + if config.TLSConfigurator == nil { + return nil, fmt.Errorf("CertMonitor creation requires a TLS Configurator") + } + + if config.Fallback == nil { + return nil, fmt.Errorf("CertMonitor creation requires specifying a FallbackFunc") + } + + if config.Datacenter == "" { + return nil, fmt.Errorf("CertMonitor creation requires specifying the datacenter") + } + + if config.NodeName == "" { + return nil, fmt.Errorf("CertMonitor creation requires specifying the agent's node name") + } + + if config.Tokens == nil { + return nil, fmt.Errorf("CertMonitor creation requires specifying a token store") + } + + return &CertMonitor{ + logger: logger, + cache: config.Cache, + tokens: config.Tokens, + tlsConfigurator: config.TLSConfigurator, + fallback: config.Fallback, + fallbackLeeway: config.FallbackLeeway, + fallbackRetry: config.FallbackRetry, + rootsReq: structs.DCSpecificRequest{Datacenter: config.Datacenter}, + leafReq: cachetype.ConnectCALeafRequest{ + Datacenter: config.Datacenter, + Agent: config.NodeName, + DNSSAN: config.DNSSANs, + IPSAN: config.IPSANs, + }, + }, nil +} + +// Update is responsible for priming the cache with the certificates +// as well as injecting them into the TLS configurator +func (m *CertMonitor) Update(certs *structs.SignedResponse) error { + if certs == nil { + return nil + } + + if err := m.populateCache(certs); err != nil { + return fmt.Errorf("error populating cache with certificates: %w", err) + } + + connectCAPems := []string{} + for _, ca := range certs.ConnectCARoots.Roots { + connectCAPems = append(connectCAPems, ca.RootCert) + } + + // Note that its expected that the private key be within the IssuedCert in the + // SignedResponse. This isn't how a server would send back the response and requires + // that the recipient of the response who also has access to the private key will + // have filled it in. The Cache definitely does this but auto-encrypt/auto-config + // will need to ensure the original response is setup this way too. + err := m.tlsConfigurator.UpdateAutoEncrypt( + certs.ManualCARoots, + connectCAPems, + certs.IssuedCert.CertPEM, + certs.IssuedCert.PrivateKeyPEM, + certs.VerifyServerHostname) + + if err != nil { + return fmt.Errorf("error updating TLS configurator with certificates: %w", err) + } + + return nil +} + +// populateCache is responsible for inserting the certificates into the cache +func (m *CertMonitor) populateCache(resp *structs.SignedResponse) error { + cert, err := connect.ParseCert(resp.IssuedCert.CertPEM) + if err != nil { + return fmt.Errorf("Failed to parse certificate: %w", err) + } + + // prepolutate roots cache + rootRes := cache.FetchResult{Value: &resp.ConnectCARoots, Index: resp.ConnectCARoots.QueryMeta.Index} + // getting the roots doesn't require a token so in order to potentially share the cache with another + if err := m.cache.Prepopulate(cachetype.ConnectCARootName, rootRes, m.rootsReq.Datacenter, "", m.rootsReq.CacheInfo().Key); err != nil { + return err + } + + // copy the template and update the token + leafReq := m.leafReq + leafReq.Token = m.tokens.AgentToken() + + // prepolutate leaf cache + certRes := cache.FetchResult{ + Value: &resp.IssuedCert, + Index: resp.ConnectCARoots.QueryMeta.Index, + State: cachetype.ConnectCALeafSuccess(connect.EncodeSigningKeyID(cert.AuthorityKeyId)), + } + if err := m.cache.Prepopulate(cachetype.ConnectCALeafName, certRes, leafReq.Datacenter, leafReq.Token, leafReq.Key()); err != nil { + return err + } + return nil +} + +// Start spawns the go routine to monitor the certificate and ensure it is +// rotated/renewed as necessary. The chan will indicate once the started +// go routine has exited +func (m *CertMonitor) Start(ctx context.Context) (<-chan struct{}, error) { + m.l.Lock() + defer m.l.Unlock() + + if m.running || m.cancel != nil { + return nil, fmt.Errorf("the CertMonitor is already running") + } + + // create the top level context to control the go + // routine executing the `run` method + ctx, cancel := context.WithCancel(ctx) + + // create the channel to get cache update events through + // really we should only ever get 10 updates + m.cacheUpdates = make(chan cache.UpdateEvent, 10) + + // setup the cache watches + cancelWatches, err := m.setupCacheWatches(ctx) + if err != nil { + cancel() + return nil, fmt.Errorf("error setting up cache watches: %w", err) + } + + // start the token update notifier + m.tokenUpdates = m.tokens.Notify(token.TokenKindAgent) + + // store the cancel funcs + m.cancel = cancel + m.cancelWatches = cancelWatches + + m.running = true + exit := make(chan struct{}) + go m.run(ctx, exit) + + return exit, nil +} + +// Stop manually stops the go routine spawned by Start and +// returns whether the go routine was still running before +// cancelling. +// +// Note that cancelling the context passed into Start will +// also cause the go routine to stop +func (m *CertMonitor) Stop() bool { + m.l.Lock() + defer m.l.Unlock() + + if !m.running { + return false + } + + if m.cancel != nil { + m.cancel() + } + + return true +} + +// IsRunning returns whether the go routine to perform certificate monitoring +// is already running. +func (m *CertMonitor) IsRunning() bool { + m.l.Lock() + defer m.l.Unlock() + return m.running +} + +// setupCacheWatches will start both the roots and leaf cert watch with a new child +// context and an up to date ACL token. The watches are started with a new child context +// whose CancelFunc is also returned. +func (m *CertMonitor) setupCacheWatches(ctx context.Context) (context.CancelFunc, error) { + notificationCtx, cancel := context.WithCancel(ctx) + + // copy the request + rootsReq := m.rootsReq + + err := m.cache.Notify(notificationCtx, cachetype.ConnectCARootName, &rootsReq, rootsWatchID, m.cacheUpdates) + if err != nil { + cancel() + return nil, err + } + + // copy the request + leafReq := m.leafReq + leafReq.Token = m.tokens.AgentToken() + + err = m.cache.Notify(notificationCtx, cachetype.ConnectCALeafName, &leafReq, leafWatchID, m.cacheUpdates) + if err != nil { + cancel() + return nil, err + } + + return cancel, nil +} + +// handleCacheEvent is used to handle event notifications from the cache for the roots +// or leaf cert watches. +func (m *CertMonitor) handleCacheEvent(u cache.UpdateEvent) error { + switch u.CorrelationID { + case rootsWatchID: + m.logger.Debug("roots watch fired - updating CA certificates") + if u.Err != nil { + return fmt.Errorf("root watch returned an error: %w", u.Err) + } + + roots, ok := u.Result.(*structs.IndexedCARoots) + if !ok { + return fmt.Errorf("invalid type for roots watch response: %T", u.Result) + } + + var pems []string + for _, root := range roots.Roots { + pems = append(pems, root.RootCert) + } + + if err := m.tlsConfigurator.UpdateAutoEncryptCA(pems); err != nil { + return fmt.Errorf("failed to update Connect CA certificates: %w", err) + } + case leafWatchID: + m.logger.Debug("leaf certificate watch fired - updating TLS certificate") + if u.Err != nil { + return fmt.Errorf("leaf watch returned an error: %w", u.Err) + } + + leaf, ok := u.Result.(*structs.IssuedCert) + if !ok { + return fmt.Errorf("invalid type for agent leaf cert watch response: %T", u.Result) + } + if err := m.tlsConfigurator.UpdateAutoEncryptCert(leaf.CertPEM, leaf.PrivateKeyPEM); err != nil { + return fmt.Errorf("failed to update the agent leaf cert: %w", err) + } + } + + return nil +} + +// handleTokenUpdate is used when a notification about the agent token being updated +// is received and various watches need cancelling/restarting to use the new token. +func (m *CertMonitor) handleTokenUpdate(ctx context.Context) error { + m.logger.Debug("Agent token updated - resetting watches") + + // TODO (autoencrypt) Prepopulate the cache with the new token with + // the existing cache entry with the old token. The certificate doesn't + // need to change just because the token has. However there isn't a + // good way to make that happen and this behavior is benign enough + // that I am going to push off implementing it. + + // the agent token has been updated so we must update our leaf cert watch. + // this cancels the current watches before setting up new ones + m.cancelWatches() + + // recreate the chan for cache updates. This is a precautionary measure to ensure + // that we don't accidentally get notified for the new watches being setup before + // a blocking query in the cache returns and sends data to the old chan. In theory + // the code in agent/cache/watch.go should prevent this where we specifically check + // for context cancellation prior to sending the event. However we could cancel + // it after that check and finish setting up the new watches before getting the old + // events. Both the go routine scheduler and the OS thread scheduler would have to + // be acting up for this to happen. Regardless the way to ensure we don't get events + // for the old watches is to simply replace the chan we are expecting them from. + close(m.cacheUpdates) + m.cacheUpdates = make(chan cache.UpdateEvent, 10) + + // restart watches - this will be done with the correct token + cancelWatches, err := m.setupCacheWatches(ctx) + if err != nil { + return fmt.Errorf("failed to restart watches after agent token update: %w", err) + } + m.cancelWatches = cancelWatches + return nil +} + +// handleFallback is used when the current TLS certificate has expired and the normal +// updating mechanisms have failed to renew it quickly enough. This function will +// use the configured fallback mechanism to retrieve a new cert and start monitoring +// that one. +func (m *CertMonitor) handleFallback(ctx context.Context) error { + m.logger.Warn("agent's client certificate has expired") + // Background because the context is mainly useful when the agent is first starting up. + reply, err := m.fallback(ctx) + if err != nil { + return fmt.Errorf("error when getting new agent certificate: %w", err) + } + + return m.Update(reply) +} + +// run is the private method to be spawn by the Start method for +// executing the main monitoring loop. +func (m *CertMonitor) run(ctx context.Context, exit chan struct{}) { + // The fallbackTimer is used to notify AFTER the agents + // leaf certificate has expired and where we need + // to fall back to the less secure RPC endpoint just like + // if the agent was starting up new. + // + // Check 10sec (fallback leeway duration) after cert + // expires. The agent cache should be handling the expiration + // and renew it before then. + // + // If there is no cert, AutoEncryptCertNotAfter returns + // a value in the past which immediately triggers the + // renew, but this case shouldn't happen because at + // this point, auto_encrypt was just being setup + // successfully. + calcFallbackInterval := func() time.Duration { + certExpiry := m.tlsConfigurator.AutoEncryptCertNotAfter() + return certExpiry.Add(m.fallbackLeeway).Sub(time.Now()) + } + fallbackTimer := time.NewTimer(calcFallbackInterval()) + + // cleanup for once we are stopped + defer func() { + // cancel the go routines performing the cache watches + m.cancelWatches() + // ensure we don't leak the timers go routine + fallbackTimer.Stop() + // stop receiving notifications for token updates + m.tokens.StopNotify(m.tokenUpdates) + + m.logger.Debug("certificate monitor has been stopped") + + m.l.Lock() + m.cancel = nil + m.running = false + m.l.Unlock() + + // this should be the final cleanup task as its what notifies + // the rest of the world that this go routine has exited. + close(exit) + }() + + for { + select { + case <-ctx.Done(): + m.logger.Debug("stopping the certificate monitor") + return + case <-m.tokenUpdates.Ch: + m.logger.Debug("handling a token update event") + + if err := m.handleTokenUpdate(ctx); err != nil { + m.logger.Error("error in handling token update event", "error", err) + } + case u := <-m.cacheUpdates: + m.logger.Debug("handling a cache update event", "correlation_id", u.CorrelationID) + + if err := m.handleCacheEvent(u); err != nil { + m.logger.Error("error in handling cache update event", "error", err) + } + + // reset the fallback timer as the certificate may have been updated + fallbackTimer.Stop() + fallbackTimer = time.NewTimer(calcFallbackInterval()) + case <-fallbackTimer.C: + // This is a safety net in case the auto_encrypt cert doesn't get renewed + // in time. The agent would be stuck in that case because the watches + // never use the AutoEncrypt.Sign endpoint. + + // check auto encrypt client cert expiration + if m.tlsConfigurator.AutoEncryptCertExpired() { + if err := m.handleFallback(ctx); err != nil { + m.logger.Error("error when handling a certificate expiry event", "error", err) + fallbackTimer = time.NewTimer(m.fallbackRetry) + } else { + fallbackTimer = time.NewTimer(calcFallbackInterval()) + } + } else { + // this shouldn't be possible. We calculate the timer duration to be the certificate + // expiration time + some leeway (10s default). So whenever we get here the certificate + // should be expired. Regardless its probably worth resetting the timer. + fallbackTimer = time.NewTimer(calcFallbackInterval()) + } + } + } +} diff --git a/agent/cert-monitor/cert_monitor_test.go b/agent/cert-monitor/cert_monitor_test.go new file mode 100644 index 0000000000..dbbd63b5b2 --- /dev/null +++ b/agent/cert-monitor/cert_monitor_test.go @@ -0,0 +1,693 @@ +package certmon + +import ( + "context" + "crypto/tls" + "fmt" + "net" + "sync" + "testing" + "time" + + "github.com/hashicorp/consul/agent/cache" + cachetype "github.com/hashicorp/consul/agent/cache-types" + "github.com/hashicorp/consul/agent/connect" + "github.com/hashicorp/consul/agent/structs" + "github.com/hashicorp/consul/agent/token" + "github.com/hashicorp/consul/sdk/testutil" + "github.com/hashicorp/consul/sdk/testutil/retry" + "github.com/hashicorp/consul/tlsutil" + "github.com/hashicorp/go-uuid" + + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" +) + +type mockFallback struct { + mock.Mock +} + +func (m *mockFallback) fallback(ctx context.Context) (*structs.SignedResponse, error) { + ret := m.Called() + resp, _ := ret.Get(0).(*structs.SignedResponse) + return resp, ret.Error(1) +} + +type mockWatcher struct { + ch chan<- cache.UpdateEvent + done <-chan struct{} +} + +type mockCache struct { + mock.Mock + + lock sync.Mutex + watchers map[string][]mockWatcher +} + +func (m *mockCache) Notify(ctx context.Context, t string, r cache.Request, correlationID string, ch chan<- cache.UpdateEvent) error { + m.lock.Lock() + key := r.CacheInfo().Key + m.watchers[key] = append(m.watchers[key], mockWatcher{ch: ch, done: ctx.Done()}) + m.lock.Unlock() + ret := m.Called(t, r, correlationID) + return ret.Error(0) +} + +func (m *mockCache) Prepopulate(t string, result cache.FetchResult, dc string, token string, key string) error { + ret := m.Called(t, result, dc, token, key) + return ret.Error(0) +} + +func (m *mockCache) sendNotification(ctx context.Context, key string, u cache.UpdateEvent) bool { + m.lock.Lock() + defer m.lock.Unlock() + + watchers, ok := m.watchers[key] + if !ok || len(m.watchers) < 1 { + return false + } + + var newWatchers []mockWatcher + + for _, watcher := range watchers { + select { + case watcher.ch <- u: + newWatchers = append(newWatchers, watcher) + case <-watcher.done: + // do nothing, this watcher will be removed from the list + case <-ctx.Done(): + // return doesn't matter here really, the test is being cancelled + return true + } + } + + // this removes any already cancelled watches from being sent to + m.watchers[key] = newWatchers + + return true +} + +func newMockCache(t *testing.T) *mockCache { + mcache := mockCache{watchers: make(map[string][]mockWatcher)} + mcache.Test(t) + return &mcache +} + +func waitForChan(timer *time.Timer, ch <-chan struct{}) bool { + select { + case <-timer.C: + return false + case <-ch: + return true + } +} + +func waitForChans(timeout time.Duration, chans ...<-chan struct{}) bool { + timer := time.NewTimer(timeout) + defer timer.Stop() + + for _, ch := range chans { + if !waitForChan(timer, ch) { + return false + } + } + return true +} + +func testTLSConfigurator(t *testing.T) *tlsutil.Configurator { + t.Helper() + logger := testutil.Logger(t) + cfg, err := tlsutil.NewConfigurator(tlsutil.Config{AutoEncryptTLS: true}, logger) + require.NoError(t, err) + return cfg +} + +func newLeaf(t *testing.T, ca *structs.CARoot, idx uint64, expiration time.Duration) *structs.IssuedCert { + t.Helper() + + pub, priv, err := connect.TestAgentLeaf(t, "node", "foo", ca, expiration) + require.NoError(t, err) + cert, err := connect.ParseCert(pub) + require.NoError(t, err) + + spiffeID, err := connect.ParseCertURI(cert.URIs[0]) + require.NoError(t, err) + + agentID, ok := spiffeID.(*connect.SpiffeIDAgent) + require.True(t, ok, "certificate doesn't have an agent leaf cert URI") + + return &structs.IssuedCert{ + SerialNumber: cert.SerialNumber.String(), + CertPEM: pub, + PrivateKeyPEM: priv, + ValidAfter: cert.NotBefore, + ValidBefore: cert.NotAfter, + Agent: agentID.Agent, + AgentURI: agentID.URI().String(), + EnterpriseMeta: *structs.DefaultEnterpriseMeta(), + RaftIndex: structs.RaftIndex{ + CreateIndex: idx, + ModifyIndex: idx, + }, + } +} + +type testCertMonitor struct { + monitor *CertMonitor + mcache *mockCache + tls *tlsutil.Configurator + tokens *token.Store + fallback *mockFallback + + extraCACerts []string + initialCert *structs.IssuedCert + initialRoots *structs.IndexedCARoots + + // these are some variables that the CertMonitor was created with + datacenter string + nodeName string + dns []string + ips []net.IP + verifyServerHostname bool +} + +func newTestCertMonitor(t *testing.T) testCertMonitor { + t.Helper() + + tlsConfigurator := testTLSConfigurator(t) + tokens := new(token.Store) + + id, err := uuid.GenerateUUID() + require.NoError(t, err) + tokens.UpdateAgentToken(id, token.TokenSourceConfig) + + ca := connect.TestCA(t, nil) + manualCA := connect.TestCA(t, nil) + // this cert is setup to not expire quickly. this will prevent + // the test from accidentally running the fallback routine + // before we want to force that to happen. + issued := newLeaf(t, ca, 1, 10*time.Minute) + + indexedRoots := structs.IndexedCARoots{ + ActiveRootID: ca.ID, + TrustDomain: connect.TestClusterID, + Roots: []*structs.CARoot{ + ca, + }, + QueryMeta: structs.QueryMeta{ + Index: 1, + }, + } + + initialCerts := &structs.SignedResponse{ + ConnectCARoots: indexedRoots, + IssuedCert: *issued, + ManualCARoots: []string{manualCA.RootCert}, + VerifyServerHostname: true, + } + + dnsSANs := []string{"test.dev"} + ipSANs := []net.IP{net.IPv4(198, 18, 0, 1)} + + // this chan should be unbuffered so we can detect when the fallback func has been called. + fallback := &mockFallback{} + + mcache := newMockCache(t) + rootRes := cache.FetchResult{Value: &indexedRoots, Index: 1} + rootsReq := structs.DCSpecificRequest{Datacenter: "foo"} + mcache.On("Prepopulate", cachetype.ConnectCARootName, rootRes, "foo", "", rootsReq.CacheInfo().Key).Return(nil).Once() + + leafReq := cachetype.ConnectCALeafRequest{ + Token: tokens.AgentToken(), + Agent: "node", + Datacenter: "foo", + DNSSAN: dnsSANs, + IPSAN: ipSANs, + } + leafRes := cache.FetchResult{ + Value: issued, + Index: 1, + State: cachetype.ConnectCALeafSuccess(ca.SigningKeyID), + } + mcache.On("Prepopulate", cachetype.ConnectCALeafName, leafRes, "foo", tokens.AgentToken(), leafReq.Key()).Return(nil).Once() + + // we can assert more later but this should always be done. + defer mcache.AssertExpectations(t) + + cfg := new(Config). + WithCache(mcache). + WithLogger(testutil.Logger(t)). + WithTLSConfigurator(tlsConfigurator). + WithTokens(tokens). + WithFallback(fallback.fallback). + WithDNSSANs(dnsSANs). + WithIPSANs(ipSANs). + WithDatacenter("foo"). + WithNodeName("node"). + WithFallbackLeeway(time.Nanosecond). + WithFallbackRetry(time.Millisecond) + + monitor, err := New(cfg) + require.NoError(t, err) + require.NotNil(t, monitor) + + require.NoError(t, monitor.Update(initialCerts)) + + return testCertMonitor{ + monitor: monitor, + tls: tlsConfigurator, + tokens: tokens, + mcache: mcache, + fallback: fallback, + extraCACerts: []string{manualCA.RootCert}, + initialCert: issued, + initialRoots: &indexedRoots, + datacenter: "foo", + nodeName: "node", + dns: dnsSANs, + ips: ipSANs, + verifyServerHostname: true, + } +} + +func tlsCertificateFromIssued(t *testing.T, issued *structs.IssuedCert) *tls.Certificate { + t.Helper() + + cert, err := tls.X509KeyPair([]byte(issued.CertPEM), []byte(issued.PrivateKeyPEM)) + require.NoError(t, err) + return &cert +} + +// convenience method to get a TLS Certificate from the intial issued certificate and priv key +func (cm *testCertMonitor) initialTLSCertificate(t *testing.T) *tls.Certificate { + t.Helper() + return tlsCertificateFromIssued(t, cm.initialCert) +} + +// just a convenience method to get a list of all the CA pems that we set up regardless +// of manual vs connect. +func (cm *testCertMonitor) initialCACerts() []string { + pems := cm.extraCACerts + for _, root := range cm.initialRoots.Roots { + pems = append(pems, root.RootCert) + } + return pems +} + +func (cm *testCertMonitor) assertExpectations(t *testing.T) { + cm.mcache.AssertExpectations(t) + cm.fallback.AssertExpectations(t) +} + +func TestCertMonitor_InitialCerts(t *testing.T) { + // this also ensures that the cache was prepopulated properly + cm := newTestCertMonitor(t) + + // verify that the certificate was injected into the TLS configurator correctly + require.Equal(t, cm.initialTLSCertificate(t), cm.tls.Cert()) + // verify that the CA certs (both Connect and manual ones) were injected correctly + require.ElementsMatch(t, cm.initialCACerts(), cm.tls.CAPems()) + // verify that the auto-tls verify server hostname setting was injected correctly + require.Equal(t, cm.verifyServerHostname, cm.tls.VerifyServerHostname()) +} + +func TestCertMonitor_GoRoutineManagement(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + cm := newTestCertMonitor(t) + + // ensure that the monitor is not running + require.False(t, cm.monitor.IsRunning()) + + // ensure that nothing bad happens and that it reports as stopped + require.False(t, cm.monitor.Stop()) + + // we will never send notifications so these just ignore everything + cm.mcache.On("Notify", cachetype.ConnectCARootName, &structs.DCSpecificRequest{Datacenter: cm.datacenter}, rootsWatchID).Return(nil).Times(2) + cm.mcache.On("Notify", cachetype.ConnectCALeafName, + &cachetype.ConnectCALeafRequest{ + Token: cm.tokens.AgentToken(), + Datacenter: cm.datacenter, + Agent: cm.nodeName, + DNSSAN: cm.dns, + IPSAN: cm.ips, + }, + leafWatchID, + ).Return(nil).Times(2) + + done, err := cm.monitor.Start(ctx) + require.NoError(t, err) + require.True(t, cm.monitor.IsRunning()) + _, err = cm.monitor.Start(ctx) + testutil.RequireErrorContains(t, err, "the CertMonitor is already running") + require.True(t, cm.monitor.Stop()) + + require.True(t, waitForChans(100*time.Millisecond, done), "monitor didn't shut down") + require.False(t, cm.monitor.IsRunning()) + done, err = cm.monitor.Start(ctx) + require.NoError(t, err) + + // ensure that context cancellation causes us to stop as well + cancel() + require.True(t, waitForChans(100*time.Millisecond, done)) + + cm.assertExpectations(t) +} + +func startedCertMonitor(t *testing.T) (context.Context, testCertMonitor) { + ctx, cancel := context.WithCancel(context.Background()) + t.Cleanup(cancel) + + cm := newTestCertMonitor(t) + + rootsCtx, rootsCancel := context.WithCancel(ctx) + defer rootsCancel() + leafCtx, leafCancel := context.WithCancel(ctx) + defer leafCancel() + + // initial roots watch + cm.mcache.On("Notify", cachetype.ConnectCARootName, + &structs.DCSpecificRequest{ + Datacenter: cm.datacenter, + }, + rootsWatchID). + Return(nil). + Once(). + Run(func(_ mock.Arguments) { + rootsCancel() + }) + // the initial watch after starting the monitor + cm.mcache.On("Notify", cachetype.ConnectCALeafName, + &cachetype.ConnectCALeafRequest{ + Token: cm.tokens.AgentToken(), + Datacenter: cm.datacenter, + Agent: cm.nodeName, + DNSSAN: cm.dns, + IPSAN: cm.ips, + }, + leafWatchID). + Return(nil). + Once(). + Run(func(_ mock.Arguments) { + leafCancel() + }) + + done, err := cm.monitor.Start(ctx) + require.NoError(t, err) + // this prevents logs after the test finishes + t.Cleanup(func() { + cm.monitor.Stop() + <-done + }) + + require.True(t, + waitForChans(100*time.Millisecond, rootsCtx.Done(), leafCtx.Done()), + "not all watches were started within the alotted time") + + return ctx, cm +} + +// This test ensures that the cache watches are restarted with the updated +// token after receiving a token update +func TestCertMonitor_TokenUpdate(t *testing.T) { + ctx, cm := startedCertMonitor(t) + + rootsCtx, rootsCancel := context.WithCancel(ctx) + defer rootsCancel() + leafCtx, leafCancel := context.WithCancel(ctx) + defer leafCancel() + + newToken := "8e4fe8db-162d-42d8-81ca-710fb2280ad0" + + // we expect a new roots watch because when the leaf cert watch is restarted so is the root cert watch + cm.mcache.On("Notify", cachetype.ConnectCARootName, + &structs.DCSpecificRequest{ + Datacenter: cm.datacenter, + }, + rootsWatchID). + Return(nil). + Once(). + Run(func(_ mock.Arguments) { + rootsCancel() + }) + + secondWatch := &cachetype.ConnectCALeafRequest{ + Token: newToken, + Datacenter: cm.datacenter, + Agent: cm.nodeName, + DNSSAN: cm.dns, + IPSAN: cm.ips, + } + // the new watch after updating the token + cm.mcache.On("Notify", cachetype.ConnectCALeafName, secondWatch, leafWatchID). + Return(nil). + Once(). + Run(func(args mock.Arguments) { + leafCancel() + }) + + cm.tokens.UpdateAgentToken(newToken, token.TokenSourceAPI) + + require.True(t, + waitForChans(100*time.Millisecond, rootsCtx.Done(), leafCtx.Done()), + "not all watches were restarted within the alotted time") + + cm.assertExpectations(t) +} + +func TestCertMonitor_RootsUpdate(t *testing.T) { + ctx, cm := startedCertMonitor(t) + + secondCA := connect.TestCA(t, cm.initialRoots.Roots[0]) + secondRoots := structs.IndexedCARoots{ + ActiveRootID: secondCA.ID, + TrustDomain: connect.TestClusterID, + Roots: []*structs.CARoot{ + secondCA, + cm.initialRoots.Roots[0], + }, + QueryMeta: structs.QueryMeta{ + Index: 99, + }, + } + + // assert value of the CA certs prior to updating + require.ElementsMatch(t, cm.initialCACerts(), cm.tls.CAPems()) + + req := structs.DCSpecificRequest{Datacenter: cm.datacenter} + require.True(t, cm.mcache.sendNotification(ctx, req.CacheInfo().Key, cache.UpdateEvent{ + CorrelationID: rootsWatchID, + Result: &secondRoots, + Meta: cache.ResultMeta{ + Index: secondRoots.Index, + }, + })) + + expectedCAs := append(cm.extraCACerts, secondCA.RootCert, cm.initialRoots.Roots[0].RootCert) + + // this will wait up to 200ms (8 x 25 ms waits between the 9 requests) + retry.RunWith(&retry.Counter{Count: 9, Wait: 25 * time.Millisecond}, t, func(r *retry.R) { + require.ElementsMatch(r, expectedCAs, cm.tls.CAPems()) + }) + + cm.assertExpectations(t) +} + +func TestCertMonitor_CertUpdate(t *testing.T) { + ctx, cm := startedCertMonitor(t) + + secondCert := newLeaf(t, cm.initialRoots.Roots[0], 100, 10*time.Minute) + + // assert value of cert prior to updating the leaf + require.Equal(t, cm.initialTLSCertificate(t), cm.tls.Cert()) + + key := cm.monitor.leafReq.CacheInfo().Key + + // send the new certificate - this notifies only the watchers utilizing + // the new ACL token + require.True(t, cm.mcache.sendNotification(ctx, key, cache.UpdateEvent{ + CorrelationID: leafWatchID, + Result: secondCert, + Meta: cache.ResultMeta{ + Index: secondCert.ModifyIndex, + }, + })) + + tlsCert := tlsCertificateFromIssued(t, secondCert) + + // this will wait up to 200ms (8 x 25 ms waits between the 9 requests) + retry.RunWith(&retry.Counter{Count: 9, Wait: 25 * time.Millisecond}, t, func(r *retry.R) { + require.Equal(r, tlsCert, cm.tls.Cert()) + }) + + cm.assertExpectations(t) +} + +func TestCertMonitor_Fallback(t *testing.T) { + ctx, cm := startedCertMonitor(t) + + // at this point everything is operating normally and the monitor is just + // waiting for events. We are going to send a new cert that is basically + // already expired and then allow the fallback routine to kick in. + secondCert := newLeaf(t, cm.initialRoots.Roots[0], 100, time.Nanosecond) + secondCA := connect.TestCA(t, cm.initialRoots.Roots[0]) + secondRoots := structs.IndexedCARoots{ + ActiveRootID: secondCA.ID, + TrustDomain: connect.TestClusterID, + Roots: []*structs.CARoot{ + secondCA, + cm.initialRoots.Roots[0], + }, + QueryMeta: structs.QueryMeta{ + Index: 101, + }, + } + thirdCert := newLeaf(t, secondCA, 102, 10*time.Minute) + + // inject a fallback routine error to check that we rerun it quickly + cm.fallback.On("fallback").Return(nil, fmt.Errorf("induced error")).Once() + + // expect the fallback routine to be executed and setup the return + cm.fallback.On("fallback").Return(&structs.SignedResponse{ + ConnectCARoots: secondRoots, + IssuedCert: *thirdCert, + ManualCARoots: cm.extraCACerts, + VerifyServerHostname: true, + }, nil).Once() + + // Add another roots cache prepopulation expectation which should happen + // in response to executing the fallback mechanism + rootRes := cache.FetchResult{Value: &secondRoots, Index: 101} + rootsReq := structs.DCSpecificRequest{Datacenter: cm.datacenter} + cm.mcache.On("Prepopulate", cachetype.ConnectCARootName, rootRes, cm.datacenter, "", rootsReq.CacheInfo().Key).Return(nil).Once() + + // add another leaf cert cache prepopulation expectation which should happen + // in response to executing the fallback mechanism + leafReq := cachetype.ConnectCALeafRequest{ + Token: cm.tokens.AgentToken(), + Agent: cm.nodeName, + Datacenter: cm.datacenter, + DNSSAN: cm.dns, + IPSAN: cm.ips, + } + leafRes := cache.FetchResult{ + Value: thirdCert, + Index: 101, + State: cachetype.ConnectCALeafSuccess(secondCA.SigningKeyID), + } + cm.mcache.On("Prepopulate", cachetype.ConnectCALeafName, leafRes, leafReq.Datacenter, leafReq.Token, leafReq.Key()).Return(nil).Once() + + // nothing in the monitor should be looking at this as its only done + // in response to sending token updates, no need to synchronize + key := cm.monitor.leafReq.CacheInfo().Key + // send the new certificate - this notifies only the watchers utilizing + // the new ACL token + require.True(t, cm.mcache.sendNotification(ctx, key, cache.UpdateEvent{ + CorrelationID: leafWatchID, + Result: secondCert, + Meta: cache.ResultMeta{ + Index: secondCert.ModifyIndex, + }, + })) + + // if all went well we would have updated the first certificate which was pretty much expired + // causing the fallback handler to be invoked almost immediately. The fallback routine will + // return the response containing the third cert and second CA roots so now we should wait + // a little while and ensure they were applied to the TLS Configurator + tlsCert := tlsCertificateFromIssued(t, thirdCert) + expectedCAs := append(cm.extraCACerts, secondCA.RootCert, cm.initialRoots.Roots[0].RootCert) + + // this will wait up to 200ms (8 x 25 ms waits between the 9 requests) + retry.RunWith(&retry.Counter{Count: 9, Wait: 25 * time.Millisecond}, t, func(r *retry.R) { + require.Equal(r, tlsCert, cm.tls.Cert()) + require.ElementsMatch(r, expectedCAs, cm.tls.CAPems()) + }) + + cm.assertExpectations(t) +} + +func TestCertMonitor_New_Errors(t *testing.T) { + type testCase struct { + cfg Config + err string + } + + fallback := func(_ context.Context) (*structs.SignedResponse, error) { + return nil, fmt.Errorf("Unimplemented") + } + + tokens := new(token.Store) + + cases := map[string]testCase{ + "no-cache": { + cfg: Config{ + TLSConfigurator: testTLSConfigurator(t), + Fallback: fallback, + Tokens: tokens, + Datacenter: "foo", + NodeName: "bar", + }, + err: "CertMonitor creation requires a Cache", + }, + "no-tls-configurator": { + cfg: Config{ + Cache: cache.New(nil), + Fallback: fallback, + Tokens: tokens, + Datacenter: "foo", + NodeName: "bar", + }, + err: "CertMonitor creation requires a TLS Configurator", + }, + "no-fallback": { + cfg: Config{ + Cache: cache.New(nil), + TLSConfigurator: testTLSConfigurator(t), + Tokens: tokens, + Datacenter: "foo", + NodeName: "bar", + }, + err: "CertMonitor creation requires specifying a FallbackFunc", + }, + "no-tokens": { + cfg: Config{ + Cache: cache.New(nil), + TLSConfigurator: testTLSConfigurator(t), + Fallback: fallback, + Datacenter: "foo", + NodeName: "bar", + }, + err: "CertMonitor creation requires specifying a token store", + }, + "no-datacenter": { + cfg: Config{ + Cache: cache.New(nil), + TLSConfigurator: testTLSConfigurator(t), + Fallback: fallback, + Tokens: tokens, + NodeName: "bar", + }, + err: "CertMonitor creation requires specifying the datacenter", + }, + "no-node-name": { + cfg: Config{ + Cache: cache.New(nil), + TLSConfigurator: testTLSConfigurator(t), + Fallback: fallback, + Tokens: tokens, + Datacenter: "foo", + }, + err: "CertMonitor creation requires specifying the agent's node name", + }, + } + + for name, tcase := range cases { + t.Run(name, func(t *testing.T) { + monitor, err := New(&tcase.cfg) + testutil.RequireErrorContains(t, err, tcase.err) + require.Nil(t, monitor) + }) + } +} diff --git a/agent/cert-monitor/config.go b/agent/cert-monitor/config.go new file mode 100644 index 0000000000..a1da2841e6 --- /dev/null +++ b/agent/cert-monitor/config.go @@ -0,0 +1,137 @@ +package certmon + +import ( + "context" + "net" + "time" + + "github.com/hashicorp/consul/agent/structs" + "github.com/hashicorp/consul/agent/token" + "github.com/hashicorp/consul/tlsutil" + "github.com/hashicorp/go-hclog" +) + +// FallbackFunc is used when the normal cache watch based Certificate +// updating fails to update the Certificate in time and a different +// method of updating the certificate is required. +type FallbackFunc func(context.Context) (*structs.SignedResponse, error) + +type Config struct { + // Logger is the logger to be used while running. If not set + // then no logging will be performed. + Logger hclog.Logger + + // TLSConfigurator is where the certificates and roots are set when + // they are updated. This field is required. + TLSConfigurator *tlsutil.Configurator + + // Cache is an object implementing our Cache interface. The Cache + // used at runtime must be able to handle Roots and Leaf Cert watches + Cache Cache + + // Tokens is the shared token store. It is used to retrieve the current + // agent token as well as getting notifications when that token is updated. + // This field is required. + Tokens *token.Store + + // Fallback is a function to run when the normal cache updating of the + // agent's certificates has failed to work for one reason or another. + // This field is required. + Fallback FallbackFunc + + // FallbackLeeway is the amount of time after certificate expiration before + // invoking the fallback routine. If not set this will default to 10s. + FallbackLeeway time.Duration + + // FallbackRetry is the duration between Fallback invocations when the configured + // fallback routine returns an error. If not set this will default to 1m. + FallbackRetry time.Duration + + // DNSSANs is a list of DNS SANs that certificate requests should include. This + // field is optional and no extra DNS SANs will be requested if unset. 'localhost' + // is unconditionally requested by the cache implementation. + DNSSANs []string + + // IPSANs is a list of IP SANs to include in the certificate signing request. This + // field is optional and no extra IP SANs will be requested if unset. Both '127.0.0.1' + // and '::1' IP SANs are unconditionally requested by the cache implementation. + IPSANs []net.IP + + // Datacenter is the datacenter to request certificates within. This filed is required + Datacenter string + + // NodeName is the agent's node name to use when requesting certificates. This field + // is required. + NodeName string +} + +// WithCache will cause the created CertMonitor type to use the provided Cache +func (cfg *Config) WithCache(cache Cache) *Config { + cfg.Cache = cache + return cfg +} + +// WithLogger will cause the created CertMonitor type to use the provided logger +func (cfg *Config) WithLogger(logger hclog.Logger) *Config { + cfg.Logger = logger + return cfg +} + +// WithTLSConfigurator will cause the created CertMonitor type to use the provided configurator +func (cfg *Config) WithTLSConfigurator(tlsConfigurator *tlsutil.Configurator) *Config { + cfg.TLSConfigurator = tlsConfigurator + return cfg +} + +// WithTokens will cause the created CertMonitor type to use the provided token store +func (cfg *Config) WithTokens(tokens *token.Store) *Config { + cfg.Tokens = tokens + return cfg +} + +// WithFallback configures a fallback function to use if the normal update mechanisms +// fail to renew the certificate in time. +func (cfg *Config) WithFallback(fallback FallbackFunc) *Config { + cfg.Fallback = fallback + return cfg +} + +// WithDNSSANs configures the CertMonitor to request these DNS SANs when requesting a new +// certificate +func (cfg *Config) WithDNSSANs(sans []string) *Config { + cfg.DNSSANs = sans + return cfg +} + +// WithIPSANs configures the CertMonitor to request these IP SANs when requesting a new +// certificate +func (cfg *Config) WithIPSANs(sans []net.IP) *Config { + cfg.IPSANs = sans + return cfg +} + +// WithDatacenter configures the CertMonitor to request Certificates in this DC +func (cfg *Config) WithDatacenter(dc string) *Config { + cfg.Datacenter = dc + return cfg +} + +// WithNodeName configures the CertMonitor to request Certificates with this agent name +func (cfg *Config) WithNodeName(name string) *Config { + cfg.NodeName = name + return cfg +} + +// WithFallbackLeeway configures how long after a certificate expires before attempting to +// generarte a new certificate using the fallback mechanism. The default is 10s. +func (cfg *Config) WithFallbackLeeway(leeway time.Duration) *Config { + cfg.FallbackLeeway = leeway + return cfg +} + +// WithFallbackRetry controls how quickly we will make subsequent invocations of +// the fallback func in the case of it erroring out. +func (cfg *Config) WithFallbackRetry(after time.Duration) *Config { + cfg.FallbackRetry = after + return cfg +} diff --git a/agent/connect/testing_ca.go b/agent/connect/testing_ca.go index fdd37ad4c1..e623c1872f 100644 --- a/agent/connect/testing_ca.go +++ b/agent/connect/testing_ca.go @@ -168,7 +168,21 @@ func TestCAWithKeyType(t testing.T, xc *structs.CARoot, keyType string, keyBits return testCA(t, xc, keyType, keyBits) } -func testLeaf(t testing.T, service string, namespace string, root *structs.CARoot, keyType string, keyBits int) (string, string, error) { +// testCertID is an interface to be implemented the various spiffe ID / CertURI types +// It adds an addition CommonName method to the CertURI interface to prevent the need +// for any type switching on the actual CertURI's concrete type in order to figure +// out its common name +type testCertID interface { + CommonName() string + CertURI +} + +func testLeafWithID(t testing.T, spiffeId testCertID, root *structs.CARoot, keyType string, keyBits int, expiration time.Duration) (string, string, error) { + + if expiration == 0 { + // this is 10 years + expiration = 10 * 365 * 24 * time.Hour + } // Parse the CA cert and signing key from the root cert := root.SigningCert if cert == "" { @@ -183,14 +197,6 @@ func testLeaf(t testing.T, service string, namespace string, root *structs.CARoo return "", "", fmt.Errorf("error parsing signing key: %s", err) } - // Build the SPIFFE ID - spiffeId := &SpiffeIDService{ - Host: fmt.Sprintf("%s.consul", TestClusterID), - Namespace: namespace, - Datacenter: "dc1", - Service: service, - } - // The serial number for the cert sn, err := testSerialNumber() if err != nil { @@ -211,7 +217,7 @@ func testLeaf(t testing.T, service string, namespace string, root *structs.CARoo // Cert template for generation template := x509.Certificate{ SerialNumber: sn, - Subject: pkix.Name{CommonName: ServiceCN(service, "default", TestClusterID)}, + Subject: pkix.Name{CommonName: spiffeId.CommonName()}, URIs: []*url.URL{spiffeId.URI()}, SignatureAlgorithm: SigAlgoForKeyType(rootKeyType), BasicConstraintsValid: true, @@ -223,7 +229,7 @@ func testLeaf(t testing.T, service string, namespace string, root *structs.CARoo x509.ExtKeyUsageClientAuth, x509.ExtKeyUsageServerAuth, }, - NotAfter: time.Now().AddDate(10, 0, 0), + NotAfter: time.Now().Add(expiration), NotBefore: time.Now(), AuthorityKeyId: testKeyID(t, caSigner.Public()), SubjectKeyId: testKeyID(t, pkSigner.Public()), @@ -244,6 +250,29 @@ func testLeaf(t testing.T, service string, namespace string, root *structs.CARoo return buf.String(), pkPEM, nil } +func TestAgentLeaf(t testing.T, node string, datacenter string, root *structs.CARoot, expiration time.Duration) (string, string, error) { + // Build the SPIFFE ID + spiffeId := &SpiffeIDAgent{ + Host: fmt.Sprintf("%s.consul", TestClusterID), + Datacenter: datacenter, + Agent: node, + } + + return testLeafWithID(t, spiffeId, root, DefaultPrivateKeyType, DefaultPrivateKeyBits, expiration) +} + +func testLeaf(t testing.T, service string, namespace string, root *structs.CARoot, keyType string, keyBits int) (string, string, error) { + // Build the SPIFFE ID + spiffeId := &SpiffeIDService{ + Host: fmt.Sprintf("%s.consul", TestClusterID), + Namespace: namespace, + Datacenter: "dc1", + Service: service, + } + + return testLeafWithID(t, spiffeId, root, keyType, keyBits, 0) +} + // TestLeaf returns a valid leaf certificate and it's private key for the named // service with the given CA Root. func TestLeaf(t testing.T, service string, root *structs.CARoot) (string, string) { diff --git a/agent/connect/uri_agent.go b/agent/connect/uri_agent.go index 00fb9335b5..86205dbdcd 100644 --- a/agent/connect/uri_agent.go +++ b/agent/connect/uri_agent.go @@ -27,3 +27,7 @@ func (id *SpiffeIDAgent) URI() *url.URL { func (id *SpiffeIDAgent) Authorize(_ *structs.Intention) (bool, bool) { return false, false } + +func (id *SpiffeIDAgent) CommonName() string { + return AgentCN(id.Agent, id.Host) +} diff --git a/agent/connect/uri_service.go b/agent/connect/uri_service.go index ed22c173e2..405bdcbd96 100644 --- a/agent/connect/uri_service.go +++ b/agent/connect/uri_service.go @@ -40,3 +40,7 @@ func (id *SpiffeIDService) Authorize(ixn *structs.Intention) (bool, bool) { // Match, return allow value return ixn.Action == structs.IntentionActionAllow, true } + +func (id *SpiffeIDService) CommonName() string { + return ServiceCN(id.Service, id.Namespace, id.Host) +} diff --git a/agent/consul/auto_encrypt.go b/agent/consul/auto_encrypt.go index e4bca49005..b9c3bbd41e 100644 --- a/agent/consul/auto_encrypt.go +++ b/agent/consul/auto_encrypt.go @@ -65,9 +65,9 @@ func (c *Client) autoEncryptCSR(extraDNSSANs []string, extraIPSANs []net.IP) (st return pkPEM, csr, nil } -func (c *Client) RequestAutoEncryptCerts(ctx context.Context, servers []string, port int, token string, extraDNSSANs []string, extraIPSANs []net.IP) (*structs.SignedResponse, string, error) { - errFn := func(err error) (*structs.SignedResponse, string, error) { - return nil, "", err +func (c *Client) RequestAutoEncryptCerts(ctx context.Context, servers []string, port int, token string, extraDNSSANs []string, extraIPSANs []net.IP) (*structs.SignedResponse, error) { + errFn := func(err error) (*structs.SignedResponse, error) { + return nil, err } // Check if we know about a server already through gossip. Depending on @@ -120,7 +120,8 @@ func (c *Client) RequestAutoEncryptCerts(ctx context.Context, servers []string, addr := net.TCPAddr{IP: ip, Port: port} if err = c.connPool.RPC(c.config.Datacenter, c.config.NodeName, &addr, "AutoEncrypt.Sign", &args, &reply); err == nil { - return &reply, pkPEM, nil + reply.IssuedCert.PrivateKeyPEM = pkPEM + return &reply, nil } else { c.logger.Warn("AutoEncrypt failed", "error", err) } diff --git a/agent/consul/auto_encrypt_test.go b/agent/consul/auto_encrypt_test.go index fd8725725f..8dd04e4166 100644 --- a/agent/consul/auto_encrypt_test.go +++ b/agent/consul/auto_encrypt_test.go @@ -104,7 +104,7 @@ func TestAutoEncrypt_RequestAutoEncryptCerts(t *testing.T) { doneCh := make(chan struct{}) var err error go func() { - _, _, err = c1.RequestAutoEncryptCerts(ctx, servers, port, token, nil, nil) + _, err = c1.RequestAutoEncryptCerts(ctx, servers, port, token, nil, nil) close(doneCh) }() select { diff --git a/agent/proxycfg/testing.go b/agent/proxycfg/testing.go index 2a36eb1929..910398c329 100644 --- a/agent/proxycfg/testing.go +++ b/agent/proxycfg/testing.go @@ -1995,7 +1995,7 @@ func (ct *ControllableCacheType) RegisterOptions() cache.RegisterOptions { return cache.RegisterOptions{ Refresh: ct.blocking, SupportsBlocking: ct.blocking, - RefreshTimeout: 10 * time.Minute, + QueryTimeout: 10 * time.Minute, } }