refactor: localize private_pda_npk_by_position and extract authorization helper

Addresses the following review comments from @Arjentix:

- "I think we can move this into `derive_from_outputs()`"
  (on the position → npk map construction in main())
  I moved the construction inside ExecutionState::derive_from_outputs
  and stored the map as a field of ExecutionState. derive_from_outputs
  now takes `private_account_keys` directly and builds the map as part
  of state initialization. main() no longer owns the intermediate
  structure. validate_and_sync_states reads the npk through
  self.private_pda_npk_by_position.

- "Let's move this whole `is_authorized` computation into a separate
  function. This became really bulky"
  I extracted the caller-seeds resolution, family-binding recording,
  and is_authorized computation into a free function
  `resolve_authorization_and_record_bindings`. It takes the three
  field borrows it needs (`&mut pda_family_binding`, `&mut
  private_pda_bound_positions`, `&private_pda_npk_by_position`), same
  shape as `assert_family_binding`. A method would have conflicted
  with the `&mut self.post_states` borrow held by the Occupied match
  arm; the free function lets rustc split-borrow the self fields.
This commit is contained in:
Moudy 2026-04-22 15:55:35 +02:00
parent 22aa5ef70b
commit e5b77a27d5
36 changed files with 114 additions and 90 deletions

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

View File

@ -43,16 +43,42 @@ struct ExecutionState {
/// `AccountId` entry or as an equality check against the existing one, making the rule: one
/// `(program, seed)` → one account per tx.
pda_family_binding: HashMap<(ProgramId, PdaSeed), AccountId>,
/// Map from a mask-3 `pre_state`'s position in `visibility_mask` to the npk supplied for
/// that position in `private_account_keys`. Built once in `derive_from_outputs` by walking
/// `visibility_mask` in lock-step with `private_account_keys`, used later by the claim and
/// caller-seeds authorization paths.
private_pda_npk_by_position: HashMap<usize, NullifierPublicKey>,
}
impl ExecutionState {
/// Validate program outputs and derive the overall execution state.
pub fn derive_from_outputs(
visibility_mask: &[u8],
private_pda_npk_by_position: &HashMap<usize, NullifierPublicKey>,
private_account_keys: &[(NullifierPublicKey, SharedSecretKey)],
program_id: ProgramId,
program_outputs: Vec<ProgramOutput>,
) -> Self {
// Build position → npk map for mask-3 pre_states. `private_account_keys` is consumed in
// pre_state order across all masks 1/2/3, so walk `visibility_mask` in lock-step. The
// downstream `compute_circuit_output` also consumes the same iterator and its trailing
// assertions catch an over-supply of keys; under-supply surfaces here.
let mut private_pda_npk_by_position: HashMap<usize, NullifierPublicKey> = HashMap::new();
{
let mut keys_iter = private_account_keys.iter();
for (pos, &mask) in visibility_mask.iter().enumerate() {
if matches!(mask, 1..=3) {
let (npk, _) = keys_iter.next().unwrap_or_else(|| {
panic!(
"private_account_keys shorter than visibility_mask demands: no key for masked position {pos} (mask {mask})"
)
});
if mask == 3 {
private_pda_npk_by_position.insert(pos, *npk);
}
}
}
}
let block_valid_from = program_outputs
.iter()
.filter_map(|output| output.block_validity_window.start())
@ -89,6 +115,7 @@ impl ExecutionState {
timestamp_validity_window,
private_pda_bound_positions: HashSet::new(),
pda_family_binding: HashMap::new(),
private_pda_npk_by_position,
};
let Some(first_output) = program_outputs.first() else {
@ -166,7 +193,6 @@ impl ExecutionState {
execution_state.validate_and_sync_states(
visibility_mask,
private_pda_npk_by_position,
chained_call.program_id,
caller_program_id,
&chained_call.pda_seeds,
@ -221,14 +247,9 @@ impl ExecutionState {
}
/// Validate program pre and post states and populate the execution state.
#[expect(
clippy::too_many_arguments,
reason = "breaking out a context struct does not buy us anything here"
)]
fn validate_and_sync_states(
&mut self,
visibility_mask: &[u8],
private_pda_npk_by_position: &HashMap<usize, NullifierPublicKey>,
program_id: ProgramId,
caller_program_id: Option<ProgramId>,
caller_pda_seeds: &[PdaSeed],
@ -270,43 +291,16 @@ impl ExecutionState {
|(pos, acc)| (acc.is_authorized, pos)
);
// Find which caller seed (if any) authorizes this pre_state, under the
// public or the private derivation. We need the *specific* seed (not just a
// bool) so we can record the `(caller, seed) → account_id` family binding.
// The match arm also returns the caller so the consumer below does not have
// to re-unwrap `caller_program_id`. Only reachable when
// `caller_program_id.is_some()`, top-level flows have no caller-emitted
// seeds, so binding at top level must come from the claim path below.
let matched_caller_seed: Option<(PdaSeed, bool, ProgramId)> = caller_program_id
.and_then(|caller| {
caller_pda_seeds.iter().find_map(|seed| {
if AccountId::for_public_pda(&caller, seed) == pre_account_id {
return Some((*seed, false, caller));
}
if let Some(npk) =
private_pda_npk_by_position.get(&pre_state_position)
&& AccountId::for_private_pda(&caller, seed, npk)
== pre_account_id
{
return Some((*seed, true, caller));
}
None
})
});
if let Some((seed, is_private_form, caller)) = matched_caller_seed {
assert_family_binding(
&mut self.pda_family_binding,
caller,
seed,
pre_account_id,
);
if is_private_form {
self.private_pda_bound_positions.insert(pre_state_position);
}
}
let is_authorized = previous_is_authorized || matched_caller_seed.is_some();
let is_authorized = resolve_authorization_and_record_bindings(
&mut self.pda_family_binding,
&mut self.private_pda_bound_positions,
&self.private_pda_npk_by_position,
pre_account_id,
pre_state_position,
caller_program_id,
caller_pda_seeds,
previous_is_authorized,
);
assert_eq!(
pre_is_authorized, is_authorized,
@ -358,31 +352,34 @@ impl ExecutionState {
);
}
},
3 => match claim {
Claim::Authorized => {
assert!(
pre_is_authorized,
"Cannot claim unauthorized private PDA {pre_account_id}"
);
3 => {
match claim {
Claim::Authorized => {
assert!(
pre_is_authorized,
"Cannot claim unauthorized private PDA {pre_account_id}"
);
}
Claim::Pda(seed) => {
let npk = self
.private_pda_npk_by_position
.get(&pre_state_position)
.expect("private PDA pre_state must have an npk in the position map");
let pda = AccountId::for_private_pda(&program_id, &seed, npk);
assert_eq!(
pre_account_id, pda,
"Invalid private PDA claim for account {pre_account_id}"
);
self.private_pda_bound_positions.insert(pre_state_position);
assert_family_binding(
&mut self.pda_family_binding,
program_id,
seed,
pre_account_id,
);
}
}
Claim::Pda(seed) => {
let npk = private_pda_npk_by_position.get(&pre_state_position).expect(
"private PDA pre_state must have an npk in the position map",
);
let pda = AccountId::for_private_pda(&program_id, &seed, npk);
assert_eq!(
pre_account_id, pda,
"Invalid private PDA claim for account {pre_account_id}"
);
self.private_pda_bound_positions.insert(pre_state_position);
assert_family_binding(
&mut self.pda_family_binding,
program_id,
seed,
pre_account_id,
);
}
},
}
_ => {
// Mask 1/2: standard private accounts don't enforce the claim semantics.
// Unauthorized private claiming is intentionally allowed since operating
@ -439,6 +436,54 @@ fn assert_family_binding(
}
}
/// Resolve the authorization state of a `pre_state` seen again in a chained call and record
/// any resulting bindings. Returns `true` if the `pre_state` is authorized through either a
/// previously-seen authorization or a matching caller seed (under the public or private
/// derivation). When a caller seed matches, also records the `(caller, seed) → account_id`
/// family binding and, for the private form, marks the position in
/// `private_pda_bound_positions`. Only reachable when `caller_program_id.is_some()`,
/// top-level flows have no caller-emitted seeds, so binding at top level must come from the
/// claim path. Free function so callers can pass individual `&mut self.*` field borrows
/// without holding a borrow on the surrounding struct's other fields.
#[expect(
clippy::too_many_arguments,
reason = "breaking out a context struct does not buy us anything here"
)]
fn resolve_authorization_and_record_bindings(
pda_family_binding: &mut HashMap<(ProgramId, PdaSeed), AccountId>,
private_pda_bound_positions: &mut HashSet<usize>,
private_pda_npk_by_position: &HashMap<usize, NullifierPublicKey>,
pre_account_id: AccountId,
pre_state_position: usize,
caller_program_id: Option<ProgramId>,
caller_pda_seeds: &[PdaSeed],
previous_is_authorized: bool,
) -> bool {
let matched_caller_seed: Option<(PdaSeed, bool, ProgramId)> =
caller_program_id.and_then(|caller| {
caller_pda_seeds.iter().find_map(|seed| {
if AccountId::for_public_pda(&caller, seed) == pre_account_id {
return Some((*seed, false, caller));
}
if let Some(npk) = private_pda_npk_by_position.get(&pre_state_position)
&& AccountId::for_private_pda(&caller, seed, npk) == pre_account_id
{
return Some((*seed, true, caller));
}
None
})
});
if let Some((seed, is_private_form, caller)) = matched_caller_seed {
assert_family_binding(pda_family_binding, caller, seed, pre_account_id);
if is_private_form {
private_pda_bound_positions.insert(pre_state_position);
}
}
previous_is_authorized || matched_caller_seed.is_some()
}
fn compute_circuit_output(
execution_state: ExecutionState,
visibility_mask: &[u8],
@ -718,30 +763,9 @@ fn main() {
program_id,
} = env::read();
// Build a position → npk map for mask-3 pre_states. `private_account_keys` is consumed in
// pre_state order across all masks 1/2/3, so walk `visibility_mask` in lock-step. The
// downstream `compute_circuit_output` also consumes the same iterator and its trailing
// assertions catch an over-supply of keys; under-supply surfaces here.
let mut private_pda_npk_by_position: HashMap<usize, NullifierPublicKey> = HashMap::new();
{
let mut keys_iter = private_account_keys.iter();
for (pos, &mask) in visibility_mask.iter().enumerate() {
if matches!(mask, 1..=3) {
let (npk, _) = keys_iter.next().unwrap_or_else(|| {
panic!(
"private_account_keys shorter than visibility_mask demands: no key for masked position {pos} (mask {mask})"
)
});
if mask == 3 {
private_pda_npk_by_position.insert(pos, *npk);
}
}
}
}
let execution_state = ExecutionState::derive_from_outputs(
&visibility_mask,
&private_pda_npk_by_position,
&private_account_keys,
program_id,
program_outputs,
);