diff --git a/src/gates/comparison.rs b/src/gates/comparison.rs index 083f923a..9783a42c 100644 --- a/src/gates/comparison.rs +++ b/src/gates/comparison.rs @@ -158,11 +158,10 @@ impl, const D: usize> Gate for ComparisonGate } let bits_combined = reduce_with_powers(&most_significant_diff_bits, F::Extension::TWO); - let two_n_minus_1 = - F::Extension::from_canonical_u64(1 << self.chunk_bits()) - F::Extension::ONE; - constraints.push((two_n_minus_1 + most_significant_diff) - bits_combined); + let two_n = F::Extension::from_canonical_u64(1 << self.chunk_bits()); + constraints.push((two_n + most_significant_diff) - bits_combined); - // Iff first < second, the top (n + 1st) bit of (2^n - 1 + most_significant_diff) will be 1. + // Iff first <= second, the top (n + 1st) bit of (2^n + most_significant_diff) will be 1. let result_bool = vars.local_wires[self.wire_result_bool()]; constraints.push(result_bool - most_significant_diff_bits[self.chunk_bits()]); @@ -239,10 +238,10 @@ impl, const D: usize> Gate for ComparisonGate } let bits_combined = reduce_with_powers(&most_significant_diff_bits, F::TWO); - let two_n_minus_1 = F::from_canonical_u64(1 << self.chunk_bits()) - F::ONE; - constraints.push((two_n_minus_1 + most_significant_diff) - bits_combined); + let two_n = F::from_canonical_u64(1 << self.chunk_bits()); + constraints.push((two_n + most_significant_diff) - bits_combined); - // Iff first < second, the top (n + 1st) bit of (2^n - 1 + most_significant_diff) will be 1. + // Iff first <= second, the top (n + 1st) bit of (2^n - 1 + most_significant_diff) will be 1. let result_bool = vars.local_wires[self.wire_result_bool()]; constraints.push(result_bool - most_significant_diff_bits[self.chunk_bits()]); @@ -334,13 +333,12 @@ impl, const D: usize> Gate for ComparisonGate let two = builder.two(); let bits_combined = reduce_with_powers_ext_recursive(builder, &most_significant_diff_bits, two); - let two_n_minus_1 = builder.constant_extension( - F::Extension::from_canonical_u64(1 << self.chunk_bits()) - F::Extension::ONE, - ); - let sum = builder.add_extension(two_n_minus_1, most_significant_diff); + let two_n = + builder.constant_extension(F::Extension::from_canonical_u64(1 << self.chunk_bits())); + let sum = builder.add_extension(two_n, most_significant_diff); constraints.push(builder.sub_extension(sum, bits_combined)); - // Iff first < second, the top (n + 1st) bit of (2^n - 1 + most_significant_diff) will be 1. + // Iff first <= second, the top (n + 1st) bit of (2^n + most_significant_diff) will be 1. let result_bool = vars.local_wires[self.wire_result_bool()]; constraints.push( builder.sub_extension(result_bool, most_significant_diff_bits[self.chunk_bits()]), @@ -410,7 +408,7 @@ impl, const D: usize> SimpleGenerator let first_input_u64 = first_input.to_canonical_u64(); let second_input_u64 = second_input.to_canonical_u64(); - let result = F::from_canonical_usize((first_input_u64 < second_input_u64) as usize); + let result = F::from_canonical_usize((first_input_u64 <= second_input_u64) as usize); let chunk_size = 1 << self.gate.chunk_bits(); let first_input_chunks: Vec = (0..self.gate.num_chunks) @@ -450,7 +448,7 @@ impl, const D: usize> SimpleGenerator let most_significant_diff = most_significant_diff_so_far; let two_n_plus_msd = - ((1 << self.gate.chunk_bits()) - 1) as u64 + most_significant_diff.to_canonical_u64(); + (1 << self.gate.chunk_bits()) as u64 + most_significant_diff.to_canonical_u64(); let msd_bits: Vec = (0..self.gate.chunk_bits() + 1) .scan(two_n_plus_msd, |acc, _| { let tmp = *acc % 2; @@ -571,7 +569,7 @@ mod tests { let first_input_u64 = first_input.to_canonical_u64(); let second_input_u64 = second_input.to_canonical_u64(); - let result_bool = F::from_bool(first_input_u64 < second_input_u64); + let result_bool = F::from_bool(first_input_u64 <= second_input_u64); let chunk_size = 1 << chunk_bits; let mut first_input_chunks: Vec = (0..num_chunks) @@ -610,10 +608,10 @@ mod tests { } let most_significant_diff = most_significant_diff_so_far; - let two_n_min_1_plus_msd = - ((1 << chunk_bits) - 1) as u64 + most_significant_diff.to_canonical_u64(); + let two_n_plus_msd = + (1 << chunk_bits) as u64 + most_significant_diff.to_canonical_u64(); let mut msd_bits: Vec = (0..chunk_bits + 1) - .scan(two_n_min_1_plus_msd, |acc, _| { + .scan(two_n_plus_msd, |acc, _| { let tmp = *acc % 2; *acc /= 2; Some(F::from_canonical_u64(tmp))