use std::fmt::{Debug, Error, Formatter}; use std::hash::{Hash, Hasher}; use std::sync::Arc; use crate::field::batch_util::batch_multiply_inplace; use crate::field::extension_field::target::ExtensionTarget; use crate::field::extension_field::{Extendable, FieldExtension}; use crate::field::field_types::{Field, RichField}; use crate::gates::gate_tree::Tree; use crate::gates::util::StridedConstraintConsumer; use crate::iop::generator::WitnessGenerator; use crate::plonk::circuit_builder::CircuitBuilder; use crate::plonk::vars::{ EvaluationTargets, EvaluationVars, EvaluationVarsBase, EvaluationVarsBaseBatch, }; /// A custom gate. pub trait Gate, const D: usize>: 'static + Send + Sync { fn id(&self) -> String; fn eval_unfiltered(&self, vars: EvaluationVars) -> Vec; /// Like `eval_unfiltered`, but specialized for points in the base field. /// /// /// `eval_unfiltered_base_batch` calls this method by default. If `eval_unfiltered_base_batch` /// is overridden, then `eval_unfiltered_base_one` is not necessary. /// /// By default, this just calls `eval_unfiltered`, which treats the point as an extension field /// element. This isn't very efficient. fn eval_unfiltered_base_one( &self, vars_base: EvaluationVarsBase, mut yield_constr: StridedConstraintConsumer, ) { // Note that this method uses `yield_constr` instead of returning its constraints. // `yield_constr` abstracts out the underlying memory layout. let local_constants = &vars_base .local_constants .iter() .map(|c| F::Extension::from_basefield(*c)) .collect::>(); let local_wires = &vars_base .local_wires .iter() .map(|w| F::Extension::from_basefield(*w)) .collect::>(); let public_inputs_hash = &vars_base.public_inputs_hash; let vars = EvaluationVars { local_constants, local_wires, public_inputs_hash, }; let values = self.eval_unfiltered(vars); // Each value should be in the base field, i.e. only the degree-zero part should be nonzero. values.into_iter().for_each(|value| { debug_assert!(F::Extension::is_in_basefield(&value)); yield_constr.one(value.to_basefield_array()[0]) }) } fn eval_unfiltered_base_batch(&self, vars_base: EvaluationVarsBaseBatch) -> Vec { let mut res = vec![F::ZERO; vars_base.len() * self.num_constraints()]; for (i, vars_base_one) in vars_base.iter().enumerate() { self.eval_unfiltered_base_one( vars_base_one, StridedConstraintConsumer::new(&mut res, vars_base.len(), i), ); } res } fn eval_unfiltered_recursively( &self, builder: &mut CircuitBuilder, vars: EvaluationTargets, ) -> Vec>; fn eval_filtered(&self, mut vars: EvaluationVars, prefix: &[bool]) -> Vec { let filter = compute_filter(prefix, vars.local_constants); vars.remove_prefix(prefix); self.eval_unfiltered(vars) .into_iter() .map(|c| filter * c) .collect() } /// The result is an array of length `vars_batch.len() * self.num_constraints()`. Constraint `j` /// for point `i` is at index `j * batch_size + i`. fn eval_filtered_base_batch( &self, mut vars_batch: EvaluationVarsBaseBatch, prefix: &[bool], ) -> Vec { let filters: Vec<_> = vars_batch .iter() .map(|vars| compute_filter(prefix, vars.local_constants)) .collect(); vars_batch.remove_prefix(prefix); let mut res_batch = self.eval_unfiltered_base_batch(vars_batch); for res_chunk in res_batch.chunks_exact_mut(filters.len()) { batch_multiply_inplace(res_chunk, &filters); } res_batch } /// Adds this gate's filtered constraints into the `combined_gate_constraints` buffer. fn eval_filtered_recursively( &self, builder: &mut CircuitBuilder, mut vars: EvaluationTargets, prefix: &[bool], combined_gate_constraints: &mut Vec>, ) { let filter = compute_filter_recursively(builder, prefix, vars.local_constants); vars.remove_prefix(prefix); let my_constraints = self.eval_unfiltered_recursively(builder, vars); for (acc, c) in combined_gate_constraints.iter_mut().zip(my_constraints) { *acc = builder.mul_add_extension(filter, c, *acc); } } fn generators( &self, gate_index: usize, local_constants: &[F], ) -> Vec>>; /// The number of wires used by this gate. fn num_wires(&self) -> usize; /// The number of constants used by this gate. fn num_constants(&self) -> usize; /// The maximum degree among this gate's constraint polynomials. fn degree(&self) -> usize; fn num_constraints(&self) -> usize; } /// A wrapper around an `Rc` which implements `PartialEq`, `Eq` and `Hash` based on gate IDs. #[derive(Clone)] pub struct GateRef, const D: usize>(pub(crate) Arc>); impl, const D: usize> GateRef { pub fn new>(gate: G) -> GateRef { GateRef(Arc::new(gate)) } } impl, const D: usize> PartialEq for GateRef { fn eq(&self, other: &Self) -> bool { self.0.id() == other.0.id() } } impl, const D: usize> Hash for GateRef { fn hash(&self, state: &mut H) { self.0.id().hash(state) } } impl, const D: usize> Eq for GateRef {} impl, const D: usize> Debug for GateRef { fn fmt(&self, f: &mut Formatter<'_>) -> Result<(), Error> { write!(f, "{}", self.0.id()) } } /// A gate along with any constants used to configure it. pub struct GateInstance, const D: usize> { pub gate_ref: GateRef, pub constants: Vec, } /// Map each gate to a boolean prefix used to construct the gate's selector polynomial. #[derive(Debug, Clone)] pub struct PrefixedGate, const D: usize> { pub gate: GateRef, pub prefix: Vec, } impl, const D: usize> PrefixedGate { pub fn from_tree(tree: Tree>) -> Vec { tree.traversal() .into_iter() .map(|(gate, prefix)| PrefixedGate { gate, prefix }) .collect() } } /// A gate's filter is computed as `prod b_i*c_i + (1-b_i)*(1-c_i)`, with `(b_i)` the prefix and /// `(c_i)` the local constants, which is one if the prefix of `constants` matches `prefix`. fn compute_filter<'a, K: Field, T: IntoIterator>(prefix: &[bool], constants: T) -> K { prefix .iter() .zip(constants) .map(|(&b, &c)| if b { c } else { K::ONE - c }) .product() } fn compute_filter_recursively, const D: usize>( builder: &mut CircuitBuilder, prefix: &[bool], constants: &[ExtensionTarget], ) -> ExtensionTarget { let one = builder.one_extension(); let v = prefix .iter() .enumerate() .map(|(i, &b)| { if b { constants[i] } else { builder.sub_extension(one, constants[i]) } }) .collect::>(); builder.mul_many_extension(&v) }