mirror of
https://github.com/status-im/consul.git
synced 2025-01-09 13:26:07 +00:00
338 lines
8.5 KiB
Go
338 lines
8.5 KiB
Go
|
package autoconf
|
||
|
|
||
|
import (
|
||
|
"context"
|
||
|
"net"
|
||
|
"sync"
|
||
|
"testing"
|
||
|
"time"
|
||
|
|
||
|
"github.com/hashicorp/consul/agent/cache"
|
||
|
cachetype "github.com/hashicorp/consul/agent/cache-types"
|
||
|
"github.com/hashicorp/consul/agent/connect"
|
||
|
"github.com/hashicorp/consul/agent/metadata"
|
||
|
"github.com/hashicorp/consul/agent/structs"
|
||
|
"github.com/hashicorp/consul/agent/token"
|
||
|
"github.com/hashicorp/consul/proto/pbautoconf"
|
||
|
"github.com/hashicorp/consul/sdk/testutil"
|
||
|
"github.com/stretchr/testify/mock"
|
||
|
)
|
||
|
|
||
|
type mockDirectRPC struct {
|
||
|
mock.Mock
|
||
|
}
|
||
|
|
||
|
func newMockDirectRPC(t *testing.T) *mockDirectRPC {
|
||
|
m := mockDirectRPC{}
|
||
|
m.Test(t)
|
||
|
return &m
|
||
|
}
|
||
|
|
||
|
func (m *mockDirectRPC) RPC(dc string, node string, addr net.Addr, method string, args interface{}, reply interface{}) error {
|
||
|
var retValues mock.Arguments
|
||
|
if method == "AutoConfig.InitialConfiguration" {
|
||
|
req := args.(*pbautoconf.AutoConfigRequest)
|
||
|
csr := req.CSR
|
||
|
req.CSR = ""
|
||
|
retValues = m.Called(dc, node, addr, method, args, reply)
|
||
|
req.CSR = csr
|
||
|
} else if method == "AutoEncrypt.Sign" {
|
||
|
req := args.(*structs.CASignRequest)
|
||
|
csr := req.CSR
|
||
|
req.CSR = ""
|
||
|
retValues = m.Called(dc, node, addr, method, args, reply)
|
||
|
req.CSR = csr
|
||
|
} else {
|
||
|
retValues = m.Called(dc, node, addr, method, args, reply)
|
||
|
}
|
||
|
|
||
|
return retValues.Error(0)
|
||
|
}
|
||
|
|
||
|
type mockTLSConfigurator struct {
|
||
|
mock.Mock
|
||
|
}
|
||
|
|
||
|
func newMockTLSConfigurator(t *testing.T) *mockTLSConfigurator {
|
||
|
m := mockTLSConfigurator{}
|
||
|
m.Test(t)
|
||
|
return &m
|
||
|
}
|
||
|
|
||
|
func (m *mockTLSConfigurator) UpdateAutoTLS(manualCAPEMs, connectCAPEMs []string, pub, priv string, verifyServerHostname bool) error {
|
||
|
if priv != "" {
|
||
|
priv = "redacted"
|
||
|
}
|
||
|
|
||
|
ret := m.Called(manualCAPEMs, connectCAPEMs, pub, priv, verifyServerHostname)
|
||
|
return ret.Error(0)
|
||
|
}
|
||
|
|
||
|
func (m *mockTLSConfigurator) UpdateAutoTLSCA(pems []string) error {
|
||
|
ret := m.Called(pems)
|
||
|
return ret.Error(0)
|
||
|
}
|
||
|
func (m *mockTLSConfigurator) UpdateAutoTLSCert(pub, priv string) error {
|
||
|
if priv != "" {
|
||
|
priv = "redacted"
|
||
|
}
|
||
|
ret := m.Called(pub, priv)
|
||
|
return ret.Error(0)
|
||
|
}
|
||
|
func (m *mockTLSConfigurator) AutoEncryptCertNotAfter() time.Time {
|
||
|
ret := m.Called()
|
||
|
ts, _ := ret.Get(0).(time.Time)
|
||
|
|
||
|
return ts
|
||
|
}
|
||
|
func (m *mockTLSConfigurator) AutoEncryptCertExpired() bool {
|
||
|
ret := m.Called()
|
||
|
return ret.Bool(0)
|
||
|
}
|
||
|
|
||
|
type mockServerProvider struct {
|
||
|
mock.Mock
|
||
|
}
|
||
|
|
||
|
func newMockServerProvider(t *testing.T) *mockServerProvider {
|
||
|
m := mockServerProvider{}
|
||
|
m.Test(t)
|
||
|
return &m
|
||
|
}
|
||
|
|
||
|
func (m *mockServerProvider) FindLANServer() *metadata.Server {
|
||
|
ret := m.Called()
|
||
|
srv, _ := ret.Get(0).(*metadata.Server)
|
||
|
return srv
|
||
|
}
|
||
|
|
||
|
type mockWatcher struct {
|
||
|
ch chan<- cache.UpdateEvent
|
||
|
done <-chan struct{}
|
||
|
}
|
||
|
|
||
|
type mockCache struct {
|
||
|
mock.Mock
|
||
|
|
||
|
lock sync.Mutex
|
||
|
watchers map[string][]mockWatcher
|
||
|
}
|
||
|
|
||
|
func newMockCache(t *testing.T) *mockCache {
|
||
|
m := mockCache{
|
||
|
watchers: make(map[string][]mockWatcher),
|
||
|
}
|
||
|
m.Test(t)
|
||
|
return &m
|
||
|
}
|
||
|
|
||
|
func (m *mockCache) Notify(ctx context.Context, t string, r cache.Request, correlationID string, ch chan<- cache.UpdateEvent) error {
|
||
|
ret := m.Called(ctx, t, r, correlationID, ch)
|
||
|
|
||
|
err := ret.Error(0)
|
||
|
if err == nil {
|
||
|
m.lock.Lock()
|
||
|
key := r.CacheInfo().Key
|
||
|
m.watchers[key] = append(m.watchers[key], mockWatcher{ch: ch, done: ctx.Done()})
|
||
|
m.lock.Unlock()
|
||
|
}
|
||
|
return err
|
||
|
}
|
||
|
|
||
|
func (m *mockCache) Prepopulate(t string, result cache.FetchResult, dc string, token string, key string) error {
|
||
|
var restore string
|
||
|
cert, ok := result.Value.(*structs.IssuedCert)
|
||
|
if ok {
|
||
|
// we cannot know what the private key is prior to it being injected into the cache.
|
||
|
// therefore redact it here and all mock expectations should take that into account
|
||
|
restore = cert.PrivateKeyPEM
|
||
|
cert.PrivateKeyPEM = "redacted"
|
||
|
}
|
||
|
|
||
|
ret := m.Called(t, result, dc, token, key)
|
||
|
|
||
|
if ok && restore != "" {
|
||
|
cert.PrivateKeyPEM = restore
|
||
|
}
|
||
|
return ret.Error(0)
|
||
|
}
|
||
|
|
||
|
func (m *mockCache) sendNotification(ctx context.Context, key string, u cache.UpdateEvent) bool {
|
||
|
m.lock.Lock()
|
||
|
defer m.lock.Unlock()
|
||
|
|
||
|
watchers, ok := m.watchers[key]
|
||
|
if !ok || len(m.watchers) < 1 {
|
||
|
return false
|
||
|
}
|
||
|
|
||
|
var newWatchers []mockWatcher
|
||
|
|
||
|
for _, watcher := range watchers {
|
||
|
select {
|
||
|
case watcher.ch <- u:
|
||
|
newWatchers = append(newWatchers, watcher)
|
||
|
case <-watcher.done:
|
||
|
// do nothing, this watcher will be removed from the list
|
||
|
case <-ctx.Done():
|
||
|
// return doesn't matter here really, the test is being cancelled
|
||
|
return true
|
||
|
}
|
||
|
}
|
||
|
|
||
|
// this removes any already cancelled watches from being sent to
|
||
|
m.watchers[key] = newWatchers
|
||
|
|
||
|
return true
|
||
|
}
|
||
|
|
||
|
type mockTokenStore struct {
|
||
|
mock.Mock
|
||
|
}
|
||
|
|
||
|
func newMockTokenStore(t *testing.T) *mockTokenStore {
|
||
|
m := mockTokenStore{}
|
||
|
m.Test(t)
|
||
|
return &m
|
||
|
}
|
||
|
|
||
|
func (m *mockTokenStore) AgentToken() string {
|
||
|
ret := m.Called()
|
||
|
return ret.String(0)
|
||
|
}
|
||
|
|
||
|
func (m *mockTokenStore) UpdateAgentToken(secret string, source token.TokenSource) bool {
|
||
|
return m.Called(secret, source).Bool(0)
|
||
|
}
|
||
|
|
||
|
func (m *mockTokenStore) Notify(kind token.TokenKind) token.Notifier {
|
||
|
ret := m.Called(kind)
|
||
|
n, _ := ret.Get(0).(token.Notifier)
|
||
|
return n
|
||
|
}
|
||
|
|
||
|
func (m *mockTokenStore) StopNotify(notifier token.Notifier) {
|
||
|
m.Called(notifier)
|
||
|
}
|
||
|
|
||
|
type mockedConfig struct {
|
||
|
Config
|
||
|
|
||
|
directRPC *mockDirectRPC
|
||
|
serverProvider *mockServerProvider
|
||
|
cache *mockCache
|
||
|
tokens *mockTokenStore
|
||
|
tlsCfg *mockTLSConfigurator
|
||
|
}
|
||
|
|
||
|
func newMockedConfig(t *testing.T) *mockedConfig {
|
||
|
directRPC := newMockDirectRPC(t)
|
||
|
serverProvider := newMockServerProvider(t)
|
||
|
mcache := newMockCache(t)
|
||
|
tokens := newMockTokenStore(t)
|
||
|
tlsCfg := newMockTLSConfigurator(t)
|
||
|
|
||
|
// I am not sure it is well defined behavior but in testing it
|
||
|
// out it does appear like Cleanup functions can fail tests
|
||
|
// Adding in the mock expectations assertions here saves us
|
||
|
// a bunch of code in the other test functions.
|
||
|
t.Cleanup(func() {
|
||
|
if !t.Failed() {
|
||
|
directRPC.AssertExpectations(t)
|
||
|
serverProvider.AssertExpectations(t)
|
||
|
mcache.AssertExpectations(t)
|
||
|
tokens.AssertExpectations(t)
|
||
|
tlsCfg.AssertExpectations(t)
|
||
|
}
|
||
|
})
|
||
|
|
||
|
return &mockedConfig{
|
||
|
Config: Config{
|
||
|
DirectRPC: directRPC,
|
||
|
ServerProvider: serverProvider,
|
||
|
Cache: mcache,
|
||
|
Tokens: tokens,
|
||
|
TLSConfigurator: tlsCfg,
|
||
|
Logger: testutil.Logger(t),
|
||
|
},
|
||
|
directRPC: directRPC,
|
||
|
serverProvider: serverProvider,
|
||
|
cache: mcache,
|
||
|
tokens: tokens,
|
||
|
tlsCfg: tlsCfg,
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func (m *mockedConfig) expectInitialTLS(t *testing.T, agentName, datacenter, token string, ca *structs.CARoot, indexedRoots *structs.IndexedCARoots, cert *structs.IssuedCert, extraCerts []string) {
|
||
|
var pems []string
|
||
|
for _, root := range indexedRoots.Roots {
|
||
|
pems = append(pems, root.RootCert)
|
||
|
}
|
||
|
|
||
|
// we should update the TLS configurator with the proper certs
|
||
|
m.tlsCfg.On("UpdateAutoTLS",
|
||
|
extraCerts,
|
||
|
pems,
|
||
|
cert.CertPEM,
|
||
|
// auto-config handles the CSR and Key so our tests don't have
|
||
|
// a way to know that the key is correct or not. We do replace
|
||
|
// a non empty PEM with "redacted" so we can ensure that some
|
||
|
// certificate is being sent
|
||
|
"redacted",
|
||
|
true,
|
||
|
).Return(nil).Once()
|
||
|
|
||
|
rootRes := cache.FetchResult{Value: indexedRoots, Index: indexedRoots.QueryMeta.Index}
|
||
|
rootsReq := structs.DCSpecificRequest{Datacenter: datacenter}
|
||
|
|
||
|
// we should prepopulate the cache with the CA roots
|
||
|
m.cache.On("Prepopulate",
|
||
|
cachetype.ConnectCARootName,
|
||
|
rootRes,
|
||
|
datacenter,
|
||
|
"",
|
||
|
rootsReq.CacheInfo().Key,
|
||
|
).Return(nil).Once()
|
||
|
|
||
|
leafReq := cachetype.ConnectCALeafRequest{
|
||
|
Token: token,
|
||
|
Agent: agentName,
|
||
|
Datacenter: datacenter,
|
||
|
}
|
||
|
|
||
|
// copy the cert and redact the private key for the mock expectation
|
||
|
// the actual private key will not correspond to the cert but thats
|
||
|
// because AutoConfig is generated a key/csr internally and sending that
|
||
|
// on up with the request.
|
||
|
copy := *cert
|
||
|
copy.PrivateKeyPEM = "redacted"
|
||
|
leafRes := cache.FetchResult{
|
||
|
Value: ©,
|
||
|
Index: copy.RaftIndex.ModifyIndex,
|
||
|
State: cachetype.ConnectCALeafSuccess(ca.SigningKeyID),
|
||
|
}
|
||
|
|
||
|
// we should prepopulate the cache with the agents cert
|
||
|
m.cache.On("Prepopulate",
|
||
|
cachetype.ConnectCALeafName,
|
||
|
leafRes,
|
||
|
datacenter,
|
||
|
token,
|
||
|
leafReq.Key(),
|
||
|
).Return(nil).Once()
|
||
|
|
||
|
// when prepopulating the cert in the cache we grab the token so
|
||
|
// we should expec that here
|
||
|
m.tokens.On("AgentToken").Return(token).Once()
|
||
|
}
|
||
|
|
||
|
func (m *mockedConfig) setupInitialTLS(t *testing.T, agentName, datacenter, token string) (*structs.IndexedCARoots, *structs.IssuedCert, []string) {
|
||
|
ca, indexedRoots, cert := testCerts(t, agentName, datacenter)
|
||
|
|
||
|
ca2 := connect.TestCA(t, nil)
|
||
|
extraCerts := []string{ca2.RootCert}
|
||
|
|
||
|
m.expectInitialTLS(t, agentName, datacenter, token, ca, indexedRoots, cert, extraCerts)
|
||
|
return indexedRoots, cert, extraCerts
|
||
|
}
|