diff --git a/eth/utp/packets.nim b/eth/utp/packets.nim index fd5118d..34cea2c 100644 --- a/eth/utp/packets.nim +++ b/eth/utp/packets.nim @@ -48,8 +48,8 @@ type # sequence number the sender of the packet last received in the other direction ackNr*: uint16 - SelectiveAckExtension = object - acks: array[acksArrayLength, byte] + SelectiveAckExtension* = object + acks*: array[4, byte] Packet* = object header*: PacketHeaderV1 diff --git a/eth/utp/utp_socket.nim b/eth/utp/utp_socket.nim index 0778ed1..b0f7154 100644 --- a/eth/utp/utp_socket.nim +++ b/eth/utp/utp_socket.nim @@ -9,7 +9,7 @@ import std/sugar, chronos, chronicles, bearssl, - stew/results, + stew/[results, bitops2], ./send_buffer_tracker, ./growable_buffer, ./packets, @@ -341,20 +341,6 @@ proc registerOutgoingPacket(socket: UtpSocket, oPacket: OutgoingPacket) = proc sendData(socket: UtpSocket, data: seq[byte]): Future[void] = socket.send(socket.remoteAddress, data) -proc sendAck(socket: UtpSocket): Future[void] = - ## Creates and sends ack, based on current socket state. Acks are different from - ## other packets as we do not track them in outgoing buffet - - let ackPacket = - ackPacket( - socket.seqNr, - socket.connectionIdSnd, - socket.ackNr, - socket.getRcvWindowSize(), - socket.replayMicro - ) - socket.sendData(encodePacket(ackPacket)) - # Should be called before sending packet proc setSend(s: UtpSocket, p: var OutgoingPacket): seq[byte] = let timestampInfo = getMonoTimestamp() @@ -725,12 +711,6 @@ proc startOutgoingSocket*(socket: UtpSocket): Future[void] {.async.} = await socket.sendData(outgoingPacket.packetBytes) await socket.connectionFuture -proc startIncomingSocket*(socket: UtpSocket) {.async.} = - # Make sure ack was flushed before moving forward - await socket.sendAck() - socket.startWriteLoop() - socket.startTimeoutLoop() - proc isConnected*(socket: UtpSocket): bool = socket.state == Connected @@ -881,6 +861,121 @@ proc isAckNrInvalid(socket: UtpSocket, packet: Packet): bool = ) ) +# counts the number of bytes acked by selective ack header +proc calculateSelectiveAckBytes*(socket: UtpSocket, receivedPackedAckNr: uint16, ext: SelectiveAckExtension): uint32 = + # we add 2, as the first bit in the mask therefore represents ackNr + 2 becouse + # ackNr + 1 (i.e next expected packet) is considered lost. + let base = receivedPackedAckNr + 2 + + if socket.curWindowPackets == 0: + return 0 + + var ackedBytes = 0'u32 + + var bits = (len(ext.acks)) * 8 - 1 + + while bits >= 0: + let v = base + uint16(bits) + + if (socket.seqNr - v - 1) >= socket.curWindowPackets - 1: + dec bits + continue + + let maybePacket = socket.outBuffer.get(v) + + if (maybePacket.isNone() or maybePacket.unsafeGet().transmissions == 0): + dec bits + continue + + let pkt = maybePacket.unsafeGet() + + if (getBit(ext.acks, bits)): + ackedBytes = ackedBytes + pkt.payloadLength + + dec bits + + return ackedBytes + +# ack packets (removes them from out going buffer) based on selective ack extension header +proc selectiveAckPackets(socket: UtpSocket, receivedPackedAckNr: uint16, ext: SelectiveAckExtension, currentTime: Moment): void = + # we add 2, as the first bit in the mask therefore represents ackNr + 2 becouse + # ackNr + 1 (i.e next expected packet) is considered lost. + let base = receivedPackedAckNr + 2 + + if socket.curWindowPackets == 0: + return + + var bits = (len(ext.acks)) * 8 - 1 + + while bits >= 0: + let v = base + uint16(bits) + + if (socket.seqNr - v - 1) >= socket.curWindowPackets - 1: + dec bits + continue + + let maybePacket = socket.outBuffer.get(v) + + if (maybePacket.isNone() or maybePacket.unsafeGet().transmissions == 0): + dec bits + continue + + let pkt = maybePacket.unsafeGet() + + if (getBit(ext.acks, bits)): + discard socket.ackPacket(v, currentTime) + + dec bits + + # TODO Add handling of fast timeouts and duplicate acks counting + +# Public mainly for test purposes +# generates bit mask which indicates which packets are already in socket +# reorder buffer +# from speck: +# The bitmask has reverse byte order. The first byte represents packets [ack_nr + 2, ack_nr + 2 + 7] in reverse order +# The least significant bit in the byte represents ack_nr + 2, the most significant bit in the byte represents ack_nr + 2 + 7 +# The next byte in the mask represents [ack_nr + 2 + 8, ack_nr + 2 + 15] in reverse order, and so on +proc generateSelectiveAckBitMask*(socket: UtpSocket): array[4, byte] = + let window = min(32, socket.inBuffer.len()) + var arr: array[4, uint8] = [0'u8, 0, 0, 0] + var i = 0 + while i < window: + if (socket.inBuffer.get(socket.ackNr + uint16(i) + 2).isSome()): + setBit(arr, i) + inc i + return arr + +# Generates ack packet based on current state of the socket. +proc generateAckPacket*(socket: UtpSocket): Packet = + let bitmask = + if (socket.reorderCount != 0 and (not socket.reachedFin)): + some(socket.generateSelectiveAckBitMask()) + else: + none[array[4, byte]]() + + ackPacket( + socket.seqNr, + socket.connectionIdSnd, + socket.ackNr, + socket.getRcvWindowSize(), + socket.replayMicro, + bitmask + ) + +proc sendAck(socket: UtpSocket): Future[void] = + ## Creates and sends ack, based on current socket state. Acks are different from + ## other packets as we do not track them in outgoing buffet + + let ackPacket = socket.generateAckPacket() + socket.sendData(encodePacket(ackPacket)) + +proc startIncomingSocket*(socket: UtpSocket) {.async.} = + # Make sure ack was flushed before moving forward + await socket.sendAck() + socket.startWriteLoop() + socket.startTimeoutLoop() + # TODO at socket level we should handle only FIN/DATA/ACK packets. Refactor to make # it enforcable by type system # TODO re-think synchronization of this procedure, as each await inside gives control @@ -888,7 +983,7 @@ proc isAckNrInvalid(socket: UtpSocket, packet: Packet): bool = # running proc processPacket*(socket: UtpSocket, p: Packet) {.async.} = let timestampInfo = getMonoTimestamp() - + if socket.isAckNrInvalid(p): notice "Received packet with invalid ack nr" return @@ -923,6 +1018,10 @@ proc processPacket*(socket: UtpSocket, p: Packet) {.async.} = var (ackedBytes, minRtt) = socket.calculateAckedbytes(acks, timestampInfo.moment) # TODO caluclate bytes acked by selective acks here (if thats the case) + if (p.eack.isSome()): + let selectiveAckedBytes = socket.calculateSelectiveAckBytes(pkAckNr, p.eack.unsafeGet()) + ackedBytes = ackedBytes + selectiveAckedBytes + let sentTimeRemote = p.header.timestamp # we are using uint32 not a Duration, to wrap a round in case of @@ -999,6 +1098,14 @@ proc processPacket*(socket: UtpSocket, p: Packet) {.async.} = socket.ackPackets(acks, timestampInfo.moment) + # packets in front may have been acked by selective ack, decrease window until we hit + # a packet that is still waiting to be acked + while (socket.curWindowPackets > 0 and socket.outBuffer.get(socket.seqNr - socket.curWindowPackets).isNone()): + dec socket.curWindowPackets + + if (p.eack.isSome()): + socket.selectiveAckPackets(pkAckNr, p.eack.unsafeGet(), timestampInfo.moment) + case p.header.pType of ST_DATA, ST_FIN: # To avoid amplification attacks, server socket is in SynRecv state until @@ -1014,6 +1121,7 @@ proc processPacket*(socket: UtpSocket, p: Packet) {.async.} = # we got in order packet if (pastExpected == 0 and (not socket.reachedFin)): + notice "Got in order packet" if (len(p.payload) > 0 and (not socket.readShutdown)): # we are getting in order data packet, we can flush data directly to the incoming buffer await upload(addr socket.buffer, unsafeAddr p.payload[0], p.payload.len()) @@ -1085,8 +1193,10 @@ proc processPacket*(socket: UtpSocket, p: Packet) {.async.} = socket.inBuffer.put(pkSeqNr, p) inc socket.reorderCount notice "added out of order packet in reorder buffer" - # TODO for now we do not sent any ack as we do not handle selective acks - # add sending of selective acks + # we send ack packet, as we reoreder count is > 0, so the eack bitmask will be + # generated + asyncSpawn socket.sendAck() + of ST_STATE: if (socket.state == SynSent and (not socket.connectionFuture.finished())): socket.state = Connected @@ -1220,13 +1330,11 @@ proc read*(socket: UtpSocket): Future[seq[byte]] {.async.}= # Check how many packets are still in the out going buffer, usefull for tests or # debugging. -# It throws assertion error when number of elements in buffer do not equal kept counter proc numPacketsInOutGoingBuffer*(socket: UtpSocket): int = var num = 0 for e in socket.outBuffer.items(): if e.isSome(): inc num - doAssert(num == int(socket.curWindowPackets)) num # Check how many payload bytes are still in flight diff --git a/tests/utp/all_utp_tests.nim b/tests/utp/all_utp_tests.nim index 63d76af..e8beaa4 100644 --- a/tests/utp/all_utp_tests.nim +++ b/tests/utp/all_utp_tests.nim @@ -12,5 +12,6 @@ import ./test_discv5_protocol, ./test_buffer, ./test_utp_socket, + ./test_utp_socket_sack, ./test_utp_router, ./test_clock_drift_calculator diff --git a/tests/utp/test_utils.nim b/tests/utp/test_utils.nim index ed53310..30209ce 100644 --- a/tests/utp/test_utils.nim +++ b/tests/utp/test_utils.nim @@ -1,9 +1,14 @@ import chronos, + ../../eth/utp/utp_socket, + ../../eth/utp/packets, ../../eth/keys type AssertionCallback = proc(): bool {.gcsafe, raises: [Defect].} +let testBufferSize = 1024'u32 +let defaultRcvOutgoingId = 314'u16 + proc generateByteArray*(rng: var BrHmacDrbgContext, length: int): seq[byte] = var bytes = newSeq[byte](length) brHmacDrbgGenerate(rng, bytes) @@ -16,3 +21,96 @@ proc waitUntil*(f: AssertionCallback): Future[void] {.async.} = break else: await sleepAsync(milliseconds(50)) + +template connectOutGoingSocket*( + initialRemoteSeq: uint16, + q: AsyncQueue[Packet], + remoteReceiveBuffer: uint32 = testBufferSize, + cfg: SocketConfig = SocketConfig.init()): (UtpSocket[TransportAddress], Packet) = + let sock1 = newOutgoingSocket[TransportAddress](testAddress, initTestSnd(q), cfg, defaultRcvOutgoingId, rng[]) + asyncSpawn sock1.startOutgoingSocket() + let initialPacket = await q.get() + + check: + initialPacket.header.pType == ST_SYN + + let responseAck = + ackPacket( + initialRemoteSeq, + initialPacket.header.connectionId, + initialPacket.header.seqNr, + remoteReceiveBuffer, + 0 + ) + + await sock1.processPacket(responseAck) + + check: + sock1.isConnected() + + (sock1, initialPacket) + +template connectOutGoingSocketWithIncoming*( + initialRemoteSeq: uint16, + outgoingQueue: AsyncQueue[Packet], + incomingQueue: AsyncQueue[Packet], + remoteReceiveBuffer: uint32 = testBufferSize, + cfg: SocketConfig = SocketConfig.init()): (UtpSocket[TransportAddress], UtpSocket[TransportAddress]) = + let outgoingSocket = newOutgoingSocket[TransportAddress](testAddress, initTestSnd(outgoingQueue), cfg, defaultRcvOutgoingId, rng[]) + asyncSpawn outgoingSocket.startOutgoingSocket() + let initialPacket = await outgoingQueue.get() + + check: + initialPacket.header.pType == ST_SYN + + let incomingSocket = newIncomingSocket[TransportAddress]( + testAddress, + initTestSnd(incomingQueue), + cfg, + initialPacket.header.connectionId, + initialPacket.header.seqNr, + rng[] + ) + + await incomingSocket.startIncomingSocket() + + let responseAck = await incomingQueue.get() + + await outgoingSocket.processPacket(responseAck) + + check: + outgoingSocket.isConnected() + + (outgoingSocket, incomingSocket) + + +proc generateDataPackets*( + numberOfPackets: uint16, + initialSeqNr: uint16, + connectionId: uint16, + ackNr: uint16, + rng: var BrHmacDrbgContext): seq[Packet] = + let packetSize = 100 + var packets = newSeq[Packet]() + var i = 0'u16 + while i < numberOfPackets: + let packet = dataPacket( + initialSeqNr + i, + connectionId, + ackNr, + testBufferSize, + generateByteArray(rng, packetSize), + 0 + ) + packets.add(packet) + + inc i + + packets + +proc initTestSnd*(q: AsyncQueue[Packet]): SendCallback[TransportAddress]= + return ( + proc (to: TransportAddress, bytes: seq[byte]): Future[void] = + let p = decodePacket(bytes).get() + q.addLast(p) + ) diff --git a/tests/utp/test_utp_socket.nim b/tests/utp/test_utp_socket.nim index dae7eb0..fdf617f 100644 --- a/tests/utp/test_utp_socket.nim +++ b/tests/utp/test_utp_socket.nim @@ -7,7 +7,7 @@ {.used.} import - std/[algorithm, random, sequtils], + std/[algorithm, random, sequtils, options], chronos, bearssl, chronicles, testutils/unittests, ./test_utils, @@ -22,71 +22,12 @@ procSuite "Utp socket unit test": let testBufferSize = 1024'u32 let defaultRcvOutgoingId = 314'u16 - proc initTestSnd(q: AsyncQueue[Packet]): SendCallback[TransportAddress]= - return ( - proc (to: TransportAddress, bytes: seq[byte]): Future[void] = - let p = decodePacket(bytes).get() - q.addLast(p) - ) - - proc generateDataPackets( - numberOfPackets: uint16, - initialSeqNr: uint16, - connectionId: uint16, - ackNr: uint16, - rng: var BrHmacDrbgContext): seq[Packet] = - let packetSize = 100 - var packets = newSeq[Packet]() - var i = 0'u16 - while i < numberOfPackets: - let packet = dataPacket( - initialSeqNr + i, - connectionId, - ackNr, - testBufferSize, - generateByteArray(rng, packetSize), - 0 - ) - packets.add(packet) - - inc i - - packets - proc packetsToBytes(packets: seq[Packet]): seq[byte] = var resultBytes = newSeq[byte]() for p in packets: resultBytes.add(p.payload) return resultBytes - template connectOutGoingSocket( - initialRemoteSeq: uint16, - q: AsyncQueue[Packet], - remoteReceiveBuffer: uint32 = testBufferSize, - cfg: SocketConfig = SocketConfig.init()): (UtpSocket[TransportAddress], Packet) = - let sock1 = newOutgoingSocket[TransportAddress](testAddress, initTestSnd(q), cfg, defaultRcvOutgoingId, rng[]) - asyncSpawn sock1.startOutgoingSocket() - let initialPacket = await q.get() - - check: - initialPacket.header.pType == ST_SYN - - let responseAck = - ackPacket( - initialRemoteSeq, - initialPacket.header.connectionId, - initialPacket.header.seqNr, - remoteReceiveBuffer, - 0 - ) - - await sock1.processPacket(responseAck) - - check: - sock1.isConnected() - - (sock1, initialPacket) - asyncTest "Starting outgoing socket should send Syn packet": let q = newAsyncQueue[Packet]() let defaultConfig = SocketConfig.init() @@ -187,10 +128,11 @@ procSuite "Utp socket unit test": # TODO test is valid until implementing selective acks let q = newAsyncQueue[Packet]() let initalRemoteSeqNr = 10'u16 + let numOfPackets = 10'u16 let (outgoingSocket, initialPacket) = connectOutGoingSocket(initalRemoteSeqNr, q) - var packets = generateDataPackets(10, initalRemoteSeqNr, initialPacket.header.connectionId, initialPacket.header.seqNr, rng[]) + var packets = generateDataPackets(numOfPackets, initalRemoteSeqNr, initialPacket.header.connectionId, initialPacket.header.seqNr, rng[]) let data = packetsToBytes(packets) @@ -200,12 +142,28 @@ procSuite "Utp socket unit test": for p in packets: await outgoingSocket.processPacket(p) - let ack2 = await q.get() + var sentAcks: seq[Packet] = @[] + + for i in 0'u16..