diff --git a/src/gadgets/sorting.rs b/src/gadgets/sorting.rs index f3b9cf0d..57f65a51 100644 --- a/src/gadgets/sorting.rs +++ b/src/gadgets/sorting.rs @@ -31,6 +31,26 @@ impl, const D: usize> CircuitBuilder { self.assert_permutation(a_chunks, b_chunks); } + /// Add a ComparisonGate to + /// Returns the gate and its index + pub fn assert_le( + &mut self, + lhs: Target, + rhs: Target, + bits: usize, + num_chunks: usize, + ) -> (ComparisonGate, usize) { + let gate = ComparisonGate::new(bits, num_chunks); + let gate_index = self.add_gate(gate.clone(), vec![]); + + self.connect(Target::wire(gate_index, gate.wire_first_input()), lhs); + self.connect(Target::wire(gate_index, gate.wire_second_input()), rhs); + + (gate, gate_index) + } + + /// Sort memory operations by address value, then by timestamp value. + /// This is done by combining address and timestamp into one field element (using their given bit lengths). pub fn sort_memory_ops( &mut self, ops: &[MemoryOpTarget], @@ -43,11 +63,13 @@ impl, const D: usize> CircuitBuilder { let chunk_bits = 3; let num_chunks = ceil_div_usize(combined_bits, chunk_bits); + // This is safe because `assert_permutation` will force these targets (in the output list) to match the boolean values from the input list. let is_write_targets: Vec<_> = self .add_virtual_targets(n) .iter() .map(|&t| BoolTarget::new_unsafe(t)) .collect(); + let address_targets = self.add_virtual_targets(n); let timestamp_targets = self.add_virtual_targets(n); let value_targets = self.add_virtual_targets(n); @@ -72,22 +94,14 @@ impl, const D: usize> CircuitBuilder { .map(|op| self.mul_add(op.address, two_n, op.timestamp)) .collect(); - let mut gate_indices = Vec::new(); let mut gates = Vec::new(); + let mut gate_indices = Vec::new(); for i in 1..n { - let (gate, gate_index) = { - let gate = ComparisonGate::new(combined_bits, num_chunks); - let gate_index = self.add_gate(gate.clone(), vec![]); - (gate, gate_index) - }; - - self.connect( - Target::wire(gate_index, gate.wire_first_input()), + let (gate, gate_index) = self.assert_le( address_timestamp_combined[i - 1], - ); - self.connect( - Target::wire(gate_index, gate.wire_second_input()), address_timestamp_combined[i], + combined_bits, + num_chunks, ); gate_indices.push(gate_index); diff --git a/src/gates/comparison.rs b/src/gates/comparison.rs index b1d57929..44f2923a 100644 --- a/src/gates/comparison.rs +++ b/src/gates/comparison.rs @@ -15,7 +15,7 @@ use crate::util::ceil_div_usize; /// A gate for checking that one value is less than or equal to another. #[derive(Clone, Debug)] -pub(crate) struct ComparisonGate, const D: usize> { +pub struct ComparisonGate, const D: usize> { pub(crate) num_bits: usize, pub(crate) num_chunks: usize, _phantom: PhantomData, @@ -436,7 +436,6 @@ mod tests { use crate::gates::gate::Gate; use crate::gates::gate_testing::{test_eval_fns, test_low_degree}; use crate::hash::hash_types::HashOut; - use crate::plonk::plonk_common::reduce_with_powers; use crate::plonk::vars::EvaluationVars; #[test] diff --git a/src/gates/switch.rs b/src/gates/switch.rs index 26efac0e..2f6b7122 100644 --- a/src/gates/switch.rs +++ b/src/gates/switch.rs @@ -249,27 +249,9 @@ impl, const D: usize> SwitchGenerator { second_outputs.push(get_local_wire(self.gate.wire_second_output(self.copy, e))); } - let first_keep = first_outputs - .iter() - .zip(first_inputs.iter()) - .all(|(x, y)| x == y); - let second_keep = second_outputs - .iter() - .zip(second_inputs.iter()) - .all(|(x, y)| x == y); - - let first_swap = first_outputs - .iter() - .zip(second_inputs.iter()) - .all(|(x, y)| x == y); - let second_swap = second_outputs - .iter() - .zip(first_inputs.iter()) - .all(|(x, y)| x == y); - - if first_keep && second_keep { + if first_outputs == first_inputs && second_outputs == second_inputs { out_buffer.set_wire(switch_bool_wire, F::ZERO); - } else if first_swap && second_swap { + } else if first_outputs == second_inputs && second_outputs == first_inputs { out_buffer.set_wire(switch_bool_wire, F::ONE); } else { panic!("No permutation from given inputs to given outputs");