package ca

import (
	"fmt"
	"io/ioutil"
	"os"
	"os/exec"
	"sync"

	"github.com/hashicorp/consul/agent/connect"
	"github.com/hashicorp/consul/agent/consul/state"
	"github.com/hashicorp/consul/agent/structs"
	"github.com/hashicorp/consul/sdk/freeport"
	"github.com/hashicorp/consul/sdk/testutil/retry"
	"github.com/hashicorp/go-hclog"
	vaultapi "github.com/hashicorp/vault/api"
	"github.com/mitchellh/go-testing-interface"
)

// KeyTestCases is a list of the important CA key types that we should test
// against when signing. For now leaf keys are always EC P256 but CA can be EC
// (any NIST curve) or RSA (2048, 4096). Providers must be able to complete all
// signing operations with both types that includes:
//   - Sign must be able to sign EC P256 leaf with all these types of CA key
//   - CrossSignCA must be able to sign all these types of new CA key with all
//     these types of old CA key.
//   - SignIntermediate muse bt able to sign all the types of secondary
//     intermediate CA key with all these types of primary CA key
var KeyTestCases = []struct {
	Desc    string
	KeyType string
	KeyBits int
}{
	{
		Desc:    "Default Key Type (EC 256)",
		KeyType: connect.DefaultPrivateKeyType,
		KeyBits: connect.DefaultPrivateKeyBits,
	},
	{
		Desc:    "RSA 2048",
		KeyType: "rsa",
		KeyBits: 2048,
	},
}

// CASigningKeyTypes is a struct with params for tests that sign one CA CSR with
// another CA key.
type CASigningKeyTypes struct {
	Desc           string
	SigningKeyType string
	SigningKeyBits int
	CSRKeyType     string
	CSRKeyBits     int
}

// CASigningKeyTypeCases returns the cross-product of the important supported CA
// key types for generating table tests for CA signing tests (CrossSignCA and
// SignIntermediate).
func CASigningKeyTypeCases() []CASigningKeyTypes {
	cases := make([]CASigningKeyTypes, 0, len(KeyTestCases)*len(KeyTestCases))
	for _, outer := range KeyTestCases {
		for _, inner := range KeyTestCases {
			cases = append(cases, CASigningKeyTypes{
				Desc: fmt.Sprintf("%s-%d signing %s-%d", outer.KeyType, outer.KeyBits,
					inner.KeyType, inner.KeyBits),
				SigningKeyType: outer.KeyType,
				SigningKeyBits: outer.KeyBits,
				CSRKeyType:     inner.KeyType,
				CSRKeyBits:     inner.KeyBits,
			})
		}
	}
	return cases
}

// TestConsulProvider creates a new ConsulProvider, taking care to stub out it's
// Logger so that logging calls don't panic. If logging output is important
// SetLogger can be called again with another logger to capture logs.
func TestConsulProvider(t testing.T, d ConsulProviderStateDelegate) *ConsulProvider {
	provider := &ConsulProvider{Delegate: d}
	logger := hclog.New(&hclog.LoggerOptions{
		Output: ioutil.Discard,
	})
	provider.SetLogger(logger)
	return provider
}

// SkipIfVaultNotPresent skips the test if the vault binary is not in PATH.
//
// These tests may be skipped in CI. They are run as part of a separate
// integration test suite.
func SkipIfVaultNotPresent(t testing.T) {
	vaultBinaryName := os.Getenv("VAULT_BINARY_NAME")
	if vaultBinaryName == "" {
		vaultBinaryName = "vault"
	}

	path, err := exec.LookPath(vaultBinaryName)
	if err != nil || path == "" {
		t.Skipf("%q not found on $PATH - download and install to run this test", vaultBinaryName)
	}
}

func NewTestVaultServer(t testing.T) *TestVaultServer {
	testVault, err := runTestVault(t)
	if err != nil {
		t.Fatalf("err: %v", err)
	}

	testVault.WaitUntilReady(t)

	return testVault
}

func runTestVault(t testing.T) (*TestVaultServer, error) {
	vaultBinaryName := os.Getenv("VAULT_BINARY_NAME")
	if vaultBinaryName == "" {
		vaultBinaryName = "vault"
	}

	path, err := exec.LookPath(vaultBinaryName)
	if err != nil || path == "" {
		return nil, fmt.Errorf("%q not found on $PATH", vaultBinaryName)
	}

	ports := freeport.MustTake(2)
	returnPortsFn := func() {
		freeport.Return(ports)
	}

	var (
		clientAddr  = fmt.Sprintf("127.0.0.1:%d", ports[0])
		clusterAddr = fmt.Sprintf("127.0.0.1:%d", ports[1])
	)

	const token = "root"

	client, err := vaultapi.NewClient(&vaultapi.Config{
		Address: "http://" + clientAddr,
	})
	if err != nil {
		returnPortsFn()
		return nil, err
	}
	client.SetToken(token)

	args := []string{
		"server",
		"-dev",
		"-dev-root-token-id",
		token,
		"-dev-listen-address",
		clientAddr,
		"-address",
		clusterAddr,
	}

	cmd := exec.Command(vaultBinaryName, args...)
	cmd.Stdout = ioutil.Discard
	cmd.Stderr = ioutil.Discard
	if err := cmd.Start(); err != nil {
		returnPortsFn()
		return nil, err
	}

	testVault := &TestVaultServer{
		RootToken:     token,
		Addr:          "http://" + clientAddr,
		cmd:           cmd,
		client:        client,
		returnPortsFn: returnPortsFn,
	}
	t.Cleanup(func() {
		testVault.Stop()
	})
	return testVault, nil
}

type TestVaultServer struct {
	RootToken string
	Addr      string
	cmd       *exec.Cmd
	client    *vaultapi.Client

	// returnPortsFn will put the ports claimed for the test back into the
	returnPortsFn func()
}

var printedVaultVersion sync.Once

func (v *TestVaultServer) Client() *vaultapi.Client {
	return v.client
}

func (v *TestVaultServer) WaitUntilReady(t testing.T) {
	var version string
	retry.Run(t, func(r *retry.R) {
		resp, err := v.client.Sys().Health()
		if err != nil {
			r.Fatalf("err: %v", err)
		}
		if !resp.Initialized {
			r.Fatalf("vault server is not initialized")
		}
		if resp.Sealed {
			r.Fatalf("vault server is sealed")
		}
		version = resp.Version
	})
	printedVaultVersion.Do(func() {
		fmt.Fprintf(os.Stderr, "[INFO] agent/connect/ca: testing with vault server version: %s\n", version)
	})
}

func (v *TestVaultServer) Stop() error {
	// There was no process
	if v.cmd == nil {
		return nil
	}

	if v.cmd.Process != nil {
		if err := v.cmd.Process.Signal(os.Interrupt); err != nil {
			return fmt.Errorf("failed to kill vault server: %v", err)
		}
	}

	// wait for the process to exit to be sure that the data dir can be
	// deleted on all platforms.
	if err := v.cmd.Wait(); err != nil {
		return err
	}

	if v.returnPortsFn != nil {
		v.returnPortsFn()
	}

	return nil
}

func ApplyCARequestToStore(store *state.Store, req *structs.CARequest) (interface{}, error) {
	idx, _, err := store.CAConfig(nil)
	if err != nil {
		return nil, err
	}

	switch req.Op {
	case structs.CAOpSetProviderState:
		_, err := store.CASetProviderState(idx+1, req.ProviderState)
		if err != nil {
			return nil, err
		}

		return true, nil
	case structs.CAOpDeleteProviderState:
		if err := store.CADeleteProviderState(req.ProviderState.ID); err != nil {
			return nil, err
		}

		return true, nil
	case structs.CAOpIncrementProviderSerialNumber:
		return uint64(2), nil
	default:
		return nil, fmt.Errorf("Invalid CA operation '%s'", req.Op)
	}
}