From d4a0a8661ee596d23880920695c4d64e40e5f971 Mon Sep 17 00:00:00 2001 From: Jakub Nabaglo Date: Mon, 20 Dec 2021 15:08:07 -0800 Subject: [PATCH] Packed evaluation for most gates (#395) * Most gates support packed evaluation * ComparisonGate * Minor: outdated todo marker * Revert superfluous change * Post-merge fixes * Daniel comments * Minor: Markdown in comments --- src/gates/arithmetic_base.rs | 47 +++++++--- src/gates/arithmetic_u32.rs | 99 ++++++++++++-------- src/gates/assert_le.rs | 154 +++++++++++++++++------------- src/gates/base_sum.rs | 49 +++++++--- src/gates/comparison.rs | 176 +++++++++++++++++++---------------- src/gates/constant.rs | 35 +++++-- src/gates/exponentiation.rs | 85 ++++++++++------- src/gates/gmimc.rs | 105 ++++++++++++--------- src/gates/mod.rs | 1 + src/gates/packed_util.rs | 39 ++++++++ src/gates/public_input.rs | 35 +++++-- src/gates/random_access.rs | 91 +++++++++++------- src/gates/subtraction_u32.rs | 91 +++++++++++------- src/gates/switch.rs | 57 ++++++++---- src/gates/util.rs | 21 +++-- src/plonk/plonk_common.rs | 9 +- src/plonk/vars.rs | 81 ++++++++++++++++ 17 files changed, 768 insertions(+), 407 deletions(-) create mode 100644 src/gates/packed_util.rs diff --git a/src/gates/arithmetic_base.rs b/src/gates/arithmetic_base.rs index 48a40b67..a765ff61 100644 --- a/src/gates/arithmetic_base.rs +++ b/src/gates/arithmetic_base.rs @@ -1,14 +1,19 @@ use crate::field::extension_field::target::ExtensionTarget; use crate::field::extension_field::Extendable; use crate::field::field_types::RichField; +use crate::field::packed_field::PackedField; use crate::gates::gate::Gate; +use crate::gates::packed_util::PackedEvaluableBase; use crate::gates::util::StridedConstraintConsumer; use crate::iop::generator::{GeneratedValues, SimpleGenerator, WitnessGenerator}; use crate::iop::target::Target; use crate::iop::witness::{PartitionWitness, Witness}; use crate::plonk::circuit_builder::CircuitBuilder; use crate::plonk::circuit_data::CircuitConfig; -use crate::plonk::vars::{EvaluationTargets, EvaluationVars, EvaluationVarsBase}; +use crate::plonk::vars::{ + EvaluationTargets, EvaluationVars, EvaluationVarsBase, EvaluationVarsBaseBatch, + EvaluationVarsBasePacked, +}; /// A gate which can perform a weighted multiply-add, i.e. `result = c0 x y + c1 z`. If the config /// supports enough routed wires, it can support several such operations in one gate. @@ -70,21 +75,14 @@ impl, const D: usize> Gate for ArithmeticGate fn eval_unfiltered_base_one( &self, - vars: EvaluationVarsBase, - mut yield_constr: StridedConstraintConsumer, + _vars: EvaluationVarsBase, + _yield_constr: StridedConstraintConsumer, ) { - let const_0 = vars.local_constants[0]; - let const_1 = vars.local_constants[1]; + panic!("use eval_unfiltered_base_packed instead"); + } - for i in 0..self.num_ops { - let multiplicand_0 = vars.local_wires[Self::wire_ith_multiplicand_0(i)]; - let multiplicand_1 = vars.local_wires[Self::wire_ith_multiplicand_1(i)]; - let addend = vars.local_wires[Self::wire_ith_addend(i)]; - let output = vars.local_wires[Self::wire_ith_output(i)]; - let computed_output = multiplicand_0 * multiplicand_1 * const_0 + addend * const_1; - - yield_constr.one(output - computed_output); - } + fn eval_unfiltered_base_batch(&self, vars_base: EvaluationVarsBaseBatch) -> Vec { + self.eval_unfiltered_base_batch_packed(vars_base) } fn eval_unfiltered_recursively( @@ -152,6 +150,27 @@ impl, const D: usize> Gate for ArithmeticGate } } +impl, const D: usize> PackedEvaluableBase for ArithmeticGate { + fn eval_unfiltered_base_packed>( + &self, + vars: EvaluationVarsBasePacked

, + mut yield_constr: StridedConstraintConsumer

, + ) { + let const_0 = vars.local_constants[0]; + let const_1 = vars.local_constants[1]; + + for i in 0..self.num_ops { + let multiplicand_0 = vars.local_wires[Self::wire_ith_multiplicand_0(i)]; + let multiplicand_1 = vars.local_wires[Self::wire_ith_multiplicand_1(i)]; + let addend = vars.local_wires[Self::wire_ith_addend(i)]; + let output = vars.local_wires[Self::wire_ith_output(i)]; + let computed_output = multiplicand_0 * multiplicand_1 * const_0 + addend * const_1; + + yield_constr.one(output - computed_output); + } + } +} + #[derive(Clone, Debug)] struct ArithmeticBaseGenerator, const D: usize> { gate_index: usize, diff --git a/src/gates/arithmetic_u32.rs b/src/gates/arithmetic_u32.rs index 0dc66ee9..b3c31b13 100644 --- a/src/gates/arithmetic_u32.rs +++ b/src/gates/arithmetic_u32.rs @@ -5,7 +5,9 @@ use itertools::unfold; use crate::field::extension_field::target::ExtensionTarget; use crate::field::extension_field::Extendable; use crate::field::field_types::Field; +use crate::field::packed_field::PackedField; use crate::gates::gate::Gate; +use crate::gates::packed_util::PackedEvaluableBase; use crate::gates::util::StridedConstraintConsumer; use crate::iop::generator::{GeneratedValues, SimpleGenerator, WitnessGenerator}; use crate::iop::target::Target; @@ -13,7 +15,10 @@ use crate::iop::wire::Wire; use crate::iop::witness::{PartitionWitness, Witness}; use crate::plonk::circuit_builder::CircuitBuilder; use crate::plonk::circuit_data::CircuitConfig; -use crate::plonk::vars::{EvaluationTargets, EvaluationVars, EvaluationVarsBase}; +use crate::plonk::vars::{ + EvaluationTargets, EvaluationVars, EvaluationVarsBase, EvaluationVarsBaseBatch, + EvaluationVarsBasePacked, +}; /// A gate to perform a basic mul-add on 32-bit values (we assume they are range-checked beforehand). #[derive(Copy, Clone, Debug)] @@ -94,8 +99,8 @@ impl, const D: usize> Gate for U32ArithmeticGate { constraints.push(combined_output - computed_output); - let mut combined_low_limbs = F::Extension::ZERO; - let mut combined_high_limbs = F::Extension::ZERO; + let mut combined_low_limbs = ::ZERO; + let mut combined_high_limbs = ::ZERO; let midpoint = Self::num_limbs() / 2; let base = F::Extension::from_canonical_u64(1u64 << Self::limb_bits()); for j in (0..Self::num_limbs()).rev() { @@ -121,45 +126,14 @@ impl, const D: usize> Gate for U32ArithmeticGate { fn eval_unfiltered_base_one( &self, - vars: EvaluationVarsBase, - mut yield_constr: StridedConstraintConsumer, + _vars: EvaluationVarsBase, + _yield_constr: StridedConstraintConsumer, ) { - for i in 0..self.num_ops { - let multiplicand_0 = vars.local_wires[self.wire_ith_multiplicand_0(i)]; - let multiplicand_1 = vars.local_wires[self.wire_ith_multiplicand_1(i)]; - let addend = vars.local_wires[self.wire_ith_addend(i)]; + panic!("use eval_unfiltered_base_packed instead"); + } - let computed_output = multiplicand_0 * multiplicand_1 + addend; - - let output_low = vars.local_wires[self.wire_ith_output_low_half(i)]; - let output_high = vars.local_wires[self.wire_ith_output_high_half(i)]; - - let base = F::from_canonical_u64(1 << 32u64); - let combined_output = output_high * base + output_low; - - yield_constr.one(combined_output - computed_output); - - let mut combined_low_limbs = F::ZERO; - let mut combined_high_limbs = F::ZERO; - let midpoint = Self::num_limbs() / 2; - let base = F::from_canonical_u64(1u64 << Self::limb_bits()); - for j in (0..Self::num_limbs()).rev() { - let this_limb = vars.local_wires[self.wire_ith_output_jth_limb(i, j)]; - let max_limb = 1 << Self::limb_bits(); - let product = (0..max_limb) - .map(|x| this_limb - F::from_canonical_usize(x)) - .product(); - yield_constr.one(product); - - if j < midpoint { - combined_low_limbs = base * combined_low_limbs + this_limb; - } else { - combined_high_limbs = base * combined_high_limbs + this_limb; - } - } - yield_constr.one(combined_low_limbs - output_low); - yield_constr.one(combined_high_limbs - output_high); - } + fn eval_unfiltered_base_batch(&self, vars_base: EvaluationVarsBaseBatch) -> Vec { + self.eval_unfiltered_base_batch_packed(vars_base) } fn eval_unfiltered_recursively( @@ -257,6 +231,51 @@ impl, const D: usize> Gate for U32ArithmeticGate { } } +impl, const D: usize> PackedEvaluableBase for U32ArithmeticGate { + fn eval_unfiltered_base_packed>( + &self, + vars: EvaluationVarsBasePacked

, + mut yield_constr: StridedConstraintConsumer

, + ) { + for i in 0..self.num_ops { + let multiplicand_0 = vars.local_wires[self.wire_ith_multiplicand_0(i)]; + let multiplicand_1 = vars.local_wires[self.wire_ith_multiplicand_1(i)]; + let addend = vars.local_wires[self.wire_ith_addend(i)]; + + let computed_output = multiplicand_0 * multiplicand_1 + addend; + + let output_low = vars.local_wires[self.wire_ith_output_low_half(i)]; + let output_high = vars.local_wires[self.wire_ith_output_high_half(i)]; + + let base = F::from_canonical_u64(1 << 32u64); + let combined_output = output_high * base + output_low; + + yield_constr.one(combined_output - computed_output); + + let mut combined_low_limbs = P::ZERO; + let mut combined_high_limbs = P::ZERO; + let midpoint = Self::num_limbs() / 2; + let base = F::from_canonical_u64(1u64 << Self::limb_bits()); + for j in (0..Self::num_limbs()).rev() { + let this_limb = vars.local_wires[self.wire_ith_output_jth_limb(i, j)]; + let max_limb = 1 << Self::limb_bits(); + let product = (0..max_limb) + .map(|x| this_limb - F::from_canonical_usize(x)) + .product(); + yield_constr.one(product); + + if j < midpoint { + combined_low_limbs = combined_low_limbs * base + this_limb; + } else { + combined_high_limbs = combined_high_limbs * base + this_limb; + } + } + yield_constr.one(combined_low_limbs - output_low); + yield_constr.one(combined_high_limbs - output_high); + } + } +} + #[derive(Clone, Debug)] struct U32ArithmeticGenerator, const D: usize> { gate: U32ArithmeticGate, diff --git a/src/gates/assert_le.rs b/src/gates/assert_le.rs index 687699a2..3ab50b4e 100644 --- a/src/gates/assert_le.rs +++ b/src/gates/assert_le.rs @@ -3,7 +3,9 @@ use std::marker::PhantomData; use crate::field::extension_field::target::ExtensionTarget; use crate::field::extension_field::Extendable; use crate::field::field_types::{Field, PrimeField, RichField}; +use crate::field::packed_field::PackedField; use crate::gates::gate::Gate; +use crate::gates::packed_util::PackedEvaluableBase; use crate::gates::util::StridedConstraintConsumer; use crate::iop::generator::{GeneratedValues, SimpleGenerator, WitnessGenerator}; use crate::iop::target::Target; @@ -11,7 +13,10 @@ use crate::iop::wire::Wire; use crate::iop::witness::{PartitionWitness, Witness}; use crate::plonk::circuit_builder::CircuitBuilder; use crate::plonk::plonk_common::{reduce_with_powers, reduce_with_powers_ext_recursive}; -use crate::plonk::vars::{EvaluationTargets, EvaluationVars, EvaluationVarsBase}; +use crate::plonk::vars::{ + EvaluationTargets, EvaluationVars, EvaluationVarsBase, EvaluationVarsBaseBatch, + EvaluationVarsBasePacked, +}; use crate::util::{bits_u64, ceil_div_usize}; // TODO: replace/merge this gate with `ComparisonGate`. @@ -109,7 +114,7 @@ impl, const D: usize> Gate for AssertLessThan let chunk_size = 1 << self.chunk_bits(); - let mut most_significant_diff_so_far = F::Extension::ZERO; + let mut most_significant_diff_so_far = ::ZERO; for i in 0..self.num_chunks { // Range-check the chunks to be less than `chunk_size`. @@ -127,14 +132,15 @@ impl, const D: usize> Gate for AssertLessThan let chunks_equal = vars.local_wires[self.wire_chunks_equal(i)]; // Two constraints to assert that `chunks_equal` is valid. - constraints.push(difference * equality_dummy - (F::Extension::ONE - chunks_equal)); + constraints + .push(difference * equality_dummy - (::ONE - chunks_equal)); constraints.push(chunks_equal * difference); // Update `most_significant_diff_so_far`. let intermediate_value = vars.local_wires[self.wire_intermediate_value(i)]; constraints.push(intermediate_value - chunks_equal * most_significant_diff_so_far); most_significant_diff_so_far = - intermediate_value + (F::Extension::ONE - chunks_equal) * difference; + intermediate_value + (::ONE - chunks_equal) * difference; } let most_significant_diff = vars.local_wires[self.wire_most_significant_diff()]; @@ -151,70 +157,14 @@ impl, const D: usize> Gate for AssertLessThan fn eval_unfiltered_base_one( &self, - vars: EvaluationVarsBase, - mut yield_constr: StridedConstraintConsumer, + _vars: EvaluationVarsBase, + _yield_constr: StridedConstraintConsumer, ) { - let first_input = vars.local_wires[self.wire_first_input()]; - let second_input = vars.local_wires[self.wire_second_input()]; + panic!("use eval_unfiltered_base_packed instead"); + } - // Get chunks and assert that they match - let first_chunks: Vec = (0..self.num_chunks) - .map(|i| vars.local_wires[self.wire_first_chunk_val(i)]) - .collect(); - let second_chunks: Vec = (0..self.num_chunks) - .map(|i| vars.local_wires[self.wire_second_chunk_val(i)]) - .collect(); - - let first_chunks_combined = reduce_with_powers( - &first_chunks, - F::from_canonical_usize(1 << self.chunk_bits()), - ); - let second_chunks_combined = reduce_with_powers( - &second_chunks, - F::from_canonical_usize(1 << self.chunk_bits()), - ); - - yield_constr.one(first_chunks_combined - first_input); - yield_constr.one(second_chunks_combined - second_input); - - let chunk_size = 1 << self.chunk_bits(); - - let mut most_significant_diff_so_far = F::ZERO; - - for i in 0..self.num_chunks { - // Range-check the chunks to be less than `chunk_size`. - let first_product = (0..chunk_size) - .map(|x| first_chunks[i] - F::from_canonical_usize(x)) - .product(); - let second_product = (0..chunk_size) - .map(|x| second_chunks[i] - F::from_canonical_usize(x)) - .product(); - yield_constr.one(first_product); - yield_constr.one(second_product); - - let difference = second_chunks[i] - first_chunks[i]; - let equality_dummy = vars.local_wires[self.wire_equality_dummy(i)]; - let chunks_equal = vars.local_wires[self.wire_chunks_equal(i)]; - - // Two constraints to assert that `chunks_equal` is valid. - yield_constr.one(difference * equality_dummy - (F::ONE - chunks_equal)); - yield_constr.one(chunks_equal * difference); - - // Update `most_significant_diff_so_far`. - let intermediate_value = vars.local_wires[self.wire_intermediate_value(i)]; - yield_constr.one(intermediate_value - chunks_equal * most_significant_diff_so_far); - most_significant_diff_so_far = - intermediate_value + (F::ONE - chunks_equal) * difference; - } - - let most_significant_diff = vars.local_wires[self.wire_most_significant_diff()]; - yield_constr.one(most_significant_diff - most_significant_diff_so_far); - - // Range check `most_significant_diff` to be less than `chunk_size`. - let product = (0..chunk_size) - .map(|x| most_significant_diff - F::from_canonical_usize(x)) - .product(); - yield_constr.one(product); + fn eval_unfiltered_base_batch(&self, vars_base: EvaluationVarsBaseBatch) -> Vec { + self.eval_unfiltered_base_batch_packed(vars_base) } fn eval_unfiltered_recursively( @@ -329,6 +279,78 @@ impl, const D: usize> Gate for AssertLessThan } } +impl, const D: usize> PackedEvaluableBase + for AssertLessThanGate +{ + fn eval_unfiltered_base_packed>( + &self, + vars: EvaluationVarsBasePacked

, + mut yield_constr: StridedConstraintConsumer

, + ) { + let first_input = vars.local_wires[self.wire_first_input()]; + let second_input = vars.local_wires[self.wire_second_input()]; + + // Get chunks and assert that they match + let first_chunks: Vec<_> = (0..self.num_chunks) + .map(|i| vars.local_wires[self.wire_first_chunk_val(i)]) + .collect(); + let second_chunks: Vec<_> = (0..self.num_chunks) + .map(|i| vars.local_wires[self.wire_second_chunk_val(i)]) + .collect(); + + let first_chunks_combined = reduce_with_powers( + &first_chunks, + F::from_canonical_usize(1 << self.chunk_bits()), + ); + let second_chunks_combined = reduce_with_powers( + &second_chunks, + F::from_canonical_usize(1 << self.chunk_bits()), + ); + + yield_constr.one(first_chunks_combined - first_input); + yield_constr.one(second_chunks_combined - second_input); + + let chunk_size = 1 << self.chunk_bits(); + + let mut most_significant_diff_so_far = P::ZERO; + + for i in 0..self.num_chunks { + // Range-check the chunks to be less than `chunk_size`. + let first_product = (0..chunk_size) + .map(|x| first_chunks[i] - F::from_canonical_usize(x)) + .product(); + let second_product = (0..chunk_size) + .map(|x| second_chunks[i] - F::from_canonical_usize(x)) + .product(); + yield_constr.one(first_product); + yield_constr.one(second_product); + + let difference = second_chunks[i] - first_chunks[i]; + let equality_dummy = vars.local_wires[self.wire_equality_dummy(i)]; + let chunks_equal = vars.local_wires[self.wire_chunks_equal(i)]; + + // Two constraints to assert that `chunks_equal` is valid. + yield_constr.one(difference * equality_dummy - (P::ONE - chunks_equal)); + yield_constr.one(chunks_equal * difference); + + // Update `most_significant_diff_so_far`. + let intermediate_value = vars.local_wires[self.wire_intermediate_value(i)]; + yield_constr.one(intermediate_value - chunks_equal * most_significant_diff_so_far); + most_significant_diff_so_far = + intermediate_value + (P::ONE - chunks_equal) * difference; + } + + let most_significant_diff = vars.local_wires[self.wire_most_significant_diff()]; + yield_constr.one(most_significant_diff - most_significant_diff_so_far); + + // Range check `most_significant_diff` to be less than `chunk_size`. + let product = (0..chunk_size) + .map(|x| most_significant_diff - F::from_canonical_usize(x)) + .product(); + yield_constr.one(product); + } +} + #[derive(Debug)] struct AssertLessThanGenerator, const D: usize> { gate_index: usize, diff --git a/src/gates/base_sum.rs b/src/gates/base_sum.rs index de8b7fb3..db5a5d80 100644 --- a/src/gates/base_sum.rs +++ b/src/gates/base_sum.rs @@ -3,7 +3,9 @@ use std::ops::Range; use crate::field::extension_field::target::ExtensionTarget; use crate::field::extension_field::Extendable; use crate::field::field_types::{Field, PrimeField, RichField}; +use crate::field::packed_field::PackedField; use crate::gates::gate::Gate; +use crate::gates::packed_util::PackedEvaluableBase; use crate::gates::util::StridedConstraintConsumer; use crate::iop::generator::{GeneratedValues, SimpleGenerator, WitnessGenerator}; use crate::iop::target::Target; @@ -11,7 +13,10 @@ use crate::iop::witness::{PartitionWitness, Witness}; use crate::plonk::circuit_builder::CircuitBuilder; use crate::plonk::circuit_data::CircuitConfig; use crate::plonk::plonk_common::{reduce_with_powers, reduce_with_powers_ext_recursive}; -use crate::plonk::vars::{EvaluationTargets, EvaluationVars, EvaluationVarsBase}; +use crate::plonk::vars::{ + EvaluationTargets, EvaluationVars, EvaluationVarsBase, EvaluationVarsBaseBatch, + EvaluationVarsBasePacked, +}; /// A gate which can decompose a number into base B little-endian limbs. #[derive(Copy, Clone, Debug)] @@ -60,21 +65,14 @@ impl, const D: usize, const B: usize> Gate for BaseSumGat fn eval_unfiltered_base_one( &self, - vars: EvaluationVarsBase, - mut yield_constr: StridedConstraintConsumer, + _vars: EvaluationVarsBase, + _yield_constr: StridedConstraintConsumer, ) { - let sum = vars.local_wires[Self::WIRE_SUM]; - let limbs = vars.local_wires.view(self.limbs()); - let computed_sum = reduce_with_powers(limbs, F::from_canonical_usize(B)); + panic!("use eval_unfiltered_base_packed instead"); + } - yield_constr.one(computed_sum - sum); - - let constraints_iter = limbs.iter().map(|&limb| { - (0..B) - .map(|i| unsafe { limb.sub_canonical_u64(i as u64) }) - .product::() - }); - yield_constr.many(constraints_iter); + fn eval_unfiltered_base_batch(&self, vars_base: EvaluationVarsBaseBatch) -> Vec { + self.eval_unfiltered_base_batch_packed(vars_base) } fn eval_unfiltered_recursively( @@ -136,6 +134,29 @@ impl, const D: usize, const B: usize> Gate for BaseSumGat } } +impl, const D: usize, const B: usize> PackedEvaluableBase + for BaseSumGate +{ + fn eval_unfiltered_base_packed>( + &self, + vars: EvaluationVarsBasePacked

, + mut yield_constr: StridedConstraintConsumer

, + ) { + let sum = vars.local_wires[Self::WIRE_SUM]; + let limbs = vars.local_wires.view(self.limbs()); + let computed_sum = reduce_with_powers(limbs, F::from_canonical_usize(B)); + + yield_constr.one(computed_sum - sum); + + let constraints_iter = limbs.iter().map(|&limb| { + (0..B) + .map(|i| limb - F::from_canonical_usize(i)) + .product::

() + }); + yield_constr.many(constraints_iter); + } +} + #[derive(Debug)] pub struct BaseSplitGenerator { gate_index: usize, diff --git a/src/gates/comparison.rs b/src/gates/comparison.rs index 45819cf6..d04b4ba4 100644 --- a/src/gates/comparison.rs +++ b/src/gates/comparison.rs @@ -3,7 +3,9 @@ use std::marker::PhantomData; use crate::field::extension_field::target::ExtensionTarget; use crate::field::extension_field::Extendable; use crate::field::field_types::{Field, PrimeField}; +use crate::field::packed_field::PackedField; use crate::gates::gate::Gate; +use crate::gates::packed_util::PackedEvaluableBase; use crate::gates::util::StridedConstraintConsumer; use crate::iop::generator::{GeneratedValues, SimpleGenerator, WitnessGenerator}; use crate::iop::target::Target; @@ -11,7 +13,10 @@ use crate::iop::wire::Wire; use crate::iop::witness::{PartitionWitness, Witness}; use crate::plonk::circuit_builder::CircuitBuilder; use crate::plonk::plonk_common::{reduce_with_powers, reduce_with_powers_ext_recursive}; -use crate::plonk::vars::{EvaluationTargets, EvaluationVars, EvaluationVarsBase}; +use crate::plonk::vars::{ + EvaluationTargets, EvaluationVars, EvaluationVarsBase, EvaluationVarsBaseBatch, + EvaluationVarsBasePacked, +}; use crate::util::{bits_u64, ceil_div_usize}; /// A gate for checking that one value is less than or equal to another. @@ -116,7 +121,7 @@ impl, const D: usize> Gate for ComparisonGate { let chunk_size = 1 << self.chunk_bits(); - let mut most_significant_diff_so_far = F::Extension::ZERO; + let mut most_significant_diff_so_far = ::ZERO; for i in 0..self.num_chunks { // Range-check the chunks to be less than `chunk_size`. @@ -134,14 +139,15 @@ impl, const D: usize> Gate for ComparisonGate { let chunks_equal = vars.local_wires[self.wire_chunks_equal(i)]; // Two constraints to assert that `chunks_equal` is valid. - constraints.push(difference * equality_dummy - (F::Extension::ONE - chunks_equal)); + constraints + .push(difference * equality_dummy - (::ONE - chunks_equal)); constraints.push(chunks_equal * difference); // Update `most_significant_diff_so_far`. let intermediate_value = vars.local_wires[self.wire_intermediate_value(i)]; constraints.push(intermediate_value - chunks_equal * most_significant_diff_so_far); most_significant_diff_so_far = - intermediate_value + (F::Extension::ONE - chunks_equal) * difference; + intermediate_value + (::ONE - chunks_equal) * difference; } let most_significant_diff = vars.local_wires[self.wire_most_significant_diff()]; @@ -153,7 +159,7 @@ impl, const D: usize> Gate for ComparisonGate { // Range-check the bits. for &bit in &most_significant_diff_bits { - constraints.push(bit * (F::Extension::ONE - bit)); + constraints.push(bit * (::ONE - bit)); } let bits_combined = reduce_with_powers(&most_significant_diff_bits, F::Extension::TWO); @@ -169,81 +175,14 @@ impl, const D: usize> Gate for ComparisonGate { fn eval_unfiltered_base_one( &self, - vars: EvaluationVarsBase, - mut yield_constr: StridedConstraintConsumer, + _vars: EvaluationVarsBase, + _yield_constr: StridedConstraintConsumer, ) { - let first_input = vars.local_wires[self.wire_first_input()]; - let second_input = vars.local_wires[self.wire_second_input()]; + panic!("use eval_unfiltered_base_packed instead"); + } - // Get chunks and assert that they match - let first_chunks: Vec = (0..self.num_chunks) - .map(|i| vars.local_wires[self.wire_first_chunk_val(i)]) - .collect(); - let second_chunks: Vec = (0..self.num_chunks) - .map(|i| vars.local_wires[self.wire_second_chunk_val(i)]) - .collect(); - - let first_chunks_combined = reduce_with_powers( - &first_chunks, - F::from_canonical_usize(1 << self.chunk_bits()), - ); - let second_chunks_combined = reduce_with_powers( - &second_chunks, - F::from_canonical_usize(1 << self.chunk_bits()), - ); - - yield_constr.one(first_chunks_combined - first_input); - yield_constr.one(second_chunks_combined - second_input); - - let chunk_size = 1 << self.chunk_bits(); - - let mut most_significant_diff_so_far = F::ZERO; - - for i in 0..self.num_chunks { - // Range-check the chunks to be less than `chunk_size`. - let first_product: F = (0..chunk_size) - .map(|x| first_chunks[i] - F::from_canonical_usize(x)) - .product(); - let second_product: F = (0..chunk_size) - .map(|x| second_chunks[i] - F::from_canonical_usize(x)) - .product(); - yield_constr.one(first_product); - yield_constr.one(second_product); - - let difference = second_chunks[i] - first_chunks[i]; - let equality_dummy = vars.local_wires[self.wire_equality_dummy(i)]; - let chunks_equal = vars.local_wires[self.wire_chunks_equal(i)]; - - // Two constraints to assert that `chunks_equal` is valid. - yield_constr.one(difference * equality_dummy - (F::ONE - chunks_equal)); - yield_constr.one(chunks_equal * difference); - - // Update `most_significant_diff_so_far`. - let intermediate_value = vars.local_wires[self.wire_intermediate_value(i)]; - yield_constr.one(intermediate_value - chunks_equal * most_significant_diff_so_far); - most_significant_diff_so_far = - intermediate_value + (F::ONE - chunks_equal) * difference; - } - - let most_significant_diff = vars.local_wires[self.wire_most_significant_diff()]; - yield_constr.one(most_significant_diff - most_significant_diff_so_far); - - let most_significant_diff_bits: Vec = (0..self.chunk_bits() + 1) - .map(|i| vars.local_wires[self.wire_most_significant_diff_bit(i)]) - .collect(); - - // Range-check the bits. - for &bit in &most_significant_diff_bits { - yield_constr.one(bit * (F::ONE - bit)); - } - - let bits_combined = reduce_with_powers(&most_significant_diff_bits, F::TWO); - let two_n = F::from_canonical_u64(1 << self.chunk_bits()); - yield_constr.one((two_n + most_significant_diff) - bits_combined); - - // Iff first <= second, the top (n + 1st) bit of (2^n - 1 + most_significant_diff) will be 1. - let result_bool = vars.local_wires[self.wire_result_bool()]; - yield_constr.one(result_bool - most_significant_diff_bits[self.chunk_bits()]); + fn eval_unfiltered_base_batch(&self, vars_base: EvaluationVarsBaseBatch) -> Vec { + self.eval_unfiltered_base_batch_packed(vars_base) } fn eval_unfiltered_recursively( @@ -373,6 +312,87 @@ impl, const D: usize> Gate for ComparisonGate { } } +impl, const D: usize> PackedEvaluableBase for ComparisonGate { + fn eval_unfiltered_base_packed>( + &self, + vars: EvaluationVarsBasePacked

, + mut yield_constr: StridedConstraintConsumer

, + ) { + let first_input = vars.local_wires[self.wire_first_input()]; + let second_input = vars.local_wires[self.wire_second_input()]; + + // Get chunks and assert that they match + let first_chunks: Vec<_> = (0..self.num_chunks) + .map(|i| vars.local_wires[self.wire_first_chunk_val(i)]) + .collect(); + let second_chunks: Vec<_> = (0..self.num_chunks) + .map(|i| vars.local_wires[self.wire_second_chunk_val(i)]) + .collect(); + + let first_chunks_combined = reduce_with_powers( + &first_chunks, + F::from_canonical_usize(1 << self.chunk_bits()), + ); + let second_chunks_combined = reduce_with_powers( + &second_chunks, + F::from_canonical_usize(1 << self.chunk_bits()), + ); + + yield_constr.one(first_chunks_combined - first_input); + yield_constr.one(second_chunks_combined - second_input); + + let chunk_size = 1 << self.chunk_bits(); + + let mut most_significant_diff_so_far = P::ZERO; + + for i in 0..self.num_chunks { + // Range-check the chunks to be less than `chunk_size`. + let first_product: P = (0..chunk_size) + .map(|x| first_chunks[i] - F::from_canonical_usize(x)) + .product(); + let second_product: P = (0..chunk_size) + .map(|x| second_chunks[i] - F::from_canonical_usize(x)) + .product(); + yield_constr.one(first_product); + yield_constr.one(second_product); + + let difference = second_chunks[i] - first_chunks[i]; + let equality_dummy = vars.local_wires[self.wire_equality_dummy(i)]; + let chunks_equal = vars.local_wires[self.wire_chunks_equal(i)]; + + // Two constraints to assert that `chunks_equal` is valid. + yield_constr.one(difference * equality_dummy - (P::ONE - chunks_equal)); + yield_constr.one(chunks_equal * difference); + + // Update `most_significant_diff_so_far`. + let intermediate_value = vars.local_wires[self.wire_intermediate_value(i)]; + yield_constr.one(intermediate_value - chunks_equal * most_significant_diff_so_far); + most_significant_diff_so_far = + intermediate_value + (P::ONE - chunks_equal) * difference; + } + + let most_significant_diff = vars.local_wires[self.wire_most_significant_diff()]; + yield_constr.one(most_significant_diff - most_significant_diff_so_far); + + let most_significant_diff_bits: Vec<_> = (0..self.chunk_bits() + 1) + .map(|i| vars.local_wires[self.wire_most_significant_diff_bit(i)]) + .collect(); + + // Range-check the bits. + for &bit in &most_significant_diff_bits { + yield_constr.one(bit * (P::ONE - bit)); + } + + let bits_combined = reduce_with_powers(&most_significant_diff_bits, F::TWO); + let two_n = F::from_canonical_u64(1 << self.chunk_bits()); + yield_constr.one((most_significant_diff + two_n) - bits_combined); + + // Iff first <= second, the top (n + 1st) bit of (2^n - 1 + most_significant_diff) will be 1. + let result_bool = vars.local_wires[self.wire_result_bool()]; + yield_constr.one(result_bool - most_significant_diff_bits[self.chunk_bits()]); + } +} + #[derive(Debug)] struct ComparisonGenerator, const D: usize> { gate_index: usize, diff --git a/src/gates/constant.rs b/src/gates/constant.rs index 3a790ee2..e3dfc47f 100644 --- a/src/gates/constant.rs +++ b/src/gates/constant.rs @@ -3,14 +3,19 @@ use std::ops::Range; use crate::field::extension_field::target::ExtensionTarget; use crate::field::extension_field::Extendable; use crate::field::field_types::Field; +use crate::field::packed_field::PackedField; use crate::gates::gate::Gate; +use crate::gates::packed_util::PackedEvaluableBase; use crate::gates::util::StridedConstraintConsumer; use crate::iop::generator::{GeneratedValues, SimpleGenerator, WitnessGenerator}; use crate::iop::target::Target; use crate::iop::wire::Wire; use crate::iop::witness::PartitionWitness; use crate::plonk::circuit_builder::CircuitBuilder; -use crate::plonk::vars::{EvaluationTargets, EvaluationVars, EvaluationVarsBase}; +use crate::plonk::vars::{ + EvaluationTargets, EvaluationVars, EvaluationVarsBase, EvaluationVarsBaseBatch, + EvaluationVarsBasePacked, +}; /// A gate which takes a single constant parameter and outputs that value. #[derive(Copy, Clone, Debug)] @@ -42,14 +47,14 @@ impl, const D: usize> Gate for ConstantGate { fn eval_unfiltered_base_one( &self, - vars: EvaluationVarsBase, - mut yield_constr: StridedConstraintConsumer, + _vars: EvaluationVarsBase, + _yield_constr: StridedConstraintConsumer, ) { - yield_constr.many( - self.consts_inputs() - .zip(self.wires_outputs()) - .map(|(con, out)| vars.local_constants[con] - vars.local_wires[out]), - ); + panic!("use eval_unfiltered_base_packed instead"); + } + + fn eval_unfiltered_base_batch(&self, vars_base: EvaluationVarsBaseBatch) -> Vec { + self.eval_unfiltered_base_batch_packed(vars_base) } fn eval_unfiltered_recursively( @@ -95,6 +100,20 @@ impl, const D: usize> Gate for ConstantGate { } } +impl, const D: usize> PackedEvaluableBase for ConstantGate { + fn eval_unfiltered_base_packed>( + &self, + vars: EvaluationVarsBasePacked

, + mut yield_constr: StridedConstraintConsumer

, + ) { + yield_constr.many( + self.consts_inputs() + .zip(self.wires_outputs()) + .map(|(con, out)| vars.local_constants[con] - vars.local_wires[out]), + ); + } +} + #[derive(Debug)] struct ConstantGenerator { gate_index: usize, diff --git a/src/gates/exponentiation.rs b/src/gates/exponentiation.rs index ed4fd4d6..5eb377b6 100644 --- a/src/gates/exponentiation.rs +++ b/src/gates/exponentiation.rs @@ -3,7 +3,9 @@ use std::marker::PhantomData; use crate::field::extension_field::target::ExtensionTarget; use crate::field::extension_field::Extendable; use crate::field::field_types::Field; +use crate::field::packed_field::PackedField; use crate::gates::gate::Gate; +use crate::gates::packed_util::PackedEvaluableBase; use crate::gates::util::StridedConstraintConsumer; use crate::iop::generator::{GeneratedValues, SimpleGenerator, WitnessGenerator}; use crate::iop::target::Target; @@ -11,7 +13,10 @@ use crate::iop::wire::Wire; use crate::iop::witness::{PartitionWitness, Witness}; use crate::plonk::circuit_builder::CircuitBuilder; use crate::plonk::circuit_data::CircuitConfig; -use crate::plonk::vars::{EvaluationTargets, EvaluationVars, EvaluationVarsBase}; +use crate::plonk::vars::{ + EvaluationTargets, EvaluationVars, EvaluationVarsBase, EvaluationVarsBaseBatch, + EvaluationVarsBasePacked, +}; /// A gate for raising a value to a power. #[derive(Clone, Debug)] @@ -81,15 +86,15 @@ impl, const D: usize> Gate for ExponentiationGate { for i in 0..self.num_power_bits { let prev_intermediate_value = if i == 0 { - F::Extension::ONE + ::ONE } else { - intermediate_values[i - 1].square() + ::square(&intermediate_values[i - 1]) }; // power_bits is in LE order, but we accumulate in BE order. let cur_bit = power_bits[self.num_power_bits - i - 1]; - let not_cur_bit = F::Extension::ONE - cur_bit; + let not_cur_bit = ::ONE - cur_bit; let computed_intermediate_value = prev_intermediate_value * (cur_bit * base + not_cur_bit); constraints.push(computed_intermediate_value - intermediate_values[i]); @@ -102,37 +107,14 @@ impl, const D: usize> Gate for ExponentiationGate { fn eval_unfiltered_base_one( &self, - vars: EvaluationVarsBase, - mut yield_constr: StridedConstraintConsumer, + _vars: EvaluationVarsBase, + _yield_constr: StridedConstraintConsumer, ) { - let base = vars.local_wires[self.wire_base()]; + panic!("use eval_unfiltered_base_packed instead"); + } - let power_bits: Vec<_> = (0..self.num_power_bits) - .map(|i| vars.local_wires[self.wire_power_bit(i)]) - .collect(); - let intermediate_values: Vec<_> = (0..self.num_power_bits) - .map(|i| vars.local_wires[self.wire_intermediate_value(i)]) - .collect(); - - let output = vars.local_wires[self.wire_output()]; - - for i in 0..self.num_power_bits { - let prev_intermediate_value = if i == 0 { - F::ONE - } else { - intermediate_values[i - 1].square() - }; - - // power_bits is in LE order, but we accumulate in BE order. - let cur_bit = power_bits[self.num_power_bits - i - 1]; - - let not_cur_bit = F::ONE - cur_bit; - let computed_intermediate_value = - prev_intermediate_value * (cur_bit * base + not_cur_bit); - yield_constr.one(computed_intermediate_value - intermediate_values[i]); - } - - yield_constr.one(output - intermediate_values[self.num_power_bits - 1]); + fn eval_unfiltered_base_batch(&self, vars_base: EvaluationVarsBaseBatch) -> Vec { + self.eval_unfiltered_base_batch_packed(vars_base) } fn eval_unfiltered_recursively( @@ -205,6 +187,43 @@ impl, const D: usize> Gate for ExponentiationGate { } } +impl, const D: usize> PackedEvaluableBase for ExponentiationGate { + fn eval_unfiltered_base_packed>( + &self, + vars: EvaluationVarsBasePacked

, + mut yield_constr: StridedConstraintConsumer

, + ) { + let base = vars.local_wires[self.wire_base()]; + + let power_bits: Vec<_> = (0..self.num_power_bits) + .map(|i| vars.local_wires[self.wire_power_bit(i)]) + .collect(); + let intermediate_values: Vec<_> = (0..self.num_power_bits) + .map(|i| vars.local_wires[self.wire_intermediate_value(i)]) + .collect(); + + let output = vars.local_wires[self.wire_output()]; + + for i in 0..self.num_power_bits { + let prev_intermediate_value = if i == 0 { + P::ONE + } else { + intermediate_values[i - 1].square() + }; + + // power_bits is in LE order, but we accumulate in BE order. + let cur_bit = power_bits[self.num_power_bits - i - 1]; + + let not_cur_bit = P::ONE - cur_bit; + let computed_intermediate_value = + prev_intermediate_value * (cur_bit * base + not_cur_bit); + yield_constr.one(computed_intermediate_value - intermediate_values[i]); + } + + yield_constr.one(output - intermediate_values[self.num_power_bits - 1]); + } +} + #[derive(Debug)] struct ExponentiationGenerator, const D: usize> { gate_index: usize, diff --git a/src/gates/gmimc.rs b/src/gates/gmimc.rs index d09d63dd..1819ccef 100644 --- a/src/gates/gmimc.rs +++ b/src/gates/gmimc.rs @@ -3,7 +3,9 @@ use std::marker::PhantomData; use crate::field::extension_field::target::ExtensionTarget; use crate::field::extension_field::Extendable; use crate::field::field_types::Field; +use crate::field::packed_field::PackedField; use crate::gates::gate::Gate; +use crate::gates::packed_util::PackedEvaluableBase; use crate::gates::util::StridedConstraintConsumer; use crate::hash::gmimc; use crate::hash::gmimc::GMiMC; @@ -12,7 +14,10 @@ use crate::iop::target::Target; use crate::iop::wire::Wire; use crate::iop::witness::{PartitionWitness, Witness}; use crate::plonk::circuit_builder::CircuitBuilder; -use crate::plonk::vars::{EvaluationTargets, EvaluationVars, EvaluationVarsBase}; +use crate::plonk::vars::{ + EvaluationTargets, EvaluationVars, EvaluationVarsBase, EvaluationVarsBaseBatch, + EvaluationVarsBasePacked, +}; /// Evaluates a full GMiMC permutation with 12 state elements. /// @@ -68,7 +73,7 @@ impl + GMiMC, const D: usize, const WIDTH: usize> Gate::ONE)); let mut state = Vec::with_capacity(12); for i in 0..4 { @@ -87,7 +92,7 @@ impl + GMiMC, const D: usize, const WIDTH: usize> Gate::ZERO; for r in 0..gmimc::NUM_ROUNDS { let active = r % WIDTH; @@ -110,47 +115,14 @@ impl + GMiMC, const D: usize, const WIDTH: usize> Gate, - mut yield_constr: StridedConstraintConsumer, + _vars: EvaluationVarsBase, + _yield_constr: StridedConstraintConsumer, ) { - // Assert that `swap` is binary. - let swap = vars.local_wires[Self::WIRE_SWAP]; - yield_constr.one(swap * swap.sub_one()); + panic!("use eval_unfiltered_base_packed instead"); + } - let mut state = Vec::with_capacity(12); - for i in 0..4 { - let a = vars.local_wires[i]; - let b = vars.local_wires[i + 4]; - state.push(a + swap * (b - a)); - } - for i in 0..4 { - let a = vars.local_wires[i + 4]; - let b = vars.local_wires[i]; - state.push(a + swap * (b - a)); - } - for i in 8..12 { - state.push(vars.local_wires[i]); - } - - // Value that is implicitly added to each element. - // See https://affine.group/2020/02/starkware-challenge - let mut addition_buffer = F::ZERO; - - for r in 0..gmimc::NUM_ROUNDS { - let active = r % WIDTH; - let constant = F::from_canonical_u64(>::ROUND_CONSTANTS[r]); - let cubing_input = state[active] + addition_buffer + constant; - let cubing_input_wire = vars.local_wires[Self::wire_cubing_input(r)]; - yield_constr.one(cubing_input - cubing_input_wire); - let f = cubing_input_wire.cube(); - addition_buffer += f; - state[active] -= f; - } - - for i in 0..WIDTH { - state[i] += addition_buffer; - yield_constr.one(state[i] - vars.local_wires[Self::wire_output(i)]); - } + fn eval_unfiltered_base_batch(&self, vars_base: EvaluationVarsBaseBatch) -> Vec { + self.eval_unfiltered_base_batch_packed(vars_base) } fn eval_unfiltered_recursively( @@ -236,6 +208,55 @@ impl + GMiMC, const D: usize, const WIDTH: usize> Gate + GMiMC, const D: usize, const WIDTH: usize> PackedEvaluableBase + for GMiMCGate +{ + fn eval_unfiltered_base_packed>( + &self, + vars: EvaluationVarsBasePacked

, + mut yield_constr: StridedConstraintConsumer

, + ) { + // Assert that `swap` is binary. + let swap = vars.local_wires[Self::WIRE_SWAP]; + yield_constr.one(swap * (swap - F::ONE)); + + let mut state = Vec::with_capacity(12); + for i in 0..4 { + let a = vars.local_wires[i]; + let b = vars.local_wires[i + 4]; + state.push(a + swap * (b - a)); + } + for i in 0..4 { + let a = vars.local_wires[i + 4]; + let b = vars.local_wires[i]; + state.push(a + swap * (b - a)); + } + for i in 8..12 { + state.push(vars.local_wires[i]); + } + + // Value that is implicitly added to each element. + // See https://affine.group/2020/02/starkware-challenge + let mut addition_buffer = P::ZERO; + + for r in 0..gmimc::NUM_ROUNDS { + let active = r % WIDTH; + let constant = F::from_canonical_u64(>::ROUND_CONSTANTS[r]); + let cubing_input = state[active] + addition_buffer + constant; + let cubing_input_wire = vars.local_wires[Self::wire_cubing_input(r)]; + yield_constr.one(cubing_input - cubing_input_wire); + let f = cubing_input_wire.square() * cubing_input_wire; + addition_buffer += f; + state[active] -= f; + } + + for i in 0..WIDTH { + state[i] += addition_buffer; + yield_constr.one(state[i] - vars.local_wires[Self::wire_output(i)]); + } + } +} + #[derive(Debug)] struct GMiMCGenerator + GMiMC, const D: usize, const WIDTH: usize> { gate_index: usize, diff --git a/src/gates/mod.rs b/src/gates/mod.rs index dbbe174c..b26f88f8 100644 --- a/src/gates/mod.rs +++ b/src/gates/mod.rs @@ -17,6 +17,7 @@ pub mod interpolation; pub mod low_degree_interpolation; pub mod multiplication_extension; pub mod noop; +mod packed_util; pub mod poseidon; pub(crate) mod poseidon_mds; pub(crate) mod public_input; diff --git a/src/gates/packed_util.rs b/src/gates/packed_util.rs new file mode 100644 index 00000000..b8874b40 --- /dev/null +++ b/src/gates/packed_util.rs @@ -0,0 +1,39 @@ +use crate::field::extension_field::Extendable; +use crate::field::packable::Packable; +use crate::field::packed_field::PackedField; +use crate::gates::gate::Gate; +use crate::gates::util::StridedConstraintConsumer; +use crate::plonk::vars::{EvaluationVarsBaseBatch, EvaluationVarsBasePacked}; + +pub trait PackedEvaluableBase, const D: usize>: Gate { + fn eval_unfiltered_base_packed>( + &self, + vars_base: EvaluationVarsBasePacked

, + yield_constr: StridedConstraintConsumer

, + ); + + /// Evaluates entire batch of points. Returns a matrix of constraints. Constraint `j` for point + /// `i` is at `index j * batch_size + i`. + fn eval_unfiltered_base_batch_packed(&self, vars_batch: EvaluationVarsBaseBatch) -> Vec { + let mut res = vec![F::ZERO; vars_batch.len() * self.num_constraints()]; + let (vars_packed_iter, vars_leftovers_iter) = vars_batch.pack::<::Packing>(); + let leftovers_start = vars_batch.len() - vars_leftovers_iter.len(); + for (i, vars_packed) in vars_packed_iter.enumerate() { + self.eval_unfiltered_base_packed( + vars_packed, + StridedConstraintConsumer::new( + &mut res[..], + vars_batch.len(), + ::Packing::WIDTH * i, + ), + ); + } + for (i, vars_leftovers) in vars_leftovers_iter.enumerate() { + self.eval_unfiltered_base_packed( + vars_leftovers, + StridedConstraintConsumer::new(&mut res[..], vars_batch.len(), leftovers_start + i), + ); + } + res + } +} diff --git a/src/gates/public_input.rs b/src/gates/public_input.rs index 116d8917..1e32cab6 100644 --- a/src/gates/public_input.rs +++ b/src/gates/public_input.rs @@ -2,11 +2,16 @@ use std::ops::Range; use crate::field::extension_field::target::ExtensionTarget; use crate::field::extension_field::Extendable; +use crate::field::packed_field::PackedField; use crate::gates::gate::Gate; +use crate::gates::packed_util::PackedEvaluableBase; use crate::gates::util::StridedConstraintConsumer; use crate::iop::generator::WitnessGenerator; use crate::plonk::circuit_builder::CircuitBuilder; -use crate::plonk::vars::{EvaluationTargets, EvaluationVars, EvaluationVarsBase}; +use crate::plonk::vars::{ + EvaluationTargets, EvaluationVars, EvaluationVarsBase, EvaluationVarsBaseBatch, + EvaluationVarsBasePacked, +}; /// A gate whose first four wires will be equal to a hash of public inputs. pub struct PublicInputGate; @@ -31,14 +36,14 @@ impl, const D: usize> Gate for PublicInputGate { fn eval_unfiltered_base_one( &self, - vars: EvaluationVarsBase, - mut yield_constr: StridedConstraintConsumer, + _vars: EvaluationVarsBase, + _yield_constr: StridedConstraintConsumer, ) { - yield_constr.many( - Self::wires_public_inputs_hash() - .zip(vars.public_inputs_hash.elements) - .map(|(wire, hash_part)| vars.local_wires[wire] - hash_part), - ) + panic!("use eval_unfiltered_base_packed instead"); + } + + fn eval_unfiltered_base_batch(&self, vars_base: EvaluationVarsBaseBatch) -> Vec { + self.eval_unfiltered_base_batch_packed(vars_base) } fn eval_unfiltered_recursively( @@ -80,6 +85,20 @@ impl, const D: usize> Gate for PublicInputGate { } } +impl, const D: usize> PackedEvaluableBase for PublicInputGate { + fn eval_unfiltered_base_packed>( + &self, + vars: EvaluationVarsBasePacked

, + mut yield_constr: StridedConstraintConsumer

, + ) { + yield_constr.many( + Self::wires_public_inputs_hash() + .zip(vars.public_inputs_hash.elements) + .map(|(wire, hash_part)| vars.local_wires[wire] - hash_part), + ) + } +} + #[cfg(test)] mod tests { use crate::field::goldilocks_field::GoldilocksField; diff --git a/src/gates/random_access.rs b/src/gates/random_access.rs index 06c1274f..b1197883 100644 --- a/src/gates/random_access.rs +++ b/src/gates/random_access.rs @@ -5,7 +5,9 @@ use itertools::Itertools; use crate::field::extension_field::target::ExtensionTarget; use crate::field::extension_field::Extendable; use crate::field::field_types::Field; +use crate::field::packed_field::PackedField; use crate::gates::gate::Gate; +use crate::gates::packed_util::PackedEvaluableBase; use crate::gates::util::StridedConstraintConsumer; use crate::iop::generator::{GeneratedValues, SimpleGenerator, WitnessGenerator}; use crate::iop::target::Target; @@ -13,7 +15,10 @@ use crate::iop::wire::Wire; use crate::iop::witness::{PartitionWitness, Witness}; use crate::plonk::circuit_builder::CircuitBuilder; use crate::plonk::circuit_data::CircuitConfig; -use crate::plonk::vars::{EvaluationTargets, EvaluationVars, EvaluationVarsBase}; +use crate::plonk::vars::{ + EvaluationTargets, EvaluationVars, EvaluationVarsBase, EvaluationVarsBaseBatch, + EvaluationVarsBasePacked, +}; /// A gate for checking that a particular element of a list matches a given value. #[derive(Copy, Clone, Debug)] @@ -99,14 +104,14 @@ impl, const D: usize> Gate for RandomAccessGate { // Assert that each bit wire value is indeed boolean. for &b in &bits { - constraints.push(b * (b - F::Extension::ONE)); + constraints.push(b * (b - ::ONE)); } // Assert that the binary decomposition was correct. let reconstructed_index = bits .iter() .rev() - .fold(F::Extension::ZERO, |acc, &b| acc.double() + b); + .fold(::ZERO, |acc, &b| acc.double() + b); constraints.push(reconstructed_index - access_index); // Repeatedly fold the list, selecting the left or right item from each pair based on @@ -128,41 +133,14 @@ impl, const D: usize> Gate for RandomAccessGate { fn eval_unfiltered_base_one( &self, - vars: EvaluationVarsBase, - mut yield_constr: StridedConstraintConsumer, + _vars: EvaluationVarsBase, + _yield_constr: StridedConstraintConsumer, ) { - for copy in 0..self.num_copies { - let access_index = vars.local_wires[self.wire_access_index(copy)]; - let mut list_items = (0..self.vec_size()) - .map(|i| vars.local_wires[self.wire_list_item(i, copy)]) - .collect::>(); - let claimed_element = vars.local_wires[self.wire_claimed_element(copy)]; - let bits = (0..self.bits) - .map(|i| vars.local_wires[self.wire_bit(i, copy)]) - .collect::>(); + panic!("use eval_unfiltered_base_packed instead"); + } - // Assert that each bit wire value is indeed boolean. - for &b in &bits { - yield_constr.one(b * (b - F::ONE)); - } - - // Assert that the binary decomposition was correct. - let reconstructed_index = bits.iter().rev().fold(F::ZERO, |acc, &b| acc.double() + b); - yield_constr.one(reconstructed_index - access_index); - - // Repeatedly fold the list, selecting the left or right item from each pair based on - // the corresponding bit. - for b in bits { - list_items = list_items - .iter() - .tuples() - .map(|(&x, &y)| x + b * (y - x)) - .collect() - } - - debug_assert_eq!(list_items.len(), 1); - yield_constr.one(list_items[0] - claimed_element); - } + fn eval_unfiltered_base_batch(&self, vars_base: EvaluationVarsBaseBatch) -> Vec { + self.eval_unfiltered_base_batch_packed(vars_base) } fn eval_unfiltered_recursively( @@ -251,6 +229,47 @@ impl, const D: usize> Gate for RandomAccessGate { } } +impl, const D: usize> PackedEvaluableBase for RandomAccessGate { + fn eval_unfiltered_base_packed>( + &self, + vars: EvaluationVarsBasePacked

, + mut yield_constr: StridedConstraintConsumer

, + ) { + for copy in 0..self.num_copies { + let access_index = vars.local_wires[self.wire_access_index(copy)]; + let mut list_items = (0..self.vec_size()) + .map(|i| vars.local_wires[self.wire_list_item(i, copy)]) + .collect::>(); + let claimed_element = vars.local_wires[self.wire_claimed_element(copy)]; + let bits = (0..self.bits) + .map(|i| vars.local_wires[self.wire_bit(i, copy)]) + .collect::>(); + + // Assert that each bit wire value is indeed boolean. + for &b in &bits { + yield_constr.one(b * (b - F::ONE)); + } + + // Assert that the binary decomposition was correct. + let reconstructed_index = bits.iter().rev().fold(P::ZERO, |acc, &b| acc + acc + b); + yield_constr.one(reconstructed_index - access_index); + + // Repeatedly fold the list, selecting the left or right item from each pair based on + // the corresponding bit. + for b in bits { + list_items = list_items + .iter() + .tuples() + .map(|(&x, &y)| x + b * (y - x)) + .collect() + } + + debug_assert_eq!(list_items.len(), 1); + yield_constr.one(list_items[0] - claimed_element); + } + } +} + #[derive(Debug)] struct RandomAccessGenerator, const D: usize> { gate_index: usize, diff --git a/src/gates/subtraction_u32.rs b/src/gates/subtraction_u32.rs index de884a24..aaa5dd09 100644 --- a/src/gates/subtraction_u32.rs +++ b/src/gates/subtraction_u32.rs @@ -3,7 +3,9 @@ use std::marker::PhantomData; use crate::field::extension_field::target::ExtensionTarget; use crate::field::extension_field::Extendable; use crate::field::field_types::{Field, RichField}; +use crate::field::packed_field::PackedField; use crate::gates::gate::Gate; +use crate::gates::packed_util::PackedEvaluableBase; use crate::gates::util::StridedConstraintConsumer; use crate::iop::generator::{GeneratedValues, SimpleGenerator, WitnessGenerator}; use crate::iop::target::Target; @@ -11,7 +13,10 @@ use crate::iop::wire::Wire; use crate::iop::witness::{PartitionWitness, Witness}; use crate::plonk::circuit_builder::CircuitBuilder; use crate::plonk::circuit_data::CircuitConfig; -use crate::plonk::vars::{EvaluationTargets, EvaluationVars, EvaluationVarsBase}; +use crate::plonk::vars::{ + EvaluationTargets, EvaluationVars, EvaluationVarsBase, EvaluationVarsBaseBatch, + EvaluationVarsBasePacked, +}; /// A gate to perform a subtraction on 32-bit limbs: given `x`, `y`, and `borrow`, it returns /// the result `x - y - borrow` and, if this underflows, a new `borrow`. Inputs are not range-checked. @@ -93,7 +98,7 @@ impl, const D: usize> Gate for U32Subtraction constraints.push(output_result - (result_initial + base * output_borrow)); // Range-check output_result to be at most 32 bits. - let mut combined_limbs = F::Extension::ZERO; + let mut combined_limbs = ::ZERO; let limb_base = F::Extension::from_canonical_u64(1u64 << Self::limb_bits()); for j in (0..Self::num_limbs()).rev() { let this_limb = vars.local_wires[self.wire_ith_output_jth_limb(i, j)]; @@ -108,7 +113,7 @@ impl, const D: usize> Gate for U32Subtraction constraints.push(combined_limbs - output_result); // Range-check output_borrow to be one bit. - constraints.push(output_borrow * (F::Extension::ONE - output_borrow)); + constraints.push(output_borrow * (::ONE - output_borrow)); } constraints @@ -116,40 +121,14 @@ impl, const D: usize> Gate for U32Subtraction fn eval_unfiltered_base_one( &self, - vars: EvaluationVarsBase, - mut yield_constr: StridedConstraintConsumer, + _vars: EvaluationVarsBase, + _yield_constr: StridedConstraintConsumer, ) { - for i in 0..self.num_ops { - let input_x = vars.local_wires[self.wire_ith_input_x(i)]; - let input_y = vars.local_wires[self.wire_ith_input_y(i)]; - let input_borrow = vars.local_wires[self.wire_ith_input_borrow(i)]; + panic!("use eval_unfiltered_base_packed instead"); + } - let result_initial = input_x - input_y - input_borrow; - let base = F::from_canonical_u64(1 << 32u64); - - let output_result = vars.local_wires[self.wire_ith_output_result(i)]; - let output_borrow = vars.local_wires[self.wire_ith_output_borrow(i)]; - - yield_constr.one(output_result - (result_initial + base * output_borrow)); - - // Range-check output_result to be at most 32 bits. - let mut combined_limbs = F::ZERO; - let limb_base = F::from_canonical_u64(1u64 << Self::limb_bits()); - for j in (0..Self::num_limbs()).rev() { - let this_limb = vars.local_wires[self.wire_ith_output_jth_limb(i, j)]; - let max_limb = 1 << Self::limb_bits(); - let product = (0..max_limb) - .map(|x| this_limb - F::from_canonical_usize(x)) - .product(); - yield_constr.one(product); - - combined_limbs = limb_base * combined_limbs + this_limb; - } - yield_constr.one(combined_limbs - output_result); - - // Range-check output_borrow to be one bit. - yield_constr.one(output_borrow * (F::ONE - output_borrow)); - } + fn eval_unfiltered_base_batch(&self, vars_base: EvaluationVarsBaseBatch) -> Vec { + self.eval_unfiltered_base_batch_packed(vars_base) } fn eval_unfiltered_recursively( @@ -240,6 +219,48 @@ impl, const D: usize> Gate for U32Subtraction } } +impl, const D: usize> PackedEvaluableBase + for U32SubtractionGate +{ + fn eval_unfiltered_base_packed>( + &self, + vars: EvaluationVarsBasePacked

, + mut yield_constr: StridedConstraintConsumer

, + ) { + for i in 0..self.num_ops { + let input_x = vars.local_wires[self.wire_ith_input_x(i)]; + let input_y = vars.local_wires[self.wire_ith_input_y(i)]; + let input_borrow = vars.local_wires[self.wire_ith_input_borrow(i)]; + + let result_initial = input_x - input_y - input_borrow; + let base = F::from_canonical_u64(1 << 32u64); + + let output_result = vars.local_wires[self.wire_ith_output_result(i)]; + let output_borrow = vars.local_wires[self.wire_ith_output_borrow(i)]; + + yield_constr.one(output_result - (result_initial + output_borrow * base)); + + // Range-check output_result to be at most 32 bits. + let mut combined_limbs = P::ZERO; + let limb_base = F::from_canonical_u64(1u64 << Self::limb_bits()); + for j in (0..Self::num_limbs()).rev() { + let this_limb = vars.local_wires[self.wire_ith_output_jth_limb(i, j)]; + let max_limb = 1 << Self::limb_bits(); + let product = (0..max_limb) + .map(|x| this_limb - F::from_canonical_usize(x)) + .product(); + yield_constr.one(product); + + combined_limbs = combined_limbs * limb_base + this_limb; + } + yield_constr.one(combined_limbs - output_result); + + // Range-check output_borrow to be one bit. + yield_constr.one(output_borrow * (P::ONE - output_borrow)); + } + } +} + #[derive(Clone, Debug)] struct U32SubtractionGenerator, const D: usize> { gate: U32SubtractionGate, diff --git a/src/gates/switch.rs b/src/gates/switch.rs index 9026201e..fc4da35c 100644 --- a/src/gates/switch.rs +++ b/src/gates/switch.rs @@ -5,7 +5,9 @@ use array_tool::vec::Union; use crate::field::extension_field::target::ExtensionTarget; use crate::field::extension_field::Extendable; use crate::field::field_types::Field; +use crate::field::packed_field::PackedField; use crate::gates::gate::Gate; +use crate::gates::packed_util::PackedEvaluableBase; use crate::gates::util::StridedConstraintConsumer; use crate::iop::generator::{GeneratedValues, WitnessGenerator}; use crate::iop::target::Target; @@ -13,7 +15,10 @@ use crate::iop::wire::Wire; use crate::iop::witness::{PartitionWitness, Witness}; use crate::plonk::circuit_builder::CircuitBuilder; use crate::plonk::circuit_data::CircuitConfig; -use crate::plonk::vars::{EvaluationTargets, EvaluationVars, EvaluationVarsBase}; +use crate::plonk::vars::{ + EvaluationTargets, EvaluationVars, EvaluationVarsBase, EvaluationVarsBaseBatch, + EvaluationVarsBasePacked, +}; /// A gate for conditionally swapping input values based on a boolean. #[derive(Clone, Debug)] @@ -77,7 +82,7 @@ impl, const D: usize> Gate for SwitchGate { for c in 0..self.num_copies { let switch_bool = vars.local_wires[self.wire_switch_bool(c)]; - let not_switch = F::Extension::ONE - switch_bool; + let not_switch = ::ONE - switch_bool; for e in 0..self.chunk_size { let first_input = vars.local_wires[self.wire_first_input(c, e)]; @@ -97,25 +102,14 @@ impl, const D: usize> Gate for SwitchGate { fn eval_unfiltered_base_one( &self, - vars: EvaluationVarsBase, - mut yield_constr: StridedConstraintConsumer, + _vars: EvaluationVarsBase, + _yield_constr: StridedConstraintConsumer, ) { - for c in 0..self.num_copies { - let switch_bool = vars.local_wires[self.wire_switch_bool(c)]; - let not_switch = F::ONE - switch_bool; + panic!("use eval_unfiltered_base_packed instead"); + } - for e in 0..self.chunk_size { - let first_input = vars.local_wires[self.wire_first_input(c, e)]; - let second_input = vars.local_wires[self.wire_second_input(c, e)]; - let first_output = vars.local_wires[self.wire_first_output(c, e)]; - let second_output = vars.local_wires[self.wire_second_output(c, e)]; - - yield_constr.one(switch_bool * (first_input - second_output)); - yield_constr.one(switch_bool * (second_input - first_output)); - yield_constr.one(not_switch * (first_input - first_output)); - yield_constr.one(not_switch * (second_input - second_output)); - } - } + fn eval_unfiltered_base_batch(&self, vars_base: EvaluationVarsBaseBatch) -> Vec { + self.eval_unfiltered_base_batch_packed(vars_base) } fn eval_unfiltered_recursively( @@ -194,6 +188,31 @@ impl, const D: usize> Gate for SwitchGate { } } +impl, const D: usize> PackedEvaluableBase for SwitchGate { + fn eval_unfiltered_base_packed>( + &self, + vars: EvaluationVarsBasePacked

, + mut yield_constr: StridedConstraintConsumer

, + ) { + for c in 0..self.num_copies { + let switch_bool = vars.local_wires[self.wire_switch_bool(c)]; + let not_switch = P::ONE - switch_bool; + + for e in 0..self.chunk_size { + let first_input = vars.local_wires[self.wire_first_input(c, e)]; + let second_input = vars.local_wires[self.wire_second_input(c, e)]; + let first_output = vars.local_wires[self.wire_first_output(c, e)]; + let second_output = vars.local_wires[self.wire_second_output(c, e)]; + + yield_constr.one(switch_bool * (first_input - second_output)); + yield_constr.one(switch_bool * (second_input - first_output)); + yield_constr.one(not_switch * (first_input - first_output)); + yield_constr.one(not_switch * (second_input - second_output)); + } + } + } +} + #[derive(Debug)] struct SwitchGenerator, const D: usize> { gate_index: usize, diff --git a/src/gates/util.rs b/src/gates/util.rs index 3f26db50..8f83e445 100644 --- a/src/gates/util.rs +++ b/src/gates/util.rs @@ -1,22 +1,23 @@ use std::marker::PhantomData; -use crate::field::field_types::Field; +use crate::field::packed_field::PackedField; /// Writes constraints yielded by a gate to a buffer, with a given stride. /// Permits us to abstract the underlying memory layout. In particular, we can make a matrix of /// constraints where every column is an evaluation point and every row is a constraint index, with /// the matrix stored in row-contiguous form. -pub struct StridedConstraintConsumer<'a, F: Field> { +pub struct StridedConstraintConsumer<'a, P: PackedField> { // This is a particularly neat way of doing this, more so than a slice. We increase start by // stride at every step and terminate when it equals end. - start: *mut F, - end: *mut F, + start: *mut P::Scalar, + end: *mut P::Scalar, stride: usize, - _phantom: PhantomData<&'a mut [F]>, + _phantom: PhantomData<&'a mut [P::Scalar]>, } -impl<'a, F: Field> StridedConstraintConsumer<'a, F> { - pub fn new(buffer: &'a mut [F], stride: usize, offset: usize) -> Self { +impl<'a, P: PackedField> StridedConstraintConsumer<'a, P> { + pub fn new(buffer: &'a mut [P::Scalar], stride: usize, offset: usize) -> Self { + assert!(stride >= P::WIDTH); assert!(offset < stride); assert_eq!(buffer.len() % stride, 0); let ptr_range = buffer.as_mut_ptr_range(); @@ -38,12 +39,12 @@ impl<'a, F: Field> StridedConstraintConsumer<'a, F> { } /// Emit one constraint. - pub fn one(&mut self, constraint: F) { + pub fn one(&mut self, constraint: P) { if self.start != self.end { // # Safety // The checks in `new` guarantee that this points to valid space. unsafe { - *self.start = constraint; + *self.start.cast() = constraint; } // See the comment in `new`. `wrapping_add` is needed to avoid UB if we've just // exhausted our buffer (and hence we're setting `self.start` to point past the end). @@ -54,7 +55,7 @@ impl<'a, F: Field> StridedConstraintConsumer<'a, F> { } /// Convenience method that calls `.one()` multiple times. - pub fn many>(&mut self, constraints: I) { + pub fn many>(&mut self, constraints: I) { constraints .into_iter() .for_each(|constraint| self.one(constraint)); diff --git a/src/plonk/plonk_common.rs b/src/plonk/plonk_common.rs index 89be659e..0bf172ae 100644 --- a/src/plonk/plonk_common.rs +++ b/src/plonk/plonk_common.rs @@ -1,6 +1,7 @@ use crate::field::extension_field::target::ExtensionTarget; use crate::field::extension_field::Extendable; use crate::field::field_types::{Field, RichField}; +use crate::field::packed_field::PackedField; use crate::fri::commitment::SALT_SIZE; use crate::iop::target::Target; use crate::plonk::circuit_builder::CircuitBuilder; @@ -157,14 +158,14 @@ pub(crate) fn reduce_with_powers_multi< cumul } -pub(crate) fn reduce_with_powers<'a, F: Field, T: IntoIterator>( +pub(crate) fn reduce_with_powers<'a, P: PackedField, T: IntoIterator>( terms: T, - alpha: F, -) -> F + alpha: P::Scalar, +) -> P where T::IntoIter: DoubleEndedIterator, { - let mut sum = F::ZERO; + let mut sum = P::ZERO; for &term in terms.into_iter().rev() { sum = sum * alpha + term; } diff --git a/src/plonk/vars.rs b/src/plonk/vars.rs index 5f5b3cf8..62e770d1 100644 --- a/src/plonk/vars.rs +++ b/src/plonk/vars.rs @@ -4,6 +4,7 @@ use crate::field::extension_field::algebra::ExtensionAlgebra; use crate::field::extension_field::target::{ExtensionAlgebraTarget, ExtensionTarget}; use crate::field::extension_field::{Extendable, FieldExtension}; use crate::field::field_types::Field; +use crate::field::packed_field::PackedField; use crate::hash::hash_types::{HashOut, HashOutTarget}; use crate::util::strided_view::PackedStridedView; @@ -33,6 +34,16 @@ pub struct EvaluationVarsBase<'a, F: Field> { pub(crate) public_inputs_hash: &'a HashOut, } +/// Like `EvaluationVarsBase`, but packed. +// It's a separate struct because `EvaluationVarsBase` implements `get_local_ext` and we do not yet +// have packed extension fields. +#[derive(Debug, Copy, Clone)] +pub struct EvaluationVarsBasePacked<'a, P: PackedField> { + pub(crate) local_constants: PackedStridedView<'a, P>, + pub(crate) local_wires: PackedStridedView<'a, P>, + pub(crate) public_inputs_hash: &'a HashOut, +} + impl<'a, F: Extendable, const D: usize> EvaluationVars<'a, F, D> { pub fn get_local_ext_algebra( &self, @@ -88,6 +99,19 @@ impl<'a, F: Field> EvaluationVarsBaseBatch<'a, F> { pub fn iter(&self) -> EvaluationVarsBaseBatchIter<'a, F> { EvaluationVarsBaseBatchIter::new(*self) } + + pub fn pack>( + &self, + ) -> ( + EvaluationVarsBaseBatchIterPacked<'a, P>, + EvaluationVarsBaseBatchIterPacked<'a, F>, + ) { + let n_leftovers = self.len() % P::WIDTH; + ( + EvaluationVarsBaseBatchIterPacked::new_with_start(*self, 0), + EvaluationVarsBaseBatchIterPacked::new_with_start(*self, self.len() - n_leftovers), + ) + } } impl<'a, F: Field> EvaluationVarsBase<'a, F> { @@ -126,6 +150,63 @@ impl<'a, F: Field> Iterator for EvaluationVarsBaseBatchIter<'a, F> { } } +/// Iterator of packed views (`EvaluationVarsBasePacked`) into a `EvaluationVarsBaseBatch`. +/// Note: if the length of `EvaluationVarsBaseBatch` is not a multiple of `P::WIDTH`, then the +/// leftovers at the end are ignored. +pub struct EvaluationVarsBaseBatchIterPacked<'a, P: PackedField> { + /// Index to yield next, in units of `P::Scalar`. E.g. if `P::WIDTH == 4`, then we will yield + /// the vars for points `i`, `i + 1`, `i + 2`, and `i + 3`, packed. + i: usize, + vars_batch: EvaluationVarsBaseBatch<'a, P::Scalar>, +} + +impl<'a, P: PackedField> EvaluationVarsBaseBatchIterPacked<'a, P> { + pub fn new_with_start( + vars_batch: EvaluationVarsBaseBatch<'a, P::Scalar>, + start: usize, + ) -> Self { + assert!(start <= vars_batch.len()); + EvaluationVarsBaseBatchIterPacked { + i: start, + vars_batch, + } + } +} + +impl<'a, P: PackedField> Iterator for EvaluationVarsBaseBatchIterPacked<'a, P> { + type Item = EvaluationVarsBasePacked<'a, P>; + fn next(&mut self) -> Option { + if self.i + P::WIDTH <= self.vars_batch.len() { + let local_constants = PackedStridedView::new( + self.vars_batch.local_constants, + self.vars_batch.len(), + self.i, + ); + let local_wires = + PackedStridedView::new(self.vars_batch.local_wires, self.vars_batch.len(), self.i); + let res = EvaluationVarsBasePacked { + local_constants, + local_wires, + public_inputs_hash: self.vars_batch.public_inputs_hash, + }; + self.i += P::WIDTH; + Some(res) + } else { + None + } + } + fn size_hint(&self) -> (usize, Option) { + let len = self.len(); + (len, Some(len)) + } +} + +impl<'a, P: PackedField> ExactSizeIterator for EvaluationVarsBaseBatchIterPacked<'a, P> { + fn len(&self) -> usize { + (self.vars_batch.len() - self.i) / P::WIDTH + } +} + impl<'a, const D: usize> EvaluationTargets<'a, D> { pub fn remove_prefix(&mut self, prefix: &[bool]) { self.local_constants = &self.local_constants[prefix.len()..];