fix(dcutr): update the DCUtR initiator transport direction to Inbound (#994)

This commit is contained in:
diegomrsantos 2023-11-29 17:38:47 +01:00 committed by GitHub
parent ce0685c272
commit deb72c8580
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
13 changed files with 48 additions and 50 deletions

View File

@ -26,7 +26,7 @@ method connect*(
addrs: seq[MultiAddress], addrs: seq[MultiAddress],
forceDial = false, forceDial = false,
reuseConnection = true, reuseConnection = true,
upgradeDir = Direction.Out) {.async, base.} = dir = Direction.Out) {.async, base.} =
## connect remote peer without negotiating ## connect remote peer without negotiating
## a protocol ## a protocol
## ##

View File

@ -53,7 +53,7 @@ proc dialAndUpgrade(
peerId: Opt[PeerId], peerId: Opt[PeerId],
hostname: string, hostname: string,
address: MultiAddress, address: MultiAddress,
upgradeDir = Direction.Out): dir = Direction.Out):
Future[Muxer] {.async.} = Future[Muxer] {.async.} =
for transport in self.transports: # for each transport for transport in self.transports: # for each transport
@ -75,15 +75,19 @@ proc dialAndUpgrade(
let mux = let mux =
try: try:
dialed.transportDir = upgradeDir # This is for the very specific case of a simultaneous dial during DCUtR. In this case, both sides will have
await transport.upgrade(dialed, upgradeDir, peerId) # an Outbound direction at the transport level. Therefore we update the DCUtR initiator transport direction to Inbound.
# The if below is more general and might handle other use cases in the future.
if dialed.dir != dir:
dialed.dir = dir
await transport.upgrade(dialed, peerId)
except CatchableError as exc: except CatchableError as exc:
# If we failed to establish the connection through one transport, # If we failed to establish the connection through one transport,
# we won't succeeded through another - no use in trying again # we won't succeeded through another - no use in trying again
await dialed.close() await dialed.close()
debug "Upgrade failed", err = exc.msg, peerId = peerId.get(default(PeerId)) debug "Upgrade failed", err = exc.msg, peerId = peerId.get(default(PeerId))
if exc isnot CancelledError: if exc isnot CancelledError:
if upgradeDir == Direction.Out: if dialed.dir == Direction.Out:
libp2p_failed_upgrades_outgoing.inc() libp2p_failed_upgrades_outgoing.inc()
else: else:
libp2p_failed_upgrades_incoming.inc() libp2p_failed_upgrades_incoming.inc()
@ -91,7 +95,7 @@ proc dialAndUpgrade(
# Try other address # Try other address
return nil return nil
doAssert not isNil(mux), "connection died after upgrade " & $upgradeDir doAssert not isNil(mux), "connection died after upgrade " & $dialed.dir
debug "Dial successful", peerId = mux.connection.peerId debug "Dial successful", peerId = mux.connection.peerId
return mux return mux
return nil return nil
@ -128,7 +132,7 @@ proc dialAndUpgrade(
self: Dialer, self: Dialer,
peerId: Opt[PeerId], peerId: Opt[PeerId],
addrs: seq[MultiAddress], addrs: seq[MultiAddress],
upgradeDir = Direction.Out): dir = Direction.Out):
Future[Muxer] {.async.} = Future[Muxer] {.async.} =
debug "Dialing peer", peerId = peerId.get(default(PeerId)) debug "Dialing peer", peerId = peerId.get(default(PeerId))
@ -146,7 +150,7 @@ proc dialAndUpgrade(
else: await self.nameResolver.resolveMAddress(expandedAddress) else: await self.nameResolver.resolveMAddress(expandedAddress)
for resolvedAddress in resolvedAddresses: for resolvedAddress in resolvedAddresses:
result = await self.dialAndUpgrade(addrPeerId, hostname, resolvedAddress, upgradeDir) result = await self.dialAndUpgrade(addrPeerId, hostname, resolvedAddress, dir)
if not isNil(result): if not isNil(result):
return result return result
@ -164,7 +168,7 @@ proc internalConnect(
addrs: seq[MultiAddress], addrs: seq[MultiAddress],
forceDial: bool, forceDial: bool,
reuseConnection = true, reuseConnection = true,
upgradeDir = Direction.Out): dir = Direction.Out):
Future[Muxer] {.async.} = Future[Muxer] {.async.} =
if Opt.some(self.localPeerId) == peerId: if Opt.some(self.localPeerId) == peerId:
raise newException(CatchableError, "can't dial self!") raise newException(CatchableError, "can't dial self!")
@ -182,7 +186,7 @@ proc internalConnect(
let slot = self.connManager.getOutgoingSlot(forceDial) let slot = self.connManager.getOutgoingSlot(forceDial)
let muxed = let muxed =
try: try:
await self.dialAndUpgrade(peerId, addrs, upgradeDir) await self.dialAndUpgrade(peerId, addrs, dir)
except CatchableError as exc: except CatchableError as exc:
slot.release() slot.release()
raise exc raise exc
@ -209,7 +213,7 @@ method connect*(
addrs: seq[MultiAddress], addrs: seq[MultiAddress],
forceDial = false, forceDial = false,
reuseConnection = true, reuseConnection = true,
upgradeDir = Direction.Out) {.async.} = dir = Direction.Out) {.async.} =
## connect remote peer without negotiating ## connect remote peer without negotiating
## a protocol ## a protocol
## ##
@ -217,7 +221,7 @@ method connect*(
if self.connManager.connCount(peerId) > 0 and reuseConnection: if self.connManager.connCount(peerId) > 0 and reuseConnection:
return return
discard await self.internalConnect(Opt.some(peerId), addrs, forceDial, reuseConnection, upgradeDir) discard await self.internalConnect(Opt.some(peerId), addrs, forceDial, reuseConnection, dir)
method connect*( method connect*(
self: Dialer, self: Dialer,

View File

@ -66,7 +66,7 @@ proc startSync*(self: DcutrClient, switch: Switch, remotePeerId: PeerId, addrs:
if peerDialableAddrs.len > self.maxDialableAddrs: if peerDialableAddrs.len > self.maxDialableAddrs:
peerDialableAddrs = peerDialableAddrs[0..<self.maxDialableAddrs] peerDialableAddrs = peerDialableAddrs[0..<self.maxDialableAddrs]
var futs = peerDialableAddrs.mapIt(switch.connect(stream.peerId, @[it], forceDial = true, reuseConnection = false, upgradeDir = Direction.In)) var futs = peerDialableAddrs.mapIt(switch.connect(stream.peerId, @[it], forceDial = true, reuseConnection = false, dir = Direction.In))
try: try:
discard await anyCompleted(futs).wait(self.connectTimeout) discard await anyCompleted(futs).wait(self.connectTimeout)
debug "Dcutr initiator has directly connected to the remote peer." debug "Dcutr initiator has directly connected to the remote peer."

View File

@ -56,7 +56,7 @@ proc new*(T: typedesc[Dcutr], switch: Switch, connectTimeout = 15.seconds, maxDi
if peerDialableAddrs.len > maxDialableAddrs: if peerDialableAddrs.len > maxDialableAddrs:
peerDialableAddrs = peerDialableAddrs[0..<maxDialableAddrs] peerDialableAddrs = peerDialableAddrs[0..<maxDialableAddrs]
var futs = peerDialableAddrs.mapIt(switch.connect(stream.peerId, @[it], forceDial = true, reuseConnection = false, upgradeDir = Direction.Out)) var futs = peerDialableAddrs.mapIt(switch.connect(stream.peerId, @[it], forceDial = true, reuseConnection = false, dir = Direction.Out))
try: try:
discard await anyCompleted(futs).wait(connectTimeout) discard await anyCompleted(futs).wait(connectTimeout)
debug "Dcutr receiver has directly connected to the remote peer." debug "Dcutr receiver has directly connected to the remote peer."

View File

@ -135,10 +135,9 @@ method init*(s: Secure) =
method secure*(s: Secure, method secure*(s: Secure,
conn: Connection, conn: Connection,
initiator: bool,
peerId: Opt[PeerId]): peerId: Opt[PeerId]):
Future[Connection] {.base.} = Future[Connection] {.base.} =
s.handleConn(conn, initiator, peerId) s.handleConn(conn, conn.dir == Direction.Out, peerId)
method readOnce*(s: SecureConn, method readOnce*(s: SecureConn,
pbytes: pointer, pbytes: pointer,

View File

@ -141,10 +141,10 @@ method connect*(
addrs: seq[MultiAddress], addrs: seq[MultiAddress],
forceDial = false, forceDial = false,
reuseConnection = true, reuseConnection = true,
upgradeDir = Direction.Out): Future[void] {.public.} = dir = Direction.Out): Future[void] {.public.} =
## Connects to a peer without opening a stream to it ## Connects to a peer without opening a stream to it
s.dialer.connect(peerId, addrs, forceDial, reuseConnection, upgradeDir) s.dialer.connect(peerId, addrs, forceDial, reuseConnection, dir)
method connect*( method connect*(
s: Switch, s: Switch,
@ -213,7 +213,7 @@ proc mount*[T: LPProtocol](s: Switch, proto: T, matcher: Matcher = nil)
s.peerInfo.protocols.add(proto.codec) s.peerInfo.protocols.add(proto.codec)
proc upgrader(switch: Switch, trans: Transport, conn: Connection) {.async.} = proc upgrader(switch: Switch, trans: Transport, conn: Connection) {.async.} =
let muxed = await trans.upgrade(conn, Direction.In, Opt.none(PeerId)) let muxed = await trans.upgrade(conn, Opt.none(PeerId))
switch.connManager.storeMuxer(muxed) switch.connManager.storeMuxer(muxed)
await switch.peerStore.identify(muxed) await switch.peerStore.identify(muxed)
trace "Connection upgrade succeeded" trace "Connection upgrade succeeded"

View File

@ -83,13 +83,12 @@ proc dial*(
method upgrade*( method upgrade*(
self: Transport, self: Transport,
conn: Connection, conn: Connection,
direction: Direction,
peerId: Opt[PeerId]): Future[Muxer] {.base, gcsafe.} = peerId: Opt[PeerId]): Future[Muxer] {.base, gcsafe.} =
## base upgrade method that the transport uses to perform ## base upgrade method that the transport uses to perform
## transport specific upgrades ## transport specific upgrades
## ##
self.upgrader.upgrade(conn, direction, peerId) self.upgrader.upgrade(conn, peerId)
method handles*( method handles*(
self: Transport, self: Transport,

View File

@ -32,8 +32,7 @@ proc getMuxerByCodec(self: MuxedUpgrade, muxerName: string): MuxerProvider =
proc mux*( proc mux*(
self: MuxedUpgrade, self: MuxedUpgrade,
conn: Connection, conn: Connection): Future[Muxer] {.async, gcsafe.} =
direction: Direction): Future[Muxer] {.async, gcsafe.} =
## mux connection ## mux connection
trace "Muxing connection", conn trace "Muxing connection", conn
@ -42,7 +41,7 @@ proc mux*(
return return
let muxerName = let muxerName =
if direction == Out: await self.ms.select(conn, self.muxers.mapIt(it.codec)) if conn.dir == Out: await self.ms.select(conn, self.muxers.mapIt(it.codec))
else: await MultistreamSelect.handle(conn, self.muxers.mapIt(it.codec)) else: await MultistreamSelect.handle(conn, self.muxers.mapIt(it.codec))
if muxerName.len == 0 or muxerName == "na": if muxerName.len == 0 or muxerName == "na":
@ -62,16 +61,15 @@ proc mux*(
method upgrade*( method upgrade*(
self: MuxedUpgrade, self: MuxedUpgrade,
conn: Connection, conn: Connection,
direction: Direction,
peerId: Opt[PeerId]): Future[Muxer] {.async.} = peerId: Opt[PeerId]): Future[Muxer] {.async.} =
trace "Upgrading connection", conn, direction trace "Upgrading connection", conn, direction = conn.dir
let sconn = await self.secure(conn, direction, peerId) # secure the connection let sconn = await self.secure(conn, peerId) # secure the connection
if isNil(sconn): if isNil(sconn):
raise newException(UpgradeFailedError, raise newException(UpgradeFailedError,
"unable to secure connection, stopping upgrade") "unable to secure connection, stopping upgrade")
let muxer = await self.mux(sconn, direction) # mux it if possible let muxer = await self.mux(sconn) # mux it if possible
if muxer == nil: if muxer == nil:
raise newException(UpgradeFailedError, raise newException(UpgradeFailedError,
"a muxer is required for outgoing connections") "a muxer is required for outgoing connections")
@ -84,7 +82,7 @@ method upgrade*(
raise newException(UpgradeFailedError, raise newException(UpgradeFailedError,
"Connection closed or missing peer info, stopping upgrade") "Connection closed or missing peer info, stopping upgrade")
trace "Upgraded connection", conn, sconn, direction trace "Upgraded connection", conn, sconn, direction = conn.dir
return muxer return muxer
proc new*( proc new*(

View File

@ -40,20 +40,18 @@ type
method upgrade*( method upgrade*(
self: Upgrade, self: Upgrade,
conn: Connection, conn: Connection,
direction: Direction,
peerId: Opt[PeerId]): Future[Muxer] {.base.} = peerId: Opt[PeerId]): Future[Muxer] {.base.} =
doAssert(false, "Not implemented!") doAssert(false, "Not implemented!")
proc secure*( proc secure*(
self: Upgrade, self: Upgrade,
conn: Connection, conn: Connection,
direction: Direction,
peerId: Opt[PeerId]): Future[Connection] {.async, gcsafe.} = peerId: Opt[PeerId]): Future[Connection] {.async, gcsafe.} =
if self.secureManagers.len <= 0: if self.secureManagers.len <= 0:
raise newException(UpgradeFailedError, "No secure managers registered!") raise newException(UpgradeFailedError, "No secure managers registered!")
let codec = let codec =
if direction == Out: await self.ms.select(conn, self.secureManagers.mapIt(it.codec)) if conn.dir == Out: await self.ms.select(conn, self.secureManagers.mapIt(it.codec))
else: await MultistreamSelect.handle(conn, self.secureManagers.mapIt(it.codec)) else: await MultistreamSelect.handle(conn, self.secureManagers.mapIt(it.codec))
if codec.len == 0: if codec.len == 0:
raise newException(UpgradeFailedError, "Unable to negotiate a secure channel!") raise newException(UpgradeFailedError, "Unable to negotiate a secure channel!")
@ -65,4 +63,4 @@ proc secure*(
# let's avoid duplicating checks but detect if it fails to do it properly # let's avoid duplicating checks but detect if it fails to do it properly
doAssert(secureProtocol.len > 0) doAssert(secureProtocol.len > 0)
return await secureProtocol[0].secure(conn, direction == Out, peerId) return await secureProtocol[0].secure(conn, peerId)

View File

@ -24,7 +24,7 @@ type
addrs: seq[MultiAddress], addrs: seq[MultiAddress],
forceDial = false, forceDial = false,
reuseConnection = true, reuseConnection = true,
upgradeDir = Direction.Out): Future[void] {.gcsafe, async.} dir = Direction.Out): Future[void] {.gcsafe, async.}
method connect*( method connect*(
self: SwitchStub, self: SwitchStub,
@ -32,11 +32,11 @@ method connect*(
addrs: seq[MultiAddress], addrs: seq[MultiAddress],
forceDial = false, forceDial = false,
reuseConnection = true, reuseConnection = true,
upgradeDir = Direction.Out) {.async.} = dir = Direction.Out) {.async.} =
if (self.connectStub != nil): if (self.connectStub != nil):
await self.connectStub(self, peerId, addrs, forceDial, reuseConnection, upgradeDir) await self.connectStub(self, peerId, addrs, forceDial, reuseConnection, dir)
else: else:
await self.switch.connect(peerId, addrs, forceDial, reuseConnection, upgradeDir) await self.switch.connect(peerId, addrs, forceDial, reuseConnection, dir)
proc new*(T: typedesc[SwitchStub], switch: Switch, connectStub: connectStubType = nil): T = proc new*(T: typedesc[SwitchStub], switch: Switch, connectStub: connectStubType = nil): T =
return SwitchStub( return SwitchStub(

View File

@ -96,7 +96,7 @@ suite "Dcutr":
addrs: seq[MultiAddress], addrs: seq[MultiAddress],
forceDial = false, forceDial = false,
reuseConnection = true, reuseConnection = true,
upgradeDir = Direction.Out): Future[void] {.async.} = dir = Direction.Out): Future[void] {.async.} =
await sleepAsync(100.millis) await sleepAsync(100.millis)
let behindNATSwitch = SwitchStub.new(newStandardSwitch(), connectTimeoutProc) let behindNATSwitch = SwitchStub.new(newStandardSwitch(), connectTimeoutProc)
@ -115,7 +115,7 @@ suite "Dcutr":
addrs: seq[MultiAddress], addrs: seq[MultiAddress],
forceDial = false, forceDial = false,
reuseConnection = true, reuseConnection = true,
upgradeDir = Direction.Out): Future[void] {.async.} = dir = Direction.Out): Future[void] {.async.} =
raise newException(CatchableError, "error") raise newException(CatchableError, "error")
let behindNATSwitch = SwitchStub.new(newStandardSwitch(), connectErrorProc) let behindNATSwitch = SwitchStub.new(newStandardSwitch(), connectErrorProc)
@ -163,7 +163,7 @@ suite "Dcutr":
addrs: seq[MultiAddress], addrs: seq[MultiAddress],
forceDial = false, forceDial = false,
reuseConnection = true, reuseConnection = true,
upgradeDir = Direction.Out): Future[void] {.async.} = dir = Direction.Out): Future[void] {.async.} =
await sleepAsync(100.millis) await sleepAsync(100.millis)
await ductrServerTest(connectProc) await ductrServerTest(connectProc)
@ -175,7 +175,7 @@ suite "Dcutr":
addrs: seq[MultiAddress], addrs: seq[MultiAddress],
forceDial = false, forceDial = false,
reuseConnection = true, reuseConnection = true,
upgradeDir = Direction.Out): Future[void] {.async.} = dir = Direction.Out): Future[void] {.async.} =
raise newException(CatchableError, "error") raise newException(CatchableError, "error")
await ductrServerTest(connectProc) await ductrServerTest(connectProc)

View File

@ -210,7 +210,7 @@ suite "Hole Punching":
addrs: seq[MultiAddress], addrs: seq[MultiAddress],
forceDial = false, forceDial = false,
reuseConnection = true, reuseConnection = true,
upgradeDir = Direction.Out): Future[void] {.async.} = dir = Direction.Out): Future[void] {.async.} =
self.connectStub = nil # this stub should be called only once self.connectStub = nil # this stub should be called only once
raise newException(CatchableError, "error") raise newException(CatchableError, "error")

View File

@ -100,7 +100,7 @@ suite "Noise":
proc acceptHandler() {.async.} = proc acceptHandler() {.async.} =
let conn = await transport1.accept() let conn = await transport1.accept()
let sconn = await serverNoise.secure(conn, false, Opt.none(PeerId)) let sconn = await serverNoise.secure(conn, Opt.none(PeerId))
try: try:
await sconn.write("Hello!") await sconn.write("Hello!")
finally: finally:
@ -115,7 +115,7 @@ suite "Noise":
clientNoise = Noise.new(rng, clientPrivKey, outgoing = true) clientNoise = Noise.new(rng, clientPrivKey, outgoing = true)
conn = await transport2.dial(transport1.addrs[0]) conn = await transport2.dial(transport1.addrs[0])
let sconn = await clientNoise.secure(conn, true, Opt.some(serverInfo.peerId)) let sconn = await clientNoise.secure(conn, Opt.some(serverInfo.peerId))
var msg = newSeq[byte](6) var msg = newSeq[byte](6)
await sconn.readExactly(addr msg[0], 6) await sconn.readExactly(addr msg[0], 6)
@ -144,7 +144,7 @@ suite "Noise":
var conn: Connection var conn: Connection
try: try:
conn = await transport1.accept() conn = await transport1.accept()
discard await serverNoise.secure(conn, false, Opt.none(PeerId)) discard await serverNoise.secure(conn, Opt.none(PeerId))
except CatchableError: except CatchableError:
discard discard
finally: finally:
@ -160,7 +160,7 @@ suite "Noise":
var sconn: Connection = nil var sconn: Connection = nil
expect(NoiseDecryptTagError): expect(NoiseDecryptTagError):
sconn = await clientNoise.secure(conn, true, Opt.some(conn.peerId)) sconn = await clientNoise.secure(conn, Opt.some(conn.peerId))
await conn.close() await conn.close()
await handlerWait await handlerWait
@ -180,7 +180,7 @@ suite "Noise":
proc acceptHandler() {.async, gcsafe.} = proc acceptHandler() {.async, gcsafe.} =
let conn = await transport1.accept() let conn = await transport1.accept()
let sconn = await serverNoise.secure(conn, false, Opt.none(PeerId)) let sconn = await serverNoise.secure(conn, Opt.none(PeerId))
defer: defer:
await sconn.close() await sconn.close()
await conn.close() await conn.close()
@ -196,7 +196,7 @@ suite "Noise":
clientInfo = PeerInfo.new(clientPrivKey, transport1.addrs) clientInfo = PeerInfo.new(clientPrivKey, transport1.addrs)
clientNoise = Noise.new(rng, clientPrivKey, outgoing = true) clientNoise = Noise.new(rng, clientPrivKey, outgoing = true)
conn = await transport2.dial(transport1.addrs[0]) conn = await transport2.dial(transport1.addrs[0])
let sconn = await clientNoise.secure(conn, true, Opt.some(serverInfo.peerId)) let sconn = await clientNoise.secure(conn, Opt.some(serverInfo.peerId))
await sconn.write("Hello!") await sconn.write("Hello!")
await acceptFut await acceptFut
@ -223,7 +223,7 @@ suite "Noise":
proc acceptHandler() {.async, gcsafe.} = proc acceptHandler() {.async, gcsafe.} =
let conn = await transport1.accept() let conn = await transport1.accept()
let sconn = await serverNoise.secure(conn, false, Opt.none(PeerId)) let sconn = await serverNoise.secure(conn, Opt.none(PeerId))
defer: defer:
await sconn.close() await sconn.close()
let msg = await sconn.readLp(1024*1024) let msg = await sconn.readLp(1024*1024)
@ -237,7 +237,7 @@ suite "Noise":
clientInfo = PeerInfo.new(clientPrivKey, transport1.addrs) clientInfo = PeerInfo.new(clientPrivKey, transport1.addrs)
clientNoise = Noise.new(rng, clientPrivKey, outgoing = true) clientNoise = Noise.new(rng, clientPrivKey, outgoing = true)
conn = await transport2.dial(transport1.addrs[0]) conn = await transport2.dial(transport1.addrs[0])
let sconn = await clientNoise.secure(conn, true, Opt.some(serverInfo.peerId)) let sconn = await clientNoise.secure(conn, Opt.some(serverInfo.peerId))
await sconn.writeLp(hugePayload) await sconn.writeLp(hugePayload)
await readTask await readTask