From 99a0813d8872c274f52b6b297a1939f1e99ca5eb Mon Sep 17 00:00:00 2001 From: Matt Joiner Date: Fri, 23 Oct 2015 12:41:45 +1100 Subject: [PATCH] dht: Make Msg a struct with bencode tags --- cmd/dht-get-peers/main.go | 6 +- cmd/dht-ping/main.go | 5 +- cmd/dht-server/main.go | 6 +- dht/announce.go | 51 +++++---- dht/compactNodeInfo.go | 49 +++++++++ dht/dht.go | 221 ++++++-------------------------------- dht/dht_test.go | 91 ++++++++++++++-- dht/krpcError.go | 45 ++++++++ dht/msg.go | 52 +++++++++ util/types.go | 28 ++++- 10 files changed, 317 insertions(+), 237 deletions(-) create mode 100644 dht/compactNodeInfo.go create mode 100644 dht/krpcError.go create mode 100644 dht/msg.go diff --git a/cmd/dht-get-peers/main.go b/cmd/dht-get-peers/main.go index 646104d9..f03e9fd1 100644 --- a/cmd/dht-get-peers/main.go +++ b/cmd/dht-get-peers/main.go @@ -39,7 +39,7 @@ func loadTable() error { defer f.Close() added := 0 for { - b := make([]byte, dht.CompactNodeInfoLen) + b := make([]byte, dht.CompactIPv4NodeInfoLen) _, err := io.ReadFull(f, b) if err == io.EOF { break @@ -48,7 +48,7 @@ func loadTable() error { return fmt.Errorf("error reading table file: %s", err) } var ni dht.NodeInfo - err = ni.UnmarshalCompact(b) + err = ni.UnmarshalCompactIPv4(b) if err != nil { return fmt.Errorf("error unmarshaling compact node info: %s", err) } @@ -101,7 +101,7 @@ func saveTable() error { } defer f.Close() for _, nodeInfo := range goodNodes { - var b [dht.CompactNodeInfoLen]byte + var b [dht.CompactIPv4NodeInfoLen]byte err := nodeInfo.PutCompact(b[:]) if err != nil { return fmt.Errorf("error compacting node info: %s", err) diff --git a/cmd/dht-ping/main.go b/cmd/dht-ping/main.go index 0588a848..bae58431 100644 --- a/cmd/dht-ping/main.go +++ b/cmd/dht-ping/main.go @@ -68,11 +68,8 @@ pingResponses: for _ = range pingStrAddrs { select { case resp := <-pingResponses: - if resp.krpc == nil { - break - } responses++ - fmt.Printf("%-65s %s\n", fmt.Sprintf("%x (%s):", resp.krpc["r"].(map[string]interface{})["id"].(string), resp.addr), resp.rtt) + fmt.Printf("%-65s %s\n", fmt.Sprintf("%x (%s):", resp.krpc.R.ID, resp.addr), resp.rtt) case <-timeoutChan: break pingResponses } diff --git a/cmd/dht-server/main.go b/cmd/dht-server/main.go index 19799ce8..f0792b8d 100644 --- a/cmd/dht-server/main.go +++ b/cmd/dht-server/main.go @@ -32,7 +32,7 @@ func loadTable() error { defer f.Close() added := 0 for { - b := make([]byte, dht.CompactNodeInfoLen) + b := make([]byte, dht.CompactIPv4NodeInfoLen) _, err := io.ReadFull(f, b) if err == io.EOF { break @@ -41,7 +41,7 @@ func loadTable() error { return fmt.Errorf("error reading table file: %s", err) } var ni dht.NodeInfo - err = ni.UnmarshalCompact(b) + err = ni.UnmarshalCompactIPv4(b) if err != nil { return fmt.Errorf("error unmarshaling compact node info: %s", err) } @@ -84,7 +84,7 @@ func saveTable() error { } defer f.Close() for _, nodeInfo := range goodNodes { - var b [dht.CompactNodeInfoLen]byte + var b [dht.CompactIPv4NodeInfoLen]byte err := nodeInfo.PutCompact(b[:]) if err != nil { return fmt.Errorf("error compacting node info: %s", err) diff --git a/dht/announce.go b/dht/announce.go index a7c55868..f87b76f2 100644 --- a/dht/announce.go +++ b/dht/announce.go @@ -160,36 +160,35 @@ func (me *Announce) getPeers(addr dHTAddr) error { } t.SetResponseHandler(func(m Msg) { // Register suggested nodes closer to the target info-hash. - me.mu.Lock() - for _, n := range m.Nodes() { - me.responseNode(n) - } - me.mu.Unlock() + if m.R != nil { + me.mu.Lock() + for _, n := range m.R.Nodes { + me.responseNode(n) + } + me.mu.Unlock() - if vs := m.Values(); vs != nil { - for _, cp := range vs { - if cp.Port == 0 { - me.server.mu.Lock() - me.server.badNode(addr) - me.server.mu.Unlock() - return + if vs := m.R.Values; len(vs) != 0 { + nodeInfo := NodeInfo{ + Addr: t.remoteAddr, + } + copy(nodeInfo.ID[:], m.SenderID()) + select { + case me.values <- PeersValues{ + Peers: func() (ret []Peer) { + for _, cp := range vs { + ret = append(ret, Peer(cp)) + } + return + }(), + NodeInfo: nodeInfo, + }: + case <-me.stop: } } - nodeInfo := NodeInfo{ - Addr: t.remoteAddr, - } - copy(nodeInfo.ID[:], m.SenderID()) - select { - case me.values <- PeersValues{ - Peers: vs, - NodeInfo: nodeInfo, - }: - case <-me.stop: - } - } - if at, ok := m.AnnounceToken(); ok { - me.announcePeer(addr, at) + if at := m.R.Token; at != "" { + me.announcePeer(addr, at) + } } me.mu.Lock() diff --git a/dht/compactNodeInfo.go b/dht/compactNodeInfo.go new file mode 100644 index 00000000..e12bb0fc --- /dev/null +++ b/dht/compactNodeInfo.go @@ -0,0 +1,49 @@ +package dht + +import ( + "bytes" + "encoding/binary" + "errors" + "fmt" + + "github.com/anacrolix/torrent/bencode" +) + +type CompactIPv4NodeInfo []NodeInfo + +var _ bencode.Unmarshaler = &CompactIPv4NodeInfo{} + +func (me *CompactIPv4NodeInfo) UnmarshalBencode(_b []byte) (err error) { + var b []byte + err = bencode.Unmarshal(_b, &b) + if err != nil { + return + } + if len(b)%CompactIPv4NodeInfoLen != 0 { + err = fmt.Errorf("bad length: %d", len(b)) + return + } + for i := 0; i < len(b); i += CompactIPv4NodeInfoLen { + var ni NodeInfo + err = ni.UnmarshalCompactIPv4(b[i : i+CompactIPv4NodeInfoLen]) + if err != nil { + return + } + *me = append(*me, ni) + } + return +} + +func (me CompactIPv4NodeInfo) MarshalBencode() (ret []byte, err error) { + var buf bytes.Buffer + for _, ni := range me { + buf.Write(ni.ID[:]) + if ni.Addr == nil { + err = errors.New("nil addr in node info") + return + } + buf.Write(ni.Addr.IP().To4()) + binary.Write(&buf, binary.BigEndian, uint16(ni.Addr.UDPAddr().Port)) + } + return bencode.Marshal(buf.Bytes()) +} diff --git a/dht/dht.go b/dht/dht.go index 51845755..bea3f8dd 100644 --- a/dht/dht.go +++ b/dht/dht.go @@ -28,11 +28,13 @@ import ( "github.com/anacrolix/torrent/bencode" "github.com/anacrolix/torrent/iplist" "github.com/anacrolix/torrent/logonce" - "github.com/anacrolix/torrent/util" ) const ( - maxNodes = 320 + maxNodes = 320 +) + +var ( queryResendEvery = 5 * time.Second ) @@ -266,105 +268,6 @@ func (n *node) DefinitelyGood() bool { return true } -// A wrapper around the unmarshalled KRPC dict that constitutes messages in -// the DHT. There are various helpers for extracting common data from the -// message. In normal use, Msg is abstracted away for you, but it can be of -// interest. -type Msg map[string]interface{} - -var _ fmt.Stringer = Msg{} - -func (m Msg) String() string { - return fmt.Sprintf("%#v", m) -} - -func (m Msg) T() (t string) { - tif, ok := m["t"] - if !ok { - return - } - t, _ = tif.(string) - return -} - -func (m Msg) Args() map[string]interface{} { - defer func() { - recover() - }() - return m["a"].(map[string]interface{}) -} - -func (m Msg) SenderID() string { - defer func() { - recover() - }() - switch m["y"].(string) { - case "q": - return m.Args()["id"].(string) - case "r": - return m["r"].(map[string]interface{})["id"].(string) - } - return "" -} - -// Suggested nodes in a response. -func (m Msg) Nodes() (nodes []NodeInfo) { - b := func() string { - defer func() { - recover() - }() - return m["r"].(map[string]interface{})["nodes"].(string) - }() - if len(b)%26 != 0 { - return - } - for i := 0; i < len(b); i += 26 { - var n NodeInfo - err := n.UnmarshalCompact([]byte(b[i : i+26])) - if err != nil { - continue - } - nodes = append(nodes, n) - } - return -} - -type KRPCError struct { - Code int - Msg string -} - -func (me KRPCError) Error() string { - return fmt.Sprintf("KRPC error %d: %s", me.Code, me.Msg) -} - -var _ error = KRPCError{} - -func (m Msg) Error() (ret *KRPCError) { - if m["y"] != "e" { - return - } - ret = &KRPCError{} - switch e := m["e"].(type) { - case []interface{}: - ret.Code = int(e[0].(int64)) - ret.Msg = e[1].(string) - case string: - ret.Msg = e - default: - logonce.Stderr.Printf(`KRPC error "e" value has unexpected type: %T`, e) - } - return -} - -// Returns the token given in response to a get_peers request for future -// announce_peer requests to that node. -func (m Msg) AnnounceToken() (token string, ok bool) { - defer func() { recover() }() - token, ok = m["r"].(map[string]interface{})["token"].(string) - return -} - type Transaction struct { mu sync.Mutex remoteAddr dHTAddr @@ -648,12 +551,12 @@ func (s *Server) processPacket(b []byte, addr dHTAddr) { } s.mu.Lock() defer s.mu.Unlock() - if d["y"] == "q" { + if d.Y == "q" { readQuery.Add(1) s.handleQuery(addr, d) return } - t := s.findResponseTransaction(d.T(), addr) + t := s.findResponseTransaction(d.T, addr) if t == nil { //log.Printf("unexpected message: %#v", d) return @@ -722,15 +625,12 @@ func (s *Server) handleQuery(source dHTAddr, m Msg) { if s.config.Passive { return } - args := m.Args() - if args == nil { - return - } - switch m["q"] { + args := m.A + switch m.Q { case "ping": - s.reply(source, m["t"].(string), nil) + s.reply(source, m.T, Return{}) case "get_peers": // TODO: Extract common behaviour with find_node. - targetID := args["info_hash"].(string) + targetID := args.InfoHash if len(targetID) != 20 { break } @@ -739,19 +639,13 @@ func (s *Server) handleQuery(source dHTAddr, m Msg) { for _, node := range s.closestGoodNodes(8, targetID) { rNodes = append(rNodes, node.NodeInfo()) } - nodesBytes := make([]byte, CompactNodeInfoLen*len(rNodes)) - for i, ni := range rNodes { - err := ni.PutCompact(nodesBytes[i*CompactNodeInfoLen : (i+1)*CompactNodeInfoLen]) - if err != nil { - panic(err) - } - } - s.reply(source, m["t"].(string), map[string]interface{}{ - "nodes": string(nodesBytes), - "token": "hi", + s.reply(source, m.T, Return{ + Nodes: rNodes, + // TODO: Generate this dynamically, and store it for the source. + Token: "hi", }) case "find_node": // TODO: Extract common behaviour with get_peers. - targetID := args["target"].(string) + targetID := args.Target if len(targetID) != 20 { log.Printf("bad DHT query: %v", m) return @@ -760,24 +654,13 @@ func (s *Server) handleQuery(source dHTAddr, m Msg) { if node := s.nodeByID(targetID); node != nil { rNodes = append(rNodes, node.NodeInfo()) } else { + // This will probably cause a crash for IPv6, but meh. for _, node := range s.closestGoodNodes(8, targetID) { rNodes = append(rNodes, node.NodeInfo()) } } - nodesBytes := make([]byte, CompactNodeInfoLen*len(rNodes)) - for i, ni := range rNodes { - // TODO: Put IPv6 nodes into the correct dict element. - if ni.Addr.UDPAddr().IP.To4() == nil { - continue - } - err := ni.PutCompact(nodesBytes[i*CompactNodeInfoLen : (i+1)*CompactNodeInfoLen]) - if err != nil { - log.Printf("error compacting %#v: %s", ni, err) - continue - } - } - s.reply(source, m["t"].(string), map[string]interface{}{ - "nodes": string(nodesBytes), + s.reply(source, m.T, Return{ + Nodes: rNodes, }) case "announce_peer": // TODO(anacrolix): Implement this lolz. @@ -785,20 +668,17 @@ func (s *Server) handleQuery(source dHTAddr, m Msg) { case "vote": // TODO(anacrolix): Or reject, I don't think I want this. default: - log.Printf("%s: not handling received query: q=%s", s, m["q"]) + log.Printf("%s: not handling received query: q=%s", s, m.Q) return } } -func (s *Server) reply(addr dHTAddr, t string, r map[string]interface{}) { - if r == nil { - r = make(map[string]interface{}, 1) - } - r["id"] = s.ID() - m := map[string]interface{}{ - "t": t, - "y": "r", - "r": r, +func (s *Server) reply(addr dHTAddr, t string, r Return) { + r.ID = s.ID() + m := Msg{ + T: t, + Y: "r", + R: &r, } b, err := bencode.Marshal(m) if err != nil { @@ -947,7 +827,7 @@ func (s *Server) query(node dHTAddr, q string, a map[string]interface{}, onRespo } // The size in bytes of a NodeInfo in its compact binary representation. -const CompactNodeInfoLen = 26 +const CompactIPv4NodeInfoLen = 26 type NodeInfo struct { ID [20]byte @@ -971,7 +851,7 @@ func (ni *NodeInfo) PutCompact(b []byte) error { return nil } -func (cni *NodeInfo) UnmarshalCompact(b []byte) error { +func (cni *NodeInfo) UnmarshalCompactIPv4(b []byte) error { if len(b) != 26 { return errors.New("expected 26 bytes") } @@ -1019,10 +899,10 @@ func (s *Server) announcePeer(node dHTAddr, infoHash string, port int, token str // Add response nodes to node table. func (s *Server) liftNodes(d Msg) { - if d["y"] != "r" { + if d.Y != "r" { return } - for _, cni := range d.Nodes() { + for _, cni := range d.R.Nodes { if missinggo.AddrPort(cni.Addr) == 0 { // TODO: Why would people even do this? continue @@ -1057,44 +937,6 @@ func (me *Peer) String() string { return net.JoinHostPort(me.IP.String(), strconv.FormatInt(int64(me.Port), 10)) } -// In a get_peers response, the addresses of torrent clients involved with the -// queried info-hash. -func (m Msg) Values() (vs []Peer) { - v := func() interface{} { - defer func() { - recover() - }() - return m["r"].(map[string]interface{})["values"] - }() - if v == nil { - return - } - vl, ok := v.([]interface{}) - if !ok { - if missinggo.CryHeard() { - log.Printf(`unexpected krpc "values" field: %#v`, v) - } - return - } - vs = make([]Peer, 0, len(vl)) - for _, i := range vl { - s, ok := i.(string) - if !ok { - panic(i) - } - // Because it's a list of strings, we can let the length of the string - // determine the IP version of the compact peer. - var cp util.CompactPeer - err := cp.UnmarshalBinary([]byte(s)) - if err != nil { - log.Printf("error decoding values list element: %s", err) - continue - } - vs = append(vs, Peer{cp.IP[:], int(cp.Port)}) - } - return -} - func (s *Server) getPeers(addr dHTAddr, infoHash string) (t *Transaction, err error) { if len(infoHash) != 20 { err = fmt.Errorf("infohash has bad length") @@ -1102,10 +944,7 @@ func (s *Server) getPeers(addr dHTAddr, infoHash string) (t *Transaction, err er } t, err = s.query(addr, "get_peers", map[string]interface{}{"info_hash": infoHash}, func(m Msg) { s.liftNodes(m) - at, ok := m.AnnounceToken() - if ok { - s.getNode(addr, m.SenderID()).announceToken = at - } + s.getNode(addr, m.SenderID()).announceToken = m.R.Token }) return } diff --git a/dht/dht_test.go b/dht/dht_test.go index f84ccb68..b8a8897c 100644 --- a/dht/dht_test.go +++ b/dht/dht_test.go @@ -9,6 +9,10 @@ import ( "github.com/anacrolix/missinggo" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/anacrolix/torrent/bencode" + "github.com/anacrolix/torrent/util" ) func TestSetNilBigInt(t *testing.T) { @@ -25,7 +29,7 @@ func TestMarshalCompactNodeInfo(t *testing.T) { t.Fatal(err) } cni.Addr = newDHTAddr(addr) - var b [CompactNodeInfoLen]byte + var b [CompactIPv4NodeInfoLen]byte cni.PutCompact(b[:]) if err != nil { t.Fatal(err) @@ -106,14 +110,12 @@ func TestClosestNodes(t *testing.T) { } func TestUnmarshalGetPeersResponse(t *testing.T) { - gpr := Msg{ - "r": map[string]interface{}{ - "values": []interface{}{"\x01\x02\x03\x04\x05\x06", "\x07\x08\x09\x0a\x0b\x0c"}, - "nodes": "\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b\x0c\x0d\x02\x03\x04\x05\x06\x07\x08\x09\x02\x03\x04\x05\x06\x07" + "\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b\x0c\x0d\x02\x03\x04\x05\x06\x07\x08\x09\x02\x03\x04\x05\x06\x07", - }, - } - assert.EqualValues(t, 2, len(gpr.Values())) - assert.EqualValues(t, 2, len(gpr.Nodes())) + var msg Msg + err := bencode.Unmarshal([]byte("d1:rd6:valuesl6:\x01\x02\x03\x04\x05\x066:\x07\x08\x09\x0a\x0b\x0ce5:nodes52:\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b\x0c\x0d\x02\x03\x04\x05\x06\x07\x08\x09\x02\x03\x04\x05\x06\x07\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b\x0c\x0d\x02\x03\x04\x05\x06\x07\x08\x09\x02\x03\x04\x05\x06\x07ee"), &msg) + require.NoError(t, err) + assert.Len(t, msg.R.Values, 2) + assert.Len(t, msg.R.Nodes, 2) + assert.Nil(t, msg.E) } func TestDHTDefaultConfig(t *testing.T) { @@ -203,3 +205,74 @@ func TestServerDefaultNodeIdSecure(t *testing.T) { t.Fatal("not secure") } } + +func testMarshalUnmarshalMsg(t *testing.T, m Msg, expected string) { + b, err := bencode.Marshal(m) + require.NoError(t, err) + assert.Equal(t, expected, string(b)) + var _m Msg + err = bencode.Unmarshal([]byte(expected), &_m) + assert.NoError(t, err) + assert.EqualValues(t, m, _m) + assert.EqualValues(t, m.R, _m.R) +} + +func TestMarshalUnmarshalMsg(t *testing.T) { + testMarshalUnmarshalMsg(t, Msg{}, "d1:t0:1:y0:e") + testMarshalUnmarshalMsg(t, Msg{ + Y: "q", + Q: "ping", + T: "hi", + }, "d1:q4:ping1:t2:hi1:y1:qe") + testMarshalUnmarshalMsg(t, Msg{ + Y: "e", + T: "42", + E: &KRPCError{Code: 200, Msg: "fuck"}, + }, "d1:eli200e4:fucke1:t2:421:y1:ee") + testMarshalUnmarshalMsg(t, Msg{ + Y: "r", + T: "\x8c%", + R: &Return{}, + }, "d1:rd2:id0:5:token0:e1:t2:\x8c%1:y1:re") + testMarshalUnmarshalMsg(t, Msg{ + Y: "r", + T: "\x8c%", + R: &Return{ + Nodes: CompactIPv4NodeInfo{ + NodeInfo{ + Addr: newDHTAddr(&net.UDPAddr{ + IP: net.IPv4(1, 2, 3, 4), + Port: 0x1234, + }), + }, + }, + }, + }, "d1:rd2:id0:5:nodes26:\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\x02\x03\x04\x1245:token0:e1:t2:\x8c%1:y1:re") + testMarshalUnmarshalMsg(t, Msg{ + Y: "r", + T: "\x8c%", + R: &Return{ + Values: []util.CompactPeer{ + util.CompactPeer{ + IP: net.IPv4(1, 2, 3, 4).To4(), + Port: 0x5678, + }, + }, + }, + }, "d1:rd2:id0:5:token0:6:valuesl6:\x01\x02\x03\x04\x56\x78ee1:t2:\x8c%1:y1:re") +} + +func TestAnnounceTimeout(t *testing.T) { + s, err := NewServer(&ServerConfig{ + BootstrapNodes: []string{"1.2.3.4:5"}, + }) + require.NoError(t, err) + a, err := s.Announce("12341234123412341234", 0, true) + <-a.Peers + a.Close() + s.Close() +} + +func TestEqualPointers(t *testing.T) { + assert.EqualValues(t, &Msg{R: &Return{}}, &Msg{R: &Return{}}) +} diff --git a/dht/krpcError.go b/dht/krpcError.go new file mode 100644 index 00000000..c52023e9 --- /dev/null +++ b/dht/krpcError.go @@ -0,0 +1,45 @@ +package dht + +import ( + "fmt" + + "github.com/anacrolix/torrent/bencode" +) + +// Represented as a string or list in bencode. +type KRPCError struct { + Code int + Msg string +} + +var ( + _ bencode.Unmarshaler = &KRPCError{} + _ bencode.Marshaler = &KRPCError{} + _ error = KRPCError{} +) + +func (me *KRPCError) UnmarshalBencode(_b []byte) (err error) { + var _v interface{} + err = bencode.Unmarshal(_b, &_v) + if err != nil { + return + } + switch v := _v.(type) { + case []interface{}: + me.Code = int(v[0].(int64)) + me.Msg = v[1].(string) + case string: + me.Msg = v + default: + err = fmt.Errorf(`KRPC error bencode value has unexpected type: %T`, _v) + } + return +} + +func (me KRPCError) MarshalBencode() (ret []byte, err error) { + return bencode.Marshal([]interface{}{me.Code, me.Msg}) +} + +func (me KRPCError) Error() string { + return fmt.Sprintf("KRPC error %d: %s", me.Code, me.Msg) +} diff --git a/dht/msg.go b/dht/msg.go new file mode 100644 index 00000000..991a354e --- /dev/null +++ b/dht/msg.go @@ -0,0 +1,52 @@ +package dht + +import ( + "fmt" + + "github.com/anacrolix/torrent/util" +) + +// The unmarshalled KRPC dict message. +type Msg struct { + Q string `bencode:"q,omitempty"` + A *struct { + ID string `bencode:"id"` + InfoHash string `bencode:"info_hash"` + Target string `bencode:"target"` + } `bencode:"a,omitempty"` + T string `bencode:"t"` + Y string `bencode:"y"` + R *Return `bencode:"r,omitempty"` + E *KRPCError `bencode:"e,omitempty"` +} + +type Return struct { + ID string `bencode:"id"` + Nodes CompactIPv4NodeInfo `bencode:"nodes,omitempty"` + Token string `bencode:"token"` + Values []util.CompactPeer `bencode:"values,omitempty"` +} + +var _ fmt.Stringer = Msg{} + +func (m Msg) String() string { + return fmt.Sprintf("%#v", m) +} + +// The node ID of the source of this Msg. +func (m Msg) SenderID() string { + switch m.Y { + case "q": + return m.A.ID + case "r": + return m.R.ID + } + return "" +} + +func (m Msg) Error() *KRPCError { + if m.Y != "e" { + return nil + } + return m.E +} diff --git a/util/types.go b/util/types.go index 942b9d73..d0655f14 100644 --- a/util/types.go +++ b/util/types.go @@ -4,6 +4,7 @@ import ( "encoding" "encoding/binary" "errors" + "fmt" "net" "github.com/bradfitz/iter" @@ -46,6 +47,22 @@ type CompactPeer struct { Port int } +var ( + _ bencode.Marshaler = &CompactPeer{} + _ bencode.Unmarshaler = &CompactPeer{} +) + +func (me CompactPeer) MarshalBencode() (ret []byte, err error) { + ip := me.IP + if ip4 := ip.To4(); ip4 != nil { + ip = ip4 + } + ret = make([]byte, len(ip)+2) + copy(ret, ip) + binary.BigEndian.PutUint16(ret[len(ip):], uint16(me.Port)) + return bencode.Marshal(ret) +} + func (me *CompactPeer) UnmarshalBinary(b []byte) error { switch len(b) { case 18: @@ -53,7 +70,7 @@ func (me *CompactPeer) UnmarshalBinary(b []byte) error { case 6: me.IP = make([]byte, 4) default: - return errors.New("bad length") + return fmt.Errorf("bad compact peer string: %q", b) } copy(me.IP, b) b = b[len(me.IP):] @@ -61,6 +78,15 @@ func (me *CompactPeer) UnmarshalBinary(b []byte) error { return nil } +func (me *CompactPeer) UnmarshalBencode(b []byte) (err error) { + var _b []byte + err = bencode.Unmarshal(b, &_b) + if err != nil { + return + } + return me.UnmarshalBinary(_b) +} + func UnmarshalIPv4CompactPeers(b []byte) (ret []CompactPeer, err error) { if len(b)%6 != 0 { err = errors.New("bad length")