dht: Various improvements and removal of cruft
This commit is contained in:
parent
8b7548e7a0
commit
e0d936e920
55
dht/bitcount.go
Normal file
55
dht/bitcount.go
Normal file
@ -0,0 +1,55 @@
|
||||
package dht
|
||||
|
||||
import (
|
||||
"math/big"
|
||||
)
|
||||
|
||||
// TODO: The bitcounting is a relic of the old and incorrect distance
|
||||
// calculation. It is still useful in some tests but should eventually be
|
||||
// replaced with actual distances.
|
||||
|
||||
// How many bits?
|
||||
func bitCount(n big.Int) int {
|
||||
var count int = 0
|
||||
for _, b := range n.Bytes() {
|
||||
count += int(bitCounts[b])
|
||||
}
|
||||
return count
|
||||
}
|
||||
|
||||
// The bit counts for each byte value (0 - 255).
|
||||
var bitCounts = []int8{
|
||||
// Generated by Java BitCount of all values from 0 to 255
|
||||
0, 1, 1, 2, 1, 2, 2, 3,
|
||||
1, 2, 2, 3, 2, 3, 3, 4,
|
||||
1, 2, 2, 3, 2, 3, 3, 4,
|
||||
2, 3, 3, 4, 3, 4, 4, 5,
|
||||
1, 2, 2, 3, 2, 3, 3, 4,
|
||||
2, 3, 3, 4, 3, 4, 4, 5,
|
||||
2, 3, 3, 4, 3, 4, 4, 5,
|
||||
3, 4, 4, 5, 4, 5, 5, 6,
|
||||
1, 2, 2, 3, 2, 3, 3, 4,
|
||||
2, 3, 3, 4, 3, 4, 4, 5,
|
||||
2, 3, 3, 4, 3, 4, 4, 5,
|
||||
3, 4, 4, 5, 4, 5, 5, 6,
|
||||
2, 3, 3, 4, 3, 4, 4, 5,
|
||||
3, 4, 4, 5, 4, 5, 5, 6,
|
||||
3, 4, 4, 5, 4, 5, 5, 6,
|
||||
4, 5, 5, 6, 5, 6, 6, 7,
|
||||
1, 2, 2, 3, 2, 3, 3, 4,
|
||||
2, 3, 3, 4, 3, 4, 4, 5,
|
||||
2, 3, 3, 4, 3, 4, 4, 5,
|
||||
3, 4, 4, 5, 4, 5, 5, 6,
|
||||
2, 3, 3, 4, 3, 4, 4, 5,
|
||||
3, 4, 4, 5, 4, 5, 5, 6,
|
||||
3, 4, 4, 5, 4, 5, 5, 6,
|
||||
4, 5, 5, 6, 5, 6, 6, 7,
|
||||
2, 3, 3, 4, 3, 4, 4, 5,
|
||||
3, 4, 4, 5, 4, 5, 5, 6,
|
||||
3, 4, 4, 5, 4, 5, 5, 6,
|
||||
4, 5, 5, 6, 5, 6, 6, 7,
|
||||
3, 4, 4, 5, 4, 5, 5, 6,
|
||||
4, 5, 5, 6, 5, 6, 6, 7,
|
||||
4, 5, 5, 6, 5, 6, 6, 7,
|
||||
5, 6, 6, 7, 6, 7, 7, 8,
|
||||
}
|
@ -5,14 +5,16 @@ import (
|
||||
)
|
||||
|
||||
type nodeMaxHeap struct {
|
||||
IDs []string
|
||||
Target string
|
||||
IDs []nodeID
|
||||
Target nodeID
|
||||
}
|
||||
|
||||
func (me nodeMaxHeap) Len() int { return len(me.IDs) }
|
||||
|
||||
func (me nodeMaxHeap) Less(i, j int) bool {
|
||||
return idDistance(me.IDs[i], me.Target).Cmp(idDistance(me.IDs[j], me.Target)) > 0
|
||||
m := me.IDs[i].Distance(&me.Target)
|
||||
n := me.IDs[j].Distance(&me.Target)
|
||||
return m.Cmp(&n) > 0
|
||||
}
|
||||
|
||||
func (me *nodeMaxHeap) Pop() (ret interface{}) {
|
||||
@ -20,7 +22,7 @@ func (me *nodeMaxHeap) Pop() (ret interface{}) {
|
||||
return
|
||||
}
|
||||
func (me *nodeMaxHeap) Push(val interface{}) {
|
||||
me.IDs = append(me.IDs, val.(string))
|
||||
me.IDs = append(me.IDs, val.(nodeID))
|
||||
}
|
||||
func (me nodeMaxHeap) Swap(i, j int) {
|
||||
me.IDs[i], me.IDs[j] = me.IDs[j], me.IDs[i]
|
||||
@ -31,18 +33,18 @@ type closestNodesSelector struct {
|
||||
k int
|
||||
}
|
||||
|
||||
func (me *closestNodesSelector) Push(id string) {
|
||||
func (me *closestNodesSelector) Push(id nodeID) {
|
||||
heap.Push(&me.closest, id)
|
||||
if me.closest.Len() > me.k {
|
||||
heap.Pop(&me.closest)
|
||||
}
|
||||
}
|
||||
|
||||
func (me *closestNodesSelector) IDs() []string {
|
||||
func (me *closestNodesSelector) IDs() []nodeID {
|
||||
return me.closest.IDs
|
||||
}
|
||||
|
||||
func newKClosestNodesSelector(k int, targetID string) (ret closestNodesSelector) {
|
||||
func newKClosestNodesSelector(k int, targetID nodeID) (ret closestNodesSelector) {
|
||||
ret.k = k
|
||||
ret.closest.Target = targetID
|
||||
return
|
||||
|
275
dht/dht.go
275
dht/dht.go
@ -1,11 +1,16 @@
|
||||
package dht
|
||||
|
||||
import (
|
||||
"bitbucket.org/anacrolix/go.torrent/iplist"
|
||||
"bitbucket.org/anacrolix/go.torrent/logonce"
|
||||
"bitbucket.org/anacrolix/go.torrent/util"
|
||||
"bitbucket.org/anacrolix/sync"
|
||||
"crypto"
|
||||
_ "crypto/sha1"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/anacrolix/libtorgo/bencode"
|
||||
"io"
|
||||
"log"
|
||||
"math/big"
|
||||
@ -13,14 +18,6 @@ import (
|
||||
"net"
|
||||
"os"
|
||||
"time"
|
||||
|
||||
"bitbucket.org/anacrolix/sync"
|
||||
|
||||
"bitbucket.org/anacrolix/go.torrent/iplist"
|
||||
|
||||
"bitbucket.org/anacrolix/go.torrent/logonce"
|
||||
"bitbucket.org/anacrolix/go.torrent/util"
|
||||
"github.com/anacrolix/libtorgo/bencode"
|
||||
)
|
||||
|
||||
const maxNodes = 10000
|
||||
@ -47,11 +44,28 @@ type Server struct {
|
||||
|
||||
type dHTAddr interface {
|
||||
net.Addr
|
||||
UDPAddr() *net.UDPAddr
|
||||
}
|
||||
|
||||
func newDHTAddr(addr *net.UDPAddr) (ret dHTAddr) {
|
||||
ret = addr
|
||||
return
|
||||
type cachedAddr struct {
|
||||
a net.Addr
|
||||
s string
|
||||
}
|
||||
|
||||
func (ca cachedAddr) Network() string {
|
||||
return ca.a.Network()
|
||||
}
|
||||
|
||||
func (ca cachedAddr) String() string {
|
||||
return ca.s
|
||||
}
|
||||
|
||||
func (ca cachedAddr) UDPAddr() *net.UDPAddr {
|
||||
return ca.a.(*net.UDPAddr)
|
||||
}
|
||||
|
||||
func newDHTAddr(addr *net.UDPAddr) dHTAddr {
|
||||
return cachedAddr{addr, addr.String()}
|
||||
}
|
||||
|
||||
type ServerConfig struct {
|
||||
@ -134,9 +148,40 @@ func (s *Server) String() string {
|
||||
return fmt.Sprintf("dht server on %s", s.socket.LocalAddr())
|
||||
}
|
||||
|
||||
type nodeID struct {
|
||||
i big.Int
|
||||
set bool
|
||||
}
|
||||
|
||||
func (nid *nodeID) IsUnset() bool {
|
||||
return !nid.set
|
||||
}
|
||||
|
||||
func nodeIDFromString(s string) (ret nodeID) {
|
||||
if s == "" {
|
||||
return
|
||||
}
|
||||
ret.i.SetBytes([]byte(s))
|
||||
ret.set = true
|
||||
return
|
||||
}
|
||||
|
||||
func (nid0 *nodeID) Distance(nid1 *nodeID) (ret big.Int) {
|
||||
if nid0.IsUnset() != nid1.IsUnset() {
|
||||
ret = maxDistance
|
||||
return
|
||||
}
|
||||
ret.Xor(&nid0.i, &nid1.i)
|
||||
return
|
||||
}
|
||||
|
||||
func (nid *nodeID) String() string {
|
||||
return string(nid.i.Bytes())
|
||||
}
|
||||
|
||||
type Node struct {
|
||||
addr dHTAddr
|
||||
id string
|
||||
id nodeID
|
||||
announceToken string
|
||||
|
||||
lastGotQuery time.Time
|
||||
@ -144,16 +189,33 @@ type Node struct {
|
||||
lastSentQuery time.Time
|
||||
}
|
||||
|
||||
func (n *Node) idString() string {
|
||||
return n.id.String()
|
||||
}
|
||||
|
||||
func (n *Node) SetIDFromBytes(b []byte) {
|
||||
n.id.i.SetBytes(b)
|
||||
n.id.set = true
|
||||
}
|
||||
|
||||
func (n *Node) SetIDFromString(s string) {
|
||||
n.id.i.SetBytes([]byte(s))
|
||||
}
|
||||
|
||||
func (n *Node) IDNotSet() bool {
|
||||
return n.id.i.Int64() == 0
|
||||
}
|
||||
|
||||
func (n *Node) NodeInfo() (ret NodeInfo) {
|
||||
ret.Addr = n.addr
|
||||
if n := copy(ret.ID[:], n.id); n != 20 {
|
||||
if n := copy(ret.ID[:], n.idString()); n != 20 {
|
||||
panic(n)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (n *Node) DefinitelyGood() bool {
|
||||
if len(n.id) != 20 {
|
||||
if len(n.idString()) != 20 {
|
||||
return false
|
||||
}
|
||||
// No reason to think ill of them if they've never been queried.
|
||||
@ -184,6 +246,13 @@ func (m Msg) T() (t string) {
|
||||
return
|
||||
}
|
||||
|
||||
func (m Msg) ID() string {
|
||||
defer func() {
|
||||
recover()
|
||||
}()
|
||||
return m[m["y"].(string)].(map[string]interface{})["id"].(string)
|
||||
}
|
||||
|
||||
func (m Msg) Nodes() []NodeInfo {
|
||||
var r findNodeResponse
|
||||
if err := r.UnmarshalKRPCMsg(m); err != nil {
|
||||
@ -447,14 +516,14 @@ func (s *Server) AddNode(ni NodeInfo) {
|
||||
s.nodes = make(map[string]*Node)
|
||||
}
|
||||
n := s.getNode(ni.Addr)
|
||||
if n.id == "" {
|
||||
n.id = string(ni.ID[:])
|
||||
if n.IDNotSet() {
|
||||
n.SetIDFromBytes(ni.ID[:])
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) nodeByID(id string) *Node {
|
||||
for _, node := range s.nodes {
|
||||
if node.id == id {
|
||||
if node.idString() == id {
|
||||
return node
|
||||
}
|
||||
}
|
||||
@ -464,7 +533,7 @@ func (s *Server) nodeByID(id string) *Node {
|
||||
func (s *Server) handleQuery(source dHTAddr, m Msg) {
|
||||
args := m["a"].(map[string]interface{})
|
||||
node := s.getNode(source)
|
||||
node.id = args["id"].(string)
|
||||
node.SetIDFromString(args["id"].(string))
|
||||
node.lastGotQuery = time.Now()
|
||||
// Don't respond.
|
||||
if s.passive {
|
||||
@ -473,7 +542,7 @@ func (s *Server) handleQuery(source dHTAddr, m Msg) {
|
||||
switch m["q"] {
|
||||
case "ping":
|
||||
s.reply(source, m["t"].(string), nil)
|
||||
case "get_peers":
|
||||
case "get_peers": // TODO: Extract common behaviour with find_node.
|
||||
targetID := args["info_hash"].(string)
|
||||
if len(targetID) != 20 {
|
||||
break
|
||||
@ -494,7 +563,7 @@ func (s *Server) handleQuery(source dHTAddr, m Msg) {
|
||||
"nodes": string(nodesBytes),
|
||||
"token": "hi",
|
||||
})
|
||||
case "find_node":
|
||||
case "find_node": // TODO: Extract common behaviour with get_peers.
|
||||
targetID := args["target"].(string)
|
||||
if len(targetID) != 20 {
|
||||
log.Printf("bad DHT query: %v", m)
|
||||
@ -510,9 +579,14 @@ func (s *Server) handleQuery(source dHTAddr, m Msg) {
|
||||
}
|
||||
nodesBytes := make([]byte, CompactNodeInfoLen*len(rNodes))
|
||||
for i, ni := range rNodes {
|
||||
// TODO: Put IPv6 nodes into the correct dict element.
|
||||
if ni.Addr.UDPAddr().IP.To4() == nil {
|
||||
continue
|
||||
}
|
||||
err := ni.PutCompact(nodesBytes[i*CompactNodeInfoLen : (i+1)*CompactNodeInfoLen])
|
||||
if err != nil {
|
||||
panic(err)
|
||||
log.Printf("error compacting %#v: %s", ni, err)
|
||||
continue
|
||||
}
|
||||
}
|
||||
s.reply(source, m["t"].(string), map[string]interface{}{
|
||||
@ -550,13 +624,14 @@ func (s *Server) reply(addr dHTAddr, t string, r map[string]interface{}) {
|
||||
}
|
||||
|
||||
func (s *Server) getNode(addr dHTAddr) (n *Node) {
|
||||
n = s.nodes[addr.String()]
|
||||
addrStr := addr.String()
|
||||
n = s.nodes[addrStr]
|
||||
if n == nil {
|
||||
n = &Node{
|
||||
addr: addr,
|
||||
}
|
||||
if len(s.nodes) < maxNodes {
|
||||
s.nodes[addr.String()] = n
|
||||
s.nodes[addrStr] = n
|
||||
}
|
||||
}
|
||||
return
|
||||
@ -577,12 +652,12 @@ func (s *Server) nodeTimedOut(addr dHTAddr) {
|
||||
|
||||
func (s *Server) writeToNode(b []byte, node dHTAddr) (err error) {
|
||||
if list := s.ipBlockList; list != nil {
|
||||
if r := list.Lookup(util.AddrIP(node)); r != nil {
|
||||
if r := list.Lookup(util.AddrIP(node.UDPAddr())); r != nil {
|
||||
err = fmt.Errorf("write to %s blocked: %s", node, r.Description)
|
||||
return
|
||||
}
|
||||
}
|
||||
n, err := s.socket.WriteTo(b, node)
|
||||
n, err := s.socket.WriteTo(b, node.UDPAddr())
|
||||
if err != nil {
|
||||
err = fmt.Errorf("error writing %d bytes to %s: %s", len(b), node, err)
|
||||
return
|
||||
@ -672,7 +747,7 @@ func (ni *NodeInfo) PutCompact(b []byte) error {
|
||||
}
|
||||
ip := util.AddrIP(ni.Addr).To4()
|
||||
if len(ip) != 4 {
|
||||
panic(ip)
|
||||
return errors.New("expected ipv4 address")
|
||||
}
|
||||
if n := copy(b[20:], ip); n != 4 {
|
||||
panic(n)
|
||||
@ -707,7 +782,7 @@ func (s *Server) Ping(node *net.UDPAddr) (*transaction, error) {
|
||||
func (s *Server) AnnouncePeer(port int, impliedPort bool, infoHash string) (err error) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
for _, node := range s.closestNodes(160, infoHash, func(n *Node) bool {
|
||||
for _, node := range s.closestNodes(160, nodeIDFromString(infoHash), func(n *Node) bool {
|
||||
return n.announceToken != ""
|
||||
}) {
|
||||
err = s.announcePeer(node.addr, infoHash, port, node.announceToken, impliedPort)
|
||||
@ -841,7 +916,7 @@ func (s *Server) liftNodes(d Msg) {
|
||||
continue
|
||||
}
|
||||
n := s.getNode(cni.Addr)
|
||||
n.id = string(cni.ID[:])
|
||||
n.SetIDFromBytes(cni.ID[:])
|
||||
}
|
||||
// log.Printf("lifted %d nodes", len(r.Nodes))
|
||||
}
|
||||
@ -1014,7 +1089,7 @@ func (s *Server) Nodes() (nis []NodeInfo) {
|
||||
ni := NodeInfo{
|
||||
Addr: node.addr,
|
||||
}
|
||||
if n := copy(ni.ID[:], node.id); n != 20 && n != 0 {
|
||||
if n := copy(ni.ID[:], node.idString()); n != 20 && n != 0 {
|
||||
panic(n)
|
||||
}
|
||||
nis = append(nis, ni)
|
||||
@ -1033,95 +1108,6 @@ func (s *Server) Close() {
|
||||
s.mu.Unlock()
|
||||
}
|
||||
|
||||
type distance interface {
|
||||
Cmp(distance) int
|
||||
BitCount() int
|
||||
IsZero() bool
|
||||
}
|
||||
|
||||
type bigIntDistance struct {
|
||||
big.Int
|
||||
}
|
||||
|
||||
// How many bits?
|
||||
func bitCount(n *big.Int) int {
|
||||
var count int = 0
|
||||
for _, b := range n.Bytes() {
|
||||
count += int(bitCounts[b])
|
||||
}
|
||||
return count
|
||||
}
|
||||
|
||||
// The bit counts for each byte value (0 - 255).
|
||||
var bitCounts = []int8{
|
||||
// Generated by Java BitCount of all values from 0 to 255
|
||||
0, 1, 1, 2, 1, 2, 2, 3,
|
||||
1, 2, 2, 3, 2, 3, 3, 4,
|
||||
1, 2, 2, 3, 2, 3, 3, 4,
|
||||
2, 3, 3, 4, 3, 4, 4, 5,
|
||||
1, 2, 2, 3, 2, 3, 3, 4,
|
||||
2, 3, 3, 4, 3, 4, 4, 5,
|
||||
2, 3, 3, 4, 3, 4, 4, 5,
|
||||
3, 4, 4, 5, 4, 5, 5, 6,
|
||||
1, 2, 2, 3, 2, 3, 3, 4,
|
||||
2, 3, 3, 4, 3, 4, 4, 5,
|
||||
2, 3, 3, 4, 3, 4, 4, 5,
|
||||
3, 4, 4, 5, 4, 5, 5, 6,
|
||||
2, 3, 3, 4, 3, 4, 4, 5,
|
||||
3, 4, 4, 5, 4, 5, 5, 6,
|
||||
3, 4, 4, 5, 4, 5, 5, 6,
|
||||
4, 5, 5, 6, 5, 6, 6, 7,
|
||||
1, 2, 2, 3, 2, 3, 3, 4,
|
||||
2, 3, 3, 4, 3, 4, 4, 5,
|
||||
2, 3, 3, 4, 3, 4, 4, 5,
|
||||
3, 4, 4, 5, 4, 5, 5, 6,
|
||||
2, 3, 3, 4, 3, 4, 4, 5,
|
||||
3, 4, 4, 5, 4, 5, 5, 6,
|
||||
3, 4, 4, 5, 4, 5, 5, 6,
|
||||
4, 5, 5, 6, 5, 6, 6, 7,
|
||||
2, 3, 3, 4, 3, 4, 4, 5,
|
||||
3, 4, 4, 5, 4, 5, 5, 6,
|
||||
3, 4, 4, 5, 4, 5, 5, 6,
|
||||
4, 5, 5, 6, 5, 6, 6, 7,
|
||||
3, 4, 4, 5, 4, 5, 5, 6,
|
||||
4, 5, 5, 6, 5, 6, 6, 7,
|
||||
4, 5, 5, 6, 5, 6, 6, 7,
|
||||
5, 6, 6, 7, 6, 7, 7, 8,
|
||||
}
|
||||
|
||||
func (me bigIntDistance) BitCount() int {
|
||||
return bitCount(&me.Int)
|
||||
}
|
||||
|
||||
func (me bigIntDistance) Cmp(d bigIntDistance) int {
|
||||
return me.Int.Cmp(&d.Int)
|
||||
}
|
||||
|
||||
func (me bigIntDistance) IsZero() bool {
|
||||
var zero big.Int
|
||||
return me.Int.Cmp(&zero) == 0
|
||||
}
|
||||
|
||||
type bitCountDistance int
|
||||
|
||||
func (me bitCountDistance) BitCount() int { return int(me) }
|
||||
|
||||
func (me bitCountDistance) Cmp(rhs distance) int {
|
||||
rhs_ := rhs.(bitCountDistance)
|
||||
if me < rhs_ {
|
||||
return -1
|
||||
} else if me == rhs_ {
|
||||
return 0
|
||||
} else {
|
||||
return 1
|
||||
}
|
||||
}
|
||||
|
||||
func (me bitCountDistance) IsZero() bool {
|
||||
return me == 0
|
||||
}
|
||||
|
||||
// Below are 2 versions of idDistance. Only one can be active.
|
||||
var maxDistance big.Int
|
||||
|
||||
func init() {
|
||||
@ -1129,67 +1115,24 @@ func init() {
|
||||
maxDistance.SetBit(&zero, 160, 1)
|
||||
}
|
||||
|
||||
// If we don't know the ID for a node, then its distance is more than the
|
||||
// furthest possible distance otherwise.
|
||||
func idDistance(a, b string) (ret bigIntDistance) {
|
||||
if a == "" && b == "" {
|
||||
return
|
||||
}
|
||||
if a == "" {
|
||||
if len(b) != 20 {
|
||||
panic(b)
|
||||
}
|
||||
ret.Set(&maxDistance)
|
||||
return
|
||||
}
|
||||
if b == "" {
|
||||
if len(a) != 20 {
|
||||
panic(a)
|
||||
}
|
||||
ret.Set(&maxDistance)
|
||||
return
|
||||
}
|
||||
if len(a) != 20 {
|
||||
panic(a)
|
||||
}
|
||||
if len(b) != 20 {
|
||||
panic(b)
|
||||
}
|
||||
var x, y big.Int
|
||||
x.SetBytes([]byte(a))
|
||||
y.SetBytes([]byte(b))
|
||||
ret.Int.Xor(&x, &y)
|
||||
return ret
|
||||
}
|
||||
|
||||
// func idDistance(a, b string) bitCountDistance {
|
||||
// ret := 0
|
||||
// for i := 0; i < 20; i++ {
|
||||
// for j := uint(0); j < 8; j++ {
|
||||
// ret += int(a[i]>>j&1 ^ b[i]>>j&1)
|
||||
// }
|
||||
// }
|
||||
// return bitCountDistance(ret)
|
||||
// }
|
||||
|
||||
func (s *Server) closestGoodNodes(k int, targetID string) []*Node {
|
||||
return s.closestNodes(k, targetID, func(n *Node) bool { return n.DefinitelyGood() })
|
||||
return s.closestNodes(k, nodeIDFromString(targetID), func(n *Node) bool { return n.DefinitelyGood() })
|
||||
}
|
||||
|
||||
func (s *Server) closestNodes(k int, targetID string, filter func(*Node) bool) []*Node {
|
||||
sel := newKClosestNodesSelector(k, targetID)
|
||||
func (s *Server) closestNodes(k int, target nodeID, filter func(*Node) bool) []*Node {
|
||||
sel := newKClosestNodesSelector(k, target)
|
||||
idNodes := make(map[string]*Node, len(s.nodes))
|
||||
for _, node := range s.nodes {
|
||||
if !filter(node) {
|
||||
continue
|
||||
}
|
||||
sel.Push(node.id)
|
||||
idNodes[node.id] = node
|
||||
idNodes[node.idString()] = node
|
||||
}
|
||||
ids := sel.IDs()
|
||||
ret := make([]*Node, 0, len(ids))
|
||||
for _, id := range ids {
|
||||
ret = append(ret, idNodes[id])
|
||||
ret = append(ret, idNodes[id.String()])
|
||||
}
|
||||
return ret
|
||||
}
|
||||
|
@ -46,29 +46,34 @@ func recoverPanicOrDie(t *testing.T, f func()) {
|
||||
|
||||
const zeroID = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"
|
||||
|
||||
var testIDs = []string{
|
||||
var testIDs []nodeID
|
||||
|
||||
func init() {
|
||||
for _, s := range []string{
|
||||
zeroID,
|
||||
"\x03" + zeroID[1:],
|
||||
"\x03" + zeroID[1:18] + "\x55\xf0",
|
||||
"\x55" + zeroID[1:17] + "\xff\x55\x0f",
|
||||
"\x54" + zeroID[1:18] + "\x50\x0f",
|
||||
"",
|
||||
} {
|
||||
testIDs = append(testIDs, nodeIDFromString(s))
|
||||
}
|
||||
}
|
||||
|
||||
func TestDistances(t *testing.T) {
|
||||
if idDistance(testIDs[3], testIDs[0]).BitCount() != 4+8+4+4 {
|
||||
t.FailNow()
|
||||
expectBitcount := func(i big.Int, count int) {
|
||||
if bitCount(i) != count {
|
||||
t.Fatalf("expected bitcount of %d: got %d", count, bitCount(i))
|
||||
}
|
||||
if idDistance(testIDs[3], testIDs[1]).BitCount() != 4+8+4+4 {
|
||||
t.FailNow()
|
||||
}
|
||||
if idDistance(testIDs[3], testIDs[2]).BitCount() != 4+8+8 {
|
||||
t.FailNow()
|
||||
}
|
||||
expectBitcount(testIDs[3].Distance(&testIDs[0]), 4+8+4+4)
|
||||
expectBitcount(testIDs[3].Distance(&testIDs[1]), 4+8+4+4)
|
||||
expectBitcount(testIDs[3].Distance(&testIDs[2]), 4+8+8)
|
||||
for i := 0; i < 5; i++ {
|
||||
dist := idDistance(testIDs[i], testIDs[5]).Int
|
||||
dist := testIDs[i].Distance(&testIDs[5])
|
||||
if dist.Cmp(&maxDistance) != 0 {
|
||||
t.FailNow()
|
||||
t.Fatal("expected max distance for comparison with unset node id")
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -79,37 +84,6 @@ func TestMaxDistanceString(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestBadIdStrings(t *testing.T) {
|
||||
var a, b string
|
||||
idDistance(a, b)
|
||||
idDistance(a, zeroID)
|
||||
idDistance(zeroID, b)
|
||||
recoverPanicOrDie(t, func() {
|
||||
idDistance("when", a)
|
||||
})
|
||||
recoverPanicOrDie(t, func() {
|
||||
idDistance(a, "bad")
|
||||
})
|
||||
recoverPanicOrDie(t, func() {
|
||||
idDistance("meets", "evil")
|
||||
})
|
||||
for _, id := range testIDs {
|
||||
if !idDistance(id, id).IsZero() {
|
||||
t.Fatal("identical IDs should have distance 0")
|
||||
}
|
||||
}
|
||||
a = "\x03" + zeroID[1:]
|
||||
b = zeroID
|
||||
if idDistance(a, b).BitCount() != 2 {
|
||||
t.FailNow()
|
||||
}
|
||||
a = "\x03" + zeroID[1:18] + "\x55\xf0"
|
||||
b = "\x55" + zeroID[1:17] + "\xff\x55\x0f"
|
||||
if c := idDistance(a, b).BitCount(); c != 20 {
|
||||
t.Fatal(c)
|
||||
}
|
||||
}
|
||||
|
||||
func TestClosestNodes(t *testing.T) {
|
||||
cn := newKClosestNodesSelector(2, testIDs[3])
|
||||
for _, i := range rand.Perm(len(testIDs)) {
|
||||
@ -120,9 +94,9 @@ func TestClosestNodes(t *testing.T) {
|
||||
}
|
||||
m := map[string]bool{}
|
||||
for _, id := range cn.IDs() {
|
||||
m[id] = true
|
||||
m[id.String()] = true
|
||||
}
|
||||
if !m[testIDs[3]] || !m[testIDs[4]] {
|
||||
if !m[testIDs[3].String()] || !m[testIDs[4].String()] {
|
||||
t.FailNow()
|
||||
}
|
||||
}
|
||||
@ -154,3 +128,28 @@ func TestDHTDefaultConfig(t *testing.T) {
|
||||
}
|
||||
s.Close()
|
||||
}
|
||||
|
||||
func TestPing(t *testing.T) {
|
||||
srv, err := NewServer(nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer srv.Close()
|
||||
srv0, err := NewServer(nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer srv0.Close()
|
||||
tn, err := srv.Ping(&net.UDPAddr{
|
||||
IP: []byte{127, 0, 0, 1},
|
||||
Port: srv0.LocalAddr().(*net.UDPAddr).Port,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer tn.Close()
|
||||
msg := <-tn.Response
|
||||
if msg.ID() != srv0.IDString() {
|
||||
t.FailNow()
|
||||
}
|
||||
}
|
||||
|
@ -1,12 +1,12 @@
|
||||
package dht
|
||||
|
||||
import (
|
||||
"log"
|
||||
"time"
|
||||
|
||||
"bitbucket.org/anacrolix/go.torrent/util"
|
||||
"bitbucket.org/anacrolix/sync"
|
||||
"github.com/willf/bloom"
|
||||
"log"
|
||||
"net"
|
||||
"time"
|
||||
)
|
||||
|
||||
type peerDiscovery struct {
|
||||
@ -19,7 +19,7 @@ type peerDiscovery struct {
|
||||
|
||||
func (s *Server) GetPeers(infoHash string) (*peerStream, error) {
|
||||
s.mu.Lock()
|
||||
startAddrs := func() (ret []net.Addr) {
|
||||
startAddrs := func() (ret []dHTAddr) {
|
||||
for _, n := range s.closestGoodNodes(160, infoHash) {
|
||||
ret = append(ret, n.addr)
|
||||
}
|
||||
@ -32,7 +32,7 @@ func (s *Server) GetPeers(infoHash string) (*peerStream, error) {
|
||||
return nil, err
|
||||
}
|
||||
for _, addr := range addrs {
|
||||
startAddrs = append(startAddrs, addr)
|
||||
startAddrs = append(startAddrs, newDHTAddr(addr))
|
||||
}
|
||||
}
|
||||
disc := &peerDiscovery{
|
||||
@ -41,7 +41,7 @@ func (s *Server) GetPeers(infoHash string) (*peerStream, error) {
|
||||
stop: make(chan struct{}),
|
||||
values: make(chan peerStreamValue),
|
||||
},
|
||||
triedAddrs: bloom.NewWithEstimates(500000, 0.01),
|
||||
triedAddrs: bloom.NewWithEstimates(10000, 0.01),
|
||||
server: s,
|
||||
infoHash: infoHash,
|
||||
}
|
||||
@ -72,7 +72,7 @@ func (s *Server) GetPeers(infoHash string) (*peerStream, error) {
|
||||
return disc.peerStream, nil
|
||||
}
|
||||
|
||||
func (me *peerDiscovery) gotNodeAddr(addr net.Addr) {
|
||||
func (me *peerDiscovery) gotNodeAddr(addr dHTAddr) {
|
||||
if util.AddrPort(addr) == 0 {
|
||||
// Not a contactable address.
|
||||
return
|
||||
@ -86,7 +86,7 @@ func (me *peerDiscovery) gotNodeAddr(addr net.Addr) {
|
||||
me.contact(addr)
|
||||
}
|
||||
|
||||
func (me *peerDiscovery) contact(addr net.Addr) {
|
||||
func (me *peerDiscovery) contact(addr dHTAddr) {
|
||||
me.triedAddrs.Add([]byte(addr.String()))
|
||||
if err := me.getPeers(addr); err != nil {
|
||||
log.Printf("error sending get_peers request to %s: %s", addr, err)
|
||||
@ -111,7 +111,7 @@ func (me *peerDiscovery) closingCh() chan struct{} {
|
||||
return me.peerStream.stop
|
||||
}
|
||||
|
||||
func (me *peerDiscovery) getPeers(addr net.Addr) error {
|
||||
func (me *peerDiscovery) getPeers(addr dHTAddr) error {
|
||||
me.server.mu.Lock()
|
||||
defer me.server.mu.Unlock()
|
||||
t, err := me.server.getPeers(addr, me.infoHash)
|
||||
|
Loading…
x
Reference in New Issue
Block a user