diff --git a/plonky2/src/gadgets/arithmetic_extension.rs b/plonky2/src/gadgets/arithmetic_extension.rs index ea3e8b13..3b8b11c7 100644 --- a/plonky2/src/gadgets/arithmetic_extension.rs +++ b/plonky2/src/gadgets/arithmetic_extension.rs @@ -613,6 +613,7 @@ mod tests { type FF = >::FE; let config = CircuitConfig::standard_recursion_zk_config(); + let config = CircuitConfig::standard_recursion_config(); let pw = PartialWitness::new(); let mut builder = CircuitBuilder::::new(config); diff --git a/plonky2/src/gates/gate.rs b/plonky2/src/gates/gate.rs index 168dab71..b5fc7772 100644 --- a/plonky2/src/gates/gate.rs +++ b/plonky2/src/gates/gate.rs @@ -86,13 +86,14 @@ pub trait Gate, const D: usize>: 'static + Send + S gate_index: usize, selector_index: usize, combination_range: (usize, usize), + num_selectors: usize, ) -> Vec { let filter = compute_filter( gate_index, combination_range, vars.local_constants[selector_index], ); - vars.remove_prefix(prefix); + vars.remove_prefix(num_selectors); self.eval_unfiltered(vars) .into_iter() .map(|c| filter * c) @@ -107,6 +108,7 @@ pub trait Gate, const D: usize>: 'static + Send + S gate_index: usize, selector_index: usize, combination_range: (usize, usize), + num_selectors: usize, ) -> Vec { let filters: Vec<_> = vars_batch .iter() @@ -118,7 +120,7 @@ pub trait Gate, const D: usize>: 'static + Send + S ) }) .collect(); - vars_batch.remove_prefix(prefix); + vars_batch.remove_prefix(num_selectors); let mut res_batch = self.eval_unfiltered_base_batch(vars_batch); for res_chunk in res_batch.chunks_exact_mut(filters.len()) { batch_multiply_inplace(res_chunk, &filters); @@ -131,17 +133,19 @@ pub trait Gate, const D: usize>: 'static + Send + S &self, builder: &mut CircuitBuilder, mut vars: EvaluationTargets, + gate_index: usize, selector_index: usize, combination_range: (usize, usize), combined_gate_constraints: &mut [ExtensionTarget], + num_selectors: usize, ) { let filter = compute_filter_recursively( builder, - selector_index, + gate_index, combination_range, vars.local_constants[selector_index], ); - vars.remove_prefix(prefix); + vars.remove_prefix(num_selectors); let my_constraints = self.eval_unfiltered_recursively(builder, vars); for (acc, c) in combined_gate_constraints.iter_mut().zip(my_constraints) { *acc = builder.mul_add_extension(filter, c, *acc); @@ -244,18 +248,20 @@ fn compute_filter<'a, K: Field>( ) -> K { (combination_range.0..combination_range.1) .filter(|&i| i != gate_index) + .chain(Some(u32::MAX as usize)) .map(|i| K::from_canonical_usize(i) - constant) .product() } fn compute_filter_recursively, const D: usize>( builder: &mut CircuitBuilder, - selector_index: usize, + gate_index: usize, combination_range: (usize, usize), constant: ExtensionTarget, ) -> ExtensionTarget { let v = (combination_range.0..combination_range.1) - .filter(|&i| i != selector_index) + .filter(|&i| i != gate_index) + .chain(Some(u32::MAX as usize)) .map(|i| builder.constant_extension(F::Extension::from_canonical_usize(i))) .collect::>(); let v = v diff --git a/plonky2/src/gates/selectors.rs b/plonky2/src/gates/selectors.rs index c1461c09..3eb5282e 100644 --- a/plonky2/src/gates/selectors.rs +++ b/plonky2/src/gates/selectors.rs @@ -8,7 +8,12 @@ pub(crate) fn compute_selectors, const D: usize>( mut gates: Vec>, instances: &[GateInstance], max_degree: usize, -) -> (Vec>, Vec, Vec<(usize, usize)>) { +) -> ( + Vec>, + Vec, + Vec<(usize, usize)>, + usize, +) { let n = instances.len(); let mut combinations = Vec::new(); @@ -16,15 +21,16 @@ pub(crate) fn compute_selectors, const D: usize>( while pos < gates.len() { let mut i = 0; - while (pos + i < gates.len()) && (i + gates[pos + i].0.degree() <= max_degree + 1) { + while (pos + i < gates.len()) && (i + gates[pos + i].0.degree() <= max_degree) { i += 1; } combinations.push((pos, pos + i)); pos += i; } + dbg!(&combinations); + let bad = F::from_canonical_usize(u32::MAX as usize); - let num_constants_polynomials = - 0.max(gates.iter().map(|g| g.0.num_constants()).max().unwrap() - combinations.len() + 1); + let num_constants_polynomials = gates.iter().map(|g| g.0.num_constants()).max().unwrap(); let mut polynomials = vec![PolynomialValues::zero(n); combinations.len() + num_constants_polynomials]; @@ -53,17 +59,20 @@ pub(crate) fn compute_selectors, const D: usize>( let i = index(gate_ref.0.id()); let comb = combination(i); polynomials[comb].values[j] = F::from_canonical_usize(i); - let mut k = 0; - let mut constant_ind = 0; - while k < constants.len() { - if constant_ind == comb { - constant_ind += 1; - } else { - polynomials[constant_ind].values[j] = constants[k]; - constant_ind += 1; - k += 1; - } + + for combis in (0..combinations.len()).filter(|&combis| combis != comb) { + polynomials[combis].values[j] = bad; + } + + for k in 0..constants.len() { + polynomials[combinations.len() + k].values[j] = constants[k]; } } - (polynomials, selector_indices, combination_ranges) + + ( + polynomials, + selector_indices, + combination_ranges, + combinations.len(), + ) } diff --git a/plonky2/src/plonk/circuit_builder.rs b/plonky2/src/plonk/circuit_builder.rs index ffa253c1..49acaade 100644 --- a/plonky2/src/plonk/circuit_builder.rs +++ b/plonky2/src/plonk/circuit_builder.rs @@ -676,12 +676,13 @@ impl, const D: usize> CircuitBuilder { let mut gates = self.gates.iter().cloned().collect::>(); gates.sort_unstable_by_key(|g| g.0.degree()); dbg!(&gates); - let (constant_vecs, selector_indices, combination_nums) = compute_selectors( - gates.clone(), - &self.gate_instances, - self.config.max_quotient_degree_factor + 1, - ); - dbg!(&constant_vecs, &selector_indices, &combination_nums); + let (constant_vecs, selector_indices, combination_ranges, num_selectors) = + compute_selectors( + gates.clone(), + &self.gate_instances, + self.config.max_quotient_degree_factor + 1, + ); + dbg!(&constant_vecs, &selector_indices, &combination_ranges); let num_constants = constant_vecs.len(); // let (gate_tree, max_filtered_constraint_degree, num_constants) = Tree::from_gates(gates); // let prefixed_gates = PrefixedGate::from_tree(gate_tree); @@ -805,6 +806,7 @@ impl, const D: usize> CircuitBuilder { gates, selector_indices, combination_ranges, + num_selectors, quotient_degree_factor, num_gate_constraints, num_constants, diff --git a/plonky2/src/plonk/circuit_data.rs b/plonky2/src/plonk/circuit_data.rs index 4ce51b3a..c889def9 100644 --- a/plonky2/src/plonk/circuit_data.rs +++ b/plonky2/src/plonk/circuit_data.rs @@ -250,6 +250,7 @@ pub struct CommonCircuitData< pub(crate) selector_indices: Vec, pub(crate) combination_ranges: Vec<(usize, usize)>, + pub(crate) num_selectors: usize, /// The degree of the PLONK quotient polynomial. pub(crate) quotient_degree_factor: usize, diff --git a/plonky2/src/plonk/vanishing_poly.rs b/plonky2/src/plonk/vanishing_poly.rs index 7be35e4f..ff578759 100644 --- a/plonky2/src/plonk/vanishing_poly.rs +++ b/plonky2/src/plonk/vanishing_poly.rs @@ -217,10 +217,11 @@ pub fn evaluate_gate_constraints< let mut constraints = vec![F::Extension::ZERO; common_data.num_gate_constraints]; for (i, gate) in common_data.gates.iter().enumerate() { let gate_constraints = gate.0.eval_filtered( - vars.clone(), + vars, i, common_data.selector_indices[i], common_data.combination_ranges[i], + common_data.num_selectors, ); for (i, c) in gate_constraints.into_iter().enumerate() { debug_assert!( @@ -253,6 +254,7 @@ pub fn evaluate_gate_constraints_base_batch< i, common_data.selector_indices[i], common_data.combination_ranges[i], + common_data.num_selectors, ); debug_assert!( gate_constraints_batch.len() <= constraints_batch.len(), @@ -283,10 +285,12 @@ pub fn evaluate_gate_constraints_recursively< &format!("evaluate {} constraints", gate.0.id()), gate.0.eval_filtered_recursively( builder, - vars.clone(), + vars, + i, common_data.selector_indices[i], common_data.combination_ranges[i], - &mut all_gate_constraints + &mut all_gate_constraints, + common_data.num_selectors ) ); } diff --git a/plonky2/src/plonk/vars.rs b/plonky2/src/plonk/vars.rs index e2e52cfb..8c7b71e3 100644 --- a/plonky2/src/plonk/vars.rs +++ b/plonky2/src/plonk/vars.rs @@ -55,8 +55,8 @@ impl<'a, F: RichField + Extendable, const D: usize> EvaluationVars<'a, F, D> ExtensionAlgebra::from_basefield_array(arr) } - pub fn remove_prefix(&mut self, prefix: &[bool]) { - self.local_constants = &self.local_constants[prefix.len()..]; + pub fn remove_prefix(&mut self, num_selectors: usize) { + self.local_constants = &self.local_constants[num_selectors..]; } } @@ -77,8 +77,8 @@ impl<'a, F: Field> EvaluationVarsBaseBatch<'a, F> { } } - pub fn remove_prefix(&mut self, prefix: &[bool]) { - self.local_constants = &self.local_constants[prefix.len() * self.len()..]; + pub fn remove_prefix(&mut self, num_selectors: usize) { + self.local_constants = &self.local_constants[num_selectors * self.len()..]; } pub fn len(&self) -> usize { @@ -209,8 +209,8 @@ impl<'a, P: PackedField> ExactSizeIterator for EvaluationVarsBaseBatchIterPacked } impl<'a, const D: usize> EvaluationTargets<'a, D> { - pub fn remove_prefix(&mut self, prefix: &[bool]) { - self.local_constants = &self.local_constants[prefix.len()..]; + pub fn remove_prefix(&mut self, num_selectors: usize) { + self.local_constants = &self.local_constants[num_selectors..]; } }