Handle selective acks (#456)

* Handle selective acks
This commit is contained in:
KonradStaniec 2022-01-04 09:52:38 +01:00 committed by GitHub
parent 664072fff7
commit 9c8e9d9f64
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 522 additions and 97 deletions

View File

@ -48,8 +48,8 @@ type
# sequence number the sender of the packet last received in the other direction # sequence number the sender of the packet last received in the other direction
ackNr*: uint16 ackNr*: uint16
SelectiveAckExtension = object SelectiveAckExtension* = object
acks: array[acksArrayLength, byte] acks*: array[4, byte]
Packet* = object Packet* = object
header*: PacketHeaderV1 header*: PacketHeaderV1

View File

@ -9,7 +9,7 @@
import import
std/sugar, std/sugar,
chronos, chronicles, bearssl, chronos, chronicles, bearssl,
stew/results, stew/[results, bitops2],
./send_buffer_tracker, ./send_buffer_tracker,
./growable_buffer, ./growable_buffer,
./packets, ./packets,
@ -341,20 +341,6 @@ proc registerOutgoingPacket(socket: UtpSocket, oPacket: OutgoingPacket) =
proc sendData(socket: UtpSocket, data: seq[byte]): Future[void] = proc sendData(socket: UtpSocket, data: seq[byte]): Future[void] =
socket.send(socket.remoteAddress, data) socket.send(socket.remoteAddress, data)
proc sendAck(socket: UtpSocket): Future[void] =
## Creates and sends ack, based on current socket state. Acks are different from
## other packets as we do not track them in outgoing buffet
let ackPacket =
ackPacket(
socket.seqNr,
socket.connectionIdSnd,
socket.ackNr,
socket.getRcvWindowSize(),
socket.replayMicro
)
socket.sendData(encodePacket(ackPacket))
# Should be called before sending packet # Should be called before sending packet
proc setSend(s: UtpSocket, p: var OutgoingPacket): seq[byte] = proc setSend(s: UtpSocket, p: var OutgoingPacket): seq[byte] =
let timestampInfo = getMonoTimestamp() let timestampInfo = getMonoTimestamp()
@ -725,12 +711,6 @@ proc startOutgoingSocket*(socket: UtpSocket): Future[void] {.async.} =
await socket.sendData(outgoingPacket.packetBytes) await socket.sendData(outgoingPacket.packetBytes)
await socket.connectionFuture await socket.connectionFuture
proc startIncomingSocket*(socket: UtpSocket) {.async.} =
# Make sure ack was flushed before moving forward
await socket.sendAck()
socket.startWriteLoop()
socket.startTimeoutLoop()
proc isConnected*(socket: UtpSocket): bool = proc isConnected*(socket: UtpSocket): bool =
socket.state == Connected socket.state == Connected
@ -881,6 +861,121 @@ proc isAckNrInvalid(socket: UtpSocket, packet: Packet): bool =
) )
) )
# counts the number of bytes acked by selective ack header
proc calculateSelectiveAckBytes*(socket: UtpSocket, receivedPackedAckNr: uint16, ext: SelectiveAckExtension): uint32 =
# we add 2, as the first bit in the mask therefore represents ackNr + 2 becouse
# ackNr + 1 (i.e next expected packet) is considered lost.
let base = receivedPackedAckNr + 2
if socket.curWindowPackets == 0:
return 0
var ackedBytes = 0'u32
var bits = (len(ext.acks)) * 8 - 1
while bits >= 0:
let v = base + uint16(bits)
if (socket.seqNr - v - 1) >= socket.curWindowPackets - 1:
dec bits
continue
let maybePacket = socket.outBuffer.get(v)
if (maybePacket.isNone() or maybePacket.unsafeGet().transmissions == 0):
dec bits
continue
let pkt = maybePacket.unsafeGet()
if (getBit(ext.acks, bits)):
ackedBytes = ackedBytes + pkt.payloadLength
dec bits
return ackedBytes
# ack packets (removes them from out going buffer) based on selective ack extension header
proc selectiveAckPackets(socket: UtpSocket, receivedPackedAckNr: uint16, ext: SelectiveAckExtension, currentTime: Moment): void =
# we add 2, as the first bit in the mask therefore represents ackNr + 2 becouse
# ackNr + 1 (i.e next expected packet) is considered lost.
let base = receivedPackedAckNr + 2
if socket.curWindowPackets == 0:
return
var bits = (len(ext.acks)) * 8 - 1
while bits >= 0:
let v = base + uint16(bits)
if (socket.seqNr - v - 1) >= socket.curWindowPackets - 1:
dec bits
continue
let maybePacket = socket.outBuffer.get(v)
if (maybePacket.isNone() or maybePacket.unsafeGet().transmissions == 0):
dec bits
continue
let pkt = maybePacket.unsafeGet()
if (getBit(ext.acks, bits)):
discard socket.ackPacket(v, currentTime)
dec bits
# TODO Add handling of fast timeouts and duplicate acks counting
# Public mainly for test purposes
# generates bit mask which indicates which packets are already in socket
# reorder buffer
# from speck:
# The bitmask has reverse byte order. The first byte represents packets [ack_nr + 2, ack_nr + 2 + 7] in reverse order
# The least significant bit in the byte represents ack_nr + 2, the most significant bit in the byte represents ack_nr + 2 + 7
# The next byte in the mask represents [ack_nr + 2 + 8, ack_nr + 2 + 15] in reverse order, and so on
proc generateSelectiveAckBitMask*(socket: UtpSocket): array[4, byte] =
let window = min(32, socket.inBuffer.len())
var arr: array[4, uint8] = [0'u8, 0, 0, 0]
var i = 0
while i < window:
if (socket.inBuffer.get(socket.ackNr + uint16(i) + 2).isSome()):
setBit(arr, i)
inc i
return arr
# Generates ack packet based on current state of the socket.
proc generateAckPacket*(socket: UtpSocket): Packet =
let bitmask =
if (socket.reorderCount != 0 and (not socket.reachedFin)):
some(socket.generateSelectiveAckBitMask())
else:
none[array[4, byte]]()
ackPacket(
socket.seqNr,
socket.connectionIdSnd,
socket.ackNr,
socket.getRcvWindowSize(),
socket.replayMicro,
bitmask
)
proc sendAck(socket: UtpSocket): Future[void] =
## Creates and sends ack, based on current socket state. Acks are different from
## other packets as we do not track them in outgoing buffet
let ackPacket = socket.generateAckPacket()
socket.sendData(encodePacket(ackPacket))
proc startIncomingSocket*(socket: UtpSocket) {.async.} =
# Make sure ack was flushed before moving forward
await socket.sendAck()
socket.startWriteLoop()
socket.startTimeoutLoop()
# TODO at socket level we should handle only FIN/DATA/ACK packets. Refactor to make # TODO at socket level we should handle only FIN/DATA/ACK packets. Refactor to make
# it enforcable by type system # it enforcable by type system
# TODO re-think synchronization of this procedure, as each await inside gives control # TODO re-think synchronization of this procedure, as each await inside gives control
@ -888,7 +983,7 @@ proc isAckNrInvalid(socket: UtpSocket, packet: Packet): bool =
# running # running
proc processPacket*(socket: UtpSocket, p: Packet) {.async.} = proc processPacket*(socket: UtpSocket, p: Packet) {.async.} =
let timestampInfo = getMonoTimestamp() let timestampInfo = getMonoTimestamp()
if socket.isAckNrInvalid(p): if socket.isAckNrInvalid(p):
notice "Received packet with invalid ack nr" notice "Received packet with invalid ack nr"
return return
@ -923,6 +1018,10 @@ proc processPacket*(socket: UtpSocket, p: Packet) {.async.} =
var (ackedBytes, minRtt) = socket.calculateAckedbytes(acks, timestampInfo.moment) var (ackedBytes, minRtt) = socket.calculateAckedbytes(acks, timestampInfo.moment)
# TODO caluclate bytes acked by selective acks here (if thats the case) # TODO caluclate bytes acked by selective acks here (if thats the case)
if (p.eack.isSome()):
let selectiveAckedBytes = socket.calculateSelectiveAckBytes(pkAckNr, p.eack.unsafeGet())
ackedBytes = ackedBytes + selectiveAckedBytes
let sentTimeRemote = p.header.timestamp let sentTimeRemote = p.header.timestamp
# we are using uint32 not a Duration, to wrap a round in case of # we are using uint32 not a Duration, to wrap a round in case of
@ -999,6 +1098,14 @@ proc processPacket*(socket: UtpSocket, p: Packet) {.async.} =
socket.ackPackets(acks, timestampInfo.moment) socket.ackPackets(acks, timestampInfo.moment)
# packets in front may have been acked by selective ack, decrease window until we hit
# a packet that is still waiting to be acked
while (socket.curWindowPackets > 0 and socket.outBuffer.get(socket.seqNr - socket.curWindowPackets).isNone()):
dec socket.curWindowPackets
if (p.eack.isSome()):
socket.selectiveAckPackets(pkAckNr, p.eack.unsafeGet(), timestampInfo.moment)
case p.header.pType case p.header.pType
of ST_DATA, ST_FIN: of ST_DATA, ST_FIN:
# To avoid amplification attacks, server socket is in SynRecv state until # To avoid amplification attacks, server socket is in SynRecv state until
@ -1014,6 +1121,7 @@ proc processPacket*(socket: UtpSocket, p: Packet) {.async.} =
# we got in order packet # we got in order packet
if (pastExpected == 0 and (not socket.reachedFin)): if (pastExpected == 0 and (not socket.reachedFin)):
notice "Got in order packet"
if (len(p.payload) > 0 and (not socket.readShutdown)): if (len(p.payload) > 0 and (not socket.readShutdown)):
# we are getting in order data packet, we can flush data directly to the incoming buffer # we are getting in order data packet, we can flush data directly to the incoming buffer
await upload(addr socket.buffer, unsafeAddr p.payload[0], p.payload.len()) await upload(addr socket.buffer, unsafeAddr p.payload[0], p.payload.len())
@ -1085,8 +1193,10 @@ proc processPacket*(socket: UtpSocket, p: Packet) {.async.} =
socket.inBuffer.put(pkSeqNr, p) socket.inBuffer.put(pkSeqNr, p)
inc socket.reorderCount inc socket.reorderCount
notice "added out of order packet in reorder buffer" notice "added out of order packet in reorder buffer"
# TODO for now we do not sent any ack as we do not handle selective acks # we send ack packet, as we reoreder count is > 0, so the eack bitmask will be
# add sending of selective acks # generated
asyncSpawn socket.sendAck()
of ST_STATE: of ST_STATE:
if (socket.state == SynSent and (not socket.connectionFuture.finished())): if (socket.state == SynSent and (not socket.connectionFuture.finished())):
socket.state = Connected socket.state = Connected
@ -1220,13 +1330,11 @@ proc read*(socket: UtpSocket): Future[seq[byte]] {.async.}=
# Check how many packets are still in the out going buffer, usefull for tests or # Check how many packets are still in the out going buffer, usefull for tests or
# debugging. # debugging.
# It throws assertion error when number of elements in buffer do not equal kept counter
proc numPacketsInOutGoingBuffer*(socket: UtpSocket): int = proc numPacketsInOutGoingBuffer*(socket: UtpSocket): int =
var num = 0 var num = 0
for e in socket.outBuffer.items(): for e in socket.outBuffer.items():
if e.isSome(): if e.isSome():
inc num inc num
doAssert(num == int(socket.curWindowPackets))
num num
# Check how many payload bytes are still in flight # Check how many payload bytes are still in flight

View File

@ -12,5 +12,6 @@ import
./test_discv5_protocol, ./test_discv5_protocol,
./test_buffer, ./test_buffer,
./test_utp_socket, ./test_utp_socket,
./test_utp_socket_sack,
./test_utp_router, ./test_utp_router,
./test_clock_drift_calculator ./test_clock_drift_calculator

View File

@ -1,9 +1,14 @@
import import
chronos, chronos,
../../eth/utp/utp_socket,
../../eth/utp/packets,
../../eth/keys ../../eth/keys
type AssertionCallback = proc(): bool {.gcsafe, raises: [Defect].} type AssertionCallback = proc(): bool {.gcsafe, raises: [Defect].}
let testBufferSize = 1024'u32
let defaultRcvOutgoingId = 314'u16
proc generateByteArray*(rng: var BrHmacDrbgContext, length: int): seq[byte] = proc generateByteArray*(rng: var BrHmacDrbgContext, length: int): seq[byte] =
var bytes = newSeq[byte](length) var bytes = newSeq[byte](length)
brHmacDrbgGenerate(rng, bytes) brHmacDrbgGenerate(rng, bytes)
@ -16,3 +21,96 @@ proc waitUntil*(f: AssertionCallback): Future[void] {.async.} =
break break
else: else:
await sleepAsync(milliseconds(50)) await sleepAsync(milliseconds(50))
template connectOutGoingSocket*(
initialRemoteSeq: uint16,
q: AsyncQueue[Packet],
remoteReceiveBuffer: uint32 = testBufferSize,
cfg: SocketConfig = SocketConfig.init()): (UtpSocket[TransportAddress], Packet) =
let sock1 = newOutgoingSocket[TransportAddress](testAddress, initTestSnd(q), cfg, defaultRcvOutgoingId, rng[])
asyncSpawn sock1.startOutgoingSocket()
let initialPacket = await q.get()
check:
initialPacket.header.pType == ST_SYN
let responseAck =
ackPacket(
initialRemoteSeq,
initialPacket.header.connectionId,
initialPacket.header.seqNr,
remoteReceiveBuffer,
0
)
await sock1.processPacket(responseAck)
check:
sock1.isConnected()
(sock1, initialPacket)
template connectOutGoingSocketWithIncoming*(
initialRemoteSeq: uint16,
outgoingQueue: AsyncQueue[Packet],
incomingQueue: AsyncQueue[Packet],
remoteReceiveBuffer: uint32 = testBufferSize,
cfg: SocketConfig = SocketConfig.init()): (UtpSocket[TransportAddress], UtpSocket[TransportAddress]) =
let outgoingSocket = newOutgoingSocket[TransportAddress](testAddress, initTestSnd(outgoingQueue), cfg, defaultRcvOutgoingId, rng[])
asyncSpawn outgoingSocket.startOutgoingSocket()
let initialPacket = await outgoingQueue.get()
check:
initialPacket.header.pType == ST_SYN
let incomingSocket = newIncomingSocket[TransportAddress](
testAddress,
initTestSnd(incomingQueue),
cfg,
initialPacket.header.connectionId,
initialPacket.header.seqNr,
rng[]
)
await incomingSocket.startIncomingSocket()
let responseAck = await incomingQueue.get()
await outgoingSocket.processPacket(responseAck)
check:
outgoingSocket.isConnected()
(outgoingSocket, incomingSocket)
proc generateDataPackets*(
numberOfPackets: uint16,
initialSeqNr: uint16,
connectionId: uint16,
ackNr: uint16,
rng: var BrHmacDrbgContext): seq[Packet] =
let packetSize = 100
var packets = newSeq[Packet]()
var i = 0'u16
while i < numberOfPackets:
let packet = dataPacket(
initialSeqNr + i,
connectionId,
ackNr,
testBufferSize,
generateByteArray(rng, packetSize),
0
)
packets.add(packet)
inc i
packets
proc initTestSnd*(q: AsyncQueue[Packet]): SendCallback[TransportAddress]=
return (
proc (to: TransportAddress, bytes: seq[byte]): Future[void] =
let p = decodePacket(bytes).get()
q.addLast(p)
)

View File

@ -7,7 +7,7 @@
{.used.} {.used.}
import import
std/[algorithm, random, sequtils], std/[algorithm, random, sequtils, options],
chronos, bearssl, chronicles, chronos, bearssl, chronicles,
testutils/unittests, testutils/unittests,
./test_utils, ./test_utils,
@ -22,71 +22,12 @@ procSuite "Utp socket unit test":
let testBufferSize = 1024'u32 let testBufferSize = 1024'u32
let defaultRcvOutgoingId = 314'u16 let defaultRcvOutgoingId = 314'u16
proc initTestSnd(q: AsyncQueue[Packet]): SendCallback[TransportAddress]=
return (
proc (to: TransportAddress, bytes: seq[byte]): Future[void] =
let p = decodePacket(bytes).get()
q.addLast(p)
)
proc generateDataPackets(
numberOfPackets: uint16,
initialSeqNr: uint16,
connectionId: uint16,
ackNr: uint16,
rng: var BrHmacDrbgContext): seq[Packet] =
let packetSize = 100
var packets = newSeq[Packet]()
var i = 0'u16
while i < numberOfPackets:
let packet = dataPacket(
initialSeqNr + i,
connectionId,
ackNr,
testBufferSize,
generateByteArray(rng, packetSize),
0
)
packets.add(packet)
inc i
packets
proc packetsToBytes(packets: seq[Packet]): seq[byte] = proc packetsToBytes(packets: seq[Packet]): seq[byte] =
var resultBytes = newSeq[byte]() var resultBytes = newSeq[byte]()
for p in packets: for p in packets:
resultBytes.add(p.payload) resultBytes.add(p.payload)
return resultBytes return resultBytes
template connectOutGoingSocket(
initialRemoteSeq: uint16,
q: AsyncQueue[Packet],
remoteReceiveBuffer: uint32 = testBufferSize,
cfg: SocketConfig = SocketConfig.init()): (UtpSocket[TransportAddress], Packet) =
let sock1 = newOutgoingSocket[TransportAddress](testAddress, initTestSnd(q), cfg, defaultRcvOutgoingId, rng[])
asyncSpawn sock1.startOutgoingSocket()
let initialPacket = await q.get()
check:
initialPacket.header.pType == ST_SYN
let responseAck =
ackPacket(
initialRemoteSeq,
initialPacket.header.connectionId,
initialPacket.header.seqNr,
remoteReceiveBuffer,
0
)
await sock1.processPacket(responseAck)
check:
sock1.isConnected()
(sock1, initialPacket)
asyncTest "Starting outgoing socket should send Syn packet": asyncTest "Starting outgoing socket should send Syn packet":
let q = newAsyncQueue[Packet]() let q = newAsyncQueue[Packet]()
let defaultConfig = SocketConfig.init() let defaultConfig = SocketConfig.init()
@ -187,10 +128,11 @@ procSuite "Utp socket unit test":
# TODO test is valid until implementing selective acks # TODO test is valid until implementing selective acks
let q = newAsyncQueue[Packet]() let q = newAsyncQueue[Packet]()
let initalRemoteSeqNr = 10'u16 let initalRemoteSeqNr = 10'u16
let numOfPackets = 10'u16
let (outgoingSocket, initialPacket) = connectOutGoingSocket(initalRemoteSeqNr, q) let (outgoingSocket, initialPacket) = connectOutGoingSocket(initalRemoteSeqNr, q)
var packets = generateDataPackets(10, initalRemoteSeqNr, initialPacket.header.connectionId, initialPacket.header.seqNr, rng[]) var packets = generateDataPackets(numOfPackets, initalRemoteSeqNr, initialPacket.header.connectionId, initialPacket.header.seqNr, rng[])
let data = packetsToBytes(packets) let data = packetsToBytes(packets)
@ -200,12 +142,28 @@ procSuite "Utp socket unit test":
for p in packets: for p in packets:
await outgoingSocket.processPacket(p) await outgoingSocket.processPacket(p)
let ack2 = await q.get() var sentAcks: seq[Packet] = @[]
for i in 0'u16..<numOfPackets:
let ack = await q.get()
sentAcks.add(ack)
# all packets except last one should be selective acks, without bumped ackNr
for i in 0'u16..<numOfPackets - 1:
check:
sentAcks[i].header.ackNr == initalRemoteSeqNr - 1
sentAcks[i].eack.isSome()
# last ack should be normal ack packet (not selective one), and it should ack
# all remaining packets
let lastAck = sentAcks[numOfPackets - 1]
check: check:
ack2.header.pType == ST_STATE lastAck.header.pType == ST_STATE
# we are acking in one shot whole 10 packets # we are acking in one shot whole 10 packets
ack2.header.ackNr == initalRemoteSeqNr + uint16(len(packets) - 1) lastAck.header.ackNr == initalRemoteSeqNr + uint16(len(packets) - 1)
lastAck.eack.isNone()
let receivedData = await outgoingSocket.read(len(data)) let receivedData = await outgoingSocket.read(len(data))
@ -218,10 +176,11 @@ procSuite "Utp socket unit test":
# TODO test is valid until implementing selective acks # TODO test is valid until implementing selective acks
let q = newAsyncQueue[Packet]() let q = newAsyncQueue[Packet]()
let initalRemoteSeqNr = 10'u16 let initalRemoteSeqNr = 10'u16
let numOfPackets = 3'u16
let (outgoingSocket, initialPacket) = connectOutGoingSocket(initalRemoteSeqNr, q) let (outgoingSocket, initialPacket) = connectOutGoingSocket(initalRemoteSeqNr, q)
var packets = generateDataPackets(3, initalRemoteSeqNr, initialPacket.header.connectionId, initialPacket.header.seqNr, rng[]) var packets = generateDataPackets(numOfPackets, initalRemoteSeqNr, initialPacket.header.connectionId, initialPacket.header.seqNr, rng[])
let data = packetsToBytes(packets) let data = packetsToBytes(packets)
@ -235,12 +194,28 @@ procSuite "Utp socket unit test":
for p in packets: for p in packets:
await outgoingSocket.processPacket(p) await outgoingSocket.processPacket(p)
let ack2 = await q.get() var sentAcks: seq[Packet] = @[]
for i in 0'u16..<numOfPackets:
let ack = await q.get()
sentAcks.add(ack)
# all packets except last one should be selective acks, without bumped ackNr
for i in 0'u16..<numOfPackets - 1:
check:
sentAcks[i].header.ackNr == initalRemoteSeqNr - 1
sentAcks[i].eack.isSome()
# last ack should be normal ack packet (not selective one), and it should ack
# all remaining packets
let lastAck = sentAcks[numOfPackets - 1]
check: check:
ack2.header.pType == ST_STATE lastAck.header.pType == ST_STATE
# we are acking in one shot whole 10 packets # we are acking in one shot whole 10 packets
ack2.header.ackNr == initalRemoteSeqNr + uint16(len(packets) - 1) lastAck.header.ackNr == initalRemoteSeqNr + uint16(len(packets) - 1)
lastAck.eack.isNone()
let receivedData = await outgoingSocket.read(len(data)) let receivedData = await outgoingSocket.read(len(data))
@ -250,7 +225,6 @@ procSuite "Utp socket unit test":
await outgoingSocket.destroyWait() await outgoingSocket.destroyWait()
asyncTest "Processing packets in random order": asyncTest "Processing packets in random order":
# TODO test is valid until implementing selective acks
let q = newAsyncQueue[Packet]() let q = newAsyncQueue[Packet]()
let initalRemoteSeqNr = 10'u16 let initalRemoteSeqNr = 10'u16

View File

@ -0,0 +1,244 @@
# Copyright (c) 2021 Status Research & Development GmbH
# Licensed and distributed under either of
# * MIT license (license terms in the root directory or at https://opensource.org/licenses/MIT).
# * Apache v2 license (license terms in the root directory or at https://www.apache.org/licenses/LICENSE-2.0).
# at your option. This file may not be copied, modified, or distributed except according to those terms.
{.used.}
import
std/[options, sequtils],
chronos, bearssl, chronicles,
stew/bitops2,
testutils/unittests,
./test_utils,
../../eth/utp/utp_router,
../../eth/utp/utp_socket,
../../eth/utp/packets,
../../eth/keys
procSuite "Utp socket selective acks unit test":
let rng = newRng()
let testAddress = initTAddress("127.0.0.1", 9079)
let defaultBufferSize = 1024'u32
proc connectAndProcessMissingPacketWithIndexes(idxs: seq[int]): Future[array[4, uint8]] {.async.} =
let initialRemoteSeq = 1'u16
let q = newAsyncQueue[Packet]()
let data = @[0'u8]
let (outgoingSocket, initialPacket) = connectOutGoingSocket(initialRemoteSeq, q)
var dataPackets: seq[Packet] = @[]
for i in idxs:
let dataP =
dataPacket(
# initialRemoteSeq is next expected packet, so n represent how far from the
# future is this packet
initialRemoteSeq + uint16(i),
initialPacket.header.connectionId,
initialPacket.header.seqNr,
defaultBufferSize,
data,
0
)
dataPackets.add(dataP)
for p in dataPackets:
await outgoingSocket.processPacket(p)
let extArray = outgoingSocket.generateSelectiveAckBitMask()
await outgoingSocket.destroyWait()
return extArray
proc numOfSetBits(arr: openArray[byte]): int =
var numOfSetBits = 0
for b in arr:
numOfSetBits = numOfSetBits + countOnes(b)
return numOfSetBits
proc hasOnlyOneBitSet(arr: openArray[byte]): bool =
return numOfSetBits(arr) == 1
asyncTest "Socket with empty buffer should generate array with only zeros":
let q = newAsyncQueue[Packet]()
let initialRemoteSeq = 10'u16
let (outgoingSocket, packet) = connectOutGoingSocket(initialRemoteSeq, q)
let extArray = outgoingSocket.generateSelectiveAckBitMask()
check:
extArray == [0'u8, 0, 0, 0]
asyncTest "Socket should generate correct bit mask for each missing packet":
# 1 means that received packet is packet just after expected packet i.e
# packet.seqNr - receivingSocket.ackNr = 2
# 32 means that received packet is 32 packets after expected one i.e
# packet.seqNr - receivingSocket.ackNr = 32
# First byte represents packets [ack_nr + 2, ack_nr + 9] in reverse order
# Second byte represents packets [ack_nr + 10, ack_nr + 17] in reverse order
# Third byte represents packets [ack_nr + 18, ack_nr + 25] in reverse order
# Fourth byte represents packets [ack_nr + 26, ack_nr + 33] in reverse order
let afterExpected = 1..32
for i in afterExpected:
# bit mask should have max 4 bytes
let bitMask = await connectAndProcessMissingPacketWithIndexes(@[i])
check:
# only one bit should have been set as only one packet has been processed
# out of order
hasOnlyOneBitSet(bitMask)
getBit(bitMask, i - 1)
asyncTest "Socket should generate correct bit mask if there is more than one missing packet":
# Each testcase defines which out of order packets should be processed i.e
# @[1] - packet just after expected will be processed
# @[3, 5] - packet three packets after will be processed and then packet 5 packets
# after expected will be processed
let testCases = @[
@[1],
@[1, 2],
@[1, 9, 11, 18],
@[1, 3, 8, 15, 18, 22, 27, 32]
]
for missingIndexes in testCases:
let bitMask = await connectAndProcessMissingPacketWithIndexes(missingIndexes)
check:
numOfSetBits(bitMask) == len(missingIndexes)
for idx in missingIndexes:
check:
getBit(bitMask, idx - 1)
asyncTest "Socket should generate max 4 bytes bit mask even if there is more missing packets":
let testCases = @[
toSeq(1..40)
]
for missingIndexes in testCases:
let bitMask = await connectAndProcessMissingPacketWithIndexes(missingIndexes)
check:
numOfSetBits(bitMask) == 32
len(bitMask) == 4
type TestCase = object
# number of packet to generate by writitng side
numOfPackets: int
# indexes of packets which should be delivered to remote
packetsDelivered: seq[int]
let selectiveAckTestCases = @[
TestCase(numOfPackets: 2, packetsDelivered: @[1]),
TestCase(numOfPackets: 10, packetsDelivered: @[1, 3, 5, 7, 9]),
TestCase(numOfPackets: 10, packetsDelivered: @[1, 2, 3, 4, 5, 6, 7, 8, 9]),
TestCase(numOfPackets: 15, packetsDelivered: @[1, 3, 5, 7, 9, 10, 11, 12, 14]),
TestCase(numOfPackets: 20, packetsDelivered: @[1, 3, 5, 7, 9, 11, 13, 15, 17, 19]),
TestCase(numOfPackets: 33, packetsDelivered: @[32]),
TestCase(numOfPackets: 33, packetsDelivered: @[25, 26, 27, 28, 29, 30, 31, 32]),
TestCase(numOfPackets: 33, packetsDelivered: toSeq(1..32))
]
asyncTest "Socket should calculate number of bytes acked by selective acks":
let dataSize = 10
let initialRemoteSeq = 10'u16
let smallData = generateByteArray(rng[], 10)
for testCase in selectiveAckTestCases:
let outgoingQueue = newAsyncQueue[Packet]()
let incomingQueue = newAsyncQueue[Packet]()
let (outgoingSocket, incomingSocket) =
connectOutGoingSocketWithIncoming(
initialRemoteSeq,
outgoingQueue,
incomingQueue
)
var packets: seq[Packet] = @[]
for _ in 0..<testCase.numOfPackets:
discard await outgoingSocket.write(smallData)
let packet = await outgoingQueue.get()
packets.add(packet)
for toDeliver in testCase.packetsDelivered:
await incomingSocket.processPacket(packets[toDeliver])
let finalAck = incomingSocket.generateAckPacket()
check:
finalAck.eack.isSome()
let mask = finalAck.eack.unsafeGet().acks
check:
numOfSetBits(mask) == len(testCase.packetsDelivered)
for idx in testCase.packetsDelivered:
check:
getBit(mask, idx - 1)
let ackedBytes = outgoingSocket.calculateSelectiveAckBytes(finalAck.header.ackNr, finalAck.eack.unsafeGet())
check:
int(ackedBytes) == len(testCase.packetsDelivered) * dataSize
await outgoingSocket.destroyWait()
await incomingSocket.destroyWait()
asyncTest "Socket should ack packets based on selective ack packet":
let dataSize = 10
let initialRemoteSeq = 10'u16
let smallData = generateByteArray(rng[], 10)
for testCase in selectiveAckTestCases:
let outgoingQueue = newAsyncQueue[Packet]()
let incomingQueue = newAsyncQueue[Packet]()
let (outgoingSocket, incomingSocket) =
connectOutGoingSocketWithIncoming(
initialRemoteSeq,
outgoingQueue,
incomingQueue
)
var packets: seq[Packet] = @[]
for _ in 0..<testCase.numOfPackets:
discard await outgoingSocket.write(smallData)
let packet = await outgoingQueue.get()
packets.add(packet)
for toDeliver in testCase.packetsDelivered:
await incomingSocket.processPacket(packets[toDeliver])
let finalAck = incomingSocket.generateAckPacket()
check:
finalAck.eack.isSome()
let mask = finalAck.eack.unsafeGet().acks
check:
numOfSetBits(mask) == len(testCase.packetsDelivered)
for idx in testCase.packetsDelivered:
check:
getBit(mask, idx - 1)
check:
outgoingSocket.numPacketsInOutGoingBuffer() == testCase.numOfPackets
await outgoingSocket.processPacket(finalAck)
check:
outgoingSocket.numPacketsInOutGoingBuffer() == testCase.numOfPackets - len(testCase.packetsDelivered)
await outgoingSocket.destroyWait()
await incomingSocket.destroyWait()