diff --git a/libp2p/protocols/connectivity/autonat.nim b/libp2p/protocols/connectivity/autonat.nim index c63d7e531..47bd3a4d7 100644 --- a/libp2p/protocols/connectivity/autonat.nim +++ b/libp2p/protocols/connectivity/autonat.nim @@ -46,21 +46,21 @@ type InternalError = 300 AutonatPeerInfo* = object - id: Option[PeerId] - addrs: seq[MultiAddress] + id*: Option[PeerId] + addrs*: seq[MultiAddress] AutonatDial* = object - peerInfo: Option[AutonatPeerInfo] + peerInfo*: Option[AutonatPeerInfo] AutonatDialResponse* = object status*: ResponseStatus text*: Option[string] ma*: Option[MultiAddress] - AutonatMsg = object - msgType: MsgType - dial: Option[AutonatDial] - response: Option[AutonatDialResponse] + AutonatMsg* = object + msgType*: MsgType + dial*: Option[AutonatDial] + response*: Option[AutonatDialResponse] proc encode*(msg: AutonatMsg): ProtoBuffer = result = initProtoBuffer() @@ -120,7 +120,7 @@ proc encode*(r: AutonatDialResponse): ProtoBuffer = result.write(3, bufferResponse.buffer) result.finish() -proc decode(_: typedesc[AutonatMsg], buf: seq[byte]): Option[AutonatMsg] = +proc decode*(_: typedesc[AutonatMsg], buf: seq[byte]): Option[AutonatMsg] = var msgTypeOrd: uint32 pbDial: ProtoBuffer @@ -203,6 +203,7 @@ type Autonat* = ref object of LPProtocol sem: AsyncSemaphore switch*: Switch + dialTimeout: Duration method dialMe*(a: Autonat, pid: PeerId, addrs: seq[MultiAddress] = newSeq[MultiAddress]()): Future[MultiAddress] {.base, async.} = @@ -240,7 +241,7 @@ method dialMe*(a: Autonat, pid: PeerId, addrs: seq[MultiAddress] = newSeq[MultiA proc tryDial(a: Autonat, conn: Connection, addrs: seq[MultiAddress]) {.async.} = try: await a.sem.acquire() - let ma = await a.switch.dialer.tryDial(conn.peerId, addrs) + let ma = await a.switch.dialer.tryDial(conn.peerId, addrs).wait(a.dialTimeout) if ma.isSome: await conn.sendResponseOk(ma.get()) else: @@ -298,8 +299,8 @@ proc handleDial(a: Autonat, conn: Connection, msg: AutonatMsg): Future[void] = return conn.sendResponseError(DialRefused, "No dialable address") return a.tryDial(conn, toSeq(addrs)) -proc new*(T: typedesc[Autonat], switch: Switch, semSize: int = 1): T = - let autonat = T(switch: switch, sem: newAsyncSemaphore(semSize)) +proc new*(T: typedesc[Autonat], switch: Switch, semSize: int = 1, dialTimeout = 15.seconds): T = + let autonat = T(switch: switch, sem: newAsyncSemaphore(semSize), dialTimeout: dialTimeout) autonat.init() autonat diff --git a/libp2p/services/autonatservice.nim b/libp2p/services/autonatservice.nim index a70aa0001..5d13e4233 100644 --- a/libp2p/services/autonatservice.nim +++ b/libp2p/services/autonatservice.nim @@ -52,7 +52,7 @@ proc new*( numPeersToAsk: int = 5, maxQueueSize: int = 10, minConfidence: float = 0.3, - dialTimeout = 5.seconds): T = + dialTimeout = 30.seconds): T = return T( scheduleInterval: scheduleInterval, networkReachability: Unknown, diff --git a/tests/testautonat.nim b/tests/testautonat.nim index e1cde99fb..7542cddf5 100644 --- a/tests/testautonat.nim +++ b/tests/testautonat.nim @@ -2,6 +2,8 @@ import std/options import chronos import ../libp2p/[ + transports/tcptransport, + upgrademngrs/upgrade, builders, protocols/connectivity/autonat ], @@ -58,3 +60,30 @@ suite "Autonat": expect AutonatUnreachableError: discard await Autonat.new(src).dialMe(dst.peerInfo.peerId, dst.peerInfo.addrs) await allFutures(src.stop(), dst.stop()) + + asyncTest "Timeout is triggered in autonat handle": + let + src = newStandardSwitch() + dst = newStandardSwitch() + autonat = Autonat.new(dst, dialTimeout = 1.seconds) + doesNothingListener = TcpTransport.new(upgrade = Upgrade()) + + dst.mount(autonat) + await src.start() + await dst.start() + await doesNothingListener.start(@[MultiAddress.init("/ip4/0.0.0.0/tcp/0").tryGet()]) + + await src.connect(dst.peerInfo.peerId, dst.peerInfo.addrs) + let conn = await src.dial(dst.peerInfo.peerId, @[AutonatCodec]) + let buffer = AutonatDial(peerInfo: some(AutonatPeerInfo( + id: some(src.peerInfo.peerId), + # we ask to be dialed in the does nothing listener instead + addrs: doesNothingListener.addrs + ))).encode().buffer + await conn.writeLp(buffer) + let response = AutonatMsg.decode(await conn.readLp(1024)).get().response.get() + check: + response.status == DialError + response.text.get() == "Timeout exceeded!" + response.ma.isNone() + await allFutures(doesNothingListener.stop(), src.stop(), dst.stop())