// package oidcauthtest exposes tools to assist in writing unit tests of OIDC
// and JWT authentication workflows.
//
// When the package is loaded it will randomly generate an ECDSA signing
// keypair used to sign JWTs both via the Server and the SignJWT method.
package oidcauthtest

import (
	"bytes"
	"crypto/ecdsa"
	"crypto/elliptic"
	"crypto/rand"
	"crypto/x509"
	"encoding/json"
	"encoding/pem"
	"fmt"
	"io/ioutil"
	"log"
	"net"
	"net/http"
	"net/http/httptest"
	"net/url"
	"strconv"
	"sync"
	"time"

	"github.com/hashicorp/consul/internal/go-sso/oidcauth/internal/strutil"
	"github.com/mitchellh/go-testing-interface"
	"github.com/stretchr/testify/require"
	"gopkg.in/square/go-jose.v2"
	"gopkg.in/square/go-jose.v2/jwt"
)

// Server is local server the mocks the endpoints used by the OIDC and
// JWKS process.
type Server struct {
	httpServer *httptest.Server
	caCert     string
	returnFunc func()

	jwks                *jose.JSONWebKeySet
	allowedRedirectURIs []string
	replySubject        string
	replyUserinfo       map[string]interface{}

	mu                sync.Mutex
	clientID          string
	clientSecret      string
	expectedAuthCode  string
	expectedAuthNonce string
	customClaims      map[string]interface{}
	customAudience    string
	omitIDToken       bool
	disableUserInfo   bool
}

type startOption struct {
	port       int
	returnFunc func()
}

// WithPort is a option for Start that lets the caller control the port
// allocation. The returnFunc parameter is used when the provider is stopped to
// return the port in whatever bookkeeping system the caller wants to use.
func WithPort(port int, returnFunc func()) startOption {
	return startOption{
		port:       port,
		returnFunc: returnFunc,
	}
}

// Start creates a disposable Server. If the port provided is
// zero it will bind to a random free port, otherwise the provided port is
// used.
func Start(t testing.T, options ...startOption) *Server {
	s := &Server{
		allowedRedirectURIs: []string{
			"https://example.com",
		},
		replySubject: "r3qXcK2bix9eFECzsU3Sbmh0K16fatW6@clients",
		replyUserinfo: map[string]interface{}{
			"color":       "red",
			"temperature": "76",
			"flavor":      "umami",
		},
	}

	jwks, err := newJWKS(ecdsaPublicKey)
	require.NoError(t, err)
	s.jwks = jwks

	var (
		port       int
		returnFunc func()
	)
	for _, option := range options {
		if option.port > 0 {
			port = option.port
			returnFunc = option.returnFunc
		}
	}

	s.httpServer = httptestNewUnstartedServerWithPort(s, port)
	s.httpServer.Config.ErrorLog = log.New(ioutil.Discard, "", 0)
	s.httpServer.StartTLS()
	if returnFunc != nil {
		t.Cleanup(returnFunc)
	}
	t.Cleanup(s.httpServer.Close)

	cert := s.httpServer.Certificate()

	var buf bytes.Buffer
	require.NoError(t, pem.Encode(&buf, &pem.Block{Type: "CERTIFICATE", Bytes: cert.Raw}))
	s.caCert = buf.String()

	return s
}

// SetClientCreds is for configuring the client information required for the
// OIDC workflows.
func (s *Server) SetClientCreds(clientID, clientSecret string) {
	s.mu.Lock()
	defer s.mu.Unlock()
	s.clientID = clientID
	s.clientSecret = clientSecret
}

// SetExpectedAuthCode configures the auth code to return from /auth and the
// allowed auth code for /token.
func (s *Server) SetExpectedAuthCode(code string) {
	s.mu.Lock()
	defer s.mu.Unlock()
	s.expectedAuthCode = code
}

// SetExpectedAuthNonce configures the nonce value required for /auth.
func (s *Server) SetExpectedAuthNonce(nonce string) {
	s.mu.Lock()
	defer s.mu.Unlock()
	s.expectedAuthNonce = nonce
}

// SetAllowedRedirectURIs allows you to configure the allowed redirect URIs for
// the OIDC workflow. If not configured a sample of "https://example.com" is
// used.
func (s *Server) SetAllowedRedirectURIs(uris []string) {
	s.mu.Lock()
	defer s.mu.Unlock()
	s.allowedRedirectURIs = uris
}

// SetCustomClaims lets you set claims to return in the JWT issued by the OIDC
// workflow.
func (s *Server) SetCustomClaims(customClaims map[string]interface{}) {
	s.mu.Lock()
	defer s.mu.Unlock()
	s.customClaims = customClaims
}

// SetCustomAudience configures what audience value to embed in the JWT issued
// by the OIDC workflow.
func (s *Server) SetCustomAudience(customAudience string) {
	s.mu.Lock()
	defer s.mu.Unlock()
	s.customAudience = customAudience
}

// OmitIDTokens forces an error state where the /token endpoint does not return
// id_token.
func (s *Server) OmitIDTokens() {
	s.mu.Lock()
	defer s.mu.Unlock()
	s.omitIDToken = true
}

// DisableUserInfo makes the userinfo endpoint return 404 and omits it from the
// discovery config.
func (s *Server) DisableUserInfo() {
	s.mu.Lock()
	defer s.mu.Unlock()
	s.disableUserInfo = true
}

// Stop stops the running Server.
func (s *Server) Stop() {
	s.httpServer.Close()
}

// Addr returns the current base URL for the running webserver.
func (s *Server) Addr() string { return s.httpServer.URL }

// CACert returns the pem-encoded CA certificate used by the HTTPS server.
func (s *Server) CACert() string { return s.caCert }

// SigningKeys returns the pem-encoded keys used to sign JWTs.
func (s *Server) SigningKeys() (pub, priv string) {
	return SigningKeys()
}

// ServeHTTP implements http.Handler.
func (s *Server) ServeHTTP(w http.ResponseWriter, req *http.Request) {
	s.mu.Lock()
	defer s.mu.Unlock()

	w.Header().Set("Content-Type", "application/json")

	switch req.URL.Path {
	case "/.well-known/openid-configuration":
		if req.Method != "GET" {
			w.WriteHeader(http.StatusMethodNotAllowed)
			return
		}

		reply := struct {
			Issuer           string `json:"issuer"`
			AuthEndpoint     string `json:"authorization_endpoint"`
			TokenEndpoint    string `json:"token_endpoint"`
			JWKSURI          string `json:"jwks_uri"`
			UserinfoEndpoint string `json:"userinfo_endpoint,omitempty"`
		}{
			Issuer:           s.Addr(),
			AuthEndpoint:     s.Addr() + "/auth",
			TokenEndpoint:    s.Addr() + "/token",
			JWKSURI:          s.Addr() + "/certs",
			UserinfoEndpoint: s.Addr() + "/userinfo",
		}
		if s.disableUserInfo {
			reply.UserinfoEndpoint = ""
		}

		if err := writeJSON(w, &reply); err != nil {
			return
		}

	case "/auth":
		if req.Method != "GET" {
			w.WriteHeader(http.StatusMethodNotAllowed)
			return
		}

		qv := req.URL.Query()

		if qv.Get("response_type") != "code" {
			writeAuthErrorResponse(w, req, "unsupported_response_type", "")
			return
		}
		if qv.Get("scope") != "openid" {
			writeAuthErrorResponse(w, req, "invalid_scope", "")
			return
		}

		if s.expectedAuthCode == "" {
			writeAuthErrorResponse(w, req, "access_denied", "")
			return
		}

		nonce := qv.Get("nonce")
		if s.expectedAuthNonce != "" && s.expectedAuthNonce != nonce {
			writeAuthErrorResponse(w, req, "access_denied", "")
			return
		}

		state := qv.Get("state")
		if state == "" {
			writeAuthErrorResponse(w, req, "invalid_request", "missing state parameter")
			return
		}

		redirectURI := qv.Get("redirect_uri")
		if redirectURI == "" {
			writeAuthErrorResponse(w, req, "invalid_request", "missing redirect_uri parameter")
			return
		}

		redirectURI += "?state=" + url.QueryEscape(state) +
			"&code=" + url.QueryEscape(s.expectedAuthCode)

		http.Redirect(w, req, redirectURI, http.StatusFound)

		return

	case "/certs":
		if req.Method != "GET" {
			w.WriteHeader(http.StatusMethodNotAllowed)
			return
		}

		if err := writeJSON(w, s.jwks); err != nil {
			return
		}

	case "/certs_missing":
		w.WriteHeader(http.StatusNotFound)

	case "/certs_invalid":
		w.Write([]byte("It's not a keyset!"))

	case "/token":
		if req.Method != "POST" {
			w.WriteHeader(http.StatusMethodNotAllowed)
			return
		}

		switch {
		case req.FormValue("grant_type") != "authorization_code":
			_ = writeTokenErrorResponse(w, req, http.StatusBadRequest, "invalid_request", "bad grant_type")
			return
		case !strutil.StrListContains(s.allowedRedirectURIs, req.FormValue("redirect_uri")):
			_ = writeTokenErrorResponse(w, req, http.StatusBadRequest, "invalid_request", "redirect_uri is not allowed")
			return
		case req.FormValue("code") != s.expectedAuthCode:
			_ = writeTokenErrorResponse(w, req, http.StatusUnauthorized, "invalid_grant", "unexpected auth code")
			return
		}

		stdClaims := jwt.Claims{
			Subject:   s.replySubject,
			Issuer:    s.Addr(),
			NotBefore: jwt.NewNumericDate(time.Now().Add(-5 * time.Second)),
			Expiry:    jwt.NewNumericDate(time.Now().Add(5 * time.Second)),
			Audience:  jwt.Audience{s.clientID},
		}
		if s.customAudience != "" {
			stdClaims.Audience = jwt.Audience{s.customAudience}
		}

		jwtData, err := SignJWT("", stdClaims, s.customClaims)
		if err != nil {
			_ = writeTokenErrorResponse(w, req, http.StatusInternalServerError, "server_error", err.Error())
			return
		}

		reply := struct {
			AccessToken string `json:"access_token"`
			IDToken     string `json:"id_token,omitempty"`
		}{
			AccessToken: jwtData,
			IDToken:     jwtData,
		}
		if s.omitIDToken {
			reply.IDToken = ""
		}
		if err := writeJSON(w, &reply); err != nil {
			return
		}

	case "/userinfo":
		if s.disableUserInfo {
			w.WriteHeader(http.StatusNotFound)
			return
		}
		if req.Method != "GET" {
			w.WriteHeader(http.StatusMethodNotAllowed)
			return
		}

		if err := writeJSON(w, s.replyUserinfo); err != nil {
			return
		}

	default:
		w.WriteHeader(http.StatusNotFound)
	}
}

func writeAuthErrorResponse(w http.ResponseWriter, req *http.Request, errorCode, errorMessage string) {
	qv := req.URL.Query()

	redirectURI := qv.Get("redirect_uri") +
		"?state=" + url.QueryEscape(qv.Get("state")) +
		"&error=" + url.QueryEscape(errorCode)

	if errorMessage != "" {
		redirectURI += "&error_description=" + url.QueryEscape(errorMessage)
	}

	http.Redirect(w, req, redirectURI, http.StatusFound)
}

func writeTokenErrorResponse(w http.ResponseWriter, req *http.Request, statusCode int, errorCode, errorMessage string) error {
	body := struct {
		Code string `json:"error"`
		Desc string `json:"error_description,omitempty"`
	}{
		Code: errorCode,
		Desc: errorMessage,
	}

	w.WriteHeader(statusCode)
	return writeJSON(w, &body)
}

// newJWKS converts a pem-encoded public key into JWKS data suitable for a
// verification endpoint response
func newJWKS(pubKey string) (*jose.JSONWebKeySet, error) {
	block, _ := pem.Decode([]byte(pubKey))
	if block == nil {
		return nil, fmt.Errorf("unable to decode public key")
	}
	input := block.Bytes

	pub, err := x509.ParsePKIXPublicKey(input)
	if err != nil {
		return nil, err
	}
	return &jose.JSONWebKeySet{
		Keys: []jose.JSONWebKey{
			jose.JSONWebKey{
				Key: pub,
			},
		},
	}, nil
}

func writeJSON(w http.ResponseWriter, out interface{}) error {
	enc := json.NewEncoder(w)
	return enc.Encode(out)
}

// SignJWT will bundle the provided claims into a signed JWT. The provided key
// is assumed to be ECDSA.
//
// If no private key is provided, the default package keys are used. These can
// be retrieved via the SigningKeys() method.
func SignJWT(privKey string, claims jwt.Claims, privateClaims interface{}) (string, error) {
	if privKey == "" {
		privKey = ecdsaPrivateKey
	}
	var key *ecdsa.PrivateKey
	block, _ := pem.Decode([]byte(privKey))
	if block != nil {
		var err error
		key, err = x509.ParseECPrivateKey(block.Bytes)
		if err != nil {
			return "", err
		}
	}

	sig, err := jose.NewSigner(
		jose.SigningKey{Algorithm: jose.ES256, Key: key},
		(&jose.SignerOptions{}).WithType("JWT"),
	)
	if err != nil {
		return "", err
	}

	raw, err := jwt.Signed(sig).
		Claims(claims).
		Claims(privateClaims).
		CompactSerialize()
	if err != nil {
		return "", err
	}

	return raw, nil
}

// httptestNewUnstartedServerWithPort is roughly the same as
// httptest.NewUnstartedServer() but allows the caller to explicitly choose the
// port if desired.
func httptestNewUnstartedServerWithPort(handler http.Handler, port int) *httptest.Server {
	if port == 0 {
		return httptest.NewUnstartedServer(handler)
	}
	addr := net.JoinHostPort("127.0.0.1", strconv.Itoa(port))
	l, err := net.Listen("tcp", addr)
	if err != nil {
		panic(fmt.Sprintf("httptest: failed to listen on a port: %v", err))
	}

	return &httptest.Server{
		Listener: l,
		Config:   &http.Server{Handler: handler},
	}
}

// SigningKeys returns the pem-encoded keys used to sign JWTs by default.
func SigningKeys() (pub, priv string) {
	return ecdsaPublicKey, ecdsaPrivateKey
}

var (
	ecdsaPublicKey  string
	ecdsaPrivateKey string
)

func init() {
	// Each time we run tests we generate a unique set of keys for use in the
	// test.  These are cached between runs but do not persist between restarts
	// of the test binary.
	var err error
	ecdsaPublicKey, ecdsaPrivateKey, err = generateKey()
	if err != nil {
		panic(err)
	}
}

func generateKey() (pub, priv string, err error) {
	privateKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
	if err != nil {
		return "", "", fmt.Errorf("error generating private key: %v", err)
	}

	{
		derBytes, err := x509.MarshalECPrivateKey(privateKey)
		if err != nil {
			return "", "", fmt.Errorf("error marshaling private key: %v", err)
		}
		pemBlock := &pem.Block{
			Type:  "EC PRIVATE KEY",
			Bytes: derBytes,
		}
		priv = string(pem.EncodeToMemory(pemBlock))
	}
	{
		derBytes, err := x509.MarshalPKIXPublicKey(privateKey.Public())
		if err != nil {
			return "", "", fmt.Errorf("error marshaling public key: %v", err)
		}
		pemBlock := &pem.Block{
			Type:  "PUBLIC KEY",
			Bytes: derBytes,
		}
		pub = string(pem.EncodeToMemory(pemBlock))
	}

	return pub, priv, nil
}