streamline socket read/write hot path (#473)
* streamline socket read/write hot path This avoids some unnecessary memory copying on the hot path of noise / mplex, as well as getting rid of a few futures - profiling shows that this is one of the main culprits of small memory allocations, which makes sense - this is where gossip fan-out happens. * fewer futures (and corresponding closures) when sending lpchannel messages * avoid data copies when encrypting and framing noise messages * avoid copying tuple when reading noise data (poor c codegen) * fix setting eof flag in secure read * write noise frames in one go ...and closing secure socket once is enough
This commit is contained in:
parent
1befeb8c2e
commit
6f1ecc8df7
|
@ -56,10 +56,7 @@ proc readMsg*(conn: Connection): Future[Msg] {.async, gcsafe.} =
|
|||
proc writeMsg*(conn: Connection,
|
||||
id: uint64,
|
||||
msgType: MessageType,
|
||||
data: seq[byte] = @[]) {.async, gcsafe.} =
|
||||
if conn.closed:
|
||||
return # No point in trying to write to an already-closed connection
|
||||
|
||||
data: seq[byte] = @[]): Future[void] =
|
||||
var
|
||||
left = data.len
|
||||
offset = 0
|
||||
|
@ -81,17 +78,9 @@ proc writeMsg*(conn: Connection,
|
|||
trace "writing mplex message",
|
||||
conn, id, msgType, data = data.len, encoded = buf.buffer.len
|
||||
|
||||
try:
|
||||
# Write all chunks in a single write to avoid async races where a close
|
||||
# message gets written before some of the chunks
|
||||
await conn.write(buf.buffer)
|
||||
trace "wrote mplex", conn, id, msgType
|
||||
except CatchableError as exc:
|
||||
# If the write to the underlying connection failed it should be closed so
|
||||
# that the other channels are notified as well
|
||||
trace "failed write", conn, id, msg = exc.msg
|
||||
await conn.close()
|
||||
raise exc
|
||||
conn.write(buf.buffer)
|
||||
|
||||
proc writeMsg*(conn: Connection,
|
||||
id: uint64,
|
||||
|
|
|
@ -50,8 +50,6 @@ type
|
|||
resetCode*: MessageType # cached in/out reset code
|
||||
writes*: int # In-flight writes
|
||||
|
||||
proc open*(s: LPChannel) {.async, gcsafe.}
|
||||
|
||||
func shortLog*(s: LPChannel): auto =
|
||||
if s.isNil: "LPChannel(nil)"
|
||||
elif s.conn.peerInfo.isNil: $s.oid
|
||||
|
@ -62,8 +60,14 @@ chronicles.formatIt(LPChannel): shortLog(it)
|
|||
|
||||
proc open*(s: LPChannel) {.async, gcsafe.} =
|
||||
trace "Opening channel", s, conn = s.conn
|
||||
if s.conn.isClosed:
|
||||
return
|
||||
try:
|
||||
await s.conn.writeMsg(s.id, MessageType.New, s.name)
|
||||
s.isOpen = true
|
||||
except CatchableError as exc:
|
||||
await s.conn.close()
|
||||
raise exc
|
||||
|
||||
method closed*(s: LPChannel): bool =
|
||||
s.closedLocal
|
||||
|
@ -91,7 +95,8 @@ proc reset*(s: LPChannel) {.async, gcsafe.} =
|
|||
trace "sending reset message", s, conn = s.conn
|
||||
await s.conn.writeMsg(s.id, s.resetCode) # write reset
|
||||
except CatchableError as exc:
|
||||
# No cancellations, errors handled in writeMsg
|
||||
# No cancellations
|
||||
await s.conn.close()
|
||||
trace "Can't send reset message", s, conn = s.conn, msg = exc.msg
|
||||
|
||||
asyncSpawn resetMessage()
|
||||
|
@ -115,10 +120,12 @@ method close*(s: LPChannel) {.async, gcsafe.} =
|
|||
try:
|
||||
await s.conn.writeMsg(s.id, s.closeCode) # write close
|
||||
except CancelledError as exc:
|
||||
await s.conn.close()
|
||||
raise exc
|
||||
except CatchableError as exc:
|
||||
# It's harmless that close message cannot be sent - the connection is
|
||||
# likely down already
|
||||
await s.conn.close()
|
||||
trace "Cannot send close message", s, id = s.id, msg = exc.msg
|
||||
|
||||
await s.closeUnderlying() # maybe already eofed
|
||||
|
|
|
@ -121,18 +121,28 @@ proc dh(priv: Curve25519Key, pub: Curve25519Key): Curve25519Key =
|
|||
proc hasKey(cs: CipherState): bool =
|
||||
cs.k != EmptyKey
|
||||
|
||||
proc encryptWithAd(state: var CipherState, ad, data: openArray[byte]): seq[byte] =
|
||||
var
|
||||
tag: ChaChaPolyTag
|
||||
nonce: ChaChaPolyNonce
|
||||
proc encrypt(
|
||||
state: var CipherState, data: var openArray[byte],
|
||||
ad: openArray[byte]): ChaChaPolyTag {.noinit.} =
|
||||
var nonce: ChaChaPolyNonce
|
||||
nonce[4..<12] = toBytesLE(state.n)
|
||||
result = @data
|
||||
ChaChaPoly.encrypt(state.k, nonce, tag, result, ad)
|
||||
|
||||
ChaChaPoly.encrypt(state.k, nonce, result, data, ad)
|
||||
|
||||
inc state.n
|
||||
if state.n > NonceMax:
|
||||
raise newException(NoiseNonceMaxError, "Noise max nonce value reached")
|
||||
result &= tag
|
||||
trace "encryptWithAd", tag = byteutils.toHex(tag), data = result.shortLog, nonce = state.n - 1
|
||||
|
||||
proc encryptWithAd(state: var CipherState, ad, data: openArray[byte]): seq[byte] =
|
||||
result = newSeqOfCap[byte](data.len + sizeof(ChachaPolyTag))
|
||||
result.add(data)
|
||||
|
||||
let tag = encrypt(state, result, ad)
|
||||
|
||||
result.add(tag)
|
||||
|
||||
trace "encryptWithAd",
|
||||
tag = byteutils.toHex(tag), data = result.shortLog, nonce = state.n - 1
|
||||
|
||||
proc decryptWithAd(state: var CipherState, ad, data: openArray[byte]): seq[byte] =
|
||||
var
|
||||
|
@ -417,20 +427,47 @@ method readMessage*(sconn: NoiseConnection): Future[seq[byte]] {.async.} =
|
|||
dumpMessage(sconn, FlowDirection.Incoming, [])
|
||||
trace "Received 0-length message", sconn
|
||||
|
||||
|
||||
proc encryptFrame(
|
||||
sconn: NoiseConnection, cipherFrame: var openArray[byte], src: openArray[byte]) =
|
||||
# Frame consists of length + cipher data + tag
|
||||
doAssert src.len <= MaxPlainSize
|
||||
doAssert cipherFrame.len == 2 + src.len + sizeof(ChaChaPolyTag)
|
||||
|
||||
cipherFrame[0..<2] = toBytesBE(uint16(src.len + sizeof(ChaChaPolyTag)))
|
||||
|
||||
copyMem(addr cipherFrame[2], unsafeAddr src[0], src.len())
|
||||
|
||||
let tag = encrypt(
|
||||
sconn.writeCs, cipherFrame.toOpenArray(2, 2 + src.len() - 1), [])
|
||||
|
||||
copyMem(
|
||||
addr cipherFrame[cipherFrame.len - sizeof(tag)], unsafeAddr tag[0],
|
||||
sizeof(tag))
|
||||
|
||||
method write*(sconn: NoiseConnection, message: seq[byte]): Future[void] {.async.} =
|
||||
if message.len == 0:
|
||||
return
|
||||
|
||||
const FramingSize = 2 + sizeof(ChaChaPolyTag)
|
||||
|
||||
let
|
||||
frames = (message.len + MaxPlainSize - 1) div MaxPlainSize
|
||||
|
||||
var
|
||||
cipherFrames = newSeqUninitialized[byte](message.len + frames * FramingSize)
|
||||
left = message.len
|
||||
offset = 0
|
||||
woffset = 0
|
||||
|
||||
while left > 0:
|
||||
let
|
||||
chunkSize = min(MaxPlainSize, left)
|
||||
cipher = sconn.writeCs.encryptWithAd(
|
||||
[], message.toOpenArray(offset, offset + chunkSize - 1))
|
||||
|
||||
await sconn.stream.writeFrame(cipher)
|
||||
encryptFrame(
|
||||
sconn,
|
||||
cipherFrames.toOpenArray(woffset, woffset + chunkSize + FramingSize - 1),
|
||||
message.toOpenArray(offset, offset + chunkSize - 1))
|
||||
|
||||
when defined(libp2p_dump):
|
||||
dumpMessage(
|
||||
|
@ -438,9 +475,12 @@ method write*(sconn: NoiseConnection, message: seq[byte]): Future[void] {.async.
|
|||
message.toOpenArray(offset, offset + chunkSize - 1))
|
||||
|
||||
left = left - chunkSize
|
||||
offset = offset + chunkSize
|
||||
offset += chunkSize
|
||||
woffset += chunkSize + FramingSize
|
||||
sconn.activity = true
|
||||
|
||||
await sconn.stream.write(cipherFrames)
|
||||
|
||||
method handshake*(p: Noise, conn: Connection, initiator: bool): Future[SecureConn] {.async.} =
|
||||
trace "Starting Noise handshake", conn, initiator
|
||||
|
||||
|
@ -529,8 +569,8 @@ method handshake*(p: Noise, conn: Connection, initiator: bool): Future[SecureCon
|
|||
|
||||
return secure
|
||||
|
||||
method close*(s: NoiseConnection) {.async.} =
|
||||
await procCall SecureConn(s).close()
|
||||
method closeImpl*(s: NoiseConnection) {.async.} =
|
||||
await procCall SecureConn(s).closeImpl()
|
||||
|
||||
burnMem(s.readCs)
|
||||
burnMem(s.writeCs)
|
||||
|
|
|
@ -56,29 +56,29 @@ method initStream*(s: SecureConn) =
|
|||
|
||||
procCall Connection(s).initStream()
|
||||
|
||||
method close*(s: SecureConn) {.async.} =
|
||||
method closeImpl*(s: SecureConn) {.async.} =
|
||||
trace "Closing secure conn", s, dir = s.dir
|
||||
if not(isNil(s.stream)):
|
||||
await s.stream.close()
|
||||
|
||||
await procCall Connection(s).close()
|
||||
await procCall Connection(s).closeImpl()
|
||||
|
||||
method readMessage*(c: SecureConn): Future[seq[byte]] {.async, base.} =
|
||||
doAssert(false, "Not implemented!")
|
||||
|
||||
method handshake(s: Secure,
|
||||
method handshake*(s: Secure,
|
||||
conn: Connection,
|
||||
initiator: bool): Future[SecureConn] {.async, base.} =
|
||||
doAssert(false, "Not implemented!")
|
||||
|
||||
proc handleConn*(s: Secure,
|
||||
proc handleConn(s: Secure,
|
||||
conn: Connection,
|
||||
initiator: bool): Future[Connection] {.async, gcsafe.} =
|
||||
initiator: bool): Future[Connection] {.async.} =
|
||||
var sconn = await s.handshake(conn, initiator)
|
||||
|
||||
proc cleanup() {.async.} =
|
||||
try:
|
||||
let futs = @[conn.join(), sconn.join()]
|
||||
let futs = [conn.join(), sconn.join()]
|
||||
await futs[0] or futs[1]
|
||||
for f in futs:
|
||||
if not f.finished: await f.cancelAndWait() # cancel outstanding join()
|
||||
|
@ -90,7 +90,7 @@ proc handleConn*(s: Secure,
|
|||
# do not need to propagate CancelledError.
|
||||
discard
|
||||
except CatchableError as exc:
|
||||
trace "error cleaning up secure connection", err = exc.msg, sconn
|
||||
debug "error cleaning up secure connection", err = exc.msg, sconn
|
||||
|
||||
if not isNil(sconn):
|
||||
# All the errors are handled inside `cleanup()` procedure.
|
||||
|
@ -98,10 +98,10 @@ proc handleConn*(s: Secure,
|
|||
|
||||
return sconn
|
||||
|
||||
method init*(s: Secure) {.gcsafe.} =
|
||||
method init*(s: Secure) =
|
||||
procCall LPProtocol(s).init()
|
||||
|
||||
proc handle(conn: Connection, proto: string) {.async, gcsafe.} =
|
||||
proc handle(conn: Connection, proto: string) {.async.} =
|
||||
trace "handling connection upgrade", proto, conn
|
||||
try:
|
||||
# We don't need the result but we
|
||||
|
@ -121,23 +121,28 @@ method init*(s: Secure) {.gcsafe.} =
|
|||
method secure*(s: Secure,
|
||||
conn: Connection,
|
||||
initiator: bool):
|
||||
Future[Connection] {.base, gcsafe.} =
|
||||
Future[Connection] {.base.} =
|
||||
s.handleConn(conn, initiator)
|
||||
|
||||
method readOnce*(s: SecureConn,
|
||||
pbytes: pointer,
|
||||
nbytes: int):
|
||||
Future[int] {.async, gcsafe.} =
|
||||
Future[int] {.async.} =
|
||||
doAssert(nbytes > 0, "nbytes must be positive integer")
|
||||
|
||||
if s.buf.data().len() == 0:
|
||||
let (buf, err) = try:
|
||||
(await s.readMessage(), nil)
|
||||
except CatchableError as exc:
|
||||
(@[], exc)
|
||||
if s.isEof:
|
||||
raise newLPStreamEOFError()
|
||||
|
||||
if not isNil(err):
|
||||
if not (err of LPStreamEOFError):
|
||||
if s.buf.data().len() == 0:
|
||||
try:
|
||||
let buf = await s.readMessage() # Always returns >0 bytes or raises
|
||||
s.activity = true
|
||||
s.buf.add(buf)
|
||||
except LPStreamEOFError as err:
|
||||
s.isEof = true
|
||||
await s.close()
|
||||
raise err
|
||||
except CatchableError as err:
|
||||
debug "Error while reading message from secure connection, closing.",
|
||||
error = err.name,
|
||||
message = err.msg,
|
||||
|
@ -145,12 +150,5 @@ method readOnce*(s: SecureConn,
|
|||
await s.close()
|
||||
raise err
|
||||
|
||||
s.activity = true
|
||||
|
||||
if buf.len == 0:
|
||||
raise newLPStreamIncompleteError()
|
||||
|
||||
s.buf.add(buf)
|
||||
|
||||
var p = cast[ptr UncheckedArray[byte]](pbytes)
|
||||
return s.buf.consumeTo(toOpenArray(p, 0, nbytes - 1))
|
||||
|
|
Loading…
Reference in New Issue