Merge pull request #1762 from libp2p/noise-extensions

noise: switch to proto2, use the new NoiseExtensions protobuf
This commit is contained in:
Marten Seemann 2022-09-20 21:42:57 +03:00 committed by GitHub
commit bbf3159100
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 300 additions and 446 deletions

View File

@ -99,7 +99,7 @@ func (s *secureSession) runHandshake(ctx context.Context) (err error) {
// stage 2 //
// Handshake Msg Len = len(DHT static key) + MAC(static key is encrypted) + len(Payload) + MAC(payload is encrypted)
var ed []byte
var ed *pb.NoiseExtensions
if s.initiatorEarlyDataHandler != nil {
ed = s.initiatorEarlyDataHandler.Send(ctx, s.insecureConn, s.remoteID)
}
@ -120,7 +120,7 @@ func (s *secureSession) runHandshake(ctx context.Context) (err error) {
// stage 1 //
// Handshake Msg Len = len(DH ephemeral key) + len(DHT static key) + MAC(static key is encrypted) + len(Payload) +
// MAC(payload is encrypted)
var ed []byte
var ed *pb.NoiseExtensions
if s.responderEarlyDataHandler != nil {
ed = s.responderEarlyDataHandler.Send(ctx, s.insecureConn, s.remoteID)
}
@ -224,7 +224,7 @@ func (s *secureSession) readHandshakeMessage(hs *noise.HandshakeState) ([]byte,
// generateHandshakePayload creates a libp2p handshake payload with a
// signature of our static noise key.
func (s *secureSession) generateHandshakePayload(localStatic noise.DHKey, data []byte) ([]byte, error) {
func (s *secureSession) generateHandshakePayload(localStatic noise.DHKey, ext *pb.NoiseExtensions) ([]byte, error) {
// obtain the public key from the handshake session, so we can sign it with
// our libp2p secret key.
localKeyRaw, err := crypto.MarshalPublicKey(s.LocalPublicKey())
@ -243,7 +243,7 @@ func (s *secureSession) generateHandshakePayload(localStatic noise.DHKey, data [
payloadEnc, err := proto.Marshal(&pb.NoiseHandshakePayload{
IdentityKey: localKeyRaw,
IdentitySig: signedPayload,
Data: data,
Extensions: ext,
})
if err != nil {
return nil, fmt.Errorf("error marshaling handshake payload: %w", err)
@ -254,7 +254,7 @@ func (s *secureSession) generateHandshakePayload(localStatic noise.DHKey, data [
// handleRemoteHandshakePayload unmarshals the handshake payload object sent
// by the remote peer and validates the signature against the peer's static Noise key.
// It returns the data attached to the payload.
func (s *secureSession) handleRemoteHandshakePayload(payload []byte, remoteStatic []byte) ([]byte, error) {
func (s *secureSession) handleRemoteHandshakePayload(payload []byte, remoteStatic []byte) (*pb.NoiseExtensions, error) {
// unmarshal payload
nhp := new(pb.NoiseHandshakePayload)
err := proto.Unmarshal(payload, nhp)
@ -293,5 +293,5 @@ func (s *secureSession) handleRemoteHandshakePayload(payload []byte, remoteStati
// set remote peer key and id
s.remoteID = id
s.remoteKey = remotePubKey
return nhp.Data, nil
return nhp.Extensions, nil
}

View File

@ -22,17 +22,61 @@ var _ = math.Inf
// proto package needs to be updated.
const _ = proto.GoGoProtoPackageIsVersion3 // please upgrade the proto package
type NoiseExtensions struct {
WebtransportCerthashes [][]byte `protobuf:"bytes,1,rep,name=webtransport_certhashes,json=webtransportCerthashes" json:"webtransport_certhashes,omitempty"`
}
func (m *NoiseExtensions) Reset() { *m = NoiseExtensions{} }
func (m *NoiseExtensions) String() string { return proto.CompactTextString(m) }
func (*NoiseExtensions) ProtoMessage() {}
func (*NoiseExtensions) Descriptor() ([]byte, []int) {
return fileDescriptor_678c914f1bee6d56, []int{0}
}
func (m *NoiseExtensions) XXX_Unmarshal(b []byte) error {
return m.Unmarshal(b)
}
func (m *NoiseExtensions) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) {
if deterministic {
return xxx_messageInfo_NoiseExtensions.Marshal(b, m, deterministic)
} else {
b = b[:cap(b)]
n, err := m.MarshalToSizedBuffer(b)
if err != nil {
return nil, err
}
return b[:n], nil
}
}
func (m *NoiseExtensions) XXX_Merge(src proto.Message) {
xxx_messageInfo_NoiseExtensions.Merge(m, src)
}
func (m *NoiseExtensions) XXX_Size() int {
return m.Size()
}
func (m *NoiseExtensions) XXX_DiscardUnknown() {
xxx_messageInfo_NoiseExtensions.DiscardUnknown(m)
}
var xxx_messageInfo_NoiseExtensions proto.InternalMessageInfo
func (m *NoiseExtensions) GetWebtransportCerthashes() [][]byte {
if m != nil {
return m.WebtransportCerthashes
}
return nil
}
type NoiseHandshakePayload struct {
IdentityKey []byte `protobuf:"bytes,1,opt,name=identity_key,json=identityKey,proto3" json:"identity_key,omitempty"`
IdentitySig []byte `protobuf:"bytes,2,opt,name=identity_sig,json=identitySig,proto3" json:"identity_sig,omitempty"`
Data []byte `protobuf:"bytes,3,opt,name=data,proto3" json:"data,omitempty"`
IdentityKey []byte `protobuf:"bytes,1,opt,name=identity_key,json=identityKey" json:"identity_key"`
IdentitySig []byte `protobuf:"bytes,2,opt,name=identity_sig,json=identitySig" json:"identity_sig"`
Extensions *NoiseExtensions `protobuf:"bytes,4,opt,name=extensions" json:"extensions,omitempty"`
}
func (m *NoiseHandshakePayload) Reset() { *m = NoiseHandshakePayload{} }
func (m *NoiseHandshakePayload) String() string { return proto.CompactTextString(m) }
func (*NoiseHandshakePayload) ProtoMessage() {}
func (*NoiseHandshakePayload) Descriptor() ([]byte, []int) {
return fileDescriptor_678c914f1bee6d56, []int{0}
return fileDescriptor_678c914f1bee6d56, []int{1}
}
func (m *NoiseHandshakePayload) XXX_Unmarshal(b []byte) error {
return m.Unmarshal(b)
@ -75,31 +119,68 @@ func (m *NoiseHandshakePayload) GetIdentitySig() []byte {
return nil
}
func (m *NoiseHandshakePayload) GetData() []byte {
func (m *NoiseHandshakePayload) GetExtensions() *NoiseExtensions {
if m != nil {
return m.Data
return m.Extensions
}
return nil
}
func init() {
proto.RegisterType((*NoiseExtensions)(nil), "pb.NoiseExtensions")
proto.RegisterType((*NoiseHandshakePayload)(nil), "pb.NoiseHandshakePayload")
}
func init() { proto.RegisterFile("payload.proto", fileDescriptor_678c914f1bee6d56) }
var fileDescriptor_678c914f1bee6d56 = []byte{
// 152 bytes of a gzipped FileDescriptorProto
// 221 bytes of a gzipped FileDescriptorProto
0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0xe2, 0xe2, 0x2d, 0x48, 0xac, 0xcc,
0xc9, 0x4f, 0x4c, 0xd1, 0x2b, 0x28, 0xca, 0x2f, 0xc9, 0x17, 0x62, 0x2a, 0x48, 0x52, 0x2a, 0xe4,
0x12, 0xf5, 0xcb, 0xcf, 0x2c, 0x4e, 0xf5, 0x48, 0xcc, 0x4b, 0x29, 0xce, 0x48, 0xcc, 0x4e, 0x0d,
0x80, 0x28, 0x11, 0x52, 0xe4, 0xe2, 0xc9, 0x4c, 0x49, 0xcd, 0x2b, 0xc9, 0x2c, 0xa9, 0x8c, 0xcf,
0x4e, 0xad, 0x94, 0x60, 0x54, 0x60, 0xd4, 0xe0, 0x09, 0xe2, 0x86, 0x89, 0x79, 0xa7, 0x56, 0xa2,
0x28, 0x29, 0xce, 0x4c, 0x97, 0x60, 0x42, 0x55, 0x12, 0x9c, 0x99, 0x2e, 0x24, 0xc4, 0xc5, 0x92,
0x92, 0x58, 0x92, 0x28, 0xc1, 0x0c, 0x96, 0x02, 0xb3, 0x9d, 0x24, 0x4e, 0x3c, 0x92, 0x63, 0xbc,
0xf0, 0x48, 0x8e, 0xf1, 0xc1, 0x23, 0x39, 0xc6, 0x09, 0x8f, 0xe5, 0x18, 0x2e, 0x3c, 0x96, 0x63,
0xb8, 0xf1, 0x58, 0x8e, 0x21, 0x89, 0x0d, 0xec, 0x2e, 0x63, 0x40, 0x00, 0x00, 0x00, 0xff, 0xff,
0x51, 0x37, 0xd7, 0x40, 0xa8, 0x00, 0x00, 0x00,
0xc9, 0x4f, 0x4c, 0xd1, 0x2b, 0x28, 0xca, 0x2f, 0xc9, 0x17, 0x62, 0x2a, 0x48, 0x52, 0xf2, 0xe2,
0xe2, 0xf7, 0xcb, 0xcf, 0x2c, 0x4e, 0x75, 0xad, 0x28, 0x49, 0xcd, 0x2b, 0xce, 0xcc, 0xcf, 0x2b,
0x16, 0x32, 0xe7, 0x12, 0x2f, 0x4f, 0x4d, 0x2a, 0x29, 0x4a, 0xcc, 0x2b, 0x2e, 0xc8, 0x2f, 0x2a,
0x89, 0x4f, 0x4e, 0x2d, 0x2a, 0xc9, 0x48, 0x2c, 0xce, 0x48, 0x2d, 0x96, 0x60, 0x54, 0x60, 0xd6,
0xe0, 0x09, 0x12, 0x43, 0x96, 0x76, 0x86, 0xcb, 0x2a, 0xcd, 0x63, 0xe4, 0x12, 0x05, 0x1b, 0xe6,
0x91, 0x98, 0x97, 0x52, 0x9c, 0x91, 0x98, 0x9d, 0x1a, 0x00, 0xb1, 0x4f, 0x48, 0x9d, 0x8b, 0x27,
0x33, 0x25, 0x35, 0xaf, 0x24, 0xb3, 0xa4, 0x32, 0x3e, 0x3b, 0xb5, 0x52, 0x82, 0x51, 0x81, 0x51,
0x83, 0xc7, 0x89, 0xe5, 0xc4, 0x3d, 0x79, 0x86, 0x20, 0x6e, 0x98, 0x8c, 0x77, 0x6a, 0x25, 0x8a,
0xc2, 0xe2, 0xcc, 0x74, 0x09, 0x26, 0x6c, 0x0a, 0x83, 0x33, 0xd3, 0x85, 0x8c, 0xb9, 0xb8, 0x52,
0xe1, 0x4e, 0x96, 0x60, 0x51, 0x60, 0xd4, 0xe0, 0x36, 0x12, 0xd6, 0x2b, 0x48, 0xd2, 0x43, 0xf3,
0x4d, 0x10, 0x92, 0x32, 0x27, 0x89, 0x13, 0x8f, 0xe4, 0x18, 0x2f, 0x3c, 0x92, 0x63, 0x7c, 0xf0,
0x48, 0x8e, 0x71, 0xc2, 0x63, 0x39, 0x86, 0x0b, 0x8f, 0xe5, 0x18, 0x6e, 0x3c, 0x96, 0x63, 0x00,
0x04, 0x00, 0x00, 0xff, 0xff, 0xb2, 0xb0, 0x39, 0x45, 0x1a, 0x01, 0x00, 0x00,
}
func (m *NoiseExtensions) Marshal() (dAtA []byte, err error) {
size := m.Size()
dAtA = make([]byte, size)
n, err := m.MarshalToSizedBuffer(dAtA[:size])
if err != nil {
return nil, err
}
return dAtA[:n], nil
}
func (m *NoiseExtensions) MarshalTo(dAtA []byte) (int, error) {
size := m.Size()
return m.MarshalToSizedBuffer(dAtA[:size])
}
func (m *NoiseExtensions) MarshalToSizedBuffer(dAtA []byte) (int, error) {
i := len(dAtA)
_ = i
var l int
_ = l
if len(m.WebtransportCerthashes) > 0 {
for iNdEx := len(m.WebtransportCerthashes) - 1; iNdEx >= 0; iNdEx-- {
i -= len(m.WebtransportCerthashes[iNdEx])
copy(dAtA[i:], m.WebtransportCerthashes[iNdEx])
i = encodeVarintPayload(dAtA, i, uint64(len(m.WebtransportCerthashes[iNdEx])))
i--
dAtA[i] = 0xa
}
}
return len(dAtA) - i, nil
}
func (m *NoiseHandshakePayload) Marshal() (dAtA []byte, err error) {
@ -122,21 +203,26 @@ func (m *NoiseHandshakePayload) MarshalToSizedBuffer(dAtA []byte) (int, error) {
_ = i
var l int
_ = l
if len(m.Data) > 0 {
i -= len(m.Data)
copy(dAtA[i:], m.Data)
i = encodeVarintPayload(dAtA, i, uint64(len(m.Data)))
if m.Extensions != nil {
{
size, err := m.Extensions.MarshalToSizedBuffer(dAtA[:i])
if err != nil {
return 0, err
}
i -= size
i = encodeVarintPayload(dAtA, i, uint64(size))
}
i--
dAtA[i] = 0x1a
dAtA[i] = 0x22
}
if len(m.IdentitySig) > 0 {
if m.IdentitySig != nil {
i -= len(m.IdentitySig)
copy(dAtA[i:], m.IdentitySig)
i = encodeVarintPayload(dAtA, i, uint64(len(m.IdentitySig)))
i--
dAtA[i] = 0x12
}
if len(m.IdentityKey) > 0 {
if m.IdentityKey != nil {
i -= len(m.IdentityKey)
copy(dAtA[i:], m.IdentityKey)
i = encodeVarintPayload(dAtA, i, uint64(len(m.IdentityKey)))
@ -157,22 +243,37 @@ func encodeVarintPayload(dAtA []byte, offset int, v uint64) int {
dAtA[offset] = uint8(v)
return base
}
func (m *NoiseExtensions) Size() (n int) {
if m == nil {
return 0
}
var l int
_ = l
if len(m.WebtransportCerthashes) > 0 {
for _, b := range m.WebtransportCerthashes {
l = len(b)
n += 1 + l + sovPayload(uint64(l))
}
}
return n
}
func (m *NoiseHandshakePayload) Size() (n int) {
if m == nil {
return 0
}
var l int
_ = l
l = len(m.IdentityKey)
if l > 0 {
if m.IdentityKey != nil {
l = len(m.IdentityKey)
n += 1 + l + sovPayload(uint64(l))
}
l = len(m.IdentitySig)
if l > 0 {
if m.IdentitySig != nil {
l = len(m.IdentitySig)
n += 1 + l + sovPayload(uint64(l))
}
l = len(m.Data)
if l > 0 {
if m.Extensions != nil {
l = m.Extensions.Size()
n += 1 + l + sovPayload(uint64(l))
}
return n
@ -184,6 +285,88 @@ func sovPayload(x uint64) (n int) {
func sozPayload(x uint64) (n int) {
return sovPayload(uint64((x << 1) ^ uint64((int64(x) >> 63))))
}
func (m *NoiseExtensions) Unmarshal(dAtA []byte) error {
l := len(dAtA)
iNdEx := 0
for iNdEx < l {
preIndex := iNdEx
var wire uint64
for shift := uint(0); ; shift += 7 {
if shift >= 64 {
return ErrIntOverflowPayload
}
if iNdEx >= l {
return io.ErrUnexpectedEOF
}
b := dAtA[iNdEx]
iNdEx++
wire |= uint64(b&0x7F) << shift
if b < 0x80 {
break
}
}
fieldNum := int32(wire >> 3)
wireType := int(wire & 0x7)
if wireType == 4 {
return fmt.Errorf("proto: NoiseExtensions: wiretype end group for non-group")
}
if fieldNum <= 0 {
return fmt.Errorf("proto: NoiseExtensions: illegal tag %d (wire type %d)", fieldNum, wire)
}
switch fieldNum {
case 1:
if wireType != 2 {
return fmt.Errorf("proto: wrong wireType = %d for field WebtransportCerthashes", wireType)
}
var byteLen int
for shift := uint(0); ; shift += 7 {
if shift >= 64 {
return ErrIntOverflowPayload
}
if iNdEx >= l {
return io.ErrUnexpectedEOF
}
b := dAtA[iNdEx]
iNdEx++
byteLen |= int(b&0x7F) << shift
if b < 0x80 {
break
}
}
if byteLen < 0 {
return ErrInvalidLengthPayload
}
postIndex := iNdEx + byteLen
if postIndex < 0 {
return ErrInvalidLengthPayload
}
if postIndex > l {
return io.ErrUnexpectedEOF
}
m.WebtransportCerthashes = append(m.WebtransportCerthashes, make([]byte, postIndex-iNdEx))
copy(m.WebtransportCerthashes[len(m.WebtransportCerthashes)-1], dAtA[iNdEx:postIndex])
iNdEx = postIndex
default:
iNdEx = preIndex
skippy, err := skipPayload(dAtA[iNdEx:])
if err != nil {
return err
}
if (skippy < 0) || (iNdEx+skippy) < 0 {
return ErrInvalidLengthPayload
}
if (iNdEx + skippy) > l {
return io.ErrUnexpectedEOF
}
iNdEx += skippy
}
}
if iNdEx > l {
return io.ErrUnexpectedEOF
}
return nil
}
func (m *NoiseHandshakePayload) Unmarshal(dAtA []byte) error {
l := len(dAtA)
iNdEx := 0
@ -281,11 +464,11 @@ func (m *NoiseHandshakePayload) Unmarshal(dAtA []byte) error {
m.IdentitySig = []byte{}
}
iNdEx = postIndex
case 3:
case 4:
if wireType != 2 {
return fmt.Errorf("proto: wrong wireType = %d for field Data", wireType)
return fmt.Errorf("proto: wrong wireType = %d for field Extensions", wireType)
}
var byteLen int
var msglen int
for shift := uint(0); ; shift += 7 {
if shift >= 64 {
return ErrIntOverflowPayload
@ -295,24 +478,26 @@ func (m *NoiseHandshakePayload) Unmarshal(dAtA []byte) error {
}
b := dAtA[iNdEx]
iNdEx++
byteLen |= int(b&0x7F) << shift
msglen |= int(b&0x7F) << shift
if b < 0x80 {
break
}
}
if byteLen < 0 {
if msglen < 0 {
return ErrInvalidLengthPayload
}
postIndex := iNdEx + byteLen
postIndex := iNdEx + msglen
if postIndex < 0 {
return ErrInvalidLengthPayload
}
if postIndex > l {
return io.ErrUnexpectedEOF
}
m.Data = append(m.Data[:0], dAtA[iNdEx:postIndex]...)
if m.Data == nil {
m.Data = []byte{}
if m.Extensions == nil {
m.Extensions = &NoiseExtensions{}
}
if err := m.Extensions.Unmarshal(dAtA[iNdEx:postIndex]); err != nil {
return err
}
iNdEx = postIndex
default:

View File

@ -1,8 +1,12 @@
syntax = "proto3";
syntax = "proto2";
package pb;
message NoiseHandshakePayload {
bytes identity_key = 1;
bytes identity_sig = 2;
bytes data = 3;
message NoiseExtensions {
repeated bytes webtransport_certhashes = 1;
}
message NoiseHandshakePayload {
optional bytes identity_key = 1;
optional bytes identity_sig = 2;
optional NoiseExtensions extensions = 4;
}

View File

@ -7,6 +7,8 @@ import (
"github.com/libp2p/go-libp2p/core/canonicallog"
"github.com/libp2p/go-libp2p/core/peer"
"github.com/libp2p/go-libp2p/core/sec"
"github.com/libp2p/go-libp2p/p2p/security/noise/pb"
manet "github.com/multiformats/go-multiaddr/net"
)
@ -30,12 +32,12 @@ type EarlyDataHandler interface {
// Send for the initiator is called for the client before sending the third
// handshake message. Defines the application payload for the third message.
// Send for the responder is called before sending the second handshake message.
Send(context.Context, net.Conn, peer.ID) []byte
Send(context.Context, net.Conn, peer.ID) *pb.NoiseExtensions
// Received for the initiator is called when the second handshake message
// from the responder is received.
// Received for the responder is called when the third handshake message
// from the initiator is received.
Received(context.Context, net.Conn, []byte) error
Received(context.Context, net.Conn, *pb.NoiseExtensions) error
}
// EarlyData sets the `EarlyDataHandler` for the initiator and responder roles.

View File

@ -16,6 +16,7 @@ import (
"github.com/libp2p/go-libp2p/core/crypto"
"github.com/libp2p/go-libp2p/core/peer"
"github.com/libp2p/go-libp2p/core/sec"
"github.com/libp2p/go-libp2p/p2p/security/noise/pb"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
@ -428,22 +429,22 @@ func TestPrologueDoesNotMatchFailsHandshake(t *testing.T) {
}
type earlyDataHandler struct {
send func(context.Context, net.Conn, peer.ID) []byte
received func(context.Context, net.Conn, []byte) error
send func(context.Context, net.Conn, peer.ID) *pb.NoiseExtensions
received func(context.Context, net.Conn, *pb.NoiseExtensions) error
}
func (e *earlyDataHandler) Send(ctx context.Context, conn net.Conn, id peer.ID) []byte {
func (e *earlyDataHandler) Send(ctx context.Context, conn net.Conn, id peer.ID) *pb.NoiseExtensions {
if e.send == nil {
return nil
}
return e.send(ctx, conn, id)
}
func (e *earlyDataHandler) Received(ctx context.Context, conn net.Conn, data []byte) error {
func (e *earlyDataHandler) Received(ctx context.Context, conn net.Conn, ext *pb.NoiseExtensions) error {
if e.received == nil {
return nil
}
return e.received(ctx, conn, data)
return e.received(ctx, conn, ext)
}
func TestEarlyDataAccepted(t *testing.T) {
@ -474,27 +475,29 @@ func TestEarlyDataAccepted(t *testing.T) {
defer conn.Close()
}
var receivedEarlyData []byte
var receivedExtensions *pb.NoiseExtensions
receivingEDH := &earlyDataHandler{
received: func(_ context.Context, _ net.Conn, data []byte) error {
receivedEarlyData = data
received: func(_ context.Context, _ net.Conn, ext *pb.NoiseExtensions) error {
receivedExtensions = ext
return nil
},
}
sendingEDH := &earlyDataHandler{
send: func(ctx context.Context, conn net.Conn, id peer.ID) []byte { return []byte("foobar") },
send: func(context.Context, net.Conn, peer.ID) *pb.NoiseExtensions {
return &pb.NoiseExtensions{WebtransportCerthashes: [][]byte{[]byte("foobar")}}
},
}
t.Run("client sending", func(t *testing.T) {
handshake(t, sendingEDH, receivingEDH)
require.Equal(t, []byte("foobar"), receivedEarlyData)
receivedEarlyData = nil
require.Equal(t, [][]byte{[]byte("foobar")}, receivedExtensions.WebtransportCerthashes)
receivedExtensions = nil
})
t.Run("server sending", func(t *testing.T) {
handshake(t, receivingEDH, sendingEDH)
require.Equal(t, []byte("foobar"), receivedEarlyData)
receivedEarlyData = nil
require.Equal(t, [][]byte{[]byte("foobar")}, receivedExtensions.WebtransportCerthashes)
receivedExtensions = nil
})
}
@ -532,10 +535,12 @@ func TestEarlyDataRejected(t *testing.T) {
}
receivingEDH := &earlyDataHandler{
received: func(_ context.Context, _ net.Conn, data []byte) error { return errors.New("nope") },
received: func(context.Context, net.Conn, *pb.NoiseExtensions) error { return errors.New("nope") },
}
sendingEDH := &earlyDataHandler{
send: func(ctx context.Context, conn net.Conn, id peer.ID) []byte { return []byte("foobar") },
send: func(context.Context, net.Conn, peer.ID) *pb.NoiseExtensions {
return &pb.NoiseExtensions{WebtransportCerthashes: [][]byte{[]byte("foobar")}}
},
}
t.Run("client sending", func(t *testing.T) {
@ -554,7 +559,9 @@ func TestEarlyDataRejected(t *testing.T) {
func TestEarlyfffDataAcceptedWithNoHandler(t *testing.T) {
clientEDH := &earlyDataHandler{
send: func(ctx context.Context, conn net.Conn, id peer.ID) []byte { return []byte("foobar") },
send: func(ctx context.Context, conn net.Conn, id peer.ID) *pb.NoiseExtensions {
return &pb.NoiseExtensions{WebtransportCerthashes: [][]byte{[]byte("foobar")}}
},
}
initTransport, err := newTestTransport(t, crypto.Ed25519, 2048).WithSessionOptions(EarlyData(clientEDH, nil))
require.NoError(t, err)

View File

@ -8,8 +8,6 @@ import (
"sync"
"time"
pb "github.com/libp2p/go-libp2p/p2p/transport/webtransport/pb"
"github.com/benbjohnson/clock"
ma "github.com/multiformats/go-multiaddr"
"github.com/multiformats/go-multihash"
@ -56,7 +54,7 @@ type certManager struct {
nextConfig *certConfig // nil until we have passed half the certValidity of the current config
addrComp ma.Multiaddr
protobuf []byte
serializedCertHashes [][]byte
}
func newCertManager(clock clock.Clock) (*certManager, error) {
@ -91,7 +89,7 @@ func (m *certManager) rollConfig() error {
m.lastConfig = m.currentConfig
m.currentConfig = m.nextConfig
m.nextConfig = c
if err := m.cacheProtobuf(); err != nil {
if err := m.cacheSerializedCertHashes(); err != nil {
return err
}
return m.cacheAddrComponent()
@ -137,11 +135,11 @@ func (m *certManager) AddrComponent() ma.Multiaddr {
return m.addrComp
}
func (m *certManager) Protobuf() []byte {
return m.protobuf
func (m *certManager) SerializedCertHashes() [][]byte {
return m.serializedCertHashes
}
func (m *certManager) cacheProtobuf() error {
func (m *certManager) cacheSerializedCertHashes() error {
hashes := make([][32]byte, 0, 3)
if m.lastConfig != nil {
hashes = append(hashes, m.lastConfig.sha256)
@ -151,19 +149,14 @@ func (m *certManager) cacheProtobuf() error {
hashes = append(hashes, m.nextConfig.sha256)
}
msg := pb.WebTransport{CertHashes: make([][]byte, 0, len(hashes))}
m.serializedCertHashes = m.serializedCertHashes[:0]
for _, certHash := range hashes {
h, err := multihash.Encode(certHash[:], multihash.SHA2_256)
if err != nil {
return fmt.Errorf("failed to encode certificate hash: %w", err)
}
msg.CertHashes = append(msg.CertHashes, h)
m.serializedCertHashes = append(m.serializedCertHashes, h)
}
msgBytes, err := msg.Marshal()
if err != nil {
return fmt.Errorf("failed to marshal WebTransport protobuf: %w", err)
}
m.protobuf = msgBytes
return nil
}

View File

@ -9,7 +9,7 @@ import (
"net/http"
"time"
pb "github.com/libp2p/go-libp2p/p2p/transport/webtransport/pb"
"github.com/libp2p/go-libp2p/p2p/security/noise/pb"
"github.com/libp2p/go-libp2p/core/connmgr"
"github.com/libp2p/go-libp2p/core/network"
@ -197,19 +197,15 @@ func (l *listener) handshake(ctx context.Context, sess *webtransport.Session) (*
if err != nil {
return nil, err
}
var earlyData []byte
if l.isStaticTLSConf {
var msg pb.WebTransport
var err error
earlyData, err = msg.Marshal()
if err != nil {
return nil, err
}
} else {
earlyData = l.certManager.Protobuf()
var earlyData [][]byte
if !l.isStaticTLSConf {
earlyData = l.certManager.SerializedCertHashes()
}
n, err := l.noise.WithSessionOptions(noise.EarlyData(nil, newEarlyDataSender(earlyData)))
n, err := l.noise.WithSessionOptions(noise.EarlyData(
nil,
newEarlyDataSender(&pb.NoiseExtensions{WebtransportCerthashes: earlyData}),
))
if err != nil {
return nil, fmt.Errorf("failed to initialize Noise session: %w", err)
}

View File

@ -2,33 +2,36 @@ package libp2pwebtransport
import (
"context"
"net"
"github.com/libp2p/go-libp2p/p2p/security/noise/pb"
"github.com/libp2p/go-libp2p/core/peer"
"github.com/libp2p/go-libp2p/p2p/security/noise"
"net"
)
type earlyDataHandler struct {
earlyData []byte
receive func([]byte) error
earlyData *pb.NoiseExtensions
receive func(extensions *pb.NoiseExtensions) error
}
var _ noise.EarlyDataHandler = &earlyDataHandler{}
func newEarlyDataSender(earlyData []byte) noise.EarlyDataHandler {
func newEarlyDataSender(earlyData *pb.NoiseExtensions) noise.EarlyDataHandler {
return &earlyDataHandler{earlyData: earlyData}
}
func newEarlyDataReceiver(receive func([]byte) error) noise.EarlyDataHandler {
func newEarlyDataReceiver(receive func(*pb.NoiseExtensions) error) noise.EarlyDataHandler {
return &earlyDataHandler{receive: receive}
}
func (e *earlyDataHandler) Send(context.Context, net.Conn, peer.ID) []byte {
func (e *earlyDataHandler) Send(context.Context, net.Conn, peer.ID) *pb.NoiseExtensions {
return e.earlyData
}
func (e *earlyDataHandler) Received(_ context.Context, _ net.Conn, b []byte) error {
func (e *earlyDataHandler) Received(_ context.Context, _ net.Conn, ext *pb.NoiseExtensions) error {
if e.receive == nil {
return nil
}
return e.receive(b)
return e.receive(ext)
}

View File

@ -1,11 +0,0 @@
PB = $(wildcard *.proto)
GO = $(PB:.proto=.pb.go)
all: $(GO)
%.pb.go: %.proto
protoc --proto_path=$(PWD)/../..:. --gogofaster_out=. $<
clean:
rm -f *.pb.go
rm -f *.go

View File

@ -1,315 +0,0 @@
// Code generated by protoc-gen-gogo. DO NOT EDIT.
// source: webtransport.proto
package webtransport
import (
fmt "fmt"
proto "github.com/gogo/protobuf/proto"
io "io"
math "math"
math_bits "math/bits"
)
// Reference imports to suppress errors if they are not otherwise used.
var _ = proto.Marshal
var _ = fmt.Errorf
var _ = math.Inf
// This is a compile-time assertion to ensure that this generated file
// is compatible with the proto package it is being compiled against.
// A compilation error at this line likely means your copy of the
// proto package needs to be updated.
const _ = proto.GoGoProtoPackageIsVersion3 // please upgrade the proto package
type WebTransport struct {
CertHashes [][]byte `protobuf:"bytes,1,rep,name=cert_hashes,json=certHashes" json:"cert_hashes,omitempty"`
}
func (m *WebTransport) Reset() { *m = WebTransport{} }
func (m *WebTransport) String() string { return proto.CompactTextString(m) }
func (*WebTransport) ProtoMessage() {}
func (*WebTransport) Descriptor() ([]byte, []int) {
return fileDescriptor_db878920ab41a4f3, []int{0}
}
func (m *WebTransport) XXX_Unmarshal(b []byte) error {
return m.Unmarshal(b)
}
func (m *WebTransport) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) {
if deterministic {
return xxx_messageInfo_WebTransport.Marshal(b, m, deterministic)
} else {
b = b[:cap(b)]
n, err := m.MarshalToSizedBuffer(b)
if err != nil {
return nil, err
}
return b[:n], nil
}
}
func (m *WebTransport) XXX_Merge(src proto.Message) {
xxx_messageInfo_WebTransport.Merge(m, src)
}
func (m *WebTransport) XXX_Size() int {
return m.Size()
}
func (m *WebTransport) XXX_DiscardUnknown() {
xxx_messageInfo_WebTransport.DiscardUnknown(m)
}
var xxx_messageInfo_WebTransport proto.InternalMessageInfo
func (m *WebTransport) GetCertHashes() [][]byte {
if m != nil {
return m.CertHashes
}
return nil
}
func init() {
proto.RegisterType((*WebTransport)(nil), "WebTransport")
}
func init() { proto.RegisterFile("webtransport.proto", fileDescriptor_db878920ab41a4f3) }
var fileDescriptor_db878920ab41a4f3 = []byte{
// 109 bytes of a gzipped FileDescriptorProto
0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0xe2, 0x12, 0x2a, 0x4f, 0x4d, 0x2a,
0x29, 0x4a, 0xcc, 0x2b, 0x2e, 0xc8, 0x2f, 0x2a, 0xd1, 0x2b, 0x28, 0xca, 0x2f, 0xc9, 0x57, 0xd2,
0xe7, 0xe2, 0x09, 0x4f, 0x4d, 0x0a, 0x81, 0x89, 0x0a, 0xc9, 0x73, 0x71, 0x27, 0xa7, 0x16, 0x95,
0xc4, 0x67, 0x24, 0x16, 0x67, 0xa4, 0x16, 0x4b, 0x30, 0x2a, 0x30, 0x6b, 0xf0, 0x04, 0x71, 0x81,
0x84, 0x3c, 0xc0, 0x22, 0x4e, 0x12, 0x27, 0x1e, 0xc9, 0x31, 0x5e, 0x78, 0x24, 0xc7, 0xf8, 0xe0,
0x91, 0x1c, 0xe3, 0x84, 0xc7, 0x72, 0x0c, 0x17, 0x1e, 0xcb, 0x31, 0xdc, 0x78, 0x2c, 0xc7, 0x00,
0x08, 0x00, 0x00, 0xff, 0xff, 0x50, 0x77, 0xe5, 0x52, 0x5f, 0x00, 0x00, 0x00,
}
func (m *WebTransport) Marshal() (dAtA []byte, err error) {
size := m.Size()
dAtA = make([]byte, size)
n, err := m.MarshalToSizedBuffer(dAtA[:size])
if err != nil {
return nil, err
}
return dAtA[:n], nil
}
func (m *WebTransport) MarshalTo(dAtA []byte) (int, error) {
size := m.Size()
return m.MarshalToSizedBuffer(dAtA[:size])
}
func (m *WebTransport) MarshalToSizedBuffer(dAtA []byte) (int, error) {
i := len(dAtA)
_ = i
var l int
_ = l
if len(m.CertHashes) > 0 {
for iNdEx := len(m.CertHashes) - 1; iNdEx >= 0; iNdEx-- {
i -= len(m.CertHashes[iNdEx])
copy(dAtA[i:], m.CertHashes[iNdEx])
i = encodeVarintWebtransport(dAtA, i, uint64(len(m.CertHashes[iNdEx])))
i--
dAtA[i] = 0xa
}
}
return len(dAtA) - i, nil
}
func encodeVarintWebtransport(dAtA []byte, offset int, v uint64) int {
offset -= sovWebtransport(v)
base := offset
for v >= 1<<7 {
dAtA[offset] = uint8(v&0x7f | 0x80)
v >>= 7
offset++
}
dAtA[offset] = uint8(v)
return base
}
func (m *WebTransport) Size() (n int) {
if m == nil {
return 0
}
var l int
_ = l
if len(m.CertHashes) > 0 {
for _, b := range m.CertHashes {
l = len(b)
n += 1 + l + sovWebtransport(uint64(l))
}
}
return n
}
func sovWebtransport(x uint64) (n int) {
return (math_bits.Len64(x|1) + 6) / 7
}
func sozWebtransport(x uint64) (n int) {
return sovWebtransport(uint64((x << 1) ^ uint64((int64(x) >> 63))))
}
func (m *WebTransport) Unmarshal(dAtA []byte) error {
l := len(dAtA)
iNdEx := 0
for iNdEx < l {
preIndex := iNdEx
var wire uint64
for shift := uint(0); ; shift += 7 {
if shift >= 64 {
return ErrIntOverflowWebtransport
}
if iNdEx >= l {
return io.ErrUnexpectedEOF
}
b := dAtA[iNdEx]
iNdEx++
wire |= uint64(b&0x7F) << shift
if b < 0x80 {
break
}
}
fieldNum := int32(wire >> 3)
wireType := int(wire & 0x7)
if wireType == 4 {
return fmt.Errorf("proto: WebTransport: wiretype end group for non-group")
}
if fieldNum <= 0 {
return fmt.Errorf("proto: WebTransport: illegal tag %d (wire type %d)", fieldNum, wire)
}
switch fieldNum {
case 1:
if wireType != 2 {
return fmt.Errorf("proto: wrong wireType = %d for field CertHashes", wireType)
}
var byteLen int
for shift := uint(0); ; shift += 7 {
if shift >= 64 {
return ErrIntOverflowWebtransport
}
if iNdEx >= l {
return io.ErrUnexpectedEOF
}
b := dAtA[iNdEx]
iNdEx++
byteLen |= int(b&0x7F) << shift
if b < 0x80 {
break
}
}
if byteLen < 0 {
return ErrInvalidLengthWebtransport
}
postIndex := iNdEx + byteLen
if postIndex < 0 {
return ErrInvalidLengthWebtransport
}
if postIndex > l {
return io.ErrUnexpectedEOF
}
m.CertHashes = append(m.CertHashes, make([]byte, postIndex-iNdEx))
copy(m.CertHashes[len(m.CertHashes)-1], dAtA[iNdEx:postIndex])
iNdEx = postIndex
default:
iNdEx = preIndex
skippy, err := skipWebtransport(dAtA[iNdEx:])
if err != nil {
return err
}
if (skippy < 0) || (iNdEx+skippy) < 0 {
return ErrInvalidLengthWebtransport
}
if (iNdEx + skippy) > l {
return io.ErrUnexpectedEOF
}
iNdEx += skippy
}
}
if iNdEx > l {
return io.ErrUnexpectedEOF
}
return nil
}
func skipWebtransport(dAtA []byte) (n int, err error) {
l := len(dAtA)
iNdEx := 0
depth := 0
for iNdEx < l {
var wire uint64
for shift := uint(0); ; shift += 7 {
if shift >= 64 {
return 0, ErrIntOverflowWebtransport
}
if iNdEx >= l {
return 0, io.ErrUnexpectedEOF
}
b := dAtA[iNdEx]
iNdEx++
wire |= (uint64(b) & 0x7F) << shift
if b < 0x80 {
break
}
}
wireType := int(wire & 0x7)
switch wireType {
case 0:
for shift := uint(0); ; shift += 7 {
if shift >= 64 {
return 0, ErrIntOverflowWebtransport
}
if iNdEx >= l {
return 0, io.ErrUnexpectedEOF
}
iNdEx++
if dAtA[iNdEx-1] < 0x80 {
break
}
}
case 1:
iNdEx += 8
case 2:
var length int
for shift := uint(0); ; shift += 7 {
if shift >= 64 {
return 0, ErrIntOverflowWebtransport
}
if iNdEx >= l {
return 0, io.ErrUnexpectedEOF
}
b := dAtA[iNdEx]
iNdEx++
length |= (int(b) & 0x7F) << shift
if b < 0x80 {
break
}
}
if length < 0 {
return 0, ErrInvalidLengthWebtransport
}
iNdEx += length
case 3:
depth++
case 4:
if depth == 0 {
return 0, ErrUnexpectedEndOfGroupWebtransport
}
depth--
case 5:
iNdEx += 4
default:
return 0, fmt.Errorf("proto: illegal wireType %d", wireType)
}
if iNdEx < 0 {
return 0, ErrInvalidLengthWebtransport
}
if depth == 0 {
return iNdEx, nil
}
}
return 0, io.ErrUnexpectedEOF
}
var (
ErrInvalidLengthWebtransport = fmt.Errorf("proto: negative length found during unmarshaling")
ErrIntOverflowWebtransport = fmt.Errorf("proto: integer overflow")
ErrUnexpectedEndOfGroupWebtransport = fmt.Errorf("proto: unexpected end of group")
)

View File

@ -1,5 +0,0 @@
syntax = "proto2";
message WebTransport {
repeated bytes cert_hashes = 1;
}

View File

@ -11,16 +11,16 @@ import (
"sync"
"time"
"github.com/libp2p/go-libp2p/p2p/security/noise/pb"
"github.com/benbjohnson/clock"
logging "github.com/ipfs/go-log/v2"
"github.com/libp2p/go-libp2p/core/connmgr"
ic "github.com/libp2p/go-libp2p/core/crypto"
"github.com/libp2p/go-libp2p/core/network"
"github.com/libp2p/go-libp2p/core/peer"
tpt "github.com/libp2p/go-libp2p/core/transport"
"github.com/libp2p/go-libp2p/p2p/security/noise"
pb "github.com/libp2p/go-libp2p/p2p/transport/webtransport/pb"
"github.com/benbjohnson/clock"
logging "github.com/ipfs/go-log/v2"
"github.com/lucas-clemente/quic-go/http3"
"github.com/marten-seemann/webtransport-go"
ma "github.com/multiformats/go-multiaddr"
@ -206,8 +206,8 @@ func (t *transport) upgrade(ctx context.Context, sess *webtransport.Session, p p
// Now run a Noise handshake (using early data) and send all the certificate hashes that we would have accepted.
// The server will verify that it advertised all of these certificate hashes.
var verified bool
n, err := t.noise.WithSessionOptions(noise.EarlyData(newEarlyDataReceiver(func(b []byte) error {
decodedCertHashes, err := decodeCertHashesFromProtobuf(b)
n, err := t.noise.WithSessionOptions(noise.EarlyData(newEarlyDataReceiver(func(b *pb.NoiseExtensions) error {
decodedCertHashes, err := decodeCertHashesFromProtobuf(b.WebtransportCerthashes)
if err != nil {
return err
}
@ -244,14 +244,9 @@ func (t *transport) upgrade(ctx context.Context, sess *webtransport.Session, p p
}, nil
}
func decodeCertHashesFromProtobuf(b []byte) ([]multihash.DecodedMultihash, error) {
var msg pb.WebTransport
if err := msg.Unmarshal(b); err != nil {
return nil, fmt.Errorf("failed to unmarshal early data protobuf: %w", err)
}
hashes := make([]multihash.DecodedMultihash, 0, len(msg.CertHashes))
for _, h := range msg.CertHashes {
func decodeCertHashesFromProtobuf(b [][]byte) ([]multihash.DecodedMultihash, error) {
hashes := make([]multihash.DecodedMultihash, 0, len(b))
for _, h := range b {
dh, err := multihash.Decode(h)
if err != nil {
return nil, fmt.Errorf("failed to decode hash: %w", err)