plonky2/src/gadgets/sorting.rs

278 lines
9.0 KiB
Rust
Raw Normal View History

2021-09-21 18:01:21 -07:00
use itertools::izip;
2021-09-16 20:44:09 -07:00
2021-09-21 18:01:21 -07:00
use crate::field::field_types::RichField;
use crate::field::extension_field::Extendable;
2021-09-16 20:44:09 -07:00
use crate::gates::comparison::ComparisonGate;
use crate::iop::generator::{GeneratedValues, SimpleGenerator};
use crate::iop::target::{BoolTarget, Target};
use crate::iop::witness::{PartitionWitness, Witness};
use crate::plonk::circuit_builder::CircuitBuilder;
2021-09-21 18:01:21 -07:00
use crate::util::ceil_div_usize;
2021-09-16 20:44:09 -07:00
2021-09-21 18:01:21 -07:00
#[derive(Clone, Debug)]
2021-09-16 20:44:09 -07:00
pub struct MemoryOpTarget {
is_write: BoolTarget,
address: Target,
timestamp: Target,
value: Target,
}
impl<F: RichField + Extendable<D>, const D: usize> CircuitBuilder<F, D> {
pub fn assert_permutation_memory_ops(&mut self, a: &[MemoryOpTarget], b: &[MemoryOpTarget]) {
2021-09-16 21:06:54 -07:00
let a_chunks: Vec<Vec<Target>> = a
.iter()
.map(|op| vec![op.address, op.timestamp, op.is_write.target, op.value])
.collect();
let b_chunks: Vec<Vec<Target>> = b
.iter()
.map(|op| vec![op.address, op.timestamp, op.is_write.target, op.value])
.collect();
2021-09-16 20:44:09 -07:00
self.assert_permutation(a_chunks, b_chunks);
}
2021-09-16 21:06:54 -07:00
pub fn sort_memory_ops(
&mut self,
ops: &[MemoryOpTarget],
address_bits: usize,
timestamp_bits: usize,
) -> Vec<MemoryOpTarget> {
2021-09-16 20:44:09 -07:00
let n = ops.len();
2021-09-17 13:09:24 -07:00
let combined_bits = address_bits + timestamp_bits;
2021-09-21 18:01:21 -07:00
let chunk_bits = 3;
let num_chunks = ceil_div_usize(combined_bits, chunk_bits);
2021-09-16 20:44:09 -07:00
2021-09-16 21:06:54 -07:00
let is_write_targets: Vec<_> = self
.add_virtual_targets(n)
.iter()
.map(|&t| BoolTarget::new_unsafe(t))
.collect();
2021-09-16 20:44:09 -07:00
let address_targets = self.add_virtual_targets(n);
let timestamp_targets = self.add_virtual_targets(n);
let value_targets = self.add_virtual_targets(n);
2021-09-16 21:06:54 -07:00
let output_targets: Vec<_> = izip!(
is_write_targets,
address_targets,
timestamp_targets,
value_targets
)
.map(|(i, a, t, v)| MemoryOpTarget {
is_write: i,
address: a,
timestamp: t,
value: v,
})
.collect();
2021-09-16 20:44:09 -07:00
2021-09-17 13:09:24 -07:00
let two_n = self.constant(F::from_canonical_usize(1 << timestamp_bits));
let address_timestamp_combined: Vec<_> = output_targets
.iter()
2021-09-21 18:01:21 -07:00
.map(|op| self.mul_add(op.address, two_n, op.timestamp))
2021-09-17 13:09:24 -07:00
.collect();
2021-09-16 20:44:09 -07:00
2021-09-21 18:01:21 -07:00
let mut gate_indices = Vec::new();
let mut gates = Vec::new();
2021-09-17 13:09:24 -07:00
for i in 1..n {
let (gate, gate_index) = {
2021-09-21 18:01:21 -07:00
let gate = ComparisonGate::new(combined_bits, num_chunks);
2021-09-16 20:44:09 -07:00
let gate_index = self.add_gate(gate.clone(), vec![]);
(gate, gate_index)
};
self.connect(
2021-09-17 13:09:24 -07:00
Target::wire(gate_index, gate.wire_first_input()),
address_timestamp_combined[i - 1],
2021-09-16 20:44:09 -07:00
);
self.connect(
2021-09-17 13:09:24 -07:00
Target::wire(gate_index, gate.wire_second_input()),
address_timestamp_combined[i],
2021-09-16 20:44:09 -07:00
);
2021-09-21 18:01:21 -07:00
gate_indices.push(gate_index);
gates.push(gate);
2021-09-16 20:44:09 -07:00
}
self.assert_permutation_memory_ops(ops, output_targets.as_slice());
2021-09-21 18:01:21 -07:00
self.add_simple_generator(MemoryOpSortGenerator::<F, D> {
input_ops: ops.to_vec(),
gate_indices,
gates: gates.clone(),
output_ops: output_targets.clone(),
address_bits,
timestamp_bits,
});
2021-09-16 20:44:09 -07:00
output_targets
}
}
2021-09-17 13:40:07 -07:00
#[derive(Debug)]
2021-09-21 18:01:21 -07:00
struct MemoryOpSortGenerator<F: RichField + Extendable<D>, const D: usize> {
2021-09-17 13:40:07 -07:00
input_ops: Vec<MemoryOpTarget>,
2021-09-21 18:01:21 -07:00
gate_indices: Vec<usize>,
gates: Vec<ComparisonGate<F, D>>,
2021-09-17 13:40:07 -07:00
output_ops: Vec<MemoryOpTarget>,
address_bits: usize,
timestamp_bits: usize,
2021-09-16 20:44:09 -07:00
}
2021-09-21 18:01:21 -07:00
impl<F: RichField + Extendable<D>, const D: usize> SimpleGenerator<F>
for MemoryOpSortGenerator<F, D>
{
2021-09-16 20:44:09 -07:00
fn dependencies(&self) -> Vec<Target> {
2021-09-17 13:40:07 -07:00
self.input_ops
.iter()
.map(|op| vec![op.is_write.target, op.address, op.timestamp, op.value])
.flatten()
.collect()
2021-09-16 20:44:09 -07:00
}
fn run_once(&self, witness: &PartitionWitness<F>, out_buffer: &mut GeneratedValues<F>) {
2021-09-17 13:40:07 -07:00
let n = self.input_ops.len();
debug_assert!(self.output_ops.len() == n);
let (timestamp_values, address_values): (Vec<_>, Vec<_>) = self
.input_ops
2021-09-16 20:44:09 -07:00
.iter()
2021-09-17 13:40:07 -07:00
.map(|op| {
(
witness.get_target(op.timestamp),
witness.get_target(op.address),
)
})
.unzip();
let combined_values_u64: Vec<_> = timestamp_values
2021-09-16 20:44:09 -07:00
.iter()
2021-09-17 13:40:07 -07:00
.zip(address_values.iter())
.map(|(&t, &a)| {
a.to_canonical_u64() * (1 << self.timestamp_bits as u64) + t.to_canonical_u64()
})
2021-09-16 20:44:09 -07:00
.collect();
2021-09-17 13:40:07 -07:00
2021-09-21 18:01:21 -07:00
let mut combined_values_sorted = combined_values_u64.clone();
combined_values_sorted.sort();
let combined_values: Vec<_> = combined_values_sorted
.iter()
.map(|&x| F::from_canonical_u64(x))
.collect();
2021-09-17 13:40:07 -07:00
let mut input_ops_and_keys: Vec<_> = self
.input_ops
.iter()
.zip(combined_values_u64)
.collect::<Vec<_>>();
input_ops_and_keys.sort_by(|(_, a_val), (_, b_val)| a_val.cmp(b_val));
let input_ops_sorted: Vec<_> = input_ops_and_keys.iter().map(|(op, _)| op).collect();
for i in 0..n {
out_buffer.set_target(
self.output_ops[i].is_write.target,
witness.get_target(input_ops_sorted[i].is_write.target),
);
out_buffer.set_target(
self.output_ops[i].address,
witness.get_target(input_ops_sorted[i].address),
);
out_buffer.set_target(
self.output_ops[i].timestamp,
witness.get_target(input_ops_sorted[i].timestamp),
);
out_buffer.set_target(
self.output_ops[i].value,
witness.get_target(input_ops_sorted[i].value),
);
2021-09-21 18:01:21 -07:00
if i > 0 {
out_buffer.set_target(
Target::wire(
self.gate_indices[i - 1],
self.gates[i - 1].wire_second_input(),
),
combined_values[i],
);
}
if i < n - 1 {
out_buffer.set_target(
Target::wire(self.gate_indices[i], self.gates[i].wire_first_input()),
combined_values[i],
);
}
2021-09-17 13:40:07 -07:00
}
2021-09-16 20:44:09 -07:00
}
2021-09-17 13:40:07 -07:00
}
2021-09-16 20:44:09 -07:00
#[cfg(test)]
mod tests {
2021-09-21 18:01:21 -07:00
use std::collections::HashSet;
2021-09-16 20:44:09 -07:00
use anyhow::Result;
2021-09-17 14:50:37 -07:00
use rand::{seq::SliceRandom, thread_rng, Rng};
2021-09-16 20:44:09 -07:00
use super::*;
use crate::field::crandall_field::CrandallField;
use crate::field::field_types::Field;
use crate::iop::witness::PartialWitness;
use crate::plonk::circuit_data::CircuitConfig;
use crate::plonk::verifier::verify;
2021-09-17 14:50:37 -07:00
fn test_sorting(size: usize, address_bits: usize, timestamp_bits: usize) -> Result<()> {
2021-09-16 20:44:09 -07:00
type F = CrandallField;
const D: usize = 4;
let config = CircuitConfig::large_zk_config();
let pw = PartialWitness::new();
let mut builder = CircuitBuilder::<F, D>::new(config);
2021-09-17 14:50:37 -07:00
let mut rng = thread_rng();
let is_write_vals: Vec<_> = (0..size).map(|_| rng.gen_range(0..2) != 0).collect();
let address_vals: Vec<_> = (0..size)
.map(|_| F::from_canonical_u64(rng.gen_range(0..1 << address_bits as u64)))
2021-09-16 20:44:09 -07:00
.collect();
2021-09-17 14:50:37 -07:00
let timestamp_vals: Vec<_> = (0..size)
.map(|_| F::from_canonical_u64(rng.gen_range(0..1 << timestamp_bits as u64)))
.collect();
let value_vals: Vec<_> = (0..size).map(|_| F::rand()).collect();
let input_ops: Vec<MemoryOpTarget> =
izip!(is_write_vals, address_vals, timestamp_vals, value_vals)
.map(|(is_write, address, timestamp, value)| MemoryOpTarget {
is_write: builder.constant_bool(is_write),
address: builder.constant(address),
timestamp: builder.constant(timestamp),
value: builder.constant(value),
})
.collect();
2021-09-16 20:44:09 -07:00
2021-09-17 14:50:37 -07:00
let _output_ops =
builder.sort_memory_ops(input_ops.as_slice(), address_bits, timestamp_bits);
2021-09-16 20:44:09 -07:00
let data = builder.build();
let proof = data.prove(pw).unwrap();
verify(proof, &data.verifier_only, &data.common)
}
#[test]
2021-09-17 14:50:37 -07:00
fn test_sorting_small() -> Result<()> {
let size = 5;
let address_bits = 20;
let timestamp_bits = 20;
2021-09-16 20:44:09 -07:00
2021-09-17 14:50:37 -07:00
test_sorting(size, address_bits, timestamp_bits)
2021-09-16 20:44:09 -07:00
}
2021-09-21 18:01:21 -07:00
#[test]
fn test_sorting_large() -> Result<()> {
let size = 20;
let address_bits = 20;
let timestamp_bits = 20;
test_sorting(size, address_bits, timestamp_bits)
}
2021-09-16 20:44:09 -07:00
}