diff --git a/src/gadgets/sorting.rs b/src/gadgets/sorting.rs index 57f65a51..21601988 100644 --- a/src/gadgets/sorting.rs +++ b/src/gadgets/sorting.rs @@ -159,62 +159,40 @@ impl, const D: usize> SimpleGenerator }) .unzip(); - let combined_values_u64: Vec<_> = timestamp_values + let combined_values: Vec<_> = timestamp_values .iter() - .zip(address_values.iter()) + .zip(&address_values) .map(|(&t, &a)| { - a.to_canonical_u64() * (1 << self.timestamp_bits as u64) + t.to_canonical_u64() + F::from_canonical_u64( + (a.to_canonical_u64() << self.timestamp_bits as u64) + t.to_canonical_u64(), + ) }) .collect(); - let mut combined_values_sorted = combined_values_u64.clone(); - combined_values_sorted.sort(); - let combined_values: Vec<_> = combined_values_sorted - .iter() - .map(|&x| F::from_canonical_u64(x)) - .collect(); - let mut input_ops_and_keys: Vec<_> = self .input_ops .iter() - .zip(combined_values_u64) + .zip(combined_values) .collect::>(); - input_ops_and_keys.sort_by(|(_, a_val), (_, b_val)| a_val.cmp(b_val)); - let input_ops_sorted: Vec<_> = input_ops_and_keys.iter().map(|(op, _)| op).collect(); + input_ops_and_keys.sort_by_key(|(_, val)| val.to_canonical_u64()); for i in 0..n { out_buffer.set_target( self.output_ops[i].is_write.target, - witness.get_target(input_ops_sorted[i].is_write.target), + witness.get_target(input_ops_and_keys[i].0.is_write.target), ); out_buffer.set_target( self.output_ops[i].address, - witness.get_target(input_ops_sorted[i].address), + witness.get_target(input_ops_and_keys[i].0.address), ); out_buffer.set_target( self.output_ops[i].timestamp, - witness.get_target(input_ops_sorted[i].timestamp), + witness.get_target(input_ops_and_keys[i].0.timestamp), ); out_buffer.set_target( self.output_ops[i].value, - witness.get_target(input_ops_sorted[i].value), + witness.get_target(input_ops_and_keys[i].0.value), ); - - if i > 0 { - out_buffer.set_target( - Target::wire( - self.gate_indices[i - 1], - self.gates[i - 1].wire_second_input(), - ), - combined_values[i], - ); - } - if i < n - 1 { - out_buffer.set_target( - Target::wire(self.gate_indices[i], self.gates[i].wire_first_input()), - combined_values[i], - ); - } } } } @@ -228,7 +206,7 @@ mod tests { use super::*; use crate::field::crandall_field::CrandallField; - use crate::field::field_types::Field; + use crate::field::field_types::{Field, PrimeField}; use crate::iop::witness::PartialWitness; use crate::plonk::circuit_data::CircuitConfig; use crate::plonk::verifier::verify; @@ -239,7 +217,7 @@ mod tests { let config = CircuitConfig::large_zk_config(); - let pw = PartialWitness::new(); + let mut pw = PartialWitness::new(); let mut builder = CircuitBuilder::::new(config); let mut rng = thread_rng(); @@ -252,19 +230,42 @@ mod tests { .collect(); let value_vals: Vec<_> = (0..size).map(|_| F::rand()).collect(); - let input_ops: Vec = - izip!(is_write_vals, address_vals, timestamp_vals, value_vals) - .map(|(is_write, address, timestamp, value)| MemoryOpTarget { - is_write: builder.constant_bool(is_write), - address: builder.constant(address), - timestamp: builder.constant(timestamp), - value: builder.constant(value), - }) - .collect(); + let input_ops: Vec = izip!( + is_write_vals.clone(), + address_vals.clone(), + timestamp_vals.clone(), + value_vals.clone() + ) + .map(|(is_write, address, timestamp, value)| MemoryOpTarget { + is_write: builder.constant_bool(is_write), + address: builder.constant(address), + timestamp: builder.constant(timestamp), + value: builder.constant(value), + }) + .collect(); - let _output_ops = + let combined_vals_u64: Vec<_> = timestamp_vals + .iter() + .zip(&address_vals) + .map(|(&t, &a)| (a.to_canonical_u64() << timestamp_bits as u64) + t.to_canonical_u64()) + .collect(); + let mut input_ops_and_keys: Vec<_> = + izip!(is_write_vals, address_vals, timestamp_vals, value_vals) + .zip(combined_vals_u64) + .collect::>(); + input_ops_and_keys.sort_by_key(|(_, val)| val.clone()); + let input_ops_sorted: Vec<_> = input_ops_and_keys.iter().map(|(x, _)| x).collect(); + + let output_ops = builder.sort_memory_ops(input_ops.as_slice(), address_bits, timestamp_bits); + for i in 0..size { + pw.set_bool_target(output_ops[i].is_write, input_ops_sorted[i].0); + pw.set_target(output_ops[i].address, input_ops_sorted[i].1); + pw.set_target(output_ops[i].timestamp, input_ops_sorted[i].2); + pw.set_target(output_ops[i].value, input_ops_sorted[i].3); + } + let data = builder.build(); let proof = data.prove(pw).unwrap();