webtransport: have the server send the certificates (#1757)
This commit is contained in:
parent
214b337106
commit
131e5bd828
|
@ -1,7 +1,6 @@
|
|||
package libp2pwebtransport
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"crypto/tls"
|
||||
|
@ -9,6 +8,8 @@ import (
|
|||
"sync"
|
||||
"time"
|
||||
|
||||
pb "github.com/libp2p/go-libp2p/p2p/transport/webtransport/pb"
|
||||
|
||||
"github.com/benbjohnson/clock"
|
||||
ma "github.com/multiformats/go-multiaddr"
|
||||
"github.com/multiformats/go-multihash"
|
||||
|
@ -54,6 +55,8 @@ type certManager struct {
|
|||
currentConfig *certConfig
|
||||
nextConfig *certConfig // nil until we have passed half the certValidity of the current config
|
||||
addrComp ma.Multiaddr
|
||||
|
||||
protobuf []byte
|
||||
}
|
||||
|
||||
func newCertManager(clock clock.Clock) (*certManager, error) {
|
||||
|
@ -88,6 +91,9 @@ func (m *certManager) rollConfig() error {
|
|||
m.lastConfig = m.currentConfig
|
||||
m.currentConfig = m.nextConfig
|
||||
m.nextConfig = c
|
||||
if err := m.cacheProtobuf(); err != nil {
|
||||
return err
|
||||
}
|
||||
return m.cacheAddrComponent()
|
||||
}
|
||||
|
||||
|
@ -131,17 +137,33 @@ func (m *certManager) AddrComponent() ma.Multiaddr {
|
|||
return m.addrComp
|
||||
}
|
||||
|
||||
func (m *certManager) Verify(hashes []multihash.DecodedMultihash) error {
|
||||
for _, h := range hashes {
|
||||
if h.Code != multihash.SHA2_256 {
|
||||
return fmt.Errorf("expected SHA256 hash, got %d", h.Code)
|
||||
}
|
||||
if !bytes.Equal(h.Digest, m.currentConfig.sha256[:]) &&
|
||||
(m.nextConfig == nil || !bytes.Equal(h.Digest, m.nextConfig.sha256[:])) &&
|
||||
(m.lastConfig == nil || !bytes.Equal(h.Digest, m.lastConfig.sha256[:])) {
|
||||
return fmt.Errorf("found unexpected hash: %+x", h.Digest)
|
||||
}
|
||||
func (m *certManager) Protobuf() []byte {
|
||||
return m.protobuf
|
||||
}
|
||||
|
||||
func (m *certManager) cacheProtobuf() error {
|
||||
hashes := make([][32]byte, 0, 3)
|
||||
if m.lastConfig != nil {
|
||||
hashes = append(hashes, m.lastConfig.sha256)
|
||||
}
|
||||
hashes = append(hashes, m.currentConfig.sha256)
|
||||
if m.nextConfig != nil {
|
||||
hashes = append(hashes, m.nextConfig.sha256)
|
||||
}
|
||||
|
||||
msg := pb.WebTransport{CertHashes: make([][]byte, 0, len(hashes))}
|
||||
for _, certHash := range hashes {
|
||||
h, err := multihash.Encode(certHash[:], multihash.SHA2_256)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to encode certificate hash: %w", err)
|
||||
}
|
||||
msg.CertHashes = append(msg.CertHashes, h)
|
||||
}
|
||||
msgBytes, err := msg.Marshal()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal WebTransport protobuf: %w", err)
|
||||
}
|
||||
m.protobuf = msgBytes
|
||||
return nil
|
||||
}
|
||||
|
||||
|
|
|
@ -9,17 +9,17 @@ import (
|
|||
"net/http"
|
||||
"time"
|
||||
|
||||
pb "github.com/libp2p/go-libp2p/p2p/transport/webtransport/pb"
|
||||
|
||||
"github.com/libp2p/go-libp2p/core/connmgr"
|
||||
"github.com/libp2p/go-libp2p/core/network"
|
||||
tpt "github.com/libp2p/go-libp2p/core/transport"
|
||||
"github.com/libp2p/go-libp2p/p2p/security/noise"
|
||||
pb "github.com/libp2p/go-libp2p/p2p/transport/webtransport/pb"
|
||||
|
||||
"github.com/lucas-clemente/quic-go/http3"
|
||||
"github.com/marten-seemann/webtransport-go"
|
||||
ma "github.com/multiformats/go-multiaddr"
|
||||
manet "github.com/multiformats/go-multiaddr/net"
|
||||
"github.com/multiformats/go-multihash"
|
||||
)
|
||||
|
||||
var errClosed = errors.New("closed")
|
||||
|
@ -197,7 +197,19 @@ func (l *listener) handshake(ctx context.Context, sess *webtransport.Session) (*
|
|||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
n, err := l.noise.WithSessionOptions(noise.EarlyData(nil, newEarlyDataReceiver(l.checkEarlyData)))
|
||||
var earlyData []byte
|
||||
if l.isStaticTLSConf {
|
||||
var msg pb.WebTransport
|
||||
var err error
|
||||
earlyData, err = msg.Marshal()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
} else {
|
||||
earlyData = l.certManager.Protobuf()
|
||||
}
|
||||
|
||||
n, err := l.noise.WithSessionOptions(noise.EarlyData(nil, newEarlyDataSender(earlyData)))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to initialize Noise session: %w", err)
|
||||
}
|
||||
|
@ -212,31 +224,6 @@ func (l *listener) handshake(ctx context.Context, sess *webtransport.Session) (*
|
|||
}, nil
|
||||
}
|
||||
|
||||
func (l *listener) checkEarlyData(b []byte) error {
|
||||
var msg pb.WebTransport
|
||||
if err := msg.Unmarshal(b); err != nil {
|
||||
fmt.Println(1)
|
||||
return fmt.Errorf("failed to unmarshal early data protobuf: %w", err)
|
||||
}
|
||||
|
||||
if l.isStaticTLSConf {
|
||||
if len(msg.CertHashes) > 0 {
|
||||
return errors.New("using static TLS config, didn't expect any certificate hashes")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
hashes := make([]multihash.DecodedMultihash, 0, len(msg.CertHashes))
|
||||
for _, h := range msg.CertHashes {
|
||||
dh, err := multihash.Decode(h)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to decode hash: %w", err)
|
||||
}
|
||||
hashes = append(hashes, *dh)
|
||||
}
|
||||
return l.certManager.Verify(hashes)
|
||||
}
|
||||
|
||||
func (l *listener) Addr() net.Addr {
|
||||
return l.addr
|
||||
}
|
||||
|
|
|
@ -1,9 +1,11 @@
|
|||
package libp2pwebtransport
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"sync"
|
||||
|
@ -196,19 +198,27 @@ func (t *transport) upgrade(ctx context.Context, sess *webtransport.Session, p p
|
|||
|
||||
// Now run a Noise handshake (using early data) and send all the certificate hashes that we would have accepted.
|
||||
// The server will verify that it advertised all of these certificate hashes.
|
||||
msg := pb.WebTransport{CertHashes: make([][]byte, 0, len(certHashes))}
|
||||
for _, certHash := range certHashes {
|
||||
h, err := multihash.Encode(certHash.Digest, certHash.Code)
|
||||
var verified bool
|
||||
n, err := t.noise.WithSessionOptions(noise.EarlyData(newEarlyDataReceiver(func(b []byte) error {
|
||||
decodedCertHashes, err := decodeCertHashesFromProtobuf(b)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to encode certificate hash: %w", err)
|
||||
return err
|
||||
}
|
||||
msg.CertHashes = append(msg.CertHashes, h)
|
||||
}
|
||||
msgBytes, err := msg.Marshal()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to marshal WebTransport protobuf: %w", err)
|
||||
}
|
||||
n, err := t.noise.WithSessionOptions(noise.EarlyData(newEarlyDataSender(msgBytes), nil))
|
||||
for _, sent := range certHashes {
|
||||
var found bool
|
||||
for _, rcvd := range decodedCertHashes {
|
||||
if sent.Code == rcvd.Code && bytes.Equal(sent.Digest, rcvd.Digest) {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
return fmt.Errorf("missing cert hash: %v", sent)
|
||||
}
|
||||
}
|
||||
verified = true
|
||||
return nil
|
||||
}), nil))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create Noise transport: %w", err)
|
||||
}
|
||||
|
@ -216,12 +226,34 @@ func (t *transport) upgrade(ctx context.Context, sess *webtransport.Session, p p
|
|||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// The Noise handshake _should_ guarantee that our verification callback is called.
|
||||
// Double-check just in case.
|
||||
if !verified {
|
||||
return nil, errors.New("didn't verify")
|
||||
}
|
||||
return &connSecurityMultiaddrs{
|
||||
ConnSecurity: c,
|
||||
ConnMultiaddrs: &connMultiaddrs{local: local, remote: remote},
|
||||
}, nil
|
||||
}
|
||||
|
||||
func decodeCertHashesFromProtobuf(b []byte) ([]multihash.DecodedMultihash, error) {
|
||||
var msg pb.WebTransport
|
||||
if err := msg.Unmarshal(b); err != nil {
|
||||
return nil, fmt.Errorf("failed to unmarshal early data protobuf: %w", err)
|
||||
}
|
||||
|
||||
hashes := make([]multihash.DecodedMultihash, 0, len(msg.CertHashes))
|
||||
for _, h := range msg.CertHashes {
|
||||
dh, err := multihash.Decode(h)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to decode hash: %w", err)
|
||||
}
|
||||
hashes = append(hashes, *dh)
|
||||
}
|
||||
return hashes, nil
|
||||
}
|
||||
|
||||
func (t *transport) CanDial(addr ma.Multiaddr) bool {
|
||||
var numHashes int
|
||||
ma.ForEach(addr, func(c ma.Component) bool {
|
||||
|
|
|
@ -162,11 +162,8 @@ func TestHashVerification(t *testing.T) {
|
|||
})
|
||||
|
||||
t.Run("fails when adding a wrong hash", func(t *testing.T) {
|
||||
conn, err := tr2.Dial(context.Background(), ln.Multiaddr().Encapsulate(foobarHash), serverID)
|
||||
if err != nil {
|
||||
_, err = conn.AcceptStream()
|
||||
require.Error(t, err)
|
||||
}
|
||||
_, err := tr2.Dial(context.Background(), ln.Multiaddr().Encapsulate(foobarHash), serverID)
|
||||
require.Error(t, err)
|
||||
})
|
||||
|
||||
require.NoError(t, ln.Close())
|
||||
|
|
Loading…
Reference in New Issue