mirror of
https://github.com/logos-storage/plonky2.git
synced 2026-01-08 08:43:06 +00:00
Merge remote-tracking branch 'origin/main' into generic_configuration
This commit is contained in:
commit
81c6f6c7bf
65
src/field/batch_util.rs
Normal file
65
src/field/batch_util.rs
Normal file
@ -0,0 +1,65 @@
|
||||
use crate::field::field_types::Field;
|
||||
use crate::field::packable::Packable;
|
||||
use crate::field::packed_field::PackedField;
|
||||
|
||||
fn pack_with_leftovers_split_point<P: PackedField>(slice: &[P::Scalar]) -> usize {
|
||||
let n = slice.len();
|
||||
let n_leftover = n % P::WIDTH;
|
||||
n - n_leftover
|
||||
}
|
||||
|
||||
fn pack_slice_with_leftovers<P: PackedField>(slice: &[P::Scalar]) -> (&[P], &[P::Scalar]) {
|
||||
let split_point = pack_with_leftovers_split_point::<P>(slice);
|
||||
let (slice_packable, slice_leftovers) = slice.split_at(split_point);
|
||||
let slice_packed = P::pack_slice(slice_packable);
|
||||
(slice_packed, slice_leftovers)
|
||||
}
|
||||
|
||||
fn pack_slice_with_leftovers_mut<P: PackedField>(
|
||||
slice: &mut [P::Scalar],
|
||||
) -> (&mut [P], &mut [P::Scalar]) {
|
||||
let split_point = pack_with_leftovers_split_point::<P>(slice);
|
||||
let (slice_packable, slice_leftovers) = slice.split_at_mut(split_point);
|
||||
let slice_packed = P::pack_slice_mut(slice_packable);
|
||||
(slice_packed, slice_leftovers)
|
||||
}
|
||||
|
||||
/// Elementwise inplace multiplication of two slices of field elements.
|
||||
/// Implementation be faster than the trivial for loop.
|
||||
pub fn batch_multiply_inplace<F: Field>(out: &mut [F], a: &[F]) {
|
||||
let n = out.len();
|
||||
assert_eq!(n, a.len(), "both arrays must have the same length");
|
||||
|
||||
// Split out slice of vectors, leaving leftovers as scalars
|
||||
let (out_packed, out_leftovers) =
|
||||
pack_slice_with_leftovers_mut::<<F as Packable>::Packing>(out);
|
||||
let (a_packed, a_leftovers) = pack_slice_with_leftovers::<<F as Packable>::Packing>(a);
|
||||
|
||||
// Multiply packed and the leftovers
|
||||
for (x_out, x_a) in out_packed.iter_mut().zip(a_packed) {
|
||||
*x_out *= *x_a;
|
||||
}
|
||||
for (x_out, x_a) in out_leftovers.iter_mut().zip(a_leftovers) {
|
||||
*x_out *= *x_a;
|
||||
}
|
||||
}
|
||||
|
||||
/// Elementwise inplace addition of two slices of field elements.
|
||||
/// Implementation be faster than the trivial for loop.
|
||||
pub fn batch_add_inplace<F: Field>(out: &mut [F], a: &[F]) {
|
||||
let n = out.len();
|
||||
assert_eq!(n, a.len(), "both arrays must have the same length");
|
||||
|
||||
// Split out slice of vectors, leaving leftovers as scalars
|
||||
let (out_packed, out_leftovers) =
|
||||
pack_slice_with_leftovers_mut::<<F as Packable>::Packing>(out);
|
||||
let (a_packed, a_leftovers) = pack_slice_with_leftovers::<<F as Packable>::Packing>(a);
|
||||
|
||||
// Add packed and the leftovers
|
||||
for (x_out, x_a) in out_packed.iter_mut().zip(a_packed) {
|
||||
*x_out += *x_a;
|
||||
}
|
||||
for (x_out, x_a) in out_leftovers.iter_mut().zip(a_leftovers) {
|
||||
*x_out += *x_a;
|
||||
}
|
||||
}
|
||||
@ -1,3 +1,4 @@
|
||||
pub(crate) mod batch_util;
|
||||
pub(crate) mod cosets;
|
||||
pub mod extension_field;
|
||||
pub mod fft;
|
||||
|
||||
@ -2,6 +2,7 @@ use crate::field::extension_field::target::ExtensionTarget;
|
||||
use crate::field::extension_field::Extendable;
|
||||
use crate::field::field_types::RichField;
|
||||
use crate::gates::gate::Gate;
|
||||
use crate::gates::util::StridedConstraintConsumer;
|
||||
use crate::iop::generator::{GeneratedValues, SimpleGenerator, WitnessGenerator};
|
||||
use crate::iop::target::Target;
|
||||
use crate::iop::witness::{PartitionWitness, Witness};
|
||||
@ -67,11 +68,14 @@ impl<F: RichField + Extendable<D>, const D: usize> Gate<F, D> for ArithmeticGate
|
||||
constraints
|
||||
}
|
||||
|
||||
fn eval_unfiltered_base(&self, vars: EvaluationVarsBase<F>) -> Vec<F> {
|
||||
fn eval_unfiltered_base_one(
|
||||
&self,
|
||||
vars: EvaluationVarsBase<F>,
|
||||
mut yield_constr: StridedConstraintConsumer<F>,
|
||||
) {
|
||||
let const_0 = vars.local_constants[0];
|
||||
let const_1 = vars.local_constants[1];
|
||||
|
||||
let mut constraints = Vec::new();
|
||||
for i in 0..self.num_ops {
|
||||
let multiplicand_0 = vars.local_wires[Self::wire_ith_multiplicand_0(i)];
|
||||
let multiplicand_1 = vars.local_wires[Self::wire_ith_multiplicand_1(i)];
|
||||
@ -79,10 +83,8 @@ impl<F: RichField + Extendable<D>, const D: usize> Gate<F, D> for ArithmeticGate
|
||||
let output = vars.local_wires[Self::wire_ith_output(i)];
|
||||
let computed_output = multiplicand_0 * multiplicand_1 * const_0 + addend * const_1;
|
||||
|
||||
constraints.push(output - computed_output);
|
||||
yield_constr.one(output - computed_output);
|
||||
}
|
||||
|
||||
constraints
|
||||
}
|
||||
|
||||
fn eval_unfiltered_recursively(
|
||||
|
||||
@ -4,6 +4,7 @@ use crate::field::extension_field::target::ExtensionTarget;
|
||||
use crate::field::extension_field::Extendable;
|
||||
use crate::field::extension_field::FieldExtension;
|
||||
use crate::gates::gate::Gate;
|
||||
use crate::gates::util::StridedConstraintConsumer;
|
||||
use crate::iop::generator::{GeneratedValues, SimpleGenerator, WitnessGenerator};
|
||||
use crate::iop::target::Target;
|
||||
use crate::iop::witness::{PartitionWitness, Witness};
|
||||
@ -70,11 +71,14 @@ impl<F: Extendable<D>, const D: usize> Gate<F, D> for ArithmeticExtensionGate<D>
|
||||
constraints
|
||||
}
|
||||
|
||||
fn eval_unfiltered_base(&self, vars: EvaluationVarsBase<F>) -> Vec<F> {
|
||||
fn eval_unfiltered_base_one(
|
||||
&self,
|
||||
vars: EvaluationVarsBase<F>,
|
||||
mut yield_constr: StridedConstraintConsumer<F>,
|
||||
) {
|
||||
let const_0 = vars.local_constants[0];
|
||||
let const_1 = vars.local_constants[1];
|
||||
|
||||
let mut constraints = Vec::new();
|
||||
for i in 0..self.num_ops {
|
||||
let multiplicand_0 = vars.get_local_ext(Self::wires_ith_multiplicand_0(i));
|
||||
let multiplicand_1 = vars.get_local_ext(Self::wires_ith_multiplicand_1(i));
|
||||
@ -83,10 +87,8 @@ impl<F: Extendable<D>, const D: usize> Gate<F, D> for ArithmeticExtensionGate<D>
|
||||
let computed_output =
|
||||
(multiplicand_0 * multiplicand_1).scalar_mul(const_0) + addend.scalar_mul(const_1);
|
||||
|
||||
constraints.extend((output - computed_output).to_basefield_array());
|
||||
yield_constr.many((output - computed_output).to_basefield_array());
|
||||
}
|
||||
|
||||
constraints
|
||||
}
|
||||
|
||||
fn eval_unfiltered_recursively(
|
||||
|
||||
@ -6,6 +6,7 @@ use crate::field::extension_field::target::ExtensionTarget;
|
||||
use crate::field::extension_field::Extendable;
|
||||
use crate::field::field_types::Field;
|
||||
use crate::gates::gate::Gate;
|
||||
use crate::gates::util::StridedConstraintConsumer;
|
||||
use crate::iop::generator::{GeneratedValues, SimpleGenerator, WitnessGenerator};
|
||||
use crate::iop::target::Target;
|
||||
use crate::iop::wire::Wire;
|
||||
@ -118,8 +119,11 @@ impl<F: Extendable<D>, const D: usize> Gate<F, D> for U32ArithmeticGate<F, D> {
|
||||
constraints
|
||||
}
|
||||
|
||||
fn eval_unfiltered_base(&self, vars: EvaluationVarsBase<F>) -> Vec<F> {
|
||||
let mut constraints = Vec::with_capacity(self.num_constraints());
|
||||
fn eval_unfiltered_base_one(
|
||||
&self,
|
||||
vars: EvaluationVarsBase<F>,
|
||||
mut yield_constr: StridedConstraintConsumer<F>,
|
||||
) {
|
||||
for i in 0..self.num_ops {
|
||||
let multiplicand_0 = vars.local_wires[self.wire_ith_multiplicand_0(i)];
|
||||
let multiplicand_1 = vars.local_wires[self.wire_ith_multiplicand_1(i)];
|
||||
@ -133,7 +137,7 @@ impl<F: Extendable<D>, const D: usize> Gate<F, D> for U32ArithmeticGate<F, D> {
|
||||
let base = F::from_canonical_u64(1 << 32u64);
|
||||
let combined_output = output_high * base + output_low;
|
||||
|
||||
constraints.push(combined_output - computed_output);
|
||||
yield_constr.one(combined_output - computed_output);
|
||||
|
||||
let mut combined_low_limbs = F::ZERO;
|
||||
let mut combined_high_limbs = F::ZERO;
|
||||
@ -145,7 +149,7 @@ impl<F: Extendable<D>, const D: usize> Gate<F, D> for U32ArithmeticGate<F, D> {
|
||||
let product = (0..max_limb)
|
||||
.map(|x| this_limb - F::from_canonical_usize(x))
|
||||
.product();
|
||||
constraints.push(product);
|
||||
yield_constr.one(product);
|
||||
|
||||
if j < midpoint {
|
||||
combined_low_limbs = base * combined_low_limbs + this_limb;
|
||||
@ -153,11 +157,9 @@ impl<F: Extendable<D>, const D: usize> Gate<F, D> for U32ArithmeticGate<F, D> {
|
||||
combined_high_limbs = base * combined_high_limbs + this_limb;
|
||||
}
|
||||
}
|
||||
constraints.push(combined_low_limbs - output_low);
|
||||
constraints.push(combined_high_limbs - output_high);
|
||||
yield_constr.one(combined_low_limbs - output_low);
|
||||
yield_constr.one(combined_high_limbs - output_high);
|
||||
}
|
||||
|
||||
constraints
|
||||
}
|
||||
|
||||
fn eval_unfiltered_recursively(
|
||||
|
||||
@ -4,6 +4,7 @@ use crate::field::extension_field::target::ExtensionTarget;
|
||||
use crate::field::extension_field::Extendable;
|
||||
use crate::field::field_types::{Field, PrimeField, RichField};
|
||||
use crate::gates::gate::Gate;
|
||||
use crate::gates::util::StridedConstraintConsumer;
|
||||
use crate::iop::generator::{GeneratedValues, SimpleGenerator, WitnessGenerator};
|
||||
use crate::iop::target::Target;
|
||||
use crate::iop::wire::Wire;
|
||||
@ -148,9 +149,11 @@ impl<F: RichField + Extendable<D>, const D: usize> Gate<F, D> for AssertLessThan
|
||||
constraints
|
||||
}
|
||||
|
||||
fn eval_unfiltered_base(&self, vars: EvaluationVarsBase<F>) -> Vec<F> {
|
||||
let mut constraints = Vec::with_capacity(self.num_constraints());
|
||||
|
||||
fn eval_unfiltered_base_one(
|
||||
&self,
|
||||
vars: EvaluationVarsBase<F>,
|
||||
mut yield_constr: StridedConstraintConsumer<F>,
|
||||
) {
|
||||
let first_input = vars.local_wires[self.wire_first_input()];
|
||||
let second_input = vars.local_wires[self.wire_second_input()];
|
||||
|
||||
@ -171,8 +174,8 @@ impl<F: RichField + Extendable<D>, const D: usize> Gate<F, D> for AssertLessThan
|
||||
F::from_canonical_usize(1 << self.chunk_bits()),
|
||||
);
|
||||
|
||||
constraints.push(first_chunks_combined - first_input);
|
||||
constraints.push(second_chunks_combined - second_input);
|
||||
yield_constr.one(first_chunks_combined - first_input);
|
||||
yield_constr.one(second_chunks_combined - second_input);
|
||||
|
||||
let chunk_size = 1 << self.chunk_bits();
|
||||
|
||||
@ -186,34 +189,32 @@ impl<F: RichField + Extendable<D>, const D: usize> Gate<F, D> for AssertLessThan
|
||||
let second_product = (0..chunk_size)
|
||||
.map(|x| second_chunks[i] - F::from_canonical_usize(x))
|
||||
.product();
|
||||
constraints.push(first_product);
|
||||
constraints.push(second_product);
|
||||
yield_constr.one(first_product);
|
||||
yield_constr.one(second_product);
|
||||
|
||||
let difference = second_chunks[i] - first_chunks[i];
|
||||
let equality_dummy = vars.local_wires[self.wire_equality_dummy(i)];
|
||||
let chunks_equal = vars.local_wires[self.wire_chunks_equal(i)];
|
||||
|
||||
// Two constraints to assert that `chunks_equal` is valid.
|
||||
constraints.push(difference * equality_dummy - (F::ONE - chunks_equal));
|
||||
constraints.push(chunks_equal * difference);
|
||||
yield_constr.one(difference * equality_dummy - (F::ONE - chunks_equal));
|
||||
yield_constr.one(chunks_equal * difference);
|
||||
|
||||
// Update `most_significant_diff_so_far`.
|
||||
let intermediate_value = vars.local_wires[self.wire_intermediate_value(i)];
|
||||
constraints.push(intermediate_value - chunks_equal * most_significant_diff_so_far);
|
||||
yield_constr.one(intermediate_value - chunks_equal * most_significant_diff_so_far);
|
||||
most_significant_diff_so_far =
|
||||
intermediate_value + (F::ONE - chunks_equal) * difference;
|
||||
}
|
||||
|
||||
let most_significant_diff = vars.local_wires[self.wire_most_significant_diff()];
|
||||
constraints.push(most_significant_diff - most_significant_diff_so_far);
|
||||
yield_constr.one(most_significant_diff - most_significant_diff_so_far);
|
||||
|
||||
// Range check `most_significant_diff` to be less than `chunk_size`.
|
||||
let product = (0..chunk_size)
|
||||
.map(|x| most_significant_diff - F::from_canonical_usize(x))
|
||||
.product();
|
||||
constraints.push(product);
|
||||
|
||||
constraints
|
||||
yield_constr.one(product);
|
||||
}
|
||||
|
||||
fn eval_unfiltered_recursively(
|
||||
|
||||
@ -4,6 +4,7 @@ use crate::field::extension_field::target::ExtensionTarget;
|
||||
use crate::field::extension_field::Extendable;
|
||||
use crate::field::field_types::{Field, PrimeField, RichField};
|
||||
use crate::gates::gate::Gate;
|
||||
use crate::gates::util::StridedConstraintConsumer;
|
||||
use crate::iop::generator::{GeneratedValues, SimpleGenerator, WitnessGenerator};
|
||||
use crate::iop::target::Target;
|
||||
use crate::iop::witness::{PartitionWitness, Witness};
|
||||
@ -57,22 +58,23 @@ impl<F: Extendable<D>, const D: usize, const B: usize> Gate<F, D> for BaseSumGat
|
||||
constraints
|
||||
}
|
||||
|
||||
fn eval_unfiltered_base(&self, vars: EvaluationVarsBase<F>) -> Vec<F> {
|
||||
fn eval_unfiltered_base_one(
|
||||
&self,
|
||||
vars: EvaluationVarsBase<F>,
|
||||
mut yield_constr: StridedConstraintConsumer<F>,
|
||||
) {
|
||||
let sum = vars.local_wires[Self::WIRE_SUM];
|
||||
let limbs = &vars.local_wires[self.limbs()];
|
||||
let limbs = vars.local_wires.view(self.limbs());
|
||||
let computed_sum = reduce_with_powers(limbs, F::from_canonical_usize(B));
|
||||
|
||||
let mut constraints = Vec::with_capacity(limbs.len() + 1);
|
||||
constraints.push(computed_sum - sum);
|
||||
yield_constr.one(computed_sum - sum);
|
||||
|
||||
let constraints_iter = limbs.iter().map(|&limb| {
|
||||
(0..B)
|
||||
.map(|i| unsafe { limb.sub_canonical_u64(i as u64) })
|
||||
.product::<F>()
|
||||
});
|
||||
constraints.extend(constraints_iter);
|
||||
|
||||
constraints
|
||||
yield_constr.many(constraints_iter);
|
||||
}
|
||||
|
||||
fn eval_unfiltered_recursively(
|
||||
|
||||
@ -4,6 +4,7 @@ use crate::field::extension_field::target::ExtensionTarget;
|
||||
use crate::field::extension_field::Extendable;
|
||||
use crate::field::field_types::{Field, PrimeField};
|
||||
use crate::gates::gate::Gate;
|
||||
use crate::gates::util::StridedConstraintConsumer;
|
||||
use crate::iop::generator::{GeneratedValues, SimpleGenerator, WitnessGenerator};
|
||||
use crate::iop::target::Target;
|
||||
use crate::iop::wire::Wire;
|
||||
@ -166,9 +167,11 @@ impl<F: Extendable<D>, const D: usize> Gate<F, D> for ComparisonGate<F, D> {
|
||||
constraints
|
||||
}
|
||||
|
||||
fn eval_unfiltered_base(&self, vars: EvaluationVarsBase<F>) -> Vec<F> {
|
||||
let mut constraints = Vec::with_capacity(self.num_constraints());
|
||||
|
||||
fn eval_unfiltered_base_one(
|
||||
&self,
|
||||
vars: EvaluationVarsBase<F>,
|
||||
mut yield_constr: StridedConstraintConsumer<F>,
|
||||
) {
|
||||
let first_input = vars.local_wires[self.wire_first_input()];
|
||||
let second_input = vars.local_wires[self.wire_second_input()];
|
||||
|
||||
@ -189,8 +192,8 @@ impl<F: Extendable<D>, const D: usize> Gate<F, D> for ComparisonGate<F, D> {
|
||||
F::from_canonical_usize(1 << self.chunk_bits()),
|
||||
);
|
||||
|
||||
constraints.push(first_chunks_combined - first_input);
|
||||
constraints.push(second_chunks_combined - second_input);
|
||||
yield_constr.one(first_chunks_combined - first_input);
|
||||
yield_constr.one(second_chunks_combined - second_input);
|
||||
|
||||
let chunk_size = 1 << self.chunk_bits();
|
||||
|
||||
@ -204,26 +207,26 @@ impl<F: Extendable<D>, const D: usize> Gate<F, D> for ComparisonGate<F, D> {
|
||||
let second_product: F = (0..chunk_size)
|
||||
.map(|x| second_chunks[i] - F::from_canonical_usize(x))
|
||||
.product();
|
||||
constraints.push(first_product);
|
||||
constraints.push(second_product);
|
||||
yield_constr.one(first_product);
|
||||
yield_constr.one(second_product);
|
||||
|
||||
let difference = second_chunks[i] - first_chunks[i];
|
||||
let equality_dummy = vars.local_wires[self.wire_equality_dummy(i)];
|
||||
let chunks_equal = vars.local_wires[self.wire_chunks_equal(i)];
|
||||
|
||||
// Two constraints to assert that `chunks_equal` is valid.
|
||||
constraints.push(difference * equality_dummy - (F::ONE - chunks_equal));
|
||||
constraints.push(chunks_equal * difference);
|
||||
yield_constr.one(difference * equality_dummy - (F::ONE - chunks_equal));
|
||||
yield_constr.one(chunks_equal * difference);
|
||||
|
||||
// Update `most_significant_diff_so_far`.
|
||||
let intermediate_value = vars.local_wires[self.wire_intermediate_value(i)];
|
||||
constraints.push(intermediate_value - chunks_equal * most_significant_diff_so_far);
|
||||
yield_constr.one(intermediate_value - chunks_equal * most_significant_diff_so_far);
|
||||
most_significant_diff_so_far =
|
||||
intermediate_value + (F::ONE - chunks_equal) * difference;
|
||||
}
|
||||
|
||||
let most_significant_diff = vars.local_wires[self.wire_most_significant_diff()];
|
||||
constraints.push(most_significant_diff - most_significant_diff_so_far);
|
||||
yield_constr.one(most_significant_diff - most_significant_diff_so_far);
|
||||
|
||||
let most_significant_diff_bits: Vec<F> = (0..self.chunk_bits() + 1)
|
||||
.map(|i| vars.local_wires[self.wire_most_significant_diff_bit(i)])
|
||||
@ -231,18 +234,16 @@ impl<F: Extendable<D>, const D: usize> Gate<F, D> for ComparisonGate<F, D> {
|
||||
|
||||
// Range-check the bits.
|
||||
for &bit in &most_significant_diff_bits {
|
||||
constraints.push(bit * (F::ONE - bit));
|
||||
yield_constr.one(bit * (F::ONE - bit));
|
||||
}
|
||||
|
||||
let bits_combined = reduce_with_powers(&most_significant_diff_bits, F::TWO);
|
||||
let two_n = F::from_canonical_u64(1 << self.chunk_bits());
|
||||
constraints.push((two_n + most_significant_diff) - bits_combined);
|
||||
yield_constr.one((two_n + most_significant_diff) - bits_combined);
|
||||
|
||||
// Iff first <= second, the top (n + 1st) bit of (2^n - 1 + most_significant_diff) will be 1.
|
||||
let result_bool = vars.local_wires[self.wire_result_bool()];
|
||||
constraints.push(result_bool - most_significant_diff_bits[self.chunk_bits()]);
|
||||
|
||||
constraints
|
||||
yield_constr.one(result_bool - most_significant_diff_bits[self.chunk_bits()]);
|
||||
}
|
||||
|
||||
fn eval_unfiltered_recursively(
|
||||
|
||||
@ -4,6 +4,7 @@ use crate::field::extension_field::target::ExtensionTarget;
|
||||
use crate::field::extension_field::Extendable;
|
||||
use crate::field::field_types::Field;
|
||||
use crate::gates::gate::Gate;
|
||||
use crate::gates::util::StridedConstraintConsumer;
|
||||
use crate::iop::generator::{GeneratedValues, SimpleGenerator, WitnessGenerator};
|
||||
use crate::iop::target::Target;
|
||||
use crate::iop::wire::Wire;
|
||||
@ -39,11 +40,16 @@ impl<F: Extendable<D>, const D: usize> Gate<F, D> for ConstantGate {
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn eval_unfiltered_base(&self, vars: EvaluationVarsBase<F>) -> Vec<F> {
|
||||
self.consts_inputs()
|
||||
.zip(self.wires_outputs())
|
||||
.map(|(con, out)| vars.local_constants[con] - vars.local_wires[out])
|
||||
.collect()
|
||||
fn eval_unfiltered_base_one(
|
||||
&self,
|
||||
vars: EvaluationVarsBase<F>,
|
||||
mut yield_constr: StridedConstraintConsumer<F>,
|
||||
) {
|
||||
yield_constr.many(
|
||||
self.consts_inputs()
|
||||
.zip(self.wires_outputs())
|
||||
.map(|(con, out)| vars.local_constants[con] - vars.local_wires[out]),
|
||||
);
|
||||
}
|
||||
|
||||
fn eval_unfiltered_recursively(
|
||||
|
||||
@ -4,6 +4,7 @@ use crate::field::extension_field::target::ExtensionTarget;
|
||||
use crate::field::extension_field::Extendable;
|
||||
use crate::field::field_types::Field;
|
||||
use crate::gates::gate::Gate;
|
||||
use crate::gates::util::StridedConstraintConsumer;
|
||||
use crate::iop::generator::{GeneratedValues, SimpleGenerator, WitnessGenerator};
|
||||
use crate::iop::target::Target;
|
||||
use crate::iop::wire::Wire;
|
||||
@ -99,7 +100,11 @@ impl<F: Extendable<D>, const D: usize> Gate<F, D> for ExponentiationGate<F, D> {
|
||||
constraints
|
||||
}
|
||||
|
||||
fn eval_unfiltered_base(&self, vars: EvaluationVarsBase<F>) -> Vec<F> {
|
||||
fn eval_unfiltered_base_one(
|
||||
&self,
|
||||
vars: EvaluationVarsBase<F>,
|
||||
mut yield_constr: StridedConstraintConsumer<F>,
|
||||
) {
|
||||
let base = vars.local_wires[self.wire_base()];
|
||||
|
||||
let power_bits: Vec<_> = (0..self.num_power_bits)
|
||||
@ -111,8 +116,6 @@ impl<F: Extendable<D>, const D: usize> Gate<F, D> for ExponentiationGate<F, D> {
|
||||
|
||||
let output = vars.local_wires[self.wire_output()];
|
||||
|
||||
let mut constraints = Vec::with_capacity(self.num_constraints());
|
||||
|
||||
for i in 0..self.num_power_bits {
|
||||
let prev_intermediate_value = if i == 0 {
|
||||
F::ONE
|
||||
@ -126,12 +129,10 @@ impl<F: Extendable<D>, const D: usize> Gate<F, D> for ExponentiationGate<F, D> {
|
||||
let not_cur_bit = F::ONE - cur_bit;
|
||||
let computed_intermediate_value =
|
||||
prev_intermediate_value * (cur_bit * base + not_cur_bit);
|
||||
constraints.push(computed_intermediate_value - intermediate_values[i]);
|
||||
yield_constr.one(computed_intermediate_value - intermediate_values[i]);
|
||||
}
|
||||
|
||||
constraints.push(output - intermediate_values[self.num_power_bits - 1]);
|
||||
|
||||
constraints
|
||||
yield_constr.one(output - intermediate_values[self.num_power_bits - 1]);
|
||||
}
|
||||
|
||||
fn eval_unfiltered_recursively(
|
||||
|
||||
@ -2,13 +2,17 @@ use std::fmt::{Debug, Error, Formatter};
|
||||
use std::hash::{Hash, Hasher};
|
||||
use std::sync::Arc;
|
||||
|
||||
use crate::field::batch_util::batch_multiply_inplace;
|
||||
use crate::field::extension_field::target::ExtensionTarget;
|
||||
use crate::field::extension_field::{Extendable, FieldExtension};
|
||||
use crate::field::field_types::{Field, RichField};
|
||||
use crate::gates::gate_tree::Tree;
|
||||
use crate::gates::util::StridedConstraintConsumer;
|
||||
use crate::iop::generator::WitnessGenerator;
|
||||
use crate::plonk::circuit_builder::CircuitBuilder;
|
||||
use crate::plonk::vars::{EvaluationTargets, EvaluationVars, EvaluationVarsBase};
|
||||
use crate::plonk::vars::{
|
||||
EvaluationTargets, EvaluationVars, EvaluationVarsBase, EvaluationVarsBaseBatch,
|
||||
};
|
||||
|
||||
/// A custom gate.
|
||||
pub trait Gate<F: Extendable<D>, const D: usize>: 'static + Send + Sync {
|
||||
@ -18,9 +22,19 @@ pub trait Gate<F: Extendable<D>, const D: usize>: 'static + Send + Sync {
|
||||
|
||||
/// Like `eval_unfiltered`, but specialized for points in the base field.
|
||||
///
|
||||
///
|
||||
/// `eval_unfiltered_base_batch` calls this method by default. If `eval_unfiltered_base_batch`
|
||||
/// is overridden, then `eval_unfiltered_base_one` is not necessary.
|
||||
///
|
||||
/// By default, this just calls `eval_unfiltered`, which treats the point as an extension field
|
||||
/// element. This isn't very efficient.
|
||||
fn eval_unfiltered_base(&self, vars_base: EvaluationVarsBase<F>) -> Vec<F> {
|
||||
fn eval_unfiltered_base_one(
|
||||
&self,
|
||||
vars_base: EvaluationVarsBase<F>,
|
||||
mut yield_constr: StridedConstraintConsumer<F>,
|
||||
) {
|
||||
// Note that this method uses `yield_constr` instead of returning its constraints.
|
||||
// `yield_constr` abstracts out the underlying memory layout.
|
||||
let local_constants = &vars_base
|
||||
.local_constants
|
||||
.iter()
|
||||
@ -40,13 +54,21 @@ pub trait Gate<F: Extendable<D>, const D: usize>: 'static + Send + Sync {
|
||||
let values = self.eval_unfiltered(vars);
|
||||
|
||||
// Each value should be in the base field, i.e. only the degree-zero part should be nonzero.
|
||||
values
|
||||
.into_iter()
|
||||
.map(|value| {
|
||||
debug_assert!(F::Extension::is_in_basefield(&value));
|
||||
value.to_basefield_array()[0]
|
||||
})
|
||||
.collect()
|
||||
values.into_iter().for_each(|value| {
|
||||
debug_assert!(F::Extension::is_in_basefield(&value));
|
||||
yield_constr.one(value.to_basefield_array()[0])
|
||||
})
|
||||
}
|
||||
|
||||
fn eval_unfiltered_base_batch(&self, vars_base: EvaluationVarsBaseBatch<F>) -> Vec<F> {
|
||||
let mut res = vec![F::ZERO; vars_base.len() * self.num_constraints()];
|
||||
for (i, vars_base_one) in vars_base.iter().enumerate() {
|
||||
self.eval_unfiltered_base_one(
|
||||
vars_base_one,
|
||||
StridedConstraintConsumer::new(&mut res, vars_base.len(), i),
|
||||
);
|
||||
}
|
||||
res
|
||||
}
|
||||
|
||||
fn eval_unfiltered_recursively(
|
||||
@ -64,26 +86,23 @@ pub trait Gate<F: Extendable<D>, const D: usize>: 'static + Send + Sync {
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Like `eval_filtered`, but specialized for points in the base field.
|
||||
fn eval_filtered_base(&self, mut vars: EvaluationVarsBase<F>, prefix: &[bool]) -> Vec<F> {
|
||||
let filter = compute_filter(prefix, vars.local_constants);
|
||||
vars.remove_prefix(prefix);
|
||||
let mut res = self.eval_unfiltered_base(vars);
|
||||
res.iter_mut().for_each(|c| {
|
||||
*c *= filter;
|
||||
});
|
||||
res
|
||||
}
|
||||
|
||||
/// The result is an array of length `vars_batch.len() * self.num_constraints()`. Constraint `j`
|
||||
/// for point `i` is at index `j * batch_size + i`.
|
||||
fn eval_filtered_base_batch(
|
||||
&self,
|
||||
vars_batch: &[EvaluationVarsBase<F>],
|
||||
mut vars_batch: EvaluationVarsBaseBatch<F>,
|
||||
prefix: &[bool],
|
||||
) -> Vec<Vec<F>> {
|
||||
vars_batch
|
||||
) -> Vec<F> {
|
||||
let filters: Vec<_> = vars_batch
|
||||
.iter()
|
||||
.map(|&vars| self.eval_filtered_base(vars, prefix))
|
||||
.collect()
|
||||
.map(|vars| compute_filter(prefix, vars.local_constants))
|
||||
.collect();
|
||||
vars_batch.remove_prefix(prefix);
|
||||
let mut res_batch = self.eval_unfiltered_base_batch(vars_batch);
|
||||
for res_chunk in res_batch.chunks_exact_mut(filters.len()) {
|
||||
batch_multiply_inplace(res_chunk, &filters);
|
||||
}
|
||||
res_batch
|
||||
}
|
||||
|
||||
/// Adds this gate's filtered constraints into the `combined_gate_constraints` buffer.
|
||||
@ -174,17 +193,11 @@ impl<F: RichField + Extendable<D>, const D: usize> PrefixedGate<F, D> {
|
||||
|
||||
/// A gate's filter is computed as `prod b_i*c_i + (1-b_i)*(1-c_i)`, with `(b_i)` the prefix and
|
||||
/// `(c_i)` the local constants, which is one if the prefix of `constants` matches `prefix`.
|
||||
fn compute_filter<K: Field>(prefix: &[bool], constants: &[K]) -> K {
|
||||
fn compute_filter<'a, K: Field, T: IntoIterator<Item = &'a K>>(prefix: &[bool], constants: T) -> K {
|
||||
prefix
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(i, &b)| {
|
||||
if b {
|
||||
constants[i]
|
||||
} else {
|
||||
K::ONE - constants[i]
|
||||
}
|
||||
})
|
||||
.zip(constants)
|
||||
.map(|(&b, &c)| if b { c } else { K::ONE - c })
|
||||
.product()
|
||||
}
|
||||
|
||||
|
||||
@ -8,7 +8,7 @@ use crate::iop::witness::{PartialWitness, Witness};
|
||||
use crate::plonk::circuit_builder::CircuitBuilder;
|
||||
use crate::plonk::circuit_data::CircuitConfig;
|
||||
use crate::plonk::config::GenericConfig;
|
||||
use crate::plonk::vars::{EvaluationTargets, EvaluationVars, EvaluationVarsBase};
|
||||
use crate::plonk::vars::{EvaluationTargets, EvaluationVars, EvaluationVarsBaseBatch};
|
||||
use crate::plonk::verifier::verify;
|
||||
use crate::polynomial::{PolynomialCoeffs, PolynomialValues};
|
||||
use crate::util::{log2_ceil, transpose};
|
||||
@ -106,19 +106,18 @@ pub(crate) fn test_eval_fns<
|
||||
.collect::<Vec<_>>();
|
||||
let public_inputs_hash = HashOut::rand();
|
||||
|
||||
let vars_base = EvaluationVarsBase {
|
||||
local_constants: &constants_base,
|
||||
local_wires: &wires_base,
|
||||
public_inputs_hash: &public_inputs_hash,
|
||||
};
|
||||
// Batch of 1.
|
||||
let vars_base_batch =
|
||||
EvaluationVarsBaseBatch::new(1, &constants_base, &wires_base, &public_inputs_hash);
|
||||
let vars = EvaluationVars {
|
||||
local_constants: &constants,
|
||||
local_wires: &wires,
|
||||
public_inputs_hash: &public_inputs_hash,
|
||||
};
|
||||
|
||||
let evals_base = gate.eval_unfiltered_base(vars_base);
|
||||
let evals_base = gate.eval_unfiltered_base_batch(vars_base_batch);
|
||||
let evals = gate.eval_unfiltered(vars);
|
||||
// This works because we have a batch of 1.
|
||||
ensure!(
|
||||
evals
|
||||
== evals_base
|
||||
|
||||
@ -4,6 +4,7 @@ use crate::field::extension_field::target::ExtensionTarget;
|
||||
use crate::field::extension_field::Extendable;
|
||||
use crate::field::field_types::Field;
|
||||
use crate::gates::gate::Gate;
|
||||
use crate::gates::util::StridedConstraintConsumer;
|
||||
use crate::hash::gmimc;
|
||||
use crate::hash::gmimc::GMiMC;
|
||||
use crate::iop::generator::{GeneratedValues, SimpleGenerator, WitnessGenerator};
|
||||
@ -107,12 +108,14 @@ impl<F: Extendable<D> + GMiMC<WIDTH>, const D: usize, const WIDTH: usize> Gate<F
|
||||
constraints
|
||||
}
|
||||
|
||||
fn eval_unfiltered_base(&self, vars: EvaluationVarsBase<F>) -> Vec<F> {
|
||||
let mut constraints = Vec::with_capacity(self.num_constraints());
|
||||
|
||||
fn eval_unfiltered_base_one(
|
||||
&self,
|
||||
vars: EvaluationVarsBase<F>,
|
||||
mut yield_constr: StridedConstraintConsumer<F>,
|
||||
) {
|
||||
// Assert that `swap` is binary.
|
||||
let swap = vars.local_wires[Self::WIRE_SWAP];
|
||||
constraints.push(swap * swap.sub_one());
|
||||
yield_constr.one(swap * swap.sub_one());
|
||||
|
||||
let mut state = Vec::with_capacity(12);
|
||||
for i in 0..4 {
|
||||
@ -138,7 +141,7 @@ impl<F: Extendable<D> + GMiMC<WIDTH>, const D: usize, const WIDTH: usize> Gate<F
|
||||
let constant = F::from_canonical_u64(<F as GMiMC<WIDTH>>::ROUND_CONSTANTS[r]);
|
||||
let cubing_input = state[active] + addition_buffer + constant;
|
||||
let cubing_input_wire = vars.local_wires[Self::wire_cubing_input(r)];
|
||||
constraints.push(cubing_input - cubing_input_wire);
|
||||
yield_constr.one(cubing_input - cubing_input_wire);
|
||||
let f = cubing_input_wire.cube();
|
||||
addition_buffer += f;
|
||||
state[active] -= f;
|
||||
@ -146,10 +149,8 @@ impl<F: Extendable<D> + GMiMC<WIDTH>, const D: usize, const WIDTH: usize> Gate<F
|
||||
|
||||
for i in 0..WIDTH {
|
||||
state[i] += addition_buffer;
|
||||
constraints.push(state[i] - vars.local_wires[Self::wire_output(i)]);
|
||||
yield_constr.one(state[i] - vars.local_wires[Self::wire_output(i)]);
|
||||
}
|
||||
|
||||
constraints
|
||||
}
|
||||
|
||||
fn eval_unfiltered_recursively(
|
||||
|
||||
@ -5,6 +5,7 @@ use crate::field::extension_field::target::ExtensionTarget;
|
||||
use crate::field::extension_field::{Extendable, FieldExtension};
|
||||
use crate::field::field_types::{Field, RichField};
|
||||
use crate::gates::gate::Gate;
|
||||
use crate::gates::util::StridedConstraintConsumer;
|
||||
use crate::iop::generator::{GeneratedValues, SimpleGenerator, WitnessGenerator};
|
||||
use crate::iop::target::Target;
|
||||
use crate::iop::wire::Wire;
|
||||
@ -112,7 +113,11 @@ impl<F: Extendable<D>, const D: usize> Gate<F, D> for InsertionGate<F, D> {
|
||||
constraints
|
||||
}
|
||||
|
||||
fn eval_unfiltered_base(&self, vars: EvaluationVarsBase<F>) -> Vec<F> {
|
||||
fn eval_unfiltered_base_one(
|
||||
&self,
|
||||
vars: EvaluationVarsBase<F>,
|
||||
mut yield_constr: StridedConstraintConsumer<F>,
|
||||
) {
|
||||
let insertion_index = vars.local_wires[self.wires_insertion_index()];
|
||||
let list_items = (0..self.vec_size)
|
||||
.map(|i| vars.get_local_ext(self.wires_original_list_item(i)))
|
||||
@ -122,7 +127,6 @@ impl<F: Extendable<D>, const D: usize> Gate<F, D> for InsertionGate<F, D> {
|
||||
.collect::<Vec<_>>();
|
||||
let element_to_insert = vars.get_local_ext(self.wires_element_to_insert());
|
||||
|
||||
let mut constraints = Vec::with_capacity(self.num_constraints());
|
||||
let mut already_inserted = F::ZERO;
|
||||
for r in 0..=self.vec_size {
|
||||
let cur_index = F::from_canonical_usize(r);
|
||||
@ -131,8 +135,8 @@ impl<F: Extendable<D>, const D: usize> Gate<F, D> for InsertionGate<F, D> {
|
||||
let insert_here = vars.local_wires[self.wire_insert_here_for_round_r(r)];
|
||||
|
||||
// The two equality constraints.
|
||||
constraints.push(difference * equality_dummy - (F::ONE - insert_here));
|
||||
constraints.push(insert_here * difference);
|
||||
yield_constr.one(difference * equality_dummy - (F::ONE - insert_here));
|
||||
yield_constr.one(insert_here * difference);
|
||||
|
||||
let mut new_item = element_to_insert.scalar_mul(insert_here);
|
||||
if r > 0 {
|
||||
@ -144,10 +148,8 @@ impl<F: Extendable<D>, const D: usize> Gate<F, D> for InsertionGate<F, D> {
|
||||
}
|
||||
|
||||
// Output constraint.
|
||||
constraints.extend((new_item - output_list_items[r]).to_basefield_array());
|
||||
yield_constr.many((new_item - output_list_items[r]).to_basefield_array());
|
||||
}
|
||||
|
||||
constraints
|
||||
}
|
||||
|
||||
fn eval_unfiltered_recursively(
|
||||
|
||||
@ -8,6 +8,7 @@ use crate::field::interpolation::interpolant;
|
||||
use crate::gadgets::interpolation::InterpolationGate;
|
||||
use crate::gadgets::polynomial::PolynomialCoeffsExtAlgebraTarget;
|
||||
use crate::gates::gate::Gate;
|
||||
use crate::gates::util::StridedConstraintConsumer;
|
||||
use crate::iop::generator::{GeneratedValues, SimpleGenerator, WitnessGenerator};
|
||||
use crate::iop::target::Target;
|
||||
use crate::iop::wire::Wire;
|
||||
@ -106,9 +107,11 @@ impl<F: Extendable<D>, const D: usize> Gate<F, D> for HighDegreeInterpolationGat
|
||||
constraints
|
||||
}
|
||||
|
||||
fn eval_unfiltered_base(&self, vars: EvaluationVarsBase<F>) -> Vec<F> {
|
||||
let mut constraints = Vec::with_capacity(self.num_constraints());
|
||||
|
||||
fn eval_unfiltered_base_one(
|
||||
&self,
|
||||
vars: EvaluationVarsBase<F>,
|
||||
mut yield_constr: StridedConstraintConsumer<F>,
|
||||
) {
|
||||
let coeffs = (0..self.num_points())
|
||||
.map(|i| vars.get_local_ext(self.wires_coeff(i)))
|
||||
.collect();
|
||||
@ -118,15 +121,13 @@ impl<F: Extendable<D>, const D: usize> Gate<F, D> for HighDegreeInterpolationGat
|
||||
for (i, point) in coset.into_iter().enumerate() {
|
||||
let value = vars.get_local_ext(self.wires_value(i));
|
||||
let computed_value = interpolant.eval_base(point);
|
||||
constraints.extend(&(value - computed_value).to_basefield_array());
|
||||
yield_constr.many((value - computed_value).to_basefield_array());
|
||||
}
|
||||
|
||||
let evaluation_point = vars.get_local_ext(self.wires_evaluation_point());
|
||||
let evaluation_value = vars.get_local_ext(self.wires_evaluation_value());
|
||||
let computed_evaluation_value = interpolant.eval(evaluation_point);
|
||||
constraints.extend(&(evaluation_value - computed_evaluation_value).to_basefield_array());
|
||||
|
||||
constraints
|
||||
yield_constr.many((evaluation_value - computed_evaluation_value).to_basefield_array());
|
||||
}
|
||||
|
||||
fn eval_unfiltered_recursively(
|
||||
|
||||
@ -9,6 +9,7 @@ use crate::field::interpolation::interpolant;
|
||||
use crate::gadgets::interpolation::InterpolationGate;
|
||||
use crate::gadgets::polynomial::PolynomialCoeffsExtAlgebraTarget;
|
||||
use crate::gates::gate::Gate;
|
||||
use crate::gates::util::StridedConstraintConsumer;
|
||||
use crate::iop::generator::{GeneratedValues, SimpleGenerator, WitnessGenerator};
|
||||
use crate::iop::target::Target;
|
||||
use crate::iop::wire::Wire;
|
||||
@ -130,9 +131,11 @@ impl<F: RichField + Extendable<D>, const D: usize> Gate<F, D> for LowDegreeInter
|
||||
constraints
|
||||
}
|
||||
|
||||
fn eval_unfiltered_base(&self, vars: EvaluationVarsBase<F>) -> Vec<F> {
|
||||
let mut constraints = Vec::with_capacity(self.num_constraints());
|
||||
|
||||
fn eval_unfiltered_base_one(
|
||||
&self,
|
||||
vars: EvaluationVarsBase<F>,
|
||||
mut yield_constr: StridedConstraintConsumer<F>,
|
||||
) {
|
||||
let coeffs = (0..self.num_points())
|
||||
.map(|i| vars.get_local_ext(self.wires_coeff(i)))
|
||||
.collect::<Vec<_>>();
|
||||
@ -142,7 +145,7 @@ impl<F: RichField + Extendable<D>, const D: usize> Gate<F, D> for LowDegreeInter
|
||||
.collect::<Vec<_>>();
|
||||
let shift = powers_shift[0];
|
||||
for i in 1..self.num_points() - 1 {
|
||||
constraints.push(powers_shift[i - 1] * shift - powers_shift[i]);
|
||||
yield_constr.one(powers_shift[i - 1] * shift - powers_shift[i]);
|
||||
}
|
||||
powers_shift.insert(0, F::ONE);
|
||||
// `altered_coeffs[i] = c_i * shift^i`, where `c_i` is the original coefficient.
|
||||
@ -161,7 +164,7 @@ impl<F: RichField + Extendable<D>, const D: usize> Gate<F, D> for LowDegreeInter
|
||||
{
|
||||
let value = vars.get_local_ext(self.wires_value(i));
|
||||
let computed_value = altered_interpolant.eval_base(point);
|
||||
constraints.extend(&(value - computed_value).to_basefield_array());
|
||||
yield_constr.many((value - computed_value).to_basefield_array());
|
||||
}
|
||||
|
||||
let evaluation_point_powers = (1..self.num_points())
|
||||
@ -169,16 +172,14 @@ impl<F: RichField + Extendable<D>, const D: usize> Gate<F, D> for LowDegreeInter
|
||||
.collect::<Vec<_>>();
|
||||
let evaluation_point = evaluation_point_powers[0];
|
||||
for i in 1..self.num_points() - 1 {
|
||||
constraints.extend(
|
||||
yield_constr.many(
|
||||
(evaluation_point_powers[i - 1] * evaluation_point - evaluation_point_powers[i])
|
||||
.to_basefield_array(),
|
||||
);
|
||||
}
|
||||
let evaluation_value = vars.get_local_ext(self.wires_evaluation_value());
|
||||
let computed_evaluation_value = interpolant.eval_with_powers(&evaluation_point_powers);
|
||||
constraints.extend(&(evaluation_value - computed_evaluation_value).to_basefield_array());
|
||||
|
||||
constraints
|
||||
yield_constr.many((evaluation_value - computed_evaluation_value).to_basefield_array());
|
||||
}
|
||||
|
||||
fn eval_unfiltered_recursively(
|
||||
|
||||
@ -25,6 +25,7 @@ pub mod reducing;
|
||||
pub mod reducing_extension;
|
||||
pub mod subtraction_u32;
|
||||
pub mod switch;
|
||||
mod util;
|
||||
|
||||
#[cfg(test)]
|
||||
mod gate_testing;
|
||||
|
||||
@ -5,6 +5,7 @@ use crate::field::extension_field::Extendable;
|
||||
use crate::field::extension_field::FieldExtension;
|
||||
use crate::field::field_types::RichField;
|
||||
use crate::gates::gate::Gate;
|
||||
use crate::gates::util::StridedConstraintConsumer;
|
||||
use crate::iop::generator::{GeneratedValues, SimpleGenerator, WitnessGenerator};
|
||||
use crate::iop::target::Target;
|
||||
use crate::iop::witness::{PartitionWitness, Witness};
|
||||
@ -65,20 +66,21 @@ impl<F: RichField + Extendable<D>, const D: usize> Gate<F, D> for MulExtensionGa
|
||||
constraints
|
||||
}
|
||||
|
||||
fn eval_unfiltered_base(&self, vars: EvaluationVarsBase<F>) -> Vec<F> {
|
||||
fn eval_unfiltered_base_one(
|
||||
&self,
|
||||
vars: EvaluationVarsBase<F>,
|
||||
mut yield_constr: StridedConstraintConsumer<F>,
|
||||
) {
|
||||
let const_0 = vars.local_constants[0];
|
||||
|
||||
let mut constraints = Vec::new();
|
||||
for i in 0..self.num_ops {
|
||||
let multiplicand_0 = vars.get_local_ext(Self::wires_ith_multiplicand_0(i));
|
||||
let multiplicand_1 = vars.get_local_ext(Self::wires_ith_multiplicand_1(i));
|
||||
let output = vars.get_local_ext(Self::wires_ith_output(i));
|
||||
let computed_output = (multiplicand_0 * multiplicand_1).scalar_mul(const_0);
|
||||
|
||||
constraints.extend((output - computed_output).to_basefield_array());
|
||||
yield_constr.many((output - computed_output).to_basefield_array());
|
||||
}
|
||||
|
||||
constraints
|
||||
}
|
||||
|
||||
fn eval_unfiltered_recursively(
|
||||
|
||||
@ -3,7 +3,7 @@ use crate::field::extension_field::Extendable;
|
||||
use crate::gates::gate::Gate;
|
||||
use crate::iop::generator::WitnessGenerator;
|
||||
use crate::plonk::circuit_builder::CircuitBuilder;
|
||||
use crate::plonk::vars::{EvaluationTargets, EvaluationVars, EvaluationVarsBase};
|
||||
use crate::plonk::vars::{EvaluationTargets, EvaluationVars, EvaluationVarsBaseBatch};
|
||||
|
||||
/// A gate which does nothing.
|
||||
pub struct NoopGate;
|
||||
@ -17,7 +17,7 @@ impl<F: Extendable<D>, const D: usize> Gate<F, D> for NoopGate {
|
||||
Vec::new()
|
||||
}
|
||||
|
||||
fn eval_unfiltered_base(&self, _vars: EvaluationVarsBase<F>) -> Vec<F> {
|
||||
fn eval_unfiltered_base_batch(&self, _vars: EvaluationVarsBaseBatch<F>) -> Vec<F> {
|
||||
Vec::new()
|
||||
}
|
||||
|
||||
|
||||
@ -5,6 +5,7 @@ use crate::field::extension_field::Extendable;
|
||||
use crate::field::field_types::{Field, RichField};
|
||||
use crate::gates::gate::Gate;
|
||||
use crate::gates::poseidon_mds::PoseidonMdsGate;
|
||||
use crate::gates::util::StridedConstraintConsumer;
|
||||
use crate::hash::hashing::SPONGE_WIDTH;
|
||||
use crate::hash::poseidon;
|
||||
use crate::hash::poseidon::Poseidon;
|
||||
@ -181,19 +182,21 @@ impl<F: Extendable<D>, const D: usize> Gate<F, D> for PoseidonGate<F, D> {
|
||||
constraints
|
||||
}
|
||||
|
||||
fn eval_unfiltered_base(&self, vars: EvaluationVarsBase<F>) -> Vec<F> {
|
||||
let mut constraints = Vec::with_capacity(self.num_constraints());
|
||||
|
||||
fn eval_unfiltered_base_one(
|
||||
&self,
|
||||
vars: EvaluationVarsBase<F>,
|
||||
mut yield_constr: StridedConstraintConsumer<F>,
|
||||
) {
|
||||
// Assert that `swap` is binary.
|
||||
let swap = vars.local_wires[Self::WIRE_SWAP];
|
||||
constraints.push(swap * swap.sub_one());
|
||||
yield_constr.one(swap * swap.sub_one());
|
||||
|
||||
// Assert that each delta wire is set properly: `delta_i = swap * (rhs - lhs)`.
|
||||
for i in 0..4 {
|
||||
let input_lhs = vars.local_wires[Self::wire_input(i)];
|
||||
let input_rhs = vars.local_wires[Self::wire_input(i + 4)];
|
||||
let delta_i = vars.local_wires[Self::wire_delta(i)];
|
||||
constraints.push(swap * (input_rhs - input_lhs) - delta_i);
|
||||
yield_constr.one(swap * (input_rhs - input_lhs) - delta_i);
|
||||
}
|
||||
|
||||
// Compute the possibly-swapped input layer.
|
||||
@ -217,7 +220,7 @@ impl<F: Extendable<D>, const D: usize> Gate<F, D> for PoseidonGate<F, D> {
|
||||
if r != 0 {
|
||||
for i in 0..SPONGE_WIDTH {
|
||||
let sbox_in = vars.local_wires[Self::wire_full_sbox_0(r, i)];
|
||||
constraints.push(state[i] - sbox_in);
|
||||
yield_constr.one(state[i] - sbox_in);
|
||||
state[i] = sbox_in;
|
||||
}
|
||||
}
|
||||
@ -231,13 +234,13 @@ impl<F: Extendable<D>, const D: usize> Gate<F, D> for PoseidonGate<F, D> {
|
||||
state = <F as Poseidon>::mds_partial_layer_init(&state);
|
||||
for r in 0..(poseidon::N_PARTIAL_ROUNDS - 1) {
|
||||
let sbox_in = vars.local_wires[Self::wire_partial_sbox(r)];
|
||||
constraints.push(state[0] - sbox_in);
|
||||
yield_constr.one(state[0] - sbox_in);
|
||||
state[0] = <F as Poseidon>::sbox_monomial(sbox_in);
|
||||
state[0] += F::from_canonical_u64(<F as Poseidon>::FAST_PARTIAL_ROUND_CONSTANTS[r]);
|
||||
state = <F as Poseidon>::mds_partial_layer_fast(&state, r);
|
||||
}
|
||||
let sbox_in = vars.local_wires[Self::wire_partial_sbox(poseidon::N_PARTIAL_ROUNDS - 1)];
|
||||
constraints.push(state[0] - sbox_in);
|
||||
yield_constr.one(state[0] - sbox_in);
|
||||
state[0] = <F as Poseidon>::sbox_monomial(sbox_in);
|
||||
state = <F as Poseidon>::mds_partial_layer_fast(&state, poseidon::N_PARTIAL_ROUNDS - 1);
|
||||
round_ctr += poseidon::N_PARTIAL_ROUNDS;
|
||||
@ -247,7 +250,7 @@ impl<F: Extendable<D>, const D: usize> Gate<F, D> for PoseidonGate<F, D> {
|
||||
<F as Poseidon>::constant_layer(&mut state, round_ctr);
|
||||
for i in 0..SPONGE_WIDTH {
|
||||
let sbox_in = vars.local_wires[Self::wire_full_sbox_1(r, i)];
|
||||
constraints.push(state[i] - sbox_in);
|
||||
yield_constr.one(state[i] - sbox_in);
|
||||
state[i] = sbox_in;
|
||||
}
|
||||
<F as Poseidon>::sbox_layer(&mut state);
|
||||
@ -256,10 +259,8 @@ impl<F: Extendable<D>, const D: usize> Gate<F, D> for PoseidonGate<F, D> {
|
||||
}
|
||||
|
||||
for i in 0..SPONGE_WIDTH {
|
||||
constraints.push(state[i] - vars.local_wires[Self::wire_output(i)]);
|
||||
yield_constr.one(state[i] - vars.local_wires[Self::wire_output(i)]);
|
||||
}
|
||||
|
||||
constraints
|
||||
}
|
||||
|
||||
fn eval_unfiltered_recursively(
|
||||
|
||||
@ -7,6 +7,7 @@ use crate::field::extension_field::Extendable;
|
||||
use crate::field::extension_field::FieldExtension;
|
||||
use crate::field::field_types::{Field, RichField};
|
||||
use crate::gates::gate::Gate;
|
||||
use crate::gates::util::StridedConstraintConsumer;
|
||||
use crate::hash::hashing::SPONGE_WIDTH;
|
||||
use crate::hash::poseidon::Poseidon;
|
||||
use crate::iop::generator::{GeneratedValues, SimpleGenerator, WitnessGenerator};
|
||||
@ -123,7 +124,11 @@ impl<F: RichField + Extendable<D> + Poseidon, const D: usize> Gate<F, D> for Pos
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn eval_unfiltered_base(&self, vars: EvaluationVarsBase<F>) -> Vec<F> {
|
||||
fn eval_unfiltered_base_one(
|
||||
&self,
|
||||
vars: EvaluationVarsBase<F>,
|
||||
mut yield_constr: StridedConstraintConsumer<F>,
|
||||
) {
|
||||
let inputs: [_; SPONGE_WIDTH] = (0..SPONGE_WIDTH)
|
||||
.map(|i| vars.get_local_ext(Self::wires_input(i)))
|
||||
.collect::<Vec<_>>()
|
||||
@ -132,11 +137,12 @@ impl<F: RichField + Extendable<D> + Poseidon, const D: usize> Gate<F, D> for Pos
|
||||
|
||||
let computed_outputs = F::mds_layer_field(&inputs);
|
||||
|
||||
(0..SPONGE_WIDTH)
|
||||
.map(|i| vars.get_local_ext(Self::wires_output(i)))
|
||||
.zip(computed_outputs)
|
||||
.flat_map(|(out, computed_out)| (out - computed_out).to_basefield_array())
|
||||
.collect()
|
||||
yield_constr.many(
|
||||
(0..SPONGE_WIDTH)
|
||||
.map(|i| vars.get_local_ext(Self::wires_output(i)))
|
||||
.zip(computed_outputs)
|
||||
.flat_map(|(out, computed_out)| (out - computed_out).to_basefield_array()),
|
||||
)
|
||||
}
|
||||
|
||||
fn eval_unfiltered_recursively(
|
||||
|
||||
@ -3,6 +3,7 @@ use std::ops::Range;
|
||||
use crate::field::extension_field::target::ExtensionTarget;
|
||||
use crate::field::extension_field::Extendable;
|
||||
use crate::gates::gate::Gate;
|
||||
use crate::gates::util::StridedConstraintConsumer;
|
||||
use crate::iop::generator::WitnessGenerator;
|
||||
use crate::plonk::circuit_builder::CircuitBuilder;
|
||||
use crate::plonk::vars::{EvaluationTargets, EvaluationVars, EvaluationVarsBase};
|
||||
@ -28,11 +29,16 @@ impl<F: Extendable<D>, const D: usize> Gate<F, D> for PublicInputGate {
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn eval_unfiltered_base(&self, vars: EvaluationVarsBase<F>) -> Vec<F> {
|
||||
Self::wires_public_inputs_hash()
|
||||
.zip(vars.public_inputs_hash.elements)
|
||||
.map(|(wire, hash_part)| vars.local_wires[wire] - hash_part)
|
||||
.collect()
|
||||
fn eval_unfiltered_base_one(
|
||||
&self,
|
||||
vars: EvaluationVarsBase<F>,
|
||||
mut yield_constr: StridedConstraintConsumer<F>,
|
||||
) {
|
||||
yield_constr.many(
|
||||
Self::wires_public_inputs_hash()
|
||||
.zip(vars.public_inputs_hash.elements)
|
||||
.map(|(wire, hash_part)| vars.local_wires[wire] - hash_part),
|
||||
)
|
||||
}
|
||||
|
||||
fn eval_unfiltered_recursively(
|
||||
|
||||
@ -6,6 +6,7 @@ use crate::field::extension_field::target::ExtensionTarget;
|
||||
use crate::field::extension_field::Extendable;
|
||||
use crate::field::field_types::Field;
|
||||
use crate::gates::gate::Gate;
|
||||
use crate::gates::util::StridedConstraintConsumer;
|
||||
use crate::iop::generator::{GeneratedValues, SimpleGenerator, WitnessGenerator};
|
||||
use crate::iop::target::Target;
|
||||
use crate::iop::wire::Wire;
|
||||
@ -125,9 +126,11 @@ impl<F: Extendable<D>, const D: usize> Gate<F, D> for RandomAccessGate<F, D> {
|
||||
constraints
|
||||
}
|
||||
|
||||
fn eval_unfiltered_base(&self, vars: EvaluationVarsBase<F>) -> Vec<F> {
|
||||
let mut constraints = Vec::with_capacity(self.num_constraints());
|
||||
|
||||
fn eval_unfiltered_base_one(
|
||||
&self,
|
||||
vars: EvaluationVarsBase<F>,
|
||||
mut yield_constr: StridedConstraintConsumer<F>,
|
||||
) {
|
||||
for copy in 0..self.num_copies {
|
||||
let access_index = vars.local_wires[self.wire_access_index(copy)];
|
||||
let mut list_items = (0..self.vec_size())
|
||||
@ -140,12 +143,12 @@ impl<F: Extendable<D>, const D: usize> Gate<F, D> for RandomAccessGate<F, D> {
|
||||
|
||||
// Assert that each bit wire value is indeed boolean.
|
||||
for &b in &bits {
|
||||
constraints.push(b * (b - F::ONE));
|
||||
yield_constr.one(b * (b - F::ONE));
|
||||
}
|
||||
|
||||
// Assert that the binary decomposition was correct.
|
||||
let reconstructed_index = bits.iter().rev().fold(F::ZERO, |acc, &b| acc.double() + b);
|
||||
constraints.push(reconstructed_index - access_index);
|
||||
yield_constr.one(reconstructed_index - access_index);
|
||||
|
||||
// Repeatedly fold the list, selecting the left or right item from each pair based on
|
||||
// the corresponding bit.
|
||||
@ -158,10 +161,8 @@ impl<F: Extendable<D>, const D: usize> Gate<F, D> for RandomAccessGate<F, D> {
|
||||
}
|
||||
|
||||
debug_assert_eq!(list_items.len(), 1);
|
||||
constraints.push(list_items[0] - claimed_element);
|
||||
yield_constr.one(list_items[0] - claimed_element);
|
||||
}
|
||||
|
||||
constraints
|
||||
}
|
||||
|
||||
fn eval_unfiltered_recursively(
|
||||
|
||||
@ -5,6 +5,7 @@ use crate::field::extension_field::Extendable;
|
||||
use crate::field::extension_field::FieldExtension;
|
||||
use crate::field::field_types::RichField;
|
||||
use crate::gates::gate::Gate;
|
||||
use crate::gates::util::StridedConstraintConsumer;
|
||||
use crate::iop::generator::{GeneratedValues, SimpleGenerator, WitnessGenerator};
|
||||
use crate::iop::target::Target;
|
||||
use crate::iop::witness::{PartitionWitness, Witness};
|
||||
@ -80,7 +81,11 @@ impl<F: RichField + Extendable<D>, const D: usize> Gate<F, D> for ReducingGate<D
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn eval_unfiltered_base(&self, vars: EvaluationVarsBase<F>) -> Vec<F> {
|
||||
fn eval_unfiltered_base_one(
|
||||
&self,
|
||||
vars: EvaluationVarsBase<F>,
|
||||
mut yield_constr: StridedConstraintConsumer<F>,
|
||||
) {
|
||||
let alpha = vars.get_local_ext(Self::wires_alpha());
|
||||
let old_acc = vars.get_local_ext(Self::wires_old_acc());
|
||||
let coeffs = self
|
||||
@ -91,14 +96,11 @@ impl<F: RichField + Extendable<D>, const D: usize> Gate<F, D> for ReducingGate<D
|
||||
.map(|i| vars.get_local_ext(self.wires_accs(i)))
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let mut constraints = Vec::with_capacity(<Self as Gate<F, D>>::num_constraints(self));
|
||||
let mut acc = old_acc;
|
||||
for i in 0..self.num_coeffs {
|
||||
constraints.extend((acc * alpha + coeffs[i].into() - accs[i]).to_basefield_array());
|
||||
yield_constr.many((acc * alpha + coeffs[i].into() - accs[i]).to_basefield_array());
|
||||
acc = accs[i];
|
||||
}
|
||||
|
||||
constraints
|
||||
}
|
||||
|
||||
fn eval_unfiltered_recursively(
|
||||
|
||||
@ -5,6 +5,7 @@ use crate::field::extension_field::Extendable;
|
||||
use crate::field::extension_field::FieldExtension;
|
||||
use crate::field::field_types::RichField;
|
||||
use crate::gates::gate::Gate;
|
||||
use crate::gates::util::StridedConstraintConsumer;
|
||||
use crate::iop::generator::{GeneratedValues, SimpleGenerator, WitnessGenerator};
|
||||
use crate::iop::target::Target;
|
||||
use crate::iop::witness::{PartitionWitness, Witness};
|
||||
@ -82,7 +83,11 @@ impl<F: RichField + Extendable<D>, const D: usize> Gate<F, D> for ReducingExtens
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn eval_unfiltered_base(&self, vars: EvaluationVarsBase<F>) -> Vec<F> {
|
||||
fn eval_unfiltered_base_one(
|
||||
&self,
|
||||
vars: EvaluationVarsBase<F>,
|
||||
mut yield_constr: StridedConstraintConsumer<F>,
|
||||
) {
|
||||
let alpha = vars.get_local_ext(Self::wires_alpha());
|
||||
let old_acc = vars.get_local_ext(Self::wires_old_acc());
|
||||
let coeffs = (0..self.num_coeffs)
|
||||
@ -92,14 +97,11 @@ impl<F: RichField + Extendable<D>, const D: usize> Gate<F, D> for ReducingExtens
|
||||
.map(|i| vars.get_local_ext(self.wires_accs(i)))
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let mut constraints = Vec::with_capacity(<Self as Gate<F, D>>::num_constraints(self));
|
||||
let mut acc = old_acc;
|
||||
for i in 0..self.num_coeffs {
|
||||
constraints.extend((acc * alpha + coeffs[i] - accs[i]).to_basefield_array());
|
||||
yield_constr.many((acc * alpha + coeffs[i] - accs[i]).to_basefield_array());
|
||||
acc = accs[i];
|
||||
}
|
||||
|
||||
constraints
|
||||
}
|
||||
|
||||
fn eval_unfiltered_recursively(
|
||||
|
||||
@ -4,6 +4,7 @@ use crate::field::extension_field::target::ExtensionTarget;
|
||||
use crate::field::extension_field::Extendable;
|
||||
use crate::field::field_types::{Field, RichField};
|
||||
use crate::gates::gate::Gate;
|
||||
use crate::gates::util::StridedConstraintConsumer;
|
||||
use crate::iop::generator::{GeneratedValues, SimpleGenerator, WitnessGenerator};
|
||||
use crate::iop::target::Target;
|
||||
use crate::iop::wire::Wire;
|
||||
@ -113,8 +114,11 @@ impl<F: RichField + Extendable<D>, const D: usize> Gate<F, D> for U32Subtraction
|
||||
constraints
|
||||
}
|
||||
|
||||
fn eval_unfiltered_base(&self, vars: EvaluationVarsBase<F>) -> Vec<F> {
|
||||
let mut constraints = Vec::with_capacity(self.num_constraints());
|
||||
fn eval_unfiltered_base_one(
|
||||
&self,
|
||||
vars: EvaluationVarsBase<F>,
|
||||
mut yield_constr: StridedConstraintConsumer<F>,
|
||||
) {
|
||||
for i in 0..self.num_ops {
|
||||
let input_x = vars.local_wires[self.wire_ith_input_x(i)];
|
||||
let input_y = vars.local_wires[self.wire_ith_input_y(i)];
|
||||
@ -126,7 +130,7 @@ impl<F: RichField + Extendable<D>, const D: usize> Gate<F, D> for U32Subtraction
|
||||
let output_result = vars.local_wires[self.wire_ith_output_result(i)];
|
||||
let output_borrow = vars.local_wires[self.wire_ith_output_borrow(i)];
|
||||
|
||||
constraints.push(output_result - (result_initial + base * output_borrow));
|
||||
yield_constr.one(output_result - (result_initial + base * output_borrow));
|
||||
|
||||
// Range-check output_result to be at most 32 bits.
|
||||
let mut combined_limbs = F::ZERO;
|
||||
@ -137,17 +141,15 @@ impl<F: RichField + Extendable<D>, const D: usize> Gate<F, D> for U32Subtraction
|
||||
let product = (0..max_limb)
|
||||
.map(|x| this_limb - F::from_canonical_usize(x))
|
||||
.product();
|
||||
constraints.push(product);
|
||||
yield_constr.one(product);
|
||||
|
||||
combined_limbs = limb_base * combined_limbs + this_limb;
|
||||
}
|
||||
constraints.push(combined_limbs - output_result);
|
||||
yield_constr.one(combined_limbs - output_result);
|
||||
|
||||
// Range-check output_borrow to be one bit.
|
||||
constraints.push(output_borrow * (F::ONE - output_borrow));
|
||||
yield_constr.one(output_borrow * (F::ONE - output_borrow));
|
||||
}
|
||||
|
||||
constraints
|
||||
}
|
||||
|
||||
fn eval_unfiltered_recursively(
|
||||
|
||||
@ -6,6 +6,7 @@ use crate::field::extension_field::target::ExtensionTarget;
|
||||
use crate::field::extension_field::Extendable;
|
||||
use crate::field::field_types::Field;
|
||||
use crate::gates::gate::Gate;
|
||||
use crate::gates::util::StridedConstraintConsumer;
|
||||
use crate::iop::generator::{GeneratedValues, WitnessGenerator};
|
||||
use crate::iop::target::Target;
|
||||
use crate::iop::wire::Wire;
|
||||
@ -94,9 +95,11 @@ impl<F: Extendable<D>, const D: usize> Gate<F, D> for SwitchGate<F, D> {
|
||||
constraints
|
||||
}
|
||||
|
||||
fn eval_unfiltered_base(&self, vars: EvaluationVarsBase<F>) -> Vec<F> {
|
||||
let mut constraints = Vec::with_capacity(self.num_constraints());
|
||||
|
||||
fn eval_unfiltered_base_one(
|
||||
&self,
|
||||
vars: EvaluationVarsBase<F>,
|
||||
mut yield_constr: StridedConstraintConsumer<F>,
|
||||
) {
|
||||
for c in 0..self.num_copies {
|
||||
let switch_bool = vars.local_wires[self.wire_switch_bool(c)];
|
||||
let not_switch = F::ONE - switch_bool;
|
||||
@ -107,14 +110,12 @@ impl<F: Extendable<D>, const D: usize> Gate<F, D> for SwitchGate<F, D> {
|
||||
let first_output = vars.local_wires[self.wire_first_output(c, e)];
|
||||
let second_output = vars.local_wires[self.wire_second_output(c, e)];
|
||||
|
||||
constraints.push(switch_bool * (first_input - second_output));
|
||||
constraints.push(switch_bool * (second_input - first_output));
|
||||
constraints.push(not_switch * (first_input - first_output));
|
||||
constraints.push(not_switch * (second_input - second_output));
|
||||
yield_constr.one(switch_bool * (first_input - second_output));
|
||||
yield_constr.one(switch_bool * (second_input - first_output));
|
||||
yield_constr.one(not_switch * (first_input - first_output));
|
||||
yield_constr.one(not_switch * (second_input - second_output));
|
||||
}
|
||||
}
|
||||
|
||||
constraints
|
||||
}
|
||||
|
||||
fn eval_unfiltered_recursively(
|
||||
|
||||
62
src/gates/util.rs
Normal file
62
src/gates/util.rs
Normal file
@ -0,0 +1,62 @@
|
||||
use std::marker::PhantomData;
|
||||
|
||||
use crate::field::field_types::Field;
|
||||
|
||||
/// Writes constraints yielded by a gate to a buffer, with a given stride.
|
||||
/// Permits us to abstract the underlying memory layout. In particular, we can make a matrix of
|
||||
/// constraints where every column is an evaluation point and every row is a constraint index, with
|
||||
/// the matrix stored in row-contiguous form.
|
||||
pub struct StridedConstraintConsumer<'a, F: Field> {
|
||||
// This is a particularly neat way of doing this, more so than a slice. We increase start by
|
||||
// stride at every step and terminate when it equals end.
|
||||
start: *mut F,
|
||||
end: *mut F,
|
||||
stride: usize,
|
||||
_phantom: PhantomData<&'a mut [F]>,
|
||||
}
|
||||
|
||||
impl<'a, F: Field> StridedConstraintConsumer<'a, F> {
|
||||
pub fn new(buffer: &'a mut [F], stride: usize, offset: usize) -> Self {
|
||||
assert!(offset < stride);
|
||||
assert_eq!(buffer.len() % stride, 0);
|
||||
let ptr_range = buffer.as_mut_ptr_range();
|
||||
// `wrapping_add` is needed to avoid undefined behavior. Plain `add` causes UB if 'the ...
|
||||
// resulting pointer [is neither] in bounds or one byte past the end of the same allocated
|
||||
// object'; the UB results even if the pointer is not dereferenced. `end` will be more than
|
||||
// one byte past the buffer unless `offset` is 0. The same applies to `start` if the buffer
|
||||
// has length 0 and the offset is not 0.
|
||||
// We _could_ do pointer arithmetic without `wrapping_add`, but the logic would be
|
||||
// unnecessarily complicated.
|
||||
let start = ptr_range.start.wrapping_add(offset);
|
||||
let end = ptr_range.end.wrapping_add(offset);
|
||||
Self {
|
||||
start,
|
||||
end,
|
||||
stride,
|
||||
_phantom: PhantomData,
|
||||
}
|
||||
}
|
||||
|
||||
/// Emit one constraint.
|
||||
pub fn one(&mut self, constraint: F) {
|
||||
if self.start != self.end {
|
||||
// # Safety
|
||||
// The checks in `new` guarantee that this points to valid space.
|
||||
unsafe {
|
||||
*self.start = constraint;
|
||||
}
|
||||
// See the comment in `new`. `wrapping_add` is needed to avoid UB if we've just
|
||||
// exhausted our buffer (and hence we're setting `self.start` to point past the end).
|
||||
self.start = self.start.wrapping_add(self.stride);
|
||||
} else {
|
||||
panic!("gate produced too many constraints");
|
||||
}
|
||||
}
|
||||
|
||||
/// Convenience method that calls `.one()` multiple times.
|
||||
pub fn many<I: IntoIterator<Item = F>>(&mut self, constraints: I) {
|
||||
constraints
|
||||
.into_iter()
|
||||
.for_each(|constraint| self.one(constraint));
|
||||
}
|
||||
}
|
||||
@ -157,9 +157,15 @@ pub(crate) fn reduce_with_powers_multi<
|
||||
cumul
|
||||
}
|
||||
|
||||
pub(crate) fn reduce_with_powers<F: Field>(terms: &[F], alpha: F) -> F {
|
||||
pub(crate) fn reduce_with_powers<'a, F: Field, T: IntoIterator<Item = &'a F>>(
|
||||
terms: T,
|
||||
alpha: F,
|
||||
) -> F
|
||||
where
|
||||
T::IntoIter: DoubleEndedIterator,
|
||||
{
|
||||
let mut sum = F::ZERO;
|
||||
for &term in terms.iter().rev() {
|
||||
for &term in terms.into_iter().rev() {
|
||||
sum = sum * alpha + term;
|
||||
}
|
||||
sum
|
||||
|
||||
@ -14,7 +14,7 @@ use crate::plonk::plonk_common::PlonkPolynomials;
|
||||
use crate::plonk::plonk_common::ZeroPolyOnCoset;
|
||||
use crate::plonk::proof::{Proof, ProofWithPublicInputs};
|
||||
use crate::plonk::vanishing_poly::eval_vanishing_poly_base_batch;
|
||||
use crate::plonk::vars::EvaluationVarsBase;
|
||||
use crate::plonk::vars::EvaluationVarsBaseBatch;
|
||||
use crate::polynomial::{PolynomialCoeffs, PolynomialValues};
|
||||
use crate::timed;
|
||||
use crate::util::partial_products::{partial_products_and_z_gx, quotient_chunk_products};
|
||||
@ -333,12 +333,14 @@ fn compute_quotient_polys<'a, F: Extendable<D>, C: GenericConfig<D, F = F>, cons
|
||||
(BATCH_SIZE * batch_i..BATCH_SIZE * (batch_i + 1)).collect();
|
||||
|
||||
let mut shifted_xs_batch = Vec::with_capacity(xs_batch.len());
|
||||
let mut vars_batch = Vec::with_capacity(xs_batch.len());
|
||||
let mut local_zs_batch = Vec::with_capacity(xs_batch.len());
|
||||
let mut next_zs_batch = Vec::with_capacity(xs_batch.len());
|
||||
let mut partial_products_batch = Vec::with_capacity(xs_batch.len());
|
||||
let mut s_sigmas_batch = Vec::with_capacity(xs_batch.len());
|
||||
|
||||
let mut local_constants_batch_refs = Vec::with_capacity(xs_batch.len());
|
||||
let mut local_wires_batch_refs = Vec::with_capacity(xs_batch.len());
|
||||
|
||||
for (&i, &x) in indices_batch.iter().zip(xs_batch) {
|
||||
let shifted_x = F::coset_shift() * x;
|
||||
let i_next = (i + next_step) % lde_size;
|
||||
@ -357,24 +359,45 @@ fn compute_quotient_polys<'a, F: Extendable<D>, C: GenericConfig<D, F = F>, cons
|
||||
debug_assert_eq!(local_wires.len(), common_data.config.num_wires);
|
||||
debug_assert_eq!(local_zs.len(), num_challenges);
|
||||
|
||||
let vars = EvaluationVarsBase {
|
||||
local_constants,
|
||||
local_wires,
|
||||
public_inputs_hash,
|
||||
};
|
||||
local_constants_batch_refs.push(local_constants);
|
||||
local_wires_batch_refs.push(local_wires);
|
||||
|
||||
shifted_xs_batch.push(shifted_x);
|
||||
vars_batch.push(vars);
|
||||
local_zs_batch.push(local_zs);
|
||||
next_zs_batch.push(next_zs);
|
||||
partial_products_batch.push(partial_products);
|
||||
s_sigmas_batch.push(s_sigmas);
|
||||
}
|
||||
|
||||
// NB (JN): I'm not sure how (in)efficient the below is. It needs measuring.
|
||||
let mut local_constants_batch =
|
||||
vec![F::ZERO; xs_batch.len() * local_constants_batch_refs[0].len()];
|
||||
for (i, constants) in local_constants_batch_refs.iter().enumerate() {
|
||||
for (j, &constant) in constants.iter().enumerate() {
|
||||
local_constants_batch[i + j * xs_batch.len()] = constant;
|
||||
}
|
||||
}
|
||||
|
||||
let mut local_wires_batch =
|
||||
vec![F::ZERO; xs_batch.len() * local_wires_batch_refs[0].len()];
|
||||
for (i, wires) in local_wires_batch_refs.iter().enumerate() {
|
||||
for (j, &wire) in wires.iter().enumerate() {
|
||||
local_wires_batch[i + j * xs_batch.len()] = wire;
|
||||
}
|
||||
}
|
||||
|
||||
let vars_batch = EvaluationVarsBaseBatch::new(
|
||||
xs_batch.len(),
|
||||
&local_constants_batch,
|
||||
&local_wires_batch,
|
||||
public_inputs_hash,
|
||||
);
|
||||
|
||||
let mut quotient_values_batch = eval_vanishing_poly_base_batch(
|
||||
common_data,
|
||||
&indices_batch,
|
||||
&shifted_xs_batch,
|
||||
&vars_batch,
|
||||
vars_batch,
|
||||
&local_zs_batch,
|
||||
&next_zs_batch,
|
||||
&partial_products_batch,
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
use crate::field::batch_util::batch_add_inplace;
|
||||
use crate::field::extension_field::target::ExtensionTarget;
|
||||
use crate::field::extension_field::{Extendable, FieldExtension};
|
||||
use crate::field::field_types::Field;
|
||||
@ -8,9 +9,10 @@ use crate::plonk::circuit_data::CommonCircuitData;
|
||||
use crate::plonk::config::GenericConfig;
|
||||
use crate::plonk::plonk_common;
|
||||
use crate::plonk::plonk_common::{eval_l_1_recursively, ZeroPolyOnCoset};
|
||||
use crate::plonk::vars::{EvaluationTargets, EvaluationVars, EvaluationVarsBase};
|
||||
use crate::plonk::vars::{EvaluationTargets, EvaluationVars, EvaluationVarsBaseBatch};
|
||||
use crate::util::partial_products::{check_partial_products, check_partial_products_recursively};
|
||||
use crate::util::reducing::ReducingFactorTarget;
|
||||
use crate::util::strided_view::PackedStridedView;
|
||||
use crate::with_context;
|
||||
|
||||
/// Evaluate the vanishing polynomial at `x`. In this context, the vanishing polynomial is a random
|
||||
@ -96,7 +98,7 @@ pub(crate) fn eval_vanishing_poly_base_batch<
|
||||
common_data: &CommonCircuitData<F, C, D>,
|
||||
indices_batch: &[usize],
|
||||
xs_batch: &[F],
|
||||
vars_batch: &[EvaluationVarsBase<F>],
|
||||
vars_batch: EvaluationVarsBaseBatch<F>,
|
||||
local_zs_batch: &[&[F]],
|
||||
next_zs_batch: &[&[F]],
|
||||
partial_products_batch: &[&[F]],
|
||||
@ -138,14 +140,13 @@ pub(crate) fn eval_vanishing_poly_base_batch<
|
||||
for k in 0..n {
|
||||
let index = indices_batch[k];
|
||||
let x = xs_batch[k];
|
||||
let vars = vars_batch[k];
|
||||
let vars = vars_batch.view(k);
|
||||
let local_zs = local_zs_batch[k];
|
||||
let next_zs = next_zs_batch[k];
|
||||
let partial_products = partial_products_batch[k];
|
||||
let s_sigmas = s_sigmas_batch[k];
|
||||
|
||||
let constraint_terms =
|
||||
&constraint_terms_batch[k * num_gate_constraints..(k + 1) * num_gate_constraints];
|
||||
let constraint_terms = PackedStridedView::new(&constraint_terms_batch, n, k);
|
||||
|
||||
let l1_x = z_h_on_coset.eval_l1(index, x);
|
||||
for i in 0..num_challenges {
|
||||
@ -221,13 +222,13 @@ pub fn evaluate_gate_constraints<F: Extendable<D>, const D: usize>(
|
||||
|
||||
/// Evaluate all gate constraints in the base field.
|
||||
///
|
||||
/// Returns a vector of num_gate_constraints * vars_batch.len() field elements. The constraints
|
||||
/// corresponding to vars_batch[i] are found in
|
||||
/// result[num_gate_constraints * i..num_gate_constraints * (i + 1)].
|
||||
/// Returns a vector of `num_gate_constraints * vars_batch.len()` field elements. The constraints
|
||||
/// corresponding to `vars_batch[i]` are found in `result[i], result[vars_batch.len() + i],
|
||||
/// result[2 * vars_batch.len() + i], ...`.
|
||||
pub fn evaluate_gate_constraints_base_batch<F: Extendable<D>, const D: usize>(
|
||||
gates: &[PrefixedGate<F, D>],
|
||||
num_gate_constraints: usize,
|
||||
vars_batch: &[EvaluationVarsBase<F>],
|
||||
vars_batch: EvaluationVarsBaseBatch<F>,
|
||||
) -> Vec<F> {
|
||||
let mut constraints_batch = vec![F::ZERO; num_gate_constraints * vars_batch.len()];
|
||||
for gate in gates {
|
||||
@ -235,20 +236,15 @@ pub fn evaluate_gate_constraints_base_batch<F: Extendable<D>, const D: usize>(
|
||||
.gate
|
||||
.0
|
||||
.eval_filtered_base_batch(vars_batch, &gate.prefix);
|
||||
for (constraints, gate_constraints) in constraints_batch
|
||||
.chunks_exact_mut(num_gate_constraints)
|
||||
.zip(gate_constraints_batch.iter())
|
||||
{
|
||||
debug_assert!(
|
||||
gate_constraints.len() <= constraints.len(),
|
||||
"num_constraints() gave too low of a number"
|
||||
);
|
||||
for (constraint, &gate_constraint) in
|
||||
constraints.iter_mut().zip(gate_constraints.iter())
|
||||
{
|
||||
*constraint += gate_constraint;
|
||||
}
|
||||
}
|
||||
debug_assert!(
|
||||
gate_constraints_batch.len() <= constraints_batch.len(),
|
||||
"num_constraints() gave too low of a number"
|
||||
);
|
||||
// below adds all constraints for all points
|
||||
batch_add_inplace(
|
||||
&mut constraints_batch[..gate_constraints_batch.len()],
|
||||
&gate_constraints_batch,
|
||||
);
|
||||
}
|
||||
constraints_batch
|
||||
}
|
||||
|
||||
@ -5,6 +5,7 @@ use crate::field::extension_field::target::{ExtensionAlgebraTarget, ExtensionTar
|
||||
use crate::field::extension_field::{Extendable, FieldExtension};
|
||||
use crate::field::field_types::Field;
|
||||
use crate::hash::hash_types::{HashOut, HashOutTarget};
|
||||
use crate::util::strided_view::PackedStridedView;
|
||||
|
||||
#[derive(Debug, Copy, Clone)]
|
||||
pub struct EvaluationVars<'a, F: Extendable<D>, const D: usize> {
|
||||
@ -13,13 +14,25 @@ pub struct EvaluationVars<'a, F: Extendable<D>, const D: usize> {
|
||||
pub(crate) public_inputs_hash: &'a HashOut<F>,
|
||||
}
|
||||
|
||||
/// A batch of evaluation vars, in the base field.
|
||||
/// Wires and constants are stored in an evaluation point-major order (that is, wire 0 for all
|
||||
/// evaluation points, then wire 1 for all points, and so on).
|
||||
#[derive(Debug, Copy, Clone)]
|
||||
pub struct EvaluationVarsBase<'a, F: Field> {
|
||||
pub struct EvaluationVarsBaseBatch<'a, F: Field> {
|
||||
batch_size: usize,
|
||||
pub(crate) local_constants: &'a [F],
|
||||
pub(crate) local_wires: &'a [F],
|
||||
pub(crate) public_inputs_hash: &'a HashOut<F>,
|
||||
}
|
||||
|
||||
/// A view into `EvaluationVarsBaseBatch` for a particular evaluation point. Does not copy the data.
|
||||
#[derive(Debug, Copy, Clone)]
|
||||
pub struct EvaluationVarsBase<'a, F: Field> {
|
||||
pub(crate) local_constants: PackedStridedView<'a, F>,
|
||||
pub(crate) local_wires: PackedStridedView<'a, F>,
|
||||
pub(crate) public_inputs_hash: &'a HashOut<F>,
|
||||
}
|
||||
|
||||
impl<'a, F: Extendable<D>, const D: usize> EvaluationVars<'a, F, D> {
|
||||
pub fn get_local_ext_algebra(
|
||||
&self,
|
||||
@ -35,18 +48,81 @@ impl<'a, F: Extendable<D>, const D: usize> EvaluationVars<'a, F, D> {
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, F: Field> EvaluationVarsBaseBatch<'a, F> {
|
||||
pub fn new(
|
||||
batch_size: usize,
|
||||
local_constants: &'a [F],
|
||||
local_wires: &'a [F],
|
||||
public_inputs_hash: &'a HashOut<F>,
|
||||
) -> Self {
|
||||
assert_eq!(local_constants.len() % batch_size, 0);
|
||||
assert_eq!(local_wires.len() % batch_size, 0);
|
||||
Self {
|
||||
batch_size,
|
||||
local_constants,
|
||||
local_wires,
|
||||
public_inputs_hash,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn remove_prefix(&mut self, prefix: &[bool]) {
|
||||
self.local_constants = &self.local_constants[prefix.len() * self.len()..];
|
||||
}
|
||||
|
||||
pub fn len(&self) -> usize {
|
||||
self.batch_size
|
||||
}
|
||||
|
||||
pub fn view(&self, index: usize) -> EvaluationVarsBase<'a, F> {
|
||||
// We cannot implement `Index` as `EvaluationVarsBase` is a struct, not a reference.
|
||||
assert!(index < self.len());
|
||||
let local_constants = PackedStridedView::new(self.local_constants, self.len(), index);
|
||||
let local_wires = PackedStridedView::new(self.local_wires, self.len(), index);
|
||||
EvaluationVarsBase {
|
||||
local_constants,
|
||||
local_wires,
|
||||
public_inputs_hash: self.public_inputs_hash,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn iter(&self) -> EvaluationVarsBaseBatchIter<'a, F> {
|
||||
EvaluationVarsBaseBatchIter::new(*self)
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, F: Field> EvaluationVarsBase<'a, F> {
|
||||
pub fn get_local_ext<const D: usize>(&self, wire_range: Range<usize>) -> F::Extension
|
||||
where
|
||||
F: Extendable<D>,
|
||||
{
|
||||
debug_assert_eq!(wire_range.len(), D);
|
||||
let arr = self.local_wires[wire_range].try_into().unwrap();
|
||||
let arr = self.local_wires.view(wire_range).try_into().unwrap();
|
||||
F::Extension::from_basefield_array(arr)
|
||||
}
|
||||
}
|
||||
|
||||
pub fn remove_prefix(&mut self, prefix: &[bool]) {
|
||||
self.local_constants = &self.local_constants[prefix.len()..];
|
||||
/// Iterator of views (`EvaluationVarsBase`) into a `EvaluationVarsBaseBatch`.
|
||||
pub struct EvaluationVarsBaseBatchIter<'a, F: Field> {
|
||||
i: usize,
|
||||
vars_batch: EvaluationVarsBaseBatch<'a, F>,
|
||||
}
|
||||
|
||||
impl<'a, F: Field> EvaluationVarsBaseBatchIter<'a, F> {
|
||||
pub fn new(vars_batch: EvaluationVarsBaseBatch<'a, F>) -> Self {
|
||||
EvaluationVarsBaseBatchIter { i: 0, vars_batch }
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, F: Field> Iterator for EvaluationVarsBaseBatchIter<'a, F> {
|
||||
type Item = EvaluationVarsBase<'a, F>;
|
||||
fn next(&mut self) -> Option<Self::Item> {
|
||||
if self.i < self.vars_batch.len() {
|
||||
let res = self.vars_batch.view(self.i);
|
||||
self.i += 1;
|
||||
Some(res)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -10,6 +10,7 @@ pub(crate) mod marking;
|
||||
pub(crate) mod partial_products;
|
||||
pub mod reducing;
|
||||
pub mod serialization;
|
||||
pub(crate) mod strided_view;
|
||||
pub(crate) mod timing;
|
||||
|
||||
pub(crate) fn bits_u64(n: u64) -> usize {
|
||||
|
||||
317
src/util/strided_view.rs
Normal file
317
src/util/strided_view.rs
Normal file
@ -0,0 +1,317 @@
|
||||
use std::marker::PhantomData;
|
||||
use std::mem::size_of;
|
||||
use std::ops::{Index, Range, RangeFrom, RangeFull, RangeInclusive, RangeTo, RangeToInclusive};
|
||||
|
||||
use crate::field::packed_field::PackedField;
|
||||
|
||||
/// Imagine a slice, but with a stride (a la a NumPy array).
|
||||
///
|
||||
/// For example, if the stride is 3,
|
||||
/// `packed_strided_view[0]` is `data[0]`,
|
||||
/// `packed_strided_view[1]` is `data[3]`,
|
||||
/// `packed_strided_view[2]` is `data[6]`,
|
||||
/// and so on. An offset may be specified. With an offset of 1, we get
|
||||
/// `packed_strided_view[0]` is `data[1]`,
|
||||
/// `packed_strided_view[1]` is `data[4]`,
|
||||
/// `packed_strided_view[2]` is `data[7]`,
|
||||
/// and so on.
|
||||
///
|
||||
/// Additionally, this view is *packed*, which means that it may yield a packing of the underlying
|
||||
/// field slice. With a packing of width 4 and a stride of 5, the accesses are
|
||||
/// `packed_strided_view[0]` is `data[0..4]`, transmuted to the packing,
|
||||
/// `packed_strided_view[1]` is `data[5..9]`, transmuted to the packing,
|
||||
/// `packed_strided_view[2]` is `data[10..14]`, transmuted to the packing,
|
||||
/// and so on.
|
||||
#[derive(Debug, Copy, Clone)]
|
||||
pub struct PackedStridedView<'a, P: PackedField> {
|
||||
// This type has to be a struct, which means that it is not itself a reference (in the sense
|
||||
// that a slice is a reference so we can return it from e.g. `Index::index`).
|
||||
|
||||
// Raw pointers rarely appear in good Rust code, but I think this is the most elegant way to
|
||||
// implement this. The alternative would be to replace `start_ptr` and `length` with one slice
|
||||
// (`&[P::Scalar]`). Unfortunately, with a slice, an empty view becomes an edge case that
|
||||
// necessitates separate handling. It _could_ be done but it would also be uglier.
|
||||
start_ptr: *const P::Scalar,
|
||||
/// This is the total length of elements accessible through the view. In other words, valid
|
||||
/// indices are in `0..length`.
|
||||
length: usize,
|
||||
/// This stride is in units of `P::Scalar` (NOT in bytes and NOT in `P`).
|
||||
stride: usize,
|
||||
_phantom: PhantomData<&'a [P::Scalar]>,
|
||||
}
|
||||
|
||||
impl<'a, P: PackedField> PackedStridedView<'a, P> {
|
||||
// `wrapping_add` is needed throughout to avoid undefined behavior. Plain `add` causes UB if
|
||||
// '[either] the starting [or] resulting pointer [is neither] in bounds or one byte past the
|
||||
// end of the same allocated object'; the UB results even if the pointer is not dereferenced.
|
||||
|
||||
#[inline]
|
||||
pub fn new(data: &'a [P::Scalar], stride: usize, offset: usize) -> Self {
|
||||
assert!(
|
||||
stride >= P::WIDTH,
|
||||
"stride (got {}) must be at least P::WIDTH ({})",
|
||||
stride,
|
||||
P::WIDTH
|
||||
);
|
||||
assert_eq!(
|
||||
data.len() % stride,
|
||||
0,
|
||||
"data.len() ({}) must be a multiple of stride (got {})",
|
||||
data.len(),
|
||||
stride
|
||||
);
|
||||
|
||||
// This requirement means that stride divides data into slices of `data.len() / stride`
|
||||
// elements. Every access must fit entirely within one of those slices.
|
||||
assert!(
|
||||
offset + P::WIDTH <= stride,
|
||||
"offset (got {}) + P::WIDTH ({}) cannot be greater than stride (got {})",
|
||||
offset,
|
||||
P::WIDTH,
|
||||
stride
|
||||
);
|
||||
|
||||
// See comment above. `start_ptr` will be more than one byte past the buffer if `data` has
|
||||
// length 0 and `offset` is not 0.
|
||||
let start_ptr = data.as_ptr().wrapping_add(offset);
|
||||
|
||||
Self {
|
||||
start_ptr,
|
||||
length: data.len() / stride,
|
||||
stride,
|
||||
_phantom: PhantomData,
|
||||
}
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn get(&self, index: usize) -> Option<&'a P> {
|
||||
if index < self.length {
|
||||
// Cast scalar pointer to vector pointer.
|
||||
let res_ptr = unsafe { self.start_ptr.add(index * self.stride) }.cast();
|
||||
// This transmutation is safe by the spec in `PackedField`.
|
||||
Some(unsafe { &*res_ptr })
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
/// Take a range of `PackedStridedView` indices, as `PackedStridedView`.
|
||||
#[inline]
|
||||
pub fn view<I>(&self, index: I) -> Self
|
||||
where
|
||||
Self: Viewable<I, View = Self>,
|
||||
{
|
||||
// We cannot implement `Index` as `PackedStridedView` is a struct, not a reference.
|
||||
|
||||
// The `Viewable` trait is needed for overloading.
|
||||
// Re-export `Viewable::view` so users don't have to import `Viewable`.
|
||||
<Self as Viewable<I>>::view(self, index)
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn iter(&self) -> PackedStridedViewIter<'a, P> {
|
||||
PackedStridedViewIter::new(
|
||||
self.start_ptr,
|
||||
// See comment at the top of the `impl`. Below will point more than one byte past the
|
||||
// end of the buffer (unless `offset` is 0) so `wrapping_add` is needed.
|
||||
self.start_ptr.wrapping_add(self.length * self.stride),
|
||||
self.stride,
|
||||
)
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn len(&self) -> usize {
|
||||
self.length
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, P: PackedField> Index<usize> for PackedStridedView<'a, P> {
|
||||
type Output = P;
|
||||
#[inline]
|
||||
fn index(&self, index: usize) -> &Self::Output {
|
||||
self.get(index)
|
||||
.expect("invalid memory access in PackedStridedView")
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, P: PackedField> IntoIterator for PackedStridedView<'a, P> {
|
||||
type Item = &'a P;
|
||||
type IntoIter = PackedStridedViewIter<'a, P>;
|
||||
fn into_iter(self) -> Self::IntoIter {
|
||||
self.iter()
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy, Debug)]
|
||||
pub struct TryFromPackedStridedViewError;
|
||||
|
||||
impl<P: PackedField, const N: usize> TryInto<[P; N]> for PackedStridedView<'_, P> {
|
||||
type Error = TryFromPackedStridedViewError;
|
||||
fn try_into(self) -> Result<[P; N], Self::Error> {
|
||||
if N == self.len() {
|
||||
let mut res = [P::ZERO; N];
|
||||
for i in 0..N {
|
||||
res[i] = *self.get(i).unwrap();
|
||||
}
|
||||
Ok(res)
|
||||
} else {
|
||||
Err(TryFromPackedStridedViewError)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Not deriving `Copy`. An implicit copy of an iterator is likely a bug.
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct PackedStridedViewIter<'a, P: PackedField> {
|
||||
// Again, a pair of pointers is a neater solution than a slice. `start` and `end` are always
|
||||
// separated by a multiple of stride elements. To advance the iterator from the front, we
|
||||
// advance `start` by `stride` elements. To advance it from the end, we subtract `stride`
|
||||
// elements. Iteration is done when they meet.
|
||||
// A slice cannot recreate the same pattern. The end pointer may point past the underlying
|
||||
// buffer (this is okay as we do not dereference it in that case); it becomes valid as soon as
|
||||
// it is decreased by `stride`. On the other hand, a slice that ends on invalid memory is
|
||||
// instant undefined behavior.
|
||||
start: *const P::Scalar,
|
||||
end: *const P::Scalar,
|
||||
stride: usize,
|
||||
_phantom: PhantomData<&'a [P::Scalar]>,
|
||||
}
|
||||
|
||||
impl<'a, P: PackedField> PackedStridedViewIter<'a, P> {
|
||||
pub(self) fn new(start: *const P::Scalar, end: *const P::Scalar, stride: usize) -> Self {
|
||||
Self {
|
||||
start,
|
||||
end,
|
||||
stride,
|
||||
_phantom: PhantomData,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, P: PackedField> Iterator for PackedStridedViewIter<'a, P> {
|
||||
type Item = &'a P;
|
||||
fn next(&mut self) -> Option<Self::Item> {
|
||||
debug_assert_eq!(
|
||||
(self.end as usize).wrapping_sub(self.start as usize)
|
||||
% (self.stride * size_of::<P::Scalar>()),
|
||||
0,
|
||||
"start and end pointers should be separated by a multiple of stride"
|
||||
);
|
||||
|
||||
if self.start != self.end {
|
||||
let res = unsafe { &*self.start.cast() };
|
||||
// See comment in `PackedStridedView`. Below will point more than one byte past the end
|
||||
// of the buffer if the offset is not 0 and we've reached the end.
|
||||
self.start = self.start.wrapping_add(self.stride);
|
||||
Some(res)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, P: PackedField> DoubleEndedIterator for PackedStridedViewIter<'a, P> {
|
||||
fn next_back(&mut self) -> Option<Self::Item> {
|
||||
debug_assert_eq!(
|
||||
(self.end as usize).wrapping_sub(self.start as usize)
|
||||
% (self.stride * size_of::<P::Scalar>()),
|
||||
0,
|
||||
"start and end pointers should be separated by a multiple of stride"
|
||||
);
|
||||
|
||||
if self.start != self.end {
|
||||
// See comment in `PackedStridedView`. `self.end` starts off pointing more than one byte
|
||||
// past the end of the buffer unless `offset` is 0.
|
||||
self.end = self.end.wrapping_sub(self.stride);
|
||||
Some(unsafe { &*self.end.cast() })
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub trait Viewable<F> {
|
||||
// We cannot implement `Index` as `PackedStridedView` is a struct, not a reference.
|
||||
type View;
|
||||
fn view(&self, index: F) -> Self::View;
|
||||
}
|
||||
|
||||
impl<'a, P: PackedField> Viewable<Range<usize>> for PackedStridedView<'a, P> {
|
||||
type View = Self;
|
||||
fn view(&self, range: Range<usize>) -> Self::View {
|
||||
assert!(range.start <= self.len(), "Invalid access");
|
||||
assert!(range.end <= self.len(), "Invalid access");
|
||||
Self {
|
||||
// See comment in `PackedStridedView`. `self.start_ptr` will point more than one byte
|
||||
// past the end of the buffer if the offset is not 0 and the buffer has length 0.
|
||||
start_ptr: self.start_ptr.wrapping_add(self.stride * range.start),
|
||||
length: range.end - range.start,
|
||||
stride: self.stride,
|
||||
_phantom: PhantomData,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, P: PackedField> Viewable<RangeFrom<usize>> for PackedStridedView<'a, P> {
|
||||
type View = Self;
|
||||
fn view(&self, range: RangeFrom<usize>) -> Self::View {
|
||||
assert!(range.start <= self.len(), "Invalid access");
|
||||
Self {
|
||||
// See comment in `PackedStridedView`. `self.start_ptr` will point more than one byte
|
||||
// past the end of the buffer if the offset is not 0 and the buffer has length 0.
|
||||
start_ptr: self.start_ptr.wrapping_add(self.stride * range.start),
|
||||
length: self.len() - range.start,
|
||||
stride: self.stride,
|
||||
_phantom: PhantomData,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, P: PackedField> Viewable<RangeFull> for PackedStridedView<'a, P> {
|
||||
type View = Self;
|
||||
fn view(&self, _range: RangeFull) -> Self::View {
|
||||
*self
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, P: PackedField> Viewable<RangeInclusive<usize>> for PackedStridedView<'a, P> {
|
||||
type View = Self;
|
||||
fn view(&self, range: RangeInclusive<usize>) -> Self::View {
|
||||
assert!(*range.start() <= self.len(), "Invalid access");
|
||||
assert!(*range.end() < self.len(), "Invalid access");
|
||||
Self {
|
||||
// See comment in `PackedStridedView`. `self.start_ptr` will point more than one byte
|
||||
// past the end of the buffer if the offset is not 0 and the buffer has length 0.
|
||||
start_ptr: self.start_ptr.wrapping_add(self.stride * range.start()),
|
||||
length: range.end() - range.start() + 1,
|
||||
stride: self.stride,
|
||||
_phantom: PhantomData,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, P: PackedField> Viewable<RangeTo<usize>> for PackedStridedView<'a, P> {
|
||||
type View = Self;
|
||||
fn view(&self, range: RangeTo<usize>) -> Self::View {
|
||||
assert!(range.end <= self.len(), "Invalid access");
|
||||
Self {
|
||||
start_ptr: self.start_ptr,
|
||||
length: range.end,
|
||||
stride: self.stride,
|
||||
_phantom: PhantomData,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, P: PackedField> Viewable<RangeToInclusive<usize>> for PackedStridedView<'a, P> {
|
||||
type View = Self;
|
||||
fn view(&self, range: RangeToInclusive<usize>) -> Self::View {
|
||||
assert!(range.end < self.len(), "Invalid access");
|
||||
Self {
|
||||
start_ptr: self.start_ptr,
|
||||
length: range.end + 1,
|
||||
stride: self.stride,
|
||||
_phantom: PhantomData,
|
||||
}
|
||||
}
|
||||
}
|
||||
Loading…
x
Reference in New Issue
Block a user