Add `localAddress` support to `stream.connect` (#362)

* Add `localAddress` support to `stream.connect`

* fix windows

* TransportAddress() instead of AnyAddress

* tweak flags

* Better flags

* try to workaround nim 1.2 issue

* Handle ReusePort in createStreamServer and improve tests

* Rename ClientFlags to SocketFlags

---------

Co-authored-by: Diego <diego@status.im>
This commit is contained in:
Tanguy 2023-04-03 14:34:35 +02:00 committed by GitHub
parent 229de5f842
commit ab5a8c2e0f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 131 additions and 6 deletions

View File

@ -50,7 +50,12 @@ type
# get stuck on transport `close()`.
# Please use this flag only if you are making both client and server in
# the same thread.
TcpNoDelay
TcpNoDelay # deprecated: Use SocketFlags.TcpNoDelay
SocketFlags* {.pure.} = enum
TcpNoDelay,
ReuseAddr,
ReusePort
StreamTransportTracker* = ref object of TrackerBase
@ -699,7 +704,9 @@ when defined(windows):
proc connect*(address: TransportAddress,
bufferSize = DefaultStreamBufferSize,
child: StreamTransport = nil,
flags: set[TransportFlags] = {}): Future[StreamTransport] =
localAddress = TransportAddress(),
flags: set[SocketFlags] = {},
): Future[StreamTransport] =
## Open new connection to remote peer with address ``address`` and create
## new transport object ``StreamTransport`` for established connection.
## ``bufferSize`` is size of internal buffer for transport.
@ -724,7 +731,35 @@ when defined(windows):
retFuture.fail(getTransportOsError(osLastError()))
return retFuture
if not(bindToDomain(sock, raddress.getDomain())):
if SocketFlags.ReuseAddr in flags:
if not(setSockOpt(sock, SOL_SOCKET, SO_REUSEADDR, 1)):
let err = osLastError()
sock.closeSocket()
retFuture.fail(getTransportOsError(err))
return retFuture
if SocketFlags.ReusePort in flags:
if not(setSockOpt(sock, SOL_SOCKET, SO_REUSEPORT, 1)):
let err = osLastError()
sock.closeSocket()
retFuture.fail(getTransportOsError(err))
return retFuture
if localAddress != TransportAddress():
if localAddress.family != address.family:
sock.closeSocket()
retFuture.fail(newException(TransportOsError,
"connect local address domain is not equal to target address domain"))
return retFuture
var
localAddr: Sockaddr_storage
localAddrLen: SockLen
localAddress.toSAddr(localAddr, localAddrLen)
if bindSocket(SocketHandle(sock),
cast[ptr SockAddr](addr localAddr), localAddrLen) != 0:
sock.closeSocket()
retFuture.fail(getTransportOsError(osLastError()))
return retFuture
elif not(bindToDomain(sock, raddress.getDomain())):
let err = wsaGetLastError()
sock.closeSocket()
retFuture.fail(getTransportOsError(err))
@ -1496,7 +1531,9 @@ else:
proc connect*(address: TransportAddress,
bufferSize = DefaultStreamBufferSize,
child: StreamTransport = nil,
flags: set[TransportFlags] = {}): Future[StreamTransport] =
localAddress = TransportAddress(),
flags: set[SocketFlags] = {},
): Future[StreamTransport] =
## Open new connection to remote peer with address ``address`` and create
## new transport object ``StreamTransport`` for established connection.
## ``bufferSize`` - size of internal buffer for transport.
@ -1523,12 +1560,40 @@ else:
return retFuture
if address.family in {AddressFamily.IPv4, AddressFamily.IPv6}:
if TransportFlags.TcpNoDelay in flags:
if SocketFlags.TcpNoDelay in flags:
if not(setSockOpt(sock, osdefs.IPPROTO_TCP, osdefs.TCP_NODELAY, 1)):
let err = osLastError()
sock.closeSocket()
retFuture.fail(getTransportOsError(err))
return retFuture
if SocketFlags.ReuseAddr in flags:
if not(setSockOpt(sock, SOL_SOCKET, SO_REUSEADDR, 1)):
let err = osLastError()
sock.closeSocket()
retFuture.fail(getTransportOsError(err))
return retFuture
if SocketFlags.ReusePort in flags:
if not(setSockOpt(sock, SOL_SOCKET, SO_REUSEPORT, 1)):
let err = osLastError()
sock.closeSocket()
retFuture.fail(getTransportOsError(err))
return retFuture
if localAddress != TransportAddress():
if localAddress.family != address.family:
sock.closeSocket()
retFuture.fail(newException(TransportOsError,
"connect local address domain is not equal to target address domain"))
return retFuture
var
localAddr: Sockaddr_storage
localAddrLen: SockLen
localAddress.toSAddr(localAddr, localAddrLen)
if bindSocket(SocketHandle(sock),
cast[ptr SockAddr](addr localAddr), localAddrLen) != 0:
sock.closeSocket()
retFuture.fail(getTransportOsError(osLastError()))
return retFuture
proc continuation(udata: pointer) =
if not(retFuture.finished()):
@ -1776,6 +1841,16 @@ proc join*(server: StreamServer): Future[void] =
retFuture.complete()
return retFuture
proc connect*(address: TransportAddress,
bufferSize = DefaultStreamBufferSize,
child: StreamTransport = nil,
flags: set[TransportFlags],
localAddress = TransportAddress()): Future[StreamTransport] =
# Retro compatibility with TransportFlags
var mappedFlags: set[SocketFlags]
if TcpNoDelay in flags: mappedFlags.incl(SocketFlags.TcpNoDelay)
address.connect(bufferSize, child, localAddress, mappedFlags)
proc close*(server: StreamServer) =
## Release ``server`` resources.
##
@ -1864,6 +1939,13 @@ proc createStreamServer*(host: TransportAddress,
if sock == asyncInvalidSocket:
discard closeFd(SocketHandle(serverSocket))
raiseTransportOsError(err)
if ServerFlags.ReusePort in flags:
if not(setSockOpt(serverSocket, osdefs.SOL_SOCKET,
osdefs.SO_REUSEPORT, 1)):
let err = osLastError()
if sock == asyncInvalidSocket:
discard closeFd(SocketHandle(serverSocket))
raiseTransportOsError(err)
# TCP flags are not useful for Unix domain sockets.
if ServerFlags.TcpNoDelay in flags:
if not(setSockOpt(serverSocket, osdefs.IPPROTO_TCP,

View File

@ -958,7 +958,7 @@ suite "TLSStream test suite":
key = TLSPrivateKey.init(pemkey)
cert = TLSCertificate.init(pemcert)
var server = createStreamServer(address, serveClient, {ReuseAddr})
var server = createStreamServer(address, serveClient, {ServerFlags.ReuseAddr})
server.start()
var conn = await connect(address)
var creader = newAsyncStreamReader(conn)

View File

@ -1259,6 +1259,47 @@ suite "Stream Transport test suite":
await allFutures(rtransp.closeWait(), wtransp.closeWait())
return buffer == message
proc testConnectBindLocalAddress() {.async.} =
let dst1 = initTAddress("127.0.0.1:33335")
let dst2 = initTAddress("127.0.0.1:33336")
let dst3 = initTAddress("127.0.0.1:33337")
proc client(server: StreamServer, transp: StreamTransport) {.async.} =
await transp.closeWait()
# We use ReuseAddr here only to be able to reuse the same IP/Port when there's a TIME_WAIT socket. It's useful when
# running the test multiple times or if a test ran previously used the same port.
let servers =
[createStreamServer(dst1, client, {ReuseAddr}),
createStreamServer(dst2, client, {ReuseAddr}),
createStreamServer(dst3, client, {ReusePort})]
for server in servers:
server.start()
let ta = initTAddress("0.0.0.0:35000")
# It works cause there's no active listening socket bound to ta and we are using ReuseAddr
var transp1 = await connect(dst1, localAddress = ta, flags={SocketFlags.ReuseAddr})
var transp2 = await connect(dst2, localAddress = ta, flags={SocketFlags.ReuseAddr})
# It works cause even thought there's an active listening socket bound to dst3, we are using ReusePort
var transp3 = await connect(dst2, localAddress = dst3, flags={SocketFlags.ReusePort})
expect(TransportOsError):
var transp2 = await connect(dst3, localAddress = ta)
expect(TransportOsError):
var transp3 = await connect(dst3, localAddress = initTAddress(":::35000"))
await transp1.closeWait()
await transp2.closeWait()
await transp3.closeWait()
for server in servers:
server.stop()
await server.closeWait()
markFD = getCurrentFD()
for i in 0..<len(addresses):
@ -1346,6 +1387,8 @@ suite "Stream Transport test suite":
check waitFor(testReadOnClose(addresses[i])) == true
test "[PIPE] readExactly()/write() test":
check waitFor(testPipe()) == true
test "[IP] bind connect to local address":
waitFor(testConnectBindLocalAddress())
test "Servers leak test":
check getTracker("stream.server").isLeaked() == false
test "Transports leak test":