Switch dht-server to bootstrapping
This commit is contained in:
parent
7c3d919cfb
commit
1b69e69461
|
@ -0,0 +1,58 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"bitbucket.org/anacrolix/go.torrent/dht"
|
||||
"flag"
|
||||
"log"
|
||||
"net"
|
||||
"os"
|
||||
)
|
||||
|
||||
type pingResponse struct {
|
||||
addr string
|
||||
krpc dht.Msg
|
||||
}
|
||||
|
||||
func main() {
|
||||
log.SetFlags(log.LstdFlags | log.Lshortfile)
|
||||
flag.Parse()
|
||||
pingStrAddrs := flag.Args()
|
||||
if len(pingStrAddrs) == 0 {
|
||||
os.Stderr.WriteString("u must specify addrs of nodes to ping e.g. router.bittorrent.com:6881\n")
|
||||
os.Exit(2)
|
||||
}
|
||||
s := dht.Server{}
|
||||
var err error
|
||||
s.Socket, err = net.ListenPacket("udp4", "")
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
log.Printf("dht server on %s", s.Socket.LocalAddr())
|
||||
s.Init()
|
||||
go func() {
|
||||
err := s.Serve()
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
}()
|
||||
pingResponses := make(chan pingResponse)
|
||||
for _, netloc := range pingStrAddrs {
|
||||
addr, err := net.ResolveUDPAddr("udp4", netloc)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
t, err := s.Ping(addr)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
go func(addr string) {
|
||||
pingResponses <- pingResponse{
|
||||
addr: addr,
|
||||
krpc: <-t.Response,
|
||||
}
|
||||
}(netloc)
|
||||
}
|
||||
for _ = range pingStrAddrs {
|
||||
log.Print(<-pingResponses)
|
||||
}
|
||||
}
|
|
@ -15,10 +15,11 @@ func main() {
|
|||
log.SetFlags(log.LstdFlags | log.Lshortfile)
|
||||
s := dht.Server{}
|
||||
var err error
|
||||
s.Socket, err = net.ListenPacket("udp4", "")
|
||||
s.Socket, err = net.ListenUDP("udp4", nil)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
s.Init()
|
||||
log.Printf("dht server on %s", s.Socket.LocalAddr())
|
||||
go func() {
|
||||
err := s.Serve()
|
||||
|
@ -26,28 +27,9 @@ func main() {
|
|||
log.Fatal(err)
|
||||
}
|
||||
}()
|
||||
pingResponses := make(chan pingResponse)
|
||||
pingStrAddrs := []string{
|
||||
"router.utorrent.com:6881",
|
||||
"router.bittorrent.com:6881",
|
||||
}
|
||||
for _, netloc := range pingStrAddrs {
|
||||
addr, err := net.ResolveUDPAddr("udp4", netloc)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
t, err := s.Ping(addr)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
go func(addr string) {
|
||||
pingResponses <- pingResponse{
|
||||
addr: addr,
|
||||
krpc: <-t.Response,
|
||||
}
|
||||
}(netloc)
|
||||
}
|
||||
for _ = range pingStrAddrs {
|
||||
log.Print(<-pingResponses)
|
||||
err = s.Bootstrap()
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
select {}
|
||||
}
|
||||
|
|
162
dht/dht.go
162
dht/dht.go
|
@ -3,6 +3,7 @@ package dht
|
|||
import (
|
||||
"crypto/rand"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/nsf/libtorgo/bencode"
|
||||
"io"
|
||||
|
@ -13,14 +14,14 @@ import (
|
|||
|
||||
type Server struct {
|
||||
ID string
|
||||
Socket net.PacketConn
|
||||
Socket *net.UDPConn
|
||||
transactions []*transaction
|
||||
transactionIDInt uint64
|
||||
nodes map[string]*Node
|
||||
}
|
||||
|
||||
type Node struct {
|
||||
addr net.Addr
|
||||
addr *net.UDPAddr
|
||||
id string
|
||||
lastHeardFrom time.Time
|
||||
lastSentTo time.Time
|
||||
|
@ -38,6 +39,27 @@ type transaction struct {
|
|||
remoteAddr net.Addr
|
||||
t string
|
||||
Response chan Msg
|
||||
response chan Msg
|
||||
}
|
||||
|
||||
func (s *Server) WriteNodes(w io.Writer) (n int, err error) {
|
||||
for _, node := range s.nodes {
|
||||
cni := compactNodeInfo{
|
||||
Addr: node.addr,
|
||||
}
|
||||
if n := copy(cni.ID[:], node.id); n != 20 {
|
||||
panic(n)
|
||||
}
|
||||
var b [26]byte
|
||||
cni.PutBinary(b[:])
|
||||
var nn int
|
||||
nn, err = w.Write(b[:])
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
n += nn
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (s *Server) setDefaults() {
|
||||
|
@ -51,16 +73,15 @@ func (s *Server) setDefaults() {
|
|||
}
|
||||
}
|
||||
|
||||
func (s *Server) init() {
|
||||
func (s *Server) Init() {
|
||||
s.setDefaults()
|
||||
s.nodes = make(map[string]*Node, 1000)
|
||||
}
|
||||
|
||||
func (s *Server) Serve() error {
|
||||
s.setDefaults()
|
||||
s.init()
|
||||
for {
|
||||
var b [1500]byte
|
||||
n, addr, err := s.Socket.ReadFrom(b[:])
|
||||
n, addr, err := s.Socket.ReadFromUDP(b[:])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -71,7 +92,7 @@ func (s *Server) Serve() error {
|
|||
continue
|
||||
}
|
||||
t := s.findResponseTransaction(d["t"].(string), addr)
|
||||
t.Response <- d
|
||||
t.response <- d
|
||||
s.removeTransaction(t)
|
||||
id := ""
|
||||
if d["y"] == "r" {
|
||||
|
@ -81,13 +102,13 @@ func (s *Server) Serve() error {
|
|||
}
|
||||
}
|
||||
|
||||
func (s *Server) heardFromNode(addr net.Addr, id string) {
|
||||
func (s *Server) heardFromNode(addr *net.UDPAddr, id string) {
|
||||
n := s.getNode(addr)
|
||||
n.id = id
|
||||
n.lastHeardFrom = time.Now()
|
||||
}
|
||||
|
||||
func (s *Server) getNode(addr net.Addr) (n *Node) {
|
||||
func (s *Server) getNode(addr *net.UDPAddr) (n *Node) {
|
||||
n = s.nodes[addr.String()]
|
||||
if n == nil {
|
||||
n = &Node{
|
||||
|
@ -98,7 +119,7 @@ func (s *Server) getNode(addr net.Addr) (n *Node) {
|
|||
return
|
||||
}
|
||||
|
||||
func (s *Server) sentToNode(addr net.Addr) {
|
||||
func (s *Server) sentToNode(addr *net.UDPAddr) {
|
||||
n := s.getNode(addr)
|
||||
n.lastSentTo = time.Now()
|
||||
}
|
||||
|
@ -142,7 +163,7 @@ func (s *Server) IDString() string {
|
|||
return s.ID
|
||||
}
|
||||
|
||||
func (s *Server) query(node net.Addr, q string, a map[string]string) (t *transaction, err error) {
|
||||
func (s *Server) query(node *net.UDPAddr, q string, a map[string]string) (t *transaction, err error) {
|
||||
tid := s.nextTransactionID()
|
||||
if a == nil {
|
||||
a = make(map[string]string, 1)
|
||||
|
@ -163,6 +184,7 @@ func (s *Server) query(node net.Addr, q string, a map[string]string) (t *transac
|
|||
t: tid,
|
||||
Response: make(chan Msg, 1),
|
||||
}
|
||||
t.response = t.Response
|
||||
s.addTransaction(t)
|
||||
n, err := s.Socket.WriteTo(b, node)
|
||||
if err != nil {
|
||||
|
@ -178,10 +200,124 @@ func (s *Server) query(node net.Addr, q string, a map[string]string) (t *transac
|
|||
return
|
||||
}
|
||||
|
||||
func (s *Server) GetPeers(node *net.UDPAddr, targetInfoHash [20]byte) {
|
||||
const compactAddrInfoLen = 26
|
||||
|
||||
type compactAddrInfo *net.UDPAddr
|
||||
|
||||
type compactNodeInfo struct {
|
||||
ID [20]byte
|
||||
Addr compactAddrInfo
|
||||
}
|
||||
|
||||
func (s *Server) Ping(node net.Addr) (*transaction, error) {
|
||||
func (cni *compactNodeInfo) PutBinary(b []byte) {
|
||||
if n := copy(b[:], cni.ID[:]); n != 20 {
|
||||
panic(n)
|
||||
}
|
||||
ip := cni.Addr.IP.To4()
|
||||
if len(ip) != 4 {
|
||||
panic(ip)
|
||||
}
|
||||
if n := copy(b[20:], ip); n != 4 {
|
||||
panic(n)
|
||||
}
|
||||
binary.BigEndian.PutUint16(b[24:], uint16(cni.Addr.Port))
|
||||
}
|
||||
|
||||
func (cni *compactNodeInfo) UnmarshalBinary(b []byte) error {
|
||||
if len(b) != 26 {
|
||||
return errors.New("expected 26 bytes")
|
||||
}
|
||||
if 20 != copy(cni.ID[:], b[:20]) {
|
||||
panic("impossibru!")
|
||||
}
|
||||
if cni.Addr == nil {
|
||||
cni.Addr = &net.UDPAddr{}
|
||||
}
|
||||
cni.Addr.IP = net.IPv4(b[20], b[21], b[22], b[23])
|
||||
cni.Addr.Port = int(binary.BigEndian.Uint16(b[24:26]))
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Server) Ping(node *net.UDPAddr) (*transaction, error) {
|
||||
return s.query(node, "ping", nil)
|
||||
}
|
||||
|
||||
type findNodeResponse struct {
|
||||
Nodes []compactNodeInfo
|
||||
}
|
||||
|
||||
func (me *findNodeResponse) UnmarshalKRPCMsg(m Msg) error {
|
||||
b := m["r"].(map[string]interface{})["nodes"].(string)
|
||||
log.Printf("%q", b)
|
||||
for i := 0; i < len(b); i += 26 {
|
||||
var n compactNodeInfo
|
||||
err := n.UnmarshalBinary([]byte(b[i : i+26]))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
me.Nodes = append(me.Nodes, n)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Server) FindNode(addr *net.UDPAddr, targetID string) (t *transaction, err error) {
|
||||
log.Print(addr)
|
||||
t, err = s.query(addr, "find_node", map[string]string{"target": targetID})
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
ch := make(chan Msg)
|
||||
t.response = ch
|
||||
go func() {
|
||||
d, ok := <-t.response
|
||||
if !ok {
|
||||
close(t.Response)
|
||||
return
|
||||
}
|
||||
if d["y"] == "r" {
|
||||
var r findNodeResponse
|
||||
err = r.UnmarshalKRPCMsg(d)
|
||||
if err != nil {
|
||||
log.Print(err)
|
||||
} else {
|
||||
for _, cni := range r.Nodes {
|
||||
n := s.getNode(cni.Addr)
|
||||
n.id = string(cni.ID[:])
|
||||
}
|
||||
}
|
||||
}
|
||||
t.Response <- d
|
||||
}()
|
||||
return
|
||||
}
|
||||
|
||||
func (s *Server) Bootstrap() error {
|
||||
if len(s.nodes) == 0 {
|
||||
addr, err := net.ResolveUDPAddr("udp4", "router.bittorrent.com:6881")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
s.nodes[addr.String()] = &Node{
|
||||
addr: addr,
|
||||
}
|
||||
}
|
||||
queriedNodes := make(map[string]bool, 1000)
|
||||
for {
|
||||
for _, node := range s.nodes {
|
||||
if queriedNodes[node.addr.String()] {
|
||||
log.Printf("skipping already queried: %s", node.addr)
|
||||
continue
|
||||
}
|
||||
t, err := s.FindNode(node.addr, s.ID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
queriedNodes[node.addr.String()] = true
|
||||
go func() {
|
||||
log.Print(<-t.Response)
|
||||
}()
|
||||
}
|
||||
time.Sleep(3 * time.Second)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
|
|
@ -0,0 +1,28 @@
|
|||
package dht
|
||||
|
||||
import (
|
||||
"net"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestMarshalCompactNodeInfo(t *testing.T) {
|
||||
cni := compactNodeInfo{
|
||||
ID: [20]byte{'a', 'b', 'c'},
|
||||
}
|
||||
var err error
|
||||
cni.Addr, err = net.ResolveUDPAddr("udp4", "1.2.3.4:5")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
var b [compactAddrInfoLen]byte
|
||||
cni.PutBinary(b[:])
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
var bb [26]byte
|
||||
copy(bb[:], []byte("abc"))
|
||||
copy(bb[20:], []byte("\x01\x02\x03\x04\x00\x05"))
|
||||
if b != bb {
|
||||
t.FailNow()
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue