diff --git a/src/constraint_polynomial.rs b/src/constraint_polynomial.rs index 16da5926..12a44436 100644 --- a/src/constraint_polynomial.rs +++ b/src/constraint_polynomial.rs @@ -2,7 +2,7 @@ use std::collections::{HashMap, HashSet}; use std::hash::{Hash, Hasher}; use std::iter::{Product, Sum}; use std::ops::{Add, AddAssign, Mul, MulAssign, Neg, Sub, SubAssign}; -use std::ptr; +use std::{ptr, fmt}; use std::rc::Rc; use num::{BigUint, FromPrimitive, One, Zero}; @@ -11,6 +11,7 @@ use crate::field::field::Field; use crate::wire::Wire; use crate::gates::output_graph::GateOutputLocation; use std::borrow::Borrow; +use std::fmt::{Display, Formatter, Debug}; pub(crate) struct EvaluationVars<'a, F: Field> { pub(crate) local_constants: &'a [F], @@ -279,6 +280,12 @@ impl ConstraintPolynomial { } } +impl Display for ConstraintPolynomial { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + Display::fmt(&self.0, f) + } +} + impl PartialEq for ConstraintPolynomial { fn eq(&self, other: &Self) -> bool { ptr::eq(&*self.0, &*other.0) @@ -480,6 +487,29 @@ pub(crate) enum ConstraintPolynomialInner { }, } +impl Display for ConstraintPolynomialInner { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + match self { + ConstraintPolynomialInner::Constant(c) => + write!(f, "{}", c), + ConstraintPolynomialInner::LocalConstant(i) => + write!(f, "local_const_{}", i), + ConstraintPolynomialInner::NextConstant(i) => + write!(f, "next_const_{}", i), + ConstraintPolynomialInner::LocalWireValue(i) => + write!(f, "local_wire_{}", i), + ConstraintPolynomialInner::NextWireValue(i) => + write!(f, "next_wire_{}", i), + ConstraintPolynomialInner::Sum { lhs, rhs } => + write!(f, "({} + {})", lhs, rhs), + ConstraintPolynomialInner::Product { lhs, rhs } => + write!(f, "({} * {})", lhs, rhs), + ConstraintPolynomialInner::Exponentiation { base, exponent } => + write!(f, "({} ^ {})", base, exponent), + } + } +} + impl ConstraintPolynomialInner { fn add_dependencies(&self, gate: usize, deps: &mut HashSet) { match self { diff --git a/src/field/crandall_field.rs b/src/field/crandall_field.rs index 66ea5966..3fc143f0 100644 --- a/src/field/crandall_field.rs +++ b/src/field/crandall_field.rs @@ -1,4 +1,4 @@ -use std::fmt::{Debug, Formatter}; +use std::fmt::{Debug, Formatter, Display}; use std::fmt; use std::ops::{Add, AddAssign, Div, DivAssign, Mul, MulAssign, Neg, Sub, SubAssign}; @@ -29,9 +29,15 @@ impl PartialEq for CrandallField { impl Eq for CrandallField {} +impl Display for CrandallField { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + Display::fmt(&self.0, f) + } +} + impl Debug for CrandallField { fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { - self.0.fmt(f) + Debug::fmt(&self.0, f) } } diff --git a/src/field/field.rs b/src/field/field.rs index 4c3776e5..85cdce26 100644 --- a/src/field/field.rs +++ b/src/field/field.rs @@ -1,4 +1,4 @@ -use std::fmt::Debug; +use std::fmt::{Debug, Display}; use std::ops::{Add, AddAssign, Div, DivAssign, Mul, MulAssign, Neg, Sub, SubAssign}; /// A finite field with prime order less than 2^64. @@ -16,6 +16,7 @@ pub trait Field: 'static + Div + DivAssign + Debug ++ Display + Send + Sync { const ZERO: Self; diff --git a/src/gates/gmimc.rs b/src/gates/gmimc.rs index f21a47e5..cecc4b5e 100644 --- a/src/gates/gmimc.rs +++ b/src/gates/gmimc.rs @@ -123,10 +123,11 @@ mod tests { use crate::gates::deterministic_gate::DeterministicGate; #[test] + #[ignore] fn degree() { type F = CrandallField; const W: usize = 12; - const R: usize = 101; + const R: usize = 20; let gate = GMiMCGate:: { constants: Arc::new([F::TWO; R]) }; let config = CircuitConfig { num_wires: 200, diff --git a/src/gates/output_graph.rs b/src/gates/output_graph.rs index 846941f7..48f77dce 100644 --- a/src/gates/output_graph.rs +++ b/src/gates/output_graph.rs @@ -1,10 +1,11 @@ use std::collections::HashMap; -use std::iter; +use std::{iter, fmt}; -use num::{BigUint, FromPrimitive, One}; +use num::{BigUint, FromPrimitive, One, ToPrimitive}; use crate::constraint_polynomial::ConstraintPolynomial; use crate::field::field::Field; +use std::fmt::{Display, Formatter}; /// Represents a set of deterministic gate outputs, expressed as polynomials over witness /// values. @@ -31,32 +32,39 @@ impl OutputGraph { let mut current_graph = self.clone(); - while current_graph.count_high_degree_polys(max_degree) > 0 { + 'shrinker: while current_graph.count_high_degree_polys(max_degree) > 0 { // Find polynomials with a degree between 2 and the max, inclusive. // These are candidates for becoming new wires. - let mut candidates = current_graph.degree_map().into_iter() - .filter(|(_poly, deg)| deg > &BigUint::one() && deg <= &max_degree_biguint) - .map(|(poly, _deg)| poly); + let degrees = current_graph.degree_map(); + let current_high_deg_count = current_graph.count_high_degree_polys(max_degree); + let mut candidate_degrees: Vec<(ConstraintPolynomial, usize)> = degrees + .iter() + .filter(|(poly, deg)| *deg > &BigUint::one() && *deg <= &max_degree_biguint) + .map(|(poly, deg)| (poly.clone(), deg.to_usize().unwrap())) + .collect(); + candidate_degrees.sort_unstable_by_key(|(poly, deg)| *deg); + candidate_degrees.reverse(); - // Pick the candidate that minimizes the number of high-degree polynomials in our graph. - // This is just a simple heuristic; it won't always give an optimal wire count. - let mut first = candidates.next().expect("No candidate; cannot reduce degree further"); - let mut leader_graph = current_graph.allocate_wire(first); - let mut leader_high_deg_count = leader_graph.count_high_degree_polys(max_degree); - - for candidate in candidates { - let candidate_graph = current_graph.allocate_wire(candidate); + for (poly, _deg) in &candidate_degrees { + let candidate_graph = current_graph.allocate_wire(poly.clone()); let candidate_high_deg_count = candidate_graph.count_high_degree_polys(max_degree); - if candidate_high_deg_count < leader_high_deg_count { - leader_graph = candidate_graph; - leader_high_deg_count = candidate_high_deg_count; + if candidate_high_deg_count < current_high_deg_count { + // println!("before {}", ¤t_graph); + // println!("after {}", &candidate_graph); + current_graph = candidate_graph; + println!("Reduced high degree polys to {}", candidate_high_deg_count); + continue 'shrinker; } } - // println!("before {:?}", current_graph); - // println!("after {:?}", leader_graph); - current_graph = leader_graph; - println!("{}", leader_high_deg_count); + println!("No good candidates; cannot reduce high degree polys"); + for (poly, _deg) in candidate_degrees { + let candidate_graph = current_graph.allocate_wire(poly); + current_graph = candidate_graph; + continue 'shrinker; + } + + panic!("No candidate; cannot make progress"); } current_graph @@ -101,6 +109,15 @@ impl OutputGraph { } } +impl Display for OutputGraph { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + for (loc, out) in &self.outputs { + write!(f, "{} := {}, ", loc, out)?; + } + Ok(()) + } +} + /// Represents an output location of a deterministic gate. #[derive(Copy, Clone, Debug)] pub enum GateOutputLocation { @@ -110,12 +127,51 @@ pub enum GateOutputLocation { NextWire(usize), } +impl Display for GateOutputLocation { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + match self { + GateOutputLocation::LocalWire(i) => write!(f, "local_wire_{}", i), + GateOutputLocation::NextWire(i) => write!(f, "next_wire_{}", i), + } + } +} + #[cfg(test)] mod tests { use crate::constraint_polynomial::ConstraintPolynomial; use crate::field::crandall_field::CrandallField; use crate::gates::output_graph::{GateOutputLocation, OutputGraph}; + #[test] + fn shrink_mimc() { + // This is like a simplified version of GMiMC, for easy debugging. + type F = CrandallField; + let switch = ConstraintPolynomial::::local_wire_value(0); + let x = ConstraintPolynomial::::local_wire_value(1); + let y = ConstraintPolynomial::::local_wire_value(2); + + // deg 2 + let delta = &switch * (&y - &x); + let l0 = &x + δ + let r0 = &y - δ + let s0 = &l0 + &r0; + + // 2*3 + let l1 = s0.cube(); let r1 = r0.cube(); let s1 = &l1 + &r1; + // 2*3*3 + let l2 = s1.cube(); let r2 = r1.cube(); let s2 = &l2 + &r2; + // 2*3*3*3 + let l3 = s2.cube(); let r3 = r2.cube(); let s3 = &l3 + &r3; + + let og = OutputGraph { outputs: vec![ + (GateOutputLocation::NextWire(0), l3), + (GateOutputLocation::NextWire(1), r3), + ] }; + + let shrunk = og.shrink_degree(9); + assert_eq!(shrunk.max_wire_input_index(), Some(4)); + } + #[test] fn shrink_squaring_graph() { type F = CrandallField;