Add `localAddress` support to `stream.connect` (#362)
* Add `localAddress` support to `stream.connect` * fix windows * TransportAddress() instead of AnyAddress * tweak flags * Better flags * try to workaround nim 1.2 issue * Handle ReusePort in createStreamServer and improve tests * Rename ClientFlags to SocketFlags --------- Co-authored-by: Diego <diego@status.im>
This commit is contained in:
parent
229de5f842
commit
ab5a8c2e0f
|
@ -50,7 +50,12 @@ type
|
|||
# get stuck on transport `close()`.
|
||||
# Please use this flag only if you are making both client and server in
|
||||
# the same thread.
|
||||
TcpNoDelay
|
||||
TcpNoDelay # deprecated: Use SocketFlags.TcpNoDelay
|
||||
|
||||
SocketFlags* {.pure.} = enum
|
||||
TcpNoDelay,
|
||||
ReuseAddr,
|
||||
ReusePort
|
||||
|
||||
|
||||
StreamTransportTracker* = ref object of TrackerBase
|
||||
|
@ -699,7 +704,9 @@ when defined(windows):
|
|||
proc connect*(address: TransportAddress,
|
||||
bufferSize = DefaultStreamBufferSize,
|
||||
child: StreamTransport = nil,
|
||||
flags: set[TransportFlags] = {}): Future[StreamTransport] =
|
||||
localAddress = TransportAddress(),
|
||||
flags: set[SocketFlags] = {},
|
||||
): Future[StreamTransport] =
|
||||
## Open new connection to remote peer with address ``address`` and create
|
||||
## new transport object ``StreamTransport`` for established connection.
|
||||
## ``bufferSize`` is size of internal buffer for transport.
|
||||
|
@ -724,7 +731,35 @@ when defined(windows):
|
|||
retFuture.fail(getTransportOsError(osLastError()))
|
||||
return retFuture
|
||||
|
||||
if not(bindToDomain(sock, raddress.getDomain())):
|
||||
if SocketFlags.ReuseAddr in flags:
|
||||
if not(setSockOpt(sock, SOL_SOCKET, SO_REUSEADDR, 1)):
|
||||
let err = osLastError()
|
||||
sock.closeSocket()
|
||||
retFuture.fail(getTransportOsError(err))
|
||||
return retFuture
|
||||
if SocketFlags.ReusePort in flags:
|
||||
if not(setSockOpt(sock, SOL_SOCKET, SO_REUSEPORT, 1)):
|
||||
let err = osLastError()
|
||||
sock.closeSocket()
|
||||
retFuture.fail(getTransportOsError(err))
|
||||
return retFuture
|
||||
|
||||
if localAddress != TransportAddress():
|
||||
if localAddress.family != address.family:
|
||||
sock.closeSocket()
|
||||
retFuture.fail(newException(TransportOsError,
|
||||
"connect local address domain is not equal to target address domain"))
|
||||
return retFuture
|
||||
var
|
||||
localAddr: Sockaddr_storage
|
||||
localAddrLen: SockLen
|
||||
localAddress.toSAddr(localAddr, localAddrLen)
|
||||
if bindSocket(SocketHandle(sock),
|
||||
cast[ptr SockAddr](addr localAddr), localAddrLen) != 0:
|
||||
sock.closeSocket()
|
||||
retFuture.fail(getTransportOsError(osLastError()))
|
||||
return retFuture
|
||||
elif not(bindToDomain(sock, raddress.getDomain())):
|
||||
let err = wsaGetLastError()
|
||||
sock.closeSocket()
|
||||
retFuture.fail(getTransportOsError(err))
|
||||
|
@ -1496,7 +1531,9 @@ else:
|
|||
proc connect*(address: TransportAddress,
|
||||
bufferSize = DefaultStreamBufferSize,
|
||||
child: StreamTransport = nil,
|
||||
flags: set[TransportFlags] = {}): Future[StreamTransport] =
|
||||
localAddress = TransportAddress(),
|
||||
flags: set[SocketFlags] = {},
|
||||
): Future[StreamTransport] =
|
||||
## Open new connection to remote peer with address ``address`` and create
|
||||
## new transport object ``StreamTransport`` for established connection.
|
||||
## ``bufferSize`` - size of internal buffer for transport.
|
||||
|
@ -1523,12 +1560,40 @@ else:
|
|||
return retFuture
|
||||
|
||||
if address.family in {AddressFamily.IPv4, AddressFamily.IPv6}:
|
||||
if TransportFlags.TcpNoDelay in flags:
|
||||
if SocketFlags.TcpNoDelay in flags:
|
||||
if not(setSockOpt(sock, osdefs.IPPROTO_TCP, osdefs.TCP_NODELAY, 1)):
|
||||
let err = osLastError()
|
||||
sock.closeSocket()
|
||||
retFuture.fail(getTransportOsError(err))
|
||||
return retFuture
|
||||
if SocketFlags.ReuseAddr in flags:
|
||||
if not(setSockOpt(sock, SOL_SOCKET, SO_REUSEADDR, 1)):
|
||||
let err = osLastError()
|
||||
sock.closeSocket()
|
||||
retFuture.fail(getTransportOsError(err))
|
||||
return retFuture
|
||||
if SocketFlags.ReusePort in flags:
|
||||
if not(setSockOpt(sock, SOL_SOCKET, SO_REUSEPORT, 1)):
|
||||
let err = osLastError()
|
||||
sock.closeSocket()
|
||||
retFuture.fail(getTransportOsError(err))
|
||||
return retFuture
|
||||
|
||||
if localAddress != TransportAddress():
|
||||
if localAddress.family != address.family:
|
||||
sock.closeSocket()
|
||||
retFuture.fail(newException(TransportOsError,
|
||||
"connect local address domain is not equal to target address domain"))
|
||||
return retFuture
|
||||
var
|
||||
localAddr: Sockaddr_storage
|
||||
localAddrLen: SockLen
|
||||
localAddress.toSAddr(localAddr, localAddrLen)
|
||||
if bindSocket(SocketHandle(sock),
|
||||
cast[ptr SockAddr](addr localAddr), localAddrLen) != 0:
|
||||
sock.closeSocket()
|
||||
retFuture.fail(getTransportOsError(osLastError()))
|
||||
return retFuture
|
||||
|
||||
proc continuation(udata: pointer) =
|
||||
if not(retFuture.finished()):
|
||||
|
@ -1776,6 +1841,16 @@ proc join*(server: StreamServer): Future[void] =
|
|||
retFuture.complete()
|
||||
return retFuture
|
||||
|
||||
proc connect*(address: TransportAddress,
|
||||
bufferSize = DefaultStreamBufferSize,
|
||||
child: StreamTransport = nil,
|
||||
flags: set[TransportFlags],
|
||||
localAddress = TransportAddress()): Future[StreamTransport] =
|
||||
# Retro compatibility with TransportFlags
|
||||
var mappedFlags: set[SocketFlags]
|
||||
if TcpNoDelay in flags: mappedFlags.incl(SocketFlags.TcpNoDelay)
|
||||
address.connect(bufferSize, child, localAddress, mappedFlags)
|
||||
|
||||
proc close*(server: StreamServer) =
|
||||
## Release ``server`` resources.
|
||||
##
|
||||
|
@ -1864,6 +1939,13 @@ proc createStreamServer*(host: TransportAddress,
|
|||
if sock == asyncInvalidSocket:
|
||||
discard closeFd(SocketHandle(serverSocket))
|
||||
raiseTransportOsError(err)
|
||||
if ServerFlags.ReusePort in flags:
|
||||
if not(setSockOpt(serverSocket, osdefs.SOL_SOCKET,
|
||||
osdefs.SO_REUSEPORT, 1)):
|
||||
let err = osLastError()
|
||||
if sock == asyncInvalidSocket:
|
||||
discard closeFd(SocketHandle(serverSocket))
|
||||
raiseTransportOsError(err)
|
||||
# TCP flags are not useful for Unix domain sockets.
|
||||
if ServerFlags.TcpNoDelay in flags:
|
||||
if not(setSockOpt(serverSocket, osdefs.IPPROTO_TCP,
|
||||
|
|
|
@ -958,7 +958,7 @@ suite "TLSStream test suite":
|
|||
key = TLSPrivateKey.init(pemkey)
|
||||
cert = TLSCertificate.init(pemcert)
|
||||
|
||||
var server = createStreamServer(address, serveClient, {ReuseAddr})
|
||||
var server = createStreamServer(address, serveClient, {ServerFlags.ReuseAddr})
|
||||
server.start()
|
||||
var conn = await connect(address)
|
||||
var creader = newAsyncStreamReader(conn)
|
||||
|
|
|
@ -1259,6 +1259,47 @@ suite "Stream Transport test suite":
|
|||
await allFutures(rtransp.closeWait(), wtransp.closeWait())
|
||||
return buffer == message
|
||||
|
||||
proc testConnectBindLocalAddress() {.async.} =
|
||||
let dst1 = initTAddress("127.0.0.1:33335")
|
||||
let dst2 = initTAddress("127.0.0.1:33336")
|
||||
let dst3 = initTAddress("127.0.0.1:33337")
|
||||
|
||||
proc client(server: StreamServer, transp: StreamTransport) {.async.} =
|
||||
await transp.closeWait()
|
||||
|
||||
# We use ReuseAddr here only to be able to reuse the same IP/Port when there's a TIME_WAIT socket. It's useful when
|
||||
# running the test multiple times or if a test ran previously used the same port.
|
||||
let servers =
|
||||
[createStreamServer(dst1, client, {ReuseAddr}),
|
||||
createStreamServer(dst2, client, {ReuseAddr}),
|
||||
createStreamServer(dst3, client, {ReusePort})]
|
||||
|
||||
for server in servers:
|
||||
server.start()
|
||||
|
||||
let ta = initTAddress("0.0.0.0:35000")
|
||||
|
||||
# It works cause there's no active listening socket bound to ta and we are using ReuseAddr
|
||||
var transp1 = await connect(dst1, localAddress = ta, flags={SocketFlags.ReuseAddr})
|
||||
var transp2 = await connect(dst2, localAddress = ta, flags={SocketFlags.ReuseAddr})
|
||||
|
||||
# It works cause even thought there's an active listening socket bound to dst3, we are using ReusePort
|
||||
var transp3 = await connect(dst2, localAddress = dst3, flags={SocketFlags.ReusePort})
|
||||
|
||||
expect(TransportOsError):
|
||||
var transp2 = await connect(dst3, localAddress = ta)
|
||||
|
||||
expect(TransportOsError):
|
||||
var transp3 = await connect(dst3, localAddress = initTAddress(":::35000"))
|
||||
|
||||
await transp1.closeWait()
|
||||
await transp2.closeWait()
|
||||
await transp3.closeWait()
|
||||
|
||||
for server in servers:
|
||||
server.stop()
|
||||
await server.closeWait()
|
||||
|
||||
markFD = getCurrentFD()
|
||||
|
||||
for i in 0..<len(addresses):
|
||||
|
@ -1346,6 +1387,8 @@ suite "Stream Transport test suite":
|
|||
check waitFor(testReadOnClose(addresses[i])) == true
|
||||
test "[PIPE] readExactly()/write() test":
|
||||
check waitFor(testPipe()) == true
|
||||
test "[IP] bind connect to local address":
|
||||
waitFor(testConnectBindLocalAddress())
|
||||
test "Servers leak test":
|
||||
check getTracker("stream.server").isLeaked() == false
|
||||
test "Transports leak test":
|
||||
|
|
Loading…
Reference in New Issue