Fix accept races. (#110)

* Fix accept races.
This commit is contained in:
Eugene Kabanov 2020-07-15 11:09:34 +03:00 committed by GitHub
parent 783f84aa4b
commit 31fec25063
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 170 additions and 135 deletions

View File

@ -809,30 +809,38 @@ when defined(windows):
if server.apending:
## Continuation
server.apending = false
if ovl.data.errCode == OSErrorCode(-1):
var ntransp: StreamTransport
var flags = {WinServerPipe}
if NoPipeFlash in server.flags:
flags.incl(WinNoPipeFlash)
if not isNil(server.init):
var transp = server.init(server, server.sock)
ntransp = newStreamPipeTransport(server.sock, server.bufferSize,
transp, flags)
if server.status notin {ServerStatus.Stopped, ServerStatus.Closed}:
if ovl.data.errCode == OSErrorCode(-1):
var ntransp: StreamTransport
var flags = {WinServerPipe}
if NoPipeFlash in server.flags:
flags.incl(WinNoPipeFlash)
if not isNil(server.init):
var transp = server.init(server, server.sock)
ntransp = newStreamPipeTransport(server.sock, server.bufferSize,
transp, flags)
else:
ntransp = newStreamPipeTransport(server.sock, server.bufferSize,
nil, flags)
# Start tracking transport
trackStream(ntransp)
asyncCheck server.function(server, ntransp)
elif int32(ovl.data.errCode) == ERROR_OPERATION_ABORTED:
# CancelIO() interrupt or close call.
if server.status in {ServerStatus.Closed, ServerStatus.Stopped}:
server.clean()
break
else:
ntransp = newStreamPipeTransport(server.sock, server.bufferSize,
nil, flags)
# Start tracking transport
trackStream(ntransp)
asyncCheck server.function(server, ntransp)
elif int32(ovl.data.errCode) == ERROR_OPERATION_ABORTED:
# CancelIO() interrupt or close call.
if server.status in {ServerStatus.Closed, ServerStatus.Stopped}:
# We should not raise defects in this loop.
discard disconnectNamedPipe(Handle(server.sock))
discard closeHandle(HANDLE(server.sock))
raiseTransportOsError(osLastError())
else:
# Server close happens in callback, and we are not started new
# connectNamedPipe session.
if not(server.loopFuture.finished()):
server.clean()
break
else:
doAssert disconnectNamedPipe(Handle(server.sock)) == 1
doAssert closeHandle(HANDLE(server.sock)) == 1
raiseTransportOsError(osLastError())
else:
## Initiation
if server.status notin {ServerStatus.Stopped, ServerStatus.Closed}:
@ -871,9 +879,9 @@ when defined(windows):
else:
# Server close happens in callback, and we are not started new
# connectNamedPipe session.
if server.status in {ServerStatus.Closed, ServerStatus.Stopped}:
if not(server.loopFuture.finished()):
server.clean()
if not(server.loopFuture.finished()):
server.clean()
break
proc acceptLoop(udata: pointer) {.gcsafe, nimcall.} =
var ovl = cast[PtrCustomOverlapped](udata)
@ -884,38 +892,45 @@ when defined(windows):
if server.apending:
## Continuation
server.apending = false
if ovl.data.errCode == OSErrorCode(-1):
if setsockopt(SocketHandle(server.asock), cint(SOL_SOCKET),
cint(SO_UPDATE_ACCEPT_CONTEXT), addr server.sock,
SockLen(sizeof(SocketHandle))) != 0'i32:
let err = OSErrorCode(wsaGetLastError())
server.asock.closeSocket()
raiseTransportOsError(err)
else:
var ntransp: StreamTransport
if not isNil(server.init):
let transp = server.init(server, server.asock)
ntransp = newStreamSocketTransport(server.asock,
server.bufferSize,
transp)
if server.status notin {ServerStatus.Stopped, ServerStatus.Closed}:
if ovl.data.errCode == OSErrorCode(-1):
if setsockopt(SocketHandle(server.asock), cint(SOL_SOCKET),
cint(SO_UPDATE_ACCEPT_CONTEXT), addr server.sock,
SockLen(sizeof(SocketHandle))) != 0'i32:
let err = OSErrorCode(wsaGetLastError())
server.asock.closeSocket()
raiseTransportOsError(err)
else:
ntransp = newStreamSocketTransport(server.asock,
server.bufferSize, nil)
# Start tracking transport
trackStream(ntransp)
asyncCheck server.function(server, ntransp)
var ntransp: StreamTransport
if not isNil(server.init):
let transp = server.init(server, server.asock)
ntransp = newStreamSocketTransport(server.asock,
server.bufferSize,
transp)
else:
ntransp = newStreamSocketTransport(server.asock,
server.bufferSize, nil)
# Start tracking transport
trackStream(ntransp)
asyncCheck server.function(server, ntransp)
elif int32(ovl.data.errCode) == ERROR_OPERATION_ABORTED:
# CancelIO() interrupt or close.
server.asock.closeSocket()
if server.status in {ServerStatus.Closed, ServerStatus.Stopped}:
# Stop tracking server
if not(server.loopFuture.finished()):
server.clean()
break
elif int32(ovl.data.errCode) == ERROR_OPERATION_ABORTED:
# CancelIO() interrupt or close.
server.asock.closeSocket()
if server.status in {ServerStatus.Closed, ServerStatus.Stopped}:
# Stop tracking server
if not(server.loopFuture.finished()):
server.clean()
break
else:
server.asock.closeSocket()
raiseTransportOsError(ovl.data.errCode)
else:
server.asock.closeSocket()
raiseTransportOsError(ovl.data.errCode)
# Server close happens in callback, and we are not started new
# AcceptEx session.
if not(server.loopFuture.finished()):
server.clean()
break
else:
## Initiation
if server.status notin {ServerStatus.Stopped, ServerStatus.Closed}:
@ -949,9 +964,9 @@ when defined(windows):
else:
# Server close happens in callback, and we are not started new
# AcceptEx session.
if server.status in {ServerStatus.Closed, ServerStatus.Stopped}:
if not(server.loopFuture.finished()):
server.clean()
if not(server.loopFuture.finished()):
server.clean()
break
proc resumeRead(transp: StreamTransport) {.inline.} =
if ReadPaused in transp.state:
@ -986,34 +1001,44 @@ when defined(windows):
var server = cast[StreamServer](ovl.data.udata)
server.apending = false
if ovl.data.errCode == OSErrorCode(-1):
if setsockopt(SocketHandle(server.asock), cint(SOL_SOCKET),
cint(SO_UPDATE_ACCEPT_CONTEXT), addr server.sock,
SockLen(sizeof(SocketHandle))) != 0'i32:
let err = OSErrorCode(wsaGetLastError())
server.asock.closeSocket()
retFuture.fail(getTransportOsError(err))
else:
var ntransp: StreamTransport
if not isNil(server.init):
let transp = server.init(server, server.asock)
ntransp = newStreamSocketTransport(server.asock,
server.bufferSize,
transp)
else:
ntransp = newStreamSocketTransport(server.asock,
server.bufferSize, nil)
# Start tracking transport
trackStream(ntransp)
retFuture.complete(ntransp)
elif int32(ovl.data.errCode) == ERROR_OPERATION_ABORTED:
# CancelIO() interrupt or close.
if server.status in {ServerStatus.Stopped, ServerStatus.Closed}:
server.asock.closeSocket()
retFuture.fail(getServerUseClosedError())
server.clean()
else:
server.asock.closeSocket()
retFuture.fail(getTransportOsError(ovl.data.errCode))
if ovl.data.errCode == OSErrorCode(-1):
if setsockopt(SocketHandle(server.asock), cint(SOL_SOCKET),
cint(SO_UPDATE_ACCEPT_CONTEXT), addr server.sock,
SockLen(sizeof(SocketHandle))) != 0'i32:
let err = OSErrorCode(wsaGetLastError())
server.asock.closeSocket()
if int32(err) == WSAENOTSOCK:
# This can be happened when server get closed, but continuation was
# already scheduled, so we failing it not with OS error.
retFuture.fail(getServerUseClosedError())
else:
retFuture.fail(getTransportOsError(err))
else:
var ntransp: StreamTransport
if not isNil(server.init):
let transp = server.init(server, server.asock)
ntransp = newStreamSocketTransport(server.asock,
server.bufferSize,
transp)
else:
ntransp = newStreamSocketTransport(server.asock,
server.bufferSize, nil)
# Start tracking transport
trackStream(ntransp)
retFuture.complete(ntransp)
elif int32(ovl.data.errCode) == ERROR_OPERATION_ABORTED:
# CancelIO() interrupt or close.
server.asock.closeSocket()
retFuture.fail(getServerUseClosedError())
server.clean()
else:
server.asock.closeSocket()
retFuture.fail(getTransportOsError(ovl.data.errCode))
proc cancellationSocket(udata: pointer) {.gcsafe.} =
server.asock.closeSocket()
@ -1023,33 +1048,37 @@ when defined(windows):
var server = cast[StreamServer](ovl.data.udata)
server.apending = false
if ovl.data.errCode == OSErrorCode(-1):
var ntransp: StreamTransport
var flags = {WinServerPipe}
if NoPipeFlash in server.flags:
flags.incl(WinNoPipeFlash)
if not isNil(server.init):
var transp = server.init(server, server.sock)
ntransp = newStreamPipeTransport(server.sock, server.bufferSize,
transp, flags)
else:
ntransp = newStreamPipeTransport(server.sock, server.bufferSize,
nil, flags)
# Start tracking transport
trackStream(ntransp)
server.createAcceptPipe()
retFuture.complete(ntransp)
elif int32(ovl.data.errCode) in {ERROR_OPERATION_ABORTED,
ERROR_PIPE_NOT_CONNECTED}:
# CancelIO() interrupt or close call.
if server.status in {ServerStatus.Stopped, ServerStatus.Closed}:
retFuture.fail(getServerUseClosedError())
server.clean()
else:
let sock = server.sock
server.createAcceptPipe()
closeHandle(sock)
retFuture.fail(getTransportOsError(ovl.data.errCode))
if ovl.data.errCode == OSErrorCode(-1):
var ntransp: StreamTransport
var flags = {WinServerPipe}
if NoPipeFlash in server.flags:
flags.incl(WinNoPipeFlash)
if not isNil(server.init):
var transp = server.init(server, server.sock)
ntransp = newStreamPipeTransport(server.sock, server.bufferSize,
transp, flags)
else:
ntransp = newStreamPipeTransport(server.sock, server.bufferSize,
nil, flags)
# Start tracking transport
trackStream(ntransp)
server.createAcceptPipe()
retFuture.complete(ntransp)
elif int32(ovl.data.errCode) in {ERROR_OPERATION_ABORTED,
ERROR_PIPE_NOT_CONNECTED}:
# CancelIO() interrupt or close call.
retFuture.fail(getServerUseClosedError())
server.clean()
else:
let sock = server.sock
server.createAcceptPipe()
closeHandle(sock)
retFuture.fail(getTransportOsError(ovl.data.errCode))
proc cancellationPipe(udata: pointer) {.gcsafe.} =
server.sock.closeHandle()
@ -1458,6 +1487,9 @@ else:
slen: SockLen
var server = cast[StreamServer](cast[ptr CompletionData](udata).udata)
while true:
if server.status in {ServerStatus.Stopped, ServerStatus.Closed}:
break
let res = posix.accept(SocketHandle(server.sock),
cast[ptr SockAddr](addr saddr), addr slen)
if int(res) > 0:
@ -1514,38 +1546,41 @@ else:
var
saddr: Sockaddr_storage
slen: SockLen
while true:
let res = posix.accept(SocketHandle(server.sock),
cast[ptr SockAddr](addr saddr), addr slen)
if int(res) > 0:
let sock = wrapAsyncSocket(res)
if sock != asyncInvalidSocket:
var ntransp: StreamTransport
if not isNil(server.init):
let transp = server.init(server, sock)
ntransp = newStreamSocketTransport(sock, server.bufferSize,
transp)
else:
ntransp = newStreamSocketTransport(sock, server.bufferSize, nil)
# Start tracking transport
trackStream(ntransp)
retFuture.complete(ntransp)
else:
retFuture.fail(getTransportOsError(osLastError()))
else:
let err = osLastError()
if int(err) == EINTR:
continue
elif int(err) == EAGAIN:
# This error appears only when server get closed, while accept()
# call pending.
retFuture.fail(getServerUseClosedError())
elif int(err) == EMFILE:
retFuture.fail(getTransportTooManyError())
else:
retFuture.fail(getTransportOsError(err))
break
if server.status in {ServerStatus.Stopped, ServerStatus.Closed}:
retFuture.fail(getServerUseClosedError())
else:
while true:
let res = posix.accept(SocketHandle(server.sock),
cast[ptr SockAddr](addr saddr), addr slen)
if int(res) > 0:
let sock = wrapAsyncSocket(res)
if sock != asyncInvalidSocket:
var ntransp: StreamTransport
if not isNil(server.init):
let transp = server.init(server, sock)
ntransp = newStreamSocketTransport(sock, server.bufferSize,
transp)
else:
ntransp = newStreamSocketTransport(sock, server.bufferSize, nil)
# Start tracking transport
trackStream(ntransp)
retFuture.complete(ntransp)
else:
retFuture.fail(getTransportOsError(osLastError()))
else:
let err = osLastError()
if int(err) == EINTR:
continue
elif int(err) == EAGAIN:
# This error appears only when server get closed, while accept()
# continuation is already scheduled.
retFuture.fail(getServerUseClosedError())
elif int(err) == EMFILE:
retFuture.fail(getTransportTooManyError())
else:
retFuture.fail(getTransportOsError(err))
break
removeReader(server.sock)
proc cancellation(udata: pointer) {.gcsafe.} =