mirror of
https://github.com/logos-storage/plonky2.git
synced 2026-01-09 09:13:09 +00:00
Merge pull request #267 from mir-protocol/sorting_gen_refactor
Sorting tweaks
This commit is contained in:
commit
3b6d4cbeea
@ -1,13 +1,21 @@
|
||||
use itertools::izip;
|
||||
|
||||
use crate::field::extension_field::Extendable;
|
||||
use crate::field::field_types::RichField;
|
||||
use crate::field::field_types::{Field, RichField};
|
||||
use crate::gates::comparison::ComparisonGate;
|
||||
use crate::iop::generator::{GeneratedValues, SimpleGenerator};
|
||||
use crate::iop::target::{BoolTarget, Target};
|
||||
use crate::iop::witness::{PartitionWitness, Witness};
|
||||
use crate::plonk::circuit_builder::CircuitBuilder;
|
||||
use crate::util::ceil_div_usize;
|
||||
use std::marker::PhantomData;
|
||||
|
||||
pub struct MemoryOp<F: Field> {
|
||||
is_write: bool,
|
||||
address: F,
|
||||
timestamp: F,
|
||||
value: F,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct MemoryOpTarget {
|
||||
@ -39,14 +47,12 @@ impl<F: RichField + Extendable<D>, const D: usize> CircuitBuilder<F, D> {
|
||||
rhs: Target,
|
||||
bits: usize,
|
||||
num_chunks: usize,
|
||||
) -> (ComparisonGate<F, D>, usize) {
|
||||
) {
|
||||
let gate = ComparisonGate::new(bits, num_chunks);
|
||||
let gate_index = self.add_gate(gate.clone(), vec![]);
|
||||
|
||||
self.connect(Target::wire(gate_index, gate.wire_first_input()), lhs);
|
||||
self.connect(Target::wire(gate_index, gate.wire_second_input()), rhs);
|
||||
|
||||
(gate, gate_index)
|
||||
}
|
||||
|
||||
/// Sort memory operations by address value, then by timestamp value.
|
||||
@ -94,29 +100,21 @@ impl<F: RichField + Extendable<D>, const D: usize> CircuitBuilder<F, D> {
|
||||
.map(|op| self.mul_add(op.address, two_n, op.timestamp))
|
||||
.collect();
|
||||
|
||||
let mut gates = Vec::new();
|
||||
let mut gate_indices = Vec::new();
|
||||
for i in 1..n {
|
||||
let (gate, gate_index) = self.assert_le(
|
||||
self.assert_le(
|
||||
address_timestamp_combined[i - 1],
|
||||
address_timestamp_combined[i],
|
||||
combined_bits,
|
||||
num_chunks,
|
||||
);
|
||||
|
||||
gate_indices.push(gate_index);
|
||||
gates.push(gate);
|
||||
}
|
||||
|
||||
self.assert_permutation_memory_ops(ops, output_targets.as_slice());
|
||||
self.assert_permutation_memory_ops(ops, &output_targets);
|
||||
|
||||
self.add_simple_generator(MemoryOpSortGenerator::<F, D> {
|
||||
input_ops: ops.to_vec(),
|
||||
gate_indices,
|
||||
gates: gates.clone(),
|
||||
output_ops: output_targets.clone(),
|
||||
address_bits,
|
||||
timestamp_bits,
|
||||
_phantom: PhantomData,
|
||||
});
|
||||
|
||||
output_targets
|
||||
@ -126,11 +124,8 @@ impl<F: RichField + Extendable<D>, const D: usize> CircuitBuilder<F, D> {
|
||||
#[derive(Debug)]
|
||||
struct MemoryOpSortGenerator<F: RichField + Extendable<D>, const D: usize> {
|
||||
input_ops: Vec<MemoryOpTarget>,
|
||||
gate_indices: Vec<usize>,
|
||||
gates: Vec<ComparisonGate<F, D>>,
|
||||
output_ops: Vec<MemoryOpTarget>,
|
||||
address_bits: usize,
|
||||
timestamp_bits: usize,
|
||||
_phantom: PhantomData<F::Extension>,
|
||||
}
|
||||
|
||||
impl<F: RichField + Extendable<D>, const D: usize> SimpleGenerator<F>
|
||||
@ -148,61 +143,43 @@ impl<F: RichField + Extendable<D>, const D: usize> SimpleGenerator<F>
|
||||
let n = self.input_ops.len();
|
||||
debug_assert!(self.output_ops.len() == n);
|
||||
|
||||
let (timestamp_values, address_values): (Vec<_>, Vec<_>) = self
|
||||
let mut ops: Vec<_> = self
|
||||
.input_ops
|
||||
.iter()
|
||||
.map(|op| {
|
||||
(
|
||||
witness.get_target(op.timestamp),
|
||||
witness.get_target(op.address),
|
||||
)
|
||||
})
|
||||
.unzip();
|
||||
|
||||
let combined_values: Vec<_> = timestamp_values
|
||||
.iter()
|
||||
.zip(&address_values)
|
||||
.map(|(&t, &a)| {
|
||||
F::from_canonical_u64(
|
||||
(a.to_canonical_u64() << self.timestamp_bits as u64) + t.to_canonical_u64(),
|
||||
)
|
||||
let is_write = witness.get_bool_target(op.is_write);
|
||||
let address = witness.get_target(op.address);
|
||||
let timestamp = witness.get_target(op.timestamp);
|
||||
let value = witness.get_target(op.value);
|
||||
MemoryOp {
|
||||
is_write,
|
||||
address,
|
||||
timestamp,
|
||||
value,
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
|
||||
let mut input_ops_and_keys: Vec<_> = self
|
||||
.input_ops
|
||||
.iter()
|
||||
.zip(combined_values)
|
||||
.collect::<Vec<_>>();
|
||||
input_ops_and_keys.sort_by_key(|(_, val)| val.to_canonical_u64());
|
||||
ops.sort_unstable_by_key(|op| {
|
||||
(
|
||||
op.address.to_canonical_u64(),
|
||||
op.timestamp.to_canonical_u64(),
|
||||
)
|
||||
});
|
||||
|
||||
for i in 0..n {
|
||||
out_buffer.set_target(
|
||||
self.output_ops[i].is_write.target,
|
||||
witness.get_target(input_ops_and_keys[i].0.is_write.target),
|
||||
);
|
||||
out_buffer.set_target(
|
||||
self.output_ops[i].address,
|
||||
witness.get_target(input_ops_and_keys[i].0.address),
|
||||
);
|
||||
out_buffer.set_target(
|
||||
self.output_ops[i].timestamp,
|
||||
witness.get_target(input_ops_and_keys[i].0.timestamp),
|
||||
);
|
||||
out_buffer.set_target(
|
||||
self.output_ops[i].value,
|
||||
witness.get_target(input_ops_and_keys[i].0.value),
|
||||
);
|
||||
for (op, out_op) in ops.iter().zip(&self.output_ops) {
|
||||
out_buffer.set_target(out_op.is_write.target, F::from_bool(op.is_write));
|
||||
out_buffer.set_target(out_op.address, op.address);
|
||||
out_buffer.set_target(out_op.timestamp, op.timestamp);
|
||||
out_buffer.set_target(out_op.value, op.value);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::collections::HashSet;
|
||||
|
||||
use anyhow::Result;
|
||||
use rand::{seq::SliceRandom, thread_rng, Rng};
|
||||
use rand::{thread_rng, Rng};
|
||||
|
||||
use super::*;
|
||||
use crate::field::crandall_field::CrandallField;
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user