Fix issue with Windows connect(0.0.0.0).
This commit is contained in:
parent
ae128b0f65
commit
3b8874a9e8
|
@ -159,6 +159,10 @@ proc localAddress*(transp: StreamTransport): TransportAddress =
|
|||
fromSAddr(addr saddr, slen, transp.local)
|
||||
result = transp.local
|
||||
|
||||
proc localAddress*(server: StreamServer): TransportAddress =
|
||||
## Returns ``server`` bound local socket address.
|
||||
result = server.local
|
||||
|
||||
template setReadError(t, e: untyped) =
|
||||
(t).state.incl(ReadError)
|
||||
(t).error = getTransportOsError(e)
|
||||
|
@ -681,15 +685,28 @@ when defined(windows):
|
|||
sock: AsyncFD
|
||||
povl: RefCustomOverlapped
|
||||
proto: Protocol
|
||||
raddress: TransportAddress
|
||||
|
||||
toSAddr(address, saddr, slen)
|
||||
## BSD Sockets on *nix systems are able to perform connections to
|
||||
## `0.0.0.0` or `::0` which are equal to `127.0.0.1` or `::1`.
|
||||
if (address.family == AddressFamily.IPv4 and
|
||||
address.address_v4 == AnyAddress.address_v4):
|
||||
raddress = initTAddress("127.0.0.1", address.port)
|
||||
elif (address.family == AddressFamily.IPv6 and
|
||||
address.address_v6 == AnyAddress6.address_v6):
|
||||
raddress = initTAddress("::1", address.port)
|
||||
else:
|
||||
raddress = address
|
||||
|
||||
toSAddr(raddress, saddr, slen)
|
||||
proto = Protocol.IPPROTO_TCP
|
||||
sock = createAsyncSocket(address.getDomain(), SockType.SOCK_STREAM, proto)
|
||||
sock = createAsyncSocket(raddress.getDomain(), SockType.SOCK_STREAM,
|
||||
proto)
|
||||
if sock == asyncInvalidSocket:
|
||||
retFuture.fail(getTransportOsError(osLastError()))
|
||||
return retFuture
|
||||
|
||||
if not bindToDomain(sock, address.getDomain()):
|
||||
if not bindToDomain(sock, raddress.getDomain()):
|
||||
let err = wsaGetLastError()
|
||||
sock.closeSocket()
|
||||
retFuture.fail(getTransportOsError(err))
|
||||
|
@ -1420,6 +1437,7 @@ proc createStreamServer*(host: TransportAddress,
|
|||
saddr: Sockaddr_storage
|
||||
slen: SockLen
|
||||
serverSocket: AsyncFD
|
||||
localAddress: TransportAddress
|
||||
|
||||
when defined(windows):
|
||||
# Windows
|
||||
|
@ -1458,6 +1476,15 @@ proc createStreamServer*(host: TransportAddress,
|
|||
serverSocket.closeSocket()
|
||||
raiseTransportOsError(err)
|
||||
|
||||
slen = SockLen(sizeof(saddr))
|
||||
if getsockname(SocketHandle(serverSocket), cast[ptr SockAddr](addr saddr),
|
||||
addr slen) != 0:
|
||||
let err = osLastError()
|
||||
if sock == asyncInvalidSocket:
|
||||
serverSocket.closeSocket()
|
||||
raiseTransportOsError(err)
|
||||
fromSAddr(addr saddr, slen, localAddress)
|
||||
|
||||
if nativesockets.listen(SocketHandle(serverSocket), cint(backlog)) != 0:
|
||||
let err = osLastError()
|
||||
if sock == asyncInvalidSocket:
|
||||
|
@ -1513,6 +1540,16 @@ proc createStreamServer*(host: TransportAddress,
|
|||
serverSocket.closeSocket()
|
||||
raiseTransportOsError(err)
|
||||
|
||||
# Obtain real address
|
||||
slen = SockLen(sizeof(saddr))
|
||||
if getsockname(SocketHandle(serverSocket), cast[ptr SockAddr](addr saddr),
|
||||
addr slen) != 0:
|
||||
let err = osLastError()
|
||||
if sock == asyncInvalidSocket:
|
||||
serverSocket.closeSocket()
|
||||
raiseTransportOsError(err)
|
||||
fromSAddr(addr saddr, slen, localAddress)
|
||||
|
||||
if nativesockets.listen(SocketHandle(serverSocket), cint(backlog)) != 0:
|
||||
let err = osLastError()
|
||||
if sock == asyncInvalidSocket:
|
||||
|
@ -1532,7 +1569,10 @@ proc createStreamServer*(host: TransportAddress,
|
|||
result.status = Starting
|
||||
result.loopFuture = newFuture[void]("stream.transport.server")
|
||||
result.udata = udata
|
||||
if localAddress.family == AddressFamily.None:
|
||||
result.local = host
|
||||
else:
|
||||
result.local = localAddress
|
||||
|
||||
when defined(windows):
|
||||
var cb: CallbackFunc
|
||||
|
|
|
@ -43,6 +43,7 @@ suite "Stream Transport test suite":
|
|||
m14 = "Closing socket while operation pending test (issue #8)"
|
||||
m15 = "Connection refused test"
|
||||
m16 = "readOnce() read until atEof() test"
|
||||
m17 = "0.0.0.0/::0 (INADDR_ANY) test"
|
||||
|
||||
when defined(windows):
|
||||
var addresses = [
|
||||
|
@ -717,6 +718,33 @@ suite "Stream Transport test suite":
|
|||
await ntransp.closeWait()
|
||||
await server.closeWait()
|
||||
|
||||
proc testAnyAddress(): Future[bool] {.async.} =
|
||||
var serverRemote, serverLocal: TransportAddress
|
||||
var connRemote, connLocal: TransportAddress
|
||||
|
||||
proc serveClient(server: StreamServer, transp: StreamTransport) {.async.} =
|
||||
serverRemote = transp.remoteAddress()
|
||||
serverLocal = transp.localAddress()
|
||||
await transp.closeWait()
|
||||
server.stop()
|
||||
server.close()
|
||||
|
||||
var ta = initTAddress("0.0.0.0:0")
|
||||
var server = createStreamServer(ta, serveClient, {ReuseAddr})
|
||||
var la = server.localAddress()
|
||||
server.start()
|
||||
var connFut = connect(la)
|
||||
if await withTimeout(connFut, 5.seconds):
|
||||
var conn = connFut.read()
|
||||
connRemote = conn.remoteAddress()
|
||||
connLocal = conn.localAddress()
|
||||
await server.join()
|
||||
await conn.closeWait()
|
||||
result = (connRemote == serverLocal) and (connLocal == serverRemote)
|
||||
else:
|
||||
server.stop()
|
||||
server.close()
|
||||
|
||||
for i in 0..<len(addresses):
|
||||
test prefixes[i] & "close(transport) test":
|
||||
check waitFor(testCloseTransport(addresses[i])) == 1
|
||||
|
@ -747,7 +775,7 @@ suite "Stream Transport test suite":
|
|||
if addresses[i].family == AddressFamily.IPv4:
|
||||
check waitFor(testSendFile(addresses[i])) == FilesCount
|
||||
else:
|
||||
discard
|
||||
skip()
|
||||
else:
|
||||
check waitFor(testSendFile(addresses[i])) == FilesCount
|
||||
test prefixes[i] & m15:
|
||||
|
@ -761,6 +789,11 @@ suite "Stream Transport test suite":
|
|||
check waitFor(test16(addresses[i])) == 1
|
||||
test prefixes[i] & "Connection reset test on send() only":
|
||||
check waitFor(testWriteConnReset(addresses[i])) == 1
|
||||
test prefixes[i] & m17:
|
||||
if addresses[i].family == AddressFamily.IPv4:
|
||||
check waitFor(testAnyAddress()) == true
|
||||
else:
|
||||
skip()
|
||||
|
||||
test "Servers leak test":
|
||||
check getTracker("stream.server").isLeaked() == false
|
||||
|
|
Loading…
Reference in New Issue