diff --git a/src/gates/comparison.rs b/src/gates/comparison.rs index 3d7a95a9..d928bd6f 100644 --- a/src/gates/comparison.rs +++ b/src/gates/comparison.rs @@ -65,6 +65,11 @@ impl, const D: usize> ComparisonGate { debug_assert!(chunk < self.num_chunks); 3 + 3 * self.num_chunks + chunk } + + pub fn wire_intermediate_value(&self, chunk: usize) -> usize { + debug_assert!(chunk < self.num_chunks); + 3 + 4 * self.num_chunks + chunk + } } impl, const D: usize> Gate for ComparisonGate { @@ -122,8 +127,10 @@ impl, const D: usize> Gate for ComparisonGate constraints.push(chunks_equal * difference); // Update `most_significant_diff_so_far`. - most_significant_diff_so_far = chunks_equal * most_significant_diff_so_far - + (F::Extension::ONE - chunks_equal) * difference; + let intermediate_value = vars.local_wires[self.wire_intermediate_value(i)]; + constraints.push(intermediate_value - chunks_equal * most_significant_diff_so_far); + most_significant_diff_so_far = + intermediate_value + (F::Extension::ONE - chunks_equal) * difference; } let most_significant_diff = vars.local_wires[self.wire_most_significant_diff()]; @@ -188,8 +195,10 @@ impl, const D: usize> Gate for ComparisonGate constraints.push(chunks_equal * difference); // Update `most_significant_diff_so_far`. + let intermediate_value = vars.local_wires[self.wire_intermediate_value(i)]; + constraints.push(intermediate_value - chunks_equal * most_significant_diff_so_far); most_significant_diff_so_far = - chunks_equal * most_significant_diff_so_far + (F::ONE - chunks_equal) * difference; + intermediate_value + (F::ONE - chunks_equal) * difference; } let most_significant_diff = vars.local_wires[self.wire_most_significant_diff()]; @@ -262,10 +271,13 @@ impl, const D: usize> Gate for ComparisonGate constraints.push(builder.mul_extension(chunks_equal, difference)); // Update `most_significant_diff_so_far`. + let intermediate_value = vars.local_wires[self.wire_intermediate_value(i)]; let old_diff = builder.mul_extension(chunks_equal, most_significant_diff_so_far); + constraints.push(builder.sub_extension(intermediate_value, old_diff)); + let not_equal = builder.sub_extension(one, chunks_equal); let new_diff = builder.mul_extension(not_equal, difference); - most_significant_diff_so_far = builder.add_extension(old_diff, new_diff); + most_significant_diff_so_far = builder.add_extension(intermediate_value, new_diff); } let most_significant_diff = vars.local_wires[self.wire_most_significant_diff()]; @@ -297,7 +309,7 @@ impl, const D: usize> Gate for ComparisonGate } fn num_wires(&self) -> usize { - self.wire_chunks_equal(self.num_chunks - 1) + 1 + self.wire_intermediate_value(self.num_chunks - 1) + 1 } fn num_constants(&self) -> usize { @@ -305,11 +317,11 @@ impl, const D: usize> Gate for ComparisonGate } fn degree(&self) -> usize { - (self.num_chunks + 1).max(1 << self.chunk_bits()) + 1 << self.chunk_bits() } fn num_constraints(&self) -> usize { - 4 + 4 * self.num_chunks + 4 + 5 * self.num_chunks } } @@ -372,15 +384,17 @@ impl, const D: usize> SimpleGenerator .map(|(&f, &s)| if f == s { F::ONE } else { F::ONE / (s - f) }) .collect(); - let mut diff_index = 0; + let mut most_significant_diff_so_far = F::ZERO; + let mut intermediate_values = Vec::new(); for i in 1..self.gate.num_chunks { if first_input_chunks[i] != second_input_chunks[i] { - diff_index = i; + most_significant_diff_so_far = second_input_chunks[i] - first_input_chunks[i]; + intermediate_values.push(F::ZERO); + } else { + intermediate_values.push(most_significant_diff_so_far); } } - - let most_significant_diff = - second_input_chunks[diff_index] - first_input_chunks[diff_index]; + let most_significant_diff = most_significant_diff_so_far; out_buffer.set_wire( local_wire(self.gate.wire_most_significant_diff()), @@ -400,6 +414,10 @@ impl, const D: usize> SimpleGenerator equality_dummies[i], ); out_buffer.set_wire(local_wire(self.gate.wire_chunks_equal(i)), chunks_equal[i]); + out_buffer.set_wire( + local_wire(self.gate.wire_intermediate_value(i)), + intermediate_values[i], + ); } } } @@ -444,6 +462,8 @@ mod tests { assert_eq!(gate.wire_equality_dummy(4), 17); assert_eq!(gate.wire_chunks_equal(0), 18); assert_eq!(gate.wire_chunks_equal(4), 22); + assert_eq!(gate.wire_intermediate_value(0), 23); + assert_eq!(gate.wire_intermediate_value(4), 27); } #[test] @@ -504,15 +524,17 @@ mod tests { .map(|(&f, &s)| if f == s { F::ONE } else { F::ONE / (s - f) }) .collect(); - let mut diff_index = 0; - for i in 1..num_chunks { + let mut most_significant_diff_so_far = F::ZERO; + let mut intermediate_values = Vec::new(); + for i in 0..num_chunks { if first_input_chunks[i] != second_input_chunks[i] { - diff_index = i; + most_significant_diff_so_far = second_input_chunks[i] - first_input_chunks[i]; + intermediate_values.push(F::ZERO); + } else { + intermediate_values.push(most_significant_diff_so_far); } } - - let most_significant_diff = - second_input_chunks[diff_index] - first_input_chunks[diff_index]; + let most_significant_diff = most_significant_diff_so_far; v.push(first_input); v.push(second_input); @@ -521,6 +543,7 @@ mod tests { v.append(&mut second_input_chunks); v.append(&mut equality_dummies); v.append(&mut chunks_equal); + v.append(&mut intermediate_values); v.iter().map(|&x| x.into()).collect::>() };