From c729a3c2353a3c2e8d8dd0c39c24f7a9f3b4b5e8 Mon Sep 17 00:00:00 2001 From: wborgeaud Date: Thu, 22 Jul 2021 16:07:07 +0200 Subject: [PATCH] Remove all non-bits indices in the FRI verifier --- src/fri/recursive_verifier.rs | 11 ++--- src/gadgets/split_base.rs | 89 +++++++++++++++++++++++++++++++++++ 2 files changed, 92 insertions(+), 8 deletions(-) diff --git a/src/fri/recursive_verifier.rs b/src/fri/recursive_verifier.rs index 4c3a3704..12246504 100644 --- a/src/fri/recursive_verifier.rs +++ b/src/fri/recursive_verifier.rs @@ -253,10 +253,8 @@ impl, const D: usize> CircuitBuilder { let config = &common_data.config.fri_config; let n_log = log2_strict(n); // TODO: Do we need to range check `x_index` to a target smaller than `p`? - let mut x_index = challenger.get_challenge(self); - x_index = self.split_low_high(x_index, n_log, 64).0; + let x_index = challenger.get_challenge(self); let mut x_index_bits = self.low_bits(x_index, n_log, 64); - let mut x_index_num_bits = n_log; let mut domain_size = n; context!( self, @@ -274,7 +272,7 @@ impl, const D: usize> CircuitBuilder { let g = self.constant(F::MULTIPLICATIVE_GROUP_GENERATOR); let phi = self.constant(F::primitive_root_of_unity(n_log)); - let reversed_x = self.reverse_limbs::<2>(x_index, n_log); + let reversed_x = self.base_sum(x_index_bits.iter().rev()); let phi = self.exp(phi, reversed_x, n_log); self.mul(g, phi) }); @@ -312,10 +310,9 @@ impl, const D: usize> CircuitBuilder { }; let mut evals = round_proof.steps[i].evals.clone(); // Insert P(y) into the evaluation vector, since it wasn't included by the prover. - let (low_x_index, high_x_index) = - self.split_low_high(x_index, arity_bits, x_index_num_bits); let high_x_index_bits = x_index_bits.split_off(arity_bits); old_x_index_bits = x_index_bits; + let low_x_index = self.base_sum(old_x_index_bits.iter()); evals = self.insert(low_x_index, e_x, evals); context!( self, @@ -334,8 +331,6 @@ impl, const D: usize> CircuitBuilder { subgroup_x = self.exp_power_of_2(subgroup_x, config.reduction_arity_bits[i - 1]); } domain_size = next_domain_size; - x_index = high_x_index; - x_index_num_bits -= arity_bits; x_index_bits = high_x_index_bits; } diff --git a/src/gadgets/split_base.rs b/src/gadgets/split_base.rs index 1223170e..da26c1db 100644 --- a/src/gadgets/split_base.rs +++ b/src/gadgets/split_base.rs @@ -1,7 +1,12 @@ +use std::borrow::Borrow; + use crate::circuit_builder::CircuitBuilder; use crate::field::extension_field::Extendable; +use crate::field::field::Field; use crate::gates::base_sum::BaseSumGate; +use crate::generator::{GeneratedValues, SimpleGenerator}; use crate::target::Target; +use crate::witness::PartialWitness; impl, const D: usize> CircuitBuilder { /// Split the given element into a list of targets, where each one represents a @@ -33,11 +38,63 @@ impl, const D: usize> CircuitBuilder { Target::wire(gate, BaseSumGate::::WIRE_REVERSED_SUM) } + + pub(crate) fn base_sum( + &mut self, + limbs: impl ExactSizeIterator> + Clone, + ) -> Target { + let num_limbs = limbs.len(); + debug_assert!( + BaseSumGate::<2>::START_LIMBS + num_limbs <= self.config.num_routed_wires, + "Not enough routed wires." + ); + let gate_index = self.add_gate(BaseSumGate::<2>::new(num_limbs), vec![]); + for (limb, wire) in limbs + .clone() + .zip(BaseSumGate::<2>::START_LIMBS..BaseSumGate::<2>::START_LIMBS + num_limbs) + { + self.route(*limb.borrow(), Target::wire(gate_index, wire)); + } + + self.add_generator(BaseSumGenerator::<2> { + gate_index, + limbs: limbs.map(|l| *l.borrow()).collect(), + }); + + Target::wire(gate_index, BaseSumGate::<2>::WIRE_SUM) + } +} + +#[derive(Debug)] +struct BaseSumGenerator { + gate_index: usize, + limbs: Vec, +} + +impl SimpleGenerator for BaseSumGenerator { + fn dependencies(&self) -> Vec { + self.limbs.clone() + } + + fn run_once(&self, witness: &PartialWitness) -> GeneratedValues { + let sum = self + .limbs + .iter() + .map(|&t| witness.get_target(t)) + .rev() + .fold(F::ZERO, |acc, limb| acc * F::from_canonical_usize(B) + limb); + + GeneratedValues::singleton_target( + Target::wire(self.gate_index, BaseSumGate::::WIRE_SUM), + sum, + ) + } } #[cfg(test)] mod tests { use anyhow::Result; + use rand::{thread_rng, Rng}; use super::*; use crate::circuit_data::CircuitConfig; @@ -73,4 +130,36 @@ mod tests { verify(proof, &data.verifier_only, &data.common) } + + #[test] + fn test_base_sum() -> Result<()> { + type F = CrandallField; + let config = CircuitConfig::large_config(); + let mut builder = CircuitBuilder::::new(config); + + let n = thread_rng().gen_range(0, 1 << 10); + let x = builder.constant(F::from_canonical_usize(n)); + + let zero = builder.zero(); + let one = builder.one(); + + let y = builder.base_sum( + (0..10) + .scan(n, |acc, _| { + let tmp = *acc % 2; + *acc /= 2; + Some(if tmp == 1 { one } else { zero }) + }) + .collect::>() + .iter(), + ); + + builder.assert_equal(x, y); + + let data = builder.build(); + + let proof = data.prove(PartialWitness::new())?; + + verify(proof, &data.verifier_only, &data.common) + } }