Merge pull request #1199 from libp2p/protect-nonsensical-expiration
reject circuitv2 reservations with nonsensical expiration times
This commit is contained in:
commit
4650ccd21b
|
@ -87,8 +87,13 @@ func Reserve(ctx context.Context, h host.Host, ai peer.AddrInfo) (*Reservation,
|
|||
|
||||
result := &Reservation{}
|
||||
result.Expiration = time.Unix(int64(rsvp.GetExpire()), 0)
|
||||
if result.Expiration.Before(time.Now()) {
|
||||
return nil, fmt.Errorf("received reservation with expiration date in the past: %s", result.Expiration)
|
||||
}
|
||||
|
||||
for _, ab := range rsvp.GetAddrs() {
|
||||
addrs := rsvp.GetAddrs()
|
||||
result.Addrs = make([]ma.Multiaddr, 0, len(addrs))
|
||||
for _, ab := range addrs {
|
||||
a, err := ma.NewMultiaddrBytes(ab)
|
||||
if err != nil {
|
||||
log.Warnf("ignoring unparsable relay address: %s", err)
|
||||
|
@ -108,7 +113,6 @@ func Reserve(ctx context.Context, h host.Host, ai peer.AddrInfo) (*Reservation,
|
|||
if !ok {
|
||||
return nil, fmt.Errorf("unexpected voucher record type: %+T", rec)
|
||||
}
|
||||
|
||||
result.Voucher = voucher
|
||||
}
|
||||
|
||||
|
|
|
@ -0,0 +1,105 @@
|
|||
package client_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"math"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/libp2p/go-libp2p"
|
||||
"github.com/libp2p/go-libp2p/p2p/protocol/circuitv2/client"
|
||||
pbv2 "github.com/libp2p/go-libp2p/p2p/protocol/circuitv2/pb"
|
||||
"github.com/libp2p/go-libp2p/p2p/protocol/circuitv2/proto"
|
||||
"github.com/libp2p/go-libp2p/p2p/protocol/circuitv2/util"
|
||||
|
||||
"github.com/libp2p/go-libp2p-core/network"
|
||||
"github.com/libp2p/go-libp2p-core/peer"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestReservationFailures(t *testing.T) {
|
||||
type testcase struct {
|
||||
name string
|
||||
streamHandler network.StreamHandler
|
||||
err string
|
||||
}
|
||||
testcases := []testcase{
|
||||
{
|
||||
name: "unsupported protocol",
|
||||
streamHandler: nil,
|
||||
err: "protocol not supported",
|
||||
},
|
||||
{
|
||||
name: "wrong message type",
|
||||
streamHandler: func(s network.Stream) {
|
||||
util.NewDelimitedWriter(s).WriteMsg(&pbv2.HopMessage{
|
||||
Type: pbv2.HopMessage_RESERVE.Enum(),
|
||||
})
|
||||
},
|
||||
err: "unexpected relay response: not a status message",
|
||||
},
|
||||
{
|
||||
name: "unknown status",
|
||||
streamHandler: func(s network.Stream) {
|
||||
status := pbv2.Status(1337)
|
||||
util.NewDelimitedWriter(s).WriteMsg(&pbv2.HopMessage{
|
||||
Type: pbv2.HopMessage_STATUS.Enum(),
|
||||
Status: &status,
|
||||
})
|
||||
},
|
||||
err: "reservation failed",
|
||||
},
|
||||
{
|
||||
name: "invalid time",
|
||||
streamHandler: func(s network.Stream) {
|
||||
status := pbv2.Status_OK
|
||||
expire := uint64(math.MaxUint64)
|
||||
util.NewDelimitedWriter(s).WriteMsg(&pbv2.HopMessage{
|
||||
Type: pbv2.HopMessage_STATUS.Enum(),
|
||||
Status: &status,
|
||||
Reservation: &pbv2.Reservation{Expire: &expire},
|
||||
})
|
||||
},
|
||||
err: "received reservation with expiration date in the past",
|
||||
},
|
||||
{
|
||||
name: "invalid voucher",
|
||||
streamHandler: func(s network.Stream) {
|
||||
status := pbv2.Status_OK
|
||||
expire := uint64(time.Now().Add(time.Hour).UnixNano())
|
||||
util.NewDelimitedWriter(s).WriteMsg(&pbv2.HopMessage{
|
||||
Type: pbv2.HopMessage_STATUS.Enum(),
|
||||
Status: &status,
|
||||
Reservation: &pbv2.Reservation{
|
||||
Expire: &expire,
|
||||
Voucher: []byte("foobar"),
|
||||
},
|
||||
})
|
||||
},
|
||||
err: "error consuming voucher envelope: failed when unmarshalling the envelope",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testcases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
host, err := libp2p.New()
|
||||
require.NoError(t, err)
|
||||
defer host.Close()
|
||||
if tc.streamHandler != nil {
|
||||
host.SetStreamHandler(proto.ProtoIDv2Hop, tc.streamHandler)
|
||||
}
|
||||
|
||||
cl, err := libp2p.New()
|
||||
require.NoError(t, err)
|
||||
defer cl.Close()
|
||||
_, err = client.Reserve(context.Background(), cl, peer.AddrInfo{ID: host.ID(), Addrs: host.Addrs()})
|
||||
if tc.err == "" {
|
||||
require.NoError(t, err)
|
||||
} else {
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), tc.err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue