streamline socket read/write hot path (#473)

* streamline socket read/write hot path

This avoids some unnecessary memory copying on the hot path of noise /
mplex, as well as getting rid of a few futures - profiling shows that
this is one of the main culprits of small memory allocations, which
makes sense - this is where gossip fan-out happens.

* fewer futures (and corresponding closures) when sending lpchannel
messages
* avoid data copies when encrypting and framing noise messages
* avoid copying tuple when reading noise data (poor c codegen)
* fix setting eof flag in secure read

* write noise frames in one go

...and closing secure socket once is enough
This commit is contained in:
Jacek Sieka 2020-12-09 15:56:40 +01:00 committed by GitHub
parent 1befeb8c2e
commit 6f1ecc8df7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 101 additions and 67 deletions

View File

@ -56,10 +56,7 @@ proc readMsg*(conn: Connection): Future[Msg] {.async, gcsafe.} =
proc writeMsg*(conn: Connection, proc writeMsg*(conn: Connection,
id: uint64, id: uint64,
msgType: MessageType, msgType: MessageType,
data: seq[byte] = @[]) {.async, gcsafe.} = data: seq[byte] = @[]): Future[void] =
if conn.closed:
return # No point in trying to write to an already-closed connection
var var
left = data.len left = data.len
offset = 0 offset = 0
@ -81,17 +78,9 @@ proc writeMsg*(conn: Connection,
trace "writing mplex message", trace "writing mplex message",
conn, id, msgType, data = data.len, encoded = buf.buffer.len conn, id, msgType, data = data.len, encoded = buf.buffer.len
try: # Write all chunks in a single write to avoid async races where a close
# Write all chunks in a single write to avoid async races where a close # message gets written before some of the chunks
# message gets written before some of the chunks conn.write(buf.buffer)
await conn.write(buf.buffer)
trace "wrote mplex", conn, id, msgType
except CatchableError as exc:
# If the write to the underlying connection failed it should be closed so
# that the other channels are notified as well
trace "failed write", conn, id, msg = exc.msg
await conn.close()
raise exc
proc writeMsg*(conn: Connection, proc writeMsg*(conn: Connection,
id: uint64, id: uint64,

View File

@ -50,8 +50,6 @@ type
resetCode*: MessageType # cached in/out reset code resetCode*: MessageType # cached in/out reset code
writes*: int # In-flight writes writes*: int # In-flight writes
proc open*(s: LPChannel) {.async, gcsafe.}
func shortLog*(s: LPChannel): auto = func shortLog*(s: LPChannel): auto =
if s.isNil: "LPChannel(nil)" if s.isNil: "LPChannel(nil)"
elif s.conn.peerInfo.isNil: $s.oid elif s.conn.peerInfo.isNil: $s.oid
@ -62,8 +60,14 @@ chronicles.formatIt(LPChannel): shortLog(it)
proc open*(s: LPChannel) {.async, gcsafe.} = proc open*(s: LPChannel) {.async, gcsafe.} =
trace "Opening channel", s, conn = s.conn trace "Opening channel", s, conn = s.conn
await s.conn.writeMsg(s.id, MessageType.New, s.name) if s.conn.isClosed:
s.isOpen = true return
try:
await s.conn.writeMsg(s.id, MessageType.New, s.name)
s.isOpen = true
except CatchableError as exc:
await s.conn.close()
raise exc
method closed*(s: LPChannel): bool = method closed*(s: LPChannel): bool =
s.closedLocal s.closedLocal
@ -88,10 +92,11 @@ proc reset*(s: LPChannel) {.async, gcsafe.} =
# If the connection is still active, notify the other end # If the connection is still active, notify the other end
proc resetMessage() {.async.} = proc resetMessage() {.async.} =
try: try:
trace "sending reset message", s, conn = s.conn trace "sending reset message", s, conn = s.conn
await s.conn.writeMsg(s.id, s.resetCode) # write reset await s.conn.writeMsg(s.id, s.resetCode) # write reset
except CatchableError as exc: except CatchableError as exc:
# No cancellations, errors handled in writeMsg # No cancellations
await s.conn.close()
trace "Can't send reset message", s, conn = s.conn, msg = exc.msg trace "Can't send reset message", s, conn = s.conn, msg = exc.msg
asyncSpawn resetMessage() asyncSpawn resetMessage()
@ -115,10 +120,12 @@ method close*(s: LPChannel) {.async, gcsafe.} =
try: try:
await s.conn.writeMsg(s.id, s.closeCode) # write close await s.conn.writeMsg(s.id, s.closeCode) # write close
except CancelledError as exc: except CancelledError as exc:
await s.conn.close()
raise exc raise exc
except CatchableError as exc: except CatchableError as exc:
# It's harmless that close message cannot be sent - the connection is # It's harmless that close message cannot be sent - the connection is
# likely down already # likely down already
await s.conn.close()
trace "Cannot send close message", s, id = s.id, msg = exc.msg trace "Cannot send close message", s, id = s.id, msg = exc.msg
await s.closeUnderlying() # maybe already eofed await s.closeUnderlying() # maybe already eofed

View File

@ -121,18 +121,28 @@ proc dh(priv: Curve25519Key, pub: Curve25519Key): Curve25519Key =
proc hasKey(cs: CipherState): bool = proc hasKey(cs: CipherState): bool =
cs.k != EmptyKey cs.k != EmptyKey
proc encryptWithAd(state: var CipherState, ad, data: openArray[byte]): seq[byte] = proc encrypt(
var state: var CipherState, data: var openArray[byte],
tag: ChaChaPolyTag ad: openArray[byte]): ChaChaPolyTag {.noinit.} =
nonce: ChaChaPolyNonce var nonce: ChaChaPolyNonce
nonce[4..<12] = toBytesLE(state.n) nonce[4..<12] = toBytesLE(state.n)
result = @data
ChaChaPoly.encrypt(state.k, nonce, tag, result, ad) ChaChaPoly.encrypt(state.k, nonce, result, data, ad)
inc state.n inc state.n
if state.n > NonceMax: if state.n > NonceMax:
raise newException(NoiseNonceMaxError, "Noise max nonce value reached") raise newException(NoiseNonceMaxError, "Noise max nonce value reached")
result &= tag
trace "encryptWithAd", tag = byteutils.toHex(tag), data = result.shortLog, nonce = state.n - 1 proc encryptWithAd(state: var CipherState, ad, data: openArray[byte]): seq[byte] =
result = newSeqOfCap[byte](data.len + sizeof(ChachaPolyTag))
result.add(data)
let tag = encrypt(state, result, ad)
result.add(tag)
trace "encryptWithAd",
tag = byteutils.toHex(tag), data = result.shortLog, nonce = state.n - 1
proc decryptWithAd(state: var CipherState, ad, data: openArray[byte]): seq[byte] = proc decryptWithAd(state: var CipherState, ad, data: openArray[byte]): seq[byte] =
var var
@ -417,20 +427,47 @@ method readMessage*(sconn: NoiseConnection): Future[seq[byte]] {.async.} =
dumpMessage(sconn, FlowDirection.Incoming, []) dumpMessage(sconn, FlowDirection.Incoming, [])
trace "Received 0-length message", sconn trace "Received 0-length message", sconn
proc encryptFrame(
sconn: NoiseConnection, cipherFrame: var openArray[byte], src: openArray[byte]) =
# Frame consists of length + cipher data + tag
doAssert src.len <= MaxPlainSize
doAssert cipherFrame.len == 2 + src.len + sizeof(ChaChaPolyTag)
cipherFrame[0..<2] = toBytesBE(uint16(src.len + sizeof(ChaChaPolyTag)))
copyMem(addr cipherFrame[2], unsafeAddr src[0], src.len())
let tag = encrypt(
sconn.writeCs, cipherFrame.toOpenArray(2, 2 + src.len() - 1), [])
copyMem(
addr cipherFrame[cipherFrame.len - sizeof(tag)], unsafeAddr tag[0],
sizeof(tag))
method write*(sconn: NoiseConnection, message: seq[byte]): Future[void] {.async.} = method write*(sconn: NoiseConnection, message: seq[byte]): Future[void] {.async.} =
if message.len == 0: if message.len == 0:
return return
const FramingSize = 2 + sizeof(ChaChaPolyTag)
let
frames = (message.len + MaxPlainSize - 1) div MaxPlainSize
var var
cipherFrames = newSeqUninitialized[byte](message.len + frames * FramingSize)
left = message.len left = message.len
offset = 0 offset = 0
woffset = 0
while left > 0: while left > 0:
let let
chunkSize = min(MaxPlainSize, left) chunkSize = min(MaxPlainSize, left)
cipher = sconn.writeCs.encryptWithAd(
[], message.toOpenArray(offset, offset + chunkSize - 1))
await sconn.stream.writeFrame(cipher) encryptFrame(
sconn,
cipherFrames.toOpenArray(woffset, woffset + chunkSize + FramingSize - 1),
message.toOpenArray(offset, offset + chunkSize - 1))
when defined(libp2p_dump): when defined(libp2p_dump):
dumpMessage( dumpMessage(
@ -438,9 +475,12 @@ method write*(sconn: NoiseConnection, message: seq[byte]): Future[void] {.async.
message.toOpenArray(offset, offset + chunkSize - 1)) message.toOpenArray(offset, offset + chunkSize - 1))
left = left - chunkSize left = left - chunkSize
offset = offset + chunkSize offset += chunkSize
woffset += chunkSize + FramingSize
sconn.activity = true sconn.activity = true
await sconn.stream.write(cipherFrames)
method handshake*(p: Noise, conn: Connection, initiator: bool): Future[SecureConn] {.async.} = method handshake*(p: Noise, conn: Connection, initiator: bool): Future[SecureConn] {.async.} =
trace "Starting Noise handshake", conn, initiator trace "Starting Noise handshake", conn, initiator
@ -529,8 +569,8 @@ method handshake*(p: Noise, conn: Connection, initiator: bool): Future[SecureCon
return secure return secure
method close*(s: NoiseConnection) {.async.} = method closeImpl*(s: NoiseConnection) {.async.} =
await procCall SecureConn(s).close() await procCall SecureConn(s).closeImpl()
burnMem(s.readCs) burnMem(s.readCs)
burnMem(s.writeCs) burnMem(s.writeCs)

View File

@ -56,29 +56,29 @@ method initStream*(s: SecureConn) =
procCall Connection(s).initStream() procCall Connection(s).initStream()
method close*(s: SecureConn) {.async.} = method closeImpl*(s: SecureConn) {.async.} =
trace "Closing secure conn", s, dir = s.dir trace "Closing secure conn", s, dir = s.dir
if not(isNil(s.stream)): if not(isNil(s.stream)):
await s.stream.close() await s.stream.close()
await procCall Connection(s).close() await procCall Connection(s).closeImpl()
method readMessage*(c: SecureConn): Future[seq[byte]] {.async, base.} = method readMessage*(c: SecureConn): Future[seq[byte]] {.async, base.} =
doAssert(false, "Not implemented!") doAssert(false, "Not implemented!")
method handshake(s: Secure, method handshake*(s: Secure,
conn: Connection, conn: Connection,
initiator: bool): Future[SecureConn] {.async, base.} = initiator: bool): Future[SecureConn] {.async, base.} =
doAssert(false, "Not implemented!") doAssert(false, "Not implemented!")
proc handleConn*(s: Secure, proc handleConn(s: Secure,
conn: Connection, conn: Connection,
initiator: bool): Future[Connection] {.async, gcsafe.} = initiator: bool): Future[Connection] {.async.} =
var sconn = await s.handshake(conn, initiator) var sconn = await s.handshake(conn, initiator)
proc cleanup() {.async.} = proc cleanup() {.async.} =
try: try:
let futs = @[conn.join(), sconn.join()] let futs = [conn.join(), sconn.join()]
await futs[0] or futs[1] await futs[0] or futs[1]
for f in futs: for f in futs:
if not f.finished: await f.cancelAndWait() # cancel outstanding join() if not f.finished: await f.cancelAndWait() # cancel outstanding join()
@ -90,7 +90,7 @@ proc handleConn*(s: Secure,
# do not need to propagate CancelledError. # do not need to propagate CancelledError.
discard discard
except CatchableError as exc: except CatchableError as exc:
trace "error cleaning up secure connection", err = exc.msg, sconn debug "error cleaning up secure connection", err = exc.msg, sconn
if not isNil(sconn): if not isNil(sconn):
# All the errors are handled inside `cleanup()` procedure. # All the errors are handled inside `cleanup()` procedure.
@ -98,10 +98,10 @@ proc handleConn*(s: Secure,
return sconn return sconn
method init*(s: Secure) {.gcsafe.} = method init*(s: Secure) =
procCall LPProtocol(s).init() procCall LPProtocol(s).init()
proc handle(conn: Connection, proto: string) {.async, gcsafe.} = proc handle(conn: Connection, proto: string) {.async.} =
trace "handling connection upgrade", proto, conn trace "handling connection upgrade", proto, conn
try: try:
# We don't need the result but we # We don't need the result but we
@ -121,36 +121,34 @@ method init*(s: Secure) {.gcsafe.} =
method secure*(s: Secure, method secure*(s: Secure,
conn: Connection, conn: Connection,
initiator: bool): initiator: bool):
Future[Connection] {.base, gcsafe.} = Future[Connection] {.base.} =
s.handleConn(conn, initiator) s.handleConn(conn, initiator)
method readOnce*(s: SecureConn, method readOnce*(s: SecureConn,
pbytes: pointer, pbytes: pointer,
nbytes: int): nbytes: int):
Future[int] {.async, gcsafe.} = Future[int] {.async.} =
doAssert(nbytes > 0, "nbytes must be positive integer") doAssert(nbytes > 0, "nbytes must be positive integer")
if s.buf.data().len() == 0: if s.isEof:
let (buf, err) = try: raise newLPStreamEOFError()
(await s.readMessage(), nil)
except CatchableError as exc:
(@[], exc)
if not isNil(err): if s.buf.data().len() == 0:
if not (err of LPStreamEOFError): try:
debug "Error while reading message from secure connection, closing.", let buf = await s.readMessage() # Always returns >0 bytes or raises
error=err.name, s.activity = true
message=err.msg, s.buf.add(buf)
connection=s except LPStreamEOFError as err:
s.isEof = true
await s.close()
raise err
except CatchableError as err:
debug "Error while reading message from secure connection, closing.",
error = err.name,
message = err.msg,
connection = s
await s.close() await s.close()
raise err raise err
s.activity = true
if buf.len == 0:
raise newLPStreamIncompleteError()
s.buf.add(buf)
var p = cast[ptr UncheckedArray[byte]](pbytes) var p = cast[ptr UncheckedArray[byte]](pbytes)
return s.buf.consumeTo(toOpenArray(p, 0, nbytes - 1)) return s.buf.consumeTo(toOpenArray(p, 0, nbytes - 1))