dht: Retry queries twice before timing out

This commit is contained in:
Matt Joiner 2014-12-06 21:21:20 -06:00
parent 1e3a00979d
commit 35ba3c44e1
1 changed files with 80 additions and 62 deletions

View File

@ -9,6 +9,7 @@ import (
"io"
"log"
"math/big"
"math/rand"
"net"
"os"
"sync"
@ -127,9 +128,11 @@ func (s *Server) String() string {
type Node struct {
addr dHTAddr
id string
lastHeardFrom time.Time
lastSentTo time.Time
announceToken string
lastGotQuery time.Time
lastGotResponse time.Time
lastSentQuery time.Time
}
func (n *Node) NodeInfo() (ret NodeInfo) {
@ -144,13 +147,15 @@ func (n *Node) Good() bool {
if len(n.id) != 20 {
return false
}
if n.lastSentTo.IsZero() {
// No reason to think ill of them if they've never responded.
if n.lastSentQuery.IsZero() {
return true
}
if n.lastSentTo.Before(n.lastHeardFrom) {
// They answered our last query.
if n.lastSentQuery.Before(n.lastGotResponse) {
return true
}
if time.Now().Sub(n.lastHeardFrom) >= 1*time.Minute {
if time.Now().Sub(n.lastSentQuery) >= 2*time.Minute {
return false
}
return true
@ -217,16 +222,63 @@ func (m Msg) AnnounceToken() string {
}
type transaction struct {
mu sync.Mutex
remoteAddr dHTAddr
t string
Response chan Msg
onResponse func(Msg)
done chan struct{}
mu sync.Mutex
remoteAddr dHTAddr
t string
Response chan Msg
onResponse func(Msg)
done chan struct{}
queryPacket []byte
timer *time.Timer
s *Server
retries int
}
func jitterDuration(average time.Duration, plusMinus time.Duration) time.Duration {
return average - plusMinus/2 + time.Duration(rand.Int63n(int64(plusMinus)))
}
func (t *transaction) startTimer() {
t.timer = time.AfterFunc(jitterDuration(20*time.Second, time.Second), t.timerCallback)
}
func (t *transaction) timerCallback() {
t.mu.Lock()
defer t.mu.Unlock()
select {
case <-t.done:
return
default:
}
if t.retries == 2 {
t.timeout()
return
}
t.retries++
t.sendQuery()
if t.timer.Reset(jitterDuration(20*time.Second, time.Second)) {
panic("timer should have fired to get here")
}
}
func (t *transaction) sendQuery() error {
return t.s.writeToNode(t.queryPacket, t.remoteAddr)
}
func (t *transaction) timeout() {
t.Close()
t.close()
}
func (t *transaction) close() {
if t.closing() {
return
}
close(t.Response)
close(t.done)
t.timer.Stop()
t.s.mu.Lock()
defer t.s.mu.Unlock()
t.s.removeTransaction(t)
}
func (t *transaction) closing() bool {
@ -241,11 +293,7 @@ func (t *transaction) closing() bool {
func (t *transaction) Close() {
t.mu.Lock()
defer t.mu.Unlock()
if t.closing() {
return
}
close(t.Response)
close(t.done)
t.close()
}
func (t *transaction) handleResponse(m Msg) {
@ -338,13 +386,9 @@ func (s *Server) processPacket(b []byte, addr dHTAddr) {
//log.Printf("unexpected message: %#v", d)
return
}
s.getNode(addr).lastGotResponse = time.Now()
t.handleResponse(d)
s.removeTransaction(t)
id := ""
if d["y"] == "r" {
id = d["r"].(map[string]interface{})["id"].(string)
}
s.heardFromNode(addr, id)
}
func (s *Server) serve() error {
@ -392,7 +436,10 @@ func (s *Server) nodeByID(id string) *Node {
func (s *Server) handleQuery(source dHTAddr, m Msg) {
args := m["a"].(map[string]interface{})
s.heardFromNode(source, args["id"].(string))
node := s.getNode(source)
node.id = args["id"].(string)
node.lastGotQuery = time.Now()
// Don't respond.
if s.passive {
return
}
@ -472,14 +519,6 @@ func (s *Server) reply(addr dHTAddr, t string, r map[string]interface{}) {
}
}
func (s *Server) heardFromNode(addr dHTAddr, id string) {
n := s.getNode(addr)
if len(id) == 20 {
n.id = id
}
n.lastHeardFrom = time.Now()
}
func (s *Server) getNode(addr dHTAddr) (n *Node) {
n = s.nodes[addr.String()]
if n == nil {
@ -507,15 +546,9 @@ func (s *Server) writeToNode(b []byte, node dHTAddr) (err error) {
err = io.ErrShortWrite
return
}
s.sentToNode(node)
return
}
func (s *Server) sentToNode(addr dHTAddr) {
n := s.getNode(addr)
n.lastSentTo = time.Now()
}
func (s *Server) findResponseTransaction(transactionID string, sourceNode dHTAddr) *transaction {
for _, t := range s.transactions {
if t.t == transactionID && t.remoteAddr.String() == sourceNode.String() {
@ -555,23 +588,6 @@ func (s *Server) IDString() string {
return s.id
}
func (s *Server) timeoutTransaction(t *transaction) {
select {
case <-t.done:
return
case <-time.After(time.Minute):
}
s.mu.Lock()
defer s.mu.Unlock()
select {
case <-t.done:
return
default:
}
t.timeout()
s.removeTransaction(t)
}
func (s *Server) query(node dHTAddr, q string, a map[string]interface{}) (t *transaction, err error) {
tid := s.nextTransactionID()
if a == nil {
@ -589,18 +605,20 @@ func (s *Server) query(node dHTAddr, q string, a map[string]interface{}) (t *tra
return
}
t = &transaction{
remoteAddr: node,
t: tid,
Response: make(chan Msg, 1),
done: make(chan struct{}),
remoteAddr: node,
t: tid,
Response: make(chan Msg, 1),
done: make(chan struct{}),
queryPacket: b,
s: s,
}
s.addTransaction(t)
err = s.writeToNode(b, node)
err = t.sendQuery()
if err != nil {
s.removeTransaction(t)
return
}
go s.timeoutTransaction(t)
s.getNode(node).lastSentQuery = time.Now()
t.startTimer()
s.addTransaction(t)
return
}