Merge pull request #123 from mir-protocol/remove_acc_in_gmimc

Remove accumulator in `GMiMCGate` and only use bits in the recursive FRI verifier
This commit is contained in:
wborgeaud 2021-07-23 08:22:16 +02:00 committed by GitHub
commit d435720d04
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 136 additions and 122 deletions

View File

@ -56,7 +56,7 @@ impl CircuitConfig {
pub(crate) fn large_config() -> Self {
Self {
num_wires: 134,
num_wires: 126,
num_routed_wires: 34,
security_bits: 128,
rate_bits: 3,

View File

@ -20,7 +20,7 @@ impl<F: Extendable<D>, const D: usize> CircuitBuilder<F, D> {
fn compute_evaluation(
&mut self,
x: Target,
old_x_index: Target,
old_x_index_bits: &[Target],
arity_bits: usize,
last_evals: &[ExtensionTarget<D>],
beta: ExtensionTarget<D>,
@ -33,13 +33,9 @@ impl<F: Extendable<D>, const D: usize> CircuitBuilder<F, D> {
// The evaluation vector needs to be reordered first.
let mut evals = last_evals.to_vec();
reverse_index_bits_in_place(&mut evals);
let mut old_x_index_bits = self.split_le(old_x_index, arity_bits);
old_x_index_bits.reverse();
// Want `g^(arity - rev_old_x_index)` as in the out-of-circuit version.
// Compute it as `g^(arity-1-rev_old_x_index) * g`, where the first term is gotten using two's complement.
// TODO: Once the exponentiation gate lands, we won't need the bits and will be able to compute
// `g^(arity-rev_old_x_index)` directly.
let start = self.exp_from_complement_bits(gt, &old_x_index_bits);
let start = self.exp_from_complement_bits(gt, old_x_index_bits.iter().rev());
let coset_start = self.mul_many(&[start, gt, x]);
// The answer is gotten by interpolating {(x*g^i, P(x*g^i))} and evaluating at beta.
@ -151,7 +147,7 @@ impl<F: Extendable<D>, const D: usize> CircuitBuilder<F, D> {
fn fri_verify_initial_proof(
&mut self,
x_index: Target,
x_index_bits: &[Target],
proof: &FriInitialTreeProofTarget,
initial_merkle_roots: &[HashTarget],
) {
@ -164,7 +160,7 @@ impl<F: Extendable<D>, const D: usize> CircuitBuilder<F, D> {
context!(
self,
&format!("verify {}'th initial Merkle proof", i),
self.verify_merkle_proof(evals.clone(), x_index, root, merkle_proof)
self.verify_merkle_proof(evals.clone(), x_index_bits, root, merkle_proof)
);
}
}
@ -256,27 +252,26 @@ impl<F: Extendable<D>, const D: usize> CircuitBuilder<F, D> {
let config = &common_data.config.fri_config;
let n_log = log2_strict(n);
// TODO: Do we need to range check `x_index` to a target smaller than `p`?
let mut x_index = challenger.get_challenge(self);
x_index = self.split_low_high(x_index, n_log, 64).0;
let mut x_index_num_bits = n_log;
let x_index = challenger.get_challenge(self);
let mut x_index_bits = self.low_bits(x_index, n_log, 64);
let mut domain_size = n;
context!(
self,
"check FRI initial proof",
self.fri_verify_initial_proof(
x_index,
&x_index_bits,
&round_proof.initial_trees_proof,
initial_merkle_roots,
)
);
let mut old_x_index = self.zero();
let mut old_x_index_bits = Vec::new();
// `subgroup_x` is `subgroup[x_index]`, i.e., the actual field element in the domain.
let mut subgroup_x = context!(self, "compute x from its index", {
let g = self.constant(F::MULTIPLICATIVE_GROUP_GENERATOR);
let phi = self.constant(F::primitive_root_of_unity(n_log));
let reversed_x = self.reverse_limbs::<2>(x_index, n_log);
let reversed_x = self.le_sum(x_index_bits.iter().rev());
let phi = self.exp(phi, reversed_x, n_log);
self.mul(g, phi)
});
@ -305,7 +300,7 @@ impl<F: Extendable<D>, const D: usize> CircuitBuilder<F, D> {
"infer evaluation using interpolation",
self.compute_evaluation(
subgroup_x,
old_x_index,
&old_x_index_bits,
config.reduction_arity_bits[i - 1],
last_evals,
betas[i - 1],
@ -314,15 +309,16 @@ impl<F: Extendable<D>, const D: usize> CircuitBuilder<F, D> {
};
let mut evals = round_proof.steps[i].evals.clone();
// Insert P(y) into the evaluation vector, since it wasn't included by the prover.
let (low_x_index, high_x_index) =
self.split_low_high(x_index, arity_bits, x_index_num_bits);
let high_x_index_bits = x_index_bits.split_off(arity_bits);
old_x_index_bits = x_index_bits;
let low_x_index = self.le_sum(old_x_index_bits.iter());
evals = self.insert(low_x_index, e_x, evals);
context!(
self,
"verify FRI round Merkle proof.",
self.verify_merkle_proof(
flatten_target(&evals),
high_x_index,
&high_x_index_bits,
proof.commit_phase_merkle_roots[i],
&round_proof.steps[i].merkle_proof,
)
@ -334,9 +330,7 @@ impl<F: Extendable<D>, const D: usize> CircuitBuilder<F, D> {
subgroup_x = self.exp_power_of_2(subgroup_x, config.reduction_arity_bits[i - 1]);
}
domain_size = next_domain_size;
old_x_index = low_x_index;
x_index = high_x_index;
x_index_num_bits -= arity_bits;
x_index_bits = high_x_index_bits;
}
let last_evals = evaluations.last().unwrap();
@ -346,7 +340,7 @@ impl<F: Extendable<D>, const D: usize> CircuitBuilder<F, D> {
"infer final evaluation using interpolation",
self.compute_evaluation(
subgroup_x,
old_x_index,
&old_x_index_bits,
final_arity_bits,
last_evals,
*betas.last().unwrap(),

View File

@ -1,3 +1,5 @@
use std::borrow::Borrow;
use crate::circuit_builder::CircuitBuilder;
use crate::field::extension_field::Extendable;
use crate::target::Target;
@ -185,13 +187,17 @@ impl<F: Extendable<D>, const D: usize> CircuitBuilder<F, D> {
// TODO: Optimize this, maybe with a new gate.
// TODO: Test
/// Exponentiate `base` to the power of `2^bit_length-1-exponent`, given by its little-endian bits.
pub fn exp_from_complement_bits(&mut self, base: Target, exponent_bits: &[Target]) -> Target {
pub fn exp_from_complement_bits(
&mut self,
base: Target,
exponent_bits: impl Iterator<Item = impl Borrow<Target>>,
) -> Target {
let mut current = base;
let one = self.one();
let mut product = one;
for &bit in exponent_bits {
let multiplicand = self.select(bit, one, current);
for bit in exponent_bits {
let multiplicand = self.select(*bit.borrow(), one, current);
product = self.mul(product, multiplicand);
current = self.mul(current, current);
}

View File

@ -22,15 +22,6 @@ impl<F: Extendable<D>, const D: usize> CircuitBuilder<F, D> {
});
self.route(zero, swap_wire);
// The old accumulator wire doesn't matter, since we won't read the new accumulator wire.
// We do have to set it to something though, so we'll arbitrary pick 0.
let old_acc_wire = GMiMCGate::<F, D, GMIMC_ROUNDS>::WIRE_INDEX_ACCUMULATOR_OLD;
let old_acc_wire = Target::Wire(Wire {
gate,
input: old_acc_wire,
});
self.route(zero, old_acc_wire);
// Route input wires.
for i in 0..12 {
let in_wire = GMiMCGate::<F, D, GMIMC_ROUNDS>::wire_input(i);

View File

@ -14,6 +14,13 @@ impl<F: Extendable<D>, const D: usize> CircuitBuilder<F, D> {
self.route(x, sum);
}
/// Returns the first `num_low_bits` little-endian bits of `x`.
pub fn low_bits(&mut self, x: Target, num_low_bits: usize, num_bits: usize) -> Vec<Target> {
let mut res = self.split_le(x, num_bits);
res.truncate(num_low_bits);
res
}
/// Returns `(a,b)` such that `x = a + 2^n_log * b` with `a < 2^n_log`.
/// `x` is assumed to be range-checked for having `num_bits` bits.
pub fn split_low_high(&mut self, x: Target, n_log: usize, num_bits: usize) -> (Target, Target) {

View File

@ -1,7 +1,12 @@
use std::borrow::Borrow;
use crate::circuit_builder::CircuitBuilder;
use crate::field::extension_field::Extendable;
use crate::field::field::Field;
use crate::gates::base_sum::BaseSumGate;
use crate::generator::{GeneratedValues, SimpleGenerator};
use crate::target::Target;
use crate::witness::PartialWitness;
impl<F: Extendable<D>, const D: usize> CircuitBuilder<F, D> {
/// Split the given element into a list of targets, where each one represents a
@ -33,11 +38,65 @@ impl<F: Extendable<D>, const D: usize> CircuitBuilder<F, D> {
Target::wire(gate, BaseSumGate::<B>::WIRE_REVERSED_SUM)
}
/// Takes an iterator of bits `(b_i)` and returns `sum b_i * 2^i`, i.e.,
/// the number with little-endian bit representation given by `bits`.
pub(crate) fn le_sum(
&mut self,
bits: impl ExactSizeIterator<Item = impl Borrow<Target>> + Clone,
) -> Target {
let num_bits = bits.len();
debug_assert!(
BaseSumGate::<2>::START_LIMBS + num_bits <= self.config.num_routed_wires,
"Not enough routed wires."
);
let gate_index = self.add_gate(BaseSumGate::<2>::new(num_bits), vec![]);
for (limb, wire) in bits
.clone()
.zip(BaseSumGate::<2>::START_LIMBS..BaseSumGate::<2>::START_LIMBS + num_bits)
{
self.route(*limb.borrow(), Target::wire(gate_index, wire));
}
self.add_generator(BaseSumGenerator::<2> {
gate_index,
limbs: bits.map(|l| *l.borrow()).collect(),
});
Target::wire(gate_index, BaseSumGate::<2>::WIRE_SUM)
}
}
#[derive(Debug)]
struct BaseSumGenerator<const B: usize> {
gate_index: usize,
limbs: Vec<Target>,
}
impl<F: Field, const B: usize> SimpleGenerator<F> for BaseSumGenerator<B> {
fn dependencies(&self) -> Vec<Target> {
self.limbs.clone()
}
fn run_once(&self, witness: &PartialWitness<F>) -> GeneratedValues<F> {
let sum = self
.limbs
.iter()
.map(|&t| witness.get_target(t))
.rev()
.fold(F::ZERO, |acc, limb| acc * F::from_canonical_usize(B) + limb);
GeneratedValues::singleton_target(
Target::wire(self.gate_index, BaseSumGate::<B>::WIRE_SUM),
sum,
)
}
}
#[cfg(test)]
mod tests {
use anyhow::Result;
use rand::{thread_rng, Rng};
use super::*;
use crate::circuit_data::CircuitConfig;
@ -73,4 +132,36 @@ mod tests {
verify(proof, &data.verifier_only, &data.common)
}
#[test]
fn test_base_sum() -> Result<()> {
type F = CrandallField;
let config = CircuitConfig::large_config();
let mut builder = CircuitBuilder::<F, 4>::new(config);
let n = thread_rng().gen_range(0, 1 << 10);
let x = builder.constant(F::from_canonical_usize(n));
let zero = builder.zero();
let one = builder.one();
let y = builder.le_sum(
(0..10)
.scan(n, |acc, _| {
let tmp = *acc % 2;
*acc /= 2;
Some(if tmp == 1 { one } else { zero })
})
.collect::<Vec<_>>()
.iter(),
);
builder.assert_equal(x, y);
let data = builder.build();
let proof = data.prove(PartialWitness::new())?;
verify(proof, &data.verifier_only, &data.common)
}
}

View File

@ -48,22 +48,18 @@ impl<F: Extendable<D>, const D: usize, const R: usize> GMiMCGate<F, D, R> {
W + i
}
/// Used to incrementally compute the index of the leaf based on a series of swap bits.
pub const WIRE_INDEX_ACCUMULATOR_OLD: usize = 2 * W;
pub const WIRE_INDEX_ACCUMULATOR_NEW: usize = 2 * W + 1;
/// If this is set to 1, the first four inputs will be swapped with the next four inputs. This
/// is useful for ordering hashes in Merkle proofs. Otherwise, this should be set to 0.
pub const WIRE_SWAP: usize = 2 * W + 2;
pub const WIRE_SWAP: usize = 2 * W;
/// A wire which stores the input to the `i`th cubing.
fn wire_cubing_input(i: usize) -> usize {
2 * W + 3 + i
2 * W + 1 + i
}
/// End of wire indices, exclusive.
fn end() -> usize {
2 * W + 3 + R
2 * W + 1 + R
}
}
@ -79,11 +75,6 @@ impl<F: Extendable<D>, const D: usize, const R: usize> Gate<F, D> for GMiMCGate<
let swap = vars.local_wires[Self::WIRE_SWAP];
constraints.push(swap * (swap - F::Extension::ONE));
let old_index_acc = vars.local_wires[Self::WIRE_INDEX_ACCUMULATOR_OLD];
let new_index_acc = vars.local_wires[Self::WIRE_INDEX_ACCUMULATOR_NEW];
let computed_new_index_acc = F::Extension::TWO * old_index_acc + swap;
constraints.push(computed_new_index_acc - new_index_acc);
let mut state = Vec::with_capacity(12);
for i in 0..4 {
let a = vars.local_wires[i];
@ -128,11 +119,6 @@ impl<F: Extendable<D>, const D: usize, const R: usize> Gate<F, D> for GMiMCGate<
let swap = vars.local_wires[Self::WIRE_SWAP];
constraints.push(swap * (swap - F::ONE));
let old_index_acc = vars.local_wires[Self::WIRE_INDEX_ACCUMULATOR_OLD];
let new_index_acc = vars.local_wires[Self::WIRE_INDEX_ACCUMULATOR_NEW];
let computed_new_index_acc = F::TWO * old_index_acc + swap;
constraints.push(computed_new_index_acc - new_index_acc);
let mut state = Vec::with_capacity(12);
for i in 0..4 {
let a = vars.local_wires[i];
@ -180,13 +166,6 @@ impl<F: Extendable<D>, const D: usize, const R: usize> Gate<F, D> for GMiMCGate<
let swap = vars.local_wires[Self::WIRE_SWAP];
constraints.push(builder.mul_sub_extension(swap, swap, swap));
let old_index_acc = vars.local_wires[Self::WIRE_INDEX_ACCUMULATOR_OLD];
let new_index_acc = vars.local_wires[Self::WIRE_INDEX_ACCUMULATOR_NEW];
// computed_new_index_acc = 2 * old_index_acc + swap
let two = builder.two_extension();
let computed_new_index_acc = builder.mul_add_extension(two, old_index_acc, swap);
constraints.push(builder.sub_extension(computed_new_index_acc, new_index_acc));
let mut state = Vec::with_capacity(12);
for i in 0..4 {
let a = vars.local_wires[i];
@ -256,7 +235,7 @@ impl<F: Extendable<D>, const D: usize, const R: usize> Gate<F, D> for GMiMCGate<
}
fn num_constraints(&self) -> usize {
R + W + 2
R + W + 1
}
}
@ -270,12 +249,11 @@ impl<F: Extendable<D>, const D: usize, const R: usize> SimpleGenerator<F>
for GMiMCGenerator<F, D, R>
{
fn dependencies(&self) -> Vec<Target> {
let mut dep_input_indices = Vec::with_capacity(W + 2);
let mut dep_input_indices = Vec::with_capacity(W + 1);
for i in 0..W {
dep_input_indices.push(GMiMCGate::<F, D, R>::wire_input(i));
}
dep_input_indices.push(GMiMCGate::<F, D, R>::WIRE_SWAP);
dep_input_indices.push(GMiMCGate::<F, D, R>::WIRE_INDEX_ACCUMULATOR_OLD);
dep_input_indices
.into_iter()
@ -289,7 +267,7 @@ impl<F: Extendable<D>, const D: usize, const R: usize> SimpleGenerator<F>
}
fn run_once(&self, witness: &PartialWitness<F>) -> GeneratedValues<F> {
let mut result = GeneratedValues::with_capacity(R + W + 1);
let mut result = GeneratedValues::with_capacity(R + W);
let mut state = (0..W)
.map(|i| {
@ -311,20 +289,6 @@ impl<F: Extendable<D>, const D: usize, const R: usize> SimpleGenerator<F>
}
}
// Update the index accumulator.
let old_index_acc_value = witness.get_wire(Wire {
gate: self.gate_index,
input: GMiMCGate::<F, D, R>::WIRE_INDEX_ACCUMULATOR_OLD,
});
let new_index_acc_value = F::TWO * old_index_acc_value + swap_value;
result.set_wire(
Wire {
gate: self.gate_index,
input: GMiMCGate::<F, D, R>::WIRE_INDEX_ACCUMULATOR_NEW,
},
new_index_acc_value,
);
// Value that is implicitly added to each element.
// See https://affine.group/2020/02/starkware-challenge
let mut addition_buffer = F::ZERO;
@ -389,22 +353,9 @@ mod tests {
type Gate = GMiMCGate<F, 4, R>;
let gate = Gate::with_constants(constants.clone());
let config = CircuitConfig {
num_wires: 134,
num_routed_wires: 200,
..Default::default()
};
let permutation_inputs = (0..W).map(F::from_canonical_usize).collect::<Vec<_>>();
let mut witness = PartialWitness::new();
witness.set_wire(
Wire {
gate: 0,
input: Gate::WIRE_INDEX_ACCUMULATOR_OLD,
},
F::from_canonical_usize(7),
);
witness.set_wire(
Wire {
gate: 0,
@ -435,12 +386,6 @@ mod tests {
});
assert_eq!(out, expected_outputs[i]);
}
let acc_new = witness.get_wire(Wire {
gate: 0,
input: Gate::WIRE_INDEX_ACCUMULATOR_NEW,
});
assert_eq!(acc_new, F::from_canonical_usize(7 * 2));
}
#[test]

View File

@ -59,22 +59,19 @@ pub(crate) fn verify_merkle_proof<F: Field>(
impl<F: Extendable<D>, const D: usize> CircuitBuilder<F, D> {
/// Verifies that the given leaf data is present at the given index in the Merkle tree with the
/// given root.
/// given root. The index is given by it's little-endian bits.
pub(crate) fn verify_merkle_proof(
&mut self,
leaf_data: Vec<Target>,
leaf_index: Target,
leaf_index_bits: &[Target],
merkle_root: HashTarget,
proof: &MerkleProofTarget,
) {
let zero = self.zero();
let height = proof.siblings.len();
let purported_index_bits = self.split_le_virtual(leaf_index, height);
let mut state: HashTarget = self.hash_or_noop(leaf_data);
let mut acc_leaf_index = zero;
for (bit, &sibling) in purported_index_bits.into_iter().zip(&proof.siblings) {
for (&bit, &sibling) in leaf_index_bits.iter().zip(&proof.siblings) {
let gate = self
.add_gate_no_constants(GMiMCGate::<F, D, GMIMC_ROUNDS>::with_automatic_constants());
@ -85,20 +82,6 @@ impl<F: Extendable<D>, const D: usize> CircuitBuilder<F, D> {
});
self.generate_copy(bit, swap_wire);
let old_acc_wire = GMiMCGate::<F, D, GMIMC_ROUNDS>::WIRE_INDEX_ACCUMULATOR_OLD;
let old_acc_wire = Target::Wire(Wire {
gate,
input: old_acc_wire,
});
self.route(acc_leaf_index, old_acc_wire);
let new_acc_wire = GMiMCGate::<F, D, GMIMC_ROUNDS>::WIRE_INDEX_ACCUMULATOR_NEW;
let new_acc_wire = Target::Wire(Wire {
gate,
input: new_acc_wire,
});
acc_leaf_index = new_acc_wire;
let input_wires = (0..12)
.map(|i| {
Target::Wire(Wire {
@ -126,10 +109,6 @@ impl<F: Extendable<D>, const D: usize> CircuitBuilder<F, D> {
)
}
// TODO: this is far from optimal.
let leaf_index_rev = self.reverse_limbs::<2>(leaf_index, height);
self.assert_equal(acc_leaf_index, leaf_index_rev);
self.named_assert_hashes_equal(state, merkle_root, "check Merkle root".into())
}
@ -191,13 +170,14 @@ mod tests {
pw.set_hash_target(root_t, tree.root);
let i_c = builder.constant(F::from_canonical_usize(i));
let i_bits = builder.split_le(i_c, log_n);
let data = builder.add_virtual_targets(tree.leaves[i].len());
for j in 0..data.len() {
pw.set_target(data[j], tree.leaves[i][j]);
}
builder.verify_merkle_proof(data, i_c, root_t, &proof_t);
builder.verify_merkle_proof(data, &i_bits, root_t, &proof_t);
let data = builder.build();
let proof = data.prove(pw)?;

View File

@ -365,7 +365,7 @@ mod tests {
type F = CrandallField;
const D: usize = 4;
let config = CircuitConfig {
num_wires: 134,
num_wires: 126,
num_routed_wires: 33,
security_bits: 128,
rate_bits: 3,