diff --git a/.github/workflows/continuous-integration-workflow.yml b/.github/workflows/continuous-integration-workflow.yml index 1f1db14b..38231892 100644 --- a/.github/workflows/continuous-integration-workflow.yml +++ b/.github/workflows/continuous-integration-workflow.yml @@ -17,12 +17,24 @@ jobs: uses: actions/checkout@v2 - name: Install nightly toolchain + id: rustc-toolchain uses: actions-rs/toolchain@v1 with: profile: minimal toolchain: nightly override: true + - name: rust-cache + uses: actions/cache@v3 + with: + path: | + ~/.cargo/bin/ + ~/.cargo/registry/index/ + ~/.cargo/registry/cache/ + ~/.cargo/git/db/ + target/ + key: rustc-test-${{ steps.rustc-toolchain.outputs.rustc_hash }}-cargo-${{ hashFiles('**/Cargo.toml') }} + - name: Run cargo test uses: actions-rs/cargo@v1 with: @@ -30,7 +42,7 @@ jobs: args: --all env: RUSTFLAGS: -Copt-level=3 -Cdebug-assertions -Coverflow-checks=y -Cdebuginfo=0 -Cprefer-dynamic=y - CARGO_INCREMENTAL: 0 + CARGO_INCREMENTAL: 1 lints: name: Formatting and Clippy @@ -41,6 +53,7 @@ jobs: uses: actions/checkout@v2 - name: Install nightly toolchain + id: rustc-toolchain uses: actions-rs/toolchain@v1 with: profile: minimal @@ -48,15 +61,30 @@ jobs: override: true components: rustfmt, clippy + - name: rust-cache + uses: actions/cache@v3 + with: + path: | + ~/.cargo/bin/ + ~/.cargo/registry/index/ + ~/.cargo/registry/cache/ + ~/.cargo/git/db/ + target/ + key: rustc-lints-${{ steps.rustc-toolchain.outputs.rustc_hash }}-cargo-${{ hashFiles('**/Cargo.toml') }} + - name: Run cargo fmt uses: actions-rs/cargo@v1 with: command: fmt args: --all -- --check + env: + CARGO_INCREMENTAL: 1 - name: Run cargo clippy uses: actions-rs/cargo@v1 with: command: clippy args: --all-features --all-targets -- -D warnings -A incomplete-features + env: + CARGO_INCREMENTAL: 1 diff --git a/ecdsa/src/curve/curve_msm.rs b/ecdsa/src/curve/curve_msm.rs index 6d07c097..f681deb2 100644 --- a/ecdsa/src/curve/curve_msm.rs +++ b/ecdsa/src/curve/curve_msm.rs @@ -207,7 +207,7 @@ mod tests { 0b00001111111111111111111111111111, 0b11111111111111111111111111111111, ]; - let x = Secp256K1Scalar::from_biguint(BigUint::from_slice(&x_canonical)); + let x = Secp256K1Scalar::from_noncanonical_biguint(BigUint::from_slice(&x_canonical)); assert_eq!(x.to_canonical_biguint().to_u32_digits(), x_canonical); assert_eq!( to_digits::(&x, 17), @@ -240,13 +240,13 @@ mod tests { let generator_2 = generator_1 + generator_1; let generator_3 = generator_1 + generator_2; - let scalar_1 = Secp256K1Scalar::from_biguint(BigUint::from_slice(&[ + let scalar_1 = Secp256K1Scalar::from_noncanonical_biguint(BigUint::from_slice(&[ 11111111, 22222222, 33333333, 44444444, ])); - let scalar_2 = Secp256K1Scalar::from_biguint(BigUint::from_slice(&[ + let scalar_2 = Secp256K1Scalar::from_noncanonical_biguint(BigUint::from_slice(&[ 22222222, 22222222, 33333333, 44444444, ])); - let scalar_3 = Secp256K1Scalar::from_biguint(BigUint::from_slice(&[ + let scalar_3 = Secp256K1Scalar::from_noncanonical_biguint(BigUint::from_slice(&[ 33333333, 22222222, 33333333, 44444444, ])); diff --git a/ecdsa/src/curve/curve_types.rs b/ecdsa/src/curve/curve_types.rs index bbf66d65..96821672 100644 --- a/ecdsa/src/curve/curve_types.rs +++ b/ecdsa/src/curve/curve_types.rs @@ -277,9 +277,9 @@ impl Neg for ProjectivePoint { } pub fn base_to_scalar(x: C::BaseField) -> C::ScalarField { - C::ScalarField::from_biguint(x.to_canonical_biguint()) + C::ScalarField::from_noncanonical_biguint(x.to_canonical_biguint()) } pub fn scalar_to_base(x: C::ScalarField) -> C::BaseField { - C::BaseField::from_biguint(x.to_canonical_biguint()) + C::BaseField::from_noncanonical_biguint(x.to_canonical_biguint()) } diff --git a/ecdsa/src/curve/glv.rs b/ecdsa/src/curve/glv.rs index 05ecea44..c58032ec 100644 --- a/ecdsa/src/curve/glv.rs +++ b/ecdsa/src/curve/glv.rs @@ -45,14 +45,14 @@ pub fn decompose_secp256k1_scalar( ) .round() .to_integer(); - let c1 = Secp256K1Scalar::from_biguint(c1_biguint); + let c1 = Secp256K1Scalar::from_noncanonical_biguint(c1_biguint); let c2_biguint = Ratio::new( MINUS_B1.to_canonical_biguint() * k.to_canonical_biguint(), p.clone(), ) .round() .to_integer(); - let c2 = Secp256K1Scalar::from_biguint(c2_biguint); + let c2 = Secp256K1Scalar::from_noncanonical_biguint(c2_biguint); let k1_raw = k - c1 * A1 - c2 * A2; let k2_raw = c1 * MINUS_B1 - c2 * B2; @@ -61,13 +61,13 @@ pub fn decompose_secp256k1_scalar( let two = BigUint::from_slice(&[2]); let k1_neg = k1_raw.to_canonical_biguint() > p.clone() / two.clone(); let k1 = if k1_neg { - Secp256K1Scalar::from_biguint(p.clone() - k1_raw.to_canonical_biguint()) + Secp256K1Scalar::from_noncanonical_biguint(p.clone() - k1_raw.to_canonical_biguint()) } else { k1_raw }; let k2_neg = k2_raw.to_canonical_biguint() > p.clone() / two; let k2 = if k2_neg { - Secp256K1Scalar::from_biguint(p - k2_raw.to_canonical_biguint()) + Secp256K1Scalar::from_noncanonical_biguint(p - k2_raw.to_canonical_biguint()) } else { k2_raw }; diff --git a/ecdsa/src/curve/secp256k1.rs b/ecdsa/src/curve/secp256k1.rs index e46fbb3d..8f7bccf3 100644 --- a/ecdsa/src/curve/secp256k1.rs +++ b/ecdsa/src/curve/secp256k1.rs @@ -71,7 +71,7 @@ mod tests { #[test] fn test_g1_multiplication() { - let lhs = Secp256K1Scalar::from_biguint(BigUint::from_slice(&[ + let lhs = Secp256K1Scalar::from_noncanonical_biguint(BigUint::from_slice(&[ 1111, 2222, 3333, 4444, 5555, 6666, 7777, 8888, ])); assert_eq!( diff --git a/ecdsa/src/gadgets/biguint.rs b/ecdsa/src/gadgets/biguint.rs index 1dbe4657..faae365c 100644 --- a/ecdsa/src/gadgets/biguint.rs +++ b/ecdsa/src/gadgets/biguint.rs @@ -7,10 +7,10 @@ use plonky2::iop::target::{BoolTarget, Target}; use plonky2::iop::witness::{PartitionWitness, Witness}; use plonky2::plonk::circuit_builder::CircuitBuilder; use plonky2_field::extension::Extendable; -use plonky2_field::types::PrimeField; +use plonky2_field::types::{PrimeField, PrimeField64}; use plonky2_u32::gadgets::arithmetic_u32::{CircuitBuilderU32, U32Target}; use plonky2_u32::gadgets::multiple_comparison::list_le_u32_circuit; -use plonky2_u32::witness::{generated_values_set_u32_target, witness_set_u32_target}; +use plonky2_u32::witness::{GeneratedValuesU32, WitnessU32}; #[derive(Clone, Debug)] pub struct BigUintTarget { @@ -270,41 +270,44 @@ impl, const D: usize> CircuitBuilderBiguint } } -pub fn witness_get_biguint_target, F: PrimeField>( - witness: &W, - bt: BigUintTarget, -) -> BigUint { - bt.limbs - .into_iter() - .rev() - .fold(BigUint::zero(), |acc, limb| { - (acc << 32) + witness.get_target(limb.0).to_canonical_biguint() - }) +pub trait WitnessBigUint: Witness { + fn get_biguint_target(&self, target: BigUintTarget) -> BigUint; + fn set_biguint_target(&mut self, target: &BigUintTarget, value: &BigUint); } -pub fn witness_set_biguint_target, F: PrimeField>( - witness: &mut W, - target: &BigUintTarget, - value: &BigUint, -) { - let mut limbs = value.to_u32_digits(); - assert!(target.num_limbs() >= limbs.len()); - limbs.resize(target.num_limbs(), 0); - for i in 0..target.num_limbs() { - witness_set_u32_target(witness, target.limbs[i], limbs[i]); +impl, F: PrimeField64> WitnessBigUint for T { + fn get_biguint_target(&self, target: BigUintTarget) -> BigUint { + target + .limbs + .into_iter() + .rev() + .fold(BigUint::zero(), |acc, limb| { + (acc << 32) + self.get_target(limb.0).to_canonical_biguint() + }) + } + + fn set_biguint_target(&mut self, target: &BigUintTarget, value: &BigUint) { + let mut limbs = value.to_u32_digits(); + assert!(target.num_limbs() >= limbs.len()); + limbs.resize(target.num_limbs(), 0); + for i in 0..target.num_limbs() { + self.set_u32_target(target.limbs[i], limbs[i]); + } } } -pub fn buffer_set_biguint_target( - buffer: &mut GeneratedValues, - target: &BigUintTarget, - value: &BigUint, -) { - let mut limbs = value.to_u32_digits(); - assert!(target.num_limbs() >= limbs.len()); - limbs.resize(target.num_limbs(), 0); - for i in 0..target.num_limbs() { - generated_values_set_u32_target(buffer, target.get_limb(i), limbs[i]); +pub trait GeneratedValuesBigUint { + fn set_biguint_target(&mut self, target: &BigUintTarget, value: &BigUint); +} + +impl GeneratedValuesBigUint for GeneratedValues { + fn set_biguint_target(&mut self, target: &BigUintTarget, value: &BigUint) { + let mut limbs = value.to_u32_digits(); + assert!(target.num_limbs() >= limbs.len()); + limbs.resize(target.num_limbs(), 0); + for i in 0..target.num_limbs() { + self.set_u32_target(target.get_limb(i), limbs[i]); + } } } @@ -330,12 +333,12 @@ impl, const D: usize> SimpleGenerator } fn run_once(&self, witness: &PartitionWitness, out_buffer: &mut GeneratedValues) { - let a = witness_get_biguint_target(witness, self.a.clone()); - let b = witness_get_biguint_target(witness, self.b.clone()); + let a = witness.get_biguint_target(self.a.clone()); + let b = witness.get_biguint_target(self.b.clone()); let (div, rem) = a.div_rem(&b); - buffer_set_biguint_target(out_buffer, &self.div, &div); - buffer_set_biguint_target(out_buffer, &self.rem, &rem); + out_buffer.set_biguint_target(&self.div, &div); + out_buffer.set_biguint_target(&self.rem, &rem); } } @@ -350,7 +353,7 @@ mod tests { }; use rand::Rng; - use crate::gadgets::biguint::{witness_set_biguint_target, CircuitBuilderBiguint}; + use crate::gadgets::biguint::{CircuitBuilderBiguint, WitnessBigUint}; #[test] fn test_biguint_add() -> Result<()> { @@ -373,9 +376,9 @@ mod tests { let expected_z = builder.add_virtual_biguint_target(expected_z_value.to_u32_digits().len()); builder.connect_biguint(&z, &expected_z); - witness_set_biguint_target(&mut pw, &x, &x_value); - witness_set_biguint_target(&mut pw, &y, &y_value); - witness_set_biguint_target(&mut pw, &expected_z, &expected_z_value); + pw.set_biguint_target(&x, &x_value); + pw.set_biguint_target(&y, &y_value); + pw.set_biguint_target(&expected_z, &expected_z_value); let data = builder.build::(); let proof = data.prove(pw).unwrap(); @@ -433,9 +436,9 @@ mod tests { let expected_z = builder.add_virtual_biguint_target(expected_z_value.to_u32_digits().len()); builder.connect_biguint(&z, &expected_z); - witness_set_biguint_target(&mut pw, &x, &x_value); - witness_set_biguint_target(&mut pw, &y, &y_value); - witness_set_biguint_target(&mut pw, &expected_z, &expected_z_value); + pw.set_biguint_target(&x, &x_value); + pw.set_biguint_target(&y, &y_value); + pw.set_biguint_target(&expected_z, &expected_z_value); let data = builder.build::(); let proof = data.prove(pw).unwrap(); diff --git a/ecdsa/src/gadgets/curve_fixed_base.rs b/ecdsa/src/gadgets/curve_fixed_base.rs index d99d5760..0fd8e841 100644 --- a/ecdsa/src/gadgets/curve_fixed_base.rs +++ b/ecdsa/src/gadgets/curve_fixed_base.rs @@ -30,7 +30,7 @@ pub fn fixed_base_curve_mul_circuit, cons let limbs = builder.split_nonnative_to_4_bit_limbs(scalar); let hash_0 = KeccakHash::<32>::hash_no_pad(&[F::ZERO]); - let hash_0_scalar = C::ScalarField::from_biguint(BigUint::from_bytes_le( + let hash_0_scalar = C::ScalarField::from_noncanonical_biguint(BigUint::from_bytes_le( &GenericHashOut::::to_bytes(&hash_0), )); let rando = (CurveScalar(hash_0_scalar) * C::GENERATOR_PROJECTIVE).to_affine(); @@ -76,7 +76,7 @@ mod tests { use crate::curve::curve_types::{Curve, CurveScalar}; use crate::curve::secp256k1::Secp256K1; - use crate::gadgets::biguint::witness_set_biguint_target; + use crate::gadgets::biguint::WitnessBigUint; use crate::gadgets::curve::CircuitBuilderCurve; use crate::gadgets::curve_fixed_base::fixed_base_curve_mul_circuit; use crate::gadgets::nonnative::CircuitBuilderNonNative; @@ -101,7 +101,7 @@ mod tests { builder.curve_assert_valid(&res_expected); let n_target = builder.add_virtual_nonnative_target::(); - witness_set_biguint_target(&mut pw, &n_target.value, &n.to_canonical_biguint()); + pw.set_biguint_target(&n_target.value, &n.to_canonical_biguint()); let res_target = fixed_base_curve_mul_circuit(&mut builder, g, &n_target); builder.curve_assert_valid(&res_target); diff --git a/ecdsa/src/gadgets/curve_msm.rs b/ecdsa/src/gadgets/curve_msm.rs index 1265d399..e059638c 100644 --- a/ecdsa/src/gadgets/curve_msm.rs +++ b/ecdsa/src/gadgets/curve_msm.rs @@ -29,7 +29,7 @@ pub fn curve_msm_circuit, const D: usize> let num_limbs = limbs_n.len(); let hash_0 = KeccakHash::<32>::hash_no_pad(&[F::ZERO]); - let hash_0_scalar = C::ScalarField::from_biguint(BigUint::from_bytes_le( + let hash_0_scalar = C::ScalarField::from_noncanonical_biguint(BigUint::from_bytes_le( &GenericHashOut::::to_bytes(&hash_0), )); let rando = (CurveScalar(hash_0_scalar) * C::GENERATOR_PROJECTIVE).to_affine(); diff --git a/ecdsa/src/gadgets/curve_windowed_mul.rs b/ecdsa/src/gadgets/curve_windowed_mul.rs index d9dcc734..bc4e1caf 100644 --- a/ecdsa/src/gadgets/curve_windowed_mul.rs +++ b/ecdsa/src/gadgets/curve_windowed_mul.rs @@ -131,7 +131,7 @@ impl, const D: usize> CircuitBuilderWindowedMul, ) -> AffinePointTarget { let hash_0 = KeccakHash::<25>::hash_no_pad(&[F::ZERO]); - let hash_0_scalar = C::ScalarField::from_biguint(BigUint::from_bytes_le( + let hash_0_scalar = C::ScalarField::from_noncanonical_biguint(BigUint::from_bytes_le( &GenericHashOut::::to_bytes(&hash_0), )); let starting_point = CurveScalar(hash_0_scalar) * C::GENERATOR_PROJECTIVE; diff --git a/ecdsa/src/gadgets/ecdsa.rs b/ecdsa/src/gadgets/ecdsa.rs index b287ff05..3ed6342d 100644 --- a/ecdsa/src/gadgets/ecdsa.rs +++ b/ecdsa/src/gadgets/ecdsa.rs @@ -13,10 +13,10 @@ use crate::gadgets::glv::CircuitBuilderGlv; use crate::gadgets::nonnative::{CircuitBuilderNonNative, NonNativeTarget}; #[derive(Clone, Debug)] -pub struct ECDSASecretKeyTarget(NonNativeTarget); +pub struct ECDSASecretKeyTarget(pub NonNativeTarget); #[derive(Clone, Debug)] -pub struct ECDSAPublicKeyTarget(AffinePointTarget); +pub struct ECDSAPublicKeyTarget(pub AffinePointTarget); #[derive(Clone, Debug)] pub struct ECDSASignatureTarget { diff --git a/ecdsa/src/gadgets/glv.rs b/ecdsa/src/gadgets/glv.rs index 746d661f..4302023e 100644 --- a/ecdsa/src/gadgets/glv.rs +++ b/ecdsa/src/gadgets/glv.rs @@ -12,7 +12,7 @@ use plonky2_field::types::{Field, PrimeField}; use crate::curve::glv::{decompose_secp256k1_scalar, GLV_BETA, GLV_S}; use crate::curve::secp256k1::Secp256K1; -use crate::gadgets::biguint::{buffer_set_biguint_target, witness_get_biguint_target}; +use crate::gadgets::biguint::{GeneratedValuesBigUint, WitnessBigUint}; use crate::gadgets::curve::{AffinePointTarget, CircuitBuilderCurve}; use crate::gadgets::curve_msm::curve_msm_circuit; use crate::gadgets::nonnative::{CircuitBuilderNonNative, NonNativeTarget}; @@ -116,15 +116,14 @@ impl, const D: usize> SimpleGenerator } fn run_once(&self, witness: &PartitionWitness, out_buffer: &mut GeneratedValues) { - let k = Secp256K1Scalar::from_biguint(witness_get_biguint_target( - witness, - self.k.value.clone(), - )); + let k = Secp256K1Scalar::from_noncanonical_biguint( + witness.get_biguint_target(self.k.value.clone()), + ); let (k1, k2, k1_neg, k2_neg) = decompose_secp256k1_scalar(k); - buffer_set_biguint_target(out_buffer, &self.k1.value, &k1.to_canonical_biguint()); - buffer_set_biguint_target(out_buffer, &self.k2.value, &k2.to_canonical_biguint()); + out_buffer.set_biguint_target(&self.k1.value, &k1.to_canonical_biguint()); + out_buffer.set_biguint_target(&self.k2.value, &k2.to_canonical_biguint()); out_buffer.set_bool_target(self.k1_neg, k1_neg); out_buffer.set_bool_target(self.k2_neg, k2_neg); } diff --git a/ecdsa/src/gadgets/nonnative.rs b/ecdsa/src/gadgets/nonnative.rs index 3c2e2ed6..c6ff4753 100644 --- a/ecdsa/src/gadgets/nonnative.rs +++ b/ecdsa/src/gadgets/nonnative.rs @@ -10,11 +10,11 @@ use plonky2_field::types::PrimeField; use plonky2_field::{extension::Extendable, types::Field}; use plonky2_u32::gadgets::arithmetic_u32::{CircuitBuilderU32, U32Target}; use plonky2_u32::gadgets::range_check::range_check_u32_circuit; -use plonky2_u32::witness::generated_values_set_u32_target; +use plonky2_u32::witness::GeneratedValuesU32; use plonky2_util::ceil_div_usize; use crate::gadgets::biguint::{ - buffer_set_biguint_target, witness_get_biguint_target, BigUintTarget, CircuitBuilderBiguint, + BigUintTarget, CircuitBuilderBiguint, GeneratedValuesBigUint, WitnessBigUint, }; #[derive(Clone, Debug)] @@ -467,8 +467,8 @@ impl, const D: usize, FF: PrimeField> SimpleGenerat } fn run_once(&self, witness: &PartitionWitness, out_buffer: &mut GeneratedValues) { - let a = FF::from_biguint(witness_get_biguint_target(witness, self.a.value.clone())); - let b = FF::from_biguint(witness_get_biguint_target(witness, self.b.value.clone())); + let a = FF::from_noncanonical_biguint(witness.get_biguint_target(self.a.value.clone())); + let b = FF::from_noncanonical_biguint(witness.get_biguint_target(self.b.value.clone())); let a_biguint = a.to_canonical_biguint(); let b_biguint = b.to_canonical_biguint(); let sum_biguint = a_biguint + b_biguint; @@ -479,7 +479,7 @@ impl, const D: usize, FF: PrimeField> SimpleGenerat (false, sum_biguint) }; - buffer_set_biguint_target(out_buffer, &self.sum.value, &sum_reduced); + out_buffer.set_biguint_target(&self.sum.value, &sum_reduced); out_buffer.set_bool_target(self.overflow, overflow); } } @@ -508,7 +508,7 @@ impl, const D: usize, FF: PrimeField> SimpleGenerat .summands .iter() .map(|summand| { - FF::from_biguint(witness_get_biguint_target(witness, summand.value.clone())) + FF::from_noncanonical_biguint(witness.get_biguint_target(summand.value.clone())) }) .collect(); let summand_biguints: Vec<_> = summands @@ -524,8 +524,8 @@ impl, const D: usize, FF: PrimeField> SimpleGenerat let (overflow_biguint, sum_reduced) = sum_biguint.div_rem(&modulus); let overflow = overflow_biguint.to_u64_digits()[0] as u32; - buffer_set_biguint_target(out_buffer, &self.sum.value, &sum_reduced); - generated_values_set_u32_target(out_buffer, self.overflow, overflow); + out_buffer.set_biguint_target(&self.sum.value, &sum_reduced); + out_buffer.set_u32_target(self.overflow, overflow); } } @@ -553,8 +553,8 @@ impl, const D: usize, FF: PrimeField> SimpleGenerat } fn run_once(&self, witness: &PartitionWitness, out_buffer: &mut GeneratedValues) { - let a = FF::from_biguint(witness_get_biguint_target(witness, self.a.value.clone())); - let b = FF::from_biguint(witness_get_biguint_target(witness, self.b.value.clone())); + let a = FF::from_noncanonical_biguint(witness.get_biguint_target(self.a.value.clone())); + let b = FF::from_noncanonical_biguint(witness.get_biguint_target(self.b.value.clone())); let a_biguint = a.to_canonical_biguint(); let b_biguint = b.to_canonical_biguint(); @@ -565,7 +565,7 @@ impl, const D: usize, FF: PrimeField> SimpleGenerat (modulus + a_biguint - b_biguint, true) }; - buffer_set_biguint_target(out_buffer, &self.diff.value, &diff_biguint); + out_buffer.set_biguint_target(&self.diff.value, &diff_biguint); out_buffer.set_bool_target(self.overflow, overflow); } } @@ -594,8 +594,8 @@ impl, const D: usize, FF: PrimeField> SimpleGenerat } fn run_once(&self, witness: &PartitionWitness, out_buffer: &mut GeneratedValues) { - let a = FF::from_biguint(witness_get_biguint_target(witness, self.a.value.clone())); - let b = FF::from_biguint(witness_get_biguint_target(witness, self.b.value.clone())); + let a = FF::from_noncanonical_biguint(witness.get_biguint_target(self.a.value.clone())); + let b = FF::from_noncanonical_biguint(witness.get_biguint_target(self.b.value.clone())); let a_biguint = a.to_canonical_biguint(); let b_biguint = b.to_canonical_biguint(); @@ -604,8 +604,8 @@ impl, const D: usize, FF: PrimeField> SimpleGenerat let modulus = FF::order(); let (overflow_biguint, prod_reduced) = prod_biguint.div_rem(&modulus); - buffer_set_biguint_target(out_buffer, &self.prod.value, &prod_reduced); - buffer_set_biguint_target(out_buffer, &self.overflow, &overflow_biguint); + out_buffer.set_biguint_target(&self.prod.value, &prod_reduced); + out_buffer.set_biguint_target(&self.overflow, &overflow_biguint); } } @@ -625,7 +625,7 @@ impl, const D: usize, FF: PrimeField> SimpleGenerat } fn run_once(&self, witness: &PartitionWitness, out_buffer: &mut GeneratedValues) { - let x = FF::from_biguint(witness_get_biguint_target(witness, self.x.value.clone())); + let x = FF::from_noncanonical_biguint(witness.get_biguint_target(self.x.value.clone())); let inv = x.inverse(); let x_biguint = x.to_canonical_biguint(); @@ -634,8 +634,8 @@ impl, const D: usize, FF: PrimeField> SimpleGenerat let modulus = FF::order(); let (div, _rem) = prod.div_rem(&modulus); - buffer_set_biguint_target(out_buffer, &self.div, &div); - buffer_set_biguint_target(out_buffer, &self.inv, &inv_biguint); + out_buffer.set_biguint_target(&self.div, &div); + out_buffer.set_biguint_target(&self.inv, &inv_biguint); } } diff --git a/evm/Cargo.toml b/evm/Cargo.toml index e844da3a..9d1bfa02 100644 --- a/evm/Cargo.toml +++ b/evm/Cargo.toml @@ -5,8 +5,10 @@ version = "0.1.0" edition = "2021" [dependencies] -plonky2 = { path = "../plonky2" } +plonky2 = { path = "../plonky2", default-features = false, features = ["rand", "timing"] } plonky2_util = { path = "../util" } +eth-trie-utils = { git = "https://github.com/mir-protocol/eth-trie-utils.git", rev = "3ca443fd18e3f6d209dd96cbad851e05ae058b34" } +maybe_rayon = { path = "../maybe_rayon" } anyhow = "1.0.40" env_logger = "0.9.0" ethereum-types = "0.13.1" @@ -17,12 +19,12 @@ log = "0.4.14" once_cell = "1.13.0" pest = "2.1.3" pest_derive = "2.1.0" -maybe_rayon = { path = "../maybe_rayon" } rand = "0.8.5" rand_chacha = "0.3.1" rlp = "0.5.1" -keccak-rust = { git = "https://github.com/npwardberkeley/keccak-rust" } +serde = { version = "1.0.144", features = ["derive"] } keccak-hash = "0.9.0" +tiny-keccak = "2.0.2" [dev-dependencies] criterion = "0.3.5" @@ -31,7 +33,7 @@ hex = "0.4.3" [features] default = ["parallel"] asmtools = ["hex"] -parallel = ["maybe_rayon/parallel"] +parallel = ["plonky2/parallel", "maybe_rayon/parallel"] [[bin]] name = "assemble" diff --git a/evm/spec/.gitignore b/evm/spec/.gitignore new file mode 100644 index 00000000..ba6d4007 --- /dev/null +++ b/evm/spec/.gitignore @@ -0,0 +1,7 @@ +## Files generated by pdflatex, bibtex, etc. +*.aux +*.log +*.out +*.toc +*.bbl +*.blg diff --git a/evm/spec/Makefile b/evm/spec/Makefile new file mode 100644 index 00000000..97954528 --- /dev/null +++ b/evm/spec/Makefile @@ -0,0 +1,20 @@ +DOCNAME=zkevm + +all: pdf + +.PHONY: clean + +quick: + pdflatex $(DOCNAME).tex + +pdf: + pdflatex $(DOCNAME).tex + bibtex $(DOCNAME).aux + pdflatex $(DOCNAME).tex + pdflatex $(DOCNAME).tex + +view: pdf + open $(DOCNAME).pdf + +clean: + rm -f *.blg *.bbl *.aux *.log diff --git a/evm/spec/bibliography.bib b/evm/spec/bibliography.bib new file mode 100644 index 00000000..41fa56b8 --- /dev/null +++ b/evm/spec/bibliography.bib @@ -0,0 +1,20 @@ +@misc{stark, + author = {Eli Ben-Sasson and + Iddo Bentov and + Yinon Horesh and + Michael Riabzev}, + title = {Scalable, transparent, and post-quantum secure computational integrity}, + howpublished = {Cryptology ePrint Archive, Report 2018/046}, + year = {2018}, + note = {\url{https://ia.cr/2018/046}}, +} + +@misc{plonk, + author = {Ariel Gabizon and + Zachary J. Williamson and + Oana Ciobotaru}, + title = {PLONK: Permutations over Lagrange-bases for Oecumenical Noninteractive arguments of Knowledge}, + howpublished = {Cryptology ePrint Archive, Report 2019/953}, + year = {2019}, + note = {\url{https://ia.cr/2019/953}}, +} diff --git a/evm/spec/framework.tex b/evm/spec/framework.tex new file mode 100644 index 00000000..122e1c75 --- /dev/null +++ b/evm/spec/framework.tex @@ -0,0 +1,39 @@ +\section{STARK framework} +\label{framework} + + +\subsection{Cost model} + +Our zkEVM is designed for efficient verification by STARKs \cite{stark}, particularly by an AIR with degree 3 constraints. In this model, the prover bottleneck is typically constructing Merkle trees, particularly constructing the tree containing low-degree extensions of witness polynomials. + +More specifically, we target a constraint system of degree 3. + + +\subsection{Field selection} +\label{field} +Our zkEVM is designed to have its execution traces encoded in a particular prime field $\mathbb{F}_p$, with $p = 2^{64} - 2^{32} + 1$. A nice property of this field is that it can represent the results of many common \texttt{u32} operations. For example, (widening) \texttt{u32} multiplication has a maximum value of $(2^{32} - 1)^2$, which is less than $p$. In fact a \texttt{u32} multiply-add has a maximum value of $p - 1$, so the result can be represented with a single field element, although if we were to add a carry in bit, this no longer holds. + +This field also enables a very efficient reduction method. Observe that +$$ +2^{64} \equiv 2^{32} - 1 \pmod p +$$ +and consequently +\begin{align*} + 2^{96} &\equiv 2^{32} (2^{32} - 1) \pmod p \\ + &\equiv 2^{64} - 2^{32} \pmod p \\ + &\equiv -1 \pmod p. +\end{align*} +To reduce a 128-bit number $n$, we first rewrite $n$ as $n_0 + 2^{64} n_1 + 2^{96} n_2$, where $n_0$ is 64 bits and $n_1, n_2$ are 32 bits each. Then +\begin{align*} + n &\equiv n_0 + 2^{64} n_1 + 2^{96} n_2 \pmod p \\ + &\equiv n_0 + (2^{32} - 1) n_1 - n_2 \pmod p +\end{align*} +After computing $(2^{32} - 1) n_1$, which can be done with a shift and subtraction, we add the first two terms, subtracting $p$ if overflow occurs. We then subtract $n_2$, adding $p$ if underflow occurs. + +At this point we have reduced $n$ to a \texttt{u64}. This partial reduction is adequate for most purposes, but if we needed the result in canonical form, we would perform a final conditional subtraction. + + +\subsection{Cross-table lookups} +\label{ctl} + +TODO diff --git a/evm/spec/instructions.tex b/evm/spec/instructions.tex new file mode 100644 index 00000000..ea096982 --- /dev/null +++ b/evm/spec/instructions.tex @@ -0,0 +1,8 @@ +\section{Privileged instructions} +\label{privileged-instructions} + +\begin{enumerate} + \item[0xFB.] \texttt{MLOAD\_GENERAL}. Returns + \item[0xFC.] \texttt{MSTORE\_GENERAL}. Returns + \item[TODO.] \texttt{STACK\_SIZE}. Returns +\end{enumerate} diff --git a/evm/spec/introduction.tex b/evm/spec/introduction.tex new file mode 100644 index 00000000..cb969a16 --- /dev/null +++ b/evm/spec/introduction.tex @@ -0,0 +1,3 @@ +\section{Introduction} + +TODO diff --git a/evm/spec/tables.tex b/evm/spec/tables.tex new file mode 100644 index 00000000..92ee1d2a --- /dev/null +++ b/evm/spec/tables.tex @@ -0,0 +1,9 @@ +\section{Tables} +\label{tables} + +\input{tables/cpu} +\input{tables/arithmetic} +\input{tables/logic} +\input{tables/memory} +\input{tables/keccak-f} +\input{tables/keccak-sponge} diff --git a/evm/spec/tables/arithmetic.tex b/evm/spec/tables/arithmetic.tex new file mode 100644 index 00000000..eafed3ba --- /dev/null +++ b/evm/spec/tables/arithmetic.tex @@ -0,0 +1,4 @@ +\subsection{Arithmetic} +\label{arithmetic} + +TODO diff --git a/evm/spec/tables/cpu.tex b/evm/spec/tables/cpu.tex new file mode 100644 index 00000000..76c8be07 --- /dev/null +++ b/evm/spec/tables/cpu.tex @@ -0,0 +1,4 @@ +\subsection{CPU} +\label{cpu} + +TODO diff --git a/evm/spec/tables/keccak-f.tex b/evm/spec/tables/keccak-f.tex new file mode 100644 index 00000000..76e9e9f4 --- /dev/null +++ b/evm/spec/tables/keccak-f.tex @@ -0,0 +1,4 @@ +\subsection{Keccak-f} +\label{keccak-f} + +This table computes the Keccak-f[1600] permutation. diff --git a/evm/spec/tables/keccak-sponge.tex b/evm/spec/tables/keccak-sponge.tex new file mode 100644 index 00000000..29f71ba1 --- /dev/null +++ b/evm/spec/tables/keccak-sponge.tex @@ -0,0 +1,4 @@ +\subsection{Keccak sponge} +\label{keccak-sponge} + +This table computes the Keccak256 hash, a sponge-based hash built on top of the Keccak-f[1600] permutation. diff --git a/evm/spec/tables/logic.tex b/evm/spec/tables/logic.tex new file mode 100644 index 00000000..b430c95d --- /dev/null +++ b/evm/spec/tables/logic.tex @@ -0,0 +1,4 @@ +\subsection{Logic} +\label{logic} + +TODO diff --git a/evm/spec/tables/memory.tex b/evm/spec/tables/memory.tex new file mode 100644 index 00000000..9653f391 --- /dev/null +++ b/evm/spec/tables/memory.tex @@ -0,0 +1,61 @@ +\subsection{Memory} +\label{memory} + +For simplicity, let's treat addresses and values as individual field elements. The generalization to multi-element addresses and values is straightforward. + +Each row of the memory table corresponds to a single memory operation (a read or a write), and contains the following columns: + +\begin{enumerate} + \item $a$, the target address + \item $r$, an ``is read'' flag, which should be 1 for a read or 0 for a write + \item $v$, the value being read or written + \item $\tau$, the timestamp of the operation +\end{enumerate} +The memory table should be ordered by $(a, \tau)$. Note that the correctness memory could be checked as follows: +\begin{enumerate} + \item Verify the ordering by checking that $(a_i, \tau_i) < (a_{i+1}, \tau_{i+1})$ for each consecutive pair. + \item Enumerate the purportedly-ordered log while tracking a ``current'' value $c$, which is initially zero.\footnote{EVM memory is zero-initialized.} + \begin{enumerate} + \item Upon observing an address which doesn't match that of the previous row, set $c \leftarrow 0$. + \item Upon observing a write, set $c \leftarrow v$. + \item Upon observing a read, check that $v = c$. + \end{enumerate} +\end{enumerate} + +The ordering check is slightly involved since we are comparing multiple columns. To facilitate this, we add an additional column $e$, where the prover can indicate whether two consecutive addresses are equal. An honest prover will set +$$ +e_i \leftarrow \begin{cases} + 1 & \text{if } a_i = a_{i + 1}, \\ + 0 & \text{otherwise}. +\end{cases} +$$ +We then impose the following transition constraints: +\begin{enumerate} + \item $e_i (e_i - 1) = 0$, + \item $e_i (a_i - a_{i + 1}) = 0$, + \item $e_i (\tau_{i + 1} - \tau_i) + (1 - e_i) (a_{i + 1} - a_i - 1) < 2^{32}$. +\end{enumerate} +The last constraint emulates a comparison between two addresses or timestamps by bounding their difference; this assumes that all addresses and timestamps fit in 32 bits and that the field is larger than that. + +Finally, the iterative checks can be arithmetized by introducing a trace column for the current value $c$. We add a boundary constraint $c_0 = 0$, and the following transition constraints: +\todo{This is out of date, we don't actually need a $c$ column.} +\begin{enumerate} + \item $v_{\text{from},i} = c_i$, + \item $c_{i + 1} = e_i v_{\text{to},i}$. +\end{enumerate} + + +\subsubsection{Virtual memory} + +In the EVM, each contract call has its own address space. Within that address space, there are separate segments for code, main memory, stack memory, calldata, and returndata. Thus each address actually has three compoments: +\begin{enumerate} + \item an execution context, representing a contract call, + \item a segment ID, used to separate code, main memory, and so forth, and so on + \item a virtual address. +\end{enumerate} +The comparisons now involve several columns, which requires some minor adaptations to the technique described above; we will leave these as an exercise to the reader. + + +\subsubsection{Timestamps} + +TODO: Explain $\tau = \texttt{NUM\_CHANNELS} \times \texttt{cycle} + \texttt{channel}$. diff --git a/evm/spec/tries.tex b/evm/spec/tries.tex new file mode 100644 index 00000000..d8fc2674 --- /dev/null +++ b/evm/spec/tries.tex @@ -0,0 +1,4 @@ +\section{Merkle Patricia tries} +\label{tries} + +TODO diff --git a/evm/spec/zkevm.pdf b/evm/spec/zkevm.pdf new file mode 100644 index 00000000..8501cfbb Binary files /dev/null and b/evm/spec/zkevm.pdf differ diff --git a/evm/spec/zkevm.tex b/evm/spec/zkevm.tex new file mode 100644 index 00000000..f87f02f3 --- /dev/null +++ b/evm/spec/zkevm.tex @@ -0,0 +1,59 @@ +\documentclass[12pt]{article} +\usepackage{amsmath} +\usepackage{amssymb} +\usepackage{cite} +\usepackage{draftwatermark} +\usepackage[margin=1.5in]{geometry} +\usepackage{hyperref} +\usepackage{makecell} +\usepackage{mathtools} +\usepackage{tabularx} +\usepackage[textwidth=1.25in]{todonotes} + +% Scale for DRAFT watermark. +\SetWatermarkFontSize{24cm} +\SetWatermarkScale{5} +\SetWatermarkLightness{0.92} + +% Hyperlink colors. +\hypersetup{ + colorlinks=true, + linkcolor=blue, + citecolor=blue, + urlcolor=blue, +} + +% We want all section autorefs to say "Section". +\def\sectionautorefname{Section} +\let\subsectionautorefname\sectionautorefname +\let\subsubsectionautorefname\sectionautorefname + +% \floor{...} and \ceil{...} +\DeclarePairedDelimiter\ceil{\lceil}{\rceil} +\DeclarePairedDelimiter\floor{\lfloor}{\rfloor} + +\title{The Polygon Zero zkEVM} +%\author{Polygon Zero Team} +\date{DRAFT\\\today} + +\begin{document} +\maketitle + +\begin{abstract} + We describe the design of Polygon Zero's zkEVM, ... +\end{abstract} + +\newpage +{\hypersetup{hidelinks} \tableofcontents} +\newpage + +\input{introduction} +\input{framework} +\input{tables} +\input{tries} +\input{instructions} + +\bibliography{bibliography}{} +\bibliographystyle{ieeetr} + +\end{document} diff --git a/evm/src/all_stark.rs b/evm/src/all_stark.rs index 4b8c7d0a..5fd262ac 100644 --- a/evm/src/all_stark.rs +++ b/evm/src/all_stark.rs @@ -8,6 +8,9 @@ use crate::cpu::cpu_stark::CpuStark; use crate::cross_table_lookup::{CrossTableLookup, TableWithColumns}; use crate::keccak::keccak_stark; use crate::keccak::keccak_stark::KeccakStark; +use crate::keccak_memory::columns::KECCAK_WIDTH_BYTES; +use crate::keccak_memory::keccak_memory_stark; +use crate::keccak_memory::keccak_memory_stark::KeccakMemoryStark; use crate::logic; use crate::logic::LogicStark; use crate::memory::memory_stark::MemoryStark; @@ -18,6 +21,7 @@ use crate::stark::Stark; pub struct AllStark, const D: usize> { pub cpu_stark: CpuStark, pub keccak_stark: KeccakStark, + pub keccak_memory_stark: KeccakMemoryStark, pub logic_stark: LogicStark, pub memory_stark: MemoryStark, pub cross_table_lookups: Vec>, @@ -28,6 +32,7 @@ impl, const D: usize> Default for AllStark { Self { cpu_stark: CpuStark::default(), keccak_stark: KeccakStark::default(), + keccak_memory_stark: KeccakMemoryStark::default(), logic_stark: LogicStark::default(), memory_stark: MemoryStark::default(), cross_table_lookups: all_cross_table_lookups(), @@ -36,26 +41,24 @@ impl, const D: usize> Default for AllStark { } impl, const D: usize> AllStark { - pub(crate) fn nums_permutation_zs(&self, config: &StarkConfig) -> Vec { - let ans = vec![ + pub(crate) fn nums_permutation_zs(&self, config: &StarkConfig) -> [usize; NUM_TABLES] { + [ self.cpu_stark.num_permutation_batches(config), self.keccak_stark.num_permutation_batches(config), + self.keccak_memory_stark.num_permutation_batches(config), self.logic_stark.num_permutation_batches(config), self.memory_stark.num_permutation_batches(config), - ]; - debug_assert_eq!(ans.len(), Table::num_tables()); - ans + ] } - pub(crate) fn permutation_batch_sizes(&self) -> Vec { - let ans = vec![ + pub(crate) fn permutation_batch_sizes(&self) -> [usize; NUM_TABLES] { + [ self.cpu_stark.permutation_batch_size(), self.keccak_stark.permutation_batch_size(), + self.keccak_memory_stark.permutation_batch_size(), self.logic_stark.permutation_batch_size(), self.memory_stark.permutation_batch_size(), - ]; - debug_assert_eq!(ans.len(), Table::num_tables()); - ans + ] } } @@ -63,30 +66,31 @@ impl, const D: usize> AllStark { pub enum Table { Cpu = 0, Keccak = 1, - Logic = 2, - Memory = 3, + KeccakMemory = 2, + Logic = 3, + Memory = 4, } -impl Table { - pub(crate) fn num_tables() -> usize { - Table::Memory as usize + 1 - } -} +pub(crate) const NUM_TABLES: usize = Table::Memory as usize + 1; #[allow(unused)] // TODO: Should be used soon. pub(crate) fn all_cross_table_lookups() -> Vec> { - let mut cross_table_lookups = vec![ctl_keccak(), ctl_logic()]; - cross_table_lookups.extend((0..NUM_CHANNELS).map(ctl_memory)); - cross_table_lookups + vec![ctl_keccak(), ctl_logic(), ctl_memory(), ctl_keccak_memory()] } fn ctl_keccak() -> CrossTableLookup { + let cpu_looking = TableWithColumns::new( + Table::Cpu, + cpu_stark::ctl_data_keccak(), + Some(cpu_stark::ctl_filter_keccak()), + ); + let keccak_memory_looking = TableWithColumns::new( + Table::KeccakMemory, + keccak_memory_stark::ctl_looking_keccak(), + Some(keccak_memory_stark::ctl_filter()), + ); CrossTableLookup::new( - vec![TableWithColumns::new( - Table::Cpu, - cpu_stark::ctl_data_keccak(), - Some(cpu_stark::ctl_filter_keccak()), - )], + vec![cpu_looking, keccak_memory_looking], TableWithColumns::new( Table::Keccak, keccak_stark::ctl_data(), @@ -96,6 +100,22 @@ fn ctl_keccak() -> CrossTableLookup { ) } +fn ctl_keccak_memory() -> CrossTableLookup { + CrossTableLookup::new( + vec![TableWithColumns::new( + Table::Cpu, + cpu_stark::ctl_data_keccak_memory(), + Some(cpu_stark::ctl_filter_keccak_memory()), + )], + TableWithColumns::new( + Table::KeccakMemory, + keccak_memory_stark::ctl_looked_data(), + Some(keccak_memory_stark::ctl_filter()), + ), + None, + ) +} + fn ctl_logic() -> CrossTableLookup { CrossTableLookup::new( vec![TableWithColumns::new( @@ -108,17 +128,38 @@ fn ctl_logic() -> CrossTableLookup { ) } -fn ctl_memory(channel: usize) -> CrossTableLookup { - CrossTableLookup::new( - vec![TableWithColumns::new( +fn ctl_memory() -> CrossTableLookup { + let cpu_memory_ops = (0..NUM_CHANNELS).map(|channel| { + TableWithColumns::new( Table::Cpu, cpu_stark::ctl_data_memory(channel), Some(cpu_stark::ctl_filter_memory(channel)), - )], + ) + }); + let keccak_memory_reads = (0..KECCAK_WIDTH_BYTES).map(|i| { + TableWithColumns::new( + Table::KeccakMemory, + keccak_memory_stark::ctl_looking_memory(i, true), + Some(keccak_memory_stark::ctl_filter()), + ) + }); + let keccak_memory_writes = (0..KECCAK_WIDTH_BYTES).map(|i| { + TableWithColumns::new( + Table::KeccakMemory, + keccak_memory_stark::ctl_looking_memory(i, false), + Some(keccak_memory_stark::ctl_filter()), + ) + }); + let all_lookers = cpu_memory_ops + .chain(keccak_memory_reads) + .chain(keccak_memory_writes) + .collect(); + CrossTableLookup::new( + all_lookers, TableWithColumns::new( Table::Memory, memory_stark::ctl_data(), - Some(memory_stark::ctl_filter(channel)), + Some(memory_stark::ctl_filter()), ), None, ) @@ -146,12 +187,13 @@ mod tests { use crate::cpu::kernel::aggregator::KERNEL; use crate::cross_table_lookup::testutils::check_ctls; use crate::keccak::keccak_stark::{KeccakStark, NUM_INPUTS, NUM_ROUNDS}; + use crate::keccak_memory::keccak_memory_stark::KeccakMemoryStark; use crate::logic::{self, LogicStark, Operation}; use crate::memory::memory_stark::tests::generate_random_memory_ops; use crate::memory::memory_stark::MemoryStark; use crate::memory::NUM_CHANNELS; - use crate::proof::AllProof; - use crate::prover::prove; + use crate::proof::{AllProof, PublicValues}; + use crate::prover::prove_with_traces; use crate::recursive_verifier::{ add_virtual_all_proof, set_all_proof_target, verify_proof_circuit, }; @@ -175,6 +217,13 @@ mod tests { keccak_stark.generate_trace(keccak_inputs) } + fn make_keccak_memory_trace( + keccak_memory_stark: &KeccakMemoryStark, + config: &StarkConfig, + ) -> Vec> { + keccak_memory_stark.generate_trace(vec![], 1 << config.fri_config.cap_height) + } + fn make_logic_trace( num_rows: usize, logic_stark: &LogicStark, @@ -203,6 +252,19 @@ mod tests { (trace, num_ops) } + fn bits_from_opcode(opcode: u8) -> [F; 8] { + [ + F::from_bool(opcode & (1 << 0) != 0), + F::from_bool(opcode & (1 << 1) != 0), + F::from_bool(opcode & (1 << 2) != 0), + F::from_bool(opcode & (1 << 3) != 0), + F::from_bool(opcode & (1 << 4) != 0), + F::from_bool(opcode & (1 << 5) != 0), + F::from_bool(opcode & (1 << 6) != 0), + F::from_bool(opcode & (1 << 7) != 0), + ] + } + fn make_cpu_trace( num_keccak_perms: usize, num_logic_rows: usize, @@ -256,39 +318,18 @@ mod tests { cpu_trace_rows.push(row.into()); } - for i in 0..num_logic_rows { + // Pad to `num_memory_ops` for memory testing. + for _ in cpu_trace_rows.len()..num_memory_ops { let mut row: cpu::columns::CpuColumnsView = [F::ZERO; CpuStark::::COLUMNS].into(); + row.opcode_bits = bits_from_opcode(0x5b); row.is_cpu_cycle = F::ONE; row.is_kernel_mode = F::ONE; - row.program_counter = F::from_canonical_usize(i); - row.opcode = [ - (logic::columns::IS_AND, 0x16), - (logic::columns::IS_OR, 0x17), - (logic::columns::IS_XOR, 0x18), - ] - .into_iter() - .map(|(col, opcode)| logic_trace[col].values[i] * F::from_canonical_u64(opcode)) - .sum(); - let logic = row.general.logic_mut(); - - let input0_bit_cols = logic::columns::limb_bit_cols_for_input(logic::columns::INPUT0); - for (col_cpu, limb_cols_logic) in logic.input0.iter_mut().zip(input0_bit_cols) { - *col_cpu = limb_from_bits_le(limb_cols_logic.map(|col| logic_trace[col].values[i])); - } - - let input1_bit_cols = logic::columns::limb_bit_cols_for_input(logic::columns::INPUT1); - for (col_cpu, limb_cols_logic) in logic.input1.iter_mut().zip(input1_bit_cols) { - *col_cpu = limb_from_bits_le(limb_cols_logic.map(|col| logic_trace[col].values[i])); - } - - for (col_cpu, col_logic) in logic.output.iter_mut().zip(logic::columns::RESULT) { - *col_cpu = logic_trace[col_logic].values[i]; - } - + row.program_counter = F::from_canonical_usize(KERNEL.global_labels["route_txn"]); cpu_stark.generate(row.borrow_mut()); cpu_trace_rows.push(row.into()); } + for i in 0..num_memory_ops { let mem_timestamp: usize = memory_trace[memory::columns::TIMESTAMP].values[i] .to_canonical_u64() @@ -297,29 +338,72 @@ mod tests { let clock = mem_timestamp / NUM_CHANNELS; let channel = mem_timestamp % NUM_CHANNELS; - let is_padding_row = (0..NUM_CHANNELS) - .map(|c| memory_trace[memory::columns::is_channel(c)].values[i]) - .all(|x| x == F::ZERO); + let filter = memory_trace[memory::columns::FILTER].values[i]; + assert!(filter.is_one() || filter.is_zero()); + let is_actual_op = filter.is_one(); - if !is_padding_row { + if is_actual_op { let row: &mut cpu::columns::CpuColumnsView = cpu_trace_rows[clock].borrow_mut(); - - row.mem_channel_used[channel] = F::ONE; row.clock = F::from_canonical_usize(clock); - row.mem_is_read[channel] = memory_trace[memory::columns::IS_READ].values[i]; - row.mem_addr_context[channel] = - memory_trace[memory::columns::ADDR_CONTEXT].values[i]; - row.mem_addr_segment[channel] = - memory_trace[memory::columns::ADDR_SEGMENT].values[i]; - row.mem_addr_virtual[channel] = - memory_trace[memory::columns::ADDR_VIRTUAL].values[i]; + + let channel = &mut row.mem_channels[channel]; + channel.used = F::ONE; + channel.is_read = memory_trace[memory::columns::IS_READ].values[i]; + channel.addr_context = memory_trace[memory::columns::ADDR_CONTEXT].values[i]; + channel.addr_segment = memory_trace[memory::columns::ADDR_SEGMENT].values[i]; + channel.addr_virtual = memory_trace[memory::columns::ADDR_VIRTUAL].values[i]; for j in 0..8 { - row.mem_value[channel][j] = - memory_trace[memory::columns::value_limb(j)].values[i]; + channel.value[j] = memory_trace[memory::columns::value_limb(j)].values[i]; } } } + for i in 0..num_logic_rows { + let mut row: cpu::columns::CpuColumnsView = + [F::ZERO; CpuStark::::COLUMNS].into(); + row.is_cpu_cycle = F::ONE; + row.is_kernel_mode = F::ONE; + + // Since these are the first cycle rows, we must start with PC=route_txn then increment. + row.program_counter = F::from_canonical_usize(KERNEL.global_labels["route_txn"] + i); + row.opcode_bits = bits_from_opcode( + if logic_trace[logic::columns::IS_AND].values[i] != F::ZERO { + 0x16 + } else if logic_trace[logic::columns::IS_OR].values[i] != F::ZERO { + 0x17 + } else if logic_trace[logic::columns::IS_XOR].values[i] != F::ZERO { + 0x18 + } else { + panic!() + }, + ); + + let input0_bit_cols = logic::columns::limb_bit_cols_for_input(logic::columns::INPUT0); + for (col_cpu, limb_cols_logic) in + row.mem_channels[0].value.iter_mut().zip(input0_bit_cols) + { + *col_cpu = limb_from_bits_le(limb_cols_logic.map(|col| logic_trace[col].values[i])); + } + + let input1_bit_cols = logic::columns::limb_bit_cols_for_input(logic::columns::INPUT1); + for (col_cpu, limb_cols_logic) in + row.mem_channels[1].value.iter_mut().zip(input1_bit_cols) + { + *col_cpu = limb_from_bits_le(limb_cols_logic.map(|col| logic_trace[col].values[i])); + } + + for (col_cpu, col_logic) in row.mem_channels[2] + .value + .iter_mut() + .zip(logic::columns::RESULT) + { + *col_cpu = logic_trace[col_logic].values[i]; + } + + cpu_stark.generate(row.borrow_mut()); + cpu_trace_rows.push(row.into()); + } + // Trap to kernel { let mut row: cpu::columns::CpuColumnsView = @@ -327,10 +411,10 @@ mod tests { let last_row: cpu::columns::CpuColumnsView = cpu_trace_rows[cpu_trace_rows.len() - 1].into(); row.is_cpu_cycle = F::ONE; - row.opcode = F::from_canonical_u8(0x0a); // `EXP` is implemented in software + row.opcode_bits = bits_from_opcode(0x0a); // `EXP` is implemented in software row.is_kernel_mode = F::ONE; row.program_counter = last_row.program_counter + F::ONE; - row.general.syscalls_mut().output = [ + row.mem_channels[0].value = [ row.program_counter, F::ONE, F::ZERO, @@ -349,10 +433,10 @@ mod tests { let mut row: cpu::columns::CpuColumnsView = [F::ZERO; CpuStark::::COLUMNS].into(); row.is_cpu_cycle = F::ONE; - row.opcode = F::from_canonical_u8(0xf9); + row.opcode_bits = bits_from_opcode(0xf9); row.is_kernel_mode = F::ONE; row.program_counter = F::from_canonical_usize(KERNEL.global_labels["sys_exp"]); - row.general.jumps_mut().input0 = [ + row.mem_channels[0].value = [ F::from_canonical_u16(15682), F::ONE, F::ZERO, @@ -371,10 +455,10 @@ mod tests { let mut row: cpu::columns::CpuColumnsView = [F::ZERO; CpuStark::::COLUMNS].into(); row.is_cpu_cycle = F::ONE; - row.opcode = F::from_canonical_u8(0x56); + row.opcode_bits = bits_from_opcode(0x56); row.is_kernel_mode = F::ONE; row.program_counter = F::from_canonical_u16(15682); - row.general.jumps_mut().input0 = [ + row.mem_channels[0].value = [ F::from_canonical_u16(15106), F::ZERO, F::ZERO, @@ -384,7 +468,7 @@ mod tests { F::ZERO, F::ZERO, ]; - row.general.jumps_mut().input1 = [ + row.mem_channels[1].value = [ F::ONE, F::ZERO, F::ZERO, @@ -408,10 +492,10 @@ mod tests { let mut row: cpu::columns::CpuColumnsView = [F::ZERO; CpuStark::::COLUMNS].into(); row.is_cpu_cycle = F::ONE; - row.opcode = F::from_canonical_u8(0xf9); + row.opcode_bits = bits_from_opcode(0xf9); row.is_kernel_mode = F::ONE; row.program_counter = F::from_canonical_u16(15106); - row.general.jumps_mut().input0 = [ + row.mem_channels[0].value = [ F::from_canonical_u16(63064), F::ZERO, F::ZERO, @@ -430,10 +514,10 @@ mod tests { let mut row: cpu::columns::CpuColumnsView = [F::ZERO; CpuStark::::COLUMNS].into(); row.is_cpu_cycle = F::ONE; - row.opcode = F::from_canonical_u8(0x56); + row.opcode_bits = bits_from_opcode(0x56); row.is_kernel_mode = F::ZERO; row.program_counter = F::from_canonical_u16(63064); - row.general.jumps_mut().input0 = [ + row.mem_channels[0].value = [ F::from_canonical_u16(3754), F::ZERO, F::ZERO, @@ -443,7 +527,7 @@ mod tests { F::ZERO, F::ZERO, ]; - row.general.jumps_mut().input1 = [ + row.mem_channels[1].value = [ F::ONE, F::ZERO, F::ZERO, @@ -468,10 +552,10 @@ mod tests { let mut row: cpu::columns::CpuColumnsView = [F::ZERO; CpuStark::::COLUMNS].into(); row.is_cpu_cycle = F::ONE; - row.opcode = F::from_canonical_u8(0x57); + row.opcode_bits = bits_from_opcode(0x57); row.is_kernel_mode = F::ZERO; row.program_counter = F::from_canonical_u16(3754); - row.general.jumps_mut().input0 = [ + row.mem_channels[0].value = [ F::from_canonical_u16(37543), F::ZERO, F::ZERO, @@ -481,7 +565,7 @@ mod tests { F::ZERO, F::ZERO, ]; - row.general.jumps_mut().input1 = [ + row.mem_channels[1].value = [ F::ZERO, F::ZERO, F::ZERO, @@ -506,10 +590,10 @@ mod tests { let mut row: cpu::columns::CpuColumnsView = [F::ZERO; CpuStark::::COLUMNS].into(); row.is_cpu_cycle = F::ONE; - row.opcode = F::from_canonical_u8(0x57); + row.opcode_bits = bits_from_opcode(0x57); row.is_kernel_mode = F::ZERO; row.program_counter = F::from_canonical_u16(37543); - row.general.jumps_mut().input0 = [ + row.mem_channels[0].value = [ F::from_canonical_u16(37543), F::ZERO, F::ZERO, @@ -535,10 +619,10 @@ mod tests { let last_row: cpu::columns::CpuColumnsView = cpu_trace_rows[cpu_trace_rows.len() - 1].into(); row.is_cpu_cycle = F::ONE; - row.opcode = F::from_canonical_u8(0x56); + row.opcode_bits = bits_from_opcode(0x56); row.is_kernel_mode = F::ZERO; row.program_counter = last_row.program_counter + F::ONE; - row.general.jumps_mut().input0 = [ + row.mem_channels[0].value = [ F::from_canonical_u16(37543), F::ZERO, F::ZERO, @@ -548,7 +632,7 @@ mod tests { F::ZERO, F::ZERO, ]; - row.general.jumps_mut().input1 = [ + row.mem_channels[1].value = [ F::ONE, F::ZERO, F::ZERO, @@ -572,7 +656,7 @@ mod tests { for i in 0..cpu_trace_rows.len().next_power_of_two() - cpu_trace_rows.len() { let mut row: cpu::columns::CpuColumnsView = [F::ZERO; CpuStark::::COLUMNS].into(); - row.opcode = F::from_canonical_u8(0xff); + row.opcode_bits = bits_from_opcode(0xff); row.is_cpu_cycle = F::ONE; row.is_kernel_mode = F::ONE; row.program_counter = @@ -604,6 +688,7 @@ mod tests { let num_keccak_perms = 2; let keccak_trace = make_keccak_trace(num_keccak_perms, &all_stark.keccak_stark, &mut rng); + let keccak_memory_trace = make_keccak_memory_trace(&all_stark.keccak_memory_stark, config); let logic_trace = make_logic_trace(num_logic_rows, &all_stark.logic_stark, &mut rng); let mem_trace = make_memory_trace(num_memory_ops, &all_stark.memory_stark, &mut rng); let mut memory_trace = mem_trace.0; @@ -618,14 +703,21 @@ mod tests { &mut memory_trace, ); - let traces = vec![cpu_trace, keccak_trace, logic_trace, memory_trace]; + let traces = [ + cpu_trace, + keccak_trace, + keccak_memory_trace, + logic_trace, + memory_trace, + ]; check_ctls(&traces, &all_stark.cross_table_lookups); - let proof = prove::( + let public_values = PublicValues::default(); + let proof = prove_with_traces::( &all_stark, config, traces, - vec![vec![]; 4], + public_values, &mut TimingTree::default(), )?; diff --git a/evm/src/arithmetic/add.rs b/evm/src/arithmetic/add.rs index 80f03d63..e87566b6 100644 --- a/evm/src/arithmetic/add.rs +++ b/evm/src/arithmetic/add.rs @@ -9,6 +9,21 @@ use crate::arithmetic::columns::*; use crate::constraint_consumer::{ConstraintConsumer, RecursiveConstraintConsumer}; use crate::range_check_error; +pub(crate) fn u256_add_cc(input0: [u64; N_LIMBS], input1: [u64; N_LIMBS]) -> ([u64; N_LIMBS], u64) { + // Input and output have 16-bit limbs + let mut output = [0u64; N_LIMBS]; + + const MASK: u64 = (1u64 << LIMB_BITS) - 1u64; + let mut cy = 0u64; + for (i, a, b) in izip!(0.., input0, input1) { + let s = a + b + cy; + cy = s >> LIMB_BITS; + assert!(cy <= 1u64, "input limbs were larger than 16 bits"); + output[i] = s & MASK; + } + (output, cy) +} + /// Given two sequences `larger` and `smaller` of equal length (not /// checked), verifies that \sum_i larger[i] 2^(LIMB_BITS * i) == /// \sum_i smaller[i] 2^(LIMB_BITS * i), taking care of carry propagation. @@ -19,7 +34,8 @@ pub(crate) fn eval_packed_generic_are_equal( is_op: P, larger: I, smaller: J, -) where +) -> P +where P: PackedField, I: Iterator, J: Iterator, @@ -36,6 +52,7 @@ pub(crate) fn eval_packed_generic_are_equal( // increase the degree of the constraint. cy = t * overflow_inv; } + cy } pub(crate) fn eval_ext_circuit_are_equal( @@ -44,7 +61,8 @@ pub(crate) fn eval_ext_circuit_are_equal( is_op: ExtensionTarget, larger: I, smaller: J, -) where +) -> ExtensionTarget +where F: RichField + Extendable, I: Iterator>, J: Iterator>, @@ -72,6 +90,7 @@ pub(crate) fn eval_ext_circuit_are_equal( cy = builder.mul_const_extension(overflow_inv, t); } + cy } pub fn generate(lv: &mut [F; NUM_ARITH_COLUMNS]) { @@ -79,17 +98,7 @@ pub fn generate(lv: &mut [F; NUM_ARITH_COLUMNS]) { let input1_limbs = ADD_INPUT_1.map(|c| lv[c].to_canonical_u64()); // Input and output have 16-bit limbs - let mut output_limbs = [0u64; N_LIMBS]; - - const MASK: u64 = (1u64 << LIMB_BITS) - 1u64; - let mut cy = 0u64; - for (i, a, b) in izip!(0.., input0_limbs, input1_limbs) { - let s = a + b + cy; - cy = s >> LIMB_BITS; - assert!(cy <= 1u64, "input limbs were larger than 16 bits"); - output_limbs[i] = s & MASK; - } - // last carry is dropped because this is addition modulo 2^256. + let (output_limbs, _) = u256_add_cc(input0_limbs, input1_limbs); for (&c, output_limb) in ADD_OUTPUT.iter().zip(output_limbs) { lv[c] = F::from_canonical_u64(output_limb); diff --git a/evm/src/arithmetic/arithmetic_stark.rs b/evm/src/arithmetic/arithmetic_stark.rs index ce8c7528..58b8afff 100644 --- a/evm/src/arithmetic/arithmetic_stark.rs +++ b/evm/src/arithmetic/arithmetic_stark.rs @@ -8,6 +8,7 @@ use plonky2::hash::hash_types::RichField; use crate::arithmetic::add; use crate::arithmetic::columns; +use crate::arithmetic::compare; use crate::arithmetic::mul; use crate::arithmetic::sub; use crate::constraint_consumer::{ConstraintConsumer, RecursiveConstraintConsumer}; @@ -45,6 +46,10 @@ impl ArithmeticStark { sub::generate(local_values); } else if local_values[columns::IS_MUL].is_one() { mul::generate(local_values); + } else if local_values[columns::IS_LT].is_one() { + compare::generate(local_values, columns::IS_LT); + } else if local_values[columns::IS_GT].is_one() { + compare::generate(local_values, columns::IS_GT); } else { todo!("the requested operation has not yet been implemented"); } @@ -53,11 +58,10 @@ impl ArithmeticStark { impl, const D: usize> Stark for ArithmeticStark { const COLUMNS: usize = columns::NUM_ARITH_COLUMNS; - const PUBLIC_INPUTS: usize = 0; fn eval_packed_generic( &self, - vars: StarkEvaluationVars, + vars: StarkEvaluationVars, yield_constr: &mut ConstraintConsumer

, ) where FE: FieldExtension, @@ -67,18 +71,20 @@ impl, const D: usize> Stark for ArithmeticSta add::eval_packed_generic(lv, yield_constr); sub::eval_packed_generic(lv, yield_constr); mul::eval_packed_generic(lv, yield_constr); + compare::eval_packed_generic(lv, yield_constr); } fn eval_ext_circuit( &self, builder: &mut plonky2::plonk::circuit_builder::CircuitBuilder, - vars: StarkEvaluationTargets, + vars: StarkEvaluationTargets, yield_constr: &mut RecursiveConstraintConsumer, ) { let lv = vars.local_values; add::eval_ext_circuit(builder, lv, yield_constr); sub::eval_ext_circuit(builder, lv, yield_constr); mul::eval_ext_circuit(builder, lv, yield_constr); + compare::eval_ext_circuit(builder, lv, yield_constr); } fn constraint_degree(&self) -> usize { diff --git a/evm/src/arithmetic/columns.rs b/evm/src/arithmetic/columns.rs index e51419a8..7b44adc1 100644 --- a/evm/src/arithmetic/columns.rs +++ b/evm/src/arithmetic/columns.rs @@ -79,4 +79,9 @@ pub(crate) const MUL_INPUT_1: [usize; N_LIMBS] = GENERAL_INPUT_1; pub(crate) const MUL_OUTPUT: [usize; N_LIMBS] = GENERAL_INPUT_2; pub(crate) const MUL_AUX_INPUT: [usize; N_LIMBS] = AUX_INPUT_0; +pub(crate) const CMP_INPUT_0: [usize; N_LIMBS] = GENERAL_INPUT_0; +pub(crate) const CMP_INPUT_1: [usize; N_LIMBS] = GENERAL_INPUT_1; +pub(crate) const CMP_OUTPUT: usize = GENERAL_INPUT_2[0]; +pub(crate) const CMP_AUX_INPUT: [usize; N_LIMBS] = AUX_INPUT_0; + pub const NUM_ARITH_COLUMNS: usize = START_SHARED_COLS + NUM_SHARED_COLS; diff --git a/evm/src/arithmetic/compare.rs b/evm/src/arithmetic/compare.rs new file mode 100644 index 00000000..a6566db5 --- /dev/null +++ b/evm/src/arithmetic/compare.rs @@ -0,0 +1,233 @@ +//! Support for EVM LT and GT instructions +//! +//! This crate verifies EVM LT and GT instructions (i.e. for unsigned +//! inputs). The difference between LT and GT is of course just a +//! matter of the order of the inputs. The verification is essentially +//! identical to the SUB instruction: For both SUB and LT we have values +//! +//! - `input0` +//! - `input1` +//! - `difference` (mod 2^256) +//! - `borrow` (= 0 or 1) +//! +//! satisfying `input0 - input1 = difference + borrow * 2^256`. Where +//! SUB verifies `difference` and ignores `borrow`, LT verifies +//! `borrow` (and uses `difference` as an auxiliary input). + +use plonky2::field::extension::Extendable; +use plonky2::field::packed::PackedField; +use plonky2::hash::hash_types::RichField; +use plonky2::iop::ext_target::ExtensionTarget; + +use crate::arithmetic::add::{eval_ext_circuit_are_equal, eval_packed_generic_are_equal}; +use crate::arithmetic::columns::*; +use crate::arithmetic::sub::u256_sub_br; +use crate::constraint_consumer::{ConstraintConsumer, RecursiveConstraintConsumer}; +use crate::range_check_error; + +pub(crate) fn generate(lv: &mut [F; NUM_ARITH_COLUMNS], op: usize) { + let input0 = CMP_INPUT_0.map(|c| lv[c].to_canonical_u64()); + let input1 = CMP_INPUT_1.map(|c| lv[c].to_canonical_u64()); + + let (diff, br) = match op { + // input0 - input1 == diff + br*2^256 + IS_LT => u256_sub_br(input0, input1), + // input1 - input0 == diff + br*2^256 + IS_GT => u256_sub_br(input1, input0), + IS_SLT => todo!(), + IS_SGT => todo!(), + _ => panic!("op code not a comparison"), + }; + + for (&c, diff_limb) in CMP_AUX_INPUT.iter().zip(diff) { + lv[c] = F::from_canonical_u64(diff_limb); + } + lv[CMP_OUTPUT] = F::from_canonical_u64(br); +} + +fn eval_packed_generic_check_is_one_bit( + yield_constr: &mut ConstraintConsumer

, + filter: P, + x: P, +) { + yield_constr.constraint(filter * x * (x - P::ONES)); +} + +pub(crate) fn eval_packed_generic_lt( + yield_constr: &mut ConstraintConsumer

, + is_op: P, + input0: [P; N_LIMBS], + input1: [P; N_LIMBS], + aux: [P; N_LIMBS], + output: P, +) { + // Verify (input0 < input1) == output by providing aux such that + // input0 - input1 == aux + output*2^256. + let lhs_limbs = input0.iter().zip(input1).map(|(&a, b)| a - b); + let cy = eval_packed_generic_are_equal(yield_constr, is_op, aux.into_iter(), lhs_limbs); + // We don't need to check that cy is 0 or 1, since output has + // already been checked to be 0 or 1. + yield_constr.constraint(is_op * (cy - output)); +} + +pub fn eval_packed_generic( + lv: &[P; NUM_ARITH_COLUMNS], + yield_constr: &mut ConstraintConsumer

, +) { + range_check_error!(CMP_INPUT_0, 16); + range_check_error!(CMP_INPUT_1, 16); + range_check_error!(CMP_AUX_INPUT, 16); + + let is_lt = lv[IS_LT]; + let is_gt = lv[IS_GT]; + + let input0 = CMP_INPUT_0.map(|c| lv[c]); + let input1 = CMP_INPUT_1.map(|c| lv[c]); + let aux = CMP_AUX_INPUT.map(|c| lv[c]); + let output = lv[CMP_OUTPUT]; + + let is_cmp = is_lt + is_gt; + eval_packed_generic_check_is_one_bit(yield_constr, is_cmp, output); + + eval_packed_generic_lt(yield_constr, is_lt, input0, input1, aux, output); + eval_packed_generic_lt(yield_constr, is_gt, input1, input0, aux, output); +} + +fn eval_ext_circuit_check_is_one_bit, const D: usize>( + builder: &mut plonky2::plonk::circuit_builder::CircuitBuilder, + yield_constr: &mut RecursiveConstraintConsumer, + filter: ExtensionTarget, + x: ExtensionTarget, +) { + let constr = builder.mul_sub_extension(x, x, x); + let filtered_constr = builder.mul_extension(filter, constr); + yield_constr.constraint(builder, filtered_constr); +} + +#[allow(clippy::needless_collect)] +pub(crate) fn eval_ext_circuit_lt, const D: usize>( + builder: &mut plonky2::plonk::circuit_builder::CircuitBuilder, + yield_constr: &mut RecursiveConstraintConsumer, + is_op: ExtensionTarget, + input0: [ExtensionTarget; N_LIMBS], + input1: [ExtensionTarget; N_LIMBS], + aux: [ExtensionTarget; N_LIMBS], + output: ExtensionTarget, +) { + // Since `map` is lazy and the closure passed to it borrows + // `builder`, we can't then borrow builder again below in the call + // to `eval_ext_circuit_are_equal`. The solution is to force + // evaluation with `collect`. + let lhs_limbs = input0 + .iter() + .zip(input1) + .map(|(&a, b)| builder.sub_extension(a, b)) + .collect::>>(); + + let cy = eval_ext_circuit_are_equal( + builder, + yield_constr, + is_op, + aux.into_iter(), + lhs_limbs.into_iter(), + ); + let good_output = builder.sub_extension(cy, output); + let filter = builder.mul_extension(is_op, good_output); + yield_constr.constraint(builder, filter); +} + +pub fn eval_ext_circuit, const D: usize>( + builder: &mut plonky2::plonk::circuit_builder::CircuitBuilder, + lv: &[ExtensionTarget; NUM_ARITH_COLUMNS], + yield_constr: &mut RecursiveConstraintConsumer, +) { + let is_lt = lv[IS_LT]; + let is_gt = lv[IS_GT]; + + let input0 = CMP_INPUT_0.map(|c| lv[c]); + let input1 = CMP_INPUT_1.map(|c| lv[c]); + let aux = CMP_AUX_INPUT.map(|c| lv[c]); + let output = lv[CMP_OUTPUT]; + + let is_cmp = builder.add_extension(is_lt, is_gt); + eval_ext_circuit_check_is_one_bit(builder, yield_constr, is_cmp, output); + + eval_ext_circuit_lt(builder, yield_constr, is_lt, input0, input1, aux, output); + eval_ext_circuit_lt(builder, yield_constr, is_gt, input1, input0, aux, output); +} + +#[cfg(test)] +mod tests { + use plonky2::field::goldilocks_field::GoldilocksField; + use plonky2::field::types::Field; + use rand::{Rng, SeedableRng}; + use rand_chacha::ChaCha8Rng; + + use super::*; + use crate::arithmetic::columns::NUM_ARITH_COLUMNS; + use crate::constraint_consumer::ConstraintConsumer; + + // TODO: Should be able to refactor this test to apply to all operations. + #[test] + fn generate_eval_consistency_not_compare() { + type F = GoldilocksField; + + let mut rng = ChaCha8Rng::seed_from_u64(0x6feb51b7ec230f25); + let mut lv = [F::default(); NUM_ARITH_COLUMNS].map(|_| F::rand_from_rng(&mut rng)); + + // if `IS_LT == 0`, then the constraints should be met even if + // all values are garbage. `eval_packed_generic` handles IS_GT + // at the same time, so we check both at once. + lv[IS_LT] = F::ZERO; + lv[IS_GT] = F::ZERO; + + let mut constrant_consumer = ConstraintConsumer::new( + vec![GoldilocksField(2), GoldilocksField(3), GoldilocksField(5)], + F::ONE, + F::ONE, + F::ONE, + ); + eval_packed_generic(&lv, &mut constrant_consumer); + for &acc in &constrant_consumer.constraint_accs { + assert_eq!(acc, F::ZERO); + } + } + + #[test] + fn generate_eval_consistency_compare() { + type F = GoldilocksField; + + let mut rng = ChaCha8Rng::seed_from_u64(0x6feb51b7ec230f25); + let mut lv = [F::default(); NUM_ARITH_COLUMNS].map(|_| F::rand_from_rng(&mut rng)); + const N_ITERS: usize = 1000; + + for _ in 0..N_ITERS { + for (op, other_op) in [(IS_LT, IS_GT), (IS_GT, IS_LT)] { + // set op == 1 and ensure all constraints are satisfied. + // we have to explicitly set the other op to zero since both + // are treated by the call. + lv[op] = F::ONE; + lv[other_op] = F::ZERO; + + // set inputs to random values + for (&ai, bi) in CMP_INPUT_0.iter().zip(CMP_INPUT_1) { + lv[ai] = F::from_canonical_u16(rng.gen()); + lv[bi] = F::from_canonical_u16(rng.gen()); + } + + generate(&mut lv, op); + + let mut constrant_consumer = ConstraintConsumer::new( + vec![GoldilocksField(2), GoldilocksField(3), GoldilocksField(5)], + F::ONE, + F::ONE, + F::ONE, + ); + eval_packed_generic(&lv, &mut constrant_consumer); + for &acc in &constrant_consumer.constraint_accs { + assert_eq!(acc, F::ZERO); + } + } + } + } +} diff --git a/evm/src/arithmetic/mod.rs b/evm/src/arithmetic/mod.rs index 07c4c5a9..69fbda09 100644 --- a/evm/src/arithmetic/mod.rs +++ b/evm/src/arithmetic/mod.rs @@ -1,4 +1,5 @@ mod add; +mod compare; mod mul; mod sub; mod utils; diff --git a/evm/src/arithmetic/sub.rs b/evm/src/arithmetic/sub.rs index ce7932e2..c632eb94 100644 --- a/evm/src/arithmetic/sub.rs +++ b/evm/src/arithmetic/sub.rs @@ -9,26 +9,29 @@ use crate::arithmetic::columns::*; use crate::constraint_consumer::{ConstraintConsumer, RecursiveConstraintConsumer}; use crate::range_check_error; -pub fn generate(lv: &mut [F; NUM_ARITH_COLUMNS]) { - let input0_limbs = SUB_INPUT_0.map(|c| lv[c].to_canonical_u64()); - let input1_limbs = SUB_INPUT_1.map(|c| lv[c].to_canonical_u64()); - - // Input and output have 16-bit limbs - let mut output_limbs = [0u64; N_LIMBS]; - +pub(crate) fn u256_sub_br(input0: [u64; N_LIMBS], input1: [u64; N_LIMBS]) -> ([u64; N_LIMBS], u64) { const LIMB_BOUNDARY: u64 = 1 << LIMB_BITS; const MASK: u64 = LIMB_BOUNDARY - 1u64; + let mut output = [0u64; N_LIMBS]; let mut br = 0u64; - for (i, a, b) in izip!(0.., input0_limbs, input1_limbs) { + for (i, a, b) in izip!(0.., input0, input1) { let d = LIMB_BOUNDARY + a - b - br; // if a < b, then d < 2^16 so br = 1 // if a >= b, then d >= 2^16 so br = 0 br = 1u64 - (d >> LIMB_BITS); assert!(br <= 1u64, "input limbs were larger than 16 bits"); - output_limbs[i] = d & MASK; + output[i] = d & MASK; } - // last borrow is dropped because this is subtraction modulo 2^256. + + (output, br) +} + +pub fn generate(lv: &mut [F; NUM_ARITH_COLUMNS]) { + let input0_limbs = SUB_INPUT_0.map(|c| lv[c].to_canonical_u64()); + let input1_limbs = SUB_INPUT_1.map(|c| lv[c].to_canonical_u64()); + + let (output_limbs, _) = u256_sub_br(input0_limbs, input1_limbs); for (&c, output_limb) in SUB_OUTPUT.iter().zip(output_limbs) { lv[c] = F::from_canonical_u64(output_limb); diff --git a/evm/src/arithmetic/utils.rs b/evm/src/arithmetic/utils.rs index dc9a0a2f..c50481f3 100644 --- a/evm/src/arithmetic/utils.rs +++ b/evm/src/arithmetic/utils.rs @@ -19,4 +19,7 @@ macro_rules! range_check_error { ($cols:ident, $rc_bits:expr) => { $crate::arithmetic::utils::_range_check_error::<$rc_bits>(file!(), line!(), &$cols); }; + ([$cols:ident], $rc_bits:expr) => { + $crate::arithmetic::utils::_range_check_error::<$rc_bits>(file!(), line!(), &[$cols]); + }; } diff --git a/evm/src/constraint_consumer.rs b/evm/src/constraint_consumer.rs index ebe0637a..49dc018c 100644 --- a/evm/src/constraint_consumer.rs +++ b/evm/src/constraint_consumer.rs @@ -44,12 +44,8 @@ impl ConstraintConsumer

{ } } - // TODO: Do this correctly. - pub fn accumulators(self) -> Vec { + pub fn accumulators(self) -> Vec

{ self.constraint_accs - .into_iter() - .map(|acc| acc.as_slice()[0]) - .collect() } /// Add one constraint valid on all rows except the last. diff --git a/evm/src/cpu/bootstrap_kernel.rs b/evm/src/cpu/bootstrap_kernel.rs index 2c6afb51..533589af 100644 --- a/evm/src/cpu/bootstrap_kernel.rs +++ b/evm/src/cpu/bootstrap_kernel.rs @@ -15,7 +15,6 @@ use crate::constraint_consumer::{ConstraintConsumer, RecursiveConstraintConsumer use crate::cpu::columns::{CpuColumnsView, NUM_CPU_COLUMNS}; use crate::cpu::kernel::aggregator::KERNEL; use crate::cpu::kernel::keccak_util::keccakf_u32s; -use crate::cpu::public_inputs::NUM_PUBLIC_INPUTS; use crate::generation::state::GenerationState; use crate::memory::segments::Segment; use crate::memory::NUM_CHANNELS; @@ -50,7 +49,7 @@ pub(crate) fn generate_bootstrap_kernel(state: &mut GenerationState let mut packed_bytes: u32 = 0; for (addr, byte) in chunk { let channel = addr % NUM_CHANNELS; - state.set_mem_current(channel, Segment::Code, addr, byte.into()); + state.set_mem_cpu_current(channel, Segment::Code, addr, byte.into()); packed_bytes = (packed_bytes << 8) | byte as u32; } @@ -73,7 +72,7 @@ pub(crate) fn generate_bootstrap_kernel(state: &mut GenerationState } pub(crate) fn eval_bootstrap_kernel>( - vars: StarkEvaluationVars, + vars: StarkEvaluationVars, yield_constr: &mut ConstraintConsumer

, ) { let local_values: &CpuColumnsView<_> = vars.local_values.borrow(); @@ -109,7 +108,7 @@ pub(crate) fn eval_bootstrap_kernel>( pub(crate) fn eval_bootstrap_kernel_circuit, const D: usize>( builder: &mut CircuitBuilder, - vars: StarkEvaluationTargets, + vars: StarkEvaluationTargets, yield_constr: &mut RecursiveConstraintConsumer, ) { let local_values: &CpuColumnsView<_> = vars.local_values.borrow(); diff --git a/evm/src/cpu/columns/general.rs b/evm/src/cpu/columns/general.rs index db7436ba..134788dc 100644 --- a/evm/src/cpu/columns/general.rs +++ b/evm/src/cpu/columns/general.rs @@ -9,7 +9,6 @@ pub(crate) union CpuGeneralColumnsView { arithmetic: CpuArithmeticView, logic: CpuLogicView, jumps: CpuJumpsView, - syscalls: CpuSyscallsView, } impl CpuGeneralColumnsView { @@ -52,16 +51,6 @@ impl CpuGeneralColumnsView { pub(crate) fn jumps_mut(&mut self) -> &mut CpuJumpsView { unsafe { &mut self.jumps } } - - // SAFETY: Each view is a valid interpretation of the underlying array. - pub(crate) fn syscalls(&self) -> &CpuSyscallsView { - unsafe { &self.syscalls } - } - - // SAFETY: Each view is a valid interpretation of the underlying array. - pub(crate) fn syscalls_mut(&mut self) -> &mut CpuSyscallsView { - unsafe { &mut self.syscalls } - } } impl PartialEq for CpuGeneralColumnsView { @@ -107,20 +96,16 @@ pub(crate) struct CpuArithmeticView { #[derive(Copy, Clone)] pub(crate) struct CpuLogicView { - // Assuming a limb size of 16 bits. This can be changed, but it must be <= 28 bits. - pub(crate) input0: [T; 16], - pub(crate) input1: [T; 16], - pub(crate) output: [T; 16], + // Pseudoinverse of `(input0 - input1)`. Used prove that they are unequal. Assumes 32-bit limbs. + pub(crate) diff_pinv: [T; 8], } #[derive(Copy, Clone)] pub(crate) struct CpuJumpsView { - /// Assuming a limb size of 32 bits. - /// The top stack value at entry (for jumps, the address; for `EXIT_KERNEL`, the address and new - /// privilege level). - pub(crate) input0: [T; 8], - /// For `JUMPI`, the second stack value (the predicate). For `JUMP`, 1. - pub(crate) input1: [T; 8], + /// `input0` is `mem_channel[0].value`. It's the top stack value at entry (for jumps, the + /// address; for `EXIT_KERNEL`, the address and new privilege level). + /// `input1` is `mem_channel[1].value`. For `JUMPI`, it's the second stack value (the + /// predicate). For `JUMP`, 1. /// Inverse of `input0[1] + ... + input0[7]`, if one exists; otherwise, an arbitrary value. /// Needed to prove that `input0` is nonzero. @@ -159,15 +144,5 @@ pub(crate) struct CpuJumpsView { pub(crate) should_trap: T, } -#[derive(Copy, Clone)] -pub(crate) struct CpuSyscallsView { - /// Assuming a limb size of 32 bits. - /// The output contains the context that is required to from the system call in `EXIT_KERNEL`. - /// `output[0]` contains the program counter at the time the system call was made (the address - /// of the syscall instruction). `output[1]` is 1 if we were in kernel mode at the time and 0 - /// otherwise. `output[2]`, ..., `output[7]` are zero. - pub(crate) output: [T; 8], -} - // `u8` is guaranteed to have a `size_of` of 1. pub const NUM_SHARED_COLUMNS: usize = size_of::>(); diff --git a/evm/src/cpu/columns/mod.rs b/evm/src/cpu/columns/mod.rs index 824ae13d..93e93ce6 100644 --- a/evm/src/cpu/columns/mod.rs +++ b/evm/src/cpu/columns/mod.rs @@ -3,14 +3,28 @@ use std::borrow::{Borrow, BorrowMut}; use std::fmt::Debug; -use std::mem::{size_of, transmute, transmute_copy, ManuallyDrop}; +use std::mem::{size_of, transmute}; use std::ops::{Index, IndexMut}; use crate::cpu::columns::general::CpuGeneralColumnsView; use crate::memory; +use crate::util::{indices_arr, transmute_no_compile_time_size_checks}; mod general; +#[repr(C)] +#[derive(Clone, Copy, Debug, Eq, PartialEq)] +pub struct MemoryChannelView { + /// 1 if this row includes a memory operation in the `i`th channel of the memory bus, otherwise + /// 0. + pub used: T, + pub is_read: T, + pub addr_context: T, + pub addr_segment: T, + pub addr_virtual: T, + pub value: [T; memory::VALUE_LIMBS], +} + #[repr(C)] #[derive(Eq, PartialEq, Debug)] pub struct CpuColumnsView { @@ -24,12 +38,16 @@ pub struct CpuColumnsView { /// If CPU cycle: The program counter for the current instruction. pub program_counter: T, + /// If CPU cycle: The stack length. + pub stack_len: T, + + /// If CPU cycle: A prover-provided value needed to show that the instruction does not cause the + /// stack to underflow or overflow. + pub stack_len_bounds_aux: T, + /// If CPU cycle: We're in kernel (privileged) mode. pub is_kernel_mode: T, - /// If CPU cycle: The opcode being decoded, in {0, ..., 255}. - pub opcode: T, - // If CPU cycle: flags for EVM instructions. PUSHn, DUPn, and SWAPn only get one flag each. // Invalid opcodes are split between a number of flags for practical reasons. Exactly one of // these flags must be 1. @@ -127,62 +145,26 @@ pub struct CpuColumnsView { pub is_revert: T, pub is_selfdestruct: T, - // An instruction is invalid if _any_ of the below flags is 1. - pub is_invalid_0: T, - pub is_invalid_1: T, - pub is_invalid_2: T, - pub is_invalid_3: T, - pub is_invalid_4: T, - pub is_invalid_5: T, - pub is_invalid_6: T, - pub is_invalid_7: T, - pub is_invalid_8: T, - pub is_invalid_9: T, - pub is_invalid_10: T, - pub is_invalid_11: T, - pub is_invalid_12: T, - pub is_invalid_13: T, - pub is_invalid_14: T, - pub is_invalid_15: T, - pub is_invalid_16: T, - pub is_invalid_17: T, - pub is_invalid_18: T, - pub is_invalid_19: T, - pub is_invalid_20: T, + pub is_invalid: T, /// If CPU cycle: the opcode, broken up into bits in little-endian order. pub opcode_bits: [T; 8], - /// Filter. 1 iff a Keccak permutation is computed on this row. + /// Filter. 1 iff a Keccak lookup is performed on this row. pub is_keccak: T, + /// Filter. 1 iff a Keccak memory lookup is performed on this row. + pub is_keccak_memory: T, + pub(crate) general: CpuGeneralColumnsView, - pub simple_logic_diff: T, - pub simple_logic_diff_inv: T, - pub(crate) clock: T, - /// 1 if this row includes a memory operation in the `i`th channel of the memory bus, otherwise - /// 0. - pub mem_channel_used: [T; memory::NUM_CHANNELS], - pub mem_is_read: [T; memory::NUM_CHANNELS], - pub mem_addr_context: [T; memory::NUM_CHANNELS], - pub mem_addr_segment: [T; memory::NUM_CHANNELS], - pub mem_addr_virtual: [T; memory::NUM_CHANNELS], - pub mem_value: [[T; memory::VALUE_LIMBS]; memory::NUM_CHANNELS], + pub mem_channels: [MemoryChannelView; memory::NUM_CHANNELS], } // `u8` is guaranteed to have a `size_of` of 1. pub const NUM_CPU_COLUMNS: usize = size_of::>(); -unsafe fn transmute_no_compile_time_size_checks(value: T) -> U { - debug_assert_eq!(size_of::(), size_of::()); - // Need ManuallyDrop so that `value` is not dropped by this function. - let value = ManuallyDrop::new(value); - // Copy the bit pattern. The original value is no longer safe to use. - transmute_copy(&value) -} - impl From<[T; NUM_CPU_COLUMNS]> for CpuColumnsView { fn from(value: [T; NUM_CPU_COLUMNS]) -> Self { unsafe { transmute_no_compile_time_size_checks(value) } @@ -242,12 +224,7 @@ where } const fn make_col_map() -> CpuColumnsView { - let mut indices_arr = [0; NUM_CPU_COLUMNS]; - let mut i = 0; - while i < NUM_CPU_COLUMNS { - indices_arr[i] = i; - i += 1; - } + let indices_arr = indices_arr::(); unsafe { transmute::<[usize; NUM_CPU_COLUMNS], CpuColumnsView>(indices_arr) } } diff --git a/evm/src/cpu/control_flow.rs b/evm/src/cpu/control_flow.rs index a157653f..5a43f7cf 100644 --- a/evm/src/cpu/control_flow.rs +++ b/evm/src/cpu/control_flow.rs @@ -68,20 +68,16 @@ pub fn eval_packed_generic( lv.is_cpu_cycle * is_native_instruction * (lv.is_kernel_mode - nv.is_kernel_mode), ); - // If a non-CPU cycle row is followed by a CPU cycle row, then the `program_counter` of the CPU - // cycle row is 0 and it is in kernel mode. - yield_constr - .constraint_transition((lv.is_cpu_cycle - P::ONES) * nv.is_cpu_cycle * nv.program_counter); - yield_constr.constraint_transition( - (lv.is_cpu_cycle - P::ONES) * nv.is_cpu_cycle * (nv.is_kernel_mode - P::ONES), - ); - - // The first row has nowhere to continue execution from, so if it's a cycle row, then its - // `program_counter` must be 0. - // NB: I know the first few rows will be used for initialization and will not be CPU cycle rows. - // Once that's done, then this constraint can be removed. Until then, it is needed to ensure - // that execution starts at 0 and not at any arbitrary offset. - yield_constr.constraint_first_row(lv.is_cpu_cycle * lv.program_counter); + // If a non-CPU cycle row is followed by a CPU cycle row, then: + // - the `program_counter` of the CPU cycle row is `route_txn` (the entry point of our kernel), + // - execution is in kernel mode, and + // - the stack is empty. + let is_last_noncpu_cycle = (lv.is_cpu_cycle - P::ONES) * nv.is_cpu_cycle; + let pc_diff = + nv.program_counter - P::Scalar::from_canonical_usize(KERNEL.global_labels["route_txn"]); + yield_constr.constraint_transition(is_last_noncpu_cycle * pc_diff); + yield_constr.constraint_transition(is_last_noncpu_cycle * (nv.is_kernel_mode - P::ONES)); + yield_constr.constraint_transition(is_last_noncpu_cycle * nv.stack_len); // The last row must be a CPU cycle row. yield_constr.constraint_last_row(lv.is_cpu_cycle - P::ONES); @@ -121,24 +117,33 @@ pub fn eval_ext_circuit, const D: usize>( yield_constr.constraint_transition(builder, kernel_constr); } - // If a non-CPU cycle row is followed by a CPU cycle row, then the `program_counter` of the CPU - // cycle row is 0 and it is in kernel mode. + // If a non-CPU cycle row is followed by a CPU cycle row, then: + // - the `program_counter` of the CPU cycle row is `route_txn` (the entry point of our kernel), + // - execution is in kernel mode, and + // - the stack is empty. { - let filter = builder.mul_sub_extension(lv.is_cpu_cycle, nv.is_cpu_cycle, nv.is_cpu_cycle); - let pc_constr = builder.mul_extension(filter, nv.program_counter); - yield_constr.constraint_transition(builder, pc_constr); - let kernel_constr = builder.mul_sub_extension(filter, nv.is_kernel_mode, filter); - yield_constr.constraint_transition(builder, kernel_constr); - } + let is_last_noncpu_cycle = + builder.mul_sub_extension(lv.is_cpu_cycle, nv.is_cpu_cycle, nv.is_cpu_cycle); - // The first row has nowhere to continue execution from, so if it's a cycle row, then its - // `program_counter` must be 0. - // NB: I know the first few rows will be used for initialization and will not be CPU cycle rows. - // Once that's done, then this constraint can be removed. Until then, it is needed to ensure - // that execution starts at 0 and not at any arbitrary offset. - { - let constr = builder.mul_extension(lv.is_cpu_cycle, lv.program_counter); - yield_constr.constraint_first_row(builder, constr); + // Start at `route_txn`. + let route_txn = builder.constant_extension(F::Extension::from_canonical_usize( + KERNEL.global_labels["route_txn"], + )); + let pc_diff = builder.sub_extension(nv.program_counter, route_txn); + let pc_constr = builder.mul_extension(is_last_noncpu_cycle, pc_diff); + yield_constr.constraint_transition(builder, pc_constr); + + // Start in kernel mode + let kernel_constr = builder.mul_sub_extension( + is_last_noncpu_cycle, + nv.is_kernel_mode, + is_last_noncpu_cycle, + ); + yield_constr.constraint_transition(builder, kernel_constr); + + // Start with empty stack + let kernel_constr = builder.mul_extension(is_last_noncpu_cycle, nv.stack_len); + yield_constr.constraint_transition(builder, kernel_constr); } // The last row must be a CPU cycle row. diff --git a/evm/src/cpu/cpu_stark.rs b/evm/src/cpu/cpu_stark.rs index 918f7d9b..9949b044 100644 --- a/evm/src/cpu/cpu_stark.rs +++ b/evm/src/cpu/cpu_stark.rs @@ -9,7 +9,9 @@ use plonky2::hash::hash_types::RichField; use crate::constraint_consumer::{ConstraintConsumer, RecursiveConstraintConsumer}; use crate::cpu::columns::{CpuColumnsView, COL_MAP, NUM_CPU_COLUMNS}; -use crate::cpu::{bootstrap_kernel, control_flow, decode, jumps, simple_logic, syscalls}; +use crate::cpu::{ + bootstrap_kernel, control_flow, decode, jumps, simple_logic, stack_bounds, syscalls, +}; use crate::cross_table_lookup::Column; use crate::memory::NUM_CHANNELS; use crate::stark::Stark; @@ -22,16 +24,35 @@ pub fn ctl_data_keccak() -> Vec> { res } +pub fn ctl_data_keccak_memory() -> Vec> { + // When executing KECCAK_GENERAL, the memory channels are used as follows: + // channel 0: instruction + // channel 1: stack[-1] = context + // channel 2: stack[-2] = segment + // channel 3: stack[-3] = virtual + let context = Column::single(COL_MAP.mem_channels[1].value[0]); + let segment = Column::single(COL_MAP.mem_channels[2].value[0]); + let virt = Column::single(COL_MAP.mem_channels[3].value[0]); + + let num_channels = F::from_canonical_usize(NUM_CHANNELS); + let clock = Column::linear_combination([(COL_MAP.clock, num_channels)]); + + vec![context, segment, virt, clock] +} + pub fn ctl_filter_keccak() -> Column { Column::single(COL_MAP.is_keccak) } +pub fn ctl_filter_keccak_memory() -> Column { + Column::single(COL_MAP.is_keccak_memory) +} + pub fn ctl_data_logic() -> Vec> { let mut res = Column::singles([COL_MAP.is_and, COL_MAP.is_or, COL_MAP.is_xor]).collect_vec(); - let logic = COL_MAP.general.logic(); - res.extend(Column::singles(logic.input0)); - res.extend(Column::singles(logic.input1)); - res.extend(Column::singles(logic.output)); + res.extend(Column::singles(COL_MAP.mem_channels[0].value)); + res.extend(Column::singles(COL_MAP.mem_channels[1].value)); + res.extend(Column::singles(COL_MAP.mem_channels[2].value)); res } @@ -41,19 +62,20 @@ pub fn ctl_filter_logic() -> Column { pub fn ctl_data_memory(channel: usize) -> Vec> { debug_assert!(channel < NUM_CHANNELS); + let channel_map = COL_MAP.mem_channels[channel]; let mut cols: Vec> = Column::singles([ - COL_MAP.mem_is_read[channel], - COL_MAP.mem_addr_context[channel], - COL_MAP.mem_addr_segment[channel], - COL_MAP.mem_addr_virtual[channel], + channel_map.is_read, + channel_map.addr_context, + channel_map.addr_segment, + channel_map.addr_virtual, ]) .collect_vec(); - cols.extend(Column::singles(COL_MAP.mem_value[channel])); + cols.extend(Column::singles(channel_map.value)); let scalar = F::from_canonical_usize(NUM_CHANNELS); let addend = F::from_canonical_usize(channel); cols.push(Column::linear_combination_with_constant( - vec![(COL_MAP.clock, scalar)], + [(COL_MAP.clock, scalar)], addend, )); @@ -61,7 +83,7 @@ pub fn ctl_data_memory(channel: usize) -> Vec> { } pub fn ctl_filter_memory(channel: usize) -> Column { - Column::single(COL_MAP.mem_channel_used[channel]) + Column::single(COL_MAP.mem_channels[channel].used) } #[derive(Copy, Clone, Default)] @@ -74,16 +96,16 @@ impl CpuStark { let local_values: &mut CpuColumnsView<_> = local_values.borrow_mut(); decode::generate(local_values); simple_logic::generate(local_values); + stack_bounds::generate(local_values); // Must come after `decode`. } } impl, const D: usize> Stark for CpuStark { const COLUMNS: usize = NUM_CPU_COLUMNS; - const PUBLIC_INPUTS: usize = 0; fn eval_packed_generic( &self, - vars: StarkEvaluationVars, + vars: StarkEvaluationVars, yield_constr: &mut ConstraintConsumer

, ) where FE: FieldExtension, @@ -96,13 +118,14 @@ impl, const D: usize> Stark for CpuStark, - vars: StarkEvaluationTargets, + vars: StarkEvaluationTargets, yield_constr: &mut RecursiveConstraintConsumer, ) { let local_values = vars.local_values.borrow(); @@ -112,6 +135,7 @@ impl, const D: usize> Stark for CpuStark [u8; 32] { + let mut res = [u8::MAX; 32]; // Start with all opcodes marked invalid. + + let mut i = 0; + while i < OPCODES.len() { + let (block_start, lb_block_len, kernel_only, _) = OPCODES[i]; + i += 1; + + if kernel_only { + continue; + } + + let block_len = 1 << lb_block_len; + let block_start = block_start as usize; + let block_end = block_start + block_len; + let mut j = block_start; + while j < block_end { + let byte = j / u8::BITS as usize; + let bit = j % u8::BITS as usize; + res[byte] &= !(1 << bit); // Mark opcode as invalid by zeroing the bit. + j += 1; + } + } + res +} + pub fn generate(lv: &mut CpuColumnsView) { let cycle_filter = lv.is_cpu_cycle; if cycle_filter == F::ZERO { @@ -158,13 +157,16 @@ pub fn generate(lv: &mut CpuColumnsView) { // This assert is not _strictly_ necessary, but I include it as a sanity check. assert_eq!(cycle_filter, F::ONE, "cycle_filter should be 0 or 1"); - let opcode = lv.opcode.to_canonical_u64(); - assert!(opcode < 256, "opcode should be in {{0, ..., 255}}"); - let opcode = opcode as u8; - - for (i, bit) in lv.opcode_bits.iter_mut().enumerate() { - *bit = F::from_bool(opcode & (1 << i) != 0); + // Validate all opcode bits. + for bit in lv.opcode_bits.into_iter() { + assert!(bit.to_canonical_u64() <= 1); } + let opcode = lv + .opcode_bits + .into_iter() + .enumerate() + .map(|(i, bit)| bit.to_canonical_u64() << i) + .sum::() as u8; let top_bits: [u8; 9] = [ 0, @@ -182,15 +184,19 @@ pub fn generate(lv: &mut CpuColumnsView) { assert!(kernel <= 1); let kernel = kernel != 0; - for (oc, block_length, availability, col) in OPCODES { - let available = match availability { - All => true, - User => !kernel, - Kernel => kernel, - }; + let mut any_flag_set = false; + for (oc, block_length, kernel_only, col) in OPCODES { + let available = !kernel_only || kernel; let opcode_match = top_bits[8 - block_length] == oc; - lv[col] = F::from_bool(available && opcode_match); + let flag = available && opcode_match; + lv[col] = F::from_bool(flag); + if flag && any_flag_set { + panic!("opcode matched multiple flags"); + } + any_flag_set = any_flag_set || flag; } + // is_invalid is a catch-all for opcodes we can't decode. + lv.is_invalid = F::from_bool(!any_flag_set); } /// Break up an opcode (which is 8 bits long) into its eight bits. @@ -217,23 +223,10 @@ pub fn eval_packed_generic( let kernel_mode = lv.is_kernel_mode; yield_constr.constraint(cycle_filter * kernel_mode * (kernel_mode - P::ONES)); - // Ensure that the opcode bits are valid: each has to be either 0 or 1, and they must match - // the opcode. Note that this also implicitly range-checks the opcode. - let bits = lv.opcode_bits; - // First check that the bits are either 0 or 1. - for bit in bits { + // Ensure that the opcode bits are valid: each has to be either 0 or 1. + for bit in lv.opcode_bits { yield_constr.constraint(cycle_filter * bit * (bit - P::ONES)); } - // Now check that they match the opcode. - { - let opcode = lv.opcode; - let reconstructed_opcode: P = bits - .into_iter() - .enumerate() - .map(|(i, bit)| bit * P::Scalar::from_canonical_u64(1 << i)) - .sum(); - yield_constr.constraint(cycle_filter * (opcode - reconstructed_opcode)); - } // Check that the instruction flags are valid. // First, check that they are all either 0 or 1. @@ -241,24 +234,26 @@ pub fn eval_packed_generic( let flag = lv[flag_col]; yield_constr.constraint(cycle_filter * flag * (flag - P::ONES)); } + yield_constr.constraint(cycle_filter * lv.is_invalid * (lv.is_invalid - P::ONES)); // Now check that exactly one is 1. let flag_sum: P = OPCODES .into_iter() .map(|(_, _, _, flag_col)| lv[flag_col]) - .sum(); + .sum::

() + + lv.is_invalid; yield_constr.constraint(cycle_filter * (P::ONES - flag_sum)); // Finally, classify all opcodes, together with the kernel flag, into blocks - for (oc, block_length, availability, col) in OPCODES { - // 0 if the block/flag is available to us (is always available, is user-only and we are in - // user mode, or kernel-only and we are in kernel mode) and 1 otherwise. - let unavailable = match availability { - All => P::ZEROS, - User => kernel_mode, - Kernel => P::ONES - kernel_mode, + for (oc, block_length, kernel_only, col) in OPCODES { + // 0 if the block/flag is available to us (is always available or we are in kernel mode) and + // 1 otherwise. + let unavailable = match kernel_only { + false => P::ZEROS, + true => P::ONES - kernel_mode, }; // 0 if all the opcode bits match, and something in {1, ..., 8}, otherwise. - let opcode_mismatch: P = bits + let opcode_mismatch: P = lv + .opcode_bits .into_iter() .zip(bits_from_opcode(oc)) .rev() @@ -294,28 +289,12 @@ pub fn eval_ext_circuit, const D: usize>( yield_constr.constraint(builder, constr); } - // Ensure that the opcode bits are valid: each has to be either 0 or 1, and they must match - // the opcode. Note that this also implicitly range-checks the opcode. - let bits = lv.opcode_bits; - // First check that the bits are either 0 or 1. - for bit in bits { + // Ensure that the opcode bits are valid: each has to be either 0 or 1. + for bit in lv.opcode_bits { let constr = builder.mul_sub_extension(bit, bit, bit); let constr = builder.mul_extension(cycle_filter, constr); yield_constr.constraint(builder, constr); } - // Now check that they match the opcode. - { - let opcode = lv.opcode; - let reconstructed_opcode = - bits.into_iter() - .enumerate() - .fold(builder.zero_extension(), |cumul, (i, bit)| { - builder.mul_const_add_extension(F::from_canonical_u64(1 << i), bit, cumul) - }); - let diff = builder.sub_extension(opcode, reconstructed_opcode); - let constr = builder.mul_extension(cycle_filter, diff); - yield_constr.constraint(builder, constr); - } // Check that the instruction flags are valid. // First, check that they are all either 0 or 1. @@ -325,6 +304,11 @@ pub fn eval_ext_circuit, const D: usize>( let constr = builder.mul_extension(cycle_filter, constr); yield_constr.constraint(builder, constr); } + { + let constr = builder.mul_sub_extension(lv.is_invalid, lv.is_invalid, lv.is_invalid); + let constr = builder.mul_extension(cycle_filter, constr); + yield_constr.constraint(builder, constr); + } // Now check that exactly one is 1. { let mut constr = builder.one_extension(); @@ -332,21 +316,22 @@ pub fn eval_ext_circuit, const D: usize>( let flag = lv[flag_col]; constr = builder.sub_extension(constr, flag); } + constr = builder.sub_extension(constr, lv.is_invalid); constr = builder.mul_extension(cycle_filter, constr); yield_constr.constraint(builder, constr); } // Finally, classify all opcodes, together with the kernel flag, into blocks - for (oc, block_length, availability, col) in OPCODES { - // 0 if the block/flag is available to us (is always available, is user-only and we are in - // user mode, or kernel-only and we are in kernel mode) and 1 otherwise. - let unavailable = match availability { - All => builder.zero_extension(), - User => kernel_mode, - Kernel => builder.sub_extension(one, kernel_mode), + for (oc, block_length, kernel_only, col) in OPCODES { + // 0 if the block/flag is available to us (is always available or we are in kernel mode) and + // 1 otherwise. + let unavailable = match kernel_only { + false => builder.zero_extension(), + true => builder.sub_extension(one, kernel_mode), }; // 0 if all the opcode bits match, and something in {1, ..., 8}, otherwise. - let opcode_mismatch = bits + let opcode_mismatch = lv + .opcode_bits .into_iter() .zip(bits_from_opcode(oc)) .rev() diff --git a/evm/src/cpu/jumps.rs b/evm/src/cpu/jumps.rs index 10c9503a..219b39dd 100644 --- a/evm/src/cpu/jumps.rs +++ b/evm/src/cpu/jumps.rs @@ -17,16 +17,16 @@ pub fn eval_packed_exit_kernel( nv: &CpuColumnsView

, yield_constr: &mut ConstraintConsumer

, ) { - let jumps_lv = lv.general.jumps(); + let input = lv.mem_channels[0].value; // If we are executing `EXIT_KERNEL` then we simply restore the program counter and kernel mode // flag. The top 6 (32-bit) limbs are ignored (this is not part of the spec, but we trust the // kernel to set them to zero). yield_constr.constraint_transition( - lv.is_cpu_cycle * lv.is_exit_kernel * (jumps_lv.input0[0] - nv.program_counter), + lv.is_cpu_cycle * lv.is_exit_kernel * (input[0] - nv.program_counter), ); yield_constr.constraint_transition( - lv.is_cpu_cycle * lv.is_exit_kernel * (jumps_lv.input0[1] - nv.is_kernel_mode), + lv.is_cpu_cycle * lv.is_exit_kernel * (input[1] - nv.is_kernel_mode), ); } @@ -36,18 +36,18 @@ pub fn eval_ext_circuit_exit_kernel, const D: usize nv: &CpuColumnsView>, yield_constr: &mut RecursiveConstraintConsumer, ) { - let jumps_lv = lv.general.jumps(); + let input = lv.mem_channels[0].value; let filter = builder.mul_extension(lv.is_cpu_cycle, lv.is_exit_kernel); // If we are executing `EXIT_KERNEL` then we simply restore the program counter and kernel mode // flag. The top 6 (32-bit) limbs are ignored (this is not part of the spec, but we trust the // kernel to set them to zero). - let pc_constr = builder.sub_extension(jumps_lv.input0[0], nv.program_counter); + let pc_constr = builder.sub_extension(input[0], nv.program_counter); let pc_constr = builder.mul_extension(filter, pc_constr); yield_constr.constraint_transition(builder, pc_constr); - let kernel_constr = builder.sub_extension(jumps_lv.input0[1], nv.is_kernel_mode); + let kernel_constr = builder.sub_extension(input[1], nv.is_kernel_mode); let kernel_constr = builder.mul_extension(filter, kernel_constr); yield_constr.constraint_transition(builder, kernel_constr); } @@ -58,12 +58,14 @@ pub fn eval_packed_jump_jumpi( yield_constr: &mut ConstraintConsumer

, ) { let jumps_lv = lv.general.jumps(); + let input0 = lv.mem_channels[0].value; + let input1 = lv.mem_channels[1].value; let filter = lv.is_jump + lv.is_jumpi; // `JUMP` or `JUMPI` // If `JUMP`, re-use the `JUMPI` logic, but setting the second input (the predicate) to be 1. // In other words, we implement `JUMP(addr)` as `JUMPI(addr, cond=1)`. - yield_constr.constraint(lv.is_jump * (jumps_lv.input1[0] - P::ONES)); - for &limb in &jumps_lv.input1[1..] { + yield_constr.constraint(lv.is_jump * (input1[0] - P::ONES)); + for &limb in &input1[1..] { // Set all limbs (other than the least-significant limb) to 0. // NB: Technically, they don't have to be 0, as long as the sum // `input1[0] + ... + input1[7]` cannot overflow. @@ -75,7 +77,7 @@ pub fn eval_packed_jump_jumpi( yield_constr .constraint(filter * jumps_lv.input0_upper_zero * (jumps_lv.input0_upper_zero - P::ONES)); // The below sum cannot overflow due to the limb size. - let input0_upper_sum: P = jumps_lv.input0[1..].iter().copied().sum(); + let input0_upper_sum: P = input0[1..].iter().copied().sum(); // `input0_upper_zero` = 1 implies `input0_upper_sum` = 0. yield_constr.constraint(filter * jumps_lv.input0_upper_zero * input0_upper_sum); // `input0_upper_zero` = 0 implies `input0_upper_sum_inv * input0_upper_sum` = 1, which can only @@ -113,7 +115,7 @@ pub fn eval_packed_jump_jumpi( // Validate `should_continue` // This sum cannot overflow (due to limb size). - let input1_sum: P = jumps_lv.input1.into_iter().sum(); + let input1_sum: P = input1.into_iter().sum(); // `should_continue` = 1 implies `input1_sum` = 0. yield_constr.constraint(filter * jumps_lv.should_continue * input1_sum); // `should_continue` = 0 implies `input1_sum * input1_sum_inv` = 1, which can only happen if @@ -147,9 +149,8 @@ pub fn eval_packed_jump_jumpi( yield_constr.constraint_transition( filter * jumps_lv.should_continue * (nv.program_counter - lv.program_counter - P::ONES), ); - yield_constr.constraint_transition( - filter * jumps_lv.should_jump * (nv.program_counter - jumps_lv.input0[0]), - ); + yield_constr + .constraint_transition(filter * jumps_lv.should_jump * (nv.program_counter - input0[0])); } pub fn eval_ext_circuit_jump_jumpi, const D: usize>( @@ -159,15 +160,17 @@ pub fn eval_ext_circuit_jump_jumpi, const D: usize> yield_constr: &mut RecursiveConstraintConsumer, ) { let jumps_lv = lv.general.jumps(); + let input0 = lv.mem_channels[0].value; + let input1 = lv.mem_channels[1].value; let filter = builder.add_extension(lv.is_jump, lv.is_jumpi); // `JUMP` or `JUMPI` // If `JUMP`, re-use the `JUMPI` logic, but setting the second input (the predicate) to be 1. // In other words, we implement `JUMP(addr)` as `JUMPI(addr, cond=1)`. { - let constr = builder.mul_sub_extension(lv.is_jump, jumps_lv.input1[0], lv.is_jump); + let constr = builder.mul_sub_extension(lv.is_jump, input1[0], lv.is_jump); yield_constr.constraint(builder, constr); } - for &limb in &jumps_lv.input1[1..] { + for &limb in &input1[1..] { // Set all limbs (other than the least-significant limb) to 0. // NB: Technically, they don't have to be 0, as long as the sum // `input1[0] + ... + input1[7]` cannot overflow. @@ -188,7 +191,7 @@ pub fn eval_ext_circuit_jump_jumpi, const D: usize> } { // The below sum cannot overflow due to the limb size. - let input0_upper_sum = builder.add_many_extension(jumps_lv.input0[1..].iter()); + let input0_upper_sum = builder.add_many_extension(input0[1..].iter()); // `input0_upper_zero` = 1 implies `input0_upper_sum` = 0. let constr = builder.mul_extension(jumps_lv.input0_upper_zero, input0_upper_sum); @@ -251,7 +254,7 @@ pub fn eval_ext_circuit_jump_jumpi, const D: usize> // Validate `should_continue` { // This sum cannot overflow (due to limb size). - let input1_sum = builder.add_many_extension(jumps_lv.input1.into_iter()); + let input1_sum = builder.add_many_extension(input1.into_iter()); // `should_continue` = 1 implies `input1_sum` = 0. let constr = builder.mul_extension(jumps_lv.should_continue, input1_sum); @@ -326,7 +329,7 @@ pub fn eval_ext_circuit_jump_jumpi, const D: usize> } // ...or jumping. { - let constr = builder.sub_extension(nv.program_counter, jumps_lv.input0[0]); + let constr = builder.sub_extension(nv.program_counter, input0[0]); let constr = builder.mul_extension(jumps_lv.should_jump, constr); let constr = builder.mul_extension(filter, constr); yield_constr.constraint_transition(builder, constr); diff --git a/evm/src/cpu/kernel/aggregator.rs b/evm/src/cpu/kernel/aggregator.rs index eb55238b..dda006e6 100644 --- a/evm/src/cpu/kernel/aggregator.rs +++ b/evm/src/cpu/kernel/aggregator.rs @@ -15,6 +15,7 @@ pub(crate) fn combined_kernel() -> Kernel { include_str!("asm/core/create.asm"), include_str!("asm/core/create_addresses.asm"), include_str!("asm/core/intrinsic_gas.asm"), + include_str!("asm/core/invalid.asm"), include_str!("asm/core/nonce.asm"), include_str!("asm/core/process_txn.asm"), include_str!("asm/core/terminate.asm"), diff --git a/evm/src/cpu/kernel/asm/core/call.asm b/evm/src/cpu/kernel/asm/core/call.asm index 3cbbb441..1b8a535f 100644 --- a/evm/src/cpu/kernel/asm/core/call.asm +++ b/evm/src/cpu/kernel/asm/core/call.asm @@ -2,21 +2,21 @@ // Creates a new sub context and executes the code of the given account. global call: - // stack: gas, address, value, args_offset, args_size, ret_offset, ret_size + // stack: gas, address, value, args_offset, args_size, ret_offset, ret_size, retdest %address %stack (self, gas, address, value) - // These are (should_transfer_value, value, static, gas, sender, storage, code_addr) - -> (1, value, 0, gas, self, address, address) + // These are (static, should_transfer_value, value, sender, address, code_addr, gas) + -> (0, 1, value, self, address, address, gas) %jump(call_common) // Creates a new sub context as if calling itself, but with the code of the // given account. In particular the storage remains the same. global call_code: - // stack: gas, address, value, args_offset, args_size, ret_offset, ret_size + // stack: gas, address, value, args_offset, args_size, ret_offset, ret_size, retdest %address %stack (self, gas, address, value) - // These are (should_transfer_value, value, static, gas, sender, storage, code_addr) - -> (1, value, 0, gas, self, self, address) + // These are (static, should_transfer_value, value, sender, address, code_addr, gas) + -> (0, 1, value, self, self, address, gas) %jump(call_common) // Creates a new sub context and executes the code of the given account. @@ -25,35 +25,86 @@ global call_code: // are CREATE, CREATE2, LOG0, LOG1, LOG2, LOG3, LOG4, SSTORE, SELFDESTRUCT and // CALL if the value sent is not 0. global static_all: - // stack: gas, address, args_offset, args_size, ret_offset, ret_size + // stack: gas, address, args_offset, args_size, ret_offset, ret_size, retdest %address %stack (self, gas, address) - // These are (should_transfer_value, value, static, gas, sender, storage, code_addr) - -> (0, 0, 1, gas, self, address, address) + // These are (static, should_transfer_value, value, sender, address, code_addr, gas) + -> (1, 0, 0, self, address, address, gas) %jump(call_common) // Creates a new sub context as if calling itself, but with the code of the // given account. In particular the storage, the current sender and the current // value remain the same. global delegate_call: - // stack: gas, address, args_offset, args_size, ret_offset, ret_size + // stack: gas, address, args_offset, args_size, ret_offset, ret_size, retdest %address %sender %callvalue %stack (self, sender, value, gas, address) - // These are (should_transfer_value, value, static, gas, sender, storage, code_addr) - -> (0, value, 0, gas, sender, self, address) + // These are (static, should_transfer_value, value, sender, address, code_addr, gas) + -> (0, 0, value, sender, self, address, gas) %jump(call_common) call_common: - // stack: should_transfer_value, value, static, gas, sender, storage, code_addr, args_offset, args_size, ret_offset, ret_size - // TODO: Set all the appropriate metadata fields... + // stack: static, should_transfer_value, value, sender, address, code_addr, gas, args_offset, args_size, ret_offset, ret_size, retdest %create_context - // stack: new_ctx, after_call + // Store the static flag in metadata. + %stack (new_ctx, static) -> (new_ctx, @SEGMENT_CONTEXT_METADATA, @CTX_METADATA_STATIC, static, new_ctx) + MSTORE_GENERAL + // stack: new_ctx, should_transfer_value, value, sender, address, code_addr, gas, args_offset, args_size, ret_offset, ret_size, retdest + + // Store the address in metadata. + %stack (new_ctx, should_transfer_value, value, sender, address) + -> (new_ctx, @SEGMENT_CONTEXT_METADATA, @CTX_METADATA_ADDRESS, address, + new_ctx, should_transfer_value, value, sender, address) + MSTORE_GENERAL + // stack: new_ctx, should_transfer_value, value, sender, address, code_addr, gas, args_offset, args_size, ret_offset, ret_size, retdest + + // Store the caller in metadata. + %stack (new_ctx, should_transfer_value, value, sender) + -> (new_ctx, @SEGMENT_CONTEXT_METADATA, @CTX_METADATA_CALLER, sender, + new_ctx, should_transfer_value, value, sender) + MSTORE_GENERAL + // stack: new_ctx, should_transfer_value, value, sender, address, code_addr, gas, args_offset, args_size, ret_offset, ret_size, retdest + + // Store the call value field in metadata. + %stack (new_ctx, should_transfer_value, value, sender, address) = + -> (new_ctx, @SEGMENT_CONTEXT_METADATA, @CTX_METADATA_CALL_VALUE, value, + should_transfer_value, sender, address, value, new_ctx) + MSTORE_GENERAL + // stack: should_transfer_value, sender, address, value, new_ctx, code_addr, gas, args_offset, args_size, ret_offset, ret_size, retdest + + %maybe_transfer_eth + // stack: new_ctx, code_addr, gas, args_offset, args_size, ret_offset, ret_size, retdest + + // Store parent context in metadata. + GET_CONTEXT + PUSH @CTX_METADATA_PARENT_CONTEXT + PUSH @SEGMENT_CONTEXT_METADATA + DUP4 // new_ctx + MSTORE_GENERAL + // stack: new_ctx, code_addr, gas, args_offset, args_size, ret_offset, ret_size, retdest + + // Store parent PC = after_call. + %stack (new_ctx) -> (new_ctx, @SEGMENT_CONTEXT_METADATA, @CTX_METADATA_PARENT_PC, after_call, new_ctx) + MSTORE_GENERAL + // stack: new_ctx, code_addr, gas, args_offset, args_size, ret_offset, ret_size, retdest + + // TODO: Populate CALLDATA + // TODO: Save parent gas and set child gas + // TODO: Populate code + + // TODO: Temporary, remove after above steps are done. + %stack (new_ctx, code_addr, gas, args_offset, args_size) -> (new_ctx) + // stack: new_ctx, ret_offset, ret_size, retdest + // Now, switch to the new context and go to usermode with PC=0. + DUP1 // new_ctx SET_CONTEXT - PUSH 0 + PUSH 0 // jump dest EXIT_KERNEL after_call: - // TODO: Set RETURNDATA etc. + // stack: new_ctx, ret_offset, ret_size, retdest + // TODO: Set RETURNDATA. + // TODO: Return to caller w/ EXIT_KERNEL. diff --git a/evm/src/cpu/kernel/asm/core/invalid.asm b/evm/src/cpu/kernel/asm/core/invalid.asm new file mode 100644 index 00000000..6a7f4c17 --- /dev/null +++ b/evm/src/cpu/kernel/asm/core/invalid.asm @@ -0,0 +1,26 @@ +global handle_invalid: + // stack: trap_info + + // if the kernel is trying to execute an invalid instruction, then we've already screwed up and + // there's no chance of getting a useful proof, so we just panic + DUP1 + // stack: trap_info, trap_info + %shr_const(32) + // stack: is_kernel, trap_info + %jumpi(panic) + + // check if the opcode that triggered this trap is _actually_ invalid + // stack: program_counter (is_kernel == 0, so trap_info == program_counter) + %mload_current_code + // stack: opcode + PUSH @INVALID_OPCODES_USER + // stack: invalid_opcodes_user, opcode + SWAP1 + // stack: opcode, invalid_opcodes_user + SHR + %and_const(1) + // stack: opcode_is_invalid + // if the opcode is indeed invalid, then perform an exceptional exit + %jumpi(fault_exception) + // otherwise, panic because this trap should not have been entered + PANIC diff --git a/evm/src/cpu/kernel/asm/core/transfer.asm b/evm/src/cpu/kernel/asm/core/transfer.asm index 41057aff..0ed48f4d 100644 --- a/evm/src/cpu/kernel/asm/core/transfer.asm +++ b/evm/src/cpu/kernel/asm/core/transfer.asm @@ -14,3 +14,15 @@ global transfer_eth: %jump(transfer_eth) %%after: %endmacro + +// Pre stack: should_transfer, from, to, amount +// Post stack: (empty) +%macro maybe_transfer_eth + %jumpi(%%transfer) + // We're skipping the transfer, so just pop the arguments and return. + %pop3 + %jump(%%after) +%%transfer: + %transfer_eth +%%after: +%endmacro diff --git a/evm/src/cpu/kernel/asm/curve/bn254/curve_add.asm b/evm/src/cpu/kernel/asm/curve/bn254/curve_add.asm index 15f9df05..dda82109 100644 --- a/evm/src/cpu/kernel/asm/curve/bn254/curve_add.asm +++ b/evm/src/cpu/kernel/asm/curve/bn254/curve_add.asm @@ -9,7 +9,6 @@ global ec_add: // PUSH 1 // PUSH 0x1bf9384aa3f0b3ad763aee81940cacdde1af71617c06f46e11510f14f3d5d121 // PUSH 0xe7313274bb29566ff0c8220eb9841de1d96c2923c6a4028f7dd3c6a14cee770 - JUMPDEST // stack: x0, y0, x1, y1, retdest // Check if points are valid BN254 points. @@ -38,7 +37,6 @@ global ec_add: // BN254 elliptic curve addition. // Assumption: (x0,y0) and (x1,y1) are valid points. global ec_add_valid_points: - JUMPDEST // stack: x0, y0, x1, y1, retdest // Check if the first point is the identity. @@ -92,7 +90,6 @@ global ec_add_valid_points: // BN254 elliptic curve addition. // Assumption: (x0,y0) == (0,0) ec_add_first_zero: - JUMPDEST // stack: x0, y0, x1, y1, retdest // Just return (x1,y1) %stack (x0, y0, x1, y1, retdest) -> (retdest, x1, y1) @@ -101,7 +98,6 @@ ec_add_first_zero: // BN254 elliptic curve addition. // Assumption: (x1,y1) == (0,0) ec_add_snd_zero: - JUMPDEST // stack: x0, y0, x1, y1, retdest // Just return (x0,y0) @@ -111,7 +107,6 @@ ec_add_snd_zero: // BN254 elliptic curve addition. // Assumption: lambda = (y0 - y1)/(x0 - x1) ec_add_valid_points_with_lambda: - JUMPDEST // stack: lambda, x0, y0, x1, y1, retdest // Compute x2 = lambda^2 - x1 - x0 @@ -159,7 +154,6 @@ ec_add_valid_points_with_lambda: // BN254 elliptic curve addition. // Assumption: (x0,y0) and (x1,y1) are valid points and x0 == x1 ec_add_equal_first_coord: - JUMPDEST // stack: x0, y0, x1, y1, retdest with x0 == x1 // Check if the points are equal @@ -188,7 +182,6 @@ ec_add_equal_first_coord: // Assumption: x0 == x1 and y0 == y1 // Standard doubling formula. ec_add_equal_points: - JUMPDEST // stack: x0, y0, x1, y1, retdest // Compute lambda = 3/2 * x0^2 / y0 @@ -216,7 +209,6 @@ ec_add_equal_points: // Assumption: (x0,y0) is a valid point. // Standard doubling formula. global ec_double: - JUMPDEST // stack: x0, y0, retdest DUP2 // stack: y0, x0, y0, retdest diff --git a/evm/src/cpu/kernel/asm/curve/bn254/curve_mul.asm b/evm/src/cpu/kernel/asm/curve/bn254/curve_mul.asm index 62cf2235..b1472812 100644 --- a/evm/src/cpu/kernel/asm/curve/bn254/curve_mul.asm +++ b/evm/src/cpu/kernel/asm/curve/bn254/curve_mul.asm @@ -6,7 +6,6 @@ global ec_mul: // PUSH 0xd // PUSH 2 // PUSH 1 - JUMPDEST // stack: x, y, s, retdest DUP2 // stack: y, x, y, s, retdest @@ -29,7 +28,6 @@ global ec_mul: // Same algorithm as in `exp.asm` ec_mul_valid_point: - JUMPDEST // stack: x, y, s, retdest DUP3 // stack: s, x, y, s, retdest @@ -38,7 +36,6 @@ ec_mul_valid_point: %jump(ret_zero_ec_mul) step_case: - JUMPDEST // stack: x, y, s, retdest PUSH recursion_return // stack: recursion_return, x, y, s, retdest @@ -58,12 +55,10 @@ step_case: // Assumption: 2(x,y) = (x',y') step_case_contd: - JUMPDEST // stack: x', y', s / 2, recursion_return, x, y, s, retdest %jump(ec_mul_valid_point) recursion_return: - JUMPDEST // stack: x', y', x, y, s, retdest SWAP4 // stack: s, y', x, y, x', retdest @@ -96,6 +91,5 @@ recursion_return: JUMP odd_scalar: - JUMPDEST // stack: x', y', x, y, retdest %jump(ec_add_valid_points) diff --git a/evm/src/cpu/kernel/asm/curve/common.asm b/evm/src/cpu/kernel/asm/curve/common.asm index 107dc63c..9e273c15 100644 --- a/evm/src/cpu/kernel/asm/curve/common.asm +++ b/evm/src/cpu/kernel/asm/curve/common.asm @@ -1,5 +1,4 @@ global ret_zero_ec_mul: - JUMPDEST // stack: x, y, s, retdest %pop3 // stack: retdest diff --git a/evm/src/cpu/kernel/asm/curve/secp256k1/curve_add.asm b/evm/src/cpu/kernel/asm/curve/secp256k1/curve_add.asm index 7f9c1fff..790fb116 100644 --- a/evm/src/cpu/kernel/asm/curve/secp256k1/curve_add.asm +++ b/evm/src/cpu/kernel/asm/curve/secp256k1/curve_add.asm @@ -3,7 +3,6 @@ // Secp256k1 elliptic curve addition. // Assumption: (x0,y0) and (x1,y1) are valid points. global ec_add_valid_points_secp: - JUMPDEST // stack: x0, y0, x1, y1, retdest // Check if the first point is the identity. @@ -57,7 +56,6 @@ global ec_add_valid_points_secp: // Secp256k1 elliptic curve addition. // Assumption: (x0,y0) == (0,0) ec_add_first_zero: - JUMPDEST // stack: x0, y0, x1, y1, retdest // Just return (x1,y1) @@ -72,7 +70,6 @@ ec_add_first_zero: // Secp256k1 elliptic curve addition. // Assumption: (x1,y1) == (0,0) ec_add_snd_zero: - JUMPDEST // stack: x0, y0, x1, y1, retdest // Just return (x1,y1) @@ -93,7 +90,6 @@ ec_add_snd_zero: // Secp256k1 elliptic curve addition. // Assumption: lambda = (y0 - y1)/(x0 - x1) ec_add_valid_points_with_lambda: - JUMPDEST // stack: lambda, x0, y0, x1, y1, retdest // Compute x2 = lambda^2 - x1 - x0 @@ -150,7 +146,6 @@ ec_add_valid_points_with_lambda: // Secp256k1 elliptic curve addition. // Assumption: (x0,y0) and (x1,y1) are valid points and x0 == x1 ec_add_equal_first_coord: - JUMPDEST // stack: x0, y0, x1, y1, retdest with x0 == x1 // Check if the points are equal @@ -179,7 +174,6 @@ ec_add_equal_first_coord: // Assumption: x0 == x1 and y0 == y1 // Standard doubling formula. ec_add_equal_points: - JUMPDEST // stack: x0, y0, x1, y1, retdest // Compute lambda = 3/2 * x0^2 / y0 @@ -207,7 +201,6 @@ ec_add_equal_points: // Assumption: (x0,y0) is a valid point. // Standard doubling formula. global ec_double_secp: - JUMPDEST // stack: x0, y0, retdest DUP2 // stack: y0, x0, y0, retdest diff --git a/evm/src/cpu/kernel/asm/curve/secp256k1/curve_mul.asm b/evm/src/cpu/kernel/asm/curve/secp256k1/curve_mul.asm index f0825e88..892d57c0 100644 --- a/evm/src/cpu/kernel/asm/curve/secp256k1/curve_mul.asm +++ b/evm/src/cpu/kernel/asm/curve/secp256k1/curve_mul.asm @@ -1,6 +1,5 @@ // Same algorithm as in `exp.asm` global ec_mul_valid_point_secp: - JUMPDEST // stack: x, y, s, retdest %stack (x,y) -> (x,y,x,y) %ec_isidentity @@ -13,7 +12,6 @@ global ec_mul_valid_point_secp: %jump(ret_zero_ec_mul) step_case: - JUMPDEST // stack: x, y, s, retdest PUSH recursion_return // stack: recursion_return, x, y, s, retdest @@ -33,12 +31,10 @@ step_case: // Assumption: 2(x,y) = (x',y') step_case_contd: - JUMPDEST // stack: x', y', s / 2, recursion_return, x, y, s, retdest %jump(ec_mul_valid_point_secp) recursion_return: - JUMPDEST // stack: x', y', x, y, s, retdest SWAP4 // stack: s, y', x, y, x', retdest @@ -71,6 +67,5 @@ recursion_return: JUMP odd_scalar: - JUMPDEST // stack: x', y', x, y, retdest %jump(ec_add_valid_points_secp) diff --git a/evm/src/cpu/kernel/asm/curve/secp256k1/ecrecover.asm b/evm/src/cpu/kernel/asm/curve/secp256k1/ecrecover.asm index 538a86dc..96e177ff 100644 --- a/evm/src/cpu/kernel/asm/curve/secp256k1/ecrecover.asm +++ b/evm/src/cpu/kernel/asm/curve/secp256k1/ecrecover.asm @@ -1,6 +1,5 @@ // ecrecover precompile. global ecrecover: - JUMPDEST // stack: hash, v, r, s, retdest // Check if inputs are valid. @@ -47,7 +46,6 @@ global ecrecover: // let u2 = -hash * r_inv; // return u1*P + u2*GENERATOR; ecrecover_valid_input: - JUMPDEST // stack: hash, y, r, s, retdest // Compute u1 = s * r^(-1) @@ -83,7 +81,6 @@ ecrecover_valid_input: // ecrecover precompile. // Assumption: (X,Y) = u1 * P. Result is (X,Y) + u2*GENERATOR ecrecover_with_first_point: - JUMPDEST // stack: X, Y, hash, r^(-1), retdest %secp_scalar // stack: p, X, Y, hash, r^(-1), retdest @@ -132,7 +129,6 @@ ecrecover_with_first_point: // Take a public key (PKx, PKy) and return the associated address KECCAK256(PKx || PKy)[-20:]. pubkey_to_addr: - JUMPDEST // stack: PKx, PKy, retdest PUSH 0 // stack: 0, PKx, PKy, retdest diff --git a/evm/src/cpu/kernel/asm/exp.asm b/evm/src/cpu/kernel/asm/exp.asm index 3640b2f6..f025e312 100644 --- a/evm/src/cpu/kernel/asm/exp.asm +++ b/evm/src/cpu/kernel/asm/exp.asm @@ -10,7 +10,6 @@ /// Note that this correctly handles exp(0, 0) == 1. global exp: - jumpdest // stack: x, e, retdest dup2 // stack: e, x, e, retdest @@ -27,7 +26,6 @@ global exp: jump step_case: - jumpdest // stack: x, e, retdest push recursion_return // stack: recursion_return, x, e, retdest @@ -43,7 +41,6 @@ step_case: // stack: x * x, e / 2, recursion_return, x, e, retdest %jump(exp) recursion_return: - jumpdest // stack: exp(x * x, e / 2), x, e, retdest push 2 // stack: 2, exp(x * x, e / 2), x, e, retdest diff --git a/evm/src/cpu/kernel/asm/memory/core.asm b/evm/src/cpu/kernel/asm/memory/core.asm index 2c896345..73bafbee 100644 --- a/evm/src/cpu/kernel/asm/memory/core.asm +++ b/evm/src/cpu/kernel/asm/memory/core.asm @@ -26,6 +26,13 @@ // stack: (empty) %endmacro +// Load a single byte from user code. +%macro mload_current_code + // stack: offset + %mload_current(@SEGMENT_CODE) + // stack: value +%endmacro + // Load a single value from the given segment of kernel (context 0) memory. %macro mload_kernel(segment) // stack: offset diff --git a/evm/src/cpu/kernel/asm/memory/memcpy.asm b/evm/src/cpu/kernel/asm/memory/memcpy.asm index 0a390736..3feca35d 100644 --- a/evm/src/cpu/kernel/asm/memory/memcpy.asm +++ b/evm/src/cpu/kernel/asm/memory/memcpy.asm @@ -4,7 +4,6 @@ // DST = (dst_ctx, dst_segment, dst_addr). // These tuple definitions are used for brevity in the stack comments below. global memcpy: - JUMPDEST // stack: DST, SRC, count, retdest DUP7 // stack: count, DST, SRC, count, retdest @@ -44,7 +43,6 @@ global memcpy: %jump(memcpy) memcpy_finish: - JUMPDEST // stack: DST, SRC, count, retdest %pop7 // stack: retdest diff --git a/evm/src/cpu/kernel/asm/rlp/decode.asm b/evm/src/cpu/kernel/asm/rlp/decode.asm index 0388276a..5749aee7 100644 --- a/evm/src/cpu/kernel/asm/rlp/decode.asm +++ b/evm/src/cpu/kernel/asm/rlp/decode.asm @@ -12,7 +12,6 @@ // Pre stack: pos, retdest // Post stack: pos', len global decode_rlp_string_len: - JUMPDEST // stack: pos, retdest DUP1 %mload_current(@SEGMENT_RLP_RAW) @@ -32,7 +31,6 @@ global decode_rlp_string_len: JUMP decode_rlp_string_len_medium: - JUMPDEST // String is 0-55 bytes long. First byte contains the len. // stack: first_byte, pos, retdest %sub_const(0x80) @@ -44,7 +42,6 @@ decode_rlp_string_len_medium: JUMP decode_rlp_string_len_large: - JUMPDEST // String is >55 bytes long. First byte contains the len of the len. // stack: first_byte, pos, retdest %sub_const(0xb7) @@ -69,7 +66,6 @@ decode_rlp_string_len_large: // bytes, so that the result can be returned as a single word on the stack. // As per the spec, scalars must not have leading zeros. global decode_rlp_scalar: - JUMPDEST // stack: pos, retdest PUSH decode_int_given_len // stack: decode_int_given_len, pos, retdest @@ -91,7 +87,6 @@ global decode_rlp_scalar: // Pre stack: pos, retdest // Post stack: pos', len global decode_rlp_list_len: - JUMPDEST // stack: pos, retdest DUP1 %mload_current(@SEGMENT_RLP_RAW) @@ -116,7 +111,6 @@ global decode_rlp_list_len: JUMP decode_rlp_list_len_big: - JUMPDEST // The length of the length is first_byte - 0xf7. // stack: first_byte, pos', retdest %sub_const(0xf7) @@ -137,7 +131,6 @@ decode_rlp_list_len_big: // Pre stack: pos, len, retdest // Post stack: pos', int decode_int_given_len: - JUMPDEST %stack (pos, len, retdest) -> (pos, len, pos, retdest) ADD // stack: end_pos, pos, retdest @@ -147,7 +140,6 @@ decode_int_given_len: // stack: acc, pos, end_pos, retdest decode_int_given_len_loop: - JUMPDEST // stack: acc, pos, end_pos, retdest DUP3 DUP3 @@ -171,6 +163,5 @@ decode_int_given_len_loop: %jump(decode_int_given_len_loop) decode_int_given_len_finish: - JUMPDEST %stack (acc, pos, end_pos, retdest) -> (retdest, pos, acc) JUMP diff --git a/evm/src/cpu/kernel/asm/rlp/read_to_memory.asm b/evm/src/cpu/kernel/asm/rlp/read_to_memory.asm index ae75e3d7..db474b9b 100644 --- a/evm/src/cpu/kernel/asm/rlp/read_to_memory.asm +++ b/evm/src/cpu/kernel/asm/rlp/read_to_memory.asm @@ -5,7 +5,6 @@ // Post stack: (empty) global read_rlp_to_memory: - JUMPDEST // stack: retdest PROVER_INPUT // Read the RLP blob length from the prover tape. // stack: len, retdest @@ -13,7 +12,6 @@ global read_rlp_to_memory: // stack: pos, len, retdest read_rlp_to_memory_loop: - JUMPDEST // stack: pos, len, retdest DUP2 DUP2 @@ -32,7 +30,6 @@ read_rlp_to_memory_loop: %jump(read_rlp_to_memory_loop) read_rlp_to_memory_finish: - JUMPDEST // stack: pos, len, retdest %pop2 // stack: retdest diff --git a/evm/src/cpu/kernel/asm/transactions/router.asm b/evm/src/cpu/kernel/asm/transactions/router.asm index 01a65fec..47a899c9 100644 --- a/evm/src/cpu/kernel/asm/transactions/router.asm +++ b/evm/src/cpu/kernel/asm/transactions/router.asm @@ -3,7 +3,6 @@ // jump to the appropriate transaction parsing method. global route_txn: - JUMPDEST // stack: (empty) // First load transaction data into memory, where it will be parsed. PUSH read_txn_from_memory @@ -11,7 +10,6 @@ global route_txn: // At this point, the raw txn data is in memory. read_txn_from_memory: - JUMPDEST // stack: (empty) // We will peak at the first byte to determine what type of transaction this is. diff --git a/evm/src/cpu/kernel/asm/transactions/type_0.asm b/evm/src/cpu/kernel/asm/transactions/type_0.asm index 7c8488f7..3f258624 100644 --- a/evm/src/cpu/kernel/asm/transactions/type_0.asm +++ b/evm/src/cpu/kernel/asm/transactions/type_0.asm @@ -12,7 +12,6 @@ // keccak256(rlp([nonce, gas_price, gas_limit, to, value, data])) global process_type_0_txn: - JUMPDEST // stack: (empty) PUSH 0 // initial pos // stack: pos diff --git a/evm/src/cpu/kernel/asm/transactions/type_1.asm b/evm/src/cpu/kernel/asm/transactions/type_1.asm index 5b9d2cdf..9d45c1e4 100644 --- a/evm/src/cpu/kernel/asm/transactions/type_1.asm +++ b/evm/src/cpu/kernel/asm/transactions/type_1.asm @@ -7,6 +7,5 @@ // data, access_list])) global process_type_1_txn: - JUMPDEST // stack: (empty) PANIC // TODO: Unfinished diff --git a/evm/src/cpu/kernel/asm/transactions/type_2.asm b/evm/src/cpu/kernel/asm/transactions/type_2.asm index 9807f88f..b2a862c1 100644 --- a/evm/src/cpu/kernel/asm/transactions/type_2.asm +++ b/evm/src/cpu/kernel/asm/transactions/type_2.asm @@ -8,6 +8,5 @@ // access_list])) global process_type_2_txn: - JUMPDEST // stack: (empty) PANIC // TODO: Unfinished diff --git a/evm/src/cpu/kernel/asm/util/assertions.asm b/evm/src/cpu/kernel/asm/util/assertions.asm index 69193e5f..0051219c 100644 --- a/evm/src/cpu/kernel/asm/util/assertions.asm +++ b/evm/src/cpu/kernel/asm/util/assertions.asm @@ -1,7 +1,6 @@ // It is convenient to have a single panic routine, which we can jump to from // anywhere. global panic: - JUMPDEST PANIC // Consumes the top element and asserts that it is zero. diff --git a/evm/src/cpu/kernel/asm/util/basic_macros.asm b/evm/src/cpu/kernel/asm/util/basic_macros.asm index 56001dc1..d1b7bff3 100644 --- a/evm/src/cpu/kernel/asm/util/basic_macros.asm +++ b/evm/src/cpu/kernel/asm/util/basic_macros.asm @@ -44,6 +44,13 @@ %endrep %endmacro +%macro and_const(c) + // stack: input, ... + PUSH $c + AND + // stack: input & c, ... +%endmacro + %macro add_const(c) // stack: input, ... PUSH $c @@ -101,6 +108,13 @@ // stack: input << c, ... %endmacro +%macro shr_const(c) + // stack: input, ... + PUSH $c + SHR + // stack: input >> c, ... +%endmacro + %macro eq_const(c) // stack: input, ... PUSH $c diff --git a/evm/src/cpu/kernel/assembler.rs b/evm/src/cpu/kernel/assembler.rs index f5175c41..0471bf99 100644 --- a/evm/src/cpu/kernel/assembler.rs +++ b/evm/src/cpu/kernel/assembler.rs @@ -551,6 +551,7 @@ mod tests { let dup1 = get_opcode("DUP1"); let swap1 = get_opcode("SWAP1"); let swap2 = get_opcode("SWAP2"); + let swap3 = get_opcode("SWAP3"); let push_label = get_push_opcode(BYTES_PER_OFFSET); let kernel = parse_and_assemble(&["%stack (a) -> (a)"]); @@ -562,6 +563,17 @@ mod tests { let kernel = parse_and_assemble(&["%stack (a, b, c) -> (b)"]); assert_eq!(kernel.code, vec![pop, swap1, pop]); + let kernel = parse_and_assemble(&["%stack (a, b: 3, c) -> (c)"]); + assert_eq!(kernel.code, vec![pop, pop, pop, pop]); + + let kernel = parse_and_assemble(&["%stack (a: 2, b: 2) -> (b, a)"]); + assert_eq!(kernel.code, vec![swap1, swap3, swap1, swap2]); + + let kernel1 = parse_and_assemble(&["%stack (a: 3, b: 3, c) -> (c, b, a)"]); + let kernel2 = + parse_and_assemble(&["%stack (a, b, c, d, e, f, g) -> (g, d, e, f, a, b, c)"]); + assert_eq!(kernel1.code, kernel2.code); + let mut consts = HashMap::new(); consts.insert("LIFE".into(), 42.into()); parse_and_assemble_ext(&["%stack (a, b) -> (b, @LIFE)"], consts, true); diff --git a/evm/src/cpu/kernel/ast.rs b/evm/src/cpu/kernel/ast.rs index 24cf01e1..bad60d03 100644 --- a/evm/src/cpu/kernel/ast.rs +++ b/evm/src/cpu/kernel/ast.rs @@ -19,7 +19,7 @@ pub(crate) enum Item { /// The first list gives names to items on the top of the stack. /// The second list specifies replacement items. /// Example: `(a, b, c) -> (c, 5, 0x20, @SOME_CONST, a)`. - StackManipulation(Vec, Vec), + StackManipulation(Vec, Vec), /// Declares a global label. GlobalLabelDeclaration(String), /// Declares a label that is local to the current file. @@ -36,6 +36,14 @@ pub(crate) enum Item { Bytes(Vec), } +/// The left hand side of a %stack stack-manipulation macro. +#[derive(Eq, PartialEq, Clone, Debug)] +pub(crate) enum StackPlaceholder { + Identifier(String), + Block(String, usize), +} + +/// The right hand side of a %stack stack-manipulation macro. #[derive(Eq, PartialEq, Clone, Debug)] pub(crate) enum StackReplacement { /// Can be either a named item or a label. diff --git a/evm/src/cpu/kernel/constants.rs b/evm/src/cpu/kernel/constants.rs index 5bc5908e..98fe57c6 100644 --- a/evm/src/cpu/kernel/constants.rs +++ b/evm/src/cpu/kernel/constants.rs @@ -3,6 +3,7 @@ use std::collections::HashMap; use ethereum_types::U256; use hex_literal::hex; +use crate::cpu::decode::invalid_opcodes_user; use crate::cpu::kernel::context_metadata::ContextMetadata; use crate::cpu::kernel::global_metadata::GlobalMetadata; use crate::cpu::kernel::txn_fields::NormalizedTxnField; @@ -29,6 +30,10 @@ pub fn evm_constants() -> HashMap { for txn_field in ContextMetadata::all() { c.insert(txn_field.var_name().into(), (txn_field as u32).into()); } + c.insert( + "INVALID_OPCODES_USER".into(), + U256::from_little_endian(&invalid_opcodes_user()), + ); c } diff --git a/evm/src/cpu/kernel/context_metadata.rs b/evm/src/cpu/kernel/context_metadata.rs index 26bd541f..17945d98 100644 --- a/evm/src/cpu/kernel/context_metadata.rs +++ b/evm/src/cpu/kernel/context_metadata.rs @@ -20,10 +20,13 @@ pub(crate) enum ContextMetadata { /// Whether this context was created by `STATICCALL`, in which case state changes are /// prohibited. Static = 8, + /// Pointer to the initial version of the state trie, at the creation of this context. Used when + /// we need to revert a context. See also `StorageTrieCheckpointPointers`. + StateTrieCheckpointPointer = 9, } impl ContextMetadata { - pub(crate) const COUNT: usize = 9; + pub(crate) const COUNT: usize = 10; pub(crate) fn all() -> [Self; Self::COUNT] { [ @@ -36,6 +39,7 @@ impl ContextMetadata { Self::Caller, Self::CallValue, Self::Static, + Self::StateTrieCheckpointPointer, ] } @@ -51,6 +55,7 @@ impl ContextMetadata { ContextMetadata::Caller => "CTX_METADATA_CALLER", ContextMetadata::CallValue => "CTX_METADATA_CALL_VALUE", ContextMetadata::Static => "CTX_METADATA_STATIC", + ContextMetadata::StateTrieCheckpointPointer => "CTX_METADATA_STATE_TRIE_CHECKPOINT_PTR", } } } diff --git a/evm/src/cpu/kernel/evm_asm.pest b/evm/src/cpu/kernel/evm_asm.pest index 8ea7de4b..89d06e74 100644 --- a/evm/src/cpu/kernel/evm_asm.pest +++ b/evm/src/cpu/kernel/evm_asm.pest @@ -21,7 +21,10 @@ macro_call = ${ "%" ~ !(^"macro" | ^"endmacro" | ^"rep" | ^"endrep" | ^"stack") repeat = { ^"%rep" ~ literal ~ item* ~ ^"%endrep" } paramlist = { "(" ~ identifier ~ ("," ~ identifier)* ~ ")" } macro_arglist = !{ "(" ~ push_target ~ ("," ~ push_target)* ~ ")" } -stack = { ^"%stack" ~ paramlist ~ "->" ~ stack_replacements } +stack = { ^"%stack" ~ stack_placeholders ~ "->" ~ stack_replacements } +stack_placeholders = { "(" ~ stack_placeholder ~ ("," ~ stack_placeholder)* ~ ")" } +stack_placeholder = { stack_block | identifier } +stack_block = { identifier ~ ":" ~ literal_decimal } stack_replacements = { "(" ~ stack_replacement ~ ("," ~ stack_replacement)* ~ ")" } stack_replacement = { literal | identifier | constant | macro_label | variable } global_label_decl = ${ ^"GLOBAL " ~ identifier ~ ":" } diff --git a/evm/src/cpu/kernel/global_metadata.rs b/evm/src/cpu/kernel/global_metadata.rs index 6343a2e6..ddc3c839 100644 --- a/evm/src/cpu/kernel/global_metadata.rs +++ b/evm/src/cpu/kernel/global_metadata.rs @@ -9,13 +9,50 @@ pub(crate) enum GlobalMetadata { Origin = 1, /// The size of active memory, in bytes. MemorySize = 2, + /// The size of the `TrieData` segment, in bytes. In other words, the next address available for + /// appending additional trie data. + TrieDataSize = 3, + /// A pointer to the root of the state trie within the `TrieData` buffer. + StateTrieRoot = 4, + /// A pointer to the root of the transaction trie within the `TrieData` buffer. + TransactionTrieRoot = 5, + /// A pointer to the root of the receipt trie within the `TrieData` buffer. + ReceiptTrieRoot = 6, + /// The number of storage tries involved in these transactions. I.e. the number of values in + /// `StorageTrieAddresses`, `StorageTriePointers` and `StorageTrieCheckpointPointers`. + NumStorageTries = 7, + + // The root digests of each Merkle trie before these transactions. + StateTrieRootDigestBefore = 8, + TransactionsTrieRootDigestBefore = 9, + ReceiptsTrieRootDigestBefore = 10, + + // The root digests of each Merkle trie after these transactions. + StateTrieRootDigestAfter = 11, + TransactionsTrieRootDigestAfter = 12, + ReceiptsTrieRootDigestAfter = 13, } impl GlobalMetadata { - pub(crate) const COUNT: usize = 3; + pub(crate) const COUNT: usize = 14; pub(crate) fn all() -> [Self; Self::COUNT] { - [Self::LargestContext, Self::Origin, Self::MemorySize] + [ + Self::LargestContext, + Self::Origin, + Self::MemorySize, + Self::TrieDataSize, + Self::StateTrieRoot, + Self::TransactionTrieRoot, + Self::ReceiptTrieRoot, + Self::NumStorageTries, + Self::StateTrieRootDigestBefore, + Self::TransactionsTrieRootDigestBefore, + Self::ReceiptsTrieRootDigestBefore, + Self::StateTrieRootDigestAfter, + Self::TransactionsTrieRootDigestAfter, + Self::ReceiptsTrieRootDigestAfter, + ] } /// The variable name that gets passed into kernel assembly code. @@ -24,6 +61,25 @@ impl GlobalMetadata { GlobalMetadata::LargestContext => "GLOBAL_METADATA_LARGEST_CONTEXT", GlobalMetadata::Origin => "GLOBAL_METADATA_ORIGIN", GlobalMetadata::MemorySize => "GLOBAL_METADATA_MEMORY_SIZE", + GlobalMetadata::TrieDataSize => "GLOBAL_METADATA_TRIE_DATA_SIZE", + GlobalMetadata::StateTrieRoot => "GLOBAL_METADATA_STATE_TRIE_ROOT", + GlobalMetadata::TransactionTrieRoot => "GLOBAL_METADATA_TXN_TRIE_ROOT", + GlobalMetadata::ReceiptTrieRoot => "GLOBAL_METADATA_RECEIPT_TRIE_ROOT", + GlobalMetadata::NumStorageTries => "GLOBAL_METADATA_NUM_STORAGE_TRIES", + GlobalMetadata::StateTrieRootDigestBefore => "GLOBAL_METADATA_STATE_TRIE_DIGEST_BEFORE", + GlobalMetadata::TransactionsTrieRootDigestBefore => { + "GLOBAL_METADATA_TXNS_TRIE_DIGEST_BEFORE" + } + GlobalMetadata::ReceiptsTrieRootDigestBefore => { + "GLOBAL_METADATA_RECEIPTS_TRIE_DIGEST_BEFORE" + } + GlobalMetadata::StateTrieRootDigestAfter => "GLOBAL_METADATA_STATE_TRIE_DIGEST_AFTER", + GlobalMetadata::TransactionsTrieRootDigestAfter => { + "GLOBAL_METADATA_TXNS_TRIE_DIGEST_AFTER" + } + GlobalMetadata::ReceiptsTrieRootDigestAfter => { + "GLOBAL_METADATA_RECEIPTS_TRIE_DIGEST_AFTER" + } } } } diff --git a/evm/src/cpu/kernel/interpreter.rs b/evm/src/cpu/kernel/interpreter.rs index 17be0523..64d70529 100644 --- a/evm/src/cpu/kernel/interpreter.rs +++ b/evm/src/cpu/kernel/interpreter.rs @@ -250,7 +250,7 @@ impl<'a> Interpreter<'a> { 0x58 => todo!(), // "GETPC", 0x59 => todo!(), // "MSIZE", 0x5a => todo!(), // "GAS", - 0x5b => (), // "JUMPDEST", + 0x5b => self.run_jumpdest(), // "JUMPDEST", 0x5c => todo!(), // "GET_STATE_ROOT", 0x5d => todo!(), // "SET_STATE_ROOT", 0x5e => todo!(), // "GET_RECEIPT_ROOT", @@ -490,6 +490,10 @@ impl<'a> Interpreter<'a> { } } + fn run_jumpdest(&mut self) { + assert!(!self.kernel_mode, "JUMPDEST is not needed in kernel code"); + } + fn jump_to(&mut self, offset: usize) { // The JUMPDEST rule is not enforced in kernel mode. if !self.kernel_mode && self.jumpdests.binary_search(&offset).is_err() { diff --git a/evm/src/cpu/kernel/keccak_util.rs b/evm/src/cpu/kernel/keccak_util.rs index 1498ba08..52cc0f08 100644 --- a/evm/src/cpu/kernel/keccak_util.rs +++ b/evm/src/cpu/kernel/keccak_util.rs @@ -1,3 +1,5 @@ +use tiny_keccak::keccakf; + /// A Keccak-f based hash. /// /// This hash does not use standard Keccak padding, since we don't care about extra zeros at the @@ -9,6 +11,42 @@ pub(crate) fn hash_kernel(_code: &[u8]) -> [u32; 8] { } /// Like tiny-keccak's `keccakf`, but deals with `u32` limbs instead of `u64` limbs. -pub(crate) fn keccakf_u32s(_state: &mut [u32; 50]) { - // TODO: Implement +pub(crate) fn keccakf_u32s(state_u32s: &mut [u32; 50]) { + let mut state_u64s: [u64; 25] = std::array::from_fn(|i| { + let lo = state_u32s[i * 2] as u64; + let hi = state_u32s[i * 2 + 1] as u64; + lo | (hi << 32) + }); + keccakf(&mut state_u64s); + *state_u32s = std::array::from_fn(|i| { + let u64_limb = state_u64s[i / 2]; + let is_hi = i % 2; + (u64_limb >> (is_hi * 32)) as u32 + }); +} + +#[cfg(test)] +mod tests { + use tiny_keccak::keccakf; + + use crate::cpu::kernel::keccak_util::keccakf_u32s; + + #[test] + #[rustfmt::skip] + fn test_consistency() { + // We will hash the same data using keccakf and keccakf_u32s. + // The inputs were randomly generated in Python. + let mut state_u64s: [u64; 25] = [0x5dc43ed05dc64048, 0x7bb9e18cdc853880, 0xc1fde300665b008f, 0xeeab85e089d5e431, 0xf7d61298e9ef27ea, 0xc2c5149d1a492455, 0x37a2f4eca0c2d2f2, 0xa35e50c015b3e85c, 0xd2daeced29446ebe, 0x245845f1bac1b98e, 0x3b3aa8783f30a9bf, 0x209ca9a81956d241, 0x8b8ea714da382165, 0x6063e67e202c6d29, 0xf4bac2ded136b907, 0xb17301b461eae65, 0xa91ff0e134ed747c, 0xcc080b28d0c20f1d, 0xf0f79cbec4fb551c, 0x25e04cb0aa930cad, 0x803113d1b541a202, 0xfaf1e4e7cd23b7ec, 0x36a03bbf2469d3b0, 0x25217341908cdfc0, 0xe9cd83f88fdcd500]; + let mut state_u32s: [u32; 50] = [0x5dc64048, 0x5dc43ed0, 0xdc853880, 0x7bb9e18c, 0x665b008f, 0xc1fde300, 0x89d5e431, 0xeeab85e0, 0xe9ef27ea, 0xf7d61298, 0x1a492455, 0xc2c5149d, 0xa0c2d2f2, 0x37a2f4ec, 0x15b3e85c, 0xa35e50c0, 0x29446ebe, 0xd2daeced, 0xbac1b98e, 0x245845f1, 0x3f30a9bf, 0x3b3aa878, 0x1956d241, 0x209ca9a8, 0xda382165, 0x8b8ea714, 0x202c6d29, 0x6063e67e, 0xd136b907, 0xf4bac2de, 0x461eae65, 0xb17301b, 0x34ed747c, 0xa91ff0e1, 0xd0c20f1d, 0xcc080b28, 0xc4fb551c, 0xf0f79cbe, 0xaa930cad, 0x25e04cb0, 0xb541a202, 0x803113d1, 0xcd23b7ec, 0xfaf1e4e7, 0x2469d3b0, 0x36a03bbf, 0x908cdfc0, 0x25217341, 0x8fdcd500, 0xe9cd83f8]; + + // The first output was generated using tiny-keccak; the second was derived from it. + let out_u64s: [u64; 25] = [0x8a541df597e79a72, 0x5c26b8c84faaebb3, 0xc0e8f4e67ca50497, 0x95d98a688de12dec, 0x1c837163975ffaed, 0x9481ec7ef948900e, 0x6a072c65d050a9a1, 0x3b2817da6d615bee, 0x7ffb3c4f8b94bf21, 0x85d6c418cced4a11, 0x18edbe0442884135, 0x2bf265ef3204b7fd, 0xc1e12ce30630d105, 0x8c554dbc61844574, 0x5504db652ce9e42c, 0x2217f3294d0dabe5, 0x7df8eebbcf5b74df, 0x3a56ebb61956f501, 0x7840219dc6f37cc, 0x23194159c967947, 0x9da289bf616ba14d, 0x5a90aaeeca9e9e5b, 0x885dcdc4a549b4e3, 0x46cb188c20947df7, 0x1ef285948ee3d8ab]; + let out_u32s: [u32; 50] = [0x97e79a72, 0x8a541df5, 0x4faaebb3, 0x5c26b8c8, 0x7ca50497, 0xc0e8f4e6, 0x8de12dec, 0x95d98a68, 0x975ffaed, 0x1c837163, 0xf948900e, 0x9481ec7e, 0xd050a9a1, 0x6a072c65, 0x6d615bee, 0x3b2817da, 0x8b94bf21, 0x7ffb3c4f, 0xcced4a11, 0x85d6c418, 0x42884135, 0x18edbe04, 0x3204b7fd, 0x2bf265ef, 0x630d105, 0xc1e12ce3, 0x61844574, 0x8c554dbc, 0x2ce9e42c, 0x5504db65, 0x4d0dabe5, 0x2217f329, 0xcf5b74df, 0x7df8eebb, 0x1956f501, 0x3a56ebb6, 0xdc6f37cc, 0x7840219, 0x9c967947, 0x2319415, 0x616ba14d, 0x9da289bf, 0xca9e9e5b, 0x5a90aaee, 0xa549b4e3, 0x885dcdc4, 0x20947df7, 0x46cb188c, 0x8ee3d8ab, 0x1ef28594]; + + keccakf(&mut state_u64s); + keccakf_u32s(&mut state_u32s); + + assert_eq!(state_u64s, out_u64s); + assert_eq!(state_u32s, out_u32s); + } } diff --git a/evm/src/cpu/kernel/mod.rs b/evm/src/cpu/kernel/mod.rs index eceba813..ef5a9ba0 100644 --- a/evm/src/cpu/kernel/mod.rs +++ b/evm/src/cpu/kernel/mod.rs @@ -2,9 +2,9 @@ pub mod aggregator; pub mod assembler; mod ast; mod constants; -mod context_metadata; +pub(crate) mod context_metadata; mod cost_estimator; -mod global_metadata; +pub(crate) mod global_metadata; pub(crate) mod keccak_util; mod opcodes; mod optimizer; diff --git a/evm/src/cpu/kernel/parser.rs b/evm/src/cpu/kernel/parser.rs index 9ed578d4..89da016c 100644 --- a/evm/src/cpu/kernel/parser.rs +++ b/evm/src/cpu/kernel/parser.rs @@ -4,6 +4,7 @@ use ethereum_types::U256; use pest::iterators::Pair; use pest::Parser; +use super::ast::StackPlaceholder; use crate::cpu::kernel::ast::{File, Item, PushTarget, StackReplacement}; /// Parses EVM assembly code. @@ -89,24 +90,21 @@ fn parse_macro_call(item: Pair) -> Item { fn parse_repeat(item: Pair) -> Item { assert_eq!(item.as_rule(), Rule::repeat); - let mut inner = item.into_inner().peekable(); + let mut inner = item.into_inner(); let count = parse_literal_u256(inner.next().unwrap()); Item::Repeat(count, inner.map(parse_item).collect()) } fn parse_stack(item: Pair) -> Item { assert_eq!(item.as_rule(), Rule::stack); - let mut inner = item.into_inner().peekable(); + let mut inner = item.into_inner(); let params = inner.next().unwrap(); - assert_eq!(params.as_rule(), Rule::paramlist); + assert_eq!(params.as_rule(), Rule::stack_placeholders); let replacements = inner.next().unwrap(); assert_eq!(replacements.as_rule(), Rule::stack_replacements); - let params = params - .into_inner() - .map(|param| param.as_str().to_string()) - .collect(); + let params = params.into_inner().map(parse_stack_placeholder).collect(); let replacements = replacements .into_inner() .map(parse_stack_replacement) @@ -114,6 +112,21 @@ fn parse_stack(item: Pair) -> Item { Item::StackManipulation(params, replacements) } +fn parse_stack_placeholder(target: Pair) -> StackPlaceholder { + assert_eq!(target.as_rule(), Rule::stack_placeholder); + let inner = target.into_inner().next().unwrap(); + match inner.as_rule() { + Rule::identifier => StackPlaceholder::Identifier(inner.as_str().into()), + Rule::stack_block => { + let mut block = inner.into_inner(); + let identifier = block.next().unwrap().as_str(); + let length = block.next().unwrap().as_str().parse().unwrap(); + StackPlaceholder::Block(identifier.to_string(), length) + } + _ => panic!("Unexpected {:?}", inner.as_rule()), + } +} + fn parse_stack_replacement(target: Pair) -> StackReplacement { assert_eq!(target.as_rule(), Rule::stack_replacement); let inner = target.into_inner().next().unwrap(); diff --git a/evm/src/cpu/kernel/stack/stack_manipulation.rs b/evm/src/cpu/kernel/stack/stack_manipulation.rs index 9f685953..faec7e04 100644 --- a/evm/src/cpu/kernel/stack/stack_manipulation.rs +++ b/evm/src/cpu/kernel/stack/stack_manipulation.rs @@ -1,13 +1,13 @@ use std::cmp::Ordering; use std::collections::hash_map::Entry::{Occupied, Vacant}; -use std::collections::{BinaryHeap, HashMap}; +use std::collections::{BinaryHeap, HashMap, HashSet}; use std::hash::Hash; use itertools::Itertools; use crate::cpu::columns::NUM_CPU_COLUMNS; use crate::cpu::kernel::assembler::BYTES_PER_OFFSET; -use crate::cpu::kernel::ast::{Item, PushTarget, StackReplacement}; +use crate::cpu::kernel::ast::{Item, PushTarget, StackPlaceholder, StackReplacement}; use crate::cpu::kernel::stack::permutations::{get_stack_ops_for_perm, is_permutation}; use crate::cpu::kernel::stack::stack_manipulation::StackOp::Pop; use crate::cpu::kernel::utils::u256_to_trimmed_be_bytes; @@ -25,25 +25,50 @@ pub(crate) fn expand_stack_manipulation(body: Vec) -> Vec { expanded } -fn expand(names: Vec, replacements: Vec) -> Vec { +fn expand(names: Vec, replacements: Vec) -> Vec { + let mut stack_blocks = HashMap::new(); + let mut stack_names = HashSet::new(); + let mut src = names .iter() .cloned() - .map(StackItem::NamedItem) + .flat_map(|item| match item { + StackPlaceholder::Identifier(name) => { + stack_names.insert(name.clone()); + vec![StackItem::NamedItem(name)] + } + StackPlaceholder::Block(name, n) => { + stack_blocks.insert(name.clone(), n); + (0..n) + .map(|i| { + let literal_name = format!("block_{}_{}", name, i); + StackItem::NamedItem(literal_name) + }) + .collect_vec() + } + }) .collect_vec(); let mut dst = replacements .into_iter() - .map(|item| match item { + .flat_map(|item| match item { StackReplacement::Identifier(name) => { // May be either a named item or a label. Named items have precedence. - if names.contains(&name) { - StackItem::NamedItem(name) + if stack_blocks.contains_key(&name) { + let n = *stack_blocks.get(&name).unwrap(); + (0..n) + .map(|i| { + let literal_name = format!("block_{}_{}", name, i); + StackItem::NamedItem(literal_name) + }) + .collect_vec() + } else if stack_names.contains(&name) { + vec![StackItem::NamedItem(name)] } else { - StackItem::PushTarget(PushTarget::Label(name)) + vec![StackItem::PushTarget(PushTarget::Label(name))] } } - StackReplacement::Literal(n) => StackItem::PushTarget(PushTarget::Literal(n)), + StackReplacement::Literal(n) => vec![StackItem::PushTarget(PushTarget::Literal(n))], StackReplacement::MacroLabel(_) | StackReplacement::MacroVar(_) | StackReplacement::Constant(_) => { diff --git a/evm/src/cpu/mod.rs b/evm/src/cpu/mod.rs index 5950c837..bda044b7 100644 --- a/evm/src/cpu/mod.rs +++ b/evm/src/cpu/mod.rs @@ -5,6 +5,6 @@ pub mod cpu_stark; pub(crate) mod decode; mod jumps; pub mod kernel; -pub mod public_inputs; mod simple_logic; +mod stack_bounds; mod syscalls; diff --git a/evm/src/cpu/public_inputs.rs b/evm/src/cpu/public_inputs.rs deleted file mode 100644 index 0a02e406..00000000 --- a/evm/src/cpu/public_inputs.rs +++ /dev/null @@ -1 +0,0 @@ -pub const NUM_PUBLIC_INPUTS: usize = 0; // PIs will be added later. diff --git a/evm/src/cpu/simple_logic/eq_iszero.rs b/evm/src/cpu/simple_logic/eq_iszero.rs index 75bb8bb6..6b7294a8 100644 --- a/evm/src/cpu/simple_logic/eq_iszero.rs +++ b/evm/src/cpu/simple_logic/eq_iszero.rs @@ -1,3 +1,4 @@ +use itertools::izip; use plonky2::field::extension::Extendable; use plonky2::field::packed::PackedField; use plonky2::hash::hash_types::RichField; @@ -6,44 +7,49 @@ use plonky2::iop::ext_target::ExtensionTarget; use crate::constraint_consumer::{ConstraintConsumer, RecursiveConstraintConsumer}; use crate::cpu::columns::CpuColumnsView; -const LIMB_SIZE: usize = 16; - pub fn generate(lv: &mut CpuColumnsView) { - let logic = lv.general.logic_mut(); + let input0 = lv.mem_channels[0].value; + let eq_filter = lv.is_eq.to_canonical_u64(); let iszero_filter = lv.is_iszero.to_canonical_u64(); assert!(eq_filter <= 1); assert!(iszero_filter <= 1); assert!(eq_filter + iszero_filter <= 1); - if eq_filter != 1 && iszero_filter != 1 { + if eq_filter + iszero_filter == 0 { return; } - let diffs = if eq_filter == 1 { - logic - .input0 - .into_iter() - .zip(logic.input1) - .map(|(in0, in1)| { - assert_eq!(in0.to_canonical_u64() >> LIMB_SIZE, 0); - assert_eq!(in1.to_canonical_u64() >> LIMB_SIZE, 0); - let diff = in0 - in1; - diff.square() - }) - .sum() - } else if iszero_filter == 1 { - logic.input0.into_iter().sum() - } else { - panic!() - }; + let input1 = &mut lv.mem_channels[1].value; + if iszero_filter != 0 { + for limb in input1.iter_mut() { + *limb = F::ZERO; + } + } - lv.simple_logic_diff = diffs; - lv.simple_logic_diff_inv = diffs.try_inverse().unwrap_or(F::ZERO); + let input1 = lv.mem_channels[1].value; + let num_unequal_limbs = izip!(input0, input1) + .map(|(limb0, limb1)| (limb0 != limb1) as usize) + .sum(); + let equal = num_unequal_limbs == 0; - logic.output[0] = F::from_bool(diffs == F::ZERO); - for out_limb_ref in logic.output[1..].iter_mut() { - *out_limb_ref = F::ZERO; + let output = &mut lv.mem_channels[2].value; + output[0] = F::from_bool(equal); + for limb in &mut output[1..] { + *limb = F::ZERO; + } + + // Form `diff_pinv`. + // Let `diff = input0 - input1`. Consider `x[i] = diff[i]^-1` if `diff[i] != 0` and 0 otherwise. + // Then `diff @ x = num_unequal_limbs`, where `@` denotes the dot product. We set + // `diff_pinv = num_unequal_limbs^-1 * x` if `num_unequal_limbs != 0` and 0 otherwise. We have + // `diff @ diff_pinv = 1 - equal` as desired. + let logic = lv.general.logic_mut(); + let num_unequal_limbs_inv = F::from_canonical_usize(num_unequal_limbs) + .try_inverse() + .unwrap_or(F::ZERO); + for (limb_pinv, limb0, limb1) in izip!(logic.diff_pinv.iter_mut(), input0, input1) { + *limb_pinv = (limb0 - limb1).try_inverse().unwrap_or(F::ZERO) * num_unequal_limbs_inv; } } @@ -52,40 +58,43 @@ pub fn eval_packed( yield_constr: &mut ConstraintConsumer

, ) { let logic = lv.general.logic(); + let input0 = lv.mem_channels[0].value; + let input1 = lv.mem_channels[1].value; + let output = lv.mem_channels[2].value; + let eq_filter = lv.is_eq; let iszero_filter = lv.is_iszero; let eq_or_iszero_filter = eq_filter + iszero_filter; - let ls_bit = logic.output[0]; + let equal = output[0]; + let unequal = P::ONES - equal; - // Handle EQ and ISZERO. Most limbs of the output are 0, but the least-significant one is + // Handle `EQ` and `ISZERO`. Most limbs of the output are 0, but the least-significant one is // either 0 or 1. - yield_constr.constraint(eq_or_iszero_filter * ls_bit * (ls_bit - P::ONES)); - - for &bit in &logic.output[1..] { - yield_constr.constraint(eq_or_iszero_filter * bit); + yield_constr.constraint(eq_or_iszero_filter * equal * unequal); + for &limb in &output[1..] { + yield_constr.constraint(eq_or_iszero_filter * limb); } - // Check SIMPLE_LOGIC_DIFF - let diffs = lv.simple_logic_diff; - let diffs_inv = lv.simple_logic_diff_inv; - { - let input0_sum: P = logic.input0.into_iter().sum(); - yield_constr.constraint(iszero_filter * (diffs - input0_sum)); - - let sum_squared_diffs: P = logic - .input0 - .into_iter() - .zip(logic.input1) - .map(|(in0, in1)| (in0 - in1).square()) - .sum(); - yield_constr.constraint(eq_filter * (diffs - sum_squared_diffs)); + // If `ISZERO`, constrain input1 to be zero, effectively implementing ISZERO(x) as EQ(x, 0). + for limb in input1 { + yield_constr.constraint(iszero_filter * limb); } - // diffs != 0 => ls_bit == 0 - yield_constr.constraint(eq_or_iszero_filter * diffs * ls_bit); - // ls_bit == 0 => diffs != 0 (we provide a diffs_inv) - yield_constr.constraint(eq_or_iszero_filter * (diffs * diffs_inv + ls_bit - P::ONES)); + // `equal` implies `input0[i] == input1[i]` for all `i`. + for (limb0, limb1) in izip!(input0, input1) { + let diff = limb0 - limb1; + yield_constr.constraint(eq_or_iszero_filter * equal * diff); + } + + // `input0[i] == input1[i]` for all `i` implies `equal`. + // If `unequal`, find `diff_pinv` such that `(input0 - input1) @ diff_pinv == 1`, where `@` + // denotes the dot product (there will be many such `diff_pinv`). This can only be done if + // `input0 != input1`. + let dot: P = izip!(input0, input1, logic.diff_pinv) + .map(|(limb0, limb1, diff_pinv_el)| (limb0 - limb1) * diff_pinv_el) + .sum(); + yield_constr.constraint(eq_or_iszero_filter * (dot - unequal)); } pub fn eval_ext_circuit, const D: usize>( @@ -93,61 +102,61 @@ pub fn eval_ext_circuit, const D: usize>( lv: &CpuColumnsView>, yield_constr: &mut RecursiveConstraintConsumer, ) { + let zero = builder.zero_extension(); + let one = builder.one_extension(); + let logic = lv.general.logic(); + let input0 = lv.mem_channels[0].value; + let input1 = lv.mem_channels[1].value; + let output = lv.mem_channels[2].value; + let eq_filter = lv.is_eq; let iszero_filter = lv.is_iszero; let eq_or_iszero_filter = builder.add_extension(eq_filter, iszero_filter); - let ls_bit = logic.output[0]; + let equal = output[0]; + let unequal = builder.sub_extension(one, equal); - // Handle EQ and ISZERO. Most limbs of the output are 0, but the least-significant one is + // Handle `EQ` and `ISZERO`. Most limbs of the output are 0, but the least-significant one is // either 0 or 1. { - let constr = builder.mul_sub_extension(ls_bit, ls_bit, ls_bit); + let constr = builder.mul_extension(equal, unequal); + let constr = builder.mul_extension(eq_or_iszero_filter, constr); + yield_constr.constraint(builder, constr); + } + for &limb in &output[1..] { + let constr = builder.mul_extension(eq_or_iszero_filter, limb); + yield_constr.constraint(builder, constr); + } + + // If `ISZERO`, constrain input1 to be zero, effectively implementing ISZERO(x) as EQ(x, 0). + for limb in input1 { + let constr = builder.mul_extension(iszero_filter, limb); + yield_constr.constraint(builder, constr); + } + + // `equal` implies `input0[i] == input1[i]` for all `i`. + for (limb0, limb1) in izip!(input0, input1) { + let diff = builder.sub_extension(limb0, limb1); + let constr = builder.mul_extension(equal, diff); let constr = builder.mul_extension(eq_or_iszero_filter, constr); yield_constr.constraint(builder, constr); } - for &bit in &logic.output[1..] { - let constr = builder.mul_extension(eq_or_iszero_filter, bit); - yield_constr.constraint(builder, constr); - } - - // Check SIMPLE_LOGIC_DIFF - let diffs = lv.simple_logic_diff; - let diffs_inv = lv.simple_logic_diff_inv; + // `input0[i] == input1[i]` for all `i` implies `equal`. + // If `unequal`, find `diff_pinv` such that `(input0 - input1) @ diff_pinv == 1`, where `@` + // denotes the dot product (there will be many such `diff_pinv`). This can only be done if + // `input0 != input1`. { - let input0_sum = builder.add_many_extension(logic.input0); - { - let constr = builder.sub_extension(diffs, input0_sum); - let constr = builder.mul_extension(iszero_filter, constr); - yield_constr.constraint(builder, constr); - } - - let sum_squared_diffs = logic.input0.into_iter().zip(logic.input1).fold( - builder.zero_extension(), - |acc, (in0, in1)| { - let diff = builder.sub_extension(in0, in1); - builder.mul_add_extension(diff, diff, acc) + let dot: ExtensionTarget = izip!(input0, input1, logic.diff_pinv).fold( + zero, + |cumul, (limb0, limb1, diff_pinv_el)| { + let diff = builder.sub_extension(limb0, limb1); + builder.mul_add_extension(diff, diff_pinv_el, cumul) }, ); - { - let constr = builder.sub_extension(diffs, sum_squared_diffs); - let constr = builder.mul_extension(eq_filter, constr); - yield_constr.constraint(builder, constr); - } - } - - { - // diffs != 0 => ls_bit == 0 - let constr = builder.mul_extension(diffs, ls_bit); + let constr = builder.sub_extension(dot, unequal); let constr = builder.mul_extension(eq_or_iszero_filter, constr); yield_constr.constraint(builder, constr); } - { - // ls_bit == 0 => diffs != 0 (we provide a diffs_inv) - let constr = builder.mul_add_extension(diffs, diffs_inv, ls_bit); - let constr = builder.mul_sub_extension(eq_or_iszero_filter, constr, eq_or_iszero_filter); - yield_constr.constraint(builder, constr); - } } diff --git a/evm/src/cpu/simple_logic/not.rs b/evm/src/cpu/simple_logic/not.rs index efbf51a6..83d43276 100644 --- a/evm/src/cpu/simple_logic/not.rs +++ b/evm/src/cpu/simple_logic/not.rs @@ -7,7 +7,7 @@ use plonky2::iop::ext_target::ExtensionTarget; use crate::constraint_consumer::{ConstraintConsumer, RecursiveConstraintConsumer}; use crate::cpu::columns::CpuColumnsView; -const LIMB_SIZE: usize = 16; +const LIMB_SIZE: usize = 32; const ALL_1_LIMB: u64 = (1 << LIMB_SIZE) - 1; pub fn generate(lv: &mut CpuColumnsView) { @@ -17,8 +17,9 @@ pub fn generate(lv: &mut CpuColumnsView) { } assert_eq!(is_not_filter, 1); - let logic = lv.general.logic_mut(); - for (input, output_ref) in logic.input0.into_iter().zip(logic.output.iter_mut()) { + let input = lv.mem_channels[0].value; + let output = &mut lv.mem_channels[1].value; + for (input, output_ref) in input.into_iter().zip(output.iter_mut()) { let input = input.to_canonical_u64(); assert_eq!(input >> LIMB_SIZE, 0); let output = input ^ ALL_1_LIMB; @@ -30,14 +31,16 @@ pub fn eval_packed( lv: &CpuColumnsView

, yield_constr: &mut ConstraintConsumer

, ) { - // This is simple: just do output = 0xffff - input. - let logic = lv.general.logic(); + // This is simple: just do output = 0xffffffff - input. + let input = lv.mem_channels[0].value; + let output = lv.mem_channels[1].value; let cycle_filter = lv.is_cpu_cycle; let is_not_filter = lv.is_not; let filter = cycle_filter * is_not_filter; - for (input, output) in logic.input0.into_iter().zip(logic.output) { - yield_constr - .constraint(filter * (output + input - P::Scalar::from_canonical_u64(ALL_1_LIMB))); + for (input_limb, output_limb) in input.into_iter().zip(output) { + yield_constr.constraint( + filter * (output_limb + input_limb - P::Scalar::from_canonical_u64(ALL_1_LIMB)), + ); } } @@ -46,12 +49,13 @@ pub fn eval_ext_circuit, const D: usize>( lv: &CpuColumnsView>, yield_constr: &mut RecursiveConstraintConsumer, ) { - let logic = lv.general.logic(); + let input = lv.mem_channels[0].value; + let output = lv.mem_channels[1].value; let cycle_filter = lv.is_cpu_cycle; let is_not_filter = lv.is_not; let filter = builder.mul_extension(cycle_filter, is_not_filter); - for (input, output) in logic.input0.into_iter().zip(logic.output) { - let constr = builder.add_extension(output, input); + for (input_limb, output_limb) in input.into_iter().zip(output) { + let constr = builder.add_extension(output_limb, input_limb); let constr = builder.arithmetic_extension( F::ONE, -F::from_canonical_u64(ALL_1_LIMB), diff --git a/evm/src/cpu/stack_bounds.rs b/evm/src/cpu/stack_bounds.rs new file mode 100644 index 00000000..2c9c46eb --- /dev/null +++ b/evm/src/cpu/stack_bounds.rs @@ -0,0 +1,157 @@ +//! Checks for stack underflow and overflow. +//! +//! The constraints defined herein validate that stack exceptions (underflow and overflow) do not +//! occur. For example, if `is_add` is set but an addition would underflow, these constraints would +//! make the proof unverifiable. +//! +//! Faults are handled under a separate operation flag, `is_exception` (this is still TODO), which +//! traps to the kernel. The kernel then handles the exception. However, before it may do so, the +//! kernel must verify in software that an exception did in fact occur (i.e. the trap was +//! warranted) and `PANIC` otherwise; this prevents the prover from faking an exception on a valid +//! operation. + +use plonky2::field::extension::Extendable; +use plonky2::field::packed::PackedField; +use plonky2::field::types::Field; +use plonky2::hash::hash_types::RichField; +use plonky2::iop::ext_target::ExtensionTarget; + +use crate::constraint_consumer::{ConstraintConsumer, RecursiveConstraintConsumer}; +use crate::cpu::columns::{CpuColumnsView, COL_MAP}; + +const MAX_USER_STACK_SIZE: u64 = 1024; + +// Below only includes the operations that pop the top of the stack **without reading the value from +// memory**, i.e. `POP`. +// Other operations that have a minimum stack size (e.g. `MULMOD`, which has three inputs) read +// all their inputs from memory. On underflow, the cross-table lookup fails, as -1, ..., -17 are +// invalid memory addresses. +const DECREMENTING_FLAGS: [usize; 1] = [COL_MAP.is_pop]; + +// Operations that increase the stack length by 1, but excluding: +// - privileged (kernel-only) operations (superfluous; doesn't affect correctness), +// - operations that from userspace to the kernel (required for correctness). +// TODO: This list is incomplete. +const INCREMENTING_FLAGS: [usize; 2] = [COL_MAP.is_pc, COL_MAP.is_dup]; + +/// Calculates `lv.stack_len_bounds_aux`. Note that this must be run after decode. +pub fn generate(lv: &mut CpuColumnsView) { + let cycle_filter = lv.is_cpu_cycle; + if cycle_filter == F::ZERO { + return; + } + + let check_underflow: F = DECREMENTING_FLAGS.map(|i| lv[i]).into_iter().sum(); + let check_overflow: F = INCREMENTING_FLAGS.map(|i| lv[i]).into_iter().sum(); + let no_check = F::ONE - (check_underflow + check_overflow); + + let disallowed_len = check_overflow * F::from_canonical_u64(MAX_USER_STACK_SIZE) - no_check; + let diff = lv.stack_len - disallowed_len; + + let user_mode = F::ONE - lv.is_kernel_mode; + let rhs = user_mode + check_underflow; + + lv.stack_len_bounds_aux = match diff.try_inverse() { + Some(diff_inv) => diff_inv * rhs, // `rhs` may be a value other than 1 or 0 + None => { + assert_eq!(rhs, F::ZERO); + F::ZERO + } + } +} + +pub fn eval_packed( + lv: &CpuColumnsView

, + yield_constr: &mut ConstraintConsumer

, +) { + // `check_underflow`, `check_overflow`, and `no_check` are mutually exclusive. + let check_underflow: P = DECREMENTING_FLAGS.map(|i| lv[i]).into_iter().sum(); + let check_overflow: P = INCREMENTING_FLAGS.map(|i| lv[i]).into_iter().sum(); + let no_check = P::ONES - (check_underflow + check_overflow); + + // If `check_underflow`, then the instruction we are executing pops a value from the stack + // without reading it from memory, and the usual underflow checks do not work. We must show that + // `lv.stack_len` is not 0. We choose to perform this check whether or not we're in kernel mode. + // (The check in kernel mode is not necessary if the kernel is correct, but this is an easy + // sanity check. + // If `check_overflow`, then the instruction we are executing increases the stack length by 1. + // If we are in user mode, then we must show that the stack length is not currently + // `MAX_USER_STACK_SIZE`, as this is the maximum for the user stack. Note that this check must + // not run in kernel mode as the kernel's stack length is unrestricted. + // If `no_check`, then we don't need to check anything. The constraint is written to always + // test that `lv.stack_len` does not equal _something_ so we just show that it's not -1, which + // is always true. + + // 0 if `check_underflow`, `MAX_USER_STACK_SIZE` if `check_overflow`, and -1 if `no_check`. + let disallowed_len = + check_overflow * P::Scalar::from_canonical_u64(MAX_USER_STACK_SIZE) - no_check; + // This `lhs` must equal some `rhs`. If `rhs` is nonzero, then this shows that `lv.stack_len` is + // not `disallowed_len`. + let lhs = (lv.stack_len - disallowed_len) * lv.stack_len_bounds_aux; + + // We want this constraint to be active if we're in user mode OR the instruction might overflow. + // (In other words, we want to _skip_ overflow checks in kernel mode). + let user_mode = P::ONES - lv.is_kernel_mode; + // `rhs` is may be 0, 1, or 2. It's 0 if we're in kernel mode and we would be checking for + // overflow. + // Note: if `user_mode` and `check_underflow` then, `rhs` is 2. This is fine: we're still + // showing that `lv.stack_len - disallowed_len` is nonzero. + let rhs = user_mode + check_underflow; + + yield_constr.constraint(lv.is_cpu_cycle * (lhs - rhs)); +} + +pub fn eval_ext_circuit, const D: usize>( + builder: &mut plonky2::plonk::circuit_builder::CircuitBuilder, + lv: &CpuColumnsView>, + yield_constr: &mut RecursiveConstraintConsumer, +) { + let one = builder.one_extension(); + let max_stack_size = + builder.constant_extension(F::from_canonical_u64(MAX_USER_STACK_SIZE).into()); + + // `check_underflow`, `check_overflow`, and `no_check` are mutually exclusive. + let check_underflow = builder.add_many_extension(DECREMENTING_FLAGS.map(|i| lv[i])); + let check_overflow = builder.add_many_extension(INCREMENTING_FLAGS.map(|i| lv[i])); + let no_check = { + let any_check = builder.add_extension(check_underflow, check_overflow); + builder.sub_extension(one, any_check) + }; + + // If `check_underflow`, then the instruction we are executing pops a value from the stack + // without reading it from memory, and the usual underflow checks do not work. We must show that + // `lv.stack_len` is not 0. We choose to perform this check whether or not we're in kernel mode. + // (The check in kernel mode is not necessary if the kernel is correct, but this is an easy + // sanity check. + // If `check_overflow`, then the instruction we are executing increases the stack length by 1. + // If we are in user mode, then we must show that the stack length is not currently + // `MAX_USER_STACK_SIZE`, as this is the maximum for the user stack. Note that this check must + // not run in kernel mode as the kernel's stack length is unrestricted. + // If `no_check`, then we don't need to check anything. The constraint is written to always + // test that `lv.stack_len` does not equal _something_ so we just show that it's not -1, which + // is always true. + + // 0 if `check_underflow`, `MAX_USER_STACK_SIZE` if `check_overflow`, and -1 if `no_check`. + let disallowed_len = builder.mul_sub_extension(check_overflow, max_stack_size, no_check); + // This `lhs` must equal some `rhs`. If `rhs` is nonzero, then this shows that `lv.stack_len` is + // not `disallowed_len`. + let lhs = { + let diff = builder.sub_extension(lv.stack_len, disallowed_len); + builder.mul_extension(diff, lv.stack_len_bounds_aux) + }; + + // We want this constraint to be active if we're in user mode OR the instruction might overflow. + // (In other words, we want to _skip_ overflow checks in kernel mode). + let user_mode = builder.sub_extension(one, lv.is_kernel_mode); + // `rhs` is may be 0, 1, or 2. It's 0 if we're in kernel mode and we would be checking for + // overflow. + // Note: if `user_mode` and `check_underflow` then, `rhs` is 2. This is fine: we're still + // showing that `lv.stack_len - disallowed_len` is nonzero. + let rhs = builder.add_extension(user_mode, check_underflow); + + let constr = { + let diff = builder.sub_extension(lhs, rhs); + builder.mul_extension(lv.is_cpu_cycle, diff) + }; + yield_constr.constraint(builder, constr); +} diff --git a/evm/src/cpu/syscalls.rs b/evm/src/cpu/syscalls.rs index a676a6a2..b0b63be8 100644 --- a/evm/src/cpu/syscalls.rs +++ b/evm/src/cpu/syscalls.rs @@ -13,12 +13,16 @@ use crate::constraint_consumer::{ConstraintConsumer, RecursiveConstraintConsumer use crate::cpu::columns::{CpuColumnsView, COL_MAP}; use crate::cpu::kernel::aggregator::KERNEL; -const NUM_SYSCALLS: usize = 2; +const NUM_SYSCALLS: usize = 3; fn make_syscall_list() -> [(usize, usize); NUM_SYSCALLS] { let kernel = Lazy::force(&KERNEL); - [(COL_MAP.is_stop, "sys_stop"), (COL_MAP.is_exp, "sys_exp")] - .map(|(col_index, handler_name)| (col_index, kernel.global_labels[handler_name])) + [ + (COL_MAP.is_stop, "sys_stop"), + (COL_MAP.is_exp, "sys_exp"), + (COL_MAP.is_invalid, "handle_invalid"), + ] + .map(|(col_index, handler_name)| (col_index, kernel.global_labels[handler_name])) } static TRAP_LIST: Lazy<[(usize, usize); NUM_SYSCALLS]> = Lazy::new(make_syscall_list); @@ -28,7 +32,6 @@ pub fn eval_packed( nv: &CpuColumnsView

, yield_constr: &mut ConstraintConsumer

, ) { - let lv_syscalls = lv.general.syscalls(); let syscall_list = Lazy::force(&TRAP_LIST); // 1 if _any_ syscall, else 0. let should_syscall: P = syscall_list @@ -48,12 +51,14 @@ pub fn eval_packed( yield_constr.constraint_transition(filter * (nv.program_counter - syscall_dst)); // If syscall: set kernel mode yield_constr.constraint_transition(filter * (nv.is_kernel_mode - P::ONES)); + + let output = lv.mem_channels[0].value; // If syscall: push current PC to stack - yield_constr.constraint(filter * (lv_syscalls.output[0] - lv.program_counter)); + yield_constr.constraint(filter * (output[0] - lv.program_counter)); // If syscall: push current kernel flag to stack (share register with PC) - yield_constr.constraint(filter * (lv_syscalls.output[1] - lv.is_kernel_mode)); + yield_constr.constraint(filter * (output[1] - lv.is_kernel_mode)); // If syscall: zero the rest of that register - for &limb in &lv_syscalls.output[2..] { + for &limb in &output[2..] { yield_constr.constraint(filter * limb); } } @@ -64,7 +69,6 @@ pub fn eval_ext_circuit, const D: usize>( nv: &CpuColumnsView>, yield_constr: &mut RecursiveConstraintConsumer, ) { - let lv_syscalls = lv.general.syscalls(); let syscall_list = Lazy::force(&TRAP_LIST); // 1 if _any_ syscall, else 0. let should_syscall = @@ -90,20 +94,22 @@ pub fn eval_ext_circuit, const D: usize>( let constr = builder.mul_sub_extension(filter, nv.is_kernel_mode, filter); yield_constr.constraint_transition(builder, constr); } + + let output = lv.mem_channels[0].value; // If syscall: push current PC to stack { - let constr = builder.sub_extension(lv_syscalls.output[0], lv.program_counter); + let constr = builder.sub_extension(output[0], lv.program_counter); let constr = builder.mul_extension(filter, constr); yield_constr.constraint(builder, constr); } // If syscall: push current kernel flag to stack (share register with PC) { - let constr = builder.sub_extension(lv_syscalls.output[1], lv.is_kernel_mode); + let constr = builder.sub_extension(output[1], lv.is_kernel_mode); let constr = builder.mul_extension(filter, constr); yield_constr.constraint(builder, constr); } // If syscall: zero the rest of that register - for &limb in &lv_syscalls.output[2..] { + for &limb in &output[2..] { let constr = builder.mul_extension(filter, limb); yield_constr.constraint(builder, constr); } diff --git a/evm/src/cross_table_lookup.rs b/evm/src/cross_table_lookup.rs index 4097df7b..83f2083d 100644 --- a/evm/src/cross_table_lookup.rs +++ b/evm/src/cross_table_lookup.rs @@ -1,3 +1,4 @@ +use std::borrow::Borrow; use std::iter::repeat; use anyhow::{ensure, Result}; @@ -13,18 +14,18 @@ use plonky2::iop::target::Target; use plonky2::plonk::circuit_builder::CircuitBuilder; use plonky2::plonk::config::GenericConfig; -use crate::all_stark::Table; +use crate::all_stark::{Table, NUM_TABLES}; use crate::config::StarkConfig; use crate::constraint_consumer::{ConstraintConsumer, RecursiveConstraintConsumer}; use crate::permutation::{ get_grand_product_challenge_set, GrandProductChallenge, GrandProductChallengeSet, }; -use crate::proof::{StarkProofWithPublicInputs, StarkProofWithPublicInputsTarget}; +use crate::proof::{StarkProof, StarkProofTarget}; use crate::stark::Stark; use crate::vars::{StarkEvaluationTargets, StarkEvaluationVars}; /// Represent a linear combination of columns. -#[derive(Clone)] +#[derive(Clone, Debug)] pub struct Column { linear_combination: Vec<(usize, F)>, constant: F, @@ -38,8 +39,25 @@ impl Column { } } - pub fn singles>(cs: I) -> impl Iterator { - cs.into_iter().map(Self::single) + pub fn singles>>( + cs: I, + ) -> impl Iterator { + cs.into_iter().map(|c| Self::single(*c.borrow())) + } + + pub fn constant(constant: F) -> Self { + Self { + linear_combination: vec![], + constant, + } + } + + pub fn zero() -> Self { + Self::constant(F::ZERO) + } + + pub fn one() -> Self { + Self::constant(F::ONE) } pub fn linear_combination_with_constant>( @@ -63,12 +81,20 @@ impl Column { Self::linear_combination_with_constant(iter, F::ZERO) } - pub fn le_bits>(cs: I) -> Self { - Self::linear_combination(cs.into_iter().zip(F::TWO.powers())) + pub fn le_bits>>(cs: I) -> Self { + Self::linear_combination(cs.into_iter().map(|c| *c.borrow()).zip(F::TWO.powers())) } - pub fn sum>(cs: I) -> Self { - Self::linear_combination(cs.into_iter().zip(repeat(F::ONE))) + pub fn le_bytes>>(cs: I) -> Self { + Self::linear_combination( + cs.into_iter() + .map(|c| *c.borrow()) + .zip(F::from_canonical_u16(256).powers()), + ) + } + + pub fn sum>>(cs: I) -> Self { + Self::linear_combination(cs.into_iter().map(|c| *c.borrow()).zip(repeat(F::ONE))) } pub fn eval(&self, v: &[P]) -> P @@ -115,7 +141,7 @@ impl Column { } } -#[derive(Clone)] +#[derive(Clone, Debug)] pub struct TableWithColumns { table: Table, columns: Vec>, @@ -168,23 +194,21 @@ impl CrossTableLookup { } /// Cross-table lookup data for one table. -#[derive(Clone)] +#[derive(Clone, Default)] pub struct CtlData { - /// Challenges used in the argument. - pub(crate) challenges: GrandProductChallengeSet, - /// Vector of `(Z, columns, filter_columns)` where `Z` is a Z-polynomial for a lookup - /// on columns `columns` with filter columns `filter_columns`. - pub zs_columns: Vec<(PolynomialValues, Vec>, Option>)>, + pub(crate) zs_columns: Vec>, +} + +/// Cross-table lookup data associated with one Z(x) polynomial. +#[derive(Clone)] +pub(crate) struct CtlZData { + pub(crate) z: PolynomialValues, + pub(crate) challenge: GrandProductChallenge, + pub(crate) columns: Vec>, + pub(crate) filter_column: Option>, } impl CtlData { - pub(crate) fn new(challenges: GrandProductChallengeSet) -> Self { - Self { - challenges, - zs_columns: vec![], - } - } - pub fn len(&self) -> usize { self.zs_columns.len() } @@ -194,18 +218,21 @@ impl CtlData { } pub fn z_polys(&self) -> Vec> { - self.zs_columns.iter().map(|(p, _, _)| p.clone()).collect() + self.zs_columns + .iter() + .map(|zs_columns| zs_columns.z.clone()) + .collect() } } pub fn cross_table_lookup_data, const D: usize>( config: &StarkConfig, - trace_poly_values: &[Vec>], + trace_poly_values: &[Vec>; NUM_TABLES], cross_table_lookups: &[CrossTableLookup], challenger: &mut Challenger, -) -> Vec> { +) -> [CtlData; NUM_TABLES] { let challenges = get_grand_product_challenge_set(challenger, config.num_challenges); - let mut ctl_data_per_table = vec![CtlData::new(challenges.clone()); trace_poly_values.len()]; + let mut ctl_data_per_table = [0; NUM_TABLES].map(|_| CtlData::default()); for CrossTableLookup { looking_tables, looked_table, @@ -252,19 +279,23 @@ pub fn cross_table_lookup_data, const D ); for (table, z) in looking_tables.iter().zip(zs_looking) { - ctl_data_per_table[table.table as usize].zs_columns.push(( - z, - table.columns.clone(), - table.filter_column.clone(), - )); + ctl_data_per_table[table.table as usize] + .zs_columns + .push(CtlZData { + z, + challenge, + columns: table.columns.clone(), + filter_column: table.filter_column.clone(), + }); } ctl_data_per_table[looked_table.table as usize] .zs_columns - .push(( - z_looked, - looked_table.columns.clone(), - looked_table.filter_column.clone(), - )); + .push(CtlZData { + z: z_looked, + challenge, + columns: looked_table.columns.clone(), + filter_column: looked_table.filter_column.clone(), + }); } } ctl_data_per_table @@ -317,24 +348,23 @@ impl<'a, F: RichField + Extendable, const D: usize> CtlCheckVars<'a, F, F::Extension, F::Extension, D> { pub(crate) fn from_proofs>( - proofs: &[StarkProofWithPublicInputs], + proofs: &[StarkProof; NUM_TABLES], cross_table_lookups: &'a [CrossTableLookup], ctl_challenges: &'a GrandProductChallengeSet, - num_permutation_zs: &[usize], - ) -> Vec> { - debug_assert_eq!(proofs.len(), num_permutation_zs.len()); + num_permutation_zs: &[usize; NUM_TABLES], + ) -> [Vec; NUM_TABLES] { let mut ctl_zs = proofs .iter() .zip(num_permutation_zs) .map(|(p, &num_perms)| { - let openings = &p.proof.openings; + let openings = &p.openings; let ctl_zs = openings.permutation_ctl_zs.iter().skip(num_perms); let ctl_zs_next = openings.permutation_ctl_zs_next.iter().skip(num_perms); ctl_zs.zip(ctl_zs_next) }) .collect::>(); - let mut ctl_vars_per_table = vec![vec![]; proofs.len()]; + let mut ctl_vars_per_table = [0; NUM_TABLES].map(|_| vec![]); for CrossTableLookup { looking_tables, looked_table, @@ -368,7 +398,7 @@ impl<'a, F: RichField + Extendable, const D: usize> } pub(crate) fn eval_cross_table_lookup_checks( - vars: StarkEvaluationVars, + vars: StarkEvaluationVars, ctl_vars: &[CtlCheckVars], consumer: &mut ConstraintConsumer

, ) where @@ -421,24 +451,23 @@ pub struct CtlCheckVarsTarget<'a, F: Field, const D: usize> { impl<'a, F: Field, const D: usize> CtlCheckVarsTarget<'a, F, D> { pub(crate) fn from_proofs( - proofs: &[StarkProofWithPublicInputsTarget], + proofs: &[StarkProofTarget; NUM_TABLES], cross_table_lookups: &'a [CrossTableLookup], ctl_challenges: &'a GrandProductChallengeSet, - num_permutation_zs: &[usize], - ) -> Vec> { - debug_assert_eq!(proofs.len(), num_permutation_zs.len()); + num_permutation_zs: &[usize; NUM_TABLES], + ) -> [Vec; NUM_TABLES] { let mut ctl_zs = proofs .iter() .zip(num_permutation_zs) .map(|(p, &num_perms)| { - let openings = &p.proof.openings; + let openings = &p.openings; let ctl_zs = openings.permutation_ctl_zs.iter().skip(num_perms); let ctl_zs_next = openings.permutation_ctl_zs_next.iter().skip(num_perms); ctl_zs.zip(ctl_zs_next) }) .collect::>(); - let mut ctl_vars_per_table = vec![vec![]; proofs.len()]; + let mut ctl_vars_per_table = [0; NUM_TABLES].map(|_| vec![]); for CrossTableLookup { looking_tables, looked_table, @@ -477,7 +506,7 @@ pub(crate) fn eval_cross_table_lookup_checks_circuit< const D: usize, >( builder: &mut CircuitBuilder, - vars: StarkEvaluationTargets, + vars: StarkEvaluationTargets, ctl_vars: &[CtlCheckVarsTarget], consumer: &mut RecursiveConstraintConsumer, ) { @@ -539,17 +568,17 @@ pub(crate) fn verify_cross_table_lookups< const D: usize, >( cross_table_lookups: Vec>, - proofs: &[StarkProofWithPublicInputs], + proofs: &[StarkProof; NUM_TABLES], challenges: GrandProductChallengeSet, config: &StarkConfig, ) -> Result<()> { let degrees_bits = proofs .iter() - .map(|p| p.proof.recover_degree_bits(config)) + .map(|p| p.recover_degree_bits(config)) .collect::>(); let mut ctl_zs_openings = proofs .iter() - .map(|p| p.proof.openings.ctl_zs_last.iter()) + .map(|p| p.openings.ctl_zs_last.iter()) .collect::>(); for ( i, @@ -597,17 +626,17 @@ pub(crate) fn verify_cross_table_lookups_circuit< >( builder: &mut CircuitBuilder, cross_table_lookups: Vec>, - proofs: &[StarkProofWithPublicInputsTarget], + proofs: &[StarkProofTarget; NUM_TABLES], challenges: GrandProductChallengeSet, inner_config: &StarkConfig, ) { let degrees_bits = proofs .iter() - .map(|p| p.proof.recover_degree_bits(inner_config)) + .map(|p| p.recover_degree_bits(inner_config)) .collect::>(); let mut ctl_zs_openings = proofs .iter() - .map(|p| p.proof.openings.ctl_zs_last.iter()) + .map(|p| p.openings.ctl_zs_last.iter()) .collect::>(); for ( i, diff --git a/evm/src/generation/mod.rs b/evm/src/generation/mod.rs index 02c91d16..baf2ec32 100644 --- a/evm/src/generation/mod.rs +++ b/evm/src/generation/mod.rs @@ -1,73 +1,121 @@ -use ethereum_types::U256; +use eth_trie_utils::partial_trie::PartialTrie; +use ethereum_types::Address; use plonky2::field::extension::Extendable; use plonky2::field::polynomial::PolynomialValues; use plonky2::field::types::Field; use plonky2::hash::hash_types::RichField; +use serde::{Deserialize, Serialize}; -use crate::all_stark::AllStark; +use crate::all_stark::{AllStark, NUM_TABLES}; +use crate::config::StarkConfig; use crate::cpu::bootstrap_kernel::generate_bootstrap_kernel; use crate::cpu::columns::NUM_CPU_COLUMNS; +use crate::cpu::kernel::global_metadata::GlobalMetadata; use crate::generation::state::GenerationState; +use crate::memory::segments::Segment; +use crate::memory::NUM_CHANNELS; +use crate::proof::{BlockMetadata, PublicValues, TrieRoots}; use crate::util::trace_rows_to_poly_values; pub(crate) mod memory; pub(crate) mod state; -/// A piece of data which has been encoded using Recursive Length Prefix (RLP) serialization. -/// See https://ethereum.org/en/developers/docs/data-structures-and-encoding/rlp/ -pub type RlpBlob = Vec; +#[derive(Clone, Debug, Deserialize, Serialize)] +/// Inputs needed for trace generation. +pub struct GenerationInputs { + pub signed_txns: Vec>, -/// Merkle proofs are encoded using an RLP blob for each node in the path. -pub type RlpMerkleProof = Vec; + /// A partial version of the state trie prior to these transactions. It should include all nodes + /// that will be accessed by these transactions. + pub state_trie: PartialTrie, -#[allow(unused)] // TODO: Should be used soon. -pub struct TransactionData { - pub signed_txn: Vec, + /// A partial version of the transaction trie prior to these transactions. It should include all + /// nodes that will be accessed by these transactions. + pub transactions_trie: PartialTrie, - /// A Merkle proof for each interaction with the state trie, ordered chronologically. - pub trie_proofs: Vec, + /// A partial version of the receipt trie prior to these transactions. It should include all nodes + /// that will be accessed by these transactions. + pub receipts_trie: PartialTrie, + + /// A partial version of each storage trie prior to these transactions. It should include all + /// storage tries, and nodes therein, that will be accessed by these transactions. + pub storage_tries: Vec<(Address, PartialTrie)>, + + pub block_metadata: BlockMetadata, } -#[allow(unused)] // TODO: Should be used soon. -pub fn generate_traces, const D: usize>( +pub(crate) fn generate_traces, const D: usize>( all_stark: &AllStark, - txns: &[TransactionData], -) -> Vec>> { + inputs: GenerationInputs, + config: &StarkConfig, +) -> ([Vec>; NUM_TABLES], PublicValues) { let mut state = GenerationState::::default(); generate_bootstrap_kernel::(&mut state); - for txn in txns { + for txn in &inputs.signed_txns { generate_txn(&mut state, txn); } + // TODO: Pad to a power of two, ending in the `halt` kernel function. + + let cpu_rows = state.cpu_rows.len(); + let mem_end_timestamp = cpu_rows * NUM_CHANNELS; + let mut read_metadata = |field| { + state.get_mem( + 0, + Segment::GlobalMetadata, + field as usize, + mem_end_timestamp, + ) + }; + + let trie_roots_before = TrieRoots { + state_root: read_metadata(GlobalMetadata::StateTrieRootDigestBefore), + transactions_root: read_metadata(GlobalMetadata::TransactionsTrieRootDigestBefore), + receipts_root: read_metadata(GlobalMetadata::ReceiptsTrieRootDigestBefore), + }; + let trie_roots_after = TrieRoots { + state_root: read_metadata(GlobalMetadata::StateTrieRootDigestAfter), + transactions_root: read_metadata(GlobalMetadata::TransactionsTrieRootDigestAfter), + receipts_root: read_metadata(GlobalMetadata::ReceiptsTrieRootDigestAfter), + }; + let GenerationState { cpu_rows, current_cpu_row, memory, keccak_inputs, + keccak_memory_inputs, logic_ops, - prover_inputs, .. } = state; assert_eq!(current_cpu_row, [F::ZERO; NUM_CPU_COLUMNS].into()); - assert_eq!(prover_inputs, vec![], "Not all prover inputs were consumed"); let cpu_trace = trace_rows_to_poly_values(cpu_rows); let keccak_trace = all_stark.keccak_stark.generate_trace(keccak_inputs); + let keccak_memory_trace = all_stark + .keccak_memory_stark + .generate_trace(keccak_memory_inputs, 1 << config.fri_config.cap_height); let logic_trace = all_stark.logic_stark.generate_trace(logic_ops); let memory_trace = all_stark.memory_stark.generate_trace(memory.log); - vec![cpu_trace, keccak_trace, logic_trace, memory_trace] + let traces = [ + cpu_trace, + keccak_trace, + keccak_memory_trace, + logic_trace, + memory_trace, + ]; + + let public_values = PublicValues { + trie_roots_before, + trie_roots_after, + block_metadata: inputs.block_metadata, + }; + + (traces, public_values) } -fn generate_txn(state: &mut GenerationState, txn: &TransactionData) { - // TODO: Add transaction RLP to prover_input. - - // Supply Merkle trie proofs as prover inputs. - for proof in &txn.trie_proofs { - let proof = proof - .iter() - .flat_map(|node_rlp| node_rlp.iter().map(|byte| U256::from(*byte))); - state.prover_inputs.extend(proof); - } +fn generate_txn(_state: &mut GenerationState, _signed_txn: &[u8]) { + // TODO } diff --git a/evm/src/generation/state.rs b/evm/src/generation/state.rs index c7f1003e..4cbe61c8 100644 --- a/evm/src/generation/state.rs +++ b/evm/src/generation/state.rs @@ -2,11 +2,15 @@ use std::mem; use ethereum_types::U256; use plonky2::field::types::Field; +use tiny_keccak::keccakf; use crate::cpu::columns::{CpuColumnsView, NUM_CPU_COLUMNS}; use crate::generation::memory::MemoryState; +use crate::keccak_memory::keccak_memory_stark::KeccakMemoryOp; use crate::memory::memory_stark::MemoryOp; use crate::memory::segments::Segment; +use crate::memory::NUM_CHANNELS; +use crate::util::u256_limbs; use crate::{keccak, logic}; #[derive(Debug)] @@ -18,10 +22,8 @@ pub(crate) struct GenerationState { pub(crate) memory: MemoryState, pub(crate) keccak_inputs: Vec<[u64; keccak::keccak_stark::NUM_INPUTS]>, + pub(crate) keccak_memory_inputs: Vec, pub(crate) logic_ops: Vec, - - /// Non-deterministic inputs provided by the prover. - pub(crate) prover_inputs: Vec, } impl GenerationState { @@ -51,19 +53,52 @@ impl GenerationState { result } - /// Read some memory within the current execution context, and log the operation. + /// Like `get_mem_cpu`, but reads from the current context specifically. #[allow(unused)] // TODO: Should be used soon. - pub(crate) fn get_mem_current( + pub(crate) fn get_mem_cpu_current( &mut self, channel_index: usize, segment: Segment, virt: usize, ) -> U256 { - let timestamp = self.cpu_rows.len(); let context = self.current_context; + self.get_mem_cpu(channel_index, context, segment, virt) + } + + /// Simulates the CPU reading some memory through the given channel. Besides logging the memory + /// operation, this also generates the associated registers in the current CPU row. + pub(crate) fn get_mem_cpu( + &mut self, + channel_index: usize, + context: usize, + segment: Segment, + virt: usize, + ) -> U256 { + let timestamp = self.cpu_rows.len() * NUM_CHANNELS + channel_index; + let value = self.get_mem(context, segment, virt, timestamp); + + let channel = &mut self.current_cpu_row.mem_channels[channel_index]; + channel.used = F::ONE; + channel.is_read = F::ONE; + channel.addr_context = F::from_canonical_usize(context); + channel.addr_segment = F::from_canonical_usize(segment as usize); + channel.addr_virtual = F::from_canonical_usize(virt); + channel.value = u256_limbs(value); + + value + } + + /// Read some memory, and log the operation. + pub(crate) fn get_mem( + &mut self, + context: usize, + segment: Segment, + virt: usize, + timestamp: usize, + ) -> U256 { let value = self.memory.contexts[context].segments[segment as usize].get(virt); self.memory.log.push(MemoryOp { - channel_index: Some(channel_index), + filter: true, timestamp, is_read: true, context, @@ -75,17 +110,49 @@ impl GenerationState { } /// Write some memory within the current execution context, and log the operation. - pub(crate) fn set_mem_current( + pub(crate) fn set_mem_cpu_current( &mut self, channel_index: usize, segment: Segment, virt: usize, value: U256, ) { - let timestamp = self.cpu_rows.len(); let context = self.current_context; + self.set_mem_cpu(channel_index, context, segment, virt, value); + } + + /// Write some memory, and log the operation. + pub(crate) fn set_mem_cpu( + &mut self, + channel_index: usize, + context: usize, + segment: Segment, + virt: usize, + value: U256, + ) { + let timestamp = self.cpu_rows.len() * NUM_CHANNELS + channel_index; + self.set_mem(context, segment, virt, value, timestamp); + + let channel = &mut self.current_cpu_row.mem_channels[channel_index]; + channel.used = F::ONE; + channel.is_read = F::ZERO; // For clarity; should already be 0. + channel.addr_context = F::from_canonical_usize(context); + channel.addr_segment = F::from_canonical_usize(segment as usize); + channel.addr_virtual = F::from_canonical_usize(virt); + channel.value = u256_limbs(value); + } + + /// Write some memory, and log the operation. + pub(crate) fn set_mem( + &mut self, + context: usize, + segment: Segment, + virt: usize, + value: U256, + timestamp: usize, + ) { self.memory.log.push(MemoryOp { - channel_index: Some(channel_index), + filter: true, timestamp, is_read: false, context, @@ -96,6 +163,54 @@ impl GenerationState { self.memory.contexts[context].segments[segment as usize].set(virt, value) } + /// Evaluate the Keccak-f permutation in-place on some data in memory, and record the operations + /// for the purpose of witness generation. + #[allow(unused)] // TODO: Should be used soon. + pub(crate) fn keccak_memory( + &mut self, + context: usize, + segment: Segment, + virt: usize, + ) -> [u64; keccak::keccak_stark::NUM_INPUTS] { + let read_timestamp = self.cpu_rows.len() * NUM_CHANNELS; + let _write_timestamp = read_timestamp + 1; + let input = (0..25) + .map(|i| { + let bytes = [0, 1, 2, 3, 4, 5, 6, 7].map(|j| { + let virt = virt + i * 8 + j; + let byte = self.get_mem(context, segment, virt, read_timestamp); + debug_assert!(byte.bits() <= 8); + byte.as_u32() as u8 + }); + u64::from_le_bytes(bytes) + }) + .collect::>() + .try_into() + .unwrap(); + let output = self.keccak(input); + self.keccak_memory_inputs.push(KeccakMemoryOp { + context, + segment, + virt, + read_timestamp, + input, + output, + }); + // TODO: Write output to memory. + output + } + + /// Evaluate the Keccak-f permutation, and record the operation for the purpose of witness + /// generation. + pub(crate) fn keccak( + &mut self, + mut input: [u64; keccak::keccak_stark::NUM_INPUTS], + ) -> [u64; keccak::keccak_stark::NUM_INPUTS] { + self.keccak_inputs.push(input); + keccakf(&mut input); + input + } + pub(crate) fn commit_cpu_row(&mut self) { let mut swapped_row = [F::ZERO; NUM_CPU_COLUMNS].into(); mem::swap(&mut self.current_cpu_row, &mut swapped_row); @@ -113,8 +228,8 @@ impl Default for GenerationState { current_context: 0, memory: MemoryState::default(), keccak_inputs: vec![], + keccak_memory_inputs: vec![], logic_ops: vec![], - prover_inputs: vec![], } } } diff --git a/evm/src/get_challenges.rs b/evm/src/get_challenges.rs index 88727ad3..6545a1af 100644 --- a/evm/src/get_challenges.rs +++ b/evm/src/get_challenges.rs @@ -1,4 +1,3 @@ -use itertools::izip; use plonky2::field::extension::Extendable; use plonky2::fri::proof::{FriProof, FriProofTarget}; use plonky2::hash::hash_types::RichField; @@ -24,22 +23,26 @@ impl, C: GenericConfig, const D: usize> A let mut challenger = Challenger::::new(); for proof in &self.stark_proofs { - challenger.observe_cap(&proof.proof.trace_cap); + challenger.observe_cap(&proof.trace_cap); } + // TODO: Observe public values. + let ctl_challenges = get_grand_product_challenge_set(&mut challenger, config.num_challenges); + let num_permutation_zs = all_stark.nums_permutation_zs(config); + let num_permutation_batch_sizes = all_stark.permutation_batch_sizes(); + AllProofChallenges { - stark_challenges: izip!( - &self.stark_proofs, - all_stark.nums_permutation_zs(config), - all_stark.permutation_batch_sizes() - ) - .map(|(proof, num_perm, batch_size)| { - proof.get_challenges(&mut challenger, num_perm > 0, batch_size, config) - }) - .collect(), + stark_challenges: std::array::from_fn(|i| { + self.stark_proofs[i].get_challenges( + &mut challenger, + num_permutation_zs[i] > 0, + num_permutation_batch_sizes[i], + config, + ) + }), ctl_challenges, } } @@ -58,34 +61,31 @@ impl AllProofTarget { let mut challenger = RecursiveChallenger::::new(builder); for proof in &self.stark_proofs { - challenger.observe_cap(&proof.proof.trace_cap); + challenger.observe_cap(&proof.trace_cap); } let ctl_challenges = get_grand_product_challenge_set_target(builder, &mut challenger, config.num_challenges); + let num_permutation_zs = all_stark.nums_permutation_zs(config); + let num_permutation_batch_sizes = all_stark.permutation_batch_sizes(); + AllProofChallengesTarget { - stark_challenges: izip!( - &self.stark_proofs, - all_stark.nums_permutation_zs(config), - all_stark.permutation_batch_sizes() - ) - .map(|(proof, num_perm, batch_size)| { - proof.get_challenges::( + stark_challenges: std::array::from_fn(|i| { + self.stark_proofs[i].get_challenges::( builder, &mut challenger, - num_perm > 0, - batch_size, + num_permutation_zs[i] > 0, + num_permutation_batch_sizes[i], config, ) - }) - .collect(), + }), ctl_challenges, } } } -impl StarkProofWithPublicInputs +impl StarkProof where F: RichField + Extendable, C: GenericConfig, @@ -98,7 +98,7 @@ where stark_permutation_batch_size: usize, config: &StarkConfig, ) -> StarkProofChallenges { - let degree_bits = self.proof.recover_degree_bits(config); + let degree_bits = self.recover_degree_bits(config); let StarkProof { permutation_ctl_zs_cap, @@ -112,7 +112,7 @@ where .. }, .. - } = &self.proof; + } = &self; let num_challenges = config.num_challenges; @@ -148,7 +148,7 @@ where } } -impl StarkProofWithPublicInputsTarget { +impl StarkProofTarget { pub(crate) fn get_challenges, C: GenericConfig>( &self, builder: &mut CircuitBuilder, @@ -172,7 +172,7 @@ impl StarkProofWithPublicInputsTarget { .. }, .. - } = &self.proof; + } = &self; let num_challenges = config.num_challenges; diff --git a/evm/src/keccak/columns.rs b/evm/src/keccak/columns.rs index 39116b4a..8313c676 100644 --- a/evm/src/keccak/columns.rs +++ b/evm/src/keccak/columns.rs @@ -15,8 +15,11 @@ pub const fn reg_step(i: usize) -> usize { pub fn reg_input_limb(i: usize) -> Column { debug_assert!(i < 2 * NUM_INPUTS); let i_u64 = i / 2; // The index of the 64-bit chunk. - let x = i_u64 / 5; - let y = i_u64 % 5; + + // The 5x5 state is treated as y-major, as per the Keccak spec. + let y = i_u64 / 5; + let x = i_u64 % 5; + let reg_low_limb = reg_a(x, y); let is_high_limb = i % 2; Column::single(reg_low_limb + is_high_limb) @@ -28,8 +31,11 @@ pub fn reg_input_limb(i: usize) -> Column { pub const fn reg_output_limb(i: usize) -> usize { debug_assert!(i < 2 * NUM_INPUTS); let i_u64 = i / 2; // The index of the 64-bit chunk. - let x = i_u64 / 5; - let y = i_u64 % 5; + + // The 5x5 state is treated as y-major, as per the Keccak spec. + let y = i_u64 / 5; + let x = i_u64 % 5; + let is_high_limb = i % 2; reg_a_prime_prime_prime(x, y) + is_high_limb } diff --git a/evm/src/keccak/keccak_stark.rs b/evm/src/keccak/keccak_stark.rs index 94fa795d..23ffe0e9 100644 --- a/evm/src/keccak/keccak_stark.rs +++ b/evm/src/keccak/keccak_stark.rs @@ -32,8 +32,6 @@ pub(crate) const NUM_ROUNDS: usize = 24; /// Number of 64-bit elements in the Keccak permutation input. pub(crate) const NUM_INPUTS: usize = 25; -pub(crate) const NUM_PUBLIC_INPUTS: usize = 0; - pub fn ctl_data() -> Vec> { let mut res: Vec<_> = (0..2 * NUM_INPUTS).map(reg_input_limb).collect(); res.extend(Column::singles((0..2 * NUM_INPUTS).map(reg_output_limb))); @@ -76,7 +74,7 @@ impl, const D: usize> KeccakStark { for x in 0..5 { for y in 0..5 { - let input_xy = input[x * 5 + y]; + let input_xy = input[y * 5 + x]; let reg_lo = reg_a(x, y); let reg_hi = reg_lo + 1; rows[0][reg_lo] = F::from_canonical_u64(input_xy & 0xFFFFFFFF); @@ -134,9 +132,10 @@ impl, const D: usize> KeccakStark { } } - // Populate A'. - // A'[x, y] = xor(A[x, y], D[x]) - // = xor(A[x, y], C[x - 1], ROT(C[x + 1], 1)) + // Populate A'. To avoid shifting indices, we rewrite + // A'[x, y, z] = xor(A[x, y, z], C[x - 1, z], C[x + 1, z - 1]) + // as + // A'[x, y, z] = xor(A[x, y, z], C[x, z], C'[x, z]). for x in 0..5 { for y in 0..5 { for z in 0..64 { @@ -145,11 +144,8 @@ impl, const D: usize> KeccakStark { let reg_a_limb = reg_a(x, y) + is_high_limb; let a_limb = row[reg_a_limb].to_canonical_u64() as u32; let a_bit = F::from_bool(((a_limb >> bit_in_limb) & 1) != 0); - row[reg_a_prime(x, y, z)] = xor([ - a_bit, - row[reg_c((x + 4) % 5, z)], - row[reg_c((x + 1) % 5, (z + 64 - 1) % 64)], - ]); + row[reg_a_prime(x, y, z)] = + xor([a_bit, row[reg_c(x, z)], row[reg_c_prime(x, z)]]); } } } @@ -228,11 +224,10 @@ impl, const D: usize> KeccakStark { impl, const D: usize> Stark for KeccakStark { const COLUMNS: usize = NUM_COLUMNS; - const PUBLIC_INPUTS: usize = NUM_PUBLIC_INPUTS; fn eval_packed_generic( &self, - vars: StarkEvaluationVars, + vars: StarkEvaluationVars, yield_constr: &mut ConstraintConsumer

, ) where FE: FieldExtension, @@ -380,7 +375,7 @@ impl, const D: usize> Stark for KeccakStark, - vars: StarkEvaluationTargets, + vars: StarkEvaluationTargets, yield_constr: &mut RecursiveConstraintConsumer, ) { let two = builder.two(); @@ -547,9 +542,9 @@ impl, const D: usize> Stark for KeccakStark>(); - let mut keccak_input: [[u64; 5]; 5] = [ - input[0..5].try_into().unwrap(), - input[5..10].try_into().unwrap(), - input[10..15].try_into().unwrap(), - input[15..20].try_into().unwrap(), - input[20..25].try_into().unwrap(), - ]; - - let keccak = KeccakF::new(StateBitsWidth::F1600); - keccak.permutations(&mut keccak_input); - let expected: Vec<_> = keccak_input - .iter() - .flatten() - .map(|&x| F::from_canonical_u64(x)) - .collect(); + let expected = { + let mut state = input; + keccakf(&mut state); + state + }; assert_eq!(output, expected); diff --git a/evm/src/keccak/round_flags.rs b/evm/src/keccak/round_flags.rs index 6a4d03b6..920ca4c8 100644 --- a/evm/src/keccak/round_flags.rs +++ b/evm/src/keccak/round_flags.rs @@ -7,12 +7,12 @@ use plonky2::plonk::circuit_builder::CircuitBuilder; use crate::constraint_consumer::{ConstraintConsumer, RecursiveConstraintConsumer}; use crate::keccak::columns::reg_step; use crate::keccak::columns::NUM_COLUMNS; -use crate::keccak::keccak_stark::{NUM_PUBLIC_INPUTS, NUM_ROUNDS}; +use crate::keccak::keccak_stark::NUM_ROUNDS; use crate::vars::StarkEvaluationTargets; use crate::vars::StarkEvaluationVars; pub(crate) fn eval_round_flags>( - vars: StarkEvaluationVars, + vars: StarkEvaluationVars, yield_constr: &mut ConstraintConsumer

, ) { // Initially, the first step flag should be 1 while the others should be 0. @@ -30,7 +30,7 @@ pub(crate) fn eval_round_flags>( pub(crate) fn eval_round_flags_recursively, const D: usize>( builder: &mut CircuitBuilder, - vars: StarkEvaluationTargets, + vars: StarkEvaluationTargets, yield_constr: &mut RecursiveConstraintConsumer, ) { let one = builder.one_extension(); diff --git a/evm/src/keccak_memory/columns.rs b/evm/src/keccak_memory/columns.rs new file mode 100644 index 00000000..92bdbf2b --- /dev/null +++ b/evm/src/keccak_memory/columns.rs @@ -0,0 +1,29 @@ +pub(crate) const KECCAK_WIDTH_BYTES: usize = 200; + +/// 1 if this row represents a real operation; 0 if it's a padding row. +pub(crate) const COL_IS_REAL: usize = 0; + +// The address at which we will read inputs and write outputs. +pub(crate) const COL_CONTEXT: usize = 1; +pub(crate) const COL_SEGMENT: usize = 2; +pub(crate) const COL_VIRTUAL: usize = 3; + +/// The timestamp at which inputs should be read from memory. +/// Outputs will be written at the following timestamp. +pub(crate) const COL_READ_TIMESTAMP: usize = 4; + +const START_INPUT_LIMBS: usize = 5; +/// A byte of the input. +pub(crate) fn col_input_byte(i: usize) -> usize { + debug_assert!(i < KECCAK_WIDTH_BYTES); + START_INPUT_LIMBS + i +} + +const START_OUTPUT_LIMBS: usize = START_INPUT_LIMBS + KECCAK_WIDTH_BYTES; +/// A byte of the output. +pub(crate) fn col_output_byte(i: usize) -> usize { + debug_assert!(i < KECCAK_WIDTH_BYTES); + START_OUTPUT_LIMBS + i +} + +pub const NUM_COLUMNS: usize = START_OUTPUT_LIMBS + KECCAK_WIDTH_BYTES; diff --git a/evm/src/keccak_memory/keccak_memory_stark.rs b/evm/src/keccak_memory/keccak_memory_stark.rs new file mode 100644 index 00000000..cf8955b3 --- /dev/null +++ b/evm/src/keccak_memory/keccak_memory_stark.rs @@ -0,0 +1,227 @@ +use std::marker::PhantomData; + +use plonky2::field::extension::{Extendable, FieldExtension}; +use plonky2::field::packed::PackedField; +use plonky2::field::polynomial::PolynomialValues; +use plonky2::field::types::Field; +use plonky2::hash::hash_types::RichField; +use plonky2::timed; +use plonky2::util::timing::TimingTree; + +use crate::constraint_consumer::{ConstraintConsumer, RecursiveConstraintConsumer}; +use crate::cross_table_lookup::Column; +use crate::keccak::keccak_stark::NUM_INPUTS; +use crate::keccak_memory::columns::*; +use crate::memory::segments::Segment; +use crate::stark::Stark; +use crate::util::trace_rows_to_poly_values; +use crate::vars::StarkEvaluationTargets; +use crate::vars::StarkEvaluationVars; + +pub(crate) fn ctl_looked_data() -> Vec> { + Column::singles([COL_CONTEXT, COL_SEGMENT, COL_VIRTUAL, COL_READ_TIMESTAMP]).collect() +} + +pub(crate) fn ctl_looking_keccak() -> Vec> { + let input_cols = (0..50).map(|i| { + Column::le_bytes((0..4).map(|j| { + let byte_index = i * 4 + j; + col_input_byte(byte_index) + })) + }); + let output_cols = (0..50).map(|i| { + Column::le_bytes((0..4).map(|j| { + let byte_index = i * 4 + j; + col_output_byte(byte_index) + })) + }); + input_cols.chain(output_cols).collect() +} + +pub(crate) fn ctl_looking_memory(i: usize, is_read: bool) -> Vec> { + let mut res = vec![Column::constant(F::from_bool(is_read))]; + res.extend(Column::singles([COL_CONTEXT, COL_SEGMENT, COL_VIRTUAL])); + + res.push(Column::single(col_input_byte(i))); + // Since we're reading or writing a single byte, the higher limbs must be zero. + res.extend((1..8).map(|_| Column::zero())); + + // Since COL_READ_TIMESTAMP is the read time, we add 1 if this is a write. + let is_write_f = F::from_bool(!is_read); + res.push(Column::linear_combination_with_constant( + [(COL_READ_TIMESTAMP, F::ONE)], + is_write_f, + )); + + assert_eq!( + res.len(), + crate::memory::memory_stark::ctl_data::().len() + ); + res +} + +/// CTL filter used for both directions (looked and looking). +pub(crate) fn ctl_filter() -> Column { + Column::single(COL_IS_REAL) +} + +/// Information about a Keccak memory operation needed for witness generation. +#[derive(Debug)] +pub(crate) struct KeccakMemoryOp { + // The address at which we will read inputs and write outputs. + pub(crate) context: usize, + pub(crate) segment: Segment, + pub(crate) virt: usize, + + /// The timestamp at which inputs should be read from memory. + /// Outputs will be written at the following timestamp. + pub(crate) read_timestamp: usize, + + /// The input that was read at that address. + pub(crate) input: [u64; NUM_INPUTS], + pub(crate) output: [u64; NUM_INPUTS], +} + +#[derive(Copy, Clone, Default)] +pub struct KeccakMemoryStark { + pub(crate) f: PhantomData, +} + +impl, const D: usize> KeccakMemoryStark { + #[allow(unused)] // TODO: Should be used soon. + pub(crate) fn generate_trace( + &self, + operations: Vec, + min_rows: usize, + ) -> Vec> { + let mut timing = TimingTree::new("generate trace", log::Level::Debug); + + // Generate the witness row-wise. + let trace_rows = timed!( + &mut timing, + "generate trace rows", + self.generate_trace_rows(operations, min_rows) + ); + + let trace_polys = timed!( + &mut timing, + "convert to PolynomialValues", + trace_rows_to_poly_values(trace_rows) + ); + + timing.print(); + trace_polys + } + + fn generate_trace_rows( + &self, + operations: Vec, + min_rows: usize, + ) -> Vec<[F; NUM_COLUMNS]> { + let num_rows = operations.len().max(min_rows).next_power_of_two(); + let mut rows = Vec::with_capacity(num_rows); + for op in operations { + rows.push(self.generate_row_for_op(op)); + } + + let padding_row = self.generate_padding_row(); + for _ in rows.len()..num_rows { + rows.push(padding_row); + } + rows + } + + fn generate_row_for_op(&self, op: KeccakMemoryOp) -> [F; NUM_COLUMNS] { + let mut row = [F::ZERO; NUM_COLUMNS]; + row[COL_IS_REAL] = F::ONE; + row[COL_CONTEXT] = F::from_canonical_usize(op.context); + row[COL_SEGMENT] = F::from_canonical_usize(op.segment as usize); + row[COL_VIRTUAL] = F::from_canonical_usize(op.virt); + row[COL_READ_TIMESTAMP] = F::from_canonical_usize(op.read_timestamp); + for i in 0..25 { + let input_u64 = op.input[i]; + let output_u64 = op.output[i]; + for j in 0..8 { + let byte_index = i * 8 + j; + row[col_input_byte(byte_index)] = F::from_canonical_u8(input_u64.to_le_bytes()[j]); + row[col_output_byte(byte_index)] = + F::from_canonical_u8(output_u64.to_le_bytes()[j]); + } + } + row + } + + fn generate_padding_row(&self) -> [F; NUM_COLUMNS] { + // We just need COL_IS_REAL to be zero, which it is by default. + // The other fields will have no effect. + [F::ZERO; NUM_COLUMNS] + } +} + +impl, const D: usize> Stark for KeccakMemoryStark { + const COLUMNS: usize = NUM_COLUMNS; + + fn eval_packed_generic( + &self, + vars: StarkEvaluationVars, + yield_constr: &mut ConstraintConsumer

, + ) where + FE: FieldExtension, + P: PackedField, + { + // is_real must be 0 or 1. + let is_real = vars.local_values[COL_IS_REAL]; + yield_constr.constraint(is_real * (is_real - P::ONES)); + } + + fn eval_ext_circuit( + &self, + builder: &mut plonky2::plonk::circuit_builder::CircuitBuilder, + vars: StarkEvaluationTargets, + yield_constr: &mut RecursiveConstraintConsumer, + ) { + // is_real must be 0 or 1. + let is_real = vars.local_values[COL_IS_REAL]; + let constraint = builder.mul_sub_extension(is_real, is_real, is_real); + yield_constr.constraint(builder, constraint); + } + + fn constraint_degree(&self) -> usize { + 2 + } +} + +#[cfg(test)] +mod tests { + use anyhow::Result; + use plonky2::plonk::config::{GenericConfig, PoseidonGoldilocksConfig}; + + use crate::keccak_memory::keccak_memory_stark::KeccakMemoryStark; + use crate::stark_testing::{test_stark_circuit_constraints, test_stark_low_degree}; + + #[test] + fn test_stark_degree() -> Result<()> { + const D: usize = 2; + type C = PoseidonGoldilocksConfig; + type F = >::F; + type S = KeccakMemoryStark; + + let stark = S { + f: Default::default(), + }; + test_stark_low_degree(stark) + } + + #[test] + fn test_stark_circuit() -> Result<()> { + const D: usize = 2; + type C = PoseidonGoldilocksConfig; + type F = >::F; + type S = KeccakMemoryStark; + + let stark = S { + f: Default::default(), + }; + test_stark_circuit_constraints::(stark) + } +} diff --git a/evm/src/keccak_memory/mod.rs b/evm/src/keccak_memory/mod.rs new file mode 100644 index 00000000..7b5e3d01 --- /dev/null +++ b/evm/src/keccak_memory/mod.rs @@ -0,0 +1,2 @@ +pub mod columns; +pub mod keccak_memory_stark; diff --git a/evm/src/keccak_sponge/columns.rs b/evm/src/keccak_sponge/columns.rs new file mode 100644 index 00000000..08194e87 --- /dev/null +++ b/evm/src/keccak_sponge/columns.rs @@ -0,0 +1,114 @@ +use std::borrow::{Borrow, BorrowMut}; +use std::mem::{size_of, transmute}; + +use crate::util::{indices_arr, transmute_no_compile_time_size_checks}; + +pub(crate) const KECCAK_WIDTH_BYTES: usize = 200; +pub(crate) const KECCAK_WIDTH_U32S: usize = KECCAK_WIDTH_BYTES / 4; +pub(crate) const KECCAK_RATE_BYTES: usize = 136; +pub(crate) const KECCAK_RATE_U32S: usize = KECCAK_RATE_BYTES / 4; +pub(crate) const KECCAK_CAPACITY_BYTES: usize = 64; +pub(crate) const KECCAK_CAPACITY_U32S: usize = KECCAK_CAPACITY_BYTES / 4; + +#[repr(C)] +#[derive(Eq, PartialEq, Debug)] +pub(crate) struct KeccakSpongeColumnsView { + /// 1 if this row represents a full input block, i.e. one in which each byte is an input byte, + /// not a padding byte; 0 otherwise. + pub is_full_input_block: T, + + /// 1 if this row represents the final block of a sponge, in which case some or all of the bytes + /// in the block will be padding bytes; 0 otherwise. + pub is_final_block: T, + + // The address at which we will read the input block. + pub context: T, + pub segment: T, + pub virt: T, + + /// The timestamp at which inputs should be read from memory. + pub timestamp: T, + + /// The length of the original input, in bytes. + pub len: T, + + /// The number of input bytes that have already been absorbed prior to this block. + pub already_absorbed_bytes: T, + + /// If this row represents a final block row, the `i`th entry should be 1 if the final chunk of + /// input has length `i` (in other words if `len - already_absorbed == i`), otherwise 0. + /// + /// If this row represents a full input block, this should contain all 0s. + pub is_final_input_len: [T; KECCAK_RATE_BYTES], + + /// The initial rate part of the sponge, at the start of this step. + pub original_rate_u32s: [T; KECCAK_RATE_U32S], + + /// The capacity part of the sponge, encoded as 32-bit chunks, at the start of this step. + pub original_capacity_u32s: [T; KECCAK_CAPACITY_U32S], + + /// The block being absorbed, which may contain input bytes and/or padding bytes. + pub block_bytes: [T; KECCAK_RATE_BYTES], + + /// The rate part of the sponge, encoded as 32-bit chunks, after the current block is xor'd in, + /// but before the permutation is applied. + pub xored_rate_u32s: [T; KECCAK_RATE_U32S], + + /// The entire state (rate + capacity) of the sponge, encoded as 32-bit chunks, after the + /// permutation is applied. + pub updated_state_u32s: [T; KECCAK_WIDTH_U32S], +} + +// `u8` is guaranteed to have a `size_of` of 1. +pub const NUM_KECCAK_SPONGE_COLUMNS: usize = size_of::>(); + +impl From<[T; NUM_KECCAK_SPONGE_COLUMNS]> for KeccakSpongeColumnsView { + fn from(value: [T; NUM_KECCAK_SPONGE_COLUMNS]) -> Self { + unsafe { transmute_no_compile_time_size_checks(value) } + } +} + +impl From> for [T; NUM_KECCAK_SPONGE_COLUMNS] { + fn from(value: KeccakSpongeColumnsView) -> Self { + unsafe { transmute_no_compile_time_size_checks(value) } + } +} + +impl Borrow> for [T; NUM_KECCAK_SPONGE_COLUMNS] { + fn borrow(&self) -> &KeccakSpongeColumnsView { + unsafe { transmute(self) } + } +} + +impl BorrowMut> for [T; NUM_KECCAK_SPONGE_COLUMNS] { + fn borrow_mut(&mut self) -> &mut KeccakSpongeColumnsView { + unsafe { transmute(self) } + } +} + +impl Borrow<[T; NUM_KECCAK_SPONGE_COLUMNS]> for KeccakSpongeColumnsView { + fn borrow(&self) -> &[T; NUM_KECCAK_SPONGE_COLUMNS] { + unsafe { transmute(self) } + } +} + +impl BorrowMut<[T; NUM_KECCAK_SPONGE_COLUMNS]> for KeccakSpongeColumnsView { + fn borrow_mut(&mut self) -> &mut [T; NUM_KECCAK_SPONGE_COLUMNS] { + unsafe { transmute(self) } + } +} + +impl Default for KeccakSpongeColumnsView { + fn default() -> Self { + [T::default(); NUM_KECCAK_SPONGE_COLUMNS].into() + } +} + +const fn make_col_map() -> KeccakSpongeColumnsView { + let indices_arr = indices_arr::(); + unsafe { + transmute::<[usize; NUM_KECCAK_SPONGE_COLUMNS], KeccakSpongeColumnsView>(indices_arr) + } +} + +pub(crate) const KECCAK_SPONGE_COL_MAP: KeccakSpongeColumnsView = make_col_map(); diff --git a/evm/src/keccak_sponge/keccak_sponge_stark.rs b/evm/src/keccak_sponge/keccak_sponge_stark.rs new file mode 100644 index 00000000..afde02c2 --- /dev/null +++ b/evm/src/keccak_sponge/keccak_sponge_stark.rs @@ -0,0 +1,468 @@ +use std::borrow::Borrow; +use std::iter::{once, repeat}; +use std::marker::PhantomData; +use std::mem::size_of; + +use itertools::Itertools; +use plonky2::field::extension::{Extendable, FieldExtension}; +use plonky2::field::packed::PackedField; +use plonky2::field::polynomial::PolynomialValues; +use plonky2::field::types::Field; +use plonky2::hash::hash_types::RichField; +use plonky2::iop::ext_target::ExtensionTarget; +use plonky2::timed; +use plonky2::util::timing::TimingTree; +use plonky2_util::ceil_div_usize; + +use crate::constraint_consumer::{ConstraintConsumer, RecursiveConstraintConsumer}; +use crate::cpu::kernel::keccak_util::keccakf_u32s; +use crate::cross_table_lookup::Column; +use crate::keccak_sponge::columns::*; +use crate::memory::segments::Segment; +use crate::stark::Stark; +use crate::util::trace_rows_to_poly_values; +use crate::vars::StarkEvaluationTargets; +use crate::vars::StarkEvaluationVars; + +#[allow(unused)] // TODO: Should be used soon. +pub(crate) fn ctl_looked_data() -> Vec> { + let cols = KECCAK_SPONGE_COL_MAP; + let outputs = Column::singles(&cols.updated_state_u32s[..8]); + Column::singles([ + cols.context, + cols.segment, + cols.virt, + cols.timestamp, + cols.len, + ]) + .chain(outputs) + .collect() +} + +#[allow(unused)] // TODO: Should be used soon. +pub(crate) fn ctl_looking_keccak() -> Vec> { + let cols = KECCAK_SPONGE_COL_MAP; + Column::singles( + [ + cols.original_rate_u32s.as_slice(), + &cols.original_capacity_u32s, + &cols.updated_state_u32s, + ] + .concat(), + ) + .collect() +} + +#[allow(unused)] // TODO: Should be used soon. +pub(crate) fn ctl_looking_memory(i: usize) -> Vec> { + let cols = KECCAK_SPONGE_COL_MAP; + + let mut res = vec![Column::constant(F::ONE)]; // is_read + + res.extend(Column::singles([cols.context, cols.segment])); + + // The address of the byte being read is `virt + already_absorbed_bytes + i`. + res.push(Column::linear_combination_with_constant( + [(cols.virt, F::ONE), (cols.already_absorbed_bytes, F::ONE)], + F::from_canonical_usize(i), + )); + + // The i'th input byte being read. + res.push(Column::single(cols.block_bytes[i])); + + // Since we're reading a single byte, the higher limbs must be zero. + res.extend((1..8).map(|_| Column::zero())); + + res.push(Column::single(cols.timestamp)); + + assert_eq!( + res.len(), + crate::memory::memory_stark::ctl_data::().len() + ); + res +} + +/// CTL for performing the `i`th logic CTL. Since we need to do 136 byte XORs, and the logic CTL can +/// XOR 32 bytes per CTL, there are 5 such CTLs. +#[allow(unused)] // TODO: Should be used soon. +pub(crate) fn ctl_looking_logic(i: usize) -> Vec> { + const U32S_PER_CTL: usize = 8; + const U8S_PER_CTL: usize = 32; + + debug_assert!(i < ceil_div_usize(KECCAK_RATE_BYTES, U8S_PER_CTL)); + let cols = KECCAK_SPONGE_COL_MAP; + + let mut res = vec![ + Column::zero(), // is_and + Column::zero(), // is_or + Column::one(), // is_xor + ]; + + // Input 0 contains some of the sponge's original rate chunks. If this is the last CTL, we won't + // need to use all of the CTL's inputs, so we will pass some zeros. + res.extend( + Column::singles(&cols.original_rate_u32s[i * U32S_PER_CTL..]) + .chain(repeat(Column::zero())) + .take(U32S_PER_CTL), + ); + + // Input 1 contains some of block's chunks. Again, for the last CTL it will include some zeros. + res.extend( + cols.block_bytes[i * U8S_PER_CTL..] + .chunks(size_of::()) + .map(|chunk| Column::le_bytes(chunk)) + .chain(repeat(Column::zero())) + .take(U8S_PER_CTL), + ); + + // The output contains the XOR'd rate part. + res.extend( + Column::singles(&cols.xored_rate_u32s[i * U32S_PER_CTL..]) + .chain(repeat(Column::zero())) + .take(U32S_PER_CTL), + ); + + res +} + +#[allow(unused)] // TODO: Should be used soon. +pub(crate) fn ctl_looked_filter() -> Column { + // The CPU table is only interested in our final-block rows, since those contain the final + // sponge output. + Column::single(KECCAK_SPONGE_COL_MAP.is_final_block) +} + +#[allow(unused)] // TODO: Should be used soon. +/// CTL filter for reading the `i`th byte of input from memory. +pub(crate) fn ctl_looking_memory_filter(i: usize) -> Column { + // We perform the `i`th read if either + // - this is a full input block, or + // - this is a final block of length `i` or greater + let cols = KECCAK_SPONGE_COL_MAP; + Column::sum(once(&cols.is_full_input_block).chain(&cols.is_final_input_len[i..])) +} + +/// Information about a Keccak sponge operation needed for witness generation. +#[derive(Debug)] +pub(crate) struct KeccakSpongeOp { + // The address at which inputs are read. + pub(crate) context: usize, + pub(crate) segment: Segment, + pub(crate) virt: usize, + + /// The timestamp at which inputs are read. + pub(crate) timestamp: usize, + + /// The length of the input, in bytes. + pub(crate) len: usize, + + /// The input that was read. + pub(crate) input: Vec, +} + +#[derive(Copy, Clone, Default)] +pub(crate) struct KeccakSpongeStark { + f: PhantomData, +} + +impl, const D: usize> KeccakSpongeStark { + #[allow(unused)] // TODO: Should be used soon. + pub(crate) fn generate_trace( + &self, + operations: Vec, + min_rows: usize, + ) -> Vec> { + let mut timing = TimingTree::new("generate trace", log::Level::Debug); + + // Generate the witness row-wise. + let trace_rows = timed!( + &mut timing, + "generate trace rows", + self.generate_trace_rows(operations, min_rows) + ); + + let trace_polys = timed!( + &mut timing, + "convert to PolynomialValues", + trace_rows_to_poly_values(trace_rows) + ); + + timing.print(); + trace_polys + } + + fn generate_trace_rows( + &self, + operations: Vec, + min_rows: usize, + ) -> Vec<[F; NUM_KECCAK_SPONGE_COLUMNS]> { + let num_rows = operations.len().max(min_rows).next_power_of_two(); + operations + .into_iter() + .flat_map(|op| self.generate_rows_for_op(op)) + .chain(repeat(self.generate_padding_row())) + .take(num_rows) + .collect() + } + + fn generate_rows_for_op(&self, op: KeccakSpongeOp) -> Vec<[F; NUM_KECCAK_SPONGE_COLUMNS]> { + let mut rows = vec![]; + + let mut sponge_state = [0u32; KECCAK_WIDTH_U32S]; + + let mut input_blocks = op.input.chunks_exact(KECCAK_RATE_BYTES); + let mut already_absorbed_bytes = 0; + for block in input_blocks.by_ref() { + let row = self.generate_full_input_row( + &op, + already_absorbed_bytes, + sponge_state, + block.try_into().unwrap(), + ); + + sponge_state = row.updated_state_u32s.map(|f| f.to_canonical_u64() as u32); + + rows.push(row.into()); + already_absorbed_bytes += KECCAK_RATE_BYTES; + } + + rows.push( + self.generate_final_row( + &op, + already_absorbed_bytes, + sponge_state, + input_blocks.remainder(), + ) + .into(), + ); + + rows + } + + fn generate_full_input_row( + &self, + op: &KeccakSpongeOp, + already_absorbed_bytes: usize, + sponge_state: [u32; KECCAK_WIDTH_U32S], + block: [u8; KECCAK_RATE_BYTES], + ) -> KeccakSpongeColumnsView { + let mut row = KeccakSpongeColumnsView { + is_full_input_block: F::ONE, + ..Default::default() + }; + + row.block_bytes = block.map(F::from_canonical_u8); + + Self::generate_common_fields(&mut row, op, already_absorbed_bytes, sponge_state); + row + } + + fn generate_final_row( + &self, + op: &KeccakSpongeOp, + already_absorbed_bytes: usize, + sponge_state: [u32; KECCAK_WIDTH_U32S], + final_inputs: &[u8], + ) -> KeccakSpongeColumnsView { + assert_eq!(already_absorbed_bytes + final_inputs.len(), op.len); + + let mut row = KeccakSpongeColumnsView { + is_final_block: F::ONE, + ..Default::default() + }; + + for (block_byte, input_byte) in row.block_bytes.iter_mut().zip(final_inputs) { + *block_byte = F::from_canonical_u8(*input_byte); + } + + // pad10*1 rule + if final_inputs.len() == KECCAK_RATE_BYTES - 1 { + // Both 1s are placed in the same byte. + row.block_bytes[final_inputs.len()] = F::from_canonical_u8(0b10000001); + } else { + row.block_bytes[final_inputs.len()] = F::ONE; + row.block_bytes[KECCAK_RATE_BYTES - 1] = F::from_canonical_u8(0b10000000); + } + + row.is_final_input_len[final_inputs.len()] = F::ONE; + + Self::generate_common_fields(&mut row, op, already_absorbed_bytes, sponge_state); + row + } + + /// Generate fields that are common to both full-input-block rows and final-block rows. + /// Also updates the sponge state with a single absorption. + fn generate_common_fields( + row: &mut KeccakSpongeColumnsView, + op: &KeccakSpongeOp, + already_absorbed_bytes: usize, + mut sponge_state: [u32; KECCAK_WIDTH_U32S], + ) { + row.context = F::from_canonical_usize(op.context); + row.segment = F::from_canonical_usize(op.segment as usize); + row.virt = F::from_canonical_usize(op.virt); + row.timestamp = F::from_canonical_usize(op.timestamp); + row.len = F::from_canonical_usize(op.len); + row.already_absorbed_bytes = F::from_canonical_usize(already_absorbed_bytes); + + row.original_rate_u32s = sponge_state[..KECCAK_RATE_U32S] + .iter() + .map(|x| F::from_canonical_u32(*x)) + .collect_vec() + .try_into() + .unwrap(); + + row.original_capacity_u32s = sponge_state[KECCAK_RATE_U32S..] + .iter() + .map(|x| F::from_canonical_u32(*x)) + .collect_vec() + .try_into() + .unwrap(); + + let block_u32s = (0..KECCAK_RATE_U32S).map(|i| { + u32::from_le_bytes( + row.block_bytes[i * 4..(i + 1) * 4] + .iter() + .map(|x| x.to_canonical_u64() as u8) + .collect_vec() + .try_into() + .unwrap(), + ) + }); + + // xor in the block + for (state_i, block_i) in sponge_state.iter_mut().zip(block_u32s) { + *state_i ^= block_i; + } + let xored_rate_u32s: [u32; KECCAK_RATE_U32S] = sponge_state[..KECCAK_RATE_U32S] + .to_vec() + .try_into() + .unwrap(); + row.xored_rate_u32s = xored_rate_u32s.map(F::from_canonical_u32); + + keccakf_u32s(&mut sponge_state); + row.updated_state_u32s = sponge_state.map(F::from_canonical_u32); + } + + fn generate_padding_row(&self) -> [F; NUM_KECCAK_SPONGE_COLUMNS] { + // The default instance has is_full_input_block = is_final_block = 0, + // indicating that it's a dummy/padding row. + KeccakSpongeColumnsView::default().into() + } +} + +impl, const D: usize> Stark for KeccakSpongeStark { + const COLUMNS: usize = NUM_KECCAK_SPONGE_COLUMNS; + + fn eval_packed_generic( + &self, + vars: StarkEvaluationVars, + _yield_constr: &mut ConstraintConsumer

, + ) where + FE: FieldExtension, + P: PackedField, + { + let _local_values: &KeccakSpongeColumnsView

= vars.local_values.borrow(); + + // TODO: Each flag (full-input block, final block or implied dummy flag) must be boolean. + // TODO: before_rate_bits, block_bits and is_final_input_len must contain booleans. + + // TODO: Sum of is_final_input_len should equal is_final_block (which will be 0 or 1). + + // TODO: If this is the first row, the original sponge state should be 0 and already_absorbed_bytes = 0. + // TODO: If this is a final block, the next row's original sponge state should be 0 and already_absorbed_bytes = 0. + + // TODO: If this is a full-input block, the next row's address, time and len must match. + // TODO: If this is a full-input block, the next row's "before" should match our "after" state. + // TODO: If this is a full-input block, the next row's already_absorbed_bytes should be ours plus 136. + + // TODO: A dummy row is always followed by another dummy row, so the prover can't put dummy rows "in between" to avoid the above checks. + + // TODO: is_final_input_len implies `len - already_absorbed == i`. + } + + fn eval_ext_circuit( + &self, + _builder: &mut plonky2::plonk::circuit_builder::CircuitBuilder, + vars: StarkEvaluationTargets, + _yield_constr: &mut RecursiveConstraintConsumer, + ) { + let _local_values: &KeccakSpongeColumnsView> = + vars.local_values.borrow(); + + // TODO + } + + fn constraint_degree(&self) -> usize { + 3 + } +} + +#[cfg(test)] +mod tests { + use std::borrow::Borrow; + + use anyhow::Result; + use itertools::Itertools; + use keccak_hash::keccak; + use plonky2::field::goldilocks_field::GoldilocksField; + use plonky2::field::types::PrimeField64; + use plonky2::plonk::config::{GenericConfig, PoseidonGoldilocksConfig}; + + use crate::keccak_sponge::columns::KeccakSpongeColumnsView; + use crate::keccak_sponge::keccak_sponge_stark::{KeccakSpongeOp, KeccakSpongeStark}; + use crate::memory::segments::Segment; + use crate::stark_testing::{test_stark_circuit_constraints, test_stark_low_degree}; + + #[test] + fn test_stark_degree() -> Result<()> { + const D: usize = 2; + type C = PoseidonGoldilocksConfig; + type F = >::F; + type S = KeccakSpongeStark; + + let stark = S::default(); + test_stark_low_degree(stark) + } + + #[test] + fn test_stark_circuit() -> Result<()> { + const D: usize = 2; + type C = PoseidonGoldilocksConfig; + type F = >::F; + type S = KeccakSpongeStark; + + let stark = S::default(); + test_stark_circuit_constraints::(stark) + } + + #[test] + fn test_generation() -> Result<()> { + const D: usize = 2; + type F = GoldilocksField; + type S = KeccakSpongeStark; + + let input = vec![1, 2, 3]; + let expected_output = keccak(&input); + + let op = KeccakSpongeOp { + context: 0, + segment: Segment::Code, + virt: 0, + timestamp: 0, + len: input.len(), + input, + }; + let stark = S::default(); + let rows = stark.generate_rows_for_op(op); + assert_eq!(rows.len(), 1); + let last_row: &KeccakSpongeColumnsView = rows.last().unwrap().borrow(); + let output = last_row.updated_state_u32s[..8] + .iter() + .flat_map(|x| (x.to_canonical_u64() as u32).to_le_bytes()) + .collect_vec(); + + assert_eq!(output, expected_output.0); + Ok(()) + } +} diff --git a/evm/src/keccak_sponge/mod.rs b/evm/src/keccak_sponge/mod.rs new file mode 100644 index 00000000..92b7f0c1 --- /dev/null +++ b/evm/src/keccak_sponge/mod.rs @@ -0,0 +1,6 @@ +//! The Keccak sponge STARK is used to hash a variable amount of data which is read from memory. +//! It connects to the memory STARK to read input data, and to the Keccak-f STARK to evaluate the +//! permutation at each absorption step. + +pub mod columns; +pub mod keccak_sponge_stark; diff --git a/evm/src/lib.rs b/evm/src/lib.rs index 47335db2..6f332b59 100644 --- a/evm/src/lib.rs +++ b/evm/src/lib.rs @@ -2,6 +2,7 @@ #![allow(clippy::needless_range_loop)] #![allow(clippy::too_many_arguments)] #![allow(clippy::type_complexity)] +#![feature(let_chains)] #![feature(generic_const_exprs)] pub mod all_stark; @@ -13,6 +14,8 @@ pub mod cross_table_lookup; pub mod generation; mod get_challenges; pub mod keccak; +pub mod keccak_memory; +pub mod keccak_sponge; pub mod logic; pub mod lookup; pub mod memory; diff --git a/evm/src/logic.rs b/evm/src/logic.rs index bde5d645..2fa9c810 100644 --- a/evm/src/logic.rs +++ b/evm/src/logic.rs @@ -7,6 +7,7 @@ use plonky2::field::packed::PackedField; use plonky2::field::polynomial::PolynomialValues; use plonky2::field::types::Field; use plonky2::hash::hash_types::RichField; +use plonky2_util::ceil_div_usize; use crate::constraint_consumer::{ConstraintConsumer, RecursiveConstraintConsumer}; use crate::cross_table_lookup::Column; @@ -17,9 +18,9 @@ use crate::vars::{StarkEvaluationTargets, StarkEvaluationVars}; // Total number of bits per input/output. const VAL_BITS: usize = 256; // Number of bits stored per field element. Ensure that this fits; it is not checked. -pub(crate) const PACKED_LIMB_BITS: usize = 16; +pub(crate) const PACKED_LIMB_BITS: usize = 32; // Number of field elements needed to store each input/output at the specified packing. -const PACKED_LEN: usize = (VAL_BITS + PACKED_LIMB_BITS - 1) / PACKED_LIMB_BITS; +const PACKED_LEN: usize = ceil_div_usize(VAL_BITS, PACKED_LIMB_BITS); pub(crate) mod columns { use std::cmp::min; @@ -140,11 +141,10 @@ impl LogicStark { impl, const D: usize> Stark for LogicStark { const COLUMNS: usize = columns::NUM_COLUMNS; - const PUBLIC_INPUTS: usize = 0; fn eval_packed_generic( &self, - vars: StarkEvaluationVars, + vars: StarkEvaluationVars, yield_constr: &mut ConstraintConsumer

, ) where FE: FieldExtension, @@ -196,7 +196,7 @@ impl, const D: usize> Stark for LogicStark, - vars: StarkEvaluationTargets, + vars: StarkEvaluationTargets, yield_constr: &mut RecursiveConstraintConsumer, ) { let lv = &vars.local_values; diff --git a/evm/src/lookup.rs b/evm/src/lookup.rs index 2c93143f..ae92e864 100644 --- a/evm/src/lookup.rs +++ b/evm/src/lookup.rs @@ -10,13 +10,8 @@ use plonky2::plonk::circuit_builder::CircuitBuilder; use crate::constraint_consumer::{ConstraintConsumer, RecursiveConstraintConsumer}; use crate::vars::{StarkEvaluationTargets, StarkEvaluationVars}; -pub(crate) fn eval_lookups< - F: Field, - P: PackedField, - const COLS: usize, - const PUB_INPUTS: usize, ->( - vars: StarkEvaluationVars, +pub(crate) fn eval_lookups, const COLS: usize>( + vars: StarkEvaluationVars, yield_constr: &mut ConstraintConsumer

, col_permuted_input: usize, col_permuted_table: usize, @@ -42,10 +37,9 @@ pub(crate) fn eval_lookups_circuit< F: RichField + Extendable, const D: usize, const COLS: usize, - const PUB_INPUTS: usize, >( builder: &mut CircuitBuilder, - vars: StarkEvaluationTargets, + vars: StarkEvaluationTargets, yield_constr: &mut RecursiveConstraintConsumer, col_permuted_input: usize, col_permuted_table: usize, diff --git a/evm/src/memory/columns.rs b/evm/src/memory/columns.rs index 7229a834..91cc8754 100644 --- a/evm/src/memory/columns.rs +++ b/evm/src/memory/columns.rs @@ -3,7 +3,9 @@ use crate::memory::{NUM_CHANNELS, VALUE_LIMBS}; // Columns for memory operations, ordered by (addr, timestamp). -pub(crate) const TIMESTAMP: usize = 0; +/// 1 if this is an actual memory operation, or 0 if it's a padding row. +pub(crate) const FILTER: usize = 0; +pub(crate) const TIMESTAMP: usize = FILTER + 1; pub(crate) const IS_READ: usize = TIMESTAMP + 1; pub(crate) const ADDR_CONTEXT: usize = IS_READ + 1; pub(crate) const ADDR_SEGMENT: usize = ADDR_CONTEXT + 1; @@ -25,15 +27,8 @@ pub(crate) const CONTEXT_FIRST_CHANGE: usize = VALUE_START + VALUE_LIMBS; pub(crate) const SEGMENT_FIRST_CHANGE: usize = CONTEXT_FIRST_CHANGE + 1; pub(crate) const VIRTUAL_FIRST_CHANGE: usize = SEGMENT_FIRST_CHANGE + 1; -// Flags to indicate if this operation came from the `i`th channel of the memory bus. -const IS_CHANNEL_START: usize = VIRTUAL_FIRST_CHANGE + 1; -pub(crate) const fn is_channel(channel: usize) -> usize { - debug_assert!(channel < NUM_CHANNELS); - IS_CHANNEL_START + channel -} - // We use a range check to enforce the ordering. -pub(crate) const RANGE_CHECK: usize = IS_CHANNEL_START + NUM_CHANNELS; +pub(crate) const RANGE_CHECK: usize = VIRTUAL_FIRST_CHANGE + NUM_CHANNELS; // The counter column (used for the range check) starts from 0 and increments. pub(crate) const COUNTER: usize = RANGE_CHECK + 1; // Helper columns for the permutation argument used to enforce the range check. diff --git a/evm/src/memory/memory_stark.rs b/evm/src/memory/memory_stark.rs index 5a17ed20..1ec0c11c 100644 --- a/evm/src/memory/memory_stark.rs +++ b/evm/src/memory/memory_stark.rs @@ -16,18 +16,16 @@ use crate::constraint_consumer::{ConstraintConsumer, RecursiveConstraintConsumer use crate::cross_table_lookup::Column; use crate::lookup::{eval_lookups, eval_lookups_circuit, permuted_cols}; use crate::memory::columns::{ - is_channel, value_limb, ADDR_CONTEXT, ADDR_SEGMENT, ADDR_VIRTUAL, CONTEXT_FIRST_CHANGE, - COUNTER, COUNTER_PERMUTED, IS_READ, NUM_COLUMNS, RANGE_CHECK, RANGE_CHECK_PERMUTED, + value_limb, ADDR_CONTEXT, ADDR_SEGMENT, ADDR_VIRTUAL, CONTEXT_FIRST_CHANGE, COUNTER, + COUNTER_PERMUTED, FILTER, IS_READ, NUM_COLUMNS, RANGE_CHECK, RANGE_CHECK_PERMUTED, SEGMENT_FIRST_CHANGE, TIMESTAMP, VIRTUAL_FIRST_CHANGE, }; use crate::memory::segments::Segment; -use crate::memory::{NUM_CHANNELS, VALUE_LIMBS}; +use crate::memory::VALUE_LIMBS; use crate::permutation::PermutationPair; use crate::stark::Stark; use crate::vars::{StarkEvaluationTargets, StarkEvaluationVars}; -pub(crate) const NUM_PUBLIC_INPUTS: usize = 0; - pub fn ctl_data() -> Vec> { let mut res = Column::singles([IS_READ, ADDR_CONTEXT, ADDR_SEGMENT, ADDR_VIRTUAL]).collect_vec(); @@ -36,8 +34,8 @@ pub fn ctl_data() -> Vec> { res } -pub fn ctl_filter(channel: usize) -> Column { - Column::single(is_channel(channel)) +pub fn ctl_filter() -> Column { + Column::single(FILTER) } #[derive(Copy, Clone, Default)] @@ -47,8 +45,8 @@ pub struct MemoryStark { #[derive(Clone, Debug)] pub(crate) struct MemoryOp { - /// The channel this operation came from, or `None` if it's a dummy operation for padding. - pub channel_index: Option, + /// true if this is an actual memory operation, or false if it's a padding row. + pub filter: bool, pub timestamp: usize, pub is_read: bool, pub context: usize, @@ -64,9 +62,7 @@ impl MemoryOp { /// trace has been transposed into column-major form. fn to_row(&self) -> [F; NUM_COLUMNS] { let mut row = [F::ZERO; NUM_COLUMNS]; - if let Some(channel) = self.channel_index { - row[is_channel(channel)] = F::ONE; - } + row[FILTER] = F::from_bool(self.filter); row[TIMESTAMP] = F::from_canonical_usize(self.timestamp); row[IS_READ] = F::from_bool(self.is_read); row[ADDR_CONTEXT] = F::from_canonical_usize(self.context); @@ -178,12 +174,12 @@ impl, const D: usize> MemoryStark { // We essentially repeat the last operation until our operation list has the desired size, // with a few changes: - // - We change its channel to `None` to indicate that this is a dummy operation. + // - We change its filter to 0 to indicate that this is a dummy operation. // - We increment its timestamp in order to pass the ordering check. // - We make sure it's a read, sine dummy operations must be reads. for i in 0..to_pad { memory_ops.push(MemoryOp { - channel_index: None, + filter: false, timestamp: last_op.timestamp + i + 1, is_read: true, ..last_op @@ -220,11 +216,10 @@ impl, const D: usize> MemoryStark { impl, const D: usize> Stark for MemoryStark { const COLUMNS: usize = NUM_COLUMNS; - const PUBLIC_INPUTS: usize = NUM_PUBLIC_INPUTS; fn eval_packed_generic( &self, - vars: StarkEvaluationVars, + vars: StarkEvaluationVars, yield_constr: &mut ConstraintConsumer

, ) where FE: FieldExtension, @@ -245,21 +240,13 @@ impl, const D: usize> Stark for MemoryStark = (0..8).map(|i| vars.next_values[value_limb(i)]).collect(); - // Each `is_channel` value must be 0 or 1. - for c in 0..NUM_CHANNELS { - let is_channel = vars.local_values[is_channel(c)]; - yield_constr.constraint(is_channel * (is_channel - P::ONES)); - } + // The filter must be 0 or 1. + let filter = vars.local_values[FILTER]; + yield_constr.constraint(filter * (filter - P::ONES)); - // The sum of `is_channel` flags, `has_channel`, must also be 0 or 1. - let has_channel: P = (0..NUM_CHANNELS) - .map(|c| vars.local_values[is_channel(c)]) - .sum(); - yield_constr.constraint(has_channel * (has_channel - P::ONES)); - - // If this is a dummy row (with no channel), it must be a read. This means the prover can + // If this is a dummy row (filter is off), it must be a read. This means the prover can // insert reads which never appear in the CPU trace (which are harmless), but not writes. - let is_dummy = P::ONES - has_channel; + let is_dummy = P::ONES - filter; let is_write = P::ONES - vars.local_values[IS_READ]; yield_constr.constraint(is_dummy * is_write); @@ -312,7 +299,7 @@ impl, const D: usize> Stark for MemoryStark, - vars: StarkEvaluationTargets, + vars: StarkEvaluationTargets, yield_constr: &mut RecursiveConstraintConsumer, ) { let one = builder.one_extension(); @@ -330,22 +317,14 @@ impl, const D: usize> Stark for MemoryStark [Self; Self::COUNT] { [ @@ -44,7 +57,11 @@ impl Segment { Self::TxnFields, Self::TxnData, Self::RlpRaw, - Self::RipeMD, + Self::TrieData, + Self::StorageTrieAddresses, + Self::StorageTriePointers, + Self::StorageTrieCheckpointPointers, + Self::RipeMd, ] } @@ -62,7 +79,11 @@ impl Segment { Segment::TxnFields => "SEGMENT_NORMALIZED_TXN", Segment::TxnData => "SEGMENT_TXN_DATA", Segment::RlpRaw => "SEGMENT_RLP_RAW", - Segment::RipeMD => "SEGMENT_RIPEMD" + Segment::TrieData => "SEGMENT_TRIE_DATA", + Segment::StorageTrieAddresses => "SEGMENT_STORAGE_TRIE_ADDRS", + Segment::StorageTriePointers => "SEGMENT_STORAGE_TRIE_PTRS", + Segment::StorageTrieCheckpointPointers => "SEGMENT_STORAGE_TRIE_CHECKPOINT_PTRS", + Segment::RipeMd => "SEGMENT_RIPEMD" } } @@ -80,7 +101,11 @@ impl Segment { Segment::TxnFields => 256, Segment::TxnData => 256, Segment::RlpRaw => 8, - Segment::RipeMD => 8, + Segment::TrieData => 256, + Segment::StorageTrieAddresses => 160, + Segment::StorageTriePointers => 32, + Segment::StorageTrieCheckpointPointers => 32, + Segment::RipeMd => 8, } } } diff --git a/evm/src/permutation.rs b/evm/src/permutation.rs index c21a06de..0bb8ab1d 100644 --- a/evm/src/permutation.rs +++ b/evm/src/permutation.rs @@ -298,7 +298,7 @@ where pub(crate) fn eval_permutation_checks( stark: &S, config: &StarkConfig, - vars: StarkEvaluationVars, + vars: StarkEvaluationVars, permutation_vars: PermutationCheckVars, consumer: &mut ConstraintConsumer

, ) where @@ -365,14 +365,13 @@ pub(crate) fn eval_permutation_checks_circuit( builder: &mut CircuitBuilder, stark: &S, config: &StarkConfig, - vars: StarkEvaluationTargets, + vars: StarkEvaluationTargets, permutation_data: PermutationCheckDataTarget, consumer: &mut RecursiveConstraintConsumer, ) where F: RichField + Extendable, S: Stark, [(); S::COLUMNS]:, - [(); S::PUBLIC_INPUTS]:, { let PermutationCheckDataTarget { local_zs, diff --git a/evm/src/proof.rs b/evm/src/proof.rs index 4f81308d..81614e67 100644 --- a/evm/src/proof.rs +++ b/evm/src/proof.rs @@ -1,10 +1,9 @@ +use ethereum_types::{Address, U256}; use itertools::Itertools; use maybe_rayon::*; use plonky2::field::extension::{Extendable, FieldExtension}; use plonky2::fri::oracle::PolynomialBatch; -use plonky2::fri::proof::{ - CompressedFriProof, FriChallenges, FriChallengesTarget, FriProof, FriProofTarget, -}; +use plonky2::fri::proof::{FriChallenges, FriChallengesTarget, FriProof, FriProofTarget}; use plonky2::fri::structure::{ FriOpeningBatch, FriOpeningBatchTarget, FriOpenings, FriOpeningsTarget, }; @@ -13,42 +12,90 @@ use plonky2::hash::merkle_tree::MerkleCap; use plonky2::iop::ext_target::ExtensionTarget; use plonky2::iop::target::Target; use plonky2::plonk::config::GenericConfig; +use serde::{Deserialize, Serialize}; +use crate::all_stark::NUM_TABLES; use crate::config::StarkConfig; use crate::permutation::GrandProductChallengeSet; #[derive(Debug, Clone)] pub struct AllProof, C: GenericConfig, const D: usize> { - pub stark_proofs: Vec>, + pub stark_proofs: [StarkProof; NUM_TABLES], + pub public_values: PublicValues, } impl, C: GenericConfig, const D: usize> AllProof { - pub fn degree_bits(&self, config: &StarkConfig) -> Vec { - self.stark_proofs - .iter() - .map(|proof| proof.proof.recover_degree_bits(config)) - .collect() + pub fn degree_bits(&self, config: &StarkConfig) -> [usize; NUM_TABLES] { + std::array::from_fn(|i| self.stark_proofs[i].recover_degree_bits(config)) } - pub fn nums_ctl_zs(&self) -> Vec { - self.stark_proofs - .iter() - .map(|proof| proof.proof.openings.ctl_zs_last.len()) - .collect() + pub fn nums_ctl_zs(&self) -> [usize; NUM_TABLES] { + std::array::from_fn(|i| self.stark_proofs[i].openings.ctl_zs_last.len()) } } pub(crate) struct AllProofChallenges, const D: usize> { - pub stark_challenges: Vec>, + pub stark_challenges: [StarkProofChallenges; NUM_TABLES], pub ctl_challenges: GrandProductChallengeSet, } pub struct AllProofTarget { - pub stark_proofs: Vec>, + pub stark_proofs: [StarkProofTarget; NUM_TABLES], + pub public_values: PublicValuesTarget, +} + +/// Memory values which are public. +#[derive(Debug, Clone, Default)] +pub struct PublicValues { + pub trie_roots_before: TrieRoots, + pub trie_roots_after: TrieRoots, + pub block_metadata: BlockMetadata, +} + +#[derive(Debug, Clone, Default)] +pub struct TrieRoots { + pub state_root: U256, + pub transactions_root: U256, + pub receipts_root: U256, +} + +#[derive(Debug, Clone, Default, Deserialize, Serialize)] +pub struct BlockMetadata { + pub block_beneficiary: Address, + pub block_timestamp: U256, + pub block_number: U256, + pub block_difficulty: U256, + pub block_gaslimit: U256, + pub block_chain_id: U256, + pub block_base_fee: U256, +} + +/// Memory values which are public. +/// Note: All the larger integers are encoded with 32-bit limbs in little-endian order. +pub struct PublicValuesTarget { + pub trie_roots_before: TrieRootsTarget, + pub trie_roots_after: TrieRootsTarget, + pub block_metadata: BlockMetadataTarget, +} + +pub struct TrieRootsTarget { + pub state_root: [Target; 8], + pub transactions_root: [Target; 8], + pub receipts_root: [Target; 8], +} + +pub struct BlockMetadataTarget { + pub block_beneficiary: [Target; 5], + pub block_timestamp: Target, + pub block_number: Target, + pub block_difficulty: Target, + pub block_gaslimit: Target, + pub block_chain_id: Target, + pub block_base_fee: Target, } pub(crate) struct AllProofChallengesTarget { - pub stark_challenges: Vec>, + pub stark_challenges: [StarkProofChallengesTarget; NUM_TABLES], pub ctl_challenges: GrandProductChallengeSet, } @@ -98,44 +145,6 @@ impl StarkProofTarget { } } -#[derive(Debug, Clone)] -pub struct StarkProofWithPublicInputs< - F: RichField + Extendable, - C: GenericConfig, - const D: usize, -> { - pub proof: StarkProof, - // TODO: Maybe make it generic over a `S: Stark` and replace with `[F; S::PUBLIC_INPUTS]`. - pub public_inputs: Vec, -} - -pub struct StarkProofWithPublicInputsTarget { - pub proof: StarkProofTarget, - pub public_inputs: Vec, -} - -pub struct CompressedStarkProof< - F: RichField + Extendable, - C: GenericConfig, - const D: usize, -> { - /// Merkle cap of LDEs of trace values. - pub trace_cap: MerkleCap, - /// Purported values of each polynomial at the challenge point. - pub openings: StarkOpeningSet, - /// A batch FRI argument for all openings. - pub opening_proof: CompressedFriProof, -} - -pub struct CompressedStarkProofWithPublicInputs< - F: RichField + Extendable, - C: GenericConfig, - const D: usize, -> { - pub proof: CompressedStarkProof, - pub public_inputs: Vec, -} - pub(crate) struct StarkProofChallenges, const D: usize> { /// Randomness used in any permutation arguments. pub permutation_challenge_sets: Option>>, diff --git a/evm/src/prover.rs b/evm/src/prover.rs index 8be39b6c..31e76a1c 100644 --- a/evm/src/prover.rs +++ b/evm/src/prover.rs @@ -17,29 +17,30 @@ use plonky2::util::timing::TimingTree; use plonky2::util::transpose; use plonky2_util::{log2_ceil, log2_strict}; -use crate::all_stark::{AllStark, Table}; +use crate::all_stark::{AllStark, Table, NUM_TABLES}; use crate::config::StarkConfig; use crate::constraint_consumer::ConstraintConsumer; use crate::cpu::cpu_stark::CpuStark; use crate::cross_table_lookup::{cross_table_lookup_data, CtlCheckVars, CtlData}; +use crate::generation::{generate_traces, GenerationInputs}; use crate::keccak::keccak_stark::KeccakStark; +use crate::keccak_memory::keccak_memory_stark::KeccakMemoryStark; use crate::logic::LogicStark; use crate::memory::memory_stark::MemoryStark; use crate::permutation::PermutationCheckVars; use crate::permutation::{ compute_permutation_z_polys, get_n_grand_product_challenge_sets, GrandProductChallengeSet, }; -use crate::proof::{AllProof, StarkOpeningSet, StarkProof, StarkProofWithPublicInputs}; +use crate::proof::{AllProof, PublicValues, StarkOpeningSet, StarkProof}; use crate::stark::Stark; use crate::vanishing_poly::eval_vanishing_poly; use crate::vars::StarkEvaluationVars; -/// Compute all STARK proofs. +/// Generate traces, then create all STARK proofs. pub fn prove( all_stark: &AllStark, config: &StarkConfig, - trace_poly_values: Vec>>, - public_inputs: Vec>, + inputs: GenerationInputs, timing: &mut TimingTree, ) -> Result> where @@ -47,18 +48,33 @@ where C: GenericConfig, [(); C::Hasher::HASH_SIZE]:, [(); CpuStark::::COLUMNS]:, - [(); CpuStark::::PUBLIC_INPUTS]:, [(); KeccakStark::::COLUMNS]:, - [(); KeccakStark::::PUBLIC_INPUTS]:, + [(); KeccakMemoryStark::::COLUMNS]:, [(); LogicStark::::COLUMNS]:, - [(); LogicStark::::PUBLIC_INPUTS]:, [(); MemoryStark::::COLUMNS]:, - [(); MemoryStark::::PUBLIC_INPUTS]:, { - let num_starks = Table::num_tables(); - debug_assert_eq!(num_starks, trace_poly_values.len()); - debug_assert_eq!(num_starks, public_inputs.len()); + let (traces, public_values) = generate_traces(all_stark, inputs, config); + prove_with_traces(all_stark, config, traces, public_values, timing) +} +/// Compute all STARK proofs. +pub(crate) fn prove_with_traces( + all_stark: &AllStark, + config: &StarkConfig, + trace_poly_values: [Vec>; NUM_TABLES], + public_values: PublicValues, + timing: &mut TimingTree, +) -> Result> +where + F: RichField + Extendable, + C: GenericConfig, + [(); C::Hasher::HASH_SIZE]:, + [(); CpuStark::::COLUMNS]:, + [(); KeccakStark::::COLUMNS]:, + [(); KeccakMemoryStark::::COLUMNS]:, + [(); LogicStark::::COLUMNS]:, + [(); MemoryStark::::COLUMNS]:, +{ let rate_bits = config.fri_config.rate_bits; let cap_height = config.fri_config.cap_height; @@ -104,10 +120,6 @@ where &trace_poly_values[Table::Cpu as usize], &trace_commitments[Table::Cpu as usize], &ctl_data_per_table[Table::Cpu as usize], - public_inputs[Table::Cpu as usize] - .clone() - .try_into() - .unwrap(), &mut challenger, timing, )?; @@ -117,10 +129,15 @@ where &trace_poly_values[Table::Keccak as usize], &trace_commitments[Table::Keccak as usize], &ctl_data_per_table[Table::Keccak as usize], - public_inputs[Table::Keccak as usize] - .clone() - .try_into() - .unwrap(), + &mut challenger, + timing, + )?; + let keccak_memory_proof = prove_single_table( + &all_stark.keccak_memory_stark, + config, + &trace_poly_values[Table::KeccakMemory as usize], + &trace_commitments[Table::KeccakMemory as usize], + &ctl_data_per_table[Table::KeccakMemory as usize], &mut challenger, timing, )?; @@ -130,10 +147,6 @@ where &trace_poly_values[Table::Logic as usize], &trace_commitments[Table::Logic as usize], &ctl_data_per_table[Table::Logic as usize], - public_inputs[Table::Logic as usize] - .clone() - .try_into() - .unwrap(), &mut challenger, timing, )?; @@ -143,18 +156,22 @@ where &trace_poly_values[Table::Memory as usize], &trace_commitments[Table::Memory as usize], &ctl_data_per_table[Table::Memory as usize], - public_inputs[Table::Memory as usize] - .clone() - .try_into() - .unwrap(), &mut challenger, timing, )?; - let stark_proofs = vec![cpu_proof, keccak_proof, logic_proof, memory_proof]; - debug_assert_eq!(stark_proofs.len(), num_starks); + let stark_proofs = [ + cpu_proof, + keccak_proof, + keccak_memory_proof, + logic_proof, + memory_proof, + ]; - Ok(AllProof { stark_proofs }) + Ok(AllProof { + stark_proofs, + public_values, + }) } /// Compute proof for a single STARK table. @@ -164,17 +181,15 @@ fn prove_single_table( trace_poly_values: &[PolynomialValues], trace_commitment: &PolynomialBatch, ctl_data: &CtlData, - public_inputs: [F; S::PUBLIC_INPUTS], challenger: &mut Challenger, timing: &mut TimingTree, -) -> Result> +) -> Result> where F: RichField + Extendable, C: GenericConfig, S: Stark, [(); C::Hasher::HASH_SIZE]:, [(); S::COLUMNS]:, - [(); S::PUBLIC_INPUTS]:, { let degree = trace_poly_values[0].len(); let degree_bits = log2_strict(degree); @@ -228,7 +243,6 @@ where &permutation_ctl_zs_commitment, permutation_challenges.as_ref(), ctl_data, - public_inputs, alphas.clone(), degree_bits, num_permutation_zs, @@ -241,7 +255,6 @@ where &permutation_ctl_zs_commitment, permutation_challenges.as_ref(), ctl_data, - public_inputs, alphas, degree_bits, num_permutation_zs, @@ -310,17 +323,13 @@ where timing, ) ); - let proof = StarkProof { + + Ok(StarkProof { trace_cap: trace_commitment.merkle_tree.cap.clone(), permutation_ctl_zs_cap, quotient_polys_cap, openings, opening_proof, - }; - - Ok(StarkProofWithPublicInputs { - proof, - public_inputs: public_inputs.to_vec(), }) } @@ -332,7 +341,6 @@ fn compute_quotient_polys<'a, F, P, C, S, const D: usize>( permutation_ctl_zs_commitment: &'a PolynomialBatch, permutation_challenges: Option<&'a Vec>>, ctl_data: &CtlData, - public_inputs: [F; S::PUBLIC_INPUTS], alphas: Vec, degree_bits: usize, num_permutation_zs: usize, @@ -344,7 +352,6 @@ where C: GenericConfig, S: Stark, [(); S::COLUMNS]:, - [(); S::PUBLIC_INPUTS]:, { let degree = 1 << degree_bits; let rate_bits = config.fri_config.rate_bits; @@ -388,7 +395,7 @@ where let quotient_values = (0..size) .into_par_iter() .step_by(P::WIDTH) - .map(|i_start| { + .flat_map_iter(|i_start| { let i_next_start = (i_start + next_step) % size; let i_range = i_start..i_start + P::WIDTH; @@ -406,7 +413,6 @@ where let vars = StarkEvaluationVars { local_values: &get_trace_values_packed(i_start), next_values: &get_trace_values_packed(i_next_start), - public_inputs: &public_inputs, }; let permutation_check_vars = permutation_challenges.map(|permutation_challenge_sets| PermutationCheckVars { @@ -422,17 +428,15 @@ where .zs_columns .iter() .enumerate() - .map( - |(i, (_, columns, filter_column))| CtlCheckVars:: { - local_z: permutation_ctl_zs_commitment.get_lde_values_packed(i_start, step) - [num_permutation_zs + i], - next_z: permutation_ctl_zs_commitment - .get_lde_values_packed(i_next_start, step)[num_permutation_zs + i], - challenges: ctl_data.challenges.challenges[i % config.num_challenges], - columns, - filter_column, - }, - ) + .map(|(i, zs_columns)| CtlCheckVars:: { + local_z: permutation_ctl_zs_commitment.get_lde_values_packed(i_start, step) + [num_permutation_zs + i], + next_z: permutation_ctl_zs_commitment.get_lde_values_packed(i_next_start, step) + [num_permutation_zs + i], + challenges: zs_columns.challenge, + columns: &zs_columns.columns, + filter_column: &zs_columns.filter_column, + }) .collect::>(); eval_vanishing_poly::( stark, @@ -444,11 +448,18 @@ where ); let mut constraints_evals = consumer.accumulators(); // We divide the constraints evaluations by `Z_H(x)`. - let denominator_inv = z_h_on_coset.eval_inverse_packed(i_start); + let denominator_inv: P = z_h_on_coset.eval_inverse_packed(i_start); for eval in &mut constraints_evals { *eval *= denominator_inv; } - constraints_evals + + let num_challenges = alphas.len(); + + (0..P::WIDTH).into_iter().map(move |i| { + (0..num_challenges) + .map(|j| constraints_evals[j].as_slice()[i]) + .collect() + }) }) .collect::>(); @@ -467,7 +478,6 @@ fn check_constraints<'a, F, C, S, const D: usize>( permutation_ctl_zs_commitment: &'a PolynomialBatch, permutation_challenges: Option<&'a Vec>>, ctl_data: &CtlData, - public_inputs: [F; S::PUBLIC_INPUTS], alphas: Vec, degree_bits: usize, num_permutation_zs: usize, @@ -477,7 +487,6 @@ fn check_constraints<'a, F, C, S, const D: usize>( C: GenericConfig, S: Stark, [(); S::COLUMNS]:, - [(); S::PUBLIC_INPUTS]:, { let degree = 1 << degree_bits; let rate_bits = 0; // Set this to higher value to check constraint degree. @@ -526,7 +535,6 @@ fn check_constraints<'a, F, C, S, const D: usize>( let vars = StarkEvaluationVars { local_values: trace_subgroup_evals[i].as_slice().try_into().unwrap(), next_values: trace_subgroup_evals[i_next].as_slice().try_into().unwrap(), - public_inputs: &public_inputs, }; let permutation_check_vars = permutation_challenges.map(|permutation_challenge_sets| PermutationCheckVars { @@ -540,15 +548,13 @@ fn check_constraints<'a, F, C, S, const D: usize>( .zs_columns .iter() .enumerate() - .map( - |(iii, (_, columns, filter_column))| CtlCheckVars:: { - local_z: permutation_ctl_zs_subgroup_evals[i][num_permutation_zs + iii], - next_z: permutation_ctl_zs_subgroup_evals[i_next][num_permutation_zs + iii], - challenges: ctl_data.challenges.challenges[iii % config.num_challenges], - columns, - filter_column, - }, - ) + .map(|(iii, zs_columns)| CtlCheckVars:: { + local_z: permutation_ctl_zs_subgroup_evals[i][num_permutation_zs + iii], + next_z: permutation_ctl_zs_subgroup_evals[i_next][num_permutation_zs + iii], + challenges: zs_columns.challenge, + columns: &zs_columns.columns, + filter_column: &zs_columns.filter_column, + }) .collect::>(); eval_vanishing_poly::( stark, diff --git a/evm/src/recursive_verifier.rs b/evm/src/recursive_verifier.rs index b69a5519..000efce9 100644 --- a/evm/src/recursive_verifier.rs +++ b/evm/src/recursive_verifier.rs @@ -17,15 +17,17 @@ use crate::constraint_consumer::RecursiveConstraintConsumer; use crate::cpu::cpu_stark::CpuStark; use crate::cross_table_lookup::{verify_cross_table_lookups_circuit, CtlCheckVarsTarget}; use crate::keccak::keccak_stark::KeccakStark; +use crate::keccak_memory::keccak_memory_stark::KeccakMemoryStark; use crate::logic::LogicStark; use crate::memory::memory_stark::MemoryStark; use crate::permutation::PermutationCheckDataTarget; use crate::proof::{ - AllProof, AllProofChallengesTarget, AllProofTarget, StarkOpeningSetTarget, StarkProof, - StarkProofChallengesTarget, StarkProofTarget, StarkProofWithPublicInputs, - StarkProofWithPublicInputsTarget, + AllProof, AllProofChallengesTarget, AllProofTarget, BlockMetadata, BlockMetadataTarget, + PublicValues, PublicValuesTarget, StarkOpeningSetTarget, StarkProof, + StarkProofChallengesTarget, StarkProofTarget, TrieRoots, TrieRootsTarget, }; use crate::stark::Stark; +use crate::util::{h160_limbs, u256_limbs}; use crate::vanishing_poly::eval_vanishing_poly_circuit; use crate::vars::StarkEvaluationTargets; @@ -40,13 +42,10 @@ pub fn verify_proof_circuit< inner_config: &StarkConfig, ) where [(); CpuStark::::COLUMNS]:, - [(); CpuStark::::PUBLIC_INPUTS]:, [(); KeccakStark::::COLUMNS]:, - [(); KeccakStark::::PUBLIC_INPUTS]:, + [(); KeccakMemoryStark::::COLUMNS]:, [(); LogicStark::::COLUMNS]:, - [(); LogicStark::::PUBLIC_INPUTS]:, [(); MemoryStark::::COLUMNS]:, - [(); MemoryStark::::PUBLIC_INPUTS]:, C::Hasher: AlgebraicHasher, { let AllProofChallengesTarget { @@ -59,6 +58,7 @@ pub fn verify_proof_circuit< let AllStark { cpu_stark, keccak_stark, + keccak_memory_stark, logic_stark, memory_stark, cross_table_lookups, @@ -95,6 +95,18 @@ pub fn verify_proof_circuit< inner_config, ) ); + with_context!( + builder, + "verify Keccak memory proof", + verify_stark_proof_with_challenges_circuit::( + builder, + keccak_memory_stark, + &all_proof.stark_proofs[Table::KeccakMemory as usize], + &stark_challenges[Table::KeccakMemory as usize], + &ctl_vars_per_table[Table::KeccakMemory as usize], + inner_config, + ) + ); with_context!( builder, "verify logic proof", @@ -142,23 +154,17 @@ fn verify_stark_proof_with_challenges_circuit< >( builder: &mut CircuitBuilder, stark: S, - proof_with_pis: &StarkProofWithPublicInputsTarget, + proof: &StarkProofTarget, challenges: &StarkProofChallengesTarget, ctl_vars: &[CtlCheckVarsTarget], inner_config: &StarkConfig, ) where C::Hasher: AlgebraicHasher, [(); S::COLUMNS]:, - [(); S::PUBLIC_INPUTS]:, { let zero = builder.zero(); let one = builder.one_extension(); - let StarkProofWithPublicInputsTarget { - proof, - public_inputs, - } = proof_with_pis; - assert_eq!(public_inputs.len(), S::PUBLIC_INPUTS); let StarkOpeningSetTarget { local_values, next_values, @@ -170,19 +176,13 @@ fn verify_stark_proof_with_challenges_circuit< let vars = StarkEvaluationTargets { local_values: &local_values.to_vec().try_into().unwrap(), next_values: &next_values.to_vec().try_into().unwrap(), - public_inputs: &public_inputs - .iter() - .map(|&t| builder.convert_to_ext(t)) - .collect::>() - .try_into() - .unwrap(), }; let degree_bits = proof.recover_degree_bits(inner_config); let zeta_pow_deg = builder.exp_power_of_2_extension(challenges.stark_zeta, degree_bits); let z_h_zeta = builder.sub_extension(zeta_pow_deg, one); - let (l_1, l_last) = - eval_l_1_and_l_last_circuit(builder, degree_bits, challenges.stark_zeta, z_h_zeta); + let (l_0, l_last) = + eval_l_0_and_l_last_circuit(builder, degree_bits, challenges.stark_zeta, z_h_zeta); let last = builder.constant_extension(F::Extension::primitive_root_of_unity(degree_bits).inverse()); let z_last = builder.sub_extension(challenges.stark_zeta, last); @@ -191,7 +191,7 @@ fn verify_stark_proof_with_challenges_circuit< builder.zero_extension(), challenges.stark_alphas.clone(), z_last, - l_1, + l_0, l_last, ); @@ -254,7 +254,7 @@ fn verify_stark_proof_with_challenges_circuit< ); } -fn eval_l_1_and_l_last_circuit, const D: usize>( +fn eval_l_0_and_l_last_circuit, const D: usize>( builder: &mut CircuitBuilder, log_n: usize, x: ExtensionTarget, @@ -263,12 +263,12 @@ fn eval_l_1_and_l_last_circuit, const D: usize>( let n = builder.constant_extension(F::Extension::from_canonical_usize(1 << log_n)); let g = builder.constant_extension(F::Extension::primitive_root_of_unity(log_n)); let one = builder.one_extension(); - let l_1_deno = builder.mul_sub_extension(n, x, n); + let l_0_deno = builder.mul_sub_extension(n, x, n); let l_last_deno = builder.mul_sub_extension(g, x, one); let l_last_deno = builder.mul_extension(n, l_last_deno); ( - builder.div_extension(z_x, l_1_deno), + builder.div_extension(z_x, l_0_deno), builder.div_extension(z_x, l_last_deno), ) } @@ -280,85 +280,95 @@ pub fn add_virtual_all_proof, const D: usize>( degree_bits: &[usize], nums_ctl_zs: &[usize], ) -> AllProofTarget { - let stark_proofs = vec![ - { - let proof = add_virtual_stark_proof( - builder, - all_stark.cpu_stark, - config, - degree_bits[Table::Cpu as usize], - nums_ctl_zs[Table::Cpu as usize], - ); - let public_inputs = builder.add_virtual_targets(CpuStark::::PUBLIC_INPUTS); - StarkProofWithPublicInputsTarget { - proof, - public_inputs, - } - }, - { - let proof = add_virtual_stark_proof( - builder, - all_stark.keccak_stark, - config, - degree_bits[Table::Keccak as usize], - nums_ctl_zs[Table::Keccak as usize], - ); - let public_inputs = builder.add_virtual_targets(KeccakStark::::PUBLIC_INPUTS); - StarkProofWithPublicInputsTarget { - proof, - public_inputs, - } - }, - { - let proof = add_virtual_stark_proof( - builder, - all_stark.logic_stark, - config, - degree_bits[Table::Logic as usize], - nums_ctl_zs[Table::Logic as usize], - ); - let public_inputs = builder.add_virtual_targets(LogicStark::::PUBLIC_INPUTS); - StarkProofWithPublicInputsTarget { - proof, - public_inputs, - } - }, - { - let proof = add_virtual_stark_proof( - builder, - all_stark.memory_stark, - config, - degree_bits[Table::Memory as usize], - nums_ctl_zs[Table::Memory as usize], - ); - let public_inputs = builder.add_virtual_targets(KeccakStark::::PUBLIC_INPUTS); - StarkProofWithPublicInputsTarget { - proof, - public_inputs, - } - }, + let stark_proofs = [ + add_virtual_stark_proof( + builder, + all_stark.cpu_stark, + config, + degree_bits[Table::Cpu as usize], + nums_ctl_zs[Table::Cpu as usize], + ), + add_virtual_stark_proof( + builder, + all_stark.keccak_stark, + config, + degree_bits[Table::Keccak as usize], + nums_ctl_zs[Table::Keccak as usize], + ), + add_virtual_stark_proof( + builder, + all_stark.keccak_memory_stark, + config, + degree_bits[Table::KeccakMemory as usize], + nums_ctl_zs[Table::KeccakMemory as usize], + ), + add_virtual_stark_proof( + builder, + all_stark.logic_stark, + config, + degree_bits[Table::Logic as usize], + nums_ctl_zs[Table::Logic as usize], + ), + add_virtual_stark_proof( + builder, + all_stark.memory_stark, + config, + degree_bits[Table::Memory as usize], + nums_ctl_zs[Table::Memory as usize], + ), ]; - assert_eq!(stark_proofs.len(), Table::num_tables()); - AllProofTarget { stark_proofs } + let public_values = add_virtual_public_values(builder); + AllProofTarget { + stark_proofs, + public_values, + } } -pub fn add_virtual_stark_proof_with_pis< - F: RichField + Extendable, - S: Stark, - const D: usize, ->( +pub fn add_virtual_public_values, const D: usize>( builder: &mut CircuitBuilder, - stark: S, - config: &StarkConfig, - degree_bits: usize, - num_ctl_zs: usize, -) -> StarkProofWithPublicInputsTarget { - let proof = add_virtual_stark_proof::(builder, stark, config, degree_bits, num_ctl_zs); - let public_inputs = builder.add_virtual_targets(S::PUBLIC_INPUTS); - StarkProofWithPublicInputsTarget { - proof, - public_inputs, +) -> PublicValuesTarget { + let trie_roots_before = add_virtual_trie_roots(builder); + let trie_roots_after = add_virtual_trie_roots(builder); + let block_metadata = add_virtual_block_metadata(builder); + PublicValuesTarget { + trie_roots_before, + trie_roots_after, + block_metadata, + } +} + +pub fn add_virtual_trie_roots, const D: usize>( + builder: &mut CircuitBuilder, +) -> TrieRootsTarget { + let state_root = builder.add_virtual_target_arr(); + let transactions_root = builder.add_virtual_target_arr(); + let receipts_root = builder.add_virtual_target_arr(); + TrieRootsTarget { + state_root, + transactions_root, + receipts_root, + } +} + +pub fn add_virtual_block_metadata, const D: usize>( + builder: &mut CircuitBuilder, +) -> BlockMetadataTarget { + let block_beneficiary = builder.add_virtual_target_arr(); + let block_timestamp = builder.add_virtual_target(); + let block_number = builder.add_virtual_target(); + let block_difficulty = builder.add_virtual_target(); + let block_gaslimit = builder.add_virtual_target(); + let block_chain_id = builder.add_virtual_target(); + let block_base_fee = builder.add_virtual_target(); + BlockMetadataTarget { + block_beneficiary, + block_timestamp, + block_number, + block_difficulty, + block_gaslimit, + block_chain_id, + block_base_fee, } } @@ -424,35 +434,13 @@ pub fn set_all_proof_target, W, const D: usize>( .iter() .zip_eq(&all_proof.stark_proofs) { - set_stark_proof_with_pis_target(witness, pt, p, zero); + set_stark_proof_target(witness, pt, p, zero); } -} - -pub fn set_stark_proof_with_pis_target, W, const D: usize>( - witness: &mut W, - stark_proof_with_pis_target: &StarkProofWithPublicInputsTarget, - stark_proof_with_pis: &StarkProofWithPublicInputs, - zero: Target, -) where - F: RichField + Extendable, - C::Hasher: AlgebraicHasher, - W: Witness, -{ - let StarkProofWithPublicInputs { - proof, - public_inputs, - } = stark_proof_with_pis; - let StarkProofWithPublicInputsTarget { - proof: pt, - public_inputs: pi_targets, - } = stark_proof_with_pis_target; - - // Set public inputs. - for (&pi_t, &pi) in pi_targets.iter().zip_eq(public_inputs) { - witness.set_target(pi_t, pi); - } - - set_stark_proof_target(witness, pt, proof, zero); + set_public_value_targets( + witness, + &all_proof_target.public_values, + &all_proof.public_values, + ) } pub fn set_stark_proof_target, W, const D: usize>( @@ -480,3 +468,88 @@ pub fn set_stark_proof_target, W, const D: usize>( set_fri_proof_target(witness, &proof_target.opening_proof, &proof.opening_proof); } + +pub fn set_public_value_targets( + witness: &mut W, + public_values_target: &PublicValuesTarget, + public_values: &PublicValues, +) where + F: RichField + Extendable, + W: Witness, +{ + set_trie_roots_target( + witness, + &public_values_target.trie_roots_before, + &public_values.trie_roots_before, + ); + set_trie_roots_target( + witness, + &public_values_target.trie_roots_after, + &public_values.trie_roots_after, + ); + set_block_metadata_target( + witness, + &public_values_target.block_metadata, + &public_values.block_metadata, + ); +} + +pub fn set_trie_roots_target( + witness: &mut W, + trie_roots_target: &TrieRootsTarget, + trie_roots: &TrieRoots, +) where + F: RichField + Extendable, + W: Witness, +{ + witness.set_target_arr( + trie_roots_target.state_root, + u256_limbs(trie_roots.state_root), + ); + witness.set_target_arr( + trie_roots_target.transactions_root, + u256_limbs(trie_roots.transactions_root), + ); + witness.set_target_arr( + trie_roots_target.receipts_root, + u256_limbs(trie_roots.receipts_root), + ); +} + +pub fn set_block_metadata_target( + witness: &mut W, + block_metadata_target: &BlockMetadataTarget, + block_metadata: &BlockMetadata, +) where + F: RichField + Extendable, + W: Witness, +{ + witness.set_target_arr( + block_metadata_target.block_beneficiary, + h160_limbs(block_metadata.block_beneficiary), + ); + witness.set_target( + block_metadata_target.block_timestamp, + F::from_canonical_u64(block_metadata.block_timestamp.as_u64()), + ); + witness.set_target( + block_metadata_target.block_number, + F::from_canonical_u64(block_metadata.block_number.as_u64()), + ); + witness.set_target( + block_metadata_target.block_difficulty, + F::from_canonical_u64(block_metadata.block_difficulty.as_u64()), + ); + witness.set_target( + block_metadata_target.block_gaslimit, + F::from_canonical_u64(block_metadata.block_gaslimit.as_u64()), + ); + witness.set_target( + block_metadata_target.block_chain_id, + F::from_canonical_u64(block_metadata.block_chain_id.as_u64()), + ); + witness.set_target( + block_metadata_target.block_base_fee, + F::from_canonical_u64(block_metadata.block_base_fee.as_u64()), + ); +} diff --git a/evm/src/stark.rs b/evm/src/stark.rs index 8935655b..a205547a 100644 --- a/evm/src/stark.rs +++ b/evm/src/stark.rs @@ -20,8 +20,6 @@ use crate::vars::StarkEvaluationVars; pub trait Stark, const D: usize>: Sync { /// The total number of columns in the trace. const COLUMNS: usize; - /// The number of public inputs. - const PUBLIC_INPUTS: usize; /// Evaluate constraints at a vector of points. /// @@ -31,7 +29,7 @@ pub trait Stark, const D: usize>: Sync { /// constraints over `F`. fn eval_packed_generic( &self, - vars: StarkEvaluationVars, + vars: StarkEvaluationVars, yield_constr: &mut ConstraintConsumer

, ) where FE: FieldExtension, @@ -40,7 +38,7 @@ pub trait Stark, const D: usize>: Sync { /// Evaluate constraints at a vector of points from the base field `F`. fn eval_packed_base>( &self, - vars: StarkEvaluationVars, + vars: StarkEvaluationVars, yield_constr: &mut ConstraintConsumer

, ) { self.eval_packed_generic(vars, yield_constr) @@ -49,12 +47,7 @@ pub trait Stark, const D: usize>: Sync { /// Evaluate constraints at a single point from the degree `D` extension field. fn eval_ext( &self, - vars: StarkEvaluationVars< - F::Extension, - F::Extension, - { Self::COLUMNS }, - { Self::PUBLIC_INPUTS }, - >, + vars: StarkEvaluationVars, yield_constr: &mut ConstraintConsumer, ) { self.eval_packed_generic(vars, yield_constr) @@ -67,7 +60,7 @@ pub trait Stark, const D: usize>: Sync { fn eval_ext_circuit( &self, builder: &mut CircuitBuilder, - vars: StarkEvaluationTargets, + vars: StarkEvaluationTargets, yield_constr: &mut RecursiveConstraintConsumer, ); diff --git a/evm/src/stark_testing.rs b/evm/src/stark_testing.rs index 809423d4..81b0f68f 100644 --- a/evm/src/stark_testing.rs +++ b/evm/src/stark_testing.rs @@ -26,13 +26,11 @@ pub fn test_stark_low_degree, S: Stark, const ) -> Result<()> where [(); S::COLUMNS]:, - [(); S::PUBLIC_INPUTS]:, { let rate_bits = log2_ceil(stark.constraint_degree() + 1); let trace_ldes = random_low_degree_matrix::(S::COLUMNS, rate_bits); let size = trace_ldes.len(); - let public_inputs = F::rand_arr::<{ S::PUBLIC_INPUTS }>(); let lagrange_first = PolynomialValues::selector(WITNESS_SIZE, 0).lde(rate_bits); let lagrange_last = PolynomialValues::selector(WITNESS_SIZE, WITNESS_SIZE - 1).lde(rate_bits); @@ -49,7 +47,6 @@ where .clone() .try_into() .unwrap(), - public_inputs: &public_inputs, }; let mut consumer = ConstraintConsumer::::new( @@ -63,17 +60,20 @@ where }) .collect::>(); - let constraint_eval_degree = PolynomialValues::new(constraint_evals).degree(); - let maximum_degree = WITNESS_SIZE * stark.constraint_degree() - 1; + let constraint_poly_values = PolynomialValues::new(constraint_evals); + if !constraint_poly_values.is_zero() { + let constraint_eval_degree = constraint_poly_values.degree(); + let maximum_degree = WITNESS_SIZE * stark.constraint_degree() - 1; - ensure!( - constraint_eval_degree <= maximum_degree, - "Expected degrees at most {} * {} - 1 = {}, actual {:?}", - WITNESS_SIZE, - stark.constraint_degree(), - maximum_degree, - constraint_eval_degree - ); + ensure!( + constraint_eval_degree <= maximum_degree, + "Expected degrees at most {} * {} - 1 = {}, actual {:?}", + WITNESS_SIZE, + stark.constraint_degree(), + maximum_degree, + constraint_eval_degree + ); + } Ok(()) } @@ -89,14 +89,12 @@ pub fn test_stark_circuit_constraints< ) -> Result<()> where [(); S::COLUMNS]:, - [(); S::PUBLIC_INPUTS]:, [(); C::Hasher::HASH_SIZE]:, { // Compute native constraint evaluation on random values. let vars = StarkEvaluationVars { local_values: &F::Extension::rand_arr::<{ S::COLUMNS }>(), next_values: &F::Extension::rand_arr::<{ S::COLUMNS }>(), - public_inputs: &F::Extension::rand_arr::<{ S::PUBLIC_INPUTS }>(), }; let alphas = F::rand_vec(1); let z_last = F::Extension::rand(); @@ -124,8 +122,6 @@ where pw.set_extension_targets(&locals_t, vars.local_values); let nexts_t = builder.add_virtual_extension_targets(S::COLUMNS); pw.set_extension_targets(&nexts_t, vars.next_values); - let pis_t = builder.add_virtual_extension_targets(S::PUBLIC_INPUTS); - pw.set_extension_targets(&pis_t, vars.public_inputs); let alphas_t = builder.add_virtual_targets(1); pw.set_target(alphas_t[0], alphas[0]); let z_last_t = builder.add_virtual_extension_target(); @@ -135,10 +131,9 @@ where let lagrange_last_t = builder.add_virtual_extension_target(); pw.set_extension_target(lagrange_last_t, lagrange_last); - let vars = StarkEvaluationTargets:: { + let vars = StarkEvaluationTargets:: { local_values: &locals_t.try_into().unwrap(), next_values: &nexts_t.try_into().unwrap(), - public_inputs: &pis_t.try_into().unwrap(), }; let mut consumer = RecursiveConstraintConsumer::::new( builder.zero_extension(), diff --git a/evm/src/util.rs b/evm/src/util.rs index 5bc85f99..12aead46 100644 --- a/evm/src/util.rs +++ b/evm/src/util.rs @@ -1,3 +1,6 @@ +use std::mem::{size_of, transmute_copy, ManuallyDrop}; + +use ethereum_types::{H160, U256}; use itertools::Itertools; use plonky2::field::extension::Extendable; use plonky2::field::packed::PackedField; @@ -40,3 +43,47 @@ pub fn trace_rows_to_poly_values( .map(|column| PolynomialValues::new(column)) .collect() } + +/// Returns the 32-bit little-endian limbs of a `U256`. +pub(crate) fn u256_limbs(u256: U256) -> [F; 8] { + u256.0 + .into_iter() + .flat_map(|limb_64| { + let lo = limb_64 as u32; + let hi = (limb_64 >> 32) as u32; + [lo, hi] + }) + .map(F::from_canonical_u32) + .collect_vec() + .try_into() + .unwrap() +} + +/// Returns the 32-bit limbs of a `U160`. +pub(crate) fn h160_limbs(h160: H160) -> [F; 5] { + h160.0 + .chunks(4) + .map(|chunk| u32::from_le_bytes(chunk.try_into().unwrap())) + .map(F::from_canonical_u32) + .collect_vec() + .try_into() + .unwrap() +} + +pub(crate) const fn indices_arr() -> [usize; N] { + let mut indices_arr = [0; N]; + let mut i = 0; + while i < N { + indices_arr[i] = i; + i += 1; + } + indices_arr +} + +pub(crate) unsafe fn transmute_no_compile_time_size_checks(value: T) -> U { + debug_assert_eq!(size_of::(), size_of::()); + // Need ManuallyDrop so that `value` is not dropped by this function. + let value = ManuallyDrop::new(value); + // Copy the bit pattern. The original value is no longer safe to use. + transmute_copy(&value) +} diff --git a/evm/src/vanishing_poly.rs b/evm/src/vanishing_poly.rs index c0a6534b..e776fa5c 100644 --- a/evm/src/vanishing_poly.rs +++ b/evm/src/vanishing_poly.rs @@ -20,7 +20,7 @@ use crate::vars::{StarkEvaluationTargets, StarkEvaluationVars}; pub(crate) fn eval_vanishing_poly( stark: &S, config: &StarkConfig, - vars: StarkEvaluationVars, + vars: StarkEvaluationVars, permutation_vars: Option>, ctl_vars: &[CtlCheckVars], consumer: &mut ConstraintConsumer

, @@ -48,7 +48,7 @@ pub(crate) fn eval_vanishing_poly_circuit( builder: &mut CircuitBuilder, stark: &S, config: &StarkConfig, - vars: StarkEvaluationTargets, + vars: StarkEvaluationTargets, permutation_data: Option>, ctl_vars: &[CtlCheckVarsTarget], consumer: &mut RecursiveConstraintConsumer, @@ -57,7 +57,6 @@ pub(crate) fn eval_vanishing_poly_circuit( C: GenericConfig, S: Stark, [(); S::COLUMNS]:, - [(); S::PUBLIC_INPUTS]:, { stark.eval_ext_circuit(builder, vars, consumer); if let Some(permutation_data) = permutation_data { diff --git a/evm/src/vars.rs b/evm/src/vars.rs index 682ac837..6c82675c 100644 --- a/evm/src/vars.rs +++ b/evm/src/vars.rs @@ -3,24 +3,17 @@ use plonky2::field::types::Field; use plonky2::iop::ext_target::ExtensionTarget; #[derive(Debug, Copy, Clone)] -pub struct StarkEvaluationVars<'a, F, P, const COLUMNS: usize, const PUBLIC_INPUTS: usize> +pub struct StarkEvaluationVars<'a, F, P, const COLUMNS: usize> where F: Field, P: PackedField, { pub local_values: &'a [P; COLUMNS], pub next_values: &'a [P; COLUMNS], - pub public_inputs: &'a [P::Scalar; PUBLIC_INPUTS], } #[derive(Debug, Copy, Clone)] -pub struct StarkEvaluationTargets< - 'a, - const D: usize, - const COLUMNS: usize, - const PUBLIC_INPUTS: usize, -> { +pub struct StarkEvaluationTargets<'a, const D: usize, const COLUMNS: usize> { pub local_values: &'a [ExtensionTarget; COLUMNS], pub next_values: &'a [ExtensionTarget; COLUMNS], - pub public_inputs: &'a [ExtensionTarget; PUBLIC_INPUTS], } diff --git a/evm/src/verifier.rs b/evm/src/verifier.rs index 1b46dc90..3f5a5a88 100644 --- a/evm/src/verifier.rs +++ b/evm/src/verifier.rs @@ -12,11 +12,12 @@ use crate::constraint_consumer::ConstraintConsumer; use crate::cpu::cpu_stark::CpuStark; use crate::cross_table_lookup::{verify_cross_table_lookups, CtlCheckVars}; use crate::keccak::keccak_stark::KeccakStark; +use crate::keccak_memory::keccak_memory_stark::KeccakMemoryStark; use crate::logic::LogicStark; use crate::memory::memory_stark::MemoryStark; use crate::permutation::PermutationCheckVars; use crate::proof::{ - AllProof, AllProofChallenges, StarkOpeningSet, StarkProofChallenges, StarkProofWithPublicInputs, + AllProof, AllProofChallenges, StarkOpeningSet, StarkProof, StarkProofChallenges, }; use crate::stark::Stark; use crate::vanishing_poly::eval_vanishing_poly; @@ -29,13 +30,10 @@ pub fn verify_proof, C: GenericConfig, co ) -> Result<()> where [(); CpuStark::::COLUMNS]:, - [(); CpuStark::::PUBLIC_INPUTS]:, [(); KeccakStark::::COLUMNS]:, - [(); KeccakStark::::PUBLIC_INPUTS]:, + [(); KeccakMemoryStark::::COLUMNS]:, [(); LogicStark::::COLUMNS]:, - [(); LogicStark::::PUBLIC_INPUTS]:, [(); MemoryStark::::COLUMNS]:, - [(); MemoryStark::::PUBLIC_INPUTS]:, [(); C::Hasher::HASH_SIZE]:, { let AllProofChallenges { @@ -48,6 +46,7 @@ where let AllStark { cpu_stark, keccak_stark, + keccak_memory_stark, logic_stark, memory_stark, cross_table_lookups, @@ -74,6 +73,13 @@ where &ctl_vars_per_table[Table::Keccak as usize], config, )?; + verify_stark_proof_with_challenges( + keccak_memory_stark, + &all_proof.stark_proofs[Table::KeccakMemory as usize], + &stark_challenges[Table::KeccakMemory as usize], + &ctl_vars_per_table[Table::KeccakMemory as usize], + config, + )?; verify_stark_proof_with_challenges( memory_stark, &all_proof.stark_proofs[Table::Memory as usize], @@ -104,21 +110,15 @@ pub(crate) fn verify_stark_proof_with_challenges< const D: usize, >( stark: S, - proof_with_pis: &StarkProofWithPublicInputs, + proof: &StarkProof, challenges: &StarkProofChallenges, ctl_vars: &[CtlCheckVars], config: &StarkConfig, ) -> Result<()> where [(); S::COLUMNS]:, - [(); S::PUBLIC_INPUTS]:, [(); C::Hasher::HASH_SIZE]:, { - let StarkProofWithPublicInputs { - proof, - public_inputs, - } = proof_with_pis; - ensure!(public_inputs.len() == S::PUBLIC_INPUTS); let StarkOpeningSet { local_values, next_values, @@ -130,17 +130,10 @@ where let vars = StarkEvaluationVars { local_values: &local_values.to_vec().try_into().unwrap(), next_values: &next_values.to_vec().try_into().unwrap(), - public_inputs: &public_inputs - .iter() - .copied() - .map(F::Extension::from_basefield) - .collect::>() - .try_into() - .unwrap(), }; let degree_bits = proof.recover_degree_bits(config); - let (l_1, l_last) = eval_l_1_and_l_last(degree_bits, challenges.stark_zeta); + let (l_0, l_last) = eval_l_0_and_l_last(degree_bits, challenges.stark_zeta); let last = F::primitive_root_of_unity(degree_bits).inverse(); let z_last = challenges.stark_zeta - last.into(); let mut consumer = ConstraintConsumer::::new( @@ -150,7 +143,7 @@ where .map(|&alpha| F::Extension::from_basefield(alpha)) .collect::>(), z_last, - l_1, + l_0, l_last, ); let num_permutation_zs = stark.num_permutation_batches(config); @@ -211,10 +204,10 @@ where Ok(()) } -/// Evaluate the Lagrange polynomials `L_1` and `L_n` at a point `x`. -/// `L_1(x) = (x^n - 1)/(n * (x - 1))` -/// `L_n(x) = (x^n - 1)/(n * (g * x - 1))`, with `g` the first element of the subgroup. -fn eval_l_1_and_l_last(log_n: usize, x: F) -> (F, F) { +/// Evaluate the Lagrange polynomials `L_0` and `L_(n-1)` at a point `x`. +/// `L_0(x) = (x^n - 1)/(n * (x - 1))` +/// `L_(n-1)(x) = (x^n - 1)/(n * (g * x - 1))`, with `g` the first element of the subgroup. +fn eval_l_0_and_l_last(log_n: usize, x: F) -> (F, F) { let n = F::from_canonical_usize(1 << log_n); let g = F::primitive_root_of_unity(log_n); let z_x = x.exp_power_of_2(log_n) - F::ONE; @@ -229,10 +222,10 @@ mod tests { use plonky2::field::polynomial::PolynomialValues; use plonky2::field::types::Field; - use crate::verifier::eval_l_1_and_l_last; + use crate::verifier::eval_l_0_and_l_last; #[test] - fn test_eval_l_1_and_l_last() { + fn test_eval_l_0_and_l_last() { type F = GoldilocksField; let log_n = 5; let n = 1 << log_n; @@ -241,7 +234,7 @@ mod tests { let expected_l_first_x = PolynomialValues::selector(n, 0).ifft().eval(x); let expected_l_last_x = PolynomialValues::selector(n, n - 1).ifft().eval(x); - let (l_first_x, l_last_x) = eval_l_1_and_l_last(log_n, x); + let (l_first_x, l_last_x) = eval_l_0_and_l_last(log_n, x); assert_eq!(l_first_x, expected_l_first_x); assert_eq!(l_last_x, expected_l_last_x); } diff --git a/evm/tests/transfer_to_new_addr.rs b/evm/tests/transfer_to_new_addr.rs index c30e7b7b..1cd79194 100644 --- a/evm/tests/transfer_to_new_addr.rs +++ b/evm/tests/transfer_to_new_addr.rs @@ -1,10 +1,12 @@ +use eth_trie_utils::partial_trie::PartialTrie; use hex_literal::hex; use plonky2::field::goldilocks_field::GoldilocksField; use plonky2::plonk::config::PoseidonGoldilocksConfig; use plonky2::util::timing::TimingTree; use plonky2_evm::all_stark::AllStark; use plonky2_evm::config::StarkConfig; -use plonky2_evm::generation::{generate_traces, TransactionData}; +use plonky2_evm::generation::GenerationInputs; +use plonky2_evm::proof::BlockMetadata; use plonky2_evm::prover::prove; use plonky2_evm::verifier::verify_proof; @@ -17,30 +19,22 @@ type C = PoseidonGoldilocksConfig; #[ignore] // TODO: Won't work until txn parsing, storage, etc. are implemented. fn test_simple_transfer() -> anyhow::Result<()> { let all_stark = AllStark::::default(); + let config = StarkConfig::standard_fast_config(); - let txn = TransactionData { - signed_txn: hex!("f85f050a82520894000000000000000000000000000000000000000064801ca0fa56df5d988638fad8798e5ef75a1e1125dc7fb55d2ac4bce25776a63f0c2967a02cb47a5579eb5f83a1cabe4662501c0059f1b58e60ef839a1b0da67af6b9fb38").to_vec(), - trie_proofs: vec![ - vec![ - hex!("f874a1202f93d0dfb1562c03c825a33eec4438e468c17fff649ae844c004065985ae2945b850f84e058a152d02c7e14af6800000a056e81f171bcc55a6ff8345e692c0f86e5b48e01b996cadc001622fb5e363b421a0c5d2460186f7233c927e7db2dcc703c0e500b653ca82273b7bfad8045d85a470").to_vec(), - ], - vec![ - hex!("f8518080a0d36b8b6b60021940d5553689fb33e5d45e649dd8f4f211d26566238a83169da58080a0c62aa627943b70321f89a8b2fea274ecd47116e62042077dcdc0bdca7c1f66738080808080808080808080").to_vec(), - hex!("f873a03f93d0dfb1562c03c825a33eec4438e468c17fff649ae844c004065985ae2945b850f84e068a152d02c7e14af67ccb4ca056e81f171bcc55a6ff8345e692c0f86e5b48e01b996cadc001622fb5e363b421a0c5d2460186f7233c927e7db2dcc703c0e500b653ca82273b7bfad8045d85a470").to_vec(), - ], - ] + let block_metadata = BlockMetadata::default(); + + let txn = hex!("f85f050a82520894000000000000000000000000000000000000000064801ca0fa56df5d988638fad8798e5ef75a1e1125dc7fb55d2ac4bce25776a63f0c2967a02cb47a5579eb5f83a1cabe4662501c0059f1b58e60ef839a1b0da67af6b9fb38"); + + let inputs = GenerationInputs { + signed_txns: vec![txn.to_vec()], + state_trie: PartialTrie::Empty, + transactions_trie: PartialTrie::Empty, + receipts_trie: PartialTrie::Empty, + storage_tries: vec![], + block_metadata, }; - let traces = generate_traces(&all_stark, &[txn]); - - let config = StarkConfig::standard_fast_config(); - let proof = prove::( - &all_stark, - &config, - traces, - vec![vec![]; 4], - &mut TimingTree::default(), - )?; + let proof = prove::(&all_stark, &config, inputs, &mut TimingTree::default())?; verify_proof(all_stark, proof, &config) } diff --git a/field/src/extension/mod.rs b/field/src/extension/mod.rs index f54d669c..ed596764 100644 --- a/field/src/extension/mod.rs +++ b/field/src/extension/mod.rs @@ -22,8 +22,8 @@ pub trait OEF: FieldExtension { } impl OEF<1> for F { - const W: Self::BaseField = F::ZERO; - const DTH_ROOT: Self::BaseField = F::ZERO; + const W: Self::BaseField = F::ONE; + const DTH_ROOT: Self::BaseField = F::ONE; } pub trait Frobenius: OEF { @@ -80,8 +80,8 @@ pub trait Extendable: Field + Sized { impl + FieldExtension<1, BaseField = F>> Extendable<1> for F { type Extension = F; - const W: Self = F::ZERO; - const DTH_ROOT: Self = F::ZERO; + const W: Self = F::ONE; + const DTH_ROOT: Self = F::ONE; const EXT_MULTIPLICATIVE_GROUP_GENERATOR: [Self; 1] = [F::MULTIPLICATIVE_GROUP_GENERATOR]; const EXT_POWER_OF_TWO_GENERATOR: [Self; 1] = [F::POWER_OF_TWO_GENERATOR]; } diff --git a/field/src/extension/quadratic.rs b/field/src/extension/quadratic.rs index d68df42e..278abba9 100644 --- a/field/src/extension/quadratic.rs +++ b/field/src/extension/quadratic.rs @@ -3,7 +3,6 @@ use std::iter::{Product, Sum}; use std::ops::{Add, AddAssign, Div, DivAssign, Mul, MulAssign, Neg, Sub, SubAssign}; use num::bigint::BigUint; -use num::Integer; use serde::{Deserialize, Serialize}; use crate::extension::{Extendable, FieldExtension, Frobenius, OEF}; @@ -89,9 +88,8 @@ impl> Field for QuadraticExtension { )) } - fn from_biguint(n: BigUint) -> Self { - let (high, low) = n.div_rem(&F::order()); - Self([F::from_biguint(low), F::from_biguint(high)]) + fn from_noncanonical_biguint(n: BigUint) -> Self { + F::from_noncanonical_biguint(n).into() } fn from_canonical_u64(n: u64) -> Self { diff --git a/field/src/extension/quartic.rs b/field/src/extension/quartic.rs index fc0cbcf8..6df39903 100644 --- a/field/src/extension/quartic.rs +++ b/field/src/extension/quartic.rs @@ -4,7 +4,6 @@ use std::ops::{Add, AddAssign, Div, DivAssign, Mul, MulAssign, Neg, Sub, SubAssi use num::bigint::BigUint; use num::traits::Pow; -use num::Integer; use serde::{Deserialize, Serialize}; use crate::extension::{Extendable, FieldExtension, Frobenius, OEF}; @@ -94,16 +93,8 @@ impl> Field for QuarticExtension { )) } - fn from_biguint(n: BigUint) -> Self { - let (rest, first) = n.div_rem(&F::order()); - let (rest, second) = rest.div_rem(&F::order()); - let (rest, third) = rest.div_rem(&F::order()); - Self([ - F::from_biguint(first), - F::from_biguint(second), - F::from_biguint(third), - F::from_biguint(rest), - ]) + fn from_noncanonical_biguint(n: BigUint) -> Self { + F::from_noncanonical_biguint(n).into() } fn from_canonical_u64(n: u64) -> Self { diff --git a/field/src/extension/quintic.rs b/field/src/extension/quintic.rs index 564674c3..6680ebc7 100644 --- a/field/src/extension/quintic.rs +++ b/field/src/extension/quintic.rs @@ -99,8 +99,8 @@ impl> Field for QuinticExtension { Some(FieldExtension::<5>::scalar_mul(&f, g.inverse())) } - fn from_biguint(n: BigUint) -> Self { - Self([F::from_biguint(n), F::ZERO, F::ZERO, F::ZERO, F::ZERO]) + fn from_noncanonical_biguint(n: BigUint) -> Self { + F::from_noncanonical_biguint(n).into() } fn from_canonical_u64(n: u64) -> Self { diff --git a/field/src/fft.rs b/field/src/fft.rs index 7e9deae5..6ede8af6 100644 --- a/field/src/fft.rs +++ b/field/src/fft.rs @@ -61,7 +61,7 @@ pub fn fft_with_options( ) -> PolynomialValues { let PolynomialCoeffs { coeffs: mut buffer } = poly; fft_dispatch(&mut buffer, zero_factor, root_table); - PolynomialValues { values: buffer } + PolynomialValues::new(buffer) } #[inline] diff --git a/field/src/goldilocks_extensions.rs b/field/src/goldilocks_extensions.rs index e684c7cb..2175494f 100644 --- a/field/src/goldilocks_extensions.rs +++ b/field/src/goldilocks_extensions.rs @@ -112,14 +112,14 @@ impl Mul for QuinticExtension { * result coefficient is necessary. */ -/// Return a, b such that a + b*2^128 = 3*x with a < 2^128 and b < 2^32. +/// Return `a`, `b` such that `a + b*2^128 = 3*(x + y*2^128)` with `a < 2^128` and `b < 2^32`. #[inline(always)] fn u160_times_3(x: u128, y: u32) -> (u128, u32) { let (s, cy) = x.overflowing_add(x << 1); (s, 3 * y + (x >> 127) as u32 + cy as u32) } -/// Return a, b such that a + b*2^128 = 7*x with a < 2^128 and b < 2^32. +/// Return `a`, `b` such that `a + b*2^128 = 7*(x + y*2^128)` with `a < 2^128` and `b < 2^32`. #[inline(always)] fn u160_times_7(x: u128, y: u32) -> (u128, u32) { let (d, br) = (x << 3).overflowing_sub(x); diff --git a/field/src/goldilocks_field.rs b/field/src/goldilocks_field.rs index c5075b5d..c1bb60b0 100644 --- a/field/src/goldilocks_field.rs +++ b/field/src/goldilocks_field.rs @@ -90,7 +90,7 @@ impl Field for GoldilocksField { try_inverse_u64(self) } - fn from_biguint(n: BigUint) -> Self { + fn from_noncanonical_biguint(n: BigUint) -> Self { Self(n.mod_floor(&Self::order()).to_u64_digits()[0]) } diff --git a/field/src/interpolation.rs b/field/src/interpolation.rs index d0675715..8f64e9d7 100644 --- a/field/src/interpolation.rs +++ b/field/src/interpolation.rs @@ -19,9 +19,7 @@ pub fn interpolant(points: &[(F, F)]) -> PolynomialCoeffs { .map(|x| interpolate(points, x, &barycentric_weights)) .collect(); - let mut coeffs = ifft(PolynomialValues { - values: subgroup_evals, - }); + let mut coeffs = ifft(PolynomialValues::new(subgroup_evals)); coeffs.trim(); coeffs } diff --git a/field/src/polynomial/mod.rs b/field/src/polynomial/mod.rs index 82c4a41c..09ed69c7 100644 --- a/field/src/polynomial/mod.rs +++ b/field/src/polynomial/mod.rs @@ -24,6 +24,8 @@ pub struct PolynomialValues { impl PolynomialValues { pub fn new(values: Vec) -> Self { + // Check that a subgroup exists of this size, which should be a power of two. + debug_assert!(log2_strict(values.len()) <= F::TWO_ADICITY); PolynomialValues { values } } @@ -35,6 +37,10 @@ impl PolynomialValues { Self::constant(F::ZERO, len) } + pub fn is_zero(&self) -> bool { + self.values.iter().all(|x| x.is_zero()) + } + /// Returns the polynomial whole value is one at the given index, and zero elsewhere. pub fn selector(len: usize, index: usize) -> Self { let mut result = Self::zero(len); @@ -116,6 +122,7 @@ impl PolynomialCoeffs { PolynomialCoeffs { coeffs } } + /// The empty list of coefficients, which is the smallest encoding of the zero polynomial. pub fn empty() -> Self { Self::new(Vec::new()) } diff --git a/field/src/secp256k1_base.rs b/field/src/secp256k1_base.rs index 9e39b982..504d63d7 100644 --- a/field/src/secp256k1_base.rs +++ b/field/src/secp256k1_base.rs @@ -106,7 +106,7 @@ impl Field for Secp256K1Base { Some(self.exp_biguint(&(Self::order() - BigUint::one() - BigUint::one()))) } - fn from_biguint(val: BigUint) -> Self { + fn from_noncanonical_biguint(val: BigUint) -> Self { Self( val.to_u64_digits() .into_iter() @@ -135,7 +135,7 @@ impl Field for Secp256K1Base { #[cfg(feature = "rand")] fn rand_from_rng(rng: &mut R) -> Self { use num::bigint::RandBigInt; - Self::from_biguint(rng.gen_biguint_below(&Self::order())) + Self::from_noncanonical_biguint(rng.gen_biguint_below(&Self::order())) } } @@ -157,7 +157,7 @@ impl Neg for Secp256K1Base { if self.is_zero() { Self::ZERO } else { - Self::from_biguint(Self::order() - self.to_canonical_biguint()) + Self::from_noncanonical_biguint(Self::order() - self.to_canonical_biguint()) } } } @@ -171,7 +171,7 @@ impl Add for Secp256K1Base { if result >= Self::order() { result -= Self::order(); } - Self::from_biguint(result) + Self::from_noncanonical_biguint(result) } } @@ -210,7 +210,7 @@ impl Mul for Secp256K1Base { #[inline] fn mul(self, rhs: Self) -> Self { - Self::from_biguint( + Self::from_noncanonical_biguint( (self.to_canonical_biguint() * rhs.to_canonical_biguint()).mod_floor(&Self::order()), ) } diff --git a/field/src/secp256k1_scalar.rs b/field/src/secp256k1_scalar.rs index eea67fab..e70b154d 100644 --- a/field/src/secp256k1_scalar.rs +++ b/field/src/secp256k1_scalar.rs @@ -115,7 +115,7 @@ impl Field for Secp256K1Scalar { Some(self.exp_biguint(&(Self::order() - BigUint::one() - BigUint::one()))) } - fn from_biguint(val: BigUint) -> Self { + fn from_noncanonical_biguint(val: BigUint) -> Self { Self( val.to_u64_digits() .into_iter() @@ -144,7 +144,7 @@ impl Field for Secp256K1Scalar { #[cfg(feature = "rand")] fn rand_from_rng(rng: &mut R) -> Self { use num::bigint::RandBigInt; - Self::from_biguint(rng.gen_biguint_below(&Self::order())) + Self::from_noncanonical_biguint(rng.gen_biguint_below(&Self::order())) } } @@ -166,7 +166,7 @@ impl Neg for Secp256K1Scalar { if self.is_zero() { Self::ZERO } else { - Self::from_biguint(Self::order() - self.to_canonical_biguint()) + Self::from_noncanonical_biguint(Self::order() - self.to_canonical_biguint()) } } } @@ -180,7 +180,7 @@ impl Add for Secp256K1Scalar { if result >= Self::order() { result -= Self::order(); } - Self::from_biguint(result) + Self::from_noncanonical_biguint(result) } } @@ -219,7 +219,7 @@ impl Mul for Secp256K1Scalar { #[inline] fn mul(self, rhs: Self) -> Self { - Self::from_biguint( + Self::from_noncanonical_biguint( (self.to_canonical_biguint() * rhs.to_canonical_biguint()).mod_floor(&Self::order()), ) } diff --git a/field/src/types.rs b/field/src/types.rs index 87fd8dd4..ac94bcfa 100644 --- a/field/src/types.rs +++ b/field/src/types.rs @@ -270,9 +270,8 @@ pub trait Field: subgroup.into_iter().map(|x| x * shift).collect() } - // TODO: The current behavior for composite fields doesn't seem natural or useful. - // Rename to `from_noncanonical_biguint` and have it return `n % Self::characteristic()`. - fn from_biguint(n: BigUint) -> Self; + /// Returns `n % Self::characteristic()`. + fn from_noncanonical_biguint(n: BigUint) -> Self; /// Returns `n`. Assumes that `n` is already in canonical form, i.e. `n < Self::order()`. // TODO: Should probably be unsafe. diff --git a/field/src/zero_poly_coset.rs b/field/src/zero_poly_coset.rs index 18cc3238..8d63bc69 100644 --- a/field/src/zero_poly_coset.rs +++ b/field/src/zero_poly_coset.rs @@ -51,8 +51,8 @@ impl ZeroPolyOnCoset { packed } - /// Returns `L_1(x) = Z_H(x)/(n * (x - 1))` with `x = w^i`. - pub fn eval_l1(&self, i: usize, x: F) -> F { + /// Returns `L_0(x) = Z_H(x)/(n * (x - 1))` with `x = w^i`. + pub fn eval_l_0(&self, i: usize, x: F) -> F { // Could also precompute the inverses using Montgomery. self.eval(i) * (self.n * (x - F::ONE)).inverse() } diff --git a/maybe_rayon/src/lib.rs b/maybe_rayon/src/lib.rs index 1a9bd823..d24ba2e5 100644 --- a/maybe_rayon/src/lib.rs +++ b/maybe_rayon/src/lib.rs @@ -1,6 +1,6 @@ #[cfg(not(feature = "parallel"))] use std::{ - iter::{IntoIterator, Iterator}, + iter::{FlatMap, IntoIterator, Iterator}, slice::{Chunks, ChunksExact, ChunksExactMut, ChunksMut}, }; @@ -223,13 +223,21 @@ impl MaybeParChunksMut for [T] { } } +#[cfg(not(feature = "parallel"))] pub trait ParallelIteratorMock { type Item; fn find_any

(self, predicate: P) -> Option where P: Fn(&Self::Item) -> bool + Sync + Send; + + fn flat_map_iter(self, map_op: F) -> FlatMap + where + Self: Sized, + U: IntoIterator, + F: Fn(Self::Item) -> U; } +#[cfg(not(feature = "parallel"))] impl ParallelIteratorMock for T { type Item = T::Item; @@ -239,6 +247,15 @@ impl ParallelIteratorMock for T { { self.find(predicate) } + + fn flat_map_iter(self, map_op: F) -> FlatMap + where + Self: Sized, + U: IntoIterator, + F: Fn(Self::Item) -> U, + { + self.flat_map(map_op) + } } #[cfg(feature = "parallel")] diff --git a/plonky2/plonky2.pdf b/plonky2/plonky2.pdf index 349b22a6..8f0f9ece 100644 Binary files a/plonky2/plonky2.pdf and b/plonky2/plonky2.pdf differ diff --git a/plonky2/src/fri/oracle.rs b/plonky2/src/fri/oracle.rs index 1f5b648f..75f8847a 100644 --- a/plonky2/src/fri/oracle.rs +++ b/plonky2/src/fri/oracle.rs @@ -180,7 +180,15 @@ impl, C: GenericConfig, const D: usize> // Final low-degree polynomial that goes into FRI. let mut final_poly = PolynomialCoeffs::empty(); + // Each batch `i` consists of an opening point `z_i` and polynomials `{f_ij}_j` to be opened at that point. + // For each batch, we compute the composition polynomial `F_i = sum alpha^j f_ij`, + // where `alpha` is a random challenge in the extension field. + // The final polynomial is then computed as `final_poly = sum_i alpha^(k_i) (F_i(X) - F_i(z_i))/(X-z_i)` + // where the `k_i`s are chosen such that each power of `alpha` appears only once in the final sum. + // There are usually two batches for the openings at `zeta` and `g * zeta`. + // The oracles used in Plonky2 are given in `FRI_ORACLES` in `plonky2/src/plonk/plonk_common.rs`. for FriBatchInfo { point, polynomials } in &instance.batches { + // Collect the coefficients of all the polynomials in `polynomials`. let polys_coeff = polynomials.iter().map(|fri_poly| { &oracles[fri_poly.oracle_index].polynomials[fri_poly.polynomial_index] }); diff --git a/plonky2/src/fri/prover.rs b/plonky2/src/fri/prover.rs index 39e25869..71efe98a 100644 --- a/plonky2/src/fri/prover.rs +++ b/plonky2/src/fri/prover.rs @@ -149,15 +149,12 @@ fn fri_prover_query_rounds< n: usize, fri_params: &FriParams, ) -> Vec> { - (0..fri_params.config.num_query_rounds) - .map(|_| { - fri_prover_query_round::( - initial_merkle_trees, - trees, - challenger, - n, - fri_params, - ) + challenger + .get_n_challenges(fri_params.config.num_query_rounds) + .into_par_iter() + .map(|rand| { + let x_index = rand.to_canonical_u64() as usize % n; + fri_prover_query_round::(initial_merkle_trees, trees, x_index, fri_params) }) .collect() } @@ -169,13 +166,10 @@ fn fri_prover_query_round< >( initial_merkle_trees: &[&MerkleTree], trees: &[MerkleTree], - challenger: &mut Challenger, - n: usize, + mut x_index: usize, fri_params: &FriParams, ) -> FriQueryRound { let mut query_steps = Vec::new(); - let x = challenger.get_challenge(); - let mut x_index = x.to_canonical_u64() as usize % n; let initial_proof = initial_merkle_trees .iter() .map(|t| (t.get(x_index).to_vec(), t.prove(x_index))) diff --git a/plonky2/src/fri/recursive_verifier.rs b/plonky2/src/fri/recursive_verifier.rs index 1a3739b4..ac7e3a87 100644 --- a/plonky2/src/fri/recursive_verifier.rs +++ b/plonky2/src/fri/recursive_verifier.rs @@ -8,9 +8,9 @@ use crate::fri::proof::{ }; use crate::fri::structure::{FriBatchInfoTarget, FriInstanceInfoTarget, FriOpeningsTarget}; use crate::fri::{FriConfig, FriParams}; -use crate::gadgets::interpolation::InterpolationGate; use crate::gates::gate::Gate; -use crate::gates::interpolation::HighDegreeInterpolationGate; +use crate::gates::high_degree_interpolation::HighDegreeInterpolationGate; +use crate::gates::interpolation::InterpolationGate; use crate::gates::low_degree_interpolation::LowDegreeInterpolationGate; use crate::gates::random_access::RandomAccessGate; use crate::hash::hash_types::MerkleCapTarget; diff --git a/plonky2/src/fri/structure.rs b/plonky2/src/fri/structure.rs index 1a37a1b2..d5c2c81c 100644 --- a/plonky2/src/fri/structure.rs +++ b/plonky2/src/fri/structure.rs @@ -42,7 +42,7 @@ pub struct FriBatchInfoTarget { #[derive(Copy, Clone, Debug)] pub struct FriPolynomialInfo { - /// Index into `FriInstanceInfoTarget`'s `oracles` list. + /// Index into `FriInstanceInfo`'s `oracles` list. pub oracle_index: usize, /// Index of the polynomial within the oracle. pub polynomial_index: usize, diff --git a/plonky2/src/gadgets/arithmetic_extension.rs b/plonky2/src/gadgets/arithmetic_extension.rs index 97dedf28..23caeac1 100644 --- a/plonky2/src/gadgets/arithmetic_extension.rs +++ b/plonky2/src/gadgets/arithmetic_extension.rs @@ -505,7 +505,7 @@ impl, const D: usize> SimpleGenerator { fn dependencies(&self) -> Vec { let mut deps = self.numerator.to_target_array().to_vec(); - deps.extend(&self.denominator.to_target_array()); + deps.extend(self.denominator.to_target_array()); deps } diff --git a/plonky2/src/gadgets/interpolation.rs b/plonky2/src/gadgets/interpolation.rs deleted file mode 100644 index b22f3b59..00000000 --- a/plonky2/src/gadgets/interpolation.rs +++ /dev/null @@ -1,178 +0,0 @@ -use std::ops::Range; - -use plonky2_field::extension::Extendable; - -use crate::gates::gate::Gate; -use crate::hash::hash_types::RichField; -use crate::iop::ext_target::ExtensionTarget; -use crate::iop::target::Target; -use crate::plonk::circuit_builder::CircuitBuilder; - -/// Trait for gates which interpolate a polynomial, whose points are a (base field) coset of the multiplicative subgroup -/// with the given size, and whose values are extension field elements, given by input wires. -/// Outputs the evaluation of the interpolant at a given (extension field) evaluation point. -pub(crate) trait InterpolationGate, const D: usize>: - Gate + Copy -{ - fn new(subgroup_bits: usize) -> Self; - - fn num_points(&self) -> usize; - - /// Wire index of the coset shift. - fn wire_shift(&self) -> usize { - 0 - } - - fn start_values(&self) -> usize { - 1 - } - - /// Wire indices of the `i`th interpolant value. - fn wires_value(&self, i: usize) -> Range { - debug_assert!(i < self.num_points()); - let start = self.start_values() + i * D; - start..start + D - } - - fn start_evaluation_point(&self) -> usize { - self.start_values() + self.num_points() * D - } - - /// Wire indices of the point to evaluate the interpolant at. - fn wires_evaluation_point(&self) -> Range { - let start = self.start_evaluation_point(); - start..start + D - } - - fn start_evaluation_value(&self) -> usize { - self.start_evaluation_point() + D - } - - /// Wire indices of the interpolated value. - fn wires_evaluation_value(&self) -> Range { - let start = self.start_evaluation_value(); - start..start + D - } - - fn start_coeffs(&self) -> usize { - self.start_evaluation_value() + D - } - - /// The number of routed wires required in the typical usage of this gate, where the points to - /// interpolate, the evaluation point, and the corresponding value are all routed. - fn num_routed_wires(&self) -> usize { - self.start_coeffs() - } - - /// Wire indices of the interpolant's `i`th coefficient. - fn wires_coeff(&self, i: usize) -> Range { - debug_assert!(i < self.num_points()); - let start = self.start_coeffs() + i * D; - start..start + D - } - - fn end_coeffs(&self) -> usize { - self.start_coeffs() + D * self.num_points() - } -} - -impl, const D: usize> CircuitBuilder { - /// Interpolates a polynomial, whose points are a coset of the multiplicative subgroup with the - /// given size, and whose values are given. Returns the evaluation of the interpolant at - /// `evaluation_point`. - pub(crate) fn interpolate_coset>( - &mut self, - subgroup_bits: usize, - coset_shift: Target, - values: &[ExtensionTarget], - evaluation_point: ExtensionTarget, - ) -> ExtensionTarget { - let gate = G::new(subgroup_bits); - let row = self.add_gate(gate, vec![]); - self.connect(coset_shift, Target::wire(row, gate.wire_shift())); - for (i, &v) in values.iter().enumerate() { - self.connect_extension(v, ExtensionTarget::from_range(row, gate.wires_value(i))); - } - self.connect_extension( - evaluation_point, - ExtensionTarget::from_range(row, gate.wires_evaluation_point()), - ); - - ExtensionTarget::from_range(row, gate.wires_evaluation_value()) - } -} - -#[cfg(test)] -mod tests { - use anyhow::Result; - use plonky2_field::extension::FieldExtension; - use plonky2_field::interpolation::interpolant; - use plonky2_field::types::Field; - - use crate::gates::interpolation::HighDegreeInterpolationGate; - use crate::gates::low_degree_interpolation::LowDegreeInterpolationGate; - use crate::iop::witness::PartialWitness; - use crate::plonk::circuit_builder::CircuitBuilder; - use crate::plonk::circuit_data::CircuitConfig; - use crate::plonk::config::{GenericConfig, PoseidonGoldilocksConfig}; - use crate::plonk::verifier::verify; - - #[test] - fn test_interpolate() -> Result<()> { - const D: usize = 2; - type C = PoseidonGoldilocksConfig; - type F = >::F; - type FF = >::FE; - let config = CircuitConfig::standard_recursion_config(); - let pw = PartialWitness::new(); - let mut builder = CircuitBuilder::::new(config); - - let subgroup_bits = 2; - let len = 1 << subgroup_bits; - let coset_shift = F::rand(); - let g = F::primitive_root_of_unity(subgroup_bits); - let points = F::cyclic_subgroup_coset_known_order(g, coset_shift, len); - let values = FF::rand_vec(len); - - let homogeneous_points = points - .iter() - .zip(values.iter()) - .map(|(&a, &b)| (>::from_basefield(a), b)) - .collect::>(); - - let true_interpolant = interpolant(&homogeneous_points); - - let z = FF::rand(); - let true_eval = true_interpolant.eval(z); - - let coset_shift_target = builder.constant(coset_shift); - - let value_targets = values - .iter() - .map(|&v| (builder.constant_extension(v))) - .collect::>(); - - let zt = builder.constant_extension(z); - - let eval_hd = builder.interpolate_coset::>( - subgroup_bits, - coset_shift_target, - &value_targets, - zt, - ); - let eval_ld = builder.interpolate_coset::>( - subgroup_bits, - coset_shift_target, - &value_targets, - zt, - ); - let true_eval_target = builder.constant_extension(true_eval); - builder.connect_extension(eval_hd, true_eval_target); - builder.connect_extension(eval_ld, true_eval_target); - - let data = builder.build::(); - let proof = data.prove(pw)?; - - verify(proof, &data.verifier_only, &data.common) - } -} diff --git a/plonky2/src/gadgets/mod.rs b/plonky2/src/gadgets/mod.rs index 6309eb3d..a3e50c4e 100644 --- a/plonky2/src/gadgets/mod.rs +++ b/plonky2/src/gadgets/mod.rs @@ -1,7 +1,6 @@ pub mod arithmetic; pub mod arithmetic_extension; pub mod hash; -pub mod interpolation; pub mod polynomial; pub mod random_access; pub mod range_check; diff --git a/plonky2/src/gates/high_degree_interpolation.rs b/plonky2/src/gates/high_degree_interpolation.rs new file mode 100644 index 00000000..bcdf2276 --- /dev/null +++ b/plonky2/src/gates/high_degree_interpolation.rs @@ -0,0 +1,363 @@ +use std::marker::PhantomData; +use std::ops::Range; + +use plonky2_field::extension::algebra::PolynomialCoeffsAlgebra; +use plonky2_field::extension::{Extendable, FieldExtension}; +use plonky2_field::interpolation::interpolant; +use plonky2_field::polynomial::PolynomialCoeffs; + +use crate::gadgets::polynomial::PolynomialCoeffsExtAlgebraTarget; +use crate::gates::gate::Gate; +use crate::gates::interpolation::InterpolationGate; +use crate::gates::util::StridedConstraintConsumer; +use crate::hash::hash_types::RichField; +use crate::iop::ext_target::ExtensionTarget; +use crate::iop::generator::{GeneratedValues, SimpleGenerator, WitnessGenerator}; +use crate::iop::target::Target; +use crate::iop::wire::Wire; +use crate::iop::witness::{PartitionWitness, Witness}; +use crate::plonk::circuit_builder::CircuitBuilder; +use crate::plonk::vars::{EvaluationTargets, EvaluationVars, EvaluationVarsBase}; + +/// One of the instantiations of `InterpolationGate`: allows constraints of variable +/// degree, up to `1<, const D: usize> { + pub subgroup_bits: usize, + _phantom: PhantomData, +} + +impl, const D: usize> InterpolationGate + for HighDegreeInterpolationGate +{ + fn new(subgroup_bits: usize) -> Self { + Self { + subgroup_bits, + _phantom: PhantomData, + } + } + + fn num_points(&self) -> usize { + 1 << self.subgroup_bits + } +} + +impl, const D: usize> HighDegreeInterpolationGate { + /// End of wire indices, exclusive. + fn end(&self) -> usize { + self.start_coeffs() + self.num_points() * D + } + + /// The domain of the points we're interpolating. + fn coset(&self, shift: F) -> impl Iterator { + let g = F::primitive_root_of_unity(self.subgroup_bits); + let size = 1 << self.subgroup_bits; + // Speed matters here, so we avoid `cyclic_subgroup_coset_known_order` which allocates. + g.powers().take(size).map(move |x| x * shift) + } + + /// The domain of the points we're interpolating. + fn coset_ext(&self, shift: F::Extension) -> impl Iterator { + let g = F::primitive_root_of_unity(self.subgroup_bits); + let size = 1 << self.subgroup_bits; + g.powers().take(size).map(move |x| shift.scalar_mul(x)) + } + + /// The domain of the points we're interpolating. + fn coset_ext_circuit( + &self, + builder: &mut CircuitBuilder, + shift: ExtensionTarget, + ) -> Vec> { + let g = F::primitive_root_of_unity(self.subgroup_bits); + let size = 1 << self.subgroup_bits; + g.powers() + .take(size) + .map(move |x| { + let subgroup_element = builder.constant(x); + builder.scalar_mul_ext(subgroup_element, shift) + }) + .collect() + } +} + +impl, const D: usize> Gate + for HighDegreeInterpolationGate +{ + fn id(&self) -> String { + format!("{:?}", self, D) + } + + fn eval_unfiltered(&self, vars: EvaluationVars) -> Vec { + let mut constraints = Vec::with_capacity(self.num_constraints()); + + let coeffs = (0..self.num_points()) + .map(|i| vars.get_local_ext_algebra(self.wires_coeff(i))) + .collect(); + let interpolant = PolynomialCoeffsAlgebra::new(coeffs); + + let coset = self.coset_ext(vars.local_wires[self.wire_shift()]); + for (i, point) in coset.into_iter().enumerate() { + let value = vars.get_local_ext_algebra(self.wires_value(i)); + let computed_value = interpolant.eval_base(point); + constraints.extend((value - computed_value).to_basefield_array()); + } + + let evaluation_point = vars.get_local_ext_algebra(self.wires_evaluation_point()); + let evaluation_value = vars.get_local_ext_algebra(self.wires_evaluation_value()); + let computed_evaluation_value = interpolant.eval(evaluation_point); + constraints.extend((evaluation_value - computed_evaluation_value).to_basefield_array()); + + constraints + } + + fn eval_unfiltered_base_one( + &self, + vars: EvaluationVarsBase, + mut yield_constr: StridedConstraintConsumer, + ) { + let coeffs = (0..self.num_points()) + .map(|i| vars.get_local_ext(self.wires_coeff(i))) + .collect(); + let interpolant = PolynomialCoeffs::new(coeffs); + + let coset = self.coset(vars.local_wires[self.wire_shift()]); + for (i, point) in coset.into_iter().enumerate() { + let value = vars.get_local_ext(self.wires_value(i)); + let computed_value = interpolant.eval_base(point); + yield_constr.many((value - computed_value).to_basefield_array()); + } + + let evaluation_point = vars.get_local_ext(self.wires_evaluation_point()); + let evaluation_value = vars.get_local_ext(self.wires_evaluation_value()); + let computed_evaluation_value = interpolant.eval(evaluation_point); + yield_constr.many((evaluation_value - computed_evaluation_value).to_basefield_array()); + } + + fn eval_unfiltered_circuit( + &self, + builder: &mut CircuitBuilder, + vars: EvaluationTargets, + ) -> Vec> { + let mut constraints = Vec::with_capacity(self.num_constraints()); + + let coeffs = (0..self.num_points()) + .map(|i| vars.get_local_ext_algebra(self.wires_coeff(i))) + .collect(); + let interpolant = PolynomialCoeffsExtAlgebraTarget(coeffs); + + let coset = self.coset_ext_circuit(builder, vars.local_wires[self.wire_shift()]); + for (i, point) in coset.into_iter().enumerate() { + let value = vars.get_local_ext_algebra(self.wires_value(i)); + let computed_value = interpolant.eval_scalar(builder, point); + constraints.extend( + builder + .sub_ext_algebra(value, computed_value) + .to_ext_target_array(), + ); + } + + let evaluation_point = vars.get_local_ext_algebra(self.wires_evaluation_point()); + let evaluation_value = vars.get_local_ext_algebra(self.wires_evaluation_value()); + let computed_evaluation_value = interpolant.eval(builder, evaluation_point); + constraints.extend( + builder + .sub_ext_algebra(evaluation_value, computed_evaluation_value) + .to_ext_target_array(), + ); + + constraints + } + + fn generators(&self, row: usize, _local_constants: &[F]) -> Vec>> { + let gen = InterpolationGenerator:: { + row, + gate: *self, + _phantom: PhantomData, + }; + vec![Box::new(gen.adapter())] + } + + fn num_wires(&self) -> usize { + self.end() + } + + fn num_constants(&self) -> usize { + 0 + } + + fn degree(&self) -> usize { + // The highest power of x is `num_points - 1`, and then multiplication by the coefficient + // adds 1. + self.num_points() + } + + fn num_constraints(&self) -> usize { + // num_points * D constraints to check for consistency between the coefficients and the + // point-value pairs, plus D constraints for the evaluation value. + self.num_points() * D + D + } +} + +#[derive(Debug)] +struct InterpolationGenerator, const D: usize> { + row: usize, + gate: HighDegreeInterpolationGate, + _phantom: PhantomData, +} + +impl, const D: usize> SimpleGenerator + for InterpolationGenerator +{ + fn dependencies(&self) -> Vec { + let local_target = |column| { + Target::Wire(Wire { + row: self.row, + column, + }) + }; + + let local_targets = |columns: Range| columns.map(local_target); + + let num_points = self.gate.num_points(); + let mut deps = Vec::with_capacity(1 + D + num_points * D); + + deps.push(local_target(self.gate.wire_shift())); + deps.extend(local_targets(self.gate.wires_evaluation_point())); + for i in 0..num_points { + deps.extend(local_targets(self.gate.wires_value(i))); + } + deps + } + + fn run_once(&self, witness: &PartitionWitness, out_buffer: &mut GeneratedValues) { + let local_wire = |column| Wire { + row: self.row, + column, + }; + + let get_local_wire = |column| witness.get_wire(local_wire(column)); + + let get_local_ext = |wire_range: Range| { + debug_assert_eq!(wire_range.len(), D); + let values = wire_range.map(get_local_wire).collect::>(); + let arr = values.try_into().unwrap(); + F::Extension::from_basefield_array(arr) + }; + + // Compute the interpolant. + let points = self.gate.coset(get_local_wire(self.gate.wire_shift())); + let points = points + .into_iter() + .enumerate() + .map(|(i, point)| (point.into(), get_local_ext(self.gate.wires_value(i)))) + .collect::>(); + let interpolant = interpolant(&points); + + for (i, &coeff) in interpolant.coeffs.iter().enumerate() { + let wires = self.gate.wires_coeff(i).map(local_wire); + out_buffer.set_ext_wires(wires, coeff); + } + + let evaluation_point = get_local_ext(self.gate.wires_evaluation_point()); + let evaluation_value = interpolant.eval(evaluation_point); + let evaluation_value_wires = self.gate.wires_evaluation_value().map(local_wire); + out_buffer.set_ext_wires(evaluation_value_wires, evaluation_value); + } +} + +#[cfg(test)] +mod tests { + use std::marker::PhantomData; + + use anyhow::Result; + use plonky2_field::goldilocks_field::GoldilocksField; + use plonky2_field::polynomial::PolynomialCoeffs; + use plonky2_field::types::Field; + + use crate::gates::gate::Gate; + use crate::gates::gate_testing::{test_eval_fns, test_low_degree}; + use crate::gates::high_degree_interpolation::HighDegreeInterpolationGate; + use crate::gates::interpolation::InterpolationGate; + use crate::hash::hash_types::HashOut; + use crate::plonk::config::{GenericConfig, PoseidonGoldilocksConfig}; + use crate::plonk::vars::EvaluationVars; + + #[test] + fn wire_indices() { + let gate = HighDegreeInterpolationGate:: { + subgroup_bits: 1, + _phantom: PhantomData, + }; + + // The exact indices aren't really important, but we want to make sure we don't have any + // overlaps or gaps. + assert_eq!(gate.wire_shift(), 0); + assert_eq!(gate.wires_value(0), 1..5); + assert_eq!(gate.wires_value(1), 5..9); + assert_eq!(gate.wires_evaluation_point(), 9..13); + assert_eq!(gate.wires_evaluation_value(), 13..17); + assert_eq!(gate.wires_coeff(0), 17..21); + assert_eq!(gate.wires_coeff(1), 21..25); + assert_eq!(gate.num_wires(), 25); + } + + #[test] + fn low_degree() { + test_low_degree::(HighDegreeInterpolationGate::new(2)); + } + + #[test] + fn eval_fns() -> Result<()> { + const D: usize = 2; + type C = PoseidonGoldilocksConfig; + type F = >::F; + test_eval_fns::(HighDegreeInterpolationGate::new(2)) + } + + #[test] + fn test_gate_constraint() { + const D: usize = 2; + type C = PoseidonGoldilocksConfig; + type F = >::F; + type FF = >::FE; + + /// Returns the local wires for an interpolation gate for given coeffs, points and eval point. + fn get_wires( + gate: &HighDegreeInterpolationGate, + shift: F, + coeffs: PolynomialCoeffs, + eval_point: FF, + ) -> Vec { + let points = gate.coset(shift); + let mut v = vec![shift]; + for x in points { + v.extend(coeffs.eval(x.into()).0); + } + v.extend(eval_point.0); + v.extend(coeffs.eval(eval_point).0); + for i in 0..coeffs.len() { + v.extend(coeffs.coeffs[i].0); + } + v.iter().map(|&x| x.into()).collect() + } + + // Get a working row for InterpolationGate. + let shift = F::rand(); + let coeffs = PolynomialCoeffs::new(vec![FF::rand(), FF::rand()]); + let eval_point = FF::rand(); + let gate = HighDegreeInterpolationGate::::new(1); + let vars = EvaluationVars { + local_constants: &[], + local_wires: &get_wires(&gate, shift, coeffs, eval_point), + public_inputs_hash: &HashOut::rand(), + }; + + assert!( + gate.eval_unfiltered(vars).iter().all(|x| x.is_zero()), + "Gate constraints are not satisfied." + ); + } +} diff --git a/plonky2/src/gates/interpolation.rs b/plonky2/src/gates/interpolation.rs index 1983e5aa..d417fa6b 100644 --- a/plonky2/src/gates/interpolation.rs +++ b/plonky2/src/gates/interpolation.rs @@ -1,361 +1,178 @@ -use std::marker::PhantomData; use std::ops::Range; -use plonky2_field::extension::algebra::PolynomialCoeffsAlgebra; -use plonky2_field::extension::{Extendable, FieldExtension}; -use plonky2_field::interpolation::interpolant; -use plonky2_field::polynomial::PolynomialCoeffs; +use plonky2_field::extension::Extendable; -use crate::gadgets::interpolation::InterpolationGate; -use crate::gadgets::polynomial::PolynomialCoeffsExtAlgebraTarget; use crate::gates::gate::Gate; -use crate::gates::util::StridedConstraintConsumer; use crate::hash::hash_types::RichField; use crate::iop::ext_target::ExtensionTarget; -use crate::iop::generator::{GeneratedValues, SimpleGenerator, WitnessGenerator}; use crate::iop::target::Target; -use crate::iop::wire::Wire; -use crate::iop::witness::{PartitionWitness, Witness}; use crate::plonk::circuit_builder::CircuitBuilder; -use crate::plonk::vars::{EvaluationTargets, EvaluationVars, EvaluationVarsBase}; -/// Interpolation gate with constraints of degree at most `1<, const D: usize> { - pub subgroup_bits: usize, - _phantom: PhantomData, -} - -impl, const D: usize> InterpolationGate - for HighDegreeInterpolationGate +/// Trait for gates which interpolate a polynomial, whose points are a (base field) coset of the multiplicative subgroup +/// with the given size, and whose values are extension field elements, given by input wires. +/// Outputs the evaluation of the interpolant at a given (extension field) evaluation point. +pub(crate) trait InterpolationGate, const D: usize>: + Gate + Copy { - fn new(subgroup_bits: usize) -> Self { - Self { - subgroup_bits, - _phantom: PhantomData, - } - } + fn new(subgroup_bits: usize) -> Self; - fn num_points(&self) -> usize { - 1 << self.subgroup_bits - } -} + fn num_points(&self) -> usize; -impl, const D: usize> HighDegreeInterpolationGate { - /// End of wire indices, exclusive. - fn end(&self) -> usize { - self.start_coeffs() + self.num_points() * D - } - - /// The domain of the points we're interpolating. - fn coset(&self, shift: F) -> impl Iterator { - let g = F::primitive_root_of_unity(self.subgroup_bits); - let size = 1 << self.subgroup_bits; - // Speed matters here, so we avoid `cyclic_subgroup_coset_known_order` which allocates. - g.powers().take(size).map(move |x| x * shift) - } - - /// The domain of the points we're interpolating. - fn coset_ext(&self, shift: F::Extension) -> impl Iterator { - let g = F::primitive_root_of_unity(self.subgroup_bits); - let size = 1 << self.subgroup_bits; - g.powers().take(size).map(move |x| shift.scalar_mul(x)) - } - - /// The domain of the points we're interpolating. - fn coset_ext_circuit( - &self, - builder: &mut CircuitBuilder, - shift: ExtensionTarget, - ) -> Vec> { - let g = F::primitive_root_of_unity(self.subgroup_bits); - let size = 1 << self.subgroup_bits; - g.powers() - .take(size) - .map(move |x| { - let subgroup_element = builder.constant(x); - builder.scalar_mul_ext(subgroup_element, shift) - }) - .collect() - } -} - -impl, const D: usize> Gate - for HighDegreeInterpolationGate -{ - fn id(&self) -> String { - format!("{:?}", self, D) - } - - fn eval_unfiltered(&self, vars: EvaluationVars) -> Vec { - let mut constraints = Vec::with_capacity(self.num_constraints()); - - let coeffs = (0..self.num_points()) - .map(|i| vars.get_local_ext_algebra(self.wires_coeff(i))) - .collect(); - let interpolant = PolynomialCoeffsAlgebra::new(coeffs); - - let coset = self.coset_ext(vars.local_wires[self.wire_shift()]); - for (i, point) in coset.into_iter().enumerate() { - let value = vars.get_local_ext_algebra(self.wires_value(i)); - let computed_value = interpolant.eval_base(point); - constraints.extend(&(value - computed_value).to_basefield_array()); - } - - let evaluation_point = vars.get_local_ext_algebra(self.wires_evaluation_point()); - let evaluation_value = vars.get_local_ext_algebra(self.wires_evaluation_value()); - let computed_evaluation_value = interpolant.eval(evaluation_point); - constraints.extend(&(evaluation_value - computed_evaluation_value).to_basefield_array()); - - constraints - } - - fn eval_unfiltered_base_one( - &self, - vars: EvaluationVarsBase, - mut yield_constr: StridedConstraintConsumer, - ) { - let coeffs = (0..self.num_points()) - .map(|i| vars.get_local_ext(self.wires_coeff(i))) - .collect(); - let interpolant = PolynomialCoeffs::new(coeffs); - - let coset = self.coset(vars.local_wires[self.wire_shift()]); - for (i, point) in coset.into_iter().enumerate() { - let value = vars.get_local_ext(self.wires_value(i)); - let computed_value = interpolant.eval_base(point); - yield_constr.many((value - computed_value).to_basefield_array()); - } - - let evaluation_point = vars.get_local_ext(self.wires_evaluation_point()); - let evaluation_value = vars.get_local_ext(self.wires_evaluation_value()); - let computed_evaluation_value = interpolant.eval(evaluation_point); - yield_constr.many((evaluation_value - computed_evaluation_value).to_basefield_array()); - } - - fn eval_unfiltered_circuit( - &self, - builder: &mut CircuitBuilder, - vars: EvaluationTargets, - ) -> Vec> { - let mut constraints = Vec::with_capacity(self.num_constraints()); - - let coeffs = (0..self.num_points()) - .map(|i| vars.get_local_ext_algebra(self.wires_coeff(i))) - .collect(); - let interpolant = PolynomialCoeffsExtAlgebraTarget(coeffs); - - let coset = self.coset_ext_circuit(builder, vars.local_wires[self.wire_shift()]); - for (i, point) in coset.into_iter().enumerate() { - let value = vars.get_local_ext_algebra(self.wires_value(i)); - let computed_value = interpolant.eval_scalar(builder, point); - constraints.extend( - &builder - .sub_ext_algebra(value, computed_value) - .to_ext_target_array(), - ); - } - - let evaluation_point = vars.get_local_ext_algebra(self.wires_evaluation_point()); - let evaluation_value = vars.get_local_ext_algebra(self.wires_evaluation_value()); - let computed_evaluation_value = interpolant.eval(builder, evaluation_point); - constraints.extend( - &builder - .sub_ext_algebra(evaluation_value, computed_evaluation_value) - .to_ext_target_array(), - ); - - constraints - } - - fn generators(&self, row: usize, _local_constants: &[F]) -> Vec>> { - let gen = InterpolationGenerator:: { - row, - gate: *self, - _phantom: PhantomData, - }; - vec![Box::new(gen.adapter())] - } - - fn num_wires(&self) -> usize { - self.end() - } - - fn num_constants(&self) -> usize { + /// Wire index of the coset shift. + fn wire_shift(&self) -> usize { 0 } - fn degree(&self) -> usize { - // The highest power of x is `num_points - 1`, and then multiplication by the coefficient - // adds 1. - self.num_points() + fn start_values(&self) -> usize { + 1 } - fn num_constraints(&self) -> usize { - // num_points * D constraints to check for consistency between the coefficients and the - // point-value pairs, plus D constraints for the evaluation value. - self.num_points() * D + D + /// Wire indices of the `i`th interpolant value. + fn wires_value(&self, i: usize) -> Range { + debug_assert!(i < self.num_points()); + let start = self.start_values() + i * D; + start..start + D + } + + fn start_evaluation_point(&self) -> usize { + self.start_values() + self.num_points() * D + } + + /// Wire indices of the point to evaluate the interpolant at. + fn wires_evaluation_point(&self) -> Range { + let start = self.start_evaluation_point(); + start..start + D + } + + fn start_evaluation_value(&self) -> usize { + self.start_evaluation_point() + D + } + + /// Wire indices of the interpolated value. + fn wires_evaluation_value(&self) -> Range { + let start = self.start_evaluation_value(); + start..start + D + } + + fn start_coeffs(&self) -> usize { + self.start_evaluation_value() + D + } + + /// The number of routed wires required in the typical usage of this gate, where the points to + /// interpolate, the evaluation point, and the corresponding value are all routed. + fn num_routed_wires(&self) -> usize { + self.start_coeffs() + } + + /// Wire indices of the interpolant's `i`th coefficient. + fn wires_coeff(&self, i: usize) -> Range { + debug_assert!(i < self.num_points()); + let start = self.start_coeffs() + i * D; + start..start + D + } + + fn end_coeffs(&self) -> usize { + self.start_coeffs() + D * self.num_points() } } -#[derive(Debug)] -struct InterpolationGenerator, const D: usize> { - row: usize, - gate: HighDegreeInterpolationGate, - _phantom: PhantomData, -} - -impl, const D: usize> SimpleGenerator - for InterpolationGenerator -{ - fn dependencies(&self) -> Vec { - let local_target = |column| { - Target::Wire(Wire { - row: self.row, - column, - }) - }; - - let local_targets = |columns: Range| columns.map(local_target); - - let num_points = self.gate.num_points(); - let mut deps = Vec::with_capacity(1 + D + num_points * D); - - deps.push(local_target(self.gate.wire_shift())); - deps.extend(local_targets(self.gate.wires_evaluation_point())); - for i in 0..num_points { - deps.extend(local_targets(self.gate.wires_value(i))); +impl, const D: usize> CircuitBuilder { + /// Interpolates a polynomial, whose points are a coset of the multiplicative subgroup with the + /// given size, and whose values are given. Returns the evaluation of the interpolant at + /// `evaluation_point`. + pub(crate) fn interpolate_coset>( + &mut self, + subgroup_bits: usize, + coset_shift: Target, + values: &[ExtensionTarget], + evaluation_point: ExtensionTarget, + ) -> ExtensionTarget { + let gate = G::new(subgroup_bits); + let row = self.add_gate(gate, vec![]); + self.connect(coset_shift, Target::wire(row, gate.wire_shift())); + for (i, &v) in values.iter().enumerate() { + self.connect_extension(v, ExtensionTarget::from_range(row, gate.wires_value(i))); } - deps - } + self.connect_extension( + evaluation_point, + ExtensionTarget::from_range(row, gate.wires_evaluation_point()), + ); - fn run_once(&self, witness: &PartitionWitness, out_buffer: &mut GeneratedValues) { - let local_wire = |column| Wire { - row: self.row, - column, - }; - - let get_local_wire = |column| witness.get_wire(local_wire(column)); - - let get_local_ext = |wire_range: Range| { - debug_assert_eq!(wire_range.len(), D); - let values = wire_range.map(get_local_wire).collect::>(); - let arr = values.try_into().unwrap(); - F::Extension::from_basefield_array(arr) - }; - - // Compute the interpolant. - let points = self.gate.coset(get_local_wire(self.gate.wire_shift())); - let points = points - .into_iter() - .enumerate() - .map(|(i, point)| (point.into(), get_local_ext(self.gate.wires_value(i)))) - .collect::>(); - let interpolant = interpolant(&points); - - for (i, &coeff) in interpolant.coeffs.iter().enumerate() { - let wires = self.gate.wires_coeff(i).map(local_wire); - out_buffer.set_ext_wires(wires, coeff); - } - - let evaluation_point = get_local_ext(self.gate.wires_evaluation_point()); - let evaluation_value = interpolant.eval(evaluation_point); - let evaluation_value_wires = self.gate.wires_evaluation_value().map(local_wire); - out_buffer.set_ext_wires(evaluation_value_wires, evaluation_value); + ExtensionTarget::from_range(row, gate.wires_evaluation_value()) } } #[cfg(test)] mod tests { - use std::marker::PhantomData; - use anyhow::Result; - use plonky2_field::goldilocks_field::GoldilocksField; - use plonky2_field::polynomial::PolynomialCoeffs; + use plonky2_field::extension::FieldExtension; + use plonky2_field::interpolation::interpolant; use plonky2_field::types::Field; - use crate::gadgets::interpolation::InterpolationGate; - use crate::gates::gate::Gate; - use crate::gates::gate_testing::{test_eval_fns, test_low_degree}; - use crate::gates::interpolation::HighDegreeInterpolationGate; - use crate::hash::hash_types::HashOut; + use crate::gates::high_degree_interpolation::HighDegreeInterpolationGate; + use crate::gates::low_degree_interpolation::LowDegreeInterpolationGate; + use crate::iop::witness::PartialWitness; + use crate::plonk::circuit_builder::CircuitBuilder; + use crate::plonk::circuit_data::CircuitConfig; use crate::plonk::config::{GenericConfig, PoseidonGoldilocksConfig}; - use crate::plonk::vars::EvaluationVars; + use crate::plonk::verifier::verify; #[test] - fn wire_indices() { - let gate = HighDegreeInterpolationGate:: { - subgroup_bits: 1, - _phantom: PhantomData, - }; - - // The exact indices aren't really important, but we want to make sure we don't have any - // overlaps or gaps. - assert_eq!(gate.wire_shift(), 0); - assert_eq!(gate.wires_value(0), 1..5); - assert_eq!(gate.wires_value(1), 5..9); - assert_eq!(gate.wires_evaluation_point(), 9..13); - assert_eq!(gate.wires_evaluation_value(), 13..17); - assert_eq!(gate.wires_coeff(0), 17..21); - assert_eq!(gate.wires_coeff(1), 21..25); - assert_eq!(gate.num_wires(), 25); - } - - #[test] - fn low_degree() { - test_low_degree::(HighDegreeInterpolationGate::new(2)); - } - - #[test] - fn eval_fns() -> Result<()> { - const D: usize = 2; - type C = PoseidonGoldilocksConfig; - type F = >::F; - test_eval_fns::(HighDegreeInterpolationGate::new(2)) - } - - #[test] - fn test_gate_constraint() { + fn test_interpolate() -> Result<()> { const D: usize = 2; type C = PoseidonGoldilocksConfig; type F = >::F; type FF = >::FE; + let config = CircuitConfig::standard_recursion_config(); + let pw = PartialWitness::new(); + let mut builder = CircuitBuilder::::new(config); - /// Returns the local wires for an interpolation gate for given coeffs, points and eval point. - fn get_wires( - gate: &HighDegreeInterpolationGate, - shift: F, - coeffs: PolynomialCoeffs, - eval_point: FF, - ) -> Vec { - let points = gate.coset(shift); - let mut v = vec![shift]; - for x in points { - v.extend(coeffs.eval(x.into()).0); - } - v.extend(eval_point.0); - v.extend(coeffs.eval(eval_point).0); - for i in 0..coeffs.len() { - v.extend(coeffs.coeffs[i].0); - } - v.iter().map(|&x| x.into()).collect() - } + let subgroup_bits = 2; + let len = 1 << subgroup_bits; + let coset_shift = F::rand(); + let g = F::primitive_root_of_unity(subgroup_bits); + let points = F::cyclic_subgroup_coset_known_order(g, coset_shift, len); + let values = FF::rand_vec(len); - // Get a working row for InterpolationGate. - let shift = F::rand(); - let coeffs = PolynomialCoeffs::new(vec![FF::rand(), FF::rand()]); - let eval_point = FF::rand(); - let gate = HighDegreeInterpolationGate::::new(1); - let vars = EvaluationVars { - local_constants: &[], - local_wires: &get_wires(&gate, shift, coeffs, eval_point), - public_inputs_hash: &HashOut::rand(), - }; + let homogeneous_points = points + .iter() + .zip(values.iter()) + .map(|(&a, &b)| (>::from_basefield(a), b)) + .collect::>(); - assert!( - gate.eval_unfiltered(vars).iter().all(|x| x.is_zero()), - "Gate constraints are not satisfied." + let true_interpolant = interpolant(&homogeneous_points); + + let z = FF::rand(); + let true_eval = true_interpolant.eval(z); + + let coset_shift_target = builder.constant(coset_shift); + + let value_targets = values + .iter() + .map(|&v| (builder.constant_extension(v))) + .collect::>(); + + let zt = builder.constant_extension(z); + + let eval_hd = builder.interpolate_coset::>( + subgroup_bits, + coset_shift_target, + &value_targets, + zt, ); + let eval_ld = builder.interpolate_coset::>( + subgroup_bits, + coset_shift_target, + &value_targets, + zt, + ); + let true_eval_target = builder.constant_extension(true_eval); + builder.connect_extension(eval_hd, true_eval_target); + builder.connect_extension(eval_ld, true_eval_target); + + let data = builder.build::(); + let proof = data.prove(pw)?; + + verify(proof, &data.verifier_only, &data.common) } } diff --git a/plonky2/src/gates/low_degree_interpolation.rs b/plonky2/src/gates/low_degree_interpolation.rs index 217f4f0a..3edc4175 100644 --- a/plonky2/src/gates/low_degree_interpolation.rs +++ b/plonky2/src/gates/low_degree_interpolation.rs @@ -7,9 +7,9 @@ use plonky2_field::interpolation::interpolant; use plonky2_field::polynomial::PolynomialCoeffs; use plonky2_field::types::Field; -use crate::gadgets::interpolation::InterpolationGate; use crate::gadgets::polynomial::PolynomialCoeffsExtAlgebraTarget; use crate::gates::gate::Gate; +use crate::gates::interpolation::InterpolationGate; use crate::gates::util::StridedConstraintConsumer; use crate::hash::hash_types::RichField; use crate::iop::ext_target::ExtensionTarget; @@ -20,8 +20,9 @@ use crate::iop::witness::{PartitionWitness, Witness}; use crate::plonk::circuit_builder::CircuitBuilder; use crate::plonk::vars::{EvaluationTargets, EvaluationVars, EvaluationVarsBase}; -/// Interpolation gate with constraints of degree 2. -/// `eval_unfiltered_recursively` uses more gates than `HighDegreeInterpolationGate`. +/// One of the instantiations of `InterpolationGate`: all constraints are degree <= 2. +/// The lower degree is a tradeoff for more gates (`eval_unfiltered_recursively` for +/// this version uses more gates than `LowDegreeInterpolationGate`). #[derive(Copy, Clone, Debug)] pub struct LowDegreeInterpolationGate, const D: usize> { pub subgroup_bits: usize, @@ -113,7 +114,7 @@ impl, const D: usize> Gate for LowDegreeInter { let value = vars.get_local_ext_algebra(self.wires_value(i)); let computed_value = altered_interpolant.eval_base(point); - constraints.extend(&(value - computed_value).to_basefield_array()); + constraints.extend((value - computed_value).to_basefield_array()); } let evaluation_point_powers = (1..self.num_points()) @@ -128,7 +129,7 @@ impl, const D: usize> Gate for LowDegreeInter } let evaluation_value = vars.get_local_ext_algebra(self.wires_evaluation_value()); let computed_evaluation_value = interpolant.eval_with_powers(&evaluation_point_powers); - constraints.extend(&(evaluation_value - computed_evaluation_value).to_basefield_array()); + constraints.extend((evaluation_value - computed_evaluation_value).to_basefield_array()); constraints } @@ -225,7 +226,7 @@ impl, const D: usize> Gate for LowDegreeInter let point = builder.constant_extension(point); let computed_value = altered_interpolant.eval_scalar(builder, point); constraints.extend( - &builder + builder .sub_ext_algebra(value, computed_value) .to_ext_target_array(), ); @@ -253,7 +254,7 @@ impl, const D: usize> Gate for LowDegreeInter // let evaluation_value = vars.get_local_ext_algebra(self.wires_evaluation_value()); // let computed_evaluation_value = interpolant.eval(builder, evaluation_point); constraints.extend( - &builder + builder .sub_ext_algebra(evaluation_value, computed_evaluation_value) .to_ext_target_array(), ); @@ -387,9 +388,9 @@ mod tests { use plonky2_field::polynomial::PolynomialCoeffs; use plonky2_field::types::Field; - use crate::gadgets::interpolation::InterpolationGate; use crate::gates::gate::Gate; use crate::gates::gate_testing::{test_eval_fns, test_low_degree}; + use crate::gates::interpolation::InterpolationGate; use crate::gates::low_degree_interpolation::LowDegreeInterpolationGate; use crate::hash::hash_types::HashOut; use crate::plonk::config::{GenericConfig, PoseidonGoldilocksConfig}; diff --git a/plonky2/src/gates/mod.rs b/plonky2/src/gates/mod.rs index df65b44c..1d2fc058 100644 --- a/plonky2/src/gates/mod.rs +++ b/plonky2/src/gates/mod.rs @@ -3,11 +3,11 @@ pub mod arithmetic_base; pub mod arithmetic_extension; -pub mod assert_le; pub mod base_sum; pub mod constant; pub mod exponentiation; pub mod gate; +pub mod high_degree_interpolation; pub mod interpolation; pub mod low_degree_interpolation; pub mod multiplication_extension; diff --git a/plonky2/src/gates/random_access.rs b/plonky2/src/gates/random_access.rs index 2df392bc..fa365f16 100644 --- a/plonky2/src/gates/random_access.rs +++ b/plonky2/src/gates/random_access.rs @@ -24,9 +24,15 @@ use crate::plonk::vars::{ /// A gate for checking that a particular element of a list matches a given value. #[derive(Copy, Clone, Debug)] pub struct RandomAccessGate, const D: usize> { + /// Number of bits in the index (log2 of the list size). pub bits: usize, + + /// How many separate copies are packed into one gate. pub num_copies: usize, + + /// Leftover wires are used as global scratch space to store constants. pub num_extra_constants: usize, + _phantom: PhantomData, } @@ -41,13 +47,18 @@ impl, const D: usize> RandomAccessGate { } pub fn new_from_config(config: &CircuitConfig, bits: usize) -> Self { + // We can access a list of 2^bits elements. let vec_size = 1 << bits; - // Need `(2 + vec_size) * num_copies` routed wires + + // We need `(2 + vec_size) * num_copies` routed wires. let max_copies = (config.num_routed_wires / (2 + vec_size)).min( - // Need `(2 + vec_size + bits) * num_copies` wires + // We need `(2 + vec_size + bits) * num_copies` wires in total. config.num_wires / (2 + vec_size + bits), ); + + // Any leftover wires can be used for constants. let max_extra_constants = config.num_routed_wires - (2 + vec_size) * max_copies; + Self::new( max_copies, bits, @@ -55,20 +66,24 @@ impl, const D: usize> RandomAccessGate { ) } + /// Length of the list being accessed. fn vec_size(&self) -> usize { 1 << self.bits } + /// For each copy, a wire containing the claimed index of the element. pub fn wire_access_index(&self, copy: usize) -> usize { debug_assert!(copy < self.num_copies); (2 + self.vec_size()) * copy } + /// For each copy, a wire containing the element claimed to be at the index. pub fn wire_claimed_element(&self, copy: usize) -> usize { debug_assert!(copy < self.num_copies); (2 + self.vec_size()) * copy + 1 } + /// For each copy, wires containing the entire list. pub fn wire_list_item(&self, i: usize, copy: usize) -> usize { debug_assert!(i < self.vec_size()); debug_assert!(copy < self.num_copies); @@ -84,6 +99,7 @@ impl, const D: usize> RandomAccessGate { self.start_extra_constants() + i } + /// All above wires are routed. pub fn num_routed_wires(&self) -> usize { self.start_extra_constants() + self.num_extra_constants } @@ -202,10 +218,12 @@ impl, const D: usize> Gate for RandomAccessGa .collect() } + // Check that the one remaining element after the folding is the claimed element. debug_assert_eq!(list_items.len(), 1); constraints.push(builder.sub_extension(list_items[0], claimed_element)); } + // Check the constant values. constraints.extend((0..self.num_extra_constants).map(|i| { builder.sub_extension( vars.local_constants[i], diff --git a/plonky2/src/hash/hash_types.rs b/plonky2/src/hash/hash_types.rs index 14303ad3..f416732a 100644 --- a/plonky2/src/hash/hash_types.rs +++ b/plonky2/src/hash/hash_types.rs @@ -115,7 +115,7 @@ pub struct MerkleCapTarget(pub Vec); pub struct BytesHash(pub [u8; N]); impl BytesHash { - #[cfg(feature = "parallel")] + #[cfg(feature = "rand")] pub fn rand_from_rng(rng: &mut R) -> Self { let mut buf = [0; N]; rng.fill_bytes(&mut buf); diff --git a/plonky2/src/iop/generator.rs b/plonky2/src/iop/generator.rs index 5bedf13d..3614b2e4 100644 --- a/plonky2/src/iop/generator.rs +++ b/plonky2/src/iop/generator.rs @@ -31,7 +31,6 @@ pub(crate) fn generate_partial_witness< let mut witness = PartitionWitness::new( config.num_wires, common_data.degree(), - common_data.num_virtual_targets, &prover_data.representative_map, ); diff --git a/plonky2/src/iop/witness.rs b/plonky2/src/iop/witness.rs index 871f303f..e7f21241 100644 --- a/plonky2/src/iop/witness.rs +++ b/plonky2/src/iop/witness.rs @@ -104,9 +104,12 @@ pub trait Witness { where F: RichField + Extendable, { - let limbs = value.to_basefield_array(); - (0..D).for_each(|i| { - self.set_target(et.0[i], limbs[i]); + self.set_target_arr(et.0, value.to_basefield_array()); + } + + fn set_target_arr(&mut self, targets: [Target; N], values: [F; N]) { + (0..N).for_each(|i| { + self.set_target(targets[i], values[i]); }); } @@ -275,14 +278,9 @@ pub struct PartitionWitness<'a, F: Field> { } impl<'a, F: Field> PartitionWitness<'a, F> { - pub fn new( - num_wires: usize, - degree: usize, - num_virtual_targets: usize, - representative_map: &'a [usize], - ) -> Self { + pub fn new(num_wires: usize, degree: usize, representative_map: &'a [usize]) -> Self { Self { - values: vec![None; degree * num_wires + num_virtual_targets], + values: vec![None; representative_map.len()], representative_map, num_wires, degree, diff --git a/plonky2/src/plonk/circuit_builder.rs b/plonky2/src/plonk/circuit_builder.rs index 9f97251a..579a9017 100644 --- a/plonky2/src/plonk/circuit_builder.rs +++ b/plonky2/src/plonk/circuit_builder.rs @@ -2,6 +2,7 @@ use std::cmp::max; use std::collections::{BTreeMap, HashMap, HashSet}; use std::time::Instant; +use itertools::Itertools; use log::{debug, info, Level}; use plonky2_field::cosets::get_unique_coset_shifts; use plonky2_field::extension::{Extendable, FieldExtension}; @@ -95,9 +96,9 @@ impl, const D: usize> CircuitBuilder { context_log: ContextTree::new(), generators: Vec::new(), constants_to_targets: HashMap::new(), + targets_to_constants: HashMap::new(), base_arithmetic_results: HashMap::new(), arithmetic_results: HashMap::new(), - targets_to_constants: HashMap::new(), current_slots: HashMap::new(), constant_generators: Vec::new(), }; @@ -156,6 +157,10 @@ impl, const D: usize> CircuitBuilder { (0..n).map(|_i| self.add_virtual_target()).collect() } + pub fn add_virtual_target_arr(&mut self) -> [Target; N] { + [0; N].map(|_| self.add_virtual_target()) + } + pub fn add_virtual_hash(&mut self) -> HashOutTarget { HashOutTarget::from_vec(self.add_virtual_targets(4)) } @@ -665,6 +670,9 @@ impl, const D: usize> CircuitBuilder { .constants_to_targets .clone() .into_iter() + // We need to enumerate constants_to_targets in some deterministic order to ensure that + // building a circuit is deterministic. + .sorted_by_key(|(c, _t)| c.to_canonical_u64()) .zip(self.constant_generators.clone()) { // Set the constant in the constant polynomial. @@ -790,7 +798,10 @@ impl, const D: usize> CircuitBuilder { // TODO: This should also include an encoding of gate constraints. let circuit_digest_parts = [ constants_sigmas_cap.flatten(), - vec![/* Add other circuit data here */], + vec![ + F::from_canonical_usize(degree_bits), + /* Add other circuit data here */ + ], ]; let circuit_digest = C::Hasher::hash_no_pad(&circuit_digest_parts.concat()); @@ -803,7 +814,6 @@ impl, const D: usize> CircuitBuilder { quotient_degree_factor, num_gate_constraints, num_constants, - num_virtual_targets: self.virtual_target_index, num_public_inputs, k_is, num_partial_products, diff --git a/plonky2/src/plonk/circuit_data.rs b/plonky2/src/plonk/circuit_data.rs index fb839978..20697d36 100644 --- a/plonky2/src/plonk/circuit_data.rs +++ b/plonky2/src/plonk/circuit_data.rs @@ -265,8 +265,6 @@ pub struct CommonCircuitData< /// The number of constant wires. pub(crate) num_constants: usize, - pub(crate) num_virtual_targets: usize, - pub(crate) num_public_inputs: usize, /// The `{k_i}` valued used in `S_ID_i` in Plonk's permutation argument. diff --git a/plonky2/src/plonk/plonk_common.rs b/plonky2/src/plonk/plonk_common.rs index 4f92d732..e947353b 100644 --- a/plonky2/src/plonk/plonk_common.rs +++ b/plonky2/src/plonk/plonk_common.rs @@ -64,31 +64,31 @@ pub(crate) fn eval_zero_poly(n: usize, x: F) -> F { x.exp_u64(n as u64) - F::ONE } -/// Evaluate the Lagrange basis `L_1` with `L_1(1) = 1`, and `L_1(x) = 0` for other members of an +/// Evaluate the Lagrange basis `L_0` with `L_0(1) = 1`, and `L_0(x) = 0` for other members of the /// order `n` multiplicative subgroup. -pub(crate) fn eval_l_1(n: usize, x: F) -> F { +pub(crate) fn eval_l_0(n: usize, x: F) -> F { if x.is_one() { // The code below would divide by zero, since we have (x - 1) in both the numerator and // denominator. return F::ONE; } - // L_1(x) = (x^n - 1) / (n * (x - 1)) + // L_0(x) = (x^n - 1) / (n * (x - 1)) // = Z(x) / (n * (x - 1)) eval_zero_poly(n, x) / (F::from_canonical_usize(n) * (x - F::ONE)) } -/// Evaluates the Lagrange basis L_1(x), which has L_1(1) = 1 and vanishes at all other points in +/// Evaluates the Lagrange basis L_0(x), which has L_0(1) = 1 and vanishes at all other points in /// the order-`n` subgroup. /// /// Assumes `x != 1`; if `x` could be 1 then this is unsound. -pub(crate) fn eval_l_1_circuit, const D: usize>( +pub(crate) fn eval_l_0_circuit, const D: usize>( builder: &mut CircuitBuilder, n: usize, x: ExtensionTarget, x_pow_n: ExtensionTarget, ) -> ExtensionTarget { - // L_1(x) = (x^n - 1) / (n * (x - 1)) + // L_0(x) = (x^n - 1) / (n * (x - 1)) // = Z(x) / (n * (x - 1)) let one = builder.one_extension(); let neg_one = builder.neg_one(); diff --git a/plonky2/src/plonk/vanishing_poly.rs b/plonky2/src/plonk/vanishing_poly.rs index ab0ba53b..303f698b 100644 --- a/plonky2/src/plonk/vanishing_poly.rs +++ b/plonky2/src/plonk/vanishing_poly.rs @@ -10,7 +10,7 @@ use crate::plonk::circuit_builder::CircuitBuilder; use crate::plonk::circuit_data::CommonCircuitData; use crate::plonk::config::GenericConfig; use crate::plonk::plonk_common; -use crate::plonk::plonk_common::eval_l_1_circuit; +use crate::plonk::plonk_common::eval_l_0_circuit; use crate::plonk::vars::{EvaluationTargets, EvaluationVars, EvaluationVarsBaseBatch}; use crate::util::partial_products::{check_partial_products, check_partial_products_circuit}; use crate::util::reducing::ReducingFactorTarget; @@ -41,17 +41,17 @@ pub(crate) fn eval_vanishing_poly< let constraint_terms = evaluate_gate_constraints(common_data, vars); - // The L_1(x) (Z(x) - 1) vanishing terms. + // The L_0(x) (Z(x) - 1) vanishing terms. let mut vanishing_z_1_terms = Vec::new(); // The terms checking the partial products. let mut vanishing_partial_products_terms = Vec::new(); - let l1_x = plonk_common::eval_l_1(common_data.degree(), x); + let l_0_x = plonk_common::eval_l_0(common_data.degree(), x); for i in 0..common_data.config.num_challenges { let z_x = local_zs[i]; let z_gx = next_zs[i]; - vanishing_z_1_terms.push(l1_x * (z_x - F::Extension::ONE)); + vanishing_z_1_terms.push(l_0_x * (z_x - F::Extension::ONE)); let numerator_values = (0..common_data.config.num_routed_wires) .map(|j| { @@ -135,7 +135,7 @@ pub(crate) fn eval_vanishing_poly_base_batch< let mut numerator_values = Vec::with_capacity(num_routed_wires); let mut denominator_values = Vec::with_capacity(num_routed_wires); - // The L_1(x) (Z(x) - 1) vanishing terms. + // The L_0(x) (Z(x) - 1) vanishing terms. let mut vanishing_z_1_terms = Vec::with_capacity(num_challenges); // The terms checking the partial products. let mut vanishing_partial_products_terms = Vec::new(); @@ -152,11 +152,11 @@ pub(crate) fn eval_vanishing_poly_base_batch< let constraint_terms = PackedStridedView::new(&constraint_terms_batch, n, k); - let l1_x = z_h_on_coset.eval_l1(index, x); + let l_0_x = z_h_on_coset.eval_l_0(index, x); for i in 0..num_challenges { let z_x = local_zs[i]; let z_gx = next_zs[i]; - vanishing_z_1_terms.push(l1_x * z_x.sub_one()); + vanishing_z_1_terms.push(l_0_x * z_x.sub_one()); numerator_values.extend((0..num_routed_wires).map(|j| { let wire_value = vars.local_wires[j]; @@ -332,12 +332,12 @@ pub(crate) fn eval_vanishing_poly_circuit< evaluate_gate_constraints_circuit(builder, common_data, vars,) ); - // The L_1(x) (Z(x) - 1) vanishing terms. + // The L_0(x) (Z(x) - 1) vanishing terms. let mut vanishing_z_1_terms = Vec::new(); // The terms checking the partial products. let mut vanishing_partial_products_terms = Vec::new(); - let l1_x = eval_l_1_circuit(builder, common_data.degree(), x, x_pow_deg); + let l_0_x = eval_l_0_circuit(builder, common_data.degree(), x, x_pow_deg); // Holds `k[i] * x`. let mut s_ids = Vec::new(); @@ -350,8 +350,8 @@ pub(crate) fn eval_vanishing_poly_circuit< let z_x = local_zs[i]; let z_gx = next_zs[i]; - // L_1(x) Z(x) = 0. - vanishing_z_1_terms.push(builder.mul_sub_extension(l1_x, z_x, l1_x)); + // L_0(x) (Z(x) - 1) = 0. + vanishing_z_1_terms.push(builder.mul_sub_extension(l_0_x, z_x, l_0_x)); let mut numerator_values = Vec::new(); let mut denominator_values = Vec::new(); diff --git a/plonky2/src/util/serialization.rs b/plonky2/src/util/serialization.rs index da03213c..978134b6 100644 --- a/plonky2/src/util/serialization.rs +++ b/plonky2/src/util/serialization.rs @@ -282,7 +282,7 @@ impl Buffer { arity: usize, compressed: bool, ) -> Result> { - let evals = self.read_field_ext_vec::(arity - if compressed { 1 } else { 0 })?; + let evals = self.read_field_ext_vec::(arity - usize::from(compressed))?; let merkle_proof = self.read_merkle_proof()?; Ok(FriQueryStep { evals, diff --git a/projects/cache-friendly-fft/__init__.py b/projects/cache-friendly-fft/__init__.py new file mode 100644 index 00000000..08f1acac --- /dev/null +++ b/projects/cache-friendly-fft/__init__.py @@ -0,0 +1,229 @@ +import numpy as np + +from transpose import transpose_square +from util import lb_exact + + +def _interleave(x, scratch): + """Interleave the elements in an array in-place. + + For example, if `x` is `array([1, 2, 3, 4, 5, 6, 7, 8])`, then its + contents will be rearranged to `array([1, 5, 2, 6, 3, 7, 4, 8])`. + + `scratch` is an externally-allocated buffer, whose `dtype` matches + `x` and whose length is at least half the length of `x`. + """ + assert len(x.shape) == len(scratch.shape) == 1 + + n, = x.shape + assert n % 2 == 0 + + half_n = n // 2 + assert scratch.shape[0] >= half_n + + assert x.dtype == scratch.dtype + scratch = scratch[:half_n] + + scratch[:] = x[:half_n] # Save the first half of `x`. + for i in range(half_n): + x[2 * i] = scratch[i] + x[2 * i + 1] = x[half_n + i] + + +def _deinterleave(x, scratch): + """Deinterleave the elements in an array in-place. + + For example, if `x` is `array([1, 2, 3, 4, 5, 6, 7, 8])`, then its + contents will be rearranged to `array([1, 3, 5, 7, 2, 4, 6, 8])`. + + `scratch` is an externally-allocated buffer, whose `dtype` matches + `x` and whose length is at least half the length of `x`. + """ + assert len(x.shape) == len(scratch.shape) == 1 + + n, = x.shape + assert n % 2 == 0 + + half_n = n // 2 + assert scratch.shape[0] >= half_n + + assert x.dtype == scratch.dtype + scratch = scratch[:half_n] + + for i in range(half_n): + x[i] = x[2 * i] + scratch[i] = x[2 * i + 1] + x[half_n:] = scratch + + +def _fft_inplace_evenpow(x, scratch): + """In-place FFT of length 2^even""" + # Reshape `x` to a square matrix in row-major order. + vec_len = x.shape[0] + n = 1 << (lb_exact(vec_len) >> 1) # Matrix dimension + x.shape = n, n, 1 + + # We want to recursively apply FFT to every column. Because `x` is + # in row-major order, we transpose it to make the columns contiguous + # in memory, then recurse, and finally transpose it back. While the + # row is in cache, we also multiply by the twiddle factors. + transpose_square(x) + for i, row in enumerate(x[..., 0]): + _fft_inplace(row, scratch) + # Multiply by the twiddle factors + for j in range(n): + row[j] *= np.exp(-2j * np.pi * (i * j) / vec_len) + transpose_square(x) + + # Now recursively apply FFT to the rows. + for row in x[..., 0]: + _fft_inplace(row, scratch) + + # Transpose again before returning. + transpose_square(x) + + +def _fft_inplace_oddpow(x, scratch): + """In-place FFT of length 2^odd""" + # This code is based on `_fft_inplace_evenpow`, but it has to + # account for some additional complications. + + vec_len = x.shape[0] + # `vec_len` is an odd power of 2, so we cannot reshape `x` to a + # matrix square. Instead, we'll (conceptually) reshape it to a + # matrix that's twice as wide as it is high. E.g., `[1 ... 8]` + # becomes `[1 2 3 4]` + # `[5 6 7 8]`. + col_len = 1 << (lb_exact(vec_len) >> 1) + row_len = col_len << 1 + + # We can only perform efficient, in-place transposes on square + # matrices, so we will actually treat this as a square matrix of + # 2-tuples, e.g. `[(1 2) (3 4)]` + # `[(5 6) (7 8)]`. + # Note that we can currently `.reshape` it to our intended wide + # matrix (although this is broken by transposition). + x.shape = col_len, col_len, 2 + + # We want to apply FFT to each column. We transpose our + # matrix-of-tuples and get something like `[(1 2) (5 6)]` + # `[(3 4) (7 8)]`. + # Note that each row of the transposed matrix represents two columns + # of the original matrix. We can deinterleave the values to recover + # the original columns. + transpose_square(x) + + for i, row_pair in enumerate(x): + # `row_pair` represents two columns of the original matrix. + # Their values must be deinterleaved to recover the columns. + row_pair.shape = row_len, + _deinterleave(row_pair, scratch) + # The below are rows of the transposed matrix(/cols of the + # original matrix. + row0 = row_pair[:col_len] + row1 = row_pair[col_len:] + + # Apply FFT and twiddle factors to each. + _fft_inplace(row0, scratch) + for j in range(col_len): + row0[j] *= np.exp(-2j * np.pi * ((2 * i) * j) / vec_len) + _fft_inplace(row1, scratch) + for j in range(col_len): + row1[j] *= np.exp(-2j * np.pi * ((2 * i + 1) * j) / vec_len) + + # Re-interleave them and transpose back. + _interleave(row_pair, scratch) + + transpose_square(x) + + # Recursively apply FFT to each row of the matrix. + for row in x: + # Turn vec of 2-tuples into vec of single elements. + row.shape = row_len, + _fft_inplace(row, scratch) + + # Transpose again before returning. This again involves + # deinterleaving. + transpose_square(x) + for row_pair in x: + row_pair.shape = row_len, + _deinterleave(row_pair, scratch) + + +def _fft_inplace(x, scratch): + """In-place FFT.""" + # Avoid modifying the shape of the original. + # This does not copy the buffer. + x = x.view() + assert x.flags['C_CONTIGUOUS'] + + n, = x.shape + if n == 1: + return + if n == 2: + x0, x1 = x + x[0] = x0 + x1 + x[1] = x0 - x1 + return + + lb_n = lb_exact(n) + is_odd = lb_n & 1 != 0 + if is_odd: + _fft_inplace_oddpow(x, scratch) + else: + _fft_inplace_evenpow(x, scratch) + + +def _scrach_length(lb_n): + """Find the amount of scratch space required to run the FFT. + + Layers where the input's length is an even power of two do not + require scratch space, but the layers where that power is odd do. + """ + if lb_n == 0: + # Length-1 input. + return 0 + # Repeatedly halve lb_n as long as it's even. This is the same as + # `n = sqrt(n)`, where the `sqrt` is exact. + while lb_n & 1 == 0: + lb_n >>= 1 + # `lb_n` is now odd, so `n` is not an even power of 2. + lb_res = (lb_n - 1) >> 1 + if lb_res == 0: + # Special case (n == 2 or n == 4): no scratch needed. + return 0 + return 1 << lb_res + + +def fft(x): + """Returns the FFT of `x`. + + This is a wrapper around an in-place routine, provided for user + convenience. + """ + n, = x.shape + lb_n = lb_exact(n) # Raises if not a power of 2. + # We have one scratch buffer for the whole algorithm. If we were to + # parallelize it, we'd need one thread-local buffer for each worker + # thread. + scratch_len = _scrach_length(lb_n) + if scratch_len == 0: + scratch = None + else: + scratch = np.empty_like(x, shape=scratch_len, order='C', subok=False) + + res = x.copy(order='C') + _fft_inplace(res, scratch) + + return res + + +if __name__ == "__main__": + LENGTH = 1 << 10 + v = np.random.normal(size=LENGTH).astype(complex) + print(v) + numpy_fft = np.fft.fft(v) + print(numpy_fft) + our_fft = fft(v) + print(our_fft) + print(np.isclose(numpy_fft, our_fft).all()) diff --git a/projects/cache-friendly-fft/transpose.py b/projects/cache-friendly-fft/transpose.py new file mode 100644 index 00000000..ea20bf6b --- /dev/null +++ b/projects/cache-friendly-fft/transpose.py @@ -0,0 +1,61 @@ +from util import lb_exact + + +def _swap_transpose_square(a, b): + """Transpose two square matrices in-place and swap them. + + The matrices must be a of shape `(n, n, m)`, where the `m` dimension + may be of arbitrary length and is not moved. + """ + assert len(a.shape) == len(b.shape) == 3 + n = a.shape[0] + m = a.shape[2] + assert n == a.shape[1] == b.shape[0] == b.shape[1] + assert m == b.shape[2] + + if n == 0: + return + if n == 1: + # Swap the two matrices (transposition is a no-op). + a = a[0, 0] + b = b[0, 0] + # Recall that each element of the matrix is an `m`-vector. Swap + # all `m` elements. + for i in range(m): + a[i], b[i] = b[i], a[i] + return + + half_n = n >> 1 + # Transpose and swap top-left of `a` with top-left of `b`. + _swap_transpose_square(a[:half_n, :half_n], b[:half_n, :half_n]) + # ...top-right of `a` with bottom-left of `b`. + _swap_transpose_square(a[:half_n, half_n:], b[half_n:, :half_n]) + # ...bottom-left of `a` with top-right of `b`. + _swap_transpose_square(a[half_n:, :half_n], b[:half_n, half_n:]) + # ...bottom-right of `a` with bottom-right of `b`. + _swap_transpose_square(a[half_n:, half_n:], b[half_n:, half_n:]) + + +def transpose_square(a): + """In-place transpose of a square matrix. + + The matrix must be a of shape `(n, n, m)`, where the `m` dimension + may be of arbitrary length and is not moved. + """ + if len(a.shape) != 3: + raise ValueError("a must be a matrix of batches") + n, n_, _ = a.shape + if n != n_: + raise ValueError("a must be square") + lb_exact(n) + + if n <= 1: + return # Base case: no-op + + half_n = n >> 1 + # Transpose top-left quarter in-place. + transpose_square(a[:half_n, :half_n]) + # Transpose top-right and bottom-left quarters and swap them. + _swap_transpose_square(a[:half_n, half_n:], a[half_n:, :half_n]) + # Transpose bottom-right quarter in-place. + transpose_square(a[half_n:, half_n:]) diff --git a/projects/cache-friendly-fft/util.py b/projects/cache-friendly-fft/util.py new file mode 100644 index 00000000..50118827 --- /dev/null +++ b/projects/cache-friendly-fft/util.py @@ -0,0 +1,6 @@ +def lb_exact(n): + """Returns `log2(n)`, raising if `n` is not a power of 2.""" + lb = n.bit_length() - 1 + if lb < 0 or n != 1 << lb: + raise ValueError(f"{n} is not a power of 2") + return lb diff --git a/starky/Cargo.toml b/starky/Cargo.toml index 80a26bfc..43bea53e 100644 --- a/starky/Cargo.toml +++ b/starky/Cargo.toml @@ -6,13 +6,13 @@ edition = "2021" [features] default = ["parallel"] -parallel = ["maybe_rayon/parallel"] +parallel = ["plonky2/parallel", "maybe_rayon/parallel"] [dependencies] -plonky2 = { path = "../plonky2" } +plonky2 = { path = "../plonky2", default-features = false, features = ["rand", "timing"] } plonky2_util = { path = "../util" } +maybe_rayon = { path = "../maybe_rayon"} anyhow = "1.0.40" env_logger = "0.9.0" itertools = "0.10.0" log = "0.4.14" -maybe_rayon = { path = "../maybe_rayon"} diff --git a/starky/src/constraint_consumer.rs b/starky/src/constraint_consumer.rs index c9368ba3..1a061c20 100644 --- a/starky/src/constraint_consumer.rs +++ b/starky/src/constraint_consumer.rs @@ -44,12 +44,8 @@ impl ConstraintConsumer

{ } } - // TODO: Do this correctly. - pub fn accumulators(self) -> Vec { + pub fn accumulators(self) -> Vec

{ self.constraint_accs - .into_iter() - .map(|acc| acc.as_slice()[0]) - .collect() } /// Add one constraint valid on all rows except the last. diff --git a/starky/src/prover.rs b/starky/src/prover.rs index 24593b45..0d291cf3 100644 --- a/starky/src/prover.rs +++ b/starky/src/prover.rs @@ -258,7 +258,7 @@ where let quotient_values = (0..size) .into_par_iter() .step_by(P::WIDTH) - .map(|i_start| { + .flat_map_iter(|i_start| { let i_next_start = (i_start + next_step) % size; let i_range = i_start..i_start + P::WIDTH; @@ -292,13 +292,22 @@ where permutation_check_data, &mut consumer, ); + let mut constraints_evals = consumer.accumulators(); // We divide the constraints evaluations by `Z_H(x)`. - let denominator_inv = z_h_on_coset.eval_inverse_packed(i_start); + let denominator_inv: P = z_h_on_coset.eval_inverse_packed(i_start); + for eval in &mut constraints_evals { *eval *= denominator_inv; } - constraints_evals + + let num_challenges = alphas.len(); + + (0..P::WIDTH).into_iter().map(move |i| { + (0..num_challenges) + .map(|j| constraints_evals[j].as_slice()[i]) + .collect() + }) }) .collect::>(); diff --git a/starky/src/recursive_verifier.rs b/starky/src/recursive_verifier.rs index 7f20d89b..04858d55 100644 --- a/starky/src/recursive_verifier.rs +++ b/starky/src/recursive_verifier.rs @@ -102,8 +102,8 @@ fn verify_stark_proof_with_challenges_circuit< let zeta_pow_deg = builder.exp_power_of_2_extension(challenges.stark_zeta, degree_bits); let z_h_zeta = builder.sub_extension(zeta_pow_deg, one); - let (l_1, l_last) = - eval_l_1_and_l_last_circuit(builder, degree_bits, challenges.stark_zeta, z_h_zeta); + let (l_0, l_last) = + eval_l_0_and_l_last_circuit(builder, degree_bits, challenges.stark_zeta, z_h_zeta); let last = builder.constant_extension(F::Extension::primitive_root_of_unity(degree_bits).inverse()); let z_last = builder.sub_extension(challenges.stark_zeta, last); @@ -112,7 +112,7 @@ fn verify_stark_proof_with_challenges_circuit< builder.zero_extension(), challenges.stark_alphas, z_last, - l_1, + l_0, l_last, ); @@ -170,7 +170,7 @@ fn verify_stark_proof_with_challenges_circuit< ); } -fn eval_l_1_and_l_last_circuit, const D: usize>( +fn eval_l_0_and_l_last_circuit, const D: usize>( builder: &mut CircuitBuilder, log_n: usize, x: ExtensionTarget, @@ -179,12 +179,12 @@ fn eval_l_1_and_l_last_circuit, const D: usize>( let n = builder.constant_extension(F::Extension::from_canonical_usize(1 << log_n)); let g = builder.constant_extension(F::Extension::primitive_root_of_unity(log_n)); let one = builder.one_extension(); - let l_1_deno = builder.mul_sub_extension(n, x, n); + let l_0_deno = builder.mul_sub_extension(n, x, n); let l_last_deno = builder.mul_sub_extension(g, x, one); let l_last_deno = builder.mul_extension(n, l_last_deno); ( - builder.div_extension(z_x, l_1_deno), + builder.div_extension(z_x, l_0_deno), builder.div_extension(z_x, l_last_deno), ) } diff --git a/starky/src/verifier.rs b/starky/src/verifier.rs index 306d3d14..efb3d29c 100644 --- a/starky/src/verifier.rs +++ b/starky/src/verifier.rs @@ -78,7 +78,7 @@ where .unwrap(), }; - let (l_1, l_last) = eval_l_1_and_l_last(degree_bits, challenges.stark_zeta); + let (l_0, l_last) = eval_l_0_and_l_last(degree_bits, challenges.stark_zeta); let last = F::primitive_root_of_unity(degree_bits).inverse(); let z_last = challenges.stark_zeta - last.into(); let mut consumer = ConstraintConsumer::::new( @@ -88,7 +88,7 @@ where .map(|&alpha| F::Extension::from_basefield(alpha)) .collect::>(), z_last, - l_1, + l_0, l_last, ); let permutation_data = stark.uses_permutation_args().then(|| PermutationCheckVars { @@ -144,10 +144,10 @@ where Ok(()) } -/// Evaluate the Lagrange polynomials `L_1` and `L_n` at a point `x`. -/// `L_1(x) = (x^n - 1)/(n * (x - 1))` -/// `L_n(x) = (x^n - 1)/(n * (g * x - 1))`, with `g` the first element of the subgroup. -fn eval_l_1_and_l_last(log_n: usize, x: F) -> (F, F) { +/// Evaluate the Lagrange polynomials `L_0` and `L_(n-1)` at a point `x`. +/// `L_0(x) = (x^n - 1)/(n * (x - 1))` +/// `L_(n-1)(x) = (x^n - 1)/(n * (g * x - 1))`, with `g` the first element of the subgroup. +fn eval_l_0_and_l_last(log_n: usize, x: F) -> (F, F) { let n = F::from_canonical_usize(1 << log_n); let g = F::primitive_root_of_unity(log_n); let z_x = x.exp_power_of_2(log_n) - F::ONE; @@ -189,10 +189,10 @@ mod tests { use plonky2::field::polynomial::PolynomialValues; use plonky2::field::types::Field; - use crate::verifier::eval_l_1_and_l_last; + use crate::verifier::eval_l_0_and_l_last; #[test] - fn test_eval_l_1_and_l_last() { + fn test_eval_l_0_and_l_last() { type F = GoldilocksField; let log_n = 5; let n = 1 << log_n; @@ -201,7 +201,7 @@ mod tests { let expected_l_first_x = PolynomialValues::selector(n, 0).ifft().eval(x); let expected_l_last_x = PolynomialValues::selector(n, n - 1).ifft().eval(x); - let (l_first_x, l_last_x) = eval_l_1_and_l_last(log_n, x); + let (l_first_x, l_last_x) = eval_l_0_and_l_last(log_n, x); assert_eq!(l_first_x, expected_l_first_x); assert_eq!(l_last_x, expected_l_last_x); } diff --git a/u32/src/gadgets/arithmetic_u32.rs b/u32/src/gadgets/arithmetic_u32.rs index 7a7731b1..7475681c 100644 --- a/u32/src/gadgets/arithmetic_u32.rs +++ b/u32/src/gadgets/arithmetic_u32.rs @@ -10,7 +10,7 @@ use plonky2_field::extension::Extendable; use crate::gates::add_many_u32::U32AddManyGate; use crate::gates::arithmetic_u32::U32ArithmeticGate; use crate::gates::subtraction_u32::U32SubtractionGate; -use crate::witness::generated_values_set_u32_target; +use crate::witness::GeneratedValuesU32; #[derive(Clone, Copy, Debug)] pub struct U32Target(pub Target); @@ -249,8 +249,8 @@ impl, const D: usize> SimpleGenerator let low = x_u64 as u32; let high = (x_u64 >> 32) as u32; - generated_values_set_u32_target(out_buffer, self.low, low); - generated_values_set_u32_target(out_buffer, self.high, high); + out_buffer.set_u32_target(self.low, low); + out_buffer.set_u32_target(self.high, high); } } diff --git a/u32/src/witness.rs b/u32/src/witness.rs index 1b88d60d..ddc3432f 100644 --- a/u32/src/witness.rs +++ b/u32/src/witness.rs @@ -1,21 +1,33 @@ use plonky2::iop::generator::GeneratedValues; use plonky2::iop::witness::Witness; -use plonky2_field::types::Field; +use plonky2_field::types::{Field, PrimeField64}; use crate::gadgets::arithmetic_u32::U32Target; -pub fn generated_values_set_u32_target( - buffer: &mut GeneratedValues, - target: U32Target, - value: u32, -) { - buffer.set_target(target.0, F::from_canonical_u32(value)) +pub trait WitnessU32: Witness { + fn set_u32_target(&mut self, target: U32Target, value: u32); + fn get_u32_target(&self, target: U32Target) -> (u32, u32); } -pub fn witness_set_u32_target, F: Field>( - witness: &mut W, - target: U32Target, - value: u32, -) { - witness.set_target(target.0, F::from_canonical_u32(value)) +impl, F: PrimeField64> WitnessU32 for T { + fn set_u32_target(&mut self, target: U32Target, value: u32) { + self.set_target(target.0, F::from_canonical_u32(value)); + } + + fn get_u32_target(&self, target: U32Target) -> (u32, u32) { + let x_u64 = self.get_target(target.0).to_canonical_u64(); + let low = x_u64 as u32; + let high = (x_u64 >> 32) as u32; + (low, high) + } +} + +pub trait GeneratedValuesU32 { + fn set_u32_target(&mut self, target: U32Target, value: u32); +} + +impl GeneratedValuesU32 for GeneratedValues { + fn set_u32_target(&mut self, target: U32Target, value: u32) { + self.set_target(target.0, F::from_canonical_u32(value)) + } } diff --git a/plonky2/src/gates/assert_le.rs b/waksman/src/gates/assert_le.rs similarity index 96% rename from plonky2/src/gates/assert_le.rs rename to waksman/src/gates/assert_le.rs index 19bff044..c67a7125 100644 --- a/plonky2/src/gates/assert_le.rs +++ b/waksman/src/gates/assert_le.rs @@ -1,26 +1,25 @@ use std::marker::PhantomData; +use plonky2::gates::gate::Gate; +use plonky2::gates::packed_util::PackedEvaluableBase; +use plonky2::gates::util::StridedConstraintConsumer; +use plonky2::hash::hash_types::RichField; +use plonky2::iop::ext_target::ExtensionTarget; +use plonky2::iop::generator::{GeneratedValues, SimpleGenerator, WitnessGenerator}; +use plonky2::iop::target::Target; +use plonky2::iop::wire::Wire; +use plonky2::iop::witness::{PartitionWitness, Witness}; +use plonky2::plonk::circuit_builder::CircuitBuilder; +use plonky2::plonk::plonk_common::{reduce_with_powers, reduce_with_powers_ext_circuit}; +use plonky2::plonk::vars::{ + EvaluationTargets, EvaluationVars, EvaluationVarsBase, EvaluationVarsBaseBatch, + EvaluationVarsBasePacked, +}; use plonky2_field::extension::Extendable; use plonky2_field::packed::PackedField; use plonky2_field::types::{Field, Field64}; use plonky2_util::{bits_u64, ceil_div_usize}; -use crate::gates::gate::Gate; -use crate::gates::packed_util::PackedEvaluableBase; -use crate::gates::util::StridedConstraintConsumer; -use crate::hash::hash_types::RichField; -use crate::iop::ext_target::ExtensionTarget; -use crate::iop::generator::{GeneratedValues, SimpleGenerator, WitnessGenerator}; -use crate::iop::target::Target; -use crate::iop::wire::Wire; -use crate::iop::witness::{PartitionWitness, Witness}; -use crate::plonk::circuit_builder::CircuitBuilder; -use crate::plonk::plonk_common::{reduce_with_powers, reduce_with_powers_ext_circuit}; -use crate::plonk::vars::{ - EvaluationTargets, EvaluationVars, EvaluationVarsBase, EvaluationVarsBaseBatch, - EvaluationVarsBasePacked, -}; - // TODO: replace/merge this gate with `ComparisonGate`. /// A gate for checking that one value is less than or equal to another. @@ -450,6 +449,11 @@ mod tests { use std::marker::PhantomData; use anyhow::Result; + use plonky2::gates::gate::Gate; + use plonky2::gates::gate_testing::{test_eval_fns, test_low_degree}; + use plonky2::hash::hash_types::HashOut; + use plonky2::plonk::config::{GenericConfig, PoseidonGoldilocksConfig}; + use plonky2::plonk::vars::EvaluationVars; use plonky2_field::extension::quartic::QuarticExtension; use plonky2_field::goldilocks_field::GoldilocksField; use plonky2_field::types::Field; @@ -457,11 +461,6 @@ mod tests { use rand::Rng; use crate::gates::assert_le::AssertLessThanGate; - use crate::gates::gate::Gate; - use crate::gates::gate_testing::{test_eval_fns, test_low_degree}; - use crate::hash::hash_types::HashOut; - use crate::plonk::config::{GenericConfig, PoseidonGoldilocksConfig}; - use crate::plonk::vars::EvaluationVars; #[test] fn wire_indices() { diff --git a/waksman/src/gates/mod.rs b/waksman/src/gates/mod.rs index 5a2a8f48..c73890b1 100644 --- a/waksman/src/gates/mod.rs +++ b/waksman/src/gates/mod.rs @@ -1 +1,2 @@ +pub mod assert_le; pub mod switch; diff --git a/waksman/src/sorting.rs b/waksman/src/sorting.rs index ac598dc8..010bc8b9 100644 --- a/waksman/src/sorting.rs +++ b/waksman/src/sorting.rs @@ -3,7 +3,6 @@ use std::marker::PhantomData; use itertools::izip; use plonky2::field::extension::Extendable; use plonky2::field::types::Field; -use plonky2::gates::assert_le::AssertLessThanGate; use plonky2::hash::hash_types::RichField; use plonky2::iop::generator::{GeneratedValues, SimpleGenerator}; use plonky2::iop::target::{BoolTarget, Target}; @@ -11,6 +10,7 @@ use plonky2::iop::witness::{PartitionWitness, Witness}; use plonky2::plonk::circuit_builder::CircuitBuilder; use plonky2_util::ceil_div_usize; +use crate::gates::assert_le::AssertLessThanGate; use crate::permutation::assert_permutation_circuit; pub struct MemoryOp {