Degree shrinker

This commit is contained in:
Daniel Lubarov 2021-02-26 23:30:22 -08:00
parent 383812dffd
commit 5d6da4f94a
3 changed files with 246 additions and 45 deletions

View File

@ -10,6 +10,7 @@ use num::{BigUint, FromPrimitive, One, Zero};
use crate::field::field::Field;
use crate::wire::Wire;
use crate::gates::output_graph::GateOutputLocation;
use std::borrow::Borrow;
pub(crate) struct EvaluationVars<'a, F: Field> {
pub(crate) local_constants: &'a [F],
@ -25,7 +26,7 @@ pub(crate) struct EvaluationVars<'a, F: Field> {
/// This type implements `Hash` and `Eq` based on references rather
/// than content. This is useful when we want to use constraint polynomials as `HashMap` keys, but
/// we want address-based hashing for performance reasons.
#[derive(Clone)]
#[derive(Clone, Debug)]
pub struct ConstraintPolynomial<F: Field>(pub(crate) Rc<ConstraintPolynomialInner<F>>);
impl<F: Field> ConstraintPolynomial<F> {
@ -214,12 +215,68 @@ impl<F: Field> ConstraintPolynomial<F> {
from: Self,
to: Self,
) -> Self {
Self::from_inner(self.0.replace_all(from, to))
self.replace_all_helper(from, to, &mut HashMap::new())
}
/// Replace all occurrences of `from` with `to` in this polynomial graph. In order to preserve
/// the structure of the graph, we keep track of any `ConstraintPolynomial`s that have been
/// replaced already.
fn replace_all_helper(
&self, from: Self,
to: Self,
replacements: &mut HashMap<Self, Self>,
) -> Self {
if *self == from {
return to;
}
if let Some(replacement) = replacements.get(self) {
return replacement.clone();
}
match self.0.borrow() {
ConstraintPolynomialInner::Constant(_) => self.clone(),
ConstraintPolynomialInner::LocalConstant(_) => self.clone(),
ConstraintPolynomialInner::NextConstant(_) => self.clone(),
ConstraintPolynomialInner::LocalWireValue(_) => self.clone(),
ConstraintPolynomialInner::NextWireValue(_) => self.clone(),
ConstraintPolynomialInner::Sum { lhs, rhs } => {
let lhs = lhs.replace_all_helper(from.clone(), to.clone(), replacements);
let rhs = rhs.replace_all_helper(from, to, replacements);
let replacement = Self::from_inner(ConstraintPolynomialInner::Sum { lhs, rhs });
debug_assert!(!replacements.contains_key(self));
replacements.insert(self.clone(), replacement.clone());
replacement
}
ConstraintPolynomialInner::Product { lhs, rhs } => {
let lhs = lhs.replace_all_helper(from.clone(), to.clone(), replacements);
let rhs = rhs.replace_all_helper(from, to, replacements);
let replacement = Self::from_inner(ConstraintPolynomialInner::Product { lhs, rhs });
debug_assert!(!replacements.contains_key(self));
replacements.insert(self.clone(), replacement.clone());
replacement
}
ConstraintPolynomialInner::Exponentiation { base, exponent } => {
let base = base.replace_all_helper(from, to, replacements);
let replacement = Self::from_inner(
ConstraintPolynomialInner::Exponentiation { base, exponent: *exponent });
debug_assert!(!replacements.contains_key(self));
replacements.insert(self.clone(), replacement.clone());
replacement
}
}
}
fn from_inner(inner: ConstraintPolynomialInner<F>) -> Self {
Self(Rc::new(inner))
}
/// The number of polynomials in this graph.
fn graph_size(&self) -> usize {
let mut degrees = HashMap::new();
self.populate_degree_map(&mut degrees);
degrees.len()
}
}
impl<F: Field> PartialEq for ConstraintPolynomial<F> {
@ -400,6 +457,7 @@ impl<F: Field> Product for ConstraintPolynomial<F> {
}
}
#[derive(Clone, Debug)]
pub(crate) enum ConstraintPolynomialInner<F: Field> {
Constant(F),
@ -429,41 +487,41 @@ impl<F: Field> ConstraintPolynomialInner<F> {
ConstraintPolynomialInner::LocalConstant(_) => (),
ConstraintPolynomialInner::NextConstant(_) => (),
ConstraintPolynomialInner::LocalWireValue(i) =>
{ deps.insert(Wire { gate, input: *i }); },
{ deps.insert(Wire { gate, input: *i }); }
ConstraintPolynomialInner::NextWireValue(i) =>
{ deps.insert(Wire { gate: gate + 1, input: *i }); }
ConstraintPolynomialInner::Sum { lhs, rhs } => {
lhs.0.add_dependencies(gate, deps);
rhs.0.add_dependencies(gate, deps);
},
}
ConstraintPolynomialInner::Product { lhs, rhs } => {
lhs.0.add_dependencies(gate, deps);
rhs.0.add_dependencies(gate, deps);
},
}
ConstraintPolynomialInner::Exponentiation { base, exponent: _ } => {
base.0.add_dependencies(gate, deps);
},
}
}
}
fn add_constant_indices(&self, indices: &mut HashSet<usize>) {
match self {
ConstraintPolynomialInner::Constant(_) => (),
ConstraintPolynomialInner::LocalConstant(i) => { indices.insert(*i); },
ConstraintPolynomialInner::NextConstant(i) => { indices.insert(*i); },
ConstraintPolynomialInner::LocalConstant(i) => { indices.insert(*i); }
ConstraintPolynomialInner::NextConstant(i) => { indices.insert(*i); }
ConstraintPolynomialInner::LocalWireValue(_) => (),
ConstraintPolynomialInner::NextWireValue(_) => (),
ConstraintPolynomialInner::Sum { lhs, rhs } => {
lhs.0.add_constant_indices(indices);
rhs.0.add_constant_indices(indices);
},
}
ConstraintPolynomialInner::Product { lhs, rhs } => {
lhs.0.add_constant_indices(indices);
rhs.0.add_constant_indices(indices);
},
}
ConstraintPolynomialInner::Exponentiation { base, exponent: _ } => {
base.0.add_constant_indices(indices);
},
}
}
}
@ -482,16 +540,16 @@ impl<F: Field> ConstraintPolynomialInner<F> {
let lhs = lhs.evaluate_memoized(vars, mem);
let rhs = rhs.evaluate_memoized(vars, mem);
lhs + rhs
},
}
ConstraintPolynomialInner::Product { lhs, rhs } => {
let lhs = lhs.evaluate_memoized(vars, mem);
let rhs = rhs.evaluate_memoized(vars, mem);
lhs * rhs
},
}
ConstraintPolynomialInner::Exponentiation { base, exponent } => {
let base = base.evaluate_memoized(vars, mem);
base.exp_usize(*exponent)
},
}
}
}
@ -508,12 +566,52 @@ impl<F: Field> ConstraintPolynomialInner<F> {
base.0.degree() * BigUint::from_usize(*exponent).unwrap(),
}
}
}
fn replace_all(
&self,
from: ConstraintPolynomial<F>,
to: ConstraintPolynomial<F>,
) -> Self {
todo!()
#[cfg(test)]
mod tests {
use crate::constraint_polynomial::ConstraintPolynomial;
use crate::field::crandall_field::CrandallField;
#[test]
fn equality() {
type F = CrandallField;
let wire0 = ConstraintPolynomial::<F>::local_wire_value(0);
// == should compare the pointers, and the clone should point to the same underlying
// ConstraintPolynomialInner.
assert_eq!(wire0.clone(), wire0);
}
#[test]
fn replace_all() {
type F = CrandallField;
let wire0 = ConstraintPolynomial::<F>::local_wire_value(0);
let wire1 = ConstraintPolynomial::<F>::local_wire_value(1);
let wire2 = ConstraintPolynomial::<F>::local_wire_value(2);
let wire3 = ConstraintPolynomial::<F>::local_wire_value(3);
let wire4 = ConstraintPolynomial::<F>::local_wire_value(4);
let sum01 = &wire0 + &wire1;
let sum12 = &wire1 + &wire2;
let sum23 = &wire2 + &wire3;
let product = &sum01 * &sum12 * &sum23;
assert_eq!(
wire0.replace_all(wire0.clone(), wire1.clone()),
wire1);
assert_eq!(
wire0.replace_all(wire1.clone(), wire2.clone()),
wire0);
// This should be a no-op, since wire 4 is not present in the product.
assert_eq!(
product.replace_all(wire4.clone(), wire3.clone()).graph_size(),
product.graph_size());
// This shouldn't change the graph structure at all, since the replacement wire 4 was not
// previously present.
assert_eq!(
product.replace_all(wire3.clone(), wire4.clone()).graph_size(),
product.graph_size());
}
}

View File

@ -99,7 +99,10 @@ impl<F: Field, const W: usize, const R: usize> DeterministicGate<F> for GMiMCGat
// A degree of 9 is reasonable for most circuits, and it means that we only need wires for
// every other addition buffer state.
OutputGraph { outputs }.shrink_degree(9)
println!("before");
let out = OutputGraph { outputs }.shrink_degree(9);
println!("after");
out
}
fn additional_constraints(&self, _config: CircuitConfig) -> Vec<ConstraintPolynomial<F>> {
@ -108,3 +111,29 @@ impl<F: Field, const W: usize, const R: usize> DeterministicGate<F> for GMiMCGat
vec![switch_bool_constraint]
}
}
#[cfg(test)]
mod tests {
use std::sync::Arc;
use crate::field::crandall_field::CrandallField;
use crate::gates::gmimc::GMiMCGate;
use crate::field::field::Field;
use crate::circuit_data::CircuitConfig;
use crate::gates::deterministic_gate::DeterministicGate;
#[test]
fn degree() {
type F = CrandallField;
const W: usize = 12;
const R: usize = 101;
let gate = GMiMCGate::<F, W, R> { constants: Arc::new([F::TWO; R]) };
let config = CircuitConfig {
num_wires: 200,
num_routed_wires: 200,
security_bits: 128
};
let outs = gate.outputs(config);
assert_eq!(outs.max_wire_input_index(), Some(50));
}
}

View File

@ -1,25 +1,18 @@
use std::collections::HashMap;
use std::iter;
use crate::constraint_polynomial::{ConstraintPolynomial};
use num::{BigUint, FromPrimitive, One};
use crate::constraint_polynomial::ConstraintPolynomial;
use crate::field::field::Field;
use std::collections::HashMap;
use num::BigUint;
/// Represents a set of deterministic gate outputs, expressed as polynomials over witness
/// values.
#[derive(Clone, Debug)]
pub struct OutputGraph<F: Field> {
pub(crate) outputs: Vec<(GateOutputLocation, ConstraintPolynomial<F>)>
}
/// Represents an output location of a deterministic gate.
#[derive(Copy, Clone)]
pub enum GateOutputLocation {
/// A wire belonging to the gate itself.
LocalWire(usize),
/// A wire belonging to the following gate.
NextWire(usize),
}
impl<F: Field> OutputGraph<F> {
/// Creates an output graph with a single output.
pub fn single_output(loc: GateOutputLocation, out: ConstraintPolynomial<F>) -> Self {
@ -31,10 +24,50 @@ impl<F: Field> OutputGraph<F> {
///
/// Note that this uses a simple greedy algorithm, so the result may not be optimal in terms of wire
/// count.
// TODO: This doesn't yet work with large exponentiations, i.e. x^n where n > new_degree. Not an
// TODO: This doesn't yet work with large exponentiations, i.e. x^n where n > max_degree. Not an
// immediate problem since our gates don't use those.
pub fn shrink_degree(&self, new_degree: usize) -> Self {
todo!()
pub fn shrink_degree(&self, max_degree: usize) -> Self {
let max_degree_biguint = BigUint::from_usize(max_degree).unwrap();
let mut current_graph = self.clone();
while current_graph.count_high_degree_polys(max_degree) > 0 {
// Find polynomials with a degree between 2 and the max, inclusive.
// These are candidates for becoming new wires.
let mut candidates = current_graph.degree_map().into_iter()
.filter(|(_poly, deg)| deg > &BigUint::one() && deg <= &max_degree_biguint)
.map(|(poly, _deg)| poly);
// Pick the candidate that minimizes the number of high-degree polynomials in our graph.
// This is just a simple heuristic; it won't always give an optimal wire count.
let mut first = candidates.next().expect("No candidate; cannot reduce degree further");
let mut leader_graph = current_graph.allocate_wire(first);
let mut leader_high_deg_count = leader_graph.count_high_degree_polys(max_degree);
for candidate in candidates {
let candidate_graph = current_graph.allocate_wire(candidate);
let candidate_high_deg_count = candidate_graph.count_high_degree_polys(max_degree);
if candidate_high_deg_count < leader_high_deg_count {
leader_graph = candidate_graph;
leader_high_deg_count = candidate_high_deg_count;
}
}
// println!("before {:?}", current_graph);
// println!("after {:?}", leader_graph);
current_graph = leader_graph;
println!("{}", leader_high_deg_count);
}
current_graph
}
/// The number of polynomials in this graph which exceed the given maximum degree.
fn count_high_degree_polys(&self, max_degree: usize) -> usize {
let max_degree = BigUint::from_usize(max_degree).unwrap();
self.degree_map().into_iter()
.filter(|(_poly, deg)| deg > &max_degree)
.count()
}
fn degree_map(&self) -> HashMap<ConstraintPolynomial<F>, BigUint> {
@ -45,12 +78,17 @@ impl<F: Field> OutputGraph<F> {
degrees
}
/// The largest local wire index in this entire graph.
pub(crate) fn max_wire_input_index(&self) -> Option<usize> {
self.outputs.iter()
.flat_map(|(loc, out)| out.max_wire_input_index())
.max()
}
/// Allocate a new wire for the given target polynomial, and return a new output graph with
/// references to the target polynomial replaced with references to that wire.
fn allocate_wire(&self, target: ConstraintPolynomial<F>) -> Self {
let new_wire_index = self.outputs.iter()
.flat_map(|(loc, out)| out.max_wire_input_index())
.max()
let new_wire_index = self.max_wire_input_index()
.map_or(0, |i| i + 1);
let new_wire = ConstraintPolynomial::local_wire_value(new_wire_index);
@ -63,16 +101,52 @@ impl<F: Field> OutputGraph<F> {
}
}
/// Represents an output location of a deterministic gate.
#[derive(Copy, Clone, Debug)]
pub enum GateOutputLocation {
/// A wire belonging to the gate itself.
LocalWire(usize),
/// A wire belonging to the following gate.
NextWire(usize),
}
#[cfg(test)]
mod tests {
use crate::constraint_polynomial::ConstraintPolynomial;
use crate::gates::output_graph::shrink_degree;
use crate::field::crandall_field::CrandallField;
use crate::gates::output_graph::{GateOutputLocation, OutputGraph};
#[test]
fn shrink_exp() {
let original = ConstraintPolynomial::local_wire_value(0).exp(10);
let shrunk = shrink_degree(original, 3);
// `shrunk` should be something similar to (wire0^3)^3 * wire0.
assert_eq!(shrunk.max_wire_input_index(), Some(2))
fn shrink_squaring_graph() {
type F = CrandallField;
let deg1 = ConstraintPolynomial::<F>::local_wire_value(0);
let deg2 = deg1.square();
let deg4 = deg2.square();
let deg8 = deg4.square();
let deg16 = deg8.square();
let original = OutputGraph::single_output(
GateOutputLocation::NextWire(0),
deg16);
let degree_map = original.degree_map();
assert_eq!(degree_map.len(), 5);
assert_eq!(original.count_high_degree_polys(2), 3);
assert_eq!(original.count_high_degree_polys(3), 3);
assert_eq!(original.count_high_degree_polys(4), 2);
let shrunk_deg_2 = original.shrink_degree(2);
let shrunk_deg_3 = original.shrink_degree(3);
let shrunk_deg_4 = original.shrink_degree(4);
// `shrunk_deg_2` should have an intermediate wire for deg2, deg4, and deg8.
assert_eq!(shrunk_deg_2.max_wire_input_index(), Some(3));
// `shrunk_deg_3` should also have an intermediate wire for deg2, deg4, and deg8.
assert_eq!(shrunk_deg_3.max_wire_input_index(), Some(3));
// `shrunk_deg_4` should have an intermediate wire for deg4 only.
assert_eq!(shrunk_deg_4.max_wire_input_index(), Some(1));
}
}