diff --git a/client.go b/client.go index 5e4b5d84..a257cf84 100644 --- a/client.go +++ b/client.go @@ -1267,16 +1267,6 @@ func (cl *Client) gotMetadataExtensionMsg(payload []byte, t *torrent, c *connect return } -func (cl *Client) peerHasAll(t *torrent, cn *connection) { - cn.peerHasAll = true - cn.PeerPieces = nil - if t.haveInfo() { - for i := 0; i < t.numPieces(); i++ { - cn.peerGotPiece(i) - } - } -} - func (me *Client) upload(t *torrent, c *connection) { if me.config.NoUpload { return @@ -1380,7 +1370,7 @@ func (me *Client) connectionLoop(t *torrent, c *connection) error { c.PeerInterested = false c.Choke() case pp.Have: - c.peerGotPiece(int(msg.Index)) + err = c.peerSentHave(int(msg.Index)) case pp.Request: if c.Choked { break @@ -1408,41 +1398,11 @@ func (me *Client) connectionLoop(t *torrent, c *connection) error { unexpectedCancels.Add(1) } case pp.Bitfield: - if c.PeerPieces != nil || c.peerHasAll { - err = errors.New("received unexpected bitfield") - break - } - if t.haveInfo() { - if len(msg.Bitfield) < t.numPieces() { - err = errors.New("received invalid bitfield") - break - } - msg.Bitfield = msg.Bitfield[:t.numPieces()] - } - c.PeerPieces = msg.Bitfield - for index, has := range c.PeerPieces { - if has { - c.peerGotPiece(index) - } - } + err = c.peerSentBitfield(msg.Bitfield) case pp.HaveAll: - if c.PeerPieces != nil || c.peerHasAll { - err = errors.New("unexpected have-all") - break - } - me.peerHasAll(t, c) + err = c.peerSentHaveAll() case pp.HaveNone: - if c.peerHasAll || c.PeerPieces != nil { - err = errors.New("unexpected have-none") - break - } - c.PeerPieces = make([]bool, func() int { - if t.haveInfo() { - return t.numPieces() - } else { - return 0 - } - }()) + err = c.peerSentHaveNone() case pp.Piece: me.downloadedChunk(t, c, &msg) case pp.Extended: diff --git a/connection.go b/connection.go index fd647002..51db2248 100644 --- a/connection.go +++ b/connection.go @@ -14,7 +14,9 @@ import ( "time" "github.com/anacrolix/missinggo" + "github.com/anacrolix/missinggo/bitmap" "github.com/anacrolix/missinggo/prioritybitmap" + "github.com/bradfitz/iter" "github.com/anacrolix/torrent/bencode" pp "github.com/anacrolix/torrent/peer_protocol" @@ -68,10 +70,15 @@ type connection struct { PeerChoked bool PeerRequests map[request]struct{} PeerExtensionBytes peerExtensionBytes - // Whether the peer has the given piece. nil if they've not sent any - // related messages yet. - PeerPieces []bool + // The pieces the peer has claimed to have. + peerPieces bitmap.Bitmap + // The peer has everything. This can occur due to a special message, when + // we may not even know the number of pieces in the torrent yet. peerHasAll bool + // The highest possible number of pieces the torrent could have based on + // communication with the peer. Generally only useful until we have the + // torrent info. + peerMinPieces int // Pieces we've accepted chunks for from the peer. peerTouchedPieces map[int]struct{} @@ -108,61 +115,24 @@ func (cn *connection) supportsExtension(ext string) bool { return ok } -func (cn *connection) completedString(t *torrent) string { - if cn.PeerPieces == nil && !cn.peerHasAll { - return "?" +// The best guess at number of pieces in the torrent for this peer. +func (cn *connection) bestPeerNumPieces() int { + if cn.t.haveInfo() { + return cn.t.numPieces() } - return fmt.Sprintf("%d/%d", func() int { - if cn.peerHasAll { - if t.haveInfo() { - return t.numPieces() - } - return -1 - } - ret := 0 - for _, b := range cn.PeerPieces { - if b { - ret++ - } - } - return ret - }(), func() int { - if cn.peerHasAll || cn.PeerPieces == nil { - if t.haveInfo() { - return t.numPieces() - } - return -1 - } - return len(cn.PeerPieces) - }()) + return cn.peerMinPieces +} + +func (cn *connection) completedString() string { + return fmt.Sprintf("%d/%d", cn.peerPieces.Len(), cn.bestPeerNumPieces()) } // Correct the PeerPieces slice length. Return false if the existing slice is // invalid, such as by receiving badly sized BITFIELD, or invalid HAVE // messages. func (cn *connection) setNumPieces(num int) error { - if cn.peerHasAll { - return nil - } - if cn.PeerPieces == nil { - return nil - } - if len(cn.PeerPieces) == num { - } else if len(cn.PeerPieces) < num { - cn.PeerPieces = append(cn.PeerPieces, make([]bool, num-len(cn.PeerPieces))...) - } else if len(cn.PeerPieces) <= (num+7)/8*8 { - for _, have := range cn.PeerPieces[num:] { - if have { - return errors.New("peer has invalid piece") - } - } - cn.PeerPieces = cn.PeerPieces[:num] - } else { - return fmt.Errorf("peer bitfield is excessively long: expected %d, have %d", num, len(cn.PeerPieces)) - } - if len(cn.PeerPieces) != num { - panic("wat") - } + cn.peerPieces.RemoveRange(num, -1) + cn.peerPiecesChanged() return nil } @@ -227,7 +197,7 @@ func (cn *connection) WriteStatus(w io.Writer, t *torrent) { eventAgeString(cn.lastUsefulChunkReceived)) fmt.Fprintf(w, " %s completed, %d pieces touched, good chunks: %d/%d-%d reqq: %d-%d, flags: %s\n", - cn.completedString(t), + cn.completedString(), len(cn.peerTouchedPieces), cn.UsefulChunksReceived, cn.UnwantedChunksReceived+cn.UsefulChunksReceived, @@ -247,13 +217,7 @@ func (c *connection) Close() { } func (c *connection) PeerHasPiece(piece int) bool { - if c.peerHasAll { - return true - } - if piece >= len(c.PeerPieces) { - return false - } - return c.PeerPieces[piece] + return c.peerHasAll || c.peerPieces.Contains(piece) } func (c *connection) Post(msg pp.Message) { @@ -626,22 +590,69 @@ func (c *connection) discardPieceInclination() { c.pieceInclination = nil } -func (c *connection) peerGotPiece(piece int) error { - if !c.peerHasAll { - if c.t.haveInfo() { - if c.PeerPieces == nil { - c.PeerPieces = make([]bool, c.t.numPieces()) - } - } else { - for piece >= len(c.PeerPieces) { - c.PeerPieces = append(c.PeerPieces, false) - } - } - if piece >= len(c.PeerPieces) { - return errors.New("peer got out of range piece index") - } - c.PeerPieces[piece] = true - } +func (c *connection) peerHasPieceChanged(piece int) { c.updatePiecePriority(piece) +} + +func (c *connection) peerPiecesChanged() { + if c.t.haveInfo() { + for i := range iter.N(c.t.numPieces()) { + c.peerHasPieceChanged(i) + } + } +} + +func (c *connection) raisePeerMinPieces(newMin int) { + if newMin > c.peerMinPieces { + c.peerMinPieces = newMin + } +} + +func (c *connection) peerSentHave(piece int) error { + if c.t.haveInfo() && piece >= c.t.numPieces() { + return errors.New("invalid piece") + } + if c.PeerHasPiece(piece) { + return nil + } + c.raisePeerMinPieces(piece + 1) + c.peerPieces.Set(piece, true) + c.peerHasPieceChanged(piece) + return nil +} + +func (c *connection) peerSentBitfield(bf []bool) error { + c.peerHasAll = false + if len(bf)%8 != 0 { + panic("expected bitfield length divisible by 8") + } + // We know that the last byte means that at most the last 7 bits are + // wasted. + c.raisePeerMinPieces(len(bf) - 7) + if c.t.haveInfo() { + // Ignore known excess pieces. + bf = bf[:c.t.numPieces()] + } + for i, have := range bf { + if have { + c.raisePeerMinPieces(i + 1) + } + c.peerPieces.Set(i, have) + } + c.peerPiecesChanged() + return nil +} + +func (cn *connection) peerSentHaveAll() error { + cn.peerHasAll = true + cn.peerPieces.Clear() + cn.peerPiecesChanged() + return nil +} + +func (c *connection) peerSentHaveNone() error { + c.peerPieces.Clear() + c.peerHasAll = false + c.peerPiecesChanged() return nil } diff --git a/connection_test.go b/connection_test.go index ea5d1f7d..dc225ad5 100644 --- a/connection_test.go +++ b/connection_test.go @@ -4,19 +4,24 @@ import ( "testing" "time" + "github.com/anacrolix/missinggo/bitmap" + "github.com/stretchr/testify/assert" + "github.com/anacrolix/torrent/peer_protocol" ) func TestCancelRequestOptimized(t *testing.T) { c := &connection{ PeerMaxRequests: 1, - PeerPieces: []bool{false, true}, - post: make(chan peer_protocol.Message), - writeCh: make(chan []byte), - } - if len(c.Requests) != 0 { - t.FailNow() + peerPieces: func() bitmap.Bitmap { + var bm bitmap.Bitmap + bm.Set(1, true) + return bm + }(), + post: make(chan peer_protocol.Message), + writeCh: make(chan []byte), } + assert.Len(t, c.Requests, 0) // Keepalive timeout of 0 works because I'm just that good. go c.writeOptimizer(0 * time.Millisecond) c.Request(newRequest(1, 2, 3))