Fix issue with Windows connect(0.0.0.0).
This commit is contained in:
parent
ae128b0f65
commit
3b8874a9e8
|
@ -159,6 +159,10 @@ proc localAddress*(transp: StreamTransport): TransportAddress =
|
||||||
fromSAddr(addr saddr, slen, transp.local)
|
fromSAddr(addr saddr, slen, transp.local)
|
||||||
result = transp.local
|
result = transp.local
|
||||||
|
|
||||||
|
proc localAddress*(server: StreamServer): TransportAddress =
|
||||||
|
## Returns ``server`` bound local socket address.
|
||||||
|
result = server.local
|
||||||
|
|
||||||
template setReadError(t, e: untyped) =
|
template setReadError(t, e: untyped) =
|
||||||
(t).state.incl(ReadError)
|
(t).state.incl(ReadError)
|
||||||
(t).error = getTransportOsError(e)
|
(t).error = getTransportOsError(e)
|
||||||
|
@ -681,15 +685,28 @@ when defined(windows):
|
||||||
sock: AsyncFD
|
sock: AsyncFD
|
||||||
povl: RefCustomOverlapped
|
povl: RefCustomOverlapped
|
||||||
proto: Protocol
|
proto: Protocol
|
||||||
|
raddress: TransportAddress
|
||||||
|
|
||||||
toSAddr(address, saddr, slen)
|
## BSD Sockets on *nix systems are able to perform connections to
|
||||||
|
## `0.0.0.0` or `::0` which are equal to `127.0.0.1` or `::1`.
|
||||||
|
if (address.family == AddressFamily.IPv4 and
|
||||||
|
address.address_v4 == AnyAddress.address_v4):
|
||||||
|
raddress = initTAddress("127.0.0.1", address.port)
|
||||||
|
elif (address.family == AddressFamily.IPv6 and
|
||||||
|
address.address_v6 == AnyAddress6.address_v6):
|
||||||
|
raddress = initTAddress("::1", address.port)
|
||||||
|
else:
|
||||||
|
raddress = address
|
||||||
|
|
||||||
|
toSAddr(raddress, saddr, slen)
|
||||||
proto = Protocol.IPPROTO_TCP
|
proto = Protocol.IPPROTO_TCP
|
||||||
sock = createAsyncSocket(address.getDomain(), SockType.SOCK_STREAM, proto)
|
sock = createAsyncSocket(raddress.getDomain(), SockType.SOCK_STREAM,
|
||||||
|
proto)
|
||||||
if sock == asyncInvalidSocket:
|
if sock == asyncInvalidSocket:
|
||||||
retFuture.fail(getTransportOsError(osLastError()))
|
retFuture.fail(getTransportOsError(osLastError()))
|
||||||
return retFuture
|
return retFuture
|
||||||
|
|
||||||
if not bindToDomain(sock, address.getDomain()):
|
if not bindToDomain(sock, raddress.getDomain()):
|
||||||
let err = wsaGetLastError()
|
let err = wsaGetLastError()
|
||||||
sock.closeSocket()
|
sock.closeSocket()
|
||||||
retFuture.fail(getTransportOsError(err))
|
retFuture.fail(getTransportOsError(err))
|
||||||
|
@ -1420,6 +1437,7 @@ proc createStreamServer*(host: TransportAddress,
|
||||||
saddr: Sockaddr_storage
|
saddr: Sockaddr_storage
|
||||||
slen: SockLen
|
slen: SockLen
|
||||||
serverSocket: AsyncFD
|
serverSocket: AsyncFD
|
||||||
|
localAddress: TransportAddress
|
||||||
|
|
||||||
when defined(windows):
|
when defined(windows):
|
||||||
# Windows
|
# Windows
|
||||||
|
@ -1458,6 +1476,15 @@ proc createStreamServer*(host: TransportAddress,
|
||||||
serverSocket.closeSocket()
|
serverSocket.closeSocket()
|
||||||
raiseTransportOsError(err)
|
raiseTransportOsError(err)
|
||||||
|
|
||||||
|
slen = SockLen(sizeof(saddr))
|
||||||
|
if getsockname(SocketHandle(serverSocket), cast[ptr SockAddr](addr saddr),
|
||||||
|
addr slen) != 0:
|
||||||
|
let err = osLastError()
|
||||||
|
if sock == asyncInvalidSocket:
|
||||||
|
serverSocket.closeSocket()
|
||||||
|
raiseTransportOsError(err)
|
||||||
|
fromSAddr(addr saddr, slen, localAddress)
|
||||||
|
|
||||||
if nativesockets.listen(SocketHandle(serverSocket), cint(backlog)) != 0:
|
if nativesockets.listen(SocketHandle(serverSocket), cint(backlog)) != 0:
|
||||||
let err = osLastError()
|
let err = osLastError()
|
||||||
if sock == asyncInvalidSocket:
|
if sock == asyncInvalidSocket:
|
||||||
|
@ -1513,6 +1540,16 @@ proc createStreamServer*(host: TransportAddress,
|
||||||
serverSocket.closeSocket()
|
serverSocket.closeSocket()
|
||||||
raiseTransportOsError(err)
|
raiseTransportOsError(err)
|
||||||
|
|
||||||
|
# Obtain real address
|
||||||
|
slen = SockLen(sizeof(saddr))
|
||||||
|
if getsockname(SocketHandle(serverSocket), cast[ptr SockAddr](addr saddr),
|
||||||
|
addr slen) != 0:
|
||||||
|
let err = osLastError()
|
||||||
|
if sock == asyncInvalidSocket:
|
||||||
|
serverSocket.closeSocket()
|
||||||
|
raiseTransportOsError(err)
|
||||||
|
fromSAddr(addr saddr, slen, localAddress)
|
||||||
|
|
||||||
if nativesockets.listen(SocketHandle(serverSocket), cint(backlog)) != 0:
|
if nativesockets.listen(SocketHandle(serverSocket), cint(backlog)) != 0:
|
||||||
let err = osLastError()
|
let err = osLastError()
|
||||||
if sock == asyncInvalidSocket:
|
if sock == asyncInvalidSocket:
|
||||||
|
@ -1532,7 +1569,10 @@ proc createStreamServer*(host: TransportAddress,
|
||||||
result.status = Starting
|
result.status = Starting
|
||||||
result.loopFuture = newFuture[void]("stream.transport.server")
|
result.loopFuture = newFuture[void]("stream.transport.server")
|
||||||
result.udata = udata
|
result.udata = udata
|
||||||
result.local = host
|
if localAddress.family == AddressFamily.None:
|
||||||
|
result.local = host
|
||||||
|
else:
|
||||||
|
result.local = localAddress
|
||||||
|
|
||||||
when defined(windows):
|
when defined(windows):
|
||||||
var cb: CallbackFunc
|
var cb: CallbackFunc
|
||||||
|
|
|
@ -43,6 +43,7 @@ suite "Stream Transport test suite":
|
||||||
m14 = "Closing socket while operation pending test (issue #8)"
|
m14 = "Closing socket while operation pending test (issue #8)"
|
||||||
m15 = "Connection refused test"
|
m15 = "Connection refused test"
|
||||||
m16 = "readOnce() read until atEof() test"
|
m16 = "readOnce() read until atEof() test"
|
||||||
|
m17 = "0.0.0.0/::0 (INADDR_ANY) test"
|
||||||
|
|
||||||
when defined(windows):
|
when defined(windows):
|
||||||
var addresses = [
|
var addresses = [
|
||||||
|
@ -717,6 +718,33 @@ suite "Stream Transport test suite":
|
||||||
await ntransp.closeWait()
|
await ntransp.closeWait()
|
||||||
await server.closeWait()
|
await server.closeWait()
|
||||||
|
|
||||||
|
proc testAnyAddress(): Future[bool] {.async.} =
|
||||||
|
var serverRemote, serverLocal: TransportAddress
|
||||||
|
var connRemote, connLocal: TransportAddress
|
||||||
|
|
||||||
|
proc serveClient(server: StreamServer, transp: StreamTransport) {.async.} =
|
||||||
|
serverRemote = transp.remoteAddress()
|
||||||
|
serverLocal = transp.localAddress()
|
||||||
|
await transp.closeWait()
|
||||||
|
server.stop()
|
||||||
|
server.close()
|
||||||
|
|
||||||
|
var ta = initTAddress("0.0.0.0:0")
|
||||||
|
var server = createStreamServer(ta, serveClient, {ReuseAddr})
|
||||||
|
var la = server.localAddress()
|
||||||
|
server.start()
|
||||||
|
var connFut = connect(la)
|
||||||
|
if await withTimeout(connFut, 5.seconds):
|
||||||
|
var conn = connFut.read()
|
||||||
|
connRemote = conn.remoteAddress()
|
||||||
|
connLocal = conn.localAddress()
|
||||||
|
await server.join()
|
||||||
|
await conn.closeWait()
|
||||||
|
result = (connRemote == serverLocal) and (connLocal == serverRemote)
|
||||||
|
else:
|
||||||
|
server.stop()
|
||||||
|
server.close()
|
||||||
|
|
||||||
for i in 0..<len(addresses):
|
for i in 0..<len(addresses):
|
||||||
test prefixes[i] & "close(transport) test":
|
test prefixes[i] & "close(transport) test":
|
||||||
check waitFor(testCloseTransport(addresses[i])) == 1
|
check waitFor(testCloseTransport(addresses[i])) == 1
|
||||||
|
@ -747,7 +775,7 @@ suite "Stream Transport test suite":
|
||||||
if addresses[i].family == AddressFamily.IPv4:
|
if addresses[i].family == AddressFamily.IPv4:
|
||||||
check waitFor(testSendFile(addresses[i])) == FilesCount
|
check waitFor(testSendFile(addresses[i])) == FilesCount
|
||||||
else:
|
else:
|
||||||
discard
|
skip()
|
||||||
else:
|
else:
|
||||||
check waitFor(testSendFile(addresses[i])) == FilesCount
|
check waitFor(testSendFile(addresses[i])) == FilesCount
|
||||||
test prefixes[i] & m15:
|
test prefixes[i] & m15:
|
||||||
|
@ -761,6 +789,11 @@ suite "Stream Transport test suite":
|
||||||
check waitFor(test16(addresses[i])) == 1
|
check waitFor(test16(addresses[i])) == 1
|
||||||
test prefixes[i] & "Connection reset test on send() only":
|
test prefixes[i] & "Connection reset test on send() only":
|
||||||
check waitFor(testWriteConnReset(addresses[i])) == 1
|
check waitFor(testWriteConnReset(addresses[i])) == 1
|
||||||
|
test prefixes[i] & m17:
|
||||||
|
if addresses[i].family == AddressFamily.IPv4:
|
||||||
|
check waitFor(testAnyAddress()) == true
|
||||||
|
else:
|
||||||
|
skip()
|
||||||
|
|
||||||
test "Servers leak test":
|
test "Servers leak test":
|
||||||
check getTracker("stream.server").isLeaked() == false
|
check getTracker("stream.server").isLeaked() == false
|
||||||
|
|
Loading…
Reference in New Issue