mirror of https://github.com/status-im/op-geth.git
p2p/discover: implement node bonding
This a fix for an attack vector where the discovery protocol could be used to amplify traffic in a DDOS attack. A malicious actor would send a findnode request with the IP address and UDP port of the target as the source address. The recipient of the findnode packet would then send a neighbors packet (which is 16x the size of findnode) to the victim. Our solution is to require a 'bond' with the sender of findnode. If no bond exists, the findnode packet is not processed. A bond between nodes α and β is created when α replies to a ping from β. This (initial) version of the bonding implementation might still be vulnerable against replay attacks during the expiration time window. We will add stricter source address validation later.
This commit is contained in:
parent
92928309b2
commit
de7af720d6
|
@ -13,6 +13,8 @@ import (
|
|||
"net/url"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/ethereum/go-ethereum/crypto"
|
||||
|
@ -30,7 +32,8 @@ type Node struct {
|
|||
DiscPort int // UDP listening port for discovery protocol
|
||||
TCPPort int // TCP listening port for RLPx
|
||||
|
||||
active time.Time
|
||||
// this must be set/read using atomic load and store.
|
||||
activeStamp int64
|
||||
}
|
||||
|
||||
func newNode(id NodeID, addr *net.UDPAddr) *Node {
|
||||
|
@ -39,7 +42,6 @@ func newNode(id NodeID, addr *net.UDPAddr) *Node {
|
|||
IP: addr.IP,
|
||||
DiscPort: addr.Port,
|
||||
TCPPort: addr.Port,
|
||||
active: time.Now(),
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -48,6 +50,20 @@ func (n *Node) isValid() bool {
|
|||
return !n.IP.IsMulticast() && !n.IP.IsUnspecified() && n.TCPPort != 0 && n.DiscPort != 0
|
||||
}
|
||||
|
||||
func (n *Node) bumpActive() {
|
||||
stamp := time.Now().Unix()
|
||||
atomic.StoreInt64(&n.activeStamp, stamp)
|
||||
}
|
||||
|
||||
func (n *Node) active() time.Time {
|
||||
stamp := atomic.LoadInt64(&n.activeStamp)
|
||||
return time.Unix(stamp, 0)
|
||||
}
|
||||
|
||||
func (n *Node) addr() *net.UDPAddr {
|
||||
return &net.UDPAddr{IP: n.IP, Port: n.DiscPort}
|
||||
}
|
||||
|
||||
// The string representation of a Node is a URL.
|
||||
// Please see ParseNode for a description of the format.
|
||||
func (n *Node) String() string {
|
||||
|
@ -304,3 +320,26 @@ func randomID(a NodeID, n int) (b NodeID) {
|
|||
}
|
||||
return b
|
||||
}
|
||||
|
||||
// nodeDB stores all nodes we know about.
|
||||
type nodeDB struct {
|
||||
mu sync.RWMutex
|
||||
byID map[NodeID]*Node
|
||||
}
|
||||
|
||||
func (db *nodeDB) get(id NodeID) *Node {
|
||||
db.mu.RLock()
|
||||
defer db.mu.RUnlock()
|
||||
return db.byID[id]
|
||||
}
|
||||
|
||||
func (db *nodeDB) add(id NodeID, addr *net.UDPAddr, tcpPort uint16) *Node {
|
||||
db.mu.Lock()
|
||||
defer db.mu.Unlock()
|
||||
if db.byID == nil {
|
||||
db.byID = make(map[NodeID]*Node)
|
||||
}
|
||||
n := &Node{ID: id, IP: addr.IP, DiscPort: addr.Port, TCPPort: int(tcpPort)}
|
||||
db.byID[n.ID] = n
|
||||
return n
|
||||
}
|
||||
|
|
|
@ -14,9 +14,10 @@ import (
|
|||
)
|
||||
|
||||
const (
|
||||
alpha = 3 // Kademlia concurrency factor
|
||||
bucketSize = 16 // Kademlia bucket size
|
||||
nBuckets = nodeIDBits + 1 // Number of buckets
|
||||
alpha = 3 // Kademlia concurrency factor
|
||||
bucketSize = 16 // Kademlia bucket size
|
||||
nBuckets = nodeIDBits + 1 // Number of buckets
|
||||
maxBondingPingPongs = 10
|
||||
)
|
||||
|
||||
type Table struct {
|
||||
|
@ -24,27 +25,50 @@ type Table struct {
|
|||
buckets [nBuckets]*bucket // index of known nodes by distance
|
||||
nursery []*Node // bootstrap nodes
|
||||
|
||||
bondmu sync.Mutex
|
||||
bonding map[NodeID]*bondproc
|
||||
bondslots chan struct{} // limits total number of active bonding processes
|
||||
|
||||
net transport
|
||||
self *Node // metadata of the local node
|
||||
db *nodeDB
|
||||
}
|
||||
|
||||
type bondproc struct {
|
||||
err error
|
||||
n *Node
|
||||
done chan struct{}
|
||||
}
|
||||
|
||||
// transport is implemented by the UDP transport.
|
||||
// it is an interface so we can test without opening lots of UDP
|
||||
// sockets and without generating a private key.
|
||||
type transport interface {
|
||||
ping(*Node) error
|
||||
findnode(e *Node, target NodeID) ([]*Node, error)
|
||||
ping(NodeID, *net.UDPAddr) error
|
||||
waitping(NodeID) error
|
||||
findnode(toid NodeID, addr *net.UDPAddr, target NodeID) ([]*Node, error)
|
||||
close()
|
||||
}
|
||||
|
||||
// bucket contains nodes, ordered by their last activity.
|
||||
// the entry that was most recently active is the last element
|
||||
// in entries.
|
||||
type bucket struct {
|
||||
lastLookup time.Time
|
||||
entries []*Node
|
||||
}
|
||||
|
||||
func newTable(t transport, ourID NodeID, ourAddr *net.UDPAddr) *Table {
|
||||
tab := &Table{net: t, self: newNode(ourID, ourAddr)}
|
||||
tab := &Table{
|
||||
net: t,
|
||||
db: new(nodeDB),
|
||||
self: newNode(ourID, ourAddr),
|
||||
bonding: make(map[NodeID]*bondproc),
|
||||
bondslots: make(chan struct{}, maxBondingPingPongs),
|
||||
}
|
||||
for i := 0; i < cap(tab.bondslots); i++ {
|
||||
tab.bondslots <- struct{}{}
|
||||
}
|
||||
for i := range tab.buckets {
|
||||
tab.buckets[i] = new(bucket)
|
||||
}
|
||||
|
@ -107,8 +131,8 @@ func (tab *Table) Lookup(target NodeID) []*Node {
|
|||
asked[n.ID] = true
|
||||
pendingQueries++
|
||||
go func() {
|
||||
result, _ := tab.net.findnode(n, target)
|
||||
reply <- result
|
||||
r, _ := tab.net.findnode(n.ID, n.addr(), target)
|
||||
reply <- tab.bondall(r)
|
||||
}()
|
||||
}
|
||||
}
|
||||
|
@ -116,13 +140,11 @@ func (tab *Table) Lookup(target NodeID) []*Node {
|
|||
// we have asked all closest nodes, stop the search
|
||||
break
|
||||
}
|
||||
|
||||
// wait for the next reply
|
||||
for _, n := range <-reply {
|
||||
cn := n
|
||||
if !seen[n.ID] {
|
||||
if n != nil && !seen[n.ID] {
|
||||
seen[n.ID] = true
|
||||
result.push(cn, bucketSize)
|
||||
result.push(n, bucketSize)
|
||||
}
|
||||
}
|
||||
pendingQueries--
|
||||
|
@ -145,8 +167,9 @@ func (tab *Table) refresh() {
|
|||
result := tab.Lookup(randomID(tab.self.ID, ld))
|
||||
if len(result) == 0 {
|
||||
// bootstrap the table with a self lookup
|
||||
all := tab.bondall(tab.nursery)
|
||||
tab.mutex.Lock()
|
||||
tab.add(tab.nursery)
|
||||
tab.add(all)
|
||||
tab.mutex.Unlock()
|
||||
tab.Lookup(tab.self.ID)
|
||||
// TODO: the Kademlia paper says that we're supposed to perform
|
||||
|
@ -176,45 +199,105 @@ func (tab *Table) len() (n int) {
|
|||
return n
|
||||
}
|
||||
|
||||
// bumpOrAdd updates the activity timestamp for the given node and
|
||||
// attempts to insert the node into a bucket. The returned Node might
|
||||
// not be part of the table. The caller must hold tab.mutex.
|
||||
func (tab *Table) bumpOrAdd(node NodeID, from *net.UDPAddr) (n *Node) {
|
||||
b := tab.buckets[logdist(tab.self.ID, node)]
|
||||
if n = b.bump(node); n == nil {
|
||||
n = newNode(node, from)
|
||||
if len(b.entries) == bucketSize {
|
||||
tab.pingReplace(n, b)
|
||||
} else {
|
||||
b.entries = append(b.entries, n)
|
||||
// bondall bonds with all given nodes concurrently and returns
|
||||
// those nodes for which bonding has probably succeeded.
|
||||
func (tab *Table) bondall(nodes []*Node) (result []*Node) {
|
||||
rc := make(chan *Node, len(nodes))
|
||||
for i := range nodes {
|
||||
go func(n *Node) {
|
||||
nn, _ := tab.bond(false, n.ID, n.addr(), uint16(n.TCPPort))
|
||||
rc <- nn
|
||||
}(nodes[i])
|
||||
}
|
||||
for _ = range nodes {
|
||||
if n := <-rc; n != nil {
|
||||
result = append(result, n)
|
||||
}
|
||||
}
|
||||
return n
|
||||
return result
|
||||
}
|
||||
|
||||
func (tab *Table) pingReplace(n *Node, b *bucket) {
|
||||
old := b.entries[bucketSize-1]
|
||||
go func() {
|
||||
if err := tab.net.ping(old); err == nil {
|
||||
// it responded, we don't need to replace it.
|
||||
// bond ensures the local node has a bond with the given remote node.
|
||||
// It also attempts to insert the node into the table if bonding succeeds.
|
||||
// The caller must not hold tab.mutex.
|
||||
//
|
||||
// A bond is must be established before sending findnode requests.
|
||||
// Both sides must have completed a ping/pong exchange for a bond to
|
||||
// exist. The total number of active bonding processes is limited in
|
||||
// order to restrain network use.
|
||||
//
|
||||
// bond is meant to operate idempotently in that bonding with a remote
|
||||
// node which still remembers a previously established bond will work.
|
||||
// The remote node will simply not send a ping back, causing waitping
|
||||
// to time out.
|
||||
//
|
||||
// If pinged is true, the remote node has just pinged us and one half
|
||||
// of the process can be skipped.
|
||||
func (tab *Table) bond(pinged bool, id NodeID, addr *net.UDPAddr, tcpPort uint16) (*Node, error) {
|
||||
var n *Node
|
||||
if n = tab.db.get(id); n == nil {
|
||||
tab.bondmu.Lock()
|
||||
w := tab.bonding[id]
|
||||
if w != nil {
|
||||
// Wait for an existing bonding process to complete.
|
||||
tab.bondmu.Unlock()
|
||||
<-w.done
|
||||
} else {
|
||||
// Register a new bonding process.
|
||||
w = &bondproc{done: make(chan struct{})}
|
||||
tab.bonding[id] = w
|
||||
tab.bondmu.Unlock()
|
||||
// Do the ping/pong. The result goes into w.
|
||||
tab.pingpong(w, pinged, id, addr, tcpPort)
|
||||
// Unregister the process after it's done.
|
||||
tab.bondmu.Lock()
|
||||
delete(tab.bonding, id)
|
||||
tab.bondmu.Unlock()
|
||||
}
|
||||
n = w.n
|
||||
if w.err != nil {
|
||||
return nil, w.err
|
||||
}
|
||||
}
|
||||
tab.mutex.Lock()
|
||||
defer tab.mutex.Unlock()
|
||||
if b := tab.buckets[logdist(tab.self.ID, n.ID)]; !b.bump(n) {
|
||||
tab.pingreplace(n, b)
|
||||
}
|
||||
return n, nil
|
||||
}
|
||||
|
||||
func (tab *Table) pingpong(w *bondproc, pinged bool, id NodeID, addr *net.UDPAddr, tcpPort uint16) {
|
||||
<-tab.bondslots
|
||||
defer func() { tab.bondslots <- struct{}{} }()
|
||||
if w.err = tab.net.ping(id, addr); w.err != nil {
|
||||
close(w.done)
|
||||
return
|
||||
}
|
||||
if !pinged {
|
||||
// Give the remote node a chance to ping us before we start
|
||||
// sending findnode requests. If they still remember us,
|
||||
// waitping will simply time out.
|
||||
tab.net.waitping(id)
|
||||
}
|
||||
w.n = tab.db.add(id, addr, tcpPort)
|
||||
close(w.done)
|
||||
}
|
||||
|
||||
func (tab *Table) pingreplace(new *Node, b *bucket) {
|
||||
if len(b.entries) == bucketSize {
|
||||
oldest := b.entries[bucketSize-1]
|
||||
if err := tab.net.ping(oldest.ID, oldest.addr()); err == nil {
|
||||
// The node responded, we don't need to replace it.
|
||||
return
|
||||
}
|
||||
// it didn't respond, replace the node if it is still the oldest node.
|
||||
tab.mutex.Lock()
|
||||
if len(b.entries) > 0 && b.entries[len(b.entries)-1] == old {
|
||||
// slide down other entries and put the new one in front.
|
||||
// TODO: insert in correct position to keep the order
|
||||
copy(b.entries[1:], b.entries)
|
||||
b.entries[0] = n
|
||||
}
|
||||
tab.mutex.Unlock()
|
||||
}()
|
||||
}
|
||||
|
||||
// bump updates the activity timestamp for the given node.
|
||||
// The caller must hold tab.mutex.
|
||||
func (tab *Table) bump(node NodeID) {
|
||||
tab.buckets[logdist(tab.self.ID, node)].bump(node)
|
||||
} else {
|
||||
// Add a slot at the end so the last entry doesn't
|
||||
// fall off when adding the new node.
|
||||
b.entries = append(b.entries, nil)
|
||||
}
|
||||
copy(b.entries[1:], b.entries)
|
||||
b.entries[0] = new
|
||||
}
|
||||
|
||||
// add puts the entries into the table if their corresponding
|
||||
|
@ -240,17 +323,17 @@ outer:
|
|||
}
|
||||
}
|
||||
|
||||
func (b *bucket) bump(id NodeID) *Node {
|
||||
for i, n := range b.entries {
|
||||
if n.ID == id {
|
||||
n.active = time.Now()
|
||||
func (b *bucket) bump(n *Node) bool {
|
||||
for i := range b.entries {
|
||||
if b.entries[i].ID == n.ID {
|
||||
n.bumpActive()
|
||||
// move it to the front
|
||||
copy(b.entries[1:], b.entries[:i+1])
|
||||
b.entries[0] = n
|
||||
return n
|
||||
return true
|
||||
}
|
||||
}
|
||||
return nil
|
||||
return false
|
||||
}
|
||||
|
||||
// nodesByDistance is a list of nodes, ordered by
|
||||
|
|
|
@ -2,79 +2,68 @@ package discover
|
|||
|
||||
import (
|
||||
"crypto/ecdsa"
|
||||
"errors"
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"net"
|
||||
"reflect"
|
||||
"testing"
|
||||
"testing/quick"
|
||||
"time"
|
||||
|
||||
"github.com/ethereum/go-ethereum/crypto"
|
||||
)
|
||||
|
||||
func TestTable_bumpOrAddBucketAssign(t *testing.T) {
|
||||
tab := newTable(nil, NodeID{}, &net.UDPAddr{})
|
||||
for i := 1; i < len(tab.buckets); i++ {
|
||||
tab.bumpOrAdd(randomID(tab.self.ID, i), &net.UDPAddr{})
|
||||
}
|
||||
for i, b := range tab.buckets {
|
||||
if i > 0 && len(b.entries) != 1 {
|
||||
t.Errorf("bucket %d has %d entries, want 1", i, len(b.entries))
|
||||
func TestTable_pingReplace(t *testing.T) {
|
||||
doit := func(newNodeIsResponding, lastInBucketIsResponding bool) {
|
||||
transport := newPingRecorder()
|
||||
tab := newTable(transport, NodeID{}, &net.UDPAddr{})
|
||||
last := fillBucket(tab, 200)
|
||||
pingSender := randomID(tab.self.ID, 200)
|
||||
|
||||
// this gotPing should replace the last node
|
||||
// if the last node is not responding.
|
||||
transport.responding[last.ID] = lastInBucketIsResponding
|
||||
transport.responding[pingSender] = newNodeIsResponding
|
||||
tab.bond(true, pingSender, &net.UDPAddr{}, 0)
|
||||
|
||||
// first ping goes to sender (bonding pingback)
|
||||
if !transport.pinged[pingSender] {
|
||||
t.Error("table did not ping back sender")
|
||||
}
|
||||
if newNodeIsResponding {
|
||||
// second ping goes to oldest node in bucket
|
||||
// to see whether it is still alive.
|
||||
if !transport.pinged[last.ID] {
|
||||
t.Error("table did not ping last node in bucket")
|
||||
}
|
||||
}
|
||||
|
||||
tab.mutex.Lock()
|
||||
defer tab.mutex.Unlock()
|
||||
if l := len(tab.buckets[200].entries); l != bucketSize {
|
||||
t.Errorf("wrong bucket size after gotPing: got %d, want %d", bucketSize, l)
|
||||
}
|
||||
|
||||
if lastInBucketIsResponding || !newNodeIsResponding {
|
||||
if !contains(tab.buckets[200].entries, last.ID) {
|
||||
t.Error("last entry was removed")
|
||||
}
|
||||
if contains(tab.buckets[200].entries, pingSender) {
|
||||
t.Error("new entry was added")
|
||||
}
|
||||
} else {
|
||||
if contains(tab.buckets[200].entries, last.ID) {
|
||||
t.Error("last entry was not removed")
|
||||
}
|
||||
if !contains(tab.buckets[200].entries, pingSender) {
|
||||
t.Error("new entry was not added")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestTable_bumpOrAddPingReplace(t *testing.T) {
|
||||
pingC := make(pingC)
|
||||
tab := newTable(pingC, NodeID{}, &net.UDPAddr{})
|
||||
last := fillBucket(tab, 200)
|
||||
|
||||
// this bumpOrAdd should not replace the last node
|
||||
// because the node replies to ping.
|
||||
new := tab.bumpOrAdd(randomID(tab.self.ID, 200), &net.UDPAddr{})
|
||||
|
||||
pinged := <-pingC
|
||||
if pinged != last.ID {
|
||||
t.Fatalf("pinged wrong node: %v\nwant %v", pinged, last.ID)
|
||||
}
|
||||
|
||||
tab.mutex.Lock()
|
||||
defer tab.mutex.Unlock()
|
||||
if l := len(tab.buckets[200].entries); l != bucketSize {
|
||||
t.Errorf("wrong bucket size after bumpOrAdd: got %d, want %d", bucketSize, l)
|
||||
}
|
||||
if !contains(tab.buckets[200].entries, last.ID) {
|
||||
t.Error("last entry was removed")
|
||||
}
|
||||
if contains(tab.buckets[200].entries, new.ID) {
|
||||
t.Error("new entry was added")
|
||||
}
|
||||
}
|
||||
|
||||
func TestTable_bumpOrAddPingTimeout(t *testing.T) {
|
||||
tab := newTable(pingC(nil), NodeID{}, &net.UDPAddr{})
|
||||
last := fillBucket(tab, 200)
|
||||
|
||||
// this bumpOrAdd should replace the last node
|
||||
// because the node does not reply to ping.
|
||||
new := tab.bumpOrAdd(randomID(tab.self.ID, 200), &net.UDPAddr{})
|
||||
|
||||
// wait for async bucket update. damn. this needs to go away.
|
||||
time.Sleep(2 * time.Millisecond)
|
||||
|
||||
tab.mutex.Lock()
|
||||
defer tab.mutex.Unlock()
|
||||
if l := len(tab.buckets[200].entries); l != bucketSize {
|
||||
t.Errorf("wrong bucket size after bumpOrAdd: got %d, want %d", bucketSize, l)
|
||||
}
|
||||
if contains(tab.buckets[200].entries, last.ID) {
|
||||
t.Error("last entry was not removed")
|
||||
}
|
||||
if !contains(tab.buckets[200].entries, new.ID) {
|
||||
t.Error("new entry was not added")
|
||||
}
|
||||
doit(true, true)
|
||||
doit(false, true)
|
||||
doit(false, true)
|
||||
doit(false, false)
|
||||
}
|
||||
|
||||
func fillBucket(tab *Table, ld int) (last *Node) {
|
||||
|
@ -85,44 +74,27 @@ func fillBucket(tab *Table, ld int) (last *Node) {
|
|||
return b.entries[bucketSize-1]
|
||||
}
|
||||
|
||||
type pingC chan NodeID
|
||||
type pingRecorder struct{ responding, pinged map[NodeID]bool }
|
||||
|
||||
func (t pingC) findnode(n *Node, target NodeID) ([]*Node, error) {
|
||||
func newPingRecorder() *pingRecorder {
|
||||
return &pingRecorder{make(map[NodeID]bool), make(map[NodeID]bool)}
|
||||
}
|
||||
|
||||
func (t *pingRecorder) findnode(toid NodeID, toaddr *net.UDPAddr, target NodeID) ([]*Node, error) {
|
||||
panic("findnode called on pingRecorder")
|
||||
}
|
||||
func (t pingC) close() {
|
||||
func (t *pingRecorder) close() {
|
||||
panic("close called on pingRecorder")
|
||||
}
|
||||
func (t pingC) ping(n *Node) error {
|
||||
if t == nil {
|
||||
return errTimeout
|
||||
}
|
||||
t <- n.ID
|
||||
return nil
|
||||
func (t *pingRecorder) waitping(from NodeID) error {
|
||||
return nil // remote always pings
|
||||
}
|
||||
|
||||
func TestTable_bump(t *testing.T) {
|
||||
tab := newTable(nil, NodeID{}, &net.UDPAddr{})
|
||||
|
||||
// add an old entry and two recent ones
|
||||
oldactive := time.Now().Add(-2 * time.Minute)
|
||||
old := &Node{ID: randomID(tab.self.ID, 200), active: oldactive}
|
||||
others := []*Node{
|
||||
&Node{ID: randomID(tab.self.ID, 200), active: time.Now()},
|
||||
&Node{ID: randomID(tab.self.ID, 200), active: time.Now()},
|
||||
}
|
||||
tab.add(append(others, old))
|
||||
if tab.buckets[200].entries[0] == old {
|
||||
t.Fatal("old entry is at front of bucket")
|
||||
}
|
||||
|
||||
// bumping the old entry should move it to the front
|
||||
tab.bump(old.ID)
|
||||
if old.active == oldactive {
|
||||
t.Error("activity timestamp not updated")
|
||||
}
|
||||
if tab.buckets[200].entries[0] != old {
|
||||
t.Errorf("bumped entry did not move to the front of bucket")
|
||||
func (t *pingRecorder) ping(toid NodeID, toaddr *net.UDPAddr) error {
|
||||
t.pinged[toid] = true
|
||||
if t.responding[toid] {
|
||||
return nil
|
||||
} else {
|
||||
return errTimeout
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -210,7 +182,7 @@ func TestTable_Lookup(t *testing.T) {
|
|||
t.Fatalf("lookup on empty table returned %d results: %#v", len(results), results)
|
||||
}
|
||||
// seed table with initial node (otherwise lookup will terminate immediately)
|
||||
tab.bumpOrAdd(randomID(target, 200), &net.UDPAddr{Port: 200})
|
||||
tab.add([]*Node{newNode(randomID(target, 200), &net.UDPAddr{Port: 200})})
|
||||
|
||||
results := tab.Lookup(target)
|
||||
t.Logf("results:")
|
||||
|
@ -238,16 +210,16 @@ type findnodeOracle struct {
|
|||
target NodeID
|
||||
}
|
||||
|
||||
func (t findnodeOracle) findnode(n *Node, target NodeID) ([]*Node, error) {
|
||||
t.t.Logf("findnode query at dist %d", n.DiscPort)
|
||||
func (t findnodeOracle) findnode(toid NodeID, toaddr *net.UDPAddr, target NodeID) ([]*Node, error) {
|
||||
t.t.Logf("findnode query at dist %d", toaddr.Port)
|
||||
// current log distance is encoded in port number
|
||||
var result []*Node
|
||||
switch n.DiscPort {
|
||||
switch toaddr.Port {
|
||||
case 0:
|
||||
panic("query to node at distance 0")
|
||||
default:
|
||||
// TODO: add more randomness to distances
|
||||
next := n.DiscPort - 1
|
||||
next := toaddr.Port - 1
|
||||
for i := 0; i < bucketSize; i++ {
|
||||
result = append(result, &Node{ID: randomID(t.target, next), DiscPort: next})
|
||||
}
|
||||
|
@ -255,11 +227,9 @@ func (t findnodeOracle) findnode(n *Node, target NodeID) ([]*Node, error) {
|
|||
return result, nil
|
||||
}
|
||||
|
||||
func (t findnodeOracle) close() {}
|
||||
|
||||
func (t findnodeOracle) ping(n *Node) error {
|
||||
return errors.New("ping is not supported by this transport")
|
||||
}
|
||||
func (t findnodeOracle) close() {}
|
||||
func (t findnodeOracle) waitping(from NodeID) error { return nil }
|
||||
func (t findnodeOracle) ping(toid NodeID, toaddr *net.UDPAddr) error { return nil }
|
||||
|
||||
func hasDuplicates(slice []*Node) bool {
|
||||
seen := make(map[NodeID]bool)
|
||||
|
|
|
@ -20,12 +20,14 @@ const Version = 3
|
|||
|
||||
// Errors
|
||||
var (
|
||||
errPacketTooSmall = errors.New("too small")
|
||||
errBadHash = errors.New("bad hash")
|
||||
errExpired = errors.New("expired")
|
||||
errBadVersion = errors.New("version mismatch")
|
||||
errTimeout = errors.New("RPC timeout")
|
||||
errClosed = errors.New("socket closed")
|
||||
errPacketTooSmall = errors.New("too small")
|
||||
errBadHash = errors.New("bad hash")
|
||||
errExpired = errors.New("expired")
|
||||
errBadVersion = errors.New("version mismatch")
|
||||
errUnsolicitedReply = errors.New("unsolicited reply")
|
||||
errUnknownNode = errors.New("unknown node")
|
||||
errTimeout = errors.New("RPC timeout")
|
||||
errClosed = errors.New("socket closed")
|
||||
)
|
||||
|
||||
// Timeouts
|
||||
|
@ -80,14 +82,27 @@ type rpcNode struct {
|
|||
ID NodeID
|
||||
}
|
||||
|
||||
type packet interface {
|
||||
handle(t *udp, from *net.UDPAddr, fromID NodeID, mac []byte) error
|
||||
}
|
||||
|
||||
type conn interface {
|
||||
ReadFromUDP(b []byte) (n int, addr *net.UDPAddr, err error)
|
||||
WriteToUDP(b []byte, addr *net.UDPAddr) (n int, err error)
|
||||
Close() error
|
||||
LocalAddr() net.Addr
|
||||
}
|
||||
|
||||
// udp implements the RPC protocol.
|
||||
type udp struct {
|
||||
conn *net.UDPConn
|
||||
priv *ecdsa.PrivateKey
|
||||
conn conn
|
||||
priv *ecdsa.PrivateKey
|
||||
|
||||
addpending chan *pending
|
||||
replies chan reply
|
||||
closing chan struct{}
|
||||
nat nat.Interface
|
||||
gotreply chan reply
|
||||
|
||||
closing chan struct{}
|
||||
nat nat.Interface
|
||||
|
||||
*Table
|
||||
}
|
||||
|
@ -124,6 +139,9 @@ type reply struct {
|
|||
from NodeID
|
||||
ptype byte
|
||||
data interface{}
|
||||
// loop indicates whether there was
|
||||
// a matching request by sending on this channel.
|
||||
matched chan<- bool
|
||||
}
|
||||
|
||||
// ListenUDP returns a new table that listens for UDP packets on laddr.
|
||||
|
@ -136,15 +154,20 @@ func ListenUDP(priv *ecdsa.PrivateKey, laddr string, natm nat.Interface) (*Table
|
|||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
tab, _ := newUDP(priv, conn, natm)
|
||||
log.Infoln("Listening,", tab.self)
|
||||
return tab, nil
|
||||
}
|
||||
|
||||
func newUDP(priv *ecdsa.PrivateKey, c conn, natm nat.Interface) (*Table, *udp) {
|
||||
udp := &udp{
|
||||
conn: conn,
|
||||
conn: c,
|
||||
priv: priv,
|
||||
closing: make(chan struct{}),
|
||||
gotreply: make(chan reply),
|
||||
addpending: make(chan *pending),
|
||||
replies: make(chan reply),
|
||||
}
|
||||
|
||||
realaddr := conn.LocalAddr().(*net.UDPAddr)
|
||||
realaddr := c.LocalAddr().(*net.UDPAddr)
|
||||
if natm != nil {
|
||||
if !realaddr.IP.IsLoopback() {
|
||||
go nat.Map(natm, udp.closing, "udp", realaddr.Port, realaddr.Port, "ethereum discovery")
|
||||
|
@ -155,11 +178,9 @@ func ListenUDP(priv *ecdsa.PrivateKey, laddr string, natm nat.Interface) (*Table
|
|||
}
|
||||
}
|
||||
udp.Table = newTable(udp, PubkeyID(&priv.PublicKey), realaddr)
|
||||
|
||||
go udp.loop()
|
||||
go udp.readLoop()
|
||||
log.Infoln("Listening, ", udp.self)
|
||||
return udp.Table, nil
|
||||
return udp.Table, udp
|
||||
}
|
||||
|
||||
func (t *udp) close() {
|
||||
|
@ -169,10 +190,10 @@ func (t *udp) close() {
|
|||
}
|
||||
|
||||
// ping sends a ping message to the given node and waits for a reply.
|
||||
func (t *udp) ping(e *Node) error {
|
||||
func (t *udp) ping(toid NodeID, toaddr *net.UDPAddr) error {
|
||||
// TODO: maybe check for ReplyTo field in callback to measure RTT
|
||||
errc := t.pending(e.ID, pongPacket, func(interface{}) bool { return true })
|
||||
t.send(e, pingPacket, ping{
|
||||
errc := t.pending(toid, pongPacket, func(interface{}) bool { return true })
|
||||
t.send(toaddr, pingPacket, ping{
|
||||
Version: Version,
|
||||
IP: t.self.IP.String(),
|
||||
Port: uint16(t.self.TCPPort),
|
||||
|
@ -181,12 +202,16 @@ func (t *udp) ping(e *Node) error {
|
|||
return <-errc
|
||||
}
|
||||
|
||||
func (t *udp) waitping(from NodeID) error {
|
||||
return <-t.pending(from, pingPacket, func(interface{}) bool { return true })
|
||||
}
|
||||
|
||||
// findnode sends a findnode request to the given node and waits until
|
||||
// the node has sent up to k neighbors.
|
||||
func (t *udp) findnode(to *Node, target NodeID) ([]*Node, error) {
|
||||
func (t *udp) findnode(toid NodeID, toaddr *net.UDPAddr, target NodeID) ([]*Node, error) {
|
||||
nodes := make([]*Node, 0, bucketSize)
|
||||
nreceived := 0
|
||||
errc := t.pending(to.ID, neighborsPacket, func(r interface{}) bool {
|
||||
errc := t.pending(toid, neighborsPacket, func(r interface{}) bool {
|
||||
reply := r.(*neighbors)
|
||||
for _, n := range reply.Nodes {
|
||||
nreceived++
|
||||
|
@ -196,8 +221,7 @@ func (t *udp) findnode(to *Node, target NodeID) ([]*Node, error) {
|
|||
}
|
||||
return nreceived >= bucketSize
|
||||
})
|
||||
|
||||
t.send(to, findnodePacket, findnode{
|
||||
t.send(toaddr, findnodePacket, findnode{
|
||||
Target: target,
|
||||
Expiration: uint64(time.Now().Add(expiration).Unix()),
|
||||
})
|
||||
|
@ -219,6 +243,17 @@ func (t *udp) pending(id NodeID, ptype byte, callback func(interface{}) bool) <-
|
|||
return ch
|
||||
}
|
||||
|
||||
func (t *udp) handleReply(from NodeID, ptype byte, req packet) bool {
|
||||
matched := make(chan bool)
|
||||
select {
|
||||
case t.gotreply <- reply{from, ptype, req, matched}:
|
||||
// loop will handle it
|
||||
return <-matched
|
||||
case <-t.closing:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// loop runs in its own goroutin. it keeps track of
|
||||
// the refresh timer and the pending reply queue.
|
||||
func (t *udp) loop() {
|
||||
|
@ -249,6 +284,7 @@ func (t *udp) loop() {
|
|||
for _, p := range pending {
|
||||
p.errc <- errClosed
|
||||
}
|
||||
pending = nil
|
||||
return
|
||||
|
||||
case p := <-t.addpending:
|
||||
|
@ -256,18 +292,21 @@ func (t *udp) loop() {
|
|||
pending = append(pending, p)
|
||||
rearmTimeout()
|
||||
|
||||
case reply := <-t.replies:
|
||||
// run matching callbacks, remove if they return false.
|
||||
case r := <-t.gotreply:
|
||||
var matched bool
|
||||
for i := 0; i < len(pending); i++ {
|
||||
p := pending[i]
|
||||
if reply.from == p.from && reply.ptype == p.ptype && p.callback(reply.data) {
|
||||
p.errc <- nil
|
||||
copy(pending[i:], pending[i+1:])
|
||||
pending = pending[:len(pending)-1]
|
||||
i--
|
||||
if p := pending[i]; p.from == r.from && p.ptype == r.ptype {
|
||||
matched = true
|
||||
if p.callback(r.data) {
|
||||
// callback indicates the request is done, remove it.
|
||||
p.errc <- nil
|
||||
copy(pending[i:], pending[i+1:])
|
||||
pending = pending[:len(pending)-1]
|
||||
i--
|
||||
}
|
||||
}
|
||||
}
|
||||
rearmTimeout()
|
||||
r.matched <- matched
|
||||
|
||||
case now := <-timeout.C:
|
||||
// notify and remove callbacks whose deadline is in the past.
|
||||
|
@ -292,28 +331,11 @@ const (
|
|||
|
||||
var headSpace = make([]byte, headSize)
|
||||
|
||||
func (t *udp) send(to *Node, ptype byte, req interface{}) error {
|
||||
b := new(bytes.Buffer)
|
||||
b.Write(headSpace)
|
||||
b.WriteByte(ptype)
|
||||
if err := rlp.Encode(b, req); err != nil {
|
||||
log.Errorln("error encoding packet:", err)
|
||||
return err
|
||||
}
|
||||
|
||||
packet := b.Bytes()
|
||||
sig, err := crypto.Sign(crypto.Sha3(packet[headSize:]), t.priv)
|
||||
func (t *udp) send(toaddr *net.UDPAddr, ptype byte, req interface{}) error {
|
||||
packet, err := encodePacket(t.priv, ptype, req)
|
||||
if err != nil {
|
||||
log.Errorln("could not sign packet:", err)
|
||||
return err
|
||||
}
|
||||
copy(packet[macSize:], sig)
|
||||
// add the hash to the front. Note: this doesn't protect the
|
||||
// packet in any way. Our public key will be part of this hash in
|
||||
// the future.
|
||||
copy(packet, crypto.Sha3(packet[macSize:]))
|
||||
|
||||
toaddr := &net.UDPAddr{IP: to.IP, Port: to.DiscPort}
|
||||
log.DebugDetailf(">>> %v %T %v\n", toaddr, req, req)
|
||||
if _, err = t.conn.WriteToUDP(packet, toaddr); err != nil {
|
||||
log.DebugDetailln("UDP send failed:", err)
|
||||
|
@ -321,6 +343,28 @@ func (t *udp) send(to *Node, ptype byte, req interface{}) error {
|
|||
return err
|
||||
}
|
||||
|
||||
func encodePacket(priv *ecdsa.PrivateKey, ptype byte, req interface{}) ([]byte, error) {
|
||||
b := new(bytes.Buffer)
|
||||
b.Write(headSpace)
|
||||
b.WriteByte(ptype)
|
||||
if err := rlp.Encode(b, req); err != nil {
|
||||
log.Errorln("error encoding packet:", err)
|
||||
return nil, err
|
||||
}
|
||||
packet := b.Bytes()
|
||||
sig, err := crypto.Sign(crypto.Sha3(packet[headSize:]), priv)
|
||||
if err != nil {
|
||||
log.Errorln("could not sign packet:", err)
|
||||
return nil, err
|
||||
}
|
||||
copy(packet[macSize:], sig)
|
||||
// add the hash to the front. Note: this doesn't protect the
|
||||
// packet in any way. Our public key will be part of this hash in
|
||||
// The future.
|
||||
copy(packet, crypto.Sha3(packet[macSize:]))
|
||||
return packet, nil
|
||||
}
|
||||
|
||||
// readLoop runs in its own goroutine. it handles incoming UDP packets.
|
||||
func (t *udp) readLoop() {
|
||||
defer t.conn.Close()
|
||||
|
@ -330,29 +374,34 @@ func (t *udp) readLoop() {
|
|||
if err != nil {
|
||||
return
|
||||
}
|
||||
if err := t.packetIn(from, buf[:nbytes]); err != nil {
|
||||
packet, fromID, hash, err := decodePacket(buf[:nbytes])
|
||||
if err != nil {
|
||||
log.Debugf("Bad packet from %v: %v\n", from, err)
|
||||
continue
|
||||
}
|
||||
log.DebugDetailf("<<< %v %T %v\n", from, packet, packet)
|
||||
go func() {
|
||||
if err := packet.handle(t, from, fromID, hash); err != nil {
|
||||
log.Debugf("error handling %T from %v: %v", packet, from, err)
|
||||
}
|
||||
}()
|
||||
}
|
||||
}
|
||||
|
||||
func (t *udp) packetIn(from *net.UDPAddr, buf []byte) error {
|
||||
func decodePacket(buf []byte) (packet, NodeID, []byte, error) {
|
||||
if len(buf) < headSize+1 {
|
||||
return errPacketTooSmall
|
||||
return nil, NodeID{}, nil, errPacketTooSmall
|
||||
}
|
||||
hash, sig, sigdata := buf[:macSize], buf[macSize:headSize], buf[headSize:]
|
||||
shouldhash := crypto.Sha3(buf[macSize:])
|
||||
if !bytes.Equal(hash, shouldhash) {
|
||||
return errBadHash
|
||||
return nil, NodeID{}, nil, errBadHash
|
||||
}
|
||||
fromID, err := recoverNodeID(crypto.Sha3(buf[headSize:]), sig)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var req interface {
|
||||
handle(t *udp, from *net.UDPAddr, fromID NodeID, mac []byte) error
|
||||
return nil, NodeID{}, hash, err
|
||||
}
|
||||
var req packet
|
||||
switch ptype := sigdata[0]; ptype {
|
||||
case pingPacket:
|
||||
req = new(ping)
|
||||
|
@ -363,13 +412,10 @@ func (t *udp) packetIn(from *net.UDPAddr, buf []byte) error {
|
|||
case neighborsPacket:
|
||||
req = new(neighbors)
|
||||
default:
|
||||
return fmt.Errorf("unknown type: %d", ptype)
|
||||
return nil, fromID, hash, fmt.Errorf("unknown type: %d", ptype)
|
||||
}
|
||||
if err := rlp.Decode(bytes.NewReader(sigdata[1:]), req); err != nil {
|
||||
return err
|
||||
}
|
||||
log.DebugDetailf("<<< %v %T %v\n", from, req, req)
|
||||
return req.handle(t, from, fromID, hash)
|
||||
err = rlp.Decode(bytes.NewReader(sigdata[1:]), req)
|
||||
return req, fromID, hash, err
|
||||
}
|
||||
|
||||
func (req *ping) handle(t *udp, from *net.UDPAddr, fromID NodeID, mac []byte) error {
|
||||
|
@ -379,18 +425,14 @@ func (req *ping) handle(t *udp, from *net.UDPAddr, fromID NodeID, mac []byte) er
|
|||
if req.Version != Version {
|
||||
return errBadVersion
|
||||
}
|
||||
t.mutex.Lock()
|
||||
// Note: we're ignoring the provided IP address right now
|
||||
n := t.bumpOrAdd(fromID, from)
|
||||
if req.Port != 0 {
|
||||
n.TCPPort = int(req.Port)
|
||||
}
|
||||
t.mutex.Unlock()
|
||||
|
||||
t.send(n, pongPacket, pong{
|
||||
t.send(from, pongPacket, pong{
|
||||
ReplyTok: mac,
|
||||
Expiration: uint64(time.Now().Add(expiration).Unix()),
|
||||
})
|
||||
if !t.handleReply(fromID, pingPacket, req) {
|
||||
// Note: we're ignoring the provided IP address right now
|
||||
t.bond(true, fromID, from, req.Port)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
|
@ -398,11 +440,9 @@ func (req *pong) handle(t *udp, from *net.UDPAddr, fromID NodeID, mac []byte) er
|
|||
if expired(req.Expiration) {
|
||||
return errExpired
|
||||
}
|
||||
t.mutex.Lock()
|
||||
t.bump(fromID)
|
||||
t.mutex.Unlock()
|
||||
|
||||
t.replies <- reply{fromID, pongPacket, req}
|
||||
if !t.handleReply(fromID, pongPacket, req) {
|
||||
return errUnsolicitedReply
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
|
@ -410,12 +450,21 @@ func (req *findnode) handle(t *udp, from *net.UDPAddr, fromID NodeID, mac []byte
|
|||
if expired(req.Expiration) {
|
||||
return errExpired
|
||||
}
|
||||
if t.db.get(fromID) == nil {
|
||||
// No bond exists, we don't process the packet. This prevents
|
||||
// an attack vector where the discovery protocol could be used
|
||||
// to amplify traffic in a DDOS attack. A malicious actor
|
||||
// would send a findnode request with the IP address and UDP
|
||||
// port of the target as the source address. The recipient of
|
||||
// the findnode packet would then send a neighbors packet
|
||||
// (which is a much bigger packet than findnode) to the victim.
|
||||
return errUnknownNode
|
||||
}
|
||||
t.mutex.Lock()
|
||||
e := t.bumpOrAdd(fromID, from)
|
||||
closest := t.closest(req.Target, bucketSize).entries
|
||||
t.mutex.Unlock()
|
||||
|
||||
t.send(e, neighborsPacket, neighbors{
|
||||
t.send(from, neighborsPacket, neighbors{
|
||||
Nodes: closest,
|
||||
Expiration: uint64(time.Now().Add(expiration).Unix()),
|
||||
})
|
||||
|
@ -426,12 +475,9 @@ func (req *neighbors) handle(t *udp, from *net.UDPAddr, fromID NodeID, mac []byt
|
|||
if expired(req.Expiration) {
|
||||
return errExpired
|
||||
}
|
||||
t.mutex.Lock()
|
||||
t.bump(fromID)
|
||||
t.add(req.Nodes)
|
||||
t.mutex.Unlock()
|
||||
|
||||
t.replies <- reply{fromID, neighborsPacket, req}
|
||||
if !t.handleReply(fromID, neighborsPacket, req) {
|
||||
return errUnsolicitedReply
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
|
|
|
@ -1,10 +1,18 @@
|
|||
package discover
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/ecdsa"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
logpkg "log"
|
||||
"net"
|
||||
"os"
|
||||
"path"
|
||||
"reflect"
|
||||
"runtime"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
|
@ -15,22 +23,243 @@ func init() {
|
|||
logger.AddLogSystem(logger.NewStdLogSystem(os.Stdout, logpkg.LstdFlags, logger.ErrorLevel))
|
||||
}
|
||||
|
||||
func TestUDP_ping(t *testing.T) {
|
||||
type udpTest struct {
|
||||
t *testing.T
|
||||
pipe *dgramPipe
|
||||
table *Table
|
||||
udp *udp
|
||||
sent [][]byte
|
||||
localkey, remotekey *ecdsa.PrivateKey
|
||||
remoteaddr *net.UDPAddr
|
||||
}
|
||||
|
||||
func newUDPTest(t *testing.T) *udpTest {
|
||||
test := &udpTest{
|
||||
t: t,
|
||||
pipe: newpipe(),
|
||||
localkey: newkey(),
|
||||
remotekey: newkey(),
|
||||
remoteaddr: &net.UDPAddr{IP: net.IP{1, 2, 3, 4}, Port: 30303},
|
||||
}
|
||||
test.table, test.udp = newUDP(test.localkey, test.pipe, nil)
|
||||
return test
|
||||
}
|
||||
|
||||
// handles a packet as if it had been sent to the transport.
|
||||
func (test *udpTest) packetIn(wantError error, ptype byte, data packet) error {
|
||||
enc, err := encodePacket(test.remotekey, ptype, data)
|
||||
if err != nil {
|
||||
return test.errorf("packet (%d) encode error: %v", err)
|
||||
}
|
||||
test.sent = append(test.sent, enc)
|
||||
err = data.handle(test.udp, test.remoteaddr, PubkeyID(&test.remotekey.PublicKey), enc[:macSize])
|
||||
if err != wantError {
|
||||
return test.errorf("error mismatch: got %q, want %q", err, wantError)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// waits for a packet to be sent by the transport.
|
||||
// validate should have type func(*udpTest, X) error, where X is a packet type.
|
||||
func (test *udpTest) waitPacketOut(validate interface{}) error {
|
||||
dgram := test.pipe.waitPacketOut()
|
||||
p, _, _, err := decodePacket(dgram)
|
||||
if err != nil {
|
||||
return test.errorf("sent packet decode error: %v", err)
|
||||
}
|
||||
fn := reflect.ValueOf(validate)
|
||||
exptype := fn.Type().In(0)
|
||||
if reflect.TypeOf(p) != exptype {
|
||||
return test.errorf("sent packet type mismatch, got: %v, want: %v", reflect.TypeOf(p), exptype)
|
||||
}
|
||||
fn.Call([]reflect.Value{reflect.ValueOf(p)})
|
||||
return nil
|
||||
}
|
||||
|
||||
func (test *udpTest) errorf(format string, args ...interface{}) error {
|
||||
_, file, line, ok := runtime.Caller(2) // errorf + waitPacketOut
|
||||
if ok {
|
||||
file = path.Base(file)
|
||||
} else {
|
||||
file = "???"
|
||||
line = 1
|
||||
}
|
||||
err := fmt.Errorf(format, args...)
|
||||
fmt.Printf("\t%s:%d: %v\n", file, line, err)
|
||||
test.t.Fail()
|
||||
return err
|
||||
}
|
||||
|
||||
// shared test variables
|
||||
var (
|
||||
futureExp = uint64(time.Now().Add(10 * time.Hour).Unix())
|
||||
testTarget = MustHexID("01010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101")
|
||||
)
|
||||
|
||||
func TestUDP_packetErrors(t *testing.T) {
|
||||
test := newUDPTest(t)
|
||||
defer test.table.Close()
|
||||
|
||||
test.packetIn(errExpired, pingPacket, &ping{IP: "foo", Port: 99, Version: Version})
|
||||
test.packetIn(errBadVersion, pingPacket, &ping{IP: "foo", Port: 99, Version: 99, Expiration: futureExp})
|
||||
test.packetIn(errUnsolicitedReply, pongPacket, &pong{ReplyTok: []byte{}, Expiration: futureExp})
|
||||
test.packetIn(errUnknownNode, findnodePacket, &findnode{Expiration: futureExp})
|
||||
test.packetIn(errUnsolicitedReply, neighborsPacket, &neighbors{Expiration: futureExp})
|
||||
}
|
||||
|
||||
func TestUDP_pingTimeout(t *testing.T) {
|
||||
t.Parallel()
|
||||
test := newUDPTest(t)
|
||||
defer test.table.Close()
|
||||
|
||||
n1, _ := ListenUDP(newkey(), "127.0.0.1:0", nil)
|
||||
n2, _ := ListenUDP(newkey(), "127.0.0.1:0", nil)
|
||||
defer n1.Close()
|
||||
defer n2.Close()
|
||||
toaddr := &net.UDPAddr{IP: net.ParseIP("1.2.3.4"), Port: 2222}
|
||||
toid := NodeID{1, 2, 3, 4}
|
||||
if err := test.udp.ping(toid, toaddr); err != errTimeout {
|
||||
t.Error("expected timeout error, got", err)
|
||||
}
|
||||
}
|
||||
|
||||
if err := n1.net.ping(n2.self); err != nil {
|
||||
t.Fatalf("ping error: %v", err)
|
||||
func TestUDP_findnodeTimeout(t *testing.T) {
|
||||
t.Parallel()
|
||||
test := newUDPTest(t)
|
||||
defer test.table.Close()
|
||||
|
||||
toaddr := &net.UDPAddr{IP: net.ParseIP("1.2.3.4"), Port: 2222}
|
||||
toid := NodeID{1, 2, 3, 4}
|
||||
target := NodeID{4, 5, 6, 7}
|
||||
result, err := test.udp.findnode(toid, toaddr, target)
|
||||
if err != errTimeout {
|
||||
t.Error("expected timeout error, got", err)
|
||||
}
|
||||
if find(n2, n1.self.ID) == nil {
|
||||
t.Errorf("node 2 does not contain id of node 1")
|
||||
if len(result) > 0 {
|
||||
t.Error("expected empty result, got", result)
|
||||
}
|
||||
if e := find(n1, n2.self.ID); e != nil {
|
||||
t.Errorf("node 1 does contains id of node 2: %v", e)
|
||||
}
|
||||
|
||||
func TestUDP_findnode(t *testing.T) {
|
||||
test := newUDPTest(t)
|
||||
defer test.table.Close()
|
||||
|
||||
// put a few nodes into the table. their exact
|
||||
// distribution shouldn't matter much, altough we need to
|
||||
// take care not to overflow any bucket.
|
||||
target := testTarget
|
||||
nodes := &nodesByDistance{target: target}
|
||||
for i := 0; i < bucketSize; i++ {
|
||||
nodes.push(&Node{
|
||||
IP: net.IP{1, 2, 3, byte(i)},
|
||||
DiscPort: i + 2,
|
||||
TCPPort: i + 2,
|
||||
ID: randomID(test.table.self.ID, i+2),
|
||||
}, bucketSize)
|
||||
}
|
||||
test.table.add(nodes.entries)
|
||||
|
||||
// ensure there's a bond with the test node,
|
||||
// findnode won't be accepted otherwise.
|
||||
test.table.db.add(PubkeyID(&test.remotekey.PublicKey), test.remoteaddr, 99)
|
||||
|
||||
// check that closest neighbors are returned.
|
||||
test.packetIn(nil, findnodePacket, &findnode{Target: testTarget, Expiration: futureExp})
|
||||
test.waitPacketOut(func(p *neighbors) {
|
||||
expected := test.table.closest(testTarget, bucketSize)
|
||||
if len(p.Nodes) != bucketSize {
|
||||
t.Errorf("wrong number of results: got %d, want %d", len(p.Nodes), bucketSize)
|
||||
}
|
||||
for i := range p.Nodes {
|
||||
if p.Nodes[i].ID != expected.entries[i].ID {
|
||||
t.Errorf("result mismatch at %d:\n got: %v\n want: %v", i, p.Nodes[i], expected.entries[i])
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestUDP_findnodeMultiReply(t *testing.T) {
|
||||
test := newUDPTest(t)
|
||||
defer test.table.Close()
|
||||
|
||||
// queue a pending findnode request
|
||||
resultc, errc := make(chan []*Node), make(chan error)
|
||||
go func() {
|
||||
rid := PubkeyID(&test.remotekey.PublicKey)
|
||||
ns, err := test.udp.findnode(rid, test.remoteaddr, testTarget)
|
||||
if err != nil && len(ns) == 0 {
|
||||
errc <- err
|
||||
} else {
|
||||
resultc <- ns
|
||||
}
|
||||
}()
|
||||
|
||||
// wait for the findnode to be sent.
|
||||
// after it is sent, the transport is waiting for a reply
|
||||
test.waitPacketOut(func(p *findnode) {
|
||||
if p.Target != testTarget {
|
||||
t.Errorf("wrong target: got %v, want %v", p.Target, testTarget)
|
||||
}
|
||||
})
|
||||
|
||||
// send the reply as two packets.
|
||||
list := []*Node{
|
||||
MustParseNode("enode://ba85011c70bcc5c04d8607d3a0ed29aa6179c092cbdda10d5d32684fb33ed01bd94f588ca8f91ac48318087dcb02eaf36773a7a453f0eedd6742af668097b29c@10.0.1.16:30303"),
|
||||
MustParseNode("enode://81fa361d25f157cd421c60dcc28d8dac5ef6a89476633339c5df30287474520caca09627da18543d9079b5b288698b542d56167aa5c09111e55acdbbdf2ef799@10.0.1.16:30303"),
|
||||
MustParseNode("enode://9bffefd833d53fac8e652415f4973bee289e8b1a5c6c4cbe70abf817ce8a64cee11b823b66a987f51aaa9fba0d6a91b3e6bf0d5a5d1042de8e9eeea057b217f8@10.0.1.36:30301"),
|
||||
MustParseNode("enode://1b5b4aa662d7cb44a7221bfba67302590b643028197a7d5214790f3bac7aaa4a3241be9e83c09cf1f6c69d007c634faae3dc1b1221793e8446c0b3a09de65960@10.0.1.16:30303"),
|
||||
}
|
||||
test.packetIn(nil, neighborsPacket, &neighbors{Expiration: futureExp, Nodes: list[:2]})
|
||||
test.packetIn(nil, neighborsPacket, &neighbors{Expiration: futureExp, Nodes: list[2:]})
|
||||
|
||||
// check that the sent neighbors are all returned by findnode
|
||||
select {
|
||||
case result := <-resultc:
|
||||
if !reflect.DeepEqual(result, list) {
|
||||
t.Errorf("neighbors mismatch:\n got: %v\n want: %v", result, list)
|
||||
}
|
||||
case err := <-errc:
|
||||
t.Errorf("findnode error: %v", err)
|
||||
case <-time.After(5 * time.Second):
|
||||
t.Error("findnode did not return within 5 seconds")
|
||||
}
|
||||
}
|
||||
|
||||
func TestUDP_successfulPing(t *testing.T) {
|
||||
test := newUDPTest(t)
|
||||
defer test.table.Close()
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
test.packetIn(nil, pingPacket, &ping{IP: "foo", Port: 99, Version: Version, Expiration: futureExp})
|
||||
close(done)
|
||||
}()
|
||||
|
||||
// the ping is replied to.
|
||||
test.waitPacketOut(func(p *pong) {
|
||||
pinghash := test.sent[0][:macSize]
|
||||
if !bytes.Equal(p.ReplyTok, pinghash) {
|
||||
t.Errorf("got ReplyTok %x, want %x", p.ReplyTok, pinghash)
|
||||
}
|
||||
})
|
||||
|
||||
// remote is unknown, the table pings back.
|
||||
test.waitPacketOut(func(p *ping) error { return nil })
|
||||
test.packetIn(nil, pongPacket, &pong{Expiration: futureExp})
|
||||
|
||||
// ping should return shortly after getting the pong packet.
|
||||
<-done
|
||||
|
||||
// check that the node was added.
|
||||
rid := PubkeyID(&test.remotekey.PublicKey)
|
||||
rnode := find(test.table, rid)
|
||||
if rnode == nil {
|
||||
t.Fatalf("node %v not found in table", rid)
|
||||
}
|
||||
if !bytes.Equal(rnode.IP, test.remoteaddr.IP) {
|
||||
t.Errorf("node has wrong IP: got %v, want: %v", rnode.IP, test.remoteaddr.IP)
|
||||
}
|
||||
if rnode.DiscPort != test.remoteaddr.Port {
|
||||
t.Errorf("node has wrong Port: got %v, want: %v", rnode.DiscPort, test.remoteaddr.Port)
|
||||
}
|
||||
if rnode.TCPPort != 99 {
|
||||
t.Errorf("node has wrong Port: got %v, want: %v", rnode.TCPPort, 99)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -45,167 +274,66 @@ func find(tab *Table, id NodeID) *Node {
|
|||
return nil
|
||||
}
|
||||
|
||||
func TestUDP_findnode(t *testing.T) {
|
||||
t.Parallel()
|
||||
// dgramPipe is a fake UDP socket. It queues all sent datagrams.
|
||||
type dgramPipe struct {
|
||||
mu *sync.Mutex
|
||||
cond *sync.Cond
|
||||
closing chan struct{}
|
||||
closed bool
|
||||
queue [][]byte
|
||||
}
|
||||
|
||||
n1, _ := ListenUDP(newkey(), "127.0.0.1:0", nil)
|
||||
n2, _ := ListenUDP(newkey(), "127.0.0.1:0", nil)
|
||||
defer n1.Close()
|
||||
defer n2.Close()
|
||||
|
||||
// put a few nodes into n2. the exact distribution shouldn't
|
||||
// matter much, altough we need to take care not to overflow
|
||||
// any bucket.
|
||||
target := randomID(n1.self.ID, 100)
|
||||
nodes := &nodesByDistance{target: target}
|
||||
for i := 0; i < bucketSize; i++ {
|
||||
n2.add([]*Node{&Node{
|
||||
IP: net.IP{1, 2, 3, byte(i)},
|
||||
DiscPort: i + 2,
|
||||
TCPPort: i + 2,
|
||||
ID: randomID(n2.self.ID, i+2),
|
||||
}})
|
||||
}
|
||||
n2.add(nodes.entries)
|
||||
n2.bumpOrAdd(n1.self.ID, &net.UDPAddr{IP: n1.self.IP, Port: n1.self.DiscPort})
|
||||
expected := n2.closest(target, bucketSize)
|
||||
|
||||
err := runUDP(10, func() error {
|
||||
result, _ := n1.net.findnode(n2.self, target)
|
||||
if len(result) != bucketSize {
|
||||
return fmt.Errorf("wrong number of results: got %d, want %d", len(result), bucketSize)
|
||||
}
|
||||
for i := range result {
|
||||
if result[i].ID != expected.entries[i].ID {
|
||||
return fmt.Errorf("result mismatch at %d:\n got: %v\n want: %v", i, result[i], expected.entries[i])
|
||||
}
|
||||
}
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
func newpipe() *dgramPipe {
|
||||
mu := new(sync.Mutex)
|
||||
return &dgramPipe{
|
||||
closing: make(chan struct{}),
|
||||
cond: &sync.Cond{L: mu},
|
||||
mu: mu,
|
||||
}
|
||||
}
|
||||
|
||||
func TestUDP_replytimeout(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// reserve a port so we don't talk to an existing service by accident
|
||||
addr, _ := net.ResolveUDPAddr("udp", "127.0.0.1:0")
|
||||
fd, err := net.ListenUDP("udp", addr)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer fd.Close()
|
||||
|
||||
n1, _ := ListenUDP(newkey(), "127.0.0.1:0", nil)
|
||||
defer n1.Close()
|
||||
n2 := n1.bumpOrAdd(randomID(n1.self.ID, 10), fd.LocalAddr().(*net.UDPAddr))
|
||||
|
||||
if err := n1.net.ping(n2); err != errTimeout {
|
||||
t.Error("expected timeout error, got", err)
|
||||
}
|
||||
|
||||
if result, err := n1.net.findnode(n2, n1.self.ID); err != errTimeout {
|
||||
t.Error("expected timeout error, got", err)
|
||||
} else if len(result) > 0 {
|
||||
t.Error("expected empty result, got", result)
|
||||
// WriteToUDP queues a datagram.
|
||||
func (c *dgramPipe) WriteToUDP(b []byte, to *net.UDPAddr) (n int, err error) {
|
||||
msg := make([]byte, len(b))
|
||||
copy(msg, b)
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
if c.closed {
|
||||
return 0, errors.New("closed")
|
||||
}
|
||||
c.queue = append(c.queue, msg)
|
||||
c.cond.Signal()
|
||||
return len(b), nil
|
||||
}
|
||||
|
||||
func TestUDP_findnodeMultiReply(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
n1, _ := ListenUDP(newkey(), "127.0.0.1:0", nil)
|
||||
n2, _ := ListenUDP(newkey(), "127.0.0.1:0", nil)
|
||||
udp2 := n2.net.(*udp)
|
||||
defer n1.Close()
|
||||
defer n2.Close()
|
||||
|
||||
err := runUDP(10, func() error {
|
||||
nodes := make([]*Node, bucketSize)
|
||||
for i := range nodes {
|
||||
nodes[i] = &Node{
|
||||
IP: net.IP{1, 2, 3, 4},
|
||||
DiscPort: i + 1,
|
||||
TCPPort: i + 1,
|
||||
ID: randomID(n2.self.ID, i+1),
|
||||
}
|
||||
}
|
||||
|
||||
// ask N2 for neighbors. it will send an empty reply back.
|
||||
// the request will wait for up to bucketSize replies.
|
||||
resultc := make(chan []*Node)
|
||||
errc := make(chan error)
|
||||
go func() {
|
||||
ns, err := n1.net.findnode(n2.self, n1.self.ID)
|
||||
if err != nil {
|
||||
errc <- err
|
||||
} else {
|
||||
resultc <- ns
|
||||
}
|
||||
}()
|
||||
|
||||
// send a few more neighbors packets to N1.
|
||||
// it should collect those.
|
||||
for end := 0; end < len(nodes); {
|
||||
off := end
|
||||
if end = end + 5; end > len(nodes) {
|
||||
end = len(nodes)
|
||||
}
|
||||
udp2.send(n1.self, neighborsPacket, neighbors{
|
||||
Nodes: nodes[off:end],
|
||||
Expiration: uint64(time.Now().Add(10 * time.Second).Unix()),
|
||||
})
|
||||
}
|
||||
|
||||
// check that they are all returned. we cannot just check for
|
||||
// equality because they might not be returned in the order they
|
||||
// were sent.
|
||||
var result []*Node
|
||||
select {
|
||||
case result = <-resultc:
|
||||
case err := <-errc:
|
||||
return err
|
||||
}
|
||||
if hasDuplicates(result) {
|
||||
return fmt.Errorf("result slice contains duplicates")
|
||||
}
|
||||
if len(result) != len(nodes) {
|
||||
return fmt.Errorf("wrong number of nodes returned: got %d, want %d", len(result), len(nodes))
|
||||
}
|
||||
matched := make(map[NodeID]bool)
|
||||
for _, n := range result {
|
||||
for _, expn := range nodes {
|
||||
if n.ID == expn.ID { // && bytes.Equal(n.Addr.IP, expn.Addr.IP) && n.Addr.Port == expn.Addr.Port {
|
||||
matched[n.ID] = true
|
||||
}
|
||||
}
|
||||
}
|
||||
if len(matched) != len(nodes) {
|
||||
return fmt.Errorf("wrong number of matching nodes: got %d, want %d", len(matched), len(nodes))
|
||||
}
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
// ReadFromUDP just hangs until the pipe is closed.
|
||||
func (c *dgramPipe) ReadFromUDP(b []byte) (n int, addr *net.UDPAddr, err error) {
|
||||
<-c.closing
|
||||
return 0, nil, io.EOF
|
||||
}
|
||||
|
||||
// runUDP runs a test n times and returns an error if the test failed
|
||||
// in all n runs. This is necessary because UDP is unreliable even for
|
||||
// connections on the local machine, causing test failures.
|
||||
func runUDP(n int, test func() error) error {
|
||||
errcount := 0
|
||||
errors := ""
|
||||
for i := 0; i < n; i++ {
|
||||
if err := test(); err != nil {
|
||||
errors += fmt.Sprintf("\n#%d: %v", i, err)
|
||||
errcount++
|
||||
}
|
||||
}
|
||||
if errcount == n {
|
||||
return fmt.Errorf("failed on all %d iterations:%s", n, errors)
|
||||
func (c *dgramPipe) Close() error {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
if !c.closed {
|
||||
close(c.closing)
|
||||
c.closed = true
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *dgramPipe) LocalAddr() net.Addr {
|
||||
return &net.UDPAddr{}
|
||||
}
|
||||
|
||||
func (c *dgramPipe) waitPacketOut() []byte {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
for len(c.queue) == 0 {
|
||||
c.cond.Wait()
|
||||
}
|
||||
p := c.queue[0]
|
||||
copy(c.queue, c.queue[1:])
|
||||
c.queue = c.queue[:len(c.queue)-1]
|
||||
return p
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue