Add datagram transport utility templates send(string) send(seq[byte]).

Fix bugs in stream.nim
Add more tests for stream.nim
This commit is contained in:
cheatfate 2018-06-05 08:51:59 +03:00
parent 2b8eeef7aa
commit 3cb521c920
3 changed files with 152 additions and 47 deletions

View File

@ -559,6 +559,28 @@ proc sendTo*(transp: DatagramTransport, pbytes: pointer, nbytes: int,
if WriteError in transp.state: if WriteError in transp.state:
raise transp.getError() raise transp.getError()
template send*(transp: DatagramTransport, msg: var string): untyped =
## Send message ``msg`` using transport ``transp`` to remote destination
## address which was bounded on transport.
send(transp, addr msg[0], len(msg))
template send*(transp: DatagramTransport, msg: var seq[byte]): untyped =
## Send message ``msg`` using transport ``transp`` to remote destination
## address which was bounded on transport.
send(transp, addr msg[0], len(msg))
template sendTo*(transp: DatagramTransport, msg: var string,
remote: TransportAddress): untyped =
## Send message ``msg`` using transport ``transp`` to remote
## destination address ``remote``.
sendTo(transp, addr msg[0], len(msg), remote)
template sendTo*(transp: DatagramTransport, msg: var seq[byte],
remote: TransportAddress): untyped =
## Send message ``msg`` using transport ``transp`` to remote
## destination address ``remote``.
sendTo(transp, addr msg[0], len(msg), remote)
proc createDatagramServer*(host: TransportAddress, proc createDatagramServer*(host: TransportAddress,
cbproc: DatagramCallback, cbproc: DatagramCallback,
flags: set[ServerFlags] = {}, flags: set[ServerFlags] = {},

View File

@ -680,7 +680,7 @@ else:
let err = osLastError() let err = osLastError()
if int(err) == EINTR: if int(err) == EINTR:
continue continue
elif int(err) in {EBADF, EINVAL, ENOTSOCK, EOPNOTSUPP, EPROTO}: else:
## Critical unrecoverable error ## Critical unrecoverable error
raiseOsError(err) raiseOsError(err)
@ -859,7 +859,7 @@ proc readExactly*(transp: StreamTransport, pbytes: pointer,
if transp.offset == 0: if transp.offset == 0:
if (ReadError in transp.state): if (ReadError in transp.state):
raise transp.getError() raise transp.getError()
if ReadClosed in transp.state or transp.atEof(): if (ReadClosed in transp.state) or transp.atEof():
raise newException(TransportIncompleteError, "Data incomplete!") raise newException(TransportIncompleteError, "Data incomplete!")
if transp.offset >= (nbytes - index): if transp.offset >= (nbytes - index):
@ -894,7 +894,7 @@ proc readOnce*(transp: StreamTransport, pbytes: pointer,
if transp.offset == 0: if transp.offset == 0:
if (ReadError in transp.state): if (ReadError in transp.state):
raise transp.getError() raise transp.getError()
if (ReadEof in transp.state) or (ReadClosed in transp.state): if (ReadClosed in transp.state) or transp.atEof():
result = 0 result = 0
break break
transp.reader = newFuture[void]("stream.transport.readOnce") transp.reader = newFuture[void]("stream.transport.readOnce")
@ -937,13 +937,10 @@ proc readUntil*(transp: StreamTransport, pbytes: pointer, nbytes: int,
var index = 0 var index = 0
while true: while true:
if (transp.offset - index) == 0: if ReadError in transp.state:
if ReadError in transp.state: raise transp.getError()
transp.shiftBuffer(index) if (ReadClosed in transp.state) or transp.atEof():
raise transp.getError() raise newException(TransportIncompleteError, "Data incomplete!")
if (ReadEof in transp.state) or (ReadClosed in transp.state):
transp.shiftBuffer(index)
raise newException(TransportIncompleteError, "Data incomplete!")
index = 0 index = 0
while index < transp.offset: while index < transp.offset:
@ -957,25 +954,23 @@ proc readUntil*(transp: StreamTransport, pbytes: pointer, nbytes: int,
inc(k) inc(k)
else: else:
raise newException(TransportLimitError, "Limit reached!") raise newException(TransportLimitError, "Limit reached!")
if state == len(sep): if state == len(sep):
transp.shiftBuffer(index + 1)
break break
inc(index) inc(index)
if state == len(sep): if state == len(sep):
transp.shiftBuffer(index + 1)
result = k result = k
break break
else: else:
if (transp.offset - index) == 0: transp.shiftBuffer(transp.offset)
transp.reader = newFuture[void]("stream.transport.readUntil") transp.reader = newFuture[void]("stream.transport.readUntil")
if ReadPaused in transp.state: if ReadPaused in transp.state:
transp.resumeRead() transp.resumeRead()
await transp.reader await transp.reader
# we need to clear transp.reader to avoid double completion of this # we need to clear transp.reader to avoid double completion of this
# Future[T], because readLoop continues working. # Future[T], because readLoop continues working.
transp.reader = nil transp.reader = nil
proc readLine*(transp: StreamTransport, limit = 0, proc readLine*(transp: StreamTransport, limit = 0,
sep = "\r\n"): Future[string] {.async.} = sep = "\r\n"): Future[string] {.async.} =
@ -998,13 +993,10 @@ proc readLine*(transp: StreamTransport, limit = 0,
var index = 0 var index = 0
while true: while true:
if (transp.offset - index) == 0: if (ReadError in transp.state):
if (ReadError in transp.state): raise transp.getError()
transp.shiftBuffer(index) if (ReadClosed in transp.state) or transp.atEof():
raise transp.getError() break
if (ReadEof in transp.state) or (ReadClosed in transp.state):
transp.shiftBuffer(index)
break
index = 0 index = 0
while index < transp.offset: while index < transp.offset:
@ -1025,17 +1017,17 @@ proc readLine*(transp: StreamTransport, limit = 0,
if (state == len(sep)) or (lim == len(result)): if (state == len(sep)) or (lim == len(result)):
break break
else: else:
if (transp.offset - index) == 0: transp.shiftBuffer(transp.offset)
transp.reader = newFuture[void]("stream.transport.readLine") transp.reader = newFuture[void]("stream.transport.readLine")
if ReadPaused in transp.state: if ReadPaused in transp.state:
transp.resumeRead() transp.resumeRead()
await transp.reader await transp.reader
# we need to clear transp.reader to avoid double completion of this # we need to clear transp.reader to avoid double completion of this
# Future[T], because readLoop continues working. # Future[T], because readLoop continues working.
transp.reader = nil transp.reader = nil
proc read*(transp: StreamTransport, n = -1): Future[seq[byte]] {.async.} = proc read*(transp: StreamTransport, n = -1): Future[seq[byte]] {.async.} =
## Read all bytes (n == -1) or `n` bytes from transport ``transp``. ## Read all bytes (n == -1) or exactly `n` bytes from transport ``transp``.
## ##
## This procedure allocates buffer seq[byte] and return it as result. ## This procedure allocates buffer seq[byte] and return it as result.
checkClosed(transp) checkClosed(transp)
@ -1044,9 +1036,9 @@ proc read*(transp: StreamTransport, n = -1): Future[seq[byte]] {.async.} =
while true: while true:
if (ReadError in transp.state): if (ReadError in transp.state):
raise transp.getError() raise transp.getError()
if ReadClosed in transp.state or transp.atEof(): if (ReadClosed in transp.state) or transp.atEof():
break break
if transp.offset > 0: if transp.offset > 0:
let s = len(result) let s = len(result)
let o = s + transp.offset let o = s + transp.offset
@ -1057,12 +1049,13 @@ proc read*(transp: StreamTransport, n = -1): Future[seq[byte]] {.async.} =
transp.offset) transp.offset)
transp.offset = 0 transp.offset = 0
else: else:
if transp.offset >= (n - s): let left = n - s
if transp.offset >= left:
# size of buffer data is more then we need, grabbing only part # size of buffer data is more then we need, grabbing only part
result.setLen(n) result.setLen(n)
copyMem(cast[pointer](addr result[s]), addr(transp.buffer[0]), copyMem(cast[pointer](addr result[s]), addr(transp.buffer[0]),
n - s) left)
transp.shiftBuffer(n - s) transp.shiftBuffer(left)
break break
else: else:
# there not enough data in buffer, grabbing all # there not enough data in buffer, grabbing all
@ -1091,6 +1084,7 @@ proc consume*(transp: StreamTransport, n = -1): Future[int] {.async.} =
raise transp.getError() raise transp.getError()
if ReadClosed in transp.state or transp.atEof(): if ReadClosed in transp.state or transp.atEof():
break break
if transp.offset > 0: if transp.offset > 0:
if n == -1: if n == -1:
# consume all incoming data, until EOF # consume all incoming data, until EOF

View File

@ -16,6 +16,8 @@ else:
const const
ConstantMessage = "SOMEDATA" ConstantMessage = "SOMEDATA"
BigMessagePattern = "ABCDEFGHIJKLMNOPQRSTUVWXYZ"
BigMessageCount = 1000
ClientsCount = 100 ClientsCount = 100
MessagesCount = 100 MessagesCount = 100
MessageSize = 20 MessageSize = 20
@ -68,7 +70,8 @@ proc serveClient3(server: StreamServer,
var suffixStr = "SUFFIX" var suffixStr = "SUFFIX"
var suffix = newSeq[byte](6) var suffix = newSeq[byte](6)
copyMem(addr suffix[0], addr suffixStr[0], len(suffixStr)) copyMem(addr suffix[0], addr suffixStr[0], len(suffixStr))
while not transp.atEof(): var counter = MessagesCount
while counter > 0:
zeroMem(addr buffer[0], MessageSize) zeroMem(addr buffer[0], MessageSize)
var res = await transp.readUntil(addr buffer[0], MessageSize, suffix) var res = await transp.readUntil(addr buffer[0], MessageSize, suffix)
doAssert(equalMem(addr buffer[0], addr check[0], len(check))) doAssert(equalMem(addr buffer[0], addr check[0], len(check)))
@ -84,6 +87,7 @@ proc serveClient3(server: StreamServer,
copyMem(addr buffer[0], addr ans[0], len(ans)) copyMem(addr buffer[0], addr ans[0], len(ans))
res = await transp.write(cast[pointer](addr buffer[0]), len(ans)) res = await transp.write(cast[pointer](addr buffer[0]), len(ans))
doAssert(res == len(ans)) doAssert(res == len(ans))
dec(counter)
transp.close() transp.close()
proc serveClient4(server: StreamServer, proc serveClient4(server: StreamServer,
@ -100,6 +104,7 @@ proc serveClient4(server: StreamServer,
var answer = "OK\r\n" var answer = "OK\r\n"
var res = await transp.write(cast[pointer](addr answer[0]), len(answer)) var res = await transp.write(cast[pointer](addr answer[0]), len(answer))
doAssert(res == len(answer)) doAssert(res == len(answer))
transp.close()
proc serveClient5(server: StreamServer, proc serveClient5(server: StreamServer,
transp: StreamTransport, udata: pointer) {.async.} = transp: StreamTransport, udata: pointer) {.async.} =
@ -131,6 +136,41 @@ proc serveClient6(server: StreamServer,
server.stop() server.stop()
server.close() server.close()
proc serveClient7(server: StreamServer,
transp: StreamTransport, udata: pointer) {.async.} =
var answer = "DONE\r\n"
var expect = ""
var line = await transp.readLine()
doAssert(len(line) == BigMessageCount * len(BigMessagePattern))
for i in 0..<BigMessageCount:
expect.add(BigMessagePattern)
doAssert(line == expect)
var res = await transp.write(answer)
doAssert(res == len(answer))
transp.close()
proc serveClient8(server: StreamServer,
transp: StreamTransport, udata: pointer) {.async.} =
var answer = "DONE\r\n"
var strpattern = BigMessagePattern
var pattern = newSeq[byte](len(BigMessagePattern))
var expect = newSeq[byte]()
var data = newSeq[byte]((BigMessageCount + 1) * len(BigMessagePattern))
var sep = @[0x0D'u8, 0x0A'u8]
copyMem(addr pattern[0], addr strpattern[0], len(BigMessagePattern))
var count = await transp.readUntil(addr data[0], len(data), sep = sep)
doAssert(count == BigMessageCount * len(BigMessagePattern) + 2)
for i in 0..<BigMessageCount:
expect.add(pattern)
expect.add(sep)
data.setLen(count)
doAssert(expect == data)
var res = await transp.write(answer)
doAssert(res == len(answer))
transp.close()
server.stop()
server.close()
proc swarmWorker1(address: TransportAddress): Future[int] {.async.} = proc swarmWorker1(address: TransportAddress): Future[int] {.async.} =
var transp = await connect(address) var transp = await connect(address)
for i in 0..<MessagesCount: for i in 0..<MessagesCount:
@ -214,6 +254,7 @@ proc swarmWorker4(address: TransportAddress): Future[int] {.async.} =
res = await transp.write(cast[pointer](addr ssize[0]), len(ssize)) res = await transp.write(cast[pointer](addr ssize[0]), len(ssize))
doAssert(res == len(ssize)) doAssert(res == len(ssize))
await transp.writeFile(handle, 0'u, size) await transp.writeFile(handle, 0'u, size)
close(fhandle)
var ans = await transp.readLine() var ans = await transp.readLine()
doAssert(ans == "OK") doAssert(ans == "OK")
result = 1 result = 1
@ -237,6 +278,32 @@ proc swarmWorker6(address: TransportAddress): Future[int] {.async.} =
result = MessagesCount result = MessagesCount
transp.close() transp.close()
proc swarmWorker7(address: TransportAddress): Future[int] {.async.} =
var transp = await connect(address)
var data = BigMessagePattern
var crlf = "\r\n"
for i in 0..<BigMessageCount:
var res = await transp.write(data)
doAssert(res == len(data))
var res = await transp.write(crlf)
var line = await transp.readLine()
doAssert(line == "DONE")
result = 1
transp.close()
proc swarmWorker8(address: TransportAddress): Future[int] {.async.} =
var transp = await connect(address)
var data = BigMessagePattern
var crlf = "\r\n"
for i in 0..<BigMessageCount:
var res = await transp.write(data)
doAssert(res == len(data))
var res = await transp.write(crlf)
var line = await transp.readLine()
doAssert(line == "DONE")
result = 1
transp.close()
proc waitAll[T](futs: seq[Future[T]]): Future[void] = proc waitAll[T](futs: seq[Future[T]]): Future[void] =
var counter = len(futs) var counter = len(futs)
var retFuture = newFuture[void]("waitAll") var retFuture = newFuture[void]("waitAll")
@ -349,7 +416,7 @@ proc test4(): Future[int] {.async.} =
server.close() server.close()
proc test5(): Future[int] {.async.} = proc test5(): Future[int] {.async.} =
var ta = initTAddress("127.0.0.1:31347") var ta = initTAddress("127.0.0.1:31348")
var counter = ClientsCount var counter = ClientsCount
var server = createStreamServer(ta, serveClient5, {ReuseAddr}, var server = createStreamServer(ta, serveClient5, {ReuseAddr},
udata = cast[pointer](addr counter)) udata = cast[pointer](addr counter))
@ -358,7 +425,7 @@ proc test5(): Future[int] {.async.} =
await server.join() await server.join()
proc test6(): Future[int] {.async.} = proc test6(): Future[int] {.async.} =
var ta = initTAddress("127.0.0.1:31347") var ta = initTAddress("127.0.0.1:31349")
var counter = ClientsCount var counter = ClientsCount
var server = createStreamServer(ta, serveClient6, {ReuseAddr}, var server = createStreamServer(ta, serveClient6, {ReuseAddr},
udata = cast[pointer](addr counter)) udata = cast[pointer](addr counter))
@ -366,6 +433,22 @@ proc test6(): Future[int] {.async.} =
result = await swarmManager6(ta) result = await swarmManager6(ta)
await server.join() await server.join()
proc test7(): Future[int] {.async.} =
var ta = initTAddress("127.0.0.1:31350")
var server = createStreamServer(ta, serveClient7, {ReuseAddr})
server.start()
result = await swarmWorker7(ta)
server.stop()
server.close()
proc test8(): Future[int] {.async.} =
var ta = initTAddress("127.0.0.1:31350")
var server = createStreamServer(ta, serveClient8, {ReuseAddr})
server.start()
result = await swarmWorker8(ta)
server.stop()
server.close()
when isMainModule: when isMainModule:
const const
m1 = "readLine() multiple clients with messages (" & $ClientsCount & m1 = "readLine() multiple clients with messages (" & $ClientsCount &
@ -379,17 +462,23 @@ when isMainModule:
" clients x " & $MessagesCount & " messages)" " clients x " & $MessagesCount & " messages)"
m6 = "write(seq[byte])/consume(int)/read(int) multiple clients (" & m6 = "write(seq[byte])/consume(int)/read(int) multiple clients (" &
$ClientsCount & " clients x " & $MessagesCount & " messages)" $ClientsCount & " clients x " & $MessagesCount & " messages)"
m7 = "readLine() buffer overflow test"
m8 = "readUntil() buffer overflow test"
suite "Stream Transport test suite": suite "Stream Transport test suite":
test m8:
check waitFor(test8()) == 1
test m7:
check waitFor(test7()) == 1
test m1: test m1:
check waitFor(test1()) == ClientsCount * MessagesCount check waitFor(test1()) == ClientsCount * MessagesCount
test m2: test m2:
check waitFor(test2()) == ClientsCount * MessagesCount check waitFor(test2()) == ClientsCount * MessagesCount
test m3: test m3:
check waitFor(test3()) == ClientsCount * MessagesCount check waitFor(test3()) == ClientsCount * MessagesCount
test m4:
check waitFor(test4()) == FilesCount
test m5: test m5:
check waitFor(test5()) == ClientsCount * MessagesCount check waitFor(test5()) == ClientsCount * MessagesCount
test m6: test m6:
check waitFor(test6()) == ClientsCount * MessagesCount check waitFor(test6()) == ClientsCount * MessagesCount
test m4:
check waitFor(test4()) == FilesCount