diff --git a/eth/utp/utp_discv5_protocol.nim b/eth/utp/utp_discv5_protocol.nim index 81b79e7..84df816 100644 --- a/eth/utp/utp_discv5_protocol.nim +++ b/eth/utp/utp_discv5_protocol.nim @@ -89,6 +89,7 @@ proc new*( p: protocol.Protocol, subProtocolName: seq[byte], acceptConnectionCb: AcceptConnectionCallback[NodeAddress], + udata: pointer = nil, allowConnectionCb: AllowConnectionCallback[NodeAddress] = nil, socketConfig: SocketConfig = SocketConfig.init()): UtpDiscv5Protocol = doAssert(not(isNil(acceptConnectionCb))) @@ -96,6 +97,7 @@ proc new*( let router = UtpRouter[NodeAddress].new( acceptConnectionCb, allowConnectionCb, + udata, socketConfig, p.rng ) @@ -112,6 +114,24 @@ proc new*( ) prot +proc new*( + T: type UtpDiscv5Protocol, + p: protocol.Protocol, + subProtocolName: seq[byte], + acceptConnectionCb: AcceptConnectionCallback[NodeAddress], + udata: ref, + allowConnectionCb: AllowConnectionCallback[NodeAddress] = nil, + socketConfig: SocketConfig = SocketConfig.init()): UtpDiscv5Protocol = + GC_ref(udata) + UtpDiscv5Protocol.new( + p, + subProtocolName, + acceptConnectionCb, + cast[pointer](udata), + allowConnectionCb, + socketConfig + ) + proc connectTo*(r: UtpDiscv5Protocol, address: NodeAddress): Future[ConnectionResult[NodeAddress]] = return r.router.connectTo(address) diff --git a/eth/utp/utp_protocol.nim b/eth/utp/utp_protocol.nim index db2d358..332b462 100644 --- a/eth/utp/utp_protocol.nim +++ b/eth/utp/utp_protocol.nim @@ -75,19 +75,21 @@ proc initSendCallback(t: DatagramTransport): SendCallback[TransportAddress] = ) proc new*( - T: type UtpProtocol, - acceptConnectionCb: AcceptConnectionCallback[TransportAddress], - address: TransportAddress, - socketConfig: SocketConfig = SocketConfig.init(), - allowConnectionCb: AllowConnectionCallback[TransportAddress] = nil, - sendCallbackBuilder: SendCallbackBuilder = nil, - rng = newRng()): UtpProtocol {.raises: [Defect, CatchableError].} = + T: type UtpProtocol, + acceptConnectionCb: AcceptConnectionCallback[TransportAddress], + address: TransportAddress, + udata: pointer = nil, + socketConfig: SocketConfig = SocketConfig.init(), + allowConnectionCb: AllowConnectionCallback[TransportAddress] = nil, + sendCallbackBuilder: SendCallbackBuilder = nil, + rng = newRng()): UtpProtocol {.raises: [Defect, CatchableError].} = doAssert(not(isNil(acceptConnectionCb))) let router = UtpRouter[TransportAddress].new( acceptConnectionCb, allowConnectionCb, + udata, socketConfig, rng ) @@ -101,6 +103,26 @@ proc new*( UtpProtocol(transport: ta, utpRouter: router) +proc new*( + T: type UtpProtocol, + acceptConnectionCb: AcceptConnectionCallback[TransportAddress], + address: TransportAddress, + udata: ref, + socketConfig: SocketConfig = SocketConfig.init(), + allowConnectionCb: AllowConnectionCallback[TransportAddress] = nil, + sendCallbackBuilder: SendCallbackBuilder = nil, + rng = newRng()): UtpProtocol {.raises: [Defect, CatchableError].} = + GC_ref(udata) + UtpProtocol.new( + acceptConnectionCb, + address, + cast[pointer](udata), + socketConfig, + allowConnectionCb, + sendCallbackBuilder, + rng + ) + proc shutdownWait*(p: UtpProtocol): Future[void] {.async.} = ## closes all managed utp sockets and then underlying transport await p.utpRouter.shutdownWait() diff --git a/eth/utp/utp_router.nim b/eth/utp/utp_router.nim index 317a881..e6d690a 100644 --- a/eth/utp/utp_router.nim +++ b/eth/utp/utp_router.nim @@ -56,6 +56,7 @@ type closed: bool sendCb*: SendCallback[A] allowConnection*: AllowConnectionCallback[A] + udata: pointer rng*: ref HmacDrbgContext const @@ -114,6 +115,7 @@ proc new*[A]( T: type UtpRouter[A], acceptConnectionCb: AcceptConnectionCallback[A], allowConnectionCb: AllowConnectionCallback[A], + udata: pointer, socketConfig: SocketConfig = SocketConfig.init(), rng = newRng()): UtpRouter[A] = doAssert(not(isNil(acceptConnectionCb))) @@ -122,6 +124,7 @@ proc new*[A]( acceptConnection: acceptConnectionCb, allowConnection: allowConnectionCb, socketConfig: socketConfig, + udata: udata, rng: rng ) @@ -130,7 +133,30 @@ proc new*[A]( acceptConnectionCb: AcceptConnectionCallback[A], socketConfig: SocketConfig = SocketConfig.init(), rng = newRng()): UtpRouter[A] = - UtpRouter[A].new(acceptConnectionCb, nil, socketConfig, rng) + UtpRouter[A].new(acceptConnectionCb, nil, nil, socketConfig, rng) + +proc new*[A]( + T: type UtpRouter[A], + acceptConnectionCb: AcceptConnectionCallback[A], + allowConnectionCb: AllowConnectionCallback[A], + udata: ref, + socketConfig: SocketConfig = SocketConfig.init(), + rng = newRng()): UtpRouter[A] = + doAssert(not(isNil(acceptConnectionCb))) + GC_ref(udata) + UtpRouter[A].new(acceptConnectionCb, allowConnectionCb, cast[pointer](udata), socketConfig, rng) + +proc new*[A]( + T: type UtpRouter[A], + acceptConnectionCb: AcceptConnectionCallback[A], + udata: ref, + socketConfig: SocketConfig = SocketConfig.init(), + rng = newRng()): UtpRouter[A] = + UtpRouter[A].new(acceptConnectionCb, nil, udata, socketConfig, rng) + +proc getUserData*[A, T](router: UtpRouter[A]): T = + ## Obtain user data stored in ``router`` object. + cast[T](router.udata) # There are different possibilities on how the connection got established, need # to check every case. diff --git a/tests/utp/test_discv5_protocol.nim b/tests/utp/test_discv5_protocol.nim index d4e4545..d35cb51 100644 --- a/tests/utp/test_discv5_protocol.nim +++ b/tests/utp/test_discv5_protocol.nim @@ -15,6 +15,7 @@ import ../../eth/p2p/discoveryv5/protocol as discv5_protocol, ../../eth/utp/utp_discv5_protocol, ../../eth/keys, + ../../eth/utp/utp_router as rt, ../p2p/discv5_test_helper procSuite "Utp protocol over discovery v5 tests": @@ -65,6 +66,41 @@ procSuite "Utp protocol over discovery v5 tests": await node1.closeWait() await node2.closeWait() + proc cbUserData(server: UtpRouter[NodeAddress], client: UtpSocket[NodeAddress]): Future[void] = + let queue = rt.getUserData[NodeAddress, AsyncQueue[UtpSocket[NodeAddress]]](server) + queue.addLast(client) + + asyncTest "Provide user data pointer and use it in callback": + let + queue = newAsyncQueue[UtpSocket[NodeAddress]]() + node1 = initDiscoveryNode( + rng, PrivateKey.random(rng[]), localAddress(20302)) + node2 = initDiscoveryNode( + rng, PrivateKey.random(rng[]), localAddress(20303)) + + # constructor which uses connection callback and user data pointer as ref + utp1 = UtpDiscv5Protocol.new(node1, utpProtId, cbUserData, queue) + utp2 = UtpDiscv5Protocol.new(node2, utpProtId, cbUserData, queue) + + # nodes must have session between each other + check: + (await node1.ping(node2.localNode)).isOk() + + let clientSocketResult = await utp1.connectTo(NodeAddress.init(node2.localNode).unsafeGet()) + let clientSocket = clientSocketResult.get() + let serverSocket = await queue.get() + + check: + clientSocket.isConnected() + # in this test we do not configure the socket to be connected just after + # accepting incoming connection + not serverSocket.isConnected() + + await clientSocket.destroyWait() + await serverSocket.destroyWait() + await node1.closeWait() + await node2.closeWait() + asyncTest "Success write data over packet size to remote host": let queue = newAsyncQueue[UtpSocket[NodeAddress]]() @@ -122,6 +158,7 @@ procSuite "Utp protocol over discovery v5 tests": node2, utpProtId, registerIncomingSocketCallback(queue), + nil, allowOneIdCallback(allowedId), SocketConfig.init()) diff --git a/tests/utp/test_protocol.nim b/tests/utp/test_protocol.nim index f24790d..d7dd1b8 100644 --- a/tests/utp/test_protocol.nim +++ b/tests/utp/test_protocol.nim @@ -11,7 +11,7 @@ import chronos, testutils/unittests, ./test_utils, - ../../eth/utp/utp_router, + ../../eth/utp/utp_router as rt, ../../eth/utp/utp_protocol, ../../eth/keys @@ -147,10 +147,43 @@ procSuite "Utp protocol over udp tests": await utpProt1.shutdownWait() await utpProt2.shutdownWait() + + proc cbUserData(server: UtpRouter[TransportAddress], client: UtpSocket[TransportAddress]): Future[void] = + let q = rt.getUserData[TransportAddress, AsyncQueue[UtpSocket[TransportAddress]]](server) + q.addLast(client) + + asyncTest "Provide user data pointer and use it in callback": + let incomingConnections = newAsyncQueue[UtpSocket[TransportAddress]]() + let address = initTAddress("127.0.0.1", 9079) + let utpProt1 = UtpProtocol.new(cbUserData, address, incomingConnections) + + let address1 = initTAddress("127.0.0.1", 9080) + let utpProt2 = UtpProtocol.new(cbUserData, address1, incomingConnections) + + let connResult = await utpProt1.connectTo(address1) + + check: + connResult.isOk() + + let clientSocket = connResult.get() + # this future will be completed when we called accepted connection callback + let serverSocket = await incomingConnections.get() + + check: + clientSocket.isConnected() + # after successful connection outgoing buffer should be empty as syn packet + # should be correctly acked + clientSocket.numPacketsInOutGoingBuffer() == 0 + + not serverSocket.isConnected() + + await utpProt1.shutdownWait() + await utpProt2.shutdownWait() + asyncTest "Fail to connect to offline remote host": let server1Called = newAsyncEvent() let address = initTAddress("127.0.0.1", 9079) - let utpProt1 = UtpProtocol.new(setAcceptedCallback(server1Called), address , SocketConfig.init(milliseconds(200))) + let utpProt1 = UtpProtocol.new(setAcceptedCallback(server1Called), address , nil, SocketConfig.init(milliseconds(200))) let address1 = initTAddress("127.0.0.1", 9080) @@ -174,7 +207,7 @@ procSuite "Utp protocol over udp tests": asyncTest "Success connect to remote host which initialy was offline": let server1Called = newAsyncEvent() let address = initTAddress("127.0.0.1", 9079) - let utpProt1 = UtpProtocol.new(setAcceptedCallback(server1Called), address, SocketConfig.init(milliseconds(500))) + let utpProt1 = UtpProtocol.new(setAcceptedCallback(server1Called), address, nil, SocketConfig.init(milliseconds(500))) let address1 = initTAddress("127.0.0.1", 9080) @@ -387,17 +420,18 @@ procSuite "Utp protocol over udp tests": var server1Called = newAsyncEvent() let address1 = initTAddress("127.0.0.1", 9079) let utpProt1 = - UtpProtocol.new(setAcceptedCallback(server1Called), address1, SocketConfig.init(lowSynTimeout)) + UtpProtocol.new(setAcceptedCallback(server1Called), address1, nil, SocketConfig.init(lowSynTimeout)) let address2 = initTAddress("127.0.0.1", 9080) let utpProt2 = - UtpProtocol.new(registerIncomingSocketCallback(serverSockets), address2, SocketConfig.init(lowSynTimeout)) + UtpProtocol.new(registerIncomingSocketCallback(serverSockets), address2, nil, SocketConfig.init(lowSynTimeout)) let address3 = initTAddress("127.0.0.1", 9081) let utpProt3 = UtpProtocol.new( registerIncomingSocketCallback(serverSockets), address3, + nil, SocketConfig.init(), allowOneIdCallback(allowedId) ) diff --git a/tests/utp/test_utp_router.nim b/tests/utp/test_utp_router.nim index 764f2a9..6cea6ec 100644 --- a/tests/utp/test_utp_router.nim +++ b/tests/utp/test_utp_router.nim @@ -87,7 +87,7 @@ procSuite "Utp router unit tests": asyncTest "Router should create new incoming socket when receiving not known syn packet": let q = newAsyncQueue[UtpSocket[int]]() - let router = UtpRouter[int].new(registerIncomingSocketCallback(q), SocketConfig.init(), rng) + let router = UtpRouter[int].new(registerIncomingSocketCallback(q), nil, nil, SocketConfig.init(), rng) router.sendCb = testSend let encodedSyn = encodePacket(synPacket(10, 10, 10))