diff --git a/client.go b/client.go index ecbbcc4c..d1ae457f 100644 --- a/client.go +++ b/client.go @@ -54,7 +54,11 @@ var ( postedCancels = expvar.NewInt("postedCancels") ) -const extensionBytes = "\x00\x00\x00\x00\x00\x10\x00\x00" +// Justification for set bits follows. +// +// Extension protocol: http://www.bittorrent.org/beps/bep_0010.html +// DHT: http://www.bittorrent.org/beps/bep_0005.html +const extensionBytes = "\x00\x00\x00\x00\x00\x10\x00\x01" // Currently doesn't really queue, but should in the future. func (cl *Client) queuePieceCheck(t *torrent, pieceIndex pp.Integer) { @@ -531,6 +535,13 @@ func (me *Client) runConnection(sock net.Conn, torrent *torrent, discovery peerS Bitfield: torrent.bitfield(), }) } + if conn.PeerExtensionBytes[7]&0x01 != 0 && me.dHT != nil { + addr, _ := me.dHT.LocalAddr().(*net.UDPAddr) + conn.Post(pp.Message{ + Type: pp.Port, + Port: uint16(addr.Port), + }) + } err = me.connectionLoop(torrent, conn) if err != nil { err = fmt.Errorf("during Connection loop with peer %q: %s", conn.PeerID, err) @@ -860,6 +871,16 @@ func (me *Client) connectionLoop(t *torrent, c *connection) error { if err != nil { log.Printf("peer extension map: %#v", c.PeerExtensionIDs) } + case pp.Port: + if me.dHT == nil { + break + } + addr, _ := c.Socket.RemoteAddr().(*net.TCPAddr) + _, err = me.dHT.Ping(&net.UDPAddr{ + IP: addr.IP, + Zone: addr.Zone, + Port: int(msg.Port), + }) default: err = fmt.Errorf("received unknown message type: %#v", msg.Type) } diff --git a/peer_protocol/protocol.go b/peer_protocol/protocol.go index db32ac59..607edf59 100644 --- a/peer_protocol/protocol.go +++ b/peer_protocol/protocol.go @@ -33,6 +33,7 @@ const ( Request // 6 Piece // 7 Cancel // 8 + Port // 9 Extended = 20 HandshakeExtendedID = 0 @@ -50,6 +51,7 @@ type Message struct { Bitfield []bool ExtendedID byte ExtendedPayload []byte + Port uint16 } func (msg Message) MarshalBinary() (data []byte, err error) { @@ -92,6 +94,8 @@ func (msg Message) MarshalBinary() (data []byte, err error) { return } _, err = buf.Write(msg.ExtendedPayload) + case Port: + err = binary.Write(buf, binary.BigEndian, msg.Port) default: err = fmt.Errorf("unknown message type: %v", msg.Type) } @@ -187,6 +191,8 @@ func (d *Decoder) Decode(msg *Message) (err error) { break } msg.ExtendedPayload, err = ioutil.ReadAll(r) + case Port: + err = binary.Read(r, binary.BigEndian, &msg.Port) default: err = fmt.Errorf("unknown message type %#v", c) } diff --git a/peer_protocol/protocol_test.go b/peer_protocol/protocol_test.go index d2fc8d4b..580d26b2 100644 --- a/peer_protocol/protocol_test.go +++ b/peer_protocol/protocol_test.go @@ -121,3 +121,34 @@ func TestMarshalKeepalive(t *testing.T) { t.Fatalf("marshalled keepalive is %q, expected %q", bs, expected) } } + +func TestMarshalPortMsg(t *testing.T) { + b, err := (Message{ + Type: Port, + Port: 0xaabb, + }).MarshalBinary() + if err != nil { + t.Fatal(err) + } + if string(b) != "\x00\x00\x00\x03\x09\xaa\xbb" { + t.FailNow() + } +} + +func TestUnmarshalPortMsg(t *testing.T) { + var m Message + d := Decoder{ + R: bufio.NewReader(bytes.NewBufferString("\x00\x00\x00\x03\x09\xaa\xbb")), + MaxLength: 8, + } + err := d.Decode(&m) + if err != nil { + t.Fatal(err) + } + if m.Type != Port { + t.FailNow() + } + if m.Port != 0xaabb { + t.FailNow() + } +}