From 779d767b024175a51cf74c79ec7513301ebe2f46 Mon Sep 17 00:00:00 2001 From: KonradStaniec Date: Thu, 10 Feb 2022 08:05:44 +0100 Subject: [PATCH] Add more tests stressing conccurent reading and writing on utp socket (#474) * Add more tests stressing concurrent reading and writing * Fix bug when remote window dropped below packet size --- eth/utp/utp_protocol.nim | 10 +- eth/utp/utp_socket.nim | 142 ++++++++------- tests/utp/all_utp_tests.nim | 3 +- tests/utp/test_protocol_integration.nim | 231 ++++++++++++++++++++++++ 4 files changed, 320 insertions(+), 66 deletions(-) create mode 100644 tests/utp/test_protocol_integration.nim diff --git a/eth/utp/utp_protocol.nim b/eth/utp/utp_protocol.nim index 38ddd93..ab85841 100644 --- a/eth/utp/utp_protocol.nim +++ b/eth/utp/utp_protocol.nim @@ -23,6 +23,8 @@ type transport: DatagramTransport utpRouter: UtpRouter[TransportAddress] + SendCallbackBuilder* = proc (d: DatagramTransport): SendCallback[TransportAddress] {.gcsafe, raises: [Defect].} + # This should probably be defined in TransportAddress module, as hash function should # be consitent with equality function # in nim zero arrays always have hash equal to 0, irrespectively of array size, to @@ -78,6 +80,7 @@ proc new*( address: TransportAddress, socketConfig: SocketConfig = SocketConfig.init(), allowConnectionCb: AllowConnectionCallback[TransportAddress] = nil, + sendCallbackBuilder: SendCallbackBuilder = nil, rng = newRng()): UtpProtocol {.raises: [Defect, CatchableError].} = doAssert(not(isNil(acceptConnectionCb))) @@ -90,7 +93,12 @@ proc new*( ) let ta = newDatagramTransport(processDatagram, udata = router, local = address) - router.sendCb = initSendCallback(ta) + + if (sendCallbackBuilder == nil): + router.sendCb = initSendCallback(ta) + else: + router.sendCb = sendCallbackBuilder(ta) + UtpProtocol(transport: ta, utpRouter: router) proc shutdownWait*(p: UtpProtocol): Future[void] {.async.} = diff --git a/eth/utp/utp_socket.nim b/eth/utp/utp_socket.nim index 121d2b0..ea41d41 100644 --- a/eth/utp/utp_socket.nim +++ b/eth/utp/utp_socket.nim @@ -191,7 +191,8 @@ type writeLoop: Future[void] - zeroWindowTimer: Moment + # timer which is started when peer max window drops below current packet size + zeroWindowTimer: Option[Moment] # last measured delay between current local timestamp, and remote sent # timestamp. In microseconds @@ -287,7 +288,8 @@ const allowedAckWindow*: uint16 = 3 # Timeout after which the send window will be reset to its minimal value after it dropped - # to zero. i.e when we received a packet from remote peer with `wndSize` set to 0. + # lower than our current packet size. i.e when we received a packet + # from remote peer with `wndSize` set to number <= current packet size defaultResetWindowTimeout = seconds(15) # If remote peer window drops to zero, then after some time we will reset it @@ -446,10 +448,15 @@ proc checkTimeouts(socket: UtpSocket) {.async.} = await socket.flushPackets() if socket.isOpened(): + let currentPacketSize = uint32(socket.getPacketSize()) - if (socket.sendBufferTracker.maxRemoteWindow == 0 and currentTime > socket.zeroWindowTimer): - debug "Reset remote window to minimal value" - socket.sendBufferTracker.updateMaxRemote(minimalRemoteWindow) + if (socket.zeroWindowTimer.isSome() and currentTime > socket.zeroWindowTimer.unsafeGet()): + if socket.sendBufferTracker.maxRemoteWindow <= currentPacketSize: + socket.sendBufferTracker.updateMaxRemote(minimalRemoteWindow) + socket.zeroWindowTimer = none[Moment]() + debug "Reset remote window to minimal value", + minRemote = minimalRemoteWindow + if (currentTime > socket.rtoTimeout): debug "CheckTimeouts rto timeout", @@ -487,7 +494,7 @@ proc checkTimeouts(socket: UtpSocket) {.async.} = # on timeout reset duplicate ack counter socket.duplicateAck = 0 - let currentPacketSize = uint32(socket.getPacketSize()) + if (socket.curWindowPackets == 0 and socket.sendBufferTracker.maxWindow > currentPacketSize): # there are no packets in flight even though there is place for more than whole packet @@ -566,57 +573,59 @@ proc resetSendTimeout(socket: UtpSocket) = socket.rtoTimeout = getMonoTimestamp().moment + socket.retransmitTimeout proc handleDataWrite(socket: UtpSocket, data: seq[byte], writeFut: Future[WriteResult]): Future[void] {.async.} = - if writeFut.finished(): - # write future was cancelled befere we got chance to process it, short circuit - # processing and move to next loop iteration - return + if writeFut.finished(): + # write future was cancelled befere we got chance to process it, short circuit + # processing and move to next loop iteration + return + + let pSize = socket.getPacketSize() + let endIndex = data.high() + var i = 0 + var bytesWritten = 0 + + while i <= endIndex: + let lastIndex = i + pSize - 1 + let lastOrEnd = min(lastIndex, endIndex) + let dataSlice = data[i..lastOrEnd] + let payloadLength = uint32(len(dataSlice)) + try: + await socket.sendBufferTracker.reserveNBytesWait(payloadLength) + + if socket.curWindowPackets == 0: + socket.resetSendTimeout() - let pSize = socket.getPacketSize() - let endIndex = data.high() - var i = 0 - var bytesWritten = 0 let wndSize = socket.getRcvWindowSize() - while i <= endIndex: - let lastIndex = i + pSize - 1 - let lastOrEnd = min(lastIndex, endIndex) - let dataSlice = data[i..lastOrEnd] - let payloadLength = uint32(len(dataSlice)) - try: - await socket.sendBufferTracker.reserveNBytesWait(payloadLength) - if socket.curWindowPackets == 0: - socket.resetSendTimeout() - - let dataPacket = - dataPacket( - socket.seqNr, - socket.connectionIdSnd, - socket.ackNr, - wndSize, - dataSlice, - socket.replayMicro - ) - let outgoingPacket = OutgoingPacket.init(encodePacket(dataPacket), 1, false, payloadLength) - socket.registerOutgoingPacket(outgoingPacket) - await socket.sendData(outgoingPacket.packetBytes) - except CancelledError as exc: - # write loop has been cancelled in the middle of processing due to the - # socket closing - # this approach can create partial write in case destroyin socket in the - # the middle of the write - doAssert(socket.state == Destroy) - if (not writeFut.finished()): - let res = Result[int, WriteError].err(WriteError(kind: SocketNotWriteable, currentState: socket.state)) - writeFut.complete(res) - # we need to re-raise exception so the outer loop will be properly cancelled too - raise exc - bytesWritten = bytesWritten + len(dataSlice) - i = lastOrEnd + 1 - - # Before completeing future with success (as all data was sent sucessfuly) - # we need to check if user did not cancel write on his end + let dataPacket = + dataPacket( + socket.seqNr, + socket.connectionIdSnd, + socket.ackNr, + wndSize, + dataSlice, + socket.replayMicro + ) + let outgoingPacket = OutgoingPacket.init(encodePacket(dataPacket), 1, false, payloadLength) + socket.registerOutgoingPacket(outgoingPacket) + await socket.sendData(outgoingPacket.packetBytes) + except CancelledError as exc: + # write loop has been cancelled in the middle of processing due to the + # socket closing + # this approach can create partial write in when destroying the socket in the + # the middle of the write + doAssert(socket.state == Destroy) if (not writeFut.finished()): - writeFut.complete(Result[int, WriteError].ok(bytesWritten)) + let res = Result[int, WriteError].err(WriteError(kind: SocketNotWriteable, currentState: socket.state)) + writeFut.complete(res) + # we need to re-raise exception so the outer loop will be properly cancelled too + raise exc + bytesWritten = bytesWritten + len(dataSlice) + i = lastOrEnd + 1 + + # Before completing the future with success (as all data was sent successfully) + # we need to check if user did not cancel write on his end + if (not writeFut.finished()): + writeFut.complete(Result[int, WriteError].ok(bytesWritten)) proc handleClose(socket: UtpSocket): Future[void] {.async.} = try: @@ -706,7 +715,7 @@ proc new[A]( sendBufferTracker: SendBufferTracker.new(0, 1024 * 1024, cfg.optSndBuffer, startMaxWindow), # queue with infinite size writeQueue: newAsyncQueue[WriteRequest](), - zeroWindowTimer: currentTime + cfg.remoteWindowResetTimeout, + zeroWindowTimer: none[Moment](), socketKey: UtpSocketKey.init(to, rcvId), slowStart: true, fastTimeout: false, @@ -1131,11 +1140,13 @@ proc generateAckPacket*(socket: UtpSocket): Packet = else: none[array[4, byte]]() + let bufferSize = socket.getRcvWindowSize() + ackPacket( socket.seqNr, socket.connectionIdSnd, socket.ackNr, - socket.getRcvWindowSize(), + bufferSize, socket.replayMicro, bitmask ) @@ -1175,7 +1186,8 @@ proc processPacket*(socket: UtpSocket, p: Packet) {.async.} = seqNr = p.header.seqNr, ackNr = p.header.ackNr, timestamp = p.header.timestamp, - timestampDiff = p.header.timestampDiff + timestampDiff = p.header.timestampDiff, + remoteWindow = p.header.wndSize let timestampInfo = getMonoTimestamp() @@ -1255,7 +1267,7 @@ proc processPacket*(socket: UtpSocket, p: Packet) {.async.} = let isPossibleDuplicatedOldPacket = pastExpected >= (int(uint16.high) + 1) - reorderBufferMaxSize if (isPossibleDuplicatedOldPacket and p.header.pType != ST_STATE): - asyncSpawn socket.sendAck() + discard socket.sendAck() debug "Got an invalid packet sequence number, too far off", pastExpected = pastExpected @@ -1311,13 +1323,14 @@ proc processPacket*(socket: UtpSocket, p: Packet) {.async.} = let diff = uint32((socket.ourHistogram.getValue() - minRtt).microseconds()) socket.ourHistogram.shift(diff) + let currentPacketSize = uint32(socket.getPacketSize()) let (newMaxWindow, newSlowStartTreshold, newSlowStart) = applyCongestionControl( socket.sendBufferTracker.maxWindow, socket.slowStart, socket.slowStartTreshold, socket.socketConfig.optSndBuffer, - uint32(socket.getPacketSize()), + currentPacketSize, microseconds(actualDelay), ackedBytes, minRtt, @@ -1336,14 +1349,15 @@ proc processPacket*(socket: UtpSocket, p: Packet) {.async.} = slowStartTreshold = newSlowStartTreshold, slowstart = newSlowStart - if (socket.sendBufferTracker.maxRemoteWindow == 0): + if (socket.zeroWindowTimer.isNone() and socket.sendBufferTracker.maxRemoteWindow <= currentPacketSize): # when zeroWindowTimer will be hit and maxRemoteWindow still will be equal to 0 # then it will be reset to minimal value - socket.zeroWindowTimer = timestampInfo.moment + socket.socketConfig.remoteWindowResetTimeout + socket.zeroWindowTimer = some(timestampInfo.moment + socket.socketConfig.remoteWindowResetTimeout) - debug "Remote window size dropped to 0", + debug "Remote window size dropped below packet size", currentTime = timestampInfo.moment, - resetZeroWindowTime = socket.zeroWindowTimer + resetZeroWindowTime = socket.zeroWindowTimer, + currentPacketSize = currentPacketSize # socket.curWindowPackets == acks means that this packet acked all remaining packets # including the sent fin packets @@ -1488,7 +1502,7 @@ proc processPacket*(socket: UtpSocket, p: Packet) {.async.} = # need improvement, as with this approach there is no direct control over # how many concurrent tasks there are and how to cancel them when socket # is closed - asyncSpawn socket.sendAck() + discard socket.sendAck() # we got packet out of order else: @@ -1515,7 +1529,7 @@ proc processPacket*(socket: UtpSocket, p: Packet) {.async.} = reorderCount = socket.reorderCount # we send ack packet, as we reoreder count is > 0, so the eack bitmask will be # generated - asyncSpawn socket.sendAck() + discard socket.sendAck() of ST_STATE: if (socket.state == SynSent and (not socket.connectionFuture.finished())): diff --git a/tests/utp/all_utp_tests.nim b/tests/utp/all_utp_tests.nim index c0e86a4..fb001ec 100644 --- a/tests/utp/all_utp_tests.nim +++ b/tests/utp/all_utp_tests.nim @@ -15,4 +15,5 @@ import ./test_utp_socket, ./test_utp_socket_sack, ./test_utp_router, - ./test_clock_drift_calculator + ./test_clock_drift_calculator, + ./test_protocol_integration diff --git a/tests/utp/test_protocol_integration.nim b/tests/utp/test_protocol_integration.nim new file mode 100644 index 0000000..43e712f --- /dev/null +++ b/tests/utp/test_protocol_integration.nim @@ -0,0 +1,231 @@ +# Copyright (c) 2022 Status Research & Development GmbH +# Licensed and distributed under either of +# * MIT license (license terms in the root directory or at https://opensource.org/licenses/MIT). +# * Apache v2 license (license terms in the root directory or at https://www.apache.org/licenses/LICENSE-2.0). +# at your option. This file may not be copied, modified, or distributed except according to those terms. + +{.used.} + +import + std/[sequtils, tables, options, sugar], + chronos, bearssl, + testutils/unittests, + ./test_utils, + ../../eth/utp/utp_router, + ../../eth/utp/utp_protocol, + ../../eth/keys, + ../../eth/p2p/discoveryv5/random2 + + +proc connectTillSuccess(p: UtpProtocol, to: TransportAddress, maxTries: int = 20): Future[UtpSocket[TransportAddress]] {.async.} = + var i = 0 + while true: + let res = await p.connectTo(to) + + if res.isOk(): + return res.unsafeGet() + else: + inc i + if i >= maxTries: + raise newException(CatchableError, "Connection failed") + +proc buildAcceptConnection( + t: ref Table[UtpSocketKey[TransportAddress], UtpSocket[TransportAddress]] + ): AcceptConnectionCallback[TransportAddress] = + return ( + proc (server: UtpRouter[TransportAddress], client: UtpSocket[TransportAddress]): Future[void] = + let fut = newFuture[void]() + let key = client.socketKey + t[key] = client + fut.complete() + return fut + ) + +proc getServerSocket( + t: ref Table[UtpSocketKey[TransportAddress], UtpSocket[TransportAddress]], + clientAddress: TransportAddress, + clientConnectionId: uint16): Option[UtpSocket[TransportAddress]] = + let serverSocketKey = UtpSocketKey[TransportAddress](remoteAddress: clientAddress, rcvId: clientConnectionId + 1) + let srvSocket = t.getOrDefault(serverSocketKey) + if srvSocket == nil: + return none[UtpSocket[TransportAddress]]() + else: + return some(srvSocket) + +procSuite "Utp protocol over udp tests with loss and delays": + let rng = newRng() + + proc sendBuilder(maxDelay: int, packetDropRate: int): SendCallbackBuilder = + return ( + proc (d: DatagramTransport): SendCallback[TransportAddress] = + return ( + proc (to: TransportAddress, data: seq[byte]): Future[void] {.async.} = + let i = rand(rng[], 99) + if i >= packetDropRate: + let delay = milliseconds(rand(rng[], maxDelay)) + await sleepAsync(delay) + await d.sendTo(to, data) + ) + ) + + proc testScenario(maxDelay: int, dropRate: int, cfg: SocketConfig = SocketConfig.init()): + Future[( + UtpProtocol, + UtpSocket[TransportAddress], + UtpProtocol, + UtpSocket[TransportAddress]) + ] {.async.} = + + var connections1 = newTable[UtpSocketKey[TransportAddress], UtpSocket[TransportAddress]]() + let address1 = initTAddress("127.0.0.1", 9080) + let utpProt1 = + UtpProtocol.new( + buildAcceptConnection(connections1), + address1, + socketConfig = cfg, + sendCallbackBuilder = sendBuilder(maxDelay, dropRate), + rng = rng) + + var connections2 = newTable[UtpSocketKey[TransportAddress], UtpSocket[TransportAddress]]() + let address2 = initTAddress("127.0.0.1", 9081) + let utpProt2 = + UtpProtocol.new( + buildAcceptConnection(connections2), + address2, + socketConfig = cfg, + sendCallbackBuilder = sendBuilder(maxDelay, dropRate), + rng = rng) + + let clientSocket = await utpProt1.connectTillSuccess(address2) + let maybeServerSocket = connections2.getServerSocket(address1, clientSocket.socketKey.rcvId) + + let serverSocket = maybeServerSocket.unsafeGet() + + return (utpProt1, clientSocket, utpProt2, serverSocket) + + type TestCase = object + # in miliseconds + maxDelay: int + dropRate: int + bytesToTransfer: int + bytesPerRead: int + cfg: SocketConfig + + proc init( + T: type TestCase, + maxDelay: int, + dropRate: int, + bytesToTransfer: int, + cfg: SocketConfig = SocketConfig.init(), + bytesPerRead: int = 0): TestCase = + TestCase(maxDelay: maxDelay, dropRate: dropRate, bytesToTransfer: bytesToTransfer, cfg: cfg, bytesPerRead: bytesPerRead) + + + let testCases = @[ + TestCase.init(45, 10, 40000), + TestCase.init(45, 15, 40000), + TestCase.init(50, 20, 20000), + # super small recv buffer which will be constantly on the brink of being full + TestCase.init(15, 5, 80000, SocketConfig.init(optRcvBuffer = uint32(2000), remoteWindowResetTimeout = seconds(5))), + TestCase.init(15, 10, 80000, SocketConfig.init(optRcvBuffer = uint32(2000), remoteWindowResetTimeout = seconds(5))) + ] + + asyncTest "Write and Read large data in different network conditions": + for testCase in testCases: + + let ( + clientProtocol, + clientSocket, + serverProtocol, + serverSocket) = await testScenario(testCase.maxDelay, testCase.dropRate, testcase.cfg) + + let smallBytes = 10 + let smallBytesToTransfer = generateByteArray(rng[], smallBytes) + # first transfer and read to make server socket connecteced + let write1 = await clientSocket.write(smallBytesToTransfer) + let read1 = await serverSocket.read(smallBytes) + + check: + write1.isOk() + read1 == smallBytesToTransfer + + let numBytes = testCase.bytesToTransfer + let bytesToTransfer = generateByteArray(rng[], numBytes) + + discard clientSocket.write(bytesToTransfer) + discard serverSocket.write(bytesToTransfer) + + let serverReadFut = serverSocket.read(numBytes) + let clientReadFut = clientSocket.read(numBytes) + + yield serverReadFut + yield clientReadFut + + let clientRead = clientReadFut.read() + let serverRead = serverReadFut.read() + + check: + clientRead == bytesToTransfer + serverRead == bytesToTransfer + + await clientProtocol.shutdownWait() + await serverProtocol.shutdownWait() + + let testCases1 = @[ + # small buffers so it will fill up between reads + TestCase.init(15, 5, 60000, SocketConfig.init(optRcvBuffer = uint32(2000), remoteWindowResetTimeout = seconds(5)), 10000), + TestCase.init(15, 10, 60000, SocketConfig.init(optRcvBuffer = uint32(2000), remoteWindowResetTimeout = seconds(5)), 10000), + TestCase.init(15, 15, 60000, SocketConfig.init(optRcvBuffer = uint32(2000), remoteWindowResetTimeout = seconds(5)), 10000) + ] + + proc readWithMultipleReads(s: UtpSocket[TransportAddress], numOfReads: int, bytesPerRead: int): Future[seq[byte]] {.async.}= + var i = 0 + var res: seq[byte] = @[] + while i < numOfReads: + let bytes = await s.read(bytesPerRead) + res.add(bytes) + inc i + return res + + asyncTest "Write and Read large data in different network conditions split over several reads": + for testCase in testCases1: + + let ( + clientProtocol, + clientSocket, + serverProtocol, + serverSocket) = await testScenario(testCase.maxDelay, testCase.dropRate, testcase.cfg) + + let smallBytes = 10 + let smallBytesToTransfer = generateByteArray(rng[], smallBytes) + # first transfer and read to make server socket connecteced + let write1 = await clientSocket.write(smallBytesToTransfer) + let read1 = await serverSocket.read(smallBytes) + + check: + read1 == smallBytesToTransfer + + let numBytes = testCase.bytesToTransfer + let bytesToTransfer = generateByteArray(rng[], numBytes) + + discard clientSocket.write(bytesToTransfer) + discard serverSocket.write(bytesToTransfer) + + let numOfReads = int(testCase.bytesToTransfer / testCase.bytesPerRead) + let serverReadFut = serverSocket.readWithMultipleReads(numOfReads, testCase.bytesPerRead) + let clientReadFut = clientSocket.readWithMultipleReads(numOfReads, testCase.bytesPerRead) + + yield serverReadFut + + yield clientReadFut + + let clientRead = clientReadFut.read() + let serverRead = serverReadFut.read() + + check: + clientRead == bytesToTransfer + serverRead == bytesToTransfer + + await clientProtocol.shutdownWait() + await serverProtocol.shutdownWait() +