diff --git a/src/constraint_polynomial.rs b/src/constraint_polynomial.rs index 40e3f2f6..16da5926 100644 --- a/src/constraint_polynomial.rs +++ b/src/constraint_polynomial.rs @@ -10,6 +10,7 @@ use num::{BigUint, FromPrimitive, One, Zero}; use crate::field::field::Field; use crate::wire::Wire; use crate::gates::output_graph::GateOutputLocation; +use std::borrow::Borrow; pub(crate) struct EvaluationVars<'a, F: Field> { pub(crate) local_constants: &'a [F], @@ -25,7 +26,7 @@ pub(crate) struct EvaluationVars<'a, F: Field> { /// This type implements `Hash` and `Eq` based on references rather /// than content. This is useful when we want to use constraint polynomials as `HashMap` keys, but /// we want address-based hashing for performance reasons. -#[derive(Clone)] +#[derive(Clone, Debug)] pub struct ConstraintPolynomial(pub(crate) Rc>); impl ConstraintPolynomial { @@ -214,12 +215,68 @@ impl ConstraintPolynomial { from: Self, to: Self, ) -> Self { - Self::from_inner(self.0.replace_all(from, to)) + self.replace_all_helper(from, to, &mut HashMap::new()) + } + + /// Replace all occurrences of `from` with `to` in this polynomial graph. In order to preserve + /// the structure of the graph, we keep track of any `ConstraintPolynomial`s that have been + /// replaced already. + fn replace_all_helper( + &self, from: Self, + to: Self, + replacements: &mut HashMap, + ) -> Self { + if *self == from { + return to; + } + + if let Some(replacement) = replacements.get(self) { + return replacement.clone(); + } + + match self.0.borrow() { + ConstraintPolynomialInner::Constant(_) => self.clone(), + ConstraintPolynomialInner::LocalConstant(_) => self.clone(), + ConstraintPolynomialInner::NextConstant(_) => self.clone(), + ConstraintPolynomialInner::LocalWireValue(_) => self.clone(), + ConstraintPolynomialInner::NextWireValue(_) => self.clone(), + ConstraintPolynomialInner::Sum { lhs, rhs } => { + let lhs = lhs.replace_all_helper(from.clone(), to.clone(), replacements); + let rhs = rhs.replace_all_helper(from, to, replacements); + let replacement = Self::from_inner(ConstraintPolynomialInner::Sum { lhs, rhs }); + debug_assert!(!replacements.contains_key(self)); + replacements.insert(self.clone(), replacement.clone()); + replacement + } + ConstraintPolynomialInner::Product { lhs, rhs } => { + let lhs = lhs.replace_all_helper(from.clone(), to.clone(), replacements); + let rhs = rhs.replace_all_helper(from, to, replacements); + let replacement = Self::from_inner(ConstraintPolynomialInner::Product { lhs, rhs }); + debug_assert!(!replacements.contains_key(self)); + replacements.insert(self.clone(), replacement.clone()); + replacement + } + ConstraintPolynomialInner::Exponentiation { base, exponent } => { + let base = base.replace_all_helper(from, to, replacements); + let replacement = Self::from_inner( + ConstraintPolynomialInner::Exponentiation { base, exponent: *exponent }); + debug_assert!(!replacements.contains_key(self)); + replacements.insert(self.clone(), replacement.clone()); + replacement + } + } } fn from_inner(inner: ConstraintPolynomialInner) -> Self { Self(Rc::new(inner)) } + + /// The number of polynomials in this graph. + fn graph_size(&self) -> usize { + let mut degrees = HashMap::new(); + self.populate_degree_map(&mut degrees); + degrees.len() + } } impl PartialEq for ConstraintPolynomial { @@ -400,6 +457,7 @@ impl Product for ConstraintPolynomial { } } +#[derive(Clone, Debug)] pub(crate) enum ConstraintPolynomialInner { Constant(F), @@ -429,41 +487,41 @@ impl ConstraintPolynomialInner { ConstraintPolynomialInner::LocalConstant(_) => (), ConstraintPolynomialInner::NextConstant(_) => (), ConstraintPolynomialInner::LocalWireValue(i) => - { deps.insert(Wire { gate, input: *i }); }, + { deps.insert(Wire { gate, input: *i }); } ConstraintPolynomialInner::NextWireValue(i) => { deps.insert(Wire { gate: gate + 1, input: *i }); } ConstraintPolynomialInner::Sum { lhs, rhs } => { lhs.0.add_dependencies(gate, deps); rhs.0.add_dependencies(gate, deps); - }, + } ConstraintPolynomialInner::Product { lhs, rhs } => { lhs.0.add_dependencies(gate, deps); rhs.0.add_dependencies(gate, deps); - }, + } ConstraintPolynomialInner::Exponentiation { base, exponent: _ } => { base.0.add_dependencies(gate, deps); - }, + } } } fn add_constant_indices(&self, indices: &mut HashSet) { match self { ConstraintPolynomialInner::Constant(_) => (), - ConstraintPolynomialInner::LocalConstant(i) => { indices.insert(*i); }, - ConstraintPolynomialInner::NextConstant(i) => { indices.insert(*i); }, + ConstraintPolynomialInner::LocalConstant(i) => { indices.insert(*i); } + ConstraintPolynomialInner::NextConstant(i) => { indices.insert(*i); } ConstraintPolynomialInner::LocalWireValue(_) => (), ConstraintPolynomialInner::NextWireValue(_) => (), ConstraintPolynomialInner::Sum { lhs, rhs } => { lhs.0.add_constant_indices(indices); rhs.0.add_constant_indices(indices); - }, + } ConstraintPolynomialInner::Product { lhs, rhs } => { lhs.0.add_constant_indices(indices); rhs.0.add_constant_indices(indices); - }, + } ConstraintPolynomialInner::Exponentiation { base, exponent: _ } => { base.0.add_constant_indices(indices); - }, + } } } @@ -482,16 +540,16 @@ impl ConstraintPolynomialInner { let lhs = lhs.evaluate_memoized(vars, mem); let rhs = rhs.evaluate_memoized(vars, mem); lhs + rhs - }, + } ConstraintPolynomialInner::Product { lhs, rhs } => { let lhs = lhs.evaluate_memoized(vars, mem); let rhs = rhs.evaluate_memoized(vars, mem); lhs * rhs - }, + } ConstraintPolynomialInner::Exponentiation { base, exponent } => { let base = base.evaluate_memoized(vars, mem); base.exp_usize(*exponent) - }, + } } } @@ -508,12 +566,52 @@ impl ConstraintPolynomialInner { base.0.degree() * BigUint::from_usize(*exponent).unwrap(), } } +} - fn replace_all( - &self, - from: ConstraintPolynomial, - to: ConstraintPolynomial, - ) -> Self { - todo!() +#[cfg(test)] +mod tests { + use crate::constraint_polynomial::ConstraintPolynomial; + use crate::field::crandall_field::CrandallField; + + #[test] + fn equality() { + type F = CrandallField; + let wire0 = ConstraintPolynomial::::local_wire_value(0); + // == should compare the pointers, and the clone should point to the same underlying + // ConstraintPolynomialInner. + assert_eq!(wire0.clone(), wire0); + } + + #[test] + fn replace_all() { + type F = CrandallField; + let wire0 = ConstraintPolynomial::::local_wire_value(0); + let wire1 = ConstraintPolynomial::::local_wire_value(1); + let wire2 = ConstraintPolynomial::::local_wire_value(2); + let wire3 = ConstraintPolynomial::::local_wire_value(3); + let wire4 = ConstraintPolynomial::::local_wire_value(4); + let sum01 = &wire0 + &wire1; + let sum12 = &wire1 + &wire2; + let sum23 = &wire2 + &wire3; + let product = &sum01 * &sum12 * &sum23; + + assert_eq!( + wire0.replace_all(wire0.clone(), wire1.clone()), + wire1); + + assert_eq!( + wire0.replace_all(wire1.clone(), wire2.clone()), + wire0); + + // This should be a no-op, since wire 4 is not present in the product. + assert_eq!( + product.replace_all(wire4.clone(), wire3.clone()).graph_size(), + product.graph_size()); + + // This shouldn't change the graph structure at all, since the replacement wire 4 was not + // previously present. + assert_eq!( + product.replace_all(wire3.clone(), wire4.clone()).graph_size(), + product.graph_size()); } } diff --git a/src/gates/gmimc.rs b/src/gates/gmimc.rs index 5b36ca4a..f21a47e5 100644 --- a/src/gates/gmimc.rs +++ b/src/gates/gmimc.rs @@ -99,7 +99,10 @@ impl DeterministicGate for GMiMCGat // A degree of 9 is reasonable for most circuits, and it means that we only need wires for // every other addition buffer state. - OutputGraph { outputs }.shrink_degree(9) + println!("before"); + let out = OutputGraph { outputs }.shrink_degree(9); + println!("after"); + out } fn additional_constraints(&self, _config: CircuitConfig) -> Vec> { @@ -108,3 +111,29 @@ impl DeterministicGate for GMiMCGat vec![switch_bool_constraint] } } + +#[cfg(test)] +mod tests { + use std::sync::Arc; + + use crate::field::crandall_field::CrandallField; + use crate::gates::gmimc::GMiMCGate; + use crate::field::field::Field; + use crate::circuit_data::CircuitConfig; + use crate::gates::deterministic_gate::DeterministicGate; + + #[test] + fn degree() { + type F = CrandallField; + const W: usize = 12; + const R: usize = 101; + let gate = GMiMCGate:: { constants: Arc::new([F::TWO; R]) }; + let config = CircuitConfig { + num_wires: 200, + num_routed_wires: 200, + security_bits: 128 + }; + let outs = gate.outputs(config); + assert_eq!(outs.max_wire_input_index(), Some(50)); + } +} diff --git a/src/gates/output_graph.rs b/src/gates/output_graph.rs index c4f7d30a..846941f7 100644 --- a/src/gates/output_graph.rs +++ b/src/gates/output_graph.rs @@ -1,25 +1,18 @@ +use std::collections::HashMap; use std::iter; -use crate::constraint_polynomial::{ConstraintPolynomial}; +use num::{BigUint, FromPrimitive, One}; + +use crate::constraint_polynomial::ConstraintPolynomial; use crate::field::field::Field; -use std::collections::HashMap; -use num::BigUint; /// Represents a set of deterministic gate outputs, expressed as polynomials over witness /// values. +#[derive(Clone, Debug)] pub struct OutputGraph { pub(crate) outputs: Vec<(GateOutputLocation, ConstraintPolynomial)> } -/// Represents an output location of a deterministic gate. -#[derive(Copy, Clone)] -pub enum GateOutputLocation { - /// A wire belonging to the gate itself. - LocalWire(usize), - /// A wire belonging to the following gate. - NextWire(usize), -} - impl OutputGraph { /// Creates an output graph with a single output. pub fn single_output(loc: GateOutputLocation, out: ConstraintPolynomial) -> Self { @@ -31,10 +24,50 @@ impl OutputGraph { /// /// Note that this uses a simple greedy algorithm, so the result may not be optimal in terms of wire /// count. - // TODO: This doesn't yet work with large exponentiations, i.e. x^n where n > new_degree. Not an + // TODO: This doesn't yet work with large exponentiations, i.e. x^n where n > max_degree. Not an // immediate problem since our gates don't use those. - pub fn shrink_degree(&self, new_degree: usize) -> Self { - todo!() + pub fn shrink_degree(&self, max_degree: usize) -> Self { + let max_degree_biguint = BigUint::from_usize(max_degree).unwrap(); + + let mut current_graph = self.clone(); + + while current_graph.count_high_degree_polys(max_degree) > 0 { + // Find polynomials with a degree between 2 and the max, inclusive. + // These are candidates for becoming new wires. + let mut candidates = current_graph.degree_map().into_iter() + .filter(|(_poly, deg)| deg > &BigUint::one() && deg <= &max_degree_biguint) + .map(|(poly, _deg)| poly); + + // Pick the candidate that minimizes the number of high-degree polynomials in our graph. + // This is just a simple heuristic; it won't always give an optimal wire count. + let mut first = candidates.next().expect("No candidate; cannot reduce degree further"); + let mut leader_graph = current_graph.allocate_wire(first); + let mut leader_high_deg_count = leader_graph.count_high_degree_polys(max_degree); + + for candidate in candidates { + let candidate_graph = current_graph.allocate_wire(candidate); + let candidate_high_deg_count = candidate_graph.count_high_degree_polys(max_degree); + if candidate_high_deg_count < leader_high_deg_count { + leader_graph = candidate_graph; + leader_high_deg_count = candidate_high_deg_count; + } + } + + // println!("before {:?}", current_graph); + // println!("after {:?}", leader_graph); + current_graph = leader_graph; + println!("{}", leader_high_deg_count); + } + + current_graph + } + + /// The number of polynomials in this graph which exceed the given maximum degree. + fn count_high_degree_polys(&self, max_degree: usize) -> usize { + let max_degree = BigUint::from_usize(max_degree).unwrap(); + self.degree_map().into_iter() + .filter(|(_poly, deg)| deg > &max_degree) + .count() } fn degree_map(&self) -> HashMap, BigUint> { @@ -45,12 +78,17 @@ impl OutputGraph { degrees } + /// The largest local wire index in this entire graph. + pub(crate) fn max_wire_input_index(&self) -> Option { + self.outputs.iter() + .flat_map(|(loc, out)| out.max_wire_input_index()) + .max() + } + /// Allocate a new wire for the given target polynomial, and return a new output graph with /// references to the target polynomial replaced with references to that wire. fn allocate_wire(&self, target: ConstraintPolynomial) -> Self { - let new_wire_index = self.outputs.iter() - .flat_map(|(loc, out)| out.max_wire_input_index()) - .max() + let new_wire_index = self.max_wire_input_index() .map_or(0, |i| i + 1); let new_wire = ConstraintPolynomial::local_wire_value(new_wire_index); @@ -63,16 +101,52 @@ impl OutputGraph { } } +/// Represents an output location of a deterministic gate. +#[derive(Copy, Clone, Debug)] +pub enum GateOutputLocation { + /// A wire belonging to the gate itself. + LocalWire(usize), + /// A wire belonging to the following gate. + NextWire(usize), +} + #[cfg(test)] mod tests { use crate::constraint_polynomial::ConstraintPolynomial; - use crate::gates::output_graph::shrink_degree; + use crate::field::crandall_field::CrandallField; + use crate::gates::output_graph::{GateOutputLocation, OutputGraph}; #[test] - fn shrink_exp() { - let original = ConstraintPolynomial::local_wire_value(0).exp(10); - let shrunk = shrink_degree(original, 3); - // `shrunk` should be something similar to (wire0^3)^3 * wire0. - assert_eq!(shrunk.max_wire_input_index(), Some(2)) + fn shrink_squaring_graph() { + type F = CrandallField; + let deg1 = ConstraintPolynomial::::local_wire_value(0); + let deg2 = deg1.square(); + let deg4 = deg2.square(); + let deg8 = deg4.square(); + let deg16 = deg8.square(); + + let original = OutputGraph::single_output( + GateOutputLocation::NextWire(0), + deg16); + + let degree_map = original.degree_map(); + assert_eq!(degree_map.len(), 5); + + assert_eq!(original.count_high_degree_polys(2), 3); + assert_eq!(original.count_high_degree_polys(3), 3); + assert_eq!(original.count_high_degree_polys(4), 2); + + let shrunk_deg_2 = original.shrink_degree(2); + let shrunk_deg_3 = original.shrink_degree(3); + let shrunk_deg_4 = original.shrink_degree(4); + + // `shrunk_deg_2` should have an intermediate wire for deg2, deg4, and deg8. + assert_eq!(shrunk_deg_2.max_wire_input_index(), Some(3)); + + // `shrunk_deg_3` should also have an intermediate wire for deg2, deg4, and deg8. + assert_eq!(shrunk_deg_3.max_wire_input_index(), Some(3)); + + // `shrunk_deg_4` should have an intermediate wire for deg4 only. + assert_eq!(shrunk_deg_4.max_wire_input_index(), Some(1)); } }