From ab5a8c2e0f6941fe3debd61dff0293790079d1b0 Mon Sep 17 00:00:00 2001 From: Tanguy Date: Mon, 3 Apr 2023 14:34:35 +0200 Subject: [PATCH] Add `localAddress` support to `stream.connect` (#362) * Add `localAddress` support to `stream.connect` * fix windows * TransportAddress() instead of AnyAddress * tweak flags * Better flags * try to workaround nim 1.2 issue * Handle ReusePort in createStreamServer and improve tests * Rename ClientFlags to SocketFlags --------- Co-authored-by: Diego --- chronos/transports/stream.nim | 92 +++++++++++++++++++++++++++++++++-- tests/testasyncstream.nim | 2 +- tests/teststream.nim | 43 ++++++++++++++++ 3 files changed, 131 insertions(+), 6 deletions(-) diff --git a/chronos/transports/stream.nim b/chronos/transports/stream.nim index ef964169..2c74085a 100644 --- a/chronos/transports/stream.nim +++ b/chronos/transports/stream.nim @@ -50,7 +50,12 @@ type # get stuck on transport `close()`. # Please use this flag only if you are making both client and server in # the same thread. - TcpNoDelay + TcpNoDelay # deprecated: Use SocketFlags.TcpNoDelay + + SocketFlags* {.pure.} = enum + TcpNoDelay, + ReuseAddr, + ReusePort StreamTransportTracker* = ref object of TrackerBase @@ -699,7 +704,9 @@ when defined(windows): proc connect*(address: TransportAddress, bufferSize = DefaultStreamBufferSize, child: StreamTransport = nil, - flags: set[TransportFlags] = {}): Future[StreamTransport] = + localAddress = TransportAddress(), + flags: set[SocketFlags] = {}, + ): Future[StreamTransport] = ## Open new connection to remote peer with address ``address`` and create ## new transport object ``StreamTransport`` for established connection. ## ``bufferSize`` is size of internal buffer for transport. @@ -724,7 +731,35 @@ when defined(windows): retFuture.fail(getTransportOsError(osLastError())) return retFuture - if not(bindToDomain(sock, raddress.getDomain())): + if SocketFlags.ReuseAddr in flags: + if not(setSockOpt(sock, SOL_SOCKET, SO_REUSEADDR, 1)): + let err = osLastError() + sock.closeSocket() + retFuture.fail(getTransportOsError(err)) + return retFuture + if SocketFlags.ReusePort in flags: + if not(setSockOpt(sock, SOL_SOCKET, SO_REUSEPORT, 1)): + let err = osLastError() + sock.closeSocket() + retFuture.fail(getTransportOsError(err)) + return retFuture + + if localAddress != TransportAddress(): + if localAddress.family != address.family: + sock.closeSocket() + retFuture.fail(newException(TransportOsError, + "connect local address domain is not equal to target address domain")) + return retFuture + var + localAddr: Sockaddr_storage + localAddrLen: SockLen + localAddress.toSAddr(localAddr, localAddrLen) + if bindSocket(SocketHandle(sock), + cast[ptr SockAddr](addr localAddr), localAddrLen) != 0: + sock.closeSocket() + retFuture.fail(getTransportOsError(osLastError())) + return retFuture + elif not(bindToDomain(sock, raddress.getDomain())): let err = wsaGetLastError() sock.closeSocket() retFuture.fail(getTransportOsError(err)) @@ -1496,7 +1531,9 @@ else: proc connect*(address: TransportAddress, bufferSize = DefaultStreamBufferSize, child: StreamTransport = nil, - flags: set[TransportFlags] = {}): Future[StreamTransport] = + localAddress = TransportAddress(), + flags: set[SocketFlags] = {}, + ): Future[StreamTransport] = ## Open new connection to remote peer with address ``address`` and create ## new transport object ``StreamTransport`` for established connection. ## ``bufferSize`` - size of internal buffer for transport. @@ -1523,12 +1560,40 @@ else: return retFuture if address.family in {AddressFamily.IPv4, AddressFamily.IPv6}: - if TransportFlags.TcpNoDelay in flags: + if SocketFlags.TcpNoDelay in flags: if not(setSockOpt(sock, osdefs.IPPROTO_TCP, osdefs.TCP_NODELAY, 1)): let err = osLastError() sock.closeSocket() retFuture.fail(getTransportOsError(err)) return retFuture + if SocketFlags.ReuseAddr in flags: + if not(setSockOpt(sock, SOL_SOCKET, SO_REUSEADDR, 1)): + let err = osLastError() + sock.closeSocket() + retFuture.fail(getTransportOsError(err)) + return retFuture + if SocketFlags.ReusePort in flags: + if not(setSockOpt(sock, SOL_SOCKET, SO_REUSEPORT, 1)): + let err = osLastError() + sock.closeSocket() + retFuture.fail(getTransportOsError(err)) + return retFuture + + if localAddress != TransportAddress(): + if localAddress.family != address.family: + sock.closeSocket() + retFuture.fail(newException(TransportOsError, + "connect local address domain is not equal to target address domain")) + return retFuture + var + localAddr: Sockaddr_storage + localAddrLen: SockLen + localAddress.toSAddr(localAddr, localAddrLen) + if bindSocket(SocketHandle(sock), + cast[ptr SockAddr](addr localAddr), localAddrLen) != 0: + sock.closeSocket() + retFuture.fail(getTransportOsError(osLastError())) + return retFuture proc continuation(udata: pointer) = if not(retFuture.finished()): @@ -1776,6 +1841,16 @@ proc join*(server: StreamServer): Future[void] = retFuture.complete() return retFuture +proc connect*(address: TransportAddress, + bufferSize = DefaultStreamBufferSize, + child: StreamTransport = nil, + flags: set[TransportFlags], + localAddress = TransportAddress()): Future[StreamTransport] = + # Retro compatibility with TransportFlags + var mappedFlags: set[SocketFlags] + if TcpNoDelay in flags: mappedFlags.incl(SocketFlags.TcpNoDelay) + address.connect(bufferSize, child, localAddress, mappedFlags) + proc close*(server: StreamServer) = ## Release ``server`` resources. ## @@ -1864,6 +1939,13 @@ proc createStreamServer*(host: TransportAddress, if sock == asyncInvalidSocket: discard closeFd(SocketHandle(serverSocket)) raiseTransportOsError(err) + if ServerFlags.ReusePort in flags: + if not(setSockOpt(serverSocket, osdefs.SOL_SOCKET, + osdefs.SO_REUSEPORT, 1)): + let err = osLastError() + if sock == asyncInvalidSocket: + discard closeFd(SocketHandle(serverSocket)) + raiseTransportOsError(err) # TCP flags are not useful for Unix domain sockets. if ServerFlags.TcpNoDelay in flags: if not(setSockOpt(serverSocket, osdefs.IPPROTO_TCP, diff --git a/tests/testasyncstream.nim b/tests/testasyncstream.nim index fd581cb9..47a6c942 100644 --- a/tests/testasyncstream.nim +++ b/tests/testasyncstream.nim @@ -958,7 +958,7 @@ suite "TLSStream test suite": key = TLSPrivateKey.init(pemkey) cert = TLSCertificate.init(pemcert) - var server = createStreamServer(address, serveClient, {ReuseAddr}) + var server = createStreamServer(address, serveClient, {ServerFlags.ReuseAddr}) server.start() var conn = await connect(address) var creader = newAsyncStreamReader(conn) diff --git a/tests/teststream.nim b/tests/teststream.nim index 90fd55de..c76ccf6f 100644 --- a/tests/teststream.nim +++ b/tests/teststream.nim @@ -1259,6 +1259,47 @@ suite "Stream Transport test suite": await allFutures(rtransp.closeWait(), wtransp.closeWait()) return buffer == message + proc testConnectBindLocalAddress() {.async.} = + let dst1 = initTAddress("127.0.0.1:33335") + let dst2 = initTAddress("127.0.0.1:33336") + let dst3 = initTAddress("127.0.0.1:33337") + + proc client(server: StreamServer, transp: StreamTransport) {.async.} = + await transp.closeWait() + + # We use ReuseAddr here only to be able to reuse the same IP/Port when there's a TIME_WAIT socket. It's useful when + # running the test multiple times or if a test ran previously used the same port. + let servers = + [createStreamServer(dst1, client, {ReuseAddr}), + createStreamServer(dst2, client, {ReuseAddr}), + createStreamServer(dst3, client, {ReusePort})] + + for server in servers: + server.start() + + let ta = initTAddress("0.0.0.0:35000") + + # It works cause there's no active listening socket bound to ta and we are using ReuseAddr + var transp1 = await connect(dst1, localAddress = ta, flags={SocketFlags.ReuseAddr}) + var transp2 = await connect(dst2, localAddress = ta, flags={SocketFlags.ReuseAddr}) + + # It works cause even thought there's an active listening socket bound to dst3, we are using ReusePort + var transp3 = await connect(dst2, localAddress = dst3, flags={SocketFlags.ReusePort}) + + expect(TransportOsError): + var transp2 = await connect(dst3, localAddress = ta) + + expect(TransportOsError): + var transp3 = await connect(dst3, localAddress = initTAddress(":::35000")) + + await transp1.closeWait() + await transp2.closeWait() + await transp3.closeWait() + + for server in servers: + server.stop() + await server.closeWait() + markFD = getCurrentFD() for i in 0..