many fixes

This commit is contained in:
Nicholas Ward 2021-08-27 14:34:53 -07:00
parent a1d5f5b6fe
commit fe843db57f

View File

@ -1,5 +1,6 @@
use std::collections::BTreeMap; use std::collections::BTreeMap;
use std::convert::TryInto; use std::convert::TryInto;
use std::marker::PhantomData;
use crate::field::{extension_field::Extendable, field_types::Field}; use crate::field::{extension_field::Extendable, field_types::Field};
use crate::gates::switch::SwitchGate; use crate::gates::switch::SwitchGate;
@ -64,13 +65,13 @@ impl<F: Extendable<D>, const D: usize> CircuitBuilder<F, D> {
.extend(vec![None; CHUNK_SIZE - self.current_switch_gates.len()]); .extend(vec![None; CHUNK_SIZE - self.current_switch_gates.len()]);
} }
let (gate, gate_index, mut next_copy) = match self.current_switch_gates[CHUNK_SIZE - 1] { let (gate_index, mut next_copy) = match self.current_switch_gates[CHUNK_SIZE - 1] {
None => { None => {
let gate = SwitchGate::<F, D, CHUNK_SIZE>::new_from_config(self.config.clone()); let gate = SwitchGate::<F, D, CHUNK_SIZE>::new_from_config(self.config.clone());
let gate_index = self.add_gate(gate.clone(), vec![]); let gate_index = self.add_gate(gate.clone(), vec![]);
(gate, gate_index, 0) (gate_index, 0)
} }
Some((idx, next_copy)) => (self.gate_instances[idx], idx, next_copy), Some((idx, next_copy)) => (idx, next_copy),
}; };
let num_copies = let num_copies =
@ -105,7 +106,7 @@ impl<F: Extendable<D>, const D: usize> CircuitBuilder<F, D> {
let switch = Target::wire( let switch = Target::wire(
gate_index, gate_index,
SwitchGate::<F, D, CHUNK_SIZE>::wire_switch_bool(gate.num_copies, next_copy), SwitchGate::<F, D, CHUNK_SIZE>::wire_switch_bool(num_copies, next_copy),
); );
let c_arr: [Target; CHUNK_SIZE] = c.try_into().unwrap(); let c_arr: [Target; CHUNK_SIZE] = c.try_into().unwrap();
@ -144,13 +145,17 @@ impl<F: Extendable<D>, const D: usize> CircuitBuilder<F, D> {
a_num_switches a_num_switches
}; };
let mut a_switches = Vec::new();
let mut b_switches = Vec::new();
for i in 0..a_num_switches { for i in 0..a_num_switches {
let (a_switch, out_1, out_2) = self.create_switch(a[i * 2], a[i * 2 + 1]); let (switch, out_1, out_2) = self.create_switch(a[i * 2], a[i * 2 + 1]);
a_switches.push(switch);
child_1_a.push(out_1); child_1_a.push(out_1);
child_2_a.push(out_2); child_2_a.push(out_2);
} }
for i in 0..b_num_switches { for i in 0..b_num_switches {
let (b_switch, out_1, out_2) = self.create_switch(b[i * 2], b[i * 2 + 1]); let (switch, out_1, out_2) = self.create_switch(b[i * 2], b[i * 2 + 1]);
b_switches.push(switch);
child_1_b.push(out_1); child_1_b.push(out_1);
child_2_b.push(out_2); child_2_b.push(out_2);
} }
@ -167,15 +172,21 @@ impl<F: Extendable<D>, const D: usize> CircuitBuilder<F, D> {
self.assert_permutation(child_1_a, child_1_b); self.assert_permutation(child_1_a, child_1_b);
self.assert_permutation(child_2_a, child_2_b); self.assert_permutation(child_2_a, child_2_b);
self.add_generator(PermutationGenerator {}); self.add_generator(PermutationGenerator::<F, CHUNK_SIZE> {
a,
b,
a_switches,
b_switches,
_phantom: PhantomData,
});
} }
} }
fn route<F: Field, const CHUNK_SIZE: usize>( fn route<F: Field, const CHUNK_SIZE: usize>(
a_values: Vec<[F; CHUNK_SIZE]>, a_values: Vec<[F; CHUNK_SIZE]>,
b_values: Vec<[F; CHUNK_SIZE]>, b_values: Vec<[F; CHUNK_SIZE]>,
a_switches: Vec<[Target; CHUNK_SIZE]>, a_switches: Vec<Target>,
b_switches: Vec<[Target; CHUNK_SIZE]>, b_switches: Vec<Target>,
witness: &PartialWitness<F>, witness: &PartialWitness<F>,
out_buffer: &mut GeneratedValues<F>, out_buffer: &mut GeneratedValues<F>,
) { ) {
@ -219,7 +230,7 @@ fn route<F: Field, const CHUNK_SIZE: usize>(
return; return;
} }
if witness.contains_all(&switches[other_side][other_switch_i]) { if witness.contains(switches[other_side][other_switch_i]) {
// The other switch has already been routed. // The other switch has already been routed.
return; return;
} }
@ -252,9 +263,7 @@ fn route<F: Field, const CHUNK_SIZE: usize>(
switch_index: usize, switch_index: usize,
swap: bool| { swap: bool| {
// First, we actually set the switch configuration. // First, we actually set the switch configuration.
for e in 0..CHUNK_SIZE { out_buffer.set_target(switches[side][switch_index], F::from_bool(swap));
out_buffer.set_target(switches[side][switch_index][e], F::from_bool(swap));
}
// Then, we enqueue the two corresponding wires on the other side of the network, to ensure // Then, we enqueue the two corresponding wires on the other side of the network, to ensure
// that they get routed in the next step. // that they get routed in the next step.
@ -292,7 +301,7 @@ fn route<F: Field, const CHUNK_SIZE: usize>(
} else { } else {
// We can route any switch next. Continue our scan for pending switches. // We can route any switch next. Continue our scan for pending switches.
while scan_index[side] < switches[side].len() while scan_index[side] < switches[side].len()
&& witness.contains_all(&switches[side][scan_index[side]]) && witness.contains(switches[side][scan_index[side]])
{ {
scan_index[side] += 1; scan_index[side] += 1;
} }
@ -313,21 +322,22 @@ fn route<F: Field, const CHUNK_SIZE: usize>(
} }
struct PermutationGenerator<F: Field, const CHUNK_SIZE: usize> { struct PermutationGenerator<F: Field, const CHUNK_SIZE: usize> {
a_wires: Vec<[Target; CHUNK_SIZE]>, a: Vec<[Target; CHUNK_SIZE]>,
b_wires: Vec<[Target; CHUNK_SIZE]>, b: Vec<[Target; CHUNK_SIZE]>,
a_switches: Vec<Target>,
b_switches: Vec<Target>,
_phantom: PhantomData<F>,
} }
impl<F: Field, const CHUNK_SIZE: usize> SimpleGenerator<F> for PermutationGenerator<F, CHUNK_SIZE> { impl<F: Field, const CHUNK_SIZE: usize> SimpleGenerator<F> for PermutationGenerator<F, CHUNK_SIZE> {
fn dependencies(&self) -> Vec<Target> { fn dependencies(&self) -> Vec<Target> {
self.a_wires let mut deps = self.a_switches.clone();
.iter() deps.extend(self.b_switches.clone());
.map(|arr| arr.to_vec()) deps
.flatten()
.collect()
} }
fn run_once(&self, witness: &PartialWitness<F>, out_buffer: &mut GeneratedValues<F>) { fn run_once(&self, witness: &PartialWitness<F>, out_buffer: &mut GeneratedValues<F>) {
let wire_chunk_to_vals = |wire| { let wire_chunk_to_vals = |wire: [Target; CHUNK_SIZE]| {
let mut vals = [F::ZERO; CHUNK_SIZE]; let mut vals = [F::ZERO; CHUNK_SIZE];
for e in 0..CHUNK_SIZE { for e in 0..CHUNK_SIZE {
vals[e] = witness.get_target(wire[e]); vals[e] = witness.get_target(wire[e]);
@ -335,13 +345,13 @@ impl<F: Field, const CHUNK_SIZE: usize> SimpleGenerator<F> for PermutationGenera
vals vals
}; };
let a_values = self.a_wires.iter().map(wire_chunk_to_vals).collect(); let a_values = self.a.iter().map(|chunk| wire_chunk_to_vals(*chunk)).collect();
let b_values = self.b_wires.iter().map(wire_chunk_to_vals).collect(); let b_values = self.b.iter().map(|chunk| wire_chunk_to_vals(*chunk)).collect();
route( route(
a_values.clone(), a_values,
b_values.clone(), b_values,
self.a_wires.clone(), self.a_switches.clone(),
self.b_wires.clone(), self.b_switches.clone(),
witness, witness,
out_buffer, out_buffer,
); );