diff --git a/src/field/extension_field/quadratic.rs b/src/field/extension_field/quadratic.rs index ebad5025..e2794330 100644 --- a/src/field/extension_field/quadratic.rs +++ b/src/field/extension_field/quadratic.rs @@ -3,6 +3,7 @@ use std::iter::{Product, Sum}; use std::ops::{Add, AddAssign, Div, DivAssign, Mul, MulAssign, Neg, Sub, SubAssign}; use num::bigint::BigUint; +use num::Integer; use rand::Rng; use serde::{Deserialize, Serialize}; @@ -99,6 +100,15 @@ impl> Field for QuadraticExtension { )) } + fn from_biguint(n: BigUint) -> Self { + let (high, low) = n.div_rem(&F::order()); + Self([F::from_biguint(low), F::from_biguint(high)]) + } + + fn to_biguint(&self) -> BigUint { + self.0[0].to_biguint() + F::order() * self.0[1].to_biguint() + } + fn from_canonical_u64(n: u64) -> Self { F::from_canonical_u64(n).into() } diff --git a/src/field/extension_field/quartic.rs b/src/field/extension_field/quartic.rs index 001da821..01918ff3 100644 --- a/src/field/extension_field/quartic.rs +++ b/src/field/extension_field/quartic.rs @@ -4,6 +4,7 @@ use std::ops::{Add, AddAssign, Div, DivAssign, Mul, MulAssign, Neg, Sub, SubAssi use num::bigint::BigUint; use num::traits::Pow; +use num::Integer; use rand::Rng; use serde::{Deserialize, Serialize}; @@ -104,6 +105,26 @@ impl> Field for QuarticExtension { )) } + fn from_biguint(n: BigUint) -> Self { + let (rest, first) = n.div_rem(&F::order()); + let (rest, second) = rest.div_rem(&F::order()); + let (rest, third) = rest.div_rem(&F::order()); + Self([ + F::from_biguint(first), + F::from_biguint(second), + F::from_biguint(third), + F::from_biguint(rest), + ]) + } + + fn to_biguint(&self) -> BigUint { + let mut result = self.0[3].to_biguint(); + result = result * F::order() + self.0[2].to_biguint(); + result = result * F::order() + self.0[1].to_biguint(); + result = result * F::order() + self.0[0].to_biguint(); + result + } + fn from_canonical_u64(n: u64) -> Self { F::from_canonical_u64(n).into() } diff --git a/src/field/field_types.rs b/src/field/field_types.rs index 4fe10b17..481d87ba 100644 --- a/src/field/field_types.rs +++ b/src/field/field_types.rs @@ -206,6 +206,10 @@ pub trait Field: subgroup.into_iter().map(|x| x * shift).collect() } + fn from_biguint(n: BigUint) -> Self; + + fn to_biguint(&self) -> BigUint; + fn from_canonical_u64(n: u64) -> Self; fn from_canonical_u32(n: u32) -> Self { diff --git a/src/field/goldilocks_field.rs b/src/field/goldilocks_field.rs index 45164506..cb85d56d 100644 --- a/src/field/goldilocks_field.rs +++ b/src/field/goldilocks_field.rs @@ -4,7 +4,7 @@ use std::hash::{Hash, Hasher}; use std::iter::{Product, Sum}; use std::ops::{Add, AddAssign, Div, DivAssign, Mul, MulAssign, Neg, Sub, SubAssign}; -use num::BigUint; +use num::{BigUint, Integer}; use rand::Rng; use serde::{Deserialize, Serialize}; @@ -91,6 +91,14 @@ impl Field for GoldilocksField { try_inverse_u64(self) } + fn from_biguint(n: BigUint) -> Self { + Self(n.mod_floor(&Self::order()).to_u64_digits()[0]) + } + + fn to_biguint(&self) -> BigUint { + self.to_canonical_u64().into() + } + #[inline] fn from_canonical_u64(n: u64) -> Self { debug_assert!(n < Self::ORDER); diff --git a/src/field/secp256k1.rs b/src/field/secp256k1.rs index 56d506d6..5f8e1b4e 100644 --- a/src/field/secp256k1.rs +++ b/src/field/secp256k1.rs @@ -36,27 +36,6 @@ fn biguint_from_array(arr: [u64; 4]) -> BigUint { ]) } -impl Secp256K1Base { - fn to_canonical_biguint(&self) -> BigUint { - let mut result = biguint_from_array(self.0); - if result >= Self::order() { - result -= Self::order(); - } - result - } - - fn from_biguint(val: BigUint) -> Self { - Self( - val.to_u64_digits() - .into_iter() - .pad_using(4, |_| 0) - .collect::>()[..] - .try_into() - .expect("error converting to u64 array"), - ) - } -} - impl Default for Secp256K1Base { fn default() -> Self { Self::ZERO @@ -65,7 +44,7 @@ impl Default for Secp256K1Base { impl PartialEq for Secp256K1Base { fn eq(&self, other: &Self) -> bool { - self.to_canonical_biguint() == other.to_canonical_biguint() + self.to_biguint() == other.to_biguint() } } @@ -73,19 +52,19 @@ impl Eq for Secp256K1Base {} impl Hash for Secp256K1Base { fn hash(&self, state: &mut H) { - self.to_canonical_biguint().hash(state) + self.to_biguint().hash(state) } } impl Display for Secp256K1Base { fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { - Display::fmt(&self.to_canonical_biguint(), f) + Display::fmt(&self.to_biguint(), f) } } impl Debug for Secp256K1Base { fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { - Debug::fmt(&self.to_canonical_biguint(), f) + Debug::fmt(&self.to_biguint(), f) } } @@ -129,6 +108,25 @@ impl Field for Secp256K1Base { Some(self.exp_biguint(&(Self::order() - BigUint::one() - BigUint::one()))) } + fn to_biguint(&self) -> BigUint { + let mut result = biguint_from_array(self.0); + if result >= Self::order() { + result -= Self::order(); + } + result + } + + fn from_biguint(val: BigUint) -> Self { + Self( + val.to_u64_digits() + .into_iter() + .pad_using(4, |_| 0) + .collect::>()[..] + .try_into() + .expect("error converting to u64 array"), + ) + } + #[inline] fn from_canonical_u64(n: u64) -> Self { Self([n, 0, 0, 0]) @@ -157,7 +155,7 @@ impl Neg for Secp256K1Base { if self.is_zero() { Self::ZERO } else { - Self::from_biguint(Self::order() - self.to_canonical_biguint()) + Self::from_biguint(Self::order() - self.to_biguint()) } } } @@ -167,7 +165,7 @@ impl Add for Secp256K1Base { #[inline] fn add(self, rhs: Self) -> Self { - let mut result = self.to_canonical_biguint() + rhs.to_canonical_biguint(); + let mut result = self.to_biguint() + rhs.to_biguint(); if result >= Self::order() { result -= Self::order(); } @@ -210,9 +208,7 @@ impl Mul for Secp256K1Base { #[inline] fn mul(self, rhs: Self) -> Self { - Self::from_biguint( - (self.to_canonical_biguint() * rhs.to_canonical_biguint()).mod_floor(&Self::order()), - ) + Self::from_biguint((self.to_biguint() * rhs.to_biguint()).mod_floor(&Self::order())) } } diff --git a/src/gadgets/biguint.rs b/src/gadgets/biguint.rs index 82bfd91e..9e00c562 100644 --- a/src/gadgets/biguint.rs +++ b/src/gadgets/biguint.rs @@ -1,14 +1,13 @@ use std::marker::PhantomData; -use std::ops::Neg; -use num::{BigUint, Zero}; +use num::Integer; +use crate::field::extension_field::Extendable; use crate::field::field_types::RichField; -use crate::field::{extension_field::Extendable, field_types::Field}; use crate::gadgets::arithmetic_u32::U32Target; use crate::iop::generator::{GeneratedValues, SimpleGenerator}; use crate::iop::target::{BoolTarget, Target}; -use crate::iop::witness::PartitionWitness; +use crate::iop::witness::{PartitionWitness, Witness}; use crate::plonk::circuit_builder::CircuitBuilder; #[derive(Clone, Debug)] @@ -197,5 +196,18 @@ impl, const D: usize> SimpleGenerator .collect() } - fn run_once(&self, witness: &PartitionWitness, out_buffer: &mut GeneratedValues) {} + fn run_once(&self, witness: &PartitionWitness, out_buffer: &mut GeneratedValues) { + let a = witness.get_biguint_target(self.a.clone()); + let b = witness.get_biguint_target(self.b.clone()); + let (div, rem) = a.div_rem(&b); + + out_buffer.set_biguint_target(self.div.clone(), div); + out_buffer.set_biguint_target(self.rem.clone(), rem); + } +} + +#[cfg(test)] +mod tests { + #[test] + fn test_biguint_add() {} } diff --git a/src/iop/generator.rs b/src/iop/generator.rs index eb2c95f7..c395ad73 100644 --- a/src/iop/generator.rs +++ b/src/iop/generator.rs @@ -1,9 +1,13 @@ use std::fmt::Debug; use std::marker::PhantomData; +use num::BigUint; + use crate::field::extension_field::target::ExtensionTarget; use crate::field::extension_field::{Extendable, FieldExtension}; use crate::field::field_types::{Field, RichField}; +use crate::gadgets::arithmetic_u32::U32Target; +use crate::gadgets::biguint::BigUintTarget; use crate::hash::hash_types::{HashOut, HashOutTarget}; use crate::iop::target::Target; use crate::iop::wire::Wire; @@ -150,6 +154,17 @@ impl GeneratedValues { self.target_values.push((target, value)) } + fn set_u32_target(&mut self, target: U32Target, value: u32) { + self.set_target(target.0, F::from_canonical_u32(value)) + } + + pub fn set_biguint_target(&mut self, target: BigUintTarget, value: BigUint) { + let limbs = value.to_u32_digits(); + for i in 0..target.num_limbs() { + self.set_u32_target(target.get_limb(i), limbs[i]); + } + } + pub fn set_hash_target(&mut self, ht: HashOutTarget, value: HashOut) { ht.elements .iter() diff --git a/src/iop/witness.rs b/src/iop/witness.rs index 13a374e2..c1f877cb 100644 --- a/src/iop/witness.rs +++ b/src/iop/witness.rs @@ -57,6 +57,19 @@ pub trait Witness { panic!("not a bool") } + fn get_biguint_target(&self, target: BigUintTarget) -> BigUint { + let mut result = BigUint::zero(); + + let limb_base = BigUint::from_u64(1 << 32u64).unwrap(); + for i in (0..target.num_limbs()).rev() { + let limb = target.get_limb(i); + result *= &limb_base; + result += self.get_target(limb.0).to_biguint(); + } + + result + } + fn get_hash_target(&self, ht: HashOutTarget) -> HashOut { HashOut { elements: self.get_targets(&ht.elements).try_into().unwrap(),