mirror of https://github.com/status-im/consul.git
Ensure certificates retrieved through the cache get persisted with auto-config (#8409)
This commit is contained in:
parent
dbb461a5d3
commit
1a78cf9b4c
|
@ -544,7 +544,8 @@ func New(options ...AgentOption) (*Agent, error) {
|
|||
WithNodeName(a.config.NodeName).
|
||||
WithFallback(a.autoConfigFallbackTLS).
|
||||
WithLogger(a.logger.Named(logging.AutoConfig)).
|
||||
WithTokens(a.tokens)
|
||||
WithTokens(a.tokens).
|
||||
WithPersistence(a.autoConfigPersist)
|
||||
acCertMon, err := certmon.New(cmConf)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
@ -888,9 +889,19 @@ func (a *Agent) autoEncryptInitialCertificate(ctx context.Context) (*structs.Sig
|
|||
}
|
||||
|
||||
func (a *Agent) autoConfigFallbackTLS(ctx context.Context) (*structs.SignedResponse, error) {
|
||||
if a.autoConf == nil {
|
||||
return nil, fmt.Errorf("AutoConfig manager has not been created yet")
|
||||
}
|
||||
return a.autoConf.FallbackTLS(ctx)
|
||||
}
|
||||
|
||||
func (a *Agent) autoConfigPersist(resp *structs.SignedResponse) error {
|
||||
if a.autoConf == nil {
|
||||
return fmt.Errorf("AutoConfig manager has not been created yet")
|
||||
}
|
||||
return a.autoConf.RecordUpdatedCerts(resp)
|
||||
}
|
||||
|
||||
func (a *Agent) listenAndServeGRPC() error {
|
||||
if len(a.config.GRPCAddrs) < 1 {
|
||||
return nil
|
||||
|
|
|
@ -21,6 +21,7 @@ import (
|
|||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/golang/protobuf/jsonpb"
|
||||
"github.com/google/tcpproxy"
|
||||
"github.com/hashicorp/consul/agent/cache"
|
||||
cachetype "github.com/hashicorp/consul/agent/cache-types"
|
||||
|
@ -32,6 +33,7 @@ import (
|
|||
"github.com/hashicorp/consul/internal/go-sso/oidcauth/oidcauthtest"
|
||||
"github.com/hashicorp/consul/ipaddr"
|
||||
"github.com/hashicorp/consul/lib"
|
||||
"github.com/hashicorp/consul/proto/pbautoconf"
|
||||
"github.com/hashicorp/consul/sdk/freeport"
|
||||
"github.com/hashicorp/consul/sdk/testutil"
|
||||
"github.com/hashicorp/consul/sdk/testutil/retry"
|
||||
|
@ -4722,21 +4724,28 @@ func TestAutoConfig_Integration(t *testing.T) {
|
|||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
client := StartTestAgent(t, TestAgent{Name: "test-client", HCL: `
|
||||
bootstrap = false
|
||||
server = false
|
||||
ca_file = "` + caFile + `"
|
||||
verify_outgoing = true
|
||||
verify_server_hostname = true
|
||||
node_name = "test-client"
|
||||
ports {
|
||||
server = ` + strconv.Itoa(srv.Config.RPCBindAddr.Port) + `
|
||||
}
|
||||
auto_config {
|
||||
enabled = true
|
||||
intro_token = "` + token + `"
|
||||
server_addresses = ["` + srv.Config.RPCBindAddr.String() + `"]
|
||||
}`})
|
||||
client := StartTestAgent(t, TestAgent{Name: "test-client",
|
||||
Overrides: `
|
||||
connect {
|
||||
test_ca_leaf_root_change_spread = "1ns"
|
||||
}
|
||||
`,
|
||||
HCL: `
|
||||
bootstrap = false
|
||||
server = false
|
||||
ca_file = "` + caFile + `"
|
||||
verify_outgoing = true
|
||||
verify_server_hostname = true
|
||||
node_name = "test-client"
|
||||
ports {
|
||||
server = ` + strconv.Itoa(srv.Config.RPCBindAddr.Port) + `
|
||||
}
|
||||
auto_config {
|
||||
enabled = true
|
||||
intro_token = "` + token + `"
|
||||
server_addresses = ["` + srv.Config.RPCBindAddr.String() + `"]
|
||||
}`,
|
||||
})
|
||||
|
||||
defer client.Shutdown()
|
||||
|
||||
|
@ -4776,6 +4785,21 @@ func TestAutoConfig_Integration(t *testing.T) {
|
|||
// ensure that a new cert gets generated and pushed into the TLS configurator
|
||||
retry.Run(t, func(r *retry.R) {
|
||||
require.NotEqual(r, cert1, client.Agent.tlsConfigurator.Cert())
|
||||
|
||||
// check that the on disk certs match expectations
|
||||
data, err := ioutil.ReadFile(filepath.Join(client.DataDir, "auto-config.json"))
|
||||
require.NoError(r, err)
|
||||
rdr := strings.NewReader(string(data))
|
||||
|
||||
var resp pbautoconf.AutoConfigResponse
|
||||
pbUnmarshaler := &jsonpb.Unmarshaler{
|
||||
AllowUnknownFields: false,
|
||||
}
|
||||
require.NoError(r, pbUnmarshaler.Unmarshal(rdr, &resp), "data: %s", data)
|
||||
|
||||
actual, err := tls.X509KeyPair([]byte(resp.Certificate.CertPEM), []byte(resp.Certificate.PrivateKeyPEM))
|
||||
require.NoError(r, err)
|
||||
require.Equal(r, client.Agent.tlsConfigurator.Cert(), &actual)
|
||||
})
|
||||
|
||||
// spot check that we now have an ACL token
|
||||
|
|
|
@ -55,15 +55,16 @@ var (
|
|||
// then we will need to add some locking here. I am deferring that for now
|
||||
// to help ease the review of this already large PR.
|
||||
type AutoConfig struct {
|
||||
builderOpts config.BuilderOpts
|
||||
logger hclog.Logger
|
||||
directRPC DirectRPC
|
||||
waiter *lib.RetryWaiter
|
||||
overrides []config.Source
|
||||
certMonitor CertMonitor
|
||||
config *config.RuntimeConfig
|
||||
autoConfigData string
|
||||
cancel context.CancelFunc
|
||||
builderOpts config.BuilderOpts
|
||||
logger hclog.Logger
|
||||
directRPC DirectRPC
|
||||
waiter *lib.RetryWaiter
|
||||
overrides []config.Source
|
||||
certMonitor CertMonitor
|
||||
config *config.RuntimeConfig
|
||||
autoConfigResponse *pbautoconf.AutoConfigResponse
|
||||
autoConfigData string
|
||||
cancel context.CancelFunc
|
||||
}
|
||||
|
||||
// New creates a new AutoConfig object for providing automatic
|
||||
|
@ -493,6 +494,8 @@ func (ac *AutoConfig) generateCSR() (csr string, key string, err error) {
|
|||
// config data to be used during a call to ReadConfig, updating the
|
||||
// tls Configurator and prepopulating the cache.
|
||||
func (ac *AutoConfig) update(resp *pbautoconf.AutoConfigResponse) error {
|
||||
ac.autoConfigResponse = resp
|
||||
|
||||
if err := ac.updateConfigFromResponse(resp); err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -591,3 +594,18 @@ func (ac *AutoConfig) FallbackTLS(ctx context.Context) (*structs.SignedResponse,
|
|||
|
||||
return extractSignedResponse(resp)
|
||||
}
|
||||
|
||||
func (ac *AutoConfig) RecordUpdatedCerts(resp *structs.SignedResponse) error {
|
||||
var err error
|
||||
ac.autoConfigResponse.ExtraCACertificates = resp.ManualCARoots
|
||||
ac.autoConfigResponse.CARoots, err = translateCARootsToProtobuf(&resp.ConnectCARoots)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
ac.autoConfigResponse.Certificate, err = translateIssuedCertToProtobuf(&resp.IssuedCert)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return ac.recordResponse(ac.autoConfigResponse)
|
||||
}
|
||||
|
|
|
@ -226,3 +226,34 @@ func mapstructureTranslateToStructs(in interface{}, out interface{}) error {
|
|||
|
||||
return decoder.Decode(in)
|
||||
}
|
||||
|
||||
func translateCARootsToProtobuf(in *structs.IndexedCARoots) (*pbconnect.CARoots, error) {
|
||||
var out pbconnect.CARoots
|
||||
if err := mapstructureTranslateToProtobuf(in, &out); err != nil {
|
||||
return nil, fmt.Errorf("Failed to re-encode CA Roots: %w", err)
|
||||
}
|
||||
|
||||
return &out, nil
|
||||
}
|
||||
|
||||
func translateIssuedCertToProtobuf(in *structs.IssuedCert) (*pbconnect.IssuedCert, error) {
|
||||
var out pbconnect.IssuedCert
|
||||
if err := mapstructureTranslateToProtobuf(in, &out); err != nil {
|
||||
return nil, fmt.Errorf("Failed to re-encode CA Roots: %w", err)
|
||||
}
|
||||
|
||||
return &out, nil
|
||||
}
|
||||
|
||||
func mapstructureTranslateToProtobuf(in interface{}, out interface{}) error {
|
||||
decoder, err := mapstructure.NewDecoder(&mapstructure.DecoderConfig{
|
||||
DecodeHook: proto.HookTimeToPBTimestamp,
|
||||
Result: out,
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return decoder.Decode(in)
|
||||
}
|
||||
|
|
|
@ -40,6 +40,7 @@ type CertMonitor struct {
|
|||
tokens *token.Store
|
||||
leafReq cachetype.ConnectCALeafRequest
|
||||
rootsReq structs.DCSpecificRequest
|
||||
persist PersistFunc
|
||||
fallback FallbackFunc
|
||||
fallbackLeeway time.Duration
|
||||
fallbackRetry time.Duration
|
||||
|
@ -66,6 +67,11 @@ type CertMonitor struct {
|
|||
// events from the token store when the Agent
|
||||
// token is updated.
|
||||
tokenUpdates token.Notifier
|
||||
|
||||
// this is used to keep a local copy of the certs
|
||||
// keys and ca certs. It will be used to persist
|
||||
// all of the local state at once.
|
||||
certs structs.SignedResponse
|
||||
}
|
||||
|
||||
// New creates a new CertMonitor for automatically rotating
|
||||
|
@ -115,6 +121,7 @@ func New(config *Config) (*CertMonitor, error) {
|
|||
cache: config.Cache,
|
||||
tokens: config.Tokens,
|
||||
tlsConfigurator: config.TLSConfigurator,
|
||||
persist: config.Persist,
|
||||
fallback: config.Fallback,
|
||||
fallbackLeeway: config.FallbackLeeway,
|
||||
fallbackRetry: config.FallbackRetry,
|
||||
|
@ -135,6 +142,8 @@ func (m *CertMonitor) Update(certs *structs.SignedResponse) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
m.certs = *certs
|
||||
|
||||
if err := m.populateCache(certs); err != nil {
|
||||
return fmt.Errorf("error populating cache with certificates: %w", err)
|
||||
}
|
||||
|
@ -306,6 +315,8 @@ func (m *CertMonitor) handleCacheEvent(u cache.UpdateEvent) error {
|
|||
return fmt.Errorf("invalid type for roots watch response: %T", u.Result)
|
||||
}
|
||||
|
||||
m.certs.ConnectCARoots = *roots
|
||||
|
||||
var pems []string
|
||||
for _, root := range roots.Roots {
|
||||
pems = append(pems, root.RootCert)
|
||||
|
@ -314,6 +325,13 @@ func (m *CertMonitor) handleCacheEvent(u cache.UpdateEvent) error {
|
|||
if err := m.tlsConfigurator.UpdateAutoTLSCA(pems); err != nil {
|
||||
return fmt.Errorf("failed to update Connect CA certificates: %w", err)
|
||||
}
|
||||
|
||||
if m.persist != nil {
|
||||
copy := m.certs
|
||||
if err := m.persist(©); err != nil {
|
||||
return fmt.Errorf("failed to persist certificate package: %w", err)
|
||||
}
|
||||
}
|
||||
case leafWatchID:
|
||||
m.logger.Debug("leaf certificate watch fired - updating TLS certificate")
|
||||
if u.Err != nil {
|
||||
|
@ -324,9 +342,19 @@ func (m *CertMonitor) handleCacheEvent(u cache.UpdateEvent) error {
|
|||
if !ok {
|
||||
return fmt.Errorf("invalid type for agent leaf cert watch response: %T", u.Result)
|
||||
}
|
||||
|
||||
m.certs.IssuedCert = *leaf
|
||||
|
||||
if err := m.tlsConfigurator.UpdateAutoTLSCert(leaf.CertPEM, leaf.PrivateKeyPEM); err != nil {
|
||||
return fmt.Errorf("failed to update the agent leaf cert: %w", err)
|
||||
}
|
||||
|
||||
if m.persist != nil {
|
||||
copy := m.certs
|
||||
if err := m.persist(©); err != nil {
|
||||
return fmt.Errorf("failed to persist certificate package: %w", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
|
@ -380,6 +408,11 @@ func (m *CertMonitor) handleFallback(ctx context.Context) error {
|
|||
return fmt.Errorf("error when getting new agent certificate: %w", err)
|
||||
}
|
||||
|
||||
if m.persist != nil {
|
||||
if err := m.persist(reply); err != nil {
|
||||
return fmt.Errorf("failed to persist certificate package: %w", err)
|
||||
}
|
||||
}
|
||||
return m.Update(reply)
|
||||
}
|
||||
|
||||
|
|
|
@ -33,6 +33,14 @@ func (m *mockFallback) fallback(ctx context.Context) (*structs.SignedResponse, e
|
|||
return resp, ret.Error(1)
|
||||
}
|
||||
|
||||
type mockPersist struct {
|
||||
mock.Mock
|
||||
}
|
||||
|
||||
func (m *mockPersist) persist(resp *structs.SignedResponse) error {
|
||||
return m.Called(resp).Error(0)
|
||||
}
|
||||
|
||||
type mockWatcher struct {
|
||||
ch chan<- cache.UpdateEvent
|
||||
done <-chan struct{}
|
||||
|
@ -159,6 +167,7 @@ type testCertMonitor struct {
|
|||
tls *tlsutil.Configurator
|
||||
tokens *token.Store
|
||||
fallback *mockFallback
|
||||
persist *mockPersist
|
||||
|
||||
extraCACerts []string
|
||||
initialCert *structs.IssuedCert
|
||||
|
@ -210,8 +219,10 @@ func newTestCertMonitor(t *testing.T) testCertMonitor {
|
|||
dnsSANs := []string{"test.dev"}
|
||||
ipSANs := []net.IP{net.IPv4(198, 18, 0, 1)}
|
||||
|
||||
// this chan should be unbuffered so we can detect when the fallback func has been called.
|
||||
fallback := &mockFallback{}
|
||||
fallback.Test(t)
|
||||
persist := &mockPersist{}
|
||||
persist.Test(t)
|
||||
|
||||
mcache := newMockCache(t)
|
||||
rootRes := cache.FetchResult{Value: &indexedRoots, Index: 1}
|
||||
|
@ -246,7 +257,8 @@ func newTestCertMonitor(t *testing.T) testCertMonitor {
|
|||
WithDatacenter("foo").
|
||||
WithNodeName("node").
|
||||
WithFallbackLeeway(time.Nanosecond).
|
||||
WithFallbackRetry(time.Millisecond)
|
||||
WithFallbackRetry(time.Millisecond).
|
||||
WithPersistence(persist.persist)
|
||||
|
||||
monitor, err := New(cfg)
|
||||
require.NoError(t, err)
|
||||
|
@ -259,6 +271,7 @@ func newTestCertMonitor(t *testing.T) testCertMonitor {
|
|||
tls: tlsConfigurator,
|
||||
tokens: tokens,
|
||||
mcache: mcache,
|
||||
persist: persist,
|
||||
fallback: fallback,
|
||||
extraCACerts: []string{manualCA.RootCert},
|
||||
initialCert: issued,
|
||||
|
@ -298,6 +311,7 @@ func (cm *testCertMonitor) initialCACerts() []string {
|
|||
func (cm *testCertMonitor) assertExpectations(t *testing.T) {
|
||||
cm.mcache.AssertExpectations(t)
|
||||
cm.fallback.AssertExpectations(t)
|
||||
cm.persist.AssertExpectations(t)
|
||||
}
|
||||
|
||||
func TestCertMonitor_InitialCerts(t *testing.T) {
|
||||
|
@ -473,6 +487,13 @@ func TestCertMonitor_RootsUpdate(t *testing.T) {
|
|||
},
|
||||
}
|
||||
|
||||
cm.persist.On("persist", &structs.SignedResponse{
|
||||
IssuedCert: *cm.initialCert,
|
||||
ManualCARoots: cm.extraCACerts,
|
||||
ConnectCARoots: secondRoots,
|
||||
VerifyServerHostname: cm.verifyServerHostname,
|
||||
}).Return(nil).Once()
|
||||
|
||||
// assert value of the CA certs prior to updating
|
||||
require.ElementsMatch(t, cm.initialCACerts(), cm.tls.CAPems())
|
||||
|
||||
|
@ -500,6 +521,13 @@ func TestCertMonitor_CertUpdate(t *testing.T) {
|
|||
|
||||
secondCert := newLeaf(t, cm.initialRoots.Roots[0], 100, 10*time.Minute)
|
||||
|
||||
cm.persist.On("persist", &structs.SignedResponse{
|
||||
IssuedCert: *secondCert,
|
||||
ManualCARoots: cm.extraCACerts,
|
||||
ConnectCARoots: *cm.initialRoots,
|
||||
VerifyServerHostname: cm.verifyServerHostname,
|
||||
}).Return(nil).Once()
|
||||
|
||||
// assert value of cert prior to updating the leaf
|
||||
require.Equal(t, cm.initialTLSCertificate(t), cm.tls.Cert())
|
||||
|
||||
|
@ -549,13 +577,23 @@ func TestCertMonitor_Fallback(t *testing.T) {
|
|||
// inject a fallback routine error to check that we rerun it quickly
|
||||
cm.fallback.On("fallback").Return(nil, fmt.Errorf("induced error")).Once()
|
||||
|
||||
// expect the fallback routine to be executed and setup the return
|
||||
cm.fallback.On("fallback").Return(&structs.SignedResponse{
|
||||
fallbackResp := &structs.SignedResponse{
|
||||
ConnectCARoots: secondRoots,
|
||||
IssuedCert: *thirdCert,
|
||||
ManualCARoots: cm.extraCACerts,
|
||||
VerifyServerHostname: true,
|
||||
}, nil).Once()
|
||||
}
|
||||
// expect the fallback routine to be executed and setup the return
|
||||
cm.fallback.On("fallback").Return(fallbackResp, nil).Once()
|
||||
|
||||
cm.persist.On("persist", &structs.SignedResponse{
|
||||
IssuedCert: *secondCert,
|
||||
ConnectCARoots: *cm.initialRoots,
|
||||
ManualCARoots: cm.extraCACerts,
|
||||
VerifyServerHostname: cm.verifyServerHostname,
|
||||
}).Return(nil).Once()
|
||||
|
||||
cm.persist.On("persist", fallbackResp).Return(nil).Once()
|
||||
|
||||
// Add another roots cache prepopulation expectation which should happen
|
||||
// in response to executing the fallback mechanism
|
||||
|
|
|
@ -16,6 +16,9 @@ import (
|
|||
// method of updating the certificate is required.
|
||||
type FallbackFunc func(context.Context) (*structs.SignedResponse, error)
|
||||
|
||||
// PersistFunc is used to persist the data from a signed response
|
||||
type PersistFunc func(*structs.SignedResponse) error
|
||||
|
||||
type Config struct {
|
||||
// Logger is the logger to be used while running. If not set
|
||||
// then no logging will be performed.
|
||||
|
@ -34,6 +37,9 @@ type Config struct {
|
|||
// This field is required.
|
||||
Tokens *token.Store
|
||||
|
||||
// Persist is a function to run when there are new certs or keys
|
||||
Persist PersistFunc
|
||||
|
||||
// Fallback is a function to run when the normal cache updating of the
|
||||
// agent's certificates has failed to work for one reason or another.
|
||||
// This field is required.
|
||||
|
@ -135,3 +141,10 @@ func (cfg *Config) WithFallbackRetry(after time.Duration) *Config {
|
|||
cfg.FallbackRetry = after
|
||||
return cfg
|
||||
}
|
||||
|
||||
// WithPersistence will configure the CertMonitor to use this callback for persisting
|
||||
// a new TLS configuration.
|
||||
func (cfg *Config) WithPersistence(persist PersistFunc) *Config {
|
||||
cfg.Persist = persist
|
||||
return cfg
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue