diff --git a/connection.go b/connection.go index f0747375..f507806a 100644 --- a/connection.go +++ b/connection.go @@ -3,6 +3,7 @@ package torrent import ( "container/list" "encoding" + "errors" "fmt" "io" "log" @@ -36,11 +37,13 @@ type connection struct { Requests map[request]struct{} // Stuff controlled by the remote peer. - PeerId [20]byte - PeerInterested bool - PeerChoked bool - PeerRequests map[request]struct{} - PeerExtensions [8]byte + PeerId [20]byte + PeerInterested bool + PeerChoked bool + PeerRequests map[request]struct{} + PeerExtensions [8]byte + // Whether the peer has the given piece. nil if they've not sent any + // related messages yet. PeerPieces []bool PeerMaxRequests int // Maximum pending requests the peer allows. PeerExtensionIDs map[string]int64 @@ -68,6 +71,32 @@ func (cn *connection) piecesPeerHasCount() (count int) { return } +// Correct the PeerPieces slice length. Return false if the existing slice is +// invalid, such as by receiving badly sized BITFIELD, or invalid HAVE +// messages. +func (cn *connection) setNumPieces(num int) error { + if cn.PeerPieces == nil { + return nil + } + if len(cn.PeerPieces) == num { + } else if len(cn.PeerPieces) < num { + cn.PeerPieces = append(cn.PeerPieces, make([]bool, num-len(cn.PeerPieces))...) + } else if len(cn.PeerPieces) < 8*(num+7)/8 { + for _, have := range cn.PeerPieces[num:] { + if have { + return errors.New("peer has invalid piece") + } + } + cn.PeerPieces = cn.PeerPieces[:num] + } else { + return errors.New("peer bitfield is excessively long") + } + if len(cn.PeerPieces) != num { + panic("wat") + } + return nil +} + func (cn *connection) WriteStatus(w io.Writer) { fmt.Fprintf(w, "%q: %s-%s: %s completed, reqs: %d-%d, flags: ", cn.PeerId, cn.Socket.LocalAddr(), cn.Socket.RemoteAddr(), cn.completedString(), len(cn.Requests), len(cn.PeerRequests)) c := func(b byte) { diff --git a/torrent.go b/torrent.go index a96c48d6..1e842029 100644 --- a/torrent.go +++ b/torrent.go @@ -93,6 +93,7 @@ func infoPieceHashes(info *metainfo.Info) (ret []string) { return } +// Called when metadata for a torrent becomes available. func (t *torrent) setMetadata(md metainfo.Info, dataDir string, infoBytes []byte) (err error) { t.Info = &md t.MetaData = infoBytes @@ -120,6 +121,12 @@ func (t *torrent) setMetadata(md metainfo.Info, dataDir string, infoBytes []byte t.pendAllChunkSpecs(pp.Integer(index)) } t.Priorities = list.New() + for _, conn := range t.Conns { + if err := conn.setNumPieces(t.NumPieces()); err != nil { + log.Printf("closing connection: %s", err) + conn.Close() + } + } return }