dht: Concurrency improvements and fixes to bootstrapping and getting peers
This commit is contained in:
parent
ae45175015
commit
ba83f65ddf
143
dht/dht.go
143
dht/dht.go
|
@ -24,6 +24,11 @@ type Server struct {
|
|||
transactionIDInt uint64
|
||||
nodes map[string]*Node
|
||||
mu sync.Mutex
|
||||
closed chan struct{}
|
||||
}
|
||||
|
||||
func (s *Server) String() string {
|
||||
return fmt.Sprintf("dht server on %s", s.Socket.LocalAddr())
|
||||
}
|
||||
|
||||
type Node struct {
|
||||
|
@ -55,7 +60,15 @@ type transaction struct {
|
|||
remoteAddr net.Addr
|
||||
t string
|
||||
Response chan Msg
|
||||
response chan Msg
|
||||
onResponse func(Msg)
|
||||
}
|
||||
|
||||
func (t *transaction) handleResponse(m Msg) {
|
||||
if t.onResponse != nil {
|
||||
t.onResponse(m)
|
||||
}
|
||||
t.Response <- m
|
||||
close(t.Response)
|
||||
}
|
||||
|
||||
func (s *Server) setDefaults() (err error) {
|
||||
|
@ -91,9 +104,13 @@ func (s *Server) setDefaults() (err error) {
|
|||
return
|
||||
}
|
||||
|
||||
func (s *Server) Init() error {
|
||||
return s.setDefaults()
|
||||
//s.nodes = make(map[string]*Node)
|
||||
func (s *Server) Init() (err error) {
|
||||
err = s.setDefaults()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
s.closed = make(chan struct{})
|
||||
return
|
||||
}
|
||||
|
||||
func (s *Server) Serve() error {
|
||||
|
@ -106,7 +123,7 @@ func (s *Server) Serve() error {
|
|||
var d map[string]interface{}
|
||||
err = bencode.Unmarshal(b[:n], &d)
|
||||
if err != nil {
|
||||
log.Printf("bad krpc message: %s: %q", err, b[:n])
|
||||
log.Printf("%s: received bad krpc message: %s: %q", s, err, b[:n])
|
||||
continue
|
||||
}
|
||||
s.mu.Lock()
|
||||
|
@ -121,7 +138,7 @@ func (s *Server) Serve() error {
|
|||
s.mu.Unlock()
|
||||
continue
|
||||
}
|
||||
t.response <- d
|
||||
t.handleResponse(d)
|
||||
s.removeTransaction(t)
|
||||
id := ""
|
||||
if d["y"] == "r" {
|
||||
|
@ -143,8 +160,8 @@ func (s *Server) AddNode(ni NodeInfo) {
|
|||
}
|
||||
|
||||
func (s *Server) handleQuery(source *net.UDPAddr, m Msg) {
|
||||
log.Print(m["q"])
|
||||
if m["q"] != "ping" {
|
||||
log.Printf("%s: not handling received query: q=%s", s, m["q"])
|
||||
return
|
||||
}
|
||||
s.heardFromNode(source, m["a"].(map[string]interface{})["id"].(string))
|
||||
|
@ -264,7 +281,6 @@ func (s *Server) query(node *net.UDPAddr, q string, a map[string]string) (t *tra
|
|||
t: tid,
|
||||
Response: make(chan Msg, 1),
|
||||
}
|
||||
t.response = t.Response
|
||||
s.addTransaction(t)
|
||||
err = s.writeToNode(b, node)
|
||||
if err != nil {
|
||||
|
@ -346,18 +362,54 @@ func (me *findNodeResponse) UnmarshalKRPCMsg(m Msg) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func (t *transaction) onResponse(f func(m Msg)) {
|
||||
ch := make(chan Msg)
|
||||
t.response = ch
|
||||
go func() {
|
||||
d, ok := <-t.response
|
||||
if !ok {
|
||||
close(t.Response)
|
||||
func (t *transaction) setOnResponse(f func(m Msg)) {
|
||||
if t.onResponse != nil {
|
||||
panic(t.onResponse)
|
||||
}
|
||||
t.onResponse = f
|
||||
}
|
||||
|
||||
func unmarshalNodeInfoBinary(b []byte) (ret []NodeInfo, err error) {
|
||||
if len(b)%26 != 0 {
|
||||
err = errors.New("bad buffer length")
|
||||
return
|
||||
}
|
||||
ret = make([]NodeInfo, 0, len(b)/26)
|
||||
for i := 0; i < len(b); i += 26 {
|
||||
var ni NodeInfo
|
||||
err = ni.UnmarshalCompact(b[i : i+26])
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
f(d)
|
||||
t.Response <- d
|
||||
}()
|
||||
ret = append(ret, ni)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func extractNodes(d Msg) (nodes []NodeInfo, err error) {
|
||||
if d["y"] != "r" {
|
||||
return
|
||||
}
|
||||
r, ok := d["r"]
|
||||
if !ok {
|
||||
err = errors.New("missing r dict")
|
||||
return
|
||||
}
|
||||
rd, ok := r.(map[string]interface{})
|
||||
if !ok {
|
||||
err = errors.New("bad r value type")
|
||||
return
|
||||
}
|
||||
n, ok := rd["nodes"]
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
ns, ok := n.(string)
|
||||
if !ok {
|
||||
err = errors.New("bad nodes value type")
|
||||
return
|
||||
}
|
||||
return unmarshalNodeInfoBinary([]byte(ns))
|
||||
}
|
||||
|
||||
func (s *Server) liftNodes(d Msg) {
|
||||
|
@ -369,25 +421,23 @@ func (s *Server) liftNodes(d Msg) {
|
|||
if err != nil {
|
||||
// log.Print(err)
|
||||
} else {
|
||||
s.mu.Lock()
|
||||
for _, cni := range r.Nodes {
|
||||
n := s.getNode(cni.Addr)
|
||||
n.id = string(cni.ID[:])
|
||||
}
|
||||
s.mu.Unlock()
|
||||
// log.Printf("lifted %d nodes", len(r.Nodes))
|
||||
}
|
||||
}
|
||||
|
||||
// Sends a find_node query to addr. targetID is the node we're looking for.
|
||||
func (s *Server) FindNode(addr *net.UDPAddr, targetID string) (t *transaction, err error) {
|
||||
func (s *Server) findNode(addr *net.UDPAddr, targetID string) (t *transaction, err error) {
|
||||
t, err = s.query(addr, "find_node", map[string]string{"target": targetID})
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
// Scrape peers from the response to put in the server's table before
|
||||
// handing the response back to the caller.
|
||||
t.onResponse(func(d Msg) {
|
||||
t.setOnResponse(func(d Msg) {
|
||||
s.liftNodes(d)
|
||||
})
|
||||
return
|
||||
|
@ -471,9 +521,10 @@ func (s *Server) GetPeers(infoHash string) (ps *peerStream, err error) {
|
|||
case m := <-t.Response:
|
||||
vs := extractValues(m)
|
||||
if vs != nil {
|
||||
ps.Values <- vs
|
||||
// } else {
|
||||
// log.Print("get_peers response had no values")
|
||||
select {
|
||||
case ps.Values <- vs:
|
||||
case <-ps.stop:
|
||||
}
|
||||
}
|
||||
case <-ps.stop:
|
||||
}
|
||||
|
@ -484,7 +535,10 @@ func (s *Server) GetPeers(infoHash string) (ps *peerStream, err error) {
|
|||
s.mu.Unlock()
|
||||
go func() {
|
||||
for ; pending > 0; pending-- {
|
||||
<-done
|
||||
select {
|
||||
case <-done:
|
||||
case <-s.closed:
|
||||
}
|
||||
}
|
||||
ps.Close()
|
||||
}()
|
||||
|
@ -500,7 +554,7 @@ func (s *Server) getPeers(addr *net.UDPAddr, infoHash string) (t *transaction, e
|
|||
if err != nil {
|
||||
return
|
||||
}
|
||||
t.onResponse(func(m Msg) {
|
||||
t.setOnResponse(func(m Msg) {
|
||||
s.liftNodes(m)
|
||||
})
|
||||
return
|
||||
|
@ -523,24 +577,38 @@ func (s *Server) Bootstrap() (err error) {
|
|||
defer s.mu.Unlock()
|
||||
if len(s.nodes) == 0 {
|
||||
err = s.addRootNode()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
for {
|
||||
var outstanding sync.WaitGroup
|
||||
for _, node := range s.nodes {
|
||||
var t *transaction
|
||||
s.mu.Unlock()
|
||||
t, err = s.FindNode(node.addr, s.ID)
|
||||
s.mu.Lock()
|
||||
t, err = s.findNode(node.addr, s.ID)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
outstanding.Add(1)
|
||||
go func() {
|
||||
<-t.Response
|
||||
outstanding.Done()
|
||||
}()
|
||||
}
|
||||
time.Sleep(5 * time.Second)
|
||||
noOutstanding := make(chan struct{})
|
||||
go func() {
|
||||
outstanding.Wait()
|
||||
close(noOutstanding)
|
||||
}()
|
||||
s.mu.Unlock()
|
||||
select {
|
||||
case <-s.closed:
|
||||
s.mu.Lock()
|
||||
return
|
||||
case <-time.After(15 * time.Second):
|
||||
case <-noOutstanding:
|
||||
}
|
||||
s.mu.Lock()
|
||||
log.Printf("now have %d nodes", len(s.nodes))
|
||||
if len(s.nodes) >= 8*160 {
|
||||
break
|
||||
|
@ -569,6 +637,13 @@ func (s *Server) Nodes() (nis []NodeInfo) {
|
|||
|
||||
func (s *Server) StopServing() {
|
||||
s.Socket.Close()
|
||||
s.mu.Lock()
|
||||
select {
|
||||
case <-s.closed:
|
||||
default:
|
||||
close(s.closed)
|
||||
}
|
||||
s.mu.Unlock()
|
||||
}
|
||||
|
||||
func idDistance(a, b string) (ret int) {
|
||||
|
|
Loading…
Reference in New Issue