diff --git a/src/gates/insertion.rs b/src/gates/insertion.rs index e82b4b32..37469237 100644 --- a/src/gates/insertion.rs +++ b/src/gates/insertion.rs @@ -38,7 +38,8 @@ impl, const D: usize> InsertionGate { 1..D + 1 } - pub fn wires_list_item(&self, i: usize) -> Range { + pub fn wires_original_list_item(&self, i: usize) -> Range { + debug_assert!(i < self.vec_size); let start = (i + 1) * D + 1; start..start + D } @@ -48,6 +49,7 @@ impl, const D: usize> InsertionGate { } pub fn wires_output_list_item(&self, i: usize) -> Range { + debug_assert!(i <= self.vec_size); let start = self.start_of_output_wires() + i * D; start..start + D } @@ -56,10 +58,15 @@ impl, const D: usize> InsertionGate { self.start_of_output_wires() + (self.vec_size + 1) * D } + /// An intermediate wire for a dummy variable used to show equality. + /// The prover sets this to 1/(x-y) if x != y, or to an arbitrary value if + /// x == y. pub fn wires_equality_dummy_for_round_r(&self, r: usize) -> usize { self.start_of_intermediate_wires() + r } + // An intermediate wire for the "insert_here" variable (1 if the current index is the index at + /// which to insert the new value, 0 otherwise). pub fn wires_insert_here_for_round_r(&self, r: usize) -> usize { self.start_of_intermediate_wires() + (self.vec_size + 1) + r } @@ -73,24 +80,20 @@ impl, const D: usize> Gate for InsertionGate { fn eval_unfiltered(&self, vars: EvaluationVars) -> Vec { let insertion_index = vars.local_wires[self.wires_insertion_index()]; - let mut list_items = Vec::new(); - for i in 0..self.vec_size { - list_items.push(vars.get_local_ext_algebra(self.wires_list_item(i))); - } - let dummy_value: ExtensionAlgebra = F::Extension::ZERO.into(); // will never be reached - list_items.push(dummy_value); + let list_items = (0..self.vec_size) + .map(|i| vars.get_local_ext_algebra(self.wires_original_list_item(i))) + .collect::>(); - let mut output_list_items = Vec::new(); - for i in 0..self.vec_size + 1 { - output_list_items.push(vars.get_local_ext_algebra(self.wires_output_list_item(i))); - } + let output_list_items = (0..=self.vec_size) + .map(|i| vars.get_local_ext_algebra(self.wires_output_list_item(i))) + .collect::>(); let element_to_insert = vars.get_local_ext_algebra(self.wires_element_to_insert()); let mut constraints = Vec::new(); let mut already_inserted = F::Extension::ZERO; - for r in 0..self.vec_size + 1 { + for r in 0..=self.vec_size { let cur_index = F::Extension::from_canonical_usize(r); let equality_dummy = vars.local_wires[self.wires_equality_dummy_for_round_r(r)]; @@ -108,7 +111,9 @@ impl, const D: usize> Gate for InsertionGate { } already_inserted += insert_here; - new_item += list_items[r] * (F::Extension::ONE - already_inserted).into(); + if r < self.vec_size { + new_item += list_items[r] * (F::Extension::ONE - already_inserted).into(); + } constraints.extend((new_item - output_list_items[r]).to_basefield_array()); } @@ -132,7 +137,6 @@ impl, const D: usize> Gate for InsertionGate { let gen = InsertionGenerator:: { gate_index, gate: self.clone(), - _phantom: PhantomData, }; vec![Box::new(gen)] } @@ -150,7 +154,7 @@ impl, const D: usize> Gate for InsertionGate { } fn num_constraints(&self) -> usize { - (self.vec_size + 1) * 3 + (self.vec_size + 1) * (2 + D) } } @@ -158,17 +162,11 @@ impl, const D: usize> Gate for InsertionGate { struct InsertionGenerator, const D: usize> { gate_index: usize, gate: InsertionGate, - _phantom: PhantomData, } impl, const D: usize> SimpleGenerator for InsertionGenerator { fn dependencies(&self) -> Vec { - let local_target = |input| { - Target::Wire(Wire { - gate: self.gate_index, - input, - }) - }; + let local_target = |input| Target::wire(self.gate_index, input); let local_targets = |inputs: Range| inputs.map(local_target); @@ -176,7 +174,7 @@ impl, const D: usize> SimpleGenerator for InsertionGenerator deps.push(local_target(self.gate.wires_insertion_index())); deps.extend(local_targets(self.gate.wires_element_to_insert())); for i in 0..self.gate.vec_size { - deps.extend(local_targets(self.gate.wires_list_item(i))); + deps.extend(local_targets(self.gate.wires_original_list_item(i))); } deps } @@ -197,38 +195,38 @@ impl, const D: usize> SimpleGenerator for InsertionGenerator }; // Compute the new vector and the values for equality_dummy and insert_here - let n = self.gate.vec_size; - let orig_vec = (0..n) - .map(|i| get_local_ext(self.gate.wires_list_item(i))) + let vec_size = self.gate.vec_size; + let orig_vec = (0..vec_size) + .map(|i| get_local_ext(self.gate.wires_original_list_item(i))) .collect::>(); let to_insert = get_local_ext(self.gate.wires_element_to_insert()); let insertion_index_f = get_local_wire(self.gate.wires_insertion_index()); let insertion_index = insertion_index_f.to_canonical_u64() as usize; - let mut new_vec = Vec::new(); - new_vec.extend(&orig_vec[..insertion_index]); - new_vec.push(to_insert); - new_vec.extend(&orig_vec[insertion_index..]); + debug_assert!( + insertion_index <= vec_size, + "Insertion index {} is larger than the vector size {}", + insertion_index, + vec_size + ); + + let mut new_vec = orig_vec.clone(); + new_vec.insert(insertion_index, to_insert); let mut equality_dummy_vals = Vec::new(); - for i in 0..n + 1 { - if i != insertion_index { - let diff = if i > insertion_index { - F::from_canonical_usize(i - insertion_index) - } else { - F::ZERO - F::from_canonical_usize(insertion_index - i) - }; - equality_dummy_vals.push(diff.inverse()); + for i in 0..=vec_size { + equality_dummy_vals.push(if i == insertion_index { + F::ONE } else { - equality_dummy_vals.push(F::ONE); - } + (F::from_canonical_usize(i) - insertion_index_f).inverse() + }); } - let mut insert_here_vals = vec![F::ZERO; n]; + let mut insert_here_vals = vec![F::ZERO; vec_size]; insert_here_vals.insert(insertion_index, F::ONE); let mut result = PartialWitness::::new(); - for i in 0..n + 1 { + for i in 0..=vec_size { let output_wires = self.gate.wires_output_list_item(i).map(local_wire); result.set_ext_wires(output_wires, new_vec[i]); let equality_dummy_wire = local_wire(self.gate.wires_equality_dummy_for_round_r(i)); @@ -263,8 +261,8 @@ mod tests { assert_eq!(gate.wires_insertion_index(), 0); assert_eq!(gate.wires_element_to_insert(), 1..5); - assert_eq!(gate.wires_list_item(0), 5..9); - assert_eq!(gate.wires_list_item(2), 13..17); + assert_eq!(gate.wires_original_list_item(0), 5..9); + assert_eq!(gate.wires_original_list_item(2), 13..17); assert_eq!(gate.wires_output_list_item(0), 17..21); assert_eq!(gate.wires_output_list_item(3), 29..33); assert_eq!(gate.wires_equality_dummy_for_round_r(0), 33); @@ -285,51 +283,37 @@ mod tests { type FF = QuarticCrandallField; const D: usize = 4; - /// Returns the local wires for an interpolation gate for given coeffs, points and eval point. - fn get_wires( - vec_size: usize, - orig_vec: Vec, - insertion_index: usize, - element_to_insert: FF, - ) -> Vec { - let mut v = vec![F::ZERO; 2 * (vec_size + 1) * (D + 1) + 1]; - v[0] = F::from_canonical_usize(insertion_index as usize); - for i in 0..D { - v[1 + i] = >::to_basefield_array(&element_to_insert)[i]; - } + /// Returns the local wires for an insertion gate for given the original vector, element to + /// insert, and index. + fn get_wires(orig_vec: Vec, insertion_index: usize, element_to_insert: FF) -> Vec { + let vec_size = orig_vec.len(); + + let mut v = Vec::new(); + v.push(F::from_canonical_usize(insertion_index)); + v.extend(element_to_insert.0); for j in 0..vec_size { - for i in 0..D { - v[(j + 1) * D + 1 + i] = - >::to_basefield_array(&orig_vec[j])[i]; - } + v.extend(orig_vec[j].0); } let mut new_vec = orig_vec.clone(); new_vec.insert(insertion_index, element_to_insert); let mut equality_dummy_vals = Vec::new(); - for i in 0..vec_size + 1 { - if i != insertion_index { - let diff = if i > insertion_index { - F::from_canonical_usize(i - insertion_index) - } else { - F::ZERO - F::from_canonical_usize(insertion_index - i) - }; - equality_dummy_vals.push(diff.inverse()); + for i in 0..=vec_size { + equality_dummy_vals.push(if i == insertion_index { + F::ONE } else { - equality_dummy_vals.push(F::ONE); - } + (F::from_canonical_usize(i) - F::from_canonical_usize(insertion_index)) + .inverse() + }); } let mut insert_here_vals = vec![F::ZERO; vec_size]; insert_here_vals.insert(insertion_index, F::ONE); - for j in 0..vec_size + 1 { - for i in 0..D { - v[(vec_size + j + 1) * D + 1 + i] = - >::to_basefield_array(&new_vec[j])[i]; - } - v[(2 * vec_size + 2) * D + 1 + j] = equality_dummy_vals[j]; - v[(2 * vec_size + 2) * D + 1 + (vec_size + 1) + j] = insert_here_vals[j]; + for j in 0..=vec_size { + v.extend(new_vec[j].0); } + v.extend(equality_dummy_vals); + v.extend(insert_here_vals); v.iter().map(|&x| x.into()).collect::>() } @@ -343,7 +327,7 @@ mod tests { }; let vars = EvaluationVars { local_constants: &[], - local_wires: &get_wires(3, orig_vec, insertion_index, element_to_insert), + local_wires: &get_wires(orig_vec, insertion_index, element_to_insert), }; assert!(