clean up tests and bench

This commit is contained in:
M Alghazwi 2025-04-03 11:15:09 +02:00
parent f49d5c5218
commit 65508f3da9
No known key found for this signature in database
GPG Key ID: 646E567CAD7DB607
4 changed files with 19 additions and 23 deletions

View File

@ -9,7 +9,7 @@ edition = "2021"
[dependencies] [dependencies]
anyhow = { workspace = true } anyhow = { workspace = true }
unroll = { workspace = true } unroll = { workspace = true }
plonky2 = { workspace = true } plonky2 = { workspace = true , features = ["gate_testing"]}
plonky2_field = { workspace = true } plonky2_field = { workspace = true }
serde = { workspace = true } serde = { workspace = true }
serde_json = { workspace = true } serde_json = { workspace = true }

View File

@ -1,6 +1,4 @@
use std::fs;
use anyhow::Result; use anyhow::Result;
use std::time::Instant;
use criterion::{criterion_group, criterion_main, BatchSize, Criterion}; use criterion::{criterion_group, criterion_main, BatchSize, Criterion};
use plonky2::field::extension::Extendable; use plonky2::field::extension::Extendable;
use plonky2::field::goldilocks_field::GoldilocksField; use plonky2::field::goldilocks_field::GoldilocksField;
@ -33,7 +31,6 @@ pub struct PoseidonCircuit<
> { > {
public_input: Vec<Target>, public_input: Vec<Target>,
circuit_data: CircuitData<F, C, D>, circuit_data: CircuitData<F, C, D>,
num_powers: usize,
_hasher: PhantomData<H>, _hasher: PhantomData<H>,
} }
@ -60,7 +57,7 @@ impl<
state.set_from_slice(&initial, 0); state.set_from_slice(&initial, 0);
for k in 0..num_hashes { for _k in 0..num_hashes {
state = builder.permute::<H>(state); state = builder.permute::<H>(state);
} }
@ -74,17 +71,16 @@ impl<
Self { Self {
public_input: initial, public_input: initial,
circuit_data: data, circuit_data: data,
num_powers: num_hashes,
_hasher: PhantomData::<H>, _hasher: PhantomData::<H>,
} }
} }
pub fn generate_proof(&self, init: F) -> Result<ProofWithPublicInputs<F, C, D>> { pub fn generate_proof(&self, init: Vec<F>) -> Result<ProofWithPublicInputs<F, C, D>> {
const T: usize = 12; const T: usize = 12;
let mut pw = PartialWitness::<F>::new(); let mut pw = PartialWitness::<F>::new();
for j in 0..T { for j in 0..T {
pw.set_target(self.public_input[j], F::from_canonical_usize(j)); pw.set_target(self.public_input[j], init[j])?;
} }
let proof = self.circuit_data.prove(pw).unwrap(); let proof = self.circuit_data.prove(pw).unwrap();
@ -135,14 +131,14 @@ fn bench_poseidon2_perm<
format!("prove circuit with 2^{} permutations", log_num_hashes).as_str(), format!("prove circuit with 2^{} permutations", log_num_hashes).as_str(),
|b| { |b| {
b.iter_batched( b.iter_batched(
|| F::rand(), || F::rand_vec(12),
|init| poseidon_circuit.generate_proof(init).unwrap(), |init| poseidon_circuit.generate_proof(init).unwrap(),
BatchSize::PerIteration, BatchSize::PerIteration,
) )
}, },
); );
let proof = poseidon_circuit.generate_proof(F::rand()).unwrap(); let proof = poseidon_circuit.generate_proof(F::rand_vec(12)).unwrap();
pretty_print!("proof size: {}", proof.to_bytes().len()); pretty_print!("proof size: {}", proof.to_bytes().len());

View File

@ -517,7 +517,7 @@ mod tests {
} }
#[test] #[test]
fn generated_output() { fn generated_output() -> Result<()>{
const D: usize = 2; const D: usize = 2;
type C = Poseidon2GoldilocksConfig; type C = Poseidon2GoldilocksConfig;
type F = <C as GenericConfig<D>>::F; type F = <C as GenericConfig<D>>::F;
@ -547,7 +547,7 @@ mod tests {
column: Gate::WIRE_SWAP, column: Gate::WIRE_SWAP,
}, },
F::ZERO, F::ZERO,
); )?;
for i in 0..SPONGE_WIDTH { for i in 0..SPONGE_WIDTH {
inputs.set_wire( inputs.set_wire(
Wire { Wire {
@ -555,7 +555,7 @@ mod tests {
column: Gate::wire_input(i), column: Gate::wire_input(i),
}, },
permutation_inputs[i], permutation_inputs[i],
); )?;
} }
let witness = generate_partial_witness(inputs, &circuit.prover_only, &circuit.common).unwrap(); let witness = generate_partial_witness(inputs, &circuit.prover_only, &circuit.common).unwrap();
@ -568,7 +568,9 @@ mod tests {
}); });
println!("out {} = {}", i, out.clone()); println!("out {} = {}", i, out.clone());
assert_eq!(out, expected_outputs[i]); assert_eq!(out, expected_outputs[i]);
} };
Ok(())
} }
#[test] #[test]
@ -588,7 +590,7 @@ mod tests {
} }
#[test] #[test]
fn test_proof() { fn test_proof() -> Result<()>{
use plonky2_field::types::Sample; use plonky2_field::types::Sample;
use plonky2::gates::gate::Gate; use plonky2::gates::gate::Gate;
use plonky2::hash::hash_types::HashOut; use plonky2::hash::hash_types::HashOut;
@ -607,10 +609,10 @@ mod tests {
let wires_t = builder.add_virtual_extension_targets(wires.len()); let wires_t = builder.add_virtual_extension_targets(wires.len());
let constants_t = builder.add_virtual_extension_targets(constants.len()); let constants_t = builder.add_virtual_extension_targets(constants.len());
pw.set_extension_targets(&wires_t, &wires); pw.set_extension_targets(&wires_t, &wires)?;
pw.set_extension_targets(&constants_t, &constants); pw.set_extension_targets(&constants_t, &constants)?;
let public_inputs_hash_t = builder.add_virtual_hash(); let public_inputs_hash_t = builder.add_virtual_hash();
pw.set_hash_target(public_inputs_hash_t, public_inputs_hash); pw.set_hash_target(public_inputs_hash_t, public_inputs_hash)?;
let vars = EvaluationVars { let vars = EvaluationVars {
local_constants: &constants, local_constants: &constants,
@ -625,9 +627,10 @@ mod tests {
public_inputs_hash: &public_inputs_hash_t, public_inputs_hash: &public_inputs_hash_t,
}; };
let evals_t = gate.eval_unfiltered_circuit(&mut builder, vars_t); let evals_t = gate.eval_unfiltered_circuit(&mut builder, vars_t);
pw.set_extension_targets(&evals_t, &evals); pw.set_extension_targets(&evals_t, &evals)?;
let data = builder.build::<C>(); let data = builder.build::<C>();
let proof = data.prove(pw); let proof = data.prove(pw);
assert!(proof.is_ok()); assert!(proof.is_ok());
Ok(())
} }
} }

View File

@ -151,18 +151,15 @@ impl Poseidon2 for GoldilocksField {
mod tests { mod tests {
use plonky2_field::goldilocks_field::GoldilocksField as F; use plonky2_field::goldilocks_field::GoldilocksField as F;
use plonky2_field::types::{Field, PrimeField64};
use crate::poseidon2_hash::poseidon2::test_helpers::check_test_vectors; use crate::poseidon2_hash::poseidon2::test_helpers::check_test_vectors;
#[test] #[test]
fn p2new_test_vectors() { fn test_vectors() {
// Test inputs are: // Test inputs are:
// 1. range 0..WIDTH // 1. range 0..WIDTH
// expected output calculated with reference implementation here: // expected output calculated with reference implementation here:
// https://github.com/HorizenLabs/poseidon2 // https://github.com/HorizenLabs/poseidon2
let neg_one: u64 = F::NEG_ONE.to_canonical_u64();
#[rustfmt::skip] #[rustfmt::skip]
let test_vectors12: Vec<([u64; 12], [u64; 12])> = vec![ let test_vectors12: Vec<([u64; 12], [u64; 12])> = vec![
([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, ], ([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, ],