package grpc

import (
	"context"
	"crypto/tls"
	"fmt"
	"io"
	"net"
	"sync/atomic"
	"testing"
	"time"

	"github.com/stretchr/testify/require"
	"golang.org/x/sync/errgroup"
	"google.golang.org/grpc"

	"github.com/hashicorp/consul/agent/grpc/internal/testservice"
	"github.com/hashicorp/consul/agent/metadata"
	"github.com/hashicorp/consul/agent/pool"
	"github.com/hashicorp/consul/tlsutil"
	"github.com/hashicorp/go-hclog"
)

type testServer struct {
	addr     net.Addr
	name     string
	dc       string
	shutdown func()
	rpc      *fakeRPCListener
}

func (s testServer) Metadata() *metadata.Server {
	return &metadata.Server{
		ID:         s.name,
		Name:       s.name + "." + s.dc,
		ShortName:  s.name,
		Datacenter: s.dc,
		Addr:       s.addr,
		UseTLS:     s.rpc.tlsConf != nil,
	}
}

func newSimpleTestServer(t *testing.T, name, dc string, tlsConf *tlsutil.Configurator) testServer {
	return newTestServer(t, hclog.Default(), name, dc, tlsConf, func(server *grpc.Server) {
		testservice.RegisterSimpleServer(server, &simple{name: name, dc: dc})
	})
}

// newPanicTestServer sets up a simple server with handlers that panic.
func newPanicTestServer(t *testing.T, logger hclog.Logger, name, dc string, tlsConf *tlsutil.Configurator) testServer {
	return newTestServer(t, logger, name, dc, tlsConf, func(server *grpc.Server) {
		testservice.RegisterSimpleServer(server, &simplePanic{name: name, dc: dc})
	})
}

func newTestServer(t *testing.T, logger hclog.Logger, name, dc string, tlsConf *tlsutil.Configurator, register func(server *grpc.Server)) testServer {
	addr := &net.IPAddr{IP: net.ParseIP("127.0.0.1")}
	handler := NewHandler(logger, addr, register)

	lis, err := net.Listen("tcp", "127.0.0.1:0")
	require.NoError(t, err)

	rpc := &fakeRPCListener{t: t, handler: handler, tlsConf: tlsConf}

	g := errgroup.Group{}
	g.Go(func() error {
		if err := rpc.listen(lis); err != nil {
			return fmt.Errorf("fake rpc listen error: %w", err)
		}
		return nil
	})
	g.Go(func() error {
		if err := handler.Run(); err != nil {
			return fmt.Errorf("grpc server error: %w", err)
		}
		return nil
	})
	return testServer{
		addr: lis.Addr(),
		name: name,
		dc:   dc,
		rpc:  rpc,
		shutdown: func() {
			rpc.shutdown = true
			if err := lis.Close(); err != nil {
				t.Logf("listener closed with error: %v", err)
			}
			if err := handler.Shutdown(); err != nil {
				t.Logf("grpc server shutdown: %v", err)
			}
			if err := g.Wait(); err != nil {
				t.Log(err)
			}
		},
	}
}

type simple struct {
	name string
	dc   string
}

func (s *simple) Flow(_ *testservice.Req, flow testservice.Simple_FlowServer) error {
	for flow.Context().Err() == nil {
		resp := &testservice.Resp{ServerName: "one", Datacenter: s.dc}
		if err := flow.Send(resp); err != nil {
			return err
		}
		time.Sleep(time.Millisecond)
	}
	return nil
}

func (s *simple) Something(_ context.Context, _ *testservice.Req) (*testservice.Resp, error) {
	return &testservice.Resp{ServerName: s.name, Datacenter: s.dc}, nil
}

type simplePanic struct {
	name, dc string
}

func (s *simplePanic) Flow(_ *testservice.Req, flow testservice.Simple_FlowServer) error {
	for flow.Context().Err() == nil {
		time.Sleep(time.Millisecond)
		panic("panic from Flow")
	}
	return nil
}

func (s *simplePanic) Something(_ context.Context, _ *testservice.Req) (*testservice.Resp, error) {
	time.Sleep(time.Millisecond)
	panic("panic from Something")
}

// fakeRPCListener mimics agent/consul.Server.listen to handle the RPCType byte.
// In the future we should be able to refactor Server and extract this RPC
// handling logic so that we don't need to use a fake.
// For now, since this logic is in agent/consul, we can't easily use Server.listen
// so we fake it.
type fakeRPCListener struct {
	t                   *testing.T
	handler             *Handler
	shutdown            bool
	tlsConf             *tlsutil.Configurator
	tlsConnEstablished  int32
	alpnConnEstablished int32
}

func (f *fakeRPCListener) listen(listener net.Listener) error {
	for {
		conn, err := listener.Accept()
		if err != nil {
			if f.shutdown {
				return nil
			}
			return err
		}

		go f.handleConn(conn)
	}
}

func (f *fakeRPCListener) handleConn(conn net.Conn) {
	if f.tlsConf != nil && f.tlsConf.MutualTLSCapable() {
		// See if actually this is native TLS multiplexed onto the old
		// "type-byte" system.

		peekedConn, nativeTLS, err := pool.PeekForTLS(conn)
		if err != nil {
			if err != io.EOF {
				fmt.Printf("ERROR: failed to read first byte: %v\n", err)
			}
			conn.Close()
			return
		}

		if nativeTLS {
			f.handleNativeTLSConn(peekedConn)
			return
		}
		conn = peekedConn
	}

	buf := make([]byte, 1)

	if _, err := conn.Read(buf); err != nil {
		if err != io.EOF {
			fmt.Println("ERROR", err.Error())
		}
		conn.Close()
		return
	}
	typ := pool.RPCType(buf[0])

	switch typ {

	case pool.RPCGRPC:
		f.handler.Handle(conn)
		return

	case pool.RPCTLS:
		// occasionally we see a test client connecting to an rpc listener that
		// was created as part of another test, despite none of the tests running
		// in parallel.
		// Maybe some strange grpc behaviour? I'm not sure.
		if f.tlsConf == nil {
			fmt.Println("ERROR: tls is not configured")
			conn.Close()
			return
		}

		atomic.AddInt32(&f.tlsConnEstablished, 1)
		conn = tls.Server(conn, f.tlsConf.IncomingRPCConfig())
		f.handleConn(conn)

	default:
		fmt.Println("ERROR: unexpected byte", typ)
		conn.Close()
	}
}

func (f *fakeRPCListener) handleNativeTLSConn(conn net.Conn) {
	tlscfg := f.tlsConf.IncomingALPNRPCConfig(pool.RPCNextProtos)
	tlsConn := tls.Server(conn, tlscfg)

	// Force the handshake to conclude.
	if err := tlsConn.Handshake(); err != nil {
		fmt.Printf("ERROR: TLS handshake failed: %v", err)
		conn.Close()
		return
	}

	conn.SetReadDeadline(time.Time{})

	var (
		cs        = tlsConn.ConnectionState()
		nextProto = cs.NegotiatedProtocol
	)

	switch nextProto {
	case pool.ALPN_RPCGRPC:
		atomic.AddInt32(&f.alpnConnEstablished, 1)
		f.handler.Handle(tlsConn)

	default:
		fmt.Printf("ERROR: discarding RPC for unknown negotiated protocol %q\n", nextProto)
		conn.Close()
	}
}