Const generics in GMiMC

This commit is contained in:
Daniel Lubarov 2021-02-24 12:25:13 -08:00
parent 1480876c9a
commit 3ba9ef8ab7
9 changed files with 132 additions and 79 deletions

View File

@ -152,14 +152,6 @@ impl DivAssign for CrandallField {
}
}
#[inline(always)]
fn reduce64(x: u64) -> u64 {
// TODO: slow? try cond sub
// x % P
let over = x > P;
x - (over as u64) * P
}
/// no final reduction
#[inline(always)]
fn reduce128(x: u128) -> CrandallField {

View File

@ -15,7 +15,9 @@ pub trait Field: 'static
+ MulAssign<Self>
+ Div<Self, Output=Self>
+ DivAssign<Self>
+ Debug {
+ Debug
+ Send
+ Sync {
const ZERO: Self;
const ONE: Self;
const NEG_ONE: Self;

View File

@ -3,11 +3,11 @@ use std::marker::PhantomData;
use crate::circuit_data::CircuitConfig;
use crate::constraint_polynomial::{ConstraintPolynomial, EvaluationVars};
use crate::field::field::Field;
use crate::gates::gate::Gate2;
use crate::gates::gate::Gate;
use crate::generator::{SimpleGenerator, WitnessGenerator2};
use crate::target::Target2;
use crate::wire::Wire;
use crate::witness::PartialWitness2;
use crate::witness::PartialWitness;
/// A deterministic gate. Each entry in `outputs()` describes how that output is evaluated; this is
/// used to create both the constraint set and the generator set.
@ -52,7 +52,7 @@ impl<F: Field, DG: DeterministicGate<F>> DeterministicGateAdapter<F, DG> {
}
}
impl<F: Field, DG: DeterministicGate<F>> Gate2<F> for DeterministicGateAdapter<F, DG> {
impl<F: Field, DG: DeterministicGate<F>> Gate<F> for DeterministicGateAdapter<F, DG> {
fn id(&self) -> String {
self.gate.id()
}
@ -110,7 +110,7 @@ impl<F: Field> SimpleGenerator<F> for OutputGenerator<F> {
.collect()
}
fn run_once(&mut self, witness: &PartialWitness2<F>) -> PartialWitness2<F> {
fn run_once(&mut self, witness: &PartialWitness<F>) -> PartialWitness<F> {
let mut local_wire_values = Vec::new();
let mut next_wire_values = Vec::new();
@ -141,6 +141,6 @@ impl<F: Field> SimpleGenerator<F> for OutputGenerator<F> {
let result_wire = Wire { gate: self.gate_index, input: self.input_index };
let result_value = self.out.evaluate(vars);
PartialWitness2::singleton(Target2::Wire(result_wire), result_value)
PartialWitness::singleton(Target2::Wire(result_wire), result_value)
}
}

View File

@ -8,7 +8,7 @@ use crate::generator::WitnessGenerator2;
/// A custom gate.
// TODO: Remove CircuitConfig params? Could just use fields within each struct.
pub trait Gate2<F: Field>: 'static {
pub trait Gate<F: Field>: 'static {
fn id(&self) -> String;
/// A set of expressions which must evaluate to zero.
@ -52,10 +52,10 @@ pub trait Gate2<F: Field>: 'static {
/// A wrapper around an `Rc<Gate>` which implements `PartialEq`, `Eq` and `Hash` based on gate IDs.
#[derive(Clone)]
pub struct GateRef<F: Field>(pub(crate) Rc<dyn Gate2<F>>);
pub struct GateRef<F: Field>(pub(crate) Rc<dyn Gate<F>>);
impl<F: Field> GateRef<F> {
pub fn new<G: Gate2<F>>(gate: G) -> GateRef<F> {
pub fn new<G: Gate<F>>(gate: G) -> GateRef<F> {
GateRef(Rc::new(gate))
}
}

View File

@ -1,29 +1,77 @@
use std::convert::TryInto;
use crate::circuit_data::CircuitConfig;
use crate::constraint_polynomial::ConstraintPolynomial;
use crate::field::field::Field;
use crate::gates::deterministic_gate::DeterministicGate;
use crate::gates::gate::{Gate2, GateRef};
use crate::gates::gate::{Gate, GateRef};
use crate::generator::{SimpleGenerator, WitnessGenerator2};
use crate::gmimc::gmimc_permute;
use crate::target::Target2;
use crate::wire::Wire;
use crate::witness::PartialWitness;
use std::sync::Arc;
/// Evaluates a full GMiMC permutation.
/// Evaluates a full GMiMC permutation, and writes the output to the next gate's first `width`
/// wires (which could be the input of another `GMiMCGate`).
#[derive(Debug)]
pub struct GMiMCGate<F: Field> {
num_rounds: usize,
width: usize,
round_constants: Vec<F>,
pub struct GMiMCGate<F: Field, const W: usize, const R: usize> {
round_constants: Arc<[F; R]>,
}
impl<F: Field> GMiMCGate<F> {
fn new(width: usize) -> GateRef<F> {
impl<F: Field, const W: usize, const R: usize> GMiMCGate<F, W, R> {
fn new() -> GateRef<F> {
todo!()
}
}
impl<F: Field> DeterministicGate<F> for GMiMCGate<F> {
impl<F: Field, const W: usize, const R: usize> Gate<F> for GMiMCGate<F, W, R> {
fn id(&self) -> String {
// TODO: Add W/R
format!("{:?}", self)
}
fn outputs(&self, config: CircuitConfig) -> Vec<(usize, ConstraintPolynomial<F>)> {
fn constraints(&self, config: CircuitConfig) -> Vec<ConstraintPolynomial<F>> {
unimplemented!()
}
fn generators(&self, config: CircuitConfig, gate_index: usize, local_constants: Vec<F>, next_constants: Vec<F>) -> Vec<Box<dyn WitnessGenerator2<F>>> {
let generator = GMiMCGenerator::<F, W, R> {
round_constants: self.round_constants.clone(),
gate_index,
};
vec![Box::new(generator)]
}
}
struct GMiMCGenerator<F: Field, const W: usize, const R: usize> {
round_constants: Arc<[F; R]>,
gate_index: usize,
}
impl<F: Field, const W: usize, const R: usize> SimpleGenerator<F> for GMiMCGenerator<F, W, R> {
fn dependencies(&self) -> Vec<Target2> {
(0..W)
.map(|i| Target2::Wire(
Wire { gate: self.gate_index, input: i }))
.collect()
}
fn run_once(&mut self, witness: &PartialWitness<F>) -> PartialWitness<F> {
let mut inputs: [F; W] = [F::ZERO; W];
for i in 0..W {
inputs[i] = witness.get_wire(
Wire { gate: self.gate_index, input: i });
}
let outputs = gmimc_permute::<F, W, R>(inputs, self.round_constants.clone());
let mut result = PartialWitness::new();
for i in 0..W {
result.set_wire(
Wire { gate: self.gate_index + 1, input: i },
outputs[i]);
}
result
}
}

View File

@ -1,6 +1,6 @@
use crate::field::field::Field;
use crate::target::Target2;
use crate::witness::PartialWitness2;
use crate::witness::PartialWitness;
/// A generator participates in the generation of the witness.
pub trait WitnessGenerator2<F: Field>: 'static {
@ -12,14 +12,14 @@ pub trait WitnessGenerator2<F: Field>: 'static {
/// flag indicating whether the generator is finished. If the flag is true, the generator will
/// never be run again, otherwise it will be queued for another run next time a target in its
/// watch list is populated.
fn run(&mut self, witness: &PartialWitness2<F>) -> (PartialWitness2<F>, bool);
fn run(&mut self, witness: &PartialWitness<F>) -> (PartialWitness<F>, bool);
}
/// A generator which runs once after a list of dependencies is present in the witness.
pub trait SimpleGenerator<F: Field>: 'static {
fn dependencies(&self) -> Vec<Target2>;
fn run_once(&mut self, witness: &PartialWitness2<F>) -> PartialWitness2<F>;
fn run_once(&mut self, witness: &PartialWitness<F>) -> PartialWitness<F>;
}
impl<F: Field, SG: SimpleGenerator<F>> WitnessGenerator2<F> for SG {
@ -27,11 +27,11 @@ impl<F: Field, SG: SimpleGenerator<F>> WitnessGenerator2<F> for SG {
self.dependencies()
}
fn run(&mut self, witness: &PartialWitness2<F>) -> (PartialWitness2<F>, bool) {
fn run(&mut self, witness: &PartialWitness<F>) -> (PartialWitness<F>, bool) {
if witness.contains_all(&self.dependencies()) {
(self.run_once(witness), true)
} else {
(PartialWitness2::new(), false)
(PartialWitness::new(), false)
}
}
}
@ -47,8 +47,8 @@ impl<F: Field> SimpleGenerator<F> for CopyGenerator {
vec![self.src]
}
fn run_once(&mut self, witness: &PartialWitness2<F>) -> PartialWitness2<F> {
fn run_once(&mut self, witness: &PartialWitness<F>) -> PartialWitness<F> {
let value = witness.get_target(self.src);
PartialWitness2::singleton(self.dst, value)
PartialWitness::singleton(self.dst, value)
}
}

View File

@ -1,39 +1,35 @@
use std::sync::Arc;
use unroll::unroll_for_loops;
use crate::field::field::Field;
const GMIMC_CONSTANTS: [u64; GMIMC_ROUNDS] = [11875528958976719239, 6107683892976199900, 7756999550758271958, 14819109722912164804, 9716579428412441110, 13627117528901194436, 16260683900833506663, 5942251937084147420, 3340009544523273897, 5103423085715007461, 17051583366444092101, 11122892258227244197, 16564300648907092407, 978667924592675864, 17676416205210517593, 1938246372790494499, 8857737698008340728, 1616088456497468086, 15961521580811621978, 17427220057097673602, 14693961562064090188, 694121596646283736, 554241305747273747, 5783347729647881086, 14933083198980931734, 2600898787591841337, 9178797321043036456, 18068112389665928586, 14493389459750307626, 1650694762687203587, 12538946551586403559, 10144328970401184255, 4215161528137084719, 17559540991336287827, 1632269449854444901, 986434918028205468, 14921385763379308253, 4345141219277982730, 2645897826751167170, 9815223670029373528, 7687983869685434132, 13956100321958014639, 519639453142393369, 15617837024229225911, 1557446238053329052, 8130006133842942201, 864716631341688017, 2860289738131495304, 16723700803638270299, 8363528906277648001, 13196016034228493087, 2514677332206134618, 15626342185220554936, 466271571343554681, 17490024028988898434, 6454235936129380878, 15187752952940298536, 18043495619660620405, 17118101079533798167, 13420382916440963101, 535472393366793763, 1071152303676936161, 6351382326603870931, 12029593435043638097, 9983185196487342247, 414304527840226604, 1578977347398530191, 13594880016528059526, 13219707576179925776, 6596253305527634647, 17708788597914990288, 7005038999589109658, 10171979740390484633, 1791376803510914239, 2405996319967739434, 12383033218117026776, 17648019043455213923, 6600216741450137683, 5359884112225925883, 1501497388400572107, 11860887439428904719, 64080876483307031, 11909038931518362287, 14166132102057826906, 14172584203466994499, 593515702472765471, 3423583343794830614, 10041710997716717966, 13434212189787960052, 9943803922749087030, 3216887087479209126, 17385898166602921353, 617799950397934255, 9245115057096506938, 13290383521064450731, 10193883853810413351, 14648839921475785656, 14635698366607946133, 9134302981480720532, 10045888297267997632, 10752096344939765738, 12049167771599274839, 16471532489936095930, 7118567245891966484, 272840212090177715, 7530334979534674340, 12300300144661791831, 14334496540665732547];
const GMIMC_ROUNDS: usize = 108;
const W: usize = 12;
pub fn gmimc_compress<F: Field>(a: [F; 4], b: [F; 4]) -> [F; 4] {
pub fn gmimc_compress<F: Field, const R: usize>(a: [F; 4], b: [F; 4], constants: Arc<[F; R]>) -> [F; 4] {
// Sponge with r=8, c=4.
let state_0 = [a[0], a[1], a[2], a[3], b[0],
b[1], b[2], b[3],
F::ZERO, F::ZERO, F::ZERO, F::ZERO];
let state_1 = gmimc_permute(state_0);
let state_1 = gmimc_permute::<F, 12, R>(state_0, constants.clone());
[state_1[0], state_1[1], state_1[2], state_1[3]]
}
#[unroll_for_loops]
pub fn gmimc_permute<F: Field>(mut xs: [F; W]) -> [F; W] {
// TODO: Hardcoded width and num rounds for now, since unroll_for_loops doesn't work with
// constants or anything. Maybe use const generics when stable?
pub fn gmimc_permute<F: Field, const W: usize, const R: usize>(
mut xs: [F; W],
constants: Arc<[F; R]>,
) -> [F; W] {
// Value that is implicitly added to each element.
// See https://affine.group/2020/02/starkware-challenge
let mut addition_buffer = F::ZERO;
for r in 0..108 {
let active = r % 12;
let f = (xs[active] + addition_buffer + F::from_canonical_u64(GMIMC_CONSTANTS[r])).cube();
for r in 0..R {
let active = r % W;
let f = (xs[active] + addition_buffer + constants[r]).cube();
addition_buffer += f;
xs[active] -= f;
}
for i in 0..12 {
for i in 0..W {
xs[i] += addition_buffer;
}
@ -41,14 +37,14 @@ pub fn gmimc_permute<F: Field>(mut xs: [F; W]) -> [F; W] {
}
#[unroll_for_loops]
pub fn gmimc_permute_naive<F: Field>(mut xs: [F; W]) -> [F; W] {
// TODO: Hardcoded width and num rounds for now, since unroll_for_loops doesn't work with
// constants or anything. Maybe use const generics when stable?
for r in 0..108 {
let active = r % 12;
let f = (xs[active] + F::from_canonical_u64(GMIMC_CONSTANTS[r])).cube();
for i in 0..12 {
pub fn gmimc_permute_naive<F: Field, const W: usize, const R: usize>(
mut xs: [F; W],
constants: Arc<[F; R]>,
) -> [F; W] {
for r in 0..R {
let active = r % W;
let f = (xs[active] + constants[r]).cube();
for i in 0..W {
if i != active {
xs[i] = xs[i] + f;
}
@ -71,8 +67,8 @@ mod tests {
for i in 0..12 {
xs[i] = F::from_canonical_usize(i);
}
let out = gmimc_permute(xs);
let out_naive = gmimc_permute_naive(xs);
let out = gmimc_permute::<_, _, 108>(xs);
let out_naive = gmimc_permute_naive::<_, _, 108>(xs);
assert_eq!(out, out_naive);
}
}

View File

@ -1,3 +1,5 @@
#![feature(const_generics)]
use std::thread;
use std::time::Instant;
@ -9,6 +11,7 @@ use field::fft::fft_precompute;
use crate::field::field::Field;
use crate::util::log2_ceil;
use std::sync::Arc;
mod circuit_data;
mod constraint_polynomial;
@ -42,27 +45,39 @@ fn main() {
// field_search()
}
const GMIMC_ROUNDS: usize = 108;
const GMIMC_CONSTANTS: [u64; GMIMC_ROUNDS] = [11875528958976719239, 6107683892976199900, 7756999550758271958, 14819109722912164804, 9716579428412441110, 13627117528901194436, 16260683900833506663, 5942251937084147420, 3340009544523273897, 5103423085715007461, 17051583366444092101, 11122892258227244197, 16564300648907092407, 978667924592675864, 17676416205210517593, 1938246372790494499, 8857737698008340728, 1616088456497468086, 15961521580811621978, 17427220057097673602, 14693961562064090188, 694121596646283736, 554241305747273747, 5783347729647881086, 14933083198980931734, 2600898787591841337, 9178797321043036456, 18068112389665928586, 14493389459750307626, 1650694762687203587, 12538946551586403559, 10144328970401184255, 4215161528137084719, 17559540991336287827, 1632269449854444901, 986434918028205468, 14921385763379308253, 4345141219277982730, 2645897826751167170, 9815223670029373528, 7687983869685434132, 13956100321958014639, 519639453142393369, 15617837024229225911, 1557446238053329052, 8130006133842942201, 864716631341688017, 2860289738131495304, 16723700803638270299, 8363528906277648001, 13196016034228493087, 2514677332206134618, 15626342185220554936, 466271571343554681, 17490024028988898434, 6454235936129380878, 15187752952940298536, 18043495619660620405, 17118101079533798167, 13420382916440963101, 535472393366793763, 1071152303676936161, 6351382326603870931, 12029593435043638097, 9983185196487342247, 414304527840226604, 1578977347398530191, 13594880016528059526, 13219707576179925776, 6596253305527634647, 17708788597914990288, 7005038999589109658, 10171979740390484633, 1791376803510914239, 2405996319967739434, 12383033218117026776, 17648019043455213923, 6600216741450137683, 5359884112225925883, 1501497388400572107, 11860887439428904719, 64080876483307031, 11909038931518362287, 14166132102057826906, 14172584203466994499, 593515702472765471, 3423583343794830614, 10041710997716717966, 13434212189787960052, 9943803922749087030, 3216887087479209126, 17385898166602921353, 617799950397934255, 9245115057096506938, 13290383521064450731, 10193883853810413351, 14648839921475785656, 14635698366607946133, 9134302981480720532, 10045888297267997632, 10752096344939765738, 12049167771599274839, 16471532489936095930, 7118567245891966484, 272840212090177715, 7530334979534674340, 12300300144661791831, 14334496540665732547];
fn bench_gmimc<F: Field>() {
let mut constants: [F; GMIMC_ROUNDS] = [F::ZERO; GMIMC_ROUNDS];
for i in 0..GMIMC_ROUNDS {
constants[i] = F::from_canonical_u64(GMIMC_CONSTANTS[i]);
}
let constants = Arc::new(constants);
let threads = 12;
// let hashes_per_poly = 623328;
// let hashes_per_poly = 1 << log2_ceil(hashes_per_poly);
let hashes_per_poly = 1 << (13 + 3);
let threads = (0..threads).map(|_i| thread::spawn(move || {
let mut x = [F::ZERO; 12];
for i in 0..12 {
x[i] = F::from_canonical_u64((i as u64) * 123456 + 789);
}
let threads = (0..threads).map(|_i| {
let constants = constants.clone();
thread::spawn(move || {
let mut x = [F::ZERO; 12];
for i in 0..12 {
x[i] = F::from_canonical_u64((i as u64) * 123456 + 789);
}
let hashes_per_thread = hashes_per_poly * PROVER_POLYS / threads;
let start = Instant::now();
for _ in 0..hashes_per_thread {
x = gmimc::gmimc_permute(x);
}
let duration = start.elapsed();
println!("took {:?}", duration);
println!("avg {:?}us", duration.as_secs_f64() * 1e6 / (hashes_per_thread as f64));
println!("result {:?}", x);
})).collect::<Vec<_>>();
let hashes_per_thread = hashes_per_poly * PROVER_POLYS / threads;
let start = Instant::now();
for _ in 0..hashes_per_thread {
x = gmimc::gmimc_permute::<_, 12, 108>(x, constants.clone());
}
let duration = start.elapsed();
println!("took {:?}", duration);
println!("avg {:?}us", duration.as_secs_f64() * 1e6 / (hashes_per_thread as f64));
println!("result {:?}", x);
})
}).collect::<Vec<_>>();
for t in threads {
t.join().expect("oops");

View File

@ -5,19 +5,19 @@ use crate::target::Target2;
use crate::wire::Wire;
#[derive(Debug)]
pub struct PartialWitness2<F: Field> {
pub struct PartialWitness<F: Field> {
target_values: HashMap<Target2, F>,
}
impl<F: Field> PartialWitness2<F> {
impl<F: Field> PartialWitness<F> {
pub fn new() -> Self {
PartialWitness2 {
PartialWitness {
target_values: HashMap::new(),
}
}
pub fn singleton(target: Target2, value: F) -> Self {
let mut witness = PartialWitness2::new();
let mut witness = PartialWitness::new();
witness.set_target(target, value);
witness
}