Fix issue with Windows connect(0.0.0.0).

This commit is contained in:
cheatfate 2019-10-09 15:12:19 +03:00
parent ae128b0f65
commit 3b8874a9e8
No known key found for this signature in database
GPG Key ID: 46ADD633A7201F95
2 changed files with 78 additions and 5 deletions

View File

@ -159,6 +159,10 @@ proc localAddress*(transp: StreamTransport): TransportAddress =
fromSAddr(addr saddr, slen, transp.local) fromSAddr(addr saddr, slen, transp.local)
result = transp.local result = transp.local
proc localAddress*(server: StreamServer): TransportAddress =
## Returns ``server`` bound local socket address.
result = server.local
template setReadError(t, e: untyped) = template setReadError(t, e: untyped) =
(t).state.incl(ReadError) (t).state.incl(ReadError)
(t).error = getTransportOsError(e) (t).error = getTransportOsError(e)
@ -681,15 +685,28 @@ when defined(windows):
sock: AsyncFD sock: AsyncFD
povl: RefCustomOverlapped povl: RefCustomOverlapped
proto: Protocol proto: Protocol
raddress: TransportAddress
toSAddr(address, saddr, slen) ## BSD Sockets on *nix systems are able to perform connections to
## `0.0.0.0` or `::0` which are equal to `127.0.0.1` or `::1`.
if (address.family == AddressFamily.IPv4 and
address.address_v4 == AnyAddress.address_v4):
raddress = initTAddress("127.0.0.1", address.port)
elif (address.family == AddressFamily.IPv6 and
address.address_v6 == AnyAddress6.address_v6):
raddress = initTAddress("::1", address.port)
else:
raddress = address
toSAddr(raddress, saddr, slen)
proto = Protocol.IPPROTO_TCP proto = Protocol.IPPROTO_TCP
sock = createAsyncSocket(address.getDomain(), SockType.SOCK_STREAM, proto) sock = createAsyncSocket(raddress.getDomain(), SockType.SOCK_STREAM,
proto)
if sock == asyncInvalidSocket: if sock == asyncInvalidSocket:
retFuture.fail(getTransportOsError(osLastError())) retFuture.fail(getTransportOsError(osLastError()))
return retFuture return retFuture
if not bindToDomain(sock, address.getDomain()): if not bindToDomain(sock, raddress.getDomain()):
let err = wsaGetLastError() let err = wsaGetLastError()
sock.closeSocket() sock.closeSocket()
retFuture.fail(getTransportOsError(err)) retFuture.fail(getTransportOsError(err))
@ -1420,6 +1437,7 @@ proc createStreamServer*(host: TransportAddress,
saddr: Sockaddr_storage saddr: Sockaddr_storage
slen: SockLen slen: SockLen
serverSocket: AsyncFD serverSocket: AsyncFD
localAddress: TransportAddress
when defined(windows): when defined(windows):
# Windows # Windows
@ -1458,6 +1476,15 @@ proc createStreamServer*(host: TransportAddress,
serverSocket.closeSocket() serverSocket.closeSocket()
raiseTransportOsError(err) raiseTransportOsError(err)
slen = SockLen(sizeof(saddr))
if getsockname(SocketHandle(serverSocket), cast[ptr SockAddr](addr saddr),
addr slen) != 0:
let err = osLastError()
if sock == asyncInvalidSocket:
serverSocket.closeSocket()
raiseTransportOsError(err)
fromSAddr(addr saddr, slen, localAddress)
if nativesockets.listen(SocketHandle(serverSocket), cint(backlog)) != 0: if nativesockets.listen(SocketHandle(serverSocket), cint(backlog)) != 0:
let err = osLastError() let err = osLastError()
if sock == asyncInvalidSocket: if sock == asyncInvalidSocket:
@ -1513,6 +1540,16 @@ proc createStreamServer*(host: TransportAddress,
serverSocket.closeSocket() serverSocket.closeSocket()
raiseTransportOsError(err) raiseTransportOsError(err)
# Obtain real address
slen = SockLen(sizeof(saddr))
if getsockname(SocketHandle(serverSocket), cast[ptr SockAddr](addr saddr),
addr slen) != 0:
let err = osLastError()
if sock == asyncInvalidSocket:
serverSocket.closeSocket()
raiseTransportOsError(err)
fromSAddr(addr saddr, slen, localAddress)
if nativesockets.listen(SocketHandle(serverSocket), cint(backlog)) != 0: if nativesockets.listen(SocketHandle(serverSocket), cint(backlog)) != 0:
let err = osLastError() let err = osLastError()
if sock == asyncInvalidSocket: if sock == asyncInvalidSocket:
@ -1532,7 +1569,10 @@ proc createStreamServer*(host: TransportAddress,
result.status = Starting result.status = Starting
result.loopFuture = newFuture[void]("stream.transport.server") result.loopFuture = newFuture[void]("stream.transport.server")
result.udata = udata result.udata = udata
if localAddress.family == AddressFamily.None:
result.local = host result.local = host
else:
result.local = localAddress
when defined(windows): when defined(windows):
var cb: CallbackFunc var cb: CallbackFunc

View File

@ -43,6 +43,7 @@ suite "Stream Transport test suite":
m14 = "Closing socket while operation pending test (issue #8)" m14 = "Closing socket while operation pending test (issue #8)"
m15 = "Connection refused test" m15 = "Connection refused test"
m16 = "readOnce() read until atEof() test" m16 = "readOnce() read until atEof() test"
m17 = "0.0.0.0/::0 (INADDR_ANY) test"
when defined(windows): when defined(windows):
var addresses = [ var addresses = [
@ -717,6 +718,33 @@ suite "Stream Transport test suite":
await ntransp.closeWait() await ntransp.closeWait()
await server.closeWait() await server.closeWait()
proc testAnyAddress(): Future[bool] {.async.} =
var serverRemote, serverLocal: TransportAddress
var connRemote, connLocal: TransportAddress
proc serveClient(server: StreamServer, transp: StreamTransport) {.async.} =
serverRemote = transp.remoteAddress()
serverLocal = transp.localAddress()
await transp.closeWait()
server.stop()
server.close()
var ta = initTAddress("0.0.0.0:0")
var server = createStreamServer(ta, serveClient, {ReuseAddr})
var la = server.localAddress()
server.start()
var connFut = connect(la)
if await withTimeout(connFut, 5.seconds):
var conn = connFut.read()
connRemote = conn.remoteAddress()
connLocal = conn.localAddress()
await server.join()
await conn.closeWait()
result = (connRemote == serverLocal) and (connLocal == serverRemote)
else:
server.stop()
server.close()
for i in 0..<len(addresses): for i in 0..<len(addresses):
test prefixes[i] & "close(transport) test": test prefixes[i] & "close(transport) test":
check waitFor(testCloseTransport(addresses[i])) == 1 check waitFor(testCloseTransport(addresses[i])) == 1
@ -747,7 +775,7 @@ suite "Stream Transport test suite":
if addresses[i].family == AddressFamily.IPv4: if addresses[i].family == AddressFamily.IPv4:
check waitFor(testSendFile(addresses[i])) == FilesCount check waitFor(testSendFile(addresses[i])) == FilesCount
else: else:
discard skip()
else: else:
check waitFor(testSendFile(addresses[i])) == FilesCount check waitFor(testSendFile(addresses[i])) == FilesCount
test prefixes[i] & m15: test prefixes[i] & m15:
@ -761,6 +789,11 @@ suite "Stream Transport test suite":
check waitFor(test16(addresses[i])) == 1 check waitFor(test16(addresses[i])) == 1
test prefixes[i] & "Connection reset test on send() only": test prefixes[i] & "Connection reset test on send() only":
check waitFor(testWriteConnReset(addresses[i])) == 1 check waitFor(testWriteConnReset(addresses[i])) == 1
test prefixes[i] & m17:
if addresses[i].family == AddressFamily.IPv4:
check waitFor(testAnyAddress()) == true
else:
skip()
test "Servers leak test": test "Servers leak test":
check getTracker("stream.server").isLeaked() == false check getTracker("stream.server").isLeaked() == false