From a4c27806ea34f2d746ac0228e005b44a95dcf836 Mon Sep 17 00:00:00 2001 From: cheatfate Date: Thu, 25 Oct 2018 13:19:19 +0300 Subject: [PATCH] Add AF_UNIX sockets support. Add Windows emulation of AF_UNIX sockets via Named Pipes. Add tests for AF_UNIX sockets. TransportAddress object change. --- asyncdispatch2/asyncloop.nim | 11 +- asyncdispatch2/transports/common.nim | 316 +++++++++++---- asyncdispatch2/transports/datagram.nim | 56 ++- asyncdispatch2/transports/stream.nim | 515 +++++++++++++++++++------ tests/teststream.nim | 444 ++++++++++----------- 5 files changed, 902 insertions(+), 440 deletions(-) diff --git a/asyncdispatch2/asyncloop.nim b/asyncdispatch2/asyncloop.nim index e989d965..7301355d 100644 --- a/asyncdispatch2/asyncloop.nim +++ b/asyncdispatch2/asyncloop.nim @@ -17,7 +17,7 @@ import asyncfutures2 except callSoon import nativesockets, net, deques export Port, SocketFlag -export asyncfutures2 +export asyncfutures2, timer #{.injectStmt: newGcInvariant().} @@ -409,6 +409,15 @@ when defined(windows) or defined(nimdoc): var acb = AsyncCallback(function: aftercb) loop.callbacks.addLast(acb) + proc closeHandle*(fd: AsyncFD, aftercb: CallbackFunc = nil) = + ## Closes a (pipe/file) handle and ensures that it is unregistered. + let loop = getGlobalDispatcher() + loop.handles.excl(fd) + doAssert closeHandle(Handle(fd)) == 1 + if not isNil(aftercb): + var acb = AsyncCallback(function: aftercb) + loop.callbacks.addLast(acb) + proc unregister*(fd: AsyncFD) = ## Unregisters ``fd``. getGlobalDispatcher().handles.excl(fd) diff --git a/asyncdispatch2/transports/common.nim b/asyncdispatch2/transports/common.nim index cec63951..dbe96201 100644 --- a/asyncdispatch2/transports/common.nim +++ b/asyncdispatch2/transports/common.nim @@ -6,11 +6,9 @@ # Licensed under either of # Apache License, version 2.0, (LICENSE-APACHEv2) # MIT license (LICENSE-MIT) - -import os, net, strutils -from nativesockets import toInt +from net import IpAddressFamily, IpAddress, `$`, parseIpAddress +import os, strutils, nativesockets import ../asyncloop -export net when defined(windows): import winlean @@ -25,12 +23,24 @@ const type ServerFlags* = enum ## Server's flags - ReuseAddr, ReusePort, TcpNoDelay, NoAutoRead, GCUserData + ReuseAddr, ReusePort, TcpNoDelay, NoAutoRead, GCUserData, FirstPipe, + NoPipeFlash + + AddressFamily* {.pure.} = enum + None, IPv4, IPv6, Unix TransportAddress* = object ## Transport network address - address*: IpAddress # IP Address - port*: Port # IP port + case family*: AddressFamily + of AddressFamily.None: + discard + of AddressFamily.IPv4: + address_v4*: array[4, uint8] + of AddressFamily.IPv6: + address_v6*: array[16, uint8] + of AddressFamily.Unix: + address_un*: array[108, uint8] + port*: Port # Port number ServerCommand* = enum ## Server's commands @@ -94,6 +104,8 @@ type TransportAddressError* = object of TransportError ## Transport's address specific exception code*: OSErrorCode + TransportNoSupport* = object of TransportError + ## Transport's capability not supported exception TransportState* = enum ## Transport's state @@ -108,85 +120,154 @@ type WriteError # Write error var - AnyAddress* = TransportAddress( - address: IpAddress(family: IpAddressFamily.IPv4), port: Port(0) - ) ## Default INADDR_ANY address for IPv4 - AnyAddress6* = TransportAddress( - address: IpAddress(family: IpAddressFamily.IPv6), port: Port(0) - ) ## Default INADDR_ANY address for IPv6 + AnyAddress* = TransportAddress(family: AddressFamily.IPv4, port: Port(0)) + ## Default INADDR_ANY address for IPv4 + AnyAddress6* = TransportAddress(family: AddressFamily.IPv6, port: Port(0)) + ## Default INADDR_ANY address for IPv6 -proc getDomain*(address: IpAddress): Domain = - ## Returns OS specific Domain from IP Address. - case address.family - of IpAddressFamily.IPv4: - result = Domain.AF_INET - of IpAddressFamily.IPv6: - result = Domain.AF_INET6 +proc `==`*(lhs, rhs: TransportAddress): bool = + ## Compare two transport addresses ``lhs`` and ``rhs``. Return ``true`` if + ## addresses are equal. + if lhs.family != lhs.family: + return false + if lhs.family == AddressFamily.IPv4: + result = equalMem(unsafeAddr lhs.address_v4[0], + unsafeAddr rhs.address_v4[0], sizeof(lhs.address_v4)) and + (lhs.port == rhs.port) + elif lhs.family == AddressFamily.IPv6: + result = equalMem(unsafeAddr lhs.address_v6[0], + unsafeAddr rhs.address_v6[0], sizeof(lhs.address_v6)) and + (lhs.port == rhs.port) + elif lhs.family == AddressFamily.Unix: + result = equalMem(unsafeAddr lhs.address_un[0], + unsafeAddr rhs.address_un[0], sizeof(lhs.address_un)) proc getDomain*(address: TransportAddress): Domain = ## Returns OS specific Domain from TransportAddress. - result = address.address.getDomain() + case address.family + of AddressFamily.IPv4: + result = Domain.AF_INET + of AddressFamily.IPv6: + result = Domain.AF_INET6 + of AddressFamily.Unix: + when defined(windows): + result = cast[Domain](1) + else: + result = Domain.AF_UNIX + else: + result = cast[Domain](0) proc `$`*(address: TransportAddress): string = ## Returns string representation of ``address``. - case address.address.family - of IpAddressFamily.IPv4: - result = $address.address + case address.family + of AddressFamily.IPv4: + var a = IpAddress( + family: IpAddressFamily.IPv4, + address_v4: address.address_v4 + ) + result = $a result.add(":") - of IpAddressFamily.IPv6: - result = "[" & $address.address & "]" - result.add(":") - result.add($int(address.port)) + result.add($int(address.port)) + of AddressFamily.IPv6: + var a = IpAddress(family: IpAddressFamily.IPv6, + address_v6: address.address_v6) + result = "[" & $a & "]:" + result.add($(int(address.port))) + of AddressFamily.Unix: + const length = sizeof(address.address_un) + 1 + var buffer: array[length, char] + if not equalMem(addr buffer[0], unsafeAddr address.address_un[0], + sizeof(address.address_un)): + copyMem(addr buffer[0], unsafeAddr address.address_un[0], + sizeof(address.address_un)) + result = $cast[cstring](addr buffer) + else: + result = "" + else: + raise newException(TransportAddressError, "Unknown address family!") proc initTAddress*(address: string): TransportAddress = - ## Parses string representation of ``address``. + ## Parses string representation of ``address``. ``address`` can be IPv4, IPv6 + ## or Unix domain address. ## ## IPv4 transport address format is ``a.b.c.d:port``. ## IPv6 transport address format is ``[::]:port``. - var parts = address.rsplit(":", maxsplit = 1) - if len(parts) != 2: - raise newException(TransportAddressError, "Format is
:!") - - try: - let port = parseInt(parts[1]) - doAssert(port > 0 and port < 65536) - result.port = Port(port) - except: - raise newException(TransportAddressError, "Illegal port number!") - - try: - if parts[0][0] == '[' and parts[0][^1] == ']': - result.address = parseIpAddress(parts[0][1..^2]) + ## Unix transport address format is ``/address``. + if len(address) > 0: + if address[0] == '/': + result = TransportAddress(family: AddressFamily.Unix, port: Port(1)) + let size = if len(address) < (sizeof(result.address_un) - 1): len(address) + else: (sizeof(result.address_un) - 1) + copyMem(addr result.address_un[0], unsafeAddr address[0], size) else: - result.address = parseIpAddress(parts[0]) - except: - raise newException(TransportAddressError, getCurrentException().msg) + var port: Port + var parts = address.rsplit(":", maxsplit = 1) + if len(parts) != 2: + raise newException(TransportAddressError, + "Format is
: or
!") + + try: + let portint = parseInt(parts[1]) + doAssert(portint > 0 and portint < 65536) + port = Port(portint) + except: + raise newException(TransportAddressError, "Illegal port number!") + + try: + var ipaddr: IpAddress + if parts[0][0] == '[' and parts[0][^1] == ']': + ipaddr = parseIpAddress(parts[0][1..^2]) + else: + ipaddr = parseIpAddress(parts[0]) + if ipaddr.family == IpAddressFamily.IPv4: + result = TransportAddress(family: AddressFamily.IPv4) + result.address_v4 = ipaddr.address_v4 + elif ipaddr.family == IpAddressFamily.IPv6: + result = TransportAddress(family: AddressFamily.IPv6) + result.address_v6 = ipaddr.address_v6 + else: + raise newException(TransportAddressError, "Incorrect address family!") + result.port = port + except: + raise newException(TransportAddressError, getCurrentException().msg) + else: + result = TransportAddress(family: AddressFamily.Unix) proc initTAddress*(address: string, port: Port): TransportAddress = - ## Initialize ``TransportAddress`` with IP address ``address`` and - ## port number ``port``. + ## Initialize ``TransportAddress`` with IP (IPv4 or IPv6) address ``address`` + ## and port number ``port``. try: - result.address = parseIpAddress(address) - result.port = port + var ipaddr = parseIpAddress(address) + if ipaddr.family == IpAddressFamily.IPv4: + result = TransportAddress(family: AddressFamily.IPv4, port: port) + result.address_v4 = ipaddr.address_v4 + elif ipaddr.family == IpAddressFamily.IPv6: + result = TransportAddress(family: AddressFamily.IPv6, port: port) + result.address_v6 = ipaddr.address_v6 + else: + raise newException(TransportAddressError, "Incorrect address family!") except: raise newException(TransportAddressError, getCurrentException().msg) -proc initTAddress*(address: string, port: int): TransportAddress = - ## Initialize ``TransportAddress`` with IP address ``address`` and - ## port number ``port``. +proc initTAddress*(address: string, port: int): TransportAddress {.inline.} = + ## Initialize ``TransportAddress`` with IP (IPv4 or IPv6) address ``address`` + ## and port number ``port``. if port < 0 or port >= 65536: raise newException(TransportAddressError, "Illegal port number!") - try: - result.address = parseIpAddress(address) - result.port = Port(port) - except: - raise newException(TransportAddressError, getCurrentException().msg) + else: + result = initTAddress(address, Port(port)) proc initTAddress*(address: IpAddress, port: Port): TransportAddress = ## Initialize ``TransportAddress`` with net.nim ``IpAddress`` and ## port number ``port``. - result.address = address - result.port = port + if address.family == IpAddressFamily.IPv4: + result = TransportAddress(family: AddressFamily.IPv4, port: port) + result.address_v4 = address.address_v4 + elif address.family == IpAddressFamily.IPv6: + result = TransportAddress(family: AddressFamily.IPv6, port: port) + result.address_v6 = address.address_v6 + else: + raise newException(TransportAddressError, "Incorrect address family!") proc getAddrInfo(address: string, port: Port, domain: Domain, sockType: SockType = SockType.SOCK_STREAM, @@ -205,8 +286,72 @@ proc getAddrInfo(address: string, port: Port, domain: Domain, else: raise newException(TransportAddressError, $gai_strerror(gaiResult)) +proc fromSAddr*(sa: ptr Sockaddr_storage, sl: Socklen, + address: var TransportAddress) = + ## Set transport address ``address`` with value from OS specific socket + ## address storage. + if int(sa.ss_family) == toInt(Domain.AF_INET) and + int(sl) == sizeof(Sockaddr_in): + address = TransportAddress(family: AddressFamily.IPv4) + let s = cast[ptr Sockaddr_in](sa) + copyMem(addr address.address_v4[0], addr s.sin_addr, + sizeof(address.address_v4)) + address.port = Port(nativesockets.ntohs(s.sin_port)) + elif int(sa.ss_family) == toInt(Domain.AF_INET6) and + int(sl) == sizeof(Sockaddr_in6): + address = TransportAddress(family: AddressFamily.IPv6) + let s = cast[ptr Sockaddr_in6](sa) + copyMem(addr address.address_v6[0], addr s.sin6_addr, + sizeof(address.address_v6)) + address.port = Port(nativesockets.ntohs(s.sin6_port)) + elif int(sa.ss_family) == toInt(Domain.AF_UNIX): + when not defined(windows): + address = TransportAddress(family: AddressFamily.Unix) + if int(sl) > sizeof(sa.ss_family): + var length = int(sl) - sizeof(sa.ss_family) + if length > (sizeof(address.address_un) - 1): + length = sizeof(address.address_un) - 1 + let s = cast[ptr Sockaddr_un](sa) + copyMem(addr address.address_un[0], addr s.sun_path[0], length) + address.port = Port(1) + else: + discard + +proc toSAddr*(address: TransportAddress, sa: var Sockaddr_storage, + sl: var Socklen) = + ## Set socket OS specific socket address storage with address from transport + ## address ``address``. + case address.family + of AddressFamily.IPv4: + sl = Socklen(sizeof(Sockaddr_in)) + let s = cast[ptr Sockaddr_in](addr sa) + s.sin_family = type(s.sin_family)(toInt(Domain.AF_INET)) + s.sin_port = nativesockets.htons(uint16(address.port)) + copyMem(addr s.sin_addr, unsafeAddr address.address_v4[0], + sizeof(s.sin_addr)) + of AddressFamily.IPv6: + sl = Socklen(sizeof(Sockaddr_in6)) + let s = cast[ptr Sockaddr_in6](addr sa) + s.sin6_family = type(s.sin6_family)(toInt(Domain.AF_INET6)) + s.sin6_port = nativesockets.htons(uint16(address.port)) + copyMem(addr s.sin6_addr, unsafeAddr address.address_v6[0], + sizeof(s.sin6_addr)) + of AddressFamily.Unix: + when not defined(windows): + if address.port == Port(0): + sl = Socklen(sizeof(sa.ss_family)) + else: + let s = cast[ptr Sockaddr_un](addr sa) + var name = cast[cstring](unsafeAddr address.address_un[0]) + sl = Socklen(sizeof(sa.ss_family) + len(name) + 1) + s.sun_family = type(s.sun_family)(toInt(Domain.AF_UNIX)) + copyMem(addr s.sun_path, unsafeAddr address.address_un[0], + len(name) + 1) + else: + discard + proc resolveTAddress*(address: string, - family = IpAddressFamily.IPv4): seq[TransportAddress] = + family = AddressFamily.IPv4): seq[TransportAddress] = ## Resolve string representation of ``address``. ## ## Supported formats are: @@ -220,6 +365,8 @@ proc resolveTAddress*(address: string, hostname: string port: int + doAssert(family in {AddressFamily.IPv4, AddressFamily.IPv6}) + result = newSeq[TransportAddress]() var parts = address.rsplit(":", maxsplit = 1) if len(parts) != 2: @@ -237,14 +384,14 @@ proc resolveTAddress*(address: string, else: hostname = parts[0] - var domain = if family == IpAddressFamily.IPv4: Domain.AF_INET else: + var domain = if family == AddressFamily.IPv4: Domain.AF_INET else: Domain.AF_INET6 var aiList = getAddrInfo(hostname, Port(port), domain) var it = aiList while it != nil: var ta: TransportAddress - fromSockAddr(cast[ptr Sockaddr_storage](it.ai_addr)[], - SockLen(it.ai_addrlen), ta.address, ta.port) + fromSAddr(cast[ptr Sockaddr_storage](it.ai_addr), + SockLen(it.ai_addrlen), ta) # For some reason getAddrInfo() sometimes returns duplicate addresses, # for example getAddrInfo(`localhost`) returns `127.0.0.1` twice. if ta notin result: @@ -253,22 +400,24 @@ proc resolveTAddress*(address: string, freeAddrInfo(aiList) proc resolveTAddress*(address: string, port: Port, - family = IpAddressFamily.IPv4): seq[TransportAddress] = + family = AddressFamily.IPv4): seq[TransportAddress] = ## Resolve string representation of ``address``. ## ## ``address`` could be dot IPv4/IPv6 address or hostname. ## ## If hostname address is detected, then network address translation via DNS ## will be performed. + assert(family in {AddressFamily.IPv4, AddressFamily.IPv6}) + result = newSeq[TransportAddress]() - var domain = if family == IpAddressFamily.IPv4: Domain.AF_INET else: + var domain = if family == AddressFamily.IPv4: Domain.AF_INET else: Domain.AF_INET6 var aiList = getAddrInfo(address, port, domain) var it = aiList while it != nil: var ta: TransportAddress - fromSockAddr(cast[ptr Sockaddr_storage](it.ai_addr)[], - SockLen(it.ai_addrlen), ta.address, ta.port) + fromSAddr(cast[ptr Sockaddr_storage](it.ai_addr), + SockLen(it.ai_addrlen), ta) # For some reason getAddrInfo() sometimes returns duplicate addresses, # for example getAddrInfo(`localhost`) returns `127.0.0.1` twice. if ta notin result: @@ -290,13 +439,6 @@ template getError*(t: untyped): ref Exception = (t).error = nil err -proc raiseTransportOsError*(err: OSErrorCode) = - ## Raises transport specific OS error. - var msg = "(" & $int(err) & ") " & osErrorMsg(err) - var tre = newException(TransportOsError, msg) - tre.code = err - raise tre - template getTransportOsError*(err: OSErrorCode): ref TransportOsError = var msg = "(" & $int(err) & ") " & osErrorMsg(err) var tre = newException(TransportOsError, msg) @@ -306,6 +448,10 @@ template getTransportOsError*(err: OSErrorCode): ref TransportOsError = template getTransportOsError*(err: cint): ref TransportOsError = getTransportOsError(OSErrorCode(err)) +proc raiseTransportOsError*(err: OSErrorCode) = + ## Raises transport specific OS error. + raise getTransportOsError(err) + type SeqHeader = object length, reserved: int @@ -321,8 +467,28 @@ when defined(windows): const ERROR_OPERATION_ABORTED* = 995 + ERROR_PIPE_CONNECTED* = 535 + ERROR_PIPE_BUSY* = 231 ERROR_SUCCESS* = 0 ERROR_CONNECTION_REFUSED* = 1225 + PIPE_TYPE_BYTE* = 0 + PIPE_READMODE_BYTE* = 0 + PIPE_TYPE_MESSAGE* = 0x4 + PIPE_READMODE_MESSAGE* = 0x2 + PIPE_WAIT* = 0 + PIPE_UNLIMITED_INSTANCES* = 255 + ERROR_BROKEN_PIPE* = 109 + ERROR_PIPE_NOT_CONNECTED* = 233 + ERROR_NO_DATA* = 232 proc cancelIo*(hFile: HANDLE): WINBOOL {.stdcall, dynlib: "kernel32", importc: "CancelIo".} + proc connectNamedPipe*(hPipe: HANDLE, lpOverlapped: ptr OVERLAPPED): WINBOOL + {.stdcall, dynlib: "kernel32", importc: "ConnectNamedPipe".} + proc disconnectNamedPipe*(hPipe: HANDLE): WINBOOL + {.stdcall, dynlib: "kernel32", importc: "DisconnectNamedPipe".} + proc setNamedPipeHandleState*(hPipe: HANDLE, lpMode, lpMaxCollectionCount, + lpCollectDataTimeout: ptr DWORD): WINBOOL + {.stdcall, dynlib: "kernel32", importc: "SetNamedPipeHandleState".} + proc resetEvent*(hEvent: HANDLE): WINBOOL + {.stdcall, dynlib: "kernel32", importc: "ResetEvent".} diff --git a/asyncdispatch2/transports/datagram.nim b/asyncdispatch2/transports/datagram.nim index 47451da7..4cca7965 100644 --- a/asyncdispatch2/transports/datagram.nim +++ b/asyncdispatch2/transports/datagram.nim @@ -94,8 +94,7 @@ when defined(windows): transp.setWriterWSABuffer(vector) var ret: cint if vector.kind == WithAddress: - toSockAddr(vector.address.address, vector.address.port, - transp.waddr, transp.walen) + toSAddr(vector.address, transp.waddr, transp.walen) ret = WSASendTo(fd, addr transp.wwsabuf, DWORD(1), addr bytesCount, DWORD(0), cast[ptr SockAddr](addr transp.waddr), cint(transp.walen), @@ -139,7 +138,7 @@ when defined(windows): let bytesCount = transp.rovl.data.bytesCount if bytesCount == 0: transp.state.incl({ReadEof, ReadPaused}) - fromSockAddr(transp.raddr, transp.ralen, raddr.address, raddr.port) + fromSAddr(addr transp.raddr, transp.ralen, raddr) transp.buflen = bytesCount asyncCheck transp.function(transp, raddr) elif int(err) == ERROR_OPERATION_ABORTED: @@ -200,7 +199,7 @@ when defined(windows): child: DatagramTransport, bufferSize: int): DatagramTransport = var localSock: AsyncFD - assert(remote.address.family == local.address.family) + assert(remote.family == local.family) assert(not isNil(cbproc)) if isNil(child): @@ -209,12 +208,8 @@ when defined(windows): result = child if sock == asyncInvalidSocket: - if local.address.family == IpAddressFamily.IPv4: - localSock = createAsyncSocket(Domain.AF_INET, SockType.SOCK_DGRAM, - Protocol.IPPROTO_UDP) - else: - localSock = createAsyncSocket(Domain.AF_INET6, SockType.SOCK_DGRAM, - Protocol.IPPROTO_UDP) + localSock = createAsyncSocket(local.getDomain(), SockType.SOCK_DGRAM, + Protocol.IPPROTO_UDP) if localSock == asyncInvalidSocket: raiseTransportOsError(osLastError()) else: @@ -239,10 +234,10 @@ when defined(windows): addr bytesRet, nil, nil) != 0: raiseTransportOsError(osLastError()) - if local.port != Port(0): + if local.family != AddressFamily.None: var saddr: Sockaddr_storage var slen: SockLen - toSockAddr(local.address, local.port, saddr, slen) + toSAddr(local, saddr, slen) if bindAddr(SocketHandle(localSock), cast[ptr SockAddr](addr saddr), slen) != 0: let err = osLastError() @@ -253,12 +248,7 @@ when defined(windows): else: var saddr: Sockaddr_storage var slen: SockLen - if local.address.family == IpAddressFamily.IPv4: - saddr.ss_family = winlean.AF_INET - slen = SockLen(sizeof(SockAddr_in)) - else: - saddr.ss_family = winlean.AF_INET6 - slen = SockLen(sizeof(SockAddr_in6)) + saddr.ss_family = type(saddr.ss_family)(local.getDomain()) if bindAddr(SocketHandle(localSock), cast[ptr SockAddr](addr saddr), slen) != 0: let err = osLastError() @@ -269,7 +259,7 @@ when defined(windows): if remote.port != Port(0): var saddr: Sockaddr_storage var slen: SockLen - toSockAddr(remote.address, remote.port, saddr, slen) + toSAddr(remote, saddr, slen) if connect(SocketHandle(localSock), cast[ptr SockAddr](addr saddr), slen) != 0: let err = osLastError() @@ -320,7 +310,7 @@ else: cast[ptr SockAddr](addr transp.raddr), addr transp.ralen) if res >= 0: - fromSockAddr(transp.raddr, transp.ralen, raddr.address, raddr.port) + fromSAddr(addr transp.raddr, transp.ralen, raddr) transp.buflen = res asyncCheck transp.function(transp, raddr) else: @@ -350,8 +340,7 @@ else: var vector = transp.queue.popFirst() while true: if vector.kind == WithAddress: - toSockAddr(vector.address.address, vector.address.port, - transp.waddr, transp.walen) + toSAddr(vector.address, transp.waddr, transp.walen) res = posix.sendto(fd, vector.buf, vector.buflen, MSG_NOSIGNAL, cast[ptr SockAddr](addr transp.waddr), transp.walen) @@ -387,7 +376,7 @@ else: child: DatagramTransport = nil, bufferSize: int): DatagramTransport = var localSock: AsyncFD - assert(remote.address.family == local.address.family) + assert(remote.family == local.family) assert(not isNil(cbproc)) if isNil(child): @@ -396,12 +385,13 @@ else: result = child if sock == asyncInvalidSocket: - if local.address.family == IpAddressFamily.IPv4: - localSock = createAsyncSocket(Domain.AF_INET, SockType.SOCK_DGRAM, - Protocol.IPPROTO_UDP) - else: - localSock = createAsyncSocket(Domain.AF_INET6, SockType.SOCK_DGRAM, - Protocol.IPPROTO_UDP) + var proto = Protocol.IPPROTO_UDP + if local.family == AddressFamily.Unix: + # `Protocol` enum is missing `0` value, so we making here cast, until + # `Protocol` enum will not support IPPROTO_IP == 0. + proto = cast[Protocol](0) + localSock = createAsyncSocket(local.getDomain(), SockType.SOCK_DGRAM, + proto) if localSock == asyncInvalidSocket: raiseTransportOsError(osLastError()) else: @@ -418,10 +408,10 @@ else: closeSocket(localSock) raiseTransportOsError(err) - if local.port != Port(0): + if local.family != AddressFamily.None: var saddr: Sockaddr_storage var slen: SockLen - toSockAddr(local.address, local.port, saddr, slen) + toSAddr(local, saddr, slen) if bindAddr(SocketHandle(localSock), cast[ptr SockAddr](addr saddr), slen) != 0: let err = osLastError() @@ -430,10 +420,10 @@ else: raiseTransportOsError(err) result.local = local - if remote.port != Port(0): + if remote.family != AddressFamily.None: var saddr: Sockaddr_storage var slen: SockLen - toSockAddr(remote.address, remote.port, saddr, slen) + toSAddr(remote, saddr, slen) if connect(SocketHandle(localSock), cast[ptr SockAddr](addr saddr), slen) != 0: let err = osLastError() diff --git a/asyncdispatch2/transports/stream.nim b/asyncdispatch2/transports/stream.nim index 7f5af885..0a0aaaca 100644 --- a/asyncdispatch2/transports/stream.nim +++ b/asyncdispatch2/transports/stream.nim @@ -11,6 +11,8 @@ import net, nativesockets, os, deques import ../asyncloop, ../handles, ../sendfile import common +{.deadCodeElim: on.} + when defined(windows): import winlean else: @@ -33,6 +35,23 @@ type Pipe, # Pipe transport File # File transport + TransportFlags* = enum + None, + # Default value + WinServerPipe, + # This is internal flag which used to differentiate between server pipe + # handle and client pipe handle. + WinNoPipeFlash + # By default `AddressFamily.Unix` transports in Windows are using + # `FlushFileBuffers()` when transport closing. + # This flag disables usage of `FlushFileBuffers()` on `AddressFamily.Unix` + # transport shutdown. If both server and client are running in the same + # thread, because of `FlushFileBuffers()` will ensure that all bytes + # or messages written to the pipe are read by the client, it is possible to + # get stuck on transport `close()`. + # Please use this flag only if you are making both client and server in + # the same thread. + when defined(windows): const SO_UPDATE_CONNECT_CONTEXT = 0x7010 @@ -52,6 +71,7 @@ when defined(windows): rovl: CustomOverlapped # Reader OVERLAPPED structure wovl: CustomOverlapped # Writer OVERLAPPED structure roffset: int # Pending reading offset + flags: set[TransportFlags] # Internal flags case kind*: TransportKind of TransportKind.Socket: domain: Domain # Socket transport domain (IPv4/IPv6) @@ -83,7 +103,6 @@ else: todo2: int type - StreamCallback* = proc(server: StreamServer, client: StreamTransport): Future[void] {.gcsafe.} ## New remote client connection callback @@ -92,7 +111,7 @@ type TransportInitCallback* = proc(server: StreamServer, fd: AsyncFD): StreamTransport {.gcsafe.} - ## Custom transport initialization procedure, which can allocated inherited + ## Custom transport initialization procedure, which can allocate inherited ## StreamTransport object. StreamServer* = ref object of SocketServer @@ -106,26 +125,26 @@ proc remoteAddress*(transp: StreamTransport): TransportAddress = ## Returns ``transp`` remote socket address. if transp.kind != TransportKind.Socket: raise newException(TransportError, "Socket required!") - if transp.remote.port == Port(0): + if transp.remote.family == AddressFamily.None: var saddr: Sockaddr_storage var slen = SockLen(sizeof(saddr)) if getpeername(SocketHandle(transp.fd), cast[ptr SockAddr](addr saddr), addr slen) != 0: raiseTransportOsError(osLastError()) - fromSockAddr(saddr, slen, transp.remote.address, transp.remote.port) + fromSAddr(addr saddr, slen, transp.remote) result = transp.remote proc localAddress*(transp: StreamTransport): TransportAddress = ## Returns ``transp`` local socket address. if transp.kind != TransportKind.Socket: raise newException(TransportError, "Socket required!") - if transp.local.port == Port(0): + if transp.local.family == AddressFamily.None: var saddr: Sockaddr_storage var slen = SockLen(sizeof(saddr)) if getsockname(SocketHandle(transp.fd), cast[ptr SockAddr](addr saddr), addr slen) != 0: raiseTransportOsError(osLastError()) - fromSockAddr(saddr, slen, transp.local.address, transp.local.port) + fromSAddr(addr saddr, slen, transp.local) result = transp.local template setReadError(t, e: untyped) = @@ -209,6 +228,13 @@ when defined(windows): transp.queue.addFirst(vector) else: vector.writer.complete(int(getFileSize(vector))) + elif transp.kind == TransportKind.Pipe: + if vector.kind == VectorKind.DataBuffer: + if bytesCount < transp.wwsabuf.len: + vector.shiftVectorBuffer(bytesCount) + transp.queue.addFirst(vector) + else: + vector.writer.complete(transp.wwsabuf.len) elif int(err) == ERROR_OPERATION_ABORTED: # CancelIO() interrupt transp.state.incl(WritePaused) @@ -275,6 +301,35 @@ when defined(windows): vector.writer.fail(getTransportOsError(err)) else: transp.queue.addFirst(vector) + elif transp.kind == TransportKind.Pipe: + let pipe = Handle(transp.wovl.data.fd) + var vector = transp.queue.popFirst() + if vector.kind == VectorKind.DataBuffer: + transp.wovl.zeroOvelappedOffset() + transp.setWriterWSABuffer(vector) + let ret = writeFile(pipe, cast[pointer](transp.wwsabuf.buf), + DWORD(transp.wwsabuf.len), addr bytesCount, + cast[POVERLAPPED](addr transp.wovl)) + if ret == 0: + let err = osLastError() + if int(err) == ERROR_OPERATION_ABORTED: + # CancelIO() interrupt + transp.state.excl(WritePending) + transp.state.incl(WritePaused) + vector.writer.complete(0) + elif int(err) == ERROR_IO_PENDING: + transp.queue.addFirst(vector) + elif int(err) == ERROR_NO_DATA: + # The pipe is being closed. + transp.state.excl(WritePending) + transp.state.incl(WritePaused) + vector.writer.complete(0) + else: + transp.state.excl(WritePending) + transp.state = transp.state + {WritePaused, WriteError} + vector.writer.fail(getTransportOsError(err)) + else: + transp.queue.addFirst(vector) break if len(transp.queue) == 0: @@ -283,7 +338,6 @@ when defined(windows): proc readStreamLoop(udata: pointer) {.gcsafe, nimcall.} = var ovl = cast[PtrCustomOverlapped](udata) var transp = cast[StreamTransport](ovl.data.udata) - while true: if ReadPending in transp.state: ## Continuation @@ -312,7 +366,11 @@ when defined(windows): elif int(err) == ERROR_OPERATION_ABORTED: # CancelIO() interrupt transp.state.incl(ReadPaused) - elif int(err) in {ERROR_NETNAME_DELETED, WSAECONNABORTED}: + elif transp.kind == TransportKind.Socket and + (int(err) in {ERROR_NETNAME_DELETED, WSAECONNABORTED}): + transp.state.incl({ReadEof, ReadPaused}) + elif transp.kind == TransportKind.Pipe and + (int(err) in {ERROR_BROKEN_PIPE, ERROR_PIPE_NOT_CONNECTED}): transp.state.incl({ReadEof, ReadPaused}) else: transp.setReadError(err) @@ -339,7 +397,7 @@ when defined(windows): cast[POVERLAPPED](addr transp.rovl), nil) if ret != 0: let err = osLastError() - if int(err) == ERROR_OPERATION_ABORTED: + if int32(err) == ERROR_OPERATION_ABORTED: # CancelIO() interrupt transp.state.excl(ReadPending) transp.state.incl(ReadPaused) @@ -356,6 +414,32 @@ when defined(windows): if not isNil(transp.reader): transp.reader.complete() transp.reader = nil + elif transp.kind == TransportKind.Pipe: + let pipe = Handle(transp.rovl.data.fd) + transp.roffset = transp.offset + transp.setReaderWSABuffer() + let ret = readFile(pipe, cast[pointer](transp.rwsabuf.buf), + DWORD(transp.rwsabuf.len), addr bytesCount, + cast[POVERLAPPED](addr transp.rovl)) + if ret == 0: + let err = osLastError() + if int32(err) == ERROR_OPERATION_ABORTED: + # CancelIO() interrupt + transp.state.excl(ReadPending) + transp.state.incl(ReadPaused) + elif int32(err) in {ERROR_BROKEN_PIPE, ERROR_PIPE_NOT_CONNECTED}: + transp.state.excl(ReadPending) + transp.state.incl({ReadEof, ReadPaused}) + if not isNil(transp.reader): + transp.reader.complete() + transp.reader = nil + elif int32(err) != ERROR_IO_PENDING: + transp.state.excl(ReadPending) + transp.state.incl(ReadPaused) + transp.setReadError(err) + if not isNil(transp.reader): + transp.reader.complete() + transp.reader = nil else: transp.state.incl(ReadPaused) if not isNil(transp.reader): @@ -383,6 +467,27 @@ when defined(windows): GC_ref(transp) result = transp + proc newStreamPipeTransport(fd: AsyncFD, bufsize: int, + child: StreamTransport, + flags: set[TransportFlags] = {}): StreamTransport = + var transp: StreamTransport + if not isNil(child): + transp = child + else: + transp = StreamTransport(kind: TransportKind.Pipe) + transp.fd = fd + transp.rovl.data = CompletionData(fd: fd, cb: readStreamLoop, + udata: cast[pointer](transp)) + transp.wovl.data = CompletionData(fd: fd, cb: writeStreamLoop, + udata: cast[pointer](transp)) + transp.buffer = newSeq[byte](bufsize) + transp.flags = flags + transp.state = {ReadPaused, WritePaused} + transp.queue = initDeque[StreamVector]() + transp.future = newFuture[void]("stream.pipe.transport") + GC_ref(transp) + result = transp + proc bindToDomain(handle: AsyncFD, domain: Domain): bool = result = true if domain == Domain.AF_INET6: @@ -391,7 +496,7 @@ when defined(windows): if bindAddr(SocketHandle(handle), cast[ptr SockAddr](addr(saddr)), sizeof(saddr).SockLen) != 0'i32: result = false - else: + elif domain == Domain.AF_INET: var saddr: Sockaddr_in saddr.sin_family = type(saddr.sin_family)(toInt(domain)) if bindAddr(SocketHandle(handle), cast[ptr SockAddr](addr(saddr)), @@ -400,66 +505,161 @@ when defined(windows): proc connect*(address: TransportAddress, bufferSize = DefaultStreamBufferSize, - child: StreamTransport = nil): Future[StreamTransport] = + child: StreamTransport = nil, + flags: set[TransportFlags] = {}): Future[StreamTransport] = ## Open new connection to remote peer with address ``address`` and create ## new transport object ``StreamTransport`` for established connection. ## ``bufferSize`` is size of internal buffer for transport. let loop = getGlobalDispatcher() - var - saddr: Sockaddr_storage - slen: SockLen - sock: AsyncFD - povl: RefCustomOverlapped var retFuture = newFuture[StreamTransport]("stream.transport.connect") - toSockAddr(address.address, address.port, saddr, slen) - sock = createAsyncSocket(address.address.getDomain(), SockType.SOCK_STREAM, - Protocol.IPPROTO_TCP) + if address.family in {AddressFamily.IPv4, AddressFamily.IPv6}: + ## Socket handling part + var + saddr: Sockaddr_storage + slen: SockLen + sock: AsyncFD + povl: RefCustomOverlapped + proto: Protocol - if sock == asyncInvalidSocket: - retFuture.fail(getTransportOsError(OSErrorCode(wsaGetLastError()))) - return retFuture + toSAddr(address, saddr, slen) + proto = Protocol.IPPROTO_TCP + sock = createAsyncSocket(address.getDomain(), SockType.SOCK_STREAM, proto) + if sock == asyncInvalidSocket: + result.fail(getTransportOsError(osLastError())) - if not bindToDomain(sock, address.address.getDomain()): - let err = wsaGetLastError() - sock.closeSocket() - retFuture.fail(getTransportOsError(err)) - return retFuture - - proc continuation(udata: pointer) = - var ovl = cast[RefCustomOverlapped](udata) - if not retFuture.finished: - if ovl.data.errCode == OSErrorCode(-1): - if setsockopt(SocketHandle(sock), cint(SOL_SOCKET), - cint(SO_UPDATE_CONNECT_CONTEXT), nil, - SockLen(0)) != 0'i32: - sock.closeSocket() - retFuture.fail(getTransportOsError(wsaGetLastError())) - else: - retFuture.complete(newStreamSocketTransport(povl.data.fd, - bufferSize, - child)) - else: - sock.closeSocket() - retFuture.fail(getTransportOsError(ovl.data.errCode)) - GC_unref(ovl) - - povl = RefCustomOverlapped() - GC_ref(povl) - povl.data = CompletionData(fd: sock, cb: continuation) - var res = loop.connectEx(SocketHandle(sock), - cast[ptr SockAddr](addr saddr), - DWORD(slen), nil, 0, nil, - cast[POVERLAPPED](povl)) - # We will not process immediate completion, to avoid undefined behavior. - if not res: - let err = osLastError() - if int32(err) != ERROR_IO_PENDING: - GC_unref(povl) + if not bindToDomain(sock, address.getDomain()): + let err = wsaGetLastError() sock.closeSocket() retFuture.fail(getTransportOsError(err)) + return retFuture + + proc socketContinuation(udata: pointer) = + var ovl = cast[RefCustomOverlapped](udata) + if not retFuture.finished: + if ovl.data.errCode == OSErrorCode(-1): + if setsockopt(SocketHandle(sock), cint(SOL_SOCKET), + cint(SO_UPDATE_CONNECT_CONTEXT), nil, + SockLen(0)) != 0'i32: + let err = wsaGetLastError() + sock.closeSocket() + retFuture.fail(getTransportOsError(err)) + else: + retFuture.complete(newStreamSocketTransport(povl.data.fd, + bufferSize, + child)) + else: + sock.closeSocket() + retFuture.fail(getTransportOsError(ovl.data.errCode)) + GC_unref(ovl) + + povl = RefCustomOverlapped() + GC_ref(povl) + povl.data = CompletionData(fd: sock, cb: socketContinuation) + if address.family in {AddressFamily.IPv4, AddressFamily.IPv6}: + var res = loop.connectEx(SocketHandle(sock), + cast[ptr SockAddr](addr saddr), + DWORD(slen), nil, 0, nil, + cast[POVERLAPPED](povl)) + # We will not process immediate completion, to avoid undefined behavior. + if not res: + let err = osLastError() + if int32(err) != ERROR_IO_PENDING: + GC_unref(povl) + sock.closeSocket() + retFuture.fail(getTransportOsError(err)) + + elif address.family == AddressFamily.Unix: + ## Unix domain socket emulation with Windows Named Pipes. + proc pipeContinuation(udata: pointer) {.gcsafe.} = + var pipeSuffix = $cast[cstring](unsafeAddr address.address_un[0]) + var pipeName = newWideCString(r"\\.\pipe\" & pipeSuffix[1 .. ^1]) + var pipeHandle = createFileW(pipeName, GENERIC_READ or GENERIC_WRITE, + FILE_SHARE_READ or FILE_SHARE_WRITE, + nil, OPEN_EXISTING, + FILE_FLAG_OVERLAPPED, Handle(0)) + if pipeHandle == INVALID_HANDLE_VALUE: + let err = osLastError() + if int32(err) == ERROR_PIPE_BUSY: + addTimer(fastEpochTime() + 50, pipeContinuation, nil) + else: + retFuture.fail(getTransportOsError(err)) + else: + register(AsyncFD(pipeHandle)) + retFuture.complete(newStreamPipeTransport(AsyncFD(pipeHandle), + bufferSize, child)) + pipeContinuation(nil) + return retFuture + proc acceptPipeLoop(udata: pointer) {.gcsafe, nimcall.} = + var ovl = cast[PtrCustomOverlapped](udata) + var server = cast[StreamServer](ovl.data.udata) + var loop = getGlobalDispatcher() + + while true: + if server.apending: + ## Continuation + server.apending = false + if server.status in {ServerStatus.Stopped, ServerStatus.Closed}: + break + else: + if ovl.data.errCode == OSErrorCode(-1): + var ntransp: StreamTransport + var flags = {WinServerPipe} + if NoPipeFlash in server.flags: + flags.incl(WinNoPipeFlash) + if not isNil(server.init): + var transp = server.init(server, server.sock) + ntransp = newStreamPipeTransport(server.sock, server.bufferSize, + transp, flags) + else: + ntransp = newStreamPipeTransport(server.sock, server.bufferSize, + nil, flags) + asyncCheck server.function(server, ntransp) + elif int32(ovl.data.errCode) == ERROR_OPERATION_ABORTED: + # CancelIO() interrupt + break + else: + doAssert disconnectNamedPipe(Handle(server.sock)) == 1 + doAssert closeHandle(HANDLE(server.sock)) == 1 + raiseTransportOsError(osLastError()) + else: + ## Initiation + server.apending = true + if server.status in {ServerStatus.Stopped, ServerStatus.Closed}: + ## Server was already stopped/closed exiting + break + + var pipeSuffix = $cast[cstring](addr server.local.address_un) + var pipeName = newWideCString(r"\\.\pipe\" & pipeSuffix[1 .. ^1]) + var openMode = PIPE_ACCESS_DUPLEX or FILE_FLAG_OVERLAPPED + if FirstPipe notin server.flags: + openMode = openMode or FILE_FLAG_FIRST_PIPE_INSTANCE + server.flags.incl(FirstPipe) + let pipeMode = int32(PIPE_TYPE_BYTE or PIPE_READMODE_BYTE or PIPE_WAIT) + let pipeHandle = createNamedPipe(pipeName, openMode, pipeMode, + PIPE_UNLIMITED_INSTANCES, + DWORD(server.bufferSize), + DWORD(server.bufferSize), + DWORD(0), nil) + if pipeHandle == INVALID_HANDLE_VALUE: + raiseTransportOsError(osLastError()) + server.sock = AsyncFD(pipeHandle) + server.aovl.data.fd = AsyncFD(pipeHandle) + register(server.sock) + let res = connectNamedPipe(pipeHandle, + cast[POVERLAPPED](addr server.aovl)) + if res == 0: + let err = osLastError() + if int32(err) == ERROR_IO_PENDING: + discard + elif int32(err) == ERROR_PIPE_CONNECTED: + discard + else: + raiseTransportOsError(err) + break + proc acceptLoop(udata: pointer) {.gcsafe, nimcall.} = var ovl = cast[PtrCustomOverlapped](udata) var server = cast[StreamServer](ovl.data.udata) @@ -469,36 +669,37 @@ when defined(windows): if server.apending: ## Continuation server.apending = false - if server.status == ServerStatus.Stopped: + if server.status in {ServerStatus.Stopped, ServerStatus.Closed}: + ## Server was already stopped/closed exiting server.asock.closeSocket() + break else: if ovl.data.errCode == OSErrorCode(-1): if setsockopt(SocketHandle(server.asock), cint(SOL_SOCKET), - cint(SO_UPDATE_ACCEPT_CONTEXT), - addr server.sock, + cint(SO_UPDATE_ACCEPT_CONTEXT), addr server.sock, SockLen(sizeof(SocketHandle))) != 0'i32: let err = OSErrorCode(wsaGetLastError()) server.asock.closeSocket() raiseTransportOsError(err) else: + var ntransp: StreamTransport if not isNil(server.init): - var transp = server.init(server, server.asock) - let ntransp = newStreamSocketTransport(server.asock, - server.bufferSize, - transp) - asyncCheck server.function(server, ntransp) + let transp = server.init(server, server.asock) + ntransp = newStreamSocketTransport(server.asock, + server.bufferSize, + transp) else: - let ntransp = newStreamSocketTransport(server.asock, - server.bufferSize, nil) - asyncCheck server.function(server, ntransp) + ntransp = newStreamSocketTransport(server.asock, + server.bufferSize, nil) + asyncCheck server.function(server, ntransp) + elif int32(ovl.data.errCode) == ERROR_OPERATION_ABORTED: # CancelIO() interrupt server.asock.closeSocket() break else: - let err = OSErrorCode(wsaGetLastError()) server.asock.closeSocket() - raiseTransportOsError(err) + raiseTransportOsError(ovl.data.errCode) else: ## Initiation if server.status in {ServerStatus.Stopped, ServerStatus.Closed}: @@ -547,7 +748,7 @@ when defined(windows): proc resumeAccept(server: StreamServer) {.inline.} = if not server.apending: - acceptLoop(cast[pointer](addr server.aovl)) + server.aovl.data.cb(addr server.aovl) else: @@ -681,10 +882,16 @@ else: saddr: Sockaddr_storage slen: SockLen sock: AsyncFD + proto: Protocol var retFuture = newFuture[StreamTransport]("transport.connect") - toSockAddr(address.address, address.port, saddr, slen) - sock = createAsyncSocket(address.address.getDomain(), SockType.SOCK_STREAM, - Protocol.IPPROTO_TCP) + address.toSAddr(saddr, slen) + proto = Protocol.IPPROTO_TCP + if address.family == AddressFamily.Unix: + # `Protocol` enum is missing `0` value, so we making here cast, until + # `Protocol` enum will not support IPPROTO_IP == 0. + proto = cast[Protocol](0) + sock = createAsyncSocket(address.getDomain(), SockType.SOCK_STREAM, + proto) if sock == asyncInvalidSocket: retFuture.fail(getTransportOsError(osLastError())) return retFuture @@ -800,7 +1007,16 @@ proc close*(server: StreamServer) = GC_unref(server) if server.status == ServerStatus.Stopped: server.status = ServerStatus.Closed - server.sock.closeSocket(continuation) + when defined(windows): + if server.local.family in {AddressFamily.IPv4, AddressFamily.IPv6}: + server.sock.closeSocket(continuation) + elif server.local.family in {AddressFamily.Unix}: + if NoPipeFlash notin server.flags: + discard flushFileBuffers(Handle(server.sock)) + doAssert disconnectNamedPipe(Handle(server.sock)) == 1 + closeHandle(server.sock, continuation) + else: + server.sock.closeSocket(continuation) proc closeWait*(server: StreamServer): Future[void] = ## Close server ``server`` and release all resources. @@ -833,53 +1049,112 @@ proc createStreamServer*(host: TransportAddress, saddr: Sockaddr_storage slen: SockLen serverSocket: AsyncFD - if sock == asyncInvalidSocket: - serverSocket = createAsyncSocket(host.address.getDomain(), - SockType.SOCK_STREAM, - Protocol.IPPROTO_TCP) - if serverSocket == asyncInvalidSocket: - raiseTransportOsError(osLastError()) + + when defined(windows): + # Windows + if host.family in {AddressFamily.IPv4, AddressFamily.IPv6}: + if sock == asyncInvalidSocket: + serverSocket = createAsyncSocket(host.getDomain(), + SockType.SOCK_STREAM, + Protocol.IPPROTO_TCP) + if serverSocket == asyncInvalidSocket: + raiseTransportOsError(osLastError()) + else: + if not setSocketBlocking(SocketHandle(sock), false): + raiseTransportOsError(osLastError()) + register(sock) + serverSocket = sock + # SO_REUSEADDR is not useful for Unix domain sockets. + if ServerFlags.ReuseAddr in flags: + if not setSockOpt(serverSocket, SOL_SOCKET, SO_REUSEADDR, 1): + let err = osLastError() + if sock == asyncInvalidSocket: + serverSocket.closeSocket() + raiseTransportOsError(err) + # TCP flags are not useful for Unix domain sockets. + if ServerFlags.TcpNoDelay in flags: + if not setSockOpt(serverSocket, handles.IPPROTO_TCP, + handles.TCP_NODELAY, 1): + let err = osLastError() + if sock == asyncInvalidSocket: + serverSocket.closeSocket() + raiseTransportOsError(err) + host.toSAddr(saddr, slen) + if bindAddr(SocketHandle(serverSocket), cast[ptr SockAddr](addr saddr), + slen) != 0: + let err = osLastError() + if sock == asyncInvalidSocket: + serverSocket.closeSocket() + raiseTransportOsError(err) + + if nativesockets.listen(SocketHandle(serverSocket), cint(backlog)) != 0: + let err = osLastError() + if sock == asyncInvalidSocket: + serverSocket.closeSocket() + raiseTransportOsError(err) + elif host.family == AddressFamily.Unix: + serverSocket = AsyncFD(0) else: - if not setSocketBlocking(SocketHandle(sock), false): - raiseTransportOsError(osLastError()) - register(sock) - serverSocket = sock + # Posix + if sock == asyncInvalidSocket: + var proto = Protocol.IPPROTO_TCP + if host.family == AddressFamily.Unix: + # `Protocol` enum is missing `0` value, so we making here cast, until + # `Protocol` enum will not support IPPROTO_IP == 0. + proto = cast[Protocol](0) + serverSocket = createAsyncSocket(host.getDomain(), + SockType.SOCK_STREAM, + proto) + if serverSocket == asyncInvalidSocket: + raiseTransportOsError(osLastError()) + else: + if not setSocketBlocking(SocketHandle(sock), false): + raiseTransportOsError(osLastError()) + register(sock) + serverSocket = sock - if ServerFlags.ReuseAddr in flags: - if not setSockOpt(serverSocket, SOL_SOCKET, SO_REUSEADDR, 1): + if host.family in {AddressFamily.IPv4, AddressFamily.IPv6}: + # SO_REUSEADDR is not useful for Unix domain sockets. + if ServerFlags.ReuseAddr in flags: + if not setSockOpt(serverSocket, SOL_SOCKET, SO_REUSEADDR, 1): + let err = osLastError() + if sock == asyncInvalidSocket: + serverSocket.closeSocket() + raiseTransportOsError(err) + # TCP flags are not useful for Unix domain sockets. + if ServerFlags.TcpNoDelay in flags: + if not setSockOpt(serverSocket, handles.IPPROTO_TCP, + handles.TCP_NODELAY, 1): + let err = osLastError() + if sock == asyncInvalidSocket: + serverSocket.closeSocket() + raiseTransportOsError(err) + elif host.family in {AddressFamily.Unix}: + # We do not care about result here, because if file cannot be removed, + # `bindAddr` will return EADDRINUSE. + discard posix.unlink(cast[cstring](unsafeAddr host.address_un[0])) + + host.toSAddr(saddr, slen) + if bindAddr(SocketHandle(serverSocket), cast[ptr SockAddr](addr saddr), + slen) != 0: let err = osLastError() if sock == asyncInvalidSocket: serverSocket.closeSocket() raiseTransportOsError(err) - if ServerFlags.TcpNoDelay in flags: - if not setSockOpt(serverSocket, handles.IPPROTO_TCP, - handles.TCP_NODELAY, 1): + if nativesockets.listen(SocketHandle(serverSocket), cint(backlog)) != 0: let err = osLastError() if sock == asyncInvalidSocket: serverSocket.closeSocket() raiseTransportOsError(err) - toSockAddr(host.address, host.port, saddr, slen) - if bindAddr(SocketHandle(serverSocket), cast[ptr SockAddr](addr saddr), - slen) != 0: - let err = osLastError() - if sock == asyncInvalidSocket: - serverSocket.closeSocket() - raiseTransportOsError(err) - - if nativesockets.listen(SocketHandle(serverSocket), cint(backlog)) != 0: - let err = osLastError() - if sock == asyncInvalidSocket: - serverSocket.closeSocket() - raiseTransportOsError(err) - if not isNil(child): result = child else: result = StreamServer() result.sock = serverSocket + result.flags = flags result.function = cbproc result.init = init result.bufferSize = bufferSize @@ -889,10 +1164,17 @@ proc createStreamServer*(host: TransportAddress, result.local = host when defined(windows): - result.aovl.data = CompletionData(fd: serverSocket, cb: acceptLoop, + var cb: CallbackFunc + if host.family in {AddressFamily.IPv4, AddressFamily.IPv6}: + cb = acceptLoop + elif host.family == AddressFamily.Unix: + cb = acceptPipeLoop + + result.aovl.data = CompletionData(fd: serverSocket, cb: cb, udata: cast[pointer](result)) - result.domain = host.address.getDomain() + result.domain = host.getDomain() result.apending = false + GC_ref(result) proc createStreamServer*[T](host: TransportAddress, @@ -967,6 +1249,9 @@ proc writeFile*(transp: StreamTransport, handle: int, ## ## You can specify starting ``offset`` in opened file and number of bytes ## to transfer from file to transport via ``size``. + when defined(windows): + if transp.kind != TransportKind.Socket: + raise newException(TransportNoSupport, "writeFile() is not supported!") var retFuture = newFuture[int]("transport.writeFile") transp.checkClosed(retFuture) var vector = StreamVector(kind: DataFile, writer: retFuture, @@ -1172,7 +1457,7 @@ proc read*(transp: StreamTransport, n = -1): Future[seq[byte]] {.async.} = if transp.offset > 0: let s = len(result) let o = s + transp.offset - if n == -1: + if n < 0: # grabbing all incoming data, until EOF result.setLen(o) copyMem(cast[pointer](addr result[s]), addr(transp.buffer[0]), @@ -1259,7 +1544,19 @@ proc close*(transp: StreamTransport) = transp.state.incl({WriteClosed, ReadClosed}) when defined(windows): discard cancelIo(Handle(transp.fd)) - closeSocket(transp.fd, continuation) + if transp.kind == TransportKind.Pipe: + if WinServerPipe in transp.flags: + if WinNoPipeFlash notin transp.flags: + discard flushFileBuffers(Handle(transp.fd)) + doAssert disconnectNamedPipe(Handle(transp.fd)) == 1 + else: + if WinNoPipeFlash notin transp.flags: + discard flushFileBuffers(Handle(transp.fd)) + closeHandle(transp.fd, continuation) + elif transp.kind == TransportKind.Socket: + closeSocket(transp.fd, continuation) + else: + closeSocket(transp.fd, continuation) proc closeWait*(transp: StreamTransport): Future[void] = ## Close and frees resources of transport ``transp``. diff --git a/tests/teststream.nim b/tests/teststream.nim index 7401fbf0..dfbcdea8 100644 --- a/tests/teststream.nim +++ b/tests/teststream.nim @@ -115,36 +115,6 @@ proc serveClient4(server: StreamServer, transp: StreamTransport) {.async.} = transp.close() await transp.join() -proc serveClient5(server: StreamServer, transp: StreamTransport) {.async.} = - var data = await transp.read() - doAssert(len(data) == len(ConstantMessage) * MessagesCount) - transp.close() - var expect = "" - for i in 0..