remove readLoop in secure protocols (#162)

* remove readLoop in secure protocols, fix security issues

* fix Defect on remote sending 0-byte noise/secio message
* remove msglen from `write` (unused)
* simplify SecureConn data flow
* document some control-flow issues

* unify exception behaviour across noise and secio

* secio would not raise on mac/decryption errors

* fix compile error
This commit is contained in:
Jacek Sieka 2020-05-07 22:37:46 +02:00 committed by GitHub
parent 330da51819
commit 1efada474c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
15 changed files with 288 additions and 182 deletions

View File

@ -109,7 +109,7 @@ method readExactly*(s: Connection,
pbytes: pointer, pbytes: pointer,
nbytes: int): nbytes: int):
Future[void] {.gcsafe.} = Future[void] {.gcsafe.} =
s.stream.readExactly(pbytes, nbytes) s.stream.readExactly(pbytes, nbytes)
method readOnce*(s: Connection, method readOnce*(s: Connection,
pbytes: pointer, pbytes: pointer,
@ -118,10 +118,9 @@ method readOnce*(s: Connection,
s.stream.readOnce(pbytes, nbytes) s.stream.readOnce(pbytes, nbytes)
method write*(s: Connection, method write*(s: Connection,
msg: seq[byte], msg: seq[byte]):
msglen = -1):
Future[void] {.gcsafe.} = Future[void] {.gcsafe.} =
s.stream.write(msg, msglen) s.stream.write(msg)
method closed*(s: Connection): bool = method closed*(s: Connection): bool =
if isNil(s.stream): if isNil(s.stream):

View File

@ -161,6 +161,6 @@ template writePrefix: untyped =
if s.isLazy and not s.isOpen: if s.isLazy and not s.isOpen:
await s.open() await s.open()
method write*(s: LPChannel, msg: seq[byte], msglen = -1) {.async.} = method write*(s: LPChannel, msg: seq[byte]) {.async.} =
writePrefix() writePrefix()
await procCall write(BufferStream(s), msg, msglen) await procCall write(BufferStream(s), msg)

View File

@ -267,11 +267,12 @@ template read_s: untyped =
proc receiveHSMessage(sconn: Connection): Future[seq[byte]] {.async.} = proc receiveHSMessage(sconn: Connection): Future[seq[byte]] {.async.} =
var besize: array[2, byte] var besize: array[2, byte]
await sconn.readExactly(addr besize[0], 2) await sconn.stream.readExactly(addr besize[0], besize.len)
let size = uint16.fromBytesBE(besize).int let size = uint16.fromBytesBE(besize).int
trace "receiveHSMessage", size trace "receiveHSMessage", size
var buffer = newSeq[byte](size) var buffer = newSeq[byte](size)
await sconn.readExactly(addr buffer[0], size) if buffer.len > 0:
await sconn.stream.readExactly(addr buffer[0], buffer.len)
return buffer return buffer
proc sendHSMessage(sconn: Connection; buf: seq[byte]) {.async.} = proc sendHSMessage(sconn: Connection; buf: seq[byte]) {.async.} =
@ -416,25 +417,25 @@ proc handshakeXXInbound(p: Noise, conn: Connection, p2pProof: ProtoBuffer): Futu
let (cs1, cs2) = hs.ss.split() let (cs1, cs2) = hs.ss.split()
return HandshakeResult(cs1: cs1, cs2: cs2, remoteP2psecret: remoteP2psecret, rs: hs.rs) return HandshakeResult(cs1: cs1, cs2: cs2, remoteP2psecret: remoteP2psecret, rs: hs.rs)
method readMessage(sconn: NoiseConnection): Future[seq[byte]] {.async.} = method readMessage*(sconn: NoiseConnection): Future[seq[byte]] {.async.} =
try: while true: # Discard 0-length payloads
var besize: array[2, byte] var besize: array[2, byte]
await sconn.readExactly(addr besize[0], 2) await sconn.stream.readExactly(addr besize[0], besize.len)
let size = uint16.fromBytesBE(besize).int let size = uint16.fromBytesBE(besize).int # Cannot overflow
trace "receiveEncryptedMessage", size, peer = $sconn.peerInfo trace "receiveEncryptedMessage", size, peer = $sconn.peerInfo
if size == 0: if size > 0:
return @[] var buffer = newSeq[byte](size)
var buffer = newSeq[byte](size) await sconn.stream.readExactly(addr buffer[0], buffer.len)
await sconn.readExactly(addr buffer[0], size) var plain = sconn.readCs.decryptWithAd([], buffer)
var plain = sconn.readCs.decryptWithAd([], buffer) unpackNoisePayload(plain)
unpackNoisePayload(plain) return plain
return plain else:
except LPStreamIncompleteError: trace "Received 0-length message", conn = $conn
trace "Connection dropped while reading"
except LPStreamReadError: method write*(sconn: NoiseConnection, message: seq[byte]): Future[void] {.async.} =
trace "Error reading from connection" if message.len == 0:
return
method writeMessage(sconn: NoiseConnection, message: seq[byte]): Future[void] {.async.} =
try: try:
var var
left = message.len left = message.len
@ -453,7 +454,7 @@ method writeMessage(sconn: NoiseConnection, message: seq[byte]): Future[void] {.
trace "sendEncryptedMessage", size = lesize, peer = $sconn.peerInfo, left, offset trace "sendEncryptedMessage", size = lesize, peer = $sconn.peerInfo, left, offset
outbuf &= besize outbuf &= besize
outbuf &= cipher outbuf &= cipher
await sconn.write(outbuf) await sconn.stream.write(outbuf)
except AsyncStreamWriteError: except AsyncStreamWriteError:
trace "Could not write to connection" trace "Could not write to connection"
@ -520,15 +521,6 @@ method init*(p: Noise) {.gcsafe.} =
procCall Secure(p).init() procCall Secure(p).init()
p.codec = NoiseCodec p.codec = NoiseCodec
proc secure*(p: Noise, conn: Connection): Future[Connection] {.async, gcsafe.} =
trace "Noise.secure called", initiator=p.outgoing
try:
result = await p.handleConn(conn, p.outgoing)
except CatchableError as exc:
warn "securing connection failed", msg = exc.msg
if not conn.closed():
await conn.close()
proc newNoise*(privateKey: PrivateKey; outgoing: bool = true; commonPrologue: seq[byte] = @[]): Noise = proc newNoise*(privateKey: PrivateKey; outgoing: bool = true; commonPrologue: seq[byte] = @[]): Noise =
new result new result
result.outgoing = outgoing result.outgoing = outgoing

View File

@ -6,7 +6,7 @@
## at your option. ## at your option.
## This file may not be copied, modified, or distributed except according to ## This file may not be copied, modified, or distributed except according to
## those terms. ## those terms.
import chronos, chronicles, oids import chronos, chronicles, oids, stew/endians2
import nimcrypto/[sysrand, hmac, sha2, sha, hash, rijndael, twofish, bcmode] import nimcrypto/[sysrand, hmac, sha2, sha, hash, rijndael, twofish, bcmode]
import secure, import secure,
../../connection, ../../connection,
@ -174,36 +174,44 @@ proc macCheckAndDecode(sconn: SecioConn, data: var seq[byte]): bool =
data.setLen(mark) data.setLen(mark)
result = true result = true
method readMessage(sconn: SecioConn): Future[seq[byte]] {.async.} = proc readRawMessage(conn: Connection): Future[seq[byte]] {.async.} =
while true: # Discard 0-length payloads
var lengthBuf: array[4, byte]
await conn.stream.readExactly(addr lengthBuf[0], lengthBuf.len)
let length = uint32.fromBytesBE(lengthBuf)
trace "Recieved message header", header = lengthBuf.shortLog, length = length
if length > SecioMaxMessageSize: # Verify length before casting!
trace "Received size of message exceed limits", conn = $conn, length = length
raise (ref SecioError)(msg: "Message exceeds maximum length")
if length > 0:
var buf = newSeq[byte](int(length))
await conn.stream.readExactly(addr buf[0], buf.len)
trace "Received message body",
conn = $conn, length = buf.len, buff = buf.shortLog
return buf
trace "Discarding 0-length payload", conn = $conn
method readMessage*(sconn: SecioConn): Future[seq[byte]] {.async.} =
## Read message from channel secure connection ``sconn``. ## Read message from channel secure connection ``sconn``.
when chronicles.enabledLogLevel == LogLevel.TRACE: when chronicles.enabledLogLevel == LogLevel.TRACE:
logScope: logScope:
stream_oid = $sconn.stream.oid stream_oid = $sconn.stream.oid
try: var buf = await sconn.readRawMessage()
var buf = newSeq[byte](4) if sconn.macCheckAndDecode(buf):
await sconn.readExactly(addr buf[0], 4) result = buf
let length = (int(buf[0]) shl 24) or (int(buf[1]) shl 16) or else:
(int(buf[2]) shl 8) or (int(buf[3])) trace "Message MAC verification failed", buf = buf.shortLog
trace "Received message header", header = buf.shortLog, length = length raise (ref SecioError)(msg: "message failed MAC verification")
if length <= SecioMaxMessageSize:
buf.setLen(length)
await sconn.readExactly(addr buf[0], length)
trace "Received message body", length = length,
buffer = buf.shortLog
if sconn.macCheckAndDecode(buf):
result = buf
else:
trace "Message MAC verification failed", buf = buf.shortLog
else:
trace "Received message header size is more then allowed",
length = length, allowed_length = SecioMaxMessageSize
except LPStreamIncompleteError:
trace "Connection dropped while reading"
except LPStreamReadError:
trace "Error reading from connection"
method writeMessage(sconn: SecioConn, message: seq[byte]) {.async.} = method write*(sconn: SecioConn, message: seq[byte]) {.async.} =
## Write message ``message`` to secure connection ``sconn``. ## Write message ``message`` to secure connection ``sconn``.
if message.len == 0:
return
try: try:
var var
left = message.len left = message.len
@ -211,8 +219,12 @@ method writeMessage(sconn: SecioConn, message: seq[byte]) {.async.} =
while left > 0: while left > 0:
let let
chunkSize = if left > SecioMaxMessageSize - 64: SecioMaxMessageSize - 64 else: left chunkSize = if left > SecioMaxMessageSize - 64: SecioMaxMessageSize - 64 else: left
let macsize = sconn.writerMac.sizeDigest() macsize = sconn.writerMac.sizeDigest()
length = chunkSize + macsize
var msg = newSeq[byte](chunkSize + 4 + macsize) var msg = newSeq[byte](chunkSize + 4 + macsize)
msg[0..<4] = uint32(length).toBytesBE()
sconn.writerCoder.encrypt(message.toOpenArray(offset, offset + chunkSize - 1), sconn.writerCoder.encrypt(message.toOpenArray(offset, offset + chunkSize - 1),
msg.toOpenArray(4, 4 + chunkSize - 1)) msg.toOpenArray(4, 4 + chunkSize - 1))
left = left - chunkSize left = left - chunkSize
@ -221,13 +233,9 @@ method writeMessage(sconn: SecioConn, message: seq[byte]) {.async.} =
sconn.writerMac.update(msg.toOpenArray(4, 4 + chunkSize - 1)) sconn.writerMac.update(msg.toOpenArray(4, 4 + chunkSize - 1))
sconn.writerMac.finish(msg.toOpenArray(mo, mo + macsize - 1)) sconn.writerMac.finish(msg.toOpenArray(mo, mo + macsize - 1))
sconn.writerMac.reset() sconn.writerMac.reset()
let length = chunkSize + macsize
msg[0] = byte((length shr 24) and 0xFF)
msg[1] = byte((length shr 16) and 0xFF)
msg[2] = byte((length shr 8) and 0xFF)
msg[3] = byte(length and 0xFF)
trace "Writing message", message = msg.shortLog, left, offset trace "Writing message", message = msg.shortLog, left, offset
await sconn.write(msg) await sconn.stream.write(msg)
except AsyncStreamWriteError: except AsyncStreamWriteError:
trace "Could not write to connection" trace "Could not write to connection"
@ -269,30 +277,9 @@ proc newSecioConn(conn: Connection,
proc transactMessage(conn: Connection, proc transactMessage(conn: Connection,
msg: seq[byte]): Future[seq[byte]] {.async.} = msg: seq[byte]): Future[seq[byte]] {.async.} =
var buf = newSeq[byte](4) trace "Sending message", message = msg.shortLog, length = len(msg)
try: await conn.write(msg)
trace "Sending message", message = msg.shortLog, length = len(msg) return await conn.readRawMessage()
await conn.write(msg)
await conn.readExactly(addr buf[0], 4)
let length = (int(buf[0]) shl 24) or (int(buf[1]) shl 16) or
(int(buf[2]) shl 8) or (int(buf[3]))
trace "Recieved message header", header = buf.shortLog, length = length
if length <= SecioMaxMessageSize:
buf.setLen(length)
await conn.readExactly(addr buf[0], length)
trace "Received message body", conn = $conn,
length = length,
buff = buf.shortLog
result = buf
else:
trace "Received size of message exceed limits", conn = $conn,
length = length
except LPStreamIncompleteError:
trace "Connection dropped while reading", conn = $conn
except LPStreamReadError:
trace "Error reading from connection", conn = $conn
except LPStreamWriteError:
trace "Could not write to connection", conn = $conn
method handshake*(s: Secio, conn: Connection, initiator: bool = false): Future[SecureConn] {.async.} = method handshake*(s: Secio, conn: Connection, initiator: bool = false): Future[SecureConn] {.async.} =
var var
@ -312,7 +299,7 @@ method handshake*(s: Secio, conn: Connection, initiator: bool = false): Future[S
localBytesPubkey = s.localPublicKey.getBytes() localBytesPubkey = s.localPublicKey.getBytes()
if randomBytes(localNonce) != SecioNonceSize: if randomBytes(localNonce) != SecioNonceSize:
raise newException(CatchableError, "Could not generate random data") raise (ref SecioError)(msg: "Could not generate random data")
var request = createProposal(localNonce, var request = createProposal(localNonce,
localBytesPubkey, localBytesPubkey,
@ -332,16 +319,16 @@ method handshake*(s: Secio, conn: Connection, initiator: bool = false): Future[S
if len(answer) == 0: if len(answer) == 0:
trace "Proposal exchange failed", conn = $conn trace "Proposal exchange failed", conn = $conn
raise newException(SecioError, "Proposal exchange failed") raise (ref SecioError)(msg: "Proposal exchange failed")
if not decodeProposal(answer, remoteNonce, remoteBytesPubkey, remoteExchanges, if not decodeProposal(answer, remoteNonce, remoteBytesPubkey, remoteExchanges,
remoteCiphers, remoteHashes): remoteCiphers, remoteHashes):
trace "Remote proposal decoding failed", conn = $conn trace "Remote proposal decoding failed", conn = $conn
raise newException(SecioError, "Remote proposal decoding failed") raise (ref SecioError)(msg: "Remote proposal decoding failed")
if not remotePubkey.init(remoteBytesPubkey): if not remotePubkey.init(remoteBytesPubkey):
trace "Remote public key incorrect or corrupted", pubkey = remoteBytesPubkey.shortLog trace "Remote public key incorrect or corrupted", pubkey = remoteBytesPubkey.shortLog
raise newException(SecioError, "Remote public key incorrect or corrupted") raise (ref SecioError)(msg: "Remote public key incorrect or corrupted")
remotePeerId = PeerID.init(remotePubkey) remotePeerId = PeerID.init(remotePubkey)
@ -358,7 +345,7 @@ method handshake*(s: Secio, conn: Connection, initiator: bool = false): Future[S
let hash = selectBest(order, SecioHashes, remoteHashes) let hash = selectBest(order, SecioHashes, remoteHashes)
if len(scheme) == 0 or len(cipher) == 0 or len(hash) == 0: if len(scheme) == 0 or len(cipher) == 0 or len(hash) == 0:
trace "No algorithms in common", peer = remotePeerId trace "No algorithms in common", peer = remotePeerId
raise newException(SecioError, "No algorithms in common") raise (ref SecioError)(msg: "No algorithms in common")
trace "Encryption scheme selected", scheme = scheme, cipher = cipher, trace "Encryption scheme selected", scheme = scheme, cipher = cipher,
hash = hash hash = hash
@ -373,15 +360,15 @@ method handshake*(s: Secio, conn: Connection, initiator: bool = false): Future[S
var remoteExchange = await transactMessage(conn, localExchange) var remoteExchange = await transactMessage(conn, localExchange)
if len(remoteExchange) == 0: if len(remoteExchange) == 0:
trace "Corpus exchange failed", conn = $conn trace "Corpus exchange failed", conn = $conn
raise newException(SecioError, "Corpus exchange failed") raise (ref SecioError)(msg: "Corpus exchange failed")
if not decodeExchange(remoteExchange, remoteEBytesPubkey, remoteEBytesSig): if not decodeExchange(remoteExchange, remoteEBytesPubkey, remoteEBytesSig):
trace "Remote exchange decoding failed", conn = $conn trace "Remote exchange decoding failed", conn = $conn
raise newException(SecioError, "Remote exchange decoding failed") raise (ref SecioError)(msg: "Remote exchange decoding failed")
if not remoteESignature.init(remoteEBytesSig): if not remoteESignature.init(remoteEBytesSig):
trace "Remote signature incorrect or corrupted", signature = remoteEBytesSig.shortLog trace "Remote signature incorrect or corrupted", signature = remoteEBytesSig.shortLog
raise newException(SecioError, "Remote signature incorrect or corrupted") raise (ref SecioError)(msg: "Remote signature incorrect or corrupted")
var remoteCorpus = answer & request[4..^1] & remoteEBytesPubkey var remoteCorpus = answer & request[4..^1] & remoteEBytesPubkey
if not remoteESignature.verify(remoteCorpus, remotePubkey): if not remoteESignature.verify(remoteCorpus, remotePubkey):
@ -389,21 +376,21 @@ method handshake*(s: Secio, conn: Connection, initiator: bool = false): Future[S
signature = $remoteESignature, signature = $remoteESignature,
pubkey = $remotePubkey, pubkey = $remotePubkey,
corpus = $remoteCorpus corpus = $remoteCorpus
raise newException(SecioError, "Signature verification failed") raise (ref SecioError)(msg: "Signature verification failed")
trace "Signature verified", scheme = remotePubkey.scheme trace "Signature verified", scheme = remotePubkey.scheme
if not remoteEPubkey.eckey.initRaw(remoteEBytesPubkey): if not remoteEPubkey.eckey.initRaw(remoteEBytesPubkey):
trace "Remote ephemeral public key incorrect or corrupted", trace "Remote ephemeral public key incorrect or corrupted",
pubkey = toHex(remoteEBytesPubkey) pubkey = toHex(remoteEBytesPubkey)
raise newException(SecioError, "Remote ephemeral public key incorrect or corrupted") raise (ref SecioError)(msg: "Remote ephemeral public key incorrect or corrupted")
var secret = getSecret(remoteEPubkey, ekeypair.seckey) var secret = getSecret(remoteEPubkey, ekeypair.seckey)
if len(secret) == 0: if len(secret) == 0:
trace "Shared secret could not be created", trace "Shared secret could not be created",
pubkeyScheme = remoteEPubkey.scheme, pubkeyScheme = remoteEPubkey.scheme,
seckeyScheme = ekeypair.seckey.scheme seckeyScheme = ekeypair.seckey.scheme
raise newException(SecioError, "Shared secret could not be created") raise (ref SecioError)(msg: "Shared secret could not be created")
trace "Shared secret calculated", secret = secret.shortLog trace "Shared secret calculated", secret = secret.shortLog
@ -419,13 +406,13 @@ method handshake*(s: Secio, conn: Connection, initiator: bool = false): Future[S
var secioConn = newSecioConn(conn, hash, cipher, keys, order, remotePubkey) var secioConn = newSecioConn(conn, hash, cipher, keys, order, remotePubkey)
result = secioConn result = secioConn
await secioConn.writeMessage(remoteNonce) await secioConn.write(remoteNonce)
var res = await secioConn.readMessage() var res = await secioConn.readMessage()
if res != @localNonce: if res != @localNonce:
trace "Nonce verification failed", receivedNonce = res.shortLog, trace "Nonce verification failed", receivedNonce = res.shortLog,
localNonce = localNonce.shortLog localNonce = localNonce.shortLog
raise newException(CatchableError, "Nonce verification failed") raise (ref SecioError)(msg: "Nonce verification failed")
else: else:
trace "Secure handshake succeeded" trace "Secure handshake succeeded"

View File

@ -10,57 +10,28 @@
import options import options
import chronos, chronicles import chronos, chronicles
import ../protocol, import ../protocol,
../../stream/bufferstream, ../../stream/streamseq,
../../connection, ../../connection,
../../peerinfo, ../../peerinfo
../../utility
type type
Secure* = ref object of LPProtocol # base type for secure managers Secure* = ref object of LPProtocol # base type for secure managers
SecureConn* = ref object of Connection SecureConn* = ref object of Connection
buf: StreamSeq
method readMessage*(c: SecureConn): Future[seq[byte]] {.async, base.} = method readMessage*(c: SecureConn): Future[seq[byte]] {.async, base.} =
doAssert(false, "Not implemented!") doAssert(false, "Not implemented!")
method writeMessage*(c: SecureConn, data: seq[byte]) {.async, base.} =
doAssert(false, "Not implemented!")
method handshake(s: Secure, method handshake(s: Secure,
conn: Connection, conn: Connection,
initiator: bool): Future[SecureConn] {.async, base.} = initiator: bool): Future[SecureConn] {.async, base.} =
doAssert(false, "Not implemented!") doAssert(false, "Not implemented!")
proc readLoop(sconn: SecureConn, conn: Connection) {.async.} =
try:
let stream = BufferStream(conn.stream)
while not sconn.closed:
let msg = await sconn.readMessage()
if msg.len == 0:
trace "stream EOF"
return
await stream.pushTo(msg)
except CatchableError as exc:
trace "Exception occurred Secure.readLoop", exc = exc.msg
finally:
trace "closing conn", closed = conn.closed()
if not conn.closed:
await conn.close()
trace "closing sconn", closed = sconn.closed()
if not sconn.closed:
await sconn.close()
trace "ending Secure readLoop"
proc handleConn*(s: Secure, conn: Connection, initiator: bool): Future[Connection] {.async, gcsafe.} = proc handleConn*(s: Secure, conn: Connection, initiator: bool): Future[Connection] {.async, gcsafe.} =
var sconn = await s.handshake(conn, initiator) var sconn = await s.handshake(conn, initiator)
proc writeHandler(data: seq[byte]) {.async, gcsafe.} =
trace "sending encrypted bytes", bytes = data.shortLog
await sconn.writeMessage(data)
result = newConnection(newBufferStream(writeHandler)) result = sconn
conn.readLoops &= readLoop(sconn, result)
if not isNil(sconn.peerInfo) and sconn.peerInfo.publicKey.isSome: if not isNil(sconn.peerInfo) and sconn.peerInfo.publicKey.isSome:
result.peerInfo = PeerInfo.init(sconn.peerInfo.publicKey.get()) result.peerInfo = PeerInfo.init(sconn.peerInfo.publicKey.get())
@ -86,3 +57,37 @@ method secure*(s: Secure, conn: Connection, initiator: bool): Future[Connection]
warn "securing connection failed", msg = exc.msg warn "securing connection failed", msg = exc.msg
if not conn.closed(): if not conn.closed():
await conn.close() await conn.close()
method readExactly*(s: SecureConn,
pbytes: pointer,
nbytes: int):
Future[void] {.async, gcsafe.} =
if nbytes == 0:
return
while s.buf.data().len < nbytes:
# TODO write decrypted content straight into buf using `prepare`
let buf = await s.readMessage()
if buf.len == 0:
raise newLPStreamIncompleteError()
s.buf.add(buf)
var p = cast[ptr UncheckedArray[byte]](pbytes)
let consumed = s.buf.consumeTo(toOpenArray(p, 0, nbytes - 1))
doAssert consumed == nbytes, "checked above"
method readOnce*(s: SecureConn,
pbytes: pointer,
nbytes: int):
Future[int] {.async, gcsafe.} =
if nbytes == 0:
return 0
if s.buf.data().len() == 0:
let buf = await s.readMessage()
if buf.len == 0:
raise newLPStreamIncompleteError()
s.buf.add(buf)
var p = cast[ptr UncheckedArray[byte]](pbytes)
return s.buf.consumeTo(toOpenArray(p, 0, nbytes - 1))

View File

@ -216,9 +216,7 @@ method readOnce*(s: BufferStream,
await s.readExactly(pbytes, len) await s.readExactly(pbytes, len)
result = len result = len
method write*(s: BufferStream, method write*(s: BufferStream, msg: seq[byte]): Future[void] =
msg: seq[byte],
msglen = -1): Future[void] =
## Write sequence of bytes ``sbytes`` of length ``msglen`` to writer ## Write sequence of bytes ``sbytes`` of length ``msglen`` to writer
## stream ``wstream``. ## stream ``wstream``.
## ##
@ -233,7 +231,7 @@ method write*(s: BufferStream,
retFuture.fail(newNotWritableError()) retFuture.fail(newNotWritableError())
return retFuture return retFuture
result = s.writeHandler(if msglen >= 0: msg[0..<msglen] else: msg) result = s.writeHandler(msg)
proc pipe*(s: BufferStream, proc pipe*(s: BufferStream,
target: BufferStream): BufferStream = target: BufferStream): BufferStream =

View File

@ -8,7 +8,7 @@
## those terms. ## those terms.
import chronos, chronicles import chronos, chronicles
import lpstream import lpstream, ../utility
logScope: logScope:
topic = "ChronosStream" topic = "ChronosStream"
@ -56,12 +56,12 @@ method readOnce*(s: ChronosStream, pbytes: pointer, nbytes: int): Future[int] {.
withExceptions: withExceptions:
result = await s.reader.readOnce(pbytes, nbytes) result = await s.reader.readOnce(pbytes, nbytes)
method write*(s: ChronosStream, msg: seq[byte], msglen = -1) {.async.} = method write*(s: ChronosStream, msg: seq[byte]) {.async.} =
if s.writer.atEof: if s.writer.atEof:
raise newLPStreamEOFError() raise newLPStreamEOFError()
withExceptions: withExceptions:
await s.writer.write(msg, msglen) await s.writer.write(msg)
method closed*(s: ChronosStream): bool {.inline.} = method closed*(s: ChronosStream): bool {.inline.} =
# TODO: we might only need to check for reader's EOF # TODO: we might only need to check for reader's EOF

View File

@ -27,28 +27,31 @@ type
par*: ref Exception par*: ref Exception
LPStreamEOFError* = object of LPStreamError LPStreamEOFError* = object of LPStreamError
proc newLPStreamReadError*(p: ref Exception): ref Exception {.inline.} = proc newLPStreamReadError*(p: ref Exception): ref Exception =
var w = newException(LPStreamReadError, "Read stream failed") var w = newException(LPStreamReadError, "Read stream failed")
w.msg = w.msg & ", originated from [" & $p.name & "] " & p.msg w.msg = w.msg & ", originated from [" & $p.name & "] " & p.msg
w.par = p w.par = p
result = w result = w
proc newLPStreamWriteError*(p: ref Exception): ref Exception {.inline.} = proc newLPStreamReadError*(msg: string): ref Exception =
newException(LPStreamReadError, msg)
proc newLPStreamWriteError*(p: ref Exception): ref Exception =
var w = newException(LPStreamWriteError, "Write stream failed") var w = newException(LPStreamWriteError, "Write stream failed")
w.msg = w.msg & ", originated from [" & $p.name & "] " & p.msg w.msg = w.msg & ", originated from [" & $p.name & "] " & p.msg
w.par = p w.par = p
result = w result = w
proc newLPStreamIncompleteError*(): ref Exception {.inline.} = proc newLPStreamIncompleteError*(): ref Exception =
result = newException(LPStreamIncompleteError, "Incomplete data received") result = newException(LPStreamIncompleteError, "Incomplete data received")
proc newLPStreamLimitError*(): ref Exception {.inline.} = proc newLPStreamLimitError*(): ref Exception =
result = newException(LPStreamLimitError, "Buffer limit reached") result = newException(LPStreamLimitError, "Buffer limit reached")
proc newLPStreamIncorrectDefect*(m: string): ref Exception {.inline.} = proc newLPStreamIncorrectDefect*(m: string): ref Exception =
result = newException(LPStreamIncorrectDefect, m) result = newException(LPStreamIncorrectDefect, m)
proc newLPStreamEOFError*(): ref Exception {.inline.} = proc newLPStreamEOFError*(): ref Exception =
result = newException(LPStreamEOFError, "Stream EOF!") result = newException(LPStreamEOFError, "Stream EOF!")
method closed*(s: LPStream): bool {.base, inline.} = method closed*(s: LPStream): bool {.base, inline.} =
@ -106,16 +109,14 @@ proc readLine*(s: LPStream, limit = 0, sep = "\r\n"): Future[string] {.async, de
except LPStreamIncompleteError, LPStreamReadError: except LPStreamIncompleteError, LPStreamReadError:
discard # EOF, in which case we should return whatever we read so far.. discard # EOF, in which case we should return whatever we read so far..
method write*(s: LPStream, msg: seq[byte], msglen = -1) method write*(s: LPStream, msg: seq[byte]) {.base, async.} =
{.base, async.} =
doAssert(false, "not implemented!") doAssert(false, "not implemented!")
proc write*(s: LPStream, pbytes: pointer, nbytes: int): Future[void] {.deprecated: "seq".} = proc write*(s: LPStream, pbytes: pointer, nbytes: int): Future[void] {.deprecated: "seq".} =
s.write(@(toOpenArray(cast[ptr UncheckedArray[byte]](pbytes), 0, nbytes - 1))) s.write(@(toOpenArray(cast[ptr UncheckedArray[byte]](pbytes), 0, nbytes - 1)))
proc write*(s: LPStream, msg: string, msglen = -1): Future[void] = proc write*(s: LPStream, msg: string): Future[void] =
let nbytes = if msglen >= 0: msglen else: msg.len s.write(@(toOpenArrayByte(msg, 0, msg.high)))
s.write(@(toOpenArrayByte(msg, 0, nbytes - 1)))
method close*(s: LPStream) method close*(s: LPStream)
{.base, async.} = {.base, async.} =

View File

@ -0,0 +1,73 @@
import stew/bitops2
type
StreamSeq* = object
# Seq adapted to the stream use case where we add data at the back and
# consume at the front in chunks. A bit like a deque but contiguous memory
# area - will try to avoid moving data unless it has to, subject to buffer
# space. The assumption is that data is typically consumed fully.
#
# See also asio::stream_buf
buf: seq[byte] # Data store
rpos: int # Reading position - valid data starts here
wpos: int # Writing position - valid data ends here
template len*(v: StreamSeq): int =
v.wpos - v.rpos
func grow(v: var StreamSeq, n: int) =
if v.rpos == v.wpos:
# All data has been consumed, reset positions
v.rpos = 0
v.wpos = 0
if v.buf.len - v.wpos < n:
if v.rpos > 0:
# We've consumed some data so we'll try to move that data to the beginning
# of the buffer, hoping that this will clear up enough capacity to avoid
# reallocation
moveMem(addr v.buf[0], addr v.buf[v.rpos], v.wpos - v.rpos)
v.wpos -= v.rpos
v.rpos = 0
if v.buf.len - v.wpos >= n:
return
# TODO this is inefficient - `setLen` will copy all data of buf, even though
# we know that only a part of it contains "valid" data
v.buf.setLen(nextPow2(max(64, v.wpos + n).uint64).int)
template prepare*(v: var StreamSeq, n: int): var openArray[byte] =
## Return a buffer that is at least `n` bytes long
mixin grow
v.grow(n)
v.buf.toOpenArray(v.wpos, v.buf.len - 1)
template commit*(v: var StreamSeq, n: int) =
## Mark `n` bytes in the buffer returned by `prepare` as ready for reading
v.wpos += n
func add*(v: var StreamSeq, data: openArray[byte]) =
## Add data - the equivalent of `buf.prepare(n) = data; buf.commit(n)`
if data.len > 0:
v.grow(data.len)
copyMem(addr v.buf[v.wpos], unsafeAddr data[0], data.len)
v.commit(data.len)
template data*(v: StreamSeq): openArray[byte] =
# Data that is ready to be consumed
# TODO a double-hash comment here breaks compile (!)
v.buf.toOpenArray(v.rpos, v.wpos - 1)
func consume*(v: var StreamSeq, n: int) =
## Mark `n` bytes that were returned via `data` as consumed
v.rpos += n
func consumeTo*(v: var StreamSeq, buf: var openArray[byte]): int =
let bytes = min(buf.len, v.len)
if bytes > 0:
copyMem(addr buf[0], addr v.buf[v.rpos], bytes)
v.consume(bytes)
bytes

View File

@ -149,7 +149,7 @@ suite "BufferStream":
let buff = newBufferStream(writeHandler, 10) let buff = newBufferStream(writeHandler, 10)
check buff.len == 0 check buff.len == 0
await buff.write("Hello!", 6) await buff.write("Hello!")
result = true result = true
@ -166,7 +166,7 @@ suite "BufferStream":
let buff = newBufferStream(writeHandler, 10) let buff = newBufferStream(writeHandler, 10)
check buff.len == 0 check buff.len == 0
await buff.write(cast[seq[byte]]("Hello!"), 6) await buff.write(cast[seq[byte]]("Hello!"))
result = true result = true

View File

@ -48,21 +48,17 @@ proc readLp*(s: StreamTransport): Future[seq[byte]] {.async, gcsafe.} =
length: int length: int
res: VarintStatus res: VarintStatus
result = newSeq[byte](10) result = newSeq[byte](10)
try:
for i in 0..<len(result): for i in 0..<len(result):
await s.readExactly(addr result[i], 1) await s.readExactly(addr result[i], 1)
res = LP.getUVarint(result.toOpenArray(0, i), length, size) res = LP.getUVarint(result.toOpenArray(0, i), length, size)
if res == VarintStatus.Success: if res == VarintStatus.Success:
break break
if res != VarintStatus.Success: if res != VarintStatus.Success:
raise newInvalidVarintException() raise newInvalidVarintException()
result.setLen(size) result.setLen(size)
if size > 0.uint: if size > 0.uint:
await s.readExactly(addr result[0], int(size)) await s.readExactly(addr result[0], int(size))
except TransportIncompleteError as exc:
trace "remote connection ended unexpectedly", exc = exc.msg
except TransportError as exc:
trace "unable to read from remote connection", exc = exc.msg
proc createNode*(privKey: Option[PrivateKey] = none(PrivateKey), proc createNode*(privKey: Option[PrivateKey] = none(PrivateKey),
address: string = "/ip4/127.0.0.1/tcp/0", address: string = "/ip4/127.0.0.1/tcp/0",

View File

@ -48,8 +48,7 @@ method readExactly*(s: TestSelectStream,
cstring("\0x3na\n"), cstring("\0x3na\n"),
"\0x3na\n".len()) "\0x3na\n".len())
method write*(s: TestSelectStream, msg: seq[byte], msglen = -1) method write*(s: TestSelectStream, msg: seq[byte]) {.async, gcsafe.} = discard
{.async, gcsafe.} = discard
method close(s: TestSelectStream) {.async, gcsafe.} = method close(s: TestSelectStream) {.async, gcsafe.} =
s.isClosed = true s.isClosed = true
@ -92,7 +91,7 @@ method readExactly*(s: TestLsStream,
var buf = "na\n" var buf = "na\n"
copyMem(pbytes, addr buf[0], buf.len()) copyMem(pbytes, addr buf[0], buf.len())
method write*(s: TestLsStream, msg: seq[byte], msglen = -1) {.async, gcsafe.} = method write*(s: TestLsStream, msg: seq[byte]) {.async, gcsafe.} =
if s.step == 4: if s.step == 4:
await s.ls(msg) await s.ls(msg)
@ -139,7 +138,7 @@ method readExactly*(s: TestNaStream,
cstring("\0x3na\n"), cstring("\0x3na\n"),
"\0x3na\n".len()) "\0x3na\n".len())
method write*(s: TestNaStream, msg: seq[byte], msglen = -1) {.async, gcsafe.} = method write*(s: TestNaStream, msg: seq[byte]) {.async, gcsafe.} =
if s.step == 4: if s.step == 4:
await s.na(string.fromBytes(msg)) await s.na(string.fromBytes(msg))

View File

@ -1,4 +1,5 @@
import testvarint import testvarint,
teststreamseq
import testrsa, import testrsa,
testecnist, testecnist,

View File

@ -93,7 +93,7 @@ suite "Noise":
serverNoise = newNoise(serverInfo.privateKey, outgoing = false) serverNoise = newNoise(serverInfo.privateKey, outgoing = false)
proc connHandler(conn: Connection) {.async, gcsafe.} = proc connHandler(conn: Connection) {.async, gcsafe.} =
let sconn = await serverNoise.secure(conn) let sconn = await serverNoise.secure(conn, false)
defer: defer:
await sconn.close() await sconn.close()
await conn.close() await conn.close()
@ -108,7 +108,7 @@ suite "Noise":
clientInfo = PeerInfo.init(PrivateKey.random(RSA), [transport1.ma]) clientInfo = PeerInfo.init(PrivateKey.random(RSA), [transport1.ma])
clientNoise = newNoise(clientInfo.privateKey, outgoing = true) clientNoise = newNoise(clientInfo.privateKey, outgoing = true)
conn = await transport2.dial(transport1.ma) conn = await transport2.dial(transport1.ma)
sconn = await clientNoise.secure(conn) sconn = await clientNoise.secure(conn, true)
msg = await sconn.read(6) msg = await sconn.read(6)
@ -131,7 +131,7 @@ suite "Noise":
readTask = newFuture[void]() readTask = newFuture[void]()
proc connHandler(conn: Connection) {.async, gcsafe.} = proc connHandler(conn: Connection) {.async, gcsafe.} =
let sconn = await serverNoise.secure(conn) let sconn = await serverNoise.secure(conn, false)
defer: defer:
await sconn.close() await sconn.close()
await conn.close() await conn.close()
@ -148,7 +148,7 @@ suite "Noise":
clientInfo = PeerInfo.init(PrivateKey.random(RSA), [transport1.ma]) clientInfo = PeerInfo.init(PrivateKey.random(RSA), [transport1.ma])
clientNoise = newNoise(clientInfo.privateKey, outgoing = true) clientNoise = newNoise(clientInfo.privateKey, outgoing = true)
conn = await transport2.dial(transport1.ma) conn = await transport2.dial(transport1.ma)
sconn = await clientNoise.secure(conn) sconn = await clientNoise.secure(conn, true)
await sconn.write("Hello!".cstring, 6) await sconn.write("Hello!".cstring, 6)
await readTask await readTask
@ -175,7 +175,7 @@ suite "Noise":
trace "Sending huge payload", size = hugePayload.len trace "Sending huge payload", size = hugePayload.len
proc connHandler(conn: Connection) {.async, gcsafe.} = proc connHandler(conn: Connection) {.async, gcsafe.} =
let sconn = await serverNoise.secure(conn) let sconn = await serverNoise.secure(conn, false)
defer: defer:
await sconn.close() await sconn.close()
let msg = await sconn.readLp() let msg = await sconn.readLp()
@ -191,7 +191,7 @@ suite "Noise":
clientInfo = PeerInfo.init(PrivateKey.random(RSA), [transport1.ma]) clientInfo = PeerInfo.init(PrivateKey.random(RSA), [transport1.ma])
clientNoise = newNoise(clientInfo.privateKey, outgoing = true) clientNoise = newNoise(clientInfo.privateKey, outgoing = true)
conn = await transport2.dial(transport1.ma) conn = await transport2.dial(transport1.ma)
sconn = await clientNoise.secure(conn) sconn = await clientNoise.secure(conn, true)
await sconn.writeLp(hugePayload) await sconn.writeLp(hugePayload)
await readTask await readTask

55
tests/teststreamseq.nim Normal file
View File

@ -0,0 +1,55 @@
{.used.}
import unittest
import stew/byteutils
import ../libp2p/stream/streamseq
suite "StreamSeq":
test "basics":
var s: StreamSeq
check:
s.data().len == 0
s.add([byte 0, 1, 2, 3])
check:
@(s.data()) == [byte 0, 1, 2, 3]
s.prepare(10)[0..<3] = [byte 4, 5, 6]
check:
@(s.data()) == [byte 0, 1, 2, 3]
s.commit(3)
check:
@(s.data()) == [byte 0, 1, 2, 3, 4, 5, 6]
s.consume(1)
check:
@(s.data()) == [byte 1, 2, 3, 4, 5, 6]
s.consume(6)
check: @(s.data()) == []
s.add([])
check: @(s.data()) == []
var o: seq[byte]
check: 0 == s.consumeTo(o)
s.add([byte 1, 2, 3])
o.setLen(2)
o.setLen(s.consumeTo(o))
check:
o == [byte 1, 2]
o.setLen(s.consumeTo(o))
check:
o == [byte 3]