diff --git a/src/field/crandall_field.rs b/src/field/crandall_field.rs index fd1a178e..b324b2d1 100644 --- a/src/field/crandall_field.rs +++ b/src/field/crandall_field.rs @@ -152,14 +152,6 @@ impl DivAssign for CrandallField { } } -#[inline(always)] -fn reduce64(x: u64) -> u64 { - // TODO: slow? try cond sub - // x % P - let over = x > P; - x - (over as u64) * P -} - /// no final reduction #[inline(always)] fn reduce128(x: u128) -> CrandallField { diff --git a/src/field/field.rs b/src/field/field.rs index 34aa25af..430a9a92 100644 --- a/src/field/field.rs +++ b/src/field/field.rs @@ -15,7 +15,9 @@ pub trait Field: 'static + MulAssign + Div + DivAssign -+ Debug { ++ Debug ++ Send ++ Sync { const ZERO: Self; const ONE: Self; const NEG_ONE: Self; diff --git a/src/gates/deterministic_gate.rs b/src/gates/deterministic_gate.rs index 481808ad..2420831f 100644 --- a/src/gates/deterministic_gate.rs +++ b/src/gates/deterministic_gate.rs @@ -3,11 +3,11 @@ use std::marker::PhantomData; use crate::circuit_data::CircuitConfig; use crate::constraint_polynomial::{ConstraintPolynomial, EvaluationVars}; use crate::field::field::Field; -use crate::gates::gate::Gate2; +use crate::gates::gate::Gate; use crate::generator::{SimpleGenerator, WitnessGenerator2}; use crate::target::Target2; use crate::wire::Wire; -use crate::witness::PartialWitness2; +use crate::witness::PartialWitness; /// A deterministic gate. Each entry in `outputs()` describes how that output is evaluated; this is /// used to create both the constraint set and the generator set. @@ -52,7 +52,7 @@ impl> DeterministicGateAdapter { } } -impl> Gate2 for DeterministicGateAdapter { +impl> Gate for DeterministicGateAdapter { fn id(&self) -> String { self.gate.id() } @@ -110,7 +110,7 @@ impl SimpleGenerator for OutputGenerator { .collect() } - fn run_once(&mut self, witness: &PartialWitness2) -> PartialWitness2 { + fn run_once(&mut self, witness: &PartialWitness) -> PartialWitness { let mut local_wire_values = Vec::new(); let mut next_wire_values = Vec::new(); @@ -141,6 +141,6 @@ impl SimpleGenerator for OutputGenerator { let result_wire = Wire { gate: self.gate_index, input: self.input_index }; let result_value = self.out.evaluate(vars); - PartialWitness2::singleton(Target2::Wire(result_wire), result_value) + PartialWitness::singleton(Target2::Wire(result_wire), result_value) } } diff --git a/src/gates/gate.rs b/src/gates/gate.rs index 7cce3acb..5af85949 100644 --- a/src/gates/gate.rs +++ b/src/gates/gate.rs @@ -8,7 +8,7 @@ use crate::generator::WitnessGenerator2; /// A custom gate. // TODO: Remove CircuitConfig params? Could just use fields within each struct. -pub trait Gate2: 'static { +pub trait Gate: 'static { fn id(&self) -> String; /// A set of expressions which must evaluate to zero. @@ -52,10 +52,10 @@ pub trait Gate2: 'static { /// A wrapper around an `Rc` which implements `PartialEq`, `Eq` and `Hash` based on gate IDs. #[derive(Clone)] -pub struct GateRef(pub(crate) Rc>); +pub struct GateRef(pub(crate) Rc>); impl GateRef { - pub fn new>(gate: G) -> GateRef { + pub fn new>(gate: G) -> GateRef { GateRef(Rc::new(gate)) } } diff --git a/src/gates/gmimc.rs b/src/gates/gmimc.rs index 259d2fbc..79acc8b4 100644 --- a/src/gates/gmimc.rs +++ b/src/gates/gmimc.rs @@ -1,29 +1,77 @@ +use std::convert::TryInto; + use crate::circuit_data::CircuitConfig; use crate::constraint_polynomial::ConstraintPolynomial; use crate::field::field::Field; use crate::gates::deterministic_gate::DeterministicGate; -use crate::gates::gate::{Gate2, GateRef}; +use crate::gates::gate::{Gate, GateRef}; +use crate::generator::{SimpleGenerator, WitnessGenerator2}; +use crate::gmimc::gmimc_permute; +use crate::target::Target2; +use crate::wire::Wire; +use crate::witness::PartialWitness; +use std::sync::Arc; -/// Evaluates a full GMiMC permutation. +/// Evaluates a full GMiMC permutation, and writes the output to the next gate's first `width` +/// wires (which could be the input of another `GMiMCGate`). #[derive(Debug)] -pub struct GMiMCGate { - num_rounds: usize, - width: usize, - round_constants: Vec, +pub struct GMiMCGate { + round_constants: Arc<[F; R]>, } -impl GMiMCGate { - fn new(width: usize) -> GateRef { +impl GMiMCGate { + fn new() -> GateRef { todo!() } } -impl DeterministicGate for GMiMCGate { +impl Gate for GMiMCGate { fn id(&self) -> String { + // TODO: Add W/R format!("{:?}", self) } - fn outputs(&self, config: CircuitConfig) -> Vec<(usize, ConstraintPolynomial)> { + fn constraints(&self, config: CircuitConfig) -> Vec> { unimplemented!() } + + fn generators(&self, config: CircuitConfig, gate_index: usize, local_constants: Vec, next_constants: Vec) -> Vec>> { + let generator = GMiMCGenerator:: { + round_constants: self.round_constants.clone(), + gate_index, + }; + vec![Box::new(generator)] + } +} + +struct GMiMCGenerator { + round_constants: Arc<[F; R]>, + gate_index: usize, +} + +impl SimpleGenerator for GMiMCGenerator { + fn dependencies(&self) -> Vec { + (0..W) + .map(|i| Target2::Wire( + Wire { gate: self.gate_index, input: i })) + .collect() + } + + fn run_once(&mut self, witness: &PartialWitness) -> PartialWitness { + let mut inputs: [F; W] = [F::ZERO; W]; + for i in 0..W { + inputs[i] = witness.get_wire( + Wire { gate: self.gate_index, input: i }); + } + + let outputs = gmimc_permute::(inputs, self.round_constants.clone()); + + let mut result = PartialWitness::new(); + for i in 0..W { + result.set_wire( + Wire { gate: self.gate_index + 1, input: i }, + outputs[i]); + } + result + } } diff --git a/src/generator.rs b/src/generator.rs index 9b69acd4..251d8298 100644 --- a/src/generator.rs +++ b/src/generator.rs @@ -1,6 +1,6 @@ use crate::field::field::Field; use crate::target::Target2; -use crate::witness::PartialWitness2; +use crate::witness::PartialWitness; /// A generator participates in the generation of the witness. pub trait WitnessGenerator2: 'static { @@ -12,14 +12,14 @@ pub trait WitnessGenerator2: 'static { /// flag indicating whether the generator is finished. If the flag is true, the generator will /// never be run again, otherwise it will be queued for another run next time a target in its /// watch list is populated. - fn run(&mut self, witness: &PartialWitness2) -> (PartialWitness2, bool); + fn run(&mut self, witness: &PartialWitness) -> (PartialWitness, bool); } /// A generator which runs once after a list of dependencies is present in the witness. pub trait SimpleGenerator: 'static { fn dependencies(&self) -> Vec; - fn run_once(&mut self, witness: &PartialWitness2) -> PartialWitness2; + fn run_once(&mut self, witness: &PartialWitness) -> PartialWitness; } impl> WitnessGenerator2 for SG { @@ -27,11 +27,11 @@ impl> WitnessGenerator2 for SG { self.dependencies() } - fn run(&mut self, witness: &PartialWitness2) -> (PartialWitness2, bool) { + fn run(&mut self, witness: &PartialWitness) -> (PartialWitness, bool) { if witness.contains_all(&self.dependencies()) { (self.run_once(witness), true) } else { - (PartialWitness2::new(), false) + (PartialWitness::new(), false) } } } @@ -47,8 +47,8 @@ impl SimpleGenerator for CopyGenerator { vec![self.src] } - fn run_once(&mut self, witness: &PartialWitness2) -> PartialWitness2 { + fn run_once(&mut self, witness: &PartialWitness) -> PartialWitness { let value = witness.get_target(self.src); - PartialWitness2::singleton(self.dst, value) + PartialWitness::singleton(self.dst, value) } } diff --git a/src/gmimc.rs b/src/gmimc.rs index 6bb9c7c3..2e5ede16 100644 --- a/src/gmimc.rs +++ b/src/gmimc.rs @@ -1,39 +1,35 @@ +use std::sync::Arc; + use unroll::unroll_for_loops; use crate::field::field::Field; -const GMIMC_CONSTANTS: [u64; GMIMC_ROUNDS] = [11875528958976719239, 6107683892976199900, 7756999550758271958, 14819109722912164804, 9716579428412441110, 13627117528901194436, 16260683900833506663, 5942251937084147420, 3340009544523273897, 5103423085715007461, 17051583366444092101, 11122892258227244197, 16564300648907092407, 978667924592675864, 17676416205210517593, 1938246372790494499, 8857737698008340728, 1616088456497468086, 15961521580811621978, 17427220057097673602, 14693961562064090188, 694121596646283736, 554241305747273747, 5783347729647881086, 14933083198980931734, 2600898787591841337, 9178797321043036456, 18068112389665928586, 14493389459750307626, 1650694762687203587, 12538946551586403559, 10144328970401184255, 4215161528137084719, 17559540991336287827, 1632269449854444901, 986434918028205468, 14921385763379308253, 4345141219277982730, 2645897826751167170, 9815223670029373528, 7687983869685434132, 13956100321958014639, 519639453142393369, 15617837024229225911, 1557446238053329052, 8130006133842942201, 864716631341688017, 2860289738131495304, 16723700803638270299, 8363528906277648001, 13196016034228493087, 2514677332206134618, 15626342185220554936, 466271571343554681, 17490024028988898434, 6454235936129380878, 15187752952940298536, 18043495619660620405, 17118101079533798167, 13420382916440963101, 535472393366793763, 1071152303676936161, 6351382326603870931, 12029593435043638097, 9983185196487342247, 414304527840226604, 1578977347398530191, 13594880016528059526, 13219707576179925776, 6596253305527634647, 17708788597914990288, 7005038999589109658, 10171979740390484633, 1791376803510914239, 2405996319967739434, 12383033218117026776, 17648019043455213923, 6600216741450137683, 5359884112225925883, 1501497388400572107, 11860887439428904719, 64080876483307031, 11909038931518362287, 14166132102057826906, 14172584203466994499, 593515702472765471, 3423583343794830614, 10041710997716717966, 13434212189787960052, 9943803922749087030, 3216887087479209126, 17385898166602921353, 617799950397934255, 9245115057096506938, 13290383521064450731, 10193883853810413351, 14648839921475785656, 14635698366607946133, 9134302981480720532, 10045888297267997632, 10752096344939765738, 12049167771599274839, 16471532489936095930, 7118567245891966484, 272840212090177715, 7530334979534674340, 12300300144661791831, 14334496540665732547]; - -const GMIMC_ROUNDS: usize = 108; - -const W: usize = 12; - -pub fn gmimc_compress(a: [F; 4], b: [F; 4]) -> [F; 4] { +pub fn gmimc_compress(a: [F; 4], b: [F; 4], constants: Arc<[F; R]>) -> [F; 4] { // Sponge with r=8, c=4. let state_0 = [a[0], a[1], a[2], a[3], b[0], b[1], b[2], b[3], F::ZERO, F::ZERO, F::ZERO, F::ZERO]; - let state_1 = gmimc_permute(state_0); + let state_1 = gmimc_permute::(state_0, constants.clone()); [state_1[0], state_1[1], state_1[2], state_1[3]] } #[unroll_for_loops] -pub fn gmimc_permute(mut xs: [F; W]) -> [F; W] { - // TODO: Hardcoded width and num rounds for now, since unroll_for_loops doesn't work with - // constants or anything. Maybe use const generics when stable? - +pub fn gmimc_permute( + mut xs: [F; W], + constants: Arc<[F; R]>, +) -> [F; W] { // Value that is implicitly added to each element. // See https://affine.group/2020/02/starkware-challenge let mut addition_buffer = F::ZERO; - for r in 0..108 { - let active = r % 12; - let f = (xs[active] + addition_buffer + F::from_canonical_u64(GMIMC_CONSTANTS[r])).cube(); + for r in 0..R { + let active = r % W; + let f = (xs[active] + addition_buffer + constants[r]).cube(); addition_buffer += f; xs[active] -= f; } - for i in 0..12 { + for i in 0..W { xs[i] += addition_buffer; } @@ -41,14 +37,14 @@ pub fn gmimc_permute(mut xs: [F; W]) -> [F; W] { } #[unroll_for_loops] -pub fn gmimc_permute_naive(mut xs: [F; W]) -> [F; W] { - // TODO: Hardcoded width and num rounds for now, since unroll_for_loops doesn't work with - // constants or anything. Maybe use const generics when stable? - - for r in 0..108 { - let active = r % 12; - let f = (xs[active] + F::from_canonical_u64(GMIMC_CONSTANTS[r])).cube(); - for i in 0..12 { +pub fn gmimc_permute_naive( + mut xs: [F; W], + constants: Arc<[F; R]>, +) -> [F; W] { + for r in 0..R { + let active = r % W; + let f = (xs[active] + constants[r]).cube(); + for i in 0..W { if i != active { xs[i] = xs[i] + f; } @@ -71,8 +67,8 @@ mod tests { for i in 0..12 { xs[i] = F::from_canonical_usize(i); } - let out = gmimc_permute(xs); - let out_naive = gmimc_permute_naive(xs); + let out = gmimc_permute::<_, _, 108>(xs); + let out_naive = gmimc_permute_naive::<_, _, 108>(xs); assert_eq!(out, out_naive); } } diff --git a/src/main.rs b/src/main.rs index 0fe2cf14..de849346 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,3 +1,5 @@ +#![feature(const_generics)] + use std::thread; use std::time::Instant; @@ -9,6 +11,7 @@ use field::fft::fft_precompute; use crate::field::field::Field; use crate::util::log2_ceil; +use std::sync::Arc; mod circuit_data; mod constraint_polynomial; @@ -42,27 +45,39 @@ fn main() { // field_search() } +const GMIMC_ROUNDS: usize = 108; +const GMIMC_CONSTANTS: [u64; GMIMC_ROUNDS] = [11875528958976719239, 6107683892976199900, 7756999550758271958, 14819109722912164804, 9716579428412441110, 13627117528901194436, 16260683900833506663, 5942251937084147420, 3340009544523273897, 5103423085715007461, 17051583366444092101, 11122892258227244197, 16564300648907092407, 978667924592675864, 17676416205210517593, 1938246372790494499, 8857737698008340728, 1616088456497468086, 15961521580811621978, 17427220057097673602, 14693961562064090188, 694121596646283736, 554241305747273747, 5783347729647881086, 14933083198980931734, 2600898787591841337, 9178797321043036456, 18068112389665928586, 14493389459750307626, 1650694762687203587, 12538946551586403559, 10144328970401184255, 4215161528137084719, 17559540991336287827, 1632269449854444901, 986434918028205468, 14921385763379308253, 4345141219277982730, 2645897826751167170, 9815223670029373528, 7687983869685434132, 13956100321958014639, 519639453142393369, 15617837024229225911, 1557446238053329052, 8130006133842942201, 864716631341688017, 2860289738131495304, 16723700803638270299, 8363528906277648001, 13196016034228493087, 2514677332206134618, 15626342185220554936, 466271571343554681, 17490024028988898434, 6454235936129380878, 15187752952940298536, 18043495619660620405, 17118101079533798167, 13420382916440963101, 535472393366793763, 1071152303676936161, 6351382326603870931, 12029593435043638097, 9983185196487342247, 414304527840226604, 1578977347398530191, 13594880016528059526, 13219707576179925776, 6596253305527634647, 17708788597914990288, 7005038999589109658, 10171979740390484633, 1791376803510914239, 2405996319967739434, 12383033218117026776, 17648019043455213923, 6600216741450137683, 5359884112225925883, 1501497388400572107, 11860887439428904719, 64080876483307031, 11909038931518362287, 14166132102057826906, 14172584203466994499, 593515702472765471, 3423583343794830614, 10041710997716717966, 13434212189787960052, 9943803922749087030, 3216887087479209126, 17385898166602921353, 617799950397934255, 9245115057096506938, 13290383521064450731, 10193883853810413351, 14648839921475785656, 14635698366607946133, 9134302981480720532, 10045888297267997632, 10752096344939765738, 12049167771599274839, 16471532489936095930, 7118567245891966484, 272840212090177715, 7530334979534674340, 12300300144661791831, 14334496540665732547]; + fn bench_gmimc() { + let mut constants: [F; GMIMC_ROUNDS] = [F::ZERO; GMIMC_ROUNDS]; + for i in 0..GMIMC_ROUNDS { + constants[i] = F::from_canonical_u64(GMIMC_CONSTANTS[i]); + } + let constants = Arc::new(constants); + let threads = 12; // let hashes_per_poly = 623328; // let hashes_per_poly = 1 << log2_ceil(hashes_per_poly); let hashes_per_poly = 1 << (13 + 3); - let threads = (0..threads).map(|_i| thread::spawn(move || { - let mut x = [F::ZERO; 12]; - for i in 0..12 { - x[i] = F::from_canonical_u64((i as u64) * 123456 + 789); - } + let threads = (0..threads).map(|_i| { + let constants = constants.clone(); + thread::spawn(move || { + let mut x = [F::ZERO; 12]; + for i in 0..12 { + x[i] = F::from_canonical_u64((i as u64) * 123456 + 789); + } - let hashes_per_thread = hashes_per_poly * PROVER_POLYS / threads; - let start = Instant::now(); - for _ in 0..hashes_per_thread { - x = gmimc::gmimc_permute(x); - } - let duration = start.elapsed(); - println!("took {:?}", duration); - println!("avg {:?}us", duration.as_secs_f64() * 1e6 / (hashes_per_thread as f64)); - println!("result {:?}", x); - })).collect::>(); + let hashes_per_thread = hashes_per_poly * PROVER_POLYS / threads; + let start = Instant::now(); + for _ in 0..hashes_per_thread { + x = gmimc::gmimc_permute::<_, 12, 108>(x, constants.clone()); + } + let duration = start.elapsed(); + println!("took {:?}", duration); + println!("avg {:?}us", duration.as_secs_f64() * 1e6 / (hashes_per_thread as f64)); + println!("result {:?}", x); + }) + }).collect::>(); for t in threads { t.join().expect("oops"); diff --git a/src/witness.rs b/src/witness.rs index 47910874..56dde3af 100644 --- a/src/witness.rs +++ b/src/witness.rs @@ -5,19 +5,19 @@ use crate::target::Target2; use crate::wire::Wire; #[derive(Debug)] -pub struct PartialWitness2 { +pub struct PartialWitness { target_values: HashMap, } -impl PartialWitness2 { +impl PartialWitness { pub fn new() -> Self { - PartialWitness2 { + PartialWitness { target_values: HashMap::new(), } } pub fn singleton(target: Target2, value: F) -> Self { - let mut witness = PartialWitness2::new(); + let mut witness = PartialWitness::new(); witness.set_target(target, value); witness }