consul/connect/testing.go
Ronald b64674623e
Copyright headers for missing files/folders (#16708)
* copyright headers for agent folder
2023-03-28 18:48:58 -04:00

197 lines
5.1 KiB
Go

// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: MPL-2.0
package connect
import (
"crypto/tls"
"crypto/x509"
"fmt"
"io"
"log"
"net"
"net/http"
"sync/atomic"
"github.com/hashicorp/go-hclog"
testing "github.com/mitchellh/go-testing-interface"
"github.com/hashicorp/consul/agent/connect"
"github.com/hashicorp/consul/agent/structs"
"github.com/hashicorp/consul/sdk/freeport"
)
// TestService returns a Service instance based on a static TLS Config.
func TestService(t testing.T, service string, ca *structs.CARoot) *Service {
t.Helper()
// Don't need to talk to client since we are setting TLSConfig locally
logger := hclog.New(&hclog.LoggerOptions{})
svc, err := NewDevServiceWithTLSConfig(service,
logger, TestTLSConfig(t, service, ca))
if err != nil {
t.Fatal(err)
}
return svc
}
// TestTLSConfig returns a *tls.Config suitable for use during tests.
func TestTLSConfig(t testing.T, service string, ca *structs.CARoot) *tls.Config {
t.Helper()
cfg := defaultTLSConfig()
cfg.Certificates = []tls.Certificate{TestSvcKeyPair(t, service, ca)}
cfg.RootCAs = TestCAPool(t, ca)
cfg.ClientCAs = TestCAPool(t, ca)
return cfg
}
// TestCAPool returns an *x509.CertPool containing the passed CA's root(s)
func TestCAPool(t testing.T, cas ...*structs.CARoot) *x509.CertPool {
t.Helper()
pool := x509.NewCertPool()
for _, ca := range cas {
pool.AppendCertsFromPEM([]byte(ca.RootCert))
}
return pool
}
// TestSvcKeyPair returns an tls.Certificate containing both cert and private
// key for a given service under a given CA from the testdata dir.
func TestSvcKeyPair(t testing.T, service string, ca *structs.CARoot) tls.Certificate {
t.Helper()
certPEM, keyPEM := connect.TestLeaf(t, service, ca)
cert, err := tls.X509KeyPair([]byte(certPEM), []byte(keyPEM))
if err != nil {
t.Fatal(err)
}
return cert
}
// TestPeerCertificates returns a []*x509.Certificate as you'd get from
// tls.Conn.ConnectionState().PeerCertificates including the named certificate.
func TestPeerCertificates(t testing.T, service string, ca *structs.CARoot) []*x509.Certificate {
t.Helper()
certPEM, _ := connect.TestLeaf(t, service, ca)
cert, err := connect.ParseCert(certPEM)
if err != nil {
t.Fatal(err)
}
return []*x509.Certificate{cert}
}
// TestServer runs a service listener that can be used to test clients. It's
// behavior can be controlled by the struct members.
type TestServer struct {
// The service name to serve.
Service string
// The (test) CA to use for generating certs.
CA *structs.CARoot
// TimeoutHandshake controls whether the listening server will complete a TLS
// handshake quickly enough.
TimeoutHandshake bool
// TLSCfg is the tls.Config that will be used. By default it's set up from the
// service and ca set.
TLSCfg *tls.Config
// Addr is the listen address. It is set to a random free port on `localhost`
// by default.
Addr string
// Listening is closed when the listener is run.
Listening chan struct{}
l net.Listener
stopFlag int32
stopChan chan struct{}
}
// NewTestServer returns a TestServer. It should be closed when test is
// complete.
func NewTestServer(t testing.T, service string, ca *structs.CARoot) *TestServer {
return &TestServer{
Service: service,
CA: ca,
stopChan: make(chan struct{}),
TLSCfg: TestTLSConfig(t, service, ca),
Addr: fmt.Sprintf("127.0.0.1:%d", freeport.GetOne(t)),
Listening: make(chan struct{}),
}
}
// Serve runs a tcp echo server and blocks until it is closed or errors. If
// TimeoutHandshake is set it won't start TLS handshake on new connections.
func (s *TestServer) Serve() error {
// Just accept TCP conn but so we can control timing of accept/handshake
l, err := net.Listen("tcp", s.Addr)
if err != nil {
return err
}
s.l = l
close(s.Listening)
log.Printf("test connect service listening on %s", s.Addr)
for {
conn, err := s.l.Accept()
if err != nil {
if atomic.LoadInt32(&s.stopFlag) == 1 {
return nil
}
return err
}
// Ignore the conn if we are not actively handshaking
if !s.TimeoutHandshake {
// Upgrade conn to TLS
conn = tls.Server(conn, s.TLSCfg)
// Run an echo service
log.Printf("test connect service accepted conn from %s, "+
" running echo service", conn.RemoteAddr())
go io.Copy(conn, conn)
}
// Close this conn when we stop
go func(c net.Conn) {
<-s.stopChan
c.Close()
}(conn)
}
}
// ServeHTTPS runs an HTTPS server with the given config. It invokes the passed
// Handler for all requests.
func (s *TestServer) ServeHTTPS(h http.Handler) error {
srv := http.Server{
Addr: s.Addr,
TLSConfig: s.TLSCfg,
Handler: h,
}
log.Printf("starting test connect HTTPS server on %s", s.Addr)
// Use our own listener so we can signal when it's ready.
l, err := net.Listen("tcp", s.Addr)
if err != nil {
return err
}
close(s.Listening)
s.l = l
log.Printf("test connect service listening on %s", s.Addr)
err = srv.ServeTLS(l, "", "")
if atomic.LoadInt32(&s.stopFlag) == 1 {
return nil
}
return err
}
// Close stops a TestServer
func (s *TestServer) Close() error {
old := atomic.SwapInt32(&s.stopFlag, 1)
if old == 0 {
if s.l != nil {
s.l.Close()
}
close(s.stopChan)
}
return nil
}