diff --git a/.github/workflows/continuous-integration-workflow.yml b/.github/workflows/continuous-integration-workflow.yml index ba2ad1bd..a0ac3ec7 100644 --- a/.github/workflows/continuous-integration-workflow.yml +++ b/.github/workflows/continuous-integration-workflow.yml @@ -124,5 +124,5 @@ jobs: command: clippy args: --all-features --all-targets -- -D warnings -A incomplete-features env: - CARGO_INCREMENTAL: 1 - + # Seems necessary until https://github.com/rust-lang/rust/pull/115819 is merged. + CARGO_INCREMENTAL: 0 diff --git a/evm/src/all_stark.rs b/evm/src/all_stark.rs index b0f3056c..068b0bcb 100644 --- a/evm/src/all_stark.rs +++ b/evm/src/all_stark.rs @@ -6,6 +6,7 @@ use plonky2::hash::hash_types::RichField; use crate::arithmetic::arithmetic_stark; use crate::arithmetic::arithmetic_stark::ArithmeticStark; +use crate::byte_packing::byte_packing_stark::{self, BytePackingStark}; use crate::config::StarkConfig; use crate::cpu::cpu_stark; use crate::cpu::cpu_stark::CpuStark; @@ -25,6 +26,7 @@ use crate::stark::Stark; #[derive(Clone)] pub struct AllStark, const D: usize> { pub arithmetic_stark: ArithmeticStark, + pub byte_packing_stark: BytePackingStark, pub cpu_stark: CpuStark, pub keccak_stark: KeccakStark, pub keccak_sponge_stark: KeccakSpongeStark, @@ -37,6 +39,7 @@ impl, const D: usize> Default for AllStark { fn default() -> Self { Self { arithmetic_stark: ArithmeticStark::default(), + byte_packing_stark: BytePackingStark::default(), cpu_stark: CpuStark::default(), keccak_stark: KeccakStark::default(), keccak_sponge_stark: KeccakSpongeStark::default(), @@ -51,6 +54,7 @@ impl, const D: usize> AllStark { pub(crate) fn nums_permutation_zs(&self, config: &StarkConfig) -> [usize; NUM_TABLES] { [ self.arithmetic_stark.num_permutation_batches(config), + self.byte_packing_stark.num_permutation_batches(config), self.cpu_stark.num_permutation_batches(config), self.keccak_stark.num_permutation_batches(config), self.keccak_sponge_stark.num_permutation_batches(config), @@ -62,6 +66,7 @@ impl, const D: usize> AllStark { pub(crate) fn permutation_batch_sizes(&self) -> [usize; NUM_TABLES] { [ self.arithmetic_stark.permutation_batch_size(), + self.byte_packing_stark.permutation_batch_size(), self.cpu_stark.permutation_batch_size(), self.keccak_stark.permutation_batch_size(), self.keccak_sponge_stark.permutation_batch_size(), @@ -74,11 +79,12 @@ impl, const D: usize> AllStark { #[derive(Debug, Copy, Clone, Eq, PartialEq)] pub enum Table { Arithmetic = 0, - Cpu = 1, - Keccak = 2, - KeccakSponge = 3, - Logic = 4, - Memory = 5, + BytePacking = 1, + Cpu = 2, + Keccak = 3, + KeccakSponge = 4, + Logic = 5, + Memory = 6, } pub(crate) const NUM_TABLES: usize = Table::Memory as usize + 1; @@ -87,6 +93,7 @@ impl Table { pub(crate) fn all() -> [Self; NUM_TABLES] { [ Self::Arithmetic, + Self::BytePacking, Self::Cpu, Self::Keccak, Self::KeccakSponge, @@ -99,6 +106,7 @@ impl Table { pub(crate) fn all_cross_table_lookups() -> Vec> { vec![ ctl_arithmetic(), + ctl_byte_packing(), ctl_keccak_sponge(), ctl_keccak(), ctl_logic(), @@ -116,6 +124,28 @@ fn ctl_arithmetic() -> CrossTableLookup { ) } +fn ctl_byte_packing() -> CrossTableLookup { + let cpu_packing_looking = TableWithColumns::new( + Table::Cpu, + cpu_stark::ctl_data_byte_packing(), + Some(cpu_stark::ctl_filter_byte_packing()), + ); + let cpu_unpacking_looking = TableWithColumns::new( + Table::Cpu, + cpu_stark::ctl_data_byte_unpacking(), + Some(cpu_stark::ctl_filter_byte_unpacking()), + ); + let byte_packing_looked = TableWithColumns::new( + Table::BytePacking, + byte_packing_stark::ctl_looked_data(), + Some(byte_packing_stark::ctl_looked_filter()), + ); + CrossTableLookup::new( + vec![cpu_packing_looking, cpu_unpacking_looking], + byte_packing_looked, + ) +} + fn ctl_keccak() -> CrossTableLookup { let keccak_sponge_looking = TableWithColumns::new( Table::KeccakSponge, @@ -184,9 +214,17 @@ fn ctl_memory() -> CrossTableLookup { Some(keccak_sponge_stark::ctl_looking_memory_filter(i)), ) }); + let byte_packing_ops = (0..32).map(|i| { + TableWithColumns::new( + Table::BytePacking, + byte_packing_stark::ctl_looking_memory(i), + Some(byte_packing_stark::ctl_looking_memory_filter(i)), + ) + }); let all_lookers = iter::once(cpu_memory_code_read) .chain(cpu_memory_gp_ops) .chain(keccak_sponge_reads) + .chain(byte_packing_ops) .collect(); let memory_looked = TableWithColumns::new( Table::Memory, diff --git a/evm/src/arithmetic/arithmetic_stark.rs b/evm/src/arithmetic/arithmetic_stark.rs index 4695798a..5441cf27 100644 --- a/evm/src/arithmetic/arithmetic_stark.rs +++ b/evm/src/arithmetic/arithmetic_stark.rs @@ -27,10 +27,17 @@ use crate::vars::{StarkEvaluationTargets, StarkEvaluationVars}; /// This is done by taking pairs of columns (x, y) of the arithmetic /// table and combining them as x + y*2^16 to ensure they equal the /// corresponding 32-bit number in the CPU table. -fn cpu_arith_data_link(ops: &[usize], regs: &[Range]) -> Vec> { +fn cpu_arith_data_link( + combined_ops: &[(usize, u8)], + regs: &[Range], +) -> Vec> { let limb_base = F::from_canonical_u64(1 << columns::LIMB_BITS); - let mut res = Column::singles(ops).collect_vec(); + let mut res = vec![Column::linear_combination( + combined_ops + .iter() + .map(|&(col, code)| (col, F::from_canonical_u8(code))), + )]; // The inner for loop below assumes N_LIMBS is even. const_assert!(columns::N_LIMBS % 2 == 0); @@ -49,21 +56,27 @@ fn cpu_arith_data_link(ops: &[usize], regs: &[Range]) -> Vec() -> TableWithColumns { - const ARITH_OPS: [usize; 14] = [ - columns::IS_ADD, - columns::IS_SUB, - columns::IS_MUL, - columns::IS_LT, - columns::IS_GT, - columns::IS_ADDFP254, - columns::IS_MULFP254, - columns::IS_SUBFP254, - columns::IS_ADDMOD, - columns::IS_MULMOD, - columns::IS_SUBMOD, - columns::IS_DIV, - columns::IS_MOD, - columns::IS_BYTE, + // We scale each filter flag with the associated opcode value. + // If an arithmetic operation is happening on the CPU side, + // the CTL will enforce that the reconstructed opcode value + // from the opcode bits matches. + const COMBINED_OPS: [(usize, u8); 16] = [ + (columns::IS_ADD, 0x01), + (columns::IS_MUL, 0x02), + (columns::IS_SUB, 0x03), + (columns::IS_DIV, 0x04), + (columns::IS_MOD, 0x06), + (columns::IS_ADDMOD, 0x08), + (columns::IS_MULMOD, 0x09), + (columns::IS_ADDFP254, 0x0c), + (columns::IS_MULFP254, 0x0d), + (columns::IS_SUBFP254, 0x0e), + (columns::IS_SUBMOD, 0x0f), + (columns::IS_LT, 0x10), + (columns::IS_GT, 0x11), + (columns::IS_BYTE, 0x1a), + (columns::IS_SHL, 0x1b), + (columns::IS_SHR, 0x1c), ]; const REGISTER_MAP: [Range; 4] = [ @@ -73,6 +86,8 @@ pub fn ctl_arithmetic_rows() -> TableWithColumns { columns::OUTPUT_REGISTER, ]; + let filter_column = Some(Column::sum(COMBINED_OPS.iter().map(|(c, _v)| *c))); + // Create the Arithmetic Table whose columns are those of the // operations listed in `ops` whose inputs and outputs are given // by `regs`, where each element of `regs` is a range of columns @@ -80,8 +95,8 @@ pub fn ctl_arithmetic_rows() -> TableWithColumns { // is used as the operation filter). TableWithColumns::new( Table::Arithmetic, - cpu_arith_data_link(&ARITH_OPS, ®ISTER_MAP), - Some(Column::sum(ARITH_OPS)), + cpu_arith_data_link(&COMBINED_OPS, ®ISTER_MAP), + filter_column, ) } diff --git a/evm/src/arithmetic/columns.rs b/evm/src/arithmetic/columns.rs index afdd5832..48e00f8e 100644 --- a/evm/src/arithmetic/columns.rs +++ b/evm/src/arithmetic/columns.rs @@ -36,8 +36,10 @@ pub(crate) const IS_SUBMOD: usize = IS_SUBFP254 + 1; pub(crate) const IS_LT: usize = IS_SUBMOD + 1; pub(crate) const IS_GT: usize = IS_LT + 1; pub(crate) const IS_BYTE: usize = IS_GT + 1; +pub(crate) const IS_SHL: usize = IS_BYTE + 1; +pub(crate) const IS_SHR: usize = IS_SHL + 1; -pub(crate) const START_SHARED_COLS: usize = IS_BYTE + 1; +pub(crate) const START_SHARED_COLS: usize = IS_SHR + 1; /// Within the Arithmetic Unit, there are shared columns which can be /// used by any arithmetic circuit, depending on which one is active diff --git a/evm/src/arithmetic/divmod.rs b/evm/src/arithmetic/divmod.rs index 4f2dd748..258c131f 100644 --- a/evm/src/arithmetic/divmod.rs +++ b/evm/src/arithmetic/divmod.rs @@ -45,7 +45,7 @@ pub(crate) fn generate( } match filter { - IS_DIV => { + IS_DIV | IS_SHR => { debug_assert!( lv[OUTPUT_REGISTER] .iter() @@ -104,11 +104,14 @@ pub(crate) fn eval_packed( nv: &[P; NUM_ARITH_COLUMNS], yield_constr: &mut ConstraintConsumer

, ) { + // Constrain IS_SHR independently, so that it doesn't impact the + // constraints when combining the flag with IS_DIV. + yield_constr.constraint_last_row(lv[IS_SHR]); eval_packed_divmod_helper( lv, nv, yield_constr, - lv[IS_DIV], + lv[IS_DIV] + lv[IS_SHR], OUTPUT_REGISTER, AUX_INPUT_REGISTER_0, ); @@ -161,12 +164,14 @@ pub(crate) fn eval_ext_circuit, const D: usize>( nv: &[ExtensionTarget; NUM_ARITH_COLUMNS], yield_constr: &mut RecursiveConstraintConsumer, ) { + yield_constr.constraint_last_row(builder, lv[IS_SHR]); + let div_shr_flag = builder.add_extension(lv[IS_DIV], lv[IS_SHR]); eval_ext_circuit_divmod_helper( builder, lv, nv, yield_constr, - lv[IS_DIV], + div_shr_flag, OUTPUT_REGISTER, AUX_INPUT_REGISTER_0, ); @@ -209,6 +214,8 @@ mod tests { for op in MODULAR_OPS { lv[op] = F::ZERO; } + // Deactivate the SHR flag so that a DIV operation is not triggered. + lv[IS_SHR] = F::ZERO; let mut constraint_consumer = ConstraintConsumer::new( vec![GoldilocksField(2), GoldilocksField(3), GoldilocksField(5)], @@ -240,6 +247,7 @@ mod tests { for op in MODULAR_OPS { lv[op] = F::ZERO; } + lv[IS_SHR] = F::ZERO; lv[op_filter] = F::ONE; let input0 = U256::from(rng.gen::<[u8; 32]>()); @@ -300,6 +308,7 @@ mod tests { for op in MODULAR_OPS { lv[op] = F::ZERO; } + lv[IS_SHR] = F::ZERO; lv[op_filter] = F::ONE; let input0 = U256::from(rng.gen::<[u8; 32]>()); diff --git a/evm/src/arithmetic/mod.rs b/evm/src/arithmetic/mod.rs index d9d63a0b..bd6d56e8 100644 --- a/evm/src/arithmetic/mod.rs +++ b/evm/src/arithmetic/mod.rs @@ -27,15 +27,17 @@ pub(crate) enum BinaryOperator { MulFp254, SubFp254, Byte, + Shl, // simulated with MUL + Shr, // simulated with DIV } impl BinaryOperator { pub(crate) fn result(&self, input0: U256, input1: U256) -> U256 { match self { BinaryOperator::Add => input0.overflowing_add(input1).0, - BinaryOperator::Mul => input0.overflowing_mul(input1).0, + BinaryOperator::Mul | BinaryOperator::Shl => input0.overflowing_mul(input1).0, BinaryOperator::Sub => input0.overflowing_sub(input1).0, - BinaryOperator::Div => { + BinaryOperator::Div | BinaryOperator::Shr => { if input1.is_zero() { U256::zero() } else { @@ -77,6 +79,8 @@ impl BinaryOperator { BinaryOperator::MulFp254 => columns::IS_MULFP254, BinaryOperator::SubFp254 => columns::IS_SUBFP254, BinaryOperator::Byte => columns::IS_BYTE, + BinaryOperator::Shl => columns::IS_SHL, + BinaryOperator::Shr => columns::IS_SHR, } } } @@ -107,6 +111,7 @@ impl TernaryOperator { } } +/// An enum representing arithmetic operations that can be either binary or ternary. #[derive(Debug)] pub(crate) enum Operation { BinaryOperation { @@ -125,6 +130,21 @@ pub(crate) enum Operation { } impl Operation { + /// Create a binary operator with given inputs. + /// + /// NB: This works as you would expect, EXCEPT for SHL and SHR, + /// whose inputs need a small amount of preprocessing. Specifically, + /// to create `SHL(shift, value)`, call (note the reversal of + /// argument order): + /// + /// `Operation::binary(BinaryOperator::Shl, value, 1 << shift)` + /// + /// Similarly, to create `SHR(shift, value)`, call + /// + /// `Operation::binary(BinaryOperator::Shr, value, 1 << shift)` + /// + /// See witness/operation.rs::append_shift() for an example (indeed + /// the only call site for such inputs). pub(crate) fn binary(operator: BinaryOperator, input0: U256, input1: U256) -> Self { let result = operator.result(input0, input1); Self::BinaryOperation { @@ -164,6 +184,10 @@ impl Operation { /// use vectors because that's what utils::transpose (who consumes /// the result of this function as part of the range check code) /// expects. + /// + /// The `is_simulated` bool indicates whether we use a native arithmetic + /// operation or simulate one with another. This is used to distinguish + /// SHL and SHR operations that are simulated through MUL and DIV respectively. fn to_rows(&self) -> (Vec, Option>) { match *self { Operation::BinaryOperation { @@ -214,11 +238,11 @@ fn binary_op_to_rows( addcy::generate(&mut row, op.row_filter(), input0, input1); (row, None) } - BinaryOperator::Mul => { + BinaryOperator::Mul | BinaryOperator::Shl => { mul::generate(&mut row, input0, input1); (row, None) } - BinaryOperator::Div | BinaryOperator::Mod => { + BinaryOperator::Div | BinaryOperator::Mod | BinaryOperator::Shr => { let mut nv = vec![F::ZERO; columns::NUM_ARITH_COLUMNS]; divmod::generate(&mut row, &mut nv, op.row_filter(), input0, input1, result); (row, Some(nv)) diff --git a/evm/src/arithmetic/mul.rs b/evm/src/arithmetic/mul.rs index 597d4051..efb4d822 100644 --- a/evm/src/arithmetic/mul.rs +++ b/evm/src/arithmetic/mul.rs @@ -121,7 +121,7 @@ pub fn eval_packed_generic( ) { let base = P::Scalar::from_canonical_u64(1 << LIMB_BITS); - let is_mul = lv[IS_MUL]; + let is_mul = lv[IS_MUL] + lv[IS_SHL]; let input0_limbs = read_value::(lv, INPUT_REGISTER_0); let input1_limbs = read_value::(lv, INPUT_REGISTER_1); let output_limbs = read_value::(lv, OUTPUT_REGISTER); @@ -173,7 +173,7 @@ pub fn eval_ext_circuit, const D: usize>( lv: &[ExtensionTarget; NUM_ARITH_COLUMNS], yield_constr: &mut RecursiveConstraintConsumer, ) { - let is_mul = lv[IS_MUL]; + let is_mul = builder.add_extension(lv[IS_MUL], lv[IS_SHL]); let input0_limbs = read_value::(lv, INPUT_REGISTER_0); let input1_limbs = read_value::(lv, INPUT_REGISTER_1); let output_limbs = read_value::(lv, OUTPUT_REGISTER); @@ -229,6 +229,8 @@ mod tests { // if `IS_MUL == 0`, then the constraints should be met even // if all values are garbage. lv[IS_MUL] = F::ZERO; + // Deactivate the SHL flag so that a MUL operation is not triggered. + lv[IS_SHL] = F::ZERO; let mut constraint_consumer = ConstraintConsumer::new( vec![GoldilocksField(2), GoldilocksField(3), GoldilocksField(5)], diff --git a/evm/src/byte_packing/byte_packing_stark.rs b/evm/src/byte_packing/byte_packing_stark.rs new file mode 100644 index 00000000..aa6a2dcf --- /dev/null +++ b/evm/src/byte_packing/byte_packing_stark.rs @@ -0,0 +1,590 @@ +//! This crate enforces the correctness of reading and writing sequences +//! of bytes in Big-Endian ordering from and to the memory. +//! +//! The trace layout consists in N consecutive rows for an `N` byte sequence, +//! with the byte values being cumulatively written to the trace as they are +//! being processed. +//! +//! At row `i` of such a group (starting from 0), the `i`-th byte flag will be activated +//! (to indicate which byte we are going to be processing), but all bytes with index +//! 0 to `i` may have non-zero values, as they have already been processed. +//! +//! The length of a sequence is stored within each group of rows corresponding to that +//! sequence in a dedicated `SEQUENCE_LEN` column. At any row `i`, the remaining length +//! of the sequence being processed is retrieved from that column and the active byte flag +//! as: +//! +//! remaining_length = sequence_length - \sum_{i=0}^31 b[i] * i +//! +//! where b[i] is the `i`-th byte flag. +//! +//! Because of the discrepancy in endianness between the different tables, the byte sequences +//! are actually written in the trace in reverse order from the order they are provided. +//! As such, the memory virtual address for a group of rows corresponding to a sequence starts +//! with the final virtual address, corresponding to the final byte being read/written, and +//! is being decremented at each step. +//! +//! Note that, when writing a sequence of bytes to memory, both the `U256` value and the +//! corresponding sequence length are being read from the stack. Because of the endianness +//! discrepancy mentioned above, we first convert the value to a byte sequence in Little-Endian, +//! then resize the sequence to prune unneeded zeros before reverting the sequence order. +//! This means that the higher-order bytes will be thrown away during the process, if the value +//! is greater than 256^length, and as a result a different value will be stored in memory. + +use std::marker::PhantomData; + +use itertools::Itertools; +use plonky2::field::extension::{Extendable, FieldExtension}; +use plonky2::field::packed::PackedField; +use plonky2::field::polynomial::PolynomialValues; +use plonky2::field::types::Field; +use plonky2::hash::hash_types::RichField; +use plonky2::iop::ext_target::ExtensionTarget; +use plonky2::timed; +use plonky2::util::timing::TimingTree; +use plonky2::util::transpose; + +use super::NUM_BYTES; +use crate::byte_packing::columns::{ + index_bytes, value_bytes, ADDR_CONTEXT, ADDR_SEGMENT, ADDR_VIRTUAL, BYTE_INDICES_COLS, IS_READ, + NUM_COLUMNS, RANGE_COUNTER, RC_COLS, SEQUENCE_END, TIMESTAMP, +}; +use crate::constraint_consumer::{ConstraintConsumer, RecursiveConstraintConsumer}; +use crate::cross_table_lookup::Column; +use crate::lookup::{eval_lookups, eval_lookups_circuit, permuted_cols}; +use crate::stark::Stark; +use crate::vars::{StarkEvaluationTargets, StarkEvaluationVars}; +use crate::witness::memory::MemoryAddress; + +/// Strict upper bound for the individual bytes range-check. +const BYTE_RANGE_MAX: usize = 1usize << 8; + +pub(crate) fn ctl_looked_data() -> Vec> { + // Reconstruct the u32 limbs composing the final `U256` word + // being read/written from the underlying byte values. For each, + // we pack 4 consecutive bytes and shift them accordingly to + // obtain the corresponding limb. + let outputs: Vec> = (0..8) + .map(|i| { + let range = (value_bytes(i * 4)..value_bytes(i * 4) + 4).collect_vec(); + Column::linear_combination( + range + .iter() + .enumerate() + .map(|(j, &c)| (c, F::from_canonical_u64(1 << (8 * j)))), + ) + }) + .collect(); + + // This will correspond to the actual sequence length when the `SEQUENCE_END` flag is on. + let sequence_len: Column = Column::linear_combination( + (0..NUM_BYTES).map(|i| (index_bytes(i), F::from_canonical_usize(i + 1))), + ); + + Column::singles([ADDR_CONTEXT, ADDR_SEGMENT, ADDR_VIRTUAL]) + .chain([sequence_len]) + .chain(Column::singles(&[TIMESTAMP])) + .chain(outputs) + .collect() +} + +pub fn ctl_looked_filter() -> Column { + // The CPU table is only interested in our sequence end rows, + // since those contain the final limbs of our packed int. + Column::single(SEQUENCE_END) +} + +pub(crate) fn ctl_looking_memory(i: usize) -> Vec> { + let mut res = + Column::singles([IS_READ, ADDR_CONTEXT, ADDR_SEGMENT, ADDR_VIRTUAL]).collect_vec(); + + // The i'th input byte being read/written. + res.push(Column::single(value_bytes(i))); + + // Since we're reading a single byte, the higher limbs must be zero. + res.extend((1..8).map(|_| Column::zero())); + + res.push(Column::single(TIMESTAMP)); + + res +} + +/// CTL filter for reading/writing the `i`th byte of the byte sequence from/to memory. +pub(crate) fn ctl_looking_memory_filter(i: usize) -> Column { + Column::single(index_bytes(i)) +} + +/// Information about a byte packing operation needed for witness generation. +#[derive(Clone, Debug)] +pub(crate) struct BytePackingOp { + /// Whether this is a read (packing) or write (unpacking) operation. + pub(crate) is_read: bool, + + /// The base address at which inputs are read/written. + pub(crate) base_address: MemoryAddress, + + /// The timestamp at which inputs are read/written. + pub(crate) timestamp: usize, + + /// The byte sequence that was read/written. + /// Its length is required to be at most 32. + pub(crate) bytes: Vec, +} + +#[derive(Copy, Clone, Default)] +pub struct BytePackingStark { + pub(crate) f: PhantomData, +} + +impl, const D: usize> BytePackingStark { + pub(crate) fn generate_trace( + &self, + ops: Vec, + min_rows: usize, + timing: &mut TimingTree, + ) -> Vec> { + // Generate most of the trace in row-major form. + let trace_rows = timed!( + timing, + "generate trace rows", + self.generate_trace_rows(ops, min_rows) + ); + let trace_row_vecs: Vec<_> = trace_rows.into_iter().map(|row| row.to_vec()).collect(); + + let mut trace_cols = transpose(&trace_row_vecs); + self.generate_range_checks(&mut trace_cols); + + trace_cols.into_iter().map(PolynomialValues::new).collect() + } + + fn generate_trace_rows( + &self, + ops: Vec, + min_rows: usize, + ) -> Vec<[F; NUM_COLUMNS]> { + let base_len: usize = ops.iter().map(|op| op.bytes.len()).sum(); + let num_rows = core::cmp::max(base_len.max(BYTE_RANGE_MAX), min_rows).next_power_of_two(); + let mut rows = Vec::with_capacity(num_rows); + + for op in ops { + rows.extend(self.generate_rows_for_op(op)); + } + + for _ in rows.len()..num_rows { + rows.push(self.generate_padding_row()); + } + + rows + } + + fn generate_rows_for_op(&self, op: BytePackingOp) -> Vec<[F; NUM_COLUMNS]> { + let BytePackingOp { + is_read, + base_address, + timestamp, + bytes, + } = op; + + let MemoryAddress { + context, + segment, + virt, + } = base_address; + + let mut rows = Vec::with_capacity(bytes.len()); + let mut row = [F::ZERO; NUM_COLUMNS]; + row[IS_READ] = F::from_bool(is_read); + + row[ADDR_CONTEXT] = F::from_canonical_usize(context); + row[ADDR_SEGMENT] = F::from_canonical_usize(segment); + // Because of the endianness, we start by the final virtual address value + // and decrement it at each step. Similarly, we process the byte sequence + // in reverse order. + row[ADDR_VIRTUAL] = F::from_canonical_usize(virt + bytes.len() - 1); + + row[TIMESTAMP] = F::from_canonical_usize(timestamp); + + for (i, &byte) in bytes.iter().rev().enumerate() { + if i == bytes.len() - 1 { + row[SEQUENCE_END] = F::ONE; + } + row[value_bytes(i)] = F::from_canonical_u8(byte); + row[index_bytes(i)] = F::ONE; + + rows.push(row.into()); + row[index_bytes(i)] = F::ZERO; + row[ADDR_VIRTUAL] -= F::ONE; + } + + rows + } + + fn generate_padding_row(&self) -> [F; NUM_COLUMNS] { + [F::ZERO; NUM_COLUMNS] + } + + /// Expects input in *column*-major layout + fn generate_range_checks(&self, cols: &mut Vec>) { + debug_assert!(cols.len() == NUM_COLUMNS); + + let n_rows = cols[0].len(); + debug_assert!(cols.iter().all(|col| col.len() == n_rows)); + + for i in 0..BYTE_RANGE_MAX { + cols[RANGE_COUNTER][i] = F::from_canonical_usize(i); + } + for i in BYTE_RANGE_MAX..n_rows { + cols[RANGE_COUNTER][i] = F::from_canonical_usize(BYTE_RANGE_MAX - 1); + } + + // For each column c in cols, generate the range-check + // permutations and put them in the corresponding range-check + // columns rc_c and rc_c+1. + for (i, rc_c) in (0..NUM_BYTES).zip(RC_COLS.step_by(2)) { + let c = value_bytes(i); + let (col_perm, table_perm) = permuted_cols(&cols[c], &cols[RANGE_COUNTER]); + cols[rc_c].copy_from_slice(&col_perm); + cols[rc_c + 1].copy_from_slice(&table_perm); + } + } + + /// There is only one `i` for which `vars.local_values[index_bytes(i)]` is non-zero, + /// and `i+1` is the current position: + fn get_active_position(&self, row: &[P; NUM_COLUMNS]) -> P + where + FE: FieldExtension, + P: PackedField, + { + (0..NUM_BYTES) + .map(|i| row[index_bytes(i)] * P::Scalar::from_canonical_usize(i + 1)) + .sum() + } + + /// Recursive version of `get_active_position`. + fn get_active_position_circuit( + &self, + builder: &mut plonky2::plonk::circuit_builder::CircuitBuilder, + row: &[ExtensionTarget; NUM_COLUMNS], + ) -> ExtensionTarget { + let mut current_position = row[index_bytes(0)]; + + for i in 1..NUM_BYTES { + current_position = builder.mul_const_add_extension( + F::from_canonical_usize(i + 1), + row[index_bytes(i)], + current_position, + ); + } + + current_position + } +} + +impl, const D: usize> Stark for BytePackingStark { + const COLUMNS: usize = NUM_COLUMNS; + + fn eval_packed_generic( + &self, + vars: StarkEvaluationVars, + yield_constr: &mut ConstraintConsumer

, + ) where + FE: FieldExtension, + P: PackedField, + { + // Range check all the columns + for col in RC_COLS.step_by(2) { + eval_lookups(vars, yield_constr, col, col + 1); + } + + let one = P::ONES; + + // We filter active columns by summing all the byte indices. + // Constraining each of them to be boolean is done later on below. + let current_filter = vars.local_values[BYTE_INDICES_COLS] + .iter() + .copied() + .sum::

(); + yield_constr.constraint(current_filter * (current_filter - one)); + + // The filter column must start by one. + yield_constr.constraint_first_row(current_filter - one); + + // The is_read flag must be boolean. + let current_is_read = vars.local_values[IS_READ]; + yield_constr.constraint(current_is_read * (current_is_read - one)); + + // Each byte index must be boolean. + for i in 0..NUM_BYTES { + let idx_i = vars.local_values[index_bytes(i)]; + yield_constr.constraint(idx_i * (idx_i - one)); + } + + // The sequence start flag column must start by one. + let current_sequence_start = vars.local_values[index_bytes(0)]; + yield_constr.constraint_first_row(current_sequence_start - one); + + // The sequence end flag must be boolean + let current_sequence_end = vars.local_values[SEQUENCE_END]; + yield_constr.constraint(current_sequence_end * (current_sequence_end - one)); + + // If filter is off, all flags and byte indices must be off. + let byte_indices = vars.local_values[BYTE_INDICES_COLS] + .iter() + .copied() + .sum::

(); + yield_constr.constraint( + (current_filter - one) * (current_is_read + current_sequence_end + byte_indices), + ); + + // Only padding rows have their filter turned off. + let next_filter = vars.next_values[BYTE_INDICES_COLS] + .iter() + .copied() + .sum::

(); + yield_constr.constraint_transition(next_filter * (next_filter - current_filter)); + + // Unless the current sequence end flag is activated, the is_read filter must remain unchanged. + let next_is_read = vars.next_values[IS_READ]; + yield_constr + .constraint_transition((current_sequence_end - one) * (next_is_read - current_is_read)); + + // If the sequence end flag is activated, the next row must be a new sequence or filter must be off. + let next_sequence_start = vars.next_values[index_bytes(0)]; + yield_constr.constraint_transition( + current_sequence_end * next_filter * (next_sequence_start - one), + ); + + // The active position in a byte sequence must increase by one on every row + // or be one on the next row (i.e. at the start of a new sequence). + let current_position = self.get_active_position(vars.local_values); + let next_position = self.get_active_position(vars.next_values); + yield_constr.constraint_transition( + next_filter * (next_position - one) * (next_position - current_position - one), + ); + + // The last row must be the end of a sequence or a padding row. + yield_constr.constraint_last_row(current_filter * (current_sequence_end - one)); + + // If the next position is one in an active row, the current end flag must be one. + yield_constr + .constraint_transition(next_filter * current_sequence_end * (next_position - one)); + + // The context, segment and timestamp fields must remain unchanged throughout a byte sequence. + // The virtual address must decrement by one at each step of a sequence. + let current_context = vars.local_values[ADDR_CONTEXT]; + let next_context = vars.next_values[ADDR_CONTEXT]; + let current_segment = vars.local_values[ADDR_SEGMENT]; + let next_segment = vars.next_values[ADDR_SEGMENT]; + let current_virtual = vars.local_values[ADDR_VIRTUAL]; + let next_virtual = vars.next_values[ADDR_VIRTUAL]; + let current_timestamp = vars.local_values[TIMESTAMP]; + let next_timestamp = vars.next_values[TIMESTAMP]; + yield_constr.constraint_transition( + next_filter * (next_sequence_start - one) * (next_context - current_context), + ); + yield_constr.constraint_transition( + next_filter * (next_sequence_start - one) * (next_segment - current_segment), + ); + yield_constr.constraint_transition( + next_filter * (next_sequence_start - one) * (next_timestamp - current_timestamp), + ); + yield_constr.constraint_transition( + next_filter * (next_sequence_start - one) * (current_virtual - next_virtual - one), + ); + + // If not at the end of a sequence, each next byte must equal the current one + // when reading through the sequence, or the next byte index must be one. + for i in 0..NUM_BYTES { + let current_byte = vars.local_values[value_bytes(i)]; + let next_byte = vars.next_values[value_bytes(i)]; + let next_byte_index = vars.next_values[index_bytes(i)]; + yield_constr.constraint_transition( + (current_sequence_end - one) * (next_byte_index - one) * (next_byte - current_byte), + ); + } + } + + fn eval_ext_circuit( + &self, + builder: &mut plonky2::plonk::circuit_builder::CircuitBuilder, + vars: StarkEvaluationTargets, + yield_constr: &mut RecursiveConstraintConsumer, + ) { + // Range check all the columns + for col in RC_COLS.step_by(2) { + eval_lookups_circuit(builder, vars, yield_constr, col, col + 1); + } + + // We filter active columns by summing all the byte indices. + // Constraining each of them to be boolean is done later on below. + let current_filter = builder.add_many_extension(&vars.local_values[BYTE_INDICES_COLS]); + let constraint = builder.mul_sub_extension(current_filter, current_filter, current_filter); + yield_constr.constraint(builder, constraint); + + // The filter column must start by one. + let constraint = builder.add_const_extension(current_filter, F::NEG_ONE); + yield_constr.constraint_first_row(builder, constraint); + + // The is_read flag must be boolean. + let current_is_read = vars.local_values[IS_READ]; + let constraint = + builder.mul_sub_extension(current_is_read, current_is_read, current_is_read); + yield_constr.constraint(builder, constraint); + + // Each byte index must be boolean. + for i in 0..NUM_BYTES { + let idx_i = vars.local_values[index_bytes(i)]; + let constraint = builder.mul_sub_extension(idx_i, idx_i, idx_i); + yield_constr.constraint(builder, constraint); + } + + // The sequence start flag column must start by one. + let current_sequence_start = vars.local_values[index_bytes(0)]; + let constraint = builder.add_const_extension(current_sequence_start, F::NEG_ONE); + yield_constr.constraint_first_row(builder, constraint); + + // The sequence end flag must be boolean + let current_sequence_end = vars.local_values[SEQUENCE_END]; + let constraint = builder.mul_sub_extension( + current_sequence_end, + current_sequence_end, + current_sequence_end, + ); + yield_constr.constraint(builder, constraint); + + // If filter is off, all flags and byte indices must be off. + let byte_indices = builder.add_many_extension(&vars.local_values[BYTE_INDICES_COLS]); + let constraint = builder.add_extension(current_sequence_end, byte_indices); + let constraint = builder.add_extension(constraint, current_is_read); + let constraint = builder.mul_sub_extension(constraint, current_filter, constraint); + yield_constr.constraint(builder, constraint); + + // Only padding rows have their filter turned off. + let next_filter = builder.add_many_extension(&vars.next_values[BYTE_INDICES_COLS]); + let constraint = builder.sub_extension(next_filter, current_filter); + let constraint = builder.mul_extension(next_filter, constraint); + yield_constr.constraint_transition(builder, constraint); + + // Unless the current sequence end flag is activated, the is_read filter must remain unchanged. + let next_is_read = vars.next_values[IS_READ]; + let diff_is_read = builder.sub_extension(next_is_read, current_is_read); + let constraint = + builder.mul_sub_extension(diff_is_read, current_sequence_end, diff_is_read); + yield_constr.constraint_transition(builder, constraint); + + // If the sequence end flag is activated, the next row must be a new sequence or filter must be off. + let next_sequence_start = vars.next_values[index_bytes(0)]; + let constraint = builder.mul_sub_extension( + current_sequence_end, + next_sequence_start, + current_sequence_end, + ); + let constraint = builder.mul_extension(next_filter, constraint); + yield_constr.constraint_transition(builder, constraint); + + // The active position in a byte sequence must increase by one on every row + // or be one on the next row (i.e. at the start of a new sequence). + let current_position = self.get_active_position_circuit(builder, vars.local_values); + let next_position = self.get_active_position_circuit(builder, vars.next_values); + + let position_diff = builder.sub_extension(next_position, current_position); + let is_new_or_inactive = builder.mul_sub_extension(next_filter, next_position, next_filter); + let constraint = + builder.mul_sub_extension(is_new_or_inactive, position_diff, is_new_or_inactive); + yield_constr.constraint_transition(builder, constraint); + + // The last row must be the end of a sequence or a padding row. + let constraint = + builder.mul_sub_extension(current_filter, current_sequence_end, current_filter); + yield_constr.constraint_last_row(builder, constraint); + + // If the next position is one in an active row, the current end flag must be one. + let constraint = builder.mul_extension(next_filter, current_sequence_end); + let constraint = builder.mul_sub_extension(constraint, next_position, constraint); + yield_constr.constraint_transition(builder, constraint); + + // The context, segment and timestamp fields must remain unchanged throughout a byte sequence. + // The virtual address must decrement by one at each step of a sequence. + let current_context = vars.local_values[ADDR_CONTEXT]; + let next_context = vars.next_values[ADDR_CONTEXT]; + let current_segment = vars.local_values[ADDR_SEGMENT]; + let next_segment = vars.next_values[ADDR_SEGMENT]; + let current_virtual = vars.local_values[ADDR_VIRTUAL]; + let next_virtual = vars.next_values[ADDR_VIRTUAL]; + let current_timestamp = vars.local_values[TIMESTAMP]; + let next_timestamp = vars.next_values[TIMESTAMP]; + let addr_filter = builder.mul_sub_extension(next_filter, next_sequence_start, next_filter); + { + let constraint = builder.sub_extension(next_context, current_context); + let constraint = builder.mul_extension(addr_filter, constraint); + yield_constr.constraint_transition(builder, constraint); + } + { + let constraint = builder.sub_extension(next_segment, current_segment); + let constraint = builder.mul_extension(addr_filter, constraint); + yield_constr.constraint_transition(builder, constraint); + } + { + let constraint = builder.sub_extension(next_timestamp, current_timestamp); + let constraint = builder.mul_extension(addr_filter, constraint); + yield_constr.constraint_transition(builder, constraint); + } + { + let constraint = builder.sub_extension(current_virtual, next_virtual); + let constraint = builder.mul_sub_extension(addr_filter, constraint, addr_filter); + yield_constr.constraint_transition(builder, constraint); + } + + // If not at the end of a sequence, each next byte must equal the current one + // when reading through the sequence, or the next byte index must be one. + for i in 0..NUM_BYTES { + let current_byte = vars.local_values[value_bytes(i)]; + let next_byte = vars.next_values[value_bytes(i)]; + let next_byte_index = vars.next_values[index_bytes(i)]; + let byte_diff = builder.sub_extension(next_byte, current_byte); + let constraint = builder.mul_sub_extension(byte_diff, next_byte_index, byte_diff); + let constraint = + builder.mul_sub_extension(constraint, current_sequence_end, constraint); + yield_constr.constraint_transition(builder, constraint); + } + } + + fn constraint_degree(&self) -> usize { + 3 + } +} + +#[cfg(test)] +pub(crate) mod tests { + use anyhow::Result; + use plonky2::plonk::config::{GenericConfig, PoseidonGoldilocksConfig}; + + use crate::byte_packing::byte_packing_stark::BytePackingStark; + use crate::stark_testing::{test_stark_circuit_constraints, test_stark_low_degree}; + + #[test] + fn test_stark_degree() -> Result<()> { + const D: usize = 2; + type C = PoseidonGoldilocksConfig; + type F = >::F; + type S = BytePackingStark; + + let stark = S { + f: Default::default(), + }; + test_stark_low_degree(stark) + } + + #[test] + fn test_stark_circuit() -> Result<()> { + const D: usize = 2; + type C = PoseidonGoldilocksConfig; + type F = >::F; + type S = BytePackingStark; + + let stark = S { + f: Default::default(), + }; + test_stark_circuit_constraints::(stark) + } +} diff --git a/evm/src/byte_packing/columns.rs b/evm/src/byte_packing/columns.rs new file mode 100644 index 00000000..feb8f2e2 --- /dev/null +++ b/evm/src/byte_packing/columns.rs @@ -0,0 +1,45 @@ +//! Byte packing registers. + +use core::ops::Range; + +use crate::byte_packing::NUM_BYTES; + +/// 1 if this is a READ operation, and 0 if this is a WRITE operation. +pub(crate) const IS_READ: usize = 0; +/// 1 if this is the end of a sequence of bytes. +/// This is also used as filter for the CTL. +pub(crate) const SEQUENCE_END: usize = IS_READ + 1; + +pub(super) const BYTES_INDICES_START: usize = SEQUENCE_END + 1; +pub(crate) const fn index_bytes(i: usize) -> usize { + debug_assert!(i < NUM_BYTES); + BYTES_INDICES_START + i +} + +// Note: Those are used as filter for distinguishing active vs padding rows, +// and also to obtain the length of a sequence of bytes being processed. +pub(crate) const BYTE_INDICES_COLS: Range = + BYTES_INDICES_START..BYTES_INDICES_START + NUM_BYTES; + +pub(crate) const ADDR_CONTEXT: usize = BYTES_INDICES_START + NUM_BYTES; +pub(crate) const ADDR_SEGMENT: usize = ADDR_CONTEXT + 1; +pub(crate) const ADDR_VIRTUAL: usize = ADDR_SEGMENT + 1; +pub(crate) const TIMESTAMP: usize = ADDR_VIRTUAL + 1; + +// 32 byte limbs hold a total of 256 bits. +const BYTES_VALUES_START: usize = TIMESTAMP + 1; +pub(crate) const fn value_bytes(i: usize) -> usize { + debug_assert!(i < NUM_BYTES); + BYTES_VALUES_START + i +} + +// We need one column for the table, then two columns for every value +// that needs to be range checked in the trace (all written bytes), +// namely the permutation of the column and the permutation of the range. +// The two permutations associated to the byte in column i will be in +// columns RC_COLS[2i] and RC_COLS[2i+1]. +pub(crate) const RANGE_COUNTER: usize = BYTES_VALUES_START + NUM_BYTES; +pub(crate) const NUM_RANGE_CHECK_COLS: usize = 1 + 2 * NUM_BYTES; +pub(crate) const RC_COLS: Range = RANGE_COUNTER + 1..RANGE_COUNTER + NUM_RANGE_CHECK_COLS; + +pub(crate) const NUM_COLUMNS: usize = RANGE_COUNTER + NUM_RANGE_CHECK_COLS; diff --git a/evm/src/byte_packing/mod.rs b/evm/src/byte_packing/mod.rs new file mode 100644 index 00000000..7cc93374 --- /dev/null +++ b/evm/src/byte_packing/mod.rs @@ -0,0 +1,9 @@ +//! Byte packing / unpacking unit for the EVM. +//! +//! This module handles reading / writing to memory byte sequences of +//! length at most 32 in Big-Endian ordering. + +pub mod byte_packing_stark; +pub mod columns; + +pub(crate) const NUM_BYTES: usize = 32; diff --git a/evm/src/cpu/bootstrap_kernel.rs b/evm/src/cpu/bootstrap_kernel.rs index 66f88d3a..4aee617c 100644 --- a/evm/src/cpu/bootstrap_kernel.rs +++ b/evm/src/cpu/bootstrap_kernel.rs @@ -25,6 +25,7 @@ pub(crate) fn generate_bootstrap_kernel(state: &mut GenerationState for chunk in &KERNEL.code.iter().enumerate().chunks(NUM_GP_CHANNELS) { let mut cpu_row = CpuColumnsView::default(); cpu_row.clock = F::from_canonical_usize(state.traces.clock()); + cpu_row.is_bootstrap_kernel = F::ONE; // Write this chunk to memory, while simultaneously packing its bytes into a u32 word. for (channel, (addr, &byte)) in chunk.enumerate() { @@ -39,6 +40,7 @@ pub(crate) fn generate_bootstrap_kernel(state: &mut GenerationState let mut final_cpu_row = CpuColumnsView::default(); final_cpu_row.clock = F::from_canonical_usize(state.traces.clock()); + final_cpu_row.is_bootstrap_kernel = F::ONE; final_cpu_row.is_keccak_sponge = F::ONE; // The Keccak sponge CTL uses memory value columns for its inputs and outputs. final_cpu_row.mem_channels[0].value[0] = F::ZERO; // context @@ -64,8 +66,8 @@ pub(crate) fn eval_bootstrap_kernel>( let next_values: &CpuColumnsView<_> = vars.next_values.borrow(); // IS_BOOTSTRAP_KERNEL must have an init value of 1, a final value of 0, and a delta in {0, -1}. - let local_is_bootstrap = P::ONES - local_values.op.into_iter().sum::

(); - let next_is_bootstrap = P::ONES - next_values.op.into_iter().sum::

(); + let local_is_bootstrap = local_values.is_bootstrap_kernel; + let next_is_bootstrap = next_values.is_bootstrap_kernel; yield_constr.constraint_first_row(local_is_bootstrap - P::ONES); yield_constr.constraint_last_row(local_is_bootstrap); let delta_is_bootstrap = next_is_bootstrap - local_is_bootstrap; @@ -111,10 +113,8 @@ pub(crate) fn eval_bootstrap_kernel_circuit, const let one = builder.one_extension(); // IS_BOOTSTRAP_KERNEL must have an init value of 1, a final value of 0, and a delta in {0, -1}. - let local_is_bootstrap = builder.add_many_extension(local_values.op.iter()); - let local_is_bootstrap = builder.sub_extension(one, local_is_bootstrap); - let next_is_bootstrap = builder.add_many_extension(next_values.op.iter()); - let next_is_bootstrap = builder.sub_extension(one, next_is_bootstrap); + let local_is_bootstrap = local_values.is_bootstrap_kernel; + let next_is_bootstrap = next_values.is_bootstrap_kernel; let constraint = builder.sub_extension(local_is_bootstrap, one); yield_constr.constraint_first_row(builder, constraint); yield_constr.constraint_last_row(builder, local_is_bootstrap); diff --git a/evm/src/cpu/columns/mod.rs b/evm/src/cpu/columns/mod.rs index 134ab02b..fecc8df9 100644 --- a/evm/src/cpu/columns/mod.rs +++ b/evm/src/cpu/columns/mod.rs @@ -35,6 +35,9 @@ pub struct MemoryChannelView { #[repr(C)] #[derive(Clone, Copy, Eq, PartialEq, Debug)] pub struct CpuColumnsView { + /// Filter. 1 if the row is part of bootstrapping the kernel code, 0 otherwise. + pub is_bootstrap_kernel: T, + /// If CPU cycle: Current context. // TODO: this is currently unconstrained pub context: T, diff --git a/evm/src/cpu/columns/ops.rs b/evm/src/cpu/columns/ops.rs index 81d8414a..d4d753f7 100644 --- a/evm/src/cpu/columns/ops.rs +++ b/evm/src/cpu/columns/ops.rs @@ -7,33 +7,17 @@ use crate::util::{indices_arr, transmute_no_compile_time_size_checks}; #[repr(C)] #[derive(Clone, Copy, Eq, PartialEq, Debug)] pub struct OpsColumnsView { - // TODO: combine ADD, MUL, SUB, DIV, MOD, ADDFP254, MULFP254, SUBFP254, LT, and GT into one flag - pub add: T, - pub mul: T, - pub sub: T, - pub div: T, - pub mod_: T, - // TODO: combine ADDMOD, MULMOD and SUBMOD into one flag - pub addmod: T, - pub mulmod: T, - pub addfp254: T, - pub mulfp254: T, - pub subfp254: T, - pub submod: T, - pub lt: T, - pub gt: T, - pub eq_iszero: T, // Combines EQ and ISZERO flags. - pub logic_op: T, // Combines AND, OR and XOR flags. + pub binary_op: T, // Combines ADD, MUL, SUB, DIV, MOD, LT, GT and BYTE flags. + pub ternary_op: T, // Combines ADDMOD, MULMOD and SUBMOD flags. + pub fp254_op: T, // Combines ADD_FP254, MUL_FP254 and SUB_FP254 flags. + pub eq_iszero: T, // Combines EQ and ISZERO flags. + pub logic_op: T, // Combines AND, OR and XOR flags. pub not: T, - pub byte: T, - // TODO: combine SHL and SHR into one flag - pub shl: T, - pub shr: T, + pub shift: T, // Combines SHL and SHR flags. pub keccak_general: T, pub prover_input: T, pub pop: T, - // TODO: combine JUMP and JUMPI into one flag - pub jumps: T, // Note: This column must be 0 when is_cpu_cycle = 0. + pub jumps: T, // Combines JUMP and JUMPI flags. pub pc: T, pub jumpdest: T, pub push0: T, @@ -41,10 +25,10 @@ pub struct OpsColumnsView { pub dup: T, pub swap: T, pub context_op: T, + pub mstore_32bytes: T, + pub mload_32bytes: T, pub exit_kernel: T, - // TODO: combine MLOAD_GENERAL and MSTORE_GENERAL into one flag - pub mload_general: T, - pub mstore_general: T, + pub m_op_general: T, pub syscall: T, pub exception: T, diff --git a/evm/src/cpu/control_flow.rs b/evm/src/cpu/control_flow.rs index 0bea5c7c..9c17367a 100644 --- a/evm/src/cpu/control_flow.rs +++ b/evm/src/cpu/control_flow.rs @@ -8,24 +8,14 @@ use crate::constraint_consumer::{ConstraintConsumer, RecursiveConstraintConsumer use crate::cpu::columns::{CpuColumnsView, COL_MAP}; use crate::cpu::kernel::aggregator::KERNEL; -const NATIVE_INSTRUCTIONS: [usize; 28] = [ - COL_MAP.op.add, - COL_MAP.op.mul, - COL_MAP.op.sub, - COL_MAP.op.div, - COL_MAP.op.mod_, - COL_MAP.op.addmod, - COL_MAP.op.mulmod, - COL_MAP.op.addfp254, - COL_MAP.op.mulfp254, - COL_MAP.op.subfp254, - COL_MAP.op.lt, - COL_MAP.op.gt, +const NATIVE_INSTRUCTIONS: [usize; 17] = [ + COL_MAP.op.binary_op, + COL_MAP.op.ternary_op, + COL_MAP.op.fp254_op, COL_MAP.op.eq_iszero, COL_MAP.op.logic_op, COL_MAP.op.not, - COL_MAP.op.shl, - COL_MAP.op.shr, + COL_MAP.op.shift, COL_MAP.op.keccak_general, COL_MAP.op.prover_input, COL_MAP.op.pop, @@ -39,20 +29,14 @@ const NATIVE_INSTRUCTIONS: [usize; 28] = [ COL_MAP.op.swap, COL_MAP.op.context_op, // not EXIT_KERNEL (performs a jump) - COL_MAP.op.mload_general, - COL_MAP.op.mstore_general, + COL_MAP.op.m_op_general, // not SYSCALL (performs a jump) // not exceptions (also jump) ]; -pub(crate) fn get_halt_pcs() -> (F, F) { - let halt_pc0 = KERNEL.global_labels["halt_pc0"]; - let halt_pc1 = KERNEL.global_labels["halt_pc1"]; - - ( - F::from_canonical_usize(halt_pc0), - F::from_canonical_usize(halt_pc1), - ) +pub(crate) fn get_halt_pc() -> F { + let halt_pc = KERNEL.global_labels["halt"]; + F::from_canonical_usize(halt_pc) } pub(crate) fn get_start_pc() -> F { @@ -68,8 +52,15 @@ pub fn eval_packed_generic( ) { let is_cpu_cycle: P = COL_MAP.op.iter().map(|&col_i| lv[col_i]).sum(); let is_cpu_cycle_next: P = COL_MAP.op.iter().map(|&col_i| nv[col_i]).sum(); - // Once we start executing instructions, then we continue until the end of the table. - yield_constr.constraint_transition(is_cpu_cycle * (is_cpu_cycle_next - P::ONES)); + + let next_halt_state = P::ONES - nv.is_bootstrap_kernel - is_cpu_cycle_next; + + // Once we start executing instructions, then we continue until the end of the table + // or we reach dummy padding rows. This, along with the constraints on the first row, + // enforces that operation flags and the halt flag are mutually exclusive over the entire + // CPU trace. + yield_constr + .constraint_transition(is_cpu_cycle * (is_cpu_cycle_next + next_halt_state - P::ONES)); // If a row is a CPU cycle and executing a native instruction (implemented as a table row; not // microcoded) then the program counter is incremented by 1 to obtain the next row's program @@ -90,16 +81,6 @@ pub fn eval_packed_generic( yield_constr.constraint_transition(is_last_noncpu_cycle * pc_diff); yield_constr.constraint_transition(is_last_noncpu_cycle * (nv.is_kernel_mode - P::ONES)); yield_constr.constraint_transition(is_last_noncpu_cycle * nv.stack_len); - - // The last row must be a CPU cycle row. - yield_constr.constraint_last_row(is_cpu_cycle - P::ONES); - // Also, the last row's `program_counter` must be inside the `halt` infinite loop. Note that - // that loop consists of two instructions, so we must check for `halt` and `halt_inner` labels. - let (halt_pc0, halt_pc1) = get_halt_pcs::(); - yield_constr - .constraint_last_row((lv.program_counter - halt_pc0) * (lv.program_counter - halt_pc1)); - // Finally, the last row must be in kernel mode. - yield_constr.constraint_last_row(lv.is_kernel_mode - P::ONES); } pub fn eval_ext_circuit, const D: usize>( @@ -108,11 +89,21 @@ pub fn eval_ext_circuit, const D: usize>( nv: &CpuColumnsView>, yield_constr: &mut RecursiveConstraintConsumer, ) { + let one = builder.one_extension(); + let is_cpu_cycle = builder.add_many_extension(COL_MAP.op.iter().map(|&col_i| lv[col_i])); let is_cpu_cycle_next = builder.add_many_extension(COL_MAP.op.iter().map(|&col_i| nv[col_i])); - // Once we start executing instructions, then we continue until the end of the table. + + let next_halt_state = builder.add_extension(nv.is_bootstrap_kernel, is_cpu_cycle_next); + let next_halt_state = builder.sub_extension(one, next_halt_state); + + // Once we start executing instructions, then we continue until the end of the table + // or we reach dummy padding rows. This, along with the constraints on the first row, + // enforces that operation flags and the halt flag are mutually exclusive over the entire + // CPU trace. { - let constr = builder.mul_sub_extension(is_cpu_cycle, is_cpu_cycle_next, is_cpu_cycle); + let constr = builder.add_extension(is_cpu_cycle_next, next_halt_state); + let constr = builder.mul_sub_extension(is_cpu_cycle, constr, is_cpu_cycle); yield_constr.constraint_transition(builder, constr); } @@ -155,30 +146,4 @@ pub fn eval_ext_circuit, const D: usize>( let kernel_constr = builder.mul_extension(is_last_noncpu_cycle, nv.stack_len); yield_constr.constraint_transition(builder, kernel_constr); } - - // The last row must be a CPU cycle row. - { - let one = builder.one_extension(); - let constr = builder.sub_extension(is_cpu_cycle, one); - yield_constr.constraint_last_row(builder, constr); - } - // Also, the last row's `program_counter` must be inside the `halt` infinite loop. Note that - // that loop consists of two instructions, so we must check for `halt` and `halt_inner` labels. - { - let (halt_pc0, halt_pc1) = get_halt_pcs(); - let halt_pc0_target = builder.constant_extension(halt_pc0); - let halt_pc1_target = builder.constant_extension(halt_pc1); - - let halt_pc0_offset = builder.sub_extension(lv.program_counter, halt_pc0_target); - let halt_pc1_offset = builder.sub_extension(lv.program_counter, halt_pc1_target); - let constr = builder.mul_extension(halt_pc0_offset, halt_pc1_offset); - - yield_constr.constraint_last_row(builder, constr); - } - // Finally, the last row must be in kernel mode. - { - let one = builder.one_extension(); - let constr = builder.sub_extension(lv.is_kernel_mode, one); - yield_constr.constraint_last_row(builder, constr); - } } diff --git a/evm/src/cpu/cpu_stark.rs b/evm/src/cpu/cpu_stark.rs index 7fd0c76f..bd2fcf19 100644 --- a/evm/src/cpu/cpu_stark.rs +++ b/evm/src/cpu/cpu_stark.rs @@ -8,6 +8,7 @@ use plonky2::field::packed::PackedField; use plonky2::field::types::Field; use plonky2::hash::hash_types::RichField; +use super::halt; use crate::all_stark::Table; use crate::constraint_consumer::{ConstraintConsumer, RecursiveConstraintConsumer}; use crate::cpu::columns::{CpuColumnsView, COL_MAP, NUM_CPU_COLUMNS}; @@ -48,9 +49,8 @@ pub fn ctl_filter_keccak_sponge() -> Column { /// Create the vector of Columns corresponding to the two inputs and /// one output of a binary operation. -fn ctl_data_binops(ops: &[usize]) -> Vec> { - let mut res = Column::singles(ops).collect_vec(); - res.extend(Column::singles(COL_MAP.mem_channels[0].value)); +fn ctl_data_binops() -> Vec> { + let mut res = Column::singles(COL_MAP.mem_channels[0].value).collect_vec(); res.extend(Column::singles(COL_MAP.mem_channels[1].value)); res.extend(Column::singles( COL_MAP.mem_channels[NUM_GP_CHANNELS - 1].value, @@ -70,10 +70,9 @@ fn ctl_data_binops(ops: &[usize]) -> Vec> { /// case of shift operations, which will skip the first memory channel and use the /// next three as ternary inputs. Because both `MUL` and `DIV` are binary operations, /// the last memory channel used for the inputs will be safely ignored. -fn ctl_data_ternops(ops: &[usize], is_shift: bool) -> Vec> { +fn ctl_data_ternops(is_shift: bool) -> Vec> { let offset = is_shift as usize; - let mut res = Column::singles(ops).collect_vec(); - res.extend(Column::singles(COL_MAP.mem_channels[offset].value)); + let mut res = Column::singles(COL_MAP.mem_channels[offset].value).collect_vec(); res.extend(Column::singles(COL_MAP.mem_channels[offset + 1].value)); res.extend(Column::singles(COL_MAP.mem_channels[offset + 2].value)); res.extend(Column::singles( @@ -85,7 +84,7 @@ fn ctl_data_ternops(ops: &[usize], is_shift: bool) -> Vec> { pub fn ctl_data_logic() -> Vec> { // Instead of taking single columns, we reconstruct the entire opcode value directly. let mut res = vec![Column::le_bits(COL_MAP.opcode_bits)]; - res.extend(ctl_data_binops(&[])); + res.extend(ctl_data_binops()); res } @@ -94,22 +93,9 @@ pub fn ctl_filter_logic() -> Column { } pub fn ctl_arithmetic_base_rows() -> TableWithColumns { - const OPS: [usize; 14] = [ - COL_MAP.op.add, - COL_MAP.op.sub, - COL_MAP.op.mul, - COL_MAP.op.lt, - COL_MAP.op.gt, - COL_MAP.op.addfp254, - COL_MAP.op.mulfp254, - COL_MAP.op.subfp254, - COL_MAP.op.addmod, - COL_MAP.op.mulmod, - COL_MAP.op.submod, - COL_MAP.op.div, - COL_MAP.op.mod_, - COL_MAP.op.byte, - ]; + // Instead of taking single columns, we reconstruct the entire opcode value directly. + let mut columns = vec![Column::le_bits(COL_MAP.opcode_bits)]; + columns.extend(ctl_data_ternops(false)); // Create the CPU Table whose columns are those with the three // inputs and one output of the ternary operations listed in `ops` // (also `ops` is used as the operation filter). The list of @@ -117,40 +103,59 @@ pub fn ctl_arithmetic_base_rows() -> TableWithColumns { // the third input. TableWithColumns::new( Table::Cpu, - ctl_data_ternops(&OPS, false), - Some(Column::sum(OPS)), + columns, + Some(Column::sum([ + COL_MAP.op.binary_op, + COL_MAP.op.fp254_op, + COL_MAP.op.ternary_op, + ])), ) } pub fn ctl_arithmetic_shift_rows() -> TableWithColumns { - const OPS: [usize; 14] = [ - COL_MAP.op.add, - COL_MAP.op.sub, - // SHL is interpreted as MUL on the arithmetic side - COL_MAP.op.shl, - COL_MAP.op.lt, - COL_MAP.op.gt, - COL_MAP.op.addfp254, - COL_MAP.op.mulfp254, - COL_MAP.op.subfp254, - COL_MAP.op.addmod, - COL_MAP.op.mulmod, - COL_MAP.op.submod, - // SHR is interpreted as DIV on the arithmetic side - COL_MAP.op.shr, - COL_MAP.op.mod_, - COL_MAP.op.byte, - ]; + // Instead of taking single columns, we reconstruct the entire opcode value directly. + let mut columns = vec![Column::le_bits(COL_MAP.opcode_bits)]; + columns.extend(ctl_data_ternops(true)); // Create the CPU Table whose columns are those with the three // inputs and one output of the ternary operations listed in `ops` // (also `ops` is used as the operation filter). The list of // operations includes binary operations which will simply ignore // the third input. - TableWithColumns::new( - Table::Cpu, - ctl_data_ternops(&OPS, true), - Some(Column::sum([COL_MAP.op.shl, COL_MAP.op.shr])), - ) + TableWithColumns::new(Table::Cpu, columns, Some(Column::single(COL_MAP.op.shift))) +} + +pub fn ctl_data_byte_packing() -> Vec> { + ctl_data_keccak_sponge() +} + +pub fn ctl_filter_byte_packing() -> Column { + Column::single(COL_MAP.op.mload_32bytes) +} + +pub fn ctl_data_byte_unpacking() -> Vec> { + // When executing MSTORE_32BYTES, the GP memory channels are used as follows: + // GP channel 0: stack[-1] = context + // GP channel 1: stack[-2] = segment + // GP channel 2: stack[-3] = virt + // GP channel 3: stack[-4] = val + // GP channel 4: stack[-5] = len + let context = Column::single(COL_MAP.mem_channels[0].value[0]); + let segment = Column::single(COL_MAP.mem_channels[1].value[0]); + let virt = Column::single(COL_MAP.mem_channels[2].value[0]); + let val = Column::singles(COL_MAP.mem_channels[3].value); + let len = Column::single(COL_MAP.mem_channels[4].value[0]); + + let num_channels = F::from_canonical_usize(NUM_CHANNELS); + let timestamp = Column::linear_combination([(COL_MAP.clock, num_channels)]); + + let mut res = vec![context, segment, virt, len, timestamp]; + res.extend(val); + + res +} + +pub fn ctl_filter_byte_unpacking() -> Column { + Column::single(COL_MAP.op.mstore_32bytes) } pub const MEM_CODE_CHANNEL_IDX: usize = 0; @@ -240,15 +245,16 @@ impl, const D: usize> Stark for CpuStark, const D: usize> Stark for CpuStark(lv: &mut CpuColumnsView) { @@ -97,6 +100,10 @@ pub fn generate(lv: &mut CpuColumnsView) { let flag = available && opcode_match; lv[col] = F::from_bool(flag); } + + if opcode == 0xfb || opcode == 0xfc { + lv.op.m_op_general = F::from_bool(kernel); + } } /// Break up an opcode (which is 8 bits long) into its eight bits. @@ -132,17 +139,17 @@ pub fn eval_packed_generic( let flag = lv[flag_col]; yield_constr.constraint(flag * (flag - P::ONES)); } - // Manually check the logic_op flag combining AND, OR and XOR. - let flag = lv.op.logic_op; - yield_constr.constraint(flag * (flag - P::ONES)); + // Also check that the combined instruction flags are valid. + for flag_idx in COMBINED_OPCODES { + yield_constr.constraint(lv[flag_idx] * (lv[flag_idx] - P::ONES)); + } - // Now check that they sum to 0 or 1. - // Includes the logic_op flag encompassing AND, OR and XOR opcodes. + // Now check that they sum to 0 or 1, including the combined flags. let flag_sum: P = OPCODES .into_iter() .map(|(_, _, _, flag_col)| lv[flag_col]) - .sum::

() - + lv.op.logic_op; + .chain(COMBINED_OPCODES.map(|op| lv[op])) + .sum::

(); yield_constr.constraint(flag_sum * (flag_sum - P::ONES)); // Finally, classify all opcodes, together with the kernel flag, into blocks @@ -171,6 +178,20 @@ pub fn eval_packed_generic( // correct mode. yield_constr.constraint(lv[col] * (unavailable + opcode_mismatch)); } + + // Manually check lv.op.m_op_constr + let opcode: P = lv + .opcode_bits + .into_iter() + .enumerate() + .map(|(i, bit)| bit * P::Scalar::from_canonical_u64(1 << i)) + .sum(); + yield_constr.constraint((P::ONES - kernel_mode) * lv.op.m_op_general); + + let m_op_constr = (opcode - P::Scalar::from_canonical_usize(0xfb_usize)) + * (opcode - P::Scalar::from_canonical_usize(0xfc_usize)) + * lv.op.m_op_general; + yield_constr.constraint(m_op_constr); } pub fn eval_ext_circuit, const D: usize>( @@ -202,15 +223,16 @@ pub fn eval_ext_circuit, const D: usize>( let constr = builder.mul_sub_extension(flag, flag, flag); yield_constr.constraint(builder, constr); } - // Manually check the logic_op flag combining AND, OR and XOR. - let flag = lv.op.logic_op; - let constr = builder.mul_sub_extension(flag, flag, flag); - yield_constr.constraint(builder, constr); + // Also check that the combined instruction flags are valid. + for flag_idx in COMBINED_OPCODES { + let constr = builder.mul_sub_extension(lv[flag_idx], lv[flag_idx], lv[flag_idx]); + yield_constr.constraint(builder, constr); + } - // Now check that they sum to 0 or 1. - // Includes the logic_op flag encompassing AND, OR and XOR opcodes. + // Now check that they sum to 0 or 1, including the combined flags. { - let mut flag_sum = lv.op.logic_op; + let mut flag_sum = + builder.add_many_extension(COMBINED_OPCODES.into_iter().map(|idx| lv[idx])); for (_, _, _, flag_col) in OPCODES { let flag = lv[flag_col]; flag_sum = builder.add_extension(flag_sum, flag); @@ -248,4 +270,28 @@ pub fn eval_ext_circuit, const D: usize>( let constr = builder.mul_extension(lv[col], constr); yield_constr.constraint(builder, constr); } + + // Manually check lv.op.m_op_constr + let opcode = lv + .opcode_bits + .into_iter() + .rev() + .fold(builder.zero_extension(), |cumul, bit| { + builder.mul_const_add_extension(F::TWO, cumul, bit) + }); + + let mload_opcode = builder.constant_extension(F::Extension::from_canonical_usize(0xfb_usize)); + let mstore_opcode = builder.constant_extension(F::Extension::from_canonical_usize(0xfc_usize)); + + let one_extension = builder.constant_extension(F::Extension::ONE); + let is_not_kernel_mode = builder.sub_extension(one_extension, kernel_mode); + let constr = builder.mul_extension(is_not_kernel_mode, lv.op.m_op_general); + yield_constr.constraint(builder, constr); + + let mload_constr = builder.sub_extension(opcode, mload_opcode); + let mstore_constr = builder.sub_extension(opcode, mstore_opcode); + let mut m_op_constr = builder.mul_extension(mload_constr, mstore_constr); + m_op_constr = builder.mul_extension(m_op_constr, lv.op.m_op_general); + + yield_constr.constraint(builder, m_op_constr); } diff --git a/evm/src/cpu/gas.rs b/evm/src/cpu/gas.rs index 61690005..51f375c0 100644 --- a/evm/src/cpu/gas.rs +++ b/evm/src/cpu/gas.rs @@ -19,25 +19,13 @@ const G_MID: Option = Some(8); const G_HIGH: Option = Some(10); const SIMPLE_OPCODES: OpsColumnsView> = OpsColumnsView { - add: G_VERYLOW, - mul: G_LOW, - sub: G_VERYLOW, - div: G_LOW, - mod_: G_LOW, - addmod: G_MID, - mulmod: G_MID, - addfp254: KERNEL_ONLY_INSTR, - mulfp254: KERNEL_ONLY_INSTR, - subfp254: KERNEL_ONLY_INSTR, - submod: KERNEL_ONLY_INSTR, - lt: G_VERYLOW, - gt: G_VERYLOW, + binary_op: None, // This is handled manually below + ternary_op: None, // This is handled manually below + fp254_op: KERNEL_ONLY_INSTR, eq_iszero: G_VERYLOW, logic_op: G_VERYLOW, not: G_VERYLOW, - byte: G_VERYLOW, - shl: G_VERYLOW, - shr: G_VERYLOW, + shift: G_VERYLOW, keccak_general: KERNEL_ONLY_INSTR, prover_input: KERNEL_ONLY_INSTR, pop: G_BASE, @@ -49,9 +37,10 @@ const SIMPLE_OPCODES: OpsColumnsView> = OpsColumnsView { dup: G_VERYLOW, swap: G_VERYLOW, context_op: KERNEL_ONLY_INSTR, + mstore_32bytes: KERNEL_ONLY_INSTR, + mload_32bytes: KERNEL_ONLY_INSTR, exit_kernel: None, - mload_general: KERNEL_ONLY_INSTR, - mstore_general: KERNEL_ONLY_INSTR, + m_op_general: KERNEL_ONLY_INSTR, syscall: None, exception: None, }; @@ -95,6 +84,21 @@ fn eval_packed_accumulate( let jump_gas_cost = P::Scalar::from_canonical_u32(G_MID.unwrap()) + lv.opcode_bits[0] * P::Scalar::from_canonical_u32(G_HIGH.unwrap() - G_MID.unwrap()); yield_constr.constraint_transition(lv.op.jumps * (nv.gas - lv.gas - jump_gas_cost)); + + // For binary_ops. + // MUL, DIV and MOD are differentiated from ADD, SUB, LT, GT and BYTE by their first and fifth bits set to 0. + let cost_filter = lv.opcode_bits[0] + lv.opcode_bits[4] - lv.opcode_bits[0] * lv.opcode_bits[4]; + let binary_op_cost = P::Scalar::from_canonical_u32(G_LOW.unwrap()) + + cost_filter + * (P::Scalar::from_canonical_u32(G_VERYLOW.unwrap()) + - P::Scalar::from_canonical_u32(G_LOW.unwrap())); + yield_constr.constraint_transition(lv.op.binary_op * (nv.gas - lv.gas - binary_op_cost)); + + // For ternary_ops. + // SUBMOD is differentiated by its second bit set to 1. + let ternary_op_cost = P::Scalar::from_canonical_u32(G_MID.unwrap()) + - lv.opcode_bits[1] * P::Scalar::from_canonical_u32(G_MID.unwrap()); + yield_constr.constraint_transition(lv.op.ternary_op * (nv.gas - lv.gas - ternary_op_cost)); } fn eval_packed_init( @@ -184,6 +188,41 @@ fn eval_ext_circuit_accumulate, const D: usize>( let gas_diff = builder.sub_extension(nv_lv_diff, jump_gas_cost); let constr = builder.mul_extension(filter, gas_diff); yield_constr.constraint_transition(builder, constr); + + // For binary_ops. + // MUL, DIV and MOD are differentiated from ADD, SUB, LT, GT and BYTE by their first and fifth bits set to 0. + let filter = lv.op.binary_op; + let cost_filter = { + let a = builder.add_extension(lv.opcode_bits[0], lv.opcode_bits[4]); + let b = builder.mul_extension(lv.opcode_bits[0], lv.opcode_bits[4]); + builder.sub_extension(a, b) + }; + let binary_op_cost = builder.mul_const_extension( + F::from_canonical_u32(G_VERYLOW.unwrap()) - F::from_canonical_u32(G_LOW.unwrap()), + cost_filter, + ); + let binary_op_cost = + builder.add_const_extension(binary_op_cost, F::from_canonical_u32(G_LOW.unwrap())); + + let nv_lv_diff = builder.sub_extension(nv.gas, lv.gas); + let gas_diff = builder.sub_extension(nv_lv_diff, binary_op_cost); + let constr = builder.mul_extension(filter, gas_diff); + yield_constr.constraint_transition(builder, constr); + + // For ternary_ops. + // SUBMOD is differentiated by its second bit set to 1. + let filter = lv.op.ternary_op; + let ternary_op_cost = builder.mul_const_extension( + F::from_canonical_u32(G_MID.unwrap()).neg(), + lv.opcode_bits[1], + ); + let ternary_op_cost = + builder.add_const_extension(ternary_op_cost, F::from_canonical_u32(G_MID.unwrap())); + + let nv_lv_diff = builder.sub_extension(nv.gas, lv.gas); + let gas_diff = builder.sub_extension(nv_lv_diff, ternary_op_cost); + let constr = builder.mul_extension(filter, gas_diff); + yield_constr.constraint_transition(builder, constr); } fn eval_ext_circuit_init, const D: usize>( diff --git a/evm/src/cpu/halt.rs b/evm/src/cpu/halt.rs new file mode 100644 index 00000000..9ad34344 --- /dev/null +++ b/evm/src/cpu/halt.rs @@ -0,0 +1,98 @@ +//! Once the CPU execution is over (i.e. reached the `halt` label in the kernel), +//! the CPU trace will be padded with special dummy rows, incurring no memory overhead. + +use plonky2::field::extension::Extendable; +use plonky2::field::packed::PackedField; +use plonky2::hash::hash_types::RichField; +use plonky2::iop::ext_target::ExtensionTarget; + +use super::control_flow::get_halt_pc; +use crate::constraint_consumer::{ConstraintConsumer, RecursiveConstraintConsumer}; +use crate::cpu::columns::{CpuColumnsView, COL_MAP}; +use crate::cpu::membus::NUM_GP_CHANNELS; + +pub fn eval_packed( + lv: &CpuColumnsView

, + nv: &CpuColumnsView

, + yield_constr: &mut ConstraintConsumer

, +) { + let is_cpu_cycle: P = COL_MAP.op.iter().map(|&col_i| lv[col_i]).sum(); + let is_cpu_cycle_next: P = COL_MAP.op.iter().map(|&col_i| nv[col_i]).sum(); + + let halt_state = P::ONES - lv.is_bootstrap_kernel - is_cpu_cycle; + let next_halt_state = P::ONES - nv.is_bootstrap_kernel - is_cpu_cycle_next; + + // The halt flag must be boolean. + yield_constr.constraint(halt_state * (halt_state - P::ONES)); + // Once we reach a padding row, there must be only padding rows. + yield_constr.constraint_transition(halt_state * (next_halt_state - P::ONES)); + + // Padding rows should have their memory channels disabled. + for i in 0..NUM_GP_CHANNELS { + let channel = lv.mem_channels[i]; + yield_constr.constraint(halt_state * channel.used); + } + + // The last row must be a dummy padding row. + yield_constr.constraint_last_row(halt_state - P::ONES); + + // Also, a padding row's `program_counter` must be at the `halt` label. + // In particular, it ensures that the first padding row may only be added + // after we jumped to the `halt` function. Subsequent padding rows may set + // the `program_counter` to arbitrary values (there's no transition + // constraints) so we can place this requirement on them too. + let halt_pc = get_halt_pc::(); + yield_constr.constraint(halt_state * (lv.program_counter - halt_pc)); +} + +pub fn eval_ext_circuit, const D: usize>( + builder: &mut plonky2::plonk::circuit_builder::CircuitBuilder, + lv: &CpuColumnsView>, + nv: &CpuColumnsView>, + yield_constr: &mut RecursiveConstraintConsumer, +) { + let one = builder.one_extension(); + + let is_cpu_cycle = builder.add_many_extension(COL_MAP.op.iter().map(|&col_i| lv[col_i])); + let is_cpu_cycle_next = builder.add_many_extension(COL_MAP.op.iter().map(|&col_i| nv[col_i])); + + let halt_state = builder.add_extension(lv.is_bootstrap_kernel, is_cpu_cycle); + let halt_state = builder.sub_extension(one, halt_state); + let next_halt_state = builder.add_extension(nv.is_bootstrap_kernel, is_cpu_cycle_next); + let next_halt_state = builder.sub_extension(one, next_halt_state); + + // The halt flag must be boolean. + let constr = builder.mul_sub_extension(halt_state, halt_state, halt_state); + yield_constr.constraint(builder, constr); + // Once we reach a padding row, there must be only padding rows. + let constr = builder.mul_sub_extension(halt_state, next_halt_state, halt_state); + yield_constr.constraint_transition(builder, constr); + + // Padding rows should have their memory channels disabled. + for i in 0..NUM_GP_CHANNELS { + let channel = lv.mem_channels[i]; + let constr = builder.mul_extension(halt_state, channel.used); + yield_constr.constraint(builder, constr); + } + + // The last row must be a dummy padding row. + { + let one = builder.one_extension(); + let constr = builder.sub_extension(halt_state, one); + yield_constr.constraint_last_row(builder, constr); + } + + // Also, a padding row's `program_counter` must be at the `halt` label. + // In particular, it ensures that the first padding row may only be added + // after we jumped to the `halt` function. Subsequent padding rows may set + // the `program_counter` to arbitrary values (there's no transition + // constraints) so we can place this requirement on them too. + { + let halt_pc = get_halt_pc(); + let halt_pc_target = builder.constant_extension(halt_pc); + let constr = builder.sub_extension(lv.program_counter, halt_pc_target); + let constr = builder.mul_extension(halt_state, constr); + + yield_constr.constraint(builder, constr); + } +} diff --git a/evm/src/cpu/jumps.rs b/evm/src/cpu/jumps.rs index a3c38a90..62d9bdfd 100644 --- a/evm/src/cpu/jumps.rs +++ b/evm/src/cpu/jumps.rs @@ -75,8 +75,8 @@ pub fn eval_packed_jump_jumpi( let is_jumpi = filter * lv.opcode_bits[0]; // Stack constraints. - stack::eval_packed_one(lv, is_jump, stack::JUMP_OP.unwrap(), yield_constr); - stack::eval_packed_one(lv, is_jumpi, stack::JUMPI_OP.unwrap(), yield_constr); + stack::eval_packed_one(lv, nv, is_jump, stack::JUMP_OP.unwrap(), yield_constr); + stack::eval_packed_one(lv, nv, is_jumpi, stack::JUMPI_OP.unwrap(), yield_constr); // If `JUMP`, re-use the `JUMPI` logic, but setting the second input (the predicate) to be 1. // In other words, we implement `JUMP(dst)` as `JUMPI(dst, cond=1)`. @@ -151,10 +151,18 @@ pub fn eval_ext_circuit_jump_jumpi, const D: usize> let is_jumpi = builder.mul_extension(filter, lv.opcode_bits[0]); // Stack constraints. - stack::eval_ext_circuit_one(builder, lv, is_jump, stack::JUMP_OP.unwrap(), yield_constr); stack::eval_ext_circuit_one( builder, lv, + nv, + is_jump, + stack::JUMP_OP.unwrap(), + yield_constr, + ); + stack::eval_ext_circuit_one( + builder, + lv, + nv, is_jumpi, stack::JUMPI_OP.unwrap(), yield_constr, diff --git a/evm/src/cpu/kernel/asm/halt.asm b/evm/src/cpu/kernel/asm/halt.asm index 906ce51a..49561fd6 100644 --- a/evm/src/cpu/kernel/asm/halt.asm +++ b/evm/src/cpu/kernel/asm/halt.asm @@ -1,6 +1,2 @@ global halt: - PUSH halt_pc0 -global halt_pc0: - DUP1 -global halt_pc1: - JUMP + PANIC diff --git a/evm/src/cpu/kernel/asm/memory/packing.asm b/evm/src/cpu/kernel/asm/memory/packing.asm index 0f802335..1dbbf393 100644 --- a/evm/src/cpu/kernel/asm/memory/packing.asm +++ b/evm/src/cpu/kernel/asm/memory/packing.asm @@ -7,40 +7,10 @@ // NOTE: addr: 3 denotes a (context, segment, virtual) tuple global mload_packing: // stack: addr: 3, len, retdest - DUP3 DUP3 DUP3 MLOAD_GENERAL DUP5 %eq_const(1) %jumpi(mload_packing_return) %shl_const(8) - DUP4 %add_const(1) DUP4 DUP4 MLOAD_GENERAL ADD DUP5 %eq_const(2) %jumpi(mload_packing_return) %shl_const(8) - DUP4 %add_const(2) DUP4 DUP4 MLOAD_GENERAL ADD DUP5 %eq_const(3) %jumpi(mload_packing_return) %shl_const(8) - DUP4 %add_const(3) DUP4 DUP4 MLOAD_GENERAL ADD DUP5 %eq_const(4) %jumpi(mload_packing_return) %shl_const(8) - DUP4 %add_const(4) DUP4 DUP4 MLOAD_GENERAL ADD DUP5 %eq_const(5) %jumpi(mload_packing_return) %shl_const(8) - DUP4 %add_const(5) DUP4 DUP4 MLOAD_GENERAL ADD DUP5 %eq_const(6) %jumpi(mload_packing_return) %shl_const(8) - DUP4 %add_const(6) DUP4 DUP4 MLOAD_GENERAL ADD DUP5 %eq_const(7) %jumpi(mload_packing_return) %shl_const(8) - DUP4 %add_const(7) DUP4 DUP4 MLOAD_GENERAL ADD DUP5 %eq_const(8) %jumpi(mload_packing_return) %shl_const(8) - DUP4 %add_const(8) DUP4 DUP4 MLOAD_GENERAL ADD DUP5 %eq_const(9) %jumpi(mload_packing_return) %shl_const(8) - DUP4 %add_const(9) DUP4 DUP4 MLOAD_GENERAL ADD DUP5 %eq_const(10) %jumpi(mload_packing_return) %shl_const(8) - DUP4 %add_const(10) DUP4 DUP4 MLOAD_GENERAL ADD DUP5 %eq_const(11) %jumpi(mload_packing_return) %shl_const(8) - DUP4 %add_const(11) DUP4 DUP4 MLOAD_GENERAL ADD DUP5 %eq_const(12) %jumpi(mload_packing_return) %shl_const(8) - DUP4 %add_const(12) DUP4 DUP4 MLOAD_GENERAL ADD DUP5 %eq_const(13) %jumpi(mload_packing_return) %shl_const(8) - DUP4 %add_const(13) DUP4 DUP4 MLOAD_GENERAL ADD DUP5 %eq_const(14) %jumpi(mload_packing_return) %shl_const(8) - DUP4 %add_const(14) DUP4 DUP4 MLOAD_GENERAL ADD DUP5 %eq_const(15) %jumpi(mload_packing_return) %shl_const(8) - DUP4 %add_const(15) DUP4 DUP4 MLOAD_GENERAL ADD DUP5 %eq_const(16) %jumpi(mload_packing_return) %shl_const(8) - DUP4 %add_const(16) DUP4 DUP4 MLOAD_GENERAL ADD DUP5 %eq_const(17) %jumpi(mload_packing_return) %shl_const(8) - DUP4 %add_const(17) DUP4 DUP4 MLOAD_GENERAL ADD DUP5 %eq_const(18) %jumpi(mload_packing_return) %shl_const(8) - DUP4 %add_const(18) DUP4 DUP4 MLOAD_GENERAL ADD DUP5 %eq_const(19) %jumpi(mload_packing_return) %shl_const(8) - DUP4 %add_const(19) DUP4 DUP4 MLOAD_GENERAL ADD DUP5 %eq_const(20) %jumpi(mload_packing_return) %shl_const(8) - DUP4 %add_const(20) DUP4 DUP4 MLOAD_GENERAL ADD DUP5 %eq_const(21) %jumpi(mload_packing_return) %shl_const(8) - DUP4 %add_const(21) DUP4 DUP4 MLOAD_GENERAL ADD DUP5 %eq_const(22) %jumpi(mload_packing_return) %shl_const(8) - DUP4 %add_const(22) DUP4 DUP4 MLOAD_GENERAL ADD DUP5 %eq_const(23) %jumpi(mload_packing_return) %shl_const(8) - DUP4 %add_const(23) DUP4 DUP4 MLOAD_GENERAL ADD DUP5 %eq_const(24) %jumpi(mload_packing_return) %shl_const(8) - DUP4 %add_const(24) DUP4 DUP4 MLOAD_GENERAL ADD DUP5 %eq_const(25) %jumpi(mload_packing_return) %shl_const(8) - DUP4 %add_const(25) DUP4 DUP4 MLOAD_GENERAL ADD DUP5 %eq_const(26) %jumpi(mload_packing_return) %shl_const(8) - DUP4 %add_const(26) DUP4 DUP4 MLOAD_GENERAL ADD DUP5 %eq_const(27) %jumpi(mload_packing_return) %shl_const(8) - DUP4 %add_const(27) DUP4 DUP4 MLOAD_GENERAL ADD DUP5 %eq_const(28) %jumpi(mload_packing_return) %shl_const(8) - DUP4 %add_const(28) DUP4 DUP4 MLOAD_GENERAL ADD DUP5 %eq_const(29) %jumpi(mload_packing_return) %shl_const(8) - DUP4 %add_const(29) DUP4 DUP4 MLOAD_GENERAL ADD DUP5 %eq_const(30) %jumpi(mload_packing_return) %shl_const(8) - DUP4 %add_const(30) DUP4 DUP4 MLOAD_GENERAL ADD DUP5 %eq_const(31) %jumpi(mload_packing_return) %shl_const(8) - DUP4 %add_const(31) DUP4 DUP4 MLOAD_GENERAL ADD -mload_packing_return: - %stack (packed_value, addr: 3, len, retdest) -> (retdest, packed_value) + MLOAD_32BYTES + // stack: packed_value, retdest + SWAP1 + // stack: retdest, packed_value JUMP %macro mload_packing @@ -72,40 +42,12 @@ global mload_packing_u64_LE: // Post stack: offset' global mstore_unpacking: // stack: context, segment, offset, value, len, retdest - // We will enumerate i in (32 - len)..32. - // That way BYTE(i, value) will give us the bytes we want. - DUP5 // len - PUSH 32 - SUB - -mstore_unpacking_loop: - // stack: i, context, segment, offset, value, len, retdest - // If i == 32, finish. - DUP1 - %eq_const(32) - %jumpi(mstore_unpacking_finish) - - // stack: i, context, segment, offset, value, len, retdest - DUP5 // value - DUP2 // i - BYTE - // stack: value[i], i, context, segment, offset, value, len, retdest - DUP5 DUP5 DUP5 // context, segment, offset - // stack: context, segment, offset, value[i], i, context, segment, offset, value, len, retdest - MSTORE_GENERAL - // stack: i, context, segment, offset, value, len, retdest - - // Increment offset. - SWAP3 %increment SWAP3 - // Increment i. - %increment - - %jump(mstore_unpacking_loop) - -mstore_unpacking_finish: - // stack: i, context, segment, offset, value, len, retdest - %pop3 - %stack (offset, value, len, retdest) -> (retdest, offset) + %stack(context, segment, offset, value, len, retdest) -> (context, segment, offset, value, len, offset, len, retdest) + // stack: context, segment, offset, value, len, offset, len, retdest + MSTORE_32BYTES + // stack: offset, len, retdest + ADD SWAP1 + // stack: retdest, offset' JUMP %macro mstore_unpacking diff --git a/evm/src/cpu/kernel/asm/memory/syscalls.asm b/evm/src/cpu/kernel/asm/memory/syscalls.asm index 5f02382f..3548930c 100644 --- a/evm/src/cpu/kernel/asm/memory/syscalls.asm +++ b/evm/src/cpu/kernel/asm/memory/syscalls.asm @@ -8,41 +8,12 @@ global sys_mload: // stack: expanded_num_bytes, kexit_info, offset %update_mem_bytes // stack: kexit_info, offset - PUSH 0 // acc = 0 - // stack: acc, kexit_info, offset - DUP3 %add_const( 0) %mload_current(@SEGMENT_MAIN_MEMORY) %shl_const(0xf8) ADD - DUP3 %add_const( 1) %mload_current(@SEGMENT_MAIN_MEMORY) %shl_const(0xf0) ADD - DUP3 %add_const( 2) %mload_current(@SEGMENT_MAIN_MEMORY) %shl_const(0xe8) ADD - DUP3 %add_const( 3) %mload_current(@SEGMENT_MAIN_MEMORY) %shl_const(0xe0) ADD - DUP3 %add_const( 4) %mload_current(@SEGMENT_MAIN_MEMORY) %shl_const(0xd8) ADD - DUP3 %add_const( 5) %mload_current(@SEGMENT_MAIN_MEMORY) %shl_const(0xd0) ADD - DUP3 %add_const( 6) %mload_current(@SEGMENT_MAIN_MEMORY) %shl_const(0xc8) ADD - DUP3 %add_const( 7) %mload_current(@SEGMENT_MAIN_MEMORY) %shl_const(0xc0) ADD - DUP3 %add_const( 8) %mload_current(@SEGMENT_MAIN_MEMORY) %shl_const(0xb8) ADD - DUP3 %add_const( 9) %mload_current(@SEGMENT_MAIN_MEMORY) %shl_const(0xb0) ADD - DUP3 %add_const(10) %mload_current(@SEGMENT_MAIN_MEMORY) %shl_const(0xa8) ADD - DUP3 %add_const(11) %mload_current(@SEGMENT_MAIN_MEMORY) %shl_const(0xa0) ADD - DUP3 %add_const(12) %mload_current(@SEGMENT_MAIN_MEMORY) %shl_const(0x98) ADD - DUP3 %add_const(13) %mload_current(@SEGMENT_MAIN_MEMORY) %shl_const(0x90) ADD - DUP3 %add_const(14) %mload_current(@SEGMENT_MAIN_MEMORY) %shl_const(0x88) ADD - DUP3 %add_const(15) %mload_current(@SEGMENT_MAIN_MEMORY) %shl_const(0x80) ADD - DUP3 %add_const(16) %mload_current(@SEGMENT_MAIN_MEMORY) %shl_const(0x78) ADD - DUP3 %add_const(17) %mload_current(@SEGMENT_MAIN_MEMORY) %shl_const(0x70) ADD - DUP3 %add_const(18) %mload_current(@SEGMENT_MAIN_MEMORY) %shl_const(0x68) ADD - DUP3 %add_const(19) %mload_current(@SEGMENT_MAIN_MEMORY) %shl_const(0x60) ADD - DUP3 %add_const(20) %mload_current(@SEGMENT_MAIN_MEMORY) %shl_const(0x58) ADD - DUP3 %add_const(21) %mload_current(@SEGMENT_MAIN_MEMORY) %shl_const(0x50) ADD - DUP3 %add_const(22) %mload_current(@SEGMENT_MAIN_MEMORY) %shl_const(0x48) ADD - DUP3 %add_const(23) %mload_current(@SEGMENT_MAIN_MEMORY) %shl_const(0x40) ADD - DUP3 %add_const(24) %mload_current(@SEGMENT_MAIN_MEMORY) %shl_const(0x38) ADD - DUP3 %add_const(25) %mload_current(@SEGMENT_MAIN_MEMORY) %shl_const(0x30) ADD - DUP3 %add_const(26) %mload_current(@SEGMENT_MAIN_MEMORY) %shl_const(0x28) ADD - DUP3 %add_const(27) %mload_current(@SEGMENT_MAIN_MEMORY) %shl_const(0x20) ADD - DUP3 %add_const(28) %mload_current(@SEGMENT_MAIN_MEMORY) %shl_const(0x18) ADD - DUP3 %add_const(29) %mload_current(@SEGMENT_MAIN_MEMORY) %shl_const(0x10) ADD - DUP3 %add_const(30) %mload_current(@SEGMENT_MAIN_MEMORY) %shl_const(0x08) ADD - DUP3 %add_const(31) %mload_current(@SEGMENT_MAIN_MEMORY) %shl_const(0x00) ADD - %stack (acc, kexit_info, offset) -> (kexit_info, acc) + %stack(kexit_info, offset) -> (offset, 32, kexit_info) + PUSH @SEGMENT_MAIN_MEMORY + GET_CONTEXT + // stack: addr: 3, len, kexit_info + MLOAD_32BYTES + %stack (value, kexit_info) -> (kexit_info, value) EXIT_KERNEL global sys_mstore: @@ -55,39 +26,12 @@ global sys_mstore: // stack: expanded_num_bytes, kexit_info, offset, value %update_mem_bytes // stack: kexit_info, offset, value - DUP3 PUSH 0 BYTE DUP3 %add_const( 0) %mstore_current(@SEGMENT_MAIN_MEMORY) - DUP3 PUSH 1 BYTE DUP3 %add_const( 1) %mstore_current(@SEGMENT_MAIN_MEMORY) - DUP3 PUSH 2 BYTE DUP3 %add_const( 2) %mstore_current(@SEGMENT_MAIN_MEMORY) - DUP3 PUSH 3 BYTE DUP3 %add_const( 3) %mstore_current(@SEGMENT_MAIN_MEMORY) - DUP3 PUSH 4 BYTE DUP3 %add_const( 4) %mstore_current(@SEGMENT_MAIN_MEMORY) - DUP3 PUSH 5 BYTE DUP3 %add_const( 5) %mstore_current(@SEGMENT_MAIN_MEMORY) - DUP3 PUSH 6 BYTE DUP3 %add_const( 6) %mstore_current(@SEGMENT_MAIN_MEMORY) - DUP3 PUSH 7 BYTE DUP3 %add_const( 7) %mstore_current(@SEGMENT_MAIN_MEMORY) - DUP3 PUSH 8 BYTE DUP3 %add_const( 8) %mstore_current(@SEGMENT_MAIN_MEMORY) - DUP3 PUSH 9 BYTE DUP3 %add_const( 9) %mstore_current(@SEGMENT_MAIN_MEMORY) - DUP3 PUSH 10 BYTE DUP3 %add_const(10) %mstore_current(@SEGMENT_MAIN_MEMORY) - DUP3 PUSH 11 BYTE DUP3 %add_const(11) %mstore_current(@SEGMENT_MAIN_MEMORY) - DUP3 PUSH 12 BYTE DUP3 %add_const(12) %mstore_current(@SEGMENT_MAIN_MEMORY) - DUP3 PUSH 13 BYTE DUP3 %add_const(13) %mstore_current(@SEGMENT_MAIN_MEMORY) - DUP3 PUSH 14 BYTE DUP3 %add_const(14) %mstore_current(@SEGMENT_MAIN_MEMORY) - DUP3 PUSH 15 BYTE DUP3 %add_const(15) %mstore_current(@SEGMENT_MAIN_MEMORY) - DUP3 PUSH 16 BYTE DUP3 %add_const(16) %mstore_current(@SEGMENT_MAIN_MEMORY) - DUP3 PUSH 17 BYTE DUP3 %add_const(17) %mstore_current(@SEGMENT_MAIN_MEMORY) - DUP3 PUSH 18 BYTE DUP3 %add_const(18) %mstore_current(@SEGMENT_MAIN_MEMORY) - DUP3 PUSH 19 BYTE DUP3 %add_const(19) %mstore_current(@SEGMENT_MAIN_MEMORY) - DUP3 PUSH 20 BYTE DUP3 %add_const(20) %mstore_current(@SEGMENT_MAIN_MEMORY) - DUP3 PUSH 21 BYTE DUP3 %add_const(21) %mstore_current(@SEGMENT_MAIN_MEMORY) - DUP3 PUSH 22 BYTE DUP3 %add_const(22) %mstore_current(@SEGMENT_MAIN_MEMORY) - DUP3 PUSH 23 BYTE DUP3 %add_const(23) %mstore_current(@SEGMENT_MAIN_MEMORY) - DUP3 PUSH 24 BYTE DUP3 %add_const(24) %mstore_current(@SEGMENT_MAIN_MEMORY) - DUP3 PUSH 25 BYTE DUP3 %add_const(25) %mstore_current(@SEGMENT_MAIN_MEMORY) - DUP3 PUSH 26 BYTE DUP3 %add_const(26) %mstore_current(@SEGMENT_MAIN_MEMORY) - DUP3 PUSH 27 BYTE DUP3 %add_const(27) %mstore_current(@SEGMENT_MAIN_MEMORY) - DUP3 PUSH 28 BYTE DUP3 %add_const(28) %mstore_current(@SEGMENT_MAIN_MEMORY) - DUP3 PUSH 29 BYTE DUP3 %add_const(29) %mstore_current(@SEGMENT_MAIN_MEMORY) - DUP3 PUSH 30 BYTE DUP3 %add_const(30) %mstore_current(@SEGMENT_MAIN_MEMORY) - DUP3 PUSH 31 BYTE DUP3 %add_const(31) %mstore_current(@SEGMENT_MAIN_MEMORY) - %stack (kexit_info, offset, value) -> (kexit_info) + %stack(kexit_info, offset, value) -> (offset, value, 32, kexit_info) + PUSH @SEGMENT_MAIN_MEMORY + GET_CONTEXT + // stack: addr: 3, value, len, kexit_info + MSTORE_32BYTES + // stack: kexit_info EXIT_KERNEL global sys_mstore8: diff --git a/evm/src/cpu/kernel/interpreter.rs b/evm/src/cpu/kernel/interpreter.rs index 98ea3cc2..c4deba99 100644 --- a/evm/src/cpu/kernel/interpreter.rs +++ b/evm/src/cpu/kernel/interpreter.rs @@ -200,7 +200,7 @@ impl<'a> Interpreter<'a> { self.generation_state.memory.contexts[0].segments[segment as usize] .content .iter() - .map(|x| x.as_u32() as u8) + .map(|x| x.low_u32() as u8) .collect() } @@ -391,6 +391,7 @@ impl<'a> Interpreter<'a> { self.stack(), self.get_kernel_general_memory() ), // "PANIC", + 0xee => self.run_mstore_32bytes(), // "MSTORE_32BYTES", 0xf0 => todo!(), // "CREATE", 0xf1 => todo!(), // "CALL", 0xf2 => todo!(), // "CALLCODE", @@ -399,6 +400,7 @@ impl<'a> Interpreter<'a> { 0xf5 => todo!(), // "CREATE2", 0xf6 => self.run_get_context(), // "GET_CONTEXT", 0xf7 => self.run_set_context(), // "SET_CONTEXT", + 0xf8 => self.run_mload_32bytes(), // "MLOAD_32BYTES", 0xf9 => todo!(), // "EXIT_KERNEL", 0xfa => todo!(), // "STATICCALL", 0xfb => self.run_mload_general(), // "MLOAD_GENERAL", @@ -1024,8 +1026,7 @@ impl<'a> Interpreter<'a> { fn run_mload_general(&mut self) { let context = self.pop().as_usize(); let segment = Segment::all()[self.pop().as_usize()]; - let offset_u256 = self.pop(); - let offset = offset_u256.as_usize(); + let offset = self.pop().as_usize(); let value = self .generation_state .memory @@ -1034,6 +1035,23 @@ impl<'a> Interpreter<'a> { self.push(value); } + fn run_mload_32bytes(&mut self) { + let context = self.pop().as_usize(); + let segment = Segment::all()[self.pop().as_usize()]; + let offset = self.pop().as_usize(); + let len = self.pop().as_usize(); + let bytes: Vec = (0..len) + .map(|i| { + self.generation_state + .memory + .mload_general(context, segment, offset + i) + .low_u32() as u8 + }) + .collect(); + let value = U256::from_big_endian(&bytes); + self.push(value); + } + fn run_mstore_general(&mut self) { let context = self.pop().as_usize(); let segment = Segment::all()[self.pop().as_usize()]; @@ -1044,6 +1062,25 @@ impl<'a> Interpreter<'a> { .mstore_general(context, segment, offset, value); } + fn run_mstore_32bytes(&mut self) { + let context = self.pop().as_usize(); + let segment = Segment::all()[self.pop().as_usize()]; + let offset = self.pop().as_usize(); + let value = self.pop(); + let len = self.pop().as_usize(); + + let mut bytes = vec![0; 32]; + value.to_little_endian(&mut bytes); + bytes.resize(len, 0); + bytes.reverse(); + + for (i, &byte) in bytes.iter().enumerate() { + self.generation_state + .memory + .mstore_general(context, segment, offset + i, byte.into()); + } + } + fn stack_len(&self) -> usize { self.generation_state.registers.stack_len } @@ -1270,6 +1307,7 @@ fn get_mnemonic(opcode: u8) -> &'static str { 0xa3 => "LOG3", 0xa4 => "LOG4", 0xa5 => "PANIC", + 0xee => "MSTORE_32BYTES", 0xf0 => "CREATE", 0xf1 => "CALL", 0xf2 => "CALLCODE", @@ -1278,6 +1316,7 @@ fn get_mnemonic(opcode: u8) -> &'static str { 0xf5 => "CREATE2", 0xf6 => "GET_CONTEXT", 0xf7 => "SET_CONTEXT", + 0xf8 => "MLOAD_32BYTES", 0xf9 => "EXIT_KERNEL", 0xfa => "STATICCALL", 0xfb => "MLOAD_GENERAL", diff --git a/evm/src/cpu/kernel/opcodes.rs b/evm/src/cpu/kernel/opcodes.rs index 09c493e0..2503a92e 100644 --- a/evm/src/cpu/kernel/opcodes.rs +++ b/evm/src/cpu/kernel/opcodes.rs @@ -113,6 +113,7 @@ pub fn get_opcode(mnemonic: &str) -> u8 { "LOG3" => 0xa3, "LOG4" => 0xa4, "PANIC" => 0xa5, + "MSTORE_32BYTES" => 0xee, "CREATE" => 0xf0, "CALL" => 0xf1, "CALLCODE" => 0xf2, @@ -121,6 +122,7 @@ pub fn get_opcode(mnemonic: &str) -> u8 { "CREATE2" => 0xf5, "GET_CONTEXT" => 0xf6, "SET_CONTEXT" => 0xf7, + "MLOAD_32BYTES" => 0xf8, "EXIT_KERNEL" => 0xf9, "STATICCALL" => 0xfa, "MLOAD_GENERAL" => 0xfb, diff --git a/evm/src/cpu/kernel/tests/blake2_f.rs b/evm/src/cpu/kernel/tests/blake2_f.rs index 1b465bed..b12c9f32 100644 --- a/evm/src/cpu/kernel/tests/blake2_f.rs +++ b/evm/src/cpu/kernel/tests/blake2_f.rs @@ -75,7 +75,7 @@ fn run_blake2_f( Ok(hash .iter() - .map(|&x| x.as_u64()) + .map(|&x| x.low_u64()) .collect::>() .try_into() .unwrap()) diff --git a/evm/src/cpu/memio.rs b/evm/src/cpu/memio.rs index 09490e87..aa3749ca 100644 --- a/evm/src/cpu/memio.rs +++ b/evm/src/cpu/memio.rs @@ -7,6 +7,7 @@ use plonky2::iop::ext_target::ExtensionTarget; use crate::constraint_consumer::{ConstraintConsumer, RecursiveConstraintConsumer}; use crate::cpu::columns::CpuColumnsView; use crate::cpu::membus::NUM_GP_CHANNELS; +use crate::cpu::stack; fn get_addr(lv: &CpuColumnsView) -> (T, T, T) { let addr_context = lv.mem_channels[0].value[0]; @@ -17,9 +18,11 @@ fn get_addr(lv: &CpuColumnsView) -> (T, T, T) { fn eval_packed_load( lv: &CpuColumnsView

, + nv: &CpuColumnsView

, yield_constr: &mut ConstraintConsumer

, ) { - let filter = lv.op.mload_general; + // The opcode for MLOAD_GENERAL is 0xfb. If the operation is MLOAD_GENERAL, lv.opcode_bits[0] = 1 + let filter = lv.op.m_op_general * lv.opcode_bits[0]; let (addr_context, addr_segment, addr_virtual) = get_addr(lv); @@ -38,14 +41,25 @@ fn eval_packed_load( for &channel in &lv.mem_channels[4..NUM_GP_CHANNELS - 1] { yield_constr.constraint(filter * channel.used); } + + // Stack constraints + stack::eval_packed_one( + lv, + nv, + filter, + stack::MLOAD_GENERAL_OP.unwrap(), + yield_constr, + ); } fn eval_ext_circuit_load, const D: usize>( builder: &mut plonky2::plonk::circuit_builder::CircuitBuilder, lv: &CpuColumnsView>, + nv: &CpuColumnsView>, yield_constr: &mut RecursiveConstraintConsumer, ) { - let filter = lv.op.mload_general; + let mut filter = lv.op.m_op_general; + filter = builder.mul_extension(filter, lv.opcode_bits[0]); let (addr_context, addr_segment, addr_virtual) = get_addr(lv); @@ -82,13 +96,24 @@ fn eval_ext_circuit_load, const D: usize>( let constr = builder.mul_extension(filter, channel.used); yield_constr.constraint(builder, constr); } + + // Stack constraints + stack::eval_ext_circuit_one( + builder, + lv, + nv, + filter, + stack::MLOAD_GENERAL_OP.unwrap(), + yield_constr, + ); } fn eval_packed_store( lv: &CpuColumnsView

, + nv: &CpuColumnsView

, yield_constr: &mut ConstraintConsumer

, ) { - let filter = lv.op.mstore_general; + let filter = lv.op.m_op_general * (P::ONES - lv.opcode_bits[0]); let (addr_context, addr_segment, addr_virtual) = get_addr(lv); @@ -107,14 +132,27 @@ fn eval_packed_store( for &channel in &lv.mem_channels[5..] { yield_constr.constraint(filter * channel.used); } + + // Stack constraints + stack::eval_packed_one( + lv, + nv, + filter, + stack::MSTORE_GENERAL_OP.unwrap(), + yield_constr, + ); } fn eval_ext_circuit_store, const D: usize>( builder: &mut plonky2::plonk::circuit_builder::CircuitBuilder, lv: &CpuColumnsView>, + nv: &CpuColumnsView>, yield_constr: &mut RecursiveConstraintConsumer, ) { - let filter = lv.op.mstore_general; + let mut filter = lv.op.m_op_general; + let one = builder.one_extension(); + let minus = builder.sub_extension(one, lv.opcode_bits[0]); + filter = builder.mul_extension(filter, minus); let (addr_context, addr_segment, addr_virtual) = get_addr(lv); @@ -151,21 +189,33 @@ fn eval_ext_circuit_store, const D: usize>( let constr = builder.mul_extension(filter, channel.used); yield_constr.constraint(builder, constr); } + + // Stack constraints + stack::eval_ext_circuit_one( + builder, + lv, + nv, + filter, + stack::MSTORE_GENERAL_OP.unwrap(), + yield_constr, + ); } pub fn eval_packed( lv: &CpuColumnsView

, + nv: &CpuColumnsView

, yield_constr: &mut ConstraintConsumer

, ) { - eval_packed_load(lv, yield_constr); - eval_packed_store(lv, yield_constr); + eval_packed_load(lv, nv, yield_constr); + eval_packed_store(lv, nv, yield_constr); } pub fn eval_ext_circuit, const D: usize>( builder: &mut plonky2::plonk::circuit_builder::CircuitBuilder, lv: &CpuColumnsView>, + nv: &CpuColumnsView>, yield_constr: &mut RecursiveConstraintConsumer, ) { - eval_ext_circuit_load(builder, lv, yield_constr); - eval_ext_circuit_store(builder, lv, yield_constr); + eval_ext_circuit_load(builder, lv, nv, yield_constr); + eval_ext_circuit_store(builder, lv, nv, yield_constr); } diff --git a/evm/src/cpu/mod.rs b/evm/src/cpu/mod.rs index 91b04cf4..b7312147 100644 --- a/evm/src/cpu/mod.rs +++ b/evm/src/cpu/mod.rs @@ -6,6 +6,7 @@ pub mod cpu_stark; pub(crate) mod decode; mod dup_swap; mod gas; +mod halt; mod jumps; pub mod kernel; pub(crate) mod membus; diff --git a/evm/src/cpu/modfp254.rs b/evm/src/cpu/modfp254.rs index e6a2815d..86f08052 100644 --- a/evm/src/cpu/modfp254.rs +++ b/evm/src/cpu/modfp254.rs @@ -19,7 +19,7 @@ pub fn eval_packed( lv: &CpuColumnsView

, yield_constr: &mut ConstraintConsumer

, ) { - let filter = lv.op.addfp254 + lv.op.mulfp254 + lv.op.subfp254; + let filter = lv.op.fp254_op; // We want to use all the same logic as the usual mod operations, but without needing to read // the modulus from the stack. We simply constrain `mem_channels[2]` to be our prime (that's @@ -36,7 +36,7 @@ pub fn eval_ext_circuit, const D: usize>( lv: &CpuColumnsView>, yield_constr: &mut RecursiveConstraintConsumer, ) { - let filter = builder.add_many_extension([lv.op.addfp254, lv.op.mulfp254, lv.op.subfp254]); + let filter = lv.op.fp254_op; // We want to use all the same logic as the usual mod operations, but without needing to read // the modulus from the stack. We simply constrain `mem_channels[2]` to be our prime (that's diff --git a/evm/src/cpu/shift.rs b/evm/src/cpu/shift.rs index a8acf5d4..a4249297 100644 --- a/evm/src/cpu/shift.rs +++ b/evm/src/cpu/shift.rs @@ -13,7 +13,7 @@ pub(crate) fn eval_packed( lv: &CpuColumnsView

, yield_constr: &mut ConstraintConsumer

, ) { - let is_shift = lv.op.shl + lv.op.shr; + let is_shift = lv.op.shift; let displacement = lv.mem_channels[0]; // holds the shift displacement d let two_exp = lv.mem_channels[2]; // holds 2^d @@ -64,7 +64,7 @@ pub(crate) fn eval_ext_circuit, const D: usize>( lv: &CpuColumnsView>, yield_constr: &mut RecursiveConstraintConsumer, ) { - let is_shift = builder.add_extension(lv.op.shl, lv.op.shr); + let is_shift = lv.op.shift; let displacement = lv.mem_channels[0]; let two_exp = lv.mem_channels[2]; diff --git a/evm/src/cpu/simple_logic/eq_iszero.rs b/evm/src/cpu/simple_logic/eq_iszero.rs index f16901f5..7be021ca 100644 --- a/evm/src/cpu/simple_logic/eq_iszero.rs +++ b/evm/src/cpu/simple_logic/eq_iszero.rs @@ -51,6 +51,7 @@ pub fn generate_pinv_diff(val0: U256, val1: U256, lv: &mut CpuColumnsV pub fn eval_packed( lv: &CpuColumnsView

, + nv: &CpuColumnsView

, yield_constr: &mut ConstraintConsumer

, ) { let logic = lv.general.logic(); @@ -94,9 +95,10 @@ pub fn eval_packed( yield_constr.constraint(eq_or_iszero_filter * (dot - unequal)); // Stack constraints. - stack::eval_packed_one(lv, eq_filter, EQ_STACK_BEHAVIOR.unwrap(), yield_constr); + stack::eval_packed_one(lv, nv, eq_filter, EQ_STACK_BEHAVIOR.unwrap(), yield_constr); stack::eval_packed_one( lv, + nv, iszero_filter, IS_ZERO_STACK_BEHAVIOR.unwrap(), yield_constr, @@ -106,6 +108,7 @@ pub fn eval_packed( pub fn eval_ext_circuit, const D: usize>( builder: &mut plonky2::plonk::circuit_builder::CircuitBuilder, lv: &CpuColumnsView>, + nv: &CpuColumnsView>, yield_constr: &mut RecursiveConstraintConsumer, ) { let zero = builder.zero_extension(); @@ -173,6 +176,7 @@ pub fn eval_ext_circuit, const D: usize>( stack::eval_ext_circuit_one( builder, lv, + nv, eq_filter, EQ_STACK_BEHAVIOR.unwrap(), yield_constr, @@ -180,6 +184,7 @@ pub fn eval_ext_circuit, const D: usize>( stack::eval_ext_circuit_one( builder, lv, + nv, iszero_filter, IS_ZERO_STACK_BEHAVIOR.unwrap(), yield_constr, diff --git a/evm/src/cpu/simple_logic/mod.rs b/evm/src/cpu/simple_logic/mod.rs index 03d2dd15..9b4e60b0 100644 --- a/evm/src/cpu/simple_logic/mod.rs +++ b/evm/src/cpu/simple_logic/mod.rs @@ -11,17 +11,19 @@ use crate::cpu::columns::CpuColumnsView; pub fn eval_packed( lv: &CpuColumnsView

, + nv: &CpuColumnsView

, yield_constr: &mut ConstraintConsumer

, ) { not::eval_packed(lv, yield_constr); - eq_iszero::eval_packed(lv, yield_constr); + eq_iszero::eval_packed(lv, nv, yield_constr); } pub fn eval_ext_circuit, const D: usize>( builder: &mut plonky2::plonk::circuit_builder::CircuitBuilder, lv: &CpuColumnsView>, + nv: &CpuColumnsView>, yield_constr: &mut RecursiveConstraintConsumer, ) { not::eval_ext_circuit(builder, lv, yield_constr); - eq_iszero::eval_ext_circuit(builder, lv, yield_constr); + eq_iszero::eval_ext_circuit(builder, lv, nv, yield_constr); } diff --git a/evm/src/cpu/stack.rs b/evm/src/cpu/stack.rs index 8ffc152d..28abf077 100644 --- a/evm/src/cpu/stack.rs +++ b/evm/src/cpu/stack.rs @@ -44,35 +44,31 @@ pub(crate) const JUMPI_OP: Option = Some(StackBehavior { disable_other_channels: false, }); +pub(crate) const MLOAD_GENERAL_OP: Option = Some(StackBehavior { + num_pops: 3, + pushes: true, + disable_other_channels: false, +}); + +pub(crate) const MSTORE_GENERAL_OP: Option = Some(StackBehavior { + num_pops: 4, + pushes: false, + disable_other_channels: false, +}); + // AUDITORS: If the value below is `None`, then the operation must be manually checked to ensure // that every general-purpose memory channel is either disabled or has its read flag and address // propertly constrained. The same applies when `disable_other_channels` is set to `false`, // except the first `num_pops` and the last `pushes as usize` channels have their read flag and // address constrained automatically in this file. const STACK_BEHAVIORS: OpsColumnsView> = OpsColumnsView { - add: BASIC_BINARY_OP, - mul: BASIC_BINARY_OP, - sub: BASIC_BINARY_OP, - div: BASIC_BINARY_OP, - mod_: BASIC_BINARY_OP, - addmod: BASIC_TERNARY_OP, - mulmod: BASIC_TERNARY_OP, - addfp254: BASIC_BINARY_OP, - mulfp254: BASIC_BINARY_OP, - subfp254: BASIC_BINARY_OP, - submod: BASIC_TERNARY_OP, - lt: BASIC_BINARY_OP, - gt: BASIC_BINARY_OP, + binary_op: BASIC_BINARY_OP, + ternary_op: BASIC_TERNARY_OP, + fp254_op: BASIC_BINARY_OP, eq_iszero: None, // EQ is binary, IS_ZERO is unary. logic_op: BASIC_BINARY_OP, not: BASIC_UNARY_OP, - byte: BASIC_BINARY_OP, - shl: Some(StackBehavior { - num_pops: 2, - pushes: true, - disable_other_channels: false, - }), - shr: Some(StackBehavior { + shift: Some(StackBehavior { num_pops: 2, pushes: true, disable_other_channels: false, @@ -108,21 +104,22 @@ const STACK_BEHAVIORS: OpsColumnsView> = OpsColumnsView { dup: None, swap: None, context_op: None, // SET_CONTEXT is special since it involves the old and the new stack. + mstore_32bytes: Some(StackBehavior { + num_pops: 5, + pushes: false, + disable_other_channels: false, + }), + mload_32bytes: Some(StackBehavior { + num_pops: 4, + pushes: true, + disable_other_channels: false, + }), exit_kernel: Some(StackBehavior { num_pops: 1, pushes: false, disable_other_channels: true, }), - mload_general: Some(StackBehavior { - num_pops: 3, - pushes: true, - disable_other_channels: false, - }), - mstore_general: Some(StackBehavior { - num_pops: 4, - pushes: false, - disable_other_channels: false, - }), + m_op_general: None, syscall: Some(StackBehavior { num_pops: 0, pushes: true, @@ -140,6 +137,7 @@ pub(crate) const IS_ZERO_STACK_BEHAVIOR: Option = BASIC_UNARY_OP; pub(crate) fn eval_packed_one( lv: &CpuColumnsView

, + nv: &CpuColumnsView

, filter: P, stack_behavior: StackBehavior, yield_constr: &mut ConstraintConsumer

, @@ -185,15 +183,21 @@ pub(crate) fn eval_packed_one( yield_constr.constraint(filter * channel.used); } } + + // Constrain new stack length. + let num_pops = P::Scalar::from_canonical_usize(stack_behavior.num_pops); + let push = P::Scalar::from_canonical_usize(stack_behavior.pushes as usize); + yield_constr.constraint_transition(filter * (nv.stack_len - (lv.stack_len - num_pops + push))); } pub fn eval_packed( lv: &CpuColumnsView

, + nv: &CpuColumnsView

, yield_constr: &mut ConstraintConsumer

, ) { for (op, stack_behavior) in izip!(lv.op.into_iter(), STACK_BEHAVIORS.into_iter()) { if let Some(stack_behavior) = stack_behavior { - eval_packed_one(lv, op, stack_behavior, yield_constr); + eval_packed_one(lv, nv, op, stack_behavior, yield_constr); } } } @@ -201,6 +205,7 @@ pub fn eval_packed( pub(crate) fn eval_ext_circuit_one, const D: usize>( builder: &mut plonky2::plonk::circuit_builder::CircuitBuilder, lv: &CpuColumnsView>, + nv: &CpuColumnsView>, filter: ExtensionTarget, stack_behavior: StackBehavior, yield_constr: &mut RecursiveConstraintConsumer, @@ -298,16 +303,27 @@ pub(crate) fn eval_ext_circuit_one, const D: usize> yield_constr.constraint(builder, constr); } } + + // Constrain new stack length. + let diff = builder.constant_extension( + F::Extension::from_canonical_usize(stack_behavior.num_pops) + - F::Extension::from_canonical_usize(stack_behavior.pushes as usize), + ); + let diff = builder.sub_extension(lv.stack_len, diff); + let diff = builder.sub_extension(nv.stack_len, diff); + let constr = builder.mul_extension(filter, diff); + yield_constr.constraint_transition(builder, constr); } pub fn eval_ext_circuit, const D: usize>( builder: &mut plonky2::plonk::circuit_builder::CircuitBuilder, lv: &CpuColumnsView>, + nv: &CpuColumnsView>, yield_constr: &mut RecursiveConstraintConsumer, ) { for (op, stack_behavior) in izip!(lv.op.into_iter(), STACK_BEHAVIORS.into_iter()) { if let Some(stack_behavior) = stack_behavior { - eval_ext_circuit_one(builder, lv, op, stack_behavior, yield_constr); + eval_ext_circuit_one(builder, lv, nv, op, stack_behavior, yield_constr); } } } diff --git a/evm/src/cross_table_lookup.rs b/evm/src/cross_table_lookup.rs index 8f481325..a2dad1ab 100644 --- a/evm/src/cross_table_lookup.rs +++ b/evm/src/cross_table_lookup.rs @@ -25,6 +25,7 @@ use crate::vars::{StarkEvaluationTargets, StarkEvaluationVars}; #[derive(Clone, Debug)] pub struct Column { linear_combination: Vec<(usize, F)>, + next_row_linear_combination: Vec<(usize, F)>, constant: F, } @@ -32,6 +33,7 @@ impl Column { pub fn single(c: usize) -> Self { Self { linear_combination: vec![(c, F::ONE)], + next_row_linear_combination: vec![], constant: F::ZERO, } } @@ -42,9 +44,24 @@ impl Column { cs.into_iter().map(|c| Self::single(*c.borrow())) } + pub fn single_next_row(c: usize) -> Self { + Self { + linear_combination: vec![], + next_row_linear_combination: vec![(c, F::ONE)], + constant: F::ZERO, + } + } + + pub fn singles_next_row>>( + cs: I, + ) -> impl Iterator { + cs.into_iter().map(|c| Self::single_next_row(*c.borrow())) + } + pub fn constant(constant: F) -> Self { Self { linear_combination: vec![], + next_row_linear_combination: vec![], constant, } } @@ -70,6 +87,34 @@ impl Column { ); Self { linear_combination: v, + next_row_linear_combination: vec![], + constant, + } + } + + pub fn linear_combination_and_next_row_with_constant>( + iter: I, + next_row_iter: I, + constant: F, + ) -> Self { + let v = iter.into_iter().collect::>(); + let next_row_v = next_row_iter.into_iter().collect::>(); + + assert!(!v.is_empty() || !next_row_v.is_empty()); + debug_assert_eq!( + v.iter().map(|(c, _)| c).unique().count(), + v.len(), + "Duplicate columns." + ); + debug_assert_eq!( + next_row_v.iter().map(|(c, _)| c).unique().count(), + next_row_v.len(), + "Duplicate columns." + ); + + Self { + linear_combination: v, + next_row_linear_combination: next_row_v, constant, } } @@ -106,13 +151,43 @@ impl Column { + FE::from_basefield(self.constant) } + pub fn eval_with_next(&self, v: &[P], next_v: &[P]) -> P + where + FE: FieldExtension, + P: PackedField, + { + self.linear_combination + .iter() + .map(|&(c, f)| v[c] * FE::from_basefield(f)) + .sum::

() + + self + .next_row_linear_combination + .iter() + .map(|&(c, f)| next_v[c] * FE::from_basefield(f)) + .sum::

() + + FE::from_basefield(self.constant) + } + /// Evaluate on an row of a table given in column-major form. pub fn eval_table(&self, table: &[PolynomialValues], row: usize) -> F { - self.linear_combination + let mut res = self + .linear_combination .iter() .map(|&(c, f)| table[c].values[row] * f) .sum::() - + self.constant + + self.constant; + + // If we access the next row at the last row, for sanity, we consider the next row's values to be 0. + // If CTLs are correctly written, the filter should be 0 in that case anyway. + if !self.next_row_linear_combination.is_empty() && row < table[0].values.len() - 1 { + res += self + .next_row_linear_combination + .iter() + .map(|&(c, f)| table[c].values[row + 1] * f) + .sum::(); + } + + res } pub fn eval_circuit( @@ -136,6 +211,36 @@ impl Column { let constant = builder.constant_extension(F::Extension::from_basefield(self.constant)); builder.inner_product_extension(F::ONE, constant, pairs) } + + pub fn eval_with_next_circuit( + &self, + builder: &mut CircuitBuilder, + v: &[ExtensionTarget], + next_v: &[ExtensionTarget], + ) -> ExtensionTarget + where + F: RichField + Extendable, + { + let mut pairs = self + .linear_combination + .iter() + .map(|&(c, f)| { + ( + v[c], + builder.constant_extension(F::Extension::from_basefield(f)), + ) + }) + .collect::>(); + let next_row_pairs = self.next_row_linear_combination.iter().map(|&(c, f)| { + ( + next_v[c], + builder.constant_extension(F::Extension::from_basefield(f)), + ) + }); + pairs.extend(next_row_pairs); + let constant = builder.constant_extension(F::Extension::from_basefield(self.constant)); + builder.inner_product_extension(F::ONE, constant, pairs) + } } #[derive(Clone, Debug)] @@ -276,7 +381,7 @@ fn partial_products( let mut partial_prod = F::ONE; let degree = trace[0].len(); let mut res = Vec::with_capacity(degree); - for i in 0..degree { + for i in (0..degree).rev() { let filter = if let Some(column) = filter_column { column.eval_table(trace, i) } else { @@ -293,6 +398,7 @@ fn partial_products( }; res.push(partial_prod); } + res.reverse(); res.into() } @@ -362,6 +468,10 @@ impl<'a, F: RichField + Extendable, const D: usize> } } +/// CTL Z partial products are upside down: the complete product is on the first row, and +/// the first term is on the last row. This allows the transition constraint to be: +/// Z(w) = Z(gw) * combine(w) where combine is called on the local row +/// and not the next. This enables CTLs across two rows. pub(crate) fn eval_cross_table_lookup_checks( vars: StarkEvaluationVars, ctl_vars: &[CtlCheckVars], @@ -380,27 +490,23 @@ pub(crate) fn eval_cross_table_lookup_checks P { - let evals = columns.iter().map(|c| c.eval(v)).collect::>(); - challenges.combine(evals.iter()) - }; - let filter = |v: &[P]| -> P { - if let Some(column) = filter_column { - column.eval(v) - } else { - P::ONES - } - }; - let local_filter = filter(vars.local_values); - let next_filter = filter(vars.next_values); - let select = |filter, x| filter * x + P::ONES - filter; - // Check value of `Z(1)` - consumer.constraint_first_row(*local_z - select(local_filter, combine(vars.local_values))); - // Check `Z(gw) = combination * Z(w)` - consumer.constraint_transition( - *local_z * select(next_filter, combine(vars.next_values)) - *next_z, - ); + let evals = columns + .iter() + .map(|c| c.eval_with_next(vars.local_values, vars.next_values)) + .collect::>(); + let combined = challenges.combine(evals.iter()); + let local_filter = if let Some(column) = filter_column { + column.eval_with_next(vars.local_values, vars.next_values) + } else { + P::ONES + }; + let select = local_filter * combined + P::ONES - local_filter; + + // Check value of `Z(g^(n-1))` + consumer.constraint_last_row(*local_z - select); + // Check `Z(w) = combination * Z(gw)` + consumer.constraint_transition(*next_z * select - *local_z); } } @@ -493,11 +599,6 @@ pub(crate) fn eval_cross_table_lookup_checks_circuit< } else { one }; - let next_filter = if let Some(column) = filter_column { - column.eval_circuit(builder, vars.next_values) - } else { - one - }; fn select, const D: usize>( builder: &mut CircuitBuilder, filter: ExtensionTarget, @@ -508,38 +609,37 @@ pub(crate) fn eval_cross_table_lookup_checks_circuit< builder.mul_add_extension(filter, x, tmp) // filter * x + 1 - filter } - // Check value of `Z(1)` - let local_columns_eval = columns + let evals = columns .iter() - .map(|c| c.eval_circuit(builder, vars.local_values)) + .map(|c| c.eval_with_next_circuit(builder, vars.local_values, vars.next_values)) .collect::>(); - let combined_local = challenges.combine_circuit(builder, &local_columns_eval); - let selected_local = select(builder, local_filter, combined_local); - let first_row = builder.sub_extension(*local_z, selected_local); - consumer.constraint_first_row(builder, first_row); - // Check `Z(gw) = combination * Z(w)` - let next_columns_eval = columns - .iter() - .map(|c| c.eval_circuit(builder, vars.next_values)) - .collect::>(); - let combined_next = challenges.combine_circuit(builder, &next_columns_eval); - let selected_next = select(builder, next_filter, combined_next); - let transition = builder.mul_sub_extension(*local_z, selected_next, *next_z); + + let combined = challenges.combine_circuit(builder, &evals); + let select = select(builder, local_filter, combined); + + // Check value of `Z(g^(n-1))` + let last_row = builder.sub_extension(*local_z, select); + consumer.constraint_last_row(builder, last_row); + // Check `Z(w) = combination * Z(gw)` + let transition = builder.mul_sub_extension(*next_z, select, *local_z); consumer.constraint_transition(builder, transition); } } pub(crate) fn verify_cross_table_lookups, const D: usize>( cross_table_lookups: &[CrossTableLookup], - ctl_zs_lasts: [Vec; NUM_TABLES], + ctl_zs_first: [Vec; NUM_TABLES], ctl_extra_looking_products: Vec>, config: &StarkConfig, ) -> Result<()> { - let mut ctl_zs_openings = ctl_zs_lasts.iter().map(|v| v.iter()).collect::>(); - for CrossTableLookup { - looking_tables, - looked_table, - } in cross_table_lookups.iter() + let mut ctl_zs_openings = ctl_zs_first.iter().map(|v| v.iter()).collect::>(); + for ( + index, + CrossTableLookup { + looking_tables, + looked_table, + }, + ) in cross_table_lookups.iter().enumerate() { let extra_product_vec = &ctl_extra_looking_products[looked_table.table as usize]; for c in 0..config.num_challenges { @@ -552,7 +652,8 @@ pub(crate) fn verify_cross_table_lookups, const D: let looked_z = *ctl_zs_openings[looked_table.table as usize].next().unwrap(); ensure!( looking_zs_prod == looked_z, - "Cross-table lookup verification failed." + "Cross-table lookup {:?} verification failed.", + index ); } } @@ -564,11 +665,11 @@ pub(crate) fn verify_cross_table_lookups, const D: pub(crate) fn verify_cross_table_lookups_circuit, const D: usize>( builder: &mut CircuitBuilder, cross_table_lookups: Vec>, - ctl_zs_lasts: [Vec; NUM_TABLES], + ctl_zs_first: [Vec; NUM_TABLES], ctl_extra_looking_products: Vec>, inner_config: &StarkConfig, ) { - let mut ctl_zs_openings = ctl_zs_lasts.iter().map(|v| v.iter()).collect::>(); + let mut ctl_zs_openings = ctl_zs_first.iter().map(|v| v.iter()).collect::>(); for CrossTableLookup { looking_tables, looked_table, diff --git a/evm/src/fixed_recursive_verifier.rs b/evm/src/fixed_recursive_verifier.rs index 4d3f4d3d..55577fb2 100644 --- a/evm/src/fixed_recursive_verifier.rs +++ b/evm/src/fixed_recursive_verifier.rs @@ -29,6 +29,7 @@ use plonky2_util::log2_ceil; use crate::all_stark::{all_cross_table_lookups, AllStark, Table, NUM_TABLES}; use crate::arithmetic::arithmetic_stark::ArithmeticStark; +use crate::byte_packing::byte_packing_stark::BytePackingStark; use crate::config::StarkConfig; use crate::cpu::cpu_stark::CpuStark; use crate::cross_table_lookup::{verify_cross_table_lookups_circuit, CrossTableLookup}; @@ -46,9 +47,8 @@ use crate::proof::{ use crate::prover::prove; use crate::recursive_verifier::{ add_common_recursion_gates, add_virtual_public_values, - get_memory_extra_looking_products_circuit, recursive_stark_circuit, set_block_hashes_target, - set_block_metadata_target, set_extra_public_values_target, set_public_value_targets, - set_trie_roots_target, PlonkWrapperCircuit, PublicInputs, StarkWrapperCircuit, + get_memory_extra_looking_products_circuit, recursive_stark_circuit, set_public_value_targets, + PlonkWrapperCircuit, PublicInputs, StarkWrapperCircuit, }; use crate::stark::Stark; use crate::util::h256_limbs; @@ -298,6 +298,7 @@ where C: GenericConfig + 'static, C::Hasher: AlgebraicHasher, [(); ArithmeticStark::::COLUMNS]:, + [(); BytePackingStark::::COLUMNS]:, [(); CpuStark::::COLUMNS]:, [(); KeccakStark::::COLUMNS]:, [(); KeccakSpongeStark::::COLUMNS]:, @@ -374,47 +375,62 @@ where let arithmetic = RecursiveCircuitsForTable::new( Table::Arithmetic, &all_stark.arithmetic_stark, - degree_bits_ranges[0].clone(), + degree_bits_ranges[Table::Arithmetic as usize].clone(), + &all_stark.cross_table_lookups, + stark_config, + ); + let byte_packing = RecursiveCircuitsForTable::new( + Table::BytePacking, + &all_stark.byte_packing_stark, + degree_bits_ranges[Table::BytePacking as usize].clone(), &all_stark.cross_table_lookups, stark_config, ); let cpu = RecursiveCircuitsForTable::new( Table::Cpu, &all_stark.cpu_stark, - degree_bits_ranges[1].clone(), + degree_bits_ranges[Table::Cpu as usize].clone(), &all_stark.cross_table_lookups, stark_config, ); let keccak = RecursiveCircuitsForTable::new( Table::Keccak, &all_stark.keccak_stark, - degree_bits_ranges[2].clone(), + degree_bits_ranges[Table::Keccak as usize].clone(), &all_stark.cross_table_lookups, stark_config, ); let keccak_sponge = RecursiveCircuitsForTable::new( Table::KeccakSponge, &all_stark.keccak_sponge_stark, - degree_bits_ranges[3].clone(), + degree_bits_ranges[Table::KeccakSponge as usize].clone(), &all_stark.cross_table_lookups, stark_config, ); let logic = RecursiveCircuitsForTable::new( Table::Logic, &all_stark.logic_stark, - degree_bits_ranges[4].clone(), + degree_bits_ranges[Table::Logic as usize].clone(), &all_stark.cross_table_lookups, stark_config, ); let memory = RecursiveCircuitsForTable::new( Table::Memory, &all_stark.memory_stark, - degree_bits_ranges[5].clone(), + degree_bits_ranges[Table::Memory as usize].clone(), &all_stark.cross_table_lookups, stark_config, ); - let by_table = [arithmetic, cpu, keccak, keccak_sponge, logic, memory]; + let by_table = [ + arithmetic, + byte_packing, + cpu, + keccak, + keccak_sponge, + logic, + memory, + ]; let root = Self::create_root_circuit(&by_table, stark_config); let aggregation = Self::create_aggregation_circuit(&root); let block = Self::create_block_circuit(&aggregation); @@ -489,13 +505,13 @@ where } } - // Extra products to add to the looked last value - // Arithmetic, KeccakSponge, Keccak, Logic + // Extra products to add to the looked last value. + // Only necessary for the Memory values. let mut extra_looking_products = - vec![vec![builder.constant(F::ONE); stark_config.num_challenges]; NUM_TABLES - 1]; + vec![vec![builder.one(); stark_config.num_challenges]; NUM_TABLES]; // Memory - let memory_looking_products = (0..stark_config.num_challenges) + extra_looking_products[Table::Memory as usize] = (0..stark_config.num_challenges) .map(|c| { get_memory_extra_looking_products_circuit( &mut builder, @@ -504,13 +520,12 @@ where ) }) .collect_vec(); - extra_looking_products.push(memory_looking_products); // Verify the CTL checks. verify_cross_table_lookups_circuit::( &mut builder, all_cross_table_lookups(), - pis.map(|p| p.ctl_zs_last), + pis.map(|p| p.ctl_zs_first), extra_looking_products, stark_config, ); @@ -914,7 +929,10 @@ where &mut root_inputs, &self.root.public_values, &all_proof.public_values, - ); + ) + .map_err(|_| { + anyhow::Error::msg("Invalid conversion when setting public values targets.") + })?; let root_proof = self.root.circuit.prove(root_inputs)?; @@ -948,32 +966,15 @@ where &self.aggregation.circuit.verifier_only, ); - set_block_hashes_target( + set_public_value_targets( &mut agg_inputs, - &self.aggregation.public_values.block_hashes, - &public_values.block_hashes, - ); - set_block_metadata_target( - &mut agg_inputs, - &self.aggregation.public_values.block_metadata, - &public_values.block_metadata, - ); + &self.aggregation.public_values, + &public_values, + ) + .map_err(|_| { + anyhow::Error::msg("Invalid conversion when setting public values targets.") + })?; - set_trie_roots_target( - &mut agg_inputs, - &self.aggregation.public_values.trie_roots_before, - &public_values.trie_roots_before, - ); - set_trie_roots_target( - &mut agg_inputs, - &self.aggregation.public_values.trie_roots_after, - &public_values.trie_roots_after, - ); - set_extra_public_values_target( - &mut agg_inputs, - &self.aggregation.public_values.extra_block_data, - &public_values.extra_block_data, - ); let aggregation_proof = self.aggregation.circuit.prove(agg_inputs)?; Ok((aggregation_proof, public_values)) } @@ -1049,32 +1050,10 @@ where block_inputs .set_verifier_data_target(&self.block.cyclic_vk, &self.block.circuit.verifier_only); - set_block_hashes_target( - &mut block_inputs, - &self.block.public_values.block_hashes, - &public_values.block_hashes, - ); - set_extra_public_values_target( - &mut block_inputs, - &self.block.public_values.extra_block_data, - &public_values.extra_block_data, - ); - set_block_metadata_target( - &mut block_inputs, - &self.block.public_values.block_metadata, - &public_values.block_metadata, - ); - - set_trie_roots_target( - &mut block_inputs, - &self.block.public_values.trie_roots_before, - &public_values.trie_roots_before, - ); - set_trie_roots_target( - &mut block_inputs, - &self.block.public_values.trie_roots_after, - &public_values.trie_roots_after, - ); + set_public_value_targets(&mut block_inputs, &self.block.public_values, &public_values) + .map_err(|_| { + anyhow::Error::msg("Invalid conversion when setting public values targets.") + })?; let block_proof = self.block.circuit.prove(block_inputs)?; Ok((block_proof, public_values)) diff --git a/evm/src/generation/mod.rs b/evm/src/generation/mod.rs index 01e3209d..85f19431 100644 --- a/evm/src/generation/mod.rs +++ b/evm/src/generation/mod.rs @@ -16,6 +16,7 @@ use GlobalMetadata::{ use crate::all_stark::{AllStark, NUM_TABLES}; use crate::config::StarkConfig; use crate::cpu::bootstrap_kernel::generate_bootstrap_kernel; +use crate::cpu::columns::CpuColumnsView; use crate::cpu::kernel::aggregator::KERNEL; use crate::cpu::kernel::constants::global_metadata::GlobalMetadata; use crate::generation::outputs::{get_outputs, GenerationOutputs}; @@ -281,26 +282,36 @@ pub fn generate_traces, const D: usize>( fn simulate_cpu, const D: usize>( state: &mut GenerationState, ) -> anyhow::Result<()> { - let halt_pc0 = KERNEL.global_labels["halt_pc0"]; - let halt_pc1 = KERNEL.global_labels["halt_pc1"]; + let halt_pc = KERNEL.global_labels["halt"]; - let mut already_in_halt_loop = false; loop { // If we've reached the kernel's halt routine, and our trace length is a power of 2, stop. let pc = state.registers.program_counter; - let in_halt_loop = state.registers.is_kernel && (pc == halt_pc0 || pc == halt_pc1); - if in_halt_loop && !already_in_halt_loop { + let halt = state.registers.is_kernel && pc == halt_pc; + if halt { log::info!("CPU halted after {} cycles", state.traces.clock()); + + // Padding + let mut row = CpuColumnsView::::default(); + row.clock = F::from_canonical_usize(state.traces.clock()); + row.context = F::from_canonical_usize(state.registers.context); + row.program_counter = F::from_canonical_usize(pc); + row.is_kernel_mode = F::ONE; + row.gas = F::from_canonical_u64(state.registers.gas_used); + row.stack_len = F::from_canonical_usize(state.registers.stack_len); + + loop { + state.traces.push_cpu(row); + row.clock += F::ONE; + if state.traces.clock().is_power_of_two() { + break; + } + } + log::info!("CPU trace padded to {} cycles", state.traces.clock()); + + return Ok(()); } - already_in_halt_loop |= in_halt_loop; transition(state)?; - - if already_in_halt_loop && state.traces.clock().is_power_of_two() { - log::info!("CPU trace padded to {} cycles", state.traces.clock()); - break; - } } - - Ok(()) } diff --git a/evm/src/generation/state.rs b/evm/src/generation/state.rs index c0c03e28..2b85821f 100644 --- a/evm/src/generation/state.rs +++ b/evm/src/generation/state.rs @@ -116,7 +116,7 @@ impl GenerationState { let code = self.memory.contexts[ctx].segments[Segment::Returndata as usize].content [..returndata_size] .iter() - .map(|x| x.as_u32() as u8) + .map(|x| x.low_u32() as u8) .collect::>(); debug_assert_eq!(keccak(&code), codehash); diff --git a/evm/src/generation/trie_extractor.rs b/evm/src/generation/trie_extractor.rs index 8311f692..a508a720 100644 --- a/evm/src/generation/trie_extractor.rs +++ b/evm/src/generation/trie_extractor.rs @@ -20,7 +20,7 @@ pub(crate) struct AccountTrieRecord { pub(crate) fn read_state_trie_value(slice: &[U256]) -> AccountTrieRecord { AccountTrieRecord { - nonce: slice[0].as_u64(), + nonce: slice[0].low_u64(), balance: slice[1], storage_ptr: slice[2].as_usize(), code_hash: H256::from_uint(&slice[3]), diff --git a/evm/src/get_challenges.rs b/evm/src/get_challenges.rs index 0afa1d80..459100c8 100644 --- a/evm/src/get_challenges.rs +++ b/evm/src/get_challenges.rs @@ -13,7 +13,8 @@ use crate::permutation::{ get_n_grand_product_challenge_sets_target, }; use crate::proof::*; -use crate::util::{h256_limbs, u256_limbs}; +use crate::util::{h256_limbs, u256_limbs, u256_to_u32, u256_to_u64}; +use crate::witness::errors::ProgramError; fn observe_root, C: GenericConfig, const D: usize>( challenger: &mut Challenger, @@ -56,35 +57,24 @@ fn observe_block_metadata< >( challenger: &mut Challenger, block_metadata: &BlockMetadata, -) { +) -> Result<(), ProgramError> { challenger.observe_elements( &u256_limbs::(U256::from_big_endian(&block_metadata.block_beneficiary.0))[..5], ); - challenger.observe_element(F::from_canonical_u32( - block_metadata.block_timestamp.as_u32(), - )); - challenger.observe_element(F::from_canonical_u32(block_metadata.block_number.as_u32())); - challenger.observe_element(F::from_canonical_u32( - block_metadata.block_difficulty.as_u32(), - )); - challenger.observe_element(F::from_canonical_u32( - block_metadata.block_gaslimit.as_u32(), - )); - challenger.observe_element(F::from_canonical_u32( - block_metadata.block_chain_id.as_u32(), - )); - challenger.observe_element(F::from_canonical_u32( - block_metadata.block_base_fee.as_u64() as u32, - )); - challenger.observe_element(F::from_canonical_u32( - (block_metadata.block_base_fee.as_u64() >> 32) as u32, - )); - challenger.observe_element(F::from_canonical_u32( - block_metadata.block_gas_used.as_u32(), - )); + challenger.observe_element(u256_to_u32(block_metadata.block_timestamp)?); + challenger.observe_element(u256_to_u32(block_metadata.block_number)?); + challenger.observe_element(u256_to_u32(block_metadata.block_difficulty)?); + challenger.observe_element(u256_to_u32(block_metadata.block_gaslimit)?); + challenger.observe_element(u256_to_u32(block_metadata.block_chain_id)?); + let basefee = u256_to_u64(block_metadata.block_base_fee)?; + challenger.observe_element(basefee.0); + challenger.observe_element(basefee.1); + challenger.observe_element(u256_to_u32(block_metadata.block_gas_used)?); for i in 0..8 { challenger.observe_elements(&u256_limbs(block_metadata.block_bloom[i])); } + + Ok(()) } fn observe_block_metadata_target< @@ -115,18 +105,20 @@ fn observe_extra_block_data< >( challenger: &mut Challenger, extra_data: &ExtraBlockData, -) { +) -> Result<(), ProgramError> { challenger.observe_elements(&h256_limbs(extra_data.genesis_state_root)); - challenger.observe_element(F::from_canonical_u32(extra_data.txn_number_before.as_u32())); - challenger.observe_element(F::from_canonical_u32(extra_data.txn_number_after.as_u32())); - challenger.observe_element(F::from_canonical_u32(extra_data.gas_used_before.as_u32())); - challenger.observe_element(F::from_canonical_u32(extra_data.gas_used_after.as_u32())); + challenger.observe_element(u256_to_u32(extra_data.txn_number_before)?); + challenger.observe_element(u256_to_u32(extra_data.txn_number_after)?); + challenger.observe_element(u256_to_u32(extra_data.gas_used_before)?); + challenger.observe_element(u256_to_u32(extra_data.gas_used_after)?); for i in 0..8 { challenger.observe_elements(&u256_limbs(extra_data.block_bloom_before[i])); } for i in 0..8 { challenger.observe_elements(&u256_limbs(extra_data.block_bloom_after[i])); } + + Ok(()) } fn observe_extra_block_data_target< @@ -183,12 +175,12 @@ pub(crate) fn observe_public_values< >( challenger: &mut Challenger, public_values: &PublicValues, -) { +) -> Result<(), ProgramError> { observe_trie_roots::(challenger, &public_values.trie_roots_before); observe_trie_roots::(challenger, &public_values.trie_roots_after); - observe_block_metadata::(challenger, &public_values.block_metadata); + observe_block_metadata::(challenger, &public_values.block_metadata)?; observe_block_hashes::(challenger, &public_values.block_hashes); - observe_extra_block_data::(challenger, &public_values.extra_block_data); + observe_extra_block_data::(challenger, &public_values.extra_block_data) } pub(crate) fn observe_public_values_target< @@ -214,14 +206,14 @@ impl, C: GenericConfig, const D: usize> A &self, all_stark: &AllStark, config: &StarkConfig, - ) -> AllProofChallenges { + ) -> Result, ProgramError> { let mut challenger = Challenger::::new(); for proof in &self.stark_proofs { challenger.observe_cap(&proof.proof.trace_cap); } - observe_public_values::(&mut challenger, &self.public_values); + observe_public_values::(&mut challenger, &self.public_values)?; let ctl_challenges = get_grand_product_challenge_set(&mut challenger, config.num_challenges); @@ -229,7 +221,7 @@ impl, C: GenericConfig, const D: usize> A let num_permutation_zs = all_stark.nums_permutation_zs(config); let num_permutation_batch_sizes = all_stark.permutation_batch_sizes(); - AllProofChallenges { + Ok(AllProofChallenges { stark_challenges: core::array::from_fn(|i| { challenger.compact(); self.stark_proofs[i].proof.get_challenges( @@ -240,7 +232,7 @@ impl, C: GenericConfig, const D: usize> A ) }), ctl_challenges, - } + }) } #[allow(unused)] // TODO: should be used soon diff --git a/evm/src/keccak/keccak_stark.rs b/evm/src/keccak/keccak_stark.rs index 7bd3c385..74f92622 100644 --- a/evm/src/keccak/keccak_stark.rs +++ b/evm/src/keccak/keccak_stark.rs @@ -261,6 +261,8 @@ impl, const D: usize> Stark for KeccakStark, const D: usize> Stark for KeccakStark, const D: usize> Stark for KeccakStark, const D: usize> Stark for KeccakStark { pub xored_rate_u32s: [T; KECCAK_RATE_U32S], /// The entire state (rate + capacity) of the sponge, encoded as 32-bit chunks, after the - /// permutation is applied. - pub updated_state_u32s: [T; KECCAK_WIDTH_U32S], + /// permutation is applied, minus the first limbs where the digest is extracted from. + /// Those missing limbs can be recomputed from their corresponding bytes stored in + /// `updated_digest_state_bytes`. + pub partial_updated_state_u32s: [T; KECCAK_WIDTH_MINUS_DIGEST_U32S], - pub updated_state_bytes: [T; KECCAK_DIGEST_BYTES], + /// The first part of the state of the sponge, seen as bytes, after the permutation is applied. + /// This also represents the output digest of the Keccak sponge during the squeezing phase. + pub updated_digest_state_bytes: [T; KECCAK_DIGEST_BYTES], } // `u8` is guaranteed to have a `size_of` of 1. diff --git a/evm/src/keccak_sponge/keccak_sponge_stark.rs b/evm/src/keccak_sponge/keccak_sponge_stark.rs index 5f1a49cc..d78e9651 100644 --- a/evm/src/keccak_sponge/keccak_sponge_stark.rs +++ b/evm/src/keccak_sponge/keccak_sponge_stark.rs @@ -28,7 +28,7 @@ pub(crate) fn ctl_looked_data() -> Vec> { let mut outputs = Vec::with_capacity(8); for i in (0..8).rev() { let cur_col = Column::linear_combination( - cols.updated_state_bytes[i * 4..(i + 1) * 4] + cols.updated_digest_state_bytes[i * 4..(i + 1) * 4] .iter() .enumerate() .map(|(j, &c)| (c, F::from_canonical_u64(1 << (24 - 8 * j)))), @@ -49,15 +49,30 @@ pub(crate) fn ctl_looked_data() -> Vec> { pub(crate) fn ctl_looking_keccak() -> Vec> { let cols = KECCAK_SPONGE_COL_MAP; - Column::singles( + let mut res: Vec<_> = Column::singles( [ cols.xored_rate_u32s.as_slice(), &cols.original_capacity_u32s, - &cols.updated_state_u32s, ] .concat(), ) - .collect() + .collect(); + + // We recover the 32-bit digest limbs from their corresponding bytes, + // and then append them to the rest of the updated state limbs. + let digest_u32s = cols.updated_digest_state_bytes.chunks_exact(4).map(|c| { + Column::linear_combination( + c.iter() + .enumerate() + .map(|(i, &b)| (b, F::from_canonical_usize(1 << (8 * i)))), + ) + }); + + res.extend(digest_u32s); + + res.extend(Column::singles(&cols.partial_updated_state_u32s)); + + res } pub(crate) fn ctl_looking_memory(i: usize) -> Vec> { @@ -239,7 +254,21 @@ impl, const D: usize> KeccakSpongeStark { block.try_into().unwrap(), ); - sponge_state = row.updated_state_u32s.map(|f| f.to_canonical_u64() as u32); + sponge_state[..KECCAK_DIGEST_U32S] + .iter_mut() + .zip(row.updated_digest_state_bytes.chunks_exact(4)) + .for_each(|(s, bs)| { + *s = bs + .iter() + .enumerate() + .map(|(i, b)| (b.to_canonical_u64() as u32) << (8 * i)) + .sum(); + }); + + sponge_state[KECCAK_DIGEST_U32S..] + .iter_mut() + .zip(row.partial_updated_state_u32s) + .for_each(|(s, x)| *s = x.to_canonical_u64() as u32); rows.push(row.into()); already_absorbed_bytes += KECCAK_RATE_BYTES; @@ -357,24 +386,33 @@ impl, const D: usize> KeccakSpongeStark { row.xored_rate_u32s = xored_rate_u32s.map(F::from_canonical_u32); keccakf_u32s(&mut sponge_state); - row.updated_state_u32s = sponge_state.map(F::from_canonical_u32); - let is_final_block = row.is_final_input_len.iter().copied().sum::() == F::ONE; - if is_final_block { - for (l, &elt) in row.updated_state_u32s[..8].iter().enumerate() { + // Store all but the first `KECCAK_DIGEST_U32S` limbs in the updated state. + // Those missing limbs will be broken down into bytes and stored separately. + row.partial_updated_state_u32s.copy_from_slice( + &sponge_state[KECCAK_DIGEST_U32S..] + .iter() + .copied() + .map(|i| F::from_canonical_u32(i)) + .collect::>(), + ); + sponge_state[..KECCAK_DIGEST_U32S] + .iter() + .enumerate() + .for_each(|(l, &elt)| { let mut cur_elt = elt; (0..4).for_each(|i| { - row.updated_state_bytes[l * 4 + i] = - F::from_canonical_u32((cur_elt.to_canonical_u64() & 0xFF) as u32); - cur_elt = F::from_canonical_u64(cur_elt.to_canonical_u64() >> 8); + row.updated_digest_state_bytes[l * 4 + i] = + F::from_canonical_u32(cur_elt & 0xFF); + cur_elt >>= 8; }); - let mut s = row.updated_state_bytes[l * 4].to_canonical_u64(); + // 32-bit limb reconstruction consistency check. + let mut s = row.updated_digest_state_bytes[l * 4].to_canonical_u64(); for i in 1..4 { - s += row.updated_state_bytes[l * 4 + i].to_canonical_u64() << (8 * i); + s += row.updated_digest_state_bytes[l * 4 + i].to_canonical_u64() << (8 * i); } - assert_eq!(elt, F::from_canonical_u64(s), "not equal"); - } - } + assert_eq!(elt as u64, s, "not equal"); + }) } fn generate_padding_row(&self) -> [F; NUM_KECCAK_SPONGE_COLUMNS] { @@ -445,26 +483,39 @@ impl, const D: usize> Stark for KeccakSpongeS ); // If this is a full-input block, the next row's "before" should match our "after" state. + for (current_bytes_after, next_before) in local_values + .updated_digest_state_bytes + .chunks_exact(4) + .zip(&next_values.original_rate_u32s[..KECCAK_DIGEST_U32S]) + { + let mut current_after = current_bytes_after[0]; + for i in 1..4 { + current_after += + current_bytes_after[i] * P::from(FE::from_canonical_usize(1 << (8 * i))); + } + yield_constr + .constraint_transition(is_full_input_block * (*next_before - current_after)); + } for (¤t_after, &next_before) in local_values - .updated_state_u32s + .partial_updated_state_u32s .iter() - .zip(next_values.original_rate_u32s.iter()) + .zip(next_values.original_rate_u32s[KECCAK_DIGEST_U32S..].iter()) { yield_constr.constraint_transition(is_full_input_block * (next_before - current_after)); } for (¤t_after, &next_before) in local_values - .updated_state_u32s + .partial_updated_state_u32s .iter() - .skip(KECCAK_RATE_U32S) + .skip(KECCAK_RATE_U32S - KECCAK_DIGEST_U32S) .zip(next_values.original_capacity_u32s.iter()) { yield_constr.constraint_transition(is_full_input_block * (next_before - current_after)); } - // If this is a full-input block, the next row's already_absorbed_bytes should be ours plus 136. + // If this is a full-input block, the next row's already_absorbed_bytes should be ours plus `KECCAK_RATE_BYTES`. yield_constr.constraint_transition( is_full_input_block - * (already_absorbed_bytes + P::from(FE::from_canonical_u64(136)) + * (already_absorbed_bytes + P::from(FE::from_canonical_usize(KECCAK_RATE_BYTES)) - next_values.already_absorbed_bytes), ); @@ -481,16 +532,6 @@ impl, const D: usize> Stark for KeccakSpongeS let entry_match = offset - P::from(FE::from_canonical_usize(i)); yield_constr.constraint(is_final_len * entry_match); } - - // Adding constraints for byte columns. - for (l, &elt) in local_values.updated_state_u32s[..8].iter().enumerate() { - let mut s = local_values.updated_state_bytes[l * 4]; - for i in 1..4 { - s += local_values.updated_state_bytes[l * 4 + i] - * P::from(FE::from_canonical_usize(1 << (8 * i))); - } - yield_constr.constraint(is_final_block * (s - elt)); - } } fn eval_ext_circuit( @@ -566,19 +607,36 @@ impl, const D: usize> Stark for KeccakSpongeS yield_constr.constraint_transition(builder, constraint); // If this is a full-input block, the next row's "before" should match our "after" state. + for (current_bytes_after, next_before) in local_values + .updated_digest_state_bytes + .chunks_exact(4) + .zip(&next_values.original_rate_u32s[..KECCAK_DIGEST_U32S]) + { + let mut current_after = current_bytes_after[0]; + for i in 1..4 { + current_after = builder.mul_const_add_extension( + F::from_canonical_usize(1 << (8 * i)), + current_bytes_after[i], + current_after, + ); + } + let diff = builder.sub_extension(*next_before, current_after); + let constraint = builder.mul_extension(is_full_input_block, diff); + yield_constr.constraint_transition(builder, constraint); + } for (¤t_after, &next_before) in local_values - .updated_state_u32s + .partial_updated_state_u32s .iter() - .zip(next_values.original_rate_u32s.iter()) + .zip(next_values.original_rate_u32s[KECCAK_DIGEST_U32S..].iter()) { let diff = builder.sub_extension(next_before, current_after); let constraint = builder.mul_extension(is_full_input_block, diff); yield_constr.constraint_transition(builder, constraint); } for (¤t_after, &next_before) in local_values - .updated_state_u32s + .partial_updated_state_u32s .iter() - .skip(KECCAK_RATE_U32S) + .skip(KECCAK_RATE_U32S - KECCAK_DIGEST_U32S) .zip(next_values.original_capacity_u32s.iter()) { let diff = builder.sub_extension(next_before, current_after); @@ -586,9 +644,11 @@ impl, const D: usize> Stark for KeccakSpongeS yield_constr.constraint_transition(builder, constraint); } - // If this is a full-input block, the next row's already_absorbed_bytes should be ours plus 136. - let absorbed_bytes = - builder.add_const_extension(already_absorbed_bytes, F::from_canonical_u64(136)); + // If this is a full-input block, the next row's already_absorbed_bytes should be ours plus `KECCAK_RATE_BYTES`. + let absorbed_bytes = builder.add_const_extension( + already_absorbed_bytes, + F::from_canonical_usize(KECCAK_RATE_BYTES), + ); let absorbed_diff = builder.sub_extension(absorbed_bytes, next_values.already_absorbed_bytes); let constraint = builder.mul_extension(is_full_input_block, absorbed_diff); @@ -615,21 +675,6 @@ impl, const D: usize> Stark for KeccakSpongeS let constraint = builder.mul_extension(is_final_len, entry_match); yield_constr.constraint(builder, constraint); } - - // Adding constraints for byte columns. - for (l, &elt) in local_values.updated_state_u32s[..8].iter().enumerate() { - let mut s = local_values.updated_state_bytes[l * 4]; - for i in 1..4 { - s = builder.mul_const_add_extension( - F::from_canonical_usize(1 << (8 * i)), - local_values.updated_state_bytes[l * 4 + i], - s, - ); - } - let constraint = builder.sub_extension(s, elt); - let constraint = builder.mul_extension(is_final_block, constraint); - yield_constr.constraint(builder, constraint); - } } fn constraint_degree(&self) -> usize { @@ -698,9 +743,10 @@ mod tests { let rows = stark.generate_rows_for_op(op); assert_eq!(rows.len(), 1); let last_row: &KeccakSpongeColumnsView = rows.last().unwrap().borrow(); - let output = last_row.updated_state_u32s[..8] + let output = last_row + .updated_digest_state_bytes .iter() - .flat_map(|x| (x.to_canonical_u64() as u32).to_le_bytes()) + .map(|x| x.to_canonical_u64() as u8) .collect_vec(); assert_eq!(output, expected_output.0); diff --git a/evm/src/lib.rs b/evm/src/lib.rs index 29ad6738..ab48cda0 100644 --- a/evm/src/lib.rs +++ b/evm/src/lib.rs @@ -8,6 +8,7 @@ pub mod all_stark; pub mod arithmetic; +pub mod byte_packing; pub mod config; pub mod constraint_consumer; pub mod cpu; diff --git a/evm/src/proof.rs b/evm/src/proof.rs index a5bb2b3d..71b2feb4 100644 --- a/evm/src/proof.rs +++ b/evm/src/proof.rs @@ -74,33 +74,68 @@ impl Default for BlockHashes { } } +/// User-provided helper values to compute the `BLOCKHASH` opcode. +/// The proofs across consecutive blocks ensure that these values +/// are consistent (i.e. shifted by one to the left). +/// +/// When the block number is less than 256, dummy values, i.e. `H256::default()`, +/// should be used for the additional block hashes. #[derive(Debug, Clone, Serialize, Deserialize)] pub struct BlockHashes { + /// The previous 256 hashes to the current block. The leftmost hash, i.e. `prev_hashes[0]`, + /// is the oldest, and the rightmost, i.e. `prev_hashes[255]` is the hash of the parent block. pub prev_hashes: Vec, + // The hash of the current block. pub cur_hash: H256, } +/// Metadata contained in a block header. Those are identical between +/// all state transition proofs within the same block. #[derive(Debug, Clone, Default, Deserialize, Serialize)] pub struct BlockMetadata { + /// The address of this block's producer. pub block_beneficiary: Address, + /// The timestamp of this block. pub block_timestamp: U256, + /// The index of this block. pub block_number: U256, + /// The difficulty (before PoS transition) of this block. pub block_difficulty: U256, + /// The gas limit of this block. It must fit in a `u32`. pub block_gaslimit: U256, + /// The chain id of this block. pub block_chain_id: U256, + /// The base fee of this block. pub block_base_fee: U256, + /// The total gas used in this block. It must fit in a `u32`. pub block_gas_used: U256, + /// The block bloom of this block, represented as the consecutive + /// 32-byte chunks of a block's final bloom filter string. pub block_bloom: [U256; 8], } +/// Additional block data that are specific to the local transaction being proven, +/// unlike `BlockMetadata`. #[derive(Debug, Clone, Default, Deserialize, Serialize)] pub struct ExtraBlockData { + /// The state trie digest of the gensis block. pub genesis_state_root: H256, + /// The transaction count prior execution of the local state transition, starting + /// at 0 for the initial transaction of a block. pub txn_number_before: U256, + /// The transaction count after execution of the local state transition. pub txn_number_after: U256, + /// The accumulated gas used prior execution of the local state transition, starting + /// at 0 for the initial transaction of a block. pub gas_used_before: U256, + /// The accumulated gas used after execution of the local state transition. It should + /// match the `block_gas_used` value after execution of the last transaction in a block. pub gas_used_after: U256, + /// The accumulated bloom filter of this block prior execution of the local state transition, + /// starting with all zeros for the initial transaction of a block. pub block_bloom_before: [U256; 8], + /// The accumulated bloom filter after execution of the local state transition. It should + /// match the `block_bloom` value after execution of the last transaction in a block. pub block_bloom_after: [U256; 8], } @@ -640,7 +675,7 @@ impl, C: GenericConfig, const D: usize> S } pub fn num_ctl_zs(&self) -> usize { - self.openings.ctl_zs_last.len() + self.openings.ctl_zs_first.len() } } @@ -721,8 +756,8 @@ pub struct StarkOpeningSet, const D: usize> { pub permutation_ctl_zs: Vec, /// Openings of permutations and cross-table lookups `Z` polynomials at `g * zeta`. pub permutation_ctl_zs_next: Vec, - /// Openings of cross-table lookups `Z` polynomials at `g^-1`. - pub ctl_zs_last: Vec, + /// Openings of cross-table lookups `Z` polynomials at `1`. + pub ctl_zs_first: Vec, /// Openings of quotient polynomials at `zeta`. pub quotient_polys: Vec, } @@ -734,7 +769,6 @@ impl, const D: usize> StarkOpeningSet { trace_commitment: &PolynomialBatch, permutation_ctl_zs_commitment: &PolynomialBatch, quotient_commitment: &PolynomialBatch, - degree_bits: usize, num_permutation_zs: usize, ) -> Self { let eval_commitment = |z: F::Extension, c: &PolynomialBatch| { @@ -755,10 +789,8 @@ impl, const D: usize> StarkOpeningSet { next_values: eval_commitment(zeta_next, trace_commitment), permutation_ctl_zs: eval_commitment(zeta, permutation_ctl_zs_commitment), permutation_ctl_zs_next: eval_commitment(zeta_next, permutation_ctl_zs_commitment), - ctl_zs_last: eval_commitment_base( - F::primitive_root_of_unity(degree_bits).inverse(), - permutation_ctl_zs_commitment, - )[num_permutation_zs..] + ctl_zs_first: eval_commitment_base(F::ONE, permutation_ctl_zs_commitment) + [num_permutation_zs..] .to_vec(), quotient_polys: eval_commitment(zeta, quotient_commitment), } @@ -782,10 +814,10 @@ impl, const D: usize> StarkOpeningSet { .copied() .collect_vec(), }; - debug_assert!(!self.ctl_zs_last.is_empty()); - let ctl_last_batch = FriOpeningBatch { + debug_assert!(!self.ctl_zs_first.is_empty()); + let ctl_first_batch = FriOpeningBatch { values: self - .ctl_zs_last + .ctl_zs_first .iter() .copied() .map(F::Extension::from_basefield) @@ -793,7 +825,7 @@ impl, const D: usize> StarkOpeningSet { }; FriOpenings { - batches: vec![zeta_batch, zeta_next_batch, ctl_last_batch], + batches: vec![zeta_batch, zeta_next_batch, ctl_first_batch], } } } @@ -804,7 +836,7 @@ pub struct StarkOpeningSetTarget { pub next_values: Vec>, pub permutation_ctl_zs: Vec>, pub permutation_ctl_zs_next: Vec>, - pub ctl_zs_last: Vec, + pub ctl_zs_first: Vec, pub quotient_polys: Vec>, } @@ -814,7 +846,7 @@ impl StarkOpeningSetTarget { buffer.write_target_ext_vec(&self.next_values)?; buffer.write_target_ext_vec(&self.permutation_ctl_zs)?; buffer.write_target_ext_vec(&self.permutation_ctl_zs_next)?; - buffer.write_target_vec(&self.ctl_zs_last)?; + buffer.write_target_vec(&self.ctl_zs_first)?; buffer.write_target_ext_vec(&self.quotient_polys)?; Ok(()) } @@ -824,7 +856,7 @@ impl StarkOpeningSetTarget { let next_values = buffer.read_target_ext_vec::()?; let permutation_ctl_zs = buffer.read_target_ext_vec::()?; let permutation_ctl_zs_next = buffer.read_target_ext_vec::()?; - let ctl_zs_last = buffer.read_target_vec()?; + let ctl_zs_first = buffer.read_target_vec()?; let quotient_polys = buffer.read_target_ext_vec::()?; Ok(Self { @@ -832,7 +864,7 @@ impl StarkOpeningSetTarget { next_values, permutation_ctl_zs, permutation_ctl_zs_next, - ctl_zs_last, + ctl_zs_first, quotient_polys, }) } @@ -855,10 +887,10 @@ impl StarkOpeningSetTarget { .copied() .collect_vec(), }; - debug_assert!(!self.ctl_zs_last.is_empty()); - let ctl_last_batch = FriOpeningBatchTarget { + debug_assert!(!self.ctl_zs_first.is_empty()); + let ctl_first_batch = FriOpeningBatchTarget { values: self - .ctl_zs_last + .ctl_zs_first .iter() .copied() .map(|t| t.to_ext_target(zero)) @@ -866,7 +898,7 @@ impl StarkOpeningSetTarget { }; FriOpeningsTarget { - batches: vec![zeta_batch, zeta_next_batch, ctl_last_batch], + batches: vec![zeta_batch, zeta_next_batch, ctl_first_batch], } } } diff --git a/evm/src/prover.rs b/evm/src/prover.rs index 31be89e7..7b960c95 100644 --- a/evm/src/prover.rs +++ b/evm/src/prover.rs @@ -21,6 +21,7 @@ use plonky2_util::{log2_ceil, log2_strict}; use crate::all_stark::{AllStark, Table, NUM_TABLES}; use crate::arithmetic::arithmetic_stark::ArithmeticStark; +use crate::byte_packing::byte_packing_stark::BytePackingStark; use crate::config::StarkConfig; use crate::constraint_consumer::ConstraintConsumer; use crate::cpu::cpu_stark::CpuStark; @@ -53,6 +54,7 @@ where F: RichField + Extendable, C: GenericConfig, [(); ArithmeticStark::::COLUMNS]:, + [(); BytePackingStark::::COLUMNS]:, [(); CpuStark::::COLUMNS]:, [(); KeccakStark::::COLUMNS]:, [(); KeccakSpongeStark::::COLUMNS]:, @@ -75,6 +77,7 @@ where F: RichField + Extendable, C: GenericConfig, [(); ArithmeticStark::::COLUMNS]:, + [(); BytePackingStark::::COLUMNS]:, [(); CpuStark::::COLUMNS]:, [(); KeccakStark::::COLUMNS]:, [(); KeccakSpongeStark::::COLUMNS]:, @@ -103,6 +106,7 @@ where F: RichField + Extendable, C: GenericConfig, [(); ArithmeticStark::::COLUMNS]:, + [(); BytePackingStark::::COLUMNS]:, [(); CpuStark::::COLUMNS]:, [(); KeccakStark::::COLUMNS]:, [(); KeccakSpongeStark::::COLUMNS]:, @@ -146,7 +150,8 @@ where challenger.observe_cap(cap); } - observe_public_values::(&mut challenger, &public_values); + observe_public_values::(&mut challenger, &public_values) + .map_err(|_| anyhow::Error::msg("Invalid conversion of public values."))?; let ctl_challenges = get_grand_product_challenge_set(&mut challenger, config.num_challenges); let ctl_data_per_table = timed!( @@ -193,6 +198,7 @@ where F: RichField + Extendable, C: GenericConfig, [(); ArithmeticStark::::COLUMNS]:, + [(); BytePackingStark::::COLUMNS]:, [(); CpuStark::::COLUMNS]:, [(); KeccakStark::::COLUMNS]:, [(); KeccakSpongeStark::::COLUMNS]:, @@ -212,6 +218,19 @@ where timing, )? ); + let byte_packing_proof = timed!( + timing, + "prove byte packing STARK", + prove_single_table( + &all_stark.byte_packing_stark, + config, + &trace_poly_values[Table::BytePacking as usize], + &trace_commitments[Table::BytePacking as usize], + &ctl_data_per_table[Table::BytePacking as usize], + challenger, + timing, + )? + ); let cpu_proof = timed!( timing, "prove CPU STARK", @@ -277,8 +296,10 @@ where timing, )? ); + Ok([ arithmetic_proof, + byte_packing_proof, cpu_proof, keccak_proof, keccak_sponge_proof, @@ -433,7 +454,6 @@ where trace_commitment, &permutation_ctl_zs_commitment, "ient_commitment, - degree_bits, stark.num_permutation_batches(config), ); challenger.observe_openings(&openings.to_fri_openings()); @@ -448,7 +468,7 @@ where timing, "compute openings proof", PolynomialBatch::prove_openings( - &stark.fri_instance(zeta, g, degree_bits, ctl_data.len(), config), + &stark.fri_instance(zeta, g, ctl_data.len(), config), &initial_merkle_trees, challenger, &fri_params, diff --git a/evm/src/recursive_verifier.rs b/evm/src/recursive_verifier.rs index 3558dc9a..60548e92 100644 --- a/evm/src/recursive_verifier.rs +++ b/evm/src/recursive_verifier.rs @@ -43,9 +43,10 @@ use crate::proof::{ TrieRootsTarget, }; use crate::stark::Stark; -use crate::util::{h256_limbs, u256_limbs}; +use crate::util::{h256_limbs, u256_limbs, u256_to_u32, u256_to_u64}; use crate::vanishing_poly::eval_vanishing_poly_circuit; use crate::vars::StarkEvaluationTargets; +use crate::witness::errors::ProgramError; /// Table-wise recursive proofs of an `AllProof`. pub struct RecursiveAllProof< @@ -59,7 +60,7 @@ pub struct RecursiveAllProof< pub(crate) struct PublicInputs> { pub(crate) trace_cap: Vec>, - pub(crate) ctl_zs_last: Vec, + pub(crate) ctl_zs_first: Vec, pub(crate) ctl_challenges: GrandProductChallengeSet, pub(crate) challenger_state_before: P, pub(crate) challenger_state_after: P, @@ -85,11 +86,11 @@ impl> Public }; let challenger_state_before = P::new(&mut iter); let challenger_state_after = P::new(&mut iter); - let ctl_zs_last: Vec<_> = iter.collect(); + let ctl_zs_first: Vec<_> = iter.collect(); Self { trace_cap, - ctl_zs_last, + ctl_zs_first, ctl_challenges, challenger_state_before, challenger_state_after, @@ -150,7 +151,7 @@ impl, C: GenericConfig, const D: usize> // Verify the CTL checks. verify_cross_table_lookups::( &cross_table_lookups, - pis.map(|p| p.ctl_zs_last), + pis.map(|p| p.ctl_zs_first), extra_looking_products, inner_config, )?; @@ -350,7 +351,7 @@ where let challenger_state = challenger.compact(&mut builder); builder.register_public_inputs(challenger_state.as_ref()); - builder.register_public_inputs(&proof_target.openings.ctl_zs_last); + builder.register_public_inputs(&proof_target.openings.ctl_zs_first); verify_stark_proof_with_challenges_circuit::( &mut builder, @@ -414,7 +415,7 @@ fn verify_stark_proof_with_challenges_circuit< next_values, permutation_ctl_zs, permutation_ctl_zs_next, - ctl_zs_last, + ctl_zs_first, quotient_polys, } = &proof.openings; let vars = StarkEvaluationTargets { @@ -484,8 +485,7 @@ fn verify_stark_proof_with_challenges_circuit< builder, challenges.stark_zeta, F::primitive_root_of_unity(degree_bits), - degree_bits, - ctl_zs_last.len(), + ctl_zs_first.len(), inner_config, ); builder.verify_fri_proof::( @@ -871,7 +871,7 @@ fn add_virtual_stark_opening_set, S: Stark, c .add_virtual_extension_targets(stark.num_permutation_batches(config) + num_ctl_zs), permutation_ctl_zs_next: builder .add_virtual_extension_targets(stark.num_permutation_batches(config) + num_ctl_zs), - ctl_zs_last: builder.add_virtual_targets(num_ctl_zs), + ctl_zs_first: builder.add_virtual_targets(num_ctl_zs), quotient_polys: builder .add_virtual_extension_targets(stark.quotient_degree_factor() * num_challenges), } @@ -907,7 +907,8 @@ pub(crate) fn set_public_value_targets( witness: &mut W, public_values_target: &PublicValuesTarget, public_values: &PublicValues, -) where +) -> Result<(), ProgramError> +where F: RichField + Extendable, W: Witness, { @@ -925,7 +926,7 @@ pub(crate) fn set_public_value_targets( witness, &public_values_target.block_metadata, &public_values.block_metadata, - ); + )?; set_block_hashes_target( witness, &public_values_target.block_hashes, @@ -936,6 +937,8 @@ pub(crate) fn set_public_value_targets( &public_values_target.extra_block_data, &public_values.extra_block_data, ); + + Ok(()) } pub(crate) fn set_trie_roots_target( @@ -996,7 +999,8 @@ pub(crate) fn set_block_metadata_target( witness: &mut W, block_metadata_target: &BlockMetadataTarget, block_metadata: &BlockMetadata, -) where +) -> Result<(), ProgramError> +where F: RichField + Extendable, W: Witness, { @@ -1007,42 +1011,39 @@ pub(crate) fn set_block_metadata_target( witness.set_target_arr(&block_metadata_target.block_beneficiary, &beneficiary_limbs); witness.set_target( block_metadata_target.block_timestamp, - F::from_canonical_u32(block_metadata.block_timestamp.as_u32()), + u256_to_u32(block_metadata.block_timestamp)?, ); witness.set_target( block_metadata_target.block_number, - F::from_canonical_u32(block_metadata.block_number.as_u32()), + u256_to_u32(block_metadata.block_number)?, ); witness.set_target( block_metadata_target.block_difficulty, - F::from_canonical_u32(block_metadata.block_difficulty.as_u32()), + u256_to_u32(block_metadata.block_difficulty)?, ); witness.set_target( block_metadata_target.block_gaslimit, - F::from_canonical_u32(block_metadata.block_gaslimit.as_u32()), + u256_to_u32(block_metadata.block_gaslimit)?, ); witness.set_target( block_metadata_target.block_chain_id, - F::from_canonical_u32(block_metadata.block_chain_id.as_u32()), + u256_to_u32(block_metadata.block_chain_id)?, ); // Basefee fits in 2 limbs - witness.set_target( - block_metadata_target.block_base_fee[0], - F::from_canonical_u32(block_metadata.block_base_fee.as_u64() as u32), - ); - witness.set_target( - block_metadata_target.block_base_fee[1], - F::from_canonical_u32((block_metadata.block_base_fee.as_u64() >> 32) as u32), - ); + let basefee = u256_to_u64(block_metadata.block_base_fee)?; + witness.set_target(block_metadata_target.block_base_fee[0], basefee.0); + witness.set_target(block_metadata_target.block_base_fee[1], basefee.1); witness.set_target( block_metadata_target.block_gas_used, - F::from_canonical_u64(block_metadata.block_gas_used.as_u64()), + u256_to_u32(block_metadata.block_gas_used)?, ); let mut block_bloom_limbs = [F::ZERO; 64]; for (i, limbs) in block_bloom_limbs.chunks_exact_mut(8).enumerate() { limbs.copy_from_slice(&u256_limbs(block_metadata.block_bloom[i])); } witness.set_target_arr(&block_metadata_target.block_bloom, &block_bloom_limbs); + + Ok(()) } pub(crate) fn set_block_hashes_target( diff --git a/evm/src/stark.rs b/evm/src/stark.rs index 72cee0ad..73b51ada 100644 --- a/evm/src/stark.rs +++ b/evm/src/stark.rs @@ -84,7 +84,6 @@ pub trait Stark, const D: usize>: Sync { &self, zeta: F::Extension, g: F, - degree_bits: usize, num_ctl_zs: usize, config: &StarkConfig, ) -> FriInstanceInfo { @@ -131,13 +130,13 @@ pub trait Stark, const D: usize>: Sync { point: zeta.scalar_mul(g), polynomials: [trace_info, permutation_ctl_zs_info].concat(), }; - let ctl_last_batch = FriBatchInfo { - point: F::Extension::primitive_root_of_unity(degree_bits).inverse(), + let ctl_first_batch = FriBatchInfo { + point: F::Extension::ONE, polynomials: ctl_zs_info, }; FriInstanceInfo { oracles: vec![trace_oracle, permutation_ctl_oracle, quotient_oracle], - batches: vec![zeta_batch, zeta_next_batch, ctl_last_batch], + batches: vec![zeta_batch, zeta_next_batch, ctl_first_batch], } } @@ -147,7 +146,6 @@ pub trait Stark, const D: usize>: Sync { builder: &mut CircuitBuilder, zeta: ExtensionTarget, g: F, - degree_bits: usize, num_ctl_zs: usize, inner_config: &StarkConfig, ) -> FriInstanceInfoTarget { @@ -195,14 +193,13 @@ pub trait Stark, const D: usize>: Sync { point: zeta_next, polynomials: [trace_info, permutation_ctl_zs_info].concat(), }; - let ctl_last_batch = FriBatchInfoTarget { - point: builder - .constant_extension(F::Extension::primitive_root_of_unity(degree_bits).inverse()), + let ctl_first_batch = FriBatchInfoTarget { + point: builder.one_extension(), polynomials: ctl_zs_info, }; FriInstanceInfoTarget { oracles: vec![trace_oracle, permutation_ctl_oracle, quotient_oracle], - batches: vec![zeta_batch, zeta_next_batch, ctl_last_batch], + batches: vec![zeta_batch, zeta_next_batch, ctl_first_batch], } } diff --git a/evm/src/util.rs b/evm/src/util.rs index 5fa085dc..a3f6d050 100644 --- a/evm/src/util.rs +++ b/evm/src/util.rs @@ -11,6 +11,8 @@ use plonky2::hash::hash_types::RichField; use plonky2::iop::ext_target::ExtensionTarget; use plonky2::util::transpose; +use crate::witness::errors::ProgramError; + /// Construct an integer from its constituent bits (in little-endian order) pub fn limb_from_bits_le(iter: impl IntoIterator) -> P { // TODO: This is technically wrong, as 1 << i won't be canonical for all fields... @@ -45,6 +47,29 @@ pub fn trace_rows_to_poly_values( .collect() } +/// Returns the lowest LE 32-bit limb of a `U256` as a field element, +/// and errors if the integer is actually greater. +pub(crate) fn u256_to_u32(u256: U256) -> Result { + if TryInto::::try_into(u256).is_err() { + return Err(ProgramError::IntegerTooLarge); + } + + Ok(F::from_canonical_u32(u256.low_u32())) +} + +/// Returns the lowest LE 64-bit word of a `U256` as two field elements +/// each storing a 32-bit limb, and errors if the integer is actually greater. +pub(crate) fn u256_to_u64(u256: U256) -> Result<(F, F), ProgramError> { + if TryInto::::try_into(u256).is_err() { + return Err(ProgramError::IntegerTooLarge); + } + + Ok(( + F::from_canonical_u32(u256.low_u64() as u32), + F::from_canonical_u32((u256.low_u64() >> 32) as u32), + )) +} + #[allow(unused)] // TODO: Remove? /// Returns the 32-bit little-endian limbs of a `U256`. pub(crate) fn u256_limbs(u256: U256) -> [F; 8] { diff --git a/evm/src/verifier.rs b/evm/src/verifier.rs index 49225f14..11f8155d 100644 --- a/evm/src/verifier.rs +++ b/evm/src/verifier.rs @@ -2,6 +2,7 @@ use std::any::type_name; use anyhow::{ensure, Result}; use ethereum_types::U256; +use itertools::Itertools; use plonky2::field::extension::{Extendable, FieldExtension}; use plonky2::field::types::Field; use plonky2::fri::verifier::verify_fri_proof; @@ -11,6 +12,7 @@ use plonky2::plonk::plonk_common::reduce_with_powers; use crate::all_stark::{AllStark, Table, NUM_TABLES}; use crate::arithmetic::arithmetic_stark::ArithmeticStark; +use crate::byte_packing::byte_packing_stark::BytePackingStark; use crate::config::StarkConfig; use crate::constraint_consumer::ConstraintConsumer; use crate::cpu::cpu_stark::CpuStark; @@ -38,6 +40,7 @@ pub fn verify_proof, C: GenericConfig, co ) -> Result<()> where [(); ArithmeticStark::::COLUMNS]:, + [(); BytePackingStark::::COLUMNS]:, [(); CpuStark::::COLUMNS]:, [(); KeccakStark::::COLUMNS]:, [(); KeccakSpongeStark::::COLUMNS]:, @@ -47,12 +50,15 @@ where let AllProofChallenges { stark_challenges, ctl_challenges, - } = all_proof.get_challenges(all_stark, config); + } = all_proof + .get_challenges(all_stark, config) + .map_err(|_| anyhow::Error::msg("Invalid sampling of proof challenges."))?; let nums_permutation_zs = all_stark.nums_permutation_zs(config); let AllStark { arithmetic_stark, + byte_packing_stark, cpu_stark, keccak_stark, keccak_sponge_stark, @@ -75,6 +81,13 @@ where &ctl_vars_per_table[Table::Arithmetic as usize], config, )?; + verify_stark_proof_with_challenges( + byte_packing_stark, + &all_proof.stark_proofs[Table::BytePacking as usize].proof, + &stark_challenges[Table::BytePacking as usize], + &ctl_vars_per_table[Table::BytePacking as usize], + config, + )?; verify_stark_proof_with_challenges( cpu_stark, &all_proof.stark_proofs[Table::Cpu as usize].proof, @@ -96,13 +109,6 @@ where &ctl_vars_per_table[Table::KeccakSponge as usize], config, )?; - verify_stark_proof_with_challenges( - memory_stark, - &all_proof.stark_proofs[Table::Memory as usize].proof, - &stark_challenges[Table::Memory as usize], - &ctl_vars_per_table[Table::Memory as usize], - config, - )?; verify_stark_proof_with_challenges( logic_stark, &all_proof.stark_proofs[Table::Logic as usize].proof, @@ -110,25 +116,30 @@ where &ctl_vars_per_table[Table::Logic as usize], config, )?; + verify_stark_proof_with_challenges( + memory_stark, + &all_proof.stark_proofs[Table::Memory as usize].proof, + &stark_challenges[Table::Memory as usize], + &ctl_vars_per_table[Table::Memory as usize], + config, + )?; let public_values = all_proof.public_values; - // Extra products to add to the looked last value - // Arithmetic, KeccakSponge, Keccak, Logic - let mut extra_looking_products = vec![vec![F::ONE; config.num_challenges]; NUM_TABLES - 1]; + // Extra products to add to the looked last value. + // Only necessary for the Memory values. + let mut extra_looking_products = vec![vec![F::ONE; config.num_challenges]; NUM_TABLES]; // Memory - extra_looking_products.push(Vec::new()); - for c in 0..config.num_challenges { - extra_looking_products[Table::Memory as usize].push(get_memory_extra_looking_products( - &public_values, - ctl_challenges.challenges[c], - )); - } + extra_looking_products[Table::Memory as usize] = (0..config.num_challenges) + .map(|i| get_memory_extra_looking_products(&public_values, ctl_challenges.challenges[i])) + .collect_vec(); verify_cross_table_lookups::( cross_table_lookups, - all_proof.stark_proofs.map(|p| p.proof.openings.ctl_zs_last), + all_proof + .stark_proofs + .map(|p| p.proof.openings.ctl_zs_first), extra_looking_products, config, ) @@ -301,7 +312,7 @@ where next_values, permutation_ctl_zs, permutation_ctl_zs_next, - ctl_zs_last, + ctl_zs_first, quotient_polys, } = &proof.openings; let vars = StarkEvaluationVars { @@ -367,8 +378,7 @@ where &stark.fri_instance( challenges.stark_zeta, F::primitive_root_of_unity(degree_bits), - degree_bits, - ctl_zs_last.len(), + ctl_zs_first.len(), config, ), &proof.openings.to_fri_openings(), @@ -408,7 +418,7 @@ where next_values, permutation_ctl_zs, permutation_ctl_zs_next, - ctl_zs_last, + ctl_zs_first, quotient_polys, } = openings; @@ -425,7 +435,7 @@ where ensure!(next_values.len() == S::COLUMNS); ensure!(permutation_ctl_zs.len() == num_zs); ensure!(permutation_ctl_zs_next.len() == num_zs); - ensure!(ctl_zs_last.len() == num_ctl_zs); + ensure!(ctl_zs_first.len() == num_ctl_zs); ensure!(quotient_polys.len() == stark.num_quotient_polys(config)); Ok(()) diff --git a/evm/src/witness/errors.rs b/evm/src/witness/errors.rs index 44693a33..1ab99eae 100644 --- a/evm/src/witness/errors.rs +++ b/evm/src/witness/errors.rs @@ -13,6 +13,7 @@ pub enum ProgramError { MemoryError(MemoryError), GasLimitError, InterpreterError, + IntegerTooLarge, } #[allow(clippy::enum_variant_names)] diff --git a/evm/src/witness/gas.rs b/evm/src/witness/gas.rs index 4c7947bb..aa312078 100644 --- a/evm/src/witness/gas.rs +++ b/evm/src/witness/gas.rs @@ -25,8 +25,8 @@ pub(crate) fn gas_to_charge(op: Operation) -> u64 { BinaryArithmetic(Lt) => G_VERYLOW, BinaryArithmetic(Gt) => G_VERYLOW, BinaryArithmetic(Byte) => G_VERYLOW, - Shl => G_VERYLOW, - Shr => G_VERYLOW, + BinaryArithmetic(Shl) => G_VERYLOW, + BinaryArithmetic(Shr) => G_VERYLOW, BinaryArithmetic(AddFp254) => KERNEL_ONLY_INSTR, BinaryArithmetic(MulFp254) => KERNEL_ONLY_INSTR, BinaryArithmetic(SubFp254) => KERNEL_ONLY_INSTR, @@ -44,6 +44,8 @@ pub(crate) fn gas_to_charge(op: Operation) -> u64 { Swap(_) => G_VERYLOW, GetContext => KERNEL_ONLY_INSTR, SetContext => KERNEL_ONLY_INSTR, + Mload32Bytes => KERNEL_ONLY_INSTR, + Mstore32Bytes => KERNEL_ONLY_INSTR, ExitKernel => KERNEL_ONLY_INSTR, MloadGeneral => KERNEL_ONLY_INSTR, MstoreGeneral => KERNEL_ONLY_INSTR, diff --git a/evm/src/witness/mod.rs b/evm/src/witness/mod.rs index 7d491e4e..fbb88a71 100644 --- a/evm/src/witness/mod.rs +++ b/evm/src/witness/mod.rs @@ -1,4 +1,4 @@ -mod errors; +pub(crate) mod errors; mod gas; pub(crate) mod memory; mod operation; diff --git a/evm/src/witness/operation.rs b/evm/src/witness/operation.rs index 13619b96..8349d56d 100644 --- a/evm/src/witness/operation.rs +++ b/evm/src/witness/operation.rs @@ -3,6 +3,7 @@ use itertools::Itertools; use keccak_hash::keccak; use plonky2::field::types::Field; +use super::util::{byte_packing_log, byte_unpacking_log}; use crate::arithmetic::BinaryOperator; use crate::cpu::columns::CpuColumnsView; use crate::cpu::kernel::aggregator::KERNEL; @@ -28,8 +29,6 @@ use crate::{arithmetic, logic}; pub(crate) enum Operation { Iszero, Not, - Shl, - Shr, Syscall(u8, usize, bool), // (syscall number, minimum stack length, increases stack length) Eq, BinaryLogic(logic::Op), @@ -47,6 +46,8 @@ pub(crate) enum Operation { Swap(u8), GetContext, SetContext, + Mload32Bytes, + Mstore32Bytes, ExitKernel, MloadGeneral, MstoreGeneral, @@ -136,7 +137,7 @@ pub(crate) fn generate_keccak_general( ..base_address }; let val = state.memory.get(address); - val.as_u32() as u8 + val.low_u32() as u8 }) .collect_vec(); log::debug!("Hashing {:?}", input); @@ -374,7 +375,7 @@ pub(crate) fn generate_push( Segment::Code, initial_offset + i, )) - .as_u32() as u8 + .low_u32() as u8 }) .collect_vec(); @@ -470,6 +471,7 @@ pub(crate) fn generate_iszero( fn append_shift( state: &mut GenerationState, mut row: CpuColumnsView, + is_shl: bool, input0: U256, input1: U256, log_in0: MemoryOp, @@ -497,10 +499,10 @@ fn append_shift( } else { U256::one() << input0 }; - let operator = if row.op.shl.is_one() { - BinaryOperator::Mul + let operator = if is_shl { + BinaryOperator::Shl } else { - BinaryOperator::Div + BinaryOperator::Shr }; let operation = arithmetic::Operation::binary(operator, input1, input0); @@ -524,7 +526,7 @@ pub(crate) fn generate_shl( } else { input1 << input0 }; - append_shift(state, row, input0, input1, log_in0, log_in1, result) + append_shift(state, row, true, input0, input1, log_in0, log_in1, result) } pub(crate) fn generate_shr( @@ -539,7 +541,7 @@ pub(crate) fn generate_shr( } else { input1 >> input0 }; - append_shift(state, row, input0, input1, log_in0, log_in1, result) + append_shift(state, row, false, input0, input1, log_in0, log_in1, result) } pub(crate) fn generate_syscall( @@ -686,6 +688,45 @@ pub(crate) fn generate_mload_general( Ok(()) } +pub(crate) fn generate_mload_32bytes( + state: &mut GenerationState, + mut row: CpuColumnsView, +) -> Result<(), ProgramError> { + let [(context, log_in0), (segment, log_in1), (base_virt, log_in2), (len, log_in3)] = + stack_pop_with_log_and_fill::<4, _>(state, &mut row)?; + let len = len.as_usize(); + + let base_address = MemoryAddress::new_u256s(context, segment, base_virt)?; + if usize::MAX - base_address.virt < len { + return Err(ProgramError::MemoryError(VirtTooLarge { + virt: base_address.virt.into(), + })); + } + let bytes = (0..len) + .map(|i| { + let address = MemoryAddress { + virt: base_address.virt + i, + ..base_address + }; + let val = state.memory.get(address); + val.low_u32() as u8 + }) + .collect_vec(); + + let packed_int = U256::from_big_endian(&bytes); + let log_out = stack_push_log_and_fill(state, &mut row, packed_int)?; + + byte_packing_log(state, base_address, bytes); + + state.traces.push_memory(log_in0); + state.traces.push_memory(log_in1); + state.traces.push_memory(log_in2); + state.traces.push_memory(log_in3); + state.traces.push_memory(log_out); + state.traces.push_cpu(row); + Ok(()) +} + pub(crate) fn generate_mstore_general( state: &mut GenerationState, mut row: CpuColumnsView, @@ -715,6 +756,27 @@ pub(crate) fn generate_mstore_general( Ok(()) } +pub(crate) fn generate_mstore_32bytes( + state: &mut GenerationState, + mut row: CpuColumnsView, +) -> Result<(), ProgramError> { + let [(context, log_in0), (segment, log_in1), (base_virt, log_in2), (val, log_in3), (len, log_in4)] = + stack_pop_with_log_and_fill::<5, _>(state, &mut row)?; + let len = len.as_usize(); + + let base_address = MemoryAddress::new_u256s(context, segment, base_virt)?; + + byte_unpacking_log(state, base_address, val, len); + + state.traces.push_memory(log_in0); + state.traces.push_memory(log_in1); + state.traces.push_memory(log_in2); + state.traces.push_memory(log_in3); + state.traces.push_memory(log_in4); + state.traces.push_cpu(row); + Ok(()) +} + pub(crate) fn generate_exception( exc_code: u8, state: &mut GenerationState, diff --git a/evm/src/witness/traces.rs b/evm/src/witness/traces.rs index 2cc1c500..c4cf832d 100644 --- a/evm/src/witness/traces.rs +++ b/evm/src/witness/traces.rs @@ -9,6 +9,7 @@ use plonky2::util::timing::TimingTree; use crate::all_stark::{AllStark, NUM_TABLES}; use crate::arithmetic::{BinaryOperator, Operation}; +use crate::byte_packing::byte_packing_stark::BytePackingOp; use crate::config::StarkConfig; use crate::cpu::columns::CpuColumnsView; use crate::keccak_sponge::columns::KECCAK_WIDTH_BYTES; @@ -20,6 +21,7 @@ use crate::{arithmetic, keccak, keccak_sponge, logic}; #[derive(Clone, Copy, Debug)] pub struct TraceCheckpoint { pub(self) arithmetic_len: usize, + pub(self) byte_packing_len: usize, pub(self) cpu_len: usize, pub(self) keccak_len: usize, pub(self) keccak_sponge_len: usize, @@ -30,6 +32,7 @@ pub struct TraceCheckpoint { #[derive(Debug)] pub(crate) struct Traces { pub(crate) arithmetic_ops: Vec, + pub(crate) byte_packing_ops: Vec, pub(crate) cpu: Vec>, pub(crate) logic_ops: Vec, pub(crate) memory_ops: Vec, @@ -41,6 +44,7 @@ impl Traces { pub fn new() -> Self { Traces { arithmetic_ops: vec![], + byte_packing_ops: vec![], cpu: vec![], logic_ops: vec![], memory_ops: vec![], @@ -64,6 +68,7 @@ impl Traces { }, }) .sum(), + byte_packing_len: self.byte_packing_ops.iter().map(|op| op.bytes.len()).sum(), cpu_len: self.cpu.len(), keccak_len: self.keccak_inputs.len() * keccak::keccak_stark::NUM_ROUNDS, keccak_sponge_len: self @@ -82,6 +87,7 @@ impl Traces { pub fn checkpoint(&self) -> TraceCheckpoint { TraceCheckpoint { arithmetic_len: self.arithmetic_ops.len(), + byte_packing_len: self.byte_packing_ops.len(), cpu_len: self.cpu.len(), keccak_len: self.keccak_inputs.len(), keccak_sponge_len: self.keccak_sponge_ops.len(), @@ -92,6 +98,7 @@ impl Traces { pub fn rollback(&mut self, checkpoint: TraceCheckpoint) { self.arithmetic_ops.truncate(checkpoint.arithmetic_len); + self.byte_packing_ops.truncate(checkpoint.byte_packing_len); self.cpu.truncate(checkpoint.cpu_len); self.keccak_inputs.truncate(checkpoint.keccak_len); self.keccak_sponge_ops @@ -120,6 +127,10 @@ impl Traces { self.memory_ops.push(op); } + pub fn push_byte_packing(&mut self, op: BytePackingOp) { + self.byte_packing_ops.push(op); + } + pub fn push_keccak(&mut self, input: [u64; keccak::keccak_stark::NUM_INPUTS]) { self.keccak_inputs.push(input); } @@ -154,6 +165,7 @@ impl Traces { let cap_elements = config.fri_config.num_cap_elements(); let Traces { arithmetic_ops, + byte_packing_ops, cpu, logic_ops, memory_ops, @@ -166,7 +178,13 @@ impl Traces { "generate arithmetic trace", all_stark.arithmetic_stark.generate_trace(arithmetic_ops) ); - + let byte_packing_trace = timed!( + timing, + "generate byte packing trace", + all_stark + .byte_packing_stark + .generate_trace(byte_packing_ops, cap_elements, timing) + ); let cpu_rows = cpu.into_iter().map(|x| x.into()).collect(); let cpu_trace = trace_rows_to_poly_values(cpu_rows); let keccak_trace = timed!( @@ -198,6 +216,7 @@ impl Traces { [ arithmetic_trace, + byte_packing_trace, cpu_trace, keccak_trace, keccak_sponge_trace, diff --git a/evm/src/witness/transition.rs b/evm/src/witness/transition.rs index 3ee8d4f5..1418beba 100644 --- a/evm/src/witness/transition.rs +++ b/evm/src/witness/transition.rs @@ -70,8 +70,8 @@ fn decode(registers: RegistersState, opcode: u8) -> Result Ok(Operation::BinaryArithmetic( arithmetic::BinaryOperator::Byte, )), - (0x1b, _) => Ok(Operation::Shl), - (0x1c, _) => Ok(Operation::Shr), + (0x1b, _) => Ok(Operation::BinaryArithmetic(arithmetic::BinaryOperator::Shl)), + (0x1c, _) => Ok(Operation::BinaryArithmetic(arithmetic::BinaryOperator::Shr)), (0x1d, _) => Ok(Operation::Syscall(opcode, 2, false)), // SAR (0x20, _) => Ok(Operation::Syscall(opcode, 2, false)), // KECCAK256 (0x21, true) => Ok(Operation::KeccakGeneral), @@ -128,6 +128,7 @@ fn decode(registers: RegistersState, opcode: u8) -> Result Ok(Operation::Mstore32Bytes), (0xf0, _) => Ok(Operation::Syscall(opcode, 3, false)), // CREATE (0xf1, _) => Ok(Operation::Syscall(opcode, 7, false)), // CALL (0xf2, _) => Ok(Operation::Syscall(opcode, 7, false)), // CALLCODE @@ -136,6 +137,7 @@ fn decode(registers: RegistersState, opcode: u8) -> Result Ok(Operation::Syscall(opcode, 4, false)), // CREATE2 (0xf6, true) => Ok(Operation::GetContext), (0xf7, true) => Ok(Operation::SetContext), + (0xf8, true) => Ok(Operation::Mload32Bytes), (0xf9, true) => Ok(Operation::ExitKernel), (0xfa, _) => Ok(Operation::Syscall(opcode, 6, false)), // STATICCALL (0xfb, true) => Ok(Operation::MloadGeneral), @@ -160,22 +162,13 @@ fn fill_op_flag(op: Operation, row: &mut CpuColumnsView) { Operation::Not => &mut flags.not, Operation::Syscall(_, _, _) => &mut flags.syscall, Operation::BinaryLogic(_) => &mut flags.logic_op, - Operation::BinaryArithmetic(arithmetic::BinaryOperator::Add) => &mut flags.add, - Operation::BinaryArithmetic(arithmetic::BinaryOperator::Mul) => &mut flags.mul, - Operation::BinaryArithmetic(arithmetic::BinaryOperator::Sub) => &mut flags.sub, - Operation::BinaryArithmetic(arithmetic::BinaryOperator::Div) => &mut flags.div, - Operation::BinaryArithmetic(arithmetic::BinaryOperator::Mod) => &mut flags.mod_, - Operation::BinaryArithmetic(arithmetic::BinaryOperator::Lt) => &mut flags.lt, - Operation::BinaryArithmetic(arithmetic::BinaryOperator::Gt) => &mut flags.gt, - Operation::BinaryArithmetic(arithmetic::BinaryOperator::Byte) => &mut flags.byte, - Operation::Shl => &mut flags.shl, - Operation::Shr => &mut flags.shr, - Operation::BinaryArithmetic(arithmetic::BinaryOperator::AddFp254) => &mut flags.addfp254, - Operation::BinaryArithmetic(arithmetic::BinaryOperator::MulFp254) => &mut flags.mulfp254, - Operation::BinaryArithmetic(arithmetic::BinaryOperator::SubFp254) => &mut flags.subfp254, - Operation::TernaryArithmetic(arithmetic::TernaryOperator::AddMod) => &mut flags.addmod, - Operation::TernaryArithmetic(arithmetic::TernaryOperator::MulMod) => &mut flags.mulmod, - Operation::TernaryArithmetic(arithmetic::TernaryOperator::SubMod) => &mut flags.submod, + Operation::BinaryArithmetic(arithmetic::BinaryOperator::AddFp254) + | Operation::BinaryArithmetic(arithmetic::BinaryOperator::MulFp254) + | Operation::BinaryArithmetic(arithmetic::BinaryOperator::SubFp254) => &mut flags.fp254_op, + Operation::BinaryArithmetic(arithmetic::BinaryOperator::Shl) + | Operation::BinaryArithmetic(arithmetic::BinaryOperator::Shr) => &mut flags.shift, + Operation::BinaryArithmetic(_) => &mut flags.binary_op, + Operation::TernaryArithmetic(_) => &mut flags.ternary_op, Operation::KeccakGeneral => &mut flags.keccak_general, Operation::ProverInput => &mut flags.prover_input, Operation::Pop => &mut flags.pop, @@ -183,9 +176,10 @@ fn fill_op_flag(op: Operation, row: &mut CpuColumnsView) { Operation::Pc => &mut flags.pc, Operation::Jumpdest => &mut flags.jumpdest, Operation::GetContext | Operation::SetContext => &mut flags.context_op, + Operation::Mload32Bytes => &mut flags.mload_32bytes, + Operation::Mstore32Bytes => &mut flags.mstore_32bytes, Operation::ExitKernel => &mut flags.exit_kernel, - Operation::MloadGeneral => &mut flags.mload_general, - Operation::MstoreGeneral => &mut flags.mstore_general, + Operation::MloadGeneral | Operation::MstoreGeneral => &mut flags.m_op_general, } = F::ONE; } @@ -200,8 +194,8 @@ fn perform_op( Operation::Swap(n) => generate_swap(n, state, row)?, Operation::Iszero => generate_iszero(state, row)?, Operation::Not => generate_not(state, row)?, - Operation::Shl => generate_shl(state, row)?, - Operation::Shr => generate_shr(state, row)?, + Operation::BinaryArithmetic(arithmetic::BinaryOperator::Shl) => generate_shl(state, row)?, + Operation::BinaryArithmetic(arithmetic::BinaryOperator::Shr) => generate_shr(state, row)?, Operation::Syscall(opcode, stack_values_read, stack_len_increased) => { generate_syscall(opcode, stack_values_read, stack_len_increased, state, row)? } @@ -220,6 +214,8 @@ fn perform_op( Operation::Jumpdest => generate_jumpdest(state, row)?, Operation::GetContext => generate_get_context(state, row)?, Operation::SetContext => generate_set_context(state, row)?, + Operation::Mload32Bytes => generate_mload_32bytes(state, row)?, + Operation::Mstore32Bytes => generate_mstore_32bytes(state, row)?, Operation::ExitKernel => generate_exit_kernel(state, row)?, Operation::MloadGeneral => generate_mload_general(state, row)?, Operation::MstoreGeneral => generate_mstore_general(state, row)?, @@ -290,7 +286,7 @@ fn log_kernel_instruction(state: &GenerationState, op: Operation) { let pc = state.registers.program_counter; let is_interesting_offset = KERNEL .offset_label(pc) - .filter(|label| !label.starts_with("halt_pc")) + .filter(|label| !label.starts_with("halt")) .is_some(); let level = if is_interesting_offset { log::Level::Debug diff --git a/evm/src/witness/util.rs b/evm/src/witness/util.rs index 0e2b3660..94488614 100644 --- a/evm/src/witness/util.rs +++ b/evm/src/witness/util.rs @@ -1,6 +1,7 @@ use ethereum_types::U256; use plonky2::field::types::Field; +use crate::byte_packing::byte_packing_stark::BytePackingOp; use crate::cpu::columns::CpuColumnsView; use crate::cpu::kernel::keccak_util::keccakf_u8s; use crate::cpu::membus::{NUM_CHANNELS, NUM_GP_CHANNELS}; @@ -258,3 +259,63 @@ pub(crate) fn keccak_sponge_log( input, }); } + +pub(crate) fn byte_packing_log( + state: &mut GenerationState, + base_address: MemoryAddress, + bytes: Vec, +) { + let clock = state.traces.clock(); + + let mut address = base_address; + for &byte in &bytes { + state.traces.push_memory(MemoryOp::new( + MemoryChannel::Code, + clock, + address, + MemoryOpKind::Read, + byte.into(), + )); + address.increment(); + } + + state.traces.push_byte_packing(BytePackingOp { + is_read: true, + base_address, + timestamp: clock * NUM_CHANNELS, + bytes, + }); +} + +pub(crate) fn byte_unpacking_log( + state: &mut GenerationState, + base_address: MemoryAddress, + val: U256, + len: usize, +) { + let clock = state.traces.clock(); + + let mut bytes = vec![0; 32]; + val.to_little_endian(&mut bytes); + bytes.resize(len, 0); + bytes.reverse(); + + let mut address = base_address; + for &byte in &bytes { + state.traces.push_memory(MemoryOp::new( + MemoryChannel::Code, + clock, + address, + MemoryOpKind::Write, + byte.into(), + )); + address.increment(); + } + + state.traces.push_byte_packing(BytePackingOp { + is_read: false, + base_address, + timestamp: clock * NUM_CHANNELS, + bytes, + }); +} diff --git a/evm/tests/basic_smart_contract.rs b/evm/tests/basic_smart_contract.rs index 3118a34a..18bf9bd0 100644 --- a/evm/tests/basic_smart_contract.rs +++ b/evm/tests/basic_smart_contract.rs @@ -53,7 +53,10 @@ fn test_basic_smart_contract() -> anyhow::Result<()> { let code_gas = 3 + 3 + 3; let code_hash = keccak(code); - let beneficiary_account_before = AccountRlp::default(); + let beneficiary_account_before = AccountRlp { + nonce: 1.into(), + ..AccountRlp::default() + }; let sender_account_before = AccountRlp { nonce: 5.into(), balance: eth_to_wei(100_000.into()), @@ -66,6 +69,11 @@ fn test_basic_smart_contract() -> anyhow::Result<()> { let state_trie_before = { let mut children = core::array::from_fn(|_| Node::Empty.into()); + children[beneficiary_nibbles.get_nibble(0) as usize] = Node::Leaf { + nibbles: beneficiary_nibbles.truncate_n_nibbles_front(1), + value: rlp::encode(&beneficiary_account_before).to_vec(), + } + .into(); children[sender_nibbles.get_nibble(0) as usize] = Node::Leaf { nibbles: sender_nibbles.truncate_n_nibbles_front(1), value: rlp::encode(&sender_account_before).to_vec(), @@ -90,25 +98,33 @@ fn test_basic_smart_contract() -> anyhow::Result<()> { storage_tries: vec![], }; + let txdata_gas = 2 * 16; + let gas_used = 21_000 + code_gas + txdata_gas; + // Generated using a little py-evm script. let txn = hex!("f861050a8255f094a0a0a0a0a0a0a0a0a0a0a0a0a0a0a0a0a0a0a0a0648242421ba02c89eb757d9deeb1f5b3859a9d4d679951ef610ac47ad4608dc142beb1b7e313a05af7e9fbab825455d36c36c7f4cfcafbeafa9a77bdff936b52afb36d4fe4bcdd"); let value = U256::from(100u32); let block_metadata = BlockMetadata { block_beneficiary: Address::from(beneficiary), - ..BlockMetadata::default() + block_difficulty: 0x20000.into(), + block_number: 1.into(), + block_chain_id: 1.into(), + block_timestamp: 0x03e8.into(), + block_gaslimit: 0xff112233u32.into(), + block_gas_used: gas_used.into(), + block_bloom: [0.into(); 8], + block_base_fee: 0xa.into(), }; let mut contract_code = HashMap::new(); contract_code.insert(keccak(vec![]), vec![]); contract_code.insert(code_hash, code.to_vec()); - let txdata_gas = 2 * 16; - let gas_used = 21_000 + code_gas + txdata_gas; let expected_state_trie_after: HashedPartialTrie = { let beneficiary_account_after = AccountRlp { - balance: beneficiary_account_before.balance + gas_used * 10, - ..beneficiary_account_before + nonce: 1.into(), + ..AccountRlp::default() }; let sender_account_after = AccountRlp { balance: sender_account_before.balance - value - gas_used * 10, diff --git a/evm/tests/empty_txn_list.rs b/evm/tests/empty_txn_list.rs index 977f3efd..dd4e624b 100644 --- a/evm/tests/empty_txn_list.rs +++ b/evm/tests/empty_txn_list.rs @@ -73,7 +73,7 @@ fn test_empty_txn_list() -> anyhow::Result<()> { let all_circuits = AllRecursiveCircuits::::new( &all_stark, - &[16..17, 15..16, 14..15, 9..10, 12..13, 18..19], // Minimal ranges to prove an empty list + &[16..17, 10..11, 15..16, 14..15, 9..10, 12..13, 18..19], // Minimal ranges to prove an empty list &config, ); diff --git a/evm/tests/log_opcode.rs b/evm/tests/log_opcode.rs index d86379ca..ab990746 100644 --- a/evm/tests/log_opcode.rs +++ b/evm/tests/log_opcode.rs @@ -441,7 +441,7 @@ fn test_log_with_aggreg() -> anyhow::Result<()> { // Preprocess all circuits. let all_circuits = AllRecursiveCircuits::::new( &all_stark, - &[16..17, 17..19, 14..15, 9..11, 12..13, 20..21], + &[16..17, 11..13, 17..19, 14..15, 9..11, 12..13, 19..21], &config, ); diff --git a/evm/tests/self_balance_gas_cost.rs b/evm/tests/self_balance_gas_cost.rs index d0e95e11..1c0d166e 100644 --- a/evm/tests/self_balance_gas_cost.rs +++ b/evm/tests/self_balance_gas_cost.rs @@ -5,7 +5,7 @@ use std::time::Duration; use env_logger::{try_init_from_env, Env, DEFAULT_FILTER_ENV}; use eth_trie_utils::nibbles::Nibbles; use eth_trie_utils::partial_trie::{HashedPartialTrie, PartialTrie}; -use ethereum_types::{Address, H256}; +use ethereum_types::{Address, H256, U256}; use hex_literal::hex; use keccak_hash::keccak; use plonky2::field::goldilocks_field::GoldilocksField; @@ -62,7 +62,10 @@ fn self_balance_gas_cost() -> anyhow::Result<()> { + 22100; // SSTORE let code_hash = keccak(code); - let beneficiary_account_before = AccountRlp::default(); + let beneficiary_account_before = AccountRlp { + nonce: 1.into(), + ..AccountRlp::default() + }; let sender_account_before = AccountRlp { balance: 0x3635c9adc5dea00000u128.into(), ..AccountRlp::default() @@ -89,10 +92,18 @@ fn self_balance_gas_cost() -> anyhow::Result<()> { let txn = hex!("f861800a8405f5e10094100000000000000000000000000000000000000080801ba07e09e26678ed4fac08a249ebe8ed680bf9051a5e14ad223e4b2b9d26e0208f37a05f6e3f188e3e6eab7d7d3b6568f5eac7d687b08d307d3154ccd8c87b4630509b"); + let gas_used = 21_000 + code_gas; + let block_metadata = BlockMetadata { block_beneficiary: Address::from(beneficiary), + block_difficulty: 0x20000.into(), + block_number: 1.into(), + block_chain_id: 1.into(), + block_timestamp: 0x03e8.into(), + block_gaslimit: 0xff112233u32.into(), + block_gas_used: gas_used.into(), + block_bloom: [0.into(); 8], block_base_fee: 0xa.into(), - ..BlockMetadata::default() }; let mut contract_code = HashMap::new(); @@ -100,9 +111,12 @@ fn self_balance_gas_cost() -> anyhow::Result<()> { contract_code.insert(code_hash, code.to_vec()); let expected_state_trie_after = { - let beneficiary_account_after = AccountRlp::default(); + let beneficiary_account_after = AccountRlp { + nonce: 1.into(), + ..AccountRlp::default() + }; let sender_account_after = AccountRlp { - balance: 999999999999999568680u128.into(), + balance: sender_account_before.balance - U256::from(gas_used) * U256::from(10), nonce: 1.into(), ..AccountRlp::default() }; @@ -132,7 +146,6 @@ fn self_balance_gas_cost() -> anyhow::Result<()> { expected_state_trie_after }; - let gas_used = 21_000 + code_gas; let receipt_0 = LegacyReceiptRlp { status: true, cum_gas_used: gas_used.into(), diff --git a/plonky2/src/lib.rs b/plonky2/src/lib.rs index b4a4b6af..c2913023 100644 --- a/plonky2/src/lib.rs +++ b/plonky2/src/lib.rs @@ -12,7 +12,9 @@ pub mod gadgets; pub mod gates; pub mod hash; pub mod iop; -pub mod lookup_test; pub mod plonk; pub mod recursion; pub mod util; + +#[cfg(test)] +mod lookup_test; diff --git a/plonky2/src/lookup_test.rs b/plonky2/src/lookup_test.rs index 165d5dfe..bca90d59 100644 --- a/plonky2/src/lookup_test.rs +++ b/plonky2/src/lookup_test.rs @@ -1,523 +1,477 @@ -#[cfg(test)] -mod tests { - static LOGGER_INITIALIZED: Once = Once::new(); - - use alloc::sync::Arc; - use std::sync::Once; - - use itertools::Itertools; - use log::{Level, LevelFilter}; - - use crate::gadgets::lookup::{OTHER_TABLE, SMALLER_TABLE, TIP5_TABLE}; - use crate::gates::lookup_table::LookupTable; - use crate::gates::noop::NoopGate; - use crate::plonk::prover::prove; - use crate::util::timing::TimingTree; - - #[test] - fn test_no_lookup() -> anyhow::Result<()> { - LOGGER_INITIALIZED.call_once(|| init_logger().unwrap()); - use crate::iop::witness::PartialWitness; - use crate::plonk::circuit_builder::CircuitBuilder; - use crate::plonk::circuit_data::CircuitConfig; - use crate::plonk::config::{GenericConfig, PoseidonGoldilocksConfig}; - - const D: usize = 2; - type C = PoseidonGoldilocksConfig; - type F = >::F; - - let config = CircuitConfig::standard_recursion_config(); - let mut builder = CircuitBuilder::::new(config); - builder.add_gate(NoopGate, vec![]); - let pw = PartialWitness::new(); - - let data = builder.build::(); - let mut timing = TimingTree::new("prove first", Level::Debug); - let proof = prove(&data.prover_only, &data.common, pw, &mut timing)?; - timing.print(); - data.verify(proof)?; - - Ok(()) - } - - #[should_panic] - #[test] - fn test_lookup_table_not_used() { - LOGGER_INITIALIZED.call_once(|| init_logger().unwrap()); - use crate::plonk::circuit_builder::CircuitBuilder; - use crate::plonk::circuit_data::CircuitConfig; - use crate::plonk::config::{GenericConfig, PoseidonGoldilocksConfig}; - - const D: usize = 2; - type C = PoseidonGoldilocksConfig; - type F = >::F; - - let config = CircuitConfig::standard_recursion_config(); - let mut builder = CircuitBuilder::::new(config); - - let tip5_table = TIP5_TABLE.to_vec(); - let table: LookupTable = Arc::new((0..256).zip_eq(tip5_table).collect()); - builder.add_lookup_table_from_pairs(table); - - builder.build::(); - } - - #[should_panic] - #[test] - fn test_lookup_without_table() { - LOGGER_INITIALIZED.call_once(|| init_logger().unwrap()); - use crate::plonk::circuit_builder::CircuitBuilder; - use crate::plonk::circuit_data::CircuitConfig; - use crate::plonk::config::{GenericConfig, PoseidonGoldilocksConfig}; - - const D: usize = 2; - type C = PoseidonGoldilocksConfig; - type F = >::F; - - let config = CircuitConfig::standard_recursion_config(); - let mut builder = CircuitBuilder::::new(config); - - let dummy = builder.add_virtual_target(); - builder.add_lookup_from_index(dummy, 0); - - builder.build::(); - } - - // Tests two lookups in one lookup table. - #[test] - fn test_one_lookup() -> anyhow::Result<()> { - use crate::field::types::Field; - use crate::iop::witness::{PartialWitness, WitnessWrite}; - use crate::plonk::circuit_builder::CircuitBuilder; - use crate::plonk::circuit_data::CircuitConfig; - use crate::plonk::config::{GenericConfig, PoseidonGoldilocksConfig}; - - const D: usize = 2; - type C = PoseidonGoldilocksConfig; - type F = >::F; - - LOGGER_INITIALIZED.call_once(|| init_logger().unwrap()); - let tip5_table = TIP5_TABLE.to_vec(); - let table: LookupTable = Arc::new((0..256).zip_eq(tip5_table).collect()); - let config = CircuitConfig::standard_recursion_config(); - let mut builder = CircuitBuilder::::new(config); - - let initial_a = builder.add_virtual_target(); - let initial_b = builder.add_virtual_target(); - - let look_val_a = 1; - let look_val_b = 2; - - let out_a = table[look_val_a].1; - let out_b = table[look_val_b].1; - let table_index = builder.add_lookup_table_from_pairs(table); - let output_a = builder.add_lookup_from_index(initial_a, table_index); - - let output_b = builder.add_lookup_from_index(initial_b, table_index); - - builder.register_public_input(initial_a); - builder.register_public_input(initial_b); - builder.register_public_input(output_a); - builder.register_public_input(output_b); - - let mut pw = PartialWitness::new(); - - pw.set_target(initial_a, F::from_canonical_usize(look_val_a)); - pw.set_target(initial_b, F::from_canonical_usize(look_val_b)); - - let data = builder.build::(); - let mut timing = TimingTree::new("prove one lookup", Level::Debug); - let proof = prove(&data.prover_only, &data.common, pw, &mut timing)?; - timing.print(); - data.verify(proof.clone())?; - - assert!( - proof.public_inputs[2] == F::from_canonical_u16(out_a), - "First lookup, at index {} in the Tip5 table gives an incorrect output.", - proof.public_inputs[0] - ); - assert!( - proof.public_inputs[3] == F::from_canonical_u16(out_b), - "Second lookup, at index {} in the Tip5 table gives an incorrect output.", - proof.public_inputs[1] - ); - - Ok(()) - } - - // Tests one lookup in two different lookup tables. - #[test] - pub fn test_two_luts() -> anyhow::Result<()> { - use crate::field::types::Field; - use crate::iop::witness::{PartialWitness, WitnessWrite}; - use crate::plonk::circuit_builder::CircuitBuilder; - use crate::plonk::circuit_data::CircuitConfig; - use crate::plonk::config::{GenericConfig, PoseidonGoldilocksConfig}; - - const D: usize = 2; - type C = PoseidonGoldilocksConfig; - type F = >::F; - - LOGGER_INITIALIZED.call_once(|| init_logger().unwrap()); - let config = CircuitConfig::standard_recursion_config(); - let mut builder = CircuitBuilder::::new(config); - - let initial_a = builder.add_virtual_target(); - let initial_b = builder.add_virtual_target(); - - let look_val_a = 1; - let look_val_b = 2; - - let tip5_table = TIP5_TABLE.to_vec(); - - let first_out = tip5_table[look_val_a]; - let second_out = tip5_table[look_val_b]; - - let table: LookupTable = Arc::new((0..256).zip_eq(tip5_table).collect()); - - let other_table = OTHER_TABLE.to_vec(); - - let table_index = builder.add_lookup_table_from_pairs(table); - let output_a = builder.add_lookup_from_index(initial_a, table_index); - - let output_b = builder.add_lookup_from_index(initial_b, table_index); - let sum = builder.add(output_a, output_b); - - let s = first_out + second_out; - let final_out = other_table[s as usize]; - - let table2: LookupTable = Arc::new((0..256).zip_eq(other_table).collect()); - let table2_index = builder.add_lookup_table_from_pairs(table2); - - let output_final = builder.add_lookup_from_index(sum, table2_index); - - builder.register_public_input(initial_a); - builder.register_public_input(initial_b); - builder.register_public_input(sum); - builder.register_public_input(output_a); - builder.register_public_input(output_b); - builder.register_public_input(output_final); - - let mut pw = PartialWitness::new(); - pw.set_target(initial_a, F::from_canonical_usize(look_val_a)); - pw.set_target(initial_b, F::from_canonical_usize(look_val_b)); - let data = builder.build::(); - let mut timing = TimingTree::new("prove two_luts", Level::Debug); - let proof = prove(&data.prover_only, &data.common, pw, &mut timing)?; - data.verify(proof.clone())?; - timing.print(); - - assert!( - proof.public_inputs[3] == F::from_canonical_u16(first_out), - "First lookup, at index {} in the Tip5 table gives an incorrect output.", - proof.public_inputs[0] - ); - assert!( - proof.public_inputs[4] == F::from_canonical_u16(second_out), - "Second lookup, at index {} in the Tip5 table gives an incorrect output.", - proof.public_inputs[1] - ); - assert!( - proof.public_inputs[2] == F::from_canonical_u16(s), - "Sum between the first two LUT outputs is incorrect." - ); - assert!( - proof.public_inputs[5] == F::from_canonical_u16(final_out), - "Output of the second LUT at index {} is incorrect.", - s - ); - - Ok(()) - } - - #[test] - pub fn test_different_inputs() -> anyhow::Result<()> { - use crate::field::types::Field; - use crate::iop::witness::{PartialWitness, WitnessWrite}; - use crate::plonk::circuit_builder::CircuitBuilder; - use crate::plonk::circuit_data::CircuitConfig; - use crate::plonk::config::{GenericConfig, PoseidonGoldilocksConfig}; - - const D: usize = 2; - type C = PoseidonGoldilocksConfig; - type F = >::F; - LOGGER_INITIALIZED.call_once(|| init_logger().unwrap()); - let config = CircuitConfig::standard_recursion_config(); - let mut builder = CircuitBuilder::::new(config); - - let initial_a = builder.add_virtual_target(); - let initial_b = builder.add_virtual_target(); - - let init_a = 1; - let init_b = 2; - - let tab: Vec = SMALLER_TABLE.to_vec(); - let table: LookupTable = Arc::new((2..10).zip_eq(tab).collect()); - - let other_table = OTHER_TABLE.to_vec(); - - let table2: LookupTable = Arc::new((0..256).zip_eq(other_table).collect()); - let small_index = builder.add_lookup_table_from_pairs(table.clone()); - let output_a = builder.add_lookup_from_index(initial_a, small_index); - - let output_b = builder.add_lookup_from_index(initial_b, small_index); - let sum = builder.add(output_a, output_b); - - let other_index = builder.add_lookup_table_from_pairs(table2.clone()); - let output_final = builder.add_lookup_from_index(sum, other_index); - - builder.register_public_input(initial_a); - builder.register_public_input(initial_b); - builder.register_public_input(sum); - builder.register_public_input(output_a); - builder.register_public_input(output_b); - builder.register_public_input(output_final); - - let mut pw = PartialWitness::new(); - - let look_val_a = table[init_a].0; - let look_val_b = table[init_b].0; - pw.set_target(initial_a, F::from_canonical_u16(look_val_a)); - pw.set_target(initial_b, F::from_canonical_u16(look_val_b)); - - let data = builder.build::(); - let mut timing = TimingTree::new("prove different lookups", Level::Debug); - let proof = prove(&data.prover_only, &data.common, pw, &mut timing)?; - data.verify(proof.clone())?; - timing.print(); - - let out_a = table[init_a].1; - let out_b = table[init_b].1; - let s = out_a + out_b; - let out_final = table2[s as usize].1; - - assert!( - proof.public_inputs[3] == F::from_canonical_u16(out_a), - "First lookup, at index {} in the smaller LUT gives an incorrect output.", - proof.public_inputs[0] - ); - assert!( - proof.public_inputs[4] == F::from_canonical_u16(out_b), - "Second lookup, at index {} in the smaller LUT gives an incorrect output.", - proof.public_inputs[1] - ); - assert!( - proof.public_inputs[2] == F::from_canonical_u16(s), - "Sum between the first two LUT outputs is incorrect." - ); - assert!( - proof.public_inputs[5] == F::from_canonical_u16(out_final), - "Output of the second LUT at index {} is incorrect.", - s - ); - - Ok(()) - } - - // This test looks up over 514 values for one LookupTableGate, which means that several LookupGates are created. - #[test] - pub fn test_many_lookups() -> anyhow::Result<()> { - use crate::field::types::Field; - use crate::iop::witness::{PartialWitness, WitnessWrite}; - use crate::plonk::circuit_builder::CircuitBuilder; - use crate::plonk::circuit_data::CircuitConfig; - use crate::plonk::config::{GenericConfig, PoseidonGoldilocksConfig}; - - const D: usize = 2; - type C = PoseidonGoldilocksConfig; - type F = >::F; - LOGGER_INITIALIZED.call_once(|| init_logger().unwrap()); - let config = CircuitConfig::standard_recursion_config(); - let mut builder = CircuitBuilder::::new(config); - - let initial_a = builder.add_virtual_target(); - let initial_b = builder.add_virtual_target(); - - let look_val_a = 1; - let look_val_b = 2; - - let tip5_table = TIP5_TABLE.to_vec(); - let table: LookupTable = Arc::new((0..256).zip_eq(tip5_table).collect()); - - let out_a = table[look_val_a].1; - let out_b = table[look_val_b].1; - - let tip5_index = builder.add_lookup_table_from_pairs(table); - let output_a = builder.add_lookup_from_index(initial_a, tip5_index); - - let output_b = builder.add_lookup_from_index(initial_b, tip5_index); - let sum = builder.add(output_a, output_b); - - for _ in 0..514 { - builder.add_lookup_from_index(initial_a, tip5_index); - } - - let other_table = OTHER_TABLE.to_vec(); - - let table2: LookupTable = Arc::new((0..256).zip_eq(other_table).collect()); - - let s = out_a + out_b; - let out_final = table2[s as usize].1; - - let other_index = builder.add_lookup_table_from_pairs(table2); - let output_final = builder.add_lookup_from_index(sum, other_index); - - builder.register_public_input(initial_a); - builder.register_public_input(initial_b); - builder.register_public_input(sum); - builder.register_public_input(output_a); - builder.register_public_input(output_b); - builder.register_public_input(output_final); - - let mut pw = PartialWitness::new(); - - pw.set_target(initial_a, F::from_canonical_usize(look_val_a)); - pw.set_target(initial_b, F::from_canonical_usize(look_val_b)); - - let data = builder.build::(); - let mut timing = TimingTree::new("prove different lookups", Level::Debug); - let proof = prove(&data.prover_only, &data.common, pw, &mut timing)?; - - data.verify(proof.clone())?; - timing.print(); - - assert!( - proof.public_inputs[3] == F::from_canonical_u16(out_a), - "First lookup, at index {} in the Tip5 table gives an incorrect output.", - proof.public_inputs[0] - ); - assert!( - proof.public_inputs[4] == F::from_canonical_u16(out_b), - "Second lookup, at index {} in the Tip5 table gives an incorrect output.", - proof.public_inputs[1] - ); - assert!( - proof.public_inputs[2] == F::from_canonical_u16(s), - "Sum between the first two LUT outputs is incorrect." - ); - assert!( - proof.public_inputs[5] == F::from_canonical_u16(out_final), - "Output of the second LUT at index {} is incorrect.", - s - ); - - Ok(()) - } - - // Tests whether, when adding the same LUT to the circuit, the circuit only adds one copy, with the same index. - #[test] - pub fn test_same_luts() -> anyhow::Result<()> { - use crate::field::types::Field; - use crate::iop::witness::{PartialWitness, WitnessWrite}; - use crate::plonk::circuit_builder::CircuitBuilder; - use crate::plonk::circuit_data::CircuitConfig; - use crate::plonk::config::{GenericConfig, PoseidonGoldilocksConfig}; - - const D: usize = 2; - type C = PoseidonGoldilocksConfig; - type F = >::F; - - LOGGER_INITIALIZED.call_once(|| init_logger().unwrap()); - let config = CircuitConfig::standard_recursion_config(); - let mut builder = CircuitBuilder::::new(config); - - let initial_a = builder.add_virtual_target(); - let initial_b = builder.add_virtual_target(); - - let look_val_a = 1; - let look_val_b = 2; - - let tip5_table = TIP5_TABLE.to_vec(); - let table: LookupTable = Arc::new((0..256).zip_eq(tip5_table).collect()); - - let table_index = builder.add_lookup_table_from_pairs(table.clone()); - let output_a = builder.add_lookup_from_index(initial_a, table_index); - - let output_b = builder.add_lookup_from_index(initial_b, table_index); - let sum = builder.add(output_a, output_b); - - let table2_index = builder.add_lookup_table_from_pairs(table); - - let output_final = builder.add_lookup_from_index(sum, table2_index); - - builder.register_public_input(initial_a); - builder.register_public_input(initial_b); - builder.register_public_input(sum); - builder.register_public_input(output_a); - builder.register_public_input(output_b); - builder.register_public_input(output_final); - - let luts_length = builder.get_luts_length(); - - assert!( - luts_length == 1, - "There are {} LUTs when there should be only one", - luts_length - ); - - let mut pw = PartialWitness::new(); - - pw.set_target(initial_a, F::from_canonical_usize(look_val_a)); - pw.set_target(initial_b, F::from_canonical_usize(look_val_b)); - - let data = builder.build::(); - let mut timing = TimingTree::new("prove two_luts", Level::Debug); - let proof = prove(&data.prover_only, &data.common, pw, &mut timing)?; - data.verify(proof)?; - timing.print(); - - Ok(()) - } - - #[test] - fn test_circuit_build_mock() { - // This code is taken from examples/fibonacci.rs - use crate::field::types::Field; - use crate::iop::witness::{PartialWitness, Witness, WitnessWrite}; - use crate::plonk::circuit_builder::CircuitBuilder; - use crate::plonk::circuit_data::CircuitConfig; - use crate::plonk::config::{GenericConfig, PoseidonGoldilocksConfig}; - - const D: usize = 2; - type C = PoseidonGoldilocksConfig; - type F = >::F; - - let config = CircuitConfig::standard_recursion_config(); - let mut builder = CircuitBuilder::::new(config); - - // The arithmetic circuit. - let initial_a = builder.add_virtual_target(); - let initial_b = builder.add_virtual_target(); - let mut prev_target = initial_a; - let mut cur_target = initial_b; - for _ in 0..99 { - let temp = builder.add(prev_target, cur_target); - prev_target = cur_target; - cur_target = temp; - } - - // Public inputs are the two initial values (provided below) and the result (which is generated). - builder.register_public_input(initial_a); - builder.register_public_input(initial_b); - builder.register_public_input(cur_target); - - // Provide initial values. - let mut pw = PartialWitness::new(); - pw.set_target(initial_a, F::ZERO); - pw.set_target(initial_b, F::ONE); - - let data = builder.mock_build::(); - let partition_witness = data.generate_witness(pw); - let result = partition_witness.try_get_target(cur_target).unwrap(); - assert_eq!(result, F::from_canonical_u64(3736710860384812976)); - } - - fn init_logger() -> anyhow::Result<()> { - let mut builder = env_logger::Builder::from_default_env(); - builder.format_timestamp(None); - builder.filter_level(LevelFilter::Debug); - - builder.try_init()?; - Ok(()) - } +static LOGGER_INITIALIZED: Once = Once::new(); + +use alloc::sync::Arc; +use std::sync::Once; + +use itertools::Itertools; +use log::{Level, LevelFilter}; + +use crate::gadgets::lookup::{OTHER_TABLE, SMALLER_TABLE, TIP5_TABLE}; +use crate::gates::lookup_table::LookupTable; +use crate::gates::noop::NoopGate; +use crate::plonk::prover::prove; +use crate::util::timing::TimingTree; + +#[test] +fn test_no_lookup() -> anyhow::Result<()> { + LOGGER_INITIALIZED.call_once(|| init_logger().unwrap()); + use crate::iop::witness::PartialWitness; + use crate::plonk::circuit_builder::CircuitBuilder; + use crate::plonk::circuit_data::CircuitConfig; + use crate::plonk::config::{GenericConfig, PoseidonGoldilocksConfig}; + + const D: usize = 2; + type C = PoseidonGoldilocksConfig; + type F = >::F; + + let config = CircuitConfig::standard_recursion_config(); + let mut builder = CircuitBuilder::::new(config); + builder.add_gate(NoopGate, vec![]); + let pw = PartialWitness::new(); + + let data = builder.build::(); + let mut timing = TimingTree::new("prove first", Level::Debug); + let proof = prove(&data.prover_only, &data.common, pw, &mut timing)?; + timing.print(); + data.verify(proof)?; + + Ok(()) +} + +#[should_panic] +#[test] +fn test_lookup_table_not_used() { + LOGGER_INITIALIZED.call_once(|| init_logger().unwrap()); + use crate::plonk::circuit_builder::CircuitBuilder; + use crate::plonk::circuit_data::CircuitConfig; + use crate::plonk::config::{GenericConfig, PoseidonGoldilocksConfig}; + + const D: usize = 2; + type C = PoseidonGoldilocksConfig; + type F = >::F; + + let config = CircuitConfig::standard_recursion_config(); + let mut builder = CircuitBuilder::::new(config); + + let tip5_table = TIP5_TABLE.to_vec(); + let table: LookupTable = Arc::new((0..256).zip_eq(tip5_table).collect()); + builder.add_lookup_table_from_pairs(table); + + builder.build::(); +} + +#[should_panic] +#[test] +fn test_lookup_without_table() { + LOGGER_INITIALIZED.call_once(|| init_logger().unwrap()); + use crate::plonk::circuit_builder::CircuitBuilder; + use crate::plonk::circuit_data::CircuitConfig; + use crate::plonk::config::{GenericConfig, PoseidonGoldilocksConfig}; + + const D: usize = 2; + type C = PoseidonGoldilocksConfig; + type F = >::F; + + let config = CircuitConfig::standard_recursion_config(); + let mut builder = CircuitBuilder::::new(config); + + let dummy = builder.add_virtual_target(); + builder.add_lookup_from_index(dummy, 0); + + builder.build::(); +} + +// Tests two lookups in one lookup table. +#[test] +fn test_one_lookup() -> anyhow::Result<()> { + use crate::field::types::Field; + use crate::iop::witness::{PartialWitness, WitnessWrite}; + use crate::plonk::circuit_builder::CircuitBuilder; + use crate::plonk::circuit_data::CircuitConfig; + use crate::plonk::config::{GenericConfig, PoseidonGoldilocksConfig}; + + const D: usize = 2; + type C = PoseidonGoldilocksConfig; + type F = >::F; + + LOGGER_INITIALIZED.call_once(|| init_logger().unwrap()); + let tip5_table = TIP5_TABLE.to_vec(); + let table: LookupTable = Arc::new((0..256).zip_eq(tip5_table).collect()); + let config = CircuitConfig::standard_recursion_config(); + let mut builder = CircuitBuilder::::new(config); + + let initial_a = builder.add_virtual_target(); + let initial_b = builder.add_virtual_target(); + + let look_val_a = 1; + let look_val_b = 2; + + let out_a = table[look_val_a].1; + let out_b = table[look_val_b].1; + let table_index = builder.add_lookup_table_from_pairs(table); + let output_a = builder.add_lookup_from_index(initial_a, table_index); + + let output_b = builder.add_lookup_from_index(initial_b, table_index); + + builder.register_public_input(initial_a); + builder.register_public_input(initial_b); + builder.register_public_input(output_a); + builder.register_public_input(output_b); + + let mut pw = PartialWitness::new(); + + pw.set_target(initial_a, F::from_canonical_usize(look_val_a)); + pw.set_target(initial_b, F::from_canonical_usize(look_val_b)); + + let data = builder.build::(); + let mut timing = TimingTree::new("prove one lookup", Level::Debug); + let proof = prove(&data.prover_only, &data.common, pw, &mut timing)?; + timing.print(); + data.verify(proof.clone())?; + + assert!( + proof.public_inputs[2] == F::from_canonical_u16(out_a), + "First lookup, at index {} in the Tip5 table gives an incorrect output.", + proof.public_inputs[0] + ); + assert!( + proof.public_inputs[3] == F::from_canonical_u16(out_b), + "Second lookup, at index {} in the Tip5 table gives an incorrect output.", + proof.public_inputs[1] + ); + + Ok(()) +} + +// Tests one lookup in two different lookup tables. +#[test] +pub fn test_two_luts() -> anyhow::Result<()> { + use crate::field::types::Field; + use crate::iop::witness::{PartialWitness, WitnessWrite}; + use crate::plonk::circuit_builder::CircuitBuilder; + use crate::plonk::circuit_data::CircuitConfig; + use crate::plonk::config::{GenericConfig, PoseidonGoldilocksConfig}; + + const D: usize = 2; + type C = PoseidonGoldilocksConfig; + type F = >::F; + + LOGGER_INITIALIZED.call_once(|| init_logger().unwrap()); + let config = CircuitConfig::standard_recursion_config(); + let mut builder = CircuitBuilder::::new(config); + + let initial_a = builder.add_virtual_target(); + let initial_b = builder.add_virtual_target(); + + let look_val_a = 1; + let look_val_b = 2; + + let tip5_table = TIP5_TABLE.to_vec(); + + let first_out = tip5_table[look_val_a]; + let second_out = tip5_table[look_val_b]; + + let table: LookupTable = Arc::new((0..256).zip_eq(tip5_table).collect()); + + let other_table = OTHER_TABLE.to_vec(); + + let table_index = builder.add_lookup_table_from_pairs(table); + let output_a = builder.add_lookup_from_index(initial_a, table_index); + + let output_b = builder.add_lookup_from_index(initial_b, table_index); + let sum = builder.add(output_a, output_b); + + let s = first_out + second_out; + let final_out = other_table[s as usize]; + + let table2: LookupTable = Arc::new((0..256).zip_eq(other_table).collect()); + let table2_index = builder.add_lookup_table_from_pairs(table2); + + let output_final = builder.add_lookup_from_index(sum, table2_index); + + builder.register_public_input(initial_a); + builder.register_public_input(initial_b); + builder.register_public_input(sum); + builder.register_public_input(output_a); + builder.register_public_input(output_b); + builder.register_public_input(output_final); + + let mut pw = PartialWitness::new(); + pw.set_target(initial_a, F::from_canonical_usize(look_val_a)); + pw.set_target(initial_b, F::from_canonical_usize(look_val_b)); + let data = builder.build::(); + let mut timing = TimingTree::new("prove two_luts", Level::Debug); + let proof = prove(&data.prover_only, &data.common, pw, &mut timing)?; + data.verify(proof.clone())?; + timing.print(); + + assert!( + proof.public_inputs[3] == F::from_canonical_u16(first_out), + "First lookup, at index {} in the Tip5 table gives an incorrect output.", + proof.public_inputs[0] + ); + assert!( + proof.public_inputs[4] == F::from_canonical_u16(second_out), + "Second lookup, at index {} in the Tip5 table gives an incorrect output.", + proof.public_inputs[1] + ); + assert!( + proof.public_inputs[2] == F::from_canonical_u16(s), + "Sum between the first two LUT outputs is incorrect." + ); + assert!( + proof.public_inputs[5] == F::from_canonical_u16(final_out), + "Output of the second LUT at index {} is incorrect.", + s + ); + + Ok(()) +} + +#[test] +pub fn test_different_inputs() -> anyhow::Result<()> { + use crate::field::types::Field; + use crate::iop::witness::{PartialWitness, WitnessWrite}; + use crate::plonk::circuit_builder::CircuitBuilder; + use crate::plonk::circuit_data::CircuitConfig; + use crate::plonk::config::{GenericConfig, PoseidonGoldilocksConfig}; + + const D: usize = 2; + type C = PoseidonGoldilocksConfig; + type F = >::F; + LOGGER_INITIALIZED.call_once(|| init_logger().unwrap()); + let config = CircuitConfig::standard_recursion_config(); + let mut builder = CircuitBuilder::::new(config); + + let initial_a = builder.add_virtual_target(); + let initial_b = builder.add_virtual_target(); + + let init_a = 1; + let init_b = 2; + + let tab: Vec = SMALLER_TABLE.to_vec(); + let table: LookupTable = Arc::new((2..10).zip_eq(tab).collect()); + + let other_table = OTHER_TABLE.to_vec(); + + let table2: LookupTable = Arc::new((0..256).zip_eq(other_table).collect()); + let small_index = builder.add_lookup_table_from_pairs(table.clone()); + let output_a = builder.add_lookup_from_index(initial_a, small_index); + + let output_b = builder.add_lookup_from_index(initial_b, small_index); + let sum = builder.add(output_a, output_b); + + let other_index = builder.add_lookup_table_from_pairs(table2.clone()); + let output_final = builder.add_lookup_from_index(sum, other_index); + + builder.register_public_input(initial_a); + builder.register_public_input(initial_b); + builder.register_public_input(sum); + builder.register_public_input(output_a); + builder.register_public_input(output_b); + builder.register_public_input(output_final); + + let mut pw = PartialWitness::new(); + + let look_val_a = table[init_a].0; + let look_val_b = table[init_b].0; + pw.set_target(initial_a, F::from_canonical_u16(look_val_a)); + pw.set_target(initial_b, F::from_canonical_u16(look_val_b)); + + let data = builder.build::(); + let mut timing = TimingTree::new("prove different lookups", Level::Debug); + let proof = prove(&data.prover_only, &data.common, pw, &mut timing)?; + data.verify(proof.clone())?; + timing.print(); + + let out_a = table[init_a].1; + let out_b = table[init_b].1; + let s = out_a + out_b; + let out_final = table2[s as usize].1; + + assert!( + proof.public_inputs[3] == F::from_canonical_u16(out_a), + "First lookup, at index {} in the smaller LUT gives an incorrect output.", + proof.public_inputs[0] + ); + assert!( + proof.public_inputs[4] == F::from_canonical_u16(out_b), + "Second lookup, at index {} in the smaller LUT gives an incorrect output.", + proof.public_inputs[1] + ); + assert!( + proof.public_inputs[2] == F::from_canonical_u16(s), + "Sum between the first two LUT outputs is incorrect." + ); + assert!( + proof.public_inputs[5] == F::from_canonical_u16(out_final), + "Output of the second LUT at index {} is incorrect.", + s + ); + + Ok(()) +} + +// This test looks up over 514 values for one LookupTableGate, which means that several LookupGates are created. +#[test] +pub fn test_many_lookups() -> anyhow::Result<()> { + use crate::field::types::Field; + use crate::iop::witness::{PartialWitness, WitnessWrite}; + use crate::plonk::circuit_builder::CircuitBuilder; + use crate::plonk::circuit_data::CircuitConfig; + use crate::plonk::config::{GenericConfig, PoseidonGoldilocksConfig}; + + const D: usize = 2; + type C = PoseidonGoldilocksConfig; + type F = >::F; + LOGGER_INITIALIZED.call_once(|| init_logger().unwrap()); + let config = CircuitConfig::standard_recursion_config(); + let mut builder = CircuitBuilder::::new(config); + + let initial_a = builder.add_virtual_target(); + let initial_b = builder.add_virtual_target(); + + let look_val_a = 1; + let look_val_b = 2; + + let tip5_table = TIP5_TABLE.to_vec(); + let table: LookupTable = Arc::new((0..256).zip_eq(tip5_table).collect()); + + let out_a = table[look_val_a].1; + let out_b = table[look_val_b].1; + + let tip5_index = builder.add_lookup_table_from_pairs(table); + let output_a = builder.add_lookup_from_index(initial_a, tip5_index); + + let output_b = builder.add_lookup_from_index(initial_b, tip5_index); + let sum = builder.add(output_a, output_b); + + for _ in 0..514 { + builder.add_lookup_from_index(initial_a, tip5_index); + } + + let other_table = OTHER_TABLE.to_vec(); + + let table2: LookupTable = Arc::new((0..256).zip_eq(other_table).collect()); + + let s = out_a + out_b; + let out_final = table2[s as usize].1; + + let other_index = builder.add_lookup_table_from_pairs(table2); + let output_final = builder.add_lookup_from_index(sum, other_index); + + builder.register_public_input(initial_a); + builder.register_public_input(initial_b); + builder.register_public_input(sum); + builder.register_public_input(output_a); + builder.register_public_input(output_b); + builder.register_public_input(output_final); + + let mut pw = PartialWitness::new(); + + pw.set_target(initial_a, F::from_canonical_usize(look_val_a)); + pw.set_target(initial_b, F::from_canonical_usize(look_val_b)); + + let data = builder.build::(); + let mut timing = TimingTree::new("prove different lookups", Level::Debug); + let proof = prove(&data.prover_only, &data.common, pw, &mut timing)?; + + data.verify(proof.clone())?; + timing.print(); + + assert!( + proof.public_inputs[3] == F::from_canonical_u16(out_a), + "First lookup, at index {} in the Tip5 table gives an incorrect output.", + proof.public_inputs[0] + ); + assert!( + proof.public_inputs[4] == F::from_canonical_u16(out_b), + "Second lookup, at index {} in the Tip5 table gives an incorrect output.", + proof.public_inputs[1] + ); + assert!( + proof.public_inputs[2] == F::from_canonical_u16(s), + "Sum between the first two LUT outputs is incorrect." + ); + assert!( + proof.public_inputs[5] == F::from_canonical_u16(out_final), + "Output of the second LUT at index {} is incorrect.", + s + ); + + Ok(()) +} + +// Tests whether, when adding the same LUT to the circuit, the circuit only adds one copy, with the same index. +#[test] +pub fn test_same_luts() -> anyhow::Result<()> { + use crate::field::types::Field; + use crate::iop::witness::{PartialWitness, WitnessWrite}; + use crate::plonk::circuit_builder::CircuitBuilder; + use crate::plonk::circuit_data::CircuitConfig; + use crate::plonk::config::{GenericConfig, PoseidonGoldilocksConfig}; + + const D: usize = 2; + type C = PoseidonGoldilocksConfig; + type F = >::F; + + LOGGER_INITIALIZED.call_once(|| init_logger().unwrap()); + let config = CircuitConfig::standard_recursion_config(); + let mut builder = CircuitBuilder::::new(config); + + let initial_a = builder.add_virtual_target(); + let initial_b = builder.add_virtual_target(); + + let look_val_a = 1; + let look_val_b = 2; + + let tip5_table = TIP5_TABLE.to_vec(); + let table: LookupTable = Arc::new((0..256).zip_eq(tip5_table).collect()); + + let table_index = builder.add_lookup_table_from_pairs(table.clone()); + let output_a = builder.add_lookup_from_index(initial_a, table_index); + + let output_b = builder.add_lookup_from_index(initial_b, table_index); + let sum = builder.add(output_a, output_b); + + let table2_index = builder.add_lookup_table_from_pairs(table); + + let output_final = builder.add_lookup_from_index(sum, table2_index); + + builder.register_public_input(initial_a); + builder.register_public_input(initial_b); + builder.register_public_input(sum); + builder.register_public_input(output_a); + builder.register_public_input(output_b); + builder.register_public_input(output_final); + + let luts_length = builder.get_luts_length(); + + assert!( + luts_length == 1, + "There are {} LUTs when there should be only one", + luts_length + ); + + let mut pw = PartialWitness::new(); + + pw.set_target(initial_a, F::from_canonical_usize(look_val_a)); + pw.set_target(initial_b, F::from_canonical_usize(look_val_b)); + + let data = builder.build::(); + let mut timing = TimingTree::new("prove two_luts", Level::Debug); + let proof = prove(&data.prover_only, &data.common, pw, &mut timing)?; + data.verify(proof)?; + timing.print(); + + Ok(()) +} + +fn init_logger() -> anyhow::Result<()> { + let mut builder = env_logger::Builder::from_default_env(); + builder.format_timestamp(None); + builder.filter_level(LevelFilter::Debug); + + builder.try_init()?; + Ok(()) } diff --git a/plonky2/src/plonk/prover.rs b/plonky2/src/plonk/prover.rs index b77f7aa5..41aebdb1 100644 --- a/plonky2/src/plonk/prover.rs +++ b/plonky2/src/plonk/prover.rs @@ -113,6 +113,29 @@ pub fn prove, C: GenericConfig, const D: inputs: PartialWitness, timing: &mut TimingTree, ) -> Result> +where + C::Hasher: Hasher, + C::InnerHasher: Hasher, +{ + let partition_witness = timed!( + timing, + &format!("run {} generators", prover_data.generators.len()), + generate_partial_witness(inputs, prover_data, common_data) + ); + + prove_with_partition_witness(prover_data, common_data, partition_witness, timing) +} + +pub fn prove_with_partition_witness< + F: RichField + Extendable, + C: GenericConfig, + const D: usize, +>( + prover_data: &ProverOnlyCircuitData, + common_data: &CommonCircuitData, + mut partition_witness: PartitionWitness, + timing: &mut TimingTree, +) -> Result> where C::Hasher: Hasher, C::InnerHasher: Hasher, @@ -123,12 +146,6 @@ where let quotient_degree = common_data.quotient_degree(); let degree = common_data.degree(); - let mut partition_witness = timed!( - timing, - &format!("run {} generators", prover_data.generators.len()), - generate_partial_witness(inputs, prover_data, common_data) - ); - set_lookup_wires(prover_data, common_data, &mut partition_witness); let public_inputs = partition_witness.get_targets(&prover_data.public_inputs);