Merge pull request #71 from mir-protocol/gate_tree

Add gate tree, gate prefixes and filtered methods
This commit is contained in:
wborgeaud 2021-06-24 21:02:44 +02:00 committed by GitHub
commit fb89d637e1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 358 additions and 29 deletions

View File

@ -11,7 +11,8 @@ use crate::field::cosets::get_unique_coset_shifts;
use crate::field::extension_field::target::ExtensionTarget;
use crate::field::extension_field::Extendable;
use crate::gates::constant::ConstantGate;
use crate::gates::gate::{GateInstance, GateRef};
use crate::gates::gate::{GateInstance, GateRef, PrefixedGate};
use crate::gates::gate_tree::Tree;
use crate::gates::noop::NoopGate;
use crate::generator::{CopyGenerator, WitnessGenerator};
use crate::hash::hash_n_to_hash;
@ -229,22 +230,26 @@ impl<F: Extendable<D>, const D: usize> CircuitBuilder<F, D> {
}
}
fn constant_polys(&self) -> Vec<PolynomialValues<F>> {
let num_constants = self
.gate_instances
fn constant_polys(&self, gates: &[PrefixedGate<F, D>]) -> Vec<PolynomialValues<F>> {
let num_constants = gates
.iter()
.map(|gate_inst| gate_inst.constants.len())
.map(|gate| gate.gate.0.num_constants() + gate.prefix.len())
.max()
.unwrap();
let constants_per_gate = self
.gate_instances
.iter()
.map(|gate_inst| {
let mut padded_constants = gate_inst.constants.clone();
for _ in padded_constants.len()..num_constants {
padded_constants.push(F::ZERO);
}
padded_constants
.map(|gate| {
let prefix = &gates
.iter()
.find(|g| g.gate.0.id() == gate.gate_type.0.id())
.unwrap()
.prefix;
let mut prefixed_constants = Vec::with_capacity(num_constants);
prefixed_constants.extend(prefix.iter().map(|&b| if b { F::ONE } else { F::ZERO }));
prefixed_constants.extend_from_slice(&gate.constants);
prefixed_constants.resize(num_constants, F::ZERO);
prefixed_constants
})
.collect::<Vec<_>>();
@ -288,7 +293,11 @@ impl<F: Extendable<D>, const D: usize> CircuitBuilder<F, D> {
let degree = self.gate_instances.len();
info!("degree after blinding & padding: {}", degree);
let constant_vecs = self.constant_polys();
let gates = self.gates.iter().cloned().collect();
let gate_tree = Tree::from_gates(gates);
let prefixed_gates = PrefixedGate::from_tree(gate_tree);
let constant_vecs = self.constant_polys(&prefixed_gates);
let constants_commitment = ListPolynomialCommitment::new(
constant_vecs.into_iter().map(|v| v.ifft()).collect(),
self.config.fri_config.rate_bits,
@ -338,7 +347,7 @@ impl<F: Extendable<D>, const D: usize> CircuitBuilder<F, D> {
let common = CommonCircuitData {
config: self.config,
degree_bits,
gates,
gates: prefixed_gates,
num_gate_constraints,
k_is,
circuit_digest,

View File

@ -3,7 +3,7 @@ use anyhow::Result;
use crate::field::extension_field::Extendable;
use crate::field::field::Field;
use crate::fri::FriConfig;
use crate::gates::gate::{GateInstance, GateRef};
use crate::gates::gate::{GateInstance, PrefixedGate};
use crate::generator::WitnessGenerator;
use crate::polynomial::commitment::ListPolynomialCommitment;
use crate::proof::{Hash, HashTarget, Proof};
@ -141,8 +141,8 @@ pub(crate) struct CommonCircuitData<F: Extendable<D>, const D: usize> {
pub(crate) degree_bits: usize,
/// The types of gates used in this circuit.
pub(crate) gates: Vec<GateRef<F, D>>,
/// The types of gates used in this circuit, along with their prefixes.
pub(crate) gates: Vec<PrefixedGate<F, D>>,
/// The largest number of constraints imposed by any gate.
pub(crate) num_gate_constraints: usize,
@ -171,7 +171,7 @@ impl<F: Extendable<D>, const D: usize> CommonCircuitData<F, D> {
pub fn constraint_degree(&self) -> usize {
self.gates
.iter()
.map(|g| g.0.degree())
.map(|g| g.gate.0.degree())
.max()
.expect("No gates?")
}

View File

@ -1,6 +1,7 @@
use crate::field::field::Field;
use std::convert::TryInto;
use crate::field::field::Field;
pub mod algebra;
pub mod quadratic;
pub mod quartic;

View File

@ -1,9 +1,12 @@
use std::fmt::{Debug, Error, Formatter};
use std::hash::{Hash, Hasher};
use std::sync::Arc;
use crate::circuit_builder::CircuitBuilder;
use crate::field::extension_field::target::ExtensionTarget;
use crate::field::extension_field::{Extendable, FieldExtension};
use crate::field::field::Field;
use crate::gates::gate_tree::Tree;
use crate::generator::WitnessGenerator;
use crate::vars::{EvaluationTargets, EvaluationVars, EvaluationVarsBase};
@ -51,15 +54,23 @@ pub trait Gate<F: Extendable<D>, const D: usize>: 'static + Send + Sync {
vars: EvaluationTargets<D>,
) -> Vec<ExtensionTarget<D>>;
fn eval_filtered(&self, vars: EvaluationVars<F, D>) -> Vec<F::Extension> {
// TODO: Filter
fn eval_filtered(&self, mut vars: EvaluationVars<F, D>, prefix: &[bool]) -> Vec<F::Extension> {
let filter = compute_filter(prefix, vars.local_constants);
vars.remove_prefix(prefix);
self.eval_unfiltered(vars)
.into_iter()
.map(|c| filter * c)
.collect()
}
/// Like `eval_filtered`, but specialized for points in the base field.
fn eval_filtered_base(&self, vars: EvaluationVarsBase<F>) -> Vec<F> {
// TODO: Filter
fn eval_filtered_base(&self, mut vars: EvaluationVarsBase<F>, prefix: &[bool]) -> Vec<F> {
let filter = compute_filter(prefix, vars.local_constants);
vars.remove_prefix(prefix);
self.eval_unfiltered_base(vars)
.into_iter()
.map(|c| c * filter)
.collect()
}
fn eval_filtered_recursively(
@ -113,8 +124,46 @@ impl<F: Extendable<D>, const D: usize> Hash for GateRef<F, D> {
impl<F: Extendable<D>, const D: usize> Eq for GateRef<F, D> {}
impl<F: Extendable<D>, const D: usize> Debug for GateRef<F, D> {
fn fmt(&self, f: &mut Formatter<'_>) -> Result<(), Error> {
write!(f, "{}", self.0.id())
}
}
/// A gate along with any constants used to configure it.
pub struct GateInstance<F: Extendable<D>, const D: usize> {
pub gate_type: GateRef<F, D>,
pub constants: Vec<F>,
}
/// Map each gate to a boolean prefix used to construct the gate's selector polynomial.
#[derive(Debug, Clone)]
pub struct PrefixedGate<F: Extendable<D>, const D: usize> {
pub gate: GateRef<F, D>,
pub prefix: Vec<bool>,
}
impl<F: Extendable<D>, const D: usize> PrefixedGate<F, D> {
pub fn from_tree(tree: Tree<GateRef<F, D>>) -> Vec<Self> {
tree.traversal()
.into_iter()
.map(|(gate, prefix)| PrefixedGate { gate, prefix })
.collect()
}
}
/// A gate's filter is computed as `prod b_i*c_i + (1-b_i)*(1-c_i)`, with `(b_i)` the prefix and
/// `(c_i)` the local constants, which is one if the prefix of `constants` matches `prefix`.
fn compute_filter<K: Field>(prefix: &[bool], constants: &[K]) -> K {
prefix
.iter()
.enumerate()
.map(|(i, &b)| {
if b {
constants[i]
} else {
K::ONE - constants[i]
}
})
.product()
}

257
src/gates/gate_tree.rs Normal file
View File

@ -0,0 +1,257 @@
use log::info;
use crate::field::extension_field::Extendable;
use crate::gates::gate::GateRef;
/// A binary tree where leaves hold some type `T` and other nodes are empty.
#[derive(Debug, Clone)]
pub enum Tree<T> {
Leaf(T),
Bifurcation(Option<Box<Tree<T>>>, Option<Box<Tree<T>>>),
}
impl<T> Default for Tree<T> {
fn default() -> Self {
Self::Bifurcation(None, None)
}
}
impl<T: Clone> Tree<T> {
/// Traverse a tree using a depth-first traversal and collect data and position for each leaf.
/// A leaf's position is represented by its left/right path, where `false` means left and `true` means right.
pub fn traversal(&self) -> Vec<(T, Vec<bool>)> {
let mut res = Vec::new();
let prefix = [];
self.traverse(&prefix, &mut res);
res
}
/// Utility function to traverse the tree.
fn traverse(&self, prefix: &[bool], current: &mut Vec<(T, Vec<bool>)>) {
match &self {
// If node is a leaf, collect the data and position.
Tree::Leaf(t) => {
current.push((t.clone(), prefix.to_vec()));
}
// Otherwise, traverse the left subtree and then the right subtree.
Tree::Bifurcation(left, right) => {
if let Some(l) = left {
let mut left_prefix = prefix.to_vec();
left_prefix.push(false);
l.traverse(&left_prefix, current);
}
if let Some(r) = right {
let mut right_prefix = prefix.to_vec();
right_prefix.push(true);
r.traverse(&right_prefix, current);
}
}
}
}
}
impl<F: Extendable<D>, const D: usize> Tree<GateRef<F, D>> {
/// The binary gate tree influences the degree `D` of the constraint polynomial and the number `C`
/// of constant wires in the circuit. We want to construct a tree minimizing both values. To do so
/// we iterate over possible values of `(D, C)` and try to construct a tree with these values.
/// For this construction, we use the greedy algorithm in `Self::find_tree`.
/// This latter function greedily adds gates at the depth where
/// `filtered_deg(gate)=D, constant_wires(gate)=C` to ensure no space is wasted.
/// We return the first tree found in this manner.
pub fn from_gates(mut gates: Vec<GateRef<F, D>>) -> Self {
let timer = std::time::Instant::now();
gates.sort_unstable_by_key(|g| (-(g.0.degree() as isize), -(g.0.num_constants() as isize)));
for max_degree_bits in 1..10 {
// The constraint polynomials are padded to the next power in `compute_vanishig_polys`.
// So we can restrict our search space by setting `max_degree` to a power of 2.
let max_degree = 1 << max_degree_bits;
for max_constants in 1..100 {
if let Some(mut tree) = Self::find_tree(&gates, max_degree, max_constants) {
tree.shorten();
info!(
"Found tree with max degree {} in {}s.",
max_degree,
timer.elapsed().as_secs_f32()
);
return tree;
}
}
}
panic!("Can't find a tree.")
}
/// Greedily add gates wherever possible. Returns `None` if this fails.
fn find_tree(gates: &[GateRef<F, D>], max_degree: usize, max_constants: usize) -> Option<Self> {
let mut tree = Tree::default();
for g in gates {
tree.try_add_gate(g, max_degree, max_constants)?;
}
Some(tree)
}
/// Try to add a gate in the tree. Returns `None` if this fails.
fn try_add_gate(
&mut self,
g: &GateRef<F, D>,
max_degree: usize,
max_constants: usize,
) -> Option<()> {
// We want `gate.degree + depth <= max_degree` and `gate.num_constants + depth <= max_wires`.
let depth = max_degree
.checked_sub(g.0.degree())?
.min(max_constants.checked_sub(g.0.num_constants())?);
self.try_add_gate_at_depth(g, depth)
}
/// Try to add a gate in the tree at a specified depth. Returns `None` if this fails.
fn try_add_gate_at_depth(&mut self, g: &GateRef<F, D>, depth: usize) -> Option<()> {
// If depth is 0, we have to insert the gate here.
if depth == 0 {
return if let Tree::Bifurcation(None, None) = self {
// Insert the gate as a new leaf.
*self = Tree::Leaf(g.clone());
Some(())
} else {
// A leaf is already here.
None
};
}
// A leaf is already here so we cannot go deeper.
if let Tree::Leaf(_) = self {
return None;
}
if let Tree::Bifurcation(left, right) = self {
if let Some(left) = left {
// Try to add the gate to the left if there's already a left subtree.
if left.try_add_gate_at_depth(g, depth - 1).is_some() {
return Some(());
}
} else {
// Add a new left subtree and try to add the gate to it.
let mut new_left = Tree::default();
if new_left.try_add_gate_at_depth(g, depth - 1).is_some() {
*left = Some(Box::new(new_left));
return Some(());
}
}
if let Some(right) = right {
// Try to add the gate to the right if there's already a right subtree.
if right.try_add_gate_at_depth(g, depth - 1).is_some() {
return Some(());
}
} else {
// Add a new right subtree and try to add the gate to it.
let mut new_right = Tree::default();
if new_right.try_add_gate_at_depth(g, depth - 1).is_some() {
*right = Some(Box::new(new_right));
return Some(());
}
}
}
None
}
/// `Self::find_tree` returns a tree where each gate has `F(gate)=M` (see `Self::from_gates` comment).
/// This can produce subtrees with more nodes than necessary. This function removes useless nodes,
/// i.e., nodes that have a left but no right subtree.
fn shorten(&mut self) {
if let Tree::Bifurcation(left, right) = self {
if let (Some(left), None) = (left, right) {
// If the node has a left but no right subtree, set the node to its (shortened) left subtree.
let mut new = *left.clone();
new.shorten();
*self = new;
}
}
if let Tree::Bifurcation(left, right) = self {
if let Some(left) = left {
// Shorten the left subtree if there is one.
left.shorten();
}
if let Some(right) = right {
// Shorten the right subtree if there is one.
right.shorten();
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::field::crandall_field::CrandallField;
use crate::gates::arithmetic::ArithmeticGate;
use crate::gates::base_sum::BaseSumGate;
use crate::gates::constant::ConstantGate;
use crate::gates::gmimc::GMiMCGate;
use crate::gates::interpolation::InterpolationGate;
use crate::gates::mul_extension::MulExtensionGate;
use crate::gates::noop::NoopGate;
use crate::hash::GMIMC_ROUNDS;
#[test]
fn test_prefix_generation() {
env_logger::init();
type F = CrandallField;
const D: usize = 4;
let gates = vec![
NoopGate::get::<F, D>(),
ConstantGate::get(),
ArithmeticGate::new(),
BaseSumGate::<4>::new(4),
GMiMCGate::<F, D, GMIMC_ROUNDS>::with_automatic_constants(),
InterpolationGate::new(4),
MulExtensionGate::new(),
];
let len = gates.len();
let tree = Tree::from_gates(gates.clone());
let mut gates_with_prefix = tree.traversal();
for (g, p) in &gates_with_prefix {
info!(
"\nGate: {}, prefix: {:?}.\n\
Filtered constraint degree: {}, Num constant wires: {}",
&g.0.id()[..20.min(g.0.id().len())],
p,
g.0.degree() + p.len(),
g.0.num_constants() + p.len()
);
}
assert_eq!(
gates_with_prefix.len(),
gates.len(),
"The tree has too much or too little gates."
);
assert!(
gates
.iter()
.all(|g| gates_with_prefix.iter().map(|(gg, _)| gg).any(|gg| gg == g)),
"Some gates are not in the tree."
);
assert!(
gates_with_prefix
.iter()
.all(|(g, p)| g.0.degree() + g.0.num_constants() + p.len() <= 8),
"Total degree is larger than 8."
);
gates_with_prefix.sort_unstable_by_key(|(g, p)| p.len());
for i in 0..gates_with_prefix.len() {
for j in i + 1..gates_with_prefix.len() {
assert_ne!(
&gates_with_prefix[i].1,
&gates_with_prefix[j].1[0..gates_with_prefix[i].1.len()],
"Some gates share an overlapping prefix"
);
}
}
}
}

View File

@ -2,6 +2,7 @@ pub(crate) mod arithmetic;
pub mod base_sum;
pub mod constant;
pub(crate) mod gate;
pub mod gate_tree;
pub mod gmimc;
pub mod interpolation;
pub mod mul_extension;

View File

@ -5,7 +5,7 @@ use crate::circuit_data::CommonCircuitData;
use crate::field::extension_field::target::ExtensionTarget;
use crate::field::extension_field::Extendable;
use crate::field::field::Field;
use crate::gates::gate::GateRef;
use crate::gates::gate::{GateRef, PrefixedGate};
use crate::polynomial::commitment::SALT_SIZE;
use crate::polynomial::polynomial::PolynomialCoeffs;
use crate::target::Target;
@ -167,13 +167,13 @@ pub(crate) fn eval_vanishing_poly_base<F: Extendable<D>, const D: usize>(
/// strictly necessary, but it helps performance by ensuring that we allocate a vector with exactly
/// the capacity that we need.
pub fn evaluate_gate_constraints<F: Extendable<D>, const D: usize>(
gates: &[GateRef<F, D>],
gates: &[PrefixedGate<F, D>],
num_gate_constraints: usize,
vars: EvaluationVars<F, D>,
) -> Vec<F::Extension> {
let mut constraints = vec![F::Extension::ZERO; num_gate_constraints];
for gate in gates {
let gate_constraints = gate.0.eval_filtered(vars);
let gate_constraints = gate.gate.0.eval_filtered(vars, &gate.prefix);
for (i, c) in gate_constraints.into_iter().enumerate() {
debug_assert!(
i < num_gate_constraints,
@ -186,13 +186,13 @@ pub fn evaluate_gate_constraints<F: Extendable<D>, const D: usize>(
}
pub fn evaluate_gate_constraints_base<F: Extendable<D>, const D: usize>(
gates: &[GateRef<F, D>],
gates: &[PrefixedGate<F, D>],
num_gate_constraints: usize,
vars: EvaluationVarsBase<F>,
) -> Vec<F> {
let mut constraints = vec![F::ZERO; num_gate_constraints];
for gate in gates {
let gate_constraints = gate.0.eval_filtered_base(vars);
let gate_constraints = gate.gate.0.eval_filtered_base(vars, &gate.prefix);
for (i, c) in gate_constraints.into_iter().enumerate() {
debug_assert!(
i < num_gate_constraints,

View File

@ -1,6 +1,7 @@
use std::ops::Range;
use crate::circuit_data::CircuitConfig;
use crate::wire::Wire;
use std::ops::Range;
/// A location in the witness.
#[derive(Copy, Clone, Eq, PartialEq, Hash, Debug)]

View File

@ -27,6 +27,16 @@ impl<'a, F: Extendable<D>, const D: usize> EvaluationVars<'a, F, D> {
let arr = self.local_wires[wire_range].try_into().unwrap();
ExtensionAlgebra::from_basefield_array(arr)
}
pub fn remove_prefix(&mut self, prefix: &[bool]) {
self.local_constants = &self.local_constants[prefix.len()..];
}
}
impl<'a, F: Field> EvaluationVarsBase<'a, F> {
pub fn remove_prefix(&mut self, prefix: &[bool]) {
self.local_constants = &self.local_constants[prefix.len()..];
}
}
#[derive(Copy, Clone)]

View File

@ -1,6 +1,7 @@
use crate::circuit_data::CircuitConfig;
use std::ops::Range;
use crate::circuit_data::CircuitConfig;
/// Represents a wire in the circuit.
#[derive(Copy, Clone, Eq, PartialEq, Hash, Debug)]
pub struct Wire {