mirror of
https://github.com/status-im/op-geth.git
synced 2025-01-15 01:04:11 +00:00
p2p: improve disconnect signaling at handshake time
As of this commit, p2p will disconnect nodes directly after the encryption handshake if too many peer connections are active. Errors in the protocol handshake packet are now handled more politely by sending a disconnect packet before closing the connection.
This commit is contained in:
parent
99a1db2d40
commit
b3c058a9e4
@ -68,50 +68,61 @@ type protoHandshake struct {
|
|||||||
// setupConn starts a protocol session on the given connection.
|
// setupConn starts a protocol session on the given connection.
|
||||||
// It runs the encryption handshake and the protocol handshake.
|
// It runs the encryption handshake and the protocol handshake.
|
||||||
// If dial is non-nil, the connection the local node is the initiator.
|
// If dial is non-nil, the connection the local node is the initiator.
|
||||||
func setupConn(fd net.Conn, prv *ecdsa.PrivateKey, our *protoHandshake, dial *discover.Node) (*conn, error) {
|
// If atcap is true, the connection will be disconnected with DiscTooManyPeers
|
||||||
|
// after the key exchange.
|
||||||
|
func setupConn(fd net.Conn, prv *ecdsa.PrivateKey, our *protoHandshake, dial *discover.Node, atcap bool) (*conn, error) {
|
||||||
if dial == nil {
|
if dial == nil {
|
||||||
return setupInboundConn(fd, prv, our)
|
return setupInboundConn(fd, prv, our, atcap)
|
||||||
} else {
|
} else {
|
||||||
return setupOutboundConn(fd, prv, our, dial)
|
return setupOutboundConn(fd, prv, our, dial, atcap)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func setupInboundConn(fd net.Conn, prv *ecdsa.PrivateKey, our *protoHandshake) (*conn, error) {
|
func setupInboundConn(fd net.Conn, prv *ecdsa.PrivateKey, our *protoHandshake, atcap bool) (*conn, error) {
|
||||||
secrets, err := receiverEncHandshake(fd, prv, nil)
|
secrets, err := receiverEncHandshake(fd, prv, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("encryption handshake failed: %v", err)
|
return nil, fmt.Errorf("encryption handshake failed: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Run the protocol handshake using authenticated messages.
|
|
||||||
rw := newRlpxFrameRW(fd, secrets)
|
rw := newRlpxFrameRW(fd, secrets)
|
||||||
rhs, err := readProtocolHandshake(rw, our)
|
if atcap {
|
||||||
|
SendItems(rw, discMsg, DiscTooManyPeers)
|
||||||
|
return nil, errors.New("we have too many peers")
|
||||||
|
}
|
||||||
|
// Run the protocol handshake using authenticated messages.
|
||||||
|
rhs, err := readProtocolHandshake(rw, secrets.RemoteID, our)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
if rhs.ID != secrets.RemoteID {
|
|
||||||
return nil, errors.New("node ID in protocol handshake does not match encryption handshake")
|
|
||||||
}
|
|
||||||
// TODO: validate that handshake node ID matches
|
|
||||||
if err := Send(rw, handshakeMsg, our); err != nil {
|
if err := Send(rw, handshakeMsg, our); err != nil {
|
||||||
return nil, fmt.Errorf("protocol write error: %v", err)
|
return nil, fmt.Errorf("protocol handshake write error: %v", err)
|
||||||
}
|
}
|
||||||
return &conn{rw, rhs}, nil
|
return &conn{rw, rhs}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func setupOutboundConn(fd net.Conn, prv *ecdsa.PrivateKey, our *protoHandshake, dial *discover.Node) (*conn, error) {
|
func setupOutboundConn(fd net.Conn, prv *ecdsa.PrivateKey, our *protoHandshake, dial *discover.Node, atcap bool) (*conn, error) {
|
||||||
secrets, err := initiatorEncHandshake(fd, prv, dial.ID, nil)
|
secrets, err := initiatorEncHandshake(fd, prv, dial.ID, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("encryption handshake failed: %v", err)
|
return nil, fmt.Errorf("encryption handshake failed: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Run the protocol handshake using authenticated messages.
|
|
||||||
rw := newRlpxFrameRW(fd, secrets)
|
rw := newRlpxFrameRW(fd, secrets)
|
||||||
if err := Send(rw, handshakeMsg, our); err != nil {
|
if atcap {
|
||||||
return nil, fmt.Errorf("protocol write error: %v", err)
|
SendItems(rw, discMsg, DiscTooManyPeers)
|
||||||
|
return nil, errors.New("we have too many peers")
|
||||||
}
|
}
|
||||||
rhs, err := readProtocolHandshake(rw, our)
|
// Run the protocol handshake using authenticated messages.
|
||||||
|
//
|
||||||
|
// Note that even though writing the handshake is first, we prefer
|
||||||
|
// returning the handshake read error. If the remote side
|
||||||
|
// disconnects us early with a valid reason, we should return it
|
||||||
|
// as the error so it can be tracked elsewhere.
|
||||||
|
werr := make(chan error)
|
||||||
|
go func() { werr <- Send(rw, handshakeMsg, our) }()
|
||||||
|
rhs, err := readProtocolHandshake(rw, secrets.RemoteID, our)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("protocol handshake read error: %v", err)
|
return nil, err
|
||||||
|
}
|
||||||
|
if err := <-werr; err != nil {
|
||||||
|
return nil, fmt.Errorf("protocol handshake write error: %v", err)
|
||||||
}
|
}
|
||||||
if rhs.ID != dial.ID {
|
if rhs.ID != dial.ID {
|
||||||
return nil, errors.New("dialed node id mismatch")
|
return nil, errors.New("dialed node id mismatch")
|
||||||
@ -398,18 +409,17 @@ func xor(one, other []byte) (xor []byte) {
|
|||||||
return xor
|
return xor
|
||||||
}
|
}
|
||||||
|
|
||||||
func readProtocolHandshake(r MsgReader, our *protoHandshake) (*protoHandshake, error) {
|
func readProtocolHandshake(rw MsgReadWriter, wantID discover.NodeID, our *protoHandshake) (*protoHandshake, error) {
|
||||||
// read and handle remote handshake
|
msg, err := rw.ReadMsg()
|
||||||
msg, err := r.ReadMsg()
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
if msg.Code == discMsg {
|
if msg.Code == discMsg {
|
||||||
// disconnect before protocol handshake is valid according to the
|
// disconnect before protocol handshake is valid according to the
|
||||||
// spec and we send it ourself if Server.addPeer fails.
|
// spec and we send it ourself if Server.addPeer fails.
|
||||||
var reason DiscReason
|
var reason [1]DiscReason
|
||||||
rlp.Decode(msg.Payload, &reason)
|
rlp.Decode(msg.Payload, &reason)
|
||||||
return nil, reason
|
return nil, reason[0]
|
||||||
}
|
}
|
||||||
if msg.Code != handshakeMsg {
|
if msg.Code != handshakeMsg {
|
||||||
return nil, fmt.Errorf("expected handshake, got %x", msg.Code)
|
return nil, fmt.Errorf("expected handshake, got %x", msg.Code)
|
||||||
@ -423,10 +433,16 @@ func readProtocolHandshake(r MsgReader, our *protoHandshake) (*protoHandshake, e
|
|||||||
}
|
}
|
||||||
// validate handshake info
|
// validate handshake info
|
||||||
if hs.Version != our.Version {
|
if hs.Version != our.Version {
|
||||||
return nil, newPeerError(errP2PVersionMismatch, "required version %d, received %d\n", baseProtocolVersion, hs.Version)
|
SendItems(rw, discMsg, DiscIncompatibleVersion)
|
||||||
|
return nil, fmt.Errorf("required version %d, received %d\n", baseProtocolVersion, hs.Version)
|
||||||
}
|
}
|
||||||
if (hs.ID == discover.NodeID{}) {
|
if (hs.ID == discover.NodeID{}) {
|
||||||
return nil, newPeerError(errPubkeyInvalid, "missing")
|
SendItems(rw, discMsg, DiscInvalidIdentity)
|
||||||
|
return nil, errors.New("invalid public key in handshake")
|
||||||
|
}
|
||||||
|
if hs.ID != wantID {
|
||||||
|
SendItems(rw, discMsg, DiscUnexpectedIdentity)
|
||||||
|
return nil, errors.New("handshake node ID does not match encryption handshake")
|
||||||
}
|
}
|
||||||
return &hs, nil
|
return &hs, nil
|
||||||
}
|
}
|
||||||
|
@ -143,7 +143,7 @@ func TestSetupConn(t *testing.T) {
|
|||||||
done := make(chan struct{})
|
done := make(chan struct{})
|
||||||
go func() {
|
go func() {
|
||||||
defer close(done)
|
defer close(done)
|
||||||
conn0, err := setupConn(fd0, prv0, hs0, node1)
|
conn0, err := setupConn(fd0, prv0, hs0, node1, false)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("outbound side error: %v", err)
|
t.Errorf("outbound side error: %v", err)
|
||||||
return
|
return
|
||||||
@ -156,7 +156,7 @@ func TestSetupConn(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
conn1, err := setupConn(fd1, prv1, hs1, nil)
|
conn1, err := setupConn(fd1, prv1, hs1, nil, false)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("inbound side error: %v", err)
|
t.Fatalf("inbound side error: %v", err)
|
||||||
}
|
}
|
||||||
|
@ -99,7 +99,7 @@ type Server struct {
|
|||||||
peerConnect chan *discover.Node
|
peerConnect chan *discover.Node
|
||||||
}
|
}
|
||||||
|
|
||||||
type setupFunc func(net.Conn, *ecdsa.PrivateKey, *protoHandshake, *discover.Node) (*conn, error)
|
type setupFunc func(net.Conn, *ecdsa.PrivateKey, *protoHandshake, *discover.Node, bool) (*conn, error)
|
||||||
type newPeerHook func(*Peer)
|
type newPeerHook func(*Peer)
|
||||||
|
|
||||||
// Peers returns all connected peers.
|
// Peers returns all connected peers.
|
||||||
@ -261,6 +261,11 @@ func (srv *Server) Stop() {
|
|||||||
srv.peerWG.Wait()
|
srv.peerWG.Wait()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Self returns the local node's endpoint information.
|
||||||
|
func (srv *Server) Self() *discover.Node {
|
||||||
|
return srv.ntab.Self()
|
||||||
|
}
|
||||||
|
|
||||||
// main loop for adding connections via listening
|
// main loop for adding connections via listening
|
||||||
func (srv *Server) listenLoop() {
|
func (srv *Server) listenLoop() {
|
||||||
defer srv.loopWG.Done()
|
defer srv.loopWG.Done()
|
||||||
@ -354,10 +359,6 @@ func (srv *Server) dialNode(dest *discover.Node) {
|
|||||||
srv.startPeer(conn, dest)
|
srv.startPeer(conn, dest)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (srv *Server) Self() *discover.Node {
|
|
||||||
return srv.ntab.Self()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (srv *Server) startPeer(fd net.Conn, dest *discover.Node) {
|
func (srv *Server) startPeer(fd net.Conn, dest *discover.Node) {
|
||||||
// TODO: handle/store session token
|
// TODO: handle/store session token
|
||||||
|
|
||||||
@ -366,7 +367,10 @@ func (srv *Server) startPeer(fd net.Conn, dest *discover.Node) {
|
|||||||
// returns during that exchange need to call peerWG.Done because
|
// returns during that exchange need to call peerWG.Done because
|
||||||
// the callers of startPeer added the peer to the wait group already.
|
// the callers of startPeer added the peer to the wait group already.
|
||||||
fd.SetDeadline(time.Now().Add(handshakeTimeout))
|
fd.SetDeadline(time.Now().Add(handshakeTimeout))
|
||||||
conn, err := srv.setupFunc(fd, srv.PrivateKey, srv.ourHandshake, dest)
|
srv.lock.RLock()
|
||||||
|
atcap := len(srv.peers) == srv.MaxPeers
|
||||||
|
srv.lock.RUnlock()
|
||||||
|
conn, err := srv.setupFunc(fd, srv.PrivateKey, srv.ourHandshake, dest, atcap)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
fd.Close()
|
fd.Close()
|
||||||
glog.V(logger.Debug).Infof("Handshake with %v failed: %v", fd.RemoteAddr(), err)
|
glog.V(logger.Debug).Infof("Handshake with %v failed: %v", fd.RemoteAddr(), err)
|
||||||
|
@ -22,7 +22,7 @@ func startTestServer(t *testing.T, pf newPeerHook) *Server {
|
|||||||
ListenAddr: "127.0.0.1:0",
|
ListenAddr: "127.0.0.1:0",
|
||||||
PrivateKey: newkey(),
|
PrivateKey: newkey(),
|
||||||
newPeerHook: pf,
|
newPeerHook: pf,
|
||||||
setupFunc: func(fd net.Conn, prv *ecdsa.PrivateKey, our *protoHandshake, dial *discover.Node) (*conn, error) {
|
setupFunc: func(fd net.Conn, prv *ecdsa.PrivateKey, our *protoHandshake, dial *discover.Node, atcap bool) (*conn, error) {
|
||||||
id := randomID()
|
id := randomID()
|
||||||
rw := newRlpxFrameRW(fd, secrets{
|
rw := newRlpxFrameRW(fd, secrets{
|
||||||
MAC: zero16,
|
MAC: zero16,
|
||||||
@ -163,6 +163,62 @@ func TestServerBroadcast(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// This test checks that connections are disconnected
|
||||||
|
// just after the encryption handshake when the server is
|
||||||
|
// at capacity.
|
||||||
|
//
|
||||||
|
// It also serves as a light-weight integration test.
|
||||||
|
func TestServerDisconnectAtCap(t *testing.T) {
|
||||||
|
defer testlog(t).detach()
|
||||||
|
|
||||||
|
started := make(chan *Peer)
|
||||||
|
srv := &Server{
|
||||||
|
ListenAddr: "127.0.0.1:0",
|
||||||
|
PrivateKey: newkey(),
|
||||||
|
MaxPeers: 10,
|
||||||
|
NoDial: true,
|
||||||
|
// This hook signals that the peer was actually started. We
|
||||||
|
// need to wait for the peer to be started before dialing the
|
||||||
|
// next connection to get a deterministic peer count.
|
||||||
|
newPeerHook: func(p *Peer) { started <- p },
|
||||||
|
}
|
||||||
|
if err := srv.Start(); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
defer srv.Stop()
|
||||||
|
|
||||||
|
nconns := srv.MaxPeers + 1
|
||||||
|
dialer := &net.Dialer{Deadline: time.Now().Add(3 * time.Second)}
|
||||||
|
for i := 0; i < nconns; i++ {
|
||||||
|
conn, err := dialer.Dial("tcp", srv.ListenAddr)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("conn %d: dial error: %v", i, err)
|
||||||
|
}
|
||||||
|
// Close the connection when the test ends, before
|
||||||
|
// shutting down the server.
|
||||||
|
defer conn.Close()
|
||||||
|
// Run the handshakes just like a real peer would.
|
||||||
|
key := newkey()
|
||||||
|
hs := &protoHandshake{Version: baseProtocolVersion, ID: discover.PubkeyID(&key.PublicKey)}
|
||||||
|
_, err = setupConn(conn, key, hs, srv.Self(), false)
|
||||||
|
if i == nconns-1 {
|
||||||
|
// When handling the last connection, the server should
|
||||||
|
// disconnect immediately instead of running the protocol
|
||||||
|
// handshake.
|
||||||
|
if err != DiscTooManyPeers {
|
||||||
|
t.Errorf("conn %d: got error %q, expected %q", i, err, DiscTooManyPeers)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// For all earlier connections, the handshake should go through.
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("conn %d: unexpected error: %v", i, err)
|
||||||
|
}
|
||||||
|
// Wait for runPeer to be started.
|
||||||
|
<-started
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func newkey() *ecdsa.PrivateKey {
|
func newkey() *ecdsa.PrivateKey {
|
||||||
key, err := crypto.GenerateKey()
|
key, err := crypto.GenerateKey()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
Loading…
x
Reference in New Issue
Block a user