diff --git a/chronos/transports/stream.nim b/chronos/transports/stream.nim index 041a3a01..604bb924 100644 --- a/chronos/transports/stream.nim +++ b/chronos/transports/stream.nim @@ -159,6 +159,10 @@ proc localAddress*(transp: StreamTransport): TransportAddress = fromSAddr(addr saddr, slen, transp.local) result = transp.local +proc localAddress*(server: StreamServer): TransportAddress = + ## Returns ``server`` bound local socket address. + result = server.local + template setReadError(t, e: untyped) = (t).state.incl(ReadError) (t).error = getTransportOsError(e) @@ -681,15 +685,28 @@ when defined(windows): sock: AsyncFD povl: RefCustomOverlapped proto: Protocol + raddress: TransportAddress - toSAddr(address, saddr, slen) + ## BSD Sockets on *nix systems are able to perform connections to + ## `0.0.0.0` or `::0` which are equal to `127.0.0.1` or `::1`. + if (address.family == AddressFamily.IPv4 and + address.address_v4 == AnyAddress.address_v4): + raddress = initTAddress("127.0.0.1", address.port) + elif (address.family == AddressFamily.IPv6 and + address.address_v6 == AnyAddress6.address_v6): + raddress = initTAddress("::1", address.port) + else: + raddress = address + + toSAddr(raddress, saddr, slen) proto = Protocol.IPPROTO_TCP - sock = createAsyncSocket(address.getDomain(), SockType.SOCK_STREAM, proto) + sock = createAsyncSocket(raddress.getDomain(), SockType.SOCK_STREAM, + proto) if sock == asyncInvalidSocket: retFuture.fail(getTransportOsError(osLastError())) return retFuture - if not bindToDomain(sock, address.getDomain()): + if not bindToDomain(sock, raddress.getDomain()): let err = wsaGetLastError() sock.closeSocket() retFuture.fail(getTransportOsError(err)) @@ -1420,6 +1437,7 @@ proc createStreamServer*(host: TransportAddress, saddr: Sockaddr_storage slen: SockLen serverSocket: AsyncFD + localAddress: TransportAddress when defined(windows): # Windows @@ -1458,6 +1476,15 @@ proc createStreamServer*(host: TransportAddress, serverSocket.closeSocket() raiseTransportOsError(err) + slen = SockLen(sizeof(saddr)) + if getsockname(SocketHandle(serverSocket), cast[ptr SockAddr](addr saddr), + addr slen) != 0: + let err = osLastError() + if sock == asyncInvalidSocket: + serverSocket.closeSocket() + raiseTransportOsError(err) + fromSAddr(addr saddr, slen, localAddress) + if nativesockets.listen(SocketHandle(serverSocket), cint(backlog)) != 0: let err = osLastError() if sock == asyncInvalidSocket: @@ -1513,6 +1540,16 @@ proc createStreamServer*(host: TransportAddress, serverSocket.closeSocket() raiseTransportOsError(err) + # Obtain real address + slen = SockLen(sizeof(saddr)) + if getsockname(SocketHandle(serverSocket), cast[ptr SockAddr](addr saddr), + addr slen) != 0: + let err = osLastError() + if sock == asyncInvalidSocket: + serverSocket.closeSocket() + raiseTransportOsError(err) + fromSAddr(addr saddr, slen, localAddress) + if nativesockets.listen(SocketHandle(serverSocket), cint(backlog)) != 0: let err = osLastError() if sock == asyncInvalidSocket: @@ -1532,7 +1569,10 @@ proc createStreamServer*(host: TransportAddress, result.status = Starting result.loopFuture = newFuture[void]("stream.transport.server") result.udata = udata - result.local = host + if localAddress.family == AddressFamily.None: + result.local = host + else: + result.local = localAddress when defined(windows): var cb: CallbackFunc diff --git a/tests/teststream.nim b/tests/teststream.nim index 755f19da..2ce9a0e8 100644 --- a/tests/teststream.nim +++ b/tests/teststream.nim @@ -43,6 +43,7 @@ suite "Stream Transport test suite": m14 = "Closing socket while operation pending test (issue #8)" m15 = "Connection refused test" m16 = "readOnce() read until atEof() test" + m17 = "0.0.0.0/::0 (INADDR_ANY) test" when defined(windows): var addresses = [ @@ -717,6 +718,33 @@ suite "Stream Transport test suite": await ntransp.closeWait() await server.closeWait() + proc testAnyAddress(): Future[bool] {.async.} = + var serverRemote, serverLocal: TransportAddress + var connRemote, connLocal: TransportAddress + + proc serveClient(server: StreamServer, transp: StreamTransport) {.async.} = + serverRemote = transp.remoteAddress() + serverLocal = transp.localAddress() + await transp.closeWait() + server.stop() + server.close() + + var ta = initTAddress("0.0.0.0:0") + var server = createStreamServer(ta, serveClient, {ReuseAddr}) + var la = server.localAddress() + server.start() + var connFut = connect(la) + if await withTimeout(connFut, 5.seconds): + var conn = connFut.read() + connRemote = conn.remoteAddress() + connLocal = conn.localAddress() + await server.join() + await conn.closeWait() + result = (connRemote == serverLocal) and (connLocal == serverRemote) + else: + server.stop() + server.close() + for i in 0..