Implement the DHT Port message

This commit is contained in:
Matt Joiner 2014-08-25 22:12:16 +10:00
parent 28e80062b6
commit c9bc892789
3 changed files with 59 additions and 1 deletions

View File

@ -54,7 +54,11 @@ var (
postedCancels = expvar.NewInt("postedCancels")
)
const extensionBytes = "\x00\x00\x00\x00\x00\x10\x00\x00"
// Justification for set bits follows.
//
// Extension protocol: http://www.bittorrent.org/beps/bep_0010.html
// DHT: http://www.bittorrent.org/beps/bep_0005.html
const extensionBytes = "\x00\x00\x00\x00\x00\x10\x00\x01"
// Currently doesn't really queue, but should in the future.
func (cl *Client) queuePieceCheck(t *torrent, pieceIndex pp.Integer) {
@ -531,6 +535,13 @@ func (me *Client) runConnection(sock net.Conn, torrent *torrent, discovery peerS
Bitfield: torrent.bitfield(),
})
}
if conn.PeerExtensionBytes[7]&0x01 != 0 && me.dHT != nil {
addr, _ := me.dHT.LocalAddr().(*net.UDPAddr)
conn.Post(pp.Message{
Type: pp.Port,
Port: uint16(addr.Port),
})
}
err = me.connectionLoop(torrent, conn)
if err != nil {
err = fmt.Errorf("during Connection loop with peer %q: %s", conn.PeerID, err)
@ -860,6 +871,16 @@ func (me *Client) connectionLoop(t *torrent, c *connection) error {
if err != nil {
log.Printf("peer extension map: %#v", c.PeerExtensionIDs)
}
case pp.Port:
if me.dHT == nil {
break
}
addr, _ := c.Socket.RemoteAddr().(*net.TCPAddr)
_, err = me.dHT.Ping(&net.UDPAddr{
IP: addr.IP,
Zone: addr.Zone,
Port: int(msg.Port),
})
default:
err = fmt.Errorf("received unknown message type: %#v", msg.Type)
}

View File

@ -33,6 +33,7 @@ const (
Request // 6
Piece // 7
Cancel // 8
Port // 9
Extended = 20
HandshakeExtendedID = 0
@ -50,6 +51,7 @@ type Message struct {
Bitfield []bool
ExtendedID byte
ExtendedPayload []byte
Port uint16
}
func (msg Message) MarshalBinary() (data []byte, err error) {
@ -92,6 +94,8 @@ func (msg Message) MarshalBinary() (data []byte, err error) {
return
}
_, err = buf.Write(msg.ExtendedPayload)
case Port:
err = binary.Write(buf, binary.BigEndian, msg.Port)
default:
err = fmt.Errorf("unknown message type: %v", msg.Type)
}
@ -187,6 +191,8 @@ func (d *Decoder) Decode(msg *Message) (err error) {
break
}
msg.ExtendedPayload, err = ioutil.ReadAll(r)
case Port:
err = binary.Read(r, binary.BigEndian, &msg.Port)
default:
err = fmt.Errorf("unknown message type %#v", c)
}

View File

@ -121,3 +121,34 @@ func TestMarshalKeepalive(t *testing.T) {
t.Fatalf("marshalled keepalive is %q, expected %q", bs, expected)
}
}
func TestMarshalPortMsg(t *testing.T) {
b, err := (Message{
Type: Port,
Port: 0xaabb,
}).MarshalBinary()
if err != nil {
t.Fatal(err)
}
if string(b) != "\x00\x00\x00\x03\x09\xaa\xbb" {
t.FailNow()
}
}
func TestUnmarshalPortMsg(t *testing.T) {
var m Message
d := Decoder{
R: bufio.NewReader(bytes.NewBufferString("\x00\x00\x00\x03\x09\xaa\xbb")),
MaxLength: 8,
}
err := d.Decode(&m)
if err != nil {
t.Fatal(err)
}
if m.Type != Port {
t.FailNow()
}
if m.Port != 0xaabb {
t.FailNow()
}
}