From c207a028520598f88b701b4093f6663af4060ce1 Mon Sep 17 00:00:00 2001 From: Nicholas Ward Date: Wed, 15 Sep 2021 16:41:29 -0700 Subject: [PATCH] changes and fixes (z --> most_significant_diff) --- src/gates/comparison.rs | 287 ++++++++++++++++++---------------------- 1 file changed, 127 insertions(+), 160 deletions(-) diff --git a/src/gates/comparison.rs b/src/gates/comparison.rs index 8b8a5da4..9e8393f1 100644 --- a/src/gates/comparison.rs +++ b/src/gates/comparison.rs @@ -42,29 +42,28 @@ impl, const D: usize> ComparisonGate { 1 } - pub fn wire_z_bit(&self, bit_index: usize) -> usize { - debug_assert!(bit_index < self.chunk_bits() + 1); - 2 + bit_index + pub fn wire_most_significant_diff(&self) -> usize { + 2 } pub fn wire_first_chunk_val(&self, chunk: usize) -> usize { debug_assert!(chunk < self.num_chunks); - 3 + self.chunk_bits() + chunk + 3 + chunk } pub fn wire_second_chunk_val(&self, chunk: usize) -> usize { debug_assert!(chunk < self.num_chunks); - 3 + self.chunk_bits() + self.num_chunks + chunk + 3 + self.num_chunks + chunk } pub fn wire_equality_dummy(&self, chunk: usize) -> usize { debug_assert!(chunk < self.num_chunks); - 3 + self.chunk_bits() + 2 * self.num_chunks + chunk + 3 + 2 * self.num_chunks + chunk } pub fn wire_chunks_equal(&self, chunk: usize) -> usize { debug_assert!(chunk < self.num_chunks); - 3 + self.chunk_bits() + 3 * self.num_chunks + chunk + 3 + 3 * self.num_chunks + chunk } } @@ -99,14 +98,15 @@ impl, const D: usize> Gate for ComparisonGate constraints.push(first_chunks_combined - first_input); constraints.push(second_chunks_combined - second_input); - let mut most_significant_diff = F::Extension::ZERO; + let max_chunk_size = 1 << self.chunk_bits(); + + let mut most_significant_diff_so_far = F::Extension::ZERO; // Find the chosen chunk. for i in 0..self.num_chunks { - let max_chunk_size = 1 << self.chunk_bits(); let mut first_product = F::Extension::ONE; let mut second_product = F::Extension::ONE; - for x in 1..max_chunk_size { + for x in 0..max_chunk_size { let x_F = F::Extension::from_canonical_usize(x); first_product = first_product * (first_chunks[i] - x_F); second_product = second_product * (second_chunks[i] - x_F); @@ -114,7 +114,7 @@ impl, const D: usize> Gate for ComparisonGate constraints.push(first_product); constraints.push(second_product); - let difference = first_chunks[i] - second_chunks[i]; + let difference = second_chunks[i] - first_chunks[i]; let equality_dummy = vars.local_wires[self.wire_equality_dummy(i)]; let chunks_equal = vars.local_wires[self.wire_chunks_equal(i)]; @@ -122,20 +122,21 @@ impl, const D: usize> Gate for ComparisonGate constraints.push(difference * equality_dummy - (F::Extension::ONE - chunks_equal)); constraints.push(chunks_equal * difference); - let this_diff = first_chunks[i] - second_chunks[i]; - most_significant_diff = chunks_equal * most_significant_diff + let this_diff = second_chunks[i] - first_chunks[i]; + most_significant_diff_so_far = chunks_equal * most_significant_diff_so_far + (F::Extension::ONE - chunks_equal) * this_diff; } - let z_bits: Vec = (0..self.chunk_bits() + 1) - .map(|i| vars.local_wires[self.wire_z_bit(i)]) - .collect(); - let z_bits_combined = reduce_with_powers(&z_bits, F::Extension::TWO); + let most_significant_diff = vars.local_wires[self.wire_most_significant_diff()]; + constraints.push(most_significant_diff - most_significant_diff_so_far); - let two_n = F::Extension::TWO.exp_u64(self.chunk_bits() as u64); - constraints.push(z_bits_combined - (two_n + most_significant_diff)); - - constraints.push(z_bits[self.chunk_bits()]); + // Range check + let mut product = F::Extension::ONE; + for x in 0..max_chunk_size { + let x_F = F::Extension::from_canonical_usize(x); + product = product * (most_significant_diff - x_F); + } + constraints.push(product); constraints } @@ -166,14 +167,15 @@ impl, const D: usize> Gate for ComparisonGate constraints.push(first_chunks_combined - first_input); constraints.push(second_chunks_combined - second_input); - let mut most_significant_diff = F::ZERO; + let max_chunk_size = 1 << self.chunk_bits(); + + let mut most_significant_diff_so_far = F::ZERO; // Find the chosen chunk. for i in 0..self.num_chunks { - let max_chunk_size = 1 << self.chunk_bits(); let mut first_product = F::ONE; let mut second_product = F::ONE; - for x in 1..max_chunk_size { + for x in 0..max_chunk_size { let x_F = F::from_canonical_usize(x); first_product = first_product * (first_chunks[i] - x_F); second_product = second_product * (second_chunks[i] - x_F); @@ -190,19 +192,20 @@ impl, const D: usize> Gate for ComparisonGate constraints.push(chunks_equal * difference); let this_diff = first_chunks[i] - second_chunks[i]; - most_significant_diff = - chunks_equal * most_significant_diff + (F::ONE - chunks_equal) * this_diff; + most_significant_diff_so_far = + chunks_equal * most_significant_diff_so_far + (F::ONE - chunks_equal) * this_diff; } - let z_bits: Vec = (0..self.chunk_bits() + 1) - .map(|i| vars.local_wires[self.wire_z_bit(i)]) - .collect(); - let z_bits_combined = reduce_with_powers(&z_bits, F::TWO); + let most_significant_diff = vars.local_wires[self.wire_most_significant_diff()]; + constraints.push(most_significant_diff - most_significant_diff_so_far); - let two_n = F::TWO.exp_u64(self.chunk_bits() as u64); - constraints.push(z_bits_combined - (two_n + most_significant_diff)); - - constraints.push(z_bits[self.chunk_bits()]); + // Range check + let mut product = F::ONE; + for x in 0..max_chunk_size { + let x_F = F::from_canonical_usize(x); + product = product * (most_significant_diff - x_F); + } + constraints.push(product); constraints } @@ -234,15 +237,16 @@ impl, const D: usize> Gate for ComparisonGate constraints.push(builder.sub_extension(first_chunks_combined, first_input)); constraints.push(builder.sub_extension(second_chunks_combined, second_input)); - let mut most_significant_diff = builder.zero_extension(); + let max_chunk_size = 1 << self.chunk_bits(); + + let mut most_significant_diff_so_far = builder.zero_extension(); let one = builder.one_extension(); // Find the chosen chunk. for i in 0..self.num_chunks { - let max_chunk_size = 1 << self.chunk_bits(); let mut first_product = one; let mut second_product = one; - for x in 1..max_chunk_size { + for x in 0..max_chunk_size { let x_F = builder.constant_extension(F::Extension::from_canonical_usize(x)); let first_diff = builder.sub_extension(first_chunks[i], x_F); let second_diff = builder.sub_extension(second_chunks[i], x_F); @@ -252,7 +256,7 @@ impl, const D: usize> Gate for ComparisonGate constraints.push(first_product); constraints.push(second_product); - let difference = builder.sub_extension(first_chunks[i], second_chunks[i]); + let difference = builder.sub_extension(second_chunks[i], first_chunks[i]); let equality_dummy = vars.local_wires[self.wire_equality_dummy(i)]; let chunks_equal = vars.local_wires[self.wire_chunks_equal(i)]; @@ -262,25 +266,24 @@ impl, const D: usize> Gate for ComparisonGate constraints.push(builder.sub_extension(diff_times_equal, not_equal)); constraints.push(builder.mul_extension(chunks_equal, difference)); - let this_diff = builder.sub_extension(first_chunks[i], second_chunks[i]); - let old_diff = builder.mul_extension(chunks_equal, most_significant_diff); + let this_diff = builder.sub_extension(second_chunks[i], first_chunks[i]); + let old_diff = builder.mul_extension(chunks_equal, most_significant_diff_so_far); let not_equal = builder.sub_extension(one, chunks_equal); let new_diff = builder.mul_extension(not_equal, this_diff); - most_significant_diff = builder.add_extension(old_diff, new_diff); + most_significant_diff_so_far = builder.add_extension(old_diff, new_diff); } - let two = builder.constant(F::TWO); - let z_bits: Vec> = (0..self.chunk_bits() + 1) - .map(|i| vars.local_wires[self.wire_z_bit(i)]) - .collect(); - let z_bits_combined = reduce_with_powers_ext_recursive(builder, &z_bits, two); + let most_significant_diff = vars.local_wires[self.wire_most_significant_diff()]; + constraints.push(builder.sub_extension(most_significant_diff, most_significant_diff_so_far)); - let two_n = builder.constant_extension(F::Extension::TWO.exp_u64(self.chunk_bits() as u64)); - let expected_z = builder.add_extension(two_n, most_significant_diff); - let z_diff = builder.sub_extension(z_bits_combined, expected_z); - constraints.push(z_diff); - - constraints.push(z_bits[self.chunk_bits()]); + // Range check + let mut product = builder.one_extension(); + for x in 0..max_chunk_size { + let x_F = builder.constant_extension(F::Extension::from_canonical_usize(x)); + let diff = builder.sub_extension(most_significant_diff, x_F); + product = builder.mul_extension(product, diff); + } + constraints.push(product); constraints } @@ -378,7 +381,7 @@ impl, const D: usize> SimpleGenerator let equality_dummies: Vec = first_input_chunks .iter() .zip(second_input_chunks.iter()) - .map(|(f, s)| if *f == *s { F::ONE } else { F::ONE / (*f - *s) }) + .map(|(&f, &s)| if f == s { F::ONE } else { F::ONE / (s - f) }) .collect(); let mut diff_index = 0; @@ -387,21 +390,10 @@ impl, const D: usize> SimpleGenerator diff_index = i; } } - let most_significant_diff = - first_input_chunks[diff_index] - second_input_chunks[diff_index]; + + let most_significant_diff = second_input_chunks[diff_index] - first_input_chunks[diff_index]; - let z = F::TWO.exp_u64(self.gate.chunk_bits() as u64) + most_significant_diff; - let z_bits: Vec = (0..self.gate.chunk_bits() + 1) - .scan(z.to_canonical_u64(), |acc, _| { - let tmp = *acc % 2; - *acc /= 2; - Some(F::from_canonical_u64(tmp)) - }) - .collect(); - - for b in 0..self.gate.chunk_bits() + 1 { - out_buffer.set_wire(local_wire(self.gate.wire_z_bit(b)), z_bits[b]); - } + out_buffer.set_wire(local_wire(self.gate.wire_most_significant_diff()), most_significant_diff); for i in 0..self.gate.num_chunks { out_buffer.set_wire( local_wire(self.gate.wire_first_chunk_val(i)), @@ -411,11 +403,11 @@ impl, const D: usize> SimpleGenerator local_wire(self.gate.wire_second_chunk_val(i)), second_input_chunks[i], ); - out_buffer.set_wire(local_wire(self.gate.wire_chunks_equal(i)), chunks_equal[i]); out_buffer.set_wire( local_wire(self.gate.wire_equality_dummy(i)), equality_dummies[i], ); + out_buffer.set_wire(local_wire(self.gate.wire_chunks_equal(i)), chunks_equal[i]); } } } @@ -451,16 +443,15 @@ mod tests { assert_eq!(gate.wire_first_input(), 0); assert_eq!(gate.wire_second_input(), 1); - assert_eq!(gate.wire_z_bit(0), 2); - assert_eq!(gate.wire_z_bit(8), 10); - assert_eq!(gate.wire_first_chunk_val(0), 11); - assert_eq!(gate.wire_first_chunk_val(4), 15); - assert_eq!(gate.wire_second_chunk_val(0), 16); - assert_eq!(gate.wire_second_chunk_val(4), 20); - assert_eq!(gate.wire_equality_dummy(0), 21); - assert_eq!(gate.wire_equality_dummy(4), 25); - assert_eq!(gate.wire_chunks_equal(0), 26); - assert_eq!(gate.wire_chunks_equal(4), 30); + assert_eq!(gate.wire_most_significant_diff(), 2); + assert_eq!(gate.wire_first_chunk_val(0), 3); + assert_eq!(gate.wire_first_chunk_val(4), 7); + assert_eq!(gate.wire_second_chunk_val(0), 8); + assert_eq!(gate.wire_second_chunk_val(4), 12); + assert_eq!(gate.wire_equality_dummy(0), 13); + assert_eq!(gate.wire_equality_dummy(4), 17); + assert_eq!(gate.wire_chunks_equal(0), 18); + assert_eq!(gate.wire_chunks_equal(4), 22); } #[test] @@ -485,107 +476,83 @@ mod tests { type FF = QuarticExtension; const D: usize = 4; - let num_copies = 3; let num_bits = 40; let num_chunks = 5; let chunk_bits = num_bits / num_chunks; // Returns the local wires for a comparison gate given the two inputs. - let get_wires = |first_inputs: Vec, second_inputs: Vec| -> Vec { - let num_copies = first_inputs.len(); - + let get_wires = |first_input: F, second_input: F| -> Vec { let mut v = Vec::new(); - for c in 0..num_copies { - let first_input = first_inputs[c]; - let second_input = second_inputs[c]; - let first_input_u64 = first_input.to_canonical_u64(); - let second_input_u64 = second_input.to_canonical_u64(); + let first_input_u64 = first_input.to_canonical_u64(); + let second_input_u64 = second_input.to_canonical_u64(); - let first_input_bits: Vec = (0..num_bits) - .scan(first_input_u64, |acc, _| { - let tmp = *acc % 2; - *acc /= 2; - Some(F::from_canonical_u64(tmp)) - }) - .collect(); - let second_input_bits: Vec = (0..num_bits) - .scan(second_input_u64, |acc, _| { - let tmp = *acc % 2; - *acc /= 2; - Some(F::from_canonical_u64(tmp)) - }) - .collect(); + let first_input_bits: Vec = (0..num_bits) + .scan(first_input_u64, |acc, _| { + let tmp = *acc % 2; + *acc /= 2; + Some(F::from_canonical_u64(tmp)) + }) + .collect(); + let second_input_bits: Vec = (0..num_bits) + .scan(second_input_u64, |acc, _| { + let tmp = *acc % 2; + *acc /= 2; + Some(F::from_canonical_u64(tmp)) + }) + .collect(); - let mut first_input_chunks: Vec = first_input_bits - .chunks(chunk_bits) - .map(|bits| reduce_with_powers(&bits, F::TWO)) - .collect(); - let mut second_input_chunks: Vec = second_input_bits - .chunks(chunk_bits) - .map(|bits| reduce_with_powers(&bits, F::TWO)) - .collect(); + let mut first_input_chunks: Vec = first_input_bits + .chunks(chunk_bits) + .map(|bits| reduce_with_powers(&bits, F::TWO)) + .collect(); + let mut second_input_chunks: Vec = second_input_bits + .chunks(chunk_bits) + .map(|bits| reduce_with_powers(&bits, F::TWO)) + .collect(); - let mut chunks_equal: Vec = (0..num_chunks) - .map(|i| F::from_bool(first_input_chunks[i] == second_input_chunks[i])) - .collect(); - let mut equality_dummies: Vec = first_input_chunks - .iter() - .zip(second_input_chunks.iter()) - .map(|(&f, &s)| if f == s { F::ONE } else { F::ONE / (f - s) }) - .collect(); + let mut chunks_equal: Vec = (0..num_chunks) + .map(|i| F::from_bool(first_input_chunks[i] == second_input_chunks[i])) + .collect(); + let mut equality_dummies: Vec = first_input_chunks + .iter() + .zip(second_input_chunks.iter()) + .map(|(&f, &s)| if f == s { F::ONE } else { F::ONE / (s - f) }) + .collect(); - let mut diff_index = 0; - for i in 1..num_chunks { - if first_input_chunks[i] != second_input_chunks[i] { - diff_index = i; - } + let mut diff_index = 0; + for i in 1..num_chunks { + if first_input_chunks[i] != second_input_chunks[i] { + diff_index = i; } - let most_significant_diff = - first_input_chunks[diff_index] - second_input_chunks[diff_index]; - - let z = F::TWO.exp_u64(chunk_bits as u64) + most_significant_diff; - let mut z_bits: Vec = (0..chunk_bits + 1) - .scan(z.to_canonical_u64(), |acc, _| { - let tmp = *acc % 2; - *acc /= 2; - Some(F::from_canonical_u64(tmp)) - }) - .collect(); - - v.push(first_input); - v.push(second_input); - v.append(&mut z_bits); - v.append(&mut first_input_chunks); - v.append(&mut second_input_chunks); - v.append(&mut equality_dummies); - v.append(&mut chunks_equal); } + let most_significant_diff = second_input_chunks[diff_index] - first_input_chunks[diff_index]; + + v.push(first_input); + v.push(second_input); + v.push(most_significant_diff); + v.append(&mut first_input_chunks); + v.append(&mut second_input_chunks); + v.append(&mut equality_dummies); + v.append(&mut chunks_equal); + v.iter().map(|&x| x.into()).collect::>() }; let mut rng = rand::thread_rng(); let max: u64 = 1 << num_bits - 1; - let first_inputs_u64: Vec = (0..num_copies).map(|_| rng.gen_range(0..max)).collect(); - let second_inputs_u64: Vec = (0..num_copies) - .map(|i| { - let mut val = rng.gen_range(0..max); - while val <= first_inputs_u64[i] { - val = rng.gen_range(0..max); - } - val - }) - .collect(); + let first_input_u64 = rng.gen_range(0..max); + let second_input_u64 = { + let mut val = rng.gen_range(0..max); + while val <= first_input_u64 { + val = rng.gen_range(0..max); + } + val + }; - let first_inputs = first_inputs_u64 - .iter() - .map(|&x| F::from_canonical_u64(x)) - .collect(); - let second_inputs = second_inputs_u64 - .iter() - .map(|&x| F::from_canonical_u64(x)) - .collect(); + let first_input = F::from_canonical_u64(first_input_u64); + let second_input = F::from_canonical_u64(second_input_u64); let gate = ComparisonGate:: { num_bits, @@ -595,7 +562,7 @@ mod tests { let vars = EvaluationVars { local_constants: &[], - local_wires: &get_wires(first_inputs, second_inputs), + local_wires: &get_wires(first_input, second_input), public_inputs_hash: &HashOut::rand(), };