From 8dd00b8d41ed3b60d43b1fa17902d6f499571e71 Mon Sep 17 00:00:00 2001 From: Nicholas Ward Date: Fri, 17 Sep 2021 13:40:07 -0700 Subject: [PATCH] added generator --- src/gadgets/sorting.rs | 87 +++++++++++++++++++++++++++++------------- 1 file changed, 61 insertions(+), 26 deletions(-) diff --git a/src/gadgets/sorting.rs b/src/gadgets/sorting.rs index 1547295b..102f30cd 100644 --- a/src/gadgets/sorting.rs +++ b/src/gadgets/sorting.rs @@ -1,8 +1,8 @@ use std::marker::PhantomData; -use itertools::izip; +use itertools::{izip, Itertools}; -use crate::field::field_types::RichField; +use crate::field::field_types::{PrimeField, RichField}; use crate::field::{extension_field::Extendable, field_types::Field}; use crate::gates::comparison::ComparisonGate; use crate::iop::generator::{GeneratedValues, SimpleGenerator}; @@ -10,6 +10,7 @@ use crate::iop::target::{BoolTarget, Target}; use crate::iop::witness::{PartitionWitness, Witness}; use crate::plonk::circuit_builder::CircuitBuilder; +#[derive(Debug)] pub struct MemoryOpTarget { is_write: BoolTarget, address: Target, @@ -94,41 +95,75 @@ impl, const D: usize> CircuitBuilder { } } -/*#[derive(Debug)] -struct MemoryOpSortGenerator { - a: Vec>, - b: Vec>, - a_switches: Vec, - b_switches: Vec, +#[derive(Debug)] +struct MemoryOpSortGenerator { + input_ops: Vec, + output_ops: Vec, + address_bits: usize, + timestamp_bits: usize, _phantom: PhantomData, } -impl SimpleGenerator for MemoryOpSortGenerator { +impl SimpleGenerator for MemoryOpSortGenerator { fn dependencies(&self) -> Vec { - self.a.iter().chain(&self.b).flatten().cloned().collect() + self.input_ops + .iter() + .map(|op| vec![op.is_write.target, op.address, op.timestamp, op.value]) + .flatten() + .collect() } fn run_once(&self, witness: &PartitionWitness, out_buffer: &mut GeneratedValues) { - let a_values = self - .a + let n = self.input_ops.len(); + debug_assert!(self.output_ops.len() == n); + + let (timestamp_values, address_values): (Vec<_>, Vec<_>) = self + .input_ops .iter() - .map(|chunk| chunk.iter().map(|wire| witness.get_target(*wire)).collect()) - .collect(); - let b_values = self - .b + .map(|op| { + ( + witness.get_target(op.timestamp), + witness.get_target(op.address), + ) + }) + .unzip(); + + let combined_values_u64: Vec<_> = timestamp_values .iter() - .map(|chunk| chunk.iter().map(|wire| witness.get_target(*wire)).collect()) + .zip(address_values.iter()) + .map(|(&t, &a)| { + a.to_canonical_u64() * (1 << self.timestamp_bits as u64) + t.to_canonical_u64() + }) .collect(); - route( - a_values, - b_values, - self.a_switches.clone(), - self.b_switches.clone(), - witness, - out_buffer, - ); + + let mut input_ops_and_keys: Vec<_> = self + .input_ops + .iter() + .zip(combined_values_u64) + .collect::>(); + input_ops_and_keys.sort_by(|(_, a_val), (_, b_val)| a_val.cmp(b_val)); + let input_ops_sorted: Vec<_> = input_ops_and_keys.iter().map(|(op, _)| op).collect(); + + for i in 0..n { + out_buffer.set_target( + self.output_ops[i].is_write.target, + witness.get_target(input_ops_sorted[i].is_write.target), + ); + out_buffer.set_target( + self.output_ops[i].address, + witness.get_target(input_ops_sorted[i].address), + ); + out_buffer.set_target( + self.output_ops[i].timestamp, + witness.get_target(input_ops_sorted[i].timestamp), + ); + out_buffer.set_target( + self.output_ops[i].value, + witness.get_target(input_ops_sorted[i].value), + ); + } } -}*/ +} #[cfg(test)] mod tests {