Merge pull request #11 from status-im/unixSockets

Add AF_UNIX sockets support.
This commit is contained in:
Eugene Kabanov 2018-11-19 05:02:39 +02:00 committed by GitHub
commit d1ff27ade3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 1007 additions and 504 deletions

View File

@ -1,9 +1,9 @@
version: '{build}'
cache:
- x86_64-4.9.2-release-win32-seh-rt_v4-rev4.7z -> .appveyor.yml
- i686-4.9.2-release-win32-dwarf-rt_v4-rev4.7z -> .appveyor.yml
- Nim -> .appveyor.yml
- x86_64-4.9.2-release-win32-seh-rt_v4-rev4.7z
- i686-4.9.2-release-win32-dwarf-rt_v4-rev4.7z
- Nim
matrix:
# We always want 32 and 64-bit compilation

View File

@ -1,5 +1,5 @@
packageName = "asyncdispatch2"
version = "2.1.4"
version = "2.1.5"
author = "Status Research & Development GmbH"
description = "Asyncdispatch2"
license = "Apache License 2.0 or MIT"

View File

@ -17,7 +17,7 @@ import asyncfutures2 except callSoon
import nativesockets, net, deques
export Port, SocketFlag
export asyncfutures2
export asyncfutures2, timer
#{.injectStmt: newGcInvariant().}
@ -409,6 +409,15 @@ when defined(windows) or defined(nimdoc):
var acb = AsyncCallback(function: aftercb)
loop.callbacks.addLast(acb)
proc closeHandle*(fd: AsyncFD, aftercb: CallbackFunc = nil) =
## Closes a (pipe/file) handle and ensures that it is unregistered.
let loop = getGlobalDispatcher()
loop.handles.excl(fd)
doAssert closeHandle(Handle(fd)) == 1
if not isNil(aftercb):
var acb = AsyncCallback(function: aftercb)
loop.callbacks.addLast(acb)
proc unregister*(fd: AsyncFD) =
## Unregisters ``fd``.
getGlobalDispatcher().handles.excl(fd)
@ -736,6 +745,7 @@ include asyncmacro2
proc callSoon(cbproc: CallbackFunc, data: pointer = nil) =
## Schedule `cbproc` to be called as soon as possible.
## The callback is called when control returns to the event loop.
assert cbproc != nil
let acb = AsyncCallback(function: cbproc, udata: data)
getGlobalDispatcher().callbacks.addLast(acb)

View File

@ -10,7 +10,7 @@
## This module provides cross-platform wrapper for ``sendfile()`` syscall.
when defined(nimdoc):
proc sendfile*(outfd, infd: int, offset: int, count: int): int =
proc sendfile*(outfd, infd: int, offset: int, count: var int): int =
## Copies data between file descriptor ``infd`` and ``outfd``. Because this
## copying is done within the kernel, ``sendfile()`` is more efficient than
## the combination of ``read(2)`` and ``write(2)``, which would require
@ -26,11 +26,13 @@ when defined(nimdoc):
## data from ``infd``.
##
## ``count`` is the number of bytes to copy between the file descriptors.
## On exit ``count`` will hold number of bytes actually transferred between
## file descriptors.
##
## If the transfer was successful, the number of bytes written to ``outfd``
## is returned. Note that a successful call to ``sendfile()`` may write
## fewer bytes than requested; the caller should be prepared to retry the
## call if there were unsent bytes.
## is stored in ``count``, and ``0`` returned. Note that a successful call to
## ``sendfile()`` may write fewer bytes than requested; the caller should
## be prepared to retry the call if there were unsent bytes.
##
## On error, ``-1`` is returned.
@ -39,13 +41,16 @@ when defined(linux) or defined(android):
proc osSendFile*(outfd, infd: cint, offset: ptr int, count: int): int
{.importc: "sendfile", header: "<sys/sendfile.h>".}
proc sendfile*(outfd, infd: int, offset: int, count: int): int =
proc sendfile*(outfd, infd: int, offset: int, count: var int): int =
var o = offset
result = osSendFile(cint(outfd), cint(infd), addr o, count)
if result >= 0:
count = result
result = 0
elif defined(freebsd) or defined(openbsd) or defined(netbsd) or
defined(dragonflybsd):
import posix, os
type
SendfileHeader* = object {.importc: "sf_hdtr",
header: """#include <sys/types.h>
@ -60,16 +65,23 @@ elif defined(freebsd) or defined(openbsd) or defined(netbsd) or
#include <sys/socket.h>
#include <sys/uio.h>""".}
proc sendfile*(outfd, infd: int, offset: int, count: int): int =
proc sendfile*(outfd, infd: int, offset: int, count: var int): int =
var o = 0'u
if osSendFile(cint(infd), cint(outfd), uint(offset), uint(count), nil,
addr o, 0) == 0:
result = int(o)
result = osSendFile(cint(infd), cint(outfd), uint(offset), uint(count), nil,
addr o, 0)
if result >= 0:
count = int(o)
result = 0
else:
result = -1
let err = osLastError()
if int(err) == EAGAIN:
count = int(o)
result = 0
else:
result = -1
elif defined(macosx):
import posix
import posix, os
type
SendfileHeader* = object {.importc: "sf_hdtr",
header: """#include <sys/types.h>
@ -84,9 +96,16 @@ elif defined(macosx):
#include <sys/socket.h>
#include <sys/uio.h>""".}
proc sendfile*(outfd, infd: int, offset: int, count: int): int =
proc sendfile*(outfd, infd: int, offset: int, count: var int): int =
var o = count
if osSendFile(cint(infd), cint(outfd), offset, addr o, nil, 0) == 0:
result = o
result = osSendFile(cint(infd), cint(outfd), offset, addr o, nil, 0)
if result >= 0:
count = int(o)
result = 0
else:
result = -1
let err = osLastError()
if int(err) == EAGAIN:
count = int(o)
result = 0
else:
result = -1

View File

@ -6,9 +6,7 @@
# Licensed under either of
# Apache License, version 2.0, (LICENSE-APACHEv2)
# MIT license (LICENSE-MIT)
import os, net, strutils
from nativesockets import toInt
import os, strutils, nativesockets, net
import ../asyncloop
export net
@ -25,12 +23,24 @@ const
type
ServerFlags* = enum
## Server's flags
ReuseAddr, ReusePort, TcpNoDelay, NoAutoRead, GCUserData
ReuseAddr, ReusePort, TcpNoDelay, NoAutoRead, GCUserData, FirstPipe,
NoPipeFlash
AddressFamily* {.pure.} = enum
None, IPv4, IPv6, Unix
TransportAddress* = object
## Transport network address
address*: IpAddress # IP Address
port*: Port # IP port
case family*: AddressFamily
of AddressFamily.None:
discard
of AddressFamily.IPv4:
address_v4*: array[4, uint8]
of AddressFamily.IPv6:
address_v6*: array[16, uint8]
of AddressFamily.Unix:
address_un*: array[108, uint8]
port*: Port # Port number
ServerCommand* = enum
## Server's commands
@ -94,6 +104,8 @@ type
TransportAddressError* = object of TransportError
## Transport's address specific exception
code*: OSErrorCode
TransportNoSupport* = object of TransportError
## Transport's capability not supported exception
TransportState* = enum
## Transport's state
@ -108,85 +120,154 @@ type
WriteError # Write error
var
AnyAddress* = TransportAddress(
address: IpAddress(family: IpAddressFamily.IPv4), port: Port(0)
) ## Default INADDR_ANY address for IPv4
AnyAddress6* = TransportAddress(
address: IpAddress(family: IpAddressFamily.IPv6), port: Port(0)
) ## Default INADDR_ANY address for IPv6
AnyAddress* = TransportAddress(family: AddressFamily.IPv4, port: Port(0))
## Default INADDR_ANY address for IPv4
AnyAddress6* = TransportAddress(family: AddressFamily.IPv6, port: Port(0))
## Default INADDR_ANY address for IPv6
proc getDomain*(address: IpAddress): Domain =
## Returns OS specific Domain from IP Address.
case address.family
of IpAddressFamily.IPv4:
result = Domain.AF_INET
of IpAddressFamily.IPv6:
result = Domain.AF_INET6
proc `==`*(lhs, rhs: TransportAddress): bool =
## Compare two transport addresses ``lhs`` and ``rhs``. Return ``true`` if
## addresses are equal.
if lhs.family != lhs.family:
return false
if lhs.family == AddressFamily.IPv4:
result = equalMem(unsafeAddr lhs.address_v4[0],
unsafeAddr rhs.address_v4[0], sizeof(lhs.address_v4)) and
(lhs.port == rhs.port)
elif lhs.family == AddressFamily.IPv6:
result = equalMem(unsafeAddr lhs.address_v6[0],
unsafeAddr rhs.address_v6[0], sizeof(lhs.address_v6)) and
(lhs.port == rhs.port)
elif lhs.family == AddressFamily.Unix:
result = equalMem(unsafeAddr lhs.address_un[0],
unsafeAddr rhs.address_un[0], sizeof(lhs.address_un))
proc getDomain*(address: TransportAddress): Domain =
## Returns OS specific Domain from TransportAddress.
result = address.address.getDomain()
case address.family
of AddressFamily.IPv4:
result = Domain.AF_INET
of AddressFamily.IPv6:
result = Domain.AF_INET6
of AddressFamily.Unix:
when defined(windows):
result = cast[Domain](1)
else:
result = Domain.AF_UNIX
else:
result = cast[Domain](0)
proc `$`*(address: TransportAddress): string =
## Returns string representation of ``address``.
case address.address.family
of IpAddressFamily.IPv4:
result = $address.address
case address.family
of AddressFamily.IPv4:
var a = IpAddress(
family: IpAddressFamily.IPv4,
address_v4: address.address_v4
)
result = $a
result.add(":")
of IpAddressFamily.IPv6:
result = "[" & $address.address & "]"
result.add(":")
result.add($int(address.port))
result.add($int(address.port))
of AddressFamily.IPv6:
var a = IpAddress(family: IpAddressFamily.IPv6,
address_v6: address.address_v6)
result = "[" & $a & "]:"
result.add($(int(address.port)))
of AddressFamily.Unix:
const length = sizeof(address.address_un) + 1
var buffer: array[length, char]
if not equalMem(addr buffer[0], unsafeAddr address.address_un[0],
sizeof(address.address_un)):
copyMem(addr buffer[0], unsafeAddr address.address_un[0],
sizeof(address.address_un))
result = $cast[cstring](addr buffer)
else:
result = ""
else:
raise newException(TransportAddressError, "Unknown address family!")
proc initTAddress*(address: string): TransportAddress =
## Parses string representation of ``address``.
## Parses string representation of ``address``. ``address`` can be IPv4, IPv6
## or Unix domain address.
##
## IPv4 transport address format is ``a.b.c.d:port``.
## IPv6 transport address format is ``[::]:port``.
var parts = address.rsplit(":", maxsplit = 1)
if len(parts) != 2:
raise newException(TransportAddressError, "Format is <address>:<port>!")
try:
let port = parseInt(parts[1])
doAssert(port > 0 and port < 65536)
result.port = Port(port)
except:
raise newException(TransportAddressError, "Illegal port number!")
try:
if parts[0][0] == '[' and parts[0][^1] == ']':
result.address = parseIpAddress(parts[0][1..^2])
## Unix transport address format is ``/address``.
if len(address) > 0:
if address[0] == '/':
result = TransportAddress(family: AddressFamily.Unix, port: Port(1))
let size = if len(address) < (sizeof(result.address_un) - 1): len(address)
else: (sizeof(result.address_un) - 1)
copyMem(addr result.address_un[0], unsafeAddr address[0], size)
else:
result.address = parseIpAddress(parts[0])
except:
raise newException(TransportAddressError, getCurrentException().msg)
var port: Port
var parts = address.rsplit(":", maxsplit = 1)
if len(parts) != 2:
raise newException(TransportAddressError,
"Format is <address>:<port> or </address>!")
try:
let portint = parseInt(parts[1])
doAssert(portint > 0 and portint < 65536)
port = Port(portint)
except:
raise newException(TransportAddressError, "Illegal port number!")
try:
var ipaddr: IpAddress
if parts[0][0] == '[' and parts[0][^1] == ']':
ipaddr = parseIpAddress(parts[0][1..^2])
else:
ipaddr = parseIpAddress(parts[0])
if ipaddr.family == IpAddressFamily.IPv4:
result = TransportAddress(family: AddressFamily.IPv4)
result.address_v4 = ipaddr.address_v4
elif ipaddr.family == IpAddressFamily.IPv6:
result = TransportAddress(family: AddressFamily.IPv6)
result.address_v6 = ipaddr.address_v6
else:
raise newException(TransportAddressError, "Incorrect address family!")
result.port = port
except:
raise newException(TransportAddressError, getCurrentException().msg)
else:
result = TransportAddress(family: AddressFamily.Unix)
proc initTAddress*(address: string, port: Port): TransportAddress =
## Initialize ``TransportAddress`` with IP address ``address`` and
## port number ``port``.
## Initialize ``TransportAddress`` with IP (IPv4 or IPv6) address ``address``
## and port number ``port``.
try:
result.address = parseIpAddress(address)
result.port = port
var ipaddr = parseIpAddress(address)
if ipaddr.family == IpAddressFamily.IPv4:
result = TransportAddress(family: AddressFamily.IPv4, port: port)
result.address_v4 = ipaddr.address_v4
elif ipaddr.family == IpAddressFamily.IPv6:
result = TransportAddress(family: AddressFamily.IPv6, port: port)
result.address_v6 = ipaddr.address_v6
else:
raise newException(TransportAddressError, "Incorrect address family!")
except:
raise newException(TransportAddressError, getCurrentException().msg)
proc initTAddress*(address: string, port: int): TransportAddress =
## Initialize ``TransportAddress`` with IP address ``address`` and
## port number ``port``.
proc initTAddress*(address: string, port: int): TransportAddress {.inline.} =
## Initialize ``TransportAddress`` with IP (IPv4 or IPv6) address ``address``
## and port number ``port``.
if port < 0 or port >= 65536:
raise newException(TransportAddressError, "Illegal port number!")
try:
result.address = parseIpAddress(address)
result.port = Port(port)
except:
raise newException(TransportAddressError, getCurrentException().msg)
else:
result = initTAddress(address, Port(port))
proc initTAddress*(address: IpAddress, port: Port): TransportAddress =
## Initialize ``TransportAddress`` with net.nim ``IpAddress`` and
## port number ``port``.
result.address = address
result.port = port
if address.family == IpAddressFamily.IPv4:
result = TransportAddress(family: AddressFamily.IPv4, port: port)
result.address_v4 = address.address_v4
elif address.family == IpAddressFamily.IPv6:
result = TransportAddress(family: AddressFamily.IPv6, port: port)
result.address_v6 = address.address_v6
else:
raise newException(TransportAddressError, "Incorrect address family!")
proc getAddrInfo(address: string, port: Port, domain: Domain,
sockType: SockType = SockType.SOCK_STREAM,
@ -205,8 +286,86 @@ proc getAddrInfo(address: string, port: Port, domain: Domain,
else:
raise newException(TransportAddressError, $gai_strerror(gaiResult))
proc fromSAddr*(sa: ptr Sockaddr_storage, sl: Socklen,
address: var TransportAddress) =
## Set transport address ``address`` with value from OS specific socket
## address storage.
if int(sa.ss_family) == toInt(Domain.AF_INET) and
int(sl) == sizeof(Sockaddr_in):
address = TransportAddress(family: AddressFamily.IPv4)
let s = cast[ptr Sockaddr_in](sa)
copyMem(addr address.address_v4[0], addr s.sin_addr,
sizeof(address.address_v4))
address.port = Port(nativesockets.ntohs(s.sin_port))
elif int(sa.ss_family) == toInt(Domain.AF_INET6) and
int(sl) == sizeof(Sockaddr_in6):
address = TransportAddress(family: AddressFamily.IPv6)
let s = cast[ptr Sockaddr_in6](sa)
copyMem(addr address.address_v6[0], addr s.sin6_addr,
sizeof(address.address_v6))
address.port = Port(nativesockets.ntohs(s.sin6_port))
elif int(sa.ss_family) == toInt(Domain.AF_UNIX):
when not defined(windows):
address = TransportAddress(family: AddressFamily.Unix)
if int(sl) > sizeof(sa.ss_family):
var length = int(sl) - sizeof(sa.ss_family)
if length > (sizeof(address.address_un) - 1):
length = sizeof(address.address_un) - 1
let s = cast[ptr Sockaddr_un](sa)
copyMem(addr address.address_un[0], addr s.sun_path[0], length)
address.port = Port(1)
else:
discard
proc toSAddr*(address: TransportAddress, sa: var Sockaddr_storage,
sl: var Socklen) =
## Set socket OS specific socket address storage with address from transport
## address ``address``.
case address.family
of AddressFamily.IPv4:
sl = Socklen(sizeof(Sockaddr_in))
let s = cast[ptr Sockaddr_in](addr sa)
s.sin_family = type(s.sin_family)(toInt(Domain.AF_INET))
s.sin_port = nativesockets.htons(uint16(address.port))
copyMem(addr s.sin_addr, unsafeAddr address.address_v4[0],
sizeof(s.sin_addr))
of AddressFamily.IPv6:
sl = Socklen(sizeof(Sockaddr_in6))
let s = cast[ptr Sockaddr_in6](addr sa)
s.sin6_family = type(s.sin6_family)(toInt(Domain.AF_INET6))
s.sin6_port = nativesockets.htons(uint16(address.port))
copyMem(addr s.sin6_addr, unsafeAddr address.address_v6[0],
sizeof(s.sin6_addr))
of AddressFamily.Unix:
when not defined(windows):
if address.port == Port(0):
sl = Socklen(sizeof(sa.ss_family))
else:
let s = cast[ptr Sockaddr_un](addr sa)
var name = cast[cstring](unsafeAddr address.address_un[0])
sl = Socklen(sizeof(sa.ss_family) + len(name) + 1)
s.sun_family = type(s.sun_family)(toInt(Domain.AF_UNIX))
copyMem(addr s.sun_path, unsafeAddr address.address_un[0],
len(name) + 1)
else:
discard
proc address*(ta: TransportAddress): IpAddress =
## Converts ``TransportAddress`` to ``net.IpAddress`` object.
##
## Note its impossible to convert ``TransportAddress`` of ``Unix`` family,
## because ``IpAddress`` supports only IPv4, IPv6 addresses.
if ta.family == AddressFamily.IPv4:
result = IpAddress(family: IpAddressFamily.IPv4)
result.address_v4 = ta.address_v4
elif ta.family == AddressFamily.IPv6:
result = IpAddress(family: IpAddressFamily.IPv6)
result.address_v6 = ta.address_v6
else:
raise newException(ValueError, "IpAddress supports only IPv4/IPv6!")
proc resolveTAddress*(address: string,
family = IpAddressFamily.IPv4): seq[TransportAddress] =
family = AddressFamily.IPv4): seq[TransportAddress] =
## Resolve string representation of ``address``.
##
## Supported formats are:
@ -220,6 +379,8 @@ proc resolveTAddress*(address: string,
hostname: string
port: int
doAssert(family in {AddressFamily.IPv4, AddressFamily.IPv6})
result = newSeq[TransportAddress]()
var parts = address.rsplit(":", maxsplit = 1)
if len(parts) != 2:
@ -237,14 +398,14 @@ proc resolveTAddress*(address: string,
else:
hostname = parts[0]
var domain = if family == IpAddressFamily.IPv4: Domain.AF_INET else:
var domain = if family == AddressFamily.IPv4: Domain.AF_INET else:
Domain.AF_INET6
var aiList = getAddrInfo(hostname, Port(port), domain)
var it = aiList
while it != nil:
var ta: TransportAddress
fromSockAddr(cast[ptr Sockaddr_storage](it.ai_addr)[],
SockLen(it.ai_addrlen), ta.address, ta.port)
fromSAddr(cast[ptr Sockaddr_storage](it.ai_addr),
SockLen(it.ai_addrlen), ta)
# For some reason getAddrInfo() sometimes returns duplicate addresses,
# for example getAddrInfo(`localhost`) returns `127.0.0.1` twice.
if ta notin result:
@ -253,22 +414,24 @@ proc resolveTAddress*(address: string,
freeAddrInfo(aiList)
proc resolveTAddress*(address: string, port: Port,
family = IpAddressFamily.IPv4): seq[TransportAddress] =
family = AddressFamily.IPv4): seq[TransportAddress] =
## Resolve string representation of ``address``.
##
## ``address`` could be dot IPv4/IPv6 address or hostname.
##
## If hostname address is detected, then network address translation via DNS
## will be performed.
assert(family in {AddressFamily.IPv4, AddressFamily.IPv6})
result = newSeq[TransportAddress]()
var domain = if family == IpAddressFamily.IPv4: Domain.AF_INET else:
var domain = if family == AddressFamily.IPv4: Domain.AF_INET else:
Domain.AF_INET6
var aiList = getAddrInfo(address, port, domain)
var it = aiList
while it != nil:
var ta: TransportAddress
fromSockAddr(cast[ptr Sockaddr_storage](it.ai_addr)[],
SockLen(it.ai_addrlen), ta.address, ta.port)
fromSAddr(cast[ptr Sockaddr_storage](it.ai_addr),
SockLen(it.ai_addrlen), ta)
# For some reason getAddrInfo() sometimes returns duplicate addresses,
# for example getAddrInfo(`localhost`) returns `127.0.0.1` twice.
if ta notin result:
@ -276,6 +439,22 @@ proc resolveTAddress*(address: string, port: Port,
it = it.ai_next
freeAddrInfo(aiList)
proc resolveTAddress*(address: string,
family: IpAddressFamily): seq[TransportAddress] {.
deprecated.} =
if family == IpAddressFamily.IPv4:
result = resolveTAddress(address, AddressFamily.IPv4)
elif family == IpAddressFamily.IPv6:
result = resolveTAddress(address, AddressFamily.IPv6)
proc resolveTAddress*(address: string, port: Port,
family: IpAddressFamily): seq[TransportAddress] {.
deprecated.} =
if family == IpAddressFamily.IPv4:
result = resolveTAddress(address, port, AddressFamily.IPv4)
elif family == IpAddressFamily.IPv6:
result = resolveTAddress(address, port, AddressFamily.IPv6)
template checkClosed*(t: untyped) =
if (ReadClosed in (t).state) or (WriteClosed in (t).state):
raise newException(TransportError, "Transport is already closed!")
@ -290,13 +469,6 @@ template getError*(t: untyped): ref Exception =
(t).error = nil
err
proc raiseTransportOsError*(err: OSErrorCode) =
## Raises transport specific OS error.
var msg = "(" & $int(err) & ") " & osErrorMsg(err)
var tre = newException(TransportOsError, msg)
tre.code = err
raise tre
template getTransportOsError*(err: OSErrorCode): ref TransportOsError =
var msg = "(" & $int(err) & ") " & osErrorMsg(err)
var tre = newException(TransportOsError, msg)
@ -306,6 +478,10 @@ template getTransportOsError*(err: OSErrorCode): ref TransportOsError =
template getTransportOsError*(err: cint): ref TransportOsError =
getTransportOsError(OSErrorCode(err))
proc raiseTransportOsError*(err: OSErrorCode) =
## Raises transport specific OS error.
raise getTransportOsError(err)
type
SeqHeader = object
length, reserved: int
@ -321,8 +497,28 @@ when defined(windows):
const
ERROR_OPERATION_ABORTED* = 995
ERROR_PIPE_CONNECTED* = 535
ERROR_PIPE_BUSY* = 231
ERROR_SUCCESS* = 0
ERROR_CONNECTION_REFUSED* = 1225
PIPE_TYPE_BYTE* = 0
PIPE_READMODE_BYTE* = 0
PIPE_TYPE_MESSAGE* = 0x4
PIPE_READMODE_MESSAGE* = 0x2
PIPE_WAIT* = 0
PIPE_UNLIMITED_INSTANCES* = 255
ERROR_BROKEN_PIPE* = 109
ERROR_PIPE_NOT_CONNECTED* = 233
ERROR_NO_DATA* = 232
proc cancelIo*(hFile: HANDLE): WINBOOL
{.stdcall, dynlib: "kernel32", importc: "CancelIo".}
proc connectNamedPipe*(hPipe: HANDLE, lpOverlapped: ptr OVERLAPPED): WINBOOL
{.stdcall, dynlib: "kernel32", importc: "ConnectNamedPipe".}
proc disconnectNamedPipe*(hPipe: HANDLE): WINBOOL
{.stdcall, dynlib: "kernel32", importc: "DisconnectNamedPipe".}
proc setNamedPipeHandleState*(hPipe: HANDLE, lpMode, lpMaxCollectionCount,
lpCollectDataTimeout: ptr DWORD): WINBOOL
{.stdcall, dynlib: "kernel32", importc: "SetNamedPipeHandleState".}
proc resetEvent*(hEvent: HANDLE): WINBOOL
{.stdcall, dynlib: "kernel32", importc: "ResetEvent".}

View File

@ -84,7 +84,7 @@ when defined(windows):
transp.state.incl(WritePaused)
vector.writer.complete()
else:
transp.state = transp.state + {WritePaused, WriteError}
transp.state.incl({WritePaused, WriteError})
vector.writer.fail(getTransportOsError(err))
else:
## Initiation
@ -94,8 +94,7 @@ when defined(windows):
transp.setWriterWSABuffer(vector)
var ret: cint
if vector.kind == WithAddress:
toSockAddr(vector.address.address, vector.address.port,
transp.waddr, transp.walen)
toSAddr(vector.address, transp.waddr, transp.walen)
ret = WSASendTo(fd, addr transp.wwsabuf, DWORD(1), addr bytesCount,
DWORD(0), cast[ptr SockAddr](addr transp.waddr),
cint(transp.walen),
@ -107,13 +106,14 @@ when defined(windows):
let err = osLastError()
if int(err) == ERROR_OPERATION_ABORTED:
# CancelIO() interrupt
transp.state.excl(WritePending)
transp.state.incl(WritePaused)
vector.writer.complete()
elif int(err) == ERROR_IO_PENDING:
transp.queue.addFirst(vector)
else:
transp.state.excl(WritePending)
transp.state = transp.state + {WritePaused, WriteError}
transp.state.incl({WritePaused, WriteError})
vector.writer.fail(getTransportOsError(err))
else:
transp.queue.addFirst(vector)
@ -131,15 +131,13 @@ when defined(windows):
while true:
if ReadPending in transp.state:
## Continuation
if ReadClosed in transp.state:
break
transp.state.excl(ReadPending)
let err = transp.rovl.data.errCode
if err == OSErrorCode(-1):
let bytesCount = transp.rovl.data.bytesCount
if bytesCount == 0:
transp.state.incl({ReadEof, ReadPaused})
fromSockAddr(transp.raddr, transp.ralen, raddr.address, raddr.port)
fromSAddr(addr transp.raddr, transp.ralen, raddr)
transp.buflen = bytesCount
asyncCheck transp.function(transp, raddr)
elif int(err) == ERROR_OPERATION_ABORTED:
@ -200,8 +198,9 @@ when defined(windows):
child: DatagramTransport,
bufferSize: int): DatagramTransport =
var localSock: AsyncFD
assert(remote.address.family == local.address.family)
assert(remote.family == local.family)
assert(not isNil(cbproc))
assert(remote.family in {AddressFamily.IPv4, AddressFamily.IPv6})
if isNil(child):
result = DatagramTransport()
@ -209,12 +208,8 @@ when defined(windows):
result = child
if sock == asyncInvalidSocket:
if local.address.family == IpAddressFamily.IPv4:
localSock = createAsyncSocket(Domain.AF_INET, SockType.SOCK_DGRAM,
Protocol.IPPROTO_UDP)
else:
localSock = createAsyncSocket(Domain.AF_INET6, SockType.SOCK_DGRAM,
Protocol.IPPROTO_UDP)
localSock = createAsyncSocket(local.getDomain(), SockType.SOCK_DGRAM,
Protocol.IPPROTO_UDP)
if localSock == asyncInvalidSocket:
raiseTransportOsError(osLastError())
else:
@ -239,10 +234,10 @@ when defined(windows):
addr bytesRet, nil, nil) != 0:
raiseTransportOsError(osLastError())
if local.port != Port(0):
if local.family != AddressFamily.None:
var saddr: Sockaddr_storage
var slen: SockLen
toSockAddr(local.address, local.port, saddr, slen)
toSAddr(local, saddr, slen)
if bindAddr(SocketHandle(localSock), cast[ptr SockAddr](addr saddr),
slen) != 0:
let err = osLastError()
@ -253,12 +248,7 @@ when defined(windows):
else:
var saddr: Sockaddr_storage
var slen: SockLen
if local.address.family == IpAddressFamily.IPv4:
saddr.ss_family = winlean.AF_INET
slen = SockLen(sizeof(SockAddr_in))
else:
saddr.ss_family = winlean.AF_INET6
slen = SockLen(sizeof(SockAddr_in6))
saddr.ss_family = type(saddr.ss_family)(local.getDomain())
if bindAddr(SocketHandle(localSock), cast[ptr SockAddr](addr saddr),
slen) != 0:
let err = osLastError()
@ -269,7 +259,7 @@ when defined(windows):
if remote.port != Port(0):
var saddr: Sockaddr_storage
var slen: SockLen
toSockAddr(remote.address, remote.port, saddr, slen)
toSAddr(remote, saddr, slen)
if connect(SocketHandle(localSock), cast[ptr SockAddr](addr saddr),
slen) != 0:
let err = osLastError()
@ -320,7 +310,7 @@ else:
cast[ptr SockAddr](addr transp.raddr),
addr transp.ralen)
if res >= 0:
fromSockAddr(transp.raddr, transp.ralen, raddr.address, raddr.port)
fromSAddr(addr transp.raddr, transp.ralen, raddr)
transp.buflen = res
asyncCheck transp.function(transp, raddr)
else:
@ -350,8 +340,7 @@ else:
var vector = transp.queue.popFirst()
while true:
if vector.kind == WithAddress:
toSockAddr(vector.address.address, vector.address.port,
transp.waddr, transp.walen)
toSAddr(vector.address, transp.waddr, transp.walen)
res = posix.sendto(fd, vector.buf, vector.buflen, MSG_NOSIGNAL,
cast[ptr SockAddr](addr transp.waddr),
transp.walen)
@ -387,7 +376,7 @@ else:
child: DatagramTransport = nil,
bufferSize: int): DatagramTransport =
var localSock: AsyncFD
assert(remote.address.family == local.address.family)
assert(remote.family == local.family)
assert(not isNil(cbproc))
if isNil(child):
@ -396,12 +385,13 @@ else:
result = child
if sock == asyncInvalidSocket:
if local.address.family == IpAddressFamily.IPv4:
localSock = createAsyncSocket(Domain.AF_INET, SockType.SOCK_DGRAM,
Protocol.IPPROTO_UDP)
else:
localSock = createAsyncSocket(Domain.AF_INET6, SockType.SOCK_DGRAM,
Protocol.IPPROTO_UDP)
var proto = Protocol.IPPROTO_UDP
if local.family == AddressFamily.Unix:
# `Protocol` enum is missing `0` value, so we making here cast, until
# `Protocol` enum will not support IPPROTO_IP == 0.
proto = cast[Protocol](0)
localSock = createAsyncSocket(local.getDomain(), SockType.SOCK_DGRAM,
proto)
if localSock == asyncInvalidSocket:
raiseTransportOsError(osLastError())
else:
@ -421,7 +411,7 @@ else:
if local.port != Port(0):
var saddr: Sockaddr_storage
var slen: SockLen
toSockAddr(local.address, local.port, saddr, slen)
toSAddr(local, saddr, slen)
if bindAddr(SocketHandle(localSock), cast[ptr SockAddr](addr saddr),
slen) != 0:
let err = osLastError()
@ -433,7 +423,7 @@ else:
if remote.port != Port(0):
var saddr: Sockaddr_storage
var slen: SockLen
toSockAddr(remote.address, remote.port, saddr, slen)
toSAddr(remote, saddr, slen)
if connect(SocketHandle(localSock), cast[ptr SockAddr](addr saddr),
slen) != 0:
let err = osLastError()

View File

@ -11,6 +11,8 @@ import net, nativesockets, os, deques
import ../asyncloop, ../handles, ../sendfile
import common
{.deadCodeElim: on.}
when defined(windows):
import winlean
else:
@ -26,6 +28,7 @@ type
buf: pointer # Writer buffer pointer
buflen: int # Writer buffer size
offset: uint # Writer vector offset
size: int # Original size
writer: Future[int] # Writer vector completion Future
TransportKind* {.pure.} = enum
@ -33,6 +36,23 @@ type
Pipe, # Pipe transport
File # File transport
TransportFlags* = enum
None,
# Default value
WinServerPipe,
# This is internal flag which used to differentiate between server pipe
# handle and client pipe handle.
WinNoPipeFlash
# By default `AddressFamily.Unix` transports in Windows are using
# `FlushFileBuffers()` when transport closing.
# This flag disables usage of `FlushFileBuffers()` on `AddressFamily.Unix`
# transport shutdown. If both server and client are running in the same
# thread, because of `FlushFileBuffers()` will ensure that all bytes
# or messages written to the pipe are read by the client, it is possible to
# get stuck on transport `close()`.
# Please use this flag only if you are making both client and server in
# the same thread.
when defined(windows):
const SO_UPDATE_CONNECT_CONTEXT = 0x7010
@ -52,6 +72,7 @@ when defined(windows):
rovl: CustomOverlapped # Reader OVERLAPPED structure
wovl: CustomOverlapped # Writer OVERLAPPED structure
roffset: int # Pending reading offset
flags: set[TransportFlags] # Internal flags
case kind*: TransportKind
of TransportKind.Socket:
domain: Domain # Socket transport domain (IPv4/IPv6)
@ -83,7 +104,6 @@ else:
todo2: int
type
StreamCallback* = proc(server: StreamServer,
client: StreamTransport): Future[void] {.gcsafe.}
## New remote client connection callback
@ -92,7 +112,7 @@ type
TransportInitCallback* = proc(server: StreamServer,
fd: AsyncFD): StreamTransport {.gcsafe.}
## Custom transport initialization procedure, which can allocated inherited
## Custom transport initialization procedure, which can allocate inherited
## StreamTransport object.
StreamServer* = ref object of SocketServer
@ -106,26 +126,26 @@ proc remoteAddress*(transp: StreamTransport): TransportAddress =
## Returns ``transp`` remote socket address.
if transp.kind != TransportKind.Socket:
raise newException(TransportError, "Socket required!")
if transp.remote.port == Port(0):
if transp.remote.family == AddressFamily.None:
var saddr: Sockaddr_storage
var slen = SockLen(sizeof(saddr))
if getpeername(SocketHandle(transp.fd), cast[ptr SockAddr](addr saddr),
addr slen) != 0:
raiseTransportOsError(osLastError())
fromSockAddr(saddr, slen, transp.remote.address, transp.remote.port)
fromSAddr(addr saddr, slen, transp.remote)
result = transp.remote
proc localAddress*(transp: StreamTransport): TransportAddress =
## Returns ``transp`` local socket address.
if transp.kind != TransportKind.Socket:
raise newException(TransportError, "Socket required!")
if transp.local.port == Port(0):
if transp.local.family == AddressFamily.None:
var saddr: Sockaddr_storage
var slen = SockLen(sizeof(saddr))
if getsockname(SocketHandle(transp.fd), cast[ptr SockAddr](addr saddr),
addr slen) != 0:
raiseTransportOsError(osLastError())
fromSockAddr(saddr, slen, transp.local.address, transp.local.port)
fromSAddr(addr saddr, slen, transp.local)
result = transp.local
template setReadError(t, e: untyped) =
@ -209,6 +229,13 @@ when defined(windows):
transp.queue.addFirst(vector)
else:
vector.writer.complete(int(getFileSize(vector)))
elif transp.kind == TransportKind.Pipe:
if vector.kind == VectorKind.DataBuffer:
if bytesCount < transp.wwsabuf.len:
vector.shiftVectorBuffer(bytesCount)
transp.queue.addFirst(vector)
else:
vector.writer.complete(transp.wwsabuf.len)
elif int(err) == ERROR_OPERATION_ABORTED:
# CancelIO() interrupt
transp.state.incl(WritePaused)
@ -275,6 +302,35 @@ when defined(windows):
vector.writer.fail(getTransportOsError(err))
else:
transp.queue.addFirst(vector)
elif transp.kind == TransportKind.Pipe:
let pipe = Handle(transp.wovl.data.fd)
var vector = transp.queue.popFirst()
if vector.kind == VectorKind.DataBuffer:
transp.wovl.zeroOvelappedOffset()
transp.setWriterWSABuffer(vector)
let ret = writeFile(pipe, cast[pointer](transp.wwsabuf.buf),
DWORD(transp.wwsabuf.len), addr bytesCount,
cast[POVERLAPPED](addr transp.wovl))
if ret == 0:
let err = osLastError()
if int(err) == ERROR_OPERATION_ABORTED:
# CancelIO() interrupt
transp.state.excl(WritePending)
transp.state.incl(WritePaused)
vector.writer.complete(0)
elif int(err) == ERROR_IO_PENDING:
transp.queue.addFirst(vector)
elif int(err) == ERROR_NO_DATA:
# The pipe is being closed.
transp.state.excl(WritePending)
transp.state.incl(WritePaused)
vector.writer.complete(0)
else:
transp.state.excl(WritePending)
transp.state = transp.state + {WritePaused, WriteError}
vector.writer.fail(getTransportOsError(err))
else:
transp.queue.addFirst(vector)
break
if len(transp.queue) == 0:
@ -283,7 +339,6 @@ when defined(windows):
proc readStreamLoop(udata: pointer) {.gcsafe, nimcall.} =
var ovl = cast[PtrCustomOverlapped](udata)
var transp = cast[StreamTransport](ovl.data.udata)
while true:
if ReadPending in transp.state:
## Continuation
@ -312,7 +367,11 @@ when defined(windows):
elif int(err) == ERROR_OPERATION_ABORTED:
# CancelIO() interrupt
transp.state.incl(ReadPaused)
elif int(err) in {ERROR_NETNAME_DELETED, WSAECONNABORTED}:
elif transp.kind == TransportKind.Socket and
(int(err) in {ERROR_NETNAME_DELETED, WSAECONNABORTED}):
transp.state.incl({ReadEof, ReadPaused})
elif transp.kind == TransportKind.Pipe and
(int(err) in {ERROR_BROKEN_PIPE, ERROR_PIPE_NOT_CONNECTED}):
transp.state.incl({ReadEof, ReadPaused})
else:
transp.setReadError(err)
@ -339,7 +398,7 @@ when defined(windows):
cast[POVERLAPPED](addr transp.rovl), nil)
if ret != 0:
let err = osLastError()
if int(err) == ERROR_OPERATION_ABORTED:
if int32(err) == ERROR_OPERATION_ABORTED:
# CancelIO() interrupt
transp.state.excl(ReadPending)
transp.state.incl(ReadPaused)
@ -356,6 +415,32 @@ when defined(windows):
if not isNil(transp.reader):
transp.reader.complete()
transp.reader = nil
elif transp.kind == TransportKind.Pipe:
let pipe = Handle(transp.rovl.data.fd)
transp.roffset = transp.offset
transp.setReaderWSABuffer()
let ret = readFile(pipe, cast[pointer](transp.rwsabuf.buf),
DWORD(transp.rwsabuf.len), addr bytesCount,
cast[POVERLAPPED](addr transp.rovl))
if ret == 0:
let err = osLastError()
if int32(err) == ERROR_OPERATION_ABORTED:
# CancelIO() interrupt
transp.state.excl(ReadPending)
transp.state.incl(ReadPaused)
elif int32(err) in {ERROR_BROKEN_PIPE, ERROR_PIPE_NOT_CONNECTED}:
transp.state.excl(ReadPending)
transp.state.incl({ReadEof, ReadPaused})
if not isNil(transp.reader):
transp.reader.complete()
transp.reader = nil
elif int32(err) != ERROR_IO_PENDING:
transp.state.excl(ReadPending)
transp.state.incl(ReadPaused)
transp.setReadError(err)
if not isNil(transp.reader):
transp.reader.complete()
transp.reader = nil
else:
transp.state.incl(ReadPaused)
if not isNil(transp.reader):
@ -383,6 +468,27 @@ when defined(windows):
GC_ref(transp)
result = transp
proc newStreamPipeTransport(fd: AsyncFD, bufsize: int,
child: StreamTransport,
flags: set[TransportFlags] = {}): StreamTransport =
var transp: StreamTransport
if not isNil(child):
transp = child
else:
transp = StreamTransport(kind: TransportKind.Pipe)
transp.fd = fd
transp.rovl.data = CompletionData(fd: fd, cb: readStreamLoop,
udata: cast[pointer](transp))
transp.wovl.data = CompletionData(fd: fd, cb: writeStreamLoop,
udata: cast[pointer](transp))
transp.buffer = newSeq[byte](bufsize)
transp.flags = flags
transp.state = {ReadPaused, WritePaused}
transp.queue = initDeque[StreamVector]()
transp.future = newFuture[void]("stream.pipe.transport")
GC_ref(transp)
result = transp
proc bindToDomain(handle: AsyncFD, domain: Domain): bool =
result = true
if domain == Domain.AF_INET6:
@ -391,7 +497,7 @@ when defined(windows):
if bindAddr(SocketHandle(handle), cast[ptr SockAddr](addr(saddr)),
sizeof(saddr).SockLen) != 0'i32:
result = false
else:
elif domain == Domain.AF_INET:
var saddr: Sockaddr_in
saddr.sin_family = type(saddr.sin_family)(toInt(domain))
if bindAddr(SocketHandle(handle), cast[ptr SockAddr](addr(saddr)),
@ -400,66 +506,161 @@ when defined(windows):
proc connect*(address: TransportAddress,
bufferSize = DefaultStreamBufferSize,
child: StreamTransport = nil): Future[StreamTransport] =
child: StreamTransport = nil,
flags: set[TransportFlags] = {}): Future[StreamTransport] =
## Open new connection to remote peer with address ``address`` and create
## new transport object ``StreamTransport`` for established connection.
## ``bufferSize`` is size of internal buffer for transport.
let loop = getGlobalDispatcher()
var
saddr: Sockaddr_storage
slen: SockLen
sock: AsyncFD
povl: RefCustomOverlapped
var retFuture = newFuture[StreamTransport]("stream.transport.connect")
toSockAddr(address.address, address.port, saddr, slen)
sock = createAsyncSocket(address.address.getDomain(), SockType.SOCK_STREAM,
Protocol.IPPROTO_TCP)
if address.family in {AddressFamily.IPv4, AddressFamily.IPv6}:
## Socket handling part
var
saddr: Sockaddr_storage
slen: SockLen
sock: AsyncFD
povl: RefCustomOverlapped
proto: Protocol
if sock == asyncInvalidSocket:
retFuture.fail(getTransportOsError(OSErrorCode(wsaGetLastError())))
return retFuture
toSAddr(address, saddr, slen)
proto = Protocol.IPPROTO_TCP
sock = createAsyncSocket(address.getDomain(), SockType.SOCK_STREAM, proto)
if sock == asyncInvalidSocket:
result.fail(getTransportOsError(osLastError()))
if not bindToDomain(sock, address.address.getDomain()):
let err = wsaGetLastError()
sock.closeSocket()
retFuture.fail(getTransportOsError(err))
return retFuture
proc continuation(udata: pointer) =
var ovl = cast[RefCustomOverlapped](udata)
if not retFuture.finished:
if ovl.data.errCode == OSErrorCode(-1):
if setsockopt(SocketHandle(sock), cint(SOL_SOCKET),
cint(SO_UPDATE_CONNECT_CONTEXT), nil,
SockLen(0)) != 0'i32:
sock.closeSocket()
retFuture.fail(getTransportOsError(wsaGetLastError()))
else:
retFuture.complete(newStreamSocketTransport(povl.data.fd,
bufferSize,
child))
else:
sock.closeSocket()
retFuture.fail(getTransportOsError(ovl.data.errCode))
GC_unref(ovl)
povl = RefCustomOverlapped()
GC_ref(povl)
povl.data = CompletionData(fd: sock, cb: continuation)
var res = loop.connectEx(SocketHandle(sock),
cast[ptr SockAddr](addr saddr),
DWORD(slen), nil, 0, nil,
cast[POVERLAPPED](povl))
# We will not process immediate completion, to avoid undefined behavior.
if not res:
let err = osLastError()
if int32(err) != ERROR_IO_PENDING:
GC_unref(povl)
if not bindToDomain(sock, address.getDomain()):
let err = wsaGetLastError()
sock.closeSocket()
retFuture.fail(getTransportOsError(err))
return retFuture
proc socketContinuation(udata: pointer) =
var ovl = cast[RefCustomOverlapped](udata)
if not retFuture.finished:
if ovl.data.errCode == OSErrorCode(-1):
if setsockopt(SocketHandle(sock), cint(SOL_SOCKET),
cint(SO_UPDATE_CONNECT_CONTEXT), nil,
SockLen(0)) != 0'i32:
let err = wsaGetLastError()
sock.closeSocket()
retFuture.fail(getTransportOsError(err))
else:
retFuture.complete(newStreamSocketTransport(povl.data.fd,
bufferSize,
child))
else:
sock.closeSocket()
retFuture.fail(getTransportOsError(ovl.data.errCode))
GC_unref(ovl)
povl = RefCustomOverlapped()
GC_ref(povl)
povl.data = CompletionData(fd: sock, cb: socketContinuation)
if address.family in {AddressFamily.IPv4, AddressFamily.IPv6}:
var res = loop.connectEx(SocketHandle(sock),
cast[ptr SockAddr](addr saddr),
DWORD(slen), nil, 0, nil,
cast[POVERLAPPED](povl))
# We will not process immediate completion, to avoid undefined behavior.
if not res:
let err = osLastError()
if int32(err) != ERROR_IO_PENDING:
GC_unref(povl)
sock.closeSocket()
retFuture.fail(getTransportOsError(err))
elif address.family == AddressFamily.Unix:
## Unix domain socket emulation with Windows Named Pipes.
proc pipeContinuation(udata: pointer) {.gcsafe.} =
var pipeSuffix = $cast[cstring](unsafeAddr address.address_un[0])
var pipeName = newWideCString(r"\\.\pipe\" & pipeSuffix[1 .. ^1])
var pipeHandle = createFileW(pipeName, GENERIC_READ or GENERIC_WRITE,
FILE_SHARE_READ or FILE_SHARE_WRITE,
nil, OPEN_EXISTING,
FILE_FLAG_OVERLAPPED, Handle(0))
if pipeHandle == INVALID_HANDLE_VALUE:
let err = osLastError()
if int32(err) == ERROR_PIPE_BUSY:
addTimer(fastEpochTime() + 50, pipeContinuation, nil)
else:
retFuture.fail(getTransportOsError(err))
else:
register(AsyncFD(pipeHandle))
retFuture.complete(newStreamPipeTransport(AsyncFD(pipeHandle),
bufferSize, child))
pipeContinuation(nil)
return retFuture
proc acceptPipeLoop(udata: pointer) {.gcsafe, nimcall.} =
var ovl = cast[PtrCustomOverlapped](udata)
var server = cast[StreamServer](ovl.data.udata)
var loop = getGlobalDispatcher()
while true:
if server.apending:
## Continuation
server.apending = false
if server.status in {ServerStatus.Stopped, ServerStatus.Closed}:
break
else:
if ovl.data.errCode == OSErrorCode(-1):
var ntransp: StreamTransport
var flags = {WinServerPipe}
if NoPipeFlash in server.flags:
flags.incl(WinNoPipeFlash)
if not isNil(server.init):
var transp = server.init(server, server.sock)
ntransp = newStreamPipeTransport(server.sock, server.bufferSize,
transp, flags)
else:
ntransp = newStreamPipeTransport(server.sock, server.bufferSize,
nil, flags)
asyncCheck server.function(server, ntransp)
elif int32(ovl.data.errCode) == ERROR_OPERATION_ABORTED:
# CancelIO() interrupt
break
else:
doAssert disconnectNamedPipe(Handle(server.sock)) == 1
doAssert closeHandle(HANDLE(server.sock)) == 1
raiseTransportOsError(osLastError())
else:
## Initiation
server.apending = true
if server.status in {ServerStatus.Stopped, ServerStatus.Closed}:
## Server was already stopped/closed exiting
break
var pipeSuffix = $cast[cstring](addr server.local.address_un)
var pipeName = newWideCString(r"\\.\pipe\" & pipeSuffix[1 .. ^1])
var openMode = PIPE_ACCESS_DUPLEX or FILE_FLAG_OVERLAPPED
if FirstPipe notin server.flags:
openMode = openMode or FILE_FLAG_FIRST_PIPE_INSTANCE
server.flags.incl(FirstPipe)
let pipeMode = int32(PIPE_TYPE_BYTE or PIPE_READMODE_BYTE or PIPE_WAIT)
let pipeHandle = createNamedPipe(pipeName, openMode, pipeMode,
PIPE_UNLIMITED_INSTANCES,
DWORD(server.bufferSize),
DWORD(server.bufferSize),
DWORD(0), nil)
if pipeHandle == INVALID_HANDLE_VALUE:
raiseTransportOsError(osLastError())
server.sock = AsyncFD(pipeHandle)
server.aovl.data.fd = AsyncFD(pipeHandle)
register(server.sock)
let res = connectNamedPipe(pipeHandle,
cast[POVERLAPPED](addr server.aovl))
if res == 0:
let err = osLastError()
if int32(err) == ERROR_IO_PENDING:
discard
elif int32(err) == ERROR_PIPE_CONNECTED:
discard
else:
raiseTransportOsError(err)
break
proc acceptLoop(udata: pointer) {.gcsafe, nimcall.} =
var ovl = cast[PtrCustomOverlapped](udata)
var server = cast[StreamServer](ovl.data.udata)
@ -469,36 +670,37 @@ when defined(windows):
if server.apending:
## Continuation
server.apending = false
if server.status == ServerStatus.Stopped:
if server.status in {ServerStatus.Stopped, ServerStatus.Closed}:
## Server was already stopped/closed exiting
server.asock.closeSocket()
break
else:
if ovl.data.errCode == OSErrorCode(-1):
if setsockopt(SocketHandle(server.asock), cint(SOL_SOCKET),
cint(SO_UPDATE_ACCEPT_CONTEXT),
addr server.sock,
cint(SO_UPDATE_ACCEPT_CONTEXT), addr server.sock,
SockLen(sizeof(SocketHandle))) != 0'i32:
let err = OSErrorCode(wsaGetLastError())
server.asock.closeSocket()
raiseTransportOsError(err)
else:
var ntransp: StreamTransport
if not isNil(server.init):
var transp = server.init(server, server.asock)
let ntransp = newStreamSocketTransport(server.asock,
server.bufferSize,
transp)
asyncCheck server.function(server, ntransp)
let transp = server.init(server, server.asock)
ntransp = newStreamSocketTransport(server.asock,
server.bufferSize,
transp)
else:
let ntransp = newStreamSocketTransport(server.asock,
server.bufferSize, nil)
asyncCheck server.function(server, ntransp)
ntransp = newStreamSocketTransport(server.asock,
server.bufferSize, nil)
asyncCheck server.function(server, ntransp)
elif int32(ovl.data.errCode) == ERROR_OPERATION_ABORTED:
# CancelIO() interrupt
server.asock.closeSocket()
break
else:
let err = OSErrorCode(wsaGetLastError())
server.asock.closeSocket()
raiseTransportOsError(err)
raiseTransportOsError(ovl.data.errCode)
else:
## Initiation
if server.status in {ServerStatus.Stopped, ServerStatus.Closed}:
@ -547,16 +749,10 @@ when defined(windows):
proc resumeAccept(server: StreamServer) {.inline.} =
if not server.apending:
acceptLoop(cast[pointer](addr server.aovl))
server.aovl.data.cb(addr server.aovl)
else:
template getVectorBuffer(v: untyped): pointer =
cast[pointer](cast[uint]((v).buf) + uint((v).boffset))
template getVectorLength(v: untyped): int =
cast[int]((v).buflen - int((v).boffset))
template initBufferStreamVector(v, p, n, t: untyped) =
(v).kind = DataBuffer
(v).buf = cast[pointer]((p))
@ -592,14 +788,17 @@ else:
else:
vector.writer.fail(getTransportOsError(err))
else:
var nbytes = cast[int](vector.buf)
let res = sendfile(int(fd), cast[int](vector.buflen),
int(vector.offset),
cast[int](vector.buf))
nbytes)
if res >= 0:
if cast[int](vector.buf) - res == 0:
vector.writer.complete(cast[int](vector.buf))
if cast[int](vector.buf) - nbytes == 0:
vector.size += nbytes
vector.writer.complete(vector.size)
else:
vector.shiftVectorFile(res)
vector.size += nbytes
vector.shiftVectorFile(nbytes)
transp.queue.addFirst(vector)
else:
let err = osLastError()
@ -681,10 +880,16 @@ else:
saddr: Sockaddr_storage
slen: SockLen
sock: AsyncFD
proto: Protocol
var retFuture = newFuture[StreamTransport]("transport.connect")
toSockAddr(address.address, address.port, saddr, slen)
sock = createAsyncSocket(address.address.getDomain(), SockType.SOCK_STREAM,
Protocol.IPPROTO_TCP)
address.toSAddr(saddr, slen)
proto = Protocol.IPPROTO_TCP
if address.family == AddressFamily.Unix:
# `Protocol` enum is missing `0` value, so we making here cast, until
# `Protocol` enum will not support IPPROTO_IP == 0.
proto = cast[Protocol](0)
sock = createAsyncSocket(address.getDomain(), SockType.SOCK_STREAM,
proto)
if sock == asyncInvalidSocket:
retFuture.fail(getTransportOsError(osLastError()))
return retFuture
@ -741,7 +946,7 @@ else:
else:
asyncCheck server.function(server,
newStreamSocketTransport(sock, server.bufferSize, nil))
break
break
else:
let err = osLastError()
if int(err) == EINTR:
@ -800,7 +1005,16 @@ proc close*(server: StreamServer) =
GC_unref(server)
if server.status == ServerStatus.Stopped:
server.status = ServerStatus.Closed
server.sock.closeSocket(continuation)
when defined(windows):
if server.local.family in {AddressFamily.IPv4, AddressFamily.IPv6}:
server.sock.closeSocket(continuation)
elif server.local.family in {AddressFamily.Unix}:
if NoPipeFlash notin server.flags:
discard flushFileBuffers(Handle(server.sock))
doAssert disconnectNamedPipe(Handle(server.sock)) == 1
closeHandle(server.sock, continuation)
else:
server.sock.closeSocket(continuation)
proc closeWait*(server: StreamServer): Future[void] =
## Close server ``server`` and release all resources.
@ -833,53 +1047,112 @@ proc createStreamServer*(host: TransportAddress,
saddr: Sockaddr_storage
slen: SockLen
serverSocket: AsyncFD
if sock == asyncInvalidSocket:
serverSocket = createAsyncSocket(host.address.getDomain(),
SockType.SOCK_STREAM,
Protocol.IPPROTO_TCP)
if serverSocket == asyncInvalidSocket:
raiseTransportOsError(osLastError())
when defined(windows):
# Windows
if host.family in {AddressFamily.IPv4, AddressFamily.IPv6}:
if sock == asyncInvalidSocket:
serverSocket = createAsyncSocket(host.getDomain(),
SockType.SOCK_STREAM,
Protocol.IPPROTO_TCP)
if serverSocket == asyncInvalidSocket:
raiseTransportOsError(osLastError())
else:
if not setSocketBlocking(SocketHandle(sock), false):
raiseTransportOsError(osLastError())
register(sock)
serverSocket = sock
# SO_REUSEADDR is not useful for Unix domain sockets.
if ServerFlags.ReuseAddr in flags:
if not setSockOpt(serverSocket, SOL_SOCKET, SO_REUSEADDR, 1):
let err = osLastError()
if sock == asyncInvalidSocket:
serverSocket.closeSocket()
raiseTransportOsError(err)
# TCP flags are not useful for Unix domain sockets.
if ServerFlags.TcpNoDelay in flags:
if not setSockOpt(serverSocket, handles.IPPROTO_TCP,
handles.TCP_NODELAY, 1):
let err = osLastError()
if sock == asyncInvalidSocket:
serverSocket.closeSocket()
raiseTransportOsError(err)
host.toSAddr(saddr, slen)
if bindAddr(SocketHandle(serverSocket), cast[ptr SockAddr](addr saddr),
slen) != 0:
let err = osLastError()
if sock == asyncInvalidSocket:
serverSocket.closeSocket()
raiseTransportOsError(err)
if nativesockets.listen(SocketHandle(serverSocket), cint(backlog)) != 0:
let err = osLastError()
if sock == asyncInvalidSocket:
serverSocket.closeSocket()
raiseTransportOsError(err)
elif host.family == AddressFamily.Unix:
serverSocket = AsyncFD(0)
else:
if not setSocketBlocking(SocketHandle(sock), false):
raiseTransportOsError(osLastError())
register(sock)
serverSocket = sock
# Posix
if sock == asyncInvalidSocket:
var proto = Protocol.IPPROTO_TCP
if host.family == AddressFamily.Unix:
# `Protocol` enum is missing `0` value, so we making here cast, until
# `Protocol` enum will not support IPPROTO_IP == 0.
proto = cast[Protocol](0)
serverSocket = createAsyncSocket(host.getDomain(),
SockType.SOCK_STREAM,
proto)
if serverSocket == asyncInvalidSocket:
raiseTransportOsError(osLastError())
else:
if not setSocketBlocking(SocketHandle(sock), false):
raiseTransportOsError(osLastError())
register(sock)
serverSocket = sock
if ServerFlags.ReuseAddr in flags:
if not setSockOpt(serverSocket, SOL_SOCKET, SO_REUSEADDR, 1):
if host.family in {AddressFamily.IPv4, AddressFamily.IPv6}:
# SO_REUSEADDR is not useful for Unix domain sockets.
if ServerFlags.ReuseAddr in flags:
if not setSockOpt(serverSocket, SOL_SOCKET, SO_REUSEADDR, 1):
let err = osLastError()
if sock == asyncInvalidSocket:
serverSocket.closeSocket()
raiseTransportOsError(err)
# TCP flags are not useful for Unix domain sockets.
if ServerFlags.TcpNoDelay in flags:
if not setSockOpt(serverSocket, handles.IPPROTO_TCP,
handles.TCP_NODELAY, 1):
let err = osLastError()
if sock == asyncInvalidSocket:
serverSocket.closeSocket()
raiseTransportOsError(err)
elif host.family in {AddressFamily.Unix}:
# We do not care about result here, because if file cannot be removed,
# `bindAddr` will return EADDRINUSE.
discard posix.unlink(cast[cstring](unsafeAddr host.address_un[0]))
host.toSAddr(saddr, slen)
if bindAddr(SocketHandle(serverSocket), cast[ptr SockAddr](addr saddr),
slen) != 0:
let err = osLastError()
if sock == asyncInvalidSocket:
serverSocket.closeSocket()
raiseTransportOsError(err)
if ServerFlags.TcpNoDelay in flags:
if not setSockOpt(serverSocket, handles.IPPROTO_TCP,
handles.TCP_NODELAY, 1):
if nativesockets.listen(SocketHandle(serverSocket), cint(backlog)) != 0:
let err = osLastError()
if sock == asyncInvalidSocket:
serverSocket.closeSocket()
raiseTransportOsError(err)
toSockAddr(host.address, host.port, saddr, slen)
if bindAddr(SocketHandle(serverSocket), cast[ptr SockAddr](addr saddr),
slen) != 0:
let err = osLastError()
if sock == asyncInvalidSocket:
serverSocket.closeSocket()
raiseTransportOsError(err)
if nativesockets.listen(SocketHandle(serverSocket), cint(backlog)) != 0:
let err = osLastError()
if sock == asyncInvalidSocket:
serverSocket.closeSocket()
raiseTransportOsError(err)
if not isNil(child):
result = child
else:
result = StreamServer()
result.sock = serverSocket
result.flags = flags
result.function = cbproc
result.init = init
result.bufferSize = bufferSize
@ -889,10 +1162,17 @@ proc createStreamServer*(host: TransportAddress,
result.local = host
when defined(windows):
result.aovl.data = CompletionData(fd: serverSocket, cb: acceptLoop,
var cb: CallbackFunc
if host.family in {AddressFamily.IPv4, AddressFamily.IPv6}:
cb = acceptLoop
elif host.family == AddressFamily.Unix:
cb = acceptPipeLoop
result.aovl.data = CompletionData(fd: serverSocket, cb: cb,
udata: cast[pointer](result))
result.domain = host.address.getDomain()
result.domain = host.getDomain()
result.apending = false
GC_ref(result)
proc createStreamServer*[T](host: TransportAddress,
@ -967,6 +1247,9 @@ proc writeFile*(transp: StreamTransport, handle: int,
##
## You can specify starting ``offset`` in opened file and number of bytes
## to transfer from file to transport via ``size``.
when defined(windows):
if transp.kind != TransportKind.Socket:
raise newException(TransportNoSupport, "writeFile() is not supported!")
var retFuture = newFuture[int]("transport.writeFile")
transp.checkClosed(retFuture)
var vector = StreamVector(kind: DataFile, writer: retFuture,
@ -1172,7 +1455,7 @@ proc read*(transp: StreamTransport, n = -1): Future[seq[byte]] {.async.} =
if transp.offset > 0:
let s = len(result)
let o = s + transp.offset
if n == -1:
if n < 0:
# grabbing all incoming data, until EOF
result.setLen(o)
copyMem(cast[pointer](addr result[s]), addr(transp.buffer[0]),
@ -1259,7 +1542,19 @@ proc close*(transp: StreamTransport) =
transp.state.incl({WriteClosed, ReadClosed})
when defined(windows):
discard cancelIo(Handle(transp.fd))
closeSocket(transp.fd, continuation)
if transp.kind == TransportKind.Pipe:
if WinServerPipe in transp.flags:
if WinNoPipeFlash notin transp.flags:
discard flushFileBuffers(Handle(transp.fd))
doAssert disconnectNamedPipe(Handle(transp.fd)) == 1
else:
if WinNoPipeFlash notin transp.flags:
discard flushFileBuffers(Handle(transp.fd))
closeHandle(transp.fd, continuation)
elif transp.kind == TransportKind.Socket:
closeSocket(transp.fd, continuation)
else:
closeSocket(transp.fd, continuation)
proc closeWait*(transp: StreamTransport): Future[void] =
## Close and frees resources of transport ``transp``.

View File

@ -9,16 +9,10 @@
import strutils, net, unittest
import ../asyncdispatch2
when sizeof(int) == 8:
const
TestsCount = 10000
ClientsCount = 100
MessagesCount = 100
elif sizeof(int) == 4:
const
TestsCount = 2000
ClientsCount = 20
MessagesCount = 20
const
TestsCount = 2000
ClientsCount = 20
MessagesCount = 20
proc client1(transp: DatagramTransport,
raddr: TransportAddress): Future[void] {.async.} =
@ -139,9 +133,8 @@ proc client5(transp: DatagramTransport,
if counterPtr[] == MessagesCount:
transp.close()
else:
var ta = initTAddress("127.0.0.1:33341")
var req = "REQUEST" & $counterPtr[]
await transp.sendTo(ta, addr req[0], len(req))
await transp.sendTo(raddr, addr req[0], len(req))
else:
var counterPtr = cast[ptr int](transp.udata)
counterPtr[] = -1
@ -190,9 +183,8 @@ proc client7(transp: DatagramTransport,
if counterPtr[] == TestsCount:
transp.close()
else:
var ta = initTAddress("127.0.0.1:33336")
var req = "REQUEST" & $counterPtr[]
await transp.sendTo(ta, req)
await transp.sendTo(raddr, req)
else:
var counterPtr = cast[ptr int](transp.udata)
counterPtr[] = -1
@ -272,11 +264,10 @@ proc client10(transp: DatagramTransport,
if counterPtr[] == TestsCount:
transp.close()
else:
var ta = initTAddress("127.0.0.1:33338")
var req = "REQUEST" & $counterPtr[]
var reqseq = newSeq[byte](len(req))
copyMem(addr reqseq[0], addr req[0], len(req))
await transp.sendTo(ta, reqseq)
await transp.sendTo(raddr, reqseq)
else:
var counterPtr = cast[ptr int](transp.udata)
counterPtr[] = -1
@ -326,7 +317,7 @@ proc testPointerSendTo(): Future[int] {.async.} =
await dgram2.sendTo(ta, addr data[0], len(data))
await dgram2.join()
dgram1.close()
dgram2.close()
await dgram1.join()
result = counter
proc testPointerSend(): Future[int] {.async.} =
@ -339,12 +330,12 @@ proc testPointerSend(): Future[int] {.async.} =
await dgram2.send(addr data[0], len(data))
await dgram2.join()
dgram1.close()
dgram2.close()
await dgram1.join()
result = counter
proc testStringSendTo(): Future[int] {.async.} =
## sendTo(string) test
var ta = initTAddress("127.0.0.1:33336")
var ta = initTAddress("127.0.0.1:33338")
var counter = 0
var dgram1 = newDatagramTransport(client6, udata = addr counter, local = ta)
var dgram2 = newDatagramTransport(client7, udata = addr counter)
@ -352,12 +343,12 @@ proc testStringSendTo(): Future[int] {.async.} =
await dgram2.sendTo(ta, data)
await dgram2.join()
dgram1.close()
dgram2.close()
await dgram1.join()
result = counter
proc testStringSend(): Future[int] {.async.} =
## send(string) test
var ta = initTAddress("127.0.0.1:33337")
var ta = initTAddress("127.0.0.1:33339")
var counter = 0
var dgram1 = newDatagramTransport(client6, udata = addr counter, local = ta)
var dgram2 = newDatagramTransport(client8, udata = addr counter, remote = ta)
@ -365,12 +356,12 @@ proc testStringSend(): Future[int] {.async.} =
await dgram2.send(data)
await dgram2.join()
dgram1.close()
dgram2.close()
await dgram1.join()
result = counter
proc testSeqSendTo(): Future[int] {.async.} =
## sendTo(string) test
var ta = initTAddress("127.0.0.1:33338")
var ta = initTAddress("127.0.0.1:33340")
var counter = 0
var dgram1 = newDatagramTransport(client9, udata = addr counter, local = ta)
var dgram2 = newDatagramTransport(client10, udata = addr counter)
@ -380,12 +371,12 @@ proc testSeqSendTo(): Future[int] {.async.} =
await dgram2.sendTo(ta, dataseq)
await dgram2.join()
dgram1.close()
dgram2.close()
await dgram1.join()
result = counter
proc testSeqSend(): Future[int] {.async.} =
## send(string) test
var ta = initTAddress("127.0.0.1:33339")
var ta = initTAddress("127.0.0.1:33341")
var counter = 0
var dgram1 = newDatagramTransport(client9, udata = addr counter, local = ta)
var dgram2 = newDatagramTransport(client11, udata = addr counter, remote = ta)
@ -395,7 +386,7 @@ proc testSeqSend(): Future[int] {.async.} =
await dgram2.send(data)
await dgram2.join()
dgram1.close()
dgram2.close()
await dgram1.join()
result = counter
#
@ -414,9 +405,9 @@ proc waitAll(futs: seq[Future[void]]): Future[void] =
proc test3(bounded: bool): Future[int] {.async.} =
var ta: TransportAddress
if bounded:
ta = initTAddress("127.0.0.1:33340")
ta = initTAddress("127.0.0.1:33240")
else:
ta = initTAddress("127.0.0.1:33341")
ta = initTAddress("127.0.0.1:33241")
var counter = 0
var dgram1 = newDatagramTransport(client1, udata = addr counter, local = ta)
var clients = newSeq[Future[void]](ClientsCount)
@ -435,6 +426,7 @@ proc test3(bounded: bool): Future[int] {.async.} =
await waitAll(clients)
dgram1.close()
await dgram1.join()
result = 0
for i in 0..<ClientsCount:
result += counters[i]
@ -448,11 +440,14 @@ proc testConnReset(): Future[bool] {.async.} =
transp.close()
var dgram1 = newDatagramTransport(client1, local = ta)
dgram1.close()
await dgram1.join()
var dgram2 = newDatagramTransport(clientMark)
var data = "MESSAGE"
asyncCheck dgram2.sendTo(ta, data)
await sleepAsync(1000)
result = (counter == 0)
dgram2.close()
await dgram2.join()
when isMainModule:
const

View File

@ -20,11 +20,11 @@ const
FilesTestName = "tests/teststream.nim"
when sizeof(int) == 8:
const
BigMessageCount = 1000
ClientsCount = 100
BigMessageCount = 500
ClientsCount = 50
MessagesCount = 100
MessageSize = 20
FilesCount = 100
FilesCount = 50
elif sizeof(int) == 4:
const
BigMessageCount = 200
@ -115,36 +115,6 @@ proc serveClient4(server: StreamServer, transp: StreamTransport) {.async.} =
transp.close()
await transp.join()
proc serveClient5(server: StreamServer, transp: StreamTransport) {.async.} =
var data = await transp.read()
doAssert(len(data) == len(ConstantMessage) * MessagesCount)
transp.close()
var expect = ""
for i in 0..<MessagesCount:
expect.add(ConstantMessage)
doAssert(equalMem(addr expect[0], addr data[0], len(data)))
var counter = cast[ptr int](server.udata)
dec(counter[])
if counter[] == 0:
server.stop()
server.close()
await server.join()
proc serveClient6(server: StreamServer, transp: StreamTransport) {.async.} =
var expect = ConstantMessage
var skip = await transp.consume(len(ConstantMessage) * (MessagesCount - 1))
doAssert(skip == len(ConstantMessage) * (MessagesCount - 1))
var data = await transp.read()
doAssert(len(data) == len(ConstantMessage))
transp.close()
doAssert(equalMem(addr data[0], addr expect[0], len(expect)))
var counter = cast[ptr int](server.udata)
dec(counter[])
if counter[] == 0:
server.stop()
server.close()
await server.join()
proc serveClient7(server: StreamServer, transp: StreamTransport) {.async.} =
var answer = "DONE\r\n"
var expect = ""
@ -157,6 +127,8 @@ proc serveClient7(server: StreamServer, transp: StreamTransport) {.async.} =
doAssert(res == len(answer))
transp.close()
await transp.join()
server.stop()
server.close()
proc serveClient8(server: StreamServer, transp: StreamTransport) {.async.} =
var answer = "DONE\r\n"
@ -176,9 +148,9 @@ proc serveClient8(server: StreamServer, transp: StreamTransport) {.async.} =
var res = await transp.write(answer)
doAssert(res == len(answer))
transp.close()
await transp.join()
server.stop()
server.close()
await server.join()
proc swarmWorker1(address: TransportAddress): Future[int] {.async.} =
var transp = await connect(address)
@ -274,26 +246,6 @@ proc swarmWorker4(address: TransportAddress): Future[int] {.async.} =
transp.close()
await transp.join()
proc swarmWorker5(address: TransportAddress): Future[int] {.async.} =
var transp = await connect(address)
var data = ConstantMessage
for i in 0..<MessagesCount:
var res = await transp.write(data)
result = MessagesCount
transp.close()
await transp.join()
proc swarmWorker6(address: TransportAddress): Future[int] {.async.} =
var transp = await connect(address)
var data = ConstantMessage
var seqdata = newSeq[byte](len(data))
copyMem(addr seqdata[0], addr data[0], len(data))
for i in 0..<MessagesCount:
var res = await transp.write(seqdata)
result = MessagesCount
transp.close()
await transp.join()
proc swarmWorker7(address: TransportAddress): Future[int] {.async.} =
var transp = await connect(address)
var data = BigMessagePattern
@ -334,7 +286,6 @@ proc waitAll[T](futs: seq[Future[T]]): Future[void] =
return retFuture
proc swarmManager1(address: TransportAddress): Future[int] {.async.} =
var retFuture = newFuture[void]("swarm.manager.readLine")
var workers = newSeq[Future[int]](ClientsCount)
var count = ClientsCount
for i in 0..<ClientsCount:
@ -345,7 +296,6 @@ proc swarmManager1(address: TransportAddress): Future[int] {.async.} =
result += res
proc swarmManager2(address: TransportAddress): Future[int] {.async.} =
var retFuture = newFuture[void]("swarm.manager.readExactly")
var workers = newSeq[Future[int]](ClientsCount)
var count = ClientsCount
for i in 0..<ClientsCount:
@ -356,7 +306,6 @@ proc swarmManager2(address: TransportAddress): Future[int] {.async.} =
result += res
proc swarmManager3(address: TransportAddress): Future[int] {.async.} =
var retFuture = newFuture[void]("swarm.manager.readUntil")
var workers = newSeq[Future[int]](ClientsCount)
var count = ClientsCount
for i in 0..<ClientsCount:
@ -367,7 +316,6 @@ proc swarmManager3(address: TransportAddress): Future[int] {.async.} =
result += res
proc swarmManager4(address: TransportAddress): Future[int] {.async.} =
var retFuture = newFuture[void]("swarm.manager.writeFile")
var workers = newSeq[Future[int]](FilesCount)
var count = FilesCount
for i in 0..<FilesCount:
@ -377,161 +325,195 @@ proc swarmManager4(address: TransportAddress): Future[int] {.async.} =
var res = workers[i].read()
result += res
proc swarmManager5(address: TransportAddress): Future[int] {.async.} =
var retFuture = newFuture[void]("swarm.manager.read")
var workers = newSeq[Future[int]](ClientsCount)
var count = ClientsCount
for i in 0..<ClientsCount:
workers[i] = swarmWorker5(address)
await waitAll(workers)
for i in 0..<ClientsCount:
var res = workers[i].read()
result += res
proc swarmManager6(address: TransportAddress): Future[int] {.async.} =
var retFuture = newFuture[void]("swarm.manager.consume")
var workers = newSeq[Future[int]](ClientsCount)
var count = ClientsCount
for i in 0..<ClientsCount:
workers[i] = swarmWorker6(address)
await waitAll(workers)
for i in 0..<ClientsCount:
var res = workers[i].read()
result += res
proc test1(): Future[int] {.async.} =
var ta = initTAddress("127.0.0.1:31344")
var server = createStreamServer(ta, serveClient1, {ReuseAddr})
proc test1(address: TransportAddress): Future[int] {.async.} =
var server = createStreamServer(address, serveClient1, {ReuseAddr})
server.start()
result = await swarmManager1(ta)
result = await swarmManager1(address)
server.stop()
server.close()
await server.join()
proc test2(): Future[int] {.async.} =
var ta = initTAddress("127.0.0.1:31345")
proc test2(address: TransportAddress): Future[int] {.async.} =
var counter = 0
var server = createStreamServer(ta, serveClient2, {ReuseAddr})
var server = createStreamServer(address, serveClient2, {ReuseAddr})
server.start()
result = await swarmManager2(ta)
result = await swarmManager2(address)
server.stop()
server.close()
await server.join()
proc test3(): Future[int] {.async.} =
var ta = initTAddress("127.0.0.1:31346")
proc test3(address: TransportAddress): Future[int] {.async.} =
var counter = 0
var server = createStreamServer(ta, serveClient3, {ReuseAddr})
var server = createStreamServer(address, serveClient3, {ReuseAddr})
server.start()
result = await swarmManager3(ta)
result = await swarmManager3(address)
server.stop()
server.close()
await server.join()
proc test4(): Future[int] {.async.} =
var ta = initTAddress("127.0.0.1:31347")
var server = createStreamServer(ta, serveClient4, {ReuseAddr})
proc testSendFile(address: TransportAddress): Future[int] {.async.} =
var server = createStreamServer(address, serveClient4, {ReuseAddr})
server.start()
result = await swarmManager4(ta)
result = await swarmManager4(address)
server.stop()
server.close()
await server.join()
proc test5(): Future[int] {.async.} =
var ta = initTAddress("127.0.0.1:31348")
proc testWR(address: TransportAddress): Future[int] {.async.} =
var counter = ClientsCount
var server = createStreamServer(ta, serveClient5, {ReuseAddr},
udata = cast[pointer](addr counter))
server.start()
result = await swarmManager5(ta)
proc test6(): Future[int] {.async.} =
var ta = initTAddress("127.0.0.1:31349")
proc swarmWorker(address: TransportAddress): Future[int] {.async.} =
var transp = await connect(address)
var data = ConstantMessage
for i in 0..<MessagesCount:
var res = await transp.write(data)
result = MessagesCount
transp.close()
await transp.join()
proc swarmManager(address: TransportAddress): Future[int] {.async.} =
var workers = newSeq[Future[int]](ClientsCount)
var count = ClientsCount
for i in 0..<ClientsCount:
workers[i] = swarmWorker(address)
await waitAll(workers)
for i in 0..<ClientsCount:
var res = workers[i].read()
result += res
proc serveClient(server: StreamServer, transp: StreamTransport) {.async.} =
var data = await transp.read()
doAssert(len(data) == len(ConstantMessage) * MessagesCount)
transp.close()
var expect = ""
for i in 0..<MessagesCount:
expect.add(ConstantMessage)
doAssert(equalMem(addr expect[0], addr data[0], len(data)))
dec(counter)
if counter == 0:
server.stop()
server.close()
var server = createStreamServer(address, serveClient, {ReuseAddr})
server.start()
result = await swarmManager(address)
await server.join()
proc testWCR(address: TransportAddress): Future[int] {.async.} =
var counter = ClientsCount
var server = createStreamServer(ta, serveClient6, {ReuseAddr},
udata = cast[pointer](addr counter))
server.start()
result = await swarmManager6(ta)
proc test7(): Future[int] {.async.} =
var ta = initTAddress("127.0.0.1:31350")
var server = createStreamServer(ta, serveClient7, {ReuseAddr})
proc serveClient(server: StreamServer, transp: StreamTransport) {.async.} =
var expect = ConstantMessage
var skip = await transp.consume(len(ConstantMessage) * (MessagesCount - 1))
doAssert(skip == len(ConstantMessage) * (MessagesCount - 1))
var data = await transp.read()
doAssert(len(data) == len(ConstantMessage))
transp.close()
doAssert(equalMem(addr data[0], addr expect[0], len(expect)))
dec(counter)
if counter == 0:
server.stop()
server.close()
proc swarmWorker(address: TransportAddress): Future[int] {.async.} =
var transp = await connect(address)
var data = ConstantMessage
var seqdata = newSeq[byte](len(data))
copyMem(addr seqdata[0], addr data[0], len(data))
for i in 0..<MessagesCount:
var res = await transp.write(seqdata)
result = MessagesCount
transp.close()
await transp.join()
proc swarmManager(address: TransportAddress): Future[int] {.async.} =
var workers = newSeq[Future[int]](ClientsCount)
for i in 0..<ClientsCount:
workers[i] = swarmWorker(address)
await waitAll(workers)
for i in 0..<ClientsCount:
var res = workers[i].read()
result += res
var server = createStreamServer(address, serveClient, {ReuseAddr})
server.start()
result = await swarmWorker7(ta)
result = await swarmManager(address)
await server.join()
proc test7(address: TransportAddress): Future[int] {.async.} =
var server = createStreamServer(address, serveClient7, {ReuseAddr})
server.start()
result = await swarmWorker7(address)
server.stop()
server.close()
await server.join()
proc test8(): Future[int] {.async.} =
var ta = initTAddress("127.0.0.1:31350")
var server = createStreamServer(ta, serveClient8, {ReuseAddr})
proc test8(address: TransportAddress): Future[int] {.async.} =
var server = createStreamServer(address, serveClient8, {ReuseAddr})
server.start()
result = await swarmWorker8(ta)
server.stop()
server.close()
result = await swarmWorker8(address)
await server.join()
proc serveClient9(server: StreamServer, transp: StreamTransport) {.async.} =
var expect = ""
for i in 0..<BigMessageCount:
expect.add(BigMessagePattern)
var res = await transp.write(expect)
doAssert(res == len(expect))
transp.close()
await transp.join()
# proc serveClient9(server: StreamServer, transp: StreamTransport) {.async.} =
# var expect = ""
# for i in 0..<BigMessageCount:
# expect.add(BigMessagePattern)
# var res = await transp.write(expect)
# doAssert(res == len(expect))
# transp.close()
# await transp.join()
proc swarmWorker9(address: TransportAddress): Future[int] {.async.} =
var transp = await connect(address)
var expect = ""
for i in 0..<BigMessageCount:
expect.add(BigMessagePattern)
var line = await transp.readLine()
if line == expect:
result = 1
else:
result = 0
transp.close()
await transp.join()
# proc swarmWorker9(address: TransportAddress): Future[int] {.async.} =
# var transp = await connect(address)
# var expect = ""
# for i in 0..<BigMessageCount:
# expect.add(BigMessagePattern)
# var line = await transp.readLine()
# if line == expect:
# result = 1
# else:
# result = 0
# transp.close()
# await transp.join()
proc test9(): Future[int] {.async.} =
var ta = initTAddress("127.0.0.1:31351")
var server = createStreamServer(ta, serveClient9, {ReuseAddr})
server.start()
result = await swarmWorker9(ta)
server.stop()
server.close()
await server.join()
# proc test9(address: TransportAddress): Future[int] {.async.} =
# let flags = {ReuseAddr, NoPipeFlash}
# var server = createStreamServer(address, serveClient9, flags)
# server.start()
# result = await swarmWorker9(address)
# server.stop()
# server.close()
# await server.join()
proc serveClient10(server: StreamServer, transp: StreamTransport) {.async.} =
var expect = ""
for i in 0..<BigMessageCount:
expect.add(BigMessagePattern)
var res = await transp.write(expect)
doAssert(res == len(expect))
transp.close()
await transp.join()
# proc serveClient10(server: StreamServer, transp: StreamTransport) {.async.} =
# var expect = ""
# for i in 0..<BigMessageCount:
# expect.add(BigMessagePattern)
# var res = await transp.write(expect)
# doAssert(res == len(expect))
# transp.close()
# await transp.join()
proc swarmWorker10(address: TransportAddress): Future[int] {.async.} =
var transp = await connect(address)
var expect = ""
for i in 0..<BigMessageCount:
expect.add(BigMessagePattern)
var line = await transp.read()
if equalMem(addr line[0], addr expect[0], len(expect)):
result = 1
else:
result = 0
transp.close()
await transp.join()
# proc swarmWorker10(address: TransportAddress): Future[int] {.async.} =
# var transp = await connect(address)
# var expect = ""
# for i in 0..<BigMessageCount:
# expect.add(BigMessagePattern)
# var line = await transp.read()
# if equalMem(addr line[0], addr expect[0], len(expect)):
# result = 1
# else:
# result = 0
# transp.close()
# await transp.join()
proc test10(): Future[int] {.async.} =
var ta = initTAddress("127.0.0.1:31351")
var server = createStreamServer(ta, serveClient10, {ReuseAddr})
server.start()
result = await swarmWorker10(ta)
server.stop()
server.close()
await server.join()
# proc test10(address: TransportAddress): Future[int] {.async.} =
# var server = createStreamServer(address, serveClient10, {ReuseAddr})
# server.start()
# result = await swarmWorker10(address)
# server.stop()
# server.close()
# await server.join()
proc serveClient11(server: StreamServer, transp: StreamTransport) {.async.} =
var res = await transp.write(BigMessagePattern)
@ -549,11 +531,10 @@ proc swarmWorker11(address: TransportAddress): Future[int] {.async.} =
transp.close()
await transp.join()
proc test11(): Future[int] {.async.} =
var ta = initTAddress("127.0.0.1:31352")
var server = createStreamServer(ta, serveClient11, {ReuseAddr})
proc test11(address: TransportAddress): Future[int] {.async.} =
var server = createStreamServer(address, serveClient11, {ReuseAddr})
server.start()
result = await swarmWorker11(ta)
result = await swarmWorker11(address)
server.stop()
server.close()
await server.join()
@ -575,11 +556,10 @@ proc swarmWorker12(address: TransportAddress): Future[int] {.async.} =
transp.close()
await transp.join()
proc test12(): Future[int] {.async.} =
var ta = initTAddress("127.0.0.1:31353")
var server = createStreamServer(ta, serveClient12, {ReuseAddr})
proc test12(address: TransportAddress): Future[int] {.async.} =
var server = createStreamServer(address, serveClient12, {ReuseAddr})
server.start()
result = await swarmWorker12(ta)
result = await swarmWorker12(address)
server.stop()
server.close()
await server.join()
@ -598,11 +578,10 @@ proc swarmWorker13(address: TransportAddress): Future[int] {.async.} =
transp.close()
await transp.join()
proc test13(): Future[int] {.async.} =
var ta = initTAddress("127.0.0.1:31354")
var server = createStreamServer(ta, serveClient13, {ReuseAddr})
proc test13(address: TransportAddress): Future[int] {.async.} =
var server = createStreamServer(address, serveClient13, {ReuseAddr})
server.start()
result = await swarmWorker13(ta)
result = await swarmWorker13(address)
server.stop()
server.close()
await server.join()
@ -610,10 +589,9 @@ proc test13(): Future[int] {.async.} =
proc serveClient14(server: StreamServer, transp: StreamTransport) {.async.} =
discard
proc test14(): Future[int] {.async.} =
proc test14(address: TransportAddress): Future[int] {.async.} =
var subres = 0
var ta = initTAddress("127.0.0.1:31354")
var server = createStreamServer(ta, serveClient13, {ReuseAddr})
var server = createStreamServer(address, serveClient13, {ReuseAddr})
proc swarmWorker(transp: StreamTransport): Future[void] {.async.} =
var line = await transp.readLine()
@ -623,7 +601,7 @@ proc test14(): Future[int] {.async.} =
subres = 0
server.start()
var transp = await connect(ta)
var transp = await connect(address)
var fut = swarmWorker(transp)
transp.close()
await fut
@ -632,14 +610,16 @@ proc test14(): Future[int] {.async.} =
await server.join()
result = subres
proc testConnectionRefused(): Future[bool] {.async.} =
proc testConnectionRefused(address: TransportAddress): Future[bool] {.async.} =
try:
var transp = await connect(initTAddress("127.0.0.1:1"))
var transp = await connect(address)
except TransportOsError as e:
let ecode = int(e.code)
when defined(windows):
result = (int(e.code) == ERROR_CONNECTION_REFUSED)
result = (ecode == ERROR_FILE_NOT_FOUND) or
(ecode == ERROR_CONNECTION_REFUSED)
else:
result = (int(e.code) == ECONNREFUSED)
result = (ecode == ECONNREFUSED) or (ecode == ENOENT)
when isMainModule:
const
@ -656,41 +636,59 @@ when isMainModule:
$ClientsCount & " clients x " & $MessagesCount & " messages)"
m7 = "readLine() buffer overflow test"
m8 = "readUntil() buffer overflow test"
m9 = "readLine() unexpected disconnect test"
m10 = "read() unexpected disconnect test"
m11 = "readExactly() unexpected disconnect test"
m12 = "readUntil() unexpected disconnect test"
m13 = "readLine() unexpected disconnect empty string test"
m14 = "Closing socket while operation pending test (issue #8)"
m15 = "Connection refused test"
when defined(windows):
var addresses = [
initTAddress("127.0.0.1:33335"),
initTAddress(r"/LOCAL\testpipe")
]
else:
var addresses = [
initTAddress("127.0.0.1:33335"),
initTAddress(r"/tmp/testpipe")
]
var prefixes = ["[IP] ", "[UNIX] "]
suite "Stream Transport test suite":
test m8:
check waitFor(test8()) == 1
test m7:
check waitFor(test7()) == 1
test m9:
check waitFor(test9()) == 1
test m10:
check waitFor(test10()) == 1
test m11:
check waitFor(test11()) == 1
test m12:
check waitFor(test12()) == 1
test m13:
check waitFor(test13()) == 1
test m14:
check waitFor(test14()) == 1
test m1:
check waitFor(test1()) == ClientsCount * MessagesCount
test m2:
check waitFor(test2()) == ClientsCount * MessagesCount
test m3:
check waitFor(test3()) == ClientsCount * MessagesCount
test m5:
check waitFor(test5()) == ClientsCount * MessagesCount
test m6:
check waitFor(test6()) == ClientsCount * MessagesCount
test m4:
check waitFor(test4()) == FilesCount
test m15:
check waitFor(testConnectionRefused()) == true
for i in 0..<len(addresses):
test prefixes[i] & m8:
check waitFor(test8(addresses[i])) == 1
test prefixes[i] & m7:
check waitFor(test7(addresses[i])) == 1
test prefixes[i] & m11:
check waitFor(test11(addresses[i])) == 1
test prefixes[i] & m12:
check waitFor(test12(addresses[i])) == 1
test prefixes[i] & m13:
check waitFor(test13(addresses[i])) == 1
test prefixes[i] & m14:
check waitFor(test14(addresses[i])) == 1
test prefixes[i] & m1:
check waitFor(test1(addresses[i])) == ClientsCount * MessagesCount
test prefixes[i] & m2:
check waitFor(test2(addresses[i])) == ClientsCount * MessagesCount
test prefixes[i] & m3:
check waitFor(test3(addresses[i])) == ClientsCount * MessagesCount
test prefixes[i] & m5:
check waitFor(testWR(addresses[i])) == ClientsCount * MessagesCount
test prefixes[i] & m6:
check waitFor(testWCR(addresses[i])) == ClientsCount * MessagesCount
test prefixes[i] & m4:
when defined(windows):
if addresses[i].family == AddressFamily.IPv4:
check waitFor(testSendFile(addresses[i])) == FilesCount
else:
discard
else:
check waitFor(testSendFile(addresses[i])) == FilesCount
test prefixes[i] & m15:
var address: TransportAddress
if addresses[i].family == AddressFamily.Unix:
address = initTAddress("/tmp/notexistingtestpipe")
else:
address = initTAddress("127.0.0.1:43335")
check waitFor(testConnectionRefused(address)) == true