diff --git a/src/field/batch_util.rs b/src/field/batch_util.rs new file mode 100644 index 00000000..6ff5fc1d --- /dev/null +++ b/src/field/batch_util.rs @@ -0,0 +1,65 @@ +use crate::field::field_types::Field; +use crate::field::packable::Packable; +use crate::field::packed_field::PackedField; + +fn pack_with_leftovers_split_point(slice: &[P::Scalar]) -> usize { + let n = slice.len(); + let n_leftover = n % P::WIDTH; + n - n_leftover +} + +fn pack_slice_with_leftovers(slice: &[P::Scalar]) -> (&[P], &[P::Scalar]) { + let split_point = pack_with_leftovers_split_point::

(slice); + let (slice_packable, slice_leftovers) = slice.split_at(split_point); + let slice_packed = P::pack_slice(slice_packable); + (slice_packed, slice_leftovers) +} + +fn pack_slice_with_leftovers_mut( + slice: &mut [P::Scalar], +) -> (&mut [P], &mut [P::Scalar]) { + let split_point = pack_with_leftovers_split_point::

(slice); + let (slice_packable, slice_leftovers) = slice.split_at_mut(split_point); + let slice_packed = P::pack_slice_mut(slice_packable); + (slice_packed, slice_leftovers) +} + +/// Elementwise inplace multiplication of two slices of field elements. +/// Implementation be faster than the trivial for loop. +pub fn batch_multiply_inplace(out: &mut [F], a: &[F]) { + let n = out.len(); + assert_eq!(n, a.len(), "both arrays must have the same length"); + + // Split out slice of vectors, leaving leftovers as scalars + let (out_packed, out_leftovers) = + pack_slice_with_leftovers_mut::<::Packing>(out); + let (a_packed, a_leftovers) = pack_slice_with_leftovers::<::Packing>(a); + + // Multiply packed and the leftovers + for (x_out, x_a) in out_packed.iter_mut().zip(a_packed) { + *x_out *= *x_a; + } + for (x_out, x_a) in out_leftovers.iter_mut().zip(a_leftovers) { + *x_out *= *x_a; + } +} + +/// Elementwise inplace addition of two slices of field elements. +/// Implementation be faster than the trivial for loop. +pub fn batch_add_inplace(out: &mut [F], a: &[F]) { + let n = out.len(); + assert_eq!(n, a.len(), "both arrays must have the same length"); + + // Split out slice of vectors, leaving leftovers as scalars + let (out_packed, out_leftovers) = + pack_slice_with_leftovers_mut::<::Packing>(out); + let (a_packed, a_leftovers) = pack_slice_with_leftovers::<::Packing>(a); + + // Add packed and the leftovers + for (x_out, x_a) in out_packed.iter_mut().zip(a_packed) { + *x_out += *x_a; + } + for (x_out, x_a) in out_leftovers.iter_mut().zip(a_leftovers) { + *x_out += *x_a; + } +} diff --git a/src/field/mod.rs b/src/field/mod.rs index 74e0fbf4..3cf4cdfb 100644 --- a/src/field/mod.rs +++ b/src/field/mod.rs @@ -1,3 +1,4 @@ +pub(crate) mod batch_util; pub(crate) mod cosets; pub mod extension_field; pub mod fft; diff --git a/src/gates/arithmetic_base.rs b/src/gates/arithmetic_base.rs index eb4f8225..48a40b67 100644 --- a/src/gates/arithmetic_base.rs +++ b/src/gates/arithmetic_base.rs @@ -2,6 +2,7 @@ use crate::field::extension_field::target::ExtensionTarget; use crate::field::extension_field::Extendable; use crate::field::field_types::RichField; use crate::gates::gate::Gate; +use crate::gates::util::StridedConstraintConsumer; use crate::iop::generator::{GeneratedValues, SimpleGenerator, WitnessGenerator}; use crate::iop::target::Target; use crate::iop::witness::{PartitionWitness, Witness}; @@ -67,11 +68,14 @@ impl, const D: usize> Gate for ArithmeticGate constraints } - fn eval_unfiltered_base(&self, vars: EvaluationVarsBase) -> Vec { + fn eval_unfiltered_base_one( + &self, + vars: EvaluationVarsBase, + mut yield_constr: StridedConstraintConsumer, + ) { let const_0 = vars.local_constants[0]; let const_1 = vars.local_constants[1]; - let mut constraints = Vec::new(); for i in 0..self.num_ops { let multiplicand_0 = vars.local_wires[Self::wire_ith_multiplicand_0(i)]; let multiplicand_1 = vars.local_wires[Self::wire_ith_multiplicand_1(i)]; @@ -79,10 +83,8 @@ impl, const D: usize> Gate for ArithmeticGate let output = vars.local_wires[Self::wire_ith_output(i)]; let computed_output = multiplicand_0 * multiplicand_1 * const_0 + addend * const_1; - constraints.push(output - computed_output); + yield_constr.one(output - computed_output); } - - constraints } fn eval_unfiltered_recursively( diff --git a/src/gates/arithmetic_extension.rs b/src/gates/arithmetic_extension.rs index 62a89af4..c52224a2 100644 --- a/src/gates/arithmetic_extension.rs +++ b/src/gates/arithmetic_extension.rs @@ -4,6 +4,7 @@ use crate::field::extension_field::target::ExtensionTarget; use crate::field::extension_field::Extendable; use crate::field::extension_field::FieldExtension; use crate::gates::gate::Gate; +use crate::gates::util::StridedConstraintConsumer; use crate::iop::generator::{GeneratedValues, SimpleGenerator, WitnessGenerator}; use crate::iop::target::Target; use crate::iop::witness::{PartitionWitness, Witness}; @@ -70,11 +71,14 @@ impl, const D: usize> Gate for ArithmeticExtensionGate constraints } - fn eval_unfiltered_base(&self, vars: EvaluationVarsBase) -> Vec { + fn eval_unfiltered_base_one( + &self, + vars: EvaluationVarsBase, + mut yield_constr: StridedConstraintConsumer, + ) { let const_0 = vars.local_constants[0]; let const_1 = vars.local_constants[1]; - let mut constraints = Vec::new(); for i in 0..self.num_ops { let multiplicand_0 = vars.get_local_ext(Self::wires_ith_multiplicand_0(i)); let multiplicand_1 = vars.get_local_ext(Self::wires_ith_multiplicand_1(i)); @@ -83,10 +87,8 @@ impl, const D: usize> Gate for ArithmeticExtensionGate let computed_output = (multiplicand_0 * multiplicand_1).scalar_mul(const_0) + addend.scalar_mul(const_1); - constraints.extend((output - computed_output).to_basefield_array()); + yield_constr.many((output - computed_output).to_basefield_array()); } - - constraints } fn eval_unfiltered_recursively( diff --git a/src/gates/arithmetic_u32.rs b/src/gates/arithmetic_u32.rs index 77b2bf97..0dc66ee9 100644 --- a/src/gates/arithmetic_u32.rs +++ b/src/gates/arithmetic_u32.rs @@ -6,6 +6,7 @@ use crate::field::extension_field::target::ExtensionTarget; use crate::field::extension_field::Extendable; use crate::field::field_types::Field; use crate::gates::gate::Gate; +use crate::gates::util::StridedConstraintConsumer; use crate::iop::generator::{GeneratedValues, SimpleGenerator, WitnessGenerator}; use crate::iop::target::Target; use crate::iop::wire::Wire; @@ -118,8 +119,11 @@ impl, const D: usize> Gate for U32ArithmeticGate { constraints } - fn eval_unfiltered_base(&self, vars: EvaluationVarsBase) -> Vec { - let mut constraints = Vec::with_capacity(self.num_constraints()); + fn eval_unfiltered_base_one( + &self, + vars: EvaluationVarsBase, + mut yield_constr: StridedConstraintConsumer, + ) { for i in 0..self.num_ops { let multiplicand_0 = vars.local_wires[self.wire_ith_multiplicand_0(i)]; let multiplicand_1 = vars.local_wires[self.wire_ith_multiplicand_1(i)]; @@ -133,7 +137,7 @@ impl, const D: usize> Gate for U32ArithmeticGate { let base = F::from_canonical_u64(1 << 32u64); let combined_output = output_high * base + output_low; - constraints.push(combined_output - computed_output); + yield_constr.one(combined_output - computed_output); let mut combined_low_limbs = F::ZERO; let mut combined_high_limbs = F::ZERO; @@ -145,7 +149,7 @@ impl, const D: usize> Gate for U32ArithmeticGate { let product = (0..max_limb) .map(|x| this_limb - F::from_canonical_usize(x)) .product(); - constraints.push(product); + yield_constr.one(product); if j < midpoint { combined_low_limbs = base * combined_low_limbs + this_limb; @@ -153,11 +157,9 @@ impl, const D: usize> Gate for U32ArithmeticGate { combined_high_limbs = base * combined_high_limbs + this_limb; } } - constraints.push(combined_low_limbs - output_low); - constraints.push(combined_high_limbs - output_high); + yield_constr.one(combined_low_limbs - output_low); + yield_constr.one(combined_high_limbs - output_high); } - - constraints } fn eval_unfiltered_recursively( diff --git a/src/gates/assert_le.rs b/src/gates/assert_le.rs index 078585fc..687699a2 100644 --- a/src/gates/assert_le.rs +++ b/src/gates/assert_le.rs @@ -4,6 +4,7 @@ use crate::field::extension_field::target::ExtensionTarget; use crate::field::extension_field::Extendable; use crate::field::field_types::{Field, PrimeField, RichField}; use crate::gates::gate::Gate; +use crate::gates::util::StridedConstraintConsumer; use crate::iop::generator::{GeneratedValues, SimpleGenerator, WitnessGenerator}; use crate::iop::target::Target; use crate::iop::wire::Wire; @@ -148,9 +149,11 @@ impl, const D: usize> Gate for AssertLessThan constraints } - fn eval_unfiltered_base(&self, vars: EvaluationVarsBase) -> Vec { - let mut constraints = Vec::with_capacity(self.num_constraints()); - + fn eval_unfiltered_base_one( + &self, + vars: EvaluationVarsBase, + mut yield_constr: StridedConstraintConsumer, + ) { let first_input = vars.local_wires[self.wire_first_input()]; let second_input = vars.local_wires[self.wire_second_input()]; @@ -171,8 +174,8 @@ impl, const D: usize> Gate for AssertLessThan F::from_canonical_usize(1 << self.chunk_bits()), ); - constraints.push(first_chunks_combined - first_input); - constraints.push(second_chunks_combined - second_input); + yield_constr.one(first_chunks_combined - first_input); + yield_constr.one(second_chunks_combined - second_input); let chunk_size = 1 << self.chunk_bits(); @@ -186,34 +189,32 @@ impl, const D: usize> Gate for AssertLessThan let second_product = (0..chunk_size) .map(|x| second_chunks[i] - F::from_canonical_usize(x)) .product(); - constraints.push(first_product); - constraints.push(second_product); + yield_constr.one(first_product); + yield_constr.one(second_product); let difference = second_chunks[i] - first_chunks[i]; let equality_dummy = vars.local_wires[self.wire_equality_dummy(i)]; let chunks_equal = vars.local_wires[self.wire_chunks_equal(i)]; // Two constraints to assert that `chunks_equal` is valid. - constraints.push(difference * equality_dummy - (F::ONE - chunks_equal)); - constraints.push(chunks_equal * difference); + yield_constr.one(difference * equality_dummy - (F::ONE - chunks_equal)); + yield_constr.one(chunks_equal * difference); // Update `most_significant_diff_so_far`. let intermediate_value = vars.local_wires[self.wire_intermediate_value(i)]; - constraints.push(intermediate_value - chunks_equal * most_significant_diff_so_far); + yield_constr.one(intermediate_value - chunks_equal * most_significant_diff_so_far); most_significant_diff_so_far = intermediate_value + (F::ONE - chunks_equal) * difference; } let most_significant_diff = vars.local_wires[self.wire_most_significant_diff()]; - constraints.push(most_significant_diff - most_significant_diff_so_far); + yield_constr.one(most_significant_diff - most_significant_diff_so_far); // Range check `most_significant_diff` to be less than `chunk_size`. let product = (0..chunk_size) .map(|x| most_significant_diff - F::from_canonical_usize(x)) .product(); - constraints.push(product); - - constraints + yield_constr.one(product); } fn eval_unfiltered_recursively( diff --git a/src/gates/base_sum.rs b/src/gates/base_sum.rs index cc4886ea..de8b7fb3 100644 --- a/src/gates/base_sum.rs +++ b/src/gates/base_sum.rs @@ -4,6 +4,7 @@ use crate::field::extension_field::target::ExtensionTarget; use crate::field::extension_field::Extendable; use crate::field::field_types::{Field, PrimeField, RichField}; use crate::gates::gate::Gate; +use crate::gates::util::StridedConstraintConsumer; use crate::iop::generator::{GeneratedValues, SimpleGenerator, WitnessGenerator}; use crate::iop::target::Target; use crate::iop::witness::{PartitionWitness, Witness}; @@ -57,22 +58,23 @@ impl, const D: usize, const B: usize> Gate for BaseSumGat constraints } - fn eval_unfiltered_base(&self, vars: EvaluationVarsBase) -> Vec { + fn eval_unfiltered_base_one( + &self, + vars: EvaluationVarsBase, + mut yield_constr: StridedConstraintConsumer, + ) { let sum = vars.local_wires[Self::WIRE_SUM]; - let limbs = &vars.local_wires[self.limbs()]; + let limbs = vars.local_wires.view(self.limbs()); let computed_sum = reduce_with_powers(limbs, F::from_canonical_usize(B)); - let mut constraints = Vec::with_capacity(limbs.len() + 1); - constraints.push(computed_sum - sum); + yield_constr.one(computed_sum - sum); let constraints_iter = limbs.iter().map(|&limb| { (0..B) .map(|i| unsafe { limb.sub_canonical_u64(i as u64) }) .product::() }); - constraints.extend(constraints_iter); - - constraints + yield_constr.many(constraints_iter); } fn eval_unfiltered_recursively( diff --git a/src/gates/comparison.rs b/src/gates/comparison.rs index 5d1fcf4f..45819cf6 100644 --- a/src/gates/comparison.rs +++ b/src/gates/comparison.rs @@ -4,6 +4,7 @@ use crate::field::extension_field::target::ExtensionTarget; use crate::field::extension_field::Extendable; use crate::field::field_types::{Field, PrimeField}; use crate::gates::gate::Gate; +use crate::gates::util::StridedConstraintConsumer; use crate::iop::generator::{GeneratedValues, SimpleGenerator, WitnessGenerator}; use crate::iop::target::Target; use crate::iop::wire::Wire; @@ -166,9 +167,11 @@ impl, const D: usize> Gate for ComparisonGate { constraints } - fn eval_unfiltered_base(&self, vars: EvaluationVarsBase) -> Vec { - let mut constraints = Vec::with_capacity(self.num_constraints()); - + fn eval_unfiltered_base_one( + &self, + vars: EvaluationVarsBase, + mut yield_constr: StridedConstraintConsumer, + ) { let first_input = vars.local_wires[self.wire_first_input()]; let second_input = vars.local_wires[self.wire_second_input()]; @@ -189,8 +192,8 @@ impl, const D: usize> Gate for ComparisonGate { F::from_canonical_usize(1 << self.chunk_bits()), ); - constraints.push(first_chunks_combined - first_input); - constraints.push(second_chunks_combined - second_input); + yield_constr.one(first_chunks_combined - first_input); + yield_constr.one(second_chunks_combined - second_input); let chunk_size = 1 << self.chunk_bits(); @@ -204,26 +207,26 @@ impl, const D: usize> Gate for ComparisonGate { let second_product: F = (0..chunk_size) .map(|x| second_chunks[i] - F::from_canonical_usize(x)) .product(); - constraints.push(first_product); - constraints.push(second_product); + yield_constr.one(first_product); + yield_constr.one(second_product); let difference = second_chunks[i] - first_chunks[i]; let equality_dummy = vars.local_wires[self.wire_equality_dummy(i)]; let chunks_equal = vars.local_wires[self.wire_chunks_equal(i)]; // Two constraints to assert that `chunks_equal` is valid. - constraints.push(difference * equality_dummy - (F::ONE - chunks_equal)); - constraints.push(chunks_equal * difference); + yield_constr.one(difference * equality_dummy - (F::ONE - chunks_equal)); + yield_constr.one(chunks_equal * difference); // Update `most_significant_diff_so_far`. let intermediate_value = vars.local_wires[self.wire_intermediate_value(i)]; - constraints.push(intermediate_value - chunks_equal * most_significant_diff_so_far); + yield_constr.one(intermediate_value - chunks_equal * most_significant_diff_so_far); most_significant_diff_so_far = intermediate_value + (F::ONE - chunks_equal) * difference; } let most_significant_diff = vars.local_wires[self.wire_most_significant_diff()]; - constraints.push(most_significant_diff - most_significant_diff_so_far); + yield_constr.one(most_significant_diff - most_significant_diff_so_far); let most_significant_diff_bits: Vec = (0..self.chunk_bits() + 1) .map(|i| vars.local_wires[self.wire_most_significant_diff_bit(i)]) @@ -231,18 +234,16 @@ impl, const D: usize> Gate for ComparisonGate { // Range-check the bits. for &bit in &most_significant_diff_bits { - constraints.push(bit * (F::ONE - bit)); + yield_constr.one(bit * (F::ONE - bit)); } let bits_combined = reduce_with_powers(&most_significant_diff_bits, F::TWO); let two_n = F::from_canonical_u64(1 << self.chunk_bits()); - constraints.push((two_n + most_significant_diff) - bits_combined); + yield_constr.one((two_n + most_significant_diff) - bits_combined); // Iff first <= second, the top (n + 1st) bit of (2^n - 1 + most_significant_diff) will be 1. let result_bool = vars.local_wires[self.wire_result_bool()]; - constraints.push(result_bool - most_significant_diff_bits[self.chunk_bits()]); - - constraints + yield_constr.one(result_bool - most_significant_diff_bits[self.chunk_bits()]); } fn eval_unfiltered_recursively( diff --git a/src/gates/constant.rs b/src/gates/constant.rs index ee3f4545..3a790ee2 100644 --- a/src/gates/constant.rs +++ b/src/gates/constant.rs @@ -4,6 +4,7 @@ use crate::field::extension_field::target::ExtensionTarget; use crate::field::extension_field::Extendable; use crate::field::field_types::Field; use crate::gates::gate::Gate; +use crate::gates::util::StridedConstraintConsumer; use crate::iop::generator::{GeneratedValues, SimpleGenerator, WitnessGenerator}; use crate::iop::target::Target; use crate::iop::wire::Wire; @@ -39,11 +40,16 @@ impl, const D: usize> Gate for ConstantGate { .collect() } - fn eval_unfiltered_base(&self, vars: EvaluationVarsBase) -> Vec { - self.consts_inputs() - .zip(self.wires_outputs()) - .map(|(con, out)| vars.local_constants[con] - vars.local_wires[out]) - .collect() + fn eval_unfiltered_base_one( + &self, + vars: EvaluationVarsBase, + mut yield_constr: StridedConstraintConsumer, + ) { + yield_constr.many( + self.consts_inputs() + .zip(self.wires_outputs()) + .map(|(con, out)| vars.local_constants[con] - vars.local_wires[out]), + ); } fn eval_unfiltered_recursively( diff --git a/src/gates/exponentiation.rs b/src/gates/exponentiation.rs index cc2f970d..ed4fd4d6 100644 --- a/src/gates/exponentiation.rs +++ b/src/gates/exponentiation.rs @@ -4,6 +4,7 @@ use crate::field::extension_field::target::ExtensionTarget; use crate::field::extension_field::Extendable; use crate::field::field_types::Field; use crate::gates::gate::Gate; +use crate::gates::util::StridedConstraintConsumer; use crate::iop::generator::{GeneratedValues, SimpleGenerator, WitnessGenerator}; use crate::iop::target::Target; use crate::iop::wire::Wire; @@ -99,7 +100,11 @@ impl, const D: usize> Gate for ExponentiationGate { constraints } - fn eval_unfiltered_base(&self, vars: EvaluationVarsBase) -> Vec { + fn eval_unfiltered_base_one( + &self, + vars: EvaluationVarsBase, + mut yield_constr: StridedConstraintConsumer, + ) { let base = vars.local_wires[self.wire_base()]; let power_bits: Vec<_> = (0..self.num_power_bits) @@ -111,8 +116,6 @@ impl, const D: usize> Gate for ExponentiationGate { let output = vars.local_wires[self.wire_output()]; - let mut constraints = Vec::with_capacity(self.num_constraints()); - for i in 0..self.num_power_bits { let prev_intermediate_value = if i == 0 { F::ONE @@ -126,12 +129,10 @@ impl, const D: usize> Gate for ExponentiationGate { let not_cur_bit = F::ONE - cur_bit; let computed_intermediate_value = prev_intermediate_value * (cur_bit * base + not_cur_bit); - constraints.push(computed_intermediate_value - intermediate_values[i]); + yield_constr.one(computed_intermediate_value - intermediate_values[i]); } - constraints.push(output - intermediate_values[self.num_power_bits - 1]); - - constraints + yield_constr.one(output - intermediate_values[self.num_power_bits - 1]); } fn eval_unfiltered_recursively( diff --git a/src/gates/gate.rs b/src/gates/gate.rs index 65234f4e..f31a60c2 100644 --- a/src/gates/gate.rs +++ b/src/gates/gate.rs @@ -2,13 +2,17 @@ use std::fmt::{Debug, Error, Formatter}; use std::hash::{Hash, Hasher}; use std::sync::Arc; +use crate::field::batch_util::batch_multiply_inplace; use crate::field::extension_field::target::ExtensionTarget; use crate::field::extension_field::{Extendable, FieldExtension}; use crate::field::field_types::{Field, RichField}; use crate::gates::gate_tree::Tree; +use crate::gates::util::StridedConstraintConsumer; use crate::iop::generator::WitnessGenerator; use crate::plonk::circuit_builder::CircuitBuilder; -use crate::plonk::vars::{EvaluationTargets, EvaluationVars, EvaluationVarsBase}; +use crate::plonk::vars::{ + EvaluationTargets, EvaluationVars, EvaluationVarsBase, EvaluationVarsBaseBatch, +}; /// A custom gate. pub trait Gate, const D: usize>: 'static + Send + Sync { @@ -18,9 +22,19 @@ pub trait Gate, const D: usize>: 'static + Send + Sync { /// Like `eval_unfiltered`, but specialized for points in the base field. /// + /// + /// `eval_unfiltered_base_batch` calls this method by default. If `eval_unfiltered_base_batch` + /// is overridden, then `eval_unfiltered_base_one` is not necessary. + /// /// By default, this just calls `eval_unfiltered`, which treats the point as an extension field /// element. This isn't very efficient. - fn eval_unfiltered_base(&self, vars_base: EvaluationVarsBase) -> Vec { + fn eval_unfiltered_base_one( + &self, + vars_base: EvaluationVarsBase, + mut yield_constr: StridedConstraintConsumer, + ) { + // Note that this method uses `yield_constr` instead of returning its constraints. + // `yield_constr` abstracts out the underlying memory layout. let local_constants = &vars_base .local_constants .iter() @@ -40,13 +54,21 @@ pub trait Gate, const D: usize>: 'static + Send + Sync { let values = self.eval_unfiltered(vars); // Each value should be in the base field, i.e. only the degree-zero part should be nonzero. - values - .into_iter() - .map(|value| { - debug_assert!(F::Extension::is_in_basefield(&value)); - value.to_basefield_array()[0] - }) - .collect() + values.into_iter().for_each(|value| { + debug_assert!(F::Extension::is_in_basefield(&value)); + yield_constr.one(value.to_basefield_array()[0]) + }) + } + + fn eval_unfiltered_base_batch(&self, vars_base: EvaluationVarsBaseBatch) -> Vec { + let mut res = vec![F::ZERO; vars_base.len() * self.num_constraints()]; + for (i, vars_base_one) in vars_base.iter().enumerate() { + self.eval_unfiltered_base_one( + vars_base_one, + StridedConstraintConsumer::new(&mut res, vars_base.len(), i), + ); + } + res } fn eval_unfiltered_recursively( @@ -64,26 +86,23 @@ pub trait Gate, const D: usize>: 'static + Send + Sync { .collect() } - /// Like `eval_filtered`, but specialized for points in the base field. - fn eval_filtered_base(&self, mut vars: EvaluationVarsBase, prefix: &[bool]) -> Vec { - let filter = compute_filter(prefix, vars.local_constants); - vars.remove_prefix(prefix); - let mut res = self.eval_unfiltered_base(vars); - res.iter_mut().for_each(|c| { - *c *= filter; - }); - res - } - + /// The result is an array of length `vars_batch.len() * self.num_constraints()`. Constraint `j` + /// for point `i` is at index `j * batch_size + i`. fn eval_filtered_base_batch( &self, - vars_batch: &[EvaluationVarsBase], + mut vars_batch: EvaluationVarsBaseBatch, prefix: &[bool], - ) -> Vec> { - vars_batch + ) -> Vec { + let filters: Vec<_> = vars_batch .iter() - .map(|&vars| self.eval_filtered_base(vars, prefix)) - .collect() + .map(|vars| compute_filter(prefix, vars.local_constants)) + .collect(); + vars_batch.remove_prefix(prefix); + let mut res_batch = self.eval_unfiltered_base_batch(vars_batch); + for res_chunk in res_batch.chunks_exact_mut(filters.len()) { + batch_multiply_inplace(res_chunk, &filters); + } + res_batch } /// Adds this gate's filtered constraints into the `combined_gate_constraints` buffer. @@ -174,17 +193,11 @@ impl, const D: usize> PrefixedGate { /// A gate's filter is computed as `prod b_i*c_i + (1-b_i)*(1-c_i)`, with `(b_i)` the prefix and /// `(c_i)` the local constants, which is one if the prefix of `constants` matches `prefix`. -fn compute_filter(prefix: &[bool], constants: &[K]) -> K { +fn compute_filter<'a, K: Field, T: IntoIterator>(prefix: &[bool], constants: T) -> K { prefix .iter() - .enumerate() - .map(|(i, &b)| { - if b { - constants[i] - } else { - K::ONE - constants[i] - } - }) + .zip(constants) + .map(|(&b, &c)| if b { c } else { K::ONE - c }) .product() } diff --git a/src/gates/gate_testing.rs b/src/gates/gate_testing.rs index 97a69fdf..e9e823e1 100644 --- a/src/gates/gate_testing.rs +++ b/src/gates/gate_testing.rs @@ -8,7 +8,7 @@ use crate::iop::witness::{PartialWitness, Witness}; use crate::plonk::circuit_builder::CircuitBuilder; use crate::plonk::circuit_data::CircuitConfig; use crate::plonk::config::GenericConfig; -use crate::plonk::vars::{EvaluationTargets, EvaluationVars, EvaluationVarsBase}; +use crate::plonk::vars::{EvaluationTargets, EvaluationVars, EvaluationVarsBaseBatch}; use crate::plonk::verifier::verify; use crate::polynomial::{PolynomialCoeffs, PolynomialValues}; use crate::util::{log2_ceil, transpose}; @@ -106,19 +106,18 @@ pub(crate) fn test_eval_fns< .collect::>(); let public_inputs_hash = HashOut::rand(); - let vars_base = EvaluationVarsBase { - local_constants: &constants_base, - local_wires: &wires_base, - public_inputs_hash: &public_inputs_hash, - }; + // Batch of 1. + let vars_base_batch = + EvaluationVarsBaseBatch::new(1, &constants_base, &wires_base, &public_inputs_hash); let vars = EvaluationVars { local_constants: &constants, local_wires: &wires, public_inputs_hash: &public_inputs_hash, }; - let evals_base = gate.eval_unfiltered_base(vars_base); + let evals_base = gate.eval_unfiltered_base_batch(vars_base_batch); let evals = gate.eval_unfiltered(vars); + // This works because we have a batch of 1. ensure!( evals == evals_base diff --git a/src/gates/gmimc.rs b/src/gates/gmimc.rs index 8a34943d..d09d63dd 100644 --- a/src/gates/gmimc.rs +++ b/src/gates/gmimc.rs @@ -4,6 +4,7 @@ use crate::field::extension_field::target::ExtensionTarget; use crate::field::extension_field::Extendable; use crate::field::field_types::Field; use crate::gates::gate::Gate; +use crate::gates::util::StridedConstraintConsumer; use crate::hash::gmimc; use crate::hash::gmimc::GMiMC; use crate::iop::generator::{GeneratedValues, SimpleGenerator, WitnessGenerator}; @@ -107,12 +108,14 @@ impl + GMiMC, const D: usize, const WIDTH: usize> Gate) -> Vec { - let mut constraints = Vec::with_capacity(self.num_constraints()); - + fn eval_unfiltered_base_one( + &self, + vars: EvaluationVarsBase, + mut yield_constr: StridedConstraintConsumer, + ) { // Assert that `swap` is binary. let swap = vars.local_wires[Self::WIRE_SWAP]; - constraints.push(swap * swap.sub_one()); + yield_constr.one(swap * swap.sub_one()); let mut state = Vec::with_capacity(12); for i in 0..4 { @@ -138,7 +141,7 @@ impl + GMiMC, const D: usize, const WIDTH: usize> Gate>::ROUND_CONSTANTS[r]); let cubing_input = state[active] + addition_buffer + constant; let cubing_input_wire = vars.local_wires[Self::wire_cubing_input(r)]; - constraints.push(cubing_input - cubing_input_wire); + yield_constr.one(cubing_input - cubing_input_wire); let f = cubing_input_wire.cube(); addition_buffer += f; state[active] -= f; @@ -146,10 +149,8 @@ impl + GMiMC, const D: usize, const WIDTH: usize> Gate, const D: usize> Gate for InsertionGate { constraints } - fn eval_unfiltered_base(&self, vars: EvaluationVarsBase) -> Vec { + fn eval_unfiltered_base_one( + &self, + vars: EvaluationVarsBase, + mut yield_constr: StridedConstraintConsumer, + ) { let insertion_index = vars.local_wires[self.wires_insertion_index()]; let list_items = (0..self.vec_size) .map(|i| vars.get_local_ext(self.wires_original_list_item(i))) @@ -122,7 +127,6 @@ impl, const D: usize> Gate for InsertionGate { .collect::>(); let element_to_insert = vars.get_local_ext(self.wires_element_to_insert()); - let mut constraints = Vec::with_capacity(self.num_constraints()); let mut already_inserted = F::ZERO; for r in 0..=self.vec_size { let cur_index = F::from_canonical_usize(r); @@ -131,8 +135,8 @@ impl, const D: usize> Gate for InsertionGate { let insert_here = vars.local_wires[self.wire_insert_here_for_round_r(r)]; // The two equality constraints. - constraints.push(difference * equality_dummy - (F::ONE - insert_here)); - constraints.push(insert_here * difference); + yield_constr.one(difference * equality_dummy - (F::ONE - insert_here)); + yield_constr.one(insert_here * difference); let mut new_item = element_to_insert.scalar_mul(insert_here); if r > 0 { @@ -144,10 +148,8 @@ impl, const D: usize> Gate for InsertionGate { } // Output constraint. - constraints.extend((new_item - output_list_items[r]).to_basefield_array()); + yield_constr.many((new_item - output_list_items[r]).to_basefield_array()); } - - constraints } fn eval_unfiltered_recursively( diff --git a/src/gates/interpolation.rs b/src/gates/interpolation.rs index 2dbe8ef8..61adce4a 100644 --- a/src/gates/interpolation.rs +++ b/src/gates/interpolation.rs @@ -8,6 +8,7 @@ use crate::field::interpolation::interpolant; use crate::gadgets::interpolation::InterpolationGate; use crate::gadgets::polynomial::PolynomialCoeffsExtAlgebraTarget; use crate::gates::gate::Gate; +use crate::gates::util::StridedConstraintConsumer; use crate::iop::generator::{GeneratedValues, SimpleGenerator, WitnessGenerator}; use crate::iop::target::Target; use crate::iop::wire::Wire; @@ -106,9 +107,11 @@ impl, const D: usize> Gate for HighDegreeInterpolationGat constraints } - fn eval_unfiltered_base(&self, vars: EvaluationVarsBase) -> Vec { - let mut constraints = Vec::with_capacity(self.num_constraints()); - + fn eval_unfiltered_base_one( + &self, + vars: EvaluationVarsBase, + mut yield_constr: StridedConstraintConsumer, + ) { let coeffs = (0..self.num_points()) .map(|i| vars.get_local_ext(self.wires_coeff(i))) .collect(); @@ -118,15 +121,13 @@ impl, const D: usize> Gate for HighDegreeInterpolationGat for (i, point) in coset.into_iter().enumerate() { let value = vars.get_local_ext(self.wires_value(i)); let computed_value = interpolant.eval_base(point); - constraints.extend(&(value - computed_value).to_basefield_array()); + yield_constr.many((value - computed_value).to_basefield_array()); } let evaluation_point = vars.get_local_ext(self.wires_evaluation_point()); let evaluation_value = vars.get_local_ext(self.wires_evaluation_value()); let computed_evaluation_value = interpolant.eval(evaluation_point); - constraints.extend(&(evaluation_value - computed_evaluation_value).to_basefield_array()); - - constraints + yield_constr.many((evaluation_value - computed_evaluation_value).to_basefield_array()); } fn eval_unfiltered_recursively( diff --git a/src/gates/low_degree_interpolation.rs b/src/gates/low_degree_interpolation.rs index 47b43359..1462af3b 100644 --- a/src/gates/low_degree_interpolation.rs +++ b/src/gates/low_degree_interpolation.rs @@ -9,6 +9,7 @@ use crate::field::interpolation::interpolant; use crate::gadgets::interpolation::InterpolationGate; use crate::gadgets::polynomial::PolynomialCoeffsExtAlgebraTarget; use crate::gates::gate::Gate; +use crate::gates::util::StridedConstraintConsumer; use crate::iop::generator::{GeneratedValues, SimpleGenerator, WitnessGenerator}; use crate::iop::target::Target; use crate::iop::wire::Wire; @@ -130,9 +131,11 @@ impl, const D: usize> Gate for LowDegreeInter constraints } - fn eval_unfiltered_base(&self, vars: EvaluationVarsBase) -> Vec { - let mut constraints = Vec::with_capacity(self.num_constraints()); - + fn eval_unfiltered_base_one( + &self, + vars: EvaluationVarsBase, + mut yield_constr: StridedConstraintConsumer, + ) { let coeffs = (0..self.num_points()) .map(|i| vars.get_local_ext(self.wires_coeff(i))) .collect::>(); @@ -142,7 +145,7 @@ impl, const D: usize> Gate for LowDegreeInter .collect::>(); let shift = powers_shift[0]; for i in 1..self.num_points() - 1 { - constraints.push(powers_shift[i - 1] * shift - powers_shift[i]); + yield_constr.one(powers_shift[i - 1] * shift - powers_shift[i]); } powers_shift.insert(0, F::ONE); // `altered_coeffs[i] = c_i * shift^i`, where `c_i` is the original coefficient. @@ -161,7 +164,7 @@ impl, const D: usize> Gate for LowDegreeInter { let value = vars.get_local_ext(self.wires_value(i)); let computed_value = altered_interpolant.eval_base(point); - constraints.extend(&(value - computed_value).to_basefield_array()); + yield_constr.many((value - computed_value).to_basefield_array()); } let evaluation_point_powers = (1..self.num_points()) @@ -169,16 +172,14 @@ impl, const D: usize> Gate for LowDegreeInter .collect::>(); let evaluation_point = evaluation_point_powers[0]; for i in 1..self.num_points() - 1 { - constraints.extend( + yield_constr.many( (evaluation_point_powers[i - 1] * evaluation_point - evaluation_point_powers[i]) .to_basefield_array(), ); } let evaluation_value = vars.get_local_ext(self.wires_evaluation_value()); let computed_evaluation_value = interpolant.eval_with_powers(&evaluation_point_powers); - constraints.extend(&(evaluation_value - computed_evaluation_value).to_basefield_array()); - - constraints + yield_constr.many((evaluation_value - computed_evaluation_value).to_basefield_array()); } fn eval_unfiltered_recursively( diff --git a/src/gates/mod.rs b/src/gates/mod.rs index 54289733..dbbe174c 100644 --- a/src/gates/mod.rs +++ b/src/gates/mod.rs @@ -25,6 +25,7 @@ pub mod reducing; pub mod reducing_extension; pub mod subtraction_u32; pub mod switch; +mod util; #[cfg(test)] mod gate_testing; diff --git a/src/gates/multiplication_extension.rs b/src/gates/multiplication_extension.rs index e532f804..a2041689 100644 --- a/src/gates/multiplication_extension.rs +++ b/src/gates/multiplication_extension.rs @@ -5,6 +5,7 @@ use crate::field::extension_field::Extendable; use crate::field::extension_field::FieldExtension; use crate::field::field_types::RichField; use crate::gates::gate::Gate; +use crate::gates::util::StridedConstraintConsumer; use crate::iop::generator::{GeneratedValues, SimpleGenerator, WitnessGenerator}; use crate::iop::target::Target; use crate::iop::witness::{PartitionWitness, Witness}; @@ -65,20 +66,21 @@ impl, const D: usize> Gate for MulExtensionGa constraints } - fn eval_unfiltered_base(&self, vars: EvaluationVarsBase) -> Vec { + fn eval_unfiltered_base_one( + &self, + vars: EvaluationVarsBase, + mut yield_constr: StridedConstraintConsumer, + ) { let const_0 = vars.local_constants[0]; - let mut constraints = Vec::new(); for i in 0..self.num_ops { let multiplicand_0 = vars.get_local_ext(Self::wires_ith_multiplicand_0(i)); let multiplicand_1 = vars.get_local_ext(Self::wires_ith_multiplicand_1(i)); let output = vars.get_local_ext(Self::wires_ith_output(i)); let computed_output = (multiplicand_0 * multiplicand_1).scalar_mul(const_0); - constraints.extend((output - computed_output).to_basefield_array()); + yield_constr.many((output - computed_output).to_basefield_array()); } - - constraints } fn eval_unfiltered_recursively( diff --git a/src/gates/noop.rs b/src/gates/noop.rs index a1f572d8..4230e678 100644 --- a/src/gates/noop.rs +++ b/src/gates/noop.rs @@ -3,7 +3,7 @@ use crate::field::extension_field::Extendable; use crate::gates::gate::Gate; use crate::iop::generator::WitnessGenerator; use crate::plonk::circuit_builder::CircuitBuilder; -use crate::plonk::vars::{EvaluationTargets, EvaluationVars, EvaluationVarsBase}; +use crate::plonk::vars::{EvaluationTargets, EvaluationVars, EvaluationVarsBaseBatch}; /// A gate which does nothing. pub struct NoopGate; @@ -17,7 +17,7 @@ impl, const D: usize> Gate for NoopGate { Vec::new() } - fn eval_unfiltered_base(&self, _vars: EvaluationVarsBase) -> Vec { + fn eval_unfiltered_base_batch(&self, _vars: EvaluationVarsBaseBatch) -> Vec { Vec::new() } diff --git a/src/gates/poseidon.rs b/src/gates/poseidon.rs index 0f5963e3..141983e9 100644 --- a/src/gates/poseidon.rs +++ b/src/gates/poseidon.rs @@ -5,6 +5,7 @@ use crate::field::extension_field::Extendable; use crate::field::field_types::{Field, RichField}; use crate::gates::gate::Gate; use crate::gates::poseidon_mds::PoseidonMdsGate; +use crate::gates::util::StridedConstraintConsumer; use crate::hash::hashing::SPONGE_WIDTH; use crate::hash::poseidon; use crate::hash::poseidon::Poseidon; @@ -181,19 +182,21 @@ impl, const D: usize> Gate for PoseidonGate { constraints } - fn eval_unfiltered_base(&self, vars: EvaluationVarsBase) -> Vec { - let mut constraints = Vec::with_capacity(self.num_constraints()); - + fn eval_unfiltered_base_one( + &self, + vars: EvaluationVarsBase, + mut yield_constr: StridedConstraintConsumer, + ) { // Assert that `swap` is binary. let swap = vars.local_wires[Self::WIRE_SWAP]; - constraints.push(swap * swap.sub_one()); + yield_constr.one(swap * swap.sub_one()); // Assert that each delta wire is set properly: `delta_i = swap * (rhs - lhs)`. for i in 0..4 { let input_lhs = vars.local_wires[Self::wire_input(i)]; let input_rhs = vars.local_wires[Self::wire_input(i + 4)]; let delta_i = vars.local_wires[Self::wire_delta(i)]; - constraints.push(swap * (input_rhs - input_lhs) - delta_i); + yield_constr.one(swap * (input_rhs - input_lhs) - delta_i); } // Compute the possibly-swapped input layer. @@ -217,7 +220,7 @@ impl, const D: usize> Gate for PoseidonGate { if r != 0 { for i in 0..SPONGE_WIDTH { let sbox_in = vars.local_wires[Self::wire_full_sbox_0(r, i)]; - constraints.push(state[i] - sbox_in); + yield_constr.one(state[i] - sbox_in); state[i] = sbox_in; } } @@ -231,13 +234,13 @@ impl, const D: usize> Gate for PoseidonGate { state = ::mds_partial_layer_init(&state); for r in 0..(poseidon::N_PARTIAL_ROUNDS - 1) { let sbox_in = vars.local_wires[Self::wire_partial_sbox(r)]; - constraints.push(state[0] - sbox_in); + yield_constr.one(state[0] - sbox_in); state[0] = ::sbox_monomial(sbox_in); state[0] += F::from_canonical_u64(::FAST_PARTIAL_ROUND_CONSTANTS[r]); state = ::mds_partial_layer_fast(&state, r); } let sbox_in = vars.local_wires[Self::wire_partial_sbox(poseidon::N_PARTIAL_ROUNDS - 1)]; - constraints.push(state[0] - sbox_in); + yield_constr.one(state[0] - sbox_in); state[0] = ::sbox_monomial(sbox_in); state = ::mds_partial_layer_fast(&state, poseidon::N_PARTIAL_ROUNDS - 1); round_ctr += poseidon::N_PARTIAL_ROUNDS; @@ -247,7 +250,7 @@ impl, const D: usize> Gate for PoseidonGate { ::constant_layer(&mut state, round_ctr); for i in 0..SPONGE_WIDTH { let sbox_in = vars.local_wires[Self::wire_full_sbox_1(r, i)]; - constraints.push(state[i] - sbox_in); + yield_constr.one(state[i] - sbox_in); state[i] = sbox_in; } ::sbox_layer(&mut state); @@ -256,10 +259,8 @@ impl, const D: usize> Gate for PoseidonGate { } for i in 0..SPONGE_WIDTH { - constraints.push(state[i] - vars.local_wires[Self::wire_output(i)]); + yield_constr.one(state[i] - vars.local_wires[Self::wire_output(i)]); } - - constraints } fn eval_unfiltered_recursively( diff --git a/src/gates/poseidon_mds.rs b/src/gates/poseidon_mds.rs index a9234619..da54ab2c 100644 --- a/src/gates/poseidon_mds.rs +++ b/src/gates/poseidon_mds.rs @@ -7,6 +7,7 @@ use crate::field::extension_field::Extendable; use crate::field::extension_field::FieldExtension; use crate::field::field_types::{Field, RichField}; use crate::gates::gate::Gate; +use crate::gates::util::StridedConstraintConsumer; use crate::hash::hashing::SPONGE_WIDTH; use crate::hash::poseidon::Poseidon; use crate::iop::generator::{GeneratedValues, SimpleGenerator, WitnessGenerator}; @@ -123,7 +124,11 @@ impl + Poseidon, const D: usize> Gate for Pos .collect() } - fn eval_unfiltered_base(&self, vars: EvaluationVarsBase) -> Vec { + fn eval_unfiltered_base_one( + &self, + vars: EvaluationVarsBase, + mut yield_constr: StridedConstraintConsumer, + ) { let inputs: [_; SPONGE_WIDTH] = (0..SPONGE_WIDTH) .map(|i| vars.get_local_ext(Self::wires_input(i))) .collect::>() @@ -132,11 +137,12 @@ impl + Poseidon, const D: usize> Gate for Pos let computed_outputs = F::mds_layer_field(&inputs); - (0..SPONGE_WIDTH) - .map(|i| vars.get_local_ext(Self::wires_output(i))) - .zip(computed_outputs) - .flat_map(|(out, computed_out)| (out - computed_out).to_basefield_array()) - .collect() + yield_constr.many( + (0..SPONGE_WIDTH) + .map(|i| vars.get_local_ext(Self::wires_output(i))) + .zip(computed_outputs) + .flat_map(|(out, computed_out)| (out - computed_out).to_basefield_array()), + ) } fn eval_unfiltered_recursively( diff --git a/src/gates/public_input.rs b/src/gates/public_input.rs index 7447cd1d..116d8917 100644 --- a/src/gates/public_input.rs +++ b/src/gates/public_input.rs @@ -3,6 +3,7 @@ use std::ops::Range; use crate::field::extension_field::target::ExtensionTarget; use crate::field::extension_field::Extendable; use crate::gates::gate::Gate; +use crate::gates::util::StridedConstraintConsumer; use crate::iop::generator::WitnessGenerator; use crate::plonk::circuit_builder::CircuitBuilder; use crate::plonk::vars::{EvaluationTargets, EvaluationVars, EvaluationVarsBase}; @@ -28,11 +29,16 @@ impl, const D: usize> Gate for PublicInputGate { .collect() } - fn eval_unfiltered_base(&self, vars: EvaluationVarsBase) -> Vec { - Self::wires_public_inputs_hash() - .zip(vars.public_inputs_hash.elements) - .map(|(wire, hash_part)| vars.local_wires[wire] - hash_part) - .collect() + fn eval_unfiltered_base_one( + &self, + vars: EvaluationVarsBase, + mut yield_constr: StridedConstraintConsumer, + ) { + yield_constr.many( + Self::wires_public_inputs_hash() + .zip(vars.public_inputs_hash.elements) + .map(|(wire, hash_part)| vars.local_wires[wire] - hash_part), + ) } fn eval_unfiltered_recursively( diff --git a/src/gates/random_access.rs b/src/gates/random_access.rs index 9ea0db55..06c1274f 100644 --- a/src/gates/random_access.rs +++ b/src/gates/random_access.rs @@ -6,6 +6,7 @@ use crate::field::extension_field::target::ExtensionTarget; use crate::field::extension_field::Extendable; use crate::field::field_types::Field; use crate::gates::gate::Gate; +use crate::gates::util::StridedConstraintConsumer; use crate::iop::generator::{GeneratedValues, SimpleGenerator, WitnessGenerator}; use crate::iop::target::Target; use crate::iop::wire::Wire; @@ -125,9 +126,11 @@ impl, const D: usize> Gate for RandomAccessGate { constraints } - fn eval_unfiltered_base(&self, vars: EvaluationVarsBase) -> Vec { - let mut constraints = Vec::with_capacity(self.num_constraints()); - + fn eval_unfiltered_base_one( + &self, + vars: EvaluationVarsBase, + mut yield_constr: StridedConstraintConsumer, + ) { for copy in 0..self.num_copies { let access_index = vars.local_wires[self.wire_access_index(copy)]; let mut list_items = (0..self.vec_size()) @@ -140,12 +143,12 @@ impl, const D: usize> Gate for RandomAccessGate { // Assert that each bit wire value is indeed boolean. for &b in &bits { - constraints.push(b * (b - F::ONE)); + yield_constr.one(b * (b - F::ONE)); } // Assert that the binary decomposition was correct. let reconstructed_index = bits.iter().rev().fold(F::ZERO, |acc, &b| acc.double() + b); - constraints.push(reconstructed_index - access_index); + yield_constr.one(reconstructed_index - access_index); // Repeatedly fold the list, selecting the left or right item from each pair based on // the corresponding bit. @@ -158,10 +161,8 @@ impl, const D: usize> Gate for RandomAccessGate { } debug_assert_eq!(list_items.len(), 1); - constraints.push(list_items[0] - claimed_element); + yield_constr.one(list_items[0] - claimed_element); } - - constraints } fn eval_unfiltered_recursively( diff --git a/src/gates/reducing.rs b/src/gates/reducing.rs index c9ffce57..5d918781 100644 --- a/src/gates/reducing.rs +++ b/src/gates/reducing.rs @@ -5,6 +5,7 @@ use crate::field::extension_field::Extendable; use crate::field::extension_field::FieldExtension; use crate::field::field_types::RichField; use crate::gates::gate::Gate; +use crate::gates::util::StridedConstraintConsumer; use crate::iop::generator::{GeneratedValues, SimpleGenerator, WitnessGenerator}; use crate::iop::target::Target; use crate::iop::witness::{PartitionWitness, Witness}; @@ -80,7 +81,11 @@ impl, const D: usize> Gate for ReducingGate) -> Vec { + fn eval_unfiltered_base_one( + &self, + vars: EvaluationVarsBase, + mut yield_constr: StridedConstraintConsumer, + ) { let alpha = vars.get_local_ext(Self::wires_alpha()); let old_acc = vars.get_local_ext(Self::wires_old_acc()); let coeffs = self @@ -91,14 +96,11 @@ impl, const D: usize> Gate for ReducingGate>(); - let mut constraints = Vec::with_capacity(>::num_constraints(self)); let mut acc = old_acc; for i in 0..self.num_coeffs { - constraints.extend((acc * alpha + coeffs[i].into() - accs[i]).to_basefield_array()); + yield_constr.many((acc * alpha + coeffs[i].into() - accs[i]).to_basefield_array()); acc = accs[i]; } - - constraints } fn eval_unfiltered_recursively( diff --git a/src/gates/reducing_extension.rs b/src/gates/reducing_extension.rs index 09b5420b..8d06dfbd 100644 --- a/src/gates/reducing_extension.rs +++ b/src/gates/reducing_extension.rs @@ -5,6 +5,7 @@ use crate::field::extension_field::Extendable; use crate::field::extension_field::FieldExtension; use crate::field::field_types::RichField; use crate::gates::gate::Gate; +use crate::gates::util::StridedConstraintConsumer; use crate::iop::generator::{GeneratedValues, SimpleGenerator, WitnessGenerator}; use crate::iop::target::Target; use crate::iop::witness::{PartitionWitness, Witness}; @@ -82,7 +83,11 @@ impl, const D: usize> Gate for ReducingExtens .collect() } - fn eval_unfiltered_base(&self, vars: EvaluationVarsBase) -> Vec { + fn eval_unfiltered_base_one( + &self, + vars: EvaluationVarsBase, + mut yield_constr: StridedConstraintConsumer, + ) { let alpha = vars.get_local_ext(Self::wires_alpha()); let old_acc = vars.get_local_ext(Self::wires_old_acc()); let coeffs = (0..self.num_coeffs) @@ -92,14 +97,11 @@ impl, const D: usize> Gate for ReducingExtens .map(|i| vars.get_local_ext(self.wires_accs(i))) .collect::>(); - let mut constraints = Vec::with_capacity(>::num_constraints(self)); let mut acc = old_acc; for i in 0..self.num_coeffs { - constraints.extend((acc * alpha + coeffs[i] - accs[i]).to_basefield_array()); + yield_constr.many((acc * alpha + coeffs[i] - accs[i]).to_basefield_array()); acc = accs[i]; } - - constraints } fn eval_unfiltered_recursively( diff --git a/src/gates/subtraction_u32.rs b/src/gates/subtraction_u32.rs index a15cf6e8..de884a24 100644 --- a/src/gates/subtraction_u32.rs +++ b/src/gates/subtraction_u32.rs @@ -4,6 +4,7 @@ use crate::field::extension_field::target::ExtensionTarget; use crate::field::extension_field::Extendable; use crate::field::field_types::{Field, RichField}; use crate::gates::gate::Gate; +use crate::gates::util::StridedConstraintConsumer; use crate::iop::generator::{GeneratedValues, SimpleGenerator, WitnessGenerator}; use crate::iop::target::Target; use crate::iop::wire::Wire; @@ -113,8 +114,11 @@ impl, const D: usize> Gate for U32Subtraction constraints } - fn eval_unfiltered_base(&self, vars: EvaluationVarsBase) -> Vec { - let mut constraints = Vec::with_capacity(self.num_constraints()); + fn eval_unfiltered_base_one( + &self, + vars: EvaluationVarsBase, + mut yield_constr: StridedConstraintConsumer, + ) { for i in 0..self.num_ops { let input_x = vars.local_wires[self.wire_ith_input_x(i)]; let input_y = vars.local_wires[self.wire_ith_input_y(i)]; @@ -126,7 +130,7 @@ impl, const D: usize> Gate for U32Subtraction let output_result = vars.local_wires[self.wire_ith_output_result(i)]; let output_borrow = vars.local_wires[self.wire_ith_output_borrow(i)]; - constraints.push(output_result - (result_initial + base * output_borrow)); + yield_constr.one(output_result - (result_initial + base * output_borrow)); // Range-check output_result to be at most 32 bits. let mut combined_limbs = F::ZERO; @@ -137,17 +141,15 @@ impl, const D: usize> Gate for U32Subtraction let product = (0..max_limb) .map(|x| this_limb - F::from_canonical_usize(x)) .product(); - constraints.push(product); + yield_constr.one(product); combined_limbs = limb_base * combined_limbs + this_limb; } - constraints.push(combined_limbs - output_result); + yield_constr.one(combined_limbs - output_result); // Range-check output_borrow to be one bit. - constraints.push(output_borrow * (F::ONE - output_borrow)); + yield_constr.one(output_borrow * (F::ONE - output_borrow)); } - - constraints } fn eval_unfiltered_recursively( diff --git a/src/gates/switch.rs b/src/gates/switch.rs index be14e76e..9026201e 100644 --- a/src/gates/switch.rs +++ b/src/gates/switch.rs @@ -6,6 +6,7 @@ use crate::field::extension_field::target::ExtensionTarget; use crate::field::extension_field::Extendable; use crate::field::field_types::Field; use crate::gates::gate::Gate; +use crate::gates::util::StridedConstraintConsumer; use crate::iop::generator::{GeneratedValues, WitnessGenerator}; use crate::iop::target::Target; use crate::iop::wire::Wire; @@ -94,9 +95,11 @@ impl, const D: usize> Gate for SwitchGate { constraints } - fn eval_unfiltered_base(&self, vars: EvaluationVarsBase) -> Vec { - let mut constraints = Vec::with_capacity(self.num_constraints()); - + fn eval_unfiltered_base_one( + &self, + vars: EvaluationVarsBase, + mut yield_constr: StridedConstraintConsumer, + ) { for c in 0..self.num_copies { let switch_bool = vars.local_wires[self.wire_switch_bool(c)]; let not_switch = F::ONE - switch_bool; @@ -107,14 +110,12 @@ impl, const D: usize> Gate for SwitchGate { let first_output = vars.local_wires[self.wire_first_output(c, e)]; let second_output = vars.local_wires[self.wire_second_output(c, e)]; - constraints.push(switch_bool * (first_input - second_output)); - constraints.push(switch_bool * (second_input - first_output)); - constraints.push(not_switch * (first_input - first_output)); - constraints.push(not_switch * (second_input - second_output)); + yield_constr.one(switch_bool * (first_input - second_output)); + yield_constr.one(switch_bool * (second_input - first_output)); + yield_constr.one(not_switch * (first_input - first_output)); + yield_constr.one(not_switch * (second_input - second_output)); } } - - constraints } fn eval_unfiltered_recursively( diff --git a/src/gates/util.rs b/src/gates/util.rs new file mode 100644 index 00000000..3f26db50 --- /dev/null +++ b/src/gates/util.rs @@ -0,0 +1,62 @@ +use std::marker::PhantomData; + +use crate::field::field_types::Field; + +/// Writes constraints yielded by a gate to a buffer, with a given stride. +/// Permits us to abstract the underlying memory layout. In particular, we can make a matrix of +/// constraints where every column is an evaluation point and every row is a constraint index, with +/// the matrix stored in row-contiguous form. +pub struct StridedConstraintConsumer<'a, F: Field> { + // This is a particularly neat way of doing this, more so than a slice. We increase start by + // stride at every step and terminate when it equals end. + start: *mut F, + end: *mut F, + stride: usize, + _phantom: PhantomData<&'a mut [F]>, +} + +impl<'a, F: Field> StridedConstraintConsumer<'a, F> { + pub fn new(buffer: &'a mut [F], stride: usize, offset: usize) -> Self { + assert!(offset < stride); + assert_eq!(buffer.len() % stride, 0); + let ptr_range = buffer.as_mut_ptr_range(); + // `wrapping_add` is needed to avoid undefined behavior. Plain `add` causes UB if 'the ... + // resulting pointer [is neither] in bounds or one byte past the end of the same allocated + // object'; the UB results even if the pointer is not dereferenced. `end` will be more than + // one byte past the buffer unless `offset` is 0. The same applies to `start` if the buffer + // has length 0 and the offset is not 0. + // We _could_ do pointer arithmetic without `wrapping_add`, but the logic would be + // unnecessarily complicated. + let start = ptr_range.start.wrapping_add(offset); + let end = ptr_range.end.wrapping_add(offset); + Self { + start, + end, + stride, + _phantom: PhantomData, + } + } + + /// Emit one constraint. + pub fn one(&mut self, constraint: F) { + if self.start != self.end { + // # Safety + // The checks in `new` guarantee that this points to valid space. + unsafe { + *self.start = constraint; + } + // See the comment in `new`. `wrapping_add` is needed to avoid UB if we've just + // exhausted our buffer (and hence we're setting `self.start` to point past the end). + self.start = self.start.wrapping_add(self.stride); + } else { + panic!("gate produced too many constraints"); + } + } + + /// Convenience method that calls `.one()` multiple times. + pub fn many>(&mut self, constraints: I) { + constraints + .into_iter() + .for_each(|constraint| self.one(constraint)); + } +} diff --git a/src/plonk/plonk_common.rs b/src/plonk/plonk_common.rs index 5be13740..89be659e 100644 --- a/src/plonk/plonk_common.rs +++ b/src/plonk/plonk_common.rs @@ -157,9 +157,15 @@ pub(crate) fn reduce_with_powers_multi< cumul } -pub(crate) fn reduce_with_powers(terms: &[F], alpha: F) -> F { +pub(crate) fn reduce_with_powers<'a, F: Field, T: IntoIterator>( + terms: T, + alpha: F, +) -> F +where + T::IntoIter: DoubleEndedIterator, +{ let mut sum = F::ZERO; - for &term in terms.iter().rev() { + for &term in terms.into_iter().rev() { sum = sum * alpha + term; } sum diff --git a/src/plonk/prover.rs b/src/plonk/prover.rs index 8f7cf3cd..6cf569e6 100644 --- a/src/plonk/prover.rs +++ b/src/plonk/prover.rs @@ -14,7 +14,7 @@ use crate::plonk::plonk_common::PlonkPolynomials; use crate::plonk::plonk_common::ZeroPolyOnCoset; use crate::plonk::proof::{Proof, ProofWithPublicInputs}; use crate::plonk::vanishing_poly::eval_vanishing_poly_base_batch; -use crate::plonk::vars::EvaluationVarsBase; +use crate::plonk::vars::EvaluationVarsBaseBatch; use crate::polynomial::{PolynomialCoeffs, PolynomialValues}; use crate::timed; use crate::util::partial_products::{partial_products_and_z_gx, quotient_chunk_products}; @@ -333,12 +333,14 @@ fn compute_quotient_polys<'a, F: Extendable, C: GenericConfig, cons (BATCH_SIZE * batch_i..BATCH_SIZE * (batch_i + 1)).collect(); let mut shifted_xs_batch = Vec::with_capacity(xs_batch.len()); - let mut vars_batch = Vec::with_capacity(xs_batch.len()); let mut local_zs_batch = Vec::with_capacity(xs_batch.len()); let mut next_zs_batch = Vec::with_capacity(xs_batch.len()); let mut partial_products_batch = Vec::with_capacity(xs_batch.len()); let mut s_sigmas_batch = Vec::with_capacity(xs_batch.len()); + let mut local_constants_batch_refs = Vec::with_capacity(xs_batch.len()); + let mut local_wires_batch_refs = Vec::with_capacity(xs_batch.len()); + for (&i, &x) in indices_batch.iter().zip(xs_batch) { let shifted_x = F::coset_shift() * x; let i_next = (i + next_step) % lde_size; @@ -357,24 +359,45 @@ fn compute_quotient_polys<'a, F: Extendable, C: GenericConfig, cons debug_assert_eq!(local_wires.len(), common_data.config.num_wires); debug_assert_eq!(local_zs.len(), num_challenges); - let vars = EvaluationVarsBase { - local_constants, - local_wires, - public_inputs_hash, - }; + local_constants_batch_refs.push(local_constants); + local_wires_batch_refs.push(local_wires); shifted_xs_batch.push(shifted_x); - vars_batch.push(vars); local_zs_batch.push(local_zs); next_zs_batch.push(next_zs); partial_products_batch.push(partial_products); s_sigmas_batch.push(s_sigmas); } + + // NB (JN): I'm not sure how (in)efficient the below is. It needs measuring. + let mut local_constants_batch = + vec![F::ZERO; xs_batch.len() * local_constants_batch_refs[0].len()]; + for (i, constants) in local_constants_batch_refs.iter().enumerate() { + for (j, &constant) in constants.iter().enumerate() { + local_constants_batch[i + j * xs_batch.len()] = constant; + } + } + + let mut local_wires_batch = + vec![F::ZERO; xs_batch.len() * local_wires_batch_refs[0].len()]; + for (i, wires) in local_wires_batch_refs.iter().enumerate() { + for (j, &wire) in wires.iter().enumerate() { + local_wires_batch[i + j * xs_batch.len()] = wire; + } + } + + let vars_batch = EvaluationVarsBaseBatch::new( + xs_batch.len(), + &local_constants_batch, + &local_wires_batch, + public_inputs_hash, + ); + let mut quotient_values_batch = eval_vanishing_poly_base_batch( common_data, &indices_batch, &shifted_xs_batch, - &vars_batch, + vars_batch, &local_zs_batch, &next_zs_batch, &partial_products_batch, diff --git a/src/plonk/vanishing_poly.rs b/src/plonk/vanishing_poly.rs index 025fd0bd..e1f35bd4 100644 --- a/src/plonk/vanishing_poly.rs +++ b/src/plonk/vanishing_poly.rs @@ -1,3 +1,4 @@ +use crate::field::batch_util::batch_add_inplace; use crate::field::extension_field::target::ExtensionTarget; use crate::field::extension_field::{Extendable, FieldExtension}; use crate::field::field_types::Field; @@ -8,9 +9,10 @@ use crate::plonk::circuit_data::CommonCircuitData; use crate::plonk::config::GenericConfig; use crate::plonk::plonk_common; use crate::plonk::plonk_common::{eval_l_1_recursively, ZeroPolyOnCoset}; -use crate::plonk::vars::{EvaluationTargets, EvaluationVars, EvaluationVarsBase}; +use crate::plonk::vars::{EvaluationTargets, EvaluationVars, EvaluationVarsBaseBatch}; use crate::util::partial_products::{check_partial_products, check_partial_products_recursively}; use crate::util::reducing::ReducingFactorTarget; +use crate::util::strided_view::PackedStridedView; use crate::with_context; /// Evaluate the vanishing polynomial at `x`. In this context, the vanishing polynomial is a random @@ -96,7 +98,7 @@ pub(crate) fn eval_vanishing_poly_base_batch< common_data: &CommonCircuitData, indices_batch: &[usize], xs_batch: &[F], - vars_batch: &[EvaluationVarsBase], + vars_batch: EvaluationVarsBaseBatch, local_zs_batch: &[&[F]], next_zs_batch: &[&[F]], partial_products_batch: &[&[F]], @@ -138,14 +140,13 @@ pub(crate) fn eval_vanishing_poly_base_batch< for k in 0..n { let index = indices_batch[k]; let x = xs_batch[k]; - let vars = vars_batch[k]; + let vars = vars_batch.view(k); let local_zs = local_zs_batch[k]; let next_zs = next_zs_batch[k]; let partial_products = partial_products_batch[k]; let s_sigmas = s_sigmas_batch[k]; - let constraint_terms = - &constraint_terms_batch[k * num_gate_constraints..(k + 1) * num_gate_constraints]; + let constraint_terms = PackedStridedView::new(&constraint_terms_batch, n, k); let l1_x = z_h_on_coset.eval_l1(index, x); for i in 0..num_challenges { @@ -221,13 +222,13 @@ pub fn evaluate_gate_constraints, const D: usize>( /// Evaluate all gate constraints in the base field. /// -/// Returns a vector of num_gate_constraints * vars_batch.len() field elements. The constraints -/// corresponding to vars_batch[i] are found in -/// result[num_gate_constraints * i..num_gate_constraints * (i + 1)]. +/// Returns a vector of `num_gate_constraints * vars_batch.len()` field elements. The constraints +/// corresponding to `vars_batch[i]` are found in `result[i], result[vars_batch.len() + i], +/// result[2 * vars_batch.len() + i], ...`. pub fn evaluate_gate_constraints_base_batch, const D: usize>( gates: &[PrefixedGate], num_gate_constraints: usize, - vars_batch: &[EvaluationVarsBase], + vars_batch: EvaluationVarsBaseBatch, ) -> Vec { let mut constraints_batch = vec![F::ZERO; num_gate_constraints * vars_batch.len()]; for gate in gates { @@ -235,20 +236,15 @@ pub fn evaluate_gate_constraints_base_batch, const D: usize>( .gate .0 .eval_filtered_base_batch(vars_batch, &gate.prefix); - for (constraints, gate_constraints) in constraints_batch - .chunks_exact_mut(num_gate_constraints) - .zip(gate_constraints_batch.iter()) - { - debug_assert!( - gate_constraints.len() <= constraints.len(), - "num_constraints() gave too low of a number" - ); - for (constraint, &gate_constraint) in - constraints.iter_mut().zip(gate_constraints.iter()) - { - *constraint += gate_constraint; - } - } + debug_assert!( + gate_constraints_batch.len() <= constraints_batch.len(), + "num_constraints() gave too low of a number" + ); + // below adds all constraints for all points + batch_add_inplace( + &mut constraints_batch[..gate_constraints_batch.len()], + &gate_constraints_batch, + ); } constraints_batch } diff --git a/src/plonk/vars.rs b/src/plonk/vars.rs index b643b7b7..5f5b3cf8 100644 --- a/src/plonk/vars.rs +++ b/src/plonk/vars.rs @@ -5,6 +5,7 @@ use crate::field::extension_field::target::{ExtensionAlgebraTarget, ExtensionTar use crate::field::extension_field::{Extendable, FieldExtension}; use crate::field::field_types::Field; use crate::hash::hash_types::{HashOut, HashOutTarget}; +use crate::util::strided_view::PackedStridedView; #[derive(Debug, Copy, Clone)] pub struct EvaluationVars<'a, F: Extendable, const D: usize> { @@ -13,13 +14,25 @@ pub struct EvaluationVars<'a, F: Extendable, const D: usize> { pub(crate) public_inputs_hash: &'a HashOut, } +/// A batch of evaluation vars, in the base field. +/// Wires and constants are stored in an evaluation point-major order (that is, wire 0 for all +/// evaluation points, then wire 1 for all points, and so on). #[derive(Debug, Copy, Clone)] -pub struct EvaluationVarsBase<'a, F: Field> { +pub struct EvaluationVarsBaseBatch<'a, F: Field> { + batch_size: usize, pub(crate) local_constants: &'a [F], pub(crate) local_wires: &'a [F], pub(crate) public_inputs_hash: &'a HashOut, } +/// A view into `EvaluationVarsBaseBatch` for a particular evaluation point. Does not copy the data. +#[derive(Debug, Copy, Clone)] +pub struct EvaluationVarsBase<'a, F: Field> { + pub(crate) local_constants: PackedStridedView<'a, F>, + pub(crate) local_wires: PackedStridedView<'a, F>, + pub(crate) public_inputs_hash: &'a HashOut, +} + impl<'a, F: Extendable, const D: usize> EvaluationVars<'a, F, D> { pub fn get_local_ext_algebra( &self, @@ -35,18 +48,81 @@ impl<'a, F: Extendable, const D: usize> EvaluationVars<'a, F, D> { } } +impl<'a, F: Field> EvaluationVarsBaseBatch<'a, F> { + pub fn new( + batch_size: usize, + local_constants: &'a [F], + local_wires: &'a [F], + public_inputs_hash: &'a HashOut, + ) -> Self { + assert_eq!(local_constants.len() % batch_size, 0); + assert_eq!(local_wires.len() % batch_size, 0); + Self { + batch_size, + local_constants, + local_wires, + public_inputs_hash, + } + } + + pub fn remove_prefix(&mut self, prefix: &[bool]) { + self.local_constants = &self.local_constants[prefix.len() * self.len()..]; + } + + pub fn len(&self) -> usize { + self.batch_size + } + + pub fn view(&self, index: usize) -> EvaluationVarsBase<'a, F> { + // We cannot implement `Index` as `EvaluationVarsBase` is a struct, not a reference. + assert!(index < self.len()); + let local_constants = PackedStridedView::new(self.local_constants, self.len(), index); + let local_wires = PackedStridedView::new(self.local_wires, self.len(), index); + EvaluationVarsBase { + local_constants, + local_wires, + public_inputs_hash: self.public_inputs_hash, + } + } + + pub fn iter(&self) -> EvaluationVarsBaseBatchIter<'a, F> { + EvaluationVarsBaseBatchIter::new(*self) + } +} + impl<'a, F: Field> EvaluationVarsBase<'a, F> { pub fn get_local_ext(&self, wire_range: Range) -> F::Extension where F: Extendable, { debug_assert_eq!(wire_range.len(), D); - let arr = self.local_wires[wire_range].try_into().unwrap(); + let arr = self.local_wires.view(wire_range).try_into().unwrap(); F::Extension::from_basefield_array(arr) } +} - pub fn remove_prefix(&mut self, prefix: &[bool]) { - self.local_constants = &self.local_constants[prefix.len()..]; +/// Iterator of views (`EvaluationVarsBase`) into a `EvaluationVarsBaseBatch`. +pub struct EvaluationVarsBaseBatchIter<'a, F: Field> { + i: usize, + vars_batch: EvaluationVarsBaseBatch<'a, F>, +} + +impl<'a, F: Field> EvaluationVarsBaseBatchIter<'a, F> { + pub fn new(vars_batch: EvaluationVarsBaseBatch<'a, F>) -> Self { + EvaluationVarsBaseBatchIter { i: 0, vars_batch } + } +} + +impl<'a, F: Field> Iterator for EvaluationVarsBaseBatchIter<'a, F> { + type Item = EvaluationVarsBase<'a, F>; + fn next(&mut self) -> Option { + if self.i < self.vars_batch.len() { + let res = self.vars_batch.view(self.i); + self.i += 1; + Some(res) + } else { + None + } } } diff --git a/src/util/mod.rs b/src/util/mod.rs index 3f7c5dd1..5d1de5ec 100644 --- a/src/util/mod.rs +++ b/src/util/mod.rs @@ -10,6 +10,7 @@ pub(crate) mod marking; pub(crate) mod partial_products; pub mod reducing; pub mod serialization; +pub(crate) mod strided_view; pub(crate) mod timing; pub(crate) fn bits_u64(n: u64) -> usize { diff --git a/src/util/strided_view.rs b/src/util/strided_view.rs new file mode 100644 index 00000000..9beb13aa --- /dev/null +++ b/src/util/strided_view.rs @@ -0,0 +1,317 @@ +use std::marker::PhantomData; +use std::mem::size_of; +use std::ops::{Index, Range, RangeFrom, RangeFull, RangeInclusive, RangeTo, RangeToInclusive}; + +use crate::field::packed_field::PackedField; + +/// Imagine a slice, but with a stride (a la a NumPy array). +/// +/// For example, if the stride is 3, +/// `packed_strided_view[0]` is `data[0]`, +/// `packed_strided_view[1]` is `data[3]`, +/// `packed_strided_view[2]` is `data[6]`, +/// and so on. An offset may be specified. With an offset of 1, we get +/// `packed_strided_view[0]` is `data[1]`, +/// `packed_strided_view[1]` is `data[4]`, +/// `packed_strided_view[2]` is `data[7]`, +/// and so on. +/// +/// Additionally, this view is *packed*, which means that it may yield a packing of the underlying +/// field slice. With a packing of width 4 and a stride of 5, the accesses are +/// `packed_strided_view[0]` is `data[0..4]`, transmuted to the packing, +/// `packed_strided_view[1]` is `data[5..9]`, transmuted to the packing, +/// `packed_strided_view[2]` is `data[10..14]`, transmuted to the packing, +/// and so on. +#[derive(Debug, Copy, Clone)] +pub struct PackedStridedView<'a, P: PackedField> { + // This type has to be a struct, which means that it is not itself a reference (in the sense + // that a slice is a reference so we can return it from e.g. `Index::index`). + + // Raw pointers rarely appear in good Rust code, but I think this is the most elegant way to + // implement this. The alternative would be to replace `start_ptr` and `length` with one slice + // (`&[P::Scalar]`). Unfortunately, with a slice, an empty view becomes an edge case that + // necessitates separate handling. It _could_ be done but it would also be uglier. + start_ptr: *const P::Scalar, + /// This is the total length of elements accessible through the view. In other words, valid + /// indices are in `0..length`. + length: usize, + /// This stride is in units of `P::Scalar` (NOT in bytes and NOT in `P`). + stride: usize, + _phantom: PhantomData<&'a [P::Scalar]>, +} + +impl<'a, P: PackedField> PackedStridedView<'a, P> { + // `wrapping_add` is needed throughout to avoid undefined behavior. Plain `add` causes UB if + // '[either] the starting [or] resulting pointer [is neither] in bounds or one byte past the + // end of the same allocated object'; the UB results even if the pointer is not dereferenced. + + #[inline] + pub fn new(data: &'a [P::Scalar], stride: usize, offset: usize) -> Self { + assert!( + stride >= P::WIDTH, + "stride (got {}) must be at least P::WIDTH ({})", + stride, + P::WIDTH + ); + assert_eq!( + data.len() % stride, + 0, + "data.len() ({}) must be a multiple of stride (got {})", + data.len(), + stride + ); + + // This requirement means that stride divides data into slices of `data.len() / stride` + // elements. Every access must fit entirely within one of those slices. + assert!( + offset + P::WIDTH <= stride, + "offset (got {}) + P::WIDTH ({}) cannot be greater than stride (got {})", + offset, + P::WIDTH, + stride + ); + + // See comment above. `start_ptr` will be more than one byte past the buffer if `data` has + // length 0 and `offset` is not 0. + let start_ptr = data.as_ptr().wrapping_add(offset); + + Self { + start_ptr, + length: data.len() / stride, + stride, + _phantom: PhantomData, + } + } + + #[inline] + pub fn get(&self, index: usize) -> Option<&'a P> { + if index < self.length { + // Cast scalar pointer to vector pointer. + let res_ptr = unsafe { self.start_ptr.add(index * self.stride) }.cast(); + // This transmutation is safe by the spec in `PackedField`. + Some(unsafe { &*res_ptr }) + } else { + None + } + } + + /// Take a range of `PackedStridedView` indices, as `PackedStridedView`. + #[inline] + pub fn view(&self, index: I) -> Self + where + Self: Viewable, + { + // We cannot implement `Index` as `PackedStridedView` is a struct, not a reference. + + // The `Viewable` trait is needed for overloading. + // Re-export `Viewable::view` so users don't have to import `Viewable`. + >::view(self, index) + } + + #[inline] + pub fn iter(&self) -> PackedStridedViewIter<'a, P> { + PackedStridedViewIter::new( + self.start_ptr, + // See comment at the top of the `impl`. Below will point more than one byte past the + // end of the buffer (unless `offset` is 0) so `wrapping_add` is needed. + self.start_ptr.wrapping_add(self.length * self.stride), + self.stride, + ) + } + + #[inline] + pub fn len(&self) -> usize { + self.length + } +} + +impl<'a, P: PackedField> Index for PackedStridedView<'a, P> { + type Output = P; + #[inline] + fn index(&self, index: usize) -> &Self::Output { + self.get(index) + .expect("invalid memory access in PackedStridedView") + } +} + +impl<'a, P: PackedField> IntoIterator for PackedStridedView<'a, P> { + type Item = &'a P; + type IntoIter = PackedStridedViewIter<'a, P>; + fn into_iter(self) -> Self::IntoIter { + self.iter() + } +} + +#[derive(Clone, Copy, Debug)] +pub struct TryFromPackedStridedViewError; + +impl TryInto<[P; N]> for PackedStridedView<'_, P> { + type Error = TryFromPackedStridedViewError; + fn try_into(self) -> Result<[P; N], Self::Error> { + if N == self.len() { + let mut res = [P::ZERO; N]; + for i in 0..N { + res[i] = *self.get(i).unwrap(); + } + Ok(res) + } else { + Err(TryFromPackedStridedViewError) + } + } +} + +// Not deriving `Copy`. An implicit copy of an iterator is likely a bug. +#[derive(Clone, Debug)] +pub struct PackedStridedViewIter<'a, P: PackedField> { + // Again, a pair of pointers is a neater solution than a slice. `start` and `end` are always + // separated by a multiple of stride elements. To advance the iterator from the front, we + // advance `start` by `stride` elements. To advance it from the end, we subtract `stride` + // elements. Iteration is done when they meet. + // A slice cannot recreate the same pattern. The end pointer may point past the underlying + // buffer (this is okay as we do not dereference it in that case); it becomes valid as soon as + // it is decreased by `stride`. On the other hand, a slice that ends on invalid memory is + // instant undefined behavior. + start: *const P::Scalar, + end: *const P::Scalar, + stride: usize, + _phantom: PhantomData<&'a [P::Scalar]>, +} + +impl<'a, P: PackedField> PackedStridedViewIter<'a, P> { + pub(self) fn new(start: *const P::Scalar, end: *const P::Scalar, stride: usize) -> Self { + Self { + start, + end, + stride, + _phantom: PhantomData, + } + } +} + +impl<'a, P: PackedField> Iterator for PackedStridedViewIter<'a, P> { + type Item = &'a P; + fn next(&mut self) -> Option { + debug_assert_eq!( + (self.end as usize).wrapping_sub(self.start as usize) + % (self.stride * size_of::()), + 0, + "start and end pointers should be separated by a multiple of stride" + ); + + if self.start != self.end { + let res = unsafe { &*self.start.cast() }; + // See comment in `PackedStridedView`. Below will point more than one byte past the end + // of the buffer if the offset is not 0 and we've reached the end. + self.start = self.start.wrapping_add(self.stride); + Some(res) + } else { + None + } + } +} + +impl<'a, P: PackedField> DoubleEndedIterator for PackedStridedViewIter<'a, P> { + fn next_back(&mut self) -> Option { + debug_assert_eq!( + (self.end as usize).wrapping_sub(self.start as usize) + % (self.stride * size_of::()), + 0, + "start and end pointers should be separated by a multiple of stride" + ); + + if self.start != self.end { + // See comment in `PackedStridedView`. `self.end` starts off pointing more than one byte + // past the end of the buffer unless `offset` is 0. + self.end = self.end.wrapping_sub(self.stride); + Some(unsafe { &*self.end.cast() }) + } else { + None + } + } +} + +pub trait Viewable { + // We cannot implement `Index` as `PackedStridedView` is a struct, not a reference. + type View; + fn view(&self, index: F) -> Self::View; +} + +impl<'a, P: PackedField> Viewable> for PackedStridedView<'a, P> { + type View = Self; + fn view(&self, range: Range) -> Self::View { + assert!(range.start <= self.len(), "Invalid access"); + assert!(range.end <= self.len(), "Invalid access"); + Self { + // See comment in `PackedStridedView`. `self.start_ptr` will point more than one byte + // past the end of the buffer if the offset is not 0 and the buffer has length 0. + start_ptr: self.start_ptr.wrapping_add(self.stride * range.start), + length: range.end - range.start, + stride: self.stride, + _phantom: PhantomData, + } + } +} + +impl<'a, P: PackedField> Viewable> for PackedStridedView<'a, P> { + type View = Self; + fn view(&self, range: RangeFrom) -> Self::View { + assert!(range.start <= self.len(), "Invalid access"); + Self { + // See comment in `PackedStridedView`. `self.start_ptr` will point more than one byte + // past the end of the buffer if the offset is not 0 and the buffer has length 0. + start_ptr: self.start_ptr.wrapping_add(self.stride * range.start), + length: self.len() - range.start, + stride: self.stride, + _phantom: PhantomData, + } + } +} + +impl<'a, P: PackedField> Viewable for PackedStridedView<'a, P> { + type View = Self; + fn view(&self, _range: RangeFull) -> Self::View { + *self + } +} + +impl<'a, P: PackedField> Viewable> for PackedStridedView<'a, P> { + type View = Self; + fn view(&self, range: RangeInclusive) -> Self::View { + assert!(*range.start() <= self.len(), "Invalid access"); + assert!(*range.end() < self.len(), "Invalid access"); + Self { + // See comment in `PackedStridedView`. `self.start_ptr` will point more than one byte + // past the end of the buffer if the offset is not 0 and the buffer has length 0. + start_ptr: self.start_ptr.wrapping_add(self.stride * range.start()), + length: range.end() - range.start() + 1, + stride: self.stride, + _phantom: PhantomData, + } + } +} + +impl<'a, P: PackedField> Viewable> for PackedStridedView<'a, P> { + type View = Self; + fn view(&self, range: RangeTo) -> Self::View { + assert!(range.end <= self.len(), "Invalid access"); + Self { + start_ptr: self.start_ptr, + length: range.end, + stride: self.stride, + _phantom: PhantomData, + } + } +} + +impl<'a, P: PackedField> Viewable> for PackedStridedView<'a, P> { + type View = Self; + fn view(&self, range: RangeToInclusive) -> Self::View { + assert!(range.end < self.len(), "Invalid access"); + Self { + start_ptr: self.start_ptr, + length: range.end + 1, + stride: self.stride, + _phantom: PhantomData, + } + } +}