From 0e20fd65658a357961d4d82b8c8b519205861c29 Mon Sep 17 00:00:00 2001 From: KonradStaniec Date: Fri, 18 Mar 2022 08:13:17 +0100 Subject: [PATCH] Utp improvements (#489) * Move connection finalization to separate function * Do not process data unless in correct state --- eth/utp/utp_socket.nim | 328 +++++++++++++++++----------------- tests/utp/test_utp_socket.nim | 58 ++++++ 2 files changed, 225 insertions(+), 161 deletions(-) diff --git a/eth/utp/utp_socket.nim b/eth/utp/utp_socket.nim index 10acbd5..1ccfbe9 100644 --- a/eth/utp/utp_socket.nim +++ b/eth/utp/utp_socket.nim @@ -1039,6 +1039,25 @@ proc sendAck(socket: UtpSocket) = socket.sendData(encodePacket(ackPacket)) + +proc tryfinalizeConnection(socket: UtpSocket, p: Packet) = + # To avoid amplification attacks, server socket is in SynRecv state until + # it receices first data transfer + # https://www.usenix.org/system/files/conference/woot15/woot15-paper-adamsky.pdf + # Socket is in SynRecv state only when recv timeout is configured + if (socket.state == SynRecv and p.header.pType == ST_DATA): + socket.state = Connected + + if (socket.state == SynSent and p.header.pType == ST_STATE): + socket.state = Connected + socket.ackNr = p.header.seqNr - 1 + + debug "Received Syn-Ack finalizing connection", + socketAckNr = socker.ackNr + + if (not socket.connectionFuture.finished()): + socket.connectionFuture.complete() + # TODO at socket level we should handle only FIN/DATA/ACK packets. Refactor to make # it enforcable by type system proc processPacketInternal(socket: UtpSocket, p: Packet) = @@ -1227,6 +1246,8 @@ proc processPacketInternal(socket: UtpSocket, p: Packet) = resetZeroWindowTime = socket.zeroWindowTimer, currentPacketSize = currentPacketSize + socket.tryfinalizeConnection(p) + # socket.curWindowPackets == acks means that this packet acked all remaining packets # including the sent fin packets if (socket.finSent and socket.curWindowPackets == acks): @@ -1280,177 +1301,162 @@ proc processPacketInternal(socket: UtpSocket, p: Packet) = if (p.eack.isSome()): socket.selectiveAckPackets(pkAckNr, p.eack.unsafeGet(), timestampInfo.moment) - case p.header.pType - of ST_DATA, ST_FIN: - # To avoid amplification attacks, server socket is in SynRecv state until - # it receices first data transfer - # https://www.usenix.org/system/files/conference/woot15/woot15-paper-adamsky.pdf - # Socket is in SynRecv state only when recv timeout is configured - if (socket.state == SynRecv and p.header.pType == ST_DATA): - socket.state = Connected + if p.header.pType == ST_DATA or p.header.pType == ST_FIN: + if socket.state != Connected: + debug "Unexpected packet", + socketState = socket.state, + packetType = p.header.pType - if (p.header.pType == ST_FIN and (not socket.gotFin)): - debug "Received FIN packet", - eofPktNr = pkSeqNr, - curAckNr = socket.ackNr + # we have received user generated packet (DATA or FIN), in not connected + # state. Stop processing it. + return - socket.gotFin = true - socket.eofPktNr = pkSeqNr + if (p.header.pType == ST_FIN and (not socket.gotFin)): + debug "Received FIN packet", + eofPktNr = pkSeqNr, + curAckNr = socket.ackNr - # we got in order packet - if (pastExpected == 0 and (not socket.reachedFin)): - debug "Received in order packet" - let payloadLength = len(p.payload) - if (payloadLength > 0 and (not socket.readShutdown)): - # we need to sum both rcv buffer and reorder buffer - if (uint32(socket.offset) + socket.inBufferBytes + uint32(payloadLength) > socket.socketConfig.optRcvBuffer): - # even though packet is in order and passes all the checks, it would - # overflow our receive buffer, it means that we are receiving data - # faster than we are reading it. Do not ack this packet, and drop received - # data - debug "Recevied packet would overflow receive buffer dropping it", - pkSeqNr = p.header.seqNr, - bytesReceived = payloadLength, - rcvbufferSize = socket.offset, - reorderBufferSize = socket.inBufferBytes - return - - debug "Received data packet", - bytesReceived = payloadLength - # we are getting in order data packet, we can flush data directly to the incoming buffer - # await upload(addr socket.buffer, unsafeAddr p.payload[0], p.payload.len()) - moveMem(addr socket.rcvBuffer[socket.offset], unsafeAddr p.payload[0], payloadLength) - socket.offset = socket.offset + payloadLength - - # Bytes have been passed to upper layer, we can increase number of last - # acked packet - inc socket.ackNr - - # check if the following packets are in reorder buffer - - debug "Looking for packets in re-order buffer", - reorderCount = socket.reorderCount - - while true: - # We are doing this in reoreder loop, to handle the case when we already received - # fin but there were some gaps before eof - # we have reached remote eof, and should not receive more packets from remote - if ((not socket.reachedFin) and socket.gotFin and socket.eofPktNr == socket.ackNr): - debug "Reached socket EOF" - # In case of reaching eof, it is up to user of library what to to with - # it. With the current implementation, the most apropriate way would be to - # destory it (as with our implementation we know that remote is destroying its acked fin) - # as any other send will either generate timeout, or socket will be forcefully - # closed by reset - socket.reachedFin = true - # this is not necessarily true, but as we have already reached eof we can - # ignore following packets - socket.reorderCount = 0 - - if socket.reorderCount == 0: - break - - let nextPacketNum = socket.ackNr + 1 - - let maybePacket = socket.inBuffer.get(nextPacketNum) - - if maybePacket.isNone(): - break - - let packet = maybePacket.unsafeGet() - let reorderPacketPayloadLength = len(packet.payload) - - if (reorderPacketPayloadLength > 0 and (not socket.readShutdown)): - debug "Got packet from reorder buffer", - packetBytes = len(packet.payload), - packetSeqNr = packet.header.seqNr, - packetAckNr = packet.header.ackNr, - socketSeqNr = socket.seqNr, - socektAckNr = socket.ackNr, - rcvbufferSize = socket.offset, - reorderBufferSize = socket.inBufferBytes - - # Rcv buffer and reorder buffer are sized that it is always possible to - # move data from reorder buffer to rcv buffer without overflow - moveMem(addr socket.rcvBuffer[socket.offset], unsafeAddr packet.payload[0], reorderPacketPayloadLength) - socket.offset = socket.offset + reorderPacketPayloadLength - - debug "Deleting packet", - seqNr = nextPacketNum - - socket.inBuffer.delete(nextPacketNum) - inc socket.ackNr - dec socket.reorderCount - socket.inBufferBytes = socket.inBufferBytes - uint32(reorderPacketPayloadLength) - - debug "Socket state after processing in order packet", - socketKey = socket.socketKey, - socketAckNr = socket.ackNr, - reorderCount = socket.reorderCount, - windowPackets = socket.curWindowPackets - - # TODO for now we just schedule concurrent task with ack sending. It may - # need improvement, as with this approach there is no direct control over - # how many concurrent tasks there are and how to cancel them when socket - # is closed - socket.sendAck() - - # we got packet out of order - else: - debug "Got out of order packet" - - if (socket.gotFin and pkSeqNr > socket.eofPktNr): - debug "Got packet past eof", - pkSeqNr = pkSeqNr, - eofPktNr = socket.eofPktNr + socket.gotFin = true + socket.eofPktNr = pkSeqNr + # we got in order packet + if (pastExpected == 0 and (not socket.reachedFin)): + debug "Received in order packet" + let payloadLength = len(p.payload) + if (payloadLength > 0 and (not socket.readShutdown)): + # we need to sum both rcv buffer and reorder buffer + if (uint32(socket.offset) + socket.inBufferBytes + uint32(payloadLength) > socket.socketConfig.optRcvBuffer): + # even though packet is in order and passes all the checks, it would + # overflow our receive buffer, it means that we are receiving data + # faster than we are reading it. Do not ack this packet, and drop received + # data + debug "Recevied packet would overflow receive buffer dropping it", + pkSeqNr = p.header.seqNr, + bytesReceived = payloadLength, + rcvbufferSize = socket.offset, + reorderBufferSize = socket.inBufferBytes return - # growing buffer before checking the packet is already there to avoid - # looking at older packet due to indices wrap aroud - socket.inBuffer.ensureSize(pkSeqNr + 1, pastExpected + 1) + debug "Received data packet", + bytesReceived = payloadLength + # we are getting in order data packet, we can flush data directly to the incoming buffer + # await upload(addr socket.buffer, unsafeAddr p.payload[0], p.payload.len()) + moveMem(addr socket.rcvBuffer[socket.offset], unsafeAddr p.payload[0], payloadLength) + socket.offset = socket.offset + payloadLength + + # Bytes have been passed to upper layer, we can increase number of last + # acked packet + inc socket.ackNr - if (socket.inBuffer.get(pkSeqNr).isSome()): - debug "Packet with seqNr already received", - seqNr = pkSeqNr - else: - let payloadLength = uint32(len(p.payload)) - if (socket.inBufferBytes + payloadLength <= socket.socketConfig.maxSizeOfReorderBuffer and - socket.inBufferBytes + uint32(socket.offset) + payloadLength <= socket.socketConfig.optRcvBuffer): - - debug "store packet in reorder buffer", - packetBytes = payloadLength, - packetSeqNr = p.header.seqNr, - packetAckNr = p.header.ackNr, - socketSeqNr = socket.seqNr, - socektAckNr = socket.ackNr, - rcvbufferSize = socket.offset, - reorderBufferSize = socket.inBufferBytes + # check if the following packets are in reorder buffer - socket.inBuffer.put(pkSeqNr, p) - inc socket.reorderCount - socket.inBufferBytes = socket.inBufferBytes + payloadLength - debug "added out of order packet to reorder buffer", - reorderCount = socket.reorderCount - # we send ack packet, as we reoreder count is > 0, so the eack bitmask will be - # generated - socket.sendAck() + debug "Looking for packets in re-order buffer", + reorderCount = socket.reorderCount - of ST_STATE: - if (socket.state == SynSent and (not socket.connectionFuture.finished())): - socket.state = Connected - # TODO reference implementation sets ackNr (p.header.seqNr - 1), although - # spec mention that it should be equal p.header.seqNr. For now follow the - # reference impl to be compatible with it. Later investigate trin compatibility. - socket.ackNr = p.header.seqNr - 1 - # In case of SynSent complate the future as last thing to make sure user of libray will - # receive socket in correct state - socket.connectionFuture.complete() + while true: + # We are doing this in reoreder loop, to handle the case when we already received + # fin but there were some gaps before eof + # we have reached remote eof, and should not receive more packets from remote + if ((not socket.reachedFin) and socket.gotFin and socket.eofPktNr == socket.ackNr): + debug "Reached socket EOF" + # In case of reaching eof, it is up to user of library what to to with + # it. With the current implementation, the most apropriate way would be to + # destory it (as with our implementation we know that remote is destroying its acked fin) + # as any other send will either generate timeout, or socket will be forcefully + # closed by reset + socket.reachedFin = true + # this is not necessarily true, but as we have already reached eof we can + # ignore following packets + socket.reorderCount = 0 - of ST_RESET: - debug "Received ST_RESET on known socket, ignoring" - of ST_SYN: - debug "Received ST_SYN on known socket, ignoring" + if socket.reorderCount == 0: + break + + let nextPacketNum = socket.ackNr + 1 + + let maybePacket = socket.inBuffer.get(nextPacketNum) + + if maybePacket.isNone(): + break + + let packet = maybePacket.unsafeGet() + let reorderPacketPayloadLength = len(packet.payload) + + if (reorderPacketPayloadLength > 0 and (not socket.readShutdown)): + debug "Got packet from reorder buffer", + packetBytes = len(packet.payload), + packetSeqNr = packet.header.seqNr, + packetAckNr = packet.header.ackNr, + socketSeqNr = socket.seqNr, + socektAckNr = socket.ackNr, + rcvbufferSize = socket.offset, + reorderBufferSize = socket.inBufferBytes + + # Rcv buffer and reorder buffer are sized that it is always possible to + # move data from reorder buffer to rcv buffer without overflow + moveMem(addr socket.rcvBuffer[socket.offset], unsafeAddr packet.payload[0], reorderPacketPayloadLength) + socket.offset = socket.offset + reorderPacketPayloadLength + + debug "Deleting packet", + seqNr = nextPacketNum + + socket.inBuffer.delete(nextPacketNum) + inc socket.ackNr + dec socket.reorderCount + socket.inBufferBytes = socket.inBufferBytes - uint32(reorderPacketPayloadLength) + + debug "Socket state after processing in order packet", + socketKey = socket.socketKey, + socketAckNr = socket.ackNr, + reorderCount = socket.reorderCount, + windowPackets = socket.curWindowPackets + + # TODO for now we just schedule concurrent task with ack sending. It may + # need improvement, as with this approach there is no direct control over + # how many concurrent tasks there are and how to cancel them when socket + # is closed + socket.sendAck() + + # we got packet out of order + else: + debug "Got out of order packet" + + if (socket.gotFin and pkSeqNr > socket.eofPktNr): + debug "Got packet past eof", + pkSeqNr = pkSeqNr, + eofPktNr = socket.eofPktNr + + return + + # growing buffer before checking the packet is already there to avoid + # looking at older packet due to indices wrap aroud + socket.inBuffer.ensureSize(pkSeqNr + 1, pastExpected + 1) + + if (socket.inBuffer.get(pkSeqNr).isSome()): + debug "Packet with seqNr already received", + seqNr = pkSeqNr + else: + let payloadLength = uint32(len(p.payload)) + if (socket.inBufferBytes + payloadLength <= socket.socketConfig.maxSizeOfReorderBuffer and + socket.inBufferBytes + uint32(socket.offset) + payloadLength <= socket.socketConfig.optRcvBuffer): + + debug "store packet in reorder buffer", + packetBytes = payloadLength, + packetSeqNr = p.header.seqNr, + packetAckNr = p.header.ackNr, + socketSeqNr = socket.seqNr, + socektAckNr = socket.ackNr, + rcvbufferSize = socket.offset, + reorderBufferSize = socket.inBufferBytes + + socket.inBuffer.put(pkSeqNr, p) + inc socket.reorderCount + socket.inBufferBytes = socket.inBufferBytes + payloadLength + debug "added out of order packet to reorder buffer", + reorderCount = socket.reorderCount + # we send ack packet, as we reoreder count is > 0, so the eack bitmask will be + # generated + socket.sendAck() proc processPacket*(socket: UtpSocket, p: Packet): Future[void] = socket.eventQueue.put(SocketEvent(kind: NewPacket, packet: p)) diff --git a/tests/utp/test_utp_socket.nim b/tests/utp/test_utp_socket.nim index fbae5f8..cf17e3c 100644 --- a/tests/utp/test_utp_socket.nim +++ b/tests/utp/test_utp_socket.nim @@ -1407,3 +1407,61 @@ procSuite "Utp socket unit test": resent3.header.seqNr == sent3.header.seqNr await outgoingSocket.destroyWait() + + asyncTest "Socket should accept data only in connected state": + let q = newAsyncQueue[Packet]() + let initialRemoteSeq = 10'u16 + let cfg = SocketConfig.init() + let remoteReciveBuffer = 1024'u32 + + let dataDropped = @[1'u8] + let dataRecived = @[2'u8] + + let sock1 = newOutgoingSocket[TransportAddress](testAddress, initTestSnd(q), cfg, defaultRcvOutgoingId, rng[]) + + asyncSpawn sock1.startOutgoingSocket() + + let initialPacket = await q.get() + + check: + initialPacket.header.pType == ST_SYN + + let dpDropped = dataPacket( + initialRemoteSeq, + initialPacket.header.connectionId, + initialPacket.header.seqNr, + testBufferSize, + dataDropped, + 0 + ) + + let dpReceived = dataPacket( + initialRemoteSeq, + initialPacket.header.connectionId, + initialPacket.header.seqNr, + testBufferSize, + dataRecived, + 0 + ) + + let responseAck = + ackPacket( + initialRemoteSeq, + initialPacket.header.connectionId, + initialPacket.header.seqNr, + remoteReciveBuffer, + 0 + ) + + # even though @[1'u8] is received first, it should be dropped as socket is not + # yet in connected state + await sock1.processPacket(dpDropped) + await sock1.processPacket(responseAck) + await sock1.processPacket(dpReceived) + + let receivedData = await sock1.read(1) + + check: + receivedData == dataRecived + + await sock1.destroyWait()