make it possible to use a custom tls.Config for listening and dialing (#22)

This commit is contained in:
Marten Seemann 2022-07-16 13:55:57 +00:00 committed by GitHub
parent 2823159a99
commit d74921df0a
3 changed files with 184 additions and 56 deletions

View File

@ -30,6 +30,8 @@ type listener struct {
transport tpt.Transport
noise *noise.Transport
certManager *certManager
staticTLSConf *tls.Config
rcmgr network.ResourceManager
gater connmgr.ConnectionGater
@ -48,7 +50,7 @@ type listener struct {
var _ tpt.Listener = &listener{}
func newListener(laddr ma.Multiaddr, transport tpt.Transport, noise *noise.Transport, certManager *certManager, gater connmgr.ConnectionGater, rcmgr network.ResourceManager) (tpt.Listener, error) {
func newListener(laddr ma.Multiaddr, transport tpt.Transport, noise *noise.Transport, certManager *certManager, tlsConf *tls.Config, gater connmgr.ConnectionGater, rcmgr network.ResourceManager) (tpt.Listener, error) {
network, addr, err := manet.DialArgs(laddr)
if err != nil {
return nil, err
@ -65,23 +67,23 @@ func newListener(laddr ma.Multiaddr, transport tpt.Transport, noise *noise.Trans
if err != nil {
return nil, err
}
if tlsConf == nil {
tlsConf = &tls.Config{GetConfigForClient: func(*tls.ClientHelloInfo) (*tls.Config, error) {
return certManager.GetConfig(), nil
}}
}
ln := &listener{
transport: transport,
noise: noise,
certManager: certManager,
staticTLSConf: tlsConf,
rcmgr: rcmgr,
gater: gater,
queue: make(chan tpt.CapableConn, queueLen),
serverClosed: make(chan struct{}),
addr: udpConn.LocalAddr(),
multiaddr: localMultiaddr,
server: webtransport.Server{
H3: http3.Server{
TLSConfig: &tls.Config{GetConfigForClient: func(*tls.ClientHelloInfo) (*tls.Config, error) {
return certManager.GetConfig(), nil
}},
},
},
server: webtransport.Server{H3: http3.Server{TLSConfig: tlsConf}},
}
ln.ctx, ln.ctxCancel = context.WithCancel(context.Background())
mux := http.NewServeMux()
@ -198,6 +200,9 @@ func (l *listener) Addr() net.Addr {
}
func (l *listener) Multiaddr() ma.Multiaddr {
if l.certManager == nil {
return l.multiaddr
}
return l.multiaddr.Encapsulate(l.certManager.AddrComponent())
}

View File

@ -4,6 +4,7 @@ import (
"context"
"crypto/tls"
"crypto/x509"
"errors"
"fmt"
"io"
"sync"
@ -43,6 +44,27 @@ func WithClock(cl clock.Clock) Option {
}
}
// WithTLSConfig sets a tls.Config used for listening.
// When used, the certificate from that config will be used, and no /certhash will be added to the listener's multiaddr.
// This is most useful when running a listener that has a valid (CA-signed) certificate.
func WithTLSConfig(c *tls.Config) Option {
return func(t *transport) error {
t.staticTLSConf = c
return nil
}
}
// WithTLSClientConfig sets a custom tls.Config used for dialing.
// This option is most useful for setting a custom tls.Config.RootCAs certificate pool.
// When dialing a multiaddr that contains a /certhash component, this library will set InsecureSkipVerify and
// overwrite the VerifyPeerCertificate callback.
func WithTLSClientConfig(c *tls.Config) Option {
return func(t *transport) error {
t.tlsClientConf = c
return nil
}
}
type transport struct {
privKey ic.PrivKey
pid peer.ID
@ -54,6 +76,8 @@ type transport struct {
listenOnce sync.Once
listenOnceErr error
certManager *certManager
staticTLSConf *tls.Config
tlsClientConf *tls.Config
noise *noise.Transport
}
@ -129,15 +153,21 @@ func (t *transport) Dial(ctx context.Context, raddr ma.Multiaddr, p peer.ID) (tp
func (t *transport) dial(ctx context.Context, addr string, certHashes []multihash.DecodedMultihash) (*webtransport.Session, error) {
url := fmt.Sprintf("https://%s%s", addr, webtransportHTTPEndpoint)
dialer := webtransport.Dialer{
RoundTripper: &http3.RoundTripper{
TLSClientConfig: &tls.Config{
InsecureSkipVerify: true, // this is not insecure. We verify the certificate ourselves.
VerifyPeerCertificate: func(rawCerts [][]byte, _ [][]*x509.Certificate) error {
var tlsConf *tls.Config
if t.tlsClientConf != nil {
tlsConf = t.tlsClientConf.Clone()
} else {
tlsConf = &tls.Config{}
}
if len(certHashes) > 0 {
tlsConf.InsecureSkipVerify = true // this is not insecure. We verify the certificate ourselves.
tlsConf.VerifyPeerCertificate = func(rawCerts [][]byte, _ [][]*x509.Certificate) error {
return verifyRawCerts(rawCerts, certHashes)
},
},
},
}
}
dialer := webtransport.Dialer{
RoundTripper: &http3.RoundTripper{TLSClientConfig: tlsConf},
}
rsp, sess, err := dialer.Dial(ctx, url, nil)
if err != nil {
@ -193,6 +223,14 @@ func (t *transport) checkEarlyData(b []byte) error {
return fmt.Errorf("failed to unmarshal early data protobuf: %w", err)
}
hashes := make([]multihash.DecodedMultihash, 0, len(msg.CertHashes))
if t.staticTLSConf != nil {
if len(hashes) > 0 {
return errors.New("using static TLS config, didn't expect any certificate hashes")
}
return nil
}
for _, h := range msg.CertHashes {
dh, err := multihash.Decode(h)
if err != nil {
@ -224,13 +262,15 @@ func (t *transport) Listen(laddr ma.Multiaddr) (tpt.Listener, error) {
if !webtransportMatcher.Matches(laddr) {
return nil, fmt.Errorf("cannot listen on non-WebTransport addr: %s", laddr)
}
if t.staticTLSConf == nil {
t.listenOnce.Do(func() {
t.certManager, t.listenOnceErr = newCertManager(t.clock)
})
if t.listenOnceErr != nil {
return nil, t.listenOnceErr
}
return newListener(laddr, t, t.noise, t.certManager, t.gater, t.rcmgr)
}
return newListener(laddr, t, t.noise, t.certManager, t.staticTLSConf, t.gater, t.rcmgr)
}
func (t *transport) Protocols() []int {

View File

@ -2,12 +2,19 @@ package libp2pwebtransport_test
import (
"context"
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rand"
"crypto/sha256"
"crypto/tls"
"crypto/x509"
"crypto/x509/pkix"
"errors"
"fmt"
"io"
"math/big"
"net"
"strings"
"testing"
"time"
@ -64,6 +71,19 @@ func stripCertHashes(addr ma.Multiaddr) ma.Multiaddr {
}
}
// create a /certhash multiaddr component using the SHA256 of foobar
func getCerthashComponent(t *testing.T, b []byte) ma.Multiaddr {
t.Helper()
h := sha256.Sum256(b)
mh, err := multihash.Encode(h[:], multihash.SHA2_256)
require.NoError(t, err)
certStr, err := multibase.Encode(multibase.Base58BTC, mh)
require.NoError(t, err)
ha, err := ma.NewComponent(ma.ProtocolWithCode(ma.P_CERTHASH).Name, certStr)
require.NoError(t, err)
return ha
}
func TestTransport(t *testing.T) {
serverID, serverKey := newIdentity(t)
tr, err := libp2pwebtransport.New(serverKey, nil, network.NullResourceManager)
@ -129,29 +149,11 @@ func TestHashVerification(t *testing.T) {
require.NoError(t, err)
defer tr2.(io.Closer).Close()
// create a hash component using the SHA256 of foobar
h := sha256.Sum256([]byte("foobar"))
mh, err := multihash.Encode(h[:], multihash.SHA2_256)
require.NoError(t, err)
certStr, err := multibase.Encode(multibase.Base58BTC, mh)
require.NoError(t, err)
foobarHash, err := ma.NewComponent(ma.ProtocolWithCode(ma.P_CERTHASH).Name, certStr)
require.NoError(t, err)
foobarHash := getCerthashComponent(t, []byte("foobar"))
t.Run("fails using only a wrong hash", func(t *testing.T) {
// replace the certificate hash in the multiaddr with a fake hash
addr := ln.Multiaddr()
// strip off all certhash components
for {
a, comp := ma.SplitLast(addr)
if comp.Protocol().Code != ma.P_CERTHASH {
break
}
addr = a
}
addr = addr.Encapsulate(foobarHash)
addr := stripCertHashes(ln.Multiaddr()).Encapsulate(foobarHash)
_, err := tr2.Dial(context.Background(), addr, serverID)
require.Error(t, err)
require.Contains(t, err.Error(), "CRYPTO_ERROR (0x12a): cert hash not found")
@ -424,3 +426,84 @@ func TestConnectionGaterInterceptSecured(t *testing.T) {
t.Fatal("timeout")
}
}
func getTLSConf(t *testing.T, ip net.IP, start, end time.Time) *tls.Config {
t.Helper()
certTempl := &x509.Certificate{
SerialNumber: big.NewInt(1234),
Subject: pkix.Name{Organization: []string{"webtransport"}},
NotBefore: start,
NotAfter: end,
IsCA: true,
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth, x509.ExtKeyUsageServerAuth},
KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageCertSign,
BasicConstraintsValid: true,
IPAddresses: []net.IP{ip},
}
priv, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
require.NoError(t, err)
caBytes, err := x509.CreateCertificate(rand.Reader, certTempl, certTempl, &priv.PublicKey, priv)
require.NoError(t, err)
cert, err := x509.ParseCertificate(caBytes)
require.NoError(t, err)
return &tls.Config{
Certificates: []tls.Certificate{{
Certificate: [][]byte{cert.Raw},
PrivateKey: priv,
Leaf: cert,
}},
}
}
func TestStaticTLSConf(t *testing.T) {
tlsConf := getTLSConf(t, net.ParseIP("127.0.0.1"), time.Now(), time.Now().Add(365*24*time.Hour))
serverID, serverKey := newIdentity(t)
tr, err := libp2pwebtransport.New(serverKey, nil, network.NullResourceManager, libp2pwebtransport.WithTLSConfig(tlsConf))
require.NoError(t, err)
defer tr.(io.Closer).Close()
ln, err := tr.Listen(ma.StringCast("/ip4/127.0.0.1/udp/0/quic/webtransport"))
require.NoError(t, err)
defer ln.Close()
require.Empty(t, extractCertHashes(ln.Multiaddr()), "listener address shouldn't contain any certhash")
t.Run("fails when the certificate is invalid", func(t *testing.T) {
_, key := newIdentity(t)
cl, err := libp2pwebtransport.New(key, nil, network.NullResourceManager)
require.NoError(t, err)
defer cl.(io.Closer).Close()
_, err = cl.Dial(context.Background(), ln.Multiaddr(), serverID)
require.Error(t, err)
if !strings.Contains(err.Error(), "certificate is not trusted") &&
!strings.Contains(err.Error(), "certificate signed by unknown authority") {
t.Fatalf("expected a certificate error, got %+v", err)
}
})
t.Run("fails when dialing with a wrong certhash", func(t *testing.T) {
_, key := newIdentity(t)
cl, err := libp2pwebtransport.New(key, nil, network.NullResourceManager)
require.NoError(t, err)
defer cl.(io.Closer).Close()
addr := ln.Multiaddr().Encapsulate(getCerthashComponent(t, []byte("foo")))
_, err = cl.Dial(context.Background(), addr, serverID)
require.Error(t, err)
require.Contains(t, err.Error(), "cert hash not found")
})
t.Run("accepts a valid TLS certificate", func(t *testing.T) {
_, key := newIdentity(t)
store := x509.NewCertPool()
store.AddCert(tlsConf.Certificates[0].Leaf)
tlsConf := &tls.Config{RootCAs: store}
cl, err := libp2pwebtransport.New(key, nil, network.NullResourceManager, libp2pwebtransport.WithTLSClientConfig(tlsConf))
require.NoError(t, err)
defer cl.(io.Closer).Close()
conn, err := cl.Dial(context.Background(), ln.Multiaddr(), serverID)
require.NoError(t, err)
defer conn.Close()
})
}