Decoding of Piece messages, and checking entire message is consumed

This commit is contained in:
Matt Joiner 2013-10-02 17:57:19 +10:00
parent 28531a4fcc
commit beb599698f
1 changed files with 26 additions and 4 deletions

View File

@ -7,6 +7,7 @@ import (
"errors" "errors"
"fmt" "fmt"
"io" "io"
"io/ioutil"
) )
type ( type (
@ -90,30 +91,51 @@ func (d *Decoder) Decode(msg *Message) (err error) {
if length > d.MaxLength { if length > d.MaxLength {
return errors.New("message too long") return errors.New("message too long")
} }
r := bufio.NewReader(io.LimitReader(d.R, int64(length)))
if length == 0 { if length == 0 {
msg.Keepalive = true msg.Keepalive = true
return return
} }
msg.Keepalive = false msg.Keepalive = false
c, err := d.R.ReadByte() c, err := r.ReadByte()
if err != nil { if err != nil {
return return
} }
msg.Type = MessageType(c) msg.Type = MessageType(c)
defer func() {
written, _ := io.Copy(ioutil.Discard, r)
if written != 0 && err != nil {
err = fmt.Errorf("short read on message type %d, left %d bytes", msg.Type, written)
}
}()
switch msg.Type { switch msg.Type {
case Choke, Unchoke, Interested, NotInterested: case Choke, Unchoke, Interested, NotInterested:
return return
case Have: case Have:
err = msg.Index.Read(d.R) err = msg.Index.Read(r)
case Request, Cancel: case Request, Cancel:
err = binary.Read(d.R, binary.BigEndian, []*Integer{&msg.Index, &msg.Begin, &msg.Length}) err = binary.Read(r, binary.BigEndian, []*Integer{&msg.Index, &msg.Begin, &msg.Length})
case Bitfield: case Bitfield:
b := make([]byte, length-1) b := make([]byte, length-1)
_, err = io.ReadFull(d.R, b) _, err = io.ReadFull(r, b)
msg.Bitfield = unmarshalBitfield(b) msg.Bitfield = unmarshalBitfield(b)
case Piece:
for _, pi := range []*Integer{&msg.Index, &msg.Begin} {
err = pi.Read(r)
if err != nil {
break
}
}
if err != nil {
break
}
msg.Piece, err = ioutil.ReadAll(r)
default: default:
err = fmt.Errorf("unknown message type %#v", c) err = fmt.Errorf("unknown message type %#v", c)
} }
if err != nil {
err = fmt.Errorf("decoding type %d: %s", msg.Type, err)
}
return return
} }