diff --git a/src/gates/comparison.rs b/src/gates/comparison.rs index 93454ebc..f9ef81c4 100644 --- a/src/gates/comparison.rs +++ b/src/gates/comparison.rs @@ -11,7 +11,7 @@ use crate::iop::witness::{PartitionWitness, Witness}; use crate::plonk::circuit_builder::CircuitBuilder; use crate::plonk::circuit_data::CircuitConfig; use crate::plonk::vars::{EvaluationTargets, EvaluationVars, EvaluationVarsBase}; -use crate::util::{ceil_div_usize, log2_ceil}; +use crate::util::{bits_u64, ceil_div_usize}; /// A gate for checking that one value is smaller than another. #[derive(Clone, Debug)] @@ -31,7 +31,7 @@ impl, const D: usize> ComparisonGate { } pub fn field_bits() -> usize { - log2_ceil(F::ORDER) + bits_u64(F::ORDER) } pub fn num_chunks(&self) -> usize { @@ -118,34 +118,34 @@ impl, const D: usize> Gate for ComparisonGate let second_input = vars.local_wires[self.wire_second_input(c)]; // Get chunks and assert that they match - let first_chunks: Vec = (0..self.num_chunks()) + let first_chunks: Vec = (0..self.num_chunks()) .map(|i| vars.local_wires[self.wire_first_chunk_val(c, i)]) .collect(); - let second_chunks: Vec = (0..self.num_chunks()) + let second_chunks: Vec = (0..self.num_chunks()) .map(|i| vars.local_wires[self.wire_second_chunk_val(c, i)]) .collect(); - let chunk_base_powers = (0..self.chunk_bits) - .map(|i| F::TWO.exp_u64(i * self.chunk_bits as u64)) + let chunk_base_powers: Vec = (0..self.chunk_bits) + .map(|i| F::Extension::TWO.exp_u64((i * self.chunk_bits) as u64)) .collect(); let first_chunks_combined = first_chunks .iter() .zip(chunk_base_powers.iter()) - .map(|(b, x)| b * x) - .fold(F::ZERO, |a, b| a + b); + .map(|(b, x)| *b * *x) + .fold(F::Extension::ZERO, |a, b| a + b); let second_chunks_combined = second_chunks .iter() .zip(chunk_base_powers.iter()) - .map(|(b, x)| b * x) - .fold(F::ZERO, |a, b| a + b); + .map(|(b, x)| *b * *x) + .fold(F::Extension::ZERO, |a, b| a + b); constraints.push(first_chunks_combined - first_input); constraints.push(second_chunks_combined - second_input); // Get bits to assert they match the chosen chunk. - let powers_of_two: Vec = (0..self.chunk_bits) - .map(|i| F::TWO.exp_u64(i as u64)) + let powers_of_two: Vec = (0..self.chunk_bits) + .map(|i| F::Extension::TWO.exp_u64(i as u64)) .collect(); let mut most_significant_diff = @@ -155,7 +155,7 @@ impl, const D: usize> Gate for ComparisonGate for i in (0..self.num_chunks()).rev() { let difference = first_chunks[i] - second_chunks[i]; let equality_dummy = vars.local_wires[self.wire_equality_dummy(c, i)]; - let chunks_equal = vars.local_wires[self.wires_chunks_equal(c, i)]; + let chunks_equal = vars.local_wires[self.wire_chunks_equal(c, i)]; // Two constraints identifying index. constraints.push(difference * equality_dummy - (F::Extension::ONE - chunks_equal)); @@ -166,25 +166,21 @@ impl, const D: usize> Gate for ComparisonGate + (F::Extension::ONE - chunks_equal) * this_diff; } - constraints.push(first_bits_combined - most_significant_diff[0]); - constraints.push(second_bits_combined - most_significant_diff[1]); - - let z_bits: Vec = (0..self.chunk_size + 1) + let z_bits: Vec = (0..self.chunk_bits + 1) .map(|i| vars.local_wires[self.wire_z_bit(c, i)]) .collect(); - let powers_of_two: Vec = (0..self.chunk_bits + 1) - .map(|i| F::TWO.exp_u64(i as u64)) + let powers_of_two: Vec = (0..self.chunk_bits + 1) + .map(|i| F::Extension::TWO.exp_u64(i as u64)) .collect(); let z_bits_combined = z_bits .iter() .zip(powers_of_two.iter()) - .map(|(b, x)| b * x) - .fold(F::ZERO, |a, b| a + b); + .map(|(b, x)| *b * *x) + .fold(F::Extension::ZERO, |a, b| a + b); - let two_n = F::TWO.exp_u64(self.chunk_bits); - let (x, y) = most_significant_diff; - constraints.push(z_bits_combined - (two_n + x - y)); + let two_n = F::Extension::TWO.exp_u64(self.chunk_bits as u64); + constraints.push(z_bits_combined - (two_n + most_significant_diff)); constraints.push(z_bits[self.chunk_bits - 1]); } @@ -211,18 +207,19 @@ impl, const D: usize> Gate for ComparisonGate ) -> Vec>> { (0..self.num_copies) .map(|c| { - let g: Box> = Box::new(ComparisonGenerator:: { + let gen = ComparisonGenerator:: { gate_index, gate: self.clone(), copy: c, - }); + }; + let g: Box> = Box::new(gen.adapter()); g }) .collect() } fn num_wires(&self) -> usize { - self.wire_switch_bool(self.num_copies - 1) + 1 + self.wire_chunks_equal(self.num_copies - 1, self.num_chunks() - 1) + 1 } fn num_constants(&self) -> usize { @@ -268,7 +265,7 @@ impl, const D: usize> SimpleGenerator let first_input = get_local_wire(self.gate.wire_first_input(self.copy)); let second_input = get_local_wire(self.gate.wire_second_input(self.copy)); - let field_bits = log2_ceil(F::ORDER); + let field_bits = bits_u64(F::ORDER); let first_input_u64 = first_input.to_canonical_u64(); let second_input_u64 = second_input.to_canonical_u64(); @@ -329,7 +326,7 @@ impl, const D: usize> SimpleGenerator out_buffer.set_wire(local_wire(self.gate.wire_z_val(self.copy)), z); for b in 0..self.gate.chunk_bits + 1 { - out_buffer.set_wire(local_wire(self.gate.wire_z_bit(c, b)), z_bits[b]); + out_buffer.set_wire(local_wire(self.gate.wire_z_bit(self.copy, b)), z_bits[b]); } for i in 0..self.gate.num_chunks() { out_buffer.set_wire(