From cca79a992c7fded98d899b8e778825091cb19752 Mon Sep 17 00:00:00 2001 From: Daniel Lubarov Date: Wed, 31 Mar 2021 21:15:24 -0700 Subject: [PATCH] Sponges etc --- src/circuit_builder.rs | 4 + src/hash.rs | 47 +++++--- src/main.rs | 6 +- src/plonk_challenger.rs | 248 ++++++++++++++++++++++++++++++++++++++++ src/proof.rs | 4 +- src/prover.rs | 19 ++- 6 files changed, 302 insertions(+), 26 deletions(-) create mode 100644 src/plonk_challenger.rs diff --git a/src/circuit_builder.rs b/src/circuit_builder.rs index 49360c65..fce4c641 100644 --- a/src/circuit_builder.rs +++ b/src/circuit_builder.rs @@ -109,6 +109,10 @@ impl CircuitBuilder { Target::Wire(Wire { gate, input: ConstantGate::WIRE_OUTPUT }) } + pub fn permute(&mut self, inputs: [Target; 12]) -> [Target; 12] { + todo!() + } + fn blind_and_pad(&mut self) { // TODO: Blind. diff --git a/src/hash.rs b/src/hash.rs index 3e2ebae2..609f10f7 100644 --- a/src/hash.rs +++ b/src/hash.rs @@ -7,15 +7,19 @@ use rayon::prelude::*; use crate::field::field::Field; use crate::gmimc::gmimc_permute_array; use crate::proof::Hash; -use crate::util::reverse_index_bits_in_place; +use crate::util::{log2_ceil, reverse_index_bits_in_place}; -const RATE: usize = 8; -const CAPACITY: usize = 4; -const WIDTH: usize = RATE + CAPACITY; +pub(crate) const SPONGE_RATE: usize = 8; +pub(crate) const SPONGE_CAPACITY: usize = 4; +pub(crate) const SPONGE_WIDTH: usize = SPONGE_RATE + SPONGE_CAPACITY; const GMIMC_ROUNDS: usize = 101; const GMIMC_CONSTANTS: [u64; GMIMC_ROUNDS] = [11875528958976719239, 6107683892976199900, 7756999550758271958, 14819109722912164804, 9716579428412441110, 13627117528901194436, 16260683900833506663, 5942251937084147420, 3340009544523273897, 5103423085715007461, 17051583366444092101, 11122892258227244197, 16564300648907092407, 978667924592675864, 17676416205210517593, 1938246372790494499, 8857737698008340728, 1616088456497468086, 15961521580811621978, 17427220057097673602, 14693961562064090188, 694121596646283736, 554241305747273747, 5783347729647881086, 14933083198980931734, 2600898787591841337, 9178797321043036456, 18068112389665928586, 14493389459750307626, 1650694762687203587, 12538946551586403559, 10144328970401184255, 4215161528137084719, 17559540991336287827, 1632269449854444901, 986434918028205468, 14921385763379308253, 4345141219277982730, 2645897826751167170, 9815223670029373528, 7687983869685434132, 13956100321958014639, 519639453142393369, 15617837024229225911, 1557446238053329052, 8130006133842942201, 864716631341688017, 2860289738131495304, 16723700803638270299, 8363528906277648001, 13196016034228493087, 2514677332206134618, 15626342185220554936, 466271571343554681, 17490024028988898434, 6454235936129380878, 15187752952940298536, 18043495619660620405, 17118101079533798167, 13420382916440963101, 535472393366793763, 1071152303676936161, 6351382326603870931, 12029593435043638097, 9983185196487342247, 414304527840226604, 1578977347398530191, 13594880016528059526, 13219707576179925776, 6596253305527634647, 17708788597914990288, 7005038999589109658, 10171979740390484633, 1791376803510914239, 2405996319967739434, 12383033218117026776, 17648019043455213923, 6600216741450137683, 5359884112225925883, 1501497388400572107, 11860887439428904719, 64080876483307031, 11909038931518362287, 14166132102057826906, 14172584203466994499, 593515702472765471, 3423583343794830614, 10041710997716717966, 13434212189787960052, 9943803922749087030, 3216887087479209126, 17385898166602921353, 617799950397934255, 9245115057096506938, 13290383521064450731, 10193883853810413351, 14648839921475785656, 14635698366607946133, 9134302981480720532, 10045888297267997632, 10752096344939765738]; +/// If we're building a Merkle tree involving more field elements than this, it will be broken up +/// into smaller sub-trees that will be built in parallel. +const ELEMS_PER_CHUNK: usize = 1 << 8; + /// Hash the vector if necessary to reduce its length to ~256 bits. If it already fits, this is a /// no-op. pub fn hash_or_noop(mut inputs: Vec) -> Hash { @@ -34,38 +38,42 @@ pub fn compress(x: Hash, y: Hash) -> Hash { hash_n_to_hash(inputs, false) } +pub fn permute(xs: [F; SPONGE_WIDTH]) -> [F; SPONGE_WIDTH] { + gmimc_permute_array(xs, GMIMC_CONSTANTS) +} + /// If `pad` is enabled, the message is padded using the pad10*1 rule. In general this is required /// for the hash to be secure, but it can safely be disabled in certain cases, like if the input /// length is fixed. pub fn hash_n_to_m(mut inputs: Vec, num_outputs: usize, pad: bool) -> Vec { if pad { inputs.push(F::ZERO); - while (inputs.len() + 1) % WIDTH != 0 { + while (inputs.len() + 1) % SPONGE_WIDTH != 0 { inputs.push(F::ONE); } inputs.push(F::ZERO); } - let mut state = [F::ZERO; WIDTH]; + let mut state = [F::ZERO; SPONGE_WIDTH]; // Absorb all input chunks. - for input_chunk in inputs.chunks(WIDTH - 1) { + for input_chunk in inputs.chunks(SPONGE_WIDTH - 1) { for i in 0..input_chunk.len() { - state[i] = state[i] + input_chunk[i]; + state[i] += input_chunk[i]; } - state = gmimc_permute_array(state, GMIMC_CONSTANTS); + state = permute(state); } // Squeeze until we have the desired number of outputs. let mut outputs = Vec::new(); loop { - for i in 0..(WIDTH - 1) { + for i in 0..(SPONGE_WIDTH - 1) { outputs.push(state[i]); if outputs.len() == num_outputs { return outputs; } } - state = gmimc_permute_array(state, GMIMC_CONSTANTS); + state = permute(state); } } @@ -81,16 +89,25 @@ pub fn hash_n_to_1(inputs: Vec, pad: bool) -> F { /// Like `merkle_root`, but first reorders each vector so that `new[i] = old[i.reverse_bits()]`. pub(crate) fn merkle_root_bit_rev_order(mut vecs: Vec>) -> Hash { reverse_index_bits_in_place(&mut vecs); - merkle_root(vecs) + merkle_root(&vecs) } /// Given `n` vectors, each of length `l`, constructs a Merkle tree with `l` leaves, where each leaf /// is a hash obtained by hashing a "leaf set" consisting of `n` elements. If `n <= 4`, this hashing /// is skipped, as there is no need to compress leaf data. -pub(crate) fn merkle_root(vecs: Vec>) -> Hash { - // TODO: Parallelize. +pub(crate) fn merkle_root(vecs: &[Vec]) -> Hash { + let elems_per_leaf = vecs[0].len(); + let leaves_per_chunk = (ELEMS_PER_CHUNK / elems_per_leaf).next_power_of_two(); + let subtree_roots: Vec> = vecs.par_chunks(leaves_per_chunk) + .map(|chunk| merkle_root_inner(chunk).elements.to_vec()) + .collect(); + merkle_root_inner(&subtree_roots) +} + +pub(crate) fn merkle_root_inner(vecs: &[Vec]) -> Hash { + // TODO: to_vec() not really needed. let mut hashes = vecs.into_iter() - .map(|leaf_set| hash_or_noop(leaf_set)) + .map(|leaf_set| hash_or_noop(leaf_set.to_vec())) .collect::>(); while hashes.len() > 1 { hashes = hashes.chunks(2) diff --git a/src/main.rs b/src/main.rs index 9a4bf678..c8823647 100644 --- a/src/main.rs +++ b/src/main.rs @@ -27,6 +27,7 @@ mod generator; mod gmimc; mod hash; mod partition; +mod plonk_challenger; mod plonk_common; mod polynomial; mod proof; @@ -49,17 +50,12 @@ fn main() { // change this to info or warn later. env_logger::Builder::from_env(Env::default().default_filter_or("debug")).init(); - let overall_start = Instant::now(); - bench_prove::(); // bench_fft(); println!(); // bench_gmimc::(); - let overall_duration = overall_start.elapsed(); - println!("Overall time: {:?}", overall_duration); - // field_search() } diff --git a/src/plonk_challenger.rs b/src/plonk_challenger.rs new file mode 100644 index 00000000..1c5ab14d --- /dev/null +++ b/src/plonk_challenger.rs @@ -0,0 +1,248 @@ +use crate::circuit_builder::CircuitBuilder; +use crate::field::field::Field; +use crate::hash::{permute, SPONGE_WIDTH, SPONGE_RATE}; +use crate::target::Target; +use crate::proof::{Hash, HashTarget}; + +/// Observes prover messages, and generates challenges by hashing the transcript. +#[derive(Clone)] +pub struct Challenger { + sponge_state: [F; SPONGE_WIDTH], + input_buffer: Vec, + output_buffer: Vec, +} + +/// Observes prover messages, and generates verifier challenges based on the transcript. +/// +/// The implementation is roughly based on a duplex sponge with a Rescue permutation. Note that in +/// each round, our sponge can absorb an arbitrary number of prover messages and generate an +/// arbitrary number of verifier challenges. This might appear to diverge from the duplex sponge +/// design, but it can be viewed as a duplex sponge whose inputs are sometimes zero (when we perform +/// multiple squeezes) and whose outputs are sometimes ignored (when we perform multiple +/// absorptions). Thus the security properties of a duplex sponge still apply to our design. +impl Challenger { + pub fn new() -> Challenger { + Challenger { + sponge_state: [F::ZERO; SPONGE_WIDTH], + input_buffer: Vec::new(), + output_buffer: Vec::new(), + } + } + + pub fn observe_element(&mut self, element: F) { + // Any buffered outputs are now invalid, since they wouldn't reflect this input. + self.output_buffer.clear(); + + self.input_buffer.push(element); + } + + pub fn observe_elements(&mut self, elements: &[F]) { + for &element in elements { + self.observe_element(element); + } + } + + pub fn observe_hash(&mut self, hash: &Hash) { + self.observe_elements(&hash.elements) + } + + pub fn get_challenge(&mut self) -> F { + self.absorb_buffered_inputs(); + + if self.output_buffer.is_empty() { + // Evaluate the permutation to produce `r` new outputs. + self.sponge_state = permute(self.sponge_state); + self.output_buffer = self.sponge_state[0..SPONGE_RATE].to_vec(); + } + + self.output_buffer + .pop() + .expect("Output buffer should be non-empty") + } + + pub fn get_2_challenges(&mut self) -> (F, F) { + (self.get_challenge(), self.get_challenge()) + } + + pub fn get_3_challenges(&mut self) -> (F, F, F) { + ( + self.get_challenge(), + self.get_challenge(), + self.get_challenge(), + ) + } + + pub fn get_n_challenges(&mut self, n: usize) -> Vec { + (0..n).map(|_| self.get_challenge()).collect() + } + + /// Absorb any buffered inputs. After calling this, the input buffer will be empty. + fn absorb_buffered_inputs(&mut self) { + for input_chunk in self.input_buffer.chunks(SPONGE_RATE) { + // Add the inputs to our sponge state. + for (i, &input) in input_chunk.iter().enumerate() { + self.sponge_state[i] = self.sponge_state[i] + input; + } + + // Apply the permutation. + self.sponge_state = permute(self.sponge_state); + } + + self.output_buffer = self.sponge_state[0..SPONGE_RATE].to_vec(); + + self.input_buffer.clear(); + } +} + +/// A recursive version of `Challenger`. +pub(crate) struct RecursiveChallenger { + sponge_state: [Target; SPONGE_WIDTH], + input_buffer: Vec, + output_buffer: Vec, +} + +impl RecursiveChallenger { + pub(crate) fn new(builder: &mut CircuitBuilder) -> Self { + let zero = builder.zero(); + RecursiveChallenger { + sponge_state: [zero; SPONGE_WIDTH], + input_buffer: Vec::new(), + output_buffer: Vec::new(), + } + } + + pub(crate) fn observe_element(&mut self, target: Target) { + // Any buffered outputs are now invalid, since they wouldn't reflect this input. + self.output_buffer.clear(); + + self.input_buffer.push(target); + } + + pub(crate) fn observe_elements(&mut self, targets: &[Target]) { + for &target in targets { + self.observe_element(target); + } + } + + pub fn observe_hash(&mut self, hash: &HashTarget) { + self.observe_elements(&hash.elements) + } + + pub(crate) fn get_challenge( + &mut self, + builder: &mut CircuitBuilder, + ) -> Target { + self.absorb_buffered_inputs(builder); + + if self.output_buffer.is_empty() { + // Evaluate the permutation to produce `r` new outputs. + self.sponge_state = builder.permute(self.sponge_state); + self.output_buffer = self.sponge_state[0..SPONGE_RATE].to_vec(); + } + + self.output_buffer + .pop() + .expect("Output buffer should be non-empty") + } + + pub(crate) fn get_2_challenges( + &mut self, + builder: &mut CircuitBuilder, + ) -> (Target, Target) { + (self.get_challenge(builder), self.get_challenge(builder)) + } + + pub(crate) fn get_3_challenges( + &mut self, + builder: &mut CircuitBuilder, + ) -> (Target, Target, Target) { + ( + self.get_challenge(builder), + self.get_challenge(builder), + self.get_challenge(builder), + ) + } + + pub(crate) fn get_n_challenges( + &mut self, + builder: &mut CircuitBuilder, + n: usize, + ) -> Vec { + (0..n).map(|_| self.get_challenge(builder)).collect() + } + + /// Absorb any buffered inputs. After calling this, the input buffer will be empty. + fn absorb_buffered_inputs( + &mut self, + builder: &mut CircuitBuilder, + ) { + for input_chunk in self.input_buffer.chunks(SPONGE_RATE) { + // Add the inputs to our sponge state. + for (i, &input) in input_chunk.iter().enumerate() { + self.sponge_state[i] = builder.add(self.sponge_state[i], input); + } + + // Apply the permutation. + self.sponge_state = builder.permute(self.sponge_state); + } + + self.output_buffer = self.sponge_state[0..SPONGE_RATE].to_vec(); + + self.input_buffer.clear(); + } +} + +#[cfg(test)] +mod tests { + use crate::{CircuitBuilder, Curve, Field, PartialWitness, Target, Tweedledum}; + use crate::circuit_data::CircuitConfig; + use crate::field::crandall_field::CrandallField; + use crate::generator::generate_partial_witness; + use crate::plonk_challenger::{Challenger, RecursiveChallenger}; + use crate::target::Target; + + /// Tests for consistency between `Challenger` and `RecursiveChallenger`. + #[test] + fn test_consistency() { + type F = CrandallField; + + // These are mostly arbitrary, but we want to test some rounds with enough inputs/outputs to + // trigger multiple absorptions/squeezes. + let num_inputs_per_round = vec![2, 5, 3]; + let num_outputs_per_round = vec![1, 2, 4]; + + // Generate random input messages. + let inputs_per_round: Vec> = num_inputs_per_round + .iter() + .map(|&n| (0..n).map(|_| F::rand()).collect::>()) + .collect(); + + let mut challenger = Challenger::new(128); + let mut outputs_per_round: Vec> = Vec::new(); + for (r, inputs) in inputs_per_round.iter().enumerate() { + challenger.observe_elements(inputs); + outputs_per_round.push(challenger.get_n_challenges(num_outputs_per_round[r])); + } + + let config = CircuitConfig::default(); + let mut builder = CircuitBuilder::::new(config); + let mut recursive_challenger = RecursiveChallenger::new(&mut builder); + let mut recursive_outputs_per_round: Vec> = + Vec::new(); + for (r, inputs) in inputs_per_round.iter().enumerate() { + recursive_challenger.observe_elements(&builder.constant_wires(inputs)); + recursive_outputs_per_round.push( + recursive_challenger.get_n_challenges(&mut builder, num_outputs_per_round[r]), + ); + } + let circuit = builder.build(); + let mut witness = PartialWitness::new(); + generate_partial_witness(&mut witness, &circuit.prover_only.generators); + let recursive_output_values_per_round: Vec> = recursive_outputs_per_round + .iter() + .map(|outputs| witness.get_targets(outputs)) + .collect(); + + assert_eq!(outputs_per_round, recursive_output_values_per_round); + } +} diff --git a/src/proof.rs b/src/proof.rs index 5cd7c65d..4e581819 100644 --- a/src/proof.rs +++ b/src/proof.rs @@ -2,7 +2,7 @@ use crate::field::field::Field; use crate::target::Target; /// Represents a ~256 bit hash output. -#[derive(Copy, Clone)] +#[derive(Copy, Clone, Debug)] pub struct Hash { pub(crate) elements: [F; 4], } @@ -18,7 +18,7 @@ impl Hash { } pub struct HashTarget { - elements: Vec, + pub(crate) elements: Vec, } pub struct Proof { diff --git a/src/prover.rs b/src/prover.rs index bd490319..cce4a122 100644 --- a/src/prover.rs +++ b/src/prover.rs @@ -16,6 +16,7 @@ use crate::proof::Proof; use crate::util::transpose_poly_values; use crate::wire::Wire; use crate::witness::PartialWitness; +use crate::plonk_challenger::Challenger; pub(crate) fn prove( prover_data: &ProverOnlyCircuitData, @@ -53,9 +54,14 @@ pub(crate) fn prove( // TODO: Could avoid cloning if it's significant? let start_wires_root = Instant::now(); let wires_root = merkle_root_bit_rev_order(wire_ldes_t.clone()); - info!("{} to Merklizing wire LDEs", + info!("{} to Merklize wire LDEs", start_wires_root.elapsed().as_secs_f32()); + let mut challenger = Challenger::new(); + challenger.observe_hash(&wires_root); + let betas = challenger.get_n_challenges(config.num_checks); + let gammas = challenger.get_n_challenges(config.num_checks); + let start_plonk_z = Instant::now(); let plonk_z_vecs = compute_zs(&common_data); let plonk_z_ldes = PolynomialValues::lde_multiple(plonk_z_vecs, config.rate_bits); @@ -68,9 +74,14 @@ pub(crate) fn prove( info!("{}s to Merklize Z's", start_plonk_z_root.elapsed().as_secs_f32()); - let beta = F::ZERO; // TODO - let gamma = F::ZERO; // TODO - let alpha = F::ZERO; // TODO + challenger.observe_hash(&plonk_z_root); + + let alphas = challenger.get_n_challenges(config.num_checks); + + // TODO + let beta = betas[0]; + let gamma = gammas[0]; + let alpha = alphas[0]; let start_vanishing_poly = Instant::now(); let vanishing_poly = compute_vanishing_poly(