Fix windows asyncLoop.

Fix OSError -> TransportOSError.
Add inherited objects initialization.
Add tests for inherited objects.
This commit is contained in:
cheatfate 2018-06-11 02:08:17 +03:00
parent 0ee9a148c7
commit 5815897de6
3 changed files with 201 additions and 85 deletions

View File

@ -281,6 +281,11 @@ template getError*(t: untyped): ref Exception =
(t).error = nil (t).error = nil
err err
proc raiseTransportOsError*(err: OSErrorCode) =
## Raises transport specific OS error.
var msg = "(" & $int(err) & ") " & osErrorMsg(err)
raise newException(TransportOsError, msg)
when defined(windows): when defined(windows):
import winlean import winlean

View File

@ -33,6 +33,35 @@ type
Pipe, # Pipe transport Pipe, # Pipe transport
File # File transport File # File transport
when defined(windows):
const SO_UPDATE_CONNECT_CONTEXT = 0x7010
type
StreamTransport* = ref object of RootRef
fd*: AsyncFD # File descriptor
state: set[TransportState] # Current Transport state
reader: Future[void] # Current reader Future
buffer: seq[byte] # Reading buffer
offset: int # Reading buffer offset
error: ref Exception # Current error
queue: Deque[StreamVector] # Writer queue
future: Future[void] # Stream life future
# Windows specific part
rwsabuf: TWSABuf # Reader WSABUF
wwsabuf: TWSABuf # Writer WSABUF
rovl: CustomOverlapped # Reader OVERLAPPED structure
wovl: CustomOverlapped # Writer OVERLAPPED structure
roffset: int # Pending reading offset
case kind*: TransportKind
of TransportKind.Socket:
domain: Domain # Socket transport domain (IPv4/IPv6)
local: TransportAddress # Local address
remote: TransportAddress # Remote address
of TransportKind.Pipe:
todo1: int
of TransportKind.File:
todo2: int
else:
type type
StreamTransport* = ref object of RootRef StreamTransport* = ref object of RootRef
fd*: AsyncFD # File descriptor fd*: AsyncFD # File descriptor
@ -53,14 +82,25 @@ type
of TransportKind.File: of TransportKind.File:
todo2: int todo2: int
type
StreamCallback* = proc(server: StreamServer, StreamCallback* = proc(server: StreamServer,
client: StreamTransport): Future[void] {.gcsafe.} client: StreamTransport): Future[void] {.gcsafe.}
## New remote client connection callback ## New remote client connection callback
## ``server`` - StreamServer object. ## ``server`` - StreamServer object.
## ``client`` - accepted client transport. ## ``client`` - accepted client transport.
TransportInitCallback* = proc(server: StreamServer,
fd: AsyncFD): StreamTransport {.gcsafe.}
## Custom transport initialization procedure, which can allocated inherited
## StreamTransport object.
StreamServer* = ref object of SocketServer StreamServer* = ref object of SocketServer
function*: StreamCallback ## StreamServer object
function*: StreamCallback # callback which will be called after new
# client accepted
init*: TransportInitCallback # callback which will be called before
# transport for new client
proc remoteAddress*(transp: StreamTransport): TransportAddress = proc remoteAddress*(transp: StreamTransport): TransportAddress =
## Returns ``transp`` remote socket address. ## Returns ``transp`` remote socket address.
@ -116,16 +156,6 @@ template shiftVectorFile(v, o: untyped) =
(v).offset += cast[uint]((o)) (v).offset += cast[uint]((o))
when defined(windows): when defined(windows):
import winlean
type
WindowsStreamTransport = ref object of StreamTransport
rwsabuf: TWSABuf # Reader WSABUF
wwsabuf: TWSABuf # Writer WSABUF
rovl: CustomOverlapped # Reader OVERLAPPED structure
wovl: CustomOverlapped # Writer OVERLAPPED structure
roffset: int # Pending reading offset
const SO_UPDATE_CONNECT_CONTEXT = 0x7010
template zeroOvelappedOffset(t: untyped) = template zeroOvelappedOffset(t: untyped) =
(t).offset = 0 (t).offset = 0
@ -157,7 +187,7 @@ when defined(windows):
proc writeStreamLoop(udata: pointer) {.gcsafe, nimcall.} = proc writeStreamLoop(udata: pointer) {.gcsafe, nimcall.} =
var bytesCount: int32 var bytesCount: int32
var ovl = cast[PtrCustomOverlapped](udata) var ovl = cast[PtrCustomOverlapped](udata)
var transp = cast[WindowsStreamTransport](ovl.data.udata) var transp = cast[StreamTransport](ovl.data.udata)
while len(transp.queue) > 0: while len(transp.queue) > 0:
if WritePending in transp.state: if WritePending in transp.state:
@ -258,7 +288,7 @@ when defined(windows):
proc readStreamLoop(udata: pointer) {.gcsafe, nimcall.} = proc readStreamLoop(udata: pointer) {.gcsafe, nimcall.} =
var ovl = cast[PtrCustomOverlapped](udata) var ovl = cast[PtrCustomOverlapped](udata)
var transp = cast[WindowsStreamTransport](ovl.data.udata) var transp = cast[StreamTransport](ovl.data.udata)
while true: while true:
if ReadPending in transp.state: if ReadPending in transp.state:
@ -324,8 +354,13 @@ when defined(windows):
## Finish Loop ## Finish Loop
break break
proc newStreamSocketTransport(sock: AsyncFD, bufsize: int): StreamTransport = proc newStreamSocketTransport(sock: AsyncFD, bufsize: int,
var transp = WindowsStreamTransport(kind: TransportKind.Socket) child: StreamTransport): StreamTransport =
var transp: StreamTransport
if not isNil(child):
transp = child
else:
transp = StreamTransport(kind: TransportKind.Socket)
transp.fd = sock transp.fd = sock
transp.rovl.data = CompletionData(fd: sock, cb: readStreamLoop, transp.rovl.data = CompletionData(fd: sock, cb: readStreamLoop,
udata: cast[pointer](transp)) udata: cast[pointer](transp))
@ -335,12 +370,8 @@ when defined(windows):
transp.state = {ReadPaused, WritePaused} transp.state = {ReadPaused, WritePaused}
transp.queue = initDeque[StreamVector]() transp.queue = initDeque[StreamVector]()
transp.future = newFuture[void]("stream.socket.transport") transp.future = newFuture[void]("stream.socket.transport")
# ZAH: If these objects are going to be manually managed, why do we bother
# with using the GC at all? It's better to rely on a destructor. If someone
# wants to share a Transport reference, they can still create a GC-managed
# wrapping object.
GC_ref(transp) GC_ref(transp)
result = cast[StreamTransport](transp) result = transp
proc bindToDomain(handle: AsyncFD, domain: Domain): bool = proc bindToDomain(handle: AsyncFD, domain: Domain): bool =
result = true result = true
@ -358,7 +389,8 @@ when defined(windows):
result = false result = false
proc connect*(address: TransportAddress, proc connect*(address: TransportAddress,
bufferSize = DefaultStreamBufferSize): Future[StreamTransport] = bufferSize = DefaultStreamBufferSize,
child: StreamTransport = nil): Future[StreamTransport] =
## Open new connection to remote peer with address ``address`` and create ## Open new connection to remote peer with address ``address`` and create
## new transport object ``StreamTransport`` for established connection. ## new transport object ``StreamTransport`` for established connection.
## ``bufferSize`` is size of internal buffer for transport. ## ``bufferSize`` is size of internal buffer for transport.
@ -392,7 +424,8 @@ when defined(windows):
osErrorMsg(osLastError()))) osErrorMsg(osLastError())))
else: else:
retFuture.complete(newStreamSocketTransport(povl.data.fd, retFuture.complete(newStreamSocketTransport(povl.data.fd,
bufferSize)) bufferSize,
child))
else: else:
sock.closeAsyncSocket() sock.closeAsyncSocket()
retFuture.fail(newException(TransportOsError, retFuture.fail(newException(TransportOsError,
@ -433,24 +466,38 @@ when defined(windows):
addr server.sock, addr server.sock,
SockLen(sizeof(SocketHandle))) != 0'i32: SockLen(sizeof(SocketHandle))) != 0'i32:
server.asock.closeAsyncSocket() server.asock.closeAsyncSocket()
raise newException(TransportOsError, osErrorMsg(osLastError())) raiseTransportOsError(osLastError())
else: else:
discard server.function(server, if not isNil(server.init):
newStreamSocketTransport(server.asock, server.bufferSize)) var transp = server.init(server, server.asock)
discard server.function(
server,
newStreamSocketTransport(server.asock, server.bufferSize,
transp)
)
else:
discard server.function(
server,
newStreamSocketTransport(server.asock, server.bufferSize, nil)
)
elif int32(ovl.data.errCode) == ERROR_OPERATION_ABORTED: elif int32(ovl.data.errCode) == ERROR_OPERATION_ABORTED:
# CancelIO() interrupt # CancelIO() interrupt
server.asock.closeAsyncSocket() server.asock.closeAsyncSocket()
break break
else: else:
server.asock.closeAsyncSocket() server.asock.closeAsyncSocket()
raiseOsError(osLastError()) raiseTransportOsError(osLastError())
else: else:
## Initiation ## Initiation
if server.status in {ServerStatus.Stopped, ServerStatus.Closed}:
## Server was already stopped/closed exiting
break
server.apending = true server.apending = true
server.asock = createAsyncSocket(server.domain, SockType.SOCK_STREAM, server.asock = createAsyncSocket(server.domain, SockType.SOCK_STREAM,
Protocol.IPPROTO_TCP) Protocol.IPPROTO_TCP)
if server.asock == asyncInvalidSocket: if server.asock == asyncInvalidSocket:
raiseOsError(osLastError()) raiseTransportOsError(osLastError())
var dwBytesReceived = DWORD(0) var dwBytesReceived = DWORD(0)
let dwReceiveDataLength = DWORD(0) let dwReceiveDataLength = DWORD(0)
@ -471,18 +518,16 @@ when defined(windows):
elif int32(err) == ERROR_IO_PENDING: elif int32(err) == ERROR_IO_PENDING:
discard discard
else: else:
raiseOsError(osLastError()) raiseTransportOsError(err)
break break
proc resumeRead(transp: StreamTransport) {.inline.} = proc resumeRead(transp: StreamTransport) {.inline.} =
var wtransp = cast[WindowsStreamTransport](transp) transp.state.excl(ReadPaused)
wtransp.state.excl(ReadPaused) readStreamLoop(cast[pointer](addr transp.rovl))
readStreamLoop(cast[pointer](addr wtransp.rovl))
proc resumeWrite(transp: StreamTransport) {.inline.} = proc resumeWrite(transp: StreamTransport) {.inline.} =
var wtransp = cast[WindowsStreamTransport](transp) transp.state.excl(WritePaused)
wtransp.state.excl(WritePaused) writeStreamLoop(cast[pointer](addr transp.wovl))
writeStreamLoop(cast[pointer](addr wtransp.wovl))
proc pauseAccept(server: StreamServer) {.inline.} = proc pauseAccept(server: StreamServer) {.inline.} =
if server.apending: if server.apending:
@ -492,10 +537,6 @@ when defined(windows):
if not server.apending: if not server.apending:
acceptLoop(cast[pointer](addr server.aovl)) acceptLoop(cast[pointer](addr server.aovl))
else: else:
import posix
type
UnixStreamTransport* = ref object of StreamTransport
template getVectorBuffer(v: untyped): pointer = template getVectorBuffer(v: untyped): pointer =
cast[pointer](cast[uint]((v).buf) + uint((v).boffset)) cast[pointer](cast[uint]((v).buf) + uint((v).boffset))
@ -514,7 +555,7 @@ else:
if not isNil(cdata) and (int(cdata.fd) == 0 or isNil(cdata.udata)): if not isNil(cdata) and (int(cdata.fd) == 0 or isNil(cdata.udata)):
# Transport was closed earlier, exiting # Transport was closed earlier, exiting
return return
var transp = cast[UnixStreamTransport](cdata.udata) var transp = cast[StreamTransport](cdata.udata)
let fd = SocketHandle(cdata.fd) let fd = SocketHandle(cdata.fd)
if len(transp.queue) > 0: if len(transp.queue) > 0:
var vector = transp.queue.popFirst() var vector = transp.queue.popFirst()
@ -562,7 +603,7 @@ else:
if not isNil(cdata) and (int(cdata.fd) == 0 or isNil(cdata.udata)): if not isNil(cdata) and (int(cdata.fd) == 0 or isNil(cdata.udata)):
# Transport was closed earlier, exiting # Transport was closed earlier, exiting
return return
var transp = cast[UnixStreamTransport](cdata.udata) var transp = cast[StreamTransport](cdata.udata)
let fd = SocketHandle(cdata.fd) let fd = SocketHandle(cdata.fd)
while true: while true:
var res = posix.recv(fd, addr transp.buffer[transp.offset], var res = posix.recv(fd, addr transp.buffer[transp.offset],
@ -589,18 +630,25 @@ else:
transp.finishReader() transp.finishReader()
break break
proc newStreamSocketTransport(sock: AsyncFD, bufsize: int): StreamTransport = proc newStreamSocketTransport(sock: AsyncFD, bufsize: int,
var transp = UnixStreamTransport(kind: TransportKind.Socket) child: StreamTransport): StreamTransport =
var transp: StreamTransport
if not isNil(child):
transp = child
else:
transp = StreamTransport(kind: TransportKind.Socket)
transp.fd = sock transp.fd = sock
transp.buffer = newSeq[byte](bufsize) transp.buffer = newSeq[byte](bufsize)
transp.state = {ReadPaused, WritePaused} transp.state = {ReadPaused, WritePaused}
transp.queue = initDeque[StreamVector]() transp.queue = initDeque[StreamVector]()
transp.future = newFuture[void]("socket.stream.transport") transp.future = newFuture[void]("socket.stream.transport")
GC_ref(transp) GC_ref(transp)
result = cast[StreamTransport](transp) result = transp
proc connect*(address: TransportAddress, proc connect*(address: TransportAddress,
bufferSize = DefaultStreamBufferSize): Future[StreamTransport] = bufferSize = DefaultStreamBufferSize,
child: StreamTransport = nil): Future[StreamTransport] =
## Open new connection to remote peer with address ``address`` and create ## Open new connection to remote peer with address ``address`` and create
## new transport object ``StreamTransport`` for established connection. ## new transport object ``StreamTransport`` for established connection.
## ``bufferSize`` - size of internal buffer for transport. ## ``bufferSize`` - size of internal buffer for transport.
@ -613,7 +661,7 @@ else:
sock = createAsyncSocket(address.address.getDomain(), SockType.SOCK_STREAM, sock = createAsyncSocket(address.address.getDomain(), SockType.SOCK_STREAM,
Protocol.IPPROTO_TCP) Protocol.IPPROTO_TCP)
if sock == asyncInvalidSocket: if sock == asyncInvalidSocket:
retFuture.fail(newException(OSError, osErrorMsg(osLastError()))) retFuture.fail(newException(TransportOsError, osErrorMsg(osLastError())))
return retFuture return retFuture
proc continuation(udata: pointer) = proc continuation(udata: pointer) =
@ -622,20 +670,22 @@ else:
let fd = data.fd let fd = data.fd
if not fd.getSocketError(err): if not fd.getSocketError(err):
fd.closeAsyncSocket() fd.closeAsyncSocket()
retFuture.fail(newException(OSError, osErrorMsg(osLastError()))) retFuture.fail(newException(TransportOsError,
osErrorMsg(osLastError())))
return return
if err != 0: if err != 0:
fd.closeAsyncSocket() fd.closeAsyncSocket()
retFuture.fail(newException(OSError, osErrorMsg(OSErrorCode(err)))) retFuture.fail(newException(TransportOsError,
osErrorMsg(OSErrorCode(err))))
return return
fd.removeWriter() fd.removeWriter()
retFuture.complete(newStreamSocketTransport(fd, bufferSize)) retFuture.complete(newStreamSocketTransport(fd, bufferSize, child))
while true: while true:
var res = posix.connect(SocketHandle(sock), var res = posix.connect(SocketHandle(sock),
cast[ptr SockAddr](addr saddr), slen) cast[ptr SockAddr](addr saddr), slen)
if res == 0: if res == 0:
retFuture.complete(newStreamSocketTransport(sock, bufferSize)) retFuture.complete(newStreamSocketTransport(sock, bufferSize, child))
break break
else: else:
let err = osLastError() let err = osLastError()
@ -646,11 +696,11 @@ else:
break break
else: else:
sock.closeAsyncSocket() sock.closeAsyncSocket()
retFuture.fail(newException(OSError, osErrorMsg(err))) retFuture.fail(newException(TransportOsError, osErrorMsg(err)))
break break
return retFuture return retFuture
proc serverCallback(udata: pointer) = proc acceptLoop(udata: pointer) =
var var
saddr: Sockaddr_storage saddr: Sockaddr_storage
slen: SockLen slen: SockLen
@ -661,8 +711,13 @@ else:
if int(res) > 0: if int(res) > 0:
let sock = wrapAsyncSocket(res) let sock = wrapAsyncSocket(res)
if sock != asyncInvalidSocket: if sock != asyncInvalidSocket:
if not isNil(server.init):
var transp = server.init(server, sock)
discard server.function(server, discard server.function(server,
newStreamSocketTransport(sock, server.bufferSize)) newStreamSocketTransport(sock, server.bufferSize, transp))
else:
discard server.function(server,
newStreamSocketTransport(sock, server.bufferSize, nil))
break break
else: else:
let err = osLastError() let err = osLastError()
@ -670,10 +725,10 @@ else:
continue continue
else: else:
## Critical unrecoverable error ## Critical unrecoverable error
raiseOsError(err) raiseTransportOsError(err)
proc resumeAccept(server: StreamServer) = proc resumeAccept(server: StreamServer) =
addReader(server.sock, serverCallback, cast[pointer](server)) addReader(server.sock, acceptLoop, cast[pointer](server))
proc pauseAccept(server: StreamServer) = proc pauseAccept(server: StreamServer) =
removeReader(server.sock) removeReader(server.sock)
@ -709,7 +764,7 @@ proc close*(server: StreamServer) =
## Release ``server`` resources. ## Release ``server`` resources.
if server.status == ServerStatus.Stopped: if server.status == ServerStatus.Stopped:
closeAsyncSocket(server.sock) closeAsyncSocket(server.sock)
server.status = Closed server.status = ServerStatus.Closed
server.loopFuture.complete() server.loopFuture.complete()
if not isNil(server.udata) and GCUserData in server.flags: if not isNil(server.udata) and GCUserData in server.flags:
GC_unref(cast[ref int](server.udata)) GC_unref(cast[ref int](server.udata))
@ -721,6 +776,8 @@ proc createStreamServer*(host: TransportAddress,
sock: AsyncFD = asyncInvalidSocket, sock: AsyncFD = asyncInvalidSocket,
backlog: int = 100, backlog: int = 100,
bufferSize: int = DefaultStreamBufferSize, bufferSize: int = DefaultStreamBufferSize,
child: StreamServer = nil,
init: TransportInitCallback = nil,
udata: pointer = nil): StreamServer = udata: pointer = nil): StreamServer =
## Create new TCP stream server. ## Create new TCP stream server.
## ##
@ -732,6 +789,8 @@ proc createStreamServer*(host: TransportAddress,
## ``backlog`` - number of outstanding connections in the socket's listen ## ``backlog`` - number of outstanding connections in the socket's listen
## queue. ## queue.
## ``bufferSize`` - size of internal buffer for transport. ## ``bufferSize`` - size of internal buffer for transport.
## ``child`` - existing object ``StreamServer``object to initialize, can be
## used to initalize ``StreamServer`` inherited objects.
## ``udata`` - user-defined pointer. ## ``udata`` - user-defined pointer.
var var
saddr: Sockaddr_storage saddr: Sockaddr_storage
@ -742,10 +801,10 @@ proc createStreamServer*(host: TransportAddress,
SockType.SOCK_STREAM, SockType.SOCK_STREAM,
Protocol.IPPROTO_TCP) Protocol.IPPROTO_TCP)
if serverSocket == asyncInvalidSocket: if serverSocket == asyncInvalidSocket:
raiseOsError(osLastError()) raiseTransportOsError(osLastError())
else: else:
if not setSocketBlocking(SocketHandle(sock), false): if not setSocketBlocking(SocketHandle(sock), false):
raiseOsError(osLastError()) raiseTransportOsError(osLastError())
register(sock) register(sock)
serverSocket = sock serverSocket = sock
@ -754,7 +813,7 @@ proc createStreamServer*(host: TransportAddress,
let err = osLastError() let err = osLastError()
if sock == asyncInvalidSocket: if sock == asyncInvalidSocket:
closeAsyncSocket(serverSocket) closeAsyncSocket(serverSocket)
raiseOsError(err) raiseTransportOsError(err)
toSockAddr(host.address, host.port, saddr, slen) toSockAddr(host.address, host.port, saddr, slen)
if bindAddr(SocketHandle(serverSocket), cast[ptr SockAddr](addr saddr), if bindAddr(SocketHandle(serverSocket), cast[ptr SockAddr](addr saddr),
@ -762,17 +821,22 @@ proc createStreamServer*(host: TransportAddress,
let err = osLastError() let err = osLastError()
if sock == asyncInvalidSocket: if sock == asyncInvalidSocket:
closeAsyncSocket(serverSocket) closeAsyncSocket(serverSocket)
raiseOsError(err) raiseTransportOsError(err)
if nativesockets.listen(SocketHandle(serverSocket), cint(backlog)) != 0: if nativesockets.listen(SocketHandle(serverSocket), cint(backlog)) != 0:
let err = osLastError() let err = osLastError()
if sock == asyncInvalidSocket: if sock == asyncInvalidSocket:
closeAsyncSocket(serverSocket) closeAsyncSocket(serverSocket)
raiseOsError(err) raiseTransportOsError(err)
if not isNil(child):
result = child
else:
result = StreamServer() result = StreamServer()
result.sock = serverSocket result.sock = serverSocket
result.function = cbproc result.function = cbproc
result.init = init
result.bufferSize = bufferSize result.bufferSize = bufferSize
result.status = Starting result.status = Starting
result.loopFuture = newFuture[void]("stream.server") result.loopFuture = newFuture[void]("stream.server")
@ -790,10 +854,11 @@ proc createStreamServer*(host: TransportAddress,
proc createStreamServer*[T](host: TransportAddress, proc createStreamServer*[T](host: TransportAddress,
cbproc: StreamCallback, cbproc: StreamCallback,
flags: set[ServerFlags] = {}, flags: set[ServerFlags] = {},
udata: ref T,
sock: AsyncFD = asyncInvalidSocket, sock: AsyncFD = asyncInvalidSocket,
backlog: int = 100, backlog: int = 100,
bufferSize: int = DefaultStreamBufferSize, bufferSize: int = DefaultStreamBufferSize,
udata: ref T): StreamServer = child: StreamServer = nil): StreamServer =
var fflags = flags + {GCUserData} var fflags = flags + {GCUserData}
GC_ref(udata) GC_ref(udata)
result = createStreamServer(host, cbproc, flags, sock, backlog, bufferSize, result = createStreamServer(host, cbproc, flags, sock, backlog, bufferSize,

View File

@ -9,10 +9,34 @@
import strutils, unittest import strutils, unittest
import ../asyncdispatch2 import ../asyncdispatch2
type
CustomServer = ref object of StreamServer
test1: string
test2: string
CustomTransport = ref object of StreamTransport
test: string
proc serveStreamClient(server: StreamServer, proc serveStreamClient(server: StreamServer,
transp: StreamTransport) {.async.} = transp: StreamTransport) {.async.} =
discard discard
proc serveCustomStreamClient(server: StreamServer,
transp: StreamTransport) {.async.} =
var cserver = cast[CustomServer](server)
var ctransp = cast[CustomTransport](transp)
cserver.test1 = "CONNECTION"
cserver.test2 = ctransp.test
transp.close()
server.stop()
server.close()
proc customServerTransport(server: StreamServer,
fd: AsyncFD): StreamTransport =
var transp = CustomTransport()
transp.test = "CUSTOM"
result = cast[StreamTransport](transp)
proc serveDatagramClient(transp: DatagramTransport, proc serveDatagramClient(transp: DatagramTransport,
pbytes: pointer, nbytes: int, pbytes: pointer, nbytes: int,
raddr: TransportAddress, raddr: TransportAddress,
@ -21,11 +45,11 @@ proc serveDatagramClient(transp: DatagramTransport,
proc test1(): bool = proc test1(): bool =
var ta = initTAddress("127.0.0.1:31354") var ta = initTAddress("127.0.0.1:31354")
var server1 = createStreamServer(ta, serveStreamClient, {ReuseAddr}) var server1 = createStreamServer(ta, serveStreamClient, {})
server1.start() server1.start()
server1.stop() server1.stop()
server1.close() server1.close()
var server2 = createStreamServer(ta, serveStreamClient, {ReuseAddr}) var server2 = createStreamServer(ta, serveStreamClient, {})
server2.start() server2.start()
server2.stop() server2.stop()
server2.close() server2.close()
@ -33,19 +57,41 @@ proc test1(): bool =
proc test2(): bool = proc test2(): bool =
var ta = initTAddress("127.0.0.1:31354") var ta = initTAddress("127.0.0.1:31354")
var server1 = createDatagramServer(ta, serveDatagramClient, {ReuseAddr}) var server1 = createDatagramServer(ta, serveDatagramClient, {})
server1.start() server1.start()
server1.stop() server1.stop()
server1.close() server1.close()
var server2 = createDatagramServer(ta, serveDatagramClient, {ReuseAddr}) var server2 = createDatagramServer(ta, serveDatagramClient, {})
server2.start() server2.start()
server2.stop() server2.stop()
server2.close() server2.close()
result = true result = true
proc client(server: CustomServer, ta: TransportAddress) {.async.} =
var transp = CustomTransport()
transp.test = "CLIENT"
server.start()
var ptransp = await connect(ta, child = transp)
var etransp = cast[CustomTransport](ptransp)
doAssert(etransp.test == "CLIENT")
transp.close()
await server.join()
proc test3(): bool =
var server = CustomServer()
server.test1 = "TEST"
var ta = initTAddress("127.0.0.1:31354")
var pserver = createStreamServer(ta, serveCustomStreamClient, {},
child = cast[StreamServer](server),
init = customServerTransport)
waitFor client(server, ta)
result = (server.test1 == "CONNECTION") and (server.test2 == "CUSTOM")
when isMainModule: when isMainModule:
suite "Server's test suite": suite "Server's test suite":
test "Stream Server start/stop test": test "Stream Server start/stop test":
check test1() == true check test1() == true
test "Stream Server inherited object test":
check test3() == true
test "Datagram Server start/stop test": test "Datagram Server start/stop test":
check test2() == true check test2() == true