Fix nested waitFor (IndexError defect crash) bug. (#309)

Add IndexError crash test for Linux/MacOS.
Add Sentinel version of fix.
Add GetQueuedCompletionStatusEx() support for Windows, which allows to capture more then one event in single `poll()` call.
This commit is contained in:
Eugene Kabanov 2022-09-16 23:34:18 +03:00 committed by GitHub
parent 8e8263370b
commit be2352027e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 164 additions and 51 deletions

View File

@ -166,11 +166,13 @@ export timer
# TODO: Check if yielded future is nil and throw a more meaningful exception
const unixPlatform = defined(macosx) or defined(freebsd) or
defined(netbsd) or defined(openbsd) or
defined(dragonfly) or defined(macos) or
defined(linux) or defined(android) or
defined(solaris)
const
unixPlatform = defined(macosx) or defined(freebsd) or
defined(netbsd) or defined(openbsd) or
defined(dragonfly) or defined(macos) or
defined(linux) or defined(android) or
defined(solaris)
MaxEventsCount* = 64
when defined(windows):
import winlean, sets, hashes
@ -212,6 +214,16 @@ type
idlers*: Deque[AsyncCallback]
trackers*: Table[string, TrackerBase]
proc sentinelCallbackImpl(arg: pointer) {.gcsafe, raises: [Defect].} =
raiseAssert "Sentinel callback MUST not be scheduled"
const
SentinelCallback = AsyncCallback(function: sentinelCallbackImpl,
udata: nil)
proc isSentinel(acb: AsyncCallback): bool {.raises: [Defect].} =
acb == SentinelCallback
proc `<`(a, b: TimerCallback): bool =
result = a.finishAt < b.finishAt
@ -275,16 +287,11 @@ template processIdlers(loop: untyped) =
loop.callbacks.addLast(loop.idlers.popFirst())
template processCallbacks(loop: untyped) =
var count = len(loop.callbacks)
for i in 0..<count:
# This is mostly workaround for people which are using `waitFor` where
# it must be used `await`. While using `waitFor` inside of callbacks
# dispatcher's callback list is got decreased and length of
# `loop.callbacks` become not equal to `count`, its why `IndexError`
# can be generated.
if len(loop.callbacks) == 0: break
let callable = loop.callbacks.popFirst()
if not isNil(callable.function):
while true:
let callable = loop.callbacks.popFirst() # len must be > 0 due to sentinel
if isSentinel(callable):
break
if not(isNil(callable.function)):
callable.function(callable.udata)
proc raiseAsDefect*(exc: ref Exception, msg: string) {.
@ -305,6 +312,14 @@ when defined(windows):
dwReserved: DWORD): cint {.
gcsafe, stdcall, raises: [].}
LPFN_GETQUEUEDCOMPLETIONSTATUSEX = proc(completionPort: Handle,
lpPortEntries: ptr OVERLAPPED_ENTRY,
ulCount: DWORD,
ulEntriesRemoved: var ULONG,
dwMilliseconds: DWORD,
fAlertable: WINBOOL): WINBOOL {.
gcsafe, stdcall, raises: [].}
CompletionKey = ULONG_PTR
CompletionData* = object
@ -316,6 +331,12 @@ when defined(windows):
CustomOverlapped* = object of OVERLAPPED
data*: CompletionData
OVERLAPPED_ENTRY* = object
lpCompletionKey*: ULONG_PTR
lpOverlapped*: ptr CustomOverlapped
internal: ULONG_PTR
dwNumberOfBytesTransferred: DWORD
PDispatcher* = ref object of PDispatcherBase
ioPort: Handle
handles: HashSet[AsyncFD]
@ -323,6 +344,7 @@ when defined(windows):
acceptEx*: WSAPROC_ACCEPTEX
getAcceptExSockAddrs*: WSAPROC_GETACCEPTEXSOCKADDRS
transmitFile*: WSAPROC_TRANSMITFILE
getQueuedCompletionStatusEx*: LPFN_GETQUEUEDCOMPLETIONSTATUSEX
PtrCustomOverlapped* = ptr CustomOverlapped
@ -330,6 +352,13 @@ when defined(windows):
AsyncFD* = distinct int
proc getModuleHandle(lpModuleName: WideCString): Handle {.
stdcall, dynlib: "kernel32", importc: "GetModuleHandleW", sideEffect.}
proc getProcAddress(hModule: Handle, lpProcName: cstring): pointer {.
stdcall, dynlib: "kernel32", importc: "GetProcAddress", sideEffect.}
proc rtlNtStatusToDosError(code: uint64): ULONG {.
stdcall, dynlib: "ntdll", importc: "RtlNtStatusToDosError", sideEffect.}
proc hash(x: AsyncFD): Hash {.borrow.}
proc `==`*(x: AsyncFD, y: AsyncFD): bool {.borrow, gcsafe.}
@ -352,6 +381,10 @@ when defined(windows):
D4: [0x95'i8, 0xca'i8, 0x00'i8, 0x80'i8,
0x5f'i8, 0x48'i8, 0xa1'i8, 0x92'i8])
let kernel32 = getModuleHandle(newWideCString("kernel32.dll"))
loop.getQueuedCompletionStatusEx = cast[LPFN_GETQUEUEDCOMPLETIONSTATUSEX](
getProcAddress(kernel32, "GetQueuedCompletionStatusEx"))
let sock = winlean.socket(winlean.AF_INET, 1, 6)
if sock == INVALID_SOCKET:
raiseOSError(osLastError())
@ -396,6 +429,7 @@ when defined(windows):
# Pre 0.20.0 Nim's stdlib version
res.timers = newHeapQueue[TimerCallback]()
res.callbacks = initDeque[AsyncCallback](64)
res.callbacks.addLast(SentinelCallback)
res.idlers = initDeque[AsyncCallback]()
res.trackers = initTable[string, TrackerBase]()
initAPI(res)
@ -425,59 +459,91 @@ when defined(windows):
proc poll*() {.raises: [Defect, CatchableError].} =
## Perform single asynchronous step, processing timers and completing
## unblocked tasks. Blocks until at least one event has completed.
## tasks. Blocks until at least one event has completed.
##
## Exceptions raised here indicate that waiting for tasks to be unblocked
## failed - exceptions from within tasks are instead propagated through
## their respective futures and not allowed to interrrupt the poll call.
let loop = getThreadDispatcher()
var curTime = Moment.now()
var curTimeout = DWORD(0)
var noNetworkEvents = false
var
curTime = Moment.now()
curTimeout = DWORD(0)
events: array[MaxEventsCount, OVERLAPPED_ENTRY]
# On reentrant `poll` calls from `processCallbacks`, e.g., `waitFor`,
# complete pending work of the outer `processCallbacks` call.
# On non-reentrant `poll` calls, this only removes sentinel element.
processCallbacks(loop)
# Moving expired timers to `loop.callbacks` and calculate timeout
loop.processTimersGetTimeout(curTimeout)
# Processing handles
var lpNumberOfBytesTransferred: DWORD
var lpCompletionKey: ULONG_PTR
var customOverlapped: PtrCustomOverlapped
let networkEventsCount =
if isNil(loop.getQueuedCompletionStatusEx):
let res = getQueuedCompletionStatus(
loop.ioPort,
addr events[0].dwNumberOfBytesTransferred,
addr events[0].lpCompletionKey,
cast[ptr POVERLAPPED](addr events[0].lpOverlapped),
curTimeout
)
if res == WINBOOL(0):
let errCode = osLastError()
if not(isNil(events[0].lpOverlapped)):
1
else:
if int32(errCode) != WAIT_TIMEOUT:
raiseOSError(errCode)
0
else:
1
else:
var eventsReceived = ULONG(0)
let res = loop.getQueuedCompletionStatusEx(
loop.ioPort,
addr events[0],
ULONG(len(events)),
eventsReceived,
curTimeout,
WINBOOL(0)
)
if res == WINBOOL(0):
let errCode = osLastError()
if int32(errCode) != WAIT_TIMEOUT:
raiseOSError(errCode)
0
else:
eventsReceived
let res = getQueuedCompletionStatus(
loop.ioPort, addr lpNumberOfBytesTransferred,
addr lpCompletionKey, cast[ptr POVERLAPPED](addr customOverlapped),
curTimeout).bool
if res:
customOverlapped.data.bytesCount = lpNumberOfBytesTransferred
customOverlapped.data.errCode = OSErrorCode(-1)
for i in 0 ..< networkEventsCount:
var customOverlapped = events[i].lpOverlapped
customOverlapped.data.errCode =
block:
let res = cast[uint64](customOverlapped.internal)
if res == 0'u64:
OSErrorCode(-1)
else:
OSErrorCode(rtlNtStatusToDosError(res))
customOverlapped.data.bytesCount = events[i].dwNumberOfBytesTransferred
let acb = AsyncCallback(function: customOverlapped.data.cb,
udata: cast[pointer](customOverlapped))
loop.callbacks.addLast(acb)
else:
let errCode = osLastError()
if customOverlapped != nil:
customOverlapped.data.errCode = errCode
let acb = AsyncCallback(function: customOverlapped.data.cb,
udata: cast[pointer](customOverlapped))
loop.callbacks.addLast(acb)
else:
if int32(errCode) != WAIT_TIMEOUT:
raiseOSError(errCode)
else:
noNetworkEvents = true
# Moving expired timers to `loop.callbacks`.
loop.processTimers()
# We move idle callbacks to `loop.callbacks` only if there no pending
# network events.
if noNetworkEvents:
if networkEventsCount == 0:
loop.processIdlers()
# All callbacks which will be added in process will be processed on next
# poll() call.
loop.processCallbacks()
# All callbacks which will be added during `processCallbacks` will be
# scheduled after the sentinel and are processed on next `poll()` call.
loop.callbacks.addLast(SentinelCallback)
processCallbacks(loop)
# All callbacks done, skip `processCallbacks` at start.
loop.callbacks.addFirst(SentinelCallback)
proc closeSocket*(fd: AsyncFD, aftercb: CallbackFunc = nil) =
## Closes a socket and ensures that it is unregistered.
@ -536,6 +602,7 @@ elif unixPlatform:
# Before 0.20.0 Nim's stdlib version
res.timers.newHeapQueue()
res.callbacks = initDeque[AsyncCallback](64)
res.callbacks.addLast(SentinelCallback)
res.idlers = initDeque[AsyncCallback]()
res.keys = newSeq[ReadyKey](64)
res.trackers = initTable[string, TrackerBase]()
@ -706,6 +773,11 @@ elif unixPlatform:
let customSet = {Event.Timer, Event.Signal, Event.Process,
Event.Vnode}
# On reentrant `poll` calls from `processCallbacks`, e.g., `waitFor`,
# complete pending work of the outer `processCallbacks` call.
# On non-reentrant `poll` calls, this only removes sentinel element.
processCallbacks(loop)
# Moving expired timers to `loop.callbacks` and calculate timeout.
loop.processTimersGetTimeout(curTimeout)
@ -741,9 +813,13 @@ elif unixPlatform:
if count == 0:
loop.processIdlers()
# All callbacks which will be added in process, will be processed on next
# poll() call.
loop.processCallbacks()
# All callbacks which will be added during `processCallbacks` will be
# scheduled after the sentinel and are processed on next `poll()` call.
loop.callbacks.addLast(SentinelCallback)
processCallbacks(loop)
# All callbacks done, skip `processCallbacks` at start.
loop.callbacks.addFirst(SentinelCallback)
else:
proc initAPI() = discard

View File

@ -21,7 +21,7 @@ suite "Asynchronous issues test suite":
test: string
proc udp4DataAvailable(transp: DatagramTransport,
remote: TransportAddress): Future[void] {.async, gcsafe.} =
remote: TransportAddress) {.async, gcsafe.} =
var udata = getUserData[CustomData](transp)
var expect = TEST_MSG
var data: seq[byte]
@ -98,6 +98,40 @@ suite "Asynchronous issues test suite":
result = r1 and r2
proc createBigMessage(size: int): seq[byte] =
var message = "MESSAGE"
var res = newSeq[byte](size)
for i in 0 ..< len(result):
res[i] = byte(message[i mod len(message)])
res
proc testIndexError(): Future[bool] {.async.} =
var server = createStreamServer(initTAddress("127.0.0.1:0"),
flags = {ReuseAddr})
let messageSize = DefaultStreamBufferSize * 4
var buffer = newSeq[byte](messageSize)
let msg = createBigMessage(messageSize)
let address = server.localAddress()
let afut = server.accept()
let outTransp = await connect(address)
let inpTransp = await afut
let bytesSent = await outTransp.write(msg)
check bytesSent == messageSize
var rfut = inpTransp.readExactly(addr buffer[0], messageSize)
proc waiterProc(udata: pointer) {.raises: [Defect], gcsafe.} =
try:
waitFor(sleepAsync(0.milliseconds))
except CatchableError as exc:
raiseAssert "Unexpected exception happened"
let timer = setTimer(Moment.fromNow(0.seconds), waiterProc, nil)
await sleepAsync(100.milliseconds)
await inpTransp.closeWait()
await outTransp.closeWait()
await server.closeWait()
return true
test "Issue #6":
check waitFor(issue6()) == true
@ -112,3 +146,6 @@ suite "Asynchronous issues test suite":
test "Defer for asynchronous procedures test [Nim's issue #13899]":
check waitFor(testDefer()) == true
test "IndexError crash test":
check waitFor(testIndexError()) == true