Merge pull request #217 from mir-protocol/permutation

Permutation network code
This commit is contained in:
Nicholas Ward 2021-09-06 21:39:10 -07:00 committed by GitHub
commit 50274883c7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
22 changed files with 832 additions and 133 deletions

View File

@ -12,6 +12,8 @@ edition = "2018"
default-run = "bench_recursion"
[dependencies]
array_tool = "1.0.3"
bimap = "0.4.0"
env_logger = "0.9.0"
log = "0.4.14"
itertools = "0.10.0"

View File

@ -108,7 +108,7 @@ impl<F: PrimeField + Extendable<D>, const D: usize> CircuitBuilder<F, D> {
exponent_bits: impl IntoIterator<Item = impl Borrow<BoolTarget>>,
) -> Target {
let _false = self._false();
let gate = ExponentiationGate::new(self.config.clone());
let gate = ExponentiationGate::new_from_config(self.config.clone());
let num_power_bits = gate.num_power_bits;
let mut exp_bits_vec: Vec<BoolTarget> =
exponent_bits.into_iter().map(|b| *b.borrow()).collect();

View File

@ -413,7 +413,7 @@ impl<F: PrimeField + Extendable<D>, const D: usize> CircuitBuilder<F, D> {
) -> ExtensionTarget<D> {
let inv = self.add_virtual_extension_target();
let one = self.one_extension();
self.add_generator(QuotientGeneratorExtension {
self.add_simple_generator(QuotientGeneratorExtension {
numerator: one,
denominator: y,
quotient: inv,

View File

@ -3,6 +3,7 @@ pub mod arithmetic_extension;
pub mod hash;
pub mod insert;
pub mod interpolation;
pub mod permutation;
pub mod polynomial;
pub mod random_access;
pub mod range_check;

483
src/gadgets/permutation.rs Normal file
View File

@ -0,0 +1,483 @@
use std::collections::BTreeMap;
use std::marker::PhantomData;
use crate::field::{
extension_field::Extendable,
field_types::{Field, PrimeField},
};
use crate::gates::switch::SwitchGate;
use crate::iop::generator::{GeneratedValues, SimpleGenerator};
use crate::iop::target::Target;
use crate::iop::witness::{PartitionWitness, Witness};
use crate::plonk::circuit_builder::CircuitBuilder;
use crate::util::bimap::bimap_from_lists;
impl<F: PrimeField + Extendable<D>, const D: usize> CircuitBuilder<F, D> {
/// Assert that two lists of expressions evaluate to permutations of one another.
pub fn assert_permutation(&mut self, a: Vec<Vec<Target>>, b: Vec<Vec<Target>>) {
assert_eq!(
a.len(),
b.len(),
"Permutation must have same number of inputs and outputs"
);
assert_eq!(a[0].len(), b[0].len(), "Chunk size must be the same");
let chunk_size = a[0].len();
match a.len() {
// Two empty lists are permutations of one another, trivially.
0 => (),
// Two singleton lists are permutations of one another as long as their items are equal.
1 => {
for e in 0..chunk_size {
self.connect(a[0][e], b[0][e])
}
}
2 => {
self.assert_permutation_2x2(a[0].clone(), a[1].clone(), b[0].clone(), b[1].clone())
}
// For larger lists, we recursively use two smaller permutation networks.
//_ => self.assert_permutation_recursive(a, b)
_ => self.assert_permutation_recursive(a, b),
}
}
/// Assert that [a1, a2] is a permutation of [b1, b2].
fn assert_permutation_2x2(
&mut self,
a1: Vec<Target>,
a2: Vec<Target>,
b1: Vec<Target>,
b2: Vec<Target>,
) {
assert!(
a1.len() == a2.len() && a2.len() == b1.len() && b1.len() == b2.len(),
"Chunk size must be the same"
);
let chunk_size = a1.len();
let (_switch, gate_out1, gate_out2) = self.create_switch(a1, a2);
for e in 0..chunk_size {
self.connect(b1[e], gate_out1[e]);
self.connect(b2[e], gate_out2[e]);
}
}
/// Given two input wire chunks, add a new switch to the circuit (by adding one copy to a switch
/// gate). Returns the wire for the switch boolean, and the two output wire chunks.
fn create_switch(
&mut self,
a1: Vec<Target>,
a2: Vec<Target>,
) -> (Target, Vec<Target>, Vec<Target>) {
assert_eq!(a1.len(), a2.len(), "Chunk size must be the same");
let chunk_size = a1.len();
if self.current_switch_gates.len() < chunk_size {
self.current_switch_gates
.extend(vec![None; chunk_size - self.current_switch_gates.len()]);
}
let (gate, gate_index, mut next_copy) =
match self.current_switch_gates[chunk_size - 1].clone() {
None => {
let gate = SwitchGate::<F, D>::new_from_config(self.config.clone(), chunk_size);
let gate_index = self.add_gate(gate.clone(), vec![]);
(gate, gate_index, 0)
}
Some((gate, idx, next_copy)) => (gate, idx, next_copy),
};
let num_copies = gate.num_copies;
let mut c = Vec::new();
let mut d = Vec::new();
for e in 0..chunk_size {
self.connect(
a1[e],
Target::wire(gate_index, gate.wire_first_input(next_copy, e)),
);
self.connect(
a2[e],
Target::wire(gate_index, gate.wire_second_input(next_copy, e)),
);
c.push(Target::wire(
gate_index,
gate.wire_first_output(next_copy, e),
));
d.push(Target::wire(
gate_index,
gate.wire_second_output(next_copy, e),
));
}
let switch = Target::wire(gate_index, gate.wire_switch_bool(next_copy));
next_copy += 1;
if next_copy == num_copies {
self.current_switch_gates[chunk_size - 1] = None;
} else {
self.current_switch_gates[chunk_size - 1] = Some((gate, gate_index, next_copy));
}
(switch, c, d)
}
fn assert_permutation_recursive(&mut self, a: Vec<Vec<Target>>, b: Vec<Vec<Target>>) {
assert_eq!(
a.len(),
b.len(),
"Permutation must have same number of inputs and outputs"
);
assert_eq!(a[0].len(), b[0].len(), "Chunk size must be the same");
let n = a.len();
let even = n % 2 == 0;
let mut child_1_a = Vec::new();
let mut child_1_b = Vec::new();
let mut child_2_a = Vec::new();
let mut child_2_b = Vec::new();
// See Figure 8 in the AS-Waksman paper.
let a_num_switches = n / 2;
let b_num_switches = if even {
a_num_switches - 1
} else {
a_num_switches
};
let mut a_switches = Vec::new();
let mut b_switches = Vec::new();
for i in 0..a_num_switches {
let (switch, out_1, out_2) = self.create_switch(a[i * 2].clone(), a[i * 2 + 1].clone());
a_switches.push(switch);
child_1_a.push(out_1);
child_2_a.push(out_2);
}
for i in 0..b_num_switches {
let (switch, out_1, out_2) = self.create_switch(b[i * 2].clone(), b[i * 2 + 1].clone());
b_switches.push(switch);
child_1_b.push(out_1);
child_2_b.push(out_2);
}
// See Figure 8 in the AS-Waksman paper.
if even {
child_1_b.push(b[n - 2].clone());
child_2_b.push(b[n - 1].clone());
} else {
child_2_a.push(a[n - 1].clone());
child_2_b.push(b[n - 1].clone());
}
self.assert_permutation(child_1_a, child_1_b);
self.assert_permutation(child_2_a, child_2_b);
self.add_simple_generator(PermutationGenerator::<F> {
a,
b,
a_switches,
b_switches,
_phantom: PhantomData,
});
}
}
fn route<F: Field>(
a_values: Vec<Vec<F>>,
b_values: Vec<Vec<F>>,
a_switches: Vec<Target>,
b_switches: Vec<Target>,
witness: &PartitionWitness<F>,
out_buffer: &mut GeneratedValues<F>,
) {
assert_eq!(a_values.len(), b_values.len());
let n = a_values.len();
let even = n % 2 == 0;
// We use a bimap to match indices of values in a to indices of the same values in b.
// This means that given a wire on one side, we can easily find the matching wire on the other side.
let ab_map = bimap_from_lists(a_values, b_values);
let switches = [a_switches, b_switches];
// We keep track of the new wires we've routed (after routing some wires, we need to check `witness`
// and `newly_set` instead of just `witness`.
let mut newly_set = [vec![false; n], vec![false; n]];
// Given a side and an index, returns the index in the other side that corresponds to the same value.
let ab_map_by_side = |side: usize, index: usize| -> usize {
*match side {
0 => ab_map.get_by_left(&index),
1 => ab_map.get_by_right(&index),
_ => panic!("Expected side to be 0 or 1"),
}
.unwrap()
};
// We maintain two maps for wires which have been routed to a particular subnetwork on one side
// of the network (left or right) but not the other. The keys are wire indices, and the values
// are subnetwork indices.
let mut partial_routes = [BTreeMap::new(), BTreeMap::new()];
// After we route a wire on one side, we find the corresponding wire on the other side and check
// if it still needs to be routed. If so, we add it to partial_routes.
let enqueue_other_side = |partial_routes: &mut [BTreeMap<usize, bool>],
witness: &PartitionWitness<F>,
newly_set: &mut [Vec<bool>],
side: usize,
this_i: usize,
subnet: bool| {
let other_side = 1 - side;
let other_i = ab_map_by_side(side, this_i);
let other_switch_i = other_i / 2;
if other_switch_i >= switches[other_side].len() {
// The other wire doesn't go through a switch, so there's no routing to be done.
// This happens in the case of the very last wire.
return;
}
if witness.contains(switches[other_side][other_switch_i])
|| newly_set[other_side][other_switch_i]
{
// The other switch has already been routed.
return;
}
let other_i_sibling = 4 * other_switch_i + 1 - other_i;
if let Some(&sibling_subnet) = partial_routes[other_side].get(&other_i_sibling) {
// The other switch's sibling is already pending routing.
assert_ne!(subnet, sibling_subnet);
} else {
let opt_old_subnet = partial_routes[other_side].insert(other_i, subnet);
if let Some(old_subnet) = opt_old_subnet {
assert_eq!(subnet, old_subnet, "Routing conflict (should never happen)");
}
}
};
// See Figure 8 in the AS-Waksman paper.
if even {
enqueue_other_side(
&mut partial_routes,
witness,
&mut newly_set,
1,
n - 2,
false,
);
enqueue_other_side(&mut partial_routes, witness, &mut newly_set, 1, n - 1, true);
} else {
enqueue_other_side(&mut partial_routes, witness, &mut newly_set, 0, n - 1, true);
enqueue_other_side(&mut partial_routes, witness, &mut newly_set, 1, n - 1, true);
}
let route_switch = |partial_routes: &mut [BTreeMap<usize, bool>],
witness: &PartitionWitness<F>,
out_buffer: &mut GeneratedValues<F>,
newly_set: &mut [Vec<bool>],
side: usize,
switch_index: usize,
swap: bool| {
// First, we actually set the switch configuration.
out_buffer.set_target(switches[side][switch_index], F::from_bool(swap));
newly_set[side][switch_index] = true;
// Then, we enqueue the two corresponding wires on the other side of the network, to ensure
// that they get routed in the next step.
let this_i_1 = switch_index * 2;
let this_i_2 = this_i_1 + 1;
enqueue_other_side(partial_routes, witness, newly_set, side, this_i_1, swap);
enqueue_other_side(partial_routes, witness, newly_set, side, this_i_2, !swap);
};
// If {a,b}_only_routes is empty, then we can route any switch next. For efficiency, we will
// simply do top-down scans (one on the left side, one on the right side) for switches which
// have not yet been routed. These variables represent the positions of those two scans.
let mut scan_index = [0, 0];
// Until both scans complete, we alternate back and worth between the left and right switch
// layers. We process any partially routed wires for that side, or if there aren't any, we route
// the next switch in our scan.
while scan_index[0] < switches[0].len() || scan_index[1] < switches[1].len() {
for side in 0..=1 {
if !partial_routes[side].is_empty() {
for (this_i, subnet) in partial_routes[side].clone().into_iter() {
let this_first_switch_input = this_i % 2 == 0;
let swap = this_first_switch_input == subnet;
let this_switch_i = this_i / 2;
route_switch(
&mut partial_routes,
witness,
out_buffer,
&mut newly_set,
side,
this_switch_i,
swap,
);
}
partial_routes[side].clear();
} else {
// We can route any switch next. Continue our scan for pending switches.
while scan_index[side] < switches[side].len()
&& (witness.contains(switches[side][scan_index[side]])
|| newly_set[side][scan_index[side]])
{
scan_index[side] += 1;
}
if scan_index[side] < switches[side].len() {
// Either switch configuration would work; we arbitrarily choose to not swap.
route_switch(
&mut partial_routes,
witness,
out_buffer,
&mut newly_set,
side,
scan_index[side],
false,
);
scan_index[side] += 1;
}
}
}
}
}
#[derive(Debug)]
struct PermutationGenerator<F: Field> {
a: Vec<Vec<Target>>,
b: Vec<Vec<Target>>,
a_switches: Vec<Target>,
b_switches: Vec<Target>,
_phantom: PhantomData<F>,
}
impl<F: Field> SimpleGenerator<F> for PermutationGenerator<F> {
fn dependencies(&self) -> Vec<Target> {
self.a.iter().chain(&self.b).flatten().cloned().collect()
}
fn run_once(&self, witness: &PartitionWitness<F>, out_buffer: &mut GeneratedValues<F>) {
let a_values = self
.a
.iter()
.map(|chunk| chunk.iter().map(|wire| witness.get_target(*wire)).collect())
.collect();
let b_values = self
.b
.iter()
.map(|chunk| chunk.iter().map(|wire| witness.get_target(*wire)).collect())
.collect();
route(
a_values,
b_values,
self.a_switches.clone(),
self.b_switches.clone(),
witness,
out_buffer,
);
}
}
#[cfg(test)]
mod tests {
use anyhow::Result;
use rand::{seq::SliceRandom, thread_rng};
use super::*;
use crate::field::crandall_field::CrandallField;
use crate::field::field_types::Field;
use crate::iop::witness::PartialWitness;
use crate::plonk::circuit_data::CircuitConfig;
use crate::plonk::verifier::verify;
fn test_permutation_good(size: usize) -> Result<()> {
type F = CrandallField;
const D: usize = 4;
let config = CircuitConfig::large_zk_config();
let pw = PartialWitness::new();
let mut builder = CircuitBuilder::<F, D>::new(config);
let lst: Vec<F> = (0..size * 2).map(|n| F::from_canonical_usize(n)).collect();
let a: Vec<Vec<Target>> = lst[..]
.chunks(2)
.map(|pair| vec![builder.constant(pair[0]), builder.constant(pair[1])])
.collect();
let mut b = a.clone();
b.shuffle(&mut thread_rng());
builder.assert_permutation(a, b);
let data = builder.build();
let proof = data.prove(pw).unwrap();
verify(proof, &data.verifier_only, &data.common)
}
fn test_permutation_bad(size: usize) -> Result<()> {
type F = CrandallField;
const D: usize = 4;
let config = CircuitConfig::large_zk_config();
let pw = PartialWitness::new();
let mut builder = CircuitBuilder::<F, D>::new(config);
let lst1: Vec<F> = F::rand_vec(size * 2);
let lst2: Vec<F> = F::rand_vec(size * 2);
let a: Vec<Vec<Target>> = lst1[..]
.chunks(2)
.map(|pair| vec![builder.constant(pair[0]), builder.constant(pair[1])])
.collect();
let b: Vec<Vec<Target>> = lst2[..]
.chunks(2)
.map(|pair| vec![builder.constant(pair[0]), builder.constant(pair[1])])
.collect();
builder.assert_permutation(a, b);
let data = builder.build();
data.prove(pw).unwrap();
Ok(())
}
#[test]
fn test_permutations_good() -> Result<()> {
for n in 2..9 {
test_permutation_good(n)?;
}
Ok(())
}
#[test]
#[should_panic]
fn test_permutation_bad_small() {
let size = 2;
test_permutation_bad(size).unwrap()
}
#[test]
#[should_panic]
fn test_permutation_bad_medium() {
let size = 6;
test_permutation_bad(size).unwrap()
}
#[test]
#[should_panic]
fn test_permutation_bad_large() {
let size = 10;
test_permutation_bad(size).unwrap()
}
}

View File

@ -28,7 +28,7 @@ impl<F: PrimeField + Extendable<D>, const D: usize> CircuitBuilder<F, D> {
let high_gate = self.add_gate(BaseSumGate::<2>::new(num_bits - n_log), vec![]);
let low = Target::wire(low_gate, BaseSumGate::<2>::WIRE_SUM);
let high = Target::wire(high_gate, BaseSumGate::<2>::WIRE_SUM);
self.add_generator(LowHighGenerator {
self.add_simple_generator(LowHighGenerator {
integer: x,
n_log,
low,

View File

@ -44,7 +44,7 @@ impl<F: PrimeField + Extendable<D>, const D: usize> CircuitBuilder<F, D> {
self.connect(limb.borrow().target, Target::wire(gate_index, wire));
}
self.add_generator(BaseSumGenerator::<2> {
self.add_simple_generator(BaseSumGenerator::<2> {
gate_index,
limbs: bits.map(|l| *l.borrow()).collect(),
});

View File

@ -47,7 +47,7 @@ impl<F: PrimeField + Extendable<D>, const D: usize> CircuitBuilder<F, D> {
}
self.connect(acc, integer);
self.add_generator(WireSplitGenerator {
self.add_simple_generator(WireSplitGenerator {
integer,
gates,
num_limbs: bits_per_gate,

View File

@ -111,12 +111,15 @@ impl<F: PrimeField + Extendable<D>, const D: usize> Gate<F, D> for ArithmeticExt
) -> Vec<Box<dyn WitnessGenerator<F>>> {
(0..NUM_ARITHMETIC_OPS)
.map(|i| {
let g: Box<dyn WitnessGenerator<F>> = Box::new(ArithmeticExtensionGenerator {
gate_index,
const_0: local_constants[0],
const_1: local_constants[1],
i,
});
let g: Box<dyn WitnessGenerator<F>> = Box::new(
ArithmeticExtensionGenerator {
gate_index,
const_0: local_constants[0],
const_1: local_constants[1],
i,
}
.adapter(),
);
g
})
.collect::<Vec<_>>()

View File

@ -105,7 +105,7 @@ impl<F: PrimeField + Extendable<D>, const D: usize, const B: usize> Gate<F, D> f
gate_index,
num_limbs: self.num_limbs,
};
vec![Box::new(gen)]
vec![Box::new(gen.adapter())]
}
// 1 for the sum then `num_limbs` for the limbs.

View File

@ -54,7 +54,7 @@ impl<F: PrimeField + Extendable<D>, const D: usize> Gate<F, D> for ConstantGate
gate_index,
constant: local_constants[0],
};
vec![Box::new(gen)]
vec![Box::new(gen.adapter())]
}
fn num_wires(&self) -> usize {

View File

@ -20,14 +20,18 @@ pub(crate) struct ExponentiationGate<F: PrimeField + Extendable<D>, const D: usi
}
impl<F: PrimeField + Extendable<D>, const D: usize> ExponentiationGate<F, D> {
pub fn new(config: CircuitConfig) -> Self {
let num_power_bits = Self::max_power_bits(config.num_wires, config.num_routed_wires);
pub fn new(num_power_bits: usize) -> Self {
Self {
num_power_bits,
_phantom: PhantomData,
}
}
pub fn new_from_config(config: CircuitConfig) -> Self {
let num_power_bits = Self::max_power_bits(config.num_wires, config.num_routed_wires);
Self::new(num_power_bits)
}
fn max_power_bits(num_wires: usize, num_routed_wires: usize) -> usize {
// 2 wires are reserved for the base and output.
let max_for_routed_wires = num_routed_wires - 2;
@ -180,7 +184,7 @@ impl<F: PrimeField + Extendable<D>, const D: usize> Gate<F, D> for Exponentiatio
gate_index,
gate: self.clone(),
};
vec![Box::new(gen)]
vec![Box::new(gen.adapter())]
}
fn num_wires(&self) -> usize {
@ -298,12 +302,14 @@ mod tests {
..CircuitConfig::large_config()
};
test_low_degree::<CrandallField, _, 4>(ExponentiationGate::new(config));
test_low_degree::<CrandallField, _, 4>(ExponentiationGate::new_from_config(config));
}
#[test]
fn eval_fns() -> Result<()> {
test_eval_fns::<CrandallField, _, 4>(ExponentiationGate::new(CircuitConfig::large_config()))
test_eval_fns::<CrandallField, _, 4>(ExponentiationGate::new_from_config(
CircuitConfig::large_config(),
))
}
#[test]

View File

@ -219,7 +219,7 @@ impl<F: PrimeField + Extendable<D>, const D: usize, const R: usize> Gate<F, D>
gate_index,
constants: self.constants.clone(),
};
vec![Box::new(gen)]
vec![Box::new(gen.adapter())]
}
fn num_wires(&self) -> usize {

View File

@ -220,7 +220,7 @@ impl<F: PrimeField + Extendable<D>, const D: usize> Gate<F, D> for InsertionGate
gate_index,
gate: self.clone(),
};
vec![Box::new(gen)]
vec![Box::new(gen.adapter())]
}
fn num_wires(&self) -> usize {

View File

@ -190,7 +190,7 @@ impl<F: PrimeField + Extendable<D>, const D: usize> Gate<F, D> for Interpolation
gate: self.clone(),
_phantom: PhantomData,
};
vec![Box::new(gen)]
vec![Box::new(gen.adapter())]
}
fn num_wires(&self) -> usize {

View File

@ -167,7 +167,7 @@ impl<F: PrimeField + Extendable<D>, const D: usize> Gate<F, D> for RandomAccessG
gate_index,
gate: self.clone(),
};
vec![Box::new(gen)]
vec![Box::new(gen.adapter())]
}
fn num_wires(&self) -> usize {

View File

@ -137,10 +137,13 @@ impl<F: PrimeField + Extendable<D>, const D: usize> Gate<F, D> for ReducingGate<
gate_index: usize,
_local_constants: &[F],
) -> Vec<Box<dyn WitnessGenerator<F>>> {
vec![Box::new(ReducingGenerator {
gate_index,
gate: self.clone(),
})]
vec![Box::new(
ReducingGenerator {
gate_index,
gate: self.clone(),
}
.adapter(),
)]
}
fn num_wires(&self) -> usize {

View File

@ -1,10 +1,12 @@
use std::marker::PhantomData;
use array_tool::vec::Union;
use crate::field::extension_field::target::ExtensionTarget;
use crate::field::extension_field::Extendable;
use crate::field::field_types::{Field, PrimeField};
use crate::gates::gate::Gate;
use crate::iop::generator::{GeneratedValues, SimpleGenerator, WitnessGenerator};
use crate::iop::generator::{GeneratedValues, WitnessGenerator};
use crate::iop::target::Target;
use crate::iop::wire::Wire;
use crate::iop::witness::{PartitionWitness, Witness};
@ -14,62 +16,59 @@ use crate::plonk::vars::{EvaluationTargets, EvaluationVars, EvaluationVarsBase};
/// A gate for conditionally swapping input values based on a boolean.
#[derive(Clone, Debug)]
pub(crate) struct SwitchGate<F: PrimeField + Extendable<D>, const D: usize, const CHUNK_SIZE: usize>
{
num_copies: usize,
pub(crate) struct SwitchGate<F: PrimeField + Extendable<D>, const D: usize> {
pub(crate) chunk_size: usize,
pub(crate) num_copies: usize,
_phantom: PhantomData<F>,
}
impl<F: PrimeField + Extendable<D>, const D: usize, const CHUNK_SIZE: usize>
SwitchGate<F, D, CHUNK_SIZE>
{
pub fn new(config: CircuitConfig) -> Self {
let num_copies = Self::max_num_copies(config.num_routed_wires);
impl<F: PrimeField + Extendable<D>, const D: usize> SwitchGate<F, D> {
pub fn new(num_copies: usize, chunk_size: usize) -> Self {
Self {
chunk_size,
num_copies,
_phantom: PhantomData,
}
}
fn max_num_copies(num_routed_wires: usize) -> usize {
num_routed_wires / (4 * CHUNK_SIZE + 1)
pub fn new_from_config(config: CircuitConfig, chunk_size: usize) -> Self {
let num_copies = Self::max_num_copies(config.num_routed_wires, chunk_size);
Self::new(num_copies, chunk_size)
}
pub fn max_num_copies(num_routed_wires: usize, chunk_size: usize) -> usize {
num_routed_wires / (4 * chunk_size + 1)
}
pub fn wire_first_input(&self, copy: usize, element: usize) -> usize {
debug_assert!(element < self.chunk_size);
copy * (4 * self.chunk_size + 1) + element
}
pub fn wire_second_input(&self, copy: usize, element: usize) -> usize {
debug_assert!(element < self.chunk_size);
copy * (4 * self.chunk_size + 1) + self.chunk_size + element
}
pub fn wire_first_output(&self, copy: usize, element: usize) -> usize {
debug_assert!(element < self.chunk_size);
copy * (4 * self.chunk_size + 1) + 2 * self.chunk_size + element
}
pub fn wire_second_output(&self, copy: usize, element: usize) -> usize {
debug_assert!(element < self.chunk_size);
copy * (4 * self.chunk_size + 1) + 3 * self.chunk_size + element
}
pub fn wire_switch_bool(&self, copy: usize) -> usize {
debug_assert!(copy < self.num_copies);
copy * (4 * CHUNK_SIZE + 1)
}
pub fn wire_first_input(&self, copy: usize, element: usize) -> usize {
debug_assert!(copy < self.num_copies);
debug_assert!(element < CHUNK_SIZE);
copy * (4 * CHUNK_SIZE + 1) + 1 + element
}
pub fn wire_second_input(&self, copy: usize, element: usize) -> usize {
debug_assert!(copy < self.num_copies);
debug_assert!(element < CHUNK_SIZE);
copy * (4 * CHUNK_SIZE + 1) + 1 + CHUNK_SIZE + element
}
pub fn wire_first_output(&self, copy: usize, element: usize) -> usize {
debug_assert!(copy < self.num_copies);
debug_assert!(element < CHUNK_SIZE);
copy * (4 * CHUNK_SIZE + 1) + 1 + 2 * CHUNK_SIZE + element
}
pub fn wire_second_output(&self, copy: usize, element: usize) -> usize {
debug_assert!(copy < self.num_copies);
debug_assert!(element < CHUNK_SIZE);
copy * (4 * CHUNK_SIZE + 1) + 1 + 3 * CHUNK_SIZE + element
copy * (4 * self.chunk_size + 1) + 4 * self.chunk_size
}
}
impl<F: PrimeField + Extendable<D>, const D: usize, const CHUNK_SIZE: usize> Gate<F, D>
for SwitchGate<F, D, CHUNK_SIZE>
{
impl<F: PrimeField + Extendable<D>, const D: usize> Gate<F, D> for SwitchGate<F, D> {
fn id(&self) -> String {
format!("{:?}<D={},CHUNK_SIZE={}>", self, D, CHUNK_SIZE)
format!("{:?}<D={}>", self, D)
}
fn eval_unfiltered(&self, vars: EvaluationVars<F, D>) -> Vec<F::Extension> {
@ -79,7 +78,7 @@ impl<F: PrimeField + Extendable<D>, const D: usize, const CHUNK_SIZE: usize> Gat
let switch_bool = vars.local_wires[self.wire_switch_bool(c)];
let not_switch = F::Extension::ONE - switch_bool;
for e in 0..CHUNK_SIZE {
for e in 0..self.chunk_size {
let first_input = vars.local_wires[self.wire_first_input(c, e)];
let second_input = vars.local_wires[self.wire_second_input(c, e)];
let first_output = vars.local_wires[self.wire_first_output(c, e)];
@ -102,7 +101,7 @@ impl<F: PrimeField + Extendable<D>, const D: usize, const CHUNK_SIZE: usize> Gat
let switch_bool = vars.local_wires[self.wire_switch_bool(c)];
let not_switch = F::ONE - switch_bool;
for e in 0..CHUNK_SIZE {
for e in 0..self.chunk_size {
let first_input = vars.local_wires[self.wire_first_input(c, e)];
let second_input = vars.local_wires[self.wire_second_input(c, e)];
let first_output = vars.local_wires[self.wire_first_output(c, e)];
@ -130,7 +129,7 @@ impl<F: PrimeField + Extendable<D>, const D: usize, const CHUNK_SIZE: usize> Gat
let switch_bool = vars.local_wires[self.wire_switch_bool(c)];
let not_switch = builder.sub_extension(one, switch_bool);
for e in 0..CHUNK_SIZE {
for e in 0..self.chunk_size {
let first_input = vars.local_wires[self.wire_first_input(c, e)];
let second_input = vars.local_wires[self.wire_second_input(c, e)];
let first_output = vars.local_wires[self.wire_first_output(c, e)];
@ -165,15 +164,20 @@ impl<F: PrimeField + Extendable<D>, const D: usize, const CHUNK_SIZE: usize> Gat
gate_index: usize,
_local_constants: &[F],
) -> Vec<Box<dyn WitnessGenerator<F>>> {
let gen = SwitchGenerator::<F, D, CHUNK_SIZE> {
gate_index,
gate: self.clone(),
};
vec![Box::new(gen)]
(0..self.num_copies)
.map(|c| {
let g: Box<dyn WitnessGenerator<F>> = Box::new(SwitchGenerator::<F, D> {
gate_index,
gate: self.clone(),
copy: c,
});
g
})
.collect()
}
fn num_wires(&self) -> usize {
self.wire_second_output(self.num_copies - 1, CHUNK_SIZE - 1) + 1
self.wire_switch_bool(self.num_copies - 1) + 1
}
fn num_constants(&self) -> usize {
@ -185,35 +189,46 @@ impl<F: PrimeField + Extendable<D>, const D: usize, const CHUNK_SIZE: usize> Gat
}
fn num_constraints(&self) -> usize {
4 * self.num_copies * CHUNK_SIZE
4 * self.num_copies * self.chunk_size
}
}
#[derive(Debug)]
struct SwitchGenerator<F: PrimeField + Extendable<D>, const D: usize, const CHUNK_SIZE: usize> {
struct SwitchGenerator<F: PrimeField + Extendable<D>, const D: usize> {
gate_index: usize,
gate: SwitchGate<F, D, CHUNK_SIZE>,
gate: SwitchGate<F, D>,
copy: usize,
}
impl<F: PrimeField + Extendable<D>, const D: usize, const CHUNK_SIZE: usize> SimpleGenerator<F>
for SwitchGenerator<F, D, CHUNK_SIZE>
{
fn dependencies(&self) -> Vec<Target> {
impl<F: PrimeField + Extendable<D>, const D: usize> SwitchGenerator<F, D> {
fn in_out_dependencies(&self) -> Vec<Target> {
let local_target = |input| Target::wire(self.gate_index, input);
let mut deps = Vec::new();
for c in 0..self.gate.num_copies {
deps.push(local_target(self.gate.wire_switch_bool(c)));
for e in 0..CHUNK_SIZE {
deps.push(local_target(self.gate.wire_first_input(c, e)));
deps.push(local_target(self.gate.wire_second_input(c, e)));
}
for e in 0..self.gate.chunk_size {
deps.push(local_target(self.gate.wire_first_input(self.copy, e)));
deps.push(local_target(self.gate.wire_second_input(self.copy, e)));
deps.push(local_target(self.gate.wire_first_output(self.copy, e)));
deps.push(local_target(self.gate.wire_second_output(self.copy, e)));
}
deps
}
fn run_once(&self, witness: &PartitionWitness<F>, out_buffer: &mut GeneratedValues<F>) {
fn in_switch_dependencies(&self) -> Vec<Target> {
let local_target = |input| Target::wire(self.gate_index, input);
let mut deps = Vec::new();
for e in 0..self.gate.chunk_size {
deps.push(local_target(self.gate.wire_first_input(self.copy, e)));
deps.push(local_target(self.gate.wire_second_input(self.copy, e)));
deps.push(local_target(self.gate.wire_switch_bool(self.copy)));
}
deps
}
fn run_in_out(&self, witness: &PartitionWitness<F>, out_buffer: &mut GeneratedValues<F>) {
let local_wire = |input| Wire {
gate: self.gate_index,
input,
@ -221,24 +236,69 @@ impl<F: PrimeField + Extendable<D>, const D: usize, const CHUNK_SIZE: usize> Sim
let get_local_wire = |input| witness.get_wire(local_wire(input));
for c in 0..self.gate.num_copies {
let switch_bool = get_local_wire(self.gate.wire_switch_bool(c));
for e in 0..CHUNK_SIZE {
let first_input = get_local_wire(self.gate.wire_first_input(c, e));
let second_input = get_local_wire(self.gate.wire_second_input(c, e));
let first_output_wire = local_wire(self.gate.wire_first_output(c, e));
let second_output_wire = local_wire(self.gate.wire_second_output(c, e));
for e in 0..self.gate.chunk_size {
let switch_bool_wire = local_wire(self.gate.wire_switch_bool(self.copy));
let first_input = get_local_wire(self.gate.wire_first_input(self.copy, e));
let second_input = get_local_wire(self.gate.wire_second_input(self.copy, e));
let first_output = get_local_wire(self.gate.wire_first_output(self.copy, e));
let second_output = get_local_wire(self.gate.wire_second_output(self.copy, e));
if switch_bool == F::ONE {
out_buffer.set_wire(first_output_wire, second_input);
out_buffer.set_wire(second_output_wire, first_input);
} else {
out_buffer.set_wire(first_output_wire, first_input);
out_buffer.set_wire(second_output_wire, second_input);
}
if first_output == first_input && second_output == second_input {
out_buffer.set_wire(switch_bool_wire, F::ZERO);
} else if first_output == second_input && second_output == first_input {
out_buffer.set_wire(switch_bool_wire, F::ONE);
} else {
panic!("No permutation from given inputs to given outputs");
}
}
}
fn run_in_switch(&self, witness: &PartitionWitness<F>, out_buffer: &mut GeneratedValues<F>) {
let local_wire = |input| Wire {
gate: self.gate_index,
input,
};
let get_local_wire = |input| witness.get_wire(local_wire(input));
for e in 0..self.gate.chunk_size {
let first_output_wire = local_wire(self.gate.wire_first_output(self.copy, e));
let second_output_wire = local_wire(self.gate.wire_second_output(self.copy, e));
let first_input = get_local_wire(self.gate.wire_first_input(self.copy, e));
let second_input = get_local_wire(self.gate.wire_second_input(self.copy, e));
let switch_bool = get_local_wire(self.gate.wire_switch_bool(self.copy));
let (first_output, second_output) = if switch_bool == F::ZERO {
(first_input, second_input)
} else if switch_bool == F::ONE {
(second_input, first_input)
} else {
panic!("Invalid switch bool value");
};
out_buffer.set_wire(first_output_wire, first_output);
out_buffer.set_wire(second_output_wire, second_output);
}
}
}
impl<F: PrimeField + Extendable<D>, const D: usize> WitnessGenerator<F> for SwitchGenerator<F, D> {
fn watch_list(&self) -> Vec<Target> {
self.in_out_dependencies()
.union(self.in_switch_dependencies())
}
fn run(&self, witness: &PartitionWitness<F>, out_buffer: &mut GeneratedValues<F>) -> bool {
if witness.contains_all(&self.in_out_dependencies()) {
self.run_in_out(witness, out_buffer);
true
} else if witness.contains_all(&self.in_switch_dependencies()) {
self.run_in_switch(witness, out_buffer);
true
} else {
false
}
}
}
#[cfg(test)]
@ -259,37 +319,44 @@ mod tests {
#[test]
fn wire_indices() {
let gate = SwitchGate::<CrandallField, 4, 3> {
num_copies: 3,
type SG = SwitchGate<CrandallField, 4>;
let num_copies = 3;
let chunk_size = 3;
let gate = SG {
chunk_size,
num_copies,
_phantom: PhantomData,
};
assert_eq!(gate.wire_switch_bool(0), 0);
assert_eq!(gate.wire_first_input(0, 0), 1);
assert_eq!(gate.wire_first_input(0, 2), 3);
assert_eq!(gate.wire_second_input(0, 0), 4);
assert_eq!(gate.wire_second_input(0, 2), 6);
assert_eq!(gate.wire_first_output(0, 0), 7);
assert_eq!(gate.wire_second_output(0, 2), 12);
assert_eq!(gate.wire_switch_bool(1), 13);
assert_eq!(gate.wire_first_input(1, 0), 14);
assert_eq!(gate.wire_second_output(1, 2), 25);
assert_eq!(gate.wire_switch_bool(2), 26);
assert_eq!(gate.wire_first_input(2, 0), 27);
assert_eq!(gate.wire_second_output(2, 2), 38);
assert_eq!(gate.wire_first_input(0, 0), 0);
assert_eq!(gate.wire_first_input(0, 2), 2);
assert_eq!(gate.wire_second_input(0, 0), 3);
assert_eq!(gate.wire_second_input(0, 2), 5);
assert_eq!(gate.wire_first_output(0, 0), 6);
assert_eq!(gate.wire_second_output(0, 2), 11);
assert_eq!(gate.wire_switch_bool(0), 12);
assert_eq!(gate.wire_first_input(1, 0), 13);
assert_eq!(gate.wire_second_output(1, 2), 24);
assert_eq!(gate.wire_switch_bool(1), 25);
assert_eq!(gate.wire_first_input(2, 0), 26);
assert_eq!(gate.wire_second_output(2, 2), 37);
assert_eq!(gate.wire_switch_bool(2), 38);
}
#[test]
fn low_degree() {
test_low_degree::<CrandallField, _, 4>(SwitchGate::<_, 4, 3>::new(
test_low_degree::<CrandallField, _, 4>(SwitchGate::<_, 4>::new_from_config(
CircuitConfig::large_config(),
3,
));
}
#[test]
fn eval_fns() -> Result<()> {
test_eval_fns::<CrandallField, _, 4>(SwitchGate::<_, 4, 3>::new(
test_eval_fns::<CrandallField, _, 4>(SwitchGate::<_, 4>::new_from_config(
CircuitConfig::large_config(),
3,
))
}
@ -312,7 +379,7 @@ mod tests {
let mut v = Vec::new();
for c in 0..num_copies {
let switch = switch_bools[c];
v.push(F::from_bool(switch));
let mut first_input_chunk = Vec::with_capacity(CHUNK_SIZE);
let mut second_input_chunk = Vec::with_capacity(CHUNK_SIZE);
let mut first_output_chunk = Vec::with_capacity(CHUNK_SIZE);
@ -331,6 +398,8 @@ mod tests {
v.append(&mut second_input_chunk);
v.append(&mut first_output_chunk);
v.append(&mut second_output_chunk);
v.push(F::from_bool(switch));
}
v.iter().map(|&x| x.into()).collect::<Vec<_>>()
@ -340,7 +409,8 @@ mod tests {
let second_inputs: Vec<Vec<F>> = (0..num_copies).map(|_| F::rand_vec(CHUNK_SIZE)).collect();
let switch_bools = vec![true, false, true];
let gate = SwitchGate::<F, D, CHUNK_SIZE> {
let gate = SwitchGate::<F, D> {
chunk_size: CHUNK_SIZE,
num_copies,
_phantom: PhantomData,
};

View File

@ -1,4 +1,5 @@
use std::fmt::Debug;
use std::marker::PhantomData;
use crate::field::extension_field::target::ExtensionTarget;
use crate::field::extension_field::{Extendable, FieldExtension};
@ -186,16 +187,32 @@ pub trait SimpleGenerator<F: Field>: 'static + Send + Sync + Debug {
fn dependencies(&self) -> Vec<Target>;
fn run_once(&self, witness: &PartitionWitness<F>, out_buffer: &mut GeneratedValues<F>);
fn adapter(self) -> SimpleGeneratorAdapter<F, Self>
where
Self: Sized,
{
SimpleGeneratorAdapter {
inner: self,
_phantom: PhantomData,
}
}
}
impl<F: Field, SG: SimpleGenerator<F>> WitnessGenerator<F> for SG {
#[derive(Debug)]
pub struct SimpleGeneratorAdapter<F: Field, SG: SimpleGenerator<F> + ?Sized> {
_phantom: PhantomData<F>,
inner: SG,
}
impl<F: Field, SG: SimpleGenerator<F>> WitnessGenerator<F> for SimpleGeneratorAdapter<F, SG> {
fn watch_list(&self) -> Vec<Target> {
self.dependencies()
self.inner.dependencies()
}
fn run(&self, witness: &PartitionWitness<F>, out_buffer: &mut GeneratedValues<F>) -> bool {
if witness.contains_all(&self.dependencies()) {
self.run_once(witness, out_buffer);
if witness.contains_all(&self.inner.dependencies()) {
self.inner.run_once(witness, out_buffer);
true
} else {
false

View File

@ -15,9 +15,12 @@ use crate::gates::gate::{Gate, GateInstance, GateRef, PrefixedGate};
use crate::gates::gate_tree::Tree;
use crate::gates::noop::NoopGate;
use crate::gates::public_input::PublicInputGate;
use crate::gates::switch::SwitchGate;
use crate::hash::hash_types::{HashOutTarget, MerkleCapTarget};
use crate::hash::hashing::hash_n_to_hash;
use crate::iop::generator::{CopyGenerator, RandomValueGenerator, WitnessGenerator};
use crate::iop::generator::{
CopyGenerator, RandomValueGenerator, SimpleGenerator, WitnessGenerator,
};
use crate::iop::target::{BoolTarget, Target};
use crate::iop::wire::Wire;
use crate::iop::witness::PartitionWitness;
@ -41,7 +44,7 @@ pub struct CircuitBuilder<F: PrimeField + Extendable<D>, const D: usize> {
gates: HashSet<GateRef<F, D>>,
/// The concrete placement of each gate.
gate_instances: Vec<GateInstance<F, D>>,
pub(crate) gate_instances: Vec<GateInstance<F, D>>,
/// Targets to be made public.
public_inputs: Vec<Target>,
@ -66,6 +69,11 @@ pub struct CircuitBuilder<F: PrimeField + Extendable<D>, const D: usize> {
/// A map `(c0, c1) -> (g, i)` from constants `(c0,c1)` to an available arithmetic gate using
/// these constants with gate index `g` and already using `i` arithmetic operations.
pub(crate) free_arithmetic: HashMap<(F, F), (usize, usize)>,
// `current_switch_gates[chunk_size - 1]` contains None if we have no switch gates with the value
// chunk_size, and contains `(g, i, c)`, if the gate `g`, at index `i`, already contains `c` copies
// of switches
pub(crate) current_switch_gates: Vec<Option<(SwitchGate<F, D>, usize, usize)>>,
}
impl<F: PrimeField + Extendable<D>, const D: usize> CircuitBuilder<F, D> {
@ -83,6 +91,7 @@ impl<F: PrimeField + Extendable<D>, const D: usize> CircuitBuilder<F, D> {
constants_to_targets: HashMap::new(),
targets_to_constants: HashMap::new(),
free_arithmetic: HashMap::new(),
current_switch_gates: Vec::new(),
}
}
@ -182,7 +191,7 @@ impl<F: PrimeField + Extendable<D>, const D: usize> CircuitBuilder<F, D> {
/// Adds a generator which will copy `src` to `dst`.
pub fn generate_copy(&mut self, src: Target, dst: Target) {
self.add_generator(CopyGenerator { src, dst });
self.add_simple_generator(CopyGenerator { src, dst });
}
/// Uses Plonk's permutation argument to require that two elements be equal.
@ -209,8 +218,8 @@ impl<F: PrimeField + Extendable<D>, const D: usize> CircuitBuilder<F, D> {
self.generators.extend(generators);
}
pub fn add_generator<G: WitnessGenerator<F>>(&mut self, generator: G) {
self.generators.push(Box::new(generator));
pub fn add_simple_generator<G: SimpleGenerator<F>>(&mut self, generator: G) {
self.generators.push(Box::new(generator.adapter()));
}
/// Returns a routable target with a value of 0.
@ -383,7 +392,7 @@ impl<F: PrimeField + Extendable<D>, const D: usize> CircuitBuilder<F, D> {
for _ in 0..regular_poly_openings {
let gate = self.add_gate(NoopGate, vec![]);
for w in 0..num_wires {
self.add_generator(RandomValueGenerator {
self.add_simple_generator(RandomValueGenerator {
target: Target::Wire(Wire { gate, input: w }),
});
}
@ -397,7 +406,7 @@ impl<F: PrimeField + Extendable<D>, const D: usize> CircuitBuilder<F, D> {
let gate_2 = self.add_gate(NoopGate, vec![]);
for w in 0..num_routed_wires {
self.add_generator(RandomValueGenerator {
self.add_simple_generator(RandomValueGenerator {
target: Target::Wire(Wire {
gate: gate_1,
input: w,
@ -507,6 +516,33 @@ impl<F: PrimeField + Extendable<D>, const D: usize> CircuitBuilder<F, D> {
}
}
/// Fill the remaining unused switch gates with dummy values, so that all
/// `SwitchGenerator` are run.
fn fill_switch_gates(&mut self) {
let zero = self.zero();
for chunk_size in 1..=self.current_switch_gates.len() {
if let Some((gate, gate_index, mut copy)) =
self.current_switch_gates[chunk_size - 1].clone()
{
while copy < gate.num_copies {
for element in 0..chunk_size {
let wire_first_input =
Target::wire(gate_index, gate.wire_first_input(copy, element));
let wire_second_input =
Target::wire(gate_index, gate.wire_second_input(copy, element));
let wire_switch_bool =
Target::wire(gate_index, gate.wire_switch_bool(copy));
self.connect(zero, wire_first_input);
self.connect(zero, wire_second_input);
self.connect(zero, wire_switch_bool);
}
copy += 1;
}
}
}
}
pub fn print_gate_counts(&self, min_delta: usize) {
self.context_log
.filter(self.num_gates(), min_delta)
@ -519,6 +555,7 @@ impl<F: PrimeField + Extendable<D>, const D: usize> CircuitBuilder<F, D> {
let start = Instant::now();
self.fill_arithmetic_gates();
self.fill_switch_gates();
// Hash the public inputs, and route them to a `PublicInputGate` which will enforce that
// those hash wires match the claimed public inputs.

76
src/util/bimap.rs Normal file
View File

@ -0,0 +1,76 @@
use std::collections::HashMap;
use std::hash::Hash;
use bimap::BiMap;
use itertools::enumerate;
/// Given two lists which are permutations of one another, creates a BiMap which maps an index in
/// one list to an index in the other list with the same associated value.
///
/// If the lists contain duplicates, then multiple permutations with this property exist, and an
/// arbitrary one of them will be returned.
pub fn bimap_from_lists<T: Eq + Hash>(a: Vec<T>, b: Vec<T>) -> BiMap<usize, usize> {
assert_eq!(a.len(), b.len(), "Vectors differ in length");
let mut b_values_to_indices = HashMap::new();
for (i, value) in enumerate(b) {
b_values_to_indices
.entry(value)
.or_insert_with(Vec::new)
.push(i);
}
let mut bimap = BiMap::new();
for (i, value) in enumerate(a) {
if let Some(j) = b_values_to_indices.get_mut(&value).and_then(Vec::pop) {
bimap.insert(i, j);
} else {
panic!("Value in first list not found in second list");
}
}
bimap
}
#[cfg(test)]
mod tests {
use crate::util::bimap::bimap_from_lists;
#[test]
fn empty_lists() {
let empty: Vec<char> = Vec::new();
let bimap = bimap_from_lists(empty.clone(), empty);
assert!(bimap.is_empty());
}
#[test]
fn without_duplicates() {
let bimap = bimap_from_lists(vec!['a', 'b', 'c'], vec!['b', 'c', 'a']);
assert_eq!(bimap.get_by_left(&0), Some(&2));
assert_eq!(bimap.get_by_left(&1), Some(&0));
assert_eq!(bimap.get_by_left(&2), Some(&1));
}
#[test]
fn with_duplicates() {
let first = vec!['a', 'a', 'b'];
let second = vec!['a', 'b', 'a'];
let bimap = bimap_from_lists(first.clone(), second.clone());
for i in 0..3 {
let j = *bimap.get_by_left(&i).unwrap();
assert_eq!(first[i], second[j]);
}
}
#[test]
#[should_panic]
fn lengths_differ() {
bimap_from_lists(vec!['a', 'a', 'b'], vec!['a', 'b']);
}
#[test]
#[should_panic]
fn not_a_permutation() {
bimap_from_lists(vec!['a', 'a', 'b'], vec!['a', 'b', 'b']);
}
}

View File

@ -1,6 +1,7 @@
use crate::field::field_types::Field;
use crate::polynomial::polynomial::PolynomialValues;
pub(crate) mod bimap;
pub(crate) mod context_tree;
pub(crate) mod marking;
pub(crate) mod partial_products;