Fix #8 and related issues, added more tests for it.

Fix Unix connection failed bug.
This commit is contained in:
cheatfate 2018-08-24 15:20:08 +03:00
parent f94cedb47b
commit 530905f276
8 changed files with 455 additions and 146 deletions

View File

@ -1,5 +1,5 @@
packageName = "asyncdispatch2"
version = "2.0.8"
version = "2.0.9"
author = "Status Research & Development GmbH"
description = "Asyncdispatch2"
license = "Apache License 2.0 or MIT"

View File

@ -287,12 +287,14 @@ when defined(windows) or defined(nimdoc):
var gDisp{.threadvar.}: PDispatcher ## Global dispatcher
proc setGlobalDispatcher*(disp: PDispatcher) =
## Set current thread's dispatcher instance to ``disp``.
if not gDisp.isNil:
assert gDisp.callbacks.len == 0
gDisp = disp
initCallSoonProc()
proc getGlobalDispatcher*(): PDispatcher =
## Returns current thread's dispatcher instance.
if gDisp.isNil:
setGlobalDispatcher(newDispatcher())
result = gDisp
@ -303,14 +305,15 @@ when defined(windows) or defined(nimdoc):
return disp.ioPort
proc register*(fd: AsyncFD) =
## Registers ``fd`` with the dispatcher.
let p = getGlobalDispatcher()
if createIoCompletionPort(fd.Handle, p.ioPort,
## Register file descriptor ``fd`` in thread's dispatcher.
let loop = getGlobalDispatcher()
if createIoCompletionPort(fd.Handle, loop.ioPort,
cast[CompletionKey](fd), 1) == 0:
raiseOSError(osLastError())
p.handles.incl(fd)
loop.handles.incl(fd)
proc poll*() =
## Perform single asynchronous step.
let loop = getGlobalDispatcher()
var curTime = fastEpochTime()
var curTimeout = DWORD(0)
@ -397,16 +400,21 @@ when defined(windows) or defined(nimdoc):
loop.transmitFile = cast[WSAPROC_TRANSMITFILE](funcPointer)
close(sock)
proc closeSocket*(socket: AsyncFD) =
proc closeSocket*(socket: AsyncFD, aftercb: CallbackFunc = nil) =
## Closes a socket and ensures that it is unregistered.
let loop = getGlobalDispatcher()
socket.SocketHandle.close()
getGlobalDispatcher().handles.excl(socket)
loop.handles.excl(socket)
if not isNil(aftercb):
var acb = AsyncCallback(function: aftercb)
loop.callbacks.addLast(acb)
proc unregister*(fd: AsyncFD) =
## Unregisters ``fd``.
getGlobalDispatcher().handles.excl(fd)
proc contains*(disp: PDispatcher, fd: AsyncFD): bool =
## Returns ``true`` if ``fd`` is registered in thread's dispatcher.
return fd in disp.handles
else:
@ -435,6 +443,7 @@ else:
proc `==`*(x, y: AsyncFD): bool {.borrow.}
proc newDispatcher*(): PDispatcher =
## Create new dispatcher.
new result
result.selector = newSelector[SelectorData]()
result.timers.newHeapQueue()
@ -444,40 +453,44 @@ else:
var gDisp{.threadvar.}: PDispatcher ## Global dispatcher
proc setGlobalDispatcher*(disp: PDispatcher) =
## Set current thread's dispatcher instance to ``disp``.
if not gDisp.isNil:
assert gDisp.callbacks.len == 0
gDisp = disp
initCallSoonProc()
proc getGlobalDispatcher*(): PDispatcher =
## Returns current thread's dispatcher instance.
if gDisp.isNil:
setGlobalDispatcher(newDispatcher())
result = gDisp
proc getIoHandler*(disp: PDispatcher): Selector[SelectorData] =
## Returns system specific OS queue.
return disp.selector
proc register*(fd: AsyncFD) =
## Register file descriptor ``fd`` in selector.
## Register file descriptor ``fd`` in thread's dispatcher.
let loop = getGlobalDispatcher()
var data: SelectorData
data.rdata.fd = fd
data.wdata.fd = fd
let loop = getGlobalDispatcher()
loop.selector.registerHandle(int(fd), {}, data)
proc unregister*(fd: AsyncFD) =
## Unregister file descriptor ``fd`` from selector.
## Unregister file descriptor ``fd`` from thread's dispatcher.
getGlobalDispatcher().selector.unregister(int(fd))
proc contains*(disp: PDispatcher, fd: AsyncFd): bool {.inline.} =
## Returns ``true`` if ``fd`` is registered in thread's dispatcher.
result = int(fd) in disp.selector
proc addReader*(fd: AsyncFD, cb: CallbackFunc, udata: pointer = nil) =
## Start watching the file descriptor ``fd`` for read availability and then
## call the callback ``cb`` with specified argument ``udata``.
let p = getGlobalDispatcher()
let loop = getGlobalDispatcher()
var newEvents = {Event.Read}
withData(p.selector, int(fd), adata) do:
withData(loop.selector, int(fd), adata) do:
let acb = AsyncCallback(function: cb, udata: addr adata.rdata)
adata.reader = acb
adata.rdata = CompletionData(fd: fd, udata: udata)
@ -485,27 +498,27 @@ else:
if not isNil(adata.writer.function): newEvents.incl(Event.Write)
do:
raise newException(ValueError, "File descriptor not registered.")
p.selector.updateHandle(int(fd), newEvents)
loop.selector.updateHandle(int(fd), newEvents)
proc removeReader*(fd: AsyncFD) =
## Stop watching the file descriptor ``fd`` for read availability.
let p = getGlobalDispatcher()
let loop = getGlobalDispatcher()
var newEvents: set[Event]
withData(p.selector, int(fd), adata) do:
withData(loop.selector, int(fd), adata) do:
# We need to clear `reader` data, because `selectors` don't do it
adata.reader = AsyncCallback()
adata.rdata = CompletionData()
adata.reader.function = nil
# adata.rdata = CompletionData()
if not isNil(adata.writer.function): newEvents.incl(Event.Write)
do:
raise newException(ValueError, "File descriptor not registered.")
p.selector.updateHandle(int(fd), newEvents)
loop.selector.updateHandle(int(fd), newEvents)
proc addWriter*(fd: AsyncFD, cb: CallbackFunc, udata: pointer = nil) =
## Start watching the file descriptor ``fd`` for write availability and then
## call the callback ``cb`` with specified argument ``udata``.
let p = getGlobalDispatcher()
let loop = getGlobalDispatcher()
var newEvents = {Event.Write}
withData(p.selector, int(fd), adata) do:
withData(loop.selector, int(fd), adata) do:
let acb = AsyncCallback(function: cb, udata: addr adata.wdata)
adata.writer = acb
adata.wdata = CompletionData(fd: fd, udata: udata)
@ -513,20 +526,44 @@ else:
if not isNil(adata.reader.function): newEvents.incl(Event.Read)
do:
raise newException(ValueError, "File descriptor not registered.")
p.selector.updateHandle(int(fd), newEvents)
loop.selector.updateHandle(int(fd), newEvents)
proc removeWriter*(fd: AsyncFD) =
## Stop watching the file descriptor ``fd`` for write availability.
let p = getGlobalDispatcher()
let loop = getGlobalDispatcher()
var newEvents: set[Event]
withData(p.selector, int(fd), adata) do:
withData(loop.selector, int(fd), adata) do:
# We need to clear `writer` data, because `selectors` don't do it
adata.writer = AsyncCallback()
adata.wdata = CompletionData()
adata.writer.function = nil
# adata.wdata = CompletionData()
if not isNil(adata.reader.function): newEvents.incl(Event.Read)
do:
raise newException(ValueError, "File descriptor not registered.")
p.selector.updateHandle(int(fd), newEvents)
loop.selector.updateHandle(int(fd), newEvents)
proc closeSocket*(fd: AsyncFD, aftercb: CallbackFunc = nil) =
## Close asynchronous socket.
##
## Please note, that socket is not closed immediately. To avoid bugs with
## closing socket, while operation pending, socket will be closed as
## soon as all pending operations will be notified.
## You can execute ``aftercb`` before actual socket close operation.
let loop = getGlobalDispatcher()
proc continuation(udata: pointer) =
aftercb(nil)
unregister(fd)
close(SocketHandle(fd))
withData(loop.selector, int(fd), adata) do:
if not isNil(adata.reader.function):
loop.callbacks.addLast(adata.reader)
if not isNil(adata.writer.function):
loop.callbacks.addLast(adata.writer)
if not isNil(aftercb):
var acb = AsyncCallback(function: continuation)
loop.callbacks.addLast(acb)
when ioselSupportedPlatform:
proc addSignal*(signal: int, cb: CallbackFunc,
@ -535,10 +572,10 @@ else:
## callback ``cb`` with specified argument ``udata``. Returns signal
## identifier code, which can be used to remove signal callback
## via ``removeSignal``.
let p = getGlobalDispatcher()
let loop = getGlobalDispatcher()
var data: SelectorData
result = p.selector.registerSignal(signal, data)
withData(p.selector, result, adata) do:
result = loop.selector.registerSignal(signal, data)
withData(loop.selector, result, adata) do:
adata.reader = AsyncCallback(function: cb, udata: addr adata.rdata)
adata.rdata.fd = AsyncFD(result)
adata.rdata.udata = udata
@ -547,8 +584,8 @@ else:
proc removeSignal*(sigfd: int) =
## Remove watching signal ``signal``.
let p = getGlobalDispatcher()
p.selector.unregister(sigfd)
let loop = getGlobalDispatcher()
loop.selector.unregister(sigfd)
proc poll*() =
## Perform single asynchronous step.
@ -569,21 +606,18 @@ else:
let fd = loop.keys[i].fd
let events = loop.keys[i].events
if Event.Read in events or events == {Event.Error}:
withData(loop.selector, fd, adata) do:
withData(loop.selector, fd, adata) do:
if Event.Read in events or events == {Event.Error}:
loop.callbacks.addLast(adata.reader)
if Event.Write in events or events == {Event.Error}:
withData(loop.selector, fd, adata) do:
if Event.Write in events or events == {Event.Error}:
loop.callbacks.addLast(adata.writer)
if Event.User in events:
withData(loop.selector, fd, adata) do:
if Event.User in events:
loop.callbacks.addLast(adata.reader)
when ioselSupportedPlatform:
if customSet * events != {}:
withData(loop.selector, fd, adata) do:
when ioselSupportedPlatform:
if customSet * events != {}:
loop.callbacks.addLast(adata.reader)
# Moving expired timers to `loop.callbacks`.
@ -618,10 +652,6 @@ proc removeTimer*(at: uint64, cb: CallbackFunc, udata: pointer = nil) =
if index != -1:
loop.timers.del(index)
# proc completeProxy*[T](data: pointer) =
# var future = cast[Future[T]](data)
# future.complete()
proc sleepAsync*(ms: int): Future[void] =
## Suspends the execution of the current async procedure for the next
## ``ms`` milliseconds.
@ -656,7 +686,7 @@ proc withTimeout*[T](fut: Future[T], timeout: int): Future[bool] =
proc wait*[T](fut: Future[T], timeout = -1): Future[T] =
## Returns a future which will complete once future ``fut`` completes
## or if timeout of ``timeout`` milliseconds has been expired.
##
##
## If ``timeout`` is ``-1``, then statement ``await wait(fut)`` is
## equal to ``await fut``.
var retFuture = newFuture[T]("asyncdispatch.wait")

View File

@ -90,8 +90,3 @@ proc wrapAsyncSocket*(sock: SocketHandle): AsyncFD =
return asyncInvalidSocket
result = AsyncFD(sock)
register(result)
proc closeAsyncSocket*(s: AsyncFD) {.inline.} =
## Closes asynchronous socket handle ``s``.
unregister(s)
close(SocketHandle(s))

View File

@ -228,7 +228,7 @@ when defined(windows):
if not setSockOpt(localSock, SOL_SOCKET, SO_REUSEADDR, 1):
let err = osLastError()
if sock == asyncInvalidSocket:
closeAsyncSocket(localSock)
closeSocket(localSock)
raiseTransportOsError(err)
## Fix for Q263823.
@ -247,7 +247,7 @@ when defined(windows):
slen) != 0:
let err = osLastError()
if sock == asyncInvalidSocket:
closeAsyncSocket(localSock)
closeSocket(localSock)
raiseTransportOsError(err)
result.local = local
else:
@ -263,7 +263,7 @@ when defined(windows):
slen) != 0:
let err = osLastError()
if sock == asyncInvalidSocket:
closeAsyncSocket(localSock)
closeSocket(localSock)
raiseTransportOsError(err)
if remote.port != Port(0):
@ -274,7 +274,7 @@ when defined(windows):
slen) != 0:
let err = osLastError()
if sock == asyncInvalidSocket:
closeAsyncSocket(localSock)
closeSocket(localSock)
raiseTransportOsError(err)
result.remote = remote
@ -297,30 +297,34 @@ when defined(windows):
else:
result.state.incl(ReadPaused)
proc close*(transp: DatagramTransport) =
## Closes and frees resources of transport ``transp``.
if ReadClosed notin transp.state and WriteClosed notin transp.state:
# discard cancelIo(Handle(transp.fd))
closeAsyncSocket(transp.fd)
transp.state.incl(WriteClosed)
transp.state.incl(ReadClosed)
transp.future.complete()
if not isNil(transp.udata) and GCUserData in transp.flags:
GC_unref(cast[ref int](transp.udata))
GC_unref(transp)
# proc close*(transp: DatagramTransport) =
# ## Closes and frees resources of transport ``transp``.
# if ReadClosed notin transp.state and WriteClosed notin transp.state:
# # discard cancelIo(Handle(transp.fd))
# closeSocket(transp.fd)
# transp.state.incl(WriteClosed)
# transp.state.incl(ReadClosed)
# transp.future.complete()
# if not isNil(transp.udata) and GCUserData in transp.flags:
# GC_unref(cast[ref int](transp.udata))
# GC_unref(transp)
else:
# Linux/BSD/MacOS part
proc readDatagramLoop(udata: pointer) =
var raddr: TransportAddress
doAssert(not isNil(udata))
var cdata = cast[ptr CompletionData](udata)
if not isNil(cdata) and (int(cdata.fd) == 0 or isNil(cdata.udata)):
# Transport was closed earlier, exiting
return
var transp = cast[DatagramTransport](cdata.udata)
let fd = SocketHandle(cdata.fd)
if not isNil(transp):
if int(fd) == 0:
## This situation can be happen, when there events present
## after transport was closed.
return
if ReadClosed in transp.state:
transp.state.incl({ReadPaused})
else:
while true:
transp.ralen = SockLen(sizeof(Sockaddr_storage))
var res = posix.recvfrom(fd, addr transp.buffer[0],
@ -343,13 +347,17 @@ else:
proc writeDatagramLoop(udata: pointer) =
var res: int
doAssert(not isNil(udata))
var cdata = cast[ptr CompletionData](udata)
if not isNil(cdata) and (int(cdata.fd) == 0 or isNil(cdata.udata)):
# Transport was closed earlier, exiting
return
var transp = cast[DatagramTransport](cdata.udata)
let fd = SocketHandle(cdata.fd)
if not isNil(transp):
if int(fd) == 0:
## This situation can be happen, when there events present
## after transport was closed.
return
if WriteClosed in transp.state:
transp.state.incl({WritePaused})
else:
if len(transp.queue) > 0:
var vector = transp.queue.popFirst()
while true:
@ -420,7 +428,7 @@ else:
if not setSockOpt(localSock, SOL_SOCKET, SO_REUSEADDR, 1):
let err = osLastError()
if sock == asyncInvalidSocket:
closeAsyncSocket(localSock)
closeSocket(localSock)
raiseTransportOsError(err)
if local.port != Port(0):
@ -431,7 +439,7 @@ else:
slen) != 0:
let err = osLastError()
if sock == asyncInvalidSocket:
closeAsyncSocket(localSock)
closeSocket(localSock)
raiseTransportOsError(err)
result.local = local
@ -443,7 +451,7 @@ else:
slen) != 0:
let err = osLastError()
if sock == asyncInvalidSocket:
closeAsyncSocket(localSock)
closeSocket(localSock)
raiseTransportOsError(err)
result.remote = remote
@ -461,13 +469,22 @@ else:
else:
result.state.incl(ReadPaused)
proc close*(transp: DatagramTransport) =
## Closes and frees resources of transport ``transp``.
proc close*(transp: DatagramTransport) =
## Closes and frees resources of transport ``transp``.
when defined(windows):
if {ReadClosed, WriteClosed} * transp.state == {}:
closeAsyncSocket(transp.fd)
discard cancelIo(Handle(transp.fd))
closeSocket(transp.fd)
transp.state.incl({WriteClosed, ReadClosed})
transp.future.complete()
GC_unref(transp)
else:
proc continuation(udata: pointer) =
transp.future.complete()
GC_unref(transp)
if {ReadClosed, WriteClosed} * transp.state == {}:
transp.state.incl({WriteClosed, ReadClosed})
closeSocket(transp.fd, continuation)
proc newDatagramTransport*(cbproc: DatagramCallback,
remote: TransportAddress = AnyAddress,
@ -543,10 +560,19 @@ proc newDatagramTransport6*[T](cbproc: DatagramCallback,
fflags, cast[pointer](udata),
child, bufSize)
proc join*(transp: DatagramTransport) {.async.} =
proc join*(transp: DatagramTransport): Future[void] =
## Wait until the transport ``transp`` will be closed.
var retFuture = newFuture[void]("datagramtransport.join")
proc continuation(udata: pointer) = retFuture.complete()
if not transp.future.finished:
await transp.future
transp.future.addCallback(continuation)
else:
retFuture.complete()
return retFuture
proc closeWait*(transp: DatagramTransport): Future[void] =
## Close transport ``transp`` and release all resources.
result = transp.join()
proc send*(transp: DatagramTransport, pbytes: pointer,
nbytes: int): Future[void] =

View File

@ -291,6 +291,11 @@ when defined(windows):
## Continuation
transp.state.excl(ReadPending)
if ReadClosed in transp.state:
transp.state.incl({ReadPaused})
if not isNil(transp.reader):
if not transp.reader.finished:
transp.reader.complete()
transp.reader = nil
break
let err = transp.rovl.data.errCode
if err == OSErrorCode(-1):
@ -353,6 +358,11 @@ when defined(windows):
if not isNil(transp.reader):
transp.reader.complete()
transp.reader = nil
else:
transp.state.incl(ReadPaused)
if not isNil(transp.reader):
transp.reader.complete()
transp.reader = nil
## Finish Loop
break
@ -411,7 +421,7 @@ when defined(windows):
result.fail(newException(TransportOsError, osErrorMsg(osLastError())))
if not bindToDomain(sock, address.address.getDomain()):
sock.closeAsyncSocket()
sock.closeSocket()
result.fail(newException(TransportOsError, osErrorMsg(osLastError())))
proc continuation(udata: pointer) =
@ -421,7 +431,7 @@ when defined(windows):
if setsockopt(SocketHandle(sock), cint(SOL_SOCKET),
cint(SO_UPDATE_CONNECT_CONTEXT), nil,
SockLen(0)) != 0'i32:
sock.closeAsyncSocket()
sock.closeSocket()
retFuture.fail(newException(TransportOsError,
osErrorMsg(osLastError())))
else:
@ -429,7 +439,7 @@ when defined(windows):
bufferSize,
child))
else:
sock.closeAsyncSocket()
sock.closeSocket()
retFuture.fail(newException(TransportOsError,
osErrorMsg(ovl.data.errCode)))
GC_unref(ovl)
@ -446,7 +456,7 @@ when defined(windows):
let err = osLastError()
if int32(err) != ERROR_IO_PENDING:
GC_unref(povl)
sock.closeAsyncSocket()
sock.closeSocket()
retFuture.fail(newException(TransportOsError, osErrorMsg(err)))
return retFuture
@ -460,14 +470,14 @@ when defined(windows):
## Continuation
server.apending = false
if server.status == ServerStatus.Stopped:
server.asock.closeAsyncSocket()
server.asock.closeSocket()
else:
if ovl.data.errCode == OSErrorCode(-1):
if setsockopt(SocketHandle(server.asock), cint(SOL_SOCKET),
cint(SO_UPDATE_ACCEPT_CONTEXT),
addr server.sock,
SockLen(sizeof(SocketHandle))) != 0'i32:
server.asock.closeAsyncSocket()
server.asock.closeSocket()
raiseTransportOsError(osLastError())
else:
if not isNil(server.init):
@ -482,10 +492,10 @@ when defined(windows):
asyncCheck server.function(server, ntransp)
elif int32(ovl.data.errCode) == ERROR_OPERATION_ABORTED:
# CancelIO() interrupt
server.asock.closeAsyncSocket()
server.asock.closeSocket()
break
else:
server.asock.closeAsyncSocket()
server.asock.closeSocket()
raiseTransportOsError(osLastError())
else:
## Initiation
@ -553,11 +563,14 @@ else:
proc writeStreamLoop(udata: pointer) {.gcsafe.} =
var cdata = cast[ptr CompletionData](udata)
if not isNil(cdata) and (int(cdata.fd) == 0 or isNil(cdata.udata)):
# Transport was closed earlier, exiting
return
var transp = cast[StreamTransport](cdata.udata)
let fd = SocketHandle(cdata.fd)
if int(fd) == 0:
## This situation can be happen, when there events present
## after transport was closed.
return
if len(transp.queue) > 0:
var vector = transp.queue.popFirst()
while true:
@ -601,36 +614,46 @@ else:
proc readStreamLoop(udata: pointer) {.gcsafe.} =
var cdata = cast[ptr CompletionData](udata)
if not isNil(cdata) and (int(cdata.fd) == 0 or isNil(cdata.udata)):
# Transport was closed earlier, exiting
return
var transp = cast[StreamTransport](cdata.udata)
let fd = SocketHandle(cdata.fd)
while true:
var res = posix.recv(fd, addr transp.buffer[transp.offset],
len(transp.buffer) - transp.offset, cint(0))
if res < 0:
let err = osLastError()
if int(err) == EINTR:
continue
elif int(err) in {ECONNRESET}:
if int(fd) == 0:
## This situation can be happen, when there events present
## after transport was closed.
return
if ReadClosed in transp.state:
transp.state.incl({ReadPaused})
if not isNil(transp.reader):
if not transp.reader.finished:
transp.reader.complete()
transp.reader = nil
else:
while true:
var res = posix.recv(fd, addr transp.buffer[transp.offset],
len(transp.buffer) - transp.offset, cint(0))
if res < 0:
let err = osLastError()
if int(err) == EINTR:
continue
elif int(err) in {ECONNRESET}:
transp.state.incl({ReadEof, ReadPaused})
cdata.fd.removeReader()
else:
transp.state.incl(ReadPaused)
transp.setReadError(err)
cdata.fd.removeReader()
elif res == 0:
transp.state.incl({ReadEof, ReadPaused})
cdata.fd.removeReader()
else:
transp.setReadError(err)
cdata.fd.removeReader()
elif res == 0:
transp.state.incl({ReadEof, ReadPaused})
cdata.fd.removeReader()
else:
transp.offset += res
if transp.offset == len(transp.buffer):
transp.state.incl(ReadPaused)
cdata.fd.removeReader()
if not isNil(transp.reader):
transp.reader.complete()
transp.reader = nil
break
transp.offset += res
if transp.offset == len(transp.buffer):
transp.state.incl(ReadPaused)
cdata.fd.removeReader()
if not isNil(transp.reader):
transp.reader.complete()
transp.reader = nil
break
proc newStreamSocketTransport(sock: AsyncFD, bufsize: int,
child: StreamTransport): StreamTransport =
@ -670,17 +693,17 @@ else:
var data = cast[ptr CompletionData](udata)
var err = 0
let fd = data.fd
fd.removeWriter()
if not fd.getSocketError(err):
fd.closeAsyncSocket()
closeSocket(fd)
retFuture.fail(newException(TransportOsError,
osErrorMsg(osLastError())))
return
if err != 0:
fd.closeAsyncSocket()
closeSocket(fd)
retFuture.fail(newException(TransportOsError,
osErrorMsg(OSErrorCode(err))))
return
fd.removeWriter()
retFuture.complete(newStreamSocketTransport(fd, bufferSize, child))
while true:
@ -697,7 +720,7 @@ else:
sock.addWriter(continuation)
break
else:
sock.closeAsyncSocket()
sock.closeSocket()
retFuture.fail(newException(TransportOsError, osErrorMsg(err)))
break
return retFuture
@ -757,20 +780,33 @@ proc stop*(server: StreamServer) =
elif server.status == ServerStatus.Starting:
server.status = ServerStatus.Stopped
proc join*(server: StreamServer) {.async.} =
proc join*(server: StreamServer): Future[void] =
## Waits until ``server`` is not closed.
var retFuture = newFuture[void]("streamserver.join")
proc continuation(udata: pointer) = retFuture.complete()
if not server.loopFuture.finished:
await server.loopFuture
server.loopFuture.addCallback(continuation)
else:
retFuture.complete()
return retFuture
proc close*(server: StreamServer) =
## Release ``server`` resources.
if server.status == ServerStatus.Stopped:
closeAsyncSocket(server.sock)
server.status = ServerStatus.Closed
##
## Please note that release of resources is not completed immediately, to be
## sure all resources got released please use ``await server.join()``.
proc continuation(udata: pointer) =
server.loopFuture.complete()
if not isNil(server.udata) and GCUserData in server.flags:
GC_unref(cast[ref int](server.udata))
GC_unref(server)
if server.status == ServerStatus.Stopped:
server.status = ServerStatus.Closed
server.sock.closeSocket(continuation)
proc closeWait*(server: StreamServer): Future[void] =
## Close server ``server`` and release all resources.
result = server.join()
proc createStreamServer*(host: TransportAddress,
cbproc: StreamCallback,
@ -814,7 +850,7 @@ proc createStreamServer*(host: TransportAddress,
if not setSockOpt(serverSocket, SOL_SOCKET, SO_REUSEADDR, 1):
let err = osLastError()
if sock == asyncInvalidSocket:
closeAsyncSocket(serverSocket)
serverSocket.closeSocket()
raiseTransportOsError(err)
toSockAddr(host.address, host.port, saddr, slen)
@ -822,13 +858,13 @@ proc createStreamServer*(host: TransportAddress,
slen) != 0:
let err = osLastError()
if sock == asyncInvalidSocket:
closeAsyncSocket(serverSocket)
serverSocket.closeSocket()
raiseTransportOsError(err)
if nativesockets.listen(SocketHandle(serverSocket), cint(backlog)) != 0:
let err = osLastError()
if sock == asyncInvalidSocket:
closeAsyncSocket(serverSocket)
serverSocket.closeSocket()
raiseTransportOsError(err)
if not isNil(child):
@ -1193,21 +1229,35 @@ proc consume*(transp: StreamTransport, n = -1): Future[int] {.async.} =
transp.resumeRead()
await fut
proc join*(transp: StreamTransport) {.async.} =
proc join*(transp: StreamTransport): Future[void] =
## Wait until ``transp`` will not be closed.
var retFuture = newFuture[void]("streamtransport.join")
proc continuation(udata: pointer) = retFuture.complete()
if not transp.future.finished:
await transp.future
transp.future.addCallback(continuation)
else:
retFuture.complete()
return retFuture
proc close*(transp: StreamTransport) =
## Closes and frees resources of transport ``transp``.
if {ReadClosed, WriteClosed} * transp.state == {}:
when defined(windows):
discard cancelIo(Handle(transp.fd))
closeAsyncSocket(transp.fd)
transp.state.incl({WriteClosed, ReadClosed})
##
## Please note that release of resources is not completed immediately, to be
## sure all resources got released please use ``await transp.join()``.
proc continuation(udata: pointer) =
transp.future.complete()
GC_unref(transp)
if {ReadClosed, WriteClosed} * transp.state == {}:
transp.state.incl({WriteClosed, ReadClosed})
when defined(windows):
discard cancelIo(Handle(transp.fd))
closeSocket(transp.fd, continuation)
proc closeWait*(transp: StreamTransport): Future[void] =
## Close and frees resources of transport ``transp``.
result = transp.join()
proc closed*(transp: StreamTransport): bool {.inline.} =
## Returns ``true`` if transport in closed state.
result = ({ReadClosed, WriteClosed} * transp.state != {})

View File

@ -139,7 +139,7 @@ proc client5(transp: DatagramTransport,
if counterPtr[] == MessagesCount:
transp.close()
else:
var ta = initTAddress("127.0.0.1:33337")
var ta = initTAddress("127.0.0.1:33341")
var req = "REQUEST" & $counterPtr[]
await transp.sendTo(ta, addr req[0], len(req))
else:
@ -272,7 +272,7 @@ proc client10(transp: DatagramTransport,
if counterPtr[] == TestsCount:
transp.close()
else:
var ta = initTAddress("127.0.0.1:33336")
var ta = initTAddress("127.0.0.1:33338")
var req = "REQUEST" & $counterPtr[]
var reqseq = newSeq[byte](len(req))
copyMem(addr reqseq[0], addr req[0], len(req))
@ -370,7 +370,7 @@ proc testStringSend(): Future[int] {.async.} =
proc testSeqSendTo(): Future[int] {.async.} =
## sendTo(string) test
var ta = initTAddress("127.0.0.1:33336")
var ta = initTAddress("127.0.0.1:33338")
var counter = 0
var dgram1 = newDatagramTransport(client9, udata = addr counter, local = ta)
var dgram2 = newDatagramTransport(client10, udata = addr counter)
@ -385,7 +385,7 @@ proc testSeqSendTo(): Future[int] {.async.} =
proc testSeqSend(): Future[int] {.async.} =
## send(string) test
var ta = initTAddress("127.0.0.1:33337")
var ta = initTAddress("127.0.0.1:33339")
var counter = 0
var dgram1 = newDatagramTransport(client9, udata = addr counter, local = ta)
var dgram2 = newDatagramTransport(client11, udata = addr counter, remote = ta)
@ -412,7 +412,11 @@ proc waitAll(futs: seq[Future[void]]): Future[void] =
return retFuture
proc test3(bounded: bool): Future[int] {.async.} =
var ta = initTAddress("127.0.0.1:33337")
var ta: TransportAddress
if bounded:
ta = initTAddress("127.0.0.1:33340")
else:
ta = initTAddress("127.0.0.1:33341")
var counter = 0
var dgram1 = newDatagramTransport(client1, udata = addr counter, local = ta)
var clients = newSeq[Future[void]](ClientsCount)

View File

@ -35,6 +35,7 @@ proc serveCustomStreamClient(server: StreamServer,
var answer = "ANSWER\r\n"
discard await transp.write(answer)
transp.close()
await transp.join()
proc serveUdataStreamClient(server: StreamServer,
transp: StreamTransport) {.async.} =
@ -43,6 +44,7 @@ proc serveUdataStreamClient(server: StreamServer,
var msg = line & udata.test & "\r\n"
discard await transp.write(msg)
transp.close()
await transp.join()
proc customServerTransport(server: StreamServer,
fd: AsyncFD): StreamTransport =
@ -56,10 +58,12 @@ proc test1(): bool =
server1.start()
server1.stop()
server1.close()
waitFor server1.join()
var server2 = createStreamServer(ta, serveStreamClient, {ReuseAddr})
server2.start()
server2.stop()
server2.close()
waitFor server2.join()
result = true
proc client1(server: CustomServer, ta: TransportAddress) {.async.} =
@ -75,6 +79,7 @@ proc client1(server: CustomServer, ta: TransportAddress) {.async.} =
transp.close()
server.stop()
server.close()
await server.join()
proc client2(server: StreamServer,
ta: TransportAddress): Future[bool] {.async.} =
@ -87,6 +92,7 @@ proc client2(server: StreamServer,
transp.close()
server.stop()
server.close()
await server.join()
proc test3(): bool =
var server = CustomServer()

View File

@ -24,7 +24,7 @@ when sizeof(int) == 8:
ClientsCount = 100
MessagesCount = 100
MessageSize = 20
FilesCount = 50
FilesCount = 100
elif sizeof(int) == 4:
const
BigMessageCount = 200
@ -46,6 +46,7 @@ proc serveClient1(server: StreamServer, transp: StreamTransport) {.async.} =
var res = await transp.write(cast[pointer](addr ans[0]), len(ans))
doAssert(res == len(ans))
transp.close()
await transp.join()
proc serveClient2(server: StreamServer, transp: StreamTransport) {.async.} =
var buffer: array[20, char]
@ -69,6 +70,7 @@ proc serveClient2(server: StreamServer, transp: StreamTransport) {.async.} =
var res = await transp.write(cast[pointer](addr buffer[0]), MessageSize)
doAssert(res == MessageSize)
transp.close()
await transp.join()
proc serveClient3(server: StreamServer, transp: StreamTransport) {.async.} =
var buffer: array[20, char]
@ -95,6 +97,7 @@ proc serveClient3(server: StreamServer, transp: StreamTransport) {.async.} =
doAssert(res == len(ans))
dec(counter)
transp.close()
await transp.join()
proc serveClient4(server: StreamServer, transp: StreamTransport) {.async.} =
var pathname = await transp.readLine()
@ -110,6 +113,7 @@ proc serveClient4(server: StreamServer, transp: StreamTransport) {.async.} =
var res = await transp.write(cast[pointer](addr answer[0]), len(answer))
doAssert(res == len(answer))
transp.close()
await transp.join()
proc serveClient5(server: StreamServer, transp: StreamTransport) {.async.} =
var data = await transp.read()
@ -124,6 +128,7 @@ proc serveClient5(server: StreamServer, transp: StreamTransport) {.async.} =
if counter[] == 0:
server.stop()
server.close()
await server.join()
proc serveClient6(server: StreamServer, transp: StreamTransport) {.async.} =
var expect = ConstantMessage
@ -138,6 +143,7 @@ proc serveClient6(server: StreamServer, transp: StreamTransport) {.async.} =
if counter[] == 0:
server.stop()
server.close()
await server.join()
proc serveClient7(server: StreamServer, transp: StreamTransport) {.async.} =
var answer = "DONE\r\n"
@ -150,6 +156,7 @@ proc serveClient7(server: StreamServer, transp: StreamTransport) {.async.} =
var res = await transp.write(answer)
doAssert(res == len(answer))
transp.close()
await transp.join()
proc serveClient8(server: StreamServer, transp: StreamTransport) {.async.} =
var answer = "DONE\r\n"
@ -171,6 +178,7 @@ proc serveClient8(server: StreamServer, transp: StreamTransport) {.async.} =
transp.close()
server.stop()
server.close()
await server.join()
proc swarmWorker1(address: TransportAddress): Future[int] {.async.} =
var transp = await connect(address)
@ -185,6 +193,7 @@ proc swarmWorker1(address: TransportAddress): Future[int] {.async.} =
doAssert(num == i)
inc(result)
transp.close()
await transp.join()
proc swarmWorker2(address: TransportAddress): Future[int] {.async.} =
var transp = await connect(address)
@ -208,6 +217,7 @@ proc swarmWorker2(address: TransportAddress): Future[int] {.async.} =
doAssert(num == i)
inc(result)
transp.close()
await transp.join()
proc swarmWorker3(address: TransportAddress): Future[int] {.async.} =
var transp = await connect(address)
@ -235,6 +245,7 @@ proc swarmWorker3(address: TransportAddress): Future[int] {.async.} =
doAssert(num == i)
inc(result)
transp.close()
await transp.join()
proc swarmWorker4(address: TransportAddress): Future[int] {.async.} =
var transp = await connect(address)
@ -249,9 +260,9 @@ proc swarmWorker4(address: TransportAddress): Future[int] {.async.} =
handle = int(getFileHandle(fhandle))
doAssert(handle > 0)
name = name & "\r\n"
ssize = $size & "\r\n"
var res = await transp.write(cast[pointer](addr name[0]), len(name))
doAssert(res == len(name))
ssize = $size & "\r\n"
res = await transp.write(cast[pointer](addr ssize[0]), len(ssize))
doAssert(res == len(ssize))
var checksize = await transp.writeFile(handle, 0'u, size)
@ -261,6 +272,7 @@ proc swarmWorker4(address: TransportAddress): Future[int] {.async.} =
doAssert(ans == "OK")
result = 1
transp.close()
await transp.join()
proc swarmWorker5(address: TransportAddress): Future[int] {.async.} =
var transp = await connect(address)
@ -269,6 +281,7 @@ proc swarmWorker5(address: TransportAddress): Future[int] {.async.} =
var res = await transp.write(data)
result = MessagesCount
transp.close()
await transp.join()
proc swarmWorker6(address: TransportAddress): Future[int] {.async.} =
var transp = await connect(address)
@ -279,6 +292,7 @@ proc swarmWorker6(address: TransportAddress): Future[int] {.async.} =
var res = await transp.write(seqdata)
result = MessagesCount
transp.close()
await transp.join()
proc swarmWorker7(address: TransportAddress): Future[int] {.async.} =
var transp = await connect(address)
@ -292,6 +306,7 @@ proc swarmWorker7(address: TransportAddress): Future[int] {.async.} =
doAssert(line == "DONE")
result = 1
transp.close()
await transp.join()
proc swarmWorker8(address: TransportAddress): Future[int] {.async.} =
var transp = await connect(address)
@ -305,6 +320,7 @@ proc swarmWorker8(address: TransportAddress): Future[int] {.async.} =
doAssert(line == "DONE")
result = 1
transp.close()
await transp.join()
proc waitAll[T](futs: seq[Future[T]]): Future[void] =
var counter = len(futs)
@ -390,6 +406,7 @@ proc test1(): Future[int] {.async.} =
result = await swarmManager1(ta)
server.stop()
server.close()
await server.join()
proc test2(): Future[int] {.async.} =
var ta = initTAddress("127.0.0.1:31345")
@ -399,6 +416,7 @@ proc test2(): Future[int] {.async.} =
result = await swarmManager2(ta)
server.stop()
server.close()
await server.join()
proc test3(): Future[int] {.async.} =
var ta = initTAddress("127.0.0.1:31346")
@ -408,6 +426,7 @@ proc test3(): Future[int] {.async.} =
result = await swarmManager3(ta)
server.stop()
server.close()
await server.join()
proc test4(): Future[int] {.async.} =
var ta = initTAddress("127.0.0.1:31347")
@ -416,6 +435,7 @@ proc test4(): Future[int] {.async.} =
result = await swarmManager4(ta)
server.stop()
server.close()
await server.join()
proc test5(): Future[int] {.async.} =
var ta = initTAddress("127.0.0.1:31348")
@ -424,7 +444,6 @@ proc test5(): Future[int] {.async.} =
udata = cast[pointer](addr counter))
server.start()
result = await swarmManager5(ta)
await server.join()
proc test6(): Future[int] {.async.} =
var ta = initTAddress("127.0.0.1:31349")
@ -433,7 +452,6 @@ proc test6(): Future[int] {.async.} =
udata = cast[pointer](addr counter))
server.start()
result = await swarmManager6(ta)
await server.join()
proc test7(): Future[int] {.async.} =
var ta = initTAddress("127.0.0.1:31350")
@ -442,6 +460,7 @@ proc test7(): Future[int] {.async.} =
result = await swarmWorker7(ta)
server.stop()
server.close()
await server.join()
proc test8(): Future[int] {.async.} =
var ta = initTAddress("127.0.0.1:31350")
@ -450,6 +469,168 @@ proc test8(): Future[int] {.async.} =
result = await swarmWorker8(ta)
server.stop()
server.close()
await server.join()
proc serveClient9(server: StreamServer, transp: StreamTransport) {.async.} =
var expect = ""
for i in 0..<BigMessageCount:
expect.add(BigMessagePattern)
var res = await transp.write(expect)
doAssert(res == len(expect))
transp.close()
await transp.join()
proc swarmWorker9(address: TransportAddress): Future[int] {.async.} =
var transp = await connect(address)
var expect = ""
for i in 0..<BigMessageCount:
expect.add(BigMessagePattern)
var line = await transp.readLine()
if line == expect:
result = 1
else:
result = 0
transp.close()
await transp.join()
proc test9(): Future[int] {.async.} =
var ta = initTAddress("127.0.0.1:31351")
var server = createStreamServer(ta, serveClient9, {ReuseAddr})
server.start()
result = await swarmWorker9(ta)
server.stop()
server.close()
await server.join()
proc serveClient10(server: StreamServer, transp: StreamTransport) {.async.} =
var expect = ""
for i in 0..<BigMessageCount:
expect.add(BigMessagePattern)
var res = await transp.write(expect)
doAssert(res == len(expect))
transp.close()
await transp.join()
proc swarmWorker10(address: TransportAddress): Future[int] {.async.} =
var transp = await connect(address)
var expect = ""
for i in 0..<BigMessageCount:
expect.add(BigMessagePattern)
var line = await transp.read()
if equalMem(addr line[0], addr expect[0], len(expect)):
result = 1
else:
result = 0
transp.close()
await transp.join()
proc test10(): Future[int] {.async.} =
var ta = initTAddress("127.0.0.1:31351")
var server = createStreamServer(ta, serveClient10, {ReuseAddr})
server.start()
result = await swarmWorker10(ta)
server.stop()
server.close()
await server.join()
proc serveClient11(server: StreamServer, transp: StreamTransport) {.async.} =
var res = await transp.write(BigMessagePattern)
doAssert(res == len(BigMessagePattern))
transp.close()
await transp.join()
proc swarmWorker11(address: TransportAddress): Future[int] {.async.} =
var buffer: array[len(BigMessagePattern) + 1, byte]
var transp = await connect(address)
try:
await transp.readExactly(addr buffer[0], len(buffer))
except TransportIncompleteError:
result = 1
transp.close()
await transp.join()
proc test11(): Future[int] {.async.} =
var ta = initTAddress("127.0.0.1:31352")
var server = createStreamServer(ta, serveClient11, {ReuseAddr})
server.start()
result = await swarmWorker11(ta)
server.stop()
server.close()
await server.join()
proc serveClient12(server: StreamServer, transp: StreamTransport) {.async.} =
var res = await transp.write(BigMessagePattern)
doAssert(res == len(BigMessagePattern))
transp.close()
await transp.join()
proc swarmWorker12(address: TransportAddress): Future[int] {.async.} =
var buffer: array[len(BigMessagePattern), byte]
var sep = @[0x0D'u8, 0x0A'u8]
var transp = await connect(address)
try:
var res = await transp.readUntil(addr buffer[0], len(buffer), sep)
except TransportIncompleteError:
result = 1
transp.close()
await transp.join()
proc test12(): Future[int] {.async.} =
var ta = initTAddress("127.0.0.1:31353")
var server = createStreamServer(ta, serveClient12, {ReuseAddr})
server.start()
result = await swarmWorker12(ta)
server.stop()
server.close()
await server.join()
proc serveClient13(server: StreamServer, transp: StreamTransport) {.async.} =
transp.close()
await transp.join()
proc swarmWorker13(address: TransportAddress): Future[int] {.async.} =
var transp = await connect(address)
var line = await transp.readLine()
if line == "":
result = 1
else:
result = 0
transp.close()
await transp.join()
proc test13(): Future[int] {.async.} =
var ta = initTAddress("127.0.0.1:31354")
var server = createStreamServer(ta, serveClient13, {ReuseAddr})
server.start()
result = await swarmWorker13(ta)
server.stop()
server.close()
await server.join()
proc serveClient14(server: StreamServer, transp: StreamTransport) {.async.} =
discard
proc test14(): Future[int] {.async.} =
var subres = 0
var ta = initTAddress("127.0.0.1:31354")
var server = createStreamServer(ta, serveClient13, {ReuseAddr})
proc swarmWorker(transp: StreamTransport): Future[void] {.async.} =
var line = await transp.readLine()
if line == "":
subres = 1
else:
subres = 0
server.start()
var transp = await connect(ta)
var fut = swarmWorker(transp)
transp.close()
await fut
server.stop()
server.close()
await server.join()
result = subres
when isMainModule:
const
@ -466,12 +647,29 @@ when isMainModule:
$ClientsCount & " clients x " & $MessagesCount & " messages)"
m7 = "readLine() buffer overflow test"
m8 = "readUntil() buffer overflow test"
m9 = "readLine() unexpected disconnect test"
m10 = "read() unexpected disconnect test"
m11 = "readExactly() unexpected disconnect test"
m12 = "readUntil() unexpected disconnect test"
m13 = "readLine() unexpected disconnect empty string test"
m14 = "Closing socket while operation pending test (issue #8)"
suite "Stream Transport test suite":
test m8:
check waitFor(test8()) == 1
test m7:
check waitFor(test7()) == 1
test m9:
check waitFor(test9()) == 1
test m10:
check waitFor(test10()) == 1
test m11:
check waitFor(test11()) == 1
test m12:
check waitFor(test12()) == 1
test m13:
check waitFor(test13()) == 1
test m14:
check waitFor(test14()) == 1
test m1:
check waitFor(test1()) == ClientsCount * MessagesCount
test m2: