From e77383b5593749521151c7d32e97ecf6d175bcd4 Mon Sep 17 00:00:00 2001 From: wborgeaud Date: Thu, 17 Mar 2022 11:08:25 +0100 Subject: [PATCH] Progress --- plonky2/src/gates/add_many_u32.rs | 2 +- plonky2/src/gates/arithmetic_u32.rs | 2 +- plonky2/src/gates/assert_le.rs | 4 +- plonky2/src/gates/comparison.rs | 4 +- plonky2/src/gates/exponentiation.rs | 2 +- plonky2/src/gates/gate.rs | 39 +++++--- plonky2/src/gates/gate_testing.rs | 6 +- plonky2/src/gates/interpolation.rs | 2 +- plonky2/src/gates/low_degree_interpolation.rs | 2 +- plonky2/src/gates/random_access.rs | 4 +- plonky2/src/gates/range_check_u32.rs | 2 +- plonky2/src/gates/selectors.rs | 34 ++++++- plonky2/src/gates/subtraction_u32.rs | 2 +- plonky2/src/gates/switch.rs | 2 +- plonky2/src/plonk/circuit_builder.rs | 53 +++++----- plonky2/src/plonk/circuit_data.rs | 9 +- plonky2/src/plonk/recursive_verifier.rs | 1 + plonky2/src/plonk/vanishing_poly.rs | 97 ++++++++++--------- plonky2/src/plonk/vars.rs | 43 ++++---- plonky2/src/plonk/verifier.rs | 2 +- plonky2/src/util/strided_view.rs | 2 +- 21 files changed, 186 insertions(+), 128 deletions(-) diff --git a/plonky2/src/gates/add_many_u32.rs b/plonky2/src/gates/add_many_u32.rs index 4f9c4293..e6248399 100644 --- a/plonky2/src/gates/add_many_u32.rs +++ b/plonky2/src/gates/add_many_u32.rs @@ -448,7 +448,7 @@ mod tests { }; let vars = EvaluationVars { - local_constants: &[], + local_constants: vec![], local_wires: &get_wires(addends, carries), public_inputs_hash: &HashOut::rand(), }; diff --git a/plonky2/src/gates/arithmetic_u32.rs b/plonky2/src/gates/arithmetic_u32.rs index dc03e296..eac2d23d 100644 --- a/plonky2/src/gates/arithmetic_u32.rs +++ b/plonky2/src/gates/arithmetic_u32.rs @@ -445,7 +445,7 @@ mod tests { }; let vars = EvaluationVars { - local_constants: &[], + local_constants: vec![], local_wires: &get_wires(multiplicands_0, multiplicands_1, addends), public_inputs_hash: &HashOut::rand(), }; diff --git a/plonky2/src/gates/assert_le.rs b/plonky2/src/gates/assert_le.rs index cec7274b..3e89baf7 100644 --- a/plonky2/src/gates/assert_le.rs +++ b/plonky2/src/gates/assert_le.rs @@ -602,7 +602,7 @@ mod tests { _phantom: PhantomData, }; let less_than_vars = EvaluationVars { - local_constants: &[], + local_constants: vec![], local_wires: &get_wires(first_input, second_input), public_inputs_hash: &HashOut::rand(), }; @@ -620,7 +620,7 @@ mod tests { _phantom: PhantomData, }; let equal_vars = EvaluationVars { - local_constants: &[], + local_constants: vec![], local_wires: &get_wires(first_input, first_input), public_inputs_hash: &HashOut::rand(), }; diff --git a/plonky2/src/gates/comparison.rs b/plonky2/src/gates/comparison.rs index b1cf7b98..5207b7ae 100644 --- a/plonky2/src/gates/comparison.rs +++ b/plonky2/src/gates/comparison.rs @@ -682,7 +682,7 @@ mod tests { _phantom: PhantomData, }; let less_than_vars = EvaluationVars { - local_constants: &[], + local_constants: vec![], local_wires: &get_wires(first_input, second_input), public_inputs_hash: &HashOut::rand(), }; @@ -700,7 +700,7 @@ mod tests { _phantom: PhantomData, }; let equal_vars = EvaluationVars { - local_constants: &[], + local_constants: vec![], local_wires: &get_wires(first_input, first_input), public_inputs_hash: &HashOut::rand(), }; diff --git a/plonky2/src/gates/exponentiation.rs b/plonky2/src/gates/exponentiation.rs index 51558a21..8db135ac 100644 --- a/plonky2/src/gates/exponentiation.rs +++ b/plonky2/src/gates/exponentiation.rs @@ -394,7 +394,7 @@ mod tests { }; let vars = EvaluationVars { - local_constants: &[], + local_constants: vec![], local_wires: &get_wires(base, power as u64), public_inputs_hash: &HashOut::rand(), }; diff --git a/plonky2/src/gates/gate.rs b/plonky2/src/gates/gate.rs index 03bd0a7b..634c71f6 100644 --- a/plonky2/src/gates/gate.rs +++ b/plonky2/src/gates/gate.rs @@ -38,7 +38,7 @@ pub trait Gate, const D: usize>: 'static + Send + S ) { // Note that this method uses `yield_constr` instead of returning its constraints. // `yield_constr` abstracts out the underlying memory layout. - let local_constants = &vars_base + let local_constants = vars_base .local_constants .iter() .map(|c| F::Extension::from_basefield(*c)) @@ -80,9 +80,18 @@ pub trait Gate, const D: usize>: 'static + Send + S vars: EvaluationTargets, ) -> Vec>; - fn eval_filtered(&self, mut vars: EvaluationVars, prefix: &[bool]) -> Vec { - let filter = compute_filter(prefix, vars.local_constants); - vars.remove_prefix(prefix); + fn eval_filtered( + &self, + mut vars: EvaluationVars, + selector_index: usize, + combination_num: usize, + ) -> Vec { + let filter = compute_filter( + selector_index, + combination_num, + vars.local_constants[selector_index], + ); + vars.remove_prefix(selector_index); self.eval_unfiltered(vars) .into_iter() .map(|c| filter * c) @@ -94,13 +103,20 @@ pub trait Gate, const D: usize>: 'static + Send + S fn eval_filtered_base_batch( &self, mut vars_batch: EvaluationVarsBaseBatch, - prefix: &[bool], + selector_index: usize, + combination_num: usize, ) -> Vec { let filters: Vec<_> = vars_batch .iter() - .map(|vars| compute_filter(prefix, vars.local_constants)) + .map(|vars| { + compute_filter( + selector_index, + combination_num, + vars.local_constants[selector_index], + ) + }) .collect(); - vars_batch.remove_prefix(prefix); + vars_batch.remove_prefix(selector_index); let mut res_batch = self.eval_unfiltered_base_batch(vars_batch); for res_chunk in res_batch.chunks_exact_mut(filters.len()) { batch_multiply_inplace(res_chunk, &filters); @@ -213,11 +229,10 @@ impl, const D: usize> PrefixedGate { /// A gate's filter is computed as `prod b_i*c_i + (1-b_i)*(1-c_i)`, with `(b_i)` the prefix and /// `(c_i)` the local constants, which is one if the prefix of `constants` matches `prefix`. -fn compute_filter<'a, K: Field, T: IntoIterator>(prefix: &[bool], constants: T) -> K { - prefix - .iter() - .zip(constants) - .map(|(&b, &c)| if b { c } else { K::ONE - c }) +fn compute_filter<'a, K: Field>(selector_index: usize, combination_num: usize, constant: K) -> K { + (0..combination_num) + .filter(|&i| i != selector_index) + .map(|i| K::from_canonical_usize(i) - constant) .product() } diff --git a/plonky2/src/gates/gate_testing.rs b/plonky2/src/gates/gate_testing.rs index 51768ba8..20e45e32 100644 --- a/plonky2/src/gates/gate_testing.rs +++ b/plonky2/src/gates/gate_testing.rs @@ -32,7 +32,7 @@ pub fn test_low_degree, G: Gate, const D: usi .iter() .zip(constant_ldes.iter()) .map(|(local_wires, local_constants)| EvaluationVars { - local_constants, + local_constants: local_constants.to_vec(), local_wires, public_inputs_hash, }) @@ -113,7 +113,7 @@ where let vars_base_batch = EvaluationVarsBaseBatch::new(1, &constants_base, &wires_base, &public_inputs_hash); let vars = EvaluationVars { - local_constants: &constants, + local_constants: constants, local_wires: &wires, public_inputs_hash: &public_inputs_hash, }; @@ -145,7 +145,7 @@ where pw.set_hash_target(public_inputs_hash_t, public_inputs_hash); let vars = EvaluationVars { - local_constants: &constants, + local_constants: constants, local_wires: &wires, public_inputs_hash: &public_inputs_hash, }; diff --git a/plonky2/src/gates/interpolation.rs b/plonky2/src/gates/interpolation.rs index 46c42113..4f1f6b33 100644 --- a/plonky2/src/gates/interpolation.rs +++ b/plonky2/src/gates/interpolation.rs @@ -352,7 +352,7 @@ mod tests { let eval_point = FF::rand(); let gate = HighDegreeInterpolationGate::::new(1); let vars = EvaluationVars { - local_constants: &[], + local_constants: vec![], local_wires: &get_wires(&gate, shift, coeffs, eval_point), public_inputs_hash: &HashOut::rand(), }; diff --git a/plonky2/src/gates/low_degree_interpolation.rs b/plonky2/src/gates/low_degree_interpolation.rs index 845da5ab..02943106 100644 --- a/plonky2/src/gates/low_degree_interpolation.rs +++ b/plonky2/src/gates/low_degree_interpolation.rs @@ -453,7 +453,7 @@ mod tests { let eval_point = FF::rand(); let gate = LowDegreeInterpolationGate::::new(subgroup_bits); let vars = EvaluationVars { - local_constants: &[], + local_constants: vec![], local_wires: &get_wires(&gate, shift, coeffs, eval_point), public_inputs_hash: &HashOut::rand(), }; diff --git a/plonky2/src/gates/random_access.rs b/plonky2/src/gates/random_access.rs index 6379f99f..0d84fa81 100644 --- a/plonky2/src/gates/random_access.rs +++ b/plonky2/src/gates/random_access.rs @@ -413,7 +413,7 @@ mod tests { .map(|(l, &i)| l[i]) .collect(); let good_vars = EvaluationVars { - local_constants: &[], + local_constants: vec![], local_wires: &get_wires( bits, lists.clone(), @@ -424,7 +424,7 @@ mod tests { }; let bad_claimed_elements = F::rand_vec(4); let bad_vars = EvaluationVars { - local_constants: &[], + local_constants: vec![], local_wires: &get_wires(bits, lists, access_indices, bad_claimed_elements), public_inputs_hash: &HashOut::rand(), }; diff --git a/plonky2/src/gates/range_check_u32.rs b/plonky2/src/gates/range_check_u32.rs index 79e91de8..f51b246e 100644 --- a/plonky2/src/gates/range_check_u32.rs +++ b/plonky2/src/gates/range_check_u32.rs @@ -292,7 +292,7 @@ mod tests { }; let vars = EvaluationVars { - local_constants: &[], + local_constants: vec![], local_wires: &get_wires(input_limbs), public_inputs_hash: &HashOut::rand(), }; diff --git a/plonky2/src/gates/selectors.rs b/plonky2/src/gates/selectors.rs index 11463512..23db1256 100644 --- a/plonky2/src/gates/selectors.rs +++ b/plonky2/src/gates/selectors.rs @@ -8,9 +8,8 @@ pub(crate) fn compute_selectors, const D: usize>( mut gates: Vec>, instances: &[GateInstance], max_degree: usize, -) { +) -> (Vec>, Vec, Vec) { let n = instances.len(); - gates.sort_unstable_by_key(|g| g.0.degree()); let mut combinations = Vec::new(); let mut pos = 0; @@ -25,16 +24,41 @@ pub(crate) fn compute_selectors, const D: usize>( } let num_constants_polynomials = - 0.max(gates.iter().map(|g| g.0.num_constants()).max().unwrap() - combinations.len() - 1); + 0.max(gates.iter().map(|g| g.0.num_constants()).max().unwrap() - combinations.len() + 1); let mut polynomials = vec![PolynomialValues::zero(n); combinations.len() + num_constants_polynomials]; let index = |id| gates.iter().position(|g| g.0.id() == id).unwrap(); let combination = |i| combinations.iter().position(|&(a, _)| a <= i).unwrap(); + let selector_indices = gates + .iter() + .map(|g| combination(index(g.0.id()))) + .collect::>(); + let combination_nums = selector_indices + .iter() + .map(|&i| combinations[i].1 - combinations[i].0) + .collect(); + for (j, g) in instances.iter().enumerate() { - let i = index(g.gate_ref.0.id()); + let GateInstance { + gate_ref, + constants, + } = g; + let i = index(gate_ref.0.id()); let comb = combination(i); - polynomials[comb].values[j] = i - combinations[comb].0; + polynomials[comb].values[j] = F::from_canonical_usize(i - combinations[comb].0); + let mut k = 0; + let mut constant_ind = 0; + while k < constants.len() { + if constant_ind == comb { + constant_ind += 1; + } else { + polynomials[constant_ind].values[j] = constants[k]; + constant_ind += 1; + k += 1; + } + } } + (polynomials, selector_indices, combination_nums) } diff --git a/plonky2/src/gates/subtraction_u32.rs b/plonky2/src/gates/subtraction_u32.rs index b1e4d84f..afa212d5 100644 --- a/plonky2/src/gates/subtraction_u32.rs +++ b/plonky2/src/gates/subtraction_u32.rs @@ -437,7 +437,7 @@ mod tests { }; let vars = EvaluationVars { - local_constants: &[], + local_constants: vec![], local_wires: &get_wires(inputs_x, inputs_y, borrows), public_inputs_hash: &HashOut::rand(), }; diff --git a/plonky2/src/gates/switch.rs b/plonky2/src/gates/switch.rs index bd298762..3cfa3349 100644 --- a/plonky2/src/gates/switch.rs +++ b/plonky2/src/gates/switch.rs @@ -446,7 +446,7 @@ mod tests { }; let vars = EvaluationVars { - local_constants: &[], + local_constants: vec![], local_wires: &get_wires(first_inputs, second_inputs, switch_bools), public_inputs_hash: &HashOut::rand(), }; diff --git a/plonky2/src/plonk/circuit_builder.rs b/plonky2/src/plonk/circuit_builder.rs index 50b5931a..8fe38da8 100644 --- a/plonky2/src/plonk/circuit_builder.rs +++ b/plonky2/src/plonk/circuit_builder.rs @@ -670,34 +670,34 @@ impl, const D: usize> CircuitBuilder { "FRI total reduction arity is too large.", ); - let gates = self.gates.iter().cloned().collect::>(); - for g in &gates { - println!("{} {}", g.0.id(), g.0.num_constants()); - } - dbg!(compute_selectors( + let mut gates = self.gates.iter().cloned().collect::>(); + gates.sort_unstable_by_key(|g| g.0.degree()); + let (constant_vecs, selector_indices, combination_nums) = compute_selectors( gates.clone(), &self.gate_instances, - self.config.max_quotient_degree_factor + 1 - )); - let (gate_tree, max_filtered_constraint_degree, num_constants) = Tree::from_gates(gates); - let prefixed_gates = PrefixedGate::from_tree(gate_tree); - - // `quotient_degree_factor` has to be between `max_filtered_constraint_degree-1` and `1<, const D: usize> CircuitBuilder { fft_root_table: Some(fft_root_table), }; - // The HashSet of gates will have a non-deterministic order. When converting to a Vec, we - // sort by ID to make the ordering deterministic. - let mut gates = self.gates.iter().cloned().collect::>(); - gates.sort_unstable_by_key(|gate| gate.0.id()); - let num_gate_constraints = gates .iter() .map(|gate| gate.0.num_constraints()) @@ -802,7 +797,9 @@ impl, const D: usize> CircuitBuilder { config: self.config, fri_params, degree_bits, - gates: prefixed_gates, + gates, + selector_indices, + combination_nums, quotient_degree_factor, num_gate_constraints, num_constants, diff --git a/plonky2/src/plonk/circuit_data.rs b/plonky2/src/plonk/circuit_data.rs index f0fa4bd5..c0bf2251 100644 --- a/plonky2/src/plonk/circuit_data.rs +++ b/plonky2/src/plonk/circuit_data.rs @@ -12,7 +12,7 @@ use crate::fri::structure::{ FriBatchInfo, FriBatchInfoTarget, FriInstanceInfo, FriInstanceInfoTarget, FriPolynomialInfo, }; use crate::fri::{FriConfig, FriParams}; -use crate::gates::gate::PrefixedGate; +use crate::gates::gate::{Gate, GateRef, PrefixedGate}; use crate::hash::hash_types::{MerkleCapTarget, RichField}; use crate::hash::merkle_tree::MerkleCap; use crate::iop::ext_target::ExtensionTarget; @@ -246,7 +246,10 @@ pub struct CommonCircuitData< pub(crate) degree_bits: usize, /// The types of gates used in this circuit, along with their prefixes. - pub(crate) gates: Vec>, + pub(crate) gates: Vec>, + + pub(crate) selector_indices: Vec, + pub(crate) combination_nums: Vec, /// The degree of the PLONK quotient polynomial. pub(crate) quotient_degree_factor: usize, @@ -290,7 +293,7 @@ impl, C: GenericConfig, const D: usize> pub fn constraint_degree(&self) -> usize { self.gates .iter() - .map(|g| g.gate.0.degree()) + .map(|g| g.0.degree()) .max() .expect("No gates?") } diff --git a/plonky2/src/plonk/recursive_verifier.rs b/plonky2/src/plonk/recursive_verifier.rs index 2fe7d648..fb4238ee 100644 --- a/plonky2/src/plonk/recursive_verifier.rs +++ b/plonky2/src/plonk/recursive_verifier.rs @@ -224,6 +224,7 @@ mod tests { // Start with a degree 2^14 proof let (proof, vd, cd) = dummy_proof::(&config, 16_000)?; assert_eq!(cd.degree_bits, 14); + test_serialization(&proof, &cd)?; // Shrink it to 2^13. let (proof, vd, cd) = diff --git a/plonky2/src/plonk/vanishing_poly.rs b/plonky2/src/plonk/vanishing_poly.rs index 70de5833..c5e0d563 100644 --- a/plonky2/src/plonk/vanishing_poly.rs +++ b/plonky2/src/plonk/vanishing_poly.rs @@ -3,7 +3,7 @@ use plonky2_field::extension_field::{Extendable, FieldExtension}; use plonky2_field::field_types::Field; use plonky2_field::zero_poly_coset::ZeroPolyOnCoset; -use crate::gates::gate::PrefixedGate; +use crate::gates::gate::{Gate, GateRef, PrefixedGate}; use crate::hash::hash_types::RichField; use crate::iop::ext_target::ExtensionTarget; use crate::iop::target::Target; @@ -40,8 +40,7 @@ pub(crate) fn eval_vanishing_poly< let max_degree = common_data.quotient_degree_factor; let num_prods = common_data.num_partial_products; - let constraint_terms = - evaluate_gate_constraints(&common_data.gates, common_data.num_gate_constraints, vars); + let constraint_terms = evaluate_gate_constraints(common_data, vars.clone()); // The L_1(x) (Z(x) - 1) vanishing terms. let mut vanishing_z_1_terms = Vec::new(); @@ -129,7 +128,7 @@ pub(crate) fn eval_vanishing_poly_base_batch< let num_gate_constraints = common_data.num_gate_constraints; let constraint_terms_batch = - evaluate_gate_constraints_base_batch(&common_data.gates, num_gate_constraints, vars_batch); + evaluate_gate_constraints_base_batch(&common_data, vars_batch.clone()); debug_assert!(constraint_terms_batch.len() == n * num_gate_constraints); let num_challenges = common_data.config.num_challenges; @@ -153,7 +152,7 @@ pub(crate) fn eval_vanishing_poly_base_batch< let partial_products = partial_products_batch[k]; let s_sigmas = s_sigmas_batch[k]; - let constraint_terms = PackedStridedView::new(&constraint_terms_batch, n, k); + let constraint_terms = PackedStridedView::new(constraint_terms_batch.clone(), n, k); let l1_x = z_h_on_coset.eval_l1(index, x); for i in 0..num_challenges { @@ -208,17 +207,24 @@ pub(crate) fn eval_vanishing_poly_base_batch< /// `num_gate_constraints` is the largest number of constraints imposed by any gate. It is not /// strictly necessary, but it helps performance by ensuring that we allocate a vector with exactly /// the capacity that we need. -pub fn evaluate_gate_constraints, const D: usize>( - gates: &[PrefixedGate], - num_gate_constraints: usize, +pub fn evaluate_gate_constraints< + F: RichField + Extendable, + C: GenericConfig, + const D: usize, +>( + common_data: &CommonCircuitData, vars: EvaluationVars, ) -> Vec { - let mut constraints = vec![F::Extension::ZERO; num_gate_constraints]; - for gate in gates { - let gate_constraints = gate.gate.0.eval_filtered(vars, &gate.prefix); + let mut constraints = vec![F::Extension::ZERO; common_data.num_gate_constraints]; + for (i, gate) in common_data.gates.iter().enumerate() { + let gate_constraints = gate.0.eval_filtered( + vars.clone(), + common_data.selector_indices[i], + common_data.combination_nums[i], + ); for (i, c) in gate_constraints.into_iter().enumerate() { debug_assert!( - i < num_gate_constraints, + i < common_data.num_gate_constraints, "num_constraints() gave too low of a number" ); constraints[i] += c; @@ -232,17 +238,21 @@ pub fn evaluate_gate_constraints, const D: usize>( /// Returns a vector of `num_gate_constraints * vars_batch.len()` field elements. The constraints /// corresponding to `vars_batch[i]` are found in `result[i], result[vars_batch.len() + i], /// result[2 * vars_batch.len() + i], ...`. -pub fn evaluate_gate_constraints_base_batch, const D: usize>( - gates: &[PrefixedGate], - num_gate_constraints: usize, +pub fn evaluate_gate_constraints_base_batch< + F: RichField + Extendable, + C: GenericConfig, + const D: usize, +>( + common_data: &CommonCircuitData, vars_batch: EvaluationVarsBaseBatch, ) -> Vec { - let mut constraints_batch = vec![F::ZERO; num_gate_constraints * vars_batch.len()]; - for gate in gates { - let gate_constraints_batch = gate - .gate - .0 - .eval_filtered_base_batch(vars_batch, &gate.prefix); + let mut constraints_batch = vec![F::ZERO; common_data.num_gate_constraints * vars_batch.len()]; + for (i, gate) in common_data.gates.iter().enumerate() { + let gate_constraints_batch = gate.0.eval_filtered_base_batch( + vars_batch.clone(), + common_data.selector_indices[i], + common_data.combination_nums[i], + ); debug_assert!( gate_constraints_batch.len() <= constraints_batch.len(), "num_constraints() gave too low of a number" @@ -256,26 +266,30 @@ pub fn evaluate_gate_constraints_base_batch, const constraints_batch } -pub fn evaluate_gate_constraints_recursively, const D: usize>( +pub fn evaluate_gate_constraints_recursively< + F: RichField + Extendable, + C: GenericConfig, + const D: usize, +>( builder: &mut CircuitBuilder, - gates: &[PrefixedGate], - num_gate_constraints: usize, + common_data: &CommonCircuitData, vars: EvaluationTargets, ) -> Vec> { - let mut all_gate_constraints = vec![builder.zero_extension(); num_gate_constraints]; - for gate in gates { - with_context!( - builder, - &format!("evaluate {} constraints", gate.gate.0.id()), - gate.gate.0.eval_filtered_recursively( - builder, - vars, - &gate.prefix, - &mut all_gate_constraints - ) - ); - } - all_gate_constraints + todo!(); + // let mut all_gate_constraints = vec![builder.zero_extension(); num_gate_constraints]; + // for gate in gates { + // with_context!( + // builder, + // &format!("evaluate {} constraints", gate.gate.0.id()), + // gate.gate.0.eval_filtered_recursively( + // builder, + // vars, + // &gate.prefix, + // &mut all_gate_constraints + // ) + // ); + // } + // all_gate_constraints } /// Evaluate the vanishing polynomial at `x`. In this context, the vanishing polynomial is a random @@ -308,12 +322,7 @@ pub(crate) fn eval_vanishing_poly_recursively< let constraint_terms = with_context!( builder, "evaluate gate constraints", - evaluate_gate_constraints_recursively( - builder, - &common_data.gates, - common_data.num_gate_constraints, - vars, - ) + evaluate_gate_constraints_recursively(builder, common_data, vars,) ); // The L_1(x) (Z(x) - 1) vanishing terms. diff --git a/plonky2/src/plonk/vars.rs b/plonky2/src/plonk/vars.rs index e2e52cfb..b9001703 100644 --- a/plonky2/src/plonk/vars.rs +++ b/plonky2/src/plonk/vars.rs @@ -9,9 +9,9 @@ use crate::hash::hash_types::{HashOut, HashOutTarget, RichField}; use crate::iop::ext_target::{ExtensionAlgebraTarget, ExtensionTarget}; use crate::util::strided_view::PackedStridedView; -#[derive(Debug, Copy, Clone)] +#[derive(Debug, Clone)] pub struct EvaluationVars<'a, F: RichField + Extendable, const D: usize> { - pub local_constants: &'a [F::Extension], + pub local_constants: Vec, pub local_wires: &'a [F::Extension], pub public_inputs_hash: &'a HashOut, } @@ -19,10 +19,10 @@ pub struct EvaluationVars<'a, F: RichField + Extendable, const D: usize> { /// A batch of evaluation vars, in the base field. /// Wires and constants are stored in an evaluation point-major order (that is, wire 0 for all /// evaluation points, then wire 1 for all points, and so on). -#[derive(Debug, Copy, Clone)] +#[derive(Debug, Clone)] pub struct EvaluationVarsBaseBatch<'a, F: Field> { batch_size: usize, - pub local_constants: &'a [F], + pub local_constants: Vec, pub local_wires: &'a [F], pub public_inputs_hash: &'a HashOut, } @@ -55,8 +55,8 @@ impl<'a, F: RichField + Extendable, const D: usize> EvaluationVars<'a, F, D> ExtensionAlgebra::from_basefield_array(arr) } - pub fn remove_prefix(&mut self, prefix: &[bool]) { - self.local_constants = &self.local_constants[prefix.len()..]; + pub fn remove_prefix(&mut self, selector_index: usize) { + self.local_constants.remove(selector_index); } } @@ -71,14 +71,16 @@ impl<'a, F: Field> EvaluationVarsBaseBatch<'a, F> { assert_eq!(local_wires.len() % batch_size, 0); Self { batch_size, - local_constants, + local_constants: local_constants.to_vec(), local_wires, public_inputs_hash, } } - pub fn remove_prefix(&mut self, prefix: &[bool]) { - self.local_constants = &self.local_constants[prefix.len() * self.len()..]; + pub fn remove_prefix(&mut self, selector_index: usize) { + let mut v = self.local_constants[..self.len() * selector_index].to_vec(); + v.extend(&self.local_constants[self.len() * (selector_index + 1)..]); + self.local_constants = v; } pub fn len(&self) -> usize { @@ -88,8 +90,9 @@ impl<'a, F: Field> EvaluationVarsBaseBatch<'a, F> { pub fn view(&self, index: usize) -> EvaluationVarsBase<'a, F> { // We cannot implement `Index` as `EvaluationVarsBase` is a struct, not a reference. assert!(index < self.len()); - let local_constants = PackedStridedView::new(self.local_constants, self.len(), index); - let local_wires = PackedStridedView::new(self.local_wires, self.len(), index); + let local_constants = + PackedStridedView::new(self.local_constants.clone(), self.len(), index); + let local_wires = PackedStridedView::new(self.local_wires.to_vec(), self.len(), index); EvaluationVarsBase { local_constants, local_wires, @@ -98,7 +101,7 @@ impl<'a, F: Field> EvaluationVarsBaseBatch<'a, F> { } pub fn iter(&self) -> EvaluationVarsBaseBatchIter<'a, F> { - EvaluationVarsBaseBatchIter::new(*self) + EvaluationVarsBaseBatchIter::new(self.clone()) } pub fn pack>( @@ -109,8 +112,11 @@ impl<'a, F: Field> EvaluationVarsBaseBatch<'a, F> { ) { let n_leftovers = self.len() % P::WIDTH; ( - EvaluationVarsBaseBatchIterPacked::new_with_start(*self, 0), - EvaluationVarsBaseBatchIterPacked::new_with_start(*self, self.len() - n_leftovers), + EvaluationVarsBaseBatchIterPacked::new_with_start(self.clone(), 0), + EvaluationVarsBaseBatchIterPacked::new_with_start( + self.clone(), + self.len() - n_leftovers, + ), ) } } @@ -179,12 +185,15 @@ impl<'a, P: PackedField> Iterator for EvaluationVarsBaseBatchIterPacked<'a, P> { fn next(&mut self) -> Option { if self.i + P::WIDTH <= self.vars_batch.len() { let local_constants = PackedStridedView::new( - self.vars_batch.local_constants, + self.vars_batch.local_constants.to_vec(), + self.vars_batch.len(), + self.i, + ); + let local_wires = PackedStridedView::new( + self.vars_batch.local_wires.to_vec(), self.vars_batch.len(), self.i, ); - let local_wires = - PackedStridedView::new(self.vars_batch.local_wires, self.vars_batch.len(), self.i); let res = EvaluationVarsBasePacked { local_constants, local_wires, diff --git a/plonky2/src/plonk/verifier.rs b/plonky2/src/plonk/verifier.rs index ee0e976f..6839424c 100644 --- a/plonky2/src/plonk/verifier.rs +++ b/plonky2/src/plonk/verifier.rs @@ -49,7 +49,7 @@ pub(crate) fn verify_with_challenges< where [(); C::Hasher::HASH_SIZE]:, { - let local_constants = &proof.openings.constants; + let local_constants = proof.openings.constants.clone(); let local_wires = &proof.openings.wires; let vars = EvaluationVars { local_constants, diff --git a/plonky2/src/util/strided_view.rs b/plonky2/src/util/strided_view.rs index 6ea270ce..ed877f71 100644 --- a/plonky2/src/util/strided_view.rs +++ b/plonky2/src/util/strided_view.rs @@ -46,7 +46,7 @@ impl<'a, P: PackedField> PackedStridedView<'a, P> { // end of the same allocated object'; the UB results even if the pointer is not dereferenced. #[inline] - pub fn new(data: &'a [P::Scalar], stride: usize, offset: usize) -> Self { + pub fn new(data: Vec, stride: usize, offset: usize) -> Self { assert!( stride >= P::WIDTH, "stride (got {}) must be at least P::WIDTH ({})",