From 3311981fc4de5a3d274d8bda4261d7aaea3aa274 Mon Sep 17 00:00:00 2001 From: Daniel Lubarov Date: Wed, 19 May 2021 15:57:28 -0700 Subject: [PATCH] Minor --- src/gates/interpolation_quartic.rs | 66 +++++++++++++++++------------- src/vars.rs | 15 +++++++ src/witness.rs | 19 +++++---- 3 files changed, 65 insertions(+), 35 deletions(-) diff --git a/src/gates/interpolation_quartic.rs b/src/gates/interpolation_quartic.rs index 0dcf22eb..6b37a48f 100644 --- a/src/gates/interpolation_quartic.rs +++ b/src/gates/interpolation_quartic.rs @@ -96,24 +96,17 @@ impl, const D: usize> QuarticInterpolationGate { impl, const D: usize> Gate for QuarticInterpolationGate { fn id(&self) -> String { - let qfe_name = std::any::type_name::(); - format!("{} {:?}", qfe_name, self) + format!("{:?}", self, D) } fn eval_unfiltered(&self, vars: EvaluationVars) -> Vec { - let lookup_fe = |wire_range: Range| { - debug_assert_eq!(wire_range.len(), D); - let arr = vars.local_wires[wire_range].try_into().unwrap(); - F::Extension::from_basefield_array(arr) - }; - let mut constraints = Vec::with_capacity(self.num_constraints()); let coeffs = (0..self.num_points) - .map(|i| lookup_fe(self.wires_coeff(i))) + .map(|i| vars.get_local_ext(self.wires_coeff(i))) .collect(); let interpolant = PolynomialCoeffs::new(coeffs); - let x_eval = lookup_fe(self.wires_evaluation_point()); + let x_eval = vars.get_local_ext(self.wires_evaluation_point()); let x_eval_powers = x_eval.powers().take(self.num_points); // TODO @@ -169,23 +162,39 @@ impl, const D: usize> SimpleGenerator for QuarticInterpolationGenerator { fn dependencies(&self) -> Vec { - todo!() + let local_target = |input| { + Target::Wire(Wire { + gate: self.gate_index, + input, + }) + }; + + let local_targets = |inputs: Range| inputs.map(|i| local_target(i)); + + let mut deps = Vec::new(); + deps.extend(local_targets(self.gate.wires_evaluation_point())); + deps.extend(local_targets(self.gate.wires_evaluation_value())); + for i in 0..self.gate.num_points { + deps.push(local_target(self.gate.wire_point(i))); + deps.extend(local_targets(self.gate.wires_value(i))); + deps.extend(local_targets(self.gate.wires_coeff(i))); + } + deps } fn run_once(&self, witness: &PartialWitness) -> PartialWitness { let n = self.gate.num_points; - let local_wire = |input| { - Wire { gate: self.gate_index, input } + let local_wire = |input| Wire { + gate: self.gate_index, + input, }; - let lookup_fe = |wire_range: Range| { + let get_local_wire = |input| witness.get_wire(local_wire(input)); + + let get_local_ext = |wire_range: Range| { debug_assert_eq!(wire_range.len(), D); - let values = wire_range - .map(|input| { - witness.get_wire(local_wire(input)) - }) - .collect::>(); + let values = wire_range.map(get_local_wire).collect::>(); let arr = values.try_into().unwrap(); F::Extension::from_basefield_array(arr) }; @@ -194,11 +203,8 @@ impl, const D: usize> SimpleGenerator let points = (0..n) .map(|i| { ( - F::Extension::from_basefield(witness.get_wire(Wire { - gate: self.gate_index, - input: self.gate.wire_point(i), - })), - lookup_fe(self.gate.wires_value(i)), + F::Extension::from_basefield(get_local_wire(self.gate.wire_point(i))), + get_local_ext(self.gate.wires_value(i)), ) }) .collect::>(); @@ -206,12 +212,16 @@ impl, const D: usize> SimpleGenerator let mut result = PartialWitness::::new(); for (i, &coeff) in interpolant.coeffs.iter().enumerate() { - let wire_range = self.gate.wires_coeff(i); - let wires = wire_range.map(|i| local_wire(i)).collect::>(); - result.set_ext_wires(&wires, coeff); + let wires = self.gate.wires_coeff(i).map(local_wire); + result.set_ext_wires(wires, coeff); } - todo!() + let evaluation_point = get_local_ext(self.gate.wires_evaluation_point()); + let evaluation_value = interpolant.eval(evaluation_point); + let evaluation_value_wires = self.gate.wires_evaluation_value().map(local_wire); + result.set_ext_wires(evaluation_value_wires, evaluation_value); + + result } } diff --git a/src/vars.rs b/src/vars.rs index f2744e8f..7c0afab9 100644 --- a/src/vars.rs +++ b/src/vars.rs @@ -1,3 +1,7 @@ +use std::convert::TryInto; +use std::ops::Range; + +use crate::field::extension_field::{Extendable, FieldExtension}; use crate::field::field::Field; use crate::target::Target; @@ -7,6 +11,17 @@ pub struct EvaluationVars<'a, F: Field> { pub(crate) local_wires: &'a [F], } +impl<'a, F: Field> EvaluationVars<'a, F> { + pub fn get_local_ext(&self, wire_range: Range) -> F::Extension + where + F: Extendable, + { + debug_assert_eq!(wire_range.len(), D); + let arr = self.local_wires[wire_range].try_into().unwrap(); + F::Extension::from_basefield_array(arr) + } +} + #[derive(Copy, Clone)] pub struct EvaluationTargets<'a> { pub(crate) local_constants: &'a [Target], diff --git a/src/witness.rs b/src/witness.rs index e0afb77c..a0b4b2a4 100644 --- a/src/witness.rs +++ b/src/witness.rs @@ -1,9 +1,9 @@ use std::collections::HashMap; +use crate::field::extension_field::{Extendable, FieldExtension}; use crate::field::field::Field; use crate::target::Target; use crate::wire::Wire; -use crate::field::extension_field::{Extendable, FieldExtension}; #[derive(Clone, Debug)] pub struct PartialWitness { @@ -74,16 +74,21 @@ impl PartialWitness { self.set_target(Target::Wire(wire), value) } - pub fn set_wires(&mut self, wires: &[Wire], values: &[F]) { - debug_assert_eq!(wires.len(), values.len()); - for (&wire, &value) in wires.iter().zip(values) { + pub fn set_wires(&mut self, wires: W, values: &[F]) + where + W: IntoIterator, + { + // If we used itertools, we could use zip_eq for extra safety. + for (wire, &value) in wires.into_iter().zip(values) { self.set_wire(wire, value); } } - pub fn set_ext_wires(&mut self, wires: &[Wire], value: F::Extension) - where F: Extendable { - debug_assert_eq!(wires.len(), D); + pub fn set_ext_wires(&mut self, wires: W, value: F::Extension) + where + F: Extendable, + W: IntoIterator, + { self.set_wires(wires, &value.to_basefield_array()); }