From 0c40b2e338aaac0e76e94e1a20574868337192c4 Mon Sep 17 00:00:00 2001 From: M Alghazwi Date: Thu, 10 Jul 2025 09:41:01 +0200 Subject: [PATCH] add more proof tracking tests. --- proof-input/tests/tree_circuit.rs | 191 ++++++++++++++++++++++-------- 1 file changed, 140 insertions(+), 51 deletions(-) diff --git a/proof-input/tests/tree_circuit.rs b/proof-input/tests/tree_circuit.rs index 9fe56ee..18fd6a3 100644 --- a/proof-input/tests/tree_circuit.rs +++ b/proof-input/tests/tree_circuit.rs @@ -576,78 +576,167 @@ mod proof_tracking_tests { #[test] fn test_split_index() -> anyhow::Result<()> { - // Create a circuit where we register the outputs q and r of split_index. - let mut builder = CircuitBuilder::::new(CircuitConfig::standard_recursion_config()); - // Let index = 45. - let index_val: u64 = 45; - let index_target = builder.constant(F::from_canonical_u64(index_val)); - // Call split_index with bucket_size=32 and num_buckets=4. We expect q = 1 and r = 13. - let (q_target, r_target) = - split_index::(&mut builder, index_target, BUCKET_SIZE, 4)?; - // Register outputs as public inputs. - builder.register_public_input(q_target); - builder.register_public_input(r_target); - // Build and prove the circuit. - let pub_inputs = build_and_prove(builder); - // We expect the first public input to be q = 1 and the second r = 13. - assert_eq!(pub_inputs[0].to_canonical_u64(), 1, "q should be 1"); - assert_eq!(pub_inputs[1].to_canonical_u64(), 13, "r should be 13"); + // Test split_index for indices 0..128 with 4 buckets of size BUCKET_SIZE. + for index_val in 0..128 { + let mut builder = CircuitBuilder::::new(CircuitConfig::standard_recursion_config()); + let index_target = builder.constant(F::from_canonical_u64(index_val as u64)); + let (q_target, r_target) = + split_index::(&mut builder, index_target, BUCKET_SIZE, 4)?; + builder.register_public_input(q_target); + builder.register_public_input(r_target); + let pub_inputs = build_and_prove(builder); + let expected_q = index_val / BUCKET_SIZE; + let expected_r = index_val % BUCKET_SIZE; + assert_eq!( + pub_inputs[0].to_canonical_u64(), + expected_q as u64, + "q for index {} should be {}", + index_val, + expected_q + ); + assert_eq!( + pub_inputs[1].to_canonical_u64(), + expected_r as u64, + "r for index {} should be {}", + index_val, + expected_r + ); + } Ok(()) } #[test] - fn test_compute_power_of_two() -> anyhow::Result<()> { - // Create a circuit to compute 2^r. + fn test_split_index_invalid_index() -> anyhow::Result<()> { + // The maximum valid index is BUCKET_SIZE * num_buckets - 1. + // Test that an out-of-range index fails to prove. + let invalid_index = BUCKET_SIZE * 4; let mut builder = CircuitBuilder::::new(CircuitConfig::standard_recursion_config()); - // Let r = 13. - let r_val: u64 = 13; - let r_target = builder.constant(F::from_canonical_u64(r_val)); - let pow_target = - compute_power_of_two::(&mut builder, r_target)?; - builder.register_public_input(pow_target); - let pub_inputs = build_and_prove(builder); - // Expect 2^13 = 8192. - assert_eq!( - pub_inputs[0].to_canonical_u64(), - 1 << 13, - "2^13 should be 8192" + let index_target = builder.constant(F::from_canonical_u64(invalid_index as u64)); + let (q_target, r_target) = + split_index::(&mut builder, index_target, BUCKET_SIZE, 4)?; + // Register the outputs as public inputs. + builder.register_public_input(q_target); + builder.register_public_input(r_target); + // Build the circuit and attempt to prove. + let circuit = builder.build::(); + let pw = PartialWitness::new(); + // Proof should fail. + assert!( + circuit.prove(pw).is_err(), + "Proving should fail for out-of-range index {}", + invalid_index ); Ok(()) } #[test] - fn test_compute_flag_buckets() -> anyhow::Result<()> { - // Create a circuit to compute flag buckets. - // Let index = 45 and flag = true. + fn test_compute_power_of_two() -> anyhow::Result<()> { + // Test compute_power_of_two for r in 0..128. + let two = F::from_canonical_u64(2); + let mut expected = F::ONE; + for r_val in 0..BUCKET_SIZE { + // Update expected = 2^r_val in the field. + if r_val == 0 { + expected = F::ONE; + } else { + expected = expected * two; + } + // Build a circuit for this r_val. + let mut builder = CircuitBuilder::::new(CircuitConfig::standard_recursion_config()); + let r_target = builder.constant(F::from_canonical_u64(r_val as u64)); + let pow_target = compute_power_of_two::(&mut builder, r_target)?; + builder.register_public_input(pow_target); + let pub_inputs = build_and_prove(builder); + // Compare the circuit output to the expected 2^r_val. + assert_eq!( + pub_inputs[0].to_canonical_u64(), + expected.to_canonical_u64(), + "2^{} should be {}", + r_val, + expected.to_canonical_u64() + ); + } + Ok(()) + } + + fn test_compute_flag_buckets(flag: bool) -> anyhow::Result<()> { + // Test compute_flag_buckets for all indices 0..128 with given flag. + for index_val in 0..(BUCKET_SIZE * 4) { + let mut builder = CircuitBuilder::::new(CircuitConfig::standard_recursion_config()); + let index_target = builder.constant(F::from_canonical_u64(index_val as u64)); + let flag_target = builder.constant_bool(flag); + let buckets = compute_flag_buckets::( + &mut builder, + index_target, + flag_target, + BUCKET_SIZE, + 4, + )?; + for bucket in buckets.iter() { + builder.register_public_input(*bucket); + } + let pub_inputs = build_and_prove(builder); + // Build expected buckets: only bucket q = index_val / BUCKET_SIZE has value 2^r. + let mut expected = vec![0u64; 4]; + let q = index_val / BUCKET_SIZE; + let r = index_val % BUCKET_SIZE; + if flag { + expected[q] = 1u64 << r; + } else { + expected[q] = 0u64; + } + for (i, &expected_val) in expected.iter().enumerate() { + let computed = pub_inputs[i].to_canonical_u64(); + assert_eq!( + computed, + expected_val, + "Bucket {} for index {}: expected {} but got {}", + i, + index_val, + expected_val, + computed + ); + } + } + Ok(()) + } + + #[test] + fn test_compute_flag_buckets_real() -> anyhow::Result<()> { + test_compute_flag_buckets(true) + } + + #[test] + fn test_compute_flag_buckets_dummy() -> anyhow::Result<()> { + test_compute_flag_buckets(false) + } + + #[test] + fn test_compute_flag_buckets_invalid_index() -> anyhow::Result<()> { + // The maximum valid index is BUCKET_SIZE * num_buckets - 1. + // Test that an out-of-range index fails to prove. + let invalid_index = BUCKET_SIZE * 4; let mut builder = CircuitBuilder::::new(CircuitConfig::standard_recursion_config()); - let index_val: u64 = 45; - let index_target = builder.constant(F::from_canonical_u64(index_val)); - // Create a boolean constant target for flag = true. + let index_target = builder.constant(F::from_canonical_u64(invalid_index as u64)); let flag_target = builder.constant_bool(true); - // Compute the flag buckets with bucket_size = 32 and num_buckets = 4. - let buckets = compute_flag_buckets::( + let buckets = compute_flag_buckets::( &mut builder, index_target, flag_target, BUCKET_SIZE, 4, )?; - // Register each bucket as a public input. for bucket in buckets.iter() { builder.register_public_input(*bucket); } - let pub_inputs = build_and_prove(builder); - // With index = 45, we expect: - // q = 45 / 32 = 1 and r = 45 % 32 = 13, so bucket 1 should be 2^13 = 8192 and the others 0. - let expected = vec![0, 8192, 0, 0]; - for (i, &expected_val) in expected.iter().enumerate() { - let computed = pub_inputs[i].to_canonical_u64(); - assert_eq!( - computed, expected_val, - "Bucket {}: expected {} but got {}", - i, expected_val, computed - ); - } + // Build and attempt to prove. + let circuit = builder.build::(); + let pw = PartialWitness::new(); + assert!( + circuit.prove(pw).is_err(), + "Proving should fail for out-of-range index {}", + invalid_index + ); Ok(()) } }