diff --git a/src/gadgets/arithmetic.rs b/src/gadgets/arithmetic.rs index 124bcccb..a8ec3754 100644 --- a/src/gadgets/arithmetic.rs +++ b/src/gadgets/arithmetic.rs @@ -4,7 +4,6 @@ use crate::field::extension_field::Extendable; use crate::gates::exponentiation::ExponentiationGate; use crate::iop::target::Target; use crate::plonk::circuit_builder::CircuitBuilder; -use crate::util::log2_ceil; impl, const D: usize> CircuitBuilder { /// Computes `-x`. @@ -97,12 +96,8 @@ impl, const D: usize> CircuitBuilder { } /// Exponentiate `base` to the power of `2^power_log`. - // TODO: Test - pub fn exp_power_of_2(&mut self, mut base: Target, power_log: usize) -> Target { - for _ in 0..power_log { - base = self.square(base); - } - base + pub fn exp_power_of_2(&mut self, base: Target, power_log: usize) -> Target { + self.exp_u64(base, 1 << power_log) } // TODO: Test @@ -110,12 +105,12 @@ impl, const D: usize> CircuitBuilder { pub fn exp_from_bits( &mut self, base: Target, - exponent_bits: impl Iterator>, + exponent_bits: impl IntoIterator>, ) -> Target { let zero = self.zero(); let gate = ExponentiationGate::new(self.config.clone()); let num_power_bits = gate.num_power_bits; - let mut exp_bits_vec: Vec = exponent_bits.map(|b| *b.borrow()).collect(); + let mut exp_bits_vec: Vec = exponent_bits.into_iter().map(|b| *b.borrow()).collect(); while exp_bits_vec.len() < num_power_bits { exp_bits_vec.push(zero); } @@ -139,10 +134,16 @@ impl, const D: usize> CircuitBuilder { /// Exponentiate `base` to the power of a known `exponent`. // TODO: Test - pub fn exp_u64(&mut self, base: Target, exponent: u64) -> Target { - let exp_target = self.constant(F::from_canonical_u64(exponent)); - let num_bits = log2_ceil(exponent as usize + 1); - self.exp(base, exp_target, num_bits) + pub fn exp_u64(&mut self, base: Target, mut exponent: u64) -> Target { + let mut exp_bits = Vec::new(); + while exponent != 0 { + let bit = exponent & 1; + let bit_target = self.constant(F::from_canonical_u64(bit)); + exp_bits.push(bit_target); + exponent >>= 1; + } + + self.exp_from_bits(base, exp_bits) } /// Computes `x / y`. Results in an unsatisfiable instance if `y = 0`. diff --git a/src/gadgets/mod.rs b/src/gadgets/mod.rs index 4c4160e1..0eb42e27 100644 --- a/src/gadgets/mod.rs +++ b/src/gadgets/mod.rs @@ -4,6 +4,7 @@ pub mod hash; pub mod insert; pub mod interpolation; pub mod polynomial; +pub mod random_access; pub mod range_check; pub mod select; pub mod split_base; diff --git a/src/gadgets/random_access.rs b/src/gadgets/random_access.rs new file mode 100644 index 00000000..a435b99f --- /dev/null +++ b/src/gadgets/random_access.rs @@ -0,0 +1,76 @@ +use crate::field::extension_field::target::ExtensionTarget; +use crate::field::extension_field::Extendable; +use crate::gates::random_access::RandomAccessGate; +use crate::iop::target::Target; +use crate::plonk::circuit_builder::CircuitBuilder; + +impl, const D: usize> CircuitBuilder { + /// Checks that a `Target` matches a vector at a non-deterministic index. + /// Note: `index` is not range-checked. + pub fn random_access( + &mut self, + access_index: Target, + claimed_element: ExtensionTarget, + v: Vec>, + ) { + let gate = RandomAccessGate::new(v.len()); + let gate_index = self.add_gate(gate.clone(), vec![]); + + v.iter().enumerate().for_each(|(i, &val)| { + self.route_extension( + val, + ExtensionTarget::from_range(gate_index, gate.wires_list_item(i)), + ); + }); + self.route( + access_index, + Target::wire(gate_index, gate.wires_access_index()), + ); + self.route_extension( + claimed_element, + ExtensionTarget::from_range(gate_index, gate.wires_claimed_element()), + ); + } +} + +#[cfg(test)] +mod tests { + use anyhow::Result; + + use super::*; + use crate::field::crandall_field::CrandallField; + use crate::field::extension_field::quartic::QuarticCrandallField; + use crate::field::field_types::Field; + use crate::iop::witness::PartialWitness; + use crate::plonk::circuit_data::CircuitConfig; + use crate::plonk::verifier::verify; + + fn test_random_access_given_len(len_log: usize) -> Result<()> { + type F = CrandallField; + type FF = QuarticCrandallField; + let len = 1 << len_log; + let config = CircuitConfig::large_config(); + let mut builder = CircuitBuilder::::new(config); + let vec = FF::rand_vec(len); + let v: Vec<_> = vec.iter().map(|x| builder.constant_extension(*x)).collect(); + + for i in 0..len { + let it = builder.constant(F::from_canonical_usize(i)); + let elem = builder.constant_extension(vec[i]); + builder.random_access(it, elem, v.clone()); + } + + let data = builder.build(); + let proof = data.prove(PartialWitness::new())?; + + verify(proof, &data.verifier_only, &data.common) + } + + #[test] + fn test_random_access() -> Result<()> { + for len_log in 1..3 { + test_random_access_given_len(len_log)?; + } + Ok(()) + } +} diff --git a/src/gates/gmimc.rs b/src/gates/gmimc.rs index 4a9425af..97847d41 100644 --- a/src/gates/gmimc.rs +++ b/src/gates/gmimc.rs @@ -334,6 +334,7 @@ mod tests { use crate::iop::generator::generate_partial_witness; use crate::iop::wire::Wire; use crate::iop::witness::PartialWitness; + use crate::util::timing::TimingTree; #[test] fn generated_output() { @@ -364,7 +365,7 @@ mod tests { } let generators = gate.generators(0, &[]); - generate_partial_witness(&mut witness, &generators); + generate_partial_witness(&mut witness, &generators, &mut TimingTree::default()); let expected_outputs: [F; W] = gmimc_permute_naive(permutation_inputs.try_into().unwrap(), constants); diff --git a/src/gates/random_access.rs b/src/gates/random_access.rs index 413a6ff3..ff35eaa2 100644 --- a/src/gates/random_access.rs +++ b/src/gates/random_access.rs @@ -1,4 +1,3 @@ -use std::convert::TryInto; use std::marker::PhantomData; use std::ops::Range; @@ -217,13 +216,6 @@ impl, const D: usize> SimpleGenerator for RandomAccessGenera let get_local_wire = |input| witness.get_wire(local_wire(input)); - let get_local_ext = |wire_range: Range| { - debug_assert_eq!(wire_range.len(), D); - let values = wire_range.map(get_local_wire).collect::>(); - let arr = values.try_into().unwrap(); - F::Extension::from_basefield_array(arr) - }; - // Compute the new vector and the values for equality_dummy and index_matches let vec_size = self.gate.vec_size; let access_index_f = get_local_wire(self.gate.wires_access_index()); diff --git a/src/iop/challenger.rs b/src/iop/challenger.rs index d15eab30..53867de7 100644 --- a/src/iop/challenger.rs +++ b/src/iop/challenger.rs @@ -350,6 +350,7 @@ mod tests { use crate::iop::witness::PartialWitness; use crate::plonk::circuit_builder::CircuitBuilder; use crate::plonk::circuit_data::CircuitConfig; + use crate::util::timing::TimingTree; #[test] fn no_duplicate_challenges() { @@ -409,7 +410,11 @@ mod tests { } let circuit = builder.build(); let mut witness = PartialWitness::new(); - generate_partial_witness(&mut witness, &circuit.prover_only.generators); + generate_partial_witness( + &mut witness, + &circuit.prover_only.generators, + &mut TimingTree::default(), + ); let recursive_output_values_per_round: Vec> = recursive_outputs_per_round .iter() .map(|outputs| witness.get_targets(outputs)) diff --git a/src/iop/generator.rs b/src/iop/generator.rs index 9f421b14..ad492810 100644 --- a/src/iop/generator.rs +++ b/src/iop/generator.rs @@ -1,4 +1,4 @@ -use std::collections::{HashMap, HashSet}; +use std::collections::HashMap; use std::convert::identity; use std::fmt::Debug; @@ -9,27 +9,32 @@ use crate::hash::hash_types::{HashOut, HashOutTarget}; use crate::iop::target::Target; use crate::iop::wire::Wire; use crate::iop::witness::PartialWitness; +use crate::timed; +use crate::util::timing::TimingTree; /// Given a `PartialWitness` that has only inputs set, populates the rest of the witness using the /// given set of generators. pub(crate) fn generate_partial_witness( witness: &mut PartialWitness, generators: &[Box>], + timing: &mut TimingTree, ) { // Index generator indices by their watched targets. let mut generator_indices_by_watches = HashMap::new(); - for (i, generator) in generators.iter().enumerate() { - for watch in generator.watch_list() { - generator_indices_by_watches - .entry(watch) - .or_insert_with(Vec::new) - .push(i); + timed!(timing, "index generators by their watched targets", { + for (i, generator) in generators.iter().enumerate() { + for watch in generator.watch_list() { + generator_indices_by_watches + .entry(watch) + .or_insert_with(Vec::new) + .push(i); + } } - } + }); // Build a list of "pending" generators which are queued to be run. Initially, all generators // are queued. - let mut pending_generator_indices: HashSet<_> = (0..generators.len()).collect(); + let mut pending_generator_indices: Vec<_> = (0..generators.len()).collect(); // We also track a list of "expired" generators which have already returned false. let mut generator_is_expired = vec![false; generators.len()]; @@ -38,9 +43,13 @@ pub(crate) fn generate_partial_witness( // Keep running generators until no generators are queued. while !pending_generator_indices.is_empty() { - let mut next_pending_generator_indices = HashSet::new(); + let mut next_pending_generator_indices = Vec::new(); for &generator_idx in &pending_generator_indices { + if generator_is_expired[generator_idx] { + continue; + } + let finished = generators[generator_idx].run(&witness, &mut buffer); if finished { generator_is_expired[generator_idx] = true; @@ -50,9 +59,7 @@ pub(crate) fn generate_partial_witness( for (watch, _) in &buffer.target_values { if let Some(watching_generator_indices) = generator_indices_by_watches.get(watch) { for &watching_generator_idx in watching_generator_indices { - if !generator_is_expired[watching_generator_idx] { - next_pending_generator_indices.insert(watching_generator_idx); - } + next_pending_generator_indices.push(watching_generator_idx); } } } diff --git a/src/iop/witness.rs b/src/iop/witness.rs index 339f8cfb..5b97f546 100644 --- a/src/iop/witness.rs +++ b/src/iop/witness.rs @@ -57,6 +57,18 @@ impl PartialWitness { ) } + pub fn get_extension_targets( + &self, + ets: &[ExtensionTarget], + ) -> Vec + where + F: Extendable, + { + ets.iter() + .map(|&et| self.get_extension_target(et)) + .collect() + } + pub fn get_hash_target(&self, ht: HashOutTarget) -> HashOut { HashOut { elements: self.get_targets(&ht.elements).try_into().unwrap(), @@ -150,9 +162,7 @@ impl PartialWitness { } pub fn extend>(&mut self, pairs: I) { - for (target, value) in pairs { - self.set_target(target, value); - } + self.target_values.extend(pairs); } pub fn full_witness(self, degree: usize, num_wires: usize) -> Witness { diff --git a/src/plonk/prover.rs b/src/plonk/prover.rs index a8520dde..06670a08 100644 --- a/src/plonk/prover.rs +++ b/src/plonk/prover.rs @@ -37,7 +37,7 @@ pub(crate) fn prove, const D: usize>( timed!( timing, &format!("run {} generators", prover_data.generators.len()), - generate_partial_witness(&mut partial_witness, &prover_data.generators) + generate_partial_witness(&mut partial_witness, &prover_data.generators, &mut timing) ); let public_inputs = partial_witness.get_targets(&prover_data.public_inputs); @@ -249,17 +249,27 @@ fn wires_permutation_partial_products, const D: usize>( .enumerate() .map(|(i, &x)| { let s_sigmas = &prover_data.sigmas[i]; - let quotient_values = (0..common_data.config.num_routed_wires) + let numerators = (0..common_data.config.num_routed_wires) .map(|j| { let wire_value = witness.get_wire(i, j); let k_i = k_is[j]; let s_id = k_i * x; - let s_sigma = s_sigmas[j]; - let numerator = wire_value + beta * s_id + gamma; - let denominator = wire_value + beta * s_sigma + gamma; - numerator / denominator + wire_value + beta * s_id + gamma }) .collect::>(); + let denominators = (0..common_data.config.num_routed_wires) + .map(|j| { + let wire_value = witness.get_wire(i, j); + let s_sigma = s_sigmas[j]; + wire_value + beta * s_sigma + gamma + }) + .collect::>(); + let denominator_invs = F::batch_multiplicative_inverse(&denominators); + let quotient_values = numerators + .into_iter() + .zip(denominator_invs) + .map(|(num, den_inv)| num * den_inv) + .collect::>(); let quotient_partials = partial_products("ient_values, degree);