mirror of
https://github.com/logos-storage/plonky2.git
synced 2026-01-21 23:23:13 +00:00
Degree shrinker
This commit is contained in:
parent
383812dffd
commit
5d6da4f94a
@ -10,6 +10,7 @@ use num::{BigUint, FromPrimitive, One, Zero};
|
||||
use crate::field::field::Field;
|
||||
use crate::wire::Wire;
|
||||
use crate::gates::output_graph::GateOutputLocation;
|
||||
use std::borrow::Borrow;
|
||||
|
||||
pub(crate) struct EvaluationVars<'a, F: Field> {
|
||||
pub(crate) local_constants: &'a [F],
|
||||
@ -25,7 +26,7 @@ pub(crate) struct EvaluationVars<'a, F: Field> {
|
||||
/// This type implements `Hash` and `Eq` based on references rather
|
||||
/// than content. This is useful when we want to use constraint polynomials as `HashMap` keys, but
|
||||
/// we want address-based hashing for performance reasons.
|
||||
#[derive(Clone)]
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct ConstraintPolynomial<F: Field>(pub(crate) Rc<ConstraintPolynomialInner<F>>);
|
||||
|
||||
impl<F: Field> ConstraintPolynomial<F> {
|
||||
@ -214,12 +215,68 @@ impl<F: Field> ConstraintPolynomial<F> {
|
||||
from: Self,
|
||||
to: Self,
|
||||
) -> Self {
|
||||
Self::from_inner(self.0.replace_all(from, to))
|
||||
self.replace_all_helper(from, to, &mut HashMap::new())
|
||||
}
|
||||
|
||||
/// Replace all occurrences of `from` with `to` in this polynomial graph. In order to preserve
|
||||
/// the structure of the graph, we keep track of any `ConstraintPolynomial`s that have been
|
||||
/// replaced already.
|
||||
fn replace_all_helper(
|
||||
&self, from: Self,
|
||||
to: Self,
|
||||
replacements: &mut HashMap<Self, Self>,
|
||||
) -> Self {
|
||||
if *self == from {
|
||||
return to;
|
||||
}
|
||||
|
||||
if let Some(replacement) = replacements.get(self) {
|
||||
return replacement.clone();
|
||||
}
|
||||
|
||||
match self.0.borrow() {
|
||||
ConstraintPolynomialInner::Constant(_) => self.clone(),
|
||||
ConstraintPolynomialInner::LocalConstant(_) => self.clone(),
|
||||
ConstraintPolynomialInner::NextConstant(_) => self.clone(),
|
||||
ConstraintPolynomialInner::LocalWireValue(_) => self.clone(),
|
||||
ConstraintPolynomialInner::NextWireValue(_) => self.clone(),
|
||||
ConstraintPolynomialInner::Sum { lhs, rhs } => {
|
||||
let lhs = lhs.replace_all_helper(from.clone(), to.clone(), replacements);
|
||||
let rhs = rhs.replace_all_helper(from, to, replacements);
|
||||
let replacement = Self::from_inner(ConstraintPolynomialInner::Sum { lhs, rhs });
|
||||
debug_assert!(!replacements.contains_key(self));
|
||||
replacements.insert(self.clone(), replacement.clone());
|
||||
replacement
|
||||
}
|
||||
ConstraintPolynomialInner::Product { lhs, rhs } => {
|
||||
let lhs = lhs.replace_all_helper(from.clone(), to.clone(), replacements);
|
||||
let rhs = rhs.replace_all_helper(from, to, replacements);
|
||||
let replacement = Self::from_inner(ConstraintPolynomialInner::Product { lhs, rhs });
|
||||
debug_assert!(!replacements.contains_key(self));
|
||||
replacements.insert(self.clone(), replacement.clone());
|
||||
replacement
|
||||
}
|
||||
ConstraintPolynomialInner::Exponentiation { base, exponent } => {
|
||||
let base = base.replace_all_helper(from, to, replacements);
|
||||
let replacement = Self::from_inner(
|
||||
ConstraintPolynomialInner::Exponentiation { base, exponent: *exponent });
|
||||
debug_assert!(!replacements.contains_key(self));
|
||||
replacements.insert(self.clone(), replacement.clone());
|
||||
replacement
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn from_inner(inner: ConstraintPolynomialInner<F>) -> Self {
|
||||
Self(Rc::new(inner))
|
||||
}
|
||||
|
||||
/// The number of polynomials in this graph.
|
||||
fn graph_size(&self) -> usize {
|
||||
let mut degrees = HashMap::new();
|
||||
self.populate_degree_map(&mut degrees);
|
||||
degrees.len()
|
||||
}
|
||||
}
|
||||
|
||||
impl<F: Field> PartialEq for ConstraintPolynomial<F> {
|
||||
@ -400,6 +457,7 @@ impl<F: Field> Product for ConstraintPolynomial<F> {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub(crate) enum ConstraintPolynomialInner<F: Field> {
|
||||
Constant(F),
|
||||
|
||||
@ -429,41 +487,41 @@ impl<F: Field> ConstraintPolynomialInner<F> {
|
||||
ConstraintPolynomialInner::LocalConstant(_) => (),
|
||||
ConstraintPolynomialInner::NextConstant(_) => (),
|
||||
ConstraintPolynomialInner::LocalWireValue(i) =>
|
||||
{ deps.insert(Wire { gate, input: *i }); },
|
||||
{ deps.insert(Wire { gate, input: *i }); }
|
||||
ConstraintPolynomialInner::NextWireValue(i) =>
|
||||
{ deps.insert(Wire { gate: gate + 1, input: *i }); }
|
||||
ConstraintPolynomialInner::Sum { lhs, rhs } => {
|
||||
lhs.0.add_dependencies(gate, deps);
|
||||
rhs.0.add_dependencies(gate, deps);
|
||||
},
|
||||
}
|
||||
ConstraintPolynomialInner::Product { lhs, rhs } => {
|
||||
lhs.0.add_dependencies(gate, deps);
|
||||
rhs.0.add_dependencies(gate, deps);
|
||||
},
|
||||
}
|
||||
ConstraintPolynomialInner::Exponentiation { base, exponent: _ } => {
|
||||
base.0.add_dependencies(gate, deps);
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn add_constant_indices(&self, indices: &mut HashSet<usize>) {
|
||||
match self {
|
||||
ConstraintPolynomialInner::Constant(_) => (),
|
||||
ConstraintPolynomialInner::LocalConstant(i) => { indices.insert(*i); },
|
||||
ConstraintPolynomialInner::NextConstant(i) => { indices.insert(*i); },
|
||||
ConstraintPolynomialInner::LocalConstant(i) => { indices.insert(*i); }
|
||||
ConstraintPolynomialInner::NextConstant(i) => { indices.insert(*i); }
|
||||
ConstraintPolynomialInner::LocalWireValue(_) => (),
|
||||
ConstraintPolynomialInner::NextWireValue(_) => (),
|
||||
ConstraintPolynomialInner::Sum { lhs, rhs } => {
|
||||
lhs.0.add_constant_indices(indices);
|
||||
rhs.0.add_constant_indices(indices);
|
||||
},
|
||||
}
|
||||
ConstraintPolynomialInner::Product { lhs, rhs } => {
|
||||
lhs.0.add_constant_indices(indices);
|
||||
rhs.0.add_constant_indices(indices);
|
||||
},
|
||||
}
|
||||
ConstraintPolynomialInner::Exponentiation { base, exponent: _ } => {
|
||||
base.0.add_constant_indices(indices);
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -482,16 +540,16 @@ impl<F: Field> ConstraintPolynomialInner<F> {
|
||||
let lhs = lhs.evaluate_memoized(vars, mem);
|
||||
let rhs = rhs.evaluate_memoized(vars, mem);
|
||||
lhs + rhs
|
||||
},
|
||||
}
|
||||
ConstraintPolynomialInner::Product { lhs, rhs } => {
|
||||
let lhs = lhs.evaluate_memoized(vars, mem);
|
||||
let rhs = rhs.evaluate_memoized(vars, mem);
|
||||
lhs * rhs
|
||||
},
|
||||
}
|
||||
ConstraintPolynomialInner::Exponentiation { base, exponent } => {
|
||||
let base = base.evaluate_memoized(vars, mem);
|
||||
base.exp_usize(*exponent)
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -508,12 +566,52 @@ impl<F: Field> ConstraintPolynomialInner<F> {
|
||||
base.0.degree() * BigUint::from_usize(*exponent).unwrap(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn replace_all(
|
||||
&self,
|
||||
from: ConstraintPolynomial<F>,
|
||||
to: ConstraintPolynomial<F>,
|
||||
) -> Self {
|
||||
todo!()
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use crate::constraint_polynomial::ConstraintPolynomial;
|
||||
use crate::field::crandall_field::CrandallField;
|
||||
|
||||
#[test]
|
||||
fn equality() {
|
||||
type F = CrandallField;
|
||||
let wire0 = ConstraintPolynomial::<F>::local_wire_value(0);
|
||||
// == should compare the pointers, and the clone should point to the same underlying
|
||||
// ConstraintPolynomialInner.
|
||||
assert_eq!(wire0.clone(), wire0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn replace_all() {
|
||||
type F = CrandallField;
|
||||
let wire0 = ConstraintPolynomial::<F>::local_wire_value(0);
|
||||
let wire1 = ConstraintPolynomial::<F>::local_wire_value(1);
|
||||
let wire2 = ConstraintPolynomial::<F>::local_wire_value(2);
|
||||
let wire3 = ConstraintPolynomial::<F>::local_wire_value(3);
|
||||
let wire4 = ConstraintPolynomial::<F>::local_wire_value(4);
|
||||
let sum01 = &wire0 + &wire1;
|
||||
let sum12 = &wire1 + &wire2;
|
||||
let sum23 = &wire2 + &wire3;
|
||||
let product = &sum01 * &sum12 * &sum23;
|
||||
|
||||
assert_eq!(
|
||||
wire0.replace_all(wire0.clone(), wire1.clone()),
|
||||
wire1);
|
||||
|
||||
assert_eq!(
|
||||
wire0.replace_all(wire1.clone(), wire2.clone()),
|
||||
wire0);
|
||||
|
||||
// This should be a no-op, since wire 4 is not present in the product.
|
||||
assert_eq!(
|
||||
product.replace_all(wire4.clone(), wire3.clone()).graph_size(),
|
||||
product.graph_size());
|
||||
|
||||
// This shouldn't change the graph structure at all, since the replacement wire 4 was not
|
||||
// previously present.
|
||||
assert_eq!(
|
||||
product.replace_all(wire3.clone(), wire4.clone()).graph_size(),
|
||||
product.graph_size());
|
||||
}
|
||||
}
|
||||
|
||||
@ -99,7 +99,10 @@ impl<F: Field, const W: usize, const R: usize> DeterministicGate<F> for GMiMCGat
|
||||
|
||||
// A degree of 9 is reasonable for most circuits, and it means that we only need wires for
|
||||
// every other addition buffer state.
|
||||
OutputGraph { outputs }.shrink_degree(9)
|
||||
println!("before");
|
||||
let out = OutputGraph { outputs }.shrink_degree(9);
|
||||
println!("after");
|
||||
out
|
||||
}
|
||||
|
||||
fn additional_constraints(&self, _config: CircuitConfig) -> Vec<ConstraintPolynomial<F>> {
|
||||
@ -108,3 +111,29 @@ impl<F: Field, const W: usize, const R: usize> DeterministicGate<F> for GMiMCGat
|
||||
vec![switch_bool_constraint]
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::sync::Arc;
|
||||
|
||||
use crate::field::crandall_field::CrandallField;
|
||||
use crate::gates::gmimc::GMiMCGate;
|
||||
use crate::field::field::Field;
|
||||
use crate::circuit_data::CircuitConfig;
|
||||
use crate::gates::deterministic_gate::DeterministicGate;
|
||||
|
||||
#[test]
|
||||
fn degree() {
|
||||
type F = CrandallField;
|
||||
const W: usize = 12;
|
||||
const R: usize = 101;
|
||||
let gate = GMiMCGate::<F, W, R> { constants: Arc::new([F::TWO; R]) };
|
||||
let config = CircuitConfig {
|
||||
num_wires: 200,
|
||||
num_routed_wires: 200,
|
||||
security_bits: 128
|
||||
};
|
||||
let outs = gate.outputs(config);
|
||||
assert_eq!(outs.max_wire_input_index(), Some(50));
|
||||
}
|
||||
}
|
||||
|
||||
@ -1,25 +1,18 @@
|
||||
use std::collections::HashMap;
|
||||
use std::iter;
|
||||
|
||||
use crate::constraint_polynomial::{ConstraintPolynomial};
|
||||
use num::{BigUint, FromPrimitive, One};
|
||||
|
||||
use crate::constraint_polynomial::ConstraintPolynomial;
|
||||
use crate::field::field::Field;
|
||||
use std::collections::HashMap;
|
||||
use num::BigUint;
|
||||
|
||||
/// Represents a set of deterministic gate outputs, expressed as polynomials over witness
|
||||
/// values.
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct OutputGraph<F: Field> {
|
||||
pub(crate) outputs: Vec<(GateOutputLocation, ConstraintPolynomial<F>)>
|
||||
}
|
||||
|
||||
/// Represents an output location of a deterministic gate.
|
||||
#[derive(Copy, Clone)]
|
||||
pub enum GateOutputLocation {
|
||||
/// A wire belonging to the gate itself.
|
||||
LocalWire(usize),
|
||||
/// A wire belonging to the following gate.
|
||||
NextWire(usize),
|
||||
}
|
||||
|
||||
impl<F: Field> OutputGraph<F> {
|
||||
/// Creates an output graph with a single output.
|
||||
pub fn single_output(loc: GateOutputLocation, out: ConstraintPolynomial<F>) -> Self {
|
||||
@ -31,10 +24,50 @@ impl<F: Field> OutputGraph<F> {
|
||||
///
|
||||
/// Note that this uses a simple greedy algorithm, so the result may not be optimal in terms of wire
|
||||
/// count.
|
||||
// TODO: This doesn't yet work with large exponentiations, i.e. x^n where n > new_degree. Not an
|
||||
// TODO: This doesn't yet work with large exponentiations, i.e. x^n where n > max_degree. Not an
|
||||
// immediate problem since our gates don't use those.
|
||||
pub fn shrink_degree(&self, new_degree: usize) -> Self {
|
||||
todo!()
|
||||
pub fn shrink_degree(&self, max_degree: usize) -> Self {
|
||||
let max_degree_biguint = BigUint::from_usize(max_degree).unwrap();
|
||||
|
||||
let mut current_graph = self.clone();
|
||||
|
||||
while current_graph.count_high_degree_polys(max_degree) > 0 {
|
||||
// Find polynomials with a degree between 2 and the max, inclusive.
|
||||
// These are candidates for becoming new wires.
|
||||
let mut candidates = current_graph.degree_map().into_iter()
|
||||
.filter(|(_poly, deg)| deg > &BigUint::one() && deg <= &max_degree_biguint)
|
||||
.map(|(poly, _deg)| poly);
|
||||
|
||||
// Pick the candidate that minimizes the number of high-degree polynomials in our graph.
|
||||
// This is just a simple heuristic; it won't always give an optimal wire count.
|
||||
let mut first = candidates.next().expect("No candidate; cannot reduce degree further");
|
||||
let mut leader_graph = current_graph.allocate_wire(first);
|
||||
let mut leader_high_deg_count = leader_graph.count_high_degree_polys(max_degree);
|
||||
|
||||
for candidate in candidates {
|
||||
let candidate_graph = current_graph.allocate_wire(candidate);
|
||||
let candidate_high_deg_count = candidate_graph.count_high_degree_polys(max_degree);
|
||||
if candidate_high_deg_count < leader_high_deg_count {
|
||||
leader_graph = candidate_graph;
|
||||
leader_high_deg_count = candidate_high_deg_count;
|
||||
}
|
||||
}
|
||||
|
||||
// println!("before {:?}", current_graph);
|
||||
// println!("after {:?}", leader_graph);
|
||||
current_graph = leader_graph;
|
||||
println!("{}", leader_high_deg_count);
|
||||
}
|
||||
|
||||
current_graph
|
||||
}
|
||||
|
||||
/// The number of polynomials in this graph which exceed the given maximum degree.
|
||||
fn count_high_degree_polys(&self, max_degree: usize) -> usize {
|
||||
let max_degree = BigUint::from_usize(max_degree).unwrap();
|
||||
self.degree_map().into_iter()
|
||||
.filter(|(_poly, deg)| deg > &max_degree)
|
||||
.count()
|
||||
}
|
||||
|
||||
fn degree_map(&self) -> HashMap<ConstraintPolynomial<F>, BigUint> {
|
||||
@ -45,12 +78,17 @@ impl<F: Field> OutputGraph<F> {
|
||||
degrees
|
||||
}
|
||||
|
||||
/// The largest local wire index in this entire graph.
|
||||
pub(crate) fn max_wire_input_index(&self) -> Option<usize> {
|
||||
self.outputs.iter()
|
||||
.flat_map(|(loc, out)| out.max_wire_input_index())
|
||||
.max()
|
||||
}
|
||||
|
||||
/// Allocate a new wire for the given target polynomial, and return a new output graph with
|
||||
/// references to the target polynomial replaced with references to that wire.
|
||||
fn allocate_wire(&self, target: ConstraintPolynomial<F>) -> Self {
|
||||
let new_wire_index = self.outputs.iter()
|
||||
.flat_map(|(loc, out)| out.max_wire_input_index())
|
||||
.max()
|
||||
let new_wire_index = self.max_wire_input_index()
|
||||
.map_or(0, |i| i + 1);
|
||||
|
||||
let new_wire = ConstraintPolynomial::local_wire_value(new_wire_index);
|
||||
@ -63,16 +101,52 @@ impl<F: Field> OutputGraph<F> {
|
||||
}
|
||||
}
|
||||
|
||||
/// Represents an output location of a deterministic gate.
|
||||
#[derive(Copy, Clone, Debug)]
|
||||
pub enum GateOutputLocation {
|
||||
/// A wire belonging to the gate itself.
|
||||
LocalWire(usize),
|
||||
/// A wire belonging to the following gate.
|
||||
NextWire(usize),
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use crate::constraint_polynomial::ConstraintPolynomial;
|
||||
use crate::gates::output_graph::shrink_degree;
|
||||
use crate::field::crandall_field::CrandallField;
|
||||
use crate::gates::output_graph::{GateOutputLocation, OutputGraph};
|
||||
|
||||
#[test]
|
||||
fn shrink_exp() {
|
||||
let original = ConstraintPolynomial::local_wire_value(0).exp(10);
|
||||
let shrunk = shrink_degree(original, 3);
|
||||
// `shrunk` should be something similar to (wire0^3)^3 * wire0.
|
||||
assert_eq!(shrunk.max_wire_input_index(), Some(2))
|
||||
fn shrink_squaring_graph() {
|
||||
type F = CrandallField;
|
||||
let deg1 = ConstraintPolynomial::<F>::local_wire_value(0);
|
||||
let deg2 = deg1.square();
|
||||
let deg4 = deg2.square();
|
||||
let deg8 = deg4.square();
|
||||
let deg16 = deg8.square();
|
||||
|
||||
let original = OutputGraph::single_output(
|
||||
GateOutputLocation::NextWire(0),
|
||||
deg16);
|
||||
|
||||
let degree_map = original.degree_map();
|
||||
assert_eq!(degree_map.len(), 5);
|
||||
|
||||
assert_eq!(original.count_high_degree_polys(2), 3);
|
||||
assert_eq!(original.count_high_degree_polys(3), 3);
|
||||
assert_eq!(original.count_high_degree_polys(4), 2);
|
||||
|
||||
let shrunk_deg_2 = original.shrink_degree(2);
|
||||
let shrunk_deg_3 = original.shrink_degree(3);
|
||||
let shrunk_deg_4 = original.shrink_degree(4);
|
||||
|
||||
// `shrunk_deg_2` should have an intermediate wire for deg2, deg4, and deg8.
|
||||
assert_eq!(shrunk_deg_2.max_wire_input_index(), Some(3));
|
||||
|
||||
// `shrunk_deg_3` should also have an intermediate wire for deg2, deg4, and deg8.
|
||||
assert_eq!(shrunk_deg_3.max_wire_input_index(), Some(3));
|
||||
|
||||
// `shrunk_deg_4` should have an intermediate wire for deg4 only.
|
||||
assert_eq!(shrunk_deg_4.max_wire_input_index(), Some(1));
|
||||
}
|
||||
}
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user