diff --git a/dht/dht.go b/dht/dht.go index 827a6046..65447144 100644 --- a/dht/dht.go +++ b/dht/dht.go @@ -9,6 +9,7 @@ import ( "io" "log" "math/big" + "math/rand" "net" "os" "sync" @@ -127,9 +128,11 @@ func (s *Server) String() string { type Node struct { addr dHTAddr id string - lastHeardFrom time.Time - lastSentTo time.Time announceToken string + + lastGotQuery time.Time + lastGotResponse time.Time + lastSentQuery time.Time } func (n *Node) NodeInfo() (ret NodeInfo) { @@ -144,13 +147,15 @@ func (n *Node) Good() bool { if len(n.id) != 20 { return false } - if n.lastSentTo.IsZero() { + // No reason to think ill of them if they've never responded. + if n.lastSentQuery.IsZero() { return true } - if n.lastSentTo.Before(n.lastHeardFrom) { + // They answered our last query. + if n.lastSentQuery.Before(n.lastGotResponse) { return true } - if time.Now().Sub(n.lastHeardFrom) >= 1*time.Minute { + if time.Now().Sub(n.lastSentQuery) >= 2*time.Minute { return false } return true @@ -217,16 +222,63 @@ func (m Msg) AnnounceToken() string { } type transaction struct { - mu sync.Mutex - remoteAddr dHTAddr - t string - Response chan Msg - onResponse func(Msg) - done chan struct{} + mu sync.Mutex + remoteAddr dHTAddr + t string + Response chan Msg + onResponse func(Msg) + done chan struct{} + queryPacket []byte + timer *time.Timer + s *Server + retries int +} + +func jitterDuration(average time.Duration, plusMinus time.Duration) time.Duration { + return average - plusMinus/2 + time.Duration(rand.Int63n(int64(plusMinus))) +} + +func (t *transaction) startTimer() { + t.timer = time.AfterFunc(jitterDuration(20*time.Second, time.Second), t.timerCallback) +} + +func (t *transaction) timerCallback() { + t.mu.Lock() + defer t.mu.Unlock() + select { + case <-t.done: + return + default: + } + if t.retries == 2 { + t.timeout() + return + } + t.retries++ + t.sendQuery() + if t.timer.Reset(jitterDuration(20*time.Second, time.Second)) { + panic("timer should have fired to get here") + } +} + +func (t *transaction) sendQuery() error { + return t.s.writeToNode(t.queryPacket, t.remoteAddr) } func (t *transaction) timeout() { - t.Close() + t.close() +} + +func (t *transaction) close() { + if t.closing() { + return + } + close(t.Response) + close(t.done) + t.timer.Stop() + t.s.mu.Lock() + defer t.s.mu.Unlock() + t.s.removeTransaction(t) } func (t *transaction) closing() bool { @@ -241,11 +293,7 @@ func (t *transaction) closing() bool { func (t *transaction) Close() { t.mu.Lock() defer t.mu.Unlock() - if t.closing() { - return - } - close(t.Response) - close(t.done) + t.close() } func (t *transaction) handleResponse(m Msg) { @@ -338,13 +386,9 @@ func (s *Server) processPacket(b []byte, addr dHTAddr) { //log.Printf("unexpected message: %#v", d) return } + s.getNode(addr).lastGotResponse = time.Now() t.handleResponse(d) s.removeTransaction(t) - id := "" - if d["y"] == "r" { - id = d["r"].(map[string]interface{})["id"].(string) - } - s.heardFromNode(addr, id) } func (s *Server) serve() error { @@ -392,7 +436,10 @@ func (s *Server) nodeByID(id string) *Node { func (s *Server) handleQuery(source dHTAddr, m Msg) { args := m["a"].(map[string]interface{}) - s.heardFromNode(source, args["id"].(string)) + node := s.getNode(source) + node.id = args["id"].(string) + node.lastGotQuery = time.Now() + // Don't respond. if s.passive { return } @@ -472,14 +519,6 @@ func (s *Server) reply(addr dHTAddr, t string, r map[string]interface{}) { } } -func (s *Server) heardFromNode(addr dHTAddr, id string) { - n := s.getNode(addr) - if len(id) == 20 { - n.id = id - } - n.lastHeardFrom = time.Now() -} - func (s *Server) getNode(addr dHTAddr) (n *Node) { n = s.nodes[addr.String()] if n == nil { @@ -507,15 +546,9 @@ func (s *Server) writeToNode(b []byte, node dHTAddr) (err error) { err = io.ErrShortWrite return } - s.sentToNode(node) return } -func (s *Server) sentToNode(addr dHTAddr) { - n := s.getNode(addr) - n.lastSentTo = time.Now() -} - func (s *Server) findResponseTransaction(transactionID string, sourceNode dHTAddr) *transaction { for _, t := range s.transactions { if t.t == transactionID && t.remoteAddr.String() == sourceNode.String() { @@ -555,23 +588,6 @@ func (s *Server) IDString() string { return s.id } -func (s *Server) timeoutTransaction(t *transaction) { - select { - case <-t.done: - return - case <-time.After(time.Minute): - } - s.mu.Lock() - defer s.mu.Unlock() - select { - case <-t.done: - return - default: - } - t.timeout() - s.removeTransaction(t) -} - func (s *Server) query(node dHTAddr, q string, a map[string]interface{}) (t *transaction, err error) { tid := s.nextTransactionID() if a == nil { @@ -589,18 +605,20 @@ func (s *Server) query(node dHTAddr, q string, a map[string]interface{}) (t *tra return } t = &transaction{ - remoteAddr: node, - t: tid, - Response: make(chan Msg, 1), - done: make(chan struct{}), + remoteAddr: node, + t: tid, + Response: make(chan Msg, 1), + done: make(chan struct{}), + queryPacket: b, + s: s, } - s.addTransaction(t) - err = s.writeToNode(b, node) + err = t.sendQuery() if err != nil { - s.removeTransaction(t) return } - go s.timeoutTransaction(t) + s.getNode(node).lastSentQuery = time.Now() + t.startTimer() + s.addTransaction(t) return }