From 9077c7fa3c517ffd0086e982585732b22aef3f09 Mon Sep 17 00:00:00 2001 From: Nicholas Ward Date: Fri, 15 Oct 2021 16:47:29 -0700 Subject: [PATCH] BigUint arithmetic, and cleanup --- src/gadgets/arithmetic_u32.rs | 4 +- src/gadgets/biguint.rs | 82 ++++++++++++++++++++++++++++++ src/gadgets/mod.rs | 3 +- src/gadgets/multiple_comparison.rs | 12 ++--- src/gadgets/permutation.rs | 3 +- src/gates/subtraction_u32.rs | 64 +++++++++++------------ 6 files changed, 119 insertions(+), 49 deletions(-) create mode 100644 src/gadgets/biguint.rs diff --git a/src/gadgets/arithmetic_u32.rs b/src/gadgets/arithmetic_u32.rs index e53c1761..ce35d4f0 100644 --- a/src/gadgets/arithmetic_u32.rs +++ b/src/gadgets/arithmetic_u32.rs @@ -1,8 +1,6 @@ -use std::marker::PhantomData; - use crate::field::extension_field::Extendable; use crate::field::field_types::RichField; -use crate::gates::arithmetic_u32::{U32ArithmeticGate, NUM_U32_ARITHMETIC_OPS}; +use crate::gates::arithmetic_u32::U32ArithmeticGate; use crate::gates::subtraction_u32::U32SubtractionGate; use crate::iop::target::Target; use crate::plonk::circuit_builder::CircuitBuilder; diff --git a/src/gadgets/biguint.rs b/src/gadgets/biguint.rs new file mode 100644 index 00000000..4adc2382 --- /dev/null +++ b/src/gadgets/biguint.rs @@ -0,0 +1,82 @@ +use std::marker::PhantomData; + +use num::BigUint; + +use crate::field::field_types::RichField; +use crate::field::{extension_field::Extendable, field_types::Field}; +use crate::gadgets::arithmetic_u32::U32Target; +use crate::plonk::circuit_builder::CircuitBuilder; + +pub struct BigUintTarget { + limbs: Vec, +} + +impl, const D: usize> CircuitBuilder { + // Add two `BigUintTarget`s. + pub fn add_biguint(&mut self, a: BigUintTarget, b: BigUintTarget) -> BigUintTarget { + let num_limbs = a.limbs.len(); + debug_assert!(b.limbs.len() == num_limbs); + + let mut combined_limbs = vec![]; + let mut carry = self.zero_u32(); + for i in 0..num_limbs { + let (new_limb, new_carry) = + self.add_three_u32(carry.clone(), a.limbs[i].clone(), b.limbs[i].clone()); + carry = new_carry; + combined_limbs.push(new_limb); + } + combined_limbs[num_limbs] = carry; + + BigUintTarget { + limbs: combined_limbs, + } + } + + // Subtract two `BigUintTarget`s. We assume that the first is larger than the second. + pub fn sub_biguint(&mut self, a: BigUintTarget, b: BigUintTarget) -> BigUintTarget { + let num_limbs = a.limbs.len(); + debug_assert!(b.limbs.len() == num_limbs); + + let mut result_limbs = vec![]; + + let mut borrow = self.zero_u32(); + for i in 0..num_limbs { + let (result, new_borrow) = self.sub_u32(a.limbs[i], b.limbs[i], borrow); + result_limbs[i] = result; + borrow = new_borrow; + } + // Borrow should be zero here. + + BigUintTarget { + limbs: result_limbs, + } + } + + pub fn mul_biguint(&mut self, a: BigUintTarget, b: BigUintTarget) -> BigUintTarget { + let num_limbs = a.limbs.len(); + debug_assert!(b.limbs.len() == num_limbs); + + let mut to_add = vec![vec![]; 2 * num_limbs]; + for i in 0..num_limbs { + for j in 0..num_limbs { + let (product, carry) = self.mul_u32(a.limbs[i], b.limbs[j]); + to_add[i + j].push(product); + to_add[i + j + 1].push(carry); + } + } + + let mut combined_limbs = vec![]; + let mut carry = self.zero_u32(); + for i in 0..2 * num_limbs { + to_add[i].push(carry); + let (new_result, new_carry) = self.add_many_u32(to_add[i].clone()); + combined_limbs.push(new_result); + carry = new_carry; + } + combined_limbs.push(carry); + + BigUintTarget { + limbs: combined_limbs, + } + } +} diff --git a/src/gadgets/mod.rs b/src/gadgets/mod.rs index 9fb572c9..cf6f6ed4 100644 --- a/src/gadgets/mod.rs +++ b/src/gadgets/mod.rs @@ -1,11 +1,12 @@ pub mod arithmetic; pub mod arithmetic_extension; pub mod arithmetic_u32; +pub mod biguint; pub mod hash; pub mod insert; pub mod interpolation; pub mod multiple_comparison; -pub mod nonnative; +//pub mod nonnative; pub mod permutation; pub mod polynomial; pub mod random_access; diff --git a/src/gadgets/multiple_comparison.rs b/src/gadgets/multiple_comparison.rs index 5291323d..11225ca5 100644 --- a/src/gadgets/multiple_comparison.rs +++ b/src/gadgets/multiple_comparison.rs @@ -1,19 +1,13 @@ -use std::marker::PhantomData; - -use itertools::izip; - use crate::field::extension_field::Extendable; -use crate::field::field_types::{Field, RichField}; +use crate::field::field_types::RichField; use crate::gates::comparison::ComparisonGate; -use crate::iop::generator::{GeneratedValues, SimpleGenerator}; use crate::iop::target::{BoolTarget, Target}; -use crate::iop::witness::{PartitionWitness, Witness}; use crate::plonk::circuit_builder::CircuitBuilder; use crate::util::ceil_div_usize; impl, const D: usize> CircuitBuilder { /// Returns true if a is less than or equal to b, considered as limbs of a large value. - pub fn compare_lists(&mut self, a: Vec, b: Vec, num_bits: usize) -> BoolTarget { + pub fn list_le(&mut self, a: Vec, b: Vec, num_bits: usize) -> BoolTarget { assert_eq!( a.len(), b.len(), @@ -22,7 +16,7 @@ impl, const D: usize> CircuitBuilder { let n = a.len(); let chunk_size = 4; - let num_chunks = ceil_div_usize(num_bits, 4); + let num_chunks = ceil_div_usize(num_bits, chunk_size); let one = self.one(); let mut result = self.one(); diff --git a/src/gadgets/permutation.rs b/src/gadgets/permutation.rs index a0c9b087..ae0e411b 100644 --- a/src/gadgets/permutation.rs +++ b/src/gadgets/permutation.rs @@ -73,8 +73,7 @@ impl, const D: usize> CircuitBuilder { let chunk_size = a1.len(); - let (gate, gate_index, mut next_copy) = - self.find_switch_gate(chunk_size); + let (gate, gate_index, mut next_copy) = self.find_switch_gate(chunk_size); let num_copies = gate.num_copies; diff --git a/src/gates/subtraction_u32.rs b/src/gates/subtraction_u32.rs index de79b67f..fc2009be 100644 --- a/src/gates/subtraction_u32.rs +++ b/src/gates/subtraction_u32.rs @@ -13,7 +13,7 @@ use crate::iop::witness::{PartitionWitness, Witness}; use crate::plonk::circuit_builder::CircuitBuilder; use crate::plonk::vars::{EvaluationTargets, EvaluationVars, EvaluationVarsBase}; -/// Number of arithmetic operations performed by an arithmetic gate. +/// Maximum number of subtractions operations performed by a single gate. pub const NUM_U32_SUBTRACTION_OPS: usize = 3; /// A gate to perform a subtraction . @@ -28,7 +28,7 @@ impl, const D: usize> U32SubtractionGate { _phantom: PhantomData, } } - + pub fn wire_ith_input_x(i: usize) -> usize { debug_assert!(i < NUM_U32_SUBTRACTION_OPS); 5 * i @@ -168,7 +168,8 @@ impl, const D: usize> Gate for U32Subtraction // Range-check output_result to be at most 32 bits. let mut combined_limbs = builder.zero_extension(); - let limb_base = builder.constant_extension(F::Extension::from_canonical_u64(1u64 << Self::limb_bits())); + let limb_base = builder + .constant_extension(F::Extension::from_canonical_u64(1u64 << Self::limb_bits())); for j in (0..Self::num_limbs()).rev() { let this_limb = vars.local_wires[Self::wire_ith_output_jth_limb(i, j)]; let max_limb = 1 << Self::limb_bits(); @@ -245,15 +246,15 @@ impl, const D: usize> SimpleGenerator let local_target = |input| Target::wire(self.gate_index, input); let mut deps = Vec::with_capacity(3); - deps.push(local_target( - U32SubtractionGate::::wire_ith_input_x(self.i), - )); - deps.push(local_target( - U32SubtractionGate::::wire_ith_input_y(self.i), - )); - deps.push(local_target(U32SubtractionGate::::wire_ith_input_borrow( + deps.push(local_target(U32SubtractionGate::::wire_ith_input_x( self.i, ))); + deps.push(local_target(U32SubtractionGate::::wire_ith_input_y( + self.i, + ))); + deps.push(local_target( + U32SubtractionGate::::wire_ith_input_borrow(self.i), + )); deps } @@ -265,11 +266,10 @@ impl, const D: usize> SimpleGenerator let get_local_wire = |input| witness.get_wire(local_wire(input)); - let input_x = - get_local_wire(U32SubtractionGate::::wire_ith_input_x(self.i)); - let input_y = - get_local_wire(U32SubtractionGate::::wire_ith_input_y(self.i)); - let input_borrow = get_local_wire(U32SubtractionGate::::wire_ith_input_borrow(self.i)); + let input_x = get_local_wire(U32SubtractionGate::::wire_ith_input_x(self.i)); + let input_y = get_local_wire(U32SubtractionGate::::wire_ith_input_y(self.i)); + let input_borrow = + get_local_wire(U32SubtractionGate::::wire_ith_input_borrow(self.i)); let result_initial = input_x - input_y - input_borrow; let result_initial_u64 = result_initial.to_canonical_u64(); @@ -281,7 +281,7 @@ impl, const D: usize> SimpleGenerator let base = F::from_canonical_u64(1 << 32u64); let output_result = result_initial + base * output_borrow; - + let output_result_wire = local_wire(U32SubtractionGate::::wire_ith_output_result(self.i)); let output_borrow_wire = @@ -295,12 +295,12 @@ impl, const D: usize> SimpleGenerator let num_limbs = U32SubtractionGate::::num_limbs(); let limb_base = 1 << U32SubtractionGate::::limb_bits(); let output_limbs: Vec<_> = (0..num_limbs) - .scan(output_result_u64, |acc, _| { - let tmp = *acc % limb_base; - *acc /= limb_base; - Some(F::from_canonical_u64(tmp)) - }) - .collect(); + .scan(output_result_u64, |acc, _| { + let tmp = *acc % limb_base; + *acc /= limb_base; + Some(F::from_canonical_u64(tmp)) + }) + .collect(); for j in 0..num_limbs { let wire = local_wire(U32SubtractionGate::::wire_ith_output_jth_limb( @@ -321,9 +321,9 @@ mod tests { use crate::field::crandall_field::CrandallField; use crate::field::extension_field::quartic::QuarticExtension; use crate::field::field_types::{Field, PrimeField}; - use crate::gates::subtraction_u32::{U32SubtractionGate, NUM_U32_SUBTRACTION_OPS}; use crate::gates::gate::Gate; use crate::gates::gate_testing::{test_eval_fns, test_low_degree}; + use crate::gates::subtraction_u32::{U32SubtractionGate, NUM_U32_SUBTRACTION_OPS}; use crate::hash::hash_types::HashOut; use crate::plonk::vars::EvaluationVars; @@ -347,11 +347,7 @@ mod tests { type FF = QuarticExtension; const D: usize = 4; - fn get_wires( - inputs_x: Vec, - inputs_y: Vec, - borrows: Vec, - ) -> Vec { + fn get_wires(inputs_x: Vec, inputs_y: Vec, borrows: Vec) -> Vec { let mut v0 = Vec::new(); let mut v1 = Vec::new(); @@ -377,12 +373,12 @@ mod tests { let output_result_u64 = output_result.to_canonical_u64(); let mut output_limbs: Vec<_> = (0..num_limbs) - .scan(output_result_u64, |acc, _| { - let tmp = *acc % limb_base; - *acc /= limb_base; - Some(F::from_canonical_u64(tmp)) - }) - .collect(); + .scan(output_result_u64, |acc, _| { + let tmp = *acc % limb_base; + *acc /= limb_base; + Some(F::from_canonical_u64(tmp)) + }) + .collect(); v0.push(input_x); v0.push(input_y);