From 351bda2b56e4389a91fb42199621b3617bdbbdd3 Mon Sep 17 00:00:00 2001 From: Tanguy Date: Mon, 23 Jan 2023 23:28:39 +0100 Subject: [PATCH] Add expected connections to connmngr (#845) Co-authored-by: diegomrsantos --- libp2p/connmanager.nim | 27 ++++++++++++++++++++++++++- tests/testconnmngr.nim | 41 +++++++++++++++++++++++++++++++++++++++-- 2 files changed, 65 insertions(+), 3 deletions(-) diff --git a/libp2p/connmanager.nim b/libp2p/connmanager.nim index fa163d6..a82c424 100644 --- a/libp2p/connmanager.nim +++ b/libp2p/connmanager.nim @@ -79,6 +79,7 @@ type muxed: Table[Connection, MuxerHolder] connEvents: array[ConnEventKind, OrderedSet[ConnEventHandler]] peerEvents: array[PeerEventKind, OrderedSet[PeerEventHandler]] + expectedConnections: Table[PeerId, Future[Connection]] peerStore*: PeerStore ConnectionSlot* = object @@ -220,6 +221,19 @@ proc triggerPeerEvents*(c: ConnManager, except CatchableError as exc: # handlers should not raise! warn "Exception in triggerPeerEvents", exc = exc.msg, peer = peerId +proc expectConnection*(c: ConnManager, p: PeerId): Future[Connection] {.async.} = + ## Wait for a peer to connect to us. This will bypass the `MaxConnectionsPerPeer` + if p in c.expectedConnections: + raise LPError.newException("Already expecting a connection from that peer") + + let future = newFuture[Connection]() + c.expectedConnections[p] = future + + try: + return await future + finally: + c.expectedConnections.del(p) + proc contains*(c: ConnManager, conn: Connection): bool = ## checks if a connection is being tracked by the ## connection manager @@ -396,7 +410,12 @@ proc storeConn*(c: ConnManager, conn: Connection) raise newException(LPError, "Connection closed or EOF") let peerId = conn.peerId - if c.conns.getOrDefault(peerId).len > c.maxConnsPerPeer: + + # we use getOrDefault in the if below instead of [] to avoid the KeyError + if peerId in c.expectedConnections and + not(c.expectedConnections.getOrDefault(peerId).finished): + c.expectedConnections.getOrDefault(peerId).complete(conn) + elif c.conns.getOrDefault(peerId).len > c.maxConnsPerPeer: debug "Too many connections for peer", conn, conns = c.conns.getOrDefault(peerId).len @@ -536,6 +555,12 @@ proc close*(c: ConnManager) {.async.} = let muxed = c.muxed c.muxed.clear() + let expected = c.expectedConnections + c.expectedConnections.clear() + + for _, fut in expected: + await fut.cancelAndWait() + for _, muxer in muxed: await closeMuxerHolder(muxer) diff --git a/tests/testconnmngr.nim b/tests/testconnmngr.nim index 9a59e61..fa3c665 100644 --- a/tests/testconnmngr.nim +++ b/tests/testconnmngr.nim @@ -96,7 +96,8 @@ suite "Connection Manager": await connMngr.close() asyncTest "get conn with direction": - let connMngr = ConnManager.new() + # This would work with 1 as well cause of a bug in connmanager that will get fixed soon + let connMngr = ConnManager.new(maxConnsPerPeer = 2) let peerId = PeerId.init(PrivateKey.random(ECDSA, (newRng())[]).tryGet()).tryGet() let conn1 = getConnection(peerId, Direction.Out) let conn2 = getConnection(peerId) @@ -176,7 +177,7 @@ suite "Connection Manager": await stream.close() asyncTest "should raise on too many connections": - let connMngr = ConnManager.new(maxConnsPerPeer = 1) + let connMngr = ConnManager.new(maxConnsPerPeer = 0) let peerId = PeerId.init(PrivateKey.random(ECDSA, (newRng())[]).tryGet()).tryGet() connMngr.storeConn(getConnection(peerId)) @@ -187,10 +188,46 @@ suite "Connection Manager": expect TooManyConnectionsError: connMngr.storeConn(conns[0]) + + await connMngr.close() + + await allFuturesThrowing( + allFutures(conns.mapIt( it.close() ))) + + asyncTest "expect connection from peer": + # FIXME This should be 1 instead of 0, it will get fixed soon + let connMngr = ConnManager.new(maxConnsPerPeer = 0) + let peerId = PeerId.init(PrivateKey.random(ECDSA, (newRng())[]).tryGet()).tryGet() + + connMngr.storeConn(getConnection(peerId)) + + let conns = @[ + getConnection(peerId), + getConnection(peerId)] + + expect TooManyConnectionsError: + connMngr.storeConn(conns[0]) + + let waitedConn1 = connMngr.expectConnection(peerId) + + expect LPError: + discard await connMngr.expectConnection(peerId) + + await waitedConn1.cancelAndWait() + let + waitedConn2 = connMngr.expectConnection(peerId) + waitedConn3 = connMngr.expectConnection(PeerId.init(PrivateKey.random(ECDSA, (newRng())[]).tryGet()).tryGet()) + conn = getConnection(peerId) + connMngr.storeConn(conn) + check (await waitedConn2) == conn + + expect TooManyConnectionsError: connMngr.storeConn(conns[1]) await connMngr.close() + checkExpiring: waitedConn3.cancelled() + await allFuturesThrowing( allFutures(conns.mapIt( it.close() )))