2022-04-06 11:48:16 +02:00

140 lines
2.8 KiB
Go

package peer_protocol
import (
"bufio"
"bytes"
"encoding"
"encoding/binary"
"fmt"
)
// This is a lazy union representing all the possible fields for messages. Go doesn't have ADTs, and
// I didn't choose to use type-assertions.
type Message struct {
Keepalive bool
Type MessageType
Index, Begin, Length Integer
Piece []byte
Bitfield []bool
ExtendedID ExtensionNumber
ExtendedPayload []byte
Port uint16
}
var _ interface {
encoding.BinaryUnmarshaler
encoding.BinaryMarshaler
} = (*Message)(nil)
func MakeCancelMessage(piece, offset, length Integer) Message {
return Message{
Type: Cancel,
Index: piece,
Begin: offset,
Length: length,
}
}
func (msg Message) RequestSpec() (ret RequestSpec) {
return RequestSpec{
msg.Index,
msg.Begin,
func() Integer {
if msg.Type == Piece {
return Integer(len(msg.Piece))
} else {
return msg.Length
}
}(),
}
}
func (msg Message) MustMarshalBinary() []byte {
b, err := msg.MarshalBinary()
if err != nil {
panic(err)
}
return b
}
func (msg Message) MarshalBinary() (data []byte, err error) {
var buf bytes.Buffer
if !msg.Keepalive {
err = buf.WriteByte(byte(msg.Type))
if err != nil {
return
}
switch msg.Type {
case Choke, Unchoke, Interested, NotInterested, HaveAll, HaveNone:
case Have:
err = binary.Write(&buf, binary.BigEndian, msg.Index)
case Request, Cancel, Reject:
for _, i := range []Integer{msg.Index, msg.Begin, msg.Length} {
err = binary.Write(&buf, binary.BigEndian, i)
if err != nil {
break
}
}
case Bitfield:
_, err = buf.Write(marshalBitfield(msg.Bitfield))
case Piece:
for _, i := range []Integer{msg.Index, msg.Begin} {
err = binary.Write(&buf, binary.BigEndian, i)
if err != nil {
return
}
}
n, err := buf.Write(msg.Piece)
if err != nil {
break
}
if n != len(msg.Piece) {
panic(n)
}
case Extended:
err = buf.WriteByte(byte(msg.ExtendedID))
if err != nil {
return
}
_, err = buf.Write(msg.ExtendedPayload)
case Port:
err = binary.Write(&buf, binary.BigEndian, msg.Port)
default:
err = fmt.Errorf("unknown message type: %v", msg.Type)
}
}
data = make([]byte, 4+buf.Len())
binary.BigEndian.PutUint32(data, uint32(buf.Len()))
if buf.Len() != copy(data[4:], buf.Bytes()) {
panic("bad copy")
}
return
}
func marshalBitfield(bf []bool) (b []byte) {
b = make([]byte, (len(bf)+7)/8)
for i, have := range bf {
if !have {
continue
}
c := b[i/8]
c |= 1 << uint(7-i%8)
b[i/8] = c
}
return
}
func (me *Message) UnmarshalBinary(b []byte) error {
d := Decoder{
R: bufio.NewReader(bytes.NewReader(b)),
}
err := d.Decode(me)
if err != nil {
return err
}
if d.R.Buffered() != 0 {
return fmt.Errorf("%d trailing bytes", d.R.Buffered())
}
return nil
}