Merge remote-tracking branch 'origin/main' into generic_configuration

This commit is contained in:
Jakub Nabaglo 2021-12-16 13:28:49 -08:00
commit 81c6f6c7bf
34 changed files with 847 additions and 242 deletions

65
src/field/batch_util.rs Normal file
View File

@ -0,0 +1,65 @@
use crate::field::field_types::Field;
use crate::field::packable::Packable;
use crate::field::packed_field::PackedField;
fn pack_with_leftovers_split_point<P: PackedField>(slice: &[P::Scalar]) -> usize {
let n = slice.len();
let n_leftover = n % P::WIDTH;
n - n_leftover
}
fn pack_slice_with_leftovers<P: PackedField>(slice: &[P::Scalar]) -> (&[P], &[P::Scalar]) {
let split_point = pack_with_leftovers_split_point::<P>(slice);
let (slice_packable, slice_leftovers) = slice.split_at(split_point);
let slice_packed = P::pack_slice(slice_packable);
(slice_packed, slice_leftovers)
}
fn pack_slice_with_leftovers_mut<P: PackedField>(
slice: &mut [P::Scalar],
) -> (&mut [P], &mut [P::Scalar]) {
let split_point = pack_with_leftovers_split_point::<P>(slice);
let (slice_packable, slice_leftovers) = slice.split_at_mut(split_point);
let slice_packed = P::pack_slice_mut(slice_packable);
(slice_packed, slice_leftovers)
}
/// Elementwise inplace multiplication of two slices of field elements.
/// Implementation be faster than the trivial for loop.
pub fn batch_multiply_inplace<F: Field>(out: &mut [F], a: &[F]) {
let n = out.len();
assert_eq!(n, a.len(), "both arrays must have the same length");
// Split out slice of vectors, leaving leftovers as scalars
let (out_packed, out_leftovers) =
pack_slice_with_leftovers_mut::<<F as Packable>::Packing>(out);
let (a_packed, a_leftovers) = pack_slice_with_leftovers::<<F as Packable>::Packing>(a);
// Multiply packed and the leftovers
for (x_out, x_a) in out_packed.iter_mut().zip(a_packed) {
*x_out *= *x_a;
}
for (x_out, x_a) in out_leftovers.iter_mut().zip(a_leftovers) {
*x_out *= *x_a;
}
}
/// Elementwise inplace addition of two slices of field elements.
/// Implementation be faster than the trivial for loop.
pub fn batch_add_inplace<F: Field>(out: &mut [F], a: &[F]) {
let n = out.len();
assert_eq!(n, a.len(), "both arrays must have the same length");
// Split out slice of vectors, leaving leftovers as scalars
let (out_packed, out_leftovers) =
pack_slice_with_leftovers_mut::<<F as Packable>::Packing>(out);
let (a_packed, a_leftovers) = pack_slice_with_leftovers::<<F as Packable>::Packing>(a);
// Add packed and the leftovers
for (x_out, x_a) in out_packed.iter_mut().zip(a_packed) {
*x_out += *x_a;
}
for (x_out, x_a) in out_leftovers.iter_mut().zip(a_leftovers) {
*x_out += *x_a;
}
}

View File

@ -1,3 +1,4 @@
pub(crate) mod batch_util;
pub(crate) mod cosets;
pub mod extension_field;
pub mod fft;

View File

@ -2,6 +2,7 @@ use crate::field::extension_field::target::ExtensionTarget;
use crate::field::extension_field::Extendable;
use crate::field::field_types::RichField;
use crate::gates::gate::Gate;
use crate::gates::util::StridedConstraintConsumer;
use crate::iop::generator::{GeneratedValues, SimpleGenerator, WitnessGenerator};
use crate::iop::target::Target;
use crate::iop::witness::{PartitionWitness, Witness};
@ -67,11 +68,14 @@ impl<F: RichField + Extendable<D>, const D: usize> Gate<F, D> for ArithmeticGate
constraints
}
fn eval_unfiltered_base(&self, vars: EvaluationVarsBase<F>) -> Vec<F> {
fn eval_unfiltered_base_one(
&self,
vars: EvaluationVarsBase<F>,
mut yield_constr: StridedConstraintConsumer<F>,
) {
let const_0 = vars.local_constants[0];
let const_1 = vars.local_constants[1];
let mut constraints = Vec::new();
for i in 0..self.num_ops {
let multiplicand_0 = vars.local_wires[Self::wire_ith_multiplicand_0(i)];
let multiplicand_1 = vars.local_wires[Self::wire_ith_multiplicand_1(i)];
@ -79,10 +83,8 @@ impl<F: RichField + Extendable<D>, const D: usize> Gate<F, D> for ArithmeticGate
let output = vars.local_wires[Self::wire_ith_output(i)];
let computed_output = multiplicand_0 * multiplicand_1 * const_0 + addend * const_1;
constraints.push(output - computed_output);
yield_constr.one(output - computed_output);
}
constraints
}
fn eval_unfiltered_recursively(

View File

@ -4,6 +4,7 @@ use crate::field::extension_field::target::ExtensionTarget;
use crate::field::extension_field::Extendable;
use crate::field::extension_field::FieldExtension;
use crate::gates::gate::Gate;
use crate::gates::util::StridedConstraintConsumer;
use crate::iop::generator::{GeneratedValues, SimpleGenerator, WitnessGenerator};
use crate::iop::target::Target;
use crate::iop::witness::{PartitionWitness, Witness};
@ -70,11 +71,14 @@ impl<F: Extendable<D>, const D: usize> Gate<F, D> for ArithmeticExtensionGate<D>
constraints
}
fn eval_unfiltered_base(&self, vars: EvaluationVarsBase<F>) -> Vec<F> {
fn eval_unfiltered_base_one(
&self,
vars: EvaluationVarsBase<F>,
mut yield_constr: StridedConstraintConsumer<F>,
) {
let const_0 = vars.local_constants[0];
let const_1 = vars.local_constants[1];
let mut constraints = Vec::new();
for i in 0..self.num_ops {
let multiplicand_0 = vars.get_local_ext(Self::wires_ith_multiplicand_0(i));
let multiplicand_1 = vars.get_local_ext(Self::wires_ith_multiplicand_1(i));
@ -83,10 +87,8 @@ impl<F: Extendable<D>, const D: usize> Gate<F, D> for ArithmeticExtensionGate<D>
let computed_output =
(multiplicand_0 * multiplicand_1).scalar_mul(const_0) + addend.scalar_mul(const_1);
constraints.extend((output - computed_output).to_basefield_array());
yield_constr.many((output - computed_output).to_basefield_array());
}
constraints
}
fn eval_unfiltered_recursively(

View File

@ -6,6 +6,7 @@ use crate::field::extension_field::target::ExtensionTarget;
use crate::field::extension_field::Extendable;
use crate::field::field_types::Field;
use crate::gates::gate::Gate;
use crate::gates::util::StridedConstraintConsumer;
use crate::iop::generator::{GeneratedValues, SimpleGenerator, WitnessGenerator};
use crate::iop::target::Target;
use crate::iop::wire::Wire;
@ -118,8 +119,11 @@ impl<F: Extendable<D>, const D: usize> Gate<F, D> for U32ArithmeticGate<F, D> {
constraints
}
fn eval_unfiltered_base(&self, vars: EvaluationVarsBase<F>) -> Vec<F> {
let mut constraints = Vec::with_capacity(self.num_constraints());
fn eval_unfiltered_base_one(
&self,
vars: EvaluationVarsBase<F>,
mut yield_constr: StridedConstraintConsumer<F>,
) {
for i in 0..self.num_ops {
let multiplicand_0 = vars.local_wires[self.wire_ith_multiplicand_0(i)];
let multiplicand_1 = vars.local_wires[self.wire_ith_multiplicand_1(i)];
@ -133,7 +137,7 @@ impl<F: Extendable<D>, const D: usize> Gate<F, D> for U32ArithmeticGate<F, D> {
let base = F::from_canonical_u64(1 << 32u64);
let combined_output = output_high * base + output_low;
constraints.push(combined_output - computed_output);
yield_constr.one(combined_output - computed_output);
let mut combined_low_limbs = F::ZERO;
let mut combined_high_limbs = F::ZERO;
@ -145,7 +149,7 @@ impl<F: Extendable<D>, const D: usize> Gate<F, D> for U32ArithmeticGate<F, D> {
let product = (0..max_limb)
.map(|x| this_limb - F::from_canonical_usize(x))
.product();
constraints.push(product);
yield_constr.one(product);
if j < midpoint {
combined_low_limbs = base * combined_low_limbs + this_limb;
@ -153,11 +157,9 @@ impl<F: Extendable<D>, const D: usize> Gate<F, D> for U32ArithmeticGate<F, D> {
combined_high_limbs = base * combined_high_limbs + this_limb;
}
}
constraints.push(combined_low_limbs - output_low);
constraints.push(combined_high_limbs - output_high);
yield_constr.one(combined_low_limbs - output_low);
yield_constr.one(combined_high_limbs - output_high);
}
constraints
}
fn eval_unfiltered_recursively(

View File

@ -4,6 +4,7 @@ use crate::field::extension_field::target::ExtensionTarget;
use crate::field::extension_field::Extendable;
use crate::field::field_types::{Field, PrimeField, RichField};
use crate::gates::gate::Gate;
use crate::gates::util::StridedConstraintConsumer;
use crate::iop::generator::{GeneratedValues, SimpleGenerator, WitnessGenerator};
use crate::iop::target::Target;
use crate::iop::wire::Wire;
@ -148,9 +149,11 @@ impl<F: RichField + Extendable<D>, const D: usize> Gate<F, D> for AssertLessThan
constraints
}
fn eval_unfiltered_base(&self, vars: EvaluationVarsBase<F>) -> Vec<F> {
let mut constraints = Vec::with_capacity(self.num_constraints());
fn eval_unfiltered_base_one(
&self,
vars: EvaluationVarsBase<F>,
mut yield_constr: StridedConstraintConsumer<F>,
) {
let first_input = vars.local_wires[self.wire_first_input()];
let second_input = vars.local_wires[self.wire_second_input()];
@ -171,8 +174,8 @@ impl<F: RichField + Extendable<D>, const D: usize> Gate<F, D> for AssertLessThan
F::from_canonical_usize(1 << self.chunk_bits()),
);
constraints.push(first_chunks_combined - first_input);
constraints.push(second_chunks_combined - second_input);
yield_constr.one(first_chunks_combined - first_input);
yield_constr.one(second_chunks_combined - second_input);
let chunk_size = 1 << self.chunk_bits();
@ -186,34 +189,32 @@ impl<F: RichField + Extendable<D>, const D: usize> Gate<F, D> for AssertLessThan
let second_product = (0..chunk_size)
.map(|x| second_chunks[i] - F::from_canonical_usize(x))
.product();
constraints.push(first_product);
constraints.push(second_product);
yield_constr.one(first_product);
yield_constr.one(second_product);
let difference = second_chunks[i] - first_chunks[i];
let equality_dummy = vars.local_wires[self.wire_equality_dummy(i)];
let chunks_equal = vars.local_wires[self.wire_chunks_equal(i)];
// Two constraints to assert that `chunks_equal` is valid.
constraints.push(difference * equality_dummy - (F::ONE - chunks_equal));
constraints.push(chunks_equal * difference);
yield_constr.one(difference * equality_dummy - (F::ONE - chunks_equal));
yield_constr.one(chunks_equal * difference);
// Update `most_significant_diff_so_far`.
let intermediate_value = vars.local_wires[self.wire_intermediate_value(i)];
constraints.push(intermediate_value - chunks_equal * most_significant_diff_so_far);
yield_constr.one(intermediate_value - chunks_equal * most_significant_diff_so_far);
most_significant_diff_so_far =
intermediate_value + (F::ONE - chunks_equal) * difference;
}
let most_significant_diff = vars.local_wires[self.wire_most_significant_diff()];
constraints.push(most_significant_diff - most_significant_diff_so_far);
yield_constr.one(most_significant_diff - most_significant_diff_so_far);
// Range check `most_significant_diff` to be less than `chunk_size`.
let product = (0..chunk_size)
.map(|x| most_significant_diff - F::from_canonical_usize(x))
.product();
constraints.push(product);
constraints
yield_constr.one(product);
}
fn eval_unfiltered_recursively(

View File

@ -4,6 +4,7 @@ use crate::field::extension_field::target::ExtensionTarget;
use crate::field::extension_field::Extendable;
use crate::field::field_types::{Field, PrimeField, RichField};
use crate::gates::gate::Gate;
use crate::gates::util::StridedConstraintConsumer;
use crate::iop::generator::{GeneratedValues, SimpleGenerator, WitnessGenerator};
use crate::iop::target::Target;
use crate::iop::witness::{PartitionWitness, Witness};
@ -57,22 +58,23 @@ impl<F: Extendable<D>, const D: usize, const B: usize> Gate<F, D> for BaseSumGat
constraints
}
fn eval_unfiltered_base(&self, vars: EvaluationVarsBase<F>) -> Vec<F> {
fn eval_unfiltered_base_one(
&self,
vars: EvaluationVarsBase<F>,
mut yield_constr: StridedConstraintConsumer<F>,
) {
let sum = vars.local_wires[Self::WIRE_SUM];
let limbs = &vars.local_wires[self.limbs()];
let limbs = vars.local_wires.view(self.limbs());
let computed_sum = reduce_with_powers(limbs, F::from_canonical_usize(B));
let mut constraints = Vec::with_capacity(limbs.len() + 1);
constraints.push(computed_sum - sum);
yield_constr.one(computed_sum - sum);
let constraints_iter = limbs.iter().map(|&limb| {
(0..B)
.map(|i| unsafe { limb.sub_canonical_u64(i as u64) })
.product::<F>()
});
constraints.extend(constraints_iter);
constraints
yield_constr.many(constraints_iter);
}
fn eval_unfiltered_recursively(

View File

@ -4,6 +4,7 @@ use crate::field::extension_field::target::ExtensionTarget;
use crate::field::extension_field::Extendable;
use crate::field::field_types::{Field, PrimeField};
use crate::gates::gate::Gate;
use crate::gates::util::StridedConstraintConsumer;
use crate::iop::generator::{GeneratedValues, SimpleGenerator, WitnessGenerator};
use crate::iop::target::Target;
use crate::iop::wire::Wire;
@ -166,9 +167,11 @@ impl<F: Extendable<D>, const D: usize> Gate<F, D> for ComparisonGate<F, D> {
constraints
}
fn eval_unfiltered_base(&self, vars: EvaluationVarsBase<F>) -> Vec<F> {
let mut constraints = Vec::with_capacity(self.num_constraints());
fn eval_unfiltered_base_one(
&self,
vars: EvaluationVarsBase<F>,
mut yield_constr: StridedConstraintConsumer<F>,
) {
let first_input = vars.local_wires[self.wire_first_input()];
let second_input = vars.local_wires[self.wire_second_input()];
@ -189,8 +192,8 @@ impl<F: Extendable<D>, const D: usize> Gate<F, D> for ComparisonGate<F, D> {
F::from_canonical_usize(1 << self.chunk_bits()),
);
constraints.push(first_chunks_combined - first_input);
constraints.push(second_chunks_combined - second_input);
yield_constr.one(first_chunks_combined - first_input);
yield_constr.one(second_chunks_combined - second_input);
let chunk_size = 1 << self.chunk_bits();
@ -204,26 +207,26 @@ impl<F: Extendable<D>, const D: usize> Gate<F, D> for ComparisonGate<F, D> {
let second_product: F = (0..chunk_size)
.map(|x| second_chunks[i] - F::from_canonical_usize(x))
.product();
constraints.push(first_product);
constraints.push(second_product);
yield_constr.one(first_product);
yield_constr.one(second_product);
let difference = second_chunks[i] - first_chunks[i];
let equality_dummy = vars.local_wires[self.wire_equality_dummy(i)];
let chunks_equal = vars.local_wires[self.wire_chunks_equal(i)];
// Two constraints to assert that `chunks_equal` is valid.
constraints.push(difference * equality_dummy - (F::ONE - chunks_equal));
constraints.push(chunks_equal * difference);
yield_constr.one(difference * equality_dummy - (F::ONE - chunks_equal));
yield_constr.one(chunks_equal * difference);
// Update `most_significant_diff_so_far`.
let intermediate_value = vars.local_wires[self.wire_intermediate_value(i)];
constraints.push(intermediate_value - chunks_equal * most_significant_diff_so_far);
yield_constr.one(intermediate_value - chunks_equal * most_significant_diff_so_far);
most_significant_diff_so_far =
intermediate_value + (F::ONE - chunks_equal) * difference;
}
let most_significant_diff = vars.local_wires[self.wire_most_significant_diff()];
constraints.push(most_significant_diff - most_significant_diff_so_far);
yield_constr.one(most_significant_diff - most_significant_diff_so_far);
let most_significant_diff_bits: Vec<F> = (0..self.chunk_bits() + 1)
.map(|i| vars.local_wires[self.wire_most_significant_diff_bit(i)])
@ -231,18 +234,16 @@ impl<F: Extendable<D>, const D: usize> Gate<F, D> for ComparisonGate<F, D> {
// Range-check the bits.
for &bit in &most_significant_diff_bits {
constraints.push(bit * (F::ONE - bit));
yield_constr.one(bit * (F::ONE - bit));
}
let bits_combined = reduce_with_powers(&most_significant_diff_bits, F::TWO);
let two_n = F::from_canonical_u64(1 << self.chunk_bits());
constraints.push((two_n + most_significant_diff) - bits_combined);
yield_constr.one((two_n + most_significant_diff) - bits_combined);
// Iff first <= second, the top (n + 1st) bit of (2^n - 1 + most_significant_diff) will be 1.
let result_bool = vars.local_wires[self.wire_result_bool()];
constraints.push(result_bool - most_significant_diff_bits[self.chunk_bits()]);
constraints
yield_constr.one(result_bool - most_significant_diff_bits[self.chunk_bits()]);
}
fn eval_unfiltered_recursively(

View File

@ -4,6 +4,7 @@ use crate::field::extension_field::target::ExtensionTarget;
use crate::field::extension_field::Extendable;
use crate::field::field_types::Field;
use crate::gates::gate::Gate;
use crate::gates::util::StridedConstraintConsumer;
use crate::iop::generator::{GeneratedValues, SimpleGenerator, WitnessGenerator};
use crate::iop::target::Target;
use crate::iop::wire::Wire;
@ -39,11 +40,16 @@ impl<F: Extendable<D>, const D: usize> Gate<F, D> for ConstantGate {
.collect()
}
fn eval_unfiltered_base(&self, vars: EvaluationVarsBase<F>) -> Vec<F> {
self.consts_inputs()
.zip(self.wires_outputs())
.map(|(con, out)| vars.local_constants[con] - vars.local_wires[out])
.collect()
fn eval_unfiltered_base_one(
&self,
vars: EvaluationVarsBase<F>,
mut yield_constr: StridedConstraintConsumer<F>,
) {
yield_constr.many(
self.consts_inputs()
.zip(self.wires_outputs())
.map(|(con, out)| vars.local_constants[con] - vars.local_wires[out]),
);
}
fn eval_unfiltered_recursively(

View File

@ -4,6 +4,7 @@ use crate::field::extension_field::target::ExtensionTarget;
use crate::field::extension_field::Extendable;
use crate::field::field_types::Field;
use crate::gates::gate::Gate;
use crate::gates::util::StridedConstraintConsumer;
use crate::iop::generator::{GeneratedValues, SimpleGenerator, WitnessGenerator};
use crate::iop::target::Target;
use crate::iop::wire::Wire;
@ -99,7 +100,11 @@ impl<F: Extendable<D>, const D: usize> Gate<F, D> for ExponentiationGate<F, D> {
constraints
}
fn eval_unfiltered_base(&self, vars: EvaluationVarsBase<F>) -> Vec<F> {
fn eval_unfiltered_base_one(
&self,
vars: EvaluationVarsBase<F>,
mut yield_constr: StridedConstraintConsumer<F>,
) {
let base = vars.local_wires[self.wire_base()];
let power_bits: Vec<_> = (0..self.num_power_bits)
@ -111,8 +116,6 @@ impl<F: Extendable<D>, const D: usize> Gate<F, D> for ExponentiationGate<F, D> {
let output = vars.local_wires[self.wire_output()];
let mut constraints = Vec::with_capacity(self.num_constraints());
for i in 0..self.num_power_bits {
let prev_intermediate_value = if i == 0 {
F::ONE
@ -126,12 +129,10 @@ impl<F: Extendable<D>, const D: usize> Gate<F, D> for ExponentiationGate<F, D> {
let not_cur_bit = F::ONE - cur_bit;
let computed_intermediate_value =
prev_intermediate_value * (cur_bit * base + not_cur_bit);
constraints.push(computed_intermediate_value - intermediate_values[i]);
yield_constr.one(computed_intermediate_value - intermediate_values[i]);
}
constraints.push(output - intermediate_values[self.num_power_bits - 1]);
constraints
yield_constr.one(output - intermediate_values[self.num_power_bits - 1]);
}
fn eval_unfiltered_recursively(

View File

@ -2,13 +2,17 @@ use std::fmt::{Debug, Error, Formatter};
use std::hash::{Hash, Hasher};
use std::sync::Arc;
use crate::field::batch_util::batch_multiply_inplace;
use crate::field::extension_field::target::ExtensionTarget;
use crate::field::extension_field::{Extendable, FieldExtension};
use crate::field::field_types::{Field, RichField};
use crate::gates::gate_tree::Tree;
use crate::gates::util::StridedConstraintConsumer;
use crate::iop::generator::WitnessGenerator;
use crate::plonk::circuit_builder::CircuitBuilder;
use crate::plonk::vars::{EvaluationTargets, EvaluationVars, EvaluationVarsBase};
use crate::plonk::vars::{
EvaluationTargets, EvaluationVars, EvaluationVarsBase, EvaluationVarsBaseBatch,
};
/// A custom gate.
pub trait Gate<F: Extendable<D>, const D: usize>: 'static + Send + Sync {
@ -18,9 +22,19 @@ pub trait Gate<F: Extendable<D>, const D: usize>: 'static + Send + Sync {
/// Like `eval_unfiltered`, but specialized for points in the base field.
///
///
/// `eval_unfiltered_base_batch` calls this method by default. If `eval_unfiltered_base_batch`
/// is overridden, then `eval_unfiltered_base_one` is not necessary.
///
/// By default, this just calls `eval_unfiltered`, which treats the point as an extension field
/// element. This isn't very efficient.
fn eval_unfiltered_base(&self, vars_base: EvaluationVarsBase<F>) -> Vec<F> {
fn eval_unfiltered_base_one(
&self,
vars_base: EvaluationVarsBase<F>,
mut yield_constr: StridedConstraintConsumer<F>,
) {
// Note that this method uses `yield_constr` instead of returning its constraints.
// `yield_constr` abstracts out the underlying memory layout.
let local_constants = &vars_base
.local_constants
.iter()
@ -40,13 +54,21 @@ pub trait Gate<F: Extendable<D>, const D: usize>: 'static + Send + Sync {
let values = self.eval_unfiltered(vars);
// Each value should be in the base field, i.e. only the degree-zero part should be nonzero.
values
.into_iter()
.map(|value| {
debug_assert!(F::Extension::is_in_basefield(&value));
value.to_basefield_array()[0]
})
.collect()
values.into_iter().for_each(|value| {
debug_assert!(F::Extension::is_in_basefield(&value));
yield_constr.one(value.to_basefield_array()[0])
})
}
fn eval_unfiltered_base_batch(&self, vars_base: EvaluationVarsBaseBatch<F>) -> Vec<F> {
let mut res = vec![F::ZERO; vars_base.len() * self.num_constraints()];
for (i, vars_base_one) in vars_base.iter().enumerate() {
self.eval_unfiltered_base_one(
vars_base_one,
StridedConstraintConsumer::new(&mut res, vars_base.len(), i),
);
}
res
}
fn eval_unfiltered_recursively(
@ -64,26 +86,23 @@ pub trait Gate<F: Extendable<D>, const D: usize>: 'static + Send + Sync {
.collect()
}
/// Like `eval_filtered`, but specialized for points in the base field.
fn eval_filtered_base(&self, mut vars: EvaluationVarsBase<F>, prefix: &[bool]) -> Vec<F> {
let filter = compute_filter(prefix, vars.local_constants);
vars.remove_prefix(prefix);
let mut res = self.eval_unfiltered_base(vars);
res.iter_mut().for_each(|c| {
*c *= filter;
});
res
}
/// The result is an array of length `vars_batch.len() * self.num_constraints()`. Constraint `j`
/// for point `i` is at index `j * batch_size + i`.
fn eval_filtered_base_batch(
&self,
vars_batch: &[EvaluationVarsBase<F>],
mut vars_batch: EvaluationVarsBaseBatch<F>,
prefix: &[bool],
) -> Vec<Vec<F>> {
vars_batch
) -> Vec<F> {
let filters: Vec<_> = vars_batch
.iter()
.map(|&vars| self.eval_filtered_base(vars, prefix))
.collect()
.map(|vars| compute_filter(prefix, vars.local_constants))
.collect();
vars_batch.remove_prefix(prefix);
let mut res_batch = self.eval_unfiltered_base_batch(vars_batch);
for res_chunk in res_batch.chunks_exact_mut(filters.len()) {
batch_multiply_inplace(res_chunk, &filters);
}
res_batch
}
/// Adds this gate's filtered constraints into the `combined_gate_constraints` buffer.
@ -174,17 +193,11 @@ impl<F: RichField + Extendable<D>, const D: usize> PrefixedGate<F, D> {
/// A gate's filter is computed as `prod b_i*c_i + (1-b_i)*(1-c_i)`, with `(b_i)` the prefix and
/// `(c_i)` the local constants, which is one if the prefix of `constants` matches `prefix`.
fn compute_filter<K: Field>(prefix: &[bool], constants: &[K]) -> K {
fn compute_filter<'a, K: Field, T: IntoIterator<Item = &'a K>>(prefix: &[bool], constants: T) -> K {
prefix
.iter()
.enumerate()
.map(|(i, &b)| {
if b {
constants[i]
} else {
K::ONE - constants[i]
}
})
.zip(constants)
.map(|(&b, &c)| if b { c } else { K::ONE - c })
.product()
}

View File

@ -8,7 +8,7 @@ use crate::iop::witness::{PartialWitness, Witness};
use crate::plonk::circuit_builder::CircuitBuilder;
use crate::plonk::circuit_data::CircuitConfig;
use crate::plonk::config::GenericConfig;
use crate::plonk::vars::{EvaluationTargets, EvaluationVars, EvaluationVarsBase};
use crate::plonk::vars::{EvaluationTargets, EvaluationVars, EvaluationVarsBaseBatch};
use crate::plonk::verifier::verify;
use crate::polynomial::{PolynomialCoeffs, PolynomialValues};
use crate::util::{log2_ceil, transpose};
@ -106,19 +106,18 @@ pub(crate) fn test_eval_fns<
.collect::<Vec<_>>();
let public_inputs_hash = HashOut::rand();
let vars_base = EvaluationVarsBase {
local_constants: &constants_base,
local_wires: &wires_base,
public_inputs_hash: &public_inputs_hash,
};
// Batch of 1.
let vars_base_batch =
EvaluationVarsBaseBatch::new(1, &constants_base, &wires_base, &public_inputs_hash);
let vars = EvaluationVars {
local_constants: &constants,
local_wires: &wires,
public_inputs_hash: &public_inputs_hash,
};
let evals_base = gate.eval_unfiltered_base(vars_base);
let evals_base = gate.eval_unfiltered_base_batch(vars_base_batch);
let evals = gate.eval_unfiltered(vars);
// This works because we have a batch of 1.
ensure!(
evals
== evals_base

View File

@ -4,6 +4,7 @@ use crate::field::extension_field::target::ExtensionTarget;
use crate::field::extension_field::Extendable;
use crate::field::field_types::Field;
use crate::gates::gate::Gate;
use crate::gates::util::StridedConstraintConsumer;
use crate::hash::gmimc;
use crate::hash::gmimc::GMiMC;
use crate::iop::generator::{GeneratedValues, SimpleGenerator, WitnessGenerator};
@ -107,12 +108,14 @@ impl<F: Extendable<D> + GMiMC<WIDTH>, const D: usize, const WIDTH: usize> Gate<F
constraints
}
fn eval_unfiltered_base(&self, vars: EvaluationVarsBase<F>) -> Vec<F> {
let mut constraints = Vec::with_capacity(self.num_constraints());
fn eval_unfiltered_base_one(
&self,
vars: EvaluationVarsBase<F>,
mut yield_constr: StridedConstraintConsumer<F>,
) {
// Assert that `swap` is binary.
let swap = vars.local_wires[Self::WIRE_SWAP];
constraints.push(swap * swap.sub_one());
yield_constr.one(swap * swap.sub_one());
let mut state = Vec::with_capacity(12);
for i in 0..4 {
@ -138,7 +141,7 @@ impl<F: Extendable<D> + GMiMC<WIDTH>, const D: usize, const WIDTH: usize> Gate<F
let constant = F::from_canonical_u64(<F as GMiMC<WIDTH>>::ROUND_CONSTANTS[r]);
let cubing_input = state[active] + addition_buffer + constant;
let cubing_input_wire = vars.local_wires[Self::wire_cubing_input(r)];
constraints.push(cubing_input - cubing_input_wire);
yield_constr.one(cubing_input - cubing_input_wire);
let f = cubing_input_wire.cube();
addition_buffer += f;
state[active] -= f;
@ -146,10 +149,8 @@ impl<F: Extendable<D> + GMiMC<WIDTH>, const D: usize, const WIDTH: usize> Gate<F
for i in 0..WIDTH {
state[i] += addition_buffer;
constraints.push(state[i] - vars.local_wires[Self::wire_output(i)]);
yield_constr.one(state[i] - vars.local_wires[Self::wire_output(i)]);
}
constraints
}
fn eval_unfiltered_recursively(

View File

@ -5,6 +5,7 @@ use crate::field::extension_field::target::ExtensionTarget;
use crate::field::extension_field::{Extendable, FieldExtension};
use crate::field::field_types::{Field, RichField};
use crate::gates::gate::Gate;
use crate::gates::util::StridedConstraintConsumer;
use crate::iop::generator::{GeneratedValues, SimpleGenerator, WitnessGenerator};
use crate::iop::target::Target;
use crate::iop::wire::Wire;
@ -112,7 +113,11 @@ impl<F: Extendable<D>, const D: usize> Gate<F, D> for InsertionGate<F, D> {
constraints
}
fn eval_unfiltered_base(&self, vars: EvaluationVarsBase<F>) -> Vec<F> {
fn eval_unfiltered_base_one(
&self,
vars: EvaluationVarsBase<F>,
mut yield_constr: StridedConstraintConsumer<F>,
) {
let insertion_index = vars.local_wires[self.wires_insertion_index()];
let list_items = (0..self.vec_size)
.map(|i| vars.get_local_ext(self.wires_original_list_item(i)))
@ -122,7 +127,6 @@ impl<F: Extendable<D>, const D: usize> Gate<F, D> for InsertionGate<F, D> {
.collect::<Vec<_>>();
let element_to_insert = vars.get_local_ext(self.wires_element_to_insert());
let mut constraints = Vec::with_capacity(self.num_constraints());
let mut already_inserted = F::ZERO;
for r in 0..=self.vec_size {
let cur_index = F::from_canonical_usize(r);
@ -131,8 +135,8 @@ impl<F: Extendable<D>, const D: usize> Gate<F, D> for InsertionGate<F, D> {
let insert_here = vars.local_wires[self.wire_insert_here_for_round_r(r)];
// The two equality constraints.
constraints.push(difference * equality_dummy - (F::ONE - insert_here));
constraints.push(insert_here * difference);
yield_constr.one(difference * equality_dummy - (F::ONE - insert_here));
yield_constr.one(insert_here * difference);
let mut new_item = element_to_insert.scalar_mul(insert_here);
if r > 0 {
@ -144,10 +148,8 @@ impl<F: Extendable<D>, const D: usize> Gate<F, D> for InsertionGate<F, D> {
}
// Output constraint.
constraints.extend((new_item - output_list_items[r]).to_basefield_array());
yield_constr.many((new_item - output_list_items[r]).to_basefield_array());
}
constraints
}
fn eval_unfiltered_recursively(

View File

@ -8,6 +8,7 @@ use crate::field::interpolation::interpolant;
use crate::gadgets::interpolation::InterpolationGate;
use crate::gadgets::polynomial::PolynomialCoeffsExtAlgebraTarget;
use crate::gates::gate::Gate;
use crate::gates::util::StridedConstraintConsumer;
use crate::iop::generator::{GeneratedValues, SimpleGenerator, WitnessGenerator};
use crate::iop::target::Target;
use crate::iop::wire::Wire;
@ -106,9 +107,11 @@ impl<F: Extendable<D>, const D: usize> Gate<F, D> for HighDegreeInterpolationGat
constraints
}
fn eval_unfiltered_base(&self, vars: EvaluationVarsBase<F>) -> Vec<F> {
let mut constraints = Vec::with_capacity(self.num_constraints());
fn eval_unfiltered_base_one(
&self,
vars: EvaluationVarsBase<F>,
mut yield_constr: StridedConstraintConsumer<F>,
) {
let coeffs = (0..self.num_points())
.map(|i| vars.get_local_ext(self.wires_coeff(i)))
.collect();
@ -118,15 +121,13 @@ impl<F: Extendable<D>, const D: usize> Gate<F, D> for HighDegreeInterpolationGat
for (i, point) in coset.into_iter().enumerate() {
let value = vars.get_local_ext(self.wires_value(i));
let computed_value = interpolant.eval_base(point);
constraints.extend(&(value - computed_value).to_basefield_array());
yield_constr.many((value - computed_value).to_basefield_array());
}
let evaluation_point = vars.get_local_ext(self.wires_evaluation_point());
let evaluation_value = vars.get_local_ext(self.wires_evaluation_value());
let computed_evaluation_value = interpolant.eval(evaluation_point);
constraints.extend(&(evaluation_value - computed_evaluation_value).to_basefield_array());
constraints
yield_constr.many((evaluation_value - computed_evaluation_value).to_basefield_array());
}
fn eval_unfiltered_recursively(

View File

@ -9,6 +9,7 @@ use crate::field::interpolation::interpolant;
use crate::gadgets::interpolation::InterpolationGate;
use crate::gadgets::polynomial::PolynomialCoeffsExtAlgebraTarget;
use crate::gates::gate::Gate;
use crate::gates::util::StridedConstraintConsumer;
use crate::iop::generator::{GeneratedValues, SimpleGenerator, WitnessGenerator};
use crate::iop::target::Target;
use crate::iop::wire::Wire;
@ -130,9 +131,11 @@ impl<F: RichField + Extendable<D>, const D: usize> Gate<F, D> for LowDegreeInter
constraints
}
fn eval_unfiltered_base(&self, vars: EvaluationVarsBase<F>) -> Vec<F> {
let mut constraints = Vec::with_capacity(self.num_constraints());
fn eval_unfiltered_base_one(
&self,
vars: EvaluationVarsBase<F>,
mut yield_constr: StridedConstraintConsumer<F>,
) {
let coeffs = (0..self.num_points())
.map(|i| vars.get_local_ext(self.wires_coeff(i)))
.collect::<Vec<_>>();
@ -142,7 +145,7 @@ impl<F: RichField + Extendable<D>, const D: usize> Gate<F, D> for LowDegreeInter
.collect::<Vec<_>>();
let shift = powers_shift[0];
for i in 1..self.num_points() - 1 {
constraints.push(powers_shift[i - 1] * shift - powers_shift[i]);
yield_constr.one(powers_shift[i - 1] * shift - powers_shift[i]);
}
powers_shift.insert(0, F::ONE);
// `altered_coeffs[i] = c_i * shift^i`, where `c_i` is the original coefficient.
@ -161,7 +164,7 @@ impl<F: RichField + Extendable<D>, const D: usize> Gate<F, D> for LowDegreeInter
{
let value = vars.get_local_ext(self.wires_value(i));
let computed_value = altered_interpolant.eval_base(point);
constraints.extend(&(value - computed_value).to_basefield_array());
yield_constr.many((value - computed_value).to_basefield_array());
}
let evaluation_point_powers = (1..self.num_points())
@ -169,16 +172,14 @@ impl<F: RichField + Extendable<D>, const D: usize> Gate<F, D> for LowDegreeInter
.collect::<Vec<_>>();
let evaluation_point = evaluation_point_powers[0];
for i in 1..self.num_points() - 1 {
constraints.extend(
yield_constr.many(
(evaluation_point_powers[i - 1] * evaluation_point - evaluation_point_powers[i])
.to_basefield_array(),
);
}
let evaluation_value = vars.get_local_ext(self.wires_evaluation_value());
let computed_evaluation_value = interpolant.eval_with_powers(&evaluation_point_powers);
constraints.extend(&(evaluation_value - computed_evaluation_value).to_basefield_array());
constraints
yield_constr.many((evaluation_value - computed_evaluation_value).to_basefield_array());
}
fn eval_unfiltered_recursively(

View File

@ -25,6 +25,7 @@ pub mod reducing;
pub mod reducing_extension;
pub mod subtraction_u32;
pub mod switch;
mod util;
#[cfg(test)]
mod gate_testing;

View File

@ -5,6 +5,7 @@ use crate::field::extension_field::Extendable;
use crate::field::extension_field::FieldExtension;
use crate::field::field_types::RichField;
use crate::gates::gate::Gate;
use crate::gates::util::StridedConstraintConsumer;
use crate::iop::generator::{GeneratedValues, SimpleGenerator, WitnessGenerator};
use crate::iop::target::Target;
use crate::iop::witness::{PartitionWitness, Witness};
@ -65,20 +66,21 @@ impl<F: RichField + Extendable<D>, const D: usize> Gate<F, D> for MulExtensionGa
constraints
}
fn eval_unfiltered_base(&self, vars: EvaluationVarsBase<F>) -> Vec<F> {
fn eval_unfiltered_base_one(
&self,
vars: EvaluationVarsBase<F>,
mut yield_constr: StridedConstraintConsumer<F>,
) {
let const_0 = vars.local_constants[0];
let mut constraints = Vec::new();
for i in 0..self.num_ops {
let multiplicand_0 = vars.get_local_ext(Self::wires_ith_multiplicand_0(i));
let multiplicand_1 = vars.get_local_ext(Self::wires_ith_multiplicand_1(i));
let output = vars.get_local_ext(Self::wires_ith_output(i));
let computed_output = (multiplicand_0 * multiplicand_1).scalar_mul(const_0);
constraints.extend((output - computed_output).to_basefield_array());
yield_constr.many((output - computed_output).to_basefield_array());
}
constraints
}
fn eval_unfiltered_recursively(

View File

@ -3,7 +3,7 @@ use crate::field::extension_field::Extendable;
use crate::gates::gate::Gate;
use crate::iop::generator::WitnessGenerator;
use crate::plonk::circuit_builder::CircuitBuilder;
use crate::plonk::vars::{EvaluationTargets, EvaluationVars, EvaluationVarsBase};
use crate::plonk::vars::{EvaluationTargets, EvaluationVars, EvaluationVarsBaseBatch};
/// A gate which does nothing.
pub struct NoopGate;
@ -17,7 +17,7 @@ impl<F: Extendable<D>, const D: usize> Gate<F, D> for NoopGate {
Vec::new()
}
fn eval_unfiltered_base(&self, _vars: EvaluationVarsBase<F>) -> Vec<F> {
fn eval_unfiltered_base_batch(&self, _vars: EvaluationVarsBaseBatch<F>) -> Vec<F> {
Vec::new()
}

View File

@ -5,6 +5,7 @@ use crate::field::extension_field::Extendable;
use crate::field::field_types::{Field, RichField};
use crate::gates::gate::Gate;
use crate::gates::poseidon_mds::PoseidonMdsGate;
use crate::gates::util::StridedConstraintConsumer;
use crate::hash::hashing::SPONGE_WIDTH;
use crate::hash::poseidon;
use crate::hash::poseidon::Poseidon;
@ -181,19 +182,21 @@ impl<F: Extendable<D>, const D: usize> Gate<F, D> for PoseidonGate<F, D> {
constraints
}
fn eval_unfiltered_base(&self, vars: EvaluationVarsBase<F>) -> Vec<F> {
let mut constraints = Vec::with_capacity(self.num_constraints());
fn eval_unfiltered_base_one(
&self,
vars: EvaluationVarsBase<F>,
mut yield_constr: StridedConstraintConsumer<F>,
) {
// Assert that `swap` is binary.
let swap = vars.local_wires[Self::WIRE_SWAP];
constraints.push(swap * swap.sub_one());
yield_constr.one(swap * swap.sub_one());
// Assert that each delta wire is set properly: `delta_i = swap * (rhs - lhs)`.
for i in 0..4 {
let input_lhs = vars.local_wires[Self::wire_input(i)];
let input_rhs = vars.local_wires[Self::wire_input(i + 4)];
let delta_i = vars.local_wires[Self::wire_delta(i)];
constraints.push(swap * (input_rhs - input_lhs) - delta_i);
yield_constr.one(swap * (input_rhs - input_lhs) - delta_i);
}
// Compute the possibly-swapped input layer.
@ -217,7 +220,7 @@ impl<F: Extendable<D>, const D: usize> Gate<F, D> for PoseidonGate<F, D> {
if r != 0 {
for i in 0..SPONGE_WIDTH {
let sbox_in = vars.local_wires[Self::wire_full_sbox_0(r, i)];
constraints.push(state[i] - sbox_in);
yield_constr.one(state[i] - sbox_in);
state[i] = sbox_in;
}
}
@ -231,13 +234,13 @@ impl<F: Extendable<D>, const D: usize> Gate<F, D> for PoseidonGate<F, D> {
state = <F as Poseidon>::mds_partial_layer_init(&state);
for r in 0..(poseidon::N_PARTIAL_ROUNDS - 1) {
let sbox_in = vars.local_wires[Self::wire_partial_sbox(r)];
constraints.push(state[0] - sbox_in);
yield_constr.one(state[0] - sbox_in);
state[0] = <F as Poseidon>::sbox_monomial(sbox_in);
state[0] += F::from_canonical_u64(<F as Poseidon>::FAST_PARTIAL_ROUND_CONSTANTS[r]);
state = <F as Poseidon>::mds_partial_layer_fast(&state, r);
}
let sbox_in = vars.local_wires[Self::wire_partial_sbox(poseidon::N_PARTIAL_ROUNDS - 1)];
constraints.push(state[0] - sbox_in);
yield_constr.one(state[0] - sbox_in);
state[0] = <F as Poseidon>::sbox_monomial(sbox_in);
state = <F as Poseidon>::mds_partial_layer_fast(&state, poseidon::N_PARTIAL_ROUNDS - 1);
round_ctr += poseidon::N_PARTIAL_ROUNDS;
@ -247,7 +250,7 @@ impl<F: Extendable<D>, const D: usize> Gate<F, D> for PoseidonGate<F, D> {
<F as Poseidon>::constant_layer(&mut state, round_ctr);
for i in 0..SPONGE_WIDTH {
let sbox_in = vars.local_wires[Self::wire_full_sbox_1(r, i)];
constraints.push(state[i] - sbox_in);
yield_constr.one(state[i] - sbox_in);
state[i] = sbox_in;
}
<F as Poseidon>::sbox_layer(&mut state);
@ -256,10 +259,8 @@ impl<F: Extendable<D>, const D: usize> Gate<F, D> for PoseidonGate<F, D> {
}
for i in 0..SPONGE_WIDTH {
constraints.push(state[i] - vars.local_wires[Self::wire_output(i)]);
yield_constr.one(state[i] - vars.local_wires[Self::wire_output(i)]);
}
constraints
}
fn eval_unfiltered_recursively(

View File

@ -7,6 +7,7 @@ use crate::field::extension_field::Extendable;
use crate::field::extension_field::FieldExtension;
use crate::field::field_types::{Field, RichField};
use crate::gates::gate::Gate;
use crate::gates::util::StridedConstraintConsumer;
use crate::hash::hashing::SPONGE_WIDTH;
use crate::hash::poseidon::Poseidon;
use crate::iop::generator::{GeneratedValues, SimpleGenerator, WitnessGenerator};
@ -123,7 +124,11 @@ impl<F: RichField + Extendable<D> + Poseidon, const D: usize> Gate<F, D> for Pos
.collect()
}
fn eval_unfiltered_base(&self, vars: EvaluationVarsBase<F>) -> Vec<F> {
fn eval_unfiltered_base_one(
&self,
vars: EvaluationVarsBase<F>,
mut yield_constr: StridedConstraintConsumer<F>,
) {
let inputs: [_; SPONGE_WIDTH] = (0..SPONGE_WIDTH)
.map(|i| vars.get_local_ext(Self::wires_input(i)))
.collect::<Vec<_>>()
@ -132,11 +137,12 @@ impl<F: RichField + Extendable<D> + Poseidon, const D: usize> Gate<F, D> for Pos
let computed_outputs = F::mds_layer_field(&inputs);
(0..SPONGE_WIDTH)
.map(|i| vars.get_local_ext(Self::wires_output(i)))
.zip(computed_outputs)
.flat_map(|(out, computed_out)| (out - computed_out).to_basefield_array())
.collect()
yield_constr.many(
(0..SPONGE_WIDTH)
.map(|i| vars.get_local_ext(Self::wires_output(i)))
.zip(computed_outputs)
.flat_map(|(out, computed_out)| (out - computed_out).to_basefield_array()),
)
}
fn eval_unfiltered_recursively(

View File

@ -3,6 +3,7 @@ use std::ops::Range;
use crate::field::extension_field::target::ExtensionTarget;
use crate::field::extension_field::Extendable;
use crate::gates::gate::Gate;
use crate::gates::util::StridedConstraintConsumer;
use crate::iop::generator::WitnessGenerator;
use crate::plonk::circuit_builder::CircuitBuilder;
use crate::plonk::vars::{EvaluationTargets, EvaluationVars, EvaluationVarsBase};
@ -28,11 +29,16 @@ impl<F: Extendable<D>, const D: usize> Gate<F, D> for PublicInputGate {
.collect()
}
fn eval_unfiltered_base(&self, vars: EvaluationVarsBase<F>) -> Vec<F> {
Self::wires_public_inputs_hash()
.zip(vars.public_inputs_hash.elements)
.map(|(wire, hash_part)| vars.local_wires[wire] - hash_part)
.collect()
fn eval_unfiltered_base_one(
&self,
vars: EvaluationVarsBase<F>,
mut yield_constr: StridedConstraintConsumer<F>,
) {
yield_constr.many(
Self::wires_public_inputs_hash()
.zip(vars.public_inputs_hash.elements)
.map(|(wire, hash_part)| vars.local_wires[wire] - hash_part),
)
}
fn eval_unfiltered_recursively(

View File

@ -6,6 +6,7 @@ use crate::field::extension_field::target::ExtensionTarget;
use crate::field::extension_field::Extendable;
use crate::field::field_types::Field;
use crate::gates::gate::Gate;
use crate::gates::util::StridedConstraintConsumer;
use crate::iop::generator::{GeneratedValues, SimpleGenerator, WitnessGenerator};
use crate::iop::target::Target;
use crate::iop::wire::Wire;
@ -125,9 +126,11 @@ impl<F: Extendable<D>, const D: usize> Gate<F, D> for RandomAccessGate<F, D> {
constraints
}
fn eval_unfiltered_base(&self, vars: EvaluationVarsBase<F>) -> Vec<F> {
let mut constraints = Vec::with_capacity(self.num_constraints());
fn eval_unfiltered_base_one(
&self,
vars: EvaluationVarsBase<F>,
mut yield_constr: StridedConstraintConsumer<F>,
) {
for copy in 0..self.num_copies {
let access_index = vars.local_wires[self.wire_access_index(copy)];
let mut list_items = (0..self.vec_size())
@ -140,12 +143,12 @@ impl<F: Extendable<D>, const D: usize> Gate<F, D> for RandomAccessGate<F, D> {
// Assert that each bit wire value is indeed boolean.
for &b in &bits {
constraints.push(b * (b - F::ONE));
yield_constr.one(b * (b - F::ONE));
}
// Assert that the binary decomposition was correct.
let reconstructed_index = bits.iter().rev().fold(F::ZERO, |acc, &b| acc.double() + b);
constraints.push(reconstructed_index - access_index);
yield_constr.one(reconstructed_index - access_index);
// Repeatedly fold the list, selecting the left or right item from each pair based on
// the corresponding bit.
@ -158,10 +161,8 @@ impl<F: Extendable<D>, const D: usize> Gate<F, D> for RandomAccessGate<F, D> {
}
debug_assert_eq!(list_items.len(), 1);
constraints.push(list_items[0] - claimed_element);
yield_constr.one(list_items[0] - claimed_element);
}
constraints
}
fn eval_unfiltered_recursively(

View File

@ -5,6 +5,7 @@ use crate::field::extension_field::Extendable;
use crate::field::extension_field::FieldExtension;
use crate::field::field_types::RichField;
use crate::gates::gate::Gate;
use crate::gates::util::StridedConstraintConsumer;
use crate::iop::generator::{GeneratedValues, SimpleGenerator, WitnessGenerator};
use crate::iop::target::Target;
use crate::iop::witness::{PartitionWitness, Witness};
@ -80,7 +81,11 @@ impl<F: RichField + Extendable<D>, const D: usize> Gate<F, D> for ReducingGate<D
.collect()
}
fn eval_unfiltered_base(&self, vars: EvaluationVarsBase<F>) -> Vec<F> {
fn eval_unfiltered_base_one(
&self,
vars: EvaluationVarsBase<F>,
mut yield_constr: StridedConstraintConsumer<F>,
) {
let alpha = vars.get_local_ext(Self::wires_alpha());
let old_acc = vars.get_local_ext(Self::wires_old_acc());
let coeffs = self
@ -91,14 +96,11 @@ impl<F: RichField + Extendable<D>, const D: usize> Gate<F, D> for ReducingGate<D
.map(|i| vars.get_local_ext(self.wires_accs(i)))
.collect::<Vec<_>>();
let mut constraints = Vec::with_capacity(<Self as Gate<F, D>>::num_constraints(self));
let mut acc = old_acc;
for i in 0..self.num_coeffs {
constraints.extend((acc * alpha + coeffs[i].into() - accs[i]).to_basefield_array());
yield_constr.many((acc * alpha + coeffs[i].into() - accs[i]).to_basefield_array());
acc = accs[i];
}
constraints
}
fn eval_unfiltered_recursively(

View File

@ -5,6 +5,7 @@ use crate::field::extension_field::Extendable;
use crate::field::extension_field::FieldExtension;
use crate::field::field_types::RichField;
use crate::gates::gate::Gate;
use crate::gates::util::StridedConstraintConsumer;
use crate::iop::generator::{GeneratedValues, SimpleGenerator, WitnessGenerator};
use crate::iop::target::Target;
use crate::iop::witness::{PartitionWitness, Witness};
@ -82,7 +83,11 @@ impl<F: RichField + Extendable<D>, const D: usize> Gate<F, D> for ReducingExtens
.collect()
}
fn eval_unfiltered_base(&self, vars: EvaluationVarsBase<F>) -> Vec<F> {
fn eval_unfiltered_base_one(
&self,
vars: EvaluationVarsBase<F>,
mut yield_constr: StridedConstraintConsumer<F>,
) {
let alpha = vars.get_local_ext(Self::wires_alpha());
let old_acc = vars.get_local_ext(Self::wires_old_acc());
let coeffs = (0..self.num_coeffs)
@ -92,14 +97,11 @@ impl<F: RichField + Extendable<D>, const D: usize> Gate<F, D> for ReducingExtens
.map(|i| vars.get_local_ext(self.wires_accs(i)))
.collect::<Vec<_>>();
let mut constraints = Vec::with_capacity(<Self as Gate<F, D>>::num_constraints(self));
let mut acc = old_acc;
for i in 0..self.num_coeffs {
constraints.extend((acc * alpha + coeffs[i] - accs[i]).to_basefield_array());
yield_constr.many((acc * alpha + coeffs[i] - accs[i]).to_basefield_array());
acc = accs[i];
}
constraints
}
fn eval_unfiltered_recursively(

View File

@ -4,6 +4,7 @@ use crate::field::extension_field::target::ExtensionTarget;
use crate::field::extension_field::Extendable;
use crate::field::field_types::{Field, RichField};
use crate::gates::gate::Gate;
use crate::gates::util::StridedConstraintConsumer;
use crate::iop::generator::{GeneratedValues, SimpleGenerator, WitnessGenerator};
use crate::iop::target::Target;
use crate::iop::wire::Wire;
@ -113,8 +114,11 @@ impl<F: RichField + Extendable<D>, const D: usize> Gate<F, D> for U32Subtraction
constraints
}
fn eval_unfiltered_base(&self, vars: EvaluationVarsBase<F>) -> Vec<F> {
let mut constraints = Vec::with_capacity(self.num_constraints());
fn eval_unfiltered_base_one(
&self,
vars: EvaluationVarsBase<F>,
mut yield_constr: StridedConstraintConsumer<F>,
) {
for i in 0..self.num_ops {
let input_x = vars.local_wires[self.wire_ith_input_x(i)];
let input_y = vars.local_wires[self.wire_ith_input_y(i)];
@ -126,7 +130,7 @@ impl<F: RichField + Extendable<D>, const D: usize> Gate<F, D> for U32Subtraction
let output_result = vars.local_wires[self.wire_ith_output_result(i)];
let output_borrow = vars.local_wires[self.wire_ith_output_borrow(i)];
constraints.push(output_result - (result_initial + base * output_borrow));
yield_constr.one(output_result - (result_initial + base * output_borrow));
// Range-check output_result to be at most 32 bits.
let mut combined_limbs = F::ZERO;
@ -137,17 +141,15 @@ impl<F: RichField + Extendable<D>, const D: usize> Gate<F, D> for U32Subtraction
let product = (0..max_limb)
.map(|x| this_limb - F::from_canonical_usize(x))
.product();
constraints.push(product);
yield_constr.one(product);
combined_limbs = limb_base * combined_limbs + this_limb;
}
constraints.push(combined_limbs - output_result);
yield_constr.one(combined_limbs - output_result);
// Range-check output_borrow to be one bit.
constraints.push(output_borrow * (F::ONE - output_borrow));
yield_constr.one(output_borrow * (F::ONE - output_borrow));
}
constraints
}
fn eval_unfiltered_recursively(

View File

@ -6,6 +6,7 @@ use crate::field::extension_field::target::ExtensionTarget;
use crate::field::extension_field::Extendable;
use crate::field::field_types::Field;
use crate::gates::gate::Gate;
use crate::gates::util::StridedConstraintConsumer;
use crate::iop::generator::{GeneratedValues, WitnessGenerator};
use crate::iop::target::Target;
use crate::iop::wire::Wire;
@ -94,9 +95,11 @@ impl<F: Extendable<D>, const D: usize> Gate<F, D> for SwitchGate<F, D> {
constraints
}
fn eval_unfiltered_base(&self, vars: EvaluationVarsBase<F>) -> Vec<F> {
let mut constraints = Vec::with_capacity(self.num_constraints());
fn eval_unfiltered_base_one(
&self,
vars: EvaluationVarsBase<F>,
mut yield_constr: StridedConstraintConsumer<F>,
) {
for c in 0..self.num_copies {
let switch_bool = vars.local_wires[self.wire_switch_bool(c)];
let not_switch = F::ONE - switch_bool;
@ -107,14 +110,12 @@ impl<F: Extendable<D>, const D: usize> Gate<F, D> for SwitchGate<F, D> {
let first_output = vars.local_wires[self.wire_first_output(c, e)];
let second_output = vars.local_wires[self.wire_second_output(c, e)];
constraints.push(switch_bool * (first_input - second_output));
constraints.push(switch_bool * (second_input - first_output));
constraints.push(not_switch * (first_input - first_output));
constraints.push(not_switch * (second_input - second_output));
yield_constr.one(switch_bool * (first_input - second_output));
yield_constr.one(switch_bool * (second_input - first_output));
yield_constr.one(not_switch * (first_input - first_output));
yield_constr.one(not_switch * (second_input - second_output));
}
}
constraints
}
fn eval_unfiltered_recursively(

62
src/gates/util.rs Normal file
View File

@ -0,0 +1,62 @@
use std::marker::PhantomData;
use crate::field::field_types::Field;
/// Writes constraints yielded by a gate to a buffer, with a given stride.
/// Permits us to abstract the underlying memory layout. In particular, we can make a matrix of
/// constraints where every column is an evaluation point and every row is a constraint index, with
/// the matrix stored in row-contiguous form.
pub struct StridedConstraintConsumer<'a, F: Field> {
// This is a particularly neat way of doing this, more so than a slice. We increase start by
// stride at every step and terminate when it equals end.
start: *mut F,
end: *mut F,
stride: usize,
_phantom: PhantomData<&'a mut [F]>,
}
impl<'a, F: Field> StridedConstraintConsumer<'a, F> {
pub fn new(buffer: &'a mut [F], stride: usize, offset: usize) -> Self {
assert!(offset < stride);
assert_eq!(buffer.len() % stride, 0);
let ptr_range = buffer.as_mut_ptr_range();
// `wrapping_add` is needed to avoid undefined behavior. Plain `add` causes UB if 'the ...
// resulting pointer [is neither] in bounds or one byte past the end of the same allocated
// object'; the UB results even if the pointer is not dereferenced. `end` will be more than
// one byte past the buffer unless `offset` is 0. The same applies to `start` if the buffer
// has length 0 and the offset is not 0.
// We _could_ do pointer arithmetic without `wrapping_add`, but the logic would be
// unnecessarily complicated.
let start = ptr_range.start.wrapping_add(offset);
let end = ptr_range.end.wrapping_add(offset);
Self {
start,
end,
stride,
_phantom: PhantomData,
}
}
/// Emit one constraint.
pub fn one(&mut self, constraint: F) {
if self.start != self.end {
// # Safety
// The checks in `new` guarantee that this points to valid space.
unsafe {
*self.start = constraint;
}
// See the comment in `new`. `wrapping_add` is needed to avoid UB if we've just
// exhausted our buffer (and hence we're setting `self.start` to point past the end).
self.start = self.start.wrapping_add(self.stride);
} else {
panic!("gate produced too many constraints");
}
}
/// Convenience method that calls `.one()` multiple times.
pub fn many<I: IntoIterator<Item = F>>(&mut self, constraints: I) {
constraints
.into_iter()
.for_each(|constraint| self.one(constraint));
}
}

View File

@ -157,9 +157,15 @@ pub(crate) fn reduce_with_powers_multi<
cumul
}
pub(crate) fn reduce_with_powers<F: Field>(terms: &[F], alpha: F) -> F {
pub(crate) fn reduce_with_powers<'a, F: Field, T: IntoIterator<Item = &'a F>>(
terms: T,
alpha: F,
) -> F
where
T::IntoIter: DoubleEndedIterator,
{
let mut sum = F::ZERO;
for &term in terms.iter().rev() {
for &term in terms.into_iter().rev() {
sum = sum * alpha + term;
}
sum

View File

@ -14,7 +14,7 @@ use crate::plonk::plonk_common::PlonkPolynomials;
use crate::plonk::plonk_common::ZeroPolyOnCoset;
use crate::plonk::proof::{Proof, ProofWithPublicInputs};
use crate::plonk::vanishing_poly::eval_vanishing_poly_base_batch;
use crate::plonk::vars::EvaluationVarsBase;
use crate::plonk::vars::EvaluationVarsBaseBatch;
use crate::polynomial::{PolynomialCoeffs, PolynomialValues};
use crate::timed;
use crate::util::partial_products::{partial_products_and_z_gx, quotient_chunk_products};
@ -333,12 +333,14 @@ fn compute_quotient_polys<'a, F: Extendable<D>, C: GenericConfig<D, F = F>, cons
(BATCH_SIZE * batch_i..BATCH_SIZE * (batch_i + 1)).collect();
let mut shifted_xs_batch = Vec::with_capacity(xs_batch.len());
let mut vars_batch = Vec::with_capacity(xs_batch.len());
let mut local_zs_batch = Vec::with_capacity(xs_batch.len());
let mut next_zs_batch = Vec::with_capacity(xs_batch.len());
let mut partial_products_batch = Vec::with_capacity(xs_batch.len());
let mut s_sigmas_batch = Vec::with_capacity(xs_batch.len());
let mut local_constants_batch_refs = Vec::with_capacity(xs_batch.len());
let mut local_wires_batch_refs = Vec::with_capacity(xs_batch.len());
for (&i, &x) in indices_batch.iter().zip(xs_batch) {
let shifted_x = F::coset_shift() * x;
let i_next = (i + next_step) % lde_size;
@ -357,24 +359,45 @@ fn compute_quotient_polys<'a, F: Extendable<D>, C: GenericConfig<D, F = F>, cons
debug_assert_eq!(local_wires.len(), common_data.config.num_wires);
debug_assert_eq!(local_zs.len(), num_challenges);
let vars = EvaluationVarsBase {
local_constants,
local_wires,
public_inputs_hash,
};
local_constants_batch_refs.push(local_constants);
local_wires_batch_refs.push(local_wires);
shifted_xs_batch.push(shifted_x);
vars_batch.push(vars);
local_zs_batch.push(local_zs);
next_zs_batch.push(next_zs);
partial_products_batch.push(partial_products);
s_sigmas_batch.push(s_sigmas);
}
// NB (JN): I'm not sure how (in)efficient the below is. It needs measuring.
let mut local_constants_batch =
vec![F::ZERO; xs_batch.len() * local_constants_batch_refs[0].len()];
for (i, constants) in local_constants_batch_refs.iter().enumerate() {
for (j, &constant) in constants.iter().enumerate() {
local_constants_batch[i + j * xs_batch.len()] = constant;
}
}
let mut local_wires_batch =
vec![F::ZERO; xs_batch.len() * local_wires_batch_refs[0].len()];
for (i, wires) in local_wires_batch_refs.iter().enumerate() {
for (j, &wire) in wires.iter().enumerate() {
local_wires_batch[i + j * xs_batch.len()] = wire;
}
}
let vars_batch = EvaluationVarsBaseBatch::new(
xs_batch.len(),
&local_constants_batch,
&local_wires_batch,
public_inputs_hash,
);
let mut quotient_values_batch = eval_vanishing_poly_base_batch(
common_data,
&indices_batch,
&shifted_xs_batch,
&vars_batch,
vars_batch,
&local_zs_batch,
&next_zs_batch,
&partial_products_batch,

View File

@ -1,3 +1,4 @@
use crate::field::batch_util::batch_add_inplace;
use crate::field::extension_field::target::ExtensionTarget;
use crate::field::extension_field::{Extendable, FieldExtension};
use crate::field::field_types::Field;
@ -8,9 +9,10 @@ use crate::plonk::circuit_data::CommonCircuitData;
use crate::plonk::config::GenericConfig;
use crate::plonk::plonk_common;
use crate::plonk::plonk_common::{eval_l_1_recursively, ZeroPolyOnCoset};
use crate::plonk::vars::{EvaluationTargets, EvaluationVars, EvaluationVarsBase};
use crate::plonk::vars::{EvaluationTargets, EvaluationVars, EvaluationVarsBaseBatch};
use crate::util::partial_products::{check_partial_products, check_partial_products_recursively};
use crate::util::reducing::ReducingFactorTarget;
use crate::util::strided_view::PackedStridedView;
use crate::with_context;
/// Evaluate the vanishing polynomial at `x`. In this context, the vanishing polynomial is a random
@ -96,7 +98,7 @@ pub(crate) fn eval_vanishing_poly_base_batch<
common_data: &CommonCircuitData<F, C, D>,
indices_batch: &[usize],
xs_batch: &[F],
vars_batch: &[EvaluationVarsBase<F>],
vars_batch: EvaluationVarsBaseBatch<F>,
local_zs_batch: &[&[F]],
next_zs_batch: &[&[F]],
partial_products_batch: &[&[F]],
@ -138,14 +140,13 @@ pub(crate) fn eval_vanishing_poly_base_batch<
for k in 0..n {
let index = indices_batch[k];
let x = xs_batch[k];
let vars = vars_batch[k];
let vars = vars_batch.view(k);
let local_zs = local_zs_batch[k];
let next_zs = next_zs_batch[k];
let partial_products = partial_products_batch[k];
let s_sigmas = s_sigmas_batch[k];
let constraint_terms =
&constraint_terms_batch[k * num_gate_constraints..(k + 1) * num_gate_constraints];
let constraint_terms = PackedStridedView::new(&constraint_terms_batch, n, k);
let l1_x = z_h_on_coset.eval_l1(index, x);
for i in 0..num_challenges {
@ -221,13 +222,13 @@ pub fn evaluate_gate_constraints<F: Extendable<D>, const D: usize>(
/// Evaluate all gate constraints in the base field.
///
/// Returns a vector of num_gate_constraints * vars_batch.len() field elements. The constraints
/// corresponding to vars_batch[i] are found in
/// result[num_gate_constraints * i..num_gate_constraints * (i + 1)].
/// Returns a vector of `num_gate_constraints * vars_batch.len()` field elements. The constraints
/// corresponding to `vars_batch[i]` are found in `result[i], result[vars_batch.len() + i],
/// result[2 * vars_batch.len() + i], ...`.
pub fn evaluate_gate_constraints_base_batch<F: Extendable<D>, const D: usize>(
gates: &[PrefixedGate<F, D>],
num_gate_constraints: usize,
vars_batch: &[EvaluationVarsBase<F>],
vars_batch: EvaluationVarsBaseBatch<F>,
) -> Vec<F> {
let mut constraints_batch = vec![F::ZERO; num_gate_constraints * vars_batch.len()];
for gate in gates {
@ -235,20 +236,15 @@ pub fn evaluate_gate_constraints_base_batch<F: Extendable<D>, const D: usize>(
.gate
.0
.eval_filtered_base_batch(vars_batch, &gate.prefix);
for (constraints, gate_constraints) in constraints_batch
.chunks_exact_mut(num_gate_constraints)
.zip(gate_constraints_batch.iter())
{
debug_assert!(
gate_constraints.len() <= constraints.len(),
"num_constraints() gave too low of a number"
);
for (constraint, &gate_constraint) in
constraints.iter_mut().zip(gate_constraints.iter())
{
*constraint += gate_constraint;
}
}
debug_assert!(
gate_constraints_batch.len() <= constraints_batch.len(),
"num_constraints() gave too low of a number"
);
// below adds all constraints for all points
batch_add_inplace(
&mut constraints_batch[..gate_constraints_batch.len()],
&gate_constraints_batch,
);
}
constraints_batch
}

View File

@ -5,6 +5,7 @@ use crate::field::extension_field::target::{ExtensionAlgebraTarget, ExtensionTar
use crate::field::extension_field::{Extendable, FieldExtension};
use crate::field::field_types::Field;
use crate::hash::hash_types::{HashOut, HashOutTarget};
use crate::util::strided_view::PackedStridedView;
#[derive(Debug, Copy, Clone)]
pub struct EvaluationVars<'a, F: Extendable<D>, const D: usize> {
@ -13,13 +14,25 @@ pub struct EvaluationVars<'a, F: Extendable<D>, const D: usize> {
pub(crate) public_inputs_hash: &'a HashOut<F>,
}
/// A batch of evaluation vars, in the base field.
/// Wires and constants are stored in an evaluation point-major order (that is, wire 0 for all
/// evaluation points, then wire 1 for all points, and so on).
#[derive(Debug, Copy, Clone)]
pub struct EvaluationVarsBase<'a, F: Field> {
pub struct EvaluationVarsBaseBatch<'a, F: Field> {
batch_size: usize,
pub(crate) local_constants: &'a [F],
pub(crate) local_wires: &'a [F],
pub(crate) public_inputs_hash: &'a HashOut<F>,
}
/// A view into `EvaluationVarsBaseBatch` for a particular evaluation point. Does not copy the data.
#[derive(Debug, Copy, Clone)]
pub struct EvaluationVarsBase<'a, F: Field> {
pub(crate) local_constants: PackedStridedView<'a, F>,
pub(crate) local_wires: PackedStridedView<'a, F>,
pub(crate) public_inputs_hash: &'a HashOut<F>,
}
impl<'a, F: Extendable<D>, const D: usize> EvaluationVars<'a, F, D> {
pub fn get_local_ext_algebra(
&self,
@ -35,18 +48,81 @@ impl<'a, F: Extendable<D>, const D: usize> EvaluationVars<'a, F, D> {
}
}
impl<'a, F: Field> EvaluationVarsBaseBatch<'a, F> {
pub fn new(
batch_size: usize,
local_constants: &'a [F],
local_wires: &'a [F],
public_inputs_hash: &'a HashOut<F>,
) -> Self {
assert_eq!(local_constants.len() % batch_size, 0);
assert_eq!(local_wires.len() % batch_size, 0);
Self {
batch_size,
local_constants,
local_wires,
public_inputs_hash,
}
}
pub fn remove_prefix(&mut self, prefix: &[bool]) {
self.local_constants = &self.local_constants[prefix.len() * self.len()..];
}
pub fn len(&self) -> usize {
self.batch_size
}
pub fn view(&self, index: usize) -> EvaluationVarsBase<'a, F> {
// We cannot implement `Index` as `EvaluationVarsBase` is a struct, not a reference.
assert!(index < self.len());
let local_constants = PackedStridedView::new(self.local_constants, self.len(), index);
let local_wires = PackedStridedView::new(self.local_wires, self.len(), index);
EvaluationVarsBase {
local_constants,
local_wires,
public_inputs_hash: self.public_inputs_hash,
}
}
pub fn iter(&self) -> EvaluationVarsBaseBatchIter<'a, F> {
EvaluationVarsBaseBatchIter::new(*self)
}
}
impl<'a, F: Field> EvaluationVarsBase<'a, F> {
pub fn get_local_ext<const D: usize>(&self, wire_range: Range<usize>) -> F::Extension
where
F: Extendable<D>,
{
debug_assert_eq!(wire_range.len(), D);
let arr = self.local_wires[wire_range].try_into().unwrap();
let arr = self.local_wires.view(wire_range).try_into().unwrap();
F::Extension::from_basefield_array(arr)
}
}
pub fn remove_prefix(&mut self, prefix: &[bool]) {
self.local_constants = &self.local_constants[prefix.len()..];
/// Iterator of views (`EvaluationVarsBase`) into a `EvaluationVarsBaseBatch`.
pub struct EvaluationVarsBaseBatchIter<'a, F: Field> {
i: usize,
vars_batch: EvaluationVarsBaseBatch<'a, F>,
}
impl<'a, F: Field> EvaluationVarsBaseBatchIter<'a, F> {
pub fn new(vars_batch: EvaluationVarsBaseBatch<'a, F>) -> Self {
EvaluationVarsBaseBatchIter { i: 0, vars_batch }
}
}
impl<'a, F: Field> Iterator for EvaluationVarsBaseBatchIter<'a, F> {
type Item = EvaluationVarsBase<'a, F>;
fn next(&mut self) -> Option<Self::Item> {
if self.i < self.vars_batch.len() {
let res = self.vars_batch.view(self.i);
self.i += 1;
Some(res)
} else {
None
}
}
}

View File

@ -10,6 +10,7 @@ pub(crate) mod marking;
pub(crate) mod partial_products;
pub mod reducing;
pub mod serialization;
pub(crate) mod strided_view;
pub(crate) mod timing;
pub(crate) fn bits_u64(n: u64) -> usize {

317
src/util/strided_view.rs Normal file
View File

@ -0,0 +1,317 @@
use std::marker::PhantomData;
use std::mem::size_of;
use std::ops::{Index, Range, RangeFrom, RangeFull, RangeInclusive, RangeTo, RangeToInclusive};
use crate::field::packed_field::PackedField;
/// Imagine a slice, but with a stride (a la a NumPy array).
///
/// For example, if the stride is 3,
/// `packed_strided_view[0]` is `data[0]`,
/// `packed_strided_view[1]` is `data[3]`,
/// `packed_strided_view[2]` is `data[6]`,
/// and so on. An offset may be specified. With an offset of 1, we get
/// `packed_strided_view[0]` is `data[1]`,
/// `packed_strided_view[1]` is `data[4]`,
/// `packed_strided_view[2]` is `data[7]`,
/// and so on.
///
/// Additionally, this view is *packed*, which means that it may yield a packing of the underlying
/// field slice. With a packing of width 4 and a stride of 5, the accesses are
/// `packed_strided_view[0]` is `data[0..4]`, transmuted to the packing,
/// `packed_strided_view[1]` is `data[5..9]`, transmuted to the packing,
/// `packed_strided_view[2]` is `data[10..14]`, transmuted to the packing,
/// and so on.
#[derive(Debug, Copy, Clone)]
pub struct PackedStridedView<'a, P: PackedField> {
// This type has to be a struct, which means that it is not itself a reference (in the sense
// that a slice is a reference so we can return it from e.g. `Index::index`).
// Raw pointers rarely appear in good Rust code, but I think this is the most elegant way to
// implement this. The alternative would be to replace `start_ptr` and `length` with one slice
// (`&[P::Scalar]`). Unfortunately, with a slice, an empty view becomes an edge case that
// necessitates separate handling. It _could_ be done but it would also be uglier.
start_ptr: *const P::Scalar,
/// This is the total length of elements accessible through the view. In other words, valid
/// indices are in `0..length`.
length: usize,
/// This stride is in units of `P::Scalar` (NOT in bytes and NOT in `P`).
stride: usize,
_phantom: PhantomData<&'a [P::Scalar]>,
}
impl<'a, P: PackedField> PackedStridedView<'a, P> {
// `wrapping_add` is needed throughout to avoid undefined behavior. Plain `add` causes UB if
// '[either] the starting [or] resulting pointer [is neither] in bounds or one byte past the
// end of the same allocated object'; the UB results even if the pointer is not dereferenced.
#[inline]
pub fn new(data: &'a [P::Scalar], stride: usize, offset: usize) -> Self {
assert!(
stride >= P::WIDTH,
"stride (got {}) must be at least P::WIDTH ({})",
stride,
P::WIDTH
);
assert_eq!(
data.len() % stride,
0,
"data.len() ({}) must be a multiple of stride (got {})",
data.len(),
stride
);
// This requirement means that stride divides data into slices of `data.len() / stride`
// elements. Every access must fit entirely within one of those slices.
assert!(
offset + P::WIDTH <= stride,
"offset (got {}) + P::WIDTH ({}) cannot be greater than stride (got {})",
offset,
P::WIDTH,
stride
);
// See comment above. `start_ptr` will be more than one byte past the buffer if `data` has
// length 0 and `offset` is not 0.
let start_ptr = data.as_ptr().wrapping_add(offset);
Self {
start_ptr,
length: data.len() / stride,
stride,
_phantom: PhantomData,
}
}
#[inline]
pub fn get(&self, index: usize) -> Option<&'a P> {
if index < self.length {
// Cast scalar pointer to vector pointer.
let res_ptr = unsafe { self.start_ptr.add(index * self.stride) }.cast();
// This transmutation is safe by the spec in `PackedField`.
Some(unsafe { &*res_ptr })
} else {
None
}
}
/// Take a range of `PackedStridedView` indices, as `PackedStridedView`.
#[inline]
pub fn view<I>(&self, index: I) -> Self
where
Self: Viewable<I, View = Self>,
{
// We cannot implement `Index` as `PackedStridedView` is a struct, not a reference.
// The `Viewable` trait is needed for overloading.
// Re-export `Viewable::view` so users don't have to import `Viewable`.
<Self as Viewable<I>>::view(self, index)
}
#[inline]
pub fn iter(&self) -> PackedStridedViewIter<'a, P> {
PackedStridedViewIter::new(
self.start_ptr,
// See comment at the top of the `impl`. Below will point more than one byte past the
// end of the buffer (unless `offset` is 0) so `wrapping_add` is needed.
self.start_ptr.wrapping_add(self.length * self.stride),
self.stride,
)
}
#[inline]
pub fn len(&self) -> usize {
self.length
}
}
impl<'a, P: PackedField> Index<usize> for PackedStridedView<'a, P> {
type Output = P;
#[inline]
fn index(&self, index: usize) -> &Self::Output {
self.get(index)
.expect("invalid memory access in PackedStridedView")
}
}
impl<'a, P: PackedField> IntoIterator for PackedStridedView<'a, P> {
type Item = &'a P;
type IntoIter = PackedStridedViewIter<'a, P>;
fn into_iter(self) -> Self::IntoIter {
self.iter()
}
}
#[derive(Clone, Copy, Debug)]
pub struct TryFromPackedStridedViewError;
impl<P: PackedField, const N: usize> TryInto<[P; N]> for PackedStridedView<'_, P> {
type Error = TryFromPackedStridedViewError;
fn try_into(self) -> Result<[P; N], Self::Error> {
if N == self.len() {
let mut res = [P::ZERO; N];
for i in 0..N {
res[i] = *self.get(i).unwrap();
}
Ok(res)
} else {
Err(TryFromPackedStridedViewError)
}
}
}
// Not deriving `Copy`. An implicit copy of an iterator is likely a bug.
#[derive(Clone, Debug)]
pub struct PackedStridedViewIter<'a, P: PackedField> {
// Again, a pair of pointers is a neater solution than a slice. `start` and `end` are always
// separated by a multiple of stride elements. To advance the iterator from the front, we
// advance `start` by `stride` elements. To advance it from the end, we subtract `stride`
// elements. Iteration is done when they meet.
// A slice cannot recreate the same pattern. The end pointer may point past the underlying
// buffer (this is okay as we do not dereference it in that case); it becomes valid as soon as
// it is decreased by `stride`. On the other hand, a slice that ends on invalid memory is
// instant undefined behavior.
start: *const P::Scalar,
end: *const P::Scalar,
stride: usize,
_phantom: PhantomData<&'a [P::Scalar]>,
}
impl<'a, P: PackedField> PackedStridedViewIter<'a, P> {
pub(self) fn new(start: *const P::Scalar, end: *const P::Scalar, stride: usize) -> Self {
Self {
start,
end,
stride,
_phantom: PhantomData,
}
}
}
impl<'a, P: PackedField> Iterator for PackedStridedViewIter<'a, P> {
type Item = &'a P;
fn next(&mut self) -> Option<Self::Item> {
debug_assert_eq!(
(self.end as usize).wrapping_sub(self.start as usize)
% (self.stride * size_of::<P::Scalar>()),
0,
"start and end pointers should be separated by a multiple of stride"
);
if self.start != self.end {
let res = unsafe { &*self.start.cast() };
// See comment in `PackedStridedView`. Below will point more than one byte past the end
// of the buffer if the offset is not 0 and we've reached the end.
self.start = self.start.wrapping_add(self.stride);
Some(res)
} else {
None
}
}
}
impl<'a, P: PackedField> DoubleEndedIterator for PackedStridedViewIter<'a, P> {
fn next_back(&mut self) -> Option<Self::Item> {
debug_assert_eq!(
(self.end as usize).wrapping_sub(self.start as usize)
% (self.stride * size_of::<P::Scalar>()),
0,
"start and end pointers should be separated by a multiple of stride"
);
if self.start != self.end {
// See comment in `PackedStridedView`. `self.end` starts off pointing more than one byte
// past the end of the buffer unless `offset` is 0.
self.end = self.end.wrapping_sub(self.stride);
Some(unsafe { &*self.end.cast() })
} else {
None
}
}
}
pub trait Viewable<F> {
// We cannot implement `Index` as `PackedStridedView` is a struct, not a reference.
type View;
fn view(&self, index: F) -> Self::View;
}
impl<'a, P: PackedField> Viewable<Range<usize>> for PackedStridedView<'a, P> {
type View = Self;
fn view(&self, range: Range<usize>) -> Self::View {
assert!(range.start <= self.len(), "Invalid access");
assert!(range.end <= self.len(), "Invalid access");
Self {
// See comment in `PackedStridedView`. `self.start_ptr` will point more than one byte
// past the end of the buffer if the offset is not 0 and the buffer has length 0.
start_ptr: self.start_ptr.wrapping_add(self.stride * range.start),
length: range.end - range.start,
stride: self.stride,
_phantom: PhantomData,
}
}
}
impl<'a, P: PackedField> Viewable<RangeFrom<usize>> for PackedStridedView<'a, P> {
type View = Self;
fn view(&self, range: RangeFrom<usize>) -> Self::View {
assert!(range.start <= self.len(), "Invalid access");
Self {
// See comment in `PackedStridedView`. `self.start_ptr` will point more than one byte
// past the end of the buffer if the offset is not 0 and the buffer has length 0.
start_ptr: self.start_ptr.wrapping_add(self.stride * range.start),
length: self.len() - range.start,
stride: self.stride,
_phantom: PhantomData,
}
}
}
impl<'a, P: PackedField> Viewable<RangeFull> for PackedStridedView<'a, P> {
type View = Self;
fn view(&self, _range: RangeFull) -> Self::View {
*self
}
}
impl<'a, P: PackedField> Viewable<RangeInclusive<usize>> for PackedStridedView<'a, P> {
type View = Self;
fn view(&self, range: RangeInclusive<usize>) -> Self::View {
assert!(*range.start() <= self.len(), "Invalid access");
assert!(*range.end() < self.len(), "Invalid access");
Self {
// See comment in `PackedStridedView`. `self.start_ptr` will point more than one byte
// past the end of the buffer if the offset is not 0 and the buffer has length 0.
start_ptr: self.start_ptr.wrapping_add(self.stride * range.start()),
length: range.end() - range.start() + 1,
stride: self.stride,
_phantom: PhantomData,
}
}
}
impl<'a, P: PackedField> Viewable<RangeTo<usize>> for PackedStridedView<'a, P> {
type View = Self;
fn view(&self, range: RangeTo<usize>) -> Self::View {
assert!(range.end <= self.len(), "Invalid access");
Self {
start_ptr: self.start_ptr,
length: range.end,
stride: self.stride,
_phantom: PhantomData,
}
}
}
impl<'a, P: PackedField> Viewable<RangeToInclusive<usize>> for PackedStridedView<'a, P> {
type View = Self;
fn view(&self, range: RangeToInclusive<usize>) -> Self::View {
assert!(range.end < self.len(), "Invalid access");
Self {
start_ptr: self.start_ptr,
length: range.end + 1,
stride: self.stride,
_phantom: PhantomData,
}
}
}