p2p: new dialer, peer management without locks

The most visible change is event-based dialing, which should be an
improvement over the timer-based system that we have at the moment.
The dialer gets a chance to compute new tasks whenever peers change or
dials complete. This is better than checking peers on a timer because
dials happen faster. The dialer can now make more precise decisions
about whom to dial based on the peer set and we can test those
decisions without actually opening any sockets.

Peer management is easier to test because the tests can inject
connections at checkpoints (after enc handshake, after protocol
handshake).

Most of the handshake stuff is now part of the RLPx code. It could be
exported or move to its own package because it is no longer entangled
with Server logic.
This commit is contained in:
Felix Lange 2015-05-16 00:38:28 +02:00
parent 9f38ef5d97
commit 1440f9a37a
11 changed files with 2142 additions and 1353 deletions

276
p2p/dial.go Normal file
View File

@ -0,0 +1,276 @@
package p2p
import (
"container/heap"
"crypto/rand"
"fmt"
"net"
"time"
"github.com/ethereum/go-ethereum/logger"
"github.com/ethereum/go-ethereum/logger/glog"
"github.com/ethereum/go-ethereum/p2p/discover"
)
const (
// This is the amount of time spent waiting in between
// redialing a certain node.
dialHistoryExpiration = 30 * time.Second
// Discovery lookup tasks will wait for this long when
// no results are returned. This can happen if the table
// becomes empty (i.e. not often).
emptyLookupDelay = 10 * time.Second
)
// dialstate schedules dials and discovery lookups.
// it get's a chance to compute new tasks on every iteration
// of the main loop in Server.run.
type dialstate struct {
maxDynDials int
ntab discoverTable
lookupRunning bool
bootstrapped bool
dialing map[discover.NodeID]connFlag
lookupBuf []*discover.Node // current discovery lookup results
randomNodes []*discover.Node // filled from Table
static map[discover.NodeID]*discover.Node
hist *dialHistory
}
type discoverTable interface {
Self() *discover.Node
Close()
Bootstrap([]*discover.Node)
Lookup(target discover.NodeID) []*discover.Node
ReadRandomNodes([]*discover.Node) int
}
// the dial history remembers recent dials.
type dialHistory []pastDial
// pastDial is an entry in the dial history.
type pastDial struct {
id discover.NodeID
exp time.Time
}
type task interface {
Do(*Server)
}
// A dialTask is generated for each node that is dialed.
type dialTask struct {
flags connFlag
dest *discover.Node
}
// discoverTask runs discovery table operations.
// Only one discoverTask is active at any time.
//
// If bootstrap is true, the task runs Table.Bootstrap,
// otherwise it performs a random lookup and leaves the
// results in the task.
type discoverTask struct {
bootstrap bool
results []*discover.Node
}
// A waitExpireTask is generated if there are no other tasks
// to keep the loop in Server.run ticking.
type waitExpireTask struct {
time.Duration
}
func newDialState(static []*discover.Node, ntab discoverTable, maxdyn int) *dialstate {
s := &dialstate{
maxDynDials: maxdyn,
ntab: ntab,
static: make(map[discover.NodeID]*discover.Node),
dialing: make(map[discover.NodeID]connFlag),
randomNodes: make([]*discover.Node, maxdyn/2),
hist: new(dialHistory),
}
for _, n := range static {
s.static[n.ID] = n
}
return s
}
func (s *dialstate) addStatic(n *discover.Node) {
s.static[n.ID] = n
}
func (s *dialstate) newTasks(nRunning int, peers map[discover.NodeID]*Peer, now time.Time) []task {
var newtasks []task
addDial := func(flag connFlag, n *discover.Node) bool {
_, dialing := s.dialing[n.ID]
if dialing || peers[n.ID] != nil || s.hist.contains(n.ID) {
return false
}
s.dialing[n.ID] = flag
newtasks = append(newtasks, &dialTask{flags: flag, dest: n})
return true
}
// Compute number of dynamic dials necessary at this point.
needDynDials := s.maxDynDials
for _, p := range peers {
if p.rw.is(dynDialedConn) {
needDynDials--
}
}
for _, flag := range s.dialing {
if flag&dynDialedConn != 0 {
needDynDials--
}
}
// Expire the dial history on every invocation.
s.hist.expire(now)
// Create dials for static nodes if they are not connected.
for _, n := range s.static {
addDial(staticDialedConn, n)
}
// Use random nodes from the table for half of the necessary
// dynamic dials.
randomCandidates := needDynDials / 2
if randomCandidates > 0 && s.bootstrapped {
n := s.ntab.ReadRandomNodes(s.randomNodes)
for i := 0; i < randomCandidates && i < n; i++ {
if addDial(dynDialedConn, s.randomNodes[i]) {
needDynDials--
}
}
}
// Create dynamic dials from random lookup results, removing tried
// items from the result buffer.
i := 0
for ; i < len(s.lookupBuf) && needDynDials > 0; i++ {
if addDial(dynDialedConn, s.lookupBuf[i]) {
needDynDials--
}
}
s.lookupBuf = s.lookupBuf[:copy(s.lookupBuf, s.lookupBuf[i:])]
// Launch a discovery lookup if more candidates are needed. The
// first discoverTask bootstraps the table and won't return any
// results.
if len(s.lookupBuf) < needDynDials && !s.lookupRunning {
s.lookupRunning = true
newtasks = append(newtasks, &discoverTask{bootstrap: !s.bootstrapped})
}
// Launch a timer to wait for the next node to expire if all
// candidates have been tried and no task is currently active.
// This should prevent cases where the dialer logic is not ticked
// because there are no pending events.
if nRunning == 0 && len(newtasks) == 0 && s.hist.Len() > 0 {
t := &waitExpireTask{s.hist.min().exp.Sub(now)}
newtasks = append(newtasks, t)
}
return newtasks
}
func (s *dialstate) taskDone(t task, now time.Time) {
switch t := t.(type) {
case *dialTask:
s.hist.add(t.dest.ID, now.Add(dialHistoryExpiration))
delete(s.dialing, t.dest.ID)
case *discoverTask:
if t.bootstrap {
s.bootstrapped = true
}
s.lookupRunning = false
s.lookupBuf = append(s.lookupBuf, t.results...)
}
}
func (t *dialTask) Do(srv *Server) {
addr := &net.TCPAddr{IP: t.dest.IP, Port: int(t.dest.TCP)}
glog.V(logger.Debug).Infof("dialing %v\n", t.dest)
fd, err := srv.Dialer.Dial("tcp", addr.String())
if err != nil {
glog.V(logger.Detail).Infof("dial error: %v", err)
return
}
srv.setupConn(fd, t.flags, t.dest)
}
func (t *dialTask) String() string {
return fmt.Sprintf("%v %x %v:%d", t.flags, t.dest.ID[:8], t.dest.IP, t.dest.TCP)
}
func (t *discoverTask) Do(srv *Server) {
if t.bootstrap {
srv.ntab.Bootstrap(srv.BootstrapNodes)
} else {
var target discover.NodeID
rand.Read(target[:])
t.results = srv.ntab.Lookup(target)
// newTasks generates a lookup task whenever dynamic dials are
// necessary. Lookups need to take some time, otherwise the
// event loop spins too fast. An empty result can only be
// returned if the table is empty.
if len(t.results) == 0 {
time.Sleep(emptyLookupDelay)
}
}
}
func (t *discoverTask) String() (s string) {
if t.bootstrap {
s = "discovery bootstrap"
} else {
s = "discovery lookup"
}
if len(t.results) > 0 {
s += fmt.Sprintf(" (%d results)", len(t.results))
}
return s
}
func (t waitExpireTask) Do(*Server) {
time.Sleep(t.Duration)
}
func (t waitExpireTask) String() string {
return fmt.Sprintf("wait for dial hist expire (%v)", t.Duration)
}
// Use only these methods to access or modify dialHistory.
func (h dialHistory) min() pastDial {
return h[0]
}
func (h *dialHistory) add(id discover.NodeID, exp time.Time) {
heap.Push(h, pastDial{id, exp})
}
func (h dialHistory) contains(id discover.NodeID) bool {
for _, v := range h {
if v.id == id {
return true
}
}
return false
}
func (h *dialHistory) expire(now time.Time) {
for h.Len() > 0 && h.min().exp.Before(now) {
heap.Pop(h)
}
}
// heap.Interface boilerplate
func (h dialHistory) Len() int { return len(h) }
func (h dialHistory) Less(i, j int) bool { return h[i].exp.Before(h[j].exp) }
func (h dialHistory) Swap(i, j int) { h[i], h[j] = h[j], h[i] }
func (h *dialHistory) Push(x interface{}) {
*h = append(*h, x.(pastDial))
}
func (h *dialHistory) Pop() interface{} {
old := *h
n := len(old)
x := old[n-1]
*h = old[0 : n-1]
return x
}

482
p2p/dial_test.go Normal file
View File

@ -0,0 +1,482 @@
package p2p
import (
"encoding/binary"
"reflect"
"testing"
"time"
"github.com/davecgh/go-spew/spew"
"github.com/ethereum/go-ethereum/p2p/discover"
)
func init() {
spew.Config.Indent = "\t"
}
type dialtest struct {
init *dialstate // state before and after the test.
rounds []round
}
type round struct {
peers []*Peer // current peer set
done []task // tasks that got done this round
new []task // the result must match this one
}
func runDialTest(t *testing.T, test dialtest) {
var (
vtime time.Time
running int
)
pm := func(ps []*Peer) map[discover.NodeID]*Peer {
m := make(map[discover.NodeID]*Peer)
for _, p := range ps {
m[p.rw.id] = p
}
return m
}
for i, round := range test.rounds {
for _, task := range round.done {
running--
if running < 0 {
panic("running task counter underflow")
}
test.init.taskDone(task, vtime)
}
new := test.init.newTasks(running, pm(round.peers), vtime)
if !sametasks(new, round.new) {
t.Errorf("round %d: new tasks mismatch:\ngot %v\nwant %v\nstate: %v\nrunning: %v\n",
i, spew.Sdump(new), spew.Sdump(round.new), spew.Sdump(test.init), spew.Sdump(running))
}
// Time advances by 16 seconds on every round.
vtime = vtime.Add(16 * time.Second)
running += len(new)
}
}
type fakeTable []*discover.Node
func (t fakeTable) Self() *discover.Node { return new(discover.Node) }
func (t fakeTable) Close() {}
func (t fakeTable) Bootstrap([]*discover.Node) {}
func (t fakeTable) Lookup(target discover.NodeID) []*discover.Node {
return nil
}
func (t fakeTable) ReadRandomNodes(buf []*discover.Node) int {
return copy(buf, t)
}
// This test checks that dynamic dials are launched from discovery results.
func TestDialStateDynDial(t *testing.T) {
runDialTest(t, dialtest{
init: newDialState(nil, fakeTable{}, 5),
rounds: []round{
// A discovery query is launched.
{
peers: []*Peer{
{rw: &conn{flags: staticDialedConn, id: uintID(0)}},
{rw: &conn{flags: dynDialedConn, id: uintID(1)}},
{rw: &conn{flags: dynDialedConn, id: uintID(2)}},
},
new: []task{&discoverTask{bootstrap: true}},
},
// Dynamic dials are launched when it completes.
{
peers: []*Peer{
{rw: &conn{flags: staticDialedConn, id: uintID(0)}},
{rw: &conn{flags: dynDialedConn, id: uintID(1)}},
{rw: &conn{flags: dynDialedConn, id: uintID(2)}},
},
done: []task{
&discoverTask{bootstrap: true, results: []*discover.Node{
{ID: uintID(2)}, // this one is already connected and not dialed.
{ID: uintID(3)},
{ID: uintID(4)},
{ID: uintID(5)},
{ID: uintID(6)}, // these are not tried because max dyn dials is 5
{ID: uintID(7)}, // ...
}},
},
new: []task{
&dialTask{dynDialedConn, &discover.Node{ID: uintID(3)}},
&dialTask{dynDialedConn, &discover.Node{ID: uintID(4)}},
&dialTask{dynDialedConn, &discover.Node{ID: uintID(5)}},
},
},
// Some of the dials complete but no new ones are launched yet because
// the sum of active dial count and dynamic peer count is == maxDynDials.
{
peers: []*Peer{
{rw: &conn{flags: staticDialedConn, id: uintID(0)}},
{rw: &conn{flags: dynDialedConn, id: uintID(1)}},
{rw: &conn{flags: dynDialedConn, id: uintID(2)}},
{rw: &conn{flags: dynDialedConn, id: uintID(3)}},
{rw: &conn{flags: dynDialedConn, id: uintID(4)}},
},
done: []task{
&dialTask{dynDialedConn, &discover.Node{ID: uintID(3)}},
&dialTask{dynDialedConn, &discover.Node{ID: uintID(4)}},
},
},
// No new dial tasks are launched in the this round because
// maxDynDials has been reached.
{
peers: []*Peer{
{rw: &conn{flags: staticDialedConn, id: uintID(0)}},
{rw: &conn{flags: dynDialedConn, id: uintID(1)}},
{rw: &conn{flags: dynDialedConn, id: uintID(2)}},
{rw: &conn{flags: dynDialedConn, id: uintID(3)}},
{rw: &conn{flags: dynDialedConn, id: uintID(4)}},
{rw: &conn{flags: dynDialedConn, id: uintID(5)}},
},
done: []task{
&dialTask{dynDialedConn, &discover.Node{ID: uintID(5)}},
},
new: []task{
&waitExpireTask{Duration: 14 * time.Second},
},
},
// In this round, the peer with id 2 drops off. The query
// results from last discovery lookup are reused.
{
peers: []*Peer{
{rw: &conn{flags: staticDialedConn, id: uintID(0)}},
{rw: &conn{flags: dynDialedConn, id: uintID(1)}},
{rw: &conn{flags: dynDialedConn, id: uintID(3)}},
{rw: &conn{flags: dynDialedConn, id: uintID(4)}},
{rw: &conn{flags: dynDialedConn, id: uintID(5)}},
},
new: []task{
&dialTask{dynDialedConn, &discover.Node{ID: uintID(6)}},
},
},
// More peers (3,4) drop off and dial for ID 6 completes.
// The last query result from the discovery lookup is reused
// and a new one is spawned because more candidates are needed.
{
peers: []*Peer{
{rw: &conn{flags: staticDialedConn, id: uintID(0)}},
{rw: &conn{flags: dynDialedConn, id: uintID(1)}},
{rw: &conn{flags: dynDialedConn, id: uintID(5)}},
},
done: []task{
&dialTask{dynDialedConn, &discover.Node{ID: uintID(6)}},
},
new: []task{
&dialTask{dynDialedConn, &discover.Node{ID: uintID(7)}},
&discoverTask{},
},
},
// Peer 7 is connected, but there still aren't enough dynamic peers
// (4 out of 5). However, a discovery is already running, so ensure
// no new is started.
{
peers: []*Peer{
{rw: &conn{flags: staticDialedConn, id: uintID(0)}},
{rw: &conn{flags: dynDialedConn, id: uintID(1)}},
{rw: &conn{flags: dynDialedConn, id: uintID(5)}},
{rw: &conn{flags: dynDialedConn, id: uintID(7)}},
},
done: []task{
&dialTask{dynDialedConn, &discover.Node{ID: uintID(7)}},
},
},
// Finish the running node discovery with an empty set. A new lookup
// should be immediately requested.
{
peers: []*Peer{
{rw: &conn{flags: staticDialedConn, id: uintID(0)}},
{rw: &conn{flags: dynDialedConn, id: uintID(1)}},
{rw: &conn{flags: dynDialedConn, id: uintID(5)}},
{rw: &conn{flags: dynDialedConn, id: uintID(7)}},
},
done: []task{
&discoverTask{},
},
new: []task{
&discoverTask{},
},
},
},
})
}
func TestDialStateDynDialFromTable(t *testing.T) {
// This table always returns the same random nodes
// in the order given below.
table := fakeTable{
{ID: uintID(1)},
{ID: uintID(2)},
{ID: uintID(3)},
{ID: uintID(4)},
{ID: uintID(5)},
{ID: uintID(6)},
{ID: uintID(7)},
{ID: uintID(8)},
}
runDialTest(t, dialtest{
init: newDialState(nil, table, 10),
rounds: []round{
// Discovery bootstrap is launched.
{
new: []task{&discoverTask{bootstrap: true}},
},
// 5 out of 8 of the nodes returned by ReadRandomNodes are dialed.
{
done: []task{
&discoverTask{bootstrap: true},
},
new: []task{
&dialTask{dynDialedConn, &discover.Node{ID: uintID(1)}},
&dialTask{dynDialedConn, &discover.Node{ID: uintID(2)}},
&dialTask{dynDialedConn, &discover.Node{ID: uintID(3)}},
&dialTask{dynDialedConn, &discover.Node{ID: uintID(4)}},
&dialTask{dynDialedConn, &discover.Node{ID: uintID(5)}},
&discoverTask{bootstrap: false},
},
},
// Dialing nodes 1,2 succeeds. Dials from the lookup are launched.
{
peers: []*Peer{
{rw: &conn{flags: dynDialedConn, id: uintID(1)}},
{rw: &conn{flags: dynDialedConn, id: uintID(2)}},
},
done: []task{
&dialTask{dynDialedConn, &discover.Node{ID: uintID(1)}},
&dialTask{dynDialedConn, &discover.Node{ID: uintID(2)}},
&discoverTask{results: []*discover.Node{
{ID: uintID(10)},
{ID: uintID(11)},
{ID: uintID(12)},
}},
},
new: []task{
&dialTask{dynDialedConn, &discover.Node{ID: uintID(10)}},
&dialTask{dynDialedConn, &discover.Node{ID: uintID(11)}},
&dialTask{dynDialedConn, &discover.Node{ID: uintID(12)}},
&discoverTask{bootstrap: false},
},
},
// Dialing nodes 3,4,5 fails. The dials from the lookup succeed.
{
peers: []*Peer{
{rw: &conn{flags: dynDialedConn, id: uintID(1)}},
{rw: &conn{flags: dynDialedConn, id: uintID(2)}},
{rw: &conn{flags: dynDialedConn, id: uintID(10)}},
{rw: &conn{flags: dynDialedConn, id: uintID(11)}},
{rw: &conn{flags: dynDialedConn, id: uintID(12)}},
},
done: []task{
&dialTask{dynDialedConn, &discover.Node{ID: uintID(3)}},
&dialTask{dynDialedConn, &discover.Node{ID: uintID(4)}},
&dialTask{dynDialedConn, &discover.Node{ID: uintID(5)}},
&dialTask{dynDialedConn, &discover.Node{ID: uintID(10)}},
&dialTask{dynDialedConn, &discover.Node{ID: uintID(11)}},
&dialTask{dynDialedConn, &discover.Node{ID: uintID(12)}},
},
},
// Waiting for expiry. No waitExpireTask is launched because the
// discovery query is still running.
{
peers: []*Peer{
{rw: &conn{flags: dynDialedConn, id: uintID(1)}},
{rw: &conn{flags: dynDialedConn, id: uintID(2)}},
{rw: &conn{flags: dynDialedConn, id: uintID(10)}},
{rw: &conn{flags: dynDialedConn, id: uintID(11)}},
{rw: &conn{flags: dynDialedConn, id: uintID(12)}},
},
},
// Nodes 3,4 are not tried again because only the first two
// returned random nodes (nodes 1,2) are tried and they're
// already connected.
{
peers: []*Peer{
{rw: &conn{flags: dynDialedConn, id: uintID(1)}},
{rw: &conn{flags: dynDialedConn, id: uintID(2)}},
{rw: &conn{flags: dynDialedConn, id: uintID(10)}},
{rw: &conn{flags: dynDialedConn, id: uintID(11)}},
{rw: &conn{flags: dynDialedConn, id: uintID(12)}},
},
},
},
})
}
// This test checks that static dials are launched.
func TestDialStateStaticDial(t *testing.T) {
wantStatic := []*discover.Node{
{ID: uintID(1)},
{ID: uintID(2)},
{ID: uintID(3)},
{ID: uintID(4)},
{ID: uintID(5)},
}
runDialTest(t, dialtest{
init: newDialState(wantStatic, fakeTable{}, 0),
rounds: []round{
// Static dials are launched for the nodes that
// aren't yet connected.
{
peers: []*Peer{
{rw: &conn{flags: dynDialedConn, id: uintID(1)}},
{rw: &conn{flags: dynDialedConn, id: uintID(2)}},
},
new: []task{
&dialTask{staticDialedConn, &discover.Node{ID: uintID(3)}},
&dialTask{staticDialedConn, &discover.Node{ID: uintID(4)}},
&dialTask{staticDialedConn, &discover.Node{ID: uintID(5)}},
},
},
// No new tasks are launched in this round because all static
// nodes are either connected or still being dialed.
{
peers: []*Peer{
{rw: &conn{flags: dynDialedConn, id: uintID(1)}},
{rw: &conn{flags: dynDialedConn, id: uintID(2)}},
{rw: &conn{flags: staticDialedConn, id: uintID(3)}},
},
done: []task{
&dialTask{staticDialedConn, &discover.Node{ID: uintID(3)}},
},
},
// No new dial tasks are launched because all static
// nodes are now connected.
{
peers: []*Peer{
{rw: &conn{flags: dynDialedConn, id: uintID(1)}},
{rw: &conn{flags: dynDialedConn, id: uintID(2)}},
{rw: &conn{flags: staticDialedConn, id: uintID(3)}},
{rw: &conn{flags: staticDialedConn, id: uintID(4)}},
{rw: &conn{flags: staticDialedConn, id: uintID(5)}},
},
done: []task{
&dialTask{staticDialedConn, &discover.Node{ID: uintID(4)}},
&dialTask{staticDialedConn, &discover.Node{ID: uintID(5)}},
},
new: []task{
&waitExpireTask{Duration: 14 * time.Second},
},
},
// Wait a round for dial history to expire, no new tasks should spawn.
{
peers: []*Peer{
{rw: &conn{flags: dynDialedConn, id: uintID(1)}},
{rw: &conn{flags: dynDialedConn, id: uintID(2)}},
{rw: &conn{flags: staticDialedConn, id: uintID(3)}},
{rw: &conn{flags: staticDialedConn, id: uintID(4)}},
{rw: &conn{flags: staticDialedConn, id: uintID(5)}},
},
},
// If a static node is dropped, it should be immediately redialed,
// irrespective whether it was originally static or dynamic.
{
peers: []*Peer{
{rw: &conn{flags: dynDialedConn, id: uintID(1)}},
{rw: &conn{flags: staticDialedConn, id: uintID(3)}},
{rw: &conn{flags: staticDialedConn, id: uintID(5)}},
},
new: []task{
&dialTask{staticDialedConn, &discover.Node{ID: uintID(2)}},
&dialTask{staticDialedConn, &discover.Node{ID: uintID(4)}},
},
},
},
})
}
// This test checks that past dials are not retried for some time.
func TestDialStateCache(t *testing.T) {
wantStatic := []*discover.Node{
{ID: uintID(1)},
{ID: uintID(2)},
{ID: uintID(3)},
}
runDialTest(t, dialtest{
init: newDialState(wantStatic, fakeTable{}, 0),
rounds: []round{
// Static dials are launched for the nodes that
// aren't yet connected.
{
peers: nil,
new: []task{
&dialTask{staticDialedConn, &discover.Node{ID: uintID(1)}},
&dialTask{staticDialedConn, &discover.Node{ID: uintID(2)}},
&dialTask{staticDialedConn, &discover.Node{ID: uintID(3)}},
},
},
// No new tasks are launched in this round because all static
// nodes are either connected or still being dialed.
{
peers: []*Peer{
{rw: &conn{flags: staticDialedConn, id: uintID(1)}},
{rw: &conn{flags: staticDialedConn, id: uintID(2)}},
},
done: []task{
&dialTask{staticDialedConn, &discover.Node{ID: uintID(1)}},
&dialTask{staticDialedConn, &discover.Node{ID: uintID(2)}},
},
},
// A salvage task is launched to wait for node 3's history
// entry to expire.
{
peers: []*Peer{
{rw: &conn{flags: dynDialedConn, id: uintID(1)}},
{rw: &conn{flags: dynDialedConn, id: uintID(2)}},
},
done: []task{
&dialTask{staticDialedConn, &discover.Node{ID: uintID(3)}},
},
new: []task{
&waitExpireTask{Duration: 14 * time.Second},
},
},
// Still waiting for node 3's entry to expire in the cache.
{
peers: []*Peer{
{rw: &conn{flags: dynDialedConn, id: uintID(1)}},
{rw: &conn{flags: dynDialedConn, id: uintID(2)}},
},
},
// The cache entry for node 3 has expired and is retried.
{
peers: []*Peer{
{rw: &conn{flags: dynDialedConn, id: uintID(1)}},
{rw: &conn{flags: dynDialedConn, id: uintID(2)}},
},
new: []task{
&dialTask{staticDialedConn, &discover.Node{ID: uintID(3)}},
},
},
},
})
}
// compares task lists but doesn't care about the order.
func sametasks(a, b []task) bool {
if len(a) != len(b) {
return false
}
next:
for _, ta := range a {
for _, tb := range b {
if reflect.DeepEqual(ta, tb) {
continue next
}
}
return false
}
return true
}
func uintID(i uint32) discover.NodeID {
var id discover.NodeID
binary.BigEndian.PutUint32(id[:], i)
return id
}

View File

@ -1,448 +0,0 @@
package p2p
import (
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rand"
"errors"
"fmt"
"hash"
"io"
"net"
"github.com/ethereum/go-ethereum/crypto"
"github.com/ethereum/go-ethereum/crypto/ecies"
"github.com/ethereum/go-ethereum/crypto/secp256k1"
"github.com/ethereum/go-ethereum/crypto/sha3"
"github.com/ethereum/go-ethereum/p2p/discover"
"github.com/ethereum/go-ethereum/rlp"
)
const (
sskLen = 16 // ecies.MaxSharedKeyLength(pubKey) / 2
sigLen = 65 // elliptic S256
pubLen = 64 // 512 bit pubkey in uncompressed representation without format byte
shaLen = 32 // hash length (for nonce etc)
authMsgLen = sigLen + shaLen + pubLen + shaLen + 1
authRespLen = pubLen + shaLen + 1
eciesBytes = 65 + 16 + 32
encAuthMsgLen = authMsgLen + eciesBytes // size of the final ECIES payload sent as initiator's handshake
encAuthRespLen = authRespLen + eciesBytes // size of the final ECIES payload sent as receiver's handshake
)
// conn represents a remote connection after encryption handshake
// and protocol handshake have completed.
//
// The MsgReadWriter is usually layered as follows:
//
// netWrapper (I/O timeouts, thread-safe ReadMsg, WriteMsg)
// rlpxFrameRW (message encoding, encryption, authentication)
// bufio.ReadWriter (buffering)
// net.Conn (network I/O)
//
type conn struct {
MsgReadWriter
*protoHandshake
}
// secrets represents the connection secrets
// which are negotiated during the encryption handshake.
type secrets struct {
RemoteID discover.NodeID
AES, MAC []byte
EgressMAC, IngressMAC hash.Hash
Token []byte
}
// protoHandshake is the RLP structure of the protocol handshake.
type protoHandshake struct {
Version uint64
Name string
Caps []Cap
ListenPort uint64
ID discover.NodeID
}
// setupConn starts a protocol session on the given connection. It
// runs the encryption handshake and the protocol handshake. If dial
// is non-nil, the connection the local node is the initiator. If
// keepconn returns false, the connection will be disconnected with
// DiscTooManyPeers after the key exchange.
func setupConn(fd net.Conn, prv *ecdsa.PrivateKey, our *protoHandshake, dial *discover.Node, keepconn func(discover.NodeID) bool) (*conn, error) {
if dial == nil {
return setupInboundConn(fd, prv, our, keepconn)
} else {
return setupOutboundConn(fd, prv, our, dial, keepconn)
}
}
func setupInboundConn(fd net.Conn, prv *ecdsa.PrivateKey, our *protoHandshake, keepconn func(discover.NodeID) bool) (*conn, error) {
secrets, err := receiverEncHandshake(fd, prv, nil)
if err != nil {
return nil, fmt.Errorf("encryption handshake failed: %v", err)
}
rw := newRlpxFrameRW(fd, secrets)
if !keepconn(secrets.RemoteID) {
SendItems(rw, discMsg, DiscTooManyPeers)
return nil, errors.New("we have too many peers")
}
// Run the protocol handshake using authenticated messages.
rhs, err := readProtocolHandshake(rw, secrets.RemoteID, our)
if err != nil {
return nil, err
}
if err := Send(rw, handshakeMsg, our); err != nil {
return nil, fmt.Errorf("protocol handshake write error: %v", err)
}
return &conn{rw, rhs}, nil
}
func setupOutboundConn(fd net.Conn, prv *ecdsa.PrivateKey, our *protoHandshake, dial *discover.Node, keepconn func(discover.NodeID) bool) (*conn, error) {
secrets, err := initiatorEncHandshake(fd, prv, dial.ID, nil)
if err != nil {
return nil, fmt.Errorf("encryption handshake failed: %v", err)
}
rw := newRlpxFrameRW(fd, secrets)
if !keepconn(secrets.RemoteID) {
SendItems(rw, discMsg, DiscTooManyPeers)
return nil, errors.New("we have too many peers")
}
// Run the protocol handshake using authenticated messages.
//
// Note that even though writing the handshake is first, we prefer
// returning the handshake read error. If the remote side
// disconnects us early with a valid reason, we should return it
// as the error so it can be tracked elsewhere.
werr := make(chan error, 1)
go func() { werr <- Send(rw, handshakeMsg, our) }()
rhs, err := readProtocolHandshake(rw, secrets.RemoteID, our)
if err != nil {
return nil, err
}
if err := <-werr; err != nil {
return nil, fmt.Errorf("protocol handshake write error: %v", err)
}
if rhs.ID != dial.ID {
return nil, errors.New("dialed node id mismatch")
}
return &conn{rw, rhs}, nil
}
// encHandshake contains the state of the encryption handshake.
type encHandshake struct {
initiator bool
remoteID discover.NodeID
remotePub *ecies.PublicKey // remote-pubk
initNonce, respNonce []byte // nonce
randomPrivKey *ecies.PrivateKey // ecdhe-random
remoteRandomPub *ecies.PublicKey // ecdhe-random-pubk
}
// secrets is called after the handshake is completed.
// It extracts the connection secrets from the handshake values.
func (h *encHandshake) secrets(auth, authResp []byte) (secrets, error) {
ecdheSecret, err := h.randomPrivKey.GenerateShared(h.remoteRandomPub, sskLen, sskLen)
if err != nil {
return secrets{}, err
}
// derive base secrets from ephemeral key agreement
sharedSecret := crypto.Sha3(ecdheSecret, crypto.Sha3(h.respNonce, h.initNonce))
aesSecret := crypto.Sha3(ecdheSecret, sharedSecret)
s := secrets{
RemoteID: h.remoteID,
AES: aesSecret,
MAC: crypto.Sha3(ecdheSecret, aesSecret),
Token: crypto.Sha3(sharedSecret),
}
// setup sha3 instances for the MACs
mac1 := sha3.NewKeccak256()
mac1.Write(xor(s.MAC, h.respNonce))
mac1.Write(auth)
mac2 := sha3.NewKeccak256()
mac2.Write(xor(s.MAC, h.initNonce))
mac2.Write(authResp)
if h.initiator {
s.EgressMAC, s.IngressMAC = mac1, mac2
} else {
s.EgressMAC, s.IngressMAC = mac2, mac1
}
return s, nil
}
func (h *encHandshake) ecdhShared(prv *ecdsa.PrivateKey) ([]byte, error) {
return ecies.ImportECDSA(prv).GenerateShared(h.remotePub, sskLen, sskLen)
}
// initiatorEncHandshake negotiates a session token on conn.
// it should be called on the dialing side of the connection.
//
// prv is the local client's private key.
// token is the token from a previous session with this node.
func initiatorEncHandshake(conn io.ReadWriter, prv *ecdsa.PrivateKey, remoteID discover.NodeID, token []byte) (s secrets, err error) {
h, err := newInitiatorHandshake(remoteID)
if err != nil {
return s, err
}
auth, err := h.authMsg(prv, token)
if err != nil {
return s, err
}
if _, err = conn.Write(auth); err != nil {
return s, err
}
response := make([]byte, encAuthRespLen)
if _, err = io.ReadFull(conn, response); err != nil {
return s, err
}
if err := h.decodeAuthResp(response, prv); err != nil {
return s, err
}
return h.secrets(auth, response)
}
func newInitiatorHandshake(remoteID discover.NodeID) (*encHandshake, error) {
// generate random initiator nonce
n := make([]byte, shaLen)
if _, err := rand.Read(n); err != nil {
return nil, err
}
// generate random keypair to use for signing
randpriv, err := ecies.GenerateKey(rand.Reader, crypto.S256(), nil)
if err != nil {
return nil, err
}
rpub, err := remoteID.Pubkey()
if err != nil {
return nil, fmt.Errorf("bad remoteID: %v", err)
}
h := &encHandshake{
initiator: true,
remoteID: remoteID,
remotePub: ecies.ImportECDSAPublic(rpub),
initNonce: n,
randomPrivKey: randpriv,
}
return h, nil
}
// authMsg creates an encrypted initiator handshake message.
func (h *encHandshake) authMsg(prv *ecdsa.PrivateKey, token []byte) ([]byte, error) {
var tokenFlag byte
if token == nil {
// no session token found means we need to generate shared secret.
// ecies shared secret is used as initial session token for new peers
// generate shared key from prv and remote pubkey
var err error
if token, err = h.ecdhShared(prv); err != nil {
return nil, err
}
} else {
// for known peers, we use stored token from the previous session
tokenFlag = 0x01
}
// sign known message:
// ecdh-shared-secret^nonce for new peers
// token^nonce for old peers
signed := xor(token, h.initNonce)
signature, err := crypto.Sign(signed, h.randomPrivKey.ExportECDSA())
if err != nil {
return nil, err
}
// encode auth message
// signature || sha3(ecdhe-random-pubk) || pubk || nonce || token-flag
msg := make([]byte, authMsgLen)
n := copy(msg, signature)
n += copy(msg[n:], crypto.Sha3(exportPubkey(&h.randomPrivKey.PublicKey)))
n += copy(msg[n:], crypto.FromECDSAPub(&prv.PublicKey)[1:])
n += copy(msg[n:], h.initNonce)
msg[n] = tokenFlag
// encrypt auth message using remote-pubk
return ecies.Encrypt(rand.Reader, h.remotePub, msg, nil, nil)
}
// decodeAuthResp decode an encrypted authentication response message.
func (h *encHandshake) decodeAuthResp(auth []byte, prv *ecdsa.PrivateKey) error {
msg, err := crypto.Decrypt(prv, auth)
if err != nil {
return fmt.Errorf("could not decrypt auth response (%v)", err)
}
h.respNonce = msg[pubLen : pubLen+shaLen]
h.remoteRandomPub, err = importPublicKey(msg[:pubLen])
if err != nil {
return err
}
// ignore token flag for now
return nil
}
// receiverEncHandshake negotiates a session token on conn.
// it should be called on the listening side of the connection.
//
// prv is the local client's private key.
// token is the token from a previous session with this node.
func receiverEncHandshake(conn io.ReadWriter, prv *ecdsa.PrivateKey, token []byte) (s secrets, err error) {
// read remote auth sent by initiator.
auth := make([]byte, encAuthMsgLen)
if _, err := io.ReadFull(conn, auth); err != nil {
return s, err
}
h, err := decodeAuthMsg(prv, token, auth)
if err != nil {
return s, err
}
// send auth response
resp, err := h.authResp(prv, token)
if err != nil {
return s, err
}
if _, err = conn.Write(resp); err != nil {
return s, err
}
return h.secrets(auth, resp)
}
func decodeAuthMsg(prv *ecdsa.PrivateKey, token []byte, auth []byte) (*encHandshake, error) {
var err error
h := new(encHandshake)
// generate random keypair for session
h.randomPrivKey, err = ecies.GenerateKey(rand.Reader, crypto.S256(), nil)
if err != nil {
return nil, err
}
// generate random nonce
h.respNonce = make([]byte, shaLen)
if _, err = rand.Read(h.respNonce); err != nil {
return nil, err
}
msg, err := crypto.Decrypt(prv, auth)
if err != nil {
return nil, fmt.Errorf("could not decrypt auth message (%v)", err)
}
// decode message parameters
// signature || sha3(ecdhe-random-pubk) || pubk || nonce || token-flag
h.initNonce = msg[authMsgLen-shaLen-1 : authMsgLen-1]
copy(h.remoteID[:], msg[sigLen+shaLen:sigLen+shaLen+pubLen])
rpub, err := h.remoteID.Pubkey()
if err != nil {
return nil, fmt.Errorf("bad remoteID: %#v", err)
}
h.remotePub = ecies.ImportECDSAPublic(rpub)
// recover remote random pubkey from signed message.
if token == nil {
// TODO: it is an error if the initiator has a token and we don't. check that.
// no session token means we need to generate shared secret.
// ecies shared secret is used as initial session token for new peers.
// generate shared key from prv and remote pubkey.
if token, err = h.ecdhShared(prv); err != nil {
return nil, err
}
}
signedMsg := xor(token, h.initNonce)
remoteRandomPub, err := secp256k1.RecoverPubkey(signedMsg, msg[:sigLen])
if err != nil {
return nil, err
}
h.remoteRandomPub, _ = importPublicKey(remoteRandomPub)
return h, nil
}
// authResp generates the encrypted authentication response message.
func (h *encHandshake) authResp(prv *ecdsa.PrivateKey, token []byte) ([]byte, error) {
// responder auth message
// E(remote-pubk, ecdhe-random-pubk || nonce || 0x0)
resp := make([]byte, authRespLen)
n := copy(resp, exportPubkey(&h.randomPrivKey.PublicKey))
n += copy(resp[n:], h.respNonce)
if token == nil {
resp[n] = 0
} else {
resp[n] = 1
}
// encrypt using remote-pubk
return ecies.Encrypt(rand.Reader, h.remotePub, resp, nil, nil)
}
// importPublicKey unmarshals 512 bit public keys.
func importPublicKey(pubKey []byte) (*ecies.PublicKey, error) {
var pubKey65 []byte
switch len(pubKey) {
case 64:
// add 'uncompressed key' flag
pubKey65 = append([]byte{0x04}, pubKey...)
case 65:
pubKey65 = pubKey
default:
return nil, fmt.Errorf("invalid public key length %v (expect 64/65)", len(pubKey))
}
// TODO: fewer pointless conversions
return ecies.ImportECDSAPublic(crypto.ToECDSAPub(pubKey65)), nil
}
func exportPubkey(pub *ecies.PublicKey) []byte {
if pub == nil {
panic("nil pubkey")
}
return elliptic.Marshal(pub.Curve, pub.X, pub.Y)[1:]
}
func xor(one, other []byte) (xor []byte) {
xor = make([]byte, len(one))
for i := 0; i < len(one); i++ {
xor[i] = one[i] ^ other[i]
}
return xor
}
func readProtocolHandshake(rw MsgReadWriter, wantID discover.NodeID, our *protoHandshake) (*protoHandshake, error) {
msg, err := rw.ReadMsg()
if err != nil {
return nil, err
}
if msg.Code == discMsg {
// disconnect before protocol handshake is valid according to the
// spec and we send it ourself if Server.addPeer fails.
var reason [1]DiscReason
rlp.Decode(msg.Payload, &reason)
return nil, reason[0]
}
if msg.Code != handshakeMsg {
return nil, fmt.Errorf("expected handshake, got %x", msg.Code)
}
if msg.Size > baseProtocolMaxMsgSize {
return nil, fmt.Errorf("message too big (%d > %d)", msg.Size, baseProtocolMaxMsgSize)
}
var hs protoHandshake
if err := msg.Decode(&hs); err != nil {
return nil, err
}
// validate handshake info
if hs.Version != our.Version {
SendItems(rw, discMsg, DiscIncompatibleVersion)
return nil, fmt.Errorf("required version %d, received %d\n", baseProtocolVersion, hs.Version)
}
if (hs.ID == discover.NodeID{}) {
SendItems(rw, discMsg, DiscInvalidIdentity)
return nil, errors.New("invalid public key in handshake")
}
if hs.ID != wantID {
SendItems(rw, discMsg, DiscUnexpectedIdentity)
return nil, errors.New("handshake node ID does not match encryption handshake")
}
return &hs, nil
}

View File

@ -1,172 +0,0 @@
package p2p
import (
"bytes"
"crypto/rand"
"fmt"
"net"
"reflect"
"testing"
"time"
"github.com/ethereum/go-ethereum/crypto"
"github.com/ethereum/go-ethereum/crypto/ecies"
"github.com/ethereum/go-ethereum/p2p/discover"
)
func TestSharedSecret(t *testing.T) {
prv0, _ := crypto.GenerateKey() // = ecdsa.GenerateKey(crypto.S256(), rand.Reader)
pub0 := &prv0.PublicKey
prv1, _ := crypto.GenerateKey()
pub1 := &prv1.PublicKey
ss0, err := ecies.ImportECDSA(prv0).GenerateShared(ecies.ImportECDSAPublic(pub1), sskLen, sskLen)
if err != nil {
return
}
ss1, err := ecies.ImportECDSA(prv1).GenerateShared(ecies.ImportECDSAPublic(pub0), sskLen, sskLen)
if err != nil {
return
}
t.Logf("Secret:\n%v %x\n%v %x", len(ss0), ss0, len(ss0), ss1)
if !bytes.Equal(ss0, ss1) {
t.Errorf("dont match :(")
}
}
func TestEncHandshake(t *testing.T) {
for i := 0; i < 20; i++ {
start := time.Now()
if err := testEncHandshake(nil); err != nil {
t.Fatalf("i=%d %v", i, err)
}
t.Logf("(without token) %d %v\n", i+1, time.Since(start))
}
for i := 0; i < 20; i++ {
tok := make([]byte, shaLen)
rand.Reader.Read(tok)
start := time.Now()
if err := testEncHandshake(tok); err != nil {
t.Fatalf("i=%d %v", i, err)
}
t.Logf("(with token) %d %v\n", i+1, time.Since(start))
}
}
func testEncHandshake(token []byte) error {
type result struct {
side string
s secrets
err error
}
var (
prv0, _ = crypto.GenerateKey()
prv1, _ = crypto.GenerateKey()
rw0, rw1 = net.Pipe()
output = make(chan result)
)
go func() {
r := result{side: "initiator"}
defer func() { output <- r }()
pub1s := discover.PubkeyID(&prv1.PublicKey)
r.s, r.err = initiatorEncHandshake(rw0, prv0, pub1s, token)
if r.err != nil {
return
}
id1 := discover.PubkeyID(&prv1.PublicKey)
if r.s.RemoteID != id1 {
r.err = fmt.Errorf("remote ID mismatch: got %v, want: %v", r.s.RemoteID, id1)
}
}()
go func() {
r := result{side: "receiver"}
defer func() { output <- r }()
r.s, r.err = receiverEncHandshake(rw1, prv1, token)
if r.err != nil {
return
}
id0 := discover.PubkeyID(&prv0.PublicKey)
if r.s.RemoteID != id0 {
r.err = fmt.Errorf("remote ID mismatch: got %v, want: %v", r.s.RemoteID, id0)
}
}()
// wait for results from both sides
r1, r2 := <-output, <-output
if r1.err != nil {
return fmt.Errorf("%s side error: %v", r1.side, r1.err)
}
if r2.err != nil {
return fmt.Errorf("%s side error: %v", r2.side, r2.err)
}
// don't compare remote node IDs
r1.s.RemoteID, r2.s.RemoteID = discover.NodeID{}, discover.NodeID{}
// flip MACs on one of them so they compare equal
r1.s.EgressMAC, r1.s.IngressMAC = r1.s.IngressMAC, r1.s.EgressMAC
if !reflect.DeepEqual(r1.s, r2.s) {
return fmt.Errorf("secrets mismatch:\n t1: %#v\n t2: %#v", r1.s, r2.s)
}
return nil
}
func TestSetupConn(t *testing.T) {
prv0, _ := crypto.GenerateKey()
prv1, _ := crypto.GenerateKey()
node0 := &discover.Node{
ID: discover.PubkeyID(&prv0.PublicKey),
IP: net.IP{1, 2, 3, 4},
TCP: 33,
}
node1 := &discover.Node{
ID: discover.PubkeyID(&prv1.PublicKey),
IP: net.IP{5, 6, 7, 8},
TCP: 44,
}
hs0 := &protoHandshake{
Version: baseProtocolVersion,
ID: node0.ID,
Caps: []Cap{{"a", 0}, {"b", 2}},
}
hs1 := &protoHandshake{
Version: baseProtocolVersion,
ID: node1.ID,
Caps: []Cap{{"c", 1}, {"d", 3}},
}
fd0, fd1 := net.Pipe()
done := make(chan struct{})
keepalways := func(discover.NodeID) bool { return true }
go func() {
defer close(done)
conn0, err := setupConn(fd0, prv0, hs0, node1, keepalways)
if err != nil {
t.Errorf("outbound side error: %v", err)
return
}
if conn0.ID != node1.ID {
t.Errorf("outbound conn id mismatch: got %v, want %v", conn0.ID, node1.ID)
}
if !reflect.DeepEqual(conn0.Caps, hs1.Caps) {
t.Errorf("outbound caps mismatch: got %v, want %v", conn0.Caps, hs1.Caps)
}
}()
conn1, err := setupConn(fd1, prv1, hs1, nil, keepalways)
if err != nil {
t.Fatalf("inbound side error: %v", err)
}
if conn1.ID != node0.ID {
t.Errorf("inbound conn id mismatch: got %v, want %v", conn1.ID, node0.ID)
}
if !reflect.DeepEqual(conn1.Caps, hs0.Caps) {
t.Errorf("inbound caps mismatch: got %v, want %v", conn1.Caps, hs0.Caps)
}
<-done
}

View File

@ -33,9 +33,17 @@ const (
peersMsg = 0x05 peersMsg = 0x05
) )
// protoHandshake is the RLP structure of the protocol handshake.
type protoHandshake struct {
Version uint64
Name string
Caps []Cap
ListenPort uint64
ID discover.NodeID
}
// Peer represents a connected remote node. // Peer represents a connected remote node.
type Peer struct { type Peer struct {
conn net.Conn
rw *conn rw *conn
running map[string]*protoRW running map[string]*protoRW
@ -48,37 +56,36 @@ type Peer struct {
// NewPeer returns a peer for testing purposes. // NewPeer returns a peer for testing purposes.
func NewPeer(id discover.NodeID, name string, caps []Cap) *Peer { func NewPeer(id discover.NodeID, name string, caps []Cap) *Peer {
pipe, _ := net.Pipe() pipe, _ := net.Pipe()
msgpipe, _ := MsgPipe() conn := &conn{fd: pipe, transport: nil, id: id, caps: caps, name: name}
conn := &conn{msgpipe, &protoHandshake{ID: id, Name: name, Caps: caps}} peer := newPeer(conn, nil)
peer := newPeer(pipe, conn, nil)
close(peer.closed) // ensures Disconnect doesn't block close(peer.closed) // ensures Disconnect doesn't block
return peer return peer
} }
// ID returns the node's public key. // ID returns the node's public key.
func (p *Peer) ID() discover.NodeID { func (p *Peer) ID() discover.NodeID {
return p.rw.ID return p.rw.id
} }
// Name returns the node name that the remote node advertised. // Name returns the node name that the remote node advertised.
func (p *Peer) Name() string { func (p *Peer) Name() string {
return p.rw.Name return p.rw.name
} }
// Caps returns the capabilities (supported subprotocols) of the remote peer. // Caps returns the capabilities (supported subprotocols) of the remote peer.
func (p *Peer) Caps() []Cap { func (p *Peer) Caps() []Cap {
// TODO: maybe return copy // TODO: maybe return copy
return p.rw.Caps return p.rw.caps
} }
// RemoteAddr returns the remote address of the network connection. // RemoteAddr returns the remote address of the network connection.
func (p *Peer) RemoteAddr() net.Addr { func (p *Peer) RemoteAddr() net.Addr {
return p.conn.RemoteAddr() return p.rw.fd.RemoteAddr()
} }
// LocalAddr returns the local address of the network connection. // LocalAddr returns the local address of the network connection.
func (p *Peer) LocalAddr() net.Addr { func (p *Peer) LocalAddr() net.Addr {
return p.conn.LocalAddr() return p.rw.fd.LocalAddr()
} }
// Disconnect terminates the peer connection with the given reason. // Disconnect terminates the peer connection with the given reason.
@ -92,13 +99,12 @@ func (p *Peer) Disconnect(reason DiscReason) {
// String implements fmt.Stringer. // String implements fmt.Stringer.
func (p *Peer) String() string { func (p *Peer) String() string {
return fmt.Sprintf("Peer %.8x %v", p.rw.ID[:], p.RemoteAddr()) return fmt.Sprintf("Peer %x %v", p.rw.id[:8], p.RemoteAddr())
} }
func newPeer(fd net.Conn, conn *conn, protocols []Protocol) *Peer { func newPeer(conn *conn, protocols []Protocol) *Peer {
protomap := matchProtocols(protocols, conn.Caps, conn) protomap := matchProtocols(protocols, conn.caps, conn)
p := &Peer{ p := &Peer{
conn: fd,
rw: conn, rw: conn,
running: protomap, running: protomap,
disc: make(chan DiscReason), disc: make(chan DiscReason),
@ -117,7 +123,10 @@ func (p *Peer) run() DiscReason {
p.startProtocols() p.startProtocols()
// Wait for an error or disconnect. // Wait for an error or disconnect.
var reason DiscReason var (
reason DiscReason
requested bool
)
select { select {
case err := <-readErr: case err := <-readErr:
if r, ok := err.(DiscReason); ok { if r, ok := err.(DiscReason); ok {
@ -131,23 +140,19 @@ func (p *Peer) run() DiscReason {
case err := <-p.protoErr: case err := <-p.protoErr:
reason = discReasonForError(err) reason = discReasonForError(err)
case reason = <-p.disc: case reason = <-p.disc:
p.politeDisconnect(reason) requested = true
}
close(p.closed)
p.rw.close(reason)
p.wg.Wait()
if requested {
reason = DiscRequested reason = DiscRequested
} }
close(p.closed)
p.wg.Wait()
glog.V(logger.Debug).Infof("%v: Disconnected: %v\n", p, reason) glog.V(logger.Debug).Infof("%v: Disconnected: %v\n", p, reason)
return reason return reason
} }
func (p *Peer) politeDisconnect(reason DiscReason) {
if reason != DiscNetworkError {
SendItems(p.rw, discMsg, uint(reason))
}
p.conn.Close()
}
func (p *Peer) pingLoop() { func (p *Peer) pingLoop() {
ping := time.NewTicker(pingInterval) ping := time.NewTicker(pingInterval)
defer p.wg.Done() defer p.wg.Done()
@ -254,7 +259,7 @@ func (p *Peer) startProtocols() {
glog.V(logger.Detail).Infof("%v: Protocol %s/%d returned\n", p, proto.Name, proto.Version) glog.V(logger.Detail).Infof("%v: Protocol %s/%d returned\n", p, proto.Name, proto.Version)
err = errors.New("protocol returned") err = errors.New("protocol returned")
} else if err != io.EOF { } else if err != io.EOF {
glog.V(logger.Detail).Infof("%v: Protocol %s/%d error: \n", p, proto.Name, proto.Version, err) glog.V(logger.Detail).Infof("%v: Protocol %s/%d error: %v\n", p, proto.Name, proto.Version, err)
} }
p.protoErr <- err p.protoErr <- err
p.wg.Done() p.wg.Done()

View File

@ -5,39 +5,17 @@ import (
) )
const ( const (
errMagicTokenMismatch = iota errInvalidMsgCode = iota
errRead
errWrite
errMisc
errInvalidMsgCode
errInvalidMsg errInvalidMsg
errP2PVersionMismatch
errPubkeyInvalid
errPubkeyForbidden
errProtocolBreach
errPingTimeout
errInvalidNetworkId
errInvalidProtocolVersion
) )
var errorToString = map[int]string{ var errorToString = map[int]string{
errMagicTokenMismatch: "magic token mismatch",
errRead: "read error",
errWrite: "write error",
errMisc: "misc error",
errInvalidMsgCode: "invalid message code", errInvalidMsgCode: "invalid message code",
errInvalidMsg: "invalid message", errInvalidMsg: "invalid message",
errP2PVersionMismatch: "P2P Version Mismatch",
errPubkeyInvalid: "public key invalid",
errPubkeyForbidden: "public key forbidden",
errProtocolBreach: "protocol Breach",
errPingTimeout: "ping timeout",
errInvalidNetworkId: "invalid network id",
errInvalidProtocolVersion: "invalid protocol version",
} }
type peerError struct { type peerError struct {
Code int code int
message string message string
} }
@ -107,23 +85,13 @@ func discReasonForError(err error) DiscReason {
return reason return reason
} }
peerError, ok := err.(*peerError) peerError, ok := err.(*peerError)
if !ok { if ok {
return DiscSubprotocolError switch peerError.code {
} case errInvalidMsgCode, errInvalidMsg:
switch peerError.Code {
case errP2PVersionMismatch:
return DiscIncompatibleVersion
case errPubkeyInvalid:
return DiscInvalidIdentity
case errPubkeyForbidden:
return DiscUselessPeer
case errInvalidMsgCode, errMagicTokenMismatch, errProtocolBreach:
return DiscProtocolError return DiscProtocolError
case errPingTimeout:
return DiscReadTimeout
case errRead, errWrite:
return DiscNetworkError
default: default:
return DiscSubprotocolError return DiscSubprotocolError
} }
} }
return DiscSubprotocolError
}

View File

@ -28,24 +28,20 @@ var discard = Protocol{
} }
func testPeer(protos []Protocol) (func(), *conn, *Peer, <-chan DiscReason) { func testPeer(protos []Protocol) (func(), *conn, *Peer, <-chan DiscReason) {
fd1, _ := net.Pipe() fd1, fd2 := net.Pipe()
hs1 := &protoHandshake{ID: randomID(), Version: baseProtocolVersion} c1 := &conn{fd: fd1, transport: newTestTransport(randomID(), fd1)}
hs2 := &protoHandshake{ID: randomID(), Version: baseProtocolVersion} c2 := &conn{fd: fd2, transport: newTestTransport(randomID(), fd2)}
for _, p := range protos { for _, p := range protos {
hs1.Caps = append(hs1.Caps, p.cap()) c1.caps = append(c1.caps, p.cap())
hs2.Caps = append(hs2.Caps, p.cap()) c2.caps = append(c2.caps, p.cap())
} }
p1, p2 := MsgPipe() peer := newPeer(c1, protos)
peer := newPeer(fd1, &conn{p1, hs1}, protos)
errc := make(chan DiscReason, 1) errc := make(chan DiscReason, 1)
go func() { errc <- peer.run() }() go func() { errc <- peer.run() }()
closer := func() { closer := func() { c2.close(errors.New("close func called")) }
p1.Close() return closer, c2, peer, errc
fd1.Close()
}
return closer, &conn{p2, hs2}, peer, errc
} }
func TestPeerProtoReadMsg(t *testing.T) { func TestPeerProtoReadMsg(t *testing.T) {

View File

@ -4,23 +4,459 @@ import (
"bytes" "bytes"
"crypto/aes" "crypto/aes"
"crypto/cipher" "crypto/cipher"
"crypto/ecdsa"
"crypto/elliptic"
"crypto/hmac" "crypto/hmac"
"crypto/rand"
"errors" "errors"
"fmt"
"hash" "hash"
"io" "io"
"net"
"sync"
"time"
"github.com/ethereum/go-ethereum/crypto"
"github.com/ethereum/go-ethereum/crypto/ecies"
"github.com/ethereum/go-ethereum/crypto/secp256k1"
"github.com/ethereum/go-ethereum/crypto/sha3"
"github.com/ethereum/go-ethereum/p2p/discover"
"github.com/ethereum/go-ethereum/rlp" "github.com/ethereum/go-ethereum/rlp"
) )
const (
maxUint24 = ^uint32(0) >> 8
sskLen = 16 // ecies.MaxSharedKeyLength(pubKey) / 2
sigLen = 65 // elliptic S256
pubLen = 64 // 512 bit pubkey in uncompressed representation without format byte
shaLen = 32 // hash length (for nonce etc)
authMsgLen = sigLen + shaLen + pubLen + shaLen + 1
authRespLen = pubLen + shaLen + 1
eciesBytes = 65 + 16 + 32
encAuthMsgLen = authMsgLen + eciesBytes // size of the final ECIES payload sent as initiator's handshake
encAuthRespLen = authRespLen + eciesBytes // size of the final ECIES payload sent as receiver's handshake
// total timeout for encryption handshake and protocol
// handshake in both directions.
handshakeTimeout = 5 * time.Second
// This is the timeout for sending the disconnect reason.
// This is shorter than the usual timeout because we don't want
// to wait if the connection is known to be bad anyway.
discWriteTimeout = 1 * time.Second
)
// rlpx is the transport protocol used by actual (non-test) connections.
// It wraps the frame encoder with locks and read/write deadlines.
type rlpx struct {
fd net.Conn
rmu, wmu sync.Mutex
rw *rlpxFrameRW
}
func newRLPX(fd net.Conn) transport {
fd.SetDeadline(time.Now().Add(handshakeTimeout))
return &rlpx{fd: fd}
}
func (t *rlpx) ReadMsg() (Msg, error) {
t.rmu.Lock()
defer t.rmu.Unlock()
t.fd.SetReadDeadline(time.Now().Add(frameReadTimeout))
return t.rw.ReadMsg()
}
func (t *rlpx) WriteMsg(msg Msg) error {
t.wmu.Lock()
defer t.wmu.Unlock()
t.fd.SetWriteDeadline(time.Now().Add(frameWriteTimeout))
return t.rw.WriteMsg(msg)
}
func (t *rlpx) close(err error) {
t.wmu.Lock()
defer t.wmu.Unlock()
// Tell the remote end why we're disconnecting if possible.
if t.rw != nil {
if r, ok := err.(DiscReason); ok && r != DiscNetworkError {
t.fd.SetWriteDeadline(time.Now().Add(discWriteTimeout))
SendItems(t.rw, discMsg, r)
}
}
t.fd.Close()
}
// doEncHandshake runs the protocol handshake using authenticated
// messages. the protocol handshake is the first authenticated message
// and also verifies whether the encryption handshake 'worked' and the
// remote side actually provided the right public key.
func (t *rlpx) doProtoHandshake(our *protoHandshake) (their *protoHandshake, err error) {
// Writing our handshake happens concurrently, we prefer
// returning the handshake read error. If the remote side
// disconnects us early with a valid reason, we should return it
// as the error so it can be tracked elsewhere.
werr := make(chan error, 1)
go func() { werr <- Send(t.rw, handshakeMsg, our) }()
if their, err = readProtocolHandshake(t.rw, our); err != nil {
return nil, err
}
if err := <-werr; err != nil {
return nil, fmt.Errorf("write error: %v", err)
}
return their, nil
}
func readProtocolHandshake(rw MsgReader, our *protoHandshake) (*protoHandshake, error) {
msg, err := rw.ReadMsg()
if err != nil {
return nil, err
}
if msg.Size > baseProtocolMaxMsgSize {
return nil, fmt.Errorf("message too big")
}
if msg.Code == discMsg {
// Disconnect before protocol handshake is valid according to the
// spec and we send it ourself if the posthanshake checks fail.
// We can't return the reason directly, though, because it is echoed
// back otherwise. Wrap it in a string instead.
var reason [1]DiscReason
rlp.Decode(msg.Payload, &reason)
return nil, reason[0]
}
if msg.Code != handshakeMsg {
return nil, fmt.Errorf("expected handshake, got %x", msg.Code)
}
var hs protoHandshake
if err := msg.Decode(&hs); err != nil {
return nil, err
}
// validate handshake info
if hs.Version != our.Version {
return nil, DiscIncompatibleVersion
}
if (hs.ID == discover.NodeID{}) {
return nil, DiscInvalidIdentity
}
return &hs, nil
}
func (t *rlpx) doEncHandshake(prv *ecdsa.PrivateKey, dial *discover.Node) (discover.NodeID, error) {
var (
sec secrets
err error
)
if dial == nil {
sec, err = receiverEncHandshake(t.fd, prv, nil)
} else {
sec, err = initiatorEncHandshake(t.fd, prv, dial.ID, nil)
}
if err != nil {
return discover.NodeID{}, err
}
t.wmu.Lock()
t.rw = newRLPXFrameRW(t.fd, sec)
t.wmu.Unlock()
return sec.RemoteID, nil
}
// encHandshake contains the state of the encryption handshake.
type encHandshake struct {
initiator bool
remoteID discover.NodeID
remotePub *ecies.PublicKey // remote-pubk
initNonce, respNonce []byte // nonce
randomPrivKey *ecies.PrivateKey // ecdhe-random
remoteRandomPub *ecies.PublicKey // ecdhe-random-pubk
}
// secrets represents the connection secrets
// which are negotiated during the encryption handshake.
type secrets struct {
RemoteID discover.NodeID
AES, MAC []byte
EgressMAC, IngressMAC hash.Hash
Token []byte
}
// secrets is called after the handshake is completed.
// It extracts the connection secrets from the handshake values.
func (h *encHandshake) secrets(auth, authResp []byte) (secrets, error) {
ecdheSecret, err := h.randomPrivKey.GenerateShared(h.remoteRandomPub, sskLen, sskLen)
if err != nil {
return secrets{}, err
}
// derive base secrets from ephemeral key agreement
sharedSecret := crypto.Sha3(ecdheSecret, crypto.Sha3(h.respNonce, h.initNonce))
aesSecret := crypto.Sha3(ecdheSecret, sharedSecret)
s := secrets{
RemoteID: h.remoteID,
AES: aesSecret,
MAC: crypto.Sha3(ecdheSecret, aesSecret),
Token: crypto.Sha3(sharedSecret),
}
// setup sha3 instances for the MACs
mac1 := sha3.NewKeccak256()
mac1.Write(xor(s.MAC, h.respNonce))
mac1.Write(auth)
mac2 := sha3.NewKeccak256()
mac2.Write(xor(s.MAC, h.initNonce))
mac2.Write(authResp)
if h.initiator {
s.EgressMAC, s.IngressMAC = mac1, mac2
} else {
s.EgressMAC, s.IngressMAC = mac2, mac1
}
return s, nil
}
func (h *encHandshake) ecdhShared(prv *ecdsa.PrivateKey) ([]byte, error) {
return ecies.ImportECDSA(prv).GenerateShared(h.remotePub, sskLen, sskLen)
}
// initiatorEncHandshake negotiates a session token on conn.
// it should be called on the dialing side of the connection.
//
// prv is the local client's private key.
// token is the token from a previous session with this node.
func initiatorEncHandshake(conn io.ReadWriter, prv *ecdsa.PrivateKey, remoteID discover.NodeID, token []byte) (s secrets, err error) {
h, err := newInitiatorHandshake(remoteID)
if err != nil {
return s, err
}
auth, err := h.authMsg(prv, token)
if err != nil {
return s, err
}
if _, err = conn.Write(auth); err != nil {
return s, err
}
response := make([]byte, encAuthRespLen)
if _, err = io.ReadFull(conn, response); err != nil {
return s, err
}
if err := h.decodeAuthResp(response, prv); err != nil {
return s, err
}
return h.secrets(auth, response)
}
func newInitiatorHandshake(remoteID discover.NodeID) (*encHandshake, error) {
// generate random initiator nonce
n := make([]byte, shaLen)
if _, err := rand.Read(n); err != nil {
return nil, err
}
// generate random keypair to use for signing
randpriv, err := ecies.GenerateKey(rand.Reader, crypto.S256(), nil)
if err != nil {
return nil, err
}
rpub, err := remoteID.Pubkey()
if err != nil {
return nil, fmt.Errorf("bad remoteID: %v", err)
}
h := &encHandshake{
initiator: true,
remoteID: remoteID,
remotePub: ecies.ImportECDSAPublic(rpub),
initNonce: n,
randomPrivKey: randpriv,
}
return h, nil
}
// authMsg creates an encrypted initiator handshake message.
func (h *encHandshake) authMsg(prv *ecdsa.PrivateKey, token []byte) ([]byte, error) {
var tokenFlag byte
if token == nil {
// no session token found means we need to generate shared secret.
// ecies shared secret is used as initial session token for new peers
// generate shared key from prv and remote pubkey
var err error
if token, err = h.ecdhShared(prv); err != nil {
return nil, err
}
} else {
// for known peers, we use stored token from the previous session
tokenFlag = 0x01
}
// sign known message:
// ecdh-shared-secret^nonce for new peers
// token^nonce for old peers
signed := xor(token, h.initNonce)
signature, err := crypto.Sign(signed, h.randomPrivKey.ExportECDSA())
if err != nil {
return nil, err
}
// encode auth message
// signature || sha3(ecdhe-random-pubk) || pubk || nonce || token-flag
msg := make([]byte, authMsgLen)
n := copy(msg, signature)
n += copy(msg[n:], crypto.Sha3(exportPubkey(&h.randomPrivKey.PublicKey)))
n += copy(msg[n:], crypto.FromECDSAPub(&prv.PublicKey)[1:])
n += copy(msg[n:], h.initNonce)
msg[n] = tokenFlag
// encrypt auth message using remote-pubk
return ecies.Encrypt(rand.Reader, h.remotePub, msg, nil, nil)
}
// decodeAuthResp decode an encrypted authentication response message.
func (h *encHandshake) decodeAuthResp(auth []byte, prv *ecdsa.PrivateKey) error {
msg, err := crypto.Decrypt(prv, auth)
if err != nil {
return fmt.Errorf("could not decrypt auth response (%v)", err)
}
h.respNonce = msg[pubLen : pubLen+shaLen]
h.remoteRandomPub, err = importPublicKey(msg[:pubLen])
if err != nil {
return err
}
// ignore token flag for now
return nil
}
// receiverEncHandshake negotiates a session token on conn.
// it should be called on the listening side of the connection.
//
// prv is the local client's private key.
// token is the token from a previous session with this node.
func receiverEncHandshake(conn io.ReadWriter, prv *ecdsa.PrivateKey, token []byte) (s secrets, err error) {
// read remote auth sent by initiator.
auth := make([]byte, encAuthMsgLen)
if _, err := io.ReadFull(conn, auth); err != nil {
return s, err
}
h, err := decodeAuthMsg(prv, token, auth)
if err != nil {
return s, err
}
// send auth response
resp, err := h.authResp(prv, token)
if err != nil {
return s, err
}
if _, err = conn.Write(resp); err != nil {
return s, err
}
return h.secrets(auth, resp)
}
func decodeAuthMsg(prv *ecdsa.PrivateKey, token []byte, auth []byte) (*encHandshake, error) {
var err error
h := new(encHandshake)
// generate random keypair for session
h.randomPrivKey, err = ecies.GenerateKey(rand.Reader, crypto.S256(), nil)
if err != nil {
return nil, err
}
// generate random nonce
h.respNonce = make([]byte, shaLen)
if _, err = rand.Read(h.respNonce); err != nil {
return nil, err
}
msg, err := crypto.Decrypt(prv, auth)
if err != nil {
return nil, fmt.Errorf("could not decrypt auth message (%v)", err)
}
// decode message parameters
// signature || sha3(ecdhe-random-pubk) || pubk || nonce || token-flag
h.initNonce = msg[authMsgLen-shaLen-1 : authMsgLen-1]
copy(h.remoteID[:], msg[sigLen+shaLen:sigLen+shaLen+pubLen])
rpub, err := h.remoteID.Pubkey()
if err != nil {
return nil, fmt.Errorf("bad remoteID: %#v", err)
}
h.remotePub = ecies.ImportECDSAPublic(rpub)
// recover remote random pubkey from signed message.
if token == nil {
// TODO: it is an error if the initiator has a token and we don't. check that.
// no session token means we need to generate shared secret.
// ecies shared secret is used as initial session token for new peers.
// generate shared key from prv and remote pubkey.
if token, err = h.ecdhShared(prv); err != nil {
return nil, err
}
}
signedMsg := xor(token, h.initNonce)
remoteRandomPub, err := secp256k1.RecoverPubkey(signedMsg, msg[:sigLen])
if err != nil {
return nil, err
}
h.remoteRandomPub, _ = importPublicKey(remoteRandomPub)
return h, nil
}
// authResp generates the encrypted authentication response message.
func (h *encHandshake) authResp(prv *ecdsa.PrivateKey, token []byte) ([]byte, error) {
// responder auth message
// E(remote-pubk, ecdhe-random-pubk || nonce || 0x0)
resp := make([]byte, authRespLen)
n := copy(resp, exportPubkey(&h.randomPrivKey.PublicKey))
n += copy(resp[n:], h.respNonce)
if token == nil {
resp[n] = 0
} else {
resp[n] = 1
}
// encrypt using remote-pubk
return ecies.Encrypt(rand.Reader, h.remotePub, resp, nil, nil)
}
// importPublicKey unmarshals 512 bit public keys.
func importPublicKey(pubKey []byte) (*ecies.PublicKey, error) {
var pubKey65 []byte
switch len(pubKey) {
case 64:
// add 'uncompressed key' flag
pubKey65 = append([]byte{0x04}, pubKey...)
case 65:
pubKey65 = pubKey
default:
return nil, fmt.Errorf("invalid public key length %v (expect 64/65)", len(pubKey))
}
// TODO: fewer pointless conversions
return ecies.ImportECDSAPublic(crypto.ToECDSAPub(pubKey65)), nil
}
func exportPubkey(pub *ecies.PublicKey) []byte {
if pub == nil {
panic("nil pubkey")
}
return elliptic.Marshal(pub.Curve, pub.X, pub.Y)[1:]
}
func xor(one, other []byte) (xor []byte) {
xor = make([]byte, len(one))
for i := 0; i < len(one); i++ {
xor[i] = one[i] ^ other[i]
}
return xor
}
var ( var (
// this is used in place of actual frame header data. // this is used in place of actual frame header data.
// TODO: replace this when Msg contains the protocol type code. // TODO: replace this when Msg contains the protocol type code.
zeroHeader = []byte{0xC2, 0x80, 0x80} zeroHeader = []byte{0xC2, 0x80, 0x80}
// sixteen zero bytes // sixteen zero bytes
zero16 = make([]byte, 16) zero16 = make([]byte, 16)
maxUint24 = ^uint32(0) >> 8
) )
// rlpxFrameRW implements a simplified version of RLPx framing. // rlpxFrameRW implements a simplified version of RLPx framing.
@ -38,7 +474,7 @@ type rlpxFrameRW struct {
ingressMAC hash.Hash ingressMAC hash.Hash
} }
func newRlpxFrameRW(conn io.ReadWriter, s secrets) *rlpxFrameRW { func newRLPXFrameRW(conn io.ReadWriter, s secrets) *rlpxFrameRW {
macc, err := aes.NewCipher(s.MAC) macc, err := aes.NewCipher(s.MAC)
if err != nil { if err != nil {
panic("invalid MAC secret: " + err.Error()) panic("invalid MAC secret: " + err.Error())

View File

@ -3,19 +3,253 @@ package p2p
import ( import (
"bytes" "bytes"
"crypto/rand" "crypto/rand"
"errors"
"fmt"
"io/ioutil" "io/ioutil"
"net"
"reflect"
"strings" "strings"
"sync"
"testing" "testing"
"time"
"github.com/davecgh/go-spew/spew"
"github.com/ethereum/go-ethereum/crypto" "github.com/ethereum/go-ethereum/crypto"
"github.com/ethereum/go-ethereum/crypto/ecies"
"github.com/ethereum/go-ethereum/crypto/sha3" "github.com/ethereum/go-ethereum/crypto/sha3"
"github.com/ethereum/go-ethereum/p2p/discover"
"github.com/ethereum/go-ethereum/rlp" "github.com/ethereum/go-ethereum/rlp"
) )
func TestRlpxFrameFake(t *testing.T) { func TestSharedSecret(t *testing.T) {
prv0, _ := crypto.GenerateKey() // = ecdsa.GenerateKey(crypto.S256(), rand.Reader)
pub0 := &prv0.PublicKey
prv1, _ := crypto.GenerateKey()
pub1 := &prv1.PublicKey
ss0, err := ecies.ImportECDSA(prv0).GenerateShared(ecies.ImportECDSAPublic(pub1), sskLen, sskLen)
if err != nil {
return
}
ss1, err := ecies.ImportECDSA(prv1).GenerateShared(ecies.ImportECDSAPublic(pub0), sskLen, sskLen)
if err != nil {
return
}
t.Logf("Secret:\n%v %x\n%v %x", len(ss0), ss0, len(ss0), ss1)
if !bytes.Equal(ss0, ss1) {
t.Errorf("dont match :(")
}
}
func TestEncHandshake(t *testing.T) {
for i := 0; i < 10; i++ {
start := time.Now()
if err := testEncHandshake(nil); err != nil {
t.Fatalf("i=%d %v", i, err)
}
t.Logf("(without token) %d %v\n", i+1, time.Since(start))
}
for i := 0; i < 10; i++ {
tok := make([]byte, shaLen)
rand.Reader.Read(tok)
start := time.Now()
if err := testEncHandshake(tok); err != nil {
t.Fatalf("i=%d %v", i, err)
}
t.Logf("(with token) %d %v\n", i+1, time.Since(start))
}
}
func testEncHandshake(token []byte) error {
type result struct {
side string
id discover.NodeID
err error
}
var (
prv0, _ = crypto.GenerateKey()
prv1, _ = crypto.GenerateKey()
fd0, fd1 = net.Pipe()
c0, c1 = newRLPX(fd0).(*rlpx), newRLPX(fd1).(*rlpx)
output = make(chan result)
)
go func() {
r := result{side: "initiator"}
defer func() { output <- r }()
dest := &discover.Node{ID: discover.PubkeyID(&prv1.PublicKey)}
r.id, r.err = c0.doEncHandshake(prv0, dest)
if r.err != nil {
return
}
id1 := discover.PubkeyID(&prv1.PublicKey)
if r.id != id1 {
r.err = fmt.Errorf("remote ID mismatch: got %v, want: %v", r.id, id1)
}
}()
go func() {
r := result{side: "receiver"}
defer func() { output <- r }()
r.id, r.err = c1.doEncHandshake(prv1, nil)
if r.err != nil {
return
}
id0 := discover.PubkeyID(&prv0.PublicKey)
if r.id != id0 {
r.err = fmt.Errorf("remote ID mismatch: got %v, want: %v", r.id, id0)
}
}()
// wait for results from both sides
r1, r2 := <-output, <-output
if r1.err != nil {
return fmt.Errorf("%s side error: %v", r1.side, r1.err)
}
if r2.err != nil {
return fmt.Errorf("%s side error: %v", r2.side, r2.err)
}
// compare derived secrets
if !reflect.DeepEqual(c0.rw.egressMAC, c1.rw.ingressMAC) {
return fmt.Errorf("egress mac mismatch:\n c0.rw: %#v\n c1.rw: %#v", c0.rw.egressMAC, c1.rw.ingressMAC)
}
if !reflect.DeepEqual(c0.rw.ingressMAC, c1.rw.egressMAC) {
return fmt.Errorf("ingress mac mismatch:\n c0.rw: %#v\n c1.rw: %#v", c0.rw.ingressMAC, c1.rw.egressMAC)
}
if !reflect.DeepEqual(c0.rw.enc, c1.rw.enc) {
return fmt.Errorf("enc cipher mismatch:\n c0.rw: %#v\n c1.rw: %#v", c0.rw.enc, c1.rw.enc)
}
if !reflect.DeepEqual(c0.rw.dec, c1.rw.dec) {
return fmt.Errorf("dec cipher mismatch:\n c0.rw: %#v\n c1.rw: %#v", c0.rw.dec, c1.rw.dec)
}
return nil
}
func TestProtocolHandshake(t *testing.T) {
var (
prv0, _ = crypto.GenerateKey()
node0 = &discover.Node{ID: discover.PubkeyID(&prv0.PublicKey), IP: net.IP{1, 2, 3, 4}, TCP: 33}
hs0 = &protoHandshake{Version: 3, ID: node0.ID, Caps: []Cap{{"a", 0}, {"b", 2}}}
prv1, _ = crypto.GenerateKey()
node1 = &discover.Node{ID: discover.PubkeyID(&prv1.PublicKey), IP: net.IP{5, 6, 7, 8}, TCP: 44}
hs1 = &protoHandshake{Version: 3, ID: node1.ID, Caps: []Cap{{"c", 1}, {"d", 3}}}
fd0, fd1 = net.Pipe()
wg sync.WaitGroup
)
wg.Add(2)
go func() {
defer wg.Done()
rlpx := newRLPX(fd0)
remid, err := rlpx.doEncHandshake(prv0, node1)
if err != nil {
t.Errorf("dial side enc handshake failed: %v", err)
return
}
if remid != node1.ID {
t.Errorf("dial side remote id mismatch: got %v, want %v", remid, node1.ID)
return
}
phs, err := rlpx.doProtoHandshake(hs0)
if err != nil {
t.Errorf("dial side proto handshake error: %v", err)
return
}
if !reflect.DeepEqual(phs, hs1) {
t.Errorf("dial side proto handshake mismatch:\ngot: %s\nwant: %s\n", spew.Sdump(phs), spew.Sdump(hs1))
return
}
rlpx.close(DiscQuitting)
}()
go func() {
defer wg.Done()
rlpx := newRLPX(fd1)
remid, err := rlpx.doEncHandshake(prv1, nil)
if err != nil {
t.Errorf("listen side enc handshake failed: %v", err)
return
}
if remid != node0.ID {
t.Errorf("listen side remote id mismatch: got %v, want %v", remid, node0.ID)
return
}
phs, err := rlpx.doProtoHandshake(hs1)
if err != nil {
t.Errorf("listen side proto handshake error: %v", err)
return
}
if !reflect.DeepEqual(phs, hs0) {
t.Errorf("listen side proto handshake mismatch:\ngot: %s\nwant: %s\n", spew.Sdump(phs), spew.Sdump(hs0))
return
}
if err := ExpectMsg(rlpx, discMsg, []DiscReason{DiscQuitting}); err != nil {
t.Errorf("error receiving disconnect: %v", err)
}
}()
wg.Wait()
}
func TestProtocolHandshakeErrors(t *testing.T) {
our := &protoHandshake{Version: 3, Caps: []Cap{{"foo", 2}, {"bar", 3}}, Name: "quux"}
id := randomID()
tests := []struct {
code uint64
msg interface{}
err error
}{
{
code: discMsg,
msg: []DiscReason{DiscQuitting},
err: DiscQuitting,
},
{
code: 0x989898,
msg: []byte{1},
err: errors.New("expected handshake, got 989898"),
},
{
code: handshakeMsg,
msg: make([]byte, baseProtocolMaxMsgSize+2),
err: errors.New("message too big"),
},
{
code: handshakeMsg,
msg: []byte{1, 2, 3},
err: newPeerError(errInvalidMsg, "(code 0) (size 4) rlp: expected input list for p2p.protoHandshake"),
},
{
code: handshakeMsg,
msg: &protoHandshake{Version: 9944, ID: id},
err: DiscIncompatibleVersion,
},
{
code: handshakeMsg,
msg: &protoHandshake{Version: 3},
err: DiscInvalidIdentity,
},
}
for i, test := range tests {
p1, p2 := MsgPipe()
go Send(p1, test.code, test.msg)
_, err := readProtocolHandshake(p2, our)
if !reflect.DeepEqual(err, test.err) {
t.Errorf("test %d: error mismatch: got %q, want %q", i, err, test.err)
}
}
}
func TestRLPXFrameFake(t *testing.T) {
buf := new(bytes.Buffer) buf := new(bytes.Buffer)
hash := fakeHash([]byte{1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1}) hash := fakeHash([]byte{1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1})
rw := newRlpxFrameRW(buf, secrets{ rw := newRLPXFrameRW(buf, secrets{
AES: crypto.Sha3(), AES: crypto.Sha3(),
MAC: crypto.Sha3(), MAC: crypto.Sha3(),
IngressMAC: hash, IngressMAC: hash,
@ -66,7 +300,7 @@ func (fakeHash) BlockSize() int { return 0 }
func (h fakeHash) Size() int { return len(h) } func (h fakeHash) Size() int { return len(h) }
func (h fakeHash) Sum(b []byte) []byte { return append(b, h...) } func (h fakeHash) Sum(b []byte) []byte { return append(b, h...) }
func TestRlpxFrameRW(t *testing.T) { func TestRLPXFrameRW(t *testing.T) {
var ( var (
aesSecret = make([]byte, 16) aesSecret = make([]byte, 16)
macSecret = make([]byte, 16) macSecret = make([]byte, 16)
@ -86,7 +320,7 @@ func TestRlpxFrameRW(t *testing.T) {
} }
s1.EgressMAC.Write(egressMACinit) s1.EgressMAC.Write(egressMACinit)
s1.IngressMAC.Write(ingressMACinit) s1.IngressMAC.Write(ingressMACinit)
rw1 := newRlpxFrameRW(conn, s1) rw1 := newRLPXFrameRW(conn, s1)
s2 := secrets{ s2 := secrets{
AES: aesSecret, AES: aesSecret,
@ -96,7 +330,7 @@ func TestRlpxFrameRW(t *testing.T) {
} }
s2.EgressMAC.Write(ingressMACinit) s2.EgressMAC.Write(ingressMACinit)
s2.IngressMAC.Write(egressMACinit) s2.IngressMAC.Write(egressMACinit)
rw2 := newRlpxFrameRW(conn, s2) rw2 := newRLPXFrameRW(conn, s2)
// send some messages // send some messages
for i := 0; i < 10; i++ { for i := 0; i < 10; i++ {

View File

@ -2,7 +2,6 @@ package p2p
import ( import (
"crypto/ecdsa" "crypto/ecdsa"
"crypto/rand"
"errors" "errors"
"fmt" "fmt"
"net" "net"
@ -24,11 +23,8 @@ const (
maxAcceptConns = 50 maxAcceptConns = 50
// Maximum number of concurrently dialing outbound connections. // Maximum number of concurrently dialing outbound connections.
maxDialingConns = 10 maxActiveDialTasks = 16
// total timeout for encryption handshake and protocol
// handshake in both directions.
handshakeTimeout = 5 * time.Second
// maximum time allowed for reading a complete message. // maximum time allowed for reading a complete message.
// this is effectively the amount of time a connection can be idle. // this is effectively the amount of time a connection can be idle.
frameReadTimeout = 1 * time.Minute frameReadTimeout = 1 * time.Minute
@ -36,6 +32,8 @@ const (
frameWriteTimeout = 5 * time.Second frameWriteTimeout = 5 * time.Second
) )
var errServerStopped = errors.New("server stopped")
var srvjslog = logger.NewJsonLogger() var srvjslog = logger.NewJsonLogger()
// Server manages all peer connections. // Server manages all peer connections.
@ -103,68 +101,173 @@ type Server struct {
// Hooks for testing. These are useful because we can inhibit // Hooks for testing. These are useful because we can inhibit
// the whole protocol stack. // the whole protocol stack.
setupFunc newTransport func(net.Conn) transport
newPeerHook newPeerHook func(*Peer)
lock sync.Mutex // protects running
running bool
ntab discoverTable
listener net.Listener
ourHandshake *protoHandshake ourHandshake *protoHandshake
lock sync.RWMutex // protects running, peers and the trust fields // These are for Peers, PeerCount (and nothing else).
running bool peerOp chan peerOpFunc
peers map[discover.NodeID]*Peer peerOpDone chan struct{}
staticNodes map[discover.NodeID]*discover.Node // Map of currently maintained static remote nodes
staticDial chan *discover.Node // Dial request channel reserved for the static nodes
staticCycle time.Duration // Overrides staticPeerCheckInterval, used for testing
trustedNodes map[discover.NodeID]bool // Set of currently trusted remote nodes
ntab *discover.Table
listener net.Listener
quit chan struct{} quit chan struct{}
loopWG sync.WaitGroup // {dial,listen,nat}Loop addstatic chan *discover.Node
peerWG sync.WaitGroup // active peer goroutines posthandshake chan *conn
addpeer chan *conn
delpeer chan *Peer
loopWG sync.WaitGroup // loop, listenLoop
} }
type setupFunc func(net.Conn, *ecdsa.PrivateKey, *protoHandshake, *discover.Node, func(discover.NodeID) bool) (*conn, error) type peerOpFunc func(map[discover.NodeID]*Peer)
type newPeerHook func(*Peer)
type connFlag int
const (
dynDialedConn connFlag = 1 << iota
staticDialedConn
inboundConn
trustedConn
)
// conn wraps a network connection with information gathered
// during the two handshakes.
type conn struct {
fd net.Conn
transport
flags connFlag
cont chan error // The run loop uses cont to signal errors to setupConn.
id discover.NodeID // valid after the encryption handshake
caps []Cap // valid after the protocol handshake
name string // valid after the protocol handshake
}
type transport interface {
// The two handshakes.
doEncHandshake(prv *ecdsa.PrivateKey, dialDest *discover.Node) (discover.NodeID, error)
doProtoHandshake(our *protoHandshake) (*protoHandshake, error)
// The MsgReadWriter can only be used after the encryption
// handshake has completed. The code uses conn.id to track this
// by setting it to a non-nil value after the encryption handshake.
MsgReadWriter
// transports must provide Close because we use MsgPipe in some of
// the tests. Closing the actual network connection doesn't do
// anything in those tests because NsgPipe doesn't use it.
close(err error)
}
func (c *conn) String() string {
s := c.flags.String() + " conn"
if (c.id != discover.NodeID{}) {
s += fmt.Sprintf(" %x", c.id[:8])
}
s += " " + c.fd.RemoteAddr().String()
return s
}
func (f connFlag) String() string {
s := ""
if f&trustedConn != 0 {
s += " trusted"
}
if f&dynDialedConn != 0 {
s += " dyn dial"
}
if f&staticDialedConn != 0 {
s += " static dial"
}
if f&inboundConn != 0 {
s += " inbound"
}
if s != "" {
s = s[1:]
}
return s
}
func (c *conn) is(f connFlag) bool {
return c.flags&f != 0
}
// Peers returns all connected peers. // Peers returns all connected peers.
func (srv *Server) Peers() (peers []*Peer) { func (srv *Server) Peers() []*Peer {
srv.lock.RLock() var ps []*Peer
defer srv.lock.RUnlock() select {
for _, peer := range srv.peers { // Note: We'd love to put this function into a variable but
if peer != nil { // that seems to cause a weird compiler error in some
peers = append(peers, peer) // environments.
case srv.peerOp <- func(peers map[discover.NodeID]*Peer) {
for _, p := range peers {
ps = append(ps, p)
} }
}:
<-srv.peerOpDone
case <-srv.quit:
} }
return return ps
} }
// PeerCount returns the number of connected peers. // PeerCount returns the number of connected peers.
func (srv *Server) PeerCount() int { func (srv *Server) PeerCount() int {
srv.lock.RLock() var count int
n := len(srv.peers) select {
srv.lock.RUnlock() case srv.peerOp <- func(ps map[discover.NodeID]*Peer) { count = len(ps) }:
return n <-srv.peerOpDone
case <-srv.quit:
}
return count
} }
// AddPeer connects to the given node and maintains the connection until the // AddPeer connects to the given node and maintains the connection until the
// server is shut down. If the connection fails for any reason, the server will // server is shut down. If the connection fails for any reason, the server will
// attempt to reconnect the peer. // attempt to reconnect the peer.
func (srv *Server) AddPeer(node *discover.Node) { func (srv *Server) AddPeer(node *discover.Node) {
select {
case srv.addstatic <- node:
case <-srv.quit:
}
}
// Self returns the local node's endpoint information.
func (srv *Server) Self() *discover.Node {
srv.lock.Lock() srv.lock.Lock()
defer srv.lock.Unlock() defer srv.lock.Unlock()
if !srv.running {
return &discover.Node{IP: net.ParseIP("0.0.0.0")}
}
return srv.ntab.Self()
}
srv.staticNodes[node.ID] = node // Stop terminates the server and all active peer connections.
// It blocks until all active connections have been closed.
func (srv *Server) Stop() {
srv.lock.Lock()
defer srv.lock.Unlock()
if !srv.running {
return
}
srv.running = false
if srv.listener != nil {
// this unblocks listener Accept
srv.listener.Close()
}
close(srv.quit)
srv.loopWG.Wait()
} }
// Start starts running the server. // Start starts running the server.
// Servers can be re-used and started again after stopping. // Servers can not be re-used after stopping.
func (srv *Server) Start() (err error) { func (srv *Server) Start() (err error) {
srv.lock.Lock() srv.lock.Lock()
defer srv.lock.Unlock() defer srv.lock.Unlock()
if srv.running { if srv.running {
return errors.New("server already running") return errors.New("server already running")
} }
srv.running = true
glog.V(logger.Info).Infoln("Starting Server") glog.V(logger.Info).Infoln("Starting Server")
// static fields // static fields
@ -174,23 +277,19 @@ func (srv *Server) Start() (err error) {
if srv.MaxPeers <= 0 { if srv.MaxPeers <= 0 {
return fmt.Errorf("Server.MaxPeers must be > 0") return fmt.Errorf("Server.MaxPeers must be > 0")
} }
if srv.newTransport == nil {
srv.newTransport = newRLPX
}
if srv.Dialer == nil {
srv.Dialer = &net.Dialer{Timeout: defaultDialTimeout}
}
srv.quit = make(chan struct{}) srv.quit = make(chan struct{})
srv.peers = make(map[discover.NodeID]*Peer) srv.addpeer = make(chan *conn)
srv.delpeer = make(chan *Peer)
// Create the current trust maps, and the associated dialing channel srv.posthandshake = make(chan *conn)
srv.trustedNodes = make(map[discover.NodeID]bool) srv.addstatic = make(chan *discover.Node)
for _, node := range srv.TrustedNodes { srv.peerOp = make(chan peerOpFunc)
srv.trustedNodes[node.ID] = true srv.peerOpDone = make(chan struct{})
}
srv.staticNodes = make(map[discover.NodeID]*discover.Node)
for _, node := range srv.StaticNodes {
srv.staticNodes[node.ID] = node
}
srv.staticDial = make(chan *discover.Node)
if srv.setupFunc == nil {
srv.setupFunc = setupConn
}
// node table // node table
ntab, err := discover.ListenUDP(srv.PrivateKey, srv.ListenAddr, srv.NAT, srv.NodeDatabase) ntab, err := discover.ListenUDP(srv.PrivateKey, srv.ListenAddr, srv.NAT, srv.NodeDatabase)
@ -198,37 +297,31 @@ func (srv *Server) Start() (err error) {
return err return err
} }
srv.ntab = ntab srv.ntab = ntab
dialer := newDialState(srv.StaticNodes, srv.ntab, srv.MaxPeers/2)
// handshake // handshake
srv.ourHandshake = &protoHandshake{Version: baseProtocolVersion, Name: srv.Name, ID: ntab.Self().ID} srv.ourHandshake = &protoHandshake{Version: baseProtocolVersion, Name: srv.Name, ID: ntab.Self().ID}
for _, p := range srv.Protocols { for _, p := range srv.Protocols {
srv.ourHandshake.Caps = append(srv.ourHandshake.Caps, p.cap()) srv.ourHandshake.Caps = append(srv.ourHandshake.Caps, p.cap())
} }
// listen/dial // listen/dial
if srv.ListenAddr != "" { if srv.ListenAddr != "" {
if err := srv.startListening(); err != nil { if err := srv.startListening(); err != nil {
return err return err
} }
} }
if srv.Dialer == nil {
srv.Dialer = &net.Dialer{Timeout: defaultDialTimeout}
}
if !srv.NoDial {
srv.loopWG.Add(1)
go srv.dialLoop()
}
if srv.NoDial && srv.ListenAddr == "" { if srv.NoDial && srv.ListenAddr == "" {
glog.V(logger.Warn).Infoln("I will be kind-of useless, neither dialing nor listening.") glog.V(logger.Warn).Infoln("I will be kind-of useless, neither dialing nor listening.")
} }
// maintain the static peers
go srv.staticNodesLoop()
srv.loopWG.Add(1)
go srv.run(dialer)
srv.running = true srv.running = true
return nil return nil
} }
func (srv *Server) startListening() error { func (srv *Server) startListening() error {
// Launch the TCP listener.
listener, err := net.Listen("tcp", srv.ListenAddr) listener, err := net.Listen("tcp", srv.ListenAddr)
if err != nil { if err != nil {
return err return err
@ -238,6 +331,7 @@ func (srv *Server) startListening() error {
srv.listener = listener srv.listener = listener
srv.loopWG.Add(1) srv.loopWG.Add(1)
go srv.listenLoop() go srv.listenLoop()
// Map the TCP listening port if NAT is configured.
if !laddr.IP.IsLoopback() && srv.NAT != nil { if !laddr.IP.IsLoopback() && srv.NAT != nil {
srv.loopWG.Add(1) srv.loopWG.Add(1)
go func() { go func() {
@ -248,50 +342,164 @@ func (srv *Server) startListening() error {
return nil return nil
} }
// Stop terminates the server and all active peer connections. type dialer interface {
// It blocks until all active connections have been closed. newTasks(running int, peers map[discover.NodeID]*Peer, now time.Time) []task
func (srv *Server) Stop() { taskDone(task, time.Time)
srv.lock.Lock() addStatic(*discover.Node)
if !srv.running {
srv.lock.Unlock()
return
} }
srv.running = false
srv.lock.Unlock()
glog.V(logger.Info).Infoln("Stopping Server") func (srv *Server) run(dialstate dialer) {
defer srv.loopWG.Done()
var (
peers = make(map[discover.NodeID]*Peer)
trusted = make(map[discover.NodeID]bool, len(srv.TrustedNodes))
tasks []task
pendingTasks []task
taskdone = make(chan task, maxActiveDialTasks)
)
// Put trusted nodes into a map to speed up checks.
// Trusted peers are loaded on startup and cannot be
// modified while the server is running.
for _, n := range srv.TrustedNodes {
trusted[n.ID] = true
}
// Some task list helpers.
delTask := func(t task) {
for i := range tasks {
if tasks[i] == t {
tasks = append(tasks[:i], tasks[i+1:]...)
break
}
}
}
scheduleTasks := func(new []task) {
pt := append(pendingTasks, new...)
start := maxActiveDialTasks - len(tasks)
if len(pt) < start {
start = len(pt)
}
if start > 0 {
tasks = append(tasks, pt[:start]...)
for _, t := range pt[:start] {
t := t
glog.V(logger.Detail).Infoln("new task:", t)
go func() { t.Do(srv); taskdone <- t }()
}
copy(pt, pt[start:])
pendingTasks = pt[:len(pt)-start]
}
}
running:
for {
// Query the dialer for new tasks and launch them.
now := time.Now()
nt := dialstate.newTasks(len(pendingTasks)+len(tasks), peers, now)
scheduleTasks(nt)
select {
case <-srv.quit:
// The server was stopped. Run the cleanup logic.
glog.V(logger.Detail).Infoln("<-quit: spinning down")
break running
case n := <-srv.addstatic:
// This channel is used by AddPeer to add to the
// ephemeral static peer list. Add it to the dialer,
// it will keep the node connected.
glog.V(logger.Detail).Infoln("<-addstatic:", n)
dialstate.addStatic(n)
case op := <-srv.peerOp:
// This channel is used by Peers and PeerCount.
op(peers)
srv.peerOpDone <- struct{}{}
case t := <-taskdone:
// A task got done. Tell dialstate about it so it
// can update its state and remove it from the active
// tasks list.
glog.V(logger.Detail).Infoln("<-taskdone:", t)
dialstate.taskDone(t, now)
delTask(t)
case c := <-srv.posthandshake:
// A connection has passed the encryption handshake so
// the remote identity is known (but hasn't been verified yet).
if trusted[c.id] {
// Ensure that the trusted flag is set before checking against MaxPeers.
c.flags |= trustedConn
}
glog.V(logger.Detail).Infoln("<-posthandshake:", c)
// TODO: track in-progress inbound node IDs (pre-Peer) to avoid dialing them.
c.cont <- srv.encHandshakeChecks(peers, c)
case c := <-srv.addpeer:
// At this point the connection is past the protocol handshake.
// Its capabilities are known and the remote identity is verified.
glog.V(logger.Detail).Infoln("<-addpeer:", c)
err := srv.protoHandshakeChecks(peers, c)
if err != nil {
glog.V(logger.Detail).Infof("Not adding %v as peer: %v", c, err)
} else {
// The handshakes are done and it passed all checks.
p := newPeer(c, srv.Protocols)
peers[c.id] = p
go srv.runPeer(p)
}
// The dialer logic relies on the assumption that
// dial tasks complete after the peer has been added or
// discarded. Unblock the task last.
c.cont <- err
case p := <-srv.delpeer:
// A peer disconnected.
glog.V(logger.Detail).Infoln("<-delpeer:", p)
delete(peers, p.ID())
}
}
// Terminate discovery. If there is a running lookup it will terminate soon.
srv.ntab.Close() srv.ntab.Close()
if srv.listener != nil { // Disconnect all peers.
// this unblocks listener Accept for _, p := range peers {
srv.listener.Close() p.Disconnect(DiscQuitting)
} }
close(srv.quit) // Wait for peers to shut down. Pending connections and tasks are
srv.loopWG.Wait() // not handled here and will terminate soon-ish because srv.quit
// is closed.
// No new peers can be added at this point because dialLoop and glog.V(logger.Detail).Infof("ignoring %d pending tasks at spindown", len(tasks))
// listenLoop are down. It is safe to call peerWG.Wait because for len(peers) > 0 {
// peerWG.Add is not called outside of those loops. p := <-srv.delpeer
srv.lock.Lock() glog.V(logger.Detail).Infoln("<-delpeer (spindown):", p)
for _, peer := range srv.peers { delete(peers, p.ID())
peer.Disconnect(DiscQuitting)
} }
srv.lock.Unlock()
srv.peerWG.Wait()
} }
// Self returns the local node's endpoint information. func (srv *Server) protoHandshakeChecks(peers map[discover.NodeID]*Peer, c *conn) error {
func (srv *Server) Self() *discover.Node { // Drop connections with no matching protocols.
srv.lock.RLock() if len(srv.Protocols) > 0 && countMatchingProtocols(srv.Protocols, c.caps) == 0 {
defer srv.lock.RUnlock() return DiscUselessPeer
if !srv.running {
return &discover.Node{IP: net.ParseIP("0.0.0.0")}
} }
return srv.ntab.Self() // Repeat the encryption handshake checks because the
// peer set might have changed between the handshakes.
return srv.encHandshakeChecks(peers, c)
} }
// main loop for adding connections via listening func (srv *Server) encHandshakeChecks(peers map[discover.NodeID]*Peer, c *conn) error {
switch {
case !c.is(trustedConn|staticDialedConn) && len(peers) >= srv.MaxPeers:
return DiscTooManyPeers
case peers[c.id] != nil:
return DiscAlreadyConnected
case c.id == srv.ntab.Self().ID:
return DiscSelf
default:
return nil
}
}
// listenLoop runs in its own goroutine and accepts
// inbound connections.
func (srv *Server) listenLoop() { func (srv *Server) listenLoop() {
defer srv.loopWG.Done() defer srv.loopWG.Done()
glog.V(logger.Info).Infoln("Listening on", srv.listener.Addr())
// This channel acts as a semaphore limiting // This channel acts as a semaphore limiting
// active inbound connections that are lingering pre-handshake. // active inbound connections that are lingering pre-handshake.
@ -305,204 +513,92 @@ func (srv *Server) listenLoop() {
slots <- struct{}{} slots <- struct{}{}
} }
glog.V(logger.Info).Infoln("Listening on", srv.listener.Addr())
for { for {
<-slots <-slots
conn, err := srv.listener.Accept() fd, err := srv.listener.Accept()
if err != nil { if err != nil {
return return
} }
glog.V(logger.Debug).Infof("Accepted conn %v\n", conn.RemoteAddr()) glog.V(logger.Debug).Infof("Accepted conn %v\n", fd.RemoteAddr())
srv.peerWG.Add(1)
go func() { go func() {
srv.startPeer(conn, nil) srv.setupConn(fd, inboundConn, nil)
slots <- struct{}{} slots <- struct{}{}
}() }()
} }
} }
// staticNodesLoop is responsible for periodically checking that static // setupConn runs the handshakes and attempts to add the connection
// connections are actually live, and requests dialing if not. // as a peer. It returns when the connection has been added as a peer
func (srv *Server) staticNodesLoop() { // or the handshakes have failed.
// Create a default maintenance ticker, but override it requested func (srv *Server) setupConn(fd net.Conn, flags connFlag, dialDest *discover.Node) {
cycle := staticPeerCheckInterval // Prevent leftover pending conns from entering the handshake.
if srv.staticCycle != 0 { srv.lock.Lock()
cycle = srv.staticCycle running := srv.running
} srv.lock.Unlock()
tick := time.NewTicker(cycle) c := &conn{fd: fd, transport: srv.newTransport(fd), flags: flags, cont: make(chan error)}
if !running {
for { c.close(errServerStopped)
select {
case <-srv.quit:
return
case <-tick.C:
// Collect all the non-connected static nodes
needed := []*discover.Node{}
srv.lock.RLock()
for id, node := range srv.staticNodes {
if _, ok := srv.peers[id]; !ok {
needed = append(needed, node)
}
}
srv.lock.RUnlock()
// Try to dial each of them (don't hang if server terminates)
for _, node := range needed {
glog.V(logger.Debug).Infof("Dialing static peer %v", node)
select {
case srv.staticDial <- node:
case <-srv.quit:
return return
} }
} // Run the encryption handshake.
} var err error
} if c.id, err = c.doEncHandshake(srv.PrivateKey, dialDest); err != nil {
} glog.V(logger.Debug).Infof("%v faild enc handshake: %v", c, err)
c.close(err)
func (srv *Server) dialLoop() {
var (
dialed = make(chan *discover.Node)
dialing = make(map[discover.NodeID]bool)
findresults = make(chan []*discover.Node)
refresh = time.NewTimer(0)
)
defer srv.loopWG.Done()
defer refresh.Stop()
// Limit the number of concurrent dials
tokens := maxDialingConns
if srv.MaxPendingPeers > 0 {
tokens = srv.MaxPendingPeers
}
slots := make(chan struct{}, tokens)
for i := 0; i < tokens; i++ {
slots <- struct{}{}
}
dial := func(dest *discover.Node) {
// Don't dial nodes that would fail the checks in addPeer.
// This is important because the connection handshake is a lot
// of work and we'd rather avoid doing that work for peers
// that can't be added.
srv.lock.RLock()
ok, _ := srv.checkPeer(dest.ID)
srv.lock.RUnlock()
if !ok || dialing[dest.ID] {
return return
} }
// Request a dial slot to prevent CPU exhaustion // For dialed connections, check that the remote public key matches.
<-slots if dialDest != nil && c.id != dialDest.ID {
c.close(DiscUnexpectedIdentity)
dialing[dest.ID] = true glog.V(logger.Debug).Infof("%v dialed identity mismatch, want %x", c, dialDest.ID[:8])
srv.peerWG.Add(1)
go func() {
srv.dialNode(dest)
slots <- struct{}{}
dialed <- dest
}()
}
srv.ntab.Bootstrap(srv.BootstrapNodes)
for {
select {
case <-refresh.C:
// Grab some nodes to connect to if we're not at capacity.
srv.lock.RLock()
needpeers := len(srv.peers) < srv.MaxPeers/2
srv.lock.RUnlock()
if needpeers {
go func() {
var target discover.NodeID
rand.Read(target[:])
findresults <- srv.ntab.Lookup(target)
}()
} else {
// Make sure we check again if the peer count falls
// below MaxPeers.
refresh.Reset(refreshPeersInterval)
}
case dest := <-srv.staticDial:
dial(dest)
case dests := <-findresults:
for _, dest := range dests {
dial(dest)
}
refresh.Reset(refreshPeersInterval)
case dest := <-dialed:
delete(dialing, dest.ID)
if len(dialing) == 0 {
// Check again immediately after dialing all current candidates.
refresh.Reset(0)
}
case <-srv.quit:
// TODO: maybe wait for active dials
return return
} }
if err := srv.checkpoint(c, srv.posthandshake); err != nil {
glog.V(logger.Debug).Infof("%v failed checkpoint posthandshake: %v", c, err)
c.close(err)
return
} }
} // Run the protocol handshake
phs, err := c.doProtoHandshake(srv.ourHandshake)
func (srv *Server) dialNode(dest *discover.Node) {
addr := &net.TCPAddr{IP: dest.IP, Port: int(dest.TCP)}
glog.V(logger.Debug).Infof("Dialing %v\n", dest)
conn, err := srv.Dialer.Dial("tcp", addr.String())
if err != nil { if err != nil {
// dialLoop adds to the wait group counter when launching glog.V(logger.Debug).Infof("%v failed proto handshake: %v", c, err)
// dialNode, so we need to count it down again. startPeer also c.close(err)
// does that when an error occurs.
srv.peerWG.Done()
glog.V(logger.Detail).Infof("dial error: %v", err)
return return
} }
srv.startPeer(conn, dest) if phs.ID != c.id {
} glog.V(logger.Debug).Infof("%v wrong proto handshake identity: %x", c, phs.ID[:8])
c.close(DiscUnexpectedIdentity)
func (srv *Server) startPeer(fd net.Conn, dest *discover.Node) {
// TODO: handle/store session token
// Run setupFunc, which should create an authenticated connection
// and run the capability exchange. Note that any early error
// returns during that exchange need to call peerWG.Done because
// the callers of startPeer added the peer to the wait group already.
fd.SetDeadline(time.Now().Add(handshakeTimeout))
conn, err := srv.setupFunc(fd, srv.PrivateKey, srv.ourHandshake, dest, srv.keepconn)
if err != nil {
fd.Close()
glog.V(logger.Debug).Infof("Handshake with %v failed: %v", fd.RemoteAddr(), err)
srv.peerWG.Done()
return return
} }
conn.MsgReadWriter = &netWrapper{ c.caps, c.name = phs.Caps, phs.Name
wrapped: conn.MsgReadWriter, if err := srv.checkpoint(c, srv.addpeer); err != nil {
conn: fd, rtimeout: frameReadTimeout, wtimeout: frameWriteTimeout, glog.V(logger.Debug).Infof("%v failed checkpoint addpeer: %v", c, err)
} c.close(err)
p := newPeer(fd, conn, srv.Protocols)
if ok, reason := srv.addPeer(conn, p); !ok {
glog.V(logger.Detail).Infof("Not adding %v (%v)\n", p, reason)
p.politeDisconnect(reason)
srv.peerWG.Done()
return return
} }
// The handshakes are done and it passed all checks. // If the checks completed successfully, runPeer has now been
// Spawn the Peer loops. // launched by run.
go srv.runPeer(p)
} }
// preflight checks whether a connection should be kept. it runs // checkpoint sends the conn to run, which performs the
// after the encryption handshake, as soon as the remote identity is // post-handshake checks for the stage (posthandshake, addpeer).
// known. func (srv *Server) checkpoint(c *conn, stage chan<- *conn) error {
func (srv *Server) keepconn(id discover.NodeID) bool { select {
srv.lock.RLock() case stage <- c:
defer srv.lock.RUnlock() case <-srv.quit:
if _, ok := srv.staticNodes[id]; ok { return errServerStopped
return true // static nodes are always allowed
} }
if _, ok := srv.trustedNodes[id]; ok { select {
return true // trusted nodes are always allowed case err := <-c.cont:
return err
case <-srv.quit:
return errServerStopped
} }
return len(srv.peers) < srv.MaxPeers
} }
// runPeer runs in its own goroutine for each peer.
// it waits until the Peer logic returns and removes
// the peer.
func (srv *Server) runPeer(p *Peer) { func (srv *Server) runPeer(p *Peer) {
glog.V(logger.Debug).Infof("Added %v\n", p) glog.V(logger.Debug).Infof("Added %v\n", p)
srvjslog.LogJson(&logger.P2PConnected{ srvjslog.LogJson(&logger.P2PConnected{
@ -511,58 +607,18 @@ func (srv *Server) runPeer(p *Peer) {
RemoteVersionString: p.Name(), RemoteVersionString: p.Name(),
NumConnections: srv.PeerCount(), NumConnections: srv.PeerCount(),
}) })
if srv.newPeerHook != nil { if srv.newPeerHook != nil {
srv.newPeerHook(p) srv.newPeerHook(p)
} }
discreason := p.run() discreason := p.run()
srv.removePeer(p) // Note: run waits for existing peers to be sent on srv.delpeer
// before returning, so this send should not select on srv.quit.
srv.delpeer <- p
glog.V(logger.Debug).Infof("Removed %v (%v)\n", p, discreason) glog.V(logger.Debug).Infof("Removed %v (%v)\n", p, discreason)
srvjslog.LogJson(&logger.P2PDisconnected{ srvjslog.LogJson(&logger.P2PDisconnected{
RemoteId: p.ID().String(), RemoteId: p.ID().String(),
NumConnections: srv.PeerCount(), NumConnections: srv.PeerCount(),
}) })
} }
func (srv *Server) addPeer(conn *conn, p *Peer) (bool, DiscReason) {
// drop connections with no matching protocols.
if len(srv.Protocols) > 0 && countMatchingProtocols(srv.Protocols, conn.protoHandshake.Caps) == 0 {
return false, DiscUselessPeer
}
// add the peer if it passes the other checks.
srv.lock.Lock()
defer srv.lock.Unlock()
if ok, reason := srv.checkPeer(conn.ID); !ok {
return false, reason
}
srv.peers[conn.ID] = p
return true, 0
}
// checkPeer verifies whether a peer looks promising and should be allowed/kept
// in the pool, or if it's of no use.
func (srv *Server) checkPeer(id discover.NodeID) (bool, DiscReason) {
// First up, figure out if the peer is static or trusted
_, static := srv.staticNodes[id]
trusted := srv.trustedNodes[id]
// Make sure the peer passes all required checks
switch {
case !srv.running:
return false, DiscQuitting
case !static && !trusted && len(srv.peers) >= srv.MaxPeers:
return false, DiscTooManyPeers
case srv.peers[id] != nil:
return false, DiscAlreadyConnected
case id == srv.ntab.Self().ID:
return false, DiscSelf
default:
return true, 0
}
}
func (srv *Server) removePeer(p *Peer) {
srv.lock.Lock()
delete(srv.peers, p.ID())
srv.lock.Unlock()
srv.peerWG.Done()
}

View File

@ -2,8 +2,10 @@ package p2p
import ( import (
"crypto/ecdsa" "crypto/ecdsa"
"errors"
"math/rand" "math/rand"
"net" "net"
"reflect"
"testing" "testing"
"time" "time"
@ -12,29 +14,50 @@ import (
"github.com/ethereum/go-ethereum/p2p/discover" "github.com/ethereum/go-ethereum/p2p/discover"
) )
func startTestServer(t *testing.T, pf newPeerHook) *Server { func init() {
// glog.SetV(6)
// glog.SetToStderr(true)
}
type testTransport struct {
id discover.NodeID
*rlpx
closeErr error
}
func newTestTransport(id discover.NodeID, fd net.Conn) transport {
wrapped := newRLPX(fd).(*rlpx)
wrapped.rw = newRLPXFrameRW(fd, secrets{
MAC: zero16,
AES: zero16,
IngressMAC: sha3.NewKeccak256(),
EgressMAC: sha3.NewKeccak256(),
})
return &testTransport{id: id, rlpx: wrapped}
}
func (c *testTransport) doEncHandshake(prv *ecdsa.PrivateKey, dialDest *discover.Node) (discover.NodeID, error) {
return c.id, nil
}
func (c *testTransport) doProtoHandshake(our *protoHandshake) (*protoHandshake, error) {
return &protoHandshake{ID: c.id, Name: "test"}, nil
}
func (c *testTransport) close(err error) {
c.rlpx.fd.Close()
c.closeErr = err
}
func startTestServer(t *testing.T, id discover.NodeID, pf func(*Peer)) *Server {
server := &Server{ server := &Server{
Name: "test", Name: "test",
MaxPeers: 10, MaxPeers: 10,
ListenAddr: "127.0.0.1:0", ListenAddr: "127.0.0.1:0",
PrivateKey: newkey(), PrivateKey: newkey(),
newPeerHook: pf, newPeerHook: pf,
setupFunc: func(fd net.Conn, prv *ecdsa.PrivateKey, our *protoHandshake, dial *discover.Node, keepconn func(discover.NodeID) bool) (*conn, error) { newTransport: func(fd net.Conn) transport { return newTestTransport(id, fd) },
id := randomID()
if !keepconn(id) {
return nil, DiscAlreadyConnected
}
rw := newRlpxFrameRW(fd, secrets{
MAC: zero16,
AES: zero16,
IngressMAC: sha3.NewKeccak256(),
EgressMAC: sha3.NewKeccak256(),
})
return &conn{
MsgReadWriter: rw,
protoHandshake: &protoHandshake{ID: id, Version: baseProtocolVersion},
}, nil
},
} }
if err := server.Start(); err != nil { if err := server.Start(); err != nil {
t.Fatalf("Could not start server: %v", err) t.Fatalf("Could not start server: %v", err)
@ -45,7 +68,11 @@ func startTestServer(t *testing.T, pf newPeerHook) *Server {
func TestServerListen(t *testing.T) { func TestServerListen(t *testing.T) {
// start the test server // start the test server
connected := make(chan *Peer) connected := make(chan *Peer)
srv := startTestServer(t, func(p *Peer) { remid := randomID()
srv := startTestServer(t, remid, func(p *Peer) {
if p.ID() != remid {
t.Error("peer func called with wrong node id")
}
if p == nil { if p == nil {
t.Error("peer func called with nil conn") t.Error("peer func called with nil conn")
} }
@ -67,6 +94,10 @@ func TestServerListen(t *testing.T) {
t.Errorf("peer started with wrong conn: got %v, want %v", t.Errorf("peer started with wrong conn: got %v, want %v",
peer.LocalAddr(), conn.RemoteAddr()) peer.LocalAddr(), conn.RemoteAddr())
} }
peers := srv.Peers()
if !reflect.DeepEqual(peers, []*Peer{peer}) {
t.Errorf("Peers mismatch: got %v, want %v", peers, []*Peer{peer})
}
case <-time.After(1 * time.Second): case <-time.After(1 * time.Second):
t.Error("server did not accept within one second") t.Error("server did not accept within one second")
} }
@ -92,23 +123,33 @@ func TestServerDial(t *testing.T) {
// start the server // start the server
connected := make(chan *Peer) connected := make(chan *Peer)
srv := startTestServer(t, func(p *Peer) { connected <- p }) remid := randomID()
srv := startTestServer(t, remid, func(p *Peer) { connected <- p })
defer close(connected) defer close(connected)
defer srv.Stop() defer srv.Stop()
// tell the server to connect // tell the server to connect
tcpAddr := listener.Addr().(*net.TCPAddr) tcpAddr := listener.Addr().(*net.TCPAddr)
srv.staticDial <- &discover.Node{IP: tcpAddr.IP, TCP: uint16(tcpAddr.Port)} srv.AddPeer(&discover.Node{ID: remid, IP: tcpAddr.IP, TCP: uint16(tcpAddr.Port)})
select { select {
case conn := <-accepted: case conn := <-accepted:
select { select {
case peer := <-connected: case peer := <-connected:
if peer.ID() != remid {
t.Errorf("peer has wrong id")
}
if peer.Name() != "test" {
t.Errorf("peer has wrong name")
}
if peer.RemoteAddr().String() != conn.LocalAddr().String() { if peer.RemoteAddr().String() != conn.LocalAddr().String() {
t.Errorf("peer started with wrong conn: got %v, want %v", t.Errorf("peer started with wrong conn: got %v, want %v",
peer.RemoteAddr(), conn.LocalAddr()) peer.RemoteAddr(), conn.LocalAddr())
} }
// TODO: validate more fields peers := srv.Peers()
if !reflect.DeepEqual(peers, []*Peer{peer}) {
t.Errorf("Peers mismatch: got %v, want %v", peers, []*Peer{peer})
}
case <-time.After(1 * time.Second): case <-time.After(1 * time.Second):
t.Error("server did not launch peer within one second") t.Error("server did not launch peer within one second")
} }
@ -118,331 +159,250 @@ func TestServerDial(t *testing.T) {
} }
} }
// This test checks that tasks generated by dialstate are
// actually executed and taskdone is called for them.
func TestServerTaskScheduling(t *testing.T) {
var (
done = make(chan *testTask)
quit, returned = make(chan struct{}), make(chan struct{})
tc = 0
tg = taskgen{
newFunc: func(running int, peers map[discover.NodeID]*Peer) []task {
tc++
return []task{&testTask{index: tc - 1}}
},
doneFunc: func(t task) {
select {
case done <- t.(*testTask):
case <-quit:
}
},
}
)
// The Server in this test isn't actually running
// because we're only interested in what run does.
srv := &Server{
MaxPeers: 10,
quit: make(chan struct{}),
ntab: fakeTable{},
running: true,
}
srv.loopWG.Add(1)
go func() {
srv.run(tg)
close(returned)
}()
var gotdone []*testTask
for i := 0; i < 100; i++ {
gotdone = append(gotdone, <-done)
}
for i, task := range gotdone {
if task.index != i {
t.Errorf("task %d has wrong index, got %d", i, task.index)
break
}
if !task.called {
t.Errorf("task %d was not called", i)
break
}
}
close(quit)
srv.Stop()
select {
case <-returned:
case <-time.After(500 * time.Millisecond):
t.Error("Server.run did not return within 500ms")
}
}
type taskgen struct {
newFunc func(running int, peers map[discover.NodeID]*Peer) []task
doneFunc func(task)
}
func (tg taskgen) newTasks(running int, peers map[discover.NodeID]*Peer, now time.Time) []task {
return tg.newFunc(running, peers)
}
func (tg taskgen) taskDone(t task, now time.Time) {
tg.doneFunc(t)
}
func (tg taskgen) addStatic(*discover.Node) {
}
type testTask struct {
index int
called bool
}
func (t *testTask) Do(srv *Server) {
t.called = true
}
// This test checks that connections are disconnected // This test checks that connections are disconnected
// just after the encryption handshake when the server is // just after the encryption handshake when the server is
// at capacity. // at capacity. Trusted connections should still be accepted.
// func TestServerAtCap(t *testing.T) {
// It also serves as a light-weight integration test. trustedID := randomID()
func TestServerDisconnectAtCap(t *testing.T) {
started := make(chan *Peer)
srv := &Server{ srv := &Server{
ListenAddr: "127.0.0.1:0",
PrivateKey: newkey(), PrivateKey: newkey(),
MaxPeers: 10, MaxPeers: 10,
NoDial: true, NoDial: true,
// This hook signals that the peer was actually started. We TrustedNodes: []*discover.Node{{ID: trustedID}},
// need to wait for the peer to be started before dialing the
// next connection to get a deterministic peer count.
newPeerHook: func(p *Peer) { started <- p },
} }
if err := srv.Start(); err != nil { if err := srv.Start(); err != nil {
t.Fatal(err) t.Fatalf("could not start: %v", err)
} }
defer srv.Stop() defer srv.Stop()
nconns := srv.MaxPeers + 1 newconn := func(id discover.NodeID) *conn {
dialer := &net.Dialer{Deadline: time.Now().Add(3 * time.Second)} fd, _ := net.Pipe()
for i := 0; i < nconns; i++ { tx := newTestTransport(id, fd)
conn, err := dialer.Dial("tcp", srv.ListenAddr) return &conn{fd: fd, transport: tx, flags: inboundConn, id: id, cont: make(chan error)}
if err != nil {
t.Fatalf("conn %d: dial error: %v", i, err)
}
// Close the connection when the test ends, before
// shutting down the server.
defer conn.Close()
// Run the handshakes just like a real peer would.
key := newkey()
hs := &protoHandshake{Version: baseProtocolVersion, ID: discover.PubkeyID(&key.PublicKey)}
_, err = setupConn(conn, key, hs, srv.Self(), keepalways)
if i == nconns-1 {
// When handling the last connection, the server should
// disconnect immediately instead of running the protocol
// handshake.
if err != DiscTooManyPeers {
t.Errorf("conn %d: got error %q, expected %q", i, err, DiscTooManyPeers)
}
} else {
// For all earlier connections, the handshake should go through.
if err != nil {
t.Fatalf("conn %d: unexpected error: %v", i, err)
}
// Wait for runPeer to be started.
<-started
}
}
} }
// Tests that static peers are (re)connected, and done so even above max peers. // Inject a few connections to fill up the peer set.
func TestServerStaticPeers(t *testing.T) { for i := 0; i < 10; i++ {
// Create a test server with limited connection slots c := newconn(randomID())
started := make(chan *Peer) if err := srv.checkpoint(c, srv.addpeer); err != nil {
server := &Server{ t.Fatalf("could not add conn %d: %v", i, err)
ListenAddr: "127.0.0.1:0",
PrivateKey: newkey(),
MaxPeers: 3,
newPeerHook: func(p *Peer) { started <- p },
staticCycle: time.Second,
}
if err := server.Start(); err != nil {
t.Fatal(err)
}
defer server.Stop()
// Fill up all the slots on the server
dialer := &net.Dialer{Deadline: time.Now().Add(3 * time.Second)}
for i := 0; i < server.MaxPeers; i++ {
// Establish a new connection
conn, err := dialer.Dial("tcp", server.ListenAddr)
if err != nil {
t.Fatalf("conn %d: dial error: %v", i, err)
}
defer conn.Close()
// Run the handshakes just like a real peer would, and wait for completion
key := newkey()
shake := &protoHandshake{Version: baseProtocolVersion, ID: discover.PubkeyID(&key.PublicKey)}
if _, err = setupConn(conn, key, shake, server.Self(), keepalways); err != nil {
t.Fatalf("conn %d: unexpected error: %v", i, err)
}
<-started
}
// Open a TCP listener to accept static connections
listener, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("failed to setup listener: %v", err)
}
defer listener.Close()
connected := make(chan net.Conn)
go func() {
for i := 0; i < 3; i++ {
conn, err := listener.Accept()
if err == nil {
connected <- conn
} }
} }
}() // Try inserting a non-trusted connection.
// Inject a static node and wait for a remote dial, then redial, then nothing c := newconn(randomID())
addr := listener.Addr().(*net.TCPAddr) if err := srv.checkpoint(c, srv.posthandshake); err != DiscTooManyPeers {
static := &discover.Node{ t.Error("wrong error for insert:", err)
ID: discover.PubkeyID(&newkey().PublicKey),
IP: addr.IP,
TCP: uint16(addr.Port),
} }
server.AddPeer(static) // Try inserting a trusted connection.
c = newconn(trustedID)
select { if err := srv.checkpoint(c, srv.posthandshake); err != nil {
case conn := <-connected: t.Error("unexpected error for trusted conn @posthandshake:", err)
// Close the first connection, expect redial }
conn.Close() if !c.is(trustedConn) {
t.Error("Server did not set trusted flag")
case <-time.After(2 * server.staticCycle):
t.Fatalf("remote dial timeout")
} }
select {
case conn := <-connected:
// Keep the second connection, don't expect redial
defer conn.Close()
case <-time.After(2 * server.staticCycle):
t.Fatalf("remote re-dial timeout")
} }
select { func TestServerSetupConn(t *testing.T) {
case <-time.After(2 * server.staticCycle): id := randomID()
// Timeout as no dial occurred srvkey := newkey()
srvid := discover.PubkeyID(&srvkey.PublicKey)
tests := []struct {
dontstart bool
tt *setupTransport
flags connFlag
dialDest *discover.Node
case <-connected: wantCloseErr error
t.Fatalf("connected node dialed") wantCalls string
} }{
{
dontstart: true,
tt: &setupTransport{id: id},
wantCalls: "close,",
wantCloseErr: errServerStopped,
},
{
tt: &setupTransport{id: id, encHandshakeErr: errors.New("read error")},
flags: inboundConn,
wantCalls: "doEncHandshake,close,",
wantCloseErr: errors.New("read error"),
},
{
tt: &setupTransport{id: id},
dialDest: &discover.Node{ID: randomID()},
flags: dynDialedConn,
wantCalls: "doEncHandshake,close,",
wantCloseErr: DiscUnexpectedIdentity,
},
{
tt: &setupTransport{id: id, phs: &protoHandshake{ID: randomID()}},
dialDest: &discover.Node{ID: id},
flags: dynDialedConn,
wantCalls: "doEncHandshake,doProtoHandshake,close,",
wantCloseErr: DiscUnexpectedIdentity,
},
{
tt: &setupTransport{id: id, protoHandshakeErr: errors.New("foo")},
dialDest: &discover.Node{ID: id},
flags: dynDialedConn,
wantCalls: "doEncHandshake,doProtoHandshake,close,",
wantCloseErr: errors.New("foo"),
},
{
tt: &setupTransport{id: srvid, phs: &protoHandshake{ID: srvid}},
flags: inboundConn,
wantCalls: "doEncHandshake,close,",
wantCloseErr: DiscSelf,
},
{
tt: &setupTransport{id: id, phs: &protoHandshake{ID: id}},
flags: inboundConn,
wantCalls: "doEncHandshake,doProtoHandshake,close,",
wantCloseErr: DiscUselessPeer,
},
} }
// Tests that trusted peers and can connect above max peer caps. for i, test := range tests {
func TestServerTrustedPeers(t *testing.T) { srv := &Server{
PrivateKey: srvkey,
// Create a trusted peer to accept connections from
key := newkey()
trusted := &discover.Node{
ID: discover.PubkeyID(&key.PublicKey),
}
// Create a test server with limited connection slots
started := make(chan *Peer)
server := &Server{
ListenAddr: "127.0.0.1:0",
PrivateKey: newkey(),
MaxPeers: 3,
NoDial: true,
TrustedNodes: []*discover.Node{trusted},
newPeerHook: func(p *Peer) { started <- p },
}
if err := server.Start(); err != nil {
t.Fatal(err)
}
defer server.Stop()
// Fill up all the slots on the server
dialer := &net.Dialer{Deadline: time.Now().Add(3 * time.Second)}
for i := 0; i < server.MaxPeers; i++ {
// Establish a new connection
conn, err := dialer.Dial("tcp", server.ListenAddr)
if err != nil {
t.Fatalf("conn %d: dial error: %v", i, err)
}
defer conn.Close()
// Run the handshakes just like a real peer would, and wait for completion
key := newkey()
shake := &protoHandshake{Version: baseProtocolVersion, ID: discover.PubkeyID(&key.PublicKey)}
if _, err = setupConn(conn, key, shake, server.Self(), keepalways); err != nil {
t.Fatalf("conn %d: unexpected error: %v", i, err)
}
<-started
}
// Dial from the trusted peer, ensure connection is accepted
conn, err := dialer.Dial("tcp", server.ListenAddr)
if err != nil {
t.Fatalf("trusted node: dial error: %v", err)
}
defer conn.Close()
shake := &protoHandshake{Version: baseProtocolVersion, ID: trusted.ID}
if _, err = setupConn(conn, key, shake, server.Self(), keepalways); err != nil {
t.Fatalf("trusted node: unexpected error: %v", err)
}
select {
case <-started:
// Ok, trusted peer accepted
case <-time.After(100 * time.Millisecond):
t.Fatalf("trusted node timeout")
}
}
// Tests that a failed dial will temporarily throttle a peer.
func TestServerMaxPendingDials(t *testing.T) {
// Start a simple test server
server := &Server{
ListenAddr: "127.0.0.1:0",
PrivateKey: newkey(),
MaxPeers: 10, MaxPeers: 10,
MaxPendingPeers: 1,
}
if err := server.Start(); err != nil {
t.Fatal("failed to start test server: %v", err)
}
defer server.Stop()
// Simulate two separate remote peers
peers := make(chan *discover.Node, 2)
conns := make(chan net.Conn, 2)
for i := 0; i < 2; i++ {
listener, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("listener %d: failed to setup: %v", i, err)
}
defer listener.Close()
addr := listener.Addr().(*net.TCPAddr)
peers <- &discover.Node{
ID: discover.PubkeyID(&newkey().PublicKey),
IP: addr.IP,
TCP: uint16(addr.Port),
}
go func() {
conn, err := listener.Accept()
if err == nil {
conns <- conn
}
}()
}
// Request a dial for both peers
go func() {
for i := 0; i < 2; i++ {
server.staticDial <- <-peers // hack piggybacking the static implementation
}
}()
// Make sure only one outbound connection goes through
var conn net.Conn
select {
case conn = <-conns:
case <-time.After(100 * time.Millisecond):
t.Fatalf("first dial timeout")
}
select {
case conn = <-conns:
t.Fatalf("second dial completed prematurely")
case <-time.After(100 * time.Millisecond):
}
// Finish the first dial, check the second
conn.Close()
select {
case conn = <-conns:
conn.Close()
case <-time.After(100 * time.Millisecond):
t.Fatalf("second dial timeout")
}
}
func TestServerMaxPendingAccepts(t *testing.T) {
// Start a test server and a peer sink for synchronization
started := make(chan *Peer)
server := &Server{
ListenAddr: "127.0.0.1:0",
PrivateKey: newkey(),
MaxPeers: 10,
MaxPendingPeers: 1,
NoDial: true, NoDial: true,
newPeerHook: func(p *Peer) { started <- p }, Protocols: []Protocol{discard},
newTransport: func(fd net.Conn) transport { return test.tt },
}
if !test.dontstart {
if err := srv.Start(); err != nil {
t.Fatalf("couldn't start server: %v", err)
}
}
p1, _ := net.Pipe()
srv.setupConn(p1, test.flags, test.dialDest)
if !reflect.DeepEqual(test.tt.closeErr, test.wantCloseErr) {
t.Errorf("test %d: close error mismatch: got %q, want %q", i, test.tt.closeErr, test.wantCloseErr)
}
if test.tt.calls != test.wantCalls {
t.Errorf("test %d: calls mismatch: got %q, want %q", i, test.tt.calls, test.wantCalls)
}
} }
if err := server.Start(); err != nil {
t.Fatal("failed to start test server: %v", err)
} }
defer server.Stop()
// Try and connect to the server on multiple threads concurrently type setupTransport struct {
conns := make([]net.Conn, 2) id discover.NodeID
for i := 0; i < 2; i++ { encHandshakeErr error
dialer := &net.Dialer{Deadline: time.Now().Add(3 * time.Second)}
conn, err := dialer.Dial("tcp", server.ListenAddr) phs *protoHandshake
if err != nil { protoHandshakeErr error
t.Fatalf("failed to dial server: %v", err)
}
conns[i] = conn
}
// Check that a handshake on the second doesn't pass
go func() {
key := newkey()
shake := &protoHandshake{Version: baseProtocolVersion, ID: discover.PubkeyID(&key.PublicKey)}
if _, err := setupConn(conns[1], key, shake, server.Self(), keepalways); err != nil {
t.Fatalf("failed to run handshake: %v", err)
}
}()
select {
case <-started:
t.Fatalf("handshake on second connection accepted")
case <-time.After(time.Second): calls string
closeErr error
} }
// Shake on first, check that both go through
go func() { func (c *setupTransport) doEncHandshake(prv *ecdsa.PrivateKey, dialDest *discover.Node) (discover.NodeID, error) {
key := newkey() c.calls += "doEncHandshake,"
shake := &protoHandshake{Version: baseProtocolVersion, ID: discover.PubkeyID(&key.PublicKey)} return c.id, c.encHandshakeErr
if _, err := setupConn(conns[0], key, shake, server.Self(), keepalways); err != nil {
t.Fatalf("failed to run handshake: %v", err)
} }
}() func (c *setupTransport) doProtoHandshake(our *protoHandshake) (*protoHandshake, error) {
for i := 0; i < 2; i++ { c.calls += "doProtoHandshake,"
select { if c.protoHandshakeErr != nil {
case <-started: return nil, c.protoHandshakeErr
case <-time.After(time.Second):
t.Fatalf("peer %d: handshake timeout", i)
} }
return c.phs, nil
} }
func (c *setupTransport) close(err error) {
c.calls += "close,"
c.closeErr = err
}
// setupConn shouldn't write to/read from the connection.
func (c *setupTransport) WriteMsg(Msg) error {
panic("WriteMsg called on setupTransport")
}
func (c *setupTransport) ReadMsg() (Msg, error) {
panic("ReadMsg called on setupTransport")
} }
func newkey() *ecdsa.PrivateKey { func newkey() *ecdsa.PrivateKey {
@ -459,7 +419,3 @@ func randomID() (id discover.NodeID) {
} }
return id return id
} }
func keepalways(id discover.NodeID) bool {
return true
}