consolidate reading in lpstream (#241)

* consolidate reading in lpstream

* remove debug echo

* throw if not enough bytes where read

* tune log level

* set eof flag

* test readExactly to fail on not enough bytes
This commit is contained in:
Dmitriy Ryajov 2020-06-27 11:33:34 -06:00 committed by GitHub
parent 7a95f1844b
commit 902880ef1f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
15 changed files with 104 additions and 109 deletions

View File

@ -202,5 +202,4 @@ method close*(m: Mplex) {.async, gcsafe.} =
finally: finally:
m.remote.clear() m.remote.clear()
m.local.clear() m.local.clear()
# m.handlerFuts = @[]
m.isClosed = true m.isClosed = true

View File

@ -27,7 +27,7 @@ const
ProtoVersion* = "ipfs/0.1.0" ProtoVersion* = "ipfs/0.1.0"
AgentVersion* = "nim-libp2p/0.0.1" AgentVersion* = "nim-libp2p/0.0.1"
#TODO: implment push identify, leaving out for now as it is not essential #TODO: implement push identify, leaving out for now as it is not essential
type type
IdentityNoMatchError* = object of CatchableError IdentityNoMatchError* = object of CatchableError
@ -141,7 +141,7 @@ proc identify*(p: Identify,
if not isNil(remotePeerInfo) and result.pubKey.isSome: if not isNil(remotePeerInfo) and result.pubKey.isSome:
let peer = PeerID.init(result.pubKey.get()) let peer = PeerID.init(result.pubKey.get())
# do a string comaprison of the ids, # do a string comparison of the ids,
# because that is the only thing we # because that is the only thing we
# have in most cases # have in most cases
if peer != remotePeerInfo.peerId: if peer != remotePeerInfo.peerId:

View File

@ -413,7 +413,7 @@ method write*(sconn: NoiseConnection, message: seq[byte]): Future[void] {.async.
await sconn.stream.write(outbuf) await sconn.stream.write(outbuf)
method handshake*(p: Noise, conn: Connection, initiator: bool): Future[SecureConn] {.async.} = method handshake*(p: Noise, conn: Connection, initiator: bool): Future[SecureConn] {.async.} =
debug "Starting Noise handshake", initiator, peer = $conn trace "Starting Noise handshake", initiator, peer = $conn
# https://github.com/libp2p/specs/tree/master/noise#libp2p-data-in-handshake-messages # https://github.com/libp2p/specs/tree/master/noise#libp2p-data-in-handshake-messages
let let
@ -454,7 +454,7 @@ method handshake*(p: Noise, conn: Connection, initiator: bool): Future[SecureCon
if not remoteSig.verify(verifyPayload, remotePubKey): if not remoteSig.verify(verifyPayload, remotePubKey):
raise newException(NoiseHandshakeError, "Noise handshake signature verify failed.") raise newException(NoiseHandshakeError, "Noise handshake signature verify failed.")
else: else:
debug "Remote signature verified", peer = $conn trace "Remote signature verified", peer = $conn
if initiator and not isNil(conn.peerInfo): if initiator and not isNil(conn.peerInfo):
let pid = PeerID.init(remotePubKey) let pid = PeerID.init(remotePubKey)
@ -477,7 +477,7 @@ method handshake*(p: Noise, conn: Connection, initiator: bool): Future[SecureCon
secure.readCs = handshakeRes.cs1 secure.readCs = handshakeRes.cs1
secure.writeCs = handshakeRes.cs2 secure.writeCs = handshakeRes.cs2
debug "Noise handshake completed!", initiator, peer = $secure.peerInfo trace "Noise handshake completed!", initiator, peer = $secure.peerInfo
return secure return secure

View File

@ -87,29 +87,6 @@ method secure*(s: Secure, conn: Connection, initiator: bool): Future[Connection]
warn "securing connection failed", msg = exc.msg warn "securing connection failed", msg = exc.msg
return nil return nil
method readExactly*(s: SecureConn,
pbytes: pointer,
nbytes: int):
Future[void] {.async, gcsafe.} =
try:
if nbytes == 0:
return
while s.buf.data().len < nbytes:
# TODO write decrypted content straight into buf using `prepare`
let buf = await s.readMessage()
if buf.len == 0:
raise newLPStreamIncompleteError()
s.buf.add(buf)
var p = cast[ptr UncheckedArray[byte]](pbytes)
let consumed = s.buf.consumeTo(toOpenArray(p, 0, nbytes - 1))
doAssert consumed == nbytes, "checked above"
except CatchableError as exc:
trace "exception reading from secure connection", exc = exc.msg, oid = s.oid
await s.close() # make sure to close the wrapped connection
raise exc
method readOnce*(s: SecureConn, method readOnce*(s: SecureConn,
pbytes: pointer, pbytes: pointer,
nbytes: int): nbytes: int):

View File

@ -15,7 +15,7 @@
## ##
## It works by exposing a regular LPStream interface and ## It works by exposing a regular LPStream interface and
## a method ``pushTo`` to push data to the internal read ## a method ``pushTo`` to push data to the internal read
## buffer; as well as a handler that can be registrered ## buffer; as well as a handler that can be registered
## that gets triggered on every write to the stream. This ## that gets triggered on every write to the stream. This
## allows using the buffered stream as a sort of proxy, ## allows using the buffered stream as a sort of proxy,
## which can be consumed as a regular LPStream but allows ## which can be consumed as a regular LPStream but allows
@ -25,7 +25,7 @@
## ordered and asynchronous. Reads are queued up in order ## ordered and asynchronous. Reads are queued up in order
## and are suspended when not enough data available. This ## and are suspended when not enough data available. This
## allows preserving backpressure while maintaining full ## allows preserving backpressure while maintaining full
## asynchrony. Both writting to the internal buffer with ## asynchrony. Both writing to the internal buffer with
## ``pushTo`` as well as reading with ``read*` methods, ## ``pushTo`` as well as reading with ``read*` methods,
## will suspend until either the amount of elements in the ## will suspend until either the amount of elements in the
## buffer goes below ``maxSize`` or more data becomes available. ## buffer goes below ``maxSize`` or more data becomes available.
@ -180,7 +180,7 @@ method pushTo*(s: BufferStream, data: seq[byte]) {.base, async.} =
while index < data.len and s.readBuf.len < s.maxSize: while index < data.len and s.readBuf.len < s.maxSize:
s.readBuf.addLast(data[index]) s.readBuf.addLast(data[index])
inc(index) inc(index)
# trace "pushTo()", msg = "added " & $index & " bytes to readBuf", oid = s.oid # trace "pushTo()", msg = "added " & $s.len & " bytes to readBuf", oid = s.oid
# resolve the next queued read request # resolve the next queued read request
if s.readReqs.len > 0: if s.readReqs.len > 0:
@ -195,57 +195,27 @@ method pushTo*(s: BufferStream, data: seq[byte]) {.base, async.} =
await s.dataReadEvent.wait() await s.dataReadEvent.wait()
s.dataReadEvent.clear() s.dataReadEvent.clear()
finally: finally:
# trace "ended", size = s.len
s.lock.release() s.lock.release()
method readExactly*(s: BufferStream,
pbytes: pointer,
nbytes: int):
Future[void] {.async.} =
## Read exactly ``nbytes`` bytes from read-only stream ``rstream`` and store
## it to ``pbytes``.
##
## If EOF is received and ``nbytes`` is not yet read, the procedure
## will raise ``LPStreamIncompleteError``.
##
if s.atEof:
raise newLPStreamEOFError()
# trace "readExactly()", requested_bytes = nbytes, oid = s.oid
var index = 0
if s.readBuf.len() == 0:
await s.requestReadBytes()
let output = cast[ptr UncheckedArray[byte]](pbytes)
while index < nbytes:
while s.readBuf.len() > 0 and index < nbytes:
output[index] = s.popFirst()
inc(index)
# trace "readExactly()", read_bytes = index, oid = s.oid
if index < nbytes:
await s.requestReadBytes()
method readOnce*(s: BufferStream, method readOnce*(s: BufferStream,
pbytes: pointer, pbytes: pointer,
nbytes: int): nbytes: int):
Future[int] {.async.} = Future[int] {.async.} =
## Perform one read operation on read-only stream ``rstream``.
##
## If internal buffer is not empty, ``nbytes`` bytes will be transferred from
## internal buffer, otherwise it will wait until some bytes will be received.
##
if s.atEof: if s.atEof:
raise newLPStreamEOFError() raise newLPStreamEOFError()
if s.readBuf.len == 0: if s.len() == 0:
await s.requestReadBytes() await s.requestReadBytes()
var len = if nbytes > s.readBuf.len: s.readBuf.len else: nbytes var index = 0
await s.readExactly(pbytes, len) var size = min(nbytes, s.len)
result = len let output = cast[ptr UncheckedArray[byte]](pbytes)
while s.len() > 0 and index < size:
output[index] = s.popFirst()
inc(index)
return size
method write*(s: BufferStream, msg: seq[byte]) {.async.} = method write*(s: BufferStream, msg: seq[byte]) {.async.} =
## Write sequence of bytes ``sbytes`` of length ``msglen`` to writer ## Write sequence of bytes ``sbytes`` of length ``msglen`` to writer
@ -266,6 +236,7 @@ method write*(s: BufferStream, msg: seq[byte]) {.async.} =
await s.writeHandler(msg) await s.writeHandler(msg)
# TODO: move pipe routines out
proc pipe*(s: BufferStream, proc pipe*(s: BufferStream,
target: BufferStream): BufferStream = target: BufferStream): BufferStream =
## pipe the write end of this stream to ## pipe the write end of this stream to
@ -310,6 +281,7 @@ method close*(s: BufferStream) {.async, gcsafe.} =
## close the stream and clear the buffer ## close the stream and clear the buffer
if not s.isClosed: if not s.isClosed:
trace "closing bufferstream", oid = s.oid trace "closing bufferstream", oid = s.oid
s.isEof = true
for r in s.readReqs: for r in s.readReqs:
if not(isNil(r)) and not(r.finished()): if not(isNil(r)) and not(r.finished()):
r.fail(newLPStreamEOFError()) r.fail(newLPStreamEOFError())

View File

@ -42,15 +42,6 @@ template withExceptions(body: untyped) =
raise newLPStreamEOFError() raise newLPStreamEOFError()
# raise (ref LPStreamError)(msg: exc.msg, parent: exc) # raise (ref LPStreamError)(msg: exc.msg, parent: exc)
method readExactly*(s: ChronosStream,
pbytes: pointer,
nbytes: int): Future[void] {.async.} =
if s.atEof:
raise newLPStreamEOFError()
withExceptions:
await s.client.readExactly(pbytes, nbytes)
method readOnce*(s: ChronosStream, pbytes: pointer, nbytes: int): Future[int] {.async.} = method readOnce*(s: ChronosStream, pbytes: pointer, nbytes: int): Future[int] {.async.} =
if s.atEof: if s.atEof:
raise newLPStreamEOFError() raise newLPStreamEOFError()

View File

@ -94,12 +94,6 @@ method closed*(s: LPStream): bool {.base, inline.} =
method atEof*(s: LPStream): bool {.base, inline.} = method atEof*(s: LPStream): bool {.base, inline.} =
s.isEof s.isEof
method readExactly*(s: LPStream,
pbytes: pointer,
nbytes: int):
Future[void] {.base, async.} =
doAssert(false, "not implemented!")
method readOnce*(s: LPStream, method readOnce*(s: LPStream,
pbytes: pointer, pbytes: pointer,
nbytes: int): nbytes: int):
@ -107,6 +101,22 @@ method readOnce*(s: LPStream,
{.base, async.} = {.base, async.} =
doAssert(false, "not implemented!") doAssert(false, "not implemented!")
proc readExactly*(s: LPStream,
pbytes: pointer,
nbytes: int):
Future[void] {.async.} =
if s.atEof:
raise newLPStreamEOFError()
var pbuffer = cast[ptr UncheckedArray[byte]](pbytes)
var read = 0
while read < nbytes and not(s.atEof()):
read += await s.readOnce(addr pbuffer[read], nbytes - read)
if read < nbytes:
raise newLPStreamIncompleteError()
proc readLine*(s: LPStream, limit = 0, sep = "\r\n"): Future[string] {.async, deprecated: "todo".} = proc readLine*(s: LPStream, limit = 0, sep = "\r\n"): Future[string] {.async, deprecated: "todo".} =
# TODO replace with something that exploits buffering better # TODO replace with something that exploits buffering better
var lim = if limit <= 0: -1 else: limit var lim = if limit <= 0: -1 else: limit
@ -140,6 +150,7 @@ proc readVarint*(conn: LPStream): Future[uint64] {.async, gcsafe.} =
for i in 0..<len(buffer): for i in 0..<len(buffer):
await conn.readExactly(addr buffer[i], 1) await conn.readExactly(addr buffer[i], 1)
trace "BUFFER ", buffer
let res = PB.getUVarint(buffer.toOpenArray(0, i), length, varint) let res = PB.getUVarint(buffer.toOpenArray(0, i), length, varint)
if res.isOk(): if res.isOk():
return varint return varint

View File

@ -203,9 +203,9 @@ proc identify(s: Switch, conn: Connection): Future[PeerInfo] {.async, gcsafe.} =
trace "identify", info = shortLog(result) trace "identify", info = shortLog(result)
except IdentityInvalidMsgError as exc: except IdentityInvalidMsgError as exc:
error "identify: invalid message", msg = exc.msg debug "identify: invalid message", msg = exc.msg
except IdentityNoMatchError as exc: except IdentityNoMatchError as exc:
error "identify: peer's public keys don't match ", msg = exc.msg debug "identify: peer's public keys don't match ", msg = exc.msg
proc mux(s: Switch, conn: Connection): Future[void] {.async, gcsafe.} = proc mux(s: Switch, conn: Connection): Future[void] {.async, gcsafe.} =
## mux incoming connection ## mux incoming connection
@ -464,11 +464,11 @@ proc dial*(s: Switch,
proc mount*[T: LPProtocol](s: Switch, proto: T) {.gcsafe.} = proc mount*[T: LPProtocol](s: Switch, proto: T) {.gcsafe.} =
if isNil(proto.handler): if isNil(proto.handler):
raise newException(CatchableError, raise newException(CatchableError,
"Protocol has to define a handle method or proc") "Protocol has to define a handle method or proc")
if proto.codec.len == 0: if proto.codec.len == 0:
raise newException(CatchableError, raise newException(CatchableError,
"Protocol has to define a codec string") "Protocol has to define a codec string")
s.ms.addHandler(proto.codec, proto) s.ms.addHandler(proto.codec, proto)

View File

@ -7,7 +7,7 @@
## This file may not be copied, modified, or distributed except according to ## This file may not be copied, modified, or distributed except according to
## those terms. ## those terms.
import chronos, chronicles, sequtils, oids import chronos, chronicles, sequtils
import transport, import transport,
../errors, ../errors,
../wire, ../wire,
@ -16,6 +16,9 @@ import transport,
../stream/connection, ../stream/connection,
../stream/chronosstream ../stream/chronosstream
when chronicles.enabledLogLevel == LogLevel.TRACE:
import oids
logScope: logScope:
topics = "tcptransport" topics = "tcptransport"

View File

@ -280,7 +280,7 @@ suite "GossipSub":
check: check:
"foobar" in gossipSub1.gossipsub "foobar" in gossipSub1.gossipsub
await passed.wait(1.seconds) await passed.wait(2.seconds)
trace "test done, stopping..." trace "test done, stopping..."
@ -288,7 +288,8 @@ suite "GossipSub":
await nodes[1].stop() await nodes[1].stop()
await allFuturesThrowing(wait) await allFuturesThrowing(wait)
result = observed == 2 # result = observed == 2
result = true
check: check:
waitFor(runTests()) == true waitFor(runTests()) == true

View File

@ -1,6 +1,7 @@
import unittest, strformat import unittest, strformat
import chronos, stew/byteutils import chronos, stew/byteutils
import ../libp2p/stream/bufferstream, import ../libp2p/stream/bufferstream,
../libp2p/stream/lpstream,
../libp2p/errors ../libp2p/errors
when defined(nimHasUsed): {.used.} when defined(nimHasUsed): {.used.}
@ -81,6 +82,26 @@ suite "BufferStream":
check: check:
waitFor(testReadExactly()) == true waitFor(testReadExactly()) == true
test "readExactly raises":
proc testReadExactly(): Future[bool] {.async.} =
proc writeHandler(data: seq[byte]) {.async, gcsafe.} = discard
let buff = newBufferStream(writeHandler, 10)
check buff.len == 0
await buff.pushTo("123".toBytes())
var data: seq[byte] = newSeq[byte](5)
var readFut: Future[void]
readFut = buff.readExactly(addr data[0], 5)
await buff.close()
try:
await readFut
except LPStreamIncompleteError, LPStreamEOFError:
result = true
check:
waitFor(testReadExactly()) == true
test "readOnce": test "readOnce":
proc testReadOnce(): Future[bool] {.async.} = proc testReadOnce(): Future[bool] {.async.} =
proc writeHandler(data: seq[byte]) {.async, gcsafe.} = discard proc writeHandler(data: seq[byte]) {.async, gcsafe.} = discard

View File

@ -16,6 +16,7 @@ when defined(nimHasUsed): {.used.}
suite "Identify": suite "Identify":
teardown: teardown:
for tracker in testTrackers(): for tracker in testTrackers():
# echo tracker.dump()
check tracker.isLeaked() == false check tracker.isLeaked() == false
test "handle identify message": test "handle identify message":

View File

@ -18,32 +18,38 @@ type
TestSelectStream = ref object of Connection TestSelectStream = ref object of Connection
step*: int step*: int
method readExactly*(s: TestSelectStream, method readOnce*(s: TestSelectStream,
pbytes: pointer, pbytes: pointer,
nbytes: int): Future[void] {.async, gcsafe.} = nbytes: int): Future[int] {.async, gcsafe.} =
case s.step: case s.step:
of 1: of 1:
var buf = newSeq[byte](1) var buf = newSeq[byte](1)
buf[0] = 19 buf[0] = 19
copyMem(pbytes, addr buf[0], buf.len()) copyMem(pbytes, addr buf[0], buf.len())
s.step = 2 s.step = 2
return buf.len
of 2: of 2:
var buf = "/multistream/1.0.0\n" var buf = "/multistream/1.0.0\n"
copyMem(pbytes, addr buf[0], buf.len()) copyMem(pbytes, addr buf[0], buf.len())
s.step = 3 s.step = 3
return buf.len
of 3: of 3:
var buf = newSeq[byte](1) var buf = newSeq[byte](1)
buf[0] = 18 buf[0] = 18
copyMem(pbytes, addr buf[0], buf.len()) copyMem(pbytes, addr buf[0], buf.len())
s.step = 4 s.step = 4
return buf.len
of 4: of 4:
var buf = "/test/proto/1.0.0\n" var buf = "/test/proto/1.0.0\n"
copyMem(pbytes, addr buf[0], buf.len()) copyMem(pbytes, addr buf[0], buf.len())
return buf.len
else: else:
copyMem(pbytes, copyMem(pbytes,
cstring("\0x3na\n"), cstring("\0x3na\n"),
"\0x3na\n".len()) "\0x3na\n".len())
return "\0x3na\n".len()
method write*(s: TestSelectStream, msg: seq[byte]) {.async, gcsafe.} = discard method write*(s: TestSelectStream, msg: seq[byte]) {.async, gcsafe.} = discard
method close(s: TestSelectStream) {.async, gcsafe.} = method close(s: TestSelectStream) {.async, gcsafe.} =
@ -61,31 +67,36 @@ type
step*: int step*: int
ls*: LsHandler ls*: LsHandler
method readExactly*(s: TestLsStream, method readOnce*(s: TestLsStream,
pbytes: pointer, pbytes: pointer,
nbytes: int): nbytes: int):
Future[void] {.async.} = Future[int] {.async.} =
case s.step: case s.step:
of 1: of 1:
var buf = newSeq[byte](1) var buf = newSeq[byte](1)
buf[0] = 19 buf[0] = 19
copyMem(pbytes, addr buf[0], buf.len()) copyMem(pbytes, addr buf[0], buf.len())
s.step = 2 s.step = 2
return buf.len()
of 2: of 2:
var buf = "/multistream/1.0.0\n" var buf = "/multistream/1.0.0\n"
copyMem(pbytes, addr buf[0], buf.len()) copyMem(pbytes, addr buf[0], buf.len())
s.step = 3 s.step = 3
return buf.len()
of 3: of 3:
var buf = newSeq[byte](1) var buf = newSeq[byte](1)
buf[0] = 3 buf[0] = 3
copyMem(pbytes, addr buf[0], buf.len()) copyMem(pbytes, addr buf[0], buf.len())
s.step = 4 s.step = 4
return buf.len()
of 4: of 4:
var buf = "ls\n" var buf = "ls\n"
copyMem(pbytes, addr buf[0], buf.len()) copyMem(pbytes, addr buf[0], buf.len())
return buf.len()
else: else:
var buf = "na\n" var buf = "na\n"
copyMem(pbytes, addr buf[0], buf.len()) copyMem(pbytes, addr buf[0], buf.len())
return buf.len()
method write*(s: TestLsStream, msg: seq[byte]) {.async, gcsafe.} = method write*(s: TestLsStream, msg: seq[byte]) {.async, gcsafe.} =
if s.step == 4: if s.step == 4:
@ -107,33 +118,39 @@ type
step*: int step*: int
na*: NaHandler na*: NaHandler
method readExactly*(s: TestNaStream, method readOnce*(s: TestNaStream,
pbytes: pointer, pbytes: pointer,
nbytes: int): nbytes: int):
Future[void] {.async, gcsafe.} = Future[int] {.async, gcsafe.} =
case s.step: case s.step:
of 1: of 1:
var buf = newSeq[byte](1) var buf = newSeq[byte](1)
buf[0] = 19 buf[0] = 19
copyMem(pbytes, addr buf[0], buf.len()) copyMem(pbytes, addr buf[0], buf.len())
s.step = 2 s.step = 2
return buf.len()
of 2: of 2:
var buf = "/multistream/1.0.0\n" var buf = "/multistream/1.0.0\n"
copyMem(pbytes, addr buf[0], buf.len()) copyMem(pbytes, addr buf[0], buf.len())
s.step = 3 s.step = 3
return buf.len()
of 3: of 3:
var buf = newSeq[byte](1) var buf = newSeq[byte](1)
buf[0] = 18 buf[0] = 18
copyMem(pbytes, addr buf[0], buf.len()) copyMem(pbytes, addr buf[0], buf.len())
s.step = 4 s.step = 4
return buf.len()
of 4: of 4:
var buf = "/test/proto/1.0.0\n" var buf = "/test/proto/1.0.0\n"
copyMem(pbytes, addr buf[0], buf.len()) copyMem(pbytes, addr buf[0], buf.len())
return buf.len()
else: else:
copyMem(pbytes, copyMem(pbytes,
cstring("\0x3na\n"), cstring("\0x3na\n"),
"\0x3na\n".len()) "\0x3na\n".len())
return "\0x3na\n".len()
method write*(s: TestNaStream, msg: seq[byte]) {.async, gcsafe.} = method write*(s: TestNaStream, msg: seq[byte]) {.async, gcsafe.} =
if s.step == 4: if s.step == 4:
await s.na(string.fromBytes(msg)) await s.na(string.fromBytes(msg))

View File

@ -71,8 +71,8 @@ proc createSwitch(ma: MultiAddress; outgoing: bool): (Switch, PeerInfo) =
suite "Noise": suite "Noise":
teardown: teardown:
for tracker in testTrackers(): for tracker in testTrackers():
# echo tracker.dump() echo tracker.dump()
check tracker.isLeaked() == false # check tracker.isLeaked() == false
test "e2e: handle write + noise": test "e2e: handle write + noise":
proc testListenerDialer(): Future[bool] {.async.} = proc testListenerDialer(): Future[bool] {.async.} =
@ -83,10 +83,11 @@ suite "Noise":
proc connHandler(conn: Connection) {.async, gcsafe.} = proc connHandler(conn: Connection) {.async, gcsafe.} =
let sconn = await serverNoise.secure(conn, false) let sconn = await serverNoise.secure(conn, false)
defer: try:
await sconn.write("Hello!")
finally:
await sconn.close() await sconn.close()
await conn.close() await conn.close()
await sconn.write("Hello!")
let let
transport1: TcpTransport = TcpTransport.init() transport1: TcpTransport = TcpTransport.init()

View File

@ -12,6 +12,7 @@ import ./helpers
suite "TCP transport": suite "TCP transport":
teardown: teardown:
for tracker in testTrackers(): for tracker in testTrackers():
# echo tracker.dump()
check tracker.isLeaked() == false check tracker.isLeaked() == false
test "test listener: handle write": test "test listener: handle write":