diff --git a/src/field/extension_field/target.rs b/src/field/extension_field/target.rs index 20cd3dcb..971d8836 100644 --- a/src/field/extension_field/target.rs +++ b/src/field/extension_field/target.rs @@ -11,6 +11,20 @@ impl ExtensionTarget { pub fn to_target_array(&self) -> [Target; D] { self.0 } + + pub fn frobenius>(&self, builder: &mut CircuitBuilder) -> Self { + let arr = self.to_target_array(); + let k = (F::ORDER - 1) / (D as u64); + let z0 = builder.constant(F::Extension::W.exp(k)); + let mut z = builder.one(); + let mut res = [builder.zero(); D]; + for i in 0..D { + res[i] = builder.mul(arr[i], z); + z = builder.mul(z, z0); + } + + Self(res) + } } /// `Target`s representing an element of an extension of an extension field. diff --git a/src/fri/recursive_verifier.rs b/src/fri/recursive_verifier.rs index bd5559a3..2006d21f 100644 --- a/src/fri/recursive_verifier.rs +++ b/src/fri/recursive_verifier.rs @@ -179,22 +179,26 @@ impl, const D: usize> CircuitBuilder { .iter() .map(|&e| self.convert_to_ext(e)) .collect::>(); + // TODO: Would probably be more efficient using `CircuitBuilder::reduce_with_powers_recursive` let mut ev = self.zero_extension(); for &e in &evs { let a = alpha_powers.next(self); - let tmp = self.mul_extension(a, e); - ev = self.add_extension(ev, tmp); + ev = self.mul_add_extension(a, e, ev); } let g = self.constant_extension(F::Extension::primitive_root_of_unity(degree_log)); let zeta_right = self.mul_extension(g, zeta); - let zs_interpol = self.interpolate2([ - (zeta, reduce_with_iter(&os.plonk_zs, alpha_powers.clone())), - ( - zeta_right, - reduce_with_iter(&os.plonk_zs_right, &mut alpha_powers), - ), - ]); + let mut ev_zeta = self.zero_extension(); + for &t in &os.plonk_zs { + let a = alpha_powers.next(self); + ev_zeta = self.mul_add_extension(a, t, ev_zeta); + } + let mut ev_zeta_right = self.zero_extension(); + for &t in &os.plonk_zs_right { + let a = alpha_powers.next(self); + ev_zeta_right = self.mul_add_extension(a, t, ev_zeta); + } + let zs_interpol = self.interpolate2([(zeta, ev_zeta), (zeta_right, ev_zeta_right)]); let interpol_val = zs_interpol.eval(self, subgroup_x); let numerator = self.sub_extension(ev, interpol_val); let vanish = self.sub_extension(subgroup_x, zeta); @@ -202,22 +206,40 @@ impl, const D: usize> CircuitBuilder { let denominator = self.mul_extension(vanish, vanish_right); let quotient = self.div_unsafe_extension(numerator, denominator); let sum = self.add_extension(sum, quotient); - // - // let ev: F::Extension = proof - // .unsalted_evals(2, config) - // .iter() - // .zip(alpha_powers.clone()) - // .map(|(&e, a)| a * e.into()) - // .sum(); - // let zeta_frob = zeta.frobenius(); - // let wire_evals_frob = os.wires.iter().map(|e| e.frobenius()).collect::>(); - // let wires_interpol = interpolant(&[ - // (zeta, reduce_with_iter(&os.wires, alpha_powers.clone())), - // (zeta_frob, reduce_with_iter(&wire_evals_frob, alpha_powers)), - // ]); - // let numerator = ev - wires_interpol.eval(subgroup_x); - // let denominator = (subgroup_x - zeta) * (subgroup_x - zeta_frob); - // sum += numerator / denominator; + + let evs = proof + .unsalted_evals(2, config) + .iter() + .map(|&e| self.convert_to_ext(e)) + .collect::>(); + let mut ev = self.zero_extension(); + for &e in &evs { + let a = alpha_powers.next(self); + ev = self.mul_add_extension(a, e, ev); + } + let zeta_frob = zeta.frobenius(self); + let wire_evals_frob = os + .wires + .iter() + .map(|e| e.frobenius(self)) + .collect::>(); + let mut ev_zeta = self.zero_extension(); + for &t in &os.wires { + let a = alpha_powers.next(self); + ev_zeta = self.mul_add_extension(a, t, ev_zeta); + } + let mut ev_zeta_frob = self.zero_extension(); + for &t in &wire_evals_frob { + let a = alpha_powers.next(self); + ev_zeta_right = self.mul_add_extension(a, t, ev_zeta); + } + let wires_interpol = self.interpolate2([(zeta, ev_zeta), (zeta_frob, ev_zeta_frob)]); + let interpol_val = wires_interpol.eval(self, subgroup_x); + let numerator = self.sub_extension(ev, interpol_val); + let vanish_frob = self.sub_extension(subgroup_x, zeta_frob); + let denominator = self.mul_extension(vanish, vanish_frob); + let quotient = self.div_unsafe_extension(numerator, denominator); + let sum = self.add_extension(sum, quotient); sum } diff --git a/src/proof.rs b/src/proof.rs index 65634f4c..51e32938 100644 --- a/src/proof.rs +++ b/src/proof.rs @@ -187,5 +187,6 @@ pub struct OpeningSetTarget { pub plonk_sigmas: Vec>, pub wires: Vec>, pub plonk_zs: Vec>, + pub plonk_zs_right: Vec>, pub quotient_polys: Vec>, }