mirror of
https://github.com/logos-storage/plonky2.git
synced 2026-01-10 01:33:07 +00:00
Merge pull request #249 from mir-protocol/sorting_gadget
Memory sorting gadget
This commit is contained in:
commit
23b1161d27
@ -8,5 +8,6 @@ pub mod polynomial;
|
||||
pub mod random_access;
|
||||
pub mod range_check;
|
||||
pub mod select;
|
||||
pub mod sorting;
|
||||
pub mod split_base;
|
||||
pub(crate) mod split_join;
|
||||
|
||||
@ -384,7 +384,7 @@ impl<F: Field> SimpleGenerator<F> for PermutationGenerator<F> {
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use anyhow::Result;
|
||||
use rand::{seq::SliceRandom, thread_rng};
|
||||
use rand::{seq::SliceRandom, thread_rng, Rng};
|
||||
|
||||
use super::*;
|
||||
use crate::field::crandall_field::CrandallField;
|
||||
@ -418,6 +418,35 @@ mod tests {
|
||||
verify(proof, &data.verifier_only, &data.common)
|
||||
}
|
||||
|
||||
fn test_permutation_duplicates(size: usize) -> Result<()> {
|
||||
type F = CrandallField;
|
||||
const D: usize = 4;
|
||||
|
||||
let config = CircuitConfig::large_zk_config();
|
||||
|
||||
let pw = PartialWitness::new();
|
||||
let mut builder = CircuitBuilder::<F, D>::new(config);
|
||||
|
||||
let mut rng = thread_rng();
|
||||
let lst: Vec<F> = (0..size * 2)
|
||||
.map(|_| F::from_canonical_usize(rng.gen_range(0..2usize)))
|
||||
.collect();
|
||||
let a: Vec<Vec<Target>> = lst[..]
|
||||
.chunks(2)
|
||||
.map(|pair| vec![builder.constant(pair[0]), builder.constant(pair[1])])
|
||||
.collect();
|
||||
|
||||
let mut b = a.clone();
|
||||
b.shuffle(&mut thread_rng());
|
||||
|
||||
builder.assert_permutation(a, b);
|
||||
|
||||
let data = builder.build();
|
||||
let proof = data.prove(pw).unwrap();
|
||||
|
||||
verify(proof, &data.verifier_only, &data.common)
|
||||
}
|
||||
|
||||
fn test_permutation_bad(size: usize) -> Result<()> {
|
||||
type F = CrandallField;
|
||||
const D: usize = 4;
|
||||
@ -446,6 +475,15 @@ mod tests {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_permutations_duplicates() -> Result<()> {
|
||||
for n in 2..9 {
|
||||
test_permutation_duplicates(n)?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_permutations_good() -> Result<()> {
|
||||
for n in 2..9 {
|
||||
|
||||
263
src/gadgets/sorting.rs
Normal file
263
src/gadgets/sorting.rs
Normal file
@ -0,0 +1,263 @@
|
||||
use std::marker::PhantomData;
|
||||
|
||||
use itertools::izip;
|
||||
|
||||
use crate::field::extension_field::Extendable;
|
||||
use crate::field::field_types::{Field, RichField};
|
||||
use crate::gates::comparison::ComparisonGate;
|
||||
use crate::iop::generator::{GeneratedValues, SimpleGenerator};
|
||||
use crate::iop::target::{BoolTarget, Target};
|
||||
use crate::iop::witness::{PartitionWitness, Witness};
|
||||
use crate::plonk::circuit_builder::CircuitBuilder;
|
||||
use crate::util::ceil_div_usize;
|
||||
|
||||
pub struct MemoryOp<F: Field> {
|
||||
is_write: bool,
|
||||
address: F,
|
||||
timestamp: F,
|
||||
value: F,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct MemoryOpTarget {
|
||||
is_write: BoolTarget,
|
||||
address: Target,
|
||||
timestamp: Target,
|
||||
value: Target,
|
||||
}
|
||||
|
||||
impl<F: RichField + Extendable<D>, const D: usize> CircuitBuilder<F, D> {
|
||||
pub fn assert_permutation_memory_ops(&mut self, a: &[MemoryOpTarget], b: &[MemoryOpTarget]) {
|
||||
let a_chunks: Vec<Vec<Target>> = a
|
||||
.iter()
|
||||
.map(|op| vec![op.address, op.timestamp, op.is_write.target, op.value])
|
||||
.collect();
|
||||
let b_chunks: Vec<Vec<Target>> = b
|
||||
.iter()
|
||||
.map(|op| vec![op.address, op.timestamp, op.is_write.target, op.value])
|
||||
.collect();
|
||||
|
||||
self.assert_permutation(a_chunks, b_chunks);
|
||||
}
|
||||
|
||||
/// Add a ComparisonGate to assert that `lhs` is less than `rhs`, where their values are at most `bits` bits.
|
||||
pub fn assert_le(&mut self, lhs: Target, rhs: Target, bits: usize, num_chunks: usize) {
|
||||
let gate = ComparisonGate::new(bits, num_chunks);
|
||||
let gate_index = self.add_gate(gate.clone(), vec![]);
|
||||
|
||||
self.connect(Target::wire(gate_index, gate.wire_first_input()), lhs);
|
||||
self.connect(Target::wire(gate_index, gate.wire_second_input()), rhs);
|
||||
}
|
||||
|
||||
/// Sort memory operations by address value, then by timestamp value.
|
||||
/// This is done by combining address and timestamp into one field element (using their given bit lengths).
|
||||
pub fn sort_memory_ops(
|
||||
&mut self,
|
||||
ops: &[MemoryOpTarget],
|
||||
address_bits: usize,
|
||||
timestamp_bits: usize,
|
||||
) -> Vec<MemoryOpTarget> {
|
||||
let n = ops.len();
|
||||
|
||||
let combined_bits = address_bits + timestamp_bits;
|
||||
let chunk_bits = 3;
|
||||
let num_chunks = ceil_div_usize(combined_bits, chunk_bits);
|
||||
|
||||
// This is safe because `assert_permutation` will force these targets (in the output list) to match the boolean values from the input list.
|
||||
let is_write_targets: Vec<_> = self
|
||||
.add_virtual_targets(n)
|
||||
.iter()
|
||||
.map(|&t| BoolTarget::new_unsafe(t))
|
||||
.collect();
|
||||
|
||||
let address_targets = self.add_virtual_targets(n);
|
||||
let timestamp_targets = self.add_virtual_targets(n);
|
||||
let value_targets = self.add_virtual_targets(n);
|
||||
|
||||
let output_targets: Vec<_> = izip!(
|
||||
is_write_targets,
|
||||
address_targets,
|
||||
timestamp_targets,
|
||||
value_targets
|
||||
)
|
||||
.map(|(i, a, t, v)| MemoryOpTarget {
|
||||
is_write: i,
|
||||
address: a,
|
||||
timestamp: t,
|
||||
value: v,
|
||||
})
|
||||
.collect();
|
||||
|
||||
let two_n = self.constant(F::from_canonical_usize(1 << timestamp_bits));
|
||||
let address_timestamp_combined: Vec<_> = output_targets
|
||||
.iter()
|
||||
.map(|op| self.mul_add(op.address, two_n, op.timestamp))
|
||||
.collect();
|
||||
|
||||
for i in 1..n {
|
||||
self.assert_le(
|
||||
address_timestamp_combined[i - 1],
|
||||
address_timestamp_combined[i],
|
||||
combined_bits,
|
||||
num_chunks,
|
||||
);
|
||||
}
|
||||
|
||||
self.assert_permutation_memory_ops(ops, &output_targets);
|
||||
|
||||
self.add_simple_generator(MemoryOpSortGenerator::<F, D> {
|
||||
input_ops: ops.to_vec(),
|
||||
output_ops: output_targets.clone(),
|
||||
_phantom: PhantomData,
|
||||
});
|
||||
|
||||
output_targets
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct MemoryOpSortGenerator<F: RichField + Extendable<D>, const D: usize> {
|
||||
input_ops: Vec<MemoryOpTarget>,
|
||||
output_ops: Vec<MemoryOpTarget>,
|
||||
_phantom: PhantomData<F>,
|
||||
}
|
||||
|
||||
impl<F: RichField + Extendable<D>, const D: usize> SimpleGenerator<F>
|
||||
for MemoryOpSortGenerator<F, D>
|
||||
{
|
||||
fn dependencies(&self) -> Vec<Target> {
|
||||
self.input_ops
|
||||
.iter()
|
||||
.map(|op| vec![op.is_write.target, op.address, op.timestamp, op.value])
|
||||
.flatten()
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn run_once(&self, witness: &PartitionWitness<F>, out_buffer: &mut GeneratedValues<F>) {
|
||||
let n = self.input_ops.len();
|
||||
debug_assert!(self.output_ops.len() == n);
|
||||
|
||||
let mut ops: Vec<_> = self
|
||||
.input_ops
|
||||
.iter()
|
||||
.map(|op| {
|
||||
let is_write = witness.get_bool_target(op.is_write);
|
||||
let address = witness.get_target(op.address);
|
||||
let timestamp = witness.get_target(op.timestamp);
|
||||
let value = witness.get_target(op.value);
|
||||
MemoryOp {
|
||||
is_write,
|
||||
address,
|
||||
timestamp,
|
||||
value,
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
|
||||
ops.sort_unstable_by_key(|op| {
|
||||
(
|
||||
op.address.to_canonical_u64(),
|
||||
op.timestamp.to_canonical_u64(),
|
||||
)
|
||||
});
|
||||
|
||||
for (op, out_op) in ops.iter().zip(&self.output_ops) {
|
||||
out_buffer.set_target(out_op.is_write.target, F::from_bool(op.is_write));
|
||||
out_buffer.set_target(out_op.address, op.address);
|
||||
out_buffer.set_target(out_op.timestamp, op.timestamp);
|
||||
out_buffer.set_target(out_op.value, op.value);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use anyhow::Result;
|
||||
use rand::{thread_rng, Rng};
|
||||
|
||||
use super::*;
|
||||
use crate::field::crandall_field::CrandallField;
|
||||
use crate::field::field_types::{Field, PrimeField};
|
||||
use crate::iop::witness::PartialWitness;
|
||||
use crate::plonk::circuit_data::CircuitConfig;
|
||||
use crate::plonk::verifier::verify;
|
||||
|
||||
fn test_sorting(size: usize, address_bits: usize, timestamp_bits: usize) -> Result<()> {
|
||||
type F = CrandallField;
|
||||
const D: usize = 4;
|
||||
|
||||
let config = CircuitConfig::large_zk_config();
|
||||
|
||||
let mut pw = PartialWitness::new();
|
||||
let mut builder = CircuitBuilder::<F, D>::new(config);
|
||||
|
||||
let mut rng = thread_rng();
|
||||
let is_write_vals: Vec<_> = (0..size).map(|_| rng.gen_range(0..2) != 0).collect();
|
||||
let address_vals: Vec<_> = (0..size)
|
||||
.map(|_| F::from_canonical_u64(rng.gen_range(0..1 << address_bits as u64)))
|
||||
.collect();
|
||||
let timestamp_vals: Vec<_> = (0..size)
|
||||
.map(|_| F::from_canonical_u64(rng.gen_range(0..1 << timestamp_bits as u64)))
|
||||
.collect();
|
||||
let value_vals: Vec<_> = (0..size).map(|_| F::rand()).collect();
|
||||
|
||||
let input_ops: Vec<MemoryOpTarget> = izip!(
|
||||
is_write_vals.clone(),
|
||||
address_vals.clone(),
|
||||
timestamp_vals.clone(),
|
||||
value_vals.clone()
|
||||
)
|
||||
.map(|(is_write, address, timestamp, value)| MemoryOpTarget {
|
||||
is_write: builder.constant_bool(is_write),
|
||||
address: builder.constant(address),
|
||||
timestamp: builder.constant(timestamp),
|
||||
value: builder.constant(value),
|
||||
})
|
||||
.collect();
|
||||
|
||||
let combined_vals_u64: Vec<_> = timestamp_vals
|
||||
.iter()
|
||||
.zip(&address_vals)
|
||||
.map(|(&t, &a)| (a.to_canonical_u64() << timestamp_bits as u64) + t.to_canonical_u64())
|
||||
.collect();
|
||||
let mut input_ops_and_keys: Vec<_> =
|
||||
izip!(is_write_vals, address_vals, timestamp_vals, value_vals)
|
||||
.zip(combined_vals_u64)
|
||||
.collect::<Vec<_>>();
|
||||
input_ops_and_keys.sort_by_key(|(_, val)| val.clone());
|
||||
let input_ops_sorted: Vec<_> = input_ops_and_keys.iter().map(|(x, _)| x).collect();
|
||||
|
||||
let output_ops =
|
||||
builder.sort_memory_ops(input_ops.as_slice(), address_bits, timestamp_bits);
|
||||
|
||||
for i in 0..size {
|
||||
pw.set_bool_target(output_ops[i].is_write, input_ops_sorted[i].0);
|
||||
pw.set_target(output_ops[i].address, input_ops_sorted[i].1);
|
||||
pw.set_target(output_ops[i].timestamp, input_ops_sorted[i].2);
|
||||
pw.set_target(output_ops[i].value, input_ops_sorted[i].3);
|
||||
}
|
||||
|
||||
let data = builder.build();
|
||||
let proof = data.prove(pw).unwrap();
|
||||
|
||||
verify(proof, &data.verifier_only, &data.common)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_sorting_small() -> Result<()> {
|
||||
let size = 5;
|
||||
let address_bits = 20;
|
||||
let timestamp_bits = 20;
|
||||
|
||||
test_sorting(size, address_bits, timestamp_bits)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_sorting_large() -> Result<()> {
|
||||
let size = 20;
|
||||
let address_bits = 20;
|
||||
let timestamp_bits = 20;
|
||||
|
||||
test_sorting(size, address_bits, timestamp_bits)
|
||||
}
|
||||
}
|
||||
@ -13,9 +13,9 @@ use crate::plonk::plonk_common::{reduce_with_powers, reduce_with_powers_ext_recu
|
||||
use crate::plonk::vars::{EvaluationTargets, EvaluationVars, EvaluationVarsBase};
|
||||
use crate::util::ceil_div_usize;
|
||||
|
||||
/// A gate for checking that one value is less than another.
|
||||
/// A gate for checking that one value is less than or equal to another.
|
||||
#[derive(Clone, Debug)]
|
||||
pub(crate) struct ComparisonGate<F: PrimeField + Extendable<D>, const D: usize> {
|
||||
pub struct ComparisonGate<F: PrimeField + Extendable<D>, const D: usize> {
|
||||
pub(crate) num_bits: usize,
|
||||
pub(crate) num_chunks: usize,
|
||||
_phantom: PhantomData<F>,
|
||||
@ -137,7 +137,7 @@ impl<F: RichField + Extendable<D>, const D: usize> Gate<F, D> for ComparisonGate
|
||||
constraints.push(most_significant_diff - most_significant_diff_so_far);
|
||||
|
||||
// Range check `most_significant_diff` to be less than `chunk_size`.
|
||||
let product = (1..chunk_size)
|
||||
let product = (0..chunk_size)
|
||||
.map(|x| most_significant_diff - F::Extension::from_canonical_usize(x))
|
||||
.product();
|
||||
constraints.push(product);
|
||||
@ -205,7 +205,7 @@ impl<F: RichField + Extendable<D>, const D: usize> Gate<F, D> for ComparisonGate
|
||||
constraints.push(most_significant_diff - most_significant_diff_so_far);
|
||||
|
||||
// Range check `most_significant_diff` to be less than `chunk_size`.
|
||||
let product = (1..chunk_size)
|
||||
let product = (0..chunk_size)
|
||||
.map(|x| most_significant_diff - F::from_canonical_usize(x))
|
||||
.product();
|
||||
constraints.push(product);
|
||||
@ -286,7 +286,7 @@ impl<F: RichField + Extendable<D>, const D: usize> Gate<F, D> for ComparisonGate
|
||||
|
||||
// Range check `most_significant_diff` to be less than `chunk_size`.
|
||||
let mut product = builder.one_extension();
|
||||
for x in 1..chunk_size {
|
||||
for x in 0..chunk_size {
|
||||
let x_F = builder.constant_extension(F::Extension::from_canonical_usize(x));
|
||||
let diff = builder.sub_extension(most_significant_diff, x_F);
|
||||
product = builder.mul_extension(product, diff);
|
||||
@ -386,7 +386,7 @@ impl<F: RichField + Extendable<D>, const D: usize> SimpleGenerator<F>
|
||||
|
||||
let mut most_significant_diff_so_far = F::ZERO;
|
||||
let mut intermediate_values = Vec::new();
|
||||
for i in 1..self.gate.num_chunks {
|
||||
for i in 0..self.gate.num_chunks {
|
||||
if first_input_chunks[i] != second_input_chunks[i] {
|
||||
most_significant_diff_so_far = second_input_chunks[i] - first_input_chunks[i];
|
||||
intermediate_values.push(F::ZERO);
|
||||
@ -552,7 +552,7 @@ mod tests {
|
||||
let first_input_u64 = rng.gen_range(0..max);
|
||||
let second_input_u64 = {
|
||||
let mut val = rng.gen_range(0..max);
|
||||
while val <= first_input_u64 {
|
||||
while val < first_input_u64 {
|
||||
val = rng.gen_range(0..max);
|
||||
}
|
||||
val
|
||||
@ -561,20 +561,39 @@ mod tests {
|
||||
let first_input = F::from_canonical_u64(first_input_u64);
|
||||
let second_input = F::from_canonical_u64(second_input_u64);
|
||||
|
||||
let gate = ComparisonGate::<F, D> {
|
||||
let less_than_gate = ComparisonGate::<F, D> {
|
||||
num_bits,
|
||||
num_chunks,
|
||||
_phantom: PhantomData,
|
||||
};
|
||||
|
||||
let vars = EvaluationVars {
|
||||
let less_than_vars = EvaluationVars {
|
||||
local_constants: &[],
|
||||
local_wires: &get_wires(first_input, second_input),
|
||||
public_inputs_hash: &HashOut::rand(),
|
||||
};
|
||||
|
||||
assert!(
|
||||
gate.eval_unfiltered(vars).iter().all(|x| x.is_zero()),
|
||||
less_than_gate
|
||||
.eval_unfiltered(less_than_vars)
|
||||
.iter()
|
||||
.all(|x| x.is_zero()),
|
||||
"Gate constraints are not satisfied."
|
||||
);
|
||||
|
||||
let equal_gate = ComparisonGate::<F, D> {
|
||||
num_bits,
|
||||
num_chunks,
|
||||
_phantom: PhantomData,
|
||||
};
|
||||
let equal_vars = EvaluationVars {
|
||||
local_constants: &[],
|
||||
local_wires: &get_wires(first_input, first_input),
|
||||
public_inputs_hash: &HashOut::rand(),
|
||||
};
|
||||
assert!(
|
||||
equal_gate
|
||||
.eval_unfiltered(equal_vars)
|
||||
.iter()
|
||||
.all(|x| x.is_zero()),
|
||||
"Gate constraints are not satisfied."
|
||||
);
|
||||
}
|
||||
|
||||
@ -236,20 +236,25 @@ impl<F: RichField + Extendable<D>, const D: usize> SwitchGenerator<F, D> {
|
||||
|
||||
let get_local_wire = |input| witness.get_wire(local_wire(input));
|
||||
|
||||
for e in 0..self.gate.chunk_size {
|
||||
let switch_bool_wire = local_wire(self.gate.wire_switch_bool(self.copy));
|
||||
let first_input = get_local_wire(self.gate.wire_first_input(self.copy, e));
|
||||
let second_input = get_local_wire(self.gate.wire_second_input(self.copy, e));
|
||||
let first_output = get_local_wire(self.gate.wire_first_output(self.copy, e));
|
||||
let second_output = get_local_wire(self.gate.wire_second_output(self.copy, e));
|
||||
let switch_bool_wire = local_wire(self.gate.wire_switch_bool(self.copy));
|
||||
|
||||
if first_output == first_input && second_output == second_input {
|
||||
out_buffer.set_wire(switch_bool_wire, F::ZERO);
|
||||
} else if first_output == second_input && second_output == first_input {
|
||||
out_buffer.set_wire(switch_bool_wire, F::ONE);
|
||||
} else {
|
||||
panic!("No permutation from given inputs to given outputs");
|
||||
}
|
||||
let mut first_inputs = Vec::new();
|
||||
let mut second_inputs = Vec::new();
|
||||
let mut first_outputs = Vec::new();
|
||||
let mut second_outputs = Vec::new();
|
||||
for e in 0..self.gate.chunk_size {
|
||||
first_inputs.push(get_local_wire(self.gate.wire_first_input(self.copy, e)));
|
||||
second_inputs.push(get_local_wire(self.gate.wire_second_input(self.copy, e)));
|
||||
first_outputs.push(get_local_wire(self.gate.wire_first_output(self.copy, e)));
|
||||
second_outputs.push(get_local_wire(self.gate.wire_second_output(self.copy, e)));
|
||||
}
|
||||
|
||||
if first_outputs == first_inputs && second_outputs == second_inputs {
|
||||
out_buffer.set_wire(switch_bool_wire, F::ZERO);
|
||||
} else if first_outputs == second_inputs && second_outputs == first_inputs {
|
||||
out_buffer.set_wire(switch_bool_wire, F::ONE);
|
||||
} else {
|
||||
panic!("No permutation from given inputs to given outputs");
|
||||
}
|
||||
}
|
||||
|
||||
@ -261,12 +266,12 @@ impl<F: RichField + Extendable<D>, const D: usize> SwitchGenerator<F, D> {
|
||||
|
||||
let get_local_wire = |input| witness.get_wire(local_wire(input));
|
||||
|
||||
let switch_bool = get_local_wire(self.gate.wire_switch_bool(self.copy));
|
||||
for e in 0..self.gate.chunk_size {
|
||||
let first_output_wire = local_wire(self.gate.wire_first_output(self.copy, e));
|
||||
let second_output_wire = local_wire(self.gate.wire_second_output(self.copy, e));
|
||||
let first_input = get_local_wire(self.gate.wire_first_input(self.copy, e));
|
||||
let second_input = get_local_wire(self.gate.wire_second_input(self.copy, e));
|
||||
let switch_bool = get_local_wire(self.gate.wire_switch_bool(self.copy));
|
||||
|
||||
let (first_output, second_output) = if switch_bool == F::ZERO {
|
||||
(first_input, second_input)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user