diff --git a/src/gadgets/permutation.rs b/src/gadgets/permutation.rs index 126846ec..0f320dfd 100644 --- a/src/gadgets/permutation.rs +++ b/src/gadgets/permutation.rs @@ -384,7 +384,7 @@ impl SimpleGenerator for PermutationGenerator { #[cfg(test)] mod tests { use anyhow::Result; - use rand::{seq::SliceRandom, thread_rng}; + use rand::{seq::SliceRandom, thread_rng, Rng}; use super::*; use crate::field::crandall_field::CrandallField; @@ -418,6 +418,35 @@ mod tests { verify(proof, &data.verifier_only, &data.common) } + fn test_permutation_duplicates(size: usize) -> Result<()> { + type F = CrandallField; + const D: usize = 4; + + let config = CircuitConfig::large_zk_config(); + + let pw = PartialWitness::new(); + let mut builder = CircuitBuilder::::new(config); + + let mut rng = thread_rng(); + let lst: Vec = (0..size * 2) + .map(|_| F::from_canonical_usize(rng.gen_range(0..2usize))) + .collect(); + let a: Vec> = lst[..] + .chunks(2) + .map(|pair| vec![builder.constant(pair[0]), builder.constant(pair[1])]) + .collect(); + + let mut b = a.clone(); + b.shuffle(&mut thread_rng()); + + builder.assert_permutation(a, b); + + let data = builder.build(); + let proof = data.prove(pw).unwrap(); + + verify(proof, &data.verifier_only, &data.common) + } + fn test_permutation_bad(size: usize) -> Result<()> { type F = CrandallField; const D: usize = 4; @@ -446,6 +475,15 @@ mod tests { Ok(()) } + #[test] + fn test_permutations_duplicates() -> Result<()> { + for n in 2..9 { + test_permutation_duplicates(n)?; + } + + Ok(()) + } + #[test] fn test_permutations_good() -> Result<()> { for n in 2..9 { diff --git a/src/gadgets/sorting.rs b/src/gadgets/sorting.rs index 1d1938b4..099209c9 100644 --- a/src/gadgets/sorting.rs +++ b/src/gadgets/sorting.rs @@ -1,16 +1,15 @@ -use std::marker::PhantomData; +use itertools::izip; -use itertools::{izip, Itertools}; - -use crate::field::field_types::{PrimeField, RichField}; -use crate::field::{extension_field::Extendable, field_types::Field}; +use crate::field::field_types::RichField; +use crate::field::extension_field::Extendable; use crate::gates::comparison::ComparisonGate; use crate::iop::generator::{GeneratedValues, SimpleGenerator}; use crate::iop::target::{BoolTarget, Target}; use crate::iop::witness::{PartitionWitness, Witness}; use crate::plonk::circuit_builder::CircuitBuilder; +use crate::util::ceil_div_usize; -#[derive(Debug)] +#[derive(Clone, Debug)] pub struct MemoryOpTarget { is_write: BoolTarget, address: Target, @@ -41,7 +40,8 @@ impl, const D: usize> CircuitBuilder { let n = ops.len(); let combined_bits = address_bits + timestamp_bits; - let chunk_size = 3; + let chunk_bits = 3; + let num_chunks = ceil_div_usize(combined_bits, chunk_bits); let is_write_targets: Vec<_> = self .add_virtual_targets(n) @@ -69,12 +69,14 @@ impl, const D: usize> CircuitBuilder { let two_n = self.constant(F::from_canonical_usize(1 << timestamp_bits)); let address_timestamp_combined: Vec<_> = output_targets .iter() - .map(|op| self.mul_add(op.timestamp, two_n, op.address)) + .map(|op| self.mul_add(op.address, two_n, op.timestamp)) .collect(); + let mut gate_indices = Vec::new(); + let mut gates = Vec::new(); for i in 1..n { let (gate, gate_index) = { - let gate = ComparisonGate::new(combined_bits, chunk_size); + let gate = ComparisonGate::new(combined_bits, num_chunks); let gate_index = self.add_gate(gate.clone(), vec![]); (gate, gate_index) }; @@ -87,24 +89,39 @@ impl, const D: usize> CircuitBuilder { Target::wire(gate_index, gate.wire_second_input()), address_timestamp_combined[i], ); + + gate_indices.push(gate_index); + gates.push(gate); } self.assert_permutation_memory_ops(ops, output_targets.as_slice()); + self.add_simple_generator(MemoryOpSortGenerator:: { + input_ops: ops.to_vec(), + gate_indices, + gates: gates.clone(), + output_ops: output_targets.clone(), + address_bits, + timestamp_bits, + }); + output_targets } } #[derive(Debug)] -struct MemoryOpSortGenerator { +struct MemoryOpSortGenerator, const D: usize> { input_ops: Vec, + gate_indices: Vec, + gates: Vec>, output_ops: Vec, address_bits: usize, timestamp_bits: usize, - _phantom: PhantomData, } -impl SimpleGenerator for MemoryOpSortGenerator { +impl, const D: usize> SimpleGenerator + for MemoryOpSortGenerator +{ fn dependencies(&self) -> Vec { self.input_ops .iter() @@ -136,6 +153,13 @@ impl SimpleGenerator for MemoryOpSortGenerator { }) .collect(); + let mut combined_values_sorted = combined_values_u64.clone(); + combined_values_sorted.sort(); + let combined_values: Vec<_> = combined_values_sorted + .iter() + .map(|&x| F::from_canonical_u64(x)) + .collect(); + let mut input_ops_and_keys: Vec<_> = self .input_ops .iter() @@ -161,12 +185,30 @@ impl SimpleGenerator for MemoryOpSortGenerator { self.output_ops[i].value, witness.get_target(input_ops_sorted[i].value), ); + + if i > 0 { + out_buffer.set_target( + Target::wire( + self.gate_indices[i - 1], + self.gates[i - 1].wire_second_input(), + ), + combined_values[i], + ); + } + if i < n - 1 { + out_buffer.set_target( + Target::wire(self.gate_indices[i], self.gates[i].wire_first_input()), + combined_values[i], + ); + } } } } #[cfg(test)] mod tests { + use std::collections::HashSet; + use anyhow::Result; use rand::{seq::SliceRandom, thread_rng, Rng}; @@ -223,4 +265,13 @@ mod tests { test_sorting(size, address_bits, timestamp_bits) } + + #[test] + fn test_sorting_large() -> Result<()> { + let size = 20; + let address_bits = 20; + let timestamp_bits = 20; + + test_sorting(size, address_bits, timestamp_bits) + } } diff --git a/src/gates/comparison.rs b/src/gates/comparison.rs index 86601fba..b1d57929 100644 --- a/src/gates/comparison.rs +++ b/src/gates/comparison.rs @@ -386,7 +386,7 @@ impl, const D: usize> SimpleGenerator let mut most_significant_diff_so_far = F::ZERO; let mut intermediate_values = Vec::new(); - for i in 1..self.gate.num_chunks { + for i in 0..self.gate.num_chunks { if first_input_chunks[i] != second_input_chunks[i] { most_significant_diff_so_far = second_input_chunks[i] - first_input_chunks[i]; intermediate_values.push(F::ZERO); diff --git a/src/gates/switch.rs b/src/gates/switch.rs index 1df5fea9..26efac0e 100644 --- a/src/gates/switch.rs +++ b/src/gates/switch.rs @@ -236,20 +236,43 @@ impl, const D: usize> SwitchGenerator { let get_local_wire = |input| witness.get_wire(local_wire(input)); - for e in 0..self.gate.chunk_size { - let switch_bool_wire = local_wire(self.gate.wire_switch_bool(self.copy)); - let first_input = get_local_wire(self.gate.wire_first_input(self.copy, e)); - let second_input = get_local_wire(self.gate.wire_second_input(self.copy, e)); - let first_output = get_local_wire(self.gate.wire_first_output(self.copy, e)); - let second_output = get_local_wire(self.gate.wire_second_output(self.copy, e)); + let switch_bool_wire = local_wire(self.gate.wire_switch_bool(self.copy)); - if first_output == first_input && second_output == second_input { - out_buffer.set_wire(switch_bool_wire, F::ZERO); - } else if first_output == second_input && second_output == first_input { - out_buffer.set_wire(switch_bool_wire, F::ONE); - } else { - panic!("No permutation from given inputs to given outputs"); - } + let mut first_inputs = Vec::new(); + let mut second_inputs = Vec::new(); + let mut first_outputs = Vec::new(); + let mut second_outputs = Vec::new(); + for e in 0..self.gate.chunk_size { + first_inputs.push(get_local_wire(self.gate.wire_first_input(self.copy, e))); + second_inputs.push(get_local_wire(self.gate.wire_second_input(self.copy, e))); + first_outputs.push(get_local_wire(self.gate.wire_first_output(self.copy, e))); + second_outputs.push(get_local_wire(self.gate.wire_second_output(self.copy, e))); + } + + let first_keep = first_outputs + .iter() + .zip(first_inputs.iter()) + .all(|(x, y)| x == y); + let second_keep = second_outputs + .iter() + .zip(second_inputs.iter()) + .all(|(x, y)| x == y); + + let first_swap = first_outputs + .iter() + .zip(second_inputs.iter()) + .all(|(x, y)| x == y); + let second_swap = second_outputs + .iter() + .zip(first_inputs.iter()) + .all(|(x, y)| x == y); + + if first_keep && second_keep { + out_buffer.set_wire(switch_bool_wire, F::ZERO); + } else if first_swap && second_swap { + out_buffer.set_wire(switch_bool_wire, F::ONE); + } else { + panic!("No permutation from given inputs to given outputs"); } } @@ -261,12 +284,12 @@ impl, const D: usize> SwitchGenerator { let get_local_wire = |input| witness.get_wire(local_wire(input)); + let switch_bool = get_local_wire(self.gate.wire_switch_bool(self.copy)); for e in 0..self.gate.chunk_size { let first_output_wire = local_wire(self.gate.wire_first_output(self.copy, e)); let second_output_wire = local_wire(self.gate.wire_second_output(self.copy, e)); let first_input = get_local_wire(self.gate.wire_first_input(self.copy, e)); let second_input = get_local_wire(self.gate.wire_second_input(self.copy, e)); - let switch_bool = get_local_wire(self.gate.wire_switch_bool(self.copy)); let (first_output, second_output) = if switch_bool == F::ZERO { (first_input, second_input)