add more proof tracking tests.

This commit is contained in:
M Alghazwi 2025-07-10 09:41:01 +02:00
parent 29a1a2fbb1
commit 0c40b2e338
No known key found for this signature in database
GPG Key ID: 646E567CAD7DB607

View File

@ -576,78 +576,167 @@ mod proof_tracking_tests {
#[test] #[test]
fn test_split_index() -> anyhow::Result<()> { fn test_split_index() -> anyhow::Result<()> {
// Create a circuit where we register the outputs q and r of split_index. // Test split_index for indices 0..128 with 4 buckets of size BUCKET_SIZE.
let mut builder = CircuitBuilder::<F, D>::new(CircuitConfig::standard_recursion_config()); for index_val in 0..128 {
// Let index = 45. let mut builder = CircuitBuilder::<F, D>::new(CircuitConfig::standard_recursion_config());
let index_val: u64 = 45; let index_target = builder.constant(F::from_canonical_u64(index_val as u64));
let index_target = builder.constant(F::from_canonical_u64(index_val)); let (q_target, r_target) =
// Call split_index with bucket_size=32 and num_buckets=4. We expect q = 1 and r = 13. split_index::<F, D>(&mut builder, index_target, BUCKET_SIZE, 4)?;
let (q_target, r_target) = builder.register_public_input(q_target);
split_index::<F,D>(&mut builder, index_target, BUCKET_SIZE, 4)?; builder.register_public_input(r_target);
// Register outputs as public inputs. let pub_inputs = build_and_prove(builder);
builder.register_public_input(q_target); let expected_q = index_val / BUCKET_SIZE;
builder.register_public_input(r_target); let expected_r = index_val % BUCKET_SIZE;
// Build and prove the circuit. assert_eq!(
let pub_inputs = build_and_prove(builder); pub_inputs[0].to_canonical_u64(),
// We expect the first public input to be q = 1 and the second r = 13. expected_q as u64,
assert_eq!(pub_inputs[0].to_canonical_u64(), 1, "q should be 1"); "q for index {} should be {}",
assert_eq!(pub_inputs[1].to_canonical_u64(), 13, "r should be 13"); index_val,
expected_q
);
assert_eq!(
pub_inputs[1].to_canonical_u64(),
expected_r as u64,
"r for index {} should be {}",
index_val,
expected_r
);
}
Ok(()) Ok(())
} }
#[test] #[test]
fn test_compute_power_of_two() -> anyhow::Result<()> { fn test_split_index_invalid_index() -> anyhow::Result<()> {
// Create a circuit to compute 2^r. // The maximum valid index is BUCKET_SIZE * num_buckets - 1.
// Test that an out-of-range index fails to prove.
let invalid_index = BUCKET_SIZE * 4;
let mut builder = CircuitBuilder::<F, D>::new(CircuitConfig::standard_recursion_config()); let mut builder = CircuitBuilder::<F, D>::new(CircuitConfig::standard_recursion_config());
// Let r = 13. let index_target = builder.constant(F::from_canonical_u64(invalid_index as u64));
let r_val: u64 = 13; let (q_target, r_target) =
let r_target = builder.constant(F::from_canonical_u64(r_val)); split_index::<F, D>(&mut builder, index_target, BUCKET_SIZE, 4)?;
let pow_target = // Register the outputs as public inputs.
compute_power_of_two::<F,D>(&mut builder, r_target)?; builder.register_public_input(q_target);
builder.register_public_input(pow_target); builder.register_public_input(r_target);
let pub_inputs = build_and_prove(builder); // Build the circuit and attempt to prove.
// Expect 2^13 = 8192. let circuit = builder.build::<C>();
assert_eq!( let pw = PartialWitness::new();
pub_inputs[0].to_canonical_u64(), // Proof should fail.
1 << 13, assert!(
"2^13 should be 8192" circuit.prove(pw).is_err(),
"Proving should fail for out-of-range index {}",
invalid_index
); );
Ok(()) Ok(())
} }
#[test] #[test]
fn test_compute_flag_buckets() -> anyhow::Result<()> { fn test_compute_power_of_two() -> anyhow::Result<()> {
// Create a circuit to compute flag buckets. // Test compute_power_of_two for r in 0..128.
// Let index = 45 and flag = true. let two = F::from_canonical_u64(2);
let mut expected = F::ONE;
for r_val in 0..BUCKET_SIZE {
// Update expected = 2^r_val in the field.
if r_val == 0 {
expected = F::ONE;
} else {
expected = expected * two;
}
// Build a circuit for this r_val.
let mut builder = CircuitBuilder::<F, D>::new(CircuitConfig::standard_recursion_config());
let r_target = builder.constant(F::from_canonical_u64(r_val as u64));
let pow_target = compute_power_of_two::<F, D>(&mut builder, r_target)?;
builder.register_public_input(pow_target);
let pub_inputs = build_and_prove(builder);
// Compare the circuit output to the expected 2^r_val.
assert_eq!(
pub_inputs[0].to_canonical_u64(),
expected.to_canonical_u64(),
"2^{} should be {}",
r_val,
expected.to_canonical_u64()
);
}
Ok(())
}
fn test_compute_flag_buckets(flag: bool) -> anyhow::Result<()> {
// Test compute_flag_buckets for all indices 0..128 with given flag.
for index_val in 0..(BUCKET_SIZE * 4) {
let mut builder = CircuitBuilder::<F, D>::new(CircuitConfig::standard_recursion_config());
let index_target = builder.constant(F::from_canonical_u64(index_val as u64));
let flag_target = builder.constant_bool(flag);
let buckets = compute_flag_buckets::<F, D>(
&mut builder,
index_target,
flag_target,
BUCKET_SIZE,
4,
)?;
for bucket in buckets.iter() {
builder.register_public_input(*bucket);
}
let pub_inputs = build_and_prove(builder);
// Build expected buckets: only bucket q = index_val / BUCKET_SIZE has value 2^r.
let mut expected = vec![0u64; 4];
let q = index_val / BUCKET_SIZE;
let r = index_val % BUCKET_SIZE;
if flag {
expected[q] = 1u64 << r;
} else {
expected[q] = 0u64;
}
for (i, &expected_val) in expected.iter().enumerate() {
let computed = pub_inputs[i].to_canonical_u64();
assert_eq!(
computed,
expected_val,
"Bucket {} for index {}: expected {} but got {}",
i,
index_val,
expected_val,
computed
);
}
}
Ok(())
}
#[test]
fn test_compute_flag_buckets_real() -> anyhow::Result<()> {
test_compute_flag_buckets(true)
}
#[test]
fn test_compute_flag_buckets_dummy() -> anyhow::Result<()> {
test_compute_flag_buckets(false)
}
#[test]
fn test_compute_flag_buckets_invalid_index() -> anyhow::Result<()> {
// The maximum valid index is BUCKET_SIZE * num_buckets - 1.
// Test that an out-of-range index fails to prove.
let invalid_index = BUCKET_SIZE * 4;
let mut builder = CircuitBuilder::<F, D>::new(CircuitConfig::standard_recursion_config()); let mut builder = CircuitBuilder::<F, D>::new(CircuitConfig::standard_recursion_config());
let index_val: u64 = 45; let index_target = builder.constant(F::from_canonical_u64(invalid_index as u64));
let index_target = builder.constant(F::from_canonical_u64(index_val));
// Create a boolean constant target for flag = true.
let flag_target = builder.constant_bool(true); let flag_target = builder.constant_bool(true);
// Compute the flag buckets with bucket_size = 32 and num_buckets = 4. let buckets = compute_flag_buckets::<F, D>(
let buckets = compute_flag_buckets::<F,D>(
&mut builder, &mut builder,
index_target, index_target,
flag_target, flag_target,
BUCKET_SIZE, BUCKET_SIZE,
4, 4,
)?; )?;
// Register each bucket as a public input.
for bucket in buckets.iter() { for bucket in buckets.iter() {
builder.register_public_input(*bucket); builder.register_public_input(*bucket);
} }
let pub_inputs = build_and_prove(builder); // Build and attempt to prove.
// With index = 45, we expect: let circuit = builder.build::<C>();
// q = 45 / 32 = 1 and r = 45 % 32 = 13, so bucket 1 should be 2^13 = 8192 and the others 0. let pw = PartialWitness::new();
let expected = vec![0, 8192, 0, 0]; assert!(
for (i, &expected_val) in expected.iter().enumerate() { circuit.prove(pw).is_err(),
let computed = pub_inputs[i].to_canonical_u64(); "Proving should fail for out-of-range index {}",
assert_eq!( invalid_index
computed, expected_val, );
"Bucket {}: expected {} but got {}",
i, expected_val, computed
);
}
Ok(()) Ok(())
} }
} }