diff --git a/src/bin/bench_recursion.rs b/src/bin/bench_recursion.rs index efab1dd8..35ba280d 100644 --- a/src/bin/bench_recursion.rs +++ b/src/bin/bench_recursion.rs @@ -19,7 +19,7 @@ fn main() { // change this to info or warn later. env_logger::Builder::from_env(Env::default().default_filter_or("debug")).init(); - bench_prove::(); + bench_prove::(); // bench_field_mul::(); @@ -29,7 +29,7 @@ fn main() { } fn bench_prove, const D: usize>() { - let gmimc_gate = GMiMCGate::::with_automatic_constants(); + let gmimc_gate = GMiMCGate::::with_automatic_constants(); let config = CircuitConfig { num_wires: 134, @@ -46,7 +46,7 @@ fn bench_prove, const D: usize>() { }, }; - let mut builder = CircuitBuilder::::new(config); + let mut builder = CircuitBuilder::::new(config); for _ in 0..10000 { builder.add_gate_no_constants(gmimc_gate.clone()); diff --git a/src/circuit_builder.rs b/src/circuit_builder.rs index 18ff666c..a2c94f83 100644 --- a/src/circuit_builder.rs +++ b/src/circuit_builder.rs @@ -8,28 +8,28 @@ use crate::circuit_data::{ VerifierCircuitData, VerifierOnlyCircuitData, }; use crate::field::cosets::get_unique_coset_shifts; +use crate::field::extension_field::Extendable; use crate::field::field::Field; use crate::gates::constant::ConstantGate; use crate::gates::gate::{GateInstance, GateRef}; use crate::gates::noop::NoopGate; use crate::generator::{CopyGenerator, WitnessGenerator}; use crate::hash::hash_n_to_hash; -use crate::merkle_tree::MerkleTree; use crate::permutation_argument::TargetPartitions; use crate::polynomial::commitment::ListPolynomialCommitment; use crate::polynomial::polynomial::PolynomialValues; use crate::target::Target; -use crate::util::{log2_strict, transpose, transpose_poly_values}; +use crate::util::{log2_strict, transpose}; use crate::wire::Wire; -pub struct CircuitBuilder { +pub struct CircuitBuilder, const D: usize> { pub(crate) config: CircuitConfig, /// The types of gates used in this circuit. - gates: HashSet>, + gates: HashSet>, /// The concrete placement of each gate. - gate_instances: Vec>, + gate_instances: Vec>, /// The next available index for a public input. public_input_index: usize, @@ -46,7 +46,7 @@ pub struct CircuitBuilder { targets_to_constants: HashMap, } -impl CircuitBuilder { +impl, const D: usize> CircuitBuilder { pub fn new(config: CircuitConfig) -> Self { CircuitBuilder { config, @@ -89,12 +89,12 @@ impl CircuitBuilder { (0..n).map(|_i| self.add_virtual_advice_target()).collect() } - pub fn add_gate_no_constants(&mut self, gate_type: GateRef) -> usize { + pub fn add_gate_no_constants(&mut self, gate_type: GateRef) -> usize { self.add_gate(gate_type, Vec::new()) } /// Adds a gate to the circuit, and returns its index. - pub fn add_gate(&mut self, gate_type: GateRef, constants: Vec) -> usize { + pub fn add_gate(&mut self, gate_type: GateRef, constants: Vec) -> usize { // If we haven't seen a gate of this type before, check that it's compatible with our // circuit configuration, then register it. if !self.gates.contains(&gate_type) { @@ -113,7 +113,7 @@ impl CircuitBuilder { index } - fn check_gate_compatibility(&self, gate: &GateRef) { + fn check_gate_compatibility(&self, gate: &GateRef) { assert!( gate.0.num_wires() <= self.config.num_wires, "{:?} requires {} wires, but our GateConfig has only {}", @@ -261,7 +261,7 @@ impl CircuitBuilder { } /// Builds a "full circuit", with both prover and verifier data. - pub fn build(mut self) -> CircuitData { + pub fn build(mut self) -> CircuitData { let start = Instant::now(); info!( "degree before blinding & padding: {}", @@ -335,7 +335,7 @@ impl CircuitBuilder { } /// Builds a "prover circuit", with data needed to generate proofs but not verify them. - pub fn build_prover(self) -> ProverCircuitData { + pub fn build_prover(self) -> ProverCircuitData { // TODO: Can skip parts of this. let CircuitData { prover_only, @@ -349,7 +349,7 @@ impl CircuitBuilder { } /// Builds a "verifier circuit", with data needed to verify proofs but not generate them. - pub fn build_verifier(self) -> VerifierCircuitData { + pub fn build_verifier(self) -> VerifierCircuitData { // TODO: Can skip parts of this. let CircuitData { verifier_only, diff --git a/src/circuit_data.rs b/src/circuit_data.rs index 0e61a583..4d9a7110 100644 --- a/src/circuit_data.rs +++ b/src/circuit_data.rs @@ -5,7 +5,6 @@ use crate::field::field::Field; use crate::fri::FriConfig; use crate::gates::gate::GateRef; use crate::generator::WitnessGenerator; -use crate::merkle_tree::MerkleTree; use crate::polynomial::commitment::ListPolynomialCommitment; use crate::proof::{Hash, HashTarget, Proof}; use crate::prover::prove; @@ -52,24 +51,18 @@ impl CircuitConfig { } /// Circuit data required by the prover or the verifier. -pub struct CircuitData { +pub struct CircuitData, const D: usize> { pub(crate) prover_only: ProverOnlyCircuitData, pub(crate) verifier_only: VerifierOnlyCircuitData, - pub(crate) common: CommonCircuitData, + pub(crate) common: CommonCircuitData, } -impl CircuitData { - pub fn prove(&self, inputs: PartialWitness) -> Proof - where - F: Extendable, - { +impl, const D: usize> CircuitData { + pub fn prove(&self, inputs: PartialWitness) -> Proof { prove(&self.prover_only, &self.common, inputs) } - pub fn verify(&self, proof: Proof) -> Result<()> - where - F: Extendable, - { + pub fn verify(&self, proof: Proof) -> Result<()> { verify(proof, &self.verifier_only, &self.common) } } @@ -81,31 +74,25 @@ impl CircuitData { /// structure as succinct as we can. Thus we include various precomputed data which isn't strictly /// required, like LDEs of preprocessed polynomials. If more succinctness was desired, we could /// construct a more minimal prover structure and convert back and forth. -pub struct ProverCircuitData { +pub struct ProverCircuitData, const D: usize> { pub(crate) prover_only: ProverOnlyCircuitData, - pub(crate) common: CommonCircuitData, + pub(crate) common: CommonCircuitData, } -impl ProverCircuitData { - pub fn prove(&self, inputs: PartialWitness) -> Proof - where - F: Extendable, - { +impl, const D: usize> ProverCircuitData { + pub fn prove(&self, inputs: PartialWitness) -> Proof { prove(&self.prover_only, &self.common, inputs) } } /// Circuit data required by the prover. -pub struct VerifierCircuitData { +pub struct VerifierCircuitData, const D: usize> { pub(crate) verifier_only: VerifierOnlyCircuitData, - pub(crate) common: CommonCircuitData, + pub(crate) common: CommonCircuitData, } -impl VerifierCircuitData { - pub fn verify(&self, proof: Proof) -> Result<()> - where - F: Extendable, - { +impl, const D: usize> VerifierCircuitData { + pub fn verify(&self, proof: Proof) -> Result<()> { verify(proof, &self.verifier_only, &self.common) } } @@ -129,13 +116,13 @@ pub(crate) struct VerifierOnlyCircuitData { } /// Circuit data required by both the prover and the verifier. -pub(crate) struct CommonCircuitData { +pub(crate) struct CommonCircuitData, const D: usize> { pub(crate) config: CircuitConfig, pub(crate) degree_bits: usize, /// The types of gates used in this circuit. - pub(crate) gates: Vec>, + pub(crate) gates: Vec>, /// The largest number of constraints imposed by any gate. pub(crate) num_gate_constraints: usize, @@ -148,7 +135,7 @@ pub(crate) struct CommonCircuitData { pub(crate) circuit_digest: Hash, } -impl CommonCircuitData { +impl, const D: usize> CommonCircuitData { pub fn degree(&self) -> usize { 1 << self.degree_bits } diff --git a/src/field/extension_field/mod.rs b/src/field/extension_field/mod.rs index 46c99f11..7e74be78 100644 --- a/src/field/extension_field/mod.rs +++ b/src/field/extension_field/mod.rs @@ -2,6 +2,7 @@ use crate::field::field::Field; pub mod quadratic; pub mod quartic; +mod quartic_quartic; pub mod target; /// Optimal extension field trait. @@ -32,7 +33,7 @@ impl OEF<1> for F { const W: Self::BaseField = F::ZERO; } -pub trait Extendable: Sized { +pub trait Extendable: Field + Sized { type Extension: Field + OEF + From; } diff --git a/src/field/extension_field/quadratic.rs b/src/field/extension_field/quadratic.rs index fc74ec88..27fc33a1 100644 --- a/src/field/extension_field/quadratic.rs +++ b/src/field/extension_field/quadratic.rs @@ -3,11 +3,11 @@ use crate::field::extension_field::{FieldExtension, OEF}; use crate::field::field::Field; use rand::Rng; use std::fmt::{Debug, Display, Formatter}; -use std::hash::{Hash, Hasher}; +use std::hash::Hash; use std::iter::{Product, Sum}; use std::ops::{Add, AddAssign, Div, DivAssign, Mul, MulAssign, Neg, Sub, SubAssign}; -#[derive(Copy, Clone)] +#[derive(Copy, Clone, Eq, PartialEq, Hash)] pub struct QuadraticCrandallField([CrandallField; 2]); impl OEF<2> for QuadraticCrandallField { @@ -38,23 +38,6 @@ impl From<>::BaseField> for QuadraticCrandallField { } } -impl PartialEq for QuadraticCrandallField { - fn eq(&self, other: &Self) -> bool { - FieldExtension::<2>::to_basefield_array(self) - == FieldExtension::<2>::to_basefield_array(other) - } -} - -impl Eq for QuadraticCrandallField {} - -impl Hash for QuadraticCrandallField { - fn hash(&self, state: &mut H) { - for l in &FieldExtension::<2>::to_basefield_array(self) { - Hash::hash(l, state); - } - } -} - impl Field for QuadraticCrandallField { const ZERO: Self = Self([CrandallField::ZERO; 2]); const ONE: Self = Self([CrandallField::ONE, CrandallField::ZERO]); diff --git a/src/field/extension_field/quartic.rs b/src/field/extension_field/quartic.rs index f1791d0e..6bd8ac55 100644 --- a/src/field/extension_field/quartic.rs +++ b/src/field/extension_field/quartic.rs @@ -1,18 +1,23 @@ -use crate::field::crandall_field::CrandallField; -use crate::field::extension_field::{FieldExtension, OEF}; -use crate::field::field::Field; -use rand::Rng; use std::fmt::{Debug, Display, Formatter}; -use std::hash::{Hash, Hasher}; +use std::hash::Hash; use std::iter::{Product, Sum}; use std::ops::{Add, AddAssign, Div, DivAssign, Mul, MulAssign, Neg, Sub, SubAssign}; -#[derive(Copy, Clone)] -pub struct QuarticCrandallField([CrandallField; 4]); +use rand::Rng; + +use crate::field::crandall_field::CrandallField; +use crate::field::extension_field::quartic_quartic::QuarticQuarticCrandallField; +use crate::field::extension_field::{Extendable, FieldExtension, OEF}; +use crate::field::field::Field; + +/// A quartic extension of `CrandallField`. +#[derive(Copy, Clone, Eq, PartialEq, Hash)] +pub struct QuarticCrandallField(pub(crate) [CrandallField; 4]); impl OEF<4> for QuarticCrandallField { // Verifiable in Sage with - // ``R. = GF(p)[]; assert (x^4 -3).is_irreducible()`. + // R. = GF(p)[] + // assert (x^4 - 3).is_irreducible() const W: CrandallField = CrandallField(3); } @@ -43,23 +48,6 @@ impl From<>::BaseField> for QuarticCrandallField { } } -impl PartialEq for QuarticCrandallField { - fn eq(&self, other: &Self) -> bool { - FieldExtension::<4>::to_basefield_array(self) - == FieldExtension::<4>::to_basefield_array(other) - } -} - -impl Eq for QuarticCrandallField {} - -impl Hash for QuarticCrandallField { - fn hash(&self, state: &mut H) { - for l in &FieldExtension::<4>::to_basefield_array(self) { - Hash::hash(l, state); - } - } -} - impl Field for QuarticCrandallField { const ZERO: Self = Self([CrandallField::ZERO; 4]); const ONE: Self = Self([ @@ -251,6 +239,10 @@ impl DivAssign for QuarticCrandallField { } } +impl Extendable<4> for QuarticCrandallField { + type Extension = QuarticQuarticCrandallField; +} + #[cfg(test)] mod tests { use crate::field::extension_field::quartic::QuarticCrandallField; diff --git a/src/field/extension_field/quartic_quartic.rs b/src/field/extension_field/quartic_quartic.rs new file mode 100644 index 00000000..62fad962 --- /dev/null +++ b/src/field/extension_field/quartic_quartic.rs @@ -0,0 +1,259 @@ +use crate::field::crandall_field::CrandallField; +use crate::field::extension_field::quartic::QuarticCrandallField; +use crate::field::extension_field::{FieldExtension, OEF}; +use crate::field::field::Field; +use rand::Rng; +use std::fmt::{Debug, Display, Formatter}; +use std::hash::Hash; +use std::iter::{Product, Sum}; +use std::ops::{Add, AddAssign, Div, DivAssign, Mul, MulAssign, Neg, Sub, SubAssign}; + +/// A quartic extension of `QuarticCrandallField`. +#[derive(Copy, Clone, Eq, PartialEq, Hash)] +pub struct QuarticQuarticCrandallField(pub(crate) [QuarticCrandallField; 4]); + +impl OEF<4> for QuarticQuarticCrandallField { + // Verifiable in Sage with + // p = 2^64 - 9 * 2^28 + 1 + // F = GF(p) + // PR_F. = PolynomialRing(F) + // assert (x^4 - 3).is_irreducible() + // F4. = F.extension(x^4 - 3) + // PR_F4. = PolynomialRing(F4) + // assert (x^4 - y).is_irreducible() + // F44. = F4.extension(x^4 - y) + const W: QuarticCrandallField = QuarticCrandallField([ + CrandallField(0), + CrandallField(1), + CrandallField(0), + CrandallField(0), + ]); +} + +impl FieldExtension<4> for QuarticQuarticCrandallField { + type BaseField = QuarticCrandallField; + + fn to_basefield_array(&self) -> [Self::BaseField; 4] { + self.0 + } + + fn from_basefield_array(arr: [Self::BaseField; 4]) -> Self { + Self(arr) + } + + fn from_basefield(x: Self::BaseField) -> Self { + x.into() + } +} + +impl From<>::BaseField> for QuarticQuarticCrandallField { + fn from(x: >::BaseField) -> Self { + Self([ + x, + >::BaseField::ZERO, + >::BaseField::ZERO, + >::BaseField::ZERO, + ]) + } +} + +impl Field for QuarticQuarticCrandallField { + const ZERO: Self = Self([QuarticCrandallField::ZERO; 4]); + const ONE: Self = Self([ + QuarticCrandallField::ONE, + QuarticCrandallField::ZERO, + QuarticCrandallField::ZERO, + QuarticCrandallField::ZERO, + ]); + const TWO: Self = Self([ + QuarticCrandallField::TWO, + QuarticCrandallField::ZERO, + QuarticCrandallField::ZERO, + QuarticCrandallField::ZERO, + ]); + const NEG_ONE: Self = Self([ + QuarticCrandallField::NEG_ONE, + QuarticCrandallField::ZERO, + QuarticCrandallField::ZERO, + QuarticCrandallField::ZERO, + ]); + + // Does not fit in 64-bits. + const ORDER: u64 = 0; + const TWO_ADICITY: usize = 32; + const MULTIPLICATIVE_GROUP_GENERATOR: Self = Self([ + QuarticCrandallField([ + CrandallField(7562951059982399618), + CrandallField(16734862117167184487), + CrandallField(8532193866847630013), + CrandallField(15462716295551021898), + ]), + QuarticCrandallField([ + CrandallField(16143979237658148445), + CrandallField(12004617499933809221), + CrandallField(11826153143854535879), + CrandallField(14780824604953232397), + ]), + QuarticCrandallField([ + CrandallField(12779077039546101185), + CrandallField(15745975127331074164), + CrandallField(4297791107105154033), + CrandallField(5966855376644799108), + ]), + QuarticCrandallField([ + CrandallField(1942992936904935291), + CrandallField(6041097781717465159), + CrandallField(16875726992388585780), + CrandallField(17742746479895474446), + ]), + ]); + const POWER_OF_TWO_GENERATOR: Self = Self([ + QuarticCrandallField::ZERO, + QuarticCrandallField([ + CrandallField::ZERO, + CrandallField::ZERO, + CrandallField::ZERO, + CrandallField(6809469153480715254), + ]), + QuarticCrandallField::ZERO, + QuarticCrandallField::ZERO, + ]); + + fn try_inverse(&self) -> Option { + todo!() + } + + fn to_canonical_u64(&self) -> u64 { + panic!("Doesn't fit!") + } + + fn from_canonical_u64(n: u64) -> Self { + >::BaseField::from_canonical_u64(n).into() + } + + fn rand_from_rng(rng: &mut R) -> Self { + Self([ + >::BaseField::rand_from_rng(rng), + >::BaseField::rand_from_rng(rng), + >::BaseField::rand_from_rng(rng), + >::BaseField::rand_from_rng(rng), + ]) + } +} + +impl Display for QuarticQuarticCrandallField { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + write!( + f, + "({}) + ({})*b + ({})*b^2 + ({})*b^3", + self.0[0], self.0[1], self.0[2], self.0[3] + ) + } +} + +impl Debug for QuarticQuarticCrandallField { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + Display::fmt(self, f) + } +} + +impl Neg for QuarticQuarticCrandallField { + type Output = Self; + + #[inline] + fn neg(self) -> Self { + Self([-self.0[0], -self.0[1], -self.0[2], -self.0[3]]) + } +} + +impl Add for QuarticQuarticCrandallField { + type Output = Self; + + #[inline] + fn add(self, rhs: Self) -> Self { + Self([ + self.0[0] + rhs.0[0], + self.0[1] + rhs.0[1], + self.0[2] + rhs.0[2], + self.0[3] + rhs.0[3], + ]) + } +} + +impl AddAssign for QuarticQuarticCrandallField { + fn add_assign(&mut self, rhs: Self) { + *self = *self + rhs; + } +} + +impl Sum for QuarticQuarticCrandallField { + fn sum>(iter: I) -> Self { + iter.fold(Self::ZERO, |acc, x| acc + x) + } +} + +impl Sub for QuarticQuarticCrandallField { + type Output = Self; + + #[inline] + fn sub(self, rhs: Self) -> Self { + Self([ + self.0[0] - rhs.0[0], + self.0[1] - rhs.0[1], + self.0[2] - rhs.0[2], + self.0[3] - rhs.0[3], + ]) + } +} + +impl SubAssign for QuarticQuarticCrandallField { + #[inline] + fn sub_assign(&mut self, rhs: Self) { + *self = *self - rhs; + } +} + +impl Mul for QuarticQuarticCrandallField { + type Output = Self; + + #[inline] + fn mul(self, rhs: Self) -> Self { + let Self([a0, a1, a2, a3]) = self; + let Self([b0, b1, b2, b3]) = rhs; + + let c0 = a0 * b0 + >::W * (a1 * b3 + a2 * b2 + a3 * b1); + let c1 = a0 * b1 + a1 * b0 + >::W * (a2 * b3 + a3 * b2); + let c2 = a0 * b2 + a1 * b1 + a2 * b0 + >::W * a3 * b3; + let c3 = a0 * b3 + a1 * b2 + a2 * b1 + a3 * b0; + + Self([c0, c1, c2, c3]) + } +} + +impl MulAssign for QuarticQuarticCrandallField { + #[inline] + fn mul_assign(&mut self, rhs: Self) { + *self = *self * rhs; + } +} + +impl Product for QuarticQuarticCrandallField { + fn product>(iter: I) -> Self { + iter.fold(Self::ONE, |acc, x| acc * x) + } +} + +impl Div for QuarticQuarticCrandallField { + type Output = Self; + + #[allow(clippy::suspicious_arithmetic_impl)] + fn div(self, rhs: Self) -> Self::Output { + self * rhs.inverse() + } +} + +impl DivAssign for QuarticQuarticCrandallField { + fn div_assign(&mut self, rhs: Self) { + *self = *self / rhs; + } +} diff --git a/src/field/extension_field/target.rs b/src/field/extension_field/target.rs index 7a72f7e0..7198ebf6 100644 --- a/src/field/extension_field/target.rs +++ b/src/field/extension_field/target.rs @@ -3,6 +3,7 @@ use crate::field::extension_field::{Extendable, FieldExtension, OEF}; use crate::field::field::Field; use crate::target::Target; +/// `Target`s representing an element of an extension field. #[derive(Copy, Clone, Debug)] pub struct ExtensionTarget(pub [Target; D]); @@ -12,50 +13,117 @@ impl ExtensionTarget { } } -impl CircuitBuilder { - pub fn zero_extension(&mut self) -> ExtensionTarget - where - F: Extendable, - { - ExtensionTarget([self.zero(); D]) +/// `Target`s representing an element of an extension of an extension field. +#[derive(Copy, Clone, Debug)] +pub struct ExtensionExtensionTarget(pub [ExtensionTarget; D]); + +impl ExtensionExtensionTarget { + pub fn to_ext_target_array(&self) -> [ExtensionTarget; D] { + self.0 + } +} + +impl, const D: usize> CircuitBuilder { + pub fn constant_extension(&mut self, c: F::Extension) -> ExtensionTarget { + let c_parts = c.to_basefield_array(); + let mut parts = [self.zero(); D]; + for i in 0..D { + parts[i] = self.constant(c_parts[i]); + } + ExtensionTarget(parts) } - pub fn add_extension( + pub fn constant_ext_ext( + &mut self, + c: <>::Extension as Extendable>::Extension, + ) -> ExtensionExtensionTarget + where + F::Extension: Extendable, + { + let c_parts = c.to_basefield_array(); + let mut parts = [self.zero_extension(); D]; + for i in 0..D { + parts[i] = self.constant_extension(c_parts[i]); + } + ExtensionExtensionTarget(parts) + } + + pub fn zero_extension(&mut self) -> ExtensionTarget { + self.constant_extension(F::Extension::ZERO) + } + + pub fn one_extension(&mut self) -> ExtensionTarget { + self.constant_extension(F::Extension::ONE) + } + + pub fn two_extension(&mut self) -> ExtensionTarget { + self.constant_extension(F::Extension::TWO) + } + + pub fn zero_ext_ext(&mut self) -> ExtensionExtensionTarget + where + F::Extension: Extendable, + { + self.constant_ext_ext(<>::Extension as Extendable>::Extension::ZERO) + } + + pub fn add_extension( &mut self, mut a: ExtensionTarget, b: ExtensionTarget, - ) -> ExtensionTarget - where - F: Extendable, - { + ) -> ExtensionTarget { for i in 0..D { a.0[i] = self.add(a.0[i], b.0[i]); } a } - pub fn sub_extension( + pub fn add_ext_ext( + &mut self, + mut a: ExtensionExtensionTarget, + b: ExtensionExtensionTarget, + ) -> ExtensionExtensionTarget { + for i in 0..D { + a.0[i] = self.add_extension(a.0[i], b.0[i]); + } + a + } + + pub fn add_many_extension(&mut self, terms: &[ExtensionTarget]) -> ExtensionTarget { + let mut sum = self.zero_extension(); + for term in terms { + sum = self.add_extension(sum, *term); + } + sum + } + + pub fn sub_extension( &mut self, mut a: ExtensionTarget, b: ExtensionTarget, - ) -> ExtensionTarget - where - F: Extendable, - { + ) -> ExtensionTarget { for i in 0..D { a.0[i] = self.sub(a.0[i], b.0[i]); } a } - pub fn mul_extension( + pub fn sub_ext_ext( + &mut self, + mut a: ExtensionExtensionTarget, + b: ExtensionExtensionTarget, + ) -> ExtensionExtensionTarget { + for i in 0..D { + a.0[i] = self.sub_extension(a.0[i], b.0[i]); + } + a + } + + pub fn mul_extension( &mut self, a: ExtensionTarget, b: ExtensionTarget, - ) -> ExtensionTarget - where - F: Extendable, - { + ) -> ExtensionTarget { let mut res = [self.zero(); D]; for i in 0..D { for j in 0..D { @@ -70,18 +138,72 @@ impl CircuitBuilder { ExtensionTarget(res) } - /// Returns a*b where `b` is in the extension field and `a` is in the base field. - pub fn scalar_mul( + pub fn mul_ext_ext( &mut self, - a: Target, - mut b: ExtensionTarget, - ) -> ExtensionTarget + mut a: ExtensionExtensionTarget, + b: ExtensionExtensionTarget, + ) -> ExtensionExtensionTarget where - F: Extendable, + F::Extension: Extendable, { + let mut res = [self.zero_extension(); D]; + let w = self + .constant_extension(<>::Extension as Extendable>::Extension::W); + for i in 0..D { + for j in 0..D { + let ai_bi = self.mul_extension(a.0[i], b.0[j]); + res[(i + j) % D] = if i + j < D { + self.add_extension(ai_bi, res[(i + j) % D]) + } else { + let w_ai_bi = self.mul_extension(w, ai_bi); + self.add_extension(w_ai_bi, res[(i + j) % D]) + } + } + } + ExtensionExtensionTarget(res) + } + + pub fn mul_many_extension(&mut self, terms: &[ExtensionTarget]) -> ExtensionTarget { + let mut product = self.one_extension(); + for term in terms { + product = self.mul_extension(product, *term); + } + product + } + + /// Like `mul_add`, but for `ExtensionTarget`s. Note that, unlike `mul_add`, this has no + /// performance benefit over separate muls and adds. + pub fn mul_add_extension( + &mut self, + a: ExtensionTarget, + b: ExtensionTarget, + c: ExtensionTarget, + ) -> ExtensionTarget { + let product = self.mul_extension(a, b); + self.add_extension(product, c) + } + + /// Returns `a * b`, where `b` is in the extension field and `a` is in the base field. + pub fn scalar_mul_ext(&mut self, a: Target, mut b: ExtensionTarget) -> ExtensionTarget { for i in 0..D { b.0[i] = self.mul(a, b.0[i]); } b } + + /// Returns `a * b`, where `b` is in the extension of the extension field, and `a` is in the + /// extension field. + pub fn scalar_mul_ext_ext( + &mut self, + a: ExtensionTarget, + mut b: ExtensionExtensionTarget, + ) -> ExtensionExtensionTarget + where + F::Extension: Extendable, + { + for i in 0..D { + b.0[i] = self.mul_extension(a, b.0[i]); + } + b + } } diff --git a/src/fri/mod.rs b/src/fri/mod.rs index d701c421..7c71eb77 100644 --- a/src/fri/mod.rs +++ b/src/fri/mod.rs @@ -50,12 +50,12 @@ fn fri_l(codeword_len: usize, rate_log: usize, conjecture: bool) -> f64 { #[cfg(test)] mod tests { - use super::*; + use anyhow::Result; + use rand::rngs::ThreadRng; + use rand::Rng; + use crate::field::crandall_field::CrandallField; - use crate::field::extension_field::quadratic::QuadraticCrandallField; - use crate::field::extension_field::quartic::QuarticCrandallField; - use crate::field::extension_field::{flatten, Extendable, FieldExtension}; - use crate::field::fft::ifft; + use crate::field::extension_field::Extendable; use crate::field::field::Field; use crate::fri::prover::fri_proof; use crate::fri::verifier::verify_fri_proof; @@ -63,9 +63,8 @@ mod tests { use crate::plonk_challenger::Challenger; use crate::polynomial::polynomial::{PolynomialCoeffs, PolynomialValues}; use crate::util::reverse_index_bits_in_place; - use anyhow::Result; - use rand::rngs::ThreadRng; - use rand::Rng; + + use super::*; fn check_fri, const D: usize>( degree_log: usize, diff --git a/src/gadgets/arithmetic.rs b/src/gadgets/arithmetic.rs index aabaed6a..a214df3f 100644 --- a/src/gadgets/arithmetic.rs +++ b/src/gadgets/arithmetic.rs @@ -1,4 +1,5 @@ use crate::circuit_builder::CircuitBuilder; +use crate::field::extension_field::Extendable; use crate::field::field::Field; use crate::gates::arithmetic::ArithmeticGate; use crate::generator::SimpleGenerator; @@ -6,7 +7,7 @@ use crate::target::Target; use crate::wire::Wire; use crate::witness::PartialWitness; -impl CircuitBuilder { +impl, const D: usize> CircuitBuilder { /// Computes `-x`. pub fn neg(&mut self, x: Target) -> Target { let neg_one = self.neg_one(); diff --git a/src/gadgets/hash.rs b/src/gadgets/hash.rs index 14cf0aca..fae6d433 100644 --- a/src/gadgets/hash.rs +++ b/src/gadgets/hash.rs @@ -1,21 +1,21 @@ use std::convert::TryInto; use crate::circuit_builder::CircuitBuilder; -use crate::field::field::Field; +use crate::field::extension_field::Extendable; use crate::gates::gmimc::GMiMCGate; use crate::hash::GMIMC_ROUNDS; use crate::target::Target; use crate::wire::Wire; // TODO: Move to be next to native `permute`? -impl CircuitBuilder { +impl, const D: usize> CircuitBuilder { pub fn permute(&mut self, inputs: [Target; 12]) -> [Target; 12] { let zero = self.zero(); let gate = - self.add_gate_no_constants(GMiMCGate::::with_automatic_constants()); + self.add_gate_no_constants(GMiMCGate::::with_automatic_constants()); // We don't want to swap any inputs, so set that wire to 0. - let swap_wire = GMiMCGate::::WIRE_SWAP; + let swap_wire = GMiMCGate::::WIRE_SWAP; let swap_wire = Target::Wire(Wire { gate, input: swap_wire, @@ -24,7 +24,7 @@ impl CircuitBuilder { // The old accumulator wire doesn't matter, since we won't read the new accumulator wire. // We do have to set it to something though, so we'll arbitrary pick 0. - let old_acc_wire = GMiMCGate::::WIRE_INDEX_ACCUMULATOR_OLD; + let old_acc_wire = GMiMCGate::::WIRE_INDEX_ACCUMULATOR_OLD; let old_acc_wire = Target::Wire(Wire { gate, input: old_acc_wire, @@ -33,7 +33,7 @@ impl CircuitBuilder { // Route input wires. for i in 0..12 { - let in_wire = GMiMCGate::::wire_input(i); + let in_wire = GMiMCGate::::wire_input(i); let in_wire = Target::Wire(Wire { gate, input: in_wire, @@ -46,7 +46,7 @@ impl CircuitBuilder { .map(|i| { Target::Wire(Wire { gate, - input: GMiMCGate::::wire_output(i), + input: GMiMCGate::::wire_output(i), }) }) .collect::>() diff --git a/src/gadgets/polynomial.rs b/src/gadgets/polynomial.rs index 64bf4ca3..cac51452 100644 --- a/src/gadgets/polynomial.rs +++ b/src/gadgets/polynomial.rs @@ -1,28 +1,27 @@ use crate::circuit_builder::CircuitBuilder; -use crate::field::extension_field::target::ExtensionTarget; +use crate::field::extension_field::target::{ExtensionExtensionTarget, ExtensionTarget}; use crate::field::extension_field::Extendable; -use crate::field::field::Field; use crate::target::Target; -pub struct PolynomialCoeffsTarget(pub Vec>); +pub struct PolynomialCoeffsExtTarget(pub Vec>); -impl PolynomialCoeffsTarget { - pub fn eval_scalar>( +impl PolynomialCoeffsExtTarget { + pub fn eval_scalar>( &self, - builder: &mut CircuitBuilder, + builder: &mut CircuitBuilder, point: Target, ) -> ExtensionTarget { let mut acc = builder.zero_extension(); for &c in self.0.iter().rev() { - let tmp = builder.scalar_mul(point, acc); + let tmp = builder.scalar_mul_ext(point, acc); acc = builder.add_extension(tmp, c); } acc } - pub fn eval>( + pub fn eval>( &self, - builder: &mut CircuitBuilder, + builder: &mut CircuitBuilder, point: ExtensionTarget, ) -> ExtensionTarget { let mut acc = builder.zero_extension(); @@ -33,3 +32,41 @@ impl PolynomialCoeffsTarget { acc } } + +pub struct PolynomialCoeffsExtExtTarget(pub Vec>); + +impl PolynomialCoeffsExtExtTarget { + pub fn eval_scalar( + &self, + builder: &mut CircuitBuilder, + point: ExtensionTarget, + ) -> ExtensionExtensionTarget + where + F: Extendable, + F::Extension: Extendable, + { + let mut acc = builder.zero_ext_ext(); + for &c in self.0.iter().rev() { + let tmp = builder.scalar_mul_ext_ext(point, acc); + acc = builder.add_ext_ext(tmp, c); + } + acc + } + + pub fn eval( + &self, + builder: &mut CircuitBuilder, + point: ExtensionExtensionTarget, + ) -> ExtensionExtensionTarget + where + F: Extendable, + F::Extension: Extendable, + { + let mut acc = builder.zero_ext_ext(); + for &c in self.0.iter().rev() { + let tmp = builder.mul_ext_ext(point, acc); + acc = builder.add_ext_ext(tmp, c); + } + acc + } +} diff --git a/src/gadgets/split_join.rs b/src/gadgets/split_join.rs index 7dbb127c..e65198ec 100644 --- a/src/gadgets/split_join.rs +++ b/src/gadgets/split_join.rs @@ -1,11 +1,12 @@ use crate::circuit_builder::CircuitBuilder; +use crate::field::extension_field::Extendable; use crate::field::field::Field; use crate::generator::{SimpleGenerator, WitnessGenerator}; use crate::target::Target; use crate::wire::Wire; use crate::witness::PartialWitness; -impl CircuitBuilder { +impl, const D: usize> CircuitBuilder { /// Split the given integer into a list of virtual advice targets, where each one represents a /// bit of the integer, with little-endian ordering. /// diff --git a/src/gates/arithmetic.rs b/src/gates/arithmetic.rs index f24ea7c1..0d0fdd7c 100644 --- a/src/gates/arithmetic.rs +++ b/src/gates/arithmetic.rs @@ -1,4 +1,6 @@ use crate::circuit_builder::CircuitBuilder; +use crate::field::extension_field::target::ExtensionTarget; +use crate::field::extension_field::Extendable; use crate::field::field::Field; use crate::gates::gate::{Gate, GateRef}; use crate::generator::{SimpleGenerator, WitnessGenerator}; @@ -16,7 +18,7 @@ use crate::witness::PartialWitness; pub struct ArithmeticGate; impl ArithmeticGate { - pub fn new() -> GateRef { + pub fn new, const D: usize>() -> GateRef { GateRef::new(ArithmeticGate) } @@ -26,12 +28,12 @@ impl ArithmeticGate { pub const WIRE_OUTPUT: usize = 3; } -impl Gate for ArithmeticGate { +impl, const D: usize> Gate for ArithmeticGate { fn id(&self) -> String { format!("{:?}", self) } - fn eval_unfiltered(&self, vars: EvaluationVars) -> Vec { + fn eval_unfiltered(&self, vars: EvaluationVars) -> Vec { let const_0 = vars.local_constants[0]; let const_1 = vars.local_constants[1]; let multiplicand_0 = vars.local_wires[Self::WIRE_MULTIPLICAND_0]; @@ -44,9 +46,9 @@ impl Gate for ArithmeticGate { fn eval_unfiltered_recursively( &self, - builder: &mut CircuitBuilder, - vars: EvaluationTargets, - ) -> Vec { + builder: &mut CircuitBuilder, + vars: EvaluationTargets, + ) -> Vec> { let const_0 = vars.local_constants[0]; let const_1 = vars.local_constants[1]; let multiplicand_0 = vars.local_wires[Self::WIRE_MULTIPLICAND_0]; @@ -54,10 +56,10 @@ impl Gate for ArithmeticGate { let addend = vars.local_wires[Self::WIRE_ADDEND]; let output = vars.local_wires[Self::WIRE_OUTPUT]; - let product_term = builder.mul_many(&[const_0, multiplicand_0, multiplicand_1]); - let addend_term = builder.mul(const_1, addend); - let computed_output = builder.add_many(&[product_term, addend_term]); - vec![builder.sub(computed_output, output)] + let product_term = builder.mul_many_extension(&[const_0, multiplicand_0, multiplicand_1]); + let addend_term = builder.mul_extension(const_1, addend); + let computed_output = builder.add_many_extension(&[product_term, addend_term]); + vec![builder.sub_extension(computed_output, output)] } fn generators( @@ -150,6 +152,6 @@ mod tests { #[test] fn low_degree() { - test_low_degree(ArithmeticGate::new::()) + test_low_degree(ArithmeticGate::new::()) } } diff --git a/src/gates/constant.rs b/src/gates/constant.rs index a0a7685e..3845031a 100644 --- a/src/gates/constant.rs +++ b/src/gates/constant.rs @@ -1,4 +1,6 @@ use crate::circuit_builder::CircuitBuilder; +use crate::field::extension_field::target::ExtensionTarget; +use crate::field::extension_field::Extendable; use crate::field::field::Field; use crate::gates::gate::{Gate, GateRef}; use crate::generator::{SimpleGenerator, WitnessGenerator}; @@ -11,7 +13,7 @@ use crate::witness::PartialWitness; pub struct ConstantGate; impl ConstantGate { - pub fn get() -> GateRef { + pub fn get, const D: usize>() -> GateRef { GateRef::new(ConstantGate) } @@ -20,12 +22,12 @@ impl ConstantGate { pub const WIRE_OUTPUT: usize = 0; } -impl Gate for ConstantGate { +impl, const D: usize> Gate for ConstantGate { fn id(&self) -> String { "ConstantGate".into() } - fn eval_unfiltered(&self, vars: EvaluationVars) -> Vec { + fn eval_unfiltered(&self, vars: EvaluationVars) -> Vec { let input = vars.local_constants[Self::CONST_INPUT]; let output = vars.local_wires[Self::WIRE_OUTPUT]; vec![output - input] @@ -33,12 +35,12 @@ impl Gate for ConstantGate { fn eval_unfiltered_recursively( &self, - builder: &mut CircuitBuilder, - vars: EvaluationTargets, - ) -> Vec { + builder: &mut CircuitBuilder, + vars: EvaluationTargets, + ) -> Vec> { let input = vars.local_constants[Self::CONST_INPUT]; let output = vars.local_wires[Self::WIRE_OUTPUT]; - vec![builder.sub(output, input)] + vec![builder.sub_extension(output, input)] } fn generators( @@ -98,6 +100,6 @@ mod tests { #[test] fn low_degree() { - test_low_degree(ConstantGate::get::()) + test_low_degree(ConstantGate::get::()) } } diff --git a/src/gates/gate.rs b/src/gates/gate.rs index a340a1bd..1765191e 100644 --- a/src/gates/gate.rs +++ b/src/gates/gate.rs @@ -2,33 +2,71 @@ use std::hash::{Hash, Hasher}; use std::sync::Arc; use crate::circuit_builder::CircuitBuilder; -use crate::field::field::Field; +use crate::field::extension_field::target::ExtensionTarget; +use crate::field::extension_field::{Extendable, FieldExtension}; use crate::generator::WitnessGenerator; -use crate::target::Target; -use crate::vars::{EvaluationTargets, EvaluationVars}; +use crate::vars::{EvaluationTargets, EvaluationVars, EvaluationVarsBase}; /// A custom gate. -pub trait Gate: 'static + Send + Sync { +pub trait Gate, const D: usize>: 'static + Send + Sync { fn id(&self) -> String; - fn eval_unfiltered(&self, vars: EvaluationVars) -> Vec; + fn eval_unfiltered(&self, vars: EvaluationVars) -> Vec; + + /// Like `eval_unfiltered`, but specialized for points in the base field. + /// + /// By default, this just calls `eval_unfiltered`, which treats the point as an extension field + /// element. This isn't very efficient. + fn eval_unfiltered_base(&self, vars_base: EvaluationVarsBase) -> Vec { + let local_constants = &vars_base + .local_constants + .iter() + .map(|c| F::Extension::from_basefield(*c)) + .collect::>(); + let local_wires = &vars_base + .local_wires + .iter() + .map(|w| F::Extension::from_basefield(*w)) + .collect::>(); + let vars = EvaluationVars { + local_constants, + local_wires, + }; + let values = self.eval_unfiltered(vars); + + // Each value should be in the base field, i.e. only the degree-zero part should be nonzero. + values + .into_iter() + .map(|value| { + // TODO: Change to debug-only once our gate code is mostly finished/stable. + assert!(F::Extension::is_in_basefield(&value)); + value.to_basefield_array()[0] + }) + .collect() + } fn eval_unfiltered_recursively( &self, - builder: &mut CircuitBuilder, - vars: EvaluationTargets, - ) -> Vec; + builder: &mut CircuitBuilder, + vars: EvaluationTargets, + ) -> Vec>; - fn eval_filtered(&self, vars: EvaluationVars) -> Vec { + fn eval_filtered(&self, vars: EvaluationVars) -> Vec { // TODO: Filter self.eval_unfiltered(vars) } + /// Like `eval_filtered`, but specialized for points in the base field. + fn eval_filtered_base(&self, vars: EvaluationVarsBase) -> Vec { + // TODO: Filter + self.eval_unfiltered_base(vars) + } + fn eval_filtered_recursively( &self, - builder: &mut CircuitBuilder, - vars: EvaluationTargets, - ) -> Vec { + builder: &mut CircuitBuilder, + vars: EvaluationTargets, + ) -> Vec> { // TODO: Filter self.eval_unfiltered_recursively(builder, vars) } @@ -53,30 +91,30 @@ pub trait Gate: 'static + Send + Sync { /// A wrapper around an `Rc` which implements `PartialEq`, `Eq` and `Hash` based on gate IDs. #[derive(Clone)] -pub struct GateRef(pub(crate) Arc>); +pub struct GateRef, const D: usize>(pub(crate) Arc>); -impl GateRef { - pub fn new>(gate: G) -> GateRef { +impl, const D: usize> GateRef { + pub fn new>(gate: G) -> GateRef { GateRef(Arc::new(gate)) } } -impl PartialEq for GateRef { +impl, const D: usize> PartialEq for GateRef { fn eq(&self, other: &Self) -> bool { self.0.id() == other.0.id() } } -impl Hash for GateRef { +impl, const D: usize> Hash for GateRef { fn hash(&self, state: &mut H) { self.0.id().hash(state) } } -impl Eq for GateRef {} +impl, const D: usize> Eq for GateRef {} /// A gate along with any constants used to configure it. -pub struct GateInstance { - pub gate_type: GateRef, +pub struct GateInstance, const D: usize> { + pub gate_type: GateRef, pub constants: Vec, } diff --git a/src/gates/gate_testing.rs b/src/gates/gate_testing.rs index b7a62c7e..a6345249 100644 --- a/src/gates/gate_testing.rs +++ b/src/gates/gate_testing.rs @@ -1,3 +1,4 @@ +use crate::field::extension_field::Extendable; use crate::field::field::Field; use crate::gates::gate::{Gate, GateRef}; use crate::polynomial::polynomial::{PolynomialCoeffs, PolynomialValues}; @@ -9,12 +10,12 @@ const WITNESS_DEGREE: usize = WITNESS_SIZE - 1; /// Tests that the constraints imposed by the given gate are low-degree by applying them to random /// low-degree witness polynomials. -pub(crate) fn test_low_degree(gate: GateRef) { +pub(crate) fn test_low_degree, const D: usize>(gate: GateRef) { let gate = gate.0; let rate_bits = log2_ceil(gate.degree() + 1); - let wire_ldes = random_low_degree_matrix(gate.num_wires(), rate_bits); - let constant_ldes = random_low_degree_matrix::(gate.num_constants(), rate_bits); + let wire_ldes = random_low_degree_matrix::(gate.num_wires(), rate_bits); + let constant_ldes = random_low_degree_matrix::(gate.num_constants(), rate_bits); assert_eq!(wire_ldes.len(), constant_ldes.len()); let constraint_evals = wire_ldes diff --git a/src/gates/gmimc.rs b/src/gates/gmimc.rs index 3097ccfb..19042d57 100644 --- a/src/gates/gmimc.rs +++ b/src/gates/gmimc.rs @@ -1,9 +1,10 @@ use std::sync::Arc; use crate::circuit_builder::CircuitBuilder; +use crate::field::extension_field::target::ExtensionTarget; +use crate::field::extension_field::Extendable; use crate::field::field::Field; use crate::gates::gate::{Gate, GateRef}; -use crate::gates::gmimc_eval::GMiMCEvalGate; use crate::generator::{SimpleGenerator, WitnessGenerator}; use crate::gmimc::gmimc_automatic_constants; use crate::target::Target; @@ -22,17 +23,17 @@ const W: usize = 12; /// sibling digests. It also has an accumulator that computes the weighted sum of these flags, for /// computing the index of the leaf based on these swap bits. #[derive(Debug)] -pub struct GMiMCGate { +pub struct GMiMCGate, const D: usize, const R: usize> { constants: Arc<[F; R]>, } -impl GMiMCGate { - pub fn with_constants(constants: Arc<[F; R]>) -> GateRef { - let gate = GMiMCGate:: { constants }; +impl, const D: usize, const R: usize> GMiMCGate { + pub fn with_constants(constants: Arc<[F; R]>) -> GateRef { + let gate = GMiMCGate:: { constants }; GateRef::new(gate) } - pub fn with_automatic_constants() -> GateRef { + pub fn with_automatic_constants() -> GateRef { let constants = Arc::new(gmimc_automatic_constants::()); Self::with_constants(constants) } @@ -66,21 +67,21 @@ impl GMiMCGate { } } -impl Gate for GMiMCGate { +impl, const D: usize, const R: usize> Gate for GMiMCGate { fn id(&self) -> String { format!(" {:?}", R, self) } - fn eval_unfiltered(&self, vars: EvaluationVars) -> Vec { + fn eval_unfiltered(&self, vars: EvaluationVars) -> Vec { let mut constraints = Vec::with_capacity(self.num_constraints()); // Assert that `swap` is binary. let swap = vars.local_wires[Self::WIRE_SWAP]; - constraints.push(swap * (swap - F::ONE)); + constraints.push(swap * (swap - F::Extension::ONE)); let old_index_acc = vars.local_wires[Self::WIRE_INDEX_ACCUMULATOR_OLD]; let new_index_acc = vars.local_wires[Self::WIRE_INDEX_ACCUMULATOR_NEW]; - let computed_new_index_acc = F::TWO * old_index_acc + swap; + let computed_new_index_acc = F::Extension::TWO * old_index_acc + swap; constraints.push(computed_new_index_acc - new_index_acc); let mut state = Vec::with_capacity(12); @@ -100,11 +101,11 @@ impl Gate for GMiMCGate { // Value that is implicitly added to each element. // See https://affine.group/2020/02/starkware-challenge - let mut addition_buffer = F::ZERO; + let mut addition_buffer = F::Extension::ZERO; for r in 0..R { let active = r % W; - let cubing_input = state[active] + addition_buffer + self.constants[r]; + let cubing_input = state[active] + addition_buffer + self.constants[r].into(); let cubing_input_wire = vars.local_wires[Self::wire_cubing_input(r)]; constraints.push(cubing_input - cubing_input_wire); let f = cubing_input_wire.cube(); @@ -122,37 +123,36 @@ impl Gate for GMiMCGate { fn eval_unfiltered_recursively( &self, - builder: &mut CircuitBuilder, - vars: EvaluationTargets, - ) -> Vec { + builder: &mut CircuitBuilder, + vars: EvaluationTargets, + ) -> Vec> { let mut constraints = Vec::with_capacity(self.num_constraints()); - // Assert that `swap` is binary. Usually we would assert that - // swap(swap - 1) = 0 - // but to make it work with a single ArithmeticGate, we will instead write it as - // swap*swap - swap = 0 let swap = vars.local_wires[Self::WIRE_SWAP]; - constraints.push(builder.mul_sub(swap, swap, swap)); + let one_ext = builder.one_extension(); + let not_swap = builder.sub_extension(swap, one_ext); + constraints.push(builder.mul_extension(swap, not_swap)); let old_index_acc = vars.local_wires[Self::WIRE_INDEX_ACCUMULATOR_OLD]; let new_index_acc = vars.local_wires[Self::WIRE_INDEX_ACCUMULATOR_NEW]; // computed_new_index_acc = 2 * old_index_acc + swap let two = builder.two(); - let computed_new_index_acc = builder.mul_add(two, old_index_acc, swap); - constraints.push(builder.sub(computed_new_index_acc, new_index_acc)); + let double_old_index_acc = builder.scalar_mul_ext(two, old_index_acc); + let computed_new_index_acc = builder.add_extension(double_old_index_acc, swap); + constraints.push(builder.sub_extension(computed_new_index_acc, new_index_acc)); let mut state = Vec::with_capacity(12); for i in 0..4 { let a = vars.local_wires[i]; let b = vars.local_wires[i + 4]; - let delta = builder.sub(b, a); - state.push(builder.mul_add(swap, delta, a)); + let delta = builder.sub_extension(b, a); + state.push(builder.mul_add_extension(swap, delta, a)); } for i in 0..4 { let a = vars.local_wires[i + 4]; let b = vars.local_wires[i]; - let delta = builder.sub(b, a); - state.push(builder.mul_add(swap, delta, a)); + let delta = builder.sub_extension(b, a); + state.push(builder.mul_add_extension(swap, delta, a)); } for i in 8..12 { state.push(vars.local_wires[i]); @@ -160,56 +160,24 @@ impl Gate for GMiMCGate { // Value that is implicitly added to each element. // See https://affine.group/2020/02/starkware-challenge - let mut addition_buffer = builder.zero(); + let mut addition_buffer = builder.zero_extension(); for r in 0..R { let active = r % W; - let gate = builder.add_gate(GMiMCEvalGate::get(), vec![self.constants[r]]); - let cubing_input = vars.local_wires[Self::wire_cubing_input(r)]; - builder.route( - cubing_input, - Target::Wire(Wire { - gate, - input: GMiMCEvalGate::::WIRE_CUBING_INPUT, - }), - ); - - builder.route( - addition_buffer, - Target::Wire(Wire { - gate, - input: GMiMCEvalGate::::WIRE_ADDITION_BUFFER_OLD, - }), - ); - - builder.route( - state[active], - Target::Wire(Wire { - gate, - input: GMiMCEvalGate::::WIRE_STATE_A_OLD, - }), - ); - - constraints.push(Target::Wire(Wire { - gate, - input: GMiMCEvalGate::::WIRE_CONSTRAINT, - })); - - addition_buffer = Target::Wire(Wire { - gate, - input: GMiMCEvalGate::::WIRE_ADDITION_BUFFER_NEW, - }); - - state[active] = Target::Wire(Wire { - gate, - input: GMiMCEvalGate::::WIRE_STATE_A_NEW, - }); + let constant = builder.constant_extension(self.constants[r].into()); + let cubing_input = + builder.add_many_extension(&[state[active], addition_buffer, constant]); + let square = builder.mul_extension(cubing_input, cubing_input); + let f = builder.mul_extension(square, cubing_input); + addition_buffer = builder.add_extension(addition_buffer, f); + state[active] = builder.sub_extension(state[active], f); } for i in 0..W { - state[i] = builder.add(state[i], addition_buffer); - constraints.push(builder.sub(state[i], vars.local_wires[Self::wire_output(i)])); + state[i] = builder.add_extension(state[i], addition_buffer); + constraints + .push(builder.sub_extension(state[i], vars.local_wires[Self::wire_output(i)])); } constraints @@ -245,19 +213,21 @@ impl Gate for GMiMCGate { } #[derive(Debug)] -struct GMiMCGenerator { +struct GMiMCGenerator, const D: usize, const R: usize> { gate_index: usize, constants: Arc<[F; R]>, } -impl SimpleGenerator for GMiMCGenerator { +impl, const D: usize, const R: usize> SimpleGenerator + for GMiMCGenerator +{ fn dependencies(&self) -> Vec { let mut dep_input_indices = Vec::with_capacity(W + 2); for i in 0..W { - dep_input_indices.push(GMiMCGate::::wire_input(i)); + dep_input_indices.push(GMiMCGate::::wire_input(i)); } - dep_input_indices.push(GMiMCGate::::WIRE_SWAP); - dep_input_indices.push(GMiMCGate::::WIRE_INDEX_ACCUMULATOR_OLD); + dep_input_indices.push(GMiMCGate::::WIRE_SWAP); + dep_input_indices.push(GMiMCGate::::WIRE_INDEX_ACCUMULATOR_OLD); dep_input_indices .into_iter() @@ -277,14 +247,14 @@ impl SimpleGenerator for GMiMCGenerator { .map(|i| { witness.get_wire(Wire { gate: self.gate_index, - input: GMiMCGate::::wire_input(i), + input: GMiMCGate::::wire_input(i), }) }) .collect::>(); let swap_value = witness.get_wire(Wire { gate: self.gate_index, - input: GMiMCGate::::WIRE_SWAP, + input: GMiMCGate::::WIRE_SWAP, }); debug_assert!(swap_value == F::ZERO || swap_value == F::ONE); if swap_value == F::ONE { @@ -296,13 +266,13 @@ impl SimpleGenerator for GMiMCGenerator { // Update the index accumulator. let old_index_acc_value = witness.get_wire(Wire { gate: self.gate_index, - input: GMiMCGate::::WIRE_INDEX_ACCUMULATOR_OLD, + input: GMiMCGate::::WIRE_INDEX_ACCUMULATOR_OLD, }); let new_index_acc_value = F::TWO * old_index_acc_value + swap_value; result.set_wire( Wire { gate: self.gate_index, - input: GMiMCGate::::WIRE_INDEX_ACCUMULATOR_NEW, + input: GMiMCGate::::WIRE_INDEX_ACCUMULATOR_NEW, }, new_index_acc_value, ); @@ -317,7 +287,7 @@ impl SimpleGenerator for GMiMCGenerator { result.set_wire( Wire { gate: self.gate_index, - input: GMiMCGate::::wire_cubing_input(r), + input: GMiMCGate::::wire_cubing_input(r), }, cubing_input, ); @@ -331,7 +301,7 @@ impl SimpleGenerator for GMiMCGenerator { result.set_wire( Wire { gate: self.gate_index, - input: GMiMCGate::::wire_output(i), + input: GMiMCGate::::wire_output(i), }, state[i], ); @@ -361,7 +331,7 @@ mod tests { type F = CrandallField; const R: usize = 101; let constants = Arc::new([F::TWO; R]); - type Gate = GMiMCGate; + type Gate = GMiMCGate; let gate = Gate::with_constants(constants.clone()); let config = CircuitConfig { @@ -423,7 +393,7 @@ mod tests { type F = CrandallField; const R: usize = 101; let constants = Arc::new([F::TWO; R]); - type Gate = GMiMCGate; + type Gate = GMiMCGate; let gate = Gate::with_constants(constants); test_low_degree(gate) } diff --git a/src/gates/gmimc_eval.rs b/src/gates/gmimc_eval.rs deleted file mode 100644 index 57d9206c..00000000 --- a/src/gates/gmimc_eval.rs +++ /dev/null @@ -1,239 +0,0 @@ -use std::marker::PhantomData; - -use crate::circuit_builder::CircuitBuilder; -use crate::field::field::Field; -use crate::gates::gate::{Gate, GateRef}; -use crate::generator::{SimpleGenerator, WitnessGenerator}; -use crate::target::Target; -use crate::vars::{EvaluationTargets, EvaluationVars}; -use crate::wire::Wire; -use crate::witness::PartialWitness; - -/// Performs some arithmetic involved in the evaluation of GMiMC's constraint polynomials for one -/// round. In particular, this performs the following computations: -/// -/// - `constraint := state_a_old + addition_buffer_old + C_r - cubing_input` -/// - `f := cubing_input^3` -/// - `addition_buffer_new := addition_buffer_old + f` -/// - `state_a_new := state_a_old - f` -/// -/// Here `state_a_{old,new}` represent the old and new states of the `a`th element of the GMiMC -/// permutation. `addition_buffer_{old,new}` represents a value that is implicitly added to each -/// element; see https://affine.group/2020/02/starkware-challenge. `C_r` represents the round -/// constant for round `r`. -#[derive(Debug)] -pub struct GMiMCEvalGate { - _phantom: PhantomData, -} - -impl GMiMCEvalGate { - pub fn get() -> GateRef { - GateRef::new(GMiMCEvalGate { - _phantom: PhantomData, - }) - } - - pub const CONST_C_R: usize = 0; - - pub const WIRE_CONSTRAINT: usize = 0; - pub const WIRE_STATE_A_OLD: usize = 1; - pub const WIRE_STATE_A_NEW: usize = 2; - pub const WIRE_ADDITION_BUFFER_OLD: usize = 3; - pub const WIRE_ADDITION_BUFFER_NEW: usize = 4; - pub const WIRE_CUBING_INPUT: usize = 5; - const WIRE_F: usize = 6; -} - -impl Gate for GMiMCEvalGate { - fn id(&self) -> String { - format!("{:?}", self) - } - - fn eval_unfiltered(&self, vars: EvaluationVars) -> Vec { - let c_r = vars.local_constants[Self::CONST_C_R]; - let constraint = vars.local_wires[Self::WIRE_CONSTRAINT]; - let state_a_old = vars.local_wires[Self::WIRE_STATE_A_OLD]; - let state_a_new = vars.local_wires[Self::WIRE_STATE_A_NEW]; - let addition_buffer_old = vars.local_wires[Self::WIRE_ADDITION_BUFFER_OLD]; - let addition_buffer_new = vars.local_wires[Self::WIRE_ADDITION_BUFFER_NEW]; - let cubing_input = vars.local_wires[Self::WIRE_CUBING_INPUT]; - let f = vars.local_wires[Self::WIRE_F]; - - let mut constraints = Vec::with_capacity(self.num_constraints()); - - // constraint := state_a_old + addition_buffer_old + C_r - cubing_input - let computed_constraint = state_a_old + addition_buffer_old + c_r - cubing_input; - constraints.push(constraint - computed_constraint); - - // f := cubing_input^3 - let computed_f = cubing_input.cube(); - constraints.push(f - computed_f); - - // addition_buffer_new := addition_buffer_old + f - let computed_addition_buffer_new = addition_buffer_old + f; - constraints.push(addition_buffer_new - computed_addition_buffer_new); - - // state_a_new := state_a_old - f - let computed_state_a_new = state_a_old - f; - constraints.push(state_a_new - computed_state_a_new); - - constraints - } - - fn eval_unfiltered_recursively( - &self, - builder: &mut CircuitBuilder, - vars: EvaluationTargets, - ) -> Vec { - let c_r = vars.local_constants[Self::CONST_C_R]; - let constraint = vars.local_wires[Self::WIRE_CONSTRAINT]; - let state_a_old = vars.local_wires[Self::WIRE_STATE_A_OLD]; - let state_a_new = vars.local_wires[Self::WIRE_STATE_A_NEW]; - let addition_buffer_old = vars.local_wires[Self::WIRE_ADDITION_BUFFER_OLD]; - let addition_buffer_new = vars.local_wires[Self::WIRE_ADDITION_BUFFER_NEW]; - let cubing_input = vars.local_wires[Self::WIRE_CUBING_INPUT]; - let f = vars.local_wires[Self::WIRE_F]; - - let mut constraints = Vec::with_capacity(self.num_constraints()); - - // constraint := state_a_old + addition_buffer_old + C_r - cubing_input - let sum = builder.add_many(&[state_a_old, addition_buffer_old, c_r]); - let computed_constraint = builder.sub(sum, cubing_input); - constraints.push(builder.sub(constraint, computed_constraint)); - - // f := cubing_input^3 - let computed_f = builder.cube(cubing_input); - constraints.push(builder.sub(f, computed_f)); - - // addition_buffer_new := addition_buffer_old + f - let computed_addition_buffer_new = builder.add(addition_buffer_old, f); - constraints.push(builder.sub(addition_buffer_new, computed_addition_buffer_new)); - - // state_a_new := state_a_old - f - let computed_state_a_new = builder.sub(state_a_old, f); - constraints.push(builder.sub(state_a_new, computed_state_a_new)); - - constraints - } - - fn generators( - &self, - gate_index: usize, - local_constants: &[F], - ) -> Vec>> { - let gen = GMiMCEvalGenerator:: { - gate_index, - c_r: local_constants[Self::CONST_C_R], - }; - vec![Box::new(gen)] - } - - fn num_wires(&self) -> usize { - 7 - } - - fn num_constants(&self) -> usize { - 1 - } - - fn degree(&self) -> usize { - 3 - } - - fn num_constraints(&self) -> usize { - 4 - } -} - -#[derive(Debug)] -struct GMiMCEvalGenerator { - gate_index: usize, - c_r: F, -} - -impl SimpleGenerator for GMiMCEvalGenerator { - fn dependencies(&self) -> Vec { - let gate = self.gate_index; - vec![ - Target::Wire(Wire { - gate, - input: GMiMCEvalGate::::WIRE_CUBING_INPUT, - }), - Target::Wire(Wire { - gate, - input: GMiMCEvalGate::::WIRE_ADDITION_BUFFER_OLD, - }), - Target::Wire(Wire { - gate, - input: GMiMCEvalGate::::WIRE_STATE_A_OLD, - }), - ] - } - - fn run_once(&self, witness: &PartialWitness) -> PartialWitness { - let gate = self.gate_index; - let wire_constraint = Wire { - gate, - input: GMiMCEvalGate::::WIRE_CONSTRAINT, - }; - let wire_state_a_old = Wire { - gate, - input: GMiMCEvalGate::::WIRE_STATE_A_OLD, - }; - let wire_state_a_new = Wire { - gate, - input: GMiMCEvalGate::::WIRE_STATE_A_NEW, - }; - let wire_addition_buffer_old = Wire { - gate, - input: GMiMCEvalGate::::WIRE_ADDITION_BUFFER_OLD, - }; - let wire_addition_buffer_new = Wire { - gate, - input: GMiMCEvalGate::::WIRE_ADDITION_BUFFER_NEW, - }; - let wire_cubing_input = Wire { - gate, - input: GMiMCEvalGate::::WIRE_CUBING_INPUT, - }; - let wire_f = Wire { - gate, - input: GMiMCEvalGate::::WIRE_F, - }; - - let addition_buffer_old = witness.get_wire(wire_addition_buffer_old); - let state_a_old = witness.get_wire(wire_state_a_old); - let cubing_input = witness.get_wire(wire_cubing_input); - - // constraint := state_a_old + addition_buffer_old + C_r - cubing_input - let constraint = state_a_old + addition_buffer_old + self.c_r - cubing_input; - - // f := cubing_input^3 - let f = cubing_input.cube(); - - // addition_buffer_new := addition_buffer_old + f - let addition_buffer_new = addition_buffer_old + f; - - // state_a_new := state_a_old - f - let state_a_new = state_a_old - f; - - let mut witness = PartialWitness::new(); - witness.set_wire(wire_constraint, constraint); - witness.set_wire(wire_f, f); - witness.set_wire(wire_state_a_new, addition_buffer_new); - witness.set_wire(wire_addition_buffer_new, state_a_new); - witness - } -} - -#[cfg(test)] -mod tests { - use crate::field::crandall_field::CrandallField; - use crate::gates::gate_testing::test_low_degree; - use crate::gates::gmimc_eval::GMiMCEvalGate; - - #[test] - fn low_degree() { - test_low_degree(GMiMCEvalGate::::get()) - } -} diff --git a/src/gates/interpolation.rs b/src/gates/interpolation.rs index a1b9ddd5..53e0b7e7 100644 --- a/src/gates/interpolation.rs +++ b/src/gates/interpolation.rs @@ -3,10 +3,10 @@ use std::marker::PhantomData; use std::ops::Range; use crate::circuit_builder::CircuitBuilder; +use crate::field::extension_field::target::ExtensionTarget; use crate::field::extension_field::{Extendable, FieldExtension}; -use crate::field::field::Field; use crate::field::lagrange::interpolant; -use crate::gadgets::polynomial::PolynomialCoeffsTarget; +use crate::gadgets::polynomial::PolynomialCoeffsExtExtTarget; use crate::gates::gate::{Gate, GateRef}; use crate::generator::{SimpleGenerator, WitnessGenerator}; use crate::polynomial::polynomial::PolynomialCoeffs; @@ -21,13 +21,19 @@ use crate::witness::PartialWitness; /// to evaluate the interpolant at. It computes the interpolant and outputs its evaluation at the /// given point. #[derive(Clone, Debug)] -pub(crate) struct InterpolationGate, const D: usize> { +pub(crate) struct InterpolationGate, const D: usize> +where + F::Extension: Extendable, +{ num_points: usize, _phantom: PhantomData, } -impl, const D: usize> InterpolationGate { - pub fn new(num_points: usize) -> GateRef { +impl, const D: usize> InterpolationGate +where + F::Extension: Extendable, +{ + pub fn new(num_points: usize) -> GateRef { let gate = Self { num_points, _phantom: PhantomData, @@ -93,28 +99,31 @@ impl, const D: usize> InterpolationGate { } } -impl, const D: usize> Gate for InterpolationGate { +impl, const D: usize> Gate for InterpolationGate +where + F::Extension: Extendable, +{ fn id(&self) -> String { format!("{:?}", self, D) } - fn eval_unfiltered(&self, vars: EvaluationVars) -> Vec { + fn eval_unfiltered(&self, vars: EvaluationVars) -> Vec { let mut constraints = Vec::with_capacity(self.num_constraints()); let coeffs = (0..self.num_points) - .map(|i| vars.get_local_ext(self.wires_coeff(i))) + .map(|i| vars.get_local_ext_ext(self.wires_coeff(i))) .collect(); let interpolant = PolynomialCoeffs::new(coeffs); for i in 0..self.num_points { - let point = F::Extension::from_basefield(vars.local_wires[self.wire_point(i)]); - let value = vars.get_local_ext(self.wires_value(i)); - let computed_value = interpolant.eval(point); + let point = vars.local_wires[self.wire_point(i)]; + let value = vars.get_local_ext_ext(self.wires_value(i)); + let computed_value = interpolant.eval(point.into()); constraints.extend(&(value - computed_value).to_basefield_array()); } - let evaluation_point = vars.get_local_ext(self.wires_evaluation_point()); - let evaluation_value = vars.get_local_ext(self.wires_evaluation_value()); + let evaluation_point = vars.get_local_ext_ext(self.wires_evaluation_point()); + let evaluation_value = vars.get_local_ext_ext(self.wires_evaluation_value()); let computed_evaluation_value = interpolant.eval(evaluation_point); constraints.extend(&(evaluation_value - computed_evaluation_value).to_basefield_array()); @@ -123,34 +132,34 @@ impl, const D: usize> Gate for InterpolationGate, - vars: EvaluationTargets, - ) -> Vec { + builder: &mut CircuitBuilder, + vars: EvaluationTargets, + ) -> Vec> { let mut constraints = Vec::with_capacity(self.num_constraints()); let coeffs = (0..self.num_points) - .map(|i| vars.get_local_ext(self.wires_coeff(i))) + .map(|i| vars.get_local_ext_ext(self.wires_coeff(i))) .collect(); - let interpolant = PolynomialCoeffsTarget(coeffs); + let interpolant = PolynomialCoeffsExtExtTarget(coeffs); for i in 0..self.num_points { let point = vars.local_wires[self.wire_point(i)]; - let value = vars.get_local_ext(self.wires_value(i)); + let value = vars.get_local_ext_ext(self.wires_value(i)); let computed_value = interpolant.eval_scalar(builder, point); constraints.extend( &builder - .sub_extension(value, computed_value) - .to_target_array(), + .sub_ext_ext(value, computed_value) + .to_ext_target_array(), ); } - let evaluation_point = vars.get_local_ext(self.wires_evaluation_point()); - let evaluation_value = vars.get_local_ext(self.wires_evaluation_value()); + let evaluation_point = vars.get_local_ext_ext(self.wires_evaluation_point()); + let evaluation_value = vars.get_local_ext_ext(self.wires_evaluation_value()); let computed_evaluation_value = interpolant.eval(builder, evaluation_point); constraints.extend( &builder - .sub_extension(evaluation_value, computed_evaluation_value) - .to_target_array(), + .sub_ext_ext(evaluation_value, computed_evaluation_value) + .to_ext_target_array(), ); constraints @@ -190,13 +199,19 @@ impl, const D: usize> Gate for InterpolationGate, const D: usize> { +struct InterpolationGenerator, const D: usize> +where + F::Extension: Extendable, +{ gate_index: usize, gate: InterpolationGate, _phantom: PhantomData, } -impl, const D: usize> SimpleGenerator for InterpolationGenerator { +impl, const D: usize> SimpleGenerator for InterpolationGenerator +where + F::Extension: Extendable, +{ fn dependencies(&self) -> Vec { let local_target = |input| { Target::Wire(Wire { @@ -293,7 +308,6 @@ mod tests { #[test] fn low_degree() { type F = CrandallField; - test_low_degree(InterpolationGate::::new(4)); test_low_degree(InterpolationGate::::new(4)); } } diff --git a/src/gates/mod.rs b/src/gates/mod.rs index afc0b904..ebcf6e3f 100644 --- a/src/gates/mod.rs +++ b/src/gates/mod.rs @@ -2,7 +2,6 @@ pub(crate) mod arithmetic; pub mod constant; pub(crate) mod gate; pub mod gmimc; -pub(crate) mod gmimc_eval; mod interpolation; pub(crate) mod noop; diff --git a/src/gates/noop.rs b/src/gates/noop.rs index fdde6ec6..eddd0361 100644 --- a/src/gates/noop.rs +++ b/src/gates/noop.rs @@ -1,33 +1,33 @@ use crate::circuit_builder::CircuitBuilder; -use crate::field::field::Field; +use crate::field::extension_field::target::ExtensionTarget; +use crate::field::extension_field::Extendable; use crate::gates::gate::{Gate, GateRef}; use crate::generator::WitnessGenerator; -use crate::target::Target; use crate::vars::{EvaluationTargets, EvaluationVars}; /// A gate which takes a single constant parameter and outputs that value. pub struct NoopGate; impl NoopGate { - pub fn get() -> GateRef { + pub fn get, const D: usize>() -> GateRef { GateRef::new(NoopGate) } } -impl Gate for NoopGate { +impl, const D: usize> Gate for NoopGate { fn id(&self) -> String { "NoopGate".into() } - fn eval_unfiltered(&self, _vars: EvaluationVars) -> Vec { + fn eval_unfiltered(&self, _vars: EvaluationVars) -> Vec { Vec::new() } fn eval_unfiltered_recursively( &self, - _builder: &mut CircuitBuilder, - _vars: EvaluationTargets, - ) -> Vec { + _builder: &mut CircuitBuilder, + _vars: EvaluationTargets, + ) -> Vec> { Vec::new() } @@ -64,6 +64,6 @@ mod tests { #[test] fn low_degree() { - test_low_degree(NoopGate::get::()) + test_low_degree(NoopGate::get::()) } } diff --git a/src/hash.rs b/src/hash.rs index a51d1f08..47d703cd 100644 --- a/src/hash.rs +++ b/src/hash.rs @@ -1,6 +1,7 @@ //! Concrete instantiation of a hash function. use crate::circuit_builder::CircuitBuilder; +use crate::field::extension_field::Extendable; use crate::field::field::Field; use crate::gmimc::gmimc_permute_array; use crate::proof::{Hash, HashTarget}; @@ -132,7 +133,7 @@ pub fn hash_or_noop(inputs: Vec) -> Hash { } } -impl CircuitBuilder { +impl, const D: usize> CircuitBuilder { pub fn hash_or_noop(&mut self, inputs: Vec) -> HashTarget { let zero = self.zero(); if inputs.len() <= 4 { diff --git a/src/merkle_proofs.rs b/src/merkle_proofs.rs index 3fd573ba..d5ab8a78 100644 --- a/src/merkle_proofs.rs +++ b/src/merkle_proofs.rs @@ -1,4 +1,7 @@ +use anyhow::{ensure, Result}; + use crate::circuit_builder::CircuitBuilder; +use crate::field::extension_field::Extendable; use crate::field::field::Field; use crate::gates::gmimc::GMiMCGate; use crate::hash::GMIMC_ROUNDS; @@ -6,7 +9,6 @@ use crate::hash::{compress, hash_or_noop}; use crate::proof::{Hash, HashTarget}; use crate::target::Target; use crate::wire::Wire; -use anyhow::{ensure, Result}; #[derive(Clone, Debug)] pub struct MerkleProof { @@ -52,7 +54,7 @@ pub(crate) fn verify_merkle_proof( Ok(()) } -impl CircuitBuilder { +impl, const D: usize> CircuitBuilder { /// Verifies that the given leaf data is present at the given index in the Merkle tree with the /// given root. pub(crate) fn verify_merkle_proof( @@ -71,23 +73,23 @@ impl CircuitBuilder { for (bit, sibling) in purported_index_bits.into_iter().zip(proof.siblings) { let gate = self - .add_gate_no_constants(GMiMCGate::::with_automatic_constants()); + .add_gate_no_constants(GMiMCGate::::with_automatic_constants()); - let swap_wire = GMiMCGate::::WIRE_SWAP; + let swap_wire = GMiMCGate::::WIRE_SWAP; let swap_wire = Target::Wire(Wire { gate, input: swap_wire, }); self.generate_copy(bit, swap_wire); - let old_acc_wire = GMiMCGate::::WIRE_INDEX_ACCUMULATOR_OLD; + let old_acc_wire = GMiMCGate::::WIRE_INDEX_ACCUMULATOR_OLD; let old_acc_wire = Target::Wire(Wire { gate, input: old_acc_wire, }); self.route(acc_leaf_index, old_acc_wire); - let new_acc_wire = GMiMCGate::::WIRE_INDEX_ACCUMULATOR_NEW; + let new_acc_wire = GMiMCGate::::WIRE_INDEX_ACCUMULATOR_NEW; let new_acc_wire = Target::Wire(Wire { gate, input: new_acc_wire, @@ -98,7 +100,7 @@ impl CircuitBuilder { .map(|i| { Target::Wire(Wire { gate, - input: GMiMCGate::::wire_input(i), + input: GMiMCGate::::wire_input(i), }) }) .collect::>(); @@ -114,7 +116,7 @@ impl CircuitBuilder { .map(|i| { Target::Wire(Wire { gate, - input: GMiMCGate::::wire_output(i), + input: GMiMCGate::::wire_output(i), }) }) .collect(), diff --git a/src/plonk_challenger.rs b/src/plonk_challenger.rs index 6a1a8888..4c8ff167 100644 --- a/src/plonk_challenger.rs +++ b/src/plonk_challenger.rs @@ -160,7 +160,9 @@ pub(crate) struct RecursiveChallenger { } impl RecursiveChallenger { - pub(crate) fn new(builder: &mut CircuitBuilder) -> Self { + pub(crate) fn new, const D: usize>( + builder: &mut CircuitBuilder, + ) -> Self { let zero = builder.zero(); RecursiveChallenger { sponge_state: [zero; SPONGE_WIDTH], @@ -186,7 +188,10 @@ impl RecursiveChallenger { self.observe_elements(&hash.elements) } - pub(crate) fn get_challenge(&mut self, builder: &mut CircuitBuilder) -> Target { + pub(crate) fn get_challenge, const D: usize>( + &mut self, + builder: &mut CircuitBuilder, + ) -> Target { self.absorb_buffered_inputs(builder); if self.output_buffer.is_empty() { @@ -200,16 +205,16 @@ impl RecursiveChallenger { .expect("Output buffer should be non-empty") } - pub(crate) fn get_2_challenges( + pub(crate) fn get_2_challenges, const D: usize>( &mut self, - builder: &mut CircuitBuilder, + builder: &mut CircuitBuilder, ) -> (Target, Target) { (self.get_challenge(builder), self.get_challenge(builder)) } - pub(crate) fn get_3_challenges( + pub(crate) fn get_3_challenges, const D: usize>( &mut self, - builder: &mut CircuitBuilder, + builder: &mut CircuitBuilder, ) -> (Target, Target, Target) { ( self.get_challenge(builder), @@ -218,16 +223,19 @@ impl RecursiveChallenger { ) } - pub(crate) fn get_n_challenges( + pub(crate) fn get_n_challenges, const D: usize>( &mut self, - builder: &mut CircuitBuilder, + builder: &mut CircuitBuilder, n: usize, ) -> Vec { (0..n).map(|_| self.get_challenge(builder)).collect() } /// Absorb any buffered inputs. After calling this, the input buffer will be empty. - fn absorb_buffered_inputs(&mut self, builder: &mut CircuitBuilder) { + fn absorb_buffered_inputs, const D: usize>( + &mut self, + builder: &mut CircuitBuilder, + ) { if self.input_buffer.is_empty() { return; } @@ -308,7 +316,7 @@ mod tests { num_routed_wires: 27, ..CircuitConfig::default() }; - let mut builder = CircuitBuilder::::new(config); + let mut builder = CircuitBuilder::::new(config); let mut recursive_challenger = RecursiveChallenger::new(&mut builder); let mut recursive_outputs_per_round: Vec> = Vec::new(); for (r, inputs) in inputs_per_round.iter().enumerate() { diff --git a/src/plonk_common.rs b/src/plonk_common.rs index ed780817..73c6d65c 100644 --- a/src/plonk_common.rs +++ b/src/plonk_common.rs @@ -1,20 +1,22 @@ use crate::circuit_builder::CircuitBuilder; +use crate::field::extension_field::target::ExtensionTarget; +use crate::field::extension_field::Extendable; use crate::field::field::Field; use crate::gates::gate::GateRef; use crate::target::Target; -use crate::vars::{EvaluationTargets, EvaluationVars}; +use crate::vars::{EvaluationTargets, EvaluationVars, EvaluationVarsBase}; /// Evaluates all gate constraints. /// /// `num_gate_constraints` is the largest number of constraints imposed by any gate. It is not /// strictly necessary, but it helps performance by ensuring that we allocate a vector with exactly /// the capacity that we need. -pub fn evaluate_gate_constraints( - gates: &[GateRef], +pub fn evaluate_gate_constraints, const D: usize>( + gates: &[GateRef], num_gate_constraints: usize, - vars: EvaluationVars, -) -> Vec { - let mut constraints = vec![F::ZERO; num_gate_constraints]; + vars: EvaluationVars, +) -> Vec { + let mut constraints = vec![F::Extension::ZERO; num_gate_constraints]; for gate in gates { let gate_constraints = gate.0.eval_filtered(vars); for (i, c) in gate_constraints.into_iter().enumerate() { @@ -28,17 +30,36 @@ pub fn evaluate_gate_constraints( constraints } -pub fn evaluate_gate_constraints_recursively( - builder: &mut CircuitBuilder, - gates: &[GateRef], +pub fn evaluate_gate_constraints_base, const D: usize>( + gates: &[GateRef], num_gate_constraints: usize, - vars: EvaluationTargets, -) -> Vec { - let mut constraints = vec![builder.zero(); num_gate_constraints]; + vars: EvaluationVarsBase, +) -> Vec { + let mut constraints = vec![F::ZERO; num_gate_constraints]; + for gate in gates { + let gate_constraints = gate.0.eval_filtered_base(vars); + for (i, c) in gate_constraints.into_iter().enumerate() { + debug_assert!( + i < num_gate_constraints, + "num_constraints() gave too low of a number" + ); + constraints[i] += c; + } + } + constraints +} + +pub fn evaluate_gate_constraints_recursively, const D: usize>( + builder: &mut CircuitBuilder, + gates: &[GateRef], + num_gate_constraints: usize, + vars: EvaluationTargets, +) -> Vec> { + let mut constraints = vec![builder.zero_extension(); num_gate_constraints]; for gate in gates { let gate_constraints = gate.0.eval_filtered_recursively(builder, vars); for (i, c) in gate_constraints.into_iter().enumerate() { - constraints[i] = builder.add(constraints[i], c); + constraints[i] = builder.add_extension(constraints[i], c); } } constraints @@ -80,8 +101,8 @@ pub(crate) fn reduce_with_powers(terms: &[F], alpha: F) -> F { sum } -pub(crate) fn reduce_with_powers_recursive( - builder: &mut CircuitBuilder, +pub(crate) fn reduce_with_powers_recursive, const D: usize>( + builder: &mut CircuitBuilder, terms: Vec, alpha: Target, ) -> Target { diff --git a/src/polynomial/commitment.rs b/src/polynomial/commitment.rs index aa58c561..8727ada9 100644 --- a/src/polynomial/commitment.rs +++ b/src/polynomial/commitment.rs @@ -359,8 +359,6 @@ impl, const D: usize> OpeningProof { mod tests { use anyhow::Result; - use crate::field::crandall_field::CrandallField; - use super::*; fn gen_random_test_case, const D: usize>( diff --git a/src/prover.rs b/src/prover.rs index a90d81f1..40cad1b9 100644 --- a/src/prover.rs +++ b/src/prover.rs @@ -9,22 +9,22 @@ use crate::field::fft::ifft; use crate::field::field::Field; use crate::generator::generate_partial_witness; use crate::plonk_challenger::Challenger; -use crate::plonk_common::{eval_l_1, evaluate_gate_constraints, reduce_with_powers_multi}; +use crate::plonk_common::{eval_l_1, evaluate_gate_constraints_base, reduce_with_powers_multi}; use crate::polynomial::commitment::ListPolynomialCommitment; use crate::polynomial::polynomial::{PolynomialCoeffs, PolynomialValues}; use crate::proof::Proof; use crate::timed; use crate::util::transpose; -use crate::vars::EvaluationVars; +use crate::vars::EvaluationVarsBase; use crate::wire::Wire; use crate::witness::PartialWitness; /// Corresponds to constants - sigmas - wires - zs - quotient — polynomial commitments. pub const PLONK_BLINDING: [bool; 5] = [false, false, true, true, true]; -pub(crate) fn prove, const D: usize>( +pub(crate) fn prove, const D: usize>( prover_data: &ProverOnlyCircuitData, - common_data: &CommonCircuitData, + common_data: &CommonCircuitData, inputs: PartialWitness, ) -> Proof { let fri_config = &common_data.config.fri_config; @@ -113,7 +113,7 @@ pub(crate) fn prove, const D: usize>( challenger.observe_hash("ient_polys_commitment.merkle_tree.root); - let zetas = challenger.get_n_extension_challenges(config.num_challenges); + let zeta = challenger.get_extension_challenge(); let (opening_proof, openings) = timed!( ListPolynomialCommitment::batch_open_plonk( @@ -124,7 +124,7 @@ pub(crate) fn prove, const D: usize>( &plonk_zs_commitment, "ient_polys_commitment, ], - &zetas, + &[zeta], &mut challenger, &common_data.config.fri_config ), @@ -145,19 +145,23 @@ pub(crate) fn prove, const D: usize>( } } -fn compute_zs(common_data: &CommonCircuitData) -> Vec> { +fn compute_zs, const D: usize>( + common_data: &CommonCircuitData, +) -> Vec> { (0..common_data.config.num_challenges) .map(|i| compute_z(common_data, i)) .collect() } -fn compute_z(common_data: &CommonCircuitData, _i: usize) -> PolynomialCoeffs { +fn compute_z, const D: usize>( + common_data: &CommonCircuitData, + _i: usize, +) -> PolynomialCoeffs { PolynomialCoeffs::zero(common_data.degree()) // TODO } -// TODO: Parallelize. -fn compute_vanishing_polys( - common_data: &CommonCircuitData, +fn compute_vanishing_polys, const D: usize>( + common_data: &CommonCircuitData, prover_data: &ProverOnlyCircuitData, wires_commitment: &ListPolynomialCommitment, plonk_zs_commitment: &ListPolynomialCommitment, @@ -184,7 +188,7 @@ fn compute_vanishing_polys( debug_assert_eq!(local_wires.len(), common_data.config.num_wires); debug_assert_eq!(local_plonk_zs.len(), num_challenges); - let vars = EvaluationVars { + let vars = EvaluationVarsBase { local_constants, local_wires, }; @@ -211,10 +215,10 @@ fn compute_vanishing_polys( /// Evaluate the vanishing polynomial at `x`. In this context, the vanishing polynomial is a random /// linear combination of gate constraints, plus some other terms relating to the permutation /// argument. All such terms should vanish on `H`. -fn compute_vanishing_poly_entry( - common_data: &CommonCircuitData, +fn compute_vanishing_poly_entry, const D: usize>( + common_data: &CommonCircuitData, x: F, - vars: EvaluationVars, + vars: EvaluationVarsBase, local_plonk_zs: &[F], next_plonk_zs: &[F], s_sigmas: &[F], @@ -223,7 +227,7 @@ fn compute_vanishing_poly_entry( alphas: &[F], ) -> Vec { let constraint_terms = - evaluate_gate_constraints(&common_data.gates, common_data.num_gate_constraints, vars); + evaluate_gate_constraints_base(&common_data.gates, common_data.num_gate_constraints, vars); // The L_1(x) (Z(x) - 1) vanishing terms. let mut vanishing_z_1_terms = Vec::new(); diff --git a/src/recursive_verifier.rs b/src/recursive_verifier.rs index b850e587..a9d37553 100644 --- a/src/recursive_verifier.rs +++ b/src/recursive_verifier.rs @@ -1,6 +1,6 @@ use crate::circuit_builder::CircuitBuilder; use crate::circuit_data::{CircuitConfig, VerifierCircuitTarget}; -use crate::field::field::Field; +use crate::field::extension_field::Extendable; use crate::gates::gate::GateRef; use crate::proof::ProofTarget; @@ -8,11 +8,11 @@ const MIN_WIRES: usize = 120; // TODO: Double check. const MIN_ROUTED_WIRES: usize = 8; // TODO: Double check. /// Recursively verifies an inner proof. -pub fn add_recursive_verifier( - builder: &mut CircuitBuilder, +pub fn add_recursive_verifier, const D: usize>( + builder: &mut CircuitBuilder, inner_config: CircuitConfig, inner_circuit: VerifierCircuitTarget, - inner_gates: Vec>, + inner_gates: Vec>, inner_proof: ProofTarget, ) { assert!(builder.config.num_wires >= MIN_WIRES); diff --git a/src/vars.rs b/src/vars.rs index aa8a3561..88d0759a 100644 --- a/src/vars.rs +++ b/src/vars.rs @@ -1,38 +1,46 @@ use std::convert::TryInto; use std::ops::Range; -use crate::field::extension_field::target::ExtensionTarget; +use crate::field::extension_field::target::{ExtensionExtensionTarget, ExtensionTarget}; use crate::field::extension_field::{Extendable, FieldExtension}; use crate::field::field::Field; -use crate::target::Target; #[derive(Copy, Clone)] -pub struct EvaluationVars<'a, F: Field> { +pub struct EvaluationVars<'a, F: Extendable, const D: usize> { + pub(crate) local_constants: &'a [F::Extension], + pub(crate) local_wires: &'a [F::Extension], +} + +#[derive(Copy, Clone)] +pub struct EvaluationVarsBase<'a, F: Field> { pub(crate) local_constants: &'a [F], pub(crate) local_wires: &'a [F], } -impl<'a, F: Field> EvaluationVars<'a, F> { - pub fn get_local_ext(&self, wire_range: Range) -> F::Extension +impl<'a, F: Extendable, const D: usize> EvaluationVars<'a, F, D> { + pub fn get_local_ext_ext( + &self, + wire_range: Range, + ) -> <>::Extension as Extendable>::Extension where - F: Extendable, + F::Extension: Extendable, { debug_assert_eq!(wire_range.len(), D); let arr = self.local_wires[wire_range].try_into().unwrap(); - F::Extension::from_basefield_array(arr) + <>::Extension as Extendable>::Extension::from_basefield_array(arr) } } #[derive(Copy, Clone)] -pub struct EvaluationTargets<'a> { - pub(crate) local_constants: &'a [Target], - pub(crate) local_wires: &'a [Target], +pub struct EvaluationTargets<'a, const D: usize> { + pub(crate) local_constants: &'a [ExtensionTarget], + pub(crate) local_wires: &'a [ExtensionTarget], } -impl<'a> EvaluationTargets<'a> { - pub fn get_local_ext(&self, wire_range: Range) -> ExtensionTarget { +impl<'a, const D: usize> EvaluationTargets<'a, D> { + pub fn get_local_ext_ext(&self, wire_range: Range) -> ExtensionExtensionTarget { debug_assert_eq!(wire_range.len(), D); let arr = self.local_wires[wire_range].try_into().unwrap(); - ExtensionTarget(arr) + ExtensionExtensionTarget(arr) } } diff --git a/src/verifier.rs b/src/verifier.rs index 0a02c19e..7af0f8a8 100644 --- a/src/verifier.rs +++ b/src/verifier.rs @@ -2,14 +2,13 @@ use anyhow::Result; use crate::circuit_data::{CommonCircuitData, VerifierOnlyCircuitData}; use crate::field::extension_field::Extendable; -use crate::field::field::Field; use crate::plonk_challenger::Challenger; use crate::proof::Proof; -pub(crate) fn verify, const D: usize>( +pub(crate) fn verify, const D: usize>( proof: Proof, verifier_data: &VerifierOnlyCircuitData, - common_data: &CommonCircuitData, + common_data: &CommonCircuitData, ) -> Result<()> { let config = &common_data.config; let fri_config = &config.fri_config; @@ -28,7 +27,7 @@ pub(crate) fn verify, const D: usize>( let alphas = challenger.get_n_challenges(num_challenges); challenger.observe_hash(&proof.quotient_polys_root); - let zetas = challenger.get_n_extension_challenges(config.num_challenges); + let zeta = challenger.get_extension_challenge(); // TODO: Compute PI(zeta), Z_H(zeta), etc. and check the identity at zeta. @@ -43,7 +42,7 @@ pub(crate) fn verify, const D: usize>( ]; proof.opening_proof.verify( - &zetas, + &[zeta], evaluations, merkle_roots, &mut challenger,