// Copyright (c) HashiCorp, Inc. // SPDX-License-Identifier: BUSL-1.1 package leafcert import ( "bytes" "context" "crypto/rand" "crypto/x509" "encoding/pem" "errors" "fmt" "math/big" "sync" "sync/atomic" "testing" "time" "github.com/hashicorp/consul/agent/cacheshim" "github.com/hashicorp/consul/agent/connect" "github.com/hashicorp/consul/agent/structs" "github.com/hashicorp/consul/sdk/testutil" ) // NewTestManager returns a *Manager that is pre-configured to use a mock RPC // implementation that can sign certs, and an in-memory CA roots reader that // interacts well with it. func NewTestManager(t *testing.T, mut func(*Config)) (*Manager, *TestSigner) { signer := newTestSigner(t, nil, nil) deps := Deps{ Logger: testutil.Logger(t), RootsReader: signer.RootsReader, CertSigner: signer, Config: Config{ // Override the root-change spread so we don't have to wait up to 20 seconds // to see root changes work. Can be changed back for specific tests that // need to test this, Note it's not 0 since that used default but is // effectively the same. TestOverrideCAChangeInitialDelay: 1 * time.Microsecond, }, } if mut != nil { mut(&deps.Config) } m := NewManager(deps) t.Cleanup(m.Stop) return m, signer } // TestSigner implements NetRPC and handles leaf signing operations type TestSigner struct { caLock sync.Mutex ca *structs.CARoot prevRoots []*structs.CARoot // remember prior ones IDGenerator *atomic.Uint64 RootsReader *testRootsReader signCallLock sync.Mutex signCallErrors []error signCallErrorCount uint64 signCallCapture []*structs.CASignRequest } var _ CertSigner = (*TestSigner)(nil) var ReplyWithExpiredCert = errors.New("reply with expired cert") func newTestSigner(t *testing.T, idGenerator *atomic.Uint64, rootsReader *testRootsReader) *TestSigner { if idGenerator == nil { idGenerator = &atomic.Uint64{} } if rootsReader == nil { rootsReader = newTestRootsReader(t) } s := &TestSigner{ IDGenerator: idGenerator, RootsReader: rootsReader, } return s } func (s *TestSigner) SetSignCallErrors(errs ...error) { s.signCallLock.Lock() defer s.signCallLock.Unlock() s.signCallErrors = append(s.signCallErrors, errs...) } func (s *TestSigner) GetSignCallErrorCount() uint64 { s.signCallLock.Lock() defer s.signCallLock.Unlock() return s.signCallErrorCount } func (s *TestSigner) UpdateCA(t *testing.T, ca *structs.CARoot) *structs.CARoot { if ca == nil { ca = connect.TestCA(t, nil) } roots := &structs.IndexedCARoots{ ActiveRootID: ca.ID, TrustDomain: connect.TestTrustDomain, Roots: []*structs.CARoot{ca}, QueryMeta: structs.QueryMeta{Index: s.nextIndex()}, } // Update the signer first. s.caLock.Lock() { s.ca = ca roots.Roots = append(roots.Roots, s.prevRoots...) // Remember for the next rotation. dup := ca.Clone() dup.Active = false s.prevRoots = append(s.prevRoots, dup) } s.caLock.Unlock() // Then trigger an event when updating the roots. s.RootsReader.Set(roots) return ca } func (s *TestSigner) nextIndex() uint64 { return s.IDGenerator.Add(1) } func (s *TestSigner) getCA() *structs.CARoot { s.caLock.Lock() defer s.caLock.Unlock() return s.ca } func (s *TestSigner) GetCapture(idx int) *structs.CASignRequest { s.signCallLock.Lock() defer s.signCallLock.Unlock() if len(s.signCallCapture) > idx { return s.signCallCapture[idx] } return nil } func (s *TestSigner) SignCert(ctx context.Context, req *structs.CASignRequest) (*structs.IssuedCert, error) { useExpiredCert := false s.signCallLock.Lock() s.signCallCapture = append(s.signCallCapture, req) if len(s.signCallErrors) > 0 { err := s.signCallErrors[0] s.signCallErrors = s.signCallErrors[1:] if err == ReplyWithExpiredCert { useExpiredCert = true } else if err != nil { s.signCallErrorCount++ s.signCallLock.Unlock() return nil, err } } s.signCallLock.Unlock() // parts of this were inlined from CAManager and the connect ca provider ca := s.getCA() if ca == nil { return nil, fmt.Errorf("must call UpdateCA at least once") } csr, err := connect.ParseCSR(req.CSR) if err != nil { return nil, fmt.Errorf("error parsing CSR: %w", err) } connect.HackSANExtensionForCSR(csr) spiffeID, err := connect.ParseCertURI(csr.URIs[0]) if err != nil { return nil, fmt.Errorf("error parsing CSR URI: %w", err) } var isService bool var serviceID *connect.SpiffeIDService var workloadID *connect.SpiffeIDWorkloadIdentity switch spiffeID.(type) { case *connect.SpiffeIDService: isService = true serviceID = spiffeID.(*connect.SpiffeIDService) case *connect.SpiffeIDWorkloadIdentity: workloadID = spiffeID.(*connect.SpiffeIDWorkloadIdentity) default: return nil, fmt.Errorf("unexpected spiffeID type %T", spiffeID) } signer, err := connect.ParseSigner(ca.SigningKey) if err != nil { return nil, fmt.Errorf("error parsing CA signing key: %w", err) } keyId, err := connect.KeyId(signer.Public()) if err != nil { return nil, fmt.Errorf("error forming CA key id from public key: %w", err) } subjectKeyID, err := connect.KeyId(csr.PublicKey) if err != nil { return nil, fmt.Errorf("error forming subject key id from public key: %w", err) } caCert, err := connect.ParseCert(ca.RootCert) if err != nil { return nil, fmt.Errorf("error parsing CA root cert pem: %w", err) } const expiration = 10 * time.Minute now := time.Now() template := x509.Certificate{ SerialNumber: big.NewInt(int64(s.nextIndex())), URIs: csr.URIs, Signature: csr.Signature, // We use the correct signature algorithm for the CA key we are signing with // regardless of the algorithm used to sign the CSR signature above since // the leaf might use a different key type. SignatureAlgorithm: connect.SigAlgoForKey(signer), PublicKeyAlgorithm: csr.PublicKeyAlgorithm, PublicKey: csr.PublicKey, BasicConstraintsValid: true, KeyUsage: x509.KeyUsageDataEncipherment | x509.KeyUsageKeyAgreement | x509.KeyUsageDigitalSignature | x509.KeyUsageKeyEncipherment, ExtKeyUsage: []x509.ExtKeyUsage{ x509.ExtKeyUsageClientAuth, x509.ExtKeyUsageServerAuth, }, NotAfter: now.Add(expiration), NotBefore: now, AuthorityKeyId: keyId, SubjectKeyId: subjectKeyID, DNSNames: csr.DNSNames, IPAddresses: csr.IPAddresses, } if useExpiredCert { template.NotBefore = time.Now().Add(-13 * time.Hour) template.NotAfter = time.Now().Add(-1 * time.Hour) } // Create the certificate, PEM encode it and return that value. var buf bytes.Buffer bs, err := x509.CreateCertificate( rand.Reader, &template, caCert, csr.PublicKey, signer) if err != nil { return nil, fmt.Errorf("error creating cert pem from CSR: %w", err) } err = pem.Encode(&buf, &pem.Block{Type: "CERTIFICATE", Bytes: bs}) if err != nil { return nil, fmt.Errorf("error encoding cert pem into text: %w", err) } leafPEM := buf.String() leafCert, err := connect.ParseCert(leafPEM) if err != nil { return nil, fmt.Errorf("error parsing cert from generated leaf pem: %w", err) } index := s.nextIndex() if isService { // Service Spiffe ID case return &structs.IssuedCert{ SerialNumber: connect.EncodeSerialNumber(leafCert.SerialNumber), CertPEM: leafPEM, Service: serviceID.Service, ServiceURI: leafCert.URIs[0].String(), ValidAfter: leafCert.NotBefore, ValidBefore: leafCert.NotAfter, RaftIndex: structs.RaftIndex{ CreateIndex: index, ModifyIndex: index, }, }, nil } else { // Workload identity Spiffe ID case return &structs.IssuedCert{ SerialNumber: connect.EncodeSerialNumber(leafCert.SerialNumber), CertPEM: leafPEM, WorkloadIdentity: workloadID.WorkloadIdentity, WorkloadIdentityURI: leafCert.URIs[0].String(), ValidAfter: leafCert.NotBefore, ValidBefore: leafCert.NotAfter, RaftIndex: structs.RaftIndex{ CreateIndex: index, ModifyIndex: index, }, }, nil } } type testRootsReader struct { mu sync.Mutex index uint64 roots *structs.IndexedCARoots watcher chan struct{} } func newTestRootsReader(t *testing.T) *testRootsReader { r := &testRootsReader{ watcher: make(chan struct{}), } t.Cleanup(func() { r.mu.Lock() watcher := r.watcher r.mu.Unlock() close(watcher) }) return r } var _ RootsReader = (*testRootsReader)(nil) func (r *testRootsReader) Set(roots *structs.IndexedCARoots) { r.mu.Lock() oldWatcher := r.watcher r.watcher = make(chan struct{}) r.roots = roots if roots == nil { r.index = 1 } else { r.index = roots.Index } r.mu.Unlock() close(oldWatcher) } func (r *testRootsReader) Get() (*structs.IndexedCARoots, error) { r.mu.Lock() defer r.mu.Unlock() return r.roots, nil } func (r *testRootsReader) Notify(ctx context.Context, correlationID string, ch chan<- cacheshim.UpdateEvent) error { r.mu.Lock() watcher := r.watcher r.mu.Unlock() go func() { <-watcher r.mu.Lock() defer r.mu.Unlock() ch <- cacheshim.UpdateEvent{ CorrelationID: correlationID, Result: r.roots, Meta: cacheshim.ResultMeta{Index: r.index}, Err: nil, } }() return nil }