Handle packets with selective acks (#451)

* Handle packets with selective acks
This commit is contained in:
KonradStaniec 2021-12-15 13:35:17 +01:00 committed by GitHub
parent 5655bd035c
commit 0cfe7df817
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 187 additions and 19 deletions

View File

@ -17,8 +17,10 @@ export results
const const
minimalHeaderSize = 20 minimalHeaderSize = 20
minimalHeaderSizeWithSelectiveAck = 26
protocolVersion = 1 protocolVersion = 1
zeroMoment = Moment.init(0, Nanosecond) zeroMoment = Moment.init(0, Nanosecond)
acksArrayLength: uint8 = 4
type type
PacketType* = enum PacketType* = enum
@ -46,8 +48,12 @@ type
# sequence number the sender of the packet last received in the other direction # sequence number the sender of the packet last received in the other direction
ackNr*: uint16 ackNr*: uint16
SelectiveAckExtension = object
acks: array[acksArrayLength, byte]
Packet* = object Packet* = object
header*: PacketHeaderV1 header*: PacketHeaderV1
eack*: Option[SelectiveAckExtension]
payload*: seq[uint8] payload*: seq[uint8]
TimeStampInfo* = object TimeStampInfo* = object
@ -102,10 +108,22 @@ proc encodeHeaderStream(s: var OutputStream, h: PacketHeaderV1) =
# This should not happen in case of in-memory streams # This should not happen in case of in-memory streams
raiseAssert e.msg raiseAssert e.msg
proc encodeExtensionStream(s: var OutputStream, e: SelectiveAckExtension) =
try:
# writing 0 as there is not further extensions after selectiv ack
s.write(0'u8)
s.write(acksArrayLength)
s.write(e.acks)
except IOError as e:
# This should not happen in case of in-memory streams
raiseAssert e.msg
proc encodePacket*(p: Packet): seq[byte] = proc encodePacket*(p: Packet): seq[byte] =
var s = memoryOutput().s var s = memoryOutput().s
try: try:
encodeHeaderStream(s, p.header) encodeHeaderStream(s, p.header)
if (p.eack.isSome()):
encodeExtensionStream(s, p.eack.unsafeGet())
if (len(p.payload) > 0): if (len(p.payload) > 0):
s.write(p.payload) s.write(p.payload)
s.getOutput() s.getOutput()
@ -113,9 +131,9 @@ proc encodePacket*(p: Packet): seq[byte] =
# This should not happen in case of in-memory streams # This should not happen in case of in-memory streams
raiseAssert e.msg raiseAssert e.msg
# TODO for now we do not handle extensions
proc decodePacket*(bytes: openArray[byte]): Result[Packet, string] = proc decodePacket*(bytes: openArray[byte]): Result[Packet, string] =
if len(bytes) < minimalHeaderSize: let receivedBytesLength = len(bytes)
if receivedBytesLength < minimalHeaderSize:
return err("invalid header size") return err("invalid header size")
let version = bytes[0] and 0xf let version = bytes[0] and 0xf
@ -126,11 +144,16 @@ proc decodePacket*(bytes: openArray[byte]): Result[Packet, string] =
if not checkedEnumAssign(kind, (bytes[0] shr 4)): if not checkedEnumAssign(kind, (bytes[0] shr 4)):
return err("Invalid message type") return err("Invalid message type")
let extensionByte = bytes[1]
if (not (extensionByte == 0 or extensionByte == 1)):
return err("Invalid extension type")
let header = let header =
PacketHeaderV1( PacketHeaderV1(
pType: kind, pType: kind,
version: version, version: version,
extension: bytes[1], extension: extensionByte,
connection_id: fromBytesBE(uint16, bytes.toOpenArray(2, 3)), connection_id: fromBytesBE(uint16, bytes.toOpenArray(2, 3)),
timestamp: fromBytesBE(uint32, bytes.toOpenArray(4, 7)), timestamp: fromBytesBE(uint32, bytes.toOpenArray(4, 7)),
timestamp_diff: fromBytesBE(uint32, bytes.toOpenArray(8, 11)), timestamp_diff: fromBytesBE(uint32, bytes.toOpenArray(8, 11)),
@ -139,13 +162,43 @@ proc decodePacket*(bytes: openArray[byte]): Result[Packet, string] =
ack_nr: fromBytesBE(uint16, bytes.toOpenArray(18, 19)), ack_nr: fromBytesBE(uint16, bytes.toOpenArray(18, 19)),
) )
if extensionByte == 0:
# packet without any extensions
let payload = let payload =
if (len(bytes) == 20): if (receivedBytesLength == minimalHeaderSize):
@[] @[]
else: else:
bytes[20..^1] bytes[20..^1]
ok(Packet(header: header, payload: payload)) return ok(Packet(header: header, eack: none[SelectiveAckExtension](), payload: payload))
else:
# packet with selective ack extension
if (receivedBytesLength < minimalHeaderSizeWithSelectiveAck):
return err("Packet too short for selective ack extension")
let nextExtension = bytes[20]
let extLength = bytes[21]
# As selective ack is only supported extension the byte for nextExtension
# must be equal to 0.
# As for extLength, specificaiton says that it must be at least 4, and in multiples of 4
# but reference implementation always uses 4 bytes bit mask which makes sense
# as 4byte bit mask is able to ack 32 packets in the future which is more than enough
if (nextExtension != 0 or extLength != 4):
return err("Bad format of selective ack extension")
let extension = SelectiveAckExtension(
acks: [bytes[22], bytes[23], bytes[24], bytes[25]]
)
let payload =
if (receivedBytesLength == minimalHeaderSizeWithSelectiveAck):
@[]
else:
bytes[26..^1]
return ok(Packet(header: header, eack: some(extension), payload: payload))
proc modifyTimeStampAndAckNr*(packetBytes: var seq[byte], newTimestamp: uint32, newAckNr: uint16) = proc modifyTimeStampAndAckNr*(packetBytes: var seq[byte], newTimestamp: uint32, newAckNr: uint16) =
## Modifies timestamp and ack nr of already encoded packets. Those fields should be ## Modifies timestamp and ack nr of already encoded packets. Those fields should be
@ -163,7 +216,6 @@ proc synPacket*(seqNr: uint16, rcvConnectionId: uint16, bufferSize: uint32): Pac
let h = PacketHeaderV1( let h = PacketHeaderV1(
pType: ST_SYN, pType: ST_SYN,
version: protocolVersion, version: protocolVersion,
# TODO for we do not handle extensions
extension: 0'u8, extension: 0'u8,
connectionId: rcvConnectionId, connectionId: rcvConnectionId,
timestamp: getMonoTimestamp().timestamp, timestamp: getMonoTimestamp().timestamp,
@ -174,14 +226,27 @@ proc synPacket*(seqNr: uint16, rcvConnectionId: uint16, bufferSize: uint32): Pac
ackNr: 0'u16 ackNr: 0'u16
) )
Packet(header: h, payload: @[]) Packet(header: h, eack: none[SelectiveAckExtension](), payload: @[])
proc ackPacket*(
seqNr: uint16,
sndConnectionId: uint16,
ackNr: uint16,
bufferSize: uint32,
timestampDiff: uint32,
acksBitmask: Option[array[4, byte]] = none[array[4, byte]]()
): Packet =
let (extensionByte, extensionData) =
if acksBitmask.isSome():
(1'u8, some(SelectiveAckExtension(acks: acksBitmask.unsafeGet())))
else:
(0'u8, none[SelectiveAckExtension]())
proc ackPacket*(seqNr: uint16, sndConnectionId: uint16, ackNr: uint16, bufferSize: uint32, timestampDiff: uint32): Packet =
let h = PacketHeaderV1( let h = PacketHeaderV1(
pType: ST_STATE, pType: ST_STATE,
version: protocolVersion, version: protocolVersion,
# TODO Handle selective acks extension: extensionByte,
extension: 0'u8,
connectionId: sndConnectionId, connectionId: sndConnectionId,
timestamp: getMonoTimestamp().timestamp, timestamp: getMonoTimestamp().timestamp,
timestampDiff: timestampDiff, timestampDiff: timestampDiff,
@ -190,7 +255,8 @@ proc ackPacket*(seqNr: uint16, sndConnectionId: uint16, ackNr: uint16, bufferSiz
ackNr: ackNr ackNr: ackNr
) )
Packet(header: h, payload: @[])
Packet(header: h, eack: extensionData, payload: @[])
proc dataPacket*( proc dataPacket*(
seqNr: uint16, seqNr: uint16,
@ -213,7 +279,7 @@ proc dataPacket*(
ackNr: ackNr ackNr: ackNr
) )
Packet(header: h, payload: payload) Packet(header: h, eack: none[SelectiveAckExtension](), payload: payload)
proc resetPacket*(seqNr: uint16, sndConnectionId: uint16, ackNr: uint16): Packet = proc resetPacket*(seqNr: uint16, sndConnectionId: uint16, ackNr: uint16): Packet =
let h = PacketHeaderV1( let h = PacketHeaderV1(
@ -231,7 +297,7 @@ proc resetPacket*(seqNr: uint16, sndConnectionId: uint16, ackNr: uint16): Packet
ackNr: ackNr ackNr: ackNr
) )
Packet(header: h, payload: @[]) Packet(header: h, eack: none[SelectiveAckExtension](), payload: @[])
proc finPacket*( proc finPacket*(
seqNr: uint16, seqNr: uint16,
@ -253,4 +319,4 @@ proc finPacket*(
ackNr: ackNr ackNr: ackNr
) )
Packet(header: h, payload: @[]) Packet(header: h, eack: none[SelectiveAckExtension](), payload: @[])

View File

@ -7,6 +7,7 @@
{.used.} {.used.}
import import
std/options,
unittest2, unittest2,
../../eth/utp/packets, ../../eth/utp/packets,
../../eth/keys ../../eth/keys
@ -21,8 +22,109 @@ suite "Utp packets encoding/decoding":
let decoded = decodePacket(encoded) let decoded = decodePacket(encoded)
check: check:
len(encoded) == 20
decoded.isOk() decoded.isOk()
synPacket == decoded.get()
let synPacketDec = decoded.get()
check:
synPacketDec == synPacket
test "Encode/decode fin packet":
let finPacket = finPacket(5, 10, 20, 30, 40)
let encoded = encodePacket(finPacket)
let decoded = decodePacket(encoded)
check:
len(encoded) == 20
decoded.isOk()
let finPacketDec = decoded.get()
check:
finPacketDec == finPacket
test "Encode/decode reset packet":
let resetPacket = resetPacket(5, 10, 20)
let encoded = encodePacket(resetPacket)
let decoded = decodePacket(encoded)
check:
len(encoded) == 20
decoded.isOk()
let resetPacketDec = decoded.get()
check:
resetPacketDec == resetPacket
test "Encode/decode ack packet without extensions":
let ackPacket = ackPacket(5, 10, 20, 30, 40)
let encoded = encodePacket(ackPacket)
let decoded = decodePacket(encoded)
check:
len(encoded) == 20
decoded.isOk()
let ackPacketDec = decoded.get()
check:
ackPacketDec == ackPacket
test "Encode/decode ack packet with extensions":
let bitMask: array[4, byte] = [1'u8, 2, 3, 4]
let ackPacket = ackPacket(5, 10, 20, 30, 40, some(bitMask))
let encoded = encodePacket(ackPacket)
let decoded = decodePacket(encoded)
check:
len(encoded) == 26
decoded.isOk()
let ackPacketDec = decoded.get()
check:
ackPacketDec == ackPacket
ackPacketDec.eack.isSome()
test "Fail to decode packet with malformed extensions":
let bitMask: array[4, byte] = [1'u8, 2, 3, 4]
let ackPacket = ackPacket(5, 10, 20, 30, 40, some(bitMask))
var encoded1 = encodePacket(ackPacket)
# change nextExtension to non zero
encoded1[20] = 1
let err1 = decodePacket(encoded1)
check:
err1.isErr()
err1.error() == "Bad format of selective ack extension"
var encoded2 = encodePacket(ackPacket)
# change len of extension to value different than 4
encoded2[21] = 7
let err2 = decodePacket(encoded2)
check:
err2.isErr()
err2.error() == "Bad format of selective ack extension"
var encoded3 = encodePacket(ackPacket)
# delete last byte, now packet is to short
encoded3.del(encoded3.high)
let err3 = decodePacket(encoded3)
check:
err3.isErr()
err3.error() == "Packet too short for selective ack extension"
var encoded4 = encodePacket(ackPacket)
# change change extension field to something other than 0 or 1
encoded4[1] = 2
let err4 = decodePacket(encoded4)
check:
err4.isErr()
err4.error() == "Invalid extension type"
test "Decode state packet": test "Decode state packet":
# Packet obtained by interaction with c reference implementation # Packet obtained by interaction with c reference implementation