diff --git a/libp2p/connection.nim b/libp2p/connection.nim index a979f00..1d5dec2 100644 --- a/libp2p/connection.nim +++ b/libp2p/connection.nim @@ -14,19 +14,21 @@ const DefaultReadSize = 1024 type Connection* = ref object of RootObj reader: AsyncStreamReader - writter: AsyncStreamWriter + writer: AsyncStreamWriter server: StreamServer client: StreamTransport + isOpen*: bool proc newConnection*(server: StreamServer, client: StreamTransport): Connection = - ## create a new Connection for the specified async stream reader/writter + ## create a new Connection for the specified async stream reader/writer new result + result.isOpen = false result.server = server result.client = client result.reader = newAsyncStreamReader(client) - result.writter = newAsyncStreamWriter(client) + result.writer = newAsyncStreamWriter(client) method read* (c: Connection, size: int = DefaultReadSize): Future[seq[byte]] {.base, async, gcsafe.} = ## read DefaultReadSize (1024) bytes or `size` bytes if specified @@ -34,12 +36,18 @@ method read* (c: Connection, size: int = DefaultReadSize): Future[seq[byte]] {.b method write* (c: Connection, data: pointer, size: int): Future[void] {.base, async.} = ## write bytes pointed to by `data` up to `size` size - discard c.writter.write(data, size) + discard c.writer.write(data, size) method close* (c: Connection): Future[void] {.base, async.} = ## close connection - ## TODO: figure out how to correctly close the streams and underlying resource - discard + await c.reader.closeWait() + + await c.writer.finish() + await c.writer.closeWait() + + await c.client.closeWait() + c.server.stop() + c.server.close() method getPeerInfo* (c: Connection): Future[PeerInfo] {.base, async.} = ## get up to date peer info diff --git a/libp2p/tcptransport.nim b/libp2p/tcptransport.nim index a514a82..9b63acf 100644 --- a/libp2p/tcptransport.nim +++ b/libp2p/tcptransport.nim @@ -13,25 +13,20 @@ import transport, wire, connection, multiaddress, connection, multicodec type TcpTransport* = ref object of Transport server*: StreamServer -proc connHandler(server: StreamServer, - client: StreamTransport): Future[Connection] {.gcsafe, async.} = - let t: TcpTransport = cast[TcpTransport](server.udata) - let conn: Connection = newConnection(server, client) - let connHolder: ConnHolder = ConnHolder(connection: conn, - connFuture: t.handler(conn)) - t.connections.add(connHolder) - result = conn - proc connCb(server: StreamServer, client: StreamTransport) {.gcsafe, async.} = - discard connHandler(server, client) + let t: Transport = cast[Transport](server.udata) + discard t.connHandler(server, client) method init*(t: TcpTransport) = t.multicodec = multiCodec("tcp") method close*(t: TcpTransport): Future[void] {.async.} = ## start the transport - result = t.server.closeWait() + await procCall Transport(t).close() # call base close + + t.server.stop() + await t.server.closeWait() method listen*(t: TcpTransport): Future[void] {.async.} = let listenFuture: Future[void] = newFuture[void]() @@ -42,8 +37,7 @@ method listen*(t: TcpTransport): Future[void] {.async.} = t.server = server server.start() -method dial*(t: TcpTransport, - address: MultiAddress): Future[Connection] {.async.} = +method dial*(t: TcpTransport, address: MultiAddress): Future[Connection] {.async.} = ## dial a peer let client: StreamTransport = await connect(address) - result = await connHandler(t.server, client) + result = await t.connHandler(t.server, client) diff --git a/libp2p/transport.nim b/libp2p/transport.nim index 14d95a1..68c44fb 100644 --- a/libp2p/transport.nim +++ b/libp2p/transport.nim @@ -23,24 +23,37 @@ type handler*: ConnHandler multicodec*: MultiCodec +method connHandler*(t: Transport, + server: StreamServer, + client: StreamTransport): Future[Connection] {.base, gcsafe, async.} = + let conn: Connection = newConnection(server, client) + let handlerFut = if t.handler == nil: nil else: t.handler(conn) + let connHolder: ConnHolder = ConnHolder(connection: conn, + connFuture: handlerFut) + t.connections.add(connHolder) + result = conn + method init*(t: Transport) {.base.} = ## perform protocol initialization discard proc newTransport*(t: typedesc[Transport], ma: MultiAddress, - handler: ConnHandler): t = + handler: ConnHandler = nil): t = new result result.ma = ma result.handler = handler result.init() method close*(t: Transport) {.base, async.} = - ## start the transport - discard + ## stop and cleanup the transport + ## including all outstanding connections + for c in t.connections: + if c.connection.isOpen: + await c.connection.close() method listen*(t: Transport) {.base, async.} = - ## stop the transport + ## listen for incoming connections discard method dial*(t: Transport, address: MultiAddress): Future[Connection] {.base, async.} = diff --git a/tests/testtransport b/tests/testtransport index 6823eb8..72d2209 100755 Binary files a/tests/testtransport and b/tests/testtransport differ diff --git a/tests/testtransport.nim b/tests/testtransport.nim index 07f2a30..e7fb275 100644 --- a/tests/testtransport.nim +++ b/tests/testtransport.nim @@ -4,17 +4,90 @@ import ../libp2p/connection, ../libp2p/transport, ../libp2p/tcptransport, ../libp2p/multiaddress, ../libp2p/wire suite "TCP transport suite": - test "test listener": + test "test listener: handle write": proc testListener(): Future[bool] {.async.} = let ma: MultiAddress = Multiaddress.init("/ip4/127.0.0.1/tcp/53335") proc connHandler(conn: Connection): Future[void] {.async ,gcsafe.} = result = conn.write(cstring("Hello!"), 6) + await conn.close() let transport: TcpTransport = newTransport(TcpTransport, ma, connHandler) await transport.listen() let streamTransport: StreamTransport = await connect(ma) let msg = await streamTransport.read(6) + await transport.close() + await streamTransport.closeWait() + result = cast[string](msg) == "Hello!" check: waitFor(testListener()) == true + + test "test listener: handle read": + proc testListener(): Future[bool] {.async.} = + let ma: MultiAddress = Multiaddress.init("/ip4/127.0.0.1/tcp/53336") + proc connHandler(conn: Connection): Future[void] {.async ,gcsafe.} = + let msg = await conn.read(6) + check cast[string](msg) == "Hello!" + + let transport: TcpTransport = newTransport(TcpTransport, ma, connHandler) + await transport.listen() + let streamTransport: StreamTransport = await connect(ma) + let sent = await streamTransport.write("Hello!", 6) + result = sent == 6 + + check: + waitFor(testListener()) == true + + test "test dialer: handle write": + proc testDialer(address: TransportAddress): Future[bool] {.async.} = + proc serveClient(server: StreamServer, + transp: StreamTransport) {.async.} = + var wstream = newAsyncStreamWriter(transp) + await wstream.write("Hello!") + await wstream.finish() + await wstream.closeWait() + await transp.closeWait() + server.stop() + server.close() + + var server = createStreamServer(address, serveClient, {ReuseAddr}) + server.start() + + let ma: MultiAddress = Multiaddress.init("/ip4/127.0.0.1/tcp/53337") + let transport: TcpTransport = newTransport(TcpTransport, ma) + let conn = await transport.dial(ma) + let msg = await conn.read(6) + result = cast[string](msg) == "Hello!" + + server.stop() + server.close() + await server.join() + check waitFor(testDialer(initTAddress("127.0.0.1:53337"))) == true + +test "test dialer: handle write": + proc testDialer(address: TransportAddress): Future[bool] {.async.} = + proc serveClient(server: StreamServer, + transp: StreamTransport) {.async.} = + var rstream = newAsyncStreamReader(transp) + let msg = await rstream.read(6) + check cast[string](msg) == "Hello!" + + await rstream.closeWait() + await transp.closeWait() + server.stop() + server.close() + + var server = createStreamServer(address, serveClient, {ReuseAddr}) + server.start() + + let ma: MultiAddress = Multiaddress.init("/ip4/127.0.0.1/tcp/53337") + let transport: TcpTransport = newTransport(TcpTransport, ma) + let conn = await transport.dial(ma) + await conn.write(cstring("Hello!"), 6) + result = true + + server.stop() + server.close() + await server.join() + check waitFor(testDialer(initTAddress("127.0.0.1:53337"))) == true