diff --git a/dht/dht.go b/dht/dht.go index 89b9ee24..836f387f 100644 --- a/dht/dht.go +++ b/dht/dht.go @@ -232,7 +232,7 @@ type transaction struct { remoteAddr dHTAddr t string Response chan Msg - onResponse func(Msg) + onResponse func(Msg) // Called with the server locked. done chan struct{} queryPacket []byte timer *time.Timer @@ -326,7 +326,9 @@ func (t *transaction) handleResponse(m Msg) { close(t.done) t.mu.Unlock() if t.onResponse != nil { + t.s.mu.Lock() t.onResponse(m) + t.s.mu.Unlock() } t.queryPacket = nil select { diff --git a/dht/getpeers.go b/dht/getpeers.go index 38295210..6654636b 100644 --- a/dht/getpeers.go +++ b/dht/getpeers.go @@ -4,24 +4,23 @@ import ( "log" "net" "sync" - "time" "bitbucket.org/anacrolix/go.torrent/util" ) type peerDiscovery struct { *peerStream - triedAddrs map[string]struct{} - contactAddrs chan net.Addr - pending int - transactionClosed chan struct{} - server *Server - infoHash string + triedAddrs map[string]struct{} + backlog map[string]net.Addr + pending int + server *Server + infoHash string } +const parallelQueries = 100 + func (me *peerDiscovery) Close() { me.peerStream.Close() - close(me.contactAddrs) } func (s *Server) GetPeers(infoHash string) (*peerStream, error) { @@ -45,63 +44,67 @@ func (s *Server) GetPeers(infoHash string) (*peerStream, error) { Values: make(chan peerStreamValue), stop: make(chan struct{}), }, - triedAddrs: make(map[string]struct{}, 500), - contactAddrs: make(chan net.Addr), - transactionClosed: make(chan struct{}), - server: s, - infoHash: infoHash, + triedAddrs: make(map[string]struct{}, 500), + backlog: make(map[string]net.Addr, parallelQueries), + server: s, + infoHash: infoHash, } - go disc.loop() + disc.mu.Lock() for _, addr := range startAddrs { disc.contact(addr) } + disc.mu.Unlock() return disc.peerStream, nil } +func (me *peerDiscovery) gotNodeAddr(addr net.Addr) { + if util.AddrPort(addr) == 0 { + // Not a contactable address. + return + } + if me.server.ipBlocked(util.AddrIP(addr)) { + return + } + if _, ok := me.triedAddrs[addr.String()]; ok { + return + } + if _, ok := me.backlog[addr.String()]; ok { + return + } + if me.pending >= parallelQueries { + me.backlog[addr.String()] = addr + } else { + me.contact(addr) + } +} + func (me *peerDiscovery) contact(addr net.Addr) { - select { - case me.contactAddrs <- addr: - case <-me.closingCh(): + me.triedAddrs[addr.String()] = struct{}{} + if err := me.getPeers(addr); err != nil { + log.Printf("error sending get_peers request to %s: %s", addr, err) + return + } + me.pending++ +} + +func (me *peerDiscovery) transactionClosed() { + me.pending-- + // log.Printf("pending: %d", me.pending) + for key, addr := range me.backlog { + if me.pending >= parallelQueries { + break + } + delete(me.backlog, key) + me.contact(addr) + } + if me.pending == 0 { + me.Close() + return } } func (me *peerDiscovery) responseNode(node NodeInfo) { - if util.AddrPort(node.Addr) == 0 { - // Not a contactable address. - return - } - me.contact(node.Addr) -} - -func (me *peerDiscovery) loop() { - for { - select { - case addr := <-me.contactAddrs: - if me.pending >= 1000 { - break - } - if _, ok := me.triedAddrs[addr.String()]; ok { - break - } - me.triedAddrs[addr.String()] = struct{}{} - if me.server.ipBlocked(util.AddrIP(addr)) { - break - } - if err := me.getPeers(addr); err != nil { - log.Printf("error sending get_peers request to %s: %s", addr, err) - break - } - // log.Printf("contacting %s", addr) - me.pending++ - case <-me.transactionClosed: - me.pending-- - // log.Printf("pending: %d", me.pending) - if me.pending == 0 { - me.Close() - return - } - } - } + me.gotNodeAddr(node.Addr) } func (me *peerDiscovery) closingCh() chan struct{} { @@ -118,11 +121,13 @@ func (me *peerDiscovery) getPeers(addr net.Addr) error { go func() { select { case m := <-t.Response: + me.mu.Lock() if nodes := m.Nodes(); len(nodes) != 0 { for _, n := range nodes { me.responseNode(n) } } + me.mu.Unlock() if vs := extractValues(m); vs != nil { nodeInfo := NodeInfo{ Addr: t.remoteAddr, @@ -145,7 +150,9 @@ func (me *peerDiscovery) getPeers(addr net.Addr) error { case <-me.closingCh(): } t.Close() - me.transactionClosed <- struct{}{} + me.mu.Lock() + me.transactionClosed() + me.mu.Unlock() }() return nil }