diff --git a/eth/utp/packets.nim b/eth/utp/packets.nim index 38a20c5..0cdf19a 100644 --- a/eth/utp/packets.nim +++ b/eth/utp/packets.nim @@ -188,3 +188,21 @@ proc dataPacket*(seqNr: uint16, sndConnectionId: uint16, ackNr: uint16, bufferSi ) Packet(header: h, payload: payload) + +proc resetPacket*(seqNr: uint16, sndConnectionId: uint16, ackNr: uint16): Packet = + let h = PacketHeaderV1( + pType: ST_RESET, + version: protocolVersion, + # data packets always have extension field set to 0 + extension: 0'u8, + connectionId: sndConnectionId, + timestamp: getMonoTimeTimeStamp(), + # TODO for not we are using 0, but this value should be calculated on socket + # level + timestampDiff: 0'u32, + wndSize: 0, + seqNr: seqNr, + ackNr: ackNr + ) + + Packet(header: h, payload: @[]) diff --git a/eth/utp/utp_router.nim b/eth/utp/utp_router.nim index 7fcb266..ba0fd8b 100644 --- a/eth/utp/utp_router.nim +++ b/eth/utp/utp_router.nim @@ -1,5 +1,5 @@ import - std/[tables, options], + std/[tables, options, sugar], chronos, bearssl, chronicles, ../keys, ./utp_socket, @@ -28,6 +28,15 @@ type sendCb*: SendCallback[A] rng*: ref BrHmacDrbgContext +# this should probably be in standard lib, it allows lazy composition of options i.e +# one can write: O1 orElse O2 orElse O3, and chain will be evaluated to first option +# which isSome() +template orElse[A](a: Option[A], b: Option[A]): Option[A] = + if (a.isSome()): + a + else: + b + proc getUtpSocket[A](s: UtpRouter[A], k: UtpSocketKey[A]): Option[UtpSocket[A]] = let s = s.sockets.getOrDefault(k) if s == nil: @@ -43,6 +52,7 @@ iterator allSockets[A](s: UtpRouter[A]): UtpSocket[A] = yield socket proc len*[A](s: UtpRouter[A]): int = + ## returns number of active sockets len(s.sockets) proc registerUtpSocket[A](p: UtpRouter, s: UtpSocket[A]) = @@ -65,15 +75,39 @@ proc new*[A]( rng: rng ) +# There are different possiblites how connection was established, and we need to +# check every case +proc getSocketOnReset[A](r: UtpRouter[A], sender: A, id: uint16): Option[UtpSocket[A]] = + # id is our recv id + let recvKey = UtpSocketKey[A].init(sender, id) + + # id is our send id, and we did nitiate the connection, our recv id is id - 1 + let sendInitKey = UtpSocketKey[A].init(sender, id - 1) + + # id is our send id, and we did not initiate the connection, so our recv id is id + 1 + let sendNoInitKey = UtpSocketKey[A].init(sender, id + 1) + + r.getUtpSocket(recvKey) + .orElse(r.getUtpSocket(sendInitKey).filter(s => s.connectionIdSnd == id)) + .orElse(r.getUtpSocket(sendNoInitKey).filter(s => s.connectionIdSnd == id)) + proc processPacket[A](r: UtpRouter[A], p: Packet, sender: A) {.async.}= notice "Received packet ", packet = p - let socketKey = UtpSocketKey[A].init(sender, p.header.connectionId) - let maybeSocket = r.getUtpSocket(socketKey) case p.header.pType of ST_RESET: - # TODO Properly handle Reset packet, and close socket - notice "Received RESET packet" + let maybeSocket = r.getSocketOnReset(sender, p.header.connectionId) + if maybeSocket.isSome(): + notice "Received rst packet on known connection closing" + let socket = maybeSocket.unsafeGet() + # reference implementation acutally changes the socket state to reset state unless + # user explicitly closed socket before. The only difference between reset and destroy + # state is that socket in destroy state is ultimatly deleted from active connection + # list but socket in reset state lingers there until user of library closes it + # explictly. + socket.close() + else: + notice "Received rst packet for not known connection" of ST_SYN: # Syn packet are special, and we need to add 1 to header connectionId let socketKey = UtpSocketKey[A].init(sender, p.header.connectionId + 1) @@ -100,8 +134,11 @@ proc processPacket[A](r: UtpRouter[A], p: Packet, sender: A) {.async.}= let socket = maybeSocket.unsafeGet() await socket.processPacket(p) else: - # TODO add handling of respondig with reset - notice "Recevied FIN/DATA/ACK on not known socket" + # TODO add keeping track of recently send reset packets and do not send reset + # to peers which we recently send reset to. + notice "Recevied FIN/DATA/ACK on not known socket sending reset" + let rstPacket = resetPacket(randUint16(r.rng[]), p.header.connectionId, p.header.seqNr) + await r.sendCb(sender, encodePacket(rstPacket)) proc processIncomingBytes*[A](r: UtpRouter[A], bytes: seq[byte], sender: A) {.async.} = let dec = decodePacket(bytes) diff --git a/eth/utp/utp_socket.nim b/eth/utp/utp_socket.nim index 75de7b6..dbf2b63 100644 --- a/eth/utp/utp_socket.nim +++ b/eth/utp/utp_socket.nim @@ -60,9 +60,9 @@ type socketConfig: SocketConfig # Connection id for packets we receive - connectionIdRcv: uint16 + connectionIdRcv*: uint16 # Connection id for packets we send - connectionIdSnd: uint16 + connectionIdSnd*: uint16 # Sequence number for the next packet to be sent. seqNr: uint16 # All seq number up to this havve been correctly acked by us @@ -254,8 +254,7 @@ proc checkTimeouts(socket: UtpSocket) {.async.} = # client initiated connections, but did not send following data packet in rto # time. TODO this should be configurable if (socket.state == SynRecv): - socket.state = Destroy - socket.closeEvent.fire() + socket.close() return if socket.shouldDisconnectFromFailedRemote(): @@ -264,8 +263,7 @@ proc checkTimeouts(socket: UtpSocket) {.async.} = # but maybe it would be more clean to use result socket.connectionFuture.fail(newException(ConnectionError, "Connection to peer timed out")) - socket.state = Destroy - socket.closeEvent.fire() + socket.close() return let newTimeout = socket.retransmitTimeout * 2 @@ -419,11 +417,18 @@ proc isConnected*(socket: UtpSocket): bool = socket.state == Connected or socket.state == ConnectedFull proc close*(s: UtpSocket) = - # TODO Rething all this when working on FIN and RESET packets and proper handling + # TODO Rething all this when working on FIN packets and proper handling # of resources + s.state = Destroy s.checkTimeoutsLoop.cancel() s.closeEvent.fire() +proc closeWait*(s: UtpSocket) {.async.} = + # TODO Rething all this when working on FIN packets and proper handling + # of resources + s.close() + await allFutures(s.closeCallbacks) + proc setCloseCallback(s: UtpSocket, cb: SocketCloseCallback) {.async.} = ## Set callback which will be called whenever the socket is permanently closed try: diff --git a/tests/utp/all_utp_tests.nim b/tests/utp/all_utp_tests.nim index d1a1feb..392c05a 100644 --- a/tests/utp/all_utp_tests.nim +++ b/tests/utp/all_utp_tests.nim @@ -11,4 +11,5 @@ import ./test_protocol, ./test_discv5_protocol, ./test_buffer, - ./test_utp_socket + ./test_utp_socket, + ./test_utp_router diff --git a/tests/utp/test_utp_router.nim b/tests/utp/test_utp_router.nim new file mode 100644 index 0000000..3ad208b --- /dev/null +++ b/tests/utp/test_utp_router.nim @@ -0,0 +1,228 @@ +# Copyright (c) 2020-2021 Status Research & Development GmbH +# Licensed and distributed under either of +# * MIT license (license terms in the root directory or at https://opensource.org/licenses/MIT). +# * Apache v2 license (license terms in the root directory or at https://www.apache.org/licenses/LICENSE-2.0). +# at your option. This file may not be copied, modified, or distributed except according to those terms. + +{.used.} + +import + std/hashes, + chronos, bearssl, chronicles, + testutils/unittests, + ./test_utils, + ../../eth/utp/utp_router, + ../../eth/utp/packets, + ../../eth/keys + +proc hash*(x: UtpSocketKey[int]): Hash = + var h = 0 + h = h !& x.remoteAddress.hash + h = h !& x.rcvId.hash + !$h + +procSuite "Utp router unit tests": + let rng = newRng() + let testSender = 1 + let testSender2 = 2 + let testBufferSize = 1024'u32 + + proc registerIncomingSocketCallback(serverSockets: AsyncQueue): AcceptConnectionCallback[int] = + return ( + proc(server: UtpRouter[int], client: UtpSocket[int]): Future[void] = + serverSockets.addLast(client) + ) + + proc testSend(to: int, bytes: seq[byte]): Future[void] = + let f = newFuture[void]() + f.complete() + f + + proc initTestSnd(q: AsyncQueue[(Packet, int)]): SendCallback[int]= + return ( + proc (to: int, bytes: seq[byte]): Future[void] = + let p = decodePacket(bytes).get() + q.addLast((p, to)) + ) + + template connectOutgoing( + r: UtpRouter[int], + remote: int, + pq: AsyncQueue[(Packet, int)], + initialRemoteSeq: uint16): (UtpSocket[int], Packet)= + let connectFuture = router.connectTo(remote) + + let (initialPacket, sender) = await pq.get() + + check: + initialPacket.header.pType == ST_SYN + + let responseAck = ackPacket(initialRemoteSeq, initialPacket.header.connectionId, initialPacket.header.seqNr, testBufferSize) + + await router.processIncomingBytes(encodePacket(responseAck), remote) + + let outgoingSocket = await connectFuture + (outgoingSocket, initialPacket) + + asyncTest "Router should ingnore non utp packets": + let q = newAsyncQueue[UtpSocket[int]]() + let router = UtpRouter[int].new(registerIncomingSocketCallback(q), SocketConfig.init(), rng) + router.sendCb = testSend + + await router.processIncomingBytes(@[1'u8, 2, 3], testSender) + + check: + router.len() == 0 + q.len() == 0 + + asyncTest "Router should create new incoming socket when receiving not known syn packet": + let q = newAsyncQueue[UtpSocket[int]]() + let router = UtpRouter[int].new(registerIncomingSocketCallback(q), SocketConfig.init(), rng) + router.sendCb = testSend + let encodedSyn = encodePacket(synPacket(10, 10, 10)) + + await router.processIncomingBytes(encodedSyn, testSender) + + check: + router.len() == 1 + + asyncTest "Router should create new incoming socket when receiving same syn packet from diffrent sender": + let q = newAsyncQueue[UtpSocket[int]]() + let router = UtpRouter[int].new(registerIncomingSocketCallback(q), SocketConfig.init(), rng) + router.sendCb = testSend + let encodedSyn = encodePacket(synPacket(10, 10, 10)) + + await router.processIncomingBytes(encodedSyn, testSender) + + check: + router.len() == 1 + + await router.processIncomingBytes(encodedSyn, testSender2) + + check: + router.len() == 2 + + asyncTest "Router should ignore duplicated syn packet": + let q = newAsyncQueue[UtpSocket[int]]() + let router = UtpRouter[int].new(registerIncomingSocketCallback(q), SocketConfig.init(), rng) + router.sendCb = testSend + let encodedSyn = encodePacket(synPacket(10, 10, 10)) + + await router.processIncomingBytes(encodedSyn, testSender) + + check: + router.len() == 1 + + await router.processIncomingBytes(encodedSyn, testSender) + + check: + router.len() == 1 + + asyncTest "Router should clear closed incoming sockets": + let q = newAsyncQueue[UtpSocket[int]]() + let router = UtpRouter[int].new(registerIncomingSocketCallback(q), SocketConfig.init(), rng) + router.sendCb = testSend + let encodedSyn = encodePacket(synPacket(10, 10, 10)) + + await router.processIncomingBytes(encodedSyn, testSender) + + let socket = await q.get() + + check: + router.len() == 1 + + await socket.closeWait() + + check: + not socket.isConnected() + router.len() == 0 + + asyncTest "Router should connect to out going peer": + let q = newAsyncQueue[UtpSocket[int]]() + let pq = newAsyncQueue[(Packet, int)]() + let router = UtpRouter[int].new(registerIncomingSocketCallback(q), SocketConfig.init(), rng) + router.sendCb = initTestSnd(pq) + + let (outgoingSocket, initialSyn) = router.connectOutgoing(testSender2, pq, 30'u16) + + check: + outgoingSocket.isConnected() + router.len() == 1 + + asyncTest "Router should clear closed outgoing connections": + let q = newAsyncQueue[UtpSocket[int]]() + let pq = newAsyncQueue[(Packet, int)]() + let router = UtpRouter[int].new(registerIncomingSocketCallback(q), SocketConfig.init(), rng) + router.sendCb = initTestSnd(pq) + + let (outgoingSocket, initialSyn) = router.connectOutgoing(testSender2, pq, 30'u16) + + check: + outgoingSocket.isConnected() + router.len() == 1 + + await outgoingSocket.closeWait() + + check: + not outgoingSocket.isConnected() + router.len() == 0 + + asyncTest "Router should respond with Reset when receiving packet for not known connection": + let q = newAsyncQueue[UtpSocket[int]]() + let pq = newAsyncQueue[(Packet, int)]() + let router = UtpRouter[int].new(registerIncomingSocketCallback(q), SocketConfig.init(), rng) + router.sendCb = initTestSnd(pq) + + let sndId = 10'u16 + let dp = dataPacket(10'u16, sndId, 10'u16, 10'u32, @[1'u8]) + + await router.processIncomingBytes(encodePacket(dp), testSender2) + + let (packet, sender) = await pq.get() + check: + packet.header.pType == ST_RESET + packet.header.connectionId == sndId + sender == testSender2 + + asyncTest "Router close incoming connection which receives reset": + let q = newAsyncQueue[UtpSocket[int]]() + let router = UtpRouter[int].new(registerIncomingSocketCallback(q), SocketConfig.init(), rng) + router.sendCb = testSend + let recvId = 10'u16 + let encodedSyn = encodePacket(synPacket(10, recvId, 10)) + + await router.processIncomingBytes(encodedSyn, testSender) + + check: + router.len() == 1 + + let rstPacket = resetPacket(10, recvId, 10) + + await router.processIncomingBytes(encodePacket(rstPacket), testSender) + + await waitUntil(proc (): bool = router.len() == 0) + + check: + router.len() == 0 + + asyncTest "Router close outgoing connection which receives reset": + let q = newAsyncQueue[UtpSocket[int]]() + let pq = newAsyncQueue[(Packet, int)]() + let router = UtpRouter[int].new(registerIncomingSocketCallback(q), SocketConfig.init(), rng) + router.sendCb = initTestSnd(pq) + + let (outgoingSocket, initialSyn) = router.connectOutgoing(testSender2, pq, 30'u16) + + check: + router.len() == 1 + + # remote side sendId is syn.header.connectionId + 1 + let rstPacket = resetPacket(10, initialSyn.header.connectionId + 1, 10) + + await router.processIncomingBytes(encodePacket(rstPacket), testSender2) + + await waitUntil(proc (): bool = router.len() == 0) + + check: + router.len() == 0 +