diff --git a/LICENSE-APACHEv2 b/LICENSE-APACHEv2 new file mode 100644 index 00000000..782d1bff --- /dev/null +++ b/LICENSE-APACHEv2 @@ -0,0 +1,201 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright 2018 Status Research & Development GmbH + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/LICENSE b/LICENSE-MIT similarity index 93% rename from LICENSE rename to LICENSE-MIT index b4c21de0..8766e65d 100644 --- a/LICENSE +++ b/LICENSE-MIT @@ -1,6 +1,6 @@ -MIT License +The MIT License (MIT) -Copyright (c) 2018 Status +Copyright (c) 2018 Status Research & Development GmbH Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal diff --git a/asyncdispatch2.nim b/asyncdispatch2.nim new file mode 100644 index 00000000..c4e34b47 --- /dev/null +++ b/asyncdispatch2.nim @@ -0,0 +1,10 @@ +# Asyncdispatch2 +# (c) Copyright 2018 +# Status Research & Development GmbH +# +# Licensed under either of +# Apache License, version 2.0, (LICENSE-APACHEv2) +# MIT license (LICENSE-MIT) + +import asyncdispatch2/[asyncloop, asyncfutures2, asyncsync, handles, transport] +export asyncloop, asyncfutures2, asyncsync, handles, transport diff --git a/asyncdispatch2.nimble b/asyncdispatch2.nimble new file mode 100644 index 00000000..17eafecf --- /dev/null +++ b/asyncdispatch2.nimble @@ -0,0 +1,13 @@ +packageName = "asyncdispatch2" +version = "0.1.0" +author = "Status Research & Development GmbH" +description = "Asyncdispatch2" +license = "Apache License 2.0 or MIT" +skipDirs = @["tests", "Nim", "nim"] + +### Dependencies + +requires "nim > 0.18.0", + +task test, "Run all tests": + exec "nim c -r tests/test1" diff --git a/asyncdispatch2/asyncfutures2.nim b/asyncdispatch2/asyncfutures2.nim new file mode 100644 index 00000000..66b2a58e --- /dev/null +++ b/asyncdispatch2/asyncfutures2.nim @@ -0,0 +1,455 @@ +# +# Asyncdispatch2 +# +# (c) Coprygith 2015 Dominik Picheta +# (c) Copyright 2018 Status Research & Development GmbH +# +# Licensed under either of +# Apache License, version 2.0, (LICENSE-APACHEv2) +# MIT license (LICENSE-MIT) + +import os, tables, strutils, times, heapqueue, options, deques, cstrutils + +type + CallbackFunc* = proc (arg: pointer = nil) {.gcsafe.} + CallSoonProc* = proc (c: CallbackFunc, u: pointer = nil) {.gcsafe.} + + AsyncCallback* = object + function*: CallbackFunc + udata*: pointer + + FutureBase* = ref object of RootObj ## Untyped future. + callbacks: Deque[AsyncCallback] + + finished: bool + error*: ref Exception ## Stored exception + errorStackTrace*: string + when not defined(release): + stackTrace: string ## For debugging purposes only. + id: int + fromProc: string + + Future*[T] = ref object of FutureBase ## Typed future. + value: T ## Stored value + + FutureVar*[T] = distinct Future[T] + + FutureError* = object of Exception + cause*: FutureBase + +{.deprecated: [PFutureBase: FutureBase, PFuture: Future].} + +when not defined(release): + var currentID = 0 + +var callSoonHolder {.threadvar.}: CallSoonProc + +proc getCallSoonProc*(): CallSoonProc {.gcsafe.} = + ## Get current implementation of ``callSoon``. + return callSoonHolder + +proc setCallSoonProc*(p: CallSoonProc) = + ## Change current implementation of ``callSoon``. + callSoonHolder = p + +proc callSoon*(c: CallbackFunc, u: pointer = nil) = + ## Call ``cbproc`` "soon". + callSoonHolder(c, u) + +template setupFutureBase(fromProc: string) = + new(result) + result.finished = false + when not defined(release): + result.stackTrace = getStackTrace() + result.id = currentID + result.fromProc = fromProc + currentID.inc() + +proc newFuture*[T](fromProc: string = "unspecified"): Future[T] = + ## Creates a new future. + ## + ## Specifying ``fromProc``, which is a string specifying the name of the proc + ## that this future belongs to, is a good habit as it helps with debugging. + setupFutureBase(fromProc) + +proc newFutureVar*[T](fromProc = "unspecified"): FutureVar[T] = + ## Create a new ``FutureVar``. This Future type is ideally suited for + ## situations where you want to avoid unnecessary allocations of Futures. + ## + ## Specifying ``fromProc``, which is a string specifying the name of the proc + ## that this future belongs to, is a good habit as it helps with debugging. + result = FutureVar[T](newFuture[T](fromProc)) + +proc clean*[T](future: FutureVar[T]) = + ## Resets the ``finished`` status of ``future``. + Future[T](future).finished = false + Future[T](future).error = nil + +proc checkFinished[T](future: Future[T]) = + ## Checks whether `future` is finished. If it is then raises a + ## ``FutureError``. + when not defined(release): + if future.finished: + var msg = "" + msg.add("An attempt was made to complete a Future more than once. ") + msg.add("Details:") + msg.add("\n Future ID: " & $future.id) + msg.add("\n Created in proc: " & future.fromProc) + msg.add("\n Stack trace to moment of creation:") + msg.add("\n" & indent(future.stackTrace.strip(), 4)) + when T is string: + msg.add("\n Contents (string): ") + msg.add("\n" & indent(future.value.repr, 4)) + msg.add("\n Stack trace to moment of secondary completion:") + msg.add("\n" & indent(getStackTrace().strip(), 4)) + var err = newException(FutureError, msg) + err.cause = future + raise err + +proc call(callbacks: var Deque[AsyncCallback]) = + var count = len(callbacks) + if count > 0: + while count > 0: + var item = callbacks.popFirst() + callSoon(item.function, item.udata) + dec(count) + +proc add(callbacks: var Deque[AsyncCallback], item: AsyncCallback) = + if len(callbacks) == 0: + callbacks = initDeque[AsyncCallback]() + callbacks.addLast(item) + +proc remove(callbacks: var Deque[AsyncCallback], item: AsyncCallback) = + if len(callbacks) > 0: + var count = len(callbacks) + while count > 0: + var p = callbacks.popFirst() + if p.function != item.function or p.udata != item.udata: + callbacks.addLast(p) + dec(count) + +proc complete*[T](future: Future[T], val: T) = + ## Completes ``future`` with value ``val``. + #assert(not future.finished, "Future already finished, cannot finish twice.") + checkFinished(future) + assert(future.error == nil) + future.value = val + future.finished = true + future.callbacks.call() + +proc complete*(future: Future[void]) = + ## Completes a void ``future``. + #assert(not future.finished, "Future already finished, cannot finish twice.") + checkFinished(future) + assert(future.error == nil) + future.finished = true + future.callbacks.call() + +proc complete*[T](future: FutureVar[T]) = + ## Completes a ``FutureVar``. + template fut: untyped = Future[T](future) + checkFinished(fut) + assert(fut.error == nil) + fut.finished = true + fut.callbacks.call() + +proc complete*[T](future: FutureVar[T], val: T) = + ## Completes a ``FutureVar`` with value ``val``. + ## + ## Any previously stored value will be overwritten. + template fut: untyped = Future[T](future) + checkFinished(fut) + assert(fut.error.isNil()) + fut.finished = true + fut.value = val + fut.callbacks.call() + +proc fail*[T](future: Future[T], error: ref Exception) = + ## Completes ``future`` with ``error``. + #assert(not future.finished, "Future already finished, cannot finish twice.") + checkFinished(future) + future.finished = true + future.error = error + future.errorStackTrace = + if getStackTrace(error) == "": getStackTrace() else: getStackTrace(error) + future.callbacks.call() + +proc clearCallbacks(future: FutureBase) = + if len(future.callbacks) > 0: + var count = len(future.callbacks) + while count > 0: + discard future.callbacks.popFirst() + dec(count) + +proc addCallback*(future: FutureBase, cb: CallbackFunc, udata: pointer = nil) = + ## Adds the callbacks proc to be called when the future completes. + ## + ## If future has already completed then ``cb`` will be called immediately. + assert cb != nil + if future.finished: + callSoon(cb, udata) + else: + let acb = AsyncCallback(function: cb, udata: udata) + future.callbacks.add acb + +proc addCallback*[T](future: Future[T], cb: CallbackFunc) = + ## Adds the callbacks proc to be called when the future completes. + ## + ## If future has already completed then ``cb`` will be called immediately. + future.addCallback(cb, cast[pointer](unsafeAddr future)) + +proc removeCallback*(future: FutureBase, cb: CallbackFunc, + udata: pointer = nil) = + assert cb != nil + let acb = AsyncCallback(function: cb, udata: udata) + future.callbacks.remove acb + +proc removeCallback*[T](future: Future[T], cb: CallbackFunc) = + future.removeCallback(cb, cast[pointer](unsafeAddr future)) + +proc `callback=`*(future: FutureBase, cb: CallbackFunc, udata: pointer = nil) = + ## Clears the list of callbacks and sets the callback proc to be called when + ## the future completes. + ## + ## If future has already completed then ``cb`` will be called immediately. + ## + ## It's recommended to use ``addCallback`` or ``then`` instead. + future.clearCallbacks + future.addCallback(cb, udata) + +proc `callback=`*[T](future: Future[T], cb: CallbackFunc) = + ## Sets the callback proc to be called when the future completes. + ## + ## If future has already completed then ``cb`` will be called immediately. + `callback=`(future, cb, cast[pointer](future)) + +proc getHint(entry: StackTraceEntry): string = + ## We try to provide some hints about stack trace entries that the user + ## may not be familiar with, in particular calls inside the stdlib. + result = "" + if entry.procname == "processPendingCallbacks": + if cmpIgnoreStyle(entry.filename, "asyncdispatch.nim") == 0: + return "Executes pending callbacks" + elif entry.procname == "poll": + if cmpIgnoreStyle(entry.filename, "asyncdispatch.nim") == 0: + return "Processes asynchronous completion events" + + if entry.procname.endsWith("_continue"): + if cmpIgnoreStyle(entry.filename, "asyncmacro.nim") == 0: + return "Resumes an async procedure" + +proc `$`*(entries: seq[StackTraceEntry]): string = + result = "" + # Find longest filename & line number combo for alignment purposes. + var longestLeft = 0 + for entry in entries: + if entry.procName.isNil: continue + + let left = $entry.filename & $entry.line + if left.len > longestLeft: + longestLeft = left.len + + var indent = 2 + # Format the entries. + for entry in entries: + if entry.procName.isNil: + if entry.line == -10: + result.add(spaces(indent) & "#[\n") + indent.inc(2) + else: + indent.dec(2) + result.add(spaces(indent)& "]#\n") + continue + + let left = "$#($#)" % [$entry.filename, $entry.line] + result.add((spaces(indent) & "$#$# $#\n") % [ + left, + spaces(longestLeft - left.len + 2), + $entry.procName + ]) + let hint = getHint(entry) + if hint.len > 0: + result.add(spaces(indent+2) & "## " & hint & "\n") + +proc injectStacktrace[T](future: Future[T]) = + when not defined(release): + const header = "\nAsync traceback:\n" + + var exceptionMsg = future.error.msg + if header in exceptionMsg: + # This is messy: extract the original exception message from the msg + # containing the async traceback. + let start = exceptionMsg.find(header) + exceptionMsg = exceptionMsg[0..`_. +## +## Limitations/Bugs +## ---------------- +## +## * The effect system (``raises: []``) does not work with async procedures. +## * Can't await in a ``except`` body +## * Forward declarations for async procs are broken, +## link includes workaround: https://github.com/nim-lang/Nim/issues/3182. + +# TODO: Check if yielded future is nil and throw a more meaningful exception + +type + TimerCallback* = object + finishAt*: uint64 + function*: AsyncCallback + + PDispatcherBase = ref object of RootRef + timers*: HeapQueue[TimerCallback] + callbacks*: Deque[AsyncCallback] + +proc `<`(a, b: TimerCallback): bool = + result = a.finishAt < b.finishAt + +proc callSoon(cbproc: CallbackFunc, data: pointer = nil) {.gcsafe.} + +proc initCallSoonProc = + if asyncfutures2.getCallSoonProc().isNil: + asyncfutures2.setCallSoonProc(callSoon) + +when defined(windows) or defined(nimdoc): + import winlean, sets, hashes + type + WSAPROC_TRANSMITFILE = proc(hSocket: SocketHandle, hFile: Handle, + nNumberOfBytesToWrite: DWORD, + nNumberOfBytesPerSend: DWORD, + lpOverlapped: POVERLAPPED, + lpTransmitBuffers: pointer, + dwReserved: DWORD): cint {. + stdcall.} + + CompletionKey = ULONG_PTR + + CompletionData* = object + fd*: AsyncFD + cb*: CallbackFunc + errCode*: OSErrorCode + bytesCount*: int32 + udata*: pointer + cell*: ForeignCell # we need this `cell` to protect our `cb` environment, + # when using RegisterWaitForSingleObject, because + # waiting is done in different thread. + + PDispatcher* = ref object of PDispatcherBase + ioPort: Handle + handles: HashSet[AsyncFD] + connectEx*: WSAPROC_CONNECTEX + acceptEx*: WSAPROC_ACCEPTEX + getAcceptExSockAddrs*: WSAPROC_GETACCEPTEXSOCKADDRS + transmitFile*: WSAPROC_TRANSMITFILE + + CustomOverlapped* = object of OVERLAPPED + data*: CompletionData + + PCustomOverlapped* = ptr CustomOverlapped + + RefCustomOverlapped* = ref CustomOverlapped + + AsyncFD* = distinct int + + # PostCallbackData = object + # ioPort: Handle + # handleFd: AsyncFD + # waitFd: Handle + # ovl: PCustomOverlapped + # PostCallbackDataPtr = ptr PostCallbackData + + proc hash(x: AsyncFD): Hash {.borrow.} + proc `==`*(x: AsyncFD, y: AsyncFD): bool {.borrow.} + + proc newDispatcher*(): PDispatcher = + ## Creates a new Dispatcher instance. + new result + result.ioPort = createIoCompletionPort(INVALID_HANDLE_VALUE, 0, 0, 1) + result.handles = initSet[AsyncFD]() + result.timers.newHeapQueue() + result.callbacks = initDeque[AsyncCallback](64) + + var gDisp{.threadvar.}: PDispatcher ## Global dispatcher + + proc setGlobalDispatcher*(disp: PDispatcher) = + if not gDisp.isNil: + assert gDisp.callbacks.len == 0 + gDisp = disp + initCallSoonProc() + + proc getGlobalDispatcher*(): PDispatcher = + if gDisp.isNil: + setGlobalDispatcher(newDispatcher()) + result = gDisp + + proc getIoHandler*(disp: PDispatcher): Handle = + ## Returns the underlying IO Completion Port handle (Windows) or selector + ## (Unix) for the specified dispatcher. + return disp.ioPort + + proc register*(fd: AsyncFD) = + ## Registers ``fd`` with the dispatcher. + let p = getGlobalDispatcher() + if createIoCompletionPort(fd.Handle, p.ioPort, + cast[CompletionKey](fd), 1) == 0: + raiseOSError(osLastError()) + p.handles.incl(fd) + + proc poll*() = + let loop = getGlobalDispatcher() + var curTime = fastEpochTime() + var curTimeout = DWORD(0) + + # Moving expired timers to `loop.callbacks` and calculate timeout + var count = len(loop.timers) + if count > 0: + var lastFinish = curTime + while count > 0: + lastFinish = loop.timers[0].finishAt + if curTime < lastFinish: + break + loop.callbacks.addLast(loop.timers.pop().function) + dec(count) + if count > 0: + curTimeout = DWORD(lastFinish - curTime) + + if curTimeout == 0: + if len(loop.callbacks) == 0: + curTimeout = INFINITE + + # Processing handles + var lpNumberOfBytesTransferred: Dword + var lpCompletionKey: ULONG_PTR + var customOverlapped: PCustomOverlapped + let res = getQueuedCompletionStatus( + loop.ioPort, addr lpNumberOfBytesTransferred, addr lpCompletionKey, + cast[ptr POVERLAPPED](addr customOverlapped), curTimeout).bool + if res: + customOverlapped.data.bytesCount = lpNumberOfBytesTransferred + customOverlapped.data.errCode = OSErrorCode(-1) + let acb = AsyncCallback(function: customOverlapped.data.cb, + udata: cast[pointer](customOverlapped)) + loop.callbacks.addLast(acb) + else: + let errCode = osLastError() + if customOverlapped != nil: + assert customOverlapped.data.fd == lpCompletionKey.AsyncFD + customOverlapped.data.errCode = errCode + let acb = AsyncCallback(function: customOverlapped.data.cb, + udata: cast[pointer](customOverlapped)) + loop.callbacks.addLast(acb) + else: + if int32(errCode) != WAIT_TIMEOUT: + raiseOSError(errCode) + + # Moving expired timers to `loop.callbacks`. + curTime = fastEpochTime() + count = len(loop.timers) + if count > 0: + while count > 0: + if curTime < loop.timers[0].finishAt: + break + loop.callbacks.addLast(loop.timers.pop().function) + dec(count) + + # All callbacks which will be added in process will be processed on next + # poll() call. + count = len(loop.callbacks) + for i in 0.. 0: + var lastFinish = curTime + while count > 0: + lastFinish = loop.timers[0].finishAt + if curTime < lastFinish: + break + loop.callbacks.addLast(loop.timers.pop().function) + dec(count) + if count > 0: + curTimeout = int(lastFinish - curTime) + + if curTimeout == 0: + if len(loop.callbacks) == 0: + curTimeout = -1 + + count = loop.selector.selectInto(curTimeout, loop.keys) + for i in 0.. 0: + while count > 0: + if curTime < loop.timers[0].finishAt: + break + loop.callbacks.addLast(loop.timers.pop().function) + dec(count) + + # All callbacks which will be added in process will be processed on next + # poll() call. + count = len(loop.callbacks) + for i in 0.. -> else: raise futSym.error + exceptionChecks.add((newIdentNode("true"), + newNimNode(nnkRaiseStmt).add(errorNode))) + # Read the future if there is no error. + # -> else: futSym.read + let elseNode = newNimNode(nnkElse, fromNode) + elseNode.add newNimNode(nnkStmtList, fromNode) + elseNode[0].add rootReceiver + + let ifBody = newStmtList() + ifBody.add newCall(newIdentNode("setCurrentException"), errorNode) + ifBody.add newIfStmt(exceptionChecks) + ifBody.add newCall(newIdentNode("setCurrentException"), newNilLit()) + + result = newIfStmt( + (newDotExpr(futSym, newIdentNode("failed")), ifBody) + ) + result.add elseNode + +template useVar(result: var NimNode, futureVarNode: NimNode, valueReceiver, + rootReceiver: untyped, fromNode: NimNode) = + ## Params: + ## futureVarNode: The NimNode which is a symbol identifying the Future[T] + ## variable to yield. + ## fromNode: Used for better debug information (to give context). + ## valueReceiver: The node which defines an expression that retrieves the + ## future's value. + ## + ## rootReceiver: ??? TODO + # -> yield future + result.add newNimNode(nnkYieldStmt, fromNode).add(futureVarNode) + # -> future.read + valueReceiver = newDotExpr(futureVarNode, newIdentNode("read")) + result.add generateExceptionCheck(futureVarNode, tryStmt, rootReceiver, + fromNode) + +template createVar(result: var NimNode, futSymName: string, + asyncProc: NimNode, + valueReceiver, rootReceiver: untyped, + fromNode: NimNode) = + result = newNimNode(nnkStmtList, fromNode) + var futSym = genSym(nskVar, "future") + result.add newVarStmt(futSym, asyncProc) # -> var future = y + useVar(result, futSym, valueReceiver, rootReceiver, fromNode) + +proc createFutureVarCompletions(futureVarIdents: seq[NimNode], + fromNode: NimNode): NimNode {.compileTime.} = + result = newNimNode(nnkStmtList, fromNode) + # Add calls to complete each FutureVar parameter. + for ident in futureVarIdents: + # Only complete them if they have not been completed already by the user. + # TODO: Once https://github.com/nim-lang/Nim/issues/5617 is fixed. + # TODO: Add line info to the complete() call! + # In the meantime, this was really useful for debugging :) + #result.add(newCall(newIdentNode("echo"), newStrLitNode(fromNode.lineinfo))) + result.add newIfStmt( + ( + newCall(newIdentNode("not"), + newDotExpr(ident, newIdentNode("finished"))), + newCall(newIdentNode("complete"), ident) + ) + ) + +proc processBody(node, retFutureSym: NimNode, + subTypeIsVoid: bool, futureVarIdents: seq[NimNode], + tryStmt: NimNode): NimNode {.compileTime.} = + #echo(node.treeRepr) + result = node + case node.kind + of nnkReturnStmt: + result = newNimNode(nnkStmtList, node) + + # As I've painfully found out, the order here really DOES matter. + result.add createFutureVarCompletions(futureVarIdents, node) + + if node[0].kind == nnkEmpty: + if not subTypeIsVoid: + result.add newCall(newIdentNode("complete"), retFutureSym, + newIdentNode("result")) + else: + result.add newCall(newIdentNode("complete"), retFutureSym) + else: + let x = node[0].processBody(retFutureSym, subTypeIsVoid, + futureVarIdents, tryStmt) + if x.kind == nnkYieldStmt: result.add x + else: + result.add newCall(newIdentNode("complete"), retFutureSym, x) + + result.add newNimNode(nnkReturnStmt, node).add(newNilLit()) + return # Don't process the children of this return stmt + of nnkCommand, nnkCall: + if node[0].kind == nnkIdent and node[0].eqIdent("await"): + case node[1].kind + of nnkIdent, nnkInfix, nnkDotExpr, nnkCall, nnkCommand: + # await x + # await x or y + # await foo(p, x) + # await foo p, x + var futureValue: NimNode + result.createVar("future" & $node[1][0].toStrLit, node[1], futureValue, + futureValue, node) + else: + error("Invalid node kind in 'await', got: " & $node[1].kind) + elif node.len > 1 and node[1].kind == nnkCommand and + node[1][0].kind == nnkIdent and node[1][0].eqIdent("await"): + # foo await x + var newCommand = node + result.createVar("future" & $node[0].toStrLit, node[1][1], newCommand[1], + newCommand, node) + + of nnkVarSection, nnkLetSection: + case node[0][2].kind + of nnkCommand: + if node[0][2][0].kind == nnkIdent and node[0][2][0].eqIdent("await"): + # var x = await y + var newVarSection = node # TODO: Should this use copyNimNode? + result.createVar("future" & node[0][0].strVal, node[0][2][1], + newVarSection[0][2], newVarSection, node) + else: discard + of nnkAsgn: + case node[1].kind + of nnkCommand: + if node[1][0].eqIdent("await"): + # x = await y + var newAsgn = node + result.createVar("future" & $node[0].toStrLit, node[1][1], newAsgn[1], newAsgn, node) + else: discard + of nnkDiscardStmt: + # discard await x + if node[0].kind == nnkCommand and node[0][0].kind == nnkIdent and + node[0][0].eqIdent("await"): + var newDiscard = node + result.createVar("futureDiscard_" & $toStrLit(node[0][1]), node[0][1], + newDiscard[0], newDiscard, node) + of nnkTryStmt: + # try: await x; except: ... + result = newNimNode(nnkStmtList, node) + template wrapInTry(n, tryBody: untyped) = + var temp = n + n[0] = tryBody + tryBody = temp + + # Transform ``except`` body. + # TODO: Could we perform some ``await`` transformation here to get it + # working in ``except``? + tryBody[1] = processBody(n[1], retFutureSym, subTypeIsVoid, + futureVarIdents, nil) + + proc processForTry(n: NimNode, i: var int, + res: NimNode): bool {.compileTime.} = + ## Transforms the body of the tryStmt. Does not transform the + ## body in ``except``. + ## Returns true if the tryStmt node was transformed into an ifStmt. + result = false + var skipped = n.skipStmtList() + while i < skipped.len: + var processed = processBody(skipped[i], retFutureSym, + subTypeIsVoid, futureVarIdents, n) + + # Check if we transformed the node into an exception check. + # This suggests skipped[i] contains ``await``. + if processed.kind != skipped[i].kind or processed.len != skipped[i].len: + processed = processed.skipUntilStmtList() + expectKind(processed, nnkStmtList) + expectKind(processed[2][1], nnkElse) + i.inc + + if not processForTry(n, i, processed[2][1][0]): + # We need to wrap the nnkElse nodes back into a tryStmt. + # As they are executed if an exception does not happen + # inside the awaited future. + # The following code will wrap the nodes inside the + # original tryStmt. + wrapInTry(n, processed[2][1][0]) + + res.add processed + result = true + else: + res.add skipped[i] + i.inc + var i = 0 + if not processForTry(node, i, result): + # If the tryStmt hasn't been transformed we can just put the body + # back into it. + wrapInTry(node, result) + return + else: discard + + for i in 0 ..< result.len: + result[i] = processBody(result[i], retFutureSym, subTypeIsVoid, + futureVarIdents, nil) + +proc getName(node: NimNode): string {.compileTime.} = + case node.kind + of nnkPostfix: + return node[1].strVal + of nnkIdent: + return node.strVal + of nnkEmpty: + return "anonymous" + else: + error("Unknown name.") + +proc getFutureVarIdents(params: NimNode): seq[NimNode] {.compileTime.} = + result = @[] + for i in 1 ..< len(params): + expectKind(params[i], nnkIdentDefs) + if params[i][1].kind == nnkBracketExpr and + params[i][1][0].eqIdent("futurevar"): + result.add(params[i][0]) + +proc isInvalidReturnType(typeName: string): bool = + return typeName notin ["Future"] #, "FutureStream"] + +proc verifyReturnType(typeName: string) {.compileTime.} = + if typeName.isInvalidReturnType: + error("Expected return type of 'Future' got '$1'" % + typeName) + +proc asyncSingleProc(prc: NimNode): NimNode {.compileTime.} = + ## This macro transforms a single procedure into a closure iterator. + ## The ``async`` macro supports a stmtList holding multiple async procedures. + if prc.kind notin {nnkProcDef, nnkLambda, nnkMethodDef, nnkDo}: + error("Cannot transform this node kind into an async proc." & + " proc/method definition or lambda node expected.") + + let prcName = prc.name.getName + + let returnType = prc.params[0] + var baseType: NimNode + # Verify that the return type is a Future[T] + if returnType.kind == nnkBracketExpr: + let fut = repr(returnType[0]) + verifyReturnType(fut) + baseType = returnType[1] + elif returnType.kind in nnkCallKinds and returnType[0].eqIdent("[]"): + let fut = repr(returnType[1]) + verifyReturnType(fut) + baseType = returnType[2] + elif returnType.kind == nnkEmpty: + baseType = returnType + else: + verifyReturnType(repr(returnType)) + + let subtypeIsVoid = returnType.kind == nnkEmpty or + (baseType.kind == nnkIdent and returnType[1].eqIdent("void")) + + let futureVarIdents = getFutureVarIdents(prc.params) + + var outerProcBody = newNimNode(nnkStmtList, prc.body) + + # -> var retFuture = newFuture[T]() + var retFutureSym = genSym(nskVar, "retFuture") + var subRetType = + if returnType.kind == nnkEmpty: newIdentNode("void") + else: baseType + outerProcBody.add( + newVarStmt(retFutureSym, + newCall( + newNimNode(nnkBracketExpr, prc.body).add( + newIdentNode("newFuture"), + subRetType), + newLit(prcName)))) # Get type from return type of this proc + + # -> iterator nameIter(): FutureBase {.closure.} = + # -> {.push warning[resultshadowed]: off.} + # -> var result: T + # -> {.pop.} + # -> + # -> complete(retFuture, result) + var iteratorNameSym = genSym(nskIterator, $prcName & "Iter") + var procBody = prc.body.processBody(retFutureSym, subtypeIsVoid, + futureVarIdents, nil) + # don't do anything with forward bodies (empty) + if procBody.kind != nnkEmpty: + procBody.add(createFutureVarCompletions(futureVarIdents, nil)) + + if not subtypeIsVoid: + procBody.insert(0, newNimNode(nnkPragma).add(newIdentNode("push"), + newNimNode(nnkExprColonExpr).add(newNimNode(nnkBracketExpr).add( + newIdentNode("warning"), newIdentNode("resultshadowed")), + newIdentNode("off")))) # -> {.push warning[resultshadowed]: off.} + + procBody.insert(1, newNimNode(nnkVarSection, prc.body).add( + newIdentDefs(newIdentNode("result"), baseType))) # -> var result: T + + procBody.insert(2, newNimNode(nnkPragma).add( + newIdentNode("pop"))) # -> {.pop.}) + + procBody.add( + newCall(newIdentNode("complete"), + retFutureSym, newIdentNode("result"))) # -> complete(retFuture, result) + else: + # -> complete(retFuture) + procBody.add(newCall(newIdentNode("complete"), retFutureSym)) + + var closureIterator = newProc(iteratorNameSym, [newIdentNode("FutureBase")], + procBody, nnkIteratorDef) + closureIterator.pragma = newNimNode(nnkPragma, lineInfoFrom=prc.body) + closureIterator.addPragma(newIdentNode("closure")) + + # If proc has an explicit gcsafe pragma, we add it to iterator as well. + if prc.pragma.findChild(it.kind in {nnkSym, nnkIdent} and $it == "gcsafe") != nil: + closureIterator.addPragma(newIdentNode("gcsafe")) + outerProcBody.add(closureIterator) + + # -> createCb(retFuture) + # NOTE: The "_continue" suffix is checked for in asyncfutures.nim to produce + # friendlier stack traces: + var cbName = genSym(nskProc, prcName & "_continue") + var procCb = getAst createCb(retFutureSym, iteratorNameSym, + newStrLitNode(prcName), + cbName, + createFutureVarCompletions(futureVarIdents, nil)) + outerProcBody.add procCb + + # -> return retFuture + outerProcBody.add newNimNode(nnkReturnStmt, prc.body[^1]).add(retFutureSym) + + result = prc + + if subtypeIsVoid: + # Add discardable pragma. + if returnType.kind == nnkEmpty: + # Add Future[void] + result.params[0] = parseExpr("Future[void]") + if procBody.kind != nnkEmpty: + result.body = outerProcBody + #echo(treeRepr(result)) + #if prcName == "recvLineInto": + # echo(toStrLit(result)) + +macro async*(prc: untyped): untyped = + ## Macro which processes async procedures into the appropriate + ## iterators and yield statements. + if prc.kind == nnkStmtList: + for oneProc in prc: + result = newStmtList() + result.add asyncSingleProc(oneProc) + else: + result = asyncSingleProc(prc) + when defined(nimDumpAsync): + echo repr result + + +# Multisync +proc emptyNoop[T](x: T): T = + # The ``await``s are replaced by a call to this for simplicity. + when T isnot void: + return x + +proc stripAwait(node: NimNode): NimNode = + ## Strips out all ``await`` commands from a procedure body, replaces them + ## with ``emptyNoop`` for simplicity. + result = node + + let emptyNoopSym = bindSym("emptyNoop") + + case node.kind + of nnkCommand, nnkCall: + if node[0].kind == nnkIdent and node[0].eqIdent("await"): + node[0] = emptyNoopSym + elif node.len > 1 and node[1].kind == nnkCommand and + node[1][0].kind == nnkIdent and node[1][0].eqIdent("await"): + # foo await x + node[1][0] = emptyNoopSym + of nnkVarSection, nnkLetSection: + case node[0][2].kind + of nnkCommand: + if node[0][2][0].kind == nnkIdent and node[0][2][0].eqIdent("await"): + # var x = await y + node[0][2][0] = emptyNoopSym + else: discard + of nnkAsgn: + case node[1].kind + of nnkCommand: + if node[1][0].eqIdent("await"): + # x = await y + node[1][0] = emptyNoopSym + else: discard + of nnkDiscardStmt: + # discard await x + if node[0].kind == nnkCommand and node[0][0].kind == nnkIdent and + node[0][0].eqIdent("await"): + node[0][0] = emptyNoopSym + else: discard + + for i in 0 ..< result.len: + result[i] = stripAwait(result[i]) + +proc splitParamType(paramType: NimNode, async: bool): NimNode = + result = paramType + if paramType.kind == nnkInfix and paramType[0].strVal in ["|", "or"]: + let firstAsync = "async" in paramType[1].strVal.normalize + let secondAsync = "async" in paramType[2].strVal.normalize + + if firstAsync: + result = paramType[if async: 1 else: 2] + elif secondAsync: + result = paramType[if async: 2 else: 1] + +proc stripReturnType(returnType: NimNode): NimNode = + # Strip out the 'Future' from 'Future[T]'. + result = returnType + if returnType.kind == nnkBracketExpr: + let fut = repr(returnType[0]) + verifyReturnType(fut) + result = returnType[1] + +proc splitProc(prc: NimNode): (NimNode, NimNode) = + ## Takes a procedure definition which takes a generic union of arguments, + ## for example: proc (socket: Socket | AsyncSocket). + ## It transforms them so that ``proc (socket: Socket)`` and + ## ``proc (socket: AsyncSocket)`` are returned. + + result[0] = prc.copyNimTree() + # Retrieve the `T` inside `Future[T]`. + let returnType = stripReturnType(result[0][3][0]) + result[0][3][0] = splitParamType(returnType, async=false) + for i in 1 ..< result[0][3].len: + # Sync proc (0) -> FormalParams (3) -> IdentDefs, the parameter (i) -> + # parameter type (1). + result[0][3][i][1] = splitParamType(result[0][3][i][1], async=false) + result[0][6] = stripAwait(result[0][6]) + + result[1] = prc.copyNimTree() + if result[1][3][0].kind == nnkBracketExpr: + result[1][3][0][1] = splitParamType(result[1][3][0][1], async=true) + for i in 1 ..< result[1][3].len: + # Async proc (1) -> FormalParams (3) -> IdentDefs, the parameter (i) -> + # parameter type (1). + result[1][3][i][1] = splitParamType(result[1][3][i][1], async=true) + +macro multisync*(prc: untyped): untyped = + ## Macro which processes async procedures into both asynchronous and + ## synchronous procedures. + ## + ## The generated async procedures use the ``async`` macro, whereas the + ## generated synchronous procedures simply strip off the ``await`` calls. + let (sync, asyncPrc) = splitProc(prc) + result = newStmtList() + result.add(asyncSingleProc(asyncPrc)) + result.add(sync) diff --git a/asyncdispatch2/asyncsync.nim b/asyncdispatch2/asyncsync.nim new file mode 100644 index 00000000..19d243df --- /dev/null +++ b/asyncdispatch2/asyncsync.nim @@ -0,0 +1,330 @@ +# +# Asyncdispatch2 synchronization primitives +# +# (c) Coprygith 2018 Eugene Kabanov +# (c) Copyright 2018 Status Research & Development GmbH +# +# Licensed under either of +# Apache License, version 2.0, (LICENSE-APACHEv2) +# MIT license (LICENSE-MIT) + + +## This module implements some core synchronization primitives, which +## `asyncdispatch` is really lacking. +import asyncloop, deques + +type + AsyncLock* = ref object of RootRef + ## A primitive lock is a synchronization primitive that is not owned by + ## a particular coroutine when locked. A primitive lock is in one of two + ## states, ``locked`` or ``unlocked``. + ## + ## When more than one coroutine is blocked in ``acquire()`` waiting for + ## the state to turn to unlocked, only one coroutine proceeds when a + ## ``release()`` call resets the state to unlocked; first coroutine which + ## is blocked in ``acquire()`` is being processed. + locked: bool + waiters: Deque[Future[void]] + + AsyncEvent* = ref object of RootRef + ## A primitive event object. + ## + ## An event manages a flag that can be set to `true` with the ``fire()`` + ## procedure and reset to `false` with the ``clear()`` procedure. + ## The ``wait()`` coroutine blocks until the flag is `false`. + ## + ## If more than one coroutine blocked in ``wait()`` waiting for event + ## state to be signaled, when event get fired, then all coroutines + ## continue proceeds in order, they have entered waiting state. + + flag: bool + waiters: Deque[Future[void]] + + AsyncQueue*[T] = ref object of RootRef + ## A queue, useful for coordinating producer and consumer coroutines. + ## + ## If ``maxsize`` is less than or equal to zero, the queue size is + ## infinite. If it is an integer greater than ``0``, then "await put()" + ## will block when the queue reaches ``maxsize``, until an item is + ## removed by "await get()". + getters: Deque[Future[void]] + putters: Deque[Future[void]] + queue: Deque[T] + maxsize: int + + AsyncQueueEmptyError* = object of Exception + ## ``AsyncQueue`` is empty. + AsyncQueueFullError* = object of Exception + ## ``AsyncQueue`` is full. + AsyncLockError* = object of Exception + ## ``AsyncLock`` is either locked or unlocked. + +proc newAsyncLock*(): AsyncLock = + ## Creates new asynchronous lock ``AsyncLock``. + ## + ## Lock is created in the unlocked state. When the state is unlocked, + ## ``acquire()`` changes the state to locked and returns immediately. + ## When the state is locked, ``acquire()`` blocks until a call to + ## ``release()`` in another coroutine changes it to unlocked. + ## + ## The ``release()`` procedure changes the state to unlocked and returns + ## immediately. + + # Workaround for callSoon() not worked correctly before + # getGlobalDispatcher() call. + discard getGlobalDispatcher() + result = new AsyncLock + result.waiters = initDeque[Future[void]]() + result.locked = false + +proc acquire*(lock: AsyncLock) {.async.} = + ## Acquire a lock ``lock``. + ## + ## This procedure blocks until the lock ``lock`` is unlocked, then sets it + ## to locked and returns. + if not lock.locked: + lock.locked = true + else: + var w = newFuture[void]("asynclock.acquire") + lock.waiters.addLast(w) + yield w + lock.locked = true + +proc own*(lock: AsyncLock) = + ## Acquire a lock ``lock``. + ## + ## This procedure not blocks, if ``lock`` is locked, then ``AsyncLockError`` + ## exception would be raised. + if lock.locked: + raise newException(AsyncLockError, "AsyncLock is already acquired!") + lock.locked = true + +proc locked*(lock: AsyncLock): bool = + ## Return `true` if the lock ``lock`` is acquired, `false` otherwise. + result = lock.locked + +proc release*(lock: AsyncLock) = + ## Release a lock ``lock``. + ## + ## When the ``lock`` is locked, reset it to unlocked, and return. If any + ## other coroutines are blocked waiting for the lock to become unlocked, + ## allow exactly one of them to proceed. + var w: Future[void] + proc wakeup(udata: pointer) {.gcsafe.} = w.complete() + + if lock.locked: + lock.locked = false + while len(lock.waiters) > 0: + w = lock.waiters.popFirst() + if not w.finished: + callSoon(wakeup) + break + else: + raise newException(AsyncLockError, "AsyncLock is not acquired!") + +proc newAsyncEvent*(): AsyncEvent = + ## Creates new asyncronous event ``AsyncEvent``. + ## + ## An event manages a flag that can be set to `true` with the `fire()` + ## procedure and reset to `false` with the `clear()` procedure. + ## The `wait()` procedure blocks until the flag is `true`. The flag is + ## initially `false`. + + # Workaround for callSoon() not worked correctly before + # getGlobalDispatcher() call. + discard getGlobalDispatcher() + result = new AsyncEvent + result.waiters = initDeque[Future[void]]() + result.flag = false + +proc wait*(event: AsyncEvent) {.async.} = + ## Block until the internal flag of ``event`` is `true`. + ## If the internal flag is `true` on entry, return immediately. Otherwise, + ## block until another task calls `fire()` to set the flag to `true`, + ## then return. + if event.flag: + discard + else: + var w = newFuture[void]("asyncevent.wait") + event.waiters.addLast(w) + yield w + +proc fire*(event: AsyncEvent) = + ## Set the internal flag of ``event`` to `true`. All tasks waiting for it + ## to become `true` are awakened. Task that call `wait()` once the flag is + ## `true` will not block at all. + proc wakeupAll(udata: pointer) {.gcsafe.} = + if len(event.waiters) > 0: + var w = event.waiters.popFirst() + if not w.finished: + w.complete() + callSoon(wakeupAll) + + if not event.flag: + event.flag = true + callSoon(wakeupAll) + +proc clear*(event: AsyncEvent) = + ## Reset the internal flag of ``event`` to `false`. Subsequently, tasks + ## calling `wait()` will block until `fire()` is called to set the internal + ## flag to `true` again. + event.flag = false + +proc isSet*(event: AsyncEvent): bool = + ## Return `true` if and only if the internal flag of ``event`` is `true`. + result = event.flag + +proc newAsyncQueue*[T](maxsize: int = 0): AsyncQueue[T] = + ## Creates a new asynchronous queue ``AsyncQueue``. + + # Workaround for callSoon() not worked correctly before + # getGlobalDispatcher() call. + discard getGlobalDispatcher() + result = new AsyncQueue[T] + result.getters = initDeque[Future[void]]() + result.putters = initDeque[Future[void]]() + result.queue = initDeque[T]() + result.maxsize = maxsize + +proc full*[T](aq: AsyncQueue[T]): bool {.inline.} = + ## Return ``true`` if there are ``maxsize`` items in the queue. + ## + ## Note: If the ``aq`` was initialized with ``maxsize = 0`` (default), + ## then ``full()`` is never ``true``. + if aq.maxsize <= 0: + result = false + else: + result = len(aq.queue) >= aq.maxsize + +proc empty*[T](aq: AsyncQueue[T]): bool {.inline.} = + ## Return ``true`` if the queue is empty, ``false`` otherwise. + result = (len(aq.queue) == 0) + +proc putNoWait*[T](aq: AsyncQueue[T], item: T) = + ## Put an item into the queue ``aq`` immediately. + ## + ## If queue ``aq`` is full, then ``AsyncQueueFullError`` exception raised + var w: Future[void] + proc wakeup(udata: pointer) {.gcsafe.} = w.complete() + + if aq.full(): + raise newException(AsyncQueueFullError, "AsyncQueue is full!") + aq.queue.addLast(item) + + while len(aq.getters) > 0: + w = aq.getters.popFirst() + if not w.finished: + callSoon(wakeup) + +proc getNoWait*[T](aq: AsyncQueue[T]): T = + ## Remove and return ``item`` from the queue immediately. + ## + ## If queue ``aq`` is empty, then ``AsyncQueueEmptyError`` exception raised. + var w: Future[void] + proc wakeup(udata: pointer) {.gcsafe.} = w.complete() + + if aq.empty(): + raise newException(AsyncQueueEmptyError, "AsyncQueue is empty!") + result = aq.queue.popFirst() + while len(aq.putters) > 0: + w = aq.putters.popFirst() + if not w.finished: + callSoon(wakeup) + +proc put*[T](aq: AsyncQueue[T], item: T) {.async.} = + ## Put an ``item`` into the queue ``aq``. If the queue is full, wait until + ## a free slot is available before adding item. + while aq.full(): + var putter = newFuture[void]("asyncqueue.putter") + aq.putters.addLast(putter) + yield putter + aq.putNoWait(item) + +proc get*[T](aq: AsyncQueue[T]): Future[T] {.async.} = + ## Remove and return an item from the queue ``aq``. + ## + ## If queue is empty, wait until an item is available. + while aq.empty(): + var getter = newFuture[void]("asyncqueue.getter") + aq.getters.addLast(getter) + yield getter + result = aq.getNoWait() + +proc len*[T](aq: AsyncQueue[T]): int {.inline.} = + ## Return the number of elements in ``aq``. + result = len(aq.queue) + +proc size*[T](aq: AsyncQueue[T]): int {.inline.} = + ## Return the maximum number of elements in ``aq``. + result = len(aq.maxsize) + +when isMainModule: + # Locks test + block: + var test = "" + var lock = newAsyncLock() + + proc testLock(n: int, lock: AsyncLock) {.async.} = + await lock.acquire() + test = test & $n + lock.release() + + lock.own() + asyncCheck testLock(0, lock) + asyncCheck testLock(1, lock) + asyncCheck testLock(2, lock) + asyncCheck testLock(3, lock) + asyncCheck testLock(4, lock) + asyncCheck testLock(5, lock) + asyncCheck testLock(6, lock) + asyncCheck testLock(7, lock) + asyncCheck testLock(8, lock) + asyncCheck testLock(9, lock) + lock.release() + poll() + doAssert(test == "0123456789") + + # Events test + block: + var test = "" + var event = newAsyncEvent() + + proc testEvent(n: int, ev: AsyncEvent) {.async.} = + await ev.wait() + test = test & $n + + event.clear() + asyncCheck testEvent(0, event) + asyncCheck testEvent(1, event) + asyncCheck testEvent(2, event) + asyncCheck testEvent(3, event) + asyncCheck testEvent(4, event) + asyncCheck testEvent(5, event) + asyncCheck testEvent(6, event) + asyncCheck testEvent(7, event) + asyncCheck testEvent(8, event) + asyncCheck testEvent(9, event) + event.fire() + poll() + doAssert(test == "0123456789") + + # Queues test + block: + const queueSize = 10 + const testsCount = 1000 + var test = 0 + + proc task1(aq: AsyncQueue[int]) {.async.} = + for i in 1..(testsCount - 1): + var item = await aq.get() + test -= item + + proc task2(aq: AsyncQueue[int]) {.async.} = + for i in 1..testsCount: + await aq.put(i) + test += i + + var queue = newAsyncQueue[int](queueSize) + discard task1(queue) or task2(queue) + poll() + doAssert(test == testsCount) diff --git a/asyncdispatch2/handles.nim b/asyncdispatch2/handles.nim new file mode 100644 index 00000000..b822b85f --- /dev/null +++ b/asyncdispatch2/handles.nim @@ -0,0 +1,98 @@ +# +# Asyncdispatch2 Handles +# (c) Copyright 2018 +# Status Research & Development GmbH +# +# Licensed under either of +# Apache License, version 2.0, (LICENSE-APACHEv2) +# MIT license (LICENSE-MIT) + +import net, nativesockets, asyncloop + +when defined(windows): + import winlean + const + asyncInvalidSocket* = AsyncFD(SocketHandle(-1)) +else: + import posix + const + asyncInvalidSocket* = AsyncFD(posix.INVALID_SOCKET) + +proc setSocketBlocking*(s: SocketHandle, blocking: bool): bool = + ## Sets blocking mode on socket. + when defined(windows): + result = true + var mode = clong(ord(not blocking)) + if ioctlsocket(s, FIONBIO, addr(mode)) == -1: + result = false + else: + result = true + var x: int = fcntl(s, F_GETFL, 0) + if x == -1: + result = false + else: + var mode = if blocking: x and not O_NONBLOCK else: x or O_NONBLOCK + if fcntl(s, F_SETFL, mode) == -1: + result = false + +proc setSockOpt*(socket: SocketHandle | AsyncFD, level, optname, + optval: int): bool = + ## `setsockopt()` for integer options. + ## Returns ``true`` on success, ``false`` on error. + result = true + var value = cint(optval) + if setsockopt(SocketHandle(socket), cint(level), cint(optname), addr(value), + sizeof(value).SockLen) < 0'i32: + result = false + +proc getSockOpt*(socket: SocketHandle | AsyncFD, level, optname: int, + value: var int): bool = + ## `getsockopt()` for integer options. + var res: cint + var size = sizeof(res).SockLen + result = true + if getsockopt(SocketHandle(socket), cint(level), cint(optname), + addr(res), addr(size)) < 0'i32: + return false + value = int(res) + +proc getSocketError*(socket: SocketHandle | AsyncFD, + err: var int): bool = + if not getSockOpt(socket, cint(SOL_SOCKET), cint(SO_ERROR), err): + result = false + else: + result = true + +proc createAsyncSocket*(domain: Domain, sockType: SockType, + protocol: Protocol): AsyncFD = + ## Creates new asynchronous socket. + ## Returns ``asyncInvalidSocket`` on error. + let handle = createNativeSocket(domain, sockType, protocol) + if handle == osInvalidSocket: + return asyncInvalidSocket + if not setSocketBlocking(handle, false): + close(handle) + return asyncInvalidSocket + when defined(macosx) and not defined(nimdoc): + if not handle.setSockOpt(SOL_SOCKET, SO_NOSIGPIPE, 1): + close(handle) + return asyncInvalidSocket + result = AsyncFD(handle) + register(result) + +proc wrapAsyncSocket*(sock: SocketHandle): AsyncFD = + ## Wraps normal socket to asynchronous socket. + ## Return ``asyncInvalidSocket`` on error. + if not setSocketBlocking(sock, false): + close(sock) + return asyncInvalidSocket + when defined(macosx) and not defined(nimdoc): + if not sock.setSockOpt(SOL_SOCKET, SO_NOSIGPIPE, 1): + close(sock) + return asyncInvalidSocket + result = AsyncFD(sock) + register(result) + +proc closeAsyncSocket*(s: AsyncFD) {.inline.} = + unregister(s) + close(SocketHandle(s)) diff --git a/asyncdispatch2/hexdump.nim b/asyncdispatch2/hexdump.nim new file mode 100644 index 00000000..38498d94 --- /dev/null +++ b/asyncdispatch2/hexdump.nim @@ -0,0 +1,92 @@ +# +# Copyright (c) 2016 Eugene Kabanov +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. +# + +from strutils import toHex, repeat + +proc dumpHex*(pbytes: pointer, nbytes: int, items = 1, ascii = true): string = + ## Return hexadecimal memory dump representation pointed by ``p``. + ## ``nbytes`` - number of bytes to show + ## ``items`` - number of bytes in group (supported ``items`` count is + ## 1, 2, 4, 8) + ## ``ascii`` - if ``true`` show ASCII representation of memory dump. + result = "" + let hexSize = items * 2 + var i = 0 + var slider = pbytes + var asciiText = "" + while i < nbytes: + if i %% 16 == 0: + result = result & toHex(cast[BiggestInt](slider), + sizeof(BiggestInt) * 2) & ": " + var k = 0 + while k < items: + var ch = cast[ptr char](cast[uint](slider) + k.uint)[] + if ord(ch) > 31 and ord(ch) < 127: asciiText &= ch else: asciiText &= "." + inc(k) + case items: + of 1: + result = result & toHex(cast[BiggestInt](cast[ptr uint8](slider)[]), + hexSize) + of 2: + result = result & toHex(cast[BiggestInt](cast[ptr uint16](slider)[]), + hexSize) + of 4: + result = result & toHex(cast[BiggestInt](cast[ptr uint32](slider)[]), + hexSize) + of 8: + result = result & toHex(cast[BiggestInt](cast[ptr uint64](slider)[]), + hexSize) + else: + raise newException(ValueError, "Wrong items size!") + result = result & " " + slider = cast[pointer](cast[uint](slider) + items.uint) + i = i + items + if i %% 16 == 0: + result = result & " " & asciiText + asciiText.setLen(0) + result = result & "\n" + + if i %% 16 != 0: + var spacesCount = ((16 - (i %% 16)) div items) * (hexSize + 1) + 1 + result = result & repeat(' ', spacesCount) + result = result & asciiText + result = result & "\n" + +proc dumpHex*[T](v: openarray[T], items: int = 0, ascii = true): string = + ## Return hexadecimal memory dump representation of openarray[T] ``v``. + ## ``items`` - number of bytes in group (supported ``items`` count is + ## 0, 1, 2, 4, 8). If ``items`` is ``0`` group size will depend on + ## ``sizeof(T)``. + ## ``ascii`` - if ``true`` show ASCII representation of memory dump. + var i = 0 + if items == 0: + when sizeof(T) == 2: + i = 2 + elif sizeof(T) == 4: + i = 4 + elif sizeof(T) == 8: + i = 8 + else: + i = 1 + else: + i = items + result = dumpHex(unsafeAddr v[0], sizeof(T) * len(v), i, ascii) diff --git a/asyncdispatch2/sendfile.nim b/asyncdispatch2/sendfile.nim new file mode 100644 index 00000000..02b326fa --- /dev/null +++ b/asyncdispatch2/sendfile.nim @@ -0,0 +1,86 @@ +# +# Asyncdispatch2 SendFile +# (c) Copyright 2018 +# Status Research & Development GmbH +# +# Licensed under either of +# Apache License, version 2.0, (LICENSE-APACHEv2) +# MIT license (LICENSE-MIT) + +## This module provides cross-platform wrapper for ``sendfile()`` syscall. + +when defined(nimdoc): + proc sendfile*(outfd, infd: int, offset: int, count: int): int = + ## Copies data between file descriptor ``infd`` and ``outfd``. Because this + ## copying is done within the kernel, ``sendfile()`` is more efficient than + ## the combination of ``read(2)`` and ``write(2)``, which would require + ## transferring data to and from user space. + ## + ## ``infd`` should be a file descriptor opened for reading and + ## ``outfd`` should be a descriptor opened for writing. + ## + ## The ``infd`` argument must correspond to a file which supports + ## ``mmap(2)``-like operations (i.e., it cannot be a socket). + ## + ## ``offset`` the file offset from which ``sendfile()`` will start reading + ## data from ``infd``. + ## + ## ``count`` is the number of bytes to copy between the file descriptors. + ## + ## If the transfer was successful, the number of bytes written to ``outfd`` + ## is returned. Note that a successful call to ``sendfile()`` may write + ## fewer bytes than requested; the caller should be prepared to retry the + ## call if there were unsent bytes. + ## + ## On error, ``-1`` is returned. + +when defined(linux) or defined(android): + + proc osSendFile*(outfd, infd: cint, offset: ptr int, count: int): int + {.importc: "sendfile", header: "".} + + proc sendfile*(outfd, infd: int, offset: int, count: int): int = + var o = offset + result = osSendFile(cint(outfd), cint(infd), addr offset, count) + +elif defined(freebsd) or defined(openbsd) or defined(netbsd) or + defined(dragonflybsd): + + type + sendfileHeader* = object {.importc: "sf_hdtr", + header: """#include + #include + #include """, + pure, final.} + + proc osSendFile*(outfd, infd: cint, offset: int, size: int, + hdtr: ptr sendfileHeader, sbytes: ptr int, + flags: int): int {.importc: "sendfile", + header: """#include + #include + #include """.} + + proc sendfile*(outfd, infd: int, offset: int, count: int): int = + var o = 0 + result = osSendFile(cint(outfd), cint(infd), offset, count, nil, + addr o, 0) + +elif defined(macosx): + + type + sendfileHeader* = object {.importc: "sf_hdtr", + header: """#include + #include + #include """, + pure, final.} + + proc osSendFile*(fd, s: cint, offset: int, size: ptr int, + hdtr: ptr sendfileHeader, + flags: int): int {.importc: "sendfile", + header: """#include + #include + #include """.} + + proc sendfile*(outfd, infd: int, offset: int, count: int): int = + var o = 0 + result = osSendFile(cint(fd), cint(s), offset, addr o, nil, 0) diff --git a/asyncdispatch2/timer.nim b/asyncdispatch2/timer.nim new file mode 100644 index 00000000..3b0e228f --- /dev/null +++ b/asyncdispatch2/timer.nim @@ -0,0 +1,47 @@ +# +# +# nAIO +# (c) Copyright 2017 Eugene Kabanov +# +# See the file "LICENSE", included in this +# distribution, for details about the copyright. +# + +## This module implements cross-platform system timer with +## milliseconds resolution. + +when defined(windows): + + from winlean import DWORD, getSystemTimeAsFileTime, FILETIME + + proc fastEpochTime*(): uint64 {.inline.} = + var t: FILETIME + getSystemTimeAsFileTime(t) + result = ((uint64(t.dwHighDateTime) shl 32) or + uint64(t.dwLowDateTime)) div 10_000 + +elif defined(macosx): + + from posix import posix_gettimeofday, Timeval + + proc fastEpochTime*(): uint64 {.inline.} = + var t: Timeval + posix_gettimeofday(t) + result = (a.tv_sec * 1_000 + a.tv_usec div 1_000) + +elif defined(posix): + + from posix import clock_gettime, Timespec, CLOCK_REALTIME + + proc fastEpochTime*(): uint64 {.inline.} = + var t: Timespec + discard clock_gettime(CLOCK_REALTIME, t) + result = (uint64(t.tv_sec) * 1_000 + uint64(t.tv_nsec) div 1_000_000) + +elif defined(nimdoc): + + proc fastEpochTime*(): uint64 + ## Returns system's timer in milliseconds. + +else: + error("Sorry, your operation system is not yet supported!") diff --git a/asyncdispatch2/transport.nim b/asyncdispatch2/transport.nim new file mode 100644 index 00000000..0394bc13 --- /dev/null +++ b/asyncdispatch2/transport.nim @@ -0,0 +1,11 @@ +# +# Asyncdispatch2 Transport +# (c) Copyright 2018 +# Status Research & Development GmbH +# +# Licensed under either of +# Apache License, version 2.0, (LICENSE-APACHEv2) +# MIT license (LICENSE-MIT) + +import transports/[datagram, stream, common] +export datagram, common, stream diff --git a/asyncdispatch2/transports/common.nim b/asyncdispatch2/transports/common.nim new file mode 100644 index 00000000..2bb846c0 --- /dev/null +++ b/asyncdispatch2/transports/common.nim @@ -0,0 +1,117 @@ +# +# Asyncdispatch2 Transport Common Types +# (c) Copyright 2018 +# Status Research & Development GmbH +# +# Licensed under either of +# Apache License, version 2.0, (LICENSE-APACHEv2) +# MIT license (LICENSE-MIT) + +import net +import ../asyncloop, ../asyncsync + +const + DefaultStreamBufferSize* = 4096 ## Default buffer size for stream + ## transports + DefaultDatagramBufferSize* = 65536 ## Default buffer size for datagram + ## transports +type + ServerFlags* = enum + ## Server's flags + ReuseAddr, ReusePort + + TransportAddress* = object + ## Transport network address + address*: IpAddress # IP Address + port*: Port # IP port + + ServerCommand* = enum + ## Server's commands + Start, # Start server + Pause, # Pause server + Stop # Stop server + + ServerStatus* = enum + ## Server's statuses + Starting, # Server created + Stopped, # Server stopped + Running, # Server running + Paused # Server paused + + SocketServer* = ref object of RootRef + ## Socket server object + sock*: AsyncFD # Socket + local*: TransportAddress # Address + actEvent*: AsyncEvent # Activation event + action*: ServerCommand # Activation command + status*: ServerStatus # Current server status + udata*: pointer # User-defined pointer + flags*: set[ServerFlags] # Flags + bufferSize*: int # Buffer Size for transports + loopFuture*: Future[void] # Server's main Future + + TransportError* = object of Exception + ## Transport's specific exception + TransportOsError* = object of TransportError + ## Transport's OS specific exception + TransportIncompleteError* = object of TransportError + ## Transport's `incomplete data received` exception + TransportLimitError* = object of TransportError + ## Transport's `data limit reached` exception + + TransportState* = enum + ## Transport's state + ReadPending, # Read operation pending (Windows) + ReadPaused, # Read operations paused + ReadClosed, # Read operations closed + ReadEof, # Read at EOF + ReadError, # Read error + WritePending, # Writer operation pending (Windows) + WritePaused, # Writer operations paused + WriteClosed, # Writer operations closed + WriteError # Write error + +var + AnyAddress* = TransportAddress( + address: IpAddress(family: IpAddressFamily.IPv4), port: Port(0) + ) ## Default INADDR_ANY address for IPv4 + AnyAddress6* = TransportAddress( + address: IpAddress(family: IpAddressFamily.IPv6), port: Port(0) + ) ## Default INADDR_ANY address for IPv6 + +proc getDomain*(address: IpAddress): Domain = + ## Returns OS specific Domain from IP Address. + case address.family + of IpAddressFamily.IPv4: + result = Domain.AF_INET + of IpAddressFamily.IPv6: + result = Domain.AF_INET6 + +proc `$`*(address: TransportAddress): string = + ## Returns string representation of ``address``. + case address.address.family + of IpAddressFamily.IPv4: + result = $address.address + result.add(":") + of IpAddressFamily.IPv6: + result = "[" & $address.address & "]" + result.add(":") + result.add($int(address.port)) + +## TODO: string -> TransportAddress conversion + +template checkClosed*(t: untyped) = + if (ReadClosed in (t).state) or (WriteClosed in (t).state): + raise newException(TransportError, "Transport is already closed!") + +template getError*(t: untyped): ref Exception = + var err = (t).error + (t).error = nil + err + +when defined(windows): + import winlean + + const ERROR_OPERATION_ABORTED* = 995 + proc cancelIo*(hFile: HANDLE): WINBOOL + {.stdcall, dynlib: "kernel32", importc: "CancelIo".} diff --git a/asyncdispatch2/transports/datagram.nim b/asyncdispatch2/transports/datagram.nim new file mode 100644 index 00000000..f2b37add --- /dev/null +++ b/asyncdispatch2/transports/datagram.nim @@ -0,0 +1,491 @@ +# +# Asyncdispatch2 Datagram Transport +# (c) Copyright 2018 +# Status Research & Development GmbH +# +# Licensed under either of +# Apache License, version 2.0, (LICENSE-APACHEv2) +# MIT license (LICENSE-MIT) + +import net, nativesockets, os, deques +import ../asyncloop, ../handles +import common + +type + VectorKind = enum + WithoutAddress, WithAddress + +when defined(windows): + import winlean + type + GramVector = object + kind: VectorKind # Vector kind (with address/without address) + buf: TWSABuf # Writer vector buffer + address: TransportAddress # Destination address + writer: Future[void] # Writer vector completion Future + +else: + import posix + + type + GramVector = object + kind: VectorKind # Vector kind (with address/without address) + buf: pointer # Writer buffer pointer + buflen: int # Writer buffer size + address: TransportAddress # Destination address + writer: Future[void] # Writer vector completion Future + +type + DatagramCallback* = proc(transp: DatagramTransport, + pbytes: pointer, + nbytes: int, + remote: TransportAddress, + udata: pointer): Future[void] {.gcsafe.} + ## Datagram asynchronous receive callback. + ## ``transp`` - transport object + ## ``pbytes`` - pointer to data received + ## ``nbytes`` - number of bytes received + ## ``remote`` - remote peer address + ## ``udata`` - user-defined pointer, specified at Transport creation. + ## + ## ``pbytes`` will be `nil` and ``nbytes`` will be ``0``, if there an error + ## happens. + + DatagramTransport* = ref object of RootRef + fd: AsyncFD # File descriptor + state: set[TransportState] # Current Transport state + buffer: seq[byte] # Reading buffer + error: ref Exception # Current error + queue: Deque[GramVector] # Writer queue + local: TransportAddress # Local address + remote: TransportAddress # Remote address + udata: pointer # User-driven pointer + function: DatagramCallback # Receive data callback + future: Future[void] # Transport's life future + +template setReadError(t, e: untyped) = + (t).state.incl(ReadError) + (t).error = newException(TransportOsError, osErrorMsg((e))) + +template setWriteError(t, e: untyped) = + (t).state.incl(WriteError) + (t).error = newException(TransportOsError, osErrorMsg((e))) + +when defined(windows): + type + WindowsDatagramTransport* = ref object of DatagramTransport + rovl: CustomOverlapped + wovl: CustomOverlapped + raddr: Sockaddr_storage + ralen: SockLen + rflag: int32 + wsabuf: TWSABuf + + template finishWriter(t: untyped) = + var vv = (t).queue.popFirst() + vv.writer.complete() + + proc writeDatagramLoop(udata: pointer) = + var bytesCount: int32 + if isNil(udata): + return + var ovl = cast[PCustomOverlapped](udata) + var transp = cast[WindowsDatagramTransport](ovl.data.udata) + while len(transp.queue) > 0: + if WritePending in transp.state: + ## Continuation + transp.state.excl(WritePending) + let err = transp.wovl.data.errCode + if err == OSErrorCode(-1): + var vector = transp.queue.popFirst() + vector.writer.complete() + else: + transp.setWriteError(err) + transp.finishWriter() + else: + ## Initiation + var saddr: Sockaddr_storage + var slen: SockLen + transp.state.incl(WritePending) + let fd = SocketHandle(ovl.data.fd) + var vector = transp.queue[0] + if vector.kind == WithAddress: + toSockAddr(vector.address.address, vector.address.port, saddr, slen) + else: + toSockAddr(transp.remote.address, transp.remote.port, saddr, slen) + let ret = WSASendTo(fd, addr vector.buf, DWORD(1), addr bytesCount, + DWORD(0), cast[ptr SockAddr](addr saddr), + cint(slen), + cast[POVERLAPPED](addr transp.wovl), nil) + if ret != 0: + let err = osLastError() + if int(err) == ERROR_OPERATION_ABORTED: + transp.state.incl(WritePaused) + elif int(err) != ERROR_IO_PENDING: + transp.state.excl(WritePending) + transp.setWriteError(err) + transp.finishWriter() + break + + if len(transp.queue) == 0: + transp.state.incl(WritePaused) + + proc readDatagramLoop(udata: pointer) = + var + bytesCount: int32 + raddr: TransportAddress + if isNil(udata): + return + var ovl = cast[PCustomOverlapped](udata) + var transp = cast[WindowsDatagramTransport](ovl.data.udata) + while true: + if ReadPending in transp.state: + ## Continuation + if ReadClosed in transp.state: + break + transp.state.excl(ReadPending) + let err = transp.rovl.data.errCode + if err == OSErrorCode(-1): + let bytesCount = transp.rovl.data.bytesCount + if bytesCount == 0: + transp.state.incl(ReadEof) + transp.state.incl(ReadPaused) + fromSockAddr(transp.raddr, transp.ralen, raddr.address, raddr.port) + discard transp.function(transp, addr transp.buffer[0], bytesCount, + raddr, transp.udata) + else: + transp.setReadError(err) + transp.state.incl(ReadPaused) + discard transp.function(transp, nil, 0, raddr, transp.udata) + else: + ## Initiation + if (ReadEof notin transp.state) and (ReadClosed notin transp.state): + transp.state.incl(ReadPending) + let fd = SocketHandle(ovl.data.fd) + transp.rflag = 0 + transp.ralen = SockLen(sizeof(Sockaddr_storage)) + let ret = WSARecvFrom(fd, + addr transp.wsabuf, + DWORD(1), + addr bytesCount, + addr transp.rflag, + cast[ptr SockAddr](addr transp.raddr), + cast[ptr cint](addr transp.ralen), + cast[POVERLAPPED](addr transp.rovl), nil) + if ret != 0: + let err = osLastError() + if int(err) == ERROR_OPERATION_ABORTED: + transp.state.incl(ReadPaused) + elif int(err) != ERROR_IO_PENDING: + transp.state.excl(ReadPending) + transp.setReadError(err) + discard transp.function(transp, nil, 0, raddr, transp.udata) + break + + proc resumeRead(transp: DatagramTransport) {.inline.} = + var wtransp = cast[WindowsDatagramTransport](transp) + wtransp.state.excl(ReadPaused) + readDatagramLoop(cast[pointer](addr wtransp.rovl)) + + proc resumeWrite(transp: DatagramTransport) {.inline.} = + var wtransp = cast[WindowsDatagramTransport](transp) + wtransp.state.excl(WritePaused) + writeDatagramLoop(cast[pointer](addr wtransp.wovl)) + + proc newDatagramTransportCommon(cbproc: DatagramCallback, + remote: TransportAddress, + local: TransportAddress, + sock: AsyncFD, + flags: set[ServerFlags], + udata: pointer, + bufferSize: int): DatagramTransport = + var localSock: AsyncFD + assert(remote.address.family == local.address.family) + assert(not isNil(cbproc)) + + var wresult = new WindowsDatagramTransport + + if sock == asyncInvalidSocket: + if local.address.family == IpAddressFamily.IPv4: + localSock = createAsyncSocket(Domain.AF_INET, SockType.SOCK_DGRAM, + Protocol.IPPROTO_UDP) + else: + localSock = createAsyncSocket(Domain.AF_INET6, SockType.SOCK_DGRAM, + Protocol.IPPROTO_UDP) + if localSock == asyncInvalidSocket: + raiseOsError(osLastError()) + else: + if not setSocketBlocking(SocketHandle(sock), false): + raiseOsError(osLastError()) + localSock = sock + register(localSock) + + if local.port != Port(0): + var saddr: Sockaddr_storage + var slen: SockLen + toSockAddr(local.address, local.port, saddr, slen) + if bindAddr(SocketHandle(localSock), cast[ptr SockAddr](addr saddr), + slen) != 0: + let err = osLastError() + if sock == asyncInvalidSocket: + closeAsyncSocket(localSock) + raiseOsError(err) + wresult.local = local + else: + var saddr: Sockaddr_storage + var slen: SockLen + if local.address.family == IpAddressFamily.IPv4: + saddr.ss_family = winlean.AF_INET + slen = SockLen(sizeof(SockAddr_in)) + else: + saddr.ss_family = winlean.AF_INET6 + slen = SockLen(sizeof(SockAddr_in6)) + if bindAddr(SocketHandle(localSock), cast[ptr SockAddr](addr saddr), + slen) != 0: + let err = osLastError() + if sock == asyncInvalidSocket: + closeAsyncSocket(localSock) + raiseOsError(err) + if remote.port != Port(0): + wresult.remote = remote + + ## TODO: Apply server flags + + wresult.fd = localSock + wresult.function = cbproc + wresult.buffer = newSeq[byte](bufferSize) + wresult.queue = initDeque[GramVector]() + wresult.udata = udata + wresult.state = {WritePaused} + wresult.future = newFuture[void]("datagram.transport") + wresult.rovl.data = CompletionData(fd: localSock, cb: readDatagramLoop, + udata: cast[pointer](wresult)) + wresult.wovl.data = CompletionData(fd: localSock, cb: writeDatagramLoop, + udata: cast[pointer](wresult)) + wresult.wsabuf = TWSABuf(buf: cast[cstring](addr wresult.buffer[0]), + len: int32(len(wresult.buffer))) + result = cast[DatagramTransport](wresult) + result.resumeRead() + + proc close*(transp: DatagramTransport) = + ## Closes and frees resources of transport ``transp``. + if ReadClosed notin transp.state and WriteClosed notin transp.state: + discard cancelIo(Handle(transp.fd)) + closeAsyncSocket(transp.fd) + transp.state.incl(WriteClosed) + transp.state.incl(ReadClosed) + transp.future.complete() + +else: + + proc readDatagramLoop(udata: pointer) = + var + saddr: Sockaddr_storage + slen: SockLen + raddr: TransportAddress + + var cdata = cast[ptr CompletionData](udata) + var transp = cast[DatagramTransport](cdata.udata) + let fd = SocketHandle(cdata.fd) + if not isNil(transp): + while true: + slen = SockLen(sizeof(Sockaddr_storage)) + var res = posix.recvfrom(fd, addr transp.buffer[0], + cint(len(transp.buffer)), cint(0), + cast[ptr SockAddr](addr saddr), + addr slen) + if res >= 0: + fromSockAddr(saddr, slen, raddr.address, raddr.port) + discard transp.function(transp, addr transp.buffer[0], res, + raddr, transp.udata) + else: + let err = osLastError() + if int(err) == EINTR: + continue + else: + transp.setReadError(err) + discard transp.function(transp, nil, 0, raddr, transp.udata) + break + + proc writeDatagramLoop(udata: pointer) = + var res: int = 0 + var cdata = cast[ptr CompletionData](udata) + var transp = cast[DatagramTransport](cdata.udata) + var saddr: Sockaddr_storage + var slen: SockLen + let fd = SocketHandle(cdata.fd) + if not isNil(transp): + if len(transp.queue) > 0: + var vector = transp.queue.popFirst() + while true: + if vector.kind == WithAddress: + toSockAddr(vector.address.address, vector.address.port, saddr, slen) + res = posix.sendto(fd, vector.buf, vector.buflen, MSG_NOSIGNAL, + cast[ptr SockAddr](addr saddr), + slen) + elif vector.kind == WithoutAddress: + res = posix.send(fd, vector.buf, vector.buflen, MSG_NOSIGNAL) + if res >= 0: + vector.writer.complete() + else: + let err = osLastError() + if int(err) == EINTR: + continue + else: + transp.setWriteError(err) + vector.writer.complete() + break + else: + transp.state.incl(WritePaused) + transp.fd.removeWriter() + + proc resumeWrite(transp: DatagramTransport) {.inline.} = + transp.state.excl(WritePaused) + addWriter(transp.fd, writeDatagramLoop, cast[pointer](transp)) + + proc resumeRead(transp: DatagramTransport) {.inline.} = + transp.state.excl(ReadPaused) + addReader(transp.fd, readDatagramLoop, cast[pointer](transp)) + + proc newDatagramTransportCommon(cbproc: DatagramCallback, + remote: TransportAddress, + local: TransportAddress, + sock: AsyncFD, + flags: set[ServerFlags], + udata: pointer, + bufferSize: int): DatagramTransport = + var localSock: AsyncFD + assert(remote.address.family == local.address.family) + assert(not isNil(cbproc)) + + result = new DatagramTransport + + if sock == asyncInvalidSocket: + if local.address.family == IpAddressFamily.IPv4: + localSock = createAsyncSocket(Domain.AF_INET, SockType.SOCK_DGRAM, + Protocol.IPPROTO_UDP) + else: + localSock = createAsyncSocket(Domain.AF_INET6, SockType.SOCK_DGRAM, + Protocol.IPPROTO_UDP) + if localSock == asyncInvalidSocket: + raiseOsError(osLastError()) + else: + if not setSocketBlocking(SocketHandle(sock), false): + raiseOsError(osLastError()) + localSock = sock + register(localSock) + + ## Apply ServerFlags here + if ServerFlags.ReuseAddr in flags: + if not setSockOpt(localSock, SOL_SOCKET, SO_REUSEADDR, 1): + let err = osLastError() + if sock == asyncInvalidSocket: + closeAsyncSocket(localSock) + raiseOsError(err) + + if local.port != Port(0): + var saddr: Sockaddr_storage + var slen: SockLen + toSockAddr(local.address, local.port, saddr, slen) + if bindAddr(SocketHandle(localSock), cast[ptr SockAddr](addr saddr), + slen) != 0: + let err = osLastError() + if sock == asyncInvalidSocket: + closeAsyncSocket(localSock) + raiseOsError(err) + result.local = local + + if remote.port != Port(0): + var saddr: Sockaddr_storage + var slen: SockLen + toSockAddr(remote.address, remote.port, saddr, slen) + if connect(SocketHandle(localSock), cast[ptr SockAddr](addr saddr), + slen) != 0: + let err = osLastError() + if sock == asyncInvalidSocket: + closeAsyncSocket(localSock) + raiseOsError(err) + result.remote = remote + + result.fd = localSock + result.function = cbproc + result.buffer = newSeq[byte](bufferSize) + result.queue = initDeque[GramVector]() + result.udata = udata + result.state = {WritePaused} + result.future = newFuture[void]("datagram.transport") + result.resumeRead() + + proc close*(transp: DatagramTransport) = + ## Closes and frees resources of transport ``transp``. + if ReadClosed notin transp.state and WriteClosed notin transp.state: + closeAsyncSocket(transp.fd) + transp.state.incl(WriteClosed) + transp.state.incl(ReadClosed) + transp.future.complete() + +proc newDatagramTransport*(cbproc: DatagramCallback, + remote: TransportAddress = AnyAddress, + local: TransportAddress = AnyAddress, + sock: AsyncFD = asyncInvalidSocket, + flags: set[ServerFlags] = {}, + udata: pointer = nil, + bufSize: int = DefaultDatagramBufferSize + ): DatagramTransport = + result = newDatagramTransportCommon(cbproc, remote, local, sock, + flags, udata, bufSize) + +proc newDatagramTransport6*(cbproc: DatagramCallback, + remote: TransportAddress = AnyAddress6, + local: TransportAddress = AnyAddress6, + sock: AsyncFD = asyncInvalidSocket, + flags: set[ServerFlags] = {}, + udata: pointer = nil, + bufSize: int = DefaultDatagramBufferSize + ): DatagramTransport = + result = newDatagramTransportCommon(cbproc, remote, local, sock, + flags, udata, bufSize) + +proc join*(transp: DatagramTransport) {.async.} = + await transp.future + +proc send*(transp: DatagramTransport, pbytes: pointer, + nbytes: int): Future[void] {.async.} = + checkClosed(transp) + if transp.remote.port == Port(0): + raise newException(TransportError, "Remote peer is not set!") + var waitFuture = newFuture[void]("datagram.transport.send") + when defined(windows): + let wsabuf = TWSABuf(buf: cast[cstring](pbytes), len: int32(nbytes)) + var vector = GramVector(kind: WithoutAddress, buf: wsabuf, + writer: waitFuture) + else: + var vector = GramVector(kind: WithoutAddress, buf: pbytes, buflen: nbytes, + writer: waitFuture) + transp.queue.addLast(vector) + if WritePaused in transp.state: + transp.resumeWrite() + await vector.writer + if WriteError in transp.state: + raise transp.getError() + +proc sendTo*(transp: DatagramTransport, pbytes: pointer, nbytes: int, + remote: TransportAddress): Future[void] {.async.} = + checkClosed(transp) + var saddr: Sockaddr_storage + var slen: SockLen + toSockAddr(remote.address, remote.port, saddr, slen) + var waitFuture = newFuture[void]("datagram.transport.sendto") + when defined(windows): + let wsabuf = TWSABuf(buf: cast[cstring](pbytes), len: int32(nbytes)) + var vector = GramVector(kind: WithAddress, buf: wsabuf, + address: remote, writer: waitFuture) + else: + var vector = GramVector(kind: WithAddress, buf: pbytes, buflen: nbytes, + address: remote, writer: waitFuture) + transp.queue.addLast(vector) + if WritePaused in transp.state: + transp.resumeWrite() + await vector.writer + if WriteError in transp.state: + raise transp.getError() diff --git a/asyncdispatch2/transports/stream.nim b/asyncdispatch2/transports/stream.nim new file mode 100644 index 00000000..aa078346 --- /dev/null +++ b/asyncdispatch2/transports/stream.nim @@ -0,0 +1,1032 @@ +# +# Asyncdispatch2 Stream Transport +# (c) Copyright 2018 +# Status Research & Development GmbH +# +# Licensed under either of +# Apache License, version 2.0, (LICENSE-APACHEv2) +# MIT license (LICENSE-MIT) + +import ../asyncloop, ../asyncsync, ../handles +import common +import net, nativesockets, os, deques, strutils + +type + VectorKind = enum + DataBuffer, # Simple buffer pointer/length + DataFile # File handle for sendfile/TransmitFile + +when defined(windows): + import winlean + type + StreamVector = object + kind: VectorKind # Writer vector source kind + dataBuf: TWSABuf # Writer vector buffer + offset: uint # Writer vector offset + writer: Future[void] # Writer vector completion Future + +else: + import posix + type + StreamVector = object + kind: VectorKind # Writer vector source kind + buf: pointer # Writer buffer pointer + buflen: int # Writer buffer size + offset: uint # Writer vector offset + writer: Future[void] # Writer vector completion Future + +type + TransportKind* {.pure.} = enum + Socket, # Socket transport + Pipe, # Pipe transport + File # File transport + +type + StreamTransport* = ref object of RootRef + fd: AsyncFD # File descriptor + state: set[TransportState] # Current Transport state + reader: Future[void] # Current reader Future + buffer: seq[byte] # Reading buffer + offset: int # Reading buffer offset + error: ref Exception # Current error + queue: Deque[StreamVector] # Writer queue + future: Future[void] # Stream life future + case kind*: TransportKind + of TransportKind.Socket: + domain: Domain # Socket transport domain (IPv4/IPv6) + local: TransportAddress # Local address + remote: TransportAddress # Remote address + of TransportKind.Pipe: + fd0: AsyncFD + fd1: AsyncFD + of TransportKind.File: + length: int + + StreamCallback* = proc(t: StreamTransport, + udata: pointer): Future[void] {.gcsafe.} + + StreamServer* = ref object of SocketServer + function*: StreamCallback + +proc remoteAddress*(transp: StreamTransport): TransportAddress = + ## Returns ``transp`` remote socket address. + if transp.kind != TransportKind.Socket: + raise newException(TransportError, "Socket required!") + if transp.remote.port == Port(0): + var saddr: Sockaddr_storage + var slen = SockLen(sizeof(saddr)) + if getpeername(SocketHandle(transp.fd), cast[ptr SockAddr](addr saddr), + addr slen) != 0: + raiseOsError(osLastError()) + fromSockAddr(saddr, slen, transp.remote.address, transp.remote.port) + result = transp.remote + +proc localAddress*(transp: StreamTransport): TransportAddress = + ## Returns ``transp`` local socket address. + if transp.kind != TransportKind.Socket: + raise newException(TransportError, "Socket required!") + if transp.local.port == Port(0): + var saddr: Sockaddr_storage + var slen = SockLen(sizeof(saddr)) + if getsockname(SocketHandle(transp.fd), cast[ptr SockAddr](addr saddr), + addr slen) != 0: + raiseOsError(osLastError()) + fromSockAddr(saddr, slen, transp.local.address, transp.local.port) + result = transp.local + +template setReadError(t, e: untyped) = + (t).state.incl(ReadError) + (t).error = newException(TransportOsError, osErrorMsg((e))) + +template setWriteError(t, e: untyped) = + (t).state.incl(WriteError) + (t).error = newException(TransportOsError, osErrorMsg((e))) + +template finishReader(t: untyped) = + var reader = (t).reader + (t).reader = nil + reader.complete() + +template checkPending(t: untyped) = + if not isNil((t).reader): + raise newException(TransportError, "Read operation already pending!") + +# template shiftBuffer(t, c: untyped) = +# let length = len((t).buffer) +# if length > c: +# moveMem(addr((t).buffer[0]), addr((t).buffer[(c)]), length - (c)) +# (t).offset = (t).offset - (c) +# else: +# (t).offset = 0 + +template shiftBuffer(t, c: untyped) = + if (t).offset > c: + echo "moveMem(" & $int((t).offset) & ", " & $int(c) & ")" + moveMem(addr((t).buffer[0]), addr((t).buffer[(c)]), (t).offset - (c)) + (t).offset = (t).offset - (c) + else: + (t).offset = 0 + +when defined(windows): + import winlean + type + WindowsStreamTransport = ref object of StreamTransport + wsabuf: TWSABuf # Reader WSABUF + rovl: CustomOverlapped # Reader OVERLAPPED structure + wovl: CustomOverlapped # Writer OVERLAPPED structure + roffset: int # Pending reading offset + + WindowsStreamServer* = ref object of RootRef + server: SocketServer # Server object + domain: Domain + abuffer: array[128, byte] + aovl: CustomOverlapped + + const SO_UPDATE_CONNECT_CONTEXT = 0x7010 + + template finishWriter(t: untyped) = + var vv = (t).queue.popFirst() + vv.writer.complete() + + template zeroOvelappedOffset(t: untyped) = + (t).offset = 0 + (t).offsetHigh = 0 + + template setOverlappedOffset(t, o: untyped) = + (t).offset = cast[int32](cast[uint64](o) and 0xFFFFFFFF'u64) + (t).offsetHigh = cast[int32](cast[uint64](o) shr 32) + + template getFileSize(t: untyped): uint = + cast[uint]((t).dataBuf.buf) + + template getFileHandle(t: untyped): Handle = + cast[Handle]((t).dataBuf.len) + + template slideOffset(v, o: untyped) = + let s = cast[uint]((v).dataBuf.buf) - cast[uint]((o)) + (v).dataBuf.buf = cast[cstring](s) + (v).offset = (v).offset + cast[uint]((o)) + + template slideBuffer(t, o: untyped) = + (t).dataBuf.buf = cast[cstring](cast[uint]((t).dataBuf.buf) + uint(o)) + (t).dataBuf.len -= int32(o) + + template setWSABuffer(t: untyped) = + (t).wsabuf.buf = cast[cstring]( + cast[uint](addr t.buffer[0]) + uint((t).roffset)) + (t).wsabuf.len = int32(len((t).buffer) - (t).roffset) + + template initBufferStreamVector(v, p, n, t: untyped) = + (v).kind = DataBuffer + (v).dataBuf.buf = cast[cstring]((p)) + (v).dataBuf.len = cast[int32](n) + (v).writer = (t) + + template initTransmitStreamVector(v, h, o, n, t: untyped) = + (v).kind = DataFile + (v).dataBuf.buf = cast[cstring]((n)) + (v).dataBuf.len = cast[int32]((h)) + (v).offset = cast[uint]((o)) + (v).writer = (t) + + proc writeStreamLoop(udata: pointer) {.gcsafe.} = + var bytesCount: int32 + if isNil(udata): + return + var ovl = cast[PCustomOverlapped](udata) + var transp = cast[WindowsStreamTransport](ovl.data.udata) + + while len(transp.queue) > 0: + if WritePending in transp.state: + ## Continuation + transp.state.excl(WritePending) + let err = transp.wovl.data.errCode + if err == OSErrorCode(-1): + bytesCount = transp.wovl.data.bytesCount + var vector = transp.queue.popFirst() + if bytesCount == 0: + vector.writer.complete() + else: + if transp.kind == TransportKind.Socket: + if vector.kind == VectorKind.DataBuffer: + if bytesCount < vector.dataBuf.len: + vector.slideBuffer(bytesCount) + transp.queue.addFirst(vector) + else: + vector.writer.complete() + else: + if uint(bytesCount) < getFileSize(vector): + vector.slideOffset(bytesCount) + transp.queue.addFirst(vector) + else: + vector.writer.complete() + elif transp.kind in {TransportKind.Pipe, TransportKind.File}: + if bytesCount < vector.dataBuf.len: + vector.slideBuffer(bytesCount) + transp.queue.addFirst(vector) + else: + vector.writer.complete() + else: + transp.setWriteError(err) + transp.finishWriter() + else: + ## Initiation + transp.state.incl(WritePending) + if transp.kind == TransportKind.Socket: + let sock = SocketHandle(transp.wovl.data.fd) + if transp.queue[0].kind == VectorKind.DataBuffer: + transp.wovl.zeroOvelappedOffset() + let ret = WSASend(sock, addr transp.queue[0].dataBuf, 1, + addr bytesCount, DWORD(0), + cast[POVERLAPPED](addr transp.wovl), nil) + if ret != 0: + let err = osLastError() + if int32(err) != ERROR_IO_PENDING: + transp.state.excl(WritePending) + transp.setWriteError(err) + transp.finishWriter() + else: + let loop = getGlobalDispatcher() + var vector = transp.queue[0] + var size: int32 + var flags: int32 + + if getFileSize(vector) > 2_147_483_646'u: + size = 2_147_483_646 + else: + size = int32(getFileSize(vector)) + + transp.wovl.setOverlappedOffset(vector.offset) + + var ret = loop.transmitFile(sock, getFileHandle(vector), size, 0, + cast[POVERLAPPED](addr transp.wovl), + nil, flags) + if ret == 0: + let err = osLastError() + if int32(err) != ERROR_IO_PENDING: + transp.state.excl(WritePending) + transp.setWriteError(err) + transp.finishWriter() + elif transp.kind in {TransportKind.Pipe, TransportKind.File}: + let fd = Handle(transp.wovl.data.fd) + var vector = transp.queue[0] + + if transp.kind == TransportKind.File: + transp.wovl.setOverlappedOffset(vector.offset) + else: + transp.wovl.zeroOvelappedOffset() + + var ret = writeFile(fd, vector.dataBuf.buf, vector.dataBuf.len, nil, + cast[POVERLAPPED](addr transp.wovl)) + if ret == 0: + let err = osLastError() + if int32(err) != ERROR_IO_PENDING: + transp.state.excl(WritePending) + transp.setWriteError(err) + transp.finishWriter() + break + + if len(transp.queue) == 0: + transp.state.incl(WritePaused) + + proc readStreamLoop(udata: pointer) {.gcsafe.} = + if isNil(udata): + return + var ovl = cast[PCustomOverlapped](udata) + var transp = cast[WindowsStreamTransport](ovl.data.udata) + + while true: + if ReadPending in transp.state: + ## Continuation + if ReadClosed in transp.state: + break + transp.state.excl(ReadPending) + let err = transp.rovl.data.errCode + if err == OSErrorCode(-1): + let bytesCount = transp.rovl.data.bytesCount + if bytesCount == 0: + transp.state.incl(ReadEof) + transp.state.incl(ReadPaused) + else: + if transp.offset != transp.roffset: + moveMem(addr transp.buffer[transp.offset], + addr transp.buffer[transp.roffset], + bytesCount) + transp.offset += bytesCount + transp.roffset = transp.offset + if transp.offset == len(transp.buffer): + transp.state.incl(ReadPaused) + else: + transp.setReadError(err) + if not isNil(transp.reader): + transp.finishReader() + else: + ## Initiation + if (ReadEof notin transp.state) and (ReadClosed notin transp.state): + var flags = DWORD(0) + var bytesCount: int32 = 0 + transp.state.excl(ReadPaused) + transp.state.incl(ReadPending) + if transp.kind == TransportKind.Socket: + let sock = SocketHandle(transp.rovl.data.fd) + transp.setWSABuffer() + let ret = WSARecv(sock, addr transp.wsabuf, 1, + addr bytesCount, addr flags, + cast[POVERLAPPED](addr transp.rovl), nil) + if ret != 0: + let err = osLastError() + if int(err) == ERROR_OPERATION_ABORTED: + transp.state.incl(ReadPaused) + elif int32(err) != ERROR_IO_PENDING: + transp.setReadError(err) + if not isNil(transp.reader): + transp.finishReader() + ## Finish Loop + break + + proc newStreamSocketTransport(sock: AsyncFD, bufsize: int): StreamTransport = + var t = WindowsStreamTransport(kind: TransportKind.Socket) + t.fd = sock + t.rovl.data = CompletionData(fd: sock, cb: readStreamLoop, + udata: cast[pointer](t)) + t.wovl.data = CompletionData(fd: sock, cb: writeStreamLoop, + udata: cast[pointer](t)) + t.buffer = newSeq[byte](bufsize) + t.state = {ReadPaused, WritePaused} + t.queue = initDeque[StreamVector]() + t.future = newFuture[void]("stream.socket.transport") + result = cast[StreamTransport](t) + + proc bindToDomain(handle: AsyncFD, domain: Domain): bool = + result = true + if domain == Domain.AF_INET6: + var saddr: Sockaddr_in6 + saddr.sin6_family = int16(toInt(domain)) + if bindAddr(SocketHandle(handle), cast[ptr SockAddr](addr(saddr)), + sizeof(saddr).SockLen) != 0'i32: + result = false + else: + var saddr: Sockaddr_in + saddr.sin_family = int16(toInt(domain)) + if bindAddr(SocketHandle(handle), cast[ptr SockAddr](addr(saddr)), + sizeof(saddr).SockLen) != 0'i32: + result = false + + proc connect*(address: TransportAddress, + bufferSize = DefaultStreamBufferSize): Future[StreamTransport] = + let loop = getGlobalDispatcher() + var + saddr: Sockaddr_storage + slen: SockLen + sock: AsyncFD + povl: RefCustomOverlapped + + var retFuture = newFuture[StreamTransport]("stream.transport.connect") + toSockAddr(address.address, address.port, saddr, slen) + sock = createAsyncSocket(address.address.getDomain(), SockType.SOCK_STREAM, + Protocol.IPPROTO_TCP) + if sock == asyncInvalidSocket: + result.fail(newException(OSError, osErrorMsg(osLastError()))) + + if not bindToDomain(sock, address.address.getDomain()): + sock.closeAsyncSocket() + result.fail(newException(OSError, osErrorMsg(osLastError()))) + + proc continuation(udata: pointer) = + var ovl = cast[PCustomOverlapped](udata) + if not retFuture.finished: + if ovl.data.errCode == OSErrorCode(-1): + if setsockopt(SocketHandle(sock), cint(SOL_SOCKET), + cint(SO_UPDATE_CONNECT_CONTEXT), nil, + SockLen(0)) != 0'i32: + sock.closeAsyncSocket() + retFuture.fail(newException(OSError, osErrorMsg(osLastError()))) + else: + retFuture.complete(newStreamSocketTransport(povl.data.fd, + bufferSize)) + else: + sock.closeAsyncSocket() + retFuture.fail(newException(OSError, osErrorMsg(ovl.data.errCode))) + + povl = RefCustomOverlapped() + povl.data = CompletionData(fd: sock, cb: continuation) + var res = loop.connectEx(SocketHandle(sock), + cast[ptr SockAddr](addr saddr), + DWORD(slen), nil, 0, nil, + cast[POVERLAPPED](povl)) + # We will not process immediate completion, to avoid undefined behavior. + if not res: + let err = osLastError() + if int32(err) != ERROR_IO_PENDING: + sock.closeAsyncSocket() + retFuture.fail(newException(OSError, osErrorMsg(err))) + return retFuture + + proc acceptAddr(server: WindowsStreamServer): Future[AsyncFD] = + var retFuture = newFuture[AsyncFD]("transport.acceptAddr") + let loop = getGlobalDispatcher() + var sock = createAsyncSocket(server.domain, SockType.SOCK_STREAM, + Protocol.IPPROTO_TCP) + if sock == asyncInvalidSocket: + retFuture.fail(newException(OSError, osErrorMsg(osLastError()))) + + var dwBytesReceived = DWORD(0) + let dwReceiveDataLength = DWORD(0) + let dwLocalAddressLength = DWORD(sizeof(Sockaddr_in6) + 16) + let dwRemoteAddressLength = DWORD(sizeof(Sockaddr_in6) + 16) + + proc continuation(udata: pointer) = + var ovl = cast[PCustomOverlapped](udata) + if not retFuture.finished: + if server.server.status in {Stopped, Paused}: + sock.closeAsyncSocket() + retFuture.complete(asyncInvalidSocket) + else: + if ovl.data.errCode == OSErrorCode(-1): + if setsockopt(SocketHandle(sock), cint(SOL_SOCKET), + cint(SO_UPDATE_ACCEPT_CONTEXT), + addr server.server.sock, + SockLen(sizeof(SocketHandle))) != 0'i32: + sock.closeAsyncSocket() + retFuture.fail(newException(OSError, osErrorMsg(osLastError()))) + else: + retFuture.complete(sock) + else: + sock.closeAsyncSocket() + retFuture.fail(newException(OSError, osErrorMsg(ovl.data.errCode))) + + server.aovl.data.fd = server.server.sock + server.aovl.data.cb = continuation + + let res = loop.acceptEx(SocketHandle(server.server.sock), + SocketHandle(sock), addr server.abuffer[0], + dwReceiveDataLength, dwLocalAddressLength, + dwRemoteAddressLength, addr dwBytesReceived, + cast[POVERLAPPED](addr server.aovl)) + + if not res: + let err = osLastError() + if int32(err) != ERROR_IO_PENDING: + retFuture.fail(newException(OSError, osErrorMsg(err))) + return retFuture + + proc serverLoop(server: StreamServer): Future[void] {.async.} = + ## TODO: This procedure must be reviewed, when cancellation support + ## will be added + var wserver = new WindowsStreamServer + wserver.server = server + wserver.domain = server.local.address.getDomain() + await server.actEvent.wait() + server.actEvent.clear() + if server.action == ServerCommand.Start: + var eventFut = server.actEvent.wait() + while true: + var acceptFut = acceptAddr(wserver) + await eventFut or acceptFut + if eventFut.finished: + if server.action == ServerCommand.Start: + if server.status in {Stopped, Paused}: + server.status = Running + elif server.action == ServerCommand.Stop: + if server.status in {Running}: + server.status = Stopped + break + elif server.status in {Paused}: + server.status = Stopped + break + elif server.action == ServerCommand.Pause: + if server.status in {Running}: + server.status = Paused + if acceptFut.finished: + if not acceptFut.failed: + var sock = acceptFut.read() + if sock != asyncInvalidSocket: + discard server.function( + newStreamSocketTransport(sock, server.bufferSize), + server.udata) + + proc resumeRead(transp: StreamTransport) {.inline.} = + var wtransp = cast[WindowsStreamTransport](transp) + wtransp.state.excl(ReadPaused) + readStreamLoop(cast[pointer](addr wtransp.rovl)) + + proc resumeWrite(transp: StreamTransport) {.inline.} = + var wtransp = cast[WindowsStreamTransport](transp) + wtransp.state.excl(WritePaused) + writeStreamLoop(cast[pointer](addr wtransp.wovl)) + +else: + import posix + + type + UnixStreamTransport* = ref object of StreamTransport + + template getVectorBuffer(v: untyped): pointer = + cast[pointer](cast[uint]((v).buf) + uint((v).boffset)) + + template getVectorLength(v: untyped): int = + cast[int]((v).buflen - int((v).boffset)) + + template shiftVectorBuffer(t, o: untyped) = + (t).buf = cast[pointer](cast[uint]((t).buf) + uint(o)) + (t).buflen -= int(o) + + template initBufferStreamVector(v, p, n, t: untyped) = + (v).kind = DataBuffer + (v).buf = cast[pointer]((p)) + (v).buflen = int(n) + (v).writer = (t) + + proc writeStreamLoop(udata: pointer) {.gcsafe.} = + var cdata = cast[ptr CompletionData](udata) + var transp = cast[UnixStreamTransport](cdata.udata) + let fd = SocketHandle(cdata.fd) + if not isNil(transp): + if len(transp.queue) > 0: + echo "len(transp.queue) = ", len(transp.queue) + var vector = transp.queue.popFirst() + while true: + if transp.kind == TransportKind.Socket: + if vector.kind == VectorKind.DataBuffer: + let res = posix.send(fd, vector.buf, vector.buflen, MSG_NOSIGNAL) + if res >= 0: + if vector.buflen - res == 0: + vector.writer.complete() + else: + vector.shiftVectorBuffer(res) + transp.queue.addFirst(vector) + else: + let err = osLastError() + if int(err) == EINTR: + continue + else: + transp.setWriteError(err) + vector.writer.complete() + break + else: + discard + else: + transp.state.incl(WritePaused) + transp.fd.removeWriter() + + proc readStreamLoop(udata: pointer) {.gcsafe.} = + var cdata = cast[ptr CompletionData](udata) + var transp = cast[UnixStreamTransport](cdata.udata) + let fd = SocketHandle(cdata.fd) + if not isNil(transp): + while true: + var res = posix.recv(fd, addr transp.buffer[transp.offset], + len(transp.buffer) - transp.offset, cint(0)) + if res < 0: + let err = osLastError() + if int(err) == EINTR: + continue + elif int(err) in {ECONNRESET}: + transp.state.incl(ReadEof) + transp.state.incl(ReadPaused) + cdata.fd.removeReader() + else: + transp.setReadError(err) + cdata.fd.removeReader() + elif res == 0: + transp.state.incl(ReadEof) + transp.state.incl(ReadPaused) + cdata.fd.removeReader() + else: + transp.offset += res + if transp.offset == len(transp.buffer): + transp.state.incl(ReadPaused) + cdata.fd.removeReader() + if not isNil(transp.reader): + transp.finishReader() + break + + proc newStreamSocketTransport(sock: AsyncFD, bufsize: int): StreamTransport = + var t = UnixStreamTransport(kind: TransportKind.Socket) + t.fd = sock + t.buffer = newSeq[byte](bufsize) + t.state = {ReadPaused, WritePaused} + t.queue = initDeque[StreamVector]() + t.future = newFuture[void]("socket.stream.transport") + result = cast[StreamTransport](t) + + proc connect*(address: TransportAddress, + bufferSize = DefaultStreamBufferSize): Future[StreamTransport] = + var + saddr: Sockaddr_storage + slen: SockLen + sock: AsyncFD + var retFuture = newFuture[StreamTransport]("transport.connect") + toSockAddr(address.address, address.port, saddr, slen) + sock = createAsyncSocket(address.address.getDomain(), SockType.SOCK_STREAM, + Protocol.IPPROTO_TCP) + if sock == asyncInvalidSocket: + result.fail(newException(OSError, osErrorMsg(osLastError()))) + + proc continuation(udata: pointer) = + var data = cast[ptr CompletionData](udata) + var err = 0 + if not data.fd.getSocketError(err): + sock.closeAsyncSocket() + retFuture.fail(newException(OSError, osErrorMsg(osLastError()))) + return + if err != 0: + sock.closeAsyncSocket() + retFuture.fail(newException(OSError, osErrorMsg(OSErrorCode(err)))) + return + data.fd.removeWriter() + retFuture.complete(newStreamSocketTransport(data.fd, bufferSize)) + + while true: + var res = posix.connect(SocketHandle(sock), + cast[ptr SockAddr](addr saddr), slen) + if res == 0: + retFuture.complete(newStreamSocketTransport(sock, bufferSize)) + break + else: + let err = osLastError() + if int(err) == EINTR: + continue + elif int(err) == EINPROGRESS: + sock.addWriter(continuation) + break + else: + sock.closeAsyncSocket() + retFuture.fail(newException(OSError, osErrorMsg(err))) + break + return retFuture + + proc serverCallback(udata: pointer) = + var + saddr: Sockaddr_storage + slen: SockLen + + var server = cast[StreamServer](cast[ptr CompletionData](udata).udata) + while true: + let res = posix.accept(SocketHandle(server.sock), + cast[ptr SockAddr](addr saddr), addr slen) + if int(res) > 0: + let sock = wrapAsyncSocket(res) + if sock != asyncInvalidSocket: + discard server.function( + newStreamSocketTransport(sock, server.bufferSize), + server.udata) + break + else: + let err = osLastError() + if int(err) == EINTR: + continue + elif int(err) in {EBADF, EINVAL, ENOTSOCK, EOPNOTSUPP, EPROTO}: + ## Critical unrecoverable error + raiseOsError(err) + + proc serverLoop(server: SocketServer): Future[void] {.async.} = + while true: + await server.actEvent.wait() + server.actEvent.clear() + if server.action == ServerCommand.Start: + if server.status in {Stopped, Paused, Starting}: + addReader(server.sock, serverCallback, + cast[pointer](server)) + server.status = Running + elif server.action == ServerCommand.Stop: + if server.status in {Running}: + removeReader(server.sock) + server.status = Stopped + break + elif server.status in {Paused}: + server.status = Stopped + break + elif server.action == ServerCommand.Pause: + if server.status in {Running}: + removeReader(server.sock) + server.status = Paused + + proc resumeRead(transp: StreamTransport) {.inline.} = + transp.state.excl(ReadPaused) + addReader(transp.fd, readStreamLoop, cast[pointer](transp)) + + proc resumeWrite(transp: StreamTransport) {.inline.} = + transp.state.excl(WritePaused) + addWriter(transp.fd, writeStreamLoop, cast[pointer](transp)) + +proc start*(server: SocketServer) = + server.action = Start + server.actEvent.fire() + +proc stop*(server: SocketServer) = + server.action = Stop + server.actEvent.fire() + +proc pause*(server: SocketServer) = + server.action = Pause + server.actEvent.fire() + +proc join*(server: SocketServer) {.async.} = + await server.loopFuture + +proc createStreamServer*(host: TransportAddress, + flags: set[ServerFlags], + cbproc: StreamCallback, + sock: AsyncFD = asyncInvalidSocket, + backlog: int = 100, + bufferSize: int = DefaultStreamBufferSize, + udata: pointer = nil): StreamServer = + var + saddr: Sockaddr_storage + slen: SockLen + serverSocket: AsyncFD + if sock == asyncInvalidSocket: + serverSocket = createAsyncSocket(host.address.getDomain(), + SockType.SOCK_STREAM, + Protocol.IPPROTO_TCP) + if serverSocket == asyncInvalidSocket: + raiseOsError(osLastError()) + else: + if not setSocketBlocking(SocketHandle(sock), false): + raiseOsError(osLastError()) + register(sock) + serverSocket = sock + + ## TODO: Set socket options here + if ServerFlags.ReuseAddr in flags: + if not setSockOpt(serverSocket, SOL_SOCKET, SO_REUSEADDR, 1): + let err = osLastError() + if sock == asyncInvalidSocket: + closeAsyncSocket(serverSocket) + raiseOsError(err) + + toSockAddr(host.address, host.port, saddr, slen) + if bindAddr(SocketHandle(serverSocket), cast[ptr SockAddr](addr saddr), + slen) != 0: + let err = osLastError() + if sock == asyncInvalidSocket: + closeAsyncSocket(serverSocket) + raiseOsError(err) + + if nativesockets.listen(SocketHandle(serverSocket), cint(backlog)) != 0: + let err = osLastError() + if sock == asyncInvalidSocket: + closeAsyncSocket(serverSocket) + raiseOsError(err) + + result = StreamServer() + result.sock = serverSocket + result.function = cbproc + result.bufferSize = bufferSize + result.status = Starting + result.actEvent = newAsyncEvent() + result.udata = udata + result.local = host + result.loopFuture = serverLoop(result) + +proc write*(transp: StreamTransport, pbytes: pointer, + nbytes: int): Future[int] {.async.} = + checkClosed(transp) + var waitFuture = newFuture[void]("transport.write") + var vector: StreamVector + vector.initBufferStreamVector(pbytes, nbytes, waitFuture) + transp.queue.addLast(vector) + if WritePaused in transp.state: + transp.resumeWrite() + await vector.writer + if WriteError in transp.state: + raise transp.getError() + result = nbytes + +# proc writeFile*(transp: StreamTransport, handle: int, +# offset: uint = 0, +# size: uint = 0): Future[void] {.async.} = +# if transp.kind != TransportKind.Socket: +# raise newException(TransportError, "You can transmit files only to sockets") +# checkClosed(transp) +# var waitFuture = newFuture[void]("transport.writeFile") +# var vector: StreamVector +# vector.initTransmitStreamVector(handle, offset, size, waitFuture) +# transp.queue.addLast(vector) + +# if WritePaused in transp.state: +# transp.resumeWrite() +# await vector.writer +# if WriteError in transp.state: +# raise transp.getError() + +proc readExactly*(transp: StreamTransport, pbytes: pointer, + nbytes: int): Future[int] {.async.} = + ## Read exactly ``nbytes`` bytes from transport ``transp``. + checkClosed(transp) + checkPending(transp) + var index = 0 + while true: + if transp.offset == 0: + if (ReadError in transp.state): + raise transp.getError() + if (ReadEof in transp.state) or (ReadClosed in transp.state): + raise newException(TransportIncompleteError, "Data incomplete!") + + if transp.offset >= (nbytes - index): + copyMem(cast[pointer](cast[uint](pbytes) + uint(index)), + addr(transp.buffer[0]), nbytes - index) + transp.shiftBuffer(nbytes - index) + result = nbytes + break + else: + copyMem(cast[pointer](cast[uint](pbytes) + uint(index)), + addr(transp.buffer[0]), transp.offset) + index += transp.offset + transp.reader = newFuture[void]("transport.readExactly") + transp.offset = 0 + if ReadPaused in transp.state: + transp.resumeRead() + await transp.reader + # we are no longer need data + transp.reader = nil + +proc readOnce*(transp: StreamTransport, pbytes: pointer, + nbytes: int): Future[int] {.async.} = + ## Perform one read operation on transport ``transp``. + checkClosed(transp) + checkPending(transp) + while true: + if transp.offset == 0: + if (ReadError in transp.state): + raise transp.getError() + if (ReadEof in transp.state) or (ReadClosed in transp.state): + result = 0 + break + transp.reader = newFuture[void]("transport.readOnce") + if ReadPaused in transp.state: + transp.resumeRead() + await transp.reader + transp.reader = nil + else: + if transp.offset > nbytes: + copyMem(pbytes, addr(transp.buffer[0]), nbytes) + transp.shiftBuffer(nbytes) + result = nbytes + else: + copyMem(pbytes, addr(transp.buffer[0]), transp.offset) + result = transp.offset + break + +proc readUntil*(transp: StreamTransport, pbytes: pointer, nbytes: int, + sep: seq[byte]): Future[int] {.async.} = + checkClosed(transp) + checkPending(transp) + + var dest = cast[ptr UncheckedArray[byte]](pbytes) + var state = 0 + var k = 0 + var index = 0 + + while true: + if (transp.offset - index) == 0: + if ReadError in transp.state: + transp.shiftBuffer(index) + raise transp.getError() + if (ReadEof in transp.state) or (ReadClosed in transp.state): + transp.shiftBuffer(index) + raise newException(TransportIncompleteError, "Data incomplete!") + + index = 0 + while index < transp.offset: + let ch = transp.buffer[index] + if sep[state] == ch: + inc(state) + else: + state = 0 + if k < nbytes: + dest[k] = ch + inc(k) + else: + raise newException(TransportLimitError, "Limit reached!") + + if state == len(sep): + transp.shiftBuffer(index + 1) + break + + inc(index) + + if state == len(sep): + result = k + break + else: + if (transp.offset - index) == 0: + transp.reader = newFuture[void]("transport.readUntil") + if ReadPaused in transp.state: + transp.resumeRead() + await transp.reader + + # we are no longer need data + transp.reader = nil + +proc readLine*(transp: StreamTransport, limit = 0, + sep = "\r\n"): Future[string] {.async.} = + checkClosed(transp) + checkPending(transp) + + result = "" + var lim = if limit <= 0: -1 else: limit + var state = 0 + var index = 0 + + while true: + if (transp.offset - index) == 0: + if (ReadError in transp.state): + transp.shiftBuffer(index) + raise transp.getError() + if (ReadEof in transp.state) or (ReadClosed in transp.state): + transp.shiftBuffer(index) + break + + index = 0 + while index < transp.offset: + let ch = char(transp.buffer[index]) + if sep[state] == ch: + inc(state) + if state == len(sep): + transp.shiftBuffer(index + 1) + break + else: + state = 0 + result.add(ch) + if len(result) == lim: + transp.shiftBuffer(index + 1) + break + inc(index) + + if (state == len(sep)) or (lim == len(result)): + break + else: + if (transp.offset - index) == 0: + transp.reader = newFuture[void]("transport.readLine") + if ReadPaused in transp.state: + transp.resumeRead() + await transp.reader + + # we are no longer need data + transp.reader = nil + +proc read*(transp: StreamTransport, n = -1): Future[seq[byte]] {.async.} = + checkClosed(transp) + checkPending(transp) + + result = newSeq[byte]() + + while true: + if (ReadError in transp.state): + raise transp.getError() + if (ReadEof in transp.state) or (ReadClosed in transp.state): + break + + if transp.offset > 0: + let s = len(result) + let o = s + transp.offset + if n == -1: + # grabbing all incoming data, until EOF + result.setLen(o) + copyMem(cast[pointer](addr result[s]), addr(transp.buffer[0]), + transp.offset) + transp.offset = 0 + else: + if transp.offset >= (n - s): + # size of buffer data is more then we need, grabbing only part + let part = transp.offset - (n - s) + result.setLen(n) + copyMem(cast[pointer](addr result[s]), addr(transp.buffer[0]), + part) + transp.shiftBuffer(part) + break + else: + # there not enough data in buffer, grabbing all + result.setLen(o) + copyMem(cast[pointer](addr result[s]), addr(transp.buffer[0]), + transp.offset) + transp.offset = 0 + + transp.reader = newFuture[void]("transport.read") + if ReadPaused in transp.state: + transp.resumeRead() + await transp.reader + + # we are no longer need data + transp.reader = nil + +proc atEof*(transp: StreamTransport): bool {.inline.} = + ## Returns ``true`` if ``transp`` is at EOF. + result = (transp.offset == 0) and (ReadEof in transp.state) and + (ReadPaused in transp.state) + +proc join*(transp: StreamTransport) {.async.} = + ## Wait until ``transp`` will not be closed. + await transp.future + +proc close*(transp: StreamTransport) = + ## Closes and frees resources of transport ``transp``. + if ReadClosed notin transp.state and WriteClosed notin transp.state: + when defined(windows): + discard cancelIo(Handle(transp.fd)) + closeAsyncSocket(transp.fd) + transp.state.incl(WriteClosed) + transp.state.incl(ReadClosed) + transp.future.complete() diff --git a/tests/test1.nim b/tests/test1.nim new file mode 100644 index 00000000..df3c25bc --- /dev/null +++ b/tests/test1.nim @@ -0,0 +1,18 @@ +import asyncdispatch2 + +proc testProc() {.async.} = + for i in 1..1_000: + await sleepAsync(1000) + echo "Timeout event " & $i + +proc callbackProc(udata: pointer) {.gcsafe.} = + echo "Callback event" + callSoon(callbackProc) + +when isMainModule: + discard getGlobalDispatcher() + asyncCheck testProc() + callSoon(callbackProc) + for i in 1..100: + echo "Iteration " & $i + poll() diff --git a/tests/test2.nim b/tests/test2.nim new file mode 100644 index 00000000..6dcf8274 --- /dev/null +++ b/tests/test2.nim @@ -0,0 +1,13 @@ +import asyncdispatch2 + +proc task() {.async.} = + await sleepAsync(10) + +when isMainModule: + var counter = 0 + var f = task() + while not f.finished: + inc(counter) + poll() + +echo counter diff --git a/tests/test3.nim b/tests/test3.nim new file mode 100644 index 00000000..94c165fa --- /dev/null +++ b/tests/test3.nim @@ -0,0 +1,10 @@ +import asyncdispatch2 + +proc task() {.async.} = + await sleepAsync(1000) + +proc waitTask() {.async.} = + echo await withTimeout(task(), 100) + +when isMainModule: + waitFor waitTask() diff --git a/tests/test4.nim b/tests/test4.nim new file mode 100644 index 00000000..f7a7d52e --- /dev/null +++ b/tests/test4.nim @@ -0,0 +1,11 @@ +import ../asyncdispatch2 + +proc task() {.async.} = + if true: + raise newException(ValueError, "Test Error") + +proc waitTask() {.async.} = + await task() + +when isMainModule: + waitFor waitTask() diff --git a/tests/testdatagram.nim b/tests/testdatagram.nim new file mode 100644 index 00000000..a2ba6579 --- /dev/null +++ b/tests/testdatagram.nim @@ -0,0 +1,173 @@ +import strutils, net, unittest +import ../asyncdispatch2 + +const + TestsCount = 5000 + ClientsCount = 10 + MessagesCount = 50 + +proc client1(transp: DatagramTransport, pbytes: pointer, nbytes: int, + raddr: TransportAddress, udata: pointer): Future[void] {.async.} = + if not isNil(pbytes): + var data = newString(nbytes + 1) + copyMem(addr data[0], pbytes, nbytes) + data.setLen(nbytes) + if data.startsWith("REQUEST"): + var numstr = data[7..^1] + var num = parseInt(numstr) + var ans = "ANSWER" & $num + await transp.sendTo(addr ans[0], len(ans), raddr) + else: + var err = "ERROR" + await transp.sendTo(addr err[0], len(err), raddr) + else: + ## Read operation failed with error + var counterPtr = cast[ptr int](udata) + counterPtr[] = -1 + transp.close() + +proc client2(transp: DatagramTransport, pbytes: pointer, nbytes: int, + raddr: TransportAddress, udata: pointer): Future[void] {.async.} = + if not isNil(pbytes): + var data = newString(nbytes + 1) + copyMem(addr data[0], pbytes, nbytes) + data.setLen(nbytes) + if data.startsWith("ANSWER"): + var counterPtr = cast[ptr int](udata) + counterPtr[] = counterPtr[] + 1 + if counterPtr[] == TestsCount: + transp.close() + else: + var ta: TransportAddress + ta.address = parseIpAddress("127.0.0.1") + ta.port = Port(33336) + var req = "REQUEST" & $counterPtr[] + await transp.sendTo(addr req[0], len(req), ta) + else: + var counterPtr = cast[ptr int](udata) + counterPtr[] = -1 + transp.close() + else: + ## Read operation failed with error + var counterPtr = cast[ptr int](udata) + counterPtr[] = -1 + transp.close() + +proc client3(transp: DatagramTransport, pbytes: pointer, nbytes: int, + raddr: TransportAddress, udata: pointer): Future[void] {.async.} = + if not isNil(pbytes): + var data = newString(nbytes + 1) + copyMem(addr data[0], pbytes, nbytes) + data.setLen(nbytes) + if data.startsWith("ANSWER"): + var counterPtr = cast[ptr int](udata) + counterPtr[] = counterPtr[] + 1 + if counterPtr[] == TestsCount: + transp.close() + else: + var req = "REQUEST" & $counterPtr[] + await transp.send(addr req[0], len(req)) + else: + echo "ERROR" + var counterPtr = cast[ptr int](udata) + counterPtr[] = -1 + transp.close() + else: + ## Read operation failed with error + echo "ERROR" + var counterPtr = cast[ptr int](udata) + counterPtr[] = -1 + transp.close() + +proc client4(transp: DatagramTransport, pbytes: pointer, nbytes: int, + raddr: TransportAddress, udata: pointer): Future[void] {.async.} = + if not isNil(pbytes): + var data = newString(nbytes + 1) + copyMem(addr data[0], pbytes, nbytes) + data.setLen(nbytes) + if data.startsWith("ANSWER"): + var counterPtr = cast[ptr int](udata) + counterPtr[] = counterPtr[] + 1 + if counterPtr[] == MessagesCount: + transp.close() + else: + var req = "REQUEST" & $counterPtr[] + await transp.send(addr req[0], len(req)) + else: + echo "ERROR" + var counterPtr = cast[ptr int](udata) + counterPtr[] = -1 + transp.close() + else: + ## Read operation failed with error + echo "ERROR" + var counterPtr = cast[ptr int](udata) + counterPtr[] = -1 + transp.close() + +proc test1(): Future[int] {.async.} = + var ta: TransportAddress + var counter = 0 + ta.address = parseIpAddress("127.0.0.1") + ta.port = Port(33336) + var dgram1 = newDatagramTransport(client1, udata = addr counter, local = ta) + var dgram2 = newDatagramTransport(client2, udata = addr counter) + var data = "REQUEST0" + await dgram2.sendTo(addr data[0], len(data), ta) + await dgram2.join() + dgram1.close() + result = counter + +proc test2(): Future[int] {.async.} = + var ta: TransportAddress + var counter = 0 + ta.address = parseIpAddress("127.0.0.1") + ta.port = Port(33337) + var dgram1 = newDatagramTransport(client1, udata = addr counter, local = ta) + var dgram2 = newDatagramTransport(client3, udata = addr counter, remote = ta) + var data = "REQUEST0" + await dgram2.send(addr data[0], len(data)) + await dgram2.join() + dgram1.close() + result = counter + +proc waitAll(futs: seq[Future[void]]): Future[void] = + var counter = len(futs) + var retFuture = newFuture[void]("waitAll") + proc cb(udata: pointer) = + dec(counter) + if counter == 0: + retFuture.complete() + for fut in futs: + fut.addCallback(cb) + return retFuture + +proc test3(): Future[int] {.async.} = + var ta: TransportAddress + ta.address = parseIpAddress("127.0.0.1") + ta.port = Port(33337) + var counter = 0 + var dgram1 = newDatagramTransport(client1, udata = addr counter, local = ta) + var clients = newSeq[Future[void]](ClientsCount) + var counters = newSeq[int](ClientsCount) + for i in 0..