cleanups (#366)
* reuse connection timeout for noise handshake (avoid extra timer) * enforce nbytes > 0 for readOnce * avoid some unnecessary memory zeroing * simplify noise * fix dumping when noise splits message
This commit is contained in:
parent
b0d86b95dd
commit
b7e5d1122c
|
@ -276,28 +276,33 @@ template read_s: untyped =
|
||||||
|
|
||||||
msg.consume(rsLen)
|
msg.consume(rsLen)
|
||||||
|
|
||||||
proc receiveHSMessage(sconn: Connection): Future[seq[byte]] {.async.} =
|
proc readFrame(sconn: Connection): Future[seq[byte]] {.async.} =
|
||||||
var besize: array[2, byte]
|
var besize {.noinit.}: array[2, byte]
|
||||||
await sconn.readExactly(addr besize[0], besize.len).wait(HandshakeTimeout)
|
await sconn.readExactly(addr besize[0], besize.len)
|
||||||
let size = uint16.fromBytesBE(besize).int
|
let size = uint16.fromBytesBE(besize).int
|
||||||
trace "receiveHSMessage", size
|
trace "readFrame", sconn, size
|
||||||
if size == 0:
|
if size == 0:
|
||||||
return
|
return
|
||||||
|
|
||||||
var buffer = newSeq[byte](size)
|
var buffer = newSeqUninitialized[byte](size)
|
||||||
await sconn.readExactly(addr buffer[0], buffer.len).wait(HandshakeTimeout)
|
await sconn.readExactly(addr buffer[0], buffer.len)
|
||||||
return buffer
|
return buffer
|
||||||
|
|
||||||
proc sendHSMessage(sconn: Connection; buf: openArray[byte]): Future[void] =
|
proc writeFrame(sconn: Connection, buf: openArray[byte]): Future[void] =
|
||||||
|
doAssert buf.len <= uint16.high.int
|
||||||
var
|
var
|
||||||
lesize = buf.len.uint16
|
lesize = buf.len.uint16
|
||||||
besize = lesize.toBytesBE
|
besize = lesize.toBytesBE
|
||||||
outbuf = newSeqOfCap[byte](besize.len + buf.len)
|
outbuf = newSeqOfCap[byte](besize.len + buf.len)
|
||||||
trace "sendHSMessage", size = lesize
|
trace "writeFrame", sconn, size = lesize, data = shortLog(buf)
|
||||||
outbuf &= besize
|
outbuf &= besize
|
||||||
outbuf &= buf
|
outbuf &= buf
|
||||||
sconn.write(outbuf)
|
sconn.write(outbuf)
|
||||||
|
|
||||||
|
proc receiveHSMessage(sconn: Connection): Future[seq[byte]] = readFrame(sconn)
|
||||||
|
proc sendHSMessage(sconn: Connection, buf: openArray[byte]): Future[void] =
|
||||||
|
writeFrame(sconn, buf)
|
||||||
|
|
||||||
proc handshakeXXOutbound(
|
proc handshakeXXOutbound(
|
||||||
p: Noise, conn: Connection,
|
p: Noise, conn: Connection,
|
||||||
p2pSecret: seq[byte]): Future[HandshakeResult] {.async.} =
|
p2pSecret: seq[byte]): Future[HandshakeResult] {.async.} =
|
||||||
|
@ -399,20 +404,15 @@ proc handshakeXXInbound(
|
||||||
|
|
||||||
method readMessage*(sconn: NoiseConnection): Future[seq[byte]] {.async.} =
|
method readMessage*(sconn: NoiseConnection): Future[seq[byte]] {.async.} =
|
||||||
while true: # Discard 0-length payloads
|
while true: # Discard 0-length payloads
|
||||||
var besize: array[2, byte]
|
let frame = await sconn.stream.readFrame()
|
||||||
await sconn.stream.readExactly(addr besize[0], besize.len)
|
sconn.activity = true
|
||||||
let size = uint16.fromBytesBE(besize).int # Cannot overflow
|
if frame.len > 0:
|
||||||
trace "receiveEncryptedMessage", size, sconn
|
let res = sconn.readCs.decryptWithAd([], frame)
|
||||||
if size > 0:
|
if res.len > 0:
|
||||||
var buffer = newSeq[byte](size)
|
|
||||||
await sconn.stream.readExactly(addr buffer[0], buffer.len)
|
|
||||||
when defined(libp2p_dump):
|
when defined(libp2p_dump):
|
||||||
let res = sconn.readCs.decryptWithAd([], buffer)
|
|
||||||
dumpMessage(sconn, FlowDirection.Incoming, res)
|
dumpMessage(sconn, FlowDirection.Incoming, res)
|
||||||
return res
|
return res
|
||||||
else:
|
|
||||||
return sconn.readCs.decryptWithAd([], buffer)
|
|
||||||
else:
|
|
||||||
when defined(libp2p_dump):
|
when defined(libp2p_dump):
|
||||||
dumpMessage(sconn, FlowDirection.Incoming, [])
|
dumpMessage(sconn, FlowDirection.Incoming, [])
|
||||||
trace "Received 0-length message", sconn
|
trace "Received 0-length message", sconn
|
||||||
|
@ -426,28 +426,27 @@ method write*(sconn: NoiseConnection, message: seq[byte]): Future[void] {.async.
|
||||||
offset = 0
|
offset = 0
|
||||||
while left > 0:
|
while left > 0:
|
||||||
let
|
let
|
||||||
chunkSize = if left > MaxPlainSize: MaxPlainSize else: left
|
chunkSize = min(MaxPlainSize, left)
|
||||||
cipher = sconn.writeCs.encryptWithAd(
|
cipher = sconn.writeCs.encryptWithAd(
|
||||||
[], message.toOpenArray(offset, offset + chunkSize - 1))
|
[], message.toOpenArray(offset, offset + chunkSize - 1))
|
||||||
left = left - chunkSize
|
|
||||||
offset = offset + chunkSize
|
await sconn.stream.writeFrame(cipher)
|
||||||
var
|
|
||||||
lesize = cipher.len.uint16
|
|
||||||
besize = lesize.toBytesBE
|
|
||||||
outbuf = newSeqOfCap[byte](cipher.len + 2)
|
|
||||||
trace "sendEncryptedMessage", sconn, size = lesize, left, offset
|
|
||||||
outbuf &= besize
|
|
||||||
outbuf &= cipher
|
|
||||||
await sconn.stream.write(outbuf)
|
|
||||||
|
|
||||||
when defined(libp2p_dump):
|
when defined(libp2p_dump):
|
||||||
dumpMessage(sconn, FlowDirection.Outgoing, message)
|
dumpMessage(
|
||||||
|
sconn, FlowDirection.Outgoing,
|
||||||
|
message.toOpenArray(offset, offset + chunkSize - 1))
|
||||||
|
|
||||||
|
left = left - chunkSize
|
||||||
|
offset = offset + chunkSize
|
||||||
sconn.activity = true
|
sconn.activity = true
|
||||||
|
|
||||||
method handshake*(p: Noise, conn: Connection, initiator: bool): Future[SecureConn] {.async.} =
|
method handshake*(p: Noise, conn: Connection, initiator: bool): Future[SecureConn] {.async.} =
|
||||||
trace "Starting Noise handshake", conn, initiator
|
trace "Starting Noise handshake", conn, initiator
|
||||||
|
|
||||||
|
let timeout = conn.timeout
|
||||||
|
conn.timeout = HandshakeTimeout
|
||||||
|
|
||||||
# https://github.com/libp2p/specs/tree/master/noise#libp2p-data-in-handshake-messages
|
# https://github.com/libp2p/specs/tree/master/noise#libp2p-data-in-handshake-messages
|
||||||
let
|
let
|
||||||
signedPayload = p.localPrivateKey.sign(
|
signedPayload = p.localPrivateKey.sign(
|
||||||
|
@ -524,6 +523,8 @@ method handshake*(p: Noise, conn: Connection, initiator: bool): Future[SecureCon
|
||||||
|
|
||||||
trace "Noise handshake completed!", initiator, peer = shortLog(secure.peerInfo)
|
trace "Noise handshake completed!", initiator, peer = shortLog(secure.peerInfo)
|
||||||
|
|
||||||
|
conn.timeout = timeout
|
||||||
|
|
||||||
return secure
|
return secure
|
||||||
|
|
||||||
method close*(s: NoiseConnection) {.async.} =
|
method close*(s: NoiseConnection) {.async.} =
|
||||||
|
|
|
@ -108,15 +108,14 @@ method init*(s: Secure) {.gcsafe.} =
|
||||||
method secure*(s: Secure,
|
method secure*(s: Secure,
|
||||||
conn: Connection,
|
conn: Connection,
|
||||||
initiator: bool):
|
initiator: bool):
|
||||||
Future[Connection] {.async, base, gcsafe.} =
|
Future[Connection] {.base, gcsafe.} =
|
||||||
result = await s.handleConn(conn, initiator)
|
s.handleConn(conn, initiator)
|
||||||
|
|
||||||
method readOnce*(s: SecureConn,
|
method readOnce*(s: SecureConn,
|
||||||
pbytes: pointer,
|
pbytes: pointer,
|
||||||
nbytes: int):
|
nbytes: int):
|
||||||
Future[int] {.async, gcsafe.} =
|
Future[int] {.async, gcsafe.} =
|
||||||
if nbytes == 0:
|
doAssert(nbytes > 0, "nbytes must be positive integer")
|
||||||
return 0
|
|
||||||
|
|
||||||
if s.buf.data().len() == 0:
|
if s.buf.data().len() == 0:
|
||||||
let buf = await s.readMessage()
|
let buf = await s.readMessage()
|
||||||
|
|
|
@ -128,6 +128,8 @@ method readOnce*(s: BufferStream,
|
||||||
pbytes: pointer,
|
pbytes: pointer,
|
||||||
nbytes: int):
|
nbytes: int):
|
||||||
Future[int] {.async.} =
|
Future[int] {.async.} =
|
||||||
|
doAssert(nbytes > 0, "nbytes must be positive integer")
|
||||||
|
|
||||||
if s.isEof and s.readBuf.len() == 0:
|
if s.isEof and s.readBuf.len() == 0:
|
||||||
raise newLPStreamEOFError()
|
raise newLPStreamEOFError()
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue