Add ackNr validation (#424)

This commit is contained in:
KonradStaniec 2021-11-15 11:32:00 +01:00 committed by GitHub
parent 8139aae346
commit 73d9bf4c80
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 60 additions and 0 deletions

View File

@ -192,6 +192,12 @@ const
# often is proportional to RTT anyway # often is proportional to RTT anyway
defaultOptRcvBuffer: uint32 = 1024 * 1024 defaultOptRcvBuffer: uint32 = 1024 * 1024
# rationale from C reference impl:
# Allow a reception window of at least 3 ack_nrs behind seq_nr
# A non-SYN packet with an ack_nr difference greater than this is
# considered suspicious and ignored
allowedAckWindow*: uint16 = 3
reorderBufferMaxSize = 1024 reorderBufferMaxSize = 1024
proc init*[A](T: type UtpSocketKey, remoteAddress: A, rcvId: uint16): T = proc init*[A](T: type UtpSocketKey, remoteAddress: A, rcvId: uint16): T =
@ -592,12 +598,40 @@ proc initializeAckNr(socket: UtpSocket, packetSeqNr: uint16) =
if (socket.state == SynSent): if (socket.state == SynSent):
socket.ackNr = packetSeqNr - 1 socket.ackNr = packetSeqNr - 1
# compare if lhs is less than rhs, taking wrapping
# into account. i.e high(lhs) < 0 == true
proc wrapCompareLess(lhs: uint16, rhs:uint16): bool =
let distDown = (lhs - rhs)
let distUp = (rhs - lhs)
# if the distance walking up is shorter, lhs
# is less than rhs. If the distance walking down
# is shorter, then rhs is less than lhs
return distUp < distDown
proc isAckNrInvalid(socket: UtpSocket, packet: Packet): bool =
let ackWindow = max(socket.curWindowPackets + allowedAckWindow, allowedAckWindow)
(
(packet.header.pType != ST_SYN or socket.state != SynRecv) and
(
# packet ack number must be smaller than our last send packet i.e
# remote should not ack packets from the future
wrapCompareLess(socket.seqNr - 1, packet.header.ackNr) or
# packet ack number should not be too old
wrapCompareLess(packet.header.ackNr, socket.seqNr - 1 - ackWindow)
)
)
# TODO at socket level we should handle only FIN/DATA/ACK packets. Refactor to make # TODO at socket level we should handle only FIN/DATA/ACK packets. Refactor to make
# it enforcable by type system # it enforcable by type system
# TODO re-think synchronization of this procedure, as each await inside gives control # TODO re-think synchronization of this procedure, as each await inside gives control
# to scheduler which means there could be potentialy several processPacket procs # to scheduler which means there could be potentialy several processPacket procs
# running # running
proc processPacket*(socket: UtpSocket, p: Packet) {.async.} = proc processPacket*(socket: UtpSocket, p: Packet) {.async.} =
if socket.isAckNrInvalid(p):
notice "Received packet with invalid ack nr"
return
## Updates socket state based on received packet, and sends ack when necessary. ## Updates socket state based on received packet, and sends ack when necessary.
## Shoyuld be called in main packet receiving loop ## Shoyuld be called in main packet receiving loop
let pkSeqNr = p.header.seqNr let pkSeqNr = p.header.seqNr

View File

@ -568,3 +568,29 @@ procSuite "Utp socket unit test":
# we have read all data from rcv buffer, advertised window should go back to # we have read all data from rcv buffer, advertised window should go back to
# initial size # initial size
sentData.header.wndSize == initialRcvBufferSize sentData.header.wndSize == initialRcvBufferSize
asyncTest "Socket should ignore packets with bad ack number":
let q = newAsyncQueue[Packet]()
let initialRemoteSeq = 10'u16
let data1 = @[1'u8, 2'u8, 3'u8]
let data2 = @[4'u8, 5'u8, 6'u8]
let data3 = @[7'u8, 7'u8, 9'u8]
let (outgoingSocket, initialPacket) = connectOutGoingSocket(initialRemoteSeq, q)
# data packet with ack nr set above our seq nr i.e packet from the future
let dataFuture = dataPacket(initialRemoteSeq, initialPacket.header.connectionId, initialPacket.header.seqNr + 1, testBufferSize, data1)
# data packet wth ack number set below out ack window i.e packet too old
let dataTooOld = dataPacket(initialRemoteSeq, initialPacket.header.connectionId, initialPacket.header.seqNr - allowedAckWindow - 1, testBufferSize, data2)
let dataOk = dataPacket(initialRemoteSeq, initialPacket.header.connectionId, initialPacket.header.seqNr, testBufferSize, data3)
await outgoingSocket.processPacket(dataFuture)
await outgoingSocket.processPacket(dataTooOld)
await outgoingSocket.processPacket(dataOk)
let receivedBytes = await outgoingSocket.read(data3.len)
check:
# data1 and data2 were sent in bad packets we should only receive data3
receivedBytes == data3