This commit is contained in:
Daniel Lubarov 2021-05-19 15:57:28 -07:00
parent 0ce1a4c5eb
commit 3311981fc4
3 changed files with 65 additions and 35 deletions

View File

@ -96,24 +96,17 @@ impl<F: Field + Extendable<D>, const D: usize> QuarticInterpolationGate<F, D> {
impl<F: Field + Extendable<D>, const D: usize> Gate<F> for QuarticInterpolationGate<F, D> {
fn id(&self) -> String {
let qfe_name = std::any::type_name::<F::Extension>();
format!("{} {:?}", qfe_name, self)
format!("{:?}<D={}>", self, D)
}
fn eval_unfiltered(&self, vars: EvaluationVars<F>) -> Vec<F> {
let lookup_fe = |wire_range: Range<usize>| {
debug_assert_eq!(wire_range.len(), D);
let arr = vars.local_wires[wire_range].try_into().unwrap();
F::Extension::from_basefield_array(arr)
};
let mut constraints = Vec::with_capacity(self.num_constraints());
let coeffs = (0..self.num_points)
.map(|i| lookup_fe(self.wires_coeff(i)))
.map(|i| vars.get_local_ext(self.wires_coeff(i)))
.collect();
let interpolant = PolynomialCoeffs::new(coeffs);
let x_eval = lookup_fe(self.wires_evaluation_point());
let x_eval = vars.get_local_ext(self.wires_evaluation_point());
let x_eval_powers = x_eval.powers().take(self.num_points);
// TODO
@ -169,23 +162,39 @@ impl<F: Field + Extendable<D>, const D: usize> SimpleGenerator<F>
for QuarticInterpolationGenerator<F, D>
{
fn dependencies(&self) -> Vec<Target> {
todo!()
let local_target = |input| {
Target::Wire(Wire {
gate: self.gate_index,
input,
})
};
let local_targets = |inputs: Range<usize>| inputs.map(|i| local_target(i));
let mut deps = Vec::new();
deps.extend(local_targets(self.gate.wires_evaluation_point()));
deps.extend(local_targets(self.gate.wires_evaluation_value()));
for i in 0..self.gate.num_points {
deps.push(local_target(self.gate.wire_point(i)));
deps.extend(local_targets(self.gate.wires_value(i)));
deps.extend(local_targets(self.gate.wires_coeff(i)));
}
deps
}
fn run_once(&self, witness: &PartialWitness<F>) -> PartialWitness<F> {
let n = self.gate.num_points;
let local_wire = |input| {
Wire { gate: self.gate_index, input }
let local_wire = |input| Wire {
gate: self.gate_index,
input,
};
let lookup_fe = |wire_range: Range<usize>| {
let get_local_wire = |input| witness.get_wire(local_wire(input));
let get_local_ext = |wire_range: Range<usize>| {
debug_assert_eq!(wire_range.len(), D);
let values = wire_range
.map(|input| {
witness.get_wire(local_wire(input))
})
.collect::<Vec<_>>();
let values = wire_range.map(get_local_wire).collect::<Vec<_>>();
let arr = values.try_into().unwrap();
F::Extension::from_basefield_array(arr)
};
@ -194,11 +203,8 @@ impl<F: Field + Extendable<D>, const D: usize> SimpleGenerator<F>
let points = (0..n)
.map(|i| {
(
F::Extension::from_basefield(witness.get_wire(Wire {
gate: self.gate_index,
input: self.gate.wire_point(i),
})),
lookup_fe(self.gate.wires_value(i)),
F::Extension::from_basefield(get_local_wire(self.gate.wire_point(i))),
get_local_ext(self.gate.wires_value(i)),
)
})
.collect::<Vec<_>>();
@ -206,12 +212,16 @@ impl<F: Field + Extendable<D>, const D: usize> SimpleGenerator<F>
let mut result = PartialWitness::<F>::new();
for (i, &coeff) in interpolant.coeffs.iter().enumerate() {
let wire_range = self.gate.wires_coeff(i);
let wires = wire_range.map(|i| local_wire(i)).collect::<Vec<_>>();
result.set_ext_wires(&wires, coeff);
let wires = self.gate.wires_coeff(i).map(local_wire);
result.set_ext_wires(wires, coeff);
}
todo!()
let evaluation_point = get_local_ext(self.gate.wires_evaluation_point());
let evaluation_value = interpolant.eval(evaluation_point);
let evaluation_value_wires = self.gate.wires_evaluation_value().map(local_wire);
result.set_ext_wires(evaluation_value_wires, evaluation_value);
result
}
}

View File

@ -1,3 +1,7 @@
use std::convert::TryInto;
use std::ops::Range;
use crate::field::extension_field::{Extendable, FieldExtension};
use crate::field::field::Field;
use crate::target::Target;
@ -7,6 +11,17 @@ pub struct EvaluationVars<'a, F: Field> {
pub(crate) local_wires: &'a [F],
}
impl<'a, F: Field> EvaluationVars<'a, F> {
pub fn get_local_ext<const D: usize>(&self, wire_range: Range<usize>) -> F::Extension
where
F: Extendable<D>,
{
debug_assert_eq!(wire_range.len(), D);
let arr = self.local_wires[wire_range].try_into().unwrap();
F::Extension::from_basefield_array(arr)
}
}
#[derive(Copy, Clone)]
pub struct EvaluationTargets<'a> {
pub(crate) local_constants: &'a [Target],

View File

@ -1,9 +1,9 @@
use std::collections::HashMap;
use crate::field::extension_field::{Extendable, FieldExtension};
use crate::field::field::Field;
use crate::target::Target;
use crate::wire::Wire;
use crate::field::extension_field::{Extendable, FieldExtension};
#[derive(Clone, Debug)]
pub struct PartialWitness<F: Field> {
@ -74,16 +74,21 @@ impl<F: Field> PartialWitness<F> {
self.set_target(Target::Wire(wire), value)
}
pub fn set_wires(&mut self, wires: &[Wire], values: &[F]) {
debug_assert_eq!(wires.len(), values.len());
for (&wire, &value) in wires.iter().zip(values) {
pub fn set_wires<W>(&mut self, wires: W, values: &[F])
where
W: IntoIterator<Item = Wire>,
{
// If we used itertools, we could use zip_eq for extra safety.
for (wire, &value) in wires.into_iter().zip(values) {
self.set_wire(wire, value);
}
}
pub fn set_ext_wires<const D: usize>(&mut self, wires: &[Wire], value: F::Extension)
where F: Extendable<D> {
debug_assert_eq!(wires.len(), D);
pub fn set_ext_wires<W, const D: usize>(&mut self, wires: W, value: F::Extension)
where
F: Extendable<D>,
W: IntoIterator<Item = Wire>,
{
self.set_wires(wires, &value.to_basefield_array());
}