Merge branch 'main' into remove_reverse_limbs

This commit is contained in:
wborgeaud 2021-08-05 16:04:16 +02:00
commit df07909f1e
9 changed files with 148 additions and 45 deletions

View File

@ -4,7 +4,6 @@ use crate::field::extension_field::Extendable;
use crate::gates::exponentiation::ExponentiationGate; use crate::gates::exponentiation::ExponentiationGate;
use crate::iop::target::Target; use crate::iop::target::Target;
use crate::plonk::circuit_builder::CircuitBuilder; use crate::plonk::circuit_builder::CircuitBuilder;
use crate::util::log2_ceil;
impl<F: Extendable<D>, const D: usize> CircuitBuilder<F, D> { impl<F: Extendable<D>, const D: usize> CircuitBuilder<F, D> {
/// Computes `-x`. /// Computes `-x`.
@ -97,12 +96,8 @@ impl<F: Extendable<D>, const D: usize> CircuitBuilder<F, D> {
} }
/// Exponentiate `base` to the power of `2^power_log`. /// Exponentiate `base` to the power of `2^power_log`.
// TODO: Test pub fn exp_power_of_2(&mut self, base: Target, power_log: usize) -> Target {
pub fn exp_power_of_2(&mut self, mut base: Target, power_log: usize) -> Target { self.exp_u64(base, 1 << power_log)
for _ in 0..power_log {
base = self.square(base);
}
base
} }
// TODO: Test // TODO: Test
@ -110,12 +105,12 @@ impl<F: Extendable<D>, const D: usize> CircuitBuilder<F, D> {
pub fn exp_from_bits( pub fn exp_from_bits(
&mut self, &mut self,
base: Target, base: Target,
exponent_bits: impl Iterator<Item = impl Borrow<Target>>, exponent_bits: impl IntoIterator<Item = impl Borrow<Target>>,
) -> Target { ) -> Target {
let zero = self.zero(); let zero = self.zero();
let gate = ExponentiationGate::new(self.config.clone()); let gate = ExponentiationGate::new(self.config.clone());
let num_power_bits = gate.num_power_bits; let num_power_bits = gate.num_power_bits;
let mut exp_bits_vec: Vec<Target> = exponent_bits.map(|b| *b.borrow()).collect(); let mut exp_bits_vec: Vec<Target> = exponent_bits.into_iter().map(|b| *b.borrow()).collect();
while exp_bits_vec.len() < num_power_bits { while exp_bits_vec.len() < num_power_bits {
exp_bits_vec.push(zero); exp_bits_vec.push(zero);
} }
@ -139,10 +134,16 @@ impl<F: Extendable<D>, const D: usize> CircuitBuilder<F, D> {
/// Exponentiate `base` to the power of a known `exponent`. /// Exponentiate `base` to the power of a known `exponent`.
// TODO: Test // TODO: Test
pub fn exp_u64(&mut self, base: Target, exponent: u64) -> Target { pub fn exp_u64(&mut self, base: Target, mut exponent: u64) -> Target {
let exp_target = self.constant(F::from_canonical_u64(exponent)); let mut exp_bits = Vec::new();
let num_bits = log2_ceil(exponent as usize + 1); while exponent != 0 {
self.exp(base, exp_target, num_bits) let bit = exponent & 1;
let bit_target = self.constant(F::from_canonical_u64(bit));
exp_bits.push(bit_target);
exponent >>= 1;
}
self.exp_from_bits(base, exp_bits)
} }
/// Computes `x / y`. Results in an unsatisfiable instance if `y = 0`. /// Computes `x / y`. Results in an unsatisfiable instance if `y = 0`.

View File

@ -4,6 +4,7 @@ pub mod hash;
pub mod insert; pub mod insert;
pub mod interpolation; pub mod interpolation;
pub mod polynomial; pub mod polynomial;
pub mod random_access;
pub mod range_check; pub mod range_check;
pub mod select; pub mod select;
pub mod split_base; pub mod split_base;

View File

@ -0,0 +1,76 @@
use crate::field::extension_field::target::ExtensionTarget;
use crate::field::extension_field::Extendable;
use crate::gates::random_access::RandomAccessGate;
use crate::iop::target::Target;
use crate::plonk::circuit_builder::CircuitBuilder;
impl<F: Extendable<D>, const D: usize> CircuitBuilder<F, D> {
/// Checks that a `Target` matches a vector at a non-deterministic index.
/// Note: `index` is not range-checked.
pub fn random_access(
&mut self,
access_index: Target,
claimed_element: ExtensionTarget<D>,
v: Vec<ExtensionTarget<D>>,
) {
let gate = RandomAccessGate::new(v.len());
let gate_index = self.add_gate(gate.clone(), vec![]);
v.iter().enumerate().for_each(|(i, &val)| {
self.route_extension(
val,
ExtensionTarget::from_range(gate_index, gate.wires_list_item(i)),
);
});
self.route(
access_index,
Target::wire(gate_index, gate.wires_access_index()),
);
self.route_extension(
claimed_element,
ExtensionTarget::from_range(gate_index, gate.wires_claimed_element()),
);
}
}
#[cfg(test)]
mod tests {
use anyhow::Result;
use super::*;
use crate::field::crandall_field::CrandallField;
use crate::field::extension_field::quartic::QuarticCrandallField;
use crate::field::field_types::Field;
use crate::iop::witness::PartialWitness;
use crate::plonk::circuit_data::CircuitConfig;
use crate::plonk::verifier::verify;
fn test_random_access_given_len(len_log: usize) -> Result<()> {
type F = CrandallField;
type FF = QuarticCrandallField;
let len = 1 << len_log;
let config = CircuitConfig::large_config();
let mut builder = CircuitBuilder::<F, 4>::new(config);
let vec = FF::rand_vec(len);
let v: Vec<_> = vec.iter().map(|x| builder.constant_extension(*x)).collect();
for i in 0..len {
let it = builder.constant(F::from_canonical_usize(i));
let elem = builder.constant_extension(vec[i]);
builder.random_access(it, elem, v.clone());
}
let data = builder.build();
let proof = data.prove(PartialWitness::new())?;
verify(proof, &data.verifier_only, &data.common)
}
#[test]
fn test_random_access() -> Result<()> {
for len_log in 1..3 {
test_random_access_given_len(len_log)?;
}
Ok(())
}
}

View File

@ -334,6 +334,7 @@ mod tests {
use crate::iop::generator::generate_partial_witness; use crate::iop::generator::generate_partial_witness;
use crate::iop::wire::Wire; use crate::iop::wire::Wire;
use crate::iop::witness::PartialWitness; use crate::iop::witness::PartialWitness;
use crate::util::timing::TimingTree;
#[test] #[test]
fn generated_output() { fn generated_output() {
@ -364,7 +365,7 @@ mod tests {
} }
let generators = gate.generators(0, &[]); let generators = gate.generators(0, &[]);
generate_partial_witness(&mut witness, &generators); generate_partial_witness(&mut witness, &generators, &mut TimingTree::default());
let expected_outputs: [F; W] = let expected_outputs: [F; W] =
gmimc_permute_naive(permutation_inputs.try_into().unwrap(), constants); gmimc_permute_naive(permutation_inputs.try_into().unwrap(), constants);

View File

@ -1,4 +1,3 @@
use std::convert::TryInto;
use std::marker::PhantomData; use std::marker::PhantomData;
use std::ops::Range; use std::ops::Range;
@ -217,13 +216,6 @@ impl<F: Extendable<D>, const D: usize> SimpleGenerator<F> for RandomAccessGenera
let get_local_wire = |input| witness.get_wire(local_wire(input)); let get_local_wire = |input| witness.get_wire(local_wire(input));
let get_local_ext = |wire_range: Range<usize>| {
debug_assert_eq!(wire_range.len(), D);
let values = wire_range.map(get_local_wire).collect::<Vec<_>>();
let arr = values.try_into().unwrap();
F::Extension::from_basefield_array(arr)
};
// Compute the new vector and the values for equality_dummy and index_matches // Compute the new vector and the values for equality_dummy and index_matches
let vec_size = self.gate.vec_size; let vec_size = self.gate.vec_size;
let access_index_f = get_local_wire(self.gate.wires_access_index()); let access_index_f = get_local_wire(self.gate.wires_access_index());

View File

@ -350,6 +350,7 @@ mod tests {
use crate::iop::witness::PartialWitness; use crate::iop::witness::PartialWitness;
use crate::plonk::circuit_builder::CircuitBuilder; use crate::plonk::circuit_builder::CircuitBuilder;
use crate::plonk::circuit_data::CircuitConfig; use crate::plonk::circuit_data::CircuitConfig;
use crate::util::timing::TimingTree;
#[test] #[test]
fn no_duplicate_challenges() { fn no_duplicate_challenges() {
@ -409,7 +410,11 @@ mod tests {
} }
let circuit = builder.build(); let circuit = builder.build();
let mut witness = PartialWitness::new(); let mut witness = PartialWitness::new();
generate_partial_witness(&mut witness, &circuit.prover_only.generators); generate_partial_witness(
&mut witness,
&circuit.prover_only.generators,
&mut TimingTree::default(),
);
let recursive_output_values_per_round: Vec<Vec<F>> = recursive_outputs_per_round let recursive_output_values_per_round: Vec<Vec<F>> = recursive_outputs_per_round
.iter() .iter()
.map(|outputs| witness.get_targets(outputs)) .map(|outputs| witness.get_targets(outputs))

View File

@ -1,4 +1,4 @@
use std::collections::{HashMap, HashSet}; use std::collections::HashMap;
use std::convert::identity; use std::convert::identity;
use std::fmt::Debug; use std::fmt::Debug;
@ -9,27 +9,32 @@ use crate::hash::hash_types::{HashOut, HashOutTarget};
use crate::iop::target::Target; use crate::iop::target::Target;
use crate::iop::wire::Wire; use crate::iop::wire::Wire;
use crate::iop::witness::PartialWitness; use crate::iop::witness::PartialWitness;
use crate::timed;
use crate::util::timing::TimingTree;
/// Given a `PartialWitness` that has only inputs set, populates the rest of the witness using the /// Given a `PartialWitness` that has only inputs set, populates the rest of the witness using the
/// given set of generators. /// given set of generators.
pub(crate) fn generate_partial_witness<F: Field>( pub(crate) fn generate_partial_witness<F: Field>(
witness: &mut PartialWitness<F>, witness: &mut PartialWitness<F>,
generators: &[Box<dyn WitnessGenerator<F>>], generators: &[Box<dyn WitnessGenerator<F>>],
timing: &mut TimingTree,
) { ) {
// Index generator indices by their watched targets. // Index generator indices by their watched targets.
let mut generator_indices_by_watches = HashMap::new(); let mut generator_indices_by_watches = HashMap::new();
for (i, generator) in generators.iter().enumerate() { timed!(timing, "index generators by their watched targets", {
for watch in generator.watch_list() { for (i, generator) in generators.iter().enumerate() {
generator_indices_by_watches for watch in generator.watch_list() {
.entry(watch) generator_indices_by_watches
.or_insert_with(Vec::new) .entry(watch)
.push(i); .or_insert_with(Vec::new)
.push(i);
}
} }
} });
// Build a list of "pending" generators which are queued to be run. Initially, all generators // Build a list of "pending" generators which are queued to be run. Initially, all generators
// are queued. // are queued.
let mut pending_generator_indices: HashSet<_> = (0..generators.len()).collect(); let mut pending_generator_indices: Vec<_> = (0..generators.len()).collect();
// We also track a list of "expired" generators which have already returned false. // We also track a list of "expired" generators which have already returned false.
let mut generator_is_expired = vec![false; generators.len()]; let mut generator_is_expired = vec![false; generators.len()];
@ -38,9 +43,13 @@ pub(crate) fn generate_partial_witness<F: Field>(
// Keep running generators until no generators are queued. // Keep running generators until no generators are queued.
while !pending_generator_indices.is_empty() { while !pending_generator_indices.is_empty() {
let mut next_pending_generator_indices = HashSet::new(); let mut next_pending_generator_indices = Vec::new();
for &generator_idx in &pending_generator_indices { for &generator_idx in &pending_generator_indices {
if generator_is_expired[generator_idx] {
continue;
}
let finished = generators[generator_idx].run(&witness, &mut buffer); let finished = generators[generator_idx].run(&witness, &mut buffer);
if finished { if finished {
generator_is_expired[generator_idx] = true; generator_is_expired[generator_idx] = true;
@ -50,9 +59,7 @@ pub(crate) fn generate_partial_witness<F: Field>(
for (watch, _) in &buffer.target_values { for (watch, _) in &buffer.target_values {
if let Some(watching_generator_indices) = generator_indices_by_watches.get(watch) { if let Some(watching_generator_indices) = generator_indices_by_watches.get(watch) {
for &watching_generator_idx in watching_generator_indices { for &watching_generator_idx in watching_generator_indices {
if !generator_is_expired[watching_generator_idx] { next_pending_generator_indices.push(watching_generator_idx);
next_pending_generator_indices.insert(watching_generator_idx);
}
} }
} }
} }

View File

@ -57,6 +57,18 @@ impl<F: Field> PartialWitness<F> {
) )
} }
pub fn get_extension_targets<const D: usize>(
&self,
ets: &[ExtensionTarget<D>],
) -> Vec<F::Extension>
where
F: Extendable<D>,
{
ets.iter()
.map(|&et| self.get_extension_target(et))
.collect()
}
pub fn get_hash_target(&self, ht: HashOutTarget) -> HashOut<F> { pub fn get_hash_target(&self, ht: HashOutTarget) -> HashOut<F> {
HashOut { HashOut {
elements: self.get_targets(&ht.elements).try_into().unwrap(), elements: self.get_targets(&ht.elements).try_into().unwrap(),
@ -150,9 +162,7 @@ impl<F: Field> PartialWitness<F> {
} }
pub fn extend<I: Iterator<Item = (Target, F)>>(&mut self, pairs: I) { pub fn extend<I: Iterator<Item = (Target, F)>>(&mut self, pairs: I) {
for (target, value) in pairs { self.target_values.extend(pairs);
self.set_target(target, value);
}
} }
pub fn full_witness(self, degree: usize, num_wires: usize) -> Witness<F> { pub fn full_witness(self, degree: usize, num_wires: usize) -> Witness<F> {

View File

@ -37,7 +37,7 @@ pub(crate) fn prove<F: Extendable<D>, const D: usize>(
timed!( timed!(
timing, timing,
&format!("run {} generators", prover_data.generators.len()), &format!("run {} generators", prover_data.generators.len()),
generate_partial_witness(&mut partial_witness, &prover_data.generators) generate_partial_witness(&mut partial_witness, &prover_data.generators, &mut timing)
); );
let public_inputs = partial_witness.get_targets(&prover_data.public_inputs); let public_inputs = partial_witness.get_targets(&prover_data.public_inputs);
@ -249,17 +249,27 @@ fn wires_permutation_partial_products<F: Extendable<D>, const D: usize>(
.enumerate() .enumerate()
.map(|(i, &x)| { .map(|(i, &x)| {
let s_sigmas = &prover_data.sigmas[i]; let s_sigmas = &prover_data.sigmas[i];
let quotient_values = (0..common_data.config.num_routed_wires) let numerators = (0..common_data.config.num_routed_wires)
.map(|j| { .map(|j| {
let wire_value = witness.get_wire(i, j); let wire_value = witness.get_wire(i, j);
let k_i = k_is[j]; let k_i = k_is[j];
let s_id = k_i * x; let s_id = k_i * x;
let s_sigma = s_sigmas[j]; wire_value + beta * s_id + gamma
let numerator = wire_value + beta * s_id + gamma;
let denominator = wire_value + beta * s_sigma + gamma;
numerator / denominator
}) })
.collect::<Vec<_>>(); .collect::<Vec<_>>();
let denominators = (0..common_data.config.num_routed_wires)
.map(|j| {
let wire_value = witness.get_wire(i, j);
let s_sigma = s_sigmas[j];
wire_value + beta * s_sigma + gamma
})
.collect::<Vec<_>>();
let denominator_invs = F::batch_multiplicative_inverse(&denominators);
let quotient_values = numerators
.into_iter()
.zip(denominator_invs)
.map(|(num, den_inv)| num * den_inv)
.collect::<Vec<_>>();
let quotient_partials = partial_products(&quotient_values, degree); let quotient_partials = partial_products(&quotient_values, degree);