mirror of https://github.com/vacp2p/nim-libp2p.git
streamline socket read/write hot path (#473)
* streamline socket read/write hot path This avoids some unnecessary memory copying on the hot path of noise / mplex, as well as getting rid of a few futures - profiling shows that this is one of the main culprits of small memory allocations, which makes sense - this is where gossip fan-out happens. * fewer futures (and corresponding closures) when sending lpchannel messages * avoid data copies when encrypting and framing noise messages * avoid copying tuple when reading noise data (poor c codegen) * fix setting eof flag in secure read * write noise frames in one go ...and closing secure socket once is enough
This commit is contained in:
parent
1befeb8c2e
commit
6f1ecc8df7
|
@ -56,10 +56,7 @@ proc readMsg*(conn: Connection): Future[Msg] {.async, gcsafe.} =
|
||||||
proc writeMsg*(conn: Connection,
|
proc writeMsg*(conn: Connection,
|
||||||
id: uint64,
|
id: uint64,
|
||||||
msgType: MessageType,
|
msgType: MessageType,
|
||||||
data: seq[byte] = @[]) {.async, gcsafe.} =
|
data: seq[byte] = @[]): Future[void] =
|
||||||
if conn.closed:
|
|
||||||
return # No point in trying to write to an already-closed connection
|
|
||||||
|
|
||||||
var
|
var
|
||||||
left = data.len
|
left = data.len
|
||||||
offset = 0
|
offset = 0
|
||||||
|
@ -81,17 +78,9 @@ proc writeMsg*(conn: Connection,
|
||||||
trace "writing mplex message",
|
trace "writing mplex message",
|
||||||
conn, id, msgType, data = data.len, encoded = buf.buffer.len
|
conn, id, msgType, data = data.len, encoded = buf.buffer.len
|
||||||
|
|
||||||
try:
|
# Write all chunks in a single write to avoid async races where a close
|
||||||
# Write all chunks in a single write to avoid async races where a close
|
# message gets written before some of the chunks
|
||||||
# message gets written before some of the chunks
|
conn.write(buf.buffer)
|
||||||
await conn.write(buf.buffer)
|
|
||||||
trace "wrote mplex", conn, id, msgType
|
|
||||||
except CatchableError as exc:
|
|
||||||
# If the write to the underlying connection failed it should be closed so
|
|
||||||
# that the other channels are notified as well
|
|
||||||
trace "failed write", conn, id, msg = exc.msg
|
|
||||||
await conn.close()
|
|
||||||
raise exc
|
|
||||||
|
|
||||||
proc writeMsg*(conn: Connection,
|
proc writeMsg*(conn: Connection,
|
||||||
id: uint64,
|
id: uint64,
|
||||||
|
|
|
@ -50,8 +50,6 @@ type
|
||||||
resetCode*: MessageType # cached in/out reset code
|
resetCode*: MessageType # cached in/out reset code
|
||||||
writes*: int # In-flight writes
|
writes*: int # In-flight writes
|
||||||
|
|
||||||
proc open*(s: LPChannel) {.async, gcsafe.}
|
|
||||||
|
|
||||||
func shortLog*(s: LPChannel): auto =
|
func shortLog*(s: LPChannel): auto =
|
||||||
if s.isNil: "LPChannel(nil)"
|
if s.isNil: "LPChannel(nil)"
|
||||||
elif s.conn.peerInfo.isNil: $s.oid
|
elif s.conn.peerInfo.isNil: $s.oid
|
||||||
|
@ -62,8 +60,14 @@ chronicles.formatIt(LPChannel): shortLog(it)
|
||||||
|
|
||||||
proc open*(s: LPChannel) {.async, gcsafe.} =
|
proc open*(s: LPChannel) {.async, gcsafe.} =
|
||||||
trace "Opening channel", s, conn = s.conn
|
trace "Opening channel", s, conn = s.conn
|
||||||
await s.conn.writeMsg(s.id, MessageType.New, s.name)
|
if s.conn.isClosed:
|
||||||
s.isOpen = true
|
return
|
||||||
|
try:
|
||||||
|
await s.conn.writeMsg(s.id, MessageType.New, s.name)
|
||||||
|
s.isOpen = true
|
||||||
|
except CatchableError as exc:
|
||||||
|
await s.conn.close()
|
||||||
|
raise exc
|
||||||
|
|
||||||
method closed*(s: LPChannel): bool =
|
method closed*(s: LPChannel): bool =
|
||||||
s.closedLocal
|
s.closedLocal
|
||||||
|
@ -88,10 +92,11 @@ proc reset*(s: LPChannel) {.async, gcsafe.} =
|
||||||
# If the connection is still active, notify the other end
|
# If the connection is still active, notify the other end
|
||||||
proc resetMessage() {.async.} =
|
proc resetMessage() {.async.} =
|
||||||
try:
|
try:
|
||||||
trace "sending reset message", s, conn = s.conn
|
trace "sending reset message", s, conn = s.conn
|
||||||
await s.conn.writeMsg(s.id, s.resetCode) # write reset
|
await s.conn.writeMsg(s.id, s.resetCode) # write reset
|
||||||
except CatchableError as exc:
|
except CatchableError as exc:
|
||||||
# No cancellations, errors handled in writeMsg
|
# No cancellations
|
||||||
|
await s.conn.close()
|
||||||
trace "Can't send reset message", s, conn = s.conn, msg = exc.msg
|
trace "Can't send reset message", s, conn = s.conn, msg = exc.msg
|
||||||
|
|
||||||
asyncSpawn resetMessage()
|
asyncSpawn resetMessage()
|
||||||
|
@ -115,10 +120,12 @@ method close*(s: LPChannel) {.async, gcsafe.} =
|
||||||
try:
|
try:
|
||||||
await s.conn.writeMsg(s.id, s.closeCode) # write close
|
await s.conn.writeMsg(s.id, s.closeCode) # write close
|
||||||
except CancelledError as exc:
|
except CancelledError as exc:
|
||||||
|
await s.conn.close()
|
||||||
raise exc
|
raise exc
|
||||||
except CatchableError as exc:
|
except CatchableError as exc:
|
||||||
# It's harmless that close message cannot be sent - the connection is
|
# It's harmless that close message cannot be sent - the connection is
|
||||||
# likely down already
|
# likely down already
|
||||||
|
await s.conn.close()
|
||||||
trace "Cannot send close message", s, id = s.id, msg = exc.msg
|
trace "Cannot send close message", s, id = s.id, msg = exc.msg
|
||||||
|
|
||||||
await s.closeUnderlying() # maybe already eofed
|
await s.closeUnderlying() # maybe already eofed
|
||||||
|
|
|
@ -121,18 +121,28 @@ proc dh(priv: Curve25519Key, pub: Curve25519Key): Curve25519Key =
|
||||||
proc hasKey(cs: CipherState): bool =
|
proc hasKey(cs: CipherState): bool =
|
||||||
cs.k != EmptyKey
|
cs.k != EmptyKey
|
||||||
|
|
||||||
proc encryptWithAd(state: var CipherState, ad, data: openArray[byte]): seq[byte] =
|
proc encrypt(
|
||||||
var
|
state: var CipherState, data: var openArray[byte],
|
||||||
tag: ChaChaPolyTag
|
ad: openArray[byte]): ChaChaPolyTag {.noinit.} =
|
||||||
nonce: ChaChaPolyNonce
|
var nonce: ChaChaPolyNonce
|
||||||
nonce[4..<12] = toBytesLE(state.n)
|
nonce[4..<12] = toBytesLE(state.n)
|
||||||
result = @data
|
|
||||||
ChaChaPoly.encrypt(state.k, nonce, tag, result, ad)
|
ChaChaPoly.encrypt(state.k, nonce, result, data, ad)
|
||||||
|
|
||||||
inc state.n
|
inc state.n
|
||||||
if state.n > NonceMax:
|
if state.n > NonceMax:
|
||||||
raise newException(NoiseNonceMaxError, "Noise max nonce value reached")
|
raise newException(NoiseNonceMaxError, "Noise max nonce value reached")
|
||||||
result &= tag
|
|
||||||
trace "encryptWithAd", tag = byteutils.toHex(tag), data = result.shortLog, nonce = state.n - 1
|
proc encryptWithAd(state: var CipherState, ad, data: openArray[byte]): seq[byte] =
|
||||||
|
result = newSeqOfCap[byte](data.len + sizeof(ChachaPolyTag))
|
||||||
|
result.add(data)
|
||||||
|
|
||||||
|
let tag = encrypt(state, result, ad)
|
||||||
|
|
||||||
|
result.add(tag)
|
||||||
|
|
||||||
|
trace "encryptWithAd",
|
||||||
|
tag = byteutils.toHex(tag), data = result.shortLog, nonce = state.n - 1
|
||||||
|
|
||||||
proc decryptWithAd(state: var CipherState, ad, data: openArray[byte]): seq[byte] =
|
proc decryptWithAd(state: var CipherState, ad, data: openArray[byte]): seq[byte] =
|
||||||
var
|
var
|
||||||
|
@ -417,20 +427,47 @@ method readMessage*(sconn: NoiseConnection): Future[seq[byte]] {.async.} =
|
||||||
dumpMessage(sconn, FlowDirection.Incoming, [])
|
dumpMessage(sconn, FlowDirection.Incoming, [])
|
||||||
trace "Received 0-length message", sconn
|
trace "Received 0-length message", sconn
|
||||||
|
|
||||||
|
|
||||||
|
proc encryptFrame(
|
||||||
|
sconn: NoiseConnection, cipherFrame: var openArray[byte], src: openArray[byte]) =
|
||||||
|
# Frame consists of length + cipher data + tag
|
||||||
|
doAssert src.len <= MaxPlainSize
|
||||||
|
doAssert cipherFrame.len == 2 + src.len + sizeof(ChaChaPolyTag)
|
||||||
|
|
||||||
|
cipherFrame[0..<2] = toBytesBE(uint16(src.len + sizeof(ChaChaPolyTag)))
|
||||||
|
|
||||||
|
copyMem(addr cipherFrame[2], unsafeAddr src[0], src.len())
|
||||||
|
|
||||||
|
let tag = encrypt(
|
||||||
|
sconn.writeCs, cipherFrame.toOpenArray(2, 2 + src.len() - 1), [])
|
||||||
|
|
||||||
|
copyMem(
|
||||||
|
addr cipherFrame[cipherFrame.len - sizeof(tag)], unsafeAddr tag[0],
|
||||||
|
sizeof(tag))
|
||||||
|
|
||||||
method write*(sconn: NoiseConnection, message: seq[byte]): Future[void] {.async.} =
|
method write*(sconn: NoiseConnection, message: seq[byte]): Future[void] {.async.} =
|
||||||
if message.len == 0:
|
if message.len == 0:
|
||||||
return
|
return
|
||||||
|
|
||||||
|
const FramingSize = 2 + sizeof(ChaChaPolyTag)
|
||||||
|
|
||||||
|
let
|
||||||
|
frames = (message.len + MaxPlainSize - 1) div MaxPlainSize
|
||||||
|
|
||||||
var
|
var
|
||||||
|
cipherFrames = newSeqUninitialized[byte](message.len + frames * FramingSize)
|
||||||
left = message.len
|
left = message.len
|
||||||
offset = 0
|
offset = 0
|
||||||
|
woffset = 0
|
||||||
|
|
||||||
while left > 0:
|
while left > 0:
|
||||||
let
|
let
|
||||||
chunkSize = min(MaxPlainSize, left)
|
chunkSize = min(MaxPlainSize, left)
|
||||||
cipher = sconn.writeCs.encryptWithAd(
|
|
||||||
[], message.toOpenArray(offset, offset + chunkSize - 1))
|
|
||||||
|
|
||||||
await sconn.stream.writeFrame(cipher)
|
encryptFrame(
|
||||||
|
sconn,
|
||||||
|
cipherFrames.toOpenArray(woffset, woffset + chunkSize + FramingSize - 1),
|
||||||
|
message.toOpenArray(offset, offset + chunkSize - 1))
|
||||||
|
|
||||||
when defined(libp2p_dump):
|
when defined(libp2p_dump):
|
||||||
dumpMessage(
|
dumpMessage(
|
||||||
|
@ -438,9 +475,12 @@ method write*(sconn: NoiseConnection, message: seq[byte]): Future[void] {.async.
|
||||||
message.toOpenArray(offset, offset + chunkSize - 1))
|
message.toOpenArray(offset, offset + chunkSize - 1))
|
||||||
|
|
||||||
left = left - chunkSize
|
left = left - chunkSize
|
||||||
offset = offset + chunkSize
|
offset += chunkSize
|
||||||
|
woffset += chunkSize + FramingSize
|
||||||
sconn.activity = true
|
sconn.activity = true
|
||||||
|
|
||||||
|
await sconn.stream.write(cipherFrames)
|
||||||
|
|
||||||
method handshake*(p: Noise, conn: Connection, initiator: bool): Future[SecureConn] {.async.} =
|
method handshake*(p: Noise, conn: Connection, initiator: bool): Future[SecureConn] {.async.} =
|
||||||
trace "Starting Noise handshake", conn, initiator
|
trace "Starting Noise handshake", conn, initiator
|
||||||
|
|
||||||
|
@ -529,8 +569,8 @@ method handshake*(p: Noise, conn: Connection, initiator: bool): Future[SecureCon
|
||||||
|
|
||||||
return secure
|
return secure
|
||||||
|
|
||||||
method close*(s: NoiseConnection) {.async.} =
|
method closeImpl*(s: NoiseConnection) {.async.} =
|
||||||
await procCall SecureConn(s).close()
|
await procCall SecureConn(s).closeImpl()
|
||||||
|
|
||||||
burnMem(s.readCs)
|
burnMem(s.readCs)
|
||||||
burnMem(s.writeCs)
|
burnMem(s.writeCs)
|
||||||
|
|
|
@ -56,29 +56,29 @@ method initStream*(s: SecureConn) =
|
||||||
|
|
||||||
procCall Connection(s).initStream()
|
procCall Connection(s).initStream()
|
||||||
|
|
||||||
method close*(s: SecureConn) {.async.} =
|
method closeImpl*(s: SecureConn) {.async.} =
|
||||||
trace "Closing secure conn", s, dir = s.dir
|
trace "Closing secure conn", s, dir = s.dir
|
||||||
if not(isNil(s.stream)):
|
if not(isNil(s.stream)):
|
||||||
await s.stream.close()
|
await s.stream.close()
|
||||||
|
|
||||||
await procCall Connection(s).close()
|
await procCall Connection(s).closeImpl()
|
||||||
|
|
||||||
method readMessage*(c: SecureConn): Future[seq[byte]] {.async, base.} =
|
method readMessage*(c: SecureConn): Future[seq[byte]] {.async, base.} =
|
||||||
doAssert(false, "Not implemented!")
|
doAssert(false, "Not implemented!")
|
||||||
|
|
||||||
method handshake(s: Secure,
|
method handshake*(s: Secure,
|
||||||
conn: Connection,
|
conn: Connection,
|
||||||
initiator: bool): Future[SecureConn] {.async, base.} =
|
initiator: bool): Future[SecureConn] {.async, base.} =
|
||||||
doAssert(false, "Not implemented!")
|
doAssert(false, "Not implemented!")
|
||||||
|
|
||||||
proc handleConn*(s: Secure,
|
proc handleConn(s: Secure,
|
||||||
conn: Connection,
|
conn: Connection,
|
||||||
initiator: bool): Future[Connection] {.async, gcsafe.} =
|
initiator: bool): Future[Connection] {.async.} =
|
||||||
var sconn = await s.handshake(conn, initiator)
|
var sconn = await s.handshake(conn, initiator)
|
||||||
|
|
||||||
proc cleanup() {.async.} =
|
proc cleanup() {.async.} =
|
||||||
try:
|
try:
|
||||||
let futs = @[conn.join(), sconn.join()]
|
let futs = [conn.join(), sconn.join()]
|
||||||
await futs[0] or futs[1]
|
await futs[0] or futs[1]
|
||||||
for f in futs:
|
for f in futs:
|
||||||
if not f.finished: await f.cancelAndWait() # cancel outstanding join()
|
if not f.finished: await f.cancelAndWait() # cancel outstanding join()
|
||||||
|
@ -90,7 +90,7 @@ proc handleConn*(s: Secure,
|
||||||
# do not need to propagate CancelledError.
|
# do not need to propagate CancelledError.
|
||||||
discard
|
discard
|
||||||
except CatchableError as exc:
|
except CatchableError as exc:
|
||||||
trace "error cleaning up secure connection", err = exc.msg, sconn
|
debug "error cleaning up secure connection", err = exc.msg, sconn
|
||||||
|
|
||||||
if not isNil(sconn):
|
if not isNil(sconn):
|
||||||
# All the errors are handled inside `cleanup()` procedure.
|
# All the errors are handled inside `cleanup()` procedure.
|
||||||
|
@ -98,10 +98,10 @@ proc handleConn*(s: Secure,
|
||||||
|
|
||||||
return sconn
|
return sconn
|
||||||
|
|
||||||
method init*(s: Secure) {.gcsafe.} =
|
method init*(s: Secure) =
|
||||||
procCall LPProtocol(s).init()
|
procCall LPProtocol(s).init()
|
||||||
|
|
||||||
proc handle(conn: Connection, proto: string) {.async, gcsafe.} =
|
proc handle(conn: Connection, proto: string) {.async.} =
|
||||||
trace "handling connection upgrade", proto, conn
|
trace "handling connection upgrade", proto, conn
|
||||||
try:
|
try:
|
||||||
# We don't need the result but we
|
# We don't need the result but we
|
||||||
|
@ -121,36 +121,34 @@ method init*(s: Secure) {.gcsafe.} =
|
||||||
method secure*(s: Secure,
|
method secure*(s: Secure,
|
||||||
conn: Connection,
|
conn: Connection,
|
||||||
initiator: bool):
|
initiator: bool):
|
||||||
Future[Connection] {.base, gcsafe.} =
|
Future[Connection] {.base.} =
|
||||||
s.handleConn(conn, initiator)
|
s.handleConn(conn, initiator)
|
||||||
|
|
||||||
method readOnce*(s: SecureConn,
|
method readOnce*(s: SecureConn,
|
||||||
pbytes: pointer,
|
pbytes: pointer,
|
||||||
nbytes: int):
|
nbytes: int):
|
||||||
Future[int] {.async, gcsafe.} =
|
Future[int] {.async.} =
|
||||||
doAssert(nbytes > 0, "nbytes must be positive integer")
|
doAssert(nbytes > 0, "nbytes must be positive integer")
|
||||||
|
|
||||||
if s.buf.data().len() == 0:
|
if s.isEof:
|
||||||
let (buf, err) = try:
|
raise newLPStreamEOFError()
|
||||||
(await s.readMessage(), nil)
|
|
||||||
except CatchableError as exc:
|
|
||||||
(@[], exc)
|
|
||||||
|
|
||||||
if not isNil(err):
|
if s.buf.data().len() == 0:
|
||||||
if not (err of LPStreamEOFError):
|
try:
|
||||||
debug "Error while reading message from secure connection, closing.",
|
let buf = await s.readMessage() # Always returns >0 bytes or raises
|
||||||
error=err.name,
|
s.activity = true
|
||||||
message=err.msg,
|
s.buf.add(buf)
|
||||||
connection=s
|
except LPStreamEOFError as err:
|
||||||
|
s.isEof = true
|
||||||
|
await s.close()
|
||||||
|
raise err
|
||||||
|
except CatchableError as err:
|
||||||
|
debug "Error while reading message from secure connection, closing.",
|
||||||
|
error = err.name,
|
||||||
|
message = err.msg,
|
||||||
|
connection = s
|
||||||
await s.close()
|
await s.close()
|
||||||
raise err
|
raise err
|
||||||
|
|
||||||
s.activity = true
|
|
||||||
|
|
||||||
if buf.len == 0:
|
|
||||||
raise newLPStreamIncompleteError()
|
|
||||||
|
|
||||||
s.buf.add(buf)
|
|
||||||
|
|
||||||
var p = cast[ptr UncheckedArray[byte]](pbytes)
|
var p = cast[ptr UncheckedArray[byte]](pbytes)
|
||||||
return s.buf.consumeTo(toOpenArray(p, 0, nbytes - 1))
|
return s.buf.consumeTo(toOpenArray(p, 0, nbytes - 1))
|
||||||
|
|
Loading…
Reference in New Issue