diff --git a/src/gadgets/sorting.rs b/src/gadgets/sorting.rs index 102f30cd..1d1938b4 100644 --- a/src/gadgets/sorting.rs +++ b/src/gadgets/sorting.rs @@ -168,7 +168,7 @@ impl SimpleGenerator for MemoryOpSortGenerator { #[cfg(test)] mod tests { use anyhow::Result; - use rand::{seq::SliceRandom, thread_rng}; + use rand::{seq::SliceRandom, thread_rng, Rng}; use super::*; use crate::field::crandall_field::CrandallField; @@ -177,7 +177,7 @@ mod tests { use crate::plonk::circuit_data::CircuitConfig; use crate::plonk::verifier::verify; - fn test_permutation_good(size: usize) -> Result<()> { + fn test_sorting(size: usize, address_bits: usize, timestamp_bits: usize) -> Result<()> { type F = CrandallField; const D: usize = 4; @@ -186,15 +186,28 @@ mod tests { let pw = PartialWitness::new(); let mut builder = CircuitBuilder::::new(config); - let lst: Vec = (0..size * 2).map(|n| F::from_canonical_usize(n)).collect(); - let a: Vec> = lst[..] - .chunks(2) - .map(|pair| vec![builder.constant(pair[0]), builder.constant(pair[1])]) + let mut rng = thread_rng(); + let is_write_vals: Vec<_> = (0..size).map(|_| rng.gen_range(0..2) != 0).collect(); + let address_vals: Vec<_> = (0..size) + .map(|_| F::from_canonical_u64(rng.gen_range(0..1 << address_bits as u64))) .collect(); - let mut b = a.clone(); - b.shuffle(&mut thread_rng()); + let timestamp_vals: Vec<_> = (0..size) + .map(|_| F::from_canonical_u64(rng.gen_range(0..1 << timestamp_bits as u64))) + .collect(); + let value_vals: Vec<_> = (0..size).map(|_| F::rand()).collect(); - builder.assert_permutation(a, b); + let input_ops: Vec = + izip!(is_write_vals, address_vals, timestamp_vals, value_vals) + .map(|(is_write, address, timestamp, value)| MemoryOpTarget { + is_write: builder.constant_bool(is_write), + address: builder.constant(address), + timestamp: builder.constant(timestamp), + value: builder.constant(value), + }) + .collect(); + + let _output_ops = + builder.sort_memory_ops(input_ops.as_slice(), address_bits, timestamp_bits); let data = builder.build(); let proof = data.prove(pw).unwrap(); @@ -202,64 +215,12 @@ mod tests { verify(proof, &data.verifier_only, &data.common) } - fn test_permutation_bad(size: usize) -> Result<()> { - type F = CrandallField; - const D: usize = 4; - - let config = CircuitConfig::large_zk_config(); - - let pw = PartialWitness::new(); - let mut builder = CircuitBuilder::::new(config); - - let lst1: Vec = F::rand_vec(size * 2); - let lst2: Vec = F::rand_vec(size * 2); - let a: Vec> = lst1[..] - .chunks(2) - .map(|pair| vec![builder.constant(pair[0]), builder.constant(pair[1])]) - .collect(); - let b: Vec> = lst2[..] - .chunks(2) - .map(|pair| vec![builder.constant(pair[0]), builder.constant(pair[1])]) - .collect(); - - builder.assert_permutation(a, b); - - let data = builder.build(); - data.prove(pw).unwrap(); - - Ok(()) - } - #[test] - fn test_permutations_good() -> Result<()> { - for n in 2..9 { - test_permutation_good(n)?; - } + fn test_sorting_small() -> Result<()> { + let size = 5; + let address_bits = 20; + let timestamp_bits = 20; - Ok(()) - } - - #[test] - #[should_panic] - fn test_permutation_bad_small() { - let size = 2; - - test_permutation_bad(size).unwrap() - } - - #[test] - #[should_panic] - fn test_permutation_bad_medium() { - let size = 6; - - test_permutation_bad(size).unwrap() - } - - #[test] - #[should_panic] - fn test_permutation_bad_large() { - let size = 10; - - test_permutation_bad(size).unwrap() + test_sorting(size, address_bits, timestamp_bits) } }