From 104fd08e72802296ea28f13e485035de81ad5f0b Mon Sep 17 00:00:00 2001 From: wborgeaud Date: Mon, 18 Oct 2021 15:19:09 +0200 Subject: [PATCH] Working RAM gate --- src/fri/recursive_verifier.rs | 16 +- src/gadgets/random_access.rs | 36 ++-- src/gates/random_access.rs | 327 +++++++++++++++++++--------------- src/hash/merkle_proofs.rs | 4 +- 4 files changed, 209 insertions(+), 174 deletions(-) diff --git a/src/fri/recursive_verifier.rs b/src/fri/recursive_verifier.rs index eb30c467..bc1310f9 100644 --- a/src/fri/recursive_verifier.rs +++ b/src/fri/recursive_verifier.rs @@ -60,15 +60,17 @@ impl, const D: usize> CircuitBuilder { /// isn't required -- without it we'd get errors elsewhere in the stack -- but just gives more /// helpful errors. fn check_config(&self, arity: usize) { - let random_access = RandomAccessGate::::new(arity); + // let random_access = RandomAccessGate::::new(arity); let interpolation_gate = InterpolationGate::::new(arity); - let min_wires = random_access - .num_wires() - .max(interpolation_gate.num_wires()); - let min_routed_wires = random_access - .num_routed_wires() - .max(interpolation_gate.num_routed_wires()); + // let min_wires = random_access + // .num_wires() + // .max(interpolation_gate.num_wires()); + let min_wires = interpolation_gate.num_wires(); + // let min_routed_wires = random_access + // .num_routed_wires() + // .max(interpolation_gate.num_routed_wires()); + let min_routed_wires = interpolation_gate.num_routed_wires(); assert!( self.config.num_wires >= min_wires, diff --git a/src/gadgets/random_access.rs b/src/gadgets/random_access.rs index ab0db68c..bae7c9cd 100644 --- a/src/gadgets/random_access.rs +++ b/src/gadgets/random_access.rs @@ -8,7 +8,7 @@ use crate::plonk::circuit_builder::CircuitBuilder; impl, const D: usize> CircuitBuilder { /// Checks that a `Target` matches a vector at a non-deterministic index. /// Note: `index` is not range-checked. - pub fn random_access( + pub fn random_access_extension( &mut self, access_index: Target, claimed_element: ExtensionTarget, @@ -18,23 +18,25 @@ impl, const D: usize> CircuitBuilder { if v.len() == 1 { return self.connect_extension(claimed_element, v[0]); } - let gate = RandomAccessGate::new(v.len()); + let gate = RandomAccessGate::new(D, v.len()); let gate_index = self.add_gate(gate.clone(), vec![]); - v.iter().enumerate().for_each(|(i, &val)| { - self.connect_extension( - val, - ExtensionTarget::from_range(gate_index, gate.wires_list_item(i)), + for copy in 0..D { + v.iter().enumerate().for_each(|(i, &val)| { + self.connect( + val.0[copy], + Target::wire(gate_index, gate.wire_list_item(i, copy)), + ); + }); + self.connect( + access_index, + Target::wire(gate_index, gate.wire_access_index(copy)), ); - }); - self.connect( - access_index, - Target::wire(gate_index, gate.wire_access_index()), - ); - self.connect_extension( - claimed_element, - ExtensionTarget::from_range(gate_index, gate.wires_claimed_element()), - ); + self.connect( + claimed_element.0[copy], + Target::wire(gate_index, gate.wire_claimed_element(copy)), + ); + } } /// Like `random_access`, but first pads `v` to a given minimum length. This can help to avoid @@ -54,7 +56,7 @@ impl, const D: usize> CircuitBuilder { if v.len() < min_length { v.resize(8, zero); } - self.random_access(access_index, claimed_element, v); + self.random_access_extension(access_index, claimed_element, v); } } @@ -83,7 +85,7 @@ mod tests { for i in 0..len { let it = builder.constant(F::from_canonical_usize(i)); let elem = builder.constant_extension(vec[i]); - builder.random_access(it, elem, v.clone()); + builder.random_access_extension(it, elem, v.clone()); } let data = builder.build(); diff --git a/src/gates/random_access.rs b/src/gates/random_access.rs index cc4c0442..8e1618fc 100644 --- a/src/gates/random_access.rs +++ b/src/gates/random_access.rs @@ -30,31 +30,33 @@ impl, const D: usize> RandomAccessGate { } } - pub fn new_from_config(config: CircuitConfig, vec_size: usize) -> Self { - let num_copies = Self::max_num_copies(config.num_routed_wires, chunk_size); - Self::new(num_copies, chunk_size) + pub fn new_from_config(config: &CircuitConfig, vec_size: usize) -> Self { + let num_copies = Self::max_num_copies(config.num_routed_wires, vec_size); + Self::new(num_copies, vec_size) } pub fn max_num_copies(num_routed_wires: usize, vec_size: usize) -> usize { num_routed_wires / (2 + vec_size) } - pub fn wire_access_index(&self) -> usize { - 0 + pub fn wire_access_index(&self, copy: usize) -> usize { + debug_assert!(copy < self.num_copies); + (2 + self.vec_size) * copy } - pub fn wires_claimed_element(&self) -> Range { - 1..D + 1 + pub fn wire_claimed_element(&self, copy: usize) -> usize { + debug_assert!(copy < self.num_copies); + (2 + self.vec_size) * copy + 1 } - pub fn wires_list_item(&self, i: usize) -> Range { + pub fn wire_list_item(&self, i: usize, copy: usize) -> usize { debug_assert!(i < self.vec_size); - let start = (i + 1) * D + 1; - start..start + D + debug_assert!(copy < self.num_copies); + (2 + self.vec_size) * copy + 2 + i } fn start_of_intermediate_wires(&self) -> usize { - (self.vec_size + 1) * D + 1 + (2 + self.vec_size) * self.num_copies } pub(crate) fn num_routed_wires(&self) -> usize { @@ -64,16 +66,21 @@ impl, const D: usize> RandomAccessGate { /// An intermediate wire for a dummy variable used to show equality. /// The prover sets this to 1/(x-y) if x != y, or to an arbitrary value if /// x == y. - pub fn wire_equality_dummy_for_index(&self, i: usize) -> usize { + pub fn wire_equality_dummy_for_index(&self, i: usize, copy: usize) -> usize { debug_assert!(i < self.vec_size); - self.start_of_intermediate_wires() + i + debug_assert!(copy < self.num_copies); + self.start_of_intermediate_wires() + copy * self.vec_size + i } /// An intermediate wire for the "index_matches" variable (1 if the current index is the index at /// which to compare, 0 otherwise). - pub fn wire_index_matches_for_index(&self, i: usize) -> usize { + pub fn wire_index_matches_for_index(&self, i: usize, copy: usize) -> usize { debug_assert!(i < self.vec_size); - self.start_of_intermediate_wires() + self.vec_size + i + debug_assert!(copy < self.num_copies); + self.start_of_intermediate_wires() + + self.vec_size * self.num_copies + + copy * self.vec_size + + i } } @@ -83,53 +90,55 @@ impl, const D: usize> Gate for RandomAccessGa } fn eval_unfiltered(&self, vars: EvaluationVars) -> Vec { - let access_index = vars.local_wires[self.wire_access_index()]; - let list_items = (0..self.vec_size) - .map(|i| vars.get_local_ext_algebra(self.wires_list_item(i))) - .collect::>(); - let claimed_element = vars.get_local_ext_algebra(self.wires_claimed_element()); - let mut constraints = Vec::with_capacity(self.num_constraints()); - for i in 0..self.vec_size { - let cur_index = F::Extension::from_canonical_usize(i); - let difference = cur_index - access_index; - let equality_dummy = vars.local_wires[self.wire_equality_dummy_for_index(i)]; - let index_matches = vars.local_wires[self.wire_index_matches_for_index(i)]; - // The two index equality constraints. - constraints.push(difference * equality_dummy - (F::Extension::ONE - index_matches)); - constraints.push(index_matches * difference); - // Value equality constraint. - constraints.extend( - ((list_items[i] - claimed_element).scalar_mul(index_matches)).to_basefield_array(), - ); + for copy in 0..self.num_copies { + let access_index = vars.local_wires[self.wire_access_index(copy)]; + let list_items = (0..self.vec_size) + .map(|i| vars.local_wires[self.wire_list_item(i, copy)]) + .collect::>(); + let claimed_element = vars.local_wires[self.wire_claimed_element(copy)]; + + for i in 0..self.vec_size { + let cur_index = F::Extension::from_canonical_usize(i); + let difference = cur_index - access_index; + let equality_dummy = vars.local_wires[self.wire_equality_dummy_for_index(i, copy)]; + let index_matches = vars.local_wires[self.wire_index_matches_for_index(i, copy)]; + + // The two index equality constraints. + dbg!(difference, equality_dummy, index_matches); + constraints.push(difference * equality_dummy - (F::Extension::ONE - index_matches)); + constraints.push(index_matches * difference); + // Value equality constraint. + constraints.push(((list_items[i] - claimed_element) * index_matches)); + } } constraints } fn eval_unfiltered_base(&self, vars: EvaluationVarsBase) -> Vec { - let access_index = vars.local_wires[self.wire_access_index()]; - let list_items = (0..self.vec_size) - .map(|i| vars.get_local_ext(self.wires_list_item(i))) - .collect::>(); - let claimed_element = vars.get_local_ext(self.wires_claimed_element()); - let mut constraints = Vec::with_capacity(self.num_constraints()); - for i in 0..self.vec_size { - let cur_index = F::from_canonical_usize(i); - let difference = cur_index - access_index; - let equality_dummy = vars.local_wires[self.wire_equality_dummy_for_index(i)]; - let index_matches = vars.local_wires[self.wire_index_matches_for_index(i)]; - // The two equality constraints. - constraints.push(difference * equality_dummy - (F::ONE - index_matches)); - constraints.push(index_matches * difference); + for copy in 0..self.num_copies { + let access_index = vars.local_wires[self.wire_access_index(copy)]; + let list_items = (0..self.vec_size) + .map(|i| vars.local_wires[self.wire_list_item(i, copy)]) + .collect::>(); + let claimed_element = vars.local_wires[self.wire_claimed_element(copy)]; - // Value equality constraint. - constraints.extend( - ((list_items[i] - claimed_element).scalar_mul(index_matches)).to_basefield_array(), - ); + for i in 0..self.vec_size { + let cur_index = F::from_canonical_usize(i); + let difference = cur_index - access_index; + let equality_dummy = vars.local_wires[self.wire_equality_dummy_for_index(i, copy)]; + let index_matches = vars.local_wires[self.wire_index_matches_for_index(i, copy)]; + + // The two index equality constraints. + constraints.push(difference * equality_dummy - (F::ONE - index_matches)); + constraints.push(index_matches * difference); + // Value equality constraint. + constraints.push(((list_items[i] - claimed_element) * index_matches)); + } } constraints @@ -140,35 +149,36 @@ impl, const D: usize> Gate for RandomAccessGa builder: &mut CircuitBuilder, vars: EvaluationTargets, ) -> Vec> { - let access_index = vars.local_wires[self.wire_access_index()]; - let list_items = (0..self.vec_size) - .map(|i| vars.get_local_ext_algebra(self.wires_list_item(i))) - .collect::>(); - let claimed_element = vars.get_local_ext_algebra(self.wires_claimed_element()); - let mut constraints = Vec::with_capacity(self.num_constraints()); - for i in 0..self.vec_size { - let cur_index_ext = F::Extension::from_canonical_usize(i); - let cur_index = builder.constant_extension(cur_index_ext); - let difference = builder.sub_extension(cur_index, access_index); - let equality_dummy = vars.local_wires[self.wire_equality_dummy_for_index(i)]; - let index_matches = vars.local_wires[self.wire_index_matches_for_index(i)]; + for copy in 0..self.num_copies { + let access_index = vars.local_wires[self.wire_access_index(copy)]; + let list_items = (0..self.vec_size) + .map(|i| vars.local_wires[self.wire_list_item(i, copy)]) + .collect::>(); + let claimed_element = vars.local_wires[self.wire_claimed_element(copy)]; - // The two equality constraints. - let one = builder.one_extension(); - let not_index_matches = builder.sub_extension(one, index_matches); - let first_equality_constraint = - builder.mul_sub_extension(difference, equality_dummy, not_index_matches); - constraints.push(first_equality_constraint); + for i in 0..self.vec_size { + let cur_index_ext = F::Extension::from_canonical_usize(i); + let cur_index = builder.constant_extension(cur_index_ext); + let difference = builder.sub_extension(cur_index, access_index); + let equality_dummy = vars.local_wires[self.wire_equality_dummy_for_index(i, copy)]; + let index_matches = vars.local_wires[self.wire_index_matches_for_index(i, copy)]; - let second_equality_constraint = builder.mul_extension(index_matches, difference); - constraints.push(second_equality_constraint); + let one = builder.one_extension(); + let not_index_matches = builder.sub_extension(one, index_matches); + let first_equality_constraint = + builder.mul_sub_extension(difference, equality_dummy, not_index_matches); + constraints.push(first_equality_constraint); - // Output constraint. - let diff = builder.sub_ext_algebra(list_items[i], claimed_element); - let conditional_diff = builder.scalar_mul_ext_algebra(index_matches, diff); - constraints.extend(conditional_diff.to_ext_target_array()); + let second_equality_constraint = builder.mul_extension(index_matches, difference); + constraints.push(second_equality_constraint); + + // Output constraint. + let diff = builder.sub_extension(list_items[i], claimed_element); + let conditional_diff = builder.mul_extension(index_matches, diff); + constraints.push(conditional_diff); + } } constraints @@ -187,7 +197,7 @@ impl, const D: usize> Gate for RandomAccessGa } fn num_wires(&self) -> usize { - self.wire_index_matches_for_index(self.vec_size - 1) + 1 + self.wire_index_matches_for_index(self.vec_size - 1, self.num_copies - 1) + 1 } fn num_constants(&self) -> usize { @@ -199,7 +209,7 @@ impl, const D: usize> Gate for RandomAccessGa } fn num_constraints(&self) -> usize { - self.vec_size * (2 + D) + self.num_copies * self.vec_size * 3 } } @@ -215,13 +225,13 @@ impl, const D: usize> SimpleGenerator fn dependencies(&self) -> Vec { let local_target = |input| Target::wire(self.gate_index, input); - let local_targets = |inputs: Range| inputs.map(local_target); - let mut deps = Vec::new(); - deps.push(local_target(self.gate.wire_access_index())); - deps.extend(local_targets(self.gate.wires_claimed_element())); - for i in 0..self.gate.vec_size { - deps.extend(local_targets(self.gate.wires_list_item(i))); + for copy in 0..self.gate.num_copies { + deps.push(local_target(self.gate.wire_access_index(copy))); + deps.push(local_target(self.gate.wire_claimed_element(copy))); + for i in 0..self.gate.vec_size { + deps.push(local_target(self.gate.wire_list_item(i, copy))); + } } deps } @@ -236,29 +246,34 @@ impl, const D: usize> SimpleGenerator // Compute the new vector and the values for equality_dummy and index_matches let vec_size = self.gate.vec_size; - let access_index_f = get_local_wire(self.gate.wire_access_index()); + for copy in 0..self.gate.num_copies { + let access_index_f = get_local_wire(self.gate.wire_access_index(copy)); - let access_index = access_index_f.to_canonical_u64() as usize; - debug_assert!( - access_index < vec_size, - "Access index {} is larger than the vector size {}", - access_index, - vec_size - ); + let access_index = access_index_f.to_canonical_u64() as usize; + debug_assert!( + access_index < vec_size, + "Access index {} is larger than the vector size {}", + access_index, + vec_size + ); - for i in 0..vec_size { - let equality_dummy_wire = local_wire(self.gate.wire_equality_dummy_for_index(i)); - let index_matches_wire = local_wire(self.gate.wire_index_matches_for_index(i)); + for i in 0..vec_size { + let equality_dummy_wire = + local_wire(self.gate.wire_equality_dummy_for_index(i, copy)); + let index_matches_wire = + local_wire(self.gate.wire_index_matches_for_index(i, copy)); - if i == access_index { - out_buffer.set_wire(equality_dummy_wire, F::ONE); - out_buffer.set_wire(index_matches_wire, F::ONE); - } else { - out_buffer.set_wire( - equality_dummy_wire, - (F::from_canonical_usize(i) - F::from_canonical_usize(access_index)).inverse(), - ); - out_buffer.set_wire(index_matches_wire, F::ZERO); + if i == access_index { + out_buffer.set_wire(equality_dummy_wire, F::ONE); + out_buffer.set_wire(index_matches_wire, F::ONE); + } else { + out_buffer.set_wire( + equality_dummy_wire, + (F::from_canonical_usize(i) - F::from_canonical_usize(access_index)) + .inverse(), + ); + out_buffer.set_wire(index_matches_wire, F::ZERO); + } } } } @@ -269,6 +284,7 @@ mod tests { use std::marker::PhantomData; use anyhow::Result; + use rand::{thread_rng, Rng}; use crate::field::crandall_field::CrandallField; use crate::field::extension_field::quartic::QuarticExtension; @@ -279,31 +295,31 @@ mod tests { use crate::hash::hash_types::HashOut; use crate::plonk::vars::EvaluationVars; - #[test] - fn wire_indices() { - let gate = RandomAccessGate:: { - vec_size: 3, - _phantom: PhantomData, - }; - - assert_eq!(gate.wire_access_index(), 0); - assert_eq!(gate.wires_claimed_element(), 1..5); - assert_eq!(gate.wires_list_item(0), 5..9); - assert_eq!(gate.wires_list_item(2), 13..17); - assert_eq!(gate.wire_equality_dummy_for_index(0), 17); - assert_eq!(gate.wire_equality_dummy_for_index(2), 19); - assert_eq!(gate.wire_index_matches_for_index(0), 20); - assert_eq!(gate.wire_index_matches_for_index(2), 22); - } + // #[test] + // fn wire_indices() { + // let gate = RandomAccessGate:: { + // vec_size: 3, + // _phantom: PhantomData, + // }; + // + // assert_eq!(gate.wire_access_index(), 0); + // assert_eq!(gate.wires_claimed_element(), 1..5); + // assert_eq!(gate.wires_list_item(0), 5..9); + // assert_eq!(gate.wires_list_item(2), 13..17); + // assert_eq!(gate.wire_equality_dummy_for_index(0), 17); + // assert_eq!(gate.wire_equality_dummy_for_index(2), 19); + // assert_eq!(gate.wire_index_matches_for_index(0), 20); + // assert_eq!(gate.wire_index_matches_for_index(2), 22); + // } #[test] fn low_degree() { - test_low_degree::(RandomAccessGate::new(4)); + test_low_degree::(RandomAccessGate::new(4, 4)); } #[test] fn eval_fns() -> Result<()> { - test_eval_fns::(RandomAccessGate::new(4)) + test_eval_fns::(RandomAccessGate::new(4, 4)) } #[test] @@ -314,64 +330,79 @@ mod tests { /// Returns the local wires for a random access gate given the vector, element to compare, /// and index. - fn get_wires(list: Vec, access_index: usize, claimed_element: FF) -> Vec { - let vec_size = list.len(); + fn get_wires( + lists: Vec>, + access_indices: Vec, + claimed_elements: Vec, + ) -> Vec { + let num_copies = lists.len(); + let vec_size = lists[0].len(); let mut v = Vec::new(); - v.push(F::from_canonical_usize(access_index)); - v.extend(claimed_element.0); - for j in 0..vec_size { - v.extend(list[j].0); - } - let mut equality_dummy_vals = Vec::new(); let mut index_matches_vals = Vec::new(); - for i in 0..vec_size { - if i == access_index { - equality_dummy_vals.push(F::ONE); - index_matches_vals.push(F::ONE); - } else { - equality_dummy_vals.push( - (F::from_canonical_usize(i) - F::from_canonical_usize(access_index)) - .inverse(), - ); - index_matches_vals.push(F::ZERO); + for copy in 0..num_copies { + let access_index = access_indices[copy]; + v.push(F::from_canonical_usize(access_index)); + v.push(claimed_elements[copy]); + for j in 0..vec_size { + v.push(lists[copy][j]); + } + + for i in 0..vec_size { + if i == access_index { + equality_dummy_vals.push(F::ONE); + index_matches_vals.push(F::ONE); + } else { + equality_dummy_vals.push( + (F::from_canonical_usize(i) - F::from_canonical_usize(access_index)) + .inverse(), + ); + index_matches_vals.push(F::ZERO); + } } } - v.extend(equality_dummy_vals); v.extend(index_matches_vals); - v.iter().map(|&x| x.into()).collect::>() } - let list = vec![FF::rand(); 3]; - let access_index = 1; + let lists = (0..4).map(|_| F::rand_vec(3)).collect::>(); + let access_indices = (0..4) + .map(|_| thread_rng().gen_range(0..3)) + .collect::>(); let gate = RandomAccessGate:: { vec_size: 3, + num_copies: 4, _phantom: PhantomData, }; - let good_claimed_element = list[access_index]; + let good_claimed_elements = lists + .iter() + .zip(&access_indices) + .map(|(l, &i)| l[i]) + .collect(); + dbg!(&lists, &access_indices, &good_claimed_elements); let good_vars = EvaluationVars { local_constants: &[], - local_wires: &get_wires(list.clone(), access_index, good_claimed_element), + local_wires: &get_wires(lists.clone(), access_indices.clone(), good_claimed_elements), public_inputs_hash: &HashOut::rand(), }; - let bad_claimed_element = FF::rand(); + let bad_claimed_elements = F::rand_vec(4); let bad_vars = EvaluationVars { local_constants: &[], - local_wires: &get_wires(list, access_index, bad_claimed_element), + local_wires: &get_wires(lists, access_indices, bad_claimed_elements), public_inputs_hash: &HashOut::rand(), }; + dbg!(gate.eval_unfiltered(good_vars)); assert!( gate.eval_unfiltered(good_vars).iter().all(|x| x.is_zero()), "Gate constraints are not satisfied." ); assert!( !gate.eval_unfiltered(bad_vars).iter().all(|x| x.is_zero()), - "Gate constraints are satisfied but shouold not be." + "Gate constraints are satisfied but should not be." ); } } diff --git a/src/hash/merkle_proofs.rs b/src/hash/merkle_proofs.rs index ac62d8f0..062ff521 100644 --- a/src/hash/merkle_proofs.rs +++ b/src/hash/merkle_proofs.rs @@ -85,7 +85,7 @@ impl, const D: usize> CircuitBuilder { ExtensionTarget(tmp) }) .collect(); - self.random_access(index, state_ext, cap_ext); + self.random_access_extension(index, state_ext, cap_ext); } /// Same a `verify_merkle_proof` but with the final "cap index" as extra parameter. @@ -122,7 +122,7 @@ impl, const D: usize> CircuitBuilder { ExtensionTarget(tmp) }) .collect(); - self.random_access(cap_index, state_ext, cap_ext); + self.random_access_extension(cap_index, state_ext, cap_ext); } pub fn assert_hashes_equal(&mut self, x: HashOutTarget, y: HashOutTarget) {