diff --git a/eth/utp/utp.nim b/eth/utp/utp.nim index 3fd4953..0ed768d 100644 --- a/eth/utp/utp.nim +++ b/eth/utp/utp.nim @@ -7,7 +7,7 @@ {.push raises: [Defect].} import - chronos, stew/byteutils, + chronos, stew/[results, byteutils], ./utp_router, ./utp_socket, ./utp_protocol @@ -32,7 +32,8 @@ when isMainModule: let utpProt = UtpProtocol.new(echoIncomingSocketCallBack(), localAddress) let remoteServer = initTAddress("127.0.0.1", 9078) - let soc = waitFor utpProt.connectTo(remoteServer) + let socResult = waitFor utpProt.connectTo(remoteServer) + let soc = socResult.get() doAssert(soc.numPacketsInOutGoingBuffer() == 0) diff --git a/eth/utp/utp_discov5_protocol.nim b/eth/utp/utp_discov5_protocol.nim index 5379144..791e129 100644 --- a/eth/utp/utp_discov5_protocol.nim +++ b/eth/utp/utp_discov5_protocol.nim @@ -50,11 +50,13 @@ proc new*( subProtocolName: seq[byte], acceptConnectionCb: AcceptConnectionCallback[Node], socketConfig: SocketConfig = SocketConfig.init(), + allowConnectionCb: AllowConnectionCallback[Node] = nil, rng = newRng()): UtpDiscv5Protocol {.raises: [Defect, CatchableError].} = doAssert(not(isNil(acceptConnectionCb))) let router = UtpRouter[Node].new( acceptConnectionCb, + allowConnectionCb, socketConfig, rng ) @@ -71,9 +73,12 @@ proc new*( ) prot -proc connectTo*(r: UtpDiscv5Protocol, address: Node): Future[UtpSocket[Node]]= +proc connectTo*(r: UtpDiscv5Protocol, address: Node): Future[ConnectionResult[Node]]= return r.router.connectTo(address) +proc connectTo*(r: UtpDiscv5Protocol, address: Node, connectionId: uint16): Future[ConnectionResult[Node]]= + return r.router.connectTo(address, connectionId) + proc shutdown*(r: UtpDiscv5Protocol) = ## closes all managed utp connections in background (not closed discovery, it is up to user) r.router.shutdown() diff --git a/eth/utp/utp_protocol.nim b/eth/utp/utp_protocol.nim index 9954e40..8ef687e 100644 --- a/eth/utp/utp_protocol.nim +++ b/eth/utp/utp_protocol.nim @@ -77,12 +77,14 @@ proc new*( acceptConnectionCb: AcceptConnectionCallback[TransportAddress], address: TransportAddress, socketConfig: SocketConfig = SocketConfig.init(), + allowConnectionCb: AllowConnectionCallback[TransportAddress] = nil, rng = newRng()): UtpProtocol {.raises: [Defect, CatchableError].} = doAssert(not(isNil(acceptConnectionCb))) let router = UtpRouter[TransportAddress].new( acceptConnectionCb, + allowConnectionCb, socketConfig, rng ) @@ -96,8 +98,11 @@ proc shutdownWait*(p: UtpProtocol): Future[void] {.async.} = await p.utpRouter.shutdownWait() await p.transport.closeWait() -proc connectTo*(r: UtpProtocol, address: TransportAddress): Future[UtpSocket[TransportAddress]] = +proc connectTo*(r: UtpProtocol, address: TransportAddress): Future[ConnectionResult[TransportAddress]] = return r.utpRouter.connectTo(address) +proc connectTo*(r: UtpProtocol, address: TransportAddress, connectionId: uint16): Future[ConnectionResult[TransportAddress]] = + return r.utpRouter.connectTo(address, connectionId) + proc openSockets*(r: UtpProtocol): int = len(r.utpRouter) diff --git a/eth/utp/utp_router.nim b/eth/utp/utp_router.nim index 99ef186..e0a548d 100644 --- a/eth/utp/utp_router.nim +++ b/eth/utp/utp_router.nim @@ -17,6 +17,11 @@ type AcceptConnectionCallback*[A] = proc(server: UtpRouter[A], client: UtpSocket[A]): Future[void] {.gcsafe, raises: [Defect].} + # Callback to act as fire wall for incoming peers. Should return true if peer is allowed + # to connect. + AllowConnectionCallback*[A] = + proc(r: UtpRouter[A], remoteAddress: A, connectionId: uint16): bool {.gcsafe, raises: [Defect], noSideEffect.} + # Oject responsible for creating and maintaing table of of utp sockets. # caller should use `processIncomingBytes` proc to feed it with incoming byte # packets, based this input, proper utp sockets will be created, closed, or will @@ -27,8 +32,14 @@ type acceptConnection: AcceptConnectionCallback[A] closed: bool sendCb*: SendCallback[A] + allowConnection*: AllowConnectionCallback[A] rng*: ref BrHmacDrbgContext +const + # Maximal number of tries to genearte unique socket while establishing outgoing + # connection. + maxSocketGenerationTries = 1000 + # this should probably be in standard lib, it allows lazy composition of options i.e # one can write: O1 orElse O2 orElse O3, and chain will be evaluated to first option # which isSome() @@ -57,25 +68,43 @@ proc len*[A](s: UtpRouter[A]): int = len(s.sockets) proc registerUtpSocket[A](p: UtpRouter, s: UtpSocket[A]) = - # TODO Handle duplicates + ## Register socket, overwriting already existing one p.sockets[s.socketKey] = s # Install deregister handler, so when socket will get closed, in will be promptly # removed from open sockets table s.registerCloseCallback(proc () = p.deRegisterUtpSocket(s)) +proc registerIfAbsent[A](p: UtpRouter, s: UtpSocket[A]): bool = + ## Registers socket only if its not already exsiting in the active sockets table + ## return true is socket has been succesfuly registered + if p.sockets.hasKey(s.socketKey): + false + else: + p.registerUtpSocket(s) + true + proc new*[A]( T: type UtpRouter[A], - acceptConnectionCb: AcceptConnectionCallback[A], + acceptConnectionCb: AcceptConnectionCallback[A], + allowConnectionCb: AllowConnectionCallback[A], socketConfig: SocketConfig = SocketConfig.init(), rng = newRng()): UtpRouter[A] {.raises: [Defect, CatchableError].} = doAssert(not(isNil(acceptConnectionCb))) UtpRouter[A]( sockets: initTable[UtpSocketKey[A], UtpSocket[A]](), acceptConnection: acceptConnectionCb, + allowConnection: allowConnectionCb, socketConfig: socketConfig, rng: rng ) +proc new*[A]( + T: type UtpRouter[A], + acceptConnectionCb: AcceptConnectionCallback[A], + socketConfig: SocketConfig = SocketConfig.init(), + rng = newRng()): UtpRouter[A] {.raises: [Defect, CatchableError].} = + UtpRouter[A].new(acceptConnectionCb, nil, socketConfig, rng) + # There are different possiblites how connection was established, and we need to # check every case proc getSocketOnReset[A](r: UtpRouter[A], sender: A, id: uint16): Option[UtpSocket[A]] = @@ -92,6 +121,13 @@ proc getSocketOnReset[A](r: UtpRouter[A], sender: A, id: uint16): Option[UtpSock .orElse(r.getUtpSocket(sendInitKey).filter(s => s.connectionIdSnd == id)) .orElse(r.getUtpSocket(sendNoInitKey).filter(s => s.connectionIdSnd == id)) +proc shouldAllowConnection[A](r: UtpRouter[A], remoteAddress: A, connectionId: uint16): bool = + if r.allowConnection == nil: + # if the callback is not configured it means all incoming connections are allowed + true + else: + r.allowConnection(r, remoteAddress, connectionId) + proc processPacket[A](r: UtpRouter[A], p: Packet, sender: A) {.async.}= notice "Received packet ", packet = p @@ -116,18 +152,21 @@ proc processPacket[A](r: UtpRouter[A], p: Packet, sender: A) {.async.}= if (maybeSocket.isSome()): notice "Ignoring SYN for already existing connection" else: - notice "Received SYN for not known connection. Initiating incoming connection" - # Initial ackNr is set to incoming packer seqNr - let incomingSocket = initIncomingSocket[A](sender, r.sendCb, r.socketConfig ,p.header.connectionId, p.header.seqNr, r.rng[]) - r.registerUtpSocket(incomingSocket) - await incomingSocket.startIncomingSocket() - # TODO By default (when we have utp over udp) socket here is passed to upper layer - # in SynRecv state, which is not writeable i.e user of socket cannot write - # data to it unless some data will be received. This is counter measure to - # amplification attacks. - # During integration with discovery v5 (i.e utp over discovv5), we must re-think - # this. - asyncSpawn r.acceptConnection(r, incomingSocket) + if (r.shouldAllowConnection(sender, p.header.connectionId)): + notice "Received SYN for not known connection. Initiating incoming connection" + # Initial ackNr is set to incoming packer seqNr + let incomingSocket = initIncomingSocket[A](sender, r.sendCb, r.socketConfig ,p.header.connectionId, p.header.seqNr, r.rng[]) + r.registerUtpSocket(incomingSocket) + await incomingSocket.startIncomingSocket() + # TODO By default (when we have utp over udp) socket here is passed to upper layer + # in SynRecv state, which is not writeable i.e user of socket cannot write + # data to it unless some data will be received. This is counter measure to + # amplification attacks. + # During integration with discovery v5 (i.e utp over discovv5), we must re-think + # this. + asyncSpawn r.acceptConnection(r, incomingSocket) + else: + notice "Connection declined" else: let socketKey = UtpSocketKey[A].init(sender, p.header.connectionId) let maybeSocket = r.getUtpSocket(socketKey) @@ -149,14 +188,60 @@ proc processIncomingBytes*[A](r: UtpRouter[A], bytes: seq[byte], sender: A) {.as else: warn "failed to decode packet from address", address = sender +proc generateNewUniqueSocket[A](r: UtpRouter[A], address: A): Option[UtpSocket[A]] = + ## Tries to generate unique socket, gives up after maxSocketGenerationTries tries + var tryCount = 0 + + while tryCount < maxSocketGenerationTries: + let rcvId = randUint16(r.rng[]) + let socket = initOutgoingSocket[A](address, r.sendCb, r.socketConfig, rcvId, r.rng[]) + + if r.registerIfAbsent(socket): + return some(socket) + + inc tryCount + + return none[UtpSocket[A]]() + +proc connect[A](s: UtpSocket[A]): Future[ConnectionResult[A]] {.async.}= + let startFut = s.startOutgoingSocket() + + startFut.cancelCallback = proc(udata: pointer) {.gcsafe.} = + # if for some reason future will be cancelled, destory socket to clear it from + # active socket list + s.destroy() + + try: + await startFut + return ok(s) + except ConnectionError: + s.destroy() + return err(OutgoingConnectionError(kind: ConnectionTimedOut)) + except CatchableError as e: + s.destroy() + # this may only happen if user provided callback will for some reason fail + return err(OutgoingConnectionError(kind: ErrorWhileSendingSyn, error: e)) + # Connect to provided address # Reference implementation: https://github.com/bittorrent/libutp/blob/master/utp_internal.cpp#L2732 -proc connectTo*[A](r: UtpRouter[A], address: A): Future[UtpSocket[A]] {.async.}= - let socket = initOutgoingSocket[A](address, r.sendCb, r.socketConfig, r.rng[]) - r.registerUtpSocket(socket) - await socket.startOutgoingSocket() - await socket.waitFotSocketToConnect() - return socket +proc connectTo*[A](r: UtpRouter[A], address: A): Future[ConnectionResult[A]] {.async.} = + let maybeSocket = r.generateNewUniqueSocket(address) + + if (maybeSocket.isNone()): + return err(OutgoingConnectionError(kind: SocketAlreadyExists)) + else: + let socket = maybeSocket.unsafeGet() + return await socket.connect() + +# Connect to provided address with provided connection id, if socket with this id +# and address already exsits return error +proc connectTo*[A](r: UtpRouter[A], address: A, connectionId: uint16): Future[ConnectionResult[A]] {.async.} = + let socket = initOutgoingSocket[A](address, r.sendCb, r.socketConfig, connectionId, r.rng[]) + + if (r.registerIfAbsent(socket)): + return await socket.connect() + else: + return err(OutgoingConnectionError(kind: SocketAlreadyExists)) proc shutdown*[A](r: UtpRouter[A]) = # stop processing any new packets and close all sockets in background without diff --git a/eth/utp/utp_socket.nim b/eth/utp/utp_socket.nim index 94b243e..8829bfb 100644 --- a/eth/utp/utp_socket.nim +++ b/eth/utp/utp_socket.nim @@ -161,6 +161,18 @@ type WriteResult* = Result[int, WriteError] + OutgoingConnectionErrorType* = enum + SocketAlreadyExists, ConnectionTimedOut, ErrorWhileSendingSyn + + OutgoingConnectionError* = object + case kind*: OutgoingConnectionErrorType + of ErrorWhileSendingSyn: + error*: ref CatchableError + of SocketAlreadyExists, ConnectionTimedOut: + discard + + ConnectionResult*[A] = Result[UtpSocket[A], OutgoingConnectionError] + const # Maximal number of payload bytes per packet. Total packet size will be equal to # mtuSize + sizeof(header) = 600 bytes @@ -258,16 +270,6 @@ proc sendAck(socket: UtpSocket): Future[void] = ) socket.sendData(encodePacket(ackPacket)) -proc sendSyn(socket: UtpSocket): Future[void] = - doAssert(socket.state == SynSent , "syn can only be send when in SynSent state") - let packet = synPacket(socket.seqNr, socket.connectionIdRcv, socket.getRcvWindowSize()) - notice "Sending syn packet packet", packet = packet - # set number of transmissions to 1 as syn packet will be send just after - # initiliazation - let outgoingPacket = OutgoingPacket.init(encodePacket(packet), 1, false) - socket.registerOutgoingPacket(outgoingPacket) - socket.sendData(outgoingPacket.packetBytes) - # Should be called before sending packet proc setSend(p: var OutgoingPacket): seq[byte] = inc p.transmissions @@ -423,10 +425,9 @@ proc initOutgoingSocket*[A]( to: A, snd: SendCallback[A], cfg: SocketConfig, + rcvConnectionId: uint16, rng: var BrHmacDrbgContext ): UtpSocket[A] = - # TODO handle possible clashes and overflows - let rcvConnectionId = randUint16(rng) let sndConnectionId = rcvConnectionId + 1 let initialSeqNr = randUint16(rng) @@ -467,18 +468,19 @@ proc initIncomingSocket*[A]( proc startOutgoingSocket*(socket: UtpSocket): Future[void] {.async.} = doAssert(socket.state == SynSent) - # TODO add callback to handle errors and cancellation i.e unregister socket on - # send error and finish connection future with failure - # sending should be done from UtpSocketContext - await socket.sendSyn() + let packet = synPacket(socket.seqNr, socket.connectionIdRcv, socket.getRcvWindowSize()) + notice "Sending syn packet packet", packet = packet + # set number of transmissions to 1 as syn packet will be send just after + # initiliazation + let outgoingPacket = OutgoingPacket.init(encodePacket(packet), 1, false) + socket.registerOutgoingPacket(outgoingPacket) socket.startTimeoutLoop() - -proc waitFotSocketToConnect*(socket: UtpSocket): Future[void] {.async.} = + await socket.sendData(outgoingPacket.packetBytes) await socket.connectionFuture proc startIncomingSocket*(socket: UtpSocket) {.async.} = doAssert(socket.state == SynRecv) - # Make sure ack was flushed before movig forward + # Make sure ack was flushed before moving forward await socket.sendAck() socket.startTimeoutLoop() @@ -928,3 +930,13 @@ proc numPacketsInReordedBuffer*(socket: UtpSocket): int = inc num doAssert(num == int(socket.reorderCount)) num + +proc connectionId*[A](socket: UtpSocket[A]): uint16 = + ## Connection id is id which is used in first SYN packet which establishes the connection + ## so for Outgoing side it is actually its rcv_id, and for Incoming side it is + ## its snd_id + case socket.direction + of Incoming: + socket.connectionIdSnd + of Outgoing: + socket.connectionIdRcv diff --git a/tests/utp/test_discv5_protocol.nim b/tests/utp/test_discv5_protocol.nim index 454f2b4..1b2f477 100644 --- a/tests/utp/test_discv5_protocol.nim +++ b/tests/utp/test_discv5_protocol.nim @@ -55,6 +55,12 @@ procSuite "Utp protocol over discovery v5 tests": serverSockets.addLast(client) ) + proc allowOneIdCallback(allowedId: uint16): AllowConnectionCallback[Node] = + return ( + proc(r: UtpRouter[Node], remoteAddress: Node, connectionId: uint16): bool = + connectionId == allowedId + ) + # TODO Add more tests to discovery v5 suite, especially those which will differ # from standard utp case asyncTest "Success connect to remote host": @@ -73,8 +79,9 @@ procSuite "Utp protocol over discovery v5 tests": node1.addNode(node2.localNode) node2.addNode(node1.localNode) - let clientSocket = await utp1.connectTo(node2.localNode) - + let clientSocketResult = await utp1.connectTo(node2.localNode) + let clientSocket = clientSocketResult.get() + check: clientSocket.isConnected() @@ -99,7 +106,9 @@ procSuite "Utp protocol over discovery v5 tests": node2.addNode(node1.localNode) let numOfBytes = 5000 - let clientSocket = await utp1.connectTo(node2.localNode) + let clientSocketResult = await utp1.connectTo(node2.localNode) + let clientSocket = clientSocketResult.get() + let serverSocket = await queue.get() let bytesToTransfer = generateByteArray(rng[], numOfBytes) @@ -117,3 +126,48 @@ procSuite "Utp protocol over discovery v5 tests": await serverSocket.destroyWait() await node1.closeWait() await node2.closeWait() + + asyncTest "Accept connection only from allowed peers": + let + allowedId: uint16 = 10 + lowSynTimeout = milliseconds(500) + queue = newAsyncQueue[UtpSocket[Node]]() + node1 = initDiscoveryNode( + rng, PrivateKey.random(rng[]), localAddress(20302)) + node2 = initDiscoveryNode( + rng, PrivateKey.random(rng[]), localAddress(20303)) + + utp1 = UtpDiscv5Protocol.new( + node1, + utpProtId, + registerIncomingSocketCallback(queue), + SocketConfig.init(lowSynTimeout)) + utp2 = + UtpDiscv5Protocol.new( + node2, + utpProtId, + registerIncomingSocketCallback(queue), + SocketConfig.init(), + allowOneIdCallback(allowedId)) + + # nodes must know about each other + check: + node1.addNode(node2.localNode) + node2.addNode(node1.localNode) + + let clientSocketResult1 = await utp1.connectTo(node2.localNode, allowedId) + let clientSocketResult2 = await utp1.connectTo(node2.localNode, allowedId + 1) + + check: + clientSocketResult1.isOk() + clientSocketResult2.isErr() + + let clientSocket = clientSocketResult1.get() + let serverSocket = await queue.get() + + check: + clientSocket.connectionId() == allowedId + serverSocket.connectionId() == allowedId + + await node1.closeWait() + await node2.closeWait() diff --git a/tests/utp/test_protocol.nim b/tests/utp/test_protocol.nim index 96a8c2d..1c24641 100644 --- a/tests/utp/test_protocol.nim +++ b/tests/utp/test_protocol.nim @@ -30,6 +30,12 @@ proc registerIncomingSocketCallback(serverSockets: AsyncQueue): AcceptConnection serverSockets.addLast(client) ) +proc allowOneIdCallback(allowedId: uint16): AllowConnectionCallback[TransportAddress] = + return ( + proc(r: UtpRouter[TransportAddress], remoteAddress: TransportAddress, connectionId: uint16): bool = + connectionId == allowedId + ) + proc transferData(sender: UtpSocket[TransportAddress], receiver: UtpSocket[TransportAddress], data: seq[byte]): Future[seq[byte]] {.async.}= let bytesWritten = await sender.write(data) doAssert bytesWritten.get() == len(data) @@ -67,7 +73,7 @@ proc initClientServerScenario(): Future[ClientServerScenario] {.async.} = return ClientServerScenario( utp1: utpProt1, utp2: utpProt2, - clientSocket: clientSocket, + clientSocket: clientSocket.get(), serverSocket: serverSocket ) @@ -102,8 +108,8 @@ proc init2ClientsServerScenario(): Future[TwoClientsServerScenario] {.async.} = utp1: utpProt1, utp2: utpProt2, utp3: utpProt3, - clientSocket1: clientSocket1, - clientSocket2: clientSocket2, + clientSocket1: clientSocket1.get(), + clientSocket2: clientSocket2.get(), serverSocket1: serverSocket1, serverSocket2: serverSocket2 ) @@ -125,8 +131,8 @@ procSuite "Utp protocol over udp tests": let address1 = initTAddress("127.0.0.1", 9080) let utpProt2 = UtpProtocol.new(setAcceptedCallback(server2Called), address1) - let sock = await utpProt1.connectTo(address1) - + let sockResult = await utpProt1.connectTo(address1) + let sock = sockResult.get() # this future will be completed when we called accepted connection callback await server2Called.wait() @@ -148,13 +154,16 @@ procSuite "Utp protocol over udp tests": let address1 = initTAddress("127.0.0.1", 9080) - let fut = utpProt1.connectTo(address1) - - yield fut + let connectionResult = await utpProt1.connectTo(address1) check: - fut.failed() - + connectionResult.isErr() + + let connectionError = connectionResult.error() + + check: + connectionError.kind == ConnectionTimedOut + await waitUntil(proc (): bool = utpProt1.openSockets() == 0) check: @@ -370,3 +379,45 @@ procSuite "Utp protocol over udp tests": s.utp1.openSockets() == 0 await s.close() + + asyncTest "Accept connection only from allowed peers": + let allowedId: uint16 = 10 + let lowSynTimeout = milliseconds(500) + var serverSockets = newAsyncQueue[UtpSocket[TransportAddress]]() + var server1Called = newAsyncEvent() + let address1 = initTAddress("127.0.0.1", 9079) + let utpProt1 = + UtpProtocol.new(setAcceptedCallback(server1Called), address1, SocketConfig.init(lowSynTimeout)) + + let address2 = initTAddress("127.0.0.1", 9080) + let utpProt2 = + UtpProtocol.new(registerIncomingSocketCallback(serverSockets), address2, SocketConfig.init(lowSynTimeout)) + + let address3 = initTAddress("127.0.0.1", 9081) + let utpProt3 = + UtpProtocol.new( + registerIncomingSocketCallback(serverSockets), + address3, + SocketConfig.init(), + allowOneIdCallback(allowedId) + ) + + let allowedSocketRes = await utpProt1.connectTo(address3, allowedId) + let notAllowedSocketRes = await utpProt2.connectTo(address3, allowedId + 1) + + check: + allowedSocketRes.isOk() + notAllowedSocketRes.isErr() + # remote did not allow this connection, and utlimatly it did time out + notAllowedSocketRes.error().kind == ConnectionTimedOut + + let clientSocket = allowedSocketRes.get() + let serverSocket = await serverSockets.get() + + check: + clientSocket.connectionId() == allowedId + serverSocket.connectionId() == allowedId + + await utpProt1.shutdownWait() + await utpProt2.shutdownWait() + await utpProt3.shutdownWait() diff --git a/tests/utp/test_utp_router.nim b/tests/utp/test_utp_router.nim index f43fc0b..905e25b 100644 --- a/tests/utp/test_utp_router.nim +++ b/tests/utp/test_utp_router.nim @@ -21,6 +21,9 @@ proc hash*(x: UtpSocketKey[int]): Hash = h = h !& x.rcvId.hash !$h +type + TestError* = object of CatchableError + procSuite "Utp router unit tests": let rng = newRng() let testSender = 1 @@ -62,7 +65,7 @@ procSuite "Utp router unit tests": await router.processIncomingBytes(encodePacket(responseAck), remote) let outgoingSocket = await connectFuture - (outgoingSocket, initialPacket) + (outgoingSocket.get(), initialPacket) asyncTest "Router should ingnore non utp packets": let q = newAsyncQueue[UtpSocket[int]]() @@ -149,6 +152,98 @@ procSuite "Utp router unit tests": outgoingSocket.isConnected() router.len() == 1 + asyncTest "Router should fail to connect to the same peer with the same connection id": + let q = newAsyncQueue[UtpSocket[int]]() + let pq = newAsyncQueue[(Packet, int)]() + let initialRemoteSeq = 30'u16 + let router = UtpRouter[int].new(registerIncomingSocketCallback(q), SocketConfig.init(), rng) + router.sendCb = initTestSnd(pq) + + let requestedConnectionId = 1'u16 + let connectFuture = router.connectTo(testSender2, requestedConnectionId) + + let (initialPacket, sender) = await pq.get() + + check: + initialPacket.header.pType == ST_SYN + # connection id of syn packet should be set to requested connection id + initialPacket.header.connectionId == requestedConnectionId + + let responseAck = ackPacket(initialRemoteSeq, initialPacket.header.connectionId, initialPacket.header.seqNr, testBufferSize) + + await router.processIncomingBytes(encodePacket(responseAck), testSender2) + + let outgoingSocket = await connectFuture + + check: + outgoingSocket.get().isConnected() + router.len() == 1 + + let duplicatedConnectionResult = await router.connectTo(testSender2, requestedConnectionId) + + check: + duplicatedConnectionResult.isErr() + duplicatedConnectionResult.error().kind == SocketAlreadyExists + + asyncTest "Router should fail connect when socket syn will not be acked": + let q = newAsyncQueue[UtpSocket[int]]() + let pq = newAsyncQueue[(Packet, int)]() + let router = UtpRouter[int].new(registerIncomingSocketCallback(q), SocketConfig.init(milliseconds(500)), rng) + router.sendCb = initTestSnd(pq) + + let connectFuture = router.connectTo(testSender2) + + let (initialPacket, sender) = await pq.get() + + check: + initialPacket.header.pType == ST_SYN + + let connectResult = await connectFuture + + check: + connectResult.isErr() + connectResult.error().kind == ConnectionTimedOut + router.len() == 0 + + asyncTest "Router should clear all resources when connection future is cancelled": + let q = newAsyncQueue[UtpSocket[int]]() + let pq = newAsyncQueue[(Packet, int)]() + let router = UtpRouter[int].new(registerIncomingSocketCallback(q), SocketConfig.init(milliseconds(500)), rng) + router.sendCb = initTestSnd(pq) + + let connectFuture = router.connectTo(testSender2) + + let (initialPacket, sender) = await pq.get() + + check: + initialPacket.header.pType == ST_SYN + router.len() == 1 + + await connectFuture.cancelAndWait() + + check: + router.len() == 0 + + asyncTest "Router should clear all resources and handle error while sending syn packet": + let q = newAsyncQueue[UtpSocket[int]]() + let pq = newAsyncQueue[(Packet, int)]() + let router = UtpRouter[int].new(registerIncomingSocketCallback(q), SocketConfig.init(milliseconds(500)), rng) + router.sendCb = + proc (to: int, data: seq[byte]): Future[void] = + let f = newFuture[void]() + f.fail(newException(TestError, "faile")) + return f + + let connectResult = await router.connectTo(testSender2) + + await waitUntil(proc (): bool = router.len() == 0) + + check: + connectResult.isErr() + connectResult.error().kind == ErrorWhileSendingSyn + cast[TestError](connectResult.error().error) is TestError + router.len() == 0 + asyncTest "Router should clear closed outgoing connections": let q = newAsyncQueue[UtpSocket[int]]() let pq = newAsyncQueue[(Packet, int)]() @@ -225,4 +320,3 @@ procSuite "Utp router unit tests": check: router.len() == 0 - diff --git a/tests/utp/test_utp_socket.nim b/tests/utp/test_utp_socket.nim index 49987e1..f3dad99 100644 --- a/tests/utp/test_utp_socket.nim +++ b/tests/utp/test_utp_socket.nim @@ -20,6 +20,7 @@ procSuite "Utp socket unit test": let rng = newRng() let testAddress = initTAddress("127.0.0.1", 9079) let testBufferSize = 1024'u32 + let defaultRcvOutgoingId = 314'u16 proc initTestSnd(q: AsyncQueue[Packet]): SendCallback[TransportAddress]= return ( @@ -61,8 +62,8 @@ procSuite "Utp socket unit test": initialRemoteSeq: uint16, q: AsyncQueue[Packet], cfg: SocketConfig = SocketConfig.init()): (UtpSocket[TransportAddress], Packet) = - let sock1 = initOutgoingSocket[TransportAddress](testAddress, initTestSnd(q), cfg, rng[]) - await sock1.startOutgoingSocket() + let sock1 = initOutgoingSocket[TransportAddress](testAddress, initTestSnd(q), cfg, defaultRcvOutgoingId, rng[]) + asyncSpawn sock1.startOutgoingSocket() let initialPacket = await q.get() check: @@ -80,18 +81,32 @@ procSuite "Utp socket unit test": asyncTest "Starting outgoing socket should send Syn packet": let q = newAsyncQueue[Packet]() let defaultConfig = SocketConfig.init() - let sock1 = initOutgoingSocket[TransportAddress](testAddress, initTestSnd(q), defaultConfig, rng[]) - await sock1.startOutgoingSocket() + let sock1 = initOutgoingSocket[TransportAddress]( + testAddress, + initTestSnd(q), + defaultConfig, + defaultRcvOutgoingId, + rng[] + ) + let fut1 = sock1.startOutgoingSocket() let initialPacket = await q.get() check: initialPacket.header.pType == ST_SYN initialPacket.header.wndSize == defaultConfig.optRcvBuffer + fut1.cancel() + asyncTest "Outgoing socket should re-send syn packet 2 times before declaring failure": let q = newAsyncQueue[Packet]() - let sock1 = initOutgoingSocket[TransportAddress](testAddress, initTestSnd(q), SocketConfig.init(milliseconds(100)), rng[]) - await sock1.startOutgoingSocket() + let sock1 = initOutgoingSocket[TransportAddress]( + testAddress, + initTestSnd(q), + SocketConfig.init(milliseconds(100)), + defaultRcvOutgoingId, + rng[] + ) + let fut1 = sock1.startOutgoingSocket() let initialPacket = await q.get() check: @@ -113,6 +128,8 @@ procSuite "Utp socket unit test": check: not sock1.isConnected() + fut1.cancel() + asyncTest "Processing in order ack should make socket connected": let q = newAsyncQueue[Packet]() let initialRemoteSeq = 10'u16 @@ -281,8 +298,16 @@ procSuite "Utp socket unit test": let q = newAsyncQueue[Packet]() let initalRemoteSeqNr = 10'u16 - let outgoingSocket = initOutgoingSocket[TransportAddress](testAddress, initTestSnd(q), SocketConfig.init(milliseconds(50), 2), rng[]) - await outgoingSocket.startOutgoingSocket() + let outgoingSocket = initOutgoingSocket[TransportAddress]( + testAddress, + initTestSnd(q), + SocketConfig.init(milliseconds(3000), 2), + defaultRcvOutgoingId, + rng[] + ) + + let fut1 = outgoingSocket.startOutgoingSocket() + let initialPacket = await q.get() check: