mirror of https://github.com/status-im/nim-eth.git
Add possibility to connect with requested conneciton id (#425)
* Improve error handling when initiating connection * Add api to connect with requested id * Add callback to allow only specific incoming peers
This commit is contained in:
parent
22757db83b
commit
d5e5ec9f90
|
@ -7,7 +7,7 @@
|
|||
{.push raises: [Defect].}
|
||||
|
||||
import
|
||||
chronos, stew/byteutils,
|
||||
chronos, stew/[results, byteutils],
|
||||
./utp_router,
|
||||
./utp_socket,
|
||||
./utp_protocol
|
||||
|
@ -32,7 +32,8 @@ when isMainModule:
|
|||
let utpProt = UtpProtocol.new(echoIncomingSocketCallBack(), localAddress)
|
||||
|
||||
let remoteServer = initTAddress("127.0.0.1", 9078)
|
||||
let soc = waitFor utpProt.connectTo(remoteServer)
|
||||
let socResult = waitFor utpProt.connectTo(remoteServer)
|
||||
let soc = socResult.get()
|
||||
|
||||
doAssert(soc.numPacketsInOutGoingBuffer() == 0)
|
||||
|
||||
|
|
|
@ -50,11 +50,13 @@ proc new*(
|
|||
subProtocolName: seq[byte],
|
||||
acceptConnectionCb: AcceptConnectionCallback[Node],
|
||||
socketConfig: SocketConfig = SocketConfig.init(),
|
||||
allowConnectionCb: AllowConnectionCallback[Node] = nil,
|
||||
rng = newRng()): UtpDiscv5Protocol {.raises: [Defect, CatchableError].} =
|
||||
doAssert(not(isNil(acceptConnectionCb)))
|
||||
|
||||
let router = UtpRouter[Node].new(
|
||||
acceptConnectionCb,
|
||||
allowConnectionCb,
|
||||
socketConfig,
|
||||
rng
|
||||
)
|
||||
|
@ -71,9 +73,12 @@ proc new*(
|
|||
)
|
||||
prot
|
||||
|
||||
proc connectTo*(r: UtpDiscv5Protocol, address: Node): Future[UtpSocket[Node]]=
|
||||
proc connectTo*(r: UtpDiscv5Protocol, address: Node): Future[ConnectionResult[Node]]=
|
||||
return r.router.connectTo(address)
|
||||
|
||||
proc connectTo*(r: UtpDiscv5Protocol, address: Node, connectionId: uint16): Future[ConnectionResult[Node]]=
|
||||
return r.router.connectTo(address, connectionId)
|
||||
|
||||
proc shutdown*(r: UtpDiscv5Protocol) =
|
||||
## closes all managed utp connections in background (not closed discovery, it is up to user)
|
||||
r.router.shutdown()
|
||||
|
|
|
@ -77,12 +77,14 @@ proc new*(
|
|||
acceptConnectionCb: AcceptConnectionCallback[TransportAddress],
|
||||
address: TransportAddress,
|
||||
socketConfig: SocketConfig = SocketConfig.init(),
|
||||
allowConnectionCb: AllowConnectionCallback[TransportAddress] = nil,
|
||||
rng = newRng()): UtpProtocol {.raises: [Defect, CatchableError].} =
|
||||
|
||||
doAssert(not(isNil(acceptConnectionCb)))
|
||||
|
||||
let router = UtpRouter[TransportAddress].new(
|
||||
acceptConnectionCb,
|
||||
allowConnectionCb,
|
||||
socketConfig,
|
||||
rng
|
||||
)
|
||||
|
@ -96,8 +98,11 @@ proc shutdownWait*(p: UtpProtocol): Future[void] {.async.} =
|
|||
await p.utpRouter.shutdownWait()
|
||||
await p.transport.closeWait()
|
||||
|
||||
proc connectTo*(r: UtpProtocol, address: TransportAddress): Future[UtpSocket[TransportAddress]] =
|
||||
proc connectTo*(r: UtpProtocol, address: TransportAddress): Future[ConnectionResult[TransportAddress]] =
|
||||
return r.utpRouter.connectTo(address)
|
||||
|
||||
proc connectTo*(r: UtpProtocol, address: TransportAddress, connectionId: uint16): Future[ConnectionResult[TransportAddress]] =
|
||||
return r.utpRouter.connectTo(address, connectionId)
|
||||
|
||||
proc openSockets*(r: UtpProtocol): int =
|
||||
len(r.utpRouter)
|
||||
|
|
|
@ -17,6 +17,11 @@ type
|
|||
AcceptConnectionCallback*[A] = proc(server: UtpRouter[A],
|
||||
client: UtpSocket[A]): Future[void] {.gcsafe, raises: [Defect].}
|
||||
|
||||
# Callback to act as fire wall for incoming peers. Should return true if peer is allowed
|
||||
# to connect.
|
||||
AllowConnectionCallback*[A] =
|
||||
proc(r: UtpRouter[A], remoteAddress: A, connectionId: uint16): bool {.gcsafe, raises: [Defect], noSideEffect.}
|
||||
|
||||
# Oject responsible for creating and maintaing table of of utp sockets.
|
||||
# caller should use `processIncomingBytes` proc to feed it with incoming byte
|
||||
# packets, based this input, proper utp sockets will be created, closed, or will
|
||||
|
@ -27,8 +32,14 @@ type
|
|||
acceptConnection: AcceptConnectionCallback[A]
|
||||
closed: bool
|
||||
sendCb*: SendCallback[A]
|
||||
allowConnection*: AllowConnectionCallback[A]
|
||||
rng*: ref BrHmacDrbgContext
|
||||
|
||||
const
|
||||
# Maximal number of tries to genearte unique socket while establishing outgoing
|
||||
# connection.
|
||||
maxSocketGenerationTries = 1000
|
||||
|
||||
# this should probably be in standard lib, it allows lazy composition of options i.e
|
||||
# one can write: O1 orElse O2 orElse O3, and chain will be evaluated to first option
|
||||
# which isSome()
|
||||
|
@ -57,25 +68,43 @@ proc len*[A](s: UtpRouter[A]): int =
|
|||
len(s.sockets)
|
||||
|
||||
proc registerUtpSocket[A](p: UtpRouter, s: UtpSocket[A]) =
|
||||
# TODO Handle duplicates
|
||||
## Register socket, overwriting already existing one
|
||||
p.sockets[s.socketKey] = s
|
||||
# Install deregister handler, so when socket will get closed, in will be promptly
|
||||
# removed from open sockets table
|
||||
s.registerCloseCallback(proc () = p.deRegisterUtpSocket(s))
|
||||
|
||||
proc registerIfAbsent[A](p: UtpRouter, s: UtpSocket[A]): bool =
|
||||
## Registers socket only if its not already exsiting in the active sockets table
|
||||
## return true is socket has been succesfuly registered
|
||||
if p.sockets.hasKey(s.socketKey):
|
||||
false
|
||||
else:
|
||||
p.registerUtpSocket(s)
|
||||
true
|
||||
|
||||
proc new*[A](
|
||||
T: type UtpRouter[A],
|
||||
acceptConnectionCb: AcceptConnectionCallback[A],
|
||||
acceptConnectionCb: AcceptConnectionCallback[A],
|
||||
allowConnectionCb: AllowConnectionCallback[A],
|
||||
socketConfig: SocketConfig = SocketConfig.init(),
|
||||
rng = newRng()): UtpRouter[A] {.raises: [Defect, CatchableError].} =
|
||||
doAssert(not(isNil(acceptConnectionCb)))
|
||||
UtpRouter[A](
|
||||
sockets: initTable[UtpSocketKey[A], UtpSocket[A]](),
|
||||
acceptConnection: acceptConnectionCb,
|
||||
allowConnection: allowConnectionCb,
|
||||
socketConfig: socketConfig,
|
||||
rng: rng
|
||||
)
|
||||
|
||||
proc new*[A](
|
||||
T: type UtpRouter[A],
|
||||
acceptConnectionCb: AcceptConnectionCallback[A],
|
||||
socketConfig: SocketConfig = SocketConfig.init(),
|
||||
rng = newRng()): UtpRouter[A] {.raises: [Defect, CatchableError].} =
|
||||
UtpRouter[A].new(acceptConnectionCb, nil, socketConfig, rng)
|
||||
|
||||
# There are different possiblites how connection was established, and we need to
|
||||
# check every case
|
||||
proc getSocketOnReset[A](r: UtpRouter[A], sender: A, id: uint16): Option[UtpSocket[A]] =
|
||||
|
@ -92,6 +121,13 @@ proc getSocketOnReset[A](r: UtpRouter[A], sender: A, id: uint16): Option[UtpSock
|
|||
.orElse(r.getUtpSocket(sendInitKey).filter(s => s.connectionIdSnd == id))
|
||||
.orElse(r.getUtpSocket(sendNoInitKey).filter(s => s.connectionIdSnd == id))
|
||||
|
||||
proc shouldAllowConnection[A](r: UtpRouter[A], remoteAddress: A, connectionId: uint16): bool =
|
||||
if r.allowConnection == nil:
|
||||
# if the callback is not configured it means all incoming connections are allowed
|
||||
true
|
||||
else:
|
||||
r.allowConnection(r, remoteAddress, connectionId)
|
||||
|
||||
proc processPacket[A](r: UtpRouter[A], p: Packet, sender: A) {.async.}=
|
||||
notice "Received packet ", packet = p
|
||||
|
||||
|
@ -116,18 +152,21 @@ proc processPacket[A](r: UtpRouter[A], p: Packet, sender: A) {.async.}=
|
|||
if (maybeSocket.isSome()):
|
||||
notice "Ignoring SYN for already existing connection"
|
||||
else:
|
||||
notice "Received SYN for not known connection. Initiating incoming connection"
|
||||
# Initial ackNr is set to incoming packer seqNr
|
||||
let incomingSocket = initIncomingSocket[A](sender, r.sendCb, r.socketConfig ,p.header.connectionId, p.header.seqNr, r.rng[])
|
||||
r.registerUtpSocket(incomingSocket)
|
||||
await incomingSocket.startIncomingSocket()
|
||||
# TODO By default (when we have utp over udp) socket here is passed to upper layer
|
||||
# in SynRecv state, which is not writeable i.e user of socket cannot write
|
||||
# data to it unless some data will be received. This is counter measure to
|
||||
# amplification attacks.
|
||||
# During integration with discovery v5 (i.e utp over discovv5), we must re-think
|
||||
# this.
|
||||
asyncSpawn r.acceptConnection(r, incomingSocket)
|
||||
if (r.shouldAllowConnection(sender, p.header.connectionId)):
|
||||
notice "Received SYN for not known connection. Initiating incoming connection"
|
||||
# Initial ackNr is set to incoming packer seqNr
|
||||
let incomingSocket = initIncomingSocket[A](sender, r.sendCb, r.socketConfig ,p.header.connectionId, p.header.seqNr, r.rng[])
|
||||
r.registerUtpSocket(incomingSocket)
|
||||
await incomingSocket.startIncomingSocket()
|
||||
# TODO By default (when we have utp over udp) socket here is passed to upper layer
|
||||
# in SynRecv state, which is not writeable i.e user of socket cannot write
|
||||
# data to it unless some data will be received. This is counter measure to
|
||||
# amplification attacks.
|
||||
# During integration with discovery v5 (i.e utp over discovv5), we must re-think
|
||||
# this.
|
||||
asyncSpawn r.acceptConnection(r, incomingSocket)
|
||||
else:
|
||||
notice "Connection declined"
|
||||
else:
|
||||
let socketKey = UtpSocketKey[A].init(sender, p.header.connectionId)
|
||||
let maybeSocket = r.getUtpSocket(socketKey)
|
||||
|
@ -149,14 +188,60 @@ proc processIncomingBytes*[A](r: UtpRouter[A], bytes: seq[byte], sender: A) {.as
|
|||
else:
|
||||
warn "failed to decode packet from address", address = sender
|
||||
|
||||
proc generateNewUniqueSocket[A](r: UtpRouter[A], address: A): Option[UtpSocket[A]] =
|
||||
## Tries to generate unique socket, gives up after maxSocketGenerationTries tries
|
||||
var tryCount = 0
|
||||
|
||||
while tryCount < maxSocketGenerationTries:
|
||||
let rcvId = randUint16(r.rng[])
|
||||
let socket = initOutgoingSocket[A](address, r.sendCb, r.socketConfig, rcvId, r.rng[])
|
||||
|
||||
if r.registerIfAbsent(socket):
|
||||
return some(socket)
|
||||
|
||||
inc tryCount
|
||||
|
||||
return none[UtpSocket[A]]()
|
||||
|
||||
proc connect[A](s: UtpSocket[A]): Future[ConnectionResult[A]] {.async.}=
|
||||
let startFut = s.startOutgoingSocket()
|
||||
|
||||
startFut.cancelCallback = proc(udata: pointer) {.gcsafe.} =
|
||||
# if for some reason future will be cancelled, destory socket to clear it from
|
||||
# active socket list
|
||||
s.destroy()
|
||||
|
||||
try:
|
||||
await startFut
|
||||
return ok(s)
|
||||
except ConnectionError:
|
||||
s.destroy()
|
||||
return err(OutgoingConnectionError(kind: ConnectionTimedOut))
|
||||
except CatchableError as e:
|
||||
s.destroy()
|
||||
# this may only happen if user provided callback will for some reason fail
|
||||
return err(OutgoingConnectionError(kind: ErrorWhileSendingSyn, error: e))
|
||||
|
||||
# Connect to provided address
|
||||
# Reference implementation: https://github.com/bittorrent/libutp/blob/master/utp_internal.cpp#L2732
|
||||
proc connectTo*[A](r: UtpRouter[A], address: A): Future[UtpSocket[A]] {.async.}=
|
||||
let socket = initOutgoingSocket[A](address, r.sendCb, r.socketConfig, r.rng[])
|
||||
r.registerUtpSocket(socket)
|
||||
await socket.startOutgoingSocket()
|
||||
await socket.waitFotSocketToConnect()
|
||||
return socket
|
||||
proc connectTo*[A](r: UtpRouter[A], address: A): Future[ConnectionResult[A]] {.async.} =
|
||||
let maybeSocket = r.generateNewUniqueSocket(address)
|
||||
|
||||
if (maybeSocket.isNone()):
|
||||
return err(OutgoingConnectionError(kind: SocketAlreadyExists))
|
||||
else:
|
||||
let socket = maybeSocket.unsafeGet()
|
||||
return await socket.connect()
|
||||
|
||||
# Connect to provided address with provided connection id, if socket with this id
|
||||
# and address already exsits return error
|
||||
proc connectTo*[A](r: UtpRouter[A], address: A, connectionId: uint16): Future[ConnectionResult[A]] {.async.} =
|
||||
let socket = initOutgoingSocket[A](address, r.sendCb, r.socketConfig, connectionId, r.rng[])
|
||||
|
||||
if (r.registerIfAbsent(socket)):
|
||||
return await socket.connect()
|
||||
else:
|
||||
return err(OutgoingConnectionError(kind: SocketAlreadyExists))
|
||||
|
||||
proc shutdown*[A](r: UtpRouter[A]) =
|
||||
# stop processing any new packets and close all sockets in background without
|
||||
|
|
|
@ -161,6 +161,18 @@ type
|
|||
|
||||
WriteResult* = Result[int, WriteError]
|
||||
|
||||
OutgoingConnectionErrorType* = enum
|
||||
SocketAlreadyExists, ConnectionTimedOut, ErrorWhileSendingSyn
|
||||
|
||||
OutgoingConnectionError* = object
|
||||
case kind*: OutgoingConnectionErrorType
|
||||
of ErrorWhileSendingSyn:
|
||||
error*: ref CatchableError
|
||||
of SocketAlreadyExists, ConnectionTimedOut:
|
||||
discard
|
||||
|
||||
ConnectionResult*[A] = Result[UtpSocket[A], OutgoingConnectionError]
|
||||
|
||||
const
|
||||
# Maximal number of payload bytes per packet. Total packet size will be equal to
|
||||
# mtuSize + sizeof(header) = 600 bytes
|
||||
|
@ -258,16 +270,6 @@ proc sendAck(socket: UtpSocket): Future[void] =
|
|||
)
|
||||
socket.sendData(encodePacket(ackPacket))
|
||||
|
||||
proc sendSyn(socket: UtpSocket): Future[void] =
|
||||
doAssert(socket.state == SynSent , "syn can only be send when in SynSent state")
|
||||
let packet = synPacket(socket.seqNr, socket.connectionIdRcv, socket.getRcvWindowSize())
|
||||
notice "Sending syn packet packet", packet = packet
|
||||
# set number of transmissions to 1 as syn packet will be send just after
|
||||
# initiliazation
|
||||
let outgoingPacket = OutgoingPacket.init(encodePacket(packet), 1, false)
|
||||
socket.registerOutgoingPacket(outgoingPacket)
|
||||
socket.sendData(outgoingPacket.packetBytes)
|
||||
|
||||
# Should be called before sending packet
|
||||
proc setSend(p: var OutgoingPacket): seq[byte] =
|
||||
inc p.transmissions
|
||||
|
@ -423,10 +425,9 @@ proc initOutgoingSocket*[A](
|
|||
to: A,
|
||||
snd: SendCallback[A],
|
||||
cfg: SocketConfig,
|
||||
rcvConnectionId: uint16,
|
||||
rng: var BrHmacDrbgContext
|
||||
): UtpSocket[A] =
|
||||
# TODO handle possible clashes and overflows
|
||||
let rcvConnectionId = randUint16(rng)
|
||||
let sndConnectionId = rcvConnectionId + 1
|
||||
let initialSeqNr = randUint16(rng)
|
||||
|
||||
|
@ -467,18 +468,19 @@ proc initIncomingSocket*[A](
|
|||
|
||||
proc startOutgoingSocket*(socket: UtpSocket): Future[void] {.async.} =
|
||||
doAssert(socket.state == SynSent)
|
||||
# TODO add callback to handle errors and cancellation i.e unregister socket on
|
||||
# send error and finish connection future with failure
|
||||
# sending should be done from UtpSocketContext
|
||||
await socket.sendSyn()
|
||||
let packet = synPacket(socket.seqNr, socket.connectionIdRcv, socket.getRcvWindowSize())
|
||||
notice "Sending syn packet packet", packet = packet
|
||||
# set number of transmissions to 1 as syn packet will be send just after
|
||||
# initiliazation
|
||||
let outgoingPacket = OutgoingPacket.init(encodePacket(packet), 1, false)
|
||||
socket.registerOutgoingPacket(outgoingPacket)
|
||||
socket.startTimeoutLoop()
|
||||
|
||||
proc waitFotSocketToConnect*(socket: UtpSocket): Future[void] {.async.} =
|
||||
await socket.sendData(outgoingPacket.packetBytes)
|
||||
await socket.connectionFuture
|
||||
|
||||
proc startIncomingSocket*(socket: UtpSocket) {.async.} =
|
||||
doAssert(socket.state == SynRecv)
|
||||
# Make sure ack was flushed before movig forward
|
||||
# Make sure ack was flushed before moving forward
|
||||
await socket.sendAck()
|
||||
socket.startTimeoutLoop()
|
||||
|
||||
|
@ -928,3 +930,13 @@ proc numPacketsInReordedBuffer*(socket: UtpSocket): int =
|
|||
inc num
|
||||
doAssert(num == int(socket.reorderCount))
|
||||
num
|
||||
|
||||
proc connectionId*[A](socket: UtpSocket[A]): uint16 =
|
||||
## Connection id is id which is used in first SYN packet which establishes the connection
|
||||
## so for Outgoing side it is actually its rcv_id, and for Incoming side it is
|
||||
## its snd_id
|
||||
case socket.direction
|
||||
of Incoming:
|
||||
socket.connectionIdSnd
|
||||
of Outgoing:
|
||||
socket.connectionIdRcv
|
||||
|
|
|
@ -55,6 +55,12 @@ procSuite "Utp protocol over discovery v5 tests":
|
|||
serverSockets.addLast(client)
|
||||
)
|
||||
|
||||
proc allowOneIdCallback(allowedId: uint16): AllowConnectionCallback[Node] =
|
||||
return (
|
||||
proc(r: UtpRouter[Node], remoteAddress: Node, connectionId: uint16): bool =
|
||||
connectionId == allowedId
|
||||
)
|
||||
|
||||
# TODO Add more tests to discovery v5 suite, especially those which will differ
|
||||
# from standard utp case
|
||||
asyncTest "Success connect to remote host":
|
||||
|
@ -73,8 +79,9 @@ procSuite "Utp protocol over discovery v5 tests":
|
|||
node1.addNode(node2.localNode)
|
||||
node2.addNode(node1.localNode)
|
||||
|
||||
let clientSocket = await utp1.connectTo(node2.localNode)
|
||||
|
||||
let clientSocketResult = await utp1.connectTo(node2.localNode)
|
||||
let clientSocket = clientSocketResult.get()
|
||||
|
||||
check:
|
||||
clientSocket.isConnected()
|
||||
|
||||
|
@ -99,7 +106,9 @@ procSuite "Utp protocol over discovery v5 tests":
|
|||
node2.addNode(node1.localNode)
|
||||
|
||||
let numOfBytes = 5000
|
||||
let clientSocket = await utp1.connectTo(node2.localNode)
|
||||
let clientSocketResult = await utp1.connectTo(node2.localNode)
|
||||
let clientSocket = clientSocketResult.get()
|
||||
|
||||
let serverSocket = await queue.get()
|
||||
|
||||
let bytesToTransfer = generateByteArray(rng[], numOfBytes)
|
||||
|
@ -117,3 +126,48 @@ procSuite "Utp protocol over discovery v5 tests":
|
|||
await serverSocket.destroyWait()
|
||||
await node1.closeWait()
|
||||
await node2.closeWait()
|
||||
|
||||
asyncTest "Accept connection only from allowed peers":
|
||||
let
|
||||
allowedId: uint16 = 10
|
||||
lowSynTimeout = milliseconds(500)
|
||||
queue = newAsyncQueue[UtpSocket[Node]]()
|
||||
node1 = initDiscoveryNode(
|
||||
rng, PrivateKey.random(rng[]), localAddress(20302))
|
||||
node2 = initDiscoveryNode(
|
||||
rng, PrivateKey.random(rng[]), localAddress(20303))
|
||||
|
||||
utp1 = UtpDiscv5Protocol.new(
|
||||
node1,
|
||||
utpProtId,
|
||||
registerIncomingSocketCallback(queue),
|
||||
SocketConfig.init(lowSynTimeout))
|
||||
utp2 =
|
||||
UtpDiscv5Protocol.new(
|
||||
node2,
|
||||
utpProtId,
|
||||
registerIncomingSocketCallback(queue),
|
||||
SocketConfig.init(),
|
||||
allowOneIdCallback(allowedId))
|
||||
|
||||
# nodes must know about each other
|
||||
check:
|
||||
node1.addNode(node2.localNode)
|
||||
node2.addNode(node1.localNode)
|
||||
|
||||
let clientSocketResult1 = await utp1.connectTo(node2.localNode, allowedId)
|
||||
let clientSocketResult2 = await utp1.connectTo(node2.localNode, allowedId + 1)
|
||||
|
||||
check:
|
||||
clientSocketResult1.isOk()
|
||||
clientSocketResult2.isErr()
|
||||
|
||||
let clientSocket = clientSocketResult1.get()
|
||||
let serverSocket = await queue.get()
|
||||
|
||||
check:
|
||||
clientSocket.connectionId() == allowedId
|
||||
serverSocket.connectionId() == allowedId
|
||||
|
||||
await node1.closeWait()
|
||||
await node2.closeWait()
|
||||
|
|
|
@ -30,6 +30,12 @@ proc registerIncomingSocketCallback(serverSockets: AsyncQueue): AcceptConnection
|
|||
serverSockets.addLast(client)
|
||||
)
|
||||
|
||||
proc allowOneIdCallback(allowedId: uint16): AllowConnectionCallback[TransportAddress] =
|
||||
return (
|
||||
proc(r: UtpRouter[TransportAddress], remoteAddress: TransportAddress, connectionId: uint16): bool =
|
||||
connectionId == allowedId
|
||||
)
|
||||
|
||||
proc transferData(sender: UtpSocket[TransportAddress], receiver: UtpSocket[TransportAddress], data: seq[byte]): Future[seq[byte]] {.async.}=
|
||||
let bytesWritten = await sender.write(data)
|
||||
doAssert bytesWritten.get() == len(data)
|
||||
|
@ -67,7 +73,7 @@ proc initClientServerScenario(): Future[ClientServerScenario] {.async.} =
|
|||
return ClientServerScenario(
|
||||
utp1: utpProt1,
|
||||
utp2: utpProt2,
|
||||
clientSocket: clientSocket,
|
||||
clientSocket: clientSocket.get(),
|
||||
serverSocket: serverSocket
|
||||
)
|
||||
|
||||
|
@ -102,8 +108,8 @@ proc init2ClientsServerScenario(): Future[TwoClientsServerScenario] {.async.} =
|
|||
utp1: utpProt1,
|
||||
utp2: utpProt2,
|
||||
utp3: utpProt3,
|
||||
clientSocket1: clientSocket1,
|
||||
clientSocket2: clientSocket2,
|
||||
clientSocket1: clientSocket1.get(),
|
||||
clientSocket2: clientSocket2.get(),
|
||||
serverSocket1: serverSocket1,
|
||||
serverSocket2: serverSocket2
|
||||
)
|
||||
|
@ -125,8 +131,8 @@ procSuite "Utp protocol over udp tests":
|
|||
let address1 = initTAddress("127.0.0.1", 9080)
|
||||
let utpProt2 = UtpProtocol.new(setAcceptedCallback(server2Called), address1)
|
||||
|
||||
let sock = await utpProt1.connectTo(address1)
|
||||
|
||||
let sockResult = await utpProt1.connectTo(address1)
|
||||
let sock = sockResult.get()
|
||||
# this future will be completed when we called accepted connection callback
|
||||
await server2Called.wait()
|
||||
|
||||
|
@ -148,13 +154,16 @@ procSuite "Utp protocol over udp tests":
|
|||
|
||||
let address1 = initTAddress("127.0.0.1", 9080)
|
||||
|
||||
let fut = utpProt1.connectTo(address1)
|
||||
|
||||
yield fut
|
||||
let connectionResult = await utpProt1.connectTo(address1)
|
||||
|
||||
check:
|
||||
fut.failed()
|
||||
|
||||
connectionResult.isErr()
|
||||
|
||||
let connectionError = connectionResult.error()
|
||||
|
||||
check:
|
||||
connectionError.kind == ConnectionTimedOut
|
||||
|
||||
await waitUntil(proc (): bool = utpProt1.openSockets() == 0)
|
||||
|
||||
check:
|
||||
|
@ -370,3 +379,45 @@ procSuite "Utp protocol over udp tests":
|
|||
s.utp1.openSockets() == 0
|
||||
|
||||
await s.close()
|
||||
|
||||
asyncTest "Accept connection only from allowed peers":
|
||||
let allowedId: uint16 = 10
|
||||
let lowSynTimeout = milliseconds(500)
|
||||
var serverSockets = newAsyncQueue[UtpSocket[TransportAddress]]()
|
||||
var server1Called = newAsyncEvent()
|
||||
let address1 = initTAddress("127.0.0.1", 9079)
|
||||
let utpProt1 =
|
||||
UtpProtocol.new(setAcceptedCallback(server1Called), address1, SocketConfig.init(lowSynTimeout))
|
||||
|
||||
let address2 = initTAddress("127.0.0.1", 9080)
|
||||
let utpProt2 =
|
||||
UtpProtocol.new(registerIncomingSocketCallback(serverSockets), address2, SocketConfig.init(lowSynTimeout))
|
||||
|
||||
let address3 = initTAddress("127.0.0.1", 9081)
|
||||
let utpProt3 =
|
||||
UtpProtocol.new(
|
||||
registerIncomingSocketCallback(serverSockets),
|
||||
address3,
|
||||
SocketConfig.init(),
|
||||
allowOneIdCallback(allowedId)
|
||||
)
|
||||
|
||||
let allowedSocketRes = await utpProt1.connectTo(address3, allowedId)
|
||||
let notAllowedSocketRes = await utpProt2.connectTo(address3, allowedId + 1)
|
||||
|
||||
check:
|
||||
allowedSocketRes.isOk()
|
||||
notAllowedSocketRes.isErr()
|
||||
# remote did not allow this connection, and utlimatly it did time out
|
||||
notAllowedSocketRes.error().kind == ConnectionTimedOut
|
||||
|
||||
let clientSocket = allowedSocketRes.get()
|
||||
let serverSocket = await serverSockets.get()
|
||||
|
||||
check:
|
||||
clientSocket.connectionId() == allowedId
|
||||
serverSocket.connectionId() == allowedId
|
||||
|
||||
await utpProt1.shutdownWait()
|
||||
await utpProt2.shutdownWait()
|
||||
await utpProt3.shutdownWait()
|
||||
|
|
|
@ -21,6 +21,9 @@ proc hash*(x: UtpSocketKey[int]): Hash =
|
|||
h = h !& x.rcvId.hash
|
||||
!$h
|
||||
|
||||
type
|
||||
TestError* = object of CatchableError
|
||||
|
||||
procSuite "Utp router unit tests":
|
||||
let rng = newRng()
|
||||
let testSender = 1
|
||||
|
@ -62,7 +65,7 @@ procSuite "Utp router unit tests":
|
|||
await router.processIncomingBytes(encodePacket(responseAck), remote)
|
||||
|
||||
let outgoingSocket = await connectFuture
|
||||
(outgoingSocket, initialPacket)
|
||||
(outgoingSocket.get(), initialPacket)
|
||||
|
||||
asyncTest "Router should ingnore non utp packets":
|
||||
let q = newAsyncQueue[UtpSocket[int]]()
|
||||
|
@ -149,6 +152,98 @@ procSuite "Utp router unit tests":
|
|||
outgoingSocket.isConnected()
|
||||
router.len() == 1
|
||||
|
||||
asyncTest "Router should fail to connect to the same peer with the same connection id":
|
||||
let q = newAsyncQueue[UtpSocket[int]]()
|
||||
let pq = newAsyncQueue[(Packet, int)]()
|
||||
let initialRemoteSeq = 30'u16
|
||||
let router = UtpRouter[int].new(registerIncomingSocketCallback(q), SocketConfig.init(), rng)
|
||||
router.sendCb = initTestSnd(pq)
|
||||
|
||||
let requestedConnectionId = 1'u16
|
||||
let connectFuture = router.connectTo(testSender2, requestedConnectionId)
|
||||
|
||||
let (initialPacket, sender) = await pq.get()
|
||||
|
||||
check:
|
||||
initialPacket.header.pType == ST_SYN
|
||||
# connection id of syn packet should be set to requested connection id
|
||||
initialPacket.header.connectionId == requestedConnectionId
|
||||
|
||||
let responseAck = ackPacket(initialRemoteSeq, initialPacket.header.connectionId, initialPacket.header.seqNr, testBufferSize)
|
||||
|
||||
await router.processIncomingBytes(encodePacket(responseAck), testSender2)
|
||||
|
||||
let outgoingSocket = await connectFuture
|
||||
|
||||
check:
|
||||
outgoingSocket.get().isConnected()
|
||||
router.len() == 1
|
||||
|
||||
let duplicatedConnectionResult = await router.connectTo(testSender2, requestedConnectionId)
|
||||
|
||||
check:
|
||||
duplicatedConnectionResult.isErr()
|
||||
duplicatedConnectionResult.error().kind == SocketAlreadyExists
|
||||
|
||||
asyncTest "Router should fail connect when socket syn will not be acked":
|
||||
let q = newAsyncQueue[UtpSocket[int]]()
|
||||
let pq = newAsyncQueue[(Packet, int)]()
|
||||
let router = UtpRouter[int].new(registerIncomingSocketCallback(q), SocketConfig.init(milliseconds(500)), rng)
|
||||
router.sendCb = initTestSnd(pq)
|
||||
|
||||
let connectFuture = router.connectTo(testSender2)
|
||||
|
||||
let (initialPacket, sender) = await pq.get()
|
||||
|
||||
check:
|
||||
initialPacket.header.pType == ST_SYN
|
||||
|
||||
let connectResult = await connectFuture
|
||||
|
||||
check:
|
||||
connectResult.isErr()
|
||||
connectResult.error().kind == ConnectionTimedOut
|
||||
router.len() == 0
|
||||
|
||||
asyncTest "Router should clear all resources when connection future is cancelled":
|
||||
let q = newAsyncQueue[UtpSocket[int]]()
|
||||
let pq = newAsyncQueue[(Packet, int)]()
|
||||
let router = UtpRouter[int].new(registerIncomingSocketCallback(q), SocketConfig.init(milliseconds(500)), rng)
|
||||
router.sendCb = initTestSnd(pq)
|
||||
|
||||
let connectFuture = router.connectTo(testSender2)
|
||||
|
||||
let (initialPacket, sender) = await pq.get()
|
||||
|
||||
check:
|
||||
initialPacket.header.pType == ST_SYN
|
||||
router.len() == 1
|
||||
|
||||
await connectFuture.cancelAndWait()
|
||||
|
||||
check:
|
||||
router.len() == 0
|
||||
|
||||
asyncTest "Router should clear all resources and handle error while sending syn packet":
|
||||
let q = newAsyncQueue[UtpSocket[int]]()
|
||||
let pq = newAsyncQueue[(Packet, int)]()
|
||||
let router = UtpRouter[int].new(registerIncomingSocketCallback(q), SocketConfig.init(milliseconds(500)), rng)
|
||||
router.sendCb =
|
||||
proc (to: int, data: seq[byte]): Future[void] =
|
||||
let f = newFuture[void]()
|
||||
f.fail(newException(TestError, "faile"))
|
||||
return f
|
||||
|
||||
let connectResult = await router.connectTo(testSender2)
|
||||
|
||||
await waitUntil(proc (): bool = router.len() == 0)
|
||||
|
||||
check:
|
||||
connectResult.isErr()
|
||||
connectResult.error().kind == ErrorWhileSendingSyn
|
||||
cast[TestError](connectResult.error().error) is TestError
|
||||
router.len() == 0
|
||||
|
||||
asyncTest "Router should clear closed outgoing connections":
|
||||
let q = newAsyncQueue[UtpSocket[int]]()
|
||||
let pq = newAsyncQueue[(Packet, int)]()
|
||||
|
@ -225,4 +320,3 @@ procSuite "Utp router unit tests":
|
|||
|
||||
check:
|
||||
router.len() == 0
|
||||
|
||||
|
|
|
@ -20,6 +20,7 @@ procSuite "Utp socket unit test":
|
|||
let rng = newRng()
|
||||
let testAddress = initTAddress("127.0.0.1", 9079)
|
||||
let testBufferSize = 1024'u32
|
||||
let defaultRcvOutgoingId = 314'u16
|
||||
|
||||
proc initTestSnd(q: AsyncQueue[Packet]): SendCallback[TransportAddress]=
|
||||
return (
|
||||
|
@ -61,8 +62,8 @@ procSuite "Utp socket unit test":
|
|||
initialRemoteSeq: uint16,
|
||||
q: AsyncQueue[Packet],
|
||||
cfg: SocketConfig = SocketConfig.init()): (UtpSocket[TransportAddress], Packet) =
|
||||
let sock1 = initOutgoingSocket[TransportAddress](testAddress, initTestSnd(q), cfg, rng[])
|
||||
await sock1.startOutgoingSocket()
|
||||
let sock1 = initOutgoingSocket[TransportAddress](testAddress, initTestSnd(q), cfg, defaultRcvOutgoingId, rng[])
|
||||
asyncSpawn sock1.startOutgoingSocket()
|
||||
let initialPacket = await q.get()
|
||||
|
||||
check:
|
||||
|
@ -80,18 +81,32 @@ procSuite "Utp socket unit test":
|
|||
asyncTest "Starting outgoing socket should send Syn packet":
|
||||
let q = newAsyncQueue[Packet]()
|
||||
let defaultConfig = SocketConfig.init()
|
||||
let sock1 = initOutgoingSocket[TransportAddress](testAddress, initTestSnd(q), defaultConfig, rng[])
|
||||
await sock1.startOutgoingSocket()
|
||||
let sock1 = initOutgoingSocket[TransportAddress](
|
||||
testAddress,
|
||||
initTestSnd(q),
|
||||
defaultConfig,
|
||||
defaultRcvOutgoingId,
|
||||
rng[]
|
||||
)
|
||||
let fut1 = sock1.startOutgoingSocket()
|
||||
let initialPacket = await q.get()
|
||||
|
||||
check:
|
||||
initialPacket.header.pType == ST_SYN
|
||||
initialPacket.header.wndSize == defaultConfig.optRcvBuffer
|
||||
|
||||
fut1.cancel()
|
||||
|
||||
asyncTest "Outgoing socket should re-send syn packet 2 times before declaring failure":
|
||||
let q = newAsyncQueue[Packet]()
|
||||
let sock1 = initOutgoingSocket[TransportAddress](testAddress, initTestSnd(q), SocketConfig.init(milliseconds(100)), rng[])
|
||||
await sock1.startOutgoingSocket()
|
||||
let sock1 = initOutgoingSocket[TransportAddress](
|
||||
testAddress,
|
||||
initTestSnd(q),
|
||||
SocketConfig.init(milliseconds(100)),
|
||||
defaultRcvOutgoingId,
|
||||
rng[]
|
||||
)
|
||||
let fut1 = sock1.startOutgoingSocket()
|
||||
let initialPacket = await q.get()
|
||||
|
||||
check:
|
||||
|
@ -113,6 +128,8 @@ procSuite "Utp socket unit test":
|
|||
check:
|
||||
not sock1.isConnected()
|
||||
|
||||
fut1.cancel()
|
||||
|
||||
asyncTest "Processing in order ack should make socket connected":
|
||||
let q = newAsyncQueue[Packet]()
|
||||
let initialRemoteSeq = 10'u16
|
||||
|
@ -281,8 +298,16 @@ procSuite "Utp socket unit test":
|
|||
let q = newAsyncQueue[Packet]()
|
||||
let initalRemoteSeqNr = 10'u16
|
||||
|
||||
let outgoingSocket = initOutgoingSocket[TransportAddress](testAddress, initTestSnd(q), SocketConfig.init(milliseconds(50), 2), rng[])
|
||||
await outgoingSocket.startOutgoingSocket()
|
||||
let outgoingSocket = initOutgoingSocket[TransportAddress](
|
||||
testAddress,
|
||||
initTestSnd(q),
|
||||
SocketConfig.init(milliseconds(3000), 2),
|
||||
defaultRcvOutgoingId,
|
||||
rng[]
|
||||
)
|
||||
|
||||
let fut1 = outgoingSocket.startOutgoingSocket()
|
||||
|
||||
let initialPacket = await q.get()
|
||||
|
||||
check:
|
||||
|
|
Loading…
Reference in New Issue