From 23a81b6492609405ce437d60a3d2e582a613396d Mon Sep 17 00:00:00 2001 From: Your Name Date: Tue, 22 May 2018 00:52:57 +0300 Subject: [PATCH] Refactoring, more tests. --- asyncdispatch2/asyncloop.nim | 14 +- asyncdispatch2/handles.nim | 21 +- asyncdispatch2/sendfile.nim | 30 ++- asyncdispatch2/transports/common.nim | 17 +- asyncdispatch2/transports/datagram.nim | 4 +- asyncdispatch2/transports/stream.nim | 283 ++++++++++---------- tests/testdatagram.nim | 2 +- tests/teststream.nim | 341 ++++++++++++++++++++----- 8 files changed, 474 insertions(+), 238 deletions(-) diff --git a/asyncdispatch2/asyncloop.nim b/asyncdispatch2/asyncloop.nim index fea5b8a8..06cff7fb 100644 --- a/asyncdispatch2/asyncloop.nim +++ b/asyncdispatch2/asyncloop.nim @@ -202,9 +202,6 @@ when defined(windows) or defined(nimdoc): errCode*: OSErrorCode bytesCount*: int32 udata*: pointer - cell*: ForeignCell # we need this `cell` to protect our `cb` environment, - # when using RegisterWaitForSingleObject, because - # waiting is done in different thread. PDispatcher* = ref object of PDispatcherBase ioPort: Handle @@ -217,19 +214,12 @@ when defined(windows) or defined(nimdoc): CustomOverlapped* = object of OVERLAPPED data*: CompletionData - PCustomOverlapped* = ptr CustomOverlapped + PtrCustomOverlapped* = ptr CustomOverlapped RefCustomOverlapped* = ref CustomOverlapped AsyncFD* = distinct int - # PostCallbackData = object - # ioPort: Handle - # handleFd: AsyncFD - # waitFd: Handle - # ovl: PCustomOverlapped - # PostCallbackDataPtr = ptr PostCallbackData - proc hash(x: AsyncFD): Hash {.borrow.} proc `==`*(x: AsyncFD, y: AsyncFD): bool {.borrow.} @@ -292,7 +282,7 @@ when defined(windows) or defined(nimdoc): # Processing handles var lpNumberOfBytesTransferred: Dword var lpCompletionKey: ULONG_PTR - var customOverlapped: PCustomOverlapped + var customOverlapped: PtrCustomOverlapped let res = getQueuedCompletionStatus( loop.ioPort, addr lpNumberOfBytesTransferred, addr lpCompletionKey, cast[ptr POVERLAPPED](addr customOverlapped), curTimeout).bool diff --git a/asyncdispatch2/handles.nim b/asyncdispatch2/handles.nim index b822b85f..6ff514c9 100644 --- a/asyncdispatch2/handles.nim +++ b/asyncdispatch2/handles.nim @@ -12,7 +12,7 @@ import net, nativesockets, asyncloop when defined(windows): import winlean const - asyncInvalidSocket* = AsyncFD(SocketHandle(-1)) + asyncInvalidSocket* = AsyncFD(-1) else: import posix const @@ -35,8 +35,7 @@ proc setSocketBlocking*(s: SocketHandle, blocking: bool): bool = if fcntl(s, F_SETFL, mode) == -1: result = false -proc setSockOpt*(socket: SocketHandle | AsyncFD, level, optname, - optval: int): bool = +proc setSockOpt*(socket: AsyncFD, level, optname, optval: int): bool = ## `setsockopt()` for integer options. ## Returns ``true`` on success, ``false`` on error. result = true @@ -44,9 +43,8 @@ proc setSockOpt*(socket: SocketHandle | AsyncFD, level, optname, if setsockopt(SocketHandle(socket), cint(level), cint(optname), addr(value), sizeof(value).SockLen) < 0'i32: result = false - -proc getSockOpt*(socket: SocketHandle | AsyncFD, level, optname: int, - value: var int): bool = + +proc getSockOpt*(socket: AsyncFD, level, optname: int, value: var int): bool = ## `getsockopt()` for integer options. var res: cint var size = sizeof(res).SockLen @@ -56,8 +54,8 @@ proc getSockOpt*(socket: SocketHandle | AsyncFD, level, optname: int, return false value = int(res) -proc getSocketError*(socket: SocketHandle | AsyncFD, - err: var int): bool = +proc getSocketError*(socket: AsyncFD, err: var int): bool = + ## Recover error code associated with socket handle ``socket``. if not getSockOpt(socket, cint(SOL_SOCKET), cint(SO_ERROR), err): result = false else: @@ -74,25 +72,26 @@ proc createAsyncSocket*(domain: Domain, sockType: SockType, close(handle) return asyncInvalidSocket when defined(macosx) and not defined(nimdoc): - if not handle.setSockOpt(SOL_SOCKET, SO_NOSIGPIPE, 1): + if not setSockOpt(AsyncFD(handle), SOL_SOCKET, SO_NOSIGPIPE, 1): close(handle) return asyncInvalidSocket result = AsyncFD(handle) register(result) proc wrapAsyncSocket*(sock: SocketHandle): AsyncFD = - ## Wraps normal socket to asynchronous socket. + ## Wraps socket to asynchronous socket handle. ## Return ``asyncInvalidSocket`` on error. if not setSocketBlocking(sock, false): close(sock) return asyncInvalidSocket when defined(macosx) and not defined(nimdoc): - if not sock.setSockOpt(SOL_SOCKET, SO_NOSIGPIPE, 1): + if not setSockOpt(AsyncFD(sock), SOL_SOCKET, SO_NOSIGPIPE, 1): close(sock) return asyncInvalidSocket result = AsyncFD(sock) register(result) proc closeAsyncSocket*(s: AsyncFD) {.inline.} = + ## Closes asynchronous socket handle ``s``. unregister(s) close(SocketHandle(s)) diff --git a/asyncdispatch2/sendfile.nim b/asyncdispatch2/sendfile.nim index 02b326fa..2d2c2007 100644 --- a/asyncdispatch2/sendfile.nim +++ b/asyncdispatch2/sendfile.nim @@ -41,46 +41,52 @@ when defined(linux) or defined(android): proc sendfile*(outfd, infd: int, offset: int, count: int): int = var o = offset - result = osSendFile(cint(outfd), cint(infd), addr offset, count) + result = osSendFile(cint(outfd), cint(infd), addr o, count) elif defined(freebsd) or defined(openbsd) or defined(netbsd) or defined(dragonflybsd): type - sendfileHeader* = object {.importc: "sf_hdtr", + SendfileHeader* = object {.importc: "sf_hdtr", header: """#include #include #include """, pure, final.} - proc osSendFile*(outfd, infd: cint, offset: int, size: int, - hdtr: ptr sendfileHeader, sbytes: ptr int, + proc osSendFile*(outfd, infd: cint, offset: uint, size: uint, + hdtr: ptr SendfileHeader, sbytes: ptr uint, flags: int): int {.importc: "sendfile", header: """#include #include #include """.} proc sendfile*(outfd, infd: int, offset: int, count: int): int = - var o = 0 - result = osSendFile(cint(outfd), cint(infd), offset, count, nil, - addr o, 0) + var o = 0'u + if osSendFile(cint(infd), cint(outfd), uint(offset), uint(count), nil, + addr o, 0) == 0: + result = int(o) + else: + result = -1 elif defined(macosx): - + import posix type - sendfileHeader* = object {.importc: "sf_hdtr", + SendfileHeader* = object {.importc: "sf_hdtr", header: """#include #include #include """, pure, final.} proc osSendFile*(fd, s: cint, offset: int, size: ptr int, - hdtr: ptr sendfileHeader, + hdtr: ptr SendfileHeader, flags: int): int {.importc: "sendfile", header: """#include #include #include """.} proc sendfile*(outfd, infd: int, offset: int, count: int): int = - var o = 0 - result = osSendFile(cint(fd), cint(s), offset, addr o, nil, 0) + var o = count + if osSendFile(cint(infd), cint(outfd), offset, addr o, nil, 0) == 0: + result = o + else: + result = -1 diff --git a/asyncdispatch2/transports/common.nim b/asyncdispatch2/transports/common.nim index a740e200..b5da3ef1 100644 --- a/asyncdispatch2/transports/common.nim +++ b/asyncdispatch2/transports/common.nim @@ -7,7 +7,7 @@ # Apache License, version 2.0, (LICENSE-APACHEv2) # MIT license (LICENSE-MIT) -import net +import net, strutils import ../asyncloop, ../asyncsync const @@ -98,7 +98,20 @@ proc `$`*(address: TransportAddress): string = result.add(":") result.add($int(address.port)) -## TODO: string -> TransportAddress conversion +proc strAddress*(address: string): TransportAddress = + ## Parses string representation of ``address``. + ## + ## IPv4 transport address format is ``a.b.c.d:port``. + ## IPv6 transport address format is ``[::]:port``. + var parts = address.rsplit(":", maxsplit = 1) + doAssert(len(parts) == 2, "Format is
:!") + let port = parseInt(parts[1]) + doAssert(port > 0 and port < 65536, "Illegal port number!") + result.port = Port(port) + if parts[0][0] == '[' and parts[0][^1] == ']': + result.address = parseIpAddress(parts[0][1..^2]) + else: + result.address = parseIpAddress(parts[0]) template checkClosed*(t: untyped) = if (ReadClosed in (t).state) or (WriteClosed in (t).state): diff --git a/asyncdispatch2/transports/datagram.nim b/asyncdispatch2/transports/datagram.nim index c148e1c4..f7c6b26c 100644 --- a/asyncdispatch2/transports/datagram.nim +++ b/asyncdispatch2/transports/datagram.nim @@ -87,7 +87,7 @@ when defined(windows): proc writeDatagramLoop(udata: pointer) = var bytesCount: int32 - var ovl = cast[PCustomOverlapped](udata) + var ovl = cast[PtrCustomOverlapped](udata) var transp = cast[WindowsDatagramTransport](ovl.data.udata) while len(transp.queue) > 0: if WritePending in transp.state: @@ -135,7 +135,7 @@ when defined(windows): var bytesCount: int32 raddr: TransportAddress - var ovl = cast[PCustomOverlapped](udata) + var ovl = cast[PtrCustomOverlapped](udata) var transp = cast[WindowsDatagramTransport](ovl.data.udata) while true: if ReadPending in transp.state: diff --git a/asyncdispatch2/transports/stream.nim b/asyncdispatch2/transports/stream.nim index 3452c448..d068ae79 100644 --- a/asyncdispatch2/transports/stream.nim +++ b/asyncdispatch2/transports/stream.nim @@ -7,35 +7,28 @@ # Apache License, version 2.0, (LICENSE-APACHEv2) # MIT license (LICENSE-MIT) -import ../asyncloop, ../asyncsync, ../handles +import ../asyncloop, ../asyncsync, ../handles, ../sendfile import common import net, nativesockets, os, deques, strutils +when defined(windows): + import winlean +else: + import posix + type VectorKind = enum DataBuffer, # Simple buffer pointer/length DataFile # File handle for sendfile/TransmitFile -when defined(windows): - import winlean - type - StreamVector = object - kind: VectorKind # Writer vector source kind - dataBuf: ptr TWSABuf # Writer vector buffer - offset: uint # Writer vector offset - writer: Future[void] # Writer vector completion Future - -else: - import posix - type - StreamVector = object - kind: VectorKind # Writer vector source kind - buf: pointer # Writer buffer pointer - buflen: int # Writer buffer size - offset: uint # Writer vector offset - writer: Future[void] # Writer vector completion Future - type + StreamVector = object + kind: VectorKind # Writer vector source kind + buf: pointer # Writer buffer pointer + buflen: int # Writer buffer size + offset: uint # Writer vector offset + writer: Future[void] # Writer vector completion Future + TransportKind* {.pure.} = enum Socket, # Socket transport Pipe, # Pipe transport @@ -51,19 +44,20 @@ type error: ref Exception # Current error queue: Deque[StreamVector] # Writer queue future: Future[void] # Stream life future + transferred: int case kind*: TransportKind of TransportKind.Socket: domain: Domain # Socket transport domain (IPv4/IPv6) local: TransportAddress # Local address remote: TransportAddress # Remote address of TransportKind.Pipe: - fd0: AsyncFD - fd1: AsyncFD + todo1: int of TransportKind.File: - length: int + todo2: int StreamCallback* = proc(t: StreamTransport, udata: pointer): Future[void] {.gcsafe.} + ## New connection callback StreamServer* = ref object of SocketServer function*: StreamCallback @@ -111,14 +105,6 @@ template checkPending(t: untyped) = if not isNil((t).reader): raise newException(TransportError, "Read operation already pending!") -# template shiftBuffer(t, c: untyped) = -# let length = len((t).buffer) -# if length > c: -# moveMem(addr((t).buffer[0]), addr((t).buffer[(c)]), length - (c)) -# (t).offset = (t).offset - (c) -# else: -# (t).offset = 0 - template shiftBuffer(t, c: untyped) = if (t).offset > c: moveMem(addr((t).buffer[0]), addr((t).buffer[(c)]), (t).offset - (c)) @@ -126,20 +112,29 @@ template shiftBuffer(t, c: untyped) = else: (t).offset = 0 +template shiftVectorBuffer(v, o: untyped) = + (v).buf = cast[pointer](cast[uint]((v).buf) + uint(o)) + (v).buflen -= int(o) + +template shiftVectorFile(v, o: untyped) = + (v).buf = cast[pointer](cast[uint]((v).buf) - cast[uint](o)) + (v).offset += cast[uint]((o)) + when defined(windows): import winlean type WindowsStreamTransport = ref object of StreamTransport - wsabuf: TWSABuf # Reader WSABUF + rwsabuf: TWSABuf # Reader WSABUF + wwsabuf: TWSABuf # Writer WSABUF rovl: CustomOverlapped # Reader OVERLAPPED structure wovl: CustomOverlapped # Writer OVERLAPPED structure roffset: int # Pending reading offset WindowsStreamServer* = ref object of RootRef server: SocketServer # Server object - domain: Domain - abuffer: array[128, byte] - aovl: CustomOverlapped + domain: Domain # Current server domain (IPv4 or IPv6) + abuffer: array[128, byte] # Windows AcceptEx buffer + aovl: CustomOverlapped # AcceptEx OVERLAPPED structure const SO_UPDATE_CONNECT_CONTEXT = 0x7010 @@ -155,38 +150,30 @@ when defined(windows): (t).offset = cast[int32](cast[uint64](o) and 0xFFFFFFFF'u64) (t).offsetHigh = cast[int32](cast[uint64](o) shr 32) - template getFileSize(t: untyped): uint = - cast[uint]((t).dataBuf.buf) + template getFileSize(v: untyped): uint = + cast[uint]((v).buf) - template getFileHandle(t: untyped): Handle = - cast[Handle]((t).dataBuf.len) - - template slideOffset(v, o: untyped) = - let s = cast[uint]((v).dataBuf.buf) - cast[uint]((o)) - (v).dataBuf.buf = cast[cstring](s) - (v).offset = (v).offset + cast[uint]((o)) + template getFileHandle(v: untyped): Handle = + cast[Handle]((v).buflen) template slideBuffer(t, o: untyped) = - (t).dataBuf.buf = cast[cstring](cast[uint]((t).dataBuf.buf) + uint(o)) - (t).dataBuf.len -= int32(o) + (t).wwsabuf.buf = cast[cstring](cast[uint]((t).wwsabuf.buf) + uint(o)) + (t).wwsabuf.len -= int32(o) - template setWSABuffer(t: untyped) = - (t).wsabuf.buf = cast[cstring]( + template setReaderWSABuffer(t: untyped) = + (t).rwsabuf.buf = cast[cstring]( cast[uint](addr t.buffer[0]) + uint((t).roffset)) - (t).wsabuf.len = int32(len((t).buffer) - (t).roffset) + (t).rwsabuf.len = int32(len((t).buffer) - (t).roffset) - # template initTransmitStreamVector(v, h, o, n, t: untyped) = - # (v).kind = DataFile - # (v).dataBuf.buf = cast[cstring]((n)) - # (v).dataBuf.len = cast[int32]((h)) - # (v).offset = cast[uint]((o)) - # (v).writer = (t) + template setWriterWSABuffer(t, v: untyped) = + (t).wwsabuf.buf = cast[cstring](v.buf) + (t).wwsabuf.len = cast[int32](v.buflen) proc writeStreamLoop(udata: pointer) {.gcsafe.} = var bytesCount: int32 if isNil(udata): return - var ovl = cast[PCustomOverlapped](udata) + var ovl = cast[PtrCustomOverlapped](udata) var transp = cast[WindowsStreamTransport](ovl.data.udata) while len(transp.queue) > 0: @@ -202,14 +189,14 @@ when defined(windows): else: if transp.kind == TransportKind.Socket: if vector.kind == VectorKind.DataBuffer: - if bytesCount < vector.dataBuf.len: - vector.slideBuffer(bytesCount) + if bytesCount < transp.wwsabuf.len: + vector.shiftVectorBuffer(bytesCount) transp.queue.addFirst(vector) else: vector.writer.complete() else: if uint(bytesCount) < getFileSize(vector): - vector.slideOffset(bytesCount) + vector.shiftVectorFile(bytesCount) transp.queue.addFirst(vector) else: vector.writer.complete() @@ -224,19 +211,21 @@ when defined(windows): var vector = transp.queue.popFirst() if vector.kind == VectorKind.DataBuffer: transp.wovl.zeroOvelappedOffset() - let ret = WSASend(sock, vector.dataBuf, 1, + transp.setWriterWSABuffer(vector) + let ret = WSASend(sock, addr transp.wwsabuf, 1, addr bytesCount, DWORD(0), cast[POVERLAPPED](addr transp.wovl), nil) if ret != 0: let err = osLastError() if int(err) == ERROR_OPERATION_ABORTED: + transp.state.excl(WritePending) transp.state.incl(WritePaused) elif int(err) == ERROR_IO_PENDING: transp.queue.addFirst(vector) else: transp.state.excl(WritePending) transp.setWriteError(err) - transp.finishWriter() + vector.writer.complete() else: transp.queue.addFirst(vector) else: @@ -256,13 +245,14 @@ when defined(windows): if ret == 0: let err = osLastError() if int(err) == ERROR_OPERATION_ABORTED: + transp.state.excl(WritePending) transp.state.incl(WritePaused) elif int(err) == ERROR_IO_PENDING: transp.queue.addFirst(vector) else: transp.state.excl(WritePending) transp.setWriteError(err) - transp.finishWriter() + vector.writer.complete() else: transp.queue.addFirst(vector) break @@ -273,18 +263,19 @@ when defined(windows): proc readStreamLoop(udata: pointer) {.gcsafe.} = if isNil(udata): return - var ovl = cast[PCustomOverlapped](udata) + var ovl = cast[PtrCustomOverlapped](udata) var transp = cast[WindowsStreamTransport](ovl.data.udata) while true: if ReadPending in transp.state: ## Continuation + transp.state.excl(ReadPending) if ReadClosed in transp.state: break - transp.state.excl(ReadPending) let err = transp.rovl.data.errCode if err == OSErrorCode(-1): let bytesCount = transp.rovl.data.bytesCount + transp.transferred += bytesCount if bytesCount == 0: transp.state.incl(ReadEof) transp.state.incl(ReadPaused) @@ -301,22 +292,27 @@ when defined(windows): transp.setReadError(err) if not isNil(transp.reader): transp.finishReader() + if ReadPaused in transp.state: + # Transport buffer is full, so we will not continue on reading. + break else: ## Initiation - if (ReadEof notin transp.state) and (ReadClosed notin transp.state): + if transp.state * {ReadEof, ReadClosed, ReadError} == {}: var flags = DWORD(0) var bytesCount: int32 = 0 transp.state.excl(ReadPaused) transp.state.incl(ReadPending) if transp.kind == TransportKind.Socket: let sock = SocketHandle(transp.rovl.data.fd) - transp.setWSABuffer() - let ret = WSARecv(sock, addr transp.wsabuf, 1, + transp.roffset = transp.offset + transp.setReaderWSABuffer() + let ret = WSARecv(sock, addr transp.rwsabuf, 1, addr bytesCount, addr flags, cast[POVERLAPPED](addr transp.rovl), nil) if ret != 0: let err = osLastError() if int(err) == ERROR_OPERATION_ABORTED: + transp.state.excl(ReadPending) transp.state.incl(ReadPaused) elif int32(err) != ERROR_IO_PENDING: transp.setReadError(err) @@ -326,17 +322,18 @@ when defined(windows): break proc newStreamSocketTransport(sock: AsyncFD, bufsize: int): StreamTransport = - var t = WindowsStreamTransport(kind: TransportKind.Socket) - t.fd = sock - t.rovl.data = CompletionData(fd: sock, cb: readStreamLoop, - udata: cast[pointer](t)) - t.wovl.data = CompletionData(fd: sock, cb: writeStreamLoop, - udata: cast[pointer](t)) - t.buffer = newSeq[byte](bufsize) - t.state = {ReadPaused, WritePaused} - t.queue = initDeque[StreamVector]() - t.future = newFuture[void]("stream.socket.transport") - result = cast[StreamTransport](t) + var transp = WindowsStreamTransport(kind: TransportKind.Socket) + transp.fd = sock + transp.rovl.data = CompletionData(fd: sock, cb: readStreamLoop, + udata: cast[pointer](transp)) + transp.wovl.data = CompletionData(fd: sock, cb: writeStreamLoop, + udata: cast[pointer](transp)) + transp.buffer = newSeq[byte](bufsize) + transp.state = {ReadPaused, WritePaused} + transp.queue = initDeque[StreamVector]() + transp.future = newFuture[void]("stream.socket.transport") + GC_ref(transp) + result = cast[StreamTransport](transp) proc bindToDomain(handle: AsyncFD, domain: Domain): bool = result = true @@ -374,7 +371,7 @@ when defined(windows): result.fail(newException(OSError, osErrorMsg(osLastError()))) proc continuation(udata: pointer) = - var ovl = cast[PCustomOverlapped](udata) + var ovl = cast[PtrCustomOverlapped](udata) if not retFuture.finished: if ovl.data.errCode == OSErrorCode(-1): if setsockopt(SocketHandle(sock), cint(SOL_SOCKET), @@ -417,7 +414,7 @@ when defined(windows): let dwRemoteAddressLength = DWORD(sizeof(Sockaddr_in6) + 16) proc continuation(udata: pointer) = - var ovl = cast[PCustomOverlapped](udata) + var ovl = cast[PtrCustomOverlapped](udata) if not retFuture.finished: if server.server.status in {Stopped, Paused}: sock.closeAsyncSocket() @@ -482,7 +479,7 @@ when defined(windows): if not acceptFut.failed: var sock = acceptFut.read() if sock != asyncInvalidSocket: - spawn server.function( + discard server.function( newStreamSocketTransport(sock, server.bufferSize), server.udata) @@ -508,10 +505,6 @@ else: template getVectorLength(v: untyped): int = cast[int]((v).buflen - int((v).boffset)) - template shiftVectorBuffer(t, o: untyped) = - (t).buf = cast[pointer](cast[uint]((t).buf) + uint(o)) - (t).buflen -= int(o) - template initBufferStreamVector(v, p, n, t: untyped) = (v).kind = DataBuffer (v).buf = cast[pointer]((p)) @@ -524,7 +517,6 @@ else: let fd = SocketHandle(cdata.fd) if not isNil(transp): if len(transp.queue) > 0: - echo "len(transp.queue) = ", len(transp.queue) var vector = transp.queue.popFirst() while true: if transp.kind == TransportKind.Socket: @@ -543,9 +535,24 @@ else: else: transp.setWriteError(err) vector.writer.complete() - break else: - discard + let res = sendfile(int(fd), cast[int](vector.buflen), + int(vector.offset), + cast[int](vector.buf)) + if res >= 0: + if cast[int](vector.buf) - res == 0: + vector.writer.complete() + else: + vector.shiftVectorFile(res) + transp.queue.addFirst(vector) + else: + let err = osLastError() + if int(err) == EINTR: + continue + else: + transp.setWriteError(err) + vector.writer.complete() + break else: transp.state.incl(WritePaused) transp.fd.removeWriter() @@ -583,13 +590,14 @@ else: break proc newStreamSocketTransport(sock: AsyncFD, bufsize: int): StreamTransport = - var t = UnixStreamTransport(kind: TransportKind.Socket) - t.fd = sock - t.buffer = newSeq[byte](bufsize) - t.state = {ReadPaused, WritePaused} - t.queue = initDeque[StreamVector]() - t.future = newFuture[void]("socket.stream.transport") - result = cast[StreamTransport](t) + var transp = UnixStreamTransport(kind: TransportKind.Socket) + transp.fd = sock + transp.buffer = newSeq[byte](bufsize) + transp.state = {ReadPaused, WritePaused} + transp.queue = initDeque[StreamVector]() + transp.future = newFuture[void]("socket.stream.transport") + GC_ref(transp) + result = cast[StreamTransport](transp) proc connect*(address: TransportAddress, bufferSize = DefaultStreamBufferSize): Future[StreamTransport] = @@ -641,7 +649,6 @@ else: var saddr: Sockaddr_storage slen: SockLen - var server = cast[StreamServer](cast[ptr CompletionData](udata).udata) while true: let res = posix.accept(SocketHandle(server.sock), @@ -649,7 +656,7 @@ else: if int(res) > 0: let sock = wrapAsyncSocket(res) if sock != asyncInvalidSocket: - spawn server.function( + discard server.function( newStreamSocketTransport(sock, server.bufferSize), server.udata) break @@ -692,19 +699,28 @@ else: addWriter(transp.fd, writeStreamLoop, cast[pointer](transp)) proc start*(server: SocketServer) = + ## Starts ``server``. server.action = Start server.actEvent.fire() proc stop*(server: SocketServer) = + ## Stops ``server`` server.action = Stop server.actEvent.fire() proc pause*(server: SocketServer) = + ## Pause ``server`` server.action = Pause server.actEvent.fire() proc join*(server: SocketServer) {.async.} = - await server.loopFuture + ## Waits until ``server`` is not stopped. + if not server.loopFuture.finished: + await server.loopFuture + +proc close*(server: SocketServer) = + ## Release ``server`` resources. + GC_unref(server) proc createStreamServer*(host: TransportAddress, flags: set[ServerFlags], @@ -729,7 +745,6 @@ proc createStreamServer*(host: TransportAddress, register(sock) serverSocket = sock - ## TODO: Set socket options here if ServerFlags.ReuseAddr in flags: if not setSockOpt(serverSocket, SOL_SOCKET, SO_REUSEADDR, 1): let err = osLastError() @@ -744,7 +759,7 @@ proc createStreamServer*(host: TransportAddress, if sock == asyncInvalidSocket: closeAsyncSocket(serverSocket) raiseOsError(err) - + if nativesockets.listen(SocketHandle(serverSocket), cint(backlog)) != 0: let err = osLastError() if sock == asyncInvalidSocket: @@ -759,19 +774,15 @@ proc createStreamServer*(host: TransportAddress, result.actEvent = newAsyncEvent() result.udata = udata result.local = host + GC_ref(result) result.loopFuture = serverLoop(result) proc write*(transp: StreamTransport, pbytes: pointer, nbytes: int): Future[int] {.async.} = checkClosed(transp) var waitFuture = newFuture[void]("transport.write") - var vector = StreamVector(kind: DataBuffer, writer: waitFuture) - when defined(windows): - var wsabuf = TWSABuf(buf: cast[cstring](pbytes), len: cast[int32](nbytes)) - vector.dataBuf = addr wsabuf - else: - vector.buf = pbytes - vector.buflen = nbytes + var vector = StreamVector(kind: DataBuffer, writer: waitFuture, + buf: pbytes, buflen: nbytes) transp.queue.addLast(vector) if WritePaused in transp.state: transp.resumeWrite() @@ -780,25 +791,25 @@ proc write*(transp: StreamTransport, pbytes: pointer, raise transp.getError() result = nbytes -# proc writeFile*(transp: StreamTransport, handle: int, -# offset: uint = 0, -# size: uint = 0): Future[void] {.async.} = -# if transp.kind != TransportKind.Socket: -# raise newException(TransportError, "You can transmit files only to sockets") -# checkClosed(transp) -# var waitFuture = newFuture[void]("transport.writeFile") -# var vector: StreamVector -# vector.initTransmitStreamVector(handle, offset, size, waitFuture) -# transp.queue.addLast(vector) - -# if WritePaused in transp.state: -# transp.resumeWrite() -# await vector.writer -# if WriteError in transp.state: -# raise transp.getError() +proc writeFile*(transp: StreamTransport, handle: int, + offset: uint = 0, + size: int = 0): Future[void] {.async.} = + if transp.kind != TransportKind.Socket: + raise newException(TransportError, "You can transmit files only to sockets") + checkClosed(transp) + var waitFuture = newFuture[void]("transport.writeFile") + var vector = StreamVector(kind: DataFile, writer: waitFuture, + buf: cast[pointer](size), offset: offset, + buflen: handle) + transp.queue.addLast(vector) + if WritePaused in transp.state: + transp.resumeWrite() + await vector.writer + if WriteError in transp.state: + raise transp.getError() proc readExactly*(transp: StreamTransport, pbytes: pointer, - nbytes: int): Future[int] {.async.} = + nbytes: int) {.async.} = ## Read exactly ``nbytes`` bytes from transport ``transp``. checkClosed(transp) checkPending(transp) @@ -814,17 +825,19 @@ proc readExactly*(transp: StreamTransport, pbytes: pointer, copyMem(cast[pointer](cast[uint](pbytes) + uint(index)), addr(transp.buffer[0]), nbytes - index) transp.shiftBuffer(nbytes - index) - result = nbytes break else: - copyMem(cast[pointer](cast[uint](pbytes) + uint(index)), - addr(transp.buffer[0]), transp.offset) - index += transp.offset - transp.reader = newFuture[void]("transport.readExactly") + if transp.offset != 0: + copyMem(cast[pointer](cast[uint](pbytes) + uint(index)), + addr(transp.buffer[0]), transp.offset) + index += transp.offset + + transp.reader = newFuture[void]("stream.transport.readExactly") transp.offset = 0 if ReadPaused in transp.state: transp.resumeRead() await transp.reader + # we are no longer need data transp.reader = nil @@ -840,7 +853,7 @@ proc readOnce*(transp: StreamTransport, pbytes: pointer, if (ReadEof in transp.state) or (ReadClosed in transp.state): result = 0 break - transp.reader = newFuture[void]("transport.readOnce") + transp.reader = newFuture[void]("stream.transport.readOnce") if ReadPaused in transp.state: transp.resumeRead() await transp.reader @@ -898,7 +911,7 @@ proc readUntil*(transp: StreamTransport, pbytes: pointer, nbytes: int, break else: if (transp.offset - index) == 0: - transp.reader = newFuture[void]("transport.readUntil") + transp.reader = newFuture[void]("stream.transport.readUntil") if ReadPaused in transp.state: transp.resumeRead() await transp.reader @@ -945,7 +958,7 @@ proc readLine*(transp: StreamTransport, limit = 0, break else: if (transp.offset - index) == 0: - transp.reader = newFuture[void]("transport.readLine") + transp.reader = newFuture[void]("stream.transport.readLine") if ReadPaused in transp.state: transp.resumeRead() await transp.reader @@ -990,7 +1003,7 @@ proc read*(transp: StreamTransport, n = -1): Future[seq[byte]] {.async.} = transp.offset) transp.offset = 0 - transp.reader = newFuture[void]("transport.read") + transp.reader = newFuture[void]("stream.transport.read") if ReadPaused in transp.state: transp.resumeRead() await transp.reader @@ -1005,7 +1018,8 @@ proc atEof*(transp: StreamTransport): bool {.inline.} = proc join*(transp: StreamTransport) {.async.} = ## Wait until ``transp`` will not be closed. - await transp.future + if not transp.future.finished: + await transp.future proc close*(transp: StreamTransport) = ## Closes and frees resources of transport ``transp``. @@ -1016,3 +1030,4 @@ proc close*(transp: StreamTransport) = transp.state.incl(WriteClosed) transp.state.incl(ReadClosed) transp.future.complete() + GC_unref(transp) diff --git a/tests/testdatagram.nim b/tests/testdatagram.nim index 8165d691..dd47cce6 100644 --- a/tests/testdatagram.nim +++ b/tests/testdatagram.nim @@ -1,4 +1,4 @@ -# Asyncdispatch2 +# Asyncdispatch2 Test Suite # (c) Copyright 2018 # Status Research & Development GmbH # diff --git a/tests/teststream.nim b/tests/teststream.nim index ddb7b0b5..9fa03999 100644 --- a/tests/teststream.nim +++ b/tests/teststream.nim @@ -1,82 +1,295 @@ -import strutils, net, unittest +# Asyncdispatch2 Test Suite +# (c) Copyright 2018 +# Status Research & Development GmbH +# +# Licensed under either of +# Apache License, version 2.0, (LICENSE-APACHEv2) +# MIT license (LICENSE-MIT) + +import strutils, net, unittest, os import ../asyncdispatch2 +when defined(windows): + import winlean +else: + import posix + const - ClientsCount = 1 - MessagesCount = 100000 + ClientsCount = 50 + MessagesCount = 50 + MessageSize = 20 + FilesCount = 50 + FilesTestName = "teststream.nim" proc serveClient1(transp: StreamTransport, udata: pointer) {.async.} = - echo "SERVER STARTING (0x" & toHex[uint](cast[uint](transp)) & ")" while not transp.atEof(): var data = await transp.readLine() - echo "SERVER READ [" & data & "]" - if data.startsWith("REQUEST"): - var numstr = data[7..^1] - var num = parseInt(numstr) - var ans = "ANSWER" & $num & "\r\n" - var res = await transp.write(cast[pointer](addr ans[0]), len(ans)) - # doAssert(res == len(ans)) - echo "SERVER EXITING (0x" & toHex[uint](cast[uint](transp)) & ")" - -proc swarmWorker(address: TransportAddress) {.async.} = - echo "CONNECTING TO " & $address - var transp = await connect(address) - echo "CONNECTED" - for i in 0..= 0) + var rbuffer = newSeq[byte](sizeNum) + await transp.readExactly(addr rbuffer[0], sizeNum) + var lbuffer = readFile(pathname) + doAssert(len(lbuffer) == sizeNum) + doAssert(equalMem(addr rbuffer[0], addr lbuffer[0], sizeNum)) + var answer = "OK\r\n" + var res = await transp.write(cast[pointer](addr answer[0]), len(answer)) + doAssert(res == len(answer)) + +proc swarmWorker1(address: TransportAddress): Future[int] {.async.} = + var transp = await connect(address) + for i in 0.. 0) + name = name & "\r\n" + ssize = $size & "\r\n" + var res = await transp.write(cast[pointer](addr name[0]), len(name)) + doAssert(res == len(name)) + res = await transp.write(cast[pointer](addr ssize[0]), len(ssize)) + doAssert(res == len(ssize)) + await transp.writeFile(handle, 0'u, size) + var ans = await transp.readLine() + doAssert(ans == "OK") + result = 1 + transp.close() + +proc waitAll[T](futs: seq[Future[T]]): Future[void] = + var counter = len(futs) + var retFuture = newFuture[void]("waitAll") + proc cb(udata: pointer) = + dec(counter) + if counter == 0: + retFuture.complete() + for fut in futs: + fut.addCallback(cb) return retFuture -when isMainModule: - var ta: TransportAddress - ta.address = parseIpAddress("127.0.0.1") - ta.port = Port(31344) +proc swarmManager1(address: TransportAddress): Future[int] {.async.} = + var retFuture = newFuture[void]("swarm.manager.readLine") + var workers = newSeq[Future[int]](ClientsCount) + var count = ClientsCount + for i in 0..