BigUint arithmetic, and cleanup

This commit is contained in:
Nicholas Ward 2021-10-15 16:47:29 -07:00
parent 72aea53d13
commit 9077c7fa3c
6 changed files with 119 additions and 49 deletions

View File

@ -1,8 +1,6 @@
use std::marker::PhantomData;
use crate::field::extension_field::Extendable;
use crate::field::field_types::RichField;
use crate::gates::arithmetic_u32::{U32ArithmeticGate, NUM_U32_ARITHMETIC_OPS};
use crate::gates::arithmetic_u32::U32ArithmeticGate;
use crate::gates::subtraction_u32::U32SubtractionGate;
use crate::iop::target::Target;
use crate::plonk::circuit_builder::CircuitBuilder;

82
src/gadgets/biguint.rs Normal file
View File

@ -0,0 +1,82 @@
use std::marker::PhantomData;
use num::BigUint;
use crate::field::field_types::RichField;
use crate::field::{extension_field::Extendable, field_types::Field};
use crate::gadgets::arithmetic_u32::U32Target;
use crate::plonk::circuit_builder::CircuitBuilder;
pub struct BigUintTarget {
limbs: Vec<U32Target>,
}
impl<F: RichField + Extendable<D>, const D: usize> CircuitBuilder<F, D> {
// Add two `BigUintTarget`s.
pub fn add_biguint(&mut self, a: BigUintTarget, b: BigUintTarget) -> BigUintTarget {
let num_limbs = a.limbs.len();
debug_assert!(b.limbs.len() == num_limbs);
let mut combined_limbs = vec![];
let mut carry = self.zero_u32();
for i in 0..num_limbs {
let (new_limb, new_carry) =
self.add_three_u32(carry.clone(), a.limbs[i].clone(), b.limbs[i].clone());
carry = new_carry;
combined_limbs.push(new_limb);
}
combined_limbs[num_limbs] = carry;
BigUintTarget {
limbs: combined_limbs,
}
}
// Subtract two `BigUintTarget`s. We assume that the first is larger than the second.
pub fn sub_biguint(&mut self, a: BigUintTarget, b: BigUintTarget) -> BigUintTarget {
let num_limbs = a.limbs.len();
debug_assert!(b.limbs.len() == num_limbs);
let mut result_limbs = vec![];
let mut borrow = self.zero_u32();
for i in 0..num_limbs {
let (result, new_borrow) = self.sub_u32(a.limbs[i], b.limbs[i], borrow);
result_limbs[i] = result;
borrow = new_borrow;
}
// Borrow should be zero here.
BigUintTarget {
limbs: result_limbs,
}
}
pub fn mul_biguint(&mut self, a: BigUintTarget, b: BigUintTarget) -> BigUintTarget {
let num_limbs = a.limbs.len();
debug_assert!(b.limbs.len() == num_limbs);
let mut to_add = vec![vec![]; 2 * num_limbs];
for i in 0..num_limbs {
for j in 0..num_limbs {
let (product, carry) = self.mul_u32(a.limbs[i], b.limbs[j]);
to_add[i + j].push(product);
to_add[i + j + 1].push(carry);
}
}
let mut combined_limbs = vec![];
let mut carry = self.zero_u32();
for i in 0..2 * num_limbs {
to_add[i].push(carry);
let (new_result, new_carry) = self.add_many_u32(to_add[i].clone());
combined_limbs.push(new_result);
carry = new_carry;
}
combined_limbs.push(carry);
BigUintTarget {
limbs: combined_limbs,
}
}
}

View File

@ -1,11 +1,12 @@
pub mod arithmetic;
pub mod arithmetic_extension;
pub mod arithmetic_u32;
pub mod biguint;
pub mod hash;
pub mod insert;
pub mod interpolation;
pub mod multiple_comparison;
pub mod nonnative;
//pub mod nonnative;
pub mod permutation;
pub mod polynomial;
pub mod random_access;

View File

@ -1,19 +1,13 @@
use std::marker::PhantomData;
use itertools::izip;
use crate::field::extension_field::Extendable;
use crate::field::field_types::{Field, RichField};
use crate::field::field_types::RichField;
use crate::gates::comparison::ComparisonGate;
use crate::iop::generator::{GeneratedValues, SimpleGenerator};
use crate::iop::target::{BoolTarget, Target};
use crate::iop::witness::{PartitionWitness, Witness};
use crate::plonk::circuit_builder::CircuitBuilder;
use crate::util::ceil_div_usize;
impl<F: RichField + Extendable<D>, const D: usize> CircuitBuilder<F, D> {
/// Returns true if a is less than or equal to b, considered as limbs of a large value.
pub fn compare_lists(&mut self, a: Vec<Target>, b: Vec<Target>, num_bits: usize) -> BoolTarget {
pub fn list_le(&mut self, a: Vec<Target>, b: Vec<Target>, num_bits: usize) -> BoolTarget {
assert_eq!(
a.len(),
b.len(),
@ -22,7 +16,7 @@ impl<F: RichField + Extendable<D>, const D: usize> CircuitBuilder<F, D> {
let n = a.len();
let chunk_size = 4;
let num_chunks = ceil_div_usize(num_bits, 4);
let num_chunks = ceil_div_usize(num_bits, chunk_size);
let one = self.one();
let mut result = self.one();

View File

@ -73,8 +73,7 @@ impl<F: RichField + Extendable<D>, const D: usize> CircuitBuilder<F, D> {
let chunk_size = a1.len();
let (gate, gate_index, mut next_copy) =
self.find_switch_gate(chunk_size);
let (gate, gate_index, mut next_copy) = self.find_switch_gate(chunk_size);
let num_copies = gate.num_copies;

View File

@ -13,7 +13,7 @@ use crate::iop::witness::{PartitionWitness, Witness};
use crate::plonk::circuit_builder::CircuitBuilder;
use crate::plonk::vars::{EvaluationTargets, EvaluationVars, EvaluationVarsBase};
/// Number of arithmetic operations performed by an arithmetic gate.
/// Maximum number of subtractions operations performed by a single gate.
pub const NUM_U32_SUBTRACTION_OPS: usize = 3;
/// A gate to perform a subtraction .
@ -28,7 +28,7 @@ impl<F: RichField + Extendable<D>, const D: usize> U32SubtractionGate<F, D> {
_phantom: PhantomData,
}
}
pub fn wire_ith_input_x(i: usize) -> usize {
debug_assert!(i < NUM_U32_SUBTRACTION_OPS);
5 * i
@ -168,7 +168,8 @@ impl<F: RichField + Extendable<D>, const D: usize> Gate<F, D> for U32Subtraction
// Range-check output_result to be at most 32 bits.
let mut combined_limbs = builder.zero_extension();
let limb_base = builder.constant_extension(F::Extension::from_canonical_u64(1u64 << Self::limb_bits()));
let limb_base = builder
.constant_extension(F::Extension::from_canonical_u64(1u64 << Self::limb_bits()));
for j in (0..Self::num_limbs()).rev() {
let this_limb = vars.local_wires[Self::wire_ith_output_jth_limb(i, j)];
let max_limb = 1 << Self::limb_bits();
@ -245,15 +246,15 @@ impl<F: RichField + Extendable<D>, const D: usize> SimpleGenerator<F>
let local_target = |input| Target::wire(self.gate_index, input);
let mut deps = Vec::with_capacity(3);
deps.push(local_target(
U32SubtractionGate::<F, D>::wire_ith_input_x(self.i),
));
deps.push(local_target(
U32SubtractionGate::<F, D>::wire_ith_input_y(self.i),
));
deps.push(local_target(U32SubtractionGate::<F, D>::wire_ith_input_borrow(
deps.push(local_target(U32SubtractionGate::<F, D>::wire_ith_input_x(
self.i,
)));
deps.push(local_target(U32SubtractionGate::<F, D>::wire_ith_input_y(
self.i,
)));
deps.push(local_target(
U32SubtractionGate::<F, D>::wire_ith_input_borrow(self.i),
));
deps
}
@ -265,11 +266,10 @@ impl<F: RichField + Extendable<D>, const D: usize> SimpleGenerator<F>
let get_local_wire = |input| witness.get_wire(local_wire(input));
let input_x =
get_local_wire(U32SubtractionGate::<F, D>::wire_ith_input_x(self.i));
let input_y =
get_local_wire(U32SubtractionGate::<F, D>::wire_ith_input_y(self.i));
let input_borrow = get_local_wire(U32SubtractionGate::<F, D>::wire_ith_input_borrow(self.i));
let input_x = get_local_wire(U32SubtractionGate::<F, D>::wire_ith_input_x(self.i));
let input_y = get_local_wire(U32SubtractionGate::<F, D>::wire_ith_input_y(self.i));
let input_borrow =
get_local_wire(U32SubtractionGate::<F, D>::wire_ith_input_borrow(self.i));
let result_initial = input_x - input_y - input_borrow;
let result_initial_u64 = result_initial.to_canonical_u64();
@ -281,7 +281,7 @@ impl<F: RichField + Extendable<D>, const D: usize> SimpleGenerator<F>
let base = F::from_canonical_u64(1 << 32u64);
let output_result = result_initial + base * output_borrow;
let output_result_wire =
local_wire(U32SubtractionGate::<F, D>::wire_ith_output_result(self.i));
let output_borrow_wire =
@ -295,12 +295,12 @@ impl<F: RichField + Extendable<D>, const D: usize> SimpleGenerator<F>
let num_limbs = U32SubtractionGate::<F, D>::num_limbs();
let limb_base = 1 << U32SubtractionGate::<F, D>::limb_bits();
let output_limbs: Vec<_> = (0..num_limbs)
.scan(output_result_u64, |acc, _| {
let tmp = *acc % limb_base;
*acc /= limb_base;
Some(F::from_canonical_u64(tmp))
})
.collect();
.scan(output_result_u64, |acc, _| {
let tmp = *acc % limb_base;
*acc /= limb_base;
Some(F::from_canonical_u64(tmp))
})
.collect();
for j in 0..num_limbs {
let wire = local_wire(U32SubtractionGate::<F, D>::wire_ith_output_jth_limb(
@ -321,9 +321,9 @@ mod tests {
use crate::field::crandall_field::CrandallField;
use crate::field::extension_field::quartic::QuarticExtension;
use crate::field::field_types::{Field, PrimeField};
use crate::gates::subtraction_u32::{U32SubtractionGate, NUM_U32_SUBTRACTION_OPS};
use crate::gates::gate::Gate;
use crate::gates::gate_testing::{test_eval_fns, test_low_degree};
use crate::gates::subtraction_u32::{U32SubtractionGate, NUM_U32_SUBTRACTION_OPS};
use crate::hash::hash_types::HashOut;
use crate::plonk::vars::EvaluationVars;
@ -347,11 +347,7 @@ mod tests {
type FF = QuarticExtension<CrandallField>;
const D: usize = 4;
fn get_wires(
inputs_x: Vec<u64>,
inputs_y: Vec<u64>,
borrows: Vec<u64>,
) -> Vec<FF> {
fn get_wires(inputs_x: Vec<u64>, inputs_y: Vec<u64>, borrows: Vec<u64>) -> Vec<FF> {
let mut v0 = Vec::new();
let mut v1 = Vec::new();
@ -377,12 +373,12 @@ mod tests {
let output_result_u64 = output_result.to_canonical_u64();
let mut output_limbs: Vec<_> = (0..num_limbs)
.scan(output_result_u64, |acc, _| {
let tmp = *acc % limb_base;
*acc /= limb_base;
Some(F::from_canonical_u64(tmp))
})
.collect();
.scan(output_result_u64, |acc, _| {
let tmp = *acc % limb_base;
*acc /= limb_base;
Some(F::from_canonical_u64(tmp))
})
.collect();
v0.push(input_x);
v0.push(input_y);