This commit is contained in:
Daniel Lubarov 2021-05-19 15:57:28 -07:00
parent 0ce1a4c5eb
commit 3311981fc4
3 changed files with 65 additions and 35 deletions

View File

@ -96,24 +96,17 @@ impl<F: Field + Extendable<D>, const D: usize> QuarticInterpolationGate<F, D> {
impl<F: Field + Extendable<D>, const D: usize> Gate<F> for QuarticInterpolationGate<F, D> { impl<F: Field + Extendable<D>, const D: usize> Gate<F> for QuarticInterpolationGate<F, D> {
fn id(&self) -> String { fn id(&self) -> String {
let qfe_name = std::any::type_name::<F::Extension>(); format!("{:?}<D={}>", self, D)
format!("{} {:?}", qfe_name, self)
} }
fn eval_unfiltered(&self, vars: EvaluationVars<F>) -> Vec<F> { fn eval_unfiltered(&self, vars: EvaluationVars<F>) -> Vec<F> {
let lookup_fe = |wire_range: Range<usize>| {
debug_assert_eq!(wire_range.len(), D);
let arr = vars.local_wires[wire_range].try_into().unwrap();
F::Extension::from_basefield_array(arr)
};
let mut constraints = Vec::with_capacity(self.num_constraints()); let mut constraints = Vec::with_capacity(self.num_constraints());
let coeffs = (0..self.num_points) let coeffs = (0..self.num_points)
.map(|i| lookup_fe(self.wires_coeff(i))) .map(|i| vars.get_local_ext(self.wires_coeff(i)))
.collect(); .collect();
let interpolant = PolynomialCoeffs::new(coeffs); let interpolant = PolynomialCoeffs::new(coeffs);
let x_eval = lookup_fe(self.wires_evaluation_point()); let x_eval = vars.get_local_ext(self.wires_evaluation_point());
let x_eval_powers = x_eval.powers().take(self.num_points); let x_eval_powers = x_eval.powers().take(self.num_points);
// TODO // TODO
@ -169,23 +162,39 @@ impl<F: Field + Extendable<D>, const D: usize> SimpleGenerator<F>
for QuarticInterpolationGenerator<F, D> for QuarticInterpolationGenerator<F, D>
{ {
fn dependencies(&self) -> Vec<Target> { fn dependencies(&self) -> Vec<Target> {
todo!() let local_target = |input| {
Target::Wire(Wire {
gate: self.gate_index,
input,
})
};
let local_targets = |inputs: Range<usize>| inputs.map(|i| local_target(i));
let mut deps = Vec::new();
deps.extend(local_targets(self.gate.wires_evaluation_point()));
deps.extend(local_targets(self.gate.wires_evaluation_value()));
for i in 0..self.gate.num_points {
deps.push(local_target(self.gate.wire_point(i)));
deps.extend(local_targets(self.gate.wires_value(i)));
deps.extend(local_targets(self.gate.wires_coeff(i)));
}
deps
} }
fn run_once(&self, witness: &PartialWitness<F>) -> PartialWitness<F> { fn run_once(&self, witness: &PartialWitness<F>) -> PartialWitness<F> {
let n = self.gate.num_points; let n = self.gate.num_points;
let local_wire = |input| { let local_wire = |input| Wire {
Wire { gate: self.gate_index, input } gate: self.gate_index,
input,
}; };
let lookup_fe = |wire_range: Range<usize>| { let get_local_wire = |input| witness.get_wire(local_wire(input));
let get_local_ext = |wire_range: Range<usize>| {
debug_assert_eq!(wire_range.len(), D); debug_assert_eq!(wire_range.len(), D);
let values = wire_range let values = wire_range.map(get_local_wire).collect::<Vec<_>>();
.map(|input| {
witness.get_wire(local_wire(input))
})
.collect::<Vec<_>>();
let arr = values.try_into().unwrap(); let arr = values.try_into().unwrap();
F::Extension::from_basefield_array(arr) F::Extension::from_basefield_array(arr)
}; };
@ -194,11 +203,8 @@ impl<F: Field + Extendable<D>, const D: usize> SimpleGenerator<F>
let points = (0..n) let points = (0..n)
.map(|i| { .map(|i| {
( (
F::Extension::from_basefield(witness.get_wire(Wire { F::Extension::from_basefield(get_local_wire(self.gate.wire_point(i))),
gate: self.gate_index, get_local_ext(self.gate.wires_value(i)),
input: self.gate.wire_point(i),
})),
lookup_fe(self.gate.wires_value(i)),
) )
}) })
.collect::<Vec<_>>(); .collect::<Vec<_>>();
@ -206,12 +212,16 @@ impl<F: Field + Extendable<D>, const D: usize> SimpleGenerator<F>
let mut result = PartialWitness::<F>::new(); let mut result = PartialWitness::<F>::new();
for (i, &coeff) in interpolant.coeffs.iter().enumerate() { for (i, &coeff) in interpolant.coeffs.iter().enumerate() {
let wire_range = self.gate.wires_coeff(i); let wires = self.gate.wires_coeff(i).map(local_wire);
let wires = wire_range.map(|i| local_wire(i)).collect::<Vec<_>>(); result.set_ext_wires(wires, coeff);
result.set_ext_wires(&wires, coeff);
} }
todo!() let evaluation_point = get_local_ext(self.gate.wires_evaluation_point());
let evaluation_value = interpolant.eval(evaluation_point);
let evaluation_value_wires = self.gate.wires_evaluation_value().map(local_wire);
result.set_ext_wires(evaluation_value_wires, evaluation_value);
result
} }
} }

View File

@ -1,3 +1,7 @@
use std::convert::TryInto;
use std::ops::Range;
use crate::field::extension_field::{Extendable, FieldExtension};
use crate::field::field::Field; use crate::field::field::Field;
use crate::target::Target; use crate::target::Target;
@ -7,6 +11,17 @@ pub struct EvaluationVars<'a, F: Field> {
pub(crate) local_wires: &'a [F], pub(crate) local_wires: &'a [F],
} }
impl<'a, F: Field> EvaluationVars<'a, F> {
pub fn get_local_ext<const D: usize>(&self, wire_range: Range<usize>) -> F::Extension
where
F: Extendable<D>,
{
debug_assert_eq!(wire_range.len(), D);
let arr = self.local_wires[wire_range].try_into().unwrap();
F::Extension::from_basefield_array(arr)
}
}
#[derive(Copy, Clone)] #[derive(Copy, Clone)]
pub struct EvaluationTargets<'a> { pub struct EvaluationTargets<'a> {
pub(crate) local_constants: &'a [Target], pub(crate) local_constants: &'a [Target],

View File

@ -1,9 +1,9 @@
use std::collections::HashMap; use std::collections::HashMap;
use crate::field::extension_field::{Extendable, FieldExtension};
use crate::field::field::Field; use crate::field::field::Field;
use crate::target::Target; use crate::target::Target;
use crate::wire::Wire; use crate::wire::Wire;
use crate::field::extension_field::{Extendable, FieldExtension};
#[derive(Clone, Debug)] #[derive(Clone, Debug)]
pub struct PartialWitness<F: Field> { pub struct PartialWitness<F: Field> {
@ -74,16 +74,21 @@ impl<F: Field> PartialWitness<F> {
self.set_target(Target::Wire(wire), value) self.set_target(Target::Wire(wire), value)
} }
pub fn set_wires(&mut self, wires: &[Wire], values: &[F]) { pub fn set_wires<W>(&mut self, wires: W, values: &[F])
debug_assert_eq!(wires.len(), values.len()); where
for (&wire, &value) in wires.iter().zip(values) { W: IntoIterator<Item = Wire>,
{
// If we used itertools, we could use zip_eq for extra safety.
for (wire, &value) in wires.into_iter().zip(values) {
self.set_wire(wire, value); self.set_wire(wire, value);
} }
} }
pub fn set_ext_wires<const D: usize>(&mut self, wires: &[Wire], value: F::Extension) pub fn set_ext_wires<W, const D: usize>(&mut self, wires: W, value: F::Extension)
where F: Extendable<D> { where
debug_assert_eq!(wires.len(), D); F: Extendable<D>,
W: IntoIterator<Item = Wire>,
{
self.set_wires(wires, &value.to_basefield_array()); self.set_wires(wires, &value.to_basefield_array());
} }