plonky2/src/permutation_argument.rs

175 lines
5.3 KiB
Rust
Raw Normal View History

2021-07-02 14:45:05 +02:00
use std::collections::HashMap;
2021-07-02 14:13:57 +02:00
use std::fmt::Debug;
use std::hash::Hash;
2021-04-25 17:05:27 -07:00
use rayon::prelude::*;
use crate::field::field::Field;
use crate::polynomial::polynomial::PolynomialValues;
use crate::target::Target;
use crate::wire::Wire;
2021-07-02 14:42:40 +02:00
/// Node in the Disjoint Set Forest.
2021-07-02 14:13:57 +02:00
#[derive(Debug, Copy, Clone, Eq, PartialEq)]
pub struct ForestNode<T: Debug + Copy + Eq + PartialEq> {
t: T,
parent: usize,
size: usize,
index: usize,
2021-04-25 17:05:27 -07:00
}
2021-07-02 14:42:40 +02:00
/// Disjoint Set Forest data-structure following https://en.wikipedia.org/wiki/Disjoint-set_data_structure.
2021-07-02 14:13:57 +02:00
#[derive(Debug, Clone)]
pub struct TargetPartition<T: Debug + Copy + Eq + PartialEq + Hash, F: Fn(T) -> usize> {
2021-07-02 14:13:57 +02:00
forest: Vec<ForestNode<T>>,
2021-07-02 14:42:40 +02:00
/// Function to compute a node's index in the forest.
indices: F,
2021-07-02 14:13:57 +02:00
}
impl<T: Debug + Copy + Eq + PartialEq + Hash, F: Fn(T) -> usize> TargetPartition<T, F> {
pub fn new(f: F) -> Self {
2021-04-25 17:05:27 -07:00
Self {
2021-07-02 14:13:57 +02:00
forest: Vec::new(),
indices: f,
2021-04-25 17:05:27 -07:00
}
}
/// Add a new partition with a single member.
2021-07-02 14:13:57 +02:00
pub fn add(&mut self, t: T) {
let index = self.forest.len();
2021-07-02 14:42:40 +02:00
debug_assert_eq!((self.indices)(t), index);
2021-07-02 14:26:49 +02:00
self.forest.push(ForestNode {
t,
parent: index,
size: 1,
index,
});
2021-04-25 17:05:27 -07:00
}
2021-07-02 14:42:40 +02:00
/// Path halving method, see https://en.wikipedia.org/wiki/Disjoint-set_data_structure#Finding_set_representatives.
2021-07-02 14:13:57 +02:00
pub fn find(&mut self, mut x: ForestNode<T>) -> ForestNode<T> {
while x.parent != x.index {
let grandparent = self.forest[x.parent].parent;
x.parent = grandparent;
x = self.forest[grandparent];
2021-04-25 17:05:27 -07:00
}
2021-07-02 14:13:57 +02:00
x
2021-04-25 17:05:27 -07:00
}
2021-07-02 14:42:40 +02:00
/// Merge two sets.
2021-07-02 14:13:57 +02:00
pub fn merge(&mut self, tx: T, ty: T) {
let index_x = (self.indices)(tx);
let index_y = (self.indices)(ty);
2021-07-02 14:26:49 +02:00
let mut x = self.forest[index_x];
let mut y = self.forest[index_y];
2021-04-25 17:05:27 -07:00
2021-07-02 14:13:57 +02:00
x = self.forest[x.parent];
y = self.forest[y.parent];
if x == y {
return;
2021-04-25 17:05:27 -07:00
}
2021-07-02 14:13:57 +02:00
if x.size < y.size {
std::mem::swap(&mut x, &mut y);
2021-04-25 17:05:27 -07:00
}
2021-07-02 14:13:57 +02:00
y.parent = x.index;
x.size += y.size;
self.forest[index_x] = x;
self.forest[index_y] = y;
}
}
impl<F: Fn(Target) -> usize> TargetPartition<Target, F> {
2021-07-02 14:45:05 +02:00
pub fn wire_partition(&mut self) -> WirePartitions {
2021-07-02 14:13:57 +02:00
let mut partition = HashMap::<_, Vec<_>>::new();
let nodes = self.forest.clone();
for x in nodes {
let v = partition.entry(self.find(x).t).or_default();
v.push(x.t);
2021-04-25 17:05:27 -07:00
}
2021-07-02 14:13:57 +02:00
let mut indices = HashMap::new();
2021-07-02 14:26:49 +02:00
// // Here we keep just the Wire targets, filtering out everything else.
2021-07-02 14:13:57 +02:00
let partition = partition
.into_values()
.map(|v| {
v.into_iter()
.filter_map(|t| match t {
Target::Wire(w) => Some(w),
_ => None,
})
.collect::<Vec<_>>()
})
.collect::<Vec<_>>();
partition.iter().enumerate().for_each(|(i, v)| {
v.iter().for_each(|t| {
indices.insert(*t, i);
});
});
WirePartitions { partition, indices }
2021-04-25 17:05:27 -07:00
}
}
pub struct WirePartitions {
2021-07-02 14:13:57 +02:00
partition: Vec<Vec<Wire>>,
2021-04-25 17:05:27 -07:00
indices: HashMap<Wire, usize>,
}
impl WirePartitions {
/// Find a wire's "neighbor" in the context of Plonk's "extended copy constraints" check. In
/// other words, find the next wire in the given wire's partition. If the given wire is last in
/// its partition, this will loop around. If the given wire has a partition all to itself, it
/// is considered its own neighbor.
fn get_neighbor(&self, wire: Wire) -> Wire {
2021-07-02 14:13:57 +02:00
let partition = &self.partition[self.indices[&wire]];
2021-04-25 17:05:27 -07:00
let n = partition.len();
for i in 0..n {
if partition[i] == wire {
let neighbor_index = (i + 1) % n;
return partition[neighbor_index];
}
}
panic!("Wire not found in the expected partition")
}
pub(crate) fn get_sigma_polys<F: Field>(
&self,
degree_log: usize,
k_is: &[F],
2021-06-16 17:43:41 +02:00
subgroup: &[F],
2021-04-25 17:05:27 -07:00
) -> Vec<PolynomialValues<F>> {
let degree = 1 << degree_log;
let sigma = self.get_sigma_map(degree);
sigma
.chunks(degree)
.map(|chunk| {
let values = chunk
.par_iter()
2021-06-16 17:43:41 +02:00
.map(|&x| k_is[x / degree] * subgroup[x % degree])
2021-04-25 17:05:27 -07:00
.collect::<Vec<_>>();
PolynomialValues::new(values)
})
.collect()
}
/// Generates sigma in the context of Plonk, which is a map from `[kn]` to `[kn]`, where `k` is
/// the number of routed wires and `n` is the number of gates.
fn get_sigma_map(&self, degree: usize) -> Vec<usize> {
debug_assert_eq!(self.indices.len() % degree, 0);
let num_routed_wires = self.indices.len() / degree;
let mut sigma = Vec::new();
for input in 0..num_routed_wires {
for gate in 0..degree {
let wire = Wire { gate, input };
let neighbor = self.get_neighbor(wire);
sigma.push(neighbor.input * degree + neighbor.gate);
}
}
sigma
}
}