consul/agent/leafcert/signer_test.go

247 lines
6.2 KiB
Go

// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: BUSL-1.1
package leafcert
import (
"bytes"
"context"
"crypto/rand"
"crypto/x509"
"encoding/pem"
"errors"
"fmt"
"math/big"
"sync"
"sync/atomic"
"testing"
"time"
"github.com/hashicorp/consul/agent/connect"
"github.com/hashicorp/consul/agent/structs"
)
// testSigner implements NetRPC and handles leaf signing operations
type testSigner struct {
caLock sync.Mutex
ca *structs.CARoot
prevRoots []*structs.CARoot // remember prior ones
IDGenerator *atomic.Uint64
RootsReader *testRootsReader
signCallLock sync.Mutex
signCallErrors []error
signCallErrorCount uint64
signCallCapture []*structs.CASignRequest
}
var _ CertSigner = (*testSigner)(nil)
var ReplyWithExpiredCert = errors.New("reply with expired cert")
func newTestSigner(t *testing.T, idGenerator *atomic.Uint64, rootsReader *testRootsReader) *testSigner {
if idGenerator == nil {
idGenerator = &atomic.Uint64{}
}
if rootsReader == nil {
rootsReader = newTestRootsReader(t)
}
s := &testSigner{
IDGenerator: idGenerator,
RootsReader: rootsReader,
}
return s
}
func (s *testSigner) SetSignCallErrors(errs ...error) {
s.signCallLock.Lock()
defer s.signCallLock.Unlock()
s.signCallErrors = append(s.signCallErrors, errs...)
}
func (s *testSigner) GetSignCallErrorCount() uint64 {
s.signCallLock.Lock()
defer s.signCallLock.Unlock()
return s.signCallErrorCount
}
func (s *testSigner) UpdateCA(t *testing.T, ca *structs.CARoot) *structs.CARoot {
if ca == nil {
ca = connect.TestCA(t, nil)
}
roots := &structs.IndexedCARoots{
ActiveRootID: ca.ID,
TrustDomain: connect.TestTrustDomain,
Roots: []*structs.CARoot{ca},
QueryMeta: structs.QueryMeta{Index: s.nextIndex()},
}
// Update the signer first.
s.caLock.Lock()
{
s.ca = ca
roots.Roots = append(roots.Roots, s.prevRoots...)
// Remember for the next rotation.
dup := ca.Clone()
dup.Active = false
s.prevRoots = append(s.prevRoots, dup)
}
s.caLock.Unlock()
// Then trigger an event when updating the roots.
s.RootsReader.Set(roots)
return ca
}
func (s *testSigner) nextIndex() uint64 {
return s.IDGenerator.Add(1)
}
func (s *testSigner) getCA() *structs.CARoot {
s.caLock.Lock()
defer s.caLock.Unlock()
return s.ca
}
func (s *testSigner) GetCapture(idx int) *structs.CASignRequest {
s.signCallLock.Lock()
defer s.signCallLock.Unlock()
if len(s.signCallCapture) > idx {
return s.signCallCapture[idx]
}
return nil
}
func (s *testSigner) SignCert(ctx context.Context, req *structs.CASignRequest) (*structs.IssuedCert, error) {
useExpiredCert := false
s.signCallLock.Lock()
s.signCallCapture = append(s.signCallCapture, req)
if len(s.signCallErrors) > 0 {
err := s.signCallErrors[0]
s.signCallErrors = s.signCallErrors[1:]
if err == ReplyWithExpiredCert {
useExpiredCert = true
} else if err != nil {
s.signCallErrorCount++
s.signCallLock.Unlock()
return nil, err
}
}
s.signCallLock.Unlock()
// parts of this were inlined from CAManager and the connect ca provider
ca := s.getCA()
if ca == nil {
return nil, fmt.Errorf("must call UpdateCA at least once")
}
csr, err := connect.ParseCSR(req.CSR)
if err != nil {
return nil, fmt.Errorf("error parsing CSR: %w", err)
}
connect.HackSANExtensionForCSR(csr)
spiffeID, err := connect.ParseCertURI(csr.URIs[0])
if err != nil {
return nil, fmt.Errorf("error parsing CSR URI: %w", err)
}
serviceID, isService := spiffeID.(*connect.SpiffeIDService)
if !isService {
return nil, fmt.Errorf("unexpected spiffeID type %T", spiffeID)
}
signer, err := connect.ParseSigner(ca.SigningKey)
if err != nil {
return nil, fmt.Errorf("error parsing CA signing key: %w", err)
}
keyId, err := connect.KeyId(signer.Public())
if err != nil {
return nil, fmt.Errorf("error forming CA key id from public key: %w", err)
}
subjectKeyID, err := connect.KeyId(csr.PublicKey)
if err != nil {
return nil, fmt.Errorf("error forming subject key id from public key: %w", err)
}
caCert, err := connect.ParseCert(ca.RootCert)
if err != nil {
return nil, fmt.Errorf("error parsing CA root cert pem: %w", err)
}
const expiration = 10 * time.Minute
now := time.Now()
template := x509.Certificate{
SerialNumber: big.NewInt(int64(s.nextIndex())),
URIs: csr.URIs,
Signature: csr.Signature,
// We use the correct signature algorithm for the CA key we are signing with
// regardless of the algorithm used to sign the CSR signature above since
// the leaf might use a different key type.
SignatureAlgorithm: connect.SigAlgoForKey(signer),
PublicKeyAlgorithm: csr.PublicKeyAlgorithm,
PublicKey: csr.PublicKey,
BasicConstraintsValid: true,
KeyUsage: x509.KeyUsageDataEncipherment |
x509.KeyUsageKeyAgreement |
x509.KeyUsageDigitalSignature |
x509.KeyUsageKeyEncipherment,
ExtKeyUsage: []x509.ExtKeyUsage{
x509.ExtKeyUsageClientAuth,
x509.ExtKeyUsageServerAuth,
},
NotAfter: now.Add(expiration),
NotBefore: now,
AuthorityKeyId: keyId,
SubjectKeyId: subjectKeyID,
DNSNames: csr.DNSNames,
IPAddresses: csr.IPAddresses,
}
if useExpiredCert {
template.NotBefore = time.Now().Add(-13 * time.Hour)
template.NotAfter = time.Now().Add(-1 * time.Hour)
}
// Create the certificate, PEM encode it and return that value.
var buf bytes.Buffer
bs, err := x509.CreateCertificate(
rand.Reader, &template, caCert, csr.PublicKey, signer)
if err != nil {
return nil, fmt.Errorf("error creating cert pem from CSR: %w", err)
}
err = pem.Encode(&buf, &pem.Block{Type: "CERTIFICATE", Bytes: bs})
if err != nil {
return nil, fmt.Errorf("error encoding cert pem into text: %w", err)
}
leafPEM := buf.String()
leafCert, err := connect.ParseCert(leafPEM)
if err != nil {
return nil, fmt.Errorf("error parsing cert from generated leaf pem: %w", err)
}
index := s.nextIndex()
return &structs.IssuedCert{
SerialNumber: connect.EncodeSerialNumber(leafCert.SerialNumber),
CertPEM: leafPEM,
Service: serviceID.Service,
ServiceURI: leafCert.URIs[0].String(),
ValidAfter: leafCert.NotBefore,
ValidBefore: leafCert.NotAfter,
RaftIndex: structs.RaftIndex{
CreateIndex: index,
ModifyIndex: index,
},
}, nil
}