diff --git a/eth/utp/utp_socket.nim b/eth/utp/utp_socket.nim index 02cc7da..e3c525a 100644 --- a/eth/utp/utp_socket.nim +++ b/eth/utp/utp_socket.nim @@ -36,6 +36,7 @@ type packetBytes: seq[byte] transmissions: uint16 needResend: bool + payloadLength: uint32 timeSent: Moment AckResult = enum @@ -144,6 +145,12 @@ type # sequence number of remoted fin packet eofPktNr: uint16 + # number payload bytes in-flight (i.e not countig header sizes) + # packets that have not yet been sent do not count, packets + # that are marked as needing to be re-sent (due to a timeout) + # don't count either + currentWindow: uint32 + # socket identifier socketKey*: UtpSocketKey[A] @@ -227,11 +234,13 @@ proc init( packetBytes: seq[byte], transmissions: uint16, needResend: bool, + payloadLength: uint32, timeSent: Moment = Moment.now()): T = OutgoingPacket( packetBytes: packetBytes, transmissions: transmissions, needResend: needResend, + payloadLength: payloadLength, timeSent: timeSent ) @@ -280,7 +289,11 @@ proc sendAck(socket: UtpSocket): Future[void] = socket.sendData(encodePacket(ackPacket)) # Should be called before sending packet -proc setSend(p: var OutgoingPacket): seq[byte] = +proc setSend(s: UtpSocket, p: var OutgoingPacket): seq[byte] = + + if (p.transmissions == 0 or p.needResend): + s.currentWindow = s.currentWindow + p.payloadLength + inc p.transmissions p.needResend = false p.timeSent = Moment.now() @@ -292,7 +305,7 @@ proc flushPackets(socket: UtpSocket) {.async.} = # sending only packet which were not transmitted yet or need a resend let shouldSendPacket = socket.outBuffer.exists(i, (p: OutgoingPacket) => (p.transmissions == 0 or p.needResend == true)) if (shouldSendPacket): - let toSend = setSend(socket.outBuffer[i]) + let toSend = socket.setSend(socket.outBuffer[i]) await socket.sendData(toSend) inc i @@ -303,8 +316,9 @@ proc markAllPacketAsLost(s: UtpSocket) = let packetSeqNr = s.seqNr - 1 - i if (s.outBuffer.exists(packetSeqNr, (p: OutgoingPacket) => p. transmissions > 0 and p.needResend == false)): s.outBuffer[packetSeqNr].needResend = true - # TODO here we should also decrease number of bytes in flight. This should be - # done when working on congestion control + let packetPayloadLength = s.outBuffer[packetSeqNr].payloadLength + doAssert(s.currentWindow >= packetPayloadLength) + s.currentWindow = s.currentWindow - packetPayloadLength inc i @@ -368,7 +382,7 @@ proc checkTimeouts(socket: UtpSocket) {.async.} = socket.outBuffer.get(oldestPacketSeqNr).isSome(), "oldest packet should always be available when there is data in flight" ) - let dataToSend = setSend(socket.outBuffer[oldestPacketSeqNr]) + let dataToSend = socket.setSend(socket.outBuffer[oldestPacketSeqNr]) await socket.sendData(dataToSend) # TODO add sending keep alives when necessary @@ -486,7 +500,7 @@ proc startOutgoingSocket*(socket: UtpSocket): Future[void] {.async.} = notice "Sending syn packet packet", packet = packet # set number of transmissions to 1 as syn packet will be send just after # initiliazation - let outgoingPacket = OutgoingPacket.init(encodePacket(packet), 1, false) + let outgoingPacket = OutgoingPacket.init(encodePacket(packet), 1, false, 0) socket.registerOutgoingPacket(outgoingPacket) socket.startTimeoutLoop() await socket.sendData(outgoingPacket.packetBytes) @@ -588,7 +602,12 @@ proc ackPacket(socket: UtpSocket, seqNr: uint16): AckResult = socket.retransmitTimeout = socket.rto socket.rtoTimeout = currentTime + socket.rto - # TODO Add handlig of decreasing bytes window, whenadding handling of congestion control + # if need_resend is set, this packet has already + # been considered timed-out, and is not included in + # the cur_window anymore + if (not packet.needResend): + doAssert(socket.currentWindow >= packet.payloadLength) + socket.currentWindow = socket.currentWindow - packet.payloadLength socket.retransmitCount = 0 PacketAcked @@ -822,7 +841,7 @@ proc close*(socket: UtpSocket) {.async.} = socket.resetSendTimeout() let finEncoded = encodePacket(finPacket(socket.seqNr, socket.connectionIdSnd, socket.ackNr, socket.getRcvWindowSize())) - socket.registerOutgoingPacket(OutgoingPacket.init(finEncoded, 1, true)) + socket.registerOutgoingPacket(OutgoingPacket.init(finEncoded, 1, true, 0)) socket.finSent = true await socket.sendData(finEncoded) else: @@ -869,7 +888,8 @@ proc write*(socket: UtpSocket, data: seq[byte]): Future[WriteResult] {.async.} = let lastOrEnd = min(lastIndex, endIndex) let dataSlice = data[i..lastOrEnd] let dataPacket = dataPacket(socket.seqNr, socket.connectionIdSnd, socket.ackNr, wndSize, dataSlice) - socket.registerOutgoingPacket(OutgoingPacket.init(encodePacket(dataPacket), 0, false)) + let payloadLength = uint32(len(dataSlice)) + socket.registerOutgoingPacket(OutgoingPacket.init(encodePacket(dataPacket), 0, false, payloadLength)) bytesWritten = bytesWritten + len(dataSlice) i = lastOrEnd + 1 await socket.flushPackets() @@ -932,6 +952,9 @@ proc numPacketsInOutGoingBuffer*(socket: UtpSocket): int = doAssert(num == int(socket.curWindowPackets)) num +# Check how many payload bytes are still in flight +proc numOfBytesInFlight*(socket: UtpSocket): uint32 = socket.currentWindow + # Check how many packets are still in the reorder buffer, usefull for tests or # debugging. # It throws assertion error when number of elements in buffer do not equal kept counter diff --git a/tests/utp/test_utp_socket.nim b/tests/utp/test_utp_socket.nim index 0c32cdb..3377a52 100644 --- a/tests/utp/test_utp_socket.nim +++ b/tests/utp/test_utp_socket.nim @@ -658,3 +658,91 @@ procSuite "Utp socket unit test": receivedBytes == data3 await outgoingSocket.destroyWait() + + asyncTest "Writing data should increase current bytes window": + let q = newAsyncQueue[Packet]() + let initialRemoteSeq = 10'u16 + + let dataToWrite = @[1'u8, 2, 3, 4, 5] + + let (outgoingSocket, initialPacket) = connectOutGoingSocket(initialRemoteSeq, q) + + discard await outgoingSocket.write(dataToWrite) + + check: + int(outgoingSocket.numOfBytesInFlight) == len(dataToWrite) + + discard await outgoingSocket.write(dataToWrite) + + check: + int(outgoingSocket.numOfBytesInFlight) == len(dataToWrite) + len(dataToWrite) + + await outgoingSocket.destroyWait() + + asyncTest "Acking data packet should decrease current bytes window": + let q = newAsyncQueue[Packet]() + let initialRemoteSeq = 10'u16 + + let dataToWrite = @[1'u8, 2, 3, 4, 5] + + let (outgoingSocket, initialPacket) = connectOutGoingSocket(initialRemoteSeq, q) + + discard await outgoingSocket.write(dataToWrite) + + let sentPacket = await q.get() + + check: + int(outgoingSocket.numOfBytesInFlight) == len(dataToWrite) + + + discard await outgoingSocket.write(dataToWrite) + + check: + int(outgoingSocket.numOfBytesInFlight) == len(dataToWrite) + len(dataToWrite) + + let responseAck = ackPacket(initialRemoteSeq, initialPacket.header.connectionId, sentPacket.header.seqNr, testBufferSize) + + await outgoingSocket.processPacket(responseAck) + + check: + # only first packet has been acked so there should still by 5 bytes left + int(outgoingSocket.numOfBytesInFlight) == len(dataToWrite) + + await outgoingSocket.destroyWait() + + asyncTest "Timeout packets should decrease bytes window": + let q = newAsyncQueue[Packet]() + let initialRemoteSeq = 10'u16 + + let dataToWrite = @[1'u8, 2, 3] + let dataToWrite1 = @[6'u8, 7, 8, 9, 10] + + let (outgoingSocket, initialPacket) = connectOutGoingSocket(initialRemoteSeq, q) + + discard await outgoingSocket.write(dataToWrite) + + let sentPacket = await q.get() + + check: + int(outgoingSocket.numOfBytesInFlight) == len(dataToWrite) + + + discard await outgoingSocket.write(dataToWrite1) + + let sentPacket1 = await q.get() + + check: + int(outgoingSocket.numOfBytesInFlight) == len(dataToWrite) + len(dataToWrite1) + + # after timeout oldest packet will be immediatly re-sent + let reSentFirstPacket = await q.get() + + check: + reSentFirstPacket.payload == sentPacket.payload + + # first packet has been re-sent so its payload still counts to bytes in flight + # second packet has been marked as missing, therefore its bytes are not counting + # to bytes in flight + int(outgoingSocket.numOfBytesInFlight) == len(dataToWrite) + + await outgoingSocket.destroyWait()