mirror of
https://github.com/status-im/nim-libp2p.git
synced 2025-01-11 05:26:02 +00:00
remove readLoop in secure protocols (#162)
* remove readLoop in secure protocols, fix security issues * fix Defect on remote sending 0-byte noise/secio message * remove msglen from `write` (unused) * simplify SecureConn data flow * document some control-flow issues * unify exception behaviour across noise and secio * secio would not raise on mac/decryption errors * fix compile error
This commit is contained in:
parent
330da51819
commit
1efada474c
@ -109,7 +109,7 @@ method readExactly*(s: Connection,
|
||||
pbytes: pointer,
|
||||
nbytes: int):
|
||||
Future[void] {.gcsafe.} =
|
||||
s.stream.readExactly(pbytes, nbytes)
|
||||
s.stream.readExactly(pbytes, nbytes)
|
||||
|
||||
method readOnce*(s: Connection,
|
||||
pbytes: pointer,
|
||||
@ -118,10 +118,9 @@ method readOnce*(s: Connection,
|
||||
s.stream.readOnce(pbytes, nbytes)
|
||||
|
||||
method write*(s: Connection,
|
||||
msg: seq[byte],
|
||||
msglen = -1):
|
||||
msg: seq[byte]):
|
||||
Future[void] {.gcsafe.} =
|
||||
s.stream.write(msg, msglen)
|
||||
s.stream.write(msg)
|
||||
|
||||
method closed*(s: Connection): bool =
|
||||
if isNil(s.stream):
|
||||
|
@ -161,6 +161,6 @@ template writePrefix: untyped =
|
||||
if s.isLazy and not s.isOpen:
|
||||
await s.open()
|
||||
|
||||
method write*(s: LPChannel, msg: seq[byte], msglen = -1) {.async.} =
|
||||
method write*(s: LPChannel, msg: seq[byte]) {.async.} =
|
||||
writePrefix()
|
||||
await procCall write(BufferStream(s), msg, msglen)
|
||||
await procCall write(BufferStream(s), msg)
|
||||
|
@ -267,11 +267,12 @@ template read_s: untyped =
|
||||
|
||||
proc receiveHSMessage(sconn: Connection): Future[seq[byte]] {.async.} =
|
||||
var besize: array[2, byte]
|
||||
await sconn.readExactly(addr besize[0], 2)
|
||||
await sconn.stream.readExactly(addr besize[0], besize.len)
|
||||
let size = uint16.fromBytesBE(besize).int
|
||||
trace "receiveHSMessage", size
|
||||
var buffer = newSeq[byte](size)
|
||||
await sconn.readExactly(addr buffer[0], size)
|
||||
if buffer.len > 0:
|
||||
await sconn.stream.readExactly(addr buffer[0], buffer.len)
|
||||
return buffer
|
||||
|
||||
proc sendHSMessage(sconn: Connection; buf: seq[byte]) {.async.} =
|
||||
@ -416,25 +417,25 @@ proc handshakeXXInbound(p: Noise, conn: Connection, p2pProof: ProtoBuffer): Futu
|
||||
let (cs1, cs2) = hs.ss.split()
|
||||
return HandshakeResult(cs1: cs1, cs2: cs2, remoteP2psecret: remoteP2psecret, rs: hs.rs)
|
||||
|
||||
method readMessage(sconn: NoiseConnection): Future[seq[byte]] {.async.} =
|
||||
try:
|
||||
method readMessage*(sconn: NoiseConnection): Future[seq[byte]] {.async.} =
|
||||
while true: # Discard 0-length payloads
|
||||
var besize: array[2, byte]
|
||||
await sconn.readExactly(addr besize[0], 2)
|
||||
let size = uint16.fromBytesBE(besize).int
|
||||
await sconn.stream.readExactly(addr besize[0], besize.len)
|
||||
let size = uint16.fromBytesBE(besize).int # Cannot overflow
|
||||
trace "receiveEncryptedMessage", size, peer = $sconn.peerInfo
|
||||
if size == 0:
|
||||
return @[]
|
||||
var buffer = newSeq[byte](size)
|
||||
await sconn.readExactly(addr buffer[0], size)
|
||||
var plain = sconn.readCs.decryptWithAd([], buffer)
|
||||
unpackNoisePayload(plain)
|
||||
return plain
|
||||
except LPStreamIncompleteError:
|
||||
trace "Connection dropped while reading"
|
||||
except LPStreamReadError:
|
||||
trace "Error reading from connection"
|
||||
if size > 0:
|
||||
var buffer = newSeq[byte](size)
|
||||
await sconn.stream.readExactly(addr buffer[0], buffer.len)
|
||||
var plain = sconn.readCs.decryptWithAd([], buffer)
|
||||
unpackNoisePayload(plain)
|
||||
return plain
|
||||
else:
|
||||
trace "Received 0-length message", conn = $conn
|
||||
|
||||
method write*(sconn: NoiseConnection, message: seq[byte]): Future[void] {.async.} =
|
||||
if message.len == 0:
|
||||
return
|
||||
|
||||
method writeMessage(sconn: NoiseConnection, message: seq[byte]): Future[void] {.async.} =
|
||||
try:
|
||||
var
|
||||
left = message.len
|
||||
@ -453,7 +454,7 @@ method writeMessage(sconn: NoiseConnection, message: seq[byte]): Future[void] {.
|
||||
trace "sendEncryptedMessage", size = lesize, peer = $sconn.peerInfo, left, offset
|
||||
outbuf &= besize
|
||||
outbuf &= cipher
|
||||
await sconn.write(outbuf)
|
||||
await sconn.stream.write(outbuf)
|
||||
except AsyncStreamWriteError:
|
||||
trace "Could not write to connection"
|
||||
|
||||
@ -520,15 +521,6 @@ method init*(p: Noise) {.gcsafe.} =
|
||||
procCall Secure(p).init()
|
||||
p.codec = NoiseCodec
|
||||
|
||||
proc secure*(p: Noise, conn: Connection): Future[Connection] {.async, gcsafe.} =
|
||||
trace "Noise.secure called", initiator=p.outgoing
|
||||
try:
|
||||
result = await p.handleConn(conn, p.outgoing)
|
||||
except CatchableError as exc:
|
||||
warn "securing connection failed", msg = exc.msg
|
||||
if not conn.closed():
|
||||
await conn.close()
|
||||
|
||||
proc newNoise*(privateKey: PrivateKey; outgoing: bool = true; commonPrologue: seq[byte] = @[]): Noise =
|
||||
new result
|
||||
result.outgoing = outgoing
|
||||
|
@ -6,7 +6,7 @@
|
||||
## at your option.
|
||||
## This file may not be copied, modified, or distributed except according to
|
||||
## those terms.
|
||||
import chronos, chronicles, oids
|
||||
import chronos, chronicles, oids, stew/endians2
|
||||
import nimcrypto/[sysrand, hmac, sha2, sha, hash, rijndael, twofish, bcmode]
|
||||
import secure,
|
||||
../../connection,
|
||||
@ -174,36 +174,44 @@ proc macCheckAndDecode(sconn: SecioConn, data: var seq[byte]): bool =
|
||||
data.setLen(mark)
|
||||
result = true
|
||||
|
||||
method readMessage(sconn: SecioConn): Future[seq[byte]] {.async.} =
|
||||
proc readRawMessage(conn: Connection): Future[seq[byte]] {.async.} =
|
||||
while true: # Discard 0-length payloads
|
||||
var lengthBuf: array[4, byte]
|
||||
await conn.stream.readExactly(addr lengthBuf[0], lengthBuf.len)
|
||||
let length = uint32.fromBytesBE(lengthBuf)
|
||||
|
||||
trace "Recieved message header", header = lengthBuf.shortLog, length = length
|
||||
|
||||
if length > SecioMaxMessageSize: # Verify length before casting!
|
||||
trace "Received size of message exceed limits", conn = $conn, length = length
|
||||
raise (ref SecioError)(msg: "Message exceeds maximum length")
|
||||
|
||||
if length > 0:
|
||||
var buf = newSeq[byte](int(length))
|
||||
await conn.stream.readExactly(addr buf[0], buf.len)
|
||||
trace "Received message body",
|
||||
conn = $conn, length = buf.len, buff = buf.shortLog
|
||||
return buf
|
||||
|
||||
trace "Discarding 0-length payload", conn = $conn
|
||||
|
||||
method readMessage*(sconn: SecioConn): Future[seq[byte]] {.async.} =
|
||||
## Read message from channel secure connection ``sconn``.
|
||||
when chronicles.enabledLogLevel == LogLevel.TRACE:
|
||||
logScope:
|
||||
stream_oid = $sconn.stream.oid
|
||||
try:
|
||||
var buf = newSeq[byte](4)
|
||||
await sconn.readExactly(addr buf[0], 4)
|
||||
let length = (int(buf[0]) shl 24) or (int(buf[1]) shl 16) or
|
||||
(int(buf[2]) shl 8) or (int(buf[3]))
|
||||
trace "Received message header", header = buf.shortLog, length = length
|
||||
if length <= SecioMaxMessageSize:
|
||||
buf.setLen(length)
|
||||
await sconn.readExactly(addr buf[0], length)
|
||||
trace "Received message body", length = length,
|
||||
buffer = buf.shortLog
|
||||
if sconn.macCheckAndDecode(buf):
|
||||
result = buf
|
||||
else:
|
||||
trace "Message MAC verification failed", buf = buf.shortLog
|
||||
else:
|
||||
trace "Received message header size is more then allowed",
|
||||
length = length, allowed_length = SecioMaxMessageSize
|
||||
except LPStreamIncompleteError:
|
||||
trace "Connection dropped while reading"
|
||||
except LPStreamReadError:
|
||||
trace "Error reading from connection"
|
||||
var buf = await sconn.readRawMessage()
|
||||
if sconn.macCheckAndDecode(buf):
|
||||
result = buf
|
||||
else:
|
||||
trace "Message MAC verification failed", buf = buf.shortLog
|
||||
raise (ref SecioError)(msg: "message failed MAC verification")
|
||||
|
||||
method writeMessage(sconn: SecioConn, message: seq[byte]) {.async.} =
|
||||
method write*(sconn: SecioConn, message: seq[byte]) {.async.} =
|
||||
## Write message ``message`` to secure connection ``sconn``.
|
||||
if message.len == 0:
|
||||
return
|
||||
|
||||
try:
|
||||
var
|
||||
left = message.len
|
||||
@ -211,8 +219,12 @@ method writeMessage(sconn: SecioConn, message: seq[byte]) {.async.} =
|
||||
while left > 0:
|
||||
let
|
||||
chunkSize = if left > SecioMaxMessageSize - 64: SecioMaxMessageSize - 64 else: left
|
||||
let macsize = sconn.writerMac.sizeDigest()
|
||||
macsize = sconn.writerMac.sizeDigest()
|
||||
length = chunkSize + macsize
|
||||
|
||||
var msg = newSeq[byte](chunkSize + 4 + macsize)
|
||||
msg[0..<4] = uint32(length).toBytesBE()
|
||||
|
||||
sconn.writerCoder.encrypt(message.toOpenArray(offset, offset + chunkSize - 1),
|
||||
msg.toOpenArray(4, 4 + chunkSize - 1))
|
||||
left = left - chunkSize
|
||||
@ -221,13 +233,9 @@ method writeMessage(sconn: SecioConn, message: seq[byte]) {.async.} =
|
||||
sconn.writerMac.update(msg.toOpenArray(4, 4 + chunkSize - 1))
|
||||
sconn.writerMac.finish(msg.toOpenArray(mo, mo + macsize - 1))
|
||||
sconn.writerMac.reset()
|
||||
let length = chunkSize + macsize
|
||||
msg[0] = byte((length shr 24) and 0xFF)
|
||||
msg[1] = byte((length shr 16) and 0xFF)
|
||||
msg[2] = byte((length shr 8) and 0xFF)
|
||||
msg[3] = byte(length and 0xFF)
|
||||
|
||||
trace "Writing message", message = msg.shortLog, left, offset
|
||||
await sconn.write(msg)
|
||||
await sconn.stream.write(msg)
|
||||
except AsyncStreamWriteError:
|
||||
trace "Could not write to connection"
|
||||
|
||||
@ -269,30 +277,9 @@ proc newSecioConn(conn: Connection,
|
||||
|
||||
proc transactMessage(conn: Connection,
|
||||
msg: seq[byte]): Future[seq[byte]] {.async.} =
|
||||
var buf = newSeq[byte](4)
|
||||
try:
|
||||
trace "Sending message", message = msg.shortLog, length = len(msg)
|
||||
await conn.write(msg)
|
||||
await conn.readExactly(addr buf[0], 4)
|
||||
let length = (int(buf[0]) shl 24) or (int(buf[1]) shl 16) or
|
||||
(int(buf[2]) shl 8) or (int(buf[3]))
|
||||
trace "Recieved message header", header = buf.shortLog, length = length
|
||||
if length <= SecioMaxMessageSize:
|
||||
buf.setLen(length)
|
||||
await conn.readExactly(addr buf[0], length)
|
||||
trace "Received message body", conn = $conn,
|
||||
length = length,
|
||||
buff = buf.shortLog
|
||||
result = buf
|
||||
else:
|
||||
trace "Received size of message exceed limits", conn = $conn,
|
||||
length = length
|
||||
except LPStreamIncompleteError:
|
||||
trace "Connection dropped while reading", conn = $conn
|
||||
except LPStreamReadError:
|
||||
trace "Error reading from connection", conn = $conn
|
||||
except LPStreamWriteError:
|
||||
trace "Could not write to connection", conn = $conn
|
||||
trace "Sending message", message = msg.shortLog, length = len(msg)
|
||||
await conn.write(msg)
|
||||
return await conn.readRawMessage()
|
||||
|
||||
method handshake*(s: Secio, conn: Connection, initiator: bool = false): Future[SecureConn] {.async.} =
|
||||
var
|
||||
@ -312,7 +299,7 @@ method handshake*(s: Secio, conn: Connection, initiator: bool = false): Future[S
|
||||
localBytesPubkey = s.localPublicKey.getBytes()
|
||||
|
||||
if randomBytes(localNonce) != SecioNonceSize:
|
||||
raise newException(CatchableError, "Could not generate random data")
|
||||
raise (ref SecioError)(msg: "Could not generate random data")
|
||||
|
||||
var request = createProposal(localNonce,
|
||||
localBytesPubkey,
|
||||
@ -332,16 +319,16 @@ method handshake*(s: Secio, conn: Connection, initiator: bool = false): Future[S
|
||||
|
||||
if len(answer) == 0:
|
||||
trace "Proposal exchange failed", conn = $conn
|
||||
raise newException(SecioError, "Proposal exchange failed")
|
||||
raise (ref SecioError)(msg: "Proposal exchange failed")
|
||||
|
||||
if not decodeProposal(answer, remoteNonce, remoteBytesPubkey, remoteExchanges,
|
||||
remoteCiphers, remoteHashes):
|
||||
trace "Remote proposal decoding failed", conn = $conn
|
||||
raise newException(SecioError, "Remote proposal decoding failed")
|
||||
raise (ref SecioError)(msg: "Remote proposal decoding failed")
|
||||
|
||||
if not remotePubkey.init(remoteBytesPubkey):
|
||||
trace "Remote public key incorrect or corrupted", pubkey = remoteBytesPubkey.shortLog
|
||||
raise newException(SecioError, "Remote public key incorrect or corrupted")
|
||||
raise (ref SecioError)(msg: "Remote public key incorrect or corrupted")
|
||||
|
||||
remotePeerId = PeerID.init(remotePubkey)
|
||||
|
||||
@ -358,7 +345,7 @@ method handshake*(s: Secio, conn: Connection, initiator: bool = false): Future[S
|
||||
let hash = selectBest(order, SecioHashes, remoteHashes)
|
||||
if len(scheme) == 0 or len(cipher) == 0 or len(hash) == 0:
|
||||
trace "No algorithms in common", peer = remotePeerId
|
||||
raise newException(SecioError, "No algorithms in common")
|
||||
raise (ref SecioError)(msg: "No algorithms in common")
|
||||
|
||||
trace "Encryption scheme selected", scheme = scheme, cipher = cipher,
|
||||
hash = hash
|
||||
@ -373,15 +360,15 @@ method handshake*(s: Secio, conn: Connection, initiator: bool = false): Future[S
|
||||
var remoteExchange = await transactMessage(conn, localExchange)
|
||||
if len(remoteExchange) == 0:
|
||||
trace "Corpus exchange failed", conn = $conn
|
||||
raise newException(SecioError, "Corpus exchange failed")
|
||||
raise (ref SecioError)(msg: "Corpus exchange failed")
|
||||
|
||||
if not decodeExchange(remoteExchange, remoteEBytesPubkey, remoteEBytesSig):
|
||||
trace "Remote exchange decoding failed", conn = $conn
|
||||
raise newException(SecioError, "Remote exchange decoding failed")
|
||||
raise (ref SecioError)(msg: "Remote exchange decoding failed")
|
||||
|
||||
if not remoteESignature.init(remoteEBytesSig):
|
||||
trace "Remote signature incorrect or corrupted", signature = remoteEBytesSig.shortLog
|
||||
raise newException(SecioError, "Remote signature incorrect or corrupted")
|
||||
raise (ref SecioError)(msg: "Remote signature incorrect or corrupted")
|
||||
|
||||
var remoteCorpus = answer & request[4..^1] & remoteEBytesPubkey
|
||||
if not remoteESignature.verify(remoteCorpus, remotePubkey):
|
||||
@ -389,21 +376,21 @@ method handshake*(s: Secio, conn: Connection, initiator: bool = false): Future[S
|
||||
signature = $remoteESignature,
|
||||
pubkey = $remotePubkey,
|
||||
corpus = $remoteCorpus
|
||||
raise newException(SecioError, "Signature verification failed")
|
||||
raise (ref SecioError)(msg: "Signature verification failed")
|
||||
|
||||
trace "Signature verified", scheme = remotePubkey.scheme
|
||||
|
||||
if not remoteEPubkey.eckey.initRaw(remoteEBytesPubkey):
|
||||
trace "Remote ephemeral public key incorrect or corrupted",
|
||||
pubkey = toHex(remoteEBytesPubkey)
|
||||
raise newException(SecioError, "Remote ephemeral public key incorrect or corrupted")
|
||||
raise (ref SecioError)(msg: "Remote ephemeral public key incorrect or corrupted")
|
||||
|
||||
var secret = getSecret(remoteEPubkey, ekeypair.seckey)
|
||||
if len(secret) == 0:
|
||||
trace "Shared secret could not be created",
|
||||
pubkeyScheme = remoteEPubkey.scheme,
|
||||
seckeyScheme = ekeypair.seckey.scheme
|
||||
raise newException(SecioError, "Shared secret could not be created")
|
||||
raise (ref SecioError)(msg: "Shared secret could not be created")
|
||||
|
||||
trace "Shared secret calculated", secret = secret.shortLog
|
||||
|
||||
@ -419,13 +406,13 @@ method handshake*(s: Secio, conn: Connection, initiator: bool = false): Future[S
|
||||
|
||||
var secioConn = newSecioConn(conn, hash, cipher, keys, order, remotePubkey)
|
||||
result = secioConn
|
||||
await secioConn.writeMessage(remoteNonce)
|
||||
await secioConn.write(remoteNonce)
|
||||
var res = await secioConn.readMessage()
|
||||
|
||||
if res != @localNonce:
|
||||
trace "Nonce verification failed", receivedNonce = res.shortLog,
|
||||
localNonce = localNonce.shortLog
|
||||
raise newException(CatchableError, "Nonce verification failed")
|
||||
raise (ref SecioError)(msg: "Nonce verification failed")
|
||||
else:
|
||||
trace "Secure handshake succeeded"
|
||||
|
||||
|
@ -10,57 +10,28 @@
|
||||
import options
|
||||
import chronos, chronicles
|
||||
import ../protocol,
|
||||
../../stream/bufferstream,
|
||||
../../stream/streamseq,
|
||||
../../connection,
|
||||
../../peerinfo,
|
||||
../../utility
|
||||
../../peerinfo
|
||||
|
||||
type
|
||||
Secure* = ref object of LPProtocol # base type for secure managers
|
||||
|
||||
SecureConn* = ref object of Connection
|
||||
buf: StreamSeq
|
||||
|
||||
method readMessage*(c: SecureConn): Future[seq[byte]] {.async, base.} =
|
||||
doAssert(false, "Not implemented!")
|
||||
|
||||
method writeMessage*(c: SecureConn, data: seq[byte]) {.async, base.} =
|
||||
doAssert(false, "Not implemented!")
|
||||
|
||||
method handshake(s: Secure,
|
||||
conn: Connection,
|
||||
initiator: bool): Future[SecureConn] {.async, base.} =
|
||||
doAssert(false, "Not implemented!")
|
||||
|
||||
proc readLoop(sconn: SecureConn, conn: Connection) {.async.} =
|
||||
try:
|
||||
let stream = BufferStream(conn.stream)
|
||||
while not sconn.closed:
|
||||
let msg = await sconn.readMessage()
|
||||
if msg.len == 0:
|
||||
trace "stream EOF"
|
||||
return
|
||||
|
||||
await stream.pushTo(msg)
|
||||
except CatchableError as exc:
|
||||
trace "Exception occurred Secure.readLoop", exc = exc.msg
|
||||
finally:
|
||||
trace "closing conn", closed = conn.closed()
|
||||
if not conn.closed:
|
||||
await conn.close()
|
||||
|
||||
trace "closing sconn", closed = sconn.closed()
|
||||
if not sconn.closed:
|
||||
await sconn.close()
|
||||
trace "ending Secure readLoop"
|
||||
|
||||
proc handleConn*(s: Secure, conn: Connection, initiator: bool): Future[Connection] {.async, gcsafe.} =
|
||||
var sconn = await s.handshake(conn, initiator)
|
||||
proc writeHandler(data: seq[byte]) {.async, gcsafe.} =
|
||||
trace "sending encrypted bytes", bytes = data.shortLog
|
||||
await sconn.writeMessage(data)
|
||||
|
||||
result = newConnection(newBufferStream(writeHandler))
|
||||
conn.readLoops &= readLoop(sconn, result)
|
||||
result = sconn
|
||||
|
||||
if not isNil(sconn.peerInfo) and sconn.peerInfo.publicKey.isSome:
|
||||
result.peerInfo = PeerInfo.init(sconn.peerInfo.publicKey.get())
|
||||
@ -86,3 +57,37 @@ method secure*(s: Secure, conn: Connection, initiator: bool): Future[Connection]
|
||||
warn "securing connection failed", msg = exc.msg
|
||||
if not conn.closed():
|
||||
await conn.close()
|
||||
|
||||
method readExactly*(s: SecureConn,
|
||||
pbytes: pointer,
|
||||
nbytes: int):
|
||||
Future[void] {.async, gcsafe.} =
|
||||
if nbytes == 0:
|
||||
return
|
||||
|
||||
while s.buf.data().len < nbytes:
|
||||
# TODO write decrypted content straight into buf using `prepare`
|
||||
let buf = await s.readMessage()
|
||||
if buf.len == 0:
|
||||
raise newLPStreamIncompleteError()
|
||||
s.buf.add(buf)
|
||||
|
||||
var p = cast[ptr UncheckedArray[byte]](pbytes)
|
||||
let consumed = s.buf.consumeTo(toOpenArray(p, 0, nbytes - 1))
|
||||
doAssert consumed == nbytes, "checked above"
|
||||
|
||||
method readOnce*(s: SecureConn,
|
||||
pbytes: pointer,
|
||||
nbytes: int):
|
||||
Future[int] {.async, gcsafe.} =
|
||||
if nbytes == 0:
|
||||
return 0
|
||||
|
||||
if s.buf.data().len() == 0:
|
||||
let buf = await s.readMessage()
|
||||
if buf.len == 0:
|
||||
raise newLPStreamIncompleteError()
|
||||
s.buf.add(buf)
|
||||
|
||||
var p = cast[ptr UncheckedArray[byte]](pbytes)
|
||||
return s.buf.consumeTo(toOpenArray(p, 0, nbytes - 1))
|
||||
|
@ -216,9 +216,7 @@ method readOnce*(s: BufferStream,
|
||||
await s.readExactly(pbytes, len)
|
||||
result = len
|
||||
|
||||
method write*(s: BufferStream,
|
||||
msg: seq[byte],
|
||||
msglen = -1): Future[void] =
|
||||
method write*(s: BufferStream, msg: seq[byte]): Future[void] =
|
||||
## Write sequence of bytes ``sbytes`` of length ``msglen`` to writer
|
||||
## stream ``wstream``.
|
||||
##
|
||||
@ -233,7 +231,7 @@ method write*(s: BufferStream,
|
||||
retFuture.fail(newNotWritableError())
|
||||
return retFuture
|
||||
|
||||
result = s.writeHandler(if msglen >= 0: msg[0..<msglen] else: msg)
|
||||
result = s.writeHandler(msg)
|
||||
|
||||
proc pipe*(s: BufferStream,
|
||||
target: BufferStream): BufferStream =
|
||||
|
@ -8,7 +8,7 @@
|
||||
## those terms.
|
||||
|
||||
import chronos, chronicles
|
||||
import lpstream
|
||||
import lpstream, ../utility
|
||||
|
||||
logScope:
|
||||
topic = "ChronosStream"
|
||||
@ -56,12 +56,12 @@ method readOnce*(s: ChronosStream, pbytes: pointer, nbytes: int): Future[int] {.
|
||||
withExceptions:
|
||||
result = await s.reader.readOnce(pbytes, nbytes)
|
||||
|
||||
method write*(s: ChronosStream, msg: seq[byte], msglen = -1) {.async.} =
|
||||
method write*(s: ChronosStream, msg: seq[byte]) {.async.} =
|
||||
if s.writer.atEof:
|
||||
raise newLPStreamEOFError()
|
||||
|
||||
withExceptions:
|
||||
await s.writer.write(msg, msglen)
|
||||
await s.writer.write(msg)
|
||||
|
||||
method closed*(s: ChronosStream): bool {.inline.} =
|
||||
# TODO: we might only need to check for reader's EOF
|
||||
|
@ -27,28 +27,31 @@ type
|
||||
par*: ref Exception
|
||||
LPStreamEOFError* = object of LPStreamError
|
||||
|
||||
proc newLPStreamReadError*(p: ref Exception): ref Exception {.inline.} =
|
||||
proc newLPStreamReadError*(p: ref Exception): ref Exception =
|
||||
var w = newException(LPStreamReadError, "Read stream failed")
|
||||
w.msg = w.msg & ", originated from [" & $p.name & "] " & p.msg
|
||||
w.par = p
|
||||
result = w
|
||||
|
||||
proc newLPStreamWriteError*(p: ref Exception): ref Exception {.inline.} =
|
||||
proc newLPStreamReadError*(msg: string): ref Exception =
|
||||
newException(LPStreamReadError, msg)
|
||||
|
||||
proc newLPStreamWriteError*(p: ref Exception): ref Exception =
|
||||
var w = newException(LPStreamWriteError, "Write stream failed")
|
||||
w.msg = w.msg & ", originated from [" & $p.name & "] " & p.msg
|
||||
w.par = p
|
||||
result = w
|
||||
|
||||
proc newLPStreamIncompleteError*(): ref Exception {.inline.} =
|
||||
proc newLPStreamIncompleteError*(): ref Exception =
|
||||
result = newException(LPStreamIncompleteError, "Incomplete data received")
|
||||
|
||||
proc newLPStreamLimitError*(): ref Exception {.inline.} =
|
||||
proc newLPStreamLimitError*(): ref Exception =
|
||||
result = newException(LPStreamLimitError, "Buffer limit reached")
|
||||
|
||||
proc newLPStreamIncorrectDefect*(m: string): ref Exception {.inline.} =
|
||||
proc newLPStreamIncorrectDefect*(m: string): ref Exception =
|
||||
result = newException(LPStreamIncorrectDefect, m)
|
||||
|
||||
proc newLPStreamEOFError*(): ref Exception {.inline.} =
|
||||
proc newLPStreamEOFError*(): ref Exception =
|
||||
result = newException(LPStreamEOFError, "Stream EOF!")
|
||||
|
||||
method closed*(s: LPStream): bool {.base, inline.} =
|
||||
@ -106,16 +109,14 @@ proc readLine*(s: LPStream, limit = 0, sep = "\r\n"): Future[string] {.async, de
|
||||
except LPStreamIncompleteError, LPStreamReadError:
|
||||
discard # EOF, in which case we should return whatever we read so far..
|
||||
|
||||
method write*(s: LPStream, msg: seq[byte], msglen = -1)
|
||||
{.base, async.} =
|
||||
method write*(s: LPStream, msg: seq[byte]) {.base, async.} =
|
||||
doAssert(false, "not implemented!")
|
||||
|
||||
proc write*(s: LPStream, pbytes: pointer, nbytes: int): Future[void] {.deprecated: "seq".} =
|
||||
s.write(@(toOpenArray(cast[ptr UncheckedArray[byte]](pbytes), 0, nbytes - 1)))
|
||||
|
||||
proc write*(s: LPStream, msg: string, msglen = -1): Future[void] =
|
||||
let nbytes = if msglen >= 0: msglen else: msg.len
|
||||
s.write(@(toOpenArrayByte(msg, 0, nbytes - 1)))
|
||||
proc write*(s: LPStream, msg: string): Future[void] =
|
||||
s.write(@(toOpenArrayByte(msg, 0, msg.high)))
|
||||
|
||||
method close*(s: LPStream)
|
||||
{.base, async.} =
|
||||
|
73
libp2p/stream/streamseq.nim
Normal file
73
libp2p/stream/streamseq.nim
Normal file
@ -0,0 +1,73 @@
|
||||
import stew/bitops2
|
||||
|
||||
type
|
||||
StreamSeq* = object
|
||||
# Seq adapted to the stream use case where we add data at the back and
|
||||
# consume at the front in chunks. A bit like a deque but contiguous memory
|
||||
# area - will try to avoid moving data unless it has to, subject to buffer
|
||||
# space. The assumption is that data is typically consumed fully.
|
||||
#
|
||||
# See also asio::stream_buf
|
||||
|
||||
buf: seq[byte] # Data store
|
||||
rpos: int # Reading position - valid data starts here
|
||||
wpos: int # Writing position - valid data ends here
|
||||
|
||||
template len*(v: StreamSeq): int =
|
||||
v.wpos - v.rpos
|
||||
|
||||
func grow(v: var StreamSeq, n: int) =
|
||||
if v.rpos == v.wpos:
|
||||
# All data has been consumed, reset positions
|
||||
v.rpos = 0
|
||||
v.wpos = 0
|
||||
|
||||
if v.buf.len - v.wpos < n:
|
||||
if v.rpos > 0:
|
||||
# We've consumed some data so we'll try to move that data to the beginning
|
||||
# of the buffer, hoping that this will clear up enough capacity to avoid
|
||||
# reallocation
|
||||
moveMem(addr v.buf[0], addr v.buf[v.rpos], v.wpos - v.rpos)
|
||||
v.wpos -= v.rpos
|
||||
v.rpos = 0
|
||||
|
||||
if v.buf.len - v.wpos >= n:
|
||||
return
|
||||
|
||||
# TODO this is inefficient - `setLen` will copy all data of buf, even though
|
||||
# we know that only a part of it contains "valid" data
|
||||
v.buf.setLen(nextPow2(max(64, v.wpos + n).uint64).int)
|
||||
|
||||
template prepare*(v: var StreamSeq, n: int): var openArray[byte] =
|
||||
## Return a buffer that is at least `n` bytes long
|
||||
mixin grow
|
||||
v.grow(n)
|
||||
|
||||
v.buf.toOpenArray(v.wpos, v.buf.len - 1)
|
||||
|
||||
template commit*(v: var StreamSeq, n: int) =
|
||||
## Mark `n` bytes in the buffer returned by `prepare` as ready for reading
|
||||
v.wpos += n
|
||||
|
||||
func add*(v: var StreamSeq, data: openArray[byte]) =
|
||||
## Add data - the equivalent of `buf.prepare(n) = data; buf.commit(n)`
|
||||
if data.len > 0:
|
||||
v.grow(data.len)
|
||||
copyMem(addr v.buf[v.wpos], unsafeAddr data[0], data.len)
|
||||
v.commit(data.len)
|
||||
|
||||
template data*(v: StreamSeq): openArray[byte] =
|
||||
# Data that is ready to be consumed
|
||||
# TODO a double-hash comment here breaks compile (!)
|
||||
v.buf.toOpenArray(v.rpos, v.wpos - 1)
|
||||
|
||||
func consume*(v: var StreamSeq, n: int) =
|
||||
## Mark `n` bytes that were returned via `data` as consumed
|
||||
v.rpos += n
|
||||
|
||||
func consumeTo*(v: var StreamSeq, buf: var openArray[byte]): int =
|
||||
let bytes = min(buf.len, v.len)
|
||||
if bytes > 0:
|
||||
copyMem(addr buf[0], addr v.buf[v.rpos], bytes)
|
||||
v.consume(bytes)
|
||||
bytes
|
@ -149,7 +149,7 @@ suite "BufferStream":
|
||||
let buff = newBufferStream(writeHandler, 10)
|
||||
check buff.len == 0
|
||||
|
||||
await buff.write("Hello!", 6)
|
||||
await buff.write("Hello!")
|
||||
|
||||
result = true
|
||||
|
||||
@ -166,7 +166,7 @@ suite "BufferStream":
|
||||
let buff = newBufferStream(writeHandler, 10)
|
||||
check buff.len == 0
|
||||
|
||||
await buff.write(cast[seq[byte]]("Hello!"), 6)
|
||||
await buff.write(cast[seq[byte]]("Hello!"))
|
||||
|
||||
result = true
|
||||
|
||||
|
@ -48,21 +48,17 @@ proc readLp*(s: StreamTransport): Future[seq[byte]] {.async, gcsafe.} =
|
||||
length: int
|
||||
res: VarintStatus
|
||||
result = newSeq[byte](10)
|
||||
try:
|
||||
for i in 0..<len(result):
|
||||
await s.readExactly(addr result[i], 1)
|
||||
res = LP.getUVarint(result.toOpenArray(0, i), length, size)
|
||||
if res == VarintStatus.Success:
|
||||
break
|
||||
if res != VarintStatus.Success:
|
||||
raise newInvalidVarintException()
|
||||
result.setLen(size)
|
||||
if size > 0.uint:
|
||||
await s.readExactly(addr result[0], int(size))
|
||||
except TransportIncompleteError as exc:
|
||||
trace "remote connection ended unexpectedly", exc = exc.msg
|
||||
except TransportError as exc:
|
||||
trace "unable to read from remote connection", exc = exc.msg
|
||||
|
||||
for i in 0..<len(result):
|
||||
await s.readExactly(addr result[i], 1)
|
||||
res = LP.getUVarint(result.toOpenArray(0, i), length, size)
|
||||
if res == VarintStatus.Success:
|
||||
break
|
||||
if res != VarintStatus.Success:
|
||||
raise newInvalidVarintException()
|
||||
result.setLen(size)
|
||||
if size > 0.uint:
|
||||
await s.readExactly(addr result[0], int(size))
|
||||
|
||||
proc createNode*(privKey: Option[PrivateKey] = none(PrivateKey),
|
||||
address: string = "/ip4/127.0.0.1/tcp/0",
|
||||
|
@ -48,8 +48,7 @@ method readExactly*(s: TestSelectStream,
|
||||
cstring("\0x3na\n"),
|
||||
"\0x3na\n".len())
|
||||
|
||||
method write*(s: TestSelectStream, msg: seq[byte], msglen = -1)
|
||||
{.async, gcsafe.} = discard
|
||||
method write*(s: TestSelectStream, msg: seq[byte]) {.async, gcsafe.} = discard
|
||||
|
||||
method close(s: TestSelectStream) {.async, gcsafe.} =
|
||||
s.isClosed = true
|
||||
@ -92,7 +91,7 @@ method readExactly*(s: TestLsStream,
|
||||
var buf = "na\n"
|
||||
copyMem(pbytes, addr buf[0], buf.len())
|
||||
|
||||
method write*(s: TestLsStream, msg: seq[byte], msglen = -1) {.async, gcsafe.} =
|
||||
method write*(s: TestLsStream, msg: seq[byte]) {.async, gcsafe.} =
|
||||
if s.step == 4:
|
||||
await s.ls(msg)
|
||||
|
||||
@ -139,7 +138,7 @@ method readExactly*(s: TestNaStream,
|
||||
cstring("\0x3na\n"),
|
||||
"\0x3na\n".len())
|
||||
|
||||
method write*(s: TestNaStream, msg: seq[byte], msglen = -1) {.async, gcsafe.} =
|
||||
method write*(s: TestNaStream, msg: seq[byte]) {.async, gcsafe.} =
|
||||
if s.step == 4:
|
||||
await s.na(string.fromBytes(msg))
|
||||
|
||||
|
@ -1,4 +1,5 @@
|
||||
import testvarint
|
||||
import testvarint,
|
||||
teststreamseq
|
||||
|
||||
import testrsa,
|
||||
testecnist,
|
||||
|
@ -93,7 +93,7 @@ suite "Noise":
|
||||
serverNoise = newNoise(serverInfo.privateKey, outgoing = false)
|
||||
|
||||
proc connHandler(conn: Connection) {.async, gcsafe.} =
|
||||
let sconn = await serverNoise.secure(conn)
|
||||
let sconn = await serverNoise.secure(conn, false)
|
||||
defer:
|
||||
await sconn.close()
|
||||
await conn.close()
|
||||
@ -108,7 +108,7 @@ suite "Noise":
|
||||
clientInfo = PeerInfo.init(PrivateKey.random(RSA), [transport1.ma])
|
||||
clientNoise = newNoise(clientInfo.privateKey, outgoing = true)
|
||||
conn = await transport2.dial(transport1.ma)
|
||||
sconn = await clientNoise.secure(conn)
|
||||
sconn = await clientNoise.secure(conn, true)
|
||||
|
||||
msg = await sconn.read(6)
|
||||
|
||||
@ -131,7 +131,7 @@ suite "Noise":
|
||||
readTask = newFuture[void]()
|
||||
|
||||
proc connHandler(conn: Connection) {.async, gcsafe.} =
|
||||
let sconn = await serverNoise.secure(conn)
|
||||
let sconn = await serverNoise.secure(conn, false)
|
||||
defer:
|
||||
await sconn.close()
|
||||
await conn.close()
|
||||
@ -148,7 +148,7 @@ suite "Noise":
|
||||
clientInfo = PeerInfo.init(PrivateKey.random(RSA), [transport1.ma])
|
||||
clientNoise = newNoise(clientInfo.privateKey, outgoing = true)
|
||||
conn = await transport2.dial(transport1.ma)
|
||||
sconn = await clientNoise.secure(conn)
|
||||
sconn = await clientNoise.secure(conn, true)
|
||||
|
||||
await sconn.write("Hello!".cstring, 6)
|
||||
await readTask
|
||||
@ -175,7 +175,7 @@ suite "Noise":
|
||||
trace "Sending huge payload", size = hugePayload.len
|
||||
|
||||
proc connHandler(conn: Connection) {.async, gcsafe.} =
|
||||
let sconn = await serverNoise.secure(conn)
|
||||
let sconn = await serverNoise.secure(conn, false)
|
||||
defer:
|
||||
await sconn.close()
|
||||
let msg = await sconn.readLp()
|
||||
@ -191,7 +191,7 @@ suite "Noise":
|
||||
clientInfo = PeerInfo.init(PrivateKey.random(RSA), [transport1.ma])
|
||||
clientNoise = newNoise(clientInfo.privateKey, outgoing = true)
|
||||
conn = await transport2.dial(transport1.ma)
|
||||
sconn = await clientNoise.secure(conn)
|
||||
sconn = await clientNoise.secure(conn, true)
|
||||
|
||||
await sconn.writeLp(hugePayload)
|
||||
await readTask
|
||||
|
55
tests/teststreamseq.nim
Normal file
55
tests/teststreamseq.nim
Normal file
@ -0,0 +1,55 @@
|
||||
{.used.}
|
||||
|
||||
import unittest
|
||||
import stew/byteutils
|
||||
import ../libp2p/stream/streamseq
|
||||
|
||||
suite "StreamSeq":
|
||||
test "basics":
|
||||
var s: StreamSeq
|
||||
|
||||
check:
|
||||
s.data().len == 0
|
||||
|
||||
s.add([byte 0, 1, 2, 3])
|
||||
|
||||
check:
|
||||
@(s.data()) == [byte 0, 1, 2, 3]
|
||||
|
||||
s.prepare(10)[0..<3] = [byte 4, 5, 6]
|
||||
|
||||
check:
|
||||
@(s.data()) == [byte 0, 1, 2, 3]
|
||||
|
||||
s.commit(3)
|
||||
|
||||
check:
|
||||
@(s.data()) == [byte 0, 1, 2, 3, 4, 5, 6]
|
||||
|
||||
s.consume(1)
|
||||
|
||||
check:
|
||||
@(s.data()) == [byte 1, 2, 3, 4, 5, 6]
|
||||
|
||||
s.consume(6)
|
||||
|
||||
check: @(s.data()) == []
|
||||
|
||||
s.add([])
|
||||
check: @(s.data()) == []
|
||||
|
||||
var o: seq[byte]
|
||||
|
||||
check: 0 == s.consumeTo(o)
|
||||
|
||||
s.add([byte 1, 2, 3])
|
||||
|
||||
o.setLen(2)
|
||||
o.setLen(s.consumeTo(o))
|
||||
check:
|
||||
o == [byte 1, 2]
|
||||
|
||||
o.setLen(s.consumeTo(o))
|
||||
|
||||
check:
|
||||
o == [byte 3]
|
Loading…
x
Reference in New Issue
Block a user