Packed evaluation for most gates (#395)

* Most gates support packed evaluation

* ComparisonGate

* Minor: outdated todo marker

* Revert superfluous change

* Post-merge fixes

* Daniel comments

* Minor: Markdown in comments
This commit is contained in:
Jakub Nabaglo 2021-12-20 15:08:07 -08:00 committed by GitHub
parent bbbb57caa6
commit d4a0a8661e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
17 changed files with 768 additions and 407 deletions

View File

@ -1,14 +1,19 @@
use crate::field::extension_field::target::ExtensionTarget;
use crate::field::extension_field::Extendable;
use crate::field::field_types::RichField;
use crate::field::packed_field::PackedField;
use crate::gates::gate::Gate;
use crate::gates::packed_util::PackedEvaluableBase;
use crate::gates::util::StridedConstraintConsumer;
use crate::iop::generator::{GeneratedValues, SimpleGenerator, WitnessGenerator};
use crate::iop::target::Target;
use crate::iop::witness::{PartitionWitness, Witness};
use crate::plonk::circuit_builder::CircuitBuilder;
use crate::plonk::circuit_data::CircuitConfig;
use crate::plonk::vars::{EvaluationTargets, EvaluationVars, EvaluationVarsBase};
use crate::plonk::vars::{
EvaluationTargets, EvaluationVars, EvaluationVarsBase, EvaluationVarsBaseBatch,
EvaluationVarsBasePacked,
};
/// A gate which can perform a weighted multiply-add, i.e. `result = c0 x y + c1 z`. If the config
/// supports enough routed wires, it can support several such operations in one gate.
@ -70,21 +75,14 @@ impl<F: RichField + Extendable<D>, const D: usize> Gate<F, D> for ArithmeticGate
fn eval_unfiltered_base_one(
&self,
vars: EvaluationVarsBase<F>,
mut yield_constr: StridedConstraintConsumer<F>,
_vars: EvaluationVarsBase<F>,
_yield_constr: StridedConstraintConsumer<F>,
) {
let const_0 = vars.local_constants[0];
let const_1 = vars.local_constants[1];
panic!("use eval_unfiltered_base_packed instead");
}
for i in 0..self.num_ops {
let multiplicand_0 = vars.local_wires[Self::wire_ith_multiplicand_0(i)];
let multiplicand_1 = vars.local_wires[Self::wire_ith_multiplicand_1(i)];
let addend = vars.local_wires[Self::wire_ith_addend(i)];
let output = vars.local_wires[Self::wire_ith_output(i)];
let computed_output = multiplicand_0 * multiplicand_1 * const_0 + addend * const_1;
yield_constr.one(output - computed_output);
}
fn eval_unfiltered_base_batch(&self, vars_base: EvaluationVarsBaseBatch<F>) -> Vec<F> {
self.eval_unfiltered_base_batch_packed(vars_base)
}
fn eval_unfiltered_recursively(
@ -152,6 +150,27 @@ impl<F: RichField + Extendable<D>, const D: usize> Gate<F, D> for ArithmeticGate
}
}
impl<F: RichField + Extendable<D>, const D: usize> PackedEvaluableBase<F, D> for ArithmeticGate {
fn eval_unfiltered_base_packed<P: PackedField<Scalar = F>>(
&self,
vars: EvaluationVarsBasePacked<P>,
mut yield_constr: StridedConstraintConsumer<P>,
) {
let const_0 = vars.local_constants[0];
let const_1 = vars.local_constants[1];
for i in 0..self.num_ops {
let multiplicand_0 = vars.local_wires[Self::wire_ith_multiplicand_0(i)];
let multiplicand_1 = vars.local_wires[Self::wire_ith_multiplicand_1(i)];
let addend = vars.local_wires[Self::wire_ith_addend(i)];
let output = vars.local_wires[Self::wire_ith_output(i)];
let computed_output = multiplicand_0 * multiplicand_1 * const_0 + addend * const_1;
yield_constr.one(output - computed_output);
}
}
}
#[derive(Clone, Debug)]
struct ArithmeticBaseGenerator<F: RichField + Extendable<D>, const D: usize> {
gate_index: usize,

View File

@ -5,7 +5,9 @@ use itertools::unfold;
use crate::field::extension_field::target::ExtensionTarget;
use crate::field::extension_field::Extendable;
use crate::field::field_types::Field;
use crate::field::packed_field::PackedField;
use crate::gates::gate::Gate;
use crate::gates::packed_util::PackedEvaluableBase;
use crate::gates::util::StridedConstraintConsumer;
use crate::iop::generator::{GeneratedValues, SimpleGenerator, WitnessGenerator};
use crate::iop::target::Target;
@ -13,7 +15,10 @@ use crate::iop::wire::Wire;
use crate::iop::witness::{PartitionWitness, Witness};
use crate::plonk::circuit_builder::CircuitBuilder;
use crate::plonk::circuit_data::CircuitConfig;
use crate::plonk::vars::{EvaluationTargets, EvaluationVars, EvaluationVarsBase};
use crate::plonk::vars::{
EvaluationTargets, EvaluationVars, EvaluationVarsBase, EvaluationVarsBaseBatch,
EvaluationVarsBasePacked,
};
/// A gate to perform a basic mul-add on 32-bit values (we assume they are range-checked beforehand).
#[derive(Copy, Clone, Debug)]
@ -94,8 +99,8 @@ impl<F: Extendable<D>, const D: usize> Gate<F, D> for U32ArithmeticGate<F, D> {
constraints.push(combined_output - computed_output);
let mut combined_low_limbs = F::Extension::ZERO;
let mut combined_high_limbs = F::Extension::ZERO;
let mut combined_low_limbs = <F::Extension as Field>::ZERO;
let mut combined_high_limbs = <F::Extension as Field>::ZERO;
let midpoint = Self::num_limbs() / 2;
let base = F::Extension::from_canonical_u64(1u64 << Self::limb_bits());
for j in (0..Self::num_limbs()).rev() {
@ -121,45 +126,14 @@ impl<F: Extendable<D>, const D: usize> Gate<F, D> for U32ArithmeticGate<F, D> {
fn eval_unfiltered_base_one(
&self,
vars: EvaluationVarsBase<F>,
mut yield_constr: StridedConstraintConsumer<F>,
_vars: EvaluationVarsBase<F>,
_yield_constr: StridedConstraintConsumer<F>,
) {
for i in 0..self.num_ops {
let multiplicand_0 = vars.local_wires[self.wire_ith_multiplicand_0(i)];
let multiplicand_1 = vars.local_wires[self.wire_ith_multiplicand_1(i)];
let addend = vars.local_wires[self.wire_ith_addend(i)];
panic!("use eval_unfiltered_base_packed instead");
}
let computed_output = multiplicand_0 * multiplicand_1 + addend;
let output_low = vars.local_wires[self.wire_ith_output_low_half(i)];
let output_high = vars.local_wires[self.wire_ith_output_high_half(i)];
let base = F::from_canonical_u64(1 << 32u64);
let combined_output = output_high * base + output_low;
yield_constr.one(combined_output - computed_output);
let mut combined_low_limbs = F::ZERO;
let mut combined_high_limbs = F::ZERO;
let midpoint = Self::num_limbs() / 2;
let base = F::from_canonical_u64(1u64 << Self::limb_bits());
for j in (0..Self::num_limbs()).rev() {
let this_limb = vars.local_wires[self.wire_ith_output_jth_limb(i, j)];
let max_limb = 1 << Self::limb_bits();
let product = (0..max_limb)
.map(|x| this_limb - F::from_canonical_usize(x))
.product();
yield_constr.one(product);
if j < midpoint {
combined_low_limbs = base * combined_low_limbs + this_limb;
} else {
combined_high_limbs = base * combined_high_limbs + this_limb;
}
}
yield_constr.one(combined_low_limbs - output_low);
yield_constr.one(combined_high_limbs - output_high);
}
fn eval_unfiltered_base_batch(&self, vars_base: EvaluationVarsBaseBatch<F>) -> Vec<F> {
self.eval_unfiltered_base_batch_packed(vars_base)
}
fn eval_unfiltered_recursively(
@ -257,6 +231,51 @@ impl<F: Extendable<D>, const D: usize> Gate<F, D> for U32ArithmeticGate<F, D> {
}
}
impl<F: Extendable<D>, const D: usize> PackedEvaluableBase<F, D> for U32ArithmeticGate<F, D> {
fn eval_unfiltered_base_packed<P: PackedField<Scalar = F>>(
&self,
vars: EvaluationVarsBasePacked<P>,
mut yield_constr: StridedConstraintConsumer<P>,
) {
for i in 0..self.num_ops {
let multiplicand_0 = vars.local_wires[self.wire_ith_multiplicand_0(i)];
let multiplicand_1 = vars.local_wires[self.wire_ith_multiplicand_1(i)];
let addend = vars.local_wires[self.wire_ith_addend(i)];
let computed_output = multiplicand_0 * multiplicand_1 + addend;
let output_low = vars.local_wires[self.wire_ith_output_low_half(i)];
let output_high = vars.local_wires[self.wire_ith_output_high_half(i)];
let base = F::from_canonical_u64(1 << 32u64);
let combined_output = output_high * base + output_low;
yield_constr.one(combined_output - computed_output);
let mut combined_low_limbs = P::ZERO;
let mut combined_high_limbs = P::ZERO;
let midpoint = Self::num_limbs() / 2;
let base = F::from_canonical_u64(1u64 << Self::limb_bits());
for j in (0..Self::num_limbs()).rev() {
let this_limb = vars.local_wires[self.wire_ith_output_jth_limb(i, j)];
let max_limb = 1 << Self::limb_bits();
let product = (0..max_limb)
.map(|x| this_limb - F::from_canonical_usize(x))
.product();
yield_constr.one(product);
if j < midpoint {
combined_low_limbs = combined_low_limbs * base + this_limb;
} else {
combined_high_limbs = combined_high_limbs * base + this_limb;
}
}
yield_constr.one(combined_low_limbs - output_low);
yield_constr.one(combined_high_limbs - output_high);
}
}
}
#[derive(Clone, Debug)]
struct U32ArithmeticGenerator<F: Extendable<D>, const D: usize> {
gate: U32ArithmeticGate<F, D>,

View File

@ -3,7 +3,9 @@ use std::marker::PhantomData;
use crate::field::extension_field::target::ExtensionTarget;
use crate::field::extension_field::Extendable;
use crate::field::field_types::{Field, PrimeField, RichField};
use crate::field::packed_field::PackedField;
use crate::gates::gate::Gate;
use crate::gates::packed_util::PackedEvaluableBase;
use crate::gates::util::StridedConstraintConsumer;
use crate::iop::generator::{GeneratedValues, SimpleGenerator, WitnessGenerator};
use crate::iop::target::Target;
@ -11,7 +13,10 @@ use crate::iop::wire::Wire;
use crate::iop::witness::{PartitionWitness, Witness};
use crate::plonk::circuit_builder::CircuitBuilder;
use crate::plonk::plonk_common::{reduce_with_powers, reduce_with_powers_ext_recursive};
use crate::plonk::vars::{EvaluationTargets, EvaluationVars, EvaluationVarsBase};
use crate::plonk::vars::{
EvaluationTargets, EvaluationVars, EvaluationVarsBase, EvaluationVarsBaseBatch,
EvaluationVarsBasePacked,
};
use crate::util::{bits_u64, ceil_div_usize};
// TODO: replace/merge this gate with `ComparisonGate`.
@ -109,7 +114,7 @@ impl<F: RichField + Extendable<D>, const D: usize> Gate<F, D> for AssertLessThan
let chunk_size = 1 << self.chunk_bits();
let mut most_significant_diff_so_far = F::Extension::ZERO;
let mut most_significant_diff_so_far = <F::Extension as Field>::ZERO;
for i in 0..self.num_chunks {
// Range-check the chunks to be less than `chunk_size`.
@ -127,14 +132,15 @@ impl<F: RichField + Extendable<D>, const D: usize> Gate<F, D> for AssertLessThan
let chunks_equal = vars.local_wires[self.wire_chunks_equal(i)];
// Two constraints to assert that `chunks_equal` is valid.
constraints.push(difference * equality_dummy - (F::Extension::ONE - chunks_equal));
constraints
.push(difference * equality_dummy - (<F::Extension as Field>::ONE - chunks_equal));
constraints.push(chunks_equal * difference);
// Update `most_significant_diff_so_far`.
let intermediate_value = vars.local_wires[self.wire_intermediate_value(i)];
constraints.push(intermediate_value - chunks_equal * most_significant_diff_so_far);
most_significant_diff_so_far =
intermediate_value + (F::Extension::ONE - chunks_equal) * difference;
intermediate_value + (<F::Extension as Field>::ONE - chunks_equal) * difference;
}
let most_significant_diff = vars.local_wires[self.wire_most_significant_diff()];
@ -151,70 +157,14 @@ impl<F: RichField + Extendable<D>, const D: usize> Gate<F, D> for AssertLessThan
fn eval_unfiltered_base_one(
&self,
vars: EvaluationVarsBase<F>,
mut yield_constr: StridedConstraintConsumer<F>,
_vars: EvaluationVarsBase<F>,
_yield_constr: StridedConstraintConsumer<F>,
) {
let first_input = vars.local_wires[self.wire_first_input()];
let second_input = vars.local_wires[self.wire_second_input()];
panic!("use eval_unfiltered_base_packed instead");
}
// Get chunks and assert that they match
let first_chunks: Vec<F> = (0..self.num_chunks)
.map(|i| vars.local_wires[self.wire_first_chunk_val(i)])
.collect();
let second_chunks: Vec<F> = (0..self.num_chunks)
.map(|i| vars.local_wires[self.wire_second_chunk_val(i)])
.collect();
let first_chunks_combined = reduce_with_powers(
&first_chunks,
F::from_canonical_usize(1 << self.chunk_bits()),
);
let second_chunks_combined = reduce_with_powers(
&second_chunks,
F::from_canonical_usize(1 << self.chunk_bits()),
);
yield_constr.one(first_chunks_combined - first_input);
yield_constr.one(second_chunks_combined - second_input);
let chunk_size = 1 << self.chunk_bits();
let mut most_significant_diff_so_far = F::ZERO;
for i in 0..self.num_chunks {
// Range-check the chunks to be less than `chunk_size`.
let first_product = (0..chunk_size)
.map(|x| first_chunks[i] - F::from_canonical_usize(x))
.product();
let second_product = (0..chunk_size)
.map(|x| second_chunks[i] - F::from_canonical_usize(x))
.product();
yield_constr.one(first_product);
yield_constr.one(second_product);
let difference = second_chunks[i] - first_chunks[i];
let equality_dummy = vars.local_wires[self.wire_equality_dummy(i)];
let chunks_equal = vars.local_wires[self.wire_chunks_equal(i)];
// Two constraints to assert that `chunks_equal` is valid.
yield_constr.one(difference * equality_dummy - (F::ONE - chunks_equal));
yield_constr.one(chunks_equal * difference);
// Update `most_significant_diff_so_far`.
let intermediate_value = vars.local_wires[self.wire_intermediate_value(i)];
yield_constr.one(intermediate_value - chunks_equal * most_significant_diff_so_far);
most_significant_diff_so_far =
intermediate_value + (F::ONE - chunks_equal) * difference;
}
let most_significant_diff = vars.local_wires[self.wire_most_significant_diff()];
yield_constr.one(most_significant_diff - most_significant_diff_so_far);
// Range check `most_significant_diff` to be less than `chunk_size`.
let product = (0..chunk_size)
.map(|x| most_significant_diff - F::from_canonical_usize(x))
.product();
yield_constr.one(product);
fn eval_unfiltered_base_batch(&self, vars_base: EvaluationVarsBaseBatch<F>) -> Vec<F> {
self.eval_unfiltered_base_batch_packed(vars_base)
}
fn eval_unfiltered_recursively(
@ -329,6 +279,78 @@ impl<F: RichField + Extendable<D>, const D: usize> Gate<F, D> for AssertLessThan
}
}
impl<F: RichField + Extendable<D>, const D: usize> PackedEvaluableBase<F, D>
for AssertLessThanGate<F, D>
{
fn eval_unfiltered_base_packed<P: PackedField<Scalar = F>>(
&self,
vars: EvaluationVarsBasePacked<P>,
mut yield_constr: StridedConstraintConsumer<P>,
) {
let first_input = vars.local_wires[self.wire_first_input()];
let second_input = vars.local_wires[self.wire_second_input()];
// Get chunks and assert that they match
let first_chunks: Vec<_> = (0..self.num_chunks)
.map(|i| vars.local_wires[self.wire_first_chunk_val(i)])
.collect();
let second_chunks: Vec<_> = (0..self.num_chunks)
.map(|i| vars.local_wires[self.wire_second_chunk_val(i)])
.collect();
let first_chunks_combined = reduce_with_powers(
&first_chunks,
F::from_canonical_usize(1 << self.chunk_bits()),
);
let second_chunks_combined = reduce_with_powers(
&second_chunks,
F::from_canonical_usize(1 << self.chunk_bits()),
);
yield_constr.one(first_chunks_combined - first_input);
yield_constr.one(second_chunks_combined - second_input);
let chunk_size = 1 << self.chunk_bits();
let mut most_significant_diff_so_far = P::ZERO;
for i in 0..self.num_chunks {
// Range-check the chunks to be less than `chunk_size`.
let first_product = (0..chunk_size)
.map(|x| first_chunks[i] - F::from_canonical_usize(x))
.product();
let second_product = (0..chunk_size)
.map(|x| second_chunks[i] - F::from_canonical_usize(x))
.product();
yield_constr.one(first_product);
yield_constr.one(second_product);
let difference = second_chunks[i] - first_chunks[i];
let equality_dummy = vars.local_wires[self.wire_equality_dummy(i)];
let chunks_equal = vars.local_wires[self.wire_chunks_equal(i)];
// Two constraints to assert that `chunks_equal` is valid.
yield_constr.one(difference * equality_dummy - (P::ONE - chunks_equal));
yield_constr.one(chunks_equal * difference);
// Update `most_significant_diff_so_far`.
let intermediate_value = vars.local_wires[self.wire_intermediate_value(i)];
yield_constr.one(intermediate_value - chunks_equal * most_significant_diff_so_far);
most_significant_diff_so_far =
intermediate_value + (P::ONE - chunks_equal) * difference;
}
let most_significant_diff = vars.local_wires[self.wire_most_significant_diff()];
yield_constr.one(most_significant_diff - most_significant_diff_so_far);
// Range check `most_significant_diff` to be less than `chunk_size`.
let product = (0..chunk_size)
.map(|x| most_significant_diff - F::from_canonical_usize(x))
.product();
yield_constr.one(product);
}
}
#[derive(Debug)]
struct AssertLessThanGenerator<F: RichField + Extendable<D>, const D: usize> {
gate_index: usize,

View File

@ -3,7 +3,9 @@ use std::ops::Range;
use crate::field::extension_field::target::ExtensionTarget;
use crate::field::extension_field::Extendable;
use crate::field::field_types::{Field, PrimeField, RichField};
use crate::field::packed_field::PackedField;
use crate::gates::gate::Gate;
use crate::gates::packed_util::PackedEvaluableBase;
use crate::gates::util::StridedConstraintConsumer;
use crate::iop::generator::{GeneratedValues, SimpleGenerator, WitnessGenerator};
use crate::iop::target::Target;
@ -11,7 +13,10 @@ use crate::iop::witness::{PartitionWitness, Witness};
use crate::plonk::circuit_builder::CircuitBuilder;
use crate::plonk::circuit_data::CircuitConfig;
use crate::plonk::plonk_common::{reduce_with_powers, reduce_with_powers_ext_recursive};
use crate::plonk::vars::{EvaluationTargets, EvaluationVars, EvaluationVarsBase};
use crate::plonk::vars::{
EvaluationTargets, EvaluationVars, EvaluationVarsBase, EvaluationVarsBaseBatch,
EvaluationVarsBasePacked,
};
/// A gate which can decompose a number into base B little-endian limbs.
#[derive(Copy, Clone, Debug)]
@ -60,21 +65,14 @@ impl<F: Extendable<D>, const D: usize, const B: usize> Gate<F, D> for BaseSumGat
fn eval_unfiltered_base_one(
&self,
vars: EvaluationVarsBase<F>,
mut yield_constr: StridedConstraintConsumer<F>,
_vars: EvaluationVarsBase<F>,
_yield_constr: StridedConstraintConsumer<F>,
) {
let sum = vars.local_wires[Self::WIRE_SUM];
let limbs = vars.local_wires.view(self.limbs());
let computed_sum = reduce_with_powers(limbs, F::from_canonical_usize(B));
panic!("use eval_unfiltered_base_packed instead");
}
yield_constr.one(computed_sum - sum);
let constraints_iter = limbs.iter().map(|&limb| {
(0..B)
.map(|i| unsafe { limb.sub_canonical_u64(i as u64) })
.product::<F>()
});
yield_constr.many(constraints_iter);
fn eval_unfiltered_base_batch(&self, vars_base: EvaluationVarsBaseBatch<F>) -> Vec<F> {
self.eval_unfiltered_base_batch_packed(vars_base)
}
fn eval_unfiltered_recursively(
@ -136,6 +134,29 @@ impl<F: Extendable<D>, const D: usize, const B: usize> Gate<F, D> for BaseSumGat
}
}
impl<F: Extendable<D>, const D: usize, const B: usize> PackedEvaluableBase<F, D>
for BaseSumGate<B>
{
fn eval_unfiltered_base_packed<P: PackedField<Scalar = F>>(
&self,
vars: EvaluationVarsBasePacked<P>,
mut yield_constr: StridedConstraintConsumer<P>,
) {
let sum = vars.local_wires[Self::WIRE_SUM];
let limbs = vars.local_wires.view(self.limbs());
let computed_sum = reduce_with_powers(limbs, F::from_canonical_usize(B));
yield_constr.one(computed_sum - sum);
let constraints_iter = limbs.iter().map(|&limb| {
(0..B)
.map(|i| limb - F::from_canonical_usize(i))
.product::<P>()
});
yield_constr.many(constraints_iter);
}
}
#[derive(Debug)]
pub struct BaseSplitGenerator<const B: usize> {
gate_index: usize,

View File

@ -3,7 +3,9 @@ use std::marker::PhantomData;
use crate::field::extension_field::target::ExtensionTarget;
use crate::field::extension_field::Extendable;
use crate::field::field_types::{Field, PrimeField};
use crate::field::packed_field::PackedField;
use crate::gates::gate::Gate;
use crate::gates::packed_util::PackedEvaluableBase;
use crate::gates::util::StridedConstraintConsumer;
use crate::iop::generator::{GeneratedValues, SimpleGenerator, WitnessGenerator};
use crate::iop::target::Target;
@ -11,7 +13,10 @@ use crate::iop::wire::Wire;
use crate::iop::witness::{PartitionWitness, Witness};
use crate::plonk::circuit_builder::CircuitBuilder;
use crate::plonk::plonk_common::{reduce_with_powers, reduce_with_powers_ext_recursive};
use crate::plonk::vars::{EvaluationTargets, EvaluationVars, EvaluationVarsBase};
use crate::plonk::vars::{
EvaluationTargets, EvaluationVars, EvaluationVarsBase, EvaluationVarsBaseBatch,
EvaluationVarsBasePacked,
};
use crate::util::{bits_u64, ceil_div_usize};
/// A gate for checking that one value is less than or equal to another.
@ -116,7 +121,7 @@ impl<F: Extendable<D>, const D: usize> Gate<F, D> for ComparisonGate<F, D> {
let chunk_size = 1 << self.chunk_bits();
let mut most_significant_diff_so_far = F::Extension::ZERO;
let mut most_significant_diff_so_far = <F::Extension as Field>::ZERO;
for i in 0..self.num_chunks {
// Range-check the chunks to be less than `chunk_size`.
@ -134,14 +139,15 @@ impl<F: Extendable<D>, const D: usize> Gate<F, D> for ComparisonGate<F, D> {
let chunks_equal = vars.local_wires[self.wire_chunks_equal(i)];
// Two constraints to assert that `chunks_equal` is valid.
constraints.push(difference * equality_dummy - (F::Extension::ONE - chunks_equal));
constraints
.push(difference * equality_dummy - (<F::Extension as Field>::ONE - chunks_equal));
constraints.push(chunks_equal * difference);
// Update `most_significant_diff_so_far`.
let intermediate_value = vars.local_wires[self.wire_intermediate_value(i)];
constraints.push(intermediate_value - chunks_equal * most_significant_diff_so_far);
most_significant_diff_so_far =
intermediate_value + (F::Extension::ONE - chunks_equal) * difference;
intermediate_value + (<F::Extension as Field>::ONE - chunks_equal) * difference;
}
let most_significant_diff = vars.local_wires[self.wire_most_significant_diff()];
@ -153,7 +159,7 @@ impl<F: Extendable<D>, const D: usize> Gate<F, D> for ComparisonGate<F, D> {
// Range-check the bits.
for &bit in &most_significant_diff_bits {
constraints.push(bit * (F::Extension::ONE - bit));
constraints.push(bit * (<F::Extension as Field>::ONE - bit));
}
let bits_combined = reduce_with_powers(&most_significant_diff_bits, F::Extension::TWO);
@ -169,81 +175,14 @@ impl<F: Extendable<D>, const D: usize> Gate<F, D> for ComparisonGate<F, D> {
fn eval_unfiltered_base_one(
&self,
vars: EvaluationVarsBase<F>,
mut yield_constr: StridedConstraintConsumer<F>,
_vars: EvaluationVarsBase<F>,
_yield_constr: StridedConstraintConsumer<F>,
) {
let first_input = vars.local_wires[self.wire_first_input()];
let second_input = vars.local_wires[self.wire_second_input()];
panic!("use eval_unfiltered_base_packed instead");
}
// Get chunks and assert that they match
let first_chunks: Vec<F> = (0..self.num_chunks)
.map(|i| vars.local_wires[self.wire_first_chunk_val(i)])
.collect();
let second_chunks: Vec<F> = (0..self.num_chunks)
.map(|i| vars.local_wires[self.wire_second_chunk_val(i)])
.collect();
let first_chunks_combined = reduce_with_powers(
&first_chunks,
F::from_canonical_usize(1 << self.chunk_bits()),
);
let second_chunks_combined = reduce_with_powers(
&second_chunks,
F::from_canonical_usize(1 << self.chunk_bits()),
);
yield_constr.one(first_chunks_combined - first_input);
yield_constr.one(second_chunks_combined - second_input);
let chunk_size = 1 << self.chunk_bits();
let mut most_significant_diff_so_far = F::ZERO;
for i in 0..self.num_chunks {
// Range-check the chunks to be less than `chunk_size`.
let first_product: F = (0..chunk_size)
.map(|x| first_chunks[i] - F::from_canonical_usize(x))
.product();
let second_product: F = (0..chunk_size)
.map(|x| second_chunks[i] - F::from_canonical_usize(x))
.product();
yield_constr.one(first_product);
yield_constr.one(second_product);
let difference = second_chunks[i] - first_chunks[i];
let equality_dummy = vars.local_wires[self.wire_equality_dummy(i)];
let chunks_equal = vars.local_wires[self.wire_chunks_equal(i)];
// Two constraints to assert that `chunks_equal` is valid.
yield_constr.one(difference * equality_dummy - (F::ONE - chunks_equal));
yield_constr.one(chunks_equal * difference);
// Update `most_significant_diff_so_far`.
let intermediate_value = vars.local_wires[self.wire_intermediate_value(i)];
yield_constr.one(intermediate_value - chunks_equal * most_significant_diff_so_far);
most_significant_diff_so_far =
intermediate_value + (F::ONE - chunks_equal) * difference;
}
let most_significant_diff = vars.local_wires[self.wire_most_significant_diff()];
yield_constr.one(most_significant_diff - most_significant_diff_so_far);
let most_significant_diff_bits: Vec<F> = (0..self.chunk_bits() + 1)
.map(|i| vars.local_wires[self.wire_most_significant_diff_bit(i)])
.collect();
// Range-check the bits.
for &bit in &most_significant_diff_bits {
yield_constr.one(bit * (F::ONE - bit));
}
let bits_combined = reduce_with_powers(&most_significant_diff_bits, F::TWO);
let two_n = F::from_canonical_u64(1 << self.chunk_bits());
yield_constr.one((two_n + most_significant_diff) - bits_combined);
// Iff first <= second, the top (n + 1st) bit of (2^n - 1 + most_significant_diff) will be 1.
let result_bool = vars.local_wires[self.wire_result_bool()];
yield_constr.one(result_bool - most_significant_diff_bits[self.chunk_bits()]);
fn eval_unfiltered_base_batch(&self, vars_base: EvaluationVarsBaseBatch<F>) -> Vec<F> {
self.eval_unfiltered_base_batch_packed(vars_base)
}
fn eval_unfiltered_recursively(
@ -373,6 +312,87 @@ impl<F: Extendable<D>, const D: usize> Gate<F, D> for ComparisonGate<F, D> {
}
}
impl<F: Extendable<D>, const D: usize> PackedEvaluableBase<F, D> for ComparisonGate<F, D> {
fn eval_unfiltered_base_packed<P: PackedField<Scalar = F>>(
&self,
vars: EvaluationVarsBasePacked<P>,
mut yield_constr: StridedConstraintConsumer<P>,
) {
let first_input = vars.local_wires[self.wire_first_input()];
let second_input = vars.local_wires[self.wire_second_input()];
// Get chunks and assert that they match
let first_chunks: Vec<_> = (0..self.num_chunks)
.map(|i| vars.local_wires[self.wire_first_chunk_val(i)])
.collect();
let second_chunks: Vec<_> = (0..self.num_chunks)
.map(|i| vars.local_wires[self.wire_second_chunk_val(i)])
.collect();
let first_chunks_combined = reduce_with_powers(
&first_chunks,
F::from_canonical_usize(1 << self.chunk_bits()),
);
let second_chunks_combined = reduce_with_powers(
&second_chunks,
F::from_canonical_usize(1 << self.chunk_bits()),
);
yield_constr.one(first_chunks_combined - first_input);
yield_constr.one(second_chunks_combined - second_input);
let chunk_size = 1 << self.chunk_bits();
let mut most_significant_diff_so_far = P::ZERO;
for i in 0..self.num_chunks {
// Range-check the chunks to be less than `chunk_size`.
let first_product: P = (0..chunk_size)
.map(|x| first_chunks[i] - F::from_canonical_usize(x))
.product();
let second_product: P = (0..chunk_size)
.map(|x| second_chunks[i] - F::from_canonical_usize(x))
.product();
yield_constr.one(first_product);
yield_constr.one(second_product);
let difference = second_chunks[i] - first_chunks[i];
let equality_dummy = vars.local_wires[self.wire_equality_dummy(i)];
let chunks_equal = vars.local_wires[self.wire_chunks_equal(i)];
// Two constraints to assert that `chunks_equal` is valid.
yield_constr.one(difference * equality_dummy - (P::ONE - chunks_equal));
yield_constr.one(chunks_equal * difference);
// Update `most_significant_diff_so_far`.
let intermediate_value = vars.local_wires[self.wire_intermediate_value(i)];
yield_constr.one(intermediate_value - chunks_equal * most_significant_diff_so_far);
most_significant_diff_so_far =
intermediate_value + (P::ONE - chunks_equal) * difference;
}
let most_significant_diff = vars.local_wires[self.wire_most_significant_diff()];
yield_constr.one(most_significant_diff - most_significant_diff_so_far);
let most_significant_diff_bits: Vec<_> = (0..self.chunk_bits() + 1)
.map(|i| vars.local_wires[self.wire_most_significant_diff_bit(i)])
.collect();
// Range-check the bits.
for &bit in &most_significant_diff_bits {
yield_constr.one(bit * (P::ONE - bit));
}
let bits_combined = reduce_with_powers(&most_significant_diff_bits, F::TWO);
let two_n = F::from_canonical_u64(1 << self.chunk_bits());
yield_constr.one((most_significant_diff + two_n) - bits_combined);
// Iff first <= second, the top (n + 1st) bit of (2^n - 1 + most_significant_diff) will be 1.
let result_bool = vars.local_wires[self.wire_result_bool()];
yield_constr.one(result_bool - most_significant_diff_bits[self.chunk_bits()]);
}
}
#[derive(Debug)]
struct ComparisonGenerator<F: Extendable<D>, const D: usize> {
gate_index: usize,

View File

@ -3,14 +3,19 @@ use std::ops::Range;
use crate::field::extension_field::target::ExtensionTarget;
use crate::field::extension_field::Extendable;
use crate::field::field_types::Field;
use crate::field::packed_field::PackedField;
use crate::gates::gate::Gate;
use crate::gates::packed_util::PackedEvaluableBase;
use crate::gates::util::StridedConstraintConsumer;
use crate::iop::generator::{GeneratedValues, SimpleGenerator, WitnessGenerator};
use crate::iop::target::Target;
use crate::iop::wire::Wire;
use crate::iop::witness::PartitionWitness;
use crate::plonk::circuit_builder::CircuitBuilder;
use crate::plonk::vars::{EvaluationTargets, EvaluationVars, EvaluationVarsBase};
use crate::plonk::vars::{
EvaluationTargets, EvaluationVars, EvaluationVarsBase, EvaluationVarsBaseBatch,
EvaluationVarsBasePacked,
};
/// A gate which takes a single constant parameter and outputs that value.
#[derive(Copy, Clone, Debug)]
@ -42,14 +47,14 @@ impl<F: Extendable<D>, const D: usize> Gate<F, D> for ConstantGate {
fn eval_unfiltered_base_one(
&self,
vars: EvaluationVarsBase<F>,
mut yield_constr: StridedConstraintConsumer<F>,
_vars: EvaluationVarsBase<F>,
_yield_constr: StridedConstraintConsumer<F>,
) {
yield_constr.many(
self.consts_inputs()
.zip(self.wires_outputs())
.map(|(con, out)| vars.local_constants[con] - vars.local_wires[out]),
);
panic!("use eval_unfiltered_base_packed instead");
}
fn eval_unfiltered_base_batch(&self, vars_base: EvaluationVarsBaseBatch<F>) -> Vec<F> {
self.eval_unfiltered_base_batch_packed(vars_base)
}
fn eval_unfiltered_recursively(
@ -95,6 +100,20 @@ impl<F: Extendable<D>, const D: usize> Gate<F, D> for ConstantGate {
}
}
impl<F: Extendable<D>, const D: usize> PackedEvaluableBase<F, D> for ConstantGate {
fn eval_unfiltered_base_packed<P: PackedField<Scalar = F>>(
&self,
vars: EvaluationVarsBasePacked<P>,
mut yield_constr: StridedConstraintConsumer<P>,
) {
yield_constr.many(
self.consts_inputs()
.zip(self.wires_outputs())
.map(|(con, out)| vars.local_constants[con] - vars.local_wires[out]),
);
}
}
#[derive(Debug)]
struct ConstantGenerator<F: Field> {
gate_index: usize,

View File

@ -3,7 +3,9 @@ use std::marker::PhantomData;
use crate::field::extension_field::target::ExtensionTarget;
use crate::field::extension_field::Extendable;
use crate::field::field_types::Field;
use crate::field::packed_field::PackedField;
use crate::gates::gate::Gate;
use crate::gates::packed_util::PackedEvaluableBase;
use crate::gates::util::StridedConstraintConsumer;
use crate::iop::generator::{GeneratedValues, SimpleGenerator, WitnessGenerator};
use crate::iop::target::Target;
@ -11,7 +13,10 @@ use crate::iop::wire::Wire;
use crate::iop::witness::{PartitionWitness, Witness};
use crate::plonk::circuit_builder::CircuitBuilder;
use crate::plonk::circuit_data::CircuitConfig;
use crate::plonk::vars::{EvaluationTargets, EvaluationVars, EvaluationVarsBase};
use crate::plonk::vars::{
EvaluationTargets, EvaluationVars, EvaluationVarsBase, EvaluationVarsBaseBatch,
EvaluationVarsBasePacked,
};
/// A gate for raising a value to a power.
#[derive(Clone, Debug)]
@ -81,15 +86,15 @@ impl<F: Extendable<D>, const D: usize> Gate<F, D> for ExponentiationGate<F, D> {
for i in 0..self.num_power_bits {
let prev_intermediate_value = if i == 0 {
F::Extension::ONE
<F::Extension as Field>::ONE
} else {
intermediate_values[i - 1].square()
<F::Extension as Field>::square(&intermediate_values[i - 1])
};
// power_bits is in LE order, but we accumulate in BE order.
let cur_bit = power_bits[self.num_power_bits - i - 1];
let not_cur_bit = F::Extension::ONE - cur_bit;
let not_cur_bit = <F::Extension as Field>::ONE - cur_bit;
let computed_intermediate_value =
prev_intermediate_value * (cur_bit * base + not_cur_bit);
constraints.push(computed_intermediate_value - intermediate_values[i]);
@ -102,37 +107,14 @@ impl<F: Extendable<D>, const D: usize> Gate<F, D> for ExponentiationGate<F, D> {
fn eval_unfiltered_base_one(
&self,
vars: EvaluationVarsBase<F>,
mut yield_constr: StridedConstraintConsumer<F>,
_vars: EvaluationVarsBase<F>,
_yield_constr: StridedConstraintConsumer<F>,
) {
let base = vars.local_wires[self.wire_base()];
panic!("use eval_unfiltered_base_packed instead");
}
let power_bits: Vec<_> = (0..self.num_power_bits)
.map(|i| vars.local_wires[self.wire_power_bit(i)])
.collect();
let intermediate_values: Vec<_> = (0..self.num_power_bits)
.map(|i| vars.local_wires[self.wire_intermediate_value(i)])
.collect();
let output = vars.local_wires[self.wire_output()];
for i in 0..self.num_power_bits {
let prev_intermediate_value = if i == 0 {
F::ONE
} else {
intermediate_values[i - 1].square()
};
// power_bits is in LE order, but we accumulate in BE order.
let cur_bit = power_bits[self.num_power_bits - i - 1];
let not_cur_bit = F::ONE - cur_bit;
let computed_intermediate_value =
prev_intermediate_value * (cur_bit * base + not_cur_bit);
yield_constr.one(computed_intermediate_value - intermediate_values[i]);
}
yield_constr.one(output - intermediate_values[self.num_power_bits - 1]);
fn eval_unfiltered_base_batch(&self, vars_base: EvaluationVarsBaseBatch<F>) -> Vec<F> {
self.eval_unfiltered_base_batch_packed(vars_base)
}
fn eval_unfiltered_recursively(
@ -205,6 +187,43 @@ impl<F: Extendable<D>, const D: usize> Gate<F, D> for ExponentiationGate<F, D> {
}
}
impl<F: Extendable<D>, const D: usize> PackedEvaluableBase<F, D> for ExponentiationGate<F, D> {
fn eval_unfiltered_base_packed<P: PackedField<Scalar = F>>(
&self,
vars: EvaluationVarsBasePacked<P>,
mut yield_constr: StridedConstraintConsumer<P>,
) {
let base = vars.local_wires[self.wire_base()];
let power_bits: Vec<_> = (0..self.num_power_bits)
.map(|i| vars.local_wires[self.wire_power_bit(i)])
.collect();
let intermediate_values: Vec<_> = (0..self.num_power_bits)
.map(|i| vars.local_wires[self.wire_intermediate_value(i)])
.collect();
let output = vars.local_wires[self.wire_output()];
for i in 0..self.num_power_bits {
let prev_intermediate_value = if i == 0 {
P::ONE
} else {
intermediate_values[i - 1].square()
};
// power_bits is in LE order, but we accumulate in BE order.
let cur_bit = power_bits[self.num_power_bits - i - 1];
let not_cur_bit = P::ONE - cur_bit;
let computed_intermediate_value =
prev_intermediate_value * (cur_bit * base + not_cur_bit);
yield_constr.one(computed_intermediate_value - intermediate_values[i]);
}
yield_constr.one(output - intermediate_values[self.num_power_bits - 1]);
}
}
#[derive(Debug)]
struct ExponentiationGenerator<F: Extendable<D>, const D: usize> {
gate_index: usize,

View File

@ -3,7 +3,9 @@ use std::marker::PhantomData;
use crate::field::extension_field::target::ExtensionTarget;
use crate::field::extension_field::Extendable;
use crate::field::field_types::Field;
use crate::field::packed_field::PackedField;
use crate::gates::gate::Gate;
use crate::gates::packed_util::PackedEvaluableBase;
use crate::gates::util::StridedConstraintConsumer;
use crate::hash::gmimc;
use crate::hash::gmimc::GMiMC;
@ -12,7 +14,10 @@ use crate::iop::target::Target;
use crate::iop::wire::Wire;
use crate::iop::witness::{PartitionWitness, Witness};
use crate::plonk::circuit_builder::CircuitBuilder;
use crate::plonk::vars::{EvaluationTargets, EvaluationVars, EvaluationVarsBase};
use crate::plonk::vars::{
EvaluationTargets, EvaluationVars, EvaluationVarsBase, EvaluationVarsBaseBatch,
EvaluationVarsBasePacked,
};
/// Evaluates a full GMiMC permutation with 12 state elements.
///
@ -68,7 +73,7 @@ impl<F: Extendable<D> + GMiMC<WIDTH>, const D: usize, const WIDTH: usize> Gate<F
// Assert that `swap` is binary.
let swap = vars.local_wires[Self::WIRE_SWAP];
constraints.push(swap * (swap - F::Extension::ONE));
constraints.push(swap * (swap - <F::Extension as Field>::ONE));
let mut state = Vec::with_capacity(12);
for i in 0..4 {
@ -87,7 +92,7 @@ impl<F: Extendable<D> + GMiMC<WIDTH>, const D: usize, const WIDTH: usize> Gate<F
// Value that is implicitly added to each element.
// See https://affine.group/2020/02/starkware-challenge
let mut addition_buffer = F::Extension::ZERO;
let mut addition_buffer = <F::Extension as Field>::ZERO;
for r in 0..gmimc::NUM_ROUNDS {
let active = r % WIDTH;
@ -110,47 +115,14 @@ impl<F: Extendable<D> + GMiMC<WIDTH>, const D: usize, const WIDTH: usize> Gate<F
fn eval_unfiltered_base_one(
&self,
vars: EvaluationVarsBase<F>,
mut yield_constr: StridedConstraintConsumer<F>,
_vars: EvaluationVarsBase<F>,
_yield_constr: StridedConstraintConsumer<F>,
) {
// Assert that `swap` is binary.
let swap = vars.local_wires[Self::WIRE_SWAP];
yield_constr.one(swap * swap.sub_one());
panic!("use eval_unfiltered_base_packed instead");
}
let mut state = Vec::with_capacity(12);
for i in 0..4 {
let a = vars.local_wires[i];
let b = vars.local_wires[i + 4];
state.push(a + swap * (b - a));
}
for i in 0..4 {
let a = vars.local_wires[i + 4];
let b = vars.local_wires[i];
state.push(a + swap * (b - a));
}
for i in 8..12 {
state.push(vars.local_wires[i]);
}
// Value that is implicitly added to each element.
// See https://affine.group/2020/02/starkware-challenge
let mut addition_buffer = F::ZERO;
for r in 0..gmimc::NUM_ROUNDS {
let active = r % WIDTH;
let constant = F::from_canonical_u64(<F as GMiMC<WIDTH>>::ROUND_CONSTANTS[r]);
let cubing_input = state[active] + addition_buffer + constant;
let cubing_input_wire = vars.local_wires[Self::wire_cubing_input(r)];
yield_constr.one(cubing_input - cubing_input_wire);
let f = cubing_input_wire.cube();
addition_buffer += f;
state[active] -= f;
}
for i in 0..WIDTH {
state[i] += addition_buffer;
yield_constr.one(state[i] - vars.local_wires[Self::wire_output(i)]);
}
fn eval_unfiltered_base_batch(&self, vars_base: EvaluationVarsBaseBatch<F>) -> Vec<F> {
self.eval_unfiltered_base_batch_packed(vars_base)
}
fn eval_unfiltered_recursively(
@ -236,6 +208,55 @@ impl<F: Extendable<D> + GMiMC<WIDTH>, const D: usize, const WIDTH: usize> Gate<F
}
}
impl<F: Extendable<D> + GMiMC<WIDTH>, const D: usize, const WIDTH: usize> PackedEvaluableBase<F, D>
for GMiMCGate<F, D, WIDTH>
{
fn eval_unfiltered_base_packed<P: PackedField<Scalar = F>>(
&self,
vars: EvaluationVarsBasePacked<P>,
mut yield_constr: StridedConstraintConsumer<P>,
) {
// Assert that `swap` is binary.
let swap = vars.local_wires[Self::WIRE_SWAP];
yield_constr.one(swap * (swap - F::ONE));
let mut state = Vec::with_capacity(12);
for i in 0..4 {
let a = vars.local_wires[i];
let b = vars.local_wires[i + 4];
state.push(a + swap * (b - a));
}
for i in 0..4 {
let a = vars.local_wires[i + 4];
let b = vars.local_wires[i];
state.push(a + swap * (b - a));
}
for i in 8..12 {
state.push(vars.local_wires[i]);
}
// Value that is implicitly added to each element.
// See https://affine.group/2020/02/starkware-challenge
let mut addition_buffer = P::ZERO;
for r in 0..gmimc::NUM_ROUNDS {
let active = r % WIDTH;
let constant = F::from_canonical_u64(<F as GMiMC<WIDTH>>::ROUND_CONSTANTS[r]);
let cubing_input = state[active] + addition_buffer + constant;
let cubing_input_wire = vars.local_wires[Self::wire_cubing_input(r)];
yield_constr.one(cubing_input - cubing_input_wire);
let f = cubing_input_wire.square() * cubing_input_wire;
addition_buffer += f;
state[active] -= f;
}
for i in 0..WIDTH {
state[i] += addition_buffer;
yield_constr.one(state[i] - vars.local_wires[Self::wire_output(i)]);
}
}
}
#[derive(Debug)]
struct GMiMCGenerator<F: Extendable<D> + GMiMC<WIDTH>, const D: usize, const WIDTH: usize> {
gate_index: usize,

View File

@ -17,6 +17,7 @@ pub mod interpolation;
pub mod low_degree_interpolation;
pub mod multiplication_extension;
pub mod noop;
mod packed_util;
pub mod poseidon;
pub(crate) mod poseidon_mds;
pub(crate) mod public_input;

39
src/gates/packed_util.rs Normal file
View File

@ -0,0 +1,39 @@
use crate::field::extension_field::Extendable;
use crate::field::packable::Packable;
use crate::field::packed_field::PackedField;
use crate::gates::gate::Gate;
use crate::gates::util::StridedConstraintConsumer;
use crate::plonk::vars::{EvaluationVarsBaseBatch, EvaluationVarsBasePacked};
pub trait PackedEvaluableBase<F: Extendable<D>, const D: usize>: Gate<F, D> {
fn eval_unfiltered_base_packed<P: PackedField<Scalar = F>>(
&self,
vars_base: EvaluationVarsBasePacked<P>,
yield_constr: StridedConstraintConsumer<P>,
);
/// Evaluates entire batch of points. Returns a matrix of constraints. Constraint `j` for point
/// `i` is at `index j * batch_size + i`.
fn eval_unfiltered_base_batch_packed(&self, vars_batch: EvaluationVarsBaseBatch<F>) -> Vec<F> {
let mut res = vec![F::ZERO; vars_batch.len() * self.num_constraints()];
let (vars_packed_iter, vars_leftovers_iter) = vars_batch.pack::<<F as Packable>::Packing>();
let leftovers_start = vars_batch.len() - vars_leftovers_iter.len();
for (i, vars_packed) in vars_packed_iter.enumerate() {
self.eval_unfiltered_base_packed(
vars_packed,
StridedConstraintConsumer::new(
&mut res[..],
vars_batch.len(),
<F as Packable>::Packing::WIDTH * i,
),
);
}
for (i, vars_leftovers) in vars_leftovers_iter.enumerate() {
self.eval_unfiltered_base_packed(
vars_leftovers,
StridedConstraintConsumer::new(&mut res[..], vars_batch.len(), leftovers_start + i),
);
}
res
}
}

View File

@ -2,11 +2,16 @@ use std::ops::Range;
use crate::field::extension_field::target::ExtensionTarget;
use crate::field::extension_field::Extendable;
use crate::field::packed_field::PackedField;
use crate::gates::gate::Gate;
use crate::gates::packed_util::PackedEvaluableBase;
use crate::gates::util::StridedConstraintConsumer;
use crate::iop::generator::WitnessGenerator;
use crate::plonk::circuit_builder::CircuitBuilder;
use crate::plonk::vars::{EvaluationTargets, EvaluationVars, EvaluationVarsBase};
use crate::plonk::vars::{
EvaluationTargets, EvaluationVars, EvaluationVarsBase, EvaluationVarsBaseBatch,
EvaluationVarsBasePacked,
};
/// A gate whose first four wires will be equal to a hash of public inputs.
pub struct PublicInputGate;
@ -31,14 +36,14 @@ impl<F: Extendable<D>, const D: usize> Gate<F, D> for PublicInputGate {
fn eval_unfiltered_base_one(
&self,
vars: EvaluationVarsBase<F>,
mut yield_constr: StridedConstraintConsumer<F>,
_vars: EvaluationVarsBase<F>,
_yield_constr: StridedConstraintConsumer<F>,
) {
yield_constr.many(
Self::wires_public_inputs_hash()
.zip(vars.public_inputs_hash.elements)
.map(|(wire, hash_part)| vars.local_wires[wire] - hash_part),
)
panic!("use eval_unfiltered_base_packed instead");
}
fn eval_unfiltered_base_batch(&self, vars_base: EvaluationVarsBaseBatch<F>) -> Vec<F> {
self.eval_unfiltered_base_batch_packed(vars_base)
}
fn eval_unfiltered_recursively(
@ -80,6 +85,20 @@ impl<F: Extendable<D>, const D: usize> Gate<F, D> for PublicInputGate {
}
}
impl<F: Extendable<D>, const D: usize> PackedEvaluableBase<F, D> for PublicInputGate {
fn eval_unfiltered_base_packed<P: PackedField<Scalar = F>>(
&self,
vars: EvaluationVarsBasePacked<P>,
mut yield_constr: StridedConstraintConsumer<P>,
) {
yield_constr.many(
Self::wires_public_inputs_hash()
.zip(vars.public_inputs_hash.elements)
.map(|(wire, hash_part)| vars.local_wires[wire] - hash_part),
)
}
}
#[cfg(test)]
mod tests {
use crate::field::goldilocks_field::GoldilocksField;

View File

@ -5,7 +5,9 @@ use itertools::Itertools;
use crate::field::extension_field::target::ExtensionTarget;
use crate::field::extension_field::Extendable;
use crate::field::field_types::Field;
use crate::field::packed_field::PackedField;
use crate::gates::gate::Gate;
use crate::gates::packed_util::PackedEvaluableBase;
use crate::gates::util::StridedConstraintConsumer;
use crate::iop::generator::{GeneratedValues, SimpleGenerator, WitnessGenerator};
use crate::iop::target::Target;
@ -13,7 +15,10 @@ use crate::iop::wire::Wire;
use crate::iop::witness::{PartitionWitness, Witness};
use crate::plonk::circuit_builder::CircuitBuilder;
use crate::plonk::circuit_data::CircuitConfig;
use crate::plonk::vars::{EvaluationTargets, EvaluationVars, EvaluationVarsBase};
use crate::plonk::vars::{
EvaluationTargets, EvaluationVars, EvaluationVarsBase, EvaluationVarsBaseBatch,
EvaluationVarsBasePacked,
};
/// A gate for checking that a particular element of a list matches a given value.
#[derive(Copy, Clone, Debug)]
@ -99,14 +104,14 @@ impl<F: Extendable<D>, const D: usize> Gate<F, D> for RandomAccessGate<F, D> {
// Assert that each bit wire value is indeed boolean.
for &b in &bits {
constraints.push(b * (b - F::Extension::ONE));
constraints.push(b * (b - <F::Extension as Field>::ONE));
}
// Assert that the binary decomposition was correct.
let reconstructed_index = bits
.iter()
.rev()
.fold(F::Extension::ZERO, |acc, &b| acc.double() + b);
.fold(<F::Extension as Field>::ZERO, |acc, &b| acc.double() + b);
constraints.push(reconstructed_index - access_index);
// Repeatedly fold the list, selecting the left or right item from each pair based on
@ -128,41 +133,14 @@ impl<F: Extendable<D>, const D: usize> Gate<F, D> for RandomAccessGate<F, D> {
fn eval_unfiltered_base_one(
&self,
vars: EvaluationVarsBase<F>,
mut yield_constr: StridedConstraintConsumer<F>,
_vars: EvaluationVarsBase<F>,
_yield_constr: StridedConstraintConsumer<F>,
) {
for copy in 0..self.num_copies {
let access_index = vars.local_wires[self.wire_access_index(copy)];
let mut list_items = (0..self.vec_size())
.map(|i| vars.local_wires[self.wire_list_item(i, copy)])
.collect::<Vec<_>>();
let claimed_element = vars.local_wires[self.wire_claimed_element(copy)];
let bits = (0..self.bits)
.map(|i| vars.local_wires[self.wire_bit(i, copy)])
.collect::<Vec<_>>();
panic!("use eval_unfiltered_base_packed instead");
}
// Assert that each bit wire value is indeed boolean.
for &b in &bits {
yield_constr.one(b * (b - F::ONE));
}
// Assert that the binary decomposition was correct.
let reconstructed_index = bits.iter().rev().fold(F::ZERO, |acc, &b| acc.double() + b);
yield_constr.one(reconstructed_index - access_index);
// Repeatedly fold the list, selecting the left or right item from each pair based on
// the corresponding bit.
for b in bits {
list_items = list_items
.iter()
.tuples()
.map(|(&x, &y)| x + b * (y - x))
.collect()
}
debug_assert_eq!(list_items.len(), 1);
yield_constr.one(list_items[0] - claimed_element);
}
fn eval_unfiltered_base_batch(&self, vars_base: EvaluationVarsBaseBatch<F>) -> Vec<F> {
self.eval_unfiltered_base_batch_packed(vars_base)
}
fn eval_unfiltered_recursively(
@ -251,6 +229,47 @@ impl<F: Extendable<D>, const D: usize> Gate<F, D> for RandomAccessGate<F, D> {
}
}
impl<F: Extendable<D>, const D: usize> PackedEvaluableBase<F, D> for RandomAccessGate<F, D> {
fn eval_unfiltered_base_packed<P: PackedField<Scalar = F>>(
&self,
vars: EvaluationVarsBasePacked<P>,
mut yield_constr: StridedConstraintConsumer<P>,
) {
for copy in 0..self.num_copies {
let access_index = vars.local_wires[self.wire_access_index(copy)];
let mut list_items = (0..self.vec_size())
.map(|i| vars.local_wires[self.wire_list_item(i, copy)])
.collect::<Vec<_>>();
let claimed_element = vars.local_wires[self.wire_claimed_element(copy)];
let bits = (0..self.bits)
.map(|i| vars.local_wires[self.wire_bit(i, copy)])
.collect::<Vec<_>>();
// Assert that each bit wire value is indeed boolean.
for &b in &bits {
yield_constr.one(b * (b - F::ONE));
}
// Assert that the binary decomposition was correct.
let reconstructed_index = bits.iter().rev().fold(P::ZERO, |acc, &b| acc + acc + b);
yield_constr.one(reconstructed_index - access_index);
// Repeatedly fold the list, selecting the left or right item from each pair based on
// the corresponding bit.
for b in bits {
list_items = list_items
.iter()
.tuples()
.map(|(&x, &y)| x + b * (y - x))
.collect()
}
debug_assert_eq!(list_items.len(), 1);
yield_constr.one(list_items[0] - claimed_element);
}
}
}
#[derive(Debug)]
struct RandomAccessGenerator<F: Extendable<D>, const D: usize> {
gate_index: usize,

View File

@ -3,7 +3,9 @@ use std::marker::PhantomData;
use crate::field::extension_field::target::ExtensionTarget;
use crate::field::extension_field::Extendable;
use crate::field::field_types::{Field, RichField};
use crate::field::packed_field::PackedField;
use crate::gates::gate::Gate;
use crate::gates::packed_util::PackedEvaluableBase;
use crate::gates::util::StridedConstraintConsumer;
use crate::iop::generator::{GeneratedValues, SimpleGenerator, WitnessGenerator};
use crate::iop::target::Target;
@ -11,7 +13,10 @@ use crate::iop::wire::Wire;
use crate::iop::witness::{PartitionWitness, Witness};
use crate::plonk::circuit_builder::CircuitBuilder;
use crate::plonk::circuit_data::CircuitConfig;
use crate::plonk::vars::{EvaluationTargets, EvaluationVars, EvaluationVarsBase};
use crate::plonk::vars::{
EvaluationTargets, EvaluationVars, EvaluationVarsBase, EvaluationVarsBaseBatch,
EvaluationVarsBasePacked,
};
/// A gate to perform a subtraction on 32-bit limbs: given `x`, `y`, and `borrow`, it returns
/// the result `x - y - borrow` and, if this underflows, a new `borrow`. Inputs are not range-checked.
@ -93,7 +98,7 @@ impl<F: RichField + Extendable<D>, const D: usize> Gate<F, D> for U32Subtraction
constraints.push(output_result - (result_initial + base * output_borrow));
// Range-check output_result to be at most 32 bits.
let mut combined_limbs = F::Extension::ZERO;
let mut combined_limbs = <F::Extension as Field>::ZERO;
let limb_base = F::Extension::from_canonical_u64(1u64 << Self::limb_bits());
for j in (0..Self::num_limbs()).rev() {
let this_limb = vars.local_wires[self.wire_ith_output_jth_limb(i, j)];
@ -108,7 +113,7 @@ impl<F: RichField + Extendable<D>, const D: usize> Gate<F, D> for U32Subtraction
constraints.push(combined_limbs - output_result);
// Range-check output_borrow to be one bit.
constraints.push(output_borrow * (F::Extension::ONE - output_borrow));
constraints.push(output_borrow * (<F::Extension as Field>::ONE - output_borrow));
}
constraints
@ -116,40 +121,14 @@ impl<F: RichField + Extendable<D>, const D: usize> Gate<F, D> for U32Subtraction
fn eval_unfiltered_base_one(
&self,
vars: EvaluationVarsBase<F>,
mut yield_constr: StridedConstraintConsumer<F>,
_vars: EvaluationVarsBase<F>,
_yield_constr: StridedConstraintConsumer<F>,
) {
for i in 0..self.num_ops {
let input_x = vars.local_wires[self.wire_ith_input_x(i)];
let input_y = vars.local_wires[self.wire_ith_input_y(i)];
let input_borrow = vars.local_wires[self.wire_ith_input_borrow(i)];
panic!("use eval_unfiltered_base_packed instead");
}
let result_initial = input_x - input_y - input_borrow;
let base = F::from_canonical_u64(1 << 32u64);
let output_result = vars.local_wires[self.wire_ith_output_result(i)];
let output_borrow = vars.local_wires[self.wire_ith_output_borrow(i)];
yield_constr.one(output_result - (result_initial + base * output_borrow));
// Range-check output_result to be at most 32 bits.
let mut combined_limbs = F::ZERO;
let limb_base = F::from_canonical_u64(1u64 << Self::limb_bits());
for j in (0..Self::num_limbs()).rev() {
let this_limb = vars.local_wires[self.wire_ith_output_jth_limb(i, j)];
let max_limb = 1 << Self::limb_bits();
let product = (0..max_limb)
.map(|x| this_limb - F::from_canonical_usize(x))
.product();
yield_constr.one(product);
combined_limbs = limb_base * combined_limbs + this_limb;
}
yield_constr.one(combined_limbs - output_result);
// Range-check output_borrow to be one bit.
yield_constr.one(output_borrow * (F::ONE - output_borrow));
}
fn eval_unfiltered_base_batch(&self, vars_base: EvaluationVarsBaseBatch<F>) -> Vec<F> {
self.eval_unfiltered_base_batch_packed(vars_base)
}
fn eval_unfiltered_recursively(
@ -240,6 +219,48 @@ impl<F: RichField + Extendable<D>, const D: usize> Gate<F, D> for U32Subtraction
}
}
impl<F: RichField + Extendable<D>, const D: usize> PackedEvaluableBase<F, D>
for U32SubtractionGate<F, D>
{
fn eval_unfiltered_base_packed<P: PackedField<Scalar = F>>(
&self,
vars: EvaluationVarsBasePacked<P>,
mut yield_constr: StridedConstraintConsumer<P>,
) {
for i in 0..self.num_ops {
let input_x = vars.local_wires[self.wire_ith_input_x(i)];
let input_y = vars.local_wires[self.wire_ith_input_y(i)];
let input_borrow = vars.local_wires[self.wire_ith_input_borrow(i)];
let result_initial = input_x - input_y - input_borrow;
let base = F::from_canonical_u64(1 << 32u64);
let output_result = vars.local_wires[self.wire_ith_output_result(i)];
let output_borrow = vars.local_wires[self.wire_ith_output_borrow(i)];
yield_constr.one(output_result - (result_initial + output_borrow * base));
// Range-check output_result to be at most 32 bits.
let mut combined_limbs = P::ZERO;
let limb_base = F::from_canonical_u64(1u64 << Self::limb_bits());
for j in (0..Self::num_limbs()).rev() {
let this_limb = vars.local_wires[self.wire_ith_output_jth_limb(i, j)];
let max_limb = 1 << Self::limb_bits();
let product = (0..max_limb)
.map(|x| this_limb - F::from_canonical_usize(x))
.product();
yield_constr.one(product);
combined_limbs = combined_limbs * limb_base + this_limb;
}
yield_constr.one(combined_limbs - output_result);
// Range-check output_borrow to be one bit.
yield_constr.one(output_borrow * (P::ONE - output_borrow));
}
}
}
#[derive(Clone, Debug)]
struct U32SubtractionGenerator<F: RichField + Extendable<D>, const D: usize> {
gate: U32SubtractionGate<F, D>,

View File

@ -5,7 +5,9 @@ use array_tool::vec::Union;
use crate::field::extension_field::target::ExtensionTarget;
use crate::field::extension_field::Extendable;
use crate::field::field_types::Field;
use crate::field::packed_field::PackedField;
use crate::gates::gate::Gate;
use crate::gates::packed_util::PackedEvaluableBase;
use crate::gates::util::StridedConstraintConsumer;
use crate::iop::generator::{GeneratedValues, WitnessGenerator};
use crate::iop::target::Target;
@ -13,7 +15,10 @@ use crate::iop::wire::Wire;
use crate::iop::witness::{PartitionWitness, Witness};
use crate::plonk::circuit_builder::CircuitBuilder;
use crate::plonk::circuit_data::CircuitConfig;
use crate::plonk::vars::{EvaluationTargets, EvaluationVars, EvaluationVarsBase};
use crate::plonk::vars::{
EvaluationTargets, EvaluationVars, EvaluationVarsBase, EvaluationVarsBaseBatch,
EvaluationVarsBasePacked,
};
/// A gate for conditionally swapping input values based on a boolean.
#[derive(Clone, Debug)]
@ -77,7 +82,7 @@ impl<F: Extendable<D>, const D: usize> Gate<F, D> for SwitchGate<F, D> {
for c in 0..self.num_copies {
let switch_bool = vars.local_wires[self.wire_switch_bool(c)];
let not_switch = F::Extension::ONE - switch_bool;
let not_switch = <F::Extension as Field>::ONE - switch_bool;
for e in 0..self.chunk_size {
let first_input = vars.local_wires[self.wire_first_input(c, e)];
@ -97,25 +102,14 @@ impl<F: Extendable<D>, const D: usize> Gate<F, D> for SwitchGate<F, D> {
fn eval_unfiltered_base_one(
&self,
vars: EvaluationVarsBase<F>,
mut yield_constr: StridedConstraintConsumer<F>,
_vars: EvaluationVarsBase<F>,
_yield_constr: StridedConstraintConsumer<F>,
) {
for c in 0..self.num_copies {
let switch_bool = vars.local_wires[self.wire_switch_bool(c)];
let not_switch = F::ONE - switch_bool;
panic!("use eval_unfiltered_base_packed instead");
}
for e in 0..self.chunk_size {
let first_input = vars.local_wires[self.wire_first_input(c, e)];
let second_input = vars.local_wires[self.wire_second_input(c, e)];
let first_output = vars.local_wires[self.wire_first_output(c, e)];
let second_output = vars.local_wires[self.wire_second_output(c, e)];
yield_constr.one(switch_bool * (first_input - second_output));
yield_constr.one(switch_bool * (second_input - first_output));
yield_constr.one(not_switch * (first_input - first_output));
yield_constr.one(not_switch * (second_input - second_output));
}
}
fn eval_unfiltered_base_batch(&self, vars_base: EvaluationVarsBaseBatch<F>) -> Vec<F> {
self.eval_unfiltered_base_batch_packed(vars_base)
}
fn eval_unfiltered_recursively(
@ -194,6 +188,31 @@ impl<F: Extendable<D>, const D: usize> Gate<F, D> for SwitchGate<F, D> {
}
}
impl<F: Extendable<D>, const D: usize> PackedEvaluableBase<F, D> for SwitchGate<F, D> {
fn eval_unfiltered_base_packed<P: PackedField<Scalar = F>>(
&self,
vars: EvaluationVarsBasePacked<P>,
mut yield_constr: StridedConstraintConsumer<P>,
) {
for c in 0..self.num_copies {
let switch_bool = vars.local_wires[self.wire_switch_bool(c)];
let not_switch = P::ONE - switch_bool;
for e in 0..self.chunk_size {
let first_input = vars.local_wires[self.wire_first_input(c, e)];
let second_input = vars.local_wires[self.wire_second_input(c, e)];
let first_output = vars.local_wires[self.wire_first_output(c, e)];
let second_output = vars.local_wires[self.wire_second_output(c, e)];
yield_constr.one(switch_bool * (first_input - second_output));
yield_constr.one(switch_bool * (second_input - first_output));
yield_constr.one(not_switch * (first_input - first_output));
yield_constr.one(not_switch * (second_input - second_output));
}
}
}
}
#[derive(Debug)]
struct SwitchGenerator<F: Extendable<D>, const D: usize> {
gate_index: usize,

View File

@ -1,22 +1,23 @@
use std::marker::PhantomData;
use crate::field::field_types::Field;
use crate::field::packed_field::PackedField;
/// Writes constraints yielded by a gate to a buffer, with a given stride.
/// Permits us to abstract the underlying memory layout. In particular, we can make a matrix of
/// constraints where every column is an evaluation point and every row is a constraint index, with
/// the matrix stored in row-contiguous form.
pub struct StridedConstraintConsumer<'a, F: Field> {
pub struct StridedConstraintConsumer<'a, P: PackedField> {
// This is a particularly neat way of doing this, more so than a slice. We increase start by
// stride at every step and terminate when it equals end.
start: *mut F,
end: *mut F,
start: *mut P::Scalar,
end: *mut P::Scalar,
stride: usize,
_phantom: PhantomData<&'a mut [F]>,
_phantom: PhantomData<&'a mut [P::Scalar]>,
}
impl<'a, F: Field> StridedConstraintConsumer<'a, F> {
pub fn new(buffer: &'a mut [F], stride: usize, offset: usize) -> Self {
impl<'a, P: PackedField> StridedConstraintConsumer<'a, P> {
pub fn new(buffer: &'a mut [P::Scalar], stride: usize, offset: usize) -> Self {
assert!(stride >= P::WIDTH);
assert!(offset < stride);
assert_eq!(buffer.len() % stride, 0);
let ptr_range = buffer.as_mut_ptr_range();
@ -38,12 +39,12 @@ impl<'a, F: Field> StridedConstraintConsumer<'a, F> {
}
/// Emit one constraint.
pub fn one(&mut self, constraint: F) {
pub fn one(&mut self, constraint: P) {
if self.start != self.end {
// # Safety
// The checks in `new` guarantee that this points to valid space.
unsafe {
*self.start = constraint;
*self.start.cast() = constraint;
}
// See the comment in `new`. `wrapping_add` is needed to avoid UB if we've just
// exhausted our buffer (and hence we're setting `self.start` to point past the end).
@ -54,7 +55,7 @@ impl<'a, F: Field> StridedConstraintConsumer<'a, F> {
}
/// Convenience method that calls `.one()` multiple times.
pub fn many<I: IntoIterator<Item = F>>(&mut self, constraints: I) {
pub fn many<I: IntoIterator<Item = P>>(&mut self, constraints: I) {
constraints
.into_iter()
.for_each(|constraint| self.one(constraint));

View File

@ -1,6 +1,7 @@
use crate::field::extension_field::target::ExtensionTarget;
use crate::field::extension_field::Extendable;
use crate::field::field_types::{Field, RichField};
use crate::field::packed_field::PackedField;
use crate::fri::commitment::SALT_SIZE;
use crate::iop::target::Target;
use crate::plonk::circuit_builder::CircuitBuilder;
@ -157,14 +158,14 @@ pub(crate) fn reduce_with_powers_multi<
cumul
}
pub(crate) fn reduce_with_powers<'a, F: Field, T: IntoIterator<Item = &'a F>>(
pub(crate) fn reduce_with_powers<'a, P: PackedField, T: IntoIterator<Item = &'a P>>(
terms: T,
alpha: F,
) -> F
alpha: P::Scalar,
) -> P
where
T::IntoIter: DoubleEndedIterator,
{
let mut sum = F::ZERO;
let mut sum = P::ZERO;
for &term in terms.into_iter().rev() {
sum = sum * alpha + term;
}

View File

@ -4,6 +4,7 @@ use crate::field::extension_field::algebra::ExtensionAlgebra;
use crate::field::extension_field::target::{ExtensionAlgebraTarget, ExtensionTarget};
use crate::field::extension_field::{Extendable, FieldExtension};
use crate::field::field_types::Field;
use crate::field::packed_field::PackedField;
use crate::hash::hash_types::{HashOut, HashOutTarget};
use crate::util::strided_view::PackedStridedView;
@ -33,6 +34,16 @@ pub struct EvaluationVarsBase<'a, F: Field> {
pub(crate) public_inputs_hash: &'a HashOut<F>,
}
/// Like `EvaluationVarsBase`, but packed.
// It's a separate struct because `EvaluationVarsBase` implements `get_local_ext` and we do not yet
// have packed extension fields.
#[derive(Debug, Copy, Clone)]
pub struct EvaluationVarsBasePacked<'a, P: PackedField> {
pub(crate) local_constants: PackedStridedView<'a, P>,
pub(crate) local_wires: PackedStridedView<'a, P>,
pub(crate) public_inputs_hash: &'a HashOut<P::Scalar>,
}
impl<'a, F: Extendable<D>, const D: usize> EvaluationVars<'a, F, D> {
pub fn get_local_ext_algebra(
&self,
@ -88,6 +99,19 @@ impl<'a, F: Field> EvaluationVarsBaseBatch<'a, F> {
pub fn iter(&self) -> EvaluationVarsBaseBatchIter<'a, F> {
EvaluationVarsBaseBatchIter::new(*self)
}
pub fn pack<P: PackedField<Scalar = F>>(
&self,
) -> (
EvaluationVarsBaseBatchIterPacked<'a, P>,
EvaluationVarsBaseBatchIterPacked<'a, F>,
) {
let n_leftovers = self.len() % P::WIDTH;
(
EvaluationVarsBaseBatchIterPacked::new_with_start(*self, 0),
EvaluationVarsBaseBatchIterPacked::new_with_start(*self, self.len() - n_leftovers),
)
}
}
impl<'a, F: Field> EvaluationVarsBase<'a, F> {
@ -126,6 +150,63 @@ impl<'a, F: Field> Iterator for EvaluationVarsBaseBatchIter<'a, F> {
}
}
/// Iterator of packed views (`EvaluationVarsBasePacked`) into a `EvaluationVarsBaseBatch`.
/// Note: if the length of `EvaluationVarsBaseBatch` is not a multiple of `P::WIDTH`, then the
/// leftovers at the end are ignored.
pub struct EvaluationVarsBaseBatchIterPacked<'a, P: PackedField> {
/// Index to yield next, in units of `P::Scalar`. E.g. if `P::WIDTH == 4`, then we will yield
/// the vars for points `i`, `i + 1`, `i + 2`, and `i + 3`, packed.
i: usize,
vars_batch: EvaluationVarsBaseBatch<'a, P::Scalar>,
}
impl<'a, P: PackedField> EvaluationVarsBaseBatchIterPacked<'a, P> {
pub fn new_with_start(
vars_batch: EvaluationVarsBaseBatch<'a, P::Scalar>,
start: usize,
) -> Self {
assert!(start <= vars_batch.len());
EvaluationVarsBaseBatchIterPacked {
i: start,
vars_batch,
}
}
}
impl<'a, P: PackedField> Iterator for EvaluationVarsBaseBatchIterPacked<'a, P> {
type Item = EvaluationVarsBasePacked<'a, P>;
fn next(&mut self) -> Option<Self::Item> {
if self.i + P::WIDTH <= self.vars_batch.len() {
let local_constants = PackedStridedView::new(
self.vars_batch.local_constants,
self.vars_batch.len(),
self.i,
);
let local_wires =
PackedStridedView::new(self.vars_batch.local_wires, self.vars_batch.len(), self.i);
let res = EvaluationVarsBasePacked {
local_constants,
local_wires,
public_inputs_hash: self.vars_batch.public_inputs_hash,
};
self.i += P::WIDTH;
Some(res)
} else {
None
}
}
fn size_hint(&self) -> (usize, Option<usize>) {
let len = self.len();
(len, Some(len))
}
}
impl<'a, P: PackedField> ExactSizeIterator for EvaluationVarsBaseBatchIterPacked<'a, P> {
fn len(&self) -> usize {
(self.vars_batch.len() - self.i) / P::WIDTH
}
}
impl<'a, const D: usize> EvaluationTargets<'a, D> {
pub fn remove_prefix(&mut self, prefix: &[bool]) {
self.local_constants = &self.local_constants[prefix.len()..];