Add `localAddress` support to `stream.connect` (#362)

* Add `localAddress` support to `stream.connect`

* fix windows

* TransportAddress() instead of AnyAddress

* tweak flags

* Better flags

* try to workaround nim 1.2 issue

* Handle ReusePort in createStreamServer and improve tests

* Rename ClientFlags to SocketFlags

---------

Co-authored-by: Diego <diego@status.im>
This commit is contained in:
Tanguy 2023-04-03 14:34:35 +02:00 committed by GitHub
parent 229de5f842
commit ab5a8c2e0f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 131 additions and 6 deletions

View File

@ -50,7 +50,12 @@ type
# get stuck on transport `close()`. # get stuck on transport `close()`.
# Please use this flag only if you are making both client and server in # Please use this flag only if you are making both client and server in
# the same thread. # the same thread.
TcpNoDelay TcpNoDelay # deprecated: Use SocketFlags.TcpNoDelay
SocketFlags* {.pure.} = enum
TcpNoDelay,
ReuseAddr,
ReusePort
StreamTransportTracker* = ref object of TrackerBase StreamTransportTracker* = ref object of TrackerBase
@ -699,7 +704,9 @@ when defined(windows):
proc connect*(address: TransportAddress, proc connect*(address: TransportAddress,
bufferSize = DefaultStreamBufferSize, bufferSize = DefaultStreamBufferSize,
child: StreamTransport = nil, child: StreamTransport = nil,
flags: set[TransportFlags] = {}): Future[StreamTransport] = localAddress = TransportAddress(),
flags: set[SocketFlags] = {},
): Future[StreamTransport] =
## Open new connection to remote peer with address ``address`` and create ## Open new connection to remote peer with address ``address`` and create
## new transport object ``StreamTransport`` for established connection. ## new transport object ``StreamTransport`` for established connection.
## ``bufferSize`` is size of internal buffer for transport. ## ``bufferSize`` is size of internal buffer for transport.
@ -724,7 +731,35 @@ when defined(windows):
retFuture.fail(getTransportOsError(osLastError())) retFuture.fail(getTransportOsError(osLastError()))
return retFuture return retFuture
if not(bindToDomain(sock, raddress.getDomain())): if SocketFlags.ReuseAddr in flags:
if not(setSockOpt(sock, SOL_SOCKET, SO_REUSEADDR, 1)):
let err = osLastError()
sock.closeSocket()
retFuture.fail(getTransportOsError(err))
return retFuture
if SocketFlags.ReusePort in flags:
if not(setSockOpt(sock, SOL_SOCKET, SO_REUSEPORT, 1)):
let err = osLastError()
sock.closeSocket()
retFuture.fail(getTransportOsError(err))
return retFuture
if localAddress != TransportAddress():
if localAddress.family != address.family:
sock.closeSocket()
retFuture.fail(newException(TransportOsError,
"connect local address domain is not equal to target address domain"))
return retFuture
var
localAddr: Sockaddr_storage
localAddrLen: SockLen
localAddress.toSAddr(localAddr, localAddrLen)
if bindSocket(SocketHandle(sock),
cast[ptr SockAddr](addr localAddr), localAddrLen) != 0:
sock.closeSocket()
retFuture.fail(getTransportOsError(osLastError()))
return retFuture
elif not(bindToDomain(sock, raddress.getDomain())):
let err = wsaGetLastError() let err = wsaGetLastError()
sock.closeSocket() sock.closeSocket()
retFuture.fail(getTransportOsError(err)) retFuture.fail(getTransportOsError(err))
@ -1496,7 +1531,9 @@ else:
proc connect*(address: TransportAddress, proc connect*(address: TransportAddress,
bufferSize = DefaultStreamBufferSize, bufferSize = DefaultStreamBufferSize,
child: StreamTransport = nil, child: StreamTransport = nil,
flags: set[TransportFlags] = {}): Future[StreamTransport] = localAddress = TransportAddress(),
flags: set[SocketFlags] = {},
): Future[StreamTransport] =
## Open new connection to remote peer with address ``address`` and create ## Open new connection to remote peer with address ``address`` and create
## new transport object ``StreamTransport`` for established connection. ## new transport object ``StreamTransport`` for established connection.
## ``bufferSize`` - size of internal buffer for transport. ## ``bufferSize`` - size of internal buffer for transport.
@ -1523,12 +1560,40 @@ else:
return retFuture return retFuture
if address.family in {AddressFamily.IPv4, AddressFamily.IPv6}: if address.family in {AddressFamily.IPv4, AddressFamily.IPv6}:
if TransportFlags.TcpNoDelay in flags: if SocketFlags.TcpNoDelay in flags:
if not(setSockOpt(sock, osdefs.IPPROTO_TCP, osdefs.TCP_NODELAY, 1)): if not(setSockOpt(sock, osdefs.IPPROTO_TCP, osdefs.TCP_NODELAY, 1)):
let err = osLastError() let err = osLastError()
sock.closeSocket() sock.closeSocket()
retFuture.fail(getTransportOsError(err)) retFuture.fail(getTransportOsError(err))
return retFuture return retFuture
if SocketFlags.ReuseAddr in flags:
if not(setSockOpt(sock, SOL_SOCKET, SO_REUSEADDR, 1)):
let err = osLastError()
sock.closeSocket()
retFuture.fail(getTransportOsError(err))
return retFuture
if SocketFlags.ReusePort in flags:
if not(setSockOpt(sock, SOL_SOCKET, SO_REUSEPORT, 1)):
let err = osLastError()
sock.closeSocket()
retFuture.fail(getTransportOsError(err))
return retFuture
if localAddress != TransportAddress():
if localAddress.family != address.family:
sock.closeSocket()
retFuture.fail(newException(TransportOsError,
"connect local address domain is not equal to target address domain"))
return retFuture
var
localAddr: Sockaddr_storage
localAddrLen: SockLen
localAddress.toSAddr(localAddr, localAddrLen)
if bindSocket(SocketHandle(sock),
cast[ptr SockAddr](addr localAddr), localAddrLen) != 0:
sock.closeSocket()
retFuture.fail(getTransportOsError(osLastError()))
return retFuture
proc continuation(udata: pointer) = proc continuation(udata: pointer) =
if not(retFuture.finished()): if not(retFuture.finished()):
@ -1776,6 +1841,16 @@ proc join*(server: StreamServer): Future[void] =
retFuture.complete() retFuture.complete()
return retFuture return retFuture
proc connect*(address: TransportAddress,
bufferSize = DefaultStreamBufferSize,
child: StreamTransport = nil,
flags: set[TransportFlags],
localAddress = TransportAddress()): Future[StreamTransport] =
# Retro compatibility with TransportFlags
var mappedFlags: set[SocketFlags]
if TcpNoDelay in flags: mappedFlags.incl(SocketFlags.TcpNoDelay)
address.connect(bufferSize, child, localAddress, mappedFlags)
proc close*(server: StreamServer) = proc close*(server: StreamServer) =
## Release ``server`` resources. ## Release ``server`` resources.
## ##
@ -1864,6 +1939,13 @@ proc createStreamServer*(host: TransportAddress,
if sock == asyncInvalidSocket: if sock == asyncInvalidSocket:
discard closeFd(SocketHandle(serverSocket)) discard closeFd(SocketHandle(serverSocket))
raiseTransportOsError(err) raiseTransportOsError(err)
if ServerFlags.ReusePort in flags:
if not(setSockOpt(serverSocket, osdefs.SOL_SOCKET,
osdefs.SO_REUSEPORT, 1)):
let err = osLastError()
if sock == asyncInvalidSocket:
discard closeFd(SocketHandle(serverSocket))
raiseTransportOsError(err)
# TCP flags are not useful for Unix domain sockets. # TCP flags are not useful for Unix domain sockets.
if ServerFlags.TcpNoDelay in flags: if ServerFlags.TcpNoDelay in flags:
if not(setSockOpt(serverSocket, osdefs.IPPROTO_TCP, if not(setSockOpt(serverSocket, osdefs.IPPROTO_TCP,

View File

@ -958,7 +958,7 @@ suite "TLSStream test suite":
key = TLSPrivateKey.init(pemkey) key = TLSPrivateKey.init(pemkey)
cert = TLSCertificate.init(pemcert) cert = TLSCertificate.init(pemcert)
var server = createStreamServer(address, serveClient, {ReuseAddr}) var server = createStreamServer(address, serveClient, {ServerFlags.ReuseAddr})
server.start() server.start()
var conn = await connect(address) var conn = await connect(address)
var creader = newAsyncStreamReader(conn) var creader = newAsyncStreamReader(conn)

View File

@ -1259,6 +1259,47 @@ suite "Stream Transport test suite":
await allFutures(rtransp.closeWait(), wtransp.closeWait()) await allFutures(rtransp.closeWait(), wtransp.closeWait())
return buffer == message return buffer == message
proc testConnectBindLocalAddress() {.async.} =
let dst1 = initTAddress("127.0.0.1:33335")
let dst2 = initTAddress("127.0.0.1:33336")
let dst3 = initTAddress("127.0.0.1:33337")
proc client(server: StreamServer, transp: StreamTransport) {.async.} =
await transp.closeWait()
# We use ReuseAddr here only to be able to reuse the same IP/Port when there's a TIME_WAIT socket. It's useful when
# running the test multiple times or if a test ran previously used the same port.
let servers =
[createStreamServer(dst1, client, {ReuseAddr}),
createStreamServer(dst2, client, {ReuseAddr}),
createStreamServer(dst3, client, {ReusePort})]
for server in servers:
server.start()
let ta = initTAddress("0.0.0.0:35000")
# It works cause there's no active listening socket bound to ta and we are using ReuseAddr
var transp1 = await connect(dst1, localAddress = ta, flags={SocketFlags.ReuseAddr})
var transp2 = await connect(dst2, localAddress = ta, flags={SocketFlags.ReuseAddr})
# It works cause even thought there's an active listening socket bound to dst3, we are using ReusePort
var transp3 = await connect(dst2, localAddress = dst3, flags={SocketFlags.ReusePort})
expect(TransportOsError):
var transp2 = await connect(dst3, localAddress = ta)
expect(TransportOsError):
var transp3 = await connect(dst3, localAddress = initTAddress(":::35000"))
await transp1.closeWait()
await transp2.closeWait()
await transp3.closeWait()
for server in servers:
server.stop()
await server.closeWait()
markFD = getCurrentFD() markFD = getCurrentFD()
for i in 0..<len(addresses): for i in 0..<len(addresses):
@ -1346,6 +1387,8 @@ suite "Stream Transport test suite":
check waitFor(testReadOnClose(addresses[i])) == true check waitFor(testReadOnClose(addresses[i])) == true
test "[PIPE] readExactly()/write() test": test "[PIPE] readExactly()/write() test":
check waitFor(testPipe()) == true check waitFor(testPipe()) == true
test "[IP] bind connect to local address":
waitFor(testConnectBindLocalAddress())
test "Servers leak test": test "Servers leak test":
check getTracker("stream.server").isLeaked() == false check getTracker("stream.server").isLeaked() == false
test "Transports leak test": test "Transports leak test":