diff --git a/libp2p/dial.nim b/libp2p/dial.nim index 422af7018..5d46cf5fc 100644 --- a/libp2p/dial.nim +++ b/libp2p/dial.nim @@ -26,7 +26,7 @@ method connect*( addrs: seq[MultiAddress], forceDial = false, reuseConnection = true, - upgradeDir = Direction.Out) {.async, base.} = + dir = Direction.Out) {.async, base.} = ## connect remote peer without negotiating ## a protocol ## diff --git a/libp2p/dialer.nim b/libp2p/dialer.nim index 9fabd3f3d..986f4e370 100644 --- a/libp2p/dialer.nim +++ b/libp2p/dialer.nim @@ -53,7 +53,7 @@ proc dialAndUpgrade( peerId: Opt[PeerId], hostname: string, address: MultiAddress, - upgradeDir = Direction.Out): + dir = Direction.Out): Future[Muxer] {.async.} = for transport in self.transports: # for each transport @@ -75,15 +75,19 @@ proc dialAndUpgrade( let mux = try: - dialed.transportDir = upgradeDir - await transport.upgrade(dialed, upgradeDir, peerId) + # This is for the very specific case of a simultaneous dial during DCUtR. In this case, both sides will have + # an Outbound direction at the transport level. Therefore we update the DCUtR initiator transport direction to Inbound. + # The if below is more general and might handle other use cases in the future. + if dialed.dir != dir: + dialed.dir = dir + await transport.upgrade(dialed, peerId) except CatchableError as exc: # If we failed to establish the connection through one transport, # we won't succeeded through another - no use in trying again await dialed.close() debug "Upgrade failed", err = exc.msg, peerId = peerId.get(default(PeerId)) if exc isnot CancelledError: - if upgradeDir == Direction.Out: + if dialed.dir == Direction.Out: libp2p_failed_upgrades_outgoing.inc() else: libp2p_failed_upgrades_incoming.inc() @@ -91,7 +95,7 @@ proc dialAndUpgrade( # Try other address return nil - doAssert not isNil(mux), "connection died after upgrade " & $upgradeDir + doAssert not isNil(mux), "connection died after upgrade " & $dialed.dir debug "Dial successful", peerId = mux.connection.peerId return mux return nil @@ -128,7 +132,7 @@ proc dialAndUpgrade( self: Dialer, peerId: Opt[PeerId], addrs: seq[MultiAddress], - upgradeDir = Direction.Out): + dir = Direction.Out): Future[Muxer] {.async.} = debug "Dialing peer", peerId = peerId.get(default(PeerId)) @@ -146,7 +150,7 @@ proc dialAndUpgrade( else: await self.nameResolver.resolveMAddress(expandedAddress) for resolvedAddress in resolvedAddresses: - result = await self.dialAndUpgrade(addrPeerId, hostname, resolvedAddress, upgradeDir) + result = await self.dialAndUpgrade(addrPeerId, hostname, resolvedAddress, dir) if not isNil(result): return result @@ -164,7 +168,7 @@ proc internalConnect( addrs: seq[MultiAddress], forceDial: bool, reuseConnection = true, - upgradeDir = Direction.Out): + dir = Direction.Out): Future[Muxer] {.async.} = if Opt.some(self.localPeerId) == peerId: raise newException(CatchableError, "can't dial self!") @@ -182,7 +186,7 @@ proc internalConnect( let slot = self.connManager.getOutgoingSlot(forceDial) let muxed = try: - await self.dialAndUpgrade(peerId, addrs, upgradeDir) + await self.dialAndUpgrade(peerId, addrs, dir) except CatchableError as exc: slot.release() raise exc @@ -209,7 +213,7 @@ method connect*( addrs: seq[MultiAddress], forceDial = false, reuseConnection = true, - upgradeDir = Direction.Out) {.async.} = + dir = Direction.Out) {.async.} = ## connect remote peer without negotiating ## a protocol ## @@ -217,7 +221,7 @@ method connect*( if self.connManager.connCount(peerId) > 0 and reuseConnection: return - discard await self.internalConnect(Opt.some(peerId), addrs, forceDial, reuseConnection, upgradeDir) + discard await self.internalConnect(Opt.some(peerId), addrs, forceDial, reuseConnection, dir) method connect*( self: Dialer, diff --git a/libp2p/protocols/connectivity/dcutr/client.nim b/libp2p/protocols/connectivity/dcutr/client.nim index 19dda0cef..d76021802 100644 --- a/libp2p/protocols/connectivity/dcutr/client.nim +++ b/libp2p/protocols/connectivity/dcutr/client.nim @@ -66,7 +66,7 @@ proc startSync*(self: DcutrClient, switch: Switch, remotePeerId: PeerId, addrs: if peerDialableAddrs.len > self.maxDialableAddrs: peerDialableAddrs = peerDialableAddrs[0.. maxDialableAddrs: peerDialableAddrs = peerDialableAddrs[0.. 0) - return await secureProtocol[0].secure(conn, direction == Out, peerId) + return await secureProtocol[0].secure(conn, peerId) diff --git a/tests/stubs/switchstub.nim b/tests/stubs/switchstub.nim index e93065b01..23b497365 100644 --- a/tests/stubs/switchstub.nim +++ b/tests/stubs/switchstub.nim @@ -24,7 +24,7 @@ type addrs: seq[MultiAddress], forceDial = false, reuseConnection = true, - upgradeDir = Direction.Out): Future[void] {.gcsafe, async.} + dir = Direction.Out): Future[void] {.gcsafe, async.} method connect*( self: SwitchStub, @@ -32,11 +32,11 @@ method connect*( addrs: seq[MultiAddress], forceDial = false, reuseConnection = true, - upgradeDir = Direction.Out) {.async.} = + dir = Direction.Out) {.async.} = if (self.connectStub != nil): - await self.connectStub(self, peerId, addrs, forceDial, reuseConnection, upgradeDir) + await self.connectStub(self, peerId, addrs, forceDial, reuseConnection, dir) else: - await self.switch.connect(peerId, addrs, forceDial, reuseConnection, upgradeDir) + await self.switch.connect(peerId, addrs, forceDial, reuseConnection, dir) proc new*(T: typedesc[SwitchStub], switch: Switch, connectStub: connectStubType = nil): T = return SwitchStub( diff --git a/tests/testdcutr.nim b/tests/testdcutr.nim index c125373ae..da6fb5e38 100644 --- a/tests/testdcutr.nim +++ b/tests/testdcutr.nim @@ -96,7 +96,7 @@ suite "Dcutr": addrs: seq[MultiAddress], forceDial = false, reuseConnection = true, - upgradeDir = Direction.Out): Future[void] {.async.} = + dir = Direction.Out): Future[void] {.async.} = await sleepAsync(100.millis) let behindNATSwitch = SwitchStub.new(newStandardSwitch(), connectTimeoutProc) @@ -115,7 +115,7 @@ suite "Dcutr": addrs: seq[MultiAddress], forceDial = false, reuseConnection = true, - upgradeDir = Direction.Out): Future[void] {.async.} = + dir = Direction.Out): Future[void] {.async.} = raise newException(CatchableError, "error") let behindNATSwitch = SwitchStub.new(newStandardSwitch(), connectErrorProc) @@ -163,7 +163,7 @@ suite "Dcutr": addrs: seq[MultiAddress], forceDial = false, reuseConnection = true, - upgradeDir = Direction.Out): Future[void] {.async.} = + dir = Direction.Out): Future[void] {.async.} = await sleepAsync(100.millis) await ductrServerTest(connectProc) @@ -175,7 +175,7 @@ suite "Dcutr": addrs: seq[MultiAddress], forceDial = false, reuseConnection = true, - upgradeDir = Direction.Out): Future[void] {.async.} = + dir = Direction.Out): Future[void] {.async.} = raise newException(CatchableError, "error") await ductrServerTest(connectProc) diff --git a/tests/testhpservice.nim b/tests/testhpservice.nim index 48f2bbb83..3897d05e7 100644 --- a/tests/testhpservice.nim +++ b/tests/testhpservice.nim @@ -210,7 +210,7 @@ suite "Hole Punching": addrs: seq[MultiAddress], forceDial = false, reuseConnection = true, - upgradeDir = Direction.Out): Future[void] {.async.} = + dir = Direction.Out): Future[void] {.async.} = self.connectStub = nil # this stub should be called only once raise newException(CatchableError, "error") diff --git a/tests/testnoise.nim b/tests/testnoise.nim index f598b4202..ef05a1402 100644 --- a/tests/testnoise.nim +++ b/tests/testnoise.nim @@ -100,7 +100,7 @@ suite "Noise": proc acceptHandler() {.async.} = let conn = await transport1.accept() - let sconn = await serverNoise.secure(conn, false, Opt.none(PeerId)) + let sconn = await serverNoise.secure(conn, Opt.none(PeerId)) try: await sconn.write("Hello!") finally: @@ -115,7 +115,7 @@ suite "Noise": clientNoise = Noise.new(rng, clientPrivKey, outgoing = true) conn = await transport2.dial(transport1.addrs[0]) - let sconn = await clientNoise.secure(conn, true, Opt.some(serverInfo.peerId)) + let sconn = await clientNoise.secure(conn, Opt.some(serverInfo.peerId)) var msg = newSeq[byte](6) await sconn.readExactly(addr msg[0], 6) @@ -144,7 +144,7 @@ suite "Noise": var conn: Connection try: conn = await transport1.accept() - discard await serverNoise.secure(conn, false, Opt.none(PeerId)) + discard await serverNoise.secure(conn, Opt.none(PeerId)) except CatchableError: discard finally: @@ -160,7 +160,7 @@ suite "Noise": var sconn: Connection = nil expect(NoiseDecryptTagError): - sconn = await clientNoise.secure(conn, true, Opt.some(conn.peerId)) + sconn = await clientNoise.secure(conn, Opt.some(conn.peerId)) await conn.close() await handlerWait @@ -180,7 +180,7 @@ suite "Noise": proc acceptHandler() {.async, gcsafe.} = let conn = await transport1.accept() - let sconn = await serverNoise.secure(conn, false, Opt.none(PeerId)) + let sconn = await serverNoise.secure(conn, Opt.none(PeerId)) defer: await sconn.close() await conn.close() @@ -196,7 +196,7 @@ suite "Noise": clientInfo = PeerInfo.new(clientPrivKey, transport1.addrs) clientNoise = Noise.new(rng, clientPrivKey, outgoing = true) conn = await transport2.dial(transport1.addrs[0]) - let sconn = await clientNoise.secure(conn, true, Opt.some(serverInfo.peerId)) + let sconn = await clientNoise.secure(conn, Opt.some(serverInfo.peerId)) await sconn.write("Hello!") await acceptFut @@ -223,7 +223,7 @@ suite "Noise": proc acceptHandler() {.async, gcsafe.} = let conn = await transport1.accept() - let sconn = await serverNoise.secure(conn, false, Opt.none(PeerId)) + let sconn = await serverNoise.secure(conn, Opt.none(PeerId)) defer: await sconn.close() let msg = await sconn.readLp(1024*1024) @@ -237,7 +237,7 @@ suite "Noise": clientInfo = PeerInfo.new(clientPrivKey, transport1.addrs) clientNoise = Noise.new(rng, clientPrivKey, outgoing = true) conn = await transport2.dial(transport1.addrs[0]) - let sconn = await clientNoise.secure(conn, true, Opt.some(serverInfo.peerId)) + let sconn = await clientNoise.secure(conn, Opt.some(serverInfo.peerId)) await sconn.writeLp(hugePayload) await readTask