From 59aaee8354549246c3000625b8ada8184cd5524f Mon Sep 17 00:00:00 2001 From: Eric Mastro Date: Mon, 13 Dec 2021 18:09:08 +1100 Subject: [PATCH] listenError async, add tcp/ws trasnport tests for listenError - listenError async - add tcp/ws trasnport tests for listenError - wstransport: in start, remove unhandled addresses from self.addrs (may need to be refactored) --- libp2p/switch.nim | 17 ++-- libp2p/transports/tcptransport.nim | 7 +- libp2p/transports/transport.nim | 16 ++-- libp2p/transports/wstransport.nim | 57 ++++++++----- tests/testswitch.nim | 129 +++++++++++++++++++---------- tests/testtcptransport.nim | 109 ++++++++++++++++++++++++ tests/testwstransport.nim | 109 ++++++++++++++++++++++++ 7 files changed, 359 insertions(+), 85 deletions(-) diff --git a/libp2p/switch.nim b/libp2p/switch.nim index 1097d23d2..d9a0c65b3 100644 --- a/libp2p/switch.nim +++ b/libp2p/switch.nim @@ -61,7 +61,7 @@ type transport*: Transport ListenErrorCallback = proc ( t: Transport, - err: ref TransportListenError): ref SwitchListenError + err: ref TransportListenError): Future[ref SwitchListenError] {.gcsafe, raises: [Defect].} Switch* = ref object of Dial @@ -258,7 +258,7 @@ proc start*(s: Switch) {.async, gcsafe.} = try: await t.start(addrs) except TransportListenError as e: - let err = s.listenError(t, e) + let err = await s.listenError(t, e) if not err.isNil: raise err s.acceptFuts.add(s.accept(t)) @@ -275,7 +275,10 @@ proc newSwitchListenError*( parent: parent) const ListenErrorDefault = - proc(t: Transport, e: ref TransportListenError): ref SwitchListenError = + proc( + t: Transport, + e: ref TransportListenError): Future[ref SwitchListenError] {.async.}= + error "Failed to start one transport", error = e.msg return newSwitchListenError(t, e) @@ -287,7 +290,7 @@ proc newSwitch*(peerInfo: PeerInfo, connManager: ConnManager, ms: MultistreamSelect, nameResolver: NameResolver = nil, - listenError: ListenErrorCallback = nil): Switch + listenError: ListenErrorCallback = ListenErrorDefault): Switch {.raises: [Defect, LPError].} = if secureManagers.len == 0: @@ -303,12 +306,6 @@ proc newSwitch*(peerInfo: PeerInfo, nameResolver: nameResolver, listenError: listenError) - if switch.listenError.isNil: - switch.listenError = ListenErrorDefault - # switch.listenError = proc(ma: MultiAddress, e: ref TransportListenError): ref SwitchError = - # error "Failed to start one transport", error = e.msg - # return nil - switch.connManager.peerStore = switch.peerStore switch.mount(identity) return switch diff --git a/libp2p/transports/tcptransport.nim b/libp2p/transports/tcptransport.nim index bda189282..1fb4269f1 100644 --- a/libp2p/transports/tcptransport.nim +++ b/libp2p/transports/tcptransport.nim @@ -119,16 +119,13 @@ proc new*( T: typedesc[TcpTransport], flags: set[ServerFlags] = {}, upgrade: Upgrade, - listenError: ListenErrorCallback = nil): T = + listenError: ListenErrorCallback = ListenErrorDefault): T = let transport = T( flags: flags, upgrader: upgrade, listenError: listenError) - if transport.listenError.isNil: - transport.listenError = ListenErrorDefault - inc getTcpTransportTracker().opened return transport @@ -164,7 +161,7 @@ method start*( self.servers &= server except CatchableError as ex: - let err = self.listenError(ma, ex) + let err = await self.listenError(ma, ex) if not err.isNil: raise err diff --git a/libp2p/transports/transport.nim b/libp2p/transports/transport.nim index 000630020..828f00d9d 100644 --- a/libp2p/transports/transport.nim +++ b/libp2p/transports/transport.nim @@ -23,7 +23,7 @@ logScope: type ListenErrorCallback* = proc ( ma: MultiAddress, - err: ref CatchableError): ref TransportListenError + err: ref CatchableError): Future[ref TransportListenError] {.gcsafe, raises: [Defect].} TransportError* = object of LPError TransportInvalidAddrError* = object of TransportError @@ -41,12 +41,18 @@ proc newTransportClosedError*(parent: ref Exception = nil): ref LPError = newException(TransportClosedError, "Transport closed, no more connections!", parent) -proc newTransportListenError*(ma: MultiAddress, parent: ref Exception = nil): ref TransportListenError = - (ref TransportListenError)(msg: "Transport failed to start", parent: parent, ma: ma) +proc newTransportListenError*( + ma: MultiAddress, + parent: ref Exception = nil): ref TransportListenError = + + return (ref TransportListenError)(msg: "Transport failed to start", parent: parent, ma: ma) const ListenErrorDefault* = - proc(ma: MultiAddress, err: ref CatchableError): ref TransportListenError = - newTransportListenError(ma, err) + proc( + ma: MultiAddress, + err: ref CatchableError): Future[ref TransportListenError] {.async.} = + + return newTransportListenError(ma, err) method start*( self: Transport, diff --git a/libp2p/transports/wstransport.nim b/libp2p/transports/wstransport.nim index c0b14278e..0f8450fe9 100644 --- a/libp2p/transports/wstransport.nim +++ b/libp2p/transports/wstransport.nim @@ -107,6 +107,11 @@ method start*( for i, ma in addrs: + if not self.handles(ma): + trace "Invalid address detected, skipping!", address = ma + self.addrs.del i + continue + let isWss = if WSS.match(ma): if self.secure: true @@ -115,26 +120,32 @@ method start*( false else: false - let httpserver = - if isWss: - TlsHttpServer.create( - address = ma.initTAddress().tryGet(), - tlsPrivateKey = self.tlsPrivateKey, - tlsCertificate = self.tlsCertificate, - flags = self.flags) - else: - HttpServer.create(ma.initTAddress().tryGet()) + try: + let httpserver = + if isWss: + TlsHttpServer.create( + address = ma.initTAddress().tryGet(), + tlsPrivateKey = self.tlsPrivateKey, + tlsCertificate = self.tlsCertificate, + flags = self.flags) + else: + HttpServer.create(ma.initTAddress().tryGet()) - self.httpservers &= httpserver + self.httpservers &= httpserver - let codec = if isWss: - MultiAddress.init("/wss") - else: - MultiAddress.init("/ws") + let codec = if isWss: + MultiAddress.init("/wss") + else: + MultiAddress.init("/ws") - # always get the resolved address in case we're bound to 0.0.0.0:0 - self.addrs[i] = MultiAddress.init( - httpserver.localAddress()).tryGet() & codec.tryGet() + # always get the resolved address in case we're bound to 0.0.0.0:0 + self.addrs[i] = MultiAddress.init( + httpserver.localAddress()).tryGet() & codec.tryGet() + + except CatchableError as ex: + let err = await self.listenError(ma, ex) + if not err.isNil: + raise err trace "Listening on", addresses = self.addrs @@ -291,7 +302,8 @@ proc new*( tlsFlags: set[TLSFlags] = {}, flags: set[ServerFlags] = {}, factories: openArray[ExtFactory] = [], - rng: Rng = nil): T = + rng: Rng = nil, + listenError: ListenErrorCallback = ListenErrorDefault): T = T( upgrader: upgrade, @@ -300,14 +312,16 @@ proc new*( tlsFlags: tlsFlags, flags: flags, factories: @factories, - rng: rng) + rng: rng, + listenError: listenError) proc new*( T: typedesc[WsTransport], upgrade: Upgrade, flags: set[ServerFlags] = {}, factories: openArray[ExtFactory] = [], - rng: Rng = nil): T = + rng: Rng = nil, + listenError: ListenErrorCallback = ListenErrorDefault): T = T.new( upgrade = upgrade, @@ -315,4 +329,5 @@ proc new*( tlsCertificate = nil, flags = flags, factories = @factories, - rng = rng) + rng = rng, + listenError = listenError) diff --git a/tests/testswitch.nim b/tests/testswitch.nim index f42c6317a..f210743d7 100644 --- a/tests/testswitch.nim +++ b/tests/testswitch.nim @@ -976,19 +976,24 @@ suite "Switch": await srcWsSwitch.stop() await srcTcpSwitch.stop() - asyncTest "listenError callback default returns TransportListenError (pessimistic)": + asyncTest "pessimistic: default listenError callback returns SwitchListenError": let switch = newStandardSwitch() transport = Transport() check switch.listenError.isNil.not - let exc = newException(TransportListenError, "test") - check not switch.listenError(transport, exc).isNil + let + exc = newException(TransportListenError, "test") + listenErrResult = await switch.listenError(transport, exc) + + check: + not listenErrResult.isNil + listenErrResult is (ref SwitchListenError) await switch.stop() - asyncTest "listenError callback assignable and callable": + asyncTest "switch listenError callback assignable and callable": let switch = newStandardSwitch() transportListenError = newException(TransportListenError, "test1") @@ -996,7 +1001,8 @@ suite "Switch": transport = Transport() switch.listenError = proc( - t: Transport, exc: ref TransportListenError): ref SwitchListenError = + t: Transport, + exc: ref TransportListenError): Future[ref SwitchListenError] {.async.} = check: exc == transportListenError @@ -1004,8 +1010,8 @@ suite "Switch": return switchListenError - check: - switch.listenError(transport, transportListenError) == switchListenError + let listenErrResult = await switch.listenError(transport, transportListenError) + check listenErrResult == switchListenError await switch.stop() @@ -1016,7 +1022,7 @@ suite "Switch": exc2 = newException(CatchableError, "test2") transportStartMock = - proc(self: MockTransport, addrs: seq[MultiAddress]): Future[void] = + proc(self: MockTransport, addrs: seq[MultiAddress]): Future[void] {.async.} = for i, ma in addrs: try: if i == 0: @@ -1031,9 +1037,10 @@ suite "Switch": echo "[test.startMock] raising exc2" raise exc2 except CatchableError as e: - let err = self.listenError(ma, e) + let err = await self.listenError(ma, e) if not err.isNil: raise err + # return fail() # should not get this far mockTransport = @@ -1067,8 +1074,9 @@ suite "Switch": exc = TransportListenError(e.parent[]) ma = exc.ma - check ma == ma0 - check exc.parent == exc0 + check: + ma == ma0 + exc.parent == exc0 await switch.stop() @@ -1080,7 +1088,10 @@ suite "Switch": exc2 = newException(CatchableError, "test2") transportListenError = - proc(ma: MultiAddress, ex: ref CatchableError): ref TransportListenError = + proc( + ma: MultiAddress, + ex: ref CatchableError): Future[ref TransportListenError] {.async.} = + handledTransportErrs[ma] = ex return nil # optimistic transport multiaddress failure @@ -1098,7 +1109,7 @@ suite "Switch": echo "[test.startMock] raising exc2" raise exc2 except CatchableError as e: - let err = self.listenError(ma, e) + let err = await self.listenError(ma, e) # check err == nil if not err.isNil: raise err @@ -1129,7 +1140,8 @@ suite "Switch": switch.listenError = proc( - t: Transport, exc: ref TransportListenError): ref SwitchListenError = + t: Transport, + exc: ref TransportListenError): Future[ref SwitchListenError] {.async.} = let ma = exc.ma if ma == ma0: @@ -1160,7 +1172,10 @@ suite "Switch": exc2 = newException(CatchableError, "test2") transportListenError = - proc(ma: MultiAddress, ex: ref CatchableError): ref TransportListenError = + proc( + ma: MultiAddress, + ex: ref CatchableError): Future[ref TransportListenError] {.async.} = + handledTransportErrs[ma] = ex return nil # optimistic transport multiaddress failure @@ -1177,7 +1192,7 @@ suite "Switch": echo "[test.startMock] raising exc2" raise exc2 except CatchableError as e: - let err = self.listenError(ma, e) + let err = await self.listenError(ma, e) # check err == nil if not err.isNil: raise err @@ -1208,7 +1223,8 @@ suite "Switch": switch.listenError = proc( - t: Transport, exc: ref TransportListenError): ref SwitchListenError = + t: Transport, + exc: ref TransportListenError): Future[ref SwitchListenError] {.async.} = fail() try: @@ -1235,13 +1251,19 @@ suite "Switch": transportListenError0 = - proc(ma: MultiAddress, ex: ref CatchableError): ref TransportListenError = + proc( + ma: MultiAddress, + ex: ref CatchableError): Future[ref TransportListenError] {.async.} = + handledTransportErrs0[ma] = ex # pessimistic transport multiaddress failure return newTransportListenError(ma, ex) transportListenError1 = - proc(ma: MultiAddress, ex: ref CatchableError): ref TransportListenError = + proc( + ma: MultiAddress, + ex: ref CatchableError): Future[ref TransportListenError] {.async.} = + handledTransportErrs1[ma] = ex # pessimistic transport multiaddress failure return newTransportListenError(ma, ex) @@ -1258,7 +1280,7 @@ suite "Switch": else: fail() except CatchableError as e: - let err = self.listenError(ma, e) + let err = await self.listenError(ma, e) if not err.isNil: raise err @@ -1271,7 +1293,7 @@ suite "Switch": else: fail() except CatchableError as e: - let err = self.listenError(ma, e) + let err = await self.listenError(ma, e) if not err.isNil: raise err @@ -1314,15 +1336,17 @@ suite "Switch": switch.listenError = proc( t: Transport, - exc: ref TransportListenError): ref SwitchListenError = + exc: ref TransportListenError): Future[ref SwitchListenError] {.async.} = let ma = exc.ma if ma == ma0: - check exc.parent == exc10 - check t == switch.transports[1] + check: + exc.parent == exc10 + t == switch.transports[1] elif ma == ma1: - check exc.parent == exc01 - check t == switch.transports[0] + check: + exc.parent == exc01 + t == switch.transports[0] else: fail() # switch optimistic, continue with all transports @@ -1343,8 +1367,9 @@ suite "Switch": for ma, ex in handledTransportErrs1: echo "ma: ", $ma, ", ex: ", ex.msg - check handledTransportErrs0 == [(ma1, exc01)].toTable() - check handledTransportErrs1 == [(ma0, exc10)].toTable() + check: + handledTransportErrs0 == [(ma1, exc01)].toTable() + handledTransportErrs1 == [(ma0, exc10)].toTable() await switch.stop() @@ -1358,18 +1383,27 @@ suite "Switch": exc10 = newException(CatchableError, "test10") transportListenError0 = - proc(ma: MultiAddress, ex: ref CatchableError): ref TransportListenError = + proc( + ma: MultiAddress, + ex: ref CatchableError): Future[ref TransportListenError] {.async.} = + handledTransportErrs0[ma] = ex # optimistic transport multiaddress failure return nil transportListenError1 = - proc(ma: MultiAddress, ex: ref CatchableError): ref TransportListenError = + proc( + ma: MultiAddress, + ex: ref CatchableError): Future[ref TransportListenError] {.async.} = + handledTransportErrs1[ma] = ex return newTransportListenError(ma, ex) transportListenError2 = - proc(ma: MultiAddress, ex: ref CatchableError): ref TransportListenError = + proc( + ma: MultiAddress, + ex: ref CatchableError): Future[ref TransportListenError] {.async.} = + # should not get here as switch is pessimistic so will stop at first # failed transport (transpor1) fail() @@ -1383,7 +1417,7 @@ suite "Switch": else: continue except CatchableError as e: - let err = self.listenError(ma, e) + let err = await self.listenError(ma, e) if not err.isNil: raise err @@ -1396,7 +1430,7 @@ suite "Switch": else: continue except CatchableError as e: - let err = self.listenError(ma, e) + let err = await self.listenError(ma, e) if not err.isNil: raise err @@ -1457,10 +1491,11 @@ suite "Switch": switch.listenError = proc( t: Transport, - exc: ref TransportListenError): ref SwitchListenError = + exc: ref TransportListenError): Future[ref SwitchListenError] {.async.} = - check t == switch.transports[1] - check exc.ma == ma0 + check: + t == switch.transports[1] + exc.ma == ma0 # pessimistic return newSwitchListenError(t, exc) @@ -1469,9 +1504,10 @@ suite "Switch": await switch.start() except SwitchListenError as e: let tListenEx = (ref TransportListenError)(e.parent) - check tListenEx.ma == ma0 - check tListenEx.parent == exc10 - check e.transport == switch.transports[1] + check: + tListenEx.ma == ma0 + tListenEx.parent == exc10 + e.transport == switch.transports[1] echo "handledTransportErrs0:" for ma, ex in handledTransportErrs0: @@ -1481,19 +1517,24 @@ suite "Switch": for ma, ex in handledTransportErrs1: echo "ma: ", $ma, ", ex: ", ex.msg - check handledTransportErrs0 == [(ma1, exc01)].toTable() - check handledTransportErrs1 == [(ma0, exc10)].toTable() + check: + handledTransportErrs0 == [(ma1, exc01)].toTable() + handledTransportErrs1 == [(ma0, exc10)].toTable() await switch.stop() asyncTest "no exceptions raised, listenError should not be called": let transportListenError0 = - proc(ma: MultiAddress, ex: ref CatchableError): ref TransportListenError = + proc( + ma: MultiAddress, + ex: ref CatchableError): Future[ref TransportListenError] {.async.} = fail() transportListenError1 = - proc(ma: MultiAddress, ex: ref CatchableError): ref TransportListenError = + proc( + ma: MultiAddress, + ex: ref CatchableError): Future[ref TransportListenError] {.async.} = fail() transportStartMock0 = @@ -1546,7 +1587,7 @@ suite "Switch": switch.listenError = proc( t: Transport, - exc: ref TransportListenError): ref SwitchListenError = + exc: ref TransportListenError): Future[ref SwitchListenError] {.async.} = fail() diff --git a/tests/testtcptransport.nim b/tests/testtcptransport.nim index 64d5d8f6b..251ca733c 100644 --- a/tests/testtcptransport.nim +++ b/tests/testtcptransport.nim @@ -125,6 +125,115 @@ suite "TCP transport": server.close() await server.join() + asyncTest "pessimistic: default listenError callback returns TransportListenError": + let + transport = TcpTransport.new(upgrade = Upgrade()) + + check not transport.listenError.isNil + + let + exc = newException(CatchableError, "test") + ma = Multiaddress.init("/ip4/0.0.0.0/tcp/0").tryGet() + listenErrResult = await transport.listenError(ma, exc) + + check: + not listenErrResult.isNil + listenErrResult is (ref TransportListenError) + + await transport.stop() + + asyncTest "listenError callback assignable and callable": + let + failListenErr = proc( + maErr: MultiAddress, + exc: ref CatchableError): Future[ref TransportListenError] {.async.} = + fail() + transport = TcpTransport.new( + upgrade = Upgrade(), + listenError = failListenErr) + ma = Multiaddress.init("/ip4/0.0.0.0/tcp/0").tryGet() + catchableError = newException(CatchableError, "test1") + transportListenError = newTransportListenError(ma, catchableError) + + transport.listenError = proc( + maErr: MultiAddress, + exc: ref CatchableError): Future[ref TransportListenError] {.async.} = + + check: + exc == catchableError + maErr == ma + + return transportListenError + + let listenErrResult = await transport.listenError(ma, catchableError) + + check: + listenErrResult == transportListenError + + await transport.stop() + + asyncTest "pessimistic: default listenError re-raises exception": + let + # use a bad MultiAddress to throw an error during transport.start + ma = Multiaddress.init("/ip4/1.0.0.0/tcp/0").tryGet() + + transport = TcpTransport.new(upgrade = Upgrade()) + + expect TransportListenError: + await transport.start(@[ma]) + + await transport.stop() + + asyncTest "pessimistic: overridden listenError re-raises exception": + var transportListenErr: ref TransportListenError + + let + # use a bad MultiAddress to throw an error during transport.start + ma = Multiaddress.init("/ip4/1.0.0.0/tcp/0").tryGet() + listenError = proc( + maErr: MultiAddress, + exc: ref CatchableError): Future[ref TransportListenError] {.async.} = + + transportListenErr = newTransportListenError(maErr, exc) + check maErr == ma + return transportListenErr + + transport = TcpTransport.new( + upgrade = Upgrade(), + listenError = listenError) + + try: + await transport.start(@[ma]) + fail() + except TransportListenError as e: + check e == transportListenErr + + await transport.stop() + + asyncTest "optimistic: overridden listenError does not re-raise exception": + var transportListenErr: ref TransportListenError + + let + # use a bad MultiAddress to throw an error during transport.start + ma = Multiaddress.init("/ip4/1.0.0.0/tcp/0").tryGet() + listenError = proc( + maErr: MultiAddress, + exc: ref CatchableError): Future[ref TransportListenError] {.async.} = + + check maErr == ma + return nil + + transport = TcpTransport.new( + upgrade = Upgrade(), + listenError = listenError) + + try: + await transport.start(@[ma]) + except TransportListenError as e: + fail() + + await transport.stop() + commonTransportTest( "TcpTransport", proc (): Transport = TcpTransport.new(upgrade = Upgrade()), diff --git a/tests/testwstransport.nim b/tests/testwstransport.nim index bf8b9ce87..e7f222e6e 100644 --- a/tests/testwstransport.nim +++ b/tests/testwstransport.nim @@ -55,6 +55,115 @@ suite "WebSocket transport": teardown: checkTrackers() + asyncTest "pessimistic: default listenError callback returns TransportListenError": + let + transport = WsTransport.new(upgrade = Upgrade()) + + check not transport.listenError.isNil + + let + exc = newException(CatchableError, "test") + ma = Multiaddress.init("/ip4/0.0.0.0/tcp/0/ws").tryGet() + listenErrResult = await transport.listenError(ma, exc) + + check: + not listenErrResult.isNil + listenErrResult is (ref TransportListenError) + + await transport.stop() + + asyncTest "listenError callback assignable and callable": + let + failListenErr = proc( + maErr: MultiAddress, + exc: ref CatchableError): Future[ref TransportListenError] {.async.} = + fail() + transport = WsTransport.new( + upgrade = Upgrade(), + listenError = failListenErr) + ma = Multiaddress.init("/ip4/0.0.0.0/tcp/0/ws").tryGet() + catchableError = newException(CatchableError, "test1") + transportListenError = newTransportListenError(ma, catchableError) + + transport.listenError = proc( + maErr: MultiAddress, + exc: ref CatchableError): Future[ref TransportListenError] {.async.} = + + check: + exc == catchableError + maErr == ma + + return transportListenError + + let listenErrResult = await transport.listenError(ma, catchableError) + + check: + listenErrResult == transportListenError + + await transport.stop() + + asyncTest "pessimistic: default listenError re-raises exception": + let + # use a bad MultiAddress to throw an error during transport.start + ma = Multiaddress.init("/ip4/1.0.0.0/tcp/0/ws").tryGet() + + transport = WsTransport.new(upgrade = Upgrade()) + + expect TransportListenError: + await transport.start(@[ma]) + + await transport.stop() + + asyncTest "pessimistic: overridden listenError re-raises exception": + var transportListenErr: ref TransportListenError + + let + # use a bad MultiAddress to throw an error during transport.start + ma = Multiaddress.init("/ip4/1.0.0.0/tcp/0/ws").tryGet() + listenError = proc( + maErr: MultiAddress, + exc: ref CatchableError): Future[ref TransportListenError] {.async.} = + + transportListenErr = newTransportListenError(maErr, exc) + check maErr == ma + return transportListenErr + + transport = WsTransport.new( + upgrade = Upgrade(), + listenError = listenError) + + try: + await transport.start(@[ma]) + fail() + except TransportListenError as e: + check e == transportListenErr + + await transport.stop() + + asyncTest "optimistic: overridden listenError does not re-raise exception": + var transportListenErr: ref TransportListenError + + let + # use a bad MultiAddress to throw an error during transport.start + ma = Multiaddress.init("/ip4/1.0.0.0/tcp/0/ws").tryGet() + listenError = proc( + maErr: MultiAddress, + exc: ref CatchableError): Future[ref TransportListenError] {.async.} = + + check maErr == ma + return nil + + transport = WsTransport.new( + upgrade = Upgrade(), + listenError = listenError) + + try: + await transport.start(@[ma]) + except TransportListenError as e: + fail() + + await transport.stop() + commonTransportTest( "WebSocket", proc (): Transport = WsTransport.new(Upgrade()),