fixes galore

This commit is contained in:
Nicholas Ward 2021-09-21 18:01:21 -07:00
parent 3d93766cc8
commit 644d87e495
4 changed files with 140 additions and 28 deletions

View File

@ -384,7 +384,7 @@ impl<F: Field> SimpleGenerator<F> for PermutationGenerator<F> {
#[cfg(test)]
mod tests {
use anyhow::Result;
use rand::{seq::SliceRandom, thread_rng};
use rand::{seq::SliceRandom, thread_rng, Rng};
use super::*;
use crate::field::crandall_field::CrandallField;
@ -418,6 +418,35 @@ mod tests {
verify(proof, &data.verifier_only, &data.common)
}
fn test_permutation_duplicates(size: usize) -> Result<()> {
type F = CrandallField;
const D: usize = 4;
let config = CircuitConfig::large_zk_config();
let pw = PartialWitness::new();
let mut builder = CircuitBuilder::<F, D>::new(config);
let mut rng = thread_rng();
let lst: Vec<F> = (0..size * 2)
.map(|_| F::from_canonical_usize(rng.gen_range(0..2usize)))
.collect();
let a: Vec<Vec<Target>> = lst[..]
.chunks(2)
.map(|pair| vec![builder.constant(pair[0]), builder.constant(pair[1])])
.collect();
let mut b = a.clone();
b.shuffle(&mut thread_rng());
builder.assert_permutation(a, b);
let data = builder.build();
let proof = data.prove(pw).unwrap();
verify(proof, &data.verifier_only, &data.common)
}
fn test_permutation_bad(size: usize) -> Result<()> {
type F = CrandallField;
const D: usize = 4;
@ -446,6 +475,15 @@ mod tests {
Ok(())
}
#[test]
fn test_permutations_duplicates() -> Result<()> {
for n in 2..9 {
test_permutation_duplicates(n)?;
}
Ok(())
}
#[test]
fn test_permutations_good() -> Result<()> {
for n in 2..9 {

View File

@ -1,16 +1,15 @@
use std::marker::PhantomData;
use itertools::izip;
use itertools::{izip, Itertools};
use crate::field::field_types::{PrimeField, RichField};
use crate::field::{extension_field::Extendable, field_types::Field};
use crate::field::field_types::RichField;
use crate::field::extension_field::Extendable;
use crate::gates::comparison::ComparisonGate;
use crate::iop::generator::{GeneratedValues, SimpleGenerator};
use crate::iop::target::{BoolTarget, Target};
use crate::iop::witness::{PartitionWitness, Witness};
use crate::plonk::circuit_builder::CircuitBuilder;
use crate::util::ceil_div_usize;
#[derive(Debug)]
#[derive(Clone, Debug)]
pub struct MemoryOpTarget {
is_write: BoolTarget,
address: Target,
@ -41,7 +40,8 @@ impl<F: RichField + Extendable<D>, const D: usize> CircuitBuilder<F, D> {
let n = ops.len();
let combined_bits = address_bits + timestamp_bits;
let chunk_size = 3;
let chunk_bits = 3;
let num_chunks = ceil_div_usize(combined_bits, chunk_bits);
let is_write_targets: Vec<_> = self
.add_virtual_targets(n)
@ -69,12 +69,14 @@ impl<F: RichField + Extendable<D>, const D: usize> CircuitBuilder<F, D> {
let two_n = self.constant(F::from_canonical_usize(1 << timestamp_bits));
let address_timestamp_combined: Vec<_> = output_targets
.iter()
.map(|op| self.mul_add(op.timestamp, two_n, op.address))
.map(|op| self.mul_add(op.address, two_n, op.timestamp))
.collect();
let mut gate_indices = Vec::new();
let mut gates = Vec::new();
for i in 1..n {
let (gate, gate_index) = {
let gate = ComparisonGate::new(combined_bits, chunk_size);
let gate = ComparisonGate::new(combined_bits, num_chunks);
let gate_index = self.add_gate(gate.clone(), vec![]);
(gate, gate_index)
};
@ -87,24 +89,39 @@ impl<F: RichField + Extendable<D>, const D: usize> CircuitBuilder<F, D> {
Target::wire(gate_index, gate.wire_second_input()),
address_timestamp_combined[i],
);
gate_indices.push(gate_index);
gates.push(gate);
}
self.assert_permutation_memory_ops(ops, output_targets.as_slice());
self.add_simple_generator(MemoryOpSortGenerator::<F, D> {
input_ops: ops.to_vec(),
gate_indices,
gates: gates.clone(),
output_ops: output_targets.clone(),
address_bits,
timestamp_bits,
});
output_targets
}
}
#[derive(Debug)]
struct MemoryOpSortGenerator<F: RichField> {
struct MemoryOpSortGenerator<F: RichField + Extendable<D>, const D: usize> {
input_ops: Vec<MemoryOpTarget>,
gate_indices: Vec<usize>,
gates: Vec<ComparisonGate<F, D>>,
output_ops: Vec<MemoryOpTarget>,
address_bits: usize,
timestamp_bits: usize,
_phantom: PhantomData<F>,
}
impl<F: RichField> SimpleGenerator<F> for MemoryOpSortGenerator<F> {
impl<F: RichField + Extendable<D>, const D: usize> SimpleGenerator<F>
for MemoryOpSortGenerator<F, D>
{
fn dependencies(&self) -> Vec<Target> {
self.input_ops
.iter()
@ -136,6 +153,13 @@ impl<F: RichField> SimpleGenerator<F> for MemoryOpSortGenerator<F> {
})
.collect();
let mut combined_values_sorted = combined_values_u64.clone();
combined_values_sorted.sort();
let combined_values: Vec<_> = combined_values_sorted
.iter()
.map(|&x| F::from_canonical_u64(x))
.collect();
let mut input_ops_and_keys: Vec<_> = self
.input_ops
.iter()
@ -161,12 +185,30 @@ impl<F: RichField> SimpleGenerator<F> for MemoryOpSortGenerator<F> {
self.output_ops[i].value,
witness.get_target(input_ops_sorted[i].value),
);
if i > 0 {
out_buffer.set_target(
Target::wire(
self.gate_indices[i - 1],
self.gates[i - 1].wire_second_input(),
),
combined_values[i],
);
}
if i < n - 1 {
out_buffer.set_target(
Target::wire(self.gate_indices[i], self.gates[i].wire_first_input()),
combined_values[i],
);
}
}
}
}
#[cfg(test)]
mod tests {
use std::collections::HashSet;
use anyhow::Result;
use rand::{seq::SliceRandom, thread_rng, Rng};
@ -223,4 +265,13 @@ mod tests {
test_sorting(size, address_bits, timestamp_bits)
}
#[test]
fn test_sorting_large() -> Result<()> {
let size = 20;
let address_bits = 20;
let timestamp_bits = 20;
test_sorting(size, address_bits, timestamp_bits)
}
}

View File

@ -386,7 +386,7 @@ impl<F: RichField + Extendable<D>, const D: usize> SimpleGenerator<F>
let mut most_significant_diff_so_far = F::ZERO;
let mut intermediate_values = Vec::new();
for i in 1..self.gate.num_chunks {
for i in 0..self.gate.num_chunks {
if first_input_chunks[i] != second_input_chunks[i] {
most_significant_diff_so_far = second_input_chunks[i] - first_input_chunks[i];
intermediate_values.push(F::ZERO);

View File

@ -236,20 +236,43 @@ impl<F: RichField + Extendable<D>, const D: usize> SwitchGenerator<F, D> {
let get_local_wire = |input| witness.get_wire(local_wire(input));
for e in 0..self.gate.chunk_size {
let switch_bool_wire = local_wire(self.gate.wire_switch_bool(self.copy));
let first_input = get_local_wire(self.gate.wire_first_input(self.copy, e));
let second_input = get_local_wire(self.gate.wire_second_input(self.copy, e));
let first_output = get_local_wire(self.gate.wire_first_output(self.copy, e));
let second_output = get_local_wire(self.gate.wire_second_output(self.copy, e));
let switch_bool_wire = local_wire(self.gate.wire_switch_bool(self.copy));
if first_output == first_input && second_output == second_input {
out_buffer.set_wire(switch_bool_wire, F::ZERO);
} else if first_output == second_input && second_output == first_input {
out_buffer.set_wire(switch_bool_wire, F::ONE);
} else {
panic!("No permutation from given inputs to given outputs");
}
let mut first_inputs = Vec::new();
let mut second_inputs = Vec::new();
let mut first_outputs = Vec::new();
let mut second_outputs = Vec::new();
for e in 0..self.gate.chunk_size {
first_inputs.push(get_local_wire(self.gate.wire_first_input(self.copy, e)));
second_inputs.push(get_local_wire(self.gate.wire_second_input(self.copy, e)));
first_outputs.push(get_local_wire(self.gate.wire_first_output(self.copy, e)));
second_outputs.push(get_local_wire(self.gate.wire_second_output(self.copy, e)));
}
let first_keep = first_outputs
.iter()
.zip(first_inputs.iter())
.all(|(x, y)| x == y);
let second_keep = second_outputs
.iter()
.zip(second_inputs.iter())
.all(|(x, y)| x == y);
let first_swap = first_outputs
.iter()
.zip(second_inputs.iter())
.all(|(x, y)| x == y);
let second_swap = second_outputs
.iter()
.zip(first_inputs.iter())
.all(|(x, y)| x == y);
if first_keep && second_keep {
out_buffer.set_wire(switch_bool_wire, F::ZERO);
} else if first_swap && second_swap {
out_buffer.set_wire(switch_bool_wire, F::ONE);
} else {
panic!("No permutation from given inputs to given outputs");
}
}
@ -261,12 +284,12 @@ impl<F: RichField + Extendable<D>, const D: usize> SwitchGenerator<F, D> {
let get_local_wire = |input| witness.get_wire(local_wire(input));
let switch_bool = get_local_wire(self.gate.wire_switch_bool(self.copy));
for e in 0..self.gate.chunk_size {
let first_output_wire = local_wire(self.gate.wire_first_output(self.copy, e));
let second_output_wire = local_wire(self.gate.wire_second_output(self.copy, e));
let first_input = get_local_wire(self.gate.wire_first_input(self.copy, e));
let second_input = get_local_wire(self.gate.wire_second_input(self.copy, e));
let switch_bool = get_local_wire(self.gate.wire_switch_bool(self.copy));
let (first_output, second_output) = if switch_bool == F::ZERO {
(first_input, second_input)