diff --git a/eth/utp/utp_socket.nim b/eth/utp/utp_socket.nim index a2d94c3..94b243e 100644 --- a/eth/utp/utp_socket.nim +++ b/eth/utp/utp_socket.nim @@ -192,6 +192,12 @@ const # often is proportional to RTT anyway defaultOptRcvBuffer: uint32 = 1024 * 1024 + # rationale from C reference impl: + # Allow a reception window of at least 3 ack_nrs behind seq_nr + # A non-SYN packet with an ack_nr difference greater than this is + # considered suspicious and ignored + allowedAckWindow*: uint16 = 3 + reorderBufferMaxSize = 1024 proc init*[A](T: type UtpSocketKey, remoteAddress: A, rcvId: uint16): T = @@ -592,12 +598,40 @@ proc initializeAckNr(socket: UtpSocket, packetSeqNr: uint16) = if (socket.state == SynSent): socket.ackNr = packetSeqNr - 1 +# compare if lhs is less than rhs, taking wrapping +# into account. i.e high(lhs) < 0 == true +proc wrapCompareLess(lhs: uint16, rhs:uint16): bool = + let distDown = (lhs - rhs) + let distUp = (rhs - lhs) + # if the distance walking up is shorter, lhs + # is less than rhs. If the distance walking down + # is shorter, then rhs is less than lhs + return distUp < distDown + +proc isAckNrInvalid(socket: UtpSocket, packet: Packet): bool = + let ackWindow = max(socket.curWindowPackets + allowedAckWindow, allowedAckWindow) + ( + (packet.header.pType != ST_SYN or socket.state != SynRecv) and + ( + # packet ack number must be smaller than our last send packet i.e + # remote should not ack packets from the future + wrapCompareLess(socket.seqNr - 1, packet.header.ackNr) or + # packet ack number should not be too old + wrapCompareLess(packet.header.ackNr, socket.seqNr - 1 - ackWindow) + ) + ) + # TODO at socket level we should handle only FIN/DATA/ACK packets. Refactor to make # it enforcable by type system # TODO re-think synchronization of this procedure, as each await inside gives control # to scheduler which means there could be potentialy several processPacket procs # running proc processPacket*(socket: UtpSocket, p: Packet) {.async.} = + + if socket.isAckNrInvalid(p): + notice "Received packet with invalid ack nr" + return + ## Updates socket state based on received packet, and sends ack when necessary. ## Shoyuld be called in main packet receiving loop let pkSeqNr = p.header.seqNr diff --git a/tests/utp/test_utp_socket.nim b/tests/utp/test_utp_socket.nim index 852a947..49987e1 100644 --- a/tests/utp/test_utp_socket.nim +++ b/tests/utp/test_utp_socket.nim @@ -568,3 +568,29 @@ procSuite "Utp socket unit test": # we have read all data from rcv buffer, advertised window should go back to # initial size sentData.header.wndSize == initialRcvBufferSize + + asyncTest "Socket should ignore packets with bad ack number": + let q = newAsyncQueue[Packet]() + let initialRemoteSeq = 10'u16 + let data1 = @[1'u8, 2'u8, 3'u8] + let data2 = @[4'u8, 5'u8, 6'u8] + let data3 = @[7'u8, 7'u8, 9'u8] + + let (outgoingSocket, initialPacket) = connectOutGoingSocket(initialRemoteSeq, q) + + # data packet with ack nr set above our seq nr i.e packet from the future + let dataFuture = dataPacket(initialRemoteSeq, initialPacket.header.connectionId, initialPacket.header.seqNr + 1, testBufferSize, data1) + # data packet wth ack number set below out ack window i.e packet too old + let dataTooOld = dataPacket(initialRemoteSeq, initialPacket.header.connectionId, initialPacket.header.seqNr - allowedAckWindow - 1, testBufferSize, data2) + + let dataOk = dataPacket(initialRemoteSeq, initialPacket.header.connectionId, initialPacket.header.seqNr, testBufferSize, data3) + + await outgoingSocket.processPacket(dataFuture) + await outgoingSocket.processPacket(dataTooOld) + await outgoingSocket.processPacket(dataOk) + + let receivedBytes = await outgoingSocket.read(data3.len) + + check: + # data1 and data2 were sent in bad packets we should only receive data3 + receivedBytes == data3