From b0a855a9c3764acd5233ed0ba5e8d614bb550657 Mon Sep 17 00:00:00 2001 From: Nicholas Ward Date: Mon, 23 Aug 2021 12:13:31 -0700 Subject: [PATCH] progress on permutation --- src/gadgets/arithmetic.rs | 2 +- src/gadgets/mod.rs | 1 + src/gadgets/permutation.rs | 115 ++++++++++++++++++++++++++++++++++++ src/gates/exponentiation.rs | 14 +++-- src/gates/switch.rs | 69 +++++++++++----------- 5 files changed, 163 insertions(+), 38 deletions(-) create mode 100644 src/gadgets/permutation.rs diff --git a/src/gadgets/arithmetic.rs b/src/gadgets/arithmetic.rs index 960f139f..2d354567 100644 --- a/src/gadgets/arithmetic.rs +++ b/src/gadgets/arithmetic.rs @@ -107,7 +107,7 @@ impl, const D: usize> CircuitBuilder { exponent_bits: impl IntoIterator>, ) -> Target { let _false = self._false(); - let gate = ExponentiationGate::new(self.config.clone()); + let gate = ExponentiationGate::new_from_config(self.config.clone()); let num_power_bits = gate.num_power_bits; let mut exp_bits_vec: Vec = exponent_bits.into_iter().map(|b| *b.borrow()).collect(); diff --git a/src/gadgets/mod.rs b/src/gadgets/mod.rs index 0eb42e27..4b3371ef 100644 --- a/src/gadgets/mod.rs +++ b/src/gadgets/mod.rs @@ -3,6 +3,7 @@ pub mod arithmetic_extension; pub mod hash; pub mod insert; pub mod interpolation; +pub mod permutation; pub mod polynomial; pub mod random_access; pub mod range_check; diff --git a/src/gadgets/permutation.rs b/src/gadgets/permutation.rs new file mode 100644 index 00000000..b477a553 --- /dev/null +++ b/src/gadgets/permutation.rs @@ -0,0 +1,115 @@ +use crate::field::{extension_field::Extendable, field_types::Field}; +use crate::gates::switch::SwitchGate; +use crate::iop::generator::{GeneratedValues, SimpleGenerator}; +use crate::iop::target::Target; +use crate::iop::witness::PartialWitness; +use crate::plonk::circuit_builder::CircuitBuilder; + +impl, const D: usize> CircuitBuilder { + /// Assert that two lists of expressions evaluate to permutations of one another. + pub fn assert_permutation( + &mut self, + a: Vec<[Target; CHUNK_SIZE]>, + b: Vec<[Target; CHUNK_SIZE]>, + ) { + assert_eq!( + a.len(), + b.len(), + "Permutation must have same number of inputs and outputs" + ); + assert_eq!(a[0].len(), b[0].len(), "Chunk sizes must be the same"); + + match a.len() { + // Two empty lists are permutations of one another, trivially. + 0 => (), + // Two singleton lists are permutations of one another as long as their items are equal. + 1 => { + for e in 0..CHUNK_SIZE { + self.assert_equal(a[0][e], b[0][e]) + } + } + 2 => self.assert_permutation_2x2(a[0], a[1], b[0], b[1]), + // For larger lists, we recursively use two smaller permutation networks. + //_ => self.assert_permutation_recursive(a, b) + _ => self.assert_permutation_recursive(a, b), + } + } + + /// Assert that [a, b] is a permutation of [c, d]. + fn assert_permutation_2x2( + &mut self, + a: [Target; CHUNK_SIZE], + b: [Target; CHUNK_SIZE], + c: [Target; CHUNK_SIZE], + d: [Target; CHUNK_SIZE], + ) { + let gate = SwitchGate::::new(1); + let gate_index = self.add_gate(gate.clone(), vec![]); + + for e in 0..CHUNK_SIZE { + self.route(a[e], Target::wire(gate_index, gate.wire_first_input(0, e))); + self.route(b[e], Target::wire(gate_index, gate.wire_second_input(0, e))); + self.route(c[e], Target::wire(gate_index, gate.wire_first_output(0, e))); + self.route( + d[e], + Target::wire(gate_index, gate.wire_second_output(0, e)), + ); + } + } + + fn assert_permutation_recursive( + &mut self, + a: Vec<[Target; CHUNK_SIZE]>, + b: Vec<[Target; CHUNK_SIZE]>, + ) { + } +} + +struct PermutationGenerator { + gate_index: usize, +} + +impl SimpleGenerator for PermutationGenerator { + fn dependencies(&self) -> Vec { + todo!() + } + + fn run_once(&self, witness: &PartialWitness, out_buffer: &mut GeneratedValues) { + todo!() + } +} + +#[cfg(test)] +mod tests { + use anyhow::Result; + + use super::*; + use crate::field::crandall_field::CrandallField; + use crate::field::extension_field::quartic::QuarticCrandallField; + use crate::field::field_types::Field; + use crate::iop::witness::PartialWitness; + use crate::plonk::circuit_data::CircuitConfig; + use crate::plonk::verifier::verify; + + fn test_permutation(size: usize) -> Result<()> { + type F = CrandallField; + type FF = QuarticCrandallField; + let len = 1 << len_log; + let config = CircuitConfig::large_config(); + let pw = PartialWitness::new(config.num_wires); + let mut builder = CircuitBuilder::::new(config); + let vec = FF::rand_vec(len); + let v: Vec<_> = vec.iter().map(|x| builder.constant_extension(*x)).collect(); + + for i in 0..len { + let it = builder.constant(F::from_canonical_usize(i)); + let elem = builder.constant_extension(vec[i]); + builder.random_access(it, elem, v.clone()); + } + + let data = builder.build(); + let proof = data.prove(pw)?; + + verify(proof, &data.verifier_only, &data.common) + } +} diff --git a/src/gates/exponentiation.rs b/src/gates/exponentiation.rs index d57c9001..1a3a6ea3 100644 --- a/src/gates/exponentiation.rs +++ b/src/gates/exponentiation.rs @@ -20,14 +20,18 @@ pub(crate) struct ExponentiationGate, const D: usize> { } impl, const D: usize> ExponentiationGate { - pub fn new(config: CircuitConfig) -> Self { - let num_power_bits = Self::max_power_bits(config.num_wires, config.num_routed_wires); + pub fn new(num_power_bits: usize) -> Self { Self { num_power_bits, _phantom: PhantomData, } } + pub fn new_from_config(config: CircuitConfig) -> Self { + let num_power_bits = Self::max_power_bits(config.num_wires, config.num_routed_wires); + Self::new(num_power_bits) + } + fn max_power_bits(num_wires: usize, num_routed_wires: usize) -> usize { // 2 wires are reserved for the base and output. let max_for_routed_wires = num_routed_wires - 2; @@ -296,12 +300,14 @@ mod tests { ..CircuitConfig::large_config() }; - test_low_degree::(ExponentiationGate::new(config)); + test_low_degree::(ExponentiationGate::new_from_config(config)); } #[test] fn eval_fns() -> Result<()> { - test_eval_fns::(ExponentiationGate::new(CircuitConfig::large_config())) + test_eval_fns::(ExponentiationGate::new_from_config( + CircuitConfig::large_config(), + )) } #[test] diff --git a/src/gates/switch.rs b/src/gates/switch.rs index 6ef29341..a48adb61 100644 --- a/src/gates/switch.rs +++ b/src/gates/switch.rs @@ -20,45 +20,49 @@ pub(crate) struct SwitchGate, const D: usize, const CHUNK_SIZE: } impl, const D: usize, const CHUNK_SIZE: usize> SwitchGate { - pub fn new(config: CircuitConfig) -> Self { - let num_copies = Self::max_num_copies(config.num_routed_wires); + pub fn new(num_copies: usize) -> Self { Self { num_copies, _phantom: PhantomData, } } - fn max_num_copies(num_routed_wires: usize) -> usize { - num_routed_wires / (4 * CHUNK_SIZE + 1) + pub fn new_from_config(config: CircuitConfig) -> Self { + let num_copies = Self::max_num_copies(config.num_routed_wires); + Self::new(num_copies) } - pub fn wire_switch_bool(&self, copy: usize) -> usize { - debug_assert!(copy < self.num_copies); - copy * (4 * CHUNK_SIZE + 1) + fn max_num_copies(num_routed_wires: usize) -> usize { + num_routed_wires / (4 * CHUNK_SIZE) } pub fn wire_first_input(&self, copy: usize, element: usize) -> usize { debug_assert!(copy < self.num_copies); debug_assert!(element < CHUNK_SIZE); - copy * (4 * CHUNK_SIZE + 1) + 1 + element + copy * (4 * CHUNK_SIZE) + element } pub fn wire_second_input(&self, copy: usize, element: usize) -> usize { debug_assert!(copy < self.num_copies); debug_assert!(element < CHUNK_SIZE); - copy * (4 * CHUNK_SIZE + 1) + 1 + CHUNK_SIZE + element + copy * (4 * CHUNK_SIZE) + CHUNK_SIZE + element } pub fn wire_first_output(&self, copy: usize, element: usize) -> usize { debug_assert!(copy < self.num_copies); debug_assert!(element < CHUNK_SIZE); - copy * (4 * CHUNK_SIZE + 1) + 1 + 2 * CHUNK_SIZE + element + copy * (4 * CHUNK_SIZE) + 2 * CHUNK_SIZE + element } pub fn wire_second_output(&self, copy: usize, element: usize) -> usize { debug_assert!(copy < self.num_copies); debug_assert!(element < CHUNK_SIZE); - copy * (4 * CHUNK_SIZE + 1) + 1 + 3 * CHUNK_SIZE + element + copy * (4 * CHUNK_SIZE) + 3 * CHUNK_SIZE + element + } + + pub fn wire_switch_bool(&self, copy: usize) -> usize { + debug_assert!(copy < self.num_copies); + self.num_copies * (4 * CHUNK_SIZE) + copy } } @@ -200,10 +204,11 @@ impl, const D: usize, const CHUNK_SIZE: usize> SimpleGenerator< let mut deps = Vec::new(); for c in 0..self.gate.num_copies { - deps.push(local_target(self.gate.wire_switch_bool(c))); for e in 0..CHUNK_SIZE { deps.push(local_target(self.gate.wire_first_input(c, e))); deps.push(local_target(self.gate.wire_second_input(c, e))); + deps.push(local_target(self.gate.wire_first_output(c, e))); + deps.push(local_target(self.gate.wire_second_output(c, e))); } } @@ -219,19 +224,17 @@ impl, const D: usize, const CHUNK_SIZE: usize> SimpleGenerator< let get_local_wire = |input| witness.get_wire(local_wire(input)); for c in 0..self.gate.num_copies { - let switch_bool = get_local_wire(self.gate.wire_switch_bool(c)); + let switch_bool_wire = local_wire(self.gate.wire_switch_bool(c)); for e in 0..CHUNK_SIZE { let first_input = get_local_wire(self.gate.wire_first_input(c, e)); let second_input = get_local_wire(self.gate.wire_second_input(c, e)); - let first_output_wire = local_wire(self.gate.wire_first_output(c, e)); - let second_output_wire = local_wire(self.gate.wire_second_output(c, e)); + let first_output = get_local_wire(self.gate.wire_first_output(c, e)); + let second_output = get_local_wire(self.gate.wire_second_output(c, e)); - if switch_bool == F::ONE { - out_buffer.set_wire(first_output_wire, second_input); - out_buffer.set_wire(second_output_wire, first_input); + if first_input == first_output { + out_buffer.set_wire(switch_bool_wire, F::ONE); } else { - out_buffer.set_wire(first_output_wire, first_input); - out_buffer.set_wire(second_output_wire, second_input); + out_buffer.set_wire(switch_bool_wire, F::ZERO); } } } @@ -261,19 +264,19 @@ mod tests { _phantom: PhantomData, }; - assert_eq!(gate.wire_switch_bool(0), 0); - assert_eq!(gate.wire_first_input(0, 0), 1); - assert_eq!(gate.wire_first_input(0, 2), 3); - assert_eq!(gate.wire_second_input(0, 0), 4); - assert_eq!(gate.wire_second_input(0, 2), 6); - assert_eq!(gate.wire_first_output(0, 0), 7); - assert_eq!(gate.wire_second_output(0, 2), 12); - assert_eq!(gate.wire_switch_bool(1), 13); - assert_eq!(gate.wire_first_input(1, 0), 14); - assert_eq!(gate.wire_second_output(1, 2), 25); - assert_eq!(gate.wire_switch_bool(2), 26); - assert_eq!(gate.wire_first_input(2, 0), 27); - assert_eq!(gate.wire_second_output(2, 2), 38); + assert_eq!(gate.wire_first_input(0, 0), 0); + assert_eq!(gate.wire_first_input(0, 2), 2); + assert_eq!(gate.wire_second_input(0, 0), 3); + assert_eq!(gate.wire_second_input(0, 2), 5); + assert_eq!(gate.wire_first_output(0, 0), 6); + assert_eq!(gate.wire_second_output(0, 2), 11); + assert_eq!(gate.wire_first_input(1, 0), 12); + assert_eq!(gate.wire_second_output(1, 2), 23); + assert_eq!(gate.wire_first_input(2, 0), 24); + assert_eq!(gate.wire_second_output(2, 2), 35); + assert_eq!(gate.wire_switch_bool(0), 36); + assert_eq!(gate.wire_switch_bool(1), 37); + assert_eq!(gate.wire_switch_bool(2), 38); } #[test]