Add `localAddress` support to `stream.connect` (#362)
* Add `localAddress` support to `stream.connect` * fix windows * TransportAddress() instead of AnyAddress * tweak flags * Better flags * try to workaround nim 1.2 issue * Handle ReusePort in createStreamServer and improve tests * Rename ClientFlags to SocketFlags --------- Co-authored-by: Diego <diego@status.im>
This commit is contained in:
parent
229de5f842
commit
ab5a8c2e0f
|
@ -50,7 +50,12 @@ type
|
||||||
# get stuck on transport `close()`.
|
# get stuck on transport `close()`.
|
||||||
# Please use this flag only if you are making both client and server in
|
# Please use this flag only if you are making both client and server in
|
||||||
# the same thread.
|
# the same thread.
|
||||||
TcpNoDelay
|
TcpNoDelay # deprecated: Use SocketFlags.TcpNoDelay
|
||||||
|
|
||||||
|
SocketFlags* {.pure.} = enum
|
||||||
|
TcpNoDelay,
|
||||||
|
ReuseAddr,
|
||||||
|
ReusePort
|
||||||
|
|
||||||
|
|
||||||
StreamTransportTracker* = ref object of TrackerBase
|
StreamTransportTracker* = ref object of TrackerBase
|
||||||
|
@ -699,7 +704,9 @@ when defined(windows):
|
||||||
proc connect*(address: TransportAddress,
|
proc connect*(address: TransportAddress,
|
||||||
bufferSize = DefaultStreamBufferSize,
|
bufferSize = DefaultStreamBufferSize,
|
||||||
child: StreamTransport = nil,
|
child: StreamTransport = nil,
|
||||||
flags: set[TransportFlags] = {}): Future[StreamTransport] =
|
localAddress = TransportAddress(),
|
||||||
|
flags: set[SocketFlags] = {},
|
||||||
|
): Future[StreamTransport] =
|
||||||
## Open new connection to remote peer with address ``address`` and create
|
## Open new connection to remote peer with address ``address`` and create
|
||||||
## new transport object ``StreamTransport`` for established connection.
|
## new transport object ``StreamTransport`` for established connection.
|
||||||
## ``bufferSize`` is size of internal buffer for transport.
|
## ``bufferSize`` is size of internal buffer for transport.
|
||||||
|
@ -724,7 +731,35 @@ when defined(windows):
|
||||||
retFuture.fail(getTransportOsError(osLastError()))
|
retFuture.fail(getTransportOsError(osLastError()))
|
||||||
return retFuture
|
return retFuture
|
||||||
|
|
||||||
if not(bindToDomain(sock, raddress.getDomain())):
|
if SocketFlags.ReuseAddr in flags:
|
||||||
|
if not(setSockOpt(sock, SOL_SOCKET, SO_REUSEADDR, 1)):
|
||||||
|
let err = osLastError()
|
||||||
|
sock.closeSocket()
|
||||||
|
retFuture.fail(getTransportOsError(err))
|
||||||
|
return retFuture
|
||||||
|
if SocketFlags.ReusePort in flags:
|
||||||
|
if not(setSockOpt(sock, SOL_SOCKET, SO_REUSEPORT, 1)):
|
||||||
|
let err = osLastError()
|
||||||
|
sock.closeSocket()
|
||||||
|
retFuture.fail(getTransportOsError(err))
|
||||||
|
return retFuture
|
||||||
|
|
||||||
|
if localAddress != TransportAddress():
|
||||||
|
if localAddress.family != address.family:
|
||||||
|
sock.closeSocket()
|
||||||
|
retFuture.fail(newException(TransportOsError,
|
||||||
|
"connect local address domain is not equal to target address domain"))
|
||||||
|
return retFuture
|
||||||
|
var
|
||||||
|
localAddr: Sockaddr_storage
|
||||||
|
localAddrLen: SockLen
|
||||||
|
localAddress.toSAddr(localAddr, localAddrLen)
|
||||||
|
if bindSocket(SocketHandle(sock),
|
||||||
|
cast[ptr SockAddr](addr localAddr), localAddrLen) != 0:
|
||||||
|
sock.closeSocket()
|
||||||
|
retFuture.fail(getTransportOsError(osLastError()))
|
||||||
|
return retFuture
|
||||||
|
elif not(bindToDomain(sock, raddress.getDomain())):
|
||||||
let err = wsaGetLastError()
|
let err = wsaGetLastError()
|
||||||
sock.closeSocket()
|
sock.closeSocket()
|
||||||
retFuture.fail(getTransportOsError(err))
|
retFuture.fail(getTransportOsError(err))
|
||||||
|
@ -1496,7 +1531,9 @@ else:
|
||||||
proc connect*(address: TransportAddress,
|
proc connect*(address: TransportAddress,
|
||||||
bufferSize = DefaultStreamBufferSize,
|
bufferSize = DefaultStreamBufferSize,
|
||||||
child: StreamTransport = nil,
|
child: StreamTransport = nil,
|
||||||
flags: set[TransportFlags] = {}): Future[StreamTransport] =
|
localAddress = TransportAddress(),
|
||||||
|
flags: set[SocketFlags] = {},
|
||||||
|
): Future[StreamTransport] =
|
||||||
## Open new connection to remote peer with address ``address`` and create
|
## Open new connection to remote peer with address ``address`` and create
|
||||||
## new transport object ``StreamTransport`` for established connection.
|
## new transport object ``StreamTransport`` for established connection.
|
||||||
## ``bufferSize`` - size of internal buffer for transport.
|
## ``bufferSize`` - size of internal buffer for transport.
|
||||||
|
@ -1523,12 +1560,40 @@ else:
|
||||||
return retFuture
|
return retFuture
|
||||||
|
|
||||||
if address.family in {AddressFamily.IPv4, AddressFamily.IPv6}:
|
if address.family in {AddressFamily.IPv4, AddressFamily.IPv6}:
|
||||||
if TransportFlags.TcpNoDelay in flags:
|
if SocketFlags.TcpNoDelay in flags:
|
||||||
if not(setSockOpt(sock, osdefs.IPPROTO_TCP, osdefs.TCP_NODELAY, 1)):
|
if not(setSockOpt(sock, osdefs.IPPROTO_TCP, osdefs.TCP_NODELAY, 1)):
|
||||||
let err = osLastError()
|
let err = osLastError()
|
||||||
sock.closeSocket()
|
sock.closeSocket()
|
||||||
retFuture.fail(getTransportOsError(err))
|
retFuture.fail(getTransportOsError(err))
|
||||||
return retFuture
|
return retFuture
|
||||||
|
if SocketFlags.ReuseAddr in flags:
|
||||||
|
if not(setSockOpt(sock, SOL_SOCKET, SO_REUSEADDR, 1)):
|
||||||
|
let err = osLastError()
|
||||||
|
sock.closeSocket()
|
||||||
|
retFuture.fail(getTransportOsError(err))
|
||||||
|
return retFuture
|
||||||
|
if SocketFlags.ReusePort in flags:
|
||||||
|
if not(setSockOpt(sock, SOL_SOCKET, SO_REUSEPORT, 1)):
|
||||||
|
let err = osLastError()
|
||||||
|
sock.closeSocket()
|
||||||
|
retFuture.fail(getTransportOsError(err))
|
||||||
|
return retFuture
|
||||||
|
|
||||||
|
if localAddress != TransportAddress():
|
||||||
|
if localAddress.family != address.family:
|
||||||
|
sock.closeSocket()
|
||||||
|
retFuture.fail(newException(TransportOsError,
|
||||||
|
"connect local address domain is not equal to target address domain"))
|
||||||
|
return retFuture
|
||||||
|
var
|
||||||
|
localAddr: Sockaddr_storage
|
||||||
|
localAddrLen: SockLen
|
||||||
|
localAddress.toSAddr(localAddr, localAddrLen)
|
||||||
|
if bindSocket(SocketHandle(sock),
|
||||||
|
cast[ptr SockAddr](addr localAddr), localAddrLen) != 0:
|
||||||
|
sock.closeSocket()
|
||||||
|
retFuture.fail(getTransportOsError(osLastError()))
|
||||||
|
return retFuture
|
||||||
|
|
||||||
proc continuation(udata: pointer) =
|
proc continuation(udata: pointer) =
|
||||||
if not(retFuture.finished()):
|
if not(retFuture.finished()):
|
||||||
|
@ -1776,6 +1841,16 @@ proc join*(server: StreamServer): Future[void] =
|
||||||
retFuture.complete()
|
retFuture.complete()
|
||||||
return retFuture
|
return retFuture
|
||||||
|
|
||||||
|
proc connect*(address: TransportAddress,
|
||||||
|
bufferSize = DefaultStreamBufferSize,
|
||||||
|
child: StreamTransport = nil,
|
||||||
|
flags: set[TransportFlags],
|
||||||
|
localAddress = TransportAddress()): Future[StreamTransport] =
|
||||||
|
# Retro compatibility with TransportFlags
|
||||||
|
var mappedFlags: set[SocketFlags]
|
||||||
|
if TcpNoDelay in flags: mappedFlags.incl(SocketFlags.TcpNoDelay)
|
||||||
|
address.connect(bufferSize, child, localAddress, mappedFlags)
|
||||||
|
|
||||||
proc close*(server: StreamServer) =
|
proc close*(server: StreamServer) =
|
||||||
## Release ``server`` resources.
|
## Release ``server`` resources.
|
||||||
##
|
##
|
||||||
|
@ -1864,6 +1939,13 @@ proc createStreamServer*(host: TransportAddress,
|
||||||
if sock == asyncInvalidSocket:
|
if sock == asyncInvalidSocket:
|
||||||
discard closeFd(SocketHandle(serverSocket))
|
discard closeFd(SocketHandle(serverSocket))
|
||||||
raiseTransportOsError(err)
|
raiseTransportOsError(err)
|
||||||
|
if ServerFlags.ReusePort in flags:
|
||||||
|
if not(setSockOpt(serverSocket, osdefs.SOL_SOCKET,
|
||||||
|
osdefs.SO_REUSEPORT, 1)):
|
||||||
|
let err = osLastError()
|
||||||
|
if sock == asyncInvalidSocket:
|
||||||
|
discard closeFd(SocketHandle(serverSocket))
|
||||||
|
raiseTransportOsError(err)
|
||||||
# TCP flags are not useful for Unix domain sockets.
|
# TCP flags are not useful for Unix domain sockets.
|
||||||
if ServerFlags.TcpNoDelay in flags:
|
if ServerFlags.TcpNoDelay in flags:
|
||||||
if not(setSockOpt(serverSocket, osdefs.IPPROTO_TCP,
|
if not(setSockOpt(serverSocket, osdefs.IPPROTO_TCP,
|
||||||
|
|
|
@ -958,7 +958,7 @@ suite "TLSStream test suite":
|
||||||
key = TLSPrivateKey.init(pemkey)
|
key = TLSPrivateKey.init(pemkey)
|
||||||
cert = TLSCertificate.init(pemcert)
|
cert = TLSCertificate.init(pemcert)
|
||||||
|
|
||||||
var server = createStreamServer(address, serveClient, {ReuseAddr})
|
var server = createStreamServer(address, serveClient, {ServerFlags.ReuseAddr})
|
||||||
server.start()
|
server.start()
|
||||||
var conn = await connect(address)
|
var conn = await connect(address)
|
||||||
var creader = newAsyncStreamReader(conn)
|
var creader = newAsyncStreamReader(conn)
|
||||||
|
|
|
@ -1259,6 +1259,47 @@ suite "Stream Transport test suite":
|
||||||
await allFutures(rtransp.closeWait(), wtransp.closeWait())
|
await allFutures(rtransp.closeWait(), wtransp.closeWait())
|
||||||
return buffer == message
|
return buffer == message
|
||||||
|
|
||||||
|
proc testConnectBindLocalAddress() {.async.} =
|
||||||
|
let dst1 = initTAddress("127.0.0.1:33335")
|
||||||
|
let dst2 = initTAddress("127.0.0.1:33336")
|
||||||
|
let dst3 = initTAddress("127.0.0.1:33337")
|
||||||
|
|
||||||
|
proc client(server: StreamServer, transp: StreamTransport) {.async.} =
|
||||||
|
await transp.closeWait()
|
||||||
|
|
||||||
|
# We use ReuseAddr here only to be able to reuse the same IP/Port when there's a TIME_WAIT socket. It's useful when
|
||||||
|
# running the test multiple times or if a test ran previously used the same port.
|
||||||
|
let servers =
|
||||||
|
[createStreamServer(dst1, client, {ReuseAddr}),
|
||||||
|
createStreamServer(dst2, client, {ReuseAddr}),
|
||||||
|
createStreamServer(dst3, client, {ReusePort})]
|
||||||
|
|
||||||
|
for server in servers:
|
||||||
|
server.start()
|
||||||
|
|
||||||
|
let ta = initTAddress("0.0.0.0:35000")
|
||||||
|
|
||||||
|
# It works cause there's no active listening socket bound to ta and we are using ReuseAddr
|
||||||
|
var transp1 = await connect(dst1, localAddress = ta, flags={SocketFlags.ReuseAddr})
|
||||||
|
var transp2 = await connect(dst2, localAddress = ta, flags={SocketFlags.ReuseAddr})
|
||||||
|
|
||||||
|
# It works cause even thought there's an active listening socket bound to dst3, we are using ReusePort
|
||||||
|
var transp3 = await connect(dst2, localAddress = dst3, flags={SocketFlags.ReusePort})
|
||||||
|
|
||||||
|
expect(TransportOsError):
|
||||||
|
var transp2 = await connect(dst3, localAddress = ta)
|
||||||
|
|
||||||
|
expect(TransportOsError):
|
||||||
|
var transp3 = await connect(dst3, localAddress = initTAddress(":::35000"))
|
||||||
|
|
||||||
|
await transp1.closeWait()
|
||||||
|
await transp2.closeWait()
|
||||||
|
await transp3.closeWait()
|
||||||
|
|
||||||
|
for server in servers:
|
||||||
|
server.stop()
|
||||||
|
await server.closeWait()
|
||||||
|
|
||||||
markFD = getCurrentFD()
|
markFD = getCurrentFD()
|
||||||
|
|
||||||
for i in 0..<len(addresses):
|
for i in 0..<len(addresses):
|
||||||
|
@ -1346,6 +1387,8 @@ suite "Stream Transport test suite":
|
||||||
check waitFor(testReadOnClose(addresses[i])) == true
|
check waitFor(testReadOnClose(addresses[i])) == true
|
||||||
test "[PIPE] readExactly()/write() test":
|
test "[PIPE] readExactly()/write() test":
|
||||||
check waitFor(testPipe()) == true
|
check waitFor(testPipe()) == true
|
||||||
|
test "[IP] bind connect to local address":
|
||||||
|
waitFor(testConnectBindLocalAddress())
|
||||||
test "Servers leak test":
|
test "Servers leak test":
|
||||||
check getTracker("stream.server").isLeaked() == false
|
check getTracker("stream.server").isLeaked() == false
|
||||||
test "Transports leak test":
|
test "Transports leak test":
|
||||||
|
|
Loading…
Reference in New Issue