comparison gate should also be <=

This commit is contained in:
Nicholas Ward 2021-10-12 13:31:05 -07:00
parent 26959d11c9
commit 6dd14eb27a

View File

@ -158,11 +158,10 @@ impl<F: RichField + Extendable<D>, const D: usize> Gate<F, D> for ComparisonGate
}
let bits_combined = reduce_with_powers(&most_significant_diff_bits, F::Extension::TWO);
let two_n_minus_1 =
F::Extension::from_canonical_u64(1 << self.chunk_bits()) - F::Extension::ONE;
constraints.push((two_n_minus_1 + most_significant_diff) - bits_combined);
let two_n = F::Extension::from_canonical_u64(1 << self.chunk_bits());
constraints.push((two_n + most_significant_diff) - bits_combined);
// Iff first < second, the top (n + 1st) bit of (2^n - 1 + most_significant_diff) will be 1.
// Iff first <= second, the top (n + 1st) bit of (2^n + most_significant_diff) will be 1.
let result_bool = vars.local_wires[self.wire_result_bool()];
constraints.push(result_bool - most_significant_diff_bits[self.chunk_bits()]);
@ -239,10 +238,10 @@ impl<F: RichField + Extendable<D>, const D: usize> Gate<F, D> for ComparisonGate
}
let bits_combined = reduce_with_powers(&most_significant_diff_bits, F::TWO);
let two_n_minus_1 = F::from_canonical_u64(1 << self.chunk_bits()) - F::ONE;
constraints.push((two_n_minus_1 + most_significant_diff) - bits_combined);
let two_n = F::from_canonical_u64(1 << self.chunk_bits());
constraints.push((two_n + most_significant_diff) - bits_combined);
// Iff first < second, the top (n + 1st) bit of (2^n - 1 + most_significant_diff) will be 1.
// Iff first <= second, the top (n + 1st) bit of (2^n - 1 + most_significant_diff) will be 1.
let result_bool = vars.local_wires[self.wire_result_bool()];
constraints.push(result_bool - most_significant_diff_bits[self.chunk_bits()]);
@ -334,13 +333,12 @@ impl<F: RichField + Extendable<D>, const D: usize> Gate<F, D> for ComparisonGate
let two = builder.two();
let bits_combined =
reduce_with_powers_ext_recursive(builder, &most_significant_diff_bits, two);
let two_n_minus_1 = builder.constant_extension(
F::Extension::from_canonical_u64(1 << self.chunk_bits()) - F::Extension::ONE,
);
let sum = builder.add_extension(two_n_minus_1, most_significant_diff);
let two_n =
builder.constant_extension(F::Extension::from_canonical_u64(1 << self.chunk_bits()));
let sum = builder.add_extension(two_n, most_significant_diff);
constraints.push(builder.sub_extension(sum, bits_combined));
// Iff first < second, the top (n + 1st) bit of (2^n - 1 + most_significant_diff) will be 1.
// Iff first <= second, the top (n + 1st) bit of (2^n + most_significant_diff) will be 1.
let result_bool = vars.local_wires[self.wire_result_bool()];
constraints.push(
builder.sub_extension(result_bool, most_significant_diff_bits[self.chunk_bits()]),
@ -410,7 +408,7 @@ impl<F: RichField + Extendable<D>, const D: usize> SimpleGenerator<F>
let first_input_u64 = first_input.to_canonical_u64();
let second_input_u64 = second_input.to_canonical_u64();
let result = F::from_canonical_usize((first_input_u64 < second_input_u64) as usize);
let result = F::from_canonical_usize((first_input_u64 <= second_input_u64) as usize);
let chunk_size = 1 << self.gate.chunk_bits();
let first_input_chunks: Vec<F> = (0..self.gate.num_chunks)
@ -450,7 +448,7 @@ impl<F: RichField + Extendable<D>, const D: usize> SimpleGenerator<F>
let most_significant_diff = most_significant_diff_so_far;
let two_n_plus_msd =
((1 << self.gate.chunk_bits()) - 1) as u64 + most_significant_diff.to_canonical_u64();
(1 << self.gate.chunk_bits()) as u64 + most_significant_diff.to_canonical_u64();
let msd_bits: Vec<F> = (0..self.gate.chunk_bits() + 1)
.scan(two_n_plus_msd, |acc, _| {
let tmp = *acc % 2;
@ -571,7 +569,7 @@ mod tests {
let first_input_u64 = first_input.to_canonical_u64();
let second_input_u64 = second_input.to_canonical_u64();
let result_bool = F::from_bool(first_input_u64 < second_input_u64);
let result_bool = F::from_bool(first_input_u64 <= second_input_u64);
let chunk_size = 1 << chunk_bits;
let mut first_input_chunks: Vec<F> = (0..num_chunks)
@ -610,10 +608,10 @@ mod tests {
}
let most_significant_diff = most_significant_diff_so_far;
let two_n_min_1_plus_msd =
((1 << chunk_bits) - 1) as u64 + most_significant_diff.to_canonical_u64();
let two_n_plus_msd =
(1 << chunk_bits) as u64 + most_significant_diff.to_canonical_u64();
let mut msd_bits: Vec<F> = (0..chunk_bits + 1)
.scan(two_n_min_1_plus_msd, |acc, _| {
.scan(two_n_plus_msd, |acc, _| {
let tmp = *acc % 2;
*acc /= 2;
Some(F::from_canonical_u64(tmp))