From e0f70b71778bf9613421261b22a56499c913f806 Mon Sep 17 00:00:00 2001 From: diegomrsantos Date: Fri, 9 Feb 2024 11:51:27 +0100 Subject: [PATCH] improvement: enhanced checkExpiring macro with custom timeout (#1023) --- tests/helpers.nim | 88 +++++++++++++++++++++++++---- tests/pubsub/testfloodsub.nim | 4 +- tests/pubsub/testgossipinternal.nim | 8 +-- tests/pubsub/testgossipsub.nim | 34 +++++------ tests/testconnmngr.nim | 6 +- tests/testdcutr.nim | 6 +- tests/testhelpers.nim | 42 ++++++++++++++ tests/testhpservice.nim | 8 +-- tests/testidentify.nim | 4 +- tests/testmplex.nim | 8 +-- tests/testnative.nim | 3 +- tests/testrendezvousinterface.nim | 4 +- tests/testswitch.nim | 18 +++--- 13 files changed, 170 insertions(+), 63 deletions(-) create mode 100644 tests/testhelpers.nim diff --git a/tests/helpers.nim b/tests/helpers.nim index efa88c05f..ad4404af2 100644 --- a/tests/helpers.nim +++ b/tests/helpers.nim @@ -1,6 +1,7 @@ {.push raises: [].} import chronos +import macros import algorithm import ../libp2p/transports/tcptransport @@ -110,20 +111,83 @@ proc bridgedConnections*: (Connection, Connection) = await connA.pushData(data) return (connA, connB) - -proc checkExpiringInternal(cond: proc(): bool {.raises: [], gcsafe.} ): Future[bool] {.async.} = - let start = Moment.now() - while true: - if Moment.now() > (start + chronos.seconds(10)): - return false - elif cond(): - return true +macro checkUntilCustomTimeout*(timeout: Duration, code: untyped): untyped = + ## Periodically checks a given condition until it is true or a timeout occurs. + ## + ## `code`: untyped - A condition expression that should eventually evaluate to true. + ## `timeout`: Duration - The maximum duration to wait for the condition to be true. + ## + ## Examples: + ## ```nim + ## # Example 1: + ## asyncTest "checkUntilCustomTimeout should pass if the condition is true": + ## let a = 2 + ## let b = 2 + ## checkUntilCustomTimeout(2.seconds): + ## a == b + ## + ## # Example 2: Multiple conditions + ## asyncTest "checkUntilCustomTimeout should pass if the conditions are true": + ## let a = 2 + ## let b = 2 + ## checkUntilCustomTimeout(5.seconds):: + ## a == b + ## a == 2 + ## b == 1 + ## ``` + # Helper proc to recursively build a combined boolean expression + proc buildAndExpr(n: NimNode): NimNode = + if n.kind == nnkStmtList and n.len > 0: + var combinedExpr = n[0] # Start with the first expression + for i in 1.. (start + `timeout`): + checkpoint("[TIMEOUT] Timeout was reached and the conditions were not true. Check if the code is working as " & + "expected or consider increasing the timeout param.") + check `code` + return + else: + if `combinedBoolExpr`: + return + else: + await sleepAsync(1.millis) + await checkExpiringInternal() + +macro checkUntilTimeout*(code: untyped): untyped = + ## Same as `checkUntilCustomTimeout` but with a default timeout of 10 seconds. + ## + ## Examples: + ## ```nim + ## # Example 1: + ## asyncTest "checkUntilTimeout should pass if the condition is true": + ## let a = 2 + ## let b = 2 + ## checkUntilTimeout: + ## a == b + ## + ## # Example 2: Multiple conditions + ## asyncTest "checkUntilTimeout should pass if the conditions are true": + ## let a = 2 + ## let b = 2 + ## checkUntilTimeout: + ## a == b + ## a == 2 + ## b == 1 + ## ``` + result = quote do: + checkUntilCustomTimeout(10.seconds, `code`) proc unorderedCompare*[T](a, b: seq[T]): bool = if a == b: diff --git a/tests/pubsub/testfloodsub.nim b/tests/pubsub/testfloodsub.nim index bb4355009..e182255b6 100644 --- a/tests/pubsub/testfloodsub.nim +++ b/tests/pubsub/testfloodsub.nim @@ -361,7 +361,7 @@ suite "FloodSub": check (await smallNode[0].publish("foo", smallMessage1)) > 0 check (await bigNode[0].publish("foo", smallMessage2)) > 0 - checkExpiring: messageReceived == 2 + checkUntilTimeout: messageReceived == 2 check (await smallNode[0].publish("foo", bigMessage)) > 0 check (await bigNode[0].publish("foo", bigMessage)) > 0 @@ -396,7 +396,7 @@ suite "FloodSub": check (await bigNode1[0].publish("foo", bigMessage)) > 0 - checkExpiring: messageReceived == 1 + checkUntilTimeout: messageReceived == 1 await allFuturesThrowing( bigNode1[0].switch.stop(), diff --git a/tests/pubsub/testgossipinternal.nim b/tests/pubsub/testgossipinternal.nim index c97ac8a7c..51ae65ff9 100644 --- a/tests/pubsub/testgossipinternal.nim +++ b/tests/pubsub/testgossipinternal.nim @@ -781,7 +781,7 @@ suite "GossipSub internal": ihave: @[ControlIHave(topicId: "foobar", messageIds: iwantMessageIds)] )))) - checkExpiring: receivedMessages[] == sentMessages + checkUntilTimeout: receivedMessages[] == sentMessages check receivedMessages[].len == 2 await teardownTest(gossip0, gossip1) @@ -799,7 +799,7 @@ suite "GossipSub internal": )))) await sleepAsync(300.milliseconds) - checkExpiring: receivedMessages[].len == 0 + checkUntilTimeout: receivedMessages[].len == 0 await teardownTest(gossip0, gossip1) @@ -815,7 +815,7 @@ suite "GossipSub internal": ihave: @[ControlIHave(topicId: "foobar", messageIds: bigIWantMessageIds)] )))) - checkExpiring: receivedMessages[] == sentMessages + checkUntilTimeout: receivedMessages[] == sentMessages check receivedMessages[].len == 2 await teardownTest(gossip0, gossip1) @@ -840,7 +840,7 @@ suite "GossipSub internal": else: smallestSet.incl(seqs[1]) - checkExpiring: receivedMessages[] == smallestSet + checkUntilTimeout: receivedMessages[] == smallestSet check receivedMessages[].len == 1 await teardownTest(gossip0, gossip1) diff --git a/tests/pubsub/testgossipsub.nim b/tests/pubsub/testgossipsub.nim index c6e1ed321..1081fe255 100644 --- a/tests/pubsub/testgossipsub.nim +++ b/tests/pubsub/testgossipsub.nim @@ -310,9 +310,9 @@ suite "GossipSub": let gossip1 = GossipSub(nodes[0]) let gossip2 = GossipSub(nodes[1]) - checkExpiring: - "foobar" in gossip2.topics and - "foobar" in gossip1.gossipsub and + checkUntilTimeout: + "foobar" in gossip2.topics + "foobar" in gossip1.gossipsub gossip1.gossipsub.hasPeerId("foobar", gossip2.peerInfo.peerId) await allFuturesThrowing( @@ -454,9 +454,9 @@ suite "GossipSub": nodes[1].subscribe("foobar", handler) let gsNode = GossipSub(nodes[1]) - checkExpiring: - gsNode.mesh.getOrDefault("foobar").len == 0 and - GossipSub(nodes[0]).mesh.getOrDefault("foobar").len == 0 and + checkUntilTimeout: + gsNode.mesh.getOrDefault("foobar").len == 0 + GossipSub(nodes[0]).mesh.getOrDefault("foobar").len == 0 ( GossipSub(nodes[0]).gossipsub.getOrDefault("foobar").len == 1 or GossipSub(nodes[0]).fanout.getOrDefault("foobar").len == 1 @@ -572,16 +572,16 @@ suite "GossipSub": gossip1.seen = TimedCache[MessageId].init() gossip3.seen = TimedCache[MessageId].init() let msgId = toSeq(gossip2.validationSeen.keys)[0] - checkExpiring(try: gossip2.validationSeen[msgId].len > 0 except: false) + checkUntilTimeout(try: gossip2.validationSeen[msgId].len > 0 except: false) result = ValidationResult.Accept bFinished.complete() nodes[1].addValidator("foobar", slowValidator) - checkExpiring( - gossip1.mesh.getOrDefault("foobar").len == 2 and - gossip2.mesh.getOrDefault("foobar").len == 2 and - gossip3.mesh.getOrDefault("foobar").len == 2) + checkUntilTimeout: + gossip1.mesh.getOrDefault("foobar").len == 2 + gossip2.mesh.getOrDefault("foobar").len == 2 + gossip3.mesh.getOrDefault("foobar").len == 2 tryPublish await nodes[0].publish("foobar", "Hello!".toBytes()), 2 await bFinished @@ -676,7 +676,7 @@ suite "GossipSub": # Now try with a mesh gossip1.subscribe("foobar", handler) - checkExpiring: gossip1.mesh.peers("foobar") > 5 + checkUntilTimeout: gossip1.mesh.peers("foobar") > 5 # use a different length so that the message is not equal to the last check (await nodes[0].publish("foobar", newSeq[byte](500_000))) == numPeersSecondMsg @@ -913,13 +913,13 @@ suite "GossipSub": gossip3.broadcast(gossip3.mesh["foobar"], RPCMsg(control: some(ControlMessage( idontwant: @[ControlIWant(messageIds: @[newSeq[byte](10)])] )))) - checkExpiring: gossip2.mesh.getOrDefault("foobar").anyIt(it.heDontWants[^1].len == 1) + checkUntilTimeout: gossip2.mesh.getOrDefault("foobar").anyIt(it.heDontWants[^1].len == 1) tryPublish await nodes[0].publish("foobar", newSeq[byte](10000)), 1 await bFinished - checkExpiring: toSeq(gossip3.mesh.getOrDefault("foobar")).anyIt(it.heDontWants[^1].len == 1) + checkUntilTimeout: toSeq(gossip3.mesh.getOrDefault("foobar")).anyIt(it.heDontWants[^1].len == 1) check: toSeq(gossip1.mesh.getOrDefault("foobar")).anyIt(it.heDontWants[^1].len == 0) await allFuturesThrowing( @@ -1000,7 +1000,7 @@ suite "GossipSub": gossip1.parameters.disconnectPeerAboveRateLimit = true await gossip0.peers[gossip1.switch.peerInfo.peerId].sendEncoded(newSeqWith[byte](35, 1.byte)) - checkExpiring gossip1.switch.isConnected(gossip0.switch.peerInfo.peerId) == false + checkUntilTimeout gossip1.switch.isConnected(gossip0.switch.peerInfo.peerId) == false check currentRateLimitHits() == rateLimitHits + 2 await stopNodes(nodes) @@ -1029,7 +1029,7 @@ suite "GossipSub": ]))) gossip0.broadcast(gossip0.mesh["foobar"], msg2) - checkExpiring gossip1.switch.isConnected(gossip0.switch.peerInfo.peerId) == false + checkUntilTimeout gossip1.switch.isConnected(gossip0.switch.peerInfo.peerId) == false check currentRateLimitHits() == rateLimitHits + 2 await stopNodes(nodes) @@ -1059,7 +1059,7 @@ suite "GossipSub": gossip1.parameters.disconnectPeerAboveRateLimit = true gossip0.broadcast(gossip0.mesh[topic], RPCMsg(messages: @[Message(topicIDs: @[topic], data: newSeq[byte](35))])) - checkExpiring gossip1.switch.isConnected(gossip0.switch.peerInfo.peerId) == false + checkUntilTimeout gossip1.switch.isConnected(gossip0.switch.peerInfo.peerId) == false check currentRateLimitHits() == rateLimitHits + 2 await stopNodes(nodes) diff --git a/tests/testconnmngr.nim b/tests/testconnmngr.nim index 00257e8b2..d81366e18 100644 --- a/tests/testconnmngr.nim +++ b/tests/testconnmngr.nim @@ -215,7 +215,7 @@ suite "Connection Manager": await connMngr.close() - checkExpiring: waitedConn3.cancelled() + checkUntilTimeout: waitedConn3.cancelled() await allFuturesThrowing( allFutures(muxs.mapIt( it.close() ))) @@ -231,7 +231,7 @@ suite "Connection Manager": await muxer.close() - checkExpiring: muxer notin connMngr + checkUntilTimeout: muxer notin connMngr await connMngr.close() @@ -254,7 +254,7 @@ suite "Connection Manager": check peerId in connMngr await connMngr.dropPeer(peerId) - checkExpiring: peerId notin connMngr + checkUntilTimeout: peerId notin connMngr check isNil(connMngr.selectMuxer(peerId, Direction.In)) check isNil(connMngr.selectMuxer(peerId, Direction.Out)) diff --git a/tests/testdcutr.nim b/tests/testdcutr.nim index da6fb5e38..a970311d6 100644 --- a/tests/testdcutr.nim +++ b/tests/testdcutr.nim @@ -64,7 +64,7 @@ suite "Dcutr": await DcutrClient.new().startSync(behindNATSwitch, publicSwitch.peerInfo.peerId, behindNATSwitch.peerInfo.addrs) .wait(300.millis) - checkExpiring: + checkUntilTimeout: # we still expect a new connection to be open by the receiver peer acting as the dcutr server behindNATSwitch.connManager.connCount(publicSwitch.peerInfo.peerId) == 2 @@ -83,7 +83,7 @@ suite "Dcutr": body - checkExpiring: + checkUntilTimeout: # we still expect a new connection to be open by the receiver peer acting as the dcutr server behindNATSwitch.connManager.connCount(publicSwitch.peerInfo.peerId) == 2 @@ -150,7 +150,7 @@ suite "Dcutr": await DcutrClient.new().startSync(behindNATSwitch, publicSwitch.peerInfo.peerId, behindNATSwitch.peerInfo.addrs) .wait(300.millis) - checkExpiring: + checkUntilTimeout: # we still expect a new connection to be open by the receiver peer acting as the dcutr server behindNATSwitch.connManager.connCount(publicSwitch.peerInfo.peerId) == 1 diff --git a/tests/testhelpers.nim b/tests/testhelpers.nim new file mode 100644 index 000000000..0898b394e --- /dev/null +++ b/tests/testhelpers.nim @@ -0,0 +1,42 @@ +{.used.} + +# Nim-Libp2p +# Copyright (c) 2023 Status Research & Development GmbH +# Licensed under either of +# * Apache License, version 2.0, ([LICENSE-APACHE](LICENSE-APACHE)) +# * MIT license ([LICENSE-MIT](LICENSE-MIT)) +# at your option. +# This file may not be copied, modified, or distributed except according to +# those terms. + +import ./helpers + +suite "Helpers": + + asyncTest "checkUntilTimeout should pass if the condition is true": + let a = 2 + let b = 2 + checkUntilTimeout: + a == b + + asyncTest "checkUntilTimeout should pass if the conditions are true": + let a = 2 + let b = 2 + checkUntilTimeout: + a == b + a == 2 + b == 2 + + asyncTest "checkUntilCustomTimeout should pass when the condition is true": + let a = 2 + let b = 2 + checkUntilCustomTimeout(2.seconds): + a == b + + asyncTest "checkUntilCustomTimeout should pass when the conditions are true": + let a = 2 + let b = 2 + checkUntilCustomTimeout(5.seconds): + a == b + a == 2 + b == 2 diff --git a/tests/testhpservice.nim b/tests/testhpservice.nim index 789c46bf6..7cf59a745 100644 --- a/tests/testhpservice.nim +++ b/tests/testhpservice.nim @@ -89,8 +89,8 @@ suite "Hole Punching": await publicPeerSwitch.connect(privatePeerSwitch.peerInfo.peerId, (await privatePeerRelayAddr)) - checkExpiring: - privatePeerSwitch.connManager.connCount(publicPeerSwitch.peerInfo.peerId) == 1 and + checkUntilTimeout: + privatePeerSwitch.connManager.connCount(publicPeerSwitch.peerInfo.peerId) == 1 not isRelayed(privatePeerSwitch.connManager.selectMuxer(publicPeerSwitch.peerInfo.peerId).connection) await allFuturesThrowing( @@ -127,8 +127,8 @@ suite "Hole Punching": await publicPeerSwitch.connect(privatePeerSwitch.peerInfo.peerId, (await privatePeerRelayAddr)) - checkExpiring: - privatePeerSwitch.connManager.connCount(publicPeerSwitch.peerInfo.peerId) == 1 and + checkUntilTimeout: + privatePeerSwitch.connManager.connCount(publicPeerSwitch.peerInfo.peerId) == 1 not isRelayed(privatePeerSwitch.connManager.selectMuxer(publicPeerSwitch.peerInfo.peerId).connection) await allFuturesThrowing( diff --git a/tests/testidentify.nim b/tests/testidentify.nim index 3faded8df..38605d9e5 100644 --- a/tests/testidentify.nim +++ b/tests/testidentify.nim @@ -219,8 +219,8 @@ suite "Identify": await identifyPush2.push(switch2.peerInfo, conn) - checkExpiring: switch1.peerStore[ProtoBook][switch2.peerInfo.peerId] == switch2.peerInfo.protocols - checkExpiring: switch1.peerStore[AddressBook][switch2.peerInfo.peerId] == switch2.peerInfo.addrs + checkUntilTimeout: switch1.peerStore[ProtoBook][switch2.peerInfo.peerId] == switch2.peerInfo.protocols + checkUntilTimeout: switch1.peerStore[AddressBook][switch2.peerInfo.peerId] == switch2.peerInfo.addrs await closeAll() diff --git a/tests/testmplex.nim b/tests/testmplex.nim index 00400f0de..7ffef9c4d 100644 --- a/tests/testmplex.nim +++ b/tests/testmplex.nim @@ -829,7 +829,7 @@ suite "Mplex": check: unorderedCompare(dialStreams, mplexDial.getStreams()) - checkExpiring: listenStreams.len == 10 and dialStreams.len == 10 + checkUntilTimeout: listenStreams.len == 10 and dialStreams.len == 10 await mplexListen.close() await allFuturesThrowing( @@ -876,7 +876,7 @@ suite "Mplex": check: unorderedCompare(dialStreams, mplexDial.getStreams()) - checkExpiring: listenStreams.len == 10 and dialStreams.len == 10 + checkUntilTimeout: listenStreams.len == 10 and dialStreams.len == 10 mplexHandle.cancel() await allFuturesThrowing( @@ -920,7 +920,7 @@ suite "Mplex": check: unorderedCompare(dialStreams, mplexDial.getStreams()) - checkExpiring: listenStreams.len == 10 and dialStreams.len == 10 + checkUntilTimeout: listenStreams.len == 10 and dialStreams.len == 10 await conn.close() await allFuturesThrowing( @@ -967,7 +967,7 @@ suite "Mplex": check: unorderedCompare(dialStreams, mplexDial.getStreams()) - checkExpiring: listenStreams.len == 10 and dialStreams.len == 10 + checkUntilTimeout: listenStreams.len == 10 and dialStreams.len == 10 await listenConn.closeWithEOF() await allFuturesThrowing( diff --git a/tests/testnative.nim b/tests/testnative.nim index 9c6b2dc0e..f6e933c4e 100644 --- a/tests/testnative.nim +++ b/tests/testnative.nim @@ -57,4 +57,5 @@ import testtcptransport, testautorelay, testdcutr, testhpservice, - testutility + testutility, + testhelpers diff --git a/tests/testrendezvousinterface.nim b/tests/testrendezvousinterface.nim index b612094ef..f9730347c 100644 --- a/tests/testrendezvousinterface.nim +++ b/tests/testrendezvousinterface.nim @@ -62,8 +62,8 @@ suite "RendezVous Interface": dm.advertise(RdvNamespace("ns1")) dm.advertise(RdvNamespace("ns2")) - checkExpiring: rdv.numAdvertiseNs1 >= 5 - checkExpiring: rdv.numAdvertiseNs2 >= 5 + checkUntilTimeout: rdv.numAdvertiseNs1 >= 5 + checkUntilTimeout: rdv.numAdvertiseNs2 >= 5 await client.stop() asyncTest "Check timeToAdvertise interval": diff --git a/tests/testswitch.nim b/tests/testswitch.nim index be2f7935c..b9638cafe 100644 --- a/tests/testswitch.nim +++ b/tests/testswitch.nim @@ -283,12 +283,12 @@ suite "Switch": await switch2.disconnect(switch1.peerInfo.peerId) check not switch2.isConnected(switch1.peerInfo.peerId) - checkExpiring: not switch1.isConnected(switch2.peerInfo.peerId) + checkUntilTimeout: not switch1.isConnected(switch2.peerInfo.peerId) checkTracker(LPChannelTrackerName) checkTracker(SecureConnTrackerName) - checkExpiring: + checkUntilTimeout: startCounts == @[ switch1.connManager.inSema.count, switch1.connManager.outSema.count, @@ -336,7 +336,7 @@ suite "Switch": await switch2.disconnect(switch1.peerInfo.peerId) check not switch2.isConnected(switch1.peerInfo.peerId) - checkExpiring: not switch1.isConnected(switch2.peerInfo.peerId) + checkUntilTimeout: not switch1.isConnected(switch2.peerInfo.peerId) checkTracker(LPChannelTrackerName) checkTracker(SecureConnTrackerName) @@ -388,7 +388,7 @@ suite "Switch": await switch2.disconnect(switch1.peerInfo.peerId) check not switch2.isConnected(switch1.peerInfo.peerId) - checkExpiring: not switch1.isConnected(switch2.peerInfo.peerId) + checkUntilTimeout: not switch1.isConnected(switch2.peerInfo.peerId) checkTracker(LPChannelTrackerName) checkTracker(SecureConnTrackerName) @@ -439,7 +439,7 @@ suite "Switch": await switch2.disconnect(switch1.peerInfo.peerId) check not switch2.isConnected(switch1.peerInfo.peerId) - checkExpiring: not switch1.isConnected(switch2.peerInfo.peerId) + checkUntilTimeout: not switch1.isConnected(switch2.peerInfo.peerId) checkTracker(LPChannelTrackerName) checkTracker(SecureConnTrackerName) @@ -490,7 +490,7 @@ suite "Switch": await switch2.disconnect(switch1.peerInfo.peerId) check not switch2.isConnected(switch1.peerInfo.peerId) - checkExpiring: not switch1.isConnected(switch2.peerInfo.peerId) + checkUntilTimeout: not switch1.isConnected(switch2.peerInfo.peerId) checkTracker(LPChannelTrackerName) checkTracker(SecureConnTrackerName) @@ -554,8 +554,8 @@ suite "Switch": check not switch2.isConnected(switch1.peerInfo.peerId) check not switch3.isConnected(switch1.peerInfo.peerId) - checkExpiring: not switch1.isConnected(switch2.peerInfo.peerId) - checkExpiring: not switch1.isConnected(switch3.peerInfo.peerId) + checkUntilTimeout: not switch1.isConnected(switch2.peerInfo.peerId) + checkUntilTimeout: not switch1.isConnected(switch3.peerInfo.peerId) checkTracker(LPChannelTrackerName) checkTracker(SecureConnTrackerName) @@ -711,7 +711,7 @@ suite "Switch": await allFuturesThrowing(readers) await switch2.stop() #Otherwise this leaks - checkExpiring: not switch1.isConnected(switch2.peerInfo.peerId) + checkUntilTimeout: not switch1.isConnected(switch2.peerInfo.peerId) checkTracker(LPChannelTrackerName) checkTracker(SecureConnTrackerName)