make it possible to use a custom tls.Config for listening and dialing (#22)
This commit is contained in:
parent
2823159a99
commit
d74921df0a
|
@ -27,11 +27,13 @@ const queueLen = 16
|
|||
const handshakeTimeout = 10 * time.Second
|
||||
|
||||
type listener struct {
|
||||
transport tpt.Transport
|
||||
noise *noise.Transport
|
||||
certManager *certManager
|
||||
rcmgr network.ResourceManager
|
||||
gater connmgr.ConnectionGater
|
||||
transport tpt.Transport
|
||||
noise *noise.Transport
|
||||
certManager *certManager
|
||||
staticTLSConf *tls.Config
|
||||
|
||||
rcmgr network.ResourceManager
|
||||
gater connmgr.ConnectionGater
|
||||
|
||||
server webtransport.Server
|
||||
|
||||
|
@ -48,7 +50,7 @@ type listener struct {
|
|||
|
||||
var _ tpt.Listener = &listener{}
|
||||
|
||||
func newListener(laddr ma.Multiaddr, transport tpt.Transport, noise *noise.Transport, certManager *certManager, gater connmgr.ConnectionGater, rcmgr network.ResourceManager) (tpt.Listener, error) {
|
||||
func newListener(laddr ma.Multiaddr, transport tpt.Transport, noise *noise.Transport, certManager *certManager, tlsConf *tls.Config, gater connmgr.ConnectionGater, rcmgr network.ResourceManager) (tpt.Listener, error) {
|
||||
network, addr, err := manet.DialArgs(laddr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
@ -65,23 +67,23 @@ func newListener(laddr ma.Multiaddr, transport tpt.Transport, noise *noise.Trans
|
|||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if tlsConf == nil {
|
||||
tlsConf = &tls.Config{GetConfigForClient: func(*tls.ClientHelloInfo) (*tls.Config, error) {
|
||||
return certManager.GetConfig(), nil
|
||||
}}
|
||||
}
|
||||
ln := &listener{
|
||||
transport: transport,
|
||||
noise: noise,
|
||||
certManager: certManager,
|
||||
rcmgr: rcmgr,
|
||||
gater: gater,
|
||||
queue: make(chan tpt.CapableConn, queueLen),
|
||||
serverClosed: make(chan struct{}),
|
||||
addr: udpConn.LocalAddr(),
|
||||
multiaddr: localMultiaddr,
|
||||
server: webtransport.Server{
|
||||
H3: http3.Server{
|
||||
TLSConfig: &tls.Config{GetConfigForClient: func(*tls.ClientHelloInfo) (*tls.Config, error) {
|
||||
return certManager.GetConfig(), nil
|
||||
}},
|
||||
},
|
||||
},
|
||||
transport: transport,
|
||||
noise: noise,
|
||||
certManager: certManager,
|
||||
staticTLSConf: tlsConf,
|
||||
rcmgr: rcmgr,
|
||||
gater: gater,
|
||||
queue: make(chan tpt.CapableConn, queueLen),
|
||||
serverClosed: make(chan struct{}),
|
||||
addr: udpConn.LocalAddr(),
|
||||
multiaddr: localMultiaddr,
|
||||
server: webtransport.Server{H3: http3.Server{TLSConfig: tlsConf}},
|
||||
}
|
||||
ln.ctx, ln.ctxCancel = context.WithCancel(context.Background())
|
||||
mux := http.NewServeMux()
|
||||
|
@ -198,6 +200,9 @@ func (l *listener) Addr() net.Addr {
|
|||
}
|
||||
|
||||
func (l *listener) Multiaddr() ma.Multiaddr {
|
||||
if l.certManager == nil {
|
||||
return l.multiaddr
|
||||
}
|
||||
return l.multiaddr.Encapsulate(l.certManager.AddrComponent())
|
||||
}
|
||||
|
||||
|
|
|
@ -4,6 +4,7 @@ import (
|
|||
"context"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"sync"
|
||||
|
@ -43,6 +44,27 @@ func WithClock(cl clock.Clock) Option {
|
|||
}
|
||||
}
|
||||
|
||||
// WithTLSConfig sets a tls.Config used for listening.
|
||||
// When used, the certificate from that config will be used, and no /certhash will be added to the listener's multiaddr.
|
||||
// This is most useful when running a listener that has a valid (CA-signed) certificate.
|
||||
func WithTLSConfig(c *tls.Config) Option {
|
||||
return func(t *transport) error {
|
||||
t.staticTLSConf = c
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// WithTLSClientConfig sets a custom tls.Config used for dialing.
|
||||
// This option is most useful for setting a custom tls.Config.RootCAs certificate pool.
|
||||
// When dialing a multiaddr that contains a /certhash component, this library will set InsecureSkipVerify and
|
||||
// overwrite the VerifyPeerCertificate callback.
|
||||
func WithTLSClientConfig(c *tls.Config) Option {
|
||||
return func(t *transport) error {
|
||||
t.tlsClientConf = c
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
type transport struct {
|
||||
privKey ic.PrivKey
|
||||
pid peer.ID
|
||||
|
@ -54,6 +76,8 @@ type transport struct {
|
|||
listenOnce sync.Once
|
||||
listenOnceErr error
|
||||
certManager *certManager
|
||||
staticTLSConf *tls.Config
|
||||
tlsClientConf *tls.Config
|
||||
|
||||
noise *noise.Transport
|
||||
}
|
||||
|
@ -129,15 +153,21 @@ func (t *transport) Dial(ctx context.Context, raddr ma.Multiaddr, p peer.ID) (tp
|
|||
|
||||
func (t *transport) dial(ctx context.Context, addr string, certHashes []multihash.DecodedMultihash) (*webtransport.Session, error) {
|
||||
url := fmt.Sprintf("https://%s%s", addr, webtransportHTTPEndpoint)
|
||||
var tlsConf *tls.Config
|
||||
if t.tlsClientConf != nil {
|
||||
tlsConf = t.tlsClientConf.Clone()
|
||||
} else {
|
||||
tlsConf = &tls.Config{}
|
||||
}
|
||||
|
||||
if len(certHashes) > 0 {
|
||||
tlsConf.InsecureSkipVerify = true // this is not insecure. We verify the certificate ourselves.
|
||||
tlsConf.VerifyPeerCertificate = func(rawCerts [][]byte, _ [][]*x509.Certificate) error {
|
||||
return verifyRawCerts(rawCerts, certHashes)
|
||||
}
|
||||
}
|
||||
dialer := webtransport.Dialer{
|
||||
RoundTripper: &http3.RoundTripper{
|
||||
TLSClientConfig: &tls.Config{
|
||||
InsecureSkipVerify: true, // this is not insecure. We verify the certificate ourselves.
|
||||
VerifyPeerCertificate: func(rawCerts [][]byte, _ [][]*x509.Certificate) error {
|
||||
return verifyRawCerts(rawCerts, certHashes)
|
||||
},
|
||||
},
|
||||
},
|
||||
RoundTripper: &http3.RoundTripper{TLSClientConfig: tlsConf},
|
||||
}
|
||||
rsp, sess, err := dialer.Dial(ctx, url, nil)
|
||||
if err != nil {
|
||||
|
@ -193,6 +223,14 @@ func (t *transport) checkEarlyData(b []byte) error {
|
|||
return fmt.Errorf("failed to unmarshal early data protobuf: %w", err)
|
||||
}
|
||||
hashes := make([]multihash.DecodedMultihash, 0, len(msg.CertHashes))
|
||||
|
||||
if t.staticTLSConf != nil {
|
||||
if len(hashes) > 0 {
|
||||
return errors.New("using static TLS config, didn't expect any certificate hashes")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
for _, h := range msg.CertHashes {
|
||||
dh, err := multihash.Decode(h)
|
||||
if err != nil {
|
||||
|
@ -224,13 +262,15 @@ func (t *transport) Listen(laddr ma.Multiaddr) (tpt.Listener, error) {
|
|||
if !webtransportMatcher.Matches(laddr) {
|
||||
return nil, fmt.Errorf("cannot listen on non-WebTransport addr: %s", laddr)
|
||||
}
|
||||
t.listenOnce.Do(func() {
|
||||
t.certManager, t.listenOnceErr = newCertManager(t.clock)
|
||||
})
|
||||
if t.listenOnceErr != nil {
|
||||
return nil, t.listenOnceErr
|
||||
if t.staticTLSConf == nil {
|
||||
t.listenOnce.Do(func() {
|
||||
t.certManager, t.listenOnceErr = newCertManager(t.clock)
|
||||
})
|
||||
if t.listenOnceErr != nil {
|
||||
return nil, t.listenOnceErr
|
||||
}
|
||||
}
|
||||
return newListener(laddr, t, t.noise, t.certManager, t.gater, t.rcmgr)
|
||||
return newListener(laddr, t, t.noise, t.certManager, t.staticTLSConf, t.gater, t.rcmgr)
|
||||
}
|
||||
|
||||
func (t *transport) Protocols() []int {
|
||||
|
|
|
@ -2,12 +2,19 @@ package libp2pwebtransport_test
|
|||
|
||||
import (
|
||||
"context"
|
||||
"crypto/ecdsa"
|
||||
"crypto/elliptic"
|
||||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"crypto/x509/pkix"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"math/big"
|
||||
"net"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
|
@ -64,6 +71,19 @@ func stripCertHashes(addr ma.Multiaddr) ma.Multiaddr {
|
|||
}
|
||||
}
|
||||
|
||||
// create a /certhash multiaddr component using the SHA256 of foobar
|
||||
func getCerthashComponent(t *testing.T, b []byte) ma.Multiaddr {
|
||||
t.Helper()
|
||||
h := sha256.Sum256(b)
|
||||
mh, err := multihash.Encode(h[:], multihash.SHA2_256)
|
||||
require.NoError(t, err)
|
||||
certStr, err := multibase.Encode(multibase.Base58BTC, mh)
|
||||
require.NoError(t, err)
|
||||
ha, err := ma.NewComponent(ma.ProtocolWithCode(ma.P_CERTHASH).Name, certStr)
|
||||
require.NoError(t, err)
|
||||
return ha
|
||||
}
|
||||
|
||||
func TestTransport(t *testing.T) {
|
||||
serverID, serverKey := newIdentity(t)
|
||||
tr, err := libp2pwebtransport.New(serverKey, nil, network.NullResourceManager)
|
||||
|
@ -129,29 +149,11 @@ func TestHashVerification(t *testing.T) {
|
|||
require.NoError(t, err)
|
||||
defer tr2.(io.Closer).Close()
|
||||
|
||||
// create a hash component using the SHA256 of foobar
|
||||
h := sha256.Sum256([]byte("foobar"))
|
||||
mh, err := multihash.Encode(h[:], multihash.SHA2_256)
|
||||
require.NoError(t, err)
|
||||
certStr, err := multibase.Encode(multibase.Base58BTC, mh)
|
||||
require.NoError(t, err)
|
||||
foobarHash, err := ma.NewComponent(ma.ProtocolWithCode(ma.P_CERTHASH).Name, certStr)
|
||||
require.NoError(t, err)
|
||||
foobarHash := getCerthashComponent(t, []byte("foobar"))
|
||||
|
||||
t.Run("fails using only a wrong hash", func(t *testing.T) {
|
||||
// replace the certificate hash in the multiaddr with a fake hash
|
||||
addr := ln.Multiaddr()
|
||||
// strip off all certhash components
|
||||
for {
|
||||
a, comp := ma.SplitLast(addr)
|
||||
if comp.Protocol().Code != ma.P_CERTHASH {
|
||||
break
|
||||
}
|
||||
addr = a
|
||||
}
|
||||
|
||||
addr = addr.Encapsulate(foobarHash)
|
||||
|
||||
addr := stripCertHashes(ln.Multiaddr()).Encapsulate(foobarHash)
|
||||
_, err := tr2.Dial(context.Background(), addr, serverID)
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "CRYPTO_ERROR (0x12a): cert hash not found")
|
||||
|
@ -424,3 +426,84 @@ func TestConnectionGaterInterceptSecured(t *testing.T) {
|
|||
t.Fatal("timeout")
|
||||
}
|
||||
}
|
||||
|
||||
func getTLSConf(t *testing.T, ip net.IP, start, end time.Time) *tls.Config {
|
||||
t.Helper()
|
||||
certTempl := &x509.Certificate{
|
||||
SerialNumber: big.NewInt(1234),
|
||||
Subject: pkix.Name{Organization: []string{"webtransport"}},
|
||||
NotBefore: start,
|
||||
NotAfter: end,
|
||||
IsCA: true,
|
||||
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth, x509.ExtKeyUsageServerAuth},
|
||||
KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageCertSign,
|
||||
BasicConstraintsValid: true,
|
||||
IPAddresses: []net.IP{ip},
|
||||
}
|
||||
priv, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
||||
require.NoError(t, err)
|
||||
caBytes, err := x509.CreateCertificate(rand.Reader, certTempl, certTempl, &priv.PublicKey, priv)
|
||||
require.NoError(t, err)
|
||||
cert, err := x509.ParseCertificate(caBytes)
|
||||
require.NoError(t, err)
|
||||
return &tls.Config{
|
||||
Certificates: []tls.Certificate{{
|
||||
Certificate: [][]byte{cert.Raw},
|
||||
PrivateKey: priv,
|
||||
Leaf: cert,
|
||||
}},
|
||||
}
|
||||
}
|
||||
|
||||
func TestStaticTLSConf(t *testing.T) {
|
||||
tlsConf := getTLSConf(t, net.ParseIP("127.0.0.1"), time.Now(), time.Now().Add(365*24*time.Hour))
|
||||
|
||||
serverID, serverKey := newIdentity(t)
|
||||
tr, err := libp2pwebtransport.New(serverKey, nil, network.NullResourceManager, libp2pwebtransport.WithTLSConfig(tlsConf))
|
||||
require.NoError(t, err)
|
||||
defer tr.(io.Closer).Close()
|
||||
ln, err := tr.Listen(ma.StringCast("/ip4/127.0.0.1/udp/0/quic/webtransport"))
|
||||
require.NoError(t, err)
|
||||
defer ln.Close()
|
||||
require.Empty(t, extractCertHashes(ln.Multiaddr()), "listener address shouldn't contain any certhash")
|
||||
|
||||
t.Run("fails when the certificate is invalid", func(t *testing.T) {
|
||||
_, key := newIdentity(t)
|
||||
cl, err := libp2pwebtransport.New(key, nil, network.NullResourceManager)
|
||||
require.NoError(t, err)
|
||||
defer cl.(io.Closer).Close()
|
||||
|
||||
_, err = cl.Dial(context.Background(), ln.Multiaddr(), serverID)
|
||||
require.Error(t, err)
|
||||
if !strings.Contains(err.Error(), "certificate is not trusted") &&
|
||||
!strings.Contains(err.Error(), "certificate signed by unknown authority") {
|
||||
t.Fatalf("expected a certificate error, got %+v", err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("fails when dialing with a wrong certhash", func(t *testing.T) {
|
||||
_, key := newIdentity(t)
|
||||
cl, err := libp2pwebtransport.New(key, nil, network.NullResourceManager)
|
||||
require.NoError(t, err)
|
||||
defer cl.(io.Closer).Close()
|
||||
|
||||
addr := ln.Multiaddr().Encapsulate(getCerthashComponent(t, []byte("foo")))
|
||||
_, err = cl.Dial(context.Background(), addr, serverID)
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "cert hash not found")
|
||||
})
|
||||
|
||||
t.Run("accepts a valid TLS certificate", func(t *testing.T) {
|
||||
_, key := newIdentity(t)
|
||||
store := x509.NewCertPool()
|
||||
store.AddCert(tlsConf.Certificates[0].Leaf)
|
||||
tlsConf := &tls.Config{RootCAs: store}
|
||||
cl, err := libp2pwebtransport.New(key, nil, network.NullResourceManager, libp2pwebtransport.WithTLSClientConfig(tlsConf))
|
||||
require.NoError(t, err)
|
||||
defer cl.(io.Closer).Close()
|
||||
|
||||
conn, err := cl.Dial(context.Background(), ln.Multiaddr(), serverID)
|
||||
require.NoError(t, err)
|
||||
defer conn.Close()
|
||||
})
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue