diff --git a/eth/utp/packets.nim b/eth/utp/packets.nim index 16f3f42..c8bcd07 100644 --- a/eth/utp/packets.nim +++ b/eth/utp/packets.nim @@ -17,8 +17,10 @@ export results const minimalHeaderSize = 20 + minimalHeaderSizeWithSelectiveAck = 26 protocolVersion = 1 zeroMoment = Moment.init(0, Nanosecond) + acksArrayLength: uint8 = 4 type PacketType* = enum @@ -46,8 +48,12 @@ type # sequence number the sender of the packet last received in the other direction ackNr*: uint16 + SelectiveAckExtension = object + acks: array[acksArrayLength, byte] + Packet* = object header*: PacketHeaderV1 + eack*: Option[SelectiveAckExtension] payload*: seq[uint8] TimeStampInfo* = object @@ -102,10 +108,22 @@ proc encodeHeaderStream(s: var OutputStream, h: PacketHeaderV1) = # This should not happen in case of in-memory streams raiseAssert e.msg +proc encodeExtensionStream(s: var OutputStream, e: SelectiveAckExtension) = + try: + # writing 0 as there is not further extensions after selectiv ack + s.write(0'u8) + s.write(acksArrayLength) + s.write(e.acks) + except IOError as e: + # This should not happen in case of in-memory streams + raiseAssert e.msg + proc encodePacket*(p: Packet): seq[byte] = var s = memoryOutput().s try: encodeHeaderStream(s, p.header) + if (p.eack.isSome()): + encodeExtensionStream(s, p.eack.unsafeGet()) if (len(p.payload) > 0): s.write(p.payload) s.getOutput() @@ -113,9 +131,9 @@ proc encodePacket*(p: Packet): seq[byte] = # This should not happen in case of in-memory streams raiseAssert e.msg -# TODO for now we do not handle extensions proc decodePacket*(bytes: openArray[byte]): Result[Packet, string] = - if len(bytes) < minimalHeaderSize: + let receivedBytesLength = len(bytes) + if receivedBytesLength < minimalHeaderSize: return err("invalid header size") let version = bytes[0] and 0xf @@ -126,11 +144,16 @@ proc decodePacket*(bytes: openArray[byte]): Result[Packet, string] = if not checkedEnumAssign(kind, (bytes[0] shr 4)): return err("Invalid message type") + let extensionByte = bytes[1] + + if (not (extensionByte == 0 or extensionByte == 1)): + return err("Invalid extension type") + let header = PacketHeaderV1( pType: kind, version: version, - extension: bytes[1], + extension: extensionByte, connection_id: fromBytesBE(uint16, bytes.toOpenArray(2, 3)), timestamp: fromBytesBE(uint32, bytes.toOpenArray(4, 7)), timestamp_diff: fromBytesBE(uint32, bytes.toOpenArray(8, 11)), @@ -139,13 +162,43 @@ proc decodePacket*(bytes: openArray[byte]): Result[Packet, string] = ack_nr: fromBytesBE(uint16, bytes.toOpenArray(18, 19)), ) - let payload = - if (len(bytes) == 20): - @[] - else: - bytes[20..^1] + if extensionByte == 0: + # packet without any extensions + let payload = + if (receivedBytesLength == minimalHeaderSize): + @[] + else: + bytes[20..^1] - ok(Packet(header: header, payload: payload)) + return ok(Packet(header: header, eack: none[SelectiveAckExtension](), payload: payload)) + else: + # packet with selective ack extension + if (receivedBytesLength < minimalHeaderSizeWithSelectiveAck): + return err("Packet too short for selective ack extension") + + let nextExtension = bytes[20] + let extLength = bytes[21] + + # As selective ack is only supported extension the byte for nextExtension + # must be equal to 0. + # As for extLength, specificaiton says that it must be at least 4, and in multiples of 4 + # but reference implementation always uses 4 bytes bit mask which makes sense + # as 4byte bit mask is able to ack 32 packets in the future which is more than enough + if (nextExtension != 0 or extLength != 4): + return err("Bad format of selective ack extension") + + + let extension = SelectiveAckExtension( + acks: [bytes[22], bytes[23], bytes[24], bytes[25]] + ) + + let payload = + if (receivedBytesLength == minimalHeaderSizeWithSelectiveAck): + @[] + else: + bytes[26..^1] + + return ok(Packet(header: header, eack: some(extension), payload: payload)) proc modifyTimeStampAndAckNr*(packetBytes: var seq[byte], newTimestamp: uint32, newAckNr: uint16) = ## Modifies timestamp and ack nr of already encoded packets. Those fields should be @@ -163,7 +216,6 @@ proc synPacket*(seqNr: uint16, rcvConnectionId: uint16, bufferSize: uint32): Pac let h = PacketHeaderV1( pType: ST_SYN, version: protocolVersion, - # TODO for we do not handle extensions extension: 0'u8, connectionId: rcvConnectionId, timestamp: getMonoTimestamp().timestamp, @@ -174,14 +226,27 @@ proc synPacket*(seqNr: uint16, rcvConnectionId: uint16, bufferSize: uint32): Pac ackNr: 0'u16 ) - Packet(header: h, payload: @[]) + Packet(header: h, eack: none[SelectiveAckExtension](), payload: @[]) + +proc ackPacket*( + seqNr: uint16, + sndConnectionId: uint16, + ackNr: uint16, + bufferSize: uint32, + timestampDiff: uint32, + acksBitmask: Option[array[4, byte]] = none[array[4, byte]]() + ): Packet = + + let (extensionByte, extensionData) = + if acksBitmask.isSome(): + (1'u8, some(SelectiveAckExtension(acks: acksBitmask.unsafeGet()))) + else: + (0'u8, none[SelectiveAckExtension]()) -proc ackPacket*(seqNr: uint16, sndConnectionId: uint16, ackNr: uint16, bufferSize: uint32, timestampDiff: uint32): Packet = let h = PacketHeaderV1( pType: ST_STATE, version: protocolVersion, - # TODO Handle selective acks - extension: 0'u8, + extension: extensionByte, connectionId: sndConnectionId, timestamp: getMonoTimestamp().timestamp, timestampDiff: timestampDiff, @@ -190,7 +255,8 @@ proc ackPacket*(seqNr: uint16, sndConnectionId: uint16, ackNr: uint16, bufferSiz ackNr: ackNr ) - Packet(header: h, payload: @[]) + + Packet(header: h, eack: extensionData, payload: @[]) proc dataPacket*( seqNr: uint16, @@ -213,7 +279,7 @@ proc dataPacket*( ackNr: ackNr ) - Packet(header: h, payload: payload) + Packet(header: h, eack: none[SelectiveAckExtension](), payload: payload) proc resetPacket*(seqNr: uint16, sndConnectionId: uint16, ackNr: uint16): Packet = let h = PacketHeaderV1( @@ -231,7 +297,7 @@ proc resetPacket*(seqNr: uint16, sndConnectionId: uint16, ackNr: uint16): Packet ackNr: ackNr ) - Packet(header: h, payload: @[]) + Packet(header: h, eack: none[SelectiveAckExtension](), payload: @[]) proc finPacket*( seqNr: uint16, @@ -253,4 +319,4 @@ proc finPacket*( ackNr: ackNr ) - Packet(header: h, payload: @[]) + Packet(header: h, eack: none[SelectiveAckExtension](), payload: @[]) diff --git a/tests/utp/test_packets.nim b/tests/utp/test_packets.nim index 6527b35..6f4debc 100644 --- a/tests/utp/test_packets.nim +++ b/tests/utp/test_packets.nim @@ -7,6 +7,7 @@ {.used.} import + std/options, unittest2, ../../eth/utp/packets, ../../eth/keys @@ -21,8 +22,109 @@ suite "Utp packets encoding/decoding": let decoded = decodePacket(encoded) check: + len(encoded) == 20 decoded.isOk() - synPacket == decoded.get() + + let synPacketDec = decoded.get() + + check: + synPacketDec == synPacket + + test "Encode/decode fin packet": + let finPacket = finPacket(5, 10, 20, 30, 40) + let encoded = encodePacket(finPacket) + let decoded = decodePacket(encoded) + + check: + len(encoded) == 20 + decoded.isOk() + + let finPacketDec = decoded.get() + + check: + finPacketDec == finPacket + + test "Encode/decode reset packet": + let resetPacket = resetPacket(5, 10, 20) + let encoded = encodePacket(resetPacket) + let decoded = decodePacket(encoded) + + check: + len(encoded) == 20 + decoded.isOk() + + let resetPacketDec = decoded.get() + + check: + resetPacketDec == resetPacket + + test "Encode/decode ack packet without extensions": + let ackPacket = ackPacket(5, 10, 20, 30, 40) + let encoded = encodePacket(ackPacket) + let decoded = decodePacket(encoded) + + check: + len(encoded) == 20 + decoded.isOk() + + let ackPacketDec = decoded.get() + + check: + ackPacketDec == ackPacket + + test "Encode/decode ack packet with extensions": + let bitMask: array[4, byte] = [1'u8, 2, 3, 4] + let ackPacket = ackPacket(5, 10, 20, 30, 40, some(bitMask)) + let encoded = encodePacket(ackPacket) + let decoded = decodePacket(encoded) + + check: + len(encoded) == 26 + decoded.isOk() + + let ackPacketDec = decoded.get() + + check: + ackPacketDec == ackPacket + ackPacketDec.eack.isSome() + + test "Fail to decode packet with malformed extensions": + let bitMask: array[4, byte] = [1'u8, 2, 3, 4] + let ackPacket = ackPacket(5, 10, 20, 30, 40, some(bitMask)) + + var encoded1 = encodePacket(ackPacket) + # change nextExtension to non zero + encoded1[20] = 1 + let err1 = decodePacket(encoded1) + check: + err1.isErr() + err1.error() == "Bad format of selective ack extension" + + var encoded2 = encodePacket(ackPacket) + # change len of extension to value different than 4 + encoded2[21] = 7 + let err2 = decodePacket(encoded2) + check: + err2.isErr() + err2.error() == "Bad format of selective ack extension" + + var encoded3 = encodePacket(ackPacket) + # delete last byte, now packet is to short + encoded3.del(encoded3.high) + let err3 = decodePacket(encoded3) + + check: + err3.isErr() + err3.error() == "Packet too short for selective ack extension" + + + var encoded4 = encodePacket(ackPacket) + # change change extension field to something other than 0 or 1 + encoded4[1] = 2 + let err4 = decodePacket(encoded4) + check: + err4.isErr() + err4.error() == "Invalid extension type" test "Decode state packet": # Packet obtained by interaction with c reference implementation