package autoconf

import (
	"context"
	"net"
	"sync"
	"testing"
	"time"

	"github.com/hashicorp/consul/agent/cache"
	cachetype "github.com/hashicorp/consul/agent/cache-types"
	"github.com/hashicorp/consul/agent/connect"
	"github.com/hashicorp/consul/agent/metadata"
	"github.com/hashicorp/consul/agent/structs"
	"github.com/hashicorp/consul/agent/token"
	"github.com/hashicorp/consul/proto/pbautoconf"
	"github.com/hashicorp/consul/sdk/testutil"
	"github.com/stretchr/testify/mock"
)

type mockDirectRPC struct {
	mock.Mock
}

func newMockDirectRPC(t *testing.T) *mockDirectRPC {
	m := mockDirectRPC{}
	m.Test(t)
	return &m
}

func (m *mockDirectRPC) RPC(dc string, node string, addr net.Addr, method string, args interface{}, reply interface{}) error {
	var retValues mock.Arguments
	if method == "AutoConfig.InitialConfiguration" {
		req := args.(*pbautoconf.AutoConfigRequest)
		csr := req.CSR
		req.CSR = ""
		retValues = m.Called(dc, node, addr, method, args, reply)
		req.CSR = csr
	} else if method == "AutoEncrypt.Sign" {
		req := args.(*structs.CASignRequest)
		csr := req.CSR
		req.CSR = ""
		retValues = m.Called(dc, node, addr, method, args, reply)
		req.CSR = csr
	} else {
		retValues = m.Called(dc, node, addr, method, args, reply)
	}

	return retValues.Error(0)
}

type mockTLSConfigurator struct {
	mock.Mock
}

func newMockTLSConfigurator(t *testing.T) *mockTLSConfigurator {
	m := mockTLSConfigurator{}
	m.Test(t)
	return &m
}

func (m *mockTLSConfigurator) UpdateAutoTLS(manualCAPEMs, connectCAPEMs []string, pub, priv string, verifyServerHostname bool) error {
	if priv != "" {
		priv = "redacted"
	}

	ret := m.Called(manualCAPEMs, connectCAPEMs, pub, priv, verifyServerHostname)
	return ret.Error(0)
}

func (m *mockTLSConfigurator) UpdateAutoTLSCA(pems []string) error {
	ret := m.Called(pems)
	return ret.Error(0)
}
func (m *mockTLSConfigurator) UpdateAutoTLSCert(pub, priv string) error {
	if priv != "" {
		priv = "redacted"
	}
	ret := m.Called(pub, priv)
	return ret.Error(0)
}
func (m *mockTLSConfigurator) AutoEncryptCertNotAfter() time.Time {
	ret := m.Called()
	ts, _ := ret.Get(0).(time.Time)

	return ts
}
func (m *mockTLSConfigurator) AutoEncryptCertExpired() bool {
	ret := m.Called()
	return ret.Bool(0)
}

type mockServerProvider struct {
	mock.Mock
}

func newMockServerProvider(t *testing.T) *mockServerProvider {
	m := mockServerProvider{}
	m.Test(t)
	return &m
}

func (m *mockServerProvider) FindLANServer() *metadata.Server {
	ret := m.Called()
	srv, _ := ret.Get(0).(*metadata.Server)
	return srv
}

type mockWatcher struct {
	ch   chan<- cache.UpdateEvent
	done <-chan struct{}
}

type mockCache struct {
	mock.Mock

	lock     sync.Mutex
	watchers map[string][]mockWatcher
}

func newMockCache(t *testing.T) *mockCache {
	m := mockCache{
		watchers: make(map[string][]mockWatcher),
	}
	m.Test(t)
	return &m
}

func (m *mockCache) Notify(ctx context.Context, t string, r cache.Request, correlationID string, ch chan<- cache.UpdateEvent) error {
	ret := m.Called(ctx, t, r, correlationID, ch)

	err := ret.Error(0)
	if err == nil {
		m.lock.Lock()
		key := r.CacheInfo().Key
		m.watchers[key] = append(m.watchers[key], mockWatcher{ch: ch, done: ctx.Done()})
		m.lock.Unlock()
	}
	return err
}

func (m *mockCache) Prepopulate(t string, result cache.FetchResult, dc string, token string, key string) error {
	var restore string
	cert, ok := result.Value.(*structs.IssuedCert)
	if ok {
		// we cannot know what the private key is prior to it being injected into the cache.
		// therefore redact it here and all mock expectations should take that into account
		restore = cert.PrivateKeyPEM
		cert.PrivateKeyPEM = "redacted"
	}

	ret := m.Called(t, result, dc, token, key)

	if ok && restore != "" {
		cert.PrivateKeyPEM = restore
	}
	return ret.Error(0)
}

func (m *mockCache) sendNotification(ctx context.Context, key string, u cache.UpdateEvent) bool {
	m.lock.Lock()
	defer m.lock.Unlock()

	watchers, ok := m.watchers[key]
	if !ok || len(m.watchers) < 1 {
		return false
	}

	var newWatchers []mockWatcher

	for _, watcher := range watchers {
		select {
		case watcher.ch <- u:
			newWatchers = append(newWatchers, watcher)
		case <-watcher.done:
			// do nothing, this watcher will be removed from the list
		case <-ctx.Done():
			// return doesn't matter here really, the test is being cancelled
			return true
		}
	}

	// this removes any already cancelled watches from being sent to
	m.watchers[key] = newWatchers

	return true
}

type mockTokenStore struct {
	mock.Mock
}

func newMockTokenStore(t *testing.T) *mockTokenStore {
	m := mockTokenStore{}
	m.Test(t)
	return &m
}

func (m *mockTokenStore) AgentToken() string {
	ret := m.Called()
	return ret.String(0)
}

func (m *mockTokenStore) UpdateAgentToken(secret string, source token.TokenSource) bool {
	return m.Called(secret, source).Bool(0)
}

func (m *mockTokenStore) Notify(kind token.TokenKind) token.Notifier {
	ret := m.Called(kind)
	n, _ := ret.Get(0).(token.Notifier)
	return n
}

func (m *mockTokenStore) StopNotify(notifier token.Notifier) {
	m.Called(notifier)
}

type mockedConfig struct {
	Config

	directRPC      *mockDirectRPC
	serverProvider *mockServerProvider
	cache          *mockCache
	tokens         *mockTokenStore
	tlsCfg         *mockTLSConfigurator
}

func newMockedConfig(t *testing.T) *mockedConfig {
	directRPC := newMockDirectRPC(t)
	serverProvider := newMockServerProvider(t)
	mcache := newMockCache(t)
	tokens := newMockTokenStore(t)
	tlsCfg := newMockTLSConfigurator(t)

	// I am not sure it is well defined behavior but in testing it
	// out it does appear like Cleanup functions can fail tests
	// Adding in the mock expectations assertions here saves us
	// a bunch of code in the other test functions.
	t.Cleanup(func() {
		if !t.Failed() {
			directRPC.AssertExpectations(t)
			serverProvider.AssertExpectations(t)
			mcache.AssertExpectations(t)
			tokens.AssertExpectations(t)
			tlsCfg.AssertExpectations(t)
		}
	})

	return &mockedConfig{
		Config: Config{
			DirectRPC:       directRPC,
			ServerProvider:  serverProvider,
			Cache:           mcache,
			Tokens:          tokens,
			TLSConfigurator: tlsCfg,
			Logger:          testutil.Logger(t),
		},
		directRPC:      directRPC,
		serverProvider: serverProvider,
		cache:          mcache,
		tokens:         tokens,
		tlsCfg:         tlsCfg,
	}
}

func (m *mockedConfig) expectInitialTLS(t *testing.T, agentName, datacenter, token string, ca *structs.CARoot, indexedRoots *structs.IndexedCARoots, cert *structs.IssuedCert, extraCerts []string) {
	var pems []string
	for _, root := range indexedRoots.Roots {
		pems = append(pems, root.RootCert)
	}

	// we should update the TLS configurator with the proper certs
	m.tlsCfg.On("UpdateAutoTLS",
		extraCerts,
		pems,
		cert.CertPEM,
		// auto-config handles the CSR and Key so our tests don't have
		// a way to know that the key is correct or not. We do replace
		// a non empty PEM with "redacted" so we can ensure that some
		// certificate is being sent
		"redacted",
		true,
	).Return(nil).Once()

	rootRes := cache.FetchResult{Value: indexedRoots, Index: indexedRoots.QueryMeta.Index}
	rootsReq := structs.DCSpecificRequest{Datacenter: datacenter}

	// we should prepopulate the cache with the CA roots
	m.cache.On("Prepopulate",
		cachetype.ConnectCARootName,
		rootRes,
		datacenter,
		"",
		rootsReq.CacheInfo().Key,
	).Return(nil).Once()

	leafReq := cachetype.ConnectCALeafRequest{
		Token:      token,
		Agent:      agentName,
		Datacenter: datacenter,
	}

	// copy the cert and redact the private key for the mock expectation
	// the actual private key will not correspond to the cert but thats
	// because AutoConfig is generated a key/csr internally and sending that
	// on up with the request.
	copy := *cert
	copy.PrivateKeyPEM = "redacted"
	leafRes := cache.FetchResult{
		Value: &copy,
		Index: copy.RaftIndex.ModifyIndex,
		State: cachetype.ConnectCALeafSuccess(ca.SigningKeyID),
	}

	// we should prepopulate the cache with the agents cert
	m.cache.On("Prepopulate",
		cachetype.ConnectCALeafName,
		leafRes,
		datacenter,
		token,
		leafReq.Key(),
	).Return(nil).Once()

	// when prepopulating the cert in the cache we grab the token so
	// we should expec that here
	m.tokens.On("AgentToken").Return(token).Once()
}

func (m *mockedConfig) setupInitialTLS(t *testing.T, agentName, datacenter, token string) (*structs.IndexedCARoots, *structs.IssuedCert, []string) {
	ca, indexedRoots, cert := testCerts(t, agentName, datacenter)

	ca2 := connect.TestCA(t, nil)
	extraCerts := []string{ca2.RootCert}

	m.expectInitialTLS(t, agentName, datacenter, token, ca, indexedRoots, cert, extraCerts)
	return indexedRoots, cert, extraCerts
}