diff --git a/codexvalidator/network.nim b/codexvalidator/network.nim index 0648021..b07c282 100644 --- a/codexvalidator/network.nim +++ b/codexvalidator/network.nim @@ -1,7 +1,9 @@ import ./network/address import ./network/server import ./network/connection +import ./network/error export address export server export connection +export error.NetworkError diff --git a/codexvalidator/network/connection.nim b/codexvalidator/network/connection.nim index 10f0e16..b712f67 100644 --- a/codexvalidator/network/connection.nim +++ b/codexvalidator/network/connection.nim @@ -1,14 +1,38 @@ import ../basics import ./address +import ./error type NetworkConnection* = distinct StreamTransport proc connect*( _: type NetworkConnection, address: NetworkAddress -): Future[?!NetworkConnection] {.async:(raises:[]).} = - NetworkConnection(await TransportAddress(address).connect()).catch() +): Future[NetworkConnection] {.async:(raises:[NetworkError, CancelledError]).} = + convertNetworkErrors: + NetworkConnection(await TransportAddress(address).connect()) + +proc sendPacket*(connection: NetworkConnection, packet: seq[byte]) {. + async:(raises:[NetworkError, CancelledError]) +.} = + convertNetworkErrors: + let transport = StreamTransport(connection) + let header = @[packet.len.uint32] + discard await transport.write(header) + if packet.len > 0: + discard await transport.write(packet) + +proc receivePacket*(connection: NetworkConnection): Future[?seq[byte]] {. + async:(raises:[NetworkError, CancelledError]) +.} = + convertNetworkErrors: + let transport = StreamTransport(connection) + let header = await transport.read(sizeof(uint32)) + if header.len != sizeof(uint32): + return none seq[byte] + let length = (cast[ptr uint32](addr header[0]))[] + if length == 0: + return some seq[byte].default + some await transport.read(length.int) proc close*(connection: NetworkConnection) {.async:(raises:[]).} = - StreamTransport(connection).close() - await noCancel StreamTransport(connection).join() + await StreamTransport(connection).closeWait() diff --git a/codexvalidator/network/error.nim b/codexvalidator/network/error.nim new file mode 100644 index 0000000..2701f56 --- /dev/null +++ b/codexvalidator/network/error.nim @@ -0,0 +1,9 @@ +import ../basics + +type NetworkError* = object of IOError + +template convertNetworkErrors*(body): untyped = + try: + body + except TransportError as error: + raise newException(NetworkError, error.msg, error) diff --git a/codexvalidator/network/server.nim b/codexvalidator/network/server.nim index 5b8c9de..f7cf4cb 100644 --- a/codexvalidator/network/server.nim +++ b/codexvalidator/network/server.nim @@ -1,14 +1,46 @@ import ../basics import ./address +import ./connection +import ./error -type NetworkServer* = distinct StreamServer +type ConnectionQueue = AsyncQueue[NetworkConnection] -proc open*(_: type NetworkServer): Future[?!NetworkServer] {.async:(raises:[]).} = - NetworkServer(createStreamServer(Port(0))).catch() +func new(_: type ConnectionQueue, maxSize: int): ConnectionQueue = + newAsyncQueue[NetworkConnection](maxSize) + +func createStreamCallback(queue: ConnectionQueue): auto = + proc(_: StreamServer, stream: StreamTransport) {.async:(raises:[]).} = + try: + await queue.addLast(NetworkConnection(stream)) + except CancelledError: + discard + +type NetworkServer* = ref object + implementation: StreamServer + incoming: ConnectionQueue + +proc open*(_: type NetworkServer): Future[NetworkServer] {. + async:(raises:[NetworkError]) +.} = + convertNetworkErrors: + let incoming = ConnectionQueue.new(1) + let callback = incoming.createStreamCallback() + let server = createStreamServer(callback, Port(0)) + server.start() + NetworkServer( + implementation: server, + incoming: incoming + ) proc address*(server: NetworkServer): ?!NetworkAddress = - NetworkAddress(StreamServer(server).localAddress()).catch() + catch NetworkAddress(server.implementation.localAddress()) -proc close*(server: NetworkServer) {.async:(raises:[]).} = - StreamServer(server).close() - await noCancel StreamServer(server).join() +proc accept*(server: NetworkServer): Future[NetworkConnection] {. + async:(raises:[CancelledError]) +.} = + await server.incoming.popFirst() + +proc close*(server: NetworkServer) {.async:(raises:[NetworkError]).} = + convertNetworkErrors: + server.implementation.stop() + await server.implementation.closeWait() diff --git a/tests/codexvalidator/testNetwork.nim b/tests/codexvalidator/testNetwork.nim index 899395e..d893260 100644 --- a/tests/codexvalidator/testNetwork.nim +++ b/tests/codexvalidator/testNetwork.nim @@ -1,17 +1,65 @@ import ./basics import codexvalidator/network -suite "Network communication": +suite "Network connections": - test "a connection can be made to a server": - let server = !await NetworkServer.open() - let address = !server.address - let connection = !await NetworkConnection.connect(address) - await connection.close() + test "connections to a server can be made": + let server = await NetworkServer.open() + let outgoing = await NetworkConnection.connect(!server.address) + let incoming = await server.accept() + await outgoing.close() + await incoming.close() await server.close() - test "connect can fail": + test "outgoing connections can fail": let address = !NetworkAddress.init("127.0.0.1:1011") # port reserved by IANA - let connection = await NetworkConnection.connect(address) - check connection.isFailure - check connection.error.msg.contains("Connection refused") + expect NetworkError: + discard await NetworkConnection.connect(address) + +suite "Network packets": + + var server: NetworkServer + + setup: + server = await NetworkServer.open() + + teardown: + await server.close() + + test "packets of bytes can be exchanged over a network connection": + let packet = seq[byte].example + var received: seq[byte] + + proc send {.async.} = + let outgoing = await NetworkConnection.connect(!server.address) + await outgoing.sendPacket(packet) + await outgoing.close() + + proc receive {.async.} = + let incoming = await server.accept() + received = !await incoming.receivePacket() + await incoming.close() + + await allFutures(send(), receive()) + + check received == packet + + test "connection handles multiple packets of different size": + let packets = newSeqWith(100, seq[byte].example) + var received: seq[seq[byte]] + + proc send {.async.} = + let outgoing = await NetworkConnection.connect(!server.address) + for packet in packets: + await outgoing.sendPacket(packet) + await outgoing.close() + + proc receive {.async.} = + let incoming = await server.accept() + while packet =? await incoming.receivePacket(): + received.add(packet) + await incoming.close() + + await allFutures(send(), receive()) + + check received == packets