webtransport: have the server send the certificates (#1757)

This commit is contained in:
Marten Seemann 2022-09-19 21:52:35 +03:00 committed by GitHub
parent 214b337106
commit 131e5bd828
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 93 additions and 55 deletions

View File

@ -1,7 +1,6 @@
package libp2pwebtransport
import (
"bytes"
"context"
"crypto/sha256"
"crypto/tls"
@ -9,6 +8,8 @@ import (
"sync"
"time"
pb "github.com/libp2p/go-libp2p/p2p/transport/webtransport/pb"
"github.com/benbjohnson/clock"
ma "github.com/multiformats/go-multiaddr"
"github.com/multiformats/go-multihash"
@ -54,6 +55,8 @@ type certManager struct {
currentConfig *certConfig
nextConfig *certConfig // nil until we have passed half the certValidity of the current config
addrComp ma.Multiaddr
protobuf []byte
}
func newCertManager(clock clock.Clock) (*certManager, error) {
@ -88,6 +91,9 @@ func (m *certManager) rollConfig() error {
m.lastConfig = m.currentConfig
m.currentConfig = m.nextConfig
m.nextConfig = c
if err := m.cacheProtobuf(); err != nil {
return err
}
return m.cacheAddrComponent()
}
@ -131,17 +137,33 @@ func (m *certManager) AddrComponent() ma.Multiaddr {
return m.addrComp
}
func (m *certManager) Verify(hashes []multihash.DecodedMultihash) error {
for _, h := range hashes {
if h.Code != multihash.SHA2_256 {
return fmt.Errorf("expected SHA256 hash, got %d", h.Code)
func (m *certManager) Protobuf() []byte {
return m.protobuf
}
if !bytes.Equal(h.Digest, m.currentConfig.sha256[:]) &&
(m.nextConfig == nil || !bytes.Equal(h.Digest, m.nextConfig.sha256[:])) &&
(m.lastConfig == nil || !bytes.Equal(h.Digest, m.lastConfig.sha256[:])) {
return fmt.Errorf("found unexpected hash: %+x", h.Digest)
func (m *certManager) cacheProtobuf() error {
hashes := make([][32]byte, 0, 3)
if m.lastConfig != nil {
hashes = append(hashes, m.lastConfig.sha256)
}
hashes = append(hashes, m.currentConfig.sha256)
if m.nextConfig != nil {
hashes = append(hashes, m.nextConfig.sha256)
}
msg := pb.WebTransport{CertHashes: make([][]byte, 0, len(hashes))}
for _, certHash := range hashes {
h, err := multihash.Encode(certHash[:], multihash.SHA2_256)
if err != nil {
return fmt.Errorf("failed to encode certificate hash: %w", err)
}
msg.CertHashes = append(msg.CertHashes, h)
}
msgBytes, err := msg.Marshal()
if err != nil {
return fmt.Errorf("failed to marshal WebTransport protobuf: %w", err)
}
m.protobuf = msgBytes
return nil
}

View File

@ -9,17 +9,17 @@ import (
"net/http"
"time"
pb "github.com/libp2p/go-libp2p/p2p/transport/webtransport/pb"
"github.com/libp2p/go-libp2p/core/connmgr"
"github.com/libp2p/go-libp2p/core/network"
tpt "github.com/libp2p/go-libp2p/core/transport"
"github.com/libp2p/go-libp2p/p2p/security/noise"
pb "github.com/libp2p/go-libp2p/p2p/transport/webtransport/pb"
"github.com/lucas-clemente/quic-go/http3"
"github.com/marten-seemann/webtransport-go"
ma "github.com/multiformats/go-multiaddr"
manet "github.com/multiformats/go-multiaddr/net"
"github.com/multiformats/go-multihash"
)
var errClosed = errors.New("closed")
@ -197,7 +197,19 @@ func (l *listener) handshake(ctx context.Context, sess *webtransport.Session) (*
if err != nil {
return nil, err
}
n, err := l.noise.WithSessionOptions(noise.EarlyData(nil, newEarlyDataReceiver(l.checkEarlyData)))
var earlyData []byte
if l.isStaticTLSConf {
var msg pb.WebTransport
var err error
earlyData, err = msg.Marshal()
if err != nil {
return nil, err
}
} else {
earlyData = l.certManager.Protobuf()
}
n, err := l.noise.WithSessionOptions(noise.EarlyData(nil, newEarlyDataSender(earlyData)))
if err != nil {
return nil, fmt.Errorf("failed to initialize Noise session: %w", err)
}
@ -212,31 +224,6 @@ func (l *listener) handshake(ctx context.Context, sess *webtransport.Session) (*
}, nil
}
func (l *listener) checkEarlyData(b []byte) error {
var msg pb.WebTransport
if err := msg.Unmarshal(b); err != nil {
fmt.Println(1)
return fmt.Errorf("failed to unmarshal early data protobuf: %w", err)
}
if l.isStaticTLSConf {
if len(msg.CertHashes) > 0 {
return errors.New("using static TLS config, didn't expect any certificate hashes")
}
return nil
}
hashes := make([]multihash.DecodedMultihash, 0, len(msg.CertHashes))
for _, h := range msg.CertHashes {
dh, err := multihash.Decode(h)
if err != nil {
return fmt.Errorf("failed to decode hash: %w", err)
}
hashes = append(hashes, *dh)
}
return l.certManager.Verify(hashes)
}
func (l *listener) Addr() net.Addr {
return l.addr
}

View File

@ -1,9 +1,11 @@
package libp2pwebtransport
import (
"bytes"
"context"
"crypto/tls"
"crypto/x509"
"errors"
"fmt"
"io"
"sync"
@ -196,19 +198,27 @@ func (t *transport) upgrade(ctx context.Context, sess *webtransport.Session, p p
// Now run a Noise handshake (using early data) and send all the certificate hashes that we would have accepted.
// The server will verify that it advertised all of these certificate hashes.
msg := pb.WebTransport{CertHashes: make([][]byte, 0, len(certHashes))}
for _, certHash := range certHashes {
h, err := multihash.Encode(certHash.Digest, certHash.Code)
var verified bool
n, err := t.noise.WithSessionOptions(noise.EarlyData(newEarlyDataReceiver(func(b []byte) error {
decodedCertHashes, err := decodeCertHashesFromProtobuf(b)
if err != nil {
return nil, fmt.Errorf("failed to encode certificate hash: %w", err)
return err
}
msg.CertHashes = append(msg.CertHashes, h)
for _, sent := range certHashes {
var found bool
for _, rcvd := range decodedCertHashes {
if sent.Code == rcvd.Code && bytes.Equal(sent.Digest, rcvd.Digest) {
found = true
break
}
msgBytes, err := msg.Marshal()
if err != nil {
return nil, fmt.Errorf("failed to marshal WebTransport protobuf: %w", err)
}
n, err := t.noise.WithSessionOptions(noise.EarlyData(newEarlyDataSender(msgBytes), nil))
if !found {
return fmt.Errorf("missing cert hash: %v", sent)
}
}
verified = true
return nil
}), nil))
if err != nil {
return nil, fmt.Errorf("failed to create Noise transport: %w", err)
}
@ -216,12 +226,34 @@ func (t *transport) upgrade(ctx context.Context, sess *webtransport.Session, p p
if err != nil {
return nil, err
}
// The Noise handshake _should_ guarantee that our verification callback is called.
// Double-check just in case.
if !verified {
return nil, errors.New("didn't verify")
}
return &connSecurityMultiaddrs{
ConnSecurity: c,
ConnMultiaddrs: &connMultiaddrs{local: local, remote: remote},
}, nil
}
func decodeCertHashesFromProtobuf(b []byte) ([]multihash.DecodedMultihash, error) {
var msg pb.WebTransport
if err := msg.Unmarshal(b); err != nil {
return nil, fmt.Errorf("failed to unmarshal early data protobuf: %w", err)
}
hashes := make([]multihash.DecodedMultihash, 0, len(msg.CertHashes))
for _, h := range msg.CertHashes {
dh, err := multihash.Decode(h)
if err != nil {
return nil, fmt.Errorf("failed to decode hash: %w", err)
}
hashes = append(hashes, *dh)
}
return hashes, nil
}
func (t *transport) CanDial(addr ma.Multiaddr) bool {
var numHashes int
ma.ForEach(addr, func(c ma.Component) bool {

View File

@ -162,11 +162,8 @@ func TestHashVerification(t *testing.T) {
})
t.Run("fails when adding a wrong hash", func(t *testing.T) {
conn, err := tr2.Dial(context.Background(), ln.Multiaddr().Encapsulate(foobarHash), serverID)
if err != nil {
_, err = conn.AcceptStream()
_, err := tr2.Dial(context.Background(), ln.Multiaddr().Encapsulate(foobarHash), serverID)
require.Error(t, err)
}
})
require.NoError(t, ln.Close())