From be2352027ecd54bb63509c6347fc7bccb01550c6 Mon Sep 17 00:00:00 2001 From: Eugene Kabanov Date: Fri, 16 Sep 2022 23:34:18 +0300 Subject: [PATCH] Fix nested waitFor (IndexError defect crash) bug. (#309) Add IndexError crash test for Linux/MacOS. Add Sentinel version of fix. Add GetQueuedCompletionStatusEx() support for Windows, which allows to capture more then one event in single `poll()` call. --- chronos/asyncloop.nim | 176 ++++++++++++++++++++++++++++++------------ tests/testbugs.nim | 39 +++++++++- 2 files changed, 164 insertions(+), 51 deletions(-) diff --git a/chronos/asyncloop.nim b/chronos/asyncloop.nim index 131b318d..55340a33 100644 --- a/chronos/asyncloop.nim +++ b/chronos/asyncloop.nim @@ -166,11 +166,13 @@ export timer # TODO: Check if yielded future is nil and throw a more meaningful exception -const unixPlatform = defined(macosx) or defined(freebsd) or - defined(netbsd) or defined(openbsd) or - defined(dragonfly) or defined(macos) or - defined(linux) or defined(android) or - defined(solaris) +const + unixPlatform = defined(macosx) or defined(freebsd) or + defined(netbsd) or defined(openbsd) or + defined(dragonfly) or defined(macos) or + defined(linux) or defined(android) or + defined(solaris) + MaxEventsCount* = 64 when defined(windows): import winlean, sets, hashes @@ -212,6 +214,16 @@ type idlers*: Deque[AsyncCallback] trackers*: Table[string, TrackerBase] +proc sentinelCallbackImpl(arg: pointer) {.gcsafe, raises: [Defect].} = + raiseAssert "Sentinel callback MUST not be scheduled" + +const + SentinelCallback = AsyncCallback(function: sentinelCallbackImpl, + udata: nil) + +proc isSentinel(acb: AsyncCallback): bool {.raises: [Defect].} = + acb == SentinelCallback + proc `<`(a, b: TimerCallback): bool = result = a.finishAt < b.finishAt @@ -275,16 +287,11 @@ template processIdlers(loop: untyped) = loop.callbacks.addLast(loop.idlers.popFirst()) template processCallbacks(loop: untyped) = - var count = len(loop.callbacks) - for i in 0.. 0 due to sentinel + if isSentinel(callable): + break + if not(isNil(callable.function)): callable.function(callable.udata) proc raiseAsDefect*(exc: ref Exception, msg: string) {. @@ -305,6 +312,14 @@ when defined(windows): dwReserved: DWORD): cint {. gcsafe, stdcall, raises: [].} + LPFN_GETQUEUEDCOMPLETIONSTATUSEX = proc(completionPort: Handle, + lpPortEntries: ptr OVERLAPPED_ENTRY, + ulCount: DWORD, + ulEntriesRemoved: var ULONG, + dwMilliseconds: DWORD, + fAlertable: WINBOOL): WINBOOL {. + gcsafe, stdcall, raises: [].} + CompletionKey = ULONG_PTR CompletionData* = object @@ -316,6 +331,12 @@ when defined(windows): CustomOverlapped* = object of OVERLAPPED data*: CompletionData + OVERLAPPED_ENTRY* = object + lpCompletionKey*: ULONG_PTR + lpOverlapped*: ptr CustomOverlapped + internal: ULONG_PTR + dwNumberOfBytesTransferred: DWORD + PDispatcher* = ref object of PDispatcherBase ioPort: Handle handles: HashSet[AsyncFD] @@ -323,6 +344,7 @@ when defined(windows): acceptEx*: WSAPROC_ACCEPTEX getAcceptExSockAddrs*: WSAPROC_GETACCEPTEXSOCKADDRS transmitFile*: WSAPROC_TRANSMITFILE + getQueuedCompletionStatusEx*: LPFN_GETQUEUEDCOMPLETIONSTATUSEX PtrCustomOverlapped* = ptr CustomOverlapped @@ -330,6 +352,13 @@ when defined(windows): AsyncFD* = distinct int + proc getModuleHandle(lpModuleName: WideCString): Handle {. + stdcall, dynlib: "kernel32", importc: "GetModuleHandleW", sideEffect.} + proc getProcAddress(hModule: Handle, lpProcName: cstring): pointer {. + stdcall, dynlib: "kernel32", importc: "GetProcAddress", sideEffect.} + proc rtlNtStatusToDosError(code: uint64): ULONG {. + stdcall, dynlib: "ntdll", importc: "RtlNtStatusToDosError", sideEffect.} + proc hash(x: AsyncFD): Hash {.borrow.} proc `==`*(x: AsyncFD, y: AsyncFD): bool {.borrow, gcsafe.} @@ -352,6 +381,10 @@ when defined(windows): D4: [0x95'i8, 0xca'i8, 0x00'i8, 0x80'i8, 0x5f'i8, 0x48'i8, 0xa1'i8, 0x92'i8]) + let kernel32 = getModuleHandle(newWideCString("kernel32.dll")) + loop.getQueuedCompletionStatusEx = cast[LPFN_GETQUEUEDCOMPLETIONSTATUSEX]( + getProcAddress(kernel32, "GetQueuedCompletionStatusEx")) + let sock = winlean.socket(winlean.AF_INET, 1, 6) if sock == INVALID_SOCKET: raiseOSError(osLastError()) @@ -396,6 +429,7 @@ when defined(windows): # Pre 0.20.0 Nim's stdlib version res.timers = newHeapQueue[TimerCallback]() res.callbacks = initDeque[AsyncCallback](64) + res.callbacks.addLast(SentinelCallback) res.idlers = initDeque[AsyncCallback]() res.trackers = initTable[string, TrackerBase]() initAPI(res) @@ -425,59 +459,91 @@ when defined(windows): proc poll*() {.raises: [Defect, CatchableError].} = ## Perform single asynchronous step, processing timers and completing - ## unblocked tasks. Blocks until at least one event has completed. + ## tasks. Blocks until at least one event has completed. ## ## Exceptions raised here indicate that waiting for tasks to be unblocked ## failed - exceptions from within tasks are instead propagated through ## their respective futures and not allowed to interrrupt the poll call. let loop = getThreadDispatcher() - var curTime = Moment.now() - var curTimeout = DWORD(0) - var noNetworkEvents = false + var + curTime = Moment.now() + curTimeout = DWORD(0) + events: array[MaxEventsCount, OVERLAPPED_ENTRY] + + # On reentrant `poll` calls from `processCallbacks`, e.g., `waitFor`, + # complete pending work of the outer `processCallbacks` call. + # On non-reentrant `poll` calls, this only removes sentinel element. + processCallbacks(loop) # Moving expired timers to `loop.callbacks` and calculate timeout loop.processTimersGetTimeout(curTimeout) - # Processing handles - var lpNumberOfBytesTransferred: DWORD - var lpCompletionKey: ULONG_PTR - var customOverlapped: PtrCustomOverlapped + let networkEventsCount = + if isNil(loop.getQueuedCompletionStatusEx): + let res = getQueuedCompletionStatus( + loop.ioPort, + addr events[0].dwNumberOfBytesTransferred, + addr events[0].lpCompletionKey, + cast[ptr POVERLAPPED](addr events[0].lpOverlapped), + curTimeout + ) + if res == WINBOOL(0): + let errCode = osLastError() + if not(isNil(events[0].lpOverlapped)): + 1 + else: + if int32(errCode) != WAIT_TIMEOUT: + raiseOSError(errCode) + 0 + else: + 1 + else: + var eventsReceived = ULONG(0) + let res = loop.getQueuedCompletionStatusEx( + loop.ioPort, + addr events[0], + ULONG(len(events)), + eventsReceived, + curTimeout, + WINBOOL(0) + ) + if res == WINBOOL(0): + let errCode = osLastError() + if int32(errCode) != WAIT_TIMEOUT: + raiseOSError(errCode) + 0 + else: + eventsReceived - let res = getQueuedCompletionStatus( - loop.ioPort, addr lpNumberOfBytesTransferred, - addr lpCompletionKey, cast[ptr POVERLAPPED](addr customOverlapped), - curTimeout).bool - - if res: - customOverlapped.data.bytesCount = lpNumberOfBytesTransferred - customOverlapped.data.errCode = OSErrorCode(-1) + for i in 0 ..< networkEventsCount: + var customOverlapped = events[i].lpOverlapped + customOverlapped.data.errCode = + block: + let res = cast[uint64](customOverlapped.internal) + if res == 0'u64: + OSErrorCode(-1) + else: + OSErrorCode(rtlNtStatusToDosError(res)) + customOverlapped.data.bytesCount = events[i].dwNumberOfBytesTransferred let acb = AsyncCallback(function: customOverlapped.data.cb, udata: cast[pointer](customOverlapped)) loop.callbacks.addLast(acb) - else: - let errCode = osLastError() - if customOverlapped != nil: - customOverlapped.data.errCode = errCode - let acb = AsyncCallback(function: customOverlapped.data.cb, - udata: cast[pointer](customOverlapped)) - loop.callbacks.addLast(acb) - else: - if int32(errCode) != WAIT_TIMEOUT: - raiseOSError(errCode) - else: - noNetworkEvents = true # Moving expired timers to `loop.callbacks`. loop.processTimers() # We move idle callbacks to `loop.callbacks` only if there no pending # network events. - if noNetworkEvents: + if networkEventsCount == 0: loop.processIdlers() - # All callbacks which will be added in process will be processed on next - # poll() call. - loop.processCallbacks() + # All callbacks which will be added during `processCallbacks` will be + # scheduled after the sentinel and are processed on next `poll()` call. + loop.callbacks.addLast(SentinelCallback) + processCallbacks(loop) + + # All callbacks done, skip `processCallbacks` at start. + loop.callbacks.addFirst(SentinelCallback) proc closeSocket*(fd: AsyncFD, aftercb: CallbackFunc = nil) = ## Closes a socket and ensures that it is unregistered. @@ -536,6 +602,7 @@ elif unixPlatform: # Before 0.20.0 Nim's stdlib version res.timers.newHeapQueue() res.callbacks = initDeque[AsyncCallback](64) + res.callbacks.addLast(SentinelCallback) res.idlers = initDeque[AsyncCallback]() res.keys = newSeq[ReadyKey](64) res.trackers = initTable[string, TrackerBase]() @@ -706,6 +773,11 @@ elif unixPlatform: let customSet = {Event.Timer, Event.Signal, Event.Process, Event.Vnode} + # On reentrant `poll` calls from `processCallbacks`, e.g., `waitFor`, + # complete pending work of the outer `processCallbacks` call. + # On non-reentrant `poll` calls, this only removes sentinel element. + processCallbacks(loop) + # Moving expired timers to `loop.callbacks` and calculate timeout. loop.processTimersGetTimeout(curTimeout) @@ -741,9 +813,13 @@ elif unixPlatform: if count == 0: loop.processIdlers() - # All callbacks which will be added in process, will be processed on next - # poll() call. - loop.processCallbacks() + # All callbacks which will be added during `processCallbacks` will be + # scheduled after the sentinel and are processed on next `poll()` call. + loop.callbacks.addLast(SentinelCallback) + processCallbacks(loop) + + # All callbacks done, skip `processCallbacks` at start. + loop.callbacks.addFirst(SentinelCallback) else: proc initAPI() = discard diff --git a/tests/testbugs.nim b/tests/testbugs.nim index a5cc2ede..d31ea994 100644 --- a/tests/testbugs.nim +++ b/tests/testbugs.nim @@ -21,7 +21,7 @@ suite "Asynchronous issues test suite": test: string proc udp4DataAvailable(transp: DatagramTransport, - remote: TransportAddress): Future[void] {.async, gcsafe.} = + remote: TransportAddress) {.async, gcsafe.} = var udata = getUserData[CustomData](transp) var expect = TEST_MSG var data: seq[byte] @@ -98,6 +98,40 @@ suite "Asynchronous issues test suite": result = r1 and r2 + proc createBigMessage(size: int): seq[byte] = + var message = "MESSAGE" + var res = newSeq[byte](size) + for i in 0 ..< len(result): + res[i] = byte(message[i mod len(message)]) + res + + proc testIndexError(): Future[bool] {.async.} = + var server = createStreamServer(initTAddress("127.0.0.1:0"), + flags = {ReuseAddr}) + let messageSize = DefaultStreamBufferSize * 4 + var buffer = newSeq[byte](messageSize) + let msg = createBigMessage(messageSize) + let address = server.localAddress() + let afut = server.accept() + let outTransp = await connect(address) + let inpTransp = await afut + let bytesSent = await outTransp.write(msg) + check bytesSent == messageSize + var rfut = inpTransp.readExactly(addr buffer[0], messageSize) + + proc waiterProc(udata: pointer) {.raises: [Defect], gcsafe.} = + try: + waitFor(sleepAsync(0.milliseconds)) + except CatchableError as exc: + raiseAssert "Unexpected exception happened" + let timer = setTimer(Moment.fromNow(0.seconds), waiterProc, nil) + await sleepAsync(100.milliseconds) + + await inpTransp.closeWait() + await outTransp.closeWait() + await server.closeWait() + return true + test "Issue #6": check waitFor(issue6()) == true @@ -112,3 +146,6 @@ suite "Asynchronous issues test suite": test "Defer for asynchronous procedures test [Nim's issue #13899]": check waitFor(testDefer()) == true + + test "IndexError crash test": + check waitFor(testIndexError()) == true