addressed comments (set sorted values in partial witness; no more directly setting gate inputs)

This commit is contained in:
Nicholas Ward 2021-09-22 14:03:27 -07:00
parent 2ec3b29741
commit 8aa4376360

View File

@ -159,62 +159,40 @@ impl<F: RichField + Extendable<D>, const D: usize> SimpleGenerator<F>
})
.unzip();
let combined_values_u64: Vec<_> = timestamp_values
let combined_values: Vec<_> = timestamp_values
.iter()
.zip(address_values.iter())
.zip(&address_values)
.map(|(&t, &a)| {
a.to_canonical_u64() * (1 << self.timestamp_bits as u64) + t.to_canonical_u64()
F::from_canonical_u64(
(a.to_canonical_u64() << self.timestamp_bits as u64) + t.to_canonical_u64(),
)
})
.collect();
let mut combined_values_sorted = combined_values_u64.clone();
combined_values_sorted.sort();
let combined_values: Vec<_> = combined_values_sorted
.iter()
.map(|&x| F::from_canonical_u64(x))
.collect();
let mut input_ops_and_keys: Vec<_> = self
.input_ops
.iter()
.zip(combined_values_u64)
.zip(combined_values)
.collect::<Vec<_>>();
input_ops_and_keys.sort_by(|(_, a_val), (_, b_val)| a_val.cmp(b_val));
let input_ops_sorted: Vec<_> = input_ops_and_keys.iter().map(|(op, _)| op).collect();
input_ops_and_keys.sort_by_key(|(_, val)| val.to_canonical_u64());
for i in 0..n {
out_buffer.set_target(
self.output_ops[i].is_write.target,
witness.get_target(input_ops_sorted[i].is_write.target),
witness.get_target(input_ops_and_keys[i].0.is_write.target),
);
out_buffer.set_target(
self.output_ops[i].address,
witness.get_target(input_ops_sorted[i].address),
witness.get_target(input_ops_and_keys[i].0.address),
);
out_buffer.set_target(
self.output_ops[i].timestamp,
witness.get_target(input_ops_sorted[i].timestamp),
witness.get_target(input_ops_and_keys[i].0.timestamp),
);
out_buffer.set_target(
self.output_ops[i].value,
witness.get_target(input_ops_sorted[i].value),
witness.get_target(input_ops_and_keys[i].0.value),
);
if i > 0 {
out_buffer.set_target(
Target::wire(
self.gate_indices[i - 1],
self.gates[i - 1].wire_second_input(),
),
combined_values[i],
);
}
if i < n - 1 {
out_buffer.set_target(
Target::wire(self.gate_indices[i], self.gates[i].wire_first_input()),
combined_values[i],
);
}
}
}
}
@ -228,7 +206,7 @@ mod tests {
use super::*;
use crate::field::crandall_field::CrandallField;
use crate::field::field_types::Field;
use crate::field::field_types::{Field, PrimeField};
use crate::iop::witness::PartialWitness;
use crate::plonk::circuit_data::CircuitConfig;
use crate::plonk::verifier::verify;
@ -239,7 +217,7 @@ mod tests {
let config = CircuitConfig::large_zk_config();
let pw = PartialWitness::new();
let mut pw = PartialWitness::new();
let mut builder = CircuitBuilder::<F, D>::new(config);
let mut rng = thread_rng();
@ -252,19 +230,42 @@ mod tests {
.collect();
let value_vals: Vec<_> = (0..size).map(|_| F::rand()).collect();
let input_ops: Vec<MemoryOpTarget> =
izip!(is_write_vals, address_vals, timestamp_vals, value_vals)
.map(|(is_write, address, timestamp, value)| MemoryOpTarget {
is_write: builder.constant_bool(is_write),
address: builder.constant(address),
timestamp: builder.constant(timestamp),
value: builder.constant(value),
})
.collect();
let input_ops: Vec<MemoryOpTarget> = izip!(
is_write_vals.clone(),
address_vals.clone(),
timestamp_vals.clone(),
value_vals.clone()
)
.map(|(is_write, address, timestamp, value)| MemoryOpTarget {
is_write: builder.constant_bool(is_write),
address: builder.constant(address),
timestamp: builder.constant(timestamp),
value: builder.constant(value),
})
.collect();
let _output_ops =
let combined_vals_u64: Vec<_> = timestamp_vals
.iter()
.zip(&address_vals)
.map(|(&t, &a)| (a.to_canonical_u64() << timestamp_bits as u64) + t.to_canonical_u64())
.collect();
let mut input_ops_and_keys: Vec<_> =
izip!(is_write_vals, address_vals, timestamp_vals, value_vals)
.zip(combined_vals_u64)
.collect::<Vec<_>>();
input_ops_and_keys.sort_by_key(|(_, val)| val.clone());
let input_ops_sorted: Vec<_> = input_ops_and_keys.iter().map(|(x, _)| x).collect();
let output_ops =
builder.sort_memory_ops(input_ops.as_slice(), address_bits, timestamp_bits);
for i in 0..size {
pw.set_bool_target(output_ops[i].is_write, input_ops_sorted[i].0);
pw.set_target(output_ops[i].address, input_ops_sorted[i].1);
pw.set_target(output_ops[i].timestamp, input_ops_sorted[i].2);
pw.set_target(output_ops[i].value, input_ops_sorted[i].3);
}
let data = builder.build();
let proof = data.prove(pw).unwrap();