diff --git a/Cargo.toml b/Cargo.toml index c36e3023..cc070d96 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,2 +1,10 @@ [workspace] -members = ["field", "insertion", "plonky2", "util", "waksman"] +members = ["field", "insertion", "plonky2", "starky", "system_zero", "util", "waksman"] + +[profile.release] +opt-level = 3 +#lto = "fat" +#codegen-units = 1 + +[profile.bench] +opt-level = 3 diff --git a/README.md b/README.md index f677a7c6..4dbd5906 100644 --- a/README.md +++ b/README.md @@ -1,8 +1,24 @@ # Plonky2 -Plonky2 is an implementation of recursive arguments based on Plonk and FRI. It uses FRI to check systems of polynomial constraints, similar to the DEEP-ALI method described in the [DEEP-FRI](https://arxiv.org/abs/1903.12243) paper. It is the successor of [plonky](https://github.com/mir-protocol/plonky), which was based on Plonk and Halo. +Plonky2 is a SNARK implementation based on techniques from PLONK and FRI. It is the successor of [Plonky](https://github.com/mir-protocol/plonky), which was based on PLONK and Halo. -Plonky2 is largely focused on recursion performance. We use custom gates to mitigate the bottlenecks of FRI verification, such as hashing and interpolation. We also encode witness data in a ~64 bit field, so field operations take just a few cycles. To achieve 128-bit security, we repeat certain checks, and run certain parts of the argument in an extension field. +Plonky2 is built for speed, and features a highly efficient recursive circuit. On a Macbook Pro, recursive proofs can be generated in about 170 ms. + + +## Documentation + +For more details about the Plonky2 argument system, see this [writeup](plonky2.pdf). + + +## Building + +Plonky2 requires a recent nightly toolchain, although we plan to transition to stable in the future. + +To use a nightly toolchain for Plonky2 by default, you can run +``` +rustup override set nightly +``` +in the Plonky2 directory. ## Running @@ -10,10 +26,17 @@ Plonky2 is largely focused on recursion performance. We use custom gates to miti To see recursion performance, one can run this test, which generates a chain of three recursion proofs: ```sh -RUST_LOG=debug RUSTFLAGS=-Ctarget-cpu=native cargo test --release test_recursive_recursive_verifier -- --ignored +RUST_LOG=debug RUSTFLAGS=-Ctarget-cpu=native cargo test --release test_recursive_recursive_verifier ``` +## Jemalloc + +By default, Plonky2 uses the [Jemalloc](http://jemalloc.net) memory allocator due to its superior performance. Currently, it changes the default allocator of any binary to which it is linked. You can disable this behavior by removing the corresponding lines in [`plonky2/src/lib.rs`](https://github.com/mir-protocol/plonky2/blob/main/plonky2/src/lib.rs). + +Jemalloc is known to cause crashes when a binary compiled for x86 is run on an Apple silicon-based Mac under [Rosetta 2](https://support.apple.com/en-us/HT211861). If you are experiencing crashes on your Apple silicon Mac, run `rustc --print target-libdir`. The output should contain `aarch64-apple-darwin`. If the output contains `x86_64-apple-darwin`, then you are running the Rust toolchain for x86; we recommend switching to the native ARM version. + + ## Copyright Plonky2 was developed by Polygon Zero (formerly Mir). While we plan to adopt an open source license, we haven't selected one yet, so all rights are reserved for the time being. Please reach out to us if you have thoughts on licensing. @@ -21,5 +44,5 @@ Plonky2 was developed by Polygon Zero (formerly Mir). While we plan to adopt an ## Disclaimer -This code has not been thoroughly reviewed or tested, and should not be used in any production systems. +This code has not yet been audited, and should not be used in any production systems. diff --git a/field/Cargo.toml b/field/Cargo.toml index 1a974852..6abffc5d 100644 --- a/field/Cargo.toml +++ b/field/Cargo.toml @@ -1,5 +1,6 @@ [package] name = "plonky2_field" +description = "Finite field arithmetic" version = "0.1.0" edition = "2021" diff --git a/field/src/arch/x86_64/avx2_goldilocks_field.rs b/field/src/arch/x86_64/avx2_goldilocks_field.rs index b9336cee..e185cb4c 100644 --- a/field/src/arch/x86_64/avx2_goldilocks_field.rs +++ b/field/src/arch/x86_64/avx2_goldilocks_field.rs @@ -5,7 +5,7 @@ use std::iter::{Product, Sum}; use std::mem::transmute; use std::ops::{Add, AddAssign, Div, DivAssign, Mul, MulAssign, Neg, Sub, SubAssign}; -use crate::field_types::{Field, PrimeField}; +use crate::field_types::{Field, Field64}; use crate::goldilocks_field::GoldilocksField; use crate::ops::Square; use crate::packed_field::PackedField; @@ -510,7 +510,7 @@ unsafe fn interleave2(x: __m256i, y: __m256i) -> (__m256i, __m256i) { #[cfg(test)] mod tests { use crate::arch::x86_64::avx2_goldilocks_field::Avx2GoldilocksField; - use crate::field_types::PrimeField; + use crate::field_types::Field64; use crate::goldilocks_field::GoldilocksField; use crate::ops::Square; use crate::packed_field::PackedField; diff --git a/field/src/arch/x86_64/avx512_goldilocks_field.rs b/field/src/arch/x86_64/avx512_goldilocks_field.rs new file mode 100644 index 00000000..aaa05e93 --- /dev/null +++ b/field/src/arch/x86_64/avx512_goldilocks_field.rs @@ -0,0 +1,656 @@ +use core::arch::x86_64::*; +use std::fmt; +use std::fmt::{Debug, Formatter}; +use std::iter::{Product, Sum}; +use std::mem::transmute; +use std::ops::{Add, AddAssign, Div, DivAssign, Mul, MulAssign, Neg, Sub, SubAssign}; + +use crate::field_types::{Field, Field64}; +use crate::goldilocks_field::GoldilocksField; +use crate::ops::Square; +use crate::packed_field::PackedField; + +// Ideally `Avx512GoldilocksField` would wrap `__m512i`. Unfortunately, `__m512i` has an alignment +// of 64B, which would preclude us from casting `[GoldilocksField; 8]` (alignment 8B) to +// `Avx512GoldilocksField`. We need to ensure that `Avx512GoldilocksField` has the same alignment as +// `GoldilocksField`. Thus we wrap `[GoldilocksField; 8]` and use the `new` and `get` methods to +// convert to and from `__m512i`. +#[derive(Copy, Clone)] +#[repr(transparent)] +pub struct Avx512GoldilocksField(pub [GoldilocksField; 8]); + +impl Avx512GoldilocksField { + #[inline] + fn new(x: __m512i) -> Self { + unsafe { transmute(x) } + } + #[inline] + fn get(&self) -> __m512i { + unsafe { transmute(*self) } + } +} + +unsafe impl PackedField for Avx512GoldilocksField { + const WIDTH: usize = 8; + + type Scalar = GoldilocksField; + + const ZEROS: Self = Self([GoldilocksField::ZERO; 8]); + const ONES: Self = Self([GoldilocksField::ONE; 8]); + + #[inline] + fn from_arr(arr: [Self::Scalar; Self::WIDTH]) -> Self { + Self(arr) + } + + #[inline] + fn as_arr(&self) -> [Self::Scalar; Self::WIDTH] { + self.0 + } + + #[inline] + fn from_slice(slice: &[Self::Scalar]) -> &Self { + assert_eq!(slice.len(), Self::WIDTH); + unsafe { &*slice.as_ptr().cast() } + } + #[inline] + fn from_slice_mut(slice: &mut [Self::Scalar]) -> &mut Self { + assert_eq!(slice.len(), Self::WIDTH); + unsafe { &mut *slice.as_mut_ptr().cast() } + } + #[inline] + fn as_slice(&self) -> &[Self::Scalar] { + &self.0[..] + } + #[inline] + fn as_slice_mut(&mut self) -> &mut [Self::Scalar] { + &mut self.0[..] + } + + #[inline] + fn interleave(&self, other: Self, block_len: usize) -> (Self, Self) { + let (v0, v1) = (self.get(), other.get()); + let (res0, res1) = match block_len { + 1 => unsafe { interleave1(v0, v1) }, + 2 => unsafe { interleave2(v0, v1) }, + 4 => unsafe { interleave4(v0, v1) }, + 8 => (v0, v1), + _ => panic!("unsupported block_len"), + }; + (Self::new(res0), Self::new(res1)) + } +} + +impl Add for Avx512GoldilocksField { + type Output = Self; + #[inline] + fn add(self, rhs: Self) -> Self { + Self::new(unsafe { add(self.get(), rhs.get()) }) + } +} +impl Add for Avx512GoldilocksField { + type Output = Self; + #[inline] + fn add(self, rhs: GoldilocksField) -> Self { + self + Self::from(rhs) + } +} +impl Add for GoldilocksField { + type Output = Avx512GoldilocksField; + #[inline] + fn add(self, rhs: Self::Output) -> Self::Output { + Self::Output::from(self) + rhs + } +} +impl AddAssign for Avx512GoldilocksField { + #[inline] + fn add_assign(&mut self, rhs: Self) { + *self = *self + rhs; + } +} +impl AddAssign for Avx512GoldilocksField { + #[inline] + fn add_assign(&mut self, rhs: GoldilocksField) { + *self = *self + rhs; + } +} + +impl Debug for Avx512GoldilocksField { + #[inline] + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + write!(f, "({:?})", self.get()) + } +} + +impl Default for Avx512GoldilocksField { + #[inline] + fn default() -> Self { + Self::ZEROS + } +} + +impl Div for Avx512GoldilocksField { + type Output = Self; + #[inline] + fn div(self, rhs: GoldilocksField) -> Self { + self * rhs.inverse() + } +} +impl DivAssign for Avx512GoldilocksField { + #[inline] + fn div_assign(&mut self, rhs: GoldilocksField) { + *self *= rhs.inverse(); + } +} + +impl From for Avx512GoldilocksField { + fn from(x: GoldilocksField) -> Self { + Self([x; 8]) + } +} + +impl Mul for Avx512GoldilocksField { + type Output = Self; + #[inline] + fn mul(self, rhs: Self) -> Self { + Self::new(unsafe { mul(self.get(), rhs.get()) }) + } +} +impl Mul for Avx512GoldilocksField { + type Output = Self; + #[inline] + fn mul(self, rhs: GoldilocksField) -> Self { + self * Self::from(rhs) + } +} +impl Mul for GoldilocksField { + type Output = Avx512GoldilocksField; + #[inline] + fn mul(self, rhs: Avx512GoldilocksField) -> Self::Output { + Self::Output::from(self) * rhs + } +} +impl MulAssign for Avx512GoldilocksField { + #[inline] + fn mul_assign(&mut self, rhs: Self) { + *self = *self * rhs; + } +} +impl MulAssign for Avx512GoldilocksField { + #[inline] + fn mul_assign(&mut self, rhs: GoldilocksField) { + *self = *self * rhs; + } +} + +impl Neg for Avx512GoldilocksField { + type Output = Self; + #[inline] + fn neg(self) -> Self { + Self::new(unsafe { neg(self.get()) }) + } +} + +impl Product for Avx512GoldilocksField { + #[inline] + fn product>(iter: I) -> Self { + iter.reduce(|x, y| x * y).unwrap_or(Self::ONES) + } +} + +impl Square for Avx512GoldilocksField { + #[inline] + fn square(&self) -> Self { + Self::new(unsafe { square(self.get()) }) + } +} + +impl Sub for Avx512GoldilocksField { + type Output = Self; + #[inline] + fn sub(self, rhs: Self) -> Self { + Self::new(unsafe { sub(self.get(), rhs.get()) }) + } +} +impl Sub for Avx512GoldilocksField { + type Output = Self; + #[inline] + fn sub(self, rhs: GoldilocksField) -> Self { + self - Self::from(rhs) + } +} +impl Sub for GoldilocksField { + type Output = Avx512GoldilocksField; + #[inline] + fn sub(self, rhs: Avx512GoldilocksField) -> Self::Output { + Self::Output::from(self) - rhs + } +} +impl SubAssign for Avx512GoldilocksField { + #[inline] + fn sub_assign(&mut self, rhs: Self) { + *self = *self - rhs; + } +} +impl SubAssign for Avx512GoldilocksField { + #[inline] + fn sub_assign(&mut self, rhs: GoldilocksField) { + *self = *self - rhs; + } +} + +impl Sum for Avx512GoldilocksField { + #[inline] + fn sum>(iter: I) -> Self { + iter.reduce(|x, y| x + y).unwrap_or(Self::ZEROS) + } +} + +const FIELD_ORDER: __m512i = unsafe { transmute([GoldilocksField::ORDER; 8]) }; +const EPSILON: __m512i = unsafe { transmute([GoldilocksField::ORDER.wrapping_neg(); 8]) }; + +#[inline] +unsafe fn canonicalize(x: __m512i) -> __m512i { + let mask = _mm512_cmpge_epu64_mask(x, FIELD_ORDER); + _mm512_mask_sub_epi64(x, mask, x, FIELD_ORDER) +} + +#[inline] +unsafe fn add_no_double_overflow_64_64(x: __m512i, y: __m512i) -> __m512i { + let res_wrapped = _mm512_add_epi64(x, y); + let mask = _mm512_cmplt_epu64_mask(res_wrapped, y); // mask set if add overflowed + let res = _mm512_mask_sub_epi64(res_wrapped, mask, res_wrapped, FIELD_ORDER); + res +} + +#[inline] +unsafe fn sub_no_double_overflow_64_64(x: __m512i, y: __m512i) -> __m512i { + let mask = _mm512_cmplt_epu64_mask(x, y); // mask set if sub will underflow (x < y) + let res_wrapped = _mm512_sub_epi64(x, y); + let res = _mm512_mask_add_epi64(res_wrapped, mask, res_wrapped, FIELD_ORDER); + res +} + +#[inline] +unsafe fn add(x: __m512i, y: __m512i) -> __m512i { + add_no_double_overflow_64_64(x, canonicalize(y)) +} + +#[inline] +unsafe fn sub(x: __m512i, y: __m512i) -> __m512i { + sub_no_double_overflow_64_64(x, canonicalize(y)) +} + +#[inline] +unsafe fn neg(y: __m512i) -> __m512i { + _mm512_sub_epi64(FIELD_ORDER, canonicalize(y)) +} + +const LO_32_BITS_MASK: __mmask16 = unsafe { transmute(0b0101010101010101u16) }; + +#[inline] +unsafe fn mul64_64(x: __m512i, y: __m512i) -> (__m512i, __m512i) { + // We want to move the high 32 bits to the low position. The multiplication instruction ignores + // the high 32 bits, so it's ok to just duplicate it into the low position. This duplication can + // be done on port 5; bitshifts run on port 0, competing with multiplication. + // This instruction is only provided for 32-bit floats, not integers. Idk why Intel makes the + // distinction; the casts are free and it guarantees that the exact bit pattern is preserved. + // Using a swizzle instruction of the wrong domain (float vs int) does not increase latency + // since Haswell. + let x_hi = _mm512_castps_si512(_mm512_movehdup_ps(_mm512_castsi512_ps(x))); + let y_hi = _mm512_castps_si512(_mm512_movehdup_ps(_mm512_castsi512_ps(y))); + + // All four pairwise multiplications + let mul_ll = _mm512_mul_epu32(x, y); + let mul_lh = _mm512_mul_epu32(x, y_hi); + let mul_hl = _mm512_mul_epu32(x_hi, y); + let mul_hh = _mm512_mul_epu32(x_hi, y_hi); + + // Bignum addition + // Extract high 32 bits of mul_ll and add to mul_hl. This cannot overflow. + let mul_ll_hi = _mm512_srli_epi64::<32>(mul_ll); + let t0 = _mm512_add_epi64(mul_hl, mul_ll_hi); + // Extract low 32 bits of t0 and add to mul_lh. Again, this cannot overflow. + // Also, extract high 32 bits of t0 and add to mul_hh. + let t0_lo = _mm512_and_si512(t0, EPSILON); + let t0_hi = _mm512_srli_epi64::<32>(t0); + let t1 = _mm512_add_epi64(mul_lh, t0_lo); + let t2 = _mm512_add_epi64(mul_hh, t0_hi); + // Lastly, extract the high 32 bits of t1 and add to t2. + let t1_hi = _mm512_srli_epi64::<32>(t1); + let res_hi = _mm512_add_epi64(t2, t1_hi); + + // Form res_lo by combining the low half of mul_ll with the low half of t1 (shifted into high + // position). + let t1_lo = _mm512_castps_si512(_mm512_moveldup_ps(_mm512_castsi512_ps(t1))); + let res_lo = _mm512_mask_blend_epi32(LO_32_BITS_MASK, t1_lo, mul_ll); + + (res_hi, res_lo) +} + +#[inline] +unsafe fn square64(x: __m512i) -> (__m512i, __m512i) { + // Get high 32 bits of x. See comment in mul64_64_s. + let x_hi = _mm512_castps_si512(_mm512_movehdup_ps(_mm512_castsi512_ps(x))); + + // All pairwise multiplications. + let mul_ll = _mm512_mul_epu32(x, x); + let mul_lh = _mm512_mul_epu32(x, x_hi); + let mul_hh = _mm512_mul_epu32(x_hi, x_hi); + + // Bignum addition, but mul_lh is shifted by 33 bits (not 32). + let mul_ll_hi = _mm512_srli_epi64::<33>(mul_ll); + let t0 = _mm512_add_epi64(mul_lh, mul_ll_hi); + let t0_hi = _mm512_srli_epi64::<31>(t0); + let res_hi = _mm512_add_epi64(mul_hh, t0_hi); + + // Form low result by adding the mul_ll and the low 31 bits of mul_lh (shifted to the high + // position). + let mul_lh_lo = _mm512_slli_epi64::<33>(mul_lh); + let res_lo = _mm512_add_epi64(mul_ll, mul_lh_lo); + + (res_hi, res_lo) +} + +#[inline] +unsafe fn reduce128(x: (__m512i, __m512i)) -> __m512i { + let (hi0, lo0) = x; + let hi_hi0 = _mm512_srli_epi64::<32>(hi0); + let lo1 = sub_no_double_overflow_64_64(lo0, hi_hi0); + let t1 = _mm512_mul_epu32(hi0, EPSILON); + let lo2 = add_no_double_overflow_64_64(lo1, t1); + lo2 +} + +#[inline] +unsafe fn mul(x: __m512i, y: __m512i) -> __m512i { + reduce128(mul64_64(x, y)) +} + +#[inline] +unsafe fn square(x: __m512i) -> __m512i { + reduce128(square64(x)) +} + +#[inline] +unsafe fn interleave1(x: __m512i, y: __m512i) -> (__m512i, __m512i) { + let a = _mm512_unpacklo_epi64(x, y); + let b = _mm512_unpackhi_epi64(x, y); + (a, b) +} + +const INTERLEAVE2_IDX_A: __m512i = unsafe { + transmute([ + 0o00u64, 0o01u64, 0o10u64, 0o11u64, 0o04u64, 0o05u64, 0o14u64, 0o15u64, + ]) +}; +const INTERLEAVE2_IDX_B: __m512i = unsafe { + transmute([ + 0o02u64, 0o03u64, 0o12u64, 0o13u64, 0o06u64, 0o07u64, 0o16u64, 0o17u64, + ]) +}; + +#[inline] +unsafe fn interleave2(x: __m512i, y: __m512i) -> (__m512i, __m512i) { + let a = _mm512_permutex2var_epi64(x, INTERLEAVE2_IDX_A, y); + let b = _mm512_permutex2var_epi64(x, INTERLEAVE2_IDX_B, y); + (a, b) +} + +#[inline] +unsafe fn interleave4(x: __m512i, y: __m512i) -> (__m512i, __m512i) { + let a = _mm512_shuffle_i64x2::<0x44>(x, y); + let b = _mm512_shuffle_i64x2::<0xee>(x, y); + (a, b) +} + +#[cfg(test)] +mod tests { + use crate::arch::x86_64::avx512_goldilocks_field::Avx512GoldilocksField; + use crate::field_types::Field64; + use crate::goldilocks_field::GoldilocksField; + use crate::ops::Square; + use crate::packed_field::PackedField; + + fn test_vals_a() -> [GoldilocksField; 8] { + [ + GoldilocksField::from_noncanonical_u64(14479013849828404771), + GoldilocksField::from_noncanonical_u64(9087029921428221768), + GoldilocksField::from_noncanonical_u64(2441288194761790662), + GoldilocksField::from_noncanonical_u64(5646033492608483824), + GoldilocksField::from_noncanonical_u64(2779181197214900072), + GoldilocksField::from_noncanonical_u64(2989742820063487116), + GoldilocksField::from_noncanonical_u64(727880025589250743), + GoldilocksField::from_noncanonical_u64(3803926346107752679), + ] + } + fn test_vals_b() -> [GoldilocksField; 8] { + [ + GoldilocksField::from_noncanonical_u64(17891926589593242302), + GoldilocksField::from_noncanonical_u64(11009798273260028228), + GoldilocksField::from_noncanonical_u64(2028722748960791447), + GoldilocksField::from_noncanonical_u64(7929433601095175579), + GoldilocksField::from_noncanonical_u64(6632528436085461172), + GoldilocksField::from_noncanonical_u64(2145438710786785567), + GoldilocksField::from_noncanonical_u64(11821483668392863016), + GoldilocksField::from_noncanonical_u64(15638272883309521929), + ] + } + + #[test] + fn test_add() { + let a_arr = test_vals_a(); + let b_arr = test_vals_b(); + + let packed_a = Avx512GoldilocksField::from_arr(a_arr); + let packed_b = Avx512GoldilocksField::from_arr(b_arr); + let packed_res = packed_a + packed_b; + let arr_res = packed_res.as_arr(); + + let expected = a_arr.iter().zip(b_arr).map(|(&a, b)| a + b); + for (exp, res) in expected.zip(arr_res) { + assert_eq!(res, exp); + } + } + + #[test] + fn test_mul() { + let a_arr = test_vals_a(); + let b_arr = test_vals_b(); + + let packed_a = Avx512GoldilocksField::from_arr(a_arr); + let packed_b = Avx512GoldilocksField::from_arr(b_arr); + let packed_res = packed_a * packed_b; + let arr_res = packed_res.as_arr(); + + let expected = a_arr.iter().zip(b_arr).map(|(&a, b)| a * b); + for (exp, res) in expected.zip(arr_res) { + assert_eq!(res, exp); + } + } + + #[test] + fn test_square() { + let a_arr = test_vals_a(); + + let packed_a = Avx512GoldilocksField::from_arr(a_arr); + let packed_res = packed_a.square(); + let arr_res = packed_res.as_arr(); + + let expected = a_arr.iter().map(|&a| a.square()); + for (exp, res) in expected.zip(arr_res) { + assert_eq!(res, exp); + } + } + + #[test] + fn test_neg() { + let a_arr = test_vals_a(); + + let packed_a = Avx512GoldilocksField::from_arr(a_arr); + let packed_res = -packed_a; + let arr_res = packed_res.as_arr(); + + let expected = a_arr.iter().map(|&a| -a); + for (exp, res) in expected.zip(arr_res) { + assert_eq!(res, exp); + } + } + + #[test] + fn test_sub() { + let a_arr = test_vals_a(); + let b_arr = test_vals_b(); + + let packed_a = Avx512GoldilocksField::from_arr(a_arr); + let packed_b = Avx512GoldilocksField::from_arr(b_arr); + let packed_res = packed_a - packed_b; + let arr_res = packed_res.as_arr(); + + let expected = a_arr.iter().zip(b_arr).map(|(&a, b)| a - b); + for (exp, res) in expected.zip(arr_res) { + assert_eq!(res, exp); + } + } + + #[test] + fn test_interleave_is_involution() { + let a_arr = test_vals_a(); + let b_arr = test_vals_b(); + + let packed_a = Avx512GoldilocksField::from_arr(a_arr); + let packed_b = Avx512GoldilocksField::from_arr(b_arr); + { + // Interleave, then deinterleave. + let (x, y) = packed_a.interleave(packed_b, 1); + let (res_a, res_b) = x.interleave(y, 1); + assert_eq!(res_a.as_arr(), a_arr); + assert_eq!(res_b.as_arr(), b_arr); + } + { + let (x, y) = packed_a.interleave(packed_b, 2); + let (res_a, res_b) = x.interleave(y, 2); + assert_eq!(res_a.as_arr(), a_arr); + assert_eq!(res_b.as_arr(), b_arr); + } + { + let (x, y) = packed_a.interleave(packed_b, 4); + let (res_a, res_b) = x.interleave(y, 4); + assert_eq!(res_a.as_arr(), a_arr); + assert_eq!(res_b.as_arr(), b_arr); + } + { + let (x, y) = packed_a.interleave(packed_b, 8); + let (res_a, res_b) = x.interleave(y, 8); + assert_eq!(res_a.as_arr(), a_arr); + assert_eq!(res_b.as_arr(), b_arr); + } + } + + #[test] + fn test_interleave() { + let in_a: [GoldilocksField; 8] = [ + GoldilocksField::from_noncanonical_u64(00), + GoldilocksField::from_noncanonical_u64(01), + GoldilocksField::from_noncanonical_u64(02), + GoldilocksField::from_noncanonical_u64(03), + GoldilocksField::from_noncanonical_u64(04), + GoldilocksField::from_noncanonical_u64(05), + GoldilocksField::from_noncanonical_u64(06), + GoldilocksField::from_noncanonical_u64(07), + ]; + let in_b: [GoldilocksField; 8] = [ + GoldilocksField::from_noncanonical_u64(10), + GoldilocksField::from_noncanonical_u64(11), + GoldilocksField::from_noncanonical_u64(12), + GoldilocksField::from_noncanonical_u64(13), + GoldilocksField::from_noncanonical_u64(14), + GoldilocksField::from_noncanonical_u64(15), + GoldilocksField::from_noncanonical_u64(16), + GoldilocksField::from_noncanonical_u64(17), + ]; + let int1_a: [GoldilocksField; 8] = [ + GoldilocksField::from_noncanonical_u64(00), + GoldilocksField::from_noncanonical_u64(10), + GoldilocksField::from_noncanonical_u64(02), + GoldilocksField::from_noncanonical_u64(12), + GoldilocksField::from_noncanonical_u64(04), + GoldilocksField::from_noncanonical_u64(14), + GoldilocksField::from_noncanonical_u64(06), + GoldilocksField::from_noncanonical_u64(16), + ]; + let int1_b: [GoldilocksField; 8] = [ + GoldilocksField::from_noncanonical_u64(01), + GoldilocksField::from_noncanonical_u64(11), + GoldilocksField::from_noncanonical_u64(03), + GoldilocksField::from_noncanonical_u64(13), + GoldilocksField::from_noncanonical_u64(05), + GoldilocksField::from_noncanonical_u64(15), + GoldilocksField::from_noncanonical_u64(07), + GoldilocksField::from_noncanonical_u64(17), + ]; + let int2_a: [GoldilocksField; 8] = [ + GoldilocksField::from_noncanonical_u64(00), + GoldilocksField::from_noncanonical_u64(01), + GoldilocksField::from_noncanonical_u64(10), + GoldilocksField::from_noncanonical_u64(11), + GoldilocksField::from_noncanonical_u64(04), + GoldilocksField::from_noncanonical_u64(05), + GoldilocksField::from_noncanonical_u64(14), + GoldilocksField::from_noncanonical_u64(15), + ]; + let int2_b: [GoldilocksField; 8] = [ + GoldilocksField::from_noncanonical_u64(02), + GoldilocksField::from_noncanonical_u64(03), + GoldilocksField::from_noncanonical_u64(12), + GoldilocksField::from_noncanonical_u64(13), + GoldilocksField::from_noncanonical_u64(06), + GoldilocksField::from_noncanonical_u64(07), + GoldilocksField::from_noncanonical_u64(16), + GoldilocksField::from_noncanonical_u64(17), + ]; + let int4_a: [GoldilocksField; 8] = [ + GoldilocksField::from_noncanonical_u64(00), + GoldilocksField::from_noncanonical_u64(01), + GoldilocksField::from_noncanonical_u64(02), + GoldilocksField::from_noncanonical_u64(03), + GoldilocksField::from_noncanonical_u64(10), + GoldilocksField::from_noncanonical_u64(11), + GoldilocksField::from_noncanonical_u64(12), + GoldilocksField::from_noncanonical_u64(13), + ]; + let int4_b: [GoldilocksField; 8] = [ + GoldilocksField::from_noncanonical_u64(04), + GoldilocksField::from_noncanonical_u64(05), + GoldilocksField::from_noncanonical_u64(06), + GoldilocksField::from_noncanonical_u64(07), + GoldilocksField::from_noncanonical_u64(14), + GoldilocksField::from_noncanonical_u64(15), + GoldilocksField::from_noncanonical_u64(16), + GoldilocksField::from_noncanonical_u64(17), + ]; + + let packed_a = Avx512GoldilocksField::from_arr(in_a); + let packed_b = Avx512GoldilocksField::from_arr(in_b); + { + let (x1, y1) = packed_a.interleave(packed_b, 1); + assert_eq!(x1.as_arr(), int1_a); + assert_eq!(y1.as_arr(), int1_b); + } + { + let (x2, y2) = packed_a.interleave(packed_b, 2); + assert_eq!(x2.as_arr(), int2_a); + assert_eq!(y2.as_arr(), int2_b); + } + { + let (x4, y4) = packed_a.interleave(packed_b, 4); + assert_eq!(x4.as_arr(), int4_a); + assert_eq!(y4.as_arr(), int4_b); + } + { + let (x8, y8) = packed_a.interleave(packed_b, 8); + assert_eq!(x8.as_arr(), in_a); + assert_eq!(y8.as_arr(), in_b); + } + } +} diff --git a/field/src/arch/x86_64/mod.rs b/field/src/arch/x86_64/mod.rs index bd9dccae..326deb78 100644 --- a/field/src/arch/x86_64/mod.rs +++ b/field/src/arch/x86_64/mod.rs @@ -1,2 +1,20 @@ -#[cfg(target_feature = "avx2")] +#[cfg(all( + target_feature = "avx2", + not(all( + target_feature = "avx512bw", + target_feature = "avx512cd", + target_feature = "avx512dq", + target_feature = "avx512f", + target_feature = "avx512vl" + )) +))] pub mod avx2_goldilocks_field; + +#[cfg(all( + target_feature = "avx512bw", + target_feature = "avx512cd", + target_feature = "avx512dq", + target_feature = "avx512f", + target_feature = "avx512vl" +))] +pub mod avx512_goldilocks_field; diff --git a/field/src/extension_field/quadratic.rs b/field/src/extension_field/quadratic.rs index e072d323..488304d2 100644 --- a/field/src/extension_field/quadratic.rs +++ b/field/src/extension_field/quadratic.rs @@ -95,10 +95,6 @@ impl> Field for QuadraticExtension { Self([F::from_biguint(low), F::from_biguint(high)]) } - fn to_biguint(&self) -> BigUint { - self.0[0].to_biguint() + F::order() * self.0[1].to_biguint() - } - fn from_canonical_u64(n: u64) -> Self { F::from_canonical_u64(n).into() } diff --git a/field/src/extension_field/quartic.rs b/field/src/extension_field/quartic.rs index 4e9cebf9..7b4a6950 100644 --- a/field/src/extension_field/quartic.rs +++ b/field/src/extension_field/quartic.rs @@ -107,14 +107,6 @@ impl> Field for QuarticExtension { ]) } - fn to_biguint(&self) -> BigUint { - let mut result = self.0[3].to_biguint(); - result = result * F::order() + self.0[2].to_biguint(); - result = result * F::order() + self.0[1].to_biguint(); - result = result * F::order() + self.0[0].to_biguint(); - result - } - fn from_canonical_u64(n: u64) -> Self { F::from_canonical_u64(n).into() } diff --git a/field/src/fft.rs b/field/src/fft.rs index 8428d3fb..c548d51e 100644 --- a/field/src/fft.rs +++ b/field/src/fft.rs @@ -1,7 +1,7 @@ use std::cmp::{max, min}; use std::option::Option; -use plonky2_util::{log2_strict, reverse_index_bits}; +use plonky2_util::{log2_strict, reverse_index_bits_in_place}; use unroll::unroll_for_loops; use crate::field_types::Field; @@ -34,10 +34,10 @@ pub fn fft_root_table(n: usize) -> FftRootTable { #[inline] fn fft_dispatch( - input: &[F], + input: &mut [F], zero_factor: Option, root_table: Option<&FftRootTable>, -) -> Vec { +) { let computed_root_table = if root_table.is_some() { None } else { @@ -45,33 +45,32 @@ fn fft_dispatch( }; let used_root_table = root_table.or(computed_root_table.as_ref()).unwrap(); - fft_classic(input, zero_factor.unwrap_or(0), used_root_table) + fft_classic(input, zero_factor.unwrap_or(0), used_root_table); } #[inline] -pub fn fft(poly: &PolynomialCoeffs) -> PolynomialValues { +pub fn fft(poly: PolynomialCoeffs) -> PolynomialValues { fft_with_options(poly, None, None) } #[inline] pub fn fft_with_options( - poly: &PolynomialCoeffs, + poly: PolynomialCoeffs, zero_factor: Option, root_table: Option<&FftRootTable>, ) -> PolynomialValues { - let PolynomialCoeffs { coeffs } = poly; - PolynomialValues { - values: fft_dispatch(coeffs, zero_factor, root_table), - } + let PolynomialCoeffs { coeffs: mut buffer } = poly; + fft_dispatch(&mut buffer, zero_factor, root_table); + PolynomialValues { values: buffer } } #[inline] -pub fn ifft(poly: &PolynomialValues) -> PolynomialCoeffs { +pub fn ifft(poly: PolynomialValues) -> PolynomialCoeffs { ifft_with_options(poly, None, None) } pub fn ifft_with_options( - poly: &PolynomialValues, + poly: PolynomialValues, zero_factor: Option, root_table: Option<&FftRootTable>, ) -> PolynomialCoeffs { @@ -79,20 +78,20 @@ pub fn ifft_with_options( let lg_n = log2_strict(n); let n_inv = F::inverse_2exp(lg_n); - let PolynomialValues { values } = poly; - let mut coeffs = fft_dispatch(values, zero_factor, root_table); + let PolynomialValues { values: mut buffer } = poly; + fft_dispatch(&mut buffer, zero_factor, root_table); // We reverse all values except the first, and divide each by n. - coeffs[0] *= n_inv; - coeffs[n / 2] *= n_inv; + buffer[0] *= n_inv; + buffer[n / 2] *= n_inv; for i in 1..(n / 2) { let j = n - i; - let coeffs_i = coeffs[j] * n_inv; - let coeffs_j = coeffs[i] * n_inv; - coeffs[i] = coeffs_i; - coeffs[j] = coeffs_j; + let coeffs_i = buffer[j] * n_inv; + let coeffs_j = buffer[i] * n_inv; + buffer[i] = coeffs_i; + buffer[j] = coeffs_j; } - PolynomialCoeffs { coeffs } + PolynomialCoeffs { coeffs: buffer } } /// Generic FFT implementation that works with both scalar and packed inputs. @@ -167,8 +166,8 @@ fn fft_classic_simd( /// The parameter r signifies that the first 1/2^r of the entries of /// input may be non-zero, but the last 1 - 1/2^r entries are /// definitely zero. -pub(crate) fn fft_classic(input: &[F], r: usize, root_table: &FftRootTable) -> Vec { - let mut values = reverse_index_bits(input); +pub(crate) fn fft_classic(values: &mut [F], r: usize, root_table: &FftRootTable) { + reverse_index_bits_in_place(values); let n = values.len(); let lg_n = log2_strict(n); @@ -200,11 +199,10 @@ pub(crate) fn fft_classic(input: &[F], r: usize, root_table: &FftRootT if lg_n <= lg_packed_width { // Need the slice to be at least the width of two packed vectors for the vectorized version // to work. Do this tiny problem in scalar. - fft_classic_simd::(&mut values[..], r, lg_n, root_table); + fft_classic_simd::(values, r, lg_n, root_table); } else { - fft_classic_simd::<::Packing>(&mut values[..], r, lg_n, root_table); + fft_classic_simd::<::Packing>(values, r, lg_n, root_table); } - values } #[cfg(test)] @@ -231,10 +229,10 @@ mod tests { assert_eq!(coeffs.len(), degree_padded); let coefficients = PolynomialCoeffs { coeffs }; - let points = fft(&coefficients); + let points = fft(coefficients.clone()); assert_eq!(points, evaluate_naive(&coefficients)); - let interpolated_coefficients = ifft(&points); + let interpolated_coefficients = ifft(points); for i in 0..degree { assert_eq!(interpolated_coefficients.coeffs[i], coefficients.coeffs[i]); } @@ -245,7 +243,10 @@ mod tests { for r in 0..4 { // expand coefficients by factor 2^r by filling with zeros let zero_tail = coefficients.lde(r); - assert_eq!(fft(&zero_tail), fft_with_options(&zero_tail, Some(r), None)); + assert_eq!( + fft(zero_tail.clone()), + fft_with_options(zero_tail, Some(r), None) + ); } } diff --git a/field/src/field_types.rs b/field/src/field_types.rs index 0d7b314f..83826b9f 100644 --- a/field/src/field_types.rs +++ b/field/src/field_types.rs @@ -264,17 +264,28 @@ pub trait Field: subgroup.into_iter().map(|x| x * shift).collect() } - // TODO: move these to a new `PrimeField` trait (for all prime fields, not just 64-bit ones) + // TODO: The current behavior for composite fields doesn't seem natural or useful. + // Rename to `from_noncanonical_biguint` and have it return `n % Self::characteristic()`. fn from_biguint(n: BigUint) -> Self; - fn to_biguint(&self) -> BigUint; - + /// Returns `n`. Assumes that `n` is already in canonical form, i.e. `n < Self::order()`. + // TODO: Should probably be unsafe. fn from_canonical_u64(n: u64) -> Self; + /// Returns `n`. Assumes that `n` is already in canonical form, i.e. `n < Self::order()`. + // TODO: Should probably be unsafe. fn from_canonical_u32(n: u32) -> Self { Self::from_canonical_u64(n as u64) } + /// Returns `n`. Assumes that `n` is already in canonical form, i.e. `n < Self::order()`. + // TODO: Should probably be unsafe. + fn from_canonical_u16(n: u16) -> Self { + Self::from_canonical_u64(n as u64) + } + + /// Returns `n`. Assumes that `n` is already in canonical form, i.e. `n < Self::order()`. + // TODO: Should probably be unsafe. fn from_canonical_usize(n: usize) -> Self { Self::from_canonical_u64(n as u64) } @@ -283,11 +294,11 @@ pub trait Field: Self::from_canonical_u64(b as u64) } - /// Returns `n % Self::CHARACTERISTIC`. + /// Returns `n % Self::characteristic()`. fn from_noncanonical_u128(n: u128) -> Self; - /// Returns `n % Self::CHARACTERISTIC`. May be cheaper than from_noncanonical_u128 when we know - /// that n < 2 ** 96. + /// Returns `n % Self::characteristic()`. May be cheaper than from_noncanonical_u128 when we know + /// that `n < 2 ** 96`. #[inline] fn from_noncanonical_u96((n_lo, n_hi): (u64, u32)) -> Self { // Default implementation. @@ -399,22 +410,26 @@ pub trait Field: } } -/// A finite field of prime order less than 2^64. pub trait PrimeField: Field { + fn to_canonical_biguint(&self) -> BigUint; +} + +/// A finite field of order less than 2^64. +pub trait Field64: Field { const ORDER: u64; - fn to_canonical_u64(&self) -> u64; - - fn to_noncanonical_u64(&self) -> u64; - + /// Returns `x % Self::CHARACTERISTIC`. + // TODO: Move to `Field`. fn from_noncanonical_u64(n: u64) -> Self; #[inline] + // TODO: Move to `Field`. fn add_one(&self) -> Self { unsafe { self.add_canonical_u64(1) } } #[inline] + // TODO: Move to `Field`. fn sub_one(&self) -> Self { unsafe { self.sub_canonical_u64(1) } } @@ -423,6 +438,7 @@ pub trait PrimeField: Field { /// Equivalent to *self + Self::from_canonical_u64(rhs), but may be cheaper. The caller must /// ensure that 0 <= rhs < Self::ORDER. The function may return incorrect results if this /// precondition is not met. It is marked unsafe for this reason. + // TODO: Move to `Field`. #[inline] unsafe fn add_canonical_u64(&self, rhs: u64) -> Self { // Default implementation. @@ -433,6 +449,7 @@ pub trait PrimeField: Field { /// Equivalent to *self - Self::from_canonical_u64(rhs), but may be cheaper. The caller must /// ensure that 0 <= rhs < Self::ORDER. The function may return incorrect results if this /// precondition is not met. It is marked unsafe for this reason. + // TODO: Move to `Field`. #[inline] unsafe fn sub_canonical_u64(&self, rhs: u64) -> Self { // Default implementation. @@ -440,6 +457,13 @@ pub trait PrimeField: Field { } } +/// A finite field of prime order less than 2^64. +pub trait PrimeField64: PrimeField + Field64 { + fn to_canonical_u64(&self) -> u64; + + fn to_noncanonical_u64(&self) -> u64; +} + /// An iterator over the powers of a certain base element `b`: `b^0, b^1, b^2, ...`. #[derive(Clone)] pub struct Powers { diff --git a/field/src/goldilocks_field.rs b/field/src/goldilocks_field.rs index 54866b1f..6c033bb2 100644 --- a/field/src/goldilocks_field.rs +++ b/field/src/goldilocks_field.rs @@ -12,7 +12,7 @@ use serde::{Deserialize, Serialize}; use crate::extension_field::quadratic::QuadraticExtension; use crate::extension_field::quartic::QuarticExtension; use crate::extension_field::{Extendable, Frobenius}; -use crate::field_types::{Field, PrimeField}; +use crate::field_types::{Field, Field64, PrimeField, PrimeField64}; use crate::inversion::try_inverse_u64; const EPSILON: u64 = (1 << 32) - 1; @@ -98,10 +98,6 @@ impl Field for GoldilocksField { Self(n.mod_floor(&Self::order()).to_u64_digits()[0]) } - fn to_biguint(&self) -> BigUint { - self.to_canonical_u64().into() - } - #[inline] fn from_canonical_u64(n: u64) -> Self { debug_assert!(n < Self::ORDER); @@ -124,22 +120,14 @@ impl Field for GoldilocksField { } impl PrimeField for GoldilocksField { + fn to_canonical_biguint(&self) -> BigUint { + self.to_canonical_u64().into() + } +} + +impl Field64 for GoldilocksField { const ORDER: u64 = 0xFFFFFFFF00000001; - #[inline] - fn to_canonical_u64(&self) -> u64 { - let mut c = self.0; - // We only need one condition subtraction, since 2 * ORDER would not fit in a u64. - if c >= Self::ORDER { - c -= Self::ORDER; - } - c - } - - fn to_noncanonical_u64(&self) -> u64 { - self.0 - } - #[inline] fn from_noncanonical_u64(n: u64) -> Self { Self(n) @@ -160,6 +148,22 @@ impl PrimeField for GoldilocksField { } } +impl PrimeField64 for GoldilocksField { + #[inline] + fn to_canonical_u64(&self) -> u64 { + let mut c = self.0; + // We only need one condition subtraction, since 2 * ORDER would not fit in a u64. + if c >= Self::ORDER { + c -= Self::ORDER; + } + c + } + + fn to_noncanonical_u64(&self) -> u64 { + self.0 + } +} + impl Neg for GoldilocksField { type Output = Self; diff --git a/field/src/interpolation.rs b/field/src/interpolation.rs index ac6f6437..1a2e37df 100644 --- a/field/src/interpolation.rs +++ b/field/src/interpolation.rs @@ -19,7 +19,7 @@ pub fn interpolant(points: &[(F, F)]) -> PolynomialCoeffs { .map(|x| interpolate(points, x, &barycentric_weights)) .collect(); - let mut coeffs = ifft(&PolynomialValues { + let mut coeffs = ifft(PolynomialValues { values: subgroup_evals, }); coeffs.trim(); diff --git a/field/src/inversion.rs b/field/src/inversion.rs index bbfb8e0d..5eabc45c 100644 --- a/field/src/inversion.rs +++ b/field/src/inversion.rs @@ -1,4 +1,4 @@ -use crate::field_types::PrimeField; +use crate::field_types::PrimeField64; /// This is a 'safe' iteration for the modular inversion algorithm. It /// is safe in the sense that it will produce the right answer even @@ -63,7 +63,7 @@ unsafe fn unsafe_iteration(f: &mut u64, g: &mut u64, c: &mut i128, d: &mut i128, /// Elliptic and Hyperelliptic Cryptography, Algorithms 11.6 /// and 11.12. #[allow(clippy::many_single_char_names)] -pub(crate) fn try_inverse_u64(x: &F) -> Option { +pub(crate) fn try_inverse_u64(x: &F) -> Option { let mut f = x.to_noncanonical_u64(); let mut g = F::ORDER; // NB: These two are very rarely such that their absolute diff --git a/field/src/lib.rs b/field/src/lib.rs index 47dd9ccb..2c89aab3 100644 --- a/field/src/lib.rs +++ b/field/src/lib.rs @@ -7,6 +7,7 @@ #![allow(clippy::return_self_not_must_use)] #![feature(generic_const_exprs)] #![feature(specialization)] +#![feature(stdsimd)] pub(crate) mod arch; pub mod batch_util; @@ -23,6 +24,7 @@ pub mod packed_field; pub mod polynomial; pub mod secp256k1_base; pub mod secp256k1_scalar; +pub mod zero_poly_coset; #[cfg(test)] mod field_testing; diff --git a/field/src/packable.rs b/field/src/packable.rs index 754a7fb6..18fe07f7 100644 --- a/field/src/packable.rs +++ b/field/src/packable.rs @@ -12,7 +12,29 @@ impl Packable for F { default type Packing = Self; } -#[cfg(all(target_arch = "x86_64", target_feature = "avx2"))] +#[cfg(all( + target_arch = "x86_64", + target_feature = "avx2", + not(all( + target_feature = "avx512bw", + target_feature = "avx512cd", + target_feature = "avx512dq", + target_feature = "avx512f", + target_feature = "avx512vl" + )) +))] impl Packable for crate::goldilocks_field::GoldilocksField { type Packing = crate::arch::x86_64::avx2_goldilocks_field::Avx2GoldilocksField; } + +#[cfg(all( + target_arch = "x86_64", + target_feature = "avx512bw", + target_feature = "avx512cd", + target_feature = "avx512dq", + target_feature = "avx512f", + target_feature = "avx512vl" +))] +impl Packable for crate::goldilocks_field::GoldilocksField { + type Packing = crate::arch::x86_64::avx512_goldilocks_field::Avx512GoldilocksField; +} diff --git a/field/src/polynomial/division.rs b/field/src/polynomial/division.rs index 4f3cafae..d761ab50 100644 --- a/field/src/polynomial/division.rs +++ b/field/src/polynomial/division.rs @@ -67,9 +67,9 @@ impl PolynomialCoeffs { } } - /// Let `self=p(X)`, this returns `(p(X)-p(z))/(X-z)` and `p(z)`. + /// Let `self=p(X)`, this returns `(p(X)-p(z))/(X-z)`. /// See https://en.wikipedia.org/wiki/Horner%27s_method - pub fn divide_by_linear(&self, z: F) -> (PolynomialCoeffs, F) { + pub fn divide_by_linear(&self, z: F) -> PolynomialCoeffs { let mut bs = self .coeffs .iter() @@ -79,9 +79,9 @@ impl PolynomialCoeffs { Some(*acc) }) .collect::>(); - let ev = bs.pop().unwrap_or(F::ZERO); + bs.pop(); bs.reverse(); - (Self { coeffs: bs }, ev) + Self { coeffs: bs } } /// Computes the inverse of `self` modulo `x^n`. @@ -125,7 +125,7 @@ impl PolynomialCoeffs { #[cfg(test)] mod tests { - use std::time::Instant; + use rand::{thread_rng, Rng}; use crate::extension_field::quartic::QuarticExtension; use crate::field_types::Field; @@ -133,47 +133,17 @@ mod tests { use crate::polynomial::PolynomialCoeffs; #[test] - #[ignore] fn test_division_by_linear() { type F = QuarticExtension; - let n = 1_000_000; + let n = thread_rng().gen_range(1..1000); let poly = PolynomialCoeffs::new(F::rand_vec(n)); let z = F::rand(); let ev = poly.eval(z); - let timer = Instant::now(); - let (_quotient, ev2) = poly.div_rem(&PolynomialCoeffs::new(vec![-z, F::ONE])); - println!("{:.3}s for usual", timer.elapsed().as_secs_f32()); - assert_eq!(ev2.trimmed().coeffs, vec![ev]); - - let timer = Instant::now(); - let (quotient, ev3) = poly.div_rem_long_division(&PolynomialCoeffs::new(vec![-z, F::ONE])); - println!("{:.3}s for long division", timer.elapsed().as_secs_f32()); - assert_eq!(ev3.trimmed().coeffs, vec![ev]); - - let timer = Instant::now(); - let horn = poly.divide_by_linear(z); - println!("{:.3}s for Horner", timer.elapsed().as_secs_f32()); - assert_eq!((quotient, ev), horn); - } - - #[test] - #[ignore] - fn test_division_by_quadratic() { - type F = QuarticExtension; - let n = 1_000_000; - let poly = PolynomialCoeffs::new(F::rand_vec(n)); - let quad = PolynomialCoeffs::new(F::rand_vec(2)); - - let timer = Instant::now(); - let (quotient0, rem0) = poly.div_rem(&quad); - println!("{:.3}s for usual", timer.elapsed().as_secs_f32()); - - let timer = Instant::now(); - let (quotient1, rem1) = poly.div_rem_long_division(&quad); - println!("{:.3}s for long division", timer.elapsed().as_secs_f32()); - - assert_eq!(quotient0.trimmed(), quotient1.trimmed()); - assert_eq!(rem0.trimmed(), rem1.trimmed()); + let quotient = poly.divide_by_linear(z); + assert_eq!( + poly, + &("ient * &vec![-z, F::ONE].into()) + &vec![ev].into() // `quotient * (X-z) + ev` + ); } } diff --git a/field/src/polynomial/mod.rs b/field/src/polynomial/mod.rs index 1f777ca3..ac3beb8e 100644 --- a/field/src/polynomial/mod.rs +++ b/field/src/polynomial/mod.rs @@ -26,17 +26,28 @@ impl PolynomialValues { PolynomialValues { values } } + pub fn zero(len: usize) -> Self { + Self::new(vec![F::ZERO; len]) + } + + /// Returns the polynomial whole value is one at the given index, and zero elsewhere. + pub fn selector(len: usize, index: usize) -> Self { + let mut result = Self::zero(len); + result.values[index] = F::ONE; + result + } + /// The number of values stored. - pub(crate) fn len(&self) -> usize { + pub fn len(&self) -> usize { self.values.len() } - pub fn ifft(&self) -> PolynomialCoeffs { + pub fn ifft(self) -> PolynomialCoeffs { ifft(self) } /// Returns the polynomial whose evaluation on the coset `shift*H` is `self`. - pub fn coset_ifft(&self, shift: F) -> PolynomialCoeffs { + pub fn coset_ifft(self, shift: F) -> PolynomialCoeffs { let mut shifted_coeffs = self.ifft(); shifted_coeffs .coeffs @@ -52,9 +63,15 @@ impl PolynomialValues { polys.into_iter().map(|p| p.lde(rate_bits)).collect() } - pub fn lde(&self, rate_bits: usize) -> Self { + pub fn lde(self, rate_bits: usize) -> Self { let coeffs = ifft(self).lde(rate_bits); - fft_with_options(&coeffs, Some(rate_bits), None) + fft_with_options(coeffs, Some(rate_bits), None) + } + + /// Low-degree extend `Self` (seen as evaluations over the subgroup) onto a coset. + pub fn lde_onto_coset(self, rate_bits: usize) -> Self { + let coeffs = ifft(self).lde(rate_bits); + coeffs.coset_fft_with_options(F::coset_shift(), Some(rate_bits), None) } pub fn degree(&self) -> usize { @@ -64,7 +81,7 @@ impl PolynomialValues { } pub fn degree_plus_one(&self) -> usize { - self.ifft().degree_plus_one() + self.clone().ifft().degree_plus_one() } } @@ -180,12 +197,21 @@ impl PolynomialCoeffs { poly } - /// Removes leading zero coefficients. + /// Removes any leading zero coefficients. pub fn trim(&mut self) { self.coeffs.truncate(self.degree_plus_one()); } - /// Removes leading zero coefficients. + /// Removes some leading zero coefficients, such that a desired length is reached. Fails if a + /// nonzero coefficient is encountered before then. + pub fn trim_to_len(&mut self, len: usize) -> Result<()> { + ensure!(self.len() >= len); + ensure!(self.coeffs[len..].iter().all(F::is_zero)); + self.coeffs.truncate(len); + Ok(()) + } + + /// Removes any leading zero coefficients. pub fn trimmed(&self) -> Self { let coeffs = self.coeffs[..self.degree_plus_one()].to_vec(); Self { coeffs } @@ -213,12 +239,12 @@ impl PolynomialCoeffs { Self::new(self.trimmed().coeffs.into_iter().rev().collect()) } - pub fn fft(&self) -> PolynomialValues { + pub fn fft(self) -> PolynomialValues { fft(self) } pub fn fft_with_options( - &self, + self, zero_factor: Option, root_table: Option<&FftRootTable>, ) -> PolynomialValues { @@ -386,7 +412,7 @@ impl Mul for &PolynomialCoeffs { .zip(b_evals.values) .map(|(pa, pb)| pa * pb) .collect(); - ifft(&mul_evals.into()) + ifft(mul_evals.into()) } } @@ -454,7 +480,7 @@ mod tests { let n = 1 << k; let evals = PolynomialValues::new(F::rand_vec(n)); let shift = F::rand(); - let coeffs = evals.coset_ifft(shift); + let coeffs = evals.clone().coset_ifft(shift); let generator = F::primitive_root_of_unity(k); let naive_coset_evals = F::cyclic_subgroup_coset_known_order(generator, shift, n) diff --git a/field/src/prime_field_testing.rs b/field/src/prime_field_testing.rs index 4aec6712..24d5e3c7 100644 --- a/field/src/prime_field_testing.rs +++ b/field/src/prime_field_testing.rs @@ -1,4 +1,4 @@ -use crate::field_types::PrimeField; +use crate::field_types::PrimeField64; /// Generates a series of non-negative integers less than `modulus` which cover a range of /// interesting test values. @@ -19,7 +19,7 @@ pub fn test_inputs(modulus: u64) -> Vec { /// word_bits)` and panic if the two resulting vectors differ. pub fn run_unaryop_test_cases(op: UnaryOp, expected_op: ExpectedOp) where - F: PrimeField, + F: PrimeField64, UnaryOp: Fn(F) -> F, ExpectedOp: Fn(u64) -> u64, { @@ -43,7 +43,7 @@ where /// Apply the binary functions `op` and `expected_op` to each pair of inputs. pub fn run_binaryop_test_cases(op: BinaryOp, expected_op: ExpectedOp) where - F: PrimeField, + F: PrimeField64, BinaryOp: Fn(F, F) -> F, ExpectedOp: Fn(u64, u64) -> u64, { @@ -70,7 +70,7 @@ macro_rules! test_prime_field_arithmetic { mod prime_field_arithmetic { use std::ops::{Add, Mul, Neg, Sub}; - use crate::field_types::{Field, PrimeField}; + use crate::field_types::{Field, Field64}; use crate::ops::Square; #[test] diff --git a/field/src/secp256k1_base.rs b/field/src/secp256k1_base.rs index 23702420..1972aed7 100644 --- a/field/src/secp256k1_base.rs +++ b/field/src/secp256k1_base.rs @@ -10,7 +10,7 @@ use num::{Integer, One}; use rand::Rng; use serde::{Deserialize, Serialize}; -use crate::field_types::Field; +use crate::field_types::{Field, PrimeField}; /// The base field of the secp256k1 elliptic curve. /// @@ -42,7 +42,7 @@ impl Default for Secp256K1Base { impl PartialEq for Secp256K1Base { fn eq(&self, other: &Self) -> bool { - self.to_biguint() == other.to_biguint() + self.to_canonical_biguint() == other.to_canonical_biguint() } } @@ -50,19 +50,19 @@ impl Eq for Secp256K1Base {} impl Hash for Secp256K1Base { fn hash(&self, state: &mut H) { - self.to_biguint().hash(state) + self.to_canonical_biguint().hash(state) } } impl Display for Secp256K1Base { fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { - Display::fmt(&self.to_biguint(), f) + Display::fmt(&self.to_canonical_biguint(), f) } } impl Debug for Secp256K1Base { fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { - Debug::fmt(&self.to_biguint(), f) + Debug::fmt(&self.to_canonical_biguint(), f) } } @@ -107,14 +107,6 @@ impl Field for Secp256K1Base { Some(self.exp_biguint(&(Self::order() - BigUint::one() - BigUint::one()))) } - fn to_biguint(&self) -> BigUint { - let mut result = biguint_from_array(self.0); - if result >= Self::order() { - result -= Self::order(); - } - result - } - fn from_biguint(val: BigUint) -> Self { Self( val.to_u64_digits() @@ -146,6 +138,16 @@ impl Field for Secp256K1Base { } } +impl PrimeField for Secp256K1Base { + fn to_canonical_biguint(&self) -> BigUint { + let mut result = biguint_from_array(self.0); + if result >= Self::order() { + result -= Self::order(); + } + result + } +} + impl Neg for Secp256K1Base { type Output = Self; @@ -154,7 +156,7 @@ impl Neg for Secp256K1Base { if self.is_zero() { Self::ZERO } else { - Self::from_biguint(Self::order() - self.to_biguint()) + Self::from_biguint(Self::order() - self.to_canonical_biguint()) } } } @@ -164,7 +166,7 @@ impl Add for Secp256K1Base { #[inline] fn add(self, rhs: Self) -> Self { - let mut result = self.to_biguint() + rhs.to_biguint(); + let mut result = self.to_canonical_biguint() + rhs.to_canonical_biguint(); if result >= Self::order() { result -= Self::order(); } @@ -207,7 +209,9 @@ impl Mul for Secp256K1Base { #[inline] fn mul(self, rhs: Self) -> Self { - Self::from_biguint((self.to_biguint() * rhs.to_biguint()).mod_floor(&Self::order())) + Self::from_biguint( + (self.to_canonical_biguint() * rhs.to_canonical_biguint()).mod_floor(&Self::order()), + ) } } diff --git a/field/src/secp256k1_scalar.rs b/field/src/secp256k1_scalar.rs index f10892af..1e506426 100644 --- a/field/src/secp256k1_scalar.rs +++ b/field/src/secp256k1_scalar.rs @@ -11,7 +11,7 @@ use num::{Integer, One}; use rand::Rng; use serde::{Deserialize, Serialize}; -use crate::field_types::Field; +use crate::field_types::{Field, PrimeField}; /// The base field of the secp256k1 elliptic curve. /// @@ -45,7 +45,7 @@ impl Default for Secp256K1Scalar { impl PartialEq for Secp256K1Scalar { fn eq(&self, other: &Self) -> bool { - self.to_biguint() == other.to_biguint() + self.to_canonical_biguint() == other.to_canonical_biguint() } } @@ -53,19 +53,19 @@ impl Eq for Secp256K1Scalar {} impl Hash for Secp256K1Scalar { fn hash(&self, state: &mut H) { - self.to_biguint().hash(state) + self.to_canonical_biguint().hash(state) } } impl Display for Secp256K1Scalar { fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { - Display::fmt(&self.to_biguint(), f) + Display::fmt(&self.to_canonical_biguint(), f) } } impl Debug for Secp256K1Scalar { fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { - Debug::fmt(&self.to_biguint(), f) + Debug::fmt(&self.to_canonical_biguint(), f) } } @@ -116,14 +116,6 @@ impl Field for Secp256K1Scalar { Some(self.exp_biguint(&(Self::order() - BigUint::one() - BigUint::one()))) } - fn to_biguint(&self) -> BigUint { - let mut result = biguint_from_array(self.0); - if result >= Self::order() { - result -= Self::order(); - } - result - } - fn from_biguint(val: BigUint) -> Self { Self( val.to_u64_digits() @@ -155,6 +147,16 @@ impl Field for Secp256K1Scalar { } } +impl PrimeField for Secp256K1Scalar { + fn to_canonical_biguint(&self) -> BigUint { + let mut result = biguint_from_array(self.0); + if result >= Self::order() { + result -= Self::order(); + } + result + } +} + impl Neg for Secp256K1Scalar { type Output = Self; @@ -163,7 +165,7 @@ impl Neg for Secp256K1Scalar { if self.is_zero() { Self::ZERO } else { - Self::from_biguint(Self::order() - self.to_biguint()) + Self::from_biguint(Self::order() - self.to_canonical_biguint()) } } } @@ -173,7 +175,7 @@ impl Add for Secp256K1Scalar { #[inline] fn add(self, rhs: Self) -> Self { - let mut result = self.to_biguint() + rhs.to_biguint(); + let mut result = self.to_canonical_biguint() + rhs.to_canonical_biguint(); if result >= Self::order() { result -= Self::order(); } @@ -216,7 +218,9 @@ impl Mul for Secp256K1Scalar { #[inline] fn mul(self, rhs: Self) -> Self { - Self::from_biguint((self.to_biguint() * rhs.to_biguint()).mod_floor(&Self::order())) + Self::from_biguint( + (self.to_canonical_biguint() * rhs.to_canonical_biguint()).mod_floor(&Self::order()), + ) } } diff --git a/field/src/zero_poly_coset.rs b/field/src/zero_poly_coset.rs new file mode 100644 index 00000000..0b7452f5 --- /dev/null +++ b/field/src/zero_poly_coset.rs @@ -0,0 +1,47 @@ +use crate::field_types::Field; + +/// Precomputations of the evaluation of `Z_H(X) = X^n - 1` on a coset `gK` with `H <= K`. +pub struct ZeroPolyOnCoset { + /// `n = |H|`. + n: F, + /// `rate = |K|/|H|`. + rate: usize, + /// Holds `g^n * (w^n)^i - 1 = g^n * v^i - 1` for `i in 0..rate`, with `w` a generator of `K` and `v` a + /// `rate`-primitive root of unity. + evals: Vec, + /// Holds the multiplicative inverses of `evals`. + inverses: Vec, +} + +impl ZeroPolyOnCoset { + pub fn new(n_log: usize, rate_bits: usize) -> Self { + let g_pow_n = F::coset_shift().exp_power_of_2(n_log); + let evals = F::two_adic_subgroup(rate_bits) + .into_iter() + .map(|x| g_pow_n * x - F::ONE) + .collect::>(); + let inverses = F::batch_multiplicative_inverse(&evals); + Self { + n: F::from_canonical_usize(1 << n_log), + rate: 1 << rate_bits, + evals, + inverses, + } + } + + /// Returns `Z_H(g * w^i)`. + pub fn eval(&self, i: usize) -> F { + self.evals[i % self.rate] + } + + /// Returns `1 / Z_H(g * w^i)`. + pub fn eval_inverse(&self, i: usize) -> F { + self.inverses[i % self.rate] + } + + /// Returns `L_1(x) = Z_H(x)/(n * (x - 1))` with `x = w^i`. + pub fn eval_l1(&self, i: usize, x: F) -> F { + // Could also precompute the inverses using Montgomery. + self.eval(i) * (self.n * (x - F::ONE)).inverse() + } +} diff --git a/insertion/src/insertion_gate.rs b/insertion/src/insertion_gate.rs index 33859e70..6fd98307 100644 --- a/insertion/src/insertion_gate.rs +++ b/insertion/src/insertion_gate.rs @@ -415,7 +415,7 @@ mod tests { v.extend(equality_dummy_vals); v.extend(insert_here_vals); - v.iter().map(|&x| x.into()).collect::>() + v.iter().map(|&x| x.into()).collect() } let orig_vec = vec![FF::rand(); 3]; diff --git a/plonky2.pdf b/plonky2.pdf new file mode 100644 index 00000000..299d1724 Binary files /dev/null and b/plonky2.pdf differ diff --git a/plonky2/Cargo.toml b/plonky2/Cargo.toml index 0dbfa2d7..af82a622 100644 --- a/plonky2/Cargo.toml +++ b/plonky2/Cargo.toml @@ -1,15 +1,14 @@ [package] name = "plonky2" -description = "Recursive SNARKs based on Plonk and FRI" +description = "Recursive SNARKs based on PLONK and FRI" version = "0.1.0" -authors = ["Daniel Lubarov "] +authors = ["Polygon Zero "] readme = "README.md" -license = "MIT OR Apache-2.0" repository = "https://github.com/mir-protocol/plonky2" -keywords = ["cryptography", "SNARK", "FRI"] +keywords = ["cryptography", "SNARK", "PLONK", "FRI"] categories = ["cryptography"] edition = "2021" -default-run = "bench_recursion" +default-run = "generate_constants" [dependencies] plonky2_field = { path = "../field" } @@ -48,14 +47,14 @@ harness = false name = "hashing" harness = false +[[bench]] +name = "merkle" +harness = false + [[bench]] name = "transpose" harness = false -[profile.release] -opt-level = 3 -#lto = "fat" -#codegen-units = 1 - -[profile.bench] -opt-level = 3 +[[bench]] +name = "reverse_index_bits" +harness = false diff --git a/plonky2/benches/ffts.rs b/plonky2/benches/ffts.rs index cfa02a25..63ac9c85 100644 --- a/plonky2/benches/ffts.rs +++ b/plonky2/benches/ffts.rs @@ -11,7 +11,7 @@ pub(crate) fn bench_ffts(c: &mut Criterion) { let size = 1 << size_log; group.bench_with_input(BenchmarkId::from_parameter(size), &size, |b, _| { let coeffs = PolynomialCoeffs::new(F::rand_vec(size)); - b.iter(|| coeffs.fft_with_options(None, None)); + b.iter(|| coeffs.clone().fft_with_options(None, None)); }); } } diff --git a/plonky2/benches/hashing.rs b/plonky2/benches/hashing.rs index b1193516..a968d957 100644 --- a/plonky2/benches/hashing.rs +++ b/plonky2/benches/hashing.rs @@ -1,17 +1,20 @@ +#![allow(incomplete_features)] #![feature(generic_const_exprs)] use criterion::{criterion_group, criterion_main, BatchSize, Criterion}; use plonky2::field::goldilocks_field::GoldilocksField; -use plonky2::hash::gmimc::GMiMC; +use plonky2::hash::hash_types::{BytesHash, RichField}; use plonky2::hash::hashing::SPONGE_WIDTH; +use plonky2::hash::keccak::KeccakHash; use plonky2::hash::poseidon::Poseidon; +use plonky2::plonk::config::Hasher; use tynm::type_name; -pub(crate) fn bench_gmimc, const WIDTH: usize>(c: &mut Criterion) { - c.bench_function(&format!("gmimc<{}, {}>", type_name::(), WIDTH), |b| { +pub(crate) fn bench_keccak(c: &mut Criterion) { + c.bench_function("keccak256", |b| { b.iter_batched( - || F::rand_arr::(), - |state| F::gmimc_permute(state), + || (BytesHash::<32>::rand(), BytesHash::<32>::rand()), + |(left, right)| as Hasher>::two_to_one(left, right), BatchSize::SmallInput, ) }); @@ -31,8 +34,8 @@ pub(crate) fn bench_poseidon(c: &mut Criterion) { } fn criterion_benchmark(c: &mut Criterion) { - bench_gmimc::(c); bench_poseidon::(c); + bench_keccak::(c); } criterion_group!(benches, criterion_benchmark); diff --git a/plonky2/benches/merkle.rs b/plonky2/benches/merkle.rs new file mode 100644 index 00000000..8bc43730 --- /dev/null +++ b/plonky2/benches/merkle.rs @@ -0,0 +1,40 @@ +#![feature(generic_const_exprs)] + +use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion}; +use plonky2::field::goldilocks_field::GoldilocksField; +use plonky2::hash::hash_types::RichField; +use plonky2::hash::keccak::KeccakHash; +use plonky2::hash::merkle_tree::MerkleTree; +use plonky2::hash::poseidon::PoseidonHash; +use plonky2::plonk::config::Hasher; +use tynm::type_name; + +const ELEMS_PER_LEAF: usize = 135; + +pub(crate) fn bench_merkle_tree>(c: &mut Criterion) +where + [(); H::HASH_SIZE]:, +{ + let mut group = c.benchmark_group(&format!( + "merkle-tree<{}, {}>", + type_name::(), + type_name::() + )); + group.sample_size(10); + + for size_log in [13, 14, 15] { + let size = 1 << size_log; + group.bench_with_input(BenchmarkId::from_parameter(size), &size, |b, _| { + let leaves = vec![F::rand_vec(ELEMS_PER_LEAF); size]; + b.iter(|| MerkleTree::::new(leaves.clone(), 0)); + }); + } +} + +fn criterion_benchmark(c: &mut Criterion) { + bench_merkle_tree::(c); + bench_merkle_tree::>(c); +} + +criterion_group!(benches, criterion_benchmark); +criterion_main!(benches); diff --git a/plonky2/benches/reverse_index_bits.rs b/plonky2/benches/reverse_index_bits.rs new file mode 100644 index 00000000..90f1e285 --- /dev/null +++ b/plonky2/benches/reverse_index_bits.rs @@ -0,0 +1,30 @@ +use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion}; +use plonky2::field::field_types::Field; +use plonky2::field::goldilocks_field::GoldilocksField; +use plonky2_util::{reverse_index_bits, reverse_index_bits_in_place}; + +type F = GoldilocksField; + +fn benchmark_in_place(c: &mut Criterion) { + let mut group = c.benchmark_group("reverse-index-bits-in-place"); + for width in [1 << 8, 1 << 16, 1 << 24] { + group.bench_with_input(BenchmarkId::from_parameter(width), &width, |b, _| { + let mut values = F::rand_vec(width); + b.iter(|| reverse_index_bits_in_place(&mut values)); + }); + } +} + +fn benchmark_out_of_place(c: &mut Criterion) { + let mut group = c.benchmark_group("reverse-index-bits"); + for width in [1 << 8, 1 << 16, 1 << 24] { + group.bench_with_input(BenchmarkId::from_parameter(width), &width, |b, _| { + let values = F::rand_vec(width); + b.iter(|| reverse_index_bits(&values)); + }); + } +} + +criterion_group!(benches_in_place, benchmark_in_place); +criterion_group!(benches_out_of_place, benchmark_out_of_place); +criterion_main!(benches_in_place, benches_out_of_place); diff --git a/plonky2/src/bin/bench_ldes.rs b/plonky2/src/bin/bench_ldes.rs deleted file mode 100644 index 57f31290..00000000 --- a/plonky2/src/bin/bench_ldes.rs +++ /dev/null @@ -1,34 +0,0 @@ -use std::time::Instant; - -use plonky2_field::field_types::Field; -use plonky2_field::goldilocks_field::GoldilocksField; -use plonky2_field::polynomial::PolynomialValues; -use rayon::prelude::*; - -type F = GoldilocksField; - -// This is an estimate of how many LDEs the prover will compute. The biggest component, 86, comes -// from wire polynomials which "store" the outputs of S-boxes in our Poseidon gate. -const NUM_LDES: usize = 8 + 8 + 3 + 86 + 3 + 8; - -const DEGREE: usize = 1 << 14; - -const RATE_BITS: usize = 3; - -fn main() { - // We start with random polynomials. - let all_poly_values = (0..NUM_LDES) - .map(|_| PolynomialValues::new(F::rand_vec(DEGREE))) - .collect::>(); - - let start = Instant::now(); - - all_poly_values.into_par_iter().for_each(|poly_values| { - let start = Instant::now(); - let lde = poly_values.lde(RATE_BITS); - let duration = start.elapsed(); - println!("LDE took {:?}", duration); - println!("LDE result: {:?}", lde.values[0]); - }); - println!("All LDEs took {:?}", start.elapsed()); -} diff --git a/plonky2/src/bin/bench_recursion.rs b/plonky2/src/bin/bench_recursion.rs deleted file mode 100644 index cb8eaca9..00000000 --- a/plonky2/src/bin/bench_recursion.rs +++ /dev/null @@ -1,60 +0,0 @@ -use anyhow::Result; -use env_logger::Env; -use log::info; -use plonky2::fri::reduction_strategies::FriReductionStrategy; -use plonky2::fri::FriConfig; -use plonky2::hash::hashing::SPONGE_WIDTH; -use plonky2::iop::witness::PartialWitness; -use plonky2::plonk::circuit_builder::CircuitBuilder; -use plonky2::plonk::circuit_data::CircuitConfig; -use plonky2::plonk::config::{GenericConfig, PoseidonGoldilocksConfig}; - -fn main() -> Result<()> { - // Set the default log filter. This can be overridden using the `RUST_LOG` environment variable, - // e.g. `RUST_LOG=debug`. - // We default to debug for now, since there aren't many logs anyway, but we should probably - // change this to info or warn later. - env_logger::Builder::from_env(Env::default().default_filter_or("debug")).init(); - - bench_prove::() -} - -fn bench_prove, const D: usize>() -> Result<()> { - let config = CircuitConfig { - num_wires: 126, - num_routed_wires: 33, - constant_gate_size: 6, - use_base_arithmetic_gate: false, - security_bits: 128, - num_challenges: 3, - zero_knowledge: false, - fri_config: FriConfig { - rate_bits: 3, - cap_height: 1, - proof_of_work_bits: 15, - reduction_strategy: FriReductionStrategy::ConstantArityBits(3, 5), - num_query_rounds: 35, - }, - }; - - let inputs = PartialWitness::new(); - let mut builder = CircuitBuilder::::new(config); - - let zero = builder.zero(); - let zero_ext = builder.zero_extension(); - - let mut state = [zero; SPONGE_WIDTH]; - for _ in 0..10000 { - state = builder.permute::<>::InnerHasher>(state); - } - - // Random other gates. - builder.add(zero, zero); - builder.add_extension(zero_ext, zero_ext); - - let circuit = builder.build::(); - let proof_with_pis = circuit.prove(inputs)?; - let proof_bytes = serde_cbor::to_vec(&proof_with_pis).unwrap(); - info!("Proof length: {} bytes", proof_bytes.len()); - circuit.verify(proof_with_pis) -} diff --git a/plonky2/src/bin/generate_constants.rs b/plonky2/src/bin/generate_constants.rs index eb35aec3..d2744991 100644 --- a/plonky2/src/bin/generate_constants.rs +++ b/plonky2/src/bin/generate_constants.rs @@ -2,7 +2,7 @@ #![allow(clippy::needless_range_loop)] -use plonky2_field::field_types::PrimeField; +use plonky2_field::field_types::Field64; use plonky2_field::goldilocks_field::GoldilocksField; use rand::{Rng, SeedableRng}; use rand_chacha::ChaCha8Rng; @@ -11,7 +11,6 @@ use rand_chacha::ChaCha8Rng; // range of GoldilocksField, then verify that each constant also fits in GoldilocksField. const SAMPLE_RANGE_END: u64 = 0xffffffff70000001; -// const N: usize = 101; // For GMiMC // const N: usize = 8 * 30; // For Posiedon-8 const N: usize = 12 * 30; // For Posiedon-12 diff --git a/plonky2/src/curve/curve_msm.rs b/plonky2/src/curve/curve_msm.rs index 388c0321..4c274c1c 100644 --- a/plonky2/src/curve/curve_msm.rs +++ b/plonky2/src/curve/curve_msm.rs @@ -1,5 +1,6 @@ use itertools::Itertools; use plonky2_field::field_types::Field; +use plonky2_field::field_types::PrimeField; use rayon::prelude::*; use crate::curve::curve_summation::affine_multisummation_best; @@ -160,7 +161,7 @@ pub(crate) fn to_digits(x: &C::ScalarField, w: usize) -> Vec { // Convert x to a bool array. let x_canonical: Vec<_> = x - .to_biguint() + .to_canonical_biguint() .to_u64_digits() .iter() .cloned() @@ -187,6 +188,7 @@ pub(crate) fn to_digits(x: &C::ScalarField, w: usize) -> Vec { mod tests { use num::BigUint; use plonky2_field::field_types::Field; + use plonky2_field::field_types::PrimeField; use plonky2_field::secp256k1_scalar::Secp256K1Scalar; use crate::curve::curve_msm::{msm_execute, msm_precompute, to_digits}; @@ -206,7 +208,7 @@ mod tests { 0b11111111111111111111111111111111, ]; let x = Secp256K1Scalar::from_biguint(BigUint::from_slice(&x_canonical)); - assert_eq!(x.to_biguint().to_u32_digits(), x_canonical); + assert_eq!(x.to_canonical_biguint().to_u32_digits(), x_canonical); assert_eq!( to_digits::(&x, 17), vec![ diff --git a/plonky2/src/curve/curve_multiplication.rs b/plonky2/src/curve/curve_multiplication.rs index 30da4973..c6fbbd83 100644 --- a/plonky2/src/curve/curve_multiplication.rs +++ b/plonky2/src/curve/curve_multiplication.rs @@ -1,6 +1,7 @@ use std::ops::Mul; use plonky2_field::field_types::Field; +use plonky2_field::field_types::PrimeField; use crate::curve::curve_types::{Curve, CurveScalar, ProjectivePoint}; @@ -88,7 +89,7 @@ fn to_digits(x: &C::ScalarField) -> Vec { ); let digits_per_u64 = 64 / WINDOW_BITS; let mut digits = Vec::with_capacity(digits_per_scalar::()); - for limb in x.to_biguint().to_u64_digits() { + for limb in x.to_canonical_biguint().to_u64_digits() { for j in 0..digits_per_u64 { digits.push((limb >> (j * WINDOW_BITS) as u64) % BASE as u64); } diff --git a/plonky2/src/curve/curve_types.rs b/plonky2/src/curve/curve_types.rs index b7ee34e6..264120c7 100644 --- a/plonky2/src/curve/curve_types.rs +++ b/plonky2/src/curve/curve_types.rs @@ -1,8 +1,10 @@ use std::fmt::Debug; +use std::hash::Hash; use std::ops::Neg; -use plonky2_field::field_types::Field; +use plonky2_field::field_types::{Field, PrimeField}; use plonky2_field::ops::Square; +use serde::{Deserialize, Serialize}; // To avoid implementation conflicts from associated types, // see https://github.com/rust-lang/rust/issues/20400 @@ -10,8 +12,8 @@ pub struct CurveScalar(pub ::ScalarField); /// A short Weierstrass curve. pub trait Curve: 'static + Sync + Sized + Copy + Debug { - type BaseField: Field; - type ScalarField: Field; + type BaseField: PrimeField; + type ScalarField: PrimeField; const A: Self::BaseField; const B: Self::BaseField; @@ -36,7 +38,7 @@ pub trait Curve: 'static + Sync + Sized + Copy + Debug { } /// A point on a short Weierstrass curve, represented in affine coordinates. -#[derive(Copy, Clone, Debug)] +#[derive(Copy, Clone, Debug, Deserialize, Serialize)] pub struct AffinePoint { pub x: C::BaseField, pub y: C::BaseField, @@ -119,6 +121,17 @@ impl PartialEq for AffinePoint { impl Eq for AffinePoint {} +impl Hash for AffinePoint { + fn hash(&self, state: &mut H) { + if self.zero { + self.zero.hash(state); + } else { + self.x.hash(state); + self.y.hash(state); + } + } +} + /// A point on a short Weierstrass curve, represented in projective coordinates. #[derive(Copy, Clone, Debug)] pub struct ProjectivePoint { @@ -259,3 +272,11 @@ impl Neg for ProjectivePoint { ProjectivePoint { x, y: -y, z } } } + +pub fn base_to_scalar(x: C::BaseField) -> C::ScalarField { + C::ScalarField::from_biguint(x.to_canonical_biguint()) +} + +pub fn scalar_to_base(x: C::ScalarField) -> C::BaseField { + C::BaseField::from_biguint(x.to_canonical_biguint()) +} diff --git a/plonky2/src/curve/ecdsa.rs b/plonky2/src/curve/ecdsa.rs new file mode 100644 index 00000000..cabe038a --- /dev/null +++ b/plonky2/src/curve/ecdsa.rs @@ -0,0 +1,78 @@ +use serde::{Deserialize, Serialize}; + +use crate::curve::curve_msm::msm_parallel; +use crate::curve::curve_types::{base_to_scalar, AffinePoint, Curve, CurveScalar}; +use crate::field::field_types::Field; + +#[derive(Copy, Clone, Debug, Deserialize, Eq, Hash, PartialEq, Serialize)] +pub struct ECDSASignature { + pub r: C::ScalarField, + pub s: C::ScalarField, +} + +#[derive(Copy, Clone, Debug, Deserialize, Eq, Hash, PartialEq, Serialize)] +pub struct ECDSASecretKey(pub C::ScalarField); + +#[derive(Copy, Clone, Debug, Deserialize, Eq, Hash, PartialEq, Serialize)] +pub struct ECDSAPublicKey(pub AffinePoint); + +pub fn sign_message(msg: C::ScalarField, sk: ECDSASecretKey) -> ECDSASignature { + let (k, rr) = { + let mut k = C::ScalarField::rand(); + let mut rr = (CurveScalar(k) * C::GENERATOR_PROJECTIVE).to_affine(); + while rr.x == C::BaseField::ZERO { + k = C::ScalarField::rand(); + rr = (CurveScalar(k) * C::GENERATOR_PROJECTIVE).to_affine(); + } + (k, rr) + }; + let r = base_to_scalar::(rr.x); + + let s = k.inverse() * (msg + r * sk.0); + + ECDSASignature { r, s } +} + +pub fn verify_message( + msg: C::ScalarField, + sig: ECDSASignature, + pk: ECDSAPublicKey, +) -> bool { + let ECDSASignature { r, s } = sig; + + assert!(pk.0.is_valid()); + + let c = s.inverse(); + let u1 = msg * c; + let u2 = r * c; + + let g = C::GENERATOR_PROJECTIVE; + let w = 5; // Experimentally fastest + let point_proj = msm_parallel(&[u1, u2], &[g, pk.0.to_projective()], w); + let point = point_proj.to_affine(); + + let x = base_to_scalar::(point.x); + r == x +} + +#[cfg(test)] +mod tests { + use crate::curve::curve_types::{Curve, CurveScalar}; + use crate::curve::ecdsa::{sign_message, verify_message, ECDSAPublicKey, ECDSASecretKey}; + use crate::curve::secp256k1::Secp256K1; + use crate::field::field_types::Field; + use crate::field::secp256k1_scalar::Secp256K1Scalar; + + #[test] + fn test_ecdsa_native() { + type C = Secp256K1; + + let msg = Secp256K1Scalar::rand(); + let sk = ECDSASecretKey(Secp256K1Scalar::rand()); + let pk = ECDSAPublicKey((CurveScalar(sk.0) * C::GENERATOR_PROJECTIVE).to_affine()); + + let sig = sign_message(msg, sk); + let result = verify_message(msg, sig, pk); + assert!(result); + } +} diff --git a/plonky2/src/curve/mod.rs b/plonky2/src/curve/mod.rs index d31e373e..8dd6f0d6 100644 --- a/plonky2/src/curve/mod.rs +++ b/plonky2/src/curve/mod.rs @@ -3,4 +3,5 @@ pub mod curve_msm; pub mod curve_multiplication; pub mod curve_summation; pub mod curve_types; +pub mod ecdsa; pub mod secp256k1; diff --git a/plonky2/src/curve/secp256k1.rs b/plonky2/src/curve/secp256k1.rs index d9039719..18040dae 100644 --- a/plonky2/src/curve/secp256k1.rs +++ b/plonky2/src/curve/secp256k1.rs @@ -1,10 +1,11 @@ use plonky2_field::field_types::Field; use plonky2_field::secp256k1_base::Secp256K1Base; use plonky2_field::secp256k1_scalar::Secp256K1Scalar; +use serde::{Deserialize, Serialize}; use crate::curve::curve_types::{AffinePoint, Curve}; -#[derive(Debug, Copy, Clone)] +#[derive(Debug, Copy, Clone, Deserialize, Eq, Hash, PartialEq, Serialize)] pub struct Secp256K1; impl Curve for Secp256K1 { @@ -40,6 +41,7 @@ const SECP256K1_GENERATOR_Y: Secp256K1Base = Secp256K1Base([ mod tests { use num::BigUint; use plonky2_field::field_types::Field; + use plonky2_field::field_types::PrimeField; use plonky2_field::secp256k1_scalar::Secp256K1Scalar; use crate::curve::curve_types::{AffinePoint, Curve, ProjectivePoint}; @@ -86,7 +88,7 @@ mod tests { ) -> ProjectivePoint { let mut g = rhs; let mut sum = ProjectivePoint::ZERO; - for limb in lhs.to_biguint().to_u64_digits().iter() { + for limb in lhs.to_canonical_biguint().to_u64_digits().iter() { for j in 0..64 { if (limb >> j & 1u64) != 0u64 { sum = sum + g; diff --git a/plonky2/src/fri/challenges.rs b/plonky2/src/fri/challenges.rs new file mode 100644 index 00000000..82438383 --- /dev/null +++ b/plonky2/src/fri/challenges.rs @@ -0,0 +1,131 @@ +use plonky2_field::extension_field::Extendable; +use plonky2_field::polynomial::PolynomialCoeffs; + +use crate::fri::proof::{FriChallenges, FriChallengesTarget}; +use crate::fri::structure::{FriOpenings, FriOpeningsTarget}; +use crate::fri::FriConfig; +use crate::gadgets::polynomial::PolynomialCoeffsExtTarget; +use crate::hash::hash_types::{MerkleCapTarget, RichField}; +use crate::hash::merkle_tree::MerkleCap; +use crate::iop::challenger::{Challenger, RecursiveChallenger}; +use crate::iop::target::Target; +use crate::plonk::circuit_builder::CircuitBuilder; +use crate::plonk::circuit_data::CommonCircuitData; +use crate::plonk::config::{AlgebraicHasher, GenericConfig, Hasher}; + +impl> Challenger { + pub fn observe_openings(&mut self, openings: &FriOpenings) + where + F: RichField + Extendable, + { + for v in &openings.batches { + self.observe_extension_elements(&v.values); + } + } + + pub fn fri_challenges, const D: usize>( + &mut self, + commit_phase_merkle_caps: &[MerkleCap], + final_poly: &PolynomialCoeffs, + pow_witness: F, + degree_bits: usize, + config: &FriConfig, + ) -> FriChallenges + where + F: RichField + Extendable, + { + let num_fri_queries = config.num_query_rounds; + let lde_size = 1 << (degree_bits + config.rate_bits); + // Scaling factor to combine polynomials. + let fri_alpha = self.get_extension_challenge::(); + + // Recover the random betas used in the FRI reductions. + let fri_betas = commit_phase_merkle_caps + .iter() + .map(|cap| { + self.observe_cap(cap); + self.get_extension_challenge::() + }) + .collect(); + + self.observe_extension_elements(&final_poly.coeffs); + + let fri_pow_response = C::InnerHasher::hash_no_pad( + &self + .get_hash() + .elements + .iter() + .copied() + .chain(Some(pow_witness)) + .collect::>(), + ) + .elements[0]; + + let fri_query_indices = (0..num_fri_queries) + .map(|_| self.get_challenge().to_canonical_u64() as usize % lde_size) + .collect(); + + FriChallenges { + fri_alpha, + fri_betas, + fri_pow_response, + fri_query_indices, + } + } +} + +impl, H: AlgebraicHasher, const D: usize> + RecursiveChallenger +{ + pub fn observe_openings(&mut self, openings: &FriOpeningsTarget) { + for v in &openings.batches { + self.observe_extension_elements(&v.values); + } + } + + pub fn fri_challenges>( + &mut self, + builder: &mut CircuitBuilder, + commit_phase_merkle_caps: &[MerkleCapTarget], + final_poly: &PolynomialCoeffsExtTarget, + pow_witness: Target, + inner_common_data: &CommonCircuitData, + ) -> FriChallengesTarget { + let num_fri_queries = inner_common_data.config.fri_config.num_query_rounds; + // Scaling factor to combine polynomials. + let fri_alpha = self.get_extension_challenge(builder); + + // Recover the random betas used in the FRI reductions. + let fri_betas = commit_phase_merkle_caps + .iter() + .map(|cap| { + self.observe_cap(cap); + self.get_extension_challenge(builder) + }) + .collect(); + + self.observe_extension_elements(&final_poly.0); + + let pow_inputs = self + .get_hash(builder) + .elements + .iter() + .copied() + .chain(Some(pow_witness)) + .collect(); + let fri_pow_response = builder + .hash_n_to_hash_no_pad::(pow_inputs) + .elements[0]; + + let fri_query_indices = (0..num_fri_queries) + .map(|_| self.get_challenge(builder)) + .collect(); + + FriChallengesTarget { + fri_alpha, + fri_betas, + fri_pow_response, + fri_query_indices, + } + } +} diff --git a/plonky2/src/fri/mod.rs b/plonky2/src/fri/mod.rs index c50d1ff7..a0cd428b 100644 --- a/plonky2/src/fri/mod.rs +++ b/plonky2/src/fri/mod.rs @@ -1,11 +1,14 @@ use crate::fri::reduction_strategies::FriReductionStrategy; -pub mod commitment; +mod challenges; +pub mod oracle; pub mod proof; pub mod prover; pub mod recursive_verifier; pub mod reduction_strategies; +pub mod structure; pub mod verifier; +pub mod witness_util; #[derive(Debug, Clone, Eq, PartialEq)] pub struct FriConfig { @@ -23,6 +26,12 @@ pub struct FriConfig { pub num_query_rounds: usize, } +impl FriConfig { + pub fn rate(&self) -> f64 { + 1.0 / ((1 << self.rate_bits) as f64) + } +} + /// FRI parameters, including generated parameters which are specific to an instance size, in /// contrast to `FriConfig` which is user-specified and independent of instance size. #[derive(Debug)] @@ -30,6 +39,9 @@ pub struct FriParams { /// User-specified FRI configuration. pub config: FriConfig, + /// Whether to use a hiding variant of Merkle trees (where random salts are added to leaves). + pub hiding: bool, + /// The degree of the purported codeword, measured in bits. pub degree_bits: usize, diff --git a/plonky2/src/fri/commitment.rs b/plonky2/src/fri/oracle.rs similarity index 52% rename from plonky2/src/fri/commitment.rs rename to plonky2/src/fri/oracle.rs index 9d7ecf43..bd1e9ac5 100644 --- a/plonky2/src/fri/commitment.rs +++ b/plonky2/src/fri/oracle.rs @@ -7,13 +7,12 @@ use rayon::prelude::*; use crate::fri::proof::FriProof; use crate::fri::prover::fri_proof; +use crate::fri::structure::{FriBatchInfo, FriInstanceInfo}; +use crate::fri::FriParams; use crate::hash::hash_types::RichField; use crate::hash::merkle_tree::MerkleTree; use crate::iop::challenger::Challenger; -use crate::plonk::circuit_data::CommonCircuitData; -use crate::plonk::config::GenericConfig; -use crate::plonk::plonk_common::PlonkPolynomials; -use crate::plonk::proof::OpeningSet; +use crate::plonk::config::{GenericConfig, Hasher}; use crate::timed; use crate::util::reducing::ReducingFactor; use crate::util::reverse_bits; @@ -23,12 +22,9 @@ use crate::util::transpose; /// Four (~64 bit) field elements gives ~128 bit security. pub const SALT_SIZE: usize = 4; -/// Represents a batch FRI based commitment to a list of polynomials. -pub struct PolynomialBatchCommitment< - F: RichField + Extendable, - C: GenericConfig, - const D: usize, -> { +/// Represents a FRI oracle, i.e. a batch of polynomials which have been Merklized. +pub struct PolynomialBatch, C: GenericConfig, const D: usize> +{ pub polynomials: Vec>, pub merkle_tree: MerkleTree, pub degree_log: usize, @@ -37,21 +33,24 @@ pub struct PolynomialBatchCommitment< } impl, C: GenericConfig, const D: usize> - PolynomialBatchCommitment + PolynomialBatch { /// Creates a list polynomial commitment for the polynomials interpolating the values in `values`. - pub(crate) fn from_values( + pub fn from_values( values: Vec>, rate_bits: usize, blinding: bool, cap_height: usize, timing: &mut TimingTree, fft_root_table: Option<&FftRootTable>, - ) -> Self { + ) -> Self + where + [(); C::Hasher::HASH_SIZE]:, + { let coeffs = timed!( timing, "IFFT", - values.par_iter().map(|v| v.ifft()).collect::>() + values.into_par_iter().map(|v| v.ifft()).collect::>() ); Self::from_coeffs( @@ -65,14 +64,17 @@ impl, C: GenericConfig, const D: usize> } /// Creates a list polynomial commitment for the polynomials `polynomials`. - pub(crate) fn from_coeffs( + pub fn from_coeffs( polynomials: Vec>, rate_bits: usize, blinding: bool, cap_height: usize, timing: &mut TimingTree, fft_root_table: Option<&FftRootTable>, - ) -> Self { + ) -> Self + where + [(); C::Hasher::HASH_SIZE]:, + { let degree = polynomials[0].len(); let lde_values = timed!( timing, @@ -130,78 +132,42 @@ impl, C: GenericConfig, const D: usize> &slice[..slice.len() - if self.blinding { SALT_SIZE } else { 0 }] } - /// Takes the commitments to the constants - sigmas - wires - zs - quotient — polynomials, - /// and an opening point `zeta` and produces a batched opening proof + opening set. - pub(crate) fn open_plonk( - commitments: &[&Self; 4], - zeta: F::Extension, + /// Produces a batch opening proof. + pub fn prove_openings( + instance: &FriInstanceInfo, + oracles: &[&Self], challenger: &mut Challenger, - common_data: &CommonCircuitData, + fri_params: &FriParams, timing: &mut TimingTree, - ) -> (FriProof, OpeningSet) { - let config = &common_data.config; + ) -> FriProof + where + [(); C::Hasher::HASH_SIZE]:, + { assert!(D > 1, "Not implemented for D=1."); - let degree_log = commitments[0].degree_log; - let g = F::Extension::primitive_root_of_unity(degree_log); - for p in &[zeta, g * zeta] { - assert_ne!( - p.exp_u64(1 << degree_log as u64), - F::Extension::ONE, - "Opening point is in the subgroup." - ); - } - - let os = timed!( - timing, - "construct the opening set", - OpeningSet::new( - zeta, - g, - commitments[0], - commitments[1], - commitments[2], - commitments[3], - common_data, - ) - ); - challenger.observe_opening_set(&os); - let alpha = challenger.get_extension_challenge::(); let mut alpha = ReducingFactor::new(alpha); // Final low-degree polynomial that goes into FRI. let mut final_poly = PolynomialCoeffs::empty(); - // All polynomials are opened at `zeta`. - let single_polys = [ - PlonkPolynomials::CONSTANTS_SIGMAS, - PlonkPolynomials::WIRES, - PlonkPolynomials::ZS_PARTIAL_PRODUCTS, - PlonkPolynomials::QUOTIENT, - ] - .iter() - .flat_map(|&p| &commitments[p.index].polynomials); - let single_composition_poly = timed!( - timing, - "reduce single polys", - alpha.reduce_polys_base(single_polys) - ); + for FriBatchInfo { point, polynomials } in &instance.batches { + let polys_coeff = polynomials.iter().map(|fri_poly| { + &oracles[fri_poly.oracle_index].polynomials[fri_poly.polynomial_index] + }); + let composition_poly = timed!( + timing, + &format!("reduce batch of {} polynomials", polynomials.len()), + alpha.reduce_polys_base(polys_coeff) + ); + let quotient = composition_poly.divide_by_linear(*point); + alpha.shift_poly(&mut final_poly); + final_poly += quotient; + } + // Multiply the final polynomial by `X`, so that `final_poly` has the maximum degree for + // which the LDT will pass. See github.com/mir-protocol/plonky2/pull/436 for details. + final_poly.coeffs.insert(0, F::Extension::ZERO); - let single_quotient = Self::compute_quotient([zeta], single_composition_poly); - final_poly += single_quotient; - alpha.reset(); - - // Z polynomials have an additional opening at `g zeta`. - let zs_polys = &commitments[PlonkPolynomials::ZS_PARTIAL_PRODUCTS.index].polynomials - [common_data.zs_range()]; - let zs_composition_poly = - timed!(timing, "reduce Z polys", alpha.reduce_polys_base(zs_polys)); - - let zs_quotient = Self::compute_quotient([g * zeta], zs_composition_poly); - alpha.shift_poly(&mut final_poly); - final_poly += zs_quotient; - - let lde_final_poly = final_poly.lde(config.fri_config.rate_bits); + let lde_final_poly = final_poly.lde(fri_params.config.rate_bits); let lde_final_values = timed!( timing, &format!("perform final FFT {}", lde_final_poly.len()), @@ -209,41 +175,17 @@ impl, C: GenericConfig, const D: usize> ); let fri_proof = fri_proof::( - &commitments + &oracles .par_iter() .map(|c| &c.merkle_tree) .collect::>(), lde_final_poly, lde_final_values, challenger, - &common_data.fri_params, + fri_params, timing, ); - (fri_proof, os) - } - - /// Given `points=(x_i)`, `evals=(y_i)` and `poly=P` with `P(x_i)=y_i`, computes the polynomial - /// `Q=(P-I)/Z` where `I` interpolates `(x_i, y_i)` and `Z` is the vanishing polynomial on `(x_i)`. - fn compute_quotient( - points: [F::Extension; N], - poly: PolynomialCoeffs, - ) -> PolynomialCoeffs { - let quotient = if N == 1 { - poly.divide_by_linear(points[0]).0 - } else if N == 2 { - // The denominator is `(X - p0)(X - p1) = p0 p1 - (p0 + p1) X + X^2`. - let denominator = vec![ - points[0] * points[1], - -points[0] - points[1], - F::Extension::ONE, - ] - .into(); - poly.div_rem_long_division(&denominator).0 // Could also use `divide_by_linear` twice. - } else { - unreachable!("This shouldn't happen. Plonk should open polynomials at 1 or 2 points.") - }; - - quotient.padded(quotient.degree_plus_one().next_power_of_two()) + fri_proof } } diff --git a/plonky2/src/fri/proof.rs b/plonky2/src/fri/proof.rs index 784f3286..9c6961a4 100644 --- a/plonky2/src/fri/proof.rs +++ b/plonky2/src/fri/proof.rs @@ -5,6 +5,7 @@ use plonky2_field::extension_field::{flatten, unflatten, Extendable}; use plonky2_field::polynomial::PolynomialCoeffs; use serde::{Deserialize, Serialize}; +use crate::fri::FriParams; use crate::gadgets::polynomial::PolynomialCoeffsExtTarget; use crate::hash::hash_types::MerkleCapTarget; use crate::hash::hash_types::RichField; @@ -13,9 +14,8 @@ use crate::hash::merkle_tree::MerkleCap; use crate::hash::path_compression::{compress_merkle_proofs, decompress_merkle_proofs}; use crate::iop::ext_target::ExtensionTarget; use crate::iop::target::Target; -use crate::plonk::circuit_data::CommonCircuitData; use crate::plonk::config::{GenericConfig, Hasher}; -use crate::plonk::plonk_common::PolynomialsIndexBlinding; +use crate::plonk::plonk_common::salt_size; use crate::plonk::proof::{FriInferredElements, ProofChallenges}; /// Evaluations and Merkle proof produced by the prover in a FRI query step. @@ -26,7 +26,7 @@ pub struct FriQueryStep, H: Hasher, const D: usi pub merkle_proof: MerkleProof, } -#[derive(Clone)] +#[derive(Clone, Debug)] pub struct FriQueryStepTarget { pub evals: Vec>, pub merkle_proof: MerkleProofTarget, @@ -41,29 +41,34 @@ pub struct FriInitialTreeProof> { } impl> FriInitialTreeProof { - pub(crate) fn unsalted_evals( - &self, - polynomials: PolynomialsIndexBlinding, - zero_knowledge: bool, - ) -> &[F] { - let evals = &self.evals_proofs[polynomials.index].0; - &evals[..evals.len() - polynomials.salt_size(zero_knowledge)] + pub(crate) fn unsalted_eval(&self, oracle_index: usize, poly_index: usize, salted: bool) -> F { + self.unsalted_evals(oracle_index, salted)[poly_index] + } + + fn unsalted_evals(&self, oracle_index: usize, salted: bool) -> &[F] { + let evals = &self.evals_proofs[oracle_index].0; + &evals[..evals.len() - salt_size(salted)] } } -#[derive(Clone)] +#[derive(Clone, Debug)] pub struct FriInitialTreeProofTarget { pub evals_proofs: Vec<(Vec, MerkleProofTarget)>, } impl FriInitialTreeProofTarget { - pub(crate) fn unsalted_evals( + pub(crate) fn unsalted_eval( &self, - polynomials: PolynomialsIndexBlinding, - zero_knowledge: bool, - ) -> &[Target] { - let evals = &self.evals_proofs[polynomials.index].0; - &evals[..evals.len() - polynomials.salt_size(zero_knowledge)] + oracle_index: usize, + poly_index: usize, + salted: bool, + ) -> Target { + self.unsalted_evals(oracle_index, salted)[poly_index] + } + + fn unsalted_evals(&self, oracle_index: usize, salted: bool) -> &[Target] { + let evals = &self.evals_proofs[oracle_index].0; + &evals[..evals.len() - salt_size(salted)] } } @@ -75,7 +80,7 @@ pub struct FriQueryRound, H: Hasher, const D: us pub steps: Vec>, } -#[derive(Clone)] +#[derive(Clone, Debug)] pub struct FriQueryRoundTarget { pub initial_trees_proof: FriInitialTreeProofTarget, pub steps: Vec>, @@ -106,6 +111,7 @@ pub struct FriProof, H: Hasher, const D: usize> pub pow_witness: F, } +#[derive(Debug)] pub struct FriProofTarget { pub commit_phase_merkle_caps: Vec, pub query_round_proofs: Vec>, @@ -131,7 +137,7 @@ impl, H: Hasher, const D: usize> FriProof>( self, indices: &[usize], - common_data: &CommonCircuitData, + params: &FriParams, ) -> CompressedFriProof { let FriProof { commit_phase_merkle_caps, @@ -140,8 +146,8 @@ impl, H: Hasher, const D: usize> FriProof, H: Hasher, const D: usize> CompressedFriPr self, challenges: &ProofChallenges, fri_inferred_elements: FriInferredElements, - common_data: &CommonCircuitData, - ) -> FriProof { + params: &FriParams, + ) -> FriProof + where + [(); H::HASH_SIZE]:, + { let CompressedFriProof { commit_phase_merkle_caps, query_round_proofs, @@ -247,13 +256,13 @@ impl, H: Hasher, const D: usize> CompressedFriPr pow_witness, .. } = self; - let ProofChallenges { + let FriChallenges { fri_query_indices: indices, .. - } = challenges; + } = &challenges.fri_challenges; let mut fri_inferred_elements = fri_inferred_elements.0.into_iter(); - let cap_height = common_data.config.fri_config.cap_height; - let reduction_arity_bits = &common_data.fri_params.reduction_arity_bits; + let cap_height = params.config.cap_height; + let reduction_arity_bits = ¶ms.reduction_arity_bits; let num_reductions = reduction_arity_bits.len(); let num_initial_trees = query_round_proofs .initial_trees_proofs @@ -270,7 +279,7 @@ impl, H: Hasher, const D: usize> CompressedFriPr let mut steps_indices = vec![vec![]; num_reductions]; let mut steps_evals = vec![vec![]; num_reductions]; let mut steps_proofs = vec![vec![]; num_reductions]; - let height = common_data.degree_bits + common_data.config.fri_config.rate_bits; + let height = params.degree_bits + params.config.rate_bits; let heights = reduction_arity_bits .iter() .scan(height, |acc, &bits| { @@ -280,10 +289,8 @@ impl, H: Hasher, const D: usize> CompressedFriPr .collect::>(); // Holds the `evals` vectors that have already been reconstructed at each reduction depth. - let mut evals_by_depth = vec![ - HashMap::>::new(); - common_data.fri_params.reduction_arity_bits.len() - ]; + let mut evals_by_depth = + vec![HashMap::>::new(); params.reduction_arity_bits.len()]; for &(mut index) in indices { let initial_trees_proof = query_round_proofs.initial_trees_proofs[&index].clone(); for (i, (leaves_data, proof)) in @@ -358,3 +365,23 @@ impl, H: Hasher, const D: usize> CompressedFriPr } } } + +pub struct FriChallenges, const D: usize> { + // Scaling factor to combine polynomials. + pub fri_alpha: F::Extension, + + // Betas used in the FRI commit phase reductions. + pub fri_betas: Vec, + + pub fri_pow_response: F, + + // Indices at which the oracle is queried in FRI. + pub fri_query_indices: Vec, +} + +pub struct FriChallengesTarget { + pub fri_alpha: ExtensionTarget, + pub fri_betas: Vec>, + pub fri_pow_response: Target, + pub fri_query_indices: Vec, +} diff --git a/plonky2/src/fri/prover.rs b/plonky2/src/fri/prover.rs index 05135c91..5a20ab9d 100644 --- a/plonky2/src/fri/prover.rs +++ b/plonky2/src/fri/prover.rs @@ -1,3 +1,4 @@ +use itertools::Itertools; use plonky2_field::extension_field::{flatten, unflatten, Extendable}; use plonky2_field::polynomial::{PolynomialCoeffs, PolynomialValues}; use plonky2_util::reverse_index_bits_in_place; @@ -23,9 +24,12 @@ pub fn fri_proof, C: GenericConfig, const challenger: &mut Challenger, fri_params: &FriParams, timing: &mut TimingTree, -) -> FriProof { - let n = lde_polynomial_values.values.len(); - assert_eq!(lde_polynomial_coeffs.coeffs.len(), n); +) -> FriProof +where + [(); C::Hasher::HASH_SIZE]:, +{ + let n = lde_polynomial_values.len(); + assert_eq!(lde_polynomial_coeffs.len(), n); // Commit phase let (trees, final_coeffs) = timed!( @@ -67,13 +71,15 @@ fn fri_committed_trees, C: GenericConfig, ) -> ( Vec>, PolynomialCoeffs, -) { +) +where + [(); C::Hasher::HASH_SIZE]:, +{ let mut trees = Vec::new(); let mut shift = F::MULTIPLICATIVE_GROUP_GENERATOR; - let num_reductions = fri_params.reduction_arity_bits.len(); - for i in 0..num_reductions { - let arity = 1 << fri_params.reduction_arity_bits[i]; + for arity_bits in &fri_params.reduction_arity_bits { + let arity = 1 << arity_bits; reverse_index_bits_in_place(&mut values.values); let chunked_values = values @@ -115,14 +121,13 @@ fn fri_proof_of_work, C: GenericConfig, c (0..=F::NEG_ONE.to_canonical_u64()) .into_par_iter() .find_any(|&i| { - C::InnerHasher::hash( - current_hash + C::InnerHasher::hash_no_pad( + ¤t_hash .elements .iter() .copied() .chain(Some(F::from_canonical_u64(i))) - .collect(), - false, + .collect_vec(), ) .elements[0] .to_canonical_u64() diff --git a/plonky2/src/fri/recursive_verifier.rs b/plonky2/src/fri/recursive_verifier.rs index 276adc2c..f51b8fe6 100644 --- a/plonky2/src/fri/recursive_verifier.rs +++ b/plonky2/src/fri/recursive_verifier.rs @@ -1,9 +1,13 @@ +use itertools::Itertools; use plonky2_field::extension_field::Extendable; -use plonky2_field::field_types::Field; use plonky2_util::{log2_strict, reverse_index_bits_in_place}; -use crate::fri::proof::{FriInitialTreeProofTarget, FriProofTarget, FriQueryRoundTarget}; -use crate::fri::FriConfig; +use crate::fri::proof::{ + FriChallengesTarget, FriInitialTreeProofTarget, FriProofTarget, FriQueryRoundTarget, + FriQueryStepTarget, +}; +use crate::fri::structure::{FriBatchInfoTarget, FriInstanceInfoTarget, FriOpeningsTarget}; +use crate::fri::{FriConfig, FriParams}; use crate::gadgets::interpolation::InterpolationGate; use crate::gates::gate::Gate; use crate::gates::interpolation::HighDegreeInterpolationGate; @@ -11,14 +15,10 @@ use crate::gates::low_degree_interpolation::LowDegreeInterpolationGate; use crate::gates::random_access::RandomAccessGate; use crate::hash::hash_types::MerkleCapTarget; use crate::hash::hash_types::RichField; -use crate::iop::challenger::RecursiveChallenger; use crate::iop::ext_target::{flatten_target, ExtensionTarget}; use crate::iop::target::{BoolTarget, Target}; use crate::plonk::circuit_builder::CircuitBuilder; -use crate::plonk::circuit_data::{CircuitConfig, CommonCircuitData}; -use crate::plonk::config::{AlgebraicConfig, AlgebraicHasher, GenericConfig}; -use crate::plonk::plonk_common::PlonkPolynomials; -use crate::plonk::proof::OpeningSetTarget; +use crate::plonk::config::{AlgebraicHasher, GenericConfig}; use crate::util::reducing::ReducingFactorTarget; use crate::with_context; @@ -32,7 +32,6 @@ impl, const D: usize> CircuitBuilder { arity_bits: usize, evals: &[ExtensionTarget], beta: ExtensionTarget, - common_data: &CommonCircuitData, ) -> ExtensionTarget { let arity = 1 << arity_bits; debug_assert_eq!(evals.len(), arity); @@ -51,7 +50,7 @@ impl, const D: usize> CircuitBuilder { // The answer is gotten by interpolating {(x*g^i, P(x*g^i))} and evaluating at beta. // `HighDegreeInterpolationGate` has degree `arity`, so we use the low-degree gate if // the arity is too large. - if arity > common_data.quotient_degree_factor { + if arity > self.config.max_quotient_degree_factor { self.interpolate_coset::>( arity_bits, coset_start, @@ -71,17 +70,13 @@ impl, const D: usize> CircuitBuilder { /// Make sure we have enough wires and routed wires to do the FRI checks efficiently. This check /// isn't required -- without it we'd get errors elsewhere in the stack -- but just gives more /// helpful errors. - fn check_recursion_config>( - &self, - max_fri_arity_bits: usize, - common_data: &CommonCircuitData, - ) { + fn check_recursion_config>(&self, max_fri_arity_bits: usize) { let random_access = RandomAccessGate::::new_from_config( &self.config, max_fri_arity_bits.max(self.config.fri_config.cap_height), ); let (interpolation_wires, interpolation_routed_wires) = - if 1 << max_fri_arity_bits > common_data.quotient_degree_factor { + if 1 << max_fri_arity_bits > self.config.max_quotient_degree_factor { let gate = LowDegreeInterpolationGate::::new(max_fri_arity_bits); (gate.num_wires(), gate.num_routed_wires()) } else { @@ -111,74 +106,48 @@ impl, const D: usize> CircuitBuilder { fn fri_verify_proof_of_work>( &mut self, - proof: &FriProofTarget, - challenger: &mut RecursiveChallenger, + fri_pow_response: Target, config: &FriConfig, ) { - let mut inputs = challenger.get_hash(self).elements.to_vec(); - inputs.push(proof.pow_witness); - - let hash = self.hash_n_to_m::(inputs, 1, false)[0]; self.assert_leading_zeros( - hash, + fri_pow_response, config.proof_of_work_bits + (64 - F::order().bits()) as u32, ); } - pub fn verify_fri_proof>( + pub fn verify_fri_proof>( &mut self, - // Openings of the PLONK polynomials. - os: &OpeningSetTarget, - // Point at which the PLONK polynomials are opened. - zeta: ExtensionTarget, + instance: &FriInstanceInfoTarget, + openings: &FriOpeningsTarget, + challenges: &FriChallengesTarget, initial_merkle_caps: &[MerkleCapTarget], proof: &FriProofTarget, - challenger: &mut RecursiveChallenger, - common_data: &CommonCircuitData, - ) { - let config = &common_data.config; - - if let Some(max_arity_bits) = common_data.fri_params.max_arity_bits() { - self.check_recursion_config(max_arity_bits, common_data); + params: &FriParams, + ) where + C::Hasher: AlgebraicHasher, + { + if let Some(max_arity_bits) = params.max_arity_bits() { + self.check_recursion_config::(max_arity_bits); } debug_assert_eq!( - common_data.fri_params.final_poly_len(), + params.final_poly_len(), proof.final_poly.len(), "Final polynomial has wrong degree." ); // Size of the LDE domain. - let n = common_data.lde_size(); - - challenger.observe_opening_set(os); - - // Scaling factor to combine polynomials. - let alpha = challenger.get_extension_challenge(self); - - let betas = with_context!( - self, - "recover the random betas used in the FRI reductions.", - proof - .commit_phase_merkle_caps - .iter() - .map(|cap| { - challenger.observe_cap(cap); - challenger.get_extension_challenge(self) - }) - .collect::>() - ); - challenger.observe_extension_elements(&proof.final_poly.0); + let n = params.lde_size(); with_context!( self, "check PoW", - self.fri_verify_proof_of_work::(proof, challenger, &config.fri_config) + self.fri_verify_proof_of_work::(challenges.fri_pow_response, ¶ms.config) ); // Check that parameters are coherent. debug_assert_eq!( - config.fri_config.num_query_rounds, + params.config.num_query_rounds, proof.query_round_proofs.len(), "Number of query rounds does not match config." ); @@ -186,11 +155,9 @@ impl, const D: usize> CircuitBuilder { let precomputed_reduced_evals = with_context!( self, "precompute reduced evaluations", - PrecomputedReducedEvalsTarget::from_os_and_alpha( - os, - alpha, - common_data.degree_bits, - zeta, + PrecomputedReducedOpeningsTarget::from_os_and_alpha( + openings, + challenges.fri_alpha, self ) ); @@ -210,17 +177,16 @@ impl, const D: usize> CircuitBuilder { self, level, &format!("verify one (of {}) query rounds", num_queries), - self.fri_verifier_query_round( - zeta, - alpha, - precomputed_reduced_evals, + self.fri_verifier_query_round::( + instance, + challenges, + &precomputed_reduced_evals, initial_merkle_caps, proof, - challenger, + challenges.fri_query_indices[i], n, - &betas, round_proof, - common_data, + params, ) ); } @@ -255,85 +221,73 @@ impl, const D: usize> CircuitBuilder { fn fri_combine_initial>( &mut self, + instance: &FriInstanceInfoTarget, proof: &FriInitialTreeProofTarget, alpha: ExtensionTarget, subgroup_x: Target, - vanish_zeta: ExtensionTarget, - precomputed_reduced_evals: PrecomputedReducedEvalsTarget, - common_data: &CommonCircuitData, + precomputed_reduced_evals: &PrecomputedReducedOpeningsTarget, + params: &FriParams, ) -> ExtensionTarget { assert!(D > 1, "Not implemented for D=1."); - let config = &common_data.config; - let degree_log = common_data.degree_bits; + let degree_log = params.degree_bits; debug_assert_eq!( degree_log, - common_data.config.fri_config.cap_height + proof.evals_proofs[0].1.siblings.len() - - config.fri_config.rate_bits + params.config.cap_height + proof.evals_proofs[0].1.siblings.len() + - params.config.rate_bits ); let subgroup_x = self.convert_to_ext(subgroup_x); let mut alpha = ReducingFactorTarget::new(alpha); let mut sum = self.zero_extension(); - // We will add two terms to `sum`: one for openings at `x`, and one for openings at `g x`. - // All polynomials are opened at `x`. - let single_evals = [ - PlonkPolynomials::CONSTANTS_SIGMAS, - PlonkPolynomials::WIRES, - PlonkPolynomials::ZS_PARTIAL_PRODUCTS, - PlonkPolynomials::QUOTIENT, - ] - .iter() - .flat_map(|&p| proof.unsalted_evals(p, config.zero_knowledge)) - .copied() - .collect::>(); - let single_composition_eval = alpha.reduce_base(&single_evals, self); - let single_numerator = - self.sub_extension(single_composition_eval, precomputed_reduced_evals.single); - sum = self.div_add_extension(single_numerator, vanish_zeta, sum); - alpha.reset(); - - // Polynomials opened at `x` and `g x`, i.e., the Zs polynomials. - let zs_evals = proof - .unsalted_evals(PlonkPolynomials::ZS_PARTIAL_PRODUCTS, config.zero_knowledge) + for (batch, reduced_openings) in instance + .batches .iter() - .take(common_data.zs_range().end) - .copied() - .collect::>(); - let zs_composition_eval = alpha.reduce_base(&zs_evals, self); + .zip(&precomputed_reduced_evals.reduced_openings_at_point) + { + let FriBatchInfoTarget { point, polynomials } = batch; + let evals = polynomials + .iter() + .map(|p| { + let poly_blinding = instance.oracles[p.oracle_index].blinding; + let salted = params.hiding && poly_blinding; + proof.unsalted_eval(p.oracle_index, p.polynomial_index, salted) + }) + .collect_vec(); + let reduced_evals = alpha.reduce_base(&evals, self); + let numerator = self.sub_extension(reduced_evals, *reduced_openings); + let denominator = self.sub_extension(subgroup_x, *point); + sum = alpha.shift(sum, self); + sum = self.div_add_extension(numerator, denominator, sum); + } - let zs_numerator = - self.sub_extension(zs_composition_eval, precomputed_reduced_evals.zs_right); - let zs_denominator = self.sub_extension(subgroup_x, precomputed_reduced_evals.zeta_right); - sum = alpha.shift(sum, self); // TODO: alpha^count could be precomputed. - sum = self.div_add_extension(zs_numerator, zs_denominator, sum); - - sum + // Multiply the final polynomial by `X`, so that `final_poly` has the maximum degree for + // which the LDT will pass. See github.com/mir-protocol/plonky2/pull/436 for details. + self.mul_extension(sum, subgroup_x) } - fn fri_verifier_query_round>( + fn fri_verifier_query_round>( &mut self, - zeta: ExtensionTarget, - alpha: ExtensionTarget, - precomputed_reduced_evals: PrecomputedReducedEvalsTarget, + instance: &FriInstanceInfoTarget, + challenges: &FriChallengesTarget, + precomputed_reduced_evals: &PrecomputedReducedOpeningsTarget, initial_merkle_caps: &[MerkleCapTarget], proof: &FriProofTarget, - challenger: &mut RecursiveChallenger, + x_index: Target, n: usize, - betas: &[ExtensionTarget], round_proof: &FriQueryRoundTarget, - common_data: &CommonCircuitData, - ) { + params: &FriParams, + ) where + C::Hasher: AlgebraicHasher, + { let n_log = log2_strict(n); // Note that this `low_bits` decomposition permits non-canonical binary encodings. Here we // verify that this has a negligible impact on soundness error. - Self::assert_noncanonical_indices_ok(&common_data.config); - let x_index = challenger.get_challenge(self); + Self::assert_noncanonical_indices_ok(¶ms.config); let mut x_index_bits = self.low_bits(x_index, n_log, F::BITS); - let cap_index = self.le_sum( - x_index_bits[x_index_bits.len() - common_data.config.fri_config.cap_height..].iter(), - ); + let cap_index = + self.le_sum(x_index_bits[x_index_bits.len() - params.config.cap_height..].iter()); with_context!( self, "check FRI initial proof", @@ -346,16 +300,12 @@ impl, const D: usize> CircuitBuilder { ); // `subgroup_x` is `subgroup[x_index]`, i.e., the actual field element in the domain. - let (mut subgroup_x, vanish_zeta) = with_context!(self, "compute x from its index", { + let mut subgroup_x = with_context!(self, "compute x from its index", { let g = self.constant(F::coset_shift()); let phi = F::primitive_root_of_unity(n_log); let phi = self.exp_from_bits_const_base(phi, x_index_bits.iter().rev()); - let g_ext = self.convert_to_ext(g); - let phi_ext = self.convert_to_ext(phi); - // `subgroup_x = g*phi, vanish_zeta = g*phi - zeta` - let subgroup_x = self.mul(g, phi); - let vanish_zeta = self.mul_sub_extension(g_ext, phi_ext, zeta); - (subgroup_x, vanish_zeta) + // subgroup_x = g * phi + self.mul(g, phi) }); // old_eval is the last derived evaluation; it will be checked for consistency with its @@ -363,22 +313,17 @@ impl, const D: usize> CircuitBuilder { let mut old_eval = with_context!( self, "combine initial oracles", - self.fri_combine_initial( + self.fri_combine_initial::( + instance, &round_proof.initial_trees_proof, - alpha, + challenges.fri_alpha, subgroup_x, - vanish_zeta, precomputed_reduced_evals, - common_data, + params, ) ); - for (i, &arity_bits) in common_data - .fri_params - .reduction_arity_bits - .iter() - .enumerate() - { + for (i, &arity_bits) in params.reduction_arity_bits.iter().enumerate() { let evals = &round_proof.steps[i].evals; // Split x_index into the index of the coset x is in, and the index of x within that coset. @@ -393,13 +338,12 @@ impl, const D: usize> CircuitBuilder { old_eval = with_context!( self, "infer evaluation using interpolation", - self.compute_evaluation( + self.compute_evaluation::( subgroup_x, x_index_within_coset_bits, arity_bits, evals, - betas[i], - common_data + challenges.fri_betas[i], ) ); @@ -446,52 +390,110 @@ impl, const D: usize> CircuitBuilder { /// Thus ambiguous elements contribute a negligible amount to soundness error. /// /// Here we compare the probabilities as a sanity check, to verify the claim above. - fn assert_noncanonical_indices_ok(config: &CircuitConfig) { + fn assert_noncanonical_indices_ok(config: &FriConfig) { let num_ambiguous_elems = u64::MAX - F::ORDER + 1; let query_error = config.rate(); let p_ambiguous = (num_ambiguous_elems as f64) / (F::ORDER as f64); assert!(p_ambiguous < query_error * 1e-5, "A non-negligible portion of field elements are in the range that permits non-canonical encodings. Need to do more analysis or enforce canonical encodings."); } -} -#[derive(Copy, Clone)] -struct PrecomputedReducedEvalsTarget { - pub single: ExtensionTarget, - pub zs_right: ExtensionTarget, - pub zeta_right: ExtensionTarget, -} + pub(crate) fn add_virtual_fri_proof( + &mut self, + num_leaves_per_oracle: &[usize], + params: &FriParams, + ) -> FriProofTarget { + let cap_height = params.config.cap_height; + let num_queries = params.config.num_query_rounds; + let commit_phase_merkle_caps = (0..params.reduction_arity_bits.len()) + .map(|_| self.add_virtual_cap(cap_height)) + .collect(); + let query_round_proofs = (0..num_queries) + .map(|_| self.add_virtual_fri_query(num_leaves_per_oracle, params)) + .collect(); + let final_poly = self.add_virtual_poly_coeff_ext(params.final_poly_len()); + let pow_witness = self.add_virtual_target(); + FriProofTarget { + commit_phase_merkle_caps, + query_round_proofs, + final_poly, + pow_witness, + } + } -impl PrecomputedReducedEvalsTarget { - fn from_os_and_alpha>( - os: &OpeningSetTarget, - alpha: ExtensionTarget, - degree_log: usize, - zeta: ExtensionTarget, - builder: &mut CircuitBuilder, - ) -> Self { - let mut alpha = ReducingFactorTarget::new(alpha); - let single = alpha.reduce( - &os.constants - .iter() - .chain(&os.plonk_sigmas) - .chain(&os.wires) - .chain(&os.plonk_zs) - .chain(&os.partial_products) - .chain(&os.quotient_polys) - .copied() - .collect::>(), - builder, - ); - let zs_right = alpha.reduce(&os.plonk_zs_right, builder); + fn add_virtual_fri_query( + &mut self, + num_leaves_per_oracle: &[usize], + params: &FriParams, + ) -> FriQueryRoundTarget { + let cap_height = params.config.cap_height; + assert!(params.lde_bits() >= cap_height); + let mut merkle_proof_len = params.lde_bits() - cap_height; - let g = builder.constant_extension(F::Extension::primitive_root_of_unity(degree_log)); - let zeta_right = builder.mul_extension(g, zeta); + let initial_trees_proof = + self.add_virtual_fri_initial_trees_proof(num_leaves_per_oracle, merkle_proof_len); - Self { - single, - zs_right, - zeta_right, + let mut steps = vec![]; + for &arity_bits in ¶ms.reduction_arity_bits { + assert!(merkle_proof_len >= arity_bits); + merkle_proof_len -= arity_bits; + steps.push(self.add_virtual_fri_query_step(arity_bits, merkle_proof_len)); + } + + FriQueryRoundTarget { + initial_trees_proof, + steps, + } + } + + fn add_virtual_fri_initial_trees_proof( + &mut self, + num_leaves_per_oracle: &[usize], + initial_merkle_proof_len: usize, + ) -> FriInitialTreeProofTarget { + let evals_proofs = num_leaves_per_oracle + .iter() + .map(|&num_oracle_leaves| { + let leaves = self.add_virtual_targets(num_oracle_leaves); + let merkle_proof = self.add_virtual_merkle_proof(initial_merkle_proof_len); + (leaves, merkle_proof) + }) + .collect(); + FriInitialTreeProofTarget { evals_proofs } + } + + fn add_virtual_fri_query_step( + &mut self, + arity_bits: usize, + merkle_proof_len: usize, + ) -> FriQueryStepTarget { + FriQueryStepTarget { + evals: self.add_virtual_extension_targets(1 << arity_bits), + merkle_proof: self.add_virtual_merkle_proof(merkle_proof_len), + } + } +} + +/// For each opening point, holds the reduced (by `alpha`) evaluations of each polynomial that's +/// opened at that point. +#[derive(Clone)] +struct PrecomputedReducedOpeningsTarget { + reduced_openings_at_point: Vec>, +} + +impl PrecomputedReducedOpeningsTarget { + fn from_os_and_alpha>( + openings: &FriOpeningsTarget, + alpha: ExtensionTarget, + builder: &mut CircuitBuilder, + ) -> Self { + let reduced_openings_at_point = openings + .batches + .iter() + .map(|batch| ReducingFactorTarget::new(alpha).reduce(&batch.values, builder)) + .collect(); + Self { + reduced_openings_at_point, } } } diff --git a/plonky2/src/fri/reduction_strategies.rs b/plonky2/src/fri/reduction_strategies.rs index c0423c2c..49eda3ba 100644 --- a/plonky2/src/fri/reduction_strategies.rs +++ b/plonky2/src/fri/reduction_strategies.rs @@ -22,7 +22,7 @@ pub enum FriReductionStrategy { impl FriReductionStrategy { /// The arity of each FRI reduction step, expressed as the log2 of the actual arity. - pub(crate) fn reduction_arity_bits( + pub fn reduction_arity_bits( &self, mut degree_bits: usize, rate_bits: usize, diff --git a/plonky2/src/fri/structure.rs b/plonky2/src/fri/structure.rs new file mode 100644 index 00000000..240abd5d --- /dev/null +++ b/plonky2/src/fri/structure.rs @@ -0,0 +1,83 @@ +//! Information about the structure of a FRI instance, in terms of the oracles and polynomials +//! involved, and the points they are opened at. + +use std::ops::Range; + +use crate::field::extension_field::Extendable; +use crate::hash::hash_types::RichField; +use crate::iop::ext_target::ExtensionTarget; + +/// Describes an instance of a FRI-based batch opening. +pub struct FriInstanceInfo, const D: usize> { + /// The oracles involved, not counting oracles created during the commit phase. + pub oracles: Vec, + /// Batches of openings, where each batch is associated with a particular point. + pub batches: Vec>, +} + +/// Describes an instance of a FRI-based batch opening. +pub struct FriInstanceInfoTarget { + /// The oracles involved, not counting oracles created during the commit phase. + pub oracles: Vec, + /// Batches of openings, where each batch is associated with a particular point. + pub batches: Vec>, +} + +#[derive(Copy, Clone)] +pub struct FriOracleInfo { + pub blinding: bool, +} + +/// A batch of openings at a particular point. +pub struct FriBatchInfo, const D: usize> { + pub point: F::Extension, + pub polynomials: Vec, +} + +/// A batch of openings at a particular point. +pub struct FriBatchInfoTarget { + pub point: ExtensionTarget, + pub polynomials: Vec, +} + +#[derive(Copy, Clone, Debug)] +pub struct FriPolynomialInfo { + /// Index into `FriInstanceInfoTarget`'s `oracles` list. + pub oracle_index: usize, + /// Index of the polynomial within the oracle. + pub polynomial_index: usize, +} + +impl FriPolynomialInfo { + pub fn from_range( + oracle_index: usize, + polynomial_indices: Range, + ) -> Vec { + polynomial_indices + .map(|polynomial_index| FriPolynomialInfo { + oracle_index, + polynomial_index, + }) + .collect() + } +} + +/// Opened values of each polynomial. +pub struct FriOpenings, const D: usize> { + pub batches: Vec>, +} + +/// Opened values of each polynomial that's opened at a particular point. +pub struct FriOpeningBatch, const D: usize> { + pub values: Vec, +} + +/// Opened values of each polynomial. +pub struct FriOpeningsTarget { + pub batches: Vec>, +} + +/// Opened values of each polynomial that's opened at a particular point. +pub struct FriOpeningBatchTarget { + pub values: Vec>, +} diff --git a/plonky2/src/fri/verifier.rs b/plonky2/src/fri/verifier.rs index 4c14f32a..2607ab0d 100644 --- a/plonky2/src/fri/verifier.rs +++ b/plonky2/src/fri/verifier.rs @@ -4,15 +4,13 @@ use plonky2_field::field_types::Field; use plonky2_field::interpolation::{barycentric_weights, interpolate}; use plonky2_util::{log2_strict, reverse_index_bits_in_place}; -use crate::fri::proof::{FriInitialTreeProof, FriProof, FriQueryRound}; -use crate::fri::FriConfig; +use crate::fri::proof::{FriChallenges, FriInitialTreeProof, FriProof, FriQueryRound}; +use crate::fri::structure::{FriBatchInfo, FriInstanceInfo, FriOpenings}; +use crate::fri::{FriConfig, FriParams}; use crate::hash::hash_types::RichField; use crate::hash::merkle_proofs::verify_merkle_proof; use crate::hash::merkle_tree::MerkleCap; -use crate::plonk::circuit_data::CommonCircuitData; use crate::plonk::config::{GenericConfig, Hasher}; -use crate::plonk::plonk_common::PlonkPolynomials; -use crate::plonk::proof::{OpeningSet, ProofChallenges}; use crate::util::reducing::ReducingFactor; use crate::util::reverse_bits; @@ -58,52 +56,51 @@ pub(crate) fn fri_verify_proof_of_work, const D: us Ok(()) } -pub(crate) fn verify_fri_proof< - F: RichField + Extendable, - C: GenericConfig, - const D: usize, ->( - // Openings of the PLONK polynomials. - os: &OpeningSet, - challenges: &ProofChallenges, +pub fn verify_fri_proof, C: GenericConfig, const D: usize>( + instance: &FriInstanceInfo, + openings: &FriOpenings, + challenges: &FriChallenges, initial_merkle_caps: &[MerkleCap], proof: &FriProof, - common_data: &CommonCircuitData, -) -> Result<()> { - let config = &common_data.config; + params: &FriParams, +) -> Result<()> +where + [(); C::Hasher::HASH_SIZE]:, +{ ensure!( - common_data.fri_params.final_poly_len() == proof.final_poly.len(), + params.final_poly_len() == proof.final_poly.len(), "Final polynomial has wrong degree." ); // Size of the LDE domain. - let n = common_data.lde_size(); + let n = params.lde_size(); // Check PoW. - fri_verify_proof_of_work(challenges.fri_pow_response, &config.fri_config)?; + fri_verify_proof_of_work(challenges.fri_pow_response, ¶ms.config)?; // Check that parameters are coherent. ensure!( - config.fri_config.num_query_rounds == proof.query_round_proofs.len(), + params.config.num_query_rounds == proof.query_round_proofs.len(), "Number of query rounds does not match config." ); let precomputed_reduced_evals = - PrecomputedReducedEvals::from_os_and_alpha(os, challenges.fri_alpha); + PrecomputedReducedOpenings::from_os_and_alpha(openings, challenges.fri_alpha); for (&x_index, round_proof) in challenges .fri_query_indices .iter() .zip(&proof.query_round_proofs) { fri_verifier_query_round::( + instance, challenges, - precomputed_reduced_evals, + &precomputed_reduced_evals, initial_merkle_caps, proof, x_index, n, round_proof, - common_data, + params, )?; } @@ -114,7 +111,10 @@ fn fri_verify_initial_proof>( x_index: usize, proof: &FriInitialTreeProof, initial_merkle_caps: &[MerkleCap], -) -> Result<()> { +) -> Result<()> +where + [(); H::HASH_SIZE]:, +{ for ((evals, merkle_proof), cap) in proof.evals_proofs.iter().zip(initial_merkle_caps) { verify_merkle_proof::(evals.clone(), x_index, cap, merkle_proof)?; } @@ -127,51 +127,42 @@ pub(crate) fn fri_combine_initial< C: GenericConfig, const D: usize, >( + instance: &FriInstanceInfo, proof: &FriInitialTreeProof, alpha: F::Extension, - zeta: F::Extension, subgroup_x: F, - precomputed_reduced_evals: PrecomputedReducedEvals, - common_data: &CommonCircuitData, + precomputed_reduced_evals: &PrecomputedReducedOpenings, + params: &FriParams, ) -> F::Extension { - let config = &common_data.config; assert!(D > 1, "Not implemented for D=1."); - let degree_log = common_data.degree_bits; let subgroup_x = F::Extension::from_basefield(subgroup_x); let mut alpha = ReducingFactor::new(alpha); let mut sum = F::Extension::ZERO; - // We will add two terms to `sum`: one for openings at `x`, and one for openings at `g x`. - // All polynomials are opened at `x`. - let single_evals = [ - PlonkPolynomials::CONSTANTS_SIGMAS, - PlonkPolynomials::WIRES, - PlonkPolynomials::ZS_PARTIAL_PRODUCTS, - PlonkPolynomials::QUOTIENT, - ] - .iter() - .flat_map(|&p| proof.unsalted_evals(p, config.zero_knowledge)) - .map(|&e| F::Extension::from_basefield(e)); - let single_composition_eval = alpha.reduce(single_evals); - let single_numerator = single_composition_eval - precomputed_reduced_evals.single; - let single_denominator = subgroup_x - zeta; - sum += single_numerator / single_denominator; - alpha.reset(); - - // Z polynomials have an additional opening at `g x`. - let zs_evals = proof - .unsalted_evals(PlonkPolynomials::ZS_PARTIAL_PRODUCTS, config.zero_knowledge) + for (batch, reduced_openings) in instance + .batches .iter() - .map(|&e| F::Extension::from_basefield(e)) - .take(common_data.zs_range().end); - let zs_composition_eval = alpha.reduce(zs_evals); - let zeta_right = F::Extension::primitive_root_of_unity(degree_log) * zeta; - let zs_numerator = zs_composition_eval - precomputed_reduced_evals.zs_right; - let zs_denominator = subgroup_x - zeta_right; - sum = alpha.shift(sum); - sum += zs_numerator / zs_denominator; + .zip(&precomputed_reduced_evals.reduced_openings_at_point) + { + let FriBatchInfo { point, polynomials } = batch; + let evals = polynomials + .iter() + .map(|p| { + let poly_blinding = instance.oracles[p.oracle_index].blinding; + let salted = params.hiding && poly_blinding; + proof.unsalted_eval(p.oracle_index, p.polynomial_index, salted) + }) + .map(F::Extension::from_basefield); + let reduced_evals = alpha.reduce(evals); + let numerator = reduced_evals - *reduced_openings; + let denominator = subgroup_x - *point; + sum = alpha.shift(sum); + sum += numerator / denominator; + } - sum + // Multiply the final polynomial by `X`, so that `final_poly` has the maximum degree for + // which the LDT will pass. See github.com/mir-protocol/plonky2/pull/436 for details. + sum * subgroup_x } fn fri_verifier_query_round< @@ -179,15 +170,19 @@ fn fri_verifier_query_round< C: GenericConfig, const D: usize, >( - challenges: &ProofChallenges, - precomputed_reduced_evals: PrecomputedReducedEvals, + instance: &FriInstanceInfo, + challenges: &FriChallenges, + precomputed_reduced_evals: &PrecomputedReducedOpenings, initial_merkle_caps: &[MerkleCap], proof: &FriProof, mut x_index: usize, n: usize, round_proof: &FriQueryRound, - common_data: &CommonCircuitData, -) -> Result<()> { + params: &FriParams, +) -> Result<()> +where + [(); C::Hasher::HASH_SIZE]:, +{ fri_verify_initial_proof::( x_index, &round_proof.initial_trees_proof, @@ -200,21 +195,16 @@ fn fri_verifier_query_round< // old_eval is the last derived evaluation; it will be checked for consistency with its // committed "parent" value in the next iteration. - let mut old_eval = fri_combine_initial( + let mut old_eval = fri_combine_initial::( + instance, &round_proof.initial_trees_proof, challenges.fri_alpha, - challenges.plonk_zeta, subgroup_x, precomputed_reduced_evals, - common_data, + params, ); - for (i, &arity_bits) in common_data - .fri_params - .reduction_arity_bits - .iter() - .enumerate() - { + for (i, &arity_bits) in params.reduction_arity_bits.iter().enumerate() { let arity = 1 << arity_bits; let evals = &round_proof.steps[i].evals; @@ -257,28 +247,22 @@ fn fri_verifier_query_round< Ok(()) } -/// Holds the reduced (by `alpha`) evaluations at `zeta` for the polynomial opened just at -/// zeta, for `Z` at zeta and for `Z` at `g*zeta`. -#[derive(Copy, Clone, Debug)] -pub(crate) struct PrecomputedReducedEvals, const D: usize> { - pub single: F::Extension, - pub zs_right: F::Extension, +/// For each opening point, holds the reduced (by `alpha`) evaluations of each polynomial that's +/// opened at that point. +#[derive(Clone, Debug)] +pub(crate) struct PrecomputedReducedOpenings, const D: usize> { + pub reduced_openings_at_point: Vec, } -impl, const D: usize> PrecomputedReducedEvals { - pub(crate) fn from_os_and_alpha(os: &OpeningSet, alpha: F::Extension) -> Self { - let mut alpha = ReducingFactor::new(alpha); - let single = alpha.reduce( - os.constants - .iter() - .chain(&os.plonk_sigmas) - .chain(&os.wires) - .chain(&os.plonk_zs) - .chain(&os.partial_products) - .chain(&os.quotient_polys), - ); - let zs_right = alpha.reduce(os.plonk_zs_right.iter()); - - Self { single, zs_right } +impl, const D: usize> PrecomputedReducedOpenings { + pub(crate) fn from_os_and_alpha(openings: &FriOpenings, alpha: F::Extension) -> Self { + let reduced_openings_at_point = openings + .batches + .iter() + .map(|batch| ReducingFactor::new(alpha).reduce(batch.values.iter())) + .collect(); + Self { + reduced_openings_at_point, + } } } diff --git a/plonky2/src/fri/witness_util.rs b/plonky2/src/fri/witness_util.rs new file mode 100644 index 00000000..741f839d --- /dev/null +++ b/plonky2/src/fri/witness_util.rs @@ -0,0 +1,71 @@ +use itertools::Itertools; +use plonky2_field::extension_field::Extendable; + +use crate::fri::proof::{FriProof, FriProofTarget}; +use crate::hash::hash_types::RichField; +use crate::iop::witness::Witness; +use crate::plonk::config::AlgebraicHasher; + +/// Set the targets in a `FriProofTarget` to their corresponding values in a `FriProof`. +pub fn set_fri_proof_target( + witness: &mut W, + fri_proof_target: &FriProofTarget, + fri_proof: &FriProof, +) where + F: RichField + Extendable, + W: Witness + ?Sized, + H: AlgebraicHasher, +{ + witness.set_target(fri_proof_target.pow_witness, fri_proof.pow_witness); + + for (&t, &x) in fri_proof_target + .final_poly + .0 + .iter() + .zip_eq(&fri_proof.final_poly.coeffs) + { + witness.set_extension_target(t, x); + } + + for (t, x) in fri_proof_target + .commit_phase_merkle_caps + .iter() + .zip_eq(&fri_proof.commit_phase_merkle_caps) + { + witness.set_cap_target(t, x); + } + + for (qt, q) in fri_proof_target + .query_round_proofs + .iter() + .zip_eq(&fri_proof.query_round_proofs) + { + for (at, a) in qt + .initial_trees_proof + .evals_proofs + .iter() + .zip_eq(&q.initial_trees_proof.evals_proofs) + { + for (&t, &x) in at.0.iter().zip_eq(&a.0) { + witness.set_target(t, x); + } + for (&t, &x) in at.1.siblings.iter().zip_eq(&a.1.siblings) { + witness.set_hash_target(t, x); + } + } + + for (st, s) in qt.steps.iter().zip_eq(&q.steps) { + for (&t, &x) in st.evals.iter().zip_eq(&s.evals) { + witness.set_extension_target(t, x); + } + for (&t, &x) in st + .merkle_proof + .siblings + .iter() + .zip_eq(&s.merkle_proof.siblings) + { + witness.set_hash_target(t, x); + } + } + } +} diff --git a/plonky2/src/gadgets/arithmetic.rs b/plonky2/src/gadgets/arithmetic.rs index 734e2705..d8dbaf22 100644 --- a/plonky2/src/gadgets/arithmetic.rs +++ b/plonky2/src/gadgets/arithmetic.rs @@ -1,7 +1,7 @@ use std::borrow::Borrow; use plonky2_field::extension_field::Extendable; -use plonky2_field::field_types::PrimeField; +use plonky2_field::field_types::Field64; use crate::gates::arithmetic_base::ArithmeticGate; use crate::gates::exponentiation::ExponentiationGate; @@ -317,11 +317,17 @@ impl, const D: usize> CircuitBuilder { let x_ext = self.convert_to_ext(x); self.inverse_extension(x_ext).0[0] } + + pub fn not(&mut self, b: BoolTarget) -> BoolTarget { + let one = self.one(); + let res = self.sub(one, b.target); + BoolTarget::new_unsafe(res) + } } /// Represents a base arithmetic operation in the circuit. Used to memoize results. #[derive(Copy, Clone, Eq, PartialEq, Hash)] -pub(crate) struct BaseArithmeticOperation { +pub(crate) struct BaseArithmeticOperation { const_0: F, const_1: F, multiplicand_0: Target, diff --git a/plonky2/src/gadgets/arithmetic_extension.rs b/plonky2/src/gadgets/arithmetic_extension.rs index d29f8f8f..ea3e8b13 100644 --- a/plonky2/src/gadgets/arithmetic_extension.rs +++ b/plonky2/src/gadgets/arithmetic_extension.rs @@ -1,6 +1,6 @@ use plonky2_field::extension_field::FieldExtension; use plonky2_field::extension_field::{Extendable, OEF}; -use plonky2_field::field_types::{Field, PrimeField}; +use plonky2_field::field_types::{Field, Field64}; use plonky2_util::bits_u64; use crate::gates::arithmetic_extension::ArithmeticExtensionGate; @@ -548,7 +548,7 @@ impl, const D: usize> CircuitBuilder { /// Represents an extension arithmetic operation in the circuit. Used to memoize results. #[derive(Copy, Clone, Eq, PartialEq, Hash)] -pub(crate) struct ExtensionArithmeticOperation, const D: usize> { +pub(crate) struct ExtensionArithmeticOperation, const D: usize> { const_0: F, const_1: F, multiplicand_0: ExtensionTarget, diff --git a/plonky2/src/gadgets/arithmetic_u32.rs b/plonky2/src/gadgets/arithmetic_u32.rs index dfdbb5fb..649c0624 100644 --- a/plonky2/src/gadgets/arithmetic_u32.rs +++ b/plonky2/src/gadgets/arithmetic_u32.rs @@ -1,9 +1,14 @@ +use std::marker::PhantomData; + use plonky2_field::extension_field::Extendable; +use crate::gates::add_many_u32::U32AddManyGate; use crate::gates::arithmetic_u32::U32ArithmeticGate; use crate::gates::subtraction_u32::U32SubtractionGate; use crate::hash::hash_types::RichField; +use crate::iop::generator::{GeneratedValues, SimpleGenerator}; use crate::iop::target::Target; +use crate::iop::witness::{PartitionWitness, Witness}; use crate::plonk::circuit_builder::CircuitBuilder; #[derive(Clone, Copy, Debug)] @@ -113,18 +118,57 @@ impl, const D: usize> CircuitBuilder { 1 => (to_add[0], self.zero_u32()), 2 => self.add_u32(to_add[0], to_add[1]), _ => { - let (mut low, mut carry) = self.add_u32(to_add[0], to_add[1]); - for i in 2..to_add.len() { - let (new_low, new_carry) = self.add_u32(to_add[i], low); - let (combined_carry, _zero) = self.add_u32(carry, new_carry); - low = new_low; - carry = combined_carry; + let num_addends = to_add.len(); + let gate = U32AddManyGate::::new_from_config(&self.config, num_addends); + let (gate_index, copy) = self.find_u32_add_many_gate(num_addends); + + for j in 0..num_addends { + self.connect( + Target::wire(gate_index, gate.wire_ith_op_jth_addend(copy, j)), + to_add[j].0, + ); } - (low, carry) + let zero = self.zero(); + self.connect(Target::wire(gate_index, gate.wire_ith_carry(copy)), zero); + + let output_low = + U32Target(Target::wire(gate_index, gate.wire_ith_output_result(copy))); + let output_high = + U32Target(Target::wire(gate_index, gate.wire_ith_output_carry(copy))); + + (output_low, output_high) } } } + pub fn add_u32s_with_carry( + &mut self, + to_add: &[U32Target], + carry: U32Target, + ) -> (U32Target, U32Target) { + if to_add.len() == 1 { + return self.add_u32(to_add[0], carry); + } + + let num_addends = to_add.len(); + + let gate = U32AddManyGate::::new_from_config(&self.config, num_addends); + let (gate_index, copy) = self.find_u32_add_many_gate(num_addends); + + for j in 0..num_addends { + self.connect( + Target::wire(gate_index, gate.wire_ith_op_jth_addend(copy, j)), + to_add[j].0, + ); + } + self.connect(Target::wire(gate_index, gate.wire_ith_carry(copy)), carry.0); + + let output = U32Target(Target::wire(gate_index, gate.wire_ith_output_result(copy))); + let output_carry = U32Target(Target::wire(gate_index, gate.wire_ith_output_carry(copy))); + + (output, output_carry) + } + pub fn mul_u32(&mut self, a: U32Target, b: U32Target) -> (U32Target, U32Target) { let zero = self.zero_u32(); self.mul_add_u32(a, b, zero) @@ -153,3 +197,75 @@ impl, const D: usize> CircuitBuilder { (output_result, output_borrow) } } + +#[derive(Debug)] +struct SplitToU32Generator, const D: usize> { + x: Target, + low: U32Target, + high: U32Target, + _phantom: PhantomData, +} + +impl, const D: usize> SimpleGenerator + for SplitToU32Generator +{ + fn dependencies(&self) -> Vec { + vec![self.x] + } + + fn run_once(&self, witness: &PartitionWitness, out_buffer: &mut GeneratedValues) { + let x = witness.get_target(self.x); + let x_u64 = x.to_canonical_u64(); + let low = x_u64 as u32; + let high = (x_u64 >> 32) as u32; + + out_buffer.set_u32_target(self.low, low); + out_buffer.set_u32_target(self.high, high); + } +} + +#[cfg(test)] +mod tests { + use anyhow::Result; + use rand::{thread_rng, Rng}; + + use crate::iop::witness::PartialWitness; + use crate::plonk::circuit_builder::CircuitBuilder; + use crate::plonk::circuit_data::CircuitConfig; + use crate::plonk::config::{GenericConfig, PoseidonGoldilocksConfig}; + use crate::plonk::verifier::verify; + + #[test] + pub fn test_add_many_u32s() -> Result<()> { + const D: usize = 2; + type C = PoseidonGoldilocksConfig; + type F = >::F; + + const NUM_ADDENDS: usize = 15; + + let config = CircuitConfig::standard_recursion_config(); + + let pw = PartialWitness::new(); + let mut builder = CircuitBuilder::::new(config); + + let mut rng = thread_rng(); + let mut to_add = Vec::new(); + let mut sum = 0u64; + for _ in 0..NUM_ADDENDS { + let x: u32 = rng.gen(); + sum += x as u64; + to_add.push(builder.constant_u32(x)); + } + let carry = builder.zero_u32(); + let (result_low, result_high) = builder.add_u32s_with_carry(&to_add, carry); + let expected_low = builder.constant_u32((sum % (1 << 32)) as u32); + let expected_high = builder.constant_u32((sum >> 32) as u32); + + builder.connect_u32(result_low, expected_low); + builder.connect_u32(result_high, expected_high); + + let data = builder.build::(); + let proof = data.prove(pw).unwrap(); + verify(proof, &data.verifier_only, &data.common) + } +} diff --git a/plonky2/src/gadgets/biguint.rs b/plonky2/src/gadgets/biguint.rs index 77013d27..c9ad7280 100644 --- a/plonky2/src/gadgets/biguint.rs +++ b/plonky2/src/gadgets/biguint.rs @@ -1,6 +1,6 @@ use std::marker::PhantomData; -use num::{BigUint, Integer}; +use num::{BigUint, Integer, Zero}; use plonky2_field::extension_field::Extendable; use crate::gadgets::arithmetic_u32::U32Target; @@ -33,6 +33,10 @@ impl, const D: usize> CircuitBuilder { BigUintTarget { limbs } } + pub fn zero_biguint(&mut self) -> BigUintTarget { + self.constant_biguint(&BigUint::zero()) + } + pub fn connect_biguint(&mut self, lhs: &BigUintTarget, rhs: &BigUintTarget) { let min_limbs = lhs.num_limbs().min(rhs.num_limbs()); for i in 0..min_limbs { @@ -76,9 +80,7 @@ impl, const D: usize> CircuitBuilder { } pub fn add_virtual_biguint_target(&mut self, num_limbs: usize) -> BigUintTarget { - let limbs = (0..num_limbs) - .map(|_| self.add_virtual_u32_target()) - .collect(); + let limbs = self.add_virtual_u32_targets(num_limbs); BigUintTarget { limbs } } @@ -143,8 +145,7 @@ impl, const D: usize> CircuitBuilder { let mut combined_limbs = vec![]; let mut carry = self.zero_u32(); for summands in &mut to_add { - summands.push(carry); - let (new_result, new_carry) = self.add_many_u32(summands); + let (new_result, new_carry) = self.add_u32s_with_carry(summands, carry); combined_limbs.push(new_result); carry = new_carry; } @@ -155,6 +156,18 @@ impl, const D: usize> CircuitBuilder { } } + pub fn mul_biguint_by_bool(&mut self, a: &BigUintTarget, b: BoolTarget) -> BigUintTarget { + let t = b.target; + + BigUintTarget { + limbs: a + .limbs + .iter() + .map(|&l| U32Target(self.mul(l.0, t))) + .collect(), + } + } + // Returns x * y + z. This is no more efficient than mul-then-add; it's purely for convenience (only need to call one CircuitBuilder function). pub fn mul_add_biguint( &mut self, diff --git a/plonky2/src/gadgets/curve.rs b/plonky2/src/gadgets/curve.rs index 63e96721..8c182345 100644 --- a/plonky2/src/gadgets/curve.rs +++ b/plonky2/src/gadgets/curve.rs @@ -104,29 +104,17 @@ impl, const D: usize> CircuitBuilder { let AffinePointTarget { x: x2, y: y2 } = p2; let u = self.sub_nonnative(y2, y1); - let uu = self.mul_nonnative(&u, &u); let v = self.sub_nonnative(x2, x1); - let vv = self.mul_nonnative(&v, &v); - let vvv = self.mul_nonnative(&v, &vv); - let r = self.mul_nonnative(&vv, x1); - let diff = self.sub_nonnative(&uu, &vvv); - let r2 = self.add_nonnative(&r, &r); - let a = self.sub_nonnative(&diff, &r2); - let x3 = self.mul_nonnative(&v, &a); + let v_inv = self.inv_nonnative(&v); + let s = self.mul_nonnative(&u, &v_inv); + let s_squared = self.mul_nonnative(&s, &s); + let x_sum = self.add_nonnative(x2, x1); + let x3 = self.sub_nonnative(&s_squared, &x_sum); + let x_diff = self.sub_nonnative(x1, &x3); + let prod = self.mul_nonnative(&s, &x_diff); + let y3 = self.sub_nonnative(&prod, y1); - let r_a = self.sub_nonnative(&r, &a); - let y3_first = self.mul_nonnative(&u, &r_a); - let y3_second = self.mul_nonnative(&vvv, y1); - let y3 = self.sub_nonnative(&y3_first, &y3_second); - - let z3_inv = self.inv_nonnative(&vvv); - let x3_norm = self.mul_nonnative(&x3, &z3_inv); - let y3_norm = self.mul_nonnative(&y3, &z3_inv); - - AffinePointTarget { - x: x3_norm, - y: y3_norm, - } + AffinePointTarget { x: x3, y: y3 } } pub fn curve_scalar_mul( @@ -134,11 +122,7 @@ impl, const D: usize> CircuitBuilder { p: &AffinePointTarget, n: &NonNativeTarget, ) -> AffinePointTarget { - let one = self.constant_nonnative(C::BaseField::ONE); - let bits = self.split_nonnative_to_bits(n); - let bits_as_base: Vec> = - bits.iter().map(|b| self.bool_to_nonnative(b)).collect(); let rando = (CurveScalar(C::ScalarField::rand()) * C::GENERATOR_PROJECTIVE).to_affine(); let randot = self.constant_affine_point(rando); @@ -149,15 +133,15 @@ impl, const D: usize> CircuitBuilder { let mut two_i_times_p = self.add_virtual_affine_point_target(); self.connect_affine_point(p, &two_i_times_p); - for bit in bits_as_base.iter() { - let not_bit = self.sub_nonnative(&one, bit); + for &bit in bits.iter() { + let not_bit = self.not(bit); let result_plus_2_i_p = self.curve_add(&result, &two_i_times_p); - let new_x_if_bit = self.mul_nonnative(bit, &result_plus_2_i_p.x); - let new_x_if_not_bit = self.mul_nonnative(¬_bit, &result.x); - let new_y_if_bit = self.mul_nonnative(bit, &result_plus_2_i_p.y); - let new_y_if_not_bit = self.mul_nonnative(¬_bit, &result.y); + let new_x_if_bit = self.mul_nonnative_by_bool(&result_plus_2_i_p.x, bit); + let new_x_if_not_bit = self.mul_nonnative_by_bool(&result.x, not_bit); + let new_y_if_bit = self.mul_nonnative_by_bool(&result_plus_2_i_p.y, bit); + let new_y_if_not_bit = self.mul_nonnative_by_bool(&result.y, not_bit); let new_x = self.add_nonnative(&new_x_if_bit, &new_x_if_not_bit); let new_y = self.add_nonnative(&new_y_if_bit, &new_y_if_not_bit); @@ -177,6 +161,8 @@ impl, const D: usize> CircuitBuilder { #[cfg(test)] mod tests { + use std::ops::Neg; + use anyhow::Result; use plonky2_field::field_types::Field; use plonky2_field::secp256k1_base::Secp256K1Base; @@ -196,7 +182,7 @@ mod tests { type C = PoseidonGoldilocksConfig; type F = >::F; - let config = CircuitConfig::standard_recursion_config(); + let config = CircuitConfig::standard_ecc_config(); let pw = PartialWitness::new(); let mut builder = CircuitBuilder::::new(config); @@ -221,7 +207,7 @@ mod tests { type C = PoseidonGoldilocksConfig; type F = >::F; - let config = CircuitConfig::standard_recursion_config(); + let config = CircuitConfig::standard_ecc_config(); let pw = PartialWitness::new(); let mut builder = CircuitBuilder::::new(config); @@ -248,7 +234,7 @@ mod tests { type C = PoseidonGoldilocksConfig; type F = >::F; - let config = CircuitConfig::standard_recursion_config(); + let config = CircuitConfig::standard_ecc_config(); let pw = PartialWitness::new(); let mut builder = CircuitBuilder::::new(config); @@ -285,7 +271,7 @@ mod tests { type C = PoseidonGoldilocksConfig; type F = >::F; - let config = CircuitConfig::standard_recursion_config(); + let config = CircuitConfig::standard_ecc_config(); let pw = PartialWitness::new(); let mut builder = CircuitBuilder::::new(config); @@ -316,27 +302,25 @@ mod tests { type C = PoseidonGoldilocksConfig; type F = >::F; - let config = CircuitConfig { - num_routed_wires: 33, - ..CircuitConfig::standard_recursion_config() - }; + let config = CircuitConfig::standard_ecc_config(); let pw = PartialWitness::new(); let mut builder = CircuitBuilder::::new(config); let g = Secp256K1::GENERATOR_AFFINE; let five = Secp256K1Scalar::from_canonical_usize(5); - let five_scalar = CurveScalar::(five); - let five_g = (five_scalar * g.to_projective()).to_affine(); - let five_g_expected = builder.constant_affine_point(five_g); - builder.curve_assert_valid(&five_g_expected); + let neg_five = five.neg(); + let neg_five_scalar = CurveScalar::(neg_five); + let neg_five_g = (neg_five_scalar * g.to_projective()).to_affine(); + let neg_five_g_expected = builder.constant_affine_point(neg_five_g); + builder.curve_assert_valid(&neg_five_g_expected); let g_target = builder.constant_affine_point(g); - let five_target = builder.constant_nonnative(five); - let five_g_actual = builder.curve_scalar_mul(&g_target, &five_target); - builder.curve_assert_valid(&five_g_actual); + let neg_five_target = builder.constant_nonnative(neg_five); + let neg_five_g_actual = builder.curve_scalar_mul(&g_target, &neg_five_target); + builder.curve_assert_valid(&neg_five_g_actual); - builder.connect_affine_point(&five_g_expected, &five_g_actual); + builder.connect_affine_point(&neg_five_g_expected, &neg_five_g_actual); let data = builder.build::(); let proof = data.prove(pw).unwrap(); @@ -345,16 +329,12 @@ mod tests { } #[test] - #[ignore] fn test_curve_random() -> Result<()> { const D: usize = 2; type C = PoseidonGoldilocksConfig; type F = >::F; - let config = CircuitConfig { - num_routed_wires: 33, - ..CircuitConfig::standard_recursion_config() - }; + let config = CircuitConfig::standard_ecc_config(); let pw = PartialWitness::new(); let mut builder = CircuitBuilder::::new(config); diff --git a/plonky2/src/gadgets/ecdsa.rs b/plonky2/src/gadgets/ecdsa.rs new file mode 100644 index 00000000..0a95e189 --- /dev/null +++ b/plonky2/src/gadgets/ecdsa.rs @@ -0,0 +1,104 @@ +use std::marker::PhantomData; + +use crate::curve::curve_types::Curve; +use crate::field::extension_field::Extendable; +use crate::gadgets::curve::AffinePointTarget; +use crate::gadgets::nonnative::NonNativeTarget; +use crate::hash::hash_types::RichField; +use crate::plonk::circuit_builder::CircuitBuilder; + +#[derive(Clone, Debug)] +pub struct ECDSASecretKeyTarget(NonNativeTarget); + +#[derive(Clone, Debug)] +pub struct ECDSAPublicKeyTarget(AffinePointTarget); + +#[derive(Clone, Debug)] +pub struct ECDSASignatureTarget { + pub r: NonNativeTarget, + pub s: NonNativeTarget, +} + +impl, const D: usize> CircuitBuilder { + pub fn verify_message( + &mut self, + msg: NonNativeTarget, + sig: ECDSASignatureTarget, + pk: ECDSAPublicKeyTarget, + ) { + let ECDSASignatureTarget { r, s } = sig; + + self.curve_assert_valid(&pk.0); + + let c = self.inv_nonnative(&s); + let u1 = self.mul_nonnative(&msg, &c); + let u2 = self.mul_nonnative(&r, &c); + + let g = self.constant_affine_point(C::GENERATOR_AFFINE); + let point1 = self.curve_scalar_mul(&g, &u1); + let point2 = self.curve_scalar_mul(&pk.0, &u2); + let point = self.curve_add(&point1, &point2); + + let x = NonNativeTarget:: { + value: point.x.value, + _phantom: PhantomData, + }; + self.connect_nonnative(&r, &x); + } +} + +#[cfg(test)] +mod tests { + use anyhow::Result; + + use crate::curve::curve_types::{Curve, CurveScalar}; + use crate::curve::ecdsa::{sign_message, ECDSAPublicKey, ECDSASecretKey, ECDSASignature}; + use crate::curve::secp256k1::Secp256K1; + use crate::field::field_types::Field; + use crate::field::secp256k1_scalar::Secp256K1Scalar; + use crate::gadgets::ecdsa::{ECDSAPublicKeyTarget, ECDSASignatureTarget}; + use crate::iop::witness::PartialWitness; + use crate::plonk::circuit_builder::CircuitBuilder; + use crate::plonk::circuit_data::CircuitConfig; + use crate::plonk::config::{GenericConfig, PoseidonGoldilocksConfig}; + use crate::plonk::verifier::verify; + + #[test] + #[ignore] + fn test_ecdsa_circuit() -> Result<()> { + const D: usize = 2; + type C = PoseidonGoldilocksConfig; + type F = >::F; + + type Curve = Secp256K1; + + let config = CircuitConfig::standard_ecc_config(); + + let pw = PartialWitness::new(); + let mut builder = CircuitBuilder::::new(config); + + let msg = Secp256K1Scalar::rand(); + let msg_target = builder.constant_nonnative(msg); + + let sk = ECDSASecretKey::(Secp256K1Scalar::rand()); + let pk = ECDSAPublicKey((CurveScalar(sk.0) * Curve::GENERATOR_PROJECTIVE).to_affine()); + + let pk_target = ECDSAPublicKeyTarget(builder.constant_affine_point(pk.0)); + + let sig = sign_message(msg, sk); + + let ECDSASignature { r, s } = sig; + let r_target = builder.constant_nonnative(r); + let s_target = builder.constant_nonnative(s); + let sig_target = ECDSASignatureTarget { + r: r_target, + s: s_target, + }; + + builder.verify_message(msg_target, sig_target, pk_target); + + let data = builder.build::(); + let proof = data.prove(pw).unwrap(); + verify(proof, &data.verifier_only, &data.common) + } +} diff --git a/plonky2/src/gadgets/mod.rs b/plonky2/src/gadgets/mod.rs index b73e2a7f..ec4d1263 100644 --- a/plonky2/src/gadgets/mod.rs +++ b/plonky2/src/gadgets/mod.rs @@ -3,6 +3,7 @@ pub mod arithmetic_extension; pub mod arithmetic_u32; pub mod biguint; pub mod curve; +pub mod ecdsa; pub mod hash; pub mod interpolation; pub mod multiple_comparison; diff --git a/plonky2/src/gadgets/multiple_comparison.rs b/plonky2/src/gadgets/multiple_comparison.rs index 7d637d87..112b0113 100644 --- a/plonky2/src/gadgets/multiple_comparison.rs +++ b/plonky2/src/gadgets/multiple_comparison.rs @@ -60,8 +60,9 @@ impl, const D: usize> CircuitBuilder { /// Helper function for comparing, specifically, lists of `U32Target`s. pub fn list_le_u32(&mut self, a: Vec, b: Vec) -> BoolTarget { - let a_targets = a.iter().map(|&t| t.0).collect(); - let b_targets = b.iter().map(|&t| t.0).collect(); + let a_targets: Vec = a.iter().map(|&t| t.0).collect(); + let b_targets: Vec = b.iter().map(|&t| t.0).collect(); + self.list_le(a_targets, b_targets, 32) } } diff --git a/plonky2/src/gadgets/nonnative.rs b/plonky2/src/gadgets/nonnative.rs index 16fd022e..3f8d29e8 100644 --- a/plonky2/src/gadgets/nonnative.rs +++ b/plonky2/src/gadgets/nonnative.rs @@ -1,6 +1,7 @@ use std::marker::PhantomData; -use num::{BigUint, Zero}; +use num::{BigUint, Integer, One, Zero}; +use plonky2_field::field_types::PrimeField; use plonky2_field::{extension_field::Extendable, field_types::Field}; use plonky2_util::ceil_div_usize; @@ -15,7 +16,7 @@ use crate::plonk::circuit_builder::CircuitBuilder; #[derive(Clone, Debug)] pub struct NonNativeTarget { pub(crate) value: BigUintTarget, - _phantom: PhantomData, + pub(crate) _phantom: PhantomData, } impl, const D: usize> CircuitBuilder { @@ -34,11 +35,15 @@ impl, const D: usize> CircuitBuilder { x.value.clone() } - pub fn constant_nonnative(&mut self, x: FF) -> NonNativeTarget { - let x_biguint = self.constant_biguint(&x.to_biguint()); + pub fn constant_nonnative(&mut self, x: FF) -> NonNativeTarget { + let x_biguint = self.constant_biguint(&x.to_canonical_biguint()); self.biguint_to_nonnative(&x_biguint) } + pub fn zero_nonnative(&mut self) -> NonNativeTarget { + self.constant_nonnative(FF::ZERO) + } + // Assert that two NonNativeTarget's, both assumed to be in reduced form, are equal. pub fn connect_nonnative( &mut self, @@ -58,82 +63,204 @@ impl, const D: usize> CircuitBuilder { } } - // Add two `NonNativeTarget`s. - pub fn add_nonnative( + pub fn add_nonnative( &mut self, a: &NonNativeTarget, b: &NonNativeTarget, ) -> NonNativeTarget { - let result = self.add_biguint(&a.value, &b.value); + let sum = self.add_virtual_nonnative_target::(); + let overflow = self.add_virtual_bool_target(); - // TODO: reduce add result with only one conditional subtraction - self.reduce(&result) + self.add_simple_generator(NonNativeAdditionGenerator:: { + a: a.clone(), + b: b.clone(), + sum: sum.clone(), + overflow, + _phantom: PhantomData, + }); + + let sum_expected = self.add_biguint(&a.value, &b.value); + + let modulus = self.constant_biguint(&FF::order()); + let mod_times_overflow = self.mul_biguint_by_bool(&modulus, overflow); + let sum_actual = self.add_biguint(&sum.value, &mod_times_overflow); + self.connect_biguint(&sum_expected, &sum_actual); + + // Range-check result. + // TODO: can potentially leave unreduced until necessary (e.g. when connecting values). + let cmp = self.cmp_biguint(&sum.value, &modulus); + let one = self.one(); + self.connect(cmp.target, one); + + sum + } + + pub fn mul_nonnative_by_bool( + &mut self, + a: &NonNativeTarget, + b: BoolTarget, + ) -> NonNativeTarget { + NonNativeTarget { + value: self.mul_biguint_by_bool(&a.value, b), + _phantom: PhantomData, + } + } + + pub fn add_many_nonnative( + &mut self, + to_add: &[NonNativeTarget], + ) -> NonNativeTarget { + if to_add.len() == 1 { + return to_add[0].clone(); + } + + let sum = self.add_virtual_nonnative_target::(); + let overflow = self.add_virtual_u32_target(); + let summands = to_add.to_vec(); + + self.add_simple_generator(NonNativeMultipleAddsGenerator:: { + summands: summands.clone(), + sum: sum.clone(), + overflow, + _phantom: PhantomData, + }); + + self.range_check_u32(sum.value.limbs.clone()); + self.range_check_u32(vec![overflow]); + + let sum_expected = summands + .iter() + .fold(self.zero_biguint(), |a, b| self.add_biguint(&a, &b.value)); + + let modulus = self.constant_biguint(&FF::order()); + let overflow_biguint = BigUintTarget { + limbs: vec![overflow], + }; + let mod_times_overflow = self.mul_biguint(&modulus, &overflow_biguint); + let sum_actual = self.add_biguint(&sum.value, &mod_times_overflow); + self.connect_biguint(&sum_expected, &sum_actual); + + // Range-check result. + // TODO: can potentially leave unreduced until necessary (e.g. when connecting values). + let cmp = self.cmp_biguint(&sum.value, &modulus); + let one = self.one(); + self.connect(cmp.target, one); + + sum } // Subtract two `NonNativeTarget`s. - pub fn sub_nonnative( + pub fn sub_nonnative( &mut self, a: &NonNativeTarget, b: &NonNativeTarget, ) -> NonNativeTarget { - let order = self.constant_biguint(&FF::order()); - let a_plus_order = self.add_biguint(&order, &a.value); - let result = self.sub_biguint(&a_plus_order, &b.value); + let diff = self.add_virtual_nonnative_target::(); + let overflow = self.add_virtual_bool_target(); - // TODO: reduce sub result with only one conditional addition? - self.reduce(&result) + self.add_simple_generator(NonNativeSubtractionGenerator:: { + a: a.clone(), + b: b.clone(), + diff: diff.clone(), + overflow, + _phantom: PhantomData, + }); + + self.range_check_u32(diff.value.limbs.clone()); + self.assert_bool(overflow); + + let diff_plus_b = self.add_biguint(&diff.value, &b.value); + let modulus = self.constant_biguint(&FF::order()); + let mod_times_overflow = self.mul_biguint_by_bool(&modulus, overflow); + let diff_plus_b_reduced = self.sub_biguint(&diff_plus_b, &mod_times_overflow); + self.connect_biguint(&a.value, &diff_plus_b_reduced); + + diff } - pub fn mul_nonnative( + pub fn mul_nonnative( &mut self, a: &NonNativeTarget, b: &NonNativeTarget, ) -> NonNativeTarget { - let result = self.mul_biguint(&a.value, &b.value); + let prod = self.add_virtual_nonnative_target::(); + let modulus = self.constant_biguint(&FF::order()); + let overflow = self.add_virtual_biguint_target( + a.value.num_limbs() + b.value.num_limbs() - modulus.num_limbs(), + ); - self.reduce(&result) + self.add_simple_generator(NonNativeMultiplicationGenerator:: { + a: a.clone(), + b: b.clone(), + prod: prod.clone(), + overflow: overflow.clone(), + _phantom: PhantomData, + }); + + self.range_check_u32(prod.value.limbs.clone()); + self.range_check_u32(overflow.limbs.clone()); + + let prod_expected = self.mul_biguint(&a.value, &b.value); + + let mod_times_overflow = self.mul_biguint(&modulus, &overflow); + let prod_actual = self.add_biguint(&prod.value, &mod_times_overflow); + self.connect_biguint(&prod_expected, &prod_actual); + + prod } - pub fn neg_nonnative(&mut self, x: &NonNativeTarget) -> NonNativeTarget { + pub fn mul_many_nonnative( + &mut self, + to_mul: &[NonNativeTarget], + ) -> NonNativeTarget { + if to_mul.len() == 1 { + return to_mul[0].clone(); + } + + let mut accumulator = self.mul_nonnative(&to_mul[0], &to_mul[1]); + for i in 2..to_mul.len() { + accumulator = self.mul_nonnative(&accumulator, &to_mul[i]); + } + accumulator + } + + pub fn neg_nonnative( + &mut self, + x: &NonNativeTarget, + ) -> NonNativeTarget { let zero_target = self.constant_biguint(&BigUint::zero()); let zero_ff = self.biguint_to_nonnative(&zero_target); self.sub_nonnative(&zero_ff, x) } - pub fn inv_nonnative(&mut self, x: &NonNativeTarget) -> NonNativeTarget { + pub fn inv_nonnative( + &mut self, + x: &NonNativeTarget, + ) -> NonNativeTarget { let num_limbs = x.value.num_limbs(); let inv_biguint = self.add_virtual_biguint_target(num_limbs); - let inv = NonNativeTarget:: { - value: inv_biguint, - _phantom: PhantomData, - }; + let div = self.add_virtual_biguint_target(num_limbs); self.add_simple_generator(NonNativeInverseGenerator:: { x: x.clone(), - inv: inv.clone(), + inv: inv_biguint.clone(), + div: div.clone(), _phantom: PhantomData, }); - let product = self.mul_nonnative(x, &inv); - let one = self.constant_nonnative(FF::ONE); - self.connect_nonnative(&product, &one); + let product = self.mul_biguint(&x.value, &inv_biguint); - inv - } + let modulus = self.constant_biguint(&FF::order()); + let mod_times_div = self.mul_biguint(&modulus, &div); + let one = self.constant_biguint(&BigUint::one()); + let expected_product = self.add_biguint(&mod_times_div, &one); + self.connect_biguint(&product, &expected_product); - pub fn div_rem_nonnative( - &mut self, - x: &NonNativeTarget, - y: &NonNativeTarget, - ) -> (NonNativeTarget, NonNativeTarget) { - let x_biguint = self.nonnative_to_biguint(x); - let y_biguint = self.nonnative_to_biguint(y); - - let (div_biguint, rem_biguint) = self.div_rem_biguint(&x_biguint, &y_biguint); - let div = self.biguint_to_nonnative(&div_biguint); - let rem = self.biguint_to_nonnative(&rem_biguint); - (div, rem) + NonNativeTarget:: { + value: inv_biguint, + _phantom: PhantomData, + } } /// Returns `x % |FF|` as a `NonNativeTarget`. @@ -148,8 +275,7 @@ impl, const D: usize> CircuitBuilder { } } - #[allow(dead_code)] - fn reduce_nonnative(&mut self, x: &NonNativeTarget) -> NonNativeTarget { + pub fn reduce_nonnative(&mut self, x: &NonNativeTarget) -> NonNativeTarget { let x_biguint = self.nonnative_to_biguint(x); self.reduce(&x_biguint) } @@ -188,13 +314,178 @@ impl, const D: usize> CircuitBuilder { } #[derive(Debug)] -struct NonNativeInverseGenerator, const D: usize, FF: Field> { - x: NonNativeTarget, - inv: NonNativeTarget, +struct NonNativeAdditionGenerator, const D: usize, FF: PrimeField> { + a: NonNativeTarget, + b: NonNativeTarget, + sum: NonNativeTarget, + overflow: BoolTarget, _phantom: PhantomData, } -impl, const D: usize, FF: Field> SimpleGenerator +impl, const D: usize, FF: PrimeField> SimpleGenerator + for NonNativeAdditionGenerator +{ + fn dependencies(&self) -> Vec { + self.a + .value + .limbs + .iter() + .cloned() + .chain(self.b.value.limbs.clone()) + .map(|l| l.0) + .collect() + } + + fn run_once(&self, witness: &PartitionWitness, out_buffer: &mut GeneratedValues) { + let a = witness.get_nonnative_target(self.a.clone()); + let b = witness.get_nonnative_target(self.b.clone()); + let a_biguint = a.to_canonical_biguint(); + let b_biguint = b.to_canonical_biguint(); + let sum_biguint = a_biguint + b_biguint; + let modulus = FF::order(); + let (overflow, sum_reduced) = if sum_biguint > modulus { + (true, sum_biguint - modulus) + } else { + (false, sum_biguint) + }; + + out_buffer.set_biguint_target(self.sum.value.clone(), sum_reduced); + out_buffer.set_bool_target(self.overflow, overflow); + } +} + +#[derive(Debug)] +struct NonNativeMultipleAddsGenerator, const D: usize, FF: PrimeField> +{ + summands: Vec>, + sum: NonNativeTarget, + overflow: U32Target, + _phantom: PhantomData, +} + +impl, const D: usize, FF: PrimeField> SimpleGenerator + for NonNativeMultipleAddsGenerator +{ + fn dependencies(&self) -> Vec { + self.summands + .iter() + .flat_map(|summand| summand.value.limbs.iter().map(|limb| limb.0)) + .collect() + } + + fn run_once(&self, witness: &PartitionWitness, out_buffer: &mut GeneratedValues) { + let summands: Vec<_> = self + .summands + .iter() + .map(|summand| witness.get_nonnative_target(summand.clone())) + .collect(); + let summand_biguints: Vec<_> = summands + .iter() + .map(|summand| summand.to_canonical_biguint()) + .collect(); + + let sum_biguint = summand_biguints + .iter() + .fold(BigUint::zero(), |a, b| a + b.clone()); + + let modulus = FF::order(); + let (overflow_biguint, sum_reduced) = sum_biguint.div_rem(&modulus); + let overflow = overflow_biguint.to_u64_digits()[0] as u32; + + out_buffer.set_biguint_target(self.sum.value.clone(), sum_reduced); + out_buffer.set_u32_target(self.overflow, overflow); + } +} + +#[derive(Debug)] +struct NonNativeSubtractionGenerator, const D: usize, FF: Field> { + a: NonNativeTarget, + b: NonNativeTarget, + diff: NonNativeTarget, + overflow: BoolTarget, + _phantom: PhantomData, +} + +impl, const D: usize, FF: PrimeField> SimpleGenerator + for NonNativeSubtractionGenerator +{ + fn dependencies(&self) -> Vec { + self.a + .value + .limbs + .iter() + .cloned() + .chain(self.b.value.limbs.clone()) + .map(|l| l.0) + .collect() + } + + fn run_once(&self, witness: &PartitionWitness, out_buffer: &mut GeneratedValues) { + let a = witness.get_nonnative_target(self.a.clone()); + let b = witness.get_nonnative_target(self.b.clone()); + let a_biguint = a.to_canonical_biguint(); + let b_biguint = b.to_canonical_biguint(); + + let modulus = FF::order(); + let (diff_biguint, overflow) = if a_biguint > b_biguint { + (a_biguint - b_biguint, false) + } else { + (modulus + a_biguint - b_biguint, true) + }; + + out_buffer.set_biguint_target(self.diff.value.clone(), diff_biguint); + out_buffer.set_bool_target(self.overflow, overflow); + } +} + +#[derive(Debug)] +struct NonNativeMultiplicationGenerator, const D: usize, FF: Field> { + a: NonNativeTarget, + b: NonNativeTarget, + prod: NonNativeTarget, + overflow: BigUintTarget, + _phantom: PhantomData, +} + +impl, const D: usize, FF: PrimeField> SimpleGenerator + for NonNativeMultiplicationGenerator +{ + fn dependencies(&self) -> Vec { + self.a + .value + .limbs + .iter() + .cloned() + .chain(self.b.value.limbs.clone()) + .map(|l| l.0) + .collect() + } + + fn run_once(&self, witness: &PartitionWitness, out_buffer: &mut GeneratedValues) { + let a = witness.get_nonnative_target(self.a.clone()); + let b = witness.get_nonnative_target(self.b.clone()); + let a_biguint = a.to_canonical_biguint(); + let b_biguint = b.to_canonical_biguint(); + + let prod_biguint = a_biguint * b_biguint; + + let modulus = FF::order(); + let (overflow_biguint, prod_reduced) = prod_biguint.div_rem(&modulus); + + out_buffer.set_biguint_target(self.prod.value.clone(), prod_reduced); + out_buffer.set_biguint_target(self.overflow.clone(), overflow_biguint); + } +} + +#[derive(Debug)] +struct NonNativeInverseGenerator, const D: usize, FF: PrimeField> { + x: NonNativeTarget, + inv: BigUintTarget, + div: BigUintTarget, + _phantom: PhantomData, +} + +impl, const D: usize, FF: PrimeField> SimpleGenerator for NonNativeInverseGenerator { fn dependencies(&self) -> Vec { @@ -205,14 +496,21 @@ impl, const D: usize, FF: Field> SimpleGenerator let x = witness.get_nonnative_target(self.x.clone()); let inv = x.inverse(); - out_buffer.set_nonnative_target(self.inv.clone(), inv); + let x_biguint = x.to_canonical_biguint(); + let inv_biguint = inv.to_canonical_biguint(); + let prod = x_biguint * &inv_biguint; + let modulus = FF::order(); + let (div, _rem) = prod.div_rem(&modulus); + + out_buffer.set_biguint_target(self.div.clone(), div); + out_buffer.set_biguint_target(self.inv.clone(), inv_biguint); } } #[cfg(test)] mod tests { use anyhow::Result; - use plonky2_field::field_types::Field; + use plonky2_field::field_types::{Field, PrimeField}; use plonky2_field::secp256k1_base::Secp256K1Base; use crate::iop::witness::PartialWitness; @@ -227,11 +525,12 @@ mod tests { const D: usize = 2; type C = PoseidonGoldilocksConfig; type F = >::F; + let x_ff = FF::rand(); let y_ff = FF::rand(); let sum_ff = x_ff + y_ff; - let config = CircuitConfig::standard_recursion_config(); + let config = CircuitConfig::standard_ecc_config(); let pw = PartialWitness::new(); let mut builder = CircuitBuilder::::new(config); @@ -247,20 +546,61 @@ mod tests { verify(proof, &data.verifier_only, &data.common) } + #[test] + fn test_nonnative_many_adds() -> Result<()> { + type FF = Secp256K1Base; + const D: usize = 2; + type C = PoseidonGoldilocksConfig; + type F = >::F; + + let a_ff = FF::rand(); + let b_ff = FF::rand(); + let c_ff = FF::rand(); + let d_ff = FF::rand(); + let e_ff = FF::rand(); + let f_ff = FF::rand(); + let g_ff = FF::rand(); + let h_ff = FF::rand(); + let sum_ff = a_ff + b_ff + c_ff + d_ff + e_ff + f_ff + g_ff + h_ff; + + let config = CircuitConfig::standard_ecc_config(); + let pw = PartialWitness::new(); + let mut builder = CircuitBuilder::::new(config); + + let a = builder.constant_nonnative(a_ff); + let b = builder.constant_nonnative(b_ff); + let c = builder.constant_nonnative(c_ff); + let d = builder.constant_nonnative(d_ff); + let e = builder.constant_nonnative(e_ff); + let f = builder.constant_nonnative(f_ff); + let g = builder.constant_nonnative(g_ff); + let h = builder.constant_nonnative(h_ff); + let all = [a, b, c, d, e, f, g, h]; + let sum = builder.add_many_nonnative(&all); + + let sum_expected = builder.constant_nonnative(sum_ff); + builder.connect_nonnative(&sum, &sum_expected); + + let data = builder.build::(); + let proof = data.prove(pw).unwrap(); + verify(proof, &data.verifier_only, &data.common) + } + #[test] fn test_nonnative_sub() -> Result<()> { type FF = Secp256K1Base; const D: usize = 2; type C = PoseidonGoldilocksConfig; type F = >::F; + let x_ff = FF::rand(); let mut y_ff = FF::rand(); - while y_ff.to_biguint() > x_ff.to_biguint() { + while y_ff.to_canonical_biguint() > x_ff.to_canonical_biguint() { y_ff = FF::rand(); } let diff_ff = x_ff - y_ff; - let config = CircuitConfig::standard_recursion_config(); + let config = CircuitConfig::standard_ecc_config(); let pw = PartialWitness::new(); let mut builder = CircuitBuilder::::new(config); @@ -286,7 +626,7 @@ mod tests { let y_ff = FF::rand(); let product_ff = x_ff * y_ff; - let config = CircuitConfig::standard_recursion_config(); + let config = CircuitConfig::standard_ecc_config(); let pw = PartialWitness::new(); let mut builder = CircuitBuilder::::new(config); @@ -311,7 +651,7 @@ mod tests { let x_ff = FF::rand(); let neg_x_ff = -x_ff; - let config = CircuitConfig::standard_recursion_config(); + let config = CircuitConfig::standard_ecc_config(); let pw = PartialWitness::new(); let mut builder = CircuitBuilder::::new(config); @@ -335,7 +675,7 @@ mod tests { let x_ff = FF::rand(); let inv_x_ff = x_ff.inverse(); - let config = CircuitConfig::standard_recursion_config(); + let config = CircuitConfig::standard_ecc_config(); let pw = PartialWitness::new(); let mut builder = CircuitBuilder::::new(config); diff --git a/plonky2/src/gadgets/polynomial.rs b/plonky2/src/gadgets/polynomial.rs index 195eabd3..6e4a9bb4 100644 --- a/plonky2/src/gadgets/polynomial.rs +++ b/plonky2/src/gadgets/polynomial.rs @@ -6,6 +6,7 @@ use crate::iop::target::Target; use crate::plonk::circuit_builder::CircuitBuilder; use crate::util::reducing::ReducingFactorTarget; +#[derive(Debug)] pub struct PolynomialCoeffsExtTarget(pub Vec>); impl PolynomialCoeffsExtTarget { diff --git a/plonky2/src/gadgets/range_check.rs b/plonky2/src/gadgets/range_check.rs index f8ada106..0776fc68 100644 --- a/plonky2/src/gadgets/range_check.rs +++ b/plonky2/src/gadgets/range_check.rs @@ -1,5 +1,7 @@ use plonky2_field::extension_field::Extendable; +use crate::gadgets::arithmetic_u32::U32Target; +use crate::gates::range_check_u32::U32RangeCheckGate; use crate::hash::hash_types::RichField; use crate::iop::generator::{GeneratedValues, SimpleGenerator}; use crate::iop::target::{BoolTarget, Target}; @@ -41,6 +43,25 @@ impl, const D: usize> CircuitBuilder { (low, high) } + + pub fn range_check_u32(&mut self, vals: Vec) { + let num_input_limbs = vals.len(); + let gate = U32RangeCheckGate::::new(num_input_limbs); + let gate_index = self.add_gate(gate, vec![]); + + for i in 0..num_input_limbs { + self.connect( + Target::wire(gate_index, gate.wire_ith_input_limb(i)), + vals[i].0, + ); + } + } + + pub fn assert_bool(&mut self, b: BoolTarget) { + let z = self.mul_sub(b.target, b.target, b.target); + let zero = self.zero(); + self.connect(z, zero); + } } #[derive(Debug)] diff --git a/plonky2/src/gates/add_many_u32.rs b/plonky2/src/gates/add_many_u32.rs new file mode 100644 index 00000000..4f9c4293 --- /dev/null +++ b/plonky2/src/gates/add_many_u32.rs @@ -0,0 +1,461 @@ +use std::marker::PhantomData; + +use itertools::unfold; +use plonky2_util::ceil_div_usize; + +use crate::field::extension_field::Extendable; +use crate::field::field_types::Field; +use crate::gates::gate::Gate; +use crate::gates::util::StridedConstraintConsumer; +use crate::hash::hash_types::RichField; +use crate::iop::ext_target::ExtensionTarget; +use crate::iop::generator::{GeneratedValues, SimpleGenerator, WitnessGenerator}; +use crate::iop::target::Target; +use crate::iop::wire::Wire; +use crate::iop::witness::{PartitionWitness, Witness}; +use crate::plonk::circuit_builder::CircuitBuilder; +use crate::plonk::circuit_data::CircuitConfig; +use crate::plonk::vars::{EvaluationTargets, EvaluationVars, EvaluationVarsBase}; + +const LOG2_MAX_NUM_ADDENDS: usize = 4; +const MAX_NUM_ADDENDS: usize = 16; + +/// A gate to perform addition on `num_addends` different 32-bit values, plus a small carry +#[derive(Copy, Clone, Debug)] +pub struct U32AddManyGate, const D: usize> { + pub num_addends: usize, + pub num_ops: usize, + _phantom: PhantomData, +} + +impl, const D: usize> U32AddManyGate { + pub fn new_from_config(config: &CircuitConfig, num_addends: usize) -> Self { + Self { + num_addends, + num_ops: Self::num_ops(num_addends, config), + _phantom: PhantomData, + } + } + + pub(crate) fn num_ops(num_addends: usize, config: &CircuitConfig) -> usize { + debug_assert!(num_addends <= MAX_NUM_ADDENDS); + let wires_per_op = (num_addends + 3) + Self::num_limbs(); + let routed_wires_per_op = num_addends + 3; + (config.num_wires / wires_per_op).min(config.num_routed_wires / routed_wires_per_op) + } + + pub fn wire_ith_op_jth_addend(&self, i: usize, j: usize) -> usize { + debug_assert!(i < self.num_ops); + debug_assert!(j < self.num_addends); + (self.num_addends + 3) * i + j + } + pub fn wire_ith_carry(&self, i: usize) -> usize { + debug_assert!(i < self.num_ops); + (self.num_addends + 3) * i + self.num_addends + } + + pub fn wire_ith_output_result(&self, i: usize) -> usize { + debug_assert!(i < self.num_ops); + (self.num_addends + 3) * i + self.num_addends + 1 + } + pub fn wire_ith_output_carry(&self, i: usize) -> usize { + debug_assert!(i < self.num_ops); + (self.num_addends + 3) * i + self.num_addends + 2 + } + + pub fn limb_bits() -> usize { + 2 + } + pub fn num_result_limbs() -> usize { + ceil_div_usize(32, Self::limb_bits()) + } + pub fn num_carry_limbs() -> usize { + ceil_div_usize(LOG2_MAX_NUM_ADDENDS, Self::limb_bits()) + } + pub fn num_limbs() -> usize { + Self::num_result_limbs() + Self::num_carry_limbs() + } + + pub fn wire_ith_output_jth_limb(&self, i: usize, j: usize) -> usize { + debug_assert!(i < self.num_ops); + debug_assert!(j < Self::num_limbs()); + (self.num_addends + 3) * self.num_ops + Self::num_limbs() * i + j + } +} + +impl, const D: usize> Gate for U32AddManyGate { + fn id(&self) -> String { + format!("{:?}", self) + } + + fn eval_unfiltered(&self, vars: EvaluationVars) -> Vec { + let mut constraints = Vec::with_capacity(self.num_constraints()); + for i in 0..self.num_ops { + let addends: Vec = (0..self.num_addends) + .map(|j| vars.local_wires[self.wire_ith_op_jth_addend(i, j)]) + .collect(); + let carry = vars.local_wires[self.wire_ith_carry(i)]; + + let computed_output = addends.iter().fold(F::Extension::ZERO, |x, &y| x + y) + carry; + + let output_result = vars.local_wires[self.wire_ith_output_result(i)]; + let output_carry = vars.local_wires[self.wire_ith_output_carry(i)]; + + let base = F::Extension::from_canonical_u64(1 << 32u64); + let combined_output = output_carry * base + output_result; + + constraints.push(combined_output - computed_output); + + let mut combined_result_limbs = F::Extension::ZERO; + let mut combined_carry_limbs = F::Extension::ZERO; + let base = F::Extension::from_canonical_u64(1u64 << Self::limb_bits()); + for j in (0..Self::num_limbs()).rev() { + let this_limb = vars.local_wires[self.wire_ith_output_jth_limb(i, j)]; + let max_limb = 1 << Self::limb_bits(); + let product = (0..max_limb) + .map(|x| this_limb - F::Extension::from_canonical_usize(x)) + .product(); + constraints.push(product); + + if j < Self::num_result_limbs() { + combined_result_limbs = base * combined_result_limbs + this_limb; + } else { + combined_carry_limbs = base * combined_carry_limbs + this_limb; + } + } + constraints.push(combined_result_limbs - output_result); + constraints.push(combined_carry_limbs - output_carry); + } + + constraints + } + + fn eval_unfiltered_base_one( + &self, + vars: EvaluationVarsBase, + mut yield_constr: StridedConstraintConsumer, + ) { + for i in 0..self.num_ops { + let addends: Vec = (0..self.num_addends) + .map(|j| vars.local_wires[self.wire_ith_op_jth_addend(i, j)]) + .collect(); + let carry = vars.local_wires[self.wire_ith_carry(i)]; + + let computed_output = addends.iter().fold(F::ZERO, |x, &y| x + y) + carry; + + let output_result = vars.local_wires[self.wire_ith_output_result(i)]; + let output_carry = vars.local_wires[self.wire_ith_output_carry(i)]; + + let base = F::from_canonical_u64(1 << 32u64); + let combined_output = output_carry * base + output_result; + + yield_constr.one(combined_output - computed_output); + + let mut combined_result_limbs = F::ZERO; + let mut combined_carry_limbs = F::ZERO; + let base = F::from_canonical_u64(1u64 << Self::limb_bits()); + for j in (0..Self::num_limbs()).rev() { + let this_limb = vars.local_wires[self.wire_ith_output_jth_limb(i, j)]; + let max_limb = 1 << Self::limb_bits(); + let product = (0..max_limb) + .map(|x| this_limb - F::from_canonical_usize(x)) + .product(); + yield_constr.one(product); + + if j < Self::num_result_limbs() { + combined_result_limbs = base * combined_result_limbs + this_limb; + } else { + combined_carry_limbs = base * combined_carry_limbs + this_limb; + } + } + yield_constr.one(combined_result_limbs - output_result); + yield_constr.one(combined_carry_limbs - output_carry); + } + } + + fn eval_unfiltered_recursively( + &self, + builder: &mut CircuitBuilder, + vars: EvaluationTargets, + ) -> Vec> { + let mut constraints = Vec::with_capacity(self.num_constraints()); + + for i in 0..self.num_ops { + let addends: Vec> = (0..self.num_addends) + .map(|j| vars.local_wires[self.wire_ith_op_jth_addend(i, j)]) + .collect(); + let carry = vars.local_wires[self.wire_ith_carry(i)]; + + let mut computed_output = carry; + for addend in addends { + computed_output = builder.add_extension(computed_output, addend); + } + + let output_result = vars.local_wires[self.wire_ith_output_result(i)]; + let output_carry = vars.local_wires[self.wire_ith_output_carry(i)]; + + let base: F::Extension = F::from_canonical_u64(1 << 32u64).into(); + let base_target = builder.constant_extension(base); + let combined_output = + builder.mul_add_extension(output_carry, base_target, output_result); + + constraints.push(builder.sub_extension(combined_output, computed_output)); + + let mut combined_result_limbs = builder.zero_extension(); + let mut combined_carry_limbs = builder.zero_extension(); + let base = builder + .constant_extension(F::Extension::from_canonical_u64(1u64 << Self::limb_bits())); + for j in (0..Self::num_limbs()).rev() { + let this_limb = vars.local_wires[self.wire_ith_output_jth_limb(i, j)]; + let max_limb = 1 << Self::limb_bits(); + + let mut product = builder.one_extension(); + for x in 0..max_limb { + let x_target = + builder.constant_extension(F::Extension::from_canonical_usize(x)); + let diff = builder.sub_extension(this_limb, x_target); + product = builder.mul_extension(product, diff); + } + constraints.push(product); + + if j < Self::num_result_limbs() { + combined_result_limbs = + builder.mul_add_extension(base, combined_result_limbs, this_limb); + } else { + combined_carry_limbs = + builder.mul_add_extension(base, combined_carry_limbs, this_limb); + } + } + constraints.push(builder.sub_extension(combined_result_limbs, output_result)); + constraints.push(builder.sub_extension(combined_carry_limbs, output_carry)); + } + + constraints + } + + fn generators( + &self, + gate_index: usize, + _local_constants: &[F], + ) -> Vec>> { + (0..self.num_ops) + .map(|i| { + let g: Box> = Box::new( + U32AddManyGenerator { + gate: *self, + gate_index, + i, + _phantom: PhantomData, + } + .adapter(), + ); + g + }) + .collect() + } + + fn num_wires(&self) -> usize { + (self.num_addends + 3) * self.num_ops + Self::num_limbs() * self.num_ops + } + + fn num_constants(&self) -> usize { + 0 + } + + fn degree(&self) -> usize { + 1 << Self::limb_bits() + } + + fn num_constraints(&self) -> usize { + self.num_ops * (3 + Self::num_limbs()) + } +} + +#[derive(Clone, Debug)] +struct U32AddManyGenerator, const D: usize> { + gate: U32AddManyGate, + gate_index: usize, + i: usize, + _phantom: PhantomData, +} + +impl, const D: usize> SimpleGenerator + for U32AddManyGenerator +{ + fn dependencies(&self) -> Vec { + let local_target = |input| Target::wire(self.gate_index, input); + + (0..self.gate.num_addends) + .map(|j| local_target(self.gate.wire_ith_op_jth_addend(self.i, j))) + .chain([local_target(self.gate.wire_ith_carry(self.i))]) + .collect() + } + + fn run_once(&self, witness: &PartitionWitness, out_buffer: &mut GeneratedValues) { + let local_wire = |input| Wire { + gate: self.gate_index, + input, + }; + + let get_local_wire = |input| witness.get_wire(local_wire(input)); + + let addends: Vec<_> = (0..self.gate.num_addends) + .map(|j| get_local_wire(self.gate.wire_ith_op_jth_addend(self.i, j))) + .collect(); + let carry = get_local_wire(self.gate.wire_ith_carry(self.i)); + + let output = addends.iter().fold(F::ZERO, |x, &y| x + y) + carry; + let output_u64 = output.to_canonical_u64(); + + let output_carry_u64 = output_u64 >> 32; + let output_result_u64 = output_u64 & ((1 << 32) - 1); + + let output_carry = F::from_canonical_u64(output_carry_u64); + let output_result = F::from_canonical_u64(output_result_u64); + + let output_carry_wire = local_wire(self.gate.wire_ith_output_carry(self.i)); + let output_result_wire = local_wire(self.gate.wire_ith_output_result(self.i)); + + out_buffer.set_wire(output_carry_wire, output_carry); + out_buffer.set_wire(output_result_wire, output_result); + + let num_result_limbs = U32AddManyGate::::num_result_limbs(); + let num_carry_limbs = U32AddManyGate::::num_carry_limbs(); + let limb_base = 1 << U32AddManyGate::::limb_bits(); + + let split_to_limbs = |mut val, num| { + unfold((), move |_| { + let ret = val % limb_base; + val /= limb_base; + Some(ret) + }) + .take(num) + .map(F::from_canonical_u64) + }; + + let result_limbs = split_to_limbs(output_result_u64, num_result_limbs); + let carry_limbs = split_to_limbs(output_carry_u64, num_carry_limbs); + + for (j, limb) in result_limbs.chain(carry_limbs).enumerate() { + let wire = local_wire(self.gate.wire_ith_output_jth_limb(self.i, j)); + out_buffer.set_wire(wire, limb); + } + } +} + +#[cfg(test)] +mod tests { + use std::marker::PhantomData; + + use anyhow::Result; + use itertools::unfold; + use rand::Rng; + + use crate::field::extension_field::quartic::QuarticExtension; + use crate::field::field_types::Field; + use crate::field::goldilocks_field::GoldilocksField; + use crate::gates::add_many_u32::U32AddManyGate; + use crate::gates::gate::Gate; + use crate::gates::gate_testing::{test_eval_fns, test_low_degree}; + use crate::hash::hash_types::HashOut; + use crate::plonk::config::{GenericConfig, PoseidonGoldilocksConfig}; + use crate::plonk::vars::EvaluationVars; + + #[test] + fn low_degree() { + test_low_degree::(U32AddManyGate:: { + num_addends: 4, + num_ops: 3, + _phantom: PhantomData, + }) + } + + #[test] + fn eval_fns() -> Result<()> { + const D: usize = 2; + type C = PoseidonGoldilocksConfig; + type F = >::F; + test_eval_fns::(U32AddManyGate:: { + num_addends: 4, + num_ops: 3, + _phantom: PhantomData, + }) + } + + #[test] + fn test_gate_constraint() { + type F = GoldilocksField; + type FF = QuarticExtension; + const D: usize = 4; + const NUM_ADDENDS: usize = 10; + const NUM_U32_ADD_MANY_OPS: usize = 3; + + fn get_wires(addends: Vec>, carries: Vec) -> Vec { + let mut v0 = Vec::new(); + let mut v1 = Vec::new(); + + let num_result_limbs = U32AddManyGate::::num_result_limbs(); + let num_carry_limbs = U32AddManyGate::::num_carry_limbs(); + let limb_base = 1 << U32AddManyGate::::limb_bits(); + for op in 0..NUM_U32_ADD_MANY_OPS { + let adds = &addends[op]; + let ca = carries[op]; + + let output = adds.iter().sum::() + ca; + let output_result = output & ((1 << 32) - 1); + let output_carry = output >> 32; + + let split_to_limbs = |mut val, num| { + unfold((), move |_| { + let ret = val % limb_base; + val /= limb_base; + Some(ret) + }) + .take(num) + .map(F::from_canonical_u64) + }; + + let mut result_limbs: Vec<_> = + split_to_limbs(output_result, num_result_limbs).collect(); + let mut carry_limbs: Vec<_> = + split_to_limbs(output_carry, num_carry_limbs).collect(); + + for a in adds { + v0.push(F::from_canonical_u64(*a)); + } + v0.push(F::from_canonical_u64(ca)); + v0.push(F::from_canonical_u64(output_result)); + v0.push(F::from_canonical_u64(output_carry)); + v1.append(&mut result_limbs); + v1.append(&mut carry_limbs); + } + + v0.iter().chain(v1.iter()).map(|&x| x.into()).collect() + } + + let mut rng = rand::thread_rng(); + let addends: Vec> = (0..NUM_U32_ADD_MANY_OPS) + .map(|_| (0..NUM_ADDENDS).map(|_| rng.gen::() as u64).collect()) + .collect(); + let carries: Vec<_> = (0..NUM_U32_ADD_MANY_OPS) + .map(|_| rng.gen::() as u64) + .collect(); + + let gate = U32AddManyGate:: { + num_addends: NUM_ADDENDS, + num_ops: NUM_U32_ADD_MANY_OPS, + _phantom: PhantomData, + }; + + let vars = EvaluationVars { + local_constants: &[], + local_wires: &get_wires(addends, carries), + public_inputs_hash: &HashOut::rand(), + }; + + assert!( + gate.eval_unfiltered(vars).iter().all(|x| x.is_zero()), + "Gate constraints are not satisfied." + ); + } +} diff --git a/plonky2/src/gates/arithmetic_base.rs b/plonky2/src/gates/arithmetic_base.rs index 455dbc81..006f32ef 100644 --- a/plonky2/src/gates/arithmetic_base.rs +++ b/plonky2/src/gates/arithmetic_base.rs @@ -132,7 +132,7 @@ impl, const D: usize> Gate for ArithmeticGate ); g }) - .collect::>() + .collect() } fn num_wires(&self) -> usize { diff --git a/plonky2/src/gates/arithmetic_extension.rs b/plonky2/src/gates/arithmetic_extension.rs index 530b030c..acb480d5 100644 --- a/plonky2/src/gates/arithmetic_extension.rs +++ b/plonky2/src/gates/arithmetic_extension.rs @@ -139,7 +139,7 @@ impl, const D: usize> Gate for ArithmeticExte ); g }) - .collect::>() + .collect() } fn num_wires(&self) -> usize { diff --git a/plonky2/src/gates/arithmetic_u32.rs b/plonky2/src/gates/arithmetic_u32.rs index 10d16ec2..6f61b7b4 100644 --- a/plonky2/src/gates/arithmetic_u32.rs +++ b/plonky2/src/gates/arithmetic_u32.rs @@ -213,7 +213,7 @@ impl, const D: usize> Gate for U32ArithmeticG ); g }) - .collect::>() + .collect() } fn num_wires(&self) -> usize { @@ -440,10 +440,7 @@ mod tests { v1.append(&mut output_limbs_f); } - v0.iter() - .chain(v1.iter()) - .map(|&x| x.into()) - .collect::>() + v0.iter().chain(v1.iter()).map(|&x| x.into()).collect() } let mut rng = rand::thread_rng(); diff --git a/plonky2/src/gates/assert_le.rs b/plonky2/src/gates/assert_le.rs index 6d7ad508..b240df85 100644 --- a/plonky2/src/gates/assert_le.rs +++ b/plonky2/src/gates/assert_le.rs @@ -1,7 +1,7 @@ use std::marker::PhantomData; use plonky2_field::extension_field::Extendable; -use plonky2_field::field_types::{Field, PrimeField}; +use plonky2_field::field_types::{Field, Field64}; use plonky2_field::packed_field::PackedField; use plonky2_util::{bits_u64, ceil_div_usize}; @@ -26,7 +26,7 @@ use crate::plonk::vars::{ /// A gate for checking that one value is less than or equal to another. #[derive(Clone, Debug)] -pub struct AssertLessThanGate, const D: usize> { +pub struct AssertLessThanGate, const D: usize> { pub(crate) num_bits: usize, pub(crate) num_chunks: usize, _phantom: PhantomData, @@ -466,7 +466,8 @@ mod tests { use anyhow::Result; use plonky2_field::extension_field::quartic::QuarticExtension; - use plonky2_field::field_types::{Field, PrimeField}; + use plonky2_field::field_types::Field; + use plonky2_field::field_types::PrimeField64; use plonky2_field::goldilocks_field::GoldilocksField; use rand::Rng; @@ -589,7 +590,7 @@ mod tests { v.append(&mut chunks_equal); v.append(&mut intermediate_values); - v.iter().map(|&x| x.into()).collect::>() + v.iter().map(|&x| x.into()).collect() }; let mut rng = rand::thread_rng(); diff --git a/plonky2/src/gates/base_sum.rs b/plonky2/src/gates/base_sum.rs index c1cf6a49..c9b0b0f6 100644 --- a/plonky2/src/gates/base_sum.rs +++ b/plonky2/src/gates/base_sum.rs @@ -1,7 +1,7 @@ use std::ops::Range; use plonky2_field::extension_field::Extendable; -use plonky2_field::field_types::{Field, PrimeField}; +use plonky2_field::field_types::{Field, Field64}; use plonky2_field::packed_field::PackedField; use crate::gates::batchable::MultiOpsGate; @@ -32,7 +32,7 @@ impl BaseSumGate { Self { num_limbs } } - pub fn new_from_config(config: &CircuitConfig) -> Self { + pub fn new_from_config(config: &CircuitConfig) -> Self { let num_limbs = F::BITS.min(config.num_routed_wires - Self::START_LIMBS); Self::new(num_limbs) } diff --git a/plonky2/src/gates/comparison.rs b/plonky2/src/gates/comparison.rs index 7d3dd70f..b64a5394 100644 --- a/plonky2/src/gates/comparison.rs +++ b/plonky2/src/gates/comparison.rs @@ -1,7 +1,7 @@ use std::marker::PhantomData; use plonky2_field::extension_field::Extendable; -use plonky2_field::field_types::{Field, PrimeField}; +use plonky2_field::field_types::{Field, Field64}; use plonky2_field::packed_field::PackedField; use plonky2_util::{bits_u64, ceil_div_usize}; @@ -24,7 +24,7 @@ use crate::plonk::vars::{ /// A gate for checking that one value is less than or equal to another. #[derive(Clone, Debug)] -pub struct ComparisonGate, const D: usize> { +pub struct ComparisonGate, const D: usize> { pub(crate) num_bits: usize, pub(crate) num_chunks: usize, _phantom: PhantomData, @@ -541,7 +541,8 @@ mod tests { use std::marker::PhantomData; use anyhow::Result; - use plonky2_field::field_types::{Field, PrimeField}; + use plonky2_field::field_types::Field; + use plonky2_field::field_types::PrimeField64; use plonky2_field::goldilocks_field::GoldilocksField; use rand::Rng; @@ -679,7 +680,7 @@ mod tests { v.append(&mut intermediate_values); v.append(&mut msd_bits); - v.iter().map(|&x| x.into()).collect::>() + v.iter().map(|&x| x.into()).collect() }; let mut rng = rand::thread_rng(); diff --git a/plonky2/src/gates/gate.rs b/plonky2/src/gates/gate.rs index e5d3943f..5e9f28a8 100644 --- a/plonky2/src/gates/gate.rs +++ b/plonky2/src/gates/gate.rs @@ -112,7 +112,7 @@ pub trait Gate, const D: usize>: 'static + Send + S builder: &mut CircuitBuilder, mut vars: EvaluationTargets, prefix: &[bool], - combined_gate_constraints: &mut Vec>, + combined_gate_constraints: &mut [ExtensionTarget], ) { let filter = compute_filter_recursively(builder, prefix, vars.local_constants); vars.remove_prefix(prefix); diff --git a/plonky2/src/gates/gate_testing.rs b/plonky2/src/gates/gate_testing.rs index ea1ef9a4..51768ba8 100644 --- a/plonky2/src/gates/gate_testing.rs +++ b/plonky2/src/gates/gate_testing.rs @@ -10,7 +10,7 @@ use crate::hash::hash_types::RichField; use crate::iop::witness::{PartialWitness, Witness}; use crate::plonk::circuit_builder::CircuitBuilder; use crate::plonk::circuit_data::CircuitConfig; -use crate::plonk::config::GenericConfig; +use crate::plonk::config::{GenericConfig, Hasher}; use crate::plonk::vars::{EvaluationTargets, EvaluationVars, EvaluationVarsBaseBatch}; use crate::plonk::verifier::verify; use crate::util::transpose; @@ -92,7 +92,10 @@ pub fn test_eval_fns< const D: usize, >( gate: G, -) -> Result<()> { +) -> Result<()> +where + [(); C::Hasher::HASH_SIZE]:, +{ // Test that `eval_unfiltered` and `eval_unfiltered_base` are coherent. let wires_base = F::rand_vec(gate.num_wires()); let constants_base = F::rand_vec(gate.num_constants()); diff --git a/plonky2/src/gates/gate_tree.rs b/plonky2/src/gates/gate_tree.rs index 8c8ecfe1..66161333 100644 --- a/plonky2/src/gates/gate_tree.rs +++ b/plonky2/src/gates/gate_tree.rs @@ -228,9 +228,9 @@ mod tests { use crate::gates::arithmetic_extension::ArithmeticExtensionGate; use crate::gates::base_sum::BaseSumGate; use crate::gates::constant::ConstantGate; - use crate::gates::gmimc::GMiMCGate; use crate::gates::interpolation::HighDegreeInterpolationGate; use crate::gates::noop::NoopGate; + use crate::gates::poseidon::PoseidonGate; use crate::plonk::config::{GenericConfig, PoseidonGoldilocksConfig}; #[test] @@ -245,7 +245,7 @@ mod tests { GateRef::new(ConstantGate { num_consts: 4 }), GateRef::new(ArithmeticExtensionGate { num_ops: 4 }), GateRef::new(BaseSumGate::<4>::new(4)), - GateRef::new(GMiMCGate::::new()), + GateRef::new(PoseidonGate::::new()), GateRef::new(HighDegreeInterpolationGate::new(2)), ]; @@ -276,7 +276,7 @@ mod tests { assert!( gates_with_prefix .iter() - .all(|(g, p)| g.0.degree() + g.0.num_constants() + p.len() <= 8), + .all(|(g, p)| g.0.degree() + g.0.num_constants() + p.len() <= 9), "Total degree is larger than 8." ); diff --git a/plonky2/src/gates/gmimc.rs b/plonky2/src/gates/gmimc.rs index b3ff8969..8b137891 100644 --- a/plonky2/src/gates/gmimc.rs +++ b/plonky2/src/gates/gmimc.rs @@ -1,445 +1 @@ -use std::marker::PhantomData; -use plonky2_field::extension_field::Extendable; -use plonky2_field::field_types::Field; -use plonky2_field::packed_field::PackedField; - -use crate::gates::batchable::MultiOpsGate; -use crate::gates::gate::Gate; -use crate::gates::packed_util::PackedEvaluableBase; -use crate::gates::util::StridedConstraintConsumer; -use crate::hash::gmimc; -use crate::hash::gmimc::GMiMC; -use crate::hash::hash_types::RichField; -use crate::iop::ext_target::ExtensionTarget; -use crate::iop::generator::{GeneratedValues, SimpleGenerator, WitnessGenerator}; -use crate::iop::target::Target; -use crate::iop::wire::Wire; -use crate::iop::witness::{PartitionWitness, Witness}; -use crate::plonk::circuit_builder::CircuitBuilder; -use crate::plonk::vars::{ - EvaluationTargets, EvaluationVars, EvaluationVarsBase, EvaluationVarsBaseBatch, - EvaluationVarsBasePacked, -}; - -/// Evaluates a full GMiMC permutation with 12 state elements. -/// -/// This also has some extra features to make it suitable for efficiently verifying Merkle proofs. -/// It has a flag which can be used to swap the first four inputs with the next four, for ordering -/// sibling digests. -#[derive(Debug)] -pub struct GMiMCGate< - F: RichField + Extendable + GMiMC, - const D: usize, - const WIDTH: usize, -> { - _phantom: PhantomData, -} - -impl + GMiMC, const D: usize, const WIDTH: usize> - GMiMCGate -{ - pub fn new() -> Self { - GMiMCGate { - _phantom: PhantomData, - } - } - - /// The wire index for the `i`th input to the permutation. - pub fn wire_input(i: usize) -> usize { - i - } - - /// The wire index for the `i`th output to the permutation. - pub fn wire_output(i: usize) -> usize { - WIDTH + i - } - - /// If this is set to 1, the first four inputs will be swapped with the next four inputs. This - /// is useful for ordering hashes in Merkle proofs. Otherwise, this should be set to 0. - pub const WIRE_SWAP: usize = 2 * WIDTH; - - /// A wire which stores the input to the `i`th cubing. - fn wire_cubing_input(i: usize) -> usize { - 2 * WIDTH + 1 + i - } - - /// End of wire indices, exclusive. - fn end() -> usize { - 2 * WIDTH + 1 + gmimc::NUM_ROUNDS - } -} - -impl + GMiMC, const D: usize, const WIDTH: usize> Gate - for GMiMCGate -{ - fn id(&self) -> String { - format!(" {:?}", WIDTH, self) - } - - fn eval_unfiltered(&self, vars: EvaluationVars) -> Vec { - let mut constraints = Vec::with_capacity(self.num_constraints()); - - // Assert that `swap` is binary. - let swap = vars.local_wires[Self::WIRE_SWAP]; - constraints.push(swap * (swap - F::Extension::ONE)); - - let mut state = Vec::with_capacity(12); - for i in 0..4 { - let a = vars.local_wires[i]; - let b = vars.local_wires[i + 4]; - state.push(a + swap * (b - a)); - } - for i in 0..4 { - let a = vars.local_wires[i + 4]; - let b = vars.local_wires[i]; - state.push(a + swap * (b - a)); - } - for i in 8..12 { - state.push(vars.local_wires[i]); - } - - // Value that is implicitly added to each element. - // See https://affine.group/2020/02/starkware-challenge - let mut addition_buffer = F::Extension::ZERO; - - for r in 0..gmimc::NUM_ROUNDS { - let active = r % WIDTH; - let constant = F::from_canonical_u64(>::ROUND_CONSTANTS[r]); - let cubing_input = state[active] + addition_buffer + constant.into(); - let cubing_input_wire = vars.local_wires[Self::wire_cubing_input(r)]; - constraints.push(cubing_input - cubing_input_wire); - let f = cubing_input_wire.cube(); - addition_buffer += f; - state[active] -= f; - } - - for i in 0..WIDTH { - state[i] += addition_buffer; - constraints.push(state[i] - vars.local_wires[Self::wire_output(i)]); - } - - constraints - } - - fn eval_unfiltered_base_one( - &self, - _vars: EvaluationVarsBase, - _yield_constr: StridedConstraintConsumer, - ) { - panic!("use eval_unfiltered_base_packed instead"); - } - - fn eval_unfiltered_base_batch(&self, vars_base: EvaluationVarsBaseBatch) -> Vec { - self.eval_unfiltered_base_batch_packed(vars_base) - } - - fn eval_unfiltered_recursively( - &self, - builder: &mut CircuitBuilder, - vars: EvaluationTargets, - ) -> Vec> { - let mut constraints = Vec::with_capacity(self.num_constraints()); - - let swap = vars.local_wires[Self::WIRE_SWAP]; - constraints.push(builder.mul_sub_extension(swap, swap, swap)); - - let mut state = Vec::with_capacity(12); - for i in 0..4 { - let a = vars.local_wires[i]; - let b = vars.local_wires[i + 4]; - let delta = builder.sub_extension(b, a); - state.push(builder.mul_add_extension(swap, delta, a)); - } - for i in 0..4 { - let a = vars.local_wires[i + 4]; - let b = vars.local_wires[i]; - let delta = builder.sub_extension(b, a); - state.push(builder.mul_add_extension(swap, delta, a)); - } - for i in 8..12 { - state.push(vars.local_wires[i]); - } - - // Value that is implicitly added to each element. - // See https://affine.group/2020/02/starkware-challenge - let mut addition_buffer = builder.zero_extension(); - - for r in 0..gmimc::NUM_ROUNDS { - let active = r % WIDTH; - - let constant = F::from_canonical_u64(>::ROUND_CONSTANTS[r]); - let constant = builder.constant_extension(constant.into()); - let cubing_input = - builder.add_many_extension(&[state[active], addition_buffer, constant]); - let cubing_input_wire = vars.local_wires[Self::wire_cubing_input(r)]; - constraints.push(builder.sub_extension(cubing_input, cubing_input_wire)); - let f = builder.cube_extension(cubing_input_wire); - addition_buffer = builder.add_extension(addition_buffer, f); - state[active] = builder.sub_extension(state[active], f); - } - - for i in 0..WIDTH { - state[i] = builder.add_extension(state[i], addition_buffer); - constraints - .push(builder.sub_extension(state[i], vars.local_wires[Self::wire_output(i)])); - } - - constraints - } - - fn generators( - &self, - gate_index: usize, - _local_constants: &[F], - ) -> Vec>> { - let gen = GMiMCGenerator:: { - gate_index, - _phantom: PhantomData, - }; - vec![Box::new(gen.adapter())] - } - - fn num_wires(&self) -> usize { - Self::end() - } - - fn num_constants(&self) -> usize { - 0 - } - - fn degree(&self) -> usize { - 3 - } - - fn num_constraints(&self) -> usize { - gmimc::NUM_ROUNDS + WIDTH + 1 - } -} -impl + GMiMC, const D: usize, const WIDTH: usize> - MultiOpsGate for GMiMCGate -{ - fn num_ops(&self) -> usize { - 1 - } - - fn dependencies_ith_op(&self, _gate_index: usize, _i: usize) -> Vec { - unreachable!() - } -} - -impl + GMiMC, const D: usize, const WIDTH: usize> - PackedEvaluableBase for GMiMCGate -{ - fn eval_unfiltered_base_packed>( - &self, - vars: EvaluationVarsBasePacked

, - mut yield_constr: StridedConstraintConsumer

, - ) { - // Assert that `swap` is binary. - let swap = vars.local_wires[Self::WIRE_SWAP]; - yield_constr.one(swap * (swap - F::ONE)); - - let mut state = Vec::with_capacity(12); - for i in 0..4 { - let a = vars.local_wires[i]; - let b = vars.local_wires[i + 4]; - state.push(a + swap * (b - a)); - } - for i in 0..4 { - let a = vars.local_wires[i + 4]; - let b = vars.local_wires[i]; - state.push(a + swap * (b - a)); - } - for i in 8..12 { - state.push(vars.local_wires[i]); - } - - // Value that is implicitly added to each element. - // See https://affine.group/2020/02/starkware-challenge - let mut addition_buffer = P::ZEROS; - - for r in 0..gmimc::NUM_ROUNDS { - let active = r % WIDTH; - let constant = F::from_canonical_u64(>::ROUND_CONSTANTS[r]); - let cubing_input = state[active] + addition_buffer + constant; - let cubing_input_wire = vars.local_wires[Self::wire_cubing_input(r)]; - yield_constr.one(cubing_input - cubing_input_wire); - let f = cubing_input_wire.square() * cubing_input_wire; - addition_buffer += f; - state[active] -= f; - } - - for i in 0..WIDTH { - state[i] += addition_buffer; - yield_constr.one(state[i] - vars.local_wires[Self::wire_output(i)]); - } - } -} - -#[derive(Debug)] -struct GMiMCGenerator< - F: RichField + Extendable + GMiMC, - const D: usize, - const WIDTH: usize, -> { - gate_index: usize, - _phantom: PhantomData, -} - -impl + GMiMC, const D: usize, const WIDTH: usize> - SimpleGenerator for GMiMCGenerator -{ - fn dependencies(&self) -> Vec { - let mut dep_input_indices = Vec::with_capacity(WIDTH + 1); - for i in 0..WIDTH { - dep_input_indices.push(GMiMCGate::::wire_input(i)); - } - dep_input_indices.push(GMiMCGate::::WIRE_SWAP); - - dep_input_indices - .into_iter() - .map(|input| { - Target::Wire(Wire { - gate: self.gate_index, - input, - }) - }) - .collect() - } - - fn run_once(&self, witness: &PartitionWitness, out_buffer: &mut GeneratedValues) { - let mut state = (0..WIDTH) - .map(|i| { - witness.get_wire(Wire { - gate: self.gate_index, - input: GMiMCGate::::wire_input(i), - }) - }) - .collect::>(); - - let swap_value = witness.get_wire(Wire { - gate: self.gate_index, - input: GMiMCGate::::WIRE_SWAP, - }); - debug_assert!(swap_value == F::ZERO || swap_value == F::ONE); - if swap_value == F::ONE { - for i in 0..4 { - state.swap(i, 4 + i); - } - } - - // Value that is implicitly added to each element. - // See https://affine.group/2020/02/starkware-challenge - let mut addition_buffer = F::ZERO; - - for r in 0..gmimc::NUM_ROUNDS { - let active = r % WIDTH; - let constant = F::from_canonical_u64(>::ROUND_CONSTANTS[r]); - let cubing_input = state[active] + addition_buffer + constant; - out_buffer.set_wire( - Wire { - gate: self.gate_index, - input: GMiMCGate::::wire_cubing_input(r), - }, - cubing_input, - ); - let f = cubing_input.cube(); - addition_buffer += f; - state[active] -= f; - } - - for i in 0..WIDTH { - state[i] += addition_buffer; - out_buffer.set_wire( - Wire { - gate: self.gate_index, - input: GMiMCGate::::wire_output(i), - }, - state[i], - ); - } - } -} - -#[cfg(test)] -mod tests { - use anyhow::Result; - use plonky2_field::field_types::Field; - use plonky2_field::goldilocks_field::GoldilocksField; - - use crate::gates::gate_testing::{test_eval_fns, test_low_degree}; - use crate::gates::gmimc::GMiMCGate; - use crate::hash::gmimc::GMiMC; - use crate::iop::generator::generate_partial_witness; - use crate::iop::wire::Wire; - use crate::iop::witness::{PartialWitness, Witness}; - use crate::plonk::circuit_builder::CircuitBuilder; - use crate::plonk::circuit_data::CircuitConfig; - use crate::plonk::config::{GenericConfig, PoseidonGoldilocksConfig}; - - #[test] - fn generated_output() { - const D: usize = 2; - type C = PoseidonGoldilocksConfig; - type F = >::F; - const WIDTH: usize = 12; - - let config = CircuitConfig::standard_recursion_config(); - let mut builder = CircuitBuilder::new(config); - type Gate = GMiMCGate; - let gate = Gate::new(); - let gate_index = builder.add_gate(gate, vec![], vec![]); - let circuit = builder.build_prover::(); - - let permutation_inputs = (0..WIDTH).map(F::from_canonical_usize).collect::>(); - - let mut inputs = PartialWitness::new(); - inputs.set_wire( - Wire { - gate: gate_index, - input: Gate::WIRE_SWAP, - }, - F::ZERO, - ); - for i in 0..WIDTH { - inputs.set_wire( - Wire { - gate: gate_index, - input: Gate::wire_input(i), - }, - permutation_inputs[i], - ); - } - - let witness = generate_partial_witness(inputs, &circuit.prover_only, &circuit.common); - - let expected_outputs: [F; WIDTH] = - F::gmimc_permute_naive(permutation_inputs.try_into().unwrap()); - for i in 0..WIDTH { - let out = witness.get_wire(Wire { - gate: 0, - input: Gate::wire_output(i), - }); - assert_eq!(out, expected_outputs[i]); - } - } - - #[test] - fn low_degree() { - type F = GoldilocksField; - const WIDTH: usize = 12; - let gate = GMiMCGate::::new(); - test_low_degree(gate) - } - - #[test] - fn eval_fns() -> Result<()> { - const D: usize = 2; - type C = PoseidonGoldilocksConfig; - type F = >::F; - const WIDTH: usize = 12; - let gate = GMiMCGate::::new(); - test_eval_fns::(gate) - } -} diff --git a/plonky2/src/gates/interpolation.rs b/plonky2/src/gates/interpolation.rs index 325624b4..c8743050 100644 --- a/plonky2/src/gates/interpolation.rs +++ b/plonky2/src/gates/interpolation.rs @@ -355,7 +355,7 @@ mod tests { for i in 0..coeffs.len() { v.extend(coeffs.coeffs[i].0); } - v.iter().map(|&x| x.into()).collect::>() + v.iter().map(|&x| x.into()).collect() } // Get a working row for InterpolationGate. diff --git a/plonky2/src/gates/low_degree_interpolation.rs b/plonky2/src/gates/low_degree_interpolation.rs index 5c640775..8e7d91c8 100644 --- a/plonky2/src/gates/low_degree_interpolation.rs +++ b/plonky2/src/gates/low_degree_interpolation.rs @@ -456,7 +456,7 @@ mod tests { .take(gate.num_points() - 2) .flat_map(|ff| ff.0), ); - v.iter().map(|&x| x.into()).collect::>() + v.iter().map(|&x| x.into()).collect() } // Get a working row for LowDegreeInterpolationGate. diff --git a/plonky2/src/gates/mod.rs b/plonky2/src/gates/mod.rs index 3070b087..4f01b7ce 100644 --- a/plonky2/src/gates/mod.rs +++ b/plonky2/src/gates/mod.rs @@ -1,6 +1,7 @@ // Gates have `new` methods that return `GateRef`s. #![allow(clippy::new_ret_no_self)] +pub mod add_many_u32; pub mod arithmetic_base; pub mod arithmetic_extension; pub mod arithmetic_u32; @@ -12,7 +13,6 @@ pub mod constant; pub mod exponentiation; pub mod gate; pub mod gate_tree; -pub mod gmimc; pub mod interpolation; pub mod low_degree_interpolation; pub mod multiplication_extension; @@ -22,6 +22,7 @@ pub mod poseidon; pub(crate) mod poseidon_mds; pub(crate) mod public_input; pub mod random_access; +pub mod range_check_u32; pub mod reducing; pub mod reducing_extension; pub mod subtraction_u32; diff --git a/plonky2/src/gates/multiplication_extension.rs b/plonky2/src/gates/multiplication_extension.rs index 0b93359d..02e93eb1 100644 --- a/plonky2/src/gates/multiplication_extension.rs +++ b/plonky2/src/gates/multiplication_extension.rs @@ -126,7 +126,7 @@ impl, const D: usize> Gate for MulExtensionGa ); g }) - .collect::>() + .collect() } fn num_wires(&self) -> usize { diff --git a/plonky2/src/gates/random_access.rs b/plonky2/src/gates/random_access.rs index 6a93b259..4a17aae8 100644 --- a/plonky2/src/gates/random_access.rs +++ b/plonky2/src/gates/random_access.rs @@ -210,7 +210,7 @@ impl, const D: usize> Gate for RandomAccessGa ); g }) - .collect::>() + .collect() } fn num_wires(&self) -> usize { diff --git a/plonky2/src/gates/range_check_u32.rs b/plonky2/src/gates/range_check_u32.rs new file mode 100644 index 00000000..79e91de8 --- /dev/null +++ b/plonky2/src/gates/range_check_u32.rs @@ -0,0 +1,322 @@ +use std::marker::PhantomData; + +use plonky2_util::ceil_div_usize; + +use crate::field::extension_field::Extendable; +use crate::field::field_types::Field; +use crate::gates::gate::Gate; +use crate::gates::util::StridedConstraintConsumer; +use crate::hash::hash_types::RichField; +use crate::iop::ext_target::ExtensionTarget; +use crate::iop::generator::{GeneratedValues, SimpleGenerator, WitnessGenerator}; +use crate::iop::target::Target; +use crate::iop::witness::{PartitionWitness, Witness}; +use crate::plonk::circuit_builder::CircuitBuilder; +use crate::plonk::plonk_common::{reduce_with_powers, reduce_with_powers_ext_recursive}; +use crate::plonk::vars::{EvaluationTargets, EvaluationVars, EvaluationVarsBase}; + +/// A gate which can decompose a number into base B little-endian limbs. +#[derive(Copy, Clone, Debug)] +pub struct U32RangeCheckGate, const D: usize> { + pub num_input_limbs: usize, + _phantom: PhantomData, +} + +impl, const D: usize> U32RangeCheckGate { + pub fn new(num_input_limbs: usize) -> Self { + Self { + num_input_limbs, + _phantom: PhantomData, + } + } + + pub const AUX_LIMB_BITS: usize = 2; + pub const BASE: usize = 1 << Self::AUX_LIMB_BITS; + + fn aux_limbs_per_input_limb(&self) -> usize { + ceil_div_usize(32, Self::AUX_LIMB_BITS) + } + pub fn wire_ith_input_limb(&self, i: usize) -> usize { + debug_assert!(i < self.num_input_limbs); + i + } + pub fn wire_ith_input_limb_jth_aux_limb(&self, i: usize, j: usize) -> usize { + debug_assert!(i < self.num_input_limbs); + debug_assert!(j < self.aux_limbs_per_input_limb()); + self.num_input_limbs + self.aux_limbs_per_input_limb() * i + j + } +} + +impl, const D: usize> Gate for U32RangeCheckGate { + fn id(&self) -> String { + format!("{:?}", self) + } + + fn eval_unfiltered(&self, vars: EvaluationVars) -> Vec { + let mut constraints = Vec::with_capacity(self.num_constraints()); + + let base = F::Extension::from_canonical_usize(Self::BASE); + for i in 0..self.num_input_limbs { + let input_limb = vars.local_wires[self.wire_ith_input_limb(i)]; + let aux_limbs: Vec<_> = (0..self.aux_limbs_per_input_limb()) + .map(|j| vars.local_wires[self.wire_ith_input_limb_jth_aux_limb(i, j)]) + .collect(); + let computed_sum = reduce_with_powers(&aux_limbs, base); + + constraints.push(computed_sum - input_limb); + for aux_limb in aux_limbs { + constraints.push( + (0..Self::BASE) + .map(|i| aux_limb - F::Extension::from_canonical_usize(i)) + .product(), + ); + } + } + + constraints + } + + fn eval_unfiltered_base_one( + &self, + vars: EvaluationVarsBase, + mut yield_constr: StridedConstraintConsumer, + ) { + let base = F::from_canonical_usize(Self::BASE); + for i in 0..self.num_input_limbs { + let input_limb = vars.local_wires[self.wire_ith_input_limb(i)]; + let aux_limbs: Vec<_> = (0..self.aux_limbs_per_input_limb()) + .map(|j| vars.local_wires[self.wire_ith_input_limb_jth_aux_limb(i, j)]) + .collect(); + let computed_sum = reduce_with_powers(&aux_limbs, base); + + yield_constr.one(computed_sum - input_limb); + for aux_limb in aux_limbs { + yield_constr.one( + (0..Self::BASE) + .map(|i| aux_limb - F::from_canonical_usize(i)) + .product(), + ); + } + } + } + + fn eval_unfiltered_recursively( + &self, + builder: &mut CircuitBuilder, + vars: EvaluationTargets, + ) -> Vec> { + let mut constraints = Vec::with_capacity(self.num_constraints()); + + let base = builder.constant(F::from_canonical_usize(Self::BASE)); + for i in 0..self.num_input_limbs { + let input_limb = vars.local_wires[self.wire_ith_input_limb(i)]; + let aux_limbs: Vec<_> = (0..self.aux_limbs_per_input_limb()) + .map(|j| vars.local_wires[self.wire_ith_input_limb_jth_aux_limb(i, j)]) + .collect(); + let computed_sum = reduce_with_powers_ext_recursive(builder, &aux_limbs, base); + + constraints.push(builder.sub_extension(computed_sum, input_limb)); + for aux_limb in aux_limbs { + constraints.push({ + let mut acc = builder.one_extension(); + (0..Self::BASE).for_each(|i| { + // We update our accumulator as: + // acc' = acc (x - i) + // = acc x + (-i) acc + // Since -i is constant, we can do this in one arithmetic_extension call. + let neg_i = -F::from_canonical_usize(i); + acc = builder.arithmetic_extension(F::ONE, neg_i, acc, aux_limb, acc) + }); + acc + }); + } + } + + constraints + } + + fn generators( + &self, + gate_index: usize, + _local_constants: &[F], + ) -> Vec>> { + let gen = U32RangeCheckGenerator { + gate: *self, + gate_index, + }; + vec![Box::new(gen.adapter())] + } + + fn num_wires(&self) -> usize { + self.num_input_limbs * (1 + self.aux_limbs_per_input_limb()) + } + + fn num_constants(&self) -> usize { + 0 + } + + // Bounded by the range-check (x-0)*(x-1)*...*(x-BASE+1). + fn degree(&self) -> usize { + Self::BASE + } + + // 1 for checking the each sum of aux limbs, plus a range check for each aux limb. + fn num_constraints(&self) -> usize { + self.num_input_limbs * (1 + self.aux_limbs_per_input_limb()) + } +} + +#[derive(Debug)] +pub struct U32RangeCheckGenerator, const D: usize> { + gate: U32RangeCheckGate, + gate_index: usize, +} + +impl, const D: usize> SimpleGenerator + for U32RangeCheckGenerator +{ + fn dependencies(&self) -> Vec { + let num_input_limbs = self.gate.num_input_limbs; + (0..num_input_limbs) + .map(|i| Target::wire(self.gate_index, self.gate.wire_ith_input_limb(i))) + .collect() + } + + fn run_once(&self, witness: &PartitionWitness, out_buffer: &mut GeneratedValues) { + let num_input_limbs = self.gate.num_input_limbs; + for i in 0..num_input_limbs { + let sum_value = witness + .get_target(Target::wire( + self.gate_index, + self.gate.wire_ith_input_limb(i), + )) + .to_canonical_u64() as u32; + + let base = U32RangeCheckGate::::BASE as u32; + let limbs = (0..self.gate.aux_limbs_per_input_limb()).map(|j| { + Target::wire( + self.gate_index, + self.gate.wire_ith_input_limb_jth_aux_limb(i, j), + ) + }); + let limbs_value = (0..self.gate.aux_limbs_per_input_limb()) + .scan(sum_value, |acc, _| { + let tmp = *acc % base; + *acc /= base; + Some(F::from_canonical_u32(tmp)) + }) + .collect::>(); + + for (b, b_value) in limbs.zip(limbs_value) { + out_buffer.set_target(b, b_value); + } + } + } +} + +#[cfg(test)] +mod tests { + use std::marker::PhantomData; + + use anyhow::Result; + use itertools::unfold; + use plonky2_util::ceil_div_usize; + use rand::Rng; + + use crate::field::extension_field::quartic::QuarticExtension; + use crate::field::field_types::Field; + use crate::field::goldilocks_field::GoldilocksField; + use crate::gates::gate::Gate; + use crate::gates::gate_testing::{test_eval_fns, test_low_degree}; + use crate::gates::range_check_u32::U32RangeCheckGate; + use crate::hash::hash_types::HashOut; + use crate::plonk::config::{GenericConfig, PoseidonGoldilocksConfig}; + use crate::plonk::vars::EvaluationVars; + + #[test] + fn low_degree() { + test_low_degree::(U32RangeCheckGate::new(8)) + } + + #[test] + fn eval_fns() -> Result<()> { + const D: usize = 2; + type C = PoseidonGoldilocksConfig; + type F = >::F; + test_eval_fns::(U32RangeCheckGate::new(8)) + } + + fn test_gate_constraint(input_limbs: Vec) { + type F = GoldilocksField; + type FF = QuarticExtension; + const D: usize = 4; + const AUX_LIMB_BITS: usize = 2; + const BASE: usize = 1 << AUX_LIMB_BITS; + const AUX_LIMBS_PER_INPUT_LIMB: usize = ceil_div_usize(32, AUX_LIMB_BITS); + + fn get_wires(input_limbs: Vec) -> Vec { + let num_input_limbs = input_limbs.len(); + let mut v = Vec::new(); + + for i in 0..num_input_limbs { + let input_limb = input_limbs[i]; + + let split_to_limbs = |mut val, num| { + unfold((), move |_| { + let ret = val % (BASE as u64); + val /= BASE as u64; + Some(ret) + }) + .take(num) + .map(F::from_canonical_u64) + }; + + let mut aux_limbs: Vec<_> = + split_to_limbs(input_limb, AUX_LIMBS_PER_INPUT_LIMB).collect(); + + v.append(&mut aux_limbs); + } + + input_limbs + .iter() + .cloned() + .map(F::from_canonical_u64) + .chain(v.iter().cloned()) + .map(|x| x.into()) + .collect() + } + + let gate = U32RangeCheckGate:: { + num_input_limbs: 8, + _phantom: PhantomData, + }; + + let vars = EvaluationVars { + local_constants: &[], + local_wires: &get_wires(input_limbs), + public_inputs_hash: &HashOut::rand(), + }; + + assert!( + gate.eval_unfiltered(vars).iter().all(|x| x.is_zero()), + "Gate constraints are not satisfied." + ); + } + + #[test] + fn test_gate_constraint_good() { + let mut rng = rand::thread_rng(); + let input_limbs: Vec<_> = (0..8).map(|_| rng.gen::() as u64).collect(); + + test_gate_constraint(input_limbs); + } + + #[test] + #[should_panic] + fn test_gate_constraint_bad() { + let mut rng = rand::thread_rng(); + let input_limbs: Vec<_> = (0..8).map(|_| rng.gen()).collect(); + + test_gate_constraint(input_limbs); + } +} diff --git a/plonky2/src/gates/subtraction_u32.rs b/plonky2/src/gates/subtraction_u32.rs index a6fa5e74..a37648c4 100644 --- a/plonky2/src/gates/subtraction_u32.rs +++ b/plonky2/src/gates/subtraction_u32.rs @@ -355,7 +355,8 @@ mod tests { use anyhow::Result; use plonky2_field::extension_field::quartic::QuarticExtension; - use plonky2_field::field_types::{Field, PrimeField}; + use plonky2_field::field_types::Field; + use plonky2_field::field_types::PrimeField64; use plonky2_field::goldilocks_field::GoldilocksField; use rand::Rng; @@ -433,10 +434,7 @@ mod tests { v1.append(&mut output_limbs); } - v0.iter() - .chain(v1.iter()) - .map(|&x| x.into()) - .collect::>() + v0.iter().chain(v1.iter()).map(|&x| x.into()).collect() } let mut rng = rand::thread_rng(); diff --git a/plonky2/src/gates/switch.rs b/plonky2/src/gates/switch.rs index 62209720..ae4419df 100644 --- a/plonky2/src/gates/switch.rs +++ b/plonky2/src/gates/switch.rs @@ -448,7 +448,7 @@ mod tests { v.push(F::from_bool(switch)); } - v.iter().map(|&x| x.into()).collect::>() + v.iter().map(|&x| x.into()).collect() } let first_inputs: Vec> = (0..num_copies).map(|_| F::rand_vec(CHUNK_SIZE)).collect(); diff --git a/plonky2/src/hash/arch/aarch64/poseidon_goldilocks_neon.rs b/plonky2/src/hash/arch/aarch64/poseidon_goldilocks_neon.rs index a7f61bf5..f2276506 100644 --- a/plonky2/src/hash/arch/aarch64/poseidon_goldilocks_neon.rs +++ b/plonky2/src/hash/arch/aarch64/poseidon_goldilocks_neon.rs @@ -3,8 +3,9 @@ use std::arch::aarch64::*; use std::arch::asm; -use plonky2_field::field_types::PrimeField; +use plonky2_field::field_types::Field64; use plonky2_field::goldilocks_field::GoldilocksField; +use plonky2_util::branch_hint; use static_assertions::const_assert; use unroll::unroll_for_loops; @@ -108,6 +109,8 @@ const_assert!(check_round_const_bounds_init()); // ====================================== SCALAR ARITHMETIC ======================================= +const EPSILON: u64 = 0xffffffff; + /// Addition modulo ORDER accounting for wraparound. Correct only when a + b < 2**64 + ORDER. #[inline(always)] unsafe fn add_with_wraparound(a: u64, b: u64) -> u64 { @@ -124,39 +127,36 @@ unsafe fn add_with_wraparound(a: u64, b: u64) -> u64 { adj = lateout(reg) adj, options(pure, nomem, nostack), ); - res.wrapping_add(adj) // adj is EPSILON if wraparound occured and 0 otherwise + res + adj // adj is EPSILON if wraparound occured and 0 otherwise } -/// Addition of a and (b >> 32) modulo ORDER accounting for wraparound. +/// Subtraction of a and (b >> 32) modulo ORDER accounting for wraparound. #[inline(always)] unsafe fn sub_with_wraparound_lsr32(a: u64, b: u64) -> u64 { - let res: u64; - let adj: u64; - asm!( - "subs {res}, {a}, {b}, lsr #32", - // Set adj to 0xffffffff if subtraction underflowed and 0 otherwise. - // 'cc' for 'carry clear'. - // NB: The CF in ARM subtraction is the opposite of x86: CF set == underflow did not occur. - "csetm {adj:w}, cc", - a = in(reg) a, - b = in(reg) b, - res = lateout(reg) res, - adj = lateout(reg) adj, - options(pure, nomem, nostack), - ); - res.wrapping_sub(adj) // adj is EPSILON if underflow occured and 0 otherwise. + let b_hi = b >> 32; + // This could be done with a.overflowing_add(b_hi), but `checked_sub` signals to the compiler + // that overflow is unlikely (note: this is a standard library implementation detail, not part + // of the spec). + match a.checked_sub(b_hi) { + Some(res) => res, + None => { + // Super rare. Better off branching. + branch_hint(); + let res_wrapped = a.wrapping_sub(b_hi); + res_wrapped - EPSILON + } + } } /// Multiplication of the low word (i.e., x as u32) by EPSILON. #[inline(always)] unsafe fn mul_epsilon(x: u64) -> u64 { let res; - let epsilon: u64 = 0xffffffff; asm!( // Use UMULL to save one instruction. The compiler emits two: extract the low word and then multiply. "umull {res}, {x:w}, {epsilon:w}", x = in(reg) x, - epsilon = in(reg) epsilon, + epsilon = in(reg) EPSILON, res = lateout(reg) res, options(pure, nomem, nostack, preserves_flags), ); diff --git a/plonky2/src/hash/arch/x86_64/poseidon_goldilocks_avx2_bmi2.rs b/plonky2/src/hash/arch/x86_64/poseidon_goldilocks_avx2_bmi2.rs index 0fddeba7..804524ee 100644 --- a/plonky2/src/hash/arch/x86_64/poseidon_goldilocks_avx2_bmi2.rs +++ b/plonky2/src/hash/arch/x86_64/poseidon_goldilocks_avx2_bmi2.rs @@ -4,6 +4,7 @@ use std::mem::size_of; use plonky2_field::field_types::Field; use plonky2_field::goldilocks_field::GoldilocksField; +use plonky2_util::branch_hint; use static_assertions::const_assert; use crate::hash::poseidon::{ @@ -141,6 +142,16 @@ macro_rules! map3 { ($f:ident::<$l:literal>, $v:ident) => { ($f::<$l>($v.0), $f::<$l>($v.1), $f::<$l>($v.2)) }; + ($f:ident::<$l:literal>, $v1:ident, $v2:ident) => { + ( + $f::<$l>($v1.0, $v2.0), + $f::<$l>($v1.1, $v2.1), + $f::<$l>($v1.2, $v2.2), + ) + }; + ($f:ident, $v:ident) => { + ($f($v.0), $f($v.1), $f($v.2)) + }; ($f:ident, $v0:ident, $v1:ident) => { ($f($v0.0, $v1.0), $f($v0.1, $v1.1), $f($v0.2, $v1.2)) }; @@ -188,19 +199,32 @@ unsafe fn const_layer( unsafe fn square3( x: (__m256i, __m256i, __m256i), ) -> ((__m256i, __m256i, __m256i), (__m256i, __m256i, __m256i)) { - let sign_bit = _mm256_set1_epi64x(i64::MIN); - let x_hi = map3!(_mm256_srli_epi64::<32>, x); + let x_hi = { + // Move high bits to low position. The high bits of x_hi are ignored. Swizzle is faster than + // bitshift. This instruction only has a floating-point flavor, so we cast to/from float. + // This is safe and free. + let x_ps = map3!(_mm256_castsi256_ps, x); + let x_hi_ps = map3!(_mm256_movehdup_ps, x_ps); + map3!(_mm256_castps_si256, x_hi_ps) + }; + + // All pairwise multiplications. let mul_ll = map3!(_mm256_mul_epu32, x, x); let mul_lh = map3!(_mm256_mul_epu32, x, x_hi); let mul_hh = map3!(_mm256_mul_epu32, x_hi, x_hi); - let res_lo0_s = map3!(_mm256_xor_si256, mul_ll, rep sign_bit); + + // Bignum addition, but mul_lh is shifted by 33 bits (not 32). + let mul_ll_hi = map3!(_mm256_srli_epi64::<33>, mul_ll); + let t0 = map3!(_mm256_add_epi64, mul_lh, mul_ll_hi); + let t0_hi = map3!(_mm256_srli_epi64::<31>, t0); + let res_hi = map3!(_mm256_add_epi64, mul_hh, t0_hi); + + // Form low result by adding the mul_ll and the low 31 bits of mul_lh (shifted to the high + // position). let mul_lh_lo = map3!(_mm256_slli_epi64::<33>, mul_lh); - let res_lo1_s = map3!(_mm256_add_epi64, res_lo0_s, mul_lh_lo); - let carry = map3!(_mm256_cmpgt_epi64, res_lo0_s, res_lo1_s); - let mul_lh_hi = map3!(_mm256_srli_epi64::<31>, mul_lh); - let res_hi0 = map3!(_mm256_add_epi64, mul_hh, mul_lh_hi); - let res_hi1 = map3!(_mm256_sub_epi64, res_hi0, carry); - (res_lo1_s, res_hi1) + let res_lo = map3!(_mm256_add_epi64, mul_ll, mul_lh_lo); + + (res_lo, res_hi) } #[inline(always)] @@ -208,49 +232,110 @@ unsafe fn mul3( x: (__m256i, __m256i, __m256i), y: (__m256i, __m256i, __m256i), ) -> ((__m256i, __m256i, __m256i), (__m256i, __m256i, __m256i)) { - let sign_bit = _mm256_set1_epi64x(i64::MIN); - let y_hi = map3!(_mm256_srli_epi64::<32>, y); - let x_hi = map3!(_mm256_srli_epi64::<32>, x); + let epsilon = _mm256_set1_epi64x(0xffffffff); + let x_hi = { + // Move high bits to low position. The high bits of x_hi are ignored. Swizzle is faster than + // bitshift. This instruction only has a floating-point flavor, so we cast to/from float. + // This is safe and free. + let x_ps = map3!(_mm256_castsi256_ps, x); + let x_hi_ps = map3!(_mm256_movehdup_ps, x_ps); + map3!(_mm256_castps_si256, x_hi_ps) + }; + let y_hi = { + let y_ps = map3!(_mm256_castsi256_ps, y); + let y_hi_ps = map3!(_mm256_movehdup_ps, y_ps); + map3!(_mm256_castps_si256, y_hi_ps) + }; + + // All four pairwise multiplications let mul_ll = map3!(_mm256_mul_epu32, x, y); let mul_lh = map3!(_mm256_mul_epu32, x, y_hi); let mul_hl = map3!(_mm256_mul_epu32, x_hi, y); let mul_hh = map3!(_mm256_mul_epu32, x_hi, y_hi); - let mul_lh_lo = map3!(_mm256_slli_epi64::<32>, mul_lh); - let res_lo0_s = map3!(_mm256_xor_si256, mul_ll, rep sign_bit); - let mul_hl_lo = map3!(_mm256_slli_epi64::<32>, mul_hl); - let res_lo1_s = map3!(_mm256_add_epi64, res_lo0_s, mul_lh_lo); - let carry0 = map3!(_mm256_cmpgt_epi64, res_lo0_s, res_lo1_s); - let mul_lh_hi = map3!(_mm256_srli_epi64::<32>, mul_lh); - let res_lo2_s = map3!(_mm256_add_epi64, res_lo1_s, mul_hl_lo); - let carry1 = map3!(_mm256_cmpgt_epi64, res_lo1_s, res_lo2_s); - let mul_hl_hi = map3!(_mm256_srli_epi64::<32>, mul_hl); - let res_hi0 = map3!(_mm256_add_epi64, mul_hh, mul_lh_hi); - let res_hi1 = map3!(_mm256_add_epi64, res_hi0, mul_hl_hi); - let res_hi2 = map3!(_mm256_sub_epi64, res_hi1, carry0); - let res_hi3 = map3!(_mm256_sub_epi64, res_hi2, carry1); - (res_lo2_s, res_hi3) + + // Bignum addition + // Extract high 32 bits of mul_ll and add to mul_hl. This cannot overflow. + let mul_ll_hi = map3!(_mm256_srli_epi64::<32>, mul_ll); + let t0 = map3!(_mm256_add_epi64, mul_hl, mul_ll_hi); + // Extract low 32 bits of t0 and add to mul_lh. Again, this cannot overflow. + // Also, extract high 32 bits of t0 and add to mul_hh. + let t0_lo = map3!(_mm256_and_si256, t0, rep epsilon); + let t0_hi = map3!(_mm256_srli_epi64::<32>, t0); + let t1 = map3!(_mm256_add_epi64, mul_lh, t0_lo); + let t2 = map3!(_mm256_add_epi64, mul_hh, t0_hi); + // Lastly, extract the high 32 bits of t1 and add to t2. + let t1_hi = map3!(_mm256_srli_epi64::<32>, t1); + let res_hi = map3!(_mm256_add_epi64, t2, t1_hi); + + // Form res_lo by combining the low half of mul_ll with the low half of t1 (shifted into high + // position). + let t1_lo = { + let t1_ps = map3!(_mm256_castsi256_ps, t1); + let t1_lo_ps = map3!(_mm256_moveldup_ps, t1_ps); + map3!(_mm256_castps_si256, t1_lo_ps) + }; + let res_lo = map3!(_mm256_blend_epi32::<0xaa>, mul_ll, t1_lo); + + (res_lo, res_hi) +} + +/// Addition, where the second operand is `0 <= y < 0xffffffff00000001`. +#[inline(always)] +unsafe fn add_small( + x_s: (__m256i, __m256i, __m256i), + y: (__m256i, __m256i, __m256i), +) -> (__m256i, __m256i, __m256i) { + let res_wrapped_s = map3!(_mm256_add_epi64, x_s, y); + let mask = map3!(_mm256_cmpgt_epi32, x_s, res_wrapped_s); + let wrapback_amt = map3!(_mm256_srli_epi64::<32>, mask); // EPSILON if overflowed else 0. + let res_s = map3!(_mm256_add_epi64, res_wrapped_s, wrapback_amt); + res_s +} + +#[inline(always)] +unsafe fn maybe_adj_sub(res_wrapped_s: __m256i, mask: __m256i) -> __m256i { + // The subtraction is very unlikely to overflow so we're best off branching. + // The even u32s in `mask` are meaningless, so we want to ignore them. `_mm256_testz_pd` + // branches depending on the sign bit of double-precision (64-bit) floats. Bit cast `mask` to + // floating-point (this is free). + let mask_pd = _mm256_castsi256_pd(mask); + // `_mm256_testz_pd(mask_pd, mask_pd) == 1` iff all sign bits are 0, meaning that underflow + // did not occur for any of the vector elements. + if _mm256_testz_pd(mask_pd, mask_pd) == 1 { + res_wrapped_s + } else { + branch_hint(); + // Highly unlikely: underflow did occur. Find adjustment per element and apply it. + let adj_amount = _mm256_srli_epi64::<32>(mask); // EPSILON if underflow. + _mm256_sub_epi64(res_wrapped_s, adj_amount) + } +} + +/// Addition, where the second operand is much smaller than `0xffffffff00000001`. +#[inline(always)] +unsafe fn sub_tiny( + x_s: (__m256i, __m256i, __m256i), + y: (__m256i, __m256i, __m256i), +) -> (__m256i, __m256i, __m256i) { + let res_wrapped_s = map3!(_mm256_sub_epi64, x_s, y); + let mask = map3!(_mm256_cmpgt_epi32, res_wrapped_s, x_s); + let res_s = map3!(maybe_adj_sub, res_wrapped_s, mask); + res_s } #[inline(always)] unsafe fn reduce3( - (x_lo_s, x_hi): ((__m256i, __m256i, __m256i), (__m256i, __m256i, __m256i)), + (lo0, hi0): ((__m256i, __m256i, __m256i), (__m256i, __m256i, __m256i)), ) -> (__m256i, __m256i, __m256i) { - let epsilon = _mm256_set1_epi64x(0xffffffff); let sign_bit = _mm256_set1_epi64x(i64::MIN); - let x_hi_hi = map3!(_mm256_srli_epi64::<32>, x_hi); - let res0_s = map3!(_mm256_sub_epi64, x_lo_s, x_hi_hi); - let wraparound_mask0 = map3!(_mm256_cmpgt_epi32, res0_s, x_lo_s); - let wraparound_adj0 = map3!(_mm256_srli_epi64::<32>, wraparound_mask0); - let x_hi_lo = map3!(_mm256_and_si256, x_hi, rep epsilon); - let x_hi_lo_shifted = map3!(_mm256_slli_epi64::<32>, x_hi); - let res1_s = map3!(_mm256_sub_epi64, res0_s, wraparound_adj0); - let x_hi_lo_mul_epsilon = map3!(_mm256_sub_epi64, x_hi_lo_shifted, x_hi_lo); - let res2_s = map3!(_mm256_add_epi64, res1_s, x_hi_lo_mul_epsilon); - let wraparound_mask2 = map3!(_mm256_cmpgt_epi32, res1_s, res2_s); - let wraparound_adj2 = map3!(_mm256_srli_epi64::<32>, wraparound_mask2); - let res3_s = map3!(_mm256_add_epi64, res2_s, wraparound_adj2); - let res3 = map3!(_mm256_xor_si256, res3_s, rep sign_bit); - res3 + let epsilon = _mm256_set1_epi64x(0xffffffff); + let lo0_s = map3!(_mm256_xor_si256, lo0, rep sign_bit); + let hi_hi0 = map3!(_mm256_srli_epi64::<32>, hi0); + let lo1_s = sub_tiny(lo0_s, hi_hi0); + let t1 = map3!(_mm256_mul_epu32, hi0, rep epsilon); + let lo2_s = add_small(lo1_s, t1); + let lo2 = map3!(_mm256_xor_si256, lo2_s, rep sign_bit); + lo2 } #[inline(always)] @@ -757,9 +842,9 @@ unsafe fn partial_round( // multiplication where we've set the first element to 0.) Add the remaining bits now. // TODO: This is a bit of an afterthought, which is why these constants are loaded 22 // times... There's likely a better way of merging those results. - "vmovdqu ymm6, {mds_matrix}[rip]", - "vmovdqu ymm7, {mds_matrix}[rip + 32]", - "vmovdqu ymm8, {mds_matrix}[rip + 64]", + "vmovdqu ymm6, [{mds_matrix}]", + "vmovdqu ymm7, [{mds_matrix} + 32]", + "vmovdqu ymm8, [{mds_matrix} + 64]", "vpsllvq ymm9, ymm13, ymm6", "vpsllvq ymm10, ymm13, ymm7", "vpsllvq ymm11, ymm13, ymm8", @@ -775,7 +860,7 @@ unsafe fn partial_round( // Reduction required. state0a = in(reg) state0a, - mds_matrix = sym TOP_ROW_EXPS, + mds_matrix = in(reg) &TOP_ROW_EXPS, inout("ymm0") unreduced_lo0_s, inout("ymm1") unreduced_lo1_s, inout("ymm2") unreduced_lo2_s, diff --git a/plonky2/src/hash/gmimc.rs b/plonky2/src/hash/gmimc.rs index 13f7807f..8b137891 100644 --- a/plonky2/src/hash/gmimc.rs +++ b/plonky2/src/hash/gmimc.rs @@ -1,168 +1 @@ -use plonky2_field::extension_field::Extendable; -use plonky2_field::field_types::Field; -use plonky2_field::goldilocks_field::GoldilocksField; -use unroll::unroll_for_loops; -use crate::gates::gmimc::GMiMCGate; -use crate::hash::hash_types::{HashOut, RichField}; -use crate::hash::hashing::{compress, hash_n_to_hash, PlonkyPermutation, SPONGE_WIDTH}; -use crate::iop::target::{BoolTarget, Target}; -use crate::plonk::circuit_builder::CircuitBuilder; -use crate::plonk::config::{AlgebraicHasher, Hasher}; - -pub(crate) const NUM_ROUNDS: usize = 101; - -pub trait GMiMC: Field -where - [u64; NUM_ROUNDS]: Sized, -{ - const ROUND_CONSTANTS: [u64; NUM_ROUNDS]; - - #[unroll_for_loops] - fn gmimc_permute(mut xs: [Self; WIDTH]) -> [Self; WIDTH] { - // Value that is implicitly added to each element. - // See https://affine.group/2020/02/starkware-challenge - let mut addition_buffer = Self::ZERO; - - for (r, &constant) in Self::ROUND_CONSTANTS.iter().enumerate() { - let active = r % WIDTH; - let f = (xs[active] + addition_buffer + Self::from_canonical_u64(constant)).cube(); - addition_buffer += f; - xs[active] -= f; - } - - for i in 0..WIDTH { - xs[i] += addition_buffer; - } - - xs - } - - #[unroll_for_loops] - fn gmimc_permute_naive(mut xs: [Self; WIDTH]) -> [Self; WIDTH] { - for (r, &constant) in Self::ROUND_CONSTANTS.iter().enumerate() { - let active = r % WIDTH; - let f = (xs[active] + Self::from_canonical_u64(constant)).cube(); - for i in 0..WIDTH { - if i != active { - xs[i] += f; - } - } - } - - xs - } -} - -/// See `generate_constants` about how these were generated. -#[rustfmt::skip] -const GOLDILOCKS_ROUND_CONSTANTS: [u64; NUM_ROUNDS] = [ - 0xb585f767417ee042, 0x7746a55f77c10331, 0xb2fb0d321d356f7a, 0x0f6760a486f1621f, - 0xe10d6666b36abcdf, 0x8cae14cb455cc50b, 0xd438539cf2cee334, 0xef781c7d4c1fd8b4, - 0xcdc4a23a0aca4b1f, 0x277fa208d07b52e3, 0xe17653a300493d38, 0xc54302f27c287dc1, - 0x8628782231d47d10, 0x59cd1a8a690b49f2, 0xc3b919ad9efec0b0, 0xa484c4c637641d97, - 0x308bbd23f191398b, 0x6e4a40c1bf713cf1, 0x9a2eedb7510414fb, 0xe360c6e111c2c63b, - 0xd5c771901d4d89aa, 0xc35eae076e7d6b2f, 0x849c2656d0a09cad, 0xc0572c8c5cf1df2b, - 0xe9fa634a883b8bf3, 0xf56f6d4900fb1fdd, 0xf7d713e872a72a1b, 0x8297132b6ba47612, - 0xad6805e12ee8af1c, 0xac51d9f6485c22b9, 0x502ad7dc3bd56bf8, 0x57a1550c3761c577, - 0x66bbd30e99d311da, 0x0da2abef5e948f87, 0xf0612750443f8e94, 0x28b8ec3afb937d8c, - 0x92a756e6be54ca18, 0x70e741ec304e925d, 0x019d5ee2b037c59f, 0x6f6f2ed7a30707d1, - 0x7cf416d01e8c169c, 0x61df517bb17617df, 0x85dc499b4c67dbaa, 0x4b959b48dad27b23, - 0xe8be3e5e0dd779a0, 0xf5c0bc1e525ed8e6, 0x40b12cbf263cf853, 0xa637093f13e2ea3c, - 0x3cc3f89232e3b0c8, 0x2e479dc16bfe86c0, 0x6f49de07d6d39469, 0x213ce7beecc232de, - 0x5b043134851fc00a, 0xa2de45784a861506, 0x7103aaf97bed8dd5, 0x5326fc0dbb88a147, - 0xa9ceb750364cb77a, 0x27f8ec88cc9e991f, 0xfceb4fda8c93fb83, 0xfac6ff13b45b260e, - 0x7131aa455813380b, 0x93510360d5d68119, 0xad535b24fb96e3db, 0x4627f5c6b7efc045, - 0x645cf794e4da78a9, 0x241c70ed1ac2877f, 0xacb8e076b009e825, 0x3737e9db6477bd9d, - 0xe7ea5e344cd688ed, 0x90dee4a009214640, 0xd1b1edf7c77e74af, 0x0b65481bab42158e, - 0x99ad1aab4b4fe3e7, 0x438a7c91f1a360cd, 0xb60de3bd159088bf, 0xc99cab6b47a3e3bb, - 0x69a5ed92d5677cef, 0x5e7b329c482a9396, 0x5fc0ac0829f893c9, 0x32db82924fb757ea, - 0x0ade699c5cf24145, 0x7cc5583b46d7b5bb, 0x85df9ed31bf8abcb, 0x6604df501ad4de64, - 0xeb84f60941611aec, 0xda60883523989bd4, 0x8f97fe40bf3470bf, 0xa93f485ce0ff2b32, - 0x6704e8eebc2afb4b, 0xcee3e9ac788ad755, 0x510d0e66062a270d, 0xf6323f48d74634a0, - 0x0b508cdf04990c90, 0xf241708a4ef7ddf9, 0x60e75c28bb368f82, 0xa6217d8c3f0f9989, - 0x7159cd30f5435b53, 0x839b4e8fe97ec79f, 0x0d3f3e5e885db625, 0x8f7d83be1daea54b, - 0x780f22441e8dbc04, -]; - -impl GMiMC<8> for GoldilocksField { - const ROUND_CONSTANTS: [u64; NUM_ROUNDS] = GOLDILOCKS_ROUND_CONSTANTS; -} - -impl GMiMC<12> for GoldilocksField { - const ROUND_CONSTANTS: [u64; NUM_ROUNDS] = GOLDILOCKS_ROUND_CONSTANTS; -} - -pub struct GMiMCPermutation; -impl PlonkyPermutation for GMiMCPermutation { - fn permute(input: [F; SPONGE_WIDTH]) -> [F; SPONGE_WIDTH] { - F::gmimc_permute(input) - } -} - -#[derive(Copy, Clone, Debug, Eq, PartialEq)] -pub struct GMiMCHash; -impl Hasher for GMiMCHash { - const HASH_SIZE: usize = 4 * 8; - type Hash = HashOut; - type Permutation = GMiMCPermutation; - - fn hash(input: Vec, pad: bool) -> Self::Hash { - hash_n_to_hash::(input, pad) - } - - fn two_to_one(left: Self::Hash, right: Self::Hash) -> Self::Hash { - compress::(left, right) - } -} - -impl AlgebraicHasher for GMiMCHash { - fn permute_swapped( - inputs: [Target; SPONGE_WIDTH], - swap: BoolTarget, - builder: &mut CircuitBuilder, - ) -> [Target; SPONGE_WIDTH] - where - F: RichField + Extendable, - { - let gate_type = GMiMCGate::::new(); - let gate = builder.add_gate(gate_type, vec![], vec![]); - - let swap_wire = GMiMCGate::::WIRE_SWAP; - let swap_wire = Target::wire(gate, swap_wire); - builder.connect(swap.target, swap_wire); - - // Route input wires. - for i in 0..SPONGE_WIDTH { - let in_wire = GMiMCGate::::wire_input(i); - let in_wire = Target::wire(gate, in_wire); - builder.connect(inputs[i], in_wire); - } - - // Collect output wires. - (0..SPONGE_WIDTH) - .map(|i| Target::wire(gate, GMiMCGate::::wire_output(i))) - .collect::>() - .try_into() - .unwrap() - } -} - -#[cfg(test)] -mod tests { - use plonky2_field::goldilocks_field::GoldilocksField; - - use crate::hash::gmimc::GMiMC; - - fn check_consistency, const WIDTH: usize>() { - let xs = F::rand_arr::(); - let out = F::gmimc_permute(xs); - let out_naive = F::gmimc_permute_naive(xs); - assert_eq!(out, out_naive); - } - - #[test] - fn consistency() { - check_consistency::(); - } -} diff --git a/plonky2/src/hash/hash_types.rs b/plonky2/src/hash/hash_types.rs index 1cfa7de8..01062960 100644 --- a/plonky2/src/hash/hash_types.rs +++ b/plonky2/src/hash/hash_types.rs @@ -1,15 +1,14 @@ -use plonky2_field::field_types::{Field, PrimeField}; +use plonky2_field::field_types::{Field, PrimeField64}; use plonky2_field::goldilocks_field::GoldilocksField; use rand::Rng; use serde::{Deserialize, Deserializer, Serialize, Serializer}; -use crate::hash::gmimc::GMiMC; use crate::hash::poseidon::Poseidon; use crate::iop::target::Target; use crate::plonk::config::GenericHashOut; /// A prime order field with the features we need to use it as a base field in our argument system. -pub trait RichField: PrimeField + GMiMC<12> + Poseidon {} +pub trait RichField: PrimeField64 + Poseidon {} impl RichField for GoldilocksField {} @@ -32,14 +31,10 @@ impl HashOut { } } - pub fn from_partial(mut elements: Vec) -> Self { - debug_assert!(elements.len() <= 4); - while elements.len() < 4 { - elements.push(F::ZERO); - } - Self { - elements: [elements[0], elements[1], elements[2], elements[3]], - } + pub fn from_partial(elements_in: &[F]) -> Self { + let mut elements = [F::ZERO; 4]; + elements[0..elements_in.len()].copy_from_slice(elements_in); + Self { elements } } pub fn rand_from_rng(rng: &mut R) -> Self { @@ -94,25 +89,21 @@ impl Default for HashOut { /// Represents a ~256 bit hash output. #[derive(Copy, Clone, Debug)] pub struct HashOutTarget { - pub(crate) elements: [Target; 4], + pub elements: [Target; 4], } impl HashOutTarget { - pub(crate) fn from_vec(elements: Vec) -> Self { + pub fn from_vec(elements: Vec) -> Self { debug_assert!(elements.len() == 4); Self { elements: elements.try_into().unwrap(), } } - pub(crate) fn from_partial(mut elements: Vec, zero: Target) -> Self { - debug_assert!(elements.len() <= 4); - while elements.len() < 4 { - elements.push(zero); - } - Self { - elements: [elements[0], elements[1], elements[2], elements[3]], - } + pub fn from_partial(elements_in: &[Target], zero: Target) -> Self { + let mut elements = [zero; 4]; + elements[0..elements_in.len()].copy_from_slice(elements_in); + Self { elements } } } @@ -123,6 +114,18 @@ pub struct MerkleCapTarget(pub Vec); #[derive(Eq, PartialEq, Copy, Clone, Debug)] pub struct BytesHash(pub [u8; N]); +impl BytesHash { + pub fn rand_from_rng(rng: &mut R) -> Self { + let mut buf = [0; N]; + rng.fill_bytes(&mut buf); + Self(buf) + } + + pub fn rand() -> Self { + Self::rand_from_rng(&mut rand::thread_rng()) + } +} + impl GenericHashOut for BytesHash { fn to_bytes(&self) -> Vec { self.0.to_vec() diff --git a/plonky2/src/hash/hashing.rs b/plonky2/src/hash/hashing.rs index 2f6a725c..9d043ea3 100644 --- a/plonky2/src/hash/hashing.rs +++ b/plonky2/src/hash/hashing.rs @@ -12,50 +12,29 @@ pub(crate) const SPONGE_RATE: usize = 8; pub(crate) const SPONGE_CAPACITY: usize = 4; pub const SPONGE_WIDTH: usize = SPONGE_RATE + SPONGE_CAPACITY; -/// Hash the vector if necessary to reduce its length to ~256 bits. If it already fits, this is a -/// no-op. -pub fn hash_or_noop>(inputs: Vec) -> HashOut { - if inputs.len() <= 4 { - HashOut::from_partial(inputs) - } else { - hash_n_to_hash::(inputs, false) - } -} - impl, const D: usize> CircuitBuilder { pub fn hash_or_noop>(&mut self, inputs: Vec) -> HashOutTarget { let zero = self.zero(); if inputs.len() <= 4 { - HashOutTarget::from_partial(inputs, zero) + HashOutTarget::from_partial(&inputs, zero) } else { - self.hash_n_to_hash::(inputs, false) + self.hash_n_to_hash_no_pad::(inputs) } } - pub fn hash_n_to_hash>( + pub fn hash_n_to_hash_no_pad>( &mut self, inputs: Vec, - pad: bool, ) -> HashOutTarget { - HashOutTarget::from_vec(self.hash_n_to_m::(inputs, 4, pad)) + HashOutTarget::from_vec(self.hash_n_to_m_no_pad::(inputs, 4)) } - pub fn hash_n_to_m>( + pub fn hash_n_to_m_no_pad>( &mut self, - mut inputs: Vec, + inputs: Vec, num_outputs: usize, - pad: bool, ) -> Vec { let zero = self.zero(); - let one = self.one(); - - if pad { - inputs.push(zero); - while (inputs.len() + 1) % SPONGE_WIDTH != 0 { - inputs.push(one); - } - inputs.push(zero); - } let mut state = [zero; SPONGE_WIDTH]; @@ -69,7 +48,7 @@ impl, const D: usize> CircuitBuilder { } // Squeeze until we have the desired number of outputs. - let mut outputs = Vec::new(); + let mut outputs = Vec::with_capacity(num_outputs); loop { for i in 0..SPONGE_RATE { outputs.push(state[i]); @@ -97,22 +76,12 @@ pub trait PlonkyPermutation { fn permute(input: [F; SPONGE_WIDTH]) -> [F; SPONGE_WIDTH]; } -/// If `pad` is enabled, the message is padded using the pad10*1 rule. In general this is required -/// for the hash to be secure, but it can safely be disabled in certain cases, like if the input -/// length is fixed. -pub fn hash_n_to_m>( - mut inputs: Vec, +/// Hash a message without any padding step. Note that this can enable length-extension attacks. +/// However, it is still collision-resistant in cases where the input has a fixed length. +pub fn hash_n_to_m_no_pad>( + inputs: &[F], num_outputs: usize, - pad: bool, ) -> Vec { - if pad { - inputs.push(F::ZERO); - while (inputs.len() + 1) % SPONGE_WIDTH != 0 { - inputs.push(F::ONE); - } - inputs.push(F::ZERO); - } - let mut state = [F::ZERO; SPONGE_WIDTH]; // Absorb all input chunks. @@ -134,9 +103,6 @@ pub fn hash_n_to_m>( } } -pub fn hash_n_to_hash>( - inputs: Vec, - pad: bool, -) -> HashOut { - HashOut::from_vec(hash_n_to_m::(inputs, 4, pad)) +pub fn hash_n_to_hash_no_pad>(inputs: &[F]) -> HashOut { + HashOut::from_vec(hash_n_to_m_no_pad::(inputs, 4)) } diff --git a/plonky2/src/hash/keccak.rs b/plonky2/src/hash/keccak.rs index 78cf5dc3..9a061d82 100644 --- a/plonky2/src/hash/keccak.rs +++ b/plonky2/src/hash/keccak.rs @@ -56,9 +56,9 @@ impl Hasher for KeccakHash { type Hash = BytesHash; type Permutation = KeccakPermutation; - fn hash(input: Vec, _pad: bool) -> Self::Hash { + fn hash_no_pad(input: &[F]) -> Self::Hash { let mut buffer = Buffer::new(Vec::new()); - buffer.write_field_vec(&input).unwrap(); + buffer.write_field_vec(input).unwrap(); let mut arr = [0; N]; let hash_bytes = keccak(buffer.bytes()).0; arr.copy_from_slice(&hash_bytes[..N]); diff --git a/plonky2/src/hash/merkle_proofs.rs b/plonky2/src/hash/merkle_proofs.rs index 543c06fd..c3ebf406 100644 --- a/plonky2/src/hash/merkle_proofs.rs +++ b/plonky2/src/hash/merkle_proofs.rs @@ -17,7 +17,7 @@ pub struct MerkleProof> { pub siblings: Vec, } -#[derive(Clone)] +#[derive(Clone, Debug)] pub struct MerkleProofTarget { /// The Merkle digest of each sibling subtree, staying from the bottommost layer. pub siblings: Vec, @@ -30,9 +30,12 @@ pub(crate) fn verify_merkle_proof>( leaf_index: usize, merkle_cap: &MerkleCap, proof: &MerkleProof, -) -> Result<()> { +) -> Result<()> +where + [(); H::HASH_SIZE]:, +{ let mut index = leaf_index; - let mut current_digest = H::hash(leaf_data, false); + let mut current_digest = H::hash_or_noop(&leaf_data); for &sibling_digest in proof.siblings.iter() { let bit = index & 1; index >>= 1; diff --git a/plonky2/src/hash/merkle_tree.rs b/plonky2/src/hash/merkle_tree.rs index 88c1ebdc..5fbc441c 100644 --- a/plonky2/src/hash/merkle_tree.rs +++ b/plonky2/src/hash/merkle_tree.rs @@ -1,3 +1,7 @@ +use std::mem::MaybeUninit; +use std::slice; + +use plonky2_util::log2_strict; use rayon::prelude::*; use serde::{Deserialize, Serialize}; @@ -27,33 +31,131 @@ pub struct MerkleTree> { /// The data in the leaves of the Merkle tree. pub leaves: Vec>, - /// The layers of hashes in the tree. The first layer is the one at the bottom. - pub layers: Vec>, + /// The digests in the tree. Consists of `cap.len()` sub-trees, each corresponding to one + /// element in `cap`. Each subtree is contiguous and located at + /// `digests[digests.len() / cap.len() * i..digests.len() / cap.len() * (i + 1)]`. + /// Within each subtree, siblings are stored next to each other. The layout is, + /// left_child_subtree || left_child_digest || right_child_digest || right_child_subtree, where + /// left_child_digest and right_child_digest are H::Hash and left_child_subtree and + /// right_child_subtree recurse. Observe that the digest of a node is stored by its _parent_. + /// Consequently, the digests of the roots are not stored here (they can be found in `cap`). + pub digests: Vec, /// The Merkle cap. pub cap: MerkleCap, } +fn capacity_up_to_mut(v: &mut Vec, len: usize) -> &mut [MaybeUninit] { + assert!(v.capacity() >= len); + let v_ptr = v.as_mut_ptr().cast::>(); + unsafe { + // SAFETY: `v_ptr` is a valid pointer to a buffer of length at least `len`. Upon return, the + // lifetime will be bound to that of `v`. The underlying memory will not be deallocated as + // we hold the sole mutable reference to `v`. The contents of the slice may be + // uninitialized, but the `MaybeUninit` makes it safe. + slice::from_raw_parts_mut(v_ptr, len) + } +} + +fn fill_subtree>( + digests_buf: &mut [MaybeUninit], + leaves: &[Vec], +) -> H::Hash +where + [(); H::HASH_SIZE]:, +{ + assert_eq!(leaves.len(), digests_buf.len() / 2 + 1); + if digests_buf.is_empty() { + H::hash_or_noop(&leaves[0]) + } else { + // Layout is: left recursive output || left child digest + // || right child digest || right recursive output. + // Split `digests_buf` into the two recursive outputs (slices) and two child digests + // (references). + let (left_digests_buf, right_digests_buf) = digests_buf.split_at_mut(digests_buf.len() / 2); + let (left_digest_mem, left_digests_buf) = left_digests_buf.split_last_mut().unwrap(); + let (right_digest_mem, right_digests_buf) = right_digests_buf.split_first_mut().unwrap(); + // Split `leaves` between both children. + let (left_leaves, right_leaves) = leaves.split_at(leaves.len() / 2); + let (left_digest, right_digest) = rayon::join( + || fill_subtree::(left_digests_buf, left_leaves), + || fill_subtree::(right_digests_buf, right_leaves), + ); + left_digest_mem.write(left_digest); + right_digest_mem.write(right_digest); + H::two_to_one(left_digest, right_digest) + } +} + +fn fill_digests_buf>( + digests_buf: &mut [MaybeUninit], + cap_buf: &mut [MaybeUninit], + leaves: &[Vec], + cap_height: usize, +) where + [(); H::HASH_SIZE]:, +{ + // Special case of a tree that's all cap. The usual case will panic because we'll try to split + // an empty slice into chunks of `0`. (We would not need this if there was a way to split into + // `blah` chunks as opposed to chunks _of_ `blah`.) + if digests_buf.is_empty() { + debug_assert_eq!(cap_buf.len(), leaves.len()); + cap_buf + .par_iter_mut() + .zip(leaves) + .for_each(|(cap_buf, leaf)| { + cap_buf.write(H::hash_or_noop(leaf)); + }); + return; + } + + let subtree_digests_len = digests_buf.len() >> cap_height; + let subtree_leaves_len = leaves.len() >> cap_height; + let digests_chunks = digests_buf.par_chunks_exact_mut(subtree_digests_len); + let leaves_chunks = leaves.par_chunks_exact(subtree_leaves_len); + assert_eq!(digests_chunks.len(), cap_buf.len()); + assert_eq!(digests_chunks.len(), leaves_chunks.len()); + digests_chunks.zip(cap_buf).zip(leaves_chunks).for_each( + |((subtree_digests, subtree_cap), subtree_leaves)| { + // We have `1 << cap_height` sub-trees, one for each entry in `cap`. They are totally + // independent, so we schedule one task for each. `digests_buf` and `leaves` are split + // into `1 << cap_height` slices, one for each sub-tree. + subtree_cap.write(fill_subtree::(subtree_digests, subtree_leaves)); + }, + ); +} + impl> MerkleTree { - pub fn new(leaves: Vec>, cap_height: usize) -> Self { - let mut layers = vec![leaves - .par_iter() - .map(|l| H::hash(l.clone(), false)) - .collect::>()]; - while let Some(l) = layers.last() { - if l.len() == 1 << cap_height { - break; - } - let next_layer = l - .par_chunks(2) - .map(|chunk| H::two_to_one(chunk[0], chunk[1])) - .collect::>(); - layers.push(next_layer); + pub fn new(leaves: Vec>, cap_height: usize) -> Self + where + [(); H::HASH_SIZE]:, + { + let log2_leaves_len = log2_strict(leaves.len()); + assert!( + cap_height <= log2_leaves_len, + "cap height should be at most log2(leaves.len())" + ); + + let num_digests = 2 * (leaves.len() - (1 << cap_height)); + let mut digests = Vec::with_capacity(num_digests); + + let len_cap = 1 << cap_height; + let mut cap = Vec::with_capacity(len_cap); + + let digests_buf = capacity_up_to_mut(&mut digests, num_digests); + let cap_buf = capacity_up_to_mut(&mut cap, len_cap); + fill_digests_buf::(digests_buf, cap_buf, &leaves[..], cap_height); + + unsafe { + // SAFETY: `fill_digests_buf` and `cap` initialized the spare capacity up to + // `num_digests` and `len_cap`, resp. + digests.set_len(num_digests); + cap.set_len(len_cap); } - let cap = layers.pop().unwrap(); + Self { leaves, - layers, + digests, cap: MerkleCap(cap), } } @@ -64,17 +166,40 @@ impl> MerkleTree { /// Create a Merkle proof from a leaf index. pub fn prove(&self, leaf_index: usize) -> MerkleProof { - MerkleProof { - siblings: self - .layers - .iter() - .scan(leaf_index, |acc, layer| { - let index = *acc ^ 1; - *acc >>= 1; - Some(layer[index]) - }) - .collect(), - } + let cap_height = log2_strict(self.cap.len()); + let num_layers = log2_strict(self.leaves.len()) - cap_height; + debug_assert_eq!(leaf_index >> (cap_height + num_layers), 0); + + let digest_tree = { + let tree_index = leaf_index >> num_layers; + let tree_len = self.digests.len() >> cap_height; + &self.digests[tree_len * tree_index..tree_len * (tree_index + 1)] + }; + + // Mask out high bits to get the index within the sub-tree. + let mut pair_index = leaf_index & ((1 << num_layers) - 1); + let siblings = (0..num_layers) + .into_iter() + .map(|i| { + let parity = pair_index & 1; + pair_index >>= 1; + + // The layers' data is interleaved as follows: + // [layer 0, layer 1, layer 0, layer 2, layer 0, layer 1, layer 0, layer 3, ...]. + // Each of the above is a pair of siblings. + // `pair_index` is the index of the pair within layer `i`. + // The index of that the pair within `digests` is + // `pair_index * 2 ** (i + 1) + (2 ** i - 1)`. + let siblings_index = (pair_index << (i + 1)) + (1 << i) - 1; + // We have an index for the _pair_, but we want the index of the _sibling_. + // Double the pair index to get the index of the left sibling. Conditionally add `1` + // if we are to retrieve the right sibling. + let sibling_index = 2 * siblings_index + (1 - parity); + digest_tree[sibling_index] + }) + .collect(); + + MerkleProof { siblings } } } @@ -91,22 +216,50 @@ mod tests { (0..n).map(|_| F::rand_vec(k)).collect() } - fn verify_all_leaves< - F: RichField + Extendable, - C: GenericConfig, - const D: usize, - >( + fn verify_all_leaves, C: GenericConfig, const D: usize>( leaves: Vec>, - n: usize, - ) -> Result<()> { - let tree = MerkleTree::::new(leaves.clone(), 1); - for i in 0..n { + cap_height: usize, + ) -> Result<()> + where + [(); C::Hasher::HASH_SIZE]:, + { + let tree = MerkleTree::::new(leaves.clone(), cap_height); + for (i, leaf) in leaves.into_iter().enumerate() { let proof = tree.prove(i); - verify_merkle_proof(leaves[i].clone(), i, &tree.cap, &proof)?; + verify_merkle_proof(leaf, i, &tree.cap, &proof)?; } Ok(()) } + #[test] + #[should_panic] + fn test_cap_height_too_big() { + const D: usize = 2; + type C = PoseidonGoldilocksConfig; + type F = >::F; + + let log_n = 8; + let cap_height = log_n + 1; // Should panic if `cap_height > len_n`. + + let leaves = random_data::(1 << log_n, 7); + let _ = MerkleTree::>::Hasher>::new(leaves, cap_height); + } + + #[test] + fn test_cap_height_eq_log2_len() -> Result<()> { + const D: usize = 2; + type C = PoseidonGoldilocksConfig; + type F = >::F; + + let log_n = 8; + let n = 1 << log_n; + let leaves = random_data::(n, 7); + + verify_all_leaves::(leaves, log_n)?; + + Ok(()) + } + #[test] fn test_merkle_trees() -> Result<()> { const D: usize = 2; @@ -117,7 +270,7 @@ mod tests { let n = 1 << log_n; let leaves = random_data::(n, 7); - verify_all_leaves::(leaves, n)?; + verify_all_leaves::(leaves, 1)?; Ok(()) } diff --git a/plonky2/src/hash/mod.rs b/plonky2/src/hash/mod.rs index 5a8ccb3f..b8293920 100644 --- a/plonky2/src/hash/mod.rs +++ b/plonky2/src/hash/mod.rs @@ -1,5 +1,4 @@ mod arch; -pub mod gmimc; pub mod hash_types; pub mod hashing; pub mod keccak; @@ -8,4 +7,3 @@ pub mod merkle_tree; pub mod path_compression; pub mod poseidon; pub mod poseidon_goldilocks; -pub mod rescue; diff --git a/plonky2/src/hash/path_compression.rs b/plonky2/src/hash/path_compression.rs index 75c63331..6dae3d94 100644 --- a/plonky2/src/hash/path_compression.rs +++ b/plonky2/src/hash/path_compression.rs @@ -57,7 +57,10 @@ pub(crate) fn decompress_merkle_proofs>( compressed_proofs: &[MerkleProof], height: usize, cap_height: usize, -) -> Vec> { +) -> Vec> +where + [(); H::HASH_SIZE]:, +{ let num_leaves = 1 << height; let compressed_proofs = compressed_proofs.to_vec(); let mut decompressed_proofs = Vec::with_capacity(compressed_proofs.len()); @@ -66,7 +69,7 @@ pub(crate) fn decompress_merkle_proofs>( for (&i, v) in leaves_indices.iter().zip(leaves_data) { // Observe the leaves. - seen.insert(i + num_leaves, H::hash(v.to_vec(), false)); + seen.insert(i + num_leaves, H::hash_or_noop(v)); } // Iterators over the siblings. diff --git a/plonky2/src/hash/poseidon.rs b/plonky2/src/hash/poseidon.rs index 8e8d6ba5..83946b29 100644 --- a/plonky2/src/hash/poseidon.rs +++ b/plonky2/src/hash/poseidon.rs @@ -2,14 +2,14 @@ //! https://eprint.iacr.org/2019/458.pdf use plonky2_field::extension_field::{Extendable, FieldExtension}; -use plonky2_field::field_types::{Field, PrimeField}; +use plonky2_field::field_types::{Field, PrimeField64}; use unroll::unroll_for_loops; use crate::gates::gate::Gate; use crate::gates::poseidon::PoseidonGate; use crate::gates::poseidon_mds::PoseidonMdsGate; use crate::hash::hash_types::{HashOut, RichField}; -use crate::hash::hashing::{compress, hash_n_to_hash, PlonkyPermutation, SPONGE_WIDTH}; +use crate::hash::hashing::{compress, hash_n_to_hash_no_pad, PlonkyPermutation, SPONGE_WIDTH}; use crate::iop::ext_target::ExtensionTarget; use crate::iop::target::{BoolTarget, Target}; use crate::plonk::circuit_builder::CircuitBuilder; @@ -21,10 +21,10 @@ use crate::plonk::config::{AlgebraicHasher, Hasher}; // // NB: Changing any of these values will require regenerating all of // the precomputed constant arrays in this file. -pub(crate) const HALF_N_FULL_ROUNDS: usize = 4; +pub const HALF_N_FULL_ROUNDS: usize = 4; pub(crate) const N_FULL_ROUNDS_TOTAL: usize = 2 * HALF_N_FULL_ROUNDS; -pub(crate) const N_PARTIAL_ROUNDS: usize = 22; -pub(crate) const N_ROUNDS: usize = N_FULL_ROUNDS_TOTAL + N_PARTIAL_ROUNDS; +pub const N_PARTIAL_ROUNDS: usize = 22; +pub const N_ROUNDS: usize = N_FULL_ROUNDS_TOTAL + N_PARTIAL_ROUNDS; const MAX_WIDTH: usize = 12; // we only have width 8 and 12, and 12 is bigger. :) #[inline(always)] @@ -35,7 +35,7 @@ fn add_u160_u128((x_lo, x_hi): (u128, u32), y: u128) -> (u128, u32) { } #[inline(always)] -fn reduce_u160((n_lo, n_hi): (u128, u32)) -> F { +fn reduce_u160((n_lo, n_hi): (u128, u32)) -> F { let n_lo_hi = (n_lo >> 64) as u64; let n_lo_lo = n_lo as u64; let reduced_hi: u64 = F::from_noncanonical_u96((n_lo_hi, n_hi)).to_noncanonical_u64(); @@ -148,7 +148,7 @@ pub const ALL_ROUND_CONSTANTS: [u64; MAX_WIDTH * N_ROUNDS] = [ ]; const WIDTH: usize = SPONGE_WIDTH; -pub trait Poseidon: PrimeField { +pub trait Poseidon: PrimeField64 { // Total number of round constants required: width of the input // times number of rounds. const N_ROUND_CONSTANTS: usize = WIDTH * N_ROUNDS; @@ -633,8 +633,8 @@ impl Hasher for PoseidonHash { type Hash = HashOut; type Permutation = PoseidonPermutation; - fn hash(input: Vec, pad: bool) -> Self::Hash { - hash_n_to_hash::(input, pad) + fn hash_no_pad(input: &[F]) -> Self::Hash { + hash_n_to_hash_no_pad::(input) } fn two_to_one(left: Self::Hash, right: Self::Hash) -> Self::Hash { diff --git a/plonky2/src/hash/poseidon_goldilocks.rs b/plonky2/src/hash/poseidon_goldilocks.rs index b8f63ab4..7b82bb01 100644 --- a/plonky2/src/hash/poseidon_goldilocks.rs +++ b/plonky2/src/hash/poseidon_goldilocks.rs @@ -270,7 +270,8 @@ impl Poseidon for GoldilocksField { #[cfg(test)] mod tests { - use plonky2_field::field_types::{Field, PrimeField}; + use plonky2_field::field_types::Field; + use plonky2_field::field_types::PrimeField64; use plonky2_field::goldilocks_field::GoldilocksField as F; use crate::hash::poseidon::test_helpers::{check_consistency, check_test_vectors}; diff --git a/plonky2/src/hash/rescue.rs b/plonky2/src/hash/rescue.rs deleted file mode 100644 index 59e9d265..00000000 --- a/plonky2/src/hash/rescue.rs +++ /dev/null @@ -1,457 +0,0 @@ -//! Implements Rescue Prime. - -use plonky2_field::field_types::Field; -use unroll::unroll_for_loops; - -const ROUNDS: usize = 8; - -const W: usize = 12; - -const MDS: [[u64; W]; W] = [ - [ - 10760600708254618966, - 16769767337539665921, - 5534023221388089754, - 2049638230143736946, - 16140901062381928449, - 2635249153041947502, - 3074457345215605419, - 11068046442776179508, - 13835058053470224385, - 6148914690431210838, - 9223372035646816257, - 1, - ], - [ - 5675921252705733081, - 10760600708254618966, - 16769767337539665921, - 5534023221388089754, - 2049638230143736946, - 16140901062381928449, - 2635249153041947502, - 3074457345215605419, - 11068046442776179508, - 13835058053470224385, - 6148914690431210838, - 9223372035646816257, - ], - [ - 1317624576520973751, - 5675921252705733081, - 10760600708254618966, - 16769767337539665921, - 5534023221388089754, - 2049638230143736946, - 16140901062381928449, - 2635249153041947502, - 3074457345215605419, - 11068046442776179508, - 13835058053470224385, - 6148914690431210838, - ], - [ - 15987178195121148178, - 1317624576520973751, - 5675921252705733081, - 10760600708254618966, - 16769767337539665921, - 5534023221388089754, - 2049638230143736946, - 16140901062381928449, - 2635249153041947502, - 3074457345215605419, - 11068046442776179508, - 13835058053470224385, - ], - [ - 17293822566837780481, - 15987178195121148178, - 1317624576520973751, - 5675921252705733081, - 10760600708254618966, - 16769767337539665921, - 5534023221388089754, - 2049638230143736946, - 16140901062381928449, - 2635249153041947502, - 3074457345215605419, - 11068046442776179508, - ], - [ - 3255307777287111620, - 17293822566837780481, - 15987178195121148178, - 1317624576520973751, - 5675921252705733081, - 10760600708254618966, - 16769767337539665921, - 5534023221388089754, - 2049638230143736946, - 16140901062381928449, - 2635249153041947502, - 3074457345215605419, - ], - [ - 1024819115071868473, - 3255307777287111620, - 17293822566837780481, - 15987178195121148178, - 1317624576520973751, - 5675921252705733081, - 10760600708254618966, - 16769767337539665921, - 5534023221388089754, - 2049638230143736946, - 16140901062381928449, - 2635249153041947502, - ], - [ - 9708812669101911849, - 1024819115071868473, - 3255307777287111620, - 17293822566837780481, - 15987178195121148178, - 1317624576520973751, - 5675921252705733081, - 10760600708254618966, - 16769767337539665921, - 5534023221388089754, - 2049638230143736946, - 16140901062381928449, - ], - [ - 2767011610694044877, - 9708812669101911849, - 1024819115071868473, - 3255307777287111620, - 17293822566837780481, - 15987178195121148178, - 1317624576520973751, - 5675921252705733081, - 10760600708254618966, - 16769767337539665921, - 5534023221388089754, - 2049638230143736946, - ], - [ - 878416384347315834, - 2767011610694044877, - 9708812669101911849, - 1024819115071868473, - 3255307777287111620, - 17293822566837780481, - 15987178195121148178, - 1317624576520973751, - 5675921252705733081, - 10760600708254618966, - 16769767337539665921, - 5534023221388089754, - ], - [ - 17608255704416649217, - 878416384347315834, - 2767011610694044877, - 9708812669101911849, - 1024819115071868473, - 3255307777287111620, - 17293822566837780481, - 15987178195121148178, - 1317624576520973751, - 5675921252705733081, - 10760600708254618966, - 16769767337539665921, - ], - [ - 15238614667590392076, - 17608255704416649217, - 878416384347315834, - 2767011610694044877, - 9708812669101911849, - 1024819115071868473, - 3255307777287111620, - 17293822566837780481, - 15987178195121148178, - 1317624576520973751, - 5675921252705733081, - 10760600708254618966, - ], -]; - -const RESCUE_CONSTANTS: [[u64; W]; ROUNDS * 2] = [ - [ - 12050887499329086906, - 1748247961703512657, - 315780861775001585, - 2827656358919812970, - 13335864861236723579, - 3010729529365640897, - 8463534053828271146, - 2528500966106598845, - 8969871077123422281, - 1002624930202741107, - 599979829006456404, - 4386170815218774254, - ], - [ - 5771413917591851532, - 11946802620311685142, - 4759792267858670262, - 6879094914431255667, - 3985911073214909073, - 1542850118294175816, - 5393560436452023029, - 8331250756632997735, - 3395511836281190608, - 17601255793194446503, - 12848459944475727152, - 11995465655754698601, - ], - [ - 14063960046551560130, - 14790209580166185143, - 5509023472758717841, - 1274395897760495573, - 16719545989415697758, - 17865948122414223407, - 3919263713959798649, - 5633741078654387163, - 15665612362287352054, - 3418834727998553015, - 5324019631954832682, - 17962066557010997431, - ], - [ - 3282193104189649752, - 18423507935939999211, - 9035104445528866459, - 30842260240043277, - 3896337933354935129, - 6615548113269323045, - 6625827707190475694, - 6677757329269550670, - 11419013193186889337, - 17111888851716383760, - 12075517898615128691, - 8139844272075088233, - ], - [ - 8872892112814161072, - 17529364346566228604, - 7526576514327158912, - 850359069964902700, - 9679332912197531902, - 10591229741059812071, - 12759208863825924546, - 14552519355635838750, - 16066249893409806278, - 11283035366525176262, - 1047378652379935387, - 17032498397644511356, - ], - [ - 2938626421478254042, - 10375267398354586672, - 13728514869380643947, - 16707318479225743731, - 9785828188762698567, - 8610686976269299752, - 5478372191917042178, - 12716344455538470365, - 9968276048553747246, - 14746805727771473956, - 4822070620124107028, - 9901161649549513416, - ], - [ - 13458162407040644078, - 4045792126424269312, - 9709263167782315020, - 2163173014916005515, - 17079206331095671215, - 2556388076102629669, - 6582772486087242347, - 1239959540200663058, - 18268236910639895687, - 12499012548657350745, - 17213068585339946119, - 7641451088868756688, - ], - [ - 14674555473338434116, - 14624532976317185113, - 13625541984298615970, - 7612892294159054770, - 12294028208969561574, - 6067206081581804358, - 5778082506883496792, - 7389487446513884800, - 12929525660730020877, - 18244350162788654296, - 15285920877034454694, - 3640669683987215349, - ], - [ - 6737585134029996281, - 1826890539455248546, - 289376081355380231, - 10782622161517803787, - 12978425540147835172, - 9828233103297278473, - 16384075371934678711, - 3187492301890791304, - 12985433735185968457, - 9470935291631377473, - 16328323199113140151, - 16218490552434224203, - ], - [ - 6188809977565251499, - 18437718710937437067, - 4530469469895539008, - 9596355277372723349, - 13602518824447658705, - 8759976068576854281, - 10504320064094929535, - 3980760429843656150, - 14609448298151012462, - 5839843841558860609, - 10283805260656050418, - 7239168159249274821, - ], - [ - 3604243611640027441, - 5237321927316578323, - 5071861664926666316, - 13025405632646149705, - 3285281651566464074, - 12121596060272825779, - 1900602777802961569, - 8122527981264852045, - 6731303887159752901, - 9197659817406857040, - 844741616904786364, - 14249777686667858094, - ], - [ - 8602844218963499297, - 10133401373828451640, - 11618292280328565166, - 8828272598402499582, - 4252246265076774689, - 9760449011955070998, - 10233981507028897480, - 10427510555228840014, - 1007817664531124790, - 4465396600980659145, - 7727267420665314215, - 7904022788946844554, - ], - [ - 11418297156527169222, - 15865399053509010196, - 1727198235391450850, - 16557095577717348672, - 1524052121709169653, - 14531367160053894310, - 4071756280138432327, - 10333204220115446291, - 16584144375833061215, - 12237566480526488368, - 11090440024401607208, - 18281335018830792766, - ], - [ - 16152169547074248135, - 18338155611216027761, - 15842640128213925612, - 14687926435880145351, - 13259626900273707210, - 6187877366876303234, - 10312881470701795438, - 1924945292721719446, - 2278209355262975917, - 3250749056007953206, - 11589006946114672195, - 241829012299953928, - ], - [ - 11244459446597052449, - 7319043416418482137, - 8148526814449636806, - 9054933038587901070, - 550333919248348827, - 5513167392062632770, - 12644459803778263764, - 9903621375535446226, - 16390581784506871871, - 14586524717888286021, - 6975796306584548762, - 5200407948555191573, - ], - [ - 2855794043288846965, - 1259443213892506318, - 6145351706926586935, - 3853784494234324998, - 5871277378086513850, - 9414363368707862566, - 11946957446931890832, - 308083693687568600, - 12712587722369770461, - 6792392698104204991, - 16465224002344550280, - 10282380383506806095, - ], -]; - -pub fn rescue(mut xs: [F; W]) -> [F; W] { - for r in 0..8 { - xs = sbox_layer_a(xs); - xs = mds_layer(xs); - xs = constant_layer(xs, &RESCUE_CONSTANTS[r * 2]); - - xs = sbox_layer_b(xs); - xs = mds_layer(xs); - xs = constant_layer(xs, &RESCUE_CONSTANTS[r * 2 + 1]); - } - xs -} - -#[unroll_for_loops] -fn sbox_layer_a(x: [F; W]) -> [F; W] { - let mut result = [F::ZERO; W]; - for i in 0..W { - result[i] = x[i].cube(); - } - result -} - -#[unroll_for_loops] -fn sbox_layer_b(x: [F; W]) -> [F; W] { - let mut result = [F::ZERO; W]; - for i in 0..W { - result[i] = x[i].cube_root(); - } - result -} - -#[unroll_for_loops] -fn mds_layer(x: [F; W]) -> [F; W] { - let mut result = [F::ZERO; W]; - for r in 0..W { - for c in 0..W { - result[r] += F::from_canonical_u64(MDS[r][c]) * x[c]; - } - } - result -} - -#[unroll_for_loops] -fn constant_layer(xs: [F; W], con: &[u64; W]) -> [F; W] { - let mut result = [F::ZERO; W]; - for i in 0..W { - result[i] = xs[i] + F::from_canonical_u64(con[i]); - } - result -} diff --git a/plonky2/src/iop/challenger.rs b/plonky2/src/iop/challenger.rs index b8ca4fb7..1519f6ec 100644 --- a/plonky2/src/iop/challenger.rs +++ b/plonky2/src/iop/challenger.rs @@ -11,7 +11,6 @@ use crate::iop::ext_target::ExtensionTarget; use crate::iop::target::Target; use crate::plonk::circuit_builder::CircuitBuilder; use crate::plonk::config::{AlgebraicHasher, GenericHashOut, Hasher}; -use crate::plonk::proof::{OpeningSet, OpeningSetTarget}; /// Observes prover messages, and generates challenges by hashing the transcript, a la Fiat-Shamir. #[derive(Clone)] @@ -69,32 +68,6 @@ impl> Challenger { } } - pub fn observe_opening_set(&mut self, os: &OpeningSet) - where - F: RichField + Extendable, - { - let OpeningSet { - constants, - plonk_sigmas, - wires, - plonk_zs, - plonk_zs_right, - partial_products, - quotient_polys, - } = os; - for v in &[ - constants, - plonk_sigmas, - wires, - plonk_zs, - plonk_zs_right, - partial_products, - quotient_polys, - ] { - self.observe_extension_elements(v); - } - } - pub fn observe_hash>(&mut self, hash: OH::Hash) { self.observe_elements(&hash.to_vec()) } @@ -215,29 +188,6 @@ impl, H: AlgebraicHasher, const D: usize> } } - pub fn observe_opening_set(&mut self, os: &OpeningSetTarget) { - let OpeningSetTarget { - constants, - plonk_sigmas, - wires, - plonk_zs, - plonk_zs_right, - partial_products, - quotient_polys, - } = os; - for v in &[ - constants, - plonk_sigmas, - wires, - plonk_zs, - plonk_zs_right, - partial_products, - quotient_polys, - ] { - self.observe_extension_elements(v); - } - } - pub fn observe_hash(&mut self, hash: &HashOutTarget) { self.observe_elements(&hash.elements) } diff --git a/plonky2/src/iop/generator.rs b/plonky2/src/iop/generator.rs index 1c7779c6..0105e031 100644 --- a/plonky2/src/iop/generator.rs +++ b/plonky2/src/iop/generator.rs @@ -3,14 +3,14 @@ use std::marker::PhantomData; use num::BigUint; use plonky2_field::extension_field::{Extendable, FieldExtension}; -use plonky2_field::field_types::Field; +use plonky2_field::field_types::{Field, PrimeField}; use crate::gadgets::arithmetic_u32::U32Target; use crate::gadgets::biguint::BigUintTarget; use crate::gadgets::nonnative::NonNativeTarget; use crate::hash::hash_types::{HashOut, HashOutTarget, RichField}; use crate::iop::ext_target::ExtensionTarget; -use crate::iop::target::Target; +use crate::iop::target::{BoolTarget, Target}; use crate::iop::wire::Wire; use crate::iop::witness::{PartialWitness, PartitionWitness, Witness}; use crate::plonk::circuit_data::{CommonCircuitData, ProverOnlyCircuitData}; @@ -166,12 +166,17 @@ impl GeneratedValues { self.target_values.push((target, value)) } - fn set_u32_target(&mut self, target: U32Target, value: u32) { + pub fn set_bool_target(&mut self, target: BoolTarget, value: bool) { + self.set_target(target.target, F::from_bool(value)) + } + + pub fn set_u32_target(&mut self, target: U32Target, value: u32) { self.set_target(target.0, F::from_canonical_u32(value)) } pub fn set_biguint_target(&mut self, target: BigUintTarget, value: BigUint) { let mut limbs = value.to_u32_digits(); + assert!(target.num_limbs() >= limbs.len()); limbs.resize(target.num_limbs(), 0); @@ -180,8 +185,8 @@ impl GeneratedValues { } } - pub fn set_nonnative_target(&mut self, target: NonNativeTarget, value: FF) { - self.set_biguint_target(target.value, value.to_biguint()) + pub fn set_nonnative_target(&mut self, target: NonNativeTarget, value: FF) { + self.set_biguint_target(target.value, value.to_canonical_biguint()) } pub fn set_hash_target(&mut self, ht: HashOutTarget, value: HashOut) { diff --git a/plonky2/src/iop/mod.rs b/plonky2/src/iop/mod.rs index cc11fb56..de315a09 100644 --- a/plonky2/src/iop/mod.rs +++ b/plonky2/src/iop/mod.rs @@ -1,5 +1,5 @@ //! Logic common to multiple IOPs. -pub(crate) mod challenger; +pub mod challenger; pub mod ext_target; pub mod generator; pub mod target; diff --git a/plonky2/src/iop/witness.rs b/plonky2/src/iop/witness.rs index 29ad513e..e1bdf06e 100644 --- a/plonky2/src/iop/witness.rs +++ b/plonky2/src/iop/witness.rs @@ -1,9 +1,11 @@ use std::collections::HashMap; +use itertools::Itertools; use num::{BigUint, FromPrimitive, Zero}; use plonky2_field::extension_field::{Extendable, FieldExtension}; -use plonky2_field::field_types::Field; +use plonky2_field::field_types::{Field, PrimeField}; +use crate::fri::witness_util::set_fri_proof_target; use crate::gadgets::arithmetic_u32::U32Target; use crate::gadgets::biguint::BigUintTarget; use crate::gadgets::nonnative::NonNativeTarget; @@ -14,7 +16,8 @@ use crate::hash::merkle_tree::MerkleCap; use crate::iop::ext_target::ExtensionTarget; use crate::iop::target::{BoolTarget, Target}; use crate::iop::wire::Wire; -use crate::plonk::config::AlgebraicHasher; +use crate::plonk::config::{AlgebraicHasher, GenericConfig}; +use crate::plonk::proof::{Proof, ProofTarget, ProofWithPublicInputs, ProofWithPublicInputsTarget}; /// A witness holds information on the values of targets in a circuit. pub trait Witness { @@ -59,20 +62,26 @@ pub trait Witness { panic!("not a bool") } - fn get_biguint_target(&self, target: BigUintTarget) -> BigUint { + fn get_biguint_target(&self, target: BigUintTarget) -> BigUint + where + F: PrimeField, + { let mut result = BigUint::zero(); let limb_base = BigUint::from_u64(1 << 32u64).unwrap(); for i in (0..target.num_limbs()).rev() { let limb = target.get_limb(i); result *= &limb_base; - result += self.get_target(limb.0).to_biguint(); + result += self.get_target(limb.0).to_canonical_biguint(); } result } - fn get_nonnative_target(&self, target: NonNativeTarget) -> FF { + fn get_nonnative_target(&self, target: NonNativeTarget) -> FF + where + F: PrimeField, + { let val = self.get_biguint_target(target.value); FF::from_biguint(val) } @@ -155,6 +164,109 @@ pub trait Witness { } } + /// Set the targets in a `ProofWithPublicInputsTarget` to their corresponding values in a + /// `ProofWithPublicInputs`. + fn set_proof_with_pis_target, const D: usize>( + &mut self, + proof_with_pis_target: &ProofWithPublicInputsTarget, + proof_with_pis: &ProofWithPublicInputs, + ) where + F: RichField + Extendable, + C::Hasher: AlgebraicHasher, + { + let ProofWithPublicInputs { + proof, + public_inputs, + } = proof_with_pis; + let ProofWithPublicInputsTarget { + proof: pt, + public_inputs: pi_targets, + } = proof_with_pis_target; + + // Set public inputs. + for (&pi_t, &pi) in pi_targets.iter().zip_eq(public_inputs) { + self.set_target(pi_t, pi); + } + + self.set_proof_target(pt, proof); + } + + /// Set the targets in a `ProofTarget` to their corresponding values in a `Proof`. + fn set_proof_target, const D: usize>( + &mut self, + proof_target: &ProofTarget, + proof: &Proof, + ) where + F: RichField + Extendable, + C::Hasher: AlgebraicHasher, + { + self.set_cap_target(&proof_target.wires_cap, &proof.wires_cap); + self.set_cap_target( + &proof_target.plonk_zs_partial_products_cap, + &proof.plonk_zs_partial_products_cap, + ); + self.set_cap_target(&proof_target.quotient_polys_cap, &proof.quotient_polys_cap); + + for (&t, &x) in proof_target + .openings + .wires + .iter() + .zip_eq(&proof.openings.wires) + { + self.set_extension_target(t, x); + } + for (&t, &x) in proof_target + .openings + .constants + .iter() + .zip_eq(&proof.openings.constants) + { + self.set_extension_target(t, x); + } + for (&t, &x) in proof_target + .openings + .plonk_sigmas + .iter() + .zip_eq(&proof.openings.plonk_sigmas) + { + self.set_extension_target(t, x); + } + for (&t, &x) in proof_target + .openings + .plonk_zs + .iter() + .zip_eq(&proof.openings.plonk_zs) + { + self.set_extension_target(t, x); + } + for (&t, &x) in proof_target + .openings + .plonk_zs_right + .iter() + .zip_eq(&proof.openings.plonk_zs_right) + { + self.set_extension_target(t, x); + } + for (&t, &x) in proof_target + .openings + .partial_products + .iter() + .zip_eq(&proof.openings.partial_products) + { + self.set_extension_target(t, x); + } + for (&t, &x) in proof_target + .openings + .quotient_polys + .iter() + .zip_eq(&proof.openings.quotient_polys) + { + self.set_extension_target(t, x); + } + + set_fri_proof_target(self, &proof_target.opening_proof, &proof.opening_proof); + } + fn set_wire(&mut self, wire: Wire, value: F) { self.set_target(Target::Wire(wire), value) } diff --git a/plonky2/src/lib.rs b/plonky2/src/lib.rs index 3bddec82..e5e77bb9 100644 --- a/plonky2/src/lib.rs +++ b/plonky2/src/lib.rs @@ -6,7 +6,6 @@ #![allow(clippy::len_without_is_empty)] #![allow(clippy::needless_range_loop)] #![allow(clippy::return_self_not_must_use)] -#![feature(asm_sym)] #![feature(generic_const_exprs)] #![feature(specialization)] #![feature(stdsimd)] diff --git a/plonky2/src/plonk/circuit_builder.rs b/plonky2/src/plonk/circuit_builder.rs index b658e339..b5db44a7 100644 --- a/plonky2/src/plonk/circuit_builder.rs +++ b/plonky2/src/plonk/circuit_builder.rs @@ -10,11 +10,13 @@ use plonky2_field::field_types::Field; use plonky2_field::polynomial::PolynomialValues; use plonky2_util::{log2_ceil, log2_strict}; -use crate::fri::commitment::PolynomialBatchCommitment; +use crate::fri::oracle::PolynomialBatch; use crate::fri::{FriConfig, FriParams}; use crate::gadgets::arithmetic::BaseArithmeticOperation; use crate::gadgets::arithmetic_extension::ExtensionArithmeticOperation; use crate::gadgets::arithmetic_u32::U32Target; +use crate::gadgets::polynomial::PolynomialCoeffsExtTarget; +use crate::gates::add_many_u32::U32AddManyGate; use crate::gates::arithmetic_base::ArithmeticGate; use crate::gates::arithmetic_extension::ArithmeticExtensionGate; use crate::gates::batchable::{BatchableGate, CurrentSlot, GateRef}; @@ -24,6 +26,7 @@ use crate::gates::gate_tree::Tree; use crate::gates::noop::NoopGate; use crate::gates::public_input::PublicInputGate; use crate::hash::hash_types::{HashOutTarget, MerkleCapTarget, RichField}; +use crate::hash::merkle_proofs::MerkleProofTarget; use crate::iop::ext_target::ExtensionTarget; use crate::iop::generator::{ CopyGenerator, RandomValueGenerator, SimpleGenerator, WitnessGenerator, @@ -37,7 +40,8 @@ use crate::plonk::circuit_data::{ use crate::plonk::config::{GenericConfig, Hasher}; use crate::plonk::copy_constraint::CopyConstraint; use crate::plonk::permutation_argument::Forest; -use crate::plonk::plonk_common::PlonkPolynomials; +use crate::plonk::plonk_common::PlonkOracle; +use crate::timed; use crate::util::context_tree::ContextTree; use crate::util::marking::{Markable, MarkedTargets}; use crate::util::partial_products::num_partial_products; @@ -167,6 +171,12 @@ impl, const D: usize> CircuitBuilder { (0..n).map(|_i| self.add_virtual_hash()).collect() } + pub(crate) fn add_virtual_merkle_proof(&mut self, len: usize) -> MerkleProofTarget { + MerkleProofTarget { + siblings: self.add_virtual_hashes(len), + } + } + pub fn add_virtual_extension_target(&mut self) -> ExtensionTarget { ExtensionTarget(self.add_virtual_targets(D).try_into().unwrap()) } @@ -177,11 +187,25 @@ impl, const D: usize> CircuitBuilder { .collect() } + pub(crate) fn add_virtual_poly_coeff_ext( + &mut self, + num_coeffs: usize, + ) -> PolynomialCoeffsExtTarget { + let coeffs = self.add_virtual_extension_targets(num_coeffs); + PolynomialCoeffsExtTarget(coeffs) + } + // TODO: Unsafe pub fn add_virtual_bool_target(&mut self) -> BoolTarget { BoolTarget::new_unsafe(self.add_virtual_target()) } + pub fn add_virtual_bool_target_safe(&mut self) -> BoolTarget { + let b = BoolTarget::new_unsafe(self.add_virtual_target()); + self.assert_bool(b); + b + } + /// Adds a gate to the circuit, and returns its index. pub fn add_gate>( &mut self, @@ -219,7 +243,7 @@ impl, const D: usize> CircuitBuilder { fn check_gate_compatibility>(&self, gate: &G) { assert!( gate.num_wires() <= self.config.num_wires, - "{:?} requires {} wires, but our GateConfig has only {}", + "{:?} requires {} wires, but our CircuitConfig has only {}", gate.id(), gate.num_wires(), self.config.num_wires @@ -418,11 +442,12 @@ impl, const D: usize> CircuitBuilder { let fri_config = &self.config.fri_config; let reduction_arity_bits = fri_config.reduction_strategy.reduction_arity_bits( degree_bits, - self.config.fri_config.rate_bits, + fri_config.rate_bits, fri_config.num_query_rounds, ); FriParams { config: fri_config.clone(), + hiding: self.config.zero_knowledge, degree_bits, reduction_arity_bits, } @@ -631,16 +656,21 @@ impl, const D: usize> CircuitBuilder { } /// Builds a "full circuit", with both prover and verifier data. - pub fn build>(mut self) -> CircuitData { + pub fn build>(mut self) -> CircuitData + where + [(); C::Hasher::HASH_SIZE]:, + { let mut timing = TimingTree::new("preprocess", Level::Trace); let start = Instant::now(); + let rate_bits = self.config.fri_config.rate_bits; self.fill_batched_gates(); // Hash the public inputs, and route them to a `PublicInputGate` which will enforce that // those hash wires match the claimed public inputs. + let num_public_inputs = self.public_inputs.len(); let public_inputs_hash = - self.hash_n_to_hash::(self.public_inputs.clone(), true); + self.hash_n_to_hash_no_pad::(self.public_inputs.clone()); let pi_gate = self.add_gate(PublicInputGate, vec![], vec![]); for (&hash_part, wire) in public_inputs_hash .elements @@ -666,31 +696,41 @@ impl, const D: usize> CircuitBuilder { let gates = self.gates.iter().cloned().collect(); let (gate_tree, max_filtered_constraint_degree, num_constants) = Tree::from_gates(gates); + let prefixed_gates = PrefixedGate::from_tree(gate_tree); + // `quotient_degree_factor` has to be between `max_filtered_constraint_degree-1` and `1<, const D: usize> CircuitBuilder { constants_sigmas_cap.flatten(), vec![/* Add other circuit data here */], ]; - let circuit_digest = C::Hasher::hash(circuit_digest_parts.concat(), false); + let circuit_digest = C::Hasher::hash_no_pad(&circuit_digest_parts.concat()); let common = CommonCircuitData { config: self.config, @@ -774,11 +814,13 @@ impl, const D: usize> CircuitBuilder { num_gate_constraints, num_constants, num_virtual_targets: self.virtual_target_index, + num_public_inputs, k_is, num_partial_products, circuit_digest, }; + timing.print(); debug!("Building circuit took {}s", start.elapsed().as_secs_f32()); CircuitData { prover_only, @@ -788,7 +830,10 @@ impl, const D: usize> CircuitBuilder { } /// Builds a "prover circuit", with data needed to generate proofs but not verify them. - pub fn build_prover>(self) -> ProverCircuitData { + pub fn build_prover>(self) -> ProverCircuitData + where + [(); C::Hasher::HASH_SIZE]:, + { // TODO: Can skip parts of this. let CircuitData { prover_only, @@ -802,7 +847,10 @@ impl, const D: usize> CircuitBuilder { } /// Builds a "verifier circuit", with data needed to verify proofs but not generate them. - pub fn build_verifier>(self) -> VerifierCircuitData { + pub fn build_verifier>(self) -> VerifierCircuitData + where + [(); C::Hasher::HASH_SIZE]:, + { // TODO: Can skip parts of this. let CircuitData { verifier_only, @@ -817,332 +865,7 @@ impl, const D: usize> CircuitBuilder { } impl, const D: usize> CircuitBuilder { - // /// Finds the last available arithmetic gate with the given constants or add one if there aren't any. - // /// Returns `(g,i)` such that there is an arithmetic gate with the given constants at index - // /// `g` and the gate's `i`-th operation is available. - // pub(crate) fn find_base_arithmetic_gate(&mut self, const_0: F, const_1: F) -> (usize, usize) { - // let (gate, i) = self - // .batched_gates - // .free_base_arithmetic - // .get(&(const_0, const_1)) - // .copied() - // .unwrap_or_else(|| { - // let gate = self.add_gate( - // ArithmeticGate::new_from_config(&self.config), - // vec![const_0, const_1], - // ); - // (gate, 0) - // }); - // - // // Update `free_arithmetic` with new values. - // if i < ArithmeticGate::num_ops(&self.config) - 1 { - // self.batched_gates - // .free_base_arithmetic - // .insert((const_0, const_1), (gate, i + 1)); - // } else { - // self.batched_gates - // .free_base_arithmetic - // .remove(&(const_0, const_1)); - // } - // - // (gate, i) - // } - // - // /// Finds the last available arithmetic gate with the given constants or add one if there aren't any. - // /// Returns `(g,i)` such that there is an arithmetic gate with the given constants at index - // /// `g` and the gate's `i`-th operation is available. - // pub(crate) fn find_arithmetic_gate(&mut self, const_0: F, const_1: F) -> (usize, usize) { - // let (gate, i) = self - // .batched_gates - // .free_arithmetic - // .get(&(const_0, const_1)) - // .copied() - // .unwrap_or_else(|| { - // let gate = self.add_gate( - // ArithmeticExtensionGate::new_from_config(&self.config), - // vec![const_0, const_1], - // ); - // (gate, 0) - // }); - // - // // Update `free_arithmetic` with new values. - // if i < ArithmeticExtensionGate::::num_ops(&self.config) - 1 { - // self.batched_gates - // .free_arithmetic - // .insert((const_0, const_1), (gate, i + 1)); - // } else { - // self.batched_gates - // .free_arithmetic - // .remove(&(const_0, const_1)); - // } - // - // (gate, i) - // } - // - // /// Finds the last available arithmetic gate with the given constants or add one if there aren't any. - // /// Returns `(g,i)` such that there is an arithmetic gate with the given constants at index - // /// `g` and the gate's `i`-th operation is available. - // pub(crate) fn find_mul_gate(&mut self, const_0: F) -> (usize, usize) { - // let (gate, i) = self - // .batched_gates - // .free_mul - // .get(&const_0) - // .copied() - // .unwrap_or_else(|| { - // let gate = self.add_gate( - // MulExtensionGate::new_from_config(&self.config), - // vec![const_0], - // ); - // (gate, 0) - // }); - // - // // Update `free_arithmetic` with new values. - // if i < MulExtensionGate::::num_ops(&self.config) - 1 { - // self.batched_gates.free_mul.insert(const_0, (gate, i + 1)); - // } else { - // self.batched_gates.free_mul.remove(&const_0); - // } - // - // (gate, i) - // } - // - // /// Finds the last available random access gate with the given `vec_size` or add one if there aren't any. - // /// Returns `(g,i)` such that there is a random access gate with the given `vec_size` at index - // /// `g` and the gate's `i`-th random access is available. - // pub(crate) fn find_random_access_gate(&mut self, bits: usize) -> (usize, usize) { - // let (gate, i) = self - // .batched_gates - // .free_random_access - // .get(&bits) - // .copied() - // .unwrap_or_else(|| { - // let gate = self.add_gate( - // RandomAccessGate::new_from_config(&self.config, bits), - // vec![], - // ); - // (gate, 0) - // }); - // - // // Update `free_random_access` with new values. - // if i + 1 < RandomAccessGate::::new_from_config(&self.config, bits).num_copies { - // self.batched_gates - // .free_random_access - // .insert(bits, (gate, i + 1)); - // } else { - // self.batched_gates.free_random_access.remove(&bits); - // } - // - // (gate, i) - // } - // - // pub fn find_switch_gate(&mut self, chunk_size: usize) -> (SwitchGate, usize, usize) { - // if self.batched_gates.current_switch_gates.len() < chunk_size { - // self.batched_gates.current_switch_gates.extend(vec![ - // None; - // chunk_size - // - self - // .batched_gates - // .current_switch_gates - // .len() - // ]); - // } - // - // let (gate, gate_index, next_copy) = - // match self.batched_gates.current_switch_gates[chunk_size - 1].clone() { - // None => { - // let gate = SwitchGate::::new_from_config(&self.config, chunk_size); - // let gate_index = self.add_gate(gate.clone(), vec![]); - // (gate, gate_index, 0) - // } - // Some((gate, idx, next_copy)) => (gate, idx, next_copy), - // }; - // - // let num_copies = gate.num_copies; - // - // if next_copy == num_copies - 1 { - // self.batched_gates.current_switch_gates[chunk_size - 1] = None; - // } else { - // self.batched_gates.current_switch_gates[chunk_size - 1] = - // Some((gate.clone(), gate_index, next_copy + 1)); - // } - // - // (gate, gate_index, next_copy) - // } - // - // pub(crate) fn find_u32_arithmetic_gate(&mut self) -> (usize, usize) { - // let (gate_index, copy) = match self.batched_gates.current_u32_arithmetic_gate { - // None => { - // let gate = U32ArithmeticGate::new_from_config(&self.config); - // let gate_index = self.add_gate(gate, vec![]); - // (gate_index, 0) - // } - // Some((gate_index, copy)) => (gate_index, copy), - // }; - // - // if copy == U32ArithmeticGate::::num_ops(&self.config) - 1 { - // self.batched_gates.current_u32_arithmetic_gate = None; - // } else { - // self.batched_gates.current_u32_arithmetic_gate = Some((gate_index, copy + 1)); - // } - // - // (gate_index, copy) - // } - // - // pub(crate) fn find_u32_subtraction_gate(&mut self) -> (usize, usize) { - // let (gate_index, copy) = match self.batched_gates.current_u32_subtraction_gate { - // None => { - // let gate = U32SubtractionGate::new_from_config(&self.config); - // let gate_index = self.add_gate(gate, vec![]); - // (gate_index, 0) - // } - // Some((gate_index, copy)) => (gate_index, copy), - // }; - // - // if copy == U32SubtractionGate::::num_ops(&self.config) - 1 { - // self.batched_gates.current_u32_subtraction_gate = None; - // } else { - // self.batched_gates.current_u32_subtraction_gate = Some((gate_index, copy + 1)); - // } - // - // (gate_index, copy) - // } - // - // /// Returns the gate index and copy index of a free `ConstantGate` slot, potentially adding a - // /// new `ConstantGate` if needed. - // fn constant_gate_instance(&mut self) -> (usize, usize) { - // if self.batched_gates.free_constant.is_none() { - // let num_consts = self.config.constant_gate_size; - // // We will fill this `ConstantGate` with zero constants initially. - // // These will be overwritten by `constant` as the gate instances are filled. - // let gate = self.add_gate(ConstantGate { num_consts }, vec![F::ZERO; num_consts]); - // self.batched_gates.free_constant = Some((gate, 0)); - // } - // - // let (gate, instance) = self.batched_gates.free_constant.unwrap(); - // if instance + 1 < self.config.constant_gate_size { - // self.batched_gates.free_constant = Some((gate, instance + 1)); - // } else { - // self.batched_gates.free_constant = None; - // } - // (gate, instance) - // } - // - // /// Fill the remaining unused arithmetic operations with zeros, so that all - // /// `ArithmeticGate` are run. - // fn fill_base_arithmetic_gates(&mut self) { - // let zero = self.zero(); - // for ((c0, c1), (_gate, i)) in self.batched_gates.free_base_arithmetic.clone() { - // for _ in i..ArithmeticGate::num_ops(&self.config) { - // // If we directly wire in zero, an optimization will skip doing anything and return - // // zero. So we pass in a virtual target and connect it to zero afterward. - // let dummy = self.add_virtual_target(); - // self.arithmetic(c0, c1, dummy, dummy, dummy); - // self.connect(dummy, zero); - // } - // } - // assert!(self.batched_gates.free_base_arithmetic.is_empty()); - // } - // - // /// Fill the remaining unused arithmetic operations with zeros, so that all - // /// `ArithmeticExtensionGenerator`s are run. - // fn fill_arithmetic_gates(&mut self) { - // let zero = self.zero_extension(); - // for ((c0, c1), (_gate, i)) in self.batched_gates.free_arithmetic.clone() { - // for _ in i..ArithmeticExtensionGate::::num_ops(&self.config) { - // // If we directly wire in zero, an optimization will skip doing anything and return - // // zero. So we pass in a virtual target and connect it to zero afterward. - // let dummy = self.add_virtual_extension_target(); - // self.arithmetic_extension(c0, c1, dummy, dummy, dummy); - // self.connect_extension(dummy, zero); - // } - // } - // assert!(self.batched_gates.free_arithmetic.is_empty()); - // } - // - // /// Fill the remaining unused arithmetic operations with zeros, so that all - // /// `ArithmeticExtensionGenerator`s are run. - // fn fill_mul_gates(&mut self) { - // let zero = self.zero_extension(); - // for (c0, (_gate, i)) in self.batched_gates.free_mul.clone() { - // for _ in i..MulExtensionGate::::num_ops(&self.config) { - // // If we directly wire in zero, an optimization will skip doing anything and return - // // zero. So we pass in a virtual target and connect it to zero afterward. - // let dummy = self.add_virtual_extension_target(); - // self.arithmetic_extension(c0, F::ZERO, dummy, dummy, zero); - // self.connect_extension(dummy, zero); - // } - // } - // assert!(self.batched_gates.free_mul.is_empty()); - // } - // - // /// Fill the remaining unused random access operations with zeros, so that all - // /// `RandomAccessGenerator`s are run. - // fn fill_random_access_gates(&mut self) { - // let zero = self.zero(); - // for (bits, (_, i)) in self.batched_gates.free_random_access.clone() { - // let max_copies = - // RandomAccessGate::::new_from_config(&self.config, bits).num_copies; - // for _ in i..max_copies { - // self.random_access(zero, zero, vec![zero; 1 << bits]); - // } - // } - // } - // - // /// Fill the remaining unused switch gates with dummy values, so that all - // /// `SwitchGenerator`s are run. - // fn fill_switch_gates(&mut self) { - // let zero = self.zero(); - // - // for chunk_size in 1..=self.batched_gates.current_switch_gates.len() { - // if let Some((gate, gate_index, mut copy)) = - // self.batched_gates.current_switch_gates[chunk_size - 1].clone() - // { - // while copy < gate.num_copies { - // for element in 0..chunk_size { - // let wire_first_input = - // Target::wire(gate_index, gate.wire_first_input(copy, element)); - // let wire_second_input = - // Target::wire(gate_index, gate.wire_second_input(copy, element)); - // let wire_switch_bool = - // Target::wire(gate_index, gate.wire_switch_bool(copy)); - // self.connect(zero, wire_first_input); - // self.connect(zero, wire_second_input); - // self.connect(zero, wire_switch_bool); - // } - // copy += 1; - // } - // } - // } - // } - // - // /// Fill the remaining unused U32 arithmetic operations with zeros, so that all - // /// `U32ArithmeticGenerator`s are run. - // fn fill_u32_arithmetic_gates(&mut self) { - // let zero = self.zero_u32(); - // if let Some((_gate_index, copy)) = self.batched_gates.current_u32_arithmetic_gate { - // for _ in copy..U32ArithmeticGate::::num_ops(&self.config) { - // let dummy = self.add_virtual_u32_target(); - // self.mul_add_u32(dummy, dummy, dummy); - // self.connect_u32(dummy, zero); - // } - // } - // } - // - // /// Fill the remaining unused U32 subtraction operations with zeros, so that all - // /// `U32SubtractionGenerator`s are run. - // fn fill_u32_subtraction_gates(&mut self) { - // let zero = self.zero_u32(); - // if let Some((_gate_index, copy)) = self.batched_gates.current_u32_subtraction_gate { - // for _i in copy..U32SubtractionGate::::num_ops(&self.config) { - // let dummy = self.add_virtual_u32_target(); - // self.sub_u32(dummy, dummy, dummy); - // self.connect_u32(dummy, zero); - // } - // } - // } - // fn fill_batched_gates(&mut self) { - dbg!(&self.current_slots); let instances = self.gate_instances.clone(); for gate in instances { if let Some(slot) = self.current_slots.get(&gate.gate_ref) { diff --git a/plonky2/src/plonk/circuit_data.rs b/plonky2/src/plonk/circuit_data.rs index 74308fc3..3d4ee2df 100644 --- a/plonky2/src/plonk/circuit_data.rs +++ b/plonky2/src/plonk/circuit_data.rs @@ -5,16 +5,23 @@ use anyhow::Result; use plonky2_field::extension_field::Extendable; use plonky2_field::fft::FftRootTable; -use crate::fri::commitment::PolynomialBatchCommitment; +use crate::field::field_types::Field; +use crate::fri::oracle::PolynomialBatch; use crate::fri::reduction_strategies::FriReductionStrategy; +use crate::fri::structure::{ + FriBatchInfo, FriBatchInfoTarget, FriInstanceInfo, FriInstanceInfoTarget, FriPolynomialInfo, +}; use crate::fri::{FriConfig, FriParams}; use crate::gates::gate::PrefixedGate; use crate::hash::hash_types::{MerkleCapTarget, RichField}; use crate::hash::merkle_tree::MerkleCap; +use crate::iop::ext_target::ExtensionTarget; use crate::iop::generator::WitnessGenerator; use crate::iop::target::Target; use crate::iop::witness::PartialWitness; +use crate::plonk::circuit_builder::CircuitBuilder; use crate::plonk::config::{GenericConfig, Hasher}; +use crate::plonk::plonk_common::{PlonkOracle, FRI_ORACLES}; use crate::plonk::proof::{CompressedProofWithPublicInputs, ProofWithPublicInputs}; use crate::plonk::prover::prove; use crate::plonk::verifier::verify; @@ -34,21 +41,19 @@ pub struct CircuitConfig { /// `degree / |F|`. pub num_challenges: usize, pub zero_knowledge: bool, - + /// A cap on the quotient polynomial's degree factor. The actual degree factor is derived + /// systematically, but will never exceed this value. + pub max_quotient_degree_factor: usize, pub fri_config: FriConfig, } impl Default for CircuitConfig { fn default() -> Self { - CircuitConfig::standard_recursion_config() + Self::standard_recursion_config() } } impl CircuitConfig { - pub fn rate(&self) -> f64 { - 1.0 / ((1 << self.fri_config.rate_bits) as f64) - } - pub fn num_advice_wires(&self) -> usize { self.num_wires - self.num_routed_wires } @@ -63,6 +68,7 @@ impl CircuitConfig { security_bits: 100, num_challenges: 2, zero_knowledge: false, + max_quotient_degree_factor: 8, fri_config: FriConfig { rate_bits: 3, cap_height: 4, @@ -73,6 +79,13 @@ impl CircuitConfig { } } + pub fn standard_ecc_config() -> Self { + Self { + num_wires: 136, + ..Self::standard_recursion_config() + } + } + pub fn standard_recursion_zk_config() -> Self { CircuitConfig { zero_knowledge: true, @@ -91,7 +104,10 @@ pub struct CircuitData, C: GenericConfig, impl, C: GenericConfig, const D: usize> CircuitData { - pub fn prove(&self, inputs: PartialWitness) -> Result> { + pub fn prove(&self, inputs: PartialWitness) -> Result> + where + [(); C::Hasher::HASH_SIZE]:, + { prove( &self.prover_only, &self.common, @@ -100,14 +116,20 @@ impl, C: GenericConfig, const D: usize> ) } - pub fn verify(&self, proof_with_pis: ProofWithPublicInputs) -> Result<()> { + pub fn verify(&self, proof_with_pis: ProofWithPublicInputs) -> Result<()> + where + [(); C::Hasher::HASH_SIZE]:, + { verify(proof_with_pis, &self.verifier_only, &self.common) } pub fn verify_compressed( &self, compressed_proof_with_pis: CompressedProofWithPublicInputs, - ) -> Result<()> { + ) -> Result<()> + where + [(); C::Hasher::HASH_SIZE]:, + { compressed_proof_with_pis.verify(&self.verifier_only, &self.common) } } @@ -131,7 +153,10 @@ pub struct ProverCircuitData< impl, C: GenericConfig, const D: usize> ProverCircuitData { - pub fn prove(&self, inputs: PartialWitness) -> Result> { + pub fn prove(&self, inputs: PartialWitness) -> Result> + where + [(); C::Hasher::HASH_SIZE]:, + { prove( &self.prover_only, &self.common, @@ -155,14 +180,20 @@ pub struct VerifierCircuitData< impl, C: GenericConfig, const D: usize> VerifierCircuitData { - pub fn verify(&self, proof_with_pis: ProofWithPublicInputs) -> Result<()> { + pub fn verify(&self, proof_with_pis: ProofWithPublicInputs) -> Result<()> + where + [(); C::Hasher::HASH_SIZE]:, + { verify(proof_with_pis, &self.verifier_only, &self.common) } pub fn verify_compressed( &self, compressed_proof_with_pis: CompressedProofWithPublicInputs, - ) -> Result<()> { + ) -> Result<()> + where + [(); C::Hasher::HASH_SIZE]:, + { compressed_proof_with_pis.verify(&self.verifier_only, &self.common) } } @@ -178,7 +209,7 @@ pub(crate) struct ProverOnlyCircuitData< /// they watch. pub generator_indices_by_watches: BTreeMap>, /// Commitments to the constants polynomials and sigma polynomials. - pub constants_sigmas_commitment: PolynomialBatchCommitment, + pub constants_sigmas_commitment: PolynomialBatch, /// The transpose of the list of sigma polynomials. pub sigmas: Vec>, /// Subgroup of order `degree`. @@ -228,12 +259,13 @@ pub struct CommonCircuitData< pub(crate) num_virtual_targets: usize, + pub(crate) num_public_inputs: usize, + /// The `{k_i}` valued used in `S_ID_i` in Plonk's permutation argument. pub(crate) k_is: Vec, - /// The number of partial products needed to compute the `Z` polynomials and - /// the number of original elements consumed in `partial_products()`. - pub(crate) num_partial_products: (usize, usize), + /// The number of partial products needed to compute the `Z` polynomials. + pub(crate) num_partial_products: usize, /// A digest of the "circuit" (i.e. the instance, minus public inputs), which can be used to /// seed Fiat-Shamir. @@ -286,6 +318,103 @@ impl, C: GenericConfig, const D: usize> pub fn partial_products_range(&self) -> RangeFrom { self.config.num_challenges.. } + + pub(crate) fn get_fri_instance(&self, zeta: F::Extension) -> FriInstanceInfo { + // All polynomials are opened at zeta. + let zeta_batch = FriBatchInfo { + point: zeta, + polynomials: self.fri_all_polys(), + }; + + // The Z polynomials are also opened at g * zeta. + let g = F::Extension::primitive_root_of_unity(self.degree_bits); + let zeta_right = g * zeta; + let zeta_right_batch = FriBatchInfo { + point: zeta_right, + polynomials: self.fri_zs_polys(), + }; + + let openings = vec![zeta_batch, zeta_right_batch]; + FriInstanceInfo { + oracles: FRI_ORACLES.to_vec(), + batches: openings, + } + } + + pub(crate) fn get_fri_instance_target( + &self, + builder: &mut CircuitBuilder, + zeta: ExtensionTarget, + ) -> FriInstanceInfoTarget { + // All polynomials are opened at zeta. + let zeta_batch = FriBatchInfoTarget { + point: zeta, + polynomials: self.fri_all_polys(), + }; + + // The Z polynomials are also opened at g * zeta. + let g = F::primitive_root_of_unity(self.degree_bits); + let zeta_right = builder.mul_const_extension(g, zeta); + let zeta_right_batch = FriBatchInfoTarget { + point: zeta_right, + polynomials: self.fri_zs_polys(), + }; + + let openings = vec![zeta_batch, zeta_right_batch]; + FriInstanceInfoTarget { + oracles: FRI_ORACLES.to_vec(), + batches: openings, + } + } + + fn fri_preprocessed_polys(&self) -> Vec { + FriPolynomialInfo::from_range( + PlonkOracle::CONSTANTS_SIGMAS.index, + 0..self.num_preprocessed_polys(), + ) + } + + pub(crate) fn num_preprocessed_polys(&self) -> usize { + self.sigmas_range().end + } + + fn fri_wire_polys(&self) -> Vec { + let num_wire_polys = self.config.num_wires; + FriPolynomialInfo::from_range(PlonkOracle::WIRES.index, 0..num_wire_polys) + } + + fn fri_zs_partial_products_polys(&self) -> Vec { + FriPolynomialInfo::from_range( + PlonkOracle::ZS_PARTIAL_PRODUCTS.index, + 0..self.num_zs_partial_products_polys(), + ) + } + + pub(crate) fn num_zs_partial_products_polys(&self) -> usize { + self.config.num_challenges * (1 + self.num_partial_products) + } + + fn fri_zs_polys(&self) -> Vec { + FriPolynomialInfo::from_range(PlonkOracle::ZS_PARTIAL_PRODUCTS.index, self.zs_range()) + } + + fn fri_quotient_polys(&self) -> Vec { + FriPolynomialInfo::from_range(PlonkOracle::QUOTIENT.index, 0..self.num_quotient_polys()) + } + + pub(crate) fn num_quotient_polys(&self) -> usize { + self.config.num_challenges * self.quotient_degree_factor + } + + fn fri_all_polys(&self) -> Vec { + [ + self.fri_preprocessed_polys(), + self.fri_wire_polys(), + self.fri_zs_partial_products_polys(), + self.fri_quotient_polys(), + ] + .concat() + } } /// The `Target` version of `VerifierCircuitData`, for use inside recursive circuits. Note that this diff --git a/plonky2/src/plonk/config.rs b/plonky2/src/plonk/config.rs index 34f92f58..cb6d9a9b 100644 --- a/plonky2/src/plonk/config.rs +++ b/plonky2/src/plonk/config.rs @@ -5,7 +5,6 @@ use plonky2_field::extension_field::{Extendable, FieldExtension}; use plonky2_field::goldilocks_field::GoldilocksField; use serde::{de::DeserializeOwned, Serialize}; -use crate::hash::gmimc::GMiMCHash; use crate::hash::hash_types::HashOut; use crate::hash::hash_types::RichField; use crate::hash::hashing::{PlonkyPermutation, SPONGE_WIDTH}; @@ -32,7 +31,39 @@ pub trait Hasher: Sized + Clone + Debug + Eq + PartialEq { /// Permutation used in the sponge construction. type Permutation: PlonkyPermutation; - fn hash(input: Vec, pad: bool) -> Self::Hash; + /// Hash a message without any padding step. Note that this can enable length-extension attacks. + /// However, it is still collision-resistant in cases where the input has a fixed length. + fn hash_no_pad(input: &[F]) -> Self::Hash; + + /// Pad the message using the `pad10*1` rule, then hash it. + fn hash_pad(input: &[F]) -> Self::Hash { + let mut padded_input = input.to_vec(); + padded_input.push(F::ONE); + while (padded_input.len() + 1) % SPONGE_WIDTH != 0 { + padded_input.push(F::ZERO); + } + padded_input.push(F::ONE); + Self::hash_no_pad(&padded_input) + } + + /// Hash the slice if necessary to reduce its length to ~256 bits. If it already fits, this is a + /// no-op. + fn hash_or_noop(inputs: &[F]) -> Self::Hash + where + [(); Self::HASH_SIZE]:, + { + if inputs.len() <= 4 { + let mut inputs_bytes = [0u8; Self::HASH_SIZE]; + for i in 0..inputs.len() { + inputs_bytes[i * 8..(i + 1) * 8] + .copy_from_slice(&inputs[i].to_canonical_u64().to_le_bytes()); + } + Self::Hash::from_bytes(&inputs_bytes) + } else { + Self::hash_no_pad(inputs) + } + } + fn two_to_one(left: Self::Hash, right: Self::Hash) -> Self::Hash; } @@ -66,45 +97,16 @@ pub trait GenericConfig: type InnerHasher: AlgebraicHasher; } -/// Configuration trait for "algebraic" configurations, i.e., those using an algebraic hash function -/// in Merkle trees. -/// Same as `GenericConfig` trait but with `InnerHasher: AlgebraicHasher`. -pub trait AlgebraicConfig: - Debug + Clone + Sync + Sized + Send + Eq + PartialEq -{ - type F: RichField + Extendable; - type FE: FieldExtension; - type Hasher: AlgebraicHasher; - type InnerHasher: AlgebraicHasher; -} - -impl, const D: usize> GenericConfig for A { - type F = >::F; - type FE = >::FE; - type Hasher = >::Hasher; - type InnerHasher = >::InnerHasher; -} - /// Configuration using Poseidon over the Goldilocks field. #[derive(Debug, Copy, Clone, Eq, PartialEq)] pub struct PoseidonGoldilocksConfig; -impl AlgebraicConfig<2> for PoseidonGoldilocksConfig { +impl GenericConfig<2> for PoseidonGoldilocksConfig { type F = GoldilocksField; type FE = QuadraticExtension; type Hasher = PoseidonHash; type InnerHasher = PoseidonHash; } -/// Configuration using GMiMC over the Goldilocks field. -#[derive(Debug, Copy, Clone, Eq, PartialEq)] -pub struct GMiMCGoldilocksConfig; -impl AlgebraicConfig<2> for GMiMCGoldilocksConfig { - type F = GoldilocksField; - type FE = QuadraticExtension; - type Hasher = GMiMCHash; - type InnerHasher = GMiMCHash; -} - /// Configuration using truncated Keccak over the Goldilocks field. #[derive(Debug, Copy, Clone, Eq, PartialEq)] pub struct KeccakGoldilocksConfig; diff --git a/plonky2/src/plonk/get_challenges.rs b/plonky2/src/plonk/get_challenges.rs index 23e7f454..a67a6207 100644 --- a/plonky2/src/plonk/get_challenges.rs +++ b/plonky2/src/plonk/get_challenges.rs @@ -3,16 +3,20 @@ use std::collections::HashSet; use plonky2_field::extension_field::Extendable; use plonky2_field::polynomial::PolynomialCoeffs; -use crate::fri::proof::{CompressedFriProof, FriProof}; -use crate::fri::verifier::{compute_evaluation, fri_combine_initial, PrecomputedReducedEvals}; -use crate::hash::hash_types::RichField; +use crate::fri::proof::{CompressedFriProof, FriChallenges, FriProof, FriProofTarget}; +use crate::fri::verifier::{compute_evaluation, fri_combine_initial, PrecomputedReducedOpenings}; +use crate::gadgets::polynomial::PolynomialCoeffsExtTarget; +use crate::hash::hash_types::{HashOutTarget, MerkleCapTarget, RichField}; use crate::hash::merkle_tree::MerkleCap; -use crate::iop::challenger::Challenger; +use crate::iop::challenger::{Challenger, RecursiveChallenger}; +use crate::iop::target::Target; +use crate::plonk::circuit_builder::CircuitBuilder; use crate::plonk::circuit_data::CommonCircuitData; -use crate::plonk::config::{GenericConfig, Hasher}; +use crate::plonk::config::{AlgebraicHasher, GenericConfig, Hasher}; use crate::plonk::proof::{ - CompressedProof, CompressedProofWithPublicInputs, FriInferredElements, OpeningSet, Proof, - ProofChallenges, ProofWithPublicInputs, + CompressedProof, CompressedProofWithPublicInputs, FriInferredElements, OpeningSet, + OpeningSetTarget, Proof, ProofChallenges, ProofChallengesTarget, ProofTarget, + ProofWithPublicInputs, ProofWithPublicInputsTarget, }; use crate::util::reverse_bits; @@ -29,8 +33,6 @@ fn get_challenges, C: GenericConfig, cons ) -> anyhow::Result> { let config = &common_data.config; let num_challenges = config.num_challenges; - let num_fri_queries = config.fri_config.num_query_rounds; - let lde_size = common_data.lde_size(); let mut challenger = Challenger::::new(); @@ -48,47 +50,20 @@ fn get_challenges, C: GenericConfig, cons challenger.observe_cap(quotient_polys_cap); let plonk_zeta = challenger.get_extension_challenge::(); - challenger.observe_opening_set(openings); - - // Scaling factor to combine polynomials. - let fri_alpha = challenger.get_extension_challenge::(); - - // Recover the random betas used in the FRI reductions. - let fri_betas = commit_phase_merkle_caps - .iter() - .map(|cap| { - challenger.observe_cap(cap); - challenger.get_extension_challenge::() - }) - .collect(); - - challenger.observe_extension_elements(&final_poly.coeffs); - - let fri_pow_response = C::InnerHasher::hash( - challenger - .get_hash() - .elements - .iter() - .copied() - .chain(Some(pow_witness)) - .collect(), - false, - ) - .elements[0]; - - let fri_query_indices = (0..num_fri_queries) - .map(|_| challenger.get_challenge().to_canonical_u64() as usize % lde_size) - .collect(); + challenger.observe_openings(&openings.to_fri_openings()); Ok(ProofChallenges { plonk_betas, plonk_gammas, plonk_alphas, plonk_zeta, - fri_alpha, - fri_betas, - fri_pow_response, - fri_query_indices, + fri_challenges: challenger.fri_challenges::( + commit_phase_merkle_caps, + final_poly, + pow_witness, + common_data.degree_bits, + &config.fri_config, + ), }) } @@ -99,12 +74,16 @@ impl, C: GenericConfig, const D: usize> &self, common_data: &CommonCircuitData, ) -> anyhow::Result> { - Ok(self.get_challenges(common_data)?.fri_query_indices) + Ok(self + .get_challenges(self.get_public_inputs_hash(), common_data)? + .fri_challenges + .fri_query_indices) } /// Computes all Fiat-Shamir challenges used in the Plonk proof. pub(crate) fn get_challenges( &self, + public_inputs_hash: <>::InnerHasher as Hasher>::Hash, common_data: &CommonCircuitData, ) -> anyhow::Result> { let Proof { @@ -122,7 +101,7 @@ impl, C: GenericConfig, const D: usize> } = &self.proof; get_challenges( - self.get_public_inputs_hash(), + public_inputs_hash, wires_cap, plonk_zs_partial_products_cap, quotient_polys_cap, @@ -141,6 +120,7 @@ impl, C: GenericConfig, const D: usize> /// Computes all Fiat-Shamir challenges used in the Plonk proof. pub(crate) fn get_challenges( &self, + public_inputs_hash: <>::InnerHasher as Hasher>::Hash, common_data: &CommonCircuitData, ) -> anyhow::Result> { let CompressedProof { @@ -158,7 +138,7 @@ impl, C: GenericConfig, const D: usize> } = &self.proof; get_challenges( - self.get_public_inputs_hash(), + public_inputs_hash, wires_cap, plonk_zs_partial_products_cap, quotient_polys_cap, @@ -178,34 +158,40 @@ impl, C: GenericConfig, const D: usize> ) -> FriInferredElements { let ProofChallenges { plonk_zeta, - fri_alpha, - fri_betas, - fri_query_indices, + fri_challenges: + FriChallenges { + fri_alpha, + fri_betas, + fri_query_indices, + .. + }, .. } = challenges; let mut fri_inferred_elements = Vec::new(); // Holds the indices that have already been seen at each reduction depth. let mut seen_indices_by_depth = vec![HashSet::new(); common_data.fri_params.reduction_arity_bits.len()]; - let precomputed_reduced_evals = - PrecomputedReducedEvals::from_os_and_alpha(&self.proof.openings, *fri_alpha); + let precomputed_reduced_evals = PrecomputedReducedOpenings::from_os_and_alpha( + &self.proof.openings.to_fri_openings(), + *fri_alpha, + ); let log_n = common_data.degree_bits + common_data.config.fri_config.rate_bits; // Simulate the proof verification and collect the inferred elements. // The content of the loop is basically the same as the `fri_verifier_query_round` function. for &(mut x_index) in fri_query_indices { let mut subgroup_x = F::MULTIPLICATIVE_GROUP_GENERATOR * F::primitive_root_of_unity(log_n).exp_u64(reverse_bits(x_index, log_n) as u64); - let mut old_eval = fri_combine_initial( + let mut old_eval = fri_combine_initial::( + &common_data.get_fri_instance(*plonk_zeta), &self .proof .opening_proof .query_round_proofs .initial_trees_proofs[&x_index], *fri_alpha, - *plonk_zeta, subgroup_x, - precomputed_reduced_evals, - common_data, + &precomputed_reduced_evals, + &common_data.fri_params, ); for (i, &arity_bits) in common_data .fri_params @@ -239,3 +225,96 @@ impl, C: GenericConfig, const D: usize> FriInferredElements(fri_inferred_elements) } } + +impl, const D: usize> CircuitBuilder { + fn get_challenges>( + &mut self, + public_inputs_hash: HashOutTarget, + wires_cap: &MerkleCapTarget, + plonk_zs_partial_products_cap: &MerkleCapTarget, + quotient_polys_cap: &MerkleCapTarget, + openings: &OpeningSetTarget, + commit_phase_merkle_caps: &[MerkleCapTarget], + final_poly: &PolynomialCoeffsExtTarget, + pow_witness: Target, + inner_common_data: &CommonCircuitData, + ) -> ProofChallengesTarget + where + C::Hasher: AlgebraicHasher, + { + let config = &inner_common_data.config; + let num_challenges = config.num_challenges; + + let mut challenger = RecursiveChallenger::::new(self); + + // Observe the instance. + let digest = + HashOutTarget::from_vec(self.constants(&inner_common_data.circuit_digest.elements)); + challenger.observe_hash(&digest); + challenger.observe_hash(&public_inputs_hash); + + challenger.observe_cap(wires_cap); + let plonk_betas = challenger.get_n_challenges(self, num_challenges); + let plonk_gammas = challenger.get_n_challenges(self, num_challenges); + + challenger.observe_cap(plonk_zs_partial_products_cap); + let plonk_alphas = challenger.get_n_challenges(self, num_challenges); + + challenger.observe_cap(quotient_polys_cap); + let plonk_zeta = challenger.get_extension_challenge(self); + + challenger.observe_openings(&openings.to_fri_openings()); + + ProofChallengesTarget { + plonk_betas, + plonk_gammas, + plonk_alphas, + plonk_zeta, + fri_challenges: challenger.fri_challenges::( + self, + commit_phase_merkle_caps, + final_poly, + pow_witness, + inner_common_data, + ), + } + } +} + +impl ProofWithPublicInputsTarget { + pub(crate) fn get_challenges, C: GenericConfig>( + &self, + builder: &mut CircuitBuilder, + public_inputs_hash: HashOutTarget, + inner_common_data: &CommonCircuitData, + ) -> ProofChallengesTarget + where + C::Hasher: AlgebraicHasher, + { + let ProofTarget { + wires_cap, + plonk_zs_partial_products_cap, + quotient_polys_cap, + openings, + opening_proof: + FriProofTarget { + commit_phase_merkle_caps, + final_poly, + pow_witness, + .. + }, + } = &self.proof; + + builder.get_challenges( + public_inputs_hash, + wires_cap, + plonk_zs_partial_products_cap, + quotient_polys_cap, + openings, + commit_phase_merkle_caps, + final_poly, + *pow_witness, + inner_common_data, + ) + } +} diff --git a/plonky2/src/plonk/mod.rs b/plonky2/src/plonk/mod.rs index b2d1ed03..4f2fa4e1 100644 --- a/plonky2/src/plonk/mod.rs +++ b/plonky2/src/plonk/mod.rs @@ -4,7 +4,7 @@ pub mod config; pub(crate) mod copy_constraint; mod get_challenges; pub(crate) mod permutation_argument; -pub(crate) mod plonk_common; +pub mod plonk_common; pub mod proof; pub mod prover; pub mod recursive_verifier; diff --git a/plonky2/src/plonk/plonk_common.rs b/plonky2/src/plonk/plonk_common.rs index 5b8119aa..09cf2652 100644 --- a/plonky2/src/plonk/plonk_common.rs +++ b/plonky2/src/plonk/plonk_common.rs @@ -2,49 +2,59 @@ use plonky2_field::extension_field::Extendable; use plonky2_field::field_types::Field; use plonky2_field::packed_field::PackedField; -use crate::fri::commitment::SALT_SIZE; +use crate::fri::oracle::SALT_SIZE; +use crate::fri::structure::FriOracleInfo; use crate::hash::hash_types::RichField; use crate::iop::ext_target::ExtensionTarget; use crate::iop::target::Target; use crate::plonk::circuit_builder::CircuitBuilder; use crate::util::reducing::ReducingFactorTarget; +pub(crate) const FRI_ORACLES: [FriOracleInfo; 4] = [ + PlonkOracle::CONSTANTS_SIGMAS.as_fri_oracle(), + PlonkOracle::WIRES.as_fri_oracle(), + PlonkOracle::ZS_PARTIAL_PRODUCTS.as_fri_oracle(), + PlonkOracle::QUOTIENT.as_fri_oracle(), +]; + /// Holds the Merkle tree index and blinding flag of a set of polynomials used in FRI. #[derive(Debug, Copy, Clone)] -pub struct PolynomialsIndexBlinding { +pub struct PlonkOracle { pub(crate) index: usize, pub(crate) blinding: bool, } -impl PolynomialsIndexBlinding { - pub fn salt_size(&self, zero_knowledge: bool) -> usize { - if zero_knowledge & self.blinding { - SALT_SIZE - } else { - 0 + +impl PlonkOracle { + pub const CONSTANTS_SIGMAS: PlonkOracle = PlonkOracle { + index: 0, + blinding: false, + }; + pub const WIRES: PlonkOracle = PlonkOracle { + index: 1, + blinding: true, + }; + pub const ZS_PARTIAL_PRODUCTS: PlonkOracle = PlonkOracle { + index: 2, + blinding: true, + }; + pub const QUOTIENT: PlonkOracle = PlonkOracle { + index: 3, + blinding: true, + }; + + pub(crate) const fn as_fri_oracle(&self) -> FriOracleInfo { + FriOracleInfo { + blinding: self.blinding, } } } -/// Holds the indices and blinding flags of the Plonk polynomials. -pub struct PlonkPolynomials; - -impl PlonkPolynomials { - pub const CONSTANTS_SIGMAS: PolynomialsIndexBlinding = PolynomialsIndexBlinding { - index: 0, - blinding: false, - }; - pub const WIRES: PolynomialsIndexBlinding = PolynomialsIndexBlinding { - index: 1, - blinding: true, - }; - pub const ZS_PARTIAL_PRODUCTS: PolynomialsIndexBlinding = PolynomialsIndexBlinding { - index: 2, - blinding: true, - }; - pub const QUOTIENT: PolynomialsIndexBlinding = PolynomialsIndexBlinding { - index: 3, - blinding: true, - }; +pub fn salt_size(salted: bool) -> usize { + if salted { + SALT_SIZE + } else { + 0 + } } /// Evaluate the polynomial which vanishes on any multiplicative subgroup of a given order `n`. @@ -53,52 +63,6 @@ pub(crate) fn eval_zero_poly(n: usize, x: F) -> F { x.exp_u64(n as u64) - F::ONE } -/// Precomputations of the evaluation of `Z_H(X) = X^n - 1` on a coset `gK` with `H <= K`. -pub(crate) struct ZeroPolyOnCoset { - /// `n = |H|`. - n: F, - /// `rate = |K|/|H|`. - rate: usize, - /// Holds `g^n * (w^n)^i - 1 = g^n * v^i - 1` for `i in 0..rate`, with `w` a generator of `K` and `v` a - /// `rate`-primitive root of unity. - evals: Vec, - /// Holds the multiplicative inverses of `evals`. - inverses: Vec, -} - -impl ZeroPolyOnCoset { - pub fn new(n_log: usize, rate_bits: usize) -> Self { - let g_pow_n = F::coset_shift().exp_power_of_2(n_log); - let evals = F::two_adic_subgroup(rate_bits) - .into_iter() - .map(|x| g_pow_n * x - F::ONE) - .collect::>(); - let inverses = F::batch_multiplicative_inverse(&evals); - Self { - n: F::from_canonical_usize(1 << n_log), - rate: 1 << rate_bits, - evals, - inverses, - } - } - - /// Returns `Z_H(g * w^i)`. - pub fn eval(&self, i: usize) -> F { - self.evals[i % self.rate] - } - - /// Returns `1 / Z_H(g * w^i)`. - pub fn eval_inverse(&self, i: usize) -> F { - self.inverses[i % self.rate] - } - - /// Returns `L_1(x) = Z_H(x)/(n * (x - 1))` with `x = w^i`. - pub fn eval_l1(&self, i: usize, x: F) -> F { - // Could also precompute the inverses using Montgomery. - self.eval(i) * (self.n * (x - F::ONE)).inverse() - } -} - /// Evaluate the Lagrange basis `L_1` with `L_1(1) = 1`, and `L_1(x) = 0` for other members of an /// order `n` multiplicative subgroup. pub(crate) fn eval_l_1(n: usize, x: F) -> F { @@ -160,7 +124,7 @@ pub(crate) fn reduce_with_powers_multi< cumul } -pub(crate) fn reduce_with_powers<'a, P: PackedField, T: IntoIterator>( +pub fn reduce_with_powers<'a, P: PackedField, T: IntoIterator>( terms: T, alpha: P::Scalar, ) -> P @@ -174,7 +138,7 @@ where sum } -pub(crate) fn reduce_with_powers_ext_recursive, const D: usize>( +pub fn reduce_with_powers_ext_recursive, const D: usize>( builder: &mut CircuitBuilder, terms: &[ExtensionTarget], alpha: Target, diff --git a/plonky2/src/plonk/proof.rs b/plonky2/src/plonk/proof.rs index b494f324..27cfd2bb 100644 --- a/plonky2/src/plonk/proof.rs +++ b/plonky2/src/plonk/proof.rs @@ -1,9 +1,16 @@ +use anyhow::ensure; use plonky2_field::extension_field::Extendable; use rayon::prelude::*; use serde::{Deserialize, Serialize}; -use crate::fri::commitment::PolynomialBatchCommitment; -use crate::fri::proof::{CompressedFriProof, FriProof, FriProofTarget}; +use crate::fri::oracle::PolynomialBatch; +use crate::fri::proof::{ + CompressedFriProof, FriChallenges, FriChallengesTarget, FriProof, FriProofTarget, +}; +use crate::fri::structure::{ + FriOpeningBatch, FriOpeningBatchTarget, FriOpenings, FriOpeningsTarget, +}; +use crate::fri::FriParams; use crate::hash::hash_types::{MerkleCapTarget, RichField}; use crate::hash::merkle_tree::MerkleCap; use crate::iop::ext_target::ExtensionTarget; @@ -28,6 +35,7 @@ pub struct Proof, C: GenericConfig, const pub opening_proof: FriProof, } +#[derive(Debug)] pub struct ProofTarget { pub wires_cap: MerkleCapTarget, pub plonk_zs_partial_products_cap: MerkleCapTarget, @@ -38,11 +46,7 @@ pub struct ProofTarget { impl, C: GenericConfig, const D: usize> Proof { /// Compress the proof. - pub fn compress( - self, - indices: &[usize], - common_data: &CommonCircuitData, - ) -> CompressedProof { + pub fn compress(self, indices: &[usize], params: &FriParams) -> CompressedProof { let Proof { wires_cap, plonk_zs_partial_products_cap, @@ -56,7 +60,7 @@ impl, C: GenericConfig, const D: usize> P plonk_zs_partial_products_cap, quotient_polys_cap, openings, - opening_proof: opening_proof.compress(indices, common_data), + opening_proof: opening_proof.compress::(indices, params), } } } @@ -80,7 +84,7 @@ impl, C: GenericConfig, const D: usize> common_data: &CommonCircuitData, ) -> anyhow::Result> { let indices = self.fri_query_indices(common_data)?; - let compressed_proof = self.proof.compress(&indices, common_data); + let compressed_proof = self.proof.compress(&indices, &common_data.fri_params); Ok(CompressedProofWithPublicInputs { public_inputs: self.public_inputs, proof: compressed_proof, @@ -90,7 +94,7 @@ impl, C: GenericConfig, const D: usize> pub(crate) fn get_public_inputs_hash( &self, ) -> <>::InnerHasher as Hasher>::Hash { - C::InnerHasher::hash(self.public_inputs.clone(), true) + C::InnerHasher::hash_no_pad(&self.public_inputs) } pub fn to_bytes(&self) -> anyhow::Result> { @@ -133,8 +137,11 @@ impl, C: GenericConfig, const D: usize> self, challenges: &ProofChallenges, fri_inferred_elements: FriInferredElements, - common_data: &CommonCircuitData, - ) -> Proof { + params: &FriParams, + ) -> Proof + where + [(); C::Hasher::HASH_SIZE]:, + { let CompressedProof { wires_cap, plonk_zs_partial_products_cap, @@ -148,7 +155,7 @@ impl, C: GenericConfig, const D: usize> plonk_zs_partial_products_cap, quotient_polys_cap, openings, - opening_proof: opening_proof.decompress(challenges, fri_inferred_elements, common_data), + opening_proof: opening_proof.decompress::(challenges, fri_inferred_elements, params), } } } @@ -170,12 +177,15 @@ impl, C: GenericConfig, const D: usize> pub fn decompress( self, common_data: &CommonCircuitData, - ) -> anyhow::Result> { - let challenges = self.get_challenges(common_data)?; + ) -> anyhow::Result> + where + [(); C::Hasher::HASH_SIZE]:, + { + let challenges = self.get_challenges(self.get_public_inputs_hash(), common_data)?; let fri_inferred_elements = self.get_inferred_elements(&challenges, common_data); let decompressed_proof = self.proof - .decompress(&challenges, fri_inferred_elements, common_data); + .decompress(&challenges, fri_inferred_elements, &common_data.fri_params); Ok(ProofWithPublicInputs { public_inputs: self.public_inputs, proof: decompressed_proof, @@ -186,17 +196,23 @@ impl, C: GenericConfig, const D: usize> self, verifier_data: &VerifierOnlyCircuitData, common_data: &CommonCircuitData, - ) -> anyhow::Result<()> { - let challenges = self.get_challenges(common_data)?; + ) -> anyhow::Result<()> + where + [(); C::Hasher::HASH_SIZE]:, + { + ensure!( + self.public_inputs.len() == common_data.num_public_inputs, + "Number of public inputs doesn't match circuit data." + ); + let public_inputs_hash = self.get_public_inputs_hash(); + let challenges = self.get_challenges(public_inputs_hash, common_data)?; let fri_inferred_elements = self.get_inferred_elements(&challenges, common_data); let decompressed_proof = self.proof - .decompress(&challenges, fri_inferred_elements, common_data); + .decompress(&challenges, fri_inferred_elements, &common_data.fri_params); verify_with_challenges( - ProofWithPublicInputs { - public_inputs: self.public_inputs, - proof: decompressed_proof, - }, + decompressed_proof, + public_inputs_hash, challenges, verifier_data, common_data, @@ -206,7 +222,7 @@ impl, C: GenericConfig, const D: usize> pub(crate) fn get_public_inputs_hash( &self, ) -> <>::InnerHasher as Hasher>::Hash { - C::InnerHasher::hash(self.public_inputs.clone(), true) + C::InnerHasher::hash_no_pad(&self.public_inputs) } pub fn to_bytes(&self) -> anyhow::Result> { @@ -226,28 +242,27 @@ impl, C: GenericConfig, const D: usize> } pub(crate) struct ProofChallenges, const D: usize> { - // Random values used in Plonk's permutation argument. + /// Random values used in Plonk's permutation argument. pub plonk_betas: Vec, - // Random values used in Plonk's permutation argument. + /// Random values used in Plonk's permutation argument. pub plonk_gammas: Vec, - // Random values used to combine PLONK constraints. + /// Random values used to combine PLONK constraints. pub plonk_alphas: Vec, - // Point at which the PLONK polynomials are opened. + /// Point at which the PLONK polynomials are opened. pub plonk_zeta: F::Extension, - // Scaling factor to combine polynomials. - pub fri_alpha: F::Extension, + pub fri_challenges: FriChallenges, +} - // Betas used in the FRI commit phase reductions. - pub fri_betas: Vec, - - pub fri_pow_response: F, - - // Indices at which the oracle is queried in FRI. - pub fri_query_indices: Vec, +pub(crate) struct ProofChallengesTarget { + pub plonk_betas: Vec, + pub plonk_gammas: Vec, + pub plonk_alphas: Vec, + pub plonk_zeta: ExtensionTarget, + pub fri_challenges: FriChallengesTarget, } /// Coset elements that can be inferred in the FRI reduction steps. @@ -255,6 +270,7 @@ pub(crate) struct FriInferredElements, const D: usi pub Vec, ); +#[derive(Debug)] pub struct ProofWithPublicInputsTarget { pub proof: ProofTarget, pub public_inputs: Vec, @@ -274,33 +290,53 @@ pub struct OpeningSet, const D: usize> { impl, const D: usize> OpeningSet { pub fn new>( - z: F::Extension, + zeta: F::Extension, g: F::Extension, - constants_sigmas_commitment: &PolynomialBatchCommitment, - wires_commitment: &PolynomialBatchCommitment, - zs_partial_products_commitment: &PolynomialBatchCommitment, - quotient_polys_commitment: &PolynomialBatchCommitment, + constants_sigmas_commitment: &PolynomialBatch, + wires_commitment: &PolynomialBatch, + zs_partial_products_commitment: &PolynomialBatch, + quotient_polys_commitment: &PolynomialBatch, common_data: &CommonCircuitData, ) -> Self { - let eval_commitment = |z: F::Extension, c: &PolynomialBatchCommitment| { + let eval_commitment = |z: F::Extension, c: &PolynomialBatch| { c.polynomials .par_iter() .map(|p| p.to_extension().eval(z)) .collect::>() }; - let constants_sigmas_eval = eval_commitment(z, constants_sigmas_commitment); - let zs_partial_products_eval = eval_commitment(z, zs_partial_products_commitment); + let constants_sigmas_eval = eval_commitment(zeta, constants_sigmas_commitment); + let zs_partial_products_eval = eval_commitment(zeta, zs_partial_products_commitment); Self { constants: constants_sigmas_eval[common_data.constants_range()].to_vec(), plonk_sigmas: constants_sigmas_eval[common_data.sigmas_range()].to_vec(), - wires: eval_commitment(z, wires_commitment), + wires: eval_commitment(zeta, wires_commitment), plonk_zs: zs_partial_products_eval[common_data.zs_range()].to_vec(), - plonk_zs_right: eval_commitment(g * z, zs_partial_products_commitment) + plonk_zs_right: eval_commitment(g * zeta, zs_partial_products_commitment) [common_data.zs_range()] .to_vec(), partial_products: zs_partial_products_eval[common_data.partial_products_range()] .to_vec(), - quotient_polys: eval_commitment(z, quotient_polys_commitment), + quotient_polys: eval_commitment(zeta, quotient_polys_commitment), + } + } + + pub(crate) fn to_fri_openings(&self) -> FriOpenings { + let zeta_batch = FriOpeningBatch { + values: [ + self.constants.as_slice(), + self.plonk_sigmas.as_slice(), + self.wires.as_slice(), + self.plonk_zs.as_slice(), + self.partial_products.as_slice(), + self.quotient_polys.as_slice(), + ] + .concat(), + }; + let zeta_right_batch = FriOpeningBatch { + values: self.plonk_zs_right.clone(), + }; + FriOpenings { + batches: vec![zeta_batch, zeta_right_batch], } } } @@ -317,6 +353,28 @@ pub struct OpeningSetTarget { pub quotient_polys: Vec>, } +impl OpeningSetTarget { + pub(crate) fn to_fri_openings(&self) -> FriOpeningsTarget { + let zeta_batch = FriOpeningBatchTarget { + values: [ + self.constants.as_slice(), + self.plonk_sigmas.as_slice(), + self.wires.as_slice(), + self.plonk_zs.as_slice(), + self.partial_products.as_slice(), + self.quotient_polys.as_slice(), + ] + .concat(), + }; + let zeta_right_batch = FriOpeningBatchTarget { + values: self.plonk_zs_right.clone(), + }; + FriOpeningsTarget { + batches: vec![zeta_batch, zeta_right_batch], + } + } +} + #[cfg(test)] mod tests { use anyhow::Result; diff --git a/plonky2/src/plonk/prover.rs b/plonky2/src/plonk/prover.rs index f7196270..1d99b60a 100644 --- a/plonky2/src/plonk/prover.rs +++ b/plonky2/src/plonk/prover.rs @@ -1,20 +1,23 @@ use std::mem::swap; +use anyhow::ensure; use anyhow::Result; use plonky2_field::extension_field::Extendable; use plonky2_field::polynomial::{PolynomialCoeffs, PolynomialValues}; -use plonky2_util::log2_ceil; +use plonky2_field::zero_poly_coset::ZeroPolyOnCoset; +use plonky2_util::{ceil_div_usize, log2_ceil}; use rayon::prelude::*; -use crate::fri::commitment::PolynomialBatchCommitment; +use crate::field::field_types::Field; +use crate::fri::oracle::PolynomialBatch; use crate::hash::hash_types::RichField; use crate::iop::challenger::Challenger; use crate::iop::generator::generate_partial_witness; use crate::iop::witness::{MatrixWitness, PartialWitness, Witness}; use crate::plonk::circuit_data::{CommonCircuitData, ProverOnlyCircuitData}; use crate::plonk::config::{GenericConfig, Hasher}; -use crate::plonk::plonk_common::PlonkPolynomials; -use crate::plonk::plonk_common::ZeroPolyOnCoset; +use crate::plonk::plonk_common::PlonkOracle; +use crate::plonk::proof::OpeningSet; use crate::plonk::proof::{Proof, ProofWithPublicInputs}; use crate::plonk::vanishing_poly::eval_vanishing_poly_base_batch; use crate::plonk::vars::EvaluationVarsBaseBatch; @@ -28,7 +31,10 @@ pub(crate) fn prove, C: GenericConfig, co common_data: &CommonCircuitData, inputs: PartialWitness, timing: &mut TimingTree, -) -> Result> { +) -> Result> +where + [(); C::Hasher::HASH_SIZE]:, +{ let config = &common_data.config; let num_challenges = config.num_challenges; let quotient_degree = common_data.quotient_degree(); @@ -41,7 +47,7 @@ pub(crate) fn prove, C: GenericConfig, co ); let public_inputs = partition_witness.get_targets(&prover_data.public_inputs); - let public_inputs_hash = C::InnerHasher::hash(public_inputs.clone(), true); + let public_inputs_hash = C::InnerHasher::hash_no_pad(&public_inputs); if cfg!(debug_assertions) { // Display the marked targets for debugging purposes. @@ -69,10 +75,10 @@ pub(crate) fn prove, C: GenericConfig, co let wires_commitment = timed!( timing, "compute wires commitment", - PolynomialBatchCommitment::from_values( + PolynomialBatch::from_values( wires_values, config.fri_config.rate_bits, - config.zero_knowledge & PlonkPolynomials::WIRES.blinding, + config.zero_knowledge && PlonkOracle::WIRES.blinding, config.fri_config.cap_height, timing, prover_data.fft_root_table.as_ref(), @@ -109,10 +115,10 @@ pub(crate) fn prove, C: GenericConfig, co let partial_products_and_zs_commitment = timed!( timing, "commit to partial products and Z's", - PolynomialBatchCommitment::from_values( + PolynomialBatch::from_values( zs_partial_products, config.fri_config.rate_bits, - config.zero_knowledge & PlonkPolynomials::ZS_PARTIAL_PRODUCTS.blinding, + config.zero_knowledge && PlonkOracle::ZS_PARTIAL_PRODUCTS.blinding, config.fri_config.cap_height, timing, prover_data.fft_root_table.as_ref(), @@ -145,11 +151,10 @@ pub(crate) fn prove, C: GenericConfig, co quotient_polys .into_par_iter() .flat_map(|mut quotient_poly| { - quotient_poly.trim(); - quotient_poly.pad(quotient_degree).expect( - "Quotient has failed, the vanishing polynomial is not divisible by `Z_H", + quotient_poly.trim_to_len(quotient_degree).expect( + "Quotient has failed, the vanishing polynomial is not divisible by Z_H", ); - // Split t into degree-n chunks. + // Split quotient into degree-n chunks. quotient_poly.chunks(degree) }) .collect() @@ -158,10 +163,10 @@ pub(crate) fn prove, C: GenericConfig, co let quotient_polys_commitment = timed!( timing, "commit to quotient polys", - PolynomialBatchCommitment::from_coeffs( + PolynomialBatch::from_coeffs( all_quotient_poly_chunks, config.fri_config.rate_bits, - config.zero_knowledge & PlonkPolynomials::QUOTIENT.blinding, + config.zero_knowledge && PlonkOracle::QUOTIENT.blinding, config.fri_config.cap_height, timing, prover_data.fft_root_table.as_ref(), @@ -171,20 +176,43 @@ pub(crate) fn prove, C: GenericConfig, co challenger.observe_cap("ient_polys_commitment.merkle_tree.cap); let zeta = challenger.get_extension_challenge::(); + // To avoid leaking witness data, we want to ensure that our opening locations, `zeta` and + // `g * zeta`, are not in our subgroup `H`. It suffices to check `zeta` only, since + // `(g * zeta)^n = zeta^n`, where `n` is the order of `g`. + let g = F::Extension::primitive_root_of_unity(common_data.degree_bits); + ensure!( + zeta.exp_power_of_2(common_data.degree_bits) != F::Extension::ONE, + "Opening point is in the subgroup." + ); - let (opening_proof, openings) = timed!( + let openings = timed!( + timing, + "construct the opening set", + OpeningSet::new( + zeta, + g, + &prover_data.constants_sigmas_commitment, + &wires_commitment, + &partial_products_and_zs_commitment, + "ient_polys_commitment, + common_data, + ) + ); + challenger.observe_openings(&openings.to_fri_openings()); + + let opening_proof = timed!( timing, "compute opening proofs", - PolynomialBatchCommitment::open_plonk( + PolynomialBatch::prove_openings( + &common_data.get_fri_instance(zeta), &[ &prover_data.constants_sigmas_commitment, &wires_commitment, &partial_products_and_zs_commitment, "ient_polys_commitment, ], - zeta, &mut challenger, - common_data, + &common_data.fri_params, timing, ) ); @@ -244,7 +272,7 @@ fn wires_permutation_partial_products_and_zs< let degree = common_data.quotient_degree_factor; let subgroup = &prover_data.subgroup; let k_is = &common_data.k_is; - let (num_prods, _final_num_prod) = common_data.num_partial_products; + let num_prods = common_data.num_partial_products; let all_quotient_chunk_products = subgroup .par_iter() .enumerate() @@ -300,44 +328,49 @@ fn compute_quotient_polys< common_data: &CommonCircuitData, prover_data: &'a ProverOnlyCircuitData, public_inputs_hash: &<>::InnerHasher as Hasher>::Hash, - wires_commitment: &'a PolynomialBatchCommitment, - zs_partial_products_commitment: &'a PolynomialBatchCommitment, + wires_commitment: &'a PolynomialBatch, + zs_partial_products_commitment: &'a PolynomialBatch, betas: &[F], gammas: &[F], alphas: &[F], ) -> Vec> { let num_challenges = common_data.config.num_challenges; - let max_degree_bits = log2_ceil(common_data.quotient_degree_factor); + let quotient_degree_bits = log2_ceil(common_data.quotient_degree_factor); assert!( - max_degree_bits <= common_data.config.fri_config.rate_bits, + quotient_degree_bits <= common_data.config.fri_config.rate_bits, "Having constraints of degree higher than the rate is not supported yet. \ - If we need this in the future, we can precompute the larger LDE before computing the `ListPolynomialCommitment`s." + If we need this in the future, we can precompute the larger LDE before computing the `PolynomialBatch`s." ); - // We reuse the LDE computed in `ListPolynomialCommitment` and extract every `step` points to get + // We reuse the LDE computed in `PolynomialBatch` and extract every `step` points to get // an LDE matching `max_filtered_constraint_degree`. - let step = 1 << (common_data.config.fri_config.rate_bits - max_degree_bits); + let step = 1 << (common_data.config.fri_config.rate_bits - quotient_degree_bits); // When opening the `Z`s polys at the "next" point in Plonk, need to look at the point `next_step` // steps away since we work on an LDE of degree `max_filtered_constraint_degree`. - let next_step = 1 << max_degree_bits; + let next_step = 1 << quotient_degree_bits; - let points = F::two_adic_subgroup(common_data.degree_bits + max_degree_bits); + let points = F::two_adic_subgroup(common_data.degree_bits + quotient_degree_bits); let lde_size = points.len(); // Retrieve the LDE values at index `i`. - let get_at_index = |comm: &'a PolynomialBatchCommitment, i: usize| -> &'a [F] { - comm.get_lde_values(i * step) - }; + let get_at_index = + |comm: &'a PolynomialBatch, i: usize| -> &'a [F] { comm.get_lde_values(i * step) }; - let z_h_on_coset = ZeroPolyOnCoset::new(common_data.degree_bits, max_degree_bits); + let z_h_on_coset = ZeroPolyOnCoset::new(common_data.degree_bits, quotient_degree_bits); let points_batches = points.par_chunks(BATCH_SIZE); + let num_batches = ceil_div_usize(points.len(), BATCH_SIZE); let quotient_values: Vec> = points_batches .enumerate() .map(|(batch_i, xs_batch)| { - assert_eq!(xs_batch.len(), BATCH_SIZE); + // Each batch must be the same size, except the last one, which may be smaller. + debug_assert!( + xs_batch.len() == BATCH_SIZE + || (batch_i == num_batches - 1 && xs_batch.len() <= BATCH_SIZE) + ); + let indices_batch: Vec = - (BATCH_SIZE * batch_i..BATCH_SIZE * (batch_i + 1)).collect(); + (BATCH_SIZE * batch_i..BATCH_SIZE * batch_i + xs_batch.len()).collect(); let mut shifted_xs_batch = Vec::with_capacity(xs_batch.len()); let mut local_zs_batch = Vec::with_capacity(xs_batch.len()); @@ -379,17 +412,17 @@ fn compute_quotient_polys< // NB (JN): I'm not sure how (in)efficient the below is. It needs measuring. let mut local_constants_batch = vec![F::ZERO; xs_batch.len() * local_constants_batch_refs[0].len()]; - for (i, constants) in local_constants_batch_refs.iter().enumerate() { - for (j, &constant) in constants.iter().enumerate() { - local_constants_batch[i + j * xs_batch.len()] = constant; + for i in 0..local_constants_batch_refs[0].len() { + for (j, constants) in local_constants_batch_refs.iter().enumerate() { + local_constants_batch[i * xs_batch.len() + j] = constants[i]; } } let mut local_wires_batch = vec![F::ZERO; xs_batch.len() * local_wires_batch_refs[0].len()]; - for (i, wires) in local_wires_batch_refs.iter().enumerate() { - for (j, &wire) in wires.iter().enumerate() { - local_wires_batch[i + j * xs_batch.len()] = wire; + for i in 0..local_wires_batch_refs[0].len() { + for (j, wires) in local_wires_batch_refs.iter().enumerate() { + local_wires_batch[i * xs_batch.len() + j] = wires[i]; } } diff --git a/plonky2/src/plonk/recursive_verifier.rs b/plonky2/src/plonk/recursive_verifier.rs index a9af478c..f5fbd2ac 100644 --- a/plonky2/src/plonk/recursive_verifier.rs +++ b/plonky2/src/plonk/recursive_verifier.rs @@ -1,11 +1,12 @@ use plonky2_field::extension_field::Extendable; use crate::hash::hash_types::{HashOutTarget, RichField}; -use crate::iop::challenger::RecursiveChallenger; use crate::plonk::circuit_builder::CircuitBuilder; -use crate::plonk::circuit_data::{CircuitConfig, CommonCircuitData, VerifierCircuitTarget}; -use crate::plonk::config::AlgebraicConfig; -use crate::plonk::proof::ProofWithPublicInputsTarget; +use crate::plonk::circuit_data::{CommonCircuitData, VerifierCircuitTarget}; +use crate::plonk::config::{AlgebraicHasher, GenericConfig}; +use crate::plonk::proof::{ + OpeningSetTarget, ProofChallengesTarget, ProofTarget, ProofWithPublicInputsTarget, +}; use crate::plonk::vanishing_poly::eval_vanishing_poly_recursively; use crate::plonk::vars::EvaluationTargets; use crate::util::reducing::ReducingFactorTarget; @@ -13,76 +14,74 @@ use crate::with_context; impl, const D: usize> CircuitBuilder { /// Recursively verifies an inner proof. - pub fn add_recursive_verifier>( + pub fn verify_proof>( &mut self, proof_with_pis: ProofWithPublicInputsTarget, - inner_config: &CircuitConfig, inner_verifier_data: &VerifierCircuitTarget, inner_common_data: &CommonCircuitData, - ) { - let ProofWithPublicInputsTarget { - proof, - public_inputs, - } = proof_with_pis; + ) where + C::Hasher: AlgebraicHasher, + { + assert_eq!( + proof_with_pis.public_inputs.len(), + inner_common_data.num_public_inputs + ); + let public_inputs_hash = + self.hash_n_to_hash_no_pad::(proof_with_pis.public_inputs.clone()); + let challenges = proof_with_pis.get_challenges(self, public_inputs_hash, inner_common_data); + + self.verify_proof_with_challenges( + proof_with_pis.proof, + public_inputs_hash, + challenges, + inner_verifier_data, + inner_common_data, + ); + } + + /// Recursively verifies an inner proof. + fn verify_proof_with_challenges>( + &mut self, + proof: ProofTarget, + public_inputs_hash: HashOutTarget, + challenges: ProofChallengesTarget, + inner_verifier_data: &VerifierCircuitTarget, + inner_common_data: &CommonCircuitData, + ) where + C::Hasher: AlgebraicHasher, + { let one = self.one_extension(); - let num_challenges = inner_config.num_challenges; - - let public_inputs_hash = &self.hash_n_to_hash::(public_inputs, true); - - let mut challenger = RecursiveChallenger::new(self); - - let (betas, gammas, alphas, zeta) = - with_context!(self, "observe proof and generates challenges", { - // Observe the instance. - let digest = HashOutTarget::from_vec( - self.constants(&inner_common_data.circuit_digest.elements), - ); - challenger.observe_hash(&digest); - challenger.observe_hash(public_inputs_hash); - - challenger.observe_cap(&proof.wires_cap); - let betas = challenger.get_n_challenges(self, num_challenges); - let gammas = challenger.get_n_challenges(self, num_challenges); - - challenger.observe_cap(&proof.plonk_zs_partial_products_cap); - let alphas = challenger.get_n_challenges(self, num_challenges); - - challenger.observe_cap(&proof.quotient_polys_cap); - let zeta = challenger.get_extension_challenge(self); - - (betas, gammas, alphas, zeta) - }); - let local_constants = &proof.openings.constants; let local_wires = &proof.openings.wires; let vars = EvaluationTargets { local_constants, local_wires, - public_inputs_hash, + public_inputs_hash: &public_inputs_hash, }; let local_zs = &proof.openings.plonk_zs; let next_zs = &proof.openings.plonk_zs_right; let s_sigmas = &proof.openings.plonk_sigmas; let partial_products = &proof.openings.partial_products; - let zeta_pow_deg = self.exp_power_of_2_extension(zeta, inner_common_data.degree_bits); + let zeta_pow_deg = + self.exp_power_of_2_extension(challenges.plonk_zeta, inner_common_data.degree_bits); let vanishing_polys_zeta = with_context!( self, "evaluate the vanishing polynomial at our challenge point, zeta.", eval_vanishing_poly_recursively( self, inner_common_data, - zeta, + challenges.plonk_zeta, zeta_pow_deg, vars, local_zs, next_zs, partial_products, s_sigmas, - &betas, - &gammas, - &alphas, + &challenges.plonk_betas, + &challenges.plonk_gammas, + &challenges.plonk_alphas, ) ); @@ -107,277 +106,95 @@ impl, const D: usize> CircuitBuilder { proof.quotient_polys_cap, ]; + let fri_instance = inner_common_data.get_fri_instance_target(self, challenges.plonk_zeta); with_context!( self, "verify FRI proof", - self.verify_fri_proof( - &proof.openings, - zeta, + self.verify_fri_proof::( + &fri_instance, + &proof.openings.to_fri_openings(), + &challenges.fri_challenges, merkle_caps, &proof.opening_proof, - &mut challenger, - inner_common_data, + &inner_common_data.fri_params, ) ); } + + pub fn add_virtual_proof_with_pis>( + &mut self, + common_data: &CommonCircuitData, + ) -> ProofWithPublicInputsTarget { + let proof = self.add_virtual_proof(common_data); + let public_inputs = self.add_virtual_targets(common_data.num_public_inputs); + ProofWithPublicInputsTarget { + proof, + public_inputs, + } + } + + fn add_virtual_proof>( + &mut self, + common_data: &CommonCircuitData, + ) -> ProofTarget { + let config = &common_data.config; + let fri_params = &common_data.fri_params; + let cap_height = fri_params.config.cap_height; + + let num_leaves_per_oracle = &[ + common_data.num_preprocessed_polys(), + config.num_wires, + common_data.num_zs_partial_products_polys(), + common_data.num_quotient_polys(), + ]; + + ProofTarget { + wires_cap: self.add_virtual_cap(cap_height), + plonk_zs_partial_products_cap: self.add_virtual_cap(cap_height), + quotient_polys_cap: self.add_virtual_cap(cap_height), + openings: self.add_opening_set(common_data), + opening_proof: self.add_virtual_fri_proof(num_leaves_per_oracle, fri_params), + } + } + + fn add_opening_set>( + &mut self, + common_data: &CommonCircuitData, + ) -> OpeningSetTarget { + let config = &common_data.config; + let num_challenges = config.num_challenges; + let total_partial_products = num_challenges * common_data.num_partial_products; + OpeningSetTarget { + constants: self.add_virtual_extension_targets(common_data.num_constants), + plonk_sigmas: self.add_virtual_extension_targets(config.num_routed_wires), + wires: self.add_virtual_extension_targets(config.num_wires), + plonk_zs: self.add_virtual_extension_targets(num_challenges), + plonk_zs_right: self.add_virtual_extension_targets(num_challenges), + partial_products: self.add_virtual_extension_targets(total_partial_products), + quotient_polys: self.add_virtual_extension_targets(common_data.num_quotient_polys()), + } + } } #[cfg(test)] mod tests { use anyhow::Result; use log::{info, Level}; - use plonky2_util::log2_strict; use super::*; - use crate::fri::proof::{ - FriInitialTreeProofTarget, FriProofTarget, FriQueryRoundTarget, FriQueryStepTarget, - }; use crate::fri::reduction_strategies::FriReductionStrategy; use crate::fri::FriConfig; - use crate::gadgets::polynomial::PolynomialCoeffsExtTarget; use crate::gates::noop::NoopGate; - use crate::hash::merkle_proofs::MerkleProofTarget; use crate::iop::witness::{PartialWitness, Witness}; - use crate::plonk::circuit_data::VerifierOnlyCircuitData; + use crate::plonk::circuit_data::{CircuitConfig, VerifierOnlyCircuitData}; use crate::plonk::config::{ - GMiMCGoldilocksConfig, GenericConfig, KeccakGoldilocksConfig, PoseidonGoldilocksConfig, - }; - use crate::plonk::proof::{ - CompressedProofWithPublicInputs, OpeningSetTarget, Proof, ProofTarget, - ProofWithPublicInputs, + GenericConfig, Hasher, KeccakGoldilocksConfig, PoseidonGoldilocksConfig, }; + use crate::plonk::proof::{CompressedProofWithPublicInputs, ProofWithPublicInputs}; use crate::plonk::prover::prove; use crate::util::timing::TimingTree; - // Construct a `FriQueryRoundTarget` with the same dimensions as the ones in `proof`. - fn get_fri_query_round< - F: RichField + Extendable, - C: GenericConfig, - const D: usize, - >( - proof: &Proof, - builder: &mut CircuitBuilder, - ) -> FriQueryRoundTarget { - let mut query_round = FriQueryRoundTarget { - initial_trees_proof: FriInitialTreeProofTarget { - evals_proofs: vec![], - }, - steps: vec![], - }; - for (v, merkle_proof) in &proof.opening_proof.query_round_proofs[0] - .initial_trees_proof - .evals_proofs - { - query_round.initial_trees_proof.evals_proofs.push(( - builder.add_virtual_targets(v.len()), - MerkleProofTarget { - siblings: builder.add_virtual_hashes(merkle_proof.siblings.len()), - }, - )); - } - for step in &proof.opening_proof.query_round_proofs[0].steps { - query_round.steps.push(FriQueryStepTarget { - evals: builder.add_virtual_extension_targets(step.evals.len()), - merkle_proof: MerkleProofTarget { - siblings: builder.add_virtual_hashes(step.merkle_proof.siblings.len()), - }, - }); - } - query_round - } - - // Construct a `ProofTarget` with the same dimensions as `proof`. - fn proof_to_proof_target< - F: RichField + Extendable, - C: GenericConfig, - const D: usize, - >( - proof_with_pis: &ProofWithPublicInputs, - builder: &mut CircuitBuilder, - ) -> ProofWithPublicInputsTarget { - let ProofWithPublicInputs { - proof, - public_inputs, - } = proof_with_pis; - - let wires_cap = builder.add_virtual_cap(log2_strict(proof.wires_cap.0.len())); - let plonk_zs_cap = - builder.add_virtual_cap(log2_strict(proof.plonk_zs_partial_products_cap.0.len())); - let quotient_polys_cap = - builder.add_virtual_cap(log2_strict(proof.quotient_polys_cap.0.len())); - - let openings = OpeningSetTarget { - constants: builder.add_virtual_extension_targets(proof.openings.constants.len()), - plonk_sigmas: builder.add_virtual_extension_targets(proof.openings.plonk_sigmas.len()), - wires: builder.add_virtual_extension_targets(proof.openings.wires.len()), - plonk_zs: builder.add_virtual_extension_targets(proof.openings.plonk_zs.len()), - plonk_zs_right: builder - .add_virtual_extension_targets(proof.openings.plonk_zs_right.len()), - partial_products: builder - .add_virtual_extension_targets(proof.openings.partial_products.len()), - quotient_polys: builder - .add_virtual_extension_targets(proof.openings.quotient_polys.len()), - }; - let query_round_proofs = (0..proof.opening_proof.query_round_proofs.len()) - .map(|_| get_fri_query_round(proof, builder)) - .collect(); - let commit_phase_merkle_caps = proof - .opening_proof - .commit_phase_merkle_caps - .iter() - .map(|r| builder.add_virtual_cap(log2_strict(r.0.len()))) - .collect(); - let opening_proof = FriProofTarget { - commit_phase_merkle_caps, - query_round_proofs, - final_poly: PolynomialCoeffsExtTarget( - builder.add_virtual_extension_targets(proof.opening_proof.final_poly.len()), - ), - pow_witness: builder.add_virtual_target(), - }; - - let proof = ProofTarget { - wires_cap, - plonk_zs_partial_products_cap: plonk_zs_cap, - quotient_polys_cap, - openings, - opening_proof, - }; - - let public_inputs = builder.add_virtual_targets(public_inputs.len()); - ProofWithPublicInputsTarget { - proof, - public_inputs, - } - } - - // Set the targets in a `ProofTarget` to their corresponding values in a `Proof`. - fn set_proof_target< - F: RichField + Extendable, - C: AlgebraicConfig, - const D: usize, - >( - proof: &ProofWithPublicInputs, - pt: &ProofWithPublicInputsTarget, - pw: &mut PartialWitness, - ) { - let ProofWithPublicInputs { - proof, - public_inputs, - } = proof; - let ProofWithPublicInputsTarget { - proof: pt, - public_inputs: pi_targets, - } = pt; - - // Set public inputs. - for (&pi_t, &pi) in pi_targets.iter().zip(public_inputs) { - pw.set_target(pi_t, pi); - } - - pw.set_cap_target(&pt.wires_cap, &proof.wires_cap); - pw.set_cap_target( - &pt.plonk_zs_partial_products_cap, - &proof.plonk_zs_partial_products_cap, - ); - pw.set_cap_target(&pt.quotient_polys_cap, &proof.quotient_polys_cap); - - for (&t, &x) in pt.openings.wires.iter().zip(&proof.openings.wires) { - pw.set_extension_target(t, x); - } - for (&t, &x) in pt.openings.constants.iter().zip(&proof.openings.constants) { - pw.set_extension_target(t, x); - } - for (&t, &x) in pt - .openings - .plonk_sigmas - .iter() - .zip(&proof.openings.plonk_sigmas) - { - pw.set_extension_target(t, x); - } - for (&t, &x) in pt.openings.plonk_zs.iter().zip(&proof.openings.plonk_zs) { - pw.set_extension_target(t, x); - } - for (&t, &x) in pt - .openings - .plonk_zs_right - .iter() - .zip(&proof.openings.plonk_zs_right) - { - pw.set_extension_target(t, x); - } - for (&t, &x) in pt - .openings - .partial_products - .iter() - .zip(&proof.openings.partial_products) - { - pw.set_extension_target(t, x); - } - for (&t, &x) in pt - .openings - .quotient_polys - .iter() - .zip(&proof.openings.quotient_polys) - { - pw.set_extension_target(t, x); - } - - let fri_proof = &proof.opening_proof; - let fpt = &pt.opening_proof; - - pw.set_target(fpt.pow_witness, fri_proof.pow_witness); - - for (&t, &x) in fpt.final_poly.0.iter().zip(&fri_proof.final_poly.coeffs) { - pw.set_extension_target(t, x); - } - - for (t, x) in fpt - .commit_phase_merkle_caps - .iter() - .zip(&fri_proof.commit_phase_merkle_caps) - { - pw.set_cap_target(t, x); - } - - for (qt, q) in fpt - .query_round_proofs - .iter() - .zip(&fri_proof.query_round_proofs) - { - for (at, a) in qt - .initial_trees_proof - .evals_proofs - .iter() - .zip(&q.initial_trees_proof.evals_proofs) - { - for (&t, &x) in at.0.iter().zip(&a.0) { - pw.set_target(t, x); - } - for (&t, &x) in at.1.siblings.iter().zip(&a.1.siblings) { - pw.set_hash_target(t, x); - } - } - - for (st, s) in qt.steps.iter().zip(&q.steps) { - for (&t, &x) in st.evals.iter().zip(&s.evals) { - pw.set_extension_target(t, x); - } - for (&t, &x) in st - .merkle_proof - .siblings - .iter() - .zip(&s.merkle_proof.siblings) - { - pw.set_hash_target(t, x); - } - } - } - } - #[test] - #[ignore] fn test_recursive_verifier() -> Result<()> { init_logger(); const D: usize = 2; @@ -387,14 +204,13 @@ mod tests { let (proof, vd, cd) = dummy_proof::(&config, 4_000)?; let (proof, _vd, cd) = - recursive_proof::(proof, vd, cd, &config, &config, None, true, true)?; + recursive_proof::(proof, vd, cd, &config, None, true, true)?; test_serialization(&proof, &cd)?; Ok(()) } #[test] - #[ignore] fn test_recursive_recursive_verifier() -> Result<()> { init_logger(); const D: usize = 2; @@ -409,12 +225,12 @@ mod tests { // Shrink it to 2^13. let (proof, vd, cd) = - recursive_proof::(proof, vd, cd, &config, &config, Some(13), false, false)?; + recursive_proof::(proof, vd, cd, &config, Some(13), false, false)?; assert_eq!(cd.degree_bits, 13); // Shrink it to 2^12. let (proof, _vd, cd) = - recursive_proof::(proof, vd, cd, &config, &config, None, true, true)?; + recursive_proof::(proof, vd, cd, &config, None, true, true)?; assert_eq!(cd.degree_bits, 12); test_serialization(&proof, &cd)?; @@ -440,16 +256,7 @@ mod tests { assert_eq!(cd.degree_bits, 12); // A standard recursive proof. - let (proof, vd, cd) = recursive_proof( - proof, - vd, - cd, - &standard_config, - &standard_config, - None, - false, - false, - )?; + let (proof, vd, cd) = recursive_proof(proof, vd, cd, &standard_config, None, false, false)?; assert_eq!(cd.degree_bits, 12); // A high-rate recursive proof, designed to be verifiable with fewer routed wires. @@ -462,16 +269,8 @@ mod tests { }, ..standard_config }; - let (proof, vd, cd) = recursive_proof::( - proof, - vd, - cd, - &standard_config, - &high_rate_config, - None, - true, - true, - )?; + let (proof, vd, cd) = + recursive_proof::(proof, vd, cd, &high_rate_config, None, true, true)?; assert_eq!(cd.degree_bits, 12); // A final proof, optimized for size. @@ -486,16 +285,8 @@ mod tests { }, ..high_rate_config }; - let (proof, _vd, cd) = recursive_proof::( - proof, - vd, - cd, - &high_rate_config, - &final_config, - None, - true, - true, - )?; + let (proof, _vd, cd) = + recursive_proof::(proof, vd, cd, &final_config, None, true, true)?; assert_eq!(cd.degree_bits, 12, "final proof too large"); test_serialization(&proof, &cd)?; @@ -504,12 +295,10 @@ mod tests { } #[test] - #[ignore] fn test_recursive_verifier_multi_hash() -> Result<()> { init_logger(); const D: usize = 2; type PC = PoseidonGoldilocksConfig; - type GC = GMiMCGoldilocksConfig; type KC = KeccakGoldilocksConfig; type F = >::F; @@ -517,19 +306,11 @@ mod tests { let (proof, vd, cd) = dummy_proof::(&config, 4_000)?; let (proof, vd, cd) = - recursive_proof::(proof, vd, cd, &config, &config, None, false, false)?; - test_serialization(&proof, &cd)?; - - let (proof, vd, cd) = - recursive_proof::(proof, vd, cd, &config, &config, None, false, false)?; - test_serialization(&proof, &cd)?; - - let (proof, vd, cd) = - recursive_proof::(proof, vd, cd, &config, &config, None, false, false)?; + recursive_proof::(proof, vd, cd, &config, None, false, false)?; test_serialization(&proof, &cd)?; let (proof, _vd, cd) = - recursive_proof::(proof, vd, cd, &config, &config, None, false, false)?; + recursive_proof::(proof, vd, cd, &config, None, false, false)?; test_serialization(&proof, &cd)?; Ok(()) @@ -543,7 +324,10 @@ mod tests { ProofWithPublicInputs, VerifierOnlyCircuitData, CommonCircuitData, - )> { + )> + where + [(); C::Hasher::HASH_SIZE]:, + { let mut builder = CircuitBuilder::::new(config.clone()); for _ in 0..num_dummy_gates { builder.add_gate(NoopGate, vec![], vec![]); @@ -560,13 +344,12 @@ mod tests { fn recursive_proof< F: RichField + Extendable, C: GenericConfig, - InnerC: AlgebraicConfig, + InnerC: GenericConfig, const D: usize, >( inner_proof: ProofWithPublicInputs, inner_vd: VerifierOnlyCircuitData, inner_cd: CommonCircuitData, - inner_config: &CircuitConfig, config: &CircuitConfig, min_degree_bits: Option, print_gate_counts: bool, @@ -575,21 +358,25 @@ mod tests { ProofWithPublicInputs, VerifierOnlyCircuitData, CommonCircuitData, - )> { + )> + where + InnerC::Hasher: AlgebraicHasher, + [(); C::Hasher::HASH_SIZE]:, + { let mut builder = CircuitBuilder::::new(config.clone()); let mut pw = PartialWitness::new(); - let pt = proof_to_proof_target(&inner_proof, &mut builder); - set_proof_target(&inner_proof, &pt, &mut pw); + let pt = builder.add_virtual_proof_with_pis(&inner_cd); + pw.set_proof_with_pis_target(&pt, &inner_proof); let inner_data = VerifierCircuitTarget { - constants_sigmas_cap: builder.add_virtual_cap(inner_config.fri_config.cap_height), + constants_sigmas_cap: builder.add_virtual_cap(inner_cd.config.fri_config.cap_height), }; pw.set_cap_target( &inner_data.constants_sigmas_cap, &inner_vd.constants_sigmas_cap, ); - builder.add_recursive_verifier(pt, inner_config, &inner_data, &inner_cd); + builder.verify_proof(pt, &inner_data, &inner_cd); if print_gate_counts { builder.print_gate_counts(0); @@ -626,7 +413,10 @@ mod tests { >( proof: &ProofWithPublicInputs, cd: &CommonCircuitData, - ) -> Result<()> { + ) -> Result<()> + where + [(); C::Hasher::HASH_SIZE]:, + { let proof_bytes = proof.to_bytes()?; info!("Proof length: {} bytes", proof_bytes.len()); let proof_from_bytes = ProofWithPublicInputs::from_bytes(proof_bytes, cd)?; diff --git a/plonky2/src/plonk/vanishing_poly.rs b/plonky2/src/plonk/vanishing_poly.rs index d4c227de..70de5833 100644 --- a/plonky2/src/plonk/vanishing_poly.rs +++ b/plonky2/src/plonk/vanishing_poly.rs @@ -1,6 +1,7 @@ use plonky2_field::batch_util::batch_add_inplace; use plonky2_field::extension_field::{Extendable, FieldExtension}; use plonky2_field::field_types::Field; +use plonky2_field::zero_poly_coset::ZeroPolyOnCoset; use crate::gates::gate::PrefixedGate; use crate::hash::hash_types::RichField; @@ -10,7 +11,7 @@ use crate::plonk::circuit_builder::CircuitBuilder; use crate::plonk::circuit_data::CommonCircuitData; use crate::plonk::config::GenericConfig; use crate::plonk::plonk_common; -use crate::plonk::plonk_common::{eval_l_1_recursively, ZeroPolyOnCoset}; +use crate::plonk::plonk_common::eval_l_1_recursively; use crate::plonk::vars::{EvaluationTargets, EvaluationVars, EvaluationVarsBaseBatch}; use crate::util::partial_products::{check_partial_products, check_partial_products_recursively}; use crate::util::reducing::ReducingFactorTarget; @@ -37,7 +38,7 @@ pub(crate) fn eval_vanishing_poly< alphas: &[F], ) -> Vec { let max_degree = common_data.quotient_degree_factor; - let (num_prods, _final_num_prod) = common_data.num_partial_products; + let num_prods = common_data.num_partial_products; let constraint_terms = evaluate_gate_constraints(&common_data.gates, common_data.num_gate_constraints, vars); @@ -123,7 +124,7 @@ pub(crate) fn eval_vanishing_poly_base_batch< assert_eq!(s_sigmas_batch.len(), n); let max_degree = common_data.quotient_degree_factor; - let (num_prods, _final_num_prod) = common_data.num_partial_products; + let num_prods = common_data.num_partial_products; let num_gate_constraints = common_data.num_gate_constraints; @@ -302,7 +303,7 @@ pub(crate) fn eval_vanishing_poly_recursively< alphas: &[Target], ) -> Vec> { let max_degree = common_data.quotient_degree_factor; - let (num_prods, _final_num_prod) = common_data.num_partial_products; + let num_prods = common_data.num_partial_products; let constraint_terms = with_context!( builder, diff --git a/plonky2/src/plonk/verifier.rs b/plonky2/src/plonk/verifier.rs index c3cf4988..ee0e976f 100644 --- a/plonky2/src/plonk/verifier.rs +++ b/plonky2/src/plonk/verifier.rs @@ -5,9 +5,9 @@ use plonky2_field::field_types::Field; use crate::fri::verifier::verify_fri_proof; use crate::hash::hash_types::RichField; use crate::plonk::circuit_data::{CommonCircuitData, VerifierOnlyCircuitData}; -use crate::plonk::config::GenericConfig; +use crate::plonk::config::{GenericConfig, Hasher}; use crate::plonk::plonk_common::reduce_with_powers; -use crate::plonk::proof::{ProofChallenges, ProofWithPublicInputs}; +use crate::plonk::proof::{Proof, ProofChallenges, ProofWithPublicInputs}; use crate::plonk::vanishing_poly::eval_vanishing_poly; use crate::plonk::vars::EvaluationVars; @@ -15,9 +15,24 @@ pub(crate) fn verify, C: GenericConfig, c proof_with_pis: ProofWithPublicInputs, verifier_data: &VerifierOnlyCircuitData, common_data: &CommonCircuitData, -) -> Result<()> { - let challenges = proof_with_pis.get_challenges(common_data)?; - verify_with_challenges(proof_with_pis, challenges, verifier_data, common_data) +) -> Result<()> +where + [(); C::Hasher::HASH_SIZE]:, +{ + ensure!( + proof_with_pis.public_inputs.len() == common_data.num_public_inputs, + "Number of public inputs doesn't match circuit data." + ); + let public_inputs_hash = proof_with_pis.get_public_inputs_hash(); + let challenges = proof_with_pis.get_challenges(public_inputs_hash, common_data)?; + + verify_with_challenges( + proof_with_pis.proof, + public_inputs_hash, + challenges, + verifier_data, + common_data, + ) } pub(crate) fn verify_with_challenges< @@ -25,21 +40,21 @@ pub(crate) fn verify_with_challenges< C: GenericConfig, const D: usize, >( - proof_with_pis: ProofWithPublicInputs, + proof: Proof, + public_inputs_hash: <>::InnerHasher as Hasher>::Hash, challenges: ProofChallenges, verifier_data: &VerifierOnlyCircuitData, common_data: &CommonCircuitData, -) -> Result<()> { - let public_inputs_hash = &proof_with_pis.get_public_inputs_hash(); - - let ProofWithPublicInputs { proof, .. } = proof_with_pis; - +) -> Result<()> +where + [(); C::Hasher::HASH_SIZE]:, +{ let local_constants = &proof.openings.constants; let local_wires = &proof.openings.wires; let vars = EvaluationVars { local_constants, local_wires, - public_inputs_hash, + public_inputs_hash: &public_inputs_hash, }; let local_zs = &proof.openings.plonk_zs; let next_zs = &proof.openings.plonk_zs_right; @@ -85,12 +100,13 @@ pub(crate) fn verify_with_challenges< proof.quotient_polys_cap, ]; - verify_fri_proof( - &proof.openings, - &challenges, + verify_fri_proof::( + &common_data.get_fri_instance(challenges.plonk_zeta), + &proof.openings.to_fri_openings(), + &challenges.fri_challenges, merkle_caps, &proof.opening_proof, - common_data, + &common_data.fri_params, )?; Ok(()) diff --git a/plonky2/src/util/mod.rs b/plonky2/src/util/mod.rs index 13a72f78..9342a75e 100644 --- a/plonky2/src/util/mod.rs +++ b/plonky2/src/util/mod.rs @@ -6,8 +6,8 @@ pub(crate) mod marking; pub(crate) mod partial_products; pub mod reducing; pub mod serialization; -pub(crate) mod strided_view; -pub(crate) mod timing; +pub mod strided_view; +pub mod timing; pub(crate) fn transpose_poly_values(polys: Vec>) -> Vec> { let poly_values = polys.into_iter().map(|p| p.values).collect::>(); @@ -99,7 +99,18 @@ mod tests { } #[test] - fn test_reverse_index_bits_in_place() { + fn test_reverse_index_bits_in_place_trivial() { + let mut arr1: Vec = vec![10]; + reverse_index_bits_in_place(&mut arr1); + assert_eq!(arr1, vec![10]); + + let mut arr2: Vec = vec![10, 20]; + reverse_index_bits_in_place(&mut arr2); + assert_eq!(arr2, vec![10, 20]); + } + + #[test] + fn test_reverse_index_bits_in_place_small() { let mut arr4: Vec = vec![10, 20, 30, 40]; reverse_index_bits_in_place(&mut arr4); assert_eq!(arr4, vec![10, 30, 20, 40]); @@ -127,4 +138,26 @@ mod tests { reverse_index_bits_in_place(&mut arr256); assert_eq!(arr256, output256); } + + #[test] + fn test_reverse_index_bits_in_place_big_even() { + let mut arr: Vec = (0..1 << 16).collect(); + let target = reverse_index_bits(&arr); + reverse_index_bits_in_place(&mut arr); + assert_eq!(arr, target); + reverse_index_bits_in_place(&mut arr); + let range: Vec = (0..1 << 16).collect(); + assert_eq!(arr, range); + } + + #[test] + fn test_reverse_index_bits_in_place_big_odd() { + let mut arr: Vec = (0..1 << 17).collect(); + let target = reverse_index_bits(&arr); + reverse_index_bits_in_place(&mut arr); + assert_eq!(arr, target); + reverse_index_bits_in_place(&mut arr); + let range: Vec = (0..1 << 17).collect(); + assert_eq!(arr, range); + } } diff --git a/plonky2/src/util/partial_products.rs b/plonky2/src/util/partial_products.rs index cc9012ed..56e9d6ed 100644 --- a/plonky2/src/util/partial_products.rs +++ b/plonky2/src/util/partial_products.rs @@ -35,16 +35,14 @@ pub(crate) fn partial_products_and_z_gx(z_x: F, quotient_chunk_product res } -/// Returns a tuple `(a,b)`, where `a` is the length of the output of `partial_products()` on a -/// vector of length `n`, and `b` is the number of original elements consumed in `partial_products()`. -pub(crate) fn num_partial_products(n: usize, max_degree: usize) -> (usize, usize) { +/// Returns the length of the output of `partial_products()` on a vector of length `n`. +pub(crate) fn num_partial_products(n: usize, max_degree: usize) -> usize { debug_assert!(max_degree > 1); let chunk_size = max_degree; // We'll split the product into `ceil_div_usize(n, chunk_size)` chunks, but the last chunk will // be associated with Z(gx) itself. Thus we subtract one to get the chunks associated with // partial products. - let num_chunks = ceil_div_usize(n, chunk_size) - 1; - (num_chunks, num_chunks * chunk_size) + ceil_div_usize(n, chunk_size) - 1 } /// Checks the relationship between each pair of partial product accumulators. In particular, this @@ -127,7 +125,7 @@ mod tests { assert_eq!(pps_and_z_gx, field_vec(&[2, 24, 720])); let nums = num_partial_products(v.len(), 2); - assert_eq!(pps.len(), nums.0); + assert_eq!(pps.len(), nums); assert!(check_partial_products(&v, &denominators, pps, z_x, z_gx, 2) .iter() .all(|x| x.is_zero())); @@ -138,7 +136,7 @@ mod tests { let pps = &pps_and_z_gx[..pps_and_z_gx.len() - 1]; assert_eq!(pps_and_z_gx, field_vec(&[6, 720])); let nums = num_partial_products(v.len(), 3); - assert_eq!(pps.len(), nums.0); + assert_eq!(pps.len(), nums); assert!(check_partial_products(&v, &denominators, pps, z_x, z_gx, 3) .iter() .all(|x| x.is_zero())); diff --git a/plonky2/src/util/reducing.rs b/plonky2/src/util/reducing.rs index a2d4e4cf..5edd8e1e 100644 --- a/plonky2/src/util/reducing.rs +++ b/plonky2/src/util/reducing.rs @@ -238,7 +238,14 @@ impl ReducingFactorTarget { where F: RichField + Extendable, { - let exp = builder.exp_u64_extension(self.base, self.count); + let zero_ext = builder.zero_extension(); + let exp = if x == zero_ext { + // The result will get zeroed out, so don't actually compute the exponentiation. + zero_ext + } else { + builder.exp_u64_extension(self.base, self.count) + }; + self.count = 0; builder.mul_extension(exp, x) } @@ -253,7 +260,7 @@ mod tests { use anyhow::Result; use super::*; - use crate::iop::witness::PartialWitness; + use crate::iop::witness::{PartialWitness, Witness}; use crate::plonk::circuit_data::CircuitConfig; use crate::plonk::config::{GenericConfig, PoseidonGoldilocksConfig}; use crate::plonk::verifier::verify; @@ -266,7 +273,7 @@ mod tests { let config = CircuitConfig::standard_recursion_config(); - let pw = PartialWitness::new(); + let mut pw = PartialWitness::new(); let mut builder = CircuitBuilder::::new(config); let alpha = FF::rand(); @@ -276,7 +283,10 @@ mod tests { let manual_reduce = builder.constant_extension(manual_reduce); let mut alpha_t = ReducingFactorTarget::new(builder.constant_extension(alpha)); - let vs_t = vs.iter().map(|&v| builder.constant(v)).collect::>(); + let vs_t = builder.add_virtual_targets(vs.len()); + for (&v, &v_t) in vs.iter().zip(&vs_t) { + pw.set_target(v_t, v); + } let circuit_reduce = alpha_t.reduce_base(&vs_t, &mut builder); builder.connect_extension(manual_reduce, circuit_reduce); @@ -295,7 +305,7 @@ mod tests { let config = CircuitConfig::standard_recursion_config(); - let pw = PartialWitness::new(); + let mut pw = PartialWitness::new(); let mut builder = CircuitBuilder::::new(config); let alpha = FF::rand(); @@ -305,10 +315,8 @@ mod tests { let manual_reduce = builder.constant_extension(manual_reduce); let mut alpha_t = ReducingFactorTarget::new(builder.constant_extension(alpha)); - let vs_t = vs - .iter() - .map(|&v| builder.constant_extension(v)) - .collect::>(); + let vs_t = builder.add_virtual_extension_targets(vs.len()); + pw.set_extension_targets(&vs_t, &vs); let circuit_reduce = alpha_t.reduce(&vs_t, &mut builder); builder.connect_extension(manual_reduce, circuit_reduce); diff --git a/plonky2/src/util/serialization.rs b/plonky2/src/util/serialization.rs index a9284bf4..d0326073 100644 --- a/plonky2/src/util/serialization.rs +++ b/plonky2/src/util/serialization.rs @@ -3,7 +3,7 @@ use std::io::Cursor; use std::io::{Read, Result, Write}; use plonky2_field::extension_field::{Extendable, FieldExtension}; -use plonky2_field::field_types::PrimeField; +use plonky2_field::field_types::{Field64, PrimeField64}; use plonky2_field::polynomial::PolynomialCoeffs; use crate::fri::proof::{ @@ -53,10 +53,10 @@ impl Buffer { Ok(u32::from_le_bytes(buf)) } - fn write_field(&mut self, x: F) -> Result<()> { + fn write_field(&mut self, x: F) -> Result<()> { self.0.write_all(&x.to_canonical_u64().to_le_bytes()) } - fn read_field(&mut self) -> Result { + fn read_field(&mut self) -> Result { let mut buf = [0; std::mem::size_of::()]; self.0.read_exact(&mut buf)?; Ok(F::from_canonical_u64(u64::from_le_bytes( @@ -116,13 +116,13 @@ impl Buffer { )) } - pub fn write_field_vec(&mut self, v: &[F]) -> Result<()> { + pub fn write_field_vec(&mut self, v: &[F]) -> Result<()> { for &a in v { self.write_field(a)?; } Ok(()) } - pub fn read_field_vec(&mut self, length: usize) -> Result> { + pub fn read_field_vec(&mut self, length: usize) -> Result> { (0..length) .map(|_| self.read_field()) .collect::>>() @@ -172,9 +172,8 @@ impl Buffer { let wires = self.read_field_ext_vec::(config.num_wires)?; let plonk_zs = self.read_field_ext_vec::(config.num_challenges)?; let plonk_zs_right = self.read_field_ext_vec::(config.num_challenges)?; - let partial_products = self.read_field_ext_vec::( - common_data.num_partial_products.0 * config.num_challenges, - )?; + let partial_products = self + .read_field_ext_vec::(common_data.num_partial_products * config.num_challenges)?; let quotient_polys = self.read_field_ext_vec::( common_data.quotient_degree_factor * config.num_challenges, )?; @@ -248,7 +247,7 @@ impl Buffer { evals_proofs.push((wires_v, wires_p)); let zs_partial_v = - self.read_field_vec(config.num_challenges * (1 + common_data.num_partial_products.0))?; + self.read_field_vec(config.num_challenges * (1 + common_data.num_partial_products))?; let zs_partial_p = self.read_merkle_proof()?; evals_proofs.push((zs_partial_v, zs_partial_p)); diff --git a/starky/Cargo.toml b/starky/Cargo.toml new file mode 100644 index 00000000..4e67856d --- /dev/null +++ b/starky/Cargo.toml @@ -0,0 +1,14 @@ +[package] +name = "starky" +description = "Implementation of STARKs" +version = "0.1.0" +edition = "2021" + +[dependencies] +plonky2 = { path = "../plonky2" } +plonky2_util = { path = "../util" } +anyhow = "1.0.40" +env_logger = "0.9.0" +itertools = "0.10.0" +log = "0.4.14" +rayon = "1.5.1" diff --git a/starky/src/config.rs b/starky/src/config.rs new file mode 100644 index 00000000..24fb725a --- /dev/null +++ b/starky/src/config.rs @@ -0,0 +1,45 @@ +use plonky2::fri::reduction_strategies::FriReductionStrategy; +use plonky2::fri::{FriConfig, FriParams}; + +pub struct StarkConfig { + pub security_bits: usize, + + /// The number of challenge points to generate, for IOPs that have soundness errors of (roughly) + /// `degree / |F|`. + pub num_challenges: usize, + + pub fri_config: FriConfig, +} + +impl StarkConfig { + /// A typical configuration with a rate of 2, resulting in fast but large proofs. + /// Targets ~100 bit conjectured security. + pub fn standard_fast_config() -> Self { + Self { + security_bits: 100, + num_challenges: 2, + fri_config: FriConfig { + rate_bits: 1, + cap_height: 4, + proof_of_work_bits: 10, + reduction_strategy: FriReductionStrategy::ConstantArityBits(4, 5), + num_query_rounds: 90, + }, + } + } + + pub(crate) fn fri_params(&self, degree_bits: usize) -> FriParams { + let fri_config = &self.fri_config; + let reduction_arity_bits = fri_config.reduction_strategy.reduction_arity_bits( + degree_bits, + fri_config.rate_bits, + fri_config.num_query_rounds, + ); + FriParams { + config: fri_config.clone(), + hiding: false, + degree_bits, + reduction_arity_bits, + } + } +} diff --git a/starky/src/constraint_consumer.rs b/starky/src/constraint_consumer.rs new file mode 100644 index 00000000..c909b520 --- /dev/null +++ b/starky/src/constraint_consumer.rs @@ -0,0 +1,144 @@ +use std::marker::PhantomData; + +use plonky2::field::extension_field::Extendable; +use plonky2::field::packed_field::PackedField; +use plonky2::hash::hash_types::RichField; +use plonky2::iop::ext_target::ExtensionTarget; +use plonky2::iop::target::Target; +use plonky2::plonk::circuit_builder::CircuitBuilder; + +pub struct ConstraintConsumer { + /// Random values used to combine multiple constraints into one. + alphas: Vec, + + /// Running sums of constraints that have been emitted so far, scaled by powers of alpha. + // TODO(JN): This is pub so it can be used in a test. Once we have an API for accessing this + // result, it should be made private. + pub constraint_accs: Vec

, + + /// The evaluation of `X - g^(n-1)`. + z_last: P, + + /// The evaluation of the Lagrange basis polynomial which is nonzero at the point associated + /// with the first trace row, and zero at other points in the subgroup. + lagrange_basis_first: P, + + /// The evaluation of the Lagrange basis polynomial which is nonzero at the point associated + /// with the last trace row, and zero at other points in the subgroup. + lagrange_basis_last: P, +} + +impl ConstraintConsumer

{ + pub fn new( + alphas: Vec, + z_last: P, + lagrange_basis_first: P, + lagrange_basis_last: P, + ) -> Self { + Self { + constraint_accs: vec![P::ZEROS; alphas.len()], + alphas, + z_last, + lagrange_basis_first, + lagrange_basis_last, + } + } + + // TODO: Do this correctly. + pub fn accumulators(self) -> Vec { + self.constraint_accs + .into_iter() + .map(|acc| acc.as_slice()[0]) + .collect() + } + + /// Add one constraint valid on all rows except the last. + pub fn constraint(&mut self, constraint: P) { + self.constraint_wrapping(constraint * self.z_last); + } + + /// Add one constraint on all rows. + pub fn constraint_wrapping(&mut self, constraint: P) { + for (&alpha, acc) in self.alphas.iter().zip(&mut self.constraint_accs) { + *acc *= alpha; + *acc += constraint; + } + } + + /// Add one constraint, but first multiply it by a filter such that it will only apply to the + /// first row of the trace. + pub fn constraint_first_row(&mut self, constraint: P) { + self.constraint_wrapping(constraint * self.lagrange_basis_first); + } + + /// Add one constraint, but first multiply it by a filter such that it will only apply to the + /// last row of the trace. + pub fn constraint_last_row(&mut self, constraint: P) { + self.constraint_wrapping(constraint * self.lagrange_basis_last); + } +} + +pub struct RecursiveConstraintConsumer, const D: usize> { + /// A random value used to combine multiple constraints into one. + alpha: Target, + + /// A running sum of constraints that have been emitted so far, scaled by powers of alpha. + constraint_acc: ExtensionTarget, + + /// The evaluation of `X - g^(n-1)`. + z_last: ExtensionTarget, + + /// The evaluation of the Lagrange basis polynomial which is nonzero at the point associated + /// with the first trace row, and zero at other points in the subgroup. + lagrange_basis_first: ExtensionTarget, + + /// The evaluation of the Lagrange basis polynomial which is nonzero at the point associated + /// with the last trace row, and zero at other points in the subgroup. + lagrange_basis_last: ExtensionTarget, + + _phantom: PhantomData, +} + +impl, const D: usize> RecursiveConstraintConsumer { + /// Add one constraint valid on all rows except the last. + pub fn constraint( + &mut self, + builder: &mut CircuitBuilder, + constraint: ExtensionTarget, + ) { + let filtered_constraint = builder.mul_extension(constraint, self.z_last); + self.constraint_wrapping(builder, filtered_constraint); + } + + /// Add one constraint valid on all rows. + pub fn constraint_wrapping( + &mut self, + builder: &mut CircuitBuilder, + constraint: ExtensionTarget, + ) { + self.constraint_acc = + builder.scalar_mul_add_extension(self.alpha, self.constraint_acc, constraint); + } + + /// Add one constraint, but first multiply it by a filter such that it will only apply to the + /// first row of the trace. + pub fn constraint_first_row( + &mut self, + builder: &mut CircuitBuilder, + constraint: ExtensionTarget, + ) { + let filtered_constraint = builder.mul_extension(constraint, self.lagrange_basis_first); + self.constraint(builder, filtered_constraint); + } + + /// Add one constraint, but first multiply it by a filter such that it will only apply to the + /// last row of the trace. + pub fn constraint_last_row( + &mut self, + builder: &mut CircuitBuilder, + constraint: ExtensionTarget, + ) { + let filtered_constraint = builder.mul_extension(constraint, self.lagrange_basis_last); + self.constraint(builder, filtered_constraint); + } +} diff --git a/starky/src/fibonacci_stark.rs b/starky/src/fibonacci_stark.rs new file mode 100644 index 00000000..c77775e8 --- /dev/null +++ b/starky/src/fibonacci_stark.rs @@ -0,0 +1,142 @@ +use std::marker::PhantomData; + +use plonky2::field::extension_field::{Extendable, FieldExtension}; +use plonky2::field::packed_field::PackedField; +use plonky2::hash::hash_types::RichField; +use plonky2::plonk::circuit_builder::CircuitBuilder; + +use crate::constraint_consumer::{ConstraintConsumer, RecursiveConstraintConsumer}; +use crate::stark::Stark; +use crate::vars::{StarkEvaluationTargets, StarkEvaluationVars}; + +/// Toy STARK system used for testing. +/// Computes a Fibonacci sequence with state `[x0, x1]` using the state transition +/// `x0 <- x1, x1 <- x0 + x1`. +#[derive(Copy, Clone)] +struct FibonacciStark, const D: usize> { + num_rows: usize, + _phantom: PhantomData, +} + +impl, const D: usize> FibonacciStark { + // The first public input is `x0`. + const PI_INDEX_X0: usize = 0; + // The second public input is `x1`. + const PI_INDEX_X1: usize = 1; + // The third public input is the second element of the last row, which should be equal to the + // `num_rows`-th Fibonacci number. + const PI_INDEX_RES: usize = 2; + + fn new(num_rows: usize) -> Self { + Self { + num_rows, + _phantom: PhantomData, + } + } + + /// Generate the trace using `x0, x1` as inital state values. + fn generate_trace(&self, x0: F, x1: F) -> Vec<[F; Self::COLUMNS]> { + (0..self.num_rows) + .scan([x0, x1], |acc, _| { + let tmp = *acc; + acc[0] = tmp[1]; + acc[1] = tmp[0] + tmp[1]; + Some(tmp) + }) + .collect() + } +} + +impl, const D: usize> Stark for FibonacciStark { + const COLUMNS: usize = 2; + const PUBLIC_INPUTS: usize = 3; + + fn eval_packed_generic( + &self, + vars: StarkEvaluationVars, + yield_constr: &mut ConstraintConsumer

, + ) where + FE: FieldExtension, + P: PackedField, + { + // Check public inputs. + yield_constr + .constraint_first_row(vars.local_values[0] - vars.public_inputs[Self::PI_INDEX_X0]); + yield_constr + .constraint_first_row(vars.local_values[1] - vars.public_inputs[Self::PI_INDEX_X1]); + yield_constr + .constraint_last_row(vars.local_values[1] - vars.public_inputs[Self::PI_INDEX_RES]); + + // x0 <- x1 + yield_constr.constraint(vars.next_values[0] - vars.local_values[1]); + // x1 <- x0 + x1 + yield_constr.constraint(vars.next_values[1] - vars.local_values[0] - vars.local_values[1]); + } + + fn eval_ext_recursively( + &self, + builder: &mut CircuitBuilder, + vars: StarkEvaluationTargets, + yield_constr: &mut RecursiveConstraintConsumer, + ) { + todo!() + } + + fn constraint_degree(&self) -> usize { + 2 + } +} + +#[cfg(test)] +mod tests { + use anyhow::Result; + use plonky2::field::field_types::Field; + use plonky2::plonk::config::{GenericConfig, PoseidonGoldilocksConfig}; + use plonky2::util::timing::TimingTree; + + use crate::config::StarkConfig; + use crate::fibonacci_stark::FibonacciStark; + use crate::prover::prove; + use crate::stark_testing::test_stark_low_degree; + use crate::verifier::verify; + + fn fibonacci(n: usize, x0: F, x1: F) -> F { + (0..n).fold((x0, x1), |x, _| (x.1, x.0 + x.1)).1 + } + + #[test] + fn test_fibonacci_stark() -> Result<()> { + const D: usize = 2; + type C = PoseidonGoldilocksConfig; + type F = >::F; + type S = FibonacciStark; + + let config = StarkConfig::standard_fast_config(); + let num_rows = 1 << 5; + let public_inputs = [F::ZERO, F::ONE, fibonacci(num_rows - 1, F::ZERO, F::ONE)]; + let stark = S::new(num_rows); + let trace = stark.generate_trace(public_inputs[0], public_inputs[1]); + let proof = prove::( + stark, + &config, + trace, + public_inputs, + &mut TimingTree::default(), + )?; + + verify(stark, proof, &config) + } + + #[test] + fn test_fibonacci_stark_degree() -> Result<()> { + const D: usize = 2; + type C = PoseidonGoldilocksConfig; + type F = >::F; + type S = FibonacciStark; + + let config = StarkConfig::standard_fast_config(); + let num_rows = 1 << 5; + let stark = S::new(num_rows); + test_stark_low_degree(stark) + } +} diff --git a/starky/src/get_challenges.rs b/starky/src/get_challenges.rs new file mode 100644 index 00000000..7e89ca3e --- /dev/null +++ b/starky/src/get_challenges.rs @@ -0,0 +1,203 @@ +use anyhow::Result; +use plonky2::field::extension_field::Extendable; +use plonky2::field::polynomial::PolynomialCoeffs; +use plonky2::fri::proof::FriProof; +use plonky2::hash::hash_types::RichField; +use plonky2::hash::merkle_tree::MerkleCap; +use plonky2::iop::challenger::Challenger; +use plonky2::plonk::config::GenericConfig; + +use crate::config::StarkConfig; +use crate::proof::{StarkOpeningSet, StarkProof, StarkProofChallenges, StarkProofWithPublicInputs}; + +#[allow(clippy::too_many_arguments)] +fn get_challenges, C: GenericConfig, const D: usize>( + trace_cap: &MerkleCap, + quotient_polys_cap: &MerkleCap, + openings: &StarkOpeningSet, + commit_phase_merkle_caps: &[MerkleCap], + final_poly: &PolynomialCoeffs, + pow_witness: F, + config: &StarkConfig, + degree_bits: usize, +) -> Result> { + let num_challenges = config.num_challenges; + let num_fri_queries = config.fri_config.num_query_rounds; + let lde_size = 1 << (degree_bits + config.fri_config.rate_bits); + + let mut challenger = Challenger::::new(); + + challenger.observe_cap(trace_cap); + let stark_alphas = challenger.get_n_challenges(num_challenges); + + challenger.observe_cap(quotient_polys_cap); + let stark_zeta = challenger.get_extension_challenge::(); + + challenger.observe_openings(&openings.to_fri_openings()); + + Ok(StarkProofChallenges { + stark_alphas, + stark_zeta, + fri_challenges: challenger.fri_challenges::( + commit_phase_merkle_caps, + final_poly, + pow_witness, + degree_bits, + &config.fri_config, + ), + }) +} + +impl, C: GenericConfig, const D: usize> + StarkProofWithPublicInputs +{ + pub(crate) fn fri_query_indices( + &self, + config: &StarkConfig, + degree_bits: usize, + ) -> anyhow::Result> { + Ok(self + .get_challenges(config, degree_bits)? + .fri_challenges + .fri_query_indices) + } + + /// Computes all Fiat-Shamir challenges used in the STARK proof. + pub(crate) fn get_challenges( + &self, + config: &StarkConfig, + degree_bits: usize, + ) -> Result> { + let StarkProof { + trace_cap, + quotient_polys_cap, + openings, + opening_proof: + FriProof { + commit_phase_merkle_caps, + final_poly, + pow_witness, + .. + }, + } = &self.proof; + + get_challenges::( + trace_cap, + quotient_polys_cap, + openings, + commit_phase_merkle_caps, + final_poly, + *pow_witness, + config, + degree_bits, + ) + } +} + +// TODO: Deal with the compressed stuff. +// impl, C: GenericConfig, const D: usize> +// CompressedProofWithPublicInputs +// { +// /// Computes all Fiat-Shamir challenges used in the Plonk proof. +// pub(crate) fn get_challenges( +// &self, +// common_data: &CommonCircuitData, +// ) -> anyhow::Result> { +// let CompressedProof { +// wires_cap, +// plonk_zs_partial_products_cap, +// quotient_polys_cap, +// openings, +// opening_proof: +// CompressedFriProof { +// commit_phase_merkle_caps, +// final_poly, +// pow_witness, +// .. +// }, +// } = &self.proof; +// +// get_challenges( +// self.get_public_inputs_hash(), +// wires_cap, +// plonk_zs_partial_products_cap, +// quotient_polys_cap, +// openings, +// commit_phase_merkle_caps, +// final_poly, +// *pow_witness, +// common_data, +// ) +// } +// +// /// Computes all coset elements that can be inferred in the FRI reduction steps. +// pub(crate) fn get_inferred_elements( +// &self, +// challenges: &ProofChallenges, +// common_data: &CommonCircuitData, +// ) -> FriInferredElements { +// let ProofChallenges { +// plonk_zeta, +// fri_alpha, +// fri_betas, +// fri_query_indices, +// .. +// } = challenges; +// let mut fri_inferred_elements = Vec::new(); +// // Holds the indices that have already been seen at each reduction depth. +// let mut seen_indices_by_depth = +// vec![HashSet::new(); common_data.fri_params.reduction_arity_bits.len()]; +// let precomputed_reduced_evals = PrecomputedReducedOpenings::from_os_and_alpha( +// &self.proof.openings.to_fri_openings(), +// *fri_alpha, +// ); +// let log_n = common_data.degree_bits + common_data.config.fri_config.rate_bits; +// // Simulate the proof verification and collect the inferred elements. +// // The content of the loop is basically the same as the `fri_verifier_query_round` function. +// for &(mut x_index) in fri_query_indices { +// let mut subgroup_x = F::MULTIPLICATIVE_GROUP_GENERATOR +// * F::primitive_root_of_unity(log_n).exp_u64(reverse_bits(x_index, log_n) as u64); +// let mut old_eval = fri_combine_initial::( +// &common_data.get_fri_instance(*plonk_zeta), +// &self +// .proof +// .opening_proof +// .query_round_proofs +// .initial_trees_proofs[&x_index], +// *fri_alpha, +// subgroup_x, +// &precomputed_reduced_evals, +// &common_data.fri_params, +// ); +// for (i, &arity_bits) in common_data +// .fri_params +// .reduction_arity_bits +// .iter() +// .enumerate() +// { +// let coset_index = x_index >> arity_bits; +// if !seen_indices_by_depth[i].insert(coset_index) { +// // If this index has already been seen, we can skip the rest of the reductions. +// break; +// } +// fri_inferred_elements.push(old_eval); +// let arity = 1 << arity_bits; +// let mut evals = self.proof.opening_proof.query_round_proofs.steps[i][&coset_index] +// .evals +// .clone(); +// let x_index_within_coset = x_index & (arity - 1); +// evals.insert(x_index_within_coset, old_eval); +// old_eval = compute_evaluation( +// subgroup_x, +// x_index_within_coset, +// arity_bits, +// &evals, +// fri_betas[i], +// ); +// subgroup_x = subgroup_x.exp_power_of_2(arity_bits); +// x_index = coset_index; +// } +// } +// FriInferredElements(fri_inferred_elements) +// } +// } diff --git a/starky/src/lib.rs b/starky/src/lib.rs new file mode 100644 index 00000000..eefab529 --- /dev/null +++ b/starky/src/lib.rs @@ -0,0 +1,18 @@ +// TODO: Remove these when crate is closer to being finished. +#![allow(dead_code)] +#![allow(unused_variables)] +#![allow(incomplete_features)] +#![feature(generic_const_exprs)] + +pub mod config; +pub mod constraint_consumer; +mod get_challenges; +pub mod proof; +pub mod prover; +pub mod stark; +pub mod stark_testing; +pub mod vars; +pub mod verifier; + +#[cfg(test)] +pub mod fibonacci_stark; diff --git a/starky/src/proof.rs b/starky/src/proof.rs new file mode 100644 index 00000000..b7ecd912 --- /dev/null +++ b/starky/src/proof.rs @@ -0,0 +1,114 @@ +use plonky2::field::extension_field::Extendable; +use plonky2::fri::oracle::PolynomialBatch; +use plonky2::fri::proof::{CompressedFriProof, FriChallenges, FriProof}; +use plonky2::fri::structure::{FriOpeningBatch, FriOpenings}; +use plonky2::hash::hash_types::RichField; +use plonky2::hash::merkle_tree::MerkleCap; +use plonky2::plonk::config::GenericConfig; +use rayon::prelude::*; + +pub struct StarkProof, C: GenericConfig, const D: usize> { + /// Merkle cap of LDEs of trace values. + pub trace_cap: MerkleCap, + /// Merkle cap of LDEs of trace values. + pub quotient_polys_cap: MerkleCap, + /// Purported values of each polynomial at the challenge point. + pub openings: StarkOpeningSet, + /// A batch FRI argument for all openings. + pub opening_proof: FriProof, +} + +pub struct StarkProofWithPublicInputs< + F: RichField + Extendable, + C: GenericConfig, + const D: usize, +> { + pub proof: StarkProof, + // TODO: Maybe make it generic over a `S: Stark` and replace with `[F; S::PUBLIC_INPUTS]`. + pub public_inputs: Vec, +} + +pub struct CompressedStarkProof< + F: RichField + Extendable, + C: GenericConfig, + const D: usize, +> { + /// Merkle cap of LDEs of trace values. + pub trace_cap: MerkleCap, + /// Purported values of each polynomial at the challenge point. + pub openings: StarkOpeningSet, + /// A batch FRI argument for all openings. + pub opening_proof: CompressedFriProof, +} + +pub struct CompressedStarkProofWithPublicInputs< + F: RichField + Extendable, + C: GenericConfig, + const D: usize, +> { + pub proof: CompressedStarkProof, + pub public_inputs: Vec, +} + +pub(crate) struct StarkProofChallenges, const D: usize> { + /// Random values used to combine STARK constraints. + pub stark_alphas: Vec, + + /// Point at which the STARK polynomials are opened. + pub stark_zeta: F::Extension, + + pub fri_challenges: FriChallenges, +} + +/// Purported values of each polynomial at the challenge point. +pub struct StarkOpeningSet, const D: usize> { + pub local_values: Vec, + pub next_values: Vec, + pub permutation_zs: Vec, + pub permutation_zs_right: Vec, + pub quotient_polys: Vec, +} + +impl, const D: usize> StarkOpeningSet { + pub fn new>( + zeta: F::Extension, + g: F::Extension, + trace_commitment: &PolynomialBatch, + quotient_commitment: &PolynomialBatch, + ) -> Self { + let eval_commitment = |z: F::Extension, c: &PolynomialBatch| { + c.polynomials + .par_iter() + .map(|p| p.to_extension().eval(z)) + .collect::>() + }; + Self { + local_values: eval_commitment(zeta, trace_commitment), + next_values: eval_commitment(zeta * g, trace_commitment), + permutation_zs: vec![/*TODO*/], + permutation_zs_right: vec![/*TODO*/], + quotient_polys: eval_commitment(zeta, quotient_commitment), + } + } + + pub(crate) fn to_fri_openings(&self) -> FriOpenings { + let zeta_batch = FriOpeningBatch { + values: [ + self.local_values.as_slice(), + self.quotient_polys.as_slice(), + self.permutation_zs.as_slice(), + ] + .concat(), + }; + let zeta_right_batch = FriOpeningBatch { + values: [ + self.next_values.as_slice(), + self.permutation_zs_right.as_slice(), + ] + .concat(), + }; + FriOpenings { + batches: vec![zeta_batch, zeta_right_batch], + } + } +} diff --git a/starky/src/prover.rs b/starky/src/prover.rs new file mode 100644 index 00000000..e88aa619 --- /dev/null +++ b/starky/src/prover.rs @@ -0,0 +1,227 @@ +use anyhow::{ensure, Result}; +use itertools::Itertools; +use plonky2::field::extension_field::Extendable; +use plonky2::field::field_types::Field; +use plonky2::field::polynomial::{PolynomialCoeffs, PolynomialValues}; +use plonky2::field::zero_poly_coset::ZeroPolyOnCoset; +use plonky2::fri::oracle::PolynomialBatch; +use plonky2::hash::hash_types::RichField; +use plonky2::iop::challenger::Challenger; +use plonky2::plonk::config::{GenericConfig, Hasher}; +use plonky2::timed; +use plonky2::util::timing::TimingTree; +use plonky2::util::transpose; +use plonky2_util::{log2_ceil, log2_strict}; +use rayon::prelude::*; + +use crate::config::StarkConfig; +use crate::constraint_consumer::ConstraintConsumer; +use crate::proof::{StarkOpeningSet, StarkProof, StarkProofWithPublicInputs}; +use crate::stark::Stark; +use crate::vars::StarkEvaluationVars; + +pub fn prove( + stark: S, + config: &StarkConfig, + trace: Vec<[F; S::COLUMNS]>, + public_inputs: [F; S::PUBLIC_INPUTS], + timing: &mut TimingTree, +) -> Result> +where + F: RichField + Extendable, + C: GenericConfig, + S: Stark, + [(); S::COLUMNS]:, + [(); S::PUBLIC_INPUTS]:, + [(); C::Hasher::HASH_SIZE]:, +{ + let degree = trace.len(); + let degree_bits = log2_strict(degree); + + let trace_vecs = trace.into_iter().map(|row| row.to_vec()).collect_vec(); + let trace_col_major: Vec> = transpose(&trace_vecs); + + let trace_poly_values: Vec> = timed!( + timing, + "compute trace polynomials", + trace_col_major + .par_iter() + .map(|column| PolynomialValues::new(column.clone())) + .collect() + ); + + let rate_bits = config.fri_config.rate_bits; + let cap_height = config.fri_config.cap_height; + let trace_commitment = timed!( + timing, + "compute trace commitment", + PolynomialBatch::::from_values( + trace_poly_values, + rate_bits, + false, + cap_height, + timing, + None, + ) + ); + + let trace_cap = trace_commitment.merkle_tree.cap.clone(); + let mut challenger = Challenger::new(); + challenger.observe_cap(&trace_cap); + + let alphas = challenger.get_n_challenges(config.num_challenges); + let quotient_polys = compute_quotient_polys::( + &stark, + &trace_commitment, + public_inputs, + alphas, + degree_bits, + rate_bits, + ); + let all_quotient_chunks = quotient_polys + .into_par_iter() + .flat_map(|mut quotient_poly| { + quotient_poly + .trim_to_len(degree * stark.quotient_degree_factor()) + .expect("Quotient has failed, the vanishing polynomial is not divisible by Z_H"); + // Split quotient into degree-n chunks. + quotient_poly.chunks(degree) + }) + .collect(); + let quotient_commitment = timed!( + timing, + "compute quotient commitment", + PolynomialBatch::from_coeffs( + all_quotient_chunks, + rate_bits, + false, + config.fri_config.cap_height, + timing, + None, + ) + ); + let quotient_polys_cap = quotient_commitment.merkle_tree.cap.clone(); + challenger.observe_cap("ient_polys_cap); + + let zeta = challenger.get_extension_challenge::(); + // To avoid leaking witness data, we want to ensure that our opening locations, `zeta` and + // `g * zeta`, are not in our subgroup `H`. It suffices to check `zeta` only, since + // `(g * zeta)^n = zeta^n`, where `n` is the order of `g`. + let g = F::Extension::primitive_root_of_unity(degree_bits); + ensure!( + zeta.exp_power_of_2(degree_bits) != F::Extension::ONE, + "Opening point is in the subgroup." + ); + let openings = StarkOpeningSet::new(zeta, g, &trace_commitment, "ient_commitment); + challenger.observe_openings(&openings.to_fri_openings()); + + // TODO: Add permuation checks + let initial_merkle_trees = &[&trace_commitment, "ient_commitment]; + let fri_params = config.fri_params(degree_bits); + + let opening_proof = timed!( + timing, + "compute openings proof", + PolynomialBatch::prove_openings( + &stark.fri_instance(zeta, g, rate_bits, config.num_challenges), + initial_merkle_trees, + &mut challenger, + &fri_params, + timing, + ) + ); + let proof = StarkProof { + trace_cap, + quotient_polys_cap, + openings, + opening_proof, + }; + + Ok(StarkProofWithPublicInputs { + proof, + public_inputs: public_inputs.to_vec(), + }) +} + +/// Computes the quotient polynomials `(sum alpha^i C_i(x)) / Z_H(x)` for `alpha` in `alphas`, +/// where the `C_i`s are the Stark constraints. +fn compute_quotient_polys( + stark: &S, + trace_commitment: &PolynomialBatch, + public_inputs: [F; S::PUBLIC_INPUTS], + alphas: Vec, + degree_bits: usize, + rate_bits: usize, +) -> Vec> +where + F: RichField + Extendable, + C: GenericConfig, + S: Stark, + [(); S::COLUMNS]:, + [(); S::PUBLIC_INPUTS]:, +{ + let degree = 1 << degree_bits; + + let quotient_degree_bits = log2_ceil(stark.quotient_degree_factor()); + assert!( + quotient_degree_bits <= rate_bits, + "Having constraints of degree higher than the rate is not supported yet." + ); + let step = 1 << (rate_bits - quotient_degree_bits); + // When opening the `Z`s polys at the "next" point, need to look at the point `next_step` steps away. + let next_step = 1 << quotient_degree_bits; + + // Evaluation of the first Lagrange polynomial on the LDE domain. + let lagrange_first = PolynomialValues::selector(degree, 0).lde_onto_coset(quotient_degree_bits); + // Evaluation of the last Lagrange polynomial on the LDE domain. + let lagrange_last = + PolynomialValues::selector(degree, degree - 1).lde_onto_coset(quotient_degree_bits); + + let z_h_on_coset = ZeroPolyOnCoset::::new(degree_bits, quotient_degree_bits); + + // Retrieve the LDE values at index `i`. + let get_at_index = |comm: &PolynomialBatch, i: usize| -> [F; S::COLUMNS] { + comm.get_lde_values(i * step).try_into().unwrap() + }; + // Last element of the subgroup. + let last = F::primitive_root_of_unity(degree_bits).inverse(); + let size = degree << quotient_degree_bits; + let coset = F::cyclic_subgroup_coset_known_order( + F::primitive_root_of_unity(degree_bits + quotient_degree_bits), + F::coset_shift(), + size, + ); + + let quotient_values = (0..size) + .into_par_iter() + .map(|i| { + // TODO: Set `P` to a genuine `PackedField` here. + let mut consumer = ConstraintConsumer::::new( + alphas.clone(), + coset[i] - last, + lagrange_first.values[i], + lagrange_last.values[i], + ); + let vars = StarkEvaluationVars:: { + local_values: &get_at_index(trace_commitment, i), + next_values: &get_at_index(trace_commitment, (i + next_step) % size), + public_inputs: &public_inputs, + }; + stark.eval_packed_base(vars, &mut consumer); + // TODO: Fix this once we use a genuine `PackedField`. + let mut constraints_evals = consumer.accumulators(); + // We divide the constraints evaluations by `Z_H(x)`. + let denominator_inv = z_h_on_coset.eval_inverse(i); + for eval in &mut constraints_evals { + *eval *= denominator_inv; + } + constraints_evals + }) + .collect::>(); + + transpose("ient_values) + .into_par_iter() + .map(PolynomialValues::new) + .map(|values| values.coset_ifft(F::coset_shift())) + .collect() +} diff --git a/starky/src/stark.rs b/starky/src/stark.rs new file mode 100644 index 00000000..4b20553e --- /dev/null +++ b/starky/src/stark.rs @@ -0,0 +1,99 @@ +use plonky2::field::extension_field::{Extendable, FieldExtension}; +use plonky2::field::packed_field::PackedField; +use plonky2::fri::structure::{FriBatchInfo, FriInstanceInfo, FriOracleInfo, FriPolynomialInfo}; +use plonky2::hash::hash_types::RichField; +use plonky2::plonk::circuit_builder::CircuitBuilder; + +use crate::constraint_consumer::{ConstraintConsumer, RecursiveConstraintConsumer}; +use crate::vars::StarkEvaluationTargets; +use crate::vars::StarkEvaluationVars; + +/// Represents a STARK system. +// TODO: Add a `constraint_degree` fn that returns the maximum constraint degree. +pub trait Stark, const D: usize>: Sync { + /// The total number of columns in the trace. + const COLUMNS: usize; + /// The number of public inputs. + const PUBLIC_INPUTS: usize; + + /// Evaluate constraints at a vector of points. + /// + /// The points are elements of a field `FE`, a degree `D2` extension of `F`. This lets us + /// evaluate constraints over a larger domain if desired. This can also be called with `FE = F` + /// and `D2 = 1`, in which case we are using the trivial extension, i.e. just evaluating + /// constraints over `F`. + fn eval_packed_generic( + &self, + vars: StarkEvaluationVars, + yield_constr: &mut ConstraintConsumer

, + ) where + FE: FieldExtension, + P: PackedField; + + /// Evaluate constraints at a vector of points from the base field `F`. + fn eval_packed_base>( + &self, + vars: StarkEvaluationVars, + yield_constr: &mut ConstraintConsumer

, + ) { + self.eval_packed_generic(vars, yield_constr) + } + + /// Evaluate constraints at a single point from the degree `D` extension field. + fn eval_ext( + &self, + vars: StarkEvaluationVars< + F::Extension, + F::Extension, + { Self::COLUMNS }, + { Self::PUBLIC_INPUTS }, + >, + yield_constr: &mut ConstraintConsumer, + ) { + self.eval_packed_generic(vars, yield_constr) + } + + /// Evaluate constraints at a vector of points from the degree `D` extension field. This is like + /// `eval_ext`, except in the context of a recursive circuit. + fn eval_ext_recursively( + &self, + builder: &mut CircuitBuilder, + vars: StarkEvaluationTargets, + yield_constr: &mut RecursiveConstraintConsumer, + ); + + /// The maximum constraint degree. + fn constraint_degree(&self) -> usize; + + /// The maximum constraint degree. + fn quotient_degree_factor(&self) -> usize { + 1.max(self.constraint_degree() - 1) + } + + /// Computes the FRI instance used to prove this Stark. + // TODO: Permutation polynomials. + fn fri_instance( + &self, + zeta: F::Extension, + g: F::Extension, + rate_bits: usize, + num_challenges: usize, + ) -> FriInstanceInfo { + let no_blinding_oracle = FriOracleInfo { blinding: false }; + let trace_info = FriPolynomialInfo::from_range(0, 0..Self::COLUMNS); + let quotient_info = + FriPolynomialInfo::from_range(1, 0..self.quotient_degree_factor() * num_challenges); + let zeta_batch = FriBatchInfo { + point: zeta, + polynomials: [trace_info.clone(), quotient_info].concat(), + }; + let zeta_right_batch = FriBatchInfo:: { + point: zeta * g, + polynomials: trace_info, + }; + FriInstanceInfo { + oracles: vec![no_blinding_oracle; 3], + batches: vec![zeta_batch], + } + } +} diff --git a/starky/src/stark_testing.rs b/starky/src/stark_testing.rs new file mode 100644 index 00000000..222ebf39 --- /dev/null +++ b/starky/src/stark_testing.rs @@ -0,0 +1,87 @@ +use anyhow::{ensure, Result}; +use plonky2::field::extension_field::Extendable; +use plonky2::field::field_types::Field; +use plonky2::field::polynomial::{PolynomialCoeffs, PolynomialValues}; +use plonky2::hash::hash_types::RichField; +use plonky2::util::transpose; +use plonky2_util::{log2_ceil, log2_strict}; + +use crate::constraint_consumer::ConstraintConsumer; +use crate::stark::Stark; +use crate::vars::StarkEvaluationVars; + +const WITNESS_SIZE: usize = 1 << 5; + +/// Tests that the constraints imposed by the given STARK are low-degree by applying them to random +/// low-degree witness polynomials. +pub fn test_stark_low_degree, S: Stark, const D: usize>( + stark: S, +) -> Result<()> +where + [(); S::COLUMNS]:, + [(); S::PUBLIC_INPUTS]:, +{ + let rate_bits = log2_ceil(stark.constraint_degree() + 1); + + let trace_ldes = random_low_degree_matrix::(S::COLUMNS, rate_bits); + let size = trace_ldes.len(); + let public_inputs = F::rand_arr::<{ S::PUBLIC_INPUTS }>(); + + let lagrange_first = PolynomialValues::selector(WITNESS_SIZE, 0).lde(rate_bits); + let lagrange_last = PolynomialValues::selector(WITNESS_SIZE, WITNESS_SIZE - 1).lde(rate_bits); + + let last = F::primitive_root_of_unity(log2_strict(WITNESS_SIZE)).inverse(); + let subgroup = + F::cyclic_subgroup_known_order(F::primitive_root_of_unity(log2_strict(size)), size); + let alpha = F::rand(); + let constraint_evals = (0..size) + .map(|i| { + let vars = StarkEvaluationVars { + local_values: &trace_ldes[i].clone().try_into().unwrap(), + next_values: &trace_ldes[(i + (1 << rate_bits)) % size] + .clone() + .try_into() + .unwrap(), + public_inputs: &public_inputs, + }; + + let mut consumer = ConstraintConsumer::::new( + vec![alpha], + subgroup[i] - last, + lagrange_first.values[i], + lagrange_last.values[i], + ); + stark.eval_packed_base(vars, &mut consumer); + consumer.accumulators()[0] + }) + .collect::>(); + + let constraint_eval_degree = PolynomialValues::new(constraint_evals).degree(); + let maximum_degree = WITNESS_SIZE * stark.constraint_degree() - 1; + + ensure!( + constraint_eval_degree <= maximum_degree, + "Expected degrees at most {} * {} - 1 = {}, actual {:?}", + WITNESS_SIZE, + stark.constraint_degree(), + maximum_degree, + constraint_eval_degree + ); + + Ok(()) +} + +fn random_low_degree_matrix(num_polys: usize, rate_bits: usize) -> Vec> { + let polys = (0..num_polys) + .map(|_| random_low_degree_values(rate_bits)) + .collect::>(); + + transpose(&polys) +} + +fn random_low_degree_values(rate_bits: usize) -> Vec { + PolynomialCoeffs::new(F::rand_vec(WITNESS_SIZE)) + .lde(rate_bits) + .fft() + .values +} diff --git a/starky/src/vars.rs b/starky/src/vars.rs new file mode 100644 index 00000000..cb83aeb7 --- /dev/null +++ b/starky/src/vars.rs @@ -0,0 +1,26 @@ +use plonky2::field::field_types::Field; +use plonky2::field::packed_field::PackedField; +use plonky2::iop::ext_target::ExtensionTarget; + +#[derive(Debug, Copy, Clone)] +pub struct StarkEvaluationVars<'a, F, P, const COLUMNS: usize, const PUBLIC_INPUTS: usize> +where + F: Field, + P: PackedField, +{ + pub local_values: &'a [P; COLUMNS], + pub next_values: &'a [P; COLUMNS], + pub public_inputs: &'a [P::Scalar; PUBLIC_INPUTS], +} + +#[derive(Debug, Copy, Clone)] +pub struct StarkEvaluationTargets< + 'a, + const D: usize, + const COLUMNS: usize, + const PUBLIC_INPUTS: usize, +> { + pub local_values: &'a [ExtensionTarget; COLUMNS], + pub next_values: &'a [ExtensionTarget; COLUMNS], + pub public_inputs: &'a [ExtensionTarget; PUBLIC_INPUTS], +} diff --git a/starky/src/verifier.rs b/starky/src/verifier.rs new file mode 100644 index 00000000..8bf1faab --- /dev/null +++ b/starky/src/verifier.rs @@ -0,0 +1,176 @@ +use anyhow::{ensure, Result}; +use plonky2::field::extension_field::{Extendable, FieldExtension}; +use plonky2::field::field_types::Field; +use plonky2::fri::verifier::verify_fri_proof; +use plonky2::hash::hash_types::RichField; +use plonky2::plonk::config::{GenericConfig, Hasher}; +use plonky2::plonk::plonk_common::reduce_with_powers; +use plonky2_util::log2_strict; + +use crate::config::StarkConfig; +use crate::constraint_consumer::ConstraintConsumer; +use crate::proof::{StarkOpeningSet, StarkProof, StarkProofChallenges, StarkProofWithPublicInputs}; +use crate::stark::Stark; +use crate::vars::StarkEvaluationVars; + +pub fn verify< + F: RichField + Extendable, + C: GenericConfig, + S: Stark, + const D: usize, +>( + stark: S, + proof_with_pis: StarkProofWithPublicInputs, + config: &StarkConfig, +) -> Result<()> +where + [(); S::COLUMNS]:, + [(); S::PUBLIC_INPUTS]:, + [(); C::Hasher::HASH_SIZE]:, +{ + let degree_bits = log2_strict(recover_degree(&proof_with_pis.proof, config)); + let challenges = proof_with_pis.get_challenges(config, degree_bits)?; + verify_with_challenges(stark, proof_with_pis, challenges, degree_bits, config) +} + +pub(crate) fn verify_with_challenges< + F: RichField + Extendable, + C: GenericConfig, + S: Stark, + const D: usize, +>( + stark: S, + proof_with_pis: StarkProofWithPublicInputs, + challenges: StarkProofChallenges, + degree_bits: usize, + config: &StarkConfig, +) -> Result<()> +where + [(); S::COLUMNS]:, + [(); S::PUBLIC_INPUTS]:, + [(); C::Hasher::HASH_SIZE]:, +{ + let StarkProofWithPublicInputs { + proof, + public_inputs, + } = proof_with_pis; + let local_values = &proof.openings.local_values; + let next_values = &proof.openings.local_values; + let StarkOpeningSet { + local_values, + next_values, + permutation_zs, + permutation_zs_right, + quotient_polys, + } = &proof.openings; + let vars = StarkEvaluationVars { + local_values: &local_values.to_vec().try_into().unwrap(), + next_values: &next_values.to_vec().try_into().unwrap(), + public_inputs: &public_inputs + .into_iter() + .map(F::Extension::from_basefield) + .collect::>() + .try_into() + .unwrap(), + }; + + let (l_1, l_last) = eval_l_1_and_l_last(degree_bits, challenges.stark_zeta); + let last = F::primitive_root_of_unity(degree_bits).inverse(); + let z_last = challenges.stark_zeta - last.into(); + let mut consumer = ConstraintConsumer::::new( + challenges + .stark_alphas + .iter() + .map(|&alpha| F::Extension::from_basefield(alpha)) + .collect::>(), + z_last, + l_1, + l_last, + ); + stark.eval_ext(vars, &mut consumer); + let vanishing_polys_zeta = consumer.accumulators(); + + // Check each polynomial identity, of the form `vanishing(x) = Z_H(x) quotient(x)`, at zeta. + let quotient_polys_zeta = &proof.openings.quotient_polys; + let zeta_pow_deg = challenges.stark_zeta.exp_power_of_2(degree_bits); + let z_h_zeta = zeta_pow_deg - F::Extension::ONE; + // `quotient_polys_zeta` holds `num_challenges * quotient_degree_factor` evaluations. + // Each chunk of `quotient_degree_factor` holds the evaluations of `t_0(zeta),...,t_{quotient_degree_factor-1}(zeta)` + // where the "real" quotient polynomial is `t(X) = t_0(X) + t_1(X)*X^n + t_2(X)*X^{2n} + ...`. + // So to reconstruct `t(zeta)` we can compute `reduce_with_powers(chunk, zeta^n)` for each + // `quotient_degree_factor`-sized chunk of the original evaluations. + for (i, chunk) in quotient_polys_zeta + .chunks(stark.quotient_degree_factor()) + .enumerate() + { + ensure!(vanishing_polys_zeta[i] == z_h_zeta * reduce_with_powers(chunk, zeta_pow_deg)); + } + + // TODO: Permutation polynomials. + let merkle_caps = &[proof.trace_cap, proof.quotient_polys_cap]; + + verify_fri_proof::( + &stark.fri_instance( + challenges.stark_zeta, + F::primitive_root_of_unity(degree_bits).into(), + config.fri_config.rate_bits, + config.num_challenges, + ), + &proof.openings.to_fri_openings(), + &challenges.fri_challenges, + merkle_caps, + &proof.opening_proof, + &config.fri_params(degree_bits), + )?; + + Ok(()) +} + +/// Evaluate the Lagrange polynomials `L_1` and `L_n` at a point `x`. +/// `L_1(x) = (x^n - 1)/(n * (x - 1))` +/// `L_n(x) = (x^n - 1)/(n * (g * x - 1))`, with `g` the first element of the subgroup. +fn eval_l_1_and_l_last(log_n: usize, x: F) -> (F, F) { + let n = F::from_canonical_usize(1 << log_n); + let g = F::primitive_root_of_unity(log_n); + let z_x = x.exp_power_of_2(log_n) - F::ONE; + let invs = F::batch_multiplicative_inverse(&[n * (x - F::ONE), n * (g * x - F::ONE)]); + + (z_x * invs[0], z_x * invs[1]) +} + +/// Recover the length of the trace from a STARK proof and a STARK config. +fn recover_degree, C: GenericConfig, const D: usize>( + proof: &StarkProof, + config: &StarkConfig, +) -> usize { + let initial_merkle_proof = &proof.opening_proof.query_round_proofs[0] + .initial_trees_proof + .evals_proofs[0] + .1; + let lde_bits = config.fri_config.cap_height + initial_merkle_proof.siblings.len(); + 1 << (lde_bits - config.fri_config.rate_bits) +} + +#[cfg(test)] +mod tests { + use plonky2::field::field_types::Field; + use plonky2::field::goldilocks_field::GoldilocksField; + use plonky2::field::polynomial::PolynomialValues; + + use crate::verifier::eval_l_1_and_l_last; + + #[test] + fn test_eval_l_1_and_l_last() { + type F = GoldilocksField; + let log_n = 5; + let n = 1 << log_n; + + let x = F::rand(); // challenge point + let expected_l_first_x = PolynomialValues::selector(n, 0).ifft().eval(x); + let expected_l_last_x = PolynomialValues::selector(n, n - 1).ifft().eval(x); + + let (l_first_x, l_last_x) = eval_l_1_and_l_last(log_n, x); + assert_eq!(l_first_x, expected_l_first_x); + assert_eq!(l_last_x, expected_l_last_x); + } +} diff --git a/system_zero/Cargo.toml b/system_zero/Cargo.toml new file mode 100644 index 00000000..e5b617c9 --- /dev/null +++ b/system_zero/Cargo.toml @@ -0,0 +1,14 @@ +[package] +name = "system_zero" +description = "A VM whose execution can be verified with STARKs; designed for proving Ethereum transactions" +version = "0.1.0" +edition = "2021" + +[dependencies] +plonky2 = { path = "../plonky2" } +starky = { path = "../starky" } +anyhow = "1.0.40" +env_logger = "0.9.0" +log = "0.4.14" +rand = "0.8.4" +rand_chacha = "0.3.1" diff --git a/system_zero/src/arithmetic/addition.rs b/system_zero/src/arithmetic/addition.rs new file mode 100644 index 00000000..7aa0d81a --- /dev/null +++ b/system_zero/src/arithmetic/addition.rs @@ -0,0 +1,70 @@ +use plonky2::field::extension_field::Extendable; +use plonky2::field::field_types::{Field, PrimeField64}; +use plonky2::field::packed_field::PackedField; +use plonky2::hash::hash_types::RichField; +use plonky2::iop::ext_target::ExtensionTarget; +use plonky2::plonk::circuit_builder::CircuitBuilder; +use plonky2::plonk::plonk_common::reduce_with_powers_ext_recursive; +use starky::constraint_consumer::{ConstraintConsumer, RecursiveConstraintConsumer}; + +use crate::registers::arithmetic::*; +use crate::registers::NUM_COLUMNS; + +pub(crate) fn generate_addition(values: &mut [F; NUM_COLUMNS]) { + let in_1 = values[COL_ADD_INPUT_1].to_canonical_u64(); + let in_2 = values[COL_ADD_INPUT_2].to_canonical_u64(); + let in_3 = values[COL_ADD_INPUT_3].to_canonical_u64(); + let output = in_1 + in_2 + in_3; + + values[COL_ADD_OUTPUT_1] = F::from_canonical_u16(output as u16); + values[COL_ADD_OUTPUT_2] = F::from_canonical_u16((output >> 16) as u16); + values[COL_ADD_OUTPUT_3] = F::from_canonical_u16((output >> 32) as u16); +} + +pub(crate) fn eval_addition>( + local_values: &[P; NUM_COLUMNS], + yield_constr: &mut ConstraintConsumer

, +) { + let is_add = local_values[IS_ADD]; + let in_1 = local_values[COL_ADD_INPUT_1]; + let in_2 = local_values[COL_ADD_INPUT_2]; + let in_3 = local_values[COL_ADD_INPUT_3]; + let out_1 = local_values[COL_ADD_OUTPUT_1]; + let out_2 = local_values[COL_ADD_OUTPUT_2]; + let out_3 = local_values[COL_ADD_OUTPUT_3]; + + let weight_2 = F::from_canonical_u64(1 << 16); + let weight_3 = F::from_canonical_u64(1 << 32); + // Note that this can't overflow. Since each output limb has been range checked as 16-bits, + // this sum can be around 48 bits at most. + let out = out_1 + out_2 * weight_2 + out_3 * weight_3; + + let computed_out = in_1 + in_2 + in_3; + + yield_constr.constraint_wrapping(is_add * (out - computed_out)); +} + +pub(crate) fn eval_addition_recursively, const D: usize>( + builder: &mut CircuitBuilder, + local_values: &[ExtensionTarget; NUM_COLUMNS], + yield_constr: &mut RecursiveConstraintConsumer, +) { + let is_add = local_values[IS_ADD]; + let in_1 = local_values[COL_ADD_INPUT_1]; + let in_2 = local_values[COL_ADD_INPUT_2]; + let in_3 = local_values[COL_ADD_INPUT_3]; + let out_1 = local_values[COL_ADD_OUTPUT_1]; + let out_2 = local_values[COL_ADD_OUTPUT_2]; + let out_3 = local_values[COL_ADD_OUTPUT_3]; + + let limb_base = builder.constant(F::from_canonical_u64(1 << 16)); + // Note that this can't overflow. Since each output limb has been range checked as 16-bits, + // this sum can be around 48 bits at most. + let out = reduce_with_powers_ext_recursive(builder, &[out_1, out_2, out_3], limb_base); + + let computed_out = builder.add_many_extension(&[in_1, in_2, in_3]); + + let diff = builder.sub_extension(out, computed_out); + let filtered_diff = builder.mul_extension(is_add, diff); + yield_constr.constraint_wrapping(builder, filtered_diff); +} diff --git a/system_zero/src/arithmetic/division.rs b/system_zero/src/arithmetic/division.rs new file mode 100644 index 00000000..e91288b9 --- /dev/null +++ b/system_zero/src/arithmetic/division.rs @@ -0,0 +1,31 @@ +use plonky2::field::extension_field::Extendable; +use plonky2::field::field_types::{Field, PrimeField64}; +use plonky2::field::packed_field::PackedField; +use plonky2::hash::hash_types::RichField; +use plonky2::iop::ext_target::ExtensionTarget; +use plonky2::plonk::circuit_builder::CircuitBuilder; +use starky::constraint_consumer::{ConstraintConsumer, RecursiveConstraintConsumer}; + +use crate::registers::arithmetic::*; +use crate::registers::NUM_COLUMNS; + +pub(crate) fn generate_division(values: &mut [F; NUM_COLUMNS]) { + // TODO +} + +pub(crate) fn eval_division>( + local_values: &[P; NUM_COLUMNS], + yield_constr: &mut ConstraintConsumer

, +) { + let is_div = local_values[IS_DIV]; + // TODO +} + +pub(crate) fn eval_division_recursively, const D: usize>( + builder: &mut CircuitBuilder, + local_values: &[ExtensionTarget; NUM_COLUMNS], + yield_constr: &mut RecursiveConstraintConsumer, +) { + let is_div = local_values[IS_DIV]; + // TODO +} diff --git a/system_zero/src/arithmetic/mod.rs b/system_zero/src/arithmetic/mod.rs new file mode 100644 index 00000000..a2b3a4f8 --- /dev/null +++ b/system_zero/src/arithmetic/mod.rs @@ -0,0 +1,75 @@ +use plonky2::field::extension_field::Extendable; +use plonky2::field::field_types::{Field, PrimeField64}; +use plonky2::field::packed_field::PackedField; +use plonky2::hash::hash_types::RichField; +use plonky2::plonk::circuit_builder::CircuitBuilder; +use starky::constraint_consumer::{ConstraintConsumer, RecursiveConstraintConsumer}; +use starky::vars::StarkEvaluationTargets; +use starky::vars::StarkEvaluationVars; + +use crate::arithmetic::addition::{eval_addition, eval_addition_recursively, generate_addition}; +use crate::arithmetic::division::{eval_division, eval_division_recursively, generate_division}; +use crate::arithmetic::multiplication::{ + eval_multiplication, eval_multiplication_recursively, generate_multiplication, +}; +use crate::arithmetic::subtraction::{ + eval_subtraction, eval_subtraction_recursively, generate_subtraction, +}; +use crate::public_input_layout::NUM_PUBLIC_INPUTS; +use crate::registers::arithmetic::*; +use crate::registers::NUM_COLUMNS; + +mod addition; +mod division; +mod multiplication; +mod subtraction; + +pub(crate) fn generate_arithmetic_unit(values: &mut [F; NUM_COLUMNS]) { + if values[IS_ADD].is_one() { + generate_addition(values); + } else if values[IS_SUB].is_one() { + generate_subtraction(values); + } else if values[IS_MUL].is_one() { + generate_multiplication(values); + } else if values[IS_DIV].is_one() { + generate_division(values); + } +} + +pub(crate) fn eval_arithmetic_unit>( + vars: StarkEvaluationVars, + yield_constr: &mut ConstraintConsumer

, +) { + let local_values = &vars.local_values; + + // Check that the operation flag values are binary. + for col in [IS_ADD, IS_SUB, IS_MUL, IS_DIV] { + let val = local_values[col]; + yield_constr.constraint_wrapping(val * val - val); + } + + eval_addition(local_values, yield_constr); + eval_subtraction(local_values, yield_constr); + eval_multiplication(local_values, yield_constr); + eval_division(local_values, yield_constr); +} + +pub(crate) fn eval_arithmetic_unit_recursively, const D: usize>( + builder: &mut CircuitBuilder, + vars: StarkEvaluationTargets, + yield_constr: &mut RecursiveConstraintConsumer, +) { + let local_values = &vars.local_values; + + // Check that the operation flag values are binary. + for col in [IS_ADD, IS_SUB, IS_MUL, IS_DIV] { + let val = local_values[col]; + let constraint = builder.mul_sub_extension(val, val, val); + yield_constr.constraint_wrapping(builder, constraint); + } + + eval_addition_recursively(builder, local_values, yield_constr); + eval_subtraction_recursively(builder, local_values, yield_constr); + eval_multiplication_recursively(builder, local_values, yield_constr); + eval_division_recursively(builder, local_values, yield_constr); +} diff --git a/system_zero/src/arithmetic/multiplication.rs b/system_zero/src/arithmetic/multiplication.rs new file mode 100644 index 00000000..70c181d8 --- /dev/null +++ b/system_zero/src/arithmetic/multiplication.rs @@ -0,0 +1,31 @@ +use plonky2::field::extension_field::Extendable; +use plonky2::field::field_types::{Field, PrimeField64}; +use plonky2::field::packed_field::PackedField; +use plonky2::hash::hash_types::RichField; +use plonky2::iop::ext_target::ExtensionTarget; +use plonky2::plonk::circuit_builder::CircuitBuilder; +use starky::constraint_consumer::{ConstraintConsumer, RecursiveConstraintConsumer}; + +use crate::registers::arithmetic::*; +use crate::registers::NUM_COLUMNS; + +pub(crate) fn generate_multiplication(values: &mut [F; NUM_COLUMNS]) { + // TODO +} + +pub(crate) fn eval_multiplication>( + local_values: &[P; NUM_COLUMNS], + yield_constr: &mut ConstraintConsumer

, +) { + let is_mul = local_values[IS_MUL]; + // TODO +} + +pub(crate) fn eval_multiplication_recursively, const D: usize>( + builder: &mut CircuitBuilder, + local_values: &[ExtensionTarget; NUM_COLUMNS], + yield_constr: &mut RecursiveConstraintConsumer, +) { + let is_mul = local_values[IS_MUL]; + // TODO +} diff --git a/system_zero/src/arithmetic/subtraction.rs b/system_zero/src/arithmetic/subtraction.rs new file mode 100644 index 00000000..267bac72 --- /dev/null +++ b/system_zero/src/arithmetic/subtraction.rs @@ -0,0 +1,31 @@ +use plonky2::field::extension_field::Extendable; +use plonky2::field::field_types::{Field, PrimeField64}; +use plonky2::field::packed_field::PackedField; +use plonky2::hash::hash_types::RichField; +use plonky2::iop::ext_target::ExtensionTarget; +use plonky2::plonk::circuit_builder::CircuitBuilder; +use starky::constraint_consumer::{ConstraintConsumer, RecursiveConstraintConsumer}; + +use crate::registers::arithmetic::*; +use crate::registers::NUM_COLUMNS; + +pub(crate) fn generate_subtraction(values: &mut [F; NUM_COLUMNS]) { + // TODO +} + +pub(crate) fn eval_subtraction>( + local_values: &[P; NUM_COLUMNS], + yield_constr: &mut ConstraintConsumer

, +) { + let is_sub = local_values[IS_SUB]; + // TODO +} + +pub(crate) fn eval_subtraction_recursively, const D: usize>( + builder: &mut CircuitBuilder, + local_values: &[ExtensionTarget; NUM_COLUMNS], + yield_constr: &mut RecursiveConstraintConsumer, +) { + let is_sub = local_values[IS_SUB]; + // TODO +} diff --git a/system_zero/src/core_registers.rs b/system_zero/src/core_registers.rs new file mode 100644 index 00000000..c8c6533b --- /dev/null +++ b/system_zero/src/core_registers.rs @@ -0,0 +1,93 @@ +use plonky2::field::extension_field::Extendable; +use plonky2::field::field_types::{Field, PrimeField64}; +use plonky2::field::packed_field::PackedField; +use plonky2::hash::hash_types::RichField; +use plonky2::plonk::circuit_builder::CircuitBuilder; +use starky::constraint_consumer::{ConstraintConsumer, RecursiveConstraintConsumer}; +use starky::vars::StarkEvaluationTargets; +use starky::vars::StarkEvaluationVars; + +use crate::public_input_layout::NUM_PUBLIC_INPUTS; +use crate::registers::core::*; +use crate::registers::NUM_COLUMNS; + +pub(crate) fn generate_first_row_core_registers(first_values: &mut [F; NUM_COLUMNS]) { + first_values[COL_CLOCK] = F::ZERO; + first_values[COL_RANGE_16] = F::ZERO; + first_values[COL_INSTRUCTION_PTR] = F::ZERO; + first_values[COL_FRAME_PTR] = F::ZERO; + first_values[COL_STACK_PTR] = F::ZERO; +} + +pub(crate) fn generate_next_row_core_registers( + local_values: &[F; NUM_COLUMNS], + next_values: &mut [F; NUM_COLUMNS], +) { + // We increment the clock by 1. + next_values[COL_CLOCK] = local_values[COL_CLOCK] + F::ONE; + + // We increment the 16-bit table by 1, unless we've reached the max value of 2^16 - 1, in + // which case we repeat that value. + let prev_range_16 = local_values[COL_RANGE_16].to_canonical_u64(); + let next_range_16 = (prev_range_16 + 1).min((1 << 16) - 1); + next_values[COL_RANGE_16] = F::from_canonical_u64(next_range_16); + + // next_values[COL_INSTRUCTION_PTR] = todo!(); + + // next_values[COL_FRAME_PTR] = todo!(); + + // next_values[COL_STACK_PTR] = todo!(); +} + +#[inline] +pub(crate) fn eval_core_registers>( + vars: StarkEvaluationVars, + yield_constr: &mut ConstraintConsumer

, +) { + // The clock must start with 0, and increment by 1. + let local_clock = vars.local_values[COL_CLOCK]; + let next_clock = vars.next_values[COL_CLOCK]; + let delta_clock = next_clock - local_clock; + yield_constr.constraint_first_row(local_clock); + yield_constr.constraint(delta_clock - F::ONE); + + // The 16-bit table must start with 0, end with 2^16 - 1, and increment by 0 or 1. + let local_range_16 = vars.local_values[COL_RANGE_16]; + let next_range_16 = vars.next_values[COL_RANGE_16]; + let delta_range_16 = next_range_16 - local_range_16; + yield_constr.constraint_first_row(local_range_16); + yield_constr.constraint_last_row(local_range_16 - F::from_canonical_u64((1 << 16) - 1)); + yield_constr.constraint(delta_range_16 * delta_range_16 - delta_range_16); + + // TODO constraints for stack etc. +} + +pub(crate) fn eval_core_registers_recursively, const D: usize>( + builder: &mut CircuitBuilder, + vars: StarkEvaluationTargets, + yield_constr: &mut RecursiveConstraintConsumer, +) { + let one_ext = builder.one_extension(); + let max_u16 = builder.constant(F::from_canonical_u64((1 << 16) - 1)); + let max_u16_ext = builder.convert_to_ext(max_u16); + + // The clock must start with 0, and increment by 1. + let local_clock = vars.local_values[COL_CLOCK]; + let next_clock = vars.next_values[COL_CLOCK]; + let delta_clock = builder.sub_extension(next_clock, local_clock); + yield_constr.constraint_first_row(builder, local_clock); + let constraint = builder.sub_extension(delta_clock, one_ext); + yield_constr.constraint(builder, constraint); + + // The 16-bit table must start with 0, end with 2^16 - 1, and increment by 0 or 1. + let local_range_16 = vars.local_values[COL_RANGE_16]; + let next_range_16 = vars.next_values[COL_RANGE_16]; + let delta_range_16 = builder.sub_extension(next_range_16, local_range_16); + yield_constr.constraint_first_row(builder, local_range_16); + let constraint = builder.sub_extension(local_range_16, max_u16_ext); + yield_constr.constraint_last_row(builder, constraint); + let constraint = builder.mul_add_extension(delta_range_16, delta_range_16, delta_range_16); + yield_constr.constraint(builder, constraint); + + // TODO constraints for stack etc. +} diff --git a/system_zero/src/lib.rs b/system_zero/src/lib.rs new file mode 100644 index 00000000..1c097573 --- /dev/null +++ b/system_zero/src/lib.rs @@ -0,0 +1,11 @@ +// TODO: Remove these when crate is closer to being finished. +#![allow(dead_code)] +#![allow(unused_variables)] + +mod arithmetic; +mod core_registers; +mod memory; +mod permutation_unit; +mod public_input_layout; +mod registers; +pub mod system_zero; diff --git a/system_zero/src/memory.rs b/system_zero/src/memory.rs new file mode 100644 index 00000000..0cc42d30 --- /dev/null +++ b/system_zero/src/memory.rs @@ -0,0 +1,16 @@ +#[derive(Default)] +pub struct TransactionMemory { + pub calls: Vec, +} + +/// A virtual memory space specific to the current contract call. +pub struct ContractMemory { + pub code: MemorySegment, + pub main: MemorySegment, + pub calldata: MemorySegment, + pub returndata: MemorySegment, +} + +pub struct MemorySegment { + pub content: Vec, +} diff --git a/system_zero/src/permutation_unit.rs b/system_zero/src/permutation_unit.rs new file mode 100644 index 00000000..366cff65 --- /dev/null +++ b/system_zero/src/permutation_unit.rs @@ -0,0 +1,328 @@ +use plonky2::field::extension_field::{Extendable, FieldExtension}; +use plonky2::field::packed_field::PackedField; +use plonky2::hash::hash_types::RichField; +use plonky2::hash::hashing::SPONGE_WIDTH; +use plonky2::hash::poseidon::{Poseidon, HALF_N_FULL_ROUNDS, N_PARTIAL_ROUNDS}; +use plonky2::plonk::circuit_builder::CircuitBuilder; +use starky::constraint_consumer::{ConstraintConsumer, RecursiveConstraintConsumer}; +use starky::vars::StarkEvaluationTargets; +use starky::vars::StarkEvaluationVars; + +use crate::public_input_layout::NUM_PUBLIC_INPUTS; +use crate::registers::permutation::*; +use crate::registers::NUM_COLUMNS; + +fn constant_layer( + mut state: [P; SPONGE_WIDTH], + round: usize, +) -> [P; SPONGE_WIDTH] +where + F: Poseidon, + FE: FieldExtension, + P: PackedField, +{ + // One day I might actually vectorize this, but today is not that day. + for i in 0..P::WIDTH { + let mut unpacked_state = [P::Scalar::default(); SPONGE_WIDTH]; + for j in 0..SPONGE_WIDTH { + unpacked_state[j] = state[j].as_slice()[i]; + } + F::constant_layer_field(&mut unpacked_state, round); + for j in 0..SPONGE_WIDTH { + state[j].as_slice_mut()[i] = unpacked_state[j]; + } + } + state +} + +fn mds_layer(mut state: [P; SPONGE_WIDTH]) -> [P; SPONGE_WIDTH] +where + F: Poseidon, + FE: FieldExtension, + P: PackedField, +{ + for i in 0..P::WIDTH { + let mut unpacked_state = [P::Scalar::default(); SPONGE_WIDTH]; + for j in 0..SPONGE_WIDTH { + unpacked_state[j] = state[j].as_slice()[i]; + } + unpacked_state = F::mds_layer_field(&unpacked_state); + for j in 0..SPONGE_WIDTH { + state[j].as_slice_mut()[i] = unpacked_state[j]; + } + } + state +} + +pub(crate) fn generate_permutation_unit(values: &mut [F; NUM_COLUMNS]) { + // Load inputs. + let mut state = [F::ZERO; SPONGE_WIDTH]; + for i in 0..SPONGE_WIDTH { + state[i] = values[col_input(i)]; + } + + for r in 0..HALF_N_FULL_ROUNDS { + F::constant_layer(&mut state, r); + + for i in 0..SPONGE_WIDTH { + let state_cubed = state[i].cube(); + values[col_full_first_mid_sbox(r, i)] = state_cubed; + state[i] *= state_cubed.square(); // Form state ** 7. + } + + state = F::mds_layer(&state); + + for i in 0..SPONGE_WIDTH { + values[col_full_first_after_mds(r, i)] = state[i]; + } + } + + for r in 0..N_PARTIAL_ROUNDS { + F::constant_layer(&mut state, HALF_N_FULL_ROUNDS + r); + + let state0_cubed = state[0].cube(); + values[col_partial_mid_sbox(r)] = state0_cubed; + state[0] *= state0_cubed.square(); // Form state ** 7. + values[col_partial_after_sbox(r)] = state[0]; + + state = F::mds_layer(&state); + } + + for r in 0..HALF_N_FULL_ROUNDS { + F::constant_layer(&mut state, HALF_N_FULL_ROUNDS + N_PARTIAL_ROUNDS + r); + + for i in 0..SPONGE_WIDTH { + let state_cubed = state[i].cube(); + values[col_full_second_mid_sbox(r, i)] = state_cubed; + state[i] *= state_cubed.square(); // Form state ** 7. + } + + state = F::mds_layer(&state); + + for i in 0..SPONGE_WIDTH { + values[col_full_second_after_mds(r, i)] = state[i]; + } + } +} + +#[inline] +pub(crate) fn eval_permutation_unit( + vars: StarkEvaluationVars, + yield_constr: &mut ConstraintConsumer

, +) where + F: Poseidon, + FE: FieldExtension, + P: PackedField, +{ + let local_values = &vars.local_values; + + // Load inputs. + let mut state = [P::ZEROS; SPONGE_WIDTH]; + for i in 0..SPONGE_WIDTH { + state[i] = local_values[col_input(i)]; + } + + for r in 0..HALF_N_FULL_ROUNDS { + state = constant_layer(state, r); + + for i in 0..SPONGE_WIDTH { + let state_cubed = state[i] * state[i].square(); + yield_constr + .constraint_wrapping(state_cubed - local_values[col_full_first_mid_sbox(r, i)]); + let state_cubed = local_values[col_full_first_mid_sbox(r, i)]; + state[i] *= state_cubed.square(); // Form state ** 7. + } + + state = mds_layer(state); + + for i in 0..SPONGE_WIDTH { + yield_constr + .constraint_wrapping(state[i] - local_values[col_full_first_after_mds(r, i)]); + state[i] = local_values[col_full_first_after_mds(r, i)]; + } + } + + for r in 0..N_PARTIAL_ROUNDS { + state = constant_layer(state, HALF_N_FULL_ROUNDS + r); + + let state0_cubed = state[0] * state[0].square(); + yield_constr.constraint_wrapping(state0_cubed - local_values[col_partial_mid_sbox(r)]); + let state0_cubed = local_values[col_partial_mid_sbox(r)]; + state[0] *= state0_cubed.square(); // Form state ** 7. + yield_constr.constraint_wrapping(state[0] - local_values[col_partial_after_sbox(r)]); + state[0] = local_values[col_partial_after_sbox(r)]; + + state = mds_layer(state); + } + + for r in 0..HALF_N_FULL_ROUNDS { + state = constant_layer(state, HALF_N_FULL_ROUNDS + N_PARTIAL_ROUNDS + r); + + for i in 0..SPONGE_WIDTH { + let state_cubed = state[i] * state[i].square(); + yield_constr + .constraint_wrapping(state_cubed - local_values[col_full_second_mid_sbox(r, i)]); + let state_cubed = local_values[col_full_second_mid_sbox(r, i)]; + state[i] *= state_cubed.square(); // Form state ** 7. + } + + state = mds_layer(state); + + for i in 0..SPONGE_WIDTH { + yield_constr + .constraint_wrapping(state[i] - local_values[col_full_second_after_mds(r, i)]); + state[i] = local_values[col_full_second_after_mds(r, i)]; + } + } +} + +pub(crate) fn eval_permutation_unit_recursively, const D: usize>( + builder: &mut CircuitBuilder, + vars: StarkEvaluationTargets, + yield_constr: &mut RecursiveConstraintConsumer, +) { + let zero = builder.zero_extension(); + let local_values = &vars.local_values; + + // Load inputs. + let mut state = [zero; SPONGE_WIDTH]; + for i in 0..SPONGE_WIDTH { + state[i] = local_values[col_input(i)]; + } + + for r in 0..HALF_N_FULL_ROUNDS { + F::constant_layer_recursive(builder, &mut state, r); + + for i in 0..SPONGE_WIDTH { + let state_cubed = builder.cube_extension(state[i]); + let diff = + builder.sub_extension(state_cubed, local_values[col_full_first_mid_sbox(r, i)]); + yield_constr.constraint_wrapping(builder, diff); + let state_cubed = local_values[col_full_first_mid_sbox(r, i)]; + state[i] = builder.mul_many_extension(&[state[i], state_cubed, state_cubed]); + // Form state ** 7. + } + + state = F::mds_layer_recursive(builder, &state); + + for i in 0..SPONGE_WIDTH { + let diff = + builder.sub_extension(state[i], local_values[col_full_first_after_mds(r, i)]); + yield_constr.constraint_wrapping(builder, diff); + state[i] = local_values[col_full_first_after_mds(r, i)]; + } + } + + for r in 0..N_PARTIAL_ROUNDS { + F::constant_layer_recursive(builder, &mut state, HALF_N_FULL_ROUNDS + r); + + let state0_cubed = builder.cube_extension(state[0]); + let diff = builder.sub_extension(state0_cubed, local_values[col_partial_mid_sbox(r)]); + yield_constr.constraint_wrapping(builder, diff); + let state0_cubed = local_values[col_partial_mid_sbox(r)]; + state[0] = builder.mul_many_extension(&[state[0], state0_cubed, state0_cubed]); // Form state ** 7. + let diff = builder.sub_extension(state[0], local_values[col_partial_after_sbox(r)]); + yield_constr.constraint_wrapping(builder, diff); + state[0] = local_values[col_partial_after_sbox(r)]; + + state = F::mds_layer_recursive(builder, &state); + } + + for r in 0..HALF_N_FULL_ROUNDS { + F::constant_layer_recursive( + builder, + &mut state, + HALF_N_FULL_ROUNDS + N_PARTIAL_ROUNDS + r, + ); + + for i in 0..SPONGE_WIDTH { + let state_cubed = builder.cube_extension(state[i]); + let diff = + builder.sub_extension(state_cubed, local_values[col_full_second_mid_sbox(r, i)]); + yield_constr.constraint_wrapping(builder, diff); + let state_cubed = local_values[col_full_second_mid_sbox(r, i)]; + state[i] = builder.mul_many_extension(&[state[i], state_cubed, state_cubed]); + // Form state ** 7. + } + + state = F::mds_layer_recursive(builder, &state); + + for i in 0..SPONGE_WIDTH { + let diff = + builder.sub_extension(state[i], local_values[col_full_second_after_mds(r, i)]); + yield_constr.constraint_wrapping(builder, diff); + state[i] = local_values[col_full_second_after_mds(r, i)]; + } + } +} + +#[cfg(test)] +mod tests { + use plonky2::field::field_types::Field; + use plonky2::field::goldilocks_field::GoldilocksField; + use plonky2::hash::poseidon::Poseidon; + use rand::SeedableRng; + use rand_chacha::ChaCha8Rng; + use starky::constraint_consumer::ConstraintConsumer; + use starky::vars::StarkEvaluationVars; + + use crate::permutation_unit::{eval_permutation_unit, generate_permutation_unit, SPONGE_WIDTH}; + use crate::public_input_layout::NUM_PUBLIC_INPUTS; + use crate::registers::permutation::{col_input, col_output}; + use crate::registers::NUM_COLUMNS; + + #[test] + fn generate_eval_consistency() { + const D: usize = 1; + type F = GoldilocksField; + + let mut values = [F::default(); NUM_COLUMNS]; + generate_permutation_unit(&mut values); + + let vars = StarkEvaluationVars { + local_values: &values, + next_values: &[F::default(); NUM_COLUMNS], + public_inputs: &[F::default(); NUM_PUBLIC_INPUTS], + }; + + let mut constrant_consumer = ConstraintConsumer::new( + vec![GoldilocksField(2), GoldilocksField(3), GoldilocksField(5)], + GoldilocksField::ONE, + GoldilocksField::ONE, + GoldilocksField::ONE, + ); + eval_permutation_unit(vars, &mut constrant_consumer); + for &acc in &constrant_consumer.constraint_accs { + assert_eq!(acc, GoldilocksField::ZERO); + } + } + + #[test] + fn poseidon_result() { + const D: usize = 1; + type F = GoldilocksField; + + let mut rng = ChaCha8Rng::seed_from_u64(0x6feb51b7ec230f25); + let state = [F::default(); SPONGE_WIDTH].map(|_| F::rand_from_rng(&mut rng)); + + // Get true Poseidon hash + let target = GoldilocksField::poseidon(state); + + // Get result from `generate_permutation_unit` + // Initialize `values` with randomness to test that the code doesn't rely on zero-filling. + let mut values = [F::default(); NUM_COLUMNS].map(|_| F::rand_from_rng(&mut rng)); + for i in 0..SPONGE_WIDTH { + values[col_input(i)] = state[i]; + } + generate_permutation_unit(&mut values); + let mut result = [F::default(); SPONGE_WIDTH]; + for i in 0..SPONGE_WIDTH { + result[i] = values[col_output(i)]; + } + + assert_eq!(target, result); + } + + // TODO(JN): test degree + // TODO(JN): test `eval_permutation_unit_recursively` +} diff --git a/system_zero/src/public_input_layout.rs b/system_zero/src/public_input_layout.rs new file mode 100644 index 00000000..225b3814 --- /dev/null +++ b/system_zero/src/public_input_layout.rs @@ -0,0 +1,7 @@ +/// The previous state root, before these transactions were executed. +const PI_OLD_STATE_ROOT: usize = 0; + +/// The updated state root, after these transactions were executed. +const PI_NEW_STATE_ROOT: usize = PI_OLD_STATE_ROOT + 1; + +pub(crate) const NUM_PUBLIC_INPUTS: usize = PI_NEW_STATE_ROOT + 1; diff --git a/system_zero/src/registers/arithmetic.rs b/system_zero/src/registers/arithmetic.rs new file mode 100644 index 00000000..92c0d2c3 --- /dev/null +++ b/system_zero/src/registers/arithmetic.rs @@ -0,0 +1,37 @@ +//! Arithmetic unit. + +pub(crate) const IS_ADD: usize = super::START_ARITHMETIC; +pub(crate) const IS_SUB: usize = IS_ADD + 1; +pub(crate) const IS_MUL: usize = IS_SUB + 1; +pub(crate) const IS_DIV: usize = IS_MUL + 1; + +const START_SHARED_COLS: usize = IS_DIV + 1; + +/// Within the arithmetic unit, there are shared columns which can be used by any arithmetic +/// circuit, depending on which one is active this cycle. +// Can be increased as needed as other operations are implemented. +const NUM_SHARED_COLS: usize = 3; + +const fn shared_col(i: usize) -> usize { + debug_assert!(i < NUM_SHARED_COLS); + START_SHARED_COLS + i +} + +/// The first value to be added; treated as an unsigned u32. +pub(crate) const COL_ADD_INPUT_1: usize = shared_col(0); +/// The second value to be added; treated as an unsigned u32. +pub(crate) const COL_ADD_INPUT_2: usize = shared_col(1); +/// The third value to be added; treated as an unsigned u32. +pub(crate) const COL_ADD_INPUT_3: usize = shared_col(2); + +// Note: Addition outputs three 16-bit chunks, and since these values need to be range-checked +// anyway, we might as well use the range check unit's columns as our addition outputs. So the +// three proceeding columns are basically aliases, not columns owned by the arithmetic unit. +/// The first 16-bit chunk of the output, based on little-endian ordering. +pub(crate) const COL_ADD_OUTPUT_1: usize = super::range_check_16::col_rc_16_input(0); +/// The second 16-bit chunk of the output, based on little-endian ordering. +pub(crate) const COL_ADD_OUTPUT_2: usize = super::range_check_16::col_rc_16_input(1); +/// The third 16-bit chunk of the output, based on little-endian ordering. +pub(crate) const COL_ADD_OUTPUT_3: usize = super::range_check_16::col_rc_16_input(2); + +pub(super) const END: usize = super::START_ARITHMETIC + NUM_SHARED_COLS; diff --git a/system_zero/src/registers/boolean.rs b/system_zero/src/registers/boolean.rs new file mode 100644 index 00000000..c59af8d4 --- /dev/null +++ b/system_zero/src/registers/boolean.rs @@ -0,0 +1,10 @@ +//! Boolean unit. Contains columns whose values must be 0 or 1. + +const NUM_BITS: usize = 128; + +pub const fn col_bit(index: usize) -> usize { + debug_assert!(index < NUM_BITS); + super::START_BOOLEAN + index +} + +pub(super) const END: usize = super::START_BOOLEAN + NUM_BITS; diff --git a/system_zero/src/registers/core.rs b/system_zero/src/registers/core.rs new file mode 100644 index 00000000..3fafab55 --- /dev/null +++ b/system_zero/src/registers/core.rs @@ -0,0 +1,20 @@ +//! Core registers. + +/// A cycle counter. Starts at 0; increments by 1. +pub(crate) const COL_CLOCK: usize = super::START_CORE; + +/// A column which contains the values `[0, ... 2^16 - 1]`, potentially with duplicates. Used for +/// 16-bit range checks. +/// +/// For ease of verification, we enforce that it must begin with 0 and end with `2^16 - 1`, and each +/// delta must be either 0 or 1. +pub(crate) const COL_RANGE_16: usize = COL_CLOCK + 1; + +/// Pointer to the current instruction. +pub(crate) const COL_INSTRUCTION_PTR: usize = COL_RANGE_16 + 1; +/// Pointer to the base of the current call's stack frame. +pub(crate) const COL_FRAME_PTR: usize = COL_INSTRUCTION_PTR + 1; +/// Pointer to the tip of the current call's stack frame. +pub(crate) const COL_STACK_PTR: usize = COL_FRAME_PTR + 1; + +pub(super) const END: usize = COL_STACK_PTR + 1; diff --git a/system_zero/src/registers/logic.rs b/system_zero/src/registers/logic.rs new file mode 100644 index 00000000..07f3f0e0 --- /dev/null +++ b/system_zero/src/registers/logic.rs @@ -0,0 +1,3 @@ +//! Logic unit. + +pub(super) const END: usize = super::START_LOGIC; diff --git a/system_zero/src/registers/lookup.rs b/system_zero/src/registers/lookup.rs new file mode 100644 index 00000000..eb773acf --- /dev/null +++ b/system_zero/src/registers/lookup.rs @@ -0,0 +1,21 @@ +//! Lookup unit. +//! See https://zcash.github.io/halo2/design/proving-system/lookup.html + +const START_UNIT: usize = super::START_LOOKUP; + +const NUM_LOOKUPS: usize = + super::range_check_16::NUM_RANGE_CHECKS + super::range_check_degree::NUM_RANGE_CHECKS; + +/// This column contains a permutation of the input values. +const fn col_permuted_input(i: usize) -> usize { + debug_assert!(i < NUM_LOOKUPS); + START_UNIT + 2 * i +} + +/// This column contains a permutation of the table values. +const fn col_permuted_table(i: usize) -> usize { + debug_assert!(i < NUM_LOOKUPS); + START_UNIT + 2 * i + 1 +} + +pub(super) const END: usize = START_UNIT + NUM_LOOKUPS; diff --git a/system_zero/src/registers/memory.rs b/system_zero/src/registers/memory.rs new file mode 100644 index 00000000..1373d0d8 --- /dev/null +++ b/system_zero/src/registers/memory.rs @@ -0,0 +1,3 @@ +//! Memory unit. + +pub(super) const END: usize = super::START_MEMORY; diff --git a/system_zero/src/registers/mod.rs b/system_zero/src/registers/mod.rs new file mode 100644 index 00000000..134a28bf --- /dev/null +++ b/system_zero/src/registers/mod.rs @@ -0,0 +1,20 @@ +pub(crate) mod arithmetic; +pub(crate) mod boolean; +pub(crate) mod core; +pub(crate) mod logic; +pub(crate) mod lookup; +pub(crate) mod memory; +pub(crate) mod permutation; +pub(crate) mod range_check_16; +pub(crate) mod range_check_degree; + +const START_ARITHMETIC: usize = 0; +const START_BOOLEAN: usize = arithmetic::END; +const START_CORE: usize = boolean::END; +const START_LOGIC: usize = core::END; +const START_LOOKUP: usize = logic::END; +const START_MEMORY: usize = lookup::END; +const START_PERMUTATION: usize = memory::END; +const START_RANGE_CHECK_16: usize = permutation::END; +const START_RANGE_CHECK_DEGREE: usize = range_check_16::END; +pub(crate) const NUM_COLUMNS: usize = range_check_degree::END; diff --git a/system_zero/src/registers/permutation.rs b/system_zero/src/registers/permutation.rs new file mode 100644 index 00000000..cde76af2 --- /dev/null +++ b/system_zero/src/registers/permutation.rs @@ -0,0 +1,57 @@ +//! Permutation unit. + +use plonky2::hash::hashing::SPONGE_WIDTH; +use plonky2::hash::poseidon; + +const START_FULL_FIRST: usize = super::START_PERMUTATION + SPONGE_WIDTH; + +pub const fn col_full_first_mid_sbox(round: usize, i: usize) -> usize { + debug_assert!(round < poseidon::HALF_N_FULL_ROUNDS); + debug_assert!(i < SPONGE_WIDTH); + START_FULL_FIRST + 2 * round * SPONGE_WIDTH + i +} + +pub const fn col_full_first_after_mds(round: usize, i: usize) -> usize { + debug_assert!(round < poseidon::HALF_N_FULL_ROUNDS); + debug_assert!(i < SPONGE_WIDTH); + START_FULL_FIRST + (2 * round + 1) * SPONGE_WIDTH + i +} + +const START_PARTIAL: usize = + col_full_first_after_mds(poseidon::HALF_N_FULL_ROUNDS - 1, SPONGE_WIDTH - 1) + 1; + +pub const fn col_partial_mid_sbox(round: usize) -> usize { + debug_assert!(round < poseidon::N_PARTIAL_ROUNDS); + START_PARTIAL + 2 * round +} + +pub const fn col_partial_after_sbox(round: usize) -> usize { + debug_assert!(round < poseidon::N_PARTIAL_ROUNDS); + START_PARTIAL + 2 * round + 1 +} + +const START_FULL_SECOND: usize = col_partial_after_sbox(poseidon::N_PARTIAL_ROUNDS - 1) + 1; + +pub const fn col_full_second_mid_sbox(round: usize, i: usize) -> usize { + debug_assert!(round <= poseidon::HALF_N_FULL_ROUNDS); + debug_assert!(i < SPONGE_WIDTH); + START_FULL_SECOND + 2 * round * SPONGE_WIDTH + i +} + +pub const fn col_full_second_after_mds(round: usize, i: usize) -> usize { + debug_assert!(round <= poseidon::HALF_N_FULL_ROUNDS); + debug_assert!(i < SPONGE_WIDTH); + START_FULL_SECOND + (2 * round + 1) * SPONGE_WIDTH + i +} + +pub const fn col_input(i: usize) -> usize { + debug_assert!(i < SPONGE_WIDTH); + super::START_PERMUTATION + i +} + +pub const fn col_output(i: usize) -> usize { + debug_assert!(i < SPONGE_WIDTH); + col_full_second_after_mds(poseidon::HALF_N_FULL_ROUNDS - 1, i) +} + +pub(super) const END: usize = col_output(SPONGE_WIDTH - 1) + 1; diff --git a/system_zero/src/registers/range_check_16.rs b/system_zero/src/registers/range_check_16.rs new file mode 100644 index 00000000..c44db494 --- /dev/null +++ b/system_zero/src/registers/range_check_16.rs @@ -0,0 +1,11 @@ +//! Range check unit which checks that values are in `[0, 2^16)`. + +pub(super) const NUM_RANGE_CHECKS: usize = 5; + +/// The input of the `i`th range check, i.e. the value being range checked. +pub(crate) const fn col_rc_16_input(i: usize) -> usize { + debug_assert!(i < NUM_RANGE_CHECKS); + super::START_RANGE_CHECK_16 + i +} + +pub(super) const END: usize = super::START_RANGE_CHECK_16 + NUM_RANGE_CHECKS; diff --git a/system_zero/src/registers/range_check_degree.rs b/system_zero/src/registers/range_check_degree.rs new file mode 100644 index 00000000..6d61e6e2 --- /dev/null +++ b/system_zero/src/registers/range_check_degree.rs @@ -0,0 +1,11 @@ +//! Range check unit which checks that values are in `[0, degree)`. + +pub(super) const NUM_RANGE_CHECKS: usize = 5; + +/// The input of the `i`th range check, i.e. the value being range checked. +pub(crate) const fn col_rc_degree_input(i: usize) -> usize { + debug_assert!(i < NUM_RANGE_CHECKS); + super::START_RANGE_CHECK_DEGREE + i +} + +pub(super) const END: usize = super::START_RANGE_CHECK_DEGREE + NUM_RANGE_CHECKS; diff --git a/system_zero/src/system_zero.rs b/system_zero/src/system_zero.rs new file mode 100644 index 00000000..2eeb4697 --- /dev/null +++ b/system_zero/src/system_zero.rs @@ -0,0 +1,152 @@ +use std::marker::PhantomData; + +use plonky2::field::extension_field::{Extendable, FieldExtension}; +use plonky2::field::packed_field::PackedField; +use plonky2::hash::hash_types::RichField; +use plonky2::plonk::circuit_builder::CircuitBuilder; +use starky::constraint_consumer::{ConstraintConsumer, RecursiveConstraintConsumer}; +use starky::stark::Stark; +use starky::vars::StarkEvaluationTargets; +use starky::vars::StarkEvaluationVars; + +use crate::arithmetic::{ + eval_arithmetic_unit, eval_arithmetic_unit_recursively, generate_arithmetic_unit, +}; +use crate::core_registers::{ + eval_core_registers, eval_core_registers_recursively, generate_first_row_core_registers, + generate_next_row_core_registers, +}; +use crate::memory::TransactionMemory; +use crate::permutation_unit::{ + eval_permutation_unit, eval_permutation_unit_recursively, generate_permutation_unit, +}; +use crate::public_input_layout::NUM_PUBLIC_INPUTS; +use crate::registers::NUM_COLUMNS; + +/// We require at least 2^16 rows as it helps support efficient 16-bit range checks. +const MIN_TRACE_ROWS: usize = 1 << 16; + +#[derive(Copy, Clone)] +pub struct SystemZero, const D: usize> { + _phantom: PhantomData, +} + +impl, const D: usize> SystemZero { + fn generate_trace(&self) -> Vec<[F; NUM_COLUMNS]> { + let memory = TransactionMemory::default(); + + let mut row = [F::ZERO; NUM_COLUMNS]; + generate_first_row_core_registers(&mut row); + generate_arithmetic_unit(&mut row); + generate_permutation_unit(&mut row); + + let mut trace = Vec::with_capacity(MIN_TRACE_ROWS); + + loop { + let mut next_row = [F::ZERO; NUM_COLUMNS]; + generate_next_row_core_registers(&row, &mut next_row); + generate_arithmetic_unit(&mut next_row); + generate_permutation_unit(&mut next_row); + + trace.push(row); + row = next_row; + + // TODO: Replace with proper termination condition. + if trace.len() == (1 << 16) - 1 { + break; + } + } + + trace.push(row); + trace + } +} + +impl, const D: usize> Default for SystemZero { + fn default() -> Self { + Self { + _phantom: PhantomData, + } + } +} + +impl, const D: usize> Stark for SystemZero { + const COLUMNS: usize = NUM_COLUMNS; + const PUBLIC_INPUTS: usize = NUM_PUBLIC_INPUTS; + + fn eval_packed_generic( + &self, + vars: StarkEvaluationVars, + yield_constr: &mut ConstraintConsumer

, + ) where + FE: FieldExtension, + P: PackedField, + { + eval_core_registers(vars, yield_constr); + eval_arithmetic_unit(vars, yield_constr); + eval_permutation_unit::(vars, yield_constr); + // TODO: Other units + } + + fn eval_ext_recursively( + &self, + builder: &mut CircuitBuilder, + vars: StarkEvaluationTargets, + yield_constr: &mut RecursiveConstraintConsumer, + ) { + eval_core_registers_recursively(builder, vars, yield_constr); + eval_arithmetic_unit_recursively(builder, vars, yield_constr); + eval_permutation_unit_recursively(builder, vars, yield_constr); + // TODO: Other units + } + + fn constraint_degree(&self) -> usize { + 3 + } +} + +#[cfg(test)] +mod tests { + use anyhow::Result; + use log::Level; + use plonky2::field::field_types::Field; + use plonky2::field::goldilocks_field::GoldilocksField; + use plonky2::plonk::config::PoseidonGoldilocksConfig; + use plonky2::util::timing::TimingTree; + use starky::config::StarkConfig; + use starky::prover::prove; + use starky::stark::Stark; + use starky::stark_testing::test_stark_low_degree; + use starky::verifier::verify; + + use crate::system_zero::SystemZero; + + #[test] + #[ignore] // A bit slow. + fn run() -> Result<()> { + type F = GoldilocksField; + type C = PoseidonGoldilocksConfig; + const D: usize = 2; + + type S = SystemZero; + let system = S::default(); + let public_inputs = [F::ZERO; S::PUBLIC_INPUTS]; + let config = StarkConfig::standard_fast_config(); + let mut timing = TimingTree::new("prove", Level::Debug); + let trace = system.generate_trace(); + let proof = prove::(system, &config, trace, public_inputs, &mut timing)?; + + verify(system, proof, &config) + } + + #[test] + fn degree() -> Result<()> { + type F = GoldilocksField; + type C = PoseidonGoldilocksConfig; + const D: usize = 2; + + type S = SystemZero; + let system = S::default(); + test_stark_low_degree(system) + } +} diff --git a/util/Cargo.toml b/util/Cargo.toml index 4d6e735c..a1ab402a 100644 --- a/util/Cargo.toml +++ b/util/Cargo.toml @@ -1,5 +1,6 @@ [package] name = "plonky2_util" +description = "Utilities used by Plonky2" version = "0.1.0" edition = "2021" diff --git a/util/src/lib.rs b/util/src/lib.rs index 8cc60a27..61677ff0 100644 --- a/util/src/lib.rs +++ b/util/src/lib.rs @@ -7,6 +7,12 @@ use std::arch::asm; use std::hint::unreachable_unchecked; +use std::mem::size_of; +use std::ptr::{swap, swap_nonoverlapping}; + +mod transpose_util; + +use crate::transpose_util::transpose_in_place_square; pub fn bits_u64(n: u64) -> usize { (64 - n.leading_zeros()) as usize @@ -17,14 +23,19 @@ pub const fn ceil_div_usize(a: usize, b: usize) -> usize { } /// Computes `ceil(log_2(n))`. +#[must_use] pub fn log2_ceil(n: usize) -> usize { - n.next_power_of_two().trailing_zeros() as usize + (usize::BITS - n.saturating_sub(1).leading_zeros()) as usize } /// Computes `log_2(n)`, panicking if `n` is not a power of two. pub fn log2_strict(n: usize) -> usize { - assert!(n.is_power_of_two(), "Not a power of two: {}", n); - log2_ceil(n) + let res = n.trailing_zeros(); + assert!(n.wrapping_shr(res) == 1, "Not a power of two: {}", n); + // Tell the optimizer about the semantics of `log2_strict`. i.e. it can replace `n` with + // `1 << res` and vice versa. + assume(n == 1 << res); + res as usize } /// Permutes `arr` such that each index is mapped to its reverse in binary. @@ -78,57 +89,129 @@ fn reverse_index_bits_large(arr: &[T], n_power: usize) -> Vec { result } -pub fn reverse_index_bits_in_place(arr: &mut Vec) { - let n = arr.len(); - let n_power = log2_strict(n); - - if n_power <= 6 { - reverse_index_bits_in_place_small(arr, n_power); +/// Bit-reverse the order of elements in `arr`. +/// SAFETY: ensure that `arr.len() == 1 << lb_n`. +#[cfg(not(target_arch = "aarch64"))] +unsafe fn reverse_index_bits_in_place_small(arr: &mut [T], lb_n: usize) { + if lb_n <= 6 { + // BIT_REVERSE_6BIT holds 6-bit reverses. This shift makes them lb_n-bit reverses. + let dst_shr_amt = 6 - lb_n; + for src in 0..arr.len() { + let dst = (BIT_REVERSE_6BIT[src] as usize) >> dst_shr_amt; + if src < dst { + swap(arr.get_unchecked_mut(src), arr.get_unchecked_mut(dst)); + } + } } else { - reverse_index_bits_in_place_large(arr, n_power); + // LLVM does not know that it does not need to reverse src at each iteration (which is + // expensive on x86). We take advantage of the fact that the low bits of dst change rarely and the high + // bits of dst are dependent only on the low bits of src. + let dst_lo_shr_amt = 64 - (lb_n - 6); + let dst_hi_shl_amt = lb_n - 6; + for src_chunk in 0..(arr.len() >> 6) { + let src_hi = src_chunk << 6; + let dst_lo = src_chunk.reverse_bits() >> dst_lo_shr_amt; + for src_lo in 0..(1 << 6) { + let dst_hi = (BIT_REVERSE_6BIT[src_lo] as usize) << dst_hi_shl_amt; + let src = src_hi + src_lo; + let dst = dst_hi + dst_lo; + if src < dst { + swap(arr.get_unchecked_mut(src), arr.get_unchecked_mut(dst)); + } + } + } } } -/* Both functions below are semantically equivalent to: - for src in 0..n { - let dst = reverse_bits(src, n_power); - if src < dst { - arr.swap(src, dst); - } - } - where reverse_bits(src, n_power) computes the n_power-bit reverse. -*/ - -fn reverse_index_bits_in_place_small(arr: &mut Vec, n_power: usize) { - let n = arr.len(); - // BIT_REVERSE_6BIT holds 6-bit reverses. This shift makes them n_power-bit reverses. - let dst_shr_amt = 6 - n_power; - for src in 0..n { - let dst = (BIT_REVERSE_6BIT[src] as usize) >> dst_shr_amt; +/// Bit-reverse the order of elements in `arr`. +/// SAFETY: ensure that `arr.len() == 1 << lb_n`. +#[cfg(target_arch = "aarch64")] +unsafe fn reverse_index_bits_in_place_small(arr: &mut [T], lb_n: usize) { + // Aarch64 can reverse bits in one instruction, so the trivial version works best. + for src in 0..arr.len() { + // `wrapping_shr` handles the case when `arr.len() == 1`. In that case `src == 0`, so + // `src.reverse_bits() == 0`. `usize::wrapping_shr` by 64 is a no-op, but it gives the + // correct result. + let dst = src.reverse_bits().wrapping_shr(usize::BITS - lb_n as u32); if src < dst { - arr.swap(src, dst); + swap(arr.get_unchecked_mut(src), arr.get_unchecked_mut(dst)); } } } -fn reverse_index_bits_in_place_large(arr: &mut Vec, n_power: usize) { - let n = arr.len(); - // LLVM does not know that it does not need to reverse src at each iteration (which is expensive - // on x86). We take advantage of the fact that the low bits of dst change rarely and the high - // bits of dst are dependent only on the low bits of src. - let dst_lo_shr_amt = 64 - (n_power - 6); - let dst_hi_shl_amt = n_power - 6; - for src_chunk in 0..(n >> 6) { - let src_hi = src_chunk << 6; - let dst_lo = src_chunk.reverse_bits() >> dst_lo_shr_amt; - for src_lo in 0..(1 << 6) { - let dst_hi = (BIT_REVERSE_6BIT[src_lo] as usize) << dst_hi_shl_amt; +/// Split `arr` chunks and bit-reverse the order of the chunks. There are `1 << lb_num_chunks` +/// chunks, each of length `1 << lb_chunk_size`. +/// SAFETY: ensure that `arr.len() == 1 << lb_num_chunks + lb_chunk_size`. +unsafe fn reverse_index_bits_in_place_chunks( + arr: &mut [T], + lb_num_chunks: usize, + lb_chunk_size: usize, +) { + for i in 0..1usize << lb_num_chunks { + // `wrapping_shr` handles the silly case when `lb_num_chunks == 0`. + let j = i + .reverse_bits() + .wrapping_shr(usize::BITS - lb_num_chunks as u32); + if i < j { + swap_nonoverlapping( + arr.get_unchecked_mut(i << lb_chunk_size), + arr.get_unchecked_mut(j << lb_chunk_size), + 1 << lb_chunk_size, + ); + } + } +} - let src = src_hi + src_lo; - let dst = dst_hi + dst_lo; - if src < dst { - arr.swap(src, dst); +// Ensure that SMALL_ARR_SIZE >= 4 * BIG_T_SIZE. +const BIG_T_SIZE: usize = 1 << 14; +const SMALL_ARR_SIZE: usize = 1 << 16; +pub fn reverse_index_bits_in_place(arr: &mut [T]) { + let n = arr.len(); + let lb_n = log2_strict(n); + // If the whole array fits in fast cache, then the trivial algorithm is cache friendly. Also, if + // `T` is really big, then the trivial algorithm is cache-friendly, no matter the size of the + // array. + if size_of::() << lb_n <= SMALL_ARR_SIZE || size_of::() >= BIG_T_SIZE { + unsafe { + reverse_index_bits_in_place_small(arr, lb_n); + } + } else { + debug_assert!(n >= 4); // By our choice of `BIG_T_SIZE` and `SMALL_ARR_SIZE`. + + // Algorithm: + // + // Treat `arr` as a `sqrt(n)` by `sqrt(n)` row-major matrix. (Assume for now that `lb_n` is + // even, i.e., `n` is a square number.) To perform bit-order reversal we: + // 1. Bit-reverse the order of the rows. (They are contiguous in memory, so this is + // basically a series of large `memcpy`s.) + // 2. Transpose the matrix. + // 3. Bit-reverse the order of the rows. + // This is equivalent to, for every index `0 <= i < n`: + // 1. bit-reversing `i[lb_n / 2..lb_n]`, + // 2. swapping `i[0..lb_n / 2]` and `i[lb_n / 2..lb_n]`, + // 3. bit-reversing `i[lb_n / 2..lb_n]`. + // + // If `lb_n` is odd, i.e., `n` is not a square number, then the above procedure requires + // slight modification. At steps 1 and 3 we bit-reverse bits `ceil(lb_n / 2)..lb_n`, of the + // index (shuffling `floor(lb_n / 2)` chunks of length `ceil(lb_n / 2)`). At step 2, we + // perform _two_ transposes. We treat `arr` as two matrices, one where the middle bit of the + // index is `0` and another, where the middle bit is `1`; we transpose each individually. + + let lb_num_chunks = lb_n >> 1; + let lb_chunk_size = lb_n - lb_num_chunks; + unsafe { + reverse_index_bits_in_place_chunks(arr, lb_num_chunks, lb_chunk_size); + transpose_in_place_square(arr, lb_chunk_size, lb_num_chunks, 0); + if lb_num_chunks != lb_chunk_size { + // `arr` cannot be interpreted as a square matrix. We instead interpret it as a + // `1 << lb_num_chunks` by `2` by `1 << lb_num_chunks` tensor, in row-major order. + // The above transpose acted on `tensor[..., 0, ...]` (all indices with middle bit + // `0`). We still need to transpose `tensor[..., 1, ...]`. To do so, we advance + // arr by `1 << lb_num_chunks` effectively, adding that to every index. + let arr_with_offset = &mut arr[1 << lb_num_chunks..]; + transpose_in_place_square(arr_with_offset, lb_chunk_size, lb_num_chunks, 0); } + reverse_index_bits_in_place_chunks(arr, lb_num_chunks, lb_chunk_size); } } } @@ -171,3 +254,59 @@ pub fn branch_hint() { asm!("", options(nomem, nostack, preserves_flags)); } } + +#[cfg(test)] +mod tests { + use crate::{log2_ceil, log2_strict}; + + #[test] + fn test_log2_strict() { + assert_eq!(log2_strict(1), 0); + assert_eq!(log2_strict(2), 1); + assert_eq!(log2_strict(1 << 18), 18); + assert_eq!(log2_strict(1 << 31), 31); + assert_eq!( + log2_strict(1 << (usize::BITS - 1)), + usize::BITS as usize - 1 + ); + } + + #[test] + #[should_panic] + fn test_log2_strict_zero() { + log2_strict(0); + } + + #[test] + #[should_panic] + fn test_log2_strict_nonpower_2() { + log2_strict(0x78c341c65ae6d262); + } + + #[test] + #[should_panic] + fn test_log2_strict_usize_max() { + log2_strict(usize::MAX); + } + + #[test] + fn test_log2_ceil() { + // Powers of 2 + assert_eq!(log2_ceil(0), 0); + assert_eq!(log2_ceil(1), 0); + assert_eq!(log2_ceil(2), 1); + assert_eq!(log2_ceil(1 << 18), 18); + assert_eq!(log2_ceil(1 << 31), 31); + assert_eq!(log2_ceil(1 << (usize::BITS - 1)), usize::BITS as usize - 1); + + // Nonpowers; want to round up + assert_eq!(log2_ceil(3), 2); + assert_eq!(log2_ceil(0x14fe901b), 29); + assert_eq!( + log2_ceil((1 << (usize::BITS - 1)) + 1), + usize::BITS as usize + ); + assert_eq!(log2_ceil(usize::MAX - 1), usize::BITS as usize); + assert_eq!(log2_ceil(usize::MAX), usize::BITS as usize); + } +} diff --git a/util/src/transpose_util.rs b/util/src/transpose_util.rs new file mode 100644 index 00000000..1c8280a8 --- /dev/null +++ b/util/src/transpose_util.rs @@ -0,0 +1,112 @@ +use std::ptr::swap; + +const LB_BLOCK_SIZE: usize = 3; + +/// Transpose square matrix in-place +/// The matrix is of size `1 << lb_size` by `1 << lb_size`. It occupies +/// `M[i, j] == arr[(i + x << lb_stride) + j + x]` for `0 <= i, j < 1 << lb_size`. The transposition +/// swaps `M[i, j]` and `M[j, i]`. +/// +/// SAFETY: +/// Make sure that `(i + x << lb_stride) + j + x` is a valid index in `arr` for all +/// `0 <= i, j < 1 << lb_size`. Ensure also that `lb_size <= lb_stride` to prevent overlap. +unsafe fn transpose_in_place_square_small( + arr: &mut [T], + lb_stride: usize, + lb_size: usize, + x: usize, +) { + for i in x..x + (1 << lb_size) { + for j in x..i { + swap( + arr.get_unchecked_mut(i + (j << lb_stride)), + arr.get_unchecked_mut((i << lb_stride) + j), + ); + } + } +} + +/// Transpose square matrices and swap +/// The matrices are of of size `1 << lb_size` by `1 << lb_size`. They occupy +/// `M0[i, j] == arr[(i + x << lb_stride) + j + y]`, `M1[i, j] == arr[i + x + (j + y << lb_stride)]` +/// for `0 <= i, j < 1 << lb_size. The transposition swaps `M0[i, j]` and `M1[j, i]`. +/// +/// SAFETY: +/// Make sure that `(i + x << lb_stride) + j + y` and `i + x + (j + y << lb_stride)` are valid +/// indices in `arr` for all `0 <= i, j < 1 << lb_size`. Ensure also that `lb_size <= lb_stride` to +/// prevent overlap. +unsafe fn transpose_swap_square_small( + arr: &mut [T], + lb_stride: usize, + lb_size: usize, + x: usize, + y: usize, +) { + for i in x..x + (1 << lb_size) { + for j in y..y + (1 << lb_size) { + swap( + arr.get_unchecked_mut(i + (j << lb_stride)), + arr.get_unchecked_mut((i << lb_stride) + j), + ); + } + } +} + +/// Transpose square matrices and swap +/// The matrices are of of size `1 << lb_size` by `1 << lb_size`. They occupy +/// `M0[i, j] == arr[(i + x << lb_stride) + j + y]`, `M1[i, j] == arr[i + x + (j + y << lb_stride)]` +/// for `0 <= i, j < 1 << lb_size. The transposition swaps `M0[i, j]` and `M1[j, i]`. +/// +/// SAFETY: +/// Make sure that `(i + x << lb_stride) + j + y` and `i + x + (j + y << lb_stride)` are valid +/// indices in `arr` for all `0 <= i, j < 1 << lb_size`. Ensure also that `lb_size <= lb_stride` to +/// prevent overlap. +unsafe fn transpose_swap_square( + arr: &mut [T], + lb_stride: usize, + lb_size: usize, + x: usize, + y: usize, +) { + if lb_size <= LB_BLOCK_SIZE { + transpose_swap_square_small(arr, lb_stride, lb_size, x, y); + } else { + let lb_block_size = lb_size - 1; + let block_size = 1 << lb_block_size; + transpose_swap_square(arr, lb_stride, lb_block_size, x, y); + transpose_swap_square(arr, lb_stride, lb_block_size, x + block_size, y); + transpose_swap_square(arr, lb_stride, lb_block_size, x, y + block_size); + transpose_swap_square( + arr, + lb_stride, + lb_block_size, + x + block_size, + y + block_size, + ); + } +} + +/// Transpose square matrix in-place +/// The matrix is of size `1 << lb_size` by `1 << lb_size`. It occupies +/// `M[i, j] == arr[(i + x << lb_stride) + j + x]` for `0 <= i, j < 1 << lb_size`. The transposition +/// swaps `M[i, j]` and `M[j, i]`. +/// +/// SAFETY: +/// Make sure that `(i + x << lb_stride) + j + x` is a valid index in `arr` for all +/// `0 <= i, j < 1 << lb_size`. Ensure also that `lb_size <= lb_stride` to prevent overlap. +pub(crate) unsafe fn transpose_in_place_square( + arr: &mut [T], + lb_stride: usize, + lb_size: usize, + x: usize, +) { + if lb_size <= LB_BLOCK_SIZE { + transpose_in_place_square_small(arr, lb_stride, lb_size, x); + } else { + let lb_block_size = lb_size - 1; + let block_size = 1 << lb_block_size; + transpose_in_place_square(arr, lb_stride, lb_block_size, x); + transpose_swap_square(arr, lb_stride, lb_block_size, x, x + block_size); + transpose_in_place_square(arr, lb_stride, lb_block_size, x + block_size); + } +} diff --git a/waksman/src/sorting.rs b/waksman/src/sorting.rs index 4270ebc7..775df9c5 100644 --- a/waksman/src/sorting.rs +++ b/waksman/src/sorting.rs @@ -183,7 +183,7 @@ impl, const D: usize> SimpleGenerator #[cfg(test)] mod tests { use anyhow::Result; - use plonky2::field::field_types::{Field, PrimeField}; + use plonky2::field::field_types::{Field, PrimeField64}; use plonky2::iop::witness::PartialWitness; use plonky2::plonk::circuit_data::CircuitConfig; use plonky2::plonk::config::{GenericConfig, PoseidonGoldilocksConfig};