diff --git a/libp2p/connmanager.nim b/libp2p/connmanager.nim index 04aded274..5299dc3bd 100644 --- a/libp2p/connmanager.nim +++ b/libp2p/connmanager.nim @@ -227,7 +227,7 @@ proc cleanupConn(c: ConnManager, conn: Connection) {.async.} = proc peerStartup(c: ConnManager, conn: Connection) {.async.} = try: - trace "Triggering peer and connection events on connect", conn + trace "Triggering connect events", conn let peerId = conn.peerInfo.peerId await c.triggerPeerEvents(peerId, PeerEvent.Joined) await c.triggerConnEvent( @@ -240,7 +240,7 @@ proc peerStartup(c: ConnManager, conn: Connection) {.async.} = proc peerCleanup(c: ConnManager, conn: Connection) {.async.} = try: - trace "Triggering peer and connection events on disconnect", conn + trace "Triggering disconnect events", conn let peerId = conn.peerInfo.peerId await c.triggerConnEvent( peerId, ConnEvent(kind: ConnEventKind.Disconnected)) @@ -418,6 +418,8 @@ proc dropPeer*(c: ConnManager, peerId: PeerID) {.async.} = await conn.close() trace "Dropped peer", peerId + trace "Peer dropped", peerId + proc close*(c: ConnManager) {.async.} = ## cleanup resources for the connection ## manager diff --git a/tests/testswitch.nim b/tests/testswitch.nim index 4d2a50c36..22f2e7322 100644 --- a/tests/testswitch.nim +++ b/tests/testswitch.nim @@ -1,6 +1,6 @@ {.used.} -import unittest, options +import unittest, options, sequtils import chronos import stew/byteutils import nimcrypto/sysrand @@ -248,14 +248,8 @@ suite "Switch": check not switch1.isConnected(switch2.peerInfo) check not switch2.isConnected(switch1.peerInfo) - var channelTracker = getTracker(LPChannelTrackerName) - # echo channelTracker.dump() - check channelTracker.isLeaked() == false - - var connTracker = getTracker(SecureConnTrackerName) - doAssert(not isNil(connTracker)) - # echo connTracker.dump() - check connTracker.isLeaked() == false + checkTracker(LPChannelTrackerName) + checkTracker(SecureConnTrackerName) await allFuturesThrowing( switch1.stop(), @@ -308,13 +302,8 @@ suite "Switch": check not switch1.isConnected(switch2.peerInfo) check not switch2.isConnected(switch1.peerInfo) - var bufferTracker = getTracker(LPChannelTrackerName) - # echo bufferTracker.dump() - check bufferTracker.isLeaked() == false - - var connTracker = getTracker(SecureConnTrackerName) - # echo connTracker.dump() - check connTracker.isLeaked() == false + checkTracker(LPChannelTrackerName) + checkTracker(SecureConnTrackerName) check: kinds == { @@ -373,13 +362,8 @@ suite "Switch": check not switch1.isConnected(switch2.peerInfo) check not switch2.isConnected(switch1.peerInfo) - var bufferTracker = getTracker(LPChannelTrackerName) - # echo bufferTracker.dump() - check bufferTracker.isLeaked() == false - - var connTracker = getTracker(SecureConnTrackerName) - # echo connTracker.dump() - check connTracker.isLeaked() == false + checkTracker(LPChannelTrackerName) + checkTracker(SecureConnTrackerName) check: kinds == { @@ -437,13 +421,8 @@ suite "Switch": check not switch1.isConnected(switch2.peerInfo) check not switch2.isConnected(switch1.peerInfo) - var bufferTracker = getTracker(LPChannelTrackerName) - # echo bufferTracker.dump() - check bufferTracker.isLeaked() == false - - var connTracker = getTracker(SecureConnTrackerName) - # echo connTracker.dump() - check connTracker.isLeaked() == false + checkTracker(LPChannelTrackerName) + checkTracker(SecureConnTrackerName) check: kinds == { @@ -501,13 +480,8 @@ suite "Switch": check not switch1.isConnected(switch2.peerInfo) check not switch2.isConnected(switch1.peerInfo) - var bufferTracker = getTracker(LPChannelTrackerName) - # echo bufferTracker.dump() - check bufferTracker.isLeaked() == false - - var connTracker = getTracker(SecureConnTrackerName) - # echo connTracker.dump() - check connTracker.isLeaked() == false + checkTracker(LPChannelTrackerName) + checkTracker(SecureConnTrackerName) check: kinds == { @@ -580,13 +554,8 @@ suite "Switch": check not switch2.isConnected(switch1.peerInfo) check not switch3.isConnected(switch1.peerInfo) - var bufferTracker = getTracker(LPChannelTrackerName) - # echo bufferTracker.dump() - check bufferTracker.isLeaked() == false - - var connTracker = getTracker(SecureConnTrackerName) - # echo connTracker.dump() - check connTracker.isLeaked() == false + checkTracker(LPChannelTrackerName) + checkTracker(SecureConnTrackerName) check: kinds == { @@ -602,6 +571,105 @@ suite "Switch": waitFor(testSwitch()) + test "e2e should allow dropping peer from connection events": + proc testSwitch() {.async, gcsafe.} = + var awaiters: seq[Future[void]] + + let rng = newRng() + # use same private keys to emulate two connection from same peer + let peerInfo = PeerInfo.init(PrivateKey.random(rng[]).tryGet()) + + var switches: seq[Switch] + var done = newFuture[void]() + var onConnect: Future[void] + proc hook(peerId: PeerID, event: ConnEvent) {.async, gcsafe.} = + case event.kind: + of ConnEventKind.Connected: + await onConnect + await switches[0].disconnect(peerInfo.peerId) # trigger disconnect + of ConnEventKind.Disconnected: + check not switches[0].isConnected(peerInfo.peerId) + await sleepAsync(1.millis) + done.complete() + + switches.add(newStandardSwitch( + rng = rng, + secureManagers = [SecureProtocol.Secio])) + + switches[0].addConnEventHandler(hook, ConnEventKind.Connected) + switches[0].addConnEventHandler(hook, ConnEventKind.Disconnected) + awaiters.add(await switches[0].start()) + + switches.add(newStandardSwitch( + privKey = some(peerInfo.privateKey), + rng = rng, + secureManagers = [SecureProtocol.Secio])) + onConnect = switches[1].connect(switches[0].peerInfo) + await onConnect + + await done + checkTracker(LPChannelTrackerName) + checkTracker(SecureConnTrackerName) + + await allFuturesThrowing( + switches.mapIt( it.stop() )) + await allFuturesThrowing(awaiters) + + waitFor(testSwitch()) + + test "e2e should allow dropping multiple connections for peer from connection events": + proc testSwitch() {.async, gcsafe.} = + var awaiters: seq[Future[void]] + + let rng = newRng() + # use same private keys to emulate two connection from same peer + let peerInfo = PeerInfo.init(PrivateKey.random(rng[]).tryGet()) + + var conns = 1 + var switches: seq[Switch] + var done = newFuture[void]() + var onConnect: Future[void] + proc hook(peerId: PeerID, event: ConnEvent) {.async, gcsafe.} = + case event.kind: + of ConnEventKind.Connected: + if conns == 5: + await onConnect + await switches[0].disconnect(peerInfo.peerId) # trigger disconnect + return + + conns.inc + of ConnEventKind.Disconnected: + if conns == 1: + check not switches[0].isConnected(peerInfo.peerId) + done.complete() + conns.dec + + switches.add(newStandardSwitch( + rng = rng, + secureManagers = [SecureProtocol.Secio])) + + switches[0].addConnEventHandler(hook, ConnEventKind.Connected) + switches[0].addConnEventHandler(hook, ConnEventKind.Disconnected) + awaiters.add(await switches[0].start()) + + for i in 1..5: + switches.add(newStandardSwitch( + privKey = some(peerInfo.privateKey), + rng = rng, + secureManagers = [SecureProtocol.Secio])) + onConnect = switches[i].connect(switches[0].peerInfo) + await onConnect + + await done + checkTracker(LPChannelTrackerName) + checkTracker(SecureConnTrackerName) + + await allFuturesThrowing( + switches.mapIt( it.stop() )) + await allFuturesThrowing(awaiters) + + waitFor(testSwitch()) + test "connect to inexistent peer": proc testSwitch() {.async, gcsafe.} = let switch2 = newStandardSwitch(secureManagers = [SecureProtocol.Noise])