Refactoring, more tests.

This commit is contained in:
Your Name 2018-05-22 00:52:57 +03:00
parent 5c6c723cb9
commit 23a81b6492
8 changed files with 474 additions and 238 deletions

View File

@ -202,9 +202,6 @@ when defined(windows) or defined(nimdoc):
errCode*: OSErrorCode
bytesCount*: int32
udata*: pointer
cell*: ForeignCell # we need this `cell` to protect our `cb` environment,
# when using RegisterWaitForSingleObject, because
# waiting is done in different thread.
PDispatcher* = ref object of PDispatcherBase
ioPort: Handle
@ -217,19 +214,12 @@ when defined(windows) or defined(nimdoc):
CustomOverlapped* = object of OVERLAPPED
data*: CompletionData
PCustomOverlapped* = ptr CustomOverlapped
PtrCustomOverlapped* = ptr CustomOverlapped
RefCustomOverlapped* = ref CustomOverlapped
AsyncFD* = distinct int
# PostCallbackData = object
# ioPort: Handle
# handleFd: AsyncFD
# waitFd: Handle
# ovl: PCustomOverlapped
# PostCallbackDataPtr = ptr PostCallbackData
proc hash(x: AsyncFD): Hash {.borrow.}
proc `==`*(x: AsyncFD, y: AsyncFD): bool {.borrow.}
@ -292,7 +282,7 @@ when defined(windows) or defined(nimdoc):
# Processing handles
var lpNumberOfBytesTransferred: Dword
var lpCompletionKey: ULONG_PTR
var customOverlapped: PCustomOverlapped
var customOverlapped: PtrCustomOverlapped
let res = getQueuedCompletionStatus(
loop.ioPort, addr lpNumberOfBytesTransferred, addr lpCompletionKey,
cast[ptr POVERLAPPED](addr customOverlapped), curTimeout).bool

View File

@ -12,7 +12,7 @@ import net, nativesockets, asyncloop
when defined(windows):
import winlean
const
asyncInvalidSocket* = AsyncFD(SocketHandle(-1))
asyncInvalidSocket* = AsyncFD(-1)
else:
import posix
const
@ -35,8 +35,7 @@ proc setSocketBlocking*(s: SocketHandle, blocking: bool): bool =
if fcntl(s, F_SETFL, mode) == -1:
result = false
proc setSockOpt*(socket: SocketHandle | AsyncFD, level, optname,
optval: int): bool =
proc setSockOpt*(socket: AsyncFD, level, optname, optval: int): bool =
## `setsockopt()` for integer options.
## Returns ``true`` on success, ``false`` on error.
result = true
@ -44,9 +43,8 @@ proc setSockOpt*(socket: SocketHandle | AsyncFD, level, optname,
if setsockopt(SocketHandle(socket), cint(level), cint(optname), addr(value),
sizeof(value).SockLen) < 0'i32:
result = false
proc getSockOpt*(socket: SocketHandle | AsyncFD, level, optname: int,
value: var int): bool =
proc getSockOpt*(socket: AsyncFD, level, optname: int, value: var int): bool =
## `getsockopt()` for integer options.
var res: cint
var size = sizeof(res).SockLen
@ -56,8 +54,8 @@ proc getSockOpt*(socket: SocketHandle | AsyncFD, level, optname: int,
return false
value = int(res)
proc getSocketError*(socket: SocketHandle | AsyncFD,
err: var int): bool =
proc getSocketError*(socket: AsyncFD, err: var int): bool =
## Recover error code associated with socket handle ``socket``.
if not getSockOpt(socket, cint(SOL_SOCKET), cint(SO_ERROR), err):
result = false
else:
@ -74,25 +72,26 @@ proc createAsyncSocket*(domain: Domain, sockType: SockType,
close(handle)
return asyncInvalidSocket
when defined(macosx) and not defined(nimdoc):
if not handle.setSockOpt(SOL_SOCKET, SO_NOSIGPIPE, 1):
if not setSockOpt(AsyncFD(handle), SOL_SOCKET, SO_NOSIGPIPE, 1):
close(handle)
return asyncInvalidSocket
result = AsyncFD(handle)
register(result)
proc wrapAsyncSocket*(sock: SocketHandle): AsyncFD =
## Wraps normal socket to asynchronous socket.
## Wraps socket to asynchronous socket handle.
## Return ``asyncInvalidSocket`` on error.
if not setSocketBlocking(sock, false):
close(sock)
return asyncInvalidSocket
when defined(macosx) and not defined(nimdoc):
if not sock.setSockOpt(SOL_SOCKET, SO_NOSIGPIPE, 1):
if not setSockOpt(AsyncFD(sock), SOL_SOCKET, SO_NOSIGPIPE, 1):
close(sock)
return asyncInvalidSocket
result = AsyncFD(sock)
register(result)
proc closeAsyncSocket*(s: AsyncFD) {.inline.} =
## Closes asynchronous socket handle ``s``.
unregister(s)
close(SocketHandle(s))

View File

@ -41,46 +41,52 @@ when defined(linux) or defined(android):
proc sendfile*(outfd, infd: int, offset: int, count: int): int =
var o = offset
result = osSendFile(cint(outfd), cint(infd), addr offset, count)
result = osSendFile(cint(outfd), cint(infd), addr o, count)
elif defined(freebsd) or defined(openbsd) or defined(netbsd) or
defined(dragonflybsd):
type
sendfileHeader* = object {.importc: "sf_hdtr",
SendfileHeader* = object {.importc: "sf_hdtr",
header: """#include <sys/types.h>
#include <sys/socket.h>
#include <sys/uio.h>""",
pure, final.}
proc osSendFile*(outfd, infd: cint, offset: int, size: int,
hdtr: ptr sendfileHeader, sbytes: ptr int,
proc osSendFile*(outfd, infd: cint, offset: uint, size: uint,
hdtr: ptr SendfileHeader, sbytes: ptr uint,
flags: int): int {.importc: "sendfile",
header: """#include <sys/types.h>
#include <sys/socket.h>
#include <sys/uio.h>""".}
proc sendfile*(outfd, infd: int, offset: int, count: int): int =
var o = 0
result = osSendFile(cint(outfd), cint(infd), offset, count, nil,
addr o, 0)
var o = 0'u
if osSendFile(cint(infd), cint(outfd), uint(offset), uint(count), nil,
addr o, 0) == 0:
result = int(o)
else:
result = -1
elif defined(macosx):
import posix
type
sendfileHeader* = object {.importc: "sf_hdtr",
SendfileHeader* = object {.importc: "sf_hdtr",
header: """#include <sys/types.h>
#include <sys/socket.h>
#include <sys/uio.h>""",
pure, final.}
proc osSendFile*(fd, s: cint, offset: int, size: ptr int,
hdtr: ptr sendfileHeader,
hdtr: ptr SendfileHeader,
flags: int): int {.importc: "sendfile",
header: """#include <sys/types.h>
#include <sys/socket.h>
#include <sys/uio.h>""".}
proc sendfile*(outfd, infd: int, offset: int, count: int): int =
var o = 0
result = osSendFile(cint(fd), cint(s), offset, addr o, nil, 0)
var o = count
if osSendFile(cint(infd), cint(outfd), offset, addr o, nil, 0) == 0:
result = o
else:
result = -1

View File

@ -7,7 +7,7 @@
# Apache License, version 2.0, (LICENSE-APACHEv2)
# MIT license (LICENSE-MIT)
import net
import net, strutils
import ../asyncloop, ../asyncsync
const
@ -98,7 +98,20 @@ proc `$`*(address: TransportAddress): string =
result.add(":")
result.add($int(address.port))
## TODO: string -> TransportAddress conversion
proc strAddress*(address: string): TransportAddress =
## Parses string representation of ``address``.
##
## IPv4 transport address format is ``a.b.c.d:port``.
## IPv6 transport address format is ``[::]:port``.
var parts = address.rsplit(":", maxsplit = 1)
doAssert(len(parts) == 2, "Format is <address>:<port>!")
let port = parseInt(parts[1])
doAssert(port > 0 and port < 65536, "Illegal port number!")
result.port = Port(port)
if parts[0][0] == '[' and parts[0][^1] == ']':
result.address = parseIpAddress(parts[0][1..^2])
else:
result.address = parseIpAddress(parts[0])
template checkClosed*(t: untyped) =
if (ReadClosed in (t).state) or (WriteClosed in (t).state):

View File

@ -87,7 +87,7 @@ when defined(windows):
proc writeDatagramLoop(udata: pointer) =
var bytesCount: int32
var ovl = cast[PCustomOverlapped](udata)
var ovl = cast[PtrCustomOverlapped](udata)
var transp = cast[WindowsDatagramTransport](ovl.data.udata)
while len(transp.queue) > 0:
if WritePending in transp.state:
@ -135,7 +135,7 @@ when defined(windows):
var
bytesCount: int32
raddr: TransportAddress
var ovl = cast[PCustomOverlapped](udata)
var ovl = cast[PtrCustomOverlapped](udata)
var transp = cast[WindowsDatagramTransport](ovl.data.udata)
while true:
if ReadPending in transp.state:

View File

@ -7,35 +7,28 @@
# Apache License, version 2.0, (LICENSE-APACHEv2)
# MIT license (LICENSE-MIT)
import ../asyncloop, ../asyncsync, ../handles
import ../asyncloop, ../asyncsync, ../handles, ../sendfile
import common
import net, nativesockets, os, deques, strutils
when defined(windows):
import winlean
else:
import posix
type
VectorKind = enum
DataBuffer, # Simple buffer pointer/length
DataFile # File handle for sendfile/TransmitFile
when defined(windows):
import winlean
type
StreamVector = object
kind: VectorKind # Writer vector source kind
dataBuf: ptr TWSABuf # Writer vector buffer
offset: uint # Writer vector offset
writer: Future[void] # Writer vector completion Future
else:
import posix
type
StreamVector = object
kind: VectorKind # Writer vector source kind
buf: pointer # Writer buffer pointer
buflen: int # Writer buffer size
offset: uint # Writer vector offset
writer: Future[void] # Writer vector completion Future
type
StreamVector = object
kind: VectorKind # Writer vector source kind
buf: pointer # Writer buffer pointer
buflen: int # Writer buffer size
offset: uint # Writer vector offset
writer: Future[void] # Writer vector completion Future
TransportKind* {.pure.} = enum
Socket, # Socket transport
Pipe, # Pipe transport
@ -51,19 +44,20 @@ type
error: ref Exception # Current error
queue: Deque[StreamVector] # Writer queue
future: Future[void] # Stream life future
transferred: int
case kind*: TransportKind
of TransportKind.Socket:
domain: Domain # Socket transport domain (IPv4/IPv6)
local: TransportAddress # Local address
remote: TransportAddress # Remote address
of TransportKind.Pipe:
fd0: AsyncFD
fd1: AsyncFD
todo1: int
of TransportKind.File:
length: int
todo2: int
StreamCallback* = proc(t: StreamTransport,
udata: pointer): Future[void] {.gcsafe.}
## New connection callback
StreamServer* = ref object of SocketServer
function*: StreamCallback
@ -111,14 +105,6 @@ template checkPending(t: untyped) =
if not isNil((t).reader):
raise newException(TransportError, "Read operation already pending!")
# template shiftBuffer(t, c: untyped) =
# let length = len((t).buffer)
# if length > c:
# moveMem(addr((t).buffer[0]), addr((t).buffer[(c)]), length - (c))
# (t).offset = (t).offset - (c)
# else:
# (t).offset = 0
template shiftBuffer(t, c: untyped) =
if (t).offset > c:
moveMem(addr((t).buffer[0]), addr((t).buffer[(c)]), (t).offset - (c))
@ -126,20 +112,29 @@ template shiftBuffer(t, c: untyped) =
else:
(t).offset = 0
template shiftVectorBuffer(v, o: untyped) =
(v).buf = cast[pointer](cast[uint]((v).buf) + uint(o))
(v).buflen -= int(o)
template shiftVectorFile(v, o: untyped) =
(v).buf = cast[pointer](cast[uint]((v).buf) - cast[uint](o))
(v).offset += cast[uint]((o))
when defined(windows):
import winlean
type
WindowsStreamTransport = ref object of StreamTransport
wsabuf: TWSABuf # Reader WSABUF
rwsabuf: TWSABuf # Reader WSABUF
wwsabuf: TWSABuf # Writer WSABUF
rovl: CustomOverlapped # Reader OVERLAPPED structure
wovl: CustomOverlapped # Writer OVERLAPPED structure
roffset: int # Pending reading offset
WindowsStreamServer* = ref object of RootRef
server: SocketServer # Server object
domain: Domain
abuffer: array[128, byte]
aovl: CustomOverlapped
domain: Domain # Current server domain (IPv4 or IPv6)
abuffer: array[128, byte] # Windows AcceptEx buffer
aovl: CustomOverlapped # AcceptEx OVERLAPPED structure
const SO_UPDATE_CONNECT_CONTEXT = 0x7010
@ -155,38 +150,30 @@ when defined(windows):
(t).offset = cast[int32](cast[uint64](o) and 0xFFFFFFFF'u64)
(t).offsetHigh = cast[int32](cast[uint64](o) shr 32)
template getFileSize(t: untyped): uint =
cast[uint]((t).dataBuf.buf)
template getFileSize(v: untyped): uint =
cast[uint]((v).buf)
template getFileHandle(t: untyped): Handle =
cast[Handle]((t).dataBuf.len)
template slideOffset(v, o: untyped) =
let s = cast[uint]((v).dataBuf.buf) - cast[uint]((o))
(v).dataBuf.buf = cast[cstring](s)
(v).offset = (v).offset + cast[uint]((o))
template getFileHandle(v: untyped): Handle =
cast[Handle]((v).buflen)
template slideBuffer(t, o: untyped) =
(t).dataBuf.buf = cast[cstring](cast[uint]((t).dataBuf.buf) + uint(o))
(t).dataBuf.len -= int32(o)
(t).wwsabuf.buf = cast[cstring](cast[uint]((t).wwsabuf.buf) + uint(o))
(t).wwsabuf.len -= int32(o)
template setWSABuffer(t: untyped) =
(t).wsabuf.buf = cast[cstring](
template setReaderWSABuffer(t: untyped) =
(t).rwsabuf.buf = cast[cstring](
cast[uint](addr t.buffer[0]) + uint((t).roffset))
(t).wsabuf.len = int32(len((t).buffer) - (t).roffset)
(t).rwsabuf.len = int32(len((t).buffer) - (t).roffset)
# template initTransmitStreamVector(v, h, o, n, t: untyped) =
# (v).kind = DataFile
# (v).dataBuf.buf = cast[cstring]((n))
# (v).dataBuf.len = cast[int32]((h))
# (v).offset = cast[uint]((o))
# (v).writer = (t)
template setWriterWSABuffer(t, v: untyped) =
(t).wwsabuf.buf = cast[cstring](v.buf)
(t).wwsabuf.len = cast[int32](v.buflen)
proc writeStreamLoop(udata: pointer) {.gcsafe.} =
var bytesCount: int32
if isNil(udata):
return
var ovl = cast[PCustomOverlapped](udata)
var ovl = cast[PtrCustomOverlapped](udata)
var transp = cast[WindowsStreamTransport](ovl.data.udata)
while len(transp.queue) > 0:
@ -202,14 +189,14 @@ when defined(windows):
else:
if transp.kind == TransportKind.Socket:
if vector.kind == VectorKind.DataBuffer:
if bytesCount < vector.dataBuf.len:
vector.slideBuffer(bytesCount)
if bytesCount < transp.wwsabuf.len:
vector.shiftVectorBuffer(bytesCount)
transp.queue.addFirst(vector)
else:
vector.writer.complete()
else:
if uint(bytesCount) < getFileSize(vector):
vector.slideOffset(bytesCount)
vector.shiftVectorFile(bytesCount)
transp.queue.addFirst(vector)
else:
vector.writer.complete()
@ -224,19 +211,21 @@ when defined(windows):
var vector = transp.queue.popFirst()
if vector.kind == VectorKind.DataBuffer:
transp.wovl.zeroOvelappedOffset()
let ret = WSASend(sock, vector.dataBuf, 1,
transp.setWriterWSABuffer(vector)
let ret = WSASend(sock, addr transp.wwsabuf, 1,
addr bytesCount, DWORD(0),
cast[POVERLAPPED](addr transp.wovl), nil)
if ret != 0:
let err = osLastError()
if int(err) == ERROR_OPERATION_ABORTED:
transp.state.excl(WritePending)
transp.state.incl(WritePaused)
elif int(err) == ERROR_IO_PENDING:
transp.queue.addFirst(vector)
else:
transp.state.excl(WritePending)
transp.setWriteError(err)
transp.finishWriter()
vector.writer.complete()
else:
transp.queue.addFirst(vector)
else:
@ -256,13 +245,14 @@ when defined(windows):
if ret == 0:
let err = osLastError()
if int(err) == ERROR_OPERATION_ABORTED:
transp.state.excl(WritePending)
transp.state.incl(WritePaused)
elif int(err) == ERROR_IO_PENDING:
transp.queue.addFirst(vector)
else:
transp.state.excl(WritePending)
transp.setWriteError(err)
transp.finishWriter()
vector.writer.complete()
else:
transp.queue.addFirst(vector)
break
@ -273,18 +263,19 @@ when defined(windows):
proc readStreamLoop(udata: pointer) {.gcsafe.} =
if isNil(udata):
return
var ovl = cast[PCustomOverlapped](udata)
var ovl = cast[PtrCustomOverlapped](udata)
var transp = cast[WindowsStreamTransport](ovl.data.udata)
while true:
if ReadPending in transp.state:
## Continuation
transp.state.excl(ReadPending)
if ReadClosed in transp.state:
break
transp.state.excl(ReadPending)
let err = transp.rovl.data.errCode
if err == OSErrorCode(-1):
let bytesCount = transp.rovl.data.bytesCount
transp.transferred += bytesCount
if bytesCount == 0:
transp.state.incl(ReadEof)
transp.state.incl(ReadPaused)
@ -301,22 +292,27 @@ when defined(windows):
transp.setReadError(err)
if not isNil(transp.reader):
transp.finishReader()
if ReadPaused in transp.state:
# Transport buffer is full, so we will not continue on reading.
break
else:
## Initiation
if (ReadEof notin transp.state) and (ReadClosed notin transp.state):
if transp.state * {ReadEof, ReadClosed, ReadError} == {}:
var flags = DWORD(0)
var bytesCount: int32 = 0
transp.state.excl(ReadPaused)
transp.state.incl(ReadPending)
if transp.kind == TransportKind.Socket:
let sock = SocketHandle(transp.rovl.data.fd)
transp.setWSABuffer()
let ret = WSARecv(sock, addr transp.wsabuf, 1,
transp.roffset = transp.offset
transp.setReaderWSABuffer()
let ret = WSARecv(sock, addr transp.rwsabuf, 1,
addr bytesCount, addr flags,
cast[POVERLAPPED](addr transp.rovl), nil)
if ret != 0:
let err = osLastError()
if int(err) == ERROR_OPERATION_ABORTED:
transp.state.excl(ReadPending)
transp.state.incl(ReadPaused)
elif int32(err) != ERROR_IO_PENDING:
transp.setReadError(err)
@ -326,17 +322,18 @@ when defined(windows):
break
proc newStreamSocketTransport(sock: AsyncFD, bufsize: int): StreamTransport =
var t = WindowsStreamTransport(kind: TransportKind.Socket)
t.fd = sock
t.rovl.data = CompletionData(fd: sock, cb: readStreamLoop,
udata: cast[pointer](t))
t.wovl.data = CompletionData(fd: sock, cb: writeStreamLoop,
udata: cast[pointer](t))
t.buffer = newSeq[byte](bufsize)
t.state = {ReadPaused, WritePaused}
t.queue = initDeque[StreamVector]()
t.future = newFuture[void]("stream.socket.transport")
result = cast[StreamTransport](t)
var transp = WindowsStreamTransport(kind: TransportKind.Socket)
transp.fd = sock
transp.rovl.data = CompletionData(fd: sock, cb: readStreamLoop,
udata: cast[pointer](transp))
transp.wovl.data = CompletionData(fd: sock, cb: writeStreamLoop,
udata: cast[pointer](transp))
transp.buffer = newSeq[byte](bufsize)
transp.state = {ReadPaused, WritePaused}
transp.queue = initDeque[StreamVector]()
transp.future = newFuture[void]("stream.socket.transport")
GC_ref(transp)
result = cast[StreamTransport](transp)
proc bindToDomain(handle: AsyncFD, domain: Domain): bool =
result = true
@ -374,7 +371,7 @@ when defined(windows):
result.fail(newException(OSError, osErrorMsg(osLastError())))
proc continuation(udata: pointer) =
var ovl = cast[PCustomOverlapped](udata)
var ovl = cast[PtrCustomOverlapped](udata)
if not retFuture.finished:
if ovl.data.errCode == OSErrorCode(-1):
if setsockopt(SocketHandle(sock), cint(SOL_SOCKET),
@ -417,7 +414,7 @@ when defined(windows):
let dwRemoteAddressLength = DWORD(sizeof(Sockaddr_in6) + 16)
proc continuation(udata: pointer) =
var ovl = cast[PCustomOverlapped](udata)
var ovl = cast[PtrCustomOverlapped](udata)
if not retFuture.finished:
if server.server.status in {Stopped, Paused}:
sock.closeAsyncSocket()
@ -482,7 +479,7 @@ when defined(windows):
if not acceptFut.failed:
var sock = acceptFut.read()
if sock != asyncInvalidSocket:
spawn server.function(
discard server.function(
newStreamSocketTransport(sock, server.bufferSize),
server.udata)
@ -508,10 +505,6 @@ else:
template getVectorLength(v: untyped): int =
cast[int]((v).buflen - int((v).boffset))
template shiftVectorBuffer(t, o: untyped) =
(t).buf = cast[pointer](cast[uint]((t).buf) + uint(o))
(t).buflen -= int(o)
template initBufferStreamVector(v, p, n, t: untyped) =
(v).kind = DataBuffer
(v).buf = cast[pointer]((p))
@ -524,7 +517,6 @@ else:
let fd = SocketHandle(cdata.fd)
if not isNil(transp):
if len(transp.queue) > 0:
echo "len(transp.queue) = ", len(transp.queue)
var vector = transp.queue.popFirst()
while true:
if transp.kind == TransportKind.Socket:
@ -543,9 +535,24 @@ else:
else:
transp.setWriteError(err)
vector.writer.complete()
break
else:
discard
let res = sendfile(int(fd), cast[int](vector.buflen),
int(vector.offset),
cast[int](vector.buf))
if res >= 0:
if cast[int](vector.buf) - res == 0:
vector.writer.complete()
else:
vector.shiftVectorFile(res)
transp.queue.addFirst(vector)
else:
let err = osLastError()
if int(err) == EINTR:
continue
else:
transp.setWriteError(err)
vector.writer.complete()
break
else:
transp.state.incl(WritePaused)
transp.fd.removeWriter()
@ -583,13 +590,14 @@ else:
break
proc newStreamSocketTransport(sock: AsyncFD, bufsize: int): StreamTransport =
var t = UnixStreamTransport(kind: TransportKind.Socket)
t.fd = sock
t.buffer = newSeq[byte](bufsize)
t.state = {ReadPaused, WritePaused}
t.queue = initDeque[StreamVector]()
t.future = newFuture[void]("socket.stream.transport")
result = cast[StreamTransport](t)
var transp = UnixStreamTransport(kind: TransportKind.Socket)
transp.fd = sock
transp.buffer = newSeq[byte](bufsize)
transp.state = {ReadPaused, WritePaused}
transp.queue = initDeque[StreamVector]()
transp.future = newFuture[void]("socket.stream.transport")
GC_ref(transp)
result = cast[StreamTransport](transp)
proc connect*(address: TransportAddress,
bufferSize = DefaultStreamBufferSize): Future[StreamTransport] =
@ -641,7 +649,6 @@ else:
var
saddr: Sockaddr_storage
slen: SockLen
var server = cast[StreamServer](cast[ptr CompletionData](udata).udata)
while true:
let res = posix.accept(SocketHandle(server.sock),
@ -649,7 +656,7 @@ else:
if int(res) > 0:
let sock = wrapAsyncSocket(res)
if sock != asyncInvalidSocket:
spawn server.function(
discard server.function(
newStreamSocketTransport(sock, server.bufferSize),
server.udata)
break
@ -692,19 +699,28 @@ else:
addWriter(transp.fd, writeStreamLoop, cast[pointer](transp))
proc start*(server: SocketServer) =
## Starts ``server``.
server.action = Start
server.actEvent.fire()
proc stop*(server: SocketServer) =
## Stops ``server``
server.action = Stop
server.actEvent.fire()
proc pause*(server: SocketServer) =
## Pause ``server``
server.action = Pause
server.actEvent.fire()
proc join*(server: SocketServer) {.async.} =
await server.loopFuture
## Waits until ``server`` is not stopped.
if not server.loopFuture.finished:
await server.loopFuture
proc close*(server: SocketServer) =
## Release ``server`` resources.
GC_unref(server)
proc createStreamServer*(host: TransportAddress,
flags: set[ServerFlags],
@ -729,7 +745,6 @@ proc createStreamServer*(host: TransportAddress,
register(sock)
serverSocket = sock
## TODO: Set socket options here
if ServerFlags.ReuseAddr in flags:
if not setSockOpt(serverSocket, SOL_SOCKET, SO_REUSEADDR, 1):
let err = osLastError()
@ -744,7 +759,7 @@ proc createStreamServer*(host: TransportAddress,
if sock == asyncInvalidSocket:
closeAsyncSocket(serverSocket)
raiseOsError(err)
if nativesockets.listen(SocketHandle(serverSocket), cint(backlog)) != 0:
let err = osLastError()
if sock == asyncInvalidSocket:
@ -759,19 +774,15 @@ proc createStreamServer*(host: TransportAddress,
result.actEvent = newAsyncEvent()
result.udata = udata
result.local = host
GC_ref(result)
result.loopFuture = serverLoop(result)
proc write*(transp: StreamTransport, pbytes: pointer,
nbytes: int): Future[int] {.async.} =
checkClosed(transp)
var waitFuture = newFuture[void]("transport.write")
var vector = StreamVector(kind: DataBuffer, writer: waitFuture)
when defined(windows):
var wsabuf = TWSABuf(buf: cast[cstring](pbytes), len: cast[int32](nbytes))
vector.dataBuf = addr wsabuf
else:
vector.buf = pbytes
vector.buflen = nbytes
var vector = StreamVector(kind: DataBuffer, writer: waitFuture,
buf: pbytes, buflen: nbytes)
transp.queue.addLast(vector)
if WritePaused in transp.state:
transp.resumeWrite()
@ -780,25 +791,25 @@ proc write*(transp: StreamTransport, pbytes: pointer,
raise transp.getError()
result = nbytes
# proc writeFile*(transp: StreamTransport, handle: int,
# offset: uint = 0,
# size: uint = 0): Future[void] {.async.} =
# if transp.kind != TransportKind.Socket:
# raise newException(TransportError, "You can transmit files only to sockets")
# checkClosed(transp)
# var waitFuture = newFuture[void]("transport.writeFile")
# var vector: StreamVector
# vector.initTransmitStreamVector(handle, offset, size, waitFuture)
# transp.queue.addLast(vector)
# if WritePaused in transp.state:
# transp.resumeWrite()
# await vector.writer
# if WriteError in transp.state:
# raise transp.getError()
proc writeFile*(transp: StreamTransport, handle: int,
offset: uint = 0,
size: int = 0): Future[void] {.async.} =
if transp.kind != TransportKind.Socket:
raise newException(TransportError, "You can transmit files only to sockets")
checkClosed(transp)
var waitFuture = newFuture[void]("transport.writeFile")
var vector = StreamVector(kind: DataFile, writer: waitFuture,
buf: cast[pointer](size), offset: offset,
buflen: handle)
transp.queue.addLast(vector)
if WritePaused in transp.state:
transp.resumeWrite()
await vector.writer
if WriteError in transp.state:
raise transp.getError()
proc readExactly*(transp: StreamTransport, pbytes: pointer,
nbytes: int): Future[int] {.async.} =
nbytes: int) {.async.} =
## Read exactly ``nbytes`` bytes from transport ``transp``.
checkClosed(transp)
checkPending(transp)
@ -814,17 +825,19 @@ proc readExactly*(transp: StreamTransport, pbytes: pointer,
copyMem(cast[pointer](cast[uint](pbytes) + uint(index)),
addr(transp.buffer[0]), nbytes - index)
transp.shiftBuffer(nbytes - index)
result = nbytes
break
else:
copyMem(cast[pointer](cast[uint](pbytes) + uint(index)),
addr(transp.buffer[0]), transp.offset)
index += transp.offset
transp.reader = newFuture[void]("transport.readExactly")
if transp.offset != 0:
copyMem(cast[pointer](cast[uint](pbytes) + uint(index)),
addr(transp.buffer[0]), transp.offset)
index += transp.offset
transp.reader = newFuture[void]("stream.transport.readExactly")
transp.offset = 0
if ReadPaused in transp.state:
transp.resumeRead()
await transp.reader
# we are no longer need data
transp.reader = nil
@ -840,7 +853,7 @@ proc readOnce*(transp: StreamTransport, pbytes: pointer,
if (ReadEof in transp.state) or (ReadClosed in transp.state):
result = 0
break
transp.reader = newFuture[void]("transport.readOnce")
transp.reader = newFuture[void]("stream.transport.readOnce")
if ReadPaused in transp.state:
transp.resumeRead()
await transp.reader
@ -898,7 +911,7 @@ proc readUntil*(transp: StreamTransport, pbytes: pointer, nbytes: int,
break
else:
if (transp.offset - index) == 0:
transp.reader = newFuture[void]("transport.readUntil")
transp.reader = newFuture[void]("stream.transport.readUntil")
if ReadPaused in transp.state:
transp.resumeRead()
await transp.reader
@ -945,7 +958,7 @@ proc readLine*(transp: StreamTransport, limit = 0,
break
else:
if (transp.offset - index) == 0:
transp.reader = newFuture[void]("transport.readLine")
transp.reader = newFuture[void]("stream.transport.readLine")
if ReadPaused in transp.state:
transp.resumeRead()
await transp.reader
@ -990,7 +1003,7 @@ proc read*(transp: StreamTransport, n = -1): Future[seq[byte]] {.async.} =
transp.offset)
transp.offset = 0
transp.reader = newFuture[void]("transport.read")
transp.reader = newFuture[void]("stream.transport.read")
if ReadPaused in transp.state:
transp.resumeRead()
await transp.reader
@ -1005,7 +1018,8 @@ proc atEof*(transp: StreamTransport): bool {.inline.} =
proc join*(transp: StreamTransport) {.async.} =
## Wait until ``transp`` will not be closed.
await transp.future
if not transp.future.finished:
await transp.future
proc close*(transp: StreamTransport) =
## Closes and frees resources of transport ``transp``.
@ -1016,3 +1030,4 @@ proc close*(transp: StreamTransport) =
transp.state.incl(WriteClosed)
transp.state.incl(ReadClosed)
transp.future.complete()
GC_unref(transp)

View File

@ -1,4 +1,4 @@
# Asyncdispatch2
# Asyncdispatch2 Test Suite
# (c) Copyright 2018
# Status Research & Development GmbH
#

View File

@ -1,82 +1,295 @@
import strutils, net, unittest
# Asyncdispatch2 Test Suite
# (c) Copyright 2018
# Status Research & Development GmbH
#
# Licensed under either of
# Apache License, version 2.0, (LICENSE-APACHEv2)
# MIT license (LICENSE-MIT)
import strutils, net, unittest, os
import ../asyncdispatch2
when defined(windows):
import winlean
else:
import posix
const
ClientsCount = 1
MessagesCount = 100000
ClientsCount = 50
MessagesCount = 50
MessageSize = 20
FilesCount = 50
FilesTestName = "teststream.nim"
proc serveClient1(transp: StreamTransport, udata: pointer) {.async.} =
echo "SERVER STARTING (0x" & toHex[uint](cast[uint](transp)) & ")"
while not transp.atEof():
var data = await transp.readLine()
echo "SERVER READ [" & data & "]"
if data.startsWith("REQUEST"):
var numstr = data[7..^1]
var num = parseInt(numstr)
var ans = "ANSWER" & $num & "\r\n"
var res = await transp.write(cast[pointer](addr ans[0]), len(ans))
# doAssert(res == len(ans))
echo "SERVER EXITING (0x" & toHex[uint](cast[uint](transp)) & ")"
proc swarmWorker(address: TransportAddress) {.async.} =
echo "CONNECTING TO " & $address
var transp = await connect(address)
echo "CONNECTED"
for i in 0..<MessagesCount:
echo "MESSAGE " & $i
var data = "REQUEST" & $i & "\r\n"
var res = await transp.write(cast[pointer](addr data[0]), len(data))
echo "CLIENT WRITE COMPLETED"
assert(res == len(data))
var ans = await transp.readLine()
if ans.startsWith("ANSWER"):
var numstr = ans[6..^1]
var num = parseInt(numstr)
doAssert(num == i)
if len(data) == 0:
doAssert(transp.atEof())
break
doAssert(data.startsWith("REQUEST"))
var numstr = data[7..^1]
var num = parseInt(numstr)
var ans = "ANSWER" & $num & "\r\n"
var res = await transp.write(cast[pointer](addr ans[0]), len(ans))
doAssert(res == len(ans))
transp.close()
proc swarmManager(address: TransportAddress): Future[void] =
var retFuture = newFuture[void]("swarm.manager")
var workers = newSeq[Future[void]](ClientsCount)
var count = ClientsCount
proc cb(data: pointer) {.gcsafe.} =
if not retFuture.finished:
dec(count)
if count == 0:
retFuture.complete()
for i in 0..<ClientsCount:
workers[i] = swarmWorker(address)
workers[i].addCallback(cb)
proc serveClient2(transp: StreamTransport, udata: pointer) {.async.} =
var buffer: array[20, char]
var check = "REQUEST"
while not transp.atEof():
zeroMem(addr buffer[0], MessageSize)
try:
await transp.readExactly(addr buffer[0], MessageSize)
except TransportIncompleteError:
break
doAssert(equalMem(addr buffer[0], addr check[0], len(check)))
var numstr = ""
var i = 7
while i < MessageSize and (buffer[i] in {'0'..'9'}):
numstr.add(buffer[i])
inc(i)
var num = parseInt(numstr)
var ans = "ANSWER" & $num
zeroMem(addr buffer[0], MessageSize)
copyMem(addr buffer[0], addr ans[0], len(ans))
var res = await transp.write(cast[pointer](addr buffer[0]), MessageSize)
doAssert(res == MessageSize)
transp.close()
proc serveClient3(transp: StreamTransport, udata: pointer) {.async.} =
var buffer: array[20, char]
var check = "REQUEST"
var suffixStr = "SUFFIX"
var suffix = newSeq[byte](6)
copyMem(addr suffix[0], addr suffixStr[0], len(suffixStr))
while not transp.atEof():
zeroMem(addr buffer[0], MessageSize)
var res = await transp.readUntil(addr buffer[0], MessageSize, suffix)
doAssert(equalMem(addr buffer[0], addr check[0], len(check)))
var numstr = ""
var i = 7
while i < MessageSize and (buffer[i] in {'0'..'9'}):
numstr.add(buffer[i])
inc(i)
var num = parseInt(numstr)
doAssert(len(numstr) < 8)
var ans = "ANSWER" & $num & "SUFFIX"
zeroMem(addr buffer[0], MessageSize)
copyMem(addr buffer[0], addr ans[0], len(ans))
res = await transp.write(cast[pointer](addr buffer[0]), len(ans))
doAssert(res == len(ans))
transp.close()
proc serveClient4(transp: StreamTransport, udata: pointer) {.async.} =
var pathname = await transp.readLine()
var size = await transp.readLine()
var sizeNum = parseInt(size)
doAssert(sizeNum >= 0)
var rbuffer = newSeq[byte](sizeNum)
await transp.readExactly(addr rbuffer[0], sizeNum)
var lbuffer = readFile(pathname)
doAssert(len(lbuffer) == sizeNum)
doAssert(equalMem(addr rbuffer[0], addr lbuffer[0], sizeNum))
var answer = "OK\r\n"
var res = await transp.write(cast[pointer](addr answer[0]), len(answer))
doAssert(res == len(answer))
proc swarmWorker1(address: TransportAddress): Future[int] {.async.} =
var transp = await connect(address)
for i in 0..<MessagesCount:
var data = "REQUEST" & $i & "\r\n"
var res = await transp.write(cast[pointer](addr data[0]), len(data))
assert(res == len(data))
var ans = await transp.readLine()
doAssert(ans.startsWith("ANSWER"))
var numstr = ans[6..^1]
var num = parseInt(numstr)
doAssert(num == i)
inc(result)
transp.close()
proc swarmWorker2(address: TransportAddress): Future[int] {.async.} =
var transp = await connect(address)
var buffer: array[MessageSize, char]
var check = "ANSWER"
for i in 0..<MessagesCount:
var data = "REQUEST" & $i & "\r\n"
zeroMem(addr buffer[0], MessageSize)
copyMem(addr buffer[0], addr data[0], min(MessageSize, len(data)))
var res = await transp.write(cast[pointer](addr buffer[0]), MessageSize)
doAssert(res == MessageSize)
zeroMem(addr buffer[0], MessageSize)
await transp.readExactly(addr buffer[0], MessageSize)
doAssert(equalMem(addr buffer[0], addr check[0], len(check)))
var numstr = ""
var k = 6
while k < MessageSize and (buffer[k] in {'0'..'9'}):
numstr.add(buffer[k])
inc(k)
var num = parseInt(numstr)
doAssert(num == i)
inc(result)
transp.close()
proc swarmWorker3(address: TransportAddress): Future[int] {.async.} =
var transp = await connect(address)
var buffer: array[MessageSize, char]
var check = "ANSWER"
var suffixStr = "SUFFIX"
var suffix = newSeq[byte](6)
copyMem(addr suffix[0], addr suffixStr[0], len(suffixStr))
for i in 0..<MessagesCount:
var data = "REQUEST" & $i & "SUFFIX"
doAssert(len(data) <= MessageSize)
zeroMem(addr buffer[0], MessageSize)
copyMem(addr buffer[0], addr data[0], len(data))
var res = await transp.write(cast[pointer](addr buffer[0]), len(data))
doAssert(res == len(data))
zeroMem(addr buffer[0], MessageSize)
res = await transp.readUntil(addr buffer[0], MessageSize, suffix)
doAssert(equalMem(addr buffer[0], addr check[0], len(check)))
var numstr = ""
var k = 6
while k < MessageSize and (buffer[k] in {'0'..'9'}):
numstr.add(buffer[k])
inc(k)
var num = parseInt(numstr)
doAssert(num == i)
inc(result)
transp.close()
proc swarmWorker4(address: TransportAddress): Future[int] {.async.} =
var transp = await connect(address)
var ssize: string
var handle = 0
var name = FilesTestName
var size = int(getFileSize(FilesTestName))
var fhandle = open(FilesTestName)
when defined(windows):
handle = int(get_osfhandle(getFileHandle(fhandle)))
else:
handle = int(getFileHandle(fhandle))
doAssert(handle > 0)
name = name & "\r\n"
ssize = $size & "\r\n"
var res = await transp.write(cast[pointer](addr name[0]), len(name))
doAssert(res == len(name))
res = await transp.write(cast[pointer](addr ssize[0]), len(ssize))
doAssert(res == len(ssize))
await transp.writeFile(handle, 0'u, size)
var ans = await transp.readLine()
doAssert(ans == "OK")
result = 1
transp.close()
proc waitAll[T](futs: seq[Future[T]]): Future[void] =
var counter = len(futs)
var retFuture = newFuture[void]("waitAll")
proc cb(udata: pointer) =
dec(counter)
if counter == 0:
retFuture.complete()
for fut in futs:
fut.addCallback(cb)
return retFuture
when isMainModule:
var ta: TransportAddress
ta.address = parseIpAddress("127.0.0.1")
ta.port = Port(31344)
proc swarmManager1(address: TransportAddress): Future[int] {.async.} =
var retFuture = newFuture[void]("swarm.manager.readLine")
var workers = newSeq[Future[int]](ClientsCount)
var count = ClientsCount
for i in 0..<ClientsCount:
workers[i] = swarmWorker1(address)
await waitAll(workers)
for i in 0..<ClientsCount:
var res = workers[i].read()
result += res
proc swarmManager2(address: TransportAddress): Future[int] {.async.} =
var retFuture = newFuture[void]("swarm.manager.readExactly")
var workers = newSeq[Future[int]](ClientsCount)
var count = ClientsCount
for i in 0..<ClientsCount:
workers[i] = swarmWorker2(address)
await waitAll(workers)
for i in 0..<ClientsCount:
var res = workers[i].read()
result += res
proc swarmManager3(address: TransportAddress): Future[int] {.async.} =
var retFuture = newFuture[void]("swarm.manager.readUntil")
var workers = newSeq[Future[int]](ClientsCount)
var count = ClientsCount
for i in 0..<ClientsCount:
workers[i] = swarmWorker3(address)
await waitAll(workers)
for i in 0..<ClientsCount:
var res = workers[i].read()
result += res
proc swarmManager4(address: TransportAddress): Future[int] {.async.} =
var retFuture = newFuture[void]("swarm.manager.writeFile")
var workers = newSeq[Future[int]](FilesCount)
var count = FilesCount
for i in 0..<FilesCount:
workers[i] = swarmWorker4(address)
await waitAll(workers)
for i in 0..<FilesCount:
var res = workers[i].read()
result += res
proc test1(): Future[int] {.async.} =
var ta = strAddress("127.0.0.1:31344")
var server = createStreamServer(ta, {ReuseAddr}, serveClient1)
server.start()
waitFor(swarmManager(ta))
result = await swarmManager1(ta)
server.stop()
server.close()
proc test2(): Future[int] {.async.} =
var ta = strAddress("127.0.0.1:31345")
var counter = 0
var server = createStreamServer(ta, {ReuseAddr}, serveClient2)
server.start()
result = await swarmManager2(ta)
server.stop()
server.close()
proc test3(): Future[int] {.async.} =
var ta = strAddress("127.0.0.1:31346")
var counter = 0
var server = createStreamServer(ta, {ReuseAddr}, serveClient3)
server.start()
result = await swarmManager3(ta)
server.stop()
server.close()
# proc processClient*(t: StreamTransport, udata: pointer) {.async.} =
# var data = newSeq[byte](10)
# var f: File
# echo "CONNECTED FROM ", $t.remoteAddress()
# if not f.open("timer.nim"):
# echo "ERROR OPENING FILE"
# echo f.getFileHandle()
# # try:
# when defined(windows):
# await t.writeFile(int(get_osfhandle(f.getFileHandle())))
# else:
# await t.writeFile(int(f.getFileHandle()))
proc test4(): Future[int] {.async.} =
var ta = strAddress("127.0.0.1:31347")
var counter = 0
var server = createStreamServer(ta, {ReuseAddr}, serveClient4)
server.start()
result = await swarmManager4(ta)
server.stop()
server.close()
# proc test2() {.async.} =
# var s = createStreamServer(parseIpAddress("0.0.0.0"), Port(31337),
# {ReusePort}, processClient)
# s.start()
# await s.join()
when isMainModule:
const
m1 = "readLine() multiple clients with messages (" & $ClientsCount &
" clients x " & $MessagesCount & " messages)"
m2 = "readExactly() multiple clients with messages (" & $ClientsCount &
" clients x " & $MessagesCount & " messages)"
m3 = "readUntil() multiple clients with messages (" & $ClientsCount &
" clients x " & $MessagesCount & " messages)"
m4 = "writeFile() multiple clients (" & $FilesCount & " files)"
# when isMainModule:
# waitFor(test2())
suite "Stream Transport test suite":
test m1:
check waitFor(test1()) == ClientsCount * MessagesCount
test m2:
check waitFor(test2()) == ClientsCount * MessagesCount
test m3:
check waitFor(test3()) == ClientsCount * MessagesCount
test m4:
check waitFor(test4()) == FilesCount