consul/agent/leafcert/leafcert_test_helpers.go

367 lines
9.1 KiB
Go

// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: BUSL-1.1
package leafcert
import (
"bytes"
"context"
"crypto/rand"
"crypto/x509"
"encoding/pem"
"errors"
"fmt"
"math/big"
"sync"
"sync/atomic"
"testing"
"time"
"github.com/hashicorp/consul/agent/cacheshim"
"github.com/hashicorp/consul/agent/connect"
"github.com/hashicorp/consul/agent/structs"
"github.com/hashicorp/consul/sdk/testutil"
)
// NewTestManager returns a *Manager that is pre-configured to use a mock RPC
// implementation that can sign certs, and an in-memory CA roots reader that
// interacts well with it.
func NewTestManager(t *testing.T, mut func(*Config)) (*Manager, *TestSigner) {
signer := newTestSigner(t, nil, nil)
deps := Deps{
Logger: testutil.Logger(t),
RootsReader: signer.RootsReader,
CertSigner: signer,
Config: Config{
// Override the root-change spread so we don't have to wait up to 20 seconds
// to see root changes work. Can be changed back for specific tests that
// need to test this, Note it's not 0 since that used default but is
// effectively the same.
TestOverrideCAChangeInitialDelay: 1 * time.Microsecond,
},
}
if mut != nil {
mut(&deps.Config)
}
m := NewManager(deps)
t.Cleanup(m.Stop)
return m, signer
}
// TestSigner implements NetRPC and handles leaf signing operations
type TestSigner struct {
caLock sync.Mutex
ca *structs.CARoot
prevRoots []*structs.CARoot // remember prior ones
IDGenerator *atomic.Uint64
RootsReader *testRootsReader
signCallLock sync.Mutex
signCallErrors []error
signCallErrorCount uint64
signCallCapture []*structs.CASignRequest
}
var _ CertSigner = (*TestSigner)(nil)
var ReplyWithExpiredCert = errors.New("reply with expired cert")
func newTestSigner(t *testing.T, idGenerator *atomic.Uint64, rootsReader *testRootsReader) *TestSigner {
if idGenerator == nil {
idGenerator = &atomic.Uint64{}
}
if rootsReader == nil {
rootsReader = newTestRootsReader(t)
}
s := &TestSigner{
IDGenerator: idGenerator,
RootsReader: rootsReader,
}
return s
}
func (s *TestSigner) SetSignCallErrors(errs ...error) {
s.signCallLock.Lock()
defer s.signCallLock.Unlock()
s.signCallErrors = append(s.signCallErrors, errs...)
}
func (s *TestSigner) GetSignCallErrorCount() uint64 {
s.signCallLock.Lock()
defer s.signCallLock.Unlock()
return s.signCallErrorCount
}
func (s *TestSigner) UpdateCA(t *testing.T, ca *structs.CARoot) *structs.CARoot {
if ca == nil {
ca = connect.TestCA(t, nil)
}
roots := &structs.IndexedCARoots{
ActiveRootID: ca.ID,
TrustDomain: connect.TestTrustDomain,
Roots: []*structs.CARoot{ca},
QueryMeta: structs.QueryMeta{Index: s.nextIndex()},
}
// Update the signer first.
s.caLock.Lock()
{
s.ca = ca
roots.Roots = append(roots.Roots, s.prevRoots...)
// Remember for the next rotation.
dup := ca.Clone()
dup.Active = false
s.prevRoots = append(s.prevRoots, dup)
}
s.caLock.Unlock()
// Then trigger an event when updating the roots.
s.RootsReader.Set(roots)
return ca
}
func (s *TestSigner) nextIndex() uint64 {
return s.IDGenerator.Add(1)
}
func (s *TestSigner) getCA() *structs.CARoot {
s.caLock.Lock()
defer s.caLock.Unlock()
return s.ca
}
func (s *TestSigner) GetCapture(idx int) *structs.CASignRequest {
s.signCallLock.Lock()
defer s.signCallLock.Unlock()
if len(s.signCallCapture) > idx {
return s.signCallCapture[idx]
}
return nil
}
func (s *TestSigner) SignCert(ctx context.Context, req *structs.CASignRequest) (*structs.IssuedCert, error) {
useExpiredCert := false
s.signCallLock.Lock()
s.signCallCapture = append(s.signCallCapture, req)
if len(s.signCallErrors) > 0 {
err := s.signCallErrors[0]
s.signCallErrors = s.signCallErrors[1:]
if err == ReplyWithExpiredCert {
useExpiredCert = true
} else if err != nil {
s.signCallErrorCount++
s.signCallLock.Unlock()
return nil, err
}
}
s.signCallLock.Unlock()
// parts of this were inlined from CAManager and the connect ca provider
ca := s.getCA()
if ca == nil {
return nil, fmt.Errorf("must call UpdateCA at least once")
}
csr, err := connect.ParseCSR(req.CSR)
if err != nil {
return nil, fmt.Errorf("error parsing CSR: %w", err)
}
connect.HackSANExtensionForCSR(csr)
spiffeID, err := connect.ParseCertURI(csr.URIs[0])
if err != nil {
return nil, fmt.Errorf("error parsing CSR URI: %w", err)
}
var isService bool
var serviceID *connect.SpiffeIDService
var workloadID *connect.SpiffeIDWorkloadIdentity
switch spiffeID.(type) {
case *connect.SpiffeIDService:
isService = true
serviceID = spiffeID.(*connect.SpiffeIDService)
case *connect.SpiffeIDWorkloadIdentity:
workloadID = spiffeID.(*connect.SpiffeIDWorkloadIdentity)
default:
return nil, fmt.Errorf("unexpected spiffeID type %T", spiffeID)
}
signer, err := connect.ParseSigner(ca.SigningKey)
if err != nil {
return nil, fmt.Errorf("error parsing CA signing key: %w", err)
}
keyId, err := connect.KeyId(signer.Public())
if err != nil {
return nil, fmt.Errorf("error forming CA key id from public key: %w", err)
}
subjectKeyID, err := connect.KeyId(csr.PublicKey)
if err != nil {
return nil, fmt.Errorf("error forming subject key id from public key: %w", err)
}
caCert, err := connect.ParseCert(ca.RootCert)
if err != nil {
return nil, fmt.Errorf("error parsing CA root cert pem: %w", err)
}
const expiration = 10 * time.Minute
now := time.Now()
template := x509.Certificate{
SerialNumber: big.NewInt(int64(s.nextIndex())),
URIs: csr.URIs,
Signature: csr.Signature,
// We use the correct signature algorithm for the CA key we are signing with
// regardless of the algorithm used to sign the CSR signature above since
// the leaf might use a different key type.
SignatureAlgorithm: connect.SigAlgoForKey(signer),
PublicKeyAlgorithm: csr.PublicKeyAlgorithm,
PublicKey: csr.PublicKey,
BasicConstraintsValid: true,
KeyUsage: x509.KeyUsageDataEncipherment |
x509.KeyUsageKeyAgreement |
x509.KeyUsageDigitalSignature |
x509.KeyUsageKeyEncipherment,
ExtKeyUsage: []x509.ExtKeyUsage{
x509.ExtKeyUsageClientAuth,
x509.ExtKeyUsageServerAuth,
},
NotAfter: now.Add(expiration),
NotBefore: now,
AuthorityKeyId: keyId,
SubjectKeyId: subjectKeyID,
DNSNames: csr.DNSNames,
IPAddresses: csr.IPAddresses,
}
if useExpiredCert {
template.NotBefore = time.Now().Add(-13 * time.Hour)
template.NotAfter = time.Now().Add(-1 * time.Hour)
}
// Create the certificate, PEM encode it and return that value.
var buf bytes.Buffer
bs, err := x509.CreateCertificate(
rand.Reader, &template, caCert, csr.PublicKey, signer)
if err != nil {
return nil, fmt.Errorf("error creating cert pem from CSR: %w", err)
}
err = pem.Encode(&buf, &pem.Block{Type: "CERTIFICATE", Bytes: bs})
if err != nil {
return nil, fmt.Errorf("error encoding cert pem into text: %w", err)
}
leafPEM := buf.String()
leafCert, err := connect.ParseCert(leafPEM)
if err != nil {
return nil, fmt.Errorf("error parsing cert from generated leaf pem: %w", err)
}
index := s.nextIndex()
if isService {
// Service Spiffe ID case
return &structs.IssuedCert{
SerialNumber: connect.EncodeSerialNumber(leafCert.SerialNumber),
CertPEM: leafPEM,
Service: serviceID.Service,
ServiceURI: leafCert.URIs[0].String(),
ValidAfter: leafCert.NotBefore,
ValidBefore: leafCert.NotAfter,
RaftIndex: structs.RaftIndex{
CreateIndex: index,
ModifyIndex: index,
},
}, nil
} else {
// Workload identity Spiffe ID case
return &structs.IssuedCert{
SerialNumber: connect.EncodeSerialNumber(leafCert.SerialNumber),
CertPEM: leafPEM,
WorkloadIdentity: workloadID.WorkloadIdentity,
WorkloadIdentityURI: leafCert.URIs[0].String(),
ValidAfter: leafCert.NotBefore,
ValidBefore: leafCert.NotAfter,
RaftIndex: structs.RaftIndex{
CreateIndex: index,
ModifyIndex: index,
},
}, nil
}
}
type testRootsReader struct {
mu sync.Mutex
index uint64
roots *structs.IndexedCARoots
watcher chan struct{}
}
func newTestRootsReader(t *testing.T) *testRootsReader {
r := &testRootsReader{
watcher: make(chan struct{}),
}
t.Cleanup(func() {
r.mu.Lock()
watcher := r.watcher
r.mu.Unlock()
close(watcher)
})
return r
}
var _ RootsReader = (*testRootsReader)(nil)
func (r *testRootsReader) Set(roots *structs.IndexedCARoots) {
r.mu.Lock()
oldWatcher := r.watcher
r.watcher = make(chan struct{})
r.roots = roots
if roots == nil {
r.index = 1
} else {
r.index = roots.Index
}
r.mu.Unlock()
close(oldWatcher)
}
func (r *testRootsReader) Get() (*structs.IndexedCARoots, error) {
r.mu.Lock()
defer r.mu.Unlock()
return r.roots, nil
}
func (r *testRootsReader) Notify(ctx context.Context, correlationID string, ch chan<- cacheshim.UpdateEvent) error {
r.mu.Lock()
watcher := r.watcher
r.mu.Unlock()
go func() {
<-watcher
r.mu.Lock()
defer r.mu.Unlock()
ch <- cacheshim.UpdateEvent{
CorrelationID: correlationID,
Result: r.roots,
Meta: cacheshim.ResultMeta{Index: r.index},
Err: nil,
}
}()
return nil
}