diff --git a/libp2p/tcptransport.nim b/libp2p/tcptransport.nim index 9b63acf..83f1fe1 100644 --- a/libp2p/tcptransport.nim +++ b/libp2p/tcptransport.nim @@ -23,19 +23,22 @@ method init*(t: TcpTransport) = method close*(t: TcpTransport): Future[void] {.async.} = ## start the transport - await procCall Transport(t).close() # call base close + await procCall Transport(t).close() # call base t.server.stop() await t.server.closeWait() -method listen*(t: TcpTransport): Future[void] {.async.} = +method listen*(t: TcpTransport, ma: MultiAddress, handler: ConnHandler): Future[void] {.async.} = + await procCall Transport(t).listen(ma, handler) # call base + + ## listen on the transport let listenFuture: Future[void] = newFuture[void]() result = listenFuture - ## listen on the transport let server = createStreamServer(t.ma, connCb, {}, t) t.server = server server.start() + listenFuture.complete() method dial*(t: TcpTransport, address: MultiAddress): Future[Connection] {.async.} = ## dial a peer diff --git a/libp2p/transport.nim b/libp2p/transport.nim index 68c44fb..8dffc6e 100644 --- a/libp2p/transport.nim +++ b/libp2p/transport.nim @@ -24,8 +24,8 @@ type multicodec*: MultiCodec method connHandler*(t: Transport, - server: StreamServer, - client: StreamTransport): Future[Connection] {.base, gcsafe, async.} = + server: StreamServer, + client: StreamTransport): Future[Connection] {.base, gcsafe, async.} = let conn: Connection = newConnection(server, client) let handlerFut = if t.handler == nil: nil else: t.handler(conn) let connHolder: ConnHolder = ConnHolder(connection: conn, @@ -37,12 +37,8 @@ method init*(t: Transport) {.base.} = ## perform protocol initialization discard -proc newTransport*(t: typedesc[Transport], - ma: MultiAddress, - handler: ConnHandler = nil): t = +proc newTransport*(t: typedesc[Transport]): t = new result - result.ma = ma - result.handler = handler result.init() method close*(t: Transport) {.base, async.} = @@ -52,9 +48,10 @@ method close*(t: Transport) {.base, async.} = if c.connection.isOpen: await c.connection.close() -method listen*(t: Transport) {.base, async.} = +method listen*(t: Transport, ma: MultiAddress, handler: ConnHandler) {.base, async.} = ## listen for incoming connections - discard + t.ma = ma + t.handler = handler method dial*(t: Transport, address: MultiAddress): Future[Connection] {.base, async.} = ## dial a peer diff --git a/tests/testtransport b/tests/testtransport index 72d2209..274a28d 100755 Binary files a/tests/testtransport and b/tests/testtransport differ diff --git a/tests/testtransport.nim b/tests/testtransport.nim index e7fb275..4625a09 100644 --- a/tests/testtransport.nim +++ b/tests/testtransport.nim @@ -11,8 +11,8 @@ suite "TCP transport suite": result = conn.write(cstring("Hello!"), 6) await conn.close() - let transport: TcpTransport = newTransport(TcpTransport, ma, connHandler) - await transport.listen() + let transport: TcpTransport = newTransport(TcpTransport) + await transport.listen(ma, connHandler) let streamTransport: StreamTransport = await connect(ma) let msg = await streamTransport.read(6) await transport.close() @@ -30,8 +30,8 @@ suite "TCP transport suite": let msg = await conn.read(6) check cast[string](msg) == "Hello!" - let transport: TcpTransport = newTransport(TcpTransport, ma, connHandler) - await transport.listen() + let transport: TcpTransport = newTransport(TcpTransport) + await transport.listen(ma, connHandler) let streamTransport: StreamTransport = await connect(ma) let sent = await streamTransport.write("Hello!", 6) result = sent == 6 @@ -55,7 +55,7 @@ suite "TCP transport suite": server.start() let ma: MultiAddress = Multiaddress.init("/ip4/127.0.0.1/tcp/53337") - let transport: TcpTransport = newTransport(TcpTransport, ma) + let transport: TcpTransport = newTransport(TcpTransport) let conn = await transport.dial(ma) let msg = await conn.read(6) result = cast[string](msg) == "Hello!" @@ -65,29 +65,68 @@ suite "TCP transport suite": await server.join() check waitFor(testDialer(initTAddress("127.0.0.1:53337"))) == true -test "test dialer: handle write": - proc testDialer(address: TransportAddress): Future[bool] {.async.} = - proc serveClient(server: StreamServer, - transp: StreamTransport) {.async.} = - var rstream = newAsyncStreamReader(transp) - let msg = await rstream.read(6) - check cast[string](msg) == "Hello!" + test "test dialer: handle write": + proc testDialer(address: TransportAddress): Future[bool] {.async.} = + proc serveClient(server: StreamServer, + transp: StreamTransport) {.async.} = + var rstream = newAsyncStreamReader(transp) + let msg = await rstream.read(6) + check cast[string](msg) == "Hello!" + + await rstream.closeWait() + await transp.closeWait() + server.stop() + server.close() + + var server = createStreamServer(address, serveClient, {ReuseAddr}) + server.start() + + let ma: MultiAddress = Multiaddress.init("/ip4/127.0.0.1/tcp/53337") + let transport: TcpTransport = newTransport(TcpTransport) + let conn = await transport.dial(ma) + await conn.write(cstring("Hello!"), 6) + result = true - await rstream.closeWait() - await transp.closeWait() server.stop() server.close() + await server.join() + check waitFor(testDialer(initTAddress("127.0.0.1:53337"))) == true - var server = createStreamServer(address, serveClient, {ReuseAddr}) - server.start() + test "test listener - dialer: handle write": + proc testListenerDialer(): Future[bool] {.async.} = + let ma: MultiAddress = Multiaddress.init("/ip4/127.0.0.1/tcp/53339") + proc connHandler(conn: Connection): Future[void] {.async ,gcsafe.} = + result = conn.write(cstring("Hello!"), 6) + await conn.close() - let ma: MultiAddress = Multiaddress.init("/ip4/127.0.0.1/tcp/53337") - let transport: TcpTransport = newTransport(TcpTransport, ma) - let conn = await transport.dial(ma) + let transport1: TcpTransport = newTransport(TcpTransport) + await transport1.listen(ma, connHandler) + + let transport2: TcpTransport = newTransport(TcpTransport) + let conn = await transport2.dial(ma) + let msg = await conn.read(6) + await transport1.close() + + result = cast[string](msg) == "Hello!" + + check: + waitFor(testListenerDialer()) == true + + test "test listener - dialer: handle read": + proc testListenerDialer(): Future[bool] {.async.} = + let ma: MultiAddress = Multiaddress.init("/ip4/127.0.0.1/tcp/53340") + proc connHandler(conn: Connection): Future[void] {.async ,gcsafe.} = + let msg = await conn.read(6) + check cast[string](msg) == "Hello!" + + let transport1: TcpTransport = newTransport(TcpTransport) + await transport1.listen(ma, connHandler) + + let transport2: TcpTransport = newTransport(TcpTransport) + let conn = await transport2.dial(ma) await conn.write(cstring("Hello!"), 6) + await transport1.close() result = true - server.stop() - server.close() - await server.join() - check waitFor(testDialer(initTAddress("127.0.0.1:53337"))) == true + check: + waitFor(testListenerDialer()) == true