diff --git a/src/gadgets/arithmetic.rs b/src/gadgets/arithmetic.rs index 6dcf1b3d..71dbf310 100644 --- a/src/gadgets/arithmetic.rs +++ b/src/gadgets/arithmetic.rs @@ -170,14 +170,12 @@ impl, const D: usize> CircuitBuilder { /// Exponentiate `base` to the power of `exponent`, given by its little-endian bits. pub fn exp_from_bits(&mut self, base: Target, exponent_bits: &[Target]) -> Target { let mut current = base; - let one_ext = self.one_extension(); - let mut product = self.one(); + let one = self.one(); + let mut product = one; for &bit in exponent_bits { - // TODO: Add base field select. - let current_ext = self.convert_to_ext(current); - let multiplicand = self.select(bit, current_ext, one_ext); - product = self.mul(product, multiplicand.0[0]); + let multiplicand = self.select(bit, current, one); + product = self.mul(product, multiplicand); current = self.mul(current, current); } @@ -189,14 +187,12 @@ impl, const D: usize> CircuitBuilder { /// Exponentiate `base` to the power of `2^bit_length-1-exponent`, given by its little-endian bits. pub fn exp_from_complement_bits(&mut self, base: Target, exponent_bits: &[Target]) -> Target { let mut current = base; - let one_ext = self.one_extension(); - let mut product = self.one(); + let one = self.one(); + let mut product = one; for &bit in exponent_bits { - let current_ext = self.convert_to_ext(current); - // TODO: Add base field select. - let multiplicand = self.select(bit, one_ext, current_ext); - product = self.mul(product, multiplicand.0[0]); + let multiplicand = self.select(bit, one, current); + product = self.mul(product, multiplicand); current = self.mul(current, current); } diff --git a/src/gadgets/mod.rs b/src/gadgets/mod.rs index 2f216870..4c4160e1 100644 --- a/src/gadgets/mod.rs +++ b/src/gadgets/mod.rs @@ -5,6 +5,6 @@ pub mod insert; pub mod interpolation; pub mod polynomial; pub mod range_check; -pub mod rotate; +pub mod select; pub mod split_base; pub(crate) mod split_join; diff --git a/src/gadgets/rotate.rs b/src/gadgets/rotate.rs deleted file mode 100644 index 67677795..00000000 --- a/src/gadgets/rotate.rs +++ /dev/null @@ -1,167 +0,0 @@ -use crate::circuit_builder::CircuitBuilder; -use crate::field::extension_field::target::ExtensionTarget; -use crate::field::extension_field::Extendable; -use crate::target::Target; -use crate::util::log2_ceil; - -impl, const D: usize> CircuitBuilder { - /// Selects `x` or `y` based on `b`, which is assumed to be binary. - /// In particular, this returns `if b { x } else { y }`. - /// Note: This does not range-check `b`. - // TODO: This uses 10 gates per call. If addends are added to `MulExtensionGate`, this will be - // reduced to 2 gates. We could also use a new degree 2 `SelectGate` for this. - // If `num_routed_wire` is larger than 26, we could batch two `select` in one gate. - pub fn select( - &mut self, - b: Target, - x: ExtensionTarget, - y: ExtensionTarget, - ) -> ExtensionTarget { - let b_y_minus_y = self.scalar_mul_sub_extension(b, y, y); - self.scalar_mul_sub_extension(b, x, b_y_minus_y) - } - - /// Left-rotates an array `k` times if `b=1` else return the same array. - pub fn rotate_left_fixed( - &mut self, - b: Target, - k: usize, - v: &[ExtensionTarget], - ) -> Vec> { - let len = v.len(); - debug_assert!(k < len, "Trying to rotate by more than the vector length."); - let mut res = Vec::new(); - - for i in 0..len { - res.push(self.select(b, v[(i + k) % len], v[i])); - } - - res - } - - /// Left-rotates an array `k` times if `b=1` else return the same array. - pub fn rotate_right_fixed( - &mut self, - b: Target, - k: usize, - v: &[ExtensionTarget], - ) -> Vec> { - let len = v.len(); - debug_assert!(k < len, "Trying to rotate by more than the vector length."); - let mut res = Vec::new(); - - for i in 0..len { - res.push(self.select(b, v[(len + i - k) % len], v[i])); - } - - res - } - - /// Left-rotates an vector by the `Target` having bits given in little-endian by `num_rotation_bits`. - pub fn rotate_left_from_bits( - &mut self, - num_rotation_bits: &[Target], - v: &[ExtensionTarget], - ) -> Vec> { - let mut v = v.to_vec(); - - for i in 0..num_rotation_bits.len() { - v = self.rotate_left_fixed(num_rotation_bits[i], 1 << i, &v); - } - - v - } - - pub fn rotate_right_from_bits( - &mut self, - num_rotation_bits: &[Target], - v: &[ExtensionTarget], - ) -> Vec> { - let mut v = v.to_vec(); - - for i in 0..num_rotation_bits.len() { - v = self.rotate_right_fixed(num_rotation_bits[i], 1 << i, &v); - } - - v - } - - /// Left-rotates an array by `num_rotation`. Assumes that `num_rotation` is range-checked to be - /// less than `2^len_bits`. - pub fn rotate_left( - &mut self, - num_rotation: Target, - v: &[ExtensionTarget], - ) -> Vec> { - let len_bits = log2_ceil(v.len()); - let bits = self.split_le(num_rotation, len_bits); - - self.rotate_left_from_bits(&bits, v) - } - - pub fn rotate_right( - &mut self, - num_rotation: Target, - v: &[ExtensionTarget], - ) -> Vec> { - let len_bits = log2_ceil(v.len()); - let bits = self.split_le(num_rotation, len_bits); - - self.rotate_right_from_bits(&bits, v) - } -} - -#[cfg(test)] -mod tests { - use anyhow::Result; - - use super::*; - use crate::circuit_data::CircuitConfig; - use crate::field::crandall_field::CrandallField; - use crate::field::extension_field::quartic::QuarticCrandallField; - use crate::field::field::Field; - use crate::verifier::verify; - use crate::witness::PartialWitness; - - fn real_rotate( - num_rotation: usize, - v: &[ExtensionTarget], - ) -> Vec> { - let mut res = v.to_vec(); - res.rotate_left(num_rotation); - res - } - - fn test_rotate_given_len(len: usize) -> Result<()> { - type F = CrandallField; - type FF = QuarticCrandallField; - let config = CircuitConfig::large_config(); - let mut builder = CircuitBuilder::::new(config); - let v = (0..len) - .map(|_| builder.constant_extension(FF::rand())) - .collect::>(); - - for i in 0..len { - let it = builder.constant(F::from_canonical_usize(i)); - let rotated = real_rotate(i, &v); - let purported_rotated = builder.rotate_left(it, &v); - - for (x, y) in rotated.into_iter().zip(purported_rotated) { - builder.assert_equal_extension(x, y); - } - } - - let data = builder.build(); - let proof = data.prove(PartialWitness::new())?; - - verify(proof, &data.verifier_only, &data.common) - } - - #[test] - fn test_rotate() -> Result<()> { - for len in 1..5 { - test_rotate_given_len(len)?; - } - Ok(()) - } -} diff --git a/src/gadgets/select.rs b/src/gadgets/select.rs new file mode 100644 index 00000000..bbd36d76 --- /dev/null +++ b/src/gadgets/select.rs @@ -0,0 +1,76 @@ +use crate::circuit_builder::CircuitBuilder; +use crate::field::extension_field::target::ExtensionTarget; +use crate::field::extension_field::Extendable; +use crate::gates::arithmetic::ArithmeticExtensionGate; +use crate::target::Target; + +impl, const D: usize> CircuitBuilder { + /// Selects `x` or `y` based on `b`, which is assumed to be binary, i.e., this returns `if b { x } else { y }`. + /// This expression is gotten as `bx - (by-y)`, which can be computed with a single `ArithmeticExtensionGate`. + /// Note: This does not range-check `b`. + pub fn select_ext( + &mut self, + b: Target, + x: ExtensionTarget, + y: ExtensionTarget, + ) -> ExtensionTarget { + let b_ext = self.convert_to_ext(b); + let gate = self.num_gates(); + // Holds `by - y`. + let first_out = + ExtensionTarget::from_range(gate, ArithmeticExtensionGate::::wires_first_output()); + self.double_arithmetic_extension(F::ONE, F::NEG_ONE, b_ext, y, y, b_ext, x, first_out) + .1 + } + + /// See `select_ext`. + pub fn select(&mut self, b: Target, x: Target, y: Target) -> Target { + let x_ext = self.convert_to_ext(x); + let y_ext = self.convert_to_ext(y); + self.select_ext(b, x_ext, y_ext).to_target_array()[0] + } +} + +#[cfg(test)] +mod tests { + use anyhow::Result; + + use super::*; + use crate::circuit_data::CircuitConfig; + use crate::field::crandall_field::CrandallField; + use crate::field::extension_field::quartic::QuarticCrandallField; + use crate::field::field::Field; + use crate::verifier::verify; + use crate::witness::PartialWitness; + + #[test] + fn test_select() -> Result<()> { + type F = CrandallField; + type FF = QuarticCrandallField; + let config = CircuitConfig::large_config(); + let mut builder = CircuitBuilder::::new(config); + let mut pw = PartialWitness::new(); + + let (x, y) = (FF::rand(), FF::rand()); + let xt = builder.add_virtual_extension_target(); + let yt = builder.add_virtual_extension_target(); + let truet = builder.add_virtual_target(); + let falset = builder.add_virtual_target(); + + pw.set_extension_target(xt, x); + pw.set_extension_target(yt, y); + pw.set_target(truet, F::ONE); + pw.set_target(falset, F::ZERO); + + let should_be_x = builder.select_ext(truet, xt, yt); + let should_be_y = builder.select_ext(falset, xt, yt); + + builder.assert_equal_extension(should_be_x, xt); + builder.assert_equal_extension(should_be_y, yt); + + let data = builder.build(); + let proof = data.prove(pw)?; + + verify(proof, &data.verifier_only, &data.common) + } +}