More optimizations in peer protocol message decoding
This commit is contained in:
parent
73696fd215
commit
2027028539
@ -5,7 +5,6 @@ import (
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"sync"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
@ -31,27 +30,20 @@ func (d *Decoder) Decode(msg *Message) (err error) {
|
||||
msg.Keepalive = true
|
||||
return
|
||||
}
|
||||
msg.Keepalive = false
|
||||
r := &io.LimitedReader{R: d.R, N: int64(length)}
|
||||
// Check that all of r was utilized.
|
||||
defer func() {
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
if r.N != 0 {
|
||||
err = fmt.Errorf("%d bytes unused in message type %d", r.N, msg.Type)
|
||||
}
|
||||
}()
|
||||
msg.Keepalive = false
|
||||
c, err := readByte(r)
|
||||
r := d.R
|
||||
readByte := func() (byte, error) {
|
||||
length--
|
||||
return d.R.ReadByte()
|
||||
}
|
||||
c, err := readByte()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
msg.Type = MessageType(c)
|
||||
switch msg.Type {
|
||||
case Choke, Unchoke, Interested, NotInterested, HaveAll, HaveNone:
|
||||
return
|
||||
case Have, AllowedFast, Suggest:
|
||||
length -= 4
|
||||
err = msg.Index.Read(r)
|
||||
case Request, Cancel, Reject:
|
||||
for _, data := range []*Integer{&msg.Index, &msg.Begin, &msg.Length} {
|
||||
@ -60,9 +52,11 @@ func (d *Decoder) Decode(msg *Message) (err error) {
|
||||
break
|
||||
}
|
||||
}
|
||||
length -= 12
|
||||
case Bitfield:
|
||||
b := make([]byte, length-1)
|
||||
b := make([]byte, length)
|
||||
_, err = io.ReadFull(r, b)
|
||||
length = 0
|
||||
msg.Bitfield = unmarshalBitfield(b)
|
||||
case Piece:
|
||||
for _, pi := range []*Integer{&msg.Index, &msg.Begin} {
|
||||
@ -71,7 +65,8 @@ func (d *Decoder) Decode(msg *Message) (err error) {
|
||||
return err
|
||||
}
|
||||
}
|
||||
dataLen := r.N
|
||||
length -= 8
|
||||
dataLen := int64(length)
|
||||
msg.Piece = (*d.Pool.Get().(*[]byte))
|
||||
if int64(cap(msg.Piece)) < dataLen {
|
||||
return errors.New("piece data longer than expected")
|
||||
@ -79,21 +74,31 @@ func (d *Decoder) Decode(msg *Message) (err error) {
|
||||
msg.Piece = msg.Piece[:dataLen]
|
||||
_, err := io.ReadFull(r, msg.Piece)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "reading piece data")
|
||||
return fmt.Errorf("reading piece data: %w", err)
|
||||
}
|
||||
length = 0
|
||||
case Extended:
|
||||
var b byte
|
||||
b, err = readByte(r)
|
||||
b, err = readByte()
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
msg.ExtendedID = ExtensionNumber(b)
|
||||
msg.ExtendedPayload, err = ioutil.ReadAll(r)
|
||||
msg.ExtendedPayload = make([]byte, length)
|
||||
_, err = io.ReadFull(r, msg.ExtendedPayload)
|
||||
if err == io.EOF {
|
||||
err = io.ErrUnexpectedEOF
|
||||
}
|
||||
length = 0
|
||||
case Port:
|
||||
err = binary.Read(r, binary.BigEndian, &msg.Port)
|
||||
length -= 2
|
||||
default:
|
||||
err = fmt.Errorf("unknown message type %#v", c)
|
||||
}
|
||||
if err == nil && length != 0 {
|
||||
err = fmt.Errorf("%v unused bytes in message type %v", length, msg.Type)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -82,7 +82,7 @@ func TestShortRead(t *testing.T) {
|
||||
}
|
||||
msg := new(Message)
|
||||
err := dec.Decode(msg)
|
||||
if !strings.Contains(err.Error(), "1 bytes unused in message type 0") {
|
||||
if !strings.Contains(err.Error(), "1 unused bytes in message type Choke") {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user