diff --git a/libp2p/connmanager.nim b/libp2p/connmanager.nim index 2ce4d7251..ac3a0b69a 100644 --- a/libp2p/connmanager.nim +++ b/libp2p/connmanager.nim @@ -32,6 +32,7 @@ const type TooManyConnectionsError* = object of LPError + AlreadyExpectingConnectionError* = object of LPError ConnEventKind* {.pure.} = enum Connected, # A connection was made and securely upgraded - there may be @@ -79,7 +80,7 @@ type muxed: Table[Connection, MuxerHolder] connEvents: array[ConnEventKind, OrderedSet[ConnEventHandler]] peerEvents: array[PeerEventKind, OrderedSet[PeerEventHandler]] - expectedConnections: Table[PeerId, Future[Connection]] + expectedConnectionsOverLimit*: Table[(PeerId, Direction), Future[Connection]] peerStore*: PeerStore ConnectionSlot* = object @@ -221,18 +222,19 @@ proc triggerPeerEvents*(c: ConnManager, except CatchableError as exc: # handlers should not raise! warn "Exception in triggerPeerEvents", exc = exc.msg, peer = peerId -proc expectConnection*(c: ConnManager, p: PeerId): Future[Connection] {.async.} = +proc expectConnection*(c: ConnManager, p: PeerId, dir: Direction): Future[Connection] {.async.} = ## Wait for a peer to connect to us. This will bypass the `MaxConnectionsPerPeer` - if p in c.expectedConnections: - raise LPError.newException("Already expecting a connection from that peer") + let key = (p, dir) + if key in c.expectedConnectionsOverLimit: + raise newException(AlreadyExpectingConnectionError, "Already expecting an incoming connection from that peer") let future = newFuture[Connection]() - c.expectedConnections[p] = future + c.expectedConnectionsOverLimit[key] = future try: return await future finally: - c.expectedConnections.del(p) + c.expectedConnectionsOverLimit.del(key) proc contains*(c: ConnManager, conn: Connection): bool = ## checks if a connection is being tracked by the @@ -412,14 +414,16 @@ proc storeConn*(c: ConnManager, conn: Connection) let peerId = conn.peerId # we use getOrDefault in the if below instead of [] to avoid the KeyError - if peerId in c.expectedConnections and - not(c.expectedConnections.getOrDefault(peerId).finished): - c.expectedConnections.getOrDefault(peerId).complete(conn) - elif c.conns.getOrDefault(peerId).len > c.maxConnsPerPeer: - debug "Too many connections for peer", - conn, conns = c.conns.getOrDefault(peerId).len + if c.conns.getOrDefault(peerId).len > c.maxConnsPerPeer: + let key = (peerId, conn.dir) + let expectedConn = c.expectedConnectionsOverLimit.getOrDefault(key) + if expectedConn != nil and not expectedConn.finished: + expectedConn.complete(conn) + else: + debug "Too many connections for peer", + conn, conns = c.conns.getOrDefault(peerId).len - raise newTooManyConnectionsError() + raise newTooManyConnectionsError() c.conns.mgetOrPut(peerId, HashSet[Connection]()).incl(conn) libp2p_peers.set(c.conns.len.int64) @@ -562,8 +566,8 @@ proc close*(c: ConnManager) {.async.} = let muxed = c.muxed c.muxed.clear() - let expected = c.expectedConnections - c.expectedConnections.clear() + let expected = c.expectedConnectionsOverLimit + c.expectedConnectionsOverLimit.clear() for _, fut in expected: await fut.cancelAndWait() diff --git a/libp2p/protocols/connectivity/autonat/client.nim b/libp2p/protocols/connectivity/autonat/client.nim index 55e0183c4..13176259c 100644 --- a/libp2p/protocols/connectivity/autonat/client.nim +++ b/libp2p/protocols/connectivity/autonat/client.nim @@ -58,10 +58,14 @@ method dialMe*(self: AutonatClient, switch: Switch, pid: PeerId, addrs: seq[Mult raise newException(AutonatError, "Unexpected error when dialling: " & err.msg, err) # To bypass maxConnectionsPerPeer - let incomingConnection = switch.connManager.expectConnection(pid) + let incomingConnection = switch.connManager.expectConnection(pid, In) + if incomingConnection.failed() and incomingConnection.error of AlreadyExpectingConnectionError: + raise newException(AutonatError, incomingConnection.error.msg) defer: await conn.close() incomingConnection.cancel() # Safer to always try to cancel cause we aren't sure if the peer dialled us or not + if incomingConnection.completed(): + await (await incomingConnection).close() trace "sending Dial", addrs = switch.peerInfo.addrs await conn.sendDial(switch.peerInfo.peerId, switch.peerInfo.addrs) let response = getResponseOrRaise(AutonatMsg.decode(await conn.readLp(1024))) diff --git a/libp2p/protocols/connectivity/autonat/server.nim b/libp2p/protocols/connectivity/autonat/server.nim index e8f1077e8..1440393e7 100644 --- a/libp2p/protocols/connectivity/autonat/server.nim +++ b/libp2p/protocols/connectivity/autonat/server.nim @@ -63,7 +63,10 @@ proc tryDial(autonat: Autonat, conn: Connection, addrs: seq[MultiAddress]) {.asy var futs: seq[Future[Opt[MultiAddress]]] try: # This is to bypass the per peer max connections limit - let outgoingConnection = autonat.switch.connManager.expectConnection(conn.peerId) + let outgoingConnection = autonat.switch.connManager.expectConnection(conn.peerId, Out) + if outgoingConnection.failed() and outgoingConnection.error of AlreadyExpectingConnectionError: + await conn.sendResponseError(DialRefused, outgoingConnection.error.msg) + return # Safer to always try to cancel cause we aren't sure if the connection was established defer: outgoingConnection.cancel() # tryDial is to bypass the global max connections limit diff --git a/libp2p/protocols/connectivity/autonat/service.nim b/libp2p/protocols/connectivity/autonat/service.nim index afbf3ca81..6d6ca7143 100644 --- a/libp2p/protocols/connectivity/autonat/service.nim +++ b/libp2p/protocols/connectivity/autonat/service.nim @@ -80,6 +80,9 @@ proc hasEnoughIncomingSlots(switch: Switch): bool = # we leave some margin instead of comparing to 0 as a peer could connect to us while we are asking for the dial back return switch.connManager.slotsAvailable(In) >= 2 +proc doesPeerHaveIncomingConn(switch: Switch, peerId: PeerId): bool = + return switch.connManager.selectConn(peerId, In) != nil + proc handleAnswer(self: AutonatService, ans: NetworkReachability) {.async.} = if ans == Unknown: @@ -104,6 +107,10 @@ proc handleAnswer(self: AutonatService, ans: NetworkReachability) {.async.} = proc askPeer(self: AutonatService, switch: Switch, peerId: PeerId): Future[NetworkReachability] {.async.} = logScope: peerId = $peerId + + if doesPeerHaveIncomingConn(switch, peerId): + return Unknown + if not hasEnoughIncomingSlots(switch): debug "No incoming slots available, not asking peer", incomingSlotsAvailable=switch.connManager.slotsAvailable(In) return Unknown @@ -152,10 +159,7 @@ method setup*(self: AutonatService, switch: Switch): Future[bool] {.async.} = if hasBeenSetup: if self.askNewConnectedPeers: self.newConnectedPeerHandler = proc (peerId: PeerId, event: PeerEvent): Future[void] {.async.} = - if switch.connManager.selectConn(peerId, In) != nil: # no need to ask an incoming peer - return discard askPeer(self, switch, peerId) - await self.callHandler() switch.connManager.addPeerEventHandler(self.newConnectedPeerHandler, PeerEventKind.Joined) if self.scheduleInterval.isSome(): self.scheduleHandle = schedule(self, switch, self.scheduleInterval.get()) diff --git a/tests/testautonatservice.nim b/tests/testautonatservice.nim index d7e684abe..4a0d35d63 100644 --- a/tests/testautonatservice.nim +++ b/tests/testautonatservice.nim @@ -277,6 +277,94 @@ suite "Autonat Service": await allFuturesThrowing( switch1.stop(), switch2.stop()) + asyncTest "Must work when peers ask each other at the same time with max 1 conn per peer": + let autonatService1 = AutonatService.new(AutonatClient.new(), newRng(), some(500.millis), maxQueueSize = 3) + let autonatService2 = AutonatService.new(AutonatClient.new(), newRng(), some(500.millis), maxQueueSize = 3) + let autonatService3 = AutonatService.new(AutonatClient.new(), newRng(), some(500.millis), maxQueueSize = 3) + + let switch1 = createSwitch(autonatService1, maxConnsPerPeer = 0) + let switch2 = createSwitch(autonatService2, maxConnsPerPeer = 0) + let switch3 = createSwitch(autonatService2, maxConnsPerPeer = 0) + + let awaiter1 = newFuture[void]() + let awaiter2 = newFuture[void]() + let awaiter3 = newFuture[void]() + + proc statusAndConfidenceHandler1(networkReachability: NetworkReachability, confidence: Option[float]) {.gcsafe, async.} = + if networkReachability == NetworkReachability.Reachable and confidence.isSome() and confidence.get() == 1: + if not awaiter1.finished: + awaiter1.complete() + + proc statusAndConfidenceHandler2(networkReachability: NetworkReachability, confidence: Option[float]) {.gcsafe, async.} = + if networkReachability == NetworkReachability.Reachable and confidence.isSome() and confidence.get() == 1: + if not awaiter2.finished: + awaiter2.complete() + + check autonatService1.networkReachability() == NetworkReachability.Unknown + check autonatService2.networkReachability() == NetworkReachability.Unknown + + autonatService1.statusAndConfidenceHandler(statusAndConfidenceHandler1) + autonatService2.statusAndConfidenceHandler(statusAndConfidenceHandler2) + + await switch1.start() + await switch2.start() + await switch3.start() + + await switch1.connect(switch2.peerInfo.peerId, switch2.peerInfo.addrs) + await switch2.connect(switch1.peerInfo.peerId, switch1.peerInfo.addrs) + await switch2.connect(switch3.peerInfo.peerId, switch3.peerInfo.addrs) + + await awaiter1 + await awaiter2 + + check autonatService1.networkReachability() == NetworkReachability.Reachable + check autonatService2.networkReachability() == NetworkReachability.Reachable + check libp2p_autonat_reachability_confidence.value(["Reachable"]) == 1 + + await allFuturesThrowing( + switch1.stop(), switch2.stop(), switch3.stop()) + + asyncTest "Must work for one peer when two peers ask each other at the same time with max 1 conn per peer": + let autonatService1 = AutonatService.new(AutonatClient.new(), newRng(), some(500.millis), maxQueueSize = 3) + let autonatService2 = AutonatService.new(AutonatClient.new(), newRng(), some(500.millis), maxQueueSize = 3) + + let switch1 = createSwitch(autonatService1, maxConnsPerPeer = 0) + let switch2 = createSwitch(autonatService2, maxConnsPerPeer = 0) + + let awaiter1 = newFuture[void]() + + proc statusAndConfidenceHandler1(networkReachability: NetworkReachability, confidence: Option[float]) {.gcsafe, async.} = + if networkReachability == NetworkReachability.Reachable and confidence.isSome() and confidence.get() == 1: + if not awaiter1.finished: + awaiter1.complete() + + check autonatService1.networkReachability() == NetworkReachability.Unknown + + autonatService1.statusAndConfidenceHandler(statusAndConfidenceHandler1) + + await switch1.start() + await switch2.start() + + await switch1.connect(switch2.peerInfo.peerId, switch2.peerInfo.addrs) + try: + # We allow a temp conn for the peer to dial us. It could use this conn to just connect to us and not dial. + # We don't care if it fails at this point or not. But this conn must be closed eventually. + # Bellow we check that there's only one connection between the peers + await switch2.connect(switch1.peerInfo.peerId, switch1.peerInfo.addrs, reuseConnection = false) + except CatchableError: + discard + + await awaiter1 + + check autonatService1.networkReachability() == NetworkReachability.Reachable + check libp2p_autonat_reachability_confidence.value(["Reachable"]) == 1 + + # Make sure remote peer can't create a connection to us + check switch1.connManager.connCount(switch2.peerInfo.peerId) == 1 + + await allFuturesThrowing( + switch1.stop(), switch2.stop()) + asyncTest "Must work with low maxConnections": let autonatService = AutonatService.new(AutonatClient.new(), newRng(), some(1.seconds), maxQueueSize = 1) diff --git a/tests/testconnmngr.nim b/tests/testconnmngr.nim index fa3c665b8..f15215766 100644 --- a/tests/testconnmngr.nim +++ b/tests/testconnmngr.nim @@ -208,15 +208,15 @@ suite "Connection Manager": expect TooManyConnectionsError: connMngr.storeConn(conns[0]) - let waitedConn1 = connMngr.expectConnection(peerId) + let waitedConn1 = connMngr.expectConnection(peerId, In) - expect LPError: - discard await connMngr.expectConnection(peerId) + expect AlreadyExpectingConnectionError: + discard await connMngr.expectConnection(peerId, In) await waitedConn1.cancelAndWait() let - waitedConn2 = connMngr.expectConnection(peerId) - waitedConn3 = connMngr.expectConnection(PeerId.init(PrivateKey.random(ECDSA, (newRng())[]).tryGet()).tryGet()) + waitedConn2 = connMngr.expectConnection(peerId, In) + waitedConn3 = connMngr.expectConnection(PeerId.init(PrivateKey.random(ECDSA, (newRng())[]).tryGet()).tryGet(), In) conn = getConnection(peerId) connMngr.storeConn(conn) check (await waitedConn2) == conn