diff --git a/src/gates/comparison.rs b/src/gates/comparison.rs index edf8ca9c..86abd51b 100644 --- a/src/gates/comparison.rs +++ b/src/gates/comparison.rs @@ -9,7 +9,7 @@ use crate::iop::target::Target; use crate::iop::wire::Wire; use crate::iop::witness::{PartitionWitness, Witness}; use crate::plonk::circuit_builder::CircuitBuilder; -use crate::plonk::plonk_common::reduce_with_powers; +use crate::plonk::plonk_common::{reduce_with_powers, reduce_with_powers_ext_recursive}; use crate::plonk::vars::{EvaluationTargets, EvaluationVars, EvaluationVarsBase}; use crate::util::ceil_div_usize; @@ -194,7 +194,64 @@ impl, const D: usize> Gate for ComparisonGate builder: &mut CircuitBuilder, vars: EvaluationTargets, ) -> Vec> { - todo!() + let mut constraints = Vec::with_capacity(self.num_constraints()); + + let first_input = vars.local_wires[self.wire_first_input()]; + let second_input = vars.local_wires[self.wire_second_input()]; + + // Get chunks and assert that they match + let first_chunks: Vec> = (0..self.num_chunks) + .map(|i| vars.local_wires[self.wire_first_chunk_val(i)]) + .collect(); + let second_chunks: Vec> = (0..self.num_chunks) + .map(|i| vars.local_wires[self.wire_second_chunk_val(i)]) + .collect(); + + let chunk_base = builder.constant(F::from_canonical_usize(1 << self.chunk_bits())); + let first_chunks_combined = + reduce_with_powers_ext_recursive(builder, &first_chunks, chunk_base); + let second_chunks_combined = + reduce_with_powers_ext_recursive(builder, &second_chunks, chunk_base); + + constraints.push(builder.sub_extension(first_chunks_combined, first_input)); + constraints.push(builder.sub_extension(second_chunks_combined, second_input)); + + let mut most_significant_diff = builder.zero_extension(); + + let one = builder.one_extension(); + // Find the chosen chunk. + for i in 0..self.num_chunks { + let difference = builder.sub_extension(first_chunks[i], second_chunks[i]); + let equality_dummy = vars.local_wires[self.wire_equality_dummy(i)]; + let chunks_equal = vars.local_wires[self.wire_chunks_equal(i)]; + + // Two constraints identifying index. + let diff_times_equal = builder.mul_extension(difference, equality_dummy); + let not_equal = builder.sub_extension(one, chunks_equal); + constraints.push(builder.sub_extension(diff_times_equal, not_equal)); + constraints.push(builder.mul_extension(chunks_equal, difference)); + + let this_diff = builder.sub_extension(first_chunks[i], second_chunks[i]); + let old_diff = builder.mul_extension(chunks_equal, most_significant_diff); + let not_equal = builder.sub_extension(one, chunks_equal); + let new_diff = builder.mul_extension(not_equal, this_diff); + most_significant_diff = builder.add_extension(old_diff, new_diff); + } + + let two = builder.constant(F::TWO); + let z_bits: Vec> = (0..self.chunk_bits() + 1) + .map(|i| vars.local_wires[self.wire_z_bit(i)]) + .collect(); + let z_bits_combined = reduce_with_powers_ext_recursive(builder, &z_bits, two); + + let two_n = builder.constant_extension(F::Extension::TWO.exp_u64(self.chunk_bits() as u64)); + let expected_z = builder.add_extension(two_n, most_significant_diff); + let z_diff = builder.sub_extension(z_bits_combined, expected_z); + constraints.push(z_diff); + + constraints.push(z_bits[self.chunk_bits()]); + + constraints } fn generators(