* Address issue #219 and add tests for it.
Some cosmetic refactoring.

* Fix *nix tests.
This commit is contained in:
Eugene Kabanov 2021-09-05 00:53:27 +03:00 committed by GitHub
parent 05c91418be
commit 5034f0a5a6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 84 additions and 46 deletions

View File

@ -180,7 +180,7 @@ template setReadError(t, e: untyped) =
(t).error = getTransportOsError(e)
template checkPending(t: untyped) =
if not isNil((t).reader):
if not(isNil((t).reader)):
raise newException(TransportError, "Read operation already pending!")
template shiftBuffer(t, c: untyped) =
@ -282,7 +282,7 @@ proc clean(server: StreamServer) {.inline.} =
if not(server.loopFuture.finished()):
untrackServer(server)
server.loopFuture.complete()
if not isNil(server.udata) and GCUserData in server.flags:
if not(isNil(server.udata)) and (GCUserData in server.flags):
GC_unref(cast[ref int](server.udata))
GC_unref(server)
@ -634,7 +634,7 @@ when defined(windows):
proc newStreamSocketTransport(sock: AsyncFD, bufsize: int,
child: StreamTransport): StreamTransport =
var transp: StreamTransport
if not isNil(child):
if not(isNil(child)):
transp = child
else:
transp = StreamTransport(kind: TransportKind.Socket)
@ -654,7 +654,7 @@ when defined(windows):
child: StreamTransport,
flags: set[TransportFlags] = {}): StreamTransport =
var transp: StreamTransport
if not isNil(child):
if not(isNil(child)):
transp = child
else:
transp = StreamTransport(kind: TransportKind.Pipe)
@ -718,7 +718,7 @@ when defined(windows):
retFuture.fail(getTransportOsError(osLastError()))
return retFuture
if not bindToDomain(sock, raddress.getDomain()):
if not(bindToDomain(sock, raddress.getDomain())):
let err = wsaGetLastError()
sock.closeSocket()
retFuture.fail(getTransportOsError(err))
@ -751,12 +751,12 @@ when defined(windows):
povl = RefCustomOverlapped()
GC_ref(povl)
povl.data = CompletionData(fd: sock, cb: socketContinuation)
var res = loop.connectEx(SocketHandle(sock),
let res = loop.connectEx(SocketHandle(sock),
cast[ptr SockAddr](addr saddr),
DWORD(slen), nil, 0, nil,
cast[POVERLAPPED](povl))
# We will not process immediate completion, to avoid undefined behavior.
if not res:
if not(res):
let err = osLastError()
if int32(err) != ERROR_IO_PENDING:
GC_unref(povl)
@ -839,7 +839,7 @@ when defined(windows):
var flags = {WinServerPipe}
if NoPipeFlash in server.flags:
flags.incl(WinNoPipeFlash)
if not isNil(server.init):
if not(isNil(server.init)):
var transp = server.init(server, server.sock)
ntransp = newStreamPipeTransport(server.sock, server.bufferSize,
transp, flags)
@ -928,7 +928,7 @@ when defined(windows):
raiseAssert osErrorMsg(err)
else:
var ntransp: StreamTransport
if not isNil(server.init):
if not(isNil(server.init)):
let transp = server.init(server, server.asock)
ntransp = newStreamSocketTransport(server.asock,
server.bufferSize,
@ -982,7 +982,7 @@ when defined(windows):
dwReceiveDataLength, dwLocalAddressLength,
dwRemoteAddressLength, addr dwBytesReceived,
cast[POVERLAPPED](addr server.aovl))
if not res:
if not(res):
let err = osLastError()
if int32(err) == ERROR_OPERATION_ABORTED:
server.apending = false
@ -1014,7 +1014,7 @@ when defined(windows):
discard cancelIO(Handle(server.sock))
proc resumeAccept(server: StreamServer) {.inline.} =
if not server.apending:
if not(server.apending):
server.aovl.data.cb(addr server.aovl)
proc accept*(server: StreamServer): Future[StreamTransport] =
@ -1052,7 +1052,7 @@ when defined(windows):
retFuture.fail(getTransportOsError(err))
else:
var ntransp: StreamTransport
if not isNil(server.init):
if not(isNil(server.init)):
let transp = server.init(server, server.asock)
ntransp = newStreamSocketTransport(server.asock,
server.bufferSize,
@ -1073,6 +1073,8 @@ when defined(windows):
retFuture.fail(getTransportOsError(ovl.data.errCode))
proc cancellationSocket(udata: pointer) {.gcsafe.} =
if server.apending:
server.apending = false
server.asock.closeSocket()
proc continuationPipe(udata: pointer) {.gcsafe.} =
@ -1091,7 +1093,7 @@ when defined(windows):
var flags = {WinServerPipe}
if NoPipeFlash in server.flags:
flags.incl(WinNoPipeFlash)
if not isNil(server.init):
if not(isNil(server.init)):
var transp = server.init(server, server.sock)
ntransp = newStreamPipeTransport(server.sock, server.bufferSize,
transp, flags)
@ -1126,6 +1128,8 @@ when defined(windows):
retFuture.fail(getTransportOsError(ovl.data.errCode))
proc cancellationPipe(udata: pointer) {.gcsafe.} =
if server.apending:
server.apending = false
server.sock.closeHandle()
if server.local.family in {AddressFamily.IPv4, AddressFamily.IPv6}:
@ -1160,7 +1164,7 @@ when defined(windows):
dwReceiveDataLength, dwLocalAddressLength,
dwRemoteAddressLength, addr dwBytesReceived,
cast[POVERLAPPED](addr server.aovl))
if not res:
if not(res):
let err = osLastError()
if int32(err) == ERROR_OPERATION_ABORTED:
server.apending = false
@ -1494,7 +1498,7 @@ else:
proc newStreamSocketTransport(sock: AsyncFD, bufsize: int,
child: StreamTransport): StreamTransport =
var transp: StreamTransport
if not isNil(child):
if not(isNil(child)):
transp = child
else:
transp = StreamTransport(kind: TransportKind.Socket)
@ -1510,7 +1514,7 @@ else:
proc newStreamPipeTransport(fd: AsyncFD, bufsize: int,
child: StreamTransport): StreamTransport =
var transp: StreamTransport
if not isNil(child):
if not(isNil(child)):
transp = child
else:
transp = StreamTransport(kind: TransportKind.Pipe)
@ -1569,7 +1573,7 @@ else:
retFuture.fail(exc)
return
if not fd.getSocketError(err):
if not(fd.getSocketError(err)):
closeSocket(fd)
retFuture.fail(getTransportOsError(osLastError()))
return
@ -1636,7 +1640,7 @@ else:
raiseAsDefect exc, "wrapAsyncSocket"
if sock != asyncInvalidSocket:
var ntransp: StreamTransport
if not isNil(server.init):
if not(isNil(server.init)):
let transp = server.init(server, sock)
ntransp = newStreamSocketTransport(sock, server.bufferSize, transp)
else:
@ -1719,7 +1723,7 @@ else:
if sock != asyncInvalidSocket:
var ntransp: StreamTransport
if not isNil(server.init):
if not(isNil(server.init)):
let transp = server.init(server, sock)
ntransp = newStreamSocketTransport(sock, server.bufferSize,
transp)
@ -1773,7 +1777,8 @@ else:
proc start*(server: StreamServer) {.
raises: [Defect, IOSelectorsException, ValueError].} =
## Starts ``server``.
doAssert(not(isNil(server.function)))
doAssert(not(isNil(server.function)),
"You should not start server, if you have not set processing callback!")
if server.status == ServerStatus.Starting:
server.resumeAccept()
server.status = ServerStatus.Running
@ -1782,7 +1787,8 @@ proc stop*(server: StreamServer) {.
raises: [Defect, IOSelectorsException, ValueError].} =
## Stops ``server``.
if server.status == ServerStatus.Running:
server.pauseAccept()
if not(isNil(server.function)):
server.pauseAccept()
server.status = ServerStatus.Stopped
elif server.status == ServerStatus.Starting:
server.status = ServerStatus.Stopped
@ -1814,27 +1820,19 @@ proc close*(server: StreamServer) =
if not(server.loopFuture.finished()):
server.clean()
let r1 = (server.status == ServerStatus.Stopped) and
not(isNil(server.function))
let r2 = (server.status == ServerStatus.Starting) and isNil(server.function)
if r1 or r2:
if server.status in {ServerStatus.Starting, ServerStatus.Stopped}:
server.status = ServerStatus.Closed
when defined(windows):
if server.local.family in {AddressFamily.IPv4, AddressFamily.IPv6}:
if not server.apending:
server.sock.closeSocket(continuation)
else:
if server.apending:
server.asock.closeSocket()
server.sock.closeSocket()
server.apending = false
server.sock.closeSocket(continuation)
elif server.local.family in {AddressFamily.Unix}:
if NoPipeFlash notin server.flags:
discard flushFileBuffers(Handle(server.sock))
discard disconnectNamedPipe(Handle(server.sock))
if not server.apending:
server.sock.closeHandle(continuation)
else:
server.sock.closeHandle()
server.sock.closeHandle(continuation)
else:
server.sock.closeSocket(continuation)
@ -1883,21 +1881,21 @@ proc createStreamServer*(host: TransportAddress,
if serverSocket == asyncInvalidSocket:
raiseTransportOsError(osLastError())
else:
if not setSocketBlocking(SocketHandle(sock), false):
if not(setSocketBlocking(SocketHandle(sock), false)):
raiseTransportOsError(osLastError())
register(sock)
serverSocket = sock
# SO_REUSEADDR is not useful for Unix domain sockets.
if ServerFlags.ReuseAddr in flags:
if not setSockOpt(serverSocket, SOL_SOCKET, SO_REUSEADDR, 1):
if not(setSockOpt(serverSocket, SOL_SOCKET, SO_REUSEADDR, 1)):
let err = osLastError()
if sock == asyncInvalidSocket:
serverSocket.closeSocket()
raiseTransportOsError(err)
# TCP flags are not useful for Unix domain sockets.
if ServerFlags.TcpNoDelay in flags:
if not setSockOpt(serverSocket, handles.IPPROTO_TCP,
handles.TCP_NODELAY, 1):
if not(setSockOpt(serverSocket, handles.IPPROTO_TCP,
handles.TCP_NODELAY, 1)):
let err = osLastError()
if sock == asyncInvalidSocket:
serverSocket.closeSocket()
@ -1940,7 +1938,7 @@ proc createStreamServer*(host: TransportAddress,
if serverSocket == asyncInvalidSocket:
raiseTransportOsError(osLastError())
else:
if not setSocketBlocking(SocketHandle(sock), false):
if not(setSocketBlocking(SocketHandle(sock), false)):
raiseTransportOsError(osLastError())
register(sock)
@ -1949,21 +1947,21 @@ proc createStreamServer*(host: TransportAddress,
if host.family in {AddressFamily.IPv4, AddressFamily.IPv6}:
# SO_REUSEADDR and SO_REUSEPORT are not useful for Unix domain sockets.
if ServerFlags.ReuseAddr in flags:
if not setSockOpt(serverSocket, SOL_SOCKET, SO_REUSEADDR, 1):
if not(setSockOpt(serverSocket, SOL_SOCKET, SO_REUSEADDR, 1)):
let err = osLastError()
if sock == asyncInvalidSocket:
serverSocket.closeSocket()
raiseTransportOsError(err)
if ServerFlags.ReusePort in flags:
if not setSockOpt(serverSocket, SOL_SOCKET, SO_REUSEPORT, 1):
if not(setSockOpt(serverSocket, SOL_SOCKET, SO_REUSEPORT, 1)):
let err = osLastError()
if sock == asyncInvalidSocket:
serverSocket.closeSocket()
raiseTransportOsError(err)
# TCP flags are not useful for Unix domain sockets.
if ServerFlags.TcpNoDelay in flags:
if not setSockOpt(serverSocket, handles.IPPROTO_TCP,
handles.TCP_NODELAY, 1):
if not(setSockOpt(serverSocket, handles.IPPROTO_TCP,
handles.TCP_NODELAY, 1)):
let err = osLastError()
if sock == asyncInvalidSocket:
serverSocket.closeSocket()
@ -1997,7 +1995,7 @@ proc createStreamServer*(host: TransportAddress,
serverSocket.closeSocket()
raiseTransportOsError(err)
if not isNil(child):
if not(isNil(child)):
result = child
else:
result = StreamServer()
@ -2099,7 +2097,7 @@ proc write*(transp: StreamTransport, msg: string, msglen = -1): Future[int] =
var retFuture = newFutureStr[int]("stream.transport.write(string)")
transp.checkClosed(retFuture)
transp.checkWriteEof(retFuture)
if not isLiteral(msg):
if not(isLiteral(msg)):
shallowCopy(retFuture.gcholder, msg)
else:
retFuture.gcholder = msg
@ -2117,7 +2115,7 @@ proc write*[T](transp: StreamTransport, msg: seq[T], msglen = -1): Future[int] =
var retFuture = newFutureSeq[int, T]("stream.transport.write(seq)")
transp.checkClosed(retFuture)
transp.checkWriteEof(retFuture)
if not isLiteral(msg):
if not(isLiteral(msg)):
shallowCopy(retFuture.gcholder, msg)
else:
retFuture.gcholder = msg

View File

@ -1201,6 +1201,44 @@ suite "Stream Transport test suite":
await acceptFut
return res
proc testAcceptRace(address: TransportAddress): Future[bool] {.async.} =
proc test1(address: TransportAddress) {.async.} =
let server = createStreamServer(address, flags = {ReuseAddr})
let acceptFut = server.accept()
server.close()
await allFutures(acceptFut.cancelAndWait(), server.join())
proc test2(address: TransportAddress) {.async.} =
let server = createStreamServer(address, flags = {ReuseAddr})
let acceptFut = server.accept()
await acceptFut.cancelAndWait()
server.close()
await server.join()
proc test3(address: TransportAddress) {.async.} =
let server = createStreamServer(address, flags = {ReuseAddr})
let acceptFut = server.accept()
server.stop()
server.close()
await allFutures(acceptFut.cancelAndWait(), server.join())
proc test4(address: TransportAddress) {.async.} =
let server = createStreamServer(address, flags = {ReuseAddr})
let acceptFut = server.accept()
await acceptFut.cancelAndWait()
server.stop()
server.close()
await server.join()
try:
await test1(address).wait(5.seconds)
await test2(address).wait(5.seconds)
await test3(address).wait(5.seconds)
await test4(address).wait(5.seconds)
return true
except AsyncTimeoutError:
return false
markFD = getCurrentFD()
for i in 0..<len(addresses):
@ -1275,6 +1313,8 @@ suite "Stream Transport test suite":
skip()
else:
check waitFor(testAcceptTooMany(addresses[i])) == true
test prefixes[i] & "accept() and close() race test":
check waitFor(testAcceptRace(addresses[i])) == true
test prefixes[i] & "write() queue notification on close() test":
check waitFor(testWriteOnClose(addresses[i])) == true
test prefixes[i] & "read() notification on close() test":