dht: Concurrency improvements and fixes to bootstrapping and getting peers

This commit is contained in:
Matt Joiner 2014-07-10 00:13:54 +10:00
parent ae45175015
commit ba83f65ddf
1 changed files with 109 additions and 34 deletions

View File

@ -24,6 +24,11 @@ type Server struct {
transactionIDInt uint64
nodes map[string]*Node
mu sync.Mutex
closed chan struct{}
}
func (s *Server) String() string {
return fmt.Sprintf("dht server on %s", s.Socket.LocalAddr())
}
type Node struct {
@ -55,7 +60,15 @@ type transaction struct {
remoteAddr net.Addr
t string
Response chan Msg
response chan Msg
onResponse func(Msg)
}
func (t *transaction) handleResponse(m Msg) {
if t.onResponse != nil {
t.onResponse(m)
}
t.Response <- m
close(t.Response)
}
func (s *Server) setDefaults() (err error) {
@ -91,9 +104,13 @@ func (s *Server) setDefaults() (err error) {
return
}
func (s *Server) Init() error {
return s.setDefaults()
//s.nodes = make(map[string]*Node)
func (s *Server) Init() (err error) {
err = s.setDefaults()
if err != nil {
return
}
s.closed = make(chan struct{})
return
}
func (s *Server) Serve() error {
@ -106,7 +123,7 @@ func (s *Server) Serve() error {
var d map[string]interface{}
err = bencode.Unmarshal(b[:n], &d)
if err != nil {
log.Printf("bad krpc message: %s: %q", err, b[:n])
log.Printf("%s: received bad krpc message: %s: %q", s, err, b[:n])
continue
}
s.mu.Lock()
@ -121,7 +138,7 @@ func (s *Server) Serve() error {
s.mu.Unlock()
continue
}
t.response <- d
t.handleResponse(d)
s.removeTransaction(t)
id := ""
if d["y"] == "r" {
@ -143,8 +160,8 @@ func (s *Server) AddNode(ni NodeInfo) {
}
func (s *Server) handleQuery(source *net.UDPAddr, m Msg) {
log.Print(m["q"])
if m["q"] != "ping" {
log.Printf("%s: not handling received query: q=%s", s, m["q"])
return
}
s.heardFromNode(source, m["a"].(map[string]interface{})["id"].(string))
@ -264,7 +281,6 @@ func (s *Server) query(node *net.UDPAddr, q string, a map[string]string) (t *tra
t: tid,
Response: make(chan Msg, 1),
}
t.response = t.Response
s.addTransaction(t)
err = s.writeToNode(b, node)
if err != nil {
@ -346,18 +362,54 @@ func (me *findNodeResponse) UnmarshalKRPCMsg(m Msg) error {
return nil
}
func (t *transaction) onResponse(f func(m Msg)) {
ch := make(chan Msg)
t.response = ch
go func() {
d, ok := <-t.response
if !ok {
close(t.Response)
func (t *transaction) setOnResponse(f func(m Msg)) {
if t.onResponse != nil {
panic(t.onResponse)
}
t.onResponse = f
}
func unmarshalNodeInfoBinary(b []byte) (ret []NodeInfo, err error) {
if len(b)%26 != 0 {
err = errors.New("bad buffer length")
return
}
ret = make([]NodeInfo, 0, len(b)/26)
for i := 0; i < len(b); i += 26 {
var ni NodeInfo
err = ni.UnmarshalCompact(b[i : i+26])
if err != nil {
return
}
f(d)
t.Response <- d
}()
ret = append(ret, ni)
}
return
}
func extractNodes(d Msg) (nodes []NodeInfo, err error) {
if d["y"] != "r" {
return
}
r, ok := d["r"]
if !ok {
err = errors.New("missing r dict")
return
}
rd, ok := r.(map[string]interface{})
if !ok {
err = errors.New("bad r value type")
return
}
n, ok := rd["nodes"]
if !ok {
return
}
ns, ok := n.(string)
if !ok {
err = errors.New("bad nodes value type")
return
}
return unmarshalNodeInfoBinary([]byte(ns))
}
func (s *Server) liftNodes(d Msg) {
@ -369,25 +421,23 @@ func (s *Server) liftNodes(d Msg) {
if err != nil {
// log.Print(err)
} else {
s.mu.Lock()
for _, cni := range r.Nodes {
n := s.getNode(cni.Addr)
n.id = string(cni.ID[:])
}
s.mu.Unlock()
// log.Printf("lifted %d nodes", len(r.Nodes))
}
}
// Sends a find_node query to addr. targetID is the node we're looking for.
func (s *Server) FindNode(addr *net.UDPAddr, targetID string) (t *transaction, err error) {
func (s *Server) findNode(addr *net.UDPAddr, targetID string) (t *transaction, err error) {
t, err = s.query(addr, "find_node", map[string]string{"target": targetID})
if err != nil {
return
}
// Scrape peers from the response to put in the server's table before
// handing the response back to the caller.
t.onResponse(func(d Msg) {
t.setOnResponse(func(d Msg) {
s.liftNodes(d)
})
return
@ -471,9 +521,10 @@ func (s *Server) GetPeers(infoHash string) (ps *peerStream, err error) {
case m := <-t.Response:
vs := extractValues(m)
if vs != nil {
ps.Values <- vs
// } else {
// log.Print("get_peers response had no values")
select {
case ps.Values <- vs:
case <-ps.stop:
}
}
case <-ps.stop:
}
@ -484,7 +535,10 @@ func (s *Server) GetPeers(infoHash string) (ps *peerStream, err error) {
s.mu.Unlock()
go func() {
for ; pending > 0; pending-- {
<-done
select {
case <-done:
case <-s.closed:
}
}
ps.Close()
}()
@ -500,7 +554,7 @@ func (s *Server) getPeers(addr *net.UDPAddr, infoHash string) (t *transaction, e
if err != nil {
return
}
t.onResponse(func(m Msg) {
t.setOnResponse(func(m Msg) {
s.liftNodes(m)
})
return
@ -523,24 +577,38 @@ func (s *Server) Bootstrap() (err error) {
defer s.mu.Unlock()
if len(s.nodes) == 0 {
err = s.addRootNode()
if err != nil {
return
}
}
if err != nil {
return
}
for {
var outstanding sync.WaitGroup
for _, node := range s.nodes {
var t *transaction
s.mu.Unlock()
t, err = s.FindNode(node.addr, s.ID)
s.mu.Lock()
t, err = s.findNode(node.addr, s.ID)
if err != nil {
return
}
outstanding.Add(1)
go func() {
<-t.Response
outstanding.Done()
}()
}
time.Sleep(5 * time.Second)
noOutstanding := make(chan struct{})
go func() {
outstanding.Wait()
close(noOutstanding)
}()
s.mu.Unlock()
select {
case <-s.closed:
s.mu.Lock()
return
case <-time.After(15 * time.Second):
case <-noOutstanding:
}
s.mu.Lock()
log.Printf("now have %d nodes", len(s.nodes))
if len(s.nodes) >= 8*160 {
break
@ -569,6 +637,13 @@ func (s *Server) Nodes() (nis []NodeInfo) {
func (s *Server) StopServing() {
s.Socket.Close()
s.mu.Lock()
select {
case <-s.closed:
default:
close(s.closed)
}
s.mu.Unlock()
}
func idDistance(a, b string) (ret int) {