diff --git a/dht/dht.go b/dht/dht.go index 2b132b6e..372f7848 100644 --- a/dht/dht.go +++ b/dht/dht.go @@ -24,6 +24,11 @@ type Server struct { transactionIDInt uint64 nodes map[string]*Node mu sync.Mutex + closed chan struct{} +} + +func (s *Server) String() string { + return fmt.Sprintf("dht server on %s", s.Socket.LocalAddr()) } type Node struct { @@ -55,7 +60,15 @@ type transaction struct { remoteAddr net.Addr t string Response chan Msg - response chan Msg + onResponse func(Msg) +} + +func (t *transaction) handleResponse(m Msg) { + if t.onResponse != nil { + t.onResponse(m) + } + t.Response <- m + close(t.Response) } func (s *Server) setDefaults() (err error) { @@ -91,9 +104,13 @@ func (s *Server) setDefaults() (err error) { return } -func (s *Server) Init() error { - return s.setDefaults() - //s.nodes = make(map[string]*Node) +func (s *Server) Init() (err error) { + err = s.setDefaults() + if err != nil { + return + } + s.closed = make(chan struct{}) + return } func (s *Server) Serve() error { @@ -106,7 +123,7 @@ func (s *Server) Serve() error { var d map[string]interface{} err = bencode.Unmarshal(b[:n], &d) if err != nil { - log.Printf("bad krpc message: %s: %q", err, b[:n]) + log.Printf("%s: received bad krpc message: %s: %q", s, err, b[:n]) continue } s.mu.Lock() @@ -121,7 +138,7 @@ func (s *Server) Serve() error { s.mu.Unlock() continue } - t.response <- d + t.handleResponse(d) s.removeTransaction(t) id := "" if d["y"] == "r" { @@ -143,8 +160,8 @@ func (s *Server) AddNode(ni NodeInfo) { } func (s *Server) handleQuery(source *net.UDPAddr, m Msg) { - log.Print(m["q"]) if m["q"] != "ping" { + log.Printf("%s: not handling received query: q=%s", s, m["q"]) return } s.heardFromNode(source, m["a"].(map[string]interface{})["id"].(string)) @@ -264,7 +281,6 @@ func (s *Server) query(node *net.UDPAddr, q string, a map[string]string) (t *tra t: tid, Response: make(chan Msg, 1), } - t.response = t.Response s.addTransaction(t) err = s.writeToNode(b, node) if err != nil { @@ -346,18 +362,54 @@ func (me *findNodeResponse) UnmarshalKRPCMsg(m Msg) error { return nil } -func (t *transaction) onResponse(f func(m Msg)) { - ch := make(chan Msg) - t.response = ch - go func() { - d, ok := <-t.response - if !ok { - close(t.Response) +func (t *transaction) setOnResponse(f func(m Msg)) { + if t.onResponse != nil { + panic(t.onResponse) + } + t.onResponse = f +} + +func unmarshalNodeInfoBinary(b []byte) (ret []NodeInfo, err error) { + if len(b)%26 != 0 { + err = errors.New("bad buffer length") + return + } + ret = make([]NodeInfo, 0, len(b)/26) + for i := 0; i < len(b); i += 26 { + var ni NodeInfo + err = ni.UnmarshalCompact(b[i : i+26]) + if err != nil { return } - f(d) - t.Response <- d - }() + ret = append(ret, ni) + } + return +} + +func extractNodes(d Msg) (nodes []NodeInfo, err error) { + if d["y"] != "r" { + return + } + r, ok := d["r"] + if !ok { + err = errors.New("missing r dict") + return + } + rd, ok := r.(map[string]interface{}) + if !ok { + err = errors.New("bad r value type") + return + } + n, ok := rd["nodes"] + if !ok { + return + } + ns, ok := n.(string) + if !ok { + err = errors.New("bad nodes value type") + return + } + return unmarshalNodeInfoBinary([]byte(ns)) } func (s *Server) liftNodes(d Msg) { @@ -369,25 +421,23 @@ func (s *Server) liftNodes(d Msg) { if err != nil { // log.Print(err) } else { - s.mu.Lock() for _, cni := range r.Nodes { n := s.getNode(cni.Addr) n.id = string(cni.ID[:]) } - s.mu.Unlock() // log.Printf("lifted %d nodes", len(r.Nodes)) } } // Sends a find_node query to addr. targetID is the node we're looking for. -func (s *Server) FindNode(addr *net.UDPAddr, targetID string) (t *transaction, err error) { +func (s *Server) findNode(addr *net.UDPAddr, targetID string) (t *transaction, err error) { t, err = s.query(addr, "find_node", map[string]string{"target": targetID}) if err != nil { return } // Scrape peers from the response to put in the server's table before // handing the response back to the caller. - t.onResponse(func(d Msg) { + t.setOnResponse(func(d Msg) { s.liftNodes(d) }) return @@ -471,9 +521,10 @@ func (s *Server) GetPeers(infoHash string) (ps *peerStream, err error) { case m := <-t.Response: vs := extractValues(m) if vs != nil { - ps.Values <- vs - // } else { - // log.Print("get_peers response had no values") + select { + case ps.Values <- vs: + case <-ps.stop: + } } case <-ps.stop: } @@ -484,7 +535,10 @@ func (s *Server) GetPeers(infoHash string) (ps *peerStream, err error) { s.mu.Unlock() go func() { for ; pending > 0; pending-- { - <-done + select { + case <-done: + case <-s.closed: + } } ps.Close() }() @@ -500,7 +554,7 @@ func (s *Server) getPeers(addr *net.UDPAddr, infoHash string) (t *transaction, e if err != nil { return } - t.onResponse(func(m Msg) { + t.setOnResponse(func(m Msg) { s.liftNodes(m) }) return @@ -523,24 +577,38 @@ func (s *Server) Bootstrap() (err error) { defer s.mu.Unlock() if len(s.nodes) == 0 { err = s.addRootNode() - if err != nil { - return - } + } + if err != nil { + return } for { + var outstanding sync.WaitGroup for _, node := range s.nodes { var t *transaction - s.mu.Unlock() - t, err = s.FindNode(node.addr, s.ID) - s.mu.Lock() + t, err = s.findNode(node.addr, s.ID) if err != nil { return } + outstanding.Add(1) go func() { <-t.Response + outstanding.Done() }() } - time.Sleep(5 * time.Second) + noOutstanding := make(chan struct{}) + go func() { + outstanding.Wait() + close(noOutstanding) + }() + s.mu.Unlock() + select { + case <-s.closed: + s.mu.Lock() + return + case <-time.After(15 * time.Second): + case <-noOutstanding: + } + s.mu.Lock() log.Printf("now have %d nodes", len(s.nodes)) if len(s.nodes) >= 8*160 { break @@ -569,6 +637,13 @@ func (s *Server) Nodes() (nis []NodeInfo) { func (s *Server) StopServing() { s.Socket.Close() + s.mu.Lock() + select { + case <-s.closed: + default: + close(s.closed) + } + s.mu.Unlock() } func idDistance(a, b string) (ret int) {