fix(dcutr): update the DCUtR initiator transport direction to Inbound (#994)
This commit is contained in:
parent
6cdd4c911b
commit
061ea21729
|
@ -26,7 +26,7 @@ method connect*(
|
|||
addrs: seq[MultiAddress],
|
||||
forceDial = false,
|
||||
reuseConnection = true,
|
||||
upgradeDir = Direction.Out) {.async, base.} =
|
||||
dir = Direction.Out) {.async, base.} =
|
||||
## connect remote peer without negotiating
|
||||
## a protocol
|
||||
##
|
||||
|
|
|
@ -53,7 +53,7 @@ proc dialAndUpgrade(
|
|||
peerId: Opt[PeerId],
|
||||
hostname: string,
|
||||
address: MultiAddress,
|
||||
upgradeDir = Direction.Out):
|
||||
dir = Direction.Out):
|
||||
Future[Muxer] {.async.} =
|
||||
|
||||
for transport in self.transports: # for each transport
|
||||
|
@ -75,15 +75,19 @@ proc dialAndUpgrade(
|
|||
|
||||
let mux =
|
||||
try:
|
||||
dialed.transportDir = upgradeDir
|
||||
await transport.upgrade(dialed, upgradeDir, peerId)
|
||||
# This is for the very specific case of a simultaneous dial during DCUtR. In this case, both sides will have
|
||||
# an Outbound direction at the transport level. Therefore we update the DCUtR initiator transport direction to Inbound.
|
||||
# The if below is more general and might handle other use cases in the future.
|
||||
if dialed.dir != dir:
|
||||
dialed.dir = dir
|
||||
await transport.upgrade(dialed, peerId)
|
||||
except CatchableError as exc:
|
||||
# If we failed to establish the connection through one transport,
|
||||
# we won't succeeded through another - no use in trying again
|
||||
await dialed.close()
|
||||
debug "Upgrade failed", err = exc.msg, peerId = peerId.get(default(PeerId))
|
||||
if exc isnot CancelledError:
|
||||
if upgradeDir == Direction.Out:
|
||||
if dialed.dir == Direction.Out:
|
||||
libp2p_failed_upgrades_outgoing.inc()
|
||||
else:
|
||||
libp2p_failed_upgrades_incoming.inc()
|
||||
|
@ -91,7 +95,7 @@ proc dialAndUpgrade(
|
|||
# Try other address
|
||||
return nil
|
||||
|
||||
doAssert not isNil(mux), "connection died after upgrade " & $upgradeDir
|
||||
doAssert not isNil(mux), "connection died after upgrade " & $dialed.dir
|
||||
debug "Dial successful", peerId = mux.connection.peerId
|
||||
return mux
|
||||
return nil
|
||||
|
@ -128,7 +132,7 @@ proc dialAndUpgrade(
|
|||
self: Dialer,
|
||||
peerId: Opt[PeerId],
|
||||
addrs: seq[MultiAddress],
|
||||
upgradeDir = Direction.Out):
|
||||
dir = Direction.Out):
|
||||
Future[Muxer] {.async.} =
|
||||
|
||||
debug "Dialing peer", peerId = peerId.get(default(PeerId))
|
||||
|
@ -146,7 +150,7 @@ proc dialAndUpgrade(
|
|||
else: await self.nameResolver.resolveMAddress(expandedAddress)
|
||||
|
||||
for resolvedAddress in resolvedAddresses:
|
||||
result = await self.dialAndUpgrade(addrPeerId, hostname, resolvedAddress, upgradeDir)
|
||||
result = await self.dialAndUpgrade(addrPeerId, hostname, resolvedAddress, dir)
|
||||
if not isNil(result):
|
||||
return result
|
||||
|
||||
|
@ -164,7 +168,7 @@ proc internalConnect(
|
|||
addrs: seq[MultiAddress],
|
||||
forceDial: bool,
|
||||
reuseConnection = true,
|
||||
upgradeDir = Direction.Out):
|
||||
dir = Direction.Out):
|
||||
Future[Muxer] {.async.} =
|
||||
if Opt.some(self.localPeerId) == peerId:
|
||||
raise newException(CatchableError, "can't dial self!")
|
||||
|
@ -182,7 +186,7 @@ proc internalConnect(
|
|||
let slot = self.connManager.getOutgoingSlot(forceDial)
|
||||
let muxed =
|
||||
try:
|
||||
await self.dialAndUpgrade(peerId, addrs, upgradeDir)
|
||||
await self.dialAndUpgrade(peerId, addrs, dir)
|
||||
except CatchableError as exc:
|
||||
slot.release()
|
||||
raise exc
|
||||
|
@ -209,7 +213,7 @@ method connect*(
|
|||
addrs: seq[MultiAddress],
|
||||
forceDial = false,
|
||||
reuseConnection = true,
|
||||
upgradeDir = Direction.Out) {.async.} =
|
||||
dir = Direction.Out) {.async.} =
|
||||
## connect remote peer without negotiating
|
||||
## a protocol
|
||||
##
|
||||
|
@ -217,7 +221,7 @@ method connect*(
|
|||
if self.connManager.connCount(peerId) > 0 and reuseConnection:
|
||||
return
|
||||
|
||||
discard await self.internalConnect(Opt.some(peerId), addrs, forceDial, reuseConnection, upgradeDir)
|
||||
discard await self.internalConnect(Opt.some(peerId), addrs, forceDial, reuseConnection, dir)
|
||||
|
||||
method connect*(
|
||||
self: Dialer,
|
||||
|
|
|
@ -66,7 +66,7 @@ proc startSync*(self: DcutrClient, switch: Switch, remotePeerId: PeerId, addrs:
|
|||
|
||||
if peerDialableAddrs.len > self.maxDialableAddrs:
|
||||
peerDialableAddrs = peerDialableAddrs[0..<self.maxDialableAddrs]
|
||||
var futs = peerDialableAddrs.mapIt(switch.connect(stream.peerId, @[it], forceDial = true, reuseConnection = false, upgradeDir = Direction.In))
|
||||
var futs = peerDialableAddrs.mapIt(switch.connect(stream.peerId, @[it], forceDial = true, reuseConnection = false, dir = Direction.In))
|
||||
try:
|
||||
discard await anyCompleted(futs).wait(self.connectTimeout)
|
||||
debug "Dcutr initiator has directly connected to the remote peer."
|
||||
|
|
|
@ -56,7 +56,7 @@ proc new*(T: typedesc[Dcutr], switch: Switch, connectTimeout = 15.seconds, maxDi
|
|||
|
||||
if peerDialableAddrs.len > maxDialableAddrs:
|
||||
peerDialableAddrs = peerDialableAddrs[0..<maxDialableAddrs]
|
||||
var futs = peerDialableAddrs.mapIt(switch.connect(stream.peerId, @[it], forceDial = true, reuseConnection = false, upgradeDir = Direction.Out))
|
||||
var futs = peerDialableAddrs.mapIt(switch.connect(stream.peerId, @[it], forceDial = true, reuseConnection = false, dir = Direction.Out))
|
||||
try:
|
||||
discard await anyCompleted(futs).wait(connectTimeout)
|
||||
debug "Dcutr receiver has directly connected to the remote peer."
|
||||
|
|
|
@ -135,10 +135,9 @@ method init*(s: Secure) =
|
|||
|
||||
method secure*(s: Secure,
|
||||
conn: Connection,
|
||||
initiator: bool,
|
||||
peerId: Opt[PeerId]):
|
||||
Future[Connection] {.base.} =
|
||||
s.handleConn(conn, initiator, peerId)
|
||||
s.handleConn(conn, conn.dir == Direction.Out, peerId)
|
||||
|
||||
method readOnce*(s: SecureConn,
|
||||
pbytes: pointer,
|
||||
|
|
|
@ -141,10 +141,10 @@ method connect*(
|
|||
addrs: seq[MultiAddress],
|
||||
forceDial = false,
|
||||
reuseConnection = true,
|
||||
upgradeDir = Direction.Out): Future[void] {.public.} =
|
||||
dir = Direction.Out): Future[void] {.public.} =
|
||||
## Connects to a peer without opening a stream to it
|
||||
|
||||
s.dialer.connect(peerId, addrs, forceDial, reuseConnection, upgradeDir)
|
||||
s.dialer.connect(peerId, addrs, forceDial, reuseConnection, dir)
|
||||
|
||||
method connect*(
|
||||
s: Switch,
|
||||
|
@ -213,7 +213,7 @@ proc mount*[T: LPProtocol](s: Switch, proto: T, matcher: Matcher = nil)
|
|||
s.peerInfo.protocols.add(proto.codec)
|
||||
|
||||
proc upgrader(switch: Switch, trans: Transport, conn: Connection) {.async.} =
|
||||
let muxed = await trans.upgrade(conn, Direction.In, Opt.none(PeerId))
|
||||
let muxed = await trans.upgrade(conn, Opt.none(PeerId))
|
||||
switch.connManager.storeMuxer(muxed)
|
||||
await switch.peerStore.identify(muxed)
|
||||
trace "Connection upgrade succeeded"
|
||||
|
|
|
@ -83,13 +83,12 @@ proc dial*(
|
|||
method upgrade*(
|
||||
self: Transport,
|
||||
conn: Connection,
|
||||
direction: Direction,
|
||||
peerId: Opt[PeerId]): Future[Muxer] {.base, gcsafe.} =
|
||||
## base upgrade method that the transport uses to perform
|
||||
## transport specific upgrades
|
||||
##
|
||||
|
||||
self.upgrader.upgrade(conn, direction, peerId)
|
||||
self.upgrader.upgrade(conn, peerId)
|
||||
|
||||
method handles*(
|
||||
self: Transport,
|
||||
|
|
|
@ -32,8 +32,7 @@ proc getMuxerByCodec(self: MuxedUpgrade, muxerName: string): MuxerProvider =
|
|||
|
||||
proc mux*(
|
||||
self: MuxedUpgrade,
|
||||
conn: Connection,
|
||||
direction: Direction): Future[Muxer] {.async, gcsafe.} =
|
||||
conn: Connection): Future[Muxer] {.async, gcsafe.} =
|
||||
## mux connection
|
||||
|
||||
trace "Muxing connection", conn
|
||||
|
@ -42,7 +41,7 @@ proc mux*(
|
|||
return
|
||||
|
||||
let muxerName =
|
||||
if direction == Out: await self.ms.select(conn, self.muxers.mapIt(it.codec))
|
||||
if conn.dir == Out: await self.ms.select(conn, self.muxers.mapIt(it.codec))
|
||||
else: await MultistreamSelect.handle(conn, self.muxers.mapIt(it.codec))
|
||||
|
||||
if muxerName.len == 0 or muxerName == "na":
|
||||
|
@ -62,16 +61,15 @@ proc mux*(
|
|||
method upgrade*(
|
||||
self: MuxedUpgrade,
|
||||
conn: Connection,
|
||||
direction: Direction,
|
||||
peerId: Opt[PeerId]): Future[Muxer] {.async.} =
|
||||
trace "Upgrading connection", conn, direction
|
||||
trace "Upgrading connection", conn, direction = conn.dir
|
||||
|
||||
let sconn = await self.secure(conn, direction, peerId) # secure the connection
|
||||
let sconn = await self.secure(conn, peerId) # secure the connection
|
||||
if isNil(sconn):
|
||||
raise newException(UpgradeFailedError,
|
||||
"unable to secure connection, stopping upgrade")
|
||||
|
||||
let muxer = await self.mux(sconn, direction) # mux it if possible
|
||||
let muxer = await self.mux(sconn) # mux it if possible
|
||||
if muxer == nil:
|
||||
raise newException(UpgradeFailedError,
|
||||
"a muxer is required for outgoing connections")
|
||||
|
@ -84,7 +82,7 @@ method upgrade*(
|
|||
raise newException(UpgradeFailedError,
|
||||
"Connection closed or missing peer info, stopping upgrade")
|
||||
|
||||
trace "Upgraded connection", conn, sconn, direction
|
||||
trace "Upgraded connection", conn, sconn, direction = conn.dir
|
||||
return muxer
|
||||
|
||||
proc new*(
|
||||
|
|
|
@ -40,20 +40,18 @@ type
|
|||
method upgrade*(
|
||||
self: Upgrade,
|
||||
conn: Connection,
|
||||
direction: Direction,
|
||||
peerId: Opt[PeerId]): Future[Muxer] {.base.} =
|
||||
doAssert(false, "Not implemented!")
|
||||
|
||||
proc secure*(
|
||||
self: Upgrade,
|
||||
conn: Connection,
|
||||
direction: Direction,
|
||||
peerId: Opt[PeerId]): Future[Connection] {.async, gcsafe.} =
|
||||
if self.secureManagers.len <= 0:
|
||||
raise newException(UpgradeFailedError, "No secure managers registered!")
|
||||
|
||||
let codec =
|
||||
if direction == Out: await self.ms.select(conn, self.secureManagers.mapIt(it.codec))
|
||||
if conn.dir == Out: await self.ms.select(conn, self.secureManagers.mapIt(it.codec))
|
||||
else: await MultistreamSelect.handle(conn, self.secureManagers.mapIt(it.codec))
|
||||
if codec.len == 0:
|
||||
raise newException(UpgradeFailedError, "Unable to negotiate a secure channel!")
|
||||
|
@ -65,4 +63,4 @@ proc secure*(
|
|||
# let's avoid duplicating checks but detect if it fails to do it properly
|
||||
doAssert(secureProtocol.len > 0)
|
||||
|
||||
return await secureProtocol[0].secure(conn, direction == Out, peerId)
|
||||
return await secureProtocol[0].secure(conn, peerId)
|
||||
|
|
|
@ -24,7 +24,7 @@ type
|
|||
addrs: seq[MultiAddress],
|
||||
forceDial = false,
|
||||
reuseConnection = true,
|
||||
upgradeDir = Direction.Out): Future[void] {.gcsafe, async.}
|
||||
dir = Direction.Out): Future[void] {.gcsafe, async.}
|
||||
|
||||
method connect*(
|
||||
self: SwitchStub,
|
||||
|
@ -32,11 +32,11 @@ method connect*(
|
|||
addrs: seq[MultiAddress],
|
||||
forceDial = false,
|
||||
reuseConnection = true,
|
||||
upgradeDir = Direction.Out) {.async.} =
|
||||
dir = Direction.Out) {.async.} =
|
||||
if (self.connectStub != nil):
|
||||
await self.connectStub(self, peerId, addrs, forceDial, reuseConnection, upgradeDir)
|
||||
await self.connectStub(self, peerId, addrs, forceDial, reuseConnection, dir)
|
||||
else:
|
||||
await self.switch.connect(peerId, addrs, forceDial, reuseConnection, upgradeDir)
|
||||
await self.switch.connect(peerId, addrs, forceDial, reuseConnection, dir)
|
||||
|
||||
proc new*(T: typedesc[SwitchStub], switch: Switch, connectStub: connectStubType = nil): T =
|
||||
return SwitchStub(
|
||||
|
|
|
@ -96,7 +96,7 @@ suite "Dcutr":
|
|||
addrs: seq[MultiAddress],
|
||||
forceDial = false,
|
||||
reuseConnection = true,
|
||||
upgradeDir = Direction.Out): Future[void] {.async.} =
|
||||
dir = Direction.Out): Future[void] {.async.} =
|
||||
await sleepAsync(100.millis)
|
||||
|
||||
let behindNATSwitch = SwitchStub.new(newStandardSwitch(), connectTimeoutProc)
|
||||
|
@ -115,7 +115,7 @@ suite "Dcutr":
|
|||
addrs: seq[MultiAddress],
|
||||
forceDial = false,
|
||||
reuseConnection = true,
|
||||
upgradeDir = Direction.Out): Future[void] {.async.} =
|
||||
dir = Direction.Out): Future[void] {.async.} =
|
||||
raise newException(CatchableError, "error")
|
||||
|
||||
let behindNATSwitch = SwitchStub.new(newStandardSwitch(), connectErrorProc)
|
||||
|
@ -163,7 +163,7 @@ suite "Dcutr":
|
|||
addrs: seq[MultiAddress],
|
||||
forceDial = false,
|
||||
reuseConnection = true,
|
||||
upgradeDir = Direction.Out): Future[void] {.async.} =
|
||||
dir = Direction.Out): Future[void] {.async.} =
|
||||
await sleepAsync(100.millis)
|
||||
|
||||
await ductrServerTest(connectProc)
|
||||
|
@ -175,7 +175,7 @@ suite "Dcutr":
|
|||
addrs: seq[MultiAddress],
|
||||
forceDial = false,
|
||||
reuseConnection = true,
|
||||
upgradeDir = Direction.Out): Future[void] {.async.} =
|
||||
dir = Direction.Out): Future[void] {.async.} =
|
||||
raise newException(CatchableError, "error")
|
||||
|
||||
await ductrServerTest(connectProc)
|
||||
|
|
|
@ -210,7 +210,7 @@ suite "Hole Punching":
|
|||
addrs: seq[MultiAddress],
|
||||
forceDial = false,
|
||||
reuseConnection = true,
|
||||
upgradeDir = Direction.Out): Future[void] {.async.} =
|
||||
dir = Direction.Out): Future[void] {.async.} =
|
||||
self.connectStub = nil # this stub should be called only once
|
||||
raise newException(CatchableError, "error")
|
||||
|
||||
|
|
|
@ -100,7 +100,7 @@ suite "Noise":
|
|||
|
||||
proc acceptHandler() {.async.} =
|
||||
let conn = await transport1.accept()
|
||||
let sconn = await serverNoise.secure(conn, false, Opt.none(PeerId))
|
||||
let sconn = await serverNoise.secure(conn, Opt.none(PeerId))
|
||||
try:
|
||||
await sconn.write("Hello!")
|
||||
finally:
|
||||
|
@ -115,7 +115,7 @@ suite "Noise":
|
|||
clientNoise = Noise.new(rng, clientPrivKey, outgoing = true)
|
||||
conn = await transport2.dial(transport1.addrs[0])
|
||||
|
||||
let sconn = await clientNoise.secure(conn, true, Opt.some(serverInfo.peerId))
|
||||
let sconn = await clientNoise.secure(conn, Opt.some(serverInfo.peerId))
|
||||
|
||||
var msg = newSeq[byte](6)
|
||||
await sconn.readExactly(addr msg[0], 6)
|
||||
|
@ -144,7 +144,7 @@ suite "Noise":
|
|||
var conn: Connection
|
||||
try:
|
||||
conn = await transport1.accept()
|
||||
discard await serverNoise.secure(conn, false, Opt.none(PeerId))
|
||||
discard await serverNoise.secure(conn, Opt.none(PeerId))
|
||||
except CatchableError:
|
||||
discard
|
||||
finally:
|
||||
|
@ -160,7 +160,7 @@ suite "Noise":
|
|||
|
||||
var sconn: Connection = nil
|
||||
expect(NoiseDecryptTagError):
|
||||
sconn = await clientNoise.secure(conn, true, Opt.some(conn.peerId))
|
||||
sconn = await clientNoise.secure(conn, Opt.some(conn.peerId))
|
||||
|
||||
await conn.close()
|
||||
await handlerWait
|
||||
|
@ -180,7 +180,7 @@ suite "Noise":
|
|||
|
||||
proc acceptHandler() {.async, gcsafe.} =
|
||||
let conn = await transport1.accept()
|
||||
let sconn = await serverNoise.secure(conn, false, Opt.none(PeerId))
|
||||
let sconn = await serverNoise.secure(conn, Opt.none(PeerId))
|
||||
defer:
|
||||
await sconn.close()
|
||||
await conn.close()
|
||||
|
@ -196,7 +196,7 @@ suite "Noise":
|
|||
clientInfo = PeerInfo.new(clientPrivKey, transport1.addrs)
|
||||
clientNoise = Noise.new(rng, clientPrivKey, outgoing = true)
|
||||
conn = await transport2.dial(transport1.addrs[0])
|
||||
let sconn = await clientNoise.secure(conn, true, Opt.some(serverInfo.peerId))
|
||||
let sconn = await clientNoise.secure(conn, Opt.some(serverInfo.peerId))
|
||||
|
||||
await sconn.write("Hello!")
|
||||
await acceptFut
|
||||
|
@ -223,7 +223,7 @@ suite "Noise":
|
|||
|
||||
proc acceptHandler() {.async, gcsafe.} =
|
||||
let conn = await transport1.accept()
|
||||
let sconn = await serverNoise.secure(conn, false, Opt.none(PeerId))
|
||||
let sconn = await serverNoise.secure(conn, Opt.none(PeerId))
|
||||
defer:
|
||||
await sconn.close()
|
||||
let msg = await sconn.readLp(1024*1024)
|
||||
|
@ -237,7 +237,7 @@ suite "Noise":
|
|||
clientInfo = PeerInfo.new(clientPrivKey, transport1.addrs)
|
||||
clientNoise = Noise.new(rng, clientPrivKey, outgoing = true)
|
||||
conn = await transport2.dial(transport1.addrs[0])
|
||||
let sconn = await clientNoise.secure(conn, true, Opt.some(serverInfo.peerId))
|
||||
let sconn = await clientNoise.secure(conn, Opt.some(serverInfo.peerId))
|
||||
|
||||
await sconn.writeLp(hugePayload)
|
||||
await readTask
|
||||
|
|
Loading…
Reference in New Issue