plonky2/src/gadgets/permutation.rs

520 lines
18 KiB
Rust
Raw Normal View History

2021-08-26 14:23:30 -07:00
use std::collections::BTreeMap;
2021-08-27 14:34:53 -07:00
use std::marker::PhantomData;
2021-08-23 14:53:32 -07:00
use crate::field::field_types::RichField;
use crate::field::{extension_field::Extendable, field_types::Field};
2021-08-23 12:13:31 -07:00
use crate::gates::switch::SwitchGate;
use crate::iop::generator::{GeneratedValues, SimpleGenerator};
2021-09-06 08:38:47 -07:00
use crate::iop::target::Target;
2021-09-02 15:03:03 -07:00
use crate::iop::witness::{PartitionWitness, Witness};
2021-08-23 12:13:31 -07:00
use crate::plonk::circuit_builder::CircuitBuilder;
2021-08-26 14:23:30 -07:00
use crate::util::bimap::bimap_from_lists;
2021-08-23 12:13:31 -07:00
impl<F: RichField + Extendable<D>, const D: usize> CircuitBuilder<F, D> {
2021-08-23 12:13:31 -07:00
/// Assert that two lists of expressions evaluate to permutations of one another.
2021-09-01 16:38:10 -07:00
pub fn assert_permutation(&mut self, a: Vec<Vec<Target>>, b: Vec<Vec<Target>>) {
2021-08-23 12:13:31 -07:00
assert_eq!(
a.len(),
b.len(),
"Permutation must have same number of inputs and outputs"
);
2021-09-01 16:38:10 -07:00
assert_eq!(a[0].len(), b[0].len(), "Chunk size must be the same");
let chunk_size = a[0].len();
2021-08-23 12:13:31 -07:00
match a.len() {
// Two empty lists are permutations of one another, trivially.
0 => (),
// Two singleton lists are permutations of one another as long as their items are equal.
1 => {
2021-09-01 16:38:10 -07:00
for e in 0..chunk_size {
2021-09-02 15:03:03 -07:00
self.connect(a[0][e], b[0][e])
2021-08-23 12:13:31 -07:00
}
}
2021-09-01 16:38:10 -07:00
2 => {
self.assert_permutation_2x2(a[0].clone(), a[1].clone(), b[0].clone(), b[1].clone())
}
2021-08-23 12:13:31 -07:00
// For larger lists, we recursively use two smaller permutation networks.
//_ => self.assert_permutation_recursive(a, b)
_ => self.assert_permutation_recursive(a, b),
}
}
2021-08-30 11:41:57 -07:00
/// Assert that [a1, a2] is a permutation of [b1, b2].
2021-09-01 16:38:10 -07:00
fn assert_permutation_2x2(
2021-08-23 12:13:31 -07:00
&mut self,
2021-09-01 16:38:10 -07:00
a1: Vec<Target>,
a2: Vec<Target>,
b1: Vec<Target>,
b2: Vec<Target>,
2021-08-23 12:13:31 -07:00
) {
2021-09-01 16:38:10 -07:00
assert!(
a1.len() == a2.len() && a2.len() == b1.len() && b1.len() == b2.len(),
"Chunk size must be the same"
);
let chunk_size = a1.len();
2021-09-03 16:43:33 -07:00
let (_switch, gate_out1, gate_out2) = self.create_switch(a1, a2);
2021-09-01 16:38:10 -07:00
for e in 0..chunk_size {
2021-09-02 15:03:03 -07:00
self.connect(b1[e], gate_out1[e]);
self.connect(b2[e], gate_out2[e]);
2021-08-23 14:03:34 -07:00
}
}
2021-09-06 08:38:47 -07:00
/// Given two input wire chunks, add a new switch to the circuit (by adding one copy to a switch
/// gate). Returns the wire for the switch boolean, and the two output wire chunks.
2021-09-01 16:38:10 -07:00
fn create_switch(
2021-08-23 14:03:34 -07:00
&mut self,
2021-09-01 16:38:10 -07:00
a1: Vec<Target>,
a2: Vec<Target>,
) -> (Target, Vec<Target>, Vec<Target>) {
2021-09-06 08:38:47 -07:00
assert_eq!(a1.len(), a2.len(), "Chunk size must be the same");
2021-09-01 16:38:10 -07:00
let chunk_size = a1.len();
if self.current_switch_gates.len() < chunk_size {
2021-08-24 18:35:30 -07:00
self.current_switch_gates
2021-09-01 16:38:10 -07:00
.extend(vec![None; chunk_size - self.current_switch_gates.len()]);
2021-08-24 18:35:30 -07:00
}
2021-09-01 16:38:10 -07:00
let (gate, gate_index, mut next_copy) =
match self.current_switch_gates[chunk_size - 1].clone() {
None => {
let gate = SwitchGate::<F, D>::new_from_config(&self.config, chunk_size);
2021-09-01 16:38:10 -07:00
let gate_index = self.add_gate(gate.clone(), vec![]);
(gate, gate_index, 0)
}
Some((gate, idx, next_copy)) => (gate, idx, next_copy),
};
2021-08-24 18:35:30 -07:00
2021-09-01 16:38:10 -07:00
let num_copies = gate.num_copies;
2021-08-23 12:13:31 -07:00
2021-08-23 14:03:34 -07:00
let mut c = Vec::new();
let mut d = Vec::new();
2021-09-01 16:38:10 -07:00
for e in 0..chunk_size {
2021-09-02 15:03:03 -07:00
self.connect(
2021-08-30 11:41:57 -07:00
a1[e],
2021-09-01 16:38:10 -07:00
Target::wire(gate_index, gate.wire_first_input(next_copy, e)),
2021-08-24 18:35:30 -07:00
);
2021-09-02 15:03:03 -07:00
self.connect(
2021-08-30 11:41:57 -07:00
a2[e],
2021-09-01 16:38:10 -07:00
Target::wire(gate_index, gate.wire_second_input(next_copy, e)),
2021-08-24 18:35:30 -07:00
);
c.push(Target::wire(
gate_index,
2021-09-01 16:38:10 -07:00
gate.wire_first_output(next_copy, e),
2021-08-24 18:35:30 -07:00
));
d.push(Target::wire(
gate_index,
2021-09-01 16:38:10 -07:00
gate.wire_second_output(next_copy, e),
2021-08-24 18:35:30 -07:00
));
2021-08-23 12:13:31 -07:00
}
2021-08-23 14:03:34 -07:00
2021-09-01 16:38:10 -07:00
let switch = Target::wire(gate_index, gate.wire_switch_bool(next_copy));
2021-08-23 14:03:34 -07:00
2021-08-24 18:35:30 -07:00
next_copy += 1;
if next_copy == num_copies {
2021-09-06 08:38:47 -07:00
self.current_switch_gates[chunk_size - 1] = None;
2021-08-24 18:35:30 -07:00
} else {
2021-09-01 16:38:10 -07:00
self.current_switch_gates[chunk_size - 1] = Some((gate, gate_index, next_copy));
2021-08-24 18:35:30 -07:00
}
2021-09-01 16:38:10 -07:00
(switch, c, d)
2021-08-23 12:13:31 -07:00
}
2021-09-01 16:38:10 -07:00
fn assert_permutation_recursive(&mut self, a: Vec<Vec<Target>>, b: Vec<Vec<Target>>) {
assert_eq!(
a.len(),
b.len(),
"Permutation must have same number of inputs and outputs"
);
assert_eq!(a[0].len(), b[0].len(), "Chunk size must be the same");
2021-08-23 14:03:34 -07:00
let n = a.len();
let even = n % 2 == 0;
let mut child_1_a = Vec::new();
let mut child_1_b = Vec::new();
let mut child_2_a = Vec::new();
let mut child_2_b = Vec::new();
// See Figure 8 in the AS-Waksman paper.
let a_num_switches = n / 2;
2021-08-23 14:53:32 -07:00
let b_num_switches = if even {
a_num_switches - 1
} else {
a_num_switches
};
2021-08-23 14:03:34 -07:00
2021-08-27 14:34:53 -07:00
let mut a_switches = Vec::new();
let mut b_switches = Vec::new();
2021-08-23 14:03:34 -07:00
for i in 0..a_num_switches {
2021-09-01 16:38:10 -07:00
let (switch, out_1, out_2) = self.create_switch(a[i * 2].clone(), a[i * 2 + 1].clone());
2021-08-27 14:34:53 -07:00
a_switches.push(switch);
2021-08-23 14:53:32 -07:00
child_1_a.push(out_1);
child_2_a.push(out_2);
}
for i in 0..b_num_switches {
2021-09-01 16:38:10 -07:00
let (switch, out_1, out_2) = self.create_switch(b[i * 2].clone(), b[i * 2 + 1].clone());
2021-08-27 14:34:53 -07:00
b_switches.push(switch);
2021-08-23 14:53:32 -07:00
child_1_b.push(out_1);
child_2_b.push(out_2);
2021-08-23 14:03:34 -07:00
}
2021-08-23 14:53:32 -07:00
// See Figure 8 in the AS-Waksman paper.
if even {
child_1_b.push(b[n - 2].clone());
child_2_b.push(b[n - 1].clone());
} else {
child_2_a.push(a[n - 1].clone());
child_2_b.push(b[n - 1].clone());
}
self.assert_permutation(child_1_a, child_1_b);
self.assert_permutation(child_2_a, child_2_b);
2021-08-27 11:45:26 -07:00
2021-09-02 15:03:03 -07:00
self.add_simple_generator(PermutationGenerator::<F> {
2021-08-27 14:34:53 -07:00
a,
b,
a_switches,
b_switches,
_phantom: PhantomData,
});
2021-08-23 12:13:31 -07:00
}
}
2021-09-01 16:38:10 -07:00
fn route<F: Field>(
a_values: Vec<Vec<F>>,
b_values: Vec<Vec<F>>,
2021-08-27 14:34:53 -07:00
a_switches: Vec<Target>,
b_switches: Vec<Target>,
2021-09-02 15:03:03 -07:00
witness: &PartitionWitness<F>,
2021-08-26 14:23:30 -07:00
out_buffer: &mut GeneratedValues<F>,
2021-08-24 18:35:30 -07:00
) {
2021-08-26 14:23:30 -07:00
assert_eq!(a_values.len(), b_values.len());
let n = a_values.len();
let even = n % 2 == 0;
2021-09-06 08:38:47 -07:00
// We use a bimap to match indices of values in a to indices of the same values in b.
// This means that given a wire on one side, we can easily find the matching wire on the other side.
2021-08-26 14:23:30 -07:00
let ab_map = bimap_from_lists(a_values, b_values);
2021-09-06 08:38:47 -07:00
2021-08-26 14:23:30 -07:00
let switches = [a_switches, b_switches];
2021-09-06 08:38:47 -07:00
// We keep track of the new wires we've routed (after routing some wires, we need to check `witness`
// and `newly_set` instead of just `witness`.
2021-08-31 14:39:58 -07:00
let mut newly_set = [vec![false; n], vec![false; n]];
2021-08-26 14:23:30 -07:00
// Given a side and an index, returns the index in the other side that corresponds to the same value.
let ab_map_by_side = |side: usize, index: usize| -> usize {
*match side {
0 => ab_map.get_by_left(&index),
1 => ab_map.get_by_right(&index),
_ => panic!("Expected side to be 0 or 1"),
}
.unwrap()
};
// We maintain two maps for wires which have been routed to a particular subnetwork on one side
// of the network (left or right) but not the other. The keys are wire indices, and the values
// are subnetwork indices.
let mut partial_routes = [BTreeMap::new(), BTreeMap::new()];
// After we route a wire on one side, we find the corresponding wire on the other side and check
// if it still needs to be routed. If so, we add it to partial_routes.
let enqueue_other_side = |partial_routes: &mut [BTreeMap<usize, bool>],
2021-09-02 15:03:03 -07:00
witness: &PartitionWitness<F>,
2021-08-31 14:39:58 -07:00
newly_set: &mut [Vec<bool>],
2021-08-26 14:23:30 -07:00
side: usize,
this_i: usize,
subnet: bool| {
let other_side = 1 - side;
let other_i = ab_map_by_side(side, this_i);
let other_switch_i = other_i / 2;
if other_switch_i >= switches[other_side].len() {
// The other wire doesn't go through a switch, so there's no routing to be done.
// This happens in the case of the very last wire.
return;
}
2021-08-31 14:39:58 -07:00
if witness.contains(switches[other_side][other_switch_i])
|| newly_set[other_side][other_switch_i]
{
2021-08-26 14:23:30 -07:00
// The other switch has already been routed.
return;
}
let other_i_sibling = 4 * other_switch_i + 1 - other_i;
if let Some(&sibling_subnet) = partial_routes[other_side].get(&other_i_sibling) {
// The other switch's sibling is already pending routing.
assert_ne!(subnet, sibling_subnet);
} else {
let opt_old_subnet = partial_routes[other_side].insert(other_i, subnet);
if let Some(old_subnet) = opt_old_subnet {
assert_eq!(subnet, old_subnet, "Routing conflict (should never happen)");
}
}
};
// See Figure 8 in the AS-Waksman paper.
if even {
2021-08-31 14:39:58 -07:00
enqueue_other_side(
&mut partial_routes,
witness,
&mut newly_set,
1,
n - 2,
false,
);
enqueue_other_side(&mut partial_routes, witness, &mut newly_set, 1, n - 1, true);
2021-08-26 14:23:30 -07:00
} else {
2021-08-31 14:39:58 -07:00
enqueue_other_side(&mut partial_routes, witness, &mut newly_set, 0, n - 1, true);
enqueue_other_side(&mut partial_routes, witness, &mut newly_set, 1, n - 1, true);
2021-08-26 14:23:30 -07:00
}
2021-09-03 16:43:33 -07:00
let route_switch = |partial_routes: &mut [BTreeMap<usize, bool>],
witness: &PartitionWitness<F>,
out_buffer: &mut GeneratedValues<F>,
newly_set: &mut [Vec<bool>],
side: usize,
switch_index: usize,
swap: bool| {
2021-08-26 14:23:30 -07:00
// First, we actually set the switch configuration.
2021-08-27 14:34:53 -07:00
out_buffer.set_target(switches[side][switch_index], F::from_bool(swap));
2021-08-31 14:39:58 -07:00
newly_set[side][switch_index] = true;
2021-08-26 14:23:30 -07:00
// Then, we enqueue the two corresponding wires on the other side of the network, to ensure
// that they get routed in the next step.
let this_i_1 = switch_index * 2;
let this_i_2 = this_i_1 + 1;
2021-09-03 16:43:33 -07:00
enqueue_other_side(partial_routes, witness, newly_set, side, this_i_1, swap);
enqueue_other_side(partial_routes, witness, newly_set, side, this_i_2, !swap);
2021-08-26 14:23:30 -07:00
};
// If {a,b}_only_routes is empty, then we can route any switch next. For efficiency, we will
// simply do top-down scans (one on the left side, one on the right side) for switches which
// have not yet been routed. These variables represent the positions of those two scans.
let mut scan_index = [0, 0];
// Until both scans complete, we alternate back and worth between the left and right switch
// layers. We process any partially routed wires for that side, or if there aren't any, we route
// the next switch in our scan.
while scan_index[0] < switches[0].len() || scan_index[1] < switches[1].len() {
for side in 0..=1 {
if !partial_routes[side].is_empty() {
for (this_i, subnet) in partial_routes[side].clone().into_iter() {
let this_first_switch_input = this_i % 2 == 0;
let swap = this_first_switch_input == subnet;
let this_switch_i = this_i / 2;
route_switch(
&mut partial_routes,
witness,
out_buffer,
2021-09-03 16:43:33 -07:00
&mut newly_set,
2021-08-26 14:23:30 -07:00
side,
this_switch_i,
swap,
);
}
partial_routes[side].clear();
} else {
// We can route any switch next. Continue our scan for pending switches.
while scan_index[side] < switches[side].len()
2021-09-03 16:43:33 -07:00
&& (witness.contains(switches[side][scan_index[side]])
|| newly_set[side][scan_index[side]])
2021-08-26 14:23:30 -07:00
{
scan_index[side] += 1;
}
if scan_index[side] < switches[side].len() {
// Either switch configuration would work; we arbitrarily choose to not swap.
route_switch(
&mut partial_routes,
witness,
out_buffer,
2021-09-03 16:43:33 -07:00
&mut newly_set,
2021-08-26 14:23:30 -07:00
side,
scan_index[side],
false,
);
2021-08-31 10:14:27 -07:00
scan_index[side] += 1;
2021-08-26 14:23:30 -07:00
}
}
}
}
2021-08-23 12:13:31 -07:00
}
2021-09-02 15:03:03 -07:00
#[derive(Debug)]
2021-09-01 16:38:10 -07:00
struct PermutationGenerator<F: Field> {
a: Vec<Vec<Target>>,
b: Vec<Vec<Target>>,
2021-08-27 14:34:53 -07:00
a_switches: Vec<Target>,
b_switches: Vec<Target>,
_phantom: PhantomData<F>,
2021-08-24 18:35:30 -07:00
}
2021-09-01 16:38:10 -07:00
impl<F: Field> SimpleGenerator<F> for PermutationGenerator<F> {
2021-08-23 12:13:31 -07:00
fn dependencies(&self) -> Vec<Target> {
2021-09-06 08:38:47 -07:00
self.a.iter().chain(&self.b).flatten().cloned().collect()
2021-08-23 12:13:31 -07:00
}
2021-09-02 15:03:03 -07:00
fn run_once(&self, witness: &PartitionWitness<F>, out_buffer: &mut GeneratedValues<F>) {
2021-08-27 15:55:13 -07:00
let a_values = self
.a
.iter()
2021-09-01 16:38:10 -07:00
.map(|chunk| chunk.iter().map(|wire| witness.get_target(*wire)).collect())
2021-08-27 15:55:13 -07:00
.collect();
let b_values = self
.b
.iter()
2021-09-01 16:38:10 -07:00
.map(|chunk| chunk.iter().map(|wire| witness.get_target(*wire)).collect())
2021-08-27 15:55:13 -07:00
.collect();
2021-08-26 14:23:30 -07:00
route(
2021-08-27 14:34:53 -07:00
a_values,
b_values,
self.a_switches.clone(),
self.b_switches.clone(),
2021-08-26 14:23:30 -07:00
witness,
out_buffer,
);
2021-08-23 12:13:31 -07:00
}
}
#[cfg(test)]
mod tests {
use anyhow::Result;
2021-09-21 18:01:21 -07:00
use rand::{seq::SliceRandom, thread_rng, Rng};
2021-08-23 12:13:31 -07:00
use super::*;
use crate::field::field_types::Field;
2021-11-02 12:04:42 -07:00
use crate::field::goldilocks_field::GoldilocksField;
2021-08-23 12:13:31 -07:00
use crate::iop::witness::PartialWitness;
use crate::plonk::circuit_data::CircuitConfig;
use crate::plonk::verifier::verify;
2021-09-03 17:15:50 -07:00
fn test_permutation_good(size: usize) -> Result<()> {
2021-11-02 12:04:42 -07:00
type F = GoldilocksField;
2021-09-04 16:44:49 -07:00
const D: usize = 4;
let config = CircuitConfig::large_config();
2021-09-02 15:03:03 -07:00
let pw = PartialWitness::new();
let mut builder = CircuitBuilder::<F, D>::new(config);
2021-08-26 14:23:30 -07:00
2021-09-03 17:15:50 -07:00
let lst: Vec<F> = (0..size * 2).map(|n| F::from_canonical_usize(n)).collect();
let a: Vec<Vec<Target>> = lst[..]
2021-09-06 08:38:47 -07:00
.chunks(2)
2021-09-03 17:15:50 -07:00
.map(|pair| vec![builder.constant(pair[0]), builder.constant(pair[1])])
.collect();
let mut b = a.clone();
b.shuffle(&mut thread_rng());
2021-08-26 14:23:30 -07:00
builder.assert_permutation(a, b);
let data = builder.build();
let proof = data.prove(pw)?;
2021-08-26 14:23:30 -07:00
verify(proof, &data.verifier_only, &data.common)
}
2021-09-21 18:01:21 -07:00
fn test_permutation_duplicates(size: usize) -> Result<()> {
2021-11-02 12:04:42 -07:00
type F = GoldilocksField;
2021-09-21 18:01:21 -07:00
const D: usize = 4;
let config = CircuitConfig::large_config();
2021-09-21 18:01:21 -07:00
let pw = PartialWitness::new();
let mut builder = CircuitBuilder::<F, D>::new(config);
let mut rng = thread_rng();
let lst: Vec<F> = (0..size * 2)
.map(|_| F::from_canonical_usize(rng.gen_range(0..2usize)))
.collect();
let a: Vec<Vec<Target>> = lst[..]
.chunks(2)
.map(|pair| vec![builder.constant(pair[0]), builder.constant(pair[1])])
.collect();
let mut b = a.clone();
b.shuffle(&mut thread_rng());
builder.assert_permutation(a, b);
let data = builder.build();
let proof = data.prove(pw)?;
2021-09-21 18:01:21 -07:00
verify(proof, &data.verifier_only, &data.common)
}
2021-09-03 17:15:50 -07:00
fn test_permutation_bad(size: usize) -> Result<()> {
2021-11-02 12:04:42 -07:00
type F = GoldilocksField;
2021-09-04 16:44:49 -07:00
const D: usize = 4;
let config = CircuitConfig::large_config();
2021-09-02 15:03:03 -07:00
let pw = PartialWitness::new();
let mut builder = CircuitBuilder::<F, D>::new(config);
2021-08-23 12:13:31 -07:00
2021-09-06 08:38:47 -07:00
let lst1: Vec<F> = F::rand_vec(size * 2);
let lst2: Vec<F> = F::rand_vec(size * 2);
2021-09-03 17:15:50 -07:00
let a: Vec<Vec<Target>> = lst1[..]
2021-09-06 08:38:47 -07:00
.chunks(2)
2021-09-03 17:15:50 -07:00
.map(|pair| vec![builder.constant(pair[0]), builder.constant(pair[1])])
.collect();
let b: Vec<Vec<Target>> = lst2[..]
2021-09-06 08:38:47 -07:00
.chunks(2)
2021-09-03 17:15:50 -07:00
.map(|pair| vec![builder.constant(pair[0]), builder.constant(pair[1])])
.collect();
2021-08-30 12:28:39 -07:00
builder.assert_permutation(a, b);
2021-08-23 12:13:31 -07:00
let data = builder.build();
data.prove(pw)?;
2021-08-23 12:13:31 -07:00
2021-09-06 08:38:47 -07:00
Ok(())
2021-08-30 12:28:39 -07:00
}
2021-09-03 17:03:21 -07:00
2021-09-21 18:01:21 -07:00
#[test]
fn test_permutations_duplicates() -> Result<()> {
for n in 2..9 {
test_permutation_duplicates(n)?;
}
Ok(())
}
2021-09-03 17:03:21 -07:00
#[test]
2021-09-03 17:15:50 -07:00
fn test_permutations_good() -> Result<()> {
for n in 2..9 {
2021-09-06 08:38:47 -07:00
test_permutation_good(n)?;
2021-09-03 17:15:50 -07:00
}
2021-09-03 17:03:21 -07:00
2021-09-03 17:15:50 -07:00
Ok(())
}
2021-09-03 17:03:21 -07:00
2021-09-03 17:15:50 -07:00
#[test]
#[should_panic]
2021-09-06 08:38:47 -07:00
fn test_permutation_bad_small() {
let size = 2;
test_permutation_bad(size).unwrap()
}
#[test]
#[should_panic]
fn test_permutation_bad_medium() {
let size = 6;
test_permutation_bad(size).unwrap()
}
#[test]
#[should_panic]
fn test_permutation_bad_large() {
let size = 10;
test_permutation_bad(size).unwrap()
2021-09-03 17:03:21 -07:00
}
2021-08-23 12:13:31 -07:00
}