2
0
mirror of synced 2025-02-24 14:48:27 +00:00

dht: Various improvements and removal of cruft

This commit is contained in:
Matt Joiner 2014-12-26 17:21:48 +11:00
parent 8b7548e7a0
commit e0d936e920
5 changed files with 231 additions and 232 deletions

55
dht/bitcount.go Normal file
View File

@ -0,0 +1,55 @@
package dht
import (
"math/big"
)
// TODO: The bitcounting is a relic of the old and incorrect distance
// calculation. It is still useful in some tests but should eventually be
// replaced with actual distances.
// How many bits?
func bitCount(n big.Int) int {
var count int = 0
for _, b := range n.Bytes() {
count += int(bitCounts[b])
}
return count
}
// The bit counts for each byte value (0 - 255).
var bitCounts = []int8{
// Generated by Java BitCount of all values from 0 to 255
0, 1, 1, 2, 1, 2, 2, 3,
1, 2, 2, 3, 2, 3, 3, 4,
1, 2, 2, 3, 2, 3, 3, 4,
2, 3, 3, 4, 3, 4, 4, 5,
1, 2, 2, 3, 2, 3, 3, 4,
2, 3, 3, 4, 3, 4, 4, 5,
2, 3, 3, 4, 3, 4, 4, 5,
3, 4, 4, 5, 4, 5, 5, 6,
1, 2, 2, 3, 2, 3, 3, 4,
2, 3, 3, 4, 3, 4, 4, 5,
2, 3, 3, 4, 3, 4, 4, 5,
3, 4, 4, 5, 4, 5, 5, 6,
2, 3, 3, 4, 3, 4, 4, 5,
3, 4, 4, 5, 4, 5, 5, 6,
3, 4, 4, 5, 4, 5, 5, 6,
4, 5, 5, 6, 5, 6, 6, 7,
1, 2, 2, 3, 2, 3, 3, 4,
2, 3, 3, 4, 3, 4, 4, 5,
2, 3, 3, 4, 3, 4, 4, 5,
3, 4, 4, 5, 4, 5, 5, 6,
2, 3, 3, 4, 3, 4, 4, 5,
3, 4, 4, 5, 4, 5, 5, 6,
3, 4, 4, 5, 4, 5, 5, 6,
4, 5, 5, 6, 5, 6, 6, 7,
2, 3, 3, 4, 3, 4, 4, 5,
3, 4, 4, 5, 4, 5, 5, 6,
3, 4, 4, 5, 4, 5, 5, 6,
4, 5, 5, 6, 5, 6, 6, 7,
3, 4, 4, 5, 4, 5, 5, 6,
4, 5, 5, 6, 5, 6, 6, 7,
4, 5, 5, 6, 5, 6, 6, 7,
5, 6, 6, 7, 6, 7, 7, 8,
}

View File

@ -5,14 +5,16 @@ import (
) )
type nodeMaxHeap struct { type nodeMaxHeap struct {
IDs []string IDs []nodeID
Target string Target nodeID
} }
func (me nodeMaxHeap) Len() int { return len(me.IDs) } func (me nodeMaxHeap) Len() int { return len(me.IDs) }
func (me nodeMaxHeap) Less(i, j int) bool { func (me nodeMaxHeap) Less(i, j int) bool {
return idDistance(me.IDs[i], me.Target).Cmp(idDistance(me.IDs[j], me.Target)) > 0 m := me.IDs[i].Distance(&me.Target)
n := me.IDs[j].Distance(&me.Target)
return m.Cmp(&n) > 0
} }
func (me *nodeMaxHeap) Pop() (ret interface{}) { func (me *nodeMaxHeap) Pop() (ret interface{}) {
@ -20,7 +22,7 @@ func (me *nodeMaxHeap) Pop() (ret interface{}) {
return return
} }
func (me *nodeMaxHeap) Push(val interface{}) { func (me *nodeMaxHeap) Push(val interface{}) {
me.IDs = append(me.IDs, val.(string)) me.IDs = append(me.IDs, val.(nodeID))
} }
func (me nodeMaxHeap) Swap(i, j int) { func (me nodeMaxHeap) Swap(i, j int) {
me.IDs[i], me.IDs[j] = me.IDs[j], me.IDs[i] me.IDs[i], me.IDs[j] = me.IDs[j], me.IDs[i]
@ -31,18 +33,18 @@ type closestNodesSelector struct {
k int k int
} }
func (me *closestNodesSelector) Push(id string) { func (me *closestNodesSelector) Push(id nodeID) {
heap.Push(&me.closest, id) heap.Push(&me.closest, id)
if me.closest.Len() > me.k { if me.closest.Len() > me.k {
heap.Pop(&me.closest) heap.Pop(&me.closest)
} }
} }
func (me *closestNodesSelector) IDs() []string { func (me *closestNodesSelector) IDs() []nodeID {
return me.closest.IDs return me.closest.IDs
} }
func newKClosestNodesSelector(k int, targetID string) (ret closestNodesSelector) { func newKClosestNodesSelector(k int, targetID nodeID) (ret closestNodesSelector) {
ret.k = k ret.k = k
ret.closest.Target = targetID ret.closest.Target = targetID
return return

View File

@ -1,11 +1,16 @@
package dht package dht
import ( import (
"bitbucket.org/anacrolix/go.torrent/iplist"
"bitbucket.org/anacrolix/go.torrent/logonce"
"bitbucket.org/anacrolix/go.torrent/util"
"bitbucket.org/anacrolix/sync"
"crypto" "crypto"
_ "crypto/sha1" _ "crypto/sha1"
"encoding/binary" "encoding/binary"
"errors" "errors"
"fmt" "fmt"
"github.com/anacrolix/libtorgo/bencode"
"io" "io"
"log" "log"
"math/big" "math/big"
@ -13,14 +18,6 @@ import (
"net" "net"
"os" "os"
"time" "time"
"bitbucket.org/anacrolix/sync"
"bitbucket.org/anacrolix/go.torrent/iplist"
"bitbucket.org/anacrolix/go.torrent/logonce"
"bitbucket.org/anacrolix/go.torrent/util"
"github.com/anacrolix/libtorgo/bencode"
) )
const maxNodes = 10000 const maxNodes = 10000
@ -47,11 +44,28 @@ type Server struct {
type dHTAddr interface { type dHTAddr interface {
net.Addr net.Addr
UDPAddr() *net.UDPAddr
} }
func newDHTAddr(addr *net.UDPAddr) (ret dHTAddr) { type cachedAddr struct {
ret = addr a net.Addr
return s string
}
func (ca cachedAddr) Network() string {
return ca.a.Network()
}
func (ca cachedAddr) String() string {
return ca.s
}
func (ca cachedAddr) UDPAddr() *net.UDPAddr {
return ca.a.(*net.UDPAddr)
}
func newDHTAddr(addr *net.UDPAddr) dHTAddr {
return cachedAddr{addr, addr.String()}
} }
type ServerConfig struct { type ServerConfig struct {
@ -134,9 +148,40 @@ func (s *Server) String() string {
return fmt.Sprintf("dht server on %s", s.socket.LocalAddr()) return fmt.Sprintf("dht server on %s", s.socket.LocalAddr())
} }
type nodeID struct {
i big.Int
set bool
}
func (nid *nodeID) IsUnset() bool {
return !nid.set
}
func nodeIDFromString(s string) (ret nodeID) {
if s == "" {
return
}
ret.i.SetBytes([]byte(s))
ret.set = true
return
}
func (nid0 *nodeID) Distance(nid1 *nodeID) (ret big.Int) {
if nid0.IsUnset() != nid1.IsUnset() {
ret = maxDistance
return
}
ret.Xor(&nid0.i, &nid1.i)
return
}
func (nid *nodeID) String() string {
return string(nid.i.Bytes())
}
type Node struct { type Node struct {
addr dHTAddr addr dHTAddr
id string id nodeID
announceToken string announceToken string
lastGotQuery time.Time lastGotQuery time.Time
@ -144,16 +189,33 @@ type Node struct {
lastSentQuery time.Time lastSentQuery time.Time
} }
func (n *Node) idString() string {
return n.id.String()
}
func (n *Node) SetIDFromBytes(b []byte) {
n.id.i.SetBytes(b)
n.id.set = true
}
func (n *Node) SetIDFromString(s string) {
n.id.i.SetBytes([]byte(s))
}
func (n *Node) IDNotSet() bool {
return n.id.i.Int64() == 0
}
func (n *Node) NodeInfo() (ret NodeInfo) { func (n *Node) NodeInfo() (ret NodeInfo) {
ret.Addr = n.addr ret.Addr = n.addr
if n := copy(ret.ID[:], n.id); n != 20 { if n := copy(ret.ID[:], n.idString()); n != 20 {
panic(n) panic(n)
} }
return return
} }
func (n *Node) DefinitelyGood() bool { func (n *Node) DefinitelyGood() bool {
if len(n.id) != 20 { if len(n.idString()) != 20 {
return false return false
} }
// No reason to think ill of them if they've never been queried. // No reason to think ill of them if they've never been queried.
@ -184,6 +246,13 @@ func (m Msg) T() (t string) {
return return
} }
func (m Msg) ID() string {
defer func() {
recover()
}()
return m[m["y"].(string)].(map[string]interface{})["id"].(string)
}
func (m Msg) Nodes() []NodeInfo { func (m Msg) Nodes() []NodeInfo {
var r findNodeResponse var r findNodeResponse
if err := r.UnmarshalKRPCMsg(m); err != nil { if err := r.UnmarshalKRPCMsg(m); err != nil {
@ -447,14 +516,14 @@ func (s *Server) AddNode(ni NodeInfo) {
s.nodes = make(map[string]*Node) s.nodes = make(map[string]*Node)
} }
n := s.getNode(ni.Addr) n := s.getNode(ni.Addr)
if n.id == "" { if n.IDNotSet() {
n.id = string(ni.ID[:]) n.SetIDFromBytes(ni.ID[:])
} }
} }
func (s *Server) nodeByID(id string) *Node { func (s *Server) nodeByID(id string) *Node {
for _, node := range s.nodes { for _, node := range s.nodes {
if node.id == id { if node.idString() == id {
return node return node
} }
} }
@ -464,7 +533,7 @@ func (s *Server) nodeByID(id string) *Node {
func (s *Server) handleQuery(source dHTAddr, m Msg) { func (s *Server) handleQuery(source dHTAddr, m Msg) {
args := m["a"].(map[string]interface{}) args := m["a"].(map[string]interface{})
node := s.getNode(source) node := s.getNode(source)
node.id = args["id"].(string) node.SetIDFromString(args["id"].(string))
node.lastGotQuery = time.Now() node.lastGotQuery = time.Now()
// Don't respond. // Don't respond.
if s.passive { if s.passive {
@ -473,7 +542,7 @@ func (s *Server) handleQuery(source dHTAddr, m Msg) {
switch m["q"] { switch m["q"] {
case "ping": case "ping":
s.reply(source, m["t"].(string), nil) s.reply(source, m["t"].(string), nil)
case "get_peers": case "get_peers": // TODO: Extract common behaviour with find_node.
targetID := args["info_hash"].(string) targetID := args["info_hash"].(string)
if len(targetID) != 20 { if len(targetID) != 20 {
break break
@ -494,7 +563,7 @@ func (s *Server) handleQuery(source dHTAddr, m Msg) {
"nodes": string(nodesBytes), "nodes": string(nodesBytes),
"token": "hi", "token": "hi",
}) })
case "find_node": case "find_node": // TODO: Extract common behaviour with get_peers.
targetID := args["target"].(string) targetID := args["target"].(string)
if len(targetID) != 20 { if len(targetID) != 20 {
log.Printf("bad DHT query: %v", m) log.Printf("bad DHT query: %v", m)
@ -510,9 +579,14 @@ func (s *Server) handleQuery(source dHTAddr, m Msg) {
} }
nodesBytes := make([]byte, CompactNodeInfoLen*len(rNodes)) nodesBytes := make([]byte, CompactNodeInfoLen*len(rNodes))
for i, ni := range rNodes { for i, ni := range rNodes {
// TODO: Put IPv6 nodes into the correct dict element.
if ni.Addr.UDPAddr().IP.To4() == nil {
continue
}
err := ni.PutCompact(nodesBytes[i*CompactNodeInfoLen : (i+1)*CompactNodeInfoLen]) err := ni.PutCompact(nodesBytes[i*CompactNodeInfoLen : (i+1)*CompactNodeInfoLen])
if err != nil { if err != nil {
panic(err) log.Printf("error compacting %#v: %s", ni, err)
continue
} }
} }
s.reply(source, m["t"].(string), map[string]interface{}{ s.reply(source, m["t"].(string), map[string]interface{}{
@ -550,13 +624,14 @@ func (s *Server) reply(addr dHTAddr, t string, r map[string]interface{}) {
} }
func (s *Server) getNode(addr dHTAddr) (n *Node) { func (s *Server) getNode(addr dHTAddr) (n *Node) {
n = s.nodes[addr.String()] addrStr := addr.String()
n = s.nodes[addrStr]
if n == nil { if n == nil {
n = &Node{ n = &Node{
addr: addr, addr: addr,
} }
if len(s.nodes) < maxNodes { if len(s.nodes) < maxNodes {
s.nodes[addr.String()] = n s.nodes[addrStr] = n
} }
} }
return return
@ -577,12 +652,12 @@ func (s *Server) nodeTimedOut(addr dHTAddr) {
func (s *Server) writeToNode(b []byte, node dHTAddr) (err error) { func (s *Server) writeToNode(b []byte, node dHTAddr) (err error) {
if list := s.ipBlockList; list != nil { if list := s.ipBlockList; list != nil {
if r := list.Lookup(util.AddrIP(node)); r != nil { if r := list.Lookup(util.AddrIP(node.UDPAddr())); r != nil {
err = fmt.Errorf("write to %s blocked: %s", node, r.Description) err = fmt.Errorf("write to %s blocked: %s", node, r.Description)
return return
} }
} }
n, err := s.socket.WriteTo(b, node) n, err := s.socket.WriteTo(b, node.UDPAddr())
if err != nil { if err != nil {
err = fmt.Errorf("error writing %d bytes to %s: %s", len(b), node, err) err = fmt.Errorf("error writing %d bytes to %s: %s", len(b), node, err)
return return
@ -672,7 +747,7 @@ func (ni *NodeInfo) PutCompact(b []byte) error {
} }
ip := util.AddrIP(ni.Addr).To4() ip := util.AddrIP(ni.Addr).To4()
if len(ip) != 4 { if len(ip) != 4 {
panic(ip) return errors.New("expected ipv4 address")
} }
if n := copy(b[20:], ip); n != 4 { if n := copy(b[20:], ip); n != 4 {
panic(n) panic(n)
@ -707,7 +782,7 @@ func (s *Server) Ping(node *net.UDPAddr) (*transaction, error) {
func (s *Server) AnnouncePeer(port int, impliedPort bool, infoHash string) (err error) { func (s *Server) AnnouncePeer(port int, impliedPort bool, infoHash string) (err error) {
s.mu.Lock() s.mu.Lock()
defer s.mu.Unlock() defer s.mu.Unlock()
for _, node := range s.closestNodes(160, infoHash, func(n *Node) bool { for _, node := range s.closestNodes(160, nodeIDFromString(infoHash), func(n *Node) bool {
return n.announceToken != "" return n.announceToken != ""
}) { }) {
err = s.announcePeer(node.addr, infoHash, port, node.announceToken, impliedPort) err = s.announcePeer(node.addr, infoHash, port, node.announceToken, impliedPort)
@ -841,7 +916,7 @@ func (s *Server) liftNodes(d Msg) {
continue continue
} }
n := s.getNode(cni.Addr) n := s.getNode(cni.Addr)
n.id = string(cni.ID[:]) n.SetIDFromBytes(cni.ID[:])
} }
// log.Printf("lifted %d nodes", len(r.Nodes)) // log.Printf("lifted %d nodes", len(r.Nodes))
} }
@ -1014,7 +1089,7 @@ func (s *Server) Nodes() (nis []NodeInfo) {
ni := NodeInfo{ ni := NodeInfo{
Addr: node.addr, Addr: node.addr,
} }
if n := copy(ni.ID[:], node.id); n != 20 && n != 0 { if n := copy(ni.ID[:], node.idString()); n != 20 && n != 0 {
panic(n) panic(n)
} }
nis = append(nis, ni) nis = append(nis, ni)
@ -1033,95 +1108,6 @@ func (s *Server) Close() {
s.mu.Unlock() s.mu.Unlock()
} }
type distance interface {
Cmp(distance) int
BitCount() int
IsZero() bool
}
type bigIntDistance struct {
big.Int
}
// How many bits?
func bitCount(n *big.Int) int {
var count int = 0
for _, b := range n.Bytes() {
count += int(bitCounts[b])
}
return count
}
// The bit counts for each byte value (0 - 255).
var bitCounts = []int8{
// Generated by Java BitCount of all values from 0 to 255
0, 1, 1, 2, 1, 2, 2, 3,
1, 2, 2, 3, 2, 3, 3, 4,
1, 2, 2, 3, 2, 3, 3, 4,
2, 3, 3, 4, 3, 4, 4, 5,
1, 2, 2, 3, 2, 3, 3, 4,
2, 3, 3, 4, 3, 4, 4, 5,
2, 3, 3, 4, 3, 4, 4, 5,
3, 4, 4, 5, 4, 5, 5, 6,
1, 2, 2, 3, 2, 3, 3, 4,
2, 3, 3, 4, 3, 4, 4, 5,
2, 3, 3, 4, 3, 4, 4, 5,
3, 4, 4, 5, 4, 5, 5, 6,
2, 3, 3, 4, 3, 4, 4, 5,
3, 4, 4, 5, 4, 5, 5, 6,
3, 4, 4, 5, 4, 5, 5, 6,
4, 5, 5, 6, 5, 6, 6, 7,
1, 2, 2, 3, 2, 3, 3, 4,
2, 3, 3, 4, 3, 4, 4, 5,
2, 3, 3, 4, 3, 4, 4, 5,
3, 4, 4, 5, 4, 5, 5, 6,
2, 3, 3, 4, 3, 4, 4, 5,
3, 4, 4, 5, 4, 5, 5, 6,
3, 4, 4, 5, 4, 5, 5, 6,
4, 5, 5, 6, 5, 6, 6, 7,
2, 3, 3, 4, 3, 4, 4, 5,
3, 4, 4, 5, 4, 5, 5, 6,
3, 4, 4, 5, 4, 5, 5, 6,
4, 5, 5, 6, 5, 6, 6, 7,
3, 4, 4, 5, 4, 5, 5, 6,
4, 5, 5, 6, 5, 6, 6, 7,
4, 5, 5, 6, 5, 6, 6, 7,
5, 6, 6, 7, 6, 7, 7, 8,
}
func (me bigIntDistance) BitCount() int {
return bitCount(&me.Int)
}
func (me bigIntDistance) Cmp(d bigIntDistance) int {
return me.Int.Cmp(&d.Int)
}
func (me bigIntDistance) IsZero() bool {
var zero big.Int
return me.Int.Cmp(&zero) == 0
}
type bitCountDistance int
func (me bitCountDistance) BitCount() int { return int(me) }
func (me bitCountDistance) Cmp(rhs distance) int {
rhs_ := rhs.(bitCountDistance)
if me < rhs_ {
return -1
} else if me == rhs_ {
return 0
} else {
return 1
}
}
func (me bitCountDistance) IsZero() bool {
return me == 0
}
// Below are 2 versions of idDistance. Only one can be active.
var maxDistance big.Int var maxDistance big.Int
func init() { func init() {
@ -1129,67 +1115,24 @@ func init() {
maxDistance.SetBit(&zero, 160, 1) maxDistance.SetBit(&zero, 160, 1)
} }
// If we don't know the ID for a node, then its distance is more than the
// furthest possible distance otherwise.
func idDistance(a, b string) (ret bigIntDistance) {
if a == "" && b == "" {
return
}
if a == "" {
if len(b) != 20 {
panic(b)
}
ret.Set(&maxDistance)
return
}
if b == "" {
if len(a) != 20 {
panic(a)
}
ret.Set(&maxDistance)
return
}
if len(a) != 20 {
panic(a)
}
if len(b) != 20 {
panic(b)
}
var x, y big.Int
x.SetBytes([]byte(a))
y.SetBytes([]byte(b))
ret.Int.Xor(&x, &y)
return ret
}
// func idDistance(a, b string) bitCountDistance {
// ret := 0
// for i := 0; i < 20; i++ {
// for j := uint(0); j < 8; j++ {
// ret += int(a[i]>>j&1 ^ b[i]>>j&1)
// }
// }
// return bitCountDistance(ret)
// }
func (s *Server) closestGoodNodes(k int, targetID string) []*Node { func (s *Server) closestGoodNodes(k int, targetID string) []*Node {
return s.closestNodes(k, targetID, func(n *Node) bool { return n.DefinitelyGood() }) return s.closestNodes(k, nodeIDFromString(targetID), func(n *Node) bool { return n.DefinitelyGood() })
} }
func (s *Server) closestNodes(k int, targetID string, filter func(*Node) bool) []*Node { func (s *Server) closestNodes(k int, target nodeID, filter func(*Node) bool) []*Node {
sel := newKClosestNodesSelector(k, targetID) sel := newKClosestNodesSelector(k, target)
idNodes := make(map[string]*Node, len(s.nodes)) idNodes := make(map[string]*Node, len(s.nodes))
for _, node := range s.nodes { for _, node := range s.nodes {
if !filter(node) { if !filter(node) {
continue continue
} }
sel.Push(node.id) sel.Push(node.id)
idNodes[node.id] = node idNodes[node.idString()] = node
} }
ids := sel.IDs() ids := sel.IDs()
ret := make([]*Node, 0, len(ids)) ret := make([]*Node, 0, len(ids))
for _, id := range ids { for _, id := range ids {
ret = append(ret, idNodes[id]) ret = append(ret, idNodes[id.String()])
} }
return ret return ret
} }

View File

@ -46,29 +46,34 @@ func recoverPanicOrDie(t *testing.T, f func()) {
const zeroID = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00" const zeroID = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"
var testIDs = []string{ var testIDs []nodeID
func init() {
for _, s := range []string{
zeroID, zeroID,
"\x03" + zeroID[1:], "\x03" + zeroID[1:],
"\x03" + zeroID[1:18] + "\x55\xf0", "\x03" + zeroID[1:18] + "\x55\xf0",
"\x55" + zeroID[1:17] + "\xff\x55\x0f", "\x55" + zeroID[1:17] + "\xff\x55\x0f",
"\x54" + zeroID[1:18] + "\x50\x0f", "\x54" + zeroID[1:18] + "\x50\x0f",
"", "",
} {
testIDs = append(testIDs, nodeIDFromString(s))
}
} }
func TestDistances(t *testing.T) { func TestDistances(t *testing.T) {
if idDistance(testIDs[3], testIDs[0]).BitCount() != 4+8+4+4 { expectBitcount := func(i big.Int, count int) {
t.FailNow() if bitCount(i) != count {
t.Fatalf("expected bitcount of %d: got %d", count, bitCount(i))
} }
if idDistance(testIDs[3], testIDs[1]).BitCount() != 4+8+4+4 {
t.FailNow()
}
if idDistance(testIDs[3], testIDs[2]).BitCount() != 4+8+8 {
t.FailNow()
} }
expectBitcount(testIDs[3].Distance(&testIDs[0]), 4+8+4+4)
expectBitcount(testIDs[3].Distance(&testIDs[1]), 4+8+4+4)
expectBitcount(testIDs[3].Distance(&testIDs[2]), 4+8+8)
for i := 0; i < 5; i++ { for i := 0; i < 5; i++ {
dist := idDistance(testIDs[i], testIDs[5]).Int dist := testIDs[i].Distance(&testIDs[5])
if dist.Cmp(&maxDistance) != 0 { if dist.Cmp(&maxDistance) != 0 {
t.FailNow() t.Fatal("expected max distance for comparison with unset node id")
} }
} }
} }
@ -79,37 +84,6 @@ func TestMaxDistanceString(t *testing.T) {
} }
} }
func TestBadIdStrings(t *testing.T) {
var a, b string
idDistance(a, b)
idDistance(a, zeroID)
idDistance(zeroID, b)
recoverPanicOrDie(t, func() {
idDistance("when", a)
})
recoverPanicOrDie(t, func() {
idDistance(a, "bad")
})
recoverPanicOrDie(t, func() {
idDistance("meets", "evil")
})
for _, id := range testIDs {
if !idDistance(id, id).IsZero() {
t.Fatal("identical IDs should have distance 0")
}
}
a = "\x03" + zeroID[1:]
b = zeroID
if idDistance(a, b).BitCount() != 2 {
t.FailNow()
}
a = "\x03" + zeroID[1:18] + "\x55\xf0"
b = "\x55" + zeroID[1:17] + "\xff\x55\x0f"
if c := idDistance(a, b).BitCount(); c != 20 {
t.Fatal(c)
}
}
func TestClosestNodes(t *testing.T) { func TestClosestNodes(t *testing.T) {
cn := newKClosestNodesSelector(2, testIDs[3]) cn := newKClosestNodesSelector(2, testIDs[3])
for _, i := range rand.Perm(len(testIDs)) { for _, i := range rand.Perm(len(testIDs)) {
@ -120,9 +94,9 @@ func TestClosestNodes(t *testing.T) {
} }
m := map[string]bool{} m := map[string]bool{}
for _, id := range cn.IDs() { for _, id := range cn.IDs() {
m[id] = true m[id.String()] = true
} }
if !m[testIDs[3]] || !m[testIDs[4]] { if !m[testIDs[3].String()] || !m[testIDs[4].String()] {
t.FailNow() t.FailNow()
} }
} }
@ -154,3 +128,28 @@ func TestDHTDefaultConfig(t *testing.T) {
} }
s.Close() s.Close()
} }
func TestPing(t *testing.T) {
srv, err := NewServer(nil)
if err != nil {
t.Fatal(err)
}
defer srv.Close()
srv0, err := NewServer(nil)
if err != nil {
t.Fatal(err)
}
defer srv0.Close()
tn, err := srv.Ping(&net.UDPAddr{
IP: []byte{127, 0, 0, 1},
Port: srv0.LocalAddr().(*net.UDPAddr).Port,
})
if err != nil {
t.Fatal(err)
}
defer tn.Close()
msg := <-tn.Response
if msg.ID() != srv0.IDString() {
t.FailNow()
}
}

View File

@ -1,12 +1,12 @@
package dht package dht
import ( import (
"log"
"time"
"bitbucket.org/anacrolix/go.torrent/util" "bitbucket.org/anacrolix/go.torrent/util"
"bitbucket.org/anacrolix/sync" "bitbucket.org/anacrolix/sync"
"github.com/willf/bloom" "github.com/willf/bloom"
"log"
"net"
"time"
) )
type peerDiscovery struct { type peerDiscovery struct {
@ -19,7 +19,7 @@ type peerDiscovery struct {
func (s *Server) GetPeers(infoHash string) (*peerStream, error) { func (s *Server) GetPeers(infoHash string) (*peerStream, error) {
s.mu.Lock() s.mu.Lock()
startAddrs := func() (ret []net.Addr) { startAddrs := func() (ret []dHTAddr) {
for _, n := range s.closestGoodNodes(160, infoHash) { for _, n := range s.closestGoodNodes(160, infoHash) {
ret = append(ret, n.addr) ret = append(ret, n.addr)
} }
@ -32,7 +32,7 @@ func (s *Server) GetPeers(infoHash string) (*peerStream, error) {
return nil, err return nil, err
} }
for _, addr := range addrs { for _, addr := range addrs {
startAddrs = append(startAddrs, addr) startAddrs = append(startAddrs, newDHTAddr(addr))
} }
} }
disc := &peerDiscovery{ disc := &peerDiscovery{
@ -41,7 +41,7 @@ func (s *Server) GetPeers(infoHash string) (*peerStream, error) {
stop: make(chan struct{}), stop: make(chan struct{}),
values: make(chan peerStreamValue), values: make(chan peerStreamValue),
}, },
triedAddrs: bloom.NewWithEstimates(500000, 0.01), triedAddrs: bloom.NewWithEstimates(10000, 0.01),
server: s, server: s,
infoHash: infoHash, infoHash: infoHash,
} }
@ -72,7 +72,7 @@ func (s *Server) GetPeers(infoHash string) (*peerStream, error) {
return disc.peerStream, nil return disc.peerStream, nil
} }
func (me *peerDiscovery) gotNodeAddr(addr net.Addr) { func (me *peerDiscovery) gotNodeAddr(addr dHTAddr) {
if util.AddrPort(addr) == 0 { if util.AddrPort(addr) == 0 {
// Not a contactable address. // Not a contactable address.
return return
@ -86,7 +86,7 @@ func (me *peerDiscovery) gotNodeAddr(addr net.Addr) {
me.contact(addr) me.contact(addr)
} }
func (me *peerDiscovery) contact(addr net.Addr) { func (me *peerDiscovery) contact(addr dHTAddr) {
me.triedAddrs.Add([]byte(addr.String())) me.triedAddrs.Add([]byte(addr.String()))
if err := me.getPeers(addr); err != nil { if err := me.getPeers(addr); err != nil {
log.Printf("error sending get_peers request to %s: %s", addr, err) log.Printf("error sending get_peers request to %s: %s", addr, err)
@ -111,7 +111,7 @@ func (me *peerDiscovery) closingCh() chan struct{} {
return me.peerStream.stop return me.peerStream.stop
} }
func (me *peerDiscovery) getPeers(addr net.Addr) error { func (me *peerDiscovery) getPeers(addr dHTAddr) error {
me.server.mu.Lock() me.server.mu.Lock()
defer me.server.mu.Unlock() defer me.server.mu.Unlock()
t, err := me.server.getPeers(addr, me.infoHash) t, err := me.server.getPeers(addr, me.infoHash)