diff --git a/examples/directchat.nim b/examples/directchat.nim index 8e5d06658..472fc072a 100644 --- a/examples/directchat.nim +++ b/examples/directchat.nim @@ -175,7 +175,7 @@ proc processInput(rfd: AsyncFD) {.async.} = result = newMplex(conn) let mplexProvider = newMuxerProvider(createMplex, MplexCodec) - let transports = @[Transport(newTransport(TcpTransport))] + let transports = @[Transport(TcpTransport.init())] let muxers = [(MplexCodec, mplexProvider)].toTable() let identify = newIdentify(peerInfo) let secureManagers = [(SecioCodec, Secure(newSecio(seckey)))].toTable() diff --git a/libp2p/standard_setup.nim b/libp2p/standard_setup.nim index 2709ad3d7..2d53b1f18 100644 --- a/libp2p/standard_setup.nim +++ b/libp2p/standard_setup.nim @@ -5,7 +5,7 @@ const libp2p_pubsub_verify {.booldefine.} = true import - options, tables, + options, tables, chronos, switch, peer, peerinfo, connection, multiaddress, crypto/crypto, transports/[transport, tcptransport], muxers/[muxer, mplex/mplex, mplex/types], @@ -26,7 +26,7 @@ proc newStandardSwitch*(privKey = none(PrivateKey), gossip = false, verifySignature = libp2p_pubsub_verify, sign = libp2p_pubsub_sign, - transportFlags: TransportFlags = {}): Switch = + transportFlags: set[ServerFlags] = {}): Switch = proc createMplex(conn: Connection): Muxer = result = newMplex(conn) @@ -34,7 +34,7 @@ proc newStandardSwitch*(privKey = none(PrivateKey), seckey = privKey.get(otherwise = PrivateKey.random(ECDSA).tryGet()) peerInfo = PeerInfo.init(seckey, [address]) mplexProvider = newMuxerProvider(createMplex, MplexCodec) - transports = @[Transport(newTransport(TcpTransport, transportFlags))] + transports = @[Transport(TcpTransport.init(transportFlags))] muxers = {MplexCodec: mplexProvider}.toTable identify = newIdentify(peerInfo) when libp2p_secure == "noise": diff --git a/libp2p/transports/tcptransport.nim b/libp2p/transports/tcptransport.nim index ccab6d658..06a55e616 100644 --- a/libp2p/transports/tcptransport.nim +++ b/libp2p/transports/tcptransport.nim @@ -26,6 +26,7 @@ type TcpTransport* = ref object of Transport server*: StreamServer clients: seq[StreamTransport] + flags: set[ServerFlags] cleanups*: seq[Future[void]] handlers*: seq[Future[void]] @@ -91,7 +92,11 @@ proc connCb(server: StreamServer, # shouldn't happen but.. warn "Error closing connection", err = err.msg -method init*(t: TcpTransport) = +proc init*(T: type TcpTransport, flags: set[ServerFlags] = {}): T = + result = T(flags: flags) + result.initTransport() + +method initTransport*(t: TcpTransport) = t.multicodec = multiCodec("tcp") inc getTcpTransportTracker().opened @@ -134,7 +139,7 @@ method listen*(t: TcpTransport, discard await procCall Transport(t).listen(ma, handler) # call base ## listen on the transport - t.server = createStreamServer(t.ma, connCb, transportFlagsToServerFlags(t.flags), t) + t.server = createStreamServer(t.ma, connCb, t.flags, t) t.server.start() # always get the resolved address in case we're bound to 0.0.0.0:0 diff --git a/libp2p/transports/transport.nim b/libp2p/transports/transport.nim index 68f7f85ae..09eaa1880 100644 --- a/libp2p/transports/transport.nim +++ b/libp2p/transports/transport.nim @@ -17,34 +17,15 @@ import ../connection, type ConnHandler* = proc (conn: Connection): Future[void] {.gcsafe.} - TransportFlag* {.pure.} = enum - ReuseAddr - - TransportFlags* = set[TransportFlag] - Transport* = ref object of RootObj ma*: Multiaddress handler*: ConnHandler multicodec*: MultiCodec - flags*: TransportFlags -proc transportFlagsToServerFlags*(flags: TransportFlags): set[ServerFlags] {.gcsafe.} = - let transportFlagToServerFlagMapping = { - TransportFlag.ReuseAddr: ServerFlags.ReuseAddr, - }.toTable() - - for flag in flags: - result.incl(transportFlagToServerFlagMapping[flag]) - -method init*(t: Transport) {.base, gcsafe.} = +method initTransport*(t: Transport) {.base, gcsafe, locks: "unknown".} = ## perform protocol initialization discard -proc newTransport*(t: typedesc[Transport], flags: TransportFlags = {}): t {.gcsafe.} = - new result - result.flags = flags - result.init() - method close*(t: Transport) {.base, async, gcsafe.} = ## stop and cleanup the transport ## including all outstanding connections diff --git a/tests/testidentify.nim b/tests/testidentify.nim index 951a5272d..f4cf0378b 100644 --- a/tests/testidentify.nim +++ b/tests/testidentify.nim @@ -34,11 +34,11 @@ suite "Identify": proc connHandler(conn: Connection): Future[void] {.async, gcsafe.} = await msListen.handle(conn) - var transport1 = newTransport(TcpTransport) + var transport1 = TcpTransport.init() serverFut = await transport1.listen(ma, connHandler) let msDial = newMultistream() - let transport2: TcpTransport = newTransport(TcpTransport) + let transport2: TcpTransport = TcpTransport.init() let conn = await transport2.dial(transport1.ma) var peerInfo = PeerInfo.init(PrivateKey.random(RSA).get(), [ma]) @@ -78,11 +78,11 @@ suite "Identify": await conn.close() done.complete() - let transport1: TcpTransport = newTransport(TcpTransport) + let transport1: TcpTransport = TcpTransport.init() asyncCheck transport1.listen(ma, connHandler) let msDial = newMultistream() - let transport2: TcpTransport = newTransport(TcpTransport) + let transport2: TcpTransport = TcpTransport.init() let conn = await transport2.dial(transport1.ma) var localPeerInfo = PeerInfo.init(PrivateKey.random(RSA).get(), [ma]) diff --git a/tests/testinterop.nim b/tests/testinterop.nim index 137bcea32..9f04a1db6 100644 --- a/tests/testinterop.nim +++ b/tests/testinterop.nim @@ -71,7 +71,7 @@ proc createNode*(privKey: Option[PrivateKey] = none(PrivateKey), var peerInfo = NativePeerInfo.init(seckey.get(), [Multiaddress.init(address)]) proc createMplex(conn: Connection): Muxer = newMplex(conn) let mplexProvider = newMuxerProvider(createMplex, MplexCodec) - let transports = @[Transport(newTransport(TcpTransport))] + let transports = @[Transport(TcpTransport.init())] let muxers = [(MplexCodec, mplexProvider)].toTable() let identify = newIdentify(peerInfo) let secureManagers = [(SecioCodec, Secure(newSecio(seckey.get())))].toTable() diff --git a/tests/testmplex.nim b/tests/testmplex.nim index 6641ff33a..ad13e7cff 100644 --- a/tests/testmplex.nim +++ b/tests/testmplex.nim @@ -229,10 +229,10 @@ suite "Mplex": await mplexListen.handle() await mplexListen.close() - let transport1: TcpTransport = newTransport(TcpTransport) + let transport1: TcpTransport = TcpTransport.init() let listenFut = await transport1.listen(ma, connHandler) - let transport2: TcpTransport = newTransport(TcpTransport) + let transport2: TcpTransport = TcpTransport.init() let conn = await transport2.dial(transport1.ma) let mplexDial = newMplex(conn) @@ -267,10 +267,10 @@ suite "Mplex": await mplexListen.handle() await mplexListen.close() - let transport1: TcpTransport = newTransport(TcpTransport) + let transport1: TcpTransport = TcpTransport.init() let listenFut = await transport1.listen(ma, connHandler) - let transport2: TcpTransport = newTransport(TcpTransport) + let transport2: TcpTransport = TcpTransport.init() let conn = await transport2.dial(transport1.ma) let mplexDial = newMplex(conn) @@ -312,10 +312,10 @@ suite "Mplex": await mplexListen.handle() await mplexListen.close() - let transport1: TcpTransport = newTransport(TcpTransport) + let transport1: TcpTransport = TcpTransport.init() let listenFut = await transport1.listen(ma, connHandler) - let transport2: TcpTransport = newTransport(TcpTransport) + let transport2: TcpTransport = TcpTransport.init() let conn = await transport2.dial(transport1.ma) let mplexDial = newMplex(conn) @@ -352,10 +352,10 @@ suite "Mplex": await mplexListen.handle() await mplexListen.close() - let transport1: TcpTransport = newTransport(TcpTransport) + let transport1: TcpTransport = TcpTransport.init() let listenFut = await transport1.listen(ma, connHandler) - let transport2: TcpTransport = newTransport(TcpTransport) + let transport2: TcpTransport = TcpTransport.init() let conn = await transport2.dial(transport1.ma) let mplexDial = newMplex(conn) @@ -393,10 +393,10 @@ suite "Mplex": await mplexListen.handle() await mplexListen.close() - let transport1 = newTransport(TcpTransport) + let transport1 = TcpTransport.init() let listenFut = await transport1.listen(ma, connHandler) - let transport2: TcpTransport = newTransport(TcpTransport) + let transport2: TcpTransport = TcpTransport.init() let conn = await transport2.dial(transport1.ma) let mplexDial = newMplex(conn) @@ -436,10 +436,10 @@ suite "Mplex": await mplexListen.handle() await mplexListen.close() - let transport1: TcpTransport = newTransport(TcpTransport) + let transport1: TcpTransport = TcpTransport.init() let listenFut = await transport1.listen(ma, connHandler) - let transport2: TcpTransport = newTransport(TcpTransport) + let transport2: TcpTransport = TcpTransport.init() let conn = await transport2.dial(transport1.ma) let mplexDial = newMplex(conn) @@ -477,10 +477,10 @@ suite "Mplex": await mplexListen.handle() await mplexListen.close() - let transport1: TcpTransport = newTransport(TcpTransport) + let transport1: TcpTransport = TcpTransport.init() let listenFut = await transport1.listen(ma, connHandler) - let transport2: TcpTransport = newTransport(TcpTransport) + let transport2: TcpTransport = TcpTransport.init() let conn = await transport2.dial(transport1.ma) let mplexDial = newMplex(conn) @@ -545,10 +545,10 @@ suite "Mplex": await mplexListen.handle() await mplexListen.close() - let transport1: TcpTransport = newTransport(TcpTransport) + let transport1: TcpTransport = TcpTransport.init() let listenFut = await transport1.listen(ma, connHandler) - let transport2: TcpTransport = newTransport(TcpTransport) + let transport2: TcpTransport = TcpTransport.init() let conn = await transport2.dial(transport1.ma) let mplexDial = newMplex(conn) diff --git a/tests/testmultistream.nim b/tests/testmultistream.nim index 38c4caf26..bbee017b6 100644 --- a/tests/testmultistream.nim +++ b/tests/testmultistream.nim @@ -260,11 +260,11 @@ suite "Multistream select": await conn.close() handlerWait2.complete() - let transport1: TcpTransport = newTransport(TcpTransport) + let transport1: TcpTransport = TcpTransport.init() asyncCheck transport1.listen(ma, connHandler) let msDial = newMultistream() - let transport2: TcpTransport = newTransport(TcpTransport) + let transport2: TcpTransport = TcpTransport.init() let conn = await transport2.dial(transport1.ma) check (await msDial.select(conn, "/test/proto/1.0.0")) == true @@ -304,7 +304,7 @@ suite "Multistream select": msListen.addHandler("/test/proto1/1.0.0", protocol) msListen.addHandler("/test/proto2/1.0.0", protocol) - let transport1: TcpTransport = newTransport(TcpTransport) + let transport1: TcpTransport = TcpTransport.init() proc connHandler(conn: Connection): Future[void] {.async, gcsafe.} = await msListen.handle(conn) handlerWait.complete() @@ -312,7 +312,7 @@ suite "Multistream select": asyncCheck transport1.listen(ma, connHandler) let msDial = newMultistream() - let transport2: TcpTransport = newTransport(TcpTransport) + let transport2: TcpTransport = TcpTransport.init() let conn = await transport2.dial(transport1.ma) let ls = await msDial.list(conn) @@ -348,11 +348,11 @@ suite "Multistream select": proc connHandler(conn: Connection): Future[void] {.async, gcsafe.} = await msListen.handle(conn) - let transport1: TcpTransport = newTransport(TcpTransport) + let transport1: TcpTransport = TcpTransport.init() asyncCheck transport1.listen(ma, connHandler) let msDial = newMultistream() - let transport2: TcpTransport = newTransport(TcpTransport) + let transport2: TcpTransport = TcpTransport.init() let conn = await transport2.dial(transport1.ma) check (await msDial.select(conn, @@ -388,11 +388,11 @@ suite "Multistream select": proc connHandler(conn: Connection): Future[void] {.async, gcsafe.} = await msListen.handle(conn) - let transport1: TcpTransport = newTransport(TcpTransport) + let transport1: TcpTransport = TcpTransport.init() asyncCheck transport1.listen(ma, connHandler) let msDial = newMultistream() - let transport2: TcpTransport = newTransport(TcpTransport) + let transport2: TcpTransport = TcpTransport.init() let conn = await transport2.dial(transport1.ma) check (await msDial.select(conn, @["/test/proto2/1.0.0", "/test/proto1/1.0.0"])) == "/test/proto2/1.0.0" diff --git a/tests/testnoise.nim b/tests/testnoise.nim index 937cd5371..59b9e4eab 100644 --- a/tests/testnoise.nim +++ b/tests/testnoise.nim @@ -58,7 +58,7 @@ proc createSwitch(ma: MultiAddress; outgoing: bool): (Switch, PeerInfo) = result = newMplex(conn) let mplexProvider = newMuxerProvider(createMplex, MplexCodec) - let transports = @[Transport(newTransport(TcpTransport))] + let transports = @[Transport(TcpTransport.init())] let muxers = [(MplexCodec, mplexProvider)].toTable() let secureManagers = [(NoiseCodec, Secure(newNoise(peerInfo.privateKey, outgoing = outgoing)))].toTable() let switch = newSwitch(peerInfo, @@ -88,11 +88,11 @@ suite "Noise": await sconn.write(cstring("Hello!"), 6) let - transport1: TcpTransport = newTransport(TcpTransport) + transport1: TcpTransport = TcpTransport.init() asyncCheck await transport1.listen(server, connHandler) let - transport2: TcpTransport = newTransport(TcpTransport) + transport2: TcpTransport = TcpTransport.init() clientInfo = PeerInfo.init(PrivateKey.random(RSA).get(), [transport1.ma]) clientNoise = newNoise(clientInfo.privateKey, outgoing = true) conn = await transport2.dial(transport1.ma) @@ -130,11 +130,11 @@ suite "Noise": readTask.complete() let - transport1: TcpTransport = newTransport(TcpTransport) + transport1: TcpTransport = TcpTransport.init() asyncCheck await transport1.listen(server, connHandler) let - transport2: TcpTransport = newTransport(TcpTransport) + transport2: TcpTransport = TcpTransport.init() clientInfo = PeerInfo.init(PrivateKey.random(RSA).get(), [transport1.ma]) clientNoise = newNoise(clientInfo.privateKey, outgoing = true) conn = await transport2.dial(transport1.ma) @@ -173,11 +173,11 @@ suite "Noise": readTask.complete() let - transport1: TcpTransport = newTransport(TcpTransport) + transport1: TcpTransport = TcpTransport.init() asyncCheck await transport1.listen(server, connHandler) let - transport2: TcpTransport = newTransport(TcpTransport) + transport2: TcpTransport = TcpTransport.init() clientInfo = PeerInfo.init(PrivateKey.random(RSA).get(), [transport1.ma]) clientNoise = newNoise(clientInfo.privateKey, outgoing = true) conn = await transport2.dial(transport1.ma) diff --git a/tests/testswitch.nim b/tests/testswitch.nim index caa4d7ebe..31cf31c34 100644 --- a/tests/testswitch.nim +++ b/tests/testswitch.nim @@ -39,7 +39,7 @@ proc createSwitch(ma: MultiAddress): (Switch, PeerInfo) = result = newMplex(conn) let mplexProvider = newMuxerProvider(createMplex, MplexCodec) - let transports = @[Transport(newTransport(TcpTransport))] + let transports = @[Transport(TcpTransport.init())] let muxers = [(MplexCodec, mplexProvider)].toTable() let secureManagers = [(SecioCodec, Secure(newSecio(peerInfo.privateKey)))].toTable() let switch = newSwitch(peerInfo, @@ -174,11 +174,11 @@ suite "Switch": # readTask.complete() # let - # transport1: TcpTransport = newTransport(TcpTransport) + # transport1: TcpTransport = TcpTransport.init() # asyncCheck await transport1.listen(server, connHandler) # let - # transport2: TcpTransport = newTransport(TcpTransport) + # transport2: TcpTransport = TcpTransport.init() # clientInfo = PeerInfo.init(PrivateKey.random(RSA), [transport1.ma]) # clientNoise = newSecio(clientInfo.privateKey) # conn = await transport2.dial(transport1.ma) diff --git a/tests/testtransport.nim b/tests/testtransport.nim index 72376a712..2cdb5d6d6 100644 --- a/tests/testtransport.nim +++ b/tests/testtransport.nim @@ -23,7 +23,7 @@ suite "TCP transport": await conn.close() handlerWait.complete() - let transport: TcpTransport = newTransport(TcpTransport) + let transport: TcpTransport = TcpTransport.init() asyncCheck transport.listen(ma, connHandler) @@ -51,7 +51,7 @@ suite "TCP transport": await conn.close() handlerWait.complete() - let transport: TcpTransport = newTransport(TcpTransport) + let transport: TcpTransport = TcpTransport.init() asyncCheck await transport.listen(ma, connHandler) let streamTransport: StreamTransport = await connect(transport.ma) let sent = await streamTransport.write("Hello!", 6) @@ -83,7 +83,7 @@ suite "TCP transport": server.start() let ma: MultiAddress = MultiAddress.init(server.sock.getLocalAddress()) - let transport: TcpTransport = newTransport(TcpTransport) + let transport: TcpTransport = TcpTransport.init() let conn = await transport.dial(ma) var msg = newSeq[byte](6) await conn.readExactly(addr msg[0], 6) @@ -120,7 +120,7 @@ suite "TCP transport": server.start() let ma: MultiAddress = MultiAddress.init(server.sock.getLocalAddress()) - let transport: TcpTransport = newTransport(TcpTransport) + let transport: TcpTransport = TcpTransport.init() let conn = await transport.dial(ma) await conn.write(cstring("Hello!"), 6) result = true @@ -145,10 +145,10 @@ suite "TCP transport": await conn.close() handlerWait.complete() - let transport1: TcpTransport = newTransport(TcpTransport) + let transport1: TcpTransport = TcpTransport.init() asyncCheck transport1.listen(ma, connHandler) - let transport2: TcpTransport = newTransport(TcpTransport) + let transport2: TcpTransport = TcpTransport.init() let conn = await transport2.dial(transport1.ma) var msg = newSeq[byte](6) await conn.readExactly(addr msg[0], 6) @@ -175,10 +175,10 @@ suite "TCP transport": await conn.close() handlerWait.complete() - let transport1: TcpTransport = newTransport(TcpTransport) + let transport1: TcpTransport = TcpTransport.init() asyncCheck transport1.listen(ma, connHandler) - let transport2: TcpTransport = newTransport(TcpTransport) + let transport2: TcpTransport = TcpTransport.init() let conn = await transport2.dial(transport1.ma) await conn.write(cstring("Hello!"), 6)