use std::fmt::{Debug, Error, Formatter}; use std::hash::{Hash, Hasher}; use std::sync::Arc; use crate::circuit_builder::CircuitBuilder; use crate::field::extension_field::target::ExtensionTarget; use crate::field::extension_field::{Extendable, FieldExtension}; use crate::field::field::Field; use crate::gates::gate_tree::Tree; use crate::generator::WitnessGenerator; use crate::vars::{EvaluationTargets, EvaluationVars, EvaluationVarsBase}; /// A custom gate. pub trait Gate, const D: usize>: 'static + Send + Sync { fn id(&self) -> String; fn eval_unfiltered(&self, vars: EvaluationVars) -> Vec; /// Like `eval_unfiltered`, but specialized for points in the base field. /// /// By default, this just calls `eval_unfiltered`, which treats the point as an extension field /// element. This isn't very efficient. fn eval_unfiltered_base(&self, vars_base: EvaluationVarsBase) -> Vec { let local_constants = &vars_base .local_constants .iter() .map(|c| F::Extension::from_basefield(*c)) .collect::>(); let local_wires = &vars_base .local_wires .iter() .map(|w| F::Extension::from_basefield(*w)) .collect::>(); let vars = EvaluationVars { local_constants, local_wires, }; let values = self.eval_unfiltered(vars); // Each value should be in the base field, i.e. only the degree-zero part should be nonzero. values .into_iter() .map(|value| { // TODO: Change to debug-only once our gate code is mostly finished/stable. assert!(F::Extension::is_in_basefield(&value)); value.to_basefield_array()[0] }) .collect() } fn eval_unfiltered_recursively( &self, builder: &mut CircuitBuilder, vars: EvaluationTargets, ) -> Vec>; fn eval_filtered(&self, mut vars: EvaluationVars, prefix: &[bool]) -> Vec { let filter = compute_filter(prefix, vars.local_constants); vars.remove_prefix(prefix); self.eval_unfiltered(vars) .into_iter() .map(|c| filter * c) .collect() } /// Like `eval_filtered`, but specialized for points in the base field. fn eval_filtered_base(&self, mut vars: EvaluationVarsBase, prefix: &[bool]) -> Vec { let filter = compute_filter(prefix, vars.local_constants); vars.remove_prefix(prefix); self.eval_unfiltered_base(vars) .into_iter() .map(|c| c * filter) .collect() } fn eval_filtered_recursively( &self, builder: &mut CircuitBuilder, vars: EvaluationTargets, ) -> Vec> { // TODO: Filter self.eval_unfiltered_recursively(builder, vars) } fn generators( &self, gate_index: usize, local_constants: &[F], ) -> Vec>>; /// The number of wires used by this gate. fn num_wires(&self) -> usize; /// The number of constants used by this gate. fn num_constants(&self) -> usize; /// The maximum degree among this gate's constraint polynomials. fn degree(&self) -> usize; fn num_constraints(&self) -> usize; } /// A wrapper around an `Rc` which implements `PartialEq`, `Eq` and `Hash` based on gate IDs. #[derive(Clone)] pub struct GateRef, const D: usize>(pub(crate) Arc>); impl, const D: usize> GateRef { pub fn new>(gate: G) -> GateRef { GateRef(Arc::new(gate)) } } impl, const D: usize> PartialEq for GateRef { fn eq(&self, other: &Self) -> bool { self.0.id() == other.0.id() } } impl, const D: usize> Hash for GateRef { fn hash(&self, state: &mut H) { self.0.id().hash(state) } } impl, const D: usize> Eq for GateRef {} impl, const D: usize> Debug for GateRef { fn fmt(&self, f: &mut Formatter<'_>) -> Result<(), Error> { write!(f, "{}", self.0.id()) } } /// A gate along with any constants used to configure it. pub struct GateInstance, const D: usize> { pub gate_type: GateRef, pub constants: Vec, } /// Map each gate to a boolean prefix used to construct the gate's selector polynomial. #[derive(Debug, Clone)] pub struct PrefixedGate, const D: usize> { pub gate: GateRef, pub prefix: Vec, } impl, const D: usize> PrefixedGate { pub fn from_tree(tree: Tree>) -> Vec { tree.traversal() .into_iter() .map(|(gate, prefix)| PrefixedGate { gate, prefix }) .collect() } } /// A gate's filter is computed as `prod b_i*c_i + (1-b_i)*(1-c_i)`, with `(b_i)` the prefix and /// `(c_i)` the local constants, which is one if the prefix of `constants` matches `prefix`. fn compute_filter(prefix: &[bool], constants: &[K]) -> K { prefix .iter() .enumerate() .map(|(i, &b)| { if b { constants[i] } else { K::ONE - constants[i] } }) .product() }