many fixes

This commit is contained in:
Nicholas Ward 2021-08-27 14:34:53 -07:00
parent a1d5f5b6fe
commit fe843db57f

View File

@ -1,5 +1,6 @@
use std::collections::BTreeMap;
use std::convert::TryInto;
use std::marker::PhantomData;
use crate::field::{extension_field::Extendable, field_types::Field};
use crate::gates::switch::SwitchGate;
@ -64,13 +65,13 @@ impl<F: Extendable<D>, const D: usize> CircuitBuilder<F, D> {
.extend(vec![None; CHUNK_SIZE - self.current_switch_gates.len()]);
}
let (gate, gate_index, mut next_copy) = match self.current_switch_gates[CHUNK_SIZE - 1] {
let (gate_index, mut next_copy) = match self.current_switch_gates[CHUNK_SIZE - 1] {
None => {
let gate = SwitchGate::<F, D, CHUNK_SIZE>::new_from_config(self.config.clone());
let gate_index = self.add_gate(gate.clone(), vec![]);
(gate, gate_index, 0)
(gate_index, 0)
}
Some((idx, next_copy)) => (self.gate_instances[idx], idx, next_copy),
Some((idx, next_copy)) => (idx, next_copy),
};
let num_copies =
@ -105,7 +106,7 @@ impl<F: Extendable<D>, const D: usize> CircuitBuilder<F, D> {
let switch = Target::wire(
gate_index,
SwitchGate::<F, D, CHUNK_SIZE>::wire_switch_bool(gate.num_copies, next_copy),
SwitchGate::<F, D, CHUNK_SIZE>::wire_switch_bool(num_copies, next_copy),
);
let c_arr: [Target; CHUNK_SIZE] = c.try_into().unwrap();
@ -144,13 +145,17 @@ impl<F: Extendable<D>, const D: usize> CircuitBuilder<F, D> {
a_num_switches
};
let mut a_switches = Vec::new();
let mut b_switches = Vec::new();
for i in 0..a_num_switches {
let (a_switch, out_1, out_2) = self.create_switch(a[i * 2], a[i * 2 + 1]);
let (switch, out_1, out_2) = self.create_switch(a[i * 2], a[i * 2 + 1]);
a_switches.push(switch);
child_1_a.push(out_1);
child_2_a.push(out_2);
}
for i in 0..b_num_switches {
let (b_switch, out_1, out_2) = self.create_switch(b[i * 2], b[i * 2 + 1]);
let (switch, out_1, out_2) = self.create_switch(b[i * 2], b[i * 2 + 1]);
b_switches.push(switch);
child_1_b.push(out_1);
child_2_b.push(out_2);
}
@ -167,15 +172,21 @@ impl<F: Extendable<D>, const D: usize> CircuitBuilder<F, D> {
self.assert_permutation(child_1_a, child_1_b);
self.assert_permutation(child_2_a, child_2_b);
self.add_generator(PermutationGenerator {});
self.add_generator(PermutationGenerator::<F, CHUNK_SIZE> {
a,
b,
a_switches,
b_switches,
_phantom: PhantomData,
});
}
}
fn route<F: Field, const CHUNK_SIZE: usize>(
a_values: Vec<[F; CHUNK_SIZE]>,
b_values: Vec<[F; CHUNK_SIZE]>,
a_switches: Vec<[Target; CHUNK_SIZE]>,
b_switches: Vec<[Target; CHUNK_SIZE]>,
a_switches: Vec<Target>,
b_switches: Vec<Target>,
witness: &PartialWitness<F>,
out_buffer: &mut GeneratedValues<F>,
) {
@ -219,7 +230,7 @@ fn route<F: Field, const CHUNK_SIZE: usize>(
return;
}
if witness.contains_all(&switches[other_side][other_switch_i]) {
if witness.contains(switches[other_side][other_switch_i]) {
// The other switch has already been routed.
return;
}
@ -252,9 +263,7 @@ fn route<F: Field, const CHUNK_SIZE: usize>(
switch_index: usize,
swap: bool| {
// First, we actually set the switch configuration.
for e in 0..CHUNK_SIZE {
out_buffer.set_target(switches[side][switch_index][e], F::from_bool(swap));
}
out_buffer.set_target(switches[side][switch_index], F::from_bool(swap));
// Then, we enqueue the two corresponding wires on the other side of the network, to ensure
// that they get routed in the next step.
@ -292,7 +301,7 @@ fn route<F: Field, const CHUNK_SIZE: usize>(
} else {
// We can route any switch next. Continue our scan for pending switches.
while scan_index[side] < switches[side].len()
&& witness.contains_all(&switches[side][scan_index[side]])
&& witness.contains(switches[side][scan_index[side]])
{
scan_index[side] += 1;
}
@ -313,21 +322,22 @@ fn route<F: Field, const CHUNK_SIZE: usize>(
}
struct PermutationGenerator<F: Field, const CHUNK_SIZE: usize> {
a_wires: Vec<[Target; CHUNK_SIZE]>,
b_wires: Vec<[Target; CHUNK_SIZE]>,
a: Vec<[Target; CHUNK_SIZE]>,
b: Vec<[Target; CHUNK_SIZE]>,
a_switches: Vec<Target>,
b_switches: Vec<Target>,
_phantom: PhantomData<F>,
}
impl<F: Field, const CHUNK_SIZE: usize> SimpleGenerator<F> for PermutationGenerator<F, CHUNK_SIZE> {
fn dependencies(&self) -> Vec<Target> {
self.a_wires
.iter()
.map(|arr| arr.to_vec())
.flatten()
.collect()
let mut deps = self.a_switches.clone();
deps.extend(self.b_switches.clone());
deps
}
fn run_once(&self, witness: &PartialWitness<F>, out_buffer: &mut GeneratedValues<F>) {
let wire_chunk_to_vals = |wire| {
let wire_chunk_to_vals = |wire: [Target; CHUNK_SIZE]| {
let mut vals = [F::ZERO; CHUNK_SIZE];
for e in 0..CHUNK_SIZE {
vals[e] = witness.get_target(wire[e]);
@ -335,13 +345,13 @@ impl<F: Field, const CHUNK_SIZE: usize> SimpleGenerator<F> for PermutationGenera
vals
};
let a_values = self.a_wires.iter().map(wire_chunk_to_vals).collect();
let b_values = self.b_wires.iter().map(wire_chunk_to_vals).collect();
let a_values = self.a.iter().map(|chunk| wire_chunk_to_vals(*chunk)).collect();
let b_values = self.b.iter().map(|chunk| wire_chunk_to_vals(*chunk)).collect();
route(
a_values.clone(),
b_values.clone(),
self.a_wires.clone(),
self.b_wires.clone(),
a_values,
b_values,
self.a_switches.clone(),
self.b_switches.clone(),
witness,
out_buffer,
);