Fix windows asyncLoop.

Fix OSError -> TransportOSError.
Add inherited objects initialization.
Add tests for inherited objects.
This commit is contained in:
cheatfate 2018-06-11 02:08:17 +03:00
parent 0ee9a148c7
commit 5815897de6
3 changed files with 201 additions and 85 deletions

View File

@ -281,6 +281,11 @@ template getError*(t: untyped): ref Exception =
(t).error = nil
err
proc raiseTransportOsError*(err: OSErrorCode) =
## Raises transport specific OS error.
var msg = "(" & $int(err) & ") " & osErrorMsg(err)
raise newException(TransportOsError, msg)
when defined(windows):
import winlean

View File

@ -33,25 +33,56 @@ type
Pipe, # Pipe transport
File # File transport
when defined(windows):
const SO_UPDATE_CONNECT_CONTEXT = 0x7010
type
StreamTransport* = ref object of RootRef
fd*: AsyncFD # File descriptor
state: set[TransportState] # Current Transport state
reader: Future[void] # Current reader Future
buffer: seq[byte] # Reading buffer
offset: int # Reading buffer offset
error: ref Exception # Current error
queue: Deque[StreamVector] # Writer queue
future: Future[void] # Stream life future
# Windows specific part
rwsabuf: TWSABuf # Reader WSABUF
wwsabuf: TWSABuf # Writer WSABUF
rovl: CustomOverlapped # Reader OVERLAPPED structure
wovl: CustomOverlapped # Writer OVERLAPPED structure
roffset: int # Pending reading offset
case kind*: TransportKind
of TransportKind.Socket:
domain: Domain # Socket transport domain (IPv4/IPv6)
local: TransportAddress # Local address
remote: TransportAddress # Remote address
of TransportKind.Pipe:
todo1: int
of TransportKind.File:
todo2: int
else:
type
StreamTransport* = ref object of RootRef
fd*: AsyncFD # File descriptor
state: set[TransportState] # Current Transport state
reader: Future[void] # Current reader Future
buffer: seq[byte] # Reading buffer
offset: int # Reading buffer offset
error: ref Exception # Current error
queue: Deque[StreamVector] # Writer queue
future: Future[void] # Stream life future
case kind*: TransportKind
of TransportKind.Socket:
domain: Domain # Socket transport domain (IPv4/IPv6)
local: TransportAddress # Local address
remote: TransportAddress # Remote address
of TransportKind.Pipe:
todo1: int
of TransportKind.File:
todo2: int
type
StreamTransport* = ref object of RootRef
fd*: AsyncFD # File descriptor
state: set[TransportState] # Current Transport state
reader: Future[void] # Current reader Future
buffer: seq[byte] # Reading buffer
offset: int # Reading buffer offset
error: ref Exception # Current error
queue: Deque[StreamVector] # Writer queue
future: Future[void] # Stream life future
case kind*: TransportKind
of TransportKind.Socket:
domain: Domain # Socket transport domain (IPv4/IPv6)
local: TransportAddress # Local address
remote: TransportAddress # Remote address
of TransportKind.Pipe:
todo1: int
of TransportKind.File:
todo2: int
StreamCallback* = proc(server: StreamServer,
client: StreamTransport): Future[void] {.gcsafe.}
@ -59,8 +90,17 @@ type
## ``server`` - StreamServer object.
## ``client`` - accepted client transport.
TransportInitCallback* = proc(server: StreamServer,
fd: AsyncFD): StreamTransport {.gcsafe.}
## Custom transport initialization procedure, which can allocated inherited
## StreamTransport object.
StreamServer* = ref object of SocketServer
function*: StreamCallback
## StreamServer object
function*: StreamCallback # callback which will be called after new
# client accepted
init*: TransportInitCallback # callback which will be called before
# transport for new client
proc remoteAddress*(transp: StreamTransport): TransportAddress =
## Returns ``transp`` remote socket address.
@ -116,16 +156,6 @@ template shiftVectorFile(v, o: untyped) =
(v).offset += cast[uint]((o))
when defined(windows):
import winlean
type
WindowsStreamTransport = ref object of StreamTransport
rwsabuf: TWSABuf # Reader WSABUF
wwsabuf: TWSABuf # Writer WSABUF
rovl: CustomOverlapped # Reader OVERLAPPED structure
wovl: CustomOverlapped # Writer OVERLAPPED structure
roffset: int # Pending reading offset
const SO_UPDATE_CONNECT_CONTEXT = 0x7010
template zeroOvelappedOffset(t: untyped) =
(t).offset = 0
@ -157,7 +187,7 @@ when defined(windows):
proc writeStreamLoop(udata: pointer) {.gcsafe, nimcall.} =
var bytesCount: int32
var ovl = cast[PtrCustomOverlapped](udata)
var transp = cast[WindowsStreamTransport](ovl.data.udata)
var transp = cast[StreamTransport](ovl.data.udata)
while len(transp.queue) > 0:
if WritePending in transp.state:
@ -258,7 +288,7 @@ when defined(windows):
proc readStreamLoop(udata: pointer) {.gcsafe, nimcall.} =
var ovl = cast[PtrCustomOverlapped](udata)
var transp = cast[WindowsStreamTransport](ovl.data.udata)
var transp = cast[StreamTransport](ovl.data.udata)
while true:
if ReadPending in transp.state:
@ -324,8 +354,13 @@ when defined(windows):
## Finish Loop
break
proc newStreamSocketTransport(sock: AsyncFD, bufsize: int): StreamTransport =
var transp = WindowsStreamTransport(kind: TransportKind.Socket)
proc newStreamSocketTransport(sock: AsyncFD, bufsize: int,
child: StreamTransport): StreamTransport =
var transp: StreamTransport
if not isNil(child):
transp = child
else:
transp = StreamTransport(kind: TransportKind.Socket)
transp.fd = sock
transp.rovl.data = CompletionData(fd: sock, cb: readStreamLoop,
udata: cast[pointer](transp))
@ -335,12 +370,8 @@ when defined(windows):
transp.state = {ReadPaused, WritePaused}
transp.queue = initDeque[StreamVector]()
transp.future = newFuture[void]("stream.socket.transport")
# ZAH: If these objects are going to be manually managed, why do we bother
# with using the GC at all? It's better to rely on a destructor. If someone
# wants to share a Transport reference, they can still create a GC-managed
# wrapping object.
GC_ref(transp)
result = cast[StreamTransport](transp)
result = transp
proc bindToDomain(handle: AsyncFD, domain: Domain): bool =
result = true
@ -358,7 +389,8 @@ when defined(windows):
result = false
proc connect*(address: TransportAddress,
bufferSize = DefaultStreamBufferSize): Future[StreamTransport] =
bufferSize = DefaultStreamBufferSize,
child: StreamTransport = nil): Future[StreamTransport] =
## Open new connection to remote peer with address ``address`` and create
## new transport object ``StreamTransport`` for established connection.
## ``bufferSize`` is size of internal buffer for transport.
@ -392,7 +424,8 @@ when defined(windows):
osErrorMsg(osLastError())))
else:
retFuture.complete(newStreamSocketTransport(povl.data.fd,
bufferSize))
bufferSize,
child))
else:
sock.closeAsyncSocket()
retFuture.fail(newException(TransportOsError,
@ -433,24 +466,38 @@ when defined(windows):
addr server.sock,
SockLen(sizeof(SocketHandle))) != 0'i32:
server.asock.closeAsyncSocket()
raise newException(TransportOsError, osErrorMsg(osLastError()))
raiseTransportOsError(osLastError())
else:
discard server.function(server,
newStreamSocketTransport(server.asock, server.bufferSize))
if not isNil(server.init):
var transp = server.init(server, server.asock)
discard server.function(
server,
newStreamSocketTransport(server.asock, server.bufferSize,
transp)
)
else:
discard server.function(
server,
newStreamSocketTransport(server.asock, server.bufferSize, nil)
)
elif int32(ovl.data.errCode) == ERROR_OPERATION_ABORTED:
# CancelIO() interrupt
server.asock.closeAsyncSocket()
break
else:
server.asock.closeAsyncSocket()
raiseOsError(osLastError())
raiseTransportOsError(osLastError())
else:
## Initiation
if server.status in {ServerStatus.Stopped, ServerStatus.Closed}:
## Server was already stopped/closed exiting
break
server.apending = true
server.asock = createAsyncSocket(server.domain, SockType.SOCK_STREAM,
Protocol.IPPROTO_TCP)
if server.asock == asyncInvalidSocket:
raiseOsError(osLastError())
raiseTransportOsError(osLastError())
var dwBytesReceived = DWORD(0)
let dwReceiveDataLength = DWORD(0)
@ -471,18 +518,16 @@ when defined(windows):
elif int32(err) == ERROR_IO_PENDING:
discard
else:
raiseOsError(osLastError())
raiseTransportOsError(err)
break
proc resumeRead(transp: StreamTransport) {.inline.} =
var wtransp = cast[WindowsStreamTransport](transp)
wtransp.state.excl(ReadPaused)
readStreamLoop(cast[pointer](addr wtransp.rovl))
transp.state.excl(ReadPaused)
readStreamLoop(cast[pointer](addr transp.rovl))
proc resumeWrite(transp: StreamTransport) {.inline.} =
var wtransp = cast[WindowsStreamTransport](transp)
wtransp.state.excl(WritePaused)
writeStreamLoop(cast[pointer](addr wtransp.wovl))
transp.state.excl(WritePaused)
writeStreamLoop(cast[pointer](addr transp.wovl))
proc pauseAccept(server: StreamServer) {.inline.} =
if server.apending:
@ -492,10 +537,6 @@ when defined(windows):
if not server.apending:
acceptLoop(cast[pointer](addr server.aovl))
else:
import posix
type
UnixStreamTransport* = ref object of StreamTransport
template getVectorBuffer(v: untyped): pointer =
cast[pointer](cast[uint]((v).buf) + uint((v).boffset))
@ -514,7 +555,7 @@ else:
if not isNil(cdata) and (int(cdata.fd) == 0 or isNil(cdata.udata)):
# Transport was closed earlier, exiting
return
var transp = cast[UnixStreamTransport](cdata.udata)
var transp = cast[StreamTransport](cdata.udata)
let fd = SocketHandle(cdata.fd)
if len(transp.queue) > 0:
var vector = transp.queue.popFirst()
@ -562,7 +603,7 @@ else:
if not isNil(cdata) and (int(cdata.fd) == 0 or isNil(cdata.udata)):
# Transport was closed earlier, exiting
return
var transp = cast[UnixStreamTransport](cdata.udata)
var transp = cast[StreamTransport](cdata.udata)
let fd = SocketHandle(cdata.fd)
while true:
var res = posix.recv(fd, addr transp.buffer[transp.offset],
@ -589,18 +630,25 @@ else:
transp.finishReader()
break
proc newStreamSocketTransport(sock: AsyncFD, bufsize: int): StreamTransport =
var transp = UnixStreamTransport(kind: TransportKind.Socket)
proc newStreamSocketTransport(sock: AsyncFD, bufsize: int,
child: StreamTransport): StreamTransport =
var transp: StreamTransport
if not isNil(child):
transp = child
else:
transp = StreamTransport(kind: TransportKind.Socket)
transp.fd = sock
transp.buffer = newSeq[byte](bufsize)
transp.state = {ReadPaused, WritePaused}
transp.queue = initDeque[StreamVector]()
transp.future = newFuture[void]("socket.stream.transport")
GC_ref(transp)
result = cast[StreamTransport](transp)
result = transp
proc connect*(address: TransportAddress,
bufferSize = DefaultStreamBufferSize): Future[StreamTransport] =
bufferSize = DefaultStreamBufferSize,
child: StreamTransport = nil): Future[StreamTransport] =
## Open new connection to remote peer with address ``address`` and create
## new transport object ``StreamTransport`` for established connection.
## ``bufferSize`` - size of internal buffer for transport.
@ -613,7 +661,7 @@ else:
sock = createAsyncSocket(address.address.getDomain(), SockType.SOCK_STREAM,
Protocol.IPPROTO_TCP)
if sock == asyncInvalidSocket:
retFuture.fail(newException(OSError, osErrorMsg(osLastError())))
retFuture.fail(newException(TransportOsError, osErrorMsg(osLastError())))
return retFuture
proc continuation(udata: pointer) =
@ -622,20 +670,22 @@ else:
let fd = data.fd
if not fd.getSocketError(err):
fd.closeAsyncSocket()
retFuture.fail(newException(OSError, osErrorMsg(osLastError())))
retFuture.fail(newException(TransportOsError,
osErrorMsg(osLastError())))
return
if err != 0:
fd.closeAsyncSocket()
retFuture.fail(newException(OSError, osErrorMsg(OSErrorCode(err))))
retFuture.fail(newException(TransportOsError,
osErrorMsg(OSErrorCode(err))))
return
fd.removeWriter()
retFuture.complete(newStreamSocketTransport(fd, bufferSize))
retFuture.complete(newStreamSocketTransport(fd, bufferSize, child))
while true:
var res = posix.connect(SocketHandle(sock),
cast[ptr SockAddr](addr saddr), slen)
if res == 0:
retFuture.complete(newStreamSocketTransport(sock, bufferSize))
retFuture.complete(newStreamSocketTransport(sock, bufferSize, child))
break
else:
let err = osLastError()
@ -646,11 +696,11 @@ else:
break
else:
sock.closeAsyncSocket()
retFuture.fail(newException(OSError, osErrorMsg(err)))
retFuture.fail(newException(TransportOsError, osErrorMsg(err)))
break
return retFuture
proc serverCallback(udata: pointer) =
proc acceptLoop(udata: pointer) =
var
saddr: Sockaddr_storage
slen: SockLen
@ -661,8 +711,13 @@ else:
if int(res) > 0:
let sock = wrapAsyncSocket(res)
if sock != asyncInvalidSocket:
discard server.function(server,
newStreamSocketTransport(sock, server.bufferSize))
if not isNil(server.init):
var transp = server.init(server, sock)
discard server.function(server,
newStreamSocketTransport(sock, server.bufferSize, transp))
else:
discard server.function(server,
newStreamSocketTransport(sock, server.bufferSize, nil))
break
else:
let err = osLastError()
@ -670,10 +725,10 @@ else:
continue
else:
## Critical unrecoverable error
raiseOsError(err)
raiseTransportOsError(err)
proc resumeAccept(server: StreamServer) =
addReader(server.sock, serverCallback, cast[pointer](server))
addReader(server.sock, acceptLoop, cast[pointer](server))
proc pauseAccept(server: StreamServer) =
removeReader(server.sock)
@ -709,7 +764,7 @@ proc close*(server: StreamServer) =
## Release ``server`` resources.
if server.status == ServerStatus.Stopped:
closeAsyncSocket(server.sock)
server.status = Closed
server.status = ServerStatus.Closed
server.loopFuture.complete()
if not isNil(server.udata) and GCUserData in server.flags:
GC_unref(cast[ref int](server.udata))
@ -721,6 +776,8 @@ proc createStreamServer*(host: TransportAddress,
sock: AsyncFD = asyncInvalidSocket,
backlog: int = 100,
bufferSize: int = DefaultStreamBufferSize,
child: StreamServer = nil,
init: TransportInitCallback = nil,
udata: pointer = nil): StreamServer =
## Create new TCP stream server.
##
@ -732,6 +789,8 @@ proc createStreamServer*(host: TransportAddress,
## ``backlog`` - number of outstanding connections in the socket's listen
## queue.
## ``bufferSize`` - size of internal buffer for transport.
## ``child`` - existing object ``StreamServer``object to initialize, can be
## used to initalize ``StreamServer`` inherited objects.
## ``udata`` - user-defined pointer.
var
saddr: Sockaddr_storage
@ -742,10 +801,10 @@ proc createStreamServer*(host: TransportAddress,
SockType.SOCK_STREAM,
Protocol.IPPROTO_TCP)
if serverSocket == asyncInvalidSocket:
raiseOsError(osLastError())
raiseTransportOsError(osLastError())
else:
if not setSocketBlocking(SocketHandle(sock), false):
raiseOsError(osLastError())
raiseTransportOsError(osLastError())
register(sock)
serverSocket = sock
@ -754,7 +813,7 @@ proc createStreamServer*(host: TransportAddress,
let err = osLastError()
if sock == asyncInvalidSocket:
closeAsyncSocket(serverSocket)
raiseOsError(err)
raiseTransportOsError(err)
toSockAddr(host.address, host.port, saddr, slen)
if bindAddr(SocketHandle(serverSocket), cast[ptr SockAddr](addr saddr),
@ -762,17 +821,22 @@ proc createStreamServer*(host: TransportAddress,
let err = osLastError()
if sock == asyncInvalidSocket:
closeAsyncSocket(serverSocket)
raiseOsError(err)
raiseTransportOsError(err)
if nativesockets.listen(SocketHandle(serverSocket), cint(backlog)) != 0:
let err = osLastError()
if sock == asyncInvalidSocket:
closeAsyncSocket(serverSocket)
raiseOsError(err)
raiseTransportOsError(err)
if not isNil(child):
result = child
else:
result = StreamServer()
result = StreamServer()
result.sock = serverSocket
result.function = cbproc
result.init = init
result.bufferSize = bufferSize
result.status = Starting
result.loopFuture = newFuture[void]("stream.server")
@ -790,10 +854,11 @@ proc createStreamServer*(host: TransportAddress,
proc createStreamServer*[T](host: TransportAddress,
cbproc: StreamCallback,
flags: set[ServerFlags] = {},
udata: ref T,
sock: AsyncFD = asyncInvalidSocket,
backlog: int = 100,
bufferSize: int = DefaultStreamBufferSize,
udata: ref T): StreamServer =
child: StreamServer = nil): StreamServer =
var fflags = flags + {GCUserData}
GC_ref(udata)
result = createStreamServer(host, cbproc, flags, sock, backlog, bufferSize,

View File

@ -9,10 +9,34 @@
import strutils, unittest
import ../asyncdispatch2
type
CustomServer = ref object of StreamServer
test1: string
test2: string
CustomTransport = ref object of StreamTransport
test: string
proc serveStreamClient(server: StreamServer,
transp: StreamTransport) {.async.} =
discard
proc serveCustomStreamClient(server: StreamServer,
transp: StreamTransport) {.async.} =
var cserver = cast[CustomServer](server)
var ctransp = cast[CustomTransport](transp)
cserver.test1 = "CONNECTION"
cserver.test2 = ctransp.test
transp.close()
server.stop()
server.close()
proc customServerTransport(server: StreamServer,
fd: AsyncFD): StreamTransport =
var transp = CustomTransport()
transp.test = "CUSTOM"
result = cast[StreamTransport](transp)
proc serveDatagramClient(transp: DatagramTransport,
pbytes: pointer, nbytes: int,
raddr: TransportAddress,
@ -21,11 +45,11 @@ proc serveDatagramClient(transp: DatagramTransport,
proc test1(): bool =
var ta = initTAddress("127.0.0.1:31354")
var server1 = createStreamServer(ta, serveStreamClient, {ReuseAddr})
var server1 = createStreamServer(ta, serveStreamClient, {})
server1.start()
server1.stop()
server1.close()
var server2 = createStreamServer(ta, serveStreamClient, {ReuseAddr})
var server2 = createStreamServer(ta, serveStreamClient, {})
server2.start()
server2.stop()
server2.close()
@ -33,19 +57,41 @@ proc test1(): bool =
proc test2(): bool =
var ta = initTAddress("127.0.0.1:31354")
var server1 = createDatagramServer(ta, serveDatagramClient, {ReuseAddr})
var server1 = createDatagramServer(ta, serveDatagramClient, {})
server1.start()
server1.stop()
server1.close()
var server2 = createDatagramServer(ta, serveDatagramClient, {ReuseAddr})
var server2 = createDatagramServer(ta, serveDatagramClient, {})
server2.start()
server2.stop()
server2.close()
result = true
proc client(server: CustomServer, ta: TransportAddress) {.async.} =
var transp = CustomTransport()
transp.test = "CLIENT"
server.start()
var ptransp = await connect(ta, child = transp)
var etransp = cast[CustomTransport](ptransp)
doAssert(etransp.test == "CLIENT")
transp.close()
await server.join()
proc test3(): bool =
var server = CustomServer()
server.test1 = "TEST"
var ta = initTAddress("127.0.0.1:31354")
var pserver = createStreamServer(ta, serveCustomStreamClient, {},
child = cast[StreamServer](server),
init = customServerTransport)
waitFor client(server, ta)
result = (server.test1 == "CONNECTION") and (server.test2 == "CUSTOM")
when isMainModule:
suite "Server's test suite":
test "Stream Server start/stop test":
check test1() == true
test "Stream Server inherited object test":
check test3() == true
test "Datagram Server start/stop test":
check test2() == true