chore: remove in memory hashmap for ephemeral keys

This commit is contained in:
kaichaosun 2026-03-12 16:22:51 +08:00
parent 5d87b1d19a
commit 3db9210ac3
No known key found for this signature in database
GPG Key ID: 223E0F992F4F03BF
4 changed files with 134 additions and 91 deletions

View File

@ -1,11 +1,9 @@
use std::collections::HashMap;
use std::rc::Rc; use std::rc::Rc;
use storage::StorageConfig; use storage::StorageConfig;
use crate::{ use crate::{
conversation::{ConversationId, ConversationStore, Convo, Id}, conversation::{ConversationId, ConversationStore, Convo, Id},
crypto::PrivateKey,
errors::ChatError, errors::ChatError,
identity::Identity, identity::Identity,
inbox::Inbox, inbox::Inbox,
@ -23,7 +21,6 @@ pub struct Context {
_identity: Rc<Identity>, _identity: Rc<Identity>,
store: ConversationStore, store: ConversationStore,
inbox: Inbox, inbox: Inbox,
#[allow(dead_code)] // Will be used for conversation persistence
storage: ChatStorage, storage: ChatStorage,
} }
@ -46,17 +43,7 @@ impl Context {
}; };
let identity = Rc::new(identity); let identity = Rc::new(identity);
let mut inbox = Inbox::new(Rc::clone(&identity)); let inbox = Inbox::new(Rc::clone(&identity));
// Restore ephemeral keys from storage
let stored_keys = storage.load_ephemeral_keys()?;
if !stored_keys.is_empty() {
let keys: HashMap<String, PrivateKey> = stored_keys
.into_iter()
.map(|record| (record.public_key_hex.clone(), PrivateKey::from(record.secret_key)))
.collect();
inbox.restore_ephemeral_keys(keys);
}
Ok(Self { Ok(Self {
_identity: identity, _identity: identity,
@ -138,8 +125,18 @@ impl Context {
&mut self, &mut self,
enc_payload: EncryptedPayload, enc_payload: EncryptedPayload,
) -> Result<Option<ContentData>, ChatError> { ) -> Result<Option<ContentData>, ChatError> {
let (convo, content, consumed_key_hex) = self.inbox.handle_frame(enc_payload)?; // Look up the ephemeral key from storage
self.storage.remove_ephemeral_key(&consumed_key_hex)?; let key_hex = Inbox::extract_ephemeral_key_hex(&enc_payload)?;
let ephemeral_key = self
.storage
.load_ephemeral_key(&key_hex)?
.ok_or(ChatError::UnknownEphemeralKey())?;
let (convo, content) = self.inbox.handle_frame(&ephemeral_key, enc_payload)?;
// Remove consumed ephemeral key from storage
self.storage.remove_ephemeral_key(&key_hex)?;
self.add_convo(convo); self.add_convo(convo);
Ok(content) Ok(content)
} }
@ -268,4 +265,37 @@ mod tests {
assert_eq!(pubkey1, pubkey2, "public key should persist"); assert_eq!(pubkey1, pubkey2, "public key should persist");
assert_eq!(name1, name2, "name should persist"); assert_eq!(name1, name2, "name should persist");
} }
#[test]
fn ephemeral_key_persistence() {
let dir = tempfile::tempdir().unwrap();
let db_path = dir
.path()
.join("test_ephemeral.db")
.to_string_lossy()
.to_string();
let config = StorageConfig::File(db_path);
// Create context and generate an intro bundle (creates ephemeral key)
let mut ctx1 = Context::open("alice", config.clone()).unwrap();
let bundle1 = ctx1.create_intro_bundle().unwrap();
// Drop and reopen - ephemeral keys should be restored from db
drop(ctx1);
let mut ctx2 = Context::open("alice", config.clone()).unwrap();
// Use the intro bundle from before restart to start a conversation
let intro = Introduction::try_from(bundle1.as_slice()).unwrap();
let mut bob = Context::new_with_name("bob");
let (_, payloads) = bob.create_private_convo(&intro, b"hello after restart");
// Alice (ctx2) should be able to handle the payload using the persisted ephemeral key
let payload = payloads.first().unwrap();
let content = ctx2
.handle_payload(&payload.data)
.expect("should handle payload with persisted ephemeral key")
.expect("should have content");
assert_eq!(content.data, b"hello after restart");
assert!(content.is_new_convo);
}
} }

View File

@ -3,7 +3,6 @@ use chat_proto::logoschat::encryption::EncryptedPayload;
use prost::Message; use prost::Message;
use prost::bytes::Bytes; use prost::bytes::Bytes;
use rand_core::OsRng; use rand_core::OsRng;
use std::collections::HashMap;
use std::rc::Rc; use std::rc::Rc;
use crypto::{PrekeyBundle, SymmetricKey32}; use crypto::{PrekeyBundle, SymmetricKey32};
@ -25,7 +24,6 @@ fn delivery_address_for_installation(_: PublicKey) -> String {
pub struct Inbox { pub struct Inbox {
ident: Rc<Identity>, ident: Rc<Identity>,
local_convo_id: String, local_convo_id: String,
ephemeral_keys: HashMap<String, PrivateKey>,
} }
impl std::fmt::Debug for Inbox { impl std::fmt::Debug for Inbox {
@ -33,10 +31,6 @@ impl std::fmt::Debug for Inbox {
f.debug_struct("Inbox") f.debug_struct("Inbox")
.field("ident", &self.ident) .field("ident", &self.ident)
.field("convo_id", &self.local_convo_id) .field("convo_id", &self.local_convo_id)
.field(
"ephemeral_keys",
&format!("<{} keys>", self.ephemeral_keys.len()),
)
.finish() .finish()
} }
} }
@ -47,24 +41,16 @@ impl Inbox {
Self { Self {
ident, ident,
local_convo_id, local_convo_id,
ephemeral_keys: HashMap::<String, PrivateKey>::new(),
} }
} }
/// Restores ephemeral keys from storage into the in-memory map. /// Creates an intro bundle and returns the Introduction along with the
pub fn restore_ephemeral_keys(&mut self, keys: HashMap<String, PrivateKey>) { /// generated ephemeral key pair (public_key_hex, private_key) for the caller to persist.
self.ephemeral_keys = keys; pub fn create_intro_bundle(&self) -> (Introduction, String, PrivateKey) {
}
/// Creates an intro bundle and returns the (public_key_hex, private_key) pair
/// so the caller can persist it.
pub fn create_intro_bundle(&mut self) -> (Introduction, String, PrivateKey) {
let ephemeral = PrivateKey::random(); let ephemeral = PrivateKey::random();
let ephemeral_key: PublicKey = (&ephemeral).into(); let ephemeral_key: PublicKey = (&ephemeral).into();
let public_key_hex = hex::encode(ephemeral_key.as_bytes()); let public_key_hex = hex::encode(ephemeral_key.as_bytes());
self.ephemeral_keys
.insert(public_key_hex.clone(), ephemeral.clone());
let intro = Introduction::new(self.ident.secret(), ephemeral_key, OsRng); let intro = Introduction::new(self.ident.secret(), ephemeral_key, OsRng);
(intro, public_key_hex, ephemeral) (intro, public_key_hex, ephemeral)
@ -123,22 +109,19 @@ impl Inbox {
Ok((convo, payloads)) Ok((convo, payloads))
} }
/// Handles an incoming inbox frame. Returns the created conversation, /// Handles an incoming inbox frame. The caller must provide the ephemeral private key
/// optional content data, and the consumed ephemeral key hex (for storage cleanup). /// looked up from storage. Returns the created conversation and optional content data.
pub fn handle_frame( pub fn handle_frame(
&mut self, &self,
ephemeral_key: &PrivateKey,
enc_payload: EncryptedPayload, enc_payload: EncryptedPayload,
) -> Result<(Box<dyn Convo>, Option<ContentData>, String), ChatError> { ) -> Result<(Box<dyn Convo>, Option<ContentData>), ChatError> {
let handshake = Self::extract_payload(enc_payload)?; let handshake = Self::extract_payload(enc_payload)?;
let header = handshake let header = handshake
.header .header
.ok_or(ChatError::UnexpectedPayload("InboxV1Header".into()))?; .ok_or(ChatError::UnexpectedPayload("InboxV1Header".into()))?;
// Get Ephemeral key used by the initator
let key_index = hex::encode(header.responder_ephemeral.as_ref());
let ephemeral_key = self.lookup_ephemeral_key(&key_index)?;
// Perform handshake and decrypt frame // Perform handshake and decrypt frame
let (seed_key, frame) = self.perform_handshake(ephemeral_key, header, handshake.payload)?; let (seed_key, frame) = self.perform_handshake(ephemeral_key, header, handshake.payload)?;
@ -159,11 +142,29 @@ impl Inbox {
None => return Err(ChatError::Protocol("expected contentData".into())), None => return Err(ChatError::Protocol("expected contentData".into())),
}; };
Ok((Box::new(convo), Some(content), key_index)) Ok((Box::new(convo), Some(content)))
} }
} }
} }
/// Extracts the ephemeral key hex from an incoming encrypted payload
/// so the caller can look it up from storage before calling handle_frame.
pub fn extract_ephemeral_key_hex(
enc_payload: &EncryptedPayload,
) -> Result<String, ChatError> {
let Some(proto::Encryption::InboxHandshake(ref handshake)) = enc_payload.encryption else {
let got = format!("{:?}", enc_payload.encryption);
return Err(ChatError::ProtocolExpectation("inboxhandshake", got));
};
let header = handshake
.header
.as_ref()
.ok_or(ChatError::UnexpectedPayload("InboxV1Header".into()))?;
Ok(hex::encode(header.responder_ephemeral.as_ref()))
}
fn wrap_in_invite(payload: proto::EncryptedPayload) -> proto::InboxV1Frame { fn wrap_in_invite(payload: proto::EncryptedPayload) -> proto::InboxV1Frame {
let invite = proto::InvitePrivateV1 { let invite = proto::InvitePrivateV1 {
discriminator: "default".into(), discriminator: "default".into(),
@ -225,12 +226,6 @@ impl Inbox {
Ok(frame) Ok(frame)
} }
fn lookup_ephemeral_key(&self, key: &str) -> Result<&PrivateKey, ChatError> {
self.ephemeral_keys
.get(key)
.ok_or(ChatError::UnknownEphemeralKey())
}
pub fn inbox_identifier_for_key(pubkey: PublicKey) -> String { pub fn inbox_identifier_for_key(pubkey: PublicKey) -> String {
// TODO: Implement ID according to spec // TODO: Implement ID according to spec
hex::encode(Blake2b512::digest(pubkey)) hex::encode(Blake2b512::digest(pubkey))
@ -246,24 +241,34 @@ impl Id for Inbox {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;
use crate::storage::ChatStorage;
use storage::StorageConfig;
#[test] #[test]
fn test_invite_privatev1_roundtrip() { fn test_invite_privatev1_roundtrip() {
let mut storage = ChatStorage::new(StorageConfig::InMemory).unwrap();
let saro_ident = Identity::new("saro"); let saro_ident = Identity::new("saro");
let saro_inbox = Inbox::new(saro_ident.into()); let saro_inbox = Inbox::new(saro_ident.into());
let raya_ident = Identity::new("raya"); let raya_ident = Identity::new("raya");
let mut raya_inbox = Inbox::new(raya_ident.into()); let raya_inbox = Inbox::new(raya_ident.into());
let (bundle, key_hex, private_key) = raya_inbox.create_intro_bundle();
storage.save_ephemeral_key(&key_hex, &private_key).unwrap();
let (bundle, _key_hex, _private_key) = raya_inbox.create_intro_bundle();
let (_, mut payloads) = saro_inbox let (_, mut payloads) = saro_inbox
.invite_to_private_convo(&bundle, "hello".as_bytes()) .invite_to_private_convo(&bundle, "hello".as_bytes())
.unwrap(); .unwrap();
let payload = payloads.remove(0); let payload = payloads.remove(0);
// Look up ephemeral key from storage
let key_hex = Inbox::extract_ephemeral_key_hex(&payload.data).unwrap();
let ephemeral_key = storage.load_ephemeral_key(&key_hex).unwrap().unwrap();
// Test handle_frame with valid payload // Test handle_frame with valid payload
let result = raya_inbox.handle_frame(payload.data); let result = raya_inbox.handle_frame(&ephemeral_key, payload.data);
assert!( assert!(
result.is_ok(), result.is_ok(),

View File

@ -4,7 +4,7 @@ use storage::{RusqliteError, SqliteDb, StorageConfig, StorageError, params};
use zeroize::Zeroize; use zeroize::Zeroize;
use super::migrations; use super::migrations;
use super::types::{EphemeralKeyRecord, IdentityRecord}; use super::types::IdentityRecord;
use crate::crypto::PrivateKey; use crate::crypto::PrivateKey;
use crate::identity::Identity; use crate::identity::Identity;
@ -66,42 +66,39 @@ impl ChatStorage {
Ok(()) Ok(())
} }
/// Loads all ephemeral keys from storage. /// Loads a single ephemeral key by its public key hex.
pub fn load_ephemeral_keys( pub fn load_ephemeral_key(
&self, &self,
) -> Result<Vec<EphemeralKeyRecord>, StorageError> { public_key_hex: &str,
) -> Result<Option<PrivateKey>, StorageError> {
let mut stmt = self let mut stmt = self
.db .db
.connection() .connection()
.prepare("SELECT public_key_hex, secret_key FROM ephemeral_keys")?; .prepare("SELECT secret_key FROM ephemeral_keys WHERE public_key_hex = ?1")?;
let records = stmt let result = stmt.query_row(params![public_key_hex], |row| {
.query_map([], |row| { let secret_key: Vec<u8> = row.get(0)?;
let public_key_hex: String = row.get(0)?; Ok(secret_key)
let secret_key: Vec<u8> = row.get(1)?; });
Ok((public_key_hex, secret_key))
})?
.collect::<Result<Vec<_>, _>>()?;
let mut result = Vec::with_capacity(records.len()); match result {
for (public_key_hex, mut secret_key_vec) in records { Ok(mut secret_key_vec) => {
let bytes: Result<[u8; 32], _> = secret_key_vec.as_slice().try_into(); let bytes: Result<[u8; 32], _> = secret_key_vec.as_slice().try_into();
let bytes = match bytes { let bytes = match bytes {
Ok(b) => b, Ok(b) => b,
Err(_) => { Err(_) => {
secret_key_vec.zeroize(); secret_key_vec.zeroize();
return Err(StorageError::InvalidData( return Err(StorageError::InvalidData(
"Invalid ephemeral secret key length".into(), "Invalid ephemeral secret key length".into(),
)); ));
} }
}; };
secret_key_vec.zeroize(); secret_key_vec.zeroize();
result.push(EphemeralKeyRecord { Ok(Some(PrivateKey::from(bytes)))
public_key_hex, }
secret_key: bytes, Err(RusqliteError::QueryReturnedNoRows) => Ok(None),
}); Err(e) => Err(e.into()),
} }
Ok(result)
} }
/// Removes an ephemeral key from storage. /// Removes an ephemeral key from storage.
@ -176,4 +173,25 @@ mod tests {
let loaded = storage.load_identity().unwrap().unwrap(); let loaded = storage.load_identity().unwrap().unwrap();
assert_eq!(loaded.public_key(), pubkey); assert_eq!(loaded.public_key(), pubkey);
} }
#[test]
fn test_ephemeral_key_roundtrip() {
let mut storage = ChatStorage::new(StorageConfig::InMemory).unwrap();
let key1 = PrivateKey::random();
let pub1: crate::crypto::PublicKey = (&key1).into();
let hex1 = hex::encode(pub1.as_bytes());
// Initially not found
assert!(storage.load_ephemeral_key(&hex1).unwrap().is_none());
// Save and load
storage.save_ephemeral_key(&hex1, &key1).unwrap();
let loaded = storage.load_ephemeral_key(&hex1).unwrap().unwrap();
assert_eq!(loaded.DANGER_to_bytes(), key1.DANGER_to_bytes());
// Remove and verify gone
storage.remove_ephemeral_key(&hex1).unwrap();
assert!(storage.load_ephemeral_key(&hex1).unwrap().is_none());
}
} }

View File

@ -22,16 +22,6 @@ impl From<IdentityRecord> for Identity {
} }
} }
/// Record for storing an ephemeral key pair.
/// Implements ZeroizeOnDrop to securely clear secret key from memory.
#[derive(Debug, Zeroize, ZeroizeOnDrop)]
pub struct EphemeralKeyRecord {
/// Hex-encoded public key (used as lookup key).
pub public_key_hex: String,
/// The secret key bytes (32 bytes).
pub secret_key: [u8; 32],
}
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;