2022-04-06 11:48:16 +02:00

1460 lines
36 KiB
Go

package dht
import (
"context"
"crypto/rand"
"encoding/binary"
"fmt"
"io"
"net"
"runtime/pprof"
"strings"
"text/tabwriter"
"time"
"github.com/anacrolix/log"
"github.com/anacrolix/missinggo/v2"
"github.com/anacrolix/sync"
"github.com/pkg/errors"
"golang.org/x/time/rate"
"github.com/anacrolix/torrent/iplist"
"github.com/anacrolix/torrent/logonce"
"github.com/anacrolix/torrent/metainfo"
"github.com/anacrolix/torrent/bencode"
"github.com/anacrolix/dht/v2/bep44"
"github.com/anacrolix/dht/v2/int160"
"github.com/anacrolix/dht/v2/krpc"
peer_store "github.com/anacrolix/dht/v2/peer-store"
"github.com/anacrolix/dht/v2/traversal"
"github.com/anacrolix/dht/v2/types"
)
// A Server defines parameters for a DHT node server that is able to send
// queries, and respond to the ones from the network. Each node has a globally
// unique identifier known as the "node ID." Node IDs are chosen at random
// from the same 160-bit space as BitTorrent infohashes and define the
// behaviour of the node. Zero valued Server does not have a valid ID and thus
// is unable to function properly. Use `NewServer(nil)` to initialize a
// default node.
type Server struct {
id int160.T
socket net.PacketConn
resendDelay func() time.Duration
mu sync.RWMutex
transactions map[transactionKey]*Transaction
nextT uint64 // unique "t" field for outbound queries
table table
closed missinggo.Event
ipBlockList iplist.Ranger
tokenServer tokenServer // Manages tokens we issue to our queriers.
config ServerConfig
stats ServerStats
sendLimit *rate.Limiter
lastBootstrap time.Time
bootstrappingNow bool
store *bep44.Wrapper
}
func (s *Server) numGoodNodes() (num int) {
s.table.forNodes(func(n *node) bool {
if s.IsGood(n) {
num++
}
return true
})
return
}
func prettySince(t time.Time) string {
if t.IsZero() {
return "never"
}
d := time.Since(t)
d /= time.Second
d *= time.Second
return fmt.Sprintf("%s ago", d)
}
func (s *Server) WriteStatus(w io.Writer) {
fmt.Fprintf(w, "Listening on %s\n", s.Addr())
s.mu.Lock()
defer s.mu.Unlock()
fmt.Fprintf(w, "Nodes in table: %d good, %d total\n", s.numGoodNodes(), s.numNodes())
fmt.Fprintf(w, "Ongoing transactions: %d\n", len(s.transactions))
fmt.Fprintf(w, "Server node ID: %x\n", s.id.Bytes())
for i, b := range s.table.buckets {
if b.Len() == 0 && b.lastChanged.IsZero() {
continue
}
fmt.Fprintf(w,
"b# %v: %v nodes, last updated: %v\n",
i, b.Len(), prettySince(b.lastChanged))
if b.Len() > 0 {
tw := tabwriter.NewWriter(w, 0, 0, 1, ' ', 0)
fmt.Fprintf(tw, " node id\taddr\tlast query\tlast response\trecv\tdiscard\tflags\n")
b.EachNode(func(n *node) bool {
var flags []string
if s.IsQuestionable(n) {
flags = append(flags, "q10e")
}
if s.nodeIsBad(n) {
flags = append(flags, "bad")
}
if s.IsGood(n) {
flags = append(flags, "good")
}
if n.IsSecure() {
flags = append(flags, "sec")
}
fmt.Fprintf(tw, " %x\t%s\t%s\t%s\t%d\t%v\t%v\n",
n.Id.Bytes(),
n.Addr,
prettySince(n.lastGotQuery),
prettySince(n.lastGotResponse),
n.numReceivesFrom,
n.failedLastQuestionablePing,
strings.Join(flags, ","),
)
return true
})
tw.Flush()
}
}
fmt.Fprintln(w)
}
func (s *Server) numNodes() (num int) {
s.table.forNodes(func(n *node) bool {
num++
return true
})
return
}
// Stats returns statistics for the server.
func (s *Server) Stats() ServerStats {
s.mu.Lock()
defer s.mu.Unlock()
ss := s.stats
ss.GoodNodes = s.numGoodNodes()
ss.Nodes = s.numNodes()
ss.OutstandingTransactions = len(s.transactions)
return ss
}
// Addr returns the listen address for the server. Packets arriving to this address
// are processed by the server (unless aliens are involved).
func (s *Server) Addr() net.Addr {
return s.socket.LocalAddr()
}
func NewDefaultServerConfig() *ServerConfig {
return &ServerConfig{
NoSecurity: true,
StartingNodes: func() ([]Addr, error) { return GlobalBootstrapAddrs("udp") },
DefaultWant: []krpc.Want{krpc.WantNodes, krpc.WantNodes6},
Store: bep44.NewMemory(),
Exp: 2 * time.Hour,
SendLimiter: DefaultSendLimiter,
}
}
// If the NodeId hasn't been specified, generate one and secure it against the PublicIP if
// NoSecurity is not set.
func (c *ServerConfig) InitNodeId() {
if missinggo.IsZeroValue(c.NodeId) {
c.NodeId = RandomNodeID()
if !c.NoSecurity && c.PublicIP != nil {
SecureNodeId(&c.NodeId, c.PublicIP)
}
}
}
// NewServer initializes a new DHT node server.
func NewServer(c *ServerConfig) (s *Server, err error) {
if c == nil {
c = NewDefaultServerConfig()
}
if c.Conn == nil {
c.Conn, err = net.ListenPacket("udp", ":0")
if err != nil {
return
}
}
c.InitNodeId()
// If Logger is empty, emulate the old behaviour: Everything is logged to the default location,
// and there are no debug messages.
if c.Logger.IsZero() {
c.Logger = log.Default.FilterLevel(log.Info)
}
// Add log.Debug by default.
c.Logger = c.Logger.WithDefaultLevel(log.Debug)
if c.Store == nil {
c.Store = bep44.NewMemory()
}
if c.SendLimiter == nil {
c.SendLimiter = DefaultSendLimiter
}
s = &Server{
config: *c,
ipBlockList: c.IPBlocklist,
tokenServer: tokenServer{
maxIntervalDelta: 2,
interval: 5 * time.Minute,
secret: make([]byte, 20),
},
transactions: make(map[transactionKey]*Transaction),
table: table{
k: 8,
},
store: bep44.NewWrapper(c.Store, c.Exp),
}
rand.Read(s.tokenServer.secret)
s.socket = c.Conn
s.id = int160.FromByteArray(c.NodeId)
s.table.rootID = s.id
s.resendDelay = s.config.QueryResendDelay
if s.resendDelay == nil {
s.resendDelay = defaultQueryResendDelay
}
go s.serveUntilClosed()
return
}
func (s *Server) serveUntilClosed() {
err := s.serve()
s.mu.Lock()
defer s.mu.Unlock()
if s.closed.IsSet() {
return
}
if err != nil {
panic(err)
}
}
// Returns a description of the Server.
func (s *Server) String() string {
return fmt.Sprintf("dht server on %s (node id %v)", s.socket.LocalAddr(), s.id)
}
// Packets to and from any address matching a range in the list are dropped.
func (s *Server) SetIPBlockList(list iplist.Ranger) {
s.mu.Lock()
defer s.mu.Unlock()
s.ipBlockList = list
}
func (s *Server) IPBlocklist() iplist.Ranger {
return s.ipBlockList
}
func (s *Server) processPacket(b []byte, addr Addr) {
// log.Printf("got packet %q", b)
if len(b) < 2 || b[0] != 'd' {
// KRPC messages are bencoded dicts.
readNotKRPCDict.Add(1)
return
}
var d krpc.Msg
err := bencode.Unmarshal(b, &d)
if _, ok := err.(bencode.ErrUnusedTrailingBytes); ok {
// log.Printf("%s: received message packet with %d trailing bytes: %q", s, _err.NumUnusedBytes, b[len(b)-_err.NumUnusedBytes:])
expvars.Add("processed packets with trailing bytes", 1)
} else if err != nil {
readUnmarshalError.Add(1)
// log.Printf("%s: received bad krpc message from %s: %s: %+q", s, addr, err, b)
func() {
if se, ok := err.(*bencode.SyntaxError); ok {
// The message was truncated.
if int(se.Offset) == len(b) {
return
}
// Some messages seem to drop to nul chars abrubtly.
if int(se.Offset) < len(b) && b[se.Offset] == 0 {
return
}
// The message isn't bencode from the first.
if se.Offset == 0 {
return
}
}
// if missinggo.CryHeard() {
log.Printf("%s: received bad krpc message from %s: %s: %+q", s, addr, err, b)
// }
}()
return
}
s.mu.Lock()
defer s.mu.Unlock()
if s.closed.IsSet() {
return
}
if d.Y == "q" {
expvars.Add("received queries", 1)
s.logger().Printf("received query %q from %v", d.Q, addr)
s.handleQuery(addr, d)
return
}
tk := transactionKey{
RemoteAddr: addr.String(),
T: d.T,
}
t, ok := s.transactions[tk]
if !ok {
s.logger().Printf("received response for untracked transaction %q from %v", d.T, addr)
return
}
// s.logger().Printf("received response for transaction %q from %v", d.T, addr)
go t.handleResponse(d)
s.updateNode(addr, d.SenderID(), !d.ReadOnly, func(n *node) {
n.lastGotResponse = time.Now()
n.failedLastQuestionablePing = false
n.numReceivesFrom++
})
// Ensure we don't provide more than one response to a transaction.
s.deleteTransaction(tk)
}
func (s *Server) serve() error {
var b [0x10000]byte
for {
n, addr, err := s.socket.ReadFrom(b[:])
if err != nil {
return err
}
expvars.Add("packets read", 1)
if n == len(b) {
logonce.Stderr.Printf("received dht packet exceeds buffer size")
continue
}
if missinggo.AddrPort(addr) == 0 {
readZeroPort.Add(1)
continue
}
blocked, err := func() (bool, error) {
s.mu.RLock()
defer s.mu.RUnlock()
if s.closed.IsSet() {
return false, errors.New("server is closed")
}
return s.ipBlocked(missinggo.AddrIP(addr)), nil
}()
if err != nil {
return err
}
if blocked {
readBlocked.Add(1)
continue
}
s.processPacket(b[:n], NewAddr(addr))
}
}
func (s *Server) ipBlocked(ip net.IP) (blocked bool) {
if s.ipBlockList == nil {
return
}
_, blocked = s.ipBlockList.Lookup(ip)
return
}
// Adds directly to the node table.
func (s *Server) AddNode(ni krpc.NodeInfo) error {
id := int160.FromByteArray(ni.ID)
if id.IsZero() {
go s.Ping(ni.Addr.UDP())
return nil
}
s.mu.Lock()
defer s.mu.Unlock()
return s.updateNode(NewAddr(ni.Addr.UDP()), (*krpc.ID)(&ni.ID), true, func(*node) {})
}
func wantsContain(ws []krpc.Want, w krpc.Want) bool {
for _, _w := range ws {
if _w == w {
return true
}
}
return false
}
func shouldReturnNodes(queryWants []krpc.Want, querySource net.IP) bool {
if len(queryWants) != 0 {
return wantsContain(queryWants, krpc.WantNodes)
}
// Is it possible to be over IPv6 with IPv4 endpoints?
return querySource.To4() != nil
}
func shouldReturnNodes6(queryWants []krpc.Want, querySource net.IP) bool {
if len(queryWants) != 0 {
return wantsContain(queryWants, krpc.WantNodes6)
}
return querySource.To4() == nil
}
func (s *Server) makeReturnNodes(target int160.T, filter func(krpc.NodeAddr) bool) []krpc.NodeInfo {
return s.closestGoodNodeInfos(8, target, filter)
}
var krpcErrMissingArguments = krpc.Error{
Code: krpc.ErrorCodeProtocolError,
Msg: "missing arguments dict",
}
// Filters peers per BEP 32 to return in the values field to a get_peers query.
func filterPeers(querySourceIp net.IP, queryWants []krpc.Want, allPeers []krpc.NodeAddr) (filtered []krpc.NodeAddr) {
// The logic here is common with nodes, see BEP 32.
retain4 := shouldReturnNodes(queryWants, querySourceIp)
retain6 := shouldReturnNodes6(queryWants, querySourceIp)
for _, peer := range allPeers {
if ip, ok := func(ip net.IP) (net.IP, bool) {
as4 := peer.IP.To4()
as16 := peer.IP.To16()
switch {
case retain4 && len(ip) == net.IPv4len:
return ip, true
case retain6 && len(ip) == net.IPv6len:
return ip, true
case retain4 && as4 != nil:
// Is it possible that we're converting to an IPv4 address when the transport in use
// is IPv6?
return as4, true
case retain6 && as16 != nil:
// Couldn't any IPv4 address be converted to IPv6, but isn't listening over IPv6?
return as16, true
default:
return nil, false
}
}(peer.IP); ok {
filtered = append(filtered, krpc.NodeAddr{IP: ip, Port: peer.Port})
}
}
return
}
func (s *Server) setReturnNodes(r *krpc.Return, queryMsg krpc.Msg, querySource Addr) *krpc.Error {
if queryMsg.A == nil {
return &krpcErrMissingArguments
}
target := int160.FromByteArray(queryMsg.A.InfoHash)
if shouldReturnNodes(queryMsg.A.Want, querySource.IP()) {
r.Nodes = s.makeReturnNodes(target, func(na krpc.NodeAddr) bool { return na.IP.To4() != nil })
}
if shouldReturnNodes6(queryMsg.A.Want, querySource.IP()) {
r.Nodes6 = s.makeReturnNodes(target, func(krpc.NodeAddr) bool { return true })
}
return nil
}
func (s *Server) handleQuery(source Addr, m krpc.Msg) {
go func() {
expvars.Add(fmt.Sprintf("received query %q", m.Q), 1)
if a := m.A; a != nil {
if a.NoSeed != 0 {
expvars.Add("received argument noseed", 1)
}
if a.Scrape != 0 {
expvars.Add("received argument scrape", 1)
}
}
}()
s.updateNode(source, m.SenderID(), !m.ReadOnly, func(n *node) {
n.lastGotQuery = time.Now()
n.numReceivesFrom++
})
if s.config.OnQuery != nil {
propagate := s.config.OnQuery(&m, source.Raw())
if !propagate {
return
}
}
// Don't respond.
if s.config.Passive {
return
}
// TODO: Should we disallow replying to ourself?
args := m.A
switch m.Q {
case "ping":
s.reply(source, m.T, krpc.Return{})
case "get_peers":
// Check for the naked m.A.Want deref below.
if m.A == nil {
s.sendError(source, m.T, krpcErrMissingArguments)
break
}
var r krpc.Return
if ps := s.config.PeerStore; ps != nil {
r.Values = filterPeers(source.IP(), m.A.Want, ps.GetPeers(peer_store.InfoHash(args.InfoHash)))
r.Token = func() *string {
t := s.createToken(source)
return &t
}()
}
if len(r.Values) == 0 {
if err := s.setReturnNodes(&r, m, source); err != nil {
s.sendError(source, m.T, *err)
break
}
}
s.reply(source, m.T, r)
case "find_node":
var r krpc.Return
if err := s.setReturnNodes(&r, m, source); err != nil {
s.sendError(source, m.T, *err)
break
}
s.reply(source, m.T, r)
case "announce_peer":
if !s.validToken(args.Token, source) {
expvars.Add("received announce_peer with invalid token", 1)
return
}
expvars.Add("received announce_peer with valid token", 1)
var port int
portOk := false
if args.Port != nil {
port = *args.Port
portOk = true
}
if args.ImpliedPort {
expvars.Add("received announce_peer with implied_port", 1)
port = source.Port()
portOk = true
}
if !portOk {
expvars.Add("received announce_peer with no derivable port", 1)
}
if h := s.config.OnAnnouncePeer; h != nil {
go h(metainfo.Hash(args.InfoHash), source.IP(), port, portOk)
}
if ps := s.config.PeerStore; ps != nil {
go ps.AddPeer(
peer_store.InfoHash(args.InfoHash),
krpc.NodeAddr{IP: source.IP(), Port: port},
)
}
s.reply(source, m.T, krpc.Return{})
case "put":
if !s.validToken(args.Token, source) {
expvars.Add("received put with invalid token", 1)
return
}
expvars.Add("received put with valid token", 1)
i := &bep44.Item{
V: args.V,
K: args.K,
Salt: args.Salt,
Sig: args.Sig,
Cas: args.Cas,
Seq: *args.Seq,
}
if err := s.store.Put(i); err != nil {
kerr, ok := err.(krpc.Error)
if !ok {
s.sendError(source, m.T, krpc.ErrorMethodUnknown)
}
s.sendError(source, m.T, kerr)
break
}
s.reply(source, m.T, krpc.Return{
ID: s.ID(),
})
case "get":
var r krpc.Return
if err := s.setReturnNodes(&r, m, source); err != nil {
s.sendError(source, m.T, *err)
break
}
t := s.createToken(source)
r.Token = &t
item, err := s.store.Get(bep44.Target(args.Target))
if err == bep44.ErrItemNotFound {
s.reply(source, m.T, r)
break
}
if kerr, ok := err.(krpc.Error); ok {
s.sendError(source, m.T, kerr)
break
}
if err != nil {
s.sendError(source, m.T, krpc.Error{
Code: krpc.ErrorCodeGenericError,
Msg: err.Error(),
})
break
}
r.Seq = &item.Seq
if args.Seq != nil && item.Seq <= *args.Seq {
s.reply(source, m.T, r)
break
}
r.V = item.V
r.K = item.K
r.Sig = item.Sig
s.reply(source, m.T, r)
// case "sample_infohashes":
// // Nodes supporting this extension should always include the samples field in the response,
// // even when it is zero-length. This lets indexing nodes to distinguish nodes supporting this
// // extension from those that respond to unknown query types which contain a target field [2].
default:
// TODO: http://libtorrent.org/dht_extensions.html#forward-compatibility
s.sendError(source, m.T, krpc.ErrorMethodUnknown)
}
}
func (s *Server) sendError(addr Addr, t string, e krpc.Error) {
go func() {
m := krpc.Msg{
T: t,
Y: "e",
E: &e,
}
b, err := bencode.Marshal(m)
if err != nil {
panic(err)
}
s.logger().Printf("sending error to %q: %v", addr, e)
_, err = s.writeToNode(context.Background(), b, addr, false, true)
if err != nil {
s.logger().Printf("error replying to %q: %v", addr, err)
}
}()
}
func (s *Server) reply(addr Addr, t string, r krpc.Return) {
go func() {
r.ID = s.id.AsByteArray()
m := krpc.Msg{
T: t,
Y: "r",
R: &r,
IP: addr.KRPC(),
}
b := bencode.MustMarshal(m)
log.Fmsg("replying to %q", addr).Log(s.logger())
wrote, err := s.writeToNode(context.Background(), b, addr, s.config.WaitToReply, true)
if err != nil {
s.config.Logger.Printf("error replying to %s: %s", addr, err)
}
if wrote {
expvars.Add("replied to peer", 1)
}
}()
}
// Adds a node if appropriate.
func (s *Server) addNode(n *node) error {
if s.nodeIsBad(n) {
return errors.New("node is bad")
}
b := s.table.bucketForID(n.Id)
if b.Len() >= s.table.k {
if b.EachNode(func(bn *node) bool {
// Replace bad and untested nodes with a good one.
if s.nodeIsBad(bn) || (s.IsGood(n) && bn.lastGotResponse.IsZero()) {
s.table.dropNode(bn)
}
return b.Len() >= s.table.k
}) {
return errors.New("no room in bucket")
}
}
if err := s.table.addNode(n); err != nil {
panic(fmt.Sprintf("expected to add node: %s", err))
}
return nil
}
func (s *Server) NodeRespondedToPing(addr Addr, id int160.T) {
s.mu.Lock()
defer s.mu.Unlock()
if id == s.id {
return
}
b := s.table.bucketForID(id)
if b.GetNode(addr, id) == nil {
return
}
b.lastChanged = time.Now()
}
// Updates the node, adding it if appropriate.
func (s *Server) updateNode(addr Addr, id *krpc.ID, tryAdd bool, update func(*node)) error {
if id == nil {
return errors.New("id is nil")
}
int160Id := int160.FromByteArray(*id)
n := s.table.getNode(addr, int160Id)
missing := n == nil
if missing {
if !tryAdd {
return errors.New("node not present and add flag false")
}
if int160Id == s.id {
return errors.New("can't store own id in routing table")
}
n = &node{nodeKey: nodeKey{
Id: int160Id,
Addr: addr,
}}
}
update(n)
if !missing {
return nil
}
return s.addNode(n)
}
func (s *Server) nodeIsBad(n *node) bool {
return s.nodeErr(n) != nil
}
func (s *Server) nodeErr(n *node) error {
if n.Id == s.id {
return errors.New("is self")
}
if n.Id.IsZero() {
return errors.New("has zero id")
}
if !(s.config.NoSecurity || n.IsSecure()) {
return errors.New("not secure")
}
if n.failedLastQuestionablePing {
return errors.New("didn't respond to last questionable node ping")
}
return nil
}
func (s *Server) writeToNode(ctx context.Context, b []byte, node Addr, wait, rate bool) (wrote bool, err error) {
func() {
// This is a pain. It would be better if the blocklist returned an error if it was closed
// instead.
s.mu.RLock()
defer s.mu.RUnlock()
if s.closed.IsSet() {
err = errors.New("server is closed")
return
}
if list := s.ipBlockList; list != nil {
if r, ok := list.Lookup(node.IP()); ok {
err = fmt.Errorf("write to %v blocked by %v", node, r)
return
}
}
}()
if err != nil {
return
}
// s.config.Logger.WithValues(log.Debug).Printf("writing to %s: %q", node.String(), b)
if rate {
if wait {
err = s.config.SendLimiter.Wait(ctx)
if err != nil {
err = fmt.Errorf("waiting for rate-limit token: %w", err)
return false, err
}
} else {
if !s.config.SendLimiter.Allow() {
return false, errors.New("rate limit exceeded")
}
}
}
n, err := s.socket.WriteTo(b, node.Raw())
writes.Add(1)
if rate {
expvars.Add("rated writes", 1)
} else {
expvars.Add("unrated writes", 1)
}
if err != nil {
writeErrors.Add(1)
if rate {
// Give the token back. nfi if this will actually work.
s.config.SendLimiter.AllowN(time.Now(), -1)
}
err = fmt.Errorf("error writing %d bytes to %s: %s", len(b), node, err)
return
}
wrote = true
if n != len(b) {
err = io.ErrShortWrite
return
}
return
}
func (s *Server) nextTransactionID() string {
var b [binary.MaxVarintLen64]byte
n := binary.PutUvarint(b[:], s.nextT)
s.nextT++
return string(b[:n])
}
func (s *Server) deleteTransaction(k transactionKey) {
delete(s.transactions, k)
}
func (s *Server) addTransaction(k transactionKey, t *Transaction) {
if _, ok := s.transactions[k]; ok {
panic("transaction not unique")
}
s.transactions[k] = t
}
// ID returns the 20-byte server ID. This is the ID used to communicate with the
// DHT network.
func (s *Server) ID() [20]byte {
return s.id.AsByteArray()
}
func (s *Server) createToken(addr Addr) string {
return s.tokenServer.CreateToken(addr)
}
func (s *Server) validToken(token string, addr Addr) bool {
return s.tokenServer.ValidToken(token, addr)
}
type numWrites int
func (s *Server) makeQueryBytes(q string, a krpc.MsgArgs, t string) []byte {
a.ID = s.ID()
m := krpc.Msg{
T: t,
Y: "q",
Q: q,
A: &a,
}
// BEP 43. Outgoing queries from passive nodes should contain "ro":1 in the top level
// dictionary.
if s.config.Passive {
m.ReadOnly = true
}
b, err := bencode.Marshal(m)
if err != nil {
panic(err)
}
return b
}
type QueryResult struct {
Reply krpc.Msg
Writes numWrites
Err error
}
func (qr QueryResult) ToError() error {
if qr.Err != nil {
return qr.Err
}
e := qr.Reply.Error()
if e != nil {
return e
}
return nil
}
// Converts a Server QueryResult to a traversal.QueryResult.
func (me QueryResult) TraversalQueryResult(addr krpc.NodeAddr) (ret traversal.QueryResult) {
r := me.Reply.R
if r == nil {
return
}
ret.ResponseFrom = &krpc.NodeInfo{
Addr: addr,
ID: r.ID,
}
ret.Nodes = r.Nodes
ret.Nodes6 = r.Nodes6
if r.Token != nil {
ret.ClosestData = *r.Token
}
return
}
// Rate-limiting to be applied to writes for a given query. Queries occur inside transactions that
// will attempt to send several times. If the STM rate-limiting helpers are used, the first send is
// often already accounted for in the rate-limiting machinery before the query method that does the
// IO is invoked.
type QueryRateLimiting struct {
// Don't rate-limit the first send for a query.
NotFirst bool
// Don't rate-limit any sends for a query. Note that there's still built-in waits before retries.
NotAny bool
WaitOnRetries bool
NoWaitFirst bool
}
// The zero value for this uses reasonable/traditional defaults on Server methods.
type QueryInput struct {
MsgArgs krpc.MsgArgs
RateLimiting QueryRateLimiting
NumTries int
}
// Performs an arbitrary query. `q` is the query value, defined by the DHT BEP. `a` should contain
// the appropriate argument values, if any. `a.ID` is clobbered by the Server. Responses to queries
// made this way are not interpreted by the Server. More specific methods like FindNode and GetPeers
// may make use of the response internally before passing it back to the caller.
func (s *Server) Query(ctx context.Context, addr Addr, q string, input QueryInput) (ret QueryResult) {
if input.NumTries == 0 {
input.NumTries = defaultMaxQuerySends
}
defer func(started time.Time) {
s.logger().WithDefaultLevel(log.Debug).WithValues(q).Printf(
"Query(%v) returned after %v (err=%v, reply.Y=%v, reply.E=%v, writes=%v)",
q, time.Since(started), ret.Err, ret.Reply.Y, ret.Reply.E, ret.Writes)
}(time.Now())
replyChan := make(chan krpc.Msg, 1)
t := &Transaction{
onResponse: func(m krpc.Msg) {
replyChan <- m
},
}
tk := transactionKey{
RemoteAddr: addr.String(),
}
s.mu.Lock()
tid := s.nextTransactionID()
s.stats.OutboundQueriesAttempted++
tk.T = tid
s.addTransaction(tk, t)
s.mu.Unlock()
// Receives a non-nil error from the sender, and closes when the sender completes.
sendErr := make(chan error, 1)
sendCtx, cancelSend := context.WithCancel(pprof.WithLabels(ctx, pprof.Labels("q", q)))
go func() {
err := s.transactionQuerySender(
sendCtx,
s.makeQueryBytes(q, input.MsgArgs, tid),
&ret.Writes,
addr,
input.RateLimiting,
input.NumTries)
if err != nil {
sendErr <- err
}
close(sendErr)
}()
expvars.Add(fmt.Sprintf("outbound %s queries", q), 1)
select {
case ret.Reply = <-replyChan:
case <-ctx.Done():
ret.Err = ctx.Err()
case ret.Err = <-sendErr:
}
// Make sure the query sender stops.
cancelSend()
// Make sure the query sender has returned, it will either send an error that we didn't catch
// above, or the channel will be closed by the sender completing.
<-sendErr
s.mu.Lock()
s.deleteTransaction(tk)
s.mu.Unlock()
return
}
func (s *Server) transactionQuerySender(
sendCtx context.Context,
b []byte,
writes *numWrites,
addr Addr,
rateLimiting QueryRateLimiting,
numTries int,
) error {
// log.Printf("sending %q", b)
err := transactionSender(
sendCtx,
func() error {
wrote, err := s.writeToNode(sendCtx, b, addr,
// We only wait for the first write by default if rate-limiting is enabled for this
// query.
func() bool {
if *writes == 0 {
return !rateLimiting.NoWaitFirst
} else {
return rateLimiting.WaitOnRetries
}
}(),
func() bool {
if rateLimiting.NotAny {
return false
}
if *writes == 0 {
return !rateLimiting.NotFirst
}
return true
}(),
)
if wrote {
*writes++
}
return err
},
s.resendDelay,
numTries,
)
if err != nil {
return err
}
select {
case <-sendCtx.Done():
err = sendCtx.Err()
case <-time.After(s.resendDelay()):
err = TransactionTimeout
}
return fmt.Errorf("after %v tries: %w", numTries, err)
}
// Sends a ping query to the address given.
func (s *Server) PingQueryInput(node *net.UDPAddr, qi QueryInput) QueryResult {
addr := NewAddr(node)
res := s.Query(context.TODO(), addr, "ping", qi)
if res.Err == nil {
id := res.Reply.SenderID()
if id != nil {
s.NodeRespondedToPing(addr, id.Int160())
}
}
return res
}
// Sends a ping query to the address given.
func (s *Server) Ping(node *net.UDPAddr) QueryResult {
return s.PingQueryInput(node, QueryInput{})
}
// Put adds a new item to node. You need to call Get first for a write token.
func (s *Server) Put(ctx context.Context, node Addr, i bep44.Put, token string, rl QueryRateLimiting) QueryResult {
if err := s.store.Put(i.ToItem()); err != nil {
return QueryResult{
Err: err,
}
}
qi := QueryInput{
MsgArgs: krpc.MsgArgs{
Cas: i.Cas,
ID: s.ID(),
Salt: i.Salt,
Seq: &i.Seq,
Sig: i.Sig,
Token: token,
V: i.V,
},
}
if i.K != nil {
qi.MsgArgs.K = *i.K
}
return s.Query(ctx, node, "put", qi)
}
func (s *Server) announcePeer(
ctx context.Context,
node Addr, infoHash int160.T, port int, token string, impliedPort bool, rl QueryRateLimiting,
) (
ret QueryResult,
) {
if port == 0 && !impliedPort {
ret.Err = errors.New("no port specified")
return
}
ret = s.Query(
ctx, node, "announce_peer",
QueryInput{
MsgArgs: krpc.MsgArgs{
ImpliedPort: impliedPort,
InfoHash: infoHash.AsByteArray(),
Port: &port,
Token: token,
},
RateLimiting: rl,
})
if ret.Err != nil {
return
}
if krpcError := ret.Reply.Error(); krpcError != nil {
announceErrors.Add(1)
ret.Err = krpcError
return
}
s.mu.Lock()
defer s.mu.Unlock()
s.stats.SuccessfulOutboundAnnouncePeerQueries++
return
}
// Sends a find_node query to addr. targetID is the node we're looking for. The Server makes use of
// some of the response fields.
func (s *Server) FindNode(addr Addr, targetID int160.T, rl QueryRateLimiting) (ret QueryResult) {
ret = s.Query(context.TODO(), addr, "find_node", QueryInput{
MsgArgs: krpc.MsgArgs{
Target: targetID.AsByteArray(),
Want: s.config.DefaultWant,
},
RateLimiting: rl,
})
return
}
// Returns how many nodes are in the node table.
func (s *Server) NumNodes() int {
s.mu.Lock()
defer s.mu.Unlock()
return s.numNodes()
}
// Returns non-bad nodes from the routing table.
func (s *Server) Nodes() (nis []krpc.NodeInfo) {
s.mu.Lock()
defer s.mu.Unlock()
return s.notBadNodes()
}
// Returns non-bad nodes from the routing table.
func (s *Server) notBadNodes() (nis []krpc.NodeInfo) {
s.table.forNodes(func(n *node) bool {
if s.nodeIsBad(n) {
return true
}
nis = append(nis, krpc.NodeInfo{
Addr: n.Addr.KRPC(),
ID: n.Id.AsByteArray(),
})
return true
})
return
}
// Stops the server network activity. This is all that's required to clean-up a Server.
func (s *Server) Close() {
s.mu.Lock()
defer s.mu.Unlock()
s.closed.Set()
go s.socket.Close()
}
func (s *Server) GetPeers(ctx context.Context, addr Addr, infoHash int160.T, scrape bool, rl QueryRateLimiting) (ret QueryResult) {
args := krpc.MsgArgs{
InfoHash: infoHash.AsByteArray(),
// TODO: Maybe IPv4-only Servers won't want IPv6 nodes?
Want: s.config.DefaultWant,
}
if scrape {
args.Scrape = 1
}
ret = s.Query(ctx, addr, "get_peers", QueryInput{
MsgArgs: args,
RateLimiting: rl,
})
m := ret.Reply
if m.R != nil {
if m.R.Token == nil {
expvars.Add("get_peers responses with no token", 1)
} else if len(*m.R.Token) == 0 {
expvars.Add("get_peers responses with empty token", 1)
} else {
expvars.Add("get_peers responses with token", 1)
}
}
return
}
// Get gets item information from a specific target ID. If seq is set to a specific value,
// only items with seq bigger than the one provided will return a V, K and Sig, if any.
// Get must be used to get a Put write token, when you want to write an item instead of read it.
func (s *Server) Get(ctx context.Context, addr Addr, target bep44.Target, seq *int64, rl QueryRateLimiting) QueryResult {
return s.Query(ctx, addr, "get", QueryInput{
MsgArgs: krpc.MsgArgs{
ID: s.ID(),
Target: target,
Seq: seq,
Want: []krpc.Want{krpc.WantNodes, krpc.WantNodes6},
},
RateLimiting: rl,
})
}
func (s *Server) closestGoodNodeInfos(
k int,
targetID int160.T,
filter func(krpc.NodeAddr) bool,
) (
ret []krpc.NodeInfo,
) {
for _, n := range s.closestNodes(k, targetID, func(n *node) bool {
return s.IsGood(n) && filter(n.NodeInfo().Addr)
}) {
ret = append(ret, n.NodeInfo())
}
return
}
func (s *Server) closestNodes(k int, target int160.T, filter func(*node) bool) []*node {
return s.table.closestNodes(k, target, filter)
}
func (s *Server) TraversalStartingNodes() (nodes []addrMaybeId, err error) {
s.mu.RLock()
s.table.forNodes(func(n *node) bool {
nodes = append(nodes, addrMaybeId{Addr: n.Addr.KRPC(), Id: &n.Id})
return true
})
s.mu.RUnlock()
if len(nodes) > 0 {
return
}
if s.config.StartingNodes != nil {
// There seems to be floods on this call on occasion, which may cause a barrage of DNS
// resolution attempts. This would require that we're unable to get replies because we can't
// resolve, transmit or receive on the network. Nodes currently don't get expired from the
// table, so once we have some entries, we should never have to fallback.
s.logger().Levelf(log.Debug, "falling back on starting nodes")
addrs, err := s.config.StartingNodes()
if err != nil {
return nil, errors.Wrap(err, "getting starting nodes")
} else {
// log.Printf("resolved %v addresses", len(addrs))
}
for _, a := range addrs {
nodes = append(nodes, addrMaybeId{Addr: a.KRPC(), Id: nil})
}
}
if len(nodes) == 0 {
err = errors.New("no initial nodes")
}
return
}
func (s *Server) AddNodesFromFile(fileName string) (added int, err error) {
ns, err := ReadNodesFromFile(fileName)
if err != nil {
return
}
for _, n := range ns {
if s.AddNode(n) == nil {
added++
}
}
return
}
func (s *Server) logger() log.Logger {
return s.config.Logger
}
func (s *Server) PeerStore() peer_store.Interface {
return s.config.PeerStore
}
func (s *Server) getQuestionableNode() (ret *node) {
s.table.forNodes(func(n *node) bool {
if s.IsQuestionable(n) {
ret = n
return false
}
return true
})
return
}
func (s *Server) shouldStopRefreshingBucket(bucketIndex int) bool {
b := &s.table.buckets[bucketIndex]
// Stop if the bucket is full, and none of the nodes are bad.
return b.Len() == s.table.K() && b.EachNode(func(n *node) bool {
return !s.nodeIsBad(n)
})
}
func (s *Server) refreshBucket(bucketIndex int) *traversal.Stats {
s.mu.RLock()
id := s.table.randomIdForBucket(bucketIndex)
op := traversal.Start(traversal.OperationInput{
Target: id.AsByteArray(),
Alpha: 3,
// Running this to completion with K matching the full-bucket size should result in a good,
// full bucket, since the Server will add nodes that respond to its table to replace the bad
// ones we're presumably refreshing. It might be possible to terminate the traversal early
// as soon as the bucket is good.
K: s.table.K(),
DoQuery: func(ctx context.Context, addr krpc.NodeAddr) traversal.QueryResult {
res := s.FindNode(NewAddr(addr.UDP()), id, QueryRateLimiting{})
return res.TraversalQueryResult(addr)
},
NodeFilter: s.TraversalNodeFilter,
})
defer func() {
s.mu.RUnlock()
op.Stop()
<-op.Stopped()
}()
b := &s.table.buckets[bucketIndex]
wait:
for {
if s.shouldStopRefreshingBucket(bucketIndex) {
break wait
}
op.AddNodes(types.AddrMaybeIdSliceFromNodeInfoSlice(s.notBadNodes()))
bucketChanged := b.changed.Signaled()
s.mu.RUnlock()
select {
case <-op.Stalled():
s.mu.RLock()
break wait
case <-bucketChanged:
}
s.mu.RLock()
}
return op.Stats()
}
func (s *Server) shouldBootstrap() bool {
return s.lastBootstrap.IsZero() || time.Since(s.lastBootstrap) > 30*time.Minute
}
func (s *Server) shouldBootstrapUnlocked() bool {
s.mu.RLock()
defer s.mu.RUnlock()
return s.shouldBootstrap()
}
func (s *Server) pingQuestionableNodesInBucket(bucketIndex int) {
b := &s.table.buckets[bucketIndex]
var wg sync.WaitGroup
b.EachNode(func(n *node) bool {
if s.IsQuestionable(n) {
wg.Add(1)
go func() {
defer wg.Done()
err := s.questionableNodePing(context.TODO(), n.Addr, n.Id.AsByteArray()).Err
if err != nil {
log.Printf("error pinging questionable node in bucket %v: %v", bucketIndex, err)
}
}()
}
return true
})
s.mu.RUnlock()
wg.Wait()
s.mu.RLock()
}
// A routine that maintains the Server's routing table, by pinging questionable nodes, and
// refreshing buckets. This should be invoked on a running Server when the caller is satisfied with
// having set it up. It is not necessary to explicitly Bootstrap the Server once this routine has
// started.
func (s *Server) TableMaintainer() {
for {
if s.shouldBootstrapUnlocked() {
stats, err := s.Bootstrap()
if err != nil {
log.Printf("error bootstrapping during bucket refresh: %v", err)
}
log.Printf("bucket refresh bootstrap stats: %v", stats)
}
s.mu.RLock()
for i := range s.table.buckets {
s.pingQuestionableNodesInBucket(i)
// if time.Since(b.lastChanged) < 15*time.Minute {
// continue
// }
if s.shouldStopRefreshingBucket(i) {
continue
}
s.logger().Levelf(log.Info, "refreshing bucket %v", i)
s.mu.RUnlock()
stats := s.refreshBucket(i)
s.logger().Levelf(log.Info, "finished refreshing bucket %v: %v", i, stats)
s.mu.RLock()
if !s.shouldStopRefreshingBucket(i) {
// Presumably we couldn't fill the bucket anymore, so assume we're as deep in the
// available node space as we can go.
break
}
}
s.mu.RUnlock()
select {
case <-s.closed.LockedChan(&s.mu):
return
case <-time.After(time.Minute):
}
}
}
func (s *Server) questionableNodePing(ctx context.Context, addr Addr, id krpc.ID) QueryResult {
// A ping query that will be certain to try at least 3 times.
res := s.Query(ctx, addr, "ping", QueryInput{
RateLimiting: QueryRateLimiting{
WaitOnRetries: true,
},
NumTries: 3,
})
if res.Err == nil && res.Reply.R != nil {
s.NodeRespondedToPing(addr, res.Reply.R.ID.Int160())
} else {
s.mu.Lock()
s.updateNode(addr, &id, false, func(n *node) {
n.failedLastQuestionablePing = true
})
s.mu.Unlock()
}
return res
}
// Whether we should consider a node for contact based on its address and possible ID.
func (s *Server) TraversalNodeFilter(node addrMaybeId) bool {
if !validNodeAddr(node.Addr.UDP()) {
return false
}
if s.ipBlocked(node.Addr.IP) {
return false
}
if node.Id == nil {
return true
}
return s.config.NoSecurity || NodeIdSecure(node.Id.AsByteArray(), node.Addr.IP)
}
func validNodeAddr(addr net.Addr) bool {
// At least for UDP addresses, we know what doesn't work.
ua := addr.(*net.UDPAddr)
if ua.Port == 0 {
return false
}
if ip4 := ua.IP.To4(); ip4 != nil && ip4[0] == 0 {
// Why?
return false
}
return true
}
// func (s *Server) refreshBucket(bucketIndex int) {
// targetId := s.table.randomIdForBucket(bucketIndex)
// }