cleanups (#366)
* reuse connection timeout for noise handshake (avoid extra timer) * enforce nbytes > 0 for readOnce * avoid some unnecessary memory zeroing * simplify noise * fix dumping when noise splits message
This commit is contained in:
parent
b0d86b95dd
commit
b7e5d1122c
|
@ -276,28 +276,33 @@ template read_s: untyped =
|
|||
|
||||
msg.consume(rsLen)
|
||||
|
||||
proc receiveHSMessage(sconn: Connection): Future[seq[byte]] {.async.} =
|
||||
var besize: array[2, byte]
|
||||
await sconn.readExactly(addr besize[0], besize.len).wait(HandshakeTimeout)
|
||||
proc readFrame(sconn: Connection): Future[seq[byte]] {.async.} =
|
||||
var besize {.noinit.}: array[2, byte]
|
||||
await sconn.readExactly(addr besize[0], besize.len)
|
||||
let size = uint16.fromBytesBE(besize).int
|
||||
trace "receiveHSMessage", size
|
||||
trace "readFrame", sconn, size
|
||||
if size == 0:
|
||||
return
|
||||
|
||||
var buffer = newSeq[byte](size)
|
||||
await sconn.readExactly(addr buffer[0], buffer.len).wait(HandshakeTimeout)
|
||||
var buffer = newSeqUninitialized[byte](size)
|
||||
await sconn.readExactly(addr buffer[0], buffer.len)
|
||||
return buffer
|
||||
|
||||
proc sendHSMessage(sconn: Connection; buf: openArray[byte]): Future[void] =
|
||||
proc writeFrame(sconn: Connection, buf: openArray[byte]): Future[void] =
|
||||
doAssert buf.len <= uint16.high.int
|
||||
var
|
||||
lesize = buf.len.uint16
|
||||
besize = lesize.toBytesBE
|
||||
outbuf = newSeqOfCap[byte](besize.len + buf.len)
|
||||
trace "sendHSMessage", size = lesize
|
||||
trace "writeFrame", sconn, size = lesize, data = shortLog(buf)
|
||||
outbuf &= besize
|
||||
outbuf &= buf
|
||||
sconn.write(outbuf)
|
||||
|
||||
proc receiveHSMessage(sconn: Connection): Future[seq[byte]] = readFrame(sconn)
|
||||
proc sendHSMessage(sconn: Connection, buf: openArray[byte]): Future[void] =
|
||||
writeFrame(sconn, buf)
|
||||
|
||||
proc handshakeXXOutbound(
|
||||
p: Noise, conn: Connection,
|
||||
p2pSecret: seq[byte]): Future[HandshakeResult] {.async.} =
|
||||
|
@ -399,23 +404,18 @@ proc handshakeXXInbound(
|
|||
|
||||
method readMessage*(sconn: NoiseConnection): Future[seq[byte]] {.async.} =
|
||||
while true: # Discard 0-length payloads
|
||||
var besize: array[2, byte]
|
||||
await sconn.stream.readExactly(addr besize[0], besize.len)
|
||||
let size = uint16.fromBytesBE(besize).int # Cannot overflow
|
||||
trace "receiveEncryptedMessage", size, sconn
|
||||
if size > 0:
|
||||
var buffer = newSeq[byte](size)
|
||||
await sconn.stream.readExactly(addr buffer[0], buffer.len)
|
||||
when defined(libp2p_dump):
|
||||
let res = sconn.readCs.decryptWithAd([], buffer)
|
||||
dumpMessage(sconn, FlowDirection.Incoming, res)
|
||||
let frame = await sconn.stream.readFrame()
|
||||
sconn.activity = true
|
||||
if frame.len > 0:
|
||||
let res = sconn.readCs.decryptWithAd([], frame)
|
||||
if res.len > 0:
|
||||
when defined(libp2p_dump):
|
||||
dumpMessage(sconn, FlowDirection.Incoming, res)
|
||||
return res
|
||||
else:
|
||||
return sconn.readCs.decryptWithAd([], buffer)
|
||||
else:
|
||||
when defined(libp2p_dump):
|
||||
dumpMessage(sconn, FlowDirection.Incoming, [])
|
||||
trace "Received 0-length message", sconn
|
||||
|
||||
when defined(libp2p_dump):
|
||||
dumpMessage(sconn, FlowDirection.Incoming, [])
|
||||
trace "Received 0-length message", sconn
|
||||
|
||||
method write*(sconn: NoiseConnection, message: seq[byte]): Future[void] {.async.} =
|
||||
if message.len == 0:
|
||||
|
@ -426,28 +426,27 @@ method write*(sconn: NoiseConnection, message: seq[byte]): Future[void] {.async.
|
|||
offset = 0
|
||||
while left > 0:
|
||||
let
|
||||
chunkSize = if left > MaxPlainSize: MaxPlainSize else: left
|
||||
chunkSize = min(MaxPlainSize, left)
|
||||
cipher = sconn.writeCs.encryptWithAd(
|
||||
[], message.toOpenArray(offset, offset + chunkSize - 1))
|
||||
left = left - chunkSize
|
||||
offset = offset + chunkSize
|
||||
var
|
||||
lesize = cipher.len.uint16
|
||||
besize = lesize.toBytesBE
|
||||
outbuf = newSeqOfCap[byte](cipher.len + 2)
|
||||
trace "sendEncryptedMessage", sconn, size = lesize, left, offset
|
||||
outbuf &= besize
|
||||
outbuf &= cipher
|
||||
await sconn.stream.write(outbuf)
|
||||
|
||||
await sconn.stream.writeFrame(cipher)
|
||||
|
||||
when defined(libp2p_dump):
|
||||
dumpMessage(sconn, FlowDirection.Outgoing, message)
|
||||
dumpMessage(
|
||||
sconn, FlowDirection.Outgoing,
|
||||
message.toOpenArray(offset, offset + chunkSize - 1))
|
||||
|
||||
left = left - chunkSize
|
||||
offset = offset + chunkSize
|
||||
sconn.activity = true
|
||||
|
||||
method handshake*(p: Noise, conn: Connection, initiator: bool): Future[SecureConn] {.async.} =
|
||||
trace "Starting Noise handshake", conn, initiator
|
||||
|
||||
let timeout = conn.timeout
|
||||
conn.timeout = HandshakeTimeout
|
||||
|
||||
# https://github.com/libp2p/specs/tree/master/noise#libp2p-data-in-handshake-messages
|
||||
let
|
||||
signedPayload = p.localPrivateKey.sign(
|
||||
|
@ -524,6 +523,8 @@ method handshake*(p: Noise, conn: Connection, initiator: bool): Future[SecureCon
|
|||
|
||||
trace "Noise handshake completed!", initiator, peer = shortLog(secure.peerInfo)
|
||||
|
||||
conn.timeout = timeout
|
||||
|
||||
return secure
|
||||
|
||||
method close*(s: NoiseConnection) {.async.} =
|
||||
|
|
|
@ -108,15 +108,14 @@ method init*(s: Secure) {.gcsafe.} =
|
|||
method secure*(s: Secure,
|
||||
conn: Connection,
|
||||
initiator: bool):
|
||||
Future[Connection] {.async, base, gcsafe.} =
|
||||
result = await s.handleConn(conn, initiator)
|
||||
Future[Connection] {.base, gcsafe.} =
|
||||
s.handleConn(conn, initiator)
|
||||
|
||||
method readOnce*(s: SecureConn,
|
||||
pbytes: pointer,
|
||||
nbytes: int):
|
||||
Future[int] {.async, gcsafe.} =
|
||||
if nbytes == 0:
|
||||
return 0
|
||||
doAssert(nbytes > 0, "nbytes must be positive integer")
|
||||
|
||||
if s.buf.data().len() == 0:
|
||||
let buf = await s.readMessage()
|
||||
|
|
|
@ -128,6 +128,8 @@ method readOnce*(s: BufferStream,
|
|||
pbytes: pointer,
|
||||
nbytes: int):
|
||||
Future[int] {.async.} =
|
||||
doAssert(nbytes > 0, "nbytes must be positive integer")
|
||||
|
||||
if s.isEof and s.readBuf.len() == 0:
|
||||
raise newLPStreamEOFError()
|
||||
|
||||
|
|
Loading…
Reference in New Issue