diff --git a/src/gates/comparison.rs b/src/gates/comparison.rs index 9b321efe..0ac911ee 100644 --- a/src/gates/comparison.rs +++ b/src/gates/comparison.rs @@ -16,16 +16,18 @@ use crate::util::{bits_u64, ceil_div_usize}; /// A gate for checking that one value is smaller than another. #[derive(Clone, Debug)] pub(crate) struct ComparisonGate, const D: usize> { - pub(crate) chunk_bits: usize, pub(crate) num_copies: usize, + pub(crate) num_bits: usize, + pub(crate) num_chunks: usize, _phantom: PhantomData, } impl, const D: usize> ComparisonGate { - pub fn new(num_copies: usize, chunk_bits: usize) -> Self { + pub fn new(num_copies: usize, num_bits: usize, num_chunks: usize) -> Self { Self { - chunk_bits, num_copies, + num_bits, + num_chunks, _phantom: PhantomData, } } @@ -34,23 +36,23 @@ impl, const D: usize> ComparisonGate { bits_u64(F::ORDER) } - pub fn num_chunks(&self) -> usize { - ceil_div_usize(Self::field_bits(), self.chunk_bits) + pub fn chunk_bits(&self) -> usize { + ceil_div_usize(self.num_bits, self.num_chunks) } - pub fn new_from_config(config: CircuitConfig, chunk_bits: usize) -> Self { - let num_copies = Self::max_num_copies(config.num_routed_wires, chunk_bits); - Self::new(num_copies, chunk_bits) + pub fn new_from_config(config: CircuitConfig, num_bits: usize, num_chunks: usize) -> Self { + let num_copies = Self::max_num_copies(config.num_routed_wires, num_bits, num_chunks); + Self::new(num_copies, num_bits, num_chunks) } - pub fn max_num_copies(num_routed_wires: usize, chunk_bits: usize) -> usize { - let num_chunks = ceil_div_usize(Self::field_bits(), chunk_bits); + pub fn max_num_copies(num_routed_wires: usize, num_bits: usize, num_chunks: usize) -> usize { + let chunk_bits = ceil_div_usize(num_bits, num_chunks); let wires_per_copy = 4 + chunk_bits + 4 * num_chunks; num_routed_wires / wires_per_copy } pub fn wires_per_copy(&self) -> usize { - 4 + self.chunk_bits + 4 * self.num_chunks() + 4 + self.chunk_bits() + 4 * self.num_chunks } pub fn wire_first_input(&self, copy: usize) -> usize { @@ -68,32 +70,32 @@ impl, const D: usize> ComparisonGate { } pub fn wire_z_bit(&self, copy: usize, bit_index: usize) -> usize { - debug_assert!(bit_index < self.chunk_bits + 1); + debug_assert!(bit_index < self.chunk_bits() + 1); copy * self.wires_per_copy() + 4 + bit_index } pub fn wire_first_chunk_val(&self, copy: usize, chunk: usize) -> usize { debug_assert!(copy < self.num_copies); - debug_assert!(chunk < self.num_chunks()); - copy * self.wires_per_copy() + 4 + self.chunk_bits + chunk + debug_assert!(chunk < self.num_chunks); + copy * self.wires_per_copy() + 4 + self.chunk_bits() + chunk } pub fn wire_second_chunk_val(&self, copy: usize, chunk: usize) -> usize { debug_assert!(copy < self.num_copies); - debug_assert!(chunk < self.num_chunks()); - copy * self.wires_per_copy() + 4 + self.chunk_bits + self.num_chunks() + chunk + debug_assert!(chunk < self.num_chunks); + copy * self.wires_per_copy() + 4 + self.chunk_bits() + self.num_chunks + chunk } pub fn wire_equality_dummy(&self, copy: usize, chunk: usize) -> usize { debug_assert!(copy < self.num_copies); - debug_assert!(chunk < self.num_chunks()); - copy * self.wires_per_copy() + 4 + self.chunk_bits + 2 * self.num_chunks() + chunk + debug_assert!(chunk < self.num_chunks); + copy * self.wires_per_copy() + 4 + self.chunk_bits() + 2 * self.num_chunks + chunk } pub fn wire_chunks_equal(&self, copy: usize, chunk: usize) -> usize { debug_assert!(copy < self.num_copies); - debug_assert!(chunk < self.num_chunks()); - copy * self.wires_per_copy() + 4 + self.chunk_bits + 3 * self.num_chunks() + chunk + debug_assert!(chunk < self.num_chunks); + copy * self.wires_per_copy() + 4 + self.chunk_bits() + 3 * self.num_chunks + chunk } } @@ -110,15 +112,15 @@ impl, const D: usize> Gate for ComparisonGate let second_input = vars.local_wires[self.wire_second_input(c)]; // Get chunks and assert that they match - let first_chunks: Vec = (0..self.num_chunks()) + let first_chunks: Vec = (0..self.num_chunks) .map(|i| vars.local_wires[self.wire_first_chunk_val(c, i)]) .collect(); - let second_chunks: Vec = (0..self.num_chunks()) + let second_chunks: Vec = (0..self.num_chunks) .map(|i| vars.local_wires[self.wire_second_chunk_val(c, i)]) .collect(); - let chunk_base_powers: Vec = (0..self.chunk_bits) - .map(|i| F::Extension::TWO.exp_u64((i * self.chunk_bits) as u64)) + let chunk_base_powers: Vec = (0..self.chunk_bits()) + .map(|i| F::Extension::TWO.exp_u64((i * self.chunk_bits()) as u64)) .collect(); let first_chunks_combined = first_chunks @@ -136,15 +138,15 @@ impl, const D: usize> Gate for ComparisonGate constraints.push(second_chunks_combined - second_input); // Get bits to assert they match the chosen chunk. - let powers_of_two: Vec = (0..self.chunk_bits) + let powers_of_two: Vec = (0..self.chunk_bits()) .map(|i| F::Extension::TWO.exp_u64(i as u64)) .collect(); let mut most_significant_diff = - first_chunks[self.num_chunks() - 1] - second_chunks[self.num_chunks() - 1]; + first_chunks[self.num_chunks - 1] - second_chunks[self.num_chunks - 1]; // Find the chosen chunk. - for i in (0..self.num_chunks()).rev() { + for i in (0..self.num_chunks).rev() { let difference = first_chunks[i] - second_chunks[i]; let equality_dummy = vars.local_wires[self.wire_equality_dummy(c, i)]; let chunks_equal = vars.local_wires[self.wire_chunks_equal(c, i)]; @@ -158,11 +160,11 @@ impl, const D: usize> Gate for ComparisonGate + (F::Extension::ONE - chunks_equal) * this_diff; } - let z_bits: Vec = (0..self.chunk_bits + 1) + let z_bits: Vec = (0..self.chunk_bits() + 1) .map(|i| vars.local_wires[self.wire_z_bit(c, i)]) .collect(); - let powers_of_two: Vec = (0..self.chunk_bits + 1) + let powers_of_two: Vec = (0..self.chunk_bits() + 1) .map(|i| F::Extension::TWO.exp_u64(i as u64)) .collect(); let z_bits_combined = z_bits @@ -171,10 +173,10 @@ impl, const D: usize> Gate for ComparisonGate .map(|(b, x)| *b * *x) .fold(F::Extension::ZERO, |a, b| a + b); - let two_n = F::Extension::TWO.exp_u64(self.chunk_bits as u64); + let two_n = F::Extension::TWO.exp_u64(self.chunk_bits() as u64); constraints.push(z_bits_combined - (two_n + most_significant_diff)); - constraints.push(z_bits[self.chunk_bits - 1]); + constraints.push(z_bits[self.chunk_bits() - 1]); } constraints @@ -211,7 +213,7 @@ impl, const D: usize> Gate for ComparisonGate } fn num_wires(&self) -> usize { - self.wire_chunks_equal(self.num_copies - 1, self.num_chunks() - 1) + 1 + self.wire_chunks_equal(self.num_copies - 1, self.num_chunks - 1) + 1 } fn num_constants(&self) -> usize { @@ -223,7 +225,7 @@ impl, const D: usize> Gate for ComparisonGate } fn num_constraints(&self) -> usize { - 4 * self.num_copies * self.chunk_bits + 4 * self.num_copies * self.chunk_bits() } } @@ -276,11 +278,11 @@ impl, const D: usize> SimpleGenerator }) .collect(); - let powers_of_two: Vec = (0..self.gate.chunk_bits) + let powers_of_two: Vec = (0..self.gate.chunk_bits()) .map(|i| F::TWO.exp_u64(i as u64)) .collect(); let first_input_chunks: Vec = first_input_bits - .chunks(self.gate.chunk_bits) + .chunks(self.gate.chunk_bits()) .map(|bits| { bits.iter() .zip(powers_of_two.iter()) @@ -289,7 +291,7 @@ impl, const D: usize> SimpleGenerator }) .collect(); let second_input_chunks: Vec = second_input_bits - .chunks(self.gate.chunk_bits) + .chunks(self.gate.chunk_bits()) .map(|bits| { bits.iter() .zip(powers_of_two.iter()) @@ -298,7 +300,7 @@ impl, const D: usize> SimpleGenerator }) .collect(); - let chunks_equal: Vec = (0..self.gate.num_chunks()) + let chunks_equal: Vec = (0..self.gate.num_chunks) .map(|i| F::from_bool(first_input_chunks[i] == second_input_chunks[i])) .collect(); let equality_dummies: Vec = first_input_chunks @@ -307,8 +309,8 @@ impl, const D: usize> SimpleGenerator .map(|(f, s)| if *f == *s { F::ONE } else { F::ONE / (*f - *s) }) .collect(); - let z = F::TWO.exp_u64(self.gate.chunk_bits as u64) + first_input - second_input; - let z_bits: Vec = (0..self.gate.chunk_bits + 1) + let z = F::TWO.exp_u64(self.gate.chunk_bits() as u64) + first_input - second_input; + let z_bits: Vec = (0..self.gate.chunk_bits() + 1) .scan(z.to_canonical_u64(), |acc, _| { let tmp = *acc % 2; *acc /= 2; @@ -317,10 +319,10 @@ impl, const D: usize> SimpleGenerator .collect(); out_buffer.set_wire(local_wire(self.gate.wire_z_val(self.copy)), z); - for b in 0..self.gate.chunk_bits + 1 { + for b in 0..self.gate.chunk_bits() + 1 { out_buffer.set_wire(local_wire(self.gate.wire_z_bit(self.copy, b)), z_bits[b]); } - for i in 0..self.gate.num_chunks() { + for i in 0..self.gate.num_chunks { out_buffer.set_wire( local_wire(self.gate.wire_first_chunk_val(self.copy, i)), first_input_chunks[i], @@ -348,7 +350,7 @@ mod tests { use anyhow::Result; use crate::field::crandall_field::CrandallField; - use crate::field::extension_field::quartic::QuarticCrandallField; + use crate::field::extension_field::quartic::QuarticExtension; use crate::field::field_types::Field; use crate::gates::comparison::ComparisonGate; use crate::gates::gate::Gate; @@ -360,15 +362,25 @@ mod tests { #[test] fn wire_indices() { type CG = ComparisonGate; + let num_bits = 40; let num_copies = 3; - let chunk_bits = 3; + let num_chunks = 5; let gate = CG { - chunk_bits, + num_bits, + num_chunks, num_copies, _phantom: PhantomData, }; + assert_eq!(gate.wire_first_input(0), 0); + assert_eq!(gate.wire_second_input(0), 1); + assert_eq!(gate.wire_z_val(0), 2); + assert_eq!(gate.wire_z_bit(0, 0), 3); + assert_eq!(gate.wire_z_bit(0, 3), 6); + assert_eq!(gate.wire_first_chunk_val(0, 0), 7); + assert_eq!(gate.wire_first_chunk_val(0, 0), 7); + assert_eq!(gate.wire_first_input(0, 0), 0); assert_eq!(gate.wire_first_input(0, 2), 2); assert_eq!(gate.wire_second_input(0, 0), 3);