witnessgenerator

This commit is contained in:
Nicholas Ward 2021-09-02 11:54:20 -07:00
parent 10d016a92c
commit d1fea5cfd3
2 changed files with 68 additions and 5 deletions

View File

@ -12,6 +12,7 @@ edition = "2018"
default-run = "bench_recursion"
[dependencies]
array_tool = "1.0.3"
bimap = "0.4.0"
env_logger = "0.9.0"
log = "0.4.14"

View File

@ -1,10 +1,12 @@
use std::marker::PhantomData;
use array_tool::vec::Union;
use crate::field::extension_field::target::ExtensionTarget;
use crate::field::extension_field::Extendable;
use crate::field::field_types::Field;
use crate::gates::gate::Gate;
use crate::iop::generator::{GeneratedValues, SimpleGenerator, WitnessGenerator};
use crate::iop::generator::{GeneratedValues, WitnessGenerator};
use crate::iop::target::Target;
use crate::iop::wire::Wire;
use crate::iop::witness::{PartitionWitness, Witness};
@ -198,8 +200,22 @@ struct SwitchGenerator<F: Extendable<D>, const D: usize> {
copy: usize,
}
impl<F: Extendable<D>, const D: usize> SimpleGenerator<F> for SwitchGenerator<F, D> {
fn dependencies(&self) -> Vec<Target> {
impl<F: Extendable<D>, const D: usize> SwitchGenerator<F, D> {
fn in_out_dependencies(&self) -> Vec<Target> {
let local_target = |input| Target::wire(self.gate_index, input);
let mut deps = Vec::new();
for e in 0..self.gate.chunk_size {
deps.push(local_target(self.gate.wire_first_input(self.copy, e)));
deps.push(local_target(self.gate.wire_second_input(self.copy, e)));
deps.push(local_target(self.gate.wire_first_output(self.copy, e)));
deps.push(local_target(self.gate.wire_second_output(self.copy, e)));
}
deps
}
fn in_switch_dependencies(&self) -> Vec<Target> {
let local_target = |input| Target::wire(self.gate_index, input);
let mut deps = Vec::new();
@ -212,7 +228,32 @@ impl<F: Extendable<D>, const D: usize> SimpleGenerator<F> for SwitchGenerator<F,
deps
}
fn run_once(&self, witness: &PartitionWitness<F>, out_buffer: &mut GeneratedValues<F>) {
fn run_in_out(&self, witness: &PartialWitness<F>, out_buffer: &mut GeneratedValues<F>) {
let local_wire = |input| Wire {
gate: self.gate_index,
input,
};
let get_local_wire = |input| witness.get_wire(local_wire(input));
for e in 0..self.gate.chunk_size {
let switch_bool_wire = local_wire(self.gate.wire_switch_bool(self.copy));
let first_input = get_local_wire(self.gate.wire_first_input(self.copy, e));
let second_input = get_local_wire(self.gate.wire_second_input(self.copy, e));
let first_output = get_local_wire(self.gate.wire_first_output(self.copy, e));
let second_output = get_local_wire(self.gate.wire_second_output(self.copy, e));
if first_output == first_input && second_output == second_input {
out_buffer.set_wire(switch_bool_wire, F::ZERO);
} else if first_output == second_input && second_output == first_input {
out_buffer.set_wire(switch_bool_wire, F::ONE);
} else {
panic!("No permutation from given inputs to given outputs");
}
}
}
fn run_in_switch(&self, witness: &PartialWitness<F>, out_buffer: &mut GeneratedValues<F>) {
let local_wire = |input| Wire {
gate: self.gate_index,
input,
@ -229,8 +270,10 @@ impl<F: Extendable<D>, const D: usize> SimpleGenerator<F> for SwitchGenerator<F,
let (first_output, second_output) = if switch_bool == F::ZERO {
(first_input, second_input)
} else {
} else if switch_bool == F::ONE {
(second_input, first_input)
} else {
panic!("Invalid switch bool value");
};
out_buffer.set_wire(first_output_wire, first_output);
@ -239,6 +282,25 @@ impl<F: Extendable<D>, const D: usize> SimpleGenerator<F> for SwitchGenerator<F,
}
}
impl<F: Extendable<D>, const D: usize> WitnessGenerator<F> for SwitchGenerator<F, D> {
fn watch_list(&self) -> Vec<Target> {
self.in_out_dependencies()
.union(self.in_switch_dependencies())
}
fn run(&self, witness: &PartialWitness<F>, out_buffer: &mut GeneratedValues<F>) -> bool {
if witness.contains_all(&self.in_out_dependencies()) {
self.run_in_out(witness, out_buffer);
true
} else if witness.contains_all(&self.in_switch_dependencies()) {
self.run_in_switch(witness, out_buffer);
true
} else {
false
}
}
}
#[cfg(test)]
mod tests {
use std::marker::PhantomData;