Compare commits

...

110 Commits

Author SHA1 Message Date
Eugene Kabanov
b55e2816eb
Allow cancelAndWait() accept multiple Futures for cancellation. (#572)
* Allow cancelAndWait accept multiple Futures for cancellation.
Add tests.

* Add 2 more tests.

* Update test name.

* Fix sporadic test.
2025-03-21 14:46:18 +02:00
Eugene Kabanov
0646c444fc
Bump nimble file. (#567) 2025-02-06 17:12:34 +02:00
Eugene Kabanov
36d8ee5617
Fix, waitpid() should wait for finished process. (#566) 2025-01-30 13:28:25 +02:00
Eugene Kabanov
7c5cbf04a6
Fix baseUri should provide correct value for ANY_ADDRESS. (#563) 2025-01-13 18:31:36 +02:00
Miran
70cbe346e2
use the common CI workflow (#561) 2025-01-13 11:15:41 +01:00
Eugene Kabanov
03f4328de6
Fix possible issues with byte order in ipnet. (#562)
* Fix issues with byte order.

* Some polishing.
2024-11-28 10:10:12 +02:00
Etan Kissling
9186950e03
Replace apt-fast with apt-get (#558)
`apt-fast` was removed from GitHub with Ubuntu 24.04:

- https://github.com/actions/runner-images/issues/10003

For compatibility, switch back to `apt-get`.
2024-10-15 15:19:42 +00:00
Eugene Kabanov
c04576d829
Bump version to 4.0.3. (#555) 2024-08-22 02:53:48 +03:00
diegomrsantos
dc3847e4d6
add ubuntu 24 and gcc 14 (#553)
* add ubuntu 24 and gcc 14

* upgrade bearssl

* Fix nim-1-6 gcc-14 issue.

* rename target to linux-gcc-14

* Bump bearssl.

---------

Co-authored-by: cheatfate <eugene.kabanov@status.im>
2024-07-18 20:59:03 +03:00
c-blake
8f609b6c17
Fix tests to be string hash order independent (#551) 2024-07-09 11:42:20 +02:00
Miran
13d28a5b71
update ci.yml and be more explicit in .nimble (#549) 2024-07-03 12:57:58 +02:00
Jacek Sieka
4ad38079de
pretty-printer for Duration (#547) 2024-06-20 09:52:23 +02:00
Giuliano Mega
7630f39471
Fixes compilation issues in v3 compatibility mode (-d:chronosHandleException) (#545)
* add missing calls to await

* add test run in v3 compatibility

* fix semantics for chronosHandleException so it does not override local raises/handleException annotations

* distinguish between explicit override and default setting; fix test

* re-enable wrongly disabled check

* make implementation simpler/clearer

* update docs

* reflow long line

* word swap
2024-06-10 10:18:42 +02:00
Jacek Sieka
c44406594f
fix results import 2024-06-07 12:05:15 +02:00
Eugene Kabanov
1b9d9253e8
Fix GCC-14 [-Wincompatible-pointer-types] issues. (#546)
* Fix class assignment.

* One more fix.

* Bump bearssl version.
2024-06-02 18:05:22 +03:00
Jacek Sieka
8a306763ce
docs for join and noCancel 2024-05-07 17:19:35 +02:00
Jacek Sieka
1ff81c60ea
avoid warning in noCancel with non-raising future (#540) 2024-05-06 08:56:48 +00:00
Jacek Sieka
52b02b9977
remove unnecessary impl overloads (#539) 2024-05-04 11:52:42 +02:00
Eugene Kabanov
72f560f049
Fix RangeError defect being happened using android toolchain. (#538)
* Fix RangeError defect being happened using android toolchain.

* Set proper type for `Tnfds`.

* Update comment.
2024-04-25 19:08:53 +03:00
Eugene Kabanov
bb96f02ae8
Fix wait(future) declaration signature. (#537) 2024-04-24 03:16:23 +03:00
Eugene Kabanov
0f0ed1d654
Add wait(deadline future) implementation. (#535)
* Add waitUntil(deadline) implementation.

* Add one more test.

* Fix rare race condition and tests for it.

* Rename waitUntil() to wait().
2024-04-20 03:49:07 +03:00
Eugene Kabanov
d184a92227
Fix rare cancellation race issue on timeout for wait/withTimeout. (#536)
Add tests.
2024-04-19 16:43:34 +03:00
Eugene Kabanov
7a3eaffa4f
Fix English spelling for readed variable. (#534) 2024-04-17 23:08:19 +00:00
Eugene Kabanov
bd7d84fbcb
Fix AsyncStreamReader constructor declaration mistypes. (#533) 2024-04-17 14:41:36 +00:00
Eugene Kabanov
e4cb48088c
Fix inability to change httpclient's internal buffer size. (#531)
Add test.
Address #529.
2024-04-17 17:27:14 +03:00
Eugene Kabanov
0d050d5823
Add automatic constructors for TCP and UDP transports. (#512)
* Add automatic constructors for TCP and UDP transports.

* Add port number argument.
Add some documentation comments.
Fix tests.

* Make datagram test use request/response scheme.

* Add helper.

* Fix issue with non-zero port setups.
Add test.

* Fix tests to probe ports.

* Attempt to fix MacOS issue.

* Add Opt[IpAddress].
Make IPv4 mapping to IPv6 space automatic.

* Add tests.

* Add stream capabilities.

* Fix Linux issues.

* Make getTransportFlags() available for all OSes.

* Fix one more compilation issue.

* Workaround weird compiler bug.

* Fix forgotten typed version of constructor.

* Make single source for addresses calculation.

* Add one more check into tests.

* Fix flags not being set in transport constructor.

* Fix post-rebase issues with flags not being set.

* Address review comments.
2024-04-13 03:04:42 +03:00
Eugene Kabanov
8e49df1400
Ensure that all buffers used inside HTTP client will follow original buffer size. (#530)
Ensure that buffer size cannot be lower than default size.
2024-04-07 07:03:12 +03:00
Eugene Kabanov
2d85229dce
Add join() operation to wait for future completion. (#525)
* Add `join()` operation to wait for future completion without cancelling it when `join()` got cancelled.

* Start using join() operation.
2024-04-04 00:30:01 +03:00
Eugene Kabanov
402914f4cf
Add custom ring buffer into chronos streams and transports. (#485)
* Add custom ring buffer into chronos stream transport.

* Rename BipBuffer.decommit() to BipBuffer.consume()
Make asyncstream's using BipBuffer.

* Address review comments part 1.

* Address review comments part 2.

* Address review comments.

* Remove unused import results.

* Address review comments.
2024-03-26 22:33:19 +02:00
Jacek Sieka
ef1b077adf
v4.0.2 2024-03-25 10:38:37 +01:00
Jacek Sieka
b8b4e1fc47
make Raising compatible with 2.0 (#526)
* make `Raising` compatible with 2.0

See https://github.com/nim-lang/Nim/issues/23432

* Update tests/testfut.nim

* Update tests/testfut.nim
2024-03-25 10:37:42 +01:00
Jacek Sieka
0e806d59ae
v4.0.1 2024-03-21 09:21:51 +01:00
Jacek Sieka
d5bc90fef2
Work around type resolution with empty generic (#522)
* Work around type resolution with empty generic

* workaround
2024-03-20 12:08:26 +01:00
Eugene Kabanov
035288f3f0
Remove sink and chronosMoveSink() usage. (#524) 2024-03-20 07:47:59 +01:00
Eugene Kabanov
d4f1487b0c
Disable libbacktrace enabled test on X86 platforms. (#523)
* Disable libbacktrace enabled test on X86 platforms.

* Fix mistype.

* Use macos-12 workers from now.
2024-03-19 16:28:52 +00:00
Jacek Sieka
47cc17719f
print warning when calling failed (#521)
`failed` cannot return true for futures that don't forward exceptions
2024-03-08 14:43:42 +01:00
Etan Kissling
17b7a76c7e
Ensure transp.reader is reset to nil on error (#508)
In `stream.readLoop`, a finished `Future` was left in `transp.reader`
if there was an error in `resumeRead`. Set it to `nil` as well.

Co-authored-by: Jacek Sieka <jacek@status.im>
2024-03-07 08:09:16 +01:00
Jacek Sieka
c5a5ece487
fix circular reference in timer (#510) 2024-03-07 08:07:53 +01:00
Jacek Sieka
03d82475d9
Avoid ValueError effect in varargs race/one (#520)
We can check at compile-time that at least one parameter is passed

* clean up closure environment explicitly in some callbacks to release
memory earlier
2024-03-06 06:42:22 +01:00
Eugene Kabanov
f6c7ecfa0a
Add missing parts of defaults buffer size increase. (#513) 2024-03-06 01:56:40 +02:00
Eugene Kabanov
4ed0cd6be7
Ensure that OwnCancelSchedule flag will not be removed from wait() and withTimeout(). (#519) 2024-03-05 17:34:53 +01:00
Eugene Kabanov
1eb834a2f9
Fix or deadlock issue. (#517)
* Fix `or` should not create future with OwnCancelSchedule flag set.

* Fix `CancelledError` missing from raises list when both futures has empty raises list.

* Fix macros tests.
2024-03-05 17:33:46 +01:00
Etan Kissling
5dfa3fd7fa
fix conversion error with or on futures with {.async: (raises: []).} (#515)
```nim
import chronos

proc f(): Future[void] {.async: (raises: []).} =
  discard

discard f() or f() or f()
```

```
/Users/etan/Documents/Repos/nimbus-eth2/vendor/nim-chronos/chronos/internal/raisesfutures.nim(145, 44) union
/Users/etan/Documents/Repos/nimbus-eth2/vendor/nimbus-build-system/vendor/Nim/lib/core/macros.nim(185, 28) []
/Users/etan/Documents/Repos/nimbus-eth2/test.nim(6, 13) template/generic instantiation of `or` from here
/Users/etan/Documents/Repos/nimbus-eth2/vendor/nim-chronos/chronos/internal/asyncfutures.nim(1668, 39) template/generic instantiation of `union` from here
/Users/etan/Documents/Repos/nimbus-eth2/vendor/nimbus-build-system/vendor/Nim/lib/core/macros.nim(185, 28) Error: illegal conversion from '-1' to '[0..9223372036854775807]'
```

Fix by checking for `void` before trying to access `raises`
2024-03-05 13:53:12 +01:00
Eugene Kabanov
7b02247ce7
Add --mm:refc to libbacktrace test. (#505)
* Add `--mm:refc` to `libbacktrace` test.

* Make tests with `refc` to run before tests with default memory manager.
2024-02-14 19:23:15 +02:00
Eugene Kabanov
2e37a6e26c
Increase AsyncStream and Transport default buffer size from 4096 to 16384 bytes. (#506)
Make buffer sizes configurable at compile time.
2024-02-14 19:23:01 +02:00
cheatfate
be4923be19
Strip debugging echo in threadsync tests. 2024-02-14 14:09:01 +02:00
Eugene Kabanov
a81961a3c6
Fix HTTP server accept() loop exiting under heavy load. (#502)
* Add more specific accept() exceptions raised.
Add some refactoring to HTTP server code.

* Refactor acceptLoop.

* Print GC statistics in every failing test.

* Try to disable failing tests.
2024-02-14 14:05:19 +02:00
Jacek Sieka
8cf2d69aaa
Minimal threading docs (#493)
* Minimal threading docs

* compile examples with threads

* links
2024-02-14 08:27:09 +01:00
Eugene Kabanov
08db79fe63
Disable memory hungry tests in 32bit tests. (#503)
* Disable memory hungry tests in 32bit tests.

* Limit threadsync tests for 32bit.
2024-02-14 00:03:12 +02:00
Jacek Sieka
672db137b7
v4.0.0 (#494)
Features:

* Exception effects / raises for async procedures helping you write more
efficient leak-free code
* Cross-thread notification mechanism for suitable building channels,
queues and other multithreaded primitives
* Async process I/O
* IPv6 dual stack support
* HTTP middleware support alloing multiple services to share a single
http server
* A new [documentation web
site](https://status-im.github.io/nim-chronos/) covering the basics,
with several simple examples for getting started
* Implicit returns, support for `results.?` and other conveniences
* Rate limiter
* Revamped cancellation support with more control over the cancellation
process
* Efficiency improvements with `lent` and `sink`

See the [porting](https://status-im.github.io/nim-chronos/porting.html)
guides for porting code from earlier chronos releases (as well as
asyncdispatch)
2024-01-24 19:33:13 +02:00
Eugene Kabanov
09a0b11719
Make asyncproc use asyncraises. (#497)
* Make asyncproc use asyncraises.

* Fix missing asyncraises for waitForExit().
2024-01-23 08:34:10 +01:00
Jacek Sieka
e296ae30c8
asyncraises for threadsync (#495)
* asyncraises for threadsync

* missing bracket

* missing exception
2024-01-20 16:56:57 +01:00
Jacek Sieka
3ca2c5e6b5
deprecate callback=, UDP fixes (fixes #491) (#492)
Using the callback setter may lead to callbacks owned by others being
reset, which is unexpected.

* don't crash on zero-length UDP writes
2024-01-19 09:21:10 +01:00
Jacek Sieka
1021a7d294
check leaks after every test (#487) 2024-01-18 14:34:16 +02:00
cheatfate
92acf68b04
Fix examples documentation. 2024-01-12 15:39:45 +02:00
Eugene Kabanov
b02b9608c3
HTTP server middleware implementation. (#483)
* HTTP server middleware implementation and test.

* Address review comments.

* Address review comments.
2024-01-12 15:27:36 +02:00
Jacek Sieka
f0a2d4df61
Feature flag for raises support (#488)
Feature flags allow consumers of chronos to target versions with and
without certain features via compile-time selection. The first feature
flag added is for raise tracking support.
2024-01-08 14:54:50 +01:00
Jacek Sieka
e15dc3b41f
prevent http closeWait future from being cancelled (#486)
* simplify `closeWait` implementations
  * remove redundant cancellation callbacks
  * use `noCancel` to avoid forgetting the right future flags
* add a few missing raises trackers
* enforce `OwnCancelSchedule` on manually created futures that don't raise `CancelledError`
* ensure cancellations don't reach internal futures
2024-01-04 16:17:42 +01:00
Jacek Sieka
41f77d261e
Better line information on effect violation
We can capture the line info from the original future source and direct
violation errors there
2023-12-27 20:57:39 +01:00
Jacek Sieka
1598471ed2
add a test for results.? compatibility (#484)
Finally! (haha)
2023-12-21 15:52:16 +01:00
Eugene Kabanov
c41599a6d6
Asyncraises HTTP layer V3 (#482)
* No Critical and Recoverable errors anymore.

* Recover raiseHttpCriticalError()

* Post-rebase fixes.

* Remove deprecated ResponseFence and getResponseFence().

* HttpProcessCallback and 2.

* Fix callback holder.

* Fix test issue.

* Fix backwards compatibility of `HttpResponse.state` field.
2023-12-09 06:50:35 +02:00
Jacek Sieka
e38ceb5378
fix v3 backwards compatibility for callbacks (#481)
Because the callback types were used explicitly in some consumers of
chronos, the change of type introduces a backwards incompatibility
preventing a smooth transition to v4 for code that doesn't uses
`raises`.

This PR restores backwards compatibility at the expense of introducing a
new type with a potentially ugly name - that said, there is already
precedence for using numbered names to provide new error handling
strategy in chronos.
2023-12-04 14:19:29 +01:00
Jacek Sieka
48b2b08cfb
Update docs (#480)
* new mdbook version with built-in Nim highlighting support
* describe examples in a dedicated page
* fixes
2023-12-01 12:33:28 +01:00
Eugene Kabanov
28a100b135
Fix processing callback missing asyncraises. (#479) 2023-11-28 18:57:13 +02:00
Eugene Kabanov
b18d471629
Asyncraises HTTP client/server. (#476)
* Fixes.

* Make httpcommon no-raises.

* Make httpbodyrw no-raises.

* Make multipart no-raises.

* Make httpdebug no-raises.

* Make httpagent no-raises.

* Make httpclient no-raises.

* Make httpserver/shttpserver no-raises.

* fix prepend/remove when E is noraises

---------

Co-authored-by: Jacek Sieka <jacek@status.im>
2023-11-21 11:01:44 +01:00
Jacek Sieka
fa0bf405e6
varargs overloads (#477)
* varargs overloads

for convenience and compatibility

* no parameterless varargs calls with generic overloads
2023-11-20 12:04:28 +02:00
Jacek Sieka
f03cdfcc40
futures: sinkify (#475)
This avoids copies here and there throughout the pipeline - ie
`copyString` and friends can often be avoided when moving things into
and out of futures

Annoyingly, one has to sprinkle the codebase liberally with `sink` and
`move` for the pipeline to work well - sink stuff _generally_ works
better in orc/arc

Looking at nim 1.6/refc, sink + local variable + move generates the best
code:

msg directly:
```nim
	T1_ = (*colonenv_).msg1; (*colonenv_).msg1 = copyStringRC1(msg);
```

local copy without move:
```nim
	T60_ = (*colonenv_).localCopy1; (*colonenv_).localCopy1 =
copyStringRC1(msg);
```

local copy with move:
```nim
	asgnRef((void**) (&(*colonenv_).localCopy1), msg);
```

Annoyingly, sink is also broken for refc+literals as it tries to
changes the refcount of the literal as part of the move (which shouldn't
be happening, but here we are), so we have to use a hack to find
literals and avoid moving them.
2023-11-19 18:29:09 +01:00
Eugene Kabanov
0b136b33c8
Asyncstreams asyncraises. (#472)
* Fix transports addresses functions should not return so many exceptions.

* Add raising `Defect` functions to AsyncQueue.

* Add raises/asyncraises into async streams.

* Remove `Safe` primitives.
Make AsyncStreamError to be ancestor of AsyncError.
Make AsyncStreamReader/Writer loops requirement to not raise any exceptions

* Remove `par` fields.

* Remove `par` fields from TLSStream.

* Attempt to lower memory usage.
2023-11-17 23:18:09 +01:00
Jacek Sieka
1306170255
dedicated exceptions for Future.read failures (#474)
Dedicated exceptions for `read` failures reduce the risk of mixing up
"user" exceptions with those of Future itself. The risk still exists, if
the user allows a chronos exception to bubble up explicitly.

Because `await` structurally guarantees that the Future is not `pending`
at the time of `read`, it does not raise this new exception.

* introduce `FuturePendingError` and `FutureCompletedError` when
`read`:ing a future of uncertain state
* fix `waitFor` / `read` to return `lent` values
* simplify code generation for `void`-returning async procs
* document `Raising` type helper
2023-11-17 13:45:17 +01:00
Jacek Sieka
f5ff9e32ca
introduce asyncraises in transports/asyncsync (#470)
With these fixes, `transports`/`asyncsync` correctly propagate and document their raises information - generally, most transport functions (send etc) raise `TransportError` and `CancelledError` - `closeWait` is special in that it generally doesn't fail.

This PR introduces the syntax `Future[void].Raises([types])` to create the `InternalRaisesFuture` type with the correct encoding for the types - this allows it to be used in user code while retaining the possibility to change the internal representation down the line.

* introduce raising constraints on stream callbacks - these constraints now give a warning when called with a callback that can raise exceptions (raising callbacks would crash 
* fix fail and its tests, which wasn't always given a good generic match
* work around nim bugs related to macro expansion of generic types
* make sure transports raise only `TransportError`-derived exceptions (and `CancelledError`)
2023-11-15 09:38:48 +01:00
Jacek Sieka
24be151cf3
move docs to docs (#466)
* introduce user guide based on `mdbook`
* set up structure for adding simple `chronos` usage examples
* move most readme content to book
* ci deploys book and api guide automatically
* remove most of existing engine docs (obsolete)
2023-11-15 09:06:37 +01:00
Eugene Kabanov
9c93ab48de
Attempt to fix CI crash at Windows. (#465)
* Attempt to fix CI crash at Windows.
Remove all cast[string] and cast[seq[byte]] from the codebase.

* Address review comments.
2023-11-13 13:14:21 +02:00
Jacek Sieka
0d55475c29
stew/results -> results (#468) 2023-11-13 10:56:19 +01:00
Jacek Sieka
f0eb7a0ae9
simplify tests (#469)
* simplify tests

`chronosPreviewV4` is obsolete

* oops
2023-11-13 10:54:37 +01:00
Eugene Kabanov
8156e2997a
Fix not enough memory on i386. (#467)
* Fix waitFor() should not exit earlier last callback will be scheduled.

* Tune tests to use less memory.

* Fix `testutils`. There is no more last poll() needed.

* Update chronos/internal/asyncfutures.nim

---------

Co-authored-by: Jacek Sieka <jacek@status.im>
2023-11-10 07:42:36 +01:00
Eugene Kabanov
9896316599
Remove deprecated AsyncEventBus. (#461)
* Remove deprecated AsyncEventBus.
Change number of tests for ThreadSignal.

* Recover 1000 tests count.
2023-11-09 18:01:43 +02:00
Jacek Sieka
9759f01016
doc generation fixes (#464)
* doc generation fixes

* fix
2023-11-08 21:20:24 +01:00
Jacek Sieka
c252ce68d8
verbose test output on actions rerun (#462) 2023-11-08 16:15:11 +01:00
Jacek Sieka
53690f4717
run tests outside of nim compilation (#463)
else we need memory for both compiler and test
2023-11-08 16:14:33 +01:00
Jacek Sieka
5ebd771d35
per-function Exception handling (#457)
This PR replaces the global strict exception mode with an option to
handle `Exception` per function while at the same time enabling strict
exception checking globally by default as has been planned for v4.

`handleException` mode raises `AsyncExceptionError` to distinguish it
from `ValueError` which may originate from user code.

* remove obsolete 1.2 config options
2023-11-08 15:12:32 +01:00
Jacek Sieka
cd6369c048
asyncraises -> async: (raises: ..., raw: ...) (#455)
Per discussion in
https://github.com/status-im/nim-chronos/pull/251#issuecomment-1559233139,
`async: (parameters..)` is introduced as a way to customize the async
transformation instead of relying on separate keywords (like
asyncraises).

Two parameters are available as of now:

`raises`: controls the exception effect tracking
`raw`: disables body transformation

Parameters are added to `async` as a tuple allowing more params to be
added easily in the future:
```nim:
proc f() {.async: (name: value, ...).}`
```
2023-11-07 12:12:59 +02:00
Eugene Kabanov
be2edab3ac
Consider ERROR_NETNAME_DELETED as ConnectionAbortedError. (#460) 2023-10-31 03:43:58 +02:00
Eugene Kabanov
a70b145964
IPv4/IPv6 dualstack (#456)
* Initial commit.

* Fix tests.

* Fix linux compilation issue.

* Add getDomain() implementation.
Add getDomain() tests.
Add datagram tests.

* Fix style errors.

* Deprecate NetFlag.
Deprecate new flags in ServerFlags.
Add isAvailable().
Fix setDualstack() to ignore errors on `Auto`.
Updatetests.

* Deprecate some old procedures.
Improve datagram transport a bit.

* Address review comments, and fix tests.

* Fix setDescriptorBlocking() issue.
Recover connect() dualstack behavior.
Add test for connect() IPv6-[IPv4 mapped] addresses.

* Fix alignment code issue.
Fix TcpNoDelay was not available on Windows.

* Add dualstack support to HTTP/HTTPS client/server.
2023-10-30 15:27:50 +02:00
Eugene Kabanov
8375770fe5
Fix unreachable code places. (#459)
* Fix unreachable code.

* Use implicit returns instead.
2023-10-30 15:27:25 +02:00
Tanguy
12dc36cfee
Update README regarding cancellation (#450)
* Update README regarding cancellation

* Apply suggestions from code review

Co-authored-by: Eugene Kabanov <eugene.kabanov@status.im>

---------

Co-authored-by: Jacek Sieka <jacek@status.im>
Co-authored-by: Eugene Kabanov <eugene.kabanov@status.im>
2023-10-25 15:16:10 +02:00
Jacek Sieka
f56d286687
introduce asyncraises to core future utilities (#454)
* introduce `asyncraises` to core future utilities

Similar to the introduction of `raises` into a codebase, `asyncraises`
needs to be introduced gradually across all functionality before
deriving benefit.

This is a first introduction along with utilities to manage raises lists
and transform them at compile time.

Several scenarios ensue:

* for trivial cases, adding `asyncraises` is enough and the framework
deduces the rest
* some functions "add" new asyncraises (similar to what `raise` does in
"normal" code) - for example `wait` may raise all exceptions of the
future passed to it and additionally a few of its own - this requires
extending the raises list
* som functions "remove" raises (similar to what `try/except` does) such
as `nocancel` with blocks cancellations and therefore reduce the raising
set

Both of the above cases are currently handled by a macro, but depending
on the situation lead to code organisation issues around return types
and pragma limitations - in particular, to keep `asyncraises`
backwards-compatibility, some code needs to exist in two versions which
somewhat complicates the implementation.

* add `asyncraises` versions for several `asyncfutures` utilities
* when assigning exceptions to a `Future` via `fail`, check at compile
time if possible and at runtime if not that the exception matches
constraints
* fix `waitFor` comments
* move async raises to separate module, implement `or`
2023-10-24 16:21:07 +02:00
Jacek Sieka
e3c5a86a14
Introduce chronos/internals, move some code (#453)
* Introduce chronos/internals, move some code

This PR breaks the include dependencies between `asyncfutures2` and
`asyncmacros2` by moving the dispatcher and some other code to a new
module.

This step makes it easier to implement `asyncraises` support for future
utilities like `allFutures` etc avoiding the need to play tricks with
include order etc.

Future PR:s may further articulate the difference between "internal"
stuff subject to API breakage and regular public API intended for end
users (rather than advanced integrators).

* names

* windows fix
2023-10-17 20:25:25 +02:00
Jacek Sieka
be9eef7a09
move test data to c file (#448)
* move test data to c file

allows compiling with nlvm

* more nlvm compat
2023-10-17 14:19:20 +02:00
Tanguy
a759c11ce4
Raise tracking (#251)
* Exception tracking v2

* some fixes

* Nim 1.2 compat

* simpler things

* Fixes for libp2p

* Fixes for strictException

* better await exception check

* Fix for template async proc

* make async work with procTy

* FuturEx is now a ref object type

* add tests

* update test

* update readme

* Switch to asyncraises pragma

* Address tests review comments

* Rename FuturEx to RaiseTrackingFuture

* Fix typo

* Split asyncraises into async, asyncraises

* Add -d:chronosWarnMissingRaises

* Add comment to RaiseTrackingFuture

* Allow standalone asyncraises

* CheckedFuture.fail type checking

* First cleanup

* Remove useless line

* Review comments

* nimble: Remove #head from unittest2

* Remove implict raises: CancelledError

* Move checkFutureExceptions to asyncfutures2

* Small refacto

* small cleanup

* Complete in closure finally

* cleanup tests, add comment

* bump

* chronos is not compatible with nim 1.2 anymore

* re-add readme modifications

* fix special exception handlers

* also propagate excetion type in `read`

* `RaiseTrackingFuture` -> `InternalRaisesFuture`

Use internal naming scheme for RTF (this type should only be accessed
via asyncraises)

* use `internalError` for error reading

* oops

* 2.0 workarounds

* again

* remove try/finally for non-raising functions

* Revert "remove try/finally for non-raising functions"

This reverts commit 86bfeb5c972ef379a3bd34e4a16cd158a7455721.

`finally` is needed if code returns early :/

* fixes

* avoid exposing `newInternalRaisesFuture` in manual macro code
* avoid unnecessary codegen for `Future[void]`
* avoid reduntant block around async proc body
* simplify body generation for forward declarations with comment but no
body
* avoid duplicate `gcsafe` annotiations
* line info for return at end of async proc

* expand tests

* fix comments, add defer test

---------

Co-authored-by: Jacek Sieka <jacek@status.im>
2023-10-17 14:18:14 +02:00
Tanguy
253bc3cfc0
Complete futures in closure finally (fix #415) (#449)
* Complete in closure finally

* cleanup tests, add comment

* handle defects

* don't complete future on defect

* complete future in test to avoid failure

* fix with strict exceptions

* fix regressions

* fix nim 1.6
2023-10-16 10:38:11 +02:00
Eugene Kabanov
2e8551b0d9
Cancellation fixes and tests. (#445)
* Add callTick and stream cancellation tests.

* Fix stepsAsync() test.

* Cancellation changes.

* Update and add more cancellation tests.

* Fix Posix shutdown call to handle ENOTCONN error.

* With new changes to to cancellation its now possible.

* Refactor testsoon.nim to not produce artifacts after tests are finished.

* Debugging MacOS issue.

* Adjust flaky test times.

* Fix issue.

* Add test for issue #334 which was also addressed in this PR.
Avoid `break` in problematic test.

* Add noCancelWait() call which prohibits cancellation.
Fix closeWait() calls to use noCancelWait() predicate.
Adding sleep to flaky MacOS test.

* Remove all debugging echoes.

* Fix cancelAndWait() which now could perform multiple attempts to cancel target Future (mustCancel behavior).

* Fix issues revealed by switch to different cancelAndWait().

* Address review comments.

* Fix testutils compilation warning.

* Rename callTick() to internalCallTick().

* Add some documentation comments.

* Disable flaky ratelimit test.

* Rename noCancelWait() to noCancel().
Address review comments.
2023-09-15 19:38:39 +03:00
Eugene Kabanov
00614476c6
Address issue #443. (#447)
* Address issue #443.

* Address review comments.
2023-09-07 16:25:25 +03:00
cheatfate
db6410f835
Fix CI badge status. 2023-09-05 13:48:09 +03:00
Jacek Sieka
e706167a53
add connect cancellation test (#444) 2023-09-05 13:41:52 +03:00
Eugene Kabanov
300fbaaf09
HttpAddress errors should be not only critical. (#446)
* Distinguish between resolve errors and check errors.

* Fix issues and add test for getHttpAddress() procedure.

* Address review comments.
2023-09-04 21:49:45 +03:00
Eugene Kabanov
60e6fc55bf
Fix #431. (#441) 2023-08-11 00:31:47 +03:00
Jacek Sieka
a7f708bea8
futures: lentify (#413)
sometimes avoid copies when reading from `Future`
2023-08-09 17:27:17 +03:00
Eugene Kabanov
6c2ea67512
Unroll defers and remove breaks. (#440)
* Unpack `finally/defer` blocks and introduce explicit cleaning of objects.
Add request query to debug information.

* Unroll one more loop to avoid `break`.
Add test for query debug string.

* Fix cancellation behavior.

* Address review comments.
2023-08-09 10:57:49 +03:00
diegomrsantos
466241aa95
Remove reuseaddr (#438)
* Remove hard-coded ports when non-windows

* Remove ReuseAddr from test
2023-08-08 03:11:35 +03:00
diegomrsantos
194226a0e0
Remove hard-coded ports when non-windows (#437) 2023-08-08 03:10:28 +03:00
andri lim
c4b066a2c4
ci: upgrade github actions/cache to v3 (#434) 2023-08-04 14:32:12 +07:00
andri lim
38c31e21d3
fix type mismatch error in asyncstream join (#433) 2023-08-04 09:27:01 +02:00
Jacek Sieka
a1eb30360b
fix invalid protocol casts (#430) 2023-08-04 08:08:34 +02:00
diegomrsantos
c546a4329c
Use random ports (#429) 2023-08-02 21:04:30 +02:00
Eugene Kabanov
6b4f5a1d23
Recover poll engine and add tests. (#421)
* Initial commit.

* Fix one more place with deprecated constant.

* Fix testall and nimble file.

* Fix poll issue.

* Workaround Nim's faulty declaration of `poll()` and types on MacOS.

* Fix syntax errors.

* Fix MacOS post-rebase issue.

* Add more conditionals.

* Address review comments.

* Fix Nim 1.2 configuration defaults.
2023-08-01 12:56:08 +03:00
rockcavera
5c39bf47be
fixing unfreed memory leak with freeAddrInfo() (#425)
* fixing unfreed memory leak with `freeAddrInfo()`

* `freeaddrinfo` to `freeAddrInfo()`
2023-08-01 01:28:34 +03:00
Eugene Kabanov
d214bcfb4f
Increase backlog defaults to maximum possible values. (#428) 2023-07-31 22:40:00 +03:00
Eugene Kabanov
926956bcbe
Add time used to establish HTTP client connection. (#427) 2023-07-30 12:43:25 +03:00
Eugene Kabanov
53e9f75735
Add some helpers for asyncproc. (#424)
* Initial commit.

* Adjust posix tests.

* Fix compilation issue.

* Attempt to fix flaky addProcess() test.
2023-07-28 11:54:53 +03:00
Eugene Kabanov
f91ac169dc
Fix NoVerifyServerName do not actually disables SNI extension. (#423)
Fix HTTP client SSL/TLS error information is now part of connection error exception.
2023-07-23 19:40:57 +03:00
96 changed files with 15055 additions and 8016 deletions

View File

@ -1,40 +0,0 @@
version: '{build}'
image: Visual Studio 2015
cache:
- NimBinaries
matrix:
# We always want 32 and 64-bit compilation
fast_finish: false
platform:
- x86
- x64
# when multiple CI builds are queued, the tested commit needs to be in the last X commits cloned with "--depth X"
clone_depth: 10
install:
# use the newest versions documented here: https://www.appveyor.com/docs/windows-images-software/#mingw-msys-cygwin
- IF "%PLATFORM%" == "x86" SET PATH=C:\mingw-w64\i686-6.3.0-posix-dwarf-rt_v5-rev1\mingw32\bin;%PATH%
- IF "%PLATFORM%" == "x64" SET PATH=C:\mingw-w64\x86_64-8.1.0-posix-seh-rt_v6-rev0\mingw64\bin;%PATH%
# build nim from our own branch - this to avoid the day-to-day churn and
# regressions of the fast-paced Nim development while maintaining the
# flexibility to apply patches
- curl -O -L -s -S https://raw.githubusercontent.com/status-im/nimbus-build-system/master/scripts/build_nim.sh
- env MAKE="mingw32-make -j2" ARCH_OVERRIDE=%PLATFORM% bash build_nim.sh Nim csources dist/nimble NimBinaries
- SET PATH=%CD%\Nim\bin;%PATH%
build_script:
- cd C:\projects\%APPVEYOR_PROJECT_SLUG%
- nimble install -y --depsOnly
- nimble install -y libbacktrace
test_script:
- nimble test
deploy: off

View File

@ -6,156 +6,12 @@ on:
pull_request: pull_request:
workflow_dispatch: workflow_dispatch:
concurrency: # Cancel stale PR builds (but not push builds)
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.sha }}
cancel-in-progress: true
jobs: jobs:
build: build:
strategy: uses: status-im/nimbus-common-workflow/.github/workflows/common.yml@main
fail-fast: false with:
matrix: test-command: |
target: nimble install -y libbacktrace
- os: linux nimble test
cpu: amd64 nimble test_libbacktrace
- os: linux nimble examples
cpu: i386
- os: macos
cpu: amd64
- os: windows
cpu: amd64
#- os: windows
#cpu: i386
branch: [version-1-6, version-2-0, devel]
include:
- target:
os: linux
builder: ubuntu-20.04
shell: bash
- target:
os: macos
builder: macos-11
shell: bash
- target:
os: windows
builder: windows-2019
shell: msys2 {0}
defaults:
run:
shell: ${{ matrix.shell }}
name: '${{ matrix.target.os }}-${{ matrix.target.cpu }} (Nim ${{ matrix.branch }})'
runs-on: ${{ matrix.builder }}
continue-on-error: ${{ matrix.branch == 'devel' }}
steps:
- name: Checkout
uses: actions/checkout@v3
- name: Install build dependencies (Linux i386)
if: runner.os == 'Linux' && matrix.target.cpu == 'i386'
run: |
sudo dpkg --add-architecture i386
sudo apt-fast update -qq
sudo DEBIAN_FRONTEND='noninteractive' apt-fast install \
--no-install-recommends -yq gcc-multilib g++-multilib \
libssl-dev:i386
mkdir -p external/bin
cat << EOF > external/bin/gcc
#!/bin/bash
exec $(which gcc) -m32 "\$@"
EOF
cat << EOF > external/bin/g++
#!/bin/bash
exec $(which g++) -m32 "\$@"
EOF
chmod 755 external/bin/gcc external/bin/g++
echo '${{ github.workspace }}/external/bin' >> $GITHUB_PATH
- name: MSYS2 (Windows i386)
if: runner.os == 'Windows' && matrix.target.cpu == 'i386'
uses: msys2/setup-msys2@v2
with:
path-type: inherit
msystem: MINGW32
install: >-
base-devel
git
mingw-w64-i686-toolchain
- name: MSYS2 (Windows amd64)
if: runner.os == 'Windows' && matrix.target.cpu == 'amd64'
uses: msys2/setup-msys2@v2
with:
path-type: inherit
install: >-
base-devel
git
mingw-w64-x86_64-toolchain
- name: Restore Nim DLLs dependencies (Windows) from cache
if: runner.os == 'Windows'
id: windows-dlls-cache
uses: actions/cache@v2
with:
path: external/dlls-${{ matrix.target.cpu }}
key: 'dlls-${{ matrix.target.cpu }}'
- name: Install DLLs dependencies (Windows)
if: >
steps.windows-dlls-cache.outputs.cache-hit != 'true' &&
runner.os == 'Windows'
run: |
mkdir -p external
curl -L "https://nim-lang.org/download/windeps.zip" -o external/windeps.zip
7z x -y external/windeps.zip -oexternal/dlls-${{ matrix.target.cpu }}
- name: Path to cached dependencies (Windows)
if: >
runner.os == 'Windows'
run: |
echo "${{ github.workspace }}/external/dlls-${{ matrix.target.cpu }}" >> $GITHUB_PATH
- name: Derive environment variables
run: |
if [[ '${{ matrix.target.cpu }}' == 'amd64' ]]; then
PLATFORM=x64
else
PLATFORM=x86
fi
echo "PLATFORM=$PLATFORM" >> $GITHUB_ENV
ncpu=
MAKE_CMD="make"
case '${{ runner.os }}' in
'Linux')
ncpu=$(nproc)
;;
'macOS')
ncpu=$(sysctl -n hw.ncpu)
;;
'Windows')
ncpu=$NUMBER_OF_PROCESSORS
MAKE_CMD="mingw32-make"
;;
esac
[[ -z "$ncpu" || $ncpu -le 0 ]] && ncpu=1
echo "ncpu=$ncpu" >> $GITHUB_ENV
echo "MAKE_CMD=${MAKE_CMD}" >> $GITHUB_ENV
- name: Build Nim and Nimble
run: |
curl -O -L -s -S https://raw.githubusercontent.com/status-im/nimbus-build-system/master/scripts/build_nim.sh
env MAKE="${MAKE_CMD} -j${ncpu}" ARCH_OVERRIDE=${PLATFORM} NIM_COMMIT=${{ matrix.branch }} \
QUICK_AND_DIRTY_COMPILER=1 QUICK_AND_DIRTY_NIMBLE=1 CC=gcc \
bash build_nim.sh nim csources dist/nimble NimBinaries
echo '${{ github.workspace }}/nim/bin' >> $GITHUB_PATH
- name: Run tests
run: |
nim --version
nimble --version
nimble install -y --depsOnly
nimble install -y libbacktrace
nimble test
nimble test_libbacktrace

View File

@ -15,48 +15,44 @@ jobs:
continue-on-error: true continue-on-error: true
steps: steps:
- name: Checkout - name: Checkout
uses: actions/checkout@v2 uses: actions/checkout@v4
with: with:
submodules: true submodules: true
- uses: actions-rs/install@v0.1
with:
crate: mdbook
use-tool-cache: true
version: "0.4.36"
- uses: actions-rs/install@v0.1
with:
crate: mdbook-toc
use-tool-cache: true
version: "0.14.1"
- uses: actions-rs/install@v0.1
with:
crate: mdbook-open-on-gh
use-tool-cache: true
version: "2.4.1"
- uses: actions-rs/install@v0.1
with:
crate: mdbook-admonish
use-tool-cache: true
version: "1.14.0"
- uses: jiro4989/setup-nim-action@v1 - uses: jiro4989/setup-nim-action@v1
with: with:
nim-version: '1.6.6' nim-version: '1.6.20'
- name: Generate doc - name: Generate doc
run: | run: |
nim --version nim --version
nimble --version nimble --version
nimble install -dy nimble install -dy
# nim doc can "fail", but the doc is still generated nimble docs || true
nim doc --git.url:https://github.com/status-im/nim-chronos --git.commit:master --outdir:docs --project chronos || true
# check that the folder exists - name: Deploy
ls docs uses: peaceiris/actions-gh-pages@v3
- name: Clone the gh-pages branch
uses: actions/checkout@v2
with: with:
repository: status-im/nim-chronos github_token: ${{ secrets.GITHUB_TOKEN }}
ref: gh-pages publish_dir: ./docs/book
path: subdoc force_orphan: true
submodules: true
fetch-depth: 0
- name: Commit & push
run: |
cd subdoc
# Update / create this branch doc
rm -rf docs
mv ../docs .
# Remove .idx files
# NOTE: git also uses idx files in his
# internal folder, hence the `*` instead of `.`
find * -name "*.idx" -delete
git add .
git config --global user.email "${{ github.actor }}@users.noreply.github.com"
git config --global user.name = "${{ github.actor }}"
git commit -a -m "update docs"
git push origin gh-pages

View File

@ -1,27 +0,0 @@
language: c
# https://docs.travis-ci.com/user/caching/
cache:
directories:
- NimBinaries
git:
# when multiple CI builds are queued, the tested commit needs to be in the last X commits cloned with "--depth X"
depth: 10
os:
- linux
- osx
install:
# build nim from our own branch - this to avoid the day-to-day churn and
# regressions of the fast-paced Nim development while maintaining the
# flexibility to apply patches
- curl -O -L -s -S https://raw.githubusercontent.com/status-im/nimbus-build-system/master/scripts/build_nim.sh
- env MAKE="make -j2" bash build_nim.sh Nim csources dist/nimble NimBinaries
- export PATH="$PWD/Nim/bin:$PATH"
script:
- nimble install -y
- nimble test

340
README.md
View File

@ -1,6 +1,6 @@
# Chronos - An efficient library for asynchronous programming # Chronos - An efficient library for asynchronous programming
[![Github action](https://github.com/status-im/nim-chronos/workflows/nim-chronos%20CI/badge.svg)](https://github.com/status-im/nim-chronos/actions/workflows/ci.yml) [![Github action](https://github.com/status-im/nim-chronos/workflows/CI/badge.svg)](https://github.com/status-im/nim-chronos/actions/workflows/ci.yml)
[![License: Apache](https://img.shields.io/badge/License-Apache%202.0-blue.svg)](https://opensource.org/licenses/Apache-2.0) [![License: Apache](https://img.shields.io/badge/License-Apache%202.0-blue.svg)](https://opensource.org/licenses/Apache-2.0)
[![License: MIT](https://img.shields.io/badge/License-MIT-blue.svg)](https://opensource.org/licenses/MIT) [![License: MIT](https://img.shields.io/badge/License-MIT-blue.svg)](https://opensource.org/licenses/MIT)
![Stability: experimental](https://img.shields.io/badge/stability-experimental-orange.svg) ![Stability: experimental](https://img.shields.io/badge/stability-experimental-orange.svg)
@ -9,16 +9,16 @@
Chronos is an efficient [async/await](https://en.wikipedia.org/wiki/Async/await) framework for Nim. Features include: Chronos is an efficient [async/await](https://en.wikipedia.org/wiki/Async/await) framework for Nim. Features include:
* Efficient dispatch pipeline for asynchronous execution * Asynchronous socket and process I/O
* HTTP server with SSL/TLS support out of the box (no OpenSSL needed) * HTTP server with SSL/TLS support out of the box (no OpenSSL needed)
* Cancellation support
* Synchronization primitivies like queues, events and locks * Synchronization primitivies like queues, events and locks
* FIFO processing order of dispatch queue * Cancellation
* Minimal exception effect support (see [exception effects](#exception-effects)) * Efficient dispatch pipeline with excellent multi-platform support
* Exceptional error handling features, including `raises` tracking
## Installation ## Getting started
You can use Nim's official package manager Nimble to install Chronos: Install `chronos` using `nimble`:
```text ```text
nimble install chronos nimble install chronos
@ -30,6 +30,30 @@ or add a dependency to your `.nimble` file:
requires "chronos" requires "chronos"
``` ```
and start using it:
```nim
import chronos/apps/http/httpclient
proc retrievePage(uri: string): Future[string] {.async.} =
# Create a new HTTP session
let httpSession = HttpSessionRef.new()
try:
# Fetch page contents
let resp = await httpSession.fetch(parseUri(uri))
# Convert response to a string, assuming its encoding matches the terminal!
bytesToString(resp.data)
finally: # Close the session
await noCancel(httpSession.closeWait())
echo waitFor retrievePage(
"https://raw.githubusercontent.com/status-im/nim-chronos/master/README.md")
```
## Documentation
See the [user guide](https://status-im.github.io/nim-chronos/).
## Projects using `chronos` ## Projects using `chronos`
* [libp2p](https://github.com/status-im/nim-libp2p) - Peer-to-Peer networking stack implemented in many languages * [libp2p](https://github.com/status-im/nim-libp2p) - Peer-to-Peer networking stack implemented in many languages
@ -42,305 +66,7 @@ requires "chronos"
Submit a PR to add yours! Submit a PR to add yours!
## Documentation
### Concepts
Chronos implements the async/await paradigm in a self-contained library, using
macros, with no specific helpers from the compiler.
Our event loop is called a "dispatcher" and a single instance per thread is
created, as soon as one is needed.
To trigger a dispatcher's processing step, we need to call `poll()` - either
directly or through a wrapper like `runForever()` or `waitFor()`. This step
handles any file descriptors, timers and callbacks that are ready to be
processed.
`Future` objects encapsulate the result of an async procedure, upon successful
completion, and a list of callbacks to be scheduled after any type of
completion - be that success, failure or cancellation.
(These explicit callbacks are rarely used outside Chronos, being replaced by
implicit ones generated by async procedure execution and `await` chaining.)
Async procedures (those using the `{.async.}` pragma) return `Future` objects.
Inside an async procedure, you can `await` the future returned by another async
procedure. At this point, control will be handled to the event loop until that
future is completed.
Future completion is tested with `Future.finished()` and is defined as success,
failure or cancellation. This means that a future is either pending or completed.
To differentiate between completion states, we have `Future.failed()` and
`Future.cancelled()`.
### Dispatcher
You can run the "dispatcher" event loop forever, with `runForever()` which is defined as:
```nim
proc runForever*() =
while true:
poll()
```
You can also run it until a certain future is completed, with `waitFor()` which
will also call `Future.read()` on it:
```nim
proc p(): Future[int] {.async.} =
await sleepAsync(100.milliseconds)
return 1
echo waitFor p() # prints "1"
```
`waitFor()` is defined like this:
```nim
proc waitFor*[T](fut: Future[T]): T =
while not(fut.finished()):
poll()
return fut.read()
```
### Async procedures and methods
The `{.async.}` pragma will transform a procedure (or a method) returning a
specialised `Future` type into a closure iterator. If there is no return type
specified, a `Future[void]` is returned.
```nim
proc p() {.async.} =
await sleepAsync(100.milliseconds)
echo p().type # prints "Future[system.void]"
```
Whenever `await` is encountered inside an async procedure, control is passed
back to the dispatcher for as many steps as it's necessary for the awaited
future to complete successfully, fail or be cancelled. `await` calls the
equivalent of `Future.read()` on the completed future and returns the
encapsulated value.
```nim
proc p1() {.async.} =
await sleepAsync(1.seconds)
proc p2() {.async.} =
await sleepAsync(1.seconds)
proc p3() {.async.} =
let
fut1 = p1()
fut2 = p2()
# Just by executing the async procs, both resulting futures entered the
# dispatcher's queue and their "clocks" started ticking.
await fut1
await fut2
# Only one second passed while awaiting them both, not two.
waitFor p3()
```
Don't let `await`'s behaviour of giving back control to the dispatcher surprise
you. If an async procedure modifies global state, and you can't predict when it
will start executing, the only way to avoid that state changing underneath your
feet, in a certain section, is to not use `await` in it.
### Error handling
Exceptions inheriting from `CatchableError` are caught by hidden `try` blocks
and placed in the `Future.error` field, changing the future's status to
`Failed`.
When a future is awaited, that exception is re-raised, only to be caught again
by a hidden `try` block in the calling async procedure. That's how these
exceptions move up the async chain.
A failed future's callbacks will still be scheduled, but it's not possible to
resume execution from the point an exception was raised.
```nim
proc p1() {.async.} =
await sleepAsync(1.seconds)
raise newException(ValueError, "ValueError inherits from CatchableError")
proc p2() {.async.} =
await sleepAsync(1.seconds)
proc p3() {.async.} =
let
fut1 = p1()
fut2 = p2()
await fut1
echo "unreachable code here"
await fut2
# `waitFor()` would call `Future.read()` unconditionally, which would raise the
# exception in `Future.error`.
let fut3 = p3()
while not(fut3.finished()):
poll()
echo "fut3.state = ", fut3.state # "Failed"
if fut3.failed():
echo "p3() failed: ", fut3.error.name, ": ", fut3.error.msg
# prints "p3() failed: ValueError: ValueError inherits from CatchableError"
```
You can put the `await` in a `try` block, to deal with that exception sooner:
```nim
proc p3() {.async.} =
let
fut1 = p1()
fut2 = p2()
try:
await fut1
except CachableError:
echo "p1() failed: ", fut1.error.name, ": ", fut1.error.msg
echo "reachable code here"
await fut2
```
Chronos does not allow that future continuations and other callbacks raise
`CatchableError` - as such, calls to `poll` will never raise exceptions caused
originating from tasks on the dispatcher queue. It is however possible that
`Defect` that happen in tasks bubble up through `poll` as these are not caught
by the transformation.
### Platform independence
Several functions in `chronos` are backed by the operating system, such as
waiting for network events, creating files and sockets etc. The specific
exceptions that are raised by the OS is platform-dependent, thus such functions
are declared as raising `CatchableError` but will in general raise something
more specific. In particular, it's possible that some functions that are
annotated as raising `CatchableError` only raise on _some_ platforms - in order
to work on all platforms, calling code must assume that they will raise even
when they don't seem to do so on one platform.
### Exception effects
`chronos` currently offers minimal support for exception effects and `raises`
annotations. In general, during the `async` transformation, a generic
`except CatchableError` handler is added around the entire function being
transformed, in order to catch any exceptions and transfer them to the `Future`.
Because of this, the effect system thinks no exceptions are "leaking" because in
fact, exception _handling_ is deferred to when the future is being read.
Effectively, this means that while code can be compiled with
`{.push raises: [Defect]}`, the intended effect propagation and checking is
**disabled** for `async` functions.
To enable checking exception effects in `async` code, enable strict mode with
`-d:chronosStrictException`.
In the strict mode, `async` functions are checked such that they only raise
`CatchableError` and thus must make sure to explicitly specify exception
effects on forward declarations, callbacks and methods using
`{.raises: [CatchableError].}` (or more strict) annotations.
### Cancellation support
Any running `Future` can be cancelled. This can be used to launch multiple
futures, and wait for one of them to finish, and cancel the rest of them,
to add timeout, or to let the user cancel a running task.
```nim
# Simple cancellation
let future = sleepAsync(10.minutes)
future.cancel()
# Wait for cancellation
let future2 = sleepAsync(10.minutes)
await future2.cancelAndWait()
# Race between futures
proc retrievePage(uri: string): Future[string] {.async.} =
# requires to import uri, chronos/apps/http/httpclient, stew/byteutils
let httpSession = HttpSessionRef.new()
try:
resp = await httpSession.fetch(parseUri(uri))
result = string.fromBytes(resp.data)
finally:
# be sure to always close the session
await httpSession.closeWait()
let
futs =
@[
retrievePage("https://duckduckgo.com/?q=chronos"),
retrievePage("https://www.google.fr/search?q=chronos")
]
let finishedFut = await one(futs)
for fut in futs:
if not fut.finished:
fut.cancel()
echo "Result: ", await finishedFut
```
When an `await` is cancelled, it will raise a `CancelledError`:
```nim
proc c1 {.async.} =
echo "Before sleep"
try:
await sleepAsync(10.minutes)
echo "After sleep" # not reach due to cancellation
except CancelledError as exc:
echo "We got cancelled!"
raise exc
proc c2 {.async.} =
await c1()
echo "Never reached, since the CancelledError got re-raised"
let work = c2()
waitFor(work.cancelAndWait())
```
The `CancelledError` will now travel up the stack like any other exception.
It can be caught and handled (for instance, freeing some resources)
### Multiple async backend support
Thanks to its powerful macro support, Nim allows `async`/`await` to be
implemented in libraries with only minimal support from the language - as such,
multiple `async` libraries exist, including `chronos` and `asyncdispatch`, and
more may come to be developed in the futures.
Libraries built on top of `async`/`await` may wish to support multiple async
backends - the best way to do so is to create separate modules for each backend
that may be imported side-by-side - see [nim-metrics](https://github.com/status-im/nim-metrics/blob/master/metrics/)
for an example.
An alternative way is to select backend using a global compile flag - this
method makes it diffucult to compose applications that use both backends as may
happen with transitive dependencies, but may be appropriate in some cases -
libraries choosing this path should call the flag `asyncBackend`, allowing
applications to choose the backend with `-d:asyncBackend=<backend_name>`.
Known `async` backends include:
* `chronos` - this library (`-d:asyncBackend=chronos`)
* `asyncdispatch` the standard library `asyncdispatch` [module](https://nim-lang.org/docs/asyncdispatch.html) (`-d:asyncBackend=asyncdispatch`)
* `none` - ``-d:asyncBackend=none`` - disable ``async`` support completely
``none`` can be used when a library supports both a synchronous and
asynchronous API, to disable the latter.
### Compile-time configuration
`chronos` contains several compile-time [configuration options](./chronos/config.nim) enabling stricter compile-time checks and debugging helpers whose runtime cost may be significant.
Strictness options generally will become default in future chronos releases and allow adapting existing code without changing the new version - see the [`config.nim`](./chronos/config.nim) module for more information.
## TODO ## TODO
* Pipe/Subprocess Transports.
* Multithreading Stream/Datagram servers * Multithreading Stream/Datagram servers
## Contributing ## Contributing
@ -349,10 +75,6 @@ When submitting pull requests, please add test cases for any new features or fix
`chronos` follows the [Status Nim Style Guide](https://status-im.github.io/nim-style-guide/). `chronos` follows the [Status Nim Style Guide](https://status-im.github.io/nim-style-guide/).
## Other resources
* [Historical differences with asyncdispatch](https://github.com/status-im/nim-chronos/wiki/AsyncDispatch-comparison)
## License ## License
Licensed and distributed under either of Licensed and distributed under either of

View File

@ -5,6 +5,10 @@
# Licensed under either of # Licensed under either of
# Apache License, version 2.0, (LICENSE-APACHEv2) # Apache License, version 2.0, (LICENSE-APACHEv2)
# MIT license (LICENSE-MIT) # MIT license (LICENSE-MIT)
import chronos/[asyncloop, asyncsync, handles, transport, timer,
asyncproc, debugutils] ## `async`/`await` framework for [Nim](https://nim-lang.org)
export asyncloop, asyncsync, handles, transport, timer, asyncproc, debugutils ##
## See https://status-im.github.io/nim-chronos/ for documentation
import chronos/[asyncloop, asyncsync, handles, transport, timer, debugutils]
export asyncloop, asyncsync, handles, transport, timer, debugutils

View File

@ -1,50 +1,85 @@
mode = ScriptMode.Verbose mode = ScriptMode.Verbose
packageName = "chronos" packageName = "chronos"
version = "3.2.0" version = "4.0.4"
author = "Status Research & Development GmbH" author = "Status Research & Development GmbH"
description = "Networking framework with async/await support" description = "Networking framework with async/await support"
license = "MIT or Apache License 2.0" license = "MIT or Apache License 2.0"
skipDirs = @["tests"] skipDirs = @["tests"]
requires "nim >= 1.2.0", requires "nim >= 1.6.16",
"results",
"stew", "stew",
"bearssl", "bearssl >= 0.2.5",
"httputils", "httputils",
"unittest2" "unittest2"
import os, strutils
let nimc = getEnv("NIMC", "nim") # Which nim compiler to use let nimc = getEnv("NIMC", "nim") # Which nim compiler to use
let lang = getEnv("NIMLANG", "c") # Which backend (c/cpp/js) let lang = getEnv("NIMLANG", "c") # Which backend (c/cpp/js)
let flags = getEnv("NIMFLAGS", "") # Extra flags for the compiler let flags = getEnv("NIMFLAGS", "") # Extra flags for the compiler
let verbose = getEnv("V", "") notin ["", "0"] let verbose = getEnv("V", "") notin ["", "0"]
let platform = getEnv("PLATFORM", "")
let testArguments =
when defined(windows):
[
"-d:debug -d:chronosDebug -d:useSysAssert -d:useGcAssert",
"-d:release",
]
else:
[
"-d:debug -d:chronosDebug -d:useSysAssert -d:useGcAssert",
"-d:debug -d:chronosDebug -d:chronosEventEngine=poll -d:useSysAssert -d:useGcAssert",
"-d:release",
]
let styleCheckStyle = if (NimMajor, NimMinor) < (1, 6): "hint" else: "error"
let cfg = let cfg =
" --styleCheck:usages --styleCheck:" & styleCheckStyle & " --styleCheck:usages --styleCheck:error" &
(if verbose: "" else: " --verbosity:0 --hints:off") & (if verbose: "" else: " --verbosity:0 --hints:off") &
" --skipParentCfg --skipUserCfg --outdir:build --nimcache:build/nimcache -f" " --skipParentCfg --skipUserCfg --outdir:build " &
quoteShell("--nimcache:build/nimcache/$projectName")
proc build(args, path: string) = proc build(args, path: string) =
exec nimc & " " & lang & " " & cfg & " " & flags & " " & args & " " & path exec nimc & " " & lang & " " & cfg & " " & flags & " " & args & " " & path
proc run(args, path: string) = proc run(args, path: string) =
build args & " -r", path build args, path
exec "build/" & path.splitPath[1]
task examples, "Build examples":
# Build book examples
for file in listFiles("docs/examples"):
if file.endsWith(".nim"):
build "--threads:on", file
task test, "Run all tests": task test, "Run all tests":
for args in [ for args in testArguments:
"-d:debug -d:chronosDebug", # First run tests with `refc` memory manager.
"-d:debug -d:chronosPreviewV4", run args & " --mm:refc", "tests/testall"
"-d:debug -d:chronosDebug -d:useSysAssert -d:useGcAssert",
"-d:release",
"-d:release -d:chronosPreviewV4"]:
run args, "tests/testall"
if (NimMajor, NimMinor) > (1, 6): if (NimMajor, NimMinor) > (1, 6):
run args & " --mm:refc", "tests/testall" run args & " --mm:orc", "tests/testall"
task test_v3_compat, "Run all tests in v3 compatibility mode":
for args in testArguments:
if (NimMajor, NimMinor) > (1, 6):
# First run tests with `refc` memory manager.
run args & " --mm:refc -d:chronosHandleException", "tests/testall"
run args & " -d:chronosHandleException", "tests/testall"
task test_libbacktrace, "test with libbacktrace": task test_libbacktrace, "test with libbacktrace":
var allArgs = @[ if platform != "x86":
let allArgs = @[
"-d:release --debugger:native -d:chronosStackTrace -d:nimStackTraceOverride --import:libbacktrace", "-d:release --debugger:native -d:chronosStackTrace -d:nimStackTraceOverride --import:libbacktrace",
] ]
for args in allArgs: for args in allArgs:
run args, "tests/testall" # First run tests with `refc` memory manager.
run args & " --mm:refc", "tests/testall"
if (NimMajor, NimMinor) > (1, 6):
run args & " --mm:orc", "tests/testall"
task docs, "Generate API documentation":
exec "mdbook build docs"
exec nimc & " doc " & "--git.url:https://github.com/status-im/nim-chronos --git.commit:master --outdir:docs/book/api --project chronos"

View File

@ -6,6 +6,9 @@
# Licensed under either of # Licensed under either of
# Apache License, version 2.0, (LICENSE-APACHEv2) # Apache License, version 2.0, (LICENSE-APACHEv2)
# MIT license (LICENSE-MIT) # MIT license (LICENSE-MIT)
{.push raises: [].}
import strutils import strutils
const const

View File

@ -6,6 +6,9 @@
# Licensed under either of # Licensed under either of
# Apache License, version 2.0, (LICENSE-APACHEv2) # Apache License, version 2.0, (LICENSE-APACHEv2)
# MIT license (LICENSE-MIT) # MIT license (LICENSE-MIT)
{.push raises: [].}
import ../../asyncloop, ../../asyncsync import ../../asyncloop, ../../asyncsync
import ../../streams/[asyncstream, boundstream] import ../../streams/[asyncstream, boundstream]
import httpcommon import httpcommon
@ -36,17 +39,17 @@ proc newHttpBodyReader*(streams: varargs[AsyncStreamReader]): HttpBodyReader =
trackCounter(HttpBodyReaderTrackerName) trackCounter(HttpBodyReaderTrackerName)
res res
proc closeWait*(bstream: HttpBodyReader) {.async.} = proc closeWait*(bstream: HttpBodyReader) {.async: (raises: []).} =
## Close and free resource allocated by body reader. ## Close and free resource allocated by body reader.
if bstream.bstate == HttpState.Alive: if bstream.bstate == HttpState.Alive:
bstream.bstate = HttpState.Closing bstream.bstate = HttpState.Closing
var res = newSeq[Future[void]]() var res = newSeq[Future[void].Raising([])]()
# We closing streams in reversed order because stream at position [0], uses # We closing streams in reversed order because stream at position [0], uses
# data from stream at position [1]. # data from stream at position [1].
for index in countdown((len(bstream.streams) - 1), 0): for index in countdown((len(bstream.streams) - 1), 0):
res.add(bstream.streams[index].closeWait()) res.add(bstream.streams[index].closeWait())
await allFutures(res) res.add(procCall(closeWait(AsyncStreamReader(bstream))))
await procCall(closeWait(AsyncStreamReader(bstream))) await noCancel(allFutures(res))
bstream.bstate = HttpState.Closed bstream.bstate = HttpState.Closed
untrackCounter(HttpBodyReaderTrackerName) untrackCounter(HttpBodyReaderTrackerName)
@ -61,19 +64,19 @@ proc newHttpBodyWriter*(streams: varargs[AsyncStreamWriter]): HttpBodyWriter =
trackCounter(HttpBodyWriterTrackerName) trackCounter(HttpBodyWriterTrackerName)
res res
proc closeWait*(bstream: HttpBodyWriter) {.async.} = proc closeWait*(bstream: HttpBodyWriter) {.async: (raises: []).} =
## Close and free all the resources allocated by body writer. ## Close and free all the resources allocated by body writer.
if bstream.bstate == HttpState.Alive: if bstream.bstate == HttpState.Alive:
bstream.bstate = HttpState.Closing bstream.bstate = HttpState.Closing
var res = newSeq[Future[void]]() var res = newSeq[Future[void].Raising([])]()
for index in countdown(len(bstream.streams) - 1, 0): for index in countdown(len(bstream.streams) - 1, 0):
res.add(bstream.streams[index].closeWait()) res.add(bstream.streams[index].closeWait())
await allFutures(res) await noCancel(allFutures(res))
await procCall(closeWait(AsyncStreamWriter(bstream))) await procCall(closeWait(AsyncStreamWriter(bstream)))
bstream.bstate = HttpState.Closed bstream.bstate = HttpState.Closed
untrackCounter(HttpBodyWriterTrackerName) untrackCounter(HttpBodyWriterTrackerName)
proc hasOverflow*(bstream: HttpBodyReader): bool {.raises: [].} = proc hasOverflow*(bstream: HttpBodyReader): bool =
if len(bstream.streams) == 1: if len(bstream.streams) == 1:
# If HttpBodyReader has only one stream it has ``BoundedStreamReader``, in # If HttpBodyReader has only one stream it has ``BoundedStreamReader``, in
# such case its impossible to get more bytes then expected amount. # such case its impossible to get more bytes then expected amount.
@ -89,6 +92,5 @@ proc hasOverflow*(bstream: HttpBodyReader): bool {.raises: [].} =
else: else:
false false
proc closed*(bstream: HttpBodyReader | HttpBodyWriter): bool {. proc closed*(bstream: HttpBodyReader | HttpBodyWriter): bool =
raises: [].} =
bstream.bstate != HttpState.Alive bstream.bstate != HttpState.Alive

View File

@ -6,14 +6,17 @@
# Licensed under either of # Licensed under either of
# Apache License, version 2.0, (LICENSE-APACHEv2) # Apache License, version 2.0, (LICENSE-APACHEv2)
# MIT license (LICENSE-MIT) # MIT license (LICENSE-MIT)
{.push raises: [].}
import std/[uri, tables, sequtils] import std/[uri, tables, sequtils]
import stew/[results, base10, base64, byteutils], httputils import stew/[base10, base64, byteutils], httputils, results
import ../../asyncloop, ../../asyncsync import ../../asyncloop, ../../asyncsync
import ../../streams/[asyncstream, tlsstream, chunkstream, boundstream] import ../../streams/[asyncstream, tlsstream, chunkstream, boundstream]
import httptable, httpcommon, httpagent, httpbodyrw, multipart import httptable, httpcommon, httpagent, httpbodyrw, multipart
export results, asyncloop, asyncsync, asyncstream, tlsstream, chunkstream, export results, asyncloop, asyncsync, asyncstream, tlsstream, chunkstream,
boundstream, httptable, httpcommon, httpagent, httpbodyrw, multipart, boundstream, httptable, httpcommon, httpagent, httpbodyrw, multipart,
httputils httputils, uri, results
export SocketFlags export SocketFlags
const const
@ -108,6 +111,7 @@ type
remoteHostname*: string remoteHostname*: string
flags*: set[HttpClientConnectionFlag] flags*: set[HttpClientConnectionFlag]
timestamp*: Moment timestamp*: Moment
duration*: Duration
HttpClientConnectionRef* = ref HttpClientConnection HttpClientConnectionRef* = ref HttpClientConnection
@ -119,12 +123,13 @@ type
headersTimeout*: Duration headersTimeout*: Duration
idleTimeout: Duration idleTimeout: Duration
idlePeriod: Duration idlePeriod: Duration
watcherFut: Future[void] watcherFut: Future[void].Raising([])
connectionBufferSize*: int connectionBufferSize*: int
maxConnections*: int maxConnections*: int
connectionsCount*: int connectionsCount*: int
socketFlags*: set[SocketFlags] socketFlags*: set[SocketFlags]
flags*: HttpClientFlags flags*: HttpClientFlags
dualstack*: DualStackType
HttpAddress* = object HttpAddress* = object
id*: string id*: string
@ -154,6 +159,7 @@ type
redirectCount: int redirectCount: int
timestamp*: Moment timestamp*: Moment
duration*: Duration duration*: Duration
headersBuffer: seq[byte]
HttpClientRequestRef* = ref HttpClientRequest HttpClientRequestRef* = ref HttpClientRequest
@ -194,6 +200,8 @@ type
name*: string name*: string
data*: string data*: string
HttpAddressResult* = Result[HttpAddress, HttpAddressErrorType]
# HttpClientRequestRef valid states are: # HttpClientRequestRef valid states are:
# Ready -> Open -> (Finished, Error) -> (Closing, Closed) # Ready -> Open -> (Finished, Error) -> (Closing, Closed)
# #
@ -233,6 +241,12 @@ template setDuration(
reqresp.duration = timestamp - reqresp.timestamp reqresp.duration = timestamp - reqresp.timestamp
reqresp.connection.setTimestamp(timestamp) reqresp.connection.setTimestamp(timestamp)
template setDuration(conn: HttpClientConnectionRef): untyped =
if not(isNil(conn)):
let timestamp = Moment.now()
conn.duration = timestamp - conn.timestamp
conn.setTimestamp(timestamp)
template isReady(conn: HttpClientConnectionRef): bool = template isReady(conn: HttpClientConnectionRef): bool =
(conn.state == HttpClientConnectionState.Ready) and (conn.state == HttpClientConnectionState.Ready) and
(HttpClientConnectionFlag.KeepAlive in conn.flags) and (HttpClientConnectionFlag.KeepAlive in conn.flags) and
@ -243,7 +257,7 @@ template isIdle(conn: HttpClientConnectionRef, timestamp: Moment,
timeout: Duration): bool = timeout: Duration): bool =
(timestamp - conn.timestamp) >= timeout (timestamp - conn.timestamp) >= timeout
proc sessionWatcher(session: HttpSessionRef) {.async.} proc sessionWatcher(session: HttpSessionRef) {.async: (raises: []).}
proc new*(t: typedesc[HttpSessionRef], proc new*(t: typedesc[HttpSessionRef],
flags: HttpClientFlags = {}, flags: HttpClientFlags = {},
@ -254,8 +268,8 @@ proc new*(t: typedesc[HttpSessionRef],
maxConnections = -1, maxConnections = -1,
idleTimeout = HttpConnectionIdleTimeout, idleTimeout = HttpConnectionIdleTimeout,
idlePeriod = HttpConnectionCheckPeriod, idlePeriod = HttpConnectionCheckPeriod,
socketFlags: set[SocketFlags] = {}): HttpSessionRef {. socketFlags: set[SocketFlags] = {},
raises: [] .} = dualstack = DualStackType.Auto): HttpSessionRef =
## Create new HTTP session object. ## Create new HTTP session object.
## ##
## ``maxRedirections`` - maximum number of HTTP 3xx redirections ## ``maxRedirections`` - maximum number of HTTP 3xx redirections
@ -274,16 +288,17 @@ proc new*(t: typedesc[HttpSessionRef],
idleTimeout: idleTimeout, idleTimeout: idleTimeout,
idlePeriod: idlePeriod, idlePeriod: idlePeriod,
connections: initTable[string, seq[HttpClientConnectionRef]](), connections: initTable[string, seq[HttpClientConnectionRef]](),
socketFlags: socketFlags socketFlags: socketFlags,
dualstack: dualstack
) )
res.watcherFut = res.watcherFut =
if HttpClientFlag.Http11Pipeline in flags: if HttpClientFlag.Http11Pipeline in flags:
sessionWatcher(res) sessionWatcher(res)
else: else:
newFuture[void]("session.watcher.placeholder") nil
res res
proc getTLSFlags(flags: HttpClientFlags): set[TLSFlags] {.raises: [] .} = proc getTLSFlags(flags: HttpClientFlags): set[TLSFlags] =
var res: set[TLSFlags] var res: set[TLSFlags]
if HttpClientFlag.NoVerifyHost in flags: if HttpClientFlag.NoVerifyHost in flags:
res.incl(TLSFlags.NoVerifyHost) res.incl(TLSFlags.NoVerifyHost)
@ -291,8 +306,90 @@ proc getTLSFlags(flags: HttpClientFlags): set[TLSFlags] {.raises: [] .} =
res.incl(TLSFlags.NoVerifyServerName) res.incl(TLSFlags.NoVerifyServerName)
res res
proc getAddress*(session: HttpSessionRef, url: Uri): HttpResult[HttpAddress] {. proc getHttpAddress*(
raises: [] .} = url: Uri,
flags: HttpClientFlags = {}
): HttpAddressResult =
let
scheme =
if len(url.scheme) == 0:
HttpClientScheme.NonSecure
else:
case toLowerAscii(url.scheme)
of "http":
HttpClientScheme.NonSecure
of "https":
HttpClientScheme.Secure
else:
return err(HttpAddressErrorType.InvalidUrlScheme)
port =
if len(url.port) == 0:
case scheme
of HttpClientScheme.NonSecure:
80'u16
of HttpClientScheme.Secure:
443'u16
else:
Base10.decode(uint16, url.port).valueOr:
return err(HttpAddressErrorType.InvalidPortNumber)
hostname =
block:
if len(url.hostname) == 0:
return err(HttpAddressErrorType.MissingHostname)
url.hostname
id = hostname & ":" & Base10.toString(port)
addresses =
if (HttpClientFlag.NoInet4Resolution in flags) and
(HttpClientFlag.NoInet6Resolution in flags):
# DNS resolution is disabled.
try:
@[initTAddress(hostname, Port(port))]
except TransportAddressError:
return err(HttpAddressErrorType.InvalidIpHostname)
else:
try:
if (HttpClientFlag.NoInet4Resolution notin flags) and
(HttpClientFlag.NoInet6Resolution notin flags):
# DNS resolution for both IPv4 and IPv6 addresses.
resolveTAddress(hostname, Port(port))
else:
if HttpClientFlag.NoInet6Resolution in flags:
# DNS resolution only for IPv4 addresses.
resolveTAddress(hostname, Port(port), AddressFamily.IPv4)
else:
# DNS resolution only for IPv6 addresses
resolveTAddress(hostname, Port(port), AddressFamily.IPv6)
except TransportAddressError:
return err(HttpAddressErrorType.NameLookupFailed)
if len(addresses) == 0:
return err(HttpAddressErrorType.NoAddressResolved)
ok(HttpAddress(id: id, scheme: scheme, hostname: hostname, port: port,
path: url.path, query: url.query, anchor: url.anchor,
username: url.username, password: url.password,
addresses: addresses))
proc getHttpAddress*(
url: string,
flags: HttpClientFlags = {}
): HttpAddressResult =
getHttpAddress(parseUri(url), flags)
proc getHttpAddress*(
session: HttpSessionRef,
url: Uri
): HttpAddressResult =
getHttpAddress(url, session.flags)
proc getHttpAddress*(
session: HttpSessionRef,
url: string
): HttpAddressResult =
## Create new HTTP address using URL string ``url`` and .
getHttpAddress(parseUri(url), session.flags)
proc getAddress*(session: HttpSessionRef, url: Uri): HttpResult[HttpAddress] =
let scheme = let scheme =
if len(url.scheme) == 0: if len(url.scheme) == 0:
HttpClientScheme.NonSecure HttpClientScheme.NonSecure
@ -356,13 +453,13 @@ proc getAddress*(session: HttpSessionRef, url: Uri): HttpResult[HttpAddress] {.
addresses: addresses)) addresses: addresses))
proc getAddress*(session: HttpSessionRef, proc getAddress*(session: HttpSessionRef,
url: string): HttpResult[HttpAddress] {.raises: [].} = url: string): HttpResult[HttpAddress] =
## Create new HTTP address using URL string ``url`` and . ## Create new HTTP address using URL string ``url`` and .
session.getAddress(parseUri(url)) session.getAddress(parseUri(url))
proc getAddress*(address: TransportAddress, proc getAddress*(address: TransportAddress,
ctype: HttpClientScheme = HttpClientScheme.NonSecure, ctype: HttpClientScheme = HttpClientScheme.NonSecure,
queryString: string = "/"): HttpAddress {.raises: [].} = queryString: string = "/"): HttpAddress =
## Create new HTTP address using Transport address ``address``, connection ## Create new HTTP address using Transport address ``address``, connection
## type ``ctype`` and query string ``queryString``. ## type ``ctype`` and query string ``queryString``.
let uri = parseUri(queryString) let uri = parseUri(queryString)
@ -445,8 +542,12 @@ proc getUniqueConnectionId(session: HttpSessionRef): uint64 =
inc(session.counter) inc(session.counter)
session.counter session.counter
proc new(t: typedesc[HttpClientConnectionRef], session: HttpSessionRef, proc new(
ha: HttpAddress, transp: StreamTransport): HttpClientConnectionRef = t: typedesc[HttpClientConnectionRef],
session: HttpSessionRef,
ha: HttpAddress,
transp: StreamTransport
): Result[HttpClientConnectionRef, string] =
case ha.scheme case ha.scheme
of HttpClientScheme.NonSecure: of HttpClientScheme.NonSecure:
let res = HttpClientConnectionRef( let res = HttpClientConnectionRef(
@ -459,108 +560,123 @@ proc new(t: typedesc[HttpClientConnectionRef], session: HttpSessionRef,
remoteHostname: ha.id remoteHostname: ha.id
) )
trackCounter(HttpClientConnectionTrackerName) trackCounter(HttpClientConnectionTrackerName)
res ok(res)
of HttpClientScheme.Secure: of HttpClientScheme.Secure:
let treader = newAsyncStreamReader(transp) let
let twriter = newAsyncStreamWriter(transp) treader = newAsyncStreamReader(transp)
let tls = newTLSClientAsyncStream(treader, twriter, ha.hostname, twriter = newAsyncStreamWriter(transp)
flags = session.flags.getTLSFlags()) tls =
let res = HttpClientConnectionRef( try:
id: session.getUniqueConnectionId(), newTLSClientAsyncStream(treader, twriter, ha.hostname,
kind: HttpClientScheme.Secure, flags = session.flags.getTLSFlags(),
transp: transp, bufferSize = session.connectionBufferSize)
treader: treader, except TLSStreamInitError as exc:
twriter: twriter, return err(exc.msg)
reader: tls.reader,
writer: tls.writer,
tls: tls,
state: HttpClientConnectionState.Connecting,
remoteHostname: ha.id
)
trackCounter(HttpClientConnectionTrackerName)
res
proc setError(request: HttpClientRequestRef, error: ref HttpError) {. res = HttpClientConnectionRef(
raises: [] .} = id: session.getUniqueConnectionId(),
kind: HttpClientScheme.Secure,
transp: transp,
treader: treader,
twriter: twriter,
reader: tls.reader,
writer: tls.writer,
tls: tls,
state: HttpClientConnectionState.Connecting,
remoteHostname: ha.id
)
trackCounter(HttpClientConnectionTrackerName)
ok(res)
proc setError(request: HttpClientRequestRef, error: ref HttpError) =
request.error = error request.error = error
request.state = HttpReqRespState.Error request.state = HttpReqRespState.Error
if not(isNil(request.connection)): if not(isNil(request.connection)):
request.connection.state = HttpClientConnectionState.Error request.connection.state = HttpClientConnectionState.Error
request.connection.error = error request.connection.error = error
proc setError(response: HttpClientResponseRef, error: ref HttpError) {. proc setError(response: HttpClientResponseRef, error: ref HttpError) =
raises: [] .} =
response.error = error response.error = error
response.state = HttpReqRespState.Error response.state = HttpReqRespState.Error
if not(isNil(response.connection)): if not(isNil(response.connection)):
response.connection.state = HttpClientConnectionState.Error response.connection.state = HttpClientConnectionState.Error
response.connection.error = error response.connection.error = error
proc closeWait(conn: HttpClientConnectionRef) {.async.} = proc closeWait(conn: HttpClientConnectionRef) {.async: (raises: []).} =
## Close HttpClientConnectionRef instance ``conn`` and free all the resources. ## Close HttpClientConnectionRef instance ``conn`` and free all the resources.
if conn.state notin {HttpClientConnectionState.Closing, if conn.state notin {HttpClientConnectionState.Closing,
HttpClientConnectionState.Closed}: HttpClientConnectionState.Closed}:
conn.state = HttpClientConnectionState.Closing conn.state = HttpClientConnectionState.Closing
let pending = let pending =
block: block:
var res: seq[Future[void]] var res: seq[Future[void].Raising([])]
if not(isNil(conn.reader)) and not(conn.reader.closed()): if not(isNil(conn.reader)) and not(conn.reader.closed()):
res.add(conn.reader.closeWait()) res.add(conn.reader.closeWait())
if not(isNil(conn.writer)) and not(conn.writer.closed()): if not(isNil(conn.writer)) and not(conn.writer.closed()):
res.add(conn.writer.closeWait()) res.add(conn.writer.closeWait())
if conn.kind == HttpClientScheme.Secure:
res.add(conn.treader.closeWait())
res.add(conn.twriter.closeWait())
res.add(conn.transp.closeWait())
res res
if len(pending) > 0: await allFutures(pending) if len(pending) > 0: await noCancel(allFutures(pending))
case conn.kind
of HttpClientScheme.Secure:
await allFutures(conn.treader.closeWait(), conn.twriter.closeWait())
of HttpClientScheme.NonSecure:
discard
await conn.transp.closeWait()
conn.state = HttpClientConnectionState.Closed conn.state = HttpClientConnectionState.Closed
untrackCounter(HttpClientConnectionTrackerName) untrackCounter(HttpClientConnectionTrackerName)
proc connect(session: HttpSessionRef, proc connect(session: HttpSessionRef,
ha: HttpAddress): Future[HttpClientConnectionRef] {.async.} = ha: HttpAddress): Future[HttpClientConnectionRef] {.
async: (raises: [CancelledError, HttpConnectionError]).} =
## Establish new connection with remote server using ``url`` and ``flags``. ## Establish new connection with remote server using ``url`` and ``flags``.
## On success returns ``HttpClientConnectionRef`` object. ## On success returns ``HttpClientConnectionRef`` object.
var lastError = ""
# Here we trying to connect to every possible remote host address we got after # Here we trying to connect to every possible remote host address we got after
# DNS resolution. # DNS resolution.
for address in ha.addresses: for address in ha.addresses:
let transp = let transp =
try: try:
await connect(address, bufferSize = session.connectionBufferSize, await connect(address, bufferSize = session.connectionBufferSize,
flags = session.socketFlags) flags = session.socketFlags,
dualstack = session.dualstack)
except CancelledError as exc: except CancelledError as exc:
raise exc raise exc
except CatchableError: except TransportError:
nil nil
if not(isNil(transp)): if not(isNil(transp)):
let conn = let conn =
block: block:
let res = HttpClientConnectionRef.new(session, ha, transp) let res = HttpClientConnectionRef.new(session, ha, transp).valueOr:
case res.kind raiseHttpConnectionError(
of HttpClientScheme.Secure: "Could not connect to remote host, reason: " & error)
if res.kind == HttpClientScheme.Secure:
try: try:
await res.tls.handshake() await res.tls.handshake()
res.state = HttpClientConnectionState.Ready res.state = HttpClientConnectionState.Ready
except CancelledError as exc: except CancelledError as exc:
await res.closeWait() await res.closeWait()
raise exc raise exc
except AsyncStreamError: except TLSStreamProtocolError as exc:
await res.closeWait() await res.closeWait()
res.state = HttpClientConnectionState.Error res.state = HttpClientConnectionState.Error
of HttpClientScheme.Nonsecure: lastError = $exc.msg
except AsyncStreamError as exc:
await res.closeWait()
res.state = HttpClientConnectionState.Error
lastError = $exc.msg
else:
res.state = HttpClientConnectionState.Ready res.state = HttpClientConnectionState.Ready
res res
if conn.state == HttpClientConnectionState.Ready: if conn.state == HttpClientConnectionState.Ready:
return conn return conn
# If all attempts to connect to the remote host have failed. # If all attempts to connect to the remote host have failed.
raiseHttpConnectionError("Could not connect to remote host") if len(lastError) > 0:
raiseHttpConnectionError("Could not connect to remote host, reason: " &
lastError)
else:
raiseHttpConnectionError("Could not connect to remote host")
proc removeConnection(session: HttpSessionRef, proc removeConnection(session: HttpSessionRef,
conn: HttpClientConnectionRef) {.async.} = conn: HttpClientConnectionRef) {.async: (raises: []).} =
let removeHost = let removeHost =
block: block:
var res = false var res = false
@ -584,12 +700,13 @@ proc acquireConnection(
session: HttpSessionRef, session: HttpSessionRef,
ha: HttpAddress, ha: HttpAddress,
flags: set[HttpClientRequestFlag] flags: set[HttpClientRequestFlag]
): Future[HttpClientConnectionRef] {.async.} = ): Future[HttpClientConnectionRef] {.
async: (raises: [CancelledError, HttpConnectionError]).} =
## Obtain connection from ``session`` or establish a new one. ## Obtain connection from ``session`` or establish a new one.
var default: seq[HttpClientConnectionRef] var default: seq[HttpClientConnectionRef]
let timestamp = Moment.now()
if session.connectionPoolEnabled(flags): if session.connectionPoolEnabled(flags):
# Trying to reuse existing connection from our connection's pool. # Trying to reuse existing connection from our connection's pool.
let timestamp = Moment.now()
# We looking for non-idle connection at `Ready` state, all idle connections # We looking for non-idle connection at `Ready` state, all idle connections
# will be freed by sessionWatcher(). # will be freed by sessionWatcher().
for connection in session.connections.getOrDefault(ha.id): for connection in session.connections.getOrDefault(ha.id):
@ -606,10 +723,13 @@ proc acquireConnection(
connection.state = HttpClientConnectionState.Acquired connection.state = HttpClientConnectionState.Acquired
session.connections.mgetOrPut(ha.id, default).add(connection) session.connections.mgetOrPut(ha.id, default).add(connection)
inc(session.connectionsCount) inc(session.connectionsCount)
return connection connection.setTimestamp(timestamp)
connection.setDuration()
connection
proc releaseConnection(session: HttpSessionRef, proc releaseConnection(session: HttpSessionRef,
connection: HttpClientConnectionRef) {.async.} = connection: HttpClientConnectionRef) {.
async: (raises: []).} =
## Return connection back to the ``session``. ## Return connection back to the ``session``.
let removeConnection = let removeConnection =
if HttpClientFlag.Http11Pipeline notin session.flags: if HttpClientFlag.Http11Pipeline notin session.flags:
@ -647,7 +767,7 @@ proc releaseConnection(session: HttpSessionRef,
HttpClientConnectionFlag.Response, HttpClientConnectionFlag.Response,
HttpClientConnectionFlag.NoBody}) HttpClientConnectionFlag.NoBody})
proc releaseConnection(request: HttpClientRequestRef) {.async.} = proc releaseConnection(request: HttpClientRequestRef) {.async: (raises: []).} =
let let
session = request.session session = request.session
connection = request.connection connection = request.connection
@ -659,7 +779,8 @@ proc releaseConnection(request: HttpClientRequestRef) {.async.} =
if HttpClientConnectionFlag.Response notin connection.flags: if HttpClientConnectionFlag.Response notin connection.flags:
await session.releaseConnection(connection) await session.releaseConnection(connection)
proc releaseConnection(response: HttpClientResponseRef) {.async.} = proc releaseConnection(response: HttpClientResponseRef) {.
async: (raises: []).} =
let let
session = response.session session = response.session
connection = response.connection connection = response.connection
@ -671,7 +792,7 @@ proc releaseConnection(response: HttpClientResponseRef) {.async.} =
if HttpClientConnectionFlag.Request notin connection.flags: if HttpClientConnectionFlag.Request notin connection.flags:
await session.releaseConnection(connection) await session.releaseConnection(connection)
proc closeWait*(session: HttpSessionRef) {.async.} = proc closeWait*(session: HttpSessionRef) {.async: (raises: []).} =
## Closes HTTP session object. ## Closes HTTP session object.
## ##
## This closes all the connections opened to remote servers. ## This closes all the connections opened to remote servers.
@ -682,9 +803,9 @@ proc closeWait*(session: HttpSessionRef) {.async.} =
for connections in session.connections.values(): for connections in session.connections.values():
for conn in connections: for conn in connections:
pending.add(closeWait(conn)) pending.add(closeWait(conn))
await allFutures(pending) await noCancel(allFutures(pending))
proc sessionWatcher(session: HttpSessionRef) {.async.} = proc sessionWatcher(session: HttpSessionRef) {.async: (raises: []).} =
while true: while true:
let firstBreak = let firstBreak =
try: try:
@ -715,45 +836,53 @@ proc sessionWatcher(session: HttpSessionRef) {.async.} =
var pending: seq[Future[void]] var pending: seq[Future[void]]
let secondBreak = let secondBreak =
try: try:
pending = idleConnections.mapIt(it.closeWait()) for conn in idleConnections:
pending.add(conn.closeWait())
await allFutures(pending) await allFutures(pending)
false false
except CancelledError: except CancelledError:
# We still want to close connections to avoid socket leaks. # We still want to close connections to avoid socket leaks.
await allFutures(pending) await noCancel(allFutures(pending))
true true
if secondBreak: if secondBreak:
break break
proc closeWait*(request: HttpClientRequestRef) {.async.} = proc closeWait*(request: HttpClientRequestRef) {.async: (raises: []).} =
var pending: seq[Future[void].Raising([])]
if request.state notin {HttpReqRespState.Closing, HttpReqRespState.Closed}: if request.state notin {HttpReqRespState.Closing, HttpReqRespState.Closed}:
request.state = HttpReqRespState.Closing request.state = HttpReqRespState.Closing
if not(isNil(request.writer)): if not(isNil(request.writer)):
if not(request.writer.closed()): if not(request.writer.closed()):
await request.writer.closeWait() pending.add(request.writer.closeWait())
request.writer = nil request.writer = nil
await request.releaseConnection() pending.add(request.releaseConnection())
await noCancel(allFutures(pending))
request.session = nil request.session = nil
request.error = nil request.error = nil
request.headersBuffer.reset()
request.state = HttpReqRespState.Closed request.state = HttpReqRespState.Closed
untrackCounter(HttpClientRequestTrackerName) untrackCounter(HttpClientRequestTrackerName)
proc closeWait*(response: HttpClientResponseRef) {.async.} = proc closeWait*(response: HttpClientResponseRef) {.async: (raises: []).} =
var pending: seq[Future[void].Raising([])]
if response.state notin {HttpReqRespState.Closing, HttpReqRespState.Closed}: if response.state notin {HttpReqRespState.Closing, HttpReqRespState.Closed}:
response.state = HttpReqRespState.Closing response.state = HttpReqRespState.Closing
if not(isNil(response.reader)): if not(isNil(response.reader)):
if not(response.reader.closed()): if not(response.reader.closed()):
await response.reader.closeWait() pending.add(response.reader.closeWait())
response.reader = nil response.reader = nil
await response.releaseConnection() pending.add(response.releaseConnection())
await noCancel(allFutures(pending))
response.session = nil response.session = nil
response.error = nil response.error = nil
response.state = HttpReqRespState.Closed response.state = HttpReqRespState.Closed
untrackCounter(HttpClientResponseTrackerName) untrackCounter(HttpClientResponseTrackerName)
proc prepareResponse(request: HttpClientRequestRef, data: openArray[byte] proc prepareResponse(
): HttpResult[HttpClientResponseRef] {.raises: [] .} = request: HttpClientRequestRef,
data: openArray[byte]
): HttpResult[HttpClientResponseRef] =
## Process response headers. ## Process response headers.
let resp = parseResponse(data, false) let resp = parseResponse(data, false)
if resp.failed(): if resp.failed():
@ -864,39 +993,41 @@ proc prepareResponse(request: HttpClientRequestRef, data: openArray[byte]
ok(res) ok(res)
proc getResponse(req: HttpClientRequestRef): Future[HttpClientResponseRef] {. proc getResponse(req: HttpClientRequestRef): Future[HttpClientResponseRef] {.
async.} = async: (raises: [CancelledError, HttpError]).} =
var buffer: array[HttpMaxHeadersSize, byte]
let timestamp = Moment.now() let timestamp = Moment.now()
req.connection.setTimestamp(timestamp) req.connection.setTimestamp(timestamp)
let let
bytesRead = bytesRead =
try: try:
await req.connection.reader.readUntil(addr buffer[0], await req.connection.reader.readUntil(addr req.headersBuffer[0],
len(buffer), HeadersMark).wait( len(req.headersBuffer),
HeadersMark).wait(
req.session.headersTimeout) req.session.headersTimeout)
except AsyncTimeoutError: except AsyncTimeoutError:
raiseHttpReadError("Reading response headers timed out") raiseHttpReadError("Reading response headers timed out")
except AsyncStreamError: except AsyncStreamError as exc:
raiseHttpReadError("Could not read response headers") raiseHttpReadError(
"Could not read response headers, reason: " & $exc.msg)
let response = prepareResponse(req, buffer.toOpenArray(0, bytesRead - 1)) let response =
if response.isErr(): prepareResponse(req,
raiseHttpProtocolError(response.error()) req.headersBuffer.toOpenArray(0, bytesRead - 1)).valueOr:
let res = response.get() raiseHttpProtocolError(error)
res.setTimestamp(timestamp) response.setTimestamp(timestamp)
return res response
proc new*(t: typedesc[HttpClientRequestRef], session: HttpSessionRef, proc new*(t: typedesc[HttpClientRequestRef], session: HttpSessionRef,
ha: HttpAddress, meth: HttpMethod = MethodGet, ha: HttpAddress, meth: HttpMethod = MethodGet,
version: HttpVersion = HttpVersion11, version: HttpVersion = HttpVersion11,
flags: set[HttpClientRequestFlag] = {}, flags: set[HttpClientRequestFlag] = {},
maxResponseHeadersSize: int = HttpMaxHeadersSize,
headers: openArray[HttpHeaderTuple] = [], headers: openArray[HttpHeaderTuple] = [],
body: openArray[byte] = []): HttpClientRequestRef {. body: openArray[byte] = []): HttpClientRequestRef =
raises: [].} =
let res = HttpClientRequestRef( let res = HttpClientRequestRef(
state: HttpReqRespState.Ready, session: session, meth: meth, state: HttpReqRespState.Ready, session: session, meth: meth,
version: version, flags: flags, headers: HttpTable.init(headers), version: version, flags: flags, headers: HttpTable.init(headers),
address: ha, bodyFlag: HttpClientBodyFlag.Custom, buffer: @body address: ha, bodyFlag: HttpClientBodyFlag.Custom, buffer: @body,
headersBuffer: newSeq[byte](max(maxResponseHeadersSize, HttpMaxHeadersSize))
) )
trackCounter(HttpClientRequestTrackerName) trackCounter(HttpClientRequestTrackerName)
res res
@ -905,14 +1036,15 @@ proc new*(t: typedesc[HttpClientRequestRef], session: HttpSessionRef,
url: string, meth: HttpMethod = MethodGet, url: string, meth: HttpMethod = MethodGet,
version: HttpVersion = HttpVersion11, version: HttpVersion = HttpVersion11,
flags: set[HttpClientRequestFlag] = {}, flags: set[HttpClientRequestFlag] = {},
maxResponseHeadersSize: int = HttpMaxHeadersSize,
headers: openArray[HttpHeaderTuple] = [], headers: openArray[HttpHeaderTuple] = [],
body: openArray[byte] = []): HttpResult[HttpClientRequestRef] {. body: openArray[byte] = []): HttpResult[HttpClientRequestRef] =
raises: [].} =
let address = ? session.getAddress(parseUri(url)) let address = ? session.getAddress(parseUri(url))
let res = HttpClientRequestRef( let res = HttpClientRequestRef(
state: HttpReqRespState.Ready, session: session, meth: meth, state: HttpReqRespState.Ready, session: session, meth: meth,
version: version, flags: flags, headers: HttpTable.init(headers), version: version, flags: flags, headers: HttpTable.init(headers),
address: address, bodyFlag: HttpClientBodyFlag.Custom, buffer: @body address: address, bodyFlag: HttpClientBodyFlag.Custom, buffer: @body,
headersBuffer: newSeq[byte](max(maxResponseHeadersSize, HttpMaxHeadersSize))
) )
trackCounter(HttpClientRequestTrackerName) trackCounter(HttpClientRequestTrackerName)
ok(res) ok(res)
@ -920,55 +1052,61 @@ proc new*(t: typedesc[HttpClientRequestRef], session: HttpSessionRef,
proc get*(t: typedesc[HttpClientRequestRef], session: HttpSessionRef, proc get*(t: typedesc[HttpClientRequestRef], session: HttpSessionRef,
url: string, version: HttpVersion = HttpVersion11, url: string, version: HttpVersion = HttpVersion11,
flags: set[HttpClientRequestFlag] = {}, flags: set[HttpClientRequestFlag] = {},
maxResponseHeadersSize: int = HttpMaxHeadersSize,
headers: openArray[HttpHeaderTuple] = [] headers: openArray[HttpHeaderTuple] = []
): HttpResult[HttpClientRequestRef] {.raises: [].} = ): HttpResult[HttpClientRequestRef] =
HttpClientRequestRef.new(session, url, MethodGet, version, flags, headers) HttpClientRequestRef.new(session, url, MethodGet, version, flags,
maxResponseHeadersSize, headers)
proc get*(t: typedesc[HttpClientRequestRef], session: HttpSessionRef, proc get*(t: typedesc[HttpClientRequestRef], session: HttpSessionRef,
ha: HttpAddress, version: HttpVersion = HttpVersion11, ha: HttpAddress, version: HttpVersion = HttpVersion11,
flags: set[HttpClientRequestFlag] = {}, flags: set[HttpClientRequestFlag] = {},
maxResponseHeadersSize: int = HttpMaxHeadersSize,
headers: openArray[HttpHeaderTuple] = [] headers: openArray[HttpHeaderTuple] = []
): HttpClientRequestRef {.raises: [].} = ): HttpClientRequestRef =
HttpClientRequestRef.new(session, ha, MethodGet, version, flags, headers) HttpClientRequestRef.new(session, ha, MethodGet, version, flags,
maxResponseHeadersSize, headers)
proc post*(t: typedesc[HttpClientRequestRef], session: HttpSessionRef, proc post*(t: typedesc[HttpClientRequestRef], session: HttpSessionRef,
url: string, version: HttpVersion = HttpVersion11, url: string, version: HttpVersion = HttpVersion11,
flags: set[HttpClientRequestFlag] = {}, flags: set[HttpClientRequestFlag] = {},
maxResponseHeadersSize: int = HttpMaxHeadersSize,
headers: openArray[HttpHeaderTuple] = [], headers: openArray[HttpHeaderTuple] = [],
body: openArray[byte] = [] body: openArray[byte] = []
): HttpResult[HttpClientRequestRef] {.raises: [].} = ): HttpResult[HttpClientRequestRef] =
HttpClientRequestRef.new(session, url, MethodPost, version, flags, headers, HttpClientRequestRef.new(session, url, MethodPost, version, flags,
body) maxResponseHeadersSize, headers, body)
proc post*(t: typedesc[HttpClientRequestRef], session: HttpSessionRef, proc post*(t: typedesc[HttpClientRequestRef], session: HttpSessionRef,
url: string, version: HttpVersion = HttpVersion11, url: string, version: HttpVersion = HttpVersion11,
flags: set[HttpClientRequestFlag] = {}, flags: set[HttpClientRequestFlag] = {},
maxResponseHeadersSize: int = HttpMaxHeadersSize,
headers: openArray[HttpHeaderTuple] = [], headers: openArray[HttpHeaderTuple] = [],
body: openArray[char] = []): HttpResult[HttpClientRequestRef] {. body: openArray[char] = []): HttpResult[HttpClientRequestRef] =
raises: [].} = HttpClientRequestRef.new(session, url, MethodPost, version, flags,
HttpClientRequestRef.new(session, url, MethodPost, version, flags, headers, maxResponseHeadersSize, headers,
body.toOpenArrayByte(0, len(body) - 1)) body.toOpenArrayByte(0, len(body) - 1))
proc post*(t: typedesc[HttpClientRequestRef], session: HttpSessionRef, proc post*(t: typedesc[HttpClientRequestRef], session: HttpSessionRef,
ha: HttpAddress, version: HttpVersion = HttpVersion11, ha: HttpAddress, version: HttpVersion = HttpVersion11,
flags: set[HttpClientRequestFlag] = {}, flags: set[HttpClientRequestFlag] = {},
maxResponseHeadersSize: int = HttpMaxHeadersSize,
headers: openArray[HttpHeaderTuple] = [], headers: openArray[HttpHeaderTuple] = [],
body: openArray[byte] = []): HttpClientRequestRef {. body: openArray[byte] = []): HttpClientRequestRef =
raises: [].} = HttpClientRequestRef.new(session, ha, MethodPost, version, flags,
HttpClientRequestRef.new(session, ha, MethodPost, version, flags, headers, maxResponseHeadersSize, headers, body)
body)
proc post*(t: typedesc[HttpClientRequestRef], session: HttpSessionRef, proc post*(t: typedesc[HttpClientRequestRef], session: HttpSessionRef,
ha: HttpAddress, version: HttpVersion = HttpVersion11, ha: HttpAddress, version: HttpVersion = HttpVersion11,
flags: set[HttpClientRequestFlag] = {}, flags: set[HttpClientRequestFlag] = {},
maxResponseHeadersSize: int = HttpMaxHeadersSize,
headers: openArray[HttpHeaderTuple] = [], headers: openArray[HttpHeaderTuple] = [],
body: openArray[char] = []): HttpClientRequestRef {. body: openArray[char] = []): HttpClientRequestRef =
raises: [].} = HttpClientRequestRef.new(session, ha, MethodPost, version, flags,
HttpClientRequestRef.new(session, ha, MethodPost, version, flags, headers, maxResponseHeadersSize, headers,
body.toOpenArrayByte(0, len(body) - 1)) body.toOpenArrayByte(0, len(body) - 1))
proc prepareRequest(request: HttpClientRequestRef): string {. proc prepareRequest(request: HttpClientRequestRef): string =
raises: [].} =
template hasChunkedEncoding(request: HttpClientRequestRef): bool = template hasChunkedEncoding(request: HttpClientRequestRef): bool =
toLowerAscii(request.headers.getString(TransferEncodingHeader)) == "chunked" toLowerAscii(request.headers.getString(TransferEncodingHeader)) == "chunked"
@ -1043,7 +1181,7 @@ proc prepareRequest(request: HttpClientRequestRef): string {.
res res
proc send*(request: HttpClientRequestRef): Future[HttpClientResponseRef] {. proc send*(request: HttpClientRequestRef): Future[HttpClientResponseRef] {.
async.} = async: (raises: [CancelledError, HttpError]).} =
doAssert(request.state == HttpReqRespState.Ready, doAssert(request.state == HttpReqRespState.Ready,
"Request's state is " & $request.state) "Request's state is " & $request.state)
let connection = let connection =
@ -1076,25 +1214,24 @@ proc send*(request: HttpClientRequestRef): Future[HttpClientResponseRef] {.
request.setDuration() request.setDuration()
request.setError(newHttpInterruptError()) request.setError(newHttpInterruptError())
raise exc raise exc
except AsyncStreamError: except AsyncStreamError as exc:
request.setDuration() request.setDuration()
let error = newHttpWriteError("Could not send request headers") let error = newHttpWriteError(
"Could not send request headers, reason: " & $exc.msg)
request.setError(error) request.setError(error)
raise error raise error
let resp = try:
try: await request.getResponse()
await request.getResponse() except CancelledError as exc:
except CancelledError as exc: request.setError(newHttpInterruptError())
request.setError(newHttpInterruptError()) raise exc
raise exc except HttpError as exc:
except HttpError as exc: request.setError(exc)
request.setError(exc) raise exc
raise exc
return resp
proc open*(request: HttpClientRequestRef): Future[HttpBodyWriter] {. proc open*(request: HttpClientRequestRef): Future[HttpBodyWriter] {.
async.} = async: (raises: [CancelledError, HttpError]).} =
## Start sending request's headers and return `HttpBodyWriter`, which can be ## Start sending request's headers and return `HttpBodyWriter`, which can be
## used to send request's body. ## used to send request's body.
doAssert(request.state == HttpReqRespState.Ready, doAssert(request.state == HttpReqRespState.Ready,
@ -1124,8 +1261,9 @@ proc open*(request: HttpClientRequestRef): Future[HttpBodyWriter] {.
request.setDuration() request.setDuration()
request.setError(newHttpInterruptError()) request.setError(newHttpInterruptError())
raise exc raise exc
except AsyncStreamError: except AsyncStreamError as exc:
let error = newHttpWriteError("Could not send request headers") let error = newHttpWriteError(
"Could not send request headers, reason: " & $exc.msg)
request.setDuration() request.setDuration()
request.setError(error) request.setError(error)
raise error raise error
@ -1147,10 +1285,10 @@ proc open*(request: HttpClientRequestRef): Future[HttpBodyWriter] {.
request.writer = writer request.writer = writer
request.state = HttpReqRespState.Open request.state = HttpReqRespState.Open
request.connection.state = HttpClientConnectionState.RequestBodySending request.connection.state = HttpClientConnectionState.RequestBodySending
return writer writer
proc finish*(request: HttpClientRequestRef): Future[HttpClientResponseRef] {. proc finish*(request: HttpClientRequestRef): Future[HttpClientResponseRef] {.
async.} = async: (raises: [CancelledError, HttpError]).} =
## Finish sending request and receive response. ## Finish sending request and receive response.
doAssert(not(isNil(request.connection)), doAssert(not(isNil(request.connection)),
"Request missing connection instance") "Request missing connection instance")
@ -1187,7 +1325,8 @@ proc getNewLocation*(resp: HttpClientResponseRef): HttpResult[HttpAddress] =
else: else:
err("Location header is missing") err("Location header is missing")
proc getBodyReader*(response: HttpClientResponseRef): HttpBodyReader = proc getBodyReader*(response: HttpClientResponseRef): HttpBodyReader {.
raises: [HttpUseClosedError].} =
## Returns stream's reader instance which can be used to read response's body. ## Returns stream's reader instance which can be used to read response's body.
## ##
## Streams which was obtained using this procedure must be closed to avoid ## Streams which was obtained using this procedure must be closed to avoid
@ -1205,18 +1344,24 @@ proc getBodyReader*(response: HttpClientResponseRef): HttpBodyReader =
let reader = let reader =
case response.bodyFlag case response.bodyFlag
of HttpClientBodyFlag.Sized: of HttpClientBodyFlag.Sized:
let bstream = newBoundedStreamReader(response.connection.reader, newHttpBodyReader(
response.contentLength) newBoundedStreamReader(
newHttpBodyReader(bstream) response.connection.reader, response.contentLength,
bufferSize = response.session.connectionBufferSize))
of HttpClientBodyFlag.Chunked: of HttpClientBodyFlag.Chunked:
newHttpBodyReader(newChunkedStreamReader(response.connection.reader)) newHttpBodyReader(
newChunkedStreamReader(
response.connection.reader,
bufferSize = response.session.connectionBufferSize))
of HttpClientBodyFlag.Custom: of HttpClientBodyFlag.Custom:
newHttpBodyReader(newAsyncStreamReader(response.connection.reader)) newHttpBodyReader(
newAsyncStreamReader(response.connection.reader))
response.connection.state = HttpClientConnectionState.ResponseBodyReceiving response.connection.state = HttpClientConnectionState.ResponseBodyReceiving
response.reader = reader response.reader = reader
response.reader response.reader
proc finish*(response: HttpClientResponseRef) {.async.} = proc finish*(response: HttpClientResponseRef) {.
async: (raises: [HttpUseClosedError]).} =
## Finish receiving response. ## Finish receiving response.
## ##
## Because ``finish()`` returns nothing, this operation become NOP for ## Because ``finish()`` returns nothing, this operation become NOP for
@ -1235,7 +1380,7 @@ proc finish*(response: HttpClientResponseRef) {.async.} =
response.setDuration() response.setDuration()
proc getBodyBytes*(response: HttpClientResponseRef): Future[seq[byte]] {. proc getBodyBytes*(response: HttpClientResponseRef): Future[seq[byte]] {.
async.} = async: (raises: [CancelledError, HttpError]).} =
## Read all bytes from response ``response``. ## Read all bytes from response ``response``.
## ##
## Note: This procedure performs automatic finishing for ``response``. ## Note: This procedure performs automatic finishing for ``response``.
@ -1245,21 +1390,22 @@ proc getBodyBytes*(response: HttpClientResponseRef): Future[seq[byte]] {.
await reader.closeWait() await reader.closeWait()
reader = nil reader = nil
await response.finish() await response.finish()
return data data
except CancelledError as exc: except CancelledError as exc:
if not(isNil(reader)): if not(isNil(reader)):
await reader.closeWait() await reader.closeWait()
response.setError(newHttpInterruptError()) response.setError(newHttpInterruptError())
raise exc raise exc
except AsyncStreamError: except AsyncStreamError as exc:
let error = newHttpReadError("Could not read response, reason: " & $exc.msg)
if not(isNil(reader)): if not(isNil(reader)):
await reader.closeWait() await reader.closeWait()
let error = newHttpReadError("Could not read response")
response.setError(error) response.setError(error)
raise error raise error
proc getBodyBytes*(response: HttpClientResponseRef, proc getBodyBytes*(response: HttpClientResponseRef,
nbytes: int): Future[seq[byte]] {.async.} = nbytes: int): Future[seq[byte]] {.
async: (raises: [CancelledError, HttpError]).} =
## Read all bytes (nbytes <= 0) or exactly `nbytes` bytes from response ## Read all bytes (nbytes <= 0) or exactly `nbytes` bytes from response
## ``response``. ## ``response``.
## ##
@ -1270,20 +1416,21 @@ proc getBodyBytes*(response: HttpClientResponseRef,
await reader.closeWait() await reader.closeWait()
reader = nil reader = nil
await response.finish() await response.finish()
return data data
except CancelledError as exc: except CancelledError as exc:
if not(isNil(reader)): if not(isNil(reader)):
await reader.closeWait() await reader.closeWait()
response.setError(newHttpInterruptError()) response.setError(newHttpInterruptError())
raise exc raise exc
except AsyncStreamError: except AsyncStreamError as exc:
let error = newHttpReadError("Could not read response, reason: " & $exc.msg)
if not(isNil(reader)): if not(isNil(reader)):
await reader.closeWait() await reader.closeWait()
let error = newHttpReadError("Could not read response")
response.setError(error) response.setError(error)
raise error raise error
proc consumeBody*(response: HttpClientResponseRef): Future[int] {.async.} = proc consumeBody*(response: HttpClientResponseRef): Future[int] {.
async: (raises: [CancelledError, HttpError]).} =
## Consume/discard response and return number of bytes consumed. ## Consume/discard response and return number of bytes consumed.
## ##
## Note: This procedure performs automatic finishing for ``response``. ## Note: This procedure performs automatic finishing for ``response``.
@ -1293,16 +1440,17 @@ proc consumeBody*(response: HttpClientResponseRef): Future[int] {.async.} =
await reader.closeWait() await reader.closeWait()
reader = nil reader = nil
await response.finish() await response.finish()
return res res
except CancelledError as exc: except CancelledError as exc:
if not(isNil(reader)): if not(isNil(reader)):
await reader.closeWait() await reader.closeWait()
response.setError(newHttpInterruptError()) response.setError(newHttpInterruptError())
raise exc raise exc
except AsyncStreamError: except AsyncStreamError as exc:
let error = newHttpReadError(
"Could not consume response, reason: " & $exc.msg)
if not(isNil(reader)): if not(isNil(reader)):
await reader.closeWait() await reader.closeWait()
let error = newHttpReadError("Could not read response")
response.setError(error) response.setError(error)
raise error raise error
@ -1317,8 +1465,15 @@ proc redirect*(request: HttpClientRequestRef,
if redirectCount > request.session.maxRedirections: if redirectCount > request.session.maxRedirections:
err("Maximum number of redirects exceeded") err("Maximum number of redirects exceeded")
else: else:
var res = HttpClientRequestRef.new(request.session, ha, request.meth, let headers =
request.version, request.flags, request.headers.toList(), request.buffer) block:
var res = request.headers
res.set(HostHeader, ha.hostname)
res
var res =
HttpClientRequestRef.new(request.session, ha, request.meth,
request.version, request.flags, headers = headers.toList(),
body = request.buffer)
res.redirectCount = redirectCount res.redirectCount = redirectCount
ok(res) ok(res)
@ -1335,13 +1490,21 @@ proc redirect*(request: HttpClientRequestRef,
err("Maximum number of redirects exceeded") err("Maximum number of redirects exceeded")
else: else:
let address = ? request.session.redirect(request.address, uri) let address = ? request.session.redirect(request.address, uri)
var res = HttpClientRequestRef.new(request.session, address, request.meth, # Update Host header to redirected URL hostname
request.version, request.flags, request.headers.toList(), request.buffer) let headers =
block:
var res = request.headers
res.set(HostHeader, address.hostname)
res
var res =
HttpClientRequestRef.new(request.session, address, request.meth,
request.version, request.flags, headers = headers.toList(),
body = request.buffer)
res.redirectCount = redirectCount res.redirectCount = redirectCount
ok(res) ok(res)
proc fetch*(request: HttpClientRequestRef): Future[HttpResponseTuple] {. proc fetch*(request: HttpClientRequestRef): Future[HttpResponseTuple] {.
async.} = async: (raises: [CancelledError, HttpError]).} =
var response: HttpClientResponseRef var response: HttpClientResponseRef
try: try:
response = await request.send() response = await request.send()
@ -1349,7 +1512,7 @@ proc fetch*(request: HttpClientRequestRef): Future[HttpResponseTuple] {.
let status = response.status let status = response.status
await response.closeWait() await response.closeWait()
response = nil response = nil
return (status, buffer) (status, buffer)
except HttpError as exc: except HttpError as exc:
if not(isNil(response)): await response.closeWait() if not(isNil(response)): await response.closeWait()
raise exc raise exc
@ -1358,7 +1521,7 @@ proc fetch*(request: HttpClientRequestRef): Future[HttpResponseTuple] {.
raise exc raise exc
proc fetch*(session: HttpSessionRef, url: Uri): Future[HttpResponseTuple] {. proc fetch*(session: HttpSessionRef, url: Uri): Future[HttpResponseTuple] {.
async.} = async: (raises: [CancelledError, HttpError]).} =
## Fetch resource pointed by ``url`` using HTTP GET method and ``session`` ## Fetch resource pointed by ``url`` using HTTP GET method and ``session``
## parameters. ## parameters.
## ##
@ -1400,28 +1563,34 @@ proc fetch*(session: HttpSessionRef, url: Uri): Future[HttpResponseTuple] {.
request = redirect request = redirect
redirect = nil redirect = nil
else: else:
let data = await response.getBodyBytes() let
let code = response.status data = await response.getBodyBytes()
code = response.status
await response.closeWait() await response.closeWait()
response = nil response = nil
await request.closeWait() await request.closeWait()
request = nil request = nil
return (code, data) return (code, data)
except CancelledError as exc: except CancelledError as exc:
if not(isNil(response)): await closeWait(response) var pending: seq[Future[void]]
if not(isNil(request)): await closeWait(request) if not(isNil(response)): pending.add(closeWait(response))
if not(isNil(redirect)): await closeWait(redirect) if not(isNil(request)): pending.add(closeWait(request))
if not(isNil(redirect)): pending.add(closeWait(redirect))
await noCancel(allFutures(pending))
raise exc raise exc
except HttpError as exc: except HttpError as exc:
if not(isNil(response)): await closeWait(response) var pending: seq[Future[void]]
if not(isNil(request)): await closeWait(request) if not(isNil(response)): pending.add(closeWait(response))
if not(isNil(redirect)): await closeWait(redirect) if not(isNil(request)): pending.add(closeWait(request))
if not(isNil(redirect)): pending.add(closeWait(redirect))
await noCancel(allFutures(pending))
raise exc raise exc
proc getServerSentEvents*( proc getServerSentEvents*(
response: HttpClientResponseRef, response: HttpClientResponseRef,
maxEventSize: int = -1 maxEventSize: int = -1
): Future[seq[ServerSentEvent]] {.async.} = ): Future[seq[ServerSentEvent]] {.
async: (raises: [CancelledError, HttpError]).} =
## Read number of server-sent events (SSE) from HTTP response ``response``. ## Read number of server-sent events (SSE) from HTTP response ``response``.
## ##
## ``maxEventSize`` - maximum size of events chunk in one message, use ## ``maxEventSize`` - maximum size of events chunk in one message, use
@ -1509,8 +1678,14 @@ proc getServerSentEvents*(
(i, false) (i, false)
await reader.readMessage(predicate) try:
await reader.readMessage(predicate)
except CancelledError as exc:
raise exc
except AsyncStreamError as exc:
raiseHttpReadError($exc.msg)
if not isNil(error): if not isNil(error):
raise error raise error
else:
return res res

View File

@ -6,8 +6,11 @@
# Licensed under either of # Licensed under either of
# Apache License, version 2.0, (LICENSE-APACHEv2) # Apache License, version 2.0, (LICENSE-APACHEv2)
# MIT license (LICENSE-MIT) # MIT license (LICENSE-MIT)
{.push raises: [].}
import std/[strutils, uri] import std/[strutils, uri]
import stew/results, httputils import results, httputils
import ../../asyncloop, ../../asyncsync import ../../asyncloop, ../../asyncsync
import ../../streams/[asyncstream, boundstream] import ../../streams/[asyncstream, boundstream]
export asyncloop, asyncsync, results, httputils, strutils export asyncloop, asyncsync, results, httputils, strutils
@ -40,30 +43,48 @@ const
ServerHeader* = "server" ServerHeader* = "server"
LocationHeader* = "location" LocationHeader* = "location"
AuthorizationHeader* = "authorization" AuthorizationHeader* = "authorization"
ContentDispositionHeader* = "content-disposition"
UrlEncodedContentType* = MediaType.init("application/x-www-form-urlencoded") UrlEncodedContentType* = MediaType.init("application/x-www-form-urlencoded")
MultipartContentType* = MediaType.init("multipart/form-data") MultipartContentType* = MediaType.init("multipart/form-data")
type type
HttpMessage* = object
code*: HttpCode
contentType*: MediaType
message*: string
HttpResult*[T] = Result[T, string] HttpResult*[T] = Result[T, string]
HttpResultCode*[T] = Result[T, HttpCode] HttpResultCode*[T] = Result[T, HttpCode]
HttpResultMessage*[T] = Result[T, HttpMessage]
HttpDefect* = object of Defect HttpError* = object of AsyncError
HttpError* = object of CatchableError
HttpCriticalError* = object of HttpError
code*: HttpCode
HttpRecoverableError* = object of HttpError
code*: HttpCode
HttpDisconnectError* = object of HttpError
HttpConnectionError* = object of HttpError
HttpInterruptError* = object of HttpError HttpInterruptError* = object of HttpError
HttpReadError* = object of HttpError
HttpWriteError* = object of HttpError HttpTransportError* = object of HttpError
HttpProtocolError* = object of HttpError HttpAddressError* = object of HttpTransportError
HttpRedirectError* = object of HttpError HttpRedirectError* = object of HttpTransportError
HttpAddressError* = object of HttpError HttpConnectionError* = object of HttpTransportError
HttpUseClosedError* = object of HttpError HttpReadError* = object of HttpTransportError
HttpReadLimitError* = object of HttpReadError HttpReadLimitError* = object of HttpReadError
HttpDisconnectError* = object of HttpReadError
HttpWriteError* = object of HttpTransportError
HttpProtocolError* = object of HttpError
code*: HttpCode
HttpCriticalError* = object of HttpProtocolError # deprecated
HttpRecoverableError* = object of HttpProtocolError # deprecated
HttpRequestError* = object of HttpProtocolError
HttpRequestHeadersError* = object of HttpRequestError
HttpRequestBodyError* = object of HttpRequestError
HttpRequestHeadersTooLargeError* = object of HttpRequestHeadersError
HttpRequestBodyTooLargeError* = object of HttpRequestBodyError
HttpResponseError* = object of HttpProtocolError
HttpInvalidUsageError* = object of HttpError
HttpUseClosedError* = object of HttpInvalidUsageError
KeyValueTuple* = tuple KeyValueTuple* = tuple
key: string key: string
@ -82,35 +103,95 @@ type
HttpState* {.pure.} = enum HttpState* {.pure.} = enum
Alive, Closing, Closed Alive, Closing, Closed
proc raiseHttpCriticalError*(msg: string, HttpAddressErrorType* {.pure.} = enum
code = Http400) {.noinline, noreturn.} = InvalidUrlScheme,
InvalidPortNumber,
MissingHostname,
InvalidIpHostname,
NameLookupFailed,
NoAddressResolved
const
CriticalHttpAddressError* = {
HttpAddressErrorType.InvalidUrlScheme,
HttpAddressErrorType.InvalidPortNumber,
HttpAddressErrorType.MissingHostname,
HttpAddressErrorType.InvalidIpHostname
}
RecoverableHttpAddressError* = {
HttpAddressErrorType.NameLookupFailed,
HttpAddressErrorType.NoAddressResolved
}
func isCriticalError*(error: HttpAddressErrorType): bool =
error in CriticalHttpAddressError
func isRecoverableError*(error: HttpAddressErrorType): bool =
error in RecoverableHttpAddressError
func toString*(error: HttpAddressErrorType): string =
case error
of HttpAddressErrorType.InvalidUrlScheme:
"URL scheme not supported"
of HttpAddressErrorType.InvalidPortNumber:
"Invalid URL port number"
of HttpAddressErrorType.MissingHostname:
"Missing URL hostname"
of HttpAddressErrorType.InvalidIpHostname:
"Invalid IPv4/IPv6 address in hostname"
of HttpAddressErrorType.NameLookupFailed:
"Could not resolve remote address"
of HttpAddressErrorType.NoAddressResolved:
"No address has been resolved"
proc raiseHttpRequestBodyTooLargeError*() {.
noinline, noreturn, raises: [HttpRequestBodyTooLargeError].} =
raise (ref HttpRequestBodyTooLargeError)(
code: Http413, msg: MaximumBodySizeError)
proc raiseHttpCriticalError*(msg: string, code = Http400) {.
noinline, noreturn, raises: [HttpCriticalError].} =
raise (ref HttpCriticalError)(code: code, msg: msg) raise (ref HttpCriticalError)(code: code, msg: msg)
proc raiseHttpDisconnectError*() {.noinline, noreturn.} = proc raiseHttpDisconnectError*() {.
noinline, noreturn, raises: [HttpDisconnectError].} =
raise (ref HttpDisconnectError)(msg: "Remote peer disconnected") raise (ref HttpDisconnectError)(msg: "Remote peer disconnected")
proc raiseHttpDefect*(msg: string) {.noinline, noreturn.} = proc raiseHttpConnectionError*(msg: string) {.
raise (ref HttpDefect)(msg: msg) noinline, noreturn, raises: [HttpConnectionError].} =
proc raiseHttpConnectionError*(msg: string) {.noinline, noreturn.} =
raise (ref HttpConnectionError)(msg: msg) raise (ref HttpConnectionError)(msg: msg)
proc raiseHttpInterruptError*() {.noinline, noreturn.} = proc raiseHttpInterruptError*() {.
noinline, noreturn, raises: [HttpInterruptError].} =
raise (ref HttpInterruptError)(msg: "Connection was interrupted") raise (ref HttpInterruptError)(msg: "Connection was interrupted")
proc raiseHttpReadError*(msg: string) {.noinline, noreturn.} = proc raiseHttpReadError*(msg: string) {.
noinline, noreturn, raises: [HttpReadError].} =
raise (ref HttpReadError)(msg: msg) raise (ref HttpReadError)(msg: msg)
proc raiseHttpProtocolError*(msg: string) {.noinline, noreturn.} = proc raiseHttpProtocolError*(msg: string) {.
raise (ref HttpProtocolError)(msg: msg) noinline, noreturn, raises: [HttpProtocolError].} =
raise (ref HttpProtocolError)(code: Http400, msg: msg)
proc raiseHttpWriteError*(msg: string) {.noinline, noreturn.} = proc raiseHttpProtocolError*(code: HttpCode, msg: string) {.
noinline, noreturn, raises: [HttpProtocolError].} =
raise (ref HttpProtocolError)(code: code, msg: msg)
proc raiseHttpProtocolError*(msg: HttpMessage) {.
noinline, noreturn, raises: [HttpProtocolError].} =
raise (ref HttpProtocolError)(code: msg.code, msg: msg.message)
proc raiseHttpWriteError*(msg: string) {.
noinline, noreturn, raises: [HttpWriteError].} =
raise (ref HttpWriteError)(msg: msg) raise (ref HttpWriteError)(msg: msg)
proc raiseHttpRedirectError*(msg: string) {.noinline, noreturn.} = proc raiseHttpRedirectError*(msg: string) {.
noinline, noreturn, raises: [HttpRedirectError].} =
raise (ref HttpRedirectError)(msg: msg) raise (ref HttpRedirectError)(msg: msg)
proc raiseHttpAddressError*(msg: string) {.noinline, noreturn.} = proc raiseHttpAddressError*(msg: string) {.
noinline, noreturn, raises: [HttpAddressError].} =
raise (ref HttpAddressError)(msg: msg) raise (ref HttpAddressError)(msg: msg)
template newHttpInterruptError*(): ref HttpInterruptError = template newHttpInterruptError*(): ref HttpInterruptError =
@ -125,9 +206,25 @@ template newHttpWriteError*(message: string): ref HttpWriteError =
template newHttpUseClosedError*(): ref HttpUseClosedError = template newHttpUseClosedError*(): ref HttpUseClosedError =
newException(HttpUseClosedError, "Connection was already closed") newException(HttpUseClosedError, "Connection was already closed")
func init*(t: typedesc[HttpMessage], code: HttpCode, message: string,
contentType: MediaType): HttpMessage =
HttpMessage(code: code, message: message, contentType: contentType)
func init*(t: typedesc[HttpMessage], code: HttpCode, message: string,
contentType: string): HttpMessage =
HttpMessage(code: code, message: message,
contentType: MediaType.init(contentType))
func init*(t: typedesc[HttpMessage], code: HttpCode,
message: string): HttpMessage =
HttpMessage(code: code, message: message,
contentType: MediaType.init("text/plain"))
func init*(t: typedesc[HttpMessage], code: HttpCode): HttpMessage =
HttpMessage(code: code)
iterator queryParams*(query: string, iterator queryParams*(query: string,
flags: set[QueryParamsFlag] = {}): KeyValueTuple {. flags: set[QueryParamsFlag] = {}): KeyValueTuple =
raises: [].} =
## Iterate over url-encoded query string. ## Iterate over url-encoded query string.
for pair in query.split('&'): for pair in query.split('&'):
let items = pair.split('=', maxsplit = 1) let items = pair.split('=', maxsplit = 1)
@ -140,9 +237,9 @@ iterator queryParams*(query: string,
else: else:
yield (decodeUrl(k), decodeUrl(v)) yield (decodeUrl(k), decodeUrl(v))
func getTransferEncoding*(ch: openArray[string]): HttpResult[ func getTransferEncoding*(
set[TransferEncodingFlags]] {. ch: openArray[string]
raises: [].} = ): HttpResult[set[TransferEncodingFlags]] =
## Parse value of multiple HTTP headers ``Transfer-Encoding`` and return ## Parse value of multiple HTTP headers ``Transfer-Encoding`` and return
## it as set of ``TransferEncodingFlags``. ## it as set of ``TransferEncodingFlags``.
var res: set[TransferEncodingFlags] = {} var res: set[TransferEncodingFlags] = {}
@ -171,9 +268,9 @@ func getTransferEncoding*(ch: openArray[string]): HttpResult[
return err("Incorrect Transfer-Encoding value") return err("Incorrect Transfer-Encoding value")
ok(res) ok(res)
func getContentEncoding*(ch: openArray[string]): HttpResult[ func getContentEncoding*(
set[ContentEncodingFlags]] {. ch: openArray[string]
raises: [].} = ): HttpResult[set[ContentEncodingFlags]] =
## Parse value of multiple HTTP headers ``Content-Encoding`` and return ## Parse value of multiple HTTP headers ``Content-Encoding`` and return
## it as set of ``ContentEncodingFlags``. ## it as set of ``ContentEncodingFlags``.
var res: set[ContentEncodingFlags] = {} var res: set[ContentEncodingFlags] = {}
@ -202,8 +299,7 @@ func getContentEncoding*(ch: openArray[string]): HttpResult[
return err("Incorrect Content-Encoding value") return err("Incorrect Content-Encoding value")
ok(res) ok(res)
func getContentType*(ch: openArray[string]): HttpResult[ContentTypeData] {. func getContentType*(ch: openArray[string]): HttpResult[ContentTypeData] =
raises: [].} =
## Check and prepare value of ``Content-Type`` header. ## Check and prepare value of ``Content-Type`` header.
if len(ch) == 0: if len(ch) == 0:
err("No Content-Type values found") err("No Content-Type values found")

View File

@ -6,8 +6,11 @@
# Licensed under either of # Licensed under either of
# Apache License, version 2.0, (LICENSE-APACHEv2) # Apache License, version 2.0, (LICENSE-APACHEv2)
# MIT license (LICENSE-MIT) # MIT license (LICENSE-MIT)
{.push raises: [].}
import std/tables import std/tables
import stew/results import results
import ../../timer import ../../timer
import httpserver, shttpserver import httpserver, shttpserver
from httpclient import HttpClientScheme from httpclient import HttpClientScheme
@ -16,8 +19,6 @@ from ../../osdefs import SocketHandle
from ../../transports/common import TransportAddress, ServerFlags from ../../transports/common import TransportAddress, ServerFlags
export HttpClientScheme, SocketHandle, TransportAddress, ServerFlags, HttpState export HttpClientScheme, SocketHandle, TransportAddress, ServerFlags, HttpState
{.push raises: [].}
type type
ConnectionType* {.pure.} = enum ConnectionType* {.pure.} = enum
NonSecure, Secure NonSecure, Secure
@ -29,6 +30,7 @@ type
handle*: SocketHandle handle*: SocketHandle
connectionType*: ConnectionType connectionType*: ConnectionType
connectionState*: ConnectionState connectionState*: ConnectionState
query*: Opt[string]
remoteAddress*: Opt[TransportAddress] remoteAddress*: Opt[TransportAddress]
localAddress*: Opt[TransportAddress] localAddress*: Opt[TransportAddress]
acceptMoment*: Moment acceptMoment*: Moment
@ -85,6 +87,12 @@ proc getConnectionState*(holder: HttpConnectionHolderRef): ConnectionState =
else: else:
ConnectionState.Accepted ConnectionState.Accepted
proc getQueryString*(holder: HttpConnectionHolderRef): Opt[string] =
if not(isNil(holder.connection)):
holder.connection.currentRawQuery
else:
Opt.none(string)
proc init*(t: typedesc[ServerConnectionInfo], proc init*(t: typedesc[ServerConnectionInfo],
holder: HttpConnectionHolderRef): ServerConnectionInfo = holder: HttpConnectionHolderRef): ServerConnectionInfo =
let let
@ -98,6 +106,7 @@ proc init*(t: typedesc[ServerConnectionInfo],
Opt.some(holder.transp.remoteAddress()) Opt.some(holder.transp.remoteAddress())
except CatchableError: except CatchableError:
Opt.none(TransportAddress) Opt.none(TransportAddress)
queryString = holder.getQueryString()
ServerConnectionInfo( ServerConnectionInfo(
handle: SocketHandle(holder.transp.fd), handle: SocketHandle(holder.transp.fd),
@ -106,6 +115,7 @@ proc init*(t: typedesc[ServerConnectionInfo],
remoteAddress: remoteAddress, remoteAddress: remoteAddress,
localAddress: localAddress, localAddress: localAddress,
acceptMoment: holder.acceptMoment, acceptMoment: holder.acceptMoment,
query: queryString,
createMoment: createMoment:
if not(isNil(holder.connection)): if not(isNil(holder.connection)):
Opt.some(holder.connection.createMoment) Opt.some(holder.connection.createMoment)

File diff suppressed because it is too large Load Diff

View File

@ -197,3 +197,7 @@ proc toList*(ht: HttpTables, normKey = false): auto =
for key, value in ht.stringItems(normKey): for key, value in ht.stringItems(normKey):
res.add((key, value)) res.add((key, value))
res res
proc clear*(ht: var HttpTables) =
## Resets the HtppTable so that it is empty.
ht.table.clear()

View File

@ -7,15 +7,20 @@
# Licensed under either of # Licensed under either of
# Apache License, version 2.0, (LICENSE-APACHEv2) # Apache License, version 2.0, (LICENSE-APACHEv2)
# MIT license (LICENSE-MIT) # MIT license (LICENSE-MIT)
{.push raises: [].}
import std/[monotimes, strutils] import std/[monotimes, strutils]
import stew/results, httputils import results, httputils
import ../../asyncloop import ../../asyncloop
import ../../streams/[asyncstream, boundstream, chunkstream] import ../../streams/[asyncstream, boundstream, chunkstream]
import httptable, httpcommon, httpbodyrw import "."/[httptable, httpcommon, httpbodyrw]
export asyncloop, httptable, httpcommon, httpbodyrw, asyncstream, httputils export asyncloop, httptable, httpcommon, httpbodyrw, asyncstream, httputils
const const
UnableToReadMultipartBody = "Unable to read multipart message body" UnableToReadMultipartBody = "Unable to read multipart message body, reason: "
UnableToSendMultipartMessage = "Unable to send multipart message, reason: "
MaxMultipartHeaderSize = 4096
type type
MultiPartSource* {.pure.} = enum MultiPartSource* {.pure.} = enum
@ -66,13 +71,12 @@ type
name*: string name*: string
filename*: string filename*: string
MultipartError* = object of HttpCriticalError MultipartError* = object of HttpProtocolError
MultipartEOMError* = object of MultipartError MultipartEOMError* = object of MultipartError
BChar* = byte | char BChar* = byte | char
proc startsWith(s, prefix: openArray[byte]): bool {. proc startsWith(s, prefix: openArray[byte]): bool =
raises: [].} =
# This procedure is copy of strutils.startsWith() procedure, however, # This procedure is copy of strutils.startsWith() procedure, however,
# it is intended to work with arrays of bytes, but not with strings. # it is intended to work with arrays of bytes, but not with strings.
var i = 0 var i = 0
@ -81,8 +85,7 @@ proc startsWith(s, prefix: openArray[byte]): bool {.
if i >= len(s) or s[i] != prefix[i]: return false if i >= len(s) or s[i] != prefix[i]: return false
inc(i) inc(i)
proc parseUntil(s, until: openArray[byte]): int {. proc parseUntil(s, until: openArray[byte]): int =
raises: [].} =
# This procedure is copy of parseutils.parseUntil() procedure, however, # This procedure is copy of parseutils.parseUntil() procedure, however,
# it is intended to work with arrays of bytes, but not with strings. # it is intended to work with arrays of bytes, but not with strings.
var i = 0 var i = 0
@ -95,8 +98,7 @@ proc parseUntil(s, until: openArray[byte]): int {.
inc(i) inc(i)
-1 -1
func setPartNames(part: var MultiPart): HttpResult[void] {. func setPartNames(part: var MultiPart): HttpResult[void] =
raises: [].} =
if part.headers.count("content-disposition") != 1: if part.headers.count("content-disposition") != 1:
return err("Content-Disposition header is incorrect") return err("Content-Disposition header is incorrect")
var header = part.headers.getString("content-disposition") var header = part.headers.getString("content-disposition")
@ -105,7 +107,7 @@ func setPartNames(part: var MultiPart): HttpResult[void] {.
return err("Content-Disposition header value is incorrect") return err("Content-Disposition header value is incorrect")
let dtype = disp.dispositionType(header.toOpenArrayByte(0, len(header) - 1)) let dtype = disp.dispositionType(header.toOpenArrayByte(0, len(header) - 1))
if dtype.toLowerAscii() != "form-data": if dtype.toLowerAscii() != "form-data":
return err("Content-Disposition type is incorrect") return err("Content-Disposition header type is incorrect")
for k, v in disp.fields(header.toOpenArrayByte(0, len(header) - 1)): for k, v in disp.fields(header.toOpenArrayByte(0, len(header) - 1)):
case k.toLowerAscii() case k.toLowerAscii()
of "name": of "name":
@ -120,8 +122,7 @@ func setPartNames(part: var MultiPart): HttpResult[void] {.
proc init*[A: BChar, B: BChar](mpt: typedesc[MultiPartReader], proc init*[A: BChar, B: BChar](mpt: typedesc[MultiPartReader],
buffer: openArray[A], buffer: openArray[A],
boundary: openArray[B]): MultiPartReader {. boundary: openArray[B]): MultiPartReader =
raises: [].} =
## Create new MultiPartReader instance with `buffer` interface. ## Create new MultiPartReader instance with `buffer` interface.
## ##
## ``buffer`` - is buffer which will be used to read data. ## ``buffer`` - is buffer which will be used to read data.
@ -142,11 +143,11 @@ proc init*[A: BChar, B: BChar](mpt: typedesc[MultiPartReader],
MultiPartReader(kind: MultiPartSource.Buffer, MultiPartReader(kind: MultiPartSource.Buffer,
buffer: buf, offset: 0, boundary: fboundary) buffer: buf, offset: 0, boundary: fboundary)
proc new*[B: BChar](mpt: typedesc[MultiPartReaderRef], proc new*[B: BChar](
stream: HttpBodyReader, mpt: typedesc[MultiPartReaderRef],
boundary: openArray[B], stream: HttpBodyReader,
partHeadersMaxSize = 4096): MultiPartReaderRef {. boundary: openArray[B],
raises: [].} = partHeadersMaxSize = MaxMultipartHeaderSize): MultiPartReaderRef =
## Create new MultiPartReader instance with `stream` interface. ## Create new MultiPartReader instance with `stream` interface.
## ##
## ``stream`` is stream used to read data. ## ``stream`` is stream used to read data.
@ -173,7 +174,17 @@ proc new*[B: BChar](mpt: typedesc[MultiPartReaderRef],
stream: stream, offset: 0, boundary: fboundary, stream: stream, offset: 0, boundary: fboundary,
buffer: newSeq[byte](partHeadersMaxSize)) buffer: newSeq[byte](partHeadersMaxSize))
proc readPart*(mpr: MultiPartReaderRef): Future[MultiPart] {.async.} = template handleAsyncStreamReaderError(targ, excarg: untyped) =
if targ.hasOverflow():
raiseHttpRequestBodyTooLargeError()
raiseHttpReadError(UnableToReadMultipartBody & $excarg.msg)
template handleAsyncStreamWriterError(targ, excarg: untyped) =
targ.state = MultiPartWriterState.MessageFailure
raiseHttpWriteError(UnableToSendMultipartMessage & $excarg.msg)
proc readPart*(mpr: MultiPartReaderRef): Future[MultiPart] {.
async: (raises: [CancelledError, HttpReadError, HttpProtocolError]).} =
doAssert(mpr.kind == MultiPartSource.Stream) doAssert(mpr.kind == MultiPartSource.Stream)
if mpr.firstTime: if mpr.firstTime:
try: try:
@ -182,14 +193,11 @@ proc readPart*(mpr: MultiPartReaderRef): Future[MultiPart] {.async.} =
mpr.firstTime = false mpr.firstTime = false
if not(startsWith(mpr.buffer.toOpenArray(0, len(mpr.boundary) - 3), if not(startsWith(mpr.buffer.toOpenArray(0, len(mpr.boundary) - 3),
mpr.boundary.toOpenArray(2, len(mpr.boundary) - 1))): mpr.boundary.toOpenArray(2, len(mpr.boundary) - 1))):
raiseHttpCriticalError("Unexpected boundary encountered") raiseHttpProtocolError(Http400, "Unexpected boundary encountered")
except CancelledError as exc: except CancelledError as exc:
raise exc raise exc
except AsyncStreamError: except AsyncStreamError as exc:
if mpr.stream.hasOverflow(): handleAsyncStreamReaderError(mpr.stream, exc)
raiseHttpCriticalError(MaximumBodySizeError, Http413)
else:
raiseHttpCriticalError(UnableToReadMultipartBody)
# Reading part's headers # Reading part's headers
try: try:
@ -203,9 +211,9 @@ proc readPart*(mpr: MultiPartReaderRef): Future[MultiPart] {.async.} =
raise newException(MultipartEOMError, raise newException(MultipartEOMError,
"End of multipart message") "End of multipart message")
else: else:
raiseHttpCriticalError("Incorrect multipart header found") raiseHttpProtocolError(Http400, "Incorrect multipart header found")
if mpr.buffer[0] != 0x0D'u8 or mpr.buffer[1] != 0x0A'u8: if mpr.buffer[0] != 0x0D'u8 or mpr.buffer[1] != 0x0A'u8:
raiseHttpCriticalError("Incorrect multipart boundary found") raiseHttpProtocolError(Http400, "Incorrect multipart boundary found")
# If two bytes are CRLF we are at the part beginning. # If two bytes are CRLF we are at the part beginning.
# Reading part's headers # Reading part's headers
@ -213,7 +221,7 @@ proc readPart*(mpr: MultiPartReaderRef): Future[MultiPart] {.async.} =
HeadersMark) HeadersMark)
var headersList = parseHeaders(mpr.buffer.toOpenArray(0, res - 1), false) var headersList = parseHeaders(mpr.buffer.toOpenArray(0, res - 1), false)
if headersList.failed(): if headersList.failed():
raiseHttpCriticalError("Incorrect multipart's headers found") raiseHttpProtocolError(Http400, "Incorrect multipart's headers found")
inc(mpr.counter) inc(mpr.counter)
var part = MultiPart( var part = MultiPart(
@ -229,48 +237,39 @@ proc readPart*(mpr: MultiPartReaderRef): Future[MultiPart] {.async.} =
let sres = part.setPartNames() let sres = part.setPartNames()
if sres.isErr(): if sres.isErr():
raiseHttpCriticalError($sres.error) raiseHttpProtocolError(Http400, $sres.error)
return part return part
except CancelledError as exc: except CancelledError as exc:
raise exc raise exc
except AsyncStreamError: except AsyncStreamError as exc:
if mpr.stream.hasOverflow(): handleAsyncStreamReaderError(mpr.stream, exc)
raiseHttpCriticalError(MaximumBodySizeError, Http413)
else:
raiseHttpCriticalError(UnableToReadMultipartBody)
proc getBody*(mp: MultiPart): Future[seq[byte]] {.async.} = proc getBody*(mp: MultiPart): Future[seq[byte]] {.
async: (raises: [CancelledError, HttpReadError, HttpProtocolError]).} =
## Get multipart's ``mp`` value as sequence of bytes. ## Get multipart's ``mp`` value as sequence of bytes.
case mp.kind case mp.kind
of MultiPartSource.Stream: of MultiPartSource.Stream:
try: try:
let res = await mp.stream.read() await mp.stream.read()
return res except AsyncStreamError as exc:
except AsyncStreamError: handleAsyncStreamReaderError(mp.breader, exc)
if mp.breader.hasOverflow():
raiseHttpCriticalError(MaximumBodySizeError, Http413)
else:
raiseHttpCriticalError(UnableToReadMultipartBody)
of MultiPartSource.Buffer: of MultiPartSource.Buffer:
return mp.buffer mp.buffer
proc consumeBody*(mp: MultiPart) {.async.} = proc consumeBody*(mp: MultiPart) {.
async: (raises: [CancelledError, HttpReadError, HttpProtocolError]).} =
## Discard multipart's ``mp`` value. ## Discard multipart's ``mp`` value.
case mp.kind case mp.kind
of MultiPartSource.Stream: of MultiPartSource.Stream:
try: try:
discard await mp.stream.consume() discard await mp.stream.consume()
except AsyncStreamError: except AsyncStreamError as exc:
if mp.breader.hasOverflow(): handleAsyncStreamReaderError(mp.breader, exc)
raiseHttpCriticalError(MaximumBodySizeError, Http413)
else:
raiseHttpCriticalError(UnableToReadMultipartBody)
of MultiPartSource.Buffer: of MultiPartSource.Buffer:
discard discard
proc getBodyStream*(mp: MultiPart): HttpResult[AsyncStreamReader] {. proc getBodyStream*(mp: MultiPart): HttpResult[AsyncStreamReader] =
raises: [].} =
## Get multipart's ``mp`` stream, which can be used to obtain value of the ## Get multipart's ``mp`` stream, which can be used to obtain value of the
## part. ## part.
case mp.kind case mp.kind
@ -279,7 +278,7 @@ proc getBodyStream*(mp: MultiPart): HttpResult[AsyncStreamReader] {.
else: else:
err("Could not obtain stream from buffer-like part") err("Could not obtain stream from buffer-like part")
proc closeWait*(mp: MultiPart) {.async.} = proc closeWait*(mp: MultiPart) {.async: (raises: []).} =
## Close and release MultiPart's ``mp`` stream and resources. ## Close and release MultiPart's ``mp`` stream and resources.
case mp.kind case mp.kind
of MultiPartSource.Stream: of MultiPartSource.Stream:
@ -287,7 +286,7 @@ proc closeWait*(mp: MultiPart) {.async.} =
else: else:
discard discard
proc closeWait*(mpr: MultiPartReaderRef) {.async.} = proc closeWait*(mpr: MultiPartReaderRef) {.async: (raises: []).} =
## Close and release MultiPartReader's ``mpr`` stream and resources. ## Close and release MultiPartReader's ``mpr`` stream and resources.
case mpr.kind case mpr.kind
of MultiPartSource.Stream: of MultiPartSource.Stream:
@ -295,7 +294,7 @@ proc closeWait*(mpr: MultiPartReaderRef) {.async.} =
else: else:
discard discard
proc getBytes*(mp: MultiPart): seq[byte] {.raises: [].} = proc getBytes*(mp: MultiPart): seq[byte] =
## Returns value for MultiPart ``mp`` as sequence of bytes. ## Returns value for MultiPart ``mp`` as sequence of bytes.
case mp.kind case mp.kind
of MultiPartSource.Buffer: of MultiPartSource.Buffer:
@ -304,7 +303,7 @@ proc getBytes*(mp: MultiPart): seq[byte] {.raises: [].} =
doAssert(not(mp.stream.atEof()), "Value is not obtained yet") doAssert(not(mp.stream.atEof()), "Value is not obtained yet")
mp.buffer mp.buffer
proc getString*(mp: MultiPart): string {.raises: [].} = proc getString*(mp: MultiPart): string =
## Returns value for MultiPart ``mp`` as string. ## Returns value for MultiPart ``mp`` as string.
case mp.kind case mp.kind
of MultiPartSource.Buffer: of MultiPartSource.Buffer:
@ -313,7 +312,7 @@ proc getString*(mp: MultiPart): string {.raises: [].} =
doAssert(not(mp.stream.atEof()), "Value is not obtained yet") doAssert(not(mp.stream.atEof()), "Value is not obtained yet")
bytesToString(mp.buffer) bytesToString(mp.buffer)
proc atEoM*(mpr: var MultiPartReader): bool {.raises: [].} = proc atEoM*(mpr: var MultiPartReader): bool =
## Procedure returns ``true`` if MultiPartReader has reached the end of ## Procedure returns ``true`` if MultiPartReader has reached the end of
## multipart message. ## multipart message.
case mpr.kind case mpr.kind
@ -322,7 +321,7 @@ proc atEoM*(mpr: var MultiPartReader): bool {.raises: [].} =
of MultiPartSource.Stream: of MultiPartSource.Stream:
mpr.stream.atEof() mpr.stream.atEof()
proc atEoM*(mpr: MultiPartReaderRef): bool {.raises: [].} = proc atEoM*(mpr: MultiPartReaderRef): bool =
## Procedure returns ``true`` if MultiPartReader has reached the end of ## Procedure returns ``true`` if MultiPartReader has reached the end of
## multipart message. ## multipart message.
case mpr.kind case mpr.kind
@ -331,8 +330,7 @@ proc atEoM*(mpr: MultiPartReaderRef): bool {.raises: [].} =
of MultiPartSource.Stream: of MultiPartSource.Stream:
mpr.stream.atEof() mpr.stream.atEof()
proc getPart*(mpr: var MultiPartReader): Result[MultiPart, string] {. proc getPart*(mpr: var MultiPartReader): Result[MultiPart, string] =
raises: [].} =
## Get multipart part from MultiPartReader instance. ## Get multipart part from MultiPartReader instance.
## ##
## This procedure will work only for MultiPartReader with buffer source. ## This procedure will work only for MultiPartReader with buffer source.
@ -422,8 +420,7 @@ proc getPart*(mpr: var MultiPartReader): Result[MultiPart, string] {.
else: else:
err("Incorrect multipart form") err("Incorrect multipart form")
func isEmpty*(mp: MultiPart): bool {. func isEmpty*(mp: MultiPart): bool =
raises: [].} =
## Returns ``true`` is multipart ``mp`` is not initialized/filled yet. ## Returns ``true`` is multipart ``mp`` is not initialized/filled yet.
mp.counter == 0 mp.counter == 0
@ -439,8 +436,7 @@ func validateBoundary[B: BChar](boundary: openArray[B]): HttpResult[void] =
return err("Content-Type boundary alphabet incorrect") return err("Content-Type boundary alphabet incorrect")
ok() ok()
func getMultipartBoundary*(contentData: ContentTypeData): HttpResult[string] {. func getMultipartBoundary*(contentData: ContentTypeData): HttpResult[string] =
raises: [].} =
## Returns ``multipart/form-data`` boundary value from ``Content-Type`` ## Returns ``multipart/form-data`` boundary value from ``Content-Type``
## header. ## header.
## ##
@ -480,8 +476,7 @@ proc quoteCheck(name: string): HttpResult[string] =
ok(name) ok(name)
proc init*[B: BChar](mpt: typedesc[MultiPartWriter], proc init*[B: BChar](mpt: typedesc[MultiPartWriter],
boundary: openArray[B]): MultiPartWriter {. boundary: openArray[B]): MultiPartWriter =
raises: [].} =
## Create new MultiPartWriter instance with `buffer` interface. ## Create new MultiPartWriter instance with `buffer` interface.
## ##
## ``boundary`` - is multipart boundary, this value must not be empty. ## ``boundary`` - is multipart boundary, this value must not be empty.
@ -510,8 +505,7 @@ proc init*[B: BChar](mpt: typedesc[MultiPartWriter],
proc new*[B: BChar](mpt: typedesc[MultiPartWriterRef], proc new*[B: BChar](mpt: typedesc[MultiPartWriterRef],
stream: HttpBodyWriter, stream: HttpBodyWriter,
boundary: openArray[B]): MultiPartWriterRef {. boundary: openArray[B]): MultiPartWriterRef =
raises: [].} =
doAssert(validateBoundary(boundary).isOk()) doAssert(validateBoundary(boundary).isOk())
doAssert(not(isNil(stream))) doAssert(not(isNil(stream)))
@ -538,7 +532,7 @@ proc new*[B: BChar](mpt: typedesc[MultiPartWriterRef],
proc prepareHeaders(partMark: openArray[byte], name: string, filename: string, proc prepareHeaders(partMark: openArray[byte], name: string, filename: string,
headers: HttpTable): string = headers: HttpTable): string =
const ContentDisposition = "Content-Disposition" const ContentDispositionHeader = "Content-Disposition"
let qname = let qname =
block: block:
let res = quoteCheck(name) let res = quoteCheck(name)
@ -551,10 +545,10 @@ proc prepareHeaders(partMark: openArray[byte], name: string, filename: string,
res.get() res.get()
var buffer = newString(len(partMark)) var buffer = newString(len(partMark))
copyMem(addr buffer[0], unsafeAddr partMark[0], len(partMark)) copyMem(addr buffer[0], unsafeAddr partMark[0], len(partMark))
buffer.add(ContentDisposition) buffer.add(ContentDispositionHeader)
buffer.add(": ") buffer.add(": ")
if ContentDisposition in headers: if ContentDispositionHeader in headers:
buffer.add(headers.getString(ContentDisposition)) buffer.add(headers.getString(ContentDispositionHeader))
buffer.add("\r\n") buffer.add("\r\n")
else: else:
buffer.add("form-data; name=\"") buffer.add("form-data; name=\"")
@ -567,7 +561,7 @@ proc prepareHeaders(partMark: openArray[byte], name: string, filename: string,
buffer.add("\r\n") buffer.add("\r\n")
for k, v in headers.stringItems(): for k, v in headers.stringItems():
if k != toLowerAscii(ContentDisposition): if k != ContentDispositionHeader:
if len(v) > 0: if len(v) > 0:
buffer.add(k) buffer.add(k)
buffer.add(": ") buffer.add(": ")
@ -576,7 +570,8 @@ proc prepareHeaders(partMark: openArray[byte], name: string, filename: string,
buffer.add("\r\n") buffer.add("\r\n")
buffer buffer
proc begin*(mpw: MultiPartWriterRef) {.async.} = proc begin*(mpw: MultiPartWriterRef) {.
async: (raises: [CancelledError, HttpWriteError]).} =
## Starts multipart message form and write approprate markers to output ## Starts multipart message form and write approprate markers to output
## stream. ## stream.
doAssert(mpw.kind == MultiPartSource.Stream) doAssert(mpw.kind == MultiPartSource.Stream)
@ -584,10 +579,9 @@ proc begin*(mpw: MultiPartWriterRef) {.async.} =
# write "--" # write "--"
try: try:
await mpw.stream.write(mpw.beginMark) await mpw.stream.write(mpw.beginMark)
except AsyncStreamError: mpw.state = MultiPartWriterState.MessageStarted
mpw.state = MultiPartWriterState.MessageFailure except AsyncStreamError as exc:
raiseHttpCriticalError("Unable to start multipart message") handleAsyncStreamWriterError(mpw, exc)
mpw.state = MultiPartWriterState.MessageStarted
proc begin*(mpw: var MultiPartWriter) = proc begin*(mpw: var MultiPartWriter) =
## Starts multipart message form and write approprate markers to output ## Starts multipart message form and write approprate markers to output
@ -599,7 +593,8 @@ proc begin*(mpw: var MultiPartWriter) =
mpw.state = MultiPartWriterState.MessageStarted mpw.state = MultiPartWriterState.MessageStarted
proc beginPart*(mpw: MultiPartWriterRef, name: string, proc beginPart*(mpw: MultiPartWriterRef, name: string,
filename: string, headers: HttpTable) {.async.} = filename: string, headers: HttpTable) {.
async: (raises: [CancelledError, HttpWriteError]).} =
## Starts part of multipart message and write appropriate ``headers`` to the ## Starts part of multipart message and write appropriate ``headers`` to the
## output stream. ## output stream.
## ##
@ -614,9 +609,8 @@ proc beginPart*(mpw: MultiPartWriterRef, name: string,
try: try:
await mpw.stream.write(buffer) await mpw.stream.write(buffer)
mpw.state = MultiPartWriterState.PartStarted mpw.state = MultiPartWriterState.PartStarted
except AsyncStreamError: except AsyncStreamError as exc:
mpw.state = MultiPartWriterState.MessageFailure handleAsyncStreamWriterError(mpw, exc)
raiseHttpCriticalError("Unable to start multipart part")
proc beginPart*(mpw: var MultiPartWriter, name: string, proc beginPart*(mpw: var MultiPartWriter, name: string,
filename: string, headers: HttpTable) = filename: string, headers: HttpTable) =
@ -634,38 +628,38 @@ proc beginPart*(mpw: var MultiPartWriter, name: string,
mpw.buffer.add(buffer.toOpenArrayByte(0, len(buffer) - 1)) mpw.buffer.add(buffer.toOpenArrayByte(0, len(buffer) - 1))
mpw.state = MultiPartWriterState.PartStarted mpw.state = MultiPartWriterState.PartStarted
proc write*(mpw: MultiPartWriterRef, pbytes: pointer, nbytes: int) {.async.} = proc write*(mpw: MultiPartWriterRef, pbytes: pointer, nbytes: int) {.
async: (raises: [CancelledError, HttpWriteError]).} =
## Write part's data ``data`` to the output stream. ## Write part's data ``data`` to the output stream.
doAssert(mpw.kind == MultiPartSource.Stream) doAssert(mpw.kind == MultiPartSource.Stream)
doAssert(mpw.state == MultiPartWriterState.PartStarted) doAssert(mpw.state == MultiPartWriterState.PartStarted)
try: try:
# write <chunk> of data # write <chunk> of data
await mpw.stream.write(pbytes, nbytes) await mpw.stream.write(pbytes, nbytes)
except AsyncStreamError: except AsyncStreamError as exc:
mpw.state = MultiPartWriterState.MessageFailure handleAsyncStreamWriterError(mpw, exc)
raiseHttpCriticalError("Unable to write multipart data")
proc write*(mpw: MultiPartWriterRef, data: seq[byte]) {.async.} = proc write*(mpw: MultiPartWriterRef, data: seq[byte]) {.
async: (raises: [CancelledError, HttpWriteError]).} =
## Write part's data ``data`` to the output stream. ## Write part's data ``data`` to the output stream.
doAssert(mpw.kind == MultiPartSource.Stream) doAssert(mpw.kind == MultiPartSource.Stream)
doAssert(mpw.state == MultiPartWriterState.PartStarted) doAssert(mpw.state == MultiPartWriterState.PartStarted)
try: try:
# write <chunk> of data # write <chunk> of data
await mpw.stream.write(data) await mpw.stream.write(data)
except AsyncStreamError: except AsyncStreamError as exc:
mpw.state = MultiPartWriterState.MessageFailure handleAsyncStreamWriterError(mpw, exc)
raiseHttpCriticalError("Unable to write multipart data")
proc write*(mpw: MultiPartWriterRef, data: string) {.async.} = proc write*(mpw: MultiPartWriterRef, data: string) {.
async: (raises: [CancelledError, HttpWriteError]).} =
## Write part's data ``data`` to the output stream. ## Write part's data ``data`` to the output stream.
doAssert(mpw.kind == MultiPartSource.Stream) doAssert(mpw.kind == MultiPartSource.Stream)
doAssert(mpw.state == MultiPartWriterState.PartStarted) doAssert(mpw.state == MultiPartWriterState.PartStarted)
try: try:
# write <chunk> of data # write <chunk> of data
await mpw.stream.write(data) await mpw.stream.write(data)
except AsyncStreamError: except AsyncStreamError as exc:
mpw.state = MultiPartWriterState.MessageFailure handleAsyncStreamWriterError(mpw, exc)
raiseHttpCriticalError("Unable to write multipart data")
proc write*(mpw: var MultiPartWriter, pbytes: pointer, nbytes: int) = proc write*(mpw: var MultiPartWriter, pbytes: pointer, nbytes: int) =
## Write part's data ``data`` to the output stream. ## Write part's data ``data`` to the output stream.
@ -688,16 +682,16 @@ proc write*(mpw: var MultiPartWriter, data: openArray[char]) =
doAssert(mpw.state == MultiPartWriterState.PartStarted) doAssert(mpw.state == MultiPartWriterState.PartStarted)
mpw.buffer.add(data.toOpenArrayByte(0, len(data) - 1)) mpw.buffer.add(data.toOpenArrayByte(0, len(data) - 1))
proc finishPart*(mpw: MultiPartWriterRef) {.async.} = proc finishPart*(mpw: MultiPartWriterRef) {.
async: (raises: [CancelledError, HttpWriteError]).} =
## Finish multipart's message part and send proper markers to output stream. ## Finish multipart's message part and send proper markers to output stream.
doAssert(mpw.state == MultiPartWriterState.PartStarted) doAssert(mpw.state == MultiPartWriterState.PartStarted)
try: try:
# write "<CR><LF>--" # write "<CR><LF>--"
await mpw.stream.write(mpw.finishPartMark) await mpw.stream.write(mpw.finishPartMark)
mpw.state = MultiPartWriterState.PartFinished mpw.state = MultiPartWriterState.PartFinished
except AsyncStreamError: except AsyncStreamError as exc:
mpw.state = MultiPartWriterState.MessageFailure handleAsyncStreamWriterError(mpw, exc)
raiseHttpCriticalError("Unable to finish multipart message part")
proc finishPart*(mpw: var MultiPartWriter) = proc finishPart*(mpw: var MultiPartWriter) =
## Finish multipart's message part and send proper markers to output stream. ## Finish multipart's message part and send proper markers to output stream.
@ -707,7 +701,8 @@ proc finishPart*(mpw: var MultiPartWriter) =
mpw.buffer.add(mpw.finishPartMark) mpw.buffer.add(mpw.finishPartMark)
mpw.state = MultiPartWriterState.PartFinished mpw.state = MultiPartWriterState.PartFinished
proc finish*(mpw: MultiPartWriterRef) {.async.} = proc finish*(mpw: MultiPartWriterRef) {.
async: (raises: [CancelledError, HttpWriteError]).} =
## Finish multipart's message form and send finishing markers to the output ## Finish multipart's message form and send finishing markers to the output
## stream. ## stream.
doAssert(mpw.kind == MultiPartSource.Stream) doAssert(mpw.kind == MultiPartSource.Stream)
@ -716,9 +711,8 @@ proc finish*(mpw: MultiPartWriterRef) {.async.} =
# write "<boundary>--" # write "<boundary>--"
await mpw.stream.write(mpw.finishMark) await mpw.stream.write(mpw.finishMark)
mpw.state = MultiPartWriterState.MessageFinished mpw.state = MultiPartWriterState.MessageFinished
except AsyncStreamError: except AsyncStreamError as exc:
mpw.state = MultiPartWriterState.MessageFailure handleAsyncStreamWriterError(mpw, exc)
raiseHttpCriticalError("Unable to finish multipart message")
proc finish*(mpw: var MultiPartWriter): seq[byte] = proc finish*(mpw: var MultiPartWriter): seq[byte] =
## Finish multipart's message form and send finishing markers to the output ## Finish multipart's message form and send finishing markers to the output

View File

@ -6,8 +6,11 @@
# Licensed under either of # Licensed under either of
# Apache License, version 2.0, (LICENSE-APACHEv2) # Apache License, version 2.0, (LICENSE-APACHEv2)
# MIT license (LICENSE-MIT) # MIT license (LICENSE-MIT)
{.push raises: [].}
import httpserver import httpserver
import ../../asyncloop, ../../asyncsync import ../../[asyncloop, asyncsync, config]
import ../../streams/[asyncstream, tlsstream] import ../../streams/[asyncstream, tlsstream]
export asyncloop, asyncsync, httpserver, asyncstream, tlsstream export asyncloop, asyncsync, httpserver, asyncstream, tlsstream
@ -24,59 +27,119 @@ type
SecureHttpConnectionRef* = ref SecureHttpConnection SecureHttpConnectionRef* = ref SecureHttpConnection
proc closeSecConnection(conn: HttpConnectionRef) {.async.} = proc closeSecConnection(conn: HttpConnectionRef) {.async: (raises: []).} =
if conn.state == HttpState.Alive: if conn.state == HttpState.Alive:
conn.state = HttpState.Closing conn.state = HttpState.Closing
var pending: seq[Future[void]] var pending: seq[Future[void]]
pending.add(conn.writer.closeWait()) pending.add(conn.writer.closeWait())
pending.add(conn.reader.closeWait()) pending.add(conn.reader.closeWait())
try: pending.add(conn.mainReader.closeWait())
await allFutures(pending) pending.add(conn.mainWriter.closeWait())
except CancelledError: pending.add(conn.transp.closeWait())
await allFutures(pending) await noCancel(allFutures(pending))
# After we going to close everything else. reset(cast[SecureHttpConnectionRef](conn)[])
pending.setLen(3)
pending[0] = conn.mainReader.closeWait()
pending[1] = conn.mainWriter.closeWait()
pending[2] = conn.transp.closeWait()
try:
await allFutures(pending)
except CancelledError:
await allFutures(pending)
untrackCounter(HttpServerSecureConnectionTrackerName) untrackCounter(HttpServerSecureConnectionTrackerName)
conn.state = HttpState.Closed conn.state = HttpState.Closed
proc new*(ht: typedesc[SecureHttpConnectionRef], server: SecureHttpServerRef, proc new(ht: typedesc[SecureHttpConnectionRef], server: SecureHttpServerRef,
transp: StreamTransport): SecureHttpConnectionRef = transp: StreamTransport): Result[SecureHttpConnectionRef, string] =
var res = SecureHttpConnectionRef() var res = SecureHttpConnectionRef()
HttpConnection(res[]).init(HttpServerRef(server), transp) HttpConnection(res[]).init(HttpServerRef(server), transp)
let tlsStream = let tlsStream =
newTLSServerAsyncStream(res.mainReader, res.mainWriter, try:
server.tlsPrivateKey, newTLSServerAsyncStream(res.mainReader, res.mainWriter,
server.tlsCertificate, server.tlsPrivateKey,
minVersion = TLSVersion.TLS12, server.tlsCertificate,
flags = server.secureFlags) minVersion = TLSVersion.TLS12,
flags = server.secureFlags)
except TLSStreamError as exc:
return err(exc.msg)
res.tlsStream = tlsStream res.tlsStream = tlsStream
res.reader = AsyncStreamReader(tlsStream.reader) res.reader = AsyncStreamReader(tlsStream.reader)
res.writer = AsyncStreamWriter(tlsStream.writer) res.writer = AsyncStreamWriter(tlsStream.writer)
res.closeCb = closeSecConnection res.closeCb = closeSecConnection
trackCounter(HttpServerSecureConnectionTrackerName) trackCounter(HttpServerSecureConnectionTrackerName)
res ok(res)
proc createSecConnection(server: HttpServerRef, proc createSecConnection(server: HttpServerRef,
transp: StreamTransport): Future[HttpConnectionRef] {. transp: StreamTransport): Future[HttpConnectionRef] {.
async.} = async: (raises: [CancelledError, HttpConnectionError]).} =
let secureServ = cast[SecureHttpServerRef](server) let
var sconn = SecureHttpConnectionRef.new(secureServ, transp) secureServ = cast[SecureHttpServerRef](server)
sconn = SecureHttpConnectionRef.new(secureServ, transp).valueOr:
raiseHttpConnectionError(error)
try: try:
await handshake(sconn.tlsStream) await handshake(sconn.tlsStream)
return HttpConnectionRef(sconn) HttpConnectionRef(sconn)
except CancelledError as exc: except CancelledError as exc:
await HttpConnectionRef(sconn).closeWait() await HttpConnectionRef(sconn).closeWait()
raise exc raise exc
except TLSStreamError: except AsyncStreamError as exc:
await HttpConnectionRef(sconn).closeWait() await HttpConnectionRef(sconn).closeWait()
raiseHttpCriticalError("Unable to establish secure connection") let msg = "Unable to establish secure connection, reason: " & $exc.msg
raiseHttpConnectionError(msg)
proc new*(htype: typedesc[SecureHttpServerRef],
address: TransportAddress,
processCallback: HttpProcessCallback2,
tlsPrivateKey: TLSPrivateKey,
tlsCertificate: TLSCertificate,
serverFlags: set[HttpServerFlags] = {},
socketFlags: set[ServerFlags] = {ReuseAddr},
serverUri = Uri(),
serverIdent = "",
secureFlags: set[TLSFlags] = {},
maxConnections: int = -1,
bufferSize: int = chronosTransportDefaultBufferSize,
backlogSize: int = DefaultBacklogSize,
httpHeadersTimeout = 10.seconds,
maxHeadersSize: int = 8192,
maxRequestBodySize: int = 1_048_576,
dualstack = DualStackType.Auto
): HttpResult[SecureHttpServerRef] =
doAssert(not(isNil(tlsPrivateKey)), "TLS private key must not be nil!")
doAssert(not(isNil(tlsCertificate)), "TLS certificate must not be nil!")
let
serverInstance =
try:
createStreamServer(address, flags = socketFlags,
bufferSize = bufferSize,
backlog = backlogSize, dualstack = dualstack)
except TransportOsError as exc:
return err(exc.msg)
serverUri =
if len(serverUri.hostname) > 0:
serverUri
else:
parseUri("https://" & $serverInstance.localAddress() & "/")
res = SecureHttpServerRef(
address: serverInstance.localAddress(),
instance: serverInstance,
processCallback: processCallback,
createConnCallback: createSecConnection,
baseUri: serverUri,
serverIdent: serverIdent,
flags: serverFlags + {HttpServerFlags.Secure},
socketFlags: socketFlags,
maxConnections: maxConnections,
bufferSize: bufferSize,
backlogSize: backlogSize,
headersTimeout: httpHeadersTimeout,
maxHeadersSize: maxHeadersSize,
maxRequestBodySize: maxRequestBodySize,
# semaphore:
# if maxConnections > 0:
# newAsyncSemaphore(maxConnections)
# else:
# nil
lifetime: newFuture[void]("http.server.lifetime"),
connections: initOrderedTable[string, HttpConnectionHolderRef](),
tlsCertificate: tlsCertificate,
tlsPrivateKey: tlsPrivateKey,
secureFlags: secureFlags
)
ok(res)
proc new*(htype: typedesc[SecureHttpServerRef], proc new*(htype: typedesc[SecureHttpServerRef],
address: TransportAddress, address: TransportAddress,
@ -89,58 +152,40 @@ proc new*(htype: typedesc[SecureHttpServerRef],
serverIdent = "", serverIdent = "",
secureFlags: set[TLSFlags] = {}, secureFlags: set[TLSFlags] = {},
maxConnections: int = -1, maxConnections: int = -1,
bufferSize: int = 4096, bufferSize: int = chronosTransportDefaultBufferSize,
backlogSize: int = 100, backlogSize: int = DefaultBacklogSize,
httpHeadersTimeout = 10.seconds, httpHeadersTimeout = 10.seconds,
maxHeadersSize: int = 8192, maxHeadersSize: int = 8192,
maxRequestBodySize: int = 1_048_576 maxRequestBodySize: int = 1_048_576,
): HttpResult[SecureHttpServerRef] {.raises: [].} = dualstack = DualStackType.Auto
): HttpResult[SecureHttpServerRef] {.
deprecated: "Callback could raise only CancelledError, annotate with " &
"{.async: (raises: [CancelledError]).}".} =
doAssert(not(isNil(tlsPrivateKey)), "TLS private key must not be nil!") proc wrap(req: RequestFence): Future[HttpResponseRef] {.
doAssert(not(isNil(tlsCertificate)), "TLS certificate must not be nil!") async: (raises: [CancelledError]).} =
let serverUri =
if len(serverUri.hostname) > 0:
serverUri
else:
try:
parseUri("https://" & $address & "/")
except TransportAddressError as exc:
return err(exc.msg)
let serverInstance =
try: try:
createStreamServer(address, flags = socketFlags, bufferSize = bufferSize, await processCallback(req)
backlog = backlogSize) except CancelledError as exc:
except TransportOsError as exc: raise exc
return err(exc.msg)
except CatchableError as exc: except CatchableError as exc:
return err(exc.msg) defaultResponse(exc)
let res = SecureHttpServerRef( SecureHttpServerRef.new(
address: address, address = address,
instance: serverInstance, processCallback = wrap,
processCallback: processCallback, tlsPrivateKey = tlsPrivateKey,
createConnCallback: createSecConnection, tlsCertificate = tlsCertificate,
baseUri: serverUri, serverFlags = serverFlags,
serverIdent: serverIdent, socketFlags = socketFlags,
flags: serverFlags + {HttpServerFlags.Secure}, serverUri = serverUri,
socketFlags: socketFlags, serverIdent = serverIdent,
maxConnections: maxConnections, secureFlags = secureFlags,
bufferSize: bufferSize, maxConnections = maxConnections,
backlogSize: backlogSize, bufferSize = bufferSize,
headersTimeout: httpHeadersTimeout, backlogSize = backlogSize,
maxHeadersSize: maxHeadersSize, httpHeadersTimeout = httpHeadersTimeout,
maxRequestBodySize: maxRequestBodySize, maxHeadersSize = maxHeadersSize,
# semaphore: maxRequestBodySize = maxRequestBodySize,
# if maxConnections > 0: dualstack = dualstack
# newAsyncSemaphore(maxConnections)
# else:
# nil
lifetime: newFuture[void]("http.server.lifetime"),
connections: initOrderedTable[string, HttpConnectionHolderRef](),
tlsCertificate: tlsCertificate,
tlsPrivateKey: tlsPrivateKey,
secureFlags: secureFlags
) )
ok(res)

View File

@ -1,983 +0,0 @@
#
# Chronos
#
# (c) Copyright 2015 Dominik Picheta
# (c) Copyright 2018-2023 Status Research & Development GmbH
#
# Licensed under either of
# Apache License, version 2.0, (LICENSE-APACHEv2)
# MIT license (LICENSE-MIT)
import std/sequtils
import stew/base10
when chronosStackTrace:
when defined(nimHasStacktracesModule):
import system/stacktraces
else:
const
reraisedFromBegin = -10
reraisedFromEnd = -100
template LocCreateIndex*: auto {.deprecated: "LocationKind.Create".} =
LocationKind.Create
template LocFinishIndex*: auto {.deprecated: "LocationKind.Finish".} =
LocationKind.Finish
template LocCompleteIndex*: untyped {.deprecated: "LocationKind.Finish".} =
LocationKind.Finish
func `[]`*(loc: array[LocationKind, ptr SrcLoc], v: int): ptr SrcLoc {.deprecated: "use LocationKind".} =
case v
of 0: loc[LocationKind.Create]
of 1: loc[LocationKind.Finish]
else: raiseAssert("Unknown source location " & $v)
type
FutureStr*[T] = ref object of Future[T]
## Future to hold GC strings
gcholder*: string
FutureSeq*[A, B] = ref object of Future[A]
## Future to hold GC seqs
gcholder*: seq[B]
# Backwards compatibility for old FutureState name
template Finished* {.deprecated: "Use Completed instead".} = Completed
template Finished*(T: type FutureState): FutureState {.deprecated: "Use FutureState.Completed instead".} = FutureState.Completed
proc newFutureImpl[T](loc: ptr SrcLoc): Future[T] =
let fut = Future[T]()
internalInitFutureBase(fut, loc, FutureState.Pending)
fut
proc newFutureSeqImpl[A, B](loc: ptr SrcLoc): FutureSeq[A, B] =
let fut = FutureSeq[A, B]()
internalInitFutureBase(fut, loc, FutureState.Pending)
fut
proc newFutureStrImpl[T](loc: ptr SrcLoc): FutureStr[T] =
let fut = FutureStr[T]()
internalInitFutureBase(fut, loc, FutureState.Pending)
fut
template newFuture*[T](fromProc: static[string] = ""): Future[T] =
## Creates a new future.
##
## Specifying ``fromProc``, which is a string specifying the name of the proc
## that this future belongs to, is a good habit as it helps with debugging.
newFutureImpl[T](getSrcLocation(fromProc))
template newFutureSeq*[A, B](fromProc: static[string] = ""): FutureSeq[A, B] =
## Create a new future which can hold/preserve GC sequence until future will
## not be completed.
##
## Specifying ``fromProc``, which is a string specifying the name of the proc
## that this future belongs to, is a good habit as it helps with debugging.
newFutureSeqImpl[A, B](getSrcLocation(fromProc))
template newFutureStr*[T](fromProc: static[string] = ""): FutureStr[T] =
## Create a new future which can hold/preserve GC string until future will
## not be completed.
##
## Specifying ``fromProc``, which is a string specifying the name of the proc
## that this future belongs to, is a good habit as it helps with debugging.
newFutureStrImpl[T](getSrcLocation(fromProc))
proc done*(future: FutureBase): bool {.deprecated: "Use `completed` instead".} =
## This is an alias for ``completed(future)`` procedure.
completed(future)
when chronosFutureTracking:
proc futureDestructor(udata: pointer) =
## This procedure will be called when Future[T] got completed, cancelled or
## failed and all Future[T].callbacks are already scheduled and processed.
let future = cast[FutureBase](udata)
if future == futureList.tail: futureList.tail = future.prev
if future == futureList.head: futureList.head = future.next
if not(isNil(future.next)): future.next.internalPrev = future.prev
if not(isNil(future.prev)): future.prev.internalNext = future.next
futureList.count.dec()
proc scheduleDestructor(future: FutureBase) {.inline.} =
callSoon(futureDestructor, cast[pointer](future))
proc checkFinished(future: FutureBase, loc: ptr SrcLoc) =
## Checks whether `future` is finished. If it is then raises a
## ``FutureDefect``.
if future.finished():
var msg = ""
msg.add("An attempt was made to complete a Future more than once. ")
msg.add("Details:")
msg.add("\n Future ID: " & Base10.toString(future.id))
msg.add("\n Creation location:")
msg.add("\n " & $future.location[LocationKind.Create])
msg.add("\n First completion location:")
msg.add("\n " & $future.location[LocationKind.Finish])
msg.add("\n Second completion location:")
msg.add("\n " & $loc)
when chronosStackTrace:
msg.add("\n Stack trace to moment of creation:")
msg.add("\n" & indent(future.stackTrace.strip(), 4))
msg.add("\n Stack trace to moment of secondary completion:")
msg.add("\n" & indent(getStackTrace().strip(), 4))
msg.add("\n\n")
var err = newException(FutureDefect, msg)
err.cause = future
raise err
else:
future.internalLocation[LocationKind.Finish] = loc
proc finish(fut: FutureBase, state: FutureState) =
# We do not perform any checks here, because:
# 1. `finish()` is a private procedure and `state` is under our control.
# 2. `fut.state` is checked by `checkFinished()`.
fut.internalState = state
when chronosStrictFutureAccess:
doAssert fut.internalCancelcb == nil or state != FutureState.Cancelled
fut.internalCancelcb = nil # release cancellation callback memory
for item in fut.internalCallbacks.mitems():
if not(isNil(item.function)):
callSoon(item)
item = default(AsyncCallback) # release memory as early as possible
fut.internalCallbacks = default(seq[AsyncCallback]) # release seq as well
when chronosFutureTracking:
scheduleDestructor(fut)
proc complete[T](future: Future[T], val: T, loc: ptr SrcLoc) =
if not(future.cancelled()):
checkFinished(future, loc)
doAssert(isNil(future.internalError))
future.internalValue = val
future.finish(FutureState.Completed)
template complete*[T](future: Future[T], val: T) =
## Completes ``future`` with value ``val``.
complete(future, val, getSrcLocation())
proc complete(future: Future[void], loc: ptr SrcLoc) =
if not(future.cancelled()):
checkFinished(future, loc)
doAssert(isNil(future.internalError))
future.finish(FutureState.Completed)
template complete*(future: Future[void]) =
## Completes a void ``future``.
complete(future, getSrcLocation())
proc fail(future: FutureBase, error: ref CatchableError, loc: ptr SrcLoc) =
if not(future.cancelled()):
checkFinished(future, loc)
future.internalError = error
when chronosStackTrace:
future.internalErrorStackTrace = if getStackTrace(error) == "":
getStackTrace()
else:
getStackTrace(error)
future.finish(FutureState.Failed)
template fail*(future: FutureBase, error: ref CatchableError) =
## Completes ``future`` with ``error``.
fail(future, error, getSrcLocation())
template newCancelledError(): ref CancelledError =
(ref CancelledError)(msg: "Future operation cancelled!")
proc cancelAndSchedule(future: FutureBase, loc: ptr SrcLoc) =
if not(future.finished()):
checkFinished(future, loc)
future.internalError = newCancelledError()
when chronosStackTrace:
future.internalErrorStackTrace = getStackTrace()
future.finish(FutureState.Cancelled)
template cancelAndSchedule*(future: FutureBase) =
cancelAndSchedule(future, getSrcLocation())
proc cancel(future: FutureBase, loc: ptr SrcLoc): bool =
## Request that Future ``future`` cancel itself.
##
## This arranges for a `CancelledError` to be thrown into procedure which
## waits for ``future`` on the next cycle through the event loop.
## The procedure then has a chance to clean up or even deny the request
## using `try/except/finally`.
##
## This call do not guarantee that the ``future`` will be cancelled: the
## exception might be caught and acted upon, delaying cancellation of the
## ``future`` or preventing cancellation completely. The ``future`` may also
## return value or raise different exception.
##
## Immediately after this procedure is called, ``future.cancelled()`` will
## not return ``true`` (unless the Future was already cancelled).
if future.finished():
return false
if not(isNil(future.internalChild)):
# If you hit this assertion, you should have used the `CancelledError`
# mechanism and/or use a regular `addCallback`
when chronosStrictFutureAccess:
doAssert future.internalCancelcb.isNil,
"futures returned from `{.async.}` functions must not use `cancelCallback`"
if cancel(future.internalChild, getSrcLocation()):
return true
else:
if not(isNil(future.internalCancelcb)):
future.internalCancelcb(cast[pointer](future))
future.internalCancelcb = nil
cancelAndSchedule(future, getSrcLocation())
future.internalMustCancel = true
return true
template cancel*(future: FutureBase) =
## Cancel ``future``.
discard cancel(future, getSrcLocation())
proc clearCallbacks(future: FutureBase) =
future.internalCallbacks = default(seq[AsyncCallback])
proc addCallback*(future: FutureBase, cb: CallbackFunc, udata: pointer) =
## Adds the callbacks proc to be called when the future completes.
##
## If future has already completed then ``cb`` will be called immediately.
doAssert(not isNil(cb))
if future.finished():
callSoon(cb, udata)
else:
future.internalCallbacks.add AsyncCallback(function: cb, udata: udata)
proc addCallback*(future: FutureBase, cb: CallbackFunc) =
## Adds the callbacks proc to be called when the future completes.
##
## If future has already completed then ``cb`` will be called immediately.
future.addCallback(cb, cast[pointer](future))
proc removeCallback*(future: FutureBase, cb: CallbackFunc,
udata: pointer) =
## Remove future from list of callbacks - this operation may be slow if there
## are many registered callbacks!
doAssert(not isNil(cb))
# Make sure to release memory associated with callback, or reference chains
# may be created!
future.internalCallbacks.keepItIf:
it.function != cb or it.udata != udata
proc removeCallback*(future: FutureBase, cb: CallbackFunc) =
future.removeCallback(cb, cast[pointer](future))
proc `callback=`*(future: FutureBase, cb: CallbackFunc, udata: pointer) =
## Clears the list of callbacks and sets the callback proc to be called when
## the future completes.
##
## If future has already completed then ``cb`` will be called immediately.
##
## It's recommended to use ``addCallback`` or ``then`` instead.
# ZAH: how about `setLen(1); callbacks[0] = cb`
future.clearCallbacks
future.addCallback(cb, udata)
proc `callback=`*(future: FutureBase, cb: CallbackFunc) =
## Sets the callback proc to be called when the future completes.
##
## If future has already completed then ``cb`` will be called immediately.
`callback=`(future, cb, cast[pointer](future))
proc `cancelCallback=`*(future: FutureBase, cb: CallbackFunc) =
## Sets the callback procedure to be called when the future is cancelled.
##
## This callback will be called immediately as ``future.cancel()`` invoked and
## must be set before future is finished.
when chronosStrictFutureAccess:
doAssert not future.finished(),
"cancellation callback must be set before finishing the future"
future.internalCancelcb = cb
{.push stackTrace: off.}
proc futureContinue*(fut: FutureBase) {.raises: [], gcsafe.}
proc internalContinue(fut: pointer) {.raises: [], gcsafe.} =
let asFut = cast[FutureBase](fut)
GC_unref(asFut)
futureContinue(asFut)
proc futureContinue*(fut: FutureBase) {.raises: [], gcsafe.} =
# This function is responsible for calling the closure iterator generated by
# the `{.async.}` transformation either until it has completed its iteration
# or raised and error / been cancelled.
#
# Every call to an `{.async.}` proc is redirected to call this function
# instead with its original body captured in `fut.closure`.
var next: FutureBase
template iterate =
while true:
# Call closure to make progress on `fut` until it reaches `yield` (inside
# `await` typically) or completes / fails / is cancelled
next = fut.internalClosure(fut)
if fut.internalClosure.finished(): # Reached the end of the transformed proc
break
if next == nil:
raiseAssert "Async procedure (" & ($fut.location[LocationKind.Create]) &
") yielded `nil`, are you await'ing a `nil` Future?"
if not next.finished():
# We cannot make progress on `fut` until `next` has finished - schedule
# `fut` to continue running when that happens
GC_ref(fut)
next.addCallback(CallbackFunc(internalContinue), cast[pointer](fut))
# return here so that we don't remove the closure below
return
# Continue while the yielded future is already finished.
when chronosStrictException:
try:
iterate
except CancelledError:
fut.cancelAndSchedule()
except CatchableError as exc:
fut.fail(exc)
finally:
next = nil # GC hygiene
else:
try:
iterate
except CancelledError:
fut.cancelAndSchedule()
except CatchableError as exc:
fut.fail(exc)
except Exception as exc:
if exc of Defect:
raise (ref Defect)(exc)
fut.fail((ref ValueError)(msg: exc.msg, parent: exc))
finally:
next = nil # GC hygiene
# `futureContinue` will not be called any more for this future so we can
# clean it up
fut.internalClosure = nil
fut.internalChild = nil
{.pop.}
when chronosStackTrace:
import std/strutils
template getFilenameProcname(entry: StackTraceEntry): (string, string) =
when compiles(entry.filenameStr) and compiles(entry.procnameStr):
# We can't rely on "entry.filename" and "entry.procname" still being valid
# cstring pointers, because the "string.data" buffers they pointed to might
# be already garbage collected (this entry being a non-shallow copy,
# "entry.filename" no longer points to "entry.filenameStr.data", but to the
# buffer of the original object).
(entry.filenameStr, entry.procnameStr)
else:
($entry.filename, $entry.procname)
proc `$`(stackTraceEntries: seq[StackTraceEntry]): string =
try:
when defined(nimStackTraceOverride) and declared(addDebuggingInfo):
let entries = addDebuggingInfo(stackTraceEntries)
else:
let entries = stackTraceEntries
# Find longest filename & line number combo for alignment purposes.
var longestLeft = 0
for entry in entries:
let (filename, procname) = getFilenameProcname(entry)
if procname == "": continue
let leftLen = filename.len + len($entry.line)
if leftLen > longestLeft:
longestLeft = leftLen
var indent = 2
# Format the entries.
for entry in entries:
let (filename, procname) = getFilenameProcname(entry)
if procname == "":
if entry.line == reraisedFromBegin:
result.add(spaces(indent) & "#[\n")
indent.inc(2)
elif entry.line == reraisedFromEnd:
indent.dec(2)
result.add(spaces(indent) & "]#\n")
continue
let left = "$#($#)" % [filename, $entry.line]
result.add((spaces(indent) & "$#$# $#\n") % [
left,
spaces(longestLeft - left.len + 2),
procname
])
except ValueError as exc:
return exc.msg # Shouldn't actually happen since we set the formatting
# string
proc injectStacktrace(error: ref Exception) =
const header = "\nAsync traceback:\n"
var exceptionMsg = error.msg
if header in exceptionMsg:
# This is messy: extract the original exception message from the msg
# containing the async traceback.
let start = exceptionMsg.find(header)
exceptionMsg = exceptionMsg[0..<start]
var newMsg = exceptionMsg & header
let entries = getStackTraceEntries(error)
newMsg.add($entries)
newMsg.add("Exception message: " & exceptionMsg & "\n")
# # For debugging purposes
# newMsg.add("Exception type:")
# for entry in getStackTraceEntries(future.error):
# newMsg.add "\n" & $entry
error.msg = newMsg
proc internalCheckComplete*(fut: FutureBase) {.raises: [CatchableError].} =
# For internal use only. Used in asyncmacro
if not(isNil(fut.internalError)):
when chronosStackTrace:
injectStacktrace(fut.internalError)
raise fut.internalError
proc internalRead*[T](fut: Future[T]): T {.inline.} =
# For internal use only. Used in asyncmacro
when T isnot void:
return fut.internalValue
proc read*[T](future: Future[T] ): T {.raises: [CatchableError].} =
## Retrieves the value of ``future``. Future must be finished otherwise
## this function will fail with a ``ValueError`` exception.
##
## If the result of the future is an error then that error will be raised.
if future.finished():
internalCheckComplete(future)
internalRead(future)
else:
# TODO: Make a custom exception type for this?
raise newException(ValueError, "Future still in progress.")
proc readError*(future: FutureBase): ref CatchableError {.raises: [ValueError].} =
## Retrieves the exception stored in ``future``.
##
## An ``ValueError`` exception will be thrown if no exception exists
## in the specified Future.
if not(isNil(future.error)):
return future.error
else:
# TODO: Make a custom exception type for this?
raise newException(ValueError, "No error in future.")
template taskFutureLocation(future: FutureBase): string =
let loc = future.location[LocationKind.Create]
"[" & (
if len(loc.procedure) == 0: "[unspecified]" else: $loc.procedure & "()"
) & " at " & $loc.file & ":" & $(loc.line) & "]"
template taskErrorMessage(future: FutureBase): string =
"Asynchronous task " & taskFutureLocation(future) &
" finished with an exception \"" & $future.error.name &
"\"!\nMessage: " & future.error.msg &
"\nStack trace: " & future.error.getStackTrace()
template taskCancelMessage(future: FutureBase): string =
"Asynchronous task " & taskFutureLocation(future) & " was cancelled!"
proc asyncSpawn*(future: Future[void]) =
## Spawns a new concurrent async task.
##
## Tasks may not raise exceptions or be cancelled - a ``Defect`` will be
## raised when this happens.
##
## This should be used instead of ``discard`` and ``asyncCheck`` when calling
## an ``async`` procedure without ``await``, to ensure exceptions in the
## returned future are not silently discarded.
##
## Note, that if passed ``future`` is already finished, it will be checked
## and processed immediately.
doAssert(not isNil(future), "Future is nil")
proc cb(data: pointer) =
if future.failed():
raise newException(FutureDefect, taskErrorMessage(future))
elif future.cancelled():
raise newException(FutureDefect, taskCancelMessage(future))
if not(future.finished()):
# We adding completion callback only if ``future`` is not finished yet.
future.addCallback(cb)
else:
cb(nil)
proc asyncCheck*[T](future: Future[T]) {.
deprecated: "Raises Defect on future failure, fix your code and use" &
" asyncSpawn!".} =
## This function used to raise an exception through the `poll` call if
## the given future failed - there's no way to handle such exceptions so this
## function is now an alias for `asyncSpawn`
##
when T is void:
asyncSpawn(future)
else:
proc cb(data: pointer) =
if future.failed():
raise newException(FutureDefect, taskErrorMessage(future))
elif future.cancelled():
raise newException(FutureDefect, taskCancelMessage(future))
if not(future.finished()):
# We adding completion callback only if ``future`` is not finished yet.
future.addCallback(cb)
else:
cb(nil)
proc asyncDiscard*[T](future: Future[T]) {.
deprecated: "Use asyncSpawn or `discard await`".} = discard
## `asyncDiscard` will discard the outcome of the operation - unlike `discard`
## it also throws away exceptions! Use `asyncSpawn` if you're sure your
## code doesn't raise exceptions, or `discard await` to ignore successful
## outcomes
proc `and`*[T, Y](fut1: Future[T], fut2: Future[Y]): Future[void] {.
deprecated: "Use allFutures[T](varargs[Future[T]])".} =
## Returns a future which will complete once both ``fut1`` and ``fut2``
## finish.
##
## If cancelled, ``fut1`` and ``fut2`` futures WILL NOT BE cancelled.
var retFuture = newFuture[void]("chronos.`and`")
proc cb(data: pointer) =
if not(retFuture.finished()):
if fut1.finished() and fut2.finished():
if cast[pointer](fut1) == data:
if fut1.failed():
retFuture.fail(fut1.error)
else:
retFuture.complete()
else:
if fut2.failed():
retFuture.fail(fut2.error)
else:
retFuture.complete()
fut1.callback = cb
fut2.callback = cb
proc cancellation(udata: pointer) =
# On cancel we remove all our callbacks only.
if not(fut1.finished()):
fut1.removeCallback(cb)
if not(fut2.finished()):
fut2.removeCallback(cb)
retFuture.cancelCallback = cancellation
return retFuture
proc `or`*[T, Y](fut1: Future[T], fut2: Future[Y]): Future[void] =
## Returns a future which will complete once either ``fut1`` or ``fut2``
## finish.
##
## If ``fut1`` or ``fut2`` future is failed, the result future will also be
## failed with an error stored in ``fut1`` or ``fut2`` respectively.
##
## If both ``fut1`` and ``fut2`` future are completed or failed, the result
## future will depend on the state of ``fut1`` future. So if ``fut1`` future
## is failed, the result future will also be failed, if ``fut1`` future is
## completed, the result future will also be completed.
##
## If cancelled, ``fut1`` and ``fut2`` futures WILL NOT BE cancelled.
var retFuture = newFuture[void]("chronos.or")
var cb: proc(udata: pointer) {.gcsafe, raises: [].}
cb = proc(udata: pointer) {.gcsafe, raises: [].} =
if not(retFuture.finished()):
var fut = cast[FutureBase](udata)
if cast[pointer](fut1) == udata:
fut2.removeCallback(cb)
else:
fut1.removeCallback(cb)
if fut.failed():
retFuture.fail(fut.error)
else:
retFuture.complete()
proc cancellation(udata: pointer) =
# On cancel we remove all our callbacks only.
if not(fut1.finished()):
fut1.removeCallback(cb)
if not(fut2.finished()):
fut2.removeCallback(cb)
if fut1.finished():
if fut1.failed():
retFuture.fail(fut1.error)
else:
retFuture.complete()
return retFuture
if fut2.finished():
if fut2.failed():
retFuture.fail(fut2.error)
else:
retFuture.complete()
return retFuture
fut1.addCallback(cb)
fut2.addCallback(cb)
retFuture.cancelCallback = cancellation
return retFuture
proc all*[T](futs: varargs[Future[T]]): auto {.
deprecated: "Use allFutures(varargs[Future[T]])".} =
## Returns a future which will complete once all futures in ``futs`` finish.
## If the argument is empty, the returned future completes immediately.
##
## If the awaited futures are not ``Future[void]``, the returned future
## will hold the values of all awaited futures in a sequence.
##
## If the awaited futures *are* ``Future[void]``, this proc returns
## ``Future[void]``.
##
## Note, that if one of the futures in ``futs`` will fail, result of ``all()``
## will also be failed with error from failed future.
##
## TODO: This procedure has bug on handling cancelled futures from ``futs``.
## So if future from ``futs`` list become cancelled, what must be returned?
## You can't cancel result ``retFuture`` because in such way infinite
## recursion will happen.
let totalFutures = len(futs)
var completedFutures = 0
# Because we can't capture varargs[T] in closures we need to create copy.
var nfuts = @futs
when T is void:
var retFuture = newFuture[void]("chronos.all(void)")
proc cb(udata: pointer) =
if not(retFuture.finished()):
inc(completedFutures)
if completedFutures == totalFutures:
for nfut in nfuts:
if nfut.failed():
retFuture.fail(nfut.error)
break
if not(retFuture.failed()):
retFuture.complete()
for fut in nfuts:
fut.addCallback(cb)
if len(nfuts) == 0:
retFuture.complete()
return retFuture
else:
var retFuture = newFuture[seq[T]]("chronos.all(T)")
var retValues = newSeq[T](totalFutures)
proc cb(udata: pointer) =
if not(retFuture.finished()):
inc(completedFutures)
if completedFutures == totalFutures:
for k, nfut in nfuts:
if nfut.failed():
retFuture.fail(nfut.error)
break
else:
retValues[k] = nfut.value
if not(retFuture.failed()):
retFuture.complete(retValues)
for fut in nfuts:
fut.addCallback(cb)
if len(nfuts) == 0:
retFuture.complete(retValues)
return retFuture
proc oneIndex*[T](futs: varargs[Future[T]]): Future[int] {.
deprecated: "Use one[T](varargs[Future[T]])".} =
## Returns a future which will complete once one of the futures in ``futs``
## complete.
##
## If the argument is empty, the returned future FAILS immediately.
##
## Returned future will hold index of completed/failed future in ``futs``
## argument.
var nfuts = @futs
var retFuture = newFuture[int]("chronos.oneIndex(T)")
proc cb(udata: pointer) =
var res = -1
if not(retFuture.finished()):
var rfut = cast[FutureBase](udata)
for i in 0..<len(nfuts):
if cast[FutureBase](nfuts[i]) != rfut:
nfuts[i].removeCallback(cb)
else:
res = i
retFuture.complete(res)
for fut in nfuts:
fut.addCallback(cb)
if len(nfuts) == 0:
retFuture.fail(newException(ValueError, "Empty Future[T] list"))
return retFuture
proc oneValue*[T](futs: varargs[Future[T]]): Future[T] {.
deprecated: "Use one[T](varargs[Future[T]])".} =
## Returns a future which will finish once one of the futures in ``futs``
## finish.
##
## If the argument is empty, returned future FAILS immediately.
##
## Returned future will hold value of completed ``futs`` future, or error
## if future was failed.
var nfuts = @futs
var retFuture = newFuture[T]("chronos.oneValue(T)")
proc cb(udata: pointer) =
var resFut: Future[T]
if not(retFuture.finished()):
var rfut = cast[FutureBase](udata)
for i in 0..<len(nfuts):
if cast[FutureBase](nfuts[i]) != rfut:
nfuts[i].removeCallback(cb)
else:
resFut = nfuts[i]
if resFut.failed():
retFuture.fail(resFut.error)
else:
when T is void:
retFuture.complete()
else:
retFuture.complete(resFut.read())
for fut in nfuts:
fut.addCallback(cb)
if len(nfuts) == 0:
retFuture.fail(newException(ValueError, "Empty Future[T] list"))
return retFuture
proc cancelAndWait*(fut: FutureBase): Future[void] =
## Initiate cancellation process for Future ``fut`` and wait until ``fut`` is
## done e.g. changes its state (become completed, failed or cancelled).
##
## If ``fut`` is already finished (completed, failed or cancelled) result
## Future[void] object will be returned complete.
var retFuture = newFuture[void]("chronos.cancelAndWait(T)")
proc continuation(udata: pointer) =
if not(retFuture.finished()):
retFuture.complete()
proc cancellation(udata: pointer) =
if not(fut.finished()):
fut.removeCallback(continuation)
if fut.finished():
retFuture.complete()
else:
fut.addCallback(continuation)
retFuture.cancelCallback = cancellation
# Initiate cancellation process.
fut.cancel()
return retFuture
proc allFutures*(futs: varargs[FutureBase]): Future[void] =
## Returns a future which will complete only when all futures in ``futs``
## will be completed, failed or canceled.
##
## If the argument is empty, the returned future COMPLETES immediately.
##
## On cancel all the awaited futures ``futs`` WILL NOT BE cancelled.
var retFuture = newFuture[void]("chronos.allFutures()")
let totalFutures = len(futs)
var finishedFutures = 0
# Because we can't capture varargs[T] in closures we need to create copy.
var nfuts = @futs
proc cb(udata: pointer) =
if not(retFuture.finished()):
inc(finishedFutures)
if finishedFutures == totalFutures:
retFuture.complete()
proc cancellation(udata: pointer) =
# On cancel we remove all our callbacks only.
for i in 0..<len(nfuts):
if not(nfuts[i].finished()):
nfuts[i].removeCallback(cb)
for fut in nfuts:
if not(fut.finished()):
fut.addCallback(cb)
else:
inc(finishedFutures)
retFuture.cancelCallback = cancellation
if len(nfuts) == 0 or len(nfuts) == finishedFutures:
retFuture.complete()
return retFuture
proc allFutures*[T](futs: varargs[Future[T]]): Future[void] =
## Returns a future which will complete only when all futures in ``futs``
## will be completed, failed or canceled.
##
## If the argument is empty, the returned future COMPLETES immediately.
##
## On cancel all the awaited futures ``futs`` WILL NOT BE cancelled.
# Because we can't capture varargs[T] in closures we need to create copy.
var nfuts: seq[FutureBase]
for future in futs:
nfuts.add(future)
allFutures(nfuts)
proc allFinished*[T](futs: varargs[Future[T]]): Future[seq[Future[T]]] =
## Returns a future which will complete only when all futures in ``futs``
## will be completed, failed or canceled.
##
## Returned sequence will hold all the Future[T] objects passed to
## ``allFinished`` with the order preserved.
##
## If the argument is empty, the returned future COMPLETES immediately.
##
## On cancel all the awaited futures ``futs`` WILL NOT BE cancelled.
var retFuture = newFuture[seq[Future[T]]]("chronos.allFinished()")
let totalFutures = len(futs)
var finishedFutures = 0
var nfuts = @futs
proc cb(udata: pointer) =
if not(retFuture.finished()):
inc(finishedFutures)
if finishedFutures == totalFutures:
retFuture.complete(nfuts)
proc cancellation(udata: pointer) =
# On cancel we remove all our callbacks only.
for fut in nfuts.mitems():
if not(fut.finished()):
fut.removeCallback(cb)
for fut in nfuts:
if not(fut.finished()):
fut.addCallback(cb)
else:
inc(finishedFutures)
retFuture.cancelCallback = cancellation
if len(nfuts) == 0 or len(nfuts) == finishedFutures:
retFuture.complete(nfuts)
return retFuture
proc one*[T](futs: varargs[Future[T]]): Future[Future[T]] =
## Returns a future which will complete and return completed Future[T] inside,
## when one of the futures in ``futs`` will be completed, failed or canceled.
##
## If the argument is empty, the returned future FAILS immediately.
##
## On success returned Future will hold finished Future[T].
##
## On cancel futures in ``futs`` WILL NOT BE cancelled.
var retFuture = newFuture[Future[T]]("chronos.one()")
if len(futs) == 0:
retFuture.fail(newException(ValueError, "Empty Future[T] list"))
return retFuture
# If one of the Future[T] already finished we return it as result
for fut in futs:
if fut.finished():
retFuture.complete(fut)
return retFuture
# Because we can't capture varargs[T] in closures we need to create copy.
var nfuts = @futs
var cb: proc(udata: pointer) {.gcsafe, raises: [].}
cb = proc(udata: pointer) {.gcsafe, raises: [].} =
if not(retFuture.finished()):
var res: Future[T]
var rfut = cast[FutureBase](udata)
for i in 0..<len(nfuts):
if cast[FutureBase](nfuts[i]) != rfut:
nfuts[i].removeCallback(cb)
else:
res = nfuts[i]
retFuture.complete(res)
proc cancellation(udata: pointer) =
# On cancel we remove all our callbacks only.
for i in 0..<len(nfuts):
if not(nfuts[i].finished()):
nfuts[i].removeCallback(cb)
for fut in nfuts:
fut.addCallback(cb)
retFuture.cancelCallback = cancellation
return retFuture
proc race*(futs: varargs[FutureBase]): Future[FutureBase] =
## Returns a future which will complete and return completed FutureBase,
## when one of the futures in ``futs`` will be completed, failed or canceled.
##
## If the argument is empty, the returned future FAILS immediately.
##
## On success returned Future will hold finished FutureBase.
##
## On cancel futures in ``futs`` WILL NOT BE cancelled.
let retFuture = newFuture[FutureBase]("chronos.race()")
if len(futs) == 0:
retFuture.fail(newException(ValueError, "Empty Future[T] list"))
return retFuture
# If one of the Future[T] already finished we return it as result
for fut in futs:
if fut.finished():
retFuture.complete(fut)
return retFuture
# Because we can't capture varargs[T] in closures we need to create copy.
var nfuts = @futs
var cb: proc(udata: pointer) {.gcsafe, raises: [].}
cb = proc(udata: pointer) {.gcsafe, raises: [].} =
if not(retFuture.finished()):
var res: FutureBase
var rfut = cast[FutureBase](udata)
for i in 0..<len(nfuts):
if nfuts[i] != rfut:
nfuts[i].removeCallback(cb)
else:
res = nfuts[i]
retFuture.complete(res)
proc cancellation(udata: pointer) =
# On cancel we remove all our callbacks only.
for i in 0..<len(nfuts):
if not(nfuts[i].finished()):
nfuts[i].removeCallback(cb)
for fut in nfuts:
fut.addCallback(cb, cast[pointer](fut))
retFuture.cancelCallback = cancellation
return retFuture

File diff suppressed because it is too large Load Diff

View File

@ -1,336 +0,0 @@
#
#
# Nim's Runtime Library
# (c) Copyright 2015 Dominik Picheta
#
# See the file "copying.txt", included in this
# distribution, for details about the copyright.
#
import std/[macros]
# `quote do` will ruin line numbers so we avoid it using these helpers
proc completeWithResult(fut, baseType: NimNode): NimNode {.compileTime.} =
# when `baseType` is void:
# complete(`fut`)
# else:
# complete(`fut`, result)
if baseType.eqIdent("void"):
# Shortcut if we know baseType at macro expansion time
newCall(ident "complete", fut)
else:
# `baseType` might be generic and resolve to `void`
nnkWhenStmt.newTree(
nnkElifExpr.newTree(
nnkInfix.newTree(ident "is", baseType, ident "void"),
newCall(ident "complete", fut)
),
nnkElseExpr.newTree(
newCall(ident "complete", fut, ident "result")
)
)
proc completeWithNode(fut, baseType, node: NimNode): NimNode {.compileTime.} =
# when typeof(`node`) is void:
# `node` # statement / explicit return
# -> completeWithResult(fut, baseType)
# else: # expression / implicit return
# complete(`fut`, `node`)
if node.kind == nnkEmpty: # shortcut when known at macro expanstion time
completeWithResult(fut, baseType)
else:
# Handle both expressions and statements - since the type is not know at
# macro expansion time, we delegate this choice to a later compilation stage
# with `when`.
nnkWhenStmt.newTree(
nnkElifExpr.newTree(
nnkInfix.newTree(
ident "is", nnkTypeOfExpr.newTree(node), ident "void"),
newStmtList(
node,
completeWithResult(fut, baseType)
)
),
nnkElseExpr.newTree(
newCall(ident "complete", fut, node)
)
)
proc processBody(node, fut, baseType: NimNode): NimNode {.compileTime.} =
#echo(node.treeRepr)
case node.kind
of nnkReturnStmt:
let
res = newNimNode(nnkStmtList, node)
res.add completeWithNode(fut, baseType, processBody(node[0], fut, baseType))
res.add newNimNode(nnkReturnStmt, node).add(newNilLit())
res
of RoutineNodes-{nnkTemplateDef}:
# skip all the nested procedure definitions
node
else:
for i in 0 ..< node.len:
# We must not transform nested procedures of any form, otherwise
# `fut` will be used for all nested procedures as their own
# `retFuture`.
node[i] = processBody(node[i], fut, baseType)
node
proc getName(node: NimNode): string {.compileTime.} =
case node.kind
of nnkSym:
return node.strVal
of nnkPostfix:
return node[1].strVal
of nnkIdent:
return node.strVal
of nnkEmpty:
return "anonymous"
else:
error("Unknown name.")
macro unsupported(s: static[string]): untyped =
error s
proc params2(someProc: NimNode): NimNode =
# until https://github.com/nim-lang/Nim/pull/19563 is available
if someProc.kind == nnkProcTy:
someProc[0]
else:
params(someProc)
proc cleanupOpenSymChoice(node: NimNode): NimNode {.compileTime.} =
# Replace every Call -> OpenSymChoice by a Bracket expr
# ref https://github.com/nim-lang/Nim/issues/11091
if node.kind in nnkCallKinds and
node[0].kind == nnkOpenSymChoice and node[0].eqIdent("[]"):
result = newNimNode(nnkBracketExpr)
for child in node[1..^1]:
result.add(cleanupOpenSymChoice(child))
else:
result = node.copyNimNode()
for child in node:
result.add(cleanupOpenSymChoice(child))
proc asyncSingleProc(prc: NimNode): NimNode {.compileTime.} =
## This macro transforms a single procedure into a closure iterator.
## The ``async`` macro supports a stmtList holding multiple async procedures.
if prc.kind notin {nnkProcTy, nnkProcDef, nnkLambda, nnkMethodDef, nnkDo}:
error("Cannot transform " & $prc.kind & " into an async proc." &
" proc/method definition or lambda node expected.", prc)
let returnType = cleanupOpenSymChoice(prc.params2[0])
# Verify that the return type is a Future[T]
let baseType =
if returnType.kind == nnkEmpty:
ident "void"
elif not (
returnType.kind == nnkBracketExpr and eqIdent(returnType[0], "Future")):
error(
"Expected return type of 'Future' got '" & repr(returnType) & "'", prc)
return
else:
returnType[1]
let
baseTypeIsVoid = baseType.eqIdent("void")
futureVoidType = nnkBracketExpr.newTree(ident "Future", ident "void")
if prc.kind in {nnkProcDef, nnkLambda, nnkMethodDef, nnkDo}:
let
prcName = prc.name.getName
outerProcBody = newNimNode(nnkStmtList, prc.body)
# Copy comment for nimdoc
if prc.body.len > 0 and prc.body[0].kind == nnkCommentStmt:
outerProcBody.add(prc.body[0])
let
internalFutureSym = ident "chronosInternalRetFuture"
internalFutureType =
if baseTypeIsVoid: futureVoidType
else: returnType
castFutureSym = nnkCast.newTree(internalFutureType, internalFutureSym)
procBody = prc.body.processBody(castFutureSym, baseType)
# don't do anything with forward bodies (empty)
if procBody.kind != nnkEmpty:
let
# fix #13899, `defer` should not escape its original scope
procBodyBlck = nnkBlockStmt.newTree(newEmptyNode(), procBody)
resultDecl = nnkWhenStmt.newTree(
# when `baseType` is void:
nnkElifExpr.newTree(
nnkInfix.newTree(ident "is", baseType, ident "void"),
quote do:
template result: auto {.used.} =
{.fatal: "You should not reference the `result` variable inside" &
" a void async proc".}
),
# else:
nnkElseExpr.newTree(
newStmtList(
quote do: {.push warning[resultshadowed]: off.},
# var result {.used.}: `baseType`
# In the proc body, result may or may not end up being used
# depending on how the body is written - with implicit returns /
# expressions in particular, it is likely but not guaranteed that
# it is not used. Ideally, we would avoid emitting it in this
# case to avoid the default initializaiton. {.used.} typically
# works better than {.push.} which has a tendency to leak out of
# scope.
# TODO figure out if there's a way to detect `result` usage in
# the proc body _after_ template exapnsion, and therefore
# avoid creating this variable - one option is to create an
# addtional when branch witha fake `result` and check
# `compiles(procBody)` - this is not without cost though
nnkVarSection.newTree(nnkIdentDefs.newTree(
nnkPragmaExpr.newTree(
ident "result",
nnkPragma.newTree(ident "used")),
baseType, newEmptyNode())
),
quote do: {.pop.},
)
)
)
completeDecl = completeWithNode(castFutureSym, baseType, procBodyBlck)
closureBody = newStmtList(resultDecl, completeDecl)
internalFutureParameter = nnkIdentDefs.newTree(
internalFutureSym, newIdentNode("FutureBase"), newEmptyNode())
iteratorNameSym = genSym(nskIterator, $prcName)
closureIterator = newProc(
iteratorNameSym,
[newIdentNode("FutureBase"), internalFutureParameter],
closureBody, nnkIteratorDef)
iteratorNameSym.copyLineInfo(prc)
closureIterator.pragma = newNimNode(nnkPragma, lineInfoFrom=prc.body)
closureIterator.addPragma(newIdentNode("closure"))
# `async` code must be gcsafe
closureIterator.addPragma(newIdentNode("gcsafe"))
# TODO when push raises is active in a module, the iterator here inherits
# that annotation - here we explicitly disable it again which goes
# against the spirit of the raises annotation - one should investigate
# here the possibility of transporting more specific error types here
# for example by casting exceptions coming out of `await`..
let raises = nnkBracket.newTree()
when chronosStrictException:
raises.add(newIdentNode("CatchableError"))
else:
raises.add(newIdentNode("Exception"))
closureIterator.addPragma(nnkExprColonExpr.newTree(
newIdentNode("raises"),
raises
))
# If proc has an explicit gcsafe pragma, we add it to iterator as well.
# TODO if these lines are not here, srcloc tests fail (!)
if prc.pragma.findChild(it.kind in {nnkSym, nnkIdent} and
it.strVal == "gcsafe") != nil:
closureIterator.addPragma(newIdentNode("gcsafe"))
outerProcBody.add(closureIterator)
# -> let resultFuture = newFuture[T]()
# declared at the end to be sure that the closure
# doesn't reference it, avoid cyclic ref (#203)
let
retFutureSym = ident "resultFuture"
retFutureSym.copyLineInfo(prc)
# Do not change this code to `quote do` version because `instantiationInfo`
# will be broken for `newFuture()` call.
outerProcBody.add(
newLetStmt(
retFutureSym,
newCall(newTree(nnkBracketExpr, ident "newFuture", baseType),
newLit(prcName))
)
)
# -> resultFuture.internalClosure = iterator
outerProcBody.add(
newAssignment(
newDotExpr(retFutureSym, newIdentNode("internalClosure")),
iteratorNameSym)
)
# -> futureContinue(resultFuture))
outerProcBody.add(
newCall(newIdentNode("futureContinue"), retFutureSym)
)
# -> return resultFuture
outerProcBody.add newNimNode(nnkReturnStmt, prc.body[^1]).add(retFutureSym)
prc.body = outerProcBody
if prc.kind notin {nnkProcTy, nnkLambda}: # TODO: Nim bug?
prc.addPragma(newColonExpr(ident "stackTrace", ident "off"))
# See **Remark 435** in this file.
# https://github.com/nim-lang/RFCs/issues/435
prc.addPragma(newIdentNode("gcsafe"))
prc.addPragma(nnkExprColonExpr.newTree(
newIdentNode("raises"),
nnkBracket.newTree()
))
if baseTypeIsVoid:
if returnType.kind == nnkEmpty:
# Add Future[void]
prc.params2[0] = futureVoidType
prc
template await*[T](f: Future[T]): untyped =
when declared(chronosInternalRetFuture):
chronosInternalRetFuture.internalChild = f
# `futureContinue` calls the iterator generated by the `async`
# transformation - `yield` gives control back to `futureContinue` which is
# responsible for resuming execution once the yielded future is finished
yield chronosInternalRetFuture.internalChild
# `child` is guaranteed to have been `finished` after the yield
if chronosInternalRetFuture.internalMustCancel:
raise newCancelledError()
# `child` released by `futureContinue`
chronosInternalRetFuture.internalChild.internalCheckComplete()
when T isnot void:
cast[type(f)](chronosInternalRetFuture.internalChild).internalRead()
else:
unsupported "await is only available within {.async.}"
template awaitne*[T](f: Future[T]): Future[T] =
when declared(chronosInternalRetFuture):
chronosInternalRetFuture.internalChild = f
yield chronosInternalRetFuture.internalChild
if chronosInternalRetFuture.internalMustCancel:
raise newCancelledError()
cast[type(f)](chronosInternalRetFuture.internalChild)
else:
unsupported "awaitne is only available within {.async.}"
macro async*(prc: untyped): untyped =
## Macro which processes async procedures into the appropriate
## iterators and yield statements.
if prc.kind == nnkStmtList:
result = newStmtList()
for oneProc in prc:
result.add asyncSingleProc(oneProc)
else:
result = asyncSingleProc(prc)
when chronosDumpAsync:
echo repr result

View File

@ -13,7 +13,7 @@
import std/strtabs import std/strtabs
import "."/[config, asyncloop, handles, osdefs, osutils, oserrno], import "."/[config, asyncloop, handles, osdefs, osutils, oserrno],
streams/asyncstream streams/asyncstream
import stew/[results, byteutils] import stew/[byteutils], results
from std/os import quoteShell, quoteShellWindows, quoteShellPosix, envPairs from std/os import quoteShell, quoteShellWindows, quoteShellPosix, envPairs
export strtabs, results export strtabs, results
@ -24,7 +24,8 @@ const
## AsyncProcess leaks tracker name ## AsyncProcess leaks tracker name
type type
AsyncProcessError* = object of CatchableError AsyncProcessError* = object of AsyncError
AsyncProcessTimeoutError* = object of AsyncProcessError
AsyncProcessResult*[T] = Result[T, OSErrorCode] AsyncProcessResult*[T] = Result[T, OSErrorCode]
@ -107,6 +108,9 @@ type
stdError*: string stdError*: string
status*: int status*: int
WaitOperation {.pure.} = enum
Kill, Terminate
template Pipe*(t: typedesc[AsyncProcess]): ProcessStreamHandle = template Pipe*(t: typedesc[AsyncProcess]): ProcessStreamHandle =
ProcessStreamHandle(kind: ProcessStreamHandleKind.Auto) ProcessStreamHandle(kind: ProcessStreamHandleKind.Auto)
@ -227,8 +231,9 @@ proc closeProcessHandles(pipes: var AsyncProcessPipes,
lastError: OSErrorCode): OSErrorCode {.apforward.} lastError: OSErrorCode): OSErrorCode {.apforward.}
proc closeProcessStreams(pipes: AsyncProcessPipes, proc closeProcessStreams(pipes: AsyncProcessPipes,
options: set[AsyncProcessOption]): Future[void] {. options: set[AsyncProcessOption]): Future[void] {.
apforward.} async: (raises: []).}
proc closeWait(holder: AsyncStreamHolder): Future[void] {.apforward.} proc closeWait(holder: AsyncStreamHolder): Future[void] {.
async: (raises: []).}
template isOk(code: OSErrorCode): bool = template isOk(code: OSErrorCode): bool =
when defined(windows): when defined(windows):
@ -294,6 +299,11 @@ proc raiseAsyncProcessError(msg: string, exc: ref CatchableError = nil) {.
msg & " ([" & $exc.name & "]: " & $exc.msg & ")" msg & " ([" & $exc.name & "]: " & $exc.msg & ")"
raise newException(AsyncProcessError, message) raise newException(AsyncProcessError, message)
proc raiseAsyncProcessTimeoutError() {.
noreturn, noinit, noinline, raises: [AsyncProcessTimeoutError].} =
let message = "Operation timed out"
raise newException(AsyncProcessTimeoutError, message)
proc raiseAsyncProcessError(msg: string, error: OSErrorCode|cint) {. proc raiseAsyncProcessError(msg: string, error: OSErrorCode|cint) {.
noreturn, noinit, noinline, raises: [AsyncProcessError].} = noreturn, noinit, noinline, raises: [AsyncProcessError].} =
when error is OSErrorCode: when error is OSErrorCode:
@ -382,7 +392,8 @@ when defined(windows):
stdinHandle = ProcessStreamHandle(), stdinHandle = ProcessStreamHandle(),
stdoutHandle = ProcessStreamHandle(), stdoutHandle = ProcessStreamHandle(),
stderrHandle = ProcessStreamHandle(), stderrHandle = ProcessStreamHandle(),
): Future[AsyncProcessRef] {.async.} = ): Future[AsyncProcessRef] {.
async: (raises: [AsyncProcessError, CancelledError]).} =
var var
pipes = preparePipes(options, stdinHandle, stdoutHandle, pipes = preparePipes(options, stdinHandle, stdoutHandle,
stderrHandle).valueOr: stderrHandle).valueOr:
@ -508,14 +519,16 @@ when defined(windows):
ok(false) ok(false)
proc waitForExit*(p: AsyncProcessRef, proc waitForExit*(p: AsyncProcessRef,
timeout = InfiniteDuration): Future[int] {.async.} = timeout = InfiniteDuration): Future[int] {.
async: (raises: [AsyncProcessError, AsyncProcessTimeoutError,
CancelledError]).} =
if p.exitStatus.isSome(): if p.exitStatus.isSome():
return p.exitStatus.get() return p.exitStatus.get()
let wres = let wres =
try: try:
await waitForSingleObject(p.processHandle, timeout) await waitForSingleObject(p.processHandle, timeout)
except ValueError as exc: except AsyncError as exc:
raiseAsyncProcessError("Unable to wait for process handle", exc) raiseAsyncProcessError("Unable to wait for process handle", exc)
if wres == WaitableResult.Timeout: if wres == WaitableResult.Timeout:
@ -528,7 +541,8 @@ when defined(windows):
if exitCode >= 0: if exitCode >= 0:
p.exitStatus = Opt.some(exitCode) p.exitStatus = Opt.some(exitCode)
return exitCode
exitCode
proc peekExitCode(p: AsyncProcessRef): AsyncProcessResult[int] = proc peekExitCode(p: AsyncProcessRef): AsyncProcessResult[int] =
if p.exitStatus.isSome(): if p.exitStatus.isSome():
@ -778,7 +792,8 @@ else:
stdinHandle = ProcessStreamHandle(), stdinHandle = ProcessStreamHandle(),
stdoutHandle = ProcessStreamHandle(), stdoutHandle = ProcessStreamHandle(),
stderrHandle = ProcessStreamHandle(), stderrHandle = ProcessStreamHandle(),
): Future[AsyncProcessRef] {.async.} = ): Future[AsyncProcessRef] {.
async: (raises: [AsyncProcessError, CancelledError]).} =
var var
pid: Pid pid: Pid
pipes = preparePipes(options, stdinHandle, stdoutHandle, pipes = preparePipes(options, stdinHandle, stdoutHandle,
@ -878,7 +893,7 @@ else:
) )
trackCounter(AsyncProcessTrackerName) trackCounter(AsyncProcessTrackerName)
return process process
proc peekProcessExitCode(p: AsyncProcessRef, proc peekProcessExitCode(p: AsyncProcessRef,
reap = false): AsyncProcessResult[int] = reap = false): AsyncProcessResult[int] =
@ -939,7 +954,9 @@ else:
ok(false) ok(false)
proc waitForExit*(p: AsyncProcessRef, proc waitForExit*(p: AsyncProcessRef,
timeout = InfiniteDuration): Future[int] = timeout = InfiniteDuration): Future[int] {.
async: (raw: true, raises: [
AsyncProcessError, AsyncProcessTimeoutError, CancelledError]).} =
var var
retFuture = newFuture[int]("chronos.waitForExit()") retFuture = newFuture[int]("chronos.waitForExit()")
processHandle: ProcessHandle processHandle: ProcessHandle
@ -979,7 +996,7 @@ else:
return return
if not(isNil(timer)): if not(isNil(timer)):
clearTimer(timer) clearTimer(timer)
let exitCode = p.peekProcessExitCode().valueOr: let exitCode = p.peekProcessExitCode(reap = true).valueOr:
retFuture.fail(newException(AsyncProcessError, osErrorMsg(error))) retFuture.fail(newException(AsyncProcessError, osErrorMsg(error)))
return return
if exitCode == -1: if exitCode == -1:
@ -993,12 +1010,14 @@ else:
retFuture.fail(newException(AsyncProcessError, retFuture.fail(newException(AsyncProcessError,
osErrorMsg(res.error()))) osErrorMsg(res.error())))
timer = nil
proc cancellation(udata: pointer) {.gcsafe.} = proc cancellation(udata: pointer) {.gcsafe.} =
if not(retFuture.finished()): if not(isNil(timer)):
if not(isNil(timer)): clearTimer(timer)
clearTimer(timer) timer = nil
# Ignore any errors because of cancellation. # Ignore any errors because of cancellation.
discard removeProcess2(processHandle) discard removeProcess2(processHandle)
if timeout != InfiniteDuration: if timeout != InfiniteDuration:
timer = setTimer(Moment.fromNow(timeout), continuation, cast[pointer](2)) timer = setTimer(Moment.fromNow(timeout), continuation, cast[pointer](2))
@ -1041,7 +1060,7 @@ else:
# Process is still running, so we going to wait for SIGCHLD. # Process is still running, so we going to wait for SIGCHLD.
retFuture.cancelCallback = cancellation retFuture.cancelCallback = cancellation
return retFuture retFuture
proc peekExitCode(p: AsyncProcessRef): AsyncProcessResult[int] = proc peekExitCode(p: AsyncProcessRef): AsyncProcessResult[int] =
let res = ? p.peekProcessExitCode() let res = ? p.peekProcessExitCode()
@ -1146,7 +1165,7 @@ proc preparePipes(options: set[AsyncProcessOption],
stderrHandle: remoteStderr stderrHandle: remoteStderr
)) ))
proc closeWait(holder: AsyncStreamHolder) {.async.} = proc closeWait(holder: AsyncStreamHolder) {.async: (raises: []).} =
let (future, transp) = let (future, transp) =
case holder.kind case holder.kind
of StreamKind.None: of StreamKind.None:
@ -1173,10 +1192,11 @@ proc closeWait(holder: AsyncStreamHolder) {.async.} =
res res
if len(pending) > 0: if len(pending) > 0:
await allFutures(pending) await noCancel allFutures(pending)
proc closeProcessStreams(pipes: AsyncProcessPipes, proc closeProcessStreams(pipes: AsyncProcessPipes,
options: set[AsyncProcessOption]): Future[void] = options: set[AsyncProcessOption]): Future[void] {.
async: (raw: true, raises: []).} =
let pending = let pending =
block: block:
var res: seq[Future[void]] var res: seq[Future[void]]
@ -1187,9 +1207,53 @@ proc closeProcessStreams(pipes: AsyncProcessPipes,
if ProcessFlag.AutoStderr in pipes.flags: if ProcessFlag.AutoStderr in pipes.flags:
res.add(pipes.stderrHolder.closeWait()) res.add(pipes.stderrHolder.closeWait())
res res
allFutures(pending) noCancel allFutures(pending)
proc closeWait*(p: AsyncProcessRef) {.async.} = proc opAndWaitForExit(p: AsyncProcessRef, op: WaitOperation,
timeout = InfiniteDuration): Future[int] {.
async: (raises: [
AsyncProcessError, AsyncProcessTimeoutError, CancelledError]).} =
let timerFut =
if timeout == InfiniteDuration:
newFuture[void]("chronos.killAndwaitForExit")
else:
sleepAsync(timeout)
while true:
if p.running().get(true):
# We ignore operation errors because we going to repeat calling
# operation until process will not exit.
case op
of WaitOperation.Kill:
discard p.kill()
of WaitOperation.Terminate:
discard p.terminate()
else:
let exitCode = p.peekExitCode().valueOr:
raiseAsyncProcessError("Unable to peek process exit code", error)
if not(timerFut.finished()):
await cancelAndWait(timerFut)
return exitCode
let waitFut = p.waitForExit().wait(100.milliseconds)
try:
discard await race(FutureBase(waitFut), FutureBase(timerFut))
except ValueError:
raiseAssert "This should not be happened!"
if waitFut.finished() and not(waitFut.failed()):
let res = p.peekExitCode()
if res.isOk():
if not(timerFut.finished()):
await cancelAndWait(timerFut)
return res.get()
if timerFut.finished():
if not(waitFut.finished()):
await waitFut.cancelAndWait()
raiseAsyncProcessTimeoutError()
proc closeWait*(p: AsyncProcessRef) {.async: (raises: []).} =
# Here we ignore all possible errrors, because we do not want to raise # Here we ignore all possible errrors, because we do not want to raise
# exceptions. # exceptions.
discard closeProcessHandles(p.pipes, p.options, OSErrorCode(0)) discard closeProcessHandles(p.pipes, p.options, OSErrorCode(0))
@ -1198,16 +1262,19 @@ proc closeWait*(p: AsyncProcessRef) {.async.} =
untrackCounter(AsyncProcessTrackerName) untrackCounter(AsyncProcessTrackerName)
proc stdinStream*(p: AsyncProcessRef): AsyncStreamWriter = proc stdinStream*(p: AsyncProcessRef): AsyncStreamWriter =
## Returns STDIN async stream associated with process `p`.
doAssert(p.pipes.stdinHolder.kind == StreamKind.Writer, doAssert(p.pipes.stdinHolder.kind == StreamKind.Writer,
"StdinStreamWriter is not available") "StdinStreamWriter is not available")
p.pipes.stdinHolder.writer p.pipes.stdinHolder.writer
proc stdoutStream*(p: AsyncProcessRef): AsyncStreamReader = proc stdoutStream*(p: AsyncProcessRef): AsyncStreamReader =
## Returns STDOUT async stream associated with process `p`.
doAssert(p.pipes.stdoutHolder.kind == StreamKind.Reader, doAssert(p.pipes.stdoutHolder.kind == StreamKind.Reader,
"StdoutStreamReader is not available") "StdoutStreamReader is not available")
p.pipes.stdoutHolder.reader p.pipes.stdoutHolder.reader
proc stderrStream*(p: AsyncProcessRef): AsyncStreamReader = proc stderrStream*(p: AsyncProcessRef): AsyncStreamReader =
## Returns STDERR async stream associated with process `p`.
doAssert(p.pipes.stderrHolder.kind == StreamKind.Reader, doAssert(p.pipes.stderrHolder.kind == StreamKind.Reader,
"StderrStreamReader is not available") "StderrStreamReader is not available")
p.pipes.stderrHolder.reader p.pipes.stderrHolder.reader
@ -1215,20 +1282,25 @@ proc stderrStream*(p: AsyncProcessRef): AsyncStreamReader =
proc execCommand*(command: string, proc execCommand*(command: string,
options = {AsyncProcessOption.EvalCommand}, options = {AsyncProcessOption.EvalCommand},
timeout = InfiniteDuration timeout = InfiniteDuration
): Future[int] {.async.} = ): Future[int] {.
let poptions = options + {AsyncProcessOption.EvalCommand} async: (raises: [
let process = await startProcess(command, options = poptions) AsyncProcessError, AsyncProcessTimeoutError, CancelledError]).} =
let res = let
try: poptions = options + {AsyncProcessOption.EvalCommand}
await process.waitForExit(timeout) process = await startProcess(command, options = poptions)
finally: res =
await process.closeWait() try:
return res await process.waitForExit(timeout)
finally:
await process.closeWait()
res
proc execCommandEx*(command: string, proc execCommandEx*(command: string,
options = {AsyncProcessOption.EvalCommand}, options = {AsyncProcessOption.EvalCommand},
timeout = InfiniteDuration timeout = InfiniteDuration
): Future[CommandExResponse] {.async.} = ): Future[CommandExResponse] {.
async: (raises: [
AsyncProcessError, AsyncProcessTimeoutError, CancelledError]).} =
let let
process = await startProcess(command, options = options, process = await startProcess(command, options = options,
stdoutHandle = AsyncProcess.Pipe, stdoutHandle = AsyncProcess.Pipe,
@ -1242,13 +1314,13 @@ proc execCommandEx*(command: string,
status = await process.waitForExit(timeout) status = await process.waitForExit(timeout)
output = output =
try: try:
string.fromBytes(outputReader.read()) string.fromBytes(await outputReader)
except AsyncStreamError as exc: except AsyncStreamError as exc:
raiseAsyncProcessError("Unable to read process' stdout channel", raiseAsyncProcessError("Unable to read process' stdout channel",
exc) exc)
error = error =
try: try:
string.fromBytes(errorReader.read()) string.fromBytes(await errorReader)
except AsyncStreamError as exc: except AsyncStreamError as exc:
raiseAsyncProcessError("Unable to read process' stderr channel", raiseAsyncProcessError("Unable to read process' stderr channel",
exc) exc)
@ -1256,10 +1328,47 @@ proc execCommandEx*(command: string,
finally: finally:
await process.closeWait() await process.closeWait()
return res res
proc pid*(p: AsyncProcessRef): int = proc pid*(p: AsyncProcessRef): int =
## Returns process ``p`` identifier. ## Returns process ``p`` unique process identifier.
int(p.processId) int(p.processId)
template processId*(p: AsyncProcessRef): int = pid(p) template processId*(p: AsyncProcessRef): int = pid(p)
proc killAndWaitForExit*(p: AsyncProcessRef,
timeout = InfiniteDuration): Future[int] {.
async: (raw: true, raises: [
AsyncProcessError, AsyncProcessTimeoutError, CancelledError]).} =
## Perform continuous attempts to kill the ``p`` process for specified period
## of time ``timeout``.
##
## On Posix systems, killing means sending ``SIGKILL`` to the process ``p``,
## On Windows, it uses ``TerminateProcess`` to kill the process ``p``.
##
## If the process ``p`` fails to be killed within the ``timeout`` time, it
## will raise ``AsyncProcessTimeoutError``.
##
## In case of error this it will raise ``AsyncProcessError``.
##
## Returns process ``p`` exit code.
opAndWaitForExit(p, WaitOperation.Kill, timeout)
proc terminateAndWaitForExit*(p: AsyncProcessRef,
timeout = InfiniteDuration): Future[int] {.
async: (raw: true, raises: [
AsyncProcessError, AsyncProcessTimeoutError, CancelledError]).} =
## Perform continuous attempts to terminate the ``p`` process for specified
## period of time ``timeout``.
##
## On Posix systems, terminating means sending ``SIGTERM`` to the process
## ``p``, on Windows, it uses ``TerminateProcess`` to terminate the process
## ``p``.
##
## If the process ``p`` fails to be terminated within the ``timeout`` time, it
## will raise ``AsyncProcessTimeoutError``.
##
## In case of error this it will raise ``AsyncProcessError``.
##
## Returns process ``p`` exit code.
opAndWaitForExit(p, WaitOperation.Terminate, timeout)

View File

@ -28,7 +28,7 @@ type
## is blocked in ``acquire()`` is being processed. ## is blocked in ``acquire()`` is being processed.
locked: bool locked: bool
acquired: bool acquired: bool
waiters: seq[Future[void]] waiters: seq[Future[void].Raising([CancelledError])]
AsyncEvent* = ref object of RootRef AsyncEvent* = ref object of RootRef
## A primitive event object. ## A primitive event object.
@ -41,7 +41,7 @@ type
## state to be signaled, when event get fired, then all coroutines ## state to be signaled, when event get fired, then all coroutines
## continue proceeds in order, they have entered waiting state. ## continue proceeds in order, they have entered waiting state.
flag: bool flag: bool
waiters: seq[Future[void]] waiters: seq[Future[void].Raising([CancelledError])]
AsyncQueue*[T] = ref object of RootRef AsyncQueue*[T] = ref object of RootRef
## A queue, useful for coordinating producer and consumer coroutines. ## A queue, useful for coordinating producer and consumer coroutines.
@ -50,8 +50,8 @@ type
## infinite. If it is an integer greater than ``0``, then "await put()" ## infinite. If it is an integer greater than ``0``, then "await put()"
## will block when the queue reaches ``maxsize``, until an item is ## will block when the queue reaches ``maxsize``, until an item is
## removed by "await get()". ## removed by "await get()".
getters: seq[Future[void]] getters: seq[Future[void].Raising([CancelledError])]
putters: seq[Future[void]] putters: seq[Future[void].Raising([CancelledError])]
queue: Deque[T] queue: Deque[T]
maxsize: int maxsize: int
@ -62,50 +62,6 @@ type
AsyncLockError* = object of AsyncError AsyncLockError* = object of AsyncError
## ``AsyncLock`` is either locked or unlocked. ## ``AsyncLock`` is either locked or unlocked.
EventBusSubscription*[T] = proc(bus: AsyncEventBus,
payload: EventPayload[T]): Future[void] {.
gcsafe, raises: [].}
## EventBus subscription callback type.
EventBusAllSubscription* = proc(bus: AsyncEventBus,
event: AwaitableEvent): Future[void] {.
gcsafe, raises: [].}
## EventBus subscription callback type.
EventBusCallback = proc(bus: AsyncEventBus, event: string, key: EventBusKey,
data: EventPayloadBase) {.
gcsafe, raises: [].}
EventBusKey* = object
## Unique subscription key.
eventName: string
typeName: string
unique: uint64
cb: EventBusCallback
EventItem = object
waiters: seq[FutureBase]
subscribers: seq[EventBusKey]
AsyncEventBus* = ref object of RootObj
## An eventbus object.
counter: uint64
events: Table[string, EventItem]
subscribers: seq[EventBusKey]
waiters: seq[Future[AwaitableEvent]]
EventPayloadBase* = ref object of RootObj
loc: ptr SrcLoc
EventPayload*[T] = ref object of EventPayloadBase
## Eventbus' event payload object
value: T
AwaitableEvent* = object
## Eventbus' event payload object
eventName: string
payload: EventPayloadBase
AsyncEventQueueFullError* = object of AsyncError AsyncEventQueueFullError* = object of AsyncError
EventQueueKey* = distinct uint64 EventQueueKey* = distinct uint64
@ -113,7 +69,7 @@ type
EventQueueReader* = object EventQueueReader* = object
key: EventQueueKey key: EventQueueKey
offset: int offset: int
waiter: Future[void] waiter: Future[void].Raising([CancelledError])
overflow: bool overflow: bool
AsyncEventQueue*[T] = ref object of RootObj AsyncEventQueue*[T] = ref object of RootObj
@ -134,17 +90,14 @@ proc newAsyncLock*(): AsyncLock =
## The ``release()`` procedure changes the state to unlocked and returns ## The ``release()`` procedure changes the state to unlocked and returns
## immediately. ## immediately.
# Workaround for callSoon() not worked correctly before AsyncLock()
# getThreadDispatcher() call.
discard getThreadDispatcher()
AsyncLock(waiters: newSeq[Future[void]](), locked: false, acquired: false)
proc wakeUpFirst(lock: AsyncLock): bool {.inline.} = proc wakeUpFirst(lock: AsyncLock): bool {.inline.} =
## Wake up the first waiter if it isn't done. ## Wake up the first waiter if it isn't done.
var i = 0 var i = 0
var res = false var res = false
while i < len(lock.waiters): while i < len(lock.waiters):
var waiter = lock.waiters[i] let waiter = lock.waiters[i]
inc(i) inc(i)
if not(waiter.finished()): if not(waiter.finished()):
waiter.complete() waiter.complete()
@ -164,7 +117,7 @@ proc checkAll(lock: AsyncLock): bool {.inline.} =
return false return false
return true return true
proc acquire*(lock: AsyncLock) {.async.} = proc acquire*(lock: AsyncLock) {.async: (raises: [CancelledError]).} =
## Acquire a lock ``lock``. ## Acquire a lock ``lock``.
## ##
## This procedure blocks until the lock ``lock`` is unlocked, then sets it ## This procedure blocks until the lock ``lock`` is unlocked, then sets it
@ -173,7 +126,7 @@ proc acquire*(lock: AsyncLock) {.async.} =
lock.acquired = true lock.acquired = true
lock.locked = true lock.locked = true
else: else:
var w = newFuture[void]("AsyncLock.acquire") let w = Future[void].Raising([CancelledError]).init("AsyncLock.acquire")
lock.waiters.add(w) lock.waiters.add(w)
await w await w
lock.acquired = true lock.acquired = true
@ -209,13 +162,10 @@ proc newAsyncEvent*(): AsyncEvent =
## procedure and reset to `false` with the `clear()` procedure. ## procedure and reset to `false` with the `clear()` procedure.
## The `wait()` procedure blocks until the flag is `true`. The flag is ## The `wait()` procedure blocks until the flag is `true`. The flag is
## initially `false`. ## initially `false`.
AsyncEvent()
# Workaround for callSoon() not worked correctly before proc wait*(event: AsyncEvent): Future[void] {.
# getThreadDispatcher() call. async: (raw: true, raises: [CancelledError]).} =
discard getThreadDispatcher()
AsyncEvent(waiters: newSeq[Future[void]](), flag: false)
proc wait*(event: AsyncEvent): Future[void] =
## Block until the internal flag of ``event`` is `true`. ## Block until the internal flag of ``event`` is `true`.
## If the internal flag is `true` on entry, return immediately. Otherwise, ## If the internal flag is `true` on entry, return immediately. Otherwise,
## block until another task calls `fire()` to set the flag to `true`, ## block until another task calls `fire()` to set the flag to `true`,
@ -254,20 +204,15 @@ proc isSet*(event: AsyncEvent): bool =
proc newAsyncQueue*[T](maxsize: int = 0): AsyncQueue[T] = proc newAsyncQueue*[T](maxsize: int = 0): AsyncQueue[T] =
## Creates a new asynchronous queue ``AsyncQueue``. ## Creates a new asynchronous queue ``AsyncQueue``.
# Workaround for callSoon() not worked correctly before
# getThreadDispatcher() call.
discard getThreadDispatcher()
AsyncQueue[T]( AsyncQueue[T](
getters: newSeq[Future[void]](),
putters: newSeq[Future[void]](),
queue: initDeque[T](), queue: initDeque[T](),
maxsize: maxsize maxsize: maxsize
) )
proc wakeupNext(waiters: var seq[Future[void]]) {.inline.} = proc wakeupNext(waiters: var seq) {.inline.} =
var i = 0 var i = 0
while i < len(waiters): while i < len(waiters):
var waiter = waiters[i] let waiter = waiters[i]
inc(i) inc(i)
if not(waiter.finished()): if not(waiter.finished()):
@ -294,119 +239,141 @@ proc empty*[T](aq: AsyncQueue[T]): bool {.inline.} =
## Return ``true`` if the queue is empty, ``false`` otherwise. ## Return ``true`` if the queue is empty, ``false`` otherwise.
(len(aq.queue) == 0) (len(aq.queue) == 0)
proc addFirstImpl[T](aq: AsyncQueue[T], item: T) =
aq.queue.addFirst(item)
aq.getters.wakeupNext()
proc addLastImpl[T](aq: AsyncQueue[T], item: T) =
aq.queue.addLast(item)
aq.getters.wakeupNext()
proc popFirstImpl[T](aq: AsyncQueue[T]): T =
let res = aq.queue.popFirst()
aq.putters.wakeupNext()
res
proc popLastImpl[T](aq: AsyncQueue[T]): T =
let res = aq.queue.popLast()
aq.putters.wakeupNext()
res
proc addFirstNoWait*[T](aq: AsyncQueue[T], item: T) {. proc addFirstNoWait*[T](aq: AsyncQueue[T], item: T) {.
raises: [AsyncQueueFullError].}= raises: [AsyncQueueFullError].} =
## Put an item ``item`` to the beginning of the queue ``aq`` immediately. ## Put an item ``item`` to the beginning of the queue ``aq`` immediately.
## ##
## If queue ``aq`` is full, then ``AsyncQueueFullError`` exception raised. ## If queue ``aq`` is full, then ``AsyncQueueFullError`` exception raised.
if aq.full(): if aq.full():
raise newException(AsyncQueueFullError, "AsyncQueue is full!") raise newException(AsyncQueueFullError, "AsyncQueue is full!")
aq.queue.addFirst(item) aq.addFirstImpl(item)
aq.getters.wakeupNext()
proc addLastNoWait*[T](aq: AsyncQueue[T], item: T) {. proc addLastNoWait*[T](aq: AsyncQueue[T], item: T) {.
raises: [AsyncQueueFullError].}= raises: [AsyncQueueFullError].} =
## Put an item ``item`` at the end of the queue ``aq`` immediately. ## Put an item ``item`` at the end of the queue ``aq`` immediately.
## ##
## If queue ``aq`` is full, then ``AsyncQueueFullError`` exception raised. ## If queue ``aq`` is full, then ``AsyncQueueFullError`` exception raised.
if aq.full(): if aq.full():
raise newException(AsyncQueueFullError, "AsyncQueue is full!") raise newException(AsyncQueueFullError, "AsyncQueue is full!")
aq.queue.addLast(item) aq.addLastImpl(item)
aq.getters.wakeupNext()
proc popFirstNoWait*[T](aq: AsyncQueue[T]): T {. proc popFirstNoWait*[T](aq: AsyncQueue[T]): T {.
raises: [AsyncQueueEmptyError].} = raises: [AsyncQueueEmptyError].} =
## Get an item from the beginning of the queue ``aq`` immediately. ## Get an item from the beginning of the queue ``aq`` immediately.
## ##
## If queue ``aq`` is empty, then ``AsyncQueueEmptyError`` exception raised. ## If queue ``aq`` is empty, then ``AsyncQueueEmptyError`` exception raised.
if aq.empty(): if aq.empty():
raise newException(AsyncQueueEmptyError, "AsyncQueue is empty!") raise newException(AsyncQueueEmptyError, "AsyncQueue is empty!")
let res = aq.queue.popFirst() aq.popFirstImpl()
aq.putters.wakeupNext()
res
proc popLastNoWait*[T](aq: AsyncQueue[T]): T {. proc popLastNoWait*[T](aq: AsyncQueue[T]): T {.
raises: [AsyncQueueEmptyError].} = raises: [AsyncQueueEmptyError].} =
## Get an item from the end of the queue ``aq`` immediately. ## Get an item from the end of the queue ``aq`` immediately.
## ##
## If queue ``aq`` is empty, then ``AsyncQueueEmptyError`` exception raised. ## If queue ``aq`` is empty, then ``AsyncQueueEmptyError`` exception raised.
if aq.empty(): if aq.empty():
raise newException(AsyncQueueEmptyError, "AsyncQueue is empty!") raise newException(AsyncQueueEmptyError, "AsyncQueue is empty!")
let res = aq.queue.popLast() aq.popLastImpl()
aq.putters.wakeupNext()
res
proc addFirst*[T](aq: AsyncQueue[T], item: T) {.async.} = proc addFirst*[T](aq: AsyncQueue[T], item: T) {.
async: (raises: [CancelledError]).} =
## Put an ``item`` to the beginning of the queue ``aq``. If the queue is full, ## Put an ``item`` to the beginning of the queue ``aq``. If the queue is full,
## wait until a free slot is available before adding item. ## wait until a free slot is available before adding item.
while aq.full(): while aq.full():
var putter = newFuture[void]("AsyncQueue.addFirst") let putter =
Future[void].Raising([CancelledError]).init("AsyncQueue.addFirst")
aq.putters.add(putter) aq.putters.add(putter)
try: try:
await putter await putter
except CatchableError as exc: except CancelledError as exc:
if not(aq.full()) and not(putter.cancelled()): if not(aq.full()) and not(putter.cancelled()):
aq.putters.wakeupNext() aq.putters.wakeupNext()
raise exc raise exc
aq.addFirstNoWait(item) aq.addFirstImpl(item)
proc addLast*[T](aq: AsyncQueue[T], item: T) {.async.} = proc addLast*[T](aq: AsyncQueue[T], item: T) {.
async: (raises: [CancelledError]).} =
## Put an ``item`` to the end of the queue ``aq``. If the queue is full, ## Put an ``item`` to the end of the queue ``aq``. If the queue is full,
## wait until a free slot is available before adding item. ## wait until a free slot is available before adding item.
while aq.full(): while aq.full():
var putter = newFuture[void]("AsyncQueue.addLast") let putter =
Future[void].Raising([CancelledError]).init("AsyncQueue.addLast")
aq.putters.add(putter) aq.putters.add(putter)
try: try:
await putter await putter
except CatchableError as exc: except CancelledError as exc:
if not(aq.full()) and not(putter.cancelled()): if not(aq.full()) and not(putter.cancelled()):
aq.putters.wakeupNext() aq.putters.wakeupNext()
raise exc raise exc
aq.addLastNoWait(item) aq.addLastImpl(item)
proc popFirst*[T](aq: AsyncQueue[T]): Future[T] {.async.} = proc popFirst*[T](aq: AsyncQueue[T]): Future[T] {.
async: (raises: [CancelledError]).} =
## Remove and return an ``item`` from the beginning of the queue ``aq``. ## Remove and return an ``item`` from the beginning of the queue ``aq``.
## If the queue is empty, wait until an item is available. ## If the queue is empty, wait until an item is available.
while aq.empty(): while aq.empty():
var getter = newFuture[void]("AsyncQueue.popFirst") let getter =
Future[void].Raising([CancelledError]).init("AsyncQueue.popFirst")
aq.getters.add(getter) aq.getters.add(getter)
try: try:
await getter await getter
except CatchableError as exc: except CancelledError as exc:
if not(aq.empty()) and not(getter.cancelled()): if not(aq.empty()) and not(getter.cancelled()):
aq.getters.wakeupNext() aq.getters.wakeupNext()
raise exc raise exc
return aq.popFirstNoWait() aq.popFirstImpl()
proc popLast*[T](aq: AsyncQueue[T]): Future[T] {.async.} = proc popLast*[T](aq: AsyncQueue[T]): Future[T] {.
async: (raises: [CancelledError]).} =
## Remove and return an ``item`` from the end of the queue ``aq``. ## Remove and return an ``item`` from the end of the queue ``aq``.
## If the queue is empty, wait until an item is available. ## If the queue is empty, wait until an item is available.
while aq.empty(): while aq.empty():
var getter = newFuture[void]("AsyncQueue.popLast") let getter =
Future[void].Raising([CancelledError]).init("AsyncQueue.popLast")
aq.getters.add(getter) aq.getters.add(getter)
try: try:
await getter await getter
except CatchableError as exc: except CancelledError as exc:
if not(aq.empty()) and not(getter.cancelled()): if not(aq.empty()) and not(getter.cancelled()):
aq.getters.wakeupNext() aq.getters.wakeupNext()
raise exc raise exc
return aq.popLastNoWait() aq.popLastImpl()
proc putNoWait*[T](aq: AsyncQueue[T], item: T) {. proc putNoWait*[T](aq: AsyncQueue[T], item: T) {.
raises: [AsyncQueueFullError].} = raises: [AsyncQueueFullError].} =
## Alias of ``addLastNoWait()``. ## Alias of ``addLastNoWait()``.
aq.addLastNoWait(item) aq.addLastNoWait(item)
proc getNoWait*[T](aq: AsyncQueue[T]): T {. proc getNoWait*[T](aq: AsyncQueue[T]): T {.
raises: [AsyncQueueEmptyError].} = raises: [AsyncQueueEmptyError].} =
## Alias of ``popFirstNoWait()``. ## Alias of ``popFirstNoWait()``.
aq.popFirstNoWait() aq.popFirstNoWait()
proc put*[T](aq: AsyncQueue[T], item: T): Future[void] {.inline.} = proc put*[T](aq: AsyncQueue[T], item: T): Future[void] {.
async: (raw: true, raises: [CancelledError]).} =
## Alias of ``addLast()``. ## Alias of ``addLast()``.
aq.addLast(item) aq.addLast(item)
proc get*[T](aq: AsyncQueue[T]): Future[T] {.inline.} = proc get*[T](aq: AsyncQueue[T]): Future[T] {.
async: (raw: true, raises: [CancelledError]).} =
## Alias of ``popFirst()``. ## Alias of ``popFirst()``.
aq.popFirst() aq.popFirst()
@ -460,7 +427,7 @@ proc contains*[T](aq: AsyncQueue[T], item: T): bool {.inline.} =
## via the ``in`` operator. ## via the ``in`` operator.
for e in aq.queue.items(): for e in aq.queue.items():
if e == item: return true if e == item: return true
return false false
proc `$`*[T](aq: AsyncQueue[T]): string = proc `$`*[T](aq: AsyncQueue[T]): string =
## Turn an async queue ``aq`` into its string representation. ## Turn an async queue ``aq`` into its string representation.
@ -471,190 +438,6 @@ proc `$`*[T](aq: AsyncQueue[T]): string =
res.add("]") res.add("]")
res res
template generateKey(typeName, eventName: string): string =
"type[" & typeName & "]-key[" & eventName & "]"
proc newAsyncEventBus*(): AsyncEventBus {.
deprecated: "Implementation has unfixable flaws, please use" &
"AsyncEventQueue[T] instead".} =
## Creates new ``AsyncEventBus``.
AsyncEventBus(counter: 0'u64, events: initTable[string, EventItem]())
template get*[T](payload: EventPayload[T]): T =
## Returns event payload data.
payload.value
template location*(payload: EventPayloadBase): SrcLoc =
## Returns source location address of event emitter.
payload.loc[]
proc get*(event: AwaitableEvent, T: typedesc): T {.
deprecated: "Implementation has unfixable flaws, please use " &
"AsyncEventQueue[T] instead".} =
## Returns event's payload of type ``T`` from event ``event``.
cast[EventPayload[T]](event.payload).value
template event*(event: AwaitableEvent): string =
## Returns event's name from event ``event``.
event.eventName
template location*(event: AwaitableEvent): SrcLoc =
## Returns source location address of event emitter.
event.payload.loc[]
proc waitEvent*(bus: AsyncEventBus, T: typedesc, event: string): Future[T] {.
deprecated: "Implementation has unfixable flaws, please use " &
"AsyncEventQueue[T] instead".} =
## Wait for the event from AsyncEventBus ``bus`` with name ``event``.
##
## Returned ``Future[T]`` will hold event's payload of type ``T``.
var default: EventItem
var retFuture = newFuture[T]("AsyncEventBus.waitEvent")
let eventKey = generateKey(T.name, event)
proc cancellation(udata: pointer) {.gcsafe, raises: [].} =
if not(retFuture.finished()):
bus.events.withValue(eventKey, item):
item.waiters.keepItIf(it != cast[FutureBase](retFuture))
retFuture.cancelCallback = cancellation
let baseFuture = cast[FutureBase](retFuture)
bus.events.mgetOrPut(eventKey, default).waiters.add(baseFuture)
retFuture
proc waitAllEvents*(bus: AsyncEventBus): Future[AwaitableEvent] {.
deprecated: "Implementation has unfixable flaws, please use " &
"AsyncEventQueue[T] instead".} =
## Wait for any event from AsyncEventBus ``bus``.
##
## Returns ``Future`` which holds helper object. Using this object you can
## retrieve event's name and payload.
var retFuture = newFuture[AwaitableEvent]("AsyncEventBus.waitAllEvents")
proc cancellation(udata: pointer) {.gcsafe, raises: [].} =
if not(retFuture.finished()):
bus.waiters.keepItIf(it != retFuture)
retFuture.cancelCallback = cancellation
bus.waiters.add(retFuture)
retFuture
proc subscribe*[T](bus: AsyncEventBus, event: string,
callback: EventBusSubscription[T]): EventBusKey {.
deprecated: "Implementation has unfixable flaws, please use " &
"AsyncEventQueue[T] instead".} =
## Subscribe to the event ``event`` passed through eventbus ``bus`` with
## callback ``callback``.
##
## Returns key that can be used to unsubscribe.
proc trampoline(tbus: AsyncEventBus, event: string, key: EventBusKey,
data: EventPayloadBase) {.gcsafe, raises: [].} =
let payload = cast[EventPayload[T]](data)
asyncSpawn callback(bus, payload)
let subkey =
block:
inc(bus.counter)
EventBusKey(eventName: event, typeName: T.name, unique: bus.counter,
cb: trampoline)
var default: EventItem
let eventKey = generateKey(T.name, event)
bus.events.mgetOrPut(eventKey, default).subscribers.add(subkey)
subkey
proc subscribeAll*(bus: AsyncEventBus,
callback: EventBusAllSubscription): EventBusKey {.
deprecated: "Implementation has unfixable flaws, please use " &
"AsyncEventQueue instead".} =
## Subscribe to all events passed through eventbus ``bus`` with callback
## ``callback``.
##
## Returns key that can be used to unsubscribe.
proc trampoline(tbus: AsyncEventBus, event: string, key: EventBusKey,
data: EventPayloadBase) {.gcsafe, raises: [].} =
let event = AwaitableEvent(eventName: event, payload: data)
asyncSpawn callback(bus, event)
let subkey =
block:
inc(bus.counter)
EventBusKey(eventName: "", typeName: "", unique: bus.counter,
cb: trampoline)
bus.subscribers.add(subkey)
subkey
proc unsubscribe*(bus: AsyncEventBus, key: EventBusKey) {.
deprecated: "Implementation has unfixable flaws, please use " &
"AsyncEventQueue instead".} =
## Cancel subscription of subscriber with key ``key`` from eventbus ``bus``.
let eventKey = generateKey(key.typeName, key.eventName)
# Clean event's subscribers.
bus.events.withValue(eventKey, item):
item.subscribers.keepItIf(it.unique != key.unique)
# Clean subscribers subscribed to all events.
bus.subscribers.keepItIf(it.unique != key.unique)
proc emit[T](bus: AsyncEventBus, event: string, data: T, loc: ptr SrcLoc) =
let
eventKey = generateKey(T.name, event)
payload =
block:
var data = EventPayload[T](value: data, loc: loc)
cast[EventPayloadBase](data)
# Used to capture the "subscriber" variable in the loops
# sugar.capture doesn't work in Nim <1.6
proc triggerSubscriberCallback(subscriber: EventBusKey) =
callSoon(proc(udata: pointer) =
subscriber.cb(bus, event, subscriber, payload)
)
bus.events.withValue(eventKey, item):
# Schedule waiters which are waiting for the event ``event``.
for waiter in item.waiters:
var fut = cast[Future[T]](waiter)
fut.complete(data)
# Clear all the waiters.
item.waiters.setLen(0)
# Schedule subscriber's callbacks, which are subscribed to the event.
for subscriber in item.subscribers:
triggerSubscriberCallback(subscriber)
# Schedule waiters which are waiting all events
for waiter in bus.waiters:
waiter.complete(AwaitableEvent(eventName: event, payload: payload))
# Clear all the waiters.
bus.waiters.setLen(0)
# Schedule subscriber's callbacks which are subscribed to all events.
for subscriber in bus.subscribers:
triggerSubscriberCallback(subscriber)
template emit*[T](bus: AsyncEventBus, event: string, data: T) {.
deprecated: "Implementation has unfixable flaws, please use " &
"AsyncEventQueue instead".} =
## Emit new event ``event`` to the eventbus ``bus`` with payload ``data``.
emit(bus, event, data, getSrcLocation())
proc emitWait[T](bus: AsyncEventBus, event: string, data: T,
loc: ptr SrcLoc): Future[void] =
var retFuture = newFuture[void]("AsyncEventBus.emitWait")
proc continuation(udata: pointer) {.gcsafe.} =
if not(retFuture.finished()):
retFuture.complete()
emit(bus, event, data, loc)
callSoon(continuation)
return retFuture
template emitWait*[T](bus: AsyncEventBus, event: string,
data: T): Future[void] {.
deprecated: "Implementation has unfixable flaws, please use " &
"AsyncEventQueue instead".} =
## Emit new event ``event`` to the eventbus ``bus`` with payload ``data`` and
## wait until all the subscribers/waiters will receive notification about
## event.
emitWait(bus, event, data, getSrcLocation())
proc `==`(a, b: EventQueueKey): bool {.borrow.} proc `==`(a, b: EventQueueKey): bool {.borrow.}
proc compact(ab: AsyncEventQueue) {.raises: [].} = proc compact(ab: AsyncEventQueue) {.raises: [].} =
@ -680,8 +463,7 @@ proc compact(ab: AsyncEventQueue) {.raises: [].} =
else: else:
ab.queue.clear() ab.queue.clear()
proc getReaderIndex(ab: AsyncEventQueue, key: EventQueueKey): int {. proc getReaderIndex(ab: AsyncEventQueue, key: EventQueueKey): int =
raises: [].} =
for index, value in ab.readers.pairs(): for index, value in ab.readers.pairs():
if value.key == key: if value.key == key:
return index return index
@ -735,11 +517,16 @@ proc close*(ab: AsyncEventQueue) {.raises: [].} =
ab.readers.reset() ab.readers.reset()
ab.queue.clear() ab.queue.clear()
proc closeWait*(ab: AsyncEventQueue): Future[void] {.raises: [].} = proc closeWait*(ab: AsyncEventQueue): Future[void] {.
var retFuture = newFuture[void]("AsyncEventQueue.closeWait()") async: (raw: true, raises: []).} =
let retFuture = newFuture[void]("AsyncEventQueue.closeWait()",
{FutureFlag.OwnCancelSchedule})
proc continuation(udata: pointer) {.gcsafe.} = proc continuation(udata: pointer) {.gcsafe.} =
if not(retFuture.finished()): retFuture.complete()
retFuture.complete()
# Ignore cancellation requests - we'll complete the future soon enough
retFuture.cancelCallback = nil
ab.close() ab.close()
# Schedule `continuation` to be called only after all the `reader` # Schedule `continuation` to be called only after all the `reader`
# notifications will be scheduled and processed. # notifications will be scheduled and processed.
@ -750,7 +537,7 @@ template readerOverflow*(ab: AsyncEventQueue,
reader: EventQueueReader): bool = reader: EventQueueReader): bool =
ab.limit + (reader.offset - ab.offset) <= len(ab.queue) ab.limit + (reader.offset - ab.offset) <= len(ab.queue)
proc emit*[T](ab: AsyncEventQueue[T], data: T) {.raises: [].} = proc emit*[T](ab: AsyncEventQueue[T], data: T) =
if len(ab.readers) > 0: if len(ab.readers) > 0:
# We enqueue `data` only if there active reader present. # We enqueue `data` only if there active reader present.
var changesPresent = false var changesPresent = false
@ -787,7 +574,8 @@ proc emit*[T](ab: AsyncEventQueue[T], data: T) {.raises: [].} =
proc waitEvents*[T](ab: AsyncEventQueue[T], proc waitEvents*[T](ab: AsyncEventQueue[T],
key: EventQueueKey, key: EventQueueKey,
eventsCount = -1): Future[seq[T]] {.async.} = eventsCount = -1): Future[seq[T]] {.
async: (raises: [AsyncEventQueueFullError, CancelledError]).} =
## Wait for events ## Wait for events
var var
events: seq[T] events: seq[T]
@ -817,7 +605,8 @@ proc waitEvents*[T](ab: AsyncEventQueue[T],
doAssert(length >= ab.readers[index].offset) doAssert(length >= ab.readers[index].offset)
if length == ab.readers[index].offset: if length == ab.readers[index].offset:
# We are at the end of queue, it means that we should wait for new events. # We are at the end of queue, it means that we should wait for new events.
let waitFuture = newFuture[void]("AsyncEventQueue.waitEvents") let waitFuture = Future[void].Raising([CancelledError]).init(
"AsyncEventQueue.waitEvents")
ab.readers[index].waiter = waitFuture ab.readers[index].waiter = waitFuture
resetFuture = true resetFuture = true
await waitFuture await waitFuture
@ -848,4 +637,4 @@ proc waitEvents*[T](ab: AsyncEventQueue[T],
if (eventsCount <= 0) or (len(events) == eventsCount): if (eventsCount <= 0) or (len(events) == eventsCount):
break break
return events events

140
chronos/bipbuffer.nim Normal file
View File

@ -0,0 +1,140 @@
#
# Chronos
#
# (c) Copyright 2018-Present Status Research & Development GmbH
#
# Licensed under either of
# Apache License, version 2.0, (LICENSE-APACHEv2)
# MIT license (LICENSE-MIT)
## This module implements Bip Buffer (bi-partite circular buffer) by Simone
## Cooke.
##
## The Bip-Buffer is like a circular buffer, but slightly different. Instead of
## keeping one head and tail pointer to the data in the buffer, it maintains two
## revolving regions, allowing for fast data access without having to worry
## about wrapping at the end of the buffer. Buffer allocations are always
## maintained as contiguous blocks, allowing the buffer to be used in a highly
## efficient manner with API calls, and also reducing the amount of copying
## which needs to be performed to put data into the buffer. Finally, a two-phase
## allocation system allows the user to pessimistically reserve an area of
## buffer space, and then trim back the buffer to commit to only the space which
## was used.
##
## https://www.codeproject.com/Articles/3479/The-Bip-Buffer-The-Circular-Buffer-with-a-Twist
{.push raises: [].}
type
BipPos = object
start: Natural
finish: Natural
BipBuffer* = object
a, b, r: BipPos
data: seq[byte]
proc init*(t: typedesc[BipBuffer], size: int): BipBuffer =
## Creates new Bip Buffer with size `size`.
BipBuffer(data: newSeq[byte](size))
template len(pos: BipPos): Natural =
pos.finish - pos.start
template reset(pos: var BipPos) =
pos = BipPos()
func init(t: typedesc[BipPos], start, finish: Natural): BipPos =
BipPos(start: start, finish: finish)
func calcReserve(bp: BipBuffer): tuple[space: Natural, start: Natural] =
if len(bp.b) > 0:
(Natural(bp.a.start - bp.b.finish), bp.b.finish)
else:
let spaceAfterA = Natural(len(bp.data) - bp.a.finish)
if spaceAfterA >= bp.a.start:
(spaceAfterA, bp.a.finish)
else:
(bp.a.start, Natural(0))
func availSpace*(bp: BipBuffer): Natural =
## Returns amount of space available for reserve in buffer `bp`.
let (res, _) = bp.calcReserve()
res
func len*(bp: BipBuffer): Natural =
## Returns amount of used space in buffer `bp`.
len(bp.b) + len(bp.a)
proc reserve*(bp: var BipBuffer,
size: Natural = 0): tuple[data: ptr byte, size: Natural] =
## Reserve `size` bytes in buffer.
##
## If `size == 0` (default) reserve all available space from buffer.
##
## If there is not enough space in buffer for resevation - error will be
## returned.
##
## Returns current reserved range as pointer of type `pt` and size of
## type `st`.
const ErrorMessage = "Not enough space available"
doAssert(size <= len(bp.data))
let (availableSpace, reserveStart) = bp.calcReserve()
if availableSpace == 0:
raiseAssert ErrorMessage
let reserveLength =
if size == 0:
availableSpace
else:
if size < availableSpace:
raiseAssert ErrorMessage
size
bp.r = BipPos.init(reserveStart, Natural(reserveStart + reserveLength))
(addr bp.data[bp.r.start], len(bp.r))
proc commit*(bp: var BipBuffer, size: Natural) =
## Updates structure's pointers when new data inserted into buffer.
doAssert(len(bp.r) >= size,
"Committed size could not be larger than the previously reserved one")
if size == 0:
bp.r.reset()
return
let toCommit = min(size, len(bp.r))
if len(bp.a) == 0 and len(bp.b) == 0:
bp.a.start = bp.r.start
bp.a.finish = bp.r.start + toCommit
elif bp.r.start == bp.a.finish:
bp.a.finish += toCommit
else:
bp.b.finish += toCommit
bp.r.reset()
proc consume*(bp: var BipBuffer, size: Natural) =
## The procedure removes/frees `size` bytes from the buffer ``bp``.
var currentSize = size
if currentSize >= len(bp.a):
currentSize -= len(bp.a)
bp.a = bp.b
bp.b.reset()
if currentSize >= len(bp.a):
currentSize -= len(bp.a)
bp.a.reset()
else:
bp.a.start += currentSize
else:
bp.a.start += currentSize
iterator items*(bp: BipBuffer): byte =
## Iterates over all the bytes in the buffer.
for index in bp.a.start ..< bp.a.finish:
yield bp.data[index]
for index in bp.b.start ..< bp.b.finish:
yield bp.data[index]
iterator regions*(bp: var BipBuffer): tuple[data: ptr byte, size: Natural] =
## Iterates over all the regions (`a` and `b`) in the buffer.
if len(bp.a) > 0:
yield (addr bp.data[bp.a.start], len(bp.a))
if len(bp.b) > 0:
yield (addr bp.data[bp.b.start], len(bp.b))

View File

@ -11,64 +11,100 @@
## `chronosDebug` can be defined to enable several debugging helpers that come ## `chronosDebug` can be defined to enable several debugging helpers that come
## with a runtime cost - it is recommeneded to not enable these in production ## with a runtime cost - it is recommeneded to not enable these in production
## code. ## code.
when (NimMajor, NimMinor) >= (1, 4): ##
const ## In this file we also declare feature flags starting with `chronosHas...` -
chronosStrictException* {.booldefine.}: bool = defined(chronosPreviewV4) ## these constants are declared when a feature exists in a particular release -
## Require that `async` code raises only derivatives of `CatchableError` ## each flag is declared as an integer starting at 0 during experimental
## and not `Exception` - forward declarations, methods and `proc` types ## development, 1 when feature complete and higher numbers when significant
## used from within `async` code may need to be be explicitly annotated ## functionality has been added. If a feature ends up being removed (or changed
## with `raises: [CatchableError]` when this mode is enabled. ## in a backwards-incompatible way), the feature flag will be removed or renamed
## also - you can use `when declared(chronosHasXxx): when chronosHasXxx >= N:`
## to require a particular version.
const
chronosHandleException* {.booldefine.}: bool = false
## Remap `Exception` to `AsyncExceptionError` for all `async` functions.
##
## This modes provides backwards compatibility when using functions with
## inaccurate `{.raises.}` effects such as unannotated forward declarations,
## methods and `proc` types - it is recommened to annotate such code
## explicitly as the `Exception` handling mode may introduce surprising
## behavior in exception handlers, should `Exception` actually be raised.
##
## The setting provides the default for the per-function-based
## `handleException` parameter which has precedence over this global setting.
##
## `Exception` handling may be removed in future chronos versions.
chronosStrictFutureAccess* {.booldefine.}: bool = defined(chronosPreviewV4) chronosStrictFutureAccess* {.booldefine.}: bool = defined(chronosPreviewV4)
chronosStackTrace* {.booldefine.}: bool = defined(chronosDebug) chronosStackTrace* {.booldefine.}: bool = defined(chronosDebug)
## Include stack traces in futures for creation and completion points ## Include stack traces in futures for creation and completion points
chronosFutureId* {.booldefine.}: bool = defined(chronosDebug) chronosFutureId* {.booldefine.}: bool = defined(chronosDebug)
## Generate a unique `id` for every future - when disabled, the address of ## Generate a unique `id` for every future - when disabled, the address of
## the future will be used instead ## the future will be used instead
chronosFutureTracking* {.booldefine.}: bool = defined(chronosDebug) chronosFutureTracking* {.booldefine.}: bool = defined(chronosDebug)
## Keep track of all pending futures and allow iterating over them - ## Keep track of all pending futures and allow iterating over them -
## useful for detecting hung tasks ## useful for detecting hung tasks
chronosDumpAsync* {.booldefine.}: bool = defined(nimDumpAsync) chronosDumpAsync* {.booldefine.}: bool = defined(nimDumpAsync)
## Print code generated by {.async.} transformation ## Print code generated by {.async.} transformation
chronosProcShell* {.strdefine.}: string = chronosProcShell* {.strdefine.}: string =
when defined(windows): when defined(windows):
"cmd.exe" "cmd.exe"
else:
when defined(android):
"/system/bin/sh"
else: else:
when defined(android): "/bin/sh"
"/system/bin/sh" ## Default shell binary path.
else: ##
"/bin/sh" ## The shell is used as command for command line when process started
## Default shell binary path. ## using `AsyncProcessOption.EvalCommand` and API calls such as
## ## ``execCommand(command)`` and ``execCommandEx(command)``.
## The shell is used as command for command line when process started
## using `AsyncProcessOption.EvalCommand` and API calls such as
## ``execCommand(command)`` and ``execCommandEx(command)``.
else: chronosEventsCount* {.intdefine.} = 64
# 1.2 doesn't support `booldefine` in `when` properly ## Number of OS poll events retrieved by syscall (epoll, kqueue, poll).
const
chronosStrictException*: bool = chronosInitialSize* {.intdefine.} = 64
defined(chronosPreviewV4) or defined(chronosStrictException) ## Initial size of Selector[T]'s array of file descriptors.
chronosStrictFutureAccess*: bool =
defined(chronosPreviewV4) or defined(chronosStrictFutureAccess) chronosEventEngine* {.strdefine.}: string =
chronosStackTrace*: bool = defined(chronosDebug) or defined(chronosStackTrace) when defined(nimdoc):
chronosFutureId*: bool = defined(chronosDebug) or defined(chronosFutureId) ""
chronosFutureTracking*: bool = elif defined(linux) and not(defined(android) or defined(emscripten)):
defined(chronosDebug) or defined(chronosFutureTracking) "epoll"
chronosDumpAsync*: bool = defined(nimDumpAsync) elif defined(macosx) or defined(macos) or defined(ios) or
chronosProcShell* {.strdefine.}: string = defined(freebsd) or defined(netbsd) or defined(openbsd) or
when defined(windows): defined(dragonfly):
"cmd.exe" "kqueue"
else: elif defined(android) or defined(emscripten):
when defined(android): "poll"
"/system/bin/sh" elif defined(posix):
else: "poll"
"/bin/sh" else:
""
## OS polling engine type which is going to be used by chronos.
chronosHasRaises* = 0
## raises effect support via `async: (raises: [])`
chronosTransportDefaultBufferSize* {.intdefine.} = 16384
## Default size of chronos transport internal buffer.
chronosStreamDefaultBufferSize* {.intdefine.} = 16384
## Default size of chronos async stream internal buffer.
chronosTLSSessionCacheBufferSize* {.intdefine.} = 4096
## Default size of chronos TLS Session cache's internal buffer.
when defined(chronosStrictException):
{.warning: "-d:chronosStrictException has been deprecated in favor of handleException".}
# In chronos v3, this setting was used as the opposite of
# `chronosHandleException` - the setting is deprecated to encourage
# migration to the new mode.
when defined(debug) or defined(chronosConfig): when defined(debug) or defined(chronosConfig):
import std/macros import std/macros
@ -77,9 +113,54 @@ when defined(debug) or defined(chronosConfig):
hint("Chronos configuration:") hint("Chronos configuration:")
template printOption(name: string, value: untyped) = template printOption(name: string, value: untyped) =
hint(name & ": " & $value) hint(name & ": " & $value)
printOption("chronosStrictException", chronosStrictException) printOption("chronosHandleException", chronosHandleException)
printOption("chronosStackTrace", chronosStackTrace) printOption("chronosStackTrace", chronosStackTrace)
printOption("chronosFutureId", chronosFutureId) printOption("chronosFutureId", chronosFutureId)
printOption("chronosFutureTracking", chronosFutureTracking) printOption("chronosFutureTracking", chronosFutureTracking)
printOption("chronosDumpAsync", chronosDumpAsync) printOption("chronosDumpAsync", chronosDumpAsync)
printOption("chronosProcShell", chronosProcShell) printOption("chronosProcShell", chronosProcShell)
printOption("chronosEventEngine", chronosEventEngine)
printOption("chronosEventsCount", chronosEventsCount)
printOption("chronosInitialSize", chronosInitialSize)
printOption("chronosTransportDefaultBufferSize",
chronosTransportDefaultBufferSize)
printOption("chronosStreamDefaultBufferSize",
chronosStreamDefaultBufferSize)
printOption("chronosTLSSessionCacheBufferSize",
chronosTLSSessionCacheBufferSize)
# In nim 1.6, `sink` + local variable + `move` generates the best code for
# moving a proc parameter into a closure - this only works for closure
# procedures however - in closure iterators, the parameter is always copied
# into the closure (!) meaning that non-raw `{.async.}` functions always carry
# this overhead, sink or no. See usages of chronosMoveSink for examples.
# In addition, we need to work around https://github.com/nim-lang/Nim/issues/22175
# which has not been backported to 1.6.
# Long story short, the workaround is not needed in non-raw {.async.} because
# a copy of the literal is always made.
# TODO review the above for 2.0 / 2.0+refc
type
SeqHeader = object
length, reserved: int
proc isLiteral(s: string): bool {.inline.} =
when defined(gcOrc) or defined(gcArc):
false
else:
s.len > 0 and (cast[ptr SeqHeader](s).reserved and (1 shl (sizeof(int) * 8 - 2))) != 0
proc isLiteral[T](s: seq[T]): bool {.inline.} =
when defined(gcOrc) or defined(gcArc):
false
else:
s.len > 0 and (cast[ptr SeqHeader](s).reserved and (1 shl (sizeof(int) * 8 - 2))) != 0
template chronosMoveSink*(val: auto): untyped =
bind isLiteral
when not (defined(gcOrc) or defined(gcArc)) and val is seq|string:
if isLiteral(val):
val
else:
move(val)
else:
move(val)

View File

@ -17,11 +17,6 @@ export srcloc
when chronosStackTrace: when chronosStackTrace:
type StackTrace = string type StackTrace = string
when chronosStrictException:
{.pragma: closureIter, raises: [CatchableError], gcsafe.}
else:
{.pragma: closureIter, raises: [Exception], gcsafe.}
type type
LocationKind* {.pure.} = enum LocationKind* {.pure.} = enum
Create Create
@ -37,6 +32,24 @@ type
FutureState* {.pure.} = enum FutureState* {.pure.} = enum
Pending, Completed, Cancelled, Failed Pending, Completed, Cancelled, Failed
FutureFlag* {.pure.} = enum
OwnCancelSchedule
## When OwnCancelSchedule is set, the owner of the future is responsible
## for implementing cancellation in one of 3 ways:
##
## * ensure that cancellation requests never reach the future by means of
## not exposing it to user code, `await` and `tryCancel`
## * set `cancelCallback` to `nil` to stop cancellation propagation - this
## is appropriate when it is expected that the future will be completed
## in a regular way "soon"
## * set `cancelCallback` to a handler that implements cancellation in an
## operation-specific way
##
## If `cancelCallback` is not set and the future gets cancelled, a
## `Defect` will be raised.
FutureFlags* = set[FutureFlag]
InternalFutureBase* = object of RootObj InternalFutureBase* = object of RootObj
# Internal untyped future representation - the fields are not part of the # Internal untyped future representation - the fields are not part of the
# public API and neither is `InternalFutureBase`, ie the inheritance # public API and neither is `InternalFutureBase`, ie the inheritance
@ -47,9 +60,9 @@ type
internalCancelcb*: CallbackFunc internalCancelcb*: CallbackFunc
internalChild*: FutureBase internalChild*: FutureBase
internalState*: FutureState internalState*: FutureState
internalFlags*: FutureFlags
internalError*: ref CatchableError ## Stored exception internalError*: ref CatchableError ## Stored exception
internalMustCancel*: bool internalClosure*: iterator(f: FutureBase): FutureBase {.raises: [], gcsafe.}
internalClosure*: iterator(f: FutureBase): FutureBase {.closureIter.}
when chronosFutureId: when chronosFutureId:
internalId*: uint internalId*: uint
@ -73,10 +86,15 @@ type
cause*: FutureBase cause*: FutureBase
FutureError* = object of CatchableError FutureError* = object of CatchableError
future*: FutureBase
CancelledError* = object of FutureError CancelledError* = object of FutureError
## Exception raised when accessing the value of a cancelled future ## Exception raised when accessing the value of a cancelled future
func raiseFutureDefect(msg: static string, fut: FutureBase) {.
noinline, noreturn.} =
raise (ref FutureDefect)(msg: msg, cause: fut)
when chronosFutureId: when chronosFutureId:
var currentID* {.threadvar.}: uint var currentID* {.threadvar.}: uint
template id*(fut: FutureBase): uint = fut.internalId template id*(fut: FutureBase): uint = fut.internalId
@ -94,12 +112,17 @@ when chronosFutureTracking:
var futureList* {.threadvar.}: FutureList var futureList* {.threadvar.}: FutureList
# Internal utilities - these are not part of the stable API # Internal utilities - these are not part of the stable API
proc internalInitFutureBase*( proc internalInitFutureBase*(fut: FutureBase, loc: ptr SrcLoc,
fut: FutureBase, state: FutureState, flags: FutureFlags) =
loc: ptr SrcLoc,
state: FutureState) =
fut.internalState = state fut.internalState = state
fut.internalLocation[LocationKind.Create] = loc fut.internalLocation[LocationKind.Create] = loc
fut.internalFlags = flags
if FutureFlag.OwnCancelSchedule in flags:
# Owners must replace `cancelCallback` with `nil` if they want to ignore
# cancellations
fut.internalCancelcb = proc(_: pointer) =
raiseAssert "Cancellation request for non-cancellable future"
if state != FutureState.Pending: if state != FutureState.Pending:
fut.internalLocation[LocationKind.Finish] = loc fut.internalLocation[LocationKind.Finish] = loc
@ -128,21 +151,34 @@ template init*[T](F: type Future[T], fromProc: static[string] = ""): Future[T] =
## Specifying ``fromProc``, which is a string specifying the name of the proc ## Specifying ``fromProc``, which is a string specifying the name of the proc
## that this future belongs to, is a good habit as it helps with debugging. ## that this future belongs to, is a good habit as it helps with debugging.
let res = Future[T]() let res = Future[T]()
internalInitFutureBase(res, getSrcLocation(fromProc), FutureState.Pending) internalInitFutureBase(res, getSrcLocation(fromProc), FutureState.Pending, {})
res
template init*[T](F: type Future[T], fromProc: static[string] = "",
flags: static[FutureFlags]): Future[T] =
## Creates a new pending future.
##
## Specifying ``fromProc``, which is a string specifying the name of the proc
## that this future belongs to, is a good habit as it helps with debugging.
let res = Future[T]()
internalInitFutureBase(res, getSrcLocation(fromProc), FutureState.Pending,
flags)
res res
template completed*( template completed*(
F: type Future, fromProc: static[string] = ""): Future[void] = F: type Future, fromProc: static[string] = ""): Future[void] =
## Create a new completed future ## Create a new completed future
let res = Future[T]() let res = Future[void]()
internalInitFutureBase(res, getSrcLocation(fromProc), FutureState.Completed) internalInitFutureBase(res, getSrcLocation(fromProc), FutureState.Completed,
{})
res res
template completed*[T: not void]( template completed*[T: not void](
F: type Future, valueParam: T, fromProc: static[string] = ""): Future[T] = F: type Future, valueParam: T, fromProc: static[string] = ""): Future[T] =
## Create a new completed future ## Create a new completed future
let res = Future[T](internalValue: valueParam) let res = Future[T](internalValue: valueParam)
internalInitFutureBase(res, getSrcLocation(fromProc), FutureState.Completed) internalInitFutureBase(res, getSrcLocation(fromProc), FutureState.Completed,
{})
res res
template failed*[T]( template failed*[T](
@ -150,19 +186,21 @@ template failed*[T](
fromProc: static[string] = ""): Future[T] = fromProc: static[string] = ""): Future[T] =
## Create a new failed future ## Create a new failed future
let res = Future[T](internalError: errorParam) let res = Future[T](internalError: errorParam)
internalInitFutureBase(res, getSrcLocation(fromProc), FutureState.Failed) internalInitFutureBase(res, getSrcLocation(fromProc), FutureState.Failed, {})
when chronosStackTrace: when chronosStackTrace:
res.internalErrorStackTrace = res.internalErrorStackTrace =
if getStackTrace(res.error) == "": if getStackTrace(res.error) == "":
getStackTrace() getStackTrace()
else: else:
getStackTrace(res.error) getStackTrace(res.error)
res res
func state*(future: FutureBase): FutureState = func state*(future: FutureBase): FutureState =
future.internalState future.internalState
func flags*(future: FutureBase): FutureFlags =
future.internalFlags
func finished*(future: FutureBase): bool {.inline.} = func finished*(future: FutureBase): bool {.inline.} =
## Determines whether ``future`` has finished, i.e. ``future`` state changed ## Determines whether ``future`` has finished, i.e. ``future`` state changed
## from state ``Pending`` to one of the states (``Finished``, ``Cancelled``, ## from state ``Pending`` to one of the states (``Finished``, ``Cancelled``,
@ -184,20 +222,27 @@ func completed*(future: FutureBase): bool {.inline.} =
func location*(future: FutureBase): array[LocationKind, ptr SrcLoc] = func location*(future: FutureBase): array[LocationKind, ptr SrcLoc] =
future.internalLocation future.internalLocation
func value*[T](future: Future[T]): T = func value*[T: not void](future: Future[T]): lent T =
## Return the value in a completed future - raises Defect when ## Return the value in a completed future - raises Defect when
## `fut.completed()` is `false`. ## `fut.completed()` is `false`.
## ##
## See `read` for a version that raises an catchable error when future ## See `read` for a version that raises a catchable error when future
## has not completed. ## has not completed.
when chronosStrictFutureAccess: when chronosStrictFutureAccess:
if not future.completed(): if not future.completed():
raise (ref FutureDefect)( raiseFutureDefect("Future not completed while accessing value", future)
msg: "Future not completed while accessing value",
cause: future)
when T isnot void: future.internalValue
future.internalValue
func value*(future: Future[void]) =
## Return the value in a completed future - raises Defect when
## `fut.completed()` is `false`.
##
## See `read` for a version that raises a catchable error when future
## has not completed.
when chronosStrictFutureAccess:
if not future.completed():
raiseFutureDefect("Future not completed while accessing value", future)
func error*(future: FutureBase): ref CatchableError = func error*(future: FutureBase): ref CatchableError =
## Return the error of `future`, or `nil` if future did not fail. ## Return the error of `future`, or `nil` if future did not fail.
@ -206,9 +251,8 @@ func error*(future: FutureBase): ref CatchableError =
## future has not failed. ## future has not failed.
when chronosStrictFutureAccess: when chronosStrictFutureAccess:
if not future.failed() and not future.cancelled(): if not future.failed() and not future.cancelled():
raise (ref FutureDefect)( raiseFutureDefect(
msg: "Future not failed/cancelled while accessing error", "Future not failed/cancelled while accessing error", future)
cause: future)
future.internalError future.internalError

View File

@ -10,7 +10,7 @@
{.push raises: [].} {.push raises: [].}
import "."/[asyncloop, osdefs, osutils] import "."/[asyncloop, osdefs, osutils]
import stew/results import results
from nativesockets import Domain, Protocol, SockType, toInt from nativesockets import Domain, Protocol, SockType, toInt
export Domain, Protocol, SockType, results export Domain, Protocol, SockType, results
@ -21,66 +21,113 @@ const
asyncInvalidSocket* = AsyncFD(osdefs.INVALID_SOCKET) asyncInvalidSocket* = AsyncFD(osdefs.INVALID_SOCKET)
asyncInvalidPipe* = asyncInvalidSocket asyncInvalidPipe* = asyncInvalidSocket
proc setSocketBlocking*(s: SocketHandle, blocking: bool): bool = proc setSocketBlocking*(s: SocketHandle, blocking: bool): bool {.
deprecated: "Please use setDescriptorBlocking() instead".} =
## Sets blocking mode on socket. ## Sets blocking mode on socket.
when defined(windows) or defined(nimdoc): setDescriptorBlocking(s, blocking).isOkOr:
var mode = clong(ord(not blocking)) return false
if osdefs.ioctlsocket(s, osdefs.FIONBIO, addr(mode)) == -1: true
false
else:
true
else:
let x: int = osdefs.fcntl(s, osdefs.F_GETFL, 0)
if x == -1:
false
else:
let mode =
if blocking: x and not osdefs.O_NONBLOCK else: x or osdefs.O_NONBLOCK
if osdefs.fcntl(s, osdefs.F_SETFL, mode) == -1:
false
else:
true
proc setSockOpt*(socket: AsyncFD, level, optname, optval: int): bool = proc setSockOpt2*(socket: AsyncFD,
## `setsockopt()` for integer options. level, optname, optval: int): Result[void, OSErrorCode] =
## Returns ``true`` on success, ``false`` on error.
var value = cint(optval) var value = cint(optval)
osdefs.setsockopt(SocketHandle(socket), cint(level), cint(optname), let res = osdefs.setsockopt(SocketHandle(socket), cint(level), cint(optname),
addr(value), SockLen(sizeof(value))) >= cint(0) addr(value), SockLen(sizeof(value)))
if res == -1:
return err(osLastError())
ok()
proc setSockOpt*(socket: AsyncFD, level, optname: int, value: pointer, proc setSockOpt2*(socket: AsyncFD, level, optname: int, value: pointer,
valuelen: int): bool = valuelen: int): Result[void, OSErrorCode] =
## `setsockopt()` for custom options (pointer and length). ## `setsockopt()` for custom options (pointer and length).
## Returns ``true`` on success, ``false`` on error. ## Returns ``true`` on success, ``false`` on error.
osdefs.setsockopt(SocketHandle(socket), cint(level), cint(optname), value, let res = osdefs.setsockopt(SocketHandle(socket), cint(level), cint(optname),
SockLen(valuelen)) >= cint(0) value, SockLen(valuelen))
if res == -1:
return err(osLastError())
ok()
proc getSockOpt*(socket: AsyncFD, level, optname: int, value: var int): bool = proc setSockOpt*(socket: AsyncFD, level, optname, optval: int): bool {.
deprecated: "Please use setSockOpt2() instead".} =
## `setsockopt()` for integer options.
## Returns ``true`` on success, ``false`` on error.
setSockOpt2(socket, level, optname, optval).isOk
proc setSockOpt*(socket: AsyncFD, level, optname: int, value: pointer,
valuelen: int): bool {.
deprecated: "Please use setSockOpt2() instead".} =
## `setsockopt()` for custom options (pointer and length).
## Returns ``true`` on success, ``false`` on error.
setSockOpt2(socket, level, optname, value, valuelen).isOk
proc getSockOpt2*(socket: AsyncFD,
level, optname: int): Result[cint, OSErrorCode] =
var
value: cint
size = SockLen(sizeof(value))
let res = osdefs.getsockopt(SocketHandle(socket), cint(level), cint(optname),
addr(value), addr(size))
if res == -1:
return err(osLastError())
ok(value)
proc getSockOpt2*(socket: AsyncFD, level, optname: int,
T: type): Result[T, OSErrorCode] =
var
value = default(T)
size = SockLen(sizeof(value))
let res = osdefs.getsockopt(SocketHandle(socket), cint(level), cint(optname),
cast[ptr byte](addr(value)), addr(size))
if res == -1:
return err(osLastError())
ok(value)
proc getSockOpt*(socket: AsyncFD, level, optname: int, value: var int): bool {.
deprecated: "Please use getSockOpt2() instead".} =
## `getsockopt()` for integer options. ## `getsockopt()` for integer options.
## Returns ``true`` on success, ``false`` on error. ## Returns ``true`` on success, ``false`` on error.
var res: cint value = getSockOpt2(socket, level, optname).valueOr:
var size = SockLen(sizeof(res)) return false
if osdefs.getsockopt(SocketHandle(socket), cint(level), cint(optname), true
addr(res), addr(size)) >= cint(0):
value = int(res)
true
else:
false
proc getSockOpt*(socket: AsyncFD, level, optname: int, value: pointer, proc getSockOpt*(socket: AsyncFD, level, optname: int, value: var pointer,
valuelen: var int): bool = valuelen: var int): bool {.
deprecated: "Please use getSockOpt2() instead".} =
## `getsockopt()` for custom options (pointer and length). ## `getsockopt()` for custom options (pointer and length).
## Returns ``true`` on success, ``false`` on error. ## Returns ``true`` on success, ``false`` on error.
osdefs.getsockopt(SocketHandle(socket), cint(level), cint(optname), osdefs.getsockopt(SocketHandle(socket), cint(level), cint(optname),
value, cast[ptr SockLen](addr valuelen)) >= cint(0) value, cast[ptr SockLen](addr valuelen)) >= cint(0)
proc getSocketError*(socket: AsyncFD, err: var int): bool = proc getSocketError*(socket: AsyncFD, err: var int): bool {.
deprecated: "Please use getSocketError() instead".} =
## Recover error code associated with socket handle ``socket``. ## Recover error code associated with socket handle ``socket``.
getSockOpt(socket, cint(osdefs.SOL_SOCKET), cint(osdefs.SO_ERROR), err) err = getSockOpt2(socket, cint(osdefs.SOL_SOCKET),
cint(osdefs.SO_ERROR)).valueOr:
return false
true
proc getSocketError2*(socket: AsyncFD): Result[cint, OSErrorCode] =
getSockOpt2(socket, cint(osdefs.SOL_SOCKET), cint(osdefs.SO_ERROR))
proc isAvailable*(domain: Domain): bool =
when defined(windows):
let fd = wsaSocket(toInt(domain), toInt(SockType.SOCK_STREAM),
toInt(Protocol.IPPROTO_TCP), nil, GROUP(0), 0'u32)
if fd == osdefs.INVALID_SOCKET:
return if osLastError() == osdefs.WSAEAFNOSUPPORT: false else: true
discard closeFd(fd)
true
else:
let fd = osdefs.socket(toInt(domain), toInt(SockType.SOCK_STREAM),
toInt(Protocol.IPPROTO_TCP))
if fd == -1:
return if osLastError() == osdefs.EAFNOSUPPORT: false else: true
discard closeFd(fd)
true
proc createAsyncSocket2*(domain: Domain, sockType: SockType, proc createAsyncSocket2*(domain: Domain, sockType: SockType,
protocol: Protocol, protocol: Protocol,
inherit = true): Result[AsyncFD, OSErrorCode] = inherit = true): Result[AsyncFD, OSErrorCode] =
## Creates new asynchronous socket. ## Creates new asynchronous socket.
when defined(windows): when defined(windows):
let flags = let flags =
@ -93,15 +140,12 @@ proc createAsyncSocket2*(domain: Domain, sockType: SockType,
if fd == osdefs.INVALID_SOCKET: if fd == osdefs.INVALID_SOCKET:
return err(osLastError()) return err(osLastError())
let bres = setDescriptorBlocking(fd, false) setDescriptorBlocking(fd, false).isOkOr:
if bres.isErr():
discard closeFd(fd) discard closeFd(fd)
return err(bres.error()) return err(error)
register2(AsyncFD(fd)).isOkOr:
let res = register2(AsyncFD(fd))
if res.isErr():
discard closeFd(fd) discard closeFd(fd)
return err(res.error()) return err(error)
ok(AsyncFD(fd)) ok(AsyncFD(fd))
else: else:
@ -114,23 +158,20 @@ proc createAsyncSocket2*(domain: Domain, sockType: SockType,
let fd = osdefs.socket(toInt(domain), socketType, toInt(protocol)) let fd = osdefs.socket(toInt(domain), socketType, toInt(protocol))
if fd == -1: if fd == -1:
return err(osLastError()) return err(osLastError())
let res = register2(AsyncFD(fd)) register2(AsyncFD(fd)).isOkOr:
if res.isErr():
discard closeFd(fd) discard closeFd(fd)
return err(res.error()) return err(error)
ok(AsyncFD(fd)) ok(AsyncFD(fd))
else: else:
let fd = osdefs.socket(toInt(domain), toInt(sockType), toInt(protocol)) let fd = osdefs.socket(toInt(domain), toInt(sockType), toInt(protocol))
if fd == -1: if fd == -1:
return err(osLastError()) return err(osLastError())
let bres = setDescriptorFlags(cint(fd), true, true) setDescriptorFlags(cint(fd), true, true).isOkOr:
if bres.isErr():
discard closeFd(fd) discard closeFd(fd)
return err(bres.error()) return err(error)
let res = register2(AsyncFD(fd)) register2(AsyncFD(fd)).isOkOr:
if res.isErr():
discard closeFd(fd) discard closeFd(fd)
return err(bres.error()) return err(error)
ok(AsyncFD(fd)) ok(AsyncFD(fd))
proc wrapAsyncSocket2*(sock: cint|SocketHandle): Result[AsyncFD, OSErrorCode] = proc wrapAsyncSocket2*(sock: cint|SocketHandle): Result[AsyncFD, OSErrorCode] =
@ -230,3 +271,26 @@ proc createAsyncPipe*(): tuple[read: AsyncFD, write: AsyncFD] =
else: else:
let pipes = res.get() let pipes = res.get()
(read: AsyncFD(pipes.read), write: AsyncFD(pipes.write)) (read: AsyncFD(pipes.read), write: AsyncFD(pipes.write))
proc getDualstack*(fd: AsyncFD): Result[bool, OSErrorCode] =
## Returns `true` if `IPV6_V6ONLY` socket option set to `false`.
var
flag = cint(0)
size = SockLen(sizeof(flag))
let res = osdefs.getsockopt(SocketHandle(fd), cint(osdefs.IPPROTO_IPV6),
cint(osdefs.IPV6_V6ONLY), addr(flag), addr(size))
if res == -1:
return err(osLastError())
ok(flag == cint(0))
proc setDualstack*(fd: AsyncFD, value: bool): Result[void, OSErrorCode] =
## Sets `IPV6_V6ONLY` socket option value to `false` if `value == true` and
## to `true` if `value == false`.
var
flag = cint(if value: 0 else: 1)
size = SockLen(sizeof(flag))
let res = osdefs.setsockopt(SocketHandle(fd), cint(osdefs.IPPROTO_IPV6),
cint(osdefs.IPV6_V6ONLY), addr(flag), size)
if res == -1:
return err(osLastError())
ok()

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,596 @@
#
#
# Nim's Runtime Library
# (c) Copyright 2015 Dominik Picheta
# (c) Copyright 2018-Present Status Research & Development GmbH
#
# See the file "copying.txt", included in this
# distribution, for details about the copyright.
#
import
std/[macros],
../[futures, config],
./raisesfutures
proc processBody(node, setResultSym: NimNode): NimNode {.compileTime.} =
case node.kind
of nnkReturnStmt:
# `return ...` -> `setResult(...); return`
let
res = newNimNode(nnkStmtList, node)
if node[0].kind != nnkEmpty:
res.add newCall(setResultSym, processBody(node[0], setResultSym))
res.add newNimNode(nnkReturnStmt, node).add(newEmptyNode())
res
of RoutineNodes-{nnkTemplateDef}:
# Skip nested routines since they have their own return value distinct from
# the Future we inject
node
else:
if node.kind == nnkYieldStmt:
# asyncdispatch allows `yield` but this breaks cancellation
warning(
"`yield` in async procedures not supported - use `awaitne` instead",
node)
for i in 0 ..< node.len:
node[i] = processBody(node[i], setResultSym)
node
proc wrapInTryFinally(
fut, baseType, body, raises: NimNode,
handleException: bool): NimNode {.compileTime.} =
# creates:
# try: `body`
# [for raise in raises]:
# except `raise`: closureSucceeded = false; `castFutureSym`.fail(exc)
# finally:
# if closureSucceeded:
# `castFutureSym`.complete(result)
#
# Calling `complete` inside `finally` ensures that all success paths
# (including early returns and code inside nested finally statements and
# defer) are completed with the final contents of `result`
let
closureSucceeded = genSym(nskVar, "closureSucceeded")
nTry = nnkTryStmt.newTree(body)
excName = ident"exc"
# Depending on the exception type, we must have at most one of each of these
# "special" exception handlers that are needed to implement cancellation and
# Defect propagation
var
hasDefect = false
hasCancelledError = false
hasCatchableError = false
template addDefect =
if not hasDefect:
hasDefect = true
# When a Defect is raised, the program is in an undefined state and
# continuing running other tasks while the Future completion sits on the
# callback queue may lead to further damage so we re-raise them eagerly.
nTry.add nnkExceptBranch.newTree(
nnkInfix.newTree(ident"as", ident"Defect", excName),
nnkStmtList.newTree(
nnkAsgn.newTree(closureSucceeded, ident"false"),
nnkRaiseStmt.newTree(excName)
)
)
template addCancelledError =
if not hasCancelledError:
hasCancelledError = true
nTry.add nnkExceptBranch.newTree(
ident"CancelledError",
nnkStmtList.newTree(
nnkAsgn.newTree(closureSucceeded, ident"false"),
newCall(ident "cancelAndSchedule", fut)
)
)
template addCatchableError =
if not hasCatchableError:
hasCatchableError = true
nTry.add nnkExceptBranch.newTree(
nnkInfix.newTree(ident"as", ident"CatchableError", excName),
nnkStmtList.newTree(
nnkAsgn.newTree(closureSucceeded, ident"false"),
newCall(ident "fail", fut, excName)
))
var raises = if raises == nil:
nnkTupleConstr.newTree(ident"CatchableError")
elif isNoRaises(raises):
nnkTupleConstr.newTree()
else:
raises.copyNimTree()
if handleException:
raises.add(ident"Exception")
for exc in raises:
if exc.eqIdent("Exception"):
addCancelledError
addCatchableError
addDefect
# Because we store `CatchableError` in the Future, we cannot re-raise the
# original exception
nTry.add nnkExceptBranch.newTree(
nnkInfix.newTree(ident"as", ident"Exception", excName),
newCall(ident "fail", fut,
nnkStmtList.newTree(
nnkAsgn.newTree(closureSucceeded, ident"false"),
quote do:
(ref AsyncExceptionError)(
msg: `excName`.msg, parent: `excName`)))
)
elif exc.eqIdent("CancelledError"):
addCancelledError
elif exc.eqIdent("CatchableError"):
# Ensure cancellations are re-routed to the cancellation handler even if
# not explicitly specified in the raises list
addCancelledError
addCatchableError
else:
nTry.add nnkExceptBranch.newTree(
nnkInfix.newTree(ident"as", exc, excName),
nnkStmtList.newTree(
nnkAsgn.newTree(closureSucceeded, ident"false"),
newCall(ident "fail", fut, excName)
))
addDefect # Must not complete future on defect
nTry.add nnkFinally.newTree(
nnkIfStmt.newTree(
nnkElifBranch.newTree(
closureSucceeded,
if baseType.eqIdent("void"): # shortcut for non-generic void
newCall(ident "complete", fut)
else:
nnkWhenStmt.newTree(
nnkElifExpr.newTree(
nnkInfix.newTree(ident "is", baseType, ident "void"),
newCall(ident "complete", fut)
),
nnkElseExpr.newTree(
newCall(ident "complete", fut, newCall(ident "move", ident "result"))
)
)
)
)
)
nnkStmtList.newTree(
newVarStmt(closureSucceeded, ident"true"),
nTry
)
proc getName(node: NimNode): string {.compileTime.} =
case node.kind
of nnkSym:
return node.strVal
of nnkPostfix:
return node[1].strVal
of nnkIdent:
return node.strVal
of nnkEmpty:
return "anonymous"
else:
error("Unknown name.")
macro unsupported(s: static[string]): untyped =
error s
proc params2(someProc: NimNode): NimNode {.compileTime.} =
# until https://github.com/nim-lang/Nim/pull/19563 is available
if someProc.kind == nnkProcTy:
someProc[0]
else:
params(someProc)
proc cleanupOpenSymChoice(node: NimNode): NimNode {.compileTime.} =
# Replace every Call -> OpenSymChoice by a Bracket expr
# ref https://github.com/nim-lang/Nim/issues/11091
if node.kind in nnkCallKinds and
node[0].kind == nnkOpenSymChoice and node[0].eqIdent("[]"):
result = newNimNode(nnkBracketExpr)
for child in node[1..^1]:
result.add(cleanupOpenSymChoice(child))
else:
result = node.copyNimNode()
for child in node:
result.add(cleanupOpenSymChoice(child))
type
AsyncParams = tuple
raw: bool
raises: NimNode
handleException: bool
proc decodeParams(params: NimNode): AsyncParams =
# decodes the parameter tuple given in `async: (name: value, ...)` to its
# recognised parts
params.expectKind(nnkTupleConstr)
var
raw = false
raises: NimNode = nil
handleException = false
hasLocalAnnotations = false
for param in params:
param.expectKind(nnkExprColonExpr)
if param[0].eqIdent("raises"):
hasLocalAnnotations = true
param[1].expectKind(nnkBracket)
if param[1].len == 0:
raises = makeNoRaises()
else:
raises = nnkTupleConstr.newTree()
for possibleRaise in param[1]:
raises.add(possibleRaise)
elif param[0].eqIdent("raw"):
# boolVal doesn't work in untyped macros it seems..
raw = param[1].eqIdent("true")
elif param[0].eqIdent("handleException"):
hasLocalAnnotations = true
handleException = param[1].eqIdent("true")
else:
warning("Unrecognised async parameter: " & repr(param[0]), param)
if not hasLocalAnnotations:
handleException = chronosHandleException
(raw, raises, handleException)
proc isEmpty(n: NimNode): bool {.compileTime.} =
# true iff node recursively contains only comments or empties
case n.kind
of nnkEmpty, nnkCommentStmt: true
of nnkStmtList:
for child in n:
if not isEmpty(child): return false
true
else:
false
proc asyncSingleProc(prc, params: NimNode): NimNode {.compileTime.} =
## This macro transforms a single procedure into a closure iterator.
## The ``async`` macro supports a stmtList holding multiple async procedures.
if prc.kind notin {nnkProcTy, nnkProcDef, nnkLambda, nnkMethodDef, nnkDo}:
error("Cannot transform " & $prc.kind & " into an async proc." &
" proc/method definition or lambda node expected.", prc)
for pragma in prc.pragma():
if pragma.kind == nnkExprColonExpr and pragma[0].eqIdent("raises"):
warning("The raises pragma doesn't work on async procedures - use " &
"`async: (raises: [...]) instead.", prc)
let returnType = cleanupOpenSymChoice(prc.params2[0])
# Verify that the return type is a Future[T]
let baseType =
if returnType.kind == nnkEmpty:
ident "void"
elif not (
returnType.kind == nnkBracketExpr and
(eqIdent(returnType[0], "Future") or eqIdent(returnType[0], "InternalRaisesFuture"))):
error(
"Expected return type of 'Future' got '" & repr(returnType) & "'", prc)
return
else:
returnType[1]
let
# When the base type is known to be void (and not generic), we can simplify
# code generation - however, in the case of generic async procedures it
# could still end up being void, meaning void detection needs to happen
# post-macro-expansion.
baseTypeIsVoid = baseType.eqIdent("void")
(raw, raises, handleException) = decodeParams(params)
internalFutureType =
if baseTypeIsVoid:
newNimNode(nnkBracketExpr, prc).
add(newIdentNode("Future")).
add(baseType)
else:
returnType
internalReturnType = if raises == nil:
internalFutureType
else:
nnkBracketExpr.newTree(
newIdentNode("InternalRaisesFuture"),
baseType,
raises
)
prc.params2[0] = internalReturnType
if prc.kind notin {nnkProcTy, nnkLambda}:
prc.addPragma(newColonExpr(ident "stackTrace", ident "off"))
# The proc itself doesn't raise
prc.addPragma(
nnkExprColonExpr.newTree(newIdentNode("raises"), nnkBracket.newTree()))
# `gcsafe` isn't deduced even though we require async code to be gcsafe
# https://github.com/nim-lang/RFCs/issues/435
prc.addPragma(newIdentNode("gcsafe"))
if raw: # raw async = body is left as-is
if raises != nil and prc.kind notin {nnkProcTy, nnkLambda} and not isEmpty(prc.body):
# Inject `raises` type marker that causes `newFuture` to return a raise-
# tracking future instead of an ordinary future:
#
# type InternalRaisesFutureRaises = `raisesTuple`
# `body`
prc.body = nnkStmtList.newTree(
nnkTypeSection.newTree(
nnkTypeDef.newTree(
nnkPragmaExpr.newTree(
ident"InternalRaisesFutureRaises",
nnkPragma.newTree(ident "used")),
newEmptyNode(),
raises,
)
),
prc.body
)
elif prc.kind in {nnkProcDef, nnkLambda, nnkMethodDef, nnkDo} and
not isEmpty(prc.body):
let
setResultSym = ident "setResult"
procBody = prc.body.processBody(setResultSym)
resultIdent = ident "result"
fakeResult = quote do:
template result: auto {.used.} =
{.fatal: "You should not reference the `result` variable inside" &
" a void async proc".}
resultDecl =
if baseTypeIsVoid: fakeResult
else: nnkWhenStmt.newTree(
# when `baseType` is void:
nnkElifExpr.newTree(
nnkInfix.newTree(ident "is", baseType, ident "void"),
fakeResult
),
# else:
nnkElseExpr.newTree(
newStmtList(
quote do: {.push warning[resultshadowed]: off.},
# var result {.used.}: `baseType`
# In the proc body, result may or may not end up being used
# depending on how the body is written - with implicit returns /
# expressions in particular, it is likely but not guaranteed that
# it is not used. Ideally, we would avoid emitting it in this
# case to avoid the default initializaiton. {.used.} typically
# works better than {.push.} which has a tendency to leak out of
# scope.
# TODO figure out if there's a way to detect `result` usage in
# the proc body _after_ template exapnsion, and therefore
# avoid creating this variable - one option is to create an
# addtional when branch witha fake `result` and check
# `compiles(procBody)` - this is not without cost though
nnkVarSection.newTree(nnkIdentDefs.newTree(
nnkPragmaExpr.newTree(
resultIdent,
nnkPragma.newTree(ident "used")),
baseType, newEmptyNode())
),
quote do: {.pop.},
)
)
)
# ```nim
# template `setResultSym`(code: untyped) {.used.} =
# when typeof(code) is void: code
# else: `resultIdent` = code
# ```
#
# this is useful to handle implicit returns, but also
# to bind the `result` to the one we declare here
setResultDecl =
if baseTypeIsVoid: # shortcut for non-generic void
newEmptyNode()
else:
nnkTemplateDef.newTree(
setResultSym,
newEmptyNode(), newEmptyNode(),
nnkFormalParams.newTree(
newEmptyNode(),
nnkIdentDefs.newTree(
ident"code",
ident"untyped",
newEmptyNode(),
)
),
nnkPragma.newTree(ident"used"),
newEmptyNode(),
nnkWhenStmt.newTree(
nnkElifBranch.newTree(
nnkInfix.newTree(
ident"is", nnkTypeOfExpr.newTree(ident"code"), ident"void"),
ident"code"
),
nnkElse.newTree(
newAssignment(resultIdent, ident"code")
)
)
)
internalFutureSym = ident "chronosInternalRetFuture"
castFutureSym = nnkCast.newTree(internalFutureType, internalFutureSym)
# Wrapping in try/finally ensures that early returns are handled properly
# and that `defer` is processed in the right scope
completeDecl = wrapInTryFinally(
castFutureSym, baseType,
if baseTypeIsVoid: procBody # shortcut for non-generic `void`
else: newCall(setResultSym, procBody),
raises,
handleException
)
closureBody = newStmtList(resultDecl, setResultDecl, completeDecl)
internalFutureParameter = nnkIdentDefs.newTree(
internalFutureSym, newIdentNode("FutureBase"), newEmptyNode())
prcName = prc.name.getName
iteratorNameSym = genSym(nskIterator, $prcName)
closureIterator = newProc(
iteratorNameSym,
[newIdentNode("FutureBase"), internalFutureParameter],
closureBody, nnkIteratorDef)
iteratorNameSym.copyLineInfo(prc)
closureIterator.pragma = newNimNode(nnkPragma, lineInfoFrom=prc.body)
closureIterator.addPragma(newIdentNode("closure"))
# `async` code must be gcsafe
closureIterator.addPragma(newIdentNode("gcsafe"))
# Exceptions are caught inside the iterator and stored in the future
closureIterator.addPragma(nnkExprColonExpr.newTree(
newIdentNode("raises"),
nnkBracket.newTree()
))
# The body of the original procedure (now moved to the iterator) is replaced
# with:
#
# ```nim
# let resultFuture = newFuture[T]()
# resultFuture.internalClosure = `iteratorNameSym`
# futureContinue(resultFuture)
# return resultFuture
# ```
#
# Declared at the end to be sure that the closure doesn't reference it,
# avoid cyclic ref (#203)
#
# Do not change this code to `quote do` version because `instantiationInfo`
# will be broken for `newFuture()` call.
let
outerProcBody = newNimNode(nnkStmtList, prc.body)
# Copy comment for nimdoc
if prc.body.len > 0 and prc.body[0].kind == nnkCommentStmt:
outerProcBody.add(prc.body[0])
outerProcBody.add(closureIterator)
let
retFutureSym = ident "resultFuture"
newFutProc = if raises == nil:
nnkBracketExpr.newTree(ident "newFuture", baseType)
else:
nnkBracketExpr.newTree(ident "newInternalRaisesFuture", baseType, raises)
retFutureSym.copyLineInfo(prc)
outerProcBody.add(
newLetStmt(
retFutureSym,
newCall(newFutProc, newLit(prcName))
)
)
outerProcBody.add(
newAssignment(
newDotExpr(retFutureSym, newIdentNode("internalClosure")),
iteratorNameSym)
)
outerProcBody.add(
newCall(newIdentNode("futureContinue"), retFutureSym)
)
outerProcBody.add newNimNode(nnkReturnStmt, prc.body[^1]).add(retFutureSym)
prc.body = outerProcBody
when chronosDumpAsync:
echo repr prc
prc
template await*[T](f: Future[T]): T =
## Ensure that the given `Future` is finished, then return its value.
##
## If the `Future` failed or was cancelled, the corresponding exception will
## be raised instead.
##
## If the `Future` is pending, execution of the current `async` procedure
## will be suspended until the `Future` is finished.
when declared(chronosInternalRetFuture):
chronosInternalRetFuture.internalChild = f
# `futureContinue` calls the iterator generated by the `async`
# transformation - `yield` gives control back to `futureContinue` which is
# responsible for resuming execution once the yielded future is finished
yield chronosInternalRetFuture.internalChild
# `child` released by `futureContinue`
cast[type(f)](chronosInternalRetFuture.internalChild).internalRaiseIfError(f)
when T isnot void:
cast[type(f)](chronosInternalRetFuture.internalChild).value()
else:
unsupported "await is only available within {.async.}"
template await*[T, E](fut: InternalRaisesFuture[T, E]): T =
## Ensure that the given `Future` is finished, then return its value.
##
## If the `Future` failed or was cancelled, the corresponding exception will
## be raised instead.
##
## If the `Future` is pending, execution of the current `async` procedure
## will be suspended until the `Future` is finished.
when declared(chronosInternalRetFuture):
chronosInternalRetFuture.internalChild = fut
# `futureContinue` calls the iterator generated by the `async`
# transformation - `yield` gives control back to `futureContinue` which is
# responsible for resuming execution once the yielded future is finished
yield chronosInternalRetFuture.internalChild
# `child` released by `futureContinue`
cast[type(fut)](
chronosInternalRetFuture.internalChild).internalRaiseIfError(E, fut)
when T isnot void:
cast[type(fut)](chronosInternalRetFuture.internalChild).value()
else:
unsupported "await is only available within {.async.}"
template awaitne*[T](f: Future[T]): Future[T] =
when declared(chronosInternalRetFuture):
chronosInternalRetFuture.internalChild = f
yield chronosInternalRetFuture.internalChild
cast[type(f)](chronosInternalRetFuture.internalChild)
else:
unsupported "awaitne is only available within {.async.}"
macro async*(params, prc: untyped): untyped =
## Macro which processes async procedures into the appropriate
## iterators and yield statements.
if prc.kind == nnkStmtList:
result = newStmtList()
for oneProc in prc:
result.add asyncSingleProc(oneProc, params)
else:
result = asyncSingleProc(prc, params)
macro async*(prc: untyped): untyped =
## Macro which processes async procedures into the appropriate
## iterators and yield statements.
if prc.kind == nnkStmtList:
result = newStmtList()
for oneProc in prc:
result.add asyncSingleProc(oneProc, nnkTupleConstr.newTree())
else:
result = asyncSingleProc(prc, nnkTupleConstr.newTree())

View File

@ -0,0 +1,9 @@
type
AsyncError* = object of CatchableError
## Generic async exception
AsyncTimeoutError* = object of AsyncError
## Timeout exception
AsyncExceptionError* = object of AsyncError
## Error raised in `handleException` mode - the original exception is
## available from the `parent` field.

View File

@ -0,0 +1,248 @@
import
std/[macros, sequtils],
../futures
{.push raises: [].}
type
InternalRaisesFuture*[T, E] = ref object of Future[T]
## Future with a tuple of possible exception types
## eg InternalRaisesFuture[void, (ValueError, OSError)]
##
## This type gets injected by `async: (raises: ...)` and similar utilities
## and should not be used manually as the internal exception representation
## is subject to change in future chronos versions.
# TODO https://github.com/nim-lang/Nim/issues/23418
# TODO https://github.com/nim-lang/Nim/issues/23419
when E is void:
dummy: E
else:
dummy: array[0, E]
proc makeNoRaises*(): NimNode {.compileTime.} =
# An empty tuple would have been easier but...
# https://github.com/nim-lang/Nim/issues/22863
# https://github.com/nim-lang/Nim/issues/22865
ident"void"
proc dig(n: NimNode): NimNode {.compileTime.} =
# Dig through the layers of type to find the raises list
if n.eqIdent("void"):
n
elif n.kind == nnkBracketExpr:
if n[0].eqIdent("tuple"):
n
elif n[0].eqIdent("typeDesc"):
dig(getType(n[1]))
else:
echo astGenRepr(n)
raiseAssert "Unkown bracket"
elif n.kind == nnkTupleConstr:
n
else:
dig(getType(getTypeInst(n)))
proc isNoRaises*(n: NimNode): bool {.compileTime.} =
dig(n).eqIdent("void")
iterator members(tup: NimNode): NimNode =
# Given a typedesc[tuple] = (A, B, C), yields the tuple members (A, B C)
if not isNoRaises(tup):
for n in getType(getTypeInst(tup)[1])[1..^1]:
yield n
proc members(tup: NimNode): seq[NimNode] {.compileTime.} =
for t in tup.members():
result.add(t)
macro hasException(raises: typedesc, ident: static string): bool =
newLit(raises.members.anyIt(it.eqIdent(ident)))
macro Raising*[T](F: typedesc[Future[T]], E: typed): untyped =
## Given a Future type instance, return a type storing `{.raises.}`
## information
##
## Note; this type may change in the future
# An earlier version used `E: varargs[typedesc]` here but this is buggyt/no
# longer supported in 2.0 in certain cases:
# https://github.com/nim-lang/Nim/issues/23432
let
e =
case E.getTypeInst().typeKind()
of ntyTypeDesc: @[E]
of ntyArray:
for x in E:
if x.getTypeInst().typeKind != ntyTypeDesc:
error("Expected typedesc, got " & repr(x), x)
E.mapIt(it)
else:
error("Expected typedesc, got " & repr(E), E)
@[]
let raises = if e.len == 0:
makeNoRaises()
else:
nnkTupleConstr.newTree(e)
nnkBracketExpr.newTree(
ident "InternalRaisesFuture",
nnkDotExpr.newTree(F, ident"T"),
raises
)
template init*[T, E](
F: type InternalRaisesFuture[T, E], fromProc: static[string] = ""): F =
## Creates a new pending future.
##
## Specifying ``fromProc``, which is a string specifying the name of the proc
## that this future belongs to, is a good habit as it helps with debugging.
when not hasException(type(E), "CancelledError"):
static:
raiseAssert "Manually created futures must either own cancellation schedule or raise CancelledError"
let res = F()
internalInitFutureBase(res, getSrcLocation(fromProc), FutureState.Pending, {})
res
template init*[T, E](
F: type InternalRaisesFuture[T, E], fromProc: static[string] = "",
flags: static[FutureFlags]): F =
## Creates a new pending future.
##
## Specifying ``fromProc``, which is a string specifying the name of the proc
## that this future belongs to, is a good habit as it helps with debugging.
let res = F()
when not hasException(type(E), "CancelledError"):
static:
doAssert FutureFlag.OwnCancelSchedule in flags,
"Manually created futures must either own cancellation schedule or raise CancelledError"
internalInitFutureBase(
res, getSrcLocation(fromProc), FutureState.Pending, flags)
res
proc containsSignature(members: openArray[NimNode], typ: NimNode): bool {.compileTime.} =
let typHash = signatureHash(typ)
for err in members:
if signatureHash(err) == typHash:
return true
false
# Utilities for working with the E part of InternalRaisesFuture - unstable
macro prepend*(tup: typedesc, typs: varargs[typed]): typedesc =
result = nnkTupleConstr.newTree()
for err in typs:
if not tup.members().containsSignature(err):
result.add err
for err in tup.members():
result.add err
if result.len == 0:
result = makeNoRaises()
macro remove*(tup: typedesc, typs: varargs[typed]): typedesc =
result = nnkTupleConstr.newTree()
for err in tup.members():
if not typs[0..^1].containsSignature(err):
result.add err
if result.len == 0:
result = makeNoRaises()
macro union*(tup0: typedesc, tup1: typedesc): typedesc =
## Join the types of the two tuples deduplicating the entries
result = nnkTupleConstr.newTree()
for err in tup0.members():
var found = false
for err2 in tup1.members():
if signatureHash(err) == signatureHash(err2):
found = true
if not found:
result.add err
for err2 in tup1.members():
result.add err2
if result.len == 0:
result = makeNoRaises()
proc getRaisesTypes*(raises: NimNode): NimNode =
let typ = getType(raises)
case typ.typeKind
of ntyTypeDesc: typ[1]
else: typ
macro checkRaises*[T: CatchableError](
future: InternalRaisesFuture, raises: typed, error: ref T,
warn: static bool = true): untyped =
## Generate code that checks that the given error is compatible with the
## raises restrictions of `future`.
##
## This check is done either at compile time or runtime depending on the
## information available at compile time - in particular, if the raises
## inherit from `error`, we end up with the equivalent of a downcast which
## raises a Defect if it fails.
let
raises = getRaisesTypes(raises)
expectKind(getTypeInst(error), nnkRefTy)
let toMatch = getTypeInst(error)[0]
if isNoRaises(raises):
error(
"`fail`: `" & repr(toMatch) & "` incompatible with `raises: []`", future)
return
var
typeChecker = ident"false"
maybeChecker = ident"false"
runtimeChecker = ident"false"
for errorType in raises[1..^1]:
typeChecker = infix(typeChecker, "or", infix(toMatch, "is", errorType))
maybeChecker = infix(maybeChecker, "or", infix(errorType, "is", toMatch))
runtimeChecker = infix(
runtimeChecker, "or",
infix(error, "of", nnkBracketExpr.newTree(ident"typedesc", errorType)))
let
errorMsg = "`fail`: `" & repr(toMatch) & "` incompatible with `raises: " & repr(raises[1..^1]) & "`"
warningMsg = "Can't verify `fail` exception type at compile time - expected one of " & repr(raises[1..^1]) & ", got `" & repr(toMatch) & "`"
# A warning from this line means exception type will be verified at runtime
warning = if warn:
quote do: {.warning: `warningMsg`.}
else: newEmptyNode()
# Cannot check inhertance in macro so we let `static` do the heavy lifting
quote do:
when not(`typeChecker`):
when not(`maybeChecker`):
static:
{.error: `errorMsg`.}
else:
`warning`
assert(`runtimeChecker`, `errorMsg`)
func failed*[T](future: InternalRaisesFuture[T, void]): bool {.inline.} =
## Determines whether ``future`` finished with an error.
static:
warning("No exceptions possible with this operation, `failed` always returns false")
false
func error*[T](future: InternalRaisesFuture[T, void]): ref CatchableError {.
raises: [].} =
static:
warning("No exceptions possible with this operation, `error` always returns nil")
nil
func readError*[T](future: InternalRaisesFuture[T, void]): ref CatchableError {.
raises: [ValueError].} =
static:
warning("No exceptions possible with this operation, `readError` always raises")
raise newException(ValueError, "No error in future.")

View File

@ -97,12 +97,12 @@ proc new*(t: typedesc[Selector], T: typedesc): SelectResult[Selector[T]] =
var nmask: Sigset var nmask: Sigset
if sigemptyset(nmask) < 0: if sigemptyset(nmask) < 0:
return err(osLastError()) return err(osLastError())
let epollFd = epoll_create(asyncEventsCount) let epollFd = epoll_create(chronosEventsCount)
if epollFd < 0: if epollFd < 0:
return err(osLastError()) return err(osLastError())
let selector = Selector[T]( let selector = Selector[T](
epollFd: epollFd, epollFd: epollFd,
fds: initTable[int32, SelectorKey[T]](asyncInitialSize), fds: initTable[int32, SelectorKey[T]](chronosInitialSize),
signalMask: nmask, signalMask: nmask,
virtualId: -1'i32, # Should start with -1, because `InvalidIdent` == -1 virtualId: -1'i32, # Should start with -1, because `InvalidIdent` == -1
childrenExited: false, childrenExited: false,
@ -411,7 +411,7 @@ proc registerProcess*[T](s: Selector, pid: int, data: T): SelectResult[cint] =
s.freeKey(fdi32) s.freeKey(fdi32)
s.freeProcess(int32(pid)) s.freeProcess(int32(pid))
return err(res.error()) return err(res.error())
s.pidFd = Opt.some(cast[cint](res.get())) s.pidFd = Opt.some(res.get())
ok(cint(fdi32)) ok(cint(fdi32))
@ -627,7 +627,7 @@ proc selectInto2*[T](s: Selector[T], timeout: int,
readyKeys: var openArray[ReadyKey] readyKeys: var openArray[ReadyKey]
): SelectResult[int] = ): SelectResult[int] =
var var
queueEvents: array[asyncEventsCount, EpollEvent] queueEvents: array[chronosEventsCount, EpollEvent]
k: int = 0 k: int = 0
verifySelectParams(timeout, -1, int(high(cint))) verifySelectParams(timeout, -1, int(high(cint)))
@ -668,7 +668,7 @@ proc selectInto2*[T](s: Selector[T], timeout: int,
ok(k) ok(k)
proc select2*[T](s: Selector[T], timeout: int): SelectResult[seq[ReadyKey]] = proc select2*[T](s: Selector[T], timeout: int): SelectResult[seq[ReadyKey]] =
var res = newSeq[ReadyKey](asyncEventsCount) var res = newSeq[ReadyKey](chronosEventsCount)
let count = ? selectInto2(s, timeout, res) let count = ? selectInto2(s, timeout, res)
res.setLen(count) res.setLen(count)
ok(res) ok(res)

View File

@ -110,7 +110,7 @@ proc new*(t: typedesc[Selector], T: typedesc): SelectResult[Selector[T]] =
let selector = Selector[T]( let selector = Selector[T](
kqFd: kqFd, kqFd: kqFd,
fds: initTable[int32, SelectorKey[T]](asyncInitialSize), fds: initTable[int32, SelectorKey[T]](chronosInitialSize),
virtualId: -1'i32, # Should start with -1, because `InvalidIdent` == -1 virtualId: -1'i32, # Should start with -1, because `InvalidIdent` == -1
virtualHoles: initDeque[int32]() virtualHoles: initDeque[int32]()
) )
@ -559,7 +559,7 @@ proc selectInto2*[T](s: Selector[T], timeout: int,
): SelectResult[int] = ): SelectResult[int] =
var var
tv: Timespec tv: Timespec
queueEvents: array[asyncEventsCount, KEvent] queueEvents: array[chronosEventsCount, KEvent]
verifySelectParams(timeout, -1, high(int)) verifySelectParams(timeout, -1, high(int))
@ -575,7 +575,7 @@ proc selectInto2*[T](s: Selector[T], timeout: int,
addr tv addr tv
else: else:
nil nil
maxEventsCount = cint(min(asyncEventsCount, len(readyKeys))) maxEventsCount = cint(min(chronosEventsCount, len(readyKeys)))
eventsCount = eventsCount =
block: block:
var res = 0 var res = 0
@ -601,7 +601,7 @@ proc selectInto2*[T](s: Selector[T], timeout: int,
proc select2*[T](s: Selector[T], proc select2*[T](s: Selector[T],
timeout: int): Result[seq[ReadyKey], OSErrorCode] = timeout: int): Result[seq[ReadyKey], OSErrorCode] =
var res = newSeq[ReadyKey](asyncEventsCount) var res = newSeq[ReadyKey](chronosEventsCount)
let count = ? selectInto2(s, timeout, res) let count = ? selectInto2(s, timeout, res)
res.setLen(count) res.setLen(count)
ok(res) ok(res)

View File

@ -16,7 +16,7 @@ import stew/base10
type type
SelectorImpl[T] = object SelectorImpl[T] = object
fds: Table[int32, SelectorKey[T]] fds: Table[int32, SelectorKey[T]]
pollfds: seq[TPollFd] pollfds: seq[TPollfd]
Selector*[T] = ref SelectorImpl[T] Selector*[T] = ref SelectorImpl[T]
type type
@ -50,7 +50,7 @@ proc freeKey[T](s: Selector[T], key: int32) =
proc new*(t: typedesc[Selector], T: typedesc): SelectResult[Selector[T]] = proc new*(t: typedesc[Selector], T: typedesc): SelectResult[Selector[T]] =
let selector = Selector[T]( let selector = Selector[T](
fds: initTable[int32, SelectorKey[T]](asyncInitialSize) fds: initTable[int32, SelectorKey[T]](chronosInitialSize)
) )
ok(selector) ok(selector)
@ -72,7 +72,7 @@ proc trigger2*(event: SelectEvent): SelectResult[void] =
if res == -1: if res == -1:
err(osLastError()) err(osLastError())
elif res != sizeof(uint64): elif res != sizeof(uint64):
err(OSErrorCode(osdefs.EINVAL)) err(osdefs.EINVAL)
else: else:
ok() ok()
@ -98,13 +98,14 @@ template toPollEvents(events: set[Event]): cshort =
res res
template pollAdd[T](s: Selector[T], sock: cint, events: set[Event]) = template pollAdd[T](s: Selector[T], sock: cint, events: set[Event]) =
s.pollfds.add(TPollFd(fd: sock, events: toPollEvents(events), revents: 0)) s.pollfds.add(TPollfd(fd: sock, events: toPollEvents(events), revents: 0))
template pollUpdate[T](s: Selector[T], sock: cint, events: set[Event]) = template pollUpdate[T](s: Selector[T], sock: cint, events: set[Event]) =
var updated = false var updated = false
for mitem in s.pollfds.mitems(): for mitem in s.pollfds.mitems():
if mitem.fd == sock: if mitem.fd == sock:
mitem.events = toPollEvents(events) mitem.events = toPollEvents(events)
updated = true
break break
if not(updated): if not(updated):
raiseAssert "Descriptor [" & $sock & "] is not registered in the queue!" raiseAssert "Descriptor [" & $sock & "] is not registered in the queue!"
@ -177,7 +178,6 @@ proc unregister2*[T](s: Selector[T], event: SelectEvent): SelectResult[void] =
proc prepareKey[T](s: Selector[T], event: var TPollfd): Opt[ReadyKey] = proc prepareKey[T](s: Selector[T], event: var TPollfd): Opt[ReadyKey] =
let let
defaultKey = SelectorKey[T](ident: InvalidIdent)
fdi32 = int32(event.fd) fdi32 = int32(event.fd)
revents = event.revents revents = event.revents
@ -220,11 +220,14 @@ proc selectInto2*[T](s: Selector[T], timeout: int,
verifySelectParams(timeout, -1, int(high(cint))) verifySelectParams(timeout, -1, int(high(cint)))
let let
maxEventsCount = min(len(s.pollfds), len(readyKeys)) maxEventsCount = culong(min(len(s.pollfds), len(readyKeys)))
# Without `culong` conversion, this code could fail with RangeError
# defect on explicit Tnfds(integer) conversion (probably related to
# combination of nim+clang (android toolchain)).
eventsCount = eventsCount =
if maxEventsCount > 0: if maxEventsCount > 0:
let res = handleEintr(poll(addr(s.pollfds[0]), Tnfds(maxEventsCount), let res = handleEintr(poll(addr(s.pollfds[0]), Tnfds(maxEventsCount),
timeout)) cint(timeout)))
if res < 0: if res < 0:
return err(osLastError()) return err(osLastError())
res res
@ -241,7 +244,7 @@ proc selectInto2*[T](s: Selector[T], timeout: int,
ok(k) ok(k)
proc select2*[T](s: Selector[T], timeout: int): SelectResult[seq[ReadyKey]] = proc select2*[T](s: Selector[T], timeout: int): SelectResult[seq[ReadyKey]] =
var res = newSeq[ReadyKey](asyncEventsCount) var res = newSeq[ReadyKey](chronosEventsCount)
let count = ? selectInto2(s, timeout, res) let count = ? selectInto2(s, timeout, res)
res.setLen(count) res.setLen(count)
ok(res) ok(res)

View File

@ -122,6 +122,7 @@ when defined(windows):
SO_UPDATE_ACCEPT_CONTEXT* = 0x700B SO_UPDATE_ACCEPT_CONTEXT* = 0x700B
SO_CONNECT_TIME* = 0x700C SO_CONNECT_TIME* = 0x700C
SO_UPDATE_CONNECT_CONTEXT* = 0x7010 SO_UPDATE_CONNECT_CONTEXT* = 0x7010
SO_PROTOCOL_INFOW* = 0x2005
FILE_FLAG_FIRST_PIPE_INSTANCE* = 0x00080000'u32 FILE_FLAG_FIRST_PIPE_INSTANCE* = 0x00080000'u32
FILE_FLAG_OPEN_NO_RECALL* = 0x00100000'u32 FILE_FLAG_OPEN_NO_RECALL* = 0x00100000'u32
@ -258,6 +259,9 @@ when defined(windows):
FIONBIO* = WSAIOW(102, 126) FIONBIO* = WSAIOW(102, 126)
HANDLE_FLAG_INHERIT* = 1'u32 HANDLE_FLAG_INHERIT* = 1'u32
IPV6_V6ONLY* = 27
MAX_PROTOCOL_CHAIN* = 7
WSAPROTOCOL_LEN* = 255
type type
LONG* = int32 LONG* = int32
@ -441,6 +445,32 @@ when defined(windows):
prefix*: SOCKADDR_INET prefix*: SOCKADDR_INET
prefixLength*: uint8 prefixLength*: uint8
WSAPROTOCOLCHAIN* {.final, pure.} = object
chainLen*: int32
chainEntries*: array[MAX_PROTOCOL_CHAIN, DWORD]
WSAPROTOCOL_INFO* {.final, pure.} = object
dwServiceFlags1*: uint32
dwServiceFlags2*: uint32
dwServiceFlags3*: uint32
dwServiceFlags4*: uint32
dwProviderFlags*: uint32
providerId*: GUID
dwCatalogEntryId*: DWORD
protocolChain*: WSAPROTOCOLCHAIN
iVersion*: int32
iAddressFamily*: int32
iMaxSockAddr*: int32
iMinSockAddr*: int32
iSocketType*: int32
iProtocol*: int32
iProtocolMaxOffset*: int32
iNetworkByteOrder*: int32
iSecurityScheme*: int32
dwMessageSize*: uint32
dwProviderReserved*: uint32
szProtocol*: array[WSAPROTOCOL_LEN + 1, WCHAR]
MibIpForwardRow2* {.final, pure.} = object MibIpForwardRow2* {.final, pure.} = object
interfaceLuid*: uint64 interfaceLuid*: uint64
interfaceIndex*: uint32 interfaceIndex*: uint32
@ -708,7 +738,7 @@ when defined(windows):
res: var ptr AddrInfo): cint {. res: var ptr AddrInfo): cint {.
stdcall, dynlib: "ws2_32", importc: "getaddrinfo", sideEffect.} stdcall, dynlib: "ws2_32", importc: "getaddrinfo", sideEffect.}
proc freeaddrinfo*(ai: ptr AddrInfo) {. proc freeAddrInfo*(ai: ptr AddrInfo) {.
stdcall, dynlib: "ws2_32", importc: "freeaddrinfo", sideEffect.} stdcall, dynlib: "ws2_32", importc: "freeaddrinfo", sideEffect.}
proc createIoCompletionPort*(fileHandle: HANDLE, proc createIoCompletionPort*(fileHandle: HANDLE,
@ -880,7 +910,7 @@ elif defined(macos) or defined(macosx):
sigemptyset, sigaddset, sigismember, fcntl, accept, sigemptyset, sigaddset, sigismember, fcntl, accept,
pipe, write, signal, read, setsockopt, getsockopt, pipe, write, signal, read, setsockopt, getsockopt,
getcwd, chdir, waitpid, kill, select, pselect, getcwd, chdir, waitpid, kill, select, pselect,
socketpair, socketpair, poll, freeAddrInfo,
Timeval, Timespec, Pid, Mode, Time, Sigset, SockAddr, Timeval, Timespec, Pid, Mode, Time, Sigset, SockAddr,
SockLen, Sockaddr_storage, Sockaddr_in, Sockaddr_in6, SockLen, Sockaddr_storage, Sockaddr_in, Sockaddr_in6,
Sockaddr_un, SocketHandle, AddrInfo, RLimit, TFdSet, Sockaddr_un, SocketHandle, AddrInfo, RLimit, TFdSet,
@ -890,7 +920,7 @@ elif defined(macos) or defined(macosx):
O_NONBLOCK, SOL_SOCKET, SOCK_RAW, SOCK_DGRAM, O_NONBLOCK, SOL_SOCKET, SOCK_RAW, SOCK_DGRAM,
SOCK_STREAM, MSG_NOSIGNAL, MSG_PEEK, SOCK_STREAM, MSG_NOSIGNAL, MSG_PEEK,
AF_INET, AF_INET6, AF_UNIX, SO_ERROR, SO_REUSEADDR, AF_INET, AF_INET6, AF_UNIX, SO_ERROR, SO_REUSEADDR,
SO_REUSEPORT, SO_BROADCAST, IPPROTO_IP, SO_REUSEPORT, SO_BROADCAST, IPPROTO_IP, IPPROTO_IPV6,
IPV6_MULTICAST_HOPS, SOCK_DGRAM, RLIMIT_NOFILE, IPV6_MULTICAST_HOPS, SOCK_DGRAM, RLIMIT_NOFILE,
SIG_BLOCK, SIG_UNBLOCK, SHUT_RD, SHUT_WR, SHUT_RDWR, SIG_BLOCK, SIG_UNBLOCK, SHUT_RD, SHUT_WR, SHUT_RDWR,
SIGHUP, SIGINT, SIGQUIT, SIGILL, SIGTRAP, SIGABRT, SIGHUP, SIGINT, SIGQUIT, SIGILL, SIGTRAP, SIGABRT,
@ -905,7 +935,7 @@ elif defined(macos) or defined(macosx):
sigemptyset, sigaddset, sigismember, fcntl, accept, sigemptyset, sigaddset, sigismember, fcntl, accept,
pipe, write, signal, read, setsockopt, getsockopt, pipe, write, signal, read, setsockopt, getsockopt,
getcwd, chdir, waitpid, kill, select, pselect, getcwd, chdir, waitpid, kill, select, pselect,
socketpair, socketpair, poll, freeAddrInfo,
Timeval, Timespec, Pid, Mode, Time, Sigset, SockAddr, Timeval, Timespec, Pid, Mode, Time, Sigset, SockAddr,
SockLen, Sockaddr_storage, Sockaddr_in, Sockaddr_in6, SockLen, Sockaddr_storage, Sockaddr_in, Sockaddr_in6,
Sockaddr_un, SocketHandle, AddrInfo, RLimit, TFdSet, Sockaddr_un, SocketHandle, AddrInfo, RLimit, TFdSet,
@ -915,7 +945,7 @@ elif defined(macos) or defined(macosx):
O_NONBLOCK, SOL_SOCKET, SOCK_RAW, SOCK_DGRAM, O_NONBLOCK, SOL_SOCKET, SOCK_RAW, SOCK_DGRAM,
SOCK_STREAM, MSG_NOSIGNAL, MSG_PEEK, SOCK_STREAM, MSG_NOSIGNAL, MSG_PEEK,
AF_INET, AF_INET6, AF_UNIX, SO_ERROR, SO_REUSEADDR, AF_INET, AF_INET6, AF_UNIX, SO_ERROR, SO_REUSEADDR,
SO_REUSEPORT, SO_BROADCAST, IPPROTO_IP, SO_REUSEPORT, SO_BROADCAST, IPPROTO_IP, IPPROTO_IPV6,
IPV6_MULTICAST_HOPS, SOCK_DGRAM, RLIMIT_NOFILE, IPV6_MULTICAST_HOPS, SOCK_DGRAM, RLIMIT_NOFILE,
SIG_BLOCK, SIG_UNBLOCK, SHUT_RD, SHUT_WR, SHUT_RDWR, SIG_BLOCK, SIG_UNBLOCK, SHUT_RD, SHUT_WR, SHUT_RDWR,
SIGHUP, SIGINT, SIGQUIT, SIGILL, SIGTRAP, SIGABRT, SIGHUP, SIGINT, SIGQUIT, SIGILL, SIGTRAP, SIGABRT,
@ -929,6 +959,21 @@ elif defined(macos) or defined(macosx):
numer*: uint32 numer*: uint32
denom*: uint32 denom*: uint32
TPollfd* {.importc: "struct pollfd", pure, final,
header: "<poll.h>".} = object
fd*: cint
events*: cshort
revents*: cshort
Tnfds* {.importc: "nfds_t", header: "<poll.h>".} = culong
const
POLLIN* = 0x0001
POLLOUT* = 0x0004
POLLERR* = 0x0008
POLLHUP* = 0x0010
POLLNVAL* = 0x0020
proc posix_gettimeofday*(tp: var Timeval, unused: pointer = nil) {. proc posix_gettimeofday*(tp: var Timeval, unused: pointer = nil) {.
importc: "gettimeofday", header: "<sys/time.h>".} importc: "gettimeofday", header: "<sys/time.h>".}
@ -938,6 +983,9 @@ elif defined(macos) or defined(macosx):
proc mach_absolute_time*(): uint64 {. proc mach_absolute_time*(): uint64 {.
importc, header: "<mach/mach_time.h>".} importc, header: "<mach/mach_time.h>".}
proc poll*(a1: ptr TPollfd, a2: Tnfds, a3: cint): cint {.
importc, header: "<poll.h>", sideEffect.}
elif defined(linux): elif defined(linux):
from std/posix import close, shutdown, sigemptyset, sigaddset, sigismember, from std/posix import close, shutdown, sigemptyset, sigaddset, sigismember,
sigdelset, write, read, waitid, getaddrinfo, sigdelset, write, read, waitid, getaddrinfo,
@ -947,20 +995,22 @@ elif defined(linux):
unlink, listen, sendmsg, recvmsg, getpid, fcntl, unlink, listen, sendmsg, recvmsg, getpid, fcntl,
pthread_sigmask, sigprocmask, clock_gettime, signal, pthread_sigmask, sigprocmask, clock_gettime, signal,
getcwd, chdir, waitpid, kill, select, pselect, getcwd, chdir, waitpid, kill, select, pselect,
socketpair, socketpair, poll, freeAddrInfo,
ClockId, Itimerspec, Timespec, Sigset, Time, Pid, Mode, ClockId, Itimerspec, Timespec, Sigset, Time, Pid, Mode,
SigInfo, Id, Tmsghdr, IOVec, RLimit, Timeval, TFdSet, SigInfo, Id, Tmsghdr, IOVec, RLimit, Timeval, TFdSet,
SockAddr, SockLen, Sockaddr_storage, Sockaddr_in, SockAddr, SockLen, Sockaddr_storage, Sockaddr_in,
Sockaddr_in6, Sockaddr_un, AddrInfo, SocketHandle, Sockaddr_in6, Sockaddr_un, AddrInfo, SocketHandle,
Suseconds, Suseconds, TPollfd, Tnfds,
FD_CLR, FD_ISSET, FD_SET, FD_ZERO, FD_CLR, FD_ISSET, FD_SET, FD_ZERO,
CLOCK_MONOTONIC, F_GETFL, F_SETFL, F_GETFD, F_SETFD, CLOCK_MONOTONIC, F_GETFL, F_SETFL, F_GETFD, F_SETFD,
FD_CLOEXEC, O_NONBLOCK, SIG_BLOCK, SIG_UNBLOCK, FD_CLOEXEC, O_NONBLOCK, SIG_BLOCK, SIG_UNBLOCK,
SOL_SOCKET, SO_ERROR, RLIMIT_NOFILE, MSG_NOSIGNAL, SOL_SOCKET, SO_ERROR, RLIMIT_NOFILE, MSG_NOSIGNAL,
MSG_PEEK, MSG_PEEK,
AF_INET, AF_INET6, AF_UNIX, SO_REUSEADDR, SO_REUSEPORT, AF_INET, AF_INET6, AF_UNIX, SO_REUSEADDR, SO_REUSEPORT,
SO_BROADCAST, IPPROTO_IP, IPV6_MULTICAST_HOPS, SO_BROADCAST, IPPROTO_IP, IPPROTO_IPV6,
IPV6_MULTICAST_HOPS,
SOCK_DGRAM, SOCK_STREAM, SHUT_RD, SHUT_WR, SHUT_RDWR, SOCK_DGRAM, SOCK_STREAM, SHUT_RD, SHUT_WR, SHUT_RDWR,
POLLIN, POLLOUT, POLLERR, POLLHUP, POLLNVAL,
SIGHUP, SIGINT, SIGQUIT, SIGILL, SIGTRAP, SIGABRT, SIGHUP, SIGINT, SIGQUIT, SIGILL, SIGTRAP, SIGABRT,
SIGBUS, SIGFPE, SIGKILL, SIGUSR1, SIGSEGV, SIGUSR2, SIGBUS, SIGFPE, SIGKILL, SIGUSR1, SIGSEGV, SIGUSR2,
SIGPIPE, SIGALRM, SIGTERM, SIGPIPE, SIGCHLD, SIGSTOP, SIGPIPE, SIGALRM, SIGTERM, SIGPIPE, SIGCHLD, SIGSTOP,
@ -974,20 +1024,21 @@ elif defined(linux):
unlink, listen, sendmsg, recvmsg, getpid, fcntl, unlink, listen, sendmsg, recvmsg, getpid, fcntl,
pthread_sigmask, sigprocmask, clock_gettime, signal, pthread_sigmask, sigprocmask, clock_gettime, signal,
getcwd, chdir, waitpid, kill, select, pselect, getcwd, chdir, waitpid, kill, select, pselect,
socketpair, socketpair, poll, freeAddrInfo,
ClockId, Itimerspec, Timespec, Sigset, Time, Pid, Mode, ClockId, Itimerspec, Timespec, Sigset, Time, Pid, Mode,
SigInfo, Id, Tmsghdr, IOVec, RLimit, TFdSet, Timeval, SigInfo, Id, Tmsghdr, IOVec, RLimit, TFdSet, Timeval,
SockAddr, SockLen, Sockaddr_storage, Sockaddr_in, SockAddr, SockLen, Sockaddr_storage, Sockaddr_in,
Sockaddr_in6, Sockaddr_un, AddrInfo, SocketHandle, Sockaddr_in6, Sockaddr_un, AddrInfo, SocketHandle,
Suseconds, Suseconds, TPollfd, Tnfds,
FD_CLR, FD_ISSET, FD_SET, FD_ZERO, FD_CLR, FD_ISSET, FD_SET, FD_ZERO,
CLOCK_MONOTONIC, F_GETFL, F_SETFL, F_GETFD, F_SETFD, CLOCK_MONOTONIC, F_GETFL, F_SETFL, F_GETFD, F_SETFD,
FD_CLOEXEC, O_NONBLOCK, SIG_BLOCK, SIG_UNBLOCK, FD_CLOEXEC, O_NONBLOCK, SIG_BLOCK, SIG_UNBLOCK,
SOL_SOCKET, SO_ERROR, RLIMIT_NOFILE, MSG_NOSIGNAL, SOL_SOCKET, SO_ERROR, RLIMIT_NOFILE, MSG_NOSIGNAL,
MSG_PEEK, MSG_PEEK,
AF_INET, AF_INET6, AF_UNIX, SO_REUSEADDR, SO_REUSEPORT, AF_INET, AF_INET6, AF_UNIX, SO_REUSEADDR, SO_REUSEPORT,
SO_BROADCAST, IPPROTO_IP, IPV6_MULTICAST_HOPS, SO_BROADCAST, IPPROTO_IP, IPPROTO_IPV6, IPV6_MULTICAST_HOPS,
SOCK_DGRAM, SOCK_STREAM, SHUT_RD, SHUT_WR, SHUT_RDWR, SOCK_DGRAM, SOCK_STREAM, SHUT_RD, SHUT_WR, SHUT_RDWR,
POLLIN, POLLOUT, POLLERR, POLLHUP, POLLNVAL,
SIGHUP, SIGINT, SIGQUIT, SIGILL, SIGTRAP, SIGABRT, SIGHUP, SIGINT, SIGQUIT, SIGILL, SIGTRAP, SIGABRT,
SIGBUS, SIGFPE, SIGKILL, SIGUSR1, SIGSEGV, SIGUSR2, SIGBUS, SIGFPE, SIGKILL, SIGUSR1, SIGSEGV, SIGUSR2,
SIGPIPE, SIGALRM, SIGTERM, SIGPIPE, SIGCHLD, SIGSTOP, SIGPIPE, SIGALRM, SIGTERM, SIGPIPE, SIGCHLD, SIGSTOP,
@ -1097,20 +1148,21 @@ elif defined(freebsd) or defined(openbsd) or defined(netbsd) or
sigaddset, sigismember, fcntl, accept, pipe, write, sigaddset, sigismember, fcntl, accept, pipe, write,
signal, read, setsockopt, getsockopt, clock_gettime, signal, read, setsockopt, getsockopt, clock_gettime,
getcwd, chdir, waitpid, kill, select, pselect, getcwd, chdir, waitpid, kill, select, pselect,
socketpair, socketpair, poll, freeAddrInfo,
Timeval, Timespec, Pid, Mode, Time, Sigset, SockAddr, Timeval, Timespec, Pid, Mode, Time, Sigset, SockAddr,
SockLen, Sockaddr_storage, Sockaddr_in, Sockaddr_in6, SockLen, Sockaddr_storage, Sockaddr_in, Sockaddr_in6,
Sockaddr_un, SocketHandle, AddrInfo, RLimit, TFdSet, Sockaddr_un, SocketHandle, AddrInfo, RLimit, TFdSet,
Suseconds, Suseconds, TPollfd, Tnfds,
FD_CLR, FD_ISSET, FD_SET, FD_ZERO, FD_CLR, FD_ISSET, FD_SET, FD_ZERO,
F_GETFL, F_SETFL, F_GETFD, F_SETFD, FD_CLOEXEC, F_GETFL, F_SETFL, F_GETFD, F_SETFD, FD_CLOEXEC,
O_NONBLOCK, SOL_SOCKET, SOCK_RAW, SOCK_DGRAM, O_NONBLOCK, SOL_SOCKET, SOCK_RAW, SOCK_DGRAM,
SOCK_STREAM, MSG_NOSIGNAL, MSG_PEEK, SOCK_STREAM, MSG_NOSIGNAL, MSG_PEEK,
AF_INET, AF_INET6, AF_UNIX, SO_ERROR, SO_REUSEADDR, AF_INET, AF_INET6, AF_UNIX, SO_ERROR, SO_REUSEADDR,
SO_REUSEPORT, SO_BROADCAST, IPPROTO_IP, SO_REUSEPORT, SO_BROADCAST, IPPROTO_IP, IPPROTO_IPV6,
IPV6_MULTICAST_HOPS, SOCK_DGRAM, RLIMIT_NOFILE, IPV6_MULTICAST_HOPS, SOCK_DGRAM, RLIMIT_NOFILE,
SIG_BLOCK, SIG_UNBLOCK, CLOCK_MONOTONIC, SIG_BLOCK, SIG_UNBLOCK, CLOCK_MONOTONIC,
SHUT_RD, SHUT_WR, SHUT_RDWR, SHUT_RD, SHUT_WR, SHUT_RDWR,
POLLIN, POLLOUT, POLLERR, POLLHUP, POLLNVAL,
SIGHUP, SIGINT, SIGQUIT, SIGILL, SIGTRAP, SIGABRT, SIGHUP, SIGINT, SIGQUIT, SIGILL, SIGTRAP, SIGABRT,
SIGBUS, SIGFPE, SIGKILL, SIGUSR1, SIGSEGV, SIGUSR2, SIGBUS, SIGFPE, SIGKILL, SIGUSR1, SIGSEGV, SIGUSR2,
SIGPIPE, SIGALRM, SIGTERM, SIGPIPE, SIGCHLD, SIGSTOP, SIGPIPE, SIGALRM, SIGTERM, SIGPIPE, SIGCHLD, SIGSTOP,
@ -1123,20 +1175,21 @@ elif defined(freebsd) or defined(openbsd) or defined(netbsd) or
sigaddset, sigismember, fcntl, accept, pipe, write, sigaddset, sigismember, fcntl, accept, pipe, write,
signal, read, setsockopt, getsockopt, clock_gettime, signal, read, setsockopt, getsockopt, clock_gettime,
getcwd, chdir, waitpid, kill, select, pselect, getcwd, chdir, waitpid, kill, select, pselect,
socketpair, socketpair, poll, freeAddrInfo,
Timeval, Timespec, Pid, Mode, Time, Sigset, SockAddr, Timeval, Timespec, Pid, Mode, Time, Sigset, SockAddr,
SockLen, Sockaddr_storage, Sockaddr_in, Sockaddr_in6, SockLen, Sockaddr_storage, Sockaddr_in, Sockaddr_in6,
Sockaddr_un, SocketHandle, AddrInfo, RLimit, TFdSet, Sockaddr_un, SocketHandle, AddrInfo, RLimit, TFdSet,
Suseconds, Suseconds, TPollfd, Tnfds,
FD_CLR, FD_ISSET, FD_SET, FD_ZERO, FD_CLR, FD_ISSET, FD_SET, FD_ZERO,
F_GETFL, F_SETFL, F_GETFD, F_SETFD, FD_CLOEXEC, F_GETFL, F_SETFL, F_GETFD, F_SETFD, FD_CLOEXEC,
O_NONBLOCK, SOL_SOCKET, SOCK_RAW, SOCK_DGRAM, O_NONBLOCK, SOL_SOCKET, SOCK_RAW, SOCK_DGRAM,
SOCK_STREAM, MSG_NOSIGNAL, MSG_PEEK, SOCK_STREAM, MSG_NOSIGNAL, MSG_PEEK,
AF_INET, AF_INET6, AF_UNIX, SO_ERROR, SO_REUSEADDR, AF_INET, AF_INET6, AF_UNIX, SO_ERROR, SO_REUSEADDR,
SO_REUSEPORT, SO_BROADCAST, IPPROTO_IP, SO_REUSEPORT, SO_BROADCAST, IPPROTO_IP, IPPROTO_IPV6,
IPV6_MULTICAST_HOPS, SOCK_DGRAM, RLIMIT_NOFILE, IPV6_MULTICAST_HOPS, SOCK_DGRAM, RLIMIT_NOFILE,
SIG_BLOCK, SIG_UNBLOCK, CLOCK_MONOTONIC, SIG_BLOCK, SIG_UNBLOCK, CLOCK_MONOTONIC,
SHUT_RD, SHUT_WR, SHUT_RDWR, SHUT_RD, SHUT_WR, SHUT_RDWR,
POLLIN, POLLOUT, POLLERR, POLLHUP, POLLNVAL,
SIGHUP, SIGINT, SIGQUIT, SIGILL, SIGTRAP, SIGABRT, SIGHUP, SIGINT, SIGQUIT, SIGILL, SIGTRAP, SIGABRT,
SIGBUS, SIGFPE, SIGKILL, SIGUSR1, SIGSEGV, SIGUSR2, SIGBUS, SIGFPE, SIGKILL, SIGUSR1, SIGSEGV, SIGUSR2,
SIGPIPE, SIGALRM, SIGTERM, SIGPIPE, SIGCHLD, SIGSTOP, SIGPIPE, SIGALRM, SIGTERM, SIGPIPE, SIGCHLD, SIGSTOP,
@ -1160,47 +1213,52 @@ when defined(linux):
SOCK_CLOEXEC* = 0x80000 SOCK_CLOEXEC* = 0x80000
TCP_NODELAY* = cint(1) TCP_NODELAY* = cint(1)
IPPROTO_TCP* = 6 IPPROTO_TCP* = 6
elif defined(freebsd) or defined(netbsd) or defined(dragonfly): O_CLOEXEC* = 0x80000
POSIX_SPAWN_USEVFORK* = 0x40
IPV6_V6ONLY* = 26
elif defined(freebsd):
const const
SOCK_NONBLOCK* = 0x20000000 SOCK_NONBLOCK* = 0x20000000
SOCK_CLOEXEC* = 0x10000000 SOCK_CLOEXEC* = 0x10000000
TCP_NODELAY* = cint(1) TCP_NODELAY* = cint(1)
IPPROTO_TCP* = 6 IPPROTO_TCP* = 6
O_CLOEXEC* = 0x00100000
POSIX_SPAWN_USEVFORK* = 0x00
IPV6_V6ONLY* = 27
elif defined(netbsd):
const
SOCK_NONBLOCK* = 0x20000000
SOCK_CLOEXEC* = 0x10000000
TCP_NODELAY* = cint(1)
IPPROTO_TCP* = 6
O_CLOEXEC* = 0x00400000
POSIX_SPAWN_USEVFORK* = 0x00
IPV6_V6ONLY* = 27
elif defined(dragonfly):
const
SOCK_NONBLOCK* = 0x20000000
SOCK_CLOEXEC* = 0x10000000
TCP_NODELAY* = cint(1)
IPPROTO_TCP* = 6
O_CLOEXEC* = 0x00020000
POSIX_SPAWN_USEVFORK* = 0x00
IPV6_V6ONLY* = 27
elif defined(openbsd): elif defined(openbsd):
const const
SOCK_CLOEXEC* = 0x8000 SOCK_CLOEXEC* = 0x8000
SOCK_NONBLOCK* = 0x4000 SOCK_NONBLOCK* = 0x4000
TCP_NODELAY* = cint(1) TCP_NODELAY* = cint(1)
IPPROTO_TCP* = 6 IPPROTO_TCP* = 6
O_CLOEXEC* = 0x10000
POSIX_SPAWN_USEVFORK* = 0x00
IPV6_V6ONLY* = 27
elif defined(macos) or defined(macosx): elif defined(macos) or defined(macosx):
const const
TCP_NODELAY* = cint(1) TCP_NODELAY* = cint(1)
IP_MULTICAST_TTL* = cint(10) IP_MULTICAST_TTL* = cint(10)
IPPROTO_TCP* = 6 IPPROTO_TCP* = 6
when defined(linux):
const
O_CLOEXEC* = 0x80000
POSIX_SPAWN_USEVFORK* = 0x40
elif defined(freebsd):
const
O_CLOEXEC* = 0x00100000
POSIX_SPAWN_USEVFORK* = 0x00
elif defined(openbsd):
const
O_CLOEXEC* = 0x10000
POSIX_SPAWN_USEVFORK* = 0x00
elif defined(netbsd):
const
O_CLOEXEC* = 0x00400000
POSIX_SPAWN_USEVFORK* = 0x00
elif defined(dragonfly):
const
O_CLOEXEC* = 0x00020000
POSIX_SPAWN_USEVFORK* = 0x00
elif defined(macos) or defined(macosx):
const
POSIX_SPAWN_USEVFORK* = 0x00 POSIX_SPAWN_USEVFORK* = 0x00
IPV6_V6ONLY* = 27
when defined(linux) or defined(macos) or defined(macosx) or defined(freebsd) or when defined(linux) or defined(macos) or defined(macosx) or defined(freebsd) or
defined(openbsd) or defined(netbsd) or defined(dragonfly): defined(openbsd) or defined(netbsd) or defined(dragonfly):
@ -1468,6 +1526,8 @@ when defined(posix):
INVALID_HANDLE_VALUE* = cint(-1) INVALID_HANDLE_VALUE* = cint(-1)
proc `==`*(x: SocketHandle, y: int): bool = int(x) == y proc `==`*(x: SocketHandle, y: int): bool = int(x) == y
when defined(nimdoc):
proc `==`*(x: SocketHandle, y: SocketHandle): bool {.borrow.}
when defined(macosx) or defined(macos) or defined(bsd): when defined(macosx) or defined(macos) or defined(bsd):
const const
@ -1595,6 +1655,8 @@ elif defined(linux):
# RTA_PRIORITY* = 6'u16 # RTA_PRIORITY* = 6'u16
RTA_PREFSRC* = 7'u16 RTA_PREFSRC* = 7'u16
# RTA_METRICS* = 8'u16 # RTA_METRICS* = 8'u16
RTM_NEWLINK* = 16'u16
RTM_NEWROUTE* = 24'u16
RTM_F_LOOKUP_TABLE* = 0x1000 RTM_F_LOOKUP_TABLE* = 0x1000

View File

@ -1328,6 +1328,7 @@ elif defined(windows):
ERROR_CONNECTION_REFUSED* = OSErrorCode(1225) ERROR_CONNECTION_REFUSED* = OSErrorCode(1225)
ERROR_CONNECTION_ABORTED* = OSErrorCode(1236) ERROR_CONNECTION_ABORTED* = OSErrorCode(1236)
WSAEMFILE* = OSErrorCode(10024) WSAEMFILE* = OSErrorCode(10024)
WSAEAFNOSUPPORT* = OSErrorCode(10047)
WSAENETDOWN* = OSErrorCode(10050) WSAENETDOWN* = OSErrorCode(10050)
WSAENETRESET* = OSErrorCode(10052) WSAENETRESET* = OSErrorCode(10052)
WSAECONNABORTED* = OSErrorCode(10053) WSAECONNABORTED* = OSErrorCode(10053)

View File

@ -6,8 +6,8 @@
# Licensed under either of # Licensed under either of
# Apache License, version 2.0, (LICENSE-APACHEv2) # Apache License, version 2.0, (LICENSE-APACHEv2)
# MIT license (LICENSE-MIT) # MIT license (LICENSE-MIT)
import stew/results import results
import osdefs, oserrno import "."/[osdefs, oserrno]
export results export results
@ -346,6 +346,10 @@ else:
return err(osLastError()) return err(osLastError())
ok() ok()
proc setDescriptorBlocking*(s: SocketHandle,
value: bool): Result[void, OSErrorCode] =
setDescriptorBlocking(cint(s), value)
proc setDescriptorInheritance*(s: cint, proc setDescriptorInheritance*(s: cint,
value: bool): Result[void, OSErrorCode] = value: bool): Result[void, OSErrorCode] =
let flags = handleEintr(osdefs.fcntl(s, osdefs.F_GETFD)) let flags = handleEintr(osdefs.fcntl(s, osdefs.F_GETFD))

View File

@ -88,8 +88,8 @@ proc worker(bucket: TokenBucket) {.async.} =
#buckets #buckets
sleeper = sleepAsync(milliseconds(timeToTarget)) sleeper = sleepAsync(milliseconds(timeToTarget))
await sleeper or eventWaiter await sleeper or eventWaiter
sleeper.cancel() sleeper.cancelSoon()
eventWaiter.cancel() eventWaiter.cancelSoon()
else: else:
await eventWaiter await eventWaiter

View File

@ -31,32 +31,11 @@
# support - changes could potentially be backported to nim but are not # support - changes could potentially be backported to nim but are not
# backwards-compatible. # backwards-compatible.
import stew/results import results
import osdefs, osutils, oserrno import "."/[config, osdefs, osutils, oserrno]
export results, oserrno export results, oserrno
const
asyncEventsCount* {.intdefine.} = 64
## Number of epoll events retrieved by syscall.
asyncInitialSize* {.intdefine.} = 64
## Initial size of Selector[T]'s array of file descriptors.
asyncEventEngine* {.strdefine.} =
when defined(linux):
"epoll"
elif defined(macosx) or defined(macos) or defined(ios) or
defined(freebsd) or defined(netbsd) or defined(openbsd) or
defined(dragonfly):
"kqueue"
elif defined(posix):
"poll"
else:
""
## Engine type which is going to be used by module.
hasThreadSupport = compileOption("threads")
when defined(nimdoc): when defined(nimdoc):
type type
Selector*[T] = ref object Selector*[T] = ref object
## An object which holds descriptors to be checked for read/write status ## An object which holds descriptors to be checked for read/write status
@ -281,7 +260,9 @@ else:
var err = newException(IOSelectorsException, msg) var err = newException(IOSelectorsException, msg)
raise err raise err
when asyncEventEngine in ["epoll", "kqueue"]: when chronosEventEngine in ["epoll", "kqueue"]:
const hasThreadSupport = compileOption("threads")
proc blockSignals(newmask: Sigset, proc blockSignals(newmask: Sigset,
oldmask: var Sigset): Result[void, OSErrorCode] = oldmask: var Sigset): Result[void, OSErrorCode] =
var nmask = newmask var nmask = newmask
@ -324,11 +305,11 @@ else:
doAssert((timeout >= min) and (timeout <= max), doAssert((timeout >= min) and (timeout <= max),
"Cannot select with incorrect timeout value, got " & $timeout) "Cannot select with incorrect timeout value, got " & $timeout)
when asyncEventEngine == "epoll": when chronosEventEngine == "epoll":
include ./ioselects/ioselectors_epoll include ./ioselects/ioselectors_epoll
elif asyncEventEngine == "kqueue": elif chronosEventEngine == "kqueue":
include ./ioselects/ioselectors_kqueue include ./ioselects/ioselectors_kqueue
elif asyncEventEngine == "poll": elif chronosEventEngine == "poll":
include ./ioselects/ioselectors_poll include ./ioselects/ioselectors_poll
else: else:
{.fatal: "Event engine `" & asyncEventEngine & "` is not supported!".} {.fatal: "Event engine `" & chronosEventEngine & "` is not supported!".}

View File

@ -38,8 +38,12 @@ when defined(nimdoc):
## be prepared to retry the call if there were unsent bytes. ## be prepared to retry the call if there were unsent bytes.
## ##
## On error, ``-1`` is returned. ## On error, ``-1`` is returned.
elif defined(emscripten):
elif defined(linux) or defined(android): proc sendfile*(outfd, infd: int, offset: int, count: var int): int =
raiseAssert "sendfile() is not implemented yet"
elif (defined(linux) or defined(android)) and not(defined(emscripten)):
proc osSendFile*(outfd, infd: cint, offset: ptr int, count: int): int proc osSendFile*(outfd, infd: cint, offset: ptr int, count: int): int
{.importc: "sendfile", header: "<sys/sendfile.h>".} {.importc: "sendfile", header: "<sys/sendfile.h>".}

View File

@ -9,12 +9,12 @@
{.push raises: [].} {.push raises: [].}
import ../asyncloop, ../asyncsync import ../[config, asyncloop, asyncsync, bipbuffer]
import ../transports/common, ../transports/stream import ../transports/[common, stream]
export asyncloop, asyncsync, stream, common export asyncloop, asyncsync, stream, common
const const
AsyncStreamDefaultBufferSize* = 4096 AsyncStreamDefaultBufferSize* = chronosStreamDefaultBufferSize
## Default reading stream internal buffer size. ## Default reading stream internal buffer size.
AsyncStreamDefaultQueueSize* = 0 AsyncStreamDefaultQueueSize* = 0
## Default writing stream internal queue size. ## Default writing stream internal queue size.
@ -24,22 +24,21 @@ const
## AsyncStreamWriter leaks tracker name ## AsyncStreamWriter leaks tracker name
type type
AsyncStreamError* = object of CatchableError AsyncStreamError* = object of AsyncError
AsyncStreamIncorrectDefect* = object of Defect AsyncStreamIncorrectDefect* = object of Defect
AsyncStreamIncompleteError* = object of AsyncStreamError AsyncStreamIncompleteError* = object of AsyncStreamError
AsyncStreamLimitError* = object of AsyncStreamError AsyncStreamLimitError* = object of AsyncStreamError
AsyncStreamUseClosedError* = object of AsyncStreamError AsyncStreamUseClosedError* = object of AsyncStreamError
AsyncStreamReadError* = object of AsyncStreamError AsyncStreamReadError* = object of AsyncStreamError
par*: ref CatchableError
AsyncStreamWriteError* = object of AsyncStreamError AsyncStreamWriteError* = object of AsyncStreamError
par*: ref CatchableError
AsyncStreamWriteEOFError* = object of AsyncStreamWriteError AsyncStreamWriteEOFError* = object of AsyncStreamWriteError
AsyncBuffer* = object AsyncBuffer* = object
offset*: int backend*: BipBuffer
buffer*: seq[byte]
events*: array[2, AsyncEvent] events*: array[2, AsyncEvent]
AsyncBufferRef* = ref AsyncBuffer
WriteType* = enum WriteType* = enum
Pointer, Sequence, String Pointer, Sequence, String
@ -53,7 +52,7 @@ type
dataStr*: string dataStr*: string
size*: int size*: int
offset*: int offset*: int
future*: Future[void] future*: Future[void].Raising([CancelledError, AsyncStreamError])
AsyncStreamState* = enum AsyncStreamState* = enum
Running, ## Stream is online and working Running, ## Stream is online and working
@ -64,10 +63,10 @@ type
Closed ## Stream was closed Closed ## Stream was closed
StreamReaderLoop* = proc (stream: AsyncStreamReader): Future[void] {. StreamReaderLoop* = proc (stream: AsyncStreamReader): Future[void] {.
gcsafe, raises: [].} async: (raises: []).}
## Main read loop for read streams. ## Main read loop for read streams.
StreamWriterLoop* = proc (stream: AsyncStreamWriter): Future[void] {. StreamWriterLoop* = proc (stream: AsyncStreamWriter): Future[void] {.
gcsafe, raises: [].} async: (raises: []).}
## Main write loop for write streams. ## Main write loop for write streams.
AsyncStreamReader* = ref object of RootRef AsyncStreamReader* = ref object of RootRef
@ -75,11 +74,11 @@ type
tsource*: StreamTransport tsource*: StreamTransport
readerLoop*: StreamReaderLoop readerLoop*: StreamReaderLoop
state*: AsyncStreamState state*: AsyncStreamState
buffer*: AsyncBuffer buffer*: AsyncBufferRef
udata: pointer udata: pointer
error*: ref AsyncStreamError error*: ref AsyncStreamError
bytesCount*: uint64 bytesCount*: uint64
future: Future[void] future: Future[void].Raising([])
AsyncStreamWriter* = ref object of RootRef AsyncStreamWriter* = ref object of RootRef
wsource*: AsyncStreamWriter wsource*: AsyncStreamWriter
@ -90,7 +89,7 @@ type
error*: ref AsyncStreamError error*: ref AsyncStreamError
udata: pointer udata: pointer
bytesCount*: uint64 bytesCount*: uint64
future: Future[void] future: Future[void].Raising([])
AsyncStream* = object of RootObj AsyncStream* = object of RootObj
reader*: AsyncStreamReader reader*: AsyncStreamReader
@ -98,84 +97,51 @@ type
AsyncStreamRW* = AsyncStreamReader | AsyncStreamWriter AsyncStreamRW* = AsyncStreamReader | AsyncStreamWriter
proc init*(t: typedesc[AsyncBuffer], size: int): AsyncBuffer = proc new*(t: typedesc[AsyncBufferRef], size: int): AsyncBufferRef =
AsyncBuffer( AsyncBufferRef(
buffer: newSeq[byte](size), backend: BipBuffer.init(size),
events: [newAsyncEvent(), newAsyncEvent()], events: [newAsyncEvent(), newAsyncEvent()]
offset: 0
) )
proc getBuffer*(sb: AsyncBuffer): pointer {.inline.} = template wait*(sb: AsyncBufferRef): untyped =
unsafeAddr sb.buffer[sb.offset]
proc bufferLen*(sb: AsyncBuffer): int {.inline.} =
len(sb.buffer) - sb.offset
proc getData*(sb: AsyncBuffer): pointer {.inline.} =
unsafeAddr sb.buffer[0]
template dataLen*(sb: AsyncBuffer): int =
sb.offset
proc `[]`*(sb: AsyncBuffer, index: int): byte {.inline.} =
doAssert(index < sb.offset)
sb.buffer[index]
proc update*(sb: var AsyncBuffer, size: int) {.inline.} =
sb.offset += size
proc wait*(sb: var AsyncBuffer): Future[void] =
sb.events[0].clear() sb.events[0].clear()
sb.events[1].fire() sb.events[1].fire()
sb.events[0].wait() sb.events[0].wait()
proc transfer*(sb: var AsyncBuffer): Future[void] = template transfer*(sb: AsyncBufferRef): untyped =
sb.events[1].clear() sb.events[1].clear()
sb.events[0].fire() sb.events[0].fire()
sb.events[1].wait() sb.events[1].wait()
proc forget*(sb: var AsyncBuffer) {.inline.} = proc forget*(sb: AsyncBufferRef) {.inline.} =
sb.events[1].clear() sb.events[1].clear()
sb.events[0].fire() sb.events[0].fire()
proc shift*(sb: var AsyncBuffer, size: int) {.inline.} = proc upload*(sb: AsyncBufferRef, pbytes: ptr byte,
if sb.offset > size: nbytes: int): Future[void] {.
moveMem(addr sb.buffer[0], addr sb.buffer[size], sb.offset - size) async: (raises: [CancelledError]).} =
sb.offset = sb.offset - size
else:
sb.offset = 0
proc copyData*(sb: AsyncBuffer, dest: pointer, offset, length: int) {.inline.} =
copyMem(cast[pointer](cast[uint](dest) + cast[uint](offset)),
unsafeAddr sb.buffer[0], length)
proc upload*(sb: ptr AsyncBuffer, pbytes: ptr byte,
nbytes: int): Future[void] {.async.} =
## You can upload any amount of bytes to the buffer. If size of internal ## You can upload any amount of bytes to the buffer. If size of internal
## buffer is not enough to fit all the data at once, data will be uploaded ## buffer is not enough to fit all the data at once, data will be uploaded
## via chunks of size up to internal buffer size. ## via chunks of size up to internal buffer size.
var length = nbytes var
var srcBuffer = cast[ptr UncheckedArray[byte]](pbytes) length = nbytes
var srcOffset = 0 srcBuffer = pbytes.toUnchecked()
offset = 0
while length > 0: while length > 0:
let size = min(length, sb[].bufferLen()) let size = min(length, sb.backend.availSpace())
if size == 0: if size == 0:
# Internal buffer is full, we need to transfer data to consumer. # Internal buffer is full, we need to notify consumer.
await sb[].transfer() await sb.transfer()
else: else:
let (data, _) = sb.backend.reserve()
# Copy data from `pbytes` to internal buffer. # Copy data from `pbytes` to internal buffer.
copyMem(addr sb[].buffer[sb.offset], addr srcBuffer[srcOffset], size) copyMem(data, addr srcBuffer[offset], size)
sb[].offset = sb[].offset + size sb.backend.commit(size)
srcOffset = srcOffset + size offset = offset + size
length = length - size length = length - size
# We notify consumers that new data is available. # We notify consumers that new data is available.
sb[].forget() sb.forget()
template toDataOpenArray*(sb: AsyncBuffer): auto =
toOpenArray(sb.buffer, 0, sb.offset - 1)
template toBufferOpenArray*(sb: AsyncBuffer): auto =
toOpenArray(sb.buffer, sb.offset, len(sb.buffer) - 1)
template copyOut*(dest: pointer, item: WriteItem, length: int) = template copyOut*(dest: pointer, item: WriteItem, length: int) =
if item.kind == Pointer: if item.kind == Pointer:
@ -186,18 +152,20 @@ template copyOut*(dest: pointer, item: WriteItem, length: int) =
elif item.kind == String: elif item.kind == String:
copyMem(dest, unsafeAddr item.dataStr[item.offset], length) copyMem(dest, unsafeAddr item.dataStr[item.offset], length)
proc newAsyncStreamReadError(p: ref CatchableError): ref AsyncStreamReadError {. proc newAsyncStreamReadError(
noinline.} = p: ref TransportError
): ref AsyncStreamReadError {.noinline.} =
var w = newException(AsyncStreamReadError, "Read stream failed") var w = newException(AsyncStreamReadError, "Read stream failed")
w.msg = w.msg & ", originated from [" & $p.name & "] " & p.msg w.msg = w.msg & ", originated from [" & $p.name & "] " & p.msg
w.par = p w.parent = p
w w
proc newAsyncStreamWriteError(p: ref CatchableError): ref AsyncStreamWriteError {. proc newAsyncStreamWriteError(
noinline.} = p: ref TransportError
): ref AsyncStreamWriteError {.noinline.} =
var w = newException(AsyncStreamWriteError, "Write stream failed") var w = newException(AsyncStreamWriteError, "Write stream failed")
w.msg = w.msg & ", originated from [" & $p.name & "] " & p.msg w.msg = w.msg & ", originated from [" & $p.name & "] " & p.msg
w.par = p w.parent = p
w w
proc newAsyncStreamIncompleteError*(): ref AsyncStreamIncompleteError {. proc newAsyncStreamIncompleteError*(): ref AsyncStreamIncompleteError {.
@ -242,7 +210,7 @@ proc atEof*(rstream: AsyncStreamReader): bool =
rstream.rsource.atEof() rstream.rsource.atEof()
else: else:
(rstream.state != AsyncStreamState.Running) and (rstream.state != AsyncStreamState.Running) and
(rstream.buffer.dataLen() == 0) (len(rstream.buffer.backend) == 0)
proc atEof*(wstream: AsyncStreamWriter): bool = proc atEof*(wstream: AsyncStreamWriter): bool =
## Returns ``true`` is writing stream ``wstream`` closed or finished. ## Returns ``true`` is writing stream ``wstream`` closed or finished.
@ -330,12 +298,12 @@ template checkStreamFinished*(t: untyped) =
template readLoop(body: untyped): untyped = template readLoop(body: untyped): untyped =
while true: while true:
if rstream.buffer.dataLen() == 0: if len(rstream.buffer.backend) == 0:
if rstream.state == AsyncStreamState.Error: if rstream.state == AsyncStreamState.Error:
raise rstream.error raise rstream.error
let (consumed, done) = body let (consumed, done) = body
rstream.buffer.shift(consumed) rstream.buffer.backend.consume(consumed)
rstream.bytesCount = rstream.bytesCount + uint64(consumed) rstream.bytesCount = rstream.bytesCount + uint64(consumed)
if done: if done:
break break
@ -344,11 +312,12 @@ template readLoop(body: untyped): untyped =
await rstream.buffer.wait() await rstream.buffer.wait()
proc readExactly*(rstream: AsyncStreamReader, pbytes: pointer, proc readExactly*(rstream: AsyncStreamReader, pbytes: pointer,
nbytes: int) {.async.} = nbytes: int) {.
async: (raises: [CancelledError, AsyncStreamError]).} =
## Read exactly ``nbytes`` bytes from read-only stream ``rstream`` and store ## Read exactly ``nbytes`` bytes from read-only stream ``rstream`` and store
## it to ``pbytes``. ## it to ``pbytes``.
## ##
## If EOF is received and ``nbytes`` is not yet readed, the procedure ## If EOF is received and ``nbytes`` is not yet read, the procedure
## will raise ``AsyncStreamIncompleteError``. ## will raise ``AsyncStreamIncompleteError``.
doAssert(not(isNil(pbytes)), "pbytes must not be nil") doAssert(not(isNil(pbytes)), "pbytes must not be nil")
doAssert(nbytes >= 0, "nbytes must be non-negative integer") doAssert(nbytes >= 0, "nbytes must be non-negative integer")
@ -365,26 +334,33 @@ proc readExactly*(rstream: AsyncStreamReader, pbytes: pointer,
raise exc raise exc
except TransportIncompleteError: except TransportIncompleteError:
raise newAsyncStreamIncompleteError() raise newAsyncStreamIncompleteError()
except CatchableError as exc: except TransportError as exc:
raise newAsyncStreamReadError(exc) raise newAsyncStreamReadError(exc)
else: else:
if isNil(rstream.readerLoop): if isNil(rstream.readerLoop):
await readExactly(rstream.rsource, pbytes, nbytes) await readExactly(rstream.rsource, pbytes, nbytes)
else: else:
var index = 0 var
var pbuffer = cast[ptr UncheckedArray[byte]](pbytes) index = 0
pbuffer = pbytes.toUnchecked()
readLoop(): readLoop():
if rstream.buffer.dataLen() == 0: if len(rstream.buffer.backend) == 0:
if rstream.atEof(): if rstream.atEof():
raise newAsyncStreamIncompleteError() raise newAsyncStreamIncompleteError()
let count = min(nbytes - index, rstream.buffer.dataLen()) var bytesRead = 0
if count > 0: for (region, rsize) in rstream.buffer.backend.regions():
rstream.buffer.copyData(addr pbuffer[index], 0, count) let count = min(nbytes - index, rsize)
index += count bytesRead += count
(consumed: count, done: index == nbytes) if count > 0:
copyMem(addr pbuffer[index], region, count)
index += count
if index == nbytes:
break
(consumed: bytesRead, done: index == nbytes)
proc readOnce*(rstream: AsyncStreamReader, pbytes: pointer, proc readOnce*(rstream: AsyncStreamReader, pbytes: pointer,
nbytes: int): Future[int] {.async.} = nbytes: int): Future[int] {.
async: (raises: [CancelledError, AsyncStreamError]).} =
## Perform one read operation on read-only stream ``rstream``. ## Perform one read operation on read-only stream ``rstream``.
## ##
## If internal buffer is not empty, ``nbytes`` bytes will be transferred from ## If internal buffer is not empty, ``nbytes`` bytes will be transferred from
@ -398,24 +374,31 @@ proc readOnce*(rstream: AsyncStreamReader, pbytes: pointer,
return await readOnce(rstream.tsource, pbytes, nbytes) return await readOnce(rstream.tsource, pbytes, nbytes)
except CancelledError as exc: except CancelledError as exc:
raise exc raise exc
except CatchableError as exc: except TransportError as exc:
raise newAsyncStreamReadError(exc) raise newAsyncStreamReadError(exc)
else: else:
if isNil(rstream.readerLoop): if isNil(rstream.readerLoop):
return await readOnce(rstream.rsource, pbytes, nbytes) return await readOnce(rstream.rsource, pbytes, nbytes)
else: else:
var count = 0 var
pbuffer = pbytes.toUnchecked()
index = 0
readLoop(): readLoop():
if rstream.buffer.dataLen() == 0: if len(rstream.buffer.backend) == 0:
(0, rstream.atEof()) (0, rstream.atEof())
else: else:
count = min(rstream.buffer.dataLen(), nbytes) for (region, rsize) in rstream.buffer.backend.regions():
rstream.buffer.copyData(pbytes, 0, count) let size = min(rsize, nbytes - index)
(count, true) copyMem(addr pbuffer[index], region, size)
return count index += size
if index >= nbytes:
break
(index, true)
index
proc readUntil*(rstream: AsyncStreamReader, pbytes: pointer, nbytes: int, proc readUntil*(rstream: AsyncStreamReader, pbytes: pointer, nbytes: int,
sep: seq[byte]): Future[int] {.async.} = sep: seq[byte]): Future[int] {.
async: (raises: [CancelledError, AsyncStreamError]).} =
## Read data from the read-only stream ``rstream`` until separator ``sep`` is ## Read data from the read-only stream ``rstream`` until separator ``sep`` is
## found. ## found.
## ##
@ -446,37 +429,42 @@ proc readUntil*(rstream: AsyncStreamReader, pbytes: pointer, nbytes: int,
raise newAsyncStreamIncompleteError() raise newAsyncStreamIncompleteError()
except TransportLimitError: except TransportLimitError:
raise newAsyncStreamLimitError() raise newAsyncStreamLimitError()
except CatchableError as exc: except TransportError as exc:
raise newAsyncStreamReadError(exc) raise newAsyncStreamReadError(exc)
else: else:
if isNil(rstream.readerLoop): if isNil(rstream.readerLoop):
return await readUntil(rstream.rsource, pbytes, nbytes, sep) return await readUntil(rstream.rsource, pbytes, nbytes, sep)
else: else:
var pbuffer = cast[ptr UncheckedArray[byte]](pbytes) var
var state = 0 pbuffer = pbytes.toUnchecked()
var k = 0 state = 0
k = 0
readLoop(): readLoop():
if rstream.atEof(): if rstream.atEof():
raise newAsyncStreamIncompleteError() raise newAsyncStreamIncompleteError()
var index = 0 var index = 0
while index < rstream.buffer.dataLen(): for ch in rstream.buffer.backend:
if k >= nbytes: if k >= nbytes:
raise newAsyncStreamLimitError() raise newAsyncStreamLimitError()
let ch = rstream.buffer[index]
inc(index) inc(index)
pbuffer[k] = ch pbuffer[k] = ch
inc(k) inc(k)
if sep[state] == ch: if sep[state] == ch:
inc(state) inc(state)
if state == len(sep): if state == len(sep):
break break
else: else:
state = 0 state = 0
(index, state == len(sep)) (index, state == len(sep))
return k k
proc readLine*(rstream: AsyncStreamReader, limit = 0, proc readLine*(rstream: AsyncStreamReader, limit = 0,
sep = "\r\n"): Future[string] {.async.} = sep = "\r\n"): Future[string] {.
async: (raises: [CancelledError, AsyncStreamError]).} =
## Read one line from read-only stream ``rstream``, where ``"line"`` is a ## Read one line from read-only stream ``rstream``, where ``"line"`` is a
## sequence of bytes ending with ``sep`` (default is ``"\r\n"``). ## sequence of bytes ending with ``sep`` (default is ``"\r\n"``).
## ##
@ -495,25 +483,26 @@ proc readLine*(rstream: AsyncStreamReader, limit = 0,
return await readLine(rstream.tsource, limit, sep) return await readLine(rstream.tsource, limit, sep)
except CancelledError as exc: except CancelledError as exc:
raise exc raise exc
except CatchableError as exc: except TransportError as exc:
raise newAsyncStreamReadError(exc) raise newAsyncStreamReadError(exc)
else: else:
if isNil(rstream.readerLoop): if isNil(rstream.readerLoop):
return await readLine(rstream.rsource, limit, sep) return await readLine(rstream.rsource, limit, sep)
else: else:
let lim = if limit <= 0: -1 else: limit let lim = if limit <= 0: -1 else: limit
var state = 0 var
var res = "" state = 0
res = ""
readLoop(): readLoop():
if rstream.atEof(): if rstream.atEof():
(0, true) (0, true)
else: else:
var index = 0 var index = 0
while index < rstream.buffer.dataLen(): for ch in rstream.buffer.backend:
let ch = char(rstream.buffer[index])
inc(index) inc(index)
if sep[state] == ch: if sep[state] == char(ch):
inc(state) inc(state)
if state == len(sep): if state == len(sep):
break break
@ -524,13 +513,17 @@ proc readLine*(rstream: AsyncStreamReader, limit = 0,
res.add(sep[0 ..< missing]) res.add(sep[0 ..< missing])
else: else:
res.add(sep[0 ..< state]) res.add(sep[0 ..< state])
res.add(ch) state = 0
res.add(char(ch))
if len(res) == lim: if len(res) == lim:
break break
(index, (state == len(sep)) or (lim == len(res)))
return res
proc read*(rstream: AsyncStreamReader): Future[seq[byte]] {.async.} = (index, (state == len(sep)) or (lim == len(res)))
res
proc read*(rstream: AsyncStreamReader): Future[seq[byte]] {.
async: (raises: [CancelledError, AsyncStreamError]).} =
## Read all bytes from read-only stream ``rstream``. ## Read all bytes from read-only stream ``rstream``.
## ##
## This procedure allocates buffer seq[byte] and return it as result. ## This procedure allocates buffer seq[byte] and return it as result.
@ -543,23 +536,26 @@ proc read*(rstream: AsyncStreamReader): Future[seq[byte]] {.async.} =
raise exc raise exc
except TransportLimitError: except TransportLimitError:
raise newAsyncStreamLimitError() raise newAsyncStreamLimitError()
except CatchableError as exc: except TransportError as exc:
raise newAsyncStreamReadError(exc) raise newAsyncStreamReadError(exc)
else: else:
if isNil(rstream.readerLoop): if isNil(rstream.readerLoop):
return await read(rstream.rsource) return await read(rstream.rsource)
else: else:
var res = newSeq[byte]() var res: seq[byte]
readLoop(): readLoop():
if rstream.atEof(): if rstream.atEof():
(0, true) (0, true)
else: else:
let count = rstream.buffer.dataLen() var bytesRead = 0
res.add(rstream.buffer.buffer.toOpenArray(0, count - 1)) for (region, rsize) in rstream.buffer.backend.regions():
(count, false) bytesRead += rsize
return res res.add(region.toUnchecked().toOpenArray(0, rsize - 1))
(bytesRead, false)
res
proc read*(rstream: AsyncStreamReader, n: int): Future[seq[byte]] {.async.} = proc read*(rstream: AsyncStreamReader, n: int): Future[seq[byte]] {.
async: (raises: [CancelledError, AsyncStreamError]).} =
## Read all bytes (n <= 0) or exactly `n` bytes from read-only stream ## Read all bytes (n <= 0) or exactly `n` bytes from read-only stream
## ``rstream``. ## ``rstream``.
## ##
@ -571,7 +567,7 @@ proc read*(rstream: AsyncStreamReader, n: int): Future[seq[byte]] {.async.} =
return await read(rstream.tsource, n) return await read(rstream.tsource, n)
except CancelledError as exc: except CancelledError as exc:
raise exc raise exc
except CatchableError as exc: except TransportError as exc:
raise newAsyncStreamReadError(exc) raise newAsyncStreamReadError(exc)
else: else:
if isNil(rstream.readerLoop): if isNil(rstream.readerLoop):
@ -585,12 +581,16 @@ proc read*(rstream: AsyncStreamReader, n: int): Future[seq[byte]] {.async.} =
if rstream.atEof(): if rstream.atEof():
(0, true) (0, true)
else: else:
let count = min(rstream.buffer.dataLen(), n - len(res)) var bytesRead = 0
res.add(rstream.buffer.buffer.toOpenArray(0, count - 1)) for (region, rsize) in rstream.buffer.backend.regions():
(count, len(res) == n) let count = min(rsize, n - len(res))
return res bytesRead += count
res.add(region.toUnchecked().toOpenArray(0, count - 1))
(bytesRead, len(res) == n)
res
proc consume*(rstream: AsyncStreamReader): Future[int] {.async.} = proc consume*(rstream: AsyncStreamReader): Future[int] {.
async: (raises: [CancelledError, AsyncStreamError]).} =
## Consume (discard) all bytes from read-only stream ``rstream``. ## Consume (discard) all bytes from read-only stream ``rstream``.
## ##
## Return number of bytes actually consumed (discarded). ## Return number of bytes actually consumed (discarded).
@ -603,7 +603,7 @@ proc consume*(rstream: AsyncStreamReader): Future[int] {.async.} =
raise exc raise exc
except TransportLimitError: except TransportLimitError:
raise newAsyncStreamLimitError() raise newAsyncStreamLimitError()
except CatchableError as exc: except TransportError as exc:
raise newAsyncStreamReadError(exc) raise newAsyncStreamReadError(exc)
else: else:
if isNil(rstream.readerLoop): if isNil(rstream.readerLoop):
@ -614,11 +614,13 @@ proc consume*(rstream: AsyncStreamReader): Future[int] {.async.} =
if rstream.atEof(): if rstream.atEof():
(0, true) (0, true)
else: else:
res += rstream.buffer.dataLen() let used = len(rstream.buffer.backend)
(rstream.buffer.dataLen(), false) res += used
return res (used, false)
res
proc consume*(rstream: AsyncStreamReader, n: int): Future[int] {.async.} = proc consume*(rstream: AsyncStreamReader, n: int): Future[int] {.
async: (raises: [CancelledError, AsyncStreamError]).} =
## Consume (discard) all bytes (n <= 0) or ``n`` bytes from read-only stream ## Consume (discard) all bytes (n <= 0) or ``n`` bytes from read-only stream
## ``rstream``. ## ``rstream``.
## ##
@ -632,7 +634,7 @@ proc consume*(rstream: AsyncStreamReader, n: int): Future[int] {.async.} =
raise exc raise exc
except TransportLimitError: except TransportLimitError:
raise newAsyncStreamLimitError() raise newAsyncStreamLimitError()
except CatchableError as exc: except TransportError as exc:
raise newAsyncStreamReadError(exc) raise newAsyncStreamReadError(exc)
else: else:
if isNil(rstream.readerLoop): if isNil(rstream.readerLoop):
@ -643,16 +645,15 @@ proc consume*(rstream: AsyncStreamReader, n: int): Future[int] {.async.} =
else: else:
var res = 0 var res = 0
readLoop(): readLoop():
if rstream.atEof(): let
(0, true) used = len(rstream.buffer.backend)
else: count = min(used, n - res)
let count = min(rstream.buffer.dataLen(), n - res) res += count
res += count (count, res == n)
(count, res == n) res
return res
proc readMessage*(rstream: AsyncStreamReader, pred: ReadMessagePredicate) {. proc readMessage*(rstream: AsyncStreamReader, pred: ReadMessagePredicate) {.
async.} = async: (raises: [CancelledError, AsyncStreamError]).} =
## Read all bytes from stream ``rstream`` until ``predicate`` callback ## Read all bytes from stream ``rstream`` until ``predicate`` callback
## will not be satisfied. ## will not be satisfied.
## ##
@ -673,25 +674,29 @@ proc readMessage*(rstream: AsyncStreamReader, pred: ReadMessagePredicate) {.
await readMessage(rstream.tsource, pred) await readMessage(rstream.tsource, pred)
except CancelledError as exc: except CancelledError as exc:
raise exc raise exc
except CatchableError as exc: except TransportError as exc:
raise newAsyncStreamReadError(exc) raise newAsyncStreamReadError(exc)
else: else:
if isNil(rstream.readerLoop): if isNil(rstream.readerLoop):
await readMessage(rstream.rsource, pred) await readMessage(rstream.rsource, pred)
else: else:
readLoop(): readLoop():
let count = rstream.buffer.dataLen() if len(rstream.buffer.backend) == 0:
if count == 0:
if rstream.atEof(): if rstream.atEof():
pred([]) pred([])
else: else:
# Case, when transport's buffer is not yet filled with data. # Case, when transport's buffer is not yet filled with data.
(0, false) (0, false)
else: else:
pred(rstream.buffer.buffer.toOpenArray(0, count - 1)) var res: tuple[consumed: int, done: bool]
for (region, rsize) in rstream.buffer.backend.regions():
res = pred(region.toUnchecked().toOpenArray(0, rsize - 1))
break
res
proc write*(wstream: AsyncStreamWriter, pbytes: pointer, proc write*(wstream: AsyncStreamWriter, pbytes: pointer,
nbytes: int) {.async.} = nbytes: int) {.
async: (raises: [CancelledError, AsyncStreamError]).} =
## Write sequence of bytes pointed by ``pbytes`` of length ``nbytes`` to ## Write sequence of bytes pointed by ``pbytes`` of length ``nbytes`` to
## writer stream ``wstream``. ## writer stream ``wstream``.
## ##
@ -708,9 +713,7 @@ proc write*(wstream: AsyncStreamWriter, pbytes: pointer,
res = await write(wstream.tsource, pbytes, nbytes) res = await write(wstream.tsource, pbytes, nbytes)
except CancelledError as exc: except CancelledError as exc:
raise exc raise exc
except AsyncStreamError as exc: except TransportError as exc:
raise exc
except CatchableError as exc:
raise newAsyncStreamWriteError(exc) raise newAsyncStreamWriteError(exc)
if res != nbytes: if res != nbytes:
raise newAsyncStreamIncompleteError() raise newAsyncStreamIncompleteError()
@ -720,23 +723,17 @@ proc write*(wstream: AsyncStreamWriter, pbytes: pointer,
await write(wstream.wsource, pbytes, nbytes) await write(wstream.wsource, pbytes, nbytes)
wstream.bytesCount = wstream.bytesCount + uint64(nbytes) wstream.bytesCount = wstream.bytesCount + uint64(nbytes)
else: else:
var item = WriteItem(kind: Pointer) let item = WriteItem(
item.dataPtr = pbytes kind: Pointer, dataPtr: pbytes, size: nbytes,
item.size = nbytes future: Future[void].Raising([CancelledError, AsyncStreamError])
item.future = newFuture[void]("async.stream.write(pointer)") .init("async.stream.write(pointer)"))
try: await wstream.queue.put(item)
await wstream.queue.put(item) await item.future
await item.future wstream.bytesCount = wstream.bytesCount + uint64(item.size)
wstream.bytesCount = wstream.bytesCount + uint64(item.size)
except CancelledError as exc:
raise exc
except AsyncStreamError as exc:
raise exc
except CatchableError as exc:
raise newAsyncStreamWriteError(exc)
proc write*(wstream: AsyncStreamWriter, sbytes: sink seq[byte], proc write*(wstream: AsyncStreamWriter, sbytes: seq[byte],
msglen = -1) {.async.} = msglen = -1) {.
async: (raises: [CancelledError, AsyncStreamError]).} =
## Write sequence of bytes ``sbytes`` of length ``msglen`` to writer ## Write sequence of bytes ``sbytes`` of length ``msglen`` to writer
## stream ``wstream``. ## stream ``wstream``.
## ##
@ -758,7 +755,7 @@ proc write*(wstream: AsyncStreamWriter, sbytes: sink seq[byte],
res = await write(wstream.tsource, sbytes, length) res = await write(wstream.tsource, sbytes, length)
except CancelledError as exc: except CancelledError as exc:
raise exc raise exc
except CatchableError as exc: except TransportError as exc:
raise newAsyncStreamWriteError(exc) raise newAsyncStreamWriteError(exc)
if res != length: if res != length:
raise newAsyncStreamIncompleteError() raise newAsyncStreamIncompleteError()
@ -768,29 +765,17 @@ proc write*(wstream: AsyncStreamWriter, sbytes: sink seq[byte],
await write(wstream.wsource, sbytes, length) await write(wstream.wsource, sbytes, length)
wstream.bytesCount = wstream.bytesCount + uint64(length) wstream.bytesCount = wstream.bytesCount + uint64(length)
else: else:
var item = WriteItem(kind: Sequence) let item = WriteItem(
when declared(shallowCopy): kind: Sequence, dataSeq: sbytes, size: length,
if not(isLiteral(sbytes)): future: Future[void].Raising([CancelledError, AsyncStreamError])
shallowCopy(item.dataSeq, sbytes) .init("async.stream.write(seq)"))
else: await wstream.queue.put(item)
item.dataSeq = sbytes await item.future
else: wstream.bytesCount = wstream.bytesCount + uint64(item.size)
item.dataSeq = sbytes
item.size = length
item.future = newFuture[void]("async.stream.write(seq)")
try:
await wstream.queue.put(item)
await item.future
wstream.bytesCount = wstream.bytesCount + uint64(item.size)
except CancelledError as exc:
raise exc
except AsyncStreamError as exc:
raise exc
except CatchableError as exc:
raise newAsyncStreamWriteError(exc)
proc write*(wstream: AsyncStreamWriter, sbytes: sink string, proc write*(wstream: AsyncStreamWriter, sbytes: string,
msglen = -1) {.async.} = msglen = -1) {.
async: (raises: [CancelledError, AsyncStreamError]).} =
## Write string ``sbytes`` of length ``msglen`` to writer stream ``wstream``. ## Write string ``sbytes`` of length ``msglen`` to writer stream ``wstream``.
## ##
## String ``sbytes`` must not be zero-length. ## String ``sbytes`` must not be zero-length.
@ -811,7 +796,7 @@ proc write*(wstream: AsyncStreamWriter, sbytes: sink string,
res = await write(wstream.tsource, sbytes, length) res = await write(wstream.tsource, sbytes, length)
except CancelledError as exc: except CancelledError as exc:
raise exc raise exc
except CatchableError as exc: except TransportError as exc:
raise newAsyncStreamWriteError(exc) raise newAsyncStreamWriteError(exc)
if res != length: if res != length:
raise newAsyncStreamIncompleteError() raise newAsyncStreamIncompleteError()
@ -821,28 +806,16 @@ proc write*(wstream: AsyncStreamWriter, sbytes: sink string,
await write(wstream.wsource, sbytes, length) await write(wstream.wsource, sbytes, length)
wstream.bytesCount = wstream.bytesCount + uint64(length) wstream.bytesCount = wstream.bytesCount + uint64(length)
else: else:
var item = WriteItem(kind: String) let item = WriteItem(
when declared(shallowCopy): kind: String, dataStr: sbytes, size: length,
if not(isLiteral(sbytes)): future: Future[void].Raising([CancelledError, AsyncStreamError])
shallowCopy(item.dataStr, sbytes) .init("async.stream.write(string)"))
else: await wstream.queue.put(item)
item.dataStr = sbytes await item.future
else: wstream.bytesCount = wstream.bytesCount + uint64(item.size)
item.dataStr = sbytes
item.size = length
item.future = newFuture[void]("async.stream.write(string)")
try:
await wstream.queue.put(item)
await item.future
wstream.bytesCount = wstream.bytesCount + uint64(item.size)
except CancelledError as exc:
raise exc
except AsyncStreamError as exc:
raise exc
except CatchableError as exc:
raise newAsyncStreamWriteError(exc)
proc finish*(wstream: AsyncStreamWriter) {.async.} = proc finish*(wstream: AsyncStreamWriter) {.
async: (raises: [CancelledError, AsyncStreamError]).} =
## Finish write stream ``wstream``. ## Finish write stream ``wstream``.
checkStreamClosed(wstream) checkStreamClosed(wstream)
# For AsyncStreamWriter Finished state could be set manually or by stream's # For AsyncStreamWriter Finished state could be set manually or by stream's
@ -852,40 +825,18 @@ proc finish*(wstream: AsyncStreamWriter) {.async.} =
if isNil(wstream.writerLoop): if isNil(wstream.writerLoop):
await wstream.wsource.finish() await wstream.wsource.finish()
else: else:
var item = WriteItem(kind: Pointer) let item = WriteItem(
item.size = 0 kind: Pointer, size: 0,
item.future = newFuture[void]("async.stream.finish") future: Future[void].Raising([CancelledError, AsyncStreamError])
try: .init("async.stream.finish"))
await wstream.queue.put(item) await wstream.queue.put(item)
await item.future await item.future
except CancelledError as exc:
raise exc
except AsyncStreamError as exc:
raise exc
except CatchableError as exc:
raise newAsyncStreamWriteError(exc)
proc join*(rw: AsyncStreamRW): Future[void] = proc join*(rw: AsyncStreamRW): Future[void] {.
async: (raw: true, raises: [CancelledError]).} =
## Get Future[void] which will be completed when stream become finished or ## Get Future[void] which will be completed when stream become finished or
## closed. ## closed.
when rw is AsyncStreamReader: rw.future.join()
var retFuture = newFuture[void]("async.stream.reader.join")
else:
var retFuture = newFuture[void]("async.stream.writer.join")
proc continuation(udata: pointer) {.gcsafe.} =
retFuture.complete()
proc cancellation(udata: pointer) {.gcsafe.} =
rw.future.removeCallback(continuation, cast[pointer](retFuture))
if not(rw.future.finished()):
rw.future.addCallback(continuation, cast[pointer](retFuture))
retFuture.cancelCallback = cancellation
else:
retFuture.complete()
return retFuture
proc close*(rw: AsyncStreamRW) = proc close*(rw: AsyncStreamRW) =
## Close and frees resources of stream ``rw``. ## Close and frees resources of stream ``rw``.
@ -913,7 +864,7 @@ proc close*(rw: AsyncStreamRW) =
callSoon(continuation) callSoon(continuation)
else: else:
rw.future.addCallback(continuation) rw.future.addCallback(continuation)
rw.future.cancel() rw.future.cancelSoon()
elif rw is AsyncStreamWriter: elif rw is AsyncStreamWriter:
if isNil(rw.wsource) or isNil(rw.writerLoop) or isNil(rw.future): if isNil(rw.wsource) or isNil(rw.writerLoop) or isNil(rw.future):
callSoon(continuation) callSoon(continuation)
@ -922,26 +873,29 @@ proc close*(rw: AsyncStreamRW) =
callSoon(continuation) callSoon(continuation)
else: else:
rw.future.addCallback(continuation) rw.future.addCallback(continuation)
rw.future.cancel() rw.future.cancelSoon()
proc closeWait*(rw: AsyncStreamRW): Future[void] = proc closeWait*(rw: AsyncStreamRW): Future[void] {.async: (raises: []).} =
## Close and frees resources of stream ``rw``. ## Close and frees resources of stream ``rw``.
rw.close() if not rw.closed():
rw.join() rw.close()
await noCancel(rw.join())
proc startReader(rstream: AsyncStreamReader) = proc startReader(rstream: AsyncStreamReader) =
rstream.state = Running rstream.state = Running
if not isNil(rstream.readerLoop): if not isNil(rstream.readerLoop):
rstream.future = rstream.readerLoop(rstream) rstream.future = rstream.readerLoop(rstream)
else: else:
rstream.future = newFuture[void]("async.stream.empty.reader") rstream.future = Future[void].Raising([]).init(
"async.stream.empty.reader", {FutureFlag.OwnCancelSchedule})
proc startWriter(wstream: AsyncStreamWriter) = proc startWriter(wstream: AsyncStreamWriter) =
wstream.state = Running wstream.state = Running
if not isNil(wstream.writerLoop): if not isNil(wstream.writerLoop):
wstream.future = wstream.writerLoop(wstream) wstream.future = wstream.writerLoop(wstream)
else: else:
wstream.future = newFuture[void]("async.stream.empty.writer") wstream.future = Future[void].Raising([]).init(
"async.stream.empty.writer", {FutureFlag.OwnCancelSchedule})
proc init*(child, wsource: AsyncStreamWriter, loop: StreamWriterLoop, proc init*(child, wsource: AsyncStreamWriter, loop: StreamWriterLoop,
queueSize = AsyncStreamDefaultQueueSize) = queueSize = AsyncStreamDefaultQueueSize) =
@ -975,7 +929,8 @@ proc init*(child, rsource: AsyncStreamReader, loop: StreamReaderLoop,
child.readerLoop = loop child.readerLoop = loop
child.rsource = rsource child.rsource = rsource
child.tsource = rsource.tsource child.tsource = rsource.tsource
child.buffer = AsyncBuffer.init(bufferSize) let size = max(AsyncStreamDefaultBufferSize, bufferSize)
child.buffer = AsyncBufferRef.new(size)
trackCounter(AsyncStreamReaderTrackerName) trackCounter(AsyncStreamReaderTrackerName)
child.startReader() child.startReader()
@ -987,7 +942,8 @@ proc init*[T](child, rsource: AsyncStreamReader, loop: StreamReaderLoop,
child.readerLoop = loop child.readerLoop = loop
child.rsource = rsource child.rsource = rsource
child.tsource = rsource.tsource child.tsource = rsource.tsource
child.buffer = AsyncBuffer.init(bufferSize) let size = max(AsyncStreamDefaultBufferSize, bufferSize)
child.buffer = AsyncBufferRef.new(size)
if not isNil(udata): if not isNil(udata):
GC_ref(udata) GC_ref(udata)
child.udata = cast[pointer](udata) child.udata = cast[pointer](udata)
@ -1126,6 +1082,22 @@ proc newAsyncStreamReader*(tsource: StreamTransport): AsyncStreamReader =
res.init(tsource) res.init(tsource)
res res
proc newAsyncStreamReader*[T](rsource: AsyncStreamReader,
udata: ref T): AsyncStreamReader =
## Create copy of AsyncStreamReader object ``rsource``.
##
## ``udata`` - user object which will be associated with new AsyncStreamReader
## object.
var res = AsyncStreamReader()
res.init(rsource, udata)
res
proc newAsyncStreamReader*(rsource: AsyncStreamReader): AsyncStreamReader =
## Create copy of AsyncStreamReader object ``rsource``.
var res = AsyncStreamReader()
res.init(rsource)
res
proc newAsyncStreamWriter*[T](wsource: AsyncStreamWriter, proc newAsyncStreamWriter*[T](wsource: AsyncStreamWriter,
loop: StreamWriterLoop, loop: StreamWriterLoop,
queueSize = AsyncStreamDefaultQueueSize, queueSize = AsyncStreamDefaultQueueSize,
@ -1191,22 +1163,6 @@ proc newAsyncStreamWriter*(wsource: AsyncStreamWriter): AsyncStreamWriter =
res.init(wsource) res.init(wsource)
res res
proc newAsyncStreamReader*[T](rsource: AsyncStreamWriter,
udata: ref T): AsyncStreamWriter =
## Create copy of AsyncStreamReader object ``rsource``.
##
## ``udata`` - user object which will be associated with new AsyncStreamReader
## object.
var res = AsyncStreamReader()
res.init(rsource, udata)
res
proc newAsyncStreamReader*(rsource: AsyncStreamReader): AsyncStreamReader =
## Create copy of AsyncStreamReader object ``rsource``.
var res = AsyncStreamReader()
res.init(rsource)
res
proc getUserData*[T](rw: AsyncStreamRW): T {.inline.} = proc getUserData*[T](rw: AsyncStreamRW): T {.inline.} =
## Obtain user data associated with AsyncStreamReader or AsyncStreamWriter ## Obtain user data associated with AsyncStreamReader or AsyncStreamWriter
## object ``rw``. ## object ``rw``.

View File

@ -14,9 +14,12 @@
## ##
## For stream writing it means that you should write exactly bounded size ## For stream writing it means that you should write exactly bounded size
## of bytes. ## of bytes.
import stew/results
import ../asyncloop, ../timer {.push raises: [].}
import asyncstream, ../transports/stream, ../transports/common
import results
import ../[asyncloop, timer, bipbuffer, config]
import asyncstream, ../transports/[stream, common]
export asyncloop, asyncstream, stream, timer, common export asyncloop, asyncstream, stream, timer, common
type type
@ -41,7 +44,7 @@ type
BoundedStreamRW* = BoundedStreamReader | BoundedStreamWriter BoundedStreamRW* = BoundedStreamReader | BoundedStreamWriter
const const
BoundedBufferSize* = 4096 BoundedBufferSize* = chronosStreamDefaultBufferSize
BoundarySizeDefectMessage = "Boundary must not be empty array" BoundarySizeDefectMessage = "Boundary must not be empty array"
template newBoundedStreamIncompleteError(): ref BoundedStreamError = template newBoundedStreamIncompleteError(): ref BoundedStreamError =
@ -52,7 +55,8 @@ template newBoundedStreamOverflowError(): ref BoundedStreamOverflowError =
newException(BoundedStreamOverflowError, "Stream boundary exceeded") newException(BoundedStreamOverflowError, "Stream boundary exceeded")
proc readUntilBoundary(rstream: AsyncStreamReader, pbytes: pointer, proc readUntilBoundary(rstream: AsyncStreamReader, pbytes: pointer,
nbytes: int, sep: seq[byte]): Future[int] {.async.} = nbytes: int, sep: seq[byte]): Future[int] {.
async: (raises: [CancelledError, AsyncStreamError]).} =
doAssert(not(isNil(pbytes)), "pbytes must not be nil") doAssert(not(isNil(pbytes)), "pbytes must not be nil")
doAssert(nbytes >= 0, "nbytes must be non-negative value") doAssert(nbytes >= 0, "nbytes must be non-negative value")
checkStreamClosed(rstream) checkStreamClosed(rstream)
@ -96,10 +100,10 @@ func endsWith(s, suffix: openArray[byte]): bool =
inc(i) inc(i)
if i >= len(suffix): return true if i >= len(suffix): return true
proc boundedReadLoop(stream: AsyncStreamReader) {.async.} = proc boundedReadLoop(stream: AsyncStreamReader) {.async: (raises: []).} =
var rstream = BoundedStreamReader(stream) var rstream = BoundedStreamReader(stream)
rstream.state = AsyncStreamState.Running rstream.state = AsyncStreamState.Running
var buffer = newSeq[byte](rstream.buffer.bufferLen()) var buffer = newSeq[byte](rstream.buffer.backend.availSpace())
while true: while true:
let toRead = let toRead =
if rstream.boundSize.isNone(): if rstream.boundSize.isNone():
@ -123,7 +127,7 @@ proc boundedReadLoop(stream: AsyncStreamReader) {.async.} =
# There should be one step between transferring last bytes to the # There should be one step between transferring last bytes to the
# consumer and declaring stream EOF. Otherwise could not be # consumer and declaring stream EOF. Otherwise could not be
# consumed. # consumed.
await upload(addr rstream.buffer, addr buffer[0], length) await upload(rstream.buffer, addr buffer[0], length)
if rstream.state == AsyncStreamState.Running: if rstream.state == AsyncStreamState.Running:
rstream.state = AsyncStreamState.Finished rstream.state = AsyncStreamState.Finished
else: else:
@ -131,7 +135,7 @@ proc boundedReadLoop(stream: AsyncStreamReader) {.async.} =
# There should be one step between transferring last bytes to the # There should be one step between transferring last bytes to the
# consumer and declaring stream EOF. Otherwise could not be # consumer and declaring stream EOF. Otherwise could not be
# consumed. # consumed.
await upload(addr rstream.buffer, addr buffer[0], res) await upload(rstream.buffer, addr buffer[0], res)
if (res < toRead) and rstream.rsource.atEof(): if (res < toRead) and rstream.rsource.atEof():
case rstream.cmpop case rstream.cmpop
@ -147,7 +151,7 @@ proc boundedReadLoop(stream: AsyncStreamReader) {.async.} =
# There should be one step between transferring last bytes to the # There should be one step between transferring last bytes to the
# consumer and declaring stream EOF. Otherwise could not be # consumer and declaring stream EOF. Otherwise could not be
# consumed. # consumed.
await upload(addr rstream.buffer, addr buffer[0], res) await upload(rstream.buffer, addr buffer[0], res)
if (res < toRead) and rstream.rsource.atEof(): if (res < toRead) and rstream.rsource.atEof():
case rstream.cmpop case rstream.cmpop
@ -186,12 +190,16 @@ proc boundedReadLoop(stream: AsyncStreamReader) {.async.} =
break break
of AsyncStreamState.Finished: of AsyncStreamState.Finished:
# Send `EOF` state to the consumer and wait until it will be received. # Send `EOF` state to the consumer and wait until it will be received.
await rstream.buffer.transfer() try:
await rstream.buffer.transfer()
except CancelledError:
rstream.state = AsyncStreamState.Error
rstream.error = newBoundedStreamIncompleteError()
break break
of AsyncStreamState.Closing, AsyncStreamState.Closed: of AsyncStreamState.Closing, AsyncStreamState.Closed:
break break
proc boundedWriteLoop(stream: AsyncStreamWriter) {.async.} = proc boundedWriteLoop(stream: AsyncStreamWriter) {.async: (raises: []).} =
var error: ref AsyncStreamError var error: ref AsyncStreamError
var wstream = BoundedStreamWriter(stream) var wstream = BoundedStreamWriter(stream)
@ -255,7 +263,11 @@ proc boundedWriteLoop(stream: AsyncStreamWriter) {.async.} =
doAssert(not(isNil(error))) doAssert(not(isNil(error)))
while not(wstream.queue.empty()): while not(wstream.queue.empty()):
let item = wstream.queue.popFirstNoWait() let item =
try:
wstream.queue.popFirstNoWait()
except AsyncQueueEmptyError:
raiseAssert "AsyncQueue should not be empty at this moment"
if not(item.future.finished()): if not(item.future.finished()):
item.future.fail(error) item.future.fail(error)

View File

@ -8,13 +8,16 @@
# MIT license (LICENSE-MIT) # MIT license (LICENSE-MIT)
## This module implements HTTP/1.1 chunked-encoded stream reading and writing. ## This module implements HTTP/1.1 chunked-encoded stream reading and writing.
import ../asyncloop, ../timer
import asyncstream, ../transports/stream, ../transports/common {.push raises: [].}
import stew/results
import ../[asyncloop, timer, bipbuffer, config]
import asyncstream, ../transports/[stream, common]
import results
export asyncloop, asyncstream, stream, timer, common, results export asyncloop, asyncstream, stream, timer, common, results
const const
ChunkBufferSize = 4096 ChunkBufferSize = chronosStreamDefaultBufferSize
MaxChunkHeaderSize = 1024 MaxChunkHeaderSize = 1024
ChunkHeaderValueSize = 8 ChunkHeaderValueSize = 8
# This is limit for chunk size to 8 hexadecimal digits, so maximum # This is limit for chunk size to 8 hexadecimal digits, so maximum
@ -95,7 +98,7 @@ proc setChunkSize(buffer: var openArray[byte], length: int64): int =
buffer[c + 1] = byte(0x0A) buffer[c + 1] = byte(0x0A)
(c + 2) (c + 2)
proc chunkedReadLoop(stream: AsyncStreamReader) {.async.} = proc chunkedReadLoop(stream: AsyncStreamReader) {.async: (raises: []).} =
var rstream = ChunkedStreamReader(stream) var rstream = ChunkedStreamReader(stream)
var buffer = newSeq[byte](MaxChunkHeaderSize) var buffer = newSeq[byte](MaxChunkHeaderSize)
rstream.state = AsyncStreamState.Running rstream.state = AsyncStreamState.Running
@ -115,11 +118,11 @@ proc chunkedReadLoop(stream: AsyncStreamReader) {.async.} =
var chunksize = cres.get() var chunksize = cres.get()
if chunksize > 0'u64: if chunksize > 0'u64:
while chunksize > 0'u64: while chunksize > 0'u64:
let toRead = int(min(chunksize, let
uint64(rstream.buffer.bufferLen()))) (data, rsize) = rstream.buffer.backend.reserve()
await rstream.rsource.readExactly(rstream.buffer.getBuffer(), toRead = int(min(chunksize, uint64(rsize)))
toRead) await rstream.rsource.readExactly(data, toRead)
rstream.buffer.update(toRead) rstream.buffer.backend.commit(toRead)
await rstream.buffer.transfer() await rstream.buffer.transfer()
chunksize = chunksize - uint64(toRead) chunksize = chunksize - uint64(toRead)
@ -156,6 +159,10 @@ proc chunkedReadLoop(stream: AsyncStreamReader) {.async.} =
if rstream.state == AsyncStreamState.Running: if rstream.state == AsyncStreamState.Running:
rstream.state = AsyncStreamState.Error rstream.state = AsyncStreamState.Error
rstream.error = exc rstream.error = exc
except AsyncStreamError as exc:
if rstream.state == AsyncStreamState.Running:
rstream.state = AsyncStreamState.Error
rstream.error = exc
if rstream.state != AsyncStreamState.Running: if rstream.state != AsyncStreamState.Running:
# We need to notify consumer about error/close, but we do not care about # We need to notify consumer about error/close, but we do not care about
@ -163,7 +170,7 @@ proc chunkedReadLoop(stream: AsyncStreamReader) {.async.} =
rstream.buffer.forget() rstream.buffer.forget()
break break
proc chunkedWriteLoop(stream: AsyncStreamWriter) {.async.} = proc chunkedWriteLoop(stream: AsyncStreamWriter) {.async: (raises: []).} =
var wstream = ChunkedStreamWriter(stream) var wstream = ChunkedStreamWriter(stream)
var buffer: array[16, byte] var buffer: array[16, byte]
var error: ref AsyncStreamError var error: ref AsyncStreamError
@ -220,7 +227,11 @@ proc chunkedWriteLoop(stream: AsyncStreamWriter) {.async.} =
if not(item.future.finished()): if not(item.future.finished()):
item.future.fail(error) item.future.fail(error)
while not(wstream.queue.empty()): while not(wstream.queue.empty()):
let pitem = wstream.queue.popFirstNoWait() let pitem =
try:
wstream.queue.popFirstNoWait()
except AsyncQueueEmptyError:
raiseAssert "AsyncQueue should not be empty at this moment"
if not(pitem.future.finished()): if not(pitem.future.finished()):
pitem.future.fail(error) pitem.future.fail(error)
break break

View File

@ -9,13 +9,19 @@
## This module implements Transport Layer Security (TLS) stream. This module ## This module implements Transport Layer Security (TLS) stream. This module
## uses sources of BearSSL <https://www.bearssl.org> by Thomas Pornin. ## uses sources of BearSSL <https://www.bearssl.org> by Thomas Pornin.
{.push raises: [].}
import import
bearssl/[brssl, ec, errors, pem, rsa, ssl, x509], bearssl/[brssl, ec, errors, pem, rsa, ssl, x509],
bearssl/certs/cacert bearssl/certs/cacert
import ../asyncloop, ../timer, ../asyncsync import ".."/[asyncloop, asyncsync, config, timer]
import asyncstream, ../transports/stream, ../transports/common import asyncstream, ../transports/[stream, common]
export asyncloop, asyncsync, timer, asyncstream export asyncloop, asyncsync, timer, asyncstream
const
TLSSessionCacheBufferSize* = chronosTLSSessionCacheBufferSize
type type
TLSStreamKind {.pure.} = enum TLSStreamKind {.pure.} = enum
Client, Server Client, Server
@ -71,7 +77,7 @@ type
scontext: ptr SslServerContext scontext: ptr SslServerContext
stream*: TLSAsyncStream stream*: TLSAsyncStream
handshaked*: bool handshaked*: bool
handshakeFut*: Future[void] handshakeFut*: Future[void].Raising([CancelledError, AsyncStreamError])
TLSStreamReader* = ref object of AsyncStreamReader TLSStreamReader* = ref object of AsyncStreamReader
case kind: TLSStreamKind case kind: TLSStreamKind
@ -81,7 +87,7 @@ type
scontext: ptr SslServerContext scontext: ptr SslServerContext
stream*: TLSAsyncStream stream*: TLSAsyncStream
handshaked*: bool handshaked*: bool
handshakeFut*: Future[void] handshakeFut*: Future[void].Raising([CancelledError, AsyncStreamError])
TLSAsyncStream* = ref object of RootRef TLSAsyncStream* = ref object of RootRef
xwc*: X509NoanchorContext xwc*: X509NoanchorContext
@ -91,18 +97,17 @@ type
x509*: X509MinimalContext x509*: X509MinimalContext
reader*: TLSStreamReader reader*: TLSStreamReader
writer*: TLSStreamWriter writer*: TLSStreamWriter
mainLoop*: Future[void] mainLoop*: Future[void].Raising([])
trustAnchors: TrustAnchorStore trustAnchors: TrustAnchorStore
SomeTLSStreamType* = TLSStreamReader|TLSStreamWriter|TLSAsyncStream SomeTLSStreamType* = TLSStreamReader|TLSStreamWriter|TLSAsyncStream
SomeTrustAnchorType* = TrustAnchorStore | openArray[X509TrustAnchor]
TLSStreamError* = object of AsyncStreamError TLSStreamError* = object of AsyncStreamError
TLSStreamHandshakeError* = object of TLSStreamError TLSStreamHandshakeError* = object of TLSStreamError
TLSStreamInitError* = object of TLSStreamError TLSStreamInitError* = object of TLSStreamError
TLSStreamReadError* = object of TLSStreamError TLSStreamReadError* = object of TLSStreamError
par*: ref AsyncStreamError
TLSStreamWriteError* = object of TLSStreamError TLSStreamWriteError* = object of TLSStreamError
par*: ref AsyncStreamError
TLSStreamProtocolError* = object of TLSStreamError TLSStreamProtocolError* = object of TLSStreamError
errCode*: int errCode*: int
@ -110,7 +115,7 @@ proc newTLSStreamWriteError(p: ref AsyncStreamError): ref TLSStreamWriteError {.
noinline.} = noinline.} =
var w = newException(TLSStreamWriteError, "Write stream failed") var w = newException(TLSStreamWriteError, "Write stream failed")
w.msg = w.msg & ", originated from [" & $p.name & "] " & p.msg w.msg = w.msg & ", originated from [" & $p.name & "] " & p.msg
w.par = p w.parent = p
w w
template newTLSStreamProtocolImpl[T](message: T): ref TLSStreamProtocolError = template newTLSStreamProtocolImpl[T](message: T): ref TLSStreamProtocolError =
@ -136,38 +141,41 @@ template newTLSUnexpectedProtocolError(): ref TLSStreamProtocolError =
proc newTLSStreamProtocolError[T](message: T): ref TLSStreamProtocolError = proc newTLSStreamProtocolError[T](message: T): ref TLSStreamProtocolError =
newTLSStreamProtocolImpl(message) newTLSStreamProtocolImpl(message)
proc raiseTLSStreamProtocolError[T](message: T) {.noreturn, noinline.} = proc raiseTLSStreamProtocolError[T](message: T) {.
noreturn, noinline, raises: [TLSStreamProtocolError].} =
raise newTLSStreamProtocolImpl(message) raise newTLSStreamProtocolImpl(message)
proc new*(T: typedesc[TrustAnchorStore], anchors: openArray[X509TrustAnchor]): TrustAnchorStore = proc new*(T: typedesc[TrustAnchorStore],
anchors: openArray[X509TrustAnchor]): TrustAnchorStore =
var res: seq[X509TrustAnchor] var res: seq[X509TrustAnchor]
for anchor in anchors: for anchor in anchors:
res.add(anchor) res.add(anchor)
doAssert(unsafeAddr(anchor) != unsafeAddr(res[^1]), "Anchors should be copied") doAssert(unsafeAddr(anchor) != unsafeAddr(res[^1]),
return TrustAnchorStore(anchors: res) "Anchors should be copied")
TrustAnchorStore(anchors: res)
proc tlsWriteRec(engine: ptr SslEngineContext, proc tlsWriteRec(engine: ptr SslEngineContext,
writer: TLSStreamWriter): Future[TLSResult] {.async.} = writer: TLSStreamWriter): Future[TLSResult] {.
async: (raises: []).} =
try: try:
var length = 0'u var length = 0'u
var buf = sslEngineSendrecBuf(engine[], length) var buf = sslEngineSendrecBuf(engine[], length)
doAssert(length != 0 and not isNil(buf)) doAssert(length != 0 and not isNil(buf))
await writer.wsource.write(buf, int(length)) await writer.wsource.write(buf, int(length))
sslEngineSendrecAck(engine[], length) sslEngineSendrecAck(engine[], length)
return TLSResult.Success TLSResult.Success
except AsyncStreamError as exc: except AsyncStreamError as exc:
writer.state = AsyncStreamState.Error writer.state = AsyncStreamState.Error
writer.error = exc writer.error = exc
return TLSResult.Error TLSResult.Error
except CancelledError: except CancelledError:
if writer.state == AsyncStreamState.Running: if writer.state == AsyncStreamState.Running:
writer.state = AsyncStreamState.Stopped writer.state = AsyncStreamState.Stopped
return TLSResult.Stopped TLSResult.Stopped
return TLSResult.Error
proc tlsWriteApp(engine: ptr SslEngineContext, proc tlsWriteApp(engine: ptr SslEngineContext,
writer: TLSStreamWriter): Future[TLSResult] {.async.} = writer: TLSStreamWriter): Future[TLSResult] {.
async: (raises: []).} =
try: try:
var item = await writer.queue.get() var item = await writer.queue.get()
if item.size > 0: if item.size > 0:
@ -179,7 +187,6 @@ proc tlsWriteApp(engine: ptr SslEngineContext,
# (and discarded). # (and discarded).
writer.state = AsyncStreamState.Finished writer.state = AsyncStreamState.Finished
return TLSResult.WriteEof return TLSResult.WriteEof
let toWrite = min(int(length), item.size) let toWrite = min(int(length), item.size)
copyOut(buf, item, toWrite) copyOut(buf, item, toWrite)
if int(length) >= item.size: if int(length) >= item.size:
@ -187,28 +194,29 @@ proc tlsWriteApp(engine: ptr SslEngineContext,
sslEngineSendappAck(engine[], uint(item.size)) sslEngineSendappAck(engine[], uint(item.size))
sslEngineFlush(engine[], 0) sslEngineFlush(engine[], 0)
item.future.complete() item.future.complete()
return TLSResult.Success
else: else:
# BearSSL is not ready to accept whole item, so we will send # BearSSL is not ready to accept whole item, so we will send
# only part of item and adjust offset. # only part of item and adjust offset.
item.offset = item.offset + int(length) item.offset = item.offset + int(length)
item.size = item.size - int(length) item.size = item.size - int(length)
writer.queue.addFirstNoWait(item) try:
writer.queue.addFirstNoWait(item)
except AsyncQueueFullError:
raiseAssert "AsyncQueue should not be full at this moment"
sslEngineSendappAck(engine[], length) sslEngineSendappAck(engine[], length)
return TLSResult.Success TLSResult.Success
else: else:
sslEngineClose(engine[]) sslEngineClose(engine[])
item.future.complete() item.future.complete()
return TLSResult.Success TLSResult.Success
except CancelledError: except CancelledError:
if writer.state == AsyncStreamState.Running: if writer.state == AsyncStreamState.Running:
writer.state = AsyncStreamState.Stopped writer.state = AsyncStreamState.Stopped
return TLSResult.Stopped TLSResult.Stopped
return TLSResult.Error
proc tlsReadRec(engine: ptr SslEngineContext, proc tlsReadRec(engine: ptr SslEngineContext,
reader: TLSStreamReader): Future[TLSResult] {.async.} = reader: TLSStreamReader): Future[TLSResult] {.
async: (raises: []).} =
try: try:
var length = 0'u var length = 0'u
var buf = sslEngineRecvrecBuf(engine[], length) var buf = sslEngineRecvrecBuf(engine[], length)
@ -216,38 +224,35 @@ proc tlsReadRec(engine: ptr SslEngineContext,
sslEngineRecvrecAck(engine[], uint(res)) sslEngineRecvrecAck(engine[], uint(res))
if res == 0: if res == 0:
sslEngineClose(engine[]) sslEngineClose(engine[])
return TLSResult.ReadEof TLSResult.ReadEof
else: else:
return TLSResult.Success TLSResult.Success
except AsyncStreamError as exc: except AsyncStreamError as exc:
reader.state = AsyncStreamState.Error reader.state = AsyncStreamState.Error
reader.error = exc reader.error = exc
return TLSResult.Error TLSResult.Error
except CancelledError: except CancelledError:
if reader.state == AsyncStreamState.Running: if reader.state == AsyncStreamState.Running:
reader.state = AsyncStreamState.Stopped reader.state = AsyncStreamState.Stopped
return TLSResult.Stopped TLSResult.Stopped
return TLSResult.Error
proc tlsReadApp(engine: ptr SslEngineContext, proc tlsReadApp(engine: ptr SslEngineContext,
reader: TLSStreamReader): Future[TLSResult] {.async.} = reader: TLSStreamReader): Future[TLSResult] {.
async: (raises: []).} =
try: try:
var length = 0'u var length = 0'u
var buf = sslEngineRecvappBuf(engine[], length) var buf = sslEngineRecvappBuf(engine[], length)
await upload(addr reader.buffer, buf, int(length)) await upload(reader.buffer, buf, int(length))
sslEngineRecvappAck(engine[], length) sslEngineRecvappAck(engine[], length)
return TLSResult.Success TLSResult.Success
except CancelledError: except CancelledError:
if reader.state == AsyncStreamState.Running: if reader.state == AsyncStreamState.Running:
reader.state = AsyncStreamState.Stopped reader.state = AsyncStreamState.Stopped
return TLSResult.Stopped TLSResult.Stopped
return TLSResult.Error
template readAndReset(fut: untyped) = template readAndReset(fut: untyped) =
if fut.finished(): if fut.finished():
let res = fut.read() let res = fut.value()
case res case res
of TLSResult.Success, TLSResult.WriteEof, TLSResult.Stopped: of TLSResult.Success, TLSResult.WriteEof, TLSResult.Stopped:
fut = nil fut = nil
@ -263,22 +268,6 @@ template readAndReset(fut: untyped) =
loopState = AsyncStreamState.Finished loopState = AsyncStreamState.Finished
break break
proc cancelAndWait*(a, b, c, d: Future[TLSResult]): Future[void] =
var waiting: seq[Future[TLSResult]]
if not(isNil(a)) and not(a.finished()):
a.cancel()
waiting.add(a)
if not(isNil(b)) and not(b.finished()):
b.cancel()
waiting.add(b)
if not(isNil(c)) and not(c.finished()):
c.cancel()
waiting.add(c)
if not(isNil(d)) and not(d.finished()):
d.cancel()
waiting.add(d)
allFutures(waiting)
proc dumpState*(state: cuint): string = proc dumpState*(state: cuint): string =
var res = "" var res = ""
if (state and SSL_CLOSED) == SSL_CLOSED: if (state and SSL_CLOSED) == SSL_CLOSED:
@ -298,10 +287,10 @@ proc dumpState*(state: cuint): string =
res.add("SSL_RECVAPP") res.add("SSL_RECVAPP")
"{" & res & "}" "{" & res & "}"
proc tlsLoop*(stream: TLSAsyncStream) {.async.} = proc tlsLoop*(stream: TLSAsyncStream) {.async: (raises: []).} =
var var
sendRecFut, sendAppFut: Future[TLSResult] sendRecFut, sendAppFut: Future[TLSResult].Raising([])
recvRecFut, recvAppFut: Future[TLSResult] recvRecFut, recvAppFut: Future[TLSResult].Raising([])
let engine = let engine =
case stream.reader.kind case stream.reader.kind
@ -313,7 +302,7 @@ proc tlsLoop*(stream: TLSAsyncStream) {.async.} =
var loopState = AsyncStreamState.Running var loopState = AsyncStreamState.Running
while true: while true:
var waiting: seq[Future[TLSResult]] var waiting: seq[Future[TLSResult].Raising([])]
var state = sslEngineCurrentState(engine[]) var state = sslEngineCurrentState(engine[])
if (state and SSL_CLOSED) == SSL_CLOSED: if (state and SSL_CLOSED) == SSL_CLOSED:
@ -364,6 +353,8 @@ proc tlsLoop*(stream: TLSAsyncStream) {.async.} =
if len(waiting) > 0: if len(waiting) > 0:
try: try:
discard await one(waiting) discard await one(waiting)
except ValueError:
raiseAssert "array should not be empty at this moment"
except CancelledError: except CancelledError:
if loopState == AsyncStreamState.Running: if loopState == AsyncStreamState.Running:
loopState = AsyncStreamState.Stopped loopState = AsyncStreamState.Stopped
@ -371,8 +362,18 @@ proc tlsLoop*(stream: TLSAsyncStream) {.async.} =
if loopState != AsyncStreamState.Running: if loopState != AsyncStreamState.Running:
break break
# Cancelling and waiting all the pending operations # Cancelling and waiting and all the pending operations
await cancelAndWait(sendRecFut, sendAppFut, recvRecFut, recvAppFut) var pending: seq[FutureBase]
if not(isNil(sendRecFut)) and not(sendRecFut.finished()):
pending.add(sendRecFut.cancelAndWait())
if not(isNil(sendAppFut)) and not(sendAppFut.finished()):
pending.add(sendAppFut.cancelAndWait())
if not(isNil(recvRecFut)) and not(recvRecFut.finished()):
pending.add(recvRecFut.cancelAndWait())
if not(isNil(recvAppFut)) and not(recvAppFut.finished()):
pending.add(recvAppFut.cancelAndWait())
await noCancel(allFutures(pending))
# Calculating error # Calculating error
let error = let error =
case loopState case loopState
@ -406,7 +407,11 @@ proc tlsLoop*(stream: TLSAsyncStream) {.async.} =
if not(isNil(error)): if not(isNil(error)):
# Completing all pending writes # Completing all pending writes
while(not(stream.writer.queue.empty())): while(not(stream.writer.queue.empty())):
let item = stream.writer.queue.popFirstNoWait() let item =
try:
stream.writer.queue.popFirstNoWait()
except AsyncQueueEmptyError:
raiseAssert "AsyncQueue should not be empty at this moment"
if not(item.future.finished()): if not(item.future.finished()):
item.future.fail(error) item.future.fail(error)
# Completing handshake # Completing handshake
@ -426,18 +431,18 @@ proc tlsLoop*(stream: TLSAsyncStream) {.async.} =
# Completing readers # Completing readers
stream.reader.buffer.forget() stream.reader.buffer.forget()
proc tlsWriteLoop(stream: AsyncStreamWriter) {.async.} = proc tlsWriteLoop(stream: AsyncStreamWriter) {.async: (raises: []).} =
var wstream = TLSStreamWriter(stream) var wstream = TLSStreamWriter(stream)
wstream.state = AsyncStreamState.Running wstream.state = AsyncStreamState.Running
await stepsAsync(1) await noCancel(sleepAsync(0.milliseconds))
if isNil(wstream.stream.mainLoop): if isNil(wstream.stream.mainLoop):
wstream.stream.mainLoop = tlsLoop(wstream.stream) wstream.stream.mainLoop = tlsLoop(wstream.stream)
await wstream.stream.mainLoop await wstream.stream.mainLoop
proc tlsReadLoop(stream: AsyncStreamReader) {.async.} = proc tlsReadLoop(stream: AsyncStreamReader) {.async: (raises: []).} =
var rstream = TLSStreamReader(stream) var rstream = TLSStreamReader(stream)
rstream.state = AsyncStreamState.Running rstream.state = AsyncStreamState.Running
await stepsAsync(1) await noCancel(sleepAsync(0.milliseconds))
if isNil(rstream.stream.mainLoop): if isNil(rstream.stream.mainLoop):
rstream.stream.mainLoop = tlsLoop(rstream.stream) rstream.stream.mainLoop = tlsLoop(rstream.stream)
await rstream.stream.mainLoop await rstream.stream.mainLoop
@ -453,15 +458,16 @@ proc getSignerAlgo(xc: X509Certificate): int =
else: else:
int(x509DecoderGetSignerKeyType(dc)) int(x509DecoderGetSignerKeyType(dc))
proc newTLSClientAsyncStream*(rsource: AsyncStreamReader, proc newTLSClientAsyncStream*(
wsource: AsyncStreamWriter, rsource: AsyncStreamReader,
serverName: string, wsource: AsyncStreamWriter,
bufferSize = SSL_BUFSIZE_BIDI, serverName: string,
minVersion = TLSVersion.TLS12, bufferSize = SSL_BUFSIZE_BIDI,
maxVersion = TLSVersion.TLS12, minVersion = TLSVersion.TLS12,
flags: set[TLSFlags] = {}, maxVersion = TLSVersion.TLS12,
trustAnchors: TrustAnchorStore | openArray[X509TrustAnchor] = MozillaTrustAnchors flags: set[TLSFlags] = {},
): TLSAsyncStream = trustAnchors: SomeTrustAnchorType = MozillaTrustAnchors
): TLSAsyncStream {.raises: [TLSStreamInitError].} =
## Create new TLS asynchronous stream for outbound (client) connections ## Create new TLS asynchronous stream for outbound (client) connections
## using reading stream ``rsource`` and writing stream ``wsource``. ## using reading stream ``rsource`` and writing stream ``wsource``.
## ##
@ -484,7 +490,8 @@ proc newTLSClientAsyncStream*(rsource: AsyncStreamReader,
## a ``TrustAnchorStore`` you should reuse the same instance for ## a ``TrustAnchorStore`` you should reuse the same instance for
## every call to avoid making a copy of the trust anchors per call. ## every call to avoid making a copy of the trust anchors per call.
when trustAnchors is TrustAnchorStore: when trustAnchors is TrustAnchorStore:
doAssert(len(trustAnchors.anchors) > 0, "Empty trust anchor list is invalid") doAssert(len(trustAnchors.anchors) > 0,
"Empty trust anchor list is invalid")
else: else:
doAssert(len(trustAnchors) > 0, "Empty trust anchor list is invalid") doAssert(len(trustAnchors) > 0, "Empty trust anchor list is invalid")
var res = TLSAsyncStream() var res = TLSAsyncStream()
@ -503,8 +510,10 @@ proc newTLSClientAsyncStream*(rsource: AsyncStreamReader,
if TLSFlags.NoVerifyHost in flags: if TLSFlags.NoVerifyHost in flags:
sslClientInitFull(res.ccontext, addr res.x509, nil, 0) sslClientInitFull(res.ccontext, addr res.x509, nil, 0)
x509NoanchorInit(res.xwc, addr res.x509.vtable) x509NoanchorInit(res.xwc,
sslEngineSetX509(res.ccontext.eng, addr res.xwc.vtable) X509ClassPointerConst(addr res.x509.vtable))
sslEngineSetX509(res.ccontext.eng,
X509ClassPointerConst(addr res.xwc.vtable))
else: else:
when trustAnchors is TrustAnchorStore: when trustAnchors is TrustAnchorStore:
res.trustAnchors = trustAnchors res.trustAnchors = trustAnchors
@ -524,7 +533,7 @@ proc newTLSClientAsyncStream*(rsource: AsyncStreamReader,
uint16(maxVersion)) uint16(maxVersion))
if TLSFlags.NoVerifyServerName in flags: if TLSFlags.NoVerifyServerName in flags:
let err = sslClientReset(res.ccontext, "", 0) let err = sslClientReset(res.ccontext, nil, 0)
if err == 0: if err == 0:
raise newException(TLSStreamInitError, "Could not initialize TLS layer") raise newException(TLSStreamInitError, "Could not initialize TLS layer")
else: else:
@ -550,7 +559,8 @@ proc newTLSServerAsyncStream*(rsource: AsyncStreamReader,
minVersion = TLSVersion.TLS11, minVersion = TLSVersion.TLS11,
maxVersion = TLSVersion.TLS12, maxVersion = TLSVersion.TLS12,
cache: TLSSessionCache = nil, cache: TLSSessionCache = nil,
flags: set[TLSFlags] = {}): TLSAsyncStream = flags: set[TLSFlags] = {}): TLSAsyncStream {.
raises: [TLSStreamInitError, TLSStreamProtocolError].} =
## Create new TLS asynchronous stream for inbound (server) connections ## Create new TLS asynchronous stream for inbound (server) connections
## using reading stream ``rsource`` and writing stream ``wsource``. ## using reading stream ``rsource`` and writing stream ``wsource``.
## ##
@ -603,7 +613,8 @@ proc newTLSServerAsyncStream*(rsource: AsyncStreamReader,
uint16(maxVersion)) uint16(maxVersion))
if not isNil(cache): if not isNil(cache):
sslServerSetCache(res.scontext, addr cache.context.vtable) sslServerSetCache(
res.scontext, SslSessionCacheClassPointerConst(addr cache.context.vtable))
if TLSFlags.EnforceServerPref in flags: if TLSFlags.EnforceServerPref in flags:
sslEngineAddFlags(res.scontext.eng, OPT_ENFORCE_SERVER_PREFERENCES) sslEngineAddFlags(res.scontext.eng, OPT_ENFORCE_SERVER_PREFERENCES)
@ -618,10 +629,8 @@ proc newTLSServerAsyncStream*(rsource: AsyncStreamReader,
if err == 0: if err == 0:
raise newException(TLSStreamInitError, "Could not initialize TLS layer") raise newException(TLSStreamInitError, "Could not initialize TLS layer")
init(AsyncStreamWriter(res.writer), wsource, tlsWriteLoop, init(AsyncStreamWriter(res.writer), wsource, tlsWriteLoop, bufferSize)
bufferSize) init(AsyncStreamReader(res.reader), rsource, tlsReadLoop, bufferSize)
init(AsyncStreamReader(res.reader), rsource, tlsReadLoop,
bufferSize)
res res
proc copyKey(src: RsaPrivateKey): TLSPrivateKey = proc copyKey(src: RsaPrivateKey): TLSPrivateKey =
@ -662,7 +671,8 @@ proc copyKey(src: EcPrivateKey): TLSPrivateKey =
res.eckey.curve = src.curve res.eckey.curve = src.curve
res res
proc init*(tt: typedesc[TLSPrivateKey], data: openArray[byte]): TLSPrivateKey = proc init*(tt: typedesc[TLSPrivateKey], data: openArray[byte]): TLSPrivateKey {.
raises: [TLSStreamProtocolError].} =
## Initialize TLS private key from array of bytes ``data``. ## Initialize TLS private key from array of bytes ``data``.
## ##
## This procedure initializes private key using raw, DER-encoded format, ## This procedure initializes private key using raw, DER-encoded format,
@ -685,7 +695,8 @@ proc init*(tt: typedesc[TLSPrivateKey], data: openArray[byte]): TLSPrivateKey =
raiseTLSStreamProtocolError("Unknown key type (" & $keyType & ")") raiseTLSStreamProtocolError("Unknown key type (" & $keyType & ")")
res res
proc pemDecode*(data: openArray[char]): seq[PEMElement] = proc pemDecode*(data: openArray[char]): seq[PEMElement] {.
raises: [TLSStreamProtocolError].} =
## Decode PEM encoded string and get array of binary blobs. ## Decode PEM encoded string and get array of binary blobs.
if len(data) == 0: if len(data) == 0:
raiseTLSStreamProtocolError("Empty PEM message") raiseTLSStreamProtocolError("Empty PEM message")
@ -726,7 +737,8 @@ proc pemDecode*(data: openArray[char]): seq[PEMElement] =
raiseTLSStreamProtocolError("Invalid PEM encoding") raiseTLSStreamProtocolError("Invalid PEM encoding")
res res
proc init*(tt: typedesc[TLSPrivateKey], data: openArray[char]): TLSPrivateKey = proc init*(tt: typedesc[TLSPrivateKey], data: openArray[char]): TLSPrivateKey {.
raises: [TLSStreamProtocolError].} =
## Initialize TLS private key from string ``data``. ## Initialize TLS private key from string ``data``.
## ##
## This procedure initializes private key using unencrypted PKCS#8 PEM ## This procedure initializes private key using unencrypted PKCS#8 PEM
@ -744,7 +756,8 @@ proc init*(tt: typedesc[TLSPrivateKey], data: openArray[char]): TLSPrivateKey =
res res
proc init*(tt: typedesc[TLSCertificate], proc init*(tt: typedesc[TLSCertificate],
data: openArray[char]): TLSCertificate = data: openArray[char]): TLSCertificate {.
raises: [TLSStreamProtocolError].} =
## Initialize TLS certificates from string ``data``. ## Initialize TLS certificates from string ``data``.
## ##
## This procedure initializes array of certificates from PEM encoded string. ## This procedure initializes array of certificates from PEM encoded string.
@ -770,18 +783,21 @@ proc init*(tt: typedesc[TLSCertificate],
raiseTLSStreamProtocolError("Could not find any certificates") raiseTLSStreamProtocolError("Could not find any certificates")
res res
proc init*(tt: typedesc[TLSSessionCache], size: int = 4096): TLSSessionCache = proc init*(tt: typedesc[TLSSessionCache],
size: int = TLSSessionCacheBufferSize): TLSSessionCache =
## Create new TLS session cache with size ``size``. ## Create new TLS session cache with size ``size``.
## ##
## One cached item is near 100 bytes size. ## One cached item is near 100 bytes size.
var rsize = min(size, 4096) let rsize = min(size, 4096)
var res = TLSSessionCache(storage: newSeq[byte](rsize)) var res = TLSSessionCache(storage: newSeq[byte](rsize))
sslSessionCacheLruInit(addr res.context, addr res.storage[0], rsize) sslSessionCacheLruInit(addr res.context, addr res.storage[0], rsize)
res res
proc handshake*(rws: SomeTLSStreamType): Future[void] = proc handshake*(rws: SomeTLSStreamType): Future[void] {.
async: (raw: true, raises: [CancelledError, AsyncStreamError]).} =
## Wait until initial TLS handshake will be successfully performed. ## Wait until initial TLS handshake will be successfully performed.
var retFuture = newFuture[void]("tlsstream.handshake") let retFuture = Future[void].Raising([CancelledError, AsyncStreamError])
.init("tlsstream.handshake")
when rws is TLSStreamReader: when rws is TLSStreamReader:
if rws.handshaked: if rws.handshaked:
retFuture.complete() retFuture.complete()

View File

@ -8,7 +8,7 @@
# MIT license (LICENSE-MIT) # MIT license (LICENSE-MIT)
## This module implements some core async thread synchronization primitives. ## This module implements some core async thread synchronization primitives.
import stew/results import results
import "."/[timer, asyncloop] import "."/[timer, asyncloop]
export results export results
@ -272,7 +272,8 @@ proc waitSync*(signal: ThreadSignalPtr,
else: else:
return ok(true) return ok(true)
proc fire*(signal: ThreadSignalPtr): Future[void] = proc fire*(signal: ThreadSignalPtr): Future[void] {.
async: (raises: [AsyncError, CancelledError], raw: true).} =
## Set state of ``signal`` to signaled in asynchronous way. ## Set state of ``signal`` to signaled in asynchronous way.
var retFuture = newFuture[void]("asyncthreadsignal.fire") var retFuture = newFuture[void]("asyncthreadsignal.fire")
when defined(windows): when defined(windows):
@ -356,14 +357,17 @@ proc fire*(signal: ThreadSignalPtr): Future[void] =
retFuture retFuture
when defined(windows): when defined(windows):
proc wait*(signal: ThreadSignalPtr) {.async.} = proc wait*(signal: ThreadSignalPtr) {.
async: (raises: [AsyncError, CancelledError]).} =
let handle = signal[].event let handle = signal[].event
let res = await waitForSingleObject(handle, InfiniteDuration) let res = await waitForSingleObject(handle, InfiniteDuration)
# There should be no other response, because we use `InfiniteDuration`. # There should be no other response, because we use `InfiniteDuration`.
doAssert(res == WaitableResult.Ok) doAssert(res == WaitableResult.Ok)
else: else:
proc wait*(signal: ThreadSignalPtr): Future[void] = proc wait*(signal: ThreadSignalPtr): Future[void] {.
var retFuture = newFuture[void]("asyncthreadsignal.wait") async: (raises: [AsyncError, CancelledError], raw: true).} =
let retFuture = Future[void].Raising([AsyncError, CancelledError]).init(
"asyncthreadsignal.wait")
var data = 1'u64 var data = 1'u64
let eventFd = let eventFd =
when defined(linux): when defined(linux):

View File

@ -370,53 +370,42 @@ template add(a: var string, b: Base10Buf[uint64]) =
for index in 0 ..< b.len: for index in 0 ..< b.len:
a.add(char(b.data[index])) a.add(char(b.data[index]))
func `$`*(a: Duration): string {.inline.} = func toString*(a: timer.Duration, parts = int.high): string =
## Returns string representation of Duration ``a`` as nanoseconds value. ## Returns a pretty string representation of Duration ``a`` - the
var res = "" ## number of parts returned can be limited thus truncating the output to
var v = a.value ## an approximation that grows more precise as the duration becomes smaller
var
res = newStringOfCap(32)
v = a.nanoseconds()
parts = parts
template f(n: string, T: Duration) =
if parts <= 0:
return res
if v >= T.nanoseconds():
res.add(Base10.toBytes(uint64(v div T.nanoseconds())))
res.add(n)
v = v mod T.nanoseconds()
dec parts
if v == 0:
return res
f("w", Week)
f("d", Day)
f("h", Hour)
f("m", Minute)
f("s", Second)
f("ms", Millisecond)
f("us", Microsecond)
f("ns", Nanosecond)
if v >= Week.value:
res.add(Base10.toBytes(uint64(v div Week.value)))
res.add('w')
v = v mod Week.value
if v == 0: return res
if v >= Day.value:
res.add(Base10.toBytes(uint64(v div Day.value)))
res.add('d')
v = v mod Day.value
if v == 0: return res
if v >= Hour.value:
res.add(Base10.toBytes(uint64(v div Hour.value)))
res.add('h')
v = v mod Hour.value
if v == 0: return res
if v >= Minute.value:
res.add(Base10.toBytes(uint64(v div Minute.value)))
res.add('m')
v = v mod Minute.value
if v == 0: return res
if v >= Second.value:
res.add(Base10.toBytes(uint64(v div Second.value)))
res.add('s')
v = v mod Second.value
if v == 0: return res
if v >= Millisecond.value:
res.add(Base10.toBytes(uint64(v div Millisecond.value)))
res.add('m')
res.add('s')
v = v mod Millisecond.value
if v == 0: return res
if v >= Microsecond.value:
res.add(Base10.toBytes(uint64(v div Microsecond.value)))
res.add('u')
res.add('s')
v = v mod Microsecond.value
if v == 0: return res
res.add(Base10.toBytes(uint64(v div Nanosecond.value)))
res.add('n')
res.add('s')
res res
func `$`*(a: Duration): string {.inline.} =
## Returns string representation of Duration ``a``.
a.toString()
func `$`*(a: Moment): string {.inline.} = func `$`*(a: Moment): string {.inline.} =
## Returns string representation of Moment ``a`` as nanoseconds value. ## Returns string representation of Moment ``a`` as nanoseconds value.
var res = "" var res = ""

View File

@ -10,26 +10,30 @@
{.push raises: [].} {.push raises: [].}
import std/[strutils] import std/[strutils]
import results
import stew/[base10, byteutils] import stew/[base10, byteutils]
import ".."/[asyncloop, osdefs, oserrno] import ".."/[config, asyncloop, osdefs, oserrno, handles]
from std/net import Domain, `==`, IpAddress, IpAddressFamily, parseIpAddress, from std/net import Domain, `==`, IpAddress, IpAddressFamily, parseIpAddress,
SockType, Protocol, Port, `$` SockType, Protocol, Port, `$`
from std/nativesockets import toInt, `$` from std/nativesockets import toInt, `$`
export Domain, `==`, IpAddress, IpAddressFamily, parseIpAddress, SockType, export Domain, `==`, IpAddress, IpAddressFamily, parseIpAddress, SockType,
Protocol, Port, toInt, `$` Protocol, Port, toInt, `$`, results
const const
DefaultStreamBufferSize* = 4096 ## Default buffer size for stream DefaultStreamBufferSize* = chronosTransportDefaultBufferSize
## transports ## Default buffer size for stream transports
DefaultDatagramBufferSize* = 65536 ## Default buffer size for datagram DefaultDatagramBufferSize* = 65536
## transports ## Default buffer size for datagram transports
type type
ServerFlags* = enum ServerFlags* = enum
## Server's flags ## Server's flags
ReuseAddr, ReusePort, TcpNoDelay, NoAutoRead, GCUserData, FirstPipe, ReuseAddr, ReusePort, TcpNoDelay, NoAutoRead, GCUserData, FirstPipe,
NoPipeFlash, Broadcast NoPipeFlash, Broadcast, V4Mapped
DualStackType* {.pure.} = enum
Auto, Enabled, Disabled, Default
AddressFamily* {.pure.} = enum AddressFamily* {.pure.} = enum
None, IPv4, IPv6, Unix None, IPv4, IPv6, Unix
@ -70,12 +74,13 @@ when defined(windows) or defined(nimdoc):
udata*: pointer # User-defined pointer udata*: pointer # User-defined pointer
flags*: set[ServerFlags] # Flags flags*: set[ServerFlags] # Flags
bufferSize*: int # Size of internal transports' buffer bufferSize*: int # Size of internal transports' buffer
loopFuture*: Future[void] # Server's main Future loopFuture*: Future[void].Raising([]) # Server's main Future
domain*: Domain # Current server domain (IPv4 or IPv6) domain*: Domain # Current server domain (IPv4 or IPv6)
apending*: bool apending*: bool
asock*: AsyncFD # Current AcceptEx() socket asock*: AsyncFD # Current AcceptEx() socket
errorCode*: OSErrorCode # Current error code errorCode*: OSErrorCode # Current error code
abuffer*: array[128, byte] # Windows AcceptEx() buffer abuffer*: array[128, byte] # Windows AcceptEx() buffer
dualstack*: DualStackType # IPv4/IPv6 dualstack parameters
when defined(windows): when defined(windows):
aovl*: CustomOverlapped # AcceptEx OVERLAPPED structure aovl*: CustomOverlapped # AcceptEx OVERLAPPED structure
else: else:
@ -88,8 +93,9 @@ else:
udata*: pointer # User-defined pointer udata*: pointer # User-defined pointer
flags*: set[ServerFlags] # Flags flags*: set[ServerFlags] # Flags
bufferSize*: int # Size of internal transports' buffer bufferSize*: int # Size of internal transports' buffer
loopFuture*: Future[void] # Server's main Future loopFuture*: Future[void].Raising([]) # Server's main Future
errorCode*: OSErrorCode # Current error code errorCode*: OSErrorCode # Current error code
dualstack*: DualStackType # IPv4/IPv6 dualstack parameters
type type
TransportError* = object of AsyncError TransportError* = object of AsyncError
@ -108,6 +114,8 @@ type
## Transport's capability not supported exception ## Transport's capability not supported exception
TransportUseClosedError* = object of TransportError TransportUseClosedError* = object of TransportError
## Usage after transport close exception ## Usage after transport close exception
TransportUseEofError* = object of TransportError
## Usage after transport half-close exception
TransportTooManyError* = object of TransportError TransportTooManyError* = object of TransportError
## Too many open file descriptors exception ## Too many open file descriptors exception
TransportAbortedError* = object of TransportError TransportAbortedError* = object of TransportError
@ -193,8 +201,17 @@ proc `$`*(address: TransportAddress): string =
of AddressFamily.None: of AddressFamily.None:
"None" "None"
proc toIpAddress*(address: TransportAddress): IpAddress =
case address.family
of AddressFamily.IPv4:
IpAddress(family: IpAddressFamily.IPv4, address_v4: address.address_v4)
of AddressFamily.IPv6:
IpAddress(family: IpAddressFamily.IPv6, address_v6: address.address_v6)
else:
raiseAssert "IpAddress do not support address family " & $address.family
proc toHex*(address: TransportAddress): string = proc toHex*(address: TransportAddress): string =
## Returns hexadecimal representation of ``address`. ## Returns hexadecimal representation of ``address``.
case address.family case address.family
of AddressFamily.IPv4: of AddressFamily.IPv4:
"0x" & address.address_v4.toHex() "0x" & address.address_v4.toHex()
@ -298,6 +315,9 @@ proc getAddrInfo(address: string, port: Port, domain: Domain,
raises: [TransportAddressError].} = raises: [TransportAddressError].} =
## We have this one copy of ``getAddrInfo()`` because of AI_V4MAPPED in ## We have this one copy of ``getAddrInfo()`` because of AI_V4MAPPED in
## ``net.nim:getAddrInfo()``, which is not cross-platform. ## ``net.nim:getAddrInfo()``, which is not cross-platform.
##
## Warning: `ptr AddrInfo` returned by `getAddrInfo()` needs to be freed by
## calling `freeAddrInfo()`.
var hints: AddrInfo var hints: AddrInfo
var res: ptr AddrInfo = nil var res: ptr AddrInfo = nil
hints.ai_family = toInt(domain) hints.ai_family = toInt(domain)
@ -420,6 +440,7 @@ proc resolveTAddress*(address: string, port: Port,
if ta notin res: if ta notin res:
res.add(ta) res.add(ta)
it = it.ai_next it = it.ai_next
freeAddrInfo(aiList)
res res
proc resolveTAddress*(address: string, domain: Domain): seq[TransportAddress] {. proc resolveTAddress*(address: string, domain: Domain): seq[TransportAddress] {.
@ -558,11 +579,11 @@ template checkClosed*(t: untyped, future: untyped) =
template checkWriteEof*(t: untyped, future: untyped) = template checkWriteEof*(t: untyped, future: untyped) =
if (WriteEof in (t).state): if (WriteEof in (t).state):
future.fail(newException(TransportError, future.fail(newException(TransportUseEofError,
"Transport connection is already dropped!")) "Transport connection is already dropped!"))
return future return future
template getError*(t: untyped): ref CatchableError = template getError*(t: untyped): ref TransportError =
var err = (t).error var err = (t).error
(t).error = nil (t).error = nil
err err
@ -585,22 +606,6 @@ proc raiseTransportOsError*(err: OSErrorCode) {.
## Raises transport specific OS error. ## Raises transport specific OS error.
raise getTransportOsError(err) raise getTransportOsError(err)
type
SeqHeader = object
length, reserved: int
proc isLiteral*(s: string): bool {.inline.} =
when defined(gcOrc) or defined(gcArc):
false
else:
(cast[ptr SeqHeader](s).reserved and (1 shl (sizeof(int) * 8 - 2))) != 0
proc isLiteral*[T](s: seq[T]): bool {.inline.} =
when defined(gcOrc) or defined(gcArc):
false
else:
(cast[ptr SeqHeader](s).reserved and (1 shl (sizeof(int) * 8 - 2))) != 0
template getTransportTooManyError*( template getTransportTooManyError*(
code = OSErrorCode(0) code = OSErrorCode(0)
): ref TransportTooManyError = ): ref TransportTooManyError =
@ -716,3 +721,97 @@ proc raiseTransportError*(ecode: OSErrorCode) {.
raise getTransportTooManyError(ecode) raise getTransportTooManyError(ecode)
else: else:
raise getTransportOsError(ecode) raise getTransportOsError(ecode)
proc isAvailable*(family: AddressFamily): bool =
case family
of AddressFamily.None:
raiseAssert "Invalid address family"
of AddressFamily.IPv4:
isAvailable(Domain.AF_INET)
of AddressFamily.IPv6:
isAvailable(Domain.AF_INET6)
of AddressFamily.Unix:
isAvailable(Domain.AF_UNIX)
proc getDomain*(socket: AsyncFD): Result[AddressFamily, OSErrorCode] =
## Returns address family which is used to create socket ``socket``.
##
## Note: `chronos` supports only `AF_INET`, `AF_INET6` and `AF_UNIX` sockets.
## For all other types of sockets this procedure returns
## `EAFNOSUPPORT/WSAEAFNOSUPPORT` error.
when defined(windows):
let protocolInfo = ? getSockOpt2(socket, cint(osdefs.SOL_SOCKET),
cint(osdefs.SO_PROTOCOL_INFOW),
WSAPROTOCOL_INFO)
if protocolInfo.iAddressFamily == toInt(Domain.AF_INET):
ok(AddressFamily.IPv4)
elif protocolInfo.iAddressFamily == toInt(Domain.AF_INET6):
ok(AddressFamily.IPv6)
else:
err(WSAEAFNOSUPPORT)
else:
var
saddr = Sockaddr_storage()
slen = SockLen(sizeof(saddr))
if getsockname(SocketHandle(socket), cast[ptr SockAddr](addr saddr),
addr slen) != 0:
return err(osLastError())
if int(saddr.ss_family) == toInt(Domain.AF_INET):
ok(AddressFamily.IPv4)
elif int(saddr.ss_family) == toInt(Domain.AF_INET6):
ok(AddressFamily.IPv6)
elif int(saddr.ss_family) == toInt(Domain.AF_UNIX):
ok(AddressFamily.Unix)
else:
err(EAFNOSUPPORT)
proc setDualstack*(socket: AsyncFD, family: AddressFamily,
flag: DualStackType): Result[void, OSErrorCode] =
if family == AddressFamily.IPv6:
case flag
of DualStackType.Auto:
# In case of `Auto` we going to ignore all the errors.
discard setDualstack(socket, true)
ok()
of DualStackType.Enabled:
? setDualstack(socket, true)
ok()
of DualStackType.Disabled:
? setDualstack(socket, false)
ok()
of DualStackType.Default:
ok()
else:
ok()
proc setDualstack*(socket: AsyncFD,
flag: DualStackType): Result[void, OSErrorCode] =
let family =
case flag
of DualStackType.Auto:
getDomain(socket).get(AddressFamily.IPv6)
else:
? getDomain(socket)
setDualstack(socket, family, flag)
proc getAutoAddress*(port: Port): TransportAddress =
var res =
if isAvailable(AddressFamily.IPv6):
AnyAddress6
else:
AnyAddress
res.port = port
res
proc getAutoAddresses*(
localPort: Port,
remotePort: Port
): tuple[local: TransportAddress, remote: TransportAddress] =
var (local, remote) =
if isAvailable(AddressFamily.IPv6):
(AnyAddress6, AnyAddress6)
else:
(AnyAddress, AnyAddress)
local.port = localPort
remote.port = remotePort
(local, remote)

File diff suppressed because it is too large Load Diff

View File

@ -52,7 +52,7 @@ proc init*(t: typedesc[IpMask], family: AddressFamily, prefix: int): IpMask =
IpMask(family: AddressFamily.IPv4, mask4: 0'u32) IpMask(family: AddressFamily.IPv4, mask4: 0'u32)
elif prefix < 32: elif prefix < 32:
let mask = 0xFFFF_FFFF'u32 shl (32 - prefix) let mask = 0xFFFF_FFFF'u32 shl (32 - prefix)
IpMask(family: AddressFamily.IPv4, mask4: mask.toBE()) IpMask(family: AddressFamily.IPv4, mask4: mask)
else: else:
IpMask(family: AddressFamily.IPv4, mask4: 0xFFFF_FFFF'u32) IpMask(family: AddressFamily.IPv4, mask4: 0xFFFF_FFFF'u32)
of AddressFamily.IPv6: of AddressFamily.IPv6:
@ -65,13 +65,13 @@ proc init*(t: typedesc[IpMask], family: AddressFamily, prefix: int): IpMask =
if prefix > 64: if prefix > 64:
let mask = 0xFFFF_FFFF_FFFF_FFFF'u64 shl (128 - prefix) let mask = 0xFFFF_FFFF_FFFF_FFFF'u64 shl (128 - prefix)
IpMask(family: AddressFamily.IPv6, IpMask(family: AddressFamily.IPv6,
mask6: [0xFFFF_FFFF_FFFF_FFFF'u64, mask.toBE()]) mask6: [0xFFFF_FFFF_FFFF_FFFF'u64, mask])
elif prefix == 64: elif prefix == 64:
IpMask(family: AddressFamily.IPv6, IpMask(family: AddressFamily.IPv6,
mask6: [0xFFFF_FFFF_FFFF_FFFF'u64, 0'u64]) mask6: [0xFFFF_FFFF_FFFF_FFFF'u64, 0'u64])
else: else:
let mask = 0xFFFF_FFFF_FFFF_FFFF'u64 shl (64 - prefix) let mask = 0xFFFF_FFFF_FFFF_FFFF'u64 shl (64 - prefix)
IpMask(family: AddressFamily.IPv6, mask6: [mask.toBE(), 0'u64]) IpMask(family: AddressFamily.IPv6, mask6: [mask, 0'u64])
else: else:
IpMask(family: family) IpMask(family: family)
@ -80,11 +80,12 @@ proc init*(t: typedesc[IpMask], netmask: TransportAddress): IpMask =
case netmask.family case netmask.family
of AddressFamily.IPv4: of AddressFamily.IPv4:
IpMask(family: AddressFamily.IPv4, IpMask(family: AddressFamily.IPv4,
mask4: uint32.fromBytes(netmask.address_v4)) mask4: uint32.fromBytesBE(netmask.address_v4))
of AddressFamily.IPv6: of AddressFamily.IPv6:
IpMask(family: AddressFamily.IPv6, IpMask(family: AddressFamily.IPv6,
mask6: [uint64.fromBytes(netmask.address_v6.toOpenArray(0, 7)), mask6: [
uint64.fromBytes(netmask.address_v6.toOpenArray(8, 15))]) uint64.fromBytesBE(netmask.address_v6.toOpenArray(0, 7)),
uint64.fromBytesBE(netmask.address_v6.toOpenArray(8, 15))])
else: else:
IpMask(family: netmask.family) IpMask(family: netmask.family)
@ -95,8 +96,7 @@ proc initIp*(t: typedesc[IpMask], netmask: string): IpMask =
## If ``netmask`` address string is invalid, result IpMask.family will be ## If ``netmask`` address string is invalid, result IpMask.family will be
## set to ``AddressFamily.None``. ## set to ``AddressFamily.None``.
try: try:
var ip = parseIpAddress(netmask) let tip = initTAddress(parseIpAddress(netmask), Port(0))
var tip = initTAddress(ip, Port(0))
t.init(tip) t.init(tip)
except ValueError: except ValueError:
IpMask(family: AddressFamily.None) IpMask(family: AddressFamily.None)
@ -127,9 +127,9 @@ proc init*(t: typedesc[IpMask], netmask: string): IpMask =
elif netmask[offset + i] in hexLowers: elif netmask[offset + i] in hexLowers:
v = uint32(ord(netmask[offset + i]) - ord('a') + 10) v = uint32(ord(netmask[offset + i]) - ord('a') + 10)
else: else:
return return IpMask(family: AddressFamily.None)
r = (r shl 4) or v r = (r shl 4) or v
res.mask4 = r.toBE() res.mask4 = r
res res
elif length == 32 or length == (2 + 32): elif length == 32 or length == (2 + 32):
## IPv6 mask ## IPv6 mask
@ -147,10 +147,10 @@ proc init*(t: typedesc[IpMask], netmask: string): IpMask =
elif netmask[offset + i] in hexLowers: elif netmask[offset + i] in hexLowers:
v = uint64(ord(netmask[offset + i]) - ord('a') + 10) v = uint64(ord(netmask[offset + i]) - ord('a') + 10)
else: else:
return return IpMask(family: AddressFamily.None)
r = (r shl 4) or v r = (r shl 4) or v
offset += 16 offset += 16
res.mask6[i] = r.toBE() res.mask6[i] = r
res res
else: else:
IpMask(family: AddressFamily.None) IpMask(family: AddressFamily.None)
@ -167,8 +167,7 @@ proc toIPv6*(address: TransportAddress): TransportAddress =
var address6: array[16, uint8] var address6: array[16, uint8]
address6[10] = 0xFF'u8 address6[10] = 0xFF'u8
address6[11] = 0xFF'u8 address6[11] = 0xFF'u8
let ip4 = uint32.fromBytes(address.address_v4) address6[12 .. 15] = toBytesBE(uint32.fromBytesBE(address.address_v4))
address6[12 .. 15] = ip4.toBytes()
TransportAddress(family: AddressFamily.IPv6, port: address.port, TransportAddress(family: AddressFamily.IPv6, port: address.port,
address_v6: address6) address_v6: address6)
of AddressFamily.IPv6: of AddressFamily.IPv6:
@ -183,9 +182,10 @@ proc isV4Mapped*(address: TransportAddress): bool =
## Procedure returns ``false`` if ``address`` family is IPv4. ## Procedure returns ``false`` if ``address`` family is IPv4.
case address.family case address.family
of AddressFamily.IPv6: of AddressFamily.IPv6:
let data0 = uint64.fromBytes(address.address_v6.toOpenArray(0, 7)) let
let data1 = uint16.fromBytes(address.address_v6.toOpenArray(8, 9)) data0 = uint64.fromBytesBE(address.address_v6.toOpenArray(0, 7))
let data2 = uint16.fromBytes(address.address_v6.toOpenArray(10, 11)) data1 = uint16.fromBytesBE(address.address_v6.toOpenArray(8, 9))
data2 = uint16.fromBytesBE(address.address_v6.toOpenArray(10, 11))
(data0 == 0x00'u64) and (data1 == 0x00'u16) and (data2 == 0xFFFF'u16) (data0 == 0x00'u64) and (data1 == 0x00'u16) and (data2 == 0xFFFF'u16)
else: else:
false false
@ -202,9 +202,9 @@ proc toIPv4*(address: TransportAddress): TransportAddress =
address address
of AddressFamily.IPv6: of AddressFamily.IPv6:
if isV4Mapped(address): if isV4Mapped(address):
let data = uint32.fromBytes(address.address_v6.toOpenArray(12, 15)) let data = uint32.fromBytesBE(address.address_v6.toOpenArray(12, 15))
TransportAddress(family: AddressFamily.IPv4, port: address.port, TransportAddress(family: AddressFamily.IPv4, port: address.port,
address_v4: data.toBytes()) address_v4: data.toBytesBE())
else: else:
TransportAddress(family: AddressFamily.None) TransportAddress(family: AddressFamily.None)
else: else:
@ -230,34 +230,34 @@ proc mask*(a: TransportAddress, m: IpMask): TransportAddress =
## In all other cases returned address will have ``AddressFamily.None``. ## In all other cases returned address will have ``AddressFamily.None``.
if (a.family == AddressFamily.IPv4) and (m.family == AddressFamily.IPv6): if (a.family == AddressFamily.IPv4) and (m.family == AddressFamily.IPv6):
if (m.mask6[0] == 0xFFFF_FFFF_FFFF_FFFF'u64) and if (m.mask6[0] == 0xFFFF_FFFF_FFFF_FFFF'u64) and
(m.mask6[1] and 0xFFFF_FFFF'u64) == 0xFFFF_FFFF'u64: (m.mask6[1] and 0xFFFF_FFFF_0000_0000'u64) == 0xFFFF_FFFF_0000_0000'u64:
let let
mask = uint32((m.mask6[1] shr 32) and 0xFFFF_FFFF'u64) mask = uint32(m.mask6[1] and 0xFFFF_FFFF'u64)
data = uint32.fromBytes(a.address_v4) data = uint32.fromBytesBE(a.address_v4)
TransportAddress(family: AddressFamily.IPv4, port: a.port, TransportAddress(family: AddressFamily.IPv4, port: a.port,
address_v4: (data and mask).toBytes()) address_v4: (data and mask).toBytesBE())
else: else:
TransportAddress(family: AddressFamily.None) TransportAddress(family: AddressFamily.None)
elif (a.family == AddressFamily.IPv6) and (m.family == AddressFamily.IPv4): elif (a.family == AddressFamily.IPv6) and (m.family == AddressFamily.IPv4):
var ip = a.toIPv4() var ip = a.toIPv4()
if ip.family != AddressFamily.IPv4: if ip.family != AddressFamily.IPv4:
return TransportAddress(family: AddressFamily.None) return TransportAddress(family: AddressFamily.None)
let data = uint32.fromBytes(ip.address_v4) let data = uint32.fromBytesBE(ip.address_v4)
ip.address_v4[0 .. 3] = (data and m.mask4).toBytes() ip.address_v4[0 .. 3] = (data and m.mask4).toBytesBE()
var res = ip.toIPv6() var res = ip.toIPv6()
res.port = a.port res.port = a.port
res res
elif a.family == AddressFamily.IPv4 and m.family == AddressFamily.IPv4: elif a.family == AddressFamily.IPv4 and m.family == AddressFamily.IPv4:
let data = uint32.fromBytes(a.address_v4) let data = uint32.fromBytesBE(a.address_v4)
TransportAddress(family: AddressFamily.IPv4, port: a.port, TransportAddress(family: AddressFamily.IPv4, port: a.port,
address_v4: (data and m.mask4).toBytes()) address_v4: (data and m.mask4).toBytesBE())
elif a.family == AddressFamily.IPv6 and m.family == AddressFamily.IPv6: elif a.family == AddressFamily.IPv6 and m.family == AddressFamily.IPv6:
var address6: array[16, uint8] var address6: array[16, uint8]
let let
data0 = uint64.fromBytes(a.address_v6.toOpenArray(0, 7)) data0 = uint64.fromBytesBE(a.address_v6.toOpenArray(0, 7))
data1 = uint64.fromBytes(a.address_v6.toOpenArray(8, 15)) data1 = uint64.fromBytesBE(a.address_v6.toOpenArray(8, 15))
address6[0 .. 7] = (data0 and m.mask6[0]).toBytes() address6[0 .. 7] = (data0 and m.mask6[0]).toBytesBE()
address6[8 .. 15] = (data1 and m.mask6[1]).toBytes() address6[8 .. 15] = (data1 and m.mask6[1]).toBytesBE()
TransportAddress(family: AddressFamily.IPv6, port: a.port, TransportAddress(family: AddressFamily.IPv6, port: a.port,
address_v6: address6) address_v6: address6)
else: else:
@ -272,14 +272,14 @@ proc prefix*(mask: IpMask): int =
of AddressFamily.IPv4: of AddressFamily.IPv4:
var var
res = 0 res = 0
n = mask.mask4.fromBE() n = mask.mask4
while n != 0: while n != 0:
if (n and 0x8000_0000'u32) == 0'u32: return -1 if (n and 0x8000_0000'u32) == 0'u32: return -1
n = n shl 1 n = n shl 1
inc(res) inc(res)
res res
of AddressFamily.IPv6: of AddressFamily.IPv6:
let mask6 = [mask.mask6[0].fromBE(), mask.mask6[1].fromBE()] let mask6 = [mask.mask6[0], mask.mask6[1]]
var res = 0 var res = 0
if mask6[0] == 0xFFFF_FFFF_FFFF_FFFF'u64: if mask6[0] == 0xFFFF_FFFF_FFFF_FFFF'u64:
res += 64 res += 64
@ -308,11 +308,11 @@ proc subnetMask*(mask: IpMask): TransportAddress =
case mask.family case mask.family
of AddressFamily.IPv4: of AddressFamily.IPv4:
TransportAddress(family: AddressFamily.IPv4, TransportAddress(family: AddressFamily.IPv4,
address_v4: mask.mask4.toBytes()) address_v4: mask.mask4.toBytesBE())
of AddressFamily.IPv6: of AddressFamily.IPv6:
var address6: array[16, uint8] var address6: array[16, uint8]
address6[0 .. 7] = mask.mask6[0].toBytes() address6[0 .. 7] = mask.mask6[0].toBytesBE()
address6[8 .. 15] = mask.mask6[1].toBytes() address6[8 .. 15] = mask.mask6[1].toBytesBE()
TransportAddress(family: AddressFamily.IPv6, address_v6: address6) TransportAddress(family: AddressFamily.IPv6, address_v6: address6)
else: else:
TransportAddress(family: mask.family) TransportAddress(family: mask.family)
@ -321,9 +321,10 @@ proc `$`*(mask: IpMask, include0x = false): string =
## Returns hexadecimal string representation of IP mask ``mask``. ## Returns hexadecimal string representation of IP mask ``mask``.
case mask.family case mask.family
of AddressFamily.IPv4: of AddressFamily.IPv4:
var res = if include0x: "0x" else: "" var
var n = 32 res = if include0x: "0x" else: ""
var m = mask.mask4.fromBE() n = 32
m = mask.mask4
while n > 0: while n > 0:
n -= 4 n -= 4
var c = int((m shr n) and 0x0F) var c = int((m shr n) and 0x0F)
@ -333,7 +334,7 @@ proc `$`*(mask: IpMask, include0x = false): string =
res.add(chr(ord('A') + (c - 10))) res.add(chr(ord('A') + (c - 10)))
res res
of AddressFamily.IPv6: of AddressFamily.IPv6:
let mask6 = [mask.mask6[0].fromBE(), mask.mask6[1].fromBE()] let mask6 = [mask.mask6[0], mask.mask6[1]]
var res = if include0x: "0x" else: "" var res = if include0x: "0x" else: ""
for i in 0 .. 1: for i in 0 .. 1:
var n = 64 var n = 64
@ -353,12 +354,11 @@ proc ip*(mask: IpMask): string {.raises: [ValueError].} =
## Returns IP address text representation of IP mask ``mask``. ## Returns IP address text representation of IP mask ``mask``.
case mask.family case mask.family
of AddressFamily.IPv4: of AddressFamily.IPv4:
var address4: array[4, uint8] $IpAddress(family: IpAddressFamily.IPv4, address_v4: mask.mask4.toBytesBE())
copyMem(addr address4[0], unsafeAddr mask.mask4, sizeof(uint32))
$IpAddress(family: IpAddressFamily.IPv4, address_v4: address4)
of AddressFamily.Ipv6: of AddressFamily.Ipv6:
var address6: array[16, uint8] var address6: array[16, uint8]
copyMem(addr address6[0], unsafeAddr mask.mask6[0], 16) address6[0 .. 7] = mask.mask6[0].toBytesBE()
address6[8 .. 15] = mask.mask6[1].toBytesBE()
$IpAddress(family: IpAddressFamily.IPv6, address_v6: address6) $IpAddress(family: IpAddressFamily.IPv6, address_v6: address6)
else: else:
raise newException(ValueError, "Invalid mask family type") raise newException(ValueError, "Invalid mask family type")
@ -387,11 +387,12 @@ proc init*(t: typedesc[IpNet], network: string): IpNet {.
raises: [TransportAddressError].} = raises: [TransportAddressError].} =
## Initialize IP Network from string representation in format ## Initialize IP Network from string representation in format
## <address>/<prefix length> or <address>/<netmask address>. ## <address>/<prefix length> or <address>/<netmask address>.
var parts = network.rsplit("/", maxsplit = 1) var
var host, mhost: TransportAddress parts = network.rsplit("/", maxsplit = 1)
var ipaddr: IpAddress host, mhost: TransportAddress
var mask: IpMask ipaddr: IpAddress
var prefix: int mask: IpMask
prefix: int
try: try:
ipaddr = parseIpAddress(parts[0]) ipaddr = parseIpAddress(parts[0])
if ipaddr.family == IpAddressFamily.IPv4: if ipaddr.family == IpAddressFamily.IPv4:
@ -428,9 +429,9 @@ proc init*(t: typedesc[IpNet], network: string): IpNet {.
raise newException(TransportAddressError, raise newException(TransportAddressError,
"Incorrect network address!") "Incorrect network address!")
if prefix == -1: if prefix == -1:
result = t.init(host, mask) t.init(host, mask)
else: else:
result = t.init(host, prefix) t.init(host, prefix)
except ValueError as exc: except ValueError as exc:
raise newException(TransportAddressError, exc.msg) raise newException(TransportAddressError, exc.msg)
@ -461,19 +462,19 @@ proc broadcast*(net: IpNet): TransportAddress =
case net.host.family case net.host.family
of AddressFamily.IPv4: of AddressFamily.IPv4:
let let
host = uint32.fromBytes(net.host.address_v4) host = uint32.fromBytesBE(net.host.address_v4)
mask = net.mask.mask4 mask = net.mask.mask4
TransportAddress(family: AddressFamily.IPv4, TransportAddress(family: AddressFamily.IPv4,
address_v4: (host or (not(mask))).toBytes()) address_v4: (host or (not(mask))).toBytesBE())
of AddressFamily.IPv6: of AddressFamily.IPv6:
var address6: array[16, uint8] var address6: array[16, uint8]
let let
host0 = uint64.fromBytes(net.host.address_v6.toOpenArray(0, 7)) host0 = uint64.fromBytesBE(net.host.address_v6.toOpenArray(0, 7))
host1 = uint64.fromBytes(net.host.address_v6.toOpenArray(8, 15)) host1 = uint64.fromBytesBE(net.host.address_v6.toOpenArray(8, 15))
data0 = net.mask.mask6[0] data0 = net.mask.mask6[0]
data1 = net.mask.mask6[1] data1 = net.mask.mask6[1]
address6[0 .. 7] = (host0 or (not(data0))).toBytes() address6[0 .. 7] = (host0 or (not(data0))).toBytesBE()
address6[8 .. 15] = (host1 or (not(data1))).toBytes() address6[8 .. 15] = (host1 or (not(data1))).toBytesBE()
TransportAddress(family: AddressFamily.IPv6, address_v6: address6) TransportAddress(family: AddressFamily.IPv6, address_v6: address6)
else: else:
TransportAddress(family: AddressFamily.None) TransportAddress(family: AddressFamily.None)
@ -496,19 +497,19 @@ proc `and`*(address1, address2: TransportAddress): TransportAddress =
case address1.family case address1.family
of AddressFamily.IPv4: of AddressFamily.IPv4:
let let
data1 = uint32.fromBytes(address1.address_v4) data1 = uint32.fromBytesBE(address1.address_v4)
data2 = uint32.fromBytes(address2.address_v4) data2 = uint32.fromBytesBE(address2.address_v4)
TransportAddress(family: AddressFamily.IPv4, TransportAddress(family: AddressFamily.IPv4,
address_v4: (data1 and data2).toBytes()) address_v4: (data1 and data2).toBytesBE())
of AddressFamily.IPv6: of AddressFamily.IPv6:
var address6: array[16, uint8] var address6: array[16, uint8]
let let
data1 = uint64.fromBytes(address1.address_v6.toOpenArray(0, 7)) data1 = uint64.fromBytesBE(address1.address_v6.toOpenArray(0, 7))
data2 = uint64.fromBytes(address1.address_v6.toOpenArray(8, 15)) data2 = uint64.fromBytesBE(address1.address_v6.toOpenArray(8, 15))
data3 = uint64.fromBytes(address2.address_v6.toOpenArray(0, 7)) data3 = uint64.fromBytesBE(address2.address_v6.toOpenArray(0, 7))
data4 = uint64.fromBytes(address2.address_v6.toOpenArray(8, 15)) data4 = uint64.fromBytesBE(address2.address_v6.toOpenArray(8, 15))
address6[0 .. 7] = (data1 and data3).toBytes() address6[0 .. 7] = (data1 and data3).toBytesBE()
address6[8 .. 15] = (data2 and data4).toBytes() address6[8 .. 15] = (data2 and data4).toBytesBE()
TransportAddress(family: AddressFamily.IPv6, address_v6: address6) TransportAddress(family: AddressFamily.IPv6, address_v6: address6)
else: else:
raiseAssert "Invalid address family type" raiseAssert "Invalid address family type"
@ -522,19 +523,19 @@ proc `or`*(address1, address2: TransportAddress): TransportAddress =
case address1.family case address1.family
of AddressFamily.IPv4: of AddressFamily.IPv4:
let let
data1 = uint32.fromBytes(address1.address_v4) data1 = uint32.fromBytesBE(address1.address_v4)
data2 = uint32.fromBytes(address2.address_v4) data2 = uint32.fromBytesBE(address2.address_v4)
TransportAddress(family: AddressFamily.IPv4, TransportAddress(family: AddressFamily.IPv4,
address_v4: (data1 or data2).toBytes()) address_v4: (data1 or data2).toBytesBE())
of AddressFamily.IPv6: of AddressFamily.IPv6:
var address6: array[16, uint8] var address6: array[16, uint8]
let let
data1 = uint64.fromBytes(address1.address_v6.toOpenArray(0, 7)) data1 = uint64.fromBytesBE(address1.address_v6.toOpenArray(0, 7))
data2 = uint64.fromBytes(address1.address_v6.toOpenArray(8, 15)) data2 = uint64.fromBytesBE(address1.address_v6.toOpenArray(8, 15))
data3 = uint64.fromBytes(address2.address_v6.toOpenArray(0, 7)) data3 = uint64.fromBytesBE(address2.address_v6.toOpenArray(0, 7))
data4 = uint64.fromBytes(address2.address_v6.toOpenArray(8, 15)) data4 = uint64.fromBytesBE(address2.address_v6.toOpenArray(8, 15))
address6[0 .. 7] = (data1 or data3).toBytes() address6[0 .. 7] = (data1 or data3).toBytesBE()
address6[8 .. 15] = (data2 or data4).toBytes() address6[8 .. 15] = (data2 or data4).toBytesBE()
TransportAddress(family: AddressFamily.IPv6, address_v6: address6) TransportAddress(family: AddressFamily.IPv6, address_v6: address6)
else: else:
raiseAssert "Invalid address family type" raiseAssert "Invalid address family type"
@ -543,15 +544,15 @@ proc `not`*(address: TransportAddress): TransportAddress =
## Bitwise ``not`` operation for ``address``. ## Bitwise ``not`` operation for ``address``.
case address.family case address.family
of AddressFamily.IPv4: of AddressFamily.IPv4:
let data = not(uint32.fromBytes(address.address_v4)) let data = not(uint32.fromBytesBE(address.address_v4))
TransportAddress(family: AddressFamily.IPv4, address_v4: data.toBytes()) TransportAddress(family: AddressFamily.IPv4, address_v4: data.toBytesBE())
of AddressFamily.IPv6: of AddressFamily.IPv6:
var address6: array[16, uint8] var address6: array[16, uint8]
let let
data1 = not(uint64.fromBytes(address.address_v6.toOpenArray(0, 7))) data1 = not(uint64.fromBytesBE(address.address_v6.toOpenArray(0, 7)))
data2 = not(uint64.fromBytes(address.address_v6.toOpenArray(8, 15))) data2 = not(uint64.fromBytesBE(address.address_v6.toOpenArray(8, 15)))
address6[0 .. 7] = data1.toBytes() address6[0 .. 7] = data1.toBytesBE()
address6[8 .. 15] = data2.toBytes() address6[8 .. 15] = data2.toBytesBE()
TransportAddress(family: AddressFamily.IPv6, address_v6: address6) TransportAddress(family: AddressFamily.IPv6, address_v6: address6)
else: else:
address address
@ -702,10 +703,10 @@ proc isZero*(address: TransportAddress): bool {.inline.} =
## not ``AddressFamily.None``. ## not ``AddressFamily.None``.
case address.family case address.family
of AddressFamily.IPv4: of AddressFamily.IPv4:
uint32.fromBytes(a4()) == 0'u32 uint32.fromBytesBE(a4()) == 0'u32
of AddressFamily.IPv6: of AddressFamily.IPv6:
(uint64.fromBytes(a6.toOpenArray(0, 7)) == 0'u64) and (uint64.fromBytesBE(a6.toOpenArray(0, 7)) == 0'u64) and
(uint64.fromBytes(a6.toOpenArray(8, 15)) == 0'u64) (uint64.fromBytesBE(a6.toOpenArray(8, 15)) == 0'u64)
of AddressFamily.Unix: of AddressFamily.Unix:
len($cast[cstring](unsafeAddr address.address_un[0])) == 0 len($cast[cstring](unsafeAddr address.address_un[0])) == 0
else: else:
@ -804,7 +805,7 @@ proc isLoopback*(address: TransportAddress): bool =
of AddressFamily.IPv4: of AddressFamily.IPv4:
a4[0] == 127'u8 a4[0] == 127'u8
of AddressFamily.IPv6: of AddressFamily.IPv6:
(uint64.fromBytes(a6.toOpenArray(0, 7)) == 0x00'u64) and (uint64.fromBytesBE(a6.toOpenArray(0, 7)) == 0x00'u64) and
(uint64.fromBytesBE(a6.toOpenArray(8, 15)) == 0x01'u64) (uint64.fromBytesBE(a6.toOpenArray(8, 15)) == 0x01'u64)
else: else:
false false
@ -817,10 +818,10 @@ proc isAnyLocal*(address: TransportAddress): bool =
## ``IPv6``: :: ## ``IPv6``: ::
case address.family case address.family
of AddressFamily.IPv4: of AddressFamily.IPv4:
uint32.fromBytes(a4) == 0'u32 uint32.fromBytesBE(a4) == 0'u32
of AddressFamily.IPv6: of AddressFamily.IPv6:
(uint64.fromBytes(a6.toOpenArray(0, 7)) == 0x00'u64) and (uint64.fromBytesBE(a6.toOpenArray(0, 7)) == 0x00'u64) and
(uint64.fromBytes(a6.toOpenArray(8, 15)) == 0x00'u64) (uint64.fromBytesBE(a6.toOpenArray(8, 15)) == 0x00'u64)
else: else:
false false
@ -895,7 +896,7 @@ proc isBroadcast*(address: TransportAddress): bool =
## ``IPv4``: 255.255.255.255 ## ``IPv4``: 255.255.255.255
case address.family case address.family
of AddressFamily.IPv4: of AddressFamily.IPv4:
uint32.fromBytes(a4) == 0xFFFF_FFFF'u32 uint32.fromBytesBE(a4) == 0xFFFF_FFFF'u32
of AddressFamily.IPv6: of AddressFamily.IPv6:
false false
else: else:
@ -916,7 +917,7 @@ proc isBenchmarking*(address: TransportAddress): bool =
of AddressFamily.IPv6: of AddressFamily.IPv6:
(uint16.fromBytesBE(a6.toOpenArray(0, 1)) == 0x2001'u16) and (uint16.fromBytesBE(a6.toOpenArray(0, 1)) == 0x2001'u16) and
(uint16.fromBytesBE(a6.toOpenArray(2, 3)) == 0x02'u16) and (uint16.fromBytesBE(a6.toOpenArray(2, 3)) == 0x02'u16) and
(uint16.fromBytes(a6.toOpenArray(4, 5)) == 0x00'u16) (uint16.fromBytesBE(a6.toOpenArray(4, 5)) == 0x00'u16)
else: else:
false false
@ -980,9 +981,9 @@ proc isGlobal*(address: TransportAddress): bool =
address.isLoopback() or address.isLoopback() or
( (
# IPv4-Mapped `::FFFF:0:0/96` # IPv4-Mapped `::FFFF:0:0/96`
(uint64.fromBytes(a6.toOpenArray(0, 7)) == 0x00'u64) and (uint64.fromBytesBE(a6.toOpenArray(0, 7)) == 0x00'u64) and
(uint16.fromBytes(a6.toOpenArray(8, 9)) == 0x00'u16) and (uint16.fromBytesBE(a6.toOpenArray(8, 9)) == 0x00'u16) and
(uint16.fromBytes(a6.toOpenArray(10, 11)) == 0xFFFF'u16) (uint16.fromBytesBE(a6.toOpenArray(10, 11)) == 0xFFFF'u16)
) or ) or
( (
# IPv4-IPv6 Translation `64:FF9B:1::/48` # IPv4-IPv6 Translation `64:FF9B:1::/48`
@ -993,8 +994,8 @@ proc isGlobal*(address: TransportAddress): bool =
( (
# Discard-Only Address Block `100::/64` # Discard-Only Address Block `100::/64`
(uint16.fromBytesBE(a6.toOpenArray(0, 1)) == 0x100'u16) and (uint16.fromBytesBE(a6.toOpenArray(0, 1)) == 0x100'u16) and
(uint32.fromBytes(a6.toOpenArray(2, 5)) == 0x00'u32) and (uint32.fromBytesBE(a6.toOpenArray(2, 5)) == 0x00'u32) and
(uint16.fromBytes(a6.toOpenArray(6, 7)) == 0x00'u16) (uint16.fromBytesBE(a6.toOpenArray(6, 7)) == 0x00'u16)
) or ) or
( (
# IETF Protocol Assignments `2001::/23` # IETF Protocol Assignments `2001::/23`
@ -1004,15 +1005,15 @@ proc isGlobal*(address: TransportAddress): bool =
( (
# Port Control Protocol Anycast `2001:1::1` # Port Control Protocol Anycast `2001:1::1`
(uint32.fromBytesBE(a6.toOpenArray(0, 3)) == 0x20010001'u32) and (uint32.fromBytesBE(a6.toOpenArray(0, 3)) == 0x20010001'u32) and
(uint32.fromBytes(a6.toOpenArray(4, 7)) == 0x00'u32) and (uint32.fromBytesBE(a6.toOpenArray(4, 7)) == 0x00'u32) and
(uint32.fromBytes(a6.toOpenArray(8, 11)) == 0x00'u32) and (uint32.fromBytesBE(a6.toOpenArray(8, 11)) == 0x00'u32) and
(uint32.fromBytesBE(a6.toOpenArray(12, 15)) == 0x01'u32) (uint32.fromBytesBE(a6.toOpenArray(12, 15)) == 0x01'u32)
) or ) or
( (
# Traversal Using Relays around NAT Anycast `2001:1::2` # Traversal Using Relays around NAT Anycast `2001:1::2`
(uint32.fromBytesBE(a6.toOpenArray(0, 3)) == 0x20010001'u32) and (uint32.fromBytesBE(a6.toOpenArray(0, 3)) == 0x20010001'u32) and
(uint32.fromBytes(a6.toOpenArray(4, 7)) == 0x00'u32) and (uint32.fromBytesBE(a6.toOpenArray(4, 7)) == 0x00'u32) and
(uint32.fromBytes(a6.toOpenArray(8, 11)) == 0x00'u32) and (uint32.fromBytesBE(a6.toOpenArray(8, 11)) == 0x00'u32) and
(uint32.fromBytesBE(a6.toOpenArray(12, 15)) == 0x02'u32) (uint32.fromBytesBE(a6.toOpenArray(12, 15)) == 0x02'u32)
) or ) or
( (
@ -1025,7 +1026,7 @@ proc isGlobal*(address: TransportAddress): bool =
(uint16.fromBytesBE(a6.toOpenArray(0, 1)) == 0x2001'u16) and (uint16.fromBytesBE(a6.toOpenArray(0, 1)) == 0x2001'u16) and
(uint16.fromBytesBE(a6.toOpenArray(2, 3)) == 0x04'u16) and (uint16.fromBytesBE(a6.toOpenArray(2, 3)) == 0x04'u16) and
(uint16.fromBytesBE(a6.toOpenArray(4, 5)) == 0x112'u16) and (uint16.fromBytesBE(a6.toOpenArray(4, 5)) == 0x112'u16) and
(uint16.fromBytes(a6.toOpenArray(6, 7)) == 0x00'u16) (uint16.fromBytesBE(a6.toOpenArray(6, 7)) == 0x00'u16)
) or ) or
( (
# ORCHIDv2 `2001:20::/28` # ORCHIDv2 `2001:20::/28`

View File

@ -677,10 +677,10 @@ when defined(linux):
var msg = cast[ptr NlMsgHeader](addr data[0]) var msg = cast[ptr NlMsgHeader](addr data[0])
var endflag = false var endflag = false
while NLMSG_OK(msg, length): while NLMSG_OK(msg, length):
if msg.nlmsg_type == NLMSG_ERROR: if msg.nlmsg_type in [uint16(NLMSG_DONE), uint16(NLMSG_ERROR)]:
endflag = true endflag = true
break break
else: elif msg.nlmsg_type == RTM_NEWROUTE:
res = processRoute(msg) res = processRoute(msg)
endflag = true endflag = true
break break

File diff suppressed because it is too large Load Diff

View File

@ -21,11 +21,12 @@ template asyncTest*(name: string, body: untyped): untyped =
template checkLeaks*(name: string): untyped = template checkLeaks*(name: string): untyped =
let counter = getTrackerCounter(name) let counter = getTrackerCounter(name)
if counter.opened != counter.closed: checkpoint:
echo "[" & name & "] opened = ", counter.opened, "[" & name & "] opened = " & $counter.opened &
", closed = ", counter.closed ", closed = " & $ counter.closed
check counter.opened == counter.closed check counter.opened == counter.closed
template checkLeaks*(): untyped = proc checkLeaks*() =
for key in getThreadDispatcher().trackerCounterKeys(): for key in getThreadDispatcher().trackerCounterKeys():
checkLeaks(key) checkLeaks(key)
GC_fullCollect()

1
docs/.gitignore vendored Normal file
View File

@ -0,0 +1 @@
book

20
docs/book.toml Normal file
View File

@ -0,0 +1,20 @@
[book]
authors = ["Jacek Sieka"]
language = "en"
multilingual = false
src = "src"
title = "Chronos"
[preprocessor.toc]
command = "mdbook-toc"
renderer = ["html"]
max-level = 2
[preprocessor.open-on-gh]
command = "mdbook-open-on-gh"
renderer = ["html"]
[output.html]
git-repository-url = "https://github.com/status-im/nim-chronos/"
git-branch = "master"
additional-css = ["open-in.css"]

View File

@ -0,0 +1,21 @@
## Simple cancellation example
import chronos
proc someTask() {.async.} = await sleepAsync(10.minutes)
proc cancellationExample() {.async.} =
# Start a task but don't wait for it to finish
let future = someTask()
future.cancelSoon()
# `cancelSoon` schedules but does not wait for the future to get cancelled -
# it might still be pending here
let future2 = someTask() # Start another task concurrently
await future2.cancelAndWait()
# Using `cancelAndWait`, we can be sure that `future2` is either
# complete, failed or cancelled at this point. `future` could still be
# pending!
assert future2.finished()
waitFor(cancellationExample())

View File

@ -0,0 +1,28 @@
## The peculiarities of `discard` in `async` procedures
import chronos
proc failingOperation() {.async.} =
echo "Raising!"
raise (ref ValueError)(msg: "My error")
proc myApp() {.async.} =
# This style of discard causes the `ValueError` to be discarded, hiding the
# failure of the operation - avoid!
discard failingOperation()
proc runAsTask(fut: Future[void]): Future[void] {.async: (raises: []).} =
# runAsTask uses `raises: []` to ensure at compile-time that no exceptions
# escape it!
try:
await fut
except CatchableError as exc:
echo "The task failed! ", exc.msg
# asyncSpawn ensures that errors don't leak unnoticed from tasks without
# blocking:
asyncSpawn runAsTask(failingOperation())
# If we didn't catch the exception with `runAsTask`, the program will crash:
asyncSpawn failingOperation()
waitFor myApp()

15
docs/examples/httpget.nim Normal file
View File

@ -0,0 +1,15 @@
import chronos/apps/http/httpclient
proc retrievePage*(uri: string): Future[string] {.async.} =
# Create a new HTTP session
let httpSession = HttpSessionRef.new()
try:
# Fetch page contents
let resp = await httpSession.fetch(parseUri(uri))
# Convert response to a string, assuming its encoding matches the terminal!
bytesToString(resp.data)
finally: # Close the session
await noCancel(httpSession.closeWait())
echo waitFor retrievePage(
"https://raw.githubusercontent.com/status-im/nim-chronos/master/README.md")

View File

@ -0,0 +1,130 @@
import chronos/apps/http/httpserver
{.push raises: [].}
proc firstMiddlewareHandler(
middleware: HttpServerMiddlewareRef,
reqfence: RequestFence,
nextHandler: HttpProcessCallback2
): Future[HttpResponseRef] {.async: (raises: [CancelledError]).} =
if reqfence.isErr():
# Ignore request errors
return await nextHandler(reqfence)
let request = reqfence.get()
var headers = request.headers
if request.uri.path.startsWith("/path/to/hidden/resources"):
headers.add("X-Filter", "drop")
elif request.uri.path.startsWith("/path/to/blocked/resources"):
headers.add("X-Filter", "block")
else:
headers.add("X-Filter", "pass")
# Updating request by adding new HTTP header `X-Filter`.
let res = request.updateRequest(headers)
if res.isErr():
# We use default error handler in case of error which will respond with
# proper HTTP status code error.
return defaultResponse(res.error)
# Calling next handler.
await nextHandler(reqfence)
proc secondMiddlewareHandler(
middleware: HttpServerMiddlewareRef,
reqfence: RequestFence,
nextHandler: HttpProcessCallback2
): Future[HttpResponseRef] {.async: (raises: [CancelledError]).} =
if reqfence.isErr():
# Ignore request errors
return await nextHandler(reqfence)
let
request = reqfence.get()
filtered = request.headers.getString("X-Filter", "pass")
if filtered == "drop":
# Force HTTP server to drop connection with remote peer.
dropResponse()
elif filtered == "block":
# Force HTTP server to respond with HTTP `404 Not Found` error code.
codeResponse(Http404)
else:
# Calling next handler.
await nextHandler(reqfence)
proc thirdMiddlewareHandler(
middleware: HttpServerMiddlewareRef,
reqfence: RequestFence,
nextHandler: HttpProcessCallback2
): Future[HttpResponseRef] {.async: (raises: [CancelledError]).} =
if reqfence.isErr():
# Ignore request errors
return await nextHandler(reqfence)
let request = reqfence.get()
echo "QUERY = [", request.rawPath, "]"
echo request.headers
try:
if request.uri.path == "/path/to/plugin/resources/page1":
await request.respond(Http200, "PLUGIN PAGE1")
elif request.uri.path == "/path/to/plugin/resources/page2":
await request.respond(Http200, "PLUGIN PAGE2")
else:
# Calling next handler.
await nextHandler(reqfence)
except HttpWriteError as exc:
# We use default error handler if we unable to send response.
defaultResponse(exc)
proc mainHandler(
reqfence: RequestFence
): Future[HttpResponseRef] {.async: (raises: [CancelledError]).} =
if reqfence.isErr():
return defaultResponse()
let request = reqfence.get()
try:
if request.uri.path == "/path/to/original/page1":
await request.respond(Http200, "ORIGINAL PAGE1")
elif request.uri.path == "/path/to/original/page2":
await request.respond(Http200, "ORIGINAL PAGE2")
else:
# Force HTTP server to respond with `404 Not Found` status code.
codeResponse(Http404)
except HttpWriteError as exc:
defaultResponse(exc)
proc middlewareExample() {.async: (raises: []).} =
let
middlewares = [
HttpServerMiddlewareRef(handler: firstMiddlewareHandler),
HttpServerMiddlewareRef(handler: secondMiddlewareHandler),
HttpServerMiddlewareRef(handler: thirdMiddlewareHandler)
]
socketFlags = {ServerFlags.TcpNoDelay, ServerFlags.ReuseAddr}
boundAddress =
if isAvailable(AddressFamily.IPv6):
AnyAddress6
else:
AnyAddress
res = HttpServerRef.new(boundAddress, mainHandler,
socketFlags = socketFlags,
middlewares = middlewares)
doAssert(res.isOk(), "Unable to start HTTP server")
let server = res.get()
server.start()
let address = server.instance.localAddress()
echo "HTTP server running on ", address
try:
await server.join()
except CancelledError:
discard
finally:
await server.stop()
await server.closeWait()
when isMainModule:
waitFor(middlewareExample())

1
docs/examples/nim.cfg Normal file
View File

@ -0,0 +1 @@
path = "../.."

View File

@ -0,0 +1,38 @@
import chronos, chronos/threadsync
import os
type
Context = object
# Context allocated by `createShared` should contain no garbage-collected
# types!
signal: ThreadSignalPtr
value: int
proc myThread(ctx: ptr Context) {.thread.} =
echo "Doing some work in a thread"
sleep(3000)
ctx.value = 42
echo "Done, firing the signal"
discard ctx.signal.fireSync().expect("correctly initialized signal should not fail")
proc main() {.async.} =
let
signal = ThreadSignalPtr.new().expect("free file descriptor for signal")
context = createShared(Context)
context.signal = signal
var thread: Thread[ptr Context]
echo "Starting thread"
createThread(thread, myThread, context)
await signal.wait()
echo "Work done: ", context.value
joinThread(thread)
signal.close().expect("closing once works")
deallocShared(context)
waitFor main()

View File

@ -0,0 +1,25 @@
## Single timeout for several operations
import chronos
proc shortTask {.async.} =
try:
await sleepAsync(1.seconds)
except CancelledError as exc:
echo "Short task was cancelled!"
raise exc # Propagate cancellation to the next operation
proc composedTimeout() {.async.} =
let
# Common timout for several sub-tasks
timeout = sleepAsync(10.seconds)
while not timeout.finished():
let task = shortTask() # Start a task but don't `await` it
if (await race(task, timeout)) == task:
echo "Ran one more task"
else:
# This cancellation may or may not happen as task might have finished
# right at the timeout!
task.cancelSoon()
waitFor composedTimeout()

View File

@ -0,0 +1,20 @@
## Simple timeouts
import chronos
proc longTask {.async.} =
try:
await sleepAsync(10.minutes)
except CancelledError as exc:
echo "Long task was cancelled!"
raise exc # Propagate cancellation to the next operation
proc simpleTimeout() {.async.} =
let
task = longTask() # Start a task but don't `await` it
if not await task.withTimeout(1.seconds):
echo "Timeout reached - withTimeout should have cancelled the task"
else:
echo "Task completed"
waitFor simpleTimeout()

24
docs/examples/twogets.nim Normal file
View File

@ -0,0 +1,24 @@
## Make two http requests concurrently and output the one that wins
import chronos
import ./httpget
proc twoGets() {.async.} =
let
futs = @[
# Both pages will start downloading concurrently...
httpget.retrievePage("https://duckduckgo.com/?q=chronos"),
httpget.retrievePage("https://www.google.fr/search?q=chronos")
]
# Wait for at least one request to finish..
let winner = await one(futs)
# ..and cancel the others since we won't need them
for fut in futs:
# Trying to cancel an already-finished future is harmless
fut.cancelSoon()
# An exception could be raised here if the winning request failed!
echo "Result: ", winner.read()
waitFor(twoGets())

7
docs/open-in.css Normal file
View File

@ -0,0 +1,7 @@
footer {
font-size: 0.8em;
text-align: center;
border-top: 1px solid black;
padding: 5px 0;
}

16
docs/src/SUMMARY.md Normal file
View File

@ -0,0 +1,16 @@
- [Introduction](./introduction.md)
- [Examples](./examples.md)
# User guide
- [Core concepts](./concepts.md)
- [`async` functions](async_procs.md)
- [Errors and exceptions](./error_handling.md)
- [Threads](./threads.md)
- [Tips, tricks and best practices](./tips.md)
- [Porting code to `chronos`](./porting.md)
- [HTTP server middleware](./http_server_middleware.md)
# Developer guide
- [Updating this book](./book.md)

123
docs/src/async_procs.md Normal file
View File

@ -0,0 +1,123 @@
# Async procedures
Async procedures are those that interact with `chronos` to cooperatively
suspend and resume their execution depending on the completion of other
async procedures, timers, tasks on other threads or asynchronous I/O scheduled
with the operating system.
Async procedures are marked with the `{.async.}` pragma and return a `Future`
indicating the state of the operation.
<!-- toc -->
## The `async` pragma
The `{.async.}` pragma will transform a procedure (or a method) returning a
`Future` into a closure iterator. If there is no return type specified,
`Future[void]` is returned.
```nim
proc p() {.async.} =
await sleepAsync(100.milliseconds)
echo p().type # prints "Future[system.void]"
```
## `await` keyword
The `await` keyword operates on `Future` instances typically returned from an
`async` procedure.
Whenever `await` is encountered inside an async procedure, control is given
back to the dispatcher for as many steps as it's necessary for the awaited
future to complete, fail or be cancelled. `await` calls the
equivalent of `Future.read()` on the completed future to return the
encapsulated value when the operation finishes.
```nim
proc p1() {.async.} =
await sleepAsync(1.seconds)
proc p2() {.async.} =
await sleepAsync(1.seconds)
proc p3() {.async.} =
let
fut1 = p1()
fut2 = p2()
# Just by executing the async procs, both resulting futures entered the
# dispatcher queue and their "clocks" started ticking.
await fut1
await fut2
# Only one second passed while awaiting them both, not two.
waitFor p3()
```
```admonition warning
Because `async` procedures are executed concurrently, they are subject to many
of the same risks that typically accompany multithreaded programming.
In particular, if two `async` procedures have access to the same mutable state,
the value before and after `await` might not be the same as the order of execution is not guaranteed!
```
## Raw async procedures
Raw async procedures are those that interact with `chronos` via the `Future`
type but whose body does not go through the async transformation.
Such functions are created by adding `raw: true` to the `async` parameters:
```nim
proc rawAsync(): Future[void] {.async: (raw: true).} =
let fut = newFuture[void]("rawAsync")
fut.complete()
fut
```
Raw functions must not raise exceptions directly - they are implicitly declared
as `raises: []` - instead they should store exceptions in the returned `Future`:
```nim
proc rawFailure(): Future[void] {.async: (raw: true).} =
let fut = newFuture[void]("rawAsync")
fut.fail((ref ValueError)(msg: "Oh no!"))
fut
```
Raw procedures can also use checked exceptions:
```nim
proc rawAsyncRaises(): Future[void] {.async: (raw: true, raises: [IOError]).} =
let fut = newFuture[void]()
assert not (compiles do: fut.fail((ref ValueError)(msg: "uh-uh")))
fut.fail((ref IOError)(msg: "IO"))
fut
```
## Callbacks and closures
Callback/closure types are declared using the `async` annotation as usual:
```nim
type MyCallback = proc(): Future[void] {.async.}
proc runCallback(cb: MyCallback) {.async: (raises: []).} =
try:
await cb()
except CatchableError:
discard # handle errors as usual
```
When calling a callback, it is important to remember that it may raise exceptions that need to be handled.
Checked exceptions can be used to limit the exceptions that a callback can
raise:
```nim
type MyEasyCallback = proc(): Future[void] {.async: (raises: []).}
proc runCallback(cb: MyEasyCallback) {.async: (raises: [])} =
await cb()
```

214
docs/src/concepts.md Normal file
View File

@ -0,0 +1,214 @@
# Concepts
Async/await is a programming model that relies on cooperative multitasking to
coordinate the concurrent execution of procedures, using event notifications
from the operating system or other treads to resume execution.
Code execution happens in a loop that alternates between making progress on
tasks and handling events.
<!-- toc -->
## The dispatcher
The event handler loop is called a "dispatcher" and a single instance per
thread is created, as soon as one is needed.
Scheduling is done by calling [async procedures](./async_procs.md) that return
`Future` objects - each time a procedure is unable to make further
progress, for example because it's waiting for some data to arrive, it hands
control back to the dispatcher which ensures that the procedure is resumed when
ready.
A single thread, and thus a single dispatcher, is typically able to handle
thousands of concurrent in-progress requests.
## The `Future` type
`Future` objects encapsulate the outcome of executing an `async` procedure. The
`Future` may be `pending` meaning that the outcome is not yet known or
`finished` meaning that the return value is available, the operation failed
with an exception or was cancelled.
Inside an async procedure, you can `await` the outcome of another async
procedure - if the `Future` representing that operation is still `pending`, a
callback representing where to resume execution will be added to it and the
dispatcher will be given back control to deal with other tasks.
When a `Future` is `finished`, all its callbacks are scheduled to be run by
the dispatcher, thus continuing any operations that were waiting for an outcome.
## The `poll` call
To trigger the processing step of the dispatcher, we need to call `poll()` -
either directly or through a wrapper like `runForever()` or `waitFor()`.
Each call to poll handles any file descriptors, timers and callbacks that are
ready to be processed.
Using `waitFor`, the result of a single asynchronous operation can be obtained:
```nim
proc myApp() {.async.} =
echo "Waiting for a second..."
await sleepAsync(1.seconds)
echo "done!"
waitFor myApp()
```
It is also possible to keep running the event loop forever using `runForever`:
```nim
proc myApp() {.async.} =
while true:
await sleepAsync(1.seconds)
echo "A bit more than a second passed!"
let future = myApp()
runForever()
```
Such an application never terminates, thus it is rare that applications are
structured this way.
```admonish warning
Both `waitFor` and `runForever` call `poll` which offers fine-grained
control over the event loop steps.
Nested calls to `poll` - directly or indirectly via `waitFor` and `runForever`
are not allowed.
```
## Cancellation
Any pending `Future` can be cancelled. This can be used for timeouts, to start
multiple parallel operations and cancel the rest as soon as one finishes,
to initiate the orderely shutdown of an application etc.
```nim
{{#include ../examples/cancellation.nim}}
```
Even if cancellation is initiated, it is not guaranteed that the operation gets
cancelled - the future might still be completed or fail depending on the
order of events in the dispatcher and the specifics of the operation.
If the future indeed gets cancelled, `await` will raise a
`CancelledError` as is likely to happen in the following example:
```nim
proc c1 {.async.} =
echo "Before sleep"
try:
await sleepAsync(10.minutes)
echo "After sleep" # not reach due to cancellation
except CancelledError as exc:
echo "We got cancelled!"
# `CancelledError` is typically re-raised to notify the caller that the
# operation is being cancelled
raise exc
proc c2 {.async.} =
await c1()
echo "Never reached, since the CancelledError got re-raised"
let work = c2()
waitFor(work.cancelAndWait())
```
The `CancelledError` will now travel up the stack like any other exception.
It can be caught for instance to free some resources and is then typically
re-raised for the whole chain operations to get cancelled.
Alternatively, the cancellation request can be translated to a regular outcome
of the operation - for example, a `read` operation might return an empty result.
Cancelling an already-finished `Future` has no effect, as the following example
of downloading two web pages concurrently shows:
```nim
{{#include ../examples/twogets.nim}}
```
### Ownership
When calling a procedure that returns a `Future`, ownership of that `Future` is
shared between the callee that created it and the caller that waits for it to be
finished.
The `Future` can be thought of as a single-item channel between a producer and a
consumer. The producer creates the `Future` and is responsible for completing or
failing it while the caller waits for completion and may `cancel` it.
Although it is technically possible, callers must not `complete` or `fail`
futures and callees or other intermediate observers must not `cancel` them as
this may lead to panics and shutdown (ie if the future is completed twice or a
cancalletion is not handled by the original caller).
### `noCancel`
Certain operations must not be cancelled for semantic reasons. Common scenarios
include `closeWait` that releases a resources irrevocably and composed
operations whose individual steps should be performed together or not at all.
In such cases, the `noCancel` modifier to `await` can be used to temporarily
disable cancellation propagation, allowing the operation to complete even if
the caller initiates a cancellation request:
```nim
proc deepSleep(dur: Duration) {.async.} =
# `noCancel` prevents any cancellation request by the caller of `deepSleep`
# from reaching `sleepAsync` - even if `deepSleep` is cancelled, its future
# will not complete until the sleep finishes.
await noCancel sleepAsync(dur)
let future = deepSleep(10.minutes)
# This will take ~10 minutes even if we try to cancel the call to `deepSleep`!
await cancelAndWait(future)
```
### `join`
The `join` modifier to `await` allows cancelling an `async` procedure without
propagating the cancellation to the awaited operation. This is useful when
`await`:ing a `Future` for monitoring purposes, ie when a procedure is not the
owner of the future that's being `await`:ed.
One situation where this happens is when implementing the "observer" pattern,
where a helper monitors an operation it did not initiate:
```nim
var tick: Future[void]
proc ticker() {.async.} =
while true:
tick = sleepAsync(1.second)
await tick
echo "tick!"
proc tocker() {.async.} =
# This operation does not own or implement the operation behind `tick`,
# so it should not cancel it when `tocker` is cancelled
await join tick
echo "tock!"
let
fut = ticker() # `ticker` is now looping and most likely waiting for `tick`
fut2 = tocker() # both `ticker` and `tocker` are waiting for `tick`
# We don't want `tocker` to cancel a future that was created in `ticker`
waitFor fut2.cancelAndWait()
waitFor fut # keeps printing `tick!` every second.
```
## Compile-time configuration
`chronos` contains several compile-time
[configuration options](./chronos/config.nim) enabling stricter compile-time
checks and debugging helpers whose runtime cost may be significant.
Strictness options generally will become default in future chronos releases and
allow adapting existing code without changing the new version - see the
[`config.nim`](./chronos/config.nim) module for more information.

168
docs/src/error_handling.md Normal file
View File

@ -0,0 +1,168 @@
# Errors and exceptions
<!-- toc -->
## Exceptions
Exceptions inheriting from [`CatchableError`](https://nim-lang.org/docs/system.html#CatchableError)
interrupt execution of an `async` procedure. The exception is placed in the
`Future.error` field while changing the status of the `Future` to `Failed`
and callbacks are scheduled.
When a future is read or awaited the exception is re-raised, traversing the
`async` execution chain until handled.
```nim
proc p1() {.async.} =
await sleepAsync(1.seconds)
raise newException(ValueError, "ValueError inherits from CatchableError")
proc p2() {.async.} =
await sleepAsync(1.seconds)
proc p3() {.async.} =
let
fut1 = p1()
fut2 = p2()
await fut1
echo "unreachable code here"
await fut2
# `waitFor()` would call `Future.read()` unconditionally, which would raise the
# exception in `Future.error`.
let fut3 = p3()
while not(fut3.finished()):
poll()
echo "fut3.state = ", fut3.state # "Failed"
if fut3.failed():
echo "p3() failed: ", fut3.error.name, ": ", fut3.error.msg
# prints "p3() failed: ValueError: ValueError inherits from CatchableError"
```
You can put the `await` in a `try` block, to deal with that exception sooner:
```nim
proc p3() {.async.} =
let
fut1 = p1()
fut2 = p2()
try:
await fut1
except CachableError:
echo "p1() failed: ", fut1.error.name, ": ", fut1.error.msg
echo "reachable code here"
await fut2
```
Because `chronos` ensures that all exceptions are re-routed to the `Future`,
`poll` will not itself raise exceptions.
`poll` may still panic / raise `Defect` if such are raised in user code due to
undefined behavior.
## Checked exceptions
By specifying a `raises` list to an async procedure, you can check which
exceptions can be raised by it:
```nim
proc p1(): Future[void] {.async: (raises: [IOError]).} =
assert not (compiles do: raise newException(ValueError, "uh-uh"))
raise newException(IOError, "works") # Or any child of IOError
proc p2(): Future[void] {.async, (raises: [IOError]).} =
await p1() # Works, because await knows that p1
# can only raise IOError
```
Under the hood, the return type of `p1` will be rewritten to an internal type
which will convey raises informations to `await`.
```admonition note
Most `async` include `CancelledError` in the list of `raises`, indicating that
the operation they implement might get cancelled resulting in neither value nor
error!
```
When using checked exceptions, the `Future` type is modified to include
`raises` information - it can be constructed with the `Raising` helper:
```nim
# Create a variable of the type that will be returned by a an async function
# raising `[CancelledError]`:
var fut: Future[int].Raising([CancelledError])
```
```admonition note
`Raising` creates a specialization of `InternalRaisesFuture` type - as the name
suggests, this is an internal type whose implementation details are likely to
change in future `chronos` versions.
```
## The `Exception` type
Exceptions deriving from `Exception` are not caught by default as these may
include `Defect` and other forms undefined or uncatchable behavior.
Because exception effect tracking is turned on for `async` functions, this may
sometimes lead to compile errors around forward declarations, methods and
closures as Nim conservatively asssumes that any `Exception` might be raised
from those.
Make sure to explicitly annotate these with `{.raises.}`:
```nim
# Forward declarations need to explicitly include a raises list:
proc myfunction() {.raises: [ValueError].}
# ... as do `proc` types
type MyClosure = proc() {.raises: [ValueError].}
proc myfunction() =
raise (ref ValueError)(msg: "Implementation here")
let closure: MyClosure = myfunction
```
## Compatibility modes
**Individual functions.** For compatibility, `async` functions can be instructed
to handle `Exception` as well, specifying `handleException: true`. Any
`Exception` that is not a `Defect` and not a `CatchableError` will then be
caught and remapped to `AsyncExceptionError`:
```nim
proc raiseException() {.async: (handleException: true, raises: [AsyncExceptionError]).} =
raise (ref Exception)(msg: "Raising Exception is UB")
proc callRaiseException() {.async: (raises: []).} =
try:
await raiseException()
except AsyncExceptionError as exc:
# The original Exception is available from the `parent` field
echo exc.parent.msg
```
**Global flag.** This mode can be enabled globally with
`-d:chronosHandleException` as a help when porting code to `chronos`. The
behavior in this case will be that:
1. old-style functions annotated with plain `async` will behave as if they had
been annotated with `async: (handleException: true)`.
This is functionally equivalent to
`async: (handleException: true, raises: [CatchableError])` and will, as
before, remap any `Exception` that is not `Defect` into
`AsyncExceptionError`, while also allowing any `CatchableError` (including
`AsyncExceptionError`) to get through without compilation errors.
2. New-style functions with `async: (raises: [...])` annotations or their own
`handleException` annotations will not be affected.
The rationale here is to allow one to incrementally introduce exception
annotations and get compiler feedback while not requiring that every bit of
legacy code is updated at once.
This should be used sparingly and with care, however, as global configuration
settings may interfere with libraries that use `chronos` leading to unexpected
behavior.

23
docs/src/examples.md Normal file
View File

@ -0,0 +1,23 @@
# Examples
Examples are available in the [`docs/examples/`](https://github.com/status-im/nim-chronos/tree/master/docs/examples/) folder.
## Basic concepts
* [cancellation](https://github.com/status-im/nim-chronos/tree/master/docs/examples/cancellation.nim) - Cancellation primer
* [timeoutsimple](https://github.com/status-im/nim-chronos/tree/master/docs/examples/timeoutsimple.nim) - Simple timeouts
* [timeoutcomposed](https://github.com/status-im/nim-chronos/tree/master/docs/examples/examples/timeoutcomposed.nim) - Shared timeout of multiple tasks
## Threads
* [signalling](https://github.com/status-im/nim-chronos/tree/master/docs/examples/signalling.nim) - Cross-thread signalling
## TCP
* [tcpserver](https://github.com/status-im/nim-chronos/tree/master/docs/examples/tcpserver.nim) - Simple TCP/IP v4/v6 echo server
## HTTP
* [httpget](https://github.com/status-im/nim-chronos/tree/master/docs/examples/httpget.nim) - Downloading a web page using the http client
* [twogets](https://github.com/status-im/nim-chronos/tree/master/docs/examples/twogets.nim) - Download two pages concurrently
* [middleware](https://github.com/status-im/nim-chronos/tree/master/docs/examples/middleware.nim) - Deploy multiple HTTP server middlewares

View File

@ -0,0 +1,19 @@
## Getting started
Install `chronos` using `nimble`:
```text
nimble install chronos
```
or add a dependency to your `.nimble` file:
```text
requires "chronos"
```
and start using it:
```nim
{{#include ../examples/httpget.nim}}
```

View File

@ -0,0 +1,102 @@
## HTTP server middleware
Chronos provides a powerful mechanism for customizing HTTP request handlers via
middlewares.
A middleware is a coroutine that can modify, block or filter HTTP request.
Single HTTP server could support unlimited number of middlewares, but you need to consider that each request in worst case could go through all the middlewares, and therefore a huge number of middlewares can have a significant impact on HTTP server performance.
Order of middlewares is also important: right after HTTP server has received request, it will be sent to the first middleware in list, and each middleware will be responsible for passing control to other middlewares. Therefore, when building a list, it would be a good idea to place the request handlers at the end of the list, while keeping the middleware that could block or modify the request at the beginning of the list.
Middleware could also modify HTTP server request, and these changes will be visible to all handlers (either middlewares or the original request handler). This can be done using the following helpers:
```nim
proc updateRequest*(request: HttpRequestRef, scheme: string, meth: HttpMethod,
version: HttpVersion, requestUri: string,
headers: HttpTable): HttpResultMessage[void]
proc updateRequest*(request: HttpRequestRef, meth: HttpMethod,
requestUri: string,
headers: HttpTable): HttpResultMessage[void]
proc updateRequest*(request: HttpRequestRef, requestUri: string,
headers: HttpTable): HttpResultMessage[void]
proc updateRequest*(request: HttpRequestRef,
requestUri: string): HttpResultMessage[void]
proc updateRequest*(request: HttpRequestRef,
headers: HttpTable): HttpResultMessage[void]
```
As you can see all the HTTP request parameters could be modified: request method, version, request path and request headers.
Middleware could also use helpers to obtain more information about remote and local addresses of request's connection (this could be helpful when you need to do some IP address filtering).
```nim
proc remote*(request: HttpRequestRef): Opt[TransportAddress]
## Returns remote address of HTTP request's connection.
proc local*(request: HttpRequestRef): Opt[TransportAddress] =
## Returns local address of HTTP request's connection.
```
Every middleware is the coroutine which looks like this:
```nim
proc middlewareHandler(
middleware: HttpServerMiddlewareRef,
reqfence: RequestFence,
nextHandler: HttpProcessCallback2
): Future[HttpResponseRef] {.async: (raises: [CancelledError]).} =
```
Where `middleware` argument is the object which could hold some specific values, `reqfence` is HTTP request which is enclosed with HTTP server error information and `nextHandler` is reference to next request handler, it could be either middleware handler or the original request processing callback handler.
```nim
await nextHandler(reqfence)
```
You should perform await for the response from the `nextHandler(reqfence)`. Usually you should call next handler when you dont want to handle request or you dont know how to handle it, for example:
```nim
proc middlewareHandler(
middleware: HttpServerMiddlewareRef,
reqfence: RequestFence,
nextHandler: HttpProcessCallback2
): Future[HttpResponseRef] {.async: (raises: [CancelledError]).} =
if reqfence.isErr():
# We dont know or do not want to handle failed requests, so we call next handler.
return await nextHandler(reqfence)
let request = reqfence.get()
if request.uri.path == "/path/we/able/to/respond":
try:
# Sending some response.
await request.respond(Http200, "TEST")
except HttpWriteError as exc:
# We could also return default response for exception or other types of error.
defaultResponse(exc)
elif request.uri.path == "/path/for/rewrite":
# We going to modify request object for this request, next handler will receive it with different request path.
let res = request.updateRequest("/path/to/new/location")
if res.isErr():
return defaultResponse(res.error)
await nextHandler(reqfence)
elif request.uri.path == "/restricted/path":
if request.remote().isNone():
# We can't obtain remote address, so we force HTTP server to respond with `401 Unauthorized` status code.
return codeResponse(Http401)
if $(request.remote().get()).startsWith("127.0.0.1"):
# Remote peer's address starts with "127.0.0.1", sending proper response.
await request.respond(Http200, "AUTHORIZED")
else:
# Force HTTP server to respond with `403 Forbidden` status code.
codeResponse(Http403)
elif request.uri.path == "/blackhole":
# Force HTTP server to drop connection with remote peer.
dropResponse()
else:
# All other requests should be handled by somebody else.
await nextHandler(reqfence)
```

50
docs/src/introduction.md Normal file
View File

@ -0,0 +1,50 @@
# Introduction
Chronos implements the [async/await](https://en.wikipedia.org/wiki/Async/await)
paradigm in a self-contained library using macro and closure iterator
transformation features provided by Nim.
Features include:
* Asynchronous socket and process I/O
* HTTP client / server with SSL/TLS support out of the box (no OpenSSL needed)
* Synchronization primitivies like queues, events and locks
* [Cancellation](./concepts.md#cancellation)
* Efficient dispatch pipeline with excellent multi-platform support
* Exception [effect support](./guide.md#error-handling)
## Installation
Install `chronos` using `nimble`:
```text
nimble install chronos
```
or add a dependency to your `.nimble` file:
```text
requires "chronos"
```
and start using it:
```nim
{{#include ../examples/httpget.nim}}
```
There are more [examples](./examples.md) throughout the manual!
## Platform support
Several platforms are supported, with different backend [options](./concepts.md#compile-time-configuration):
* Windows: [`IOCP`](https://learn.microsoft.com/en-us/windows/win32/fileio/i-o-completion-ports)
* Linux: [`epoll`](https://en.wikipedia.org/wiki/Epoll) / `poll`
* OSX / BSD: [`kqueue`](https://en.wikipedia.org/wiki/Kqueue) / `poll`
* Android / Emscripten / posix: `poll`
## API documentation
This guide covers basic usage of chronos - for details, see the
[API reference](./api/chronos.html).

59
docs/src/porting.md Normal file
View File

@ -0,0 +1,59 @@
# Porting code to `chronos` v4
<!-- toc -->
Thanks to its macro support, Nim allows `async`/`await` to be implemented in
libraries with only minimal support from the language - as such, multiple
`async` libraries exist, including `chronos` and `asyncdispatch`, and more may
come to be developed in the futures.
## Chronos v3
Chronos v4 introduces new features for IPv6, exception effects, a stand-alone
`Future` type as well as several other changes - when upgrading from chronos v3,
here are several things to consider:
* Exception handling is now strict by default - see the [error handling](./error_handling.md)
chapter for how to deal with `raises` effects
* `AsyncEventBus` was removed - use `AsyncEventQueue` instead
* `Future.value` and `Future.error` panic when accessed in the wrong state
* `Future.read` and `Future.readError` raise `FutureError` instead of
`ValueError` when accessed in the wrong state
## `asyncdispatch`
Code written for `asyncdispatch` and `chronos` looks similar but there are
several differences to be aware of:
* `chronos` has its own dispatch loop - you can typically not mix `chronos` and
`asyncdispatch` in the same thread
* `import chronos` instead of `import asyncdispatch`
* cleanup is important - make sure to use `closeWait` to release any resources
you're using or file descriptor and other leaks will ensue
* cancellation support means that `CancelledError` may be raised from most
`{.async.}` functions
* Calling `yield` directly in tasks is not supported - instead, use `awaitne`.
* `asyncSpawn` is used instead of `asyncCheck` - note that exceptions raised
in tasks that are `asyncSpawn`:ed cause panic
## Supporting multiple backends
Libraries built on top of `async`/`await` may wish to support multiple async
backends - the best way to do so is to create separate modules for each backend
that may be imported side-by-side - see [nim-metrics](https://github.com/status-im/nim-metrics/blob/master/metrics/)
for an example.
An alternative way is to select backend using a global compile flag - this
method makes it diffucult to compose applications that use both backends as may
happen with transitive dependencies, but may be appropriate in some cases -
libraries choosing this path should call the flag `asyncBackend`, allowing
applications to choose the backend with `-d:asyncBackend=<backend_name>`.
Known `async` backends include:
* `chronos` - this library (`-d:asyncBackend=chronos`)
* `asyncdispatch` the standard library `asyncdispatch` [module](https://nim-lang.org/docs/asyncdispatch.html) (`-d:asyncBackend=asyncdispatch`)
* `none` - ``-d:asyncBackend=none`` - disable ``async`` support completely
``none`` can be used when a library supports both a synchronous and
asynchronous API, to disable the latter.

18
docs/src/threads.md Normal file
View File

@ -0,0 +1,18 @@
# Threads
While the cooperative [`async`](./concepts.md) model offers an efficient model
for dealing with many tasks that often are blocked on I/O, it is not suitable
for long-running computations that would prevent concurrent tasks from progressing.
Multithreading offers a way to offload heavy computations to be executed in
parallel with the async work, or, in cases where a single event loop gets
overloaded, to manage multiple event loops in parallel.
For interaction between threads, the `ThreadSignalPtr` type (found in the
(`chronos/threadsync`)(https://github.com/status-im/nim-chronos/blob/master/chronos/threadsync.nim)
module) is used - both to wait for notifications coming from other threads and
to notify other threads of progress from within an async procedure.
```nim
{{#include ../examples/signalling.nim}}
```

34
docs/src/tips.md Normal file
View File

@ -0,0 +1,34 @@
# Tips, tricks and best practices
## Timeouts
To prevent a single task from taking too long, `withTimeout` can be used:
```nim
{{#include ../examples/timeoutsimple.nim}}
```
When several tasks should share a single timeout, a common timer can be created
with `sleepAsync`:
```nim
{{#include ../examples/timeoutcomposed.nim}}
```
## `discard`
When calling an asynchronous procedure without `await`, the operation is started
but its result is not processed until corresponding `Future` is `read`.
It is therefore important to never `discard` futures directly - instead, one
can discard the result of awaiting the future or use `asyncSpawn` to monitor
the outcome of the future as if it were running in a separate thread.
Similar to threads, tasks managed by `asyncSpawn` may causes the application to
crash if any exceptions leak out of it - use
[checked exceptions](./error_handling.md#checked-exceptions) to avoid this
problem.
```nim
{{#include ../examples/discards.nim}}
```

1
nim.cfg Normal file
View File

@ -0,0 +1 @@
nimcache = "build/nimcache/$projectName"

View File

@ -5,10 +5,22 @@
# Licensed under either of # Licensed under either of
# Apache License, version 2.0, (LICENSE-APACHEv2) # Apache License, version 2.0, (LICENSE-APACHEv2)
# MIT license (LICENSE-MIT) # MIT license (LICENSE-MIT)
import testmacro, testsync, testsoon, testtime, testfut, testsignal, import ".."/chronos/config
testaddress, testdatagram, teststream, testserver, testbugs, testnet,
testasyncstream, testhttpserver, testshttpserver, testhttpclient,
testproc, testratelimit, testfutures, testthreadsync
# Must be imported last to check for Pending futures when (chronosEventEngine in ["epoll", "kqueue"]) or defined(windows):
import testutils import testmacro, testsync, testsoon, testtime, testfut, testsignal,
testaddress, testdatagram, teststream, testserver, testbugs, testnet,
testasyncstream, testhttpserver, testshttpserver, testhttpclient,
testproc, testratelimit, testfutures, testthreadsync
# Must be imported last to check for Pending futures
import testutils
elif chronosEventEngine == "poll":
# `poll` engine do not support signals and processes
import testmacro, testsync, testsoon, testtime, testfut, testaddress,
testdatagram, teststream, testserver, testbugs, testnet,
testasyncstream, testhttpserver, testshttpserver, testhttpclient,
testratelimit, testfutures, testthreadsync
# Must be imported last to check for Pending futures
import testutils

63
tests/testasyncstream.c Normal file
View File

@ -0,0 +1,63 @@
#include <brssl.h>
// This is the X509TrustAnchor for the SelfSignedRsaCert above
// Generate by doing the following:
// 1. Compile `brssl` from BearSSL
// 2. Run `brssl ta filewithSelfSignedRsaCert.pem`
// 3. Paste the output in the emit block below
// 4. Rename `TAs` to `SelfSignedTAs`
static const unsigned char TA0_DN[] = {
0x30, 0x5F, 0x31, 0x0B, 0x30, 0x09, 0x06, 0x03, 0x55, 0x04, 0x06, 0x13,
0x02, 0x41, 0x55, 0x31, 0x13, 0x30, 0x11, 0x06, 0x03, 0x55, 0x04, 0x08,
0x0C, 0x0A, 0x53, 0x6F, 0x6D, 0x65, 0x2D, 0x53, 0x74, 0x61, 0x74, 0x65,
0x31, 0x21, 0x30, 0x1F, 0x06, 0x03, 0x55, 0x04, 0x0A, 0x0C, 0x18, 0x49,
0x6E, 0x74, 0x65, 0x72, 0x6E, 0x65, 0x74, 0x20, 0x57, 0x69, 0x64, 0x67,
0x69, 0x74, 0x73, 0x20, 0x50, 0x74, 0x79, 0x20, 0x4C, 0x74, 0x64, 0x31,
0x18, 0x30, 0x16, 0x06, 0x03, 0x55, 0x04, 0x03, 0x0C, 0x0F, 0x31, 0x32,
0x37, 0x2E, 0x30, 0x2E, 0x30, 0x2E, 0x31, 0x3A, 0x34, 0x33, 0x38, 0x30,
0x38
};
static const unsigned char TA0_RSA_N[] = {
0xA7, 0xEE, 0xD5, 0xC6, 0x2C, 0xA3, 0x08, 0x33, 0x33, 0x86, 0xB5, 0x5C,
0xD4, 0x8B, 0x16, 0xB1, 0xD7, 0xF7, 0xED, 0x95, 0x22, 0xDC, 0xA4, 0x40,
0x24, 0x64, 0xC3, 0x91, 0xBA, 0x20, 0x82, 0x9D, 0x88, 0xED, 0x20, 0x98,
0x46, 0x65, 0xDC, 0xD1, 0x15, 0x90, 0xBC, 0x7C, 0x19, 0x5F, 0x00, 0x96,
0x69, 0x2C, 0x80, 0x0E, 0x7D, 0x7D, 0x8B, 0xD9, 0xFD, 0x49, 0x66, 0xEC,
0x29, 0xC0, 0x39, 0x0E, 0x22, 0xF3, 0x6A, 0x28, 0xC0, 0x6B, 0x97, 0x93,
0x2F, 0x92, 0x5E, 0x5A, 0xCC, 0xF4, 0xF4, 0xAE, 0xD9, 0xE3, 0xBB, 0x0A,
0xDC, 0xA8, 0xDE, 0x4D, 0x16, 0xD6, 0xE6, 0x64, 0xF2, 0x85, 0x62, 0xF6,
0xE3, 0x7B, 0x1D, 0x9A, 0x5C, 0x6A, 0xA3, 0x97, 0x93, 0x16, 0x9D, 0x02,
0x2C, 0xFD, 0x90, 0x3E, 0xF8, 0x35, 0x44, 0x5E, 0x66, 0x8D, 0xF6, 0x80,
0xF1, 0x71, 0x9B, 0x2F, 0x44, 0xC0, 0xCA, 0x7E, 0xB1, 0x90, 0x7F, 0xD8,
0x8B, 0x7A, 0x85, 0x4B, 0xE3, 0xB1, 0xB1, 0xF4, 0xAA, 0x6A, 0x36, 0xA0,
0xFF, 0x24, 0xB2, 0x27, 0xE0, 0xBA, 0x62, 0x7A, 0xE9, 0x95, 0xC9, 0x88,
0x9D, 0x9B, 0xAB, 0xA4, 0x4C, 0xEA, 0x87, 0x46, 0xFA, 0xD6, 0x9B, 0x7E,
0xB2, 0xE9, 0x5B, 0xCA, 0x5B, 0x84, 0xC4, 0xF7, 0xB4, 0xC7, 0x69, 0xC5,
0x0B, 0x9A, 0x47, 0x9A, 0x86, 0xD4, 0xDF, 0xF3, 0x30, 0xC9, 0x6D, 0xB8,
0x78, 0x10, 0xEF, 0xA0, 0x89, 0xF8, 0x30, 0x80, 0x9D, 0x96, 0x05, 0x44,
0xB4, 0xFB, 0x98, 0x4C, 0x71, 0x6B, 0xBC, 0xD7, 0x5D, 0x66, 0x5E, 0x66,
0xA7, 0x94, 0xE5, 0x65, 0x72, 0x85, 0xBC, 0x7C, 0x7F, 0x11, 0x98, 0xF8,
0xCB, 0xD5, 0xE2, 0xB5, 0x67, 0x78, 0xF7, 0x49, 0x51, 0xC4, 0x7F, 0xBA,
0x16, 0x66, 0xD2, 0x15, 0x5B, 0x98, 0x06, 0x03, 0x48, 0xD0, 0x9D, 0xF0,
0x38, 0x2B, 0x9D, 0x51
};
static const unsigned char TA0_RSA_E[] = {
0x01, 0x00, 0x01
};
const br_x509_trust_anchor SelfSignedTAs[1] = {
{
{ (unsigned char *)TA0_DN, sizeof TA0_DN },
BR_X509_TA_CA,
{
BR_KEYTYPE_RSA,
{ .rsa = {
(unsigned char *)TA0_RSA_N, sizeof TA0_RSA_N,
(unsigned char *)TA0_RSA_E, sizeof TA0_RSA_E,
} }
}
}
};

File diff suppressed because it is too large Load Diff

View File

@ -14,23 +14,26 @@ suite "Asynchronous issues test suite":
const HELLO_PORT = 45679 const HELLO_PORT = 45679
const TEST_MSG = "testmsg" const TEST_MSG = "testmsg"
const MSG_LEN = TEST_MSG.len() const MSG_LEN = TEST_MSG.len()
const TestsCount = 500 const TestsCount = 100
type type
CustomData = ref object CustomData = ref object
test: string test: string
proc udp4DataAvailable(transp: DatagramTransport, proc udp4DataAvailable(transp: DatagramTransport,
remote: TransportAddress) {.async, gcsafe.} = remote: TransportAddress) {.async: (raises: []).} =
var udata = getUserData[CustomData](transp) try:
var expect = TEST_MSG var udata = getUserData[CustomData](transp)
var data: seq[byte] var expect = TEST_MSG
var datalen: int var data: seq[byte]
transp.peekMessage(data, datalen) var datalen: int
if udata.test == "CHECK" and datalen == MSG_LEN and transp.peekMessage(data, datalen)
equalMem(addr data[0], addr expect[0], datalen): if udata.test == "CHECK" and datalen == MSG_LEN and
udata.test = "OK" equalMem(addr data[0], addr expect[0], datalen):
transp.close() udata.test = "OK"
transp.close()
except CatchableError as exc:
raiseAssert exc.msg
proc issue6(): Future[bool] {.async.} = proc issue6(): Future[bool] {.async.} =
var myself = initTAddress("127.0.0.1:" & $HELLO_PORT) var myself = initTAddress("127.0.0.1:" & $HELLO_PORT)
@ -132,6 +135,16 @@ suite "Asynchronous issues test suite":
await server.closeWait() await server.closeWait()
return true return true
proc testOrDeadlock(): Future[bool] {.async.} =
proc f(): Future[void] {.async.} =
await sleepAsync(2.seconds) or sleepAsync(1.seconds)
let fx = f()
try:
await fx.cancelAndWait().wait(2.seconds)
except AsyncTimeoutError:
return false
true
test "Issue #6": test "Issue #6":
check waitFor(issue6()) == true check waitFor(issue6()) == true
@ -149,3 +162,6 @@ suite "Asynchronous issues test suite":
test "IndexError crash test": test "IndexError crash test":
check waitFor(testIndexError()) == true check waitFor(testIndexError()) == true
test "`or` deadlock [#516] test":
check waitFor(testOrDeadlock()) == true

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -9,7 +9,7 @@ import std/[strutils, sha1]
import ".."/chronos/unittest2/asynctests import ".."/chronos/unittest2/asynctests
import ".."/chronos, import ".."/chronos,
".."/chronos/apps/http/[httpserver, shttpserver, httpclient] ".."/chronos/apps/http/[httpserver, shttpserver, httpclient]
import stew/base10 import stew/[byteutils, base10]
{.used.} {.used.}
@ -74,6 +74,8 @@ N8r5CwGcIX/XPC3lKazzbZ8baA==
""" """
suite "HTTP client testing suite": suite "HTTP client testing suite":
teardown:
checkLeaks()
type type
TestResponseTuple = tuple[status: int, data: string, count: int] TestResponseTuple = tuple[status: int, data: string, count: int]
@ -85,7 +87,8 @@ suite "HTTP client testing suite":
res res
proc createServer(address: TransportAddress, proc createServer(address: TransportAddress,
process: HttpProcessCallback, secure: bool): HttpServerRef = process: HttpProcessCallback2,
secure: bool): HttpServerRef =
let let
socketFlags = {ServerFlags.TcpNoDelay, ServerFlags.ReuseAddr} socketFlags = {ServerFlags.TcpNoDelay, ServerFlags.ReuseAddr}
serverFlags = {HttpServerFlags.Http11Pipeline} serverFlags = {HttpServerFlags.Http11Pipeline}
@ -128,18 +131,24 @@ suite "HTTP client testing suite":
(MethodPatch, "/test/patch") (MethodPatch, "/test/patch")
] ]
proc process(r: RequestFence): Future[HttpResponseRef] {. proc process(r: RequestFence): Future[HttpResponseRef] {.
async.} = async: (raises: [CancelledError]).} =
if r.isOk(): if r.isOk():
let request = r.get() let request = r.get()
case request.uri.path case request.uri.path
of "/test/get", "/test/post", "/test/head", "/test/put", of "/test/get", "/test/post", "/test/head", "/test/put",
"/test/delete", "/test/trace", "/test/options", "/test/connect", "/test/delete", "/test/trace", "/test/options", "/test/connect",
"/test/patch", "/test/error": "/test/patch", "/test/error":
return await request.respond(Http200, request.uri.path) try:
await request.respond(Http200, request.uri.path)
except HttpWriteError as exc:
defaultResponse(exc)
else: else:
return await request.respond(Http404, "Page not found") try:
await request.respond(Http404, "Page not found")
except HttpWriteError as exc:
defaultResponse(exc)
else: else:
return defaultResponse() defaultResponse()
var server = createServer(initTAddress("127.0.0.1:0"), process, secure) var server = createServer(initTAddress("127.0.0.1:0"), process, secure)
server.start() server.start()
@ -157,7 +166,7 @@ suite "HTTP client testing suite":
var req = HttpClientRequestRef.new(session, ha, item[0]) var req = HttpClientRequestRef.new(session, ha, item[0])
let response = await fetch(req) let response = await fetch(req)
if response.status == 200: if response.status == 200:
let data = cast[string](response.data) let data = string.fromBytes(response.data)
if data == item[1]: if data == item[1]:
inc(counter) inc(counter)
await req.closeWait() await req.closeWait()
@ -173,7 +182,7 @@ suite "HTTP client testing suite":
var req = HttpClientRequestRef.new(session, ha, item[0]) var req = HttpClientRequestRef.new(session, ha, item[0])
let response = await fetch(req) let response = await fetch(req)
if response.status == 200: if response.status == 200:
let data = cast[string](response.data) let data = string.fromBytes(response.data)
if data == item[1]: if data == item[1]:
inc(counter) inc(counter)
await req.closeWait() await req.closeWait()
@ -187,15 +196,15 @@ suite "HTTP client testing suite":
let ResponseTests = [ let ResponseTests = [
(MethodGet, "/test/short_size_response", 65600, 1024, (MethodGet, "/test/short_size_response", 65600, 1024,
"SHORTSIZERESPONSE"), "SHORTSIZERESPONSE"),
(MethodGet, "/test/long_size_response", 262400, 1024, (MethodGet, "/test/long_size_response", 131200, 1024,
"LONGSIZERESPONSE"), "LONGSIZERESPONSE"),
(MethodGet, "/test/short_chunked_response", 65600, 1024, (MethodGet, "/test/short_chunked_response", 65600, 1024,
"SHORTCHUNKRESPONSE"), "SHORTCHUNKRESPONSE"),
(MethodGet, "/test/long_chunked_response", 262400, 1024, (MethodGet, "/test/long_chunked_response", 131200, 1024,
"LONGCHUNKRESPONSE") "LONGCHUNKRESPONSE")
] ]
proc process(r: RequestFence): Future[HttpResponseRef] {. proc process(r: RequestFence): Future[HttpResponseRef] {.
async.} = async: (raises: [CancelledError]).} =
if r.isOk(): if r.isOk():
let request = r.get() let request = r.get()
case request.uri.path case request.uri.path
@ -203,46 +212,58 @@ suite "HTTP client testing suite":
var response = request.getResponse() var response = request.getResponse()
var data = createBigMessage(ResponseTests[0][4], ResponseTests[0][2]) var data = createBigMessage(ResponseTests[0][4], ResponseTests[0][2])
response.status = Http200 response.status = Http200
await response.sendBody(data) try:
return response await response.sendBody(data)
except HttpWriteError as exc:
return defaultResponse(exc)
response
of "/test/long_size_response": of "/test/long_size_response":
var response = request.getResponse() var response = request.getResponse()
var data = createBigMessage(ResponseTests[1][4], ResponseTests[1][2]) var data = createBigMessage(ResponseTests[1][4], ResponseTests[1][2])
response.status = Http200 response.status = Http200
await response.sendBody(data) try:
return response await response.sendBody(data)
except HttpWriteError as exc:
return defaultResponse(exc)
response
of "/test/short_chunked_response": of "/test/short_chunked_response":
var response = request.getResponse() var response = request.getResponse()
var data = createBigMessage(ResponseTests[2][4], ResponseTests[2][2]) var data = createBigMessage(ResponseTests[2][4], ResponseTests[2][2])
response.status = Http200 response.status = Http200
await response.prepare() try:
var offset = 0 await response.prepare()
while true: var offset = 0
if len(data) == offset: while true:
break if len(data) == offset:
let toWrite = min(1024, len(data) - offset) break
await response.sendChunk(addr data[offset], toWrite) let toWrite = min(1024, len(data) - offset)
offset = offset + toWrite await response.sendChunk(addr data[offset], toWrite)
await response.finish() offset = offset + toWrite
return response await response.finish()
except HttpWriteError as exc:
return defaultResponse(exc)
response
of "/test/long_chunked_response": of "/test/long_chunked_response":
var response = request.getResponse() var response = request.getResponse()
var data = createBigMessage(ResponseTests[3][4], ResponseTests[3][2]) var data = createBigMessage(ResponseTests[3][4], ResponseTests[3][2])
response.status = Http200 response.status = Http200
await response.prepare() try:
var offset = 0 await response.prepare()
while true: var offset = 0
if len(data) == offset: while true:
break if len(data) == offset:
let toWrite = min(1024, len(data) - offset) break
await response.sendChunk(addr data[offset], toWrite) let toWrite = min(1024, len(data) - offset)
offset = offset + toWrite await response.sendChunk(addr data[offset], toWrite)
await response.finish() offset = offset + toWrite
return response await response.finish()
except HttpWriteError as exc:
return defaultResponse(exc)
response
else: else:
return await request.respond(Http404, "Page not found") defaultResponse()
else: else:
return defaultResponse() defaultResponse()
var server = createServer(initTAddress("127.0.0.1:0"), process, secure) var server = createServer(initTAddress("127.0.0.1:0"), process, secure)
server.start() server.start()
@ -311,21 +332,26 @@ suite "HTTP client testing suite":
(MethodPost, "/test/big_request", 262400) (MethodPost, "/test/big_request", 262400)
] ]
proc process(r: RequestFence): Future[HttpResponseRef] {. proc process(r: RequestFence): Future[HttpResponseRef] {.
async.} = async: (raises: [CancelledError]).} =
if r.isOk(): if r.isOk():
let request = r.get() let request = r.get()
case request.uri.path case request.uri.path
of "/test/big_request": of "/test/big_request":
if request.hasBody(): try:
let body = await request.getBody() if request.hasBody():
let digest = $secureHash(cast[string](body)) let body = await request.getBody()
return await request.respond(Http200, digest) let digest = $secureHash(string.fromBytes(body))
else: await request.respond(Http200, digest)
return await request.respond(Http400, "Missing content body") else:
await request.respond(Http400, "Missing content body")
except HttpProtocolError as exc:
defaultResponse(exc)
except HttpTransportError as exc:
defaultResponse(exc)
else: else:
return await request.respond(Http404, "Page not found") defaultResponse()
else: else:
return defaultResponse() defaultResponse()
var server = createServer(initTAddress("127.0.0.1:0"), process, secure) var server = createServer(initTAddress("127.0.0.1:0"), process, secure)
server.start() server.start()
@ -348,7 +374,7 @@ suite "HTTP client testing suite":
session, ha, item[0], headers = headers session, ha, item[0], headers = headers
) )
var expectDigest = $secureHash(cast[string](data)) var expectDigest = $secureHash(string.fromBytes(data))
# Sending big request by 1024bytes long chunks # Sending big request by 1024bytes long chunks
var writer = await open(request) var writer = await open(request)
var offset = 0 var offset = 0
@ -364,7 +390,7 @@ suite "HTTP client testing suite":
if response.status == 200: if response.status == 200:
var res = await response.getBodyBytes() var res = await response.getBodyBytes()
if cast[string](res) == expectDigest: if string.fromBytes(res) == expectDigest:
inc(counter) inc(counter)
await response.closeWait() await response.closeWait()
await request.closeWait() await request.closeWait()
@ -381,21 +407,27 @@ suite "HTTP client testing suite":
(MethodPost, "/test/big_chunk_request", 262400) (MethodPost, "/test/big_chunk_request", 262400)
] ]
proc process(r: RequestFence): Future[HttpResponseRef] {. proc process(r: RequestFence): Future[HttpResponseRef] {.
async.} = async: (raises: [CancelledError]).} =
if r.isOk(): if r.isOk():
let request = r.get() let request = r.get()
case request.uri.path case request.uri.path
of "/test/big_chunk_request": of "/test/big_chunk_request":
if request.hasBody(): try:
let body = await request.getBody() if request.hasBody():
let digest = $secureHash(cast[string](body)) let
return await request.respond(Http200, digest) body = await request.getBody()
else: digest = $secureHash(string.fromBytes(body))
return await request.respond(Http400, "Missing content body") await request.respond(Http200, digest)
else:
await request.respond(Http400, "Missing content body")
except HttpProtocolError as exc:
defaultResponse(exc)
except HttpTransportError as exc:
defaultResponse(exc)
else: else:
return await request.respond(Http404, "Page not found") defaultResponse()
else: else:
return defaultResponse() defaultResponse()
var server = createServer(initTAddress("127.0.0.1:0"), process, secure) var server = createServer(initTAddress("127.0.0.1:0"), process, secure)
server.start() server.start()
@ -418,7 +450,7 @@ suite "HTTP client testing suite":
session, ha, item[0], headers = headers session, ha, item[0], headers = headers
) )
var expectDigest = $secureHash(cast[string](data)) var expectDigest = $secureHash(string.fromBytes(data))
# Sending big request by 1024bytes long chunks # Sending big request by 1024bytes long chunks
var writer = await open(request) var writer = await open(request)
var offset = 0 var offset = 0
@ -434,7 +466,7 @@ suite "HTTP client testing suite":
if response.status == 200: if response.status == 200:
var res = await response.getBodyBytes() var res = await response.getBodyBytes()
if cast[string](res) == expectDigest: if string.fromBytes(res) == expectDigest:
inc(counter) inc(counter)
await response.closeWait() await response.closeWait()
await request.closeWait() await request.closeWait()
@ -455,23 +487,28 @@ suite "HTTP client testing suite":
] ]
proc process(r: RequestFence): Future[HttpResponseRef] {. proc process(r: RequestFence): Future[HttpResponseRef] {.
async.} = async: (raises: [CancelledError]).} =
if r.isOk(): if r.isOk():
let request = r.get() let request = r.get()
case request.uri.path case request.uri.path
of "/test/post/urlencoded_size", "/test/post/urlencoded_chunked": of "/test/post/urlencoded_size", "/test/post/urlencoded_chunked":
if request.hasBody(): try:
var postTable = await request.post() if request.hasBody():
let body = postTable.getString("field1") & ":" & var postTable = await request.post()
postTable.getString("field2") & ":" & let body = postTable.getString("field1") & ":" &
postTable.getString("field3") postTable.getString("field2") & ":" &
return await request.respond(Http200, body) postTable.getString("field3")
else: await request.respond(Http200, body)
return await request.respond(Http400, "Missing content body") else:
await request.respond(Http400, "Missing content body")
except HttpTransportError as exc:
defaultResponse(exc)
except HttpProtocolError as exc:
defaultResponse(exc)
else: else:
return await request.respond(Http404, "Page not found") defaultResponse()
else: else:
return defaultResponse() defaultResponse()
var server = createServer(initTAddress("127.0.0.1:0"), process, secure) var server = createServer(initTAddress("127.0.0.1:0"), process, secure)
server.start() server.start()
@ -491,12 +528,12 @@ suite "HTTP client testing suite":
] ]
var request = HttpClientRequestRef.new( var request = HttpClientRequestRef.new(
session, ha, MethodPost, headers = headers, session, ha, MethodPost, headers = headers,
body = cast[seq[byte]](PostRequests[0][1])) body = PostRequests[0][1].toBytes())
var response = await send(request) var response = await send(request)
if response.status == 200: if response.status == 200:
var res = await response.getBodyBytes() var res = await response.getBodyBytes()
if cast[string](res) == PostRequests[0][2]: if string.fromBytes(res) == PostRequests[0][2]:
inc(counter) inc(counter)
await response.closeWait() await response.closeWait()
await request.closeWait() await request.closeWait()
@ -532,7 +569,7 @@ suite "HTTP client testing suite":
var response = await request.finish() var response = await request.finish()
if response.status == 200: if response.status == 200:
var res = await response.getBodyBytes() var res = await response.getBodyBytes()
if cast[string](res) == PostRequests[1][2]: if string.fromBytes(res) == PostRequests[1][2]:
inc(counter) inc(counter)
await response.closeWait() await response.closeWait()
await request.closeWait() await request.closeWait()
@ -554,23 +591,28 @@ suite "HTTP client testing suite":
] ]
proc process(r: RequestFence): Future[HttpResponseRef] {. proc process(r: RequestFence): Future[HttpResponseRef] {.
async.} = async: (raises: [CancelledError]).} =
if r.isOk(): if r.isOk():
let request = r.get() let request = r.get()
case request.uri.path case request.uri.path
of "/test/post/multipart_size", "/test/post/multipart_chunked": of "/test/post/multipart_size", "/test/post/multipart_chunked":
if request.hasBody(): try:
var postTable = await request.post() if request.hasBody():
let body = postTable.getString("field1") & ":" & var postTable = await request.post()
postTable.getString("field2") & ":" & let body = postTable.getString("field1") & ":" &
postTable.getString("field3") postTable.getString("field2") & ":" &
return await request.respond(Http200, body) postTable.getString("field3")
else: await request.respond(Http200, body)
return await request.respond(Http400, "Missing content body") else:
await request.respond(Http400, "Missing content body")
except HttpProtocolError as exc:
defaultResponse(exc)
except HttpTransportError as exc:
defaultResponse(exc)
else: else:
return await request.respond(Http404, "Page not found") defaultResponse()
else: else:
return defaultResponse() defaultResponse()
var server = createServer(initTAddress("127.0.0.1:0"), process, secure) var server = createServer(initTAddress("127.0.0.1:0"), process, secure)
server.start() server.start()
@ -601,7 +643,7 @@ suite "HTTP client testing suite":
var response = await send(request) var response = await send(request)
if response.status == 200: if response.status == 200:
var res = await response.getBodyBytes() var res = await response.getBodyBytes()
if cast[string](res) == PostRequests[0][3]: if string.fromBytes(res) == PostRequests[0][3]:
inc(counter) inc(counter)
await response.closeWait() await response.closeWait()
await request.closeWait() await request.closeWait()
@ -634,7 +676,7 @@ suite "HTTP client testing suite":
let response = await request.finish() let response = await request.finish()
if response.status == 200: if response.status == 200:
var res = await response.getBodyBytes() var res = await response.getBodyBytes()
if cast[string](res) == PostRequests[1][3]: if string.fromBytes(res) == PostRequests[1][3]:
inc(counter) inc(counter)
await response.closeWait() await response.closeWait()
await request.closeWait() await request.closeWait()
@ -649,26 +691,29 @@ suite "HTTP client testing suite":
var lastAddress: Uri var lastAddress: Uri
proc process(r: RequestFence): Future[HttpResponseRef] {. proc process(r: RequestFence): Future[HttpResponseRef] {.
async.} = async: (raises: [CancelledError]).} =
if r.isOk(): if r.isOk():
let request = r.get() let request = r.get()
case request.uri.path try:
of "/": case request.uri.path
return await request.redirect(Http302, "/redirect/1") of "/":
of "/redirect/1": await request.redirect(Http302, "/redirect/1")
return await request.redirect(Http302, "/next/redirect/2") of "/redirect/1":
of "/next/redirect/2": await request.redirect(Http302, "/next/redirect/2")
return await request.redirect(Http302, "redirect/3") of "/next/redirect/2":
of "/next/redirect/redirect/3": await request.redirect(Http302, "redirect/3")
return await request.redirect(Http302, "next/redirect/4") of "/next/redirect/redirect/3":
of "/next/redirect/redirect/next/redirect/4": await request.redirect(Http302, "next/redirect/4")
return await request.redirect(Http302, lastAddress) of "/next/redirect/redirect/next/redirect/4":
of "/final/5": await request.redirect(Http302, lastAddress)
return await request.respond(Http200, "ok-5") of "/final/5":
else: await request.respond(Http200, "ok-5")
return await request.respond(Http404, "Page not found") else:
await request.respond(Http404, "Page not found")
except HttpWriteError as exc:
defaultResponse(exc)
else: else:
return defaultResponse() defaultResponse()
var server = createServer(initTAddress("127.0.0.1:0"), process, secure) var server = createServer(initTAddress("127.0.0.1:0"), process, secure)
server.start() server.start()
@ -704,6 +749,107 @@ suite "HTTP client testing suite":
await server.closeWait() await server.closeWait()
return "redirect-" & $res return "redirect-" & $res
proc testSendCancelLeaksTest(secure: bool): Future[bool] {.async.} =
proc process(r: RequestFence): Future[HttpResponseRef] {.
async: (raises: [CancelledError]).} =
defaultResponse()
var server = createServer(initTAddress("127.0.0.1:0"), process, secure)
server.start()
let address = server.instance.localAddress()
let ha =
if secure:
getAddress(address, HttpClientScheme.Secure, "/")
else:
getAddress(address, HttpClientScheme.NonSecure, "/")
var counter = 0
while true:
let
session = createSession(secure)
request = HttpClientRequestRef.new(session, ha, MethodGet)
requestFut = request.send()
if counter > 0:
await stepsAsync(counter)
let exitLoop =
if not(requestFut.finished()):
await cancelAndWait(requestFut)
doAssert(cancelled(requestFut) or completed(requestFut),
"Future should be Cancelled or Completed at this point")
if requestFut.completed():
let response = await requestFut
await response.closeWait()
inc(counter)
false
else:
let response = await requestFut
await response.closeWait()
true
await request.closeWait()
await session.closeWait()
if exitLoop:
break
await server.stop()
await server.closeWait()
return true
proc testOpenCancelLeaksTest(secure: bool): Future[bool] {.async.} =
proc process(r: RequestFence): Future[HttpResponseRef] {.
async: (raises: [CancelledError]).} =
defaultResponse()
var server = createServer(initTAddress("127.0.0.1:0"), process, secure)
server.start()
let address = server.instance.localAddress()
let ha =
if secure:
getAddress(address, HttpClientScheme.Secure, "/")
else:
getAddress(address, HttpClientScheme.NonSecure, "/")
var counter = 0
while true:
let
session = createSession(secure)
request = HttpClientRequestRef.new(session, ha, MethodPost)
bodyFut = request.open()
if counter > 0:
await stepsAsync(counter)
let exitLoop =
if not(bodyFut.finished()):
await cancelAndWait(bodyFut)
doAssert(cancelled(bodyFut) or completed(bodyFut),
"Future should be Cancelled or Completed at this point")
if bodyFut.completed():
let bodyWriter = await bodyFut
await bodyWriter.closeWait()
inc(counter)
false
else:
let bodyWriter = await bodyFut
await bodyWriter.closeWait()
true
await request.closeWait()
await session.closeWait()
if exitLoop:
break
await server.stop()
await server.closeWait()
return true
# proc testBasicAuthorization(): Future[bool] {.async.} = # proc testBasicAuthorization(): Future[bool] {.async.} =
# let session = HttpSessionRef.new({HttpClientFlag.NoVerifyHost}, # let session = HttpSessionRef.new({HttpClientFlag.NoVerifyHost},
# maxRedirections = 10) # maxRedirections = 10)
@ -766,20 +912,24 @@ suite "HTTP client testing suite":
return @[(data1.status, data1.data.bytesToString(), count), return @[(data1.status, data1.data.bytesToString(), count),
(data2.status, data2.data.bytesToString(), count)] (data2.status, data2.data.bytesToString(), count)]
proc process(r: RequestFence): Future[HttpResponseRef] {.async.} = proc process(r: RequestFence): Future[HttpResponseRef] {.
async: (raises: [CancelledError]).} =
if r.isOk(): if r.isOk():
let request = r.get() let request = r.get()
case request.uri.path try:
of "/keep": case request.uri.path
let headers = HttpTable.init([("connection", "keep-alive")]) of "/keep":
return await request.respond(Http200, "ok", headers = headers) let headers = HttpTable.init([("connection", "keep-alive")])
of "/drop": await request.respond(Http200, "ok", headers = headers)
let headers = HttpTable.init([("connection", "close")]) of "/drop":
return await request.respond(Http200, "ok", headers = headers) let headers = HttpTable.init([("connection", "close")])
else: await request.respond(Http200, "ok", headers = headers)
return await request.respond(Http404, "Page not found") else:
await request.respond(Http404, "Page not found")
except HttpWriteError as exc:
defaultResponse(exc)
else: else:
return defaultResponse() defaultResponse()
var server = createServer(initTAddress("127.0.0.1:0"), process, false) var server = createServer(initTAddress("127.0.0.1:0"), process, false)
server.start() server.start()
@ -901,16 +1051,20 @@ suite "HTTP client testing suite":
await request.closeWait() await request.closeWait()
return (data.status, data.data.bytesToString(), 0) return (data.status, data.data.bytesToString(), 0)
proc process(r: RequestFence): Future[HttpResponseRef] {.async.} = proc process(r: RequestFence): Future[HttpResponseRef] {.
async: (raises: [CancelledError]).} =
if r.isOk(): if r.isOk():
let request = r.get() let request = r.get()
case request.uri.path try:
of "/test": case request.uri.path
return await request.respond(Http200, "ok") of "/test":
else: await request.respond(Http200, "ok")
return await request.respond(Http404, "Page not found") else:
await request.respond(Http404, "Page not found")
except HttpWriteError as exc:
defaultResponse(exc)
else: else:
return defaultResponse() defaultResponse()
var server = createServer(initTAddress("127.0.0.1:0"), process, false) var server = createServer(initTAddress("127.0.0.1:0"), process, false)
server.start() server.start()
@ -960,19 +1114,23 @@ suite "HTTP client testing suite":
await request.closeWait() await request.closeWait()
return (data.status, data.data.bytesToString(), 0) return (data.status, data.data.bytesToString(), 0)
proc process(r: RequestFence): Future[HttpResponseRef] {.async.} = proc process(r: RequestFence): Future[HttpResponseRef] {.
async: (raises: [CancelledError]).} =
if r.isOk(): if r.isOk():
let request = r.get() let request = r.get()
case request.uri.path try:
of "/test": case request.uri.path
return await request.respond(Http200, "ok") of "/test":
of "/keep-test": await request.respond(Http200, "ok")
let headers = HttpTable.init([("Connection", "keep-alive")]) of "/keep-test":
return await request.respond(Http200, "not-alive", headers) let headers = HttpTable.init([("Connection", "keep-alive")])
else: await request.respond(Http200, "not-alive", headers)
return await request.respond(Http404, "Page not found") else:
await request.respond(Http404, "Page not found")
except HttpWriteError as exc:
defaultResponse(exc)
else: else:
return defaultResponse() defaultResponse()
var server = createServer(initTAddress("127.0.0.1:0"), process, false) var server = createServer(initTAddress("127.0.0.1:0"), process, false)
server.start() server.start()
@ -1075,58 +1233,62 @@ suite "HTTP client testing suite":
return false return false
true true
proc process(r: RequestFence): Future[HttpResponseRef] {.async.} = proc process(r: RequestFence): Future[HttpResponseRef] {.
async: (raises: [CancelledError]).} =
if r.isOk(): if r.isOk():
let request = r.get() let request = r.get()
if request.uri.path.startsWith("/test/single/"): try:
let index = if request.uri.path.startsWith("/test/single/"):
block: let index =
var res = -1 block:
for index, value in SingleGoodTests.pairs(): var res = -1
if value[0] == request.uri.path: for index, value in SingleGoodTests.pairs():
res = index if value[0] == request.uri.path:
break res = index
res break
if index < 0: res
return await request.respond(Http404, "Page not found") if index < 0:
var response = request.getResponse() return await request.respond(Http404, "Page not found")
response.status = Http200 var response = request.getResponse()
await response.sendBody(SingleGoodTests[index][1]) response.status = Http200
return response await response.sendBody(SingleGoodTests[index][1])
elif request.uri.path.startsWith("/test/multiple/"): response
let index = elif request.uri.path.startsWith("/test/multiple/"):
block: let index =
var res = -1 block:
for index, value in MultipleGoodTests.pairs(): var res = -1
if value[0] == request.uri.path: for index, value in MultipleGoodTests.pairs():
res = index if value[0] == request.uri.path:
break res = index
res break
if index < 0: res
return await request.respond(Http404, "Page not found") if index < 0:
var response = request.getResponse() return await request.respond(Http404, "Page not found")
response.status = Http200 var response = request.getResponse()
await response.sendBody(MultipleGoodTests[index][1]) response.status = Http200
return response await response.sendBody(MultipleGoodTests[index][1])
elif request.uri.path.startsWith("/test/overflow/"): response
let index = elif request.uri.path.startsWith("/test/overflow/"):
block: let index =
var res = -1 block:
for index, value in OverflowTests.pairs(): var res = -1
if value[0] == request.uri.path: for index, value in OverflowTests.pairs():
res = index if value[0] == request.uri.path:
break res = index
res break
if index < 0: res
return await request.respond(Http404, "Page not found") if index < 0:
var response = request.getResponse() return await request.respond(Http404, "Page not found")
response.status = Http200 var response = request.getResponse()
await response.sendBody(OverflowTests[index][1]) response.status = Http200
return response await response.sendBody(OverflowTests[index][1])
else: response
return await request.respond(Http404, "Page not found") else:
defaultResponse()
except HttpWriteError as exc:
defaultResponse(exc)
else: else:
return defaultResponse() defaultResponse()
var server = createServer(initTAddress("127.0.0.1:0"), process, secure) var server = createServer(initTAddress("127.0.0.1:0"), process, secure)
server.start() server.start()
@ -1243,6 +1405,18 @@ suite "HTTP client testing suite":
test "HTTP(S) client maximum redirections test": test "HTTP(S) client maximum redirections test":
check waitFor(testRequestRedirectTest(true, 4)) == "redirect-true" check waitFor(testRequestRedirectTest(true, 4)) == "redirect-true"
test "HTTP send() cancellation leaks test":
check waitFor(testSendCancelLeaksTest(false)) == true
test "HTTP(S) send() cancellation leaks test":
check waitFor(testSendCancelLeaksTest(true)) == true
test "HTTP open() cancellation leaks test":
check waitFor(testOpenCancelLeaksTest(false)) == true
test "HTTP(S) open() cancellation leaks test":
check waitFor(testOpenCancelLeaksTest(true)) == true
test "HTTPS basic authorization test": test "HTTPS basic authorization test":
skip() skip()
# This test disabled because remote service is pretty flaky and fails pretty # This test disabled because remote service is pretty flaky and fails pretty
@ -1262,5 +1436,145 @@ suite "HTTP client testing suite":
test "HTTP client server-sent events test": test "HTTP client server-sent events test":
check waitFor(testServerSentEvents(false)) == true check waitFor(testServerSentEvents(false)) == true
test "Leaks test": test "HTTP getHttpAddress() test":
checkLeaks() block:
# HTTP client supports only `http` and `https` schemes in URL.
let res = getHttpAddress("ftp://ftp.scene.org")
check:
res.isErr()
res.error == HttpAddressErrorType.InvalidUrlScheme
res.error.isCriticalError()
block:
# HTTP URL default ports and custom ports test
let
res1 = getHttpAddress("http://www.google.com")
res2 = getHttpAddress("https://www.google.com")
res3 = getHttpAddress("http://www.google.com:35000")
res4 = getHttpAddress("https://www.google.com:25000")
check:
res1.isOk()
res2.isOk()
res3.isOk()
res4.isOk()
res1.get().port == 80
res2.get().port == 443
res3.get().port == 35000
res4.get().port == 25000
block:
# HTTP URL invalid port values test
let
res1 = getHttpAddress("http://www.google.com:-80")
res2 = getHttpAddress("http://www.google.com:0")
res3 = getHttpAddress("http://www.google.com:65536")
res4 = getHttpAddress("http://www.google.com:65537")
res5 = getHttpAddress("https://www.google.com:-443")
res6 = getHttpAddress("https://www.google.com:0")
res7 = getHttpAddress("https://www.google.com:65536")
res8 = getHttpAddress("https://www.google.com:65537")
check:
res1.isErr() and res1.error == HttpAddressErrorType.InvalidPortNumber
res1.error.isCriticalError()
res2.isOk()
res2.get().port == 0
res3.isErr() and res3.error == HttpAddressErrorType.InvalidPortNumber
res3.error.isCriticalError()
res4.isErr() and res4.error == HttpAddressErrorType.InvalidPortNumber
res4.error.isCriticalError()
res5.isErr() and res5.error == HttpAddressErrorType.InvalidPortNumber
res5.error.isCriticalError()
res6.isOk()
res6.get().port == 0
res7.isErr() and res7.error == HttpAddressErrorType.InvalidPortNumber
res7.error.isCriticalError()
res8.isErr() and res8.error == HttpAddressErrorType.InvalidPortNumber
res8.error.isCriticalError()
block:
# HTTP URL missing hostname
let
res1 = getHttpAddress("http://")
res2 = getHttpAddress("https://")
check:
res1.isErr() and res1.error == HttpAddressErrorType.MissingHostname
res1.error.isCriticalError()
res2.isErr() and res2.error == HttpAddressErrorType.MissingHostname
res2.error.isCriticalError()
block:
# No resolution flags and incorrect URL
let
flags = {HttpClientFlag.NoInet4Resolution,
HttpClientFlag.NoInet6Resolution}
res1 = getHttpAddress("http://256.256.256.256", flags)
res2 = getHttpAddress(
"http://[FFFFFF:FFFF:FFFF:FFFF:FFFF:FFFF:FFFF:FFFF]", flags)
check:
res1.isErr() and res1.error == HttpAddressErrorType.InvalidIpHostname
res1.error.isCriticalError()
res2.isErr() and res2.error == HttpAddressErrorType.InvalidIpHostname
res2.error.isCriticalError()
block:
# Resolution of non-existent hostname
let res = getHttpAddress("http://eYr6bdBo.com")
check:
res.isErr() and res.error == HttpAddressErrorType.NameLookupFailed
res.error.isRecoverableError()
not(res.error.isCriticalError())
asyncTest "HTTPS response headers buffer size test":
const HeadersSize = HttpMaxHeadersSize
let expectValue =
string.fromBytes(createBigMessage("HEADERSTEST", HeadersSize))
proc process(r: RequestFence): Future[HttpResponseRef] {.
async: (raises: [CancelledError]).} =
if r.isOk():
let request = r.get()
try:
case request.uri.path
of "/test":
let headers = HttpTable.init([("big-header", expectValue)])
await request.respond(Http200, "ok", headers)
else:
await request.respond(Http404, "Page not found")
except HttpWriteError as exc:
defaultResponse(exc)
else:
defaultResponse()
var server = createServer(initTAddress("127.0.0.1:0"), process, false)
server.start()
let
address = server.instance.localAddress()
ha = getAddress(address, HttpClientScheme.NonSecure, "/test")
session = HttpSessionRef.new()
let
req1 = HttpClientRequestRef.new(session, ha)
req2 =
HttpClientRequestRef.new(session, ha,
maxResponseHeadersSize = HttpMaxHeadersSize * 2)
res1 =
try:
let res {.used.} = await send(req1)
await closeWait(req1)
await closeWait(res)
false
except HttpReadError:
true
except HttpError:
await closeWait(req1)
false
except CancelledError:
await closeWait(req1)
false
res2 = await send(req2)
check:
res1 == true
res2.status == 200
res2.headers.getString("big-header") == expectValue
await req1.closeWait()
await req2.closeWait()
await res2.closeWait()
await session.closeWait()
await server.stop()
await server.closeWait()

View File

@ -7,21 +7,35 @@
# MIT license (LICENSE-MIT) # MIT license (LICENSE-MIT)
import std/[strutils, algorithm] import std/[strutils, algorithm]
import ".."/chronos/unittest2/asynctests, import ".."/chronos/unittest2/asynctests,
".."/chronos, ".."/chronos/apps/http/httpserver, ".."/chronos,
".."/chronos/apps/http/httpcommon, ".."/chronos/apps/http/[httpserver, httpcommon, httpdebug]
".."/chronos/apps/http/httpdebug
import stew/base10 import stew/base10
{.used.} {.used.}
# Trouble finding this if defined near its use for `data2.sorted`, etc. likely
# related to "generic sandwich" issues. If any test ever wants to `sort` a
# `seq[(string, seq[string]]` differently, they may need to re-work that test.
proc `<`(a, b: (string, seq[string])): bool = a[0] < b[0]
suite "HTTP server testing suite": suite "HTTP server testing suite":
teardown:
checkLeaks()
type type
TooBigTest = enum TooBigTest = enum
GetBodyTest, ConsumeBodyTest, PostUrlTest, PostMultipartTest GetBodyTest, ConsumeBodyTest, PostUrlTest, PostMultipartTest
TestHttpResponse = object TestHttpResponse = object
status: int
headers: HttpTable headers: HttpTable
data: string data: string
FirstMiddlewareRef = ref object of HttpServerMiddlewareRef
someInteger: int
SecondMiddlewareRef = ref object of HttpServerMiddlewareRef
someString: string
proc httpClient(address: TransportAddress, proc httpClient(address: TransportAddress,
data: string): Future[string] {.async.} = data: string): Future[string] {.async.} =
var transp: StreamTransport var transp: StreamTransport
@ -51,7 +65,7 @@ suite "HTTP server testing suite":
zeroMem(addr buffer[0], len(buffer)) zeroMem(addr buffer[0], len(buffer))
await transp.readExactly(addr buffer[0], length) await transp.readExactly(addr buffer[0], length)
let data = bytesToString(buffer.toOpenArray(0, length - 1)) let data = bytesToString(buffer.toOpenArray(0, length - 1))
let headers = let (status, headers) =
block: block:
let resp = parseResponse(hdata, false) let resp = parseResponse(hdata, false)
if resp.failed(): if resp.failed():
@ -59,13 +73,43 @@ suite "HTTP server testing suite":
var res = HttpTable.init() var res = HttpTable.init()
for key, value in resp.headers(hdata): for key, value in resp.headers(hdata):
res.add(key, value) res.add(key, value)
res (resp.code, res)
return TestHttpResponse(headers: headers, data: data) TestHttpResponse(status: status, headers: headers, data: data)
proc httpClient3(address: TransportAddress,
data: string): Future[TestHttpResponse] {.async.} =
var
transp: StreamTransport
buffer = newSeq[byte](4096)
sep = @[0x0D'u8, 0x0A'u8, 0x0D'u8, 0x0A'u8]
try:
transp = await connect(address)
if len(data) > 0:
let wres = await transp.write(data)
if wres != len(data):
raise newException(ValueError, "Unable to write full request")
let hres = await transp.readUntil(addr buffer[0], len(buffer), sep)
var hdata = @buffer
hdata.setLen(hres)
var rres = bytesToString(await transp.read())
let (status, headers) =
block:
let resp = parseResponse(hdata, false)
if resp.failed():
raise newException(ValueError, "Unable to decode response headers")
var res = HttpTable.init()
for key, value in resp.headers(hdata):
res.add(key, value)
(resp.code, res)
TestHttpResponse(status: status, headers: headers, data: rres)
finally:
if not(isNil(transp)):
await closeWait(transp)
proc testTooBigBodyChunked(operation: TooBigTest): Future[bool] {.async.} = proc testTooBigBodyChunked(operation: TooBigTest): Future[bool] {.async.} =
var serverRes = false var serverRes = false
proc process(r: RequestFence): Future[HttpResponseRef] {. proc process(r: RequestFence): Future[HttpResponseRef] {.
async.} = async: (raises: [CancelledError]).} =
if r.isOk(): if r.isOk():
let request = r.get() let request = r.get()
try: try:
@ -78,13 +122,15 @@ suite "HTTP server testing suite":
let ptable {.used.} = await request.post() let ptable {.used.} = await request.post()
of PostMultipartTest: of PostMultipartTest:
let ptable {.used.} = await request.post() let ptable {.used.} = await request.post()
except HttpCriticalError as exc: defaultResponse()
except HttpTransportError as exc:
defaultResponse(exc)
except HttpProtocolError as exc:
if exc.code == Http413: if exc.code == Http413:
serverRes = true serverRes = true
# Reraising exception, because processor should properly handle it. defaultResponse(exc)
raise exc
else: else:
return defaultResponse() defaultResponse()
let socketFlags = {ServerFlags.TcpNoDelay, ServerFlags.ReuseAddr} let socketFlags = {ServerFlags.TcpNoDelay, ServerFlags.ReuseAddr}
let res = HttpServerRef.new(initTAddress("127.0.0.1:0"), process, let res = HttpServerRef.new(initTAddress("127.0.0.1:0"), process,
@ -129,14 +175,17 @@ suite "HTTP server testing suite":
proc testTimeout(): Future[bool] {.async.} = proc testTimeout(): Future[bool] {.async.} =
var serverRes = false var serverRes = false
proc process(r: RequestFence): Future[HttpResponseRef] {. proc process(r: RequestFence): Future[HttpResponseRef] {.
async.} = async: (raises: [CancelledError]).} =
if r.isOk(): if r.isOk():
let request = r.get() let request = r.get()
return await request.respond(Http200, "TEST_OK", HttpTable.init()) try:
await request.respond(Http200, "TEST_OK", HttpTable.init())
except HttpWriteError as exc:
defaultResponse(exc)
else: else:
if r.error.kind == HttpServerError.TimeoutError: if r.error.kind == HttpServerError.TimeoutError:
serverRes = true serverRes = true
return defaultResponse() defaultResponse()
let socketFlags = {ServerFlags.TcpNoDelay, ServerFlags.ReuseAddr} let socketFlags = {ServerFlags.TcpNoDelay, ServerFlags.ReuseAddr}
let res = HttpServerRef.new(initTAddress("127.0.0.1:0"), let res = HttpServerRef.new(initTAddress("127.0.0.1:0"),
@ -159,14 +208,17 @@ suite "HTTP server testing suite":
proc testEmpty(): Future[bool] {.async.} = proc testEmpty(): Future[bool] {.async.} =
var serverRes = false var serverRes = false
proc process(r: RequestFence): Future[HttpResponseRef] {. proc process(r: RequestFence): Future[HttpResponseRef] {.
async.} = async: (raises: [CancelledError]).} =
if r.isOk(): if r.isOk():
let request = r.get() let request = r.get()
return await request.respond(Http200, "TEST_OK", HttpTable.init()) try:
await request.respond(Http200, "TEST_OK", HttpTable.init())
except HttpWriteError as exc:
defaultResponse(exc)
else: else:
if r.error.kind == HttpServerError.CriticalError: if r.error.kind == HttpServerError.ProtocolError:
serverRes = true serverRes = true
return defaultResponse() defaultResponse()
let socketFlags = {ServerFlags.TcpNoDelay, ServerFlags.ReuseAddr} let socketFlags = {ServerFlags.TcpNoDelay, ServerFlags.ReuseAddr}
let res = HttpServerRef.new(initTAddress("127.0.0.1:0"), let res = HttpServerRef.new(initTAddress("127.0.0.1:0"),
@ -189,14 +241,17 @@ suite "HTTP server testing suite":
proc testTooBig(): Future[bool] {.async.} = proc testTooBig(): Future[bool] {.async.} =
var serverRes = false var serverRes = false
proc process(r: RequestFence): Future[HttpResponseRef] {. proc process(r: RequestFence): Future[HttpResponseRef] {.
async.} = async: (raises: [CancelledError]).} =
if r.isOk(): if r.isOk():
let request = r.get() let request = r.get()
return await request.respond(Http200, "TEST_OK", HttpTable.init()) try:
await request.respond(Http200, "TEST_OK", HttpTable.init())
except HttpWriteError as exc:
defaultResponse(exc)
else: else:
if r.error.error == HttpServerError.CriticalError: if r.error.error == HttpServerError.ProtocolError:
serverRes = true serverRes = true
return defaultResponse() defaultResponse()
let socketFlags = {ServerFlags.TcpNoDelay, ServerFlags.ReuseAddr} let socketFlags = {ServerFlags.TcpNoDelay, ServerFlags.ReuseAddr}
let res = HttpServerRef.new(initTAddress("127.0.0.1:0"), process, let res = HttpServerRef.new(initTAddress("127.0.0.1:0"), process,
@ -220,13 +275,11 @@ suite "HTTP server testing suite":
proc testTooBigBody(): Future[bool] {.async.} = proc testTooBigBody(): Future[bool] {.async.} =
var serverRes = false var serverRes = false
proc process(r: RequestFence): Future[HttpResponseRef] {. proc process(r: RequestFence): Future[HttpResponseRef] {.
async.} = async: (raises: [CancelledError]).} =
if r.isOk(): if r.isErr():
discard if r.error.error == HttpServerError.ProtocolError:
else:
if r.error.error == HttpServerError.CriticalError:
serverRes = true serverRes = true
return defaultResponse() defaultResponse()
let socketFlags = {ServerFlags.TcpNoDelay, ServerFlags.ReuseAddr} let socketFlags = {ServerFlags.TcpNoDelay, ServerFlags.ReuseAddr}
let res = HttpServerRef.new(initTAddress("127.0.0.1:0"), process, let res = HttpServerRef.new(initTAddress("127.0.0.1:0"), process,
@ -267,7 +320,7 @@ suite "HTTP server testing suite":
proc testQuery(): Future[bool] {.async.} = proc testQuery(): Future[bool] {.async.} =
var serverRes = false var serverRes = false
proc process(r: RequestFence): Future[HttpResponseRef] {. proc process(r: RequestFence): Future[HttpResponseRef] {.
async.} = async: (raises: [CancelledError]).} =
if r.isOk(): if r.isOk():
let request = r.get() let request = r.get()
var kres = newSeq[string]() var kres = newSeq[string]()
@ -275,11 +328,14 @@ suite "HTTP server testing suite":
kres.add(k & ":" & v) kres.add(k & ":" & v)
sort(kres) sort(kres)
serverRes = true serverRes = true
return await request.respond(Http200, "TEST_OK:" & kres.join(":"), try:
HttpTable.init()) await request.respond(Http200, "TEST_OK:" & kres.join(":"),
HttpTable.init())
except HttpWriteError as exc:
serverRes = false
defaultResponse(exc)
else: else:
serverRes = false defaultResponse()
return defaultResponse()
let socketFlags = {ServerFlags.TcpNoDelay, ServerFlags.ReuseAddr} let socketFlags = {ServerFlags.TcpNoDelay, ServerFlags.ReuseAddr}
let res = HttpServerRef.new(initTAddress("127.0.0.1:0"), process, let res = HttpServerRef.new(initTAddress("127.0.0.1:0"), process,
@ -297,10 +353,9 @@ suite "HTTP server testing suite":
"GET /?a=%D0%9F&%D0%A4=%D0%91&b=%D0%A6&c=%D0%AE HTTP/1.0\r\n\r\n") "GET /?a=%D0%9F&%D0%A4=%D0%91&b=%D0%A6&c=%D0%AE HTTP/1.0\r\n\r\n")
await server.stop() await server.stop()
await server.closeWait() await server.closeWait()
let r = serverRes and serverRes and
(data1.find("TEST_OK:a:1:a:2:b:3:c:4") >= 0) and (data1.find("TEST_OK:a:1:a:2:b:3:c:4") >= 0) and
(data2.find("TEST_OK:a:П:b:Ц:c:Ю:Ф:Б") >= 0) (data2.find("TEST_OK:a:П:b:Ц:c:Ю:Ф:Б") >= 0)
return r
check waitFor(testQuery()) == true check waitFor(testQuery()) == true
@ -308,7 +363,7 @@ suite "HTTP server testing suite":
proc testHeaders(): Future[bool] {.async.} = proc testHeaders(): Future[bool] {.async.} =
var serverRes = false var serverRes = false
proc process(r: RequestFence): Future[HttpResponseRef] {. proc process(r: RequestFence): Future[HttpResponseRef] {.
async.} = async: (raises: [CancelledError]).} =
if r.isOk(): if r.isOk():
let request = r.get() let request = r.get()
var kres = newSeq[string]() var kres = newSeq[string]()
@ -316,11 +371,14 @@ suite "HTTP server testing suite":
kres.add(k & ":" & v) kres.add(k & ":" & v)
sort(kres) sort(kres)
serverRes = true serverRes = true
return await request.respond(Http200, "TEST_OK:" & kres.join(":"), try:
HttpTable.init()) await request.respond(Http200, "TEST_OK:" & kres.join(":"),
HttpTable.init())
except HttpWriteError as exc:
serverRes = false
defaultResponse(exc)
else: else:
serverRes = false defaultResponse()
return defaultResponse()
let socketFlags = {ServerFlags.TcpNoDelay, ServerFlags.ReuseAddr} let socketFlags = {ServerFlags.TcpNoDelay, ServerFlags.ReuseAddr}
let res = HttpServerRef.new(initTAddress("127.0.0.1:0"), process, let res = HttpServerRef.new(initTAddress("127.0.0.1:0"), process,
@ -352,21 +410,30 @@ suite "HTTP server testing suite":
proc testPostUrl(): Future[bool] {.async.} = proc testPostUrl(): Future[bool] {.async.} =
var serverRes = false var serverRes = false
proc process(r: RequestFence): Future[HttpResponseRef] {. proc process(r: RequestFence): Future[HttpResponseRef] {.
async.} = async: (raises: [CancelledError]).} =
if r.isOk(): if r.isOk():
var kres = newSeq[string]() var kres = newSeq[string]()
let request = r.get() let request = r.get()
if request.meth in PostMethods: if request.meth in PostMethods:
let post = await request.post() let post =
try:
await request.post()
except HttpProtocolError as exc:
return defaultResponse(exc)
except HttpTransportError as exc:
return defaultResponse(exc)
for k, v in post.stringItems(): for k, v in post.stringItems():
kres.add(k & ":" & v) kres.add(k & ":" & v)
sort(kres) sort(kres)
serverRes = true serverRes = true
return await request.respond(Http200, "TEST_OK:" & kres.join(":"), try:
HttpTable.init()) await request.respond(Http200, "TEST_OK:" & kres.join(":"),
HttpTable.init())
except HttpWriteError as exc:
serverRes = false
defaultResponse(exc)
else: else:
serverRes = false defaultResponse()
return defaultResponse()
let socketFlags = {ServerFlags.TcpNoDelay, ServerFlags.ReuseAddr} let socketFlags = {ServerFlags.TcpNoDelay, ServerFlags.ReuseAddr}
let res = HttpServerRef.new(initTAddress("127.0.0.1:0"), process, let res = HttpServerRef.new(initTAddress("127.0.0.1:0"), process,
@ -396,21 +463,30 @@ suite "HTTP server testing suite":
proc testPostUrl2(): Future[bool] {.async.} = proc testPostUrl2(): Future[bool] {.async.} =
var serverRes = false var serverRes = false
proc process(r: RequestFence): Future[HttpResponseRef] {. proc process(r: RequestFence): Future[HttpResponseRef] {.
async.} = async: (raises: [CancelledError]).} =
if r.isOk(): if r.isOk():
var kres = newSeq[string]() var kres = newSeq[string]()
let request = r.get() let request = r.get()
if request.meth in PostMethods: if request.meth in PostMethods:
let post = await request.post() let post =
try:
await request.post()
except HttpProtocolError as exc:
return defaultResponse(exc)
except HttpTransportError as exc:
return defaultResponse(exc)
for k, v in post.stringItems(): for k, v in post.stringItems():
kres.add(k & ":" & v) kres.add(k & ":" & v)
sort(kres) sort(kres)
serverRes = true serverRes = true
return await request.respond(Http200, "TEST_OK:" & kres.join(":"), try:
HttpTable.init()) await request.respond(Http200, "TEST_OK:" & kres.join(":"),
HttpTable.init())
except HttpWriteError as exc:
serverRes = false
defaultResponse(exc)
else: else:
serverRes = false defaultResponse()
return defaultResponse()
let socketFlags = {ServerFlags.TcpNoDelay, ServerFlags.ReuseAddr} let socketFlags = {ServerFlags.TcpNoDelay, ServerFlags.ReuseAddr}
let res = HttpServerRef.new(initTAddress("127.0.0.1:0"), process, let res = HttpServerRef.new(initTAddress("127.0.0.1:0"), process,
@ -441,21 +517,30 @@ suite "HTTP server testing suite":
proc testPostMultipart(): Future[bool] {.async.} = proc testPostMultipart(): Future[bool] {.async.} =
var serverRes = false var serverRes = false
proc process(r: RequestFence): Future[HttpResponseRef] {. proc process(r: RequestFence): Future[HttpResponseRef] {.
async.} = async: (raises: [CancelledError]).} =
if r.isOk(): if r.isOk():
var kres = newSeq[string]() var kres = newSeq[string]()
let request = r.get() let request = r.get()
if request.meth in PostMethods: if request.meth in PostMethods:
let post = await request.post() let post =
try:
await request.post()
except HttpProtocolError as exc:
return defaultResponse(exc)
except HttpTransportError as exc:
return defaultResponse(exc)
for k, v in post.stringItems(): for k, v in post.stringItems():
kres.add(k & ":" & v) kres.add(k & ":" & v)
sort(kres) sort(kres)
serverRes = true serverRes = true
return await request.respond(Http200, "TEST_OK:" & kres.join(":"), try:
HttpTable.init()) await request.respond(Http200, "TEST_OK:" & kres.join(":"),
HttpTable.init())
except HttpWriteError as exc:
serverRes = false
defaultResponse(exc)
else: else:
serverRes = false defaultResponse()
return defaultResponse()
let socketFlags = {ServerFlags.TcpNoDelay, ServerFlags.ReuseAddr} let socketFlags = {ServerFlags.TcpNoDelay, ServerFlags.ReuseAddr}
let res = HttpServerRef.new(initTAddress("127.0.0.1:0"), process, let res = HttpServerRef.new(initTAddress("127.0.0.1:0"), process,
@ -497,21 +582,31 @@ suite "HTTP server testing suite":
proc testPostMultipart2(): Future[bool] {.async.} = proc testPostMultipart2(): Future[bool] {.async.} =
var serverRes = false var serverRes = false
proc process(r: RequestFence): Future[HttpResponseRef] {. proc process(r: RequestFence): Future[HttpResponseRef] {.
async.} = async: (raises: [CancelledError]).} =
if r.isOk(): if r.isOk():
var kres = newSeq[string]() var kres = newSeq[string]()
let request = r.get() let request = r.get()
if request.meth in PostMethods: if request.meth in PostMethods:
let post = await request.post() let post =
try:
await request.post()
except HttpProtocolError as exc:
return defaultResponse(exc)
except HttpTransportError as exc:
return defaultResponse(exc)
for k, v in post.stringItems(): for k, v in post.stringItems():
kres.add(k & ":" & v) kres.add(k & ":" & v)
sort(kres) sort(kres)
serverRes = true serverRes = true
return await request.respond(Http200, "TEST_OK:" & kres.join(":"), try:
HttpTable.init()) await request.respond(Http200, "TEST_OK:" & kres.join(":"),
HttpTable.init())
except HttpWriteError as exc:
serverRes = false
defaultResponse(exc)
else: else:
serverRes = false serverRes = false
return defaultResponse() defaultResponse()
let socketFlags = {ServerFlags.TcpNoDelay, ServerFlags.ReuseAddr} let socketFlags = {ServerFlags.TcpNoDelay, ServerFlags.ReuseAddr}
let res = HttpServerRef.new(initTAddress("127.0.0.1:0"), process, let res = HttpServerRef.new(initTAddress("127.0.0.1:0"), process,
@ -566,16 +661,20 @@ suite "HTTP server testing suite":
var eventContinue = newAsyncEvent() var eventContinue = newAsyncEvent()
var count = 0 var count = 0
proc process(r: RequestFence): Future[HttpResponseRef] {.async.} = proc process(r: RequestFence): Future[HttpResponseRef] {.
async: (raises: [CancelledError]).} =
if r.isOk(): if r.isOk():
let request = r.get() let request = r.get()
inc(count) inc(count)
if count == ClientsCount: if count == ClientsCount:
eventWait.fire() eventWait.fire()
await eventContinue.wait() await eventContinue.wait()
return await request.respond(Http404, "", HttpTable.init()) try:
await request.respond(Http404, "", HttpTable.init())
except HttpWriteError as exc:
defaultResponse(exc)
else: else:
return defaultResponse() defaultResponse()
let socketFlags = {ServerFlags.TcpNoDelay, ServerFlags.ReuseAddr} let socketFlags = {ServerFlags.TcpNoDelay, ServerFlags.ReuseAddr}
let res = HttpServerRef.new(initTAddress("127.0.0.1:0"), process, let res = HttpServerRef.new(initTAddress("127.0.0.1:0"), process,
@ -752,11 +851,11 @@ suite "HTTP server testing suite":
for key, value in table1.items(true): for key, value in table1.items(true):
data2.add((key, value)) data2.add((key, value))
check: check: # .sorted to not depend upon hash(key)-order
data1 == @[("Header2", "value2"), ("Header2", "VALUE3"), data1.sorted == sorted(@[("Header2", "value2"), ("Header2", "VALUE3"),
("Header1", "value1")] ("Header1", "value1")])
data2 == @[("Header2", @["value2", "VALUE3"]), data2.sorted == sorted(@[("Header2", @["value2", "VALUE3"]),
("Header1", @["value1"])] ("Header1", @["value1"])])
table1.set("header2", "value4") table1.set("header2", "value4")
check: check:
@ -1230,23 +1329,26 @@ suite "HTTP server testing suite":
proc testPostMultipart2(): Future[bool] {.async.} = proc testPostMultipart2(): Future[bool] {.async.} =
var serverRes = false var serverRes = false
proc process(r: RequestFence): Future[HttpResponseRef] {. proc process(r: RequestFence): Future[HttpResponseRef] {.
async.} = async: (raises: [CancelledError]).} =
if r.isOk(): if r.isOk():
let request = r.get() let request = r.get()
let response = request.getResponse() let response = request.getResponse()
await response.prepareSSE() try:
await response.send("event: event1\r\ndata: data1\r\n\r\n") await response.prepareSSE()
await response.send("event: event2\r\ndata: data2\r\n\r\n") await response.send("event: event1\r\ndata: data1\r\n\r\n")
await response.sendEvent("event3", "data3") await response.send("event: event2\r\ndata: data2\r\n\r\n")
await response.sendEvent("event4", "data4") await response.sendEvent("event3", "data3")
await response.send("data: data5\r\n\r\n") await response.sendEvent("event4", "data4")
await response.sendEvent("", "data6") await response.send("data: data5\r\n\r\n")
await response.finish() await response.sendEvent("", "data6")
serverRes = true await response.finish()
return response serverRes = true
response
except HttpWriteError as exc:
serverRes = false
defaultResponse(exc)
else: else:
serverRes = false defaultResponse()
return defaultResponse()
let socketFlags = {ServerFlags.TcpNoDelay, ServerFlags.ReuseAddr} let socketFlags = {ServerFlags.TcpNoDelay, ServerFlags.ReuseAddr}
let res = HttpServerRef.new(initTAddress("127.0.0.1:0"), process, let res = HttpServerRef.new(initTAddress("127.0.0.1:0"), process,
@ -1305,12 +1407,16 @@ suite "HTTP server testing suite":
{}, false, "close") {}, false, "close")
] ]
proc process(r: RequestFence): Future[HttpResponseRef] {.async.} = proc process(r: RequestFence): Future[HttpResponseRef] {.
async: (raises: [CancelledError]).} =
if r.isOk(): if r.isOk():
let request = r.get() let request = r.get()
return await request.respond(Http200, "TEST_OK", HttpTable.init()) try:
await request.respond(Http200, "TEST_OK", HttpTable.init())
except HttpWriteError as exc:
defaultResponse(exc)
else: else:
return defaultResponse() defaultResponse()
for test in TestMessages: for test in TestMessages:
let let
@ -1327,44 +1433,47 @@ suite "HTTP server testing suite":
server.start() server.start()
var transp: StreamTransport var transp: StreamTransport
try:
transp = await connect(address)
block:
let response = await transp.httpClient2(test[0], 7)
check:
response.data == "TEST_OK"
response.headers.getString("connection") == test[3]
# We do this sleeping here just because we running both server and
# client in single process, so when we received response from server
# it does not mean that connection has been immediately closed - it
# takes some more calls, so we trying to get this calls happens.
await sleepAsync(50.milliseconds)
let connectionStillAvailable =
try:
let response {.used.} = await transp.httpClient2(test[0], 7)
true
except CatchableError:
false
check connectionStillAvailable == test[2] transp = await connect(address)
block:
let response = await transp.httpClient2(test[0], 7)
check:
response.data == "TEST_OK"
response.headers.getString("connection") == test[3]
# We do this sleeping here just because we running both server and
# client in single process, so when we received response from server
# it does not mean that connection has been immediately closed - it
# takes some more calls, so we trying to get this calls happens.
await sleepAsync(50.milliseconds)
let connectionStillAvailable =
try:
let response {.used.} = await transp.httpClient2(test[0], 7)
true
except CatchableError:
false
finally: check connectionStillAvailable == test[2]
if not(isNil(transp)):
await transp.closeWait() if not(isNil(transp)):
await server.stop() await transp.closeWait()
await server.closeWait() await server.stop()
await server.closeWait()
asyncTest "HTTP debug tests": asyncTest "HTTP debug tests":
const const
TestsCount = 10 TestsCount = 10
TestRequest = "GET / HTTP/1.1\r\nConnection: keep-alive\r\n\r\n" TestRequest = "GET /httpdebug HTTP/1.1\r\nConnection: keep-alive\r\n\r\n"
proc process(r: RequestFence): Future[HttpResponseRef] {.async.} = proc process(r: RequestFence): Future[HttpResponseRef] {.
async: (raises: [CancelledError]).} =
if r.isOk(): if r.isOk():
let request = r.get() let request = r.get()
return await request.respond(Http200, "TEST_OK", HttpTable.init()) try:
await request.respond(Http200, "TEST_OK", HttpTable.init())
except HttpWriteError as exc:
defaultResponse(exc)
else: else:
return defaultResponse() defaultResponse()
proc client(address: TransportAddress, proc client(address: TransportAddress,
data: string): Future[StreamTransport] {.async.} = data: string): Future[StreamTransport] {.async.} =
@ -1401,31 +1510,320 @@ suite "HTTP server testing suite":
info.flags == {HttpServerFlags.Http11Pipeline} info.flags == {HttpServerFlags.Http11Pipeline}
info.socketFlags == socketFlags info.socketFlags == socketFlags
var clientFutures: seq[Future[StreamTransport]]
for i in 0 ..< TestsCount:
clientFutures.add(client(address, TestRequest))
await allFutures(clientFutures)
let connections = server.getConnections()
check len(connections) == TestsCount
let currentTime = Moment.now()
for index, connection in connections.pairs():
let transp = clientFutures[index].read()
check:
connection.remoteAddress.get() == transp.localAddress()
connection.localAddress.get() == transp.remoteAddress()
connection.connectionType == ConnectionType.NonSecure
connection.connectionState == ConnectionState.Alive
connection.query.get("") == "/httpdebug"
(currentTime - connection.createMoment.get()) != ZeroDuration
(currentTime - connection.acceptMoment) != ZeroDuration
var pending: seq[Future[void]]
for transpFut in clientFutures:
pending.add(closeWait(transpFut.read()))
await allFutures(pending)
await server.stop()
await server.closeWait()
asyncTest "HTTP middleware request filtering test":
proc init(t: typedesc[FirstMiddlewareRef],
data: int): HttpServerMiddlewareRef =
proc shandler(
middleware: HttpServerMiddlewareRef,
reqfence: RequestFence,
nextHandler: HttpProcessCallback2
): Future[HttpResponseRef] {.async: (raises: [CancelledError]).} =
let mw = FirstMiddlewareRef(middleware)
if reqfence.isErr():
# Our handler is not supposed to handle request errors, so we
# call next handler in sequence which could process errors.
return await nextHandler(reqfence)
let request = reqfence.get()
if request.uri.path == "/first":
# This is request we are waiting for, so we going to process it.
try:
await request.respond(Http200, $mw.someInteger)
except HttpWriteError as exc:
defaultResponse(exc)
else:
# We know nothing about request's URI, so we pass this request to the
# next handler which could process such request.
await nextHandler(reqfence)
HttpServerMiddlewareRef(
FirstMiddlewareRef(someInteger: data, handler: shandler))
proc init(t: typedesc[SecondMiddlewareRef],
data: string): HttpServerMiddlewareRef =
proc shandler(
middleware: HttpServerMiddlewareRef,
reqfence: RequestFence,
nextHandler: HttpProcessCallback2
): Future[HttpResponseRef] {.async: (raises: [CancelledError]).} =
let mw = SecondMiddlewareRef(middleware)
if reqfence.isErr():
# Our handler is not supposed to handle request errors, so we
# call next handler in sequence which could process errors.
return await nextHandler(reqfence)
let request = reqfence.get()
if request.uri.path == "/second":
# This is request we are waiting for, so we going to process it.
try:
await request.respond(Http200, mw.someString)
except HttpWriteError as exc:
defaultResponse(exc)
else:
# We know nothing about request's URI, so we pass this request to the
# next handler which could process such request.
await nextHandler(reqfence)
HttpServerMiddlewareRef(
SecondMiddlewareRef(someString: data, handler: shandler))
proc process(r: RequestFence): Future[HttpResponseRef] {.
async: (raises: [CancelledError]).} =
if r.isOk():
let request = r.get()
if request.uri.path == "/test":
try:
await request.respond(Http200, "ORIGIN")
except HttpWriteError as exc:
defaultResponse(exc)
else:
defaultResponse()
else:
defaultResponse()
let
middlewares = [FirstMiddlewareRef.init(655370),
SecondMiddlewareRef.init("SECOND")]
socketFlags = {ServerFlags.TcpNoDelay, ServerFlags.ReuseAddr}
res = HttpServerRef.new(initTAddress("127.0.0.1:0"), process,
socketFlags = socketFlags,
middlewares = middlewares)
check res.isOk()
let server = res.get()
server.start()
let
address = server.instance.localAddress()
req1 = "GET /test HTTP/1.1\r\n\r\n"
req2 = "GET /first HTTP/1.1\r\n\r\n"
req3 = "GET /second HTTP/1.1\r\n\r\n"
req4 = "GET /noway HTTP/1.1\r\n\r\n"
resp1 = await httpClient3(address, req1)
resp2 = await httpClient3(address, req2)
resp3 = await httpClient3(address, req3)
resp4 = await httpClient3(address, req4)
check:
resp1.status == 200
resp1.data == "ORIGIN"
resp2.status == 200
resp2.data == "655370"
resp3.status == 200
resp3.data == "SECOND"
resp4.status == 404
await server.stop()
await server.closeWait()
asyncTest "HTTP middleware request modification test":
proc init(t: typedesc[FirstMiddlewareRef],
data: int): HttpServerMiddlewareRef =
proc shandler(
middleware: HttpServerMiddlewareRef,
reqfence: RequestFence,
nextHandler: HttpProcessCallback2
): Future[HttpResponseRef] {.async: (raises: [CancelledError]).} =
let mw = FirstMiddlewareRef(middleware)
if reqfence.isErr():
# Our handler is not supposed to handle request errors, so we
# call next handler in sequence which could process errors.
return await nextHandler(reqfence)
let
request = reqfence.get()
modifiedUri = "/modified/" & $mw.someInteger & request.rawPath
var modifiedHeaders = request.headers
modifiedHeaders.add("X-Modified", "test-value")
let res = request.updateRequest(modifiedUri, modifiedHeaders)
if res.isErr():
return defaultResponse(res.error)
# We sending modified request to the next handler.
await nextHandler(reqfence)
HttpServerMiddlewareRef(
FirstMiddlewareRef(someInteger: data, handler: shandler))
proc process(r: RequestFence): Future[HttpResponseRef] {.
async: (raises: [CancelledError]).} =
if r.isOk():
let request = r.get()
try:
await request.respond(Http200, request.rawPath & ":" &
request.headers.getString("x-modified"))
except HttpWriteError as exc:
defaultResponse(exc)
else:
defaultResponse()
let
middlewares = [FirstMiddlewareRef.init(655370)]
socketFlags = {ServerFlags.TcpNoDelay, ServerFlags.ReuseAddr}
res = HttpServerRef.new(initTAddress("127.0.0.1:0"), process,
socketFlags = socketFlags,
middlewares = middlewares)
check res.isOk()
let server = res.get()
server.start()
let
address = server.instance.localAddress()
req1 = "GET /test HTTP/1.1\r\n\r\n"
req2 = "GET /first HTTP/1.1\r\n\r\n"
req3 = "GET /second HTTP/1.1\r\n\r\n"
req4 = "GET /noway HTTP/1.1\r\n\r\n"
resp1 = await httpClient3(address, req1)
resp2 = await httpClient3(address, req2)
resp3 = await httpClient3(address, req3)
resp4 = await httpClient3(address, req4)
check:
resp1.status == 200
resp1.data == "/modified/655370/test:test-value"
resp2.status == 200
resp2.data == "/modified/655370/first:test-value"
resp3.status == 200
resp3.data == "/modified/655370/second:test-value"
resp4.status == 200
resp4.data == "/modified/655370/noway:test-value"
await server.stop()
await server.closeWait()
asyncTest "HTTP middleware request blocking test":
proc init(t: typedesc[FirstMiddlewareRef],
data: int): HttpServerMiddlewareRef =
proc shandler(
middleware: HttpServerMiddlewareRef,
reqfence: RequestFence,
nextHandler: HttpProcessCallback2
): Future[HttpResponseRef] {.async: (raises: [CancelledError]).} =
if reqfence.isErr():
# Our handler is not supposed to handle request errors, so we
# call next handler in sequence which could process errors.
return await nextHandler(reqfence)
let request = reqfence.get()
if request.uri.path == "/first":
# Blocking request by disconnecting remote peer.
dropResponse()
elif request.uri.path == "/second":
# Blocking request by sending HTTP error message with 401 code.
codeResponse(Http401)
else:
# Allow all other requests to be processed by next handler.
await nextHandler(reqfence)
HttpServerMiddlewareRef(
FirstMiddlewareRef(someInteger: data, handler: shandler))
proc process(r: RequestFence): Future[HttpResponseRef] {.
async: (raises: [CancelledError]).} =
if r.isOk():
let request = r.get()
try:
await request.respond(Http200, "ORIGIN")
except HttpWriteError as exc:
defaultResponse(exc)
else:
defaultResponse()
let
middlewares = [FirstMiddlewareRef.init(655370)]
socketFlags = {ServerFlags.TcpNoDelay, ServerFlags.ReuseAddr}
res = HttpServerRef.new(initTAddress("127.0.0.1:0"), process,
socketFlags = socketFlags,
middlewares = middlewares)
check res.isOk()
let server = res.get()
server.start()
let
address = server.instance.localAddress()
req1 = "GET /test HTTP/1.1\r\n\r\n"
req2 = "GET /first HTTP/1.1\r\n\r\n"
req3 = "GET /second HTTP/1.1\r\n\r\n"
resp1 = await httpClient3(address, req1)
resp3 = await httpClient3(address, req3)
check:
resp1.status == 200
resp1.data == "ORIGIN"
resp3.status == 401
let checked =
try:
let res {.used.} = await httpClient3(address, req2)
false
except TransportIncompleteError:
true
check:
checked == true
await server.stop()
await server.closeWait()
asyncTest "HTTP server - baseUri value test":
proc process(r: RequestFence): Future[HttpResponseRef] {.
async: (raises: [CancelledError]).} =
defaultResponse()
let
expectUri2 = "http://www.chronos-test.com/"
address = initTAddress("127.0.0.1:0")
socketFlags = {ServerFlags.TcpNoDelay, ServerFlags.ReuseAddr}
res1 = HttpServerRef.new(address, process,
socketFlags = socketFlags)
res2 = HttpServerRef.new(address, process,
socketFlags = socketFlags,
serverUri = parseUri(expectUri2))
check:
res1.isOk == true
res2.isOk == true
let
server1 = res1.get()
server2 = res2.get()
try: try:
var clientFutures: seq[Future[StreamTransport]] server1.start()
for i in 0 ..< TestsCount: server2.start()
clientFutures.add(client(address, TestRequest)) let
await allFutures(clientFutures) localAddress = server1.instance.localAddress()
expectUri1 = "http://127.0.0.1:" & $localAddress.port & "/"
let connections = server.getConnections() check:
check len(connections) == TestsCount server1.baseUri == parseUri(expectUri1)
let currentTime = Moment.now() server2.baseUri == parseUri(expectUri2)
for index, connection in connections.pairs():
let transp = clientFutures[index].read()
check:
connection.remoteAddress.get() == transp.localAddress()
connection.localAddress.get() == transp.remoteAddress()
connection.connectionType == ConnectionType.NonSecure
connection.connectionState == ConnectionState.Alive
(currentTime - connection.createMoment.get()) != ZeroDuration
(currentTime - connection.acceptMoment) != ZeroDuration
var pending: seq[Future[void]]
for transpFut in clientFutures:
pending.add(closeWait(transpFut.read()))
await allFutures(pending)
finally: finally:
await server.stop() await server1.stop()
await server.closeWait() await server1.closeWait()
await server2.stop()
test "Leaks test": await server2.closeWait()
checkLeaks()

View File

@ -8,6 +8,7 @@
import std/[macros, strutils] import std/[macros, strutils]
import unittest2 import unittest2
import ../chronos import ../chronos
import ../chronos/config
{.used.} {.used.}
@ -94,6 +95,11 @@ proc testAwaitne(): Future[bool] {.async.} =
return true return true
template returner =
# can't use `return 5`
result = 5
return
suite "Macro transformations test suite": suite "Macro transformations test suite":
test "`await` command test": test "`await` command test":
check waitFor(testAwait()) == true check waitFor(testAwait()) == true
@ -136,6 +142,151 @@ suite "Macro transformations test suite":
check: check:
waitFor(gen(int)) == default(int) waitFor(gen(int)) == default(int)
test "Nested return":
proc nr: Future[int] {.async.} =
return
if 1 == 1:
return 42
else:
33
check waitFor(nr()) == 42
# There are a few unreacheable statements to ensure that we don't regress in
# generated code
{.push warning[UnreachableCode]: off.}
suite "Macro transformations - completions":
test "Run closure to completion on return": # issue #415
var x = 0
proc test415 {.async.} =
try:
return
finally:
await sleepAsync(1.milliseconds)
x = 5
waitFor(test415())
check: x == 5
test "Run closure to completion on defer":
var x = 0
proc testDefer {.async.} =
defer:
await sleepAsync(1.milliseconds)
x = 5
return
waitFor(testDefer())
check: x == 5
test "Run closure to completion with exceptions":
var x = 0
proc testExceptionHandling {.async.} =
try:
return
finally:
try:
await sleepAsync(1.milliseconds)
raise newException(ValueError, "")
except ValueError:
await sleepAsync(1.milliseconds)
await sleepAsync(1.milliseconds)
x = 5
waitFor(testExceptionHandling())
check: x == 5
test "Correct return value when updating result after return":
proc testWeirdCase: int =
try: return 33
finally: result = 55
proc testWeirdCaseAsync: Future[int] {.async.} =
try:
await sleepAsync(1.milliseconds)
return 33
finally: result = 55
check:
testWeirdCase() == waitFor(testWeirdCaseAsync())
testWeirdCase() == 55
test "Correct return value with result assignment in defer":
proc testWeirdCase: int =
defer:
result = 55
result = 33
proc testWeirdCaseAsync: Future[int] {.async.} =
defer:
result = 55
await sleepAsync(1.milliseconds)
return 33
check:
testWeirdCase() == waitFor(testWeirdCaseAsync())
testWeirdCase() == 55
test "Generic & finally calling async":
proc testGeneric(T: type): Future[T] {.async.} =
try:
try:
await sleepAsync(1.milliseconds)
return
finally:
await sleepAsync(1.milliseconds)
await sleepAsync(1.milliseconds)
result = 11
finally:
await sleepAsync(1.milliseconds)
await sleepAsync(1.milliseconds)
result = 12
check waitFor(testGeneric(int)) == 12
proc testFinallyCallsAsync(T: type): Future[T] {.async.} =
try:
await sleepAsync(1.milliseconds)
return
finally:
result = await testGeneric(T)
check waitFor(testFinallyCallsAsync(int)) == 12
test "templates returning":
proc testReturner: Future[int] {.async.} =
returner
doAssert false
check waitFor(testReturner()) == 5
proc testReturner2: Future[int] {.async.} =
template returner2 =
return 6
returner2
doAssert false
check waitFor(testReturner2()) == 6
test "raising defects":
proc raiser {.async.} =
# sleeping to make sure our caller is the poll loop
await sleepAsync(0.milliseconds)
raise newException(Defect, "uh-oh")
let fut = raiser()
expect(Defect): waitFor(fut)
check not fut.completed()
fut.complete()
test "return result":
proc returnResult: Future[int] {.async.} =
var result: int
result = 12
return result
check waitFor(returnResult()) == 12
test "async in async":
proc asyncInAsync: Future[int] {.async.} =
proc a2: Future[int] {.async.} =
result = 12
result = await a2()
check waitFor(asyncInAsync()) == 12
{.pop.}
suite "Macro transformations - implicit returns":
test "Implicit return": test "Implicit return":
proc implicit(): Future[int] {.async.} = proc implicit(): Future[int] {.async.} =
42 42
@ -232,3 +383,244 @@ suite "Closure iterator's exception transformation issues":
waitFor(x()) waitFor(x())
suite "Exceptions tracking":
template checkNotCompiles(body: untyped) =
check (not compiles(body))
test "Can raise valid exception":
proc test1 {.async.} = raise newException(ValueError, "hey")
proc test2 {.async: (raises: [ValueError]).} = raise newException(ValueError, "hey")
proc test3 {.async: (raises: [IOError, ValueError]).} =
if 1 == 2:
raise newException(ValueError, "hey")
else:
raise newException(IOError, "hey")
proc test4 {.async: (raises: []), used.} = raise newException(Defect, "hey")
proc test5 {.async: (raises: []).} = discard
proc test6 {.async: (raises: []).} = await test5()
expect(ValueError): waitFor test1()
expect(ValueError): waitFor test2()
expect(IOError): waitFor test3()
waitFor test6()
test "Cannot raise invalid exception":
checkNotCompiles:
proc test3 {.async: (raises: [IOError]).} = raise newException(ValueError, "hey")
test "Explicit return in non-raising proc":
proc test(): Future[int] {.async: (raises: []).} = return 12
check:
waitFor(test()) == 12
test "Non-raising compatibility":
proc test1 {.async: (raises: [ValueError]).} = raise newException(ValueError, "hey")
let testVar: Future[void] = test1()
proc test2 {.async.} = raise newException(ValueError, "hey")
let testVar2: proc: Future[void] = test2
# Doesn't work unfortunately
#let testVar3: proc: Future[void] = test1
test "Cannot store invalid future types":
proc test1 {.async: (raises: [ValueError]).} = raise newException(ValueError, "hey")
proc test2 {.async: (raises: [IOError]).} = raise newException(IOError, "hey")
var a = test1()
checkNotCompiles:
a = test2()
test "Await raises the correct types":
proc test1 {.async: (raises: [ValueError]).} = raise newException(ValueError, "hey")
proc test2 {.async: (raises: [ValueError, CancelledError]).} = await test1()
checkNotCompiles:
proc test3 {.async: (raises: [CancelledError]).} = await test1()
test "Can create callbacks":
proc test1 {.async: (raises: [ValueError]).} = raise newException(ValueError, "hey")
let callback: proc() {.async: (raises: [ValueError]).} = test1
test "Can return values":
proc test1: Future[int] {.async: (raises: [ValueError]).} =
if 1 == 0: raise newException(ValueError, "hey")
return 12
proc test2: Future[int] {.async: (raises: [ValueError, IOError, CancelledError]).} =
return await test1()
checkNotCompiles:
proc test3: Future[int] {.async: (raises: [CancelledError]).} = await test1()
check waitFor(test2()) == 12
test "Manual tracking":
proc test1: Future[int] {.async: (raw: true, raises: [ValueError]).} =
result = newFuture[int]()
result.complete(12)
check waitFor(test1()) == 12
proc test2: Future[int] {.async: (raw: true, raises: [IOError, OSError]).} =
checkNotCompiles:
result.fail(newException(ValueError, "fail"))
result = newFuture[int]()
result.fail(newException(IOError, "fail"))
proc test3: Future[void] {.async: (raw: true, raises: []).} =
result = newFuture[void]()
checkNotCompiles:
result.fail(newException(ValueError, "fail"))
result.complete()
# Inheritance
proc test4: Future[void] {.async: (raw: true, raises: [CatchableError]).} =
result = newFuture[void]()
result.fail(newException(IOError, "fail"))
check:
waitFor(test1()) == 12
expect(IOError):
discard waitFor(test2())
waitFor(test3())
expect(IOError):
waitFor(test4())
test "or errors":
proc testit {.async: (raises: [ValueError]).} =
raise (ref ValueError)()
proc testit2 {.async: (raises: [IOError]).} =
raise (ref IOError)()
proc test {.async: (raises: [CancelledError, ValueError, IOError]).} =
await testit() or testit2()
proc noraises() {.raises: [].} =
expect(ValueError):
try:
let f = test()
waitFor(f)
except IOError:
doAssert false
noraises()
test "Wait errors":
proc testit {.async: (raises: [ValueError]).} =
raise newException(ValueError, "hey")
proc test {.async: (raises: [ValueError, AsyncTimeoutError, CancelledError]).} =
await wait(testit(), 1000.milliseconds)
proc noraises() {.raises: [].} =
try:
expect(ValueError): waitFor(test())
except CancelledError: doAssert false
except AsyncTimeoutError: doAssert false
noraises()
test "Nocancel errors with raises":
proc testit {.async: (raises: [ValueError, CancelledError]).} =
await sleepAsync(5.milliseconds)
raise (ref ValueError)()
proc test {.async: (raises: [ValueError]).} =
await noCancel testit()
proc noraises() {.raises: [].} =
expect(ValueError):
let f = test()
waitFor(f.cancelAndWait())
waitFor(f)
noraises()
test "Nocancel with no errors":
proc testit {.async: (raises: [CancelledError]).} =
await sleepAsync(5.milliseconds)
proc test {.async: (raises: []).} =
await noCancel testit()
proc noraises() {.raises: [].} =
let f = test()
waitFor(f.cancelAndWait())
waitFor(f)
noraises()
test "Nocancel errors without raises":
proc testit {.async.} =
await sleepAsync(5.milliseconds)
raise (ref ValueError)()
proc test {.async.} =
await noCancel testit()
proc noraises() =
expect(ValueError):
let f = test()
waitFor(f.cancelAndWait())
waitFor(f)
noraises()
test "Defect on wrong exception type at runtime":
{.push warning[User]: off}
let f = InternalRaisesFuture[void, (ValueError,)]()
expect(Defect): f.fail((ref CatchableError)())
{.pop.}
check: not f.finished()
expect(Defect): f.fail((ref CatchableError)(), warn = false)
check: not f.finished()
test "handleException behavior":
proc raiseException() {.
async: (handleException: true, raises: [AsyncExceptionError]).} =
raise (ref Exception)(msg: "Raising Exception is UB and support for it may change in the future")
proc callCatchAll() {.async: (raises: []).} =
expect(AsyncExceptionError):
await raiseException()
waitFor(callCatchAll())
test "Global handleException does not override local annotations":
when chronosHandleException:
proc unnanotated() {.async.} = raise (ref CatchableError)()
checkNotCompiles:
proc annotated() {.async: (raises: [ValueError]).} =
raise (ref CatchableError)()
checkNotCompiles:
proc noHandleException() {.async: (handleException: false).} =
raise (ref Exception)()
else:
skip()
test "Results compatibility":
proc returnOk(): Future[Result[int, string]] {.async: (raises: []).} =
ok(42)
proc returnErr(): Future[Result[int, string]] {.async: (raises: []).} =
err("failed")
proc testit(): Future[Result[void, string]] {.async: (raises: []).} =
let
v = await returnOk()
check:
v.isOk() and v.value() == 42
let
vok = ?v
check:
vok == 42
discard ?await returnErr()
check:
waitFor(testit()).error() == "failed"

View File

@ -2,6 +2,8 @@
IF /I "%1" == "STDIN" ( IF /I "%1" == "STDIN" (
GOTO :STDINTEST GOTO :STDINTEST
) ELSE IF /I "%1" == "TIMEOUT1" (
GOTO :TIMEOUTTEST1
) ELSE IF /I "%1" == "TIMEOUT2" ( ) ELSE IF /I "%1" == "TIMEOUT2" (
GOTO :TIMEOUTTEST2 GOTO :TIMEOUTTEST2
) ELSE IF /I "%1" == "TIMEOUT10" ( ) ELSE IF /I "%1" == "TIMEOUT10" (
@ -19,6 +21,10 @@ SET /P "INPUTDATA="
ECHO STDIN DATA: %INPUTDATA% ECHO STDIN DATA: %INPUTDATA%
EXIT 0 EXIT 0
:TIMEOUTTEST1
ping -n 1 127.0.0.1 > NUL
EXIT 1
:TIMEOUTTEST2 :TIMEOUTTEST2
ping -n 2 127.0.0.1 > NUL ping -n 2 127.0.0.1 > NUL
EXIT 2 EXIT 2
@ -28,7 +34,7 @@ ping -n 10 127.0.0.1 > NUL
EXIT 0 EXIT 0
:BIGDATA :BIGDATA
FOR /L %%G IN (1, 1, 400000) DO ECHO ALICEWASBEGINNINGTOGETVERYTIREDOFSITTINGBYHERSISTERONTHEBANKANDO FOR /L %%G IN (1, 1, 100000) DO ECHO ALICEWASBEGINNINGTOGETVERYTIREDOFSITTINGBYHERSISTERONTHEBANKANDO
EXIT 0 EXIT 0
:ENVTEST :ENVTEST

View File

@ -8,6 +8,7 @@
import std/os import std/os
import stew/[base10, byteutils] import stew/[base10, byteutils]
import ".."/chronos/unittest2/asynctests import ".."/chronos/unittest2/asynctests
import ".."/chronos/asyncproc
when defined(posix): when defined(posix):
from ".."/chronos/osdefs import SIGKILL from ".."/chronos/osdefs import SIGKILL
@ -15,6 +16,9 @@ when defined(posix):
when defined(nimHasUsed): {.used.} when defined(nimHasUsed): {.used.}
suite "Asynchronous process management test suite": suite "Asynchronous process management test suite":
teardown:
checkLeaks()
const OutputTests = const OutputTests =
when defined(windows): when defined(windows):
[ [
@ -96,7 +100,11 @@ suite "Asynchronous process management test suite":
let let
options = {AsyncProcessOption.EvalCommand} options = {AsyncProcessOption.EvalCommand}
command = "exit 1" command =
when defined(windows):
"tests\\testproc.bat timeout1"
else:
"tests/testproc.sh timeout1"
process = await startProcess(command, options = options) process = await startProcess(command, options = options)
@ -201,31 +209,34 @@ suite "Asynchronous process management test suite":
await process.closeWait() await process.closeWait()
asyncTest "Capture big amount of bytes from STDOUT stream test": asyncTest "Capture big amount of bytes from STDOUT stream test":
let options = {AsyncProcessOption.EvalCommand} when sizeof(int) == 4:
let command = skip()
when defined(windows): else:
"tests\\testproc.bat bigdata" let options = {AsyncProcessOption.EvalCommand}
else: let command =
"tests/testproc.sh bigdata" when defined(windows):
let expect = "tests\\testproc.bat bigdata"
when defined(windows): else:
400_000 * (64 + 2) "tests/testproc.sh bigdata"
else: let expect =
400_000 * (64 + 1) when defined(windows):
let process = await startProcess(command, options = options, 100_000 * (64 + 2)
stdoutHandle = AsyncProcess.Pipe, else:
stderrHandle = AsyncProcess.Pipe) 100_000 * (64 + 1)
try: let process = await startProcess(command, options = options,
let outBytesFut = process.stdoutStream.read() stdoutHandle = AsyncProcess.Pipe,
let errBytesFut = process.stderrStream.read() stderrHandle = AsyncProcess.Pipe)
let res = await process.waitForExit(InfiniteDuration) try:
await allFutures(outBytesFut, errBytesFut) let outBytesFut = process.stdoutStream.read()
check: let errBytesFut = process.stderrStream.read()
res == 0 let res = await process.waitForExit(InfiniteDuration)
len(outBytesFut.read()) == expect await allFutures(outBytesFut, errBytesFut)
len(errBytesFut.read()) == 0 check:
finally: res == 0
await process.closeWait() len(outBytesFut.read()) == expect
len(errBytesFut.read()) == 0
finally:
await process.closeWait()
asyncTest "Long-waiting waitForExit() test": asyncTest "Long-waiting waitForExit() test":
let command = let command =
@ -407,11 +418,54 @@ suite "Asynchronous process management test suite":
finally: finally:
await process.closeWait() await process.closeWait()
asyncTest "killAndWaitForExit() test":
let command =
when defined(windows):
("tests\\testproc.bat", "timeout10", 0)
else:
("tests/testproc.sh", "timeout10", 128 + int(SIGKILL))
let process = await startProcess(command[0], arguments = @[command[1]])
try:
let exitCode = await process.killAndWaitForExit(10.seconds)
check exitCode == command[2]
finally:
await process.closeWait()
asyncTest "terminateAndWaitForExit() test":
let command =
when defined(windows):
("tests\\testproc.bat", "timeout10", 0)
else:
("tests/testproc.sh", "timeout10", 128 + int(SIGTERM))
let process = await startProcess(command[0], arguments = @[command[1]])
try:
let exitCode = await process.terminateAndWaitForExit(10.seconds)
check exitCode == command[2]
finally:
await process.closeWait()
asyncTest "terminateAndWaitForExit() timeout test":
when defined(windows):
skip()
else:
let
command = ("tests/testproc.sh", "noterm", 128 + int(SIGKILL))
process = await startProcess(command[0], arguments = @[command[1]])
# We should wait here to allow `bash` execute `trap` command, otherwise
# our test script will be killed with SIGTERM. Increase this timeout
# if test become flaky.
await sleepAsync(1.seconds)
try:
expect AsyncProcessTimeoutError:
let exitCode {.used.} =
await process.terminateAndWaitForExit(1.seconds)
let exitCode = await process.killAndWaitForExit(10.seconds)
check exitCode == command[2]
finally:
await process.closeWait()
test "File descriptors leaks test": test "File descriptors leaks test":
when defined(windows): when defined(windows):
skip() skip()
else: else:
check getCurrentFD() == markFD check getCurrentFD() == markFD
test "Leaks test":
checkLeaks()

View File

@ -3,18 +3,26 @@
if [ "$1" == "stdin" ]; then if [ "$1" == "stdin" ]; then
read -r inputdata read -r inputdata
echo "STDIN DATA: $inputdata" echo "STDIN DATA: $inputdata"
elif [ "$1" == "timeout1" ]; then
sleep 1
exit 1
elif [ "$1" == "timeout2" ]; then elif [ "$1" == "timeout2" ]; then
sleep 2 sleep 2
exit 2 exit 2
elif [ "$1" == "timeout10" ]; then elif [ "$1" == "timeout10" ]; then
sleep 10 sleep 10
elif [ "$1" == "bigdata" ]; then elif [ "$1" == "bigdata" ]; then
for i in {1..400000} for i in {1..100000}
do do
echo "ALICEWASBEGINNINGTOGETVERYTIREDOFSITTINGBYHERSISTERONTHEBANKANDO" echo "ALICEWASBEGINNINGTOGETVERYTIREDOFSITTINGBYHERSISTERONTHEBANKANDO"
done done
elif [ "$1" == "envtest" ]; then elif [ "$1" == "envtest" ]; then
echo "$CHRONOSASYNC" echo "$CHRONOSASYNC"
elif [ "$1" == "noterm" ]; then
trap -- '' SIGTERM
while true; do
sleep 1
done
else else
echo "arguments missing" echo "arguments missing"
fi fi

View File

@ -49,7 +49,7 @@ suite "Token Bucket":
# Consume 10* the budget cap # Consume 10* the budget cap
let beforeStart = Moment.now() let beforeStart = Moment.now()
waitFor(bucket.consume(1000).wait(5.seconds)) waitFor(bucket.consume(1000).wait(5.seconds))
check Moment.now() - beforeStart in 900.milliseconds .. 1500.milliseconds check Moment.now() - beforeStart in 900.milliseconds .. 2200.milliseconds
test "Sync manual replenish": test "Sync manual replenish":
var bucket = TokenBucket.new(1000, 0.seconds) var bucket = TokenBucket.new(1000, 0.seconds)
@ -96,7 +96,7 @@ suite "Token Bucket":
futBlocker.finished == false futBlocker.finished == false
fut2.finished == false fut2.finished == false
futBlocker.cancel() futBlocker.cancelSoon()
waitFor(fut2.wait(10.milliseconds)) waitFor(fut2.wait(10.milliseconds))
test "Very long replenish": test "Very long replenish":
@ -117,9 +117,14 @@ suite "Token Bucket":
check bucket.tryConsume(1, fakeNow) == true check bucket.tryConsume(1, fakeNow) == true
test "Short replenish": test "Short replenish":
var bucket = TokenBucket.new(15000, 1.milliseconds) skip()
let start = Moment.now() # TODO (cheatfate): This test was disabled, because it continuosly fails in
check bucket.tryConsume(15000, start) # Github Actions Windows x64 CI when using Nim 1.6.14 version.
check bucket.tryConsume(1, start) == false # Unable to reproduce failure locally.
check bucket.tryConsume(15000, start + 1.milliseconds) == true # var bucket = TokenBucket.new(15000, 1.milliseconds)
# let start = Moment.now()
# check bucket.tryConsume(15000, start)
# check bucket.tryConsume(1, start) == false
# check bucket.tryConsume(15000, start + 1.milliseconds) == true

View File

@ -5,8 +5,8 @@
# Licensed under either of # Licensed under either of
# Apache License, version 2.0, (LICENSE-APACHEv2) # Apache License, version 2.0, (LICENSE-APACHEv2)
# MIT license (LICENSE-MIT) # MIT license (LICENSE-MIT)
import unittest2
import ../chronos import ../chronos/unittest2/asynctests
{.used.} {.used.}
@ -23,30 +23,40 @@ suite "Server's test suite":
CustomData = ref object CustomData = ref object
test: string test: string
teardown:
checkLeaks()
proc serveStreamClient(server: StreamServer, proc serveStreamClient(server: StreamServer,
transp: StreamTransport) {.async.} = transp: StreamTransport) {.async: (raises: []).} =
discard discard
proc serveCustomStreamClient(server: StreamServer, proc serveCustomStreamClient(server: StreamServer,
transp: StreamTransport) {.async.} = transp: StreamTransport) {.async: (raises: []).} =
var cserver = cast[CustomServer](server) try:
var ctransp = cast[CustomTransport](transp) var cserver = cast[CustomServer](server)
cserver.test1 = "CONNECTION" var ctransp = cast[CustomTransport](transp)
cserver.test2 = ctransp.test cserver.test1 = "CONNECTION"
cserver.test3 = await transp.readLine() cserver.test2 = ctransp.test
var answer = "ANSWER\r\n" cserver.test3 = await transp.readLine()
discard await transp.write(answer) var answer = "ANSWER\r\n"
transp.close() discard await transp.write(answer)
await transp.join() transp.close()
await transp.join()
except CatchableError as exc:
raiseAssert exc.msg
proc serveUdataStreamClient(server: StreamServer, proc serveUdataStreamClient(server: StreamServer,
transp: StreamTransport) {.async.} = transp: StreamTransport) {.async: (raises: []).} =
var udata = getUserData[CustomData](server) try:
var line = await transp.readLine() var udata = getUserData[CustomData](server)
var msg = line & udata.test & "\r\n" var line = await transp.readLine()
discard await transp.write(msg) var msg = line & udata.test & "\r\n"
transp.close() discard await transp.write(msg)
await transp.join() transp.close()
await transp.join()
except CatchableError as exc:
raiseAssert exc.msg
proc customServerTransport(server: StreamServer, proc customServerTransport(server: StreamServer,
fd: AsyncFD): StreamTransport = fd: AsyncFD): StreamTransport =
@ -54,37 +64,47 @@ suite "Server's test suite":
transp.test = "CUSTOM" transp.test = "CUSTOM"
result = cast[StreamTransport](transp) result = cast[StreamTransport](transp)
proc test1(): bool = asyncTest "Stream Server start/stop test":
var ta = initTAddress("127.0.0.1:31354") var ta = initTAddress("127.0.0.1:31354")
var server1 = createStreamServer(ta, serveStreamClient, {ReuseAddr}) var server1 = createStreamServer(ta, serveStreamClient, {ReuseAddr})
server1.start() server1.start()
server1.stop() server1.stop()
server1.close() server1.close()
waitFor server1.join() await server1.join()
var server2 = createStreamServer(ta, serveStreamClient, {ReuseAddr}) var server2 = createStreamServer(ta, serveStreamClient, {ReuseAddr})
server2.start() server2.start()
server2.stop() server2.stop()
server2.close() server2.close()
waitFor server2.join() await server2.join()
result = true
proc test5(): bool = asyncTest "Stream Server stop without start test":
var ta = initTAddress("127.0.0.1:31354") var ta = initTAddress("127.0.0.1:0")
var server1 = createStreamServer(ta, serveStreamClient, {ReuseAddr}) var server1 = createStreamServer(ta, serveStreamClient, {ReuseAddr})
ta = server1.localAddress()
server1.stop() server1.stop()
server1.close() server1.close()
waitFor server1.join()
await server1.join()
var server2 = createStreamServer(ta, serveStreamClient, {ReuseAddr}) var server2 = createStreamServer(ta, serveStreamClient, {ReuseAddr})
server2.stop() server2.stop()
server2.close() server2.close()
waitFor server2.join() await server2.join()
result = true
asyncTest "Stream Server inherited object test":
var server = CustomServer()
server.test1 = "TEST"
var ta = initTAddress("127.0.0.1:0")
var pserver = createStreamServer(ta, serveCustomStreamClient, {ReuseAddr},
child = server,
init = customServerTransport)
check:
pserver == server
proc client1(server: CustomServer, ta: TransportAddress) {.async.} =
var transp = CustomTransport() var transp = CustomTransport()
transp.test = "CLIENT" transp.test = "CLIENT"
server.start() server.start()
var ptransp = await connect(ta, child = transp) var ptransp = await connect(server.localAddress(), child = transp)
var etransp = cast[CustomTransport](ptransp) var etransp = cast[CustomTransport](ptransp)
doAssert(etransp.test == "CLIENT") doAssert(etransp.test == "CLIENT")
var msg = "TEST\r\n" var msg = "TEST\r\n"
@ -96,44 +116,48 @@ suite "Server's test suite":
server.close() server.close()
await server.join() await server.join()
proc client2(server: StreamServer, check:
ta: TransportAddress): Future[bool] {.async.} = server.test1 == "CONNECTION"
server.test2 == "CUSTOM"
asyncTest "StreamServer[T] test":
var co = CustomData()
co.test = "CUSTOMDATA"
var ta = initTAddress("127.0.0.1:0")
var server = createStreamServer(ta, serveUdataStreamClient, {ReuseAddr},
udata = co)
server.start() server.start()
var transp = await connect(ta) var transp = await connect(server.localAddress())
var msg = "TEST\r\n" var msg = "TEST\r\n"
discard await transp.write(msg) discard await transp.write(msg)
var line = await transp.readLine() var line = await transp.readLine()
result = (line == "TESTCUSTOMDATA") check:
line == "TESTCUSTOMDATA"
transp.close() transp.close()
server.stop() server.stop()
server.close() server.close()
await server.join() await server.join()
proc test3(): bool = asyncTest "Backlog and connect cancellation":
var server = CustomServer() var ta = initTAddress("127.0.0.1:0")
server.test1 = "TEST" var server1 = createStreamServer(ta, serveStreamClient, {ReuseAddr}, backlog = 1)
var ta = initTAddress("127.0.0.1:31354") ta = server1.localAddress()
var pserver = createStreamServer(ta, serveCustomStreamClient, {ReuseAddr},
child = cast[StreamServer](server),
init = customServerTransport)
doAssert(not isNil(pserver))
waitFor client1(server, ta)
result = (server.test1 == "CONNECTION") and (server.test2 == "CUSTOM")
proc test4(): bool = var clients: seq[Future[StreamTransport]]
var co = CustomData() for i in 0..<10:
co.test = "CUSTOMDATA" clients.add(connect(server1.localAddress))
var ta = initTAddress("127.0.0.1:31354")
var server = createStreamServer(ta, serveUdataStreamClient, {ReuseAddr},
udata = co)
result = waitFor client2(server, ta)
# Check for leaks in cancellation / connect when server is not accepting
for c in clients:
if not c.finished:
await c.cancelAndWait()
else:
# The backlog connection "should" end up here
try:
await c.read().closeWait()
except CatchableError:
discard
test "Stream Server start/stop test": server1.close()
check test1() == true await server1.join()
test "Stream Server stop without start test":
check test5() == true
test "Stream Server inherited object test":
check test3() == true
test "StreamServer[T] test":
check test4() == true

View File

@ -7,7 +7,8 @@
# MIT license (LICENSE-MIT) # MIT license (LICENSE-MIT)
import std/strutils import std/strutils
import ".."/chronos/unittest2/asynctests import ".."/chronos/unittest2/asynctests
import ".."/chronos, ".."/chronos/apps/http/shttpserver import ".."/chronos,
".."/chronos/apps/http/shttpserver
import stew/base10 import stew/base10
{.used.} {.used.}
@ -74,6 +75,8 @@ N8r5CwGcIX/XPC3lKazzbZ8baA==
suite "Secure HTTP server testing suite": suite "Secure HTTP server testing suite":
teardown:
checkLeaks()
proc httpsClient(address: TransportAddress, proc httpsClient(address: TransportAddress,
data: string, flags = {NoVerifyHost, NoVerifyServerName} data: string, flags = {NoVerifyHost, NoVerifyServerName}
@ -107,15 +110,18 @@ suite "Secure HTTP server testing suite":
proc testHTTPS(address: TransportAddress): Future[bool] {.async.} = proc testHTTPS(address: TransportAddress): Future[bool] {.async.} =
var serverRes = false var serverRes = false
proc process(r: RequestFence): Future[HttpResponseRef] {. proc process(r: RequestFence): Future[HttpResponseRef] {.
async.} = async: (raises: [CancelledError]).} =
if r.isOk(): if r.isOk():
let request = r.get() let request = r.get()
serverRes = true serverRes = true
return await request.respond(Http200, "TEST_OK:" & $request.meth, try:
HttpTable.init()) await request.respond(Http200, "TEST_OK:" & $request.meth,
HttpTable.init())
except HttpWriteError as exc:
serverRes = false
defaultResponse(exc)
else: else:
serverRes = false defaultResponse()
return defaultResponse()
let socketFlags = {ServerFlags.TcpNoDelay, ServerFlags.ReuseAddr} let socketFlags = {ServerFlags.TcpNoDelay, ServerFlags.ReuseAddr}
let serverFlags = {Secure} let serverFlags = {Secure}
@ -145,16 +151,18 @@ suite "Secure HTTP server testing suite":
var serverRes = false var serverRes = false
var testFut = newFuture[void]() var testFut = newFuture[void]()
proc process(r: RequestFence): Future[HttpResponseRef] {. proc process(r: RequestFence): Future[HttpResponseRef] {.
async.} = async: (raises: [CancelledError]).} =
if r.isOk(): if r.isOk():
let request = r.get() let request = r.get()
serverRes = false try:
return await request.respond(Http200, "TEST_OK:" & $request.meth, await request.respond(Http200, "TEST_OK:" & $request.meth,
HttpTable.init()) HttpTable.init())
except HttpWriteError as exc:
defaultResponse(exc)
else: else:
serverRes = true serverRes = true
testFut.complete() testFut.complete()
return defaultResponse() defaultResponse()
let socketFlags = {ServerFlags.TcpNoDelay, ServerFlags.ReuseAddr} let socketFlags = {ServerFlags.TcpNoDelay, ServerFlags.ReuseAddr}
let serverFlags = {Secure} let serverFlags = {Secure}
@ -179,5 +187,48 @@ suite "Secure HTTP server testing suite":
check waitFor(testHTTPS2(initTAddress("127.0.0.1:30080"))) == true check waitFor(testHTTPS2(initTAddress("127.0.0.1:30080"))) == true
test "Leaks test": asyncTest "HTTPS server - baseUri value test":
checkLeaks() proc process(r: RequestFence): Future[HttpResponseRef] {.
async: (raises: [CancelledError]).} =
defaultResponse()
let
expectUri2 = "https://www.chronos-test.com/"
address = initTAddress("127.0.0.1:0")
socketFlags = {ServerFlags.TcpNoDelay, ServerFlags.ReuseAddr}
serverFlags = {Secure}
secureKey = TLSPrivateKey.init(HttpsSelfSignedRsaKey)
secureCert = TLSCertificate.init(HttpsSelfSignedRsaCert)
res1 = SecureHttpServerRef.new(address, process,
socketFlags = socketFlags,
serverFlags = serverFlags,
tlsPrivateKey = secureKey,
tlsCertificate = secureCert)
res2 = SecureHttpServerRef.new(address, process,
socketFlags = socketFlags,
serverFlags = serverFlags,
serverUri = parseUri(expectUri2),
tlsPrivateKey = secureKey,
tlsCertificate = secureCert)
check:
res1.isOk == true
res2.isOk == true
let
server1 = res1.get()
server2 = res2.get()
try:
server1.start()
server2.start()
let
localAddress = server1.instance.localAddress()
expectUri1 = "https://127.0.0.1:" & $localAddress.port & "/"
check:
server1.baseUri == parseUri(expectUri1)
server2.baseUri == parseUri(expectUri2)
finally:
await server1.stop()
await server1.closeWait()
await server2.stop()
await server2.closeWait()

View File

@ -11,75 +11,83 @@ import ../chronos
{.used.} {.used.}
suite "callSoon() tests suite": suite "callSoon() tests suite":
const CallSoonTests = 10
var soonTest1 = 0'u
var timeoutsTest1 = 0
var timeoutsTest2 = 0
var soonTest2 = 0
proc callback1(udata: pointer) {.gcsafe.} =
soonTest1 = soonTest1 xor cast[uint](udata)
proc test1(): uint =
callSoon(callback1, cast[pointer](0x12345678'u))
callSoon(callback1, cast[pointer](0x23456789'u))
callSoon(callback1, cast[pointer](0x3456789A'u))
callSoon(callback1, cast[pointer](0x456789AB'u))
callSoon(callback1, cast[pointer](0x56789ABC'u))
callSoon(callback1, cast[pointer](0x6789ABCD'u))
callSoon(callback1, cast[pointer](0x789ABCDE'u))
callSoon(callback1, cast[pointer](0x89ABCDEF'u))
callSoon(callback1, cast[pointer](0x9ABCDEF1'u))
callSoon(callback1, cast[pointer](0xABCDEF12'u))
callSoon(callback1, cast[pointer](0xBCDEF123'u))
callSoon(callback1, cast[pointer](0xCDEF1234'u))
callSoon(callback1, cast[pointer](0xDEF12345'u))
callSoon(callback1, cast[pointer](0xEF123456'u))
callSoon(callback1, cast[pointer](0xF1234567'u))
callSoon(callback1, cast[pointer](0x12345678'u))
## All callbacks must be processed exactly with 1 poll() call.
poll()
result = soonTest1
proc testProc() {.async.} =
for i in 1..CallSoonTests:
await sleepAsync(100.milliseconds)
timeoutsTest1 += 1
var callbackproc: proc(udata: pointer) {.gcsafe, raises: [].}
callbackproc = proc (udata: pointer) {.gcsafe, raises: [].} =
timeoutsTest2 += 1
{.gcsafe.}:
callSoon(callbackproc)
proc test2(timers, callbacks: var int) =
callSoon(callbackproc)
waitFor(testProc())
timers = timeoutsTest1
callbacks = timeoutsTest2
proc testCallback(udata: pointer) =
soonTest2 = 987654321
proc test3(): bool =
callSoon(testCallback)
poll()
result = soonTest2 == 987654321
test "User-defined callback argument test": test "User-defined callback argument test":
var values = [0x12345678'u, 0x23456789'u, 0x3456789A'u, 0x456789AB'u, proc test(): bool =
0x56789ABC'u, 0x6789ABCD'u, 0x789ABCDE'u, 0x89ABCDEF'u, var soonTest = 0'u
0x9ABCDEF1'u, 0xABCDEF12'u, 0xBCDEF123'u, 0xCDEF1234'u,
0xDEF12345'u, 0xEF123456'u, 0xF1234567'u, 0x12345678'u] proc callback(udata: pointer) {.gcsafe.} =
var expect = 0'u soonTest = soonTest xor cast[uint](udata)
for item in values:
expect = expect xor item callSoon(callback, cast[pointer](0x12345678'u))
check test1() == expect callSoon(callback, cast[pointer](0x23456789'u))
callSoon(callback, cast[pointer](0x3456789A'u))
callSoon(callback, cast[pointer](0x456789AB'u))
callSoon(callback, cast[pointer](0x56789ABC'u))
callSoon(callback, cast[pointer](0x6789ABCD'u))
callSoon(callback, cast[pointer](0x789ABCDE'u))
callSoon(callback, cast[pointer](0x89ABCDEF'u))
callSoon(callback, cast[pointer](0x9ABCDEF1'u))
callSoon(callback, cast[pointer](0xABCDEF12'u))
callSoon(callback, cast[pointer](0xBCDEF123'u))
callSoon(callback, cast[pointer](0xCDEF1234'u))
callSoon(callback, cast[pointer](0xDEF12345'u))
callSoon(callback, cast[pointer](0xEF123456'u))
callSoon(callback, cast[pointer](0xF1234567'u))
callSoon(callback, cast[pointer](0x12345678'u))
## All callbacks must be processed exactly with 1 poll() call.
poll()
var values = [0x12345678'u, 0x23456789'u, 0x3456789A'u, 0x456789AB'u,
0x56789ABC'u, 0x6789ABCD'u, 0x789ABCDE'u, 0x89ABCDEF'u,
0x9ABCDEF1'u, 0xABCDEF12'u, 0xBCDEF123'u, 0xCDEF1234'u,
0xDEF12345'u, 0xEF123456'u, 0xF1234567'u, 0x12345678'u]
var expect = 0'u
for item in values:
expect = expect xor item
soonTest == expect
check test() == true
test "`Asynchronous dead end` #7193 test": test "`Asynchronous dead end` #7193 test":
var timers, callbacks: int const CallSoonTests = 5
test2(timers, callbacks) proc test() =
check: var
timers == CallSoonTests timeoutsTest1 = 0
callbacks > CallSoonTests * 2 timeoutsTest2 = 0
stopFlag = false
var callbackproc: proc(udata: pointer) {.gcsafe, raises: [].}
callbackproc = proc (udata: pointer) {.gcsafe, raises: [].} =
timeoutsTest2 += 1
if not(stopFlag):
callSoon(callbackproc)
proc testProc() {.async.} =
for i in 1 .. CallSoonTests:
await sleepAsync(10.milliseconds)
timeoutsTest1 += 1
callSoon(callbackproc)
waitFor(testProc())
stopFlag = true
poll()
check:
timeoutsTest1 == CallSoonTests
timeoutsTest2 > CallSoonTests * 2
test()
test "`callSoon() is not working prior getGlobalDispatcher()` #7192 test": test "`callSoon() is not working prior getGlobalDispatcher()` #7192 test":
check test3() == true proc test(): bool =
var soonTest = 0
proc testCallback(udata: pointer) =
soonTest = 987654321
callSoon(testCallback)
poll()
soonTest == 987654321
check test() == true

File diff suppressed because it is too large Load Diff

View File

@ -150,9 +150,9 @@ suite "Asynchronous sync primitives test suite":
var fut2 = task(lock, 2, n2) var fut2 = task(lock, 2, n2)
var fut3 = task(lock, 3, n3) var fut3 = task(lock, 3, n3)
if cancelIndex == 2: if cancelIndex == 2:
fut2.cancel() fut2.cancelSoon()
else: else:
fut3.cancel() fut3.cancelSoon()
await allFutures(fut1, fut2, fut3) await allFutures(fut1, fut2, fut3)
result = stripe result = stripe

View File

@ -39,9 +39,12 @@ type
Sync, Async Sync, Async
const const
TestsCount = 1000 TestsCount = when sizeof(int) == 8: 1000 else: 100
suite "Asynchronous multi-threading sync primitives test suite": suite "Asynchronous multi-threading sync primitives test suite":
teardown:
checkLeaks()
proc setResult(thr: ThreadResultPtr, value: int) = proc setResult(thr: ThreadResultPtr, value: int) =
thr[].value = value thr[].value = value
@ -322,19 +325,31 @@ suite "Asynchronous multi-threading sync primitives test suite":
asyncTest "ThreadSignal: Multiple thread switches [" & $TestsCount & asyncTest "ThreadSignal: Multiple thread switches [" & $TestsCount &
"] test [sync -> sync]": "] test [sync -> sync]":
threadSignalTest2(TestsCount, WaitSendKind.Sync, WaitSendKind.Sync) when sizeof(int) == 8:
threadSignalTest2(TestsCount, WaitSendKind.Sync, WaitSendKind.Sync)
else:
skip()
asyncTest "ThreadSignal: Multiple thread switches [" & $TestsCount & asyncTest "ThreadSignal: Multiple thread switches [" & $TestsCount &
"] test [async -> async]": "] test [async -> async]":
threadSignalTest2(TestsCount, WaitSendKind.Async, WaitSendKind.Async) when sizeof(int) == 8:
threadSignalTest2(TestsCount, WaitSendKind.Async, WaitSendKind.Async)
else:
skip()
asyncTest "ThreadSignal: Multiple thread switches [" & $TestsCount & asyncTest "ThreadSignal: Multiple thread switches [" & $TestsCount &
"] test [sync -> async]": "] test [sync -> async]":
threadSignalTest2(TestsCount, WaitSendKind.Sync, WaitSendKind.Async) when sizeof(int) == 8:
threadSignalTest2(TestsCount, WaitSendKind.Sync, WaitSendKind.Async)
else:
skip()
asyncTest "ThreadSignal: Multiple thread switches [" & $TestsCount & asyncTest "ThreadSignal: Multiple thread switches [" & $TestsCount &
"] test [async -> sync]": "] test [async -> sync]":
threadSignalTest2(TestsCount, WaitSendKind.Async, WaitSendKind.Sync) when sizeof(int) == 8:
threadSignalTest2(TestsCount, WaitSendKind.Async, WaitSendKind.Sync)
else:
skip()
asyncTest "ThreadSignal: Multiple signals [" & $TestsCount & asyncTest "ThreadSignal: Multiple signals [" & $TestsCount &
"] to multiple threads [" & $numProcs & "] test [sync -> sync]": "] to multiple threads [" & $numProcs & "] test [sync -> sync]":

View File

@ -89,28 +89,41 @@ suite "Asynchronous timers & steps test suite":
$nanoseconds(1_000_000_900) == "1s900ns" $nanoseconds(1_000_000_900) == "1s900ns"
$nanoseconds(1_800_700_000) == "1s800ms700us" $nanoseconds(1_800_700_000) == "1s800ms700us"
$nanoseconds(1_800_000_600) == "1s800ms600ns" $nanoseconds(1_800_000_600) == "1s800ms600ns"
nanoseconds(1_800_000_600).toString(0) == ""
nanoseconds(1_800_000_600).toString(1) == "1s"
nanoseconds(1_800_000_600).toString(2) == "1s800ms"
test "Asynchronous steps test": test "Asynchronous steps test":
var futn1 = stepsAsync(-1)
var fut0 = stepsAsync(0)
var fut1 = stepsAsync(1) var fut1 = stepsAsync(1)
var fut2 = stepsAsync(2) var fut2 = stepsAsync(2)
var fut3 = stepsAsync(3) var fut3 = stepsAsync(3)
check: check:
futn1.completed() == true
fut0.completed() == true
fut1.completed() == false fut1.completed() == false
fut2.completed() == false fut2.completed() == false
fut3.completed() == false fut3.completed() == false
poll()
# We need `fut` because `stepsAsync` do not power `poll()` anymore.
block:
var fut {.used.} = sleepAsync(50.milliseconds)
poll()
check: check:
fut1.completed() == true fut1.completed() == true
fut2.completed() == false fut2.completed() == false
fut3.completed() == false fut3.completed() == false
poll()
block:
var fut {.used.} = sleepAsync(50.milliseconds)
poll()
check: check:
fut2.completed() == true fut2.completed() == true
fut3.completed() == false fut3.completed() == false
poll()
block:
var fut {.used.} = sleepAsync(50.milliseconds)
poll()
check: check:
fut3.completed() == true fut3.completed() == true

View File

@ -56,7 +56,7 @@ suite "Asynchronous utilities test suite":
check: check:
getCount() == 1'u getCount() == 1'u
pendingFuturesCount() == 1'u pendingFuturesCount() == 1'u
fut3.cancel() discard fut3.tryCancel()
poll() poll()
check: check:
getCount() == 0'u getCount() == 0'u
@ -75,11 +75,6 @@ suite "Asynchronous utilities test suite":
pendingFuturesCount() == 2'u pendingFuturesCount() == 2'u
waitFor fut waitFor fut
check:
getCount() == 1'u
pendingFuturesCount() == 1'u
poll()
check: check:
getCount() == 0'u getCount() == 0'u
pendingFuturesCount() == 0'u pendingFuturesCount() == 0'u