diff --git a/protocol/protocol.go b/protocol/protocol.go new file mode 100644 index 00000000..c4f2cb1a --- /dev/null +++ b/protocol/protocol.go @@ -0,0 +1,116 @@ +package protocol + +import ( + "bytes" + "encoding/binary" + "errors" + // "errors" + "fmt" + "io" + "io/ioutil" + // "os" +) + +type ( + MessageType byte + Integer uint32 + PieceIndex Integer + PieceOffset Integer +) + +const ( + Choke MessageType = iota + Unchoke + Interested + NotInterested + Have + Bitfield + Request + Piece + Cancel +) + +type Message struct { + KeepAlive bool + Type MessageType + Bitfield []bool + Index PieceIndex + Begin PieceOffset + Length Integer + Piece []byte +} + +func (msg *Message) UnmarshalReader(r io.Reader) (err error) { + err = binary.Read(r, binary.BigEndian, &msg.Type) + switch err { + case nil: + msg.KeepAlive = false + case io.EOF: + msg.KeepAlive = true + err = nil + return + default: + return + } + switch msg.Type { + case Choke, Unchoke, Interested, NotInterested: + case Have: + err = binary.Read(r, binary.BigEndian, &msg.Index) + case Request, Cancel: + err = binary.Read(r, binary.BigEndian, &msg.Index) + if err != nil { + return + } + err = binary.Read(r, binary.BigEndian, &msg.Begin) + if err != nil { + return + } + err = binary.Read(r, binary.BigEndian, &msg.Length) + case Bitfield: + // var bf []byte + _, err = ioutil.ReadAll(r) + if err != nil { + return + } + case Piece: + default: + return fmt.Errorf("unknown type: %v", msg.Type) + } + return +} + +func (msg *Message) MarshalBinary() (b []byte, err error) { + w := &bytes.Buffer{} + if msg.KeepAlive { + b = w.Bytes() + return + } + err = w.WriteByte(byte(msg.Type)) + if err != nil { + return + } + switch msg.Type { + case Choke, Unchoke, Interested, NotInterested: + case Have: + err = binary.Write(w, binary.BigEndian, msg.Index) + case Request, Cancel: + err = binary.Write(w, binary.BigEndian, msg.Index) + if err != nil { + return + } + err = binary.Write(w, binary.BigEndian, msg.Begin) + if err != nil { + return + } + err = binary.Write(w, binary.BigEndian, msg.Length) + case Bitfield, Piece: + panic("unimplemented") + default: + err = errors.New("unknown type") + return + } + if err == nil { + b = w.Bytes() + } + return +} diff --git a/protocol/protocol_test.go b/protocol/protocol_test.go new file mode 100644 index 00000000..32db2a62 --- /dev/null +++ b/protocol/protocol_test.go @@ -0,0 +1,12 @@ +package protocol + +import ( + "testing" +) + +func TestConstants(t *testing.T) { + // check that iota works as expected in the const block + if NotInterested != 3 { + t.FailNow() + } +}