Fix nested waitFor (IndexError defect crash) bug. (#309)
Add IndexError crash test for Linux/MacOS. Add Sentinel version of fix. Add GetQueuedCompletionStatusEx() support for Windows, which allows to capture more then one event in single `poll()` call.
This commit is contained in:
parent
8e8263370b
commit
be2352027e
|
@ -166,11 +166,13 @@ export timer
|
|||
|
||||
# TODO: Check if yielded future is nil and throw a more meaningful exception
|
||||
|
||||
const unixPlatform = defined(macosx) or defined(freebsd) or
|
||||
defined(netbsd) or defined(openbsd) or
|
||||
defined(dragonfly) or defined(macos) or
|
||||
defined(linux) or defined(android) or
|
||||
defined(solaris)
|
||||
const
|
||||
unixPlatform = defined(macosx) or defined(freebsd) or
|
||||
defined(netbsd) or defined(openbsd) or
|
||||
defined(dragonfly) or defined(macos) or
|
||||
defined(linux) or defined(android) or
|
||||
defined(solaris)
|
||||
MaxEventsCount* = 64
|
||||
|
||||
when defined(windows):
|
||||
import winlean, sets, hashes
|
||||
|
@ -212,6 +214,16 @@ type
|
|||
idlers*: Deque[AsyncCallback]
|
||||
trackers*: Table[string, TrackerBase]
|
||||
|
||||
proc sentinelCallbackImpl(arg: pointer) {.gcsafe, raises: [Defect].} =
|
||||
raiseAssert "Sentinel callback MUST not be scheduled"
|
||||
|
||||
const
|
||||
SentinelCallback = AsyncCallback(function: sentinelCallbackImpl,
|
||||
udata: nil)
|
||||
|
||||
proc isSentinel(acb: AsyncCallback): bool {.raises: [Defect].} =
|
||||
acb == SentinelCallback
|
||||
|
||||
proc `<`(a, b: TimerCallback): bool =
|
||||
result = a.finishAt < b.finishAt
|
||||
|
||||
|
@ -275,16 +287,11 @@ template processIdlers(loop: untyped) =
|
|||
loop.callbacks.addLast(loop.idlers.popFirst())
|
||||
|
||||
template processCallbacks(loop: untyped) =
|
||||
var count = len(loop.callbacks)
|
||||
for i in 0..<count:
|
||||
# This is mostly workaround for people which are using `waitFor` where
|
||||
# it must be used `await`. While using `waitFor` inside of callbacks
|
||||
# dispatcher's callback list is got decreased and length of
|
||||
# `loop.callbacks` become not equal to `count`, its why `IndexError`
|
||||
# can be generated.
|
||||
if len(loop.callbacks) == 0: break
|
||||
let callable = loop.callbacks.popFirst()
|
||||
if not isNil(callable.function):
|
||||
while true:
|
||||
let callable = loop.callbacks.popFirst() # len must be > 0 due to sentinel
|
||||
if isSentinel(callable):
|
||||
break
|
||||
if not(isNil(callable.function)):
|
||||
callable.function(callable.udata)
|
||||
|
||||
proc raiseAsDefect*(exc: ref Exception, msg: string) {.
|
||||
|
@ -305,6 +312,14 @@ when defined(windows):
|
|||
dwReserved: DWORD): cint {.
|
||||
gcsafe, stdcall, raises: [].}
|
||||
|
||||
LPFN_GETQUEUEDCOMPLETIONSTATUSEX = proc(completionPort: Handle,
|
||||
lpPortEntries: ptr OVERLAPPED_ENTRY,
|
||||
ulCount: DWORD,
|
||||
ulEntriesRemoved: var ULONG,
|
||||
dwMilliseconds: DWORD,
|
||||
fAlertable: WINBOOL): WINBOOL {.
|
||||
gcsafe, stdcall, raises: [].}
|
||||
|
||||
CompletionKey = ULONG_PTR
|
||||
|
||||
CompletionData* = object
|
||||
|
@ -316,6 +331,12 @@ when defined(windows):
|
|||
CustomOverlapped* = object of OVERLAPPED
|
||||
data*: CompletionData
|
||||
|
||||
OVERLAPPED_ENTRY* = object
|
||||
lpCompletionKey*: ULONG_PTR
|
||||
lpOverlapped*: ptr CustomOverlapped
|
||||
internal: ULONG_PTR
|
||||
dwNumberOfBytesTransferred: DWORD
|
||||
|
||||
PDispatcher* = ref object of PDispatcherBase
|
||||
ioPort: Handle
|
||||
handles: HashSet[AsyncFD]
|
||||
|
@ -323,6 +344,7 @@ when defined(windows):
|
|||
acceptEx*: WSAPROC_ACCEPTEX
|
||||
getAcceptExSockAddrs*: WSAPROC_GETACCEPTEXSOCKADDRS
|
||||
transmitFile*: WSAPROC_TRANSMITFILE
|
||||
getQueuedCompletionStatusEx*: LPFN_GETQUEUEDCOMPLETIONSTATUSEX
|
||||
|
||||
PtrCustomOverlapped* = ptr CustomOverlapped
|
||||
|
||||
|
@ -330,6 +352,13 @@ when defined(windows):
|
|||
|
||||
AsyncFD* = distinct int
|
||||
|
||||
proc getModuleHandle(lpModuleName: WideCString): Handle {.
|
||||
stdcall, dynlib: "kernel32", importc: "GetModuleHandleW", sideEffect.}
|
||||
proc getProcAddress(hModule: Handle, lpProcName: cstring): pointer {.
|
||||
stdcall, dynlib: "kernel32", importc: "GetProcAddress", sideEffect.}
|
||||
proc rtlNtStatusToDosError(code: uint64): ULONG {.
|
||||
stdcall, dynlib: "ntdll", importc: "RtlNtStatusToDosError", sideEffect.}
|
||||
|
||||
proc hash(x: AsyncFD): Hash {.borrow.}
|
||||
proc `==`*(x: AsyncFD, y: AsyncFD): bool {.borrow, gcsafe.}
|
||||
|
||||
|
@ -352,6 +381,10 @@ when defined(windows):
|
|||
D4: [0x95'i8, 0xca'i8, 0x00'i8, 0x80'i8,
|
||||
0x5f'i8, 0x48'i8, 0xa1'i8, 0x92'i8])
|
||||
|
||||
let kernel32 = getModuleHandle(newWideCString("kernel32.dll"))
|
||||
loop.getQueuedCompletionStatusEx = cast[LPFN_GETQUEUEDCOMPLETIONSTATUSEX](
|
||||
getProcAddress(kernel32, "GetQueuedCompletionStatusEx"))
|
||||
|
||||
let sock = winlean.socket(winlean.AF_INET, 1, 6)
|
||||
if sock == INVALID_SOCKET:
|
||||
raiseOSError(osLastError())
|
||||
|
@ -396,6 +429,7 @@ when defined(windows):
|
|||
# Pre 0.20.0 Nim's stdlib version
|
||||
res.timers = newHeapQueue[TimerCallback]()
|
||||
res.callbacks = initDeque[AsyncCallback](64)
|
||||
res.callbacks.addLast(SentinelCallback)
|
||||
res.idlers = initDeque[AsyncCallback]()
|
||||
res.trackers = initTable[string, TrackerBase]()
|
||||
initAPI(res)
|
||||
|
@ -425,59 +459,91 @@ when defined(windows):
|
|||
|
||||
proc poll*() {.raises: [Defect, CatchableError].} =
|
||||
## Perform single asynchronous step, processing timers and completing
|
||||
## unblocked tasks. Blocks until at least one event has completed.
|
||||
## tasks. Blocks until at least one event has completed.
|
||||
##
|
||||
## Exceptions raised here indicate that waiting for tasks to be unblocked
|
||||
## failed - exceptions from within tasks are instead propagated through
|
||||
## their respective futures and not allowed to interrrupt the poll call.
|
||||
let loop = getThreadDispatcher()
|
||||
var curTime = Moment.now()
|
||||
var curTimeout = DWORD(0)
|
||||
var noNetworkEvents = false
|
||||
var
|
||||
curTime = Moment.now()
|
||||
curTimeout = DWORD(0)
|
||||
events: array[MaxEventsCount, OVERLAPPED_ENTRY]
|
||||
|
||||
# On reentrant `poll` calls from `processCallbacks`, e.g., `waitFor`,
|
||||
# complete pending work of the outer `processCallbacks` call.
|
||||
# On non-reentrant `poll` calls, this only removes sentinel element.
|
||||
processCallbacks(loop)
|
||||
|
||||
# Moving expired timers to `loop.callbacks` and calculate timeout
|
||||
loop.processTimersGetTimeout(curTimeout)
|
||||
|
||||
# Processing handles
|
||||
var lpNumberOfBytesTransferred: DWORD
|
||||
var lpCompletionKey: ULONG_PTR
|
||||
var customOverlapped: PtrCustomOverlapped
|
||||
let networkEventsCount =
|
||||
if isNil(loop.getQueuedCompletionStatusEx):
|
||||
let res = getQueuedCompletionStatus(
|
||||
loop.ioPort,
|
||||
addr events[0].dwNumberOfBytesTransferred,
|
||||
addr events[0].lpCompletionKey,
|
||||
cast[ptr POVERLAPPED](addr events[0].lpOverlapped),
|
||||
curTimeout
|
||||
)
|
||||
if res == WINBOOL(0):
|
||||
let errCode = osLastError()
|
||||
if not(isNil(events[0].lpOverlapped)):
|
||||
1
|
||||
else:
|
||||
if int32(errCode) != WAIT_TIMEOUT:
|
||||
raiseOSError(errCode)
|
||||
0
|
||||
else:
|
||||
1
|
||||
else:
|
||||
var eventsReceived = ULONG(0)
|
||||
let res = loop.getQueuedCompletionStatusEx(
|
||||
loop.ioPort,
|
||||
addr events[0],
|
||||
ULONG(len(events)),
|
||||
eventsReceived,
|
||||
curTimeout,
|
||||
WINBOOL(0)
|
||||
)
|
||||
if res == WINBOOL(0):
|
||||
let errCode = osLastError()
|
||||
if int32(errCode) != WAIT_TIMEOUT:
|
||||
raiseOSError(errCode)
|
||||
0
|
||||
else:
|
||||
eventsReceived
|
||||
|
||||
let res = getQueuedCompletionStatus(
|
||||
loop.ioPort, addr lpNumberOfBytesTransferred,
|
||||
addr lpCompletionKey, cast[ptr POVERLAPPED](addr customOverlapped),
|
||||
curTimeout).bool
|
||||
|
||||
if res:
|
||||
customOverlapped.data.bytesCount = lpNumberOfBytesTransferred
|
||||
customOverlapped.data.errCode = OSErrorCode(-1)
|
||||
for i in 0 ..< networkEventsCount:
|
||||
var customOverlapped = events[i].lpOverlapped
|
||||
customOverlapped.data.errCode =
|
||||
block:
|
||||
let res = cast[uint64](customOverlapped.internal)
|
||||
if res == 0'u64:
|
||||
OSErrorCode(-1)
|
||||
else:
|
||||
OSErrorCode(rtlNtStatusToDosError(res))
|
||||
customOverlapped.data.bytesCount = events[i].dwNumberOfBytesTransferred
|
||||
let acb = AsyncCallback(function: customOverlapped.data.cb,
|
||||
udata: cast[pointer](customOverlapped))
|
||||
loop.callbacks.addLast(acb)
|
||||
else:
|
||||
let errCode = osLastError()
|
||||
if customOverlapped != nil:
|
||||
customOverlapped.data.errCode = errCode
|
||||
let acb = AsyncCallback(function: customOverlapped.data.cb,
|
||||
udata: cast[pointer](customOverlapped))
|
||||
loop.callbacks.addLast(acb)
|
||||
else:
|
||||
if int32(errCode) != WAIT_TIMEOUT:
|
||||
raiseOSError(errCode)
|
||||
else:
|
||||
noNetworkEvents = true
|
||||
|
||||
# Moving expired timers to `loop.callbacks`.
|
||||
loop.processTimers()
|
||||
|
||||
# We move idle callbacks to `loop.callbacks` only if there no pending
|
||||
# network events.
|
||||
if noNetworkEvents:
|
||||
if networkEventsCount == 0:
|
||||
loop.processIdlers()
|
||||
|
||||
# All callbacks which will be added in process will be processed on next
|
||||
# poll() call.
|
||||
loop.processCallbacks()
|
||||
# All callbacks which will be added during `processCallbacks` will be
|
||||
# scheduled after the sentinel and are processed on next `poll()` call.
|
||||
loop.callbacks.addLast(SentinelCallback)
|
||||
processCallbacks(loop)
|
||||
|
||||
# All callbacks done, skip `processCallbacks` at start.
|
||||
loop.callbacks.addFirst(SentinelCallback)
|
||||
|
||||
proc closeSocket*(fd: AsyncFD, aftercb: CallbackFunc = nil) =
|
||||
## Closes a socket and ensures that it is unregistered.
|
||||
|
@ -536,6 +602,7 @@ elif unixPlatform:
|
|||
# Before 0.20.0 Nim's stdlib version
|
||||
res.timers.newHeapQueue()
|
||||
res.callbacks = initDeque[AsyncCallback](64)
|
||||
res.callbacks.addLast(SentinelCallback)
|
||||
res.idlers = initDeque[AsyncCallback]()
|
||||
res.keys = newSeq[ReadyKey](64)
|
||||
res.trackers = initTable[string, TrackerBase]()
|
||||
|
@ -706,6 +773,11 @@ elif unixPlatform:
|
|||
let customSet = {Event.Timer, Event.Signal, Event.Process,
|
||||
Event.Vnode}
|
||||
|
||||
# On reentrant `poll` calls from `processCallbacks`, e.g., `waitFor`,
|
||||
# complete pending work of the outer `processCallbacks` call.
|
||||
# On non-reentrant `poll` calls, this only removes sentinel element.
|
||||
processCallbacks(loop)
|
||||
|
||||
# Moving expired timers to `loop.callbacks` and calculate timeout.
|
||||
loop.processTimersGetTimeout(curTimeout)
|
||||
|
||||
|
@ -741,9 +813,13 @@ elif unixPlatform:
|
|||
if count == 0:
|
||||
loop.processIdlers()
|
||||
|
||||
# All callbacks which will be added in process, will be processed on next
|
||||
# poll() call.
|
||||
loop.processCallbacks()
|
||||
# All callbacks which will be added during `processCallbacks` will be
|
||||
# scheduled after the sentinel and are processed on next `poll()` call.
|
||||
loop.callbacks.addLast(SentinelCallback)
|
||||
processCallbacks(loop)
|
||||
|
||||
# All callbacks done, skip `processCallbacks` at start.
|
||||
loop.callbacks.addFirst(SentinelCallback)
|
||||
|
||||
else:
|
||||
proc initAPI() = discard
|
||||
|
|
|
@ -21,7 +21,7 @@ suite "Asynchronous issues test suite":
|
|||
test: string
|
||||
|
||||
proc udp4DataAvailable(transp: DatagramTransport,
|
||||
remote: TransportAddress): Future[void] {.async, gcsafe.} =
|
||||
remote: TransportAddress) {.async, gcsafe.} =
|
||||
var udata = getUserData[CustomData](transp)
|
||||
var expect = TEST_MSG
|
||||
var data: seq[byte]
|
||||
|
@ -98,6 +98,40 @@ suite "Asynchronous issues test suite":
|
|||
|
||||
result = r1 and r2
|
||||
|
||||
proc createBigMessage(size: int): seq[byte] =
|
||||
var message = "MESSAGE"
|
||||
var res = newSeq[byte](size)
|
||||
for i in 0 ..< len(result):
|
||||
res[i] = byte(message[i mod len(message)])
|
||||
res
|
||||
|
||||
proc testIndexError(): Future[bool] {.async.} =
|
||||
var server = createStreamServer(initTAddress("127.0.0.1:0"),
|
||||
flags = {ReuseAddr})
|
||||
let messageSize = DefaultStreamBufferSize * 4
|
||||
var buffer = newSeq[byte](messageSize)
|
||||
let msg = createBigMessage(messageSize)
|
||||
let address = server.localAddress()
|
||||
let afut = server.accept()
|
||||
let outTransp = await connect(address)
|
||||
let inpTransp = await afut
|
||||
let bytesSent = await outTransp.write(msg)
|
||||
check bytesSent == messageSize
|
||||
var rfut = inpTransp.readExactly(addr buffer[0], messageSize)
|
||||
|
||||
proc waiterProc(udata: pointer) {.raises: [Defect], gcsafe.} =
|
||||
try:
|
||||
waitFor(sleepAsync(0.milliseconds))
|
||||
except CatchableError as exc:
|
||||
raiseAssert "Unexpected exception happened"
|
||||
let timer = setTimer(Moment.fromNow(0.seconds), waiterProc, nil)
|
||||
await sleepAsync(100.milliseconds)
|
||||
|
||||
await inpTransp.closeWait()
|
||||
await outTransp.closeWait()
|
||||
await server.closeWait()
|
||||
return true
|
||||
|
||||
test "Issue #6":
|
||||
check waitFor(issue6()) == true
|
||||
|
||||
|
@ -112,3 +146,6 @@ suite "Asynchronous issues test suite":
|
|||
|
||||
test "Defer for asynchronous procedures test [Nim's issue #13899]":
|
||||
check waitFor(testDefer()) == true
|
||||
|
||||
test "IndexError crash test":
|
||||
check waitFor(testIndexError()) == true
|
||||
|
|
Loading…
Reference in New Issue