mirror of
https://github.com/logos-storage/plonky2.git
synced 2026-01-07 08:13:11 +00:00
more efficient nonnative add and multi-add
This commit is contained in:
parent
facb5661f3
commit
50c24dfe8a
@ -40,6 +40,10 @@ impl<F: RichField + Extendable<D>, const D: usize> CircuitBuilder<F, D> {
|
||||
BigUintTarget { limbs }
|
||||
}
|
||||
|
||||
pub fn zero_biguint(&mut self) -> BigUintTarget {
|
||||
self.constant_biguint(&BigUint::zero())
|
||||
}
|
||||
|
||||
pub fn connect_biguint(&mut self, lhs: &BigUintTarget, rhs: &BigUintTarget) {
|
||||
let min_limbs = lhs.num_limbs().min(rhs.num_limbs());
|
||||
for i in 0..min_limbs {
|
||||
@ -159,6 +163,18 @@ impl<F: RichField + Extendable<D>, const D: usize> CircuitBuilder<F, D> {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn mul_biguint_by_bool(
|
||||
&mut self,
|
||||
a: &BigUintTarget,
|
||||
b: BoolTarget,
|
||||
) -> BigUintTarget {
|
||||
let t = b.target;
|
||||
|
||||
BigUintTarget {
|
||||
limbs: a.limbs.iter().map(|l| U32Target(self.mul(l.0, t))).collect()
|
||||
}
|
||||
}
|
||||
|
||||
// Returns x * y + z. This is no more efficient than mul-then-add; it's purely for convenience (only need to call one CircuitBuilder function).
|
||||
pub fn mul_add_biguint(
|
||||
&mut self,
|
||||
@ -396,11 +412,11 @@ mod tests {
|
||||
let y = builder.constant_biguint(&y_value);
|
||||
let (div, rem) = builder.div_rem_biguint(&x, &y);
|
||||
|
||||
// let expected_div = builder.constant_biguint(&expected_div_value);
|
||||
// let expected_rem = builder.constant_biguint(&expected_rem_value);
|
||||
let expected_div = builder.constant_biguint(&expected_div_value);
|
||||
let expected_rem = builder.constant_biguint(&expected_rem_value);
|
||||
|
||||
// builder.connect_biguint(&div, &expected_div);
|
||||
// builder.connect_biguint(&rem, &expected_rem);
|
||||
builder.connect_biguint(&div, &expected_div);
|
||||
builder.connect_biguint(&rem, &expected_rem);
|
||||
|
||||
let data = builder.build::<C>();
|
||||
let proof = data.prove(pw).unwrap();
|
||||
|
||||
@ -100,6 +100,7 @@ impl<F: RichField + Extendable<D>, const D: usize> CircuitBuilder<F, D> {
|
||||
p1: &AffinePointTarget<C>,
|
||||
p2: &AffinePointTarget<C>,
|
||||
) -> AffinePointTarget<C> {
|
||||
let before = self.num_gates();
|
||||
let AffinePointTarget { x: x1, y: y1 } = p1;
|
||||
let AffinePointTarget { x: x2, y: y2 } = p2;
|
||||
|
||||
@ -123,6 +124,7 @@ impl<F: RichField + Extendable<D>, const D: usize> CircuitBuilder<F, D> {
|
||||
let x3_norm = self.mul_nonnative(&x3, &z3_inv);
|
||||
let y3_norm = self.mul_nonnative(&y3, &z3_inv);
|
||||
|
||||
println!("NUM GATES: {}", self.num_gates() - before);
|
||||
AffinePointTarget {
|
||||
x: x3_norm,
|
||||
y: y3_norm,
|
||||
@ -310,7 +312,6 @@ mod tests {
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[ignore]
|
||||
fn test_curve_mul() -> Result<()> {
|
||||
const D: usize = 2;
|
||||
type C = PoseidonGoldilocksConfig;
|
||||
@ -345,7 +346,6 @@ mod tests {
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[ignore]
|
||||
fn test_curve_random() -> Result<()> {
|
||||
const D: usize = 2;
|
||||
type C = PoseidonGoldilocksConfig;
|
||||
|
||||
@ -60,16 +60,45 @@ impl<F: RichField + Extendable<D>, const D: usize> CircuitBuilder<F, D> {
|
||||
}
|
||||
}
|
||||
|
||||
// Add two `NonNativeTarget`s.
|
||||
pub fn add_nonnative<FF: Field>(
|
||||
&mut self,
|
||||
a: &NonNativeTarget<FF>,
|
||||
b: &NonNativeTarget<FF>,
|
||||
) -> NonNativeTarget<FF> {
|
||||
let result = self.add_biguint(&a.value, &b.value);
|
||||
let sum = self.add_virtual_nonnative_target::<FF>();
|
||||
let overflow = self.add_virtual_bool_target();
|
||||
|
||||
// TODO: reduce add result with only one conditional subtraction
|
||||
self.reduce(&result)
|
||||
self.add_simple_generator(NonNativeAdditionGenerator::<F, D, FF> {
|
||||
a: a.clone(),
|
||||
b: b.clone(),
|
||||
sum: sum.clone(),
|
||||
overflow: overflow.clone(),
|
||||
_phantom: PhantomData,
|
||||
});
|
||||
|
||||
let sum_expected = self.add_biguint(&a.value, &b.value);
|
||||
|
||||
let modulus = self.constant_biguint(&FF::order());
|
||||
let mod_times_overflow = self.mul_biguint_by_bool(&modulus, overflow);
|
||||
let sum_actual = self.add_biguint(&sum.value, &mod_times_overflow);
|
||||
self.connect_biguint(&sum_expected, &sum_actual);
|
||||
|
||||
sum
|
||||
}
|
||||
|
||||
pub fn mul_nonnative_by_bool<FF: Field>(
|
||||
&mut self,
|
||||
a: &NonNativeTarget<FF>,
|
||||
b: BoolTarget,
|
||||
) -> NonNativeTarget<FF> {
|
||||
let t = b.target;
|
||||
|
||||
NonNativeTarget {
|
||||
value: BigUintTarget {
|
||||
limbs: a.value.limbs.iter().map(|l| U32Target(self.mul(l.0, t))).collect()
|
||||
},
|
||||
_phantom: PhantomData,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn add_many_nonnative<FF: Field>(
|
||||
@ -80,12 +109,28 @@ impl<F: RichField + Extendable<D>, const D: usize> CircuitBuilder<F, D> {
|
||||
return to_add[0].clone();
|
||||
}
|
||||
|
||||
let mut result = self.add_biguint(&to_add[0].value, &to_add[1].value);
|
||||
for i in 2..to_add.len() {
|
||||
result = self.add_biguint(&result, &to_add[i].value);
|
||||
}
|
||||
let sum = self.add_virtual_nonnative_target::<FF>();
|
||||
let overflow = self.add_virtual_u32_target();
|
||||
let summands = to_add.to_vec();
|
||||
|
||||
self.reduce(&result)
|
||||
self.add_simple_generator(NonNativeMultipleAddsGenerator::<F, D, FF> {
|
||||
summands: summands.clone(),
|
||||
sum: sum.clone(),
|
||||
overflow: overflow.clone(),
|
||||
_phantom: PhantomData,
|
||||
});
|
||||
|
||||
let sum_expected = summands.iter().fold(self.zero_biguint(), |a, b| self.add_biguint(&a, &b.value));
|
||||
|
||||
let modulus = self.constant_biguint(&FF::order());
|
||||
let overflow_biguint = BigUintTarget {
|
||||
limbs: vec![overflow],
|
||||
};
|
||||
let mod_times_overflow = self.mul_biguint(&modulus, &overflow_biguint);
|
||||
let sum_actual = self.add_biguint(&sum.value, &mod_times_overflow);
|
||||
self.connect_biguint(&sum_expected, &sum_actual);
|
||||
|
||||
sum
|
||||
}
|
||||
|
||||
// Subtract two `NonNativeTarget`s.
|
||||
@ -188,59 +233,6 @@ impl<F: RichField + Extendable<D>, const D: usize> CircuitBuilder<F, D> {
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns `x % |FF|` as a `NonNativeTarget`.
|
||||
/*fn reduce_by_bits<FF: Field>(&mut self, x: &BigUintTarget) -> NonNativeTarget<FF> {
|
||||
let before = self.num_gates();
|
||||
|
||||
let mut powers_of_two = Vec::new();
|
||||
let mut cur_power_of_two = FF::ONE;
|
||||
let two = FF::TWO;
|
||||
let mut max_num_limbs = 0;
|
||||
for _ in 0..(x.limbs.len() * 32) {
|
||||
let cur_power = self.constant_biguint(&cur_power_of_two.to_biguint());
|
||||
max_num_limbs = max_num_limbs.max(cur_power.limbs.len());
|
||||
powers_of_two.push(cur_power.limbs);
|
||||
|
||||
cur_power_of_two *= two;
|
||||
}
|
||||
|
||||
let mut result_limbs_unreduced = vec![self.zero(); max_num_limbs];
|
||||
for i in 0..x.limbs.len() {
|
||||
let this_limb = x.limbs[i];
|
||||
let bits = self.split_le(this_limb.0, 32);
|
||||
for b in 0..bits.len() {
|
||||
let this_power = powers_of_two[32 * i + b].clone();
|
||||
for x in 0..this_power.len() {
|
||||
result_limbs_unreduced[x] = self.mul_add(bits[b].target, this_power[x].0, result_limbs_unreduced[x]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let mut result_limbs_reduced = Vec::new();
|
||||
let mut carry = self.zero_u32();
|
||||
for i in 0..result_limbs_unreduced.len() {
|
||||
println!("{}", i);
|
||||
let (low, high) = self.split_to_u32(result_limbs_unreduced[i]);
|
||||
let (cur, overflow) = self.add_u32(carry, low);
|
||||
let (new_carry, _) = self.add_many_u32(&[overflow, high, carry]);
|
||||
result_limbs_reduced.push(cur);
|
||||
carry = new_carry;
|
||||
}
|
||||
result_limbs_reduced.push(carry);
|
||||
|
||||
let value = BigUintTarget {
|
||||
limbs: result_limbs_reduced,
|
||||
};
|
||||
|
||||
println!("NUMBER OF GATES: {}", self.num_gates() - before);
|
||||
println!("OUTPUT LIMBS: {}", value.limbs.len());
|
||||
|
||||
NonNativeTarget {
|
||||
value,
|
||||
_phantom: PhantomData,
|
||||
}
|
||||
}*/
|
||||
|
||||
#[allow(dead_code)]
|
||||
fn reduce_nonnative<FF: Field>(&mut self, x: &NonNativeTarget<FF>) -> NonNativeTarget<FF> {
|
||||
let x_biguint = self.nonnative_to_biguint(x);
|
||||
@ -280,6 +272,74 @@ impl<F: RichField + Extendable<D>, const D: usize> CircuitBuilder<F, D> {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct NonNativeAdditionGenerator<F: RichField + Extendable<D>, const D: usize, FF: Field> {
|
||||
a: NonNativeTarget<FF>,
|
||||
b: NonNativeTarget<FF>,
|
||||
sum: NonNativeTarget<FF>,
|
||||
overflow: BoolTarget,
|
||||
_phantom: PhantomData<F>,
|
||||
}
|
||||
|
||||
impl<F: RichField + Extendable<D>, const D: usize, FF: Field> SimpleGenerator<F>
|
||||
for NonNativeAdditionGenerator<F, D, FF>
|
||||
{
|
||||
fn dependencies(&self) -> Vec<Target> {
|
||||
self.a.value.limbs.iter().cloned().chain(self.b.value.limbs.clone())
|
||||
.map(|l| l.0)
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn run_once(&self, witness: &PartitionWitness<F>, out_buffer: &mut GeneratedValues<F>) {
|
||||
let a = witness.get_nonnative_target(self.a.clone());
|
||||
let b = witness.get_nonnative_target(self.b.clone());
|
||||
let a_biguint = a.to_biguint();
|
||||
let b_biguint = b.to_biguint();
|
||||
let sum_biguint = a_biguint + b_biguint;
|
||||
let modulus = FF::order();
|
||||
let (overflow, sum_reduced) = if sum_biguint > modulus {
|
||||
(true, sum_biguint - modulus)
|
||||
} else {
|
||||
(false, sum_biguint)
|
||||
};
|
||||
|
||||
out_buffer.set_biguint_target(self.sum.value.clone(), sum_reduced);
|
||||
out_buffer.set_bool_target(self.overflow, overflow);
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct NonNativeMultipleAddsGenerator<F: RichField + Extendable<D>, const D: usize, FF: Field> {
|
||||
summands: Vec<NonNativeTarget<FF>>,
|
||||
sum: NonNativeTarget<FF>,
|
||||
overflow: U32Target,
|
||||
_phantom: PhantomData<F>,
|
||||
}
|
||||
|
||||
impl<F: RichField + Extendable<D>, const D: usize, FF: Field> SimpleGenerator<F>
|
||||
for NonNativeMultipleAddsGenerator<F, D, FF>
|
||||
{
|
||||
fn dependencies(&self) -> Vec<Target> {
|
||||
self.summands.iter().map(|summand| summand.value.limbs.iter().map(|limb| limb.0))
|
||||
.flatten()
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn run_once(&self, witness: &PartitionWitness<F>, out_buffer: &mut GeneratedValues<F>) {
|
||||
let summands: Vec<_> = self.summands.iter().map(|summand| witness.get_nonnative_target(summand.clone())).collect();
|
||||
let summand_biguints: Vec<_> = summands.iter().map(|summand| summand.to_biguint()).collect();
|
||||
|
||||
let sum_biguint = summand_biguints.iter().fold(BigUint::zero(), |a, b| a + b.clone());
|
||||
|
||||
let modulus = FF::order();
|
||||
let (overflow_biguint, sum_reduced) = sum_biguint.div_rem(&modulus);
|
||||
let overflow = overflow_biguint.to_u64_digits()[0] as u32;
|
||||
|
||||
out_buffer.set_biguint_target(self.sum.value.clone(), sum_reduced);
|
||||
out_buffer.set_u32_target(self.overflow, overflow);
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct NonNativeInverseGenerator<F: RichField + Extendable<D>, const D: usize, FF: Field> {
|
||||
x: NonNativeTarget<FF>,
|
||||
@ -310,6 +370,8 @@ impl<F: RichField + Extendable<D>, const D: usize, FF: Field> SimpleGenerator<F>
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use anyhow::Result;
|
||||
|
||||
@ -11,6 +11,7 @@ pub mod binary_arithmetic;
|
||||
pub mod binary_subtraction;
|
||||
pub mod comparison;
|
||||
pub mod constant;
|
||||
// pub mod curve_double;
|
||||
pub mod exponentiation;
|
||||
pub mod gate;
|
||||
pub mod gate_tree;
|
||||
|
||||
@ -162,6 +162,10 @@ impl<F: Field> GeneratedValues<F> {
|
||||
self.target_values.push((target, value))
|
||||
}
|
||||
|
||||
pub fn set_bool_target(&mut self, target: BoolTarget, value: bool) {
|
||||
self.set_target(target.target, F::from_bool(value))
|
||||
}
|
||||
|
||||
pub fn set_u32_target(&mut self, target: U32Target, value: u32) {
|
||||
self.set_target(target.0, F::from_canonical_u32(value))
|
||||
}
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user