diff --git a/peer_protocol/int.go b/peer_protocol/int.go index 13bd1ca9..ebcf6035 100644 --- a/peer_protocol/int.go +++ b/peer_protocol/int.go @@ -3,11 +3,18 @@ package peer_protocol import ( "encoding/binary" "io" + "math" "github.com/pkg/errors" ) -type Integer uint32 +type ( + // An alias for the underlying type of Integer. This is needed for fuzzing. + IntegerKind = uint32 + Integer IntegerKind +) + +const IntegerMax = math.MaxUint32 func (i *Integer) UnmarshalBinary(b []byte) error { if len(b) != 4 { diff --git a/peerconn.go b/peerconn.go index 3c515aaf..fdc8236d 100644 --- a/peerconn.go +++ b/peerconn.go @@ -1003,6 +1003,18 @@ func (c *PeerConn) maximumPeerRequestChunkLength() (_ Option[int]) { return Some(uploadRateLimiter.Burst()) } +// Returns whether any part of the chunk would lie outside a piece of the given length. +func chunkOverflowsPiece(cs ChunkSpec, pieceLength pp.Integer) bool { + switch { + default: + return false + case cs.Begin+cs.Length > pieceLength: + // Check for integer overflow + case cs.Begin > pp.IntegerMax-cs.Length: + } + return true +} + // startFetch is for testing purposes currently. func (c *PeerConn) onReadRequest(r Request, startFetch bool) error { requestedChunkLengths.Add(strconv.FormatUint(r.Length.Uint64(), 10), 1) @@ -1045,10 +1057,11 @@ func (c *PeerConn) onReadRequest(r Request, startFetch bool) error { requestsReceivedForMissingPieces.Add(1) return fmt.Errorf("peer requested piece we don't have: %v", r.Index.Int()) } + pieceLength := c.t.pieceLength(pieceIndex(r.Index)) // Check this after we know we have the piece, so that the piece length will be known. - if r.Begin+r.Length > c.t.pieceLength(pieceIndex(r.Index)) { + if chunkOverflowsPiece(r.ChunkSpec, pieceLength) { torrent.Add("bad requests received", 1) - return errors.New("bad Request") + return errors.New("chunk overflows piece") } if c.peerRequests == nil { c.peerRequests = make(map[Request]*peerRequestState, localClientReqq) @@ -1255,6 +1268,9 @@ func (c *PeerConn) mainReadLoop() (err error) { case pp.Request: r := newRequestFromMessage(&msg) err = c.onReadRequest(r, true) + if err != nil { + err = fmt.Errorf("on reading request %v: %w", r, err) + } case pp.Piece: c.doChunkReadStats(int64(len(msg.Piece))) err = c.receiveChunk(&msg) diff --git a/peerconn_test.go b/peerconn_test.go index 23d32286..42f8fe27 100644 --- a/peerconn_test.go +++ b/peerconn_test.go @@ -4,12 +4,13 @@ import ( "encoding/binary" "errors" "fmt" - "golang.org/x/time/rate" "io" "net" "sync" "testing" + "golang.org/x/time/rate" + "github.com/frankban/quicktest" qt "github.com/frankban/quicktest" "github.com/stretchr/testify/require" @@ -317,3 +318,15 @@ func TestReceiveLargeRequest(t *testing.T) { c.Check(pc.onReadRequest(req, false), qt.IsNil) c.Check(pc.messageWriter.writeBuffer.Len(), qt.Equals, 17) } + +func TestChunkOverflowsPiece(t *testing.T) { + c := qt.New(t) + check := func(begin, length, limit pp.Integer, expected bool) { + c.Check(chunkOverflowsPiece(ChunkSpec{begin, length}, limit), qt.Equals, expected) + } + check(2, 3, 1, true) + check(2, pp.IntegerMax, 1, true) + check(2, pp.IntegerMax, 3, true) + check(2, pp.IntegerMax, pp.IntegerMax, true) + check(2, pp.IntegerMax-2, pp.IntegerMax, false) +}