Decoding of Piece messages, and checking entire message is consumed
This commit is contained in:
parent
28531a4fcc
commit
beb599698f
|
@ -7,6 +7,7 @@ import (
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
|
"io/ioutil"
|
||||||
)
|
)
|
||||||
|
|
||||||
type (
|
type (
|
||||||
|
@ -90,30 +91,51 @@ func (d *Decoder) Decode(msg *Message) (err error) {
|
||||||
if length > d.MaxLength {
|
if length > d.MaxLength {
|
||||||
return errors.New("message too long")
|
return errors.New("message too long")
|
||||||
}
|
}
|
||||||
|
r := bufio.NewReader(io.LimitReader(d.R, int64(length)))
|
||||||
if length == 0 {
|
if length == 0 {
|
||||||
msg.Keepalive = true
|
msg.Keepalive = true
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
msg.Keepalive = false
|
msg.Keepalive = false
|
||||||
c, err := d.R.ReadByte()
|
c, err := r.ReadByte()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
msg.Type = MessageType(c)
|
msg.Type = MessageType(c)
|
||||||
|
defer func() {
|
||||||
|
written, _ := io.Copy(ioutil.Discard, r)
|
||||||
|
if written != 0 && err != nil {
|
||||||
|
err = fmt.Errorf("short read on message type %d, left %d bytes", msg.Type, written)
|
||||||
|
}
|
||||||
|
}()
|
||||||
switch msg.Type {
|
switch msg.Type {
|
||||||
case Choke, Unchoke, Interested, NotInterested:
|
case Choke, Unchoke, Interested, NotInterested:
|
||||||
return
|
return
|
||||||
case Have:
|
case Have:
|
||||||
err = msg.Index.Read(d.R)
|
err = msg.Index.Read(r)
|
||||||
case Request, Cancel:
|
case Request, Cancel:
|
||||||
err = binary.Read(d.R, binary.BigEndian, []*Integer{&msg.Index, &msg.Begin, &msg.Length})
|
err = binary.Read(r, binary.BigEndian, []*Integer{&msg.Index, &msg.Begin, &msg.Length})
|
||||||
case Bitfield:
|
case Bitfield:
|
||||||
b := make([]byte, length-1)
|
b := make([]byte, length-1)
|
||||||
_, err = io.ReadFull(d.R, b)
|
_, err = io.ReadFull(r, b)
|
||||||
msg.Bitfield = unmarshalBitfield(b)
|
msg.Bitfield = unmarshalBitfield(b)
|
||||||
|
case Piece:
|
||||||
|
for _, pi := range []*Integer{&msg.Index, &msg.Begin} {
|
||||||
|
err = pi.Read(r)
|
||||||
|
if err != nil {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
msg.Piece, err = ioutil.ReadAll(r)
|
||||||
default:
|
default:
|
||||||
err = fmt.Errorf("unknown message type %#v", c)
|
err = fmt.Errorf("unknown message type %#v", c)
|
||||||
}
|
}
|
||||||
|
if err != nil {
|
||||||
|
err = fmt.Errorf("decoding type %d: %s", msg.Type, err)
|
||||||
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue