diff --git a/libp2p/connection.nim b/libp2p/connection.nim index ea2baec10..d30f223dc 100644 --- a/libp2p/connection.nim +++ b/libp2p/connection.nim @@ -109,7 +109,7 @@ method readExactly*(s: Connection, pbytes: pointer, nbytes: int): Future[void] {.gcsafe.} = - s.stream.readExactly(pbytes, nbytes) + s.stream.readExactly(pbytes, nbytes) method readOnce*(s: Connection, pbytes: pointer, @@ -118,10 +118,9 @@ method readOnce*(s: Connection, s.stream.readOnce(pbytes, nbytes) method write*(s: Connection, - msg: seq[byte], - msglen = -1): + msg: seq[byte]): Future[void] {.gcsafe.} = - s.stream.write(msg, msglen) + s.stream.write(msg) method closed*(s: Connection): bool = if isNil(s.stream): diff --git a/libp2p/muxers/mplex/lpchannel.nim b/libp2p/muxers/mplex/lpchannel.nim index ea21e155d..8aeff1f6d 100644 --- a/libp2p/muxers/mplex/lpchannel.nim +++ b/libp2p/muxers/mplex/lpchannel.nim @@ -161,6 +161,6 @@ template writePrefix: untyped = if s.isLazy and not s.isOpen: await s.open() -method write*(s: LPChannel, msg: seq[byte], msglen = -1) {.async.} = +method write*(s: LPChannel, msg: seq[byte]) {.async.} = writePrefix() - await procCall write(BufferStream(s), msg, msglen) + await procCall write(BufferStream(s), msg) diff --git a/libp2p/protocols/secure/noise.nim b/libp2p/protocols/secure/noise.nim index 6f6f551dd..4e71ac264 100644 --- a/libp2p/protocols/secure/noise.nim +++ b/libp2p/protocols/secure/noise.nim @@ -267,11 +267,12 @@ template read_s: untyped = proc receiveHSMessage(sconn: Connection): Future[seq[byte]] {.async.} = var besize: array[2, byte] - await sconn.readExactly(addr besize[0], 2) + await sconn.stream.readExactly(addr besize[0], besize.len) let size = uint16.fromBytesBE(besize).int trace "receiveHSMessage", size var buffer = newSeq[byte](size) - await sconn.readExactly(addr buffer[0], size) + if buffer.len > 0: + await sconn.stream.readExactly(addr buffer[0], buffer.len) return buffer proc sendHSMessage(sconn: Connection; buf: seq[byte]) {.async.} = @@ -416,25 +417,25 @@ proc handshakeXXInbound(p: Noise, conn: Connection, p2pProof: ProtoBuffer): Futu let (cs1, cs2) = hs.ss.split() return HandshakeResult(cs1: cs1, cs2: cs2, remoteP2psecret: remoteP2psecret, rs: hs.rs) -method readMessage(sconn: NoiseConnection): Future[seq[byte]] {.async.} = - try: +method readMessage*(sconn: NoiseConnection): Future[seq[byte]] {.async.} = + while true: # Discard 0-length payloads var besize: array[2, byte] - await sconn.readExactly(addr besize[0], 2) - let size = uint16.fromBytesBE(besize).int + await sconn.stream.readExactly(addr besize[0], besize.len) + let size = uint16.fromBytesBE(besize).int # Cannot overflow trace "receiveEncryptedMessage", size, peer = $sconn.peerInfo - if size == 0: - return @[] - var buffer = newSeq[byte](size) - await sconn.readExactly(addr buffer[0], size) - var plain = sconn.readCs.decryptWithAd([], buffer) - unpackNoisePayload(plain) - return plain - except LPStreamIncompleteError: - trace "Connection dropped while reading" - except LPStreamReadError: - trace "Error reading from connection" + if size > 0: + var buffer = newSeq[byte](size) + await sconn.stream.readExactly(addr buffer[0], buffer.len) + var plain = sconn.readCs.decryptWithAd([], buffer) + unpackNoisePayload(plain) + return plain + else: + trace "Received 0-length message", conn = $conn + +method write*(sconn: NoiseConnection, message: seq[byte]): Future[void] {.async.} = + if message.len == 0: + return -method writeMessage(sconn: NoiseConnection, message: seq[byte]): Future[void] {.async.} = try: var left = message.len @@ -453,7 +454,7 @@ method writeMessage(sconn: NoiseConnection, message: seq[byte]): Future[void] {. trace "sendEncryptedMessage", size = lesize, peer = $sconn.peerInfo, left, offset outbuf &= besize outbuf &= cipher - await sconn.write(outbuf) + await sconn.stream.write(outbuf) except AsyncStreamWriteError: trace "Could not write to connection" @@ -520,15 +521,6 @@ method init*(p: Noise) {.gcsafe.} = procCall Secure(p).init() p.codec = NoiseCodec -proc secure*(p: Noise, conn: Connection): Future[Connection] {.async, gcsafe.} = - trace "Noise.secure called", initiator=p.outgoing - try: - result = await p.handleConn(conn, p.outgoing) - except CatchableError as exc: - warn "securing connection failed", msg = exc.msg - if not conn.closed(): - await conn.close() - proc newNoise*(privateKey: PrivateKey; outgoing: bool = true; commonPrologue: seq[byte] = @[]): Noise = new result result.outgoing = outgoing diff --git a/libp2p/protocols/secure/secio.nim b/libp2p/protocols/secure/secio.nim index 77c9b2a3e..44e155384 100644 --- a/libp2p/protocols/secure/secio.nim +++ b/libp2p/protocols/secure/secio.nim @@ -6,7 +6,7 @@ ## at your option. ## This file may not be copied, modified, or distributed except according to ## those terms. -import chronos, chronicles, oids +import chronos, chronicles, oids, stew/endians2 import nimcrypto/[sysrand, hmac, sha2, sha, hash, rijndael, twofish, bcmode] import secure, ../../connection, @@ -174,36 +174,44 @@ proc macCheckAndDecode(sconn: SecioConn, data: var seq[byte]): bool = data.setLen(mark) result = true -method readMessage(sconn: SecioConn): Future[seq[byte]] {.async.} = +proc readRawMessage(conn: Connection): Future[seq[byte]] {.async.} = + while true: # Discard 0-length payloads + var lengthBuf: array[4, byte] + await conn.stream.readExactly(addr lengthBuf[0], lengthBuf.len) + let length = uint32.fromBytesBE(lengthBuf) + + trace "Recieved message header", header = lengthBuf.shortLog, length = length + + if length > SecioMaxMessageSize: # Verify length before casting! + trace "Received size of message exceed limits", conn = $conn, length = length + raise (ref SecioError)(msg: "Message exceeds maximum length") + + if length > 0: + var buf = newSeq[byte](int(length)) + await conn.stream.readExactly(addr buf[0], buf.len) + trace "Received message body", + conn = $conn, length = buf.len, buff = buf.shortLog + return buf + + trace "Discarding 0-length payload", conn = $conn + +method readMessage*(sconn: SecioConn): Future[seq[byte]] {.async.} = ## Read message from channel secure connection ``sconn``. when chronicles.enabledLogLevel == LogLevel.TRACE: logScope: stream_oid = $sconn.stream.oid - try: - var buf = newSeq[byte](4) - await sconn.readExactly(addr buf[0], 4) - let length = (int(buf[0]) shl 24) or (int(buf[1]) shl 16) or - (int(buf[2]) shl 8) or (int(buf[3])) - trace "Received message header", header = buf.shortLog, length = length - if length <= SecioMaxMessageSize: - buf.setLen(length) - await sconn.readExactly(addr buf[0], length) - trace "Received message body", length = length, - buffer = buf.shortLog - if sconn.macCheckAndDecode(buf): - result = buf - else: - trace "Message MAC verification failed", buf = buf.shortLog - else: - trace "Received message header size is more then allowed", - length = length, allowed_length = SecioMaxMessageSize - except LPStreamIncompleteError: - trace "Connection dropped while reading" - except LPStreamReadError: - trace "Error reading from connection" + var buf = await sconn.readRawMessage() + if sconn.macCheckAndDecode(buf): + result = buf + else: + trace "Message MAC verification failed", buf = buf.shortLog + raise (ref SecioError)(msg: "message failed MAC verification") -method writeMessage(sconn: SecioConn, message: seq[byte]) {.async.} = +method write*(sconn: SecioConn, message: seq[byte]) {.async.} = ## Write message ``message`` to secure connection ``sconn``. + if message.len == 0: + return + try: var left = message.len @@ -211,8 +219,12 @@ method writeMessage(sconn: SecioConn, message: seq[byte]) {.async.} = while left > 0: let chunkSize = if left > SecioMaxMessageSize - 64: SecioMaxMessageSize - 64 else: left - let macsize = sconn.writerMac.sizeDigest() + macsize = sconn.writerMac.sizeDigest() + length = chunkSize + macsize + var msg = newSeq[byte](chunkSize + 4 + macsize) + msg[0..<4] = uint32(length).toBytesBE() + sconn.writerCoder.encrypt(message.toOpenArray(offset, offset + chunkSize - 1), msg.toOpenArray(4, 4 + chunkSize - 1)) left = left - chunkSize @@ -221,13 +233,9 @@ method writeMessage(sconn: SecioConn, message: seq[byte]) {.async.} = sconn.writerMac.update(msg.toOpenArray(4, 4 + chunkSize - 1)) sconn.writerMac.finish(msg.toOpenArray(mo, mo + macsize - 1)) sconn.writerMac.reset() - let length = chunkSize + macsize - msg[0] = byte((length shr 24) and 0xFF) - msg[1] = byte((length shr 16) and 0xFF) - msg[2] = byte((length shr 8) and 0xFF) - msg[3] = byte(length and 0xFF) + trace "Writing message", message = msg.shortLog, left, offset - await sconn.write(msg) + await sconn.stream.write(msg) except AsyncStreamWriteError: trace "Could not write to connection" @@ -269,30 +277,9 @@ proc newSecioConn(conn: Connection, proc transactMessage(conn: Connection, msg: seq[byte]): Future[seq[byte]] {.async.} = - var buf = newSeq[byte](4) - try: - trace "Sending message", message = msg.shortLog, length = len(msg) - await conn.write(msg) - await conn.readExactly(addr buf[0], 4) - let length = (int(buf[0]) shl 24) or (int(buf[1]) shl 16) or - (int(buf[2]) shl 8) or (int(buf[3])) - trace "Recieved message header", header = buf.shortLog, length = length - if length <= SecioMaxMessageSize: - buf.setLen(length) - await conn.readExactly(addr buf[0], length) - trace "Received message body", conn = $conn, - length = length, - buff = buf.shortLog - result = buf - else: - trace "Received size of message exceed limits", conn = $conn, - length = length - except LPStreamIncompleteError: - trace "Connection dropped while reading", conn = $conn - except LPStreamReadError: - trace "Error reading from connection", conn = $conn - except LPStreamWriteError: - trace "Could not write to connection", conn = $conn + trace "Sending message", message = msg.shortLog, length = len(msg) + await conn.write(msg) + return await conn.readRawMessage() method handshake*(s: Secio, conn: Connection, initiator: bool = false): Future[SecureConn] {.async.} = var @@ -312,7 +299,7 @@ method handshake*(s: Secio, conn: Connection, initiator: bool = false): Future[S localBytesPubkey = s.localPublicKey.getBytes() if randomBytes(localNonce) != SecioNonceSize: - raise newException(CatchableError, "Could not generate random data") + raise (ref SecioError)(msg: "Could not generate random data") var request = createProposal(localNonce, localBytesPubkey, @@ -332,16 +319,16 @@ method handshake*(s: Secio, conn: Connection, initiator: bool = false): Future[S if len(answer) == 0: trace "Proposal exchange failed", conn = $conn - raise newException(SecioError, "Proposal exchange failed") + raise (ref SecioError)(msg: "Proposal exchange failed") if not decodeProposal(answer, remoteNonce, remoteBytesPubkey, remoteExchanges, remoteCiphers, remoteHashes): trace "Remote proposal decoding failed", conn = $conn - raise newException(SecioError, "Remote proposal decoding failed") + raise (ref SecioError)(msg: "Remote proposal decoding failed") if not remotePubkey.init(remoteBytesPubkey): trace "Remote public key incorrect or corrupted", pubkey = remoteBytesPubkey.shortLog - raise newException(SecioError, "Remote public key incorrect or corrupted") + raise (ref SecioError)(msg: "Remote public key incorrect or corrupted") remotePeerId = PeerID.init(remotePubkey) @@ -358,7 +345,7 @@ method handshake*(s: Secio, conn: Connection, initiator: bool = false): Future[S let hash = selectBest(order, SecioHashes, remoteHashes) if len(scheme) == 0 or len(cipher) == 0 or len(hash) == 0: trace "No algorithms in common", peer = remotePeerId - raise newException(SecioError, "No algorithms in common") + raise (ref SecioError)(msg: "No algorithms in common") trace "Encryption scheme selected", scheme = scheme, cipher = cipher, hash = hash @@ -373,15 +360,15 @@ method handshake*(s: Secio, conn: Connection, initiator: bool = false): Future[S var remoteExchange = await transactMessage(conn, localExchange) if len(remoteExchange) == 0: trace "Corpus exchange failed", conn = $conn - raise newException(SecioError, "Corpus exchange failed") + raise (ref SecioError)(msg: "Corpus exchange failed") if not decodeExchange(remoteExchange, remoteEBytesPubkey, remoteEBytesSig): trace "Remote exchange decoding failed", conn = $conn - raise newException(SecioError, "Remote exchange decoding failed") + raise (ref SecioError)(msg: "Remote exchange decoding failed") if not remoteESignature.init(remoteEBytesSig): trace "Remote signature incorrect or corrupted", signature = remoteEBytesSig.shortLog - raise newException(SecioError, "Remote signature incorrect or corrupted") + raise (ref SecioError)(msg: "Remote signature incorrect or corrupted") var remoteCorpus = answer & request[4..^1] & remoteEBytesPubkey if not remoteESignature.verify(remoteCorpus, remotePubkey): @@ -389,21 +376,21 @@ method handshake*(s: Secio, conn: Connection, initiator: bool = false): Future[S signature = $remoteESignature, pubkey = $remotePubkey, corpus = $remoteCorpus - raise newException(SecioError, "Signature verification failed") + raise (ref SecioError)(msg: "Signature verification failed") trace "Signature verified", scheme = remotePubkey.scheme if not remoteEPubkey.eckey.initRaw(remoteEBytesPubkey): trace "Remote ephemeral public key incorrect or corrupted", pubkey = toHex(remoteEBytesPubkey) - raise newException(SecioError, "Remote ephemeral public key incorrect or corrupted") + raise (ref SecioError)(msg: "Remote ephemeral public key incorrect or corrupted") var secret = getSecret(remoteEPubkey, ekeypair.seckey) if len(secret) == 0: trace "Shared secret could not be created", pubkeyScheme = remoteEPubkey.scheme, seckeyScheme = ekeypair.seckey.scheme - raise newException(SecioError, "Shared secret could not be created") + raise (ref SecioError)(msg: "Shared secret could not be created") trace "Shared secret calculated", secret = secret.shortLog @@ -419,13 +406,13 @@ method handshake*(s: Secio, conn: Connection, initiator: bool = false): Future[S var secioConn = newSecioConn(conn, hash, cipher, keys, order, remotePubkey) result = secioConn - await secioConn.writeMessage(remoteNonce) + await secioConn.write(remoteNonce) var res = await secioConn.readMessage() if res != @localNonce: trace "Nonce verification failed", receivedNonce = res.shortLog, localNonce = localNonce.shortLog - raise newException(CatchableError, "Nonce verification failed") + raise (ref SecioError)(msg: "Nonce verification failed") else: trace "Secure handshake succeeded" diff --git a/libp2p/protocols/secure/secure.nim b/libp2p/protocols/secure/secure.nim index 15446c94a..83b234bbf 100644 --- a/libp2p/protocols/secure/secure.nim +++ b/libp2p/protocols/secure/secure.nim @@ -10,57 +10,28 @@ import options import chronos, chronicles import ../protocol, - ../../stream/bufferstream, + ../../stream/streamseq, ../../connection, - ../../peerinfo, - ../../utility + ../../peerinfo type Secure* = ref object of LPProtocol # base type for secure managers SecureConn* = ref object of Connection + buf: StreamSeq method readMessage*(c: SecureConn): Future[seq[byte]] {.async, base.} = doAssert(false, "Not implemented!") -method writeMessage*(c: SecureConn, data: seq[byte]) {.async, base.} = - doAssert(false, "Not implemented!") - method handshake(s: Secure, conn: Connection, initiator: bool): Future[SecureConn] {.async, base.} = doAssert(false, "Not implemented!") -proc readLoop(sconn: SecureConn, conn: Connection) {.async.} = - try: - let stream = BufferStream(conn.stream) - while not sconn.closed: - let msg = await sconn.readMessage() - if msg.len == 0: - trace "stream EOF" - return - - await stream.pushTo(msg) - except CatchableError as exc: - trace "Exception occurred Secure.readLoop", exc = exc.msg - finally: - trace "closing conn", closed = conn.closed() - if not conn.closed: - await conn.close() - - trace "closing sconn", closed = sconn.closed() - if not sconn.closed: - await sconn.close() - trace "ending Secure readLoop" - proc handleConn*(s: Secure, conn: Connection, initiator: bool): Future[Connection] {.async, gcsafe.} = var sconn = await s.handshake(conn, initiator) - proc writeHandler(data: seq[byte]) {.async, gcsafe.} = - trace "sending encrypted bytes", bytes = data.shortLog - await sconn.writeMessage(data) - result = newConnection(newBufferStream(writeHandler)) - conn.readLoops &= readLoop(sconn, result) + result = sconn if not isNil(sconn.peerInfo) and sconn.peerInfo.publicKey.isSome: result.peerInfo = PeerInfo.init(sconn.peerInfo.publicKey.get()) @@ -86,3 +57,37 @@ method secure*(s: Secure, conn: Connection, initiator: bool): Future[Connection] warn "securing connection failed", msg = exc.msg if not conn.closed(): await conn.close() + +method readExactly*(s: SecureConn, + pbytes: pointer, + nbytes: int): + Future[void] {.async, gcsafe.} = + if nbytes == 0: + return + + while s.buf.data().len < nbytes: + # TODO write decrypted content straight into buf using `prepare` + let buf = await s.readMessage() + if buf.len == 0: + raise newLPStreamIncompleteError() + s.buf.add(buf) + + var p = cast[ptr UncheckedArray[byte]](pbytes) + let consumed = s.buf.consumeTo(toOpenArray(p, 0, nbytes - 1)) + doAssert consumed == nbytes, "checked above" + +method readOnce*(s: SecureConn, + pbytes: pointer, + nbytes: int): + Future[int] {.async, gcsafe.} = + if nbytes == 0: + return 0 + + if s.buf.data().len() == 0: + let buf = await s.readMessage() + if buf.len == 0: + raise newLPStreamIncompleteError() + s.buf.add(buf) + + var p = cast[ptr UncheckedArray[byte]](pbytes) + return s.buf.consumeTo(toOpenArray(p, 0, nbytes - 1)) diff --git a/libp2p/stream/bufferstream.nim b/libp2p/stream/bufferstream.nim index 1d21c6d6d..6437b7bd8 100644 --- a/libp2p/stream/bufferstream.nim +++ b/libp2p/stream/bufferstream.nim @@ -216,9 +216,7 @@ method readOnce*(s: BufferStream, await s.readExactly(pbytes, len) result = len -method write*(s: BufferStream, - msg: seq[byte], - msglen = -1): Future[void] = +method write*(s: BufferStream, msg: seq[byte]): Future[void] = ## Write sequence of bytes ``sbytes`` of length ``msglen`` to writer ## stream ``wstream``. ## @@ -233,7 +231,7 @@ method write*(s: BufferStream, retFuture.fail(newNotWritableError()) return retFuture - result = s.writeHandler(if msglen >= 0: msg[0..= 0: msglen else: msg.len - s.write(@(toOpenArrayByte(msg, 0, nbytes - 1))) +proc write*(s: LPStream, msg: string): Future[void] = + s.write(@(toOpenArrayByte(msg, 0, msg.high))) method close*(s: LPStream) {.base, async.} = diff --git a/libp2p/stream/streamseq.nim b/libp2p/stream/streamseq.nim new file mode 100644 index 000000000..bbb44aeba --- /dev/null +++ b/libp2p/stream/streamseq.nim @@ -0,0 +1,73 @@ +import stew/bitops2 + +type + StreamSeq* = object + # Seq adapted to the stream use case where we add data at the back and + # consume at the front in chunks. A bit like a deque but contiguous memory + # area - will try to avoid moving data unless it has to, subject to buffer + # space. The assumption is that data is typically consumed fully. + # + # See also asio::stream_buf + + buf: seq[byte] # Data store + rpos: int # Reading position - valid data starts here + wpos: int # Writing position - valid data ends here + +template len*(v: StreamSeq): int = + v.wpos - v.rpos + +func grow(v: var StreamSeq, n: int) = + if v.rpos == v.wpos: + # All data has been consumed, reset positions + v.rpos = 0 + v.wpos = 0 + + if v.buf.len - v.wpos < n: + if v.rpos > 0: + # We've consumed some data so we'll try to move that data to the beginning + # of the buffer, hoping that this will clear up enough capacity to avoid + # reallocation + moveMem(addr v.buf[0], addr v.buf[v.rpos], v.wpos - v.rpos) + v.wpos -= v.rpos + v.rpos = 0 + + if v.buf.len - v.wpos >= n: + return + + # TODO this is inefficient - `setLen` will copy all data of buf, even though + # we know that only a part of it contains "valid" data + v.buf.setLen(nextPow2(max(64, v.wpos + n).uint64).int) + +template prepare*(v: var StreamSeq, n: int): var openArray[byte] = + ## Return a buffer that is at least `n` bytes long + mixin grow + v.grow(n) + + v.buf.toOpenArray(v.wpos, v.buf.len - 1) + +template commit*(v: var StreamSeq, n: int) = + ## Mark `n` bytes in the buffer returned by `prepare` as ready for reading + v.wpos += n + +func add*(v: var StreamSeq, data: openArray[byte]) = + ## Add data - the equivalent of `buf.prepare(n) = data; buf.commit(n)` + if data.len > 0: + v.grow(data.len) + copyMem(addr v.buf[v.wpos], unsafeAddr data[0], data.len) + v.commit(data.len) + +template data*(v: StreamSeq): openArray[byte] = + # Data that is ready to be consumed + # TODO a double-hash comment here breaks compile (!) + v.buf.toOpenArray(v.rpos, v.wpos - 1) + +func consume*(v: var StreamSeq, n: int) = + ## Mark `n` bytes that were returned via `data` as consumed + v.rpos += n + +func consumeTo*(v: var StreamSeq, buf: var openArray[byte]): int = + let bytes = min(buf.len, v.len) + if bytes > 0: + copyMem(addr buf[0], addr v.buf[v.rpos], bytes) + v.consume(bytes) + bytes diff --git a/tests/testbufferstream.nim b/tests/testbufferstream.nim index d2971b3d9..d5ca1d6b6 100644 --- a/tests/testbufferstream.nim +++ b/tests/testbufferstream.nim @@ -149,7 +149,7 @@ suite "BufferStream": let buff = newBufferStream(writeHandler, 10) check buff.len == 0 - await buff.write("Hello!", 6) + await buff.write("Hello!") result = true @@ -166,7 +166,7 @@ suite "BufferStream": let buff = newBufferStream(writeHandler, 10) check buff.len == 0 - await buff.write(cast[seq[byte]]("Hello!"), 6) + await buff.write(cast[seq[byte]]("Hello!")) result = true diff --git a/tests/testinterop.nim b/tests/testinterop.nim index f519f6397..9c4c98d99 100644 --- a/tests/testinterop.nim +++ b/tests/testinterop.nim @@ -48,21 +48,17 @@ proc readLp*(s: StreamTransport): Future[seq[byte]] {.async, gcsafe.} = length: int res: VarintStatus result = newSeq[byte](10) - try: - for i in 0.. 0.uint: - await s.readExactly(addr result[0], int(size)) - except TransportIncompleteError as exc: - trace "remote connection ended unexpectedly", exc = exc.msg - except TransportError as exc: - trace "unable to read from remote connection", exc = exc.msg + + for i in 0.. 0.uint: + await s.readExactly(addr result[0], int(size)) proc createNode*(privKey: Option[PrivateKey] = none(PrivateKey), address: string = "/ip4/127.0.0.1/tcp/0", diff --git a/tests/testmultistream.nim b/tests/testmultistream.nim index 5fdfbe0f2..5b20ad2b9 100644 --- a/tests/testmultistream.nim +++ b/tests/testmultistream.nim @@ -48,8 +48,7 @@ method readExactly*(s: TestSelectStream, cstring("\0x3na\n"), "\0x3na\n".len()) -method write*(s: TestSelectStream, msg: seq[byte], msglen = -1) - {.async, gcsafe.} = discard +method write*(s: TestSelectStream, msg: seq[byte]) {.async, gcsafe.} = discard method close(s: TestSelectStream) {.async, gcsafe.} = s.isClosed = true @@ -92,7 +91,7 @@ method readExactly*(s: TestLsStream, var buf = "na\n" copyMem(pbytes, addr buf[0], buf.len()) -method write*(s: TestLsStream, msg: seq[byte], msglen = -1) {.async, gcsafe.} = +method write*(s: TestLsStream, msg: seq[byte]) {.async, gcsafe.} = if s.step == 4: await s.ls(msg) @@ -139,7 +138,7 @@ method readExactly*(s: TestNaStream, cstring("\0x3na\n"), "\0x3na\n".len()) -method write*(s: TestNaStream, msg: seq[byte], msglen = -1) {.async, gcsafe.} = +method write*(s: TestNaStream, msg: seq[byte]) {.async, gcsafe.} = if s.step == 4: await s.na(string.fromBytes(msg)) diff --git a/tests/testnative.nim b/tests/testnative.nim index bb4147f48..53e4a244e 100644 --- a/tests/testnative.nim +++ b/tests/testnative.nim @@ -1,4 +1,5 @@ -import testvarint +import testvarint, + teststreamseq import testrsa, testecnist, diff --git a/tests/testnoise.nim b/tests/testnoise.nim index 965c7aaad..1b2cda0c5 100644 --- a/tests/testnoise.nim +++ b/tests/testnoise.nim @@ -93,7 +93,7 @@ suite "Noise": serverNoise = newNoise(serverInfo.privateKey, outgoing = false) proc connHandler(conn: Connection) {.async, gcsafe.} = - let sconn = await serverNoise.secure(conn) + let sconn = await serverNoise.secure(conn, false) defer: await sconn.close() await conn.close() @@ -108,7 +108,7 @@ suite "Noise": clientInfo = PeerInfo.init(PrivateKey.random(RSA), [transport1.ma]) clientNoise = newNoise(clientInfo.privateKey, outgoing = true) conn = await transport2.dial(transport1.ma) - sconn = await clientNoise.secure(conn) + sconn = await clientNoise.secure(conn, true) msg = await sconn.read(6) @@ -131,7 +131,7 @@ suite "Noise": readTask = newFuture[void]() proc connHandler(conn: Connection) {.async, gcsafe.} = - let sconn = await serverNoise.secure(conn) + let sconn = await serverNoise.secure(conn, false) defer: await sconn.close() await conn.close() @@ -148,7 +148,7 @@ suite "Noise": clientInfo = PeerInfo.init(PrivateKey.random(RSA), [transport1.ma]) clientNoise = newNoise(clientInfo.privateKey, outgoing = true) conn = await transport2.dial(transport1.ma) - sconn = await clientNoise.secure(conn) + sconn = await clientNoise.secure(conn, true) await sconn.write("Hello!".cstring, 6) await readTask @@ -175,7 +175,7 @@ suite "Noise": trace "Sending huge payload", size = hugePayload.len proc connHandler(conn: Connection) {.async, gcsafe.} = - let sconn = await serverNoise.secure(conn) + let sconn = await serverNoise.secure(conn, false) defer: await sconn.close() let msg = await sconn.readLp() @@ -191,7 +191,7 @@ suite "Noise": clientInfo = PeerInfo.init(PrivateKey.random(RSA), [transport1.ma]) clientNoise = newNoise(clientInfo.privateKey, outgoing = true) conn = await transport2.dial(transport1.ma) - sconn = await clientNoise.secure(conn) + sconn = await clientNoise.secure(conn, true) await sconn.writeLp(hugePayload) await readTask diff --git a/tests/teststreamseq.nim b/tests/teststreamseq.nim new file mode 100644 index 000000000..51932fce8 --- /dev/null +++ b/tests/teststreamseq.nim @@ -0,0 +1,55 @@ +{.used.} + +import unittest +import stew/byteutils +import ../libp2p/stream/streamseq + +suite "StreamSeq": + test "basics": + var s: StreamSeq + + check: + s.data().len == 0 + + s.add([byte 0, 1, 2, 3]) + + check: + @(s.data()) == [byte 0, 1, 2, 3] + + s.prepare(10)[0..<3] = [byte 4, 5, 6] + + check: + @(s.data()) == [byte 0, 1, 2, 3] + + s.commit(3) + + check: + @(s.data()) == [byte 0, 1, 2, 3, 4, 5, 6] + + s.consume(1) + + check: + @(s.data()) == [byte 1, 2, 3, 4, 5, 6] + + s.consume(6) + + check: @(s.data()) == [] + + s.add([]) + check: @(s.data()) == [] + + var o: seq[byte] + + check: 0 == s.consumeTo(o) + + s.add([byte 1, 2, 3]) + + o.setLen(2) + o.setLen(s.consumeTo(o)) + check: + o == [byte 1, 2] + + o.setLen(s.consumeTo(o)) + + check: + o == [byte 3]