dht: Make Msg a struct with bencode tags

This commit is contained in:
Matt Joiner 2015-10-23 12:41:45 +11:00
parent de69976bbf
commit 99a0813d88
10 changed files with 317 additions and 237 deletions

View File

@ -39,7 +39,7 @@ func loadTable() error {
defer f.Close()
added := 0
for {
b := make([]byte, dht.CompactNodeInfoLen)
b := make([]byte, dht.CompactIPv4NodeInfoLen)
_, err := io.ReadFull(f, b)
if err == io.EOF {
break
@ -48,7 +48,7 @@ func loadTable() error {
return fmt.Errorf("error reading table file: %s", err)
}
var ni dht.NodeInfo
err = ni.UnmarshalCompact(b)
err = ni.UnmarshalCompactIPv4(b)
if err != nil {
return fmt.Errorf("error unmarshaling compact node info: %s", err)
}
@ -101,7 +101,7 @@ func saveTable() error {
}
defer f.Close()
for _, nodeInfo := range goodNodes {
var b [dht.CompactNodeInfoLen]byte
var b [dht.CompactIPv4NodeInfoLen]byte
err := nodeInfo.PutCompact(b[:])
if err != nil {
return fmt.Errorf("error compacting node info: %s", err)

View File

@ -68,11 +68,8 @@ pingResponses:
for _ = range pingStrAddrs {
select {
case resp := <-pingResponses:
if resp.krpc == nil {
break
}
responses++
fmt.Printf("%-65s %s\n", fmt.Sprintf("%x (%s):", resp.krpc["r"].(map[string]interface{})["id"].(string), resp.addr), resp.rtt)
fmt.Printf("%-65s %s\n", fmt.Sprintf("%x (%s):", resp.krpc.R.ID, resp.addr), resp.rtt)
case <-timeoutChan:
break pingResponses
}

View File

@ -32,7 +32,7 @@ func loadTable() error {
defer f.Close()
added := 0
for {
b := make([]byte, dht.CompactNodeInfoLen)
b := make([]byte, dht.CompactIPv4NodeInfoLen)
_, err := io.ReadFull(f, b)
if err == io.EOF {
break
@ -41,7 +41,7 @@ func loadTable() error {
return fmt.Errorf("error reading table file: %s", err)
}
var ni dht.NodeInfo
err = ni.UnmarshalCompact(b)
err = ni.UnmarshalCompactIPv4(b)
if err != nil {
return fmt.Errorf("error unmarshaling compact node info: %s", err)
}
@ -84,7 +84,7 @@ func saveTable() error {
}
defer f.Close()
for _, nodeInfo := range goodNodes {
var b [dht.CompactNodeInfoLen]byte
var b [dht.CompactIPv4NodeInfoLen]byte
err := nodeInfo.PutCompact(b[:])
if err != nil {
return fmt.Errorf("error compacting node info: %s", err)

View File

@ -160,37 +160,36 @@ func (me *Announce) getPeers(addr dHTAddr) error {
}
t.SetResponseHandler(func(m Msg) {
// Register suggested nodes closer to the target info-hash.
if m.R != nil {
me.mu.Lock()
for _, n := range m.Nodes() {
for _, n := range m.R.Nodes {
me.responseNode(n)
}
me.mu.Unlock()
if vs := m.Values(); vs != nil {
for _, cp := range vs {
if cp.Port == 0 {
me.server.mu.Lock()
me.server.badNode(addr)
me.server.mu.Unlock()
return
}
}
if vs := m.R.Values; len(vs) != 0 {
nodeInfo := NodeInfo{
Addr: t.remoteAddr,
}
copy(nodeInfo.ID[:], m.SenderID())
select {
case me.values <- PeersValues{
Peers: vs,
Peers: func() (ret []Peer) {
for _, cp := range vs {
ret = append(ret, Peer(cp))
}
return
}(),
NodeInfo: nodeInfo,
}:
case <-me.stop:
}
}
if at, ok := m.AnnounceToken(); ok {
if at := m.R.Token; at != "" {
me.announcePeer(addr, at)
}
}
me.mu.Lock()
me.transactionClosed()

49
dht/compactNodeInfo.go Normal file
View File

@ -0,0 +1,49 @@
package dht
import (
"bytes"
"encoding/binary"
"errors"
"fmt"
"github.com/anacrolix/torrent/bencode"
)
type CompactIPv4NodeInfo []NodeInfo
var _ bencode.Unmarshaler = &CompactIPv4NodeInfo{}
func (me *CompactIPv4NodeInfo) UnmarshalBencode(_b []byte) (err error) {
var b []byte
err = bencode.Unmarshal(_b, &b)
if err != nil {
return
}
if len(b)%CompactIPv4NodeInfoLen != 0 {
err = fmt.Errorf("bad length: %d", len(b))
return
}
for i := 0; i < len(b); i += CompactIPv4NodeInfoLen {
var ni NodeInfo
err = ni.UnmarshalCompactIPv4(b[i : i+CompactIPv4NodeInfoLen])
if err != nil {
return
}
*me = append(*me, ni)
}
return
}
func (me CompactIPv4NodeInfo) MarshalBencode() (ret []byte, err error) {
var buf bytes.Buffer
for _, ni := range me {
buf.Write(ni.ID[:])
if ni.Addr == nil {
err = errors.New("nil addr in node info")
return
}
buf.Write(ni.Addr.IP().To4())
binary.Write(&buf, binary.BigEndian, uint16(ni.Addr.UDPAddr().Port))
}
return bencode.Marshal(buf.Bytes())
}

View File

@ -28,11 +28,13 @@ import (
"github.com/anacrolix/torrent/bencode"
"github.com/anacrolix/torrent/iplist"
"github.com/anacrolix/torrent/logonce"
"github.com/anacrolix/torrent/util"
)
const (
maxNodes = 320
)
var (
queryResendEvery = 5 * time.Second
)
@ -266,105 +268,6 @@ func (n *node) DefinitelyGood() bool {
return true
}
// A wrapper around the unmarshalled KRPC dict that constitutes messages in
// the DHT. There are various helpers for extracting common data from the
// message. In normal use, Msg is abstracted away for you, but it can be of
// interest.
type Msg map[string]interface{}
var _ fmt.Stringer = Msg{}
func (m Msg) String() string {
return fmt.Sprintf("%#v", m)
}
func (m Msg) T() (t string) {
tif, ok := m["t"]
if !ok {
return
}
t, _ = tif.(string)
return
}
func (m Msg) Args() map[string]interface{} {
defer func() {
recover()
}()
return m["a"].(map[string]interface{})
}
func (m Msg) SenderID() string {
defer func() {
recover()
}()
switch m["y"].(string) {
case "q":
return m.Args()["id"].(string)
case "r":
return m["r"].(map[string]interface{})["id"].(string)
}
return ""
}
// Suggested nodes in a response.
func (m Msg) Nodes() (nodes []NodeInfo) {
b := func() string {
defer func() {
recover()
}()
return m["r"].(map[string]interface{})["nodes"].(string)
}()
if len(b)%26 != 0 {
return
}
for i := 0; i < len(b); i += 26 {
var n NodeInfo
err := n.UnmarshalCompact([]byte(b[i : i+26]))
if err != nil {
continue
}
nodes = append(nodes, n)
}
return
}
type KRPCError struct {
Code int
Msg string
}
func (me KRPCError) Error() string {
return fmt.Sprintf("KRPC error %d: %s", me.Code, me.Msg)
}
var _ error = KRPCError{}
func (m Msg) Error() (ret *KRPCError) {
if m["y"] != "e" {
return
}
ret = &KRPCError{}
switch e := m["e"].(type) {
case []interface{}:
ret.Code = int(e[0].(int64))
ret.Msg = e[1].(string)
case string:
ret.Msg = e
default:
logonce.Stderr.Printf(`KRPC error "e" value has unexpected type: %T`, e)
}
return
}
// Returns the token given in response to a get_peers request for future
// announce_peer requests to that node.
func (m Msg) AnnounceToken() (token string, ok bool) {
defer func() { recover() }()
token, ok = m["r"].(map[string]interface{})["token"].(string)
return
}
type Transaction struct {
mu sync.Mutex
remoteAddr dHTAddr
@ -648,12 +551,12 @@ func (s *Server) processPacket(b []byte, addr dHTAddr) {
}
s.mu.Lock()
defer s.mu.Unlock()
if d["y"] == "q" {
if d.Y == "q" {
readQuery.Add(1)
s.handleQuery(addr, d)
return
}
t := s.findResponseTransaction(d.T(), addr)
t := s.findResponseTransaction(d.T, addr)
if t == nil {
//log.Printf("unexpected message: %#v", d)
return
@ -722,15 +625,12 @@ func (s *Server) handleQuery(source dHTAddr, m Msg) {
if s.config.Passive {
return
}
args := m.Args()
if args == nil {
return
}
switch m["q"] {
args := m.A
switch m.Q {
case "ping":
s.reply(source, m["t"].(string), nil)
s.reply(source, m.T, Return{})
case "get_peers": // TODO: Extract common behaviour with find_node.
targetID := args["info_hash"].(string)
targetID := args.InfoHash
if len(targetID) != 20 {
break
}
@ -739,19 +639,13 @@ func (s *Server) handleQuery(source dHTAddr, m Msg) {
for _, node := range s.closestGoodNodes(8, targetID) {
rNodes = append(rNodes, node.NodeInfo())
}
nodesBytes := make([]byte, CompactNodeInfoLen*len(rNodes))
for i, ni := range rNodes {
err := ni.PutCompact(nodesBytes[i*CompactNodeInfoLen : (i+1)*CompactNodeInfoLen])
if err != nil {
panic(err)
}
}
s.reply(source, m["t"].(string), map[string]interface{}{
"nodes": string(nodesBytes),
"token": "hi",
s.reply(source, m.T, Return{
Nodes: rNodes,
// TODO: Generate this dynamically, and store it for the source.
Token: "hi",
})
case "find_node": // TODO: Extract common behaviour with get_peers.
targetID := args["target"].(string)
targetID := args.Target
if len(targetID) != 20 {
log.Printf("bad DHT query: %v", m)
return
@ -760,24 +654,13 @@ func (s *Server) handleQuery(source dHTAddr, m Msg) {
if node := s.nodeByID(targetID); node != nil {
rNodes = append(rNodes, node.NodeInfo())
} else {
// This will probably cause a crash for IPv6, but meh.
for _, node := range s.closestGoodNodes(8, targetID) {
rNodes = append(rNodes, node.NodeInfo())
}
}
nodesBytes := make([]byte, CompactNodeInfoLen*len(rNodes))
for i, ni := range rNodes {
// TODO: Put IPv6 nodes into the correct dict element.
if ni.Addr.UDPAddr().IP.To4() == nil {
continue
}
err := ni.PutCompact(nodesBytes[i*CompactNodeInfoLen : (i+1)*CompactNodeInfoLen])
if err != nil {
log.Printf("error compacting %#v: %s", ni, err)
continue
}
}
s.reply(source, m["t"].(string), map[string]interface{}{
"nodes": string(nodesBytes),
s.reply(source, m.T, Return{
Nodes: rNodes,
})
case "announce_peer":
// TODO(anacrolix): Implement this lolz.
@ -785,20 +668,17 @@ func (s *Server) handleQuery(source dHTAddr, m Msg) {
case "vote":
// TODO(anacrolix): Or reject, I don't think I want this.
default:
log.Printf("%s: not handling received query: q=%s", s, m["q"])
log.Printf("%s: not handling received query: q=%s", s, m.Q)
return
}
}
func (s *Server) reply(addr dHTAddr, t string, r map[string]interface{}) {
if r == nil {
r = make(map[string]interface{}, 1)
}
r["id"] = s.ID()
m := map[string]interface{}{
"t": t,
"y": "r",
"r": r,
func (s *Server) reply(addr dHTAddr, t string, r Return) {
r.ID = s.ID()
m := Msg{
T: t,
Y: "r",
R: &r,
}
b, err := bencode.Marshal(m)
if err != nil {
@ -947,7 +827,7 @@ func (s *Server) query(node dHTAddr, q string, a map[string]interface{}, onRespo
}
// The size in bytes of a NodeInfo in its compact binary representation.
const CompactNodeInfoLen = 26
const CompactIPv4NodeInfoLen = 26
type NodeInfo struct {
ID [20]byte
@ -971,7 +851,7 @@ func (ni *NodeInfo) PutCompact(b []byte) error {
return nil
}
func (cni *NodeInfo) UnmarshalCompact(b []byte) error {
func (cni *NodeInfo) UnmarshalCompactIPv4(b []byte) error {
if len(b) != 26 {
return errors.New("expected 26 bytes")
}
@ -1019,10 +899,10 @@ func (s *Server) announcePeer(node dHTAddr, infoHash string, port int, token str
// Add response nodes to node table.
func (s *Server) liftNodes(d Msg) {
if d["y"] != "r" {
if d.Y != "r" {
return
}
for _, cni := range d.Nodes() {
for _, cni := range d.R.Nodes {
if missinggo.AddrPort(cni.Addr) == 0 {
// TODO: Why would people even do this?
continue
@ -1057,44 +937,6 @@ func (me *Peer) String() string {
return net.JoinHostPort(me.IP.String(), strconv.FormatInt(int64(me.Port), 10))
}
// In a get_peers response, the addresses of torrent clients involved with the
// queried info-hash.
func (m Msg) Values() (vs []Peer) {
v := func() interface{} {
defer func() {
recover()
}()
return m["r"].(map[string]interface{})["values"]
}()
if v == nil {
return
}
vl, ok := v.([]interface{})
if !ok {
if missinggo.CryHeard() {
log.Printf(`unexpected krpc "values" field: %#v`, v)
}
return
}
vs = make([]Peer, 0, len(vl))
for _, i := range vl {
s, ok := i.(string)
if !ok {
panic(i)
}
// Because it's a list of strings, we can let the length of the string
// determine the IP version of the compact peer.
var cp util.CompactPeer
err := cp.UnmarshalBinary([]byte(s))
if err != nil {
log.Printf("error decoding values list element: %s", err)
continue
}
vs = append(vs, Peer{cp.IP[:], int(cp.Port)})
}
return
}
func (s *Server) getPeers(addr dHTAddr, infoHash string) (t *Transaction, err error) {
if len(infoHash) != 20 {
err = fmt.Errorf("infohash has bad length")
@ -1102,10 +944,7 @@ func (s *Server) getPeers(addr dHTAddr, infoHash string) (t *Transaction, err er
}
t, err = s.query(addr, "get_peers", map[string]interface{}{"info_hash": infoHash}, func(m Msg) {
s.liftNodes(m)
at, ok := m.AnnounceToken()
if ok {
s.getNode(addr, m.SenderID()).announceToken = at
}
s.getNode(addr, m.SenderID()).announceToken = m.R.Token
})
return
}

View File

@ -9,6 +9,10 @@ import (
"github.com/anacrolix/missinggo"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/anacrolix/torrent/bencode"
"github.com/anacrolix/torrent/util"
)
func TestSetNilBigInt(t *testing.T) {
@ -25,7 +29,7 @@ func TestMarshalCompactNodeInfo(t *testing.T) {
t.Fatal(err)
}
cni.Addr = newDHTAddr(addr)
var b [CompactNodeInfoLen]byte
var b [CompactIPv4NodeInfoLen]byte
cni.PutCompact(b[:])
if err != nil {
t.Fatal(err)
@ -106,14 +110,12 @@ func TestClosestNodes(t *testing.T) {
}
func TestUnmarshalGetPeersResponse(t *testing.T) {
gpr := Msg{
"r": map[string]interface{}{
"values": []interface{}{"\x01\x02\x03\x04\x05\x06", "\x07\x08\x09\x0a\x0b\x0c"},
"nodes": "\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b\x0c\x0d\x02\x03\x04\x05\x06\x07\x08\x09\x02\x03\x04\x05\x06\x07" + "\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b\x0c\x0d\x02\x03\x04\x05\x06\x07\x08\x09\x02\x03\x04\x05\x06\x07",
},
}
assert.EqualValues(t, 2, len(gpr.Values()))
assert.EqualValues(t, 2, len(gpr.Nodes()))
var msg Msg
err := bencode.Unmarshal([]byte("d1:rd6:valuesl6:\x01\x02\x03\x04\x05\x066:\x07\x08\x09\x0a\x0b\x0ce5:nodes52:\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b\x0c\x0d\x02\x03\x04\x05\x06\x07\x08\x09\x02\x03\x04\x05\x06\x07\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b\x0c\x0d\x02\x03\x04\x05\x06\x07\x08\x09\x02\x03\x04\x05\x06\x07ee"), &msg)
require.NoError(t, err)
assert.Len(t, msg.R.Values, 2)
assert.Len(t, msg.R.Nodes, 2)
assert.Nil(t, msg.E)
}
func TestDHTDefaultConfig(t *testing.T) {
@ -203,3 +205,74 @@ func TestServerDefaultNodeIdSecure(t *testing.T) {
t.Fatal("not secure")
}
}
func testMarshalUnmarshalMsg(t *testing.T, m Msg, expected string) {
b, err := bencode.Marshal(m)
require.NoError(t, err)
assert.Equal(t, expected, string(b))
var _m Msg
err = bencode.Unmarshal([]byte(expected), &_m)
assert.NoError(t, err)
assert.EqualValues(t, m, _m)
assert.EqualValues(t, m.R, _m.R)
}
func TestMarshalUnmarshalMsg(t *testing.T) {
testMarshalUnmarshalMsg(t, Msg{}, "d1:t0:1:y0:e")
testMarshalUnmarshalMsg(t, Msg{
Y: "q",
Q: "ping",
T: "hi",
}, "d1:q4:ping1:t2:hi1:y1:qe")
testMarshalUnmarshalMsg(t, Msg{
Y: "e",
T: "42",
E: &KRPCError{Code: 200, Msg: "fuck"},
}, "d1:eli200e4:fucke1:t2:421:y1:ee")
testMarshalUnmarshalMsg(t, Msg{
Y: "r",
T: "\x8c%",
R: &Return{},
}, "d1:rd2:id0:5:token0:e1:t2:\x8c%1:y1:re")
testMarshalUnmarshalMsg(t, Msg{
Y: "r",
T: "\x8c%",
R: &Return{
Nodes: CompactIPv4NodeInfo{
NodeInfo{
Addr: newDHTAddr(&net.UDPAddr{
IP: net.IPv4(1, 2, 3, 4),
Port: 0x1234,
}),
},
},
},
}, "d1:rd2:id0:5:nodes26:\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\x02\x03\x04\x1245:token0:e1:t2:\x8c%1:y1:re")
testMarshalUnmarshalMsg(t, Msg{
Y: "r",
T: "\x8c%",
R: &Return{
Values: []util.CompactPeer{
util.CompactPeer{
IP: net.IPv4(1, 2, 3, 4).To4(),
Port: 0x5678,
},
},
},
}, "d1:rd2:id0:5:token0:6:valuesl6:\x01\x02\x03\x04\x56\x78ee1:t2:\x8c%1:y1:re")
}
func TestAnnounceTimeout(t *testing.T) {
s, err := NewServer(&ServerConfig{
BootstrapNodes: []string{"1.2.3.4:5"},
})
require.NoError(t, err)
a, err := s.Announce("12341234123412341234", 0, true)
<-a.Peers
a.Close()
s.Close()
}
func TestEqualPointers(t *testing.T) {
assert.EqualValues(t, &Msg{R: &Return{}}, &Msg{R: &Return{}})
}

45
dht/krpcError.go Normal file
View File

@ -0,0 +1,45 @@
package dht
import (
"fmt"
"github.com/anacrolix/torrent/bencode"
)
// Represented as a string or list in bencode.
type KRPCError struct {
Code int
Msg string
}
var (
_ bencode.Unmarshaler = &KRPCError{}
_ bencode.Marshaler = &KRPCError{}
_ error = KRPCError{}
)
func (me *KRPCError) UnmarshalBencode(_b []byte) (err error) {
var _v interface{}
err = bencode.Unmarshal(_b, &_v)
if err != nil {
return
}
switch v := _v.(type) {
case []interface{}:
me.Code = int(v[0].(int64))
me.Msg = v[1].(string)
case string:
me.Msg = v
default:
err = fmt.Errorf(`KRPC error bencode value has unexpected type: %T`, _v)
}
return
}
func (me KRPCError) MarshalBencode() (ret []byte, err error) {
return bencode.Marshal([]interface{}{me.Code, me.Msg})
}
func (me KRPCError) Error() string {
return fmt.Sprintf("KRPC error %d: %s", me.Code, me.Msg)
}

52
dht/msg.go Normal file
View File

@ -0,0 +1,52 @@
package dht
import (
"fmt"
"github.com/anacrolix/torrent/util"
)
// The unmarshalled KRPC dict message.
type Msg struct {
Q string `bencode:"q,omitempty"`
A *struct {
ID string `bencode:"id"`
InfoHash string `bencode:"info_hash"`
Target string `bencode:"target"`
} `bencode:"a,omitempty"`
T string `bencode:"t"`
Y string `bencode:"y"`
R *Return `bencode:"r,omitempty"`
E *KRPCError `bencode:"e,omitempty"`
}
type Return struct {
ID string `bencode:"id"`
Nodes CompactIPv4NodeInfo `bencode:"nodes,omitempty"`
Token string `bencode:"token"`
Values []util.CompactPeer `bencode:"values,omitempty"`
}
var _ fmt.Stringer = Msg{}
func (m Msg) String() string {
return fmt.Sprintf("%#v", m)
}
// The node ID of the source of this Msg.
func (m Msg) SenderID() string {
switch m.Y {
case "q":
return m.A.ID
case "r":
return m.R.ID
}
return ""
}
func (m Msg) Error() *KRPCError {
if m.Y != "e" {
return nil
}
return m.E
}

View File

@ -4,6 +4,7 @@ import (
"encoding"
"encoding/binary"
"errors"
"fmt"
"net"
"github.com/bradfitz/iter"
@ -46,6 +47,22 @@ type CompactPeer struct {
Port int
}
var (
_ bencode.Marshaler = &CompactPeer{}
_ bencode.Unmarshaler = &CompactPeer{}
)
func (me CompactPeer) MarshalBencode() (ret []byte, err error) {
ip := me.IP
if ip4 := ip.To4(); ip4 != nil {
ip = ip4
}
ret = make([]byte, len(ip)+2)
copy(ret, ip)
binary.BigEndian.PutUint16(ret[len(ip):], uint16(me.Port))
return bencode.Marshal(ret)
}
func (me *CompactPeer) UnmarshalBinary(b []byte) error {
switch len(b) {
case 18:
@ -53,7 +70,7 @@ func (me *CompactPeer) UnmarshalBinary(b []byte) error {
case 6:
me.IP = make([]byte, 4)
default:
return errors.New("bad length")
return fmt.Errorf("bad compact peer string: %q", b)
}
copy(me.IP, b)
b = b[len(me.IP):]
@ -61,6 +78,15 @@ func (me *CompactPeer) UnmarshalBinary(b []byte) error {
return nil
}
func (me *CompactPeer) UnmarshalBencode(b []byte) (err error) {
var _b []byte
err = bencode.Unmarshal(b, &_b)
if err != nil {
return
}
return me.UnmarshalBinary(_b)
}
func UnmarshalIPv4CompactPeers(b []byte) (ret []CompactPeer, err error) {
if len(b)%6 != 0 {
err = errors.New("bad length")