* Address issue #219 and add tests for it.
Some cosmetic refactoring.

* Fix *nix tests.
This commit is contained in:
Eugene Kabanov 2021-09-05 00:53:27 +03:00 committed by GitHub
parent 05c91418be
commit 5034f0a5a6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 84 additions and 46 deletions

View File

@ -180,7 +180,7 @@ template setReadError(t, e: untyped) =
(t).error = getTransportOsError(e) (t).error = getTransportOsError(e)
template checkPending(t: untyped) = template checkPending(t: untyped) =
if not isNil((t).reader): if not(isNil((t).reader)):
raise newException(TransportError, "Read operation already pending!") raise newException(TransportError, "Read operation already pending!")
template shiftBuffer(t, c: untyped) = template shiftBuffer(t, c: untyped) =
@ -282,7 +282,7 @@ proc clean(server: StreamServer) {.inline.} =
if not(server.loopFuture.finished()): if not(server.loopFuture.finished()):
untrackServer(server) untrackServer(server)
server.loopFuture.complete() server.loopFuture.complete()
if not isNil(server.udata) and GCUserData in server.flags: if not(isNil(server.udata)) and (GCUserData in server.flags):
GC_unref(cast[ref int](server.udata)) GC_unref(cast[ref int](server.udata))
GC_unref(server) GC_unref(server)
@ -634,7 +634,7 @@ when defined(windows):
proc newStreamSocketTransport(sock: AsyncFD, bufsize: int, proc newStreamSocketTransport(sock: AsyncFD, bufsize: int,
child: StreamTransport): StreamTransport = child: StreamTransport): StreamTransport =
var transp: StreamTransport var transp: StreamTransport
if not isNil(child): if not(isNil(child)):
transp = child transp = child
else: else:
transp = StreamTransport(kind: TransportKind.Socket) transp = StreamTransport(kind: TransportKind.Socket)
@ -654,7 +654,7 @@ when defined(windows):
child: StreamTransport, child: StreamTransport,
flags: set[TransportFlags] = {}): StreamTransport = flags: set[TransportFlags] = {}): StreamTransport =
var transp: StreamTransport var transp: StreamTransport
if not isNil(child): if not(isNil(child)):
transp = child transp = child
else: else:
transp = StreamTransport(kind: TransportKind.Pipe) transp = StreamTransport(kind: TransportKind.Pipe)
@ -718,7 +718,7 @@ when defined(windows):
retFuture.fail(getTransportOsError(osLastError())) retFuture.fail(getTransportOsError(osLastError()))
return retFuture return retFuture
if not bindToDomain(sock, raddress.getDomain()): if not(bindToDomain(sock, raddress.getDomain())):
let err = wsaGetLastError() let err = wsaGetLastError()
sock.closeSocket() sock.closeSocket()
retFuture.fail(getTransportOsError(err)) retFuture.fail(getTransportOsError(err))
@ -751,12 +751,12 @@ when defined(windows):
povl = RefCustomOverlapped() povl = RefCustomOverlapped()
GC_ref(povl) GC_ref(povl)
povl.data = CompletionData(fd: sock, cb: socketContinuation) povl.data = CompletionData(fd: sock, cb: socketContinuation)
var res = loop.connectEx(SocketHandle(sock), let res = loop.connectEx(SocketHandle(sock),
cast[ptr SockAddr](addr saddr), cast[ptr SockAddr](addr saddr),
DWORD(slen), nil, 0, nil, DWORD(slen), nil, 0, nil,
cast[POVERLAPPED](povl)) cast[POVERLAPPED](povl))
# We will not process immediate completion, to avoid undefined behavior. # We will not process immediate completion, to avoid undefined behavior.
if not res: if not(res):
let err = osLastError() let err = osLastError()
if int32(err) != ERROR_IO_PENDING: if int32(err) != ERROR_IO_PENDING:
GC_unref(povl) GC_unref(povl)
@ -839,7 +839,7 @@ when defined(windows):
var flags = {WinServerPipe} var flags = {WinServerPipe}
if NoPipeFlash in server.flags: if NoPipeFlash in server.flags:
flags.incl(WinNoPipeFlash) flags.incl(WinNoPipeFlash)
if not isNil(server.init): if not(isNil(server.init)):
var transp = server.init(server, server.sock) var transp = server.init(server, server.sock)
ntransp = newStreamPipeTransport(server.sock, server.bufferSize, ntransp = newStreamPipeTransport(server.sock, server.bufferSize,
transp, flags) transp, flags)
@ -928,7 +928,7 @@ when defined(windows):
raiseAssert osErrorMsg(err) raiseAssert osErrorMsg(err)
else: else:
var ntransp: StreamTransport var ntransp: StreamTransport
if not isNil(server.init): if not(isNil(server.init)):
let transp = server.init(server, server.asock) let transp = server.init(server, server.asock)
ntransp = newStreamSocketTransport(server.asock, ntransp = newStreamSocketTransport(server.asock,
server.bufferSize, server.bufferSize,
@ -982,7 +982,7 @@ when defined(windows):
dwReceiveDataLength, dwLocalAddressLength, dwReceiveDataLength, dwLocalAddressLength,
dwRemoteAddressLength, addr dwBytesReceived, dwRemoteAddressLength, addr dwBytesReceived,
cast[POVERLAPPED](addr server.aovl)) cast[POVERLAPPED](addr server.aovl))
if not res: if not(res):
let err = osLastError() let err = osLastError()
if int32(err) == ERROR_OPERATION_ABORTED: if int32(err) == ERROR_OPERATION_ABORTED:
server.apending = false server.apending = false
@ -1014,7 +1014,7 @@ when defined(windows):
discard cancelIO(Handle(server.sock)) discard cancelIO(Handle(server.sock))
proc resumeAccept(server: StreamServer) {.inline.} = proc resumeAccept(server: StreamServer) {.inline.} =
if not server.apending: if not(server.apending):
server.aovl.data.cb(addr server.aovl) server.aovl.data.cb(addr server.aovl)
proc accept*(server: StreamServer): Future[StreamTransport] = proc accept*(server: StreamServer): Future[StreamTransport] =
@ -1052,7 +1052,7 @@ when defined(windows):
retFuture.fail(getTransportOsError(err)) retFuture.fail(getTransportOsError(err))
else: else:
var ntransp: StreamTransport var ntransp: StreamTransport
if not isNil(server.init): if not(isNil(server.init)):
let transp = server.init(server, server.asock) let transp = server.init(server, server.asock)
ntransp = newStreamSocketTransport(server.asock, ntransp = newStreamSocketTransport(server.asock,
server.bufferSize, server.bufferSize,
@ -1073,6 +1073,8 @@ when defined(windows):
retFuture.fail(getTransportOsError(ovl.data.errCode)) retFuture.fail(getTransportOsError(ovl.data.errCode))
proc cancellationSocket(udata: pointer) {.gcsafe.} = proc cancellationSocket(udata: pointer) {.gcsafe.} =
if server.apending:
server.apending = false
server.asock.closeSocket() server.asock.closeSocket()
proc continuationPipe(udata: pointer) {.gcsafe.} = proc continuationPipe(udata: pointer) {.gcsafe.} =
@ -1091,7 +1093,7 @@ when defined(windows):
var flags = {WinServerPipe} var flags = {WinServerPipe}
if NoPipeFlash in server.flags: if NoPipeFlash in server.flags:
flags.incl(WinNoPipeFlash) flags.incl(WinNoPipeFlash)
if not isNil(server.init): if not(isNil(server.init)):
var transp = server.init(server, server.sock) var transp = server.init(server, server.sock)
ntransp = newStreamPipeTransport(server.sock, server.bufferSize, ntransp = newStreamPipeTransport(server.sock, server.bufferSize,
transp, flags) transp, flags)
@ -1126,6 +1128,8 @@ when defined(windows):
retFuture.fail(getTransportOsError(ovl.data.errCode)) retFuture.fail(getTransportOsError(ovl.data.errCode))
proc cancellationPipe(udata: pointer) {.gcsafe.} = proc cancellationPipe(udata: pointer) {.gcsafe.} =
if server.apending:
server.apending = false
server.sock.closeHandle() server.sock.closeHandle()
if server.local.family in {AddressFamily.IPv4, AddressFamily.IPv6}: if server.local.family in {AddressFamily.IPv4, AddressFamily.IPv6}:
@ -1160,7 +1164,7 @@ when defined(windows):
dwReceiveDataLength, dwLocalAddressLength, dwReceiveDataLength, dwLocalAddressLength,
dwRemoteAddressLength, addr dwBytesReceived, dwRemoteAddressLength, addr dwBytesReceived,
cast[POVERLAPPED](addr server.aovl)) cast[POVERLAPPED](addr server.aovl))
if not res: if not(res):
let err = osLastError() let err = osLastError()
if int32(err) == ERROR_OPERATION_ABORTED: if int32(err) == ERROR_OPERATION_ABORTED:
server.apending = false server.apending = false
@ -1494,7 +1498,7 @@ else:
proc newStreamSocketTransport(sock: AsyncFD, bufsize: int, proc newStreamSocketTransport(sock: AsyncFD, bufsize: int,
child: StreamTransport): StreamTransport = child: StreamTransport): StreamTransport =
var transp: StreamTransport var transp: StreamTransport
if not isNil(child): if not(isNil(child)):
transp = child transp = child
else: else:
transp = StreamTransport(kind: TransportKind.Socket) transp = StreamTransport(kind: TransportKind.Socket)
@ -1510,7 +1514,7 @@ else:
proc newStreamPipeTransport(fd: AsyncFD, bufsize: int, proc newStreamPipeTransport(fd: AsyncFD, bufsize: int,
child: StreamTransport): StreamTransport = child: StreamTransport): StreamTransport =
var transp: StreamTransport var transp: StreamTransport
if not isNil(child): if not(isNil(child)):
transp = child transp = child
else: else:
transp = StreamTransport(kind: TransportKind.Pipe) transp = StreamTransport(kind: TransportKind.Pipe)
@ -1569,7 +1573,7 @@ else:
retFuture.fail(exc) retFuture.fail(exc)
return return
if not fd.getSocketError(err): if not(fd.getSocketError(err)):
closeSocket(fd) closeSocket(fd)
retFuture.fail(getTransportOsError(osLastError())) retFuture.fail(getTransportOsError(osLastError()))
return return
@ -1636,7 +1640,7 @@ else:
raiseAsDefect exc, "wrapAsyncSocket" raiseAsDefect exc, "wrapAsyncSocket"
if sock != asyncInvalidSocket: if sock != asyncInvalidSocket:
var ntransp: StreamTransport var ntransp: StreamTransport
if not isNil(server.init): if not(isNil(server.init)):
let transp = server.init(server, sock) let transp = server.init(server, sock)
ntransp = newStreamSocketTransport(sock, server.bufferSize, transp) ntransp = newStreamSocketTransport(sock, server.bufferSize, transp)
else: else:
@ -1719,7 +1723,7 @@ else:
if sock != asyncInvalidSocket: if sock != asyncInvalidSocket:
var ntransp: StreamTransport var ntransp: StreamTransport
if not isNil(server.init): if not(isNil(server.init)):
let transp = server.init(server, sock) let transp = server.init(server, sock)
ntransp = newStreamSocketTransport(sock, server.bufferSize, ntransp = newStreamSocketTransport(sock, server.bufferSize,
transp) transp)
@ -1773,7 +1777,8 @@ else:
proc start*(server: StreamServer) {. proc start*(server: StreamServer) {.
raises: [Defect, IOSelectorsException, ValueError].} = raises: [Defect, IOSelectorsException, ValueError].} =
## Starts ``server``. ## Starts ``server``.
doAssert(not(isNil(server.function))) doAssert(not(isNil(server.function)),
"You should not start server, if you have not set processing callback!")
if server.status == ServerStatus.Starting: if server.status == ServerStatus.Starting:
server.resumeAccept() server.resumeAccept()
server.status = ServerStatus.Running server.status = ServerStatus.Running
@ -1782,6 +1787,7 @@ proc stop*(server: StreamServer) {.
raises: [Defect, IOSelectorsException, ValueError].} = raises: [Defect, IOSelectorsException, ValueError].} =
## Stops ``server``. ## Stops ``server``.
if server.status == ServerStatus.Running: if server.status == ServerStatus.Running:
if not(isNil(server.function)):
server.pauseAccept() server.pauseAccept()
server.status = ServerStatus.Stopped server.status = ServerStatus.Stopped
elif server.status == ServerStatus.Starting: elif server.status == ServerStatus.Starting:
@ -1814,27 +1820,19 @@ proc close*(server: StreamServer) =
if not(server.loopFuture.finished()): if not(server.loopFuture.finished()):
server.clean() server.clean()
let r1 = (server.status == ServerStatus.Stopped) and if server.status in {ServerStatus.Starting, ServerStatus.Stopped}:
not(isNil(server.function))
let r2 = (server.status == ServerStatus.Starting) and isNil(server.function)
if r1 or r2:
server.status = ServerStatus.Closed server.status = ServerStatus.Closed
when defined(windows): when defined(windows):
if server.local.family in {AddressFamily.IPv4, AddressFamily.IPv6}: if server.local.family in {AddressFamily.IPv4, AddressFamily.IPv6}:
if not server.apending: if server.apending:
server.sock.closeSocket(continuation)
else:
server.asock.closeSocket() server.asock.closeSocket()
server.sock.closeSocket() server.apending = false
server.sock.closeSocket(continuation)
elif server.local.family in {AddressFamily.Unix}: elif server.local.family in {AddressFamily.Unix}:
if NoPipeFlash notin server.flags: if NoPipeFlash notin server.flags:
discard flushFileBuffers(Handle(server.sock)) discard flushFileBuffers(Handle(server.sock))
discard disconnectNamedPipe(Handle(server.sock)) discard disconnectNamedPipe(Handle(server.sock))
if not server.apending:
server.sock.closeHandle(continuation) server.sock.closeHandle(continuation)
else:
server.sock.closeHandle()
else: else:
server.sock.closeSocket(continuation) server.sock.closeSocket(continuation)
@ -1883,21 +1881,21 @@ proc createStreamServer*(host: TransportAddress,
if serverSocket == asyncInvalidSocket: if serverSocket == asyncInvalidSocket:
raiseTransportOsError(osLastError()) raiseTransportOsError(osLastError())
else: else:
if not setSocketBlocking(SocketHandle(sock), false): if not(setSocketBlocking(SocketHandle(sock), false)):
raiseTransportOsError(osLastError()) raiseTransportOsError(osLastError())
register(sock) register(sock)
serverSocket = sock serverSocket = sock
# SO_REUSEADDR is not useful for Unix domain sockets. # SO_REUSEADDR is not useful for Unix domain sockets.
if ServerFlags.ReuseAddr in flags: if ServerFlags.ReuseAddr in flags:
if not setSockOpt(serverSocket, SOL_SOCKET, SO_REUSEADDR, 1): if not(setSockOpt(serverSocket, SOL_SOCKET, SO_REUSEADDR, 1)):
let err = osLastError() let err = osLastError()
if sock == asyncInvalidSocket: if sock == asyncInvalidSocket:
serverSocket.closeSocket() serverSocket.closeSocket()
raiseTransportOsError(err) raiseTransportOsError(err)
# TCP flags are not useful for Unix domain sockets. # TCP flags are not useful for Unix domain sockets.
if ServerFlags.TcpNoDelay in flags: if ServerFlags.TcpNoDelay in flags:
if not setSockOpt(serverSocket, handles.IPPROTO_TCP, if not(setSockOpt(serverSocket, handles.IPPROTO_TCP,
handles.TCP_NODELAY, 1): handles.TCP_NODELAY, 1)):
let err = osLastError() let err = osLastError()
if sock == asyncInvalidSocket: if sock == asyncInvalidSocket:
serverSocket.closeSocket() serverSocket.closeSocket()
@ -1940,7 +1938,7 @@ proc createStreamServer*(host: TransportAddress,
if serverSocket == asyncInvalidSocket: if serverSocket == asyncInvalidSocket:
raiseTransportOsError(osLastError()) raiseTransportOsError(osLastError())
else: else:
if not setSocketBlocking(SocketHandle(sock), false): if not(setSocketBlocking(SocketHandle(sock), false)):
raiseTransportOsError(osLastError()) raiseTransportOsError(osLastError())
register(sock) register(sock)
@ -1949,21 +1947,21 @@ proc createStreamServer*(host: TransportAddress,
if host.family in {AddressFamily.IPv4, AddressFamily.IPv6}: if host.family in {AddressFamily.IPv4, AddressFamily.IPv6}:
# SO_REUSEADDR and SO_REUSEPORT are not useful for Unix domain sockets. # SO_REUSEADDR and SO_REUSEPORT are not useful for Unix domain sockets.
if ServerFlags.ReuseAddr in flags: if ServerFlags.ReuseAddr in flags:
if not setSockOpt(serverSocket, SOL_SOCKET, SO_REUSEADDR, 1): if not(setSockOpt(serverSocket, SOL_SOCKET, SO_REUSEADDR, 1)):
let err = osLastError() let err = osLastError()
if sock == asyncInvalidSocket: if sock == asyncInvalidSocket:
serverSocket.closeSocket() serverSocket.closeSocket()
raiseTransportOsError(err) raiseTransportOsError(err)
if ServerFlags.ReusePort in flags: if ServerFlags.ReusePort in flags:
if not setSockOpt(serverSocket, SOL_SOCKET, SO_REUSEPORT, 1): if not(setSockOpt(serverSocket, SOL_SOCKET, SO_REUSEPORT, 1)):
let err = osLastError() let err = osLastError()
if sock == asyncInvalidSocket: if sock == asyncInvalidSocket:
serverSocket.closeSocket() serverSocket.closeSocket()
raiseTransportOsError(err) raiseTransportOsError(err)
# TCP flags are not useful for Unix domain sockets. # TCP flags are not useful for Unix domain sockets.
if ServerFlags.TcpNoDelay in flags: if ServerFlags.TcpNoDelay in flags:
if not setSockOpt(serverSocket, handles.IPPROTO_TCP, if not(setSockOpt(serverSocket, handles.IPPROTO_TCP,
handles.TCP_NODELAY, 1): handles.TCP_NODELAY, 1)):
let err = osLastError() let err = osLastError()
if sock == asyncInvalidSocket: if sock == asyncInvalidSocket:
serverSocket.closeSocket() serverSocket.closeSocket()
@ -1997,7 +1995,7 @@ proc createStreamServer*(host: TransportAddress,
serverSocket.closeSocket() serverSocket.closeSocket()
raiseTransportOsError(err) raiseTransportOsError(err)
if not isNil(child): if not(isNil(child)):
result = child result = child
else: else:
result = StreamServer() result = StreamServer()
@ -2099,7 +2097,7 @@ proc write*(transp: StreamTransport, msg: string, msglen = -1): Future[int] =
var retFuture = newFutureStr[int]("stream.transport.write(string)") var retFuture = newFutureStr[int]("stream.transport.write(string)")
transp.checkClosed(retFuture) transp.checkClosed(retFuture)
transp.checkWriteEof(retFuture) transp.checkWriteEof(retFuture)
if not isLiteral(msg): if not(isLiteral(msg)):
shallowCopy(retFuture.gcholder, msg) shallowCopy(retFuture.gcholder, msg)
else: else:
retFuture.gcholder = msg retFuture.gcholder = msg
@ -2117,7 +2115,7 @@ proc write*[T](transp: StreamTransport, msg: seq[T], msglen = -1): Future[int] =
var retFuture = newFutureSeq[int, T]("stream.transport.write(seq)") var retFuture = newFutureSeq[int, T]("stream.transport.write(seq)")
transp.checkClosed(retFuture) transp.checkClosed(retFuture)
transp.checkWriteEof(retFuture) transp.checkWriteEof(retFuture)
if not isLiteral(msg): if not(isLiteral(msg)):
shallowCopy(retFuture.gcholder, msg) shallowCopy(retFuture.gcholder, msg)
else: else:
retFuture.gcholder = msg retFuture.gcholder = msg

View File

@ -1201,6 +1201,44 @@ suite "Stream Transport test suite":
await acceptFut await acceptFut
return res return res
proc testAcceptRace(address: TransportAddress): Future[bool] {.async.} =
proc test1(address: TransportAddress) {.async.} =
let server = createStreamServer(address, flags = {ReuseAddr})
let acceptFut = server.accept()
server.close()
await allFutures(acceptFut.cancelAndWait(), server.join())
proc test2(address: TransportAddress) {.async.} =
let server = createStreamServer(address, flags = {ReuseAddr})
let acceptFut = server.accept()
await acceptFut.cancelAndWait()
server.close()
await server.join()
proc test3(address: TransportAddress) {.async.} =
let server = createStreamServer(address, flags = {ReuseAddr})
let acceptFut = server.accept()
server.stop()
server.close()
await allFutures(acceptFut.cancelAndWait(), server.join())
proc test4(address: TransportAddress) {.async.} =
let server = createStreamServer(address, flags = {ReuseAddr})
let acceptFut = server.accept()
await acceptFut.cancelAndWait()
server.stop()
server.close()
await server.join()
try:
await test1(address).wait(5.seconds)
await test2(address).wait(5.seconds)
await test3(address).wait(5.seconds)
await test4(address).wait(5.seconds)
return true
except AsyncTimeoutError:
return false
markFD = getCurrentFD() markFD = getCurrentFD()
for i in 0..<len(addresses): for i in 0..<len(addresses):
@ -1275,6 +1313,8 @@ suite "Stream Transport test suite":
skip() skip()
else: else:
check waitFor(testAcceptTooMany(addresses[i])) == true check waitFor(testAcceptTooMany(addresses[i])) == true
test prefixes[i] & "accept() and close() race test":
check waitFor(testAcceptRace(addresses[i])) == true
test prefixes[i] & "write() queue notification on close() test": test prefixes[i] & "write() queue notification on close() test":
check waitFor(testWriteOnClose(addresses[i])) == true check waitFor(testWriteOnClose(addresses[i])) == true
test prefixes[i] & "read() notification on close() test": test prefixes[i] & "read() notification on close() test":