From 1d60431992ab3cc90addfa12b43f851e35ad97cb Mon Sep 17 00:00:00 2001 From: Hamy Ratoanina Date: Wed, 11 Oct 2023 22:28:49 +0200 Subject: [PATCH] Store top of the stack in memory channel 0 (#1215) * Store top of the stack in memory channel 0 * Fix interpreter * Apply comments * Remove debugging code * Merge commit * Remove debugging comments * Apply comments * Fix witness generation for exceptions * Fix witness generation for exceptions (again) * Fix modfp254 constraint --- evm/src/cpu/columns/general.rs | 20 + evm/src/cpu/columns/ops.rs | 3 +- evm/src/cpu/contextops.rs | 226 ++++++----- evm/src/cpu/control_flow.rs | 8 +- evm/src/cpu/cpu_stark.rs | 12 +- evm/src/cpu/decode.rs | 5 +- evm/src/cpu/dup_swap.rs | 131 ++++--- evm/src/cpu/gas.rs | 3 +- evm/src/cpu/jumps.rs | 107 +++++- evm/src/cpu/kernel/interpreter.rs | 54 ++- evm/src/cpu/kernel/tests/signed_syscalls.rs | 4 +- evm/src/cpu/memio.rs | 178 +++++++-- evm/src/cpu/mod.rs | 2 +- evm/src/cpu/modfp254.rs | 4 +- evm/src/cpu/pc.rs | 15 +- evm/src/cpu/push0.rs | 9 +- evm/src/cpu/stack.rs | 394 +++++++++++++++----- evm/src/cpu/syscalls_exceptions.rs | 21 +- evm/src/witness/memory.rs | 12 + evm/src/witness/operation.rs | 316 +++++++++++----- evm/src/witness/state.rs | 7 + evm/src/witness/transition.rs | 96 ++++- evm/src/witness/util.rs | 125 +++++-- 23 files changed, 1256 insertions(+), 496 deletions(-) diff --git a/evm/src/cpu/columns/general.rs b/evm/src/cpu/columns/general.rs index 57eb16fc..d4f34473 100644 --- a/evm/src/cpu/columns/general.rs +++ b/evm/src/cpu/columns/general.rs @@ -10,6 +10,7 @@ pub(crate) union CpuGeneralColumnsView { logic: CpuLogicView, jumps: CpuJumpsView, shift: CpuShiftView, + stack: CpuStackView, } impl CpuGeneralColumnsView { @@ -52,6 +53,16 @@ impl CpuGeneralColumnsView { pub(crate) fn shift_mut(&mut self) -> &mut CpuShiftView { unsafe { &mut self.shift } } + + // SAFETY: Each view is a valid interpretation of the underlying array. + pub(crate) fn stack(&self) -> &CpuStackView { + unsafe { &self.stack } + } + + // SAFETY: Each view is a valid interpretation of the underlying array. + pub(crate) fn stack_mut(&mut self) -> &mut CpuStackView { + unsafe { &mut self.stack } + } } impl PartialEq for CpuGeneralColumnsView { @@ -110,5 +121,14 @@ pub(crate) struct CpuShiftView { pub(crate) high_limb_sum_inv: T, } +#[derive(Copy, Clone)] +pub(crate) struct CpuStackView { + // Used for conditionally enabling and disabling channels when reading the next `stack_top`. + _unused: [T; 5], + pub(crate) stack_inv: T, + pub(crate) stack_inv_aux: T, + pub(crate) stack_inv_aux_2: T, +} + // `u8` is guaranteed to have a `size_of` of 1. pub const NUM_SHARED_COLUMNS: usize = size_of::>(); diff --git a/evm/src/cpu/columns/ops.rs b/evm/src/cpu/columns/ops.rs index 64474c98..feeb3f5f 100644 --- a/evm/src/cpu/columns/ops.rs +++ b/evm/src/cpu/columns/ops.rs @@ -24,7 +24,8 @@ pub struct OpsColumnsView { pub push: T, pub dup: T, pub swap: T, - pub context_op: T, + pub get_context: T, + pub set_context: T, pub mstore_32bytes: T, pub mload_32bytes: T, pub exit_kernel: T, diff --git a/evm/src/cpu/contextops.rs b/evm/src/cpu/contextops.rs index 55f14820..1683c30e 100644 --- a/evm/src/cpu/contextops.rs +++ b/evm/src/cpu/contextops.rs @@ -8,99 +8,38 @@ use plonky2::plonk::circuit_builder::CircuitBuilder; use crate::constraint_consumer::{ConstraintConsumer, RecursiveConstraintConsumer}; use crate::cpu::columns::CpuColumnsView; use crate::cpu::kernel::constants::context_metadata::ContextMetadata; -use crate::cpu::membus::NUM_GP_CHANNELS; use crate::memory::segments::Segment; fn eval_packed_get( lv: &CpuColumnsView

, + nv: &CpuColumnsView

, yield_constr: &mut ConstraintConsumer

, ) { - // If the opcode is GET_CONTEXT, then lv.opcode_bits[0] = 0 - let filter = lv.op.context_op * (P::ONES - lv.opcode_bits[0]); - let push_channel = lv.mem_channels[NUM_GP_CHANNELS - 1]; - yield_constr.constraint(filter * (push_channel.value[0] - lv.context)); - for &limb in &push_channel.value[1..] { + let filter = lv.op.get_context; + let new_stack_top = nv.mem_channels[0].value; + yield_constr.constraint(filter * (new_stack_top[0] - lv.context)); + for &limb in &new_stack_top[1..] { yield_constr.constraint(filter * limb); } - - // Stack constraints - let channel = lv.mem_channels[NUM_GP_CHANNELS - 1]; - yield_constr.constraint(filter * (channel.used - P::ONES)); - yield_constr.constraint(filter * channel.is_read); - - yield_constr.constraint(filter * (channel.addr_context - lv.context)); - yield_constr.constraint( - filter * (channel.addr_segment - P::Scalar::from_canonical_u64(Segment::Stack as u64)), - ); - let addr_virtual = lv.stack_len; - yield_constr.constraint(filter * (channel.addr_virtual - addr_virtual)); - - // Unused channels - for i in 0..NUM_GP_CHANNELS - 1 { - let channel = lv.mem_channels[i]; - yield_constr.constraint(filter * channel.used); - } } fn eval_ext_circuit_get, const D: usize>( builder: &mut CircuitBuilder, lv: &CpuColumnsView>, + nv: &CpuColumnsView>, yield_constr: &mut RecursiveConstraintConsumer, ) { - let mut filter = lv.op.context_op; - let one = builder.one_extension(); - let minus = builder.sub_extension(one, lv.opcode_bits[0]); - filter = builder.mul_extension(filter, minus); - - let push_channel = lv.mem_channels[NUM_GP_CHANNELS - 1]; + let filter = lv.op.get_context; + let new_stack_top = nv.mem_channels[0].value; { - let diff = builder.sub_extension(push_channel.value[0], lv.context); + let diff = builder.sub_extension(new_stack_top[0], lv.context); let constr = builder.mul_extension(filter, diff); yield_constr.constraint(builder, constr); } - for &limb in &push_channel.value[1..] { + for &limb in &new_stack_top[1..] { let constr = builder.mul_extension(filter, limb); yield_constr.constraint(builder, constr); } - - // Stack constraints - let channel = lv.mem_channels[NUM_GP_CHANNELS - 1]; - - { - let constr = builder.mul_sub_extension(filter, channel.used, filter); - yield_constr.constraint(builder, constr); - } - { - let constr = builder.mul_extension(filter, channel.is_read); - yield_constr.constraint(builder, constr); - } - - { - let diff = builder.sub_extension(channel.addr_context, lv.context); - let constr = builder.mul_extension(filter, diff); - yield_constr.constraint(builder, constr); - } - { - let constr = builder.arithmetic_extension( - F::ONE, - -F::from_canonical_u64(Segment::Stack as u64), - filter, - channel.addr_segment, - filter, - ); - yield_constr.constraint(builder, constr); - } - { - let diff = builder.sub_extension(channel.addr_virtual, lv.stack_len); - let constr = builder.arithmetic_extension(F::ONE, F::ZERO, filter, diff, filter); - yield_constr.constraint(builder, constr); - } - - for i in 0..NUM_GP_CHANNELS - 1 { - let channel = lv.mem_channels[i]; - let constr = builder.mul_extension(filter, channel.used); - yield_constr.constraint(builder, constr); - } } fn eval_packed_set( @@ -108,22 +47,16 @@ fn eval_packed_set( nv: &CpuColumnsView

, yield_constr: &mut ConstraintConsumer

, ) { - let filter = lv.op.context_op * lv.opcode_bits[0]; - let pop_channel = lv.mem_channels[0]; + let filter = lv.op.set_context; + let stack_top = lv.mem_channels[0].value; let write_old_sp_channel = lv.mem_channels[1]; let read_new_sp_channel = lv.mem_channels[2]; - let stack_segment = P::Scalar::from_canonical_u64(Segment::Stack as u64); let ctx_metadata_segment = P::Scalar::from_canonical_u64(Segment::ContextMetadata as u64); let stack_size_field = P::Scalar::from_canonical_u64(ContextMetadata::StackSize as u64); let local_sp_dec = lv.stack_len - P::ONES; - // The next row's context is read from memory channel 0. - yield_constr.constraint(filter * (pop_channel.value[0] - nv.context)); - yield_constr.constraint(filter * (pop_channel.used - P::ONES)); - yield_constr.constraint(filter * (pop_channel.is_read - P::ONES)); - yield_constr.constraint(filter * (pop_channel.addr_context - lv.context)); - yield_constr.constraint(filter * (pop_channel.addr_segment - stack_segment)); - yield_constr.constraint(filter * (pop_channel.addr_virtual - local_sp_dec)); + // The next row's context is read from stack_top. + yield_constr.constraint(filter * (stack_top[0] - nv.context)); // The old SP is decremented (since the new context was popped) and written to memory. yield_constr.constraint(filter * (write_old_sp_channel.value[0] - local_sp_dec)); @@ -144,10 +77,34 @@ fn eval_packed_set( yield_constr.constraint(filter * (read_new_sp_channel.addr_segment - ctx_metadata_segment)); yield_constr.constraint(filter * (read_new_sp_channel.addr_virtual - stack_size_field)); - // Disable unused memory channels - for &channel in &lv.mem_channels[3..] { - yield_constr.constraint(filter * channel.used); + // The next row's stack top is loaded from memory (if the stack isn't empty). + yield_constr.constraint(filter * nv.mem_channels[0].used); + + let read_new_stack_top_channel = lv.mem_channels[3]; + let stack_segment = P::Scalar::from_canonical_u64(Segment::Stack as u64); + let new_filter = filter * nv.stack_len; + + for (limb_channel, limb_top) in read_new_stack_top_channel + .value + .iter() + .zip(nv.mem_channels[0].value) + { + yield_constr.constraint(new_filter * (*limb_channel - limb_top)); } + yield_constr.constraint(new_filter * (read_new_stack_top_channel.used - P::ONES)); + yield_constr.constraint(new_filter * (read_new_stack_top_channel.is_read - P::ONES)); + yield_constr.constraint(new_filter * (read_new_stack_top_channel.addr_context - nv.context)); + yield_constr.constraint(new_filter * (read_new_stack_top_channel.addr_segment - stack_segment)); + yield_constr.constraint( + new_filter * (read_new_stack_top_channel.addr_virtual - (nv.stack_len - P::ONES)), + ); + + // If the new stack is empty, disable the channel read. + yield_constr.constraint( + filter * (nv.stack_len * lv.general.stack().stack_inv - lv.general.stack().stack_inv_aux), + ); + let empty_stack_filter = filter * (lv.general.stack().stack_inv_aux - P::ONES); + yield_constr.constraint(empty_stack_filter * read_new_stack_top_channel.used); } fn eval_ext_circuit_set, const D: usize>( @@ -156,13 +113,10 @@ fn eval_ext_circuit_set, const D: usize>( nv: &CpuColumnsView>, yield_constr: &mut RecursiveConstraintConsumer, ) { - let mut filter = lv.op.context_op; - filter = builder.mul_extension(filter, lv.opcode_bits[0]); - let pop_channel = lv.mem_channels[0]; + let filter = lv.op.set_context; + let stack_top = lv.mem_channels[0].value; let write_old_sp_channel = lv.mem_channels[1]; let read_new_sp_channel = lv.mem_channels[2]; - let stack_segment = - builder.constant_extension(F::Extension::from_canonical_u32(Segment::Stack as u32)); let ctx_metadata_segment = builder.constant_extension(F::Extension::from_canonical_u32( Segment::ContextMetadata as u32, )); @@ -172,32 +126,9 @@ fn eval_ext_circuit_set, const D: usize>( let one = builder.one_extension(); let local_sp_dec = builder.sub_extension(lv.stack_len, one); - // The next row's context is read from memory channel 0. + // The next row's context is read from stack_top. { - let diff = builder.sub_extension(pop_channel.value[0], nv.context); - let constr = builder.mul_extension(filter, diff); - yield_constr.constraint(builder, constr); - } - { - let constr = builder.mul_sub_extension(filter, pop_channel.used, filter); - yield_constr.constraint(builder, constr); - } - { - let constr = builder.mul_sub_extension(filter, pop_channel.is_read, filter); - yield_constr.constraint(builder, constr); - } - { - let diff = builder.sub_extension(pop_channel.addr_context, lv.context); - let constr = builder.mul_extension(filter, diff); - yield_constr.constraint(builder, constr); - } - { - let diff = builder.sub_extension(pop_channel.addr_segment, stack_segment); - let constr = builder.mul_extension(filter, diff); - yield_constr.constraint(builder, constr); - } - { - let diff = builder.sub_extension(pop_channel.addr_virtual, local_sp_dec); + let diff = builder.sub_extension(stack_top[0], nv.context); let constr = builder.mul_extension(filter, diff); yield_constr.constraint(builder, constr); } @@ -266,9 +197,66 @@ fn eval_ext_circuit_set, const D: usize>( yield_constr.constraint(builder, constr); } - // Disable unused memory channels - for &channel in &lv.mem_channels[3..] { - let constr = builder.mul_extension(filter, channel.used); + // The next row's stack top is loaded from memory (if the stack isn't empty). + { + let constr = builder.mul_extension(filter, nv.mem_channels[0].used); + yield_constr.constraint(builder, constr); + } + + let read_new_stack_top_channel = lv.mem_channels[3]; + let stack_segment = + builder.constant_extension(F::Extension::from_canonical_u32(Segment::Stack as u32)); + + let new_filter = builder.mul_extension(filter, nv.stack_len); + + for (limb_channel, limb_top) in read_new_stack_top_channel + .value + .iter() + .zip(nv.mem_channels[0].value) + { + let diff = builder.sub_extension(*limb_channel, limb_top); + let constr = builder.mul_extension(new_filter, diff); + yield_constr.constraint(builder, constr); + } + { + let constr = + builder.mul_sub_extension(new_filter, read_new_stack_top_channel.used, new_filter); + yield_constr.constraint(builder, constr); + } + { + let constr = + builder.mul_sub_extension(new_filter, read_new_stack_top_channel.is_read, new_filter); + yield_constr.constraint(builder, constr); + } + { + let diff = builder.sub_extension(read_new_stack_top_channel.addr_context, nv.context); + let constr = builder.mul_extension(new_filter, diff); + yield_constr.constraint(builder, constr); + } + { + let diff = builder.sub_extension(read_new_stack_top_channel.addr_segment, stack_segment); + let constr = builder.mul_extension(new_filter, diff); + yield_constr.constraint(builder, constr); + } + { + let diff = builder.sub_extension(nv.stack_len, one); + let diff = builder.sub_extension(read_new_stack_top_channel.addr_virtual, diff); + let constr = builder.mul_extension(new_filter, diff); + yield_constr.constraint(builder, constr); + } + + // If the new stack is empty, disable the channel read. + { + let diff = builder.mul_extension(nv.stack_len, lv.general.stack().stack_inv); + let diff = builder.sub_extension(diff, lv.general.stack().stack_inv_aux); + let constr = builder.mul_extension(filter, diff); + yield_constr.constraint(builder, constr); + } + + { + let empty_stack_filter = + builder.mul_sub_extension(filter, lv.general.stack().stack_inv_aux, filter); + let constr = builder.mul_extension(empty_stack_filter, read_new_stack_top_channel.used); yield_constr.constraint(builder, constr); } } @@ -278,7 +266,7 @@ pub fn eval_packed( nv: &CpuColumnsView

, yield_constr: &mut ConstraintConsumer

, ) { - eval_packed_get(lv, yield_constr); + eval_packed_get(lv, nv, yield_constr); eval_packed_set(lv, nv, yield_constr); } @@ -288,6 +276,6 @@ pub fn eval_ext_circuit, const D: usize>( nv: &CpuColumnsView>, yield_constr: &mut RecursiveConstraintConsumer, ) { - eval_ext_circuit_get(builder, lv, yield_constr); + eval_ext_circuit_get(builder, lv, nv, yield_constr); eval_ext_circuit_set(builder, lv, nv, yield_constr); } diff --git a/evm/src/cpu/control_flow.rs b/evm/src/cpu/control_flow.rs index 9c17367a..a192ffb1 100644 --- a/evm/src/cpu/control_flow.rs +++ b/evm/src/cpu/control_flow.rs @@ -8,7 +8,7 @@ use crate::constraint_consumer::{ConstraintConsumer, RecursiveConstraintConsumer use crate::cpu::columns::{CpuColumnsView, COL_MAP}; use crate::cpu::kernel::aggregator::KERNEL; -const NATIVE_INSTRUCTIONS: [usize; 17] = [ +const NATIVE_INSTRUCTIONS: [usize; 18] = [ COL_MAP.op.binary_op, COL_MAP.op.ternary_op, COL_MAP.op.fp254_op, @@ -19,15 +19,15 @@ const NATIVE_INSTRUCTIONS: [usize; 17] = [ COL_MAP.op.keccak_general, COL_MAP.op.prover_input, COL_MAP.op.pop, - // not JUMP (need to jump) - // not JUMPI (possible need to jump) + // not JUMPS (possible need to jump) COL_MAP.op.pc, COL_MAP.op.jumpdest, COL_MAP.op.push0, // not PUSH (need to increment by more than 1) COL_MAP.op.dup, COL_MAP.op.swap, - COL_MAP.op.context_op, + COL_MAP.op.get_context, + COL_MAP.op.set_context, // not EXIT_KERNEL (performs a jump) COL_MAP.op.m_op_general, // not SYSCALL (performs a jump) diff --git a/evm/src/cpu/cpu_stark.rs b/evm/src/cpu/cpu_stark.rs index a77adbcb..64a2db9c 100644 --- a/evm/src/cpu/cpu_stark.rs +++ b/evm/src/cpu/cpu_stark.rs @@ -224,15 +224,15 @@ impl, const D: usize> Stark for CpuStark, const D: usize> Stark for CpuStark, const D: usize>( fn eval_packed_dup( n: P, lv: &CpuColumnsView

, + nv: &CpuColumnsView

, yield_constr: &mut ConstraintConsumer

, ) { let filter = lv.op.dup; - let in_channel = &lv.mem_channels[0]; - let out_channel = &lv.mem_channels[NUM_GP_CHANNELS - 1]; + let write_channel = &lv.mem_channels[1]; + let read_channel = &lv.mem_channels[2]; - channels_equal_packed(filter, in_channel, out_channel, yield_constr); + channels_equal_packed(filter, write_channel, &lv.mem_channels[0], yield_constr); + constrain_channel_packed(false, filter, P::ZEROS, write_channel, lv, yield_constr); - constrain_channel_packed(true, filter, n, in_channel, lv, yield_constr); - constrain_channel_packed( - false, - filter, - P::Scalar::NEG_ONE.into(), - out_channel, - lv, - yield_constr, - ); + channels_equal_packed(filter, read_channel, &nv.mem_channels[0], yield_constr); + constrain_channel_packed(true, filter, n, read_channel, lv, yield_constr); + + // Constrain nv.stack_len. + yield_constr.constraint_transition(filter * (nv.stack_len - lv.stack_len - P::ONES)); + + // TODO: Constrain unused channels? } fn eval_ext_circuit_dup, const D: usize>( builder: &mut CircuitBuilder, n: ExtensionTarget, lv: &CpuColumnsView>, + nv: &CpuColumnsView>, yield_constr: &mut RecursiveConstraintConsumer, ) { - let neg_one = builder.constant_extension(F::NEG_ONE.into()); + let zero = builder.zero_extension(); let filter = lv.op.dup; - let in_channel = &lv.mem_channels[0]; - let out_channel = &lv.mem_channels[NUM_GP_CHANNELS - 1]; + let write_channel = &lv.mem_channels[1]; + let read_channel = &lv.mem_channels[2]; - channels_equal_ext_circuit(builder, filter, in_channel, out_channel, yield_constr); - - constrain_channel_ext_circuit(builder, true, filter, n, in_channel, lv, yield_constr); + channels_equal_ext_circuit( + builder, + filter, + write_channel, + &lv.mem_channels[0], + yield_constr, + ); constrain_channel_ext_circuit( builder, false, filter, - neg_one, - out_channel, + zero, + write_channel, lv, yield_constr, ); + + channels_equal_ext_circuit( + builder, + filter, + read_channel, + &nv.mem_channels[0], + yield_constr, + ); + constrain_channel_ext_circuit(builder, true, filter, n, read_channel, lv, yield_constr); + + // Constrain nv.stack_len. + let diff = builder.sub_extension(nv.stack_len, lv.stack_len); + let constr = builder.mul_sub_extension(filter, diff, filter); + yield_constr.constraint_transition(builder, constr); + + // TODO: Constrain unused channels? } fn eval_packed_swap( n: P, lv: &CpuColumnsView

, + nv: &CpuColumnsView

, yield_constr: &mut ConstraintConsumer

, ) { let n_plus_one = n + P::ONES; @@ -170,25 +191,27 @@ fn eval_packed_swap( let in1_channel = &lv.mem_channels[0]; let in2_channel = &lv.mem_channels[1]; - let out1_channel = &lv.mem_channels[NUM_GP_CHANNELS - 2]; - let out2_channel = &lv.mem_channels[NUM_GP_CHANNELS - 1]; + let out_channel = &lv.mem_channels[2]; - channels_equal_packed(filter, in1_channel, out1_channel, yield_constr); - channels_equal_packed(filter, in2_channel, out2_channel, yield_constr); + channels_equal_packed(filter, in1_channel, out_channel, yield_constr); + constrain_channel_packed(false, filter, n_plus_one, out_channel, lv, yield_constr); - constrain_channel_packed(true, filter, P::ZEROS, in1_channel, lv, yield_constr); + channels_equal_packed(filter, in2_channel, &nv.mem_channels[0], yield_constr); constrain_channel_packed(true, filter, n_plus_one, in2_channel, lv, yield_constr); - constrain_channel_packed(false, filter, n_plus_one, out1_channel, lv, yield_constr); - constrain_channel_packed(false, filter, P::ZEROS, out2_channel, lv, yield_constr); + + // Constrain nv.stack_len; + yield_constr.constraint(filter * (nv.stack_len - lv.stack_len)); + + // TODO: Constrain unused channels? } fn eval_ext_circuit_swap, const D: usize>( builder: &mut CircuitBuilder, n: ExtensionTarget, lv: &CpuColumnsView>, + nv: &CpuColumnsView>, yield_constr: &mut RecursiveConstraintConsumer, ) { - let zero = builder.zero_extension(); let one = builder.one_extension(); let n_plus_one = builder.add_extension(n, one); @@ -196,13 +219,26 @@ fn eval_ext_circuit_swap, const D: usize>( let in1_channel = &lv.mem_channels[0]; let in2_channel = &lv.mem_channels[1]; - let out1_channel = &lv.mem_channels[NUM_GP_CHANNELS - 2]; - let out2_channel = &lv.mem_channels[NUM_GP_CHANNELS - 1]; + let out_channel = &lv.mem_channels[2]; - channels_equal_ext_circuit(builder, filter, in1_channel, out1_channel, yield_constr); - channels_equal_ext_circuit(builder, filter, in2_channel, out2_channel, yield_constr); + channels_equal_ext_circuit(builder, filter, in1_channel, out_channel, yield_constr); + constrain_channel_ext_circuit( + builder, + false, + filter, + n_plus_one, + out_channel, + lv, + yield_constr, + ); - constrain_channel_ext_circuit(builder, true, filter, zero, in1_channel, lv, yield_constr); + channels_equal_ext_circuit( + builder, + filter, + in2_channel, + &nv.mem_channels[0], + yield_constr, + ); constrain_channel_ext_circuit( builder, true, @@ -212,20 +248,18 @@ fn eval_ext_circuit_swap, const D: usize>( lv, yield_constr, ); - constrain_channel_ext_circuit( - builder, - false, - filter, - n_plus_one, - out1_channel, - lv, - yield_constr, - ); - constrain_channel_ext_circuit(builder, false, filter, zero, out2_channel, lv, yield_constr); + + // Constrain nv.stack_len. + let diff = builder.sub_extension(nv.stack_len, lv.stack_len); + let constr = builder.mul_extension(filter, diff); + yield_constr.constraint(builder, constr); + + // TODO: Constrain unused channels? } pub fn eval_packed( lv: &CpuColumnsView

, + nv: &CpuColumnsView

, yield_constr: &mut ConstraintConsumer

, ) { let n = lv.opcode_bits[0] @@ -233,13 +267,14 @@ pub fn eval_packed( + lv.opcode_bits[2] * P::Scalar::from_canonical_u64(4) + lv.opcode_bits[3] * P::Scalar::from_canonical_u64(8); - eval_packed_dup(n, lv, yield_constr); - eval_packed_swap(n, lv, yield_constr); + eval_packed_dup(n, lv, nv, yield_constr); + eval_packed_swap(n, lv, nv, yield_constr); } pub fn eval_ext_circuit, const D: usize>( builder: &mut CircuitBuilder, lv: &CpuColumnsView>, + nv: &CpuColumnsView>, yield_constr: &mut RecursiveConstraintConsumer, ) { let n = lv.opcode_bits[..4].iter().enumerate().fold( @@ -249,6 +284,6 @@ pub fn eval_ext_circuit, const D: usize>( }, ); - eval_ext_circuit_dup(builder, n, lv, yield_constr); - eval_ext_circuit_swap(builder, n, lv, yield_constr); + eval_ext_circuit_dup(builder, n, lv, nv, yield_constr); + eval_ext_circuit_swap(builder, n, lv, nv, yield_constr); } diff --git a/evm/src/cpu/gas.rs b/evm/src/cpu/gas.rs index 694fb0f4..1434efd9 100644 --- a/evm/src/cpu/gas.rs +++ b/evm/src/cpu/gas.rs @@ -36,7 +36,8 @@ const SIMPLE_OPCODES: OpsColumnsView> = OpsColumnsView { push: G_VERYLOW, dup: G_VERYLOW, swap: G_VERYLOW, - context_op: KERNEL_ONLY_INSTR, + get_context: KERNEL_ONLY_INSTR, + set_context: KERNEL_ONLY_INSTR, mstore_32bytes: KERNEL_ONLY_INSTR, mload_32bytes: KERNEL_ONLY_INSTR, exit_kernel: None, diff --git a/evm/src/cpu/jumps.rs b/evm/src/cpu/jumps.rs index 18291773..0c03e2d1 100644 --- a/evm/src/cpu/jumps.rs +++ b/evm/src/cpu/jumps.rs @@ -7,7 +7,6 @@ use plonky2::iop::ext_target::ExtensionTarget; use crate::constraint_consumer::{ConstraintConsumer, RecursiveConstraintConsumer}; use crate::cpu::columns::CpuColumnsView; use crate::cpu::membus::NUM_GP_CHANNELS; -use crate::cpu::stack; use crate::memory::segments::Segment; pub fn eval_packed_exit_kernel( @@ -74,8 +73,26 @@ pub fn eval_packed_jump_jumpi( let is_jumpi = filter * lv.opcode_bits[0]; // Stack constraints. - stack::eval_packed_one(lv, nv, is_jump, stack::JUMP_OP.unwrap(), yield_constr); - stack::eval_packed_one(lv, nv, is_jumpi, stack::JUMPI_OP.unwrap(), yield_constr); + // If (JUMP and stack_len != 1) or (JUMPI and stack_len != 2)... + let len_diff = lv.stack_len - P::ONES - lv.opcode_bits[0]; + let new_filter = len_diff * filter; + // Read an extra element. + let channel = nv.mem_channels[0]; + yield_constr.constraint_transition(new_filter * (channel.used - P::ONES)); + yield_constr.constraint_transition(new_filter * (channel.is_read - P::ONES)); + yield_constr.constraint_transition(new_filter * (channel.addr_context - nv.context)); + yield_constr.constraint_transition( + new_filter * (channel.addr_segment - P::Scalar::from_canonical_u64(Segment::Stack as u64)), + ); + let addr_virtual = nv.stack_len - P::ONES; + yield_constr.constraint_transition(new_filter * (channel.addr_virtual - addr_virtual)); + // Constrain `stack_inv_aux`. + yield_constr.constraint( + filter * (len_diff * lv.general.stack().stack_inv - lv.general.stack().stack_inv_aux), + ); + // Disable channel if stack_len == N. + let empty_stack_filter = filter * (lv.general.stack().stack_inv_aux - P::ONES); + yield_constr.constraint_transition(empty_stack_filter * channel.used); // If `JUMP`, re-use the `JUMPI` logic, but setting the second input (the predicate) to be 1. // In other words, we implement `JUMP(dst)` as `JUMPI(dst, cond=1)`. @@ -123,6 +140,12 @@ pub fn eval_packed_jump_jumpi( // Channel 1 is unused by the `JUMP` instruction. yield_constr.constraint(is_jump * lv.mem_channels[1].used); + // Update stack length. + yield_constr.constraint_transition(is_jump * (nv.stack_len - lv.stack_len + P::ONES)); + yield_constr.constraint_transition( + is_jumpi * (nv.stack_len - lv.stack_len + P::Scalar::from_canonical_u64(2)), + ); + // Finally, set the next program counter. let fallthrough_dst = lv.program_counter + P::ONES; let jump_dest = dst[0]; @@ -150,22 +173,55 @@ pub fn eval_ext_circuit_jump_jumpi, const D: usize> let is_jumpi = builder.mul_extension(filter, lv.opcode_bits[0]); // Stack constraints. - stack::eval_ext_circuit_one( - builder, - lv, - nv, - is_jump, - stack::JUMP_OP.unwrap(), - yield_constr, - ); - stack::eval_ext_circuit_one( - builder, - lv, - nv, - is_jumpi, - stack::JUMPI_OP.unwrap(), - yield_constr, - ); + // If (JUMP and stack_len != 1) or (JUMPI and stack_len != 2)... + let len_diff = builder.sub_extension(lv.stack_len, one_extension); + let len_diff = builder.sub_extension(len_diff, lv.opcode_bits[0]); + let new_filter = builder.mul_extension(len_diff, filter); + // Read an extra element. + let channel = nv.mem_channels[0]; + + { + let constr = builder.mul_sub_extension(new_filter, channel.used, new_filter); + yield_constr.constraint_transition(builder, constr); + } + { + let constr = builder.mul_sub_extension(new_filter, channel.is_read, new_filter); + yield_constr.constraint_transition(builder, constr); + } + { + let diff = builder.sub_extension(channel.addr_context, nv.context); + let constr = builder.mul_extension(new_filter, diff); + yield_constr.constraint_transition(builder, constr); + } + { + let constr = builder.arithmetic_extension( + F::ONE, + -F::from_canonical_u64(Segment::Stack as u64), + new_filter, + channel.addr_segment, + new_filter, + ); + yield_constr.constraint_transition(builder, constr); + } + { + let diff = builder.sub_extension(channel.addr_virtual, nv.stack_len); + let constr = builder.arithmetic_extension(F::ONE, F::ONE, new_filter, diff, new_filter); + yield_constr.constraint_transition(builder, constr); + } + // Constrain `stack_inv_aux`. + { + let prod = builder.mul_extension(len_diff, lv.general.stack().stack_inv); + let diff = builder.sub_extension(prod, lv.general.stack().stack_inv_aux); + let constr = builder.mul_extension(filter, diff); + yield_constr.constraint(builder, constr); + } + // Disable channel if stack_len == N. + { + let empty_stack_filter = + builder.mul_sub_extension(filter, lv.general.stack().stack_inv_aux, filter); + let constr = builder.mul_extension(empty_stack_filter, channel.used); + yield_constr.constraint_transition(builder, constr); + } // If `JUMP`, re-use the `JUMPI` logic, but setting the second input (the predicate) to be 1. // In other words, we implement `JUMP(dst)` as `JUMPI(dst, cond=1)`. @@ -267,6 +323,19 @@ pub fn eval_ext_circuit_jump_jumpi, const D: usize> yield_constr.constraint(builder, constr); } + // Update stack length. + { + let diff = builder.sub_extension(nv.stack_len, lv.stack_len); + let constr = builder.mul_add_extension(is_jump, diff, is_jump); + yield_constr.constraint_transition(builder, constr); + } + { + let diff = builder.sub_extension(nv.stack_len, lv.stack_len); + let diff = builder.add_const_extension(diff, F::TWO); + let constr = builder.mul_extension(is_jumpi, diff); + yield_constr.constraint_transition(builder, constr); + } + // Finally, set the next program counter. let fallthrough_dst = builder.add_const_extension(lv.program_counter, F::ONE); let jump_dest = dst[0]; diff --git a/evm/src/cpu/kernel/interpreter.rs b/evm/src/cpu/kernel/interpreter.rs index 8f19a072..315e93f1 100644 --- a/evm/src/cpu/kernel/interpreter.rs +++ b/evm/src/cpu/kernel/interpreter.rs @@ -126,8 +126,14 @@ impl<'a> Interpreter<'a> { opcode_count: [0; 0x100], }; result.generation_state.registers.program_counter = initial_offset; - result.generation_state.registers.stack_len = initial_stack.len(); - *result.stack_mut() = initial_stack; + let initial_stack_len = initial_stack.len(); + result.generation_state.registers.stack_len = initial_stack_len; + if !initial_stack.is_empty() { + result.generation_state.registers.stack_top = initial_stack[initial_stack_len - 1]; + *result.stack_segment_mut() = initial_stack; + result.stack_segment_mut().truncate(initial_stack_len - 1); + } + result } @@ -262,12 +268,18 @@ impl<'a> Interpreter<'a> { self.generation_state.registers.program_counter += n; } - pub(crate) fn stack(&self) -> &[U256] { - &self.generation_state.memory.contexts[self.context].segments[Segment::Stack as usize] + pub(crate) fn stack(&self) -> Vec { + let mut stack = self.generation_state.memory.contexts[self.context].segments + [Segment::Stack as usize] .content + .clone(); + if self.stack_len() > 0 { + stack.push(self.stack_top()); + } + stack } - fn stack_mut(&mut self) -> &mut Vec { + fn stack_segment_mut(&mut self) -> &mut Vec { &mut self.generation_state.memory.contexts[self.context].segments[Segment::Stack as usize] .content } @@ -285,7 +297,11 @@ impl<'a> Interpreter<'a> { } pub(crate) fn push(&mut self, x: U256) { - self.stack_mut().push(x); + if self.stack_len() > 0 { + let top = self.stack_top(); + self.stack_segment_mut().push(top); + } + self.generation_state.registers.stack_top = x; self.generation_state.registers.stack_len += 1; } @@ -295,9 +311,17 @@ impl<'a> Interpreter<'a> { pub(crate) fn pop(&mut self) -> U256 { let result = stack_peek(&self.generation_state, 0); + if self.stack_len() > 1 { + let top = stack_peek(&self.generation_state, 1).unwrap(); + self.generation_state.registers.stack_top = top; + } self.generation_state.registers.stack_len -= 1; let new_len = self.stack_len(); - self.stack_mut().truncate(new_len); + if new_len > 0 { + self.stack_segment_mut().truncate(new_len - 1); + } else { + self.stack_segment_mut().truncate(0); + } result.expect("Empty stack") } @@ -1007,13 +1031,19 @@ impl<'a> Interpreter<'a> { } fn run_dup(&mut self, n: u8) { - self.push(self.stack()[self.stack_len() - n as usize]); + if n == 0 { + self.push(self.stack_top()); + } else { + self.push(stack_peek(&self.generation_state, n as usize - 1).unwrap()); + } } fn run_swap(&mut self, n: u8) -> anyhow::Result<()> { let len = self.stack_len(); ensure!(len > n as usize); - self.stack_mut().swap(len - 1, len - n as usize - 1); + let to_swap = stack_peek(&self.generation_state, n as usize).unwrap(); + self.stack_segment_mut()[len - n as usize - 1] = self.stack_top(); + self.generation_state.registers.stack_top = to_swap; Ok(()) } @@ -1084,9 +1114,13 @@ impl<'a> Interpreter<'a> { } } - fn stack_len(&self) -> usize { + pub(crate) fn stack_len(&self) -> usize { self.generation_state.registers.stack_len } + + pub(crate) fn stack_top(&self) -> U256 { + self.generation_state.registers.stack_top + } } // Computes the two's complement of the given integer. diff --git a/evm/src/cpu/kernel/tests/signed_syscalls.rs b/evm/src/cpu/kernel/tests/signed_syscalls.rs index 728d5565..93391cf6 100644 --- a/evm/src/cpu/kernel/tests/signed_syscalls.rs +++ b/evm/src/cpu/kernel/tests/signed_syscalls.rs @@ -119,8 +119,8 @@ fn run_test(fn_label: &str, expected_fn: fn(U256, U256) -> U256, opname: &str) { let stack = vec![retdest, y, x]; let mut interpreter = Interpreter::new_with_kernel(fn_label, stack); interpreter.run().unwrap(); - assert_eq!(interpreter.stack().len(), 1usize, "unexpected stack size"); - let output = interpreter.stack()[0]; + assert_eq!(interpreter.stack_len(), 1usize, "unexpected stack size"); + let output = interpreter.stack_top(); let expected_output = expected_fn(x, y); assert_eq!( output, expected_output, diff --git a/evm/src/cpu/memio.rs b/evm/src/cpu/memio.rs index aa3749ca..f70f3fdb 100644 --- a/evm/src/cpu/memio.rs +++ b/evm/src/cpu/memio.rs @@ -1,6 +1,7 @@ use itertools::izip; use plonky2::field::extension::Extendable; use plonky2::field::packed::PackedField; +use plonky2::field::types::Field; use plonky2::hash::hash_types::RichField; use plonky2::iop::ext_target::ExtensionTarget; @@ -8,6 +9,7 @@ use crate::constraint_consumer::{ConstraintConsumer, RecursiveConstraintConsumer use crate::cpu::columns::CpuColumnsView; use crate::cpu::membus::NUM_GP_CHANNELS; use crate::cpu::stack; +use crate::memory::segments::Segment; fn get_addr(lv: &CpuColumnsView) -> (T, T, T) { let addr_context = lv.mem_channels[0].value[0]; @@ -27,18 +29,14 @@ fn eval_packed_load( let (addr_context, addr_segment, addr_virtual) = get_addr(lv); let load_channel = lv.mem_channels[3]; - let push_channel = lv.mem_channels[NUM_GP_CHANNELS - 1]; yield_constr.constraint(filter * (load_channel.used - P::ONES)); yield_constr.constraint(filter * (load_channel.is_read - P::ONES)); yield_constr.constraint(filter * (load_channel.addr_context - addr_context)); yield_constr.constraint(filter * (load_channel.addr_segment - addr_segment)); yield_constr.constraint(filter * (load_channel.addr_virtual - addr_virtual)); - for (load_limb, push_limb) in izip!(load_channel.value, push_channel.value) { - yield_constr.constraint(filter * (load_limb - push_limb)); - } // Disable remaining memory channels, if any. - for &channel in &lv.mem_channels[4..NUM_GP_CHANNELS - 1] { + for &channel in &lv.mem_channels[4..NUM_GP_CHANNELS] { yield_constr.constraint(filter * channel.used); } @@ -64,7 +62,6 @@ fn eval_ext_circuit_load, const D: usize>( let (addr_context, addr_segment, addr_virtual) = get_addr(lv); let load_channel = lv.mem_channels[3]; - let push_channel = lv.mem_channels[NUM_GP_CHANNELS - 1]; { let constr = builder.mul_sub_extension(filter, load_channel.used, filter); yield_constr.constraint(builder, constr); @@ -85,14 +82,9 @@ fn eval_ext_circuit_load, const D: usize>( let constr = builder.mul_extension(filter, diff); yield_constr.constraint(builder, constr); } - for (load_limb, push_limb) in izip!(load_channel.value, push_channel.value) { - let diff = builder.sub_extension(load_limb, push_limb); - let constr = builder.mul_extension(filter, diff); - yield_constr.constraint(builder, constr); - } // Disable remaining memory channels, if any. - for &channel in &lv.mem_channels[4..NUM_GP_CHANNELS - 1] { + for &channel in &lv.mem_channels[4..NUM_GP_CHANNELS] { let constr = builder.mul_extension(filter, channel.used); yield_constr.constraint(builder, constr); } @@ -113,7 +105,7 @@ fn eval_packed_store( nv: &CpuColumnsView

, yield_constr: &mut ConstraintConsumer

, ) { - let filter = lv.op.m_op_general * (P::ONES - lv.opcode_bits[0]); + let filter = lv.op.m_op_general * (lv.opcode_bits[0] - P::ONES); let (addr_context, addr_segment, addr_virtual) = get_addr(lv); @@ -133,14 +125,50 @@ fn eval_packed_store( yield_constr.constraint(filter * channel.used); } - // Stack constraints - stack::eval_packed_one( - lv, - nv, - filter, - stack::MSTORE_GENERAL_OP.unwrap(), - yield_constr, + // Stack constraints. + // Pops. + for i in 1..4 { + let channel = lv.mem_channels[i]; + + yield_constr.constraint(filter * (channel.used - P::ONES)); + yield_constr.constraint(filter * (channel.is_read - P::ONES)); + + yield_constr.constraint(filter * (channel.addr_context - lv.context)); + yield_constr.constraint( + filter * (channel.addr_segment - P::Scalar::from_canonical_u64(Segment::Stack as u64)), + ); + // Remember that the first read (`i == 1`) is for the second stack element at `stack[stack_len - 1]`. + let addr_virtual = lv.stack_len - P::Scalar::from_canonical_usize(i + 1); + yield_constr.constraint(filter * (channel.addr_virtual - addr_virtual)); + } + // Constrain `stack_inv_aux`. + let len_diff = lv.stack_len - P::Scalar::from_canonical_usize(4); + yield_constr.constraint( + lv.op.m_op_general + * (len_diff * lv.general.stack().stack_inv - lv.general.stack().stack_inv_aux), ); + // If stack_len != 4 and MSTORE, read new top of the stack in nv.mem_channels[0]. + let top_read_channel = nv.mem_channels[0]; + let is_top_read = lv.general.stack().stack_inv_aux * (P::ONES - lv.opcode_bits[0]); + // Constrain `stack_inv_aux_2`. It contains `stack_inv_aux * opcode_bits[0]`. + yield_constr + .constraint(lv.op.m_op_general * (lv.general.stack().stack_inv_aux_2 - is_top_read)); + let new_filter = lv.op.m_op_general * lv.general.stack().stack_inv_aux_2; + yield_constr.constraint_transition(new_filter * (top_read_channel.used - P::ONES)); + yield_constr.constraint_transition(new_filter * (top_read_channel.is_read - P::ONES)); + yield_constr.constraint_transition(new_filter * (top_read_channel.addr_context - nv.context)); + yield_constr.constraint_transition( + new_filter + * (top_read_channel.addr_segment + - P::Scalar::from_canonical_u64(Segment::Stack as u64)), + ); + let addr_virtual = nv.stack_len - P::ONES; + yield_constr.constraint_transition(new_filter * (top_read_channel.addr_virtual - addr_virtual)); + // If stack_len == 4 or MLOAD, disable the channel. + yield_constr.constraint( + lv.op.m_op_general * (lv.general.stack().stack_inv_aux - P::ONES) * top_read_channel.used, + ); + yield_constr.constraint(lv.op.m_op_general * lv.opcode_bits[0] * top_read_channel.used); } fn eval_ext_circuit_store, const D: usize>( @@ -149,10 +177,8 @@ fn eval_ext_circuit_store, const D: usize>( nv: &CpuColumnsView>, yield_constr: &mut RecursiveConstraintConsumer, ) { - let mut filter = lv.op.m_op_general; - let one = builder.one_extension(); - let minus = builder.sub_extension(one, lv.opcode_bits[0]); - filter = builder.mul_extension(filter, minus); + let filter = + builder.mul_sub_extension(lv.op.m_op_general, lv.opcode_bits[0], lv.op.m_op_general); let (addr_context, addr_segment, addr_virtual) = get_addr(lv); @@ -191,14 +217,102 @@ fn eval_ext_circuit_store, const D: usize>( } // Stack constraints - stack::eval_ext_circuit_one( - builder, - lv, - nv, - filter, - stack::MSTORE_GENERAL_OP.unwrap(), - yield_constr, - ); + // Pops. + for i in 1..4 { + let channel = lv.mem_channels[i]; + + { + let constr = builder.mul_sub_extension(filter, channel.used, filter); + yield_constr.constraint(builder, constr); + } + { + let constr = builder.mul_sub_extension(filter, channel.is_read, filter); + yield_constr.constraint(builder, constr); + } + { + let diff = builder.sub_extension(channel.addr_context, lv.context); + let constr = builder.mul_extension(filter, diff); + yield_constr.constraint(builder, constr); + } + { + let diff = builder.add_const_extension( + channel.addr_segment, + -F::from_canonical_u64(Segment::Stack as u64), + ); + let constr = builder.mul_extension(filter, diff); + yield_constr.constraint(builder, constr); + } + // Remember that the first read (`i == 1`) is for the second stack element at `stack[stack_len - 1]`. + let addr_virtual = + builder.add_const_extension(lv.stack_len, -F::from_canonical_usize(i + 1)); + let diff = builder.sub_extension(channel.addr_virtual, addr_virtual); + let constr = builder.mul_extension(filter, diff); + yield_constr.constraint(builder, constr); + } + // Constrain `stack_inv_aux`. + { + let len_diff = builder.add_const_extension(lv.stack_len, -F::from_canonical_usize(4)); + let diff = builder.mul_sub_extension( + len_diff, + lv.general.stack().stack_inv, + lv.general.stack().stack_inv_aux, + ); + let constr = builder.mul_extension(lv.op.m_op_general, diff); + yield_constr.constraint(builder, constr); + } + // If stack_len != 4 and MSTORE, read new top of the stack in nv.mem_channels[0]. + let top_read_channel = nv.mem_channels[0]; + let is_top_read = builder.mul_extension(lv.general.stack().stack_inv_aux, lv.opcode_bits[0]); + let is_top_read = builder.sub_extension(lv.general.stack().stack_inv_aux, is_top_read); + // Constrain `stack_inv_aux_2`. It contains `stack_inv_aux * opcode_bits[0]`. + { + let diff = builder.sub_extension(lv.general.stack().stack_inv_aux_2, is_top_read); + let constr = builder.mul_extension(lv.op.m_op_general, diff); + yield_constr.constraint(builder, constr); + } + let new_filter = builder.mul_extension(lv.op.m_op_general, lv.general.stack().stack_inv_aux_2); + { + let constr = builder.mul_sub_extension(new_filter, top_read_channel.used, new_filter); + yield_constr.constraint_transition(builder, constr); + } + { + let constr = builder.mul_sub_extension(new_filter, top_read_channel.is_read, new_filter); + yield_constr.constraint_transition(builder, constr); + } + { + let diff = builder.sub_extension(top_read_channel.addr_context, nv.context); + let constr = builder.mul_extension(new_filter, diff); + yield_constr.constraint_transition(builder, constr); + } + { + let diff = builder.add_const_extension( + top_read_channel.addr_segment, + -F::from_canonical_u64(Segment::Stack as u64), + ); + let constr = builder.mul_extension(new_filter, diff); + yield_constr.constraint_transition(builder, constr); + } + { + let addr_virtual = builder.add_const_extension(nv.stack_len, -F::ONE); + let diff = builder.sub_extension(top_read_channel.addr_virtual, addr_virtual); + let constr = builder.mul_extension(new_filter, diff); + yield_constr.constraint_transition(builder, constr); + } + // If stack_len == 4 or MLOAD, disable the channel. + { + let diff = builder.mul_sub_extension( + lv.op.m_op_general, + lv.general.stack().stack_inv_aux, + lv.op.m_op_general, + ); + let constr = builder.mul_extension(diff, top_read_channel.used); + yield_constr.constraint(builder, constr); + } + { + let mul = builder.mul_extension(lv.op.m_op_general, lv.opcode_bits[0]); + let constr = builder.mul_extension(mul, top_read_channel.used); + yield_constr.constraint(builder, constr); + } } pub fn eval_packed( diff --git a/evm/src/cpu/mod.rs b/evm/src/cpu/mod.rs index b7312147..0885f644 100644 --- a/evm/src/cpu/mod.rs +++ b/evm/src/cpu/mod.rs @@ -16,6 +16,6 @@ mod pc; mod push0; mod shift; pub(crate) mod simple_logic; -mod stack; +pub(crate) mod stack; pub(crate) mod stack_bounds; mod syscalls_exceptions; diff --git a/evm/src/cpu/modfp254.rs b/evm/src/cpu/modfp254.rs index 86f08052..eed497f5 100644 --- a/evm/src/cpu/modfp254.rs +++ b/evm/src/cpu/modfp254.rs @@ -22,7 +22,7 @@ pub fn eval_packed( let filter = lv.op.fp254_op; // We want to use all the same logic as the usual mod operations, but without needing to read - // the modulus from the stack. We simply constrain `mem_channels[2]` to be our prime (that's + // the modulus from the stack. We simply constrain `mem_channels[1]` to be our prime (that's // where the modulus goes in the generalized operations). let channel_val = lv.mem_channels[2].value; for (channel_limb, p_limb) in izip!(channel_val, P_LIMBS) { @@ -39,7 +39,7 @@ pub fn eval_ext_circuit, const D: usize>( let filter = lv.op.fp254_op; // We want to use all the same logic as the usual mod operations, but without needing to read - // the modulus from the stack. We simply constrain `mem_channels[2]` to be our prime (that's + // the modulus from the stack. We simply constrain `mem_channels[1]` to be our prime (that's // where the modulus goes in the generalized operations). let channel_val = lv.mem_channels[2].value; for (channel_limb, p_limb) in izip!(channel_val, P_LIMBS) { diff --git a/evm/src/cpu/pc.rs b/evm/src/cpu/pc.rs index 26731c92..5271ad81 100644 --- a/evm/src/cpu/pc.rs +++ b/evm/src/cpu/pc.rs @@ -5,16 +5,16 @@ use plonky2::iop::ext_target::ExtensionTarget; use crate::constraint_consumer::{ConstraintConsumer, RecursiveConstraintConsumer}; use crate::cpu::columns::CpuColumnsView; -use crate::cpu::membus::NUM_GP_CHANNELS; pub fn eval_packed( lv: &CpuColumnsView

, + nv: &CpuColumnsView

, yield_constr: &mut ConstraintConsumer

, ) { let filter = lv.op.pc; - let push_value = lv.mem_channels[NUM_GP_CHANNELS - 1].value; - yield_constr.constraint(filter * (push_value[0] - lv.program_counter)); - for &limb in &push_value[1..] { + let new_stack_top = nv.mem_channels[0].value; + yield_constr.constraint(filter * (new_stack_top[0] - lv.program_counter)); + for &limb in &new_stack_top[1..] { yield_constr.constraint(filter * limb); } } @@ -22,16 +22,17 @@ pub fn eval_packed( pub fn eval_ext_circuit, const D: usize>( builder: &mut plonky2::plonk::circuit_builder::CircuitBuilder, lv: &CpuColumnsView>, + nv: &CpuColumnsView>, yield_constr: &mut RecursiveConstraintConsumer, ) { let filter = lv.op.pc; - let push_value = lv.mem_channels[NUM_GP_CHANNELS - 1].value; + let new_stack_top = nv.mem_channels[0].value; { - let diff = builder.sub_extension(push_value[0], lv.program_counter); + let diff = builder.sub_extension(new_stack_top[0], lv.program_counter); let constr = builder.mul_extension(filter, diff); yield_constr.constraint(builder, constr); } - for &limb in &push_value[1..] { + for &limb in &new_stack_top[1..] { let constr = builder.mul_extension(filter, limb); yield_constr.constraint(builder, constr); } diff --git a/evm/src/cpu/push0.rs b/evm/src/cpu/push0.rs index 30f6d0ae..d49446cc 100644 --- a/evm/src/cpu/push0.rs +++ b/evm/src/cpu/push0.rs @@ -5,15 +5,14 @@ use plonky2::iop::ext_target::ExtensionTarget; use crate::constraint_consumer::{ConstraintConsumer, RecursiveConstraintConsumer}; use crate::cpu::columns::CpuColumnsView; -use crate::cpu::membus::NUM_GP_CHANNELS; pub fn eval_packed( lv: &CpuColumnsView

, + nv: &CpuColumnsView

, yield_constr: &mut ConstraintConsumer

, ) { let filter = lv.op.push0; - let push_value = lv.mem_channels[NUM_GP_CHANNELS - 1].value; - for limb in push_value { + for limb in nv.mem_channels[0].value { yield_constr.constraint(filter * limb); } } @@ -21,11 +20,11 @@ pub fn eval_packed( pub fn eval_ext_circuit, const D: usize>( builder: &mut plonky2::plonk::circuit_builder::CircuitBuilder, lv: &CpuColumnsView>, + nv: &CpuColumnsView>, yield_constr: &mut RecursiveConstraintConsumer, ) { let filter = lv.op.push0; - let push_value = lv.mem_channels[NUM_GP_CHANNELS - 1].value; - for limb in push_value { + for limb in nv.mem_channels[0].value { let constr = builder.mul_extension(filter, limb); yield_constr.constraint(builder, constr); } diff --git a/evm/src/cpu/stack.rs b/evm/src/cpu/stack.rs index 28abf077..31d0405c 100644 --- a/evm/src/cpu/stack.rs +++ b/evm/src/cpu/stack.rs @@ -1,3 +1,5 @@ +use std::cmp::max; + use itertools::izip; use plonky2::field::extension::Extendable; use plonky2::field::packed::PackedField; @@ -13,46 +15,41 @@ use crate::memory::segments::Segment; #[derive(Clone, Copy)] pub(crate) struct StackBehavior { - num_pops: usize, - pushes: bool, + pub(crate) num_pops: usize, + pub(crate) pushes: bool, + new_top_stack_channel: Option, disable_other_channels: bool, } -const BASIC_UNARY_OP: Option = Some(StackBehavior { - num_pops: 1, - pushes: true, - disable_other_channels: true, -}); const BASIC_BINARY_OP: Option = Some(StackBehavior { num_pops: 2, pushes: true, + new_top_stack_channel: Some(NUM_GP_CHANNELS - 1), disable_other_channels: true, }); const BASIC_TERNARY_OP: Option = Some(StackBehavior { num_pops: 3, pushes: true, + new_top_stack_channel: Some(NUM_GP_CHANNELS - 1), disable_other_channels: true, }); pub(crate) const JUMP_OP: Option = Some(StackBehavior { num_pops: 1, pushes: false, + new_top_stack_channel: None, disable_other_channels: false, }); pub(crate) const JUMPI_OP: Option = Some(StackBehavior { num_pops: 2, pushes: false, + new_top_stack_channel: None, disable_other_channels: false, }); pub(crate) const MLOAD_GENERAL_OP: Option = Some(StackBehavior { num_pops: 3, pushes: true, - disable_other_channels: false, -}); - -pub(crate) const MSTORE_GENERAL_OP: Option = Some(StackBehavior { - num_pops: 4, - pushes: false, + new_top_stack_channel: None, disable_other_channels: false, }); @@ -61,79 +58,111 @@ pub(crate) const MSTORE_GENERAL_OP: Option = Some(StackBehavior { // propertly constrained. The same applies when `disable_other_channels` is set to `false`, // except the first `num_pops` and the last `pushes as usize` channels have their read flag and // address constrained automatically in this file. -const STACK_BEHAVIORS: OpsColumnsView> = OpsColumnsView { +pub(crate) const STACK_BEHAVIORS: OpsColumnsView> = OpsColumnsView { binary_op: BASIC_BINARY_OP, ternary_op: BASIC_TERNARY_OP, fp254_op: BASIC_BINARY_OP, eq_iszero: None, // EQ is binary, IS_ZERO is unary. logic_op: BASIC_BINARY_OP, - not: BASIC_UNARY_OP, + not: Some(StackBehavior { + num_pops: 1, + pushes: true, + new_top_stack_channel: Some(NUM_GP_CHANNELS - 1), + disable_other_channels: true, + }), shift: Some(StackBehavior { num_pops: 2, pushes: true, + new_top_stack_channel: Some(NUM_GP_CHANNELS - 1), disable_other_channels: false, }), keccak_general: Some(StackBehavior { num_pops: 4, pushes: true, + new_top_stack_channel: Some(NUM_GP_CHANNELS - 1), disable_other_channels: true, }), prover_input: None, // TODO pop: Some(StackBehavior { num_pops: 1, pushes: false, + new_top_stack_channel: None, disable_other_channels: true, }), jumps: None, // Depends on whether it's a JUMP or a JUMPI. pc: Some(StackBehavior { num_pops: 0, pushes: true, + new_top_stack_channel: None, disable_other_channels: true, }), jumpdest: Some(StackBehavior { num_pops: 0, pushes: false, + new_top_stack_channel: None, disable_other_channels: true, }), push0: Some(StackBehavior { num_pops: 0, pushes: true, + new_top_stack_channel: None, disable_other_channels: true, }), push: None, // TODO dup: None, swap: None, - context_op: None, // SET_CONTEXT is special since it involves the old and the new stack. - mstore_32bytes: Some(StackBehavior { - num_pops: 5, - pushes: false, - disable_other_channels: false, + get_context: Some(StackBehavior { + num_pops: 0, + pushes: true, + new_top_stack_channel: None, + disable_other_channels: true, }), + set_context: None, // SET_CONTEXT is special since it involves the old and the new stack. mload_32bytes: Some(StackBehavior { num_pops: 4, pushes: true, + new_top_stack_channel: Some(4), + disable_other_channels: false, + }), + mstore_32bytes: Some(StackBehavior { + num_pops: 5, + pushes: false, + new_top_stack_channel: None, disable_other_channels: false, }), exit_kernel: Some(StackBehavior { num_pops: 1, pushes: false, + new_top_stack_channel: None, disable_other_channels: true, }), m_op_general: None, syscall: Some(StackBehavior { num_pops: 0, pushes: true, + new_top_stack_channel: None, disable_other_channels: false, }), exception: Some(StackBehavior { num_pops: 0, pushes: true, + new_top_stack_channel: None, disable_other_channels: false, }), }; -pub(crate) const EQ_STACK_BEHAVIOR: Option = BASIC_BINARY_OP; -pub(crate) const IS_ZERO_STACK_BEHAVIOR: Option = BASIC_UNARY_OP; +pub(crate) const EQ_STACK_BEHAVIOR: Option = Some(StackBehavior { + num_pops: 2, + pushes: true, + new_top_stack_channel: Some(2), + disable_other_channels: true, +}); +pub(crate) const IS_ZERO_STACK_BEHAVIOR: Option = Some(StackBehavior { + num_pops: 1, + pushes: true, + new_top_stack_channel: Some(2), + disable_other_channels: true, +}); pub(crate) fn eval_packed_one( lv: &CpuColumnsView

, @@ -142,43 +171,109 @@ pub(crate) fn eval_packed_one( stack_behavior: StackBehavior, yield_constr: &mut ConstraintConsumer

, ) { - let num_operands = stack_behavior.num_pops + (stack_behavior.pushes as usize); - assert!(num_operands <= NUM_GP_CHANNELS); + // If you have pops. + if stack_behavior.num_pops > 0 { + for i in 1..stack_behavior.num_pops { + let channel = lv.mem_channels[i]; - // Pops - for i in 0..stack_behavior.num_pops { - let channel = lv.mem_channels[i]; + yield_constr.constraint(filter * (channel.used - P::ONES)); + yield_constr.constraint(filter * (channel.is_read - P::ONES)); - yield_constr.constraint(filter * (channel.used - P::ONES)); - yield_constr.constraint(filter * (channel.is_read - P::ONES)); + yield_constr.constraint(filter * (channel.addr_context - lv.context)); + yield_constr.constraint( + filter + * (channel.addr_segment - P::Scalar::from_canonical_u64(Segment::Stack as u64)), + ); + // Remember that the first read (`i == 1`) is for the second stack element at `stack[stack_len - 1]`. + let addr_virtual = lv.stack_len - P::Scalar::from_canonical_usize(i + 1); + yield_constr.constraint(filter * (channel.addr_virtual - addr_virtual)); + } - yield_constr.constraint(filter * (channel.addr_context - lv.context)); + // If you also push, you don't need to read the new top of the stack. + // If you don't: + // - if the stack isn't empty after the pops, you read the new top from an extra pop. + // - if not, the extra read is disabled. + // These are transition constraints: they don't apply to the last row. + if !stack_behavior.pushes { + // If stack_len != N... + let len_diff = lv.stack_len - P::Scalar::from_canonical_usize(stack_behavior.num_pops); + let new_filter = len_diff * filter; + // Read an extra element. + let channel = nv.mem_channels[0]; + yield_constr.constraint_transition(new_filter * (channel.used - P::ONES)); + yield_constr.constraint_transition(new_filter * (channel.is_read - P::ONES)); + yield_constr.constraint_transition(new_filter * (channel.addr_context - nv.context)); + yield_constr.constraint_transition( + new_filter + * (channel.addr_segment - P::Scalar::from_canonical_u64(Segment::Stack as u64)), + ); + let addr_virtual = nv.stack_len - P::ONES; + yield_constr.constraint_transition(new_filter * (channel.addr_virtual - addr_virtual)); + // Constrain `stack_inv_aux`. + yield_constr.constraint( + filter + * (len_diff * lv.general.stack().stack_inv - lv.general.stack().stack_inv_aux), + ); + // Disable channel if stack_len == N. + let empty_stack_filter = filter * (lv.general.stack().stack_inv_aux - P::ONES); + yield_constr.constraint_transition(empty_stack_filter * channel.used); + } + } + // If the op only pushes, you only need to constrain the top of the stack if the stack isn't empty. + else if stack_behavior.pushes { + // If len > 0... + let new_filter = lv.stack_len * filter; + // You write the previous top of the stack in memory, in the last channel. + let channel = lv.mem_channels[NUM_GP_CHANNELS - 1]; + yield_constr.constraint(new_filter * (channel.used - P::ONES)); + yield_constr.constraint(new_filter * channel.is_read); + yield_constr.constraint(new_filter * (channel.addr_context - lv.context)); yield_constr.constraint( - filter * (channel.addr_segment - P::Scalar::from_canonical_u64(Segment::Stack as u64)), + new_filter + * (channel.addr_segment - P::Scalar::from_canonical_u64(Segment::Stack as u64)), ); - // E.g. if `stack_len == 1` and `i == 0`, we want `add_virtual == 0`. - let addr_virtual = lv.stack_len - P::Scalar::from_canonical_usize(i + 1); - yield_constr.constraint(filter * (channel.addr_virtual - addr_virtual)); + let addr_virtual = lv.stack_len - P::ONES; + yield_constr.constraint(new_filter * (channel.addr_virtual - addr_virtual)); + for (limb_ch, limb_top) in channel.value.iter().zip(lv.mem_channels[0].value.iter()) { + yield_constr.constraint(new_filter * (*limb_ch - *limb_top)); + } + // Else you disable the channel. + yield_constr.constraint( + filter + * (lv.stack_len * lv.general.stack().stack_inv - lv.general.stack().stack_inv_aux), + ); + let empty_stack_filter = filter * (lv.general.stack().stack_inv_aux - P::ONES); + yield_constr.constraint(empty_stack_filter * channel.used); + } + // If the op doesn't pop nor push, the top of the stack must not change. + else { + yield_constr.constraint(filter * nv.mem_channels[0].used); + for (limb_old, limb_new) in lv.mem_channels[0] + .value + .iter() + .zip(nv.mem_channels[0].value.iter()) + { + yield_constr.constraint(filter * (*limb_old - *limb_new)); + } } - // Pushes - if stack_behavior.pushes { - let channel = lv.mem_channels[NUM_GP_CHANNELS - 1]; - - yield_constr.constraint(filter * (channel.used - P::ONES)); - yield_constr.constraint(filter * channel.is_read); - - yield_constr.constraint(filter * (channel.addr_context - lv.context)); - yield_constr.constraint( - filter * (channel.addr_segment - P::Scalar::from_canonical_u64(Segment::Stack as u64)), - ); - let addr_virtual = lv.stack_len - P::Scalar::from_canonical_usize(stack_behavior.num_pops); - yield_constr.constraint(filter * (channel.addr_virtual - addr_virtual)); + // Maybe constrain next stack_top. + // These are transition constraints: they don't apply to the last row. + if let Some(next_top_ch) = stack_behavior.new_top_stack_channel { + for (limb_ch, limb_top) in lv.mem_channels[next_top_ch] + .value + .iter() + .zip(nv.mem_channels[0].value.iter()) + { + yield_constr.constraint_transition(filter * (*limb_ch - *limb_top)); + } } // Unused channels if stack_behavior.disable_other_channels { - for i in stack_behavior.num_pops..NUM_GP_CHANNELS - (stack_behavior.pushes as usize) { + // The first channel contains (or not) the top od the stack and is constrained elsewhere. + for i in max(1, stack_behavior.num_pops)..NUM_GP_CHANNELS - (stack_behavior.pushes as usize) + { let channel = lv.mem_channels[i]; yield_constr.constraint(filter * channel.used); } @@ -210,94 +305,199 @@ pub(crate) fn eval_ext_circuit_one, const D: usize> stack_behavior: StackBehavior, yield_constr: &mut RecursiveConstraintConsumer, ) { - let num_operands = stack_behavior.num_pops + (stack_behavior.pushes as usize); - assert!(num_operands <= NUM_GP_CHANNELS); + // If you have pops. + if stack_behavior.num_pops > 0 { + for i in 1..stack_behavior.num_pops { + let channel = lv.mem_channels[i]; - // Pops - for i in 0..stack_behavior.num_pops { - let channel = lv.mem_channels[i]; + { + let constr = builder.mul_sub_extension(filter, channel.used, filter); + yield_constr.constraint(builder, constr); + } + { + let constr = builder.mul_sub_extension(filter, channel.is_read, filter); + yield_constr.constraint(builder, constr); + } + { + let diff = builder.sub_extension(channel.addr_context, lv.context); + let constr = builder.mul_extension(filter, diff); + yield_constr.constraint(builder, constr); + } + { + let constr = builder.arithmetic_extension( + F::ONE, + -F::from_canonical_u64(Segment::Stack as u64), + filter, + channel.addr_segment, + filter, + ); + yield_constr.constraint(builder, constr); + } + // Remember that the first read (`i == 1`) is for the second stack element at `stack[stack_len - 1]`. + { + let diff = builder.sub_extension(channel.addr_virtual, lv.stack_len); + let constr = builder.arithmetic_extension( + F::ONE, + F::from_canonical_usize(i + 1), + filter, + diff, + filter, + ); + yield_constr.constraint(builder, constr); + } + } + // If you also push, you don't need to read the new top of the stack. + // If you don't: + // - if the stack isn't empty after the pops, you read the new top from an extra pop. + // - if not, the extra read is disabled. + // These are transition constraints: they don't apply to the last row. + if !stack_behavior.pushes { + // If stack_len != N... + let target_num_pops = + builder.constant_extension(F::from_canonical_usize(stack_behavior.num_pops).into()); + let len_diff = builder.sub_extension(lv.stack_len, target_num_pops); + let new_filter = builder.mul_extension(filter, len_diff); + // Read an extra element. + let channel = nv.mem_channels[0]; + + { + let constr = builder.mul_sub_extension(new_filter, channel.used, new_filter); + yield_constr.constraint_transition(builder, constr); + } + { + let constr = builder.mul_sub_extension(new_filter, channel.is_read, new_filter); + yield_constr.constraint_transition(builder, constr); + } + { + let diff = builder.sub_extension(channel.addr_context, nv.context); + let constr = builder.mul_extension(new_filter, diff); + yield_constr.constraint_transition(builder, constr); + } + { + let constr = builder.arithmetic_extension( + F::ONE, + -F::from_canonical_u64(Segment::Stack as u64), + new_filter, + channel.addr_segment, + new_filter, + ); + yield_constr.constraint_transition(builder, constr); + } + { + let diff = builder.sub_extension(channel.addr_virtual, nv.stack_len); + let constr = + builder.arithmetic_extension(F::ONE, F::ONE, new_filter, diff, new_filter); + yield_constr.constraint_transition(builder, constr); + } + // Constrain `stack_inv_aux`. + { + let prod = builder.mul_extension(len_diff, lv.general.stack().stack_inv); + let diff = builder.sub_extension(prod, lv.general.stack().stack_inv_aux); + let constr = builder.mul_extension(filter, diff); + yield_constr.constraint(builder, constr); + } + // Disable channel if stack_len == N. + { + let empty_stack_filter = + builder.mul_sub_extension(filter, lv.general.stack().stack_inv_aux, filter); + let constr = builder.mul_extension(empty_stack_filter, channel.used); + yield_constr.constraint_transition(builder, constr); + } + } + } + // If the op only pushes, you only need to constrain the top of the stack if the stack isn't empty. + else if stack_behavior.pushes { + // If len > 0... + let new_filter = builder.mul_extension(lv.stack_len, filter); + // You write the previous top of the stack in memory, in the last channel. + let channel = lv.mem_channels[NUM_GP_CHANNELS - 1]; { - let constr = builder.mul_sub_extension(filter, channel.used, filter); + let constr = builder.mul_sub_extension(new_filter, channel.used, new_filter); yield_constr.constraint(builder, constr); } { - let constr = builder.mul_sub_extension(filter, channel.is_read, filter); + let constr = builder.mul_extension(new_filter, channel.is_read); yield_constr.constraint(builder, constr); } { let diff = builder.sub_extension(channel.addr_context, lv.context); - let constr = builder.mul_extension(filter, diff); + let constr = builder.mul_extension(new_filter, diff); yield_constr.constraint(builder, constr); } { let constr = builder.arithmetic_extension( F::ONE, -F::from_canonical_u64(Segment::Stack as u64), - filter, + new_filter, channel.addr_segment, - filter, + new_filter, ); yield_constr.constraint(builder, constr); } { let diff = builder.sub_extension(channel.addr_virtual, lv.stack_len); - let constr = builder.arithmetic_extension( - F::ONE, - F::from_canonical_usize(i + 1), - filter, - diff, - filter, - ); + let constr = builder.arithmetic_extension(F::ONE, F::ONE, new_filter, diff, new_filter); + yield_constr.constraint(builder, constr); + } + for (limb_ch, limb_top) in channel.value.iter().zip(lv.mem_channels[0].value.iter()) { + let diff = builder.sub_extension(*limb_ch, *limb_top); + let constr = builder.mul_extension(new_filter, diff); + yield_constr.constraint(builder, constr); + } + // Else you disable the channel. + { + let diff = builder.mul_extension(lv.stack_len, lv.general.stack().stack_inv); + let diff = builder.sub_extension(diff, lv.general.stack().stack_inv_aux); + let constr = builder.mul_extension(filter, diff); + yield_constr.constraint(builder, constr); + } + { + let empty_stack_filter = + builder.mul_sub_extension(filter, lv.general.stack().stack_inv_aux, filter); + let constr = builder.mul_extension(empty_stack_filter, channel.used); yield_constr.constraint(builder, constr); } } - - // Pushes - if stack_behavior.pushes { - let channel = lv.mem_channels[NUM_GP_CHANNELS - 1]; - + // If the op doesn't pop nor push, the top of the stack must not change. + else { { - let constr = builder.mul_sub_extension(filter, channel.used, filter); + let constr = builder.mul_extension(filter, nv.mem_channels[0].used); yield_constr.constraint(builder, constr); } { - let constr = builder.mul_extension(filter, channel.is_read); - yield_constr.constraint(builder, constr); + for (limb_old, limb_new) in lv.mem_channels[0] + .value + .iter() + .zip(nv.mem_channels[0].value.iter()) + { + let diff = builder.sub_extension(*limb_old, *limb_new); + let constr = builder.mul_extension(filter, diff); + yield_constr.constraint(builder, constr); + } } + } + // Maybe constrain next stack_top. + // These are transition constraints: they don't apply to the last row. + if let Some(next_top_ch) = stack_behavior.new_top_stack_channel { + for (limb_ch, limb_top) in lv.mem_channels[next_top_ch] + .value + .iter() + .zip(nv.mem_channels[0].value.iter()) { - let diff = builder.sub_extension(channel.addr_context, lv.context); + let diff = builder.sub_extension(*limb_ch, *limb_top); let constr = builder.mul_extension(filter, diff); - yield_constr.constraint(builder, constr); - } - { - let constr = builder.arithmetic_extension( - F::ONE, - -F::from_canonical_u64(Segment::Stack as u64), - filter, - channel.addr_segment, - filter, - ); - yield_constr.constraint(builder, constr); - } - { - let diff = builder.sub_extension(channel.addr_virtual, lv.stack_len); - let constr = builder.arithmetic_extension( - F::ONE, - F::from_canonical_usize(stack_behavior.num_pops), - filter, - diff, - filter, - ); - yield_constr.constraint(builder, constr); + yield_constr.constraint_transition(builder, constr); } } // Unused channels if stack_behavior.disable_other_channels { - for i in stack_behavior.num_pops..NUM_GP_CHANNELS - (stack_behavior.pushes as usize) { + // The first channel contains (or not) the top od the stack and is constrained elsewhere. + for i in max(1, stack_behavior.num_pops)..NUM_GP_CHANNELS - (stack_behavior.pushes as usize) + { let channel = lv.mem_channels[i]; let constr = builder.mul_extension(filter, channel.used); yield_constr.constraint(builder, constr); diff --git a/evm/src/cpu/syscalls_exceptions.rs b/evm/src/cpu/syscalls_exceptions.rs index f9ea9a0a..1437fba0 100644 --- a/evm/src/cpu/syscalls_exceptions.rs +++ b/evm/src/cpu/syscalls_exceptions.rs @@ -64,7 +64,7 @@ pub fn eval_packed( let exc_handler_addr_start = exc_jumptable_start + exc_code * P::Scalar::from_canonical_usize(BYTES_PER_OFFSET); - for (i, channel) in lv.mem_channels[0..BYTES_PER_OFFSET].iter().enumerate() { + for (i, channel) in lv.mem_channels[1..BYTES_PER_OFFSET + 1].iter().enumerate() { yield_constr.constraint(total_filter * (channel.used - P::ONES)); yield_constr.constraint(total_filter * (channel.is_read - P::ONES)); @@ -81,13 +81,13 @@ pub fn eval_packed( } // Disable unused channels (the last channel is used to push to the stack) - for channel in &lv.mem_channels[BYTES_PER_OFFSET..NUM_GP_CHANNELS - 1] { + for channel in &lv.mem_channels[BYTES_PER_OFFSET + 1..NUM_GP_CHANNELS - 1] { yield_constr.constraint(total_filter * channel.used); } // Set program counter to the handler address // The addresses are big-endian in memory - let target = lv.mem_channels[0..BYTES_PER_OFFSET] + let target = lv.mem_channels[1..BYTES_PER_OFFSET + 1] .iter() .map(|channel| channel.value[0]) .fold(P::ZEROS, |cumul, limb| { @@ -102,9 +102,8 @@ pub fn eval_packed( yield_constr.constraint_transition(total_filter * nv.gas[0]); yield_constr.constraint_transition(total_filter * nv.gas[1]); - // This memory channel is constrained in `stack.rs`. - let output = lv.mem_channels[NUM_GP_CHANNELS - 1].value; - // Push to stack: current PC + 1 (limb 0), kernel flag (limb 1), gas counter (limbs 6 and 7). + let output = nv.mem_channels[0].value; + // New top of the stack: current PC + 1 (limb 0), kernel flag (limb 1), gas counter (limbs 6 and 7). yield_constr.constraint(filter_syscall * (output[0] - (lv.program_counter + P::ONES))); yield_constr.constraint(filter_exception * (output[0] - lv.program_counter)); // Check the kernel mode, for syscalls only @@ -182,7 +181,7 @@ pub fn eval_ext_circuit, const D: usize>( exc_jumptable_start, ); - for (i, channel) in lv.mem_channels[0..BYTES_PER_OFFSET].iter().enumerate() { + for (i, channel) in lv.mem_channels[1..BYTES_PER_OFFSET + 1].iter().enumerate() { { let constr = builder.mul_sub_extension(total_filter, channel.used, total_filter); yield_constr.constraint(builder, constr); @@ -235,7 +234,7 @@ pub fn eval_ext_circuit, const D: usize>( } // Disable unused channels (the last channel is used to push to the stack) - for channel in &lv.mem_channels[BYTES_PER_OFFSET..NUM_GP_CHANNELS - 1] { + for channel in &lv.mem_channels[BYTES_PER_OFFSET + 1..NUM_GP_CHANNELS - 1] { let constr = builder.mul_extension(total_filter, channel.used); yield_constr.constraint(builder, constr); } @@ -243,7 +242,7 @@ pub fn eval_ext_circuit, const D: usize>( // Set program counter to the handler address // The addresses are big-endian in memory { - let target = lv.mem_channels[0..BYTES_PER_OFFSET] + let target = lv.mem_channels[1..BYTES_PER_OFFSET + 1] .iter() .map(|channel| channel.value[0]) .fold(builder.zero_extension(), |cumul, limb| { @@ -272,8 +271,8 @@ pub fn eval_ext_circuit, const D: usize>( yield_constr.constraint_transition(builder, constr); } - // This memory channel is constrained in `stack.rs`. - let output = lv.mem_channels[NUM_GP_CHANNELS - 1].value; + // New top of the stack. + let output = nv.mem_channels[0].value; // Push to stack (syscall): current PC + 1 (limb 0), kernel flag (limb 1), gas counter (limbs 6 and 7). { let pc_plus_1 = builder.add_const_extension(lv.program_counter, F::ONE); diff --git a/evm/src/witness/memory.rs b/evm/src/witness/memory.rs index 3b62c945..5d589934 100644 --- a/evm/src/witness/memory.rs +++ b/evm/src/witness/memory.rs @@ -88,6 +88,18 @@ pub struct MemoryOp { pub value: U256, } +pub static DUMMY_MEMOP: MemoryOp = MemoryOp { + filter: false, + timestamp: 0, + address: MemoryAddress { + context: 0, + segment: 0, + virt: 0, + }, + kind: MemoryOpKind::Read, + value: U256::zero(), +}; + impl MemoryOp { pub fn new( channel: MemoryChannel, diff --git a/evm/src/witness/operation.rs b/evm/src/witness/operation.rs index f4dc03e8..a503ab49 100644 --- a/evm/src/witness/operation.rs +++ b/evm/src/witness/operation.rs @@ -3,7 +3,7 @@ use itertools::Itertools; use keccak_hash::keccak; use plonky2::field::types::Field; -use super::util::{byte_packing_log, byte_unpacking_log}; +use super::util::{byte_packing_log, byte_unpacking_log, push_no_write, push_with_write}; use crate::arithmetic::BinaryOperator; use crate::cpu::columns::CpuColumnsView; use crate::cpu::kernel::aggregator::KERNEL; @@ -20,9 +20,10 @@ use crate::witness::errors::MemoryError::{ContextTooLarge, SegmentTooLarge, Virt use crate::witness::errors::ProgramError; use crate::witness::errors::ProgramError::MemoryError; use crate::witness::memory::{MemoryAddress, MemoryChannel, MemoryOp, MemoryOpKind}; +use crate::witness::operation::MemoryChannel::GeneralPurpose; use crate::witness::util::{ keccak_sponge_log, mem_read_gp_with_log_and_fill, mem_write_gp_log_and_fill, - stack_pop_with_log_and_fill, stack_push_log_and_fill, + stack_pop_with_log_and_fill, }; use crate::{arithmetic, logic}; @@ -59,14 +60,13 @@ pub(crate) fn generate_binary_logic_op( state: &mut GenerationState, mut row: CpuColumnsView, ) -> Result<(), ProgramError> { - let [(in0, log_in0), (in1, log_in1)] = stack_pop_with_log_and_fill::<2, _>(state, &mut row)?; + let [(in0, _), (in1, log_in1)] = stack_pop_with_log_and_fill::<2, _>(state, &mut row)?; let operation = logic::Operation::new(op, in0, in1); - let log_out = stack_push_log_and_fill(state, &mut row, operation.result)?; + + push_no_write(state, &mut row, operation.result, Some(NUM_GP_CHANNELS - 1)); state.traces.push_logic(operation); - state.traces.push_memory(log_in0); state.traces.push_memory(log_in1); - state.traces.push_memory(log_out); state.traces.push_cpu(row); Ok(()) } @@ -76,10 +76,8 @@ pub(crate) fn generate_binary_arithmetic_op( state: &mut GenerationState, mut row: CpuColumnsView, ) -> Result<(), ProgramError> { - let [(input0, log_in0), (input1, log_in1)] = - stack_pop_with_log_and_fill::<2, _>(state, &mut row)?; + let [(input0, _), (input1, log_in1)] = stack_pop_with_log_and_fill::<2, _>(state, &mut row)?; let operation = arithmetic::Operation::binary(operator, input0, input1); - let log_out = stack_push_log_and_fill(state, &mut row, operation.result())?; if operator == arithmetic::BinaryOperator::AddFp254 || operator == arithmetic::BinaryOperator::MulFp254 @@ -94,10 +92,15 @@ pub(crate) fn generate_binary_arithmetic_op( } } + push_no_write( + state, + &mut row, + operation.result(), + Some(NUM_GP_CHANNELS - 1), + ); + state.traces.push_arithmetic(operation); - state.traces.push_memory(log_in0); state.traces.push_memory(log_in1); - state.traces.push_memory(log_out); state.traces.push_cpu(row); Ok(()) } @@ -107,16 +110,20 @@ pub(crate) fn generate_ternary_arithmetic_op( state: &mut GenerationState, mut row: CpuColumnsView, ) -> Result<(), ProgramError> { - let [(input0, log_in0), (input1, log_in1), (input2, log_in2)] = + let [(input0, _), (input1, log_in1), (input2, log_in2)] = stack_pop_with_log_and_fill::<3, _>(state, &mut row)?; let operation = arithmetic::Operation::ternary(operator, input0, input1, input2); - let log_out = stack_push_log_and_fill(state, &mut row, operation.result())?; + + push_no_write( + state, + &mut row, + operation.result(), + Some(NUM_GP_CHANNELS - 1), + ); state.traces.push_arithmetic(operation); - state.traces.push_memory(log_in0); state.traces.push_memory(log_in1); state.traces.push_memory(log_in2); - state.traces.push_memory(log_out); state.traces.push_cpu(row); Ok(()) } @@ -126,7 +133,7 @@ pub(crate) fn generate_keccak_general( mut row: CpuColumnsView, ) -> Result<(), ProgramError> { row.is_keccak_sponge = F::ONE; - let [(context, log_in0), (segment, log_in1), (base_virt, log_in2), (len, log_in3)] = + let [(context, _), (segment, log_in1), (base_virt, log_in2), (len, log_in3)] = stack_pop_with_log_and_fill::<4, _>(state, &mut row)?; let len = u256_to_usize(len)?; @@ -144,15 +151,13 @@ pub(crate) fn generate_keccak_general( log::debug!("Hashing {:?}", input); let hash = keccak(&input); - let log_push = stack_push_log_and_fill(state, &mut row, hash.into_uint())?; + push_no_write(state, &mut row, hash.into_uint(), Some(NUM_GP_CHANNELS - 1)); keccak_sponge_log(state, base_address, input); - state.traces.push_memory(log_in0); state.traces.push_memory(log_in1); state.traces.push_memory(log_in2); state.traces.push_memory(log_in3); - state.traces.push_memory(log_push); state.traces.push_cpu(row); Ok(()) } @@ -164,9 +169,7 @@ pub(crate) fn generate_prover_input( let pc = state.registers.program_counter; let input_fn = &KERNEL.prover_inputs[&pc]; let input = state.prover_input(input_fn)?; - let write = stack_push_log_and_fill(state, &mut row, input)?; - - state.traces.push_memory(write); + push_with_write(state, &mut row, input)?; state.traces.push_cpu(row); Ok(()) } @@ -175,10 +178,10 @@ pub(crate) fn generate_pop( state: &mut GenerationState, mut row: CpuColumnsView, ) -> Result<(), ProgramError> { - let [(_, log_in)] = stack_pop_with_log_and_fill::<1, _>(state, &mut row)?; + let [(_, _)] = stack_pop_with_log_and_fill::<1, _>(state, &mut row)?; - state.traces.push_memory(log_in); state.traces.push_cpu(row); + Ok(()) } @@ -186,7 +189,8 @@ pub(crate) fn generate_jump( state: &mut GenerationState, mut row: CpuColumnsView, ) -> Result<(), ProgramError> { - let [(dst, log_in0)] = stack_pop_with_log_and_fill::<1, _>(state, &mut row)?; + let [(dst, _)] = stack_pop_with_log_and_fill::<1, _>(state, &mut row)?; + let dst: u32 = dst .try_into() .map_err(|_| ProgramError::InvalidJumpDestination)?; @@ -216,7 +220,15 @@ pub(crate) fn generate_jump( row.general.jumps_mut().should_jump = F::ONE; row.general.jumps_mut().cond_sum_pinv = F::ONE; - state.traces.push_memory(log_in0); + let diff = row.stack_len - F::ONE; + if let Some(inv) = diff.try_inverse() { + row.general.stack_mut().stack_inv = inv; + row.general.stack_mut().stack_inv_aux = F::ONE; + } else { + row.general.stack_mut().stack_inv = F::ZERO; + row.general.stack_mut().stack_inv_aux = F::ZERO; + } + state.traces.push_cpu(row); state.jump_to(dst as usize)?; Ok(()) @@ -226,7 +238,7 @@ pub(crate) fn generate_jumpi( state: &mut GenerationState, mut row: CpuColumnsView, ) -> Result<(), ProgramError> { - let [(dst, log_in0), (cond, log_in1)] = stack_pop_with_log_and_fill::<2, _>(state, &mut row)?; + let [(dst, _), (cond, log_cond)] = stack_pop_with_log_and_fill::<2, _>(state, &mut row)?; let should_jump = !cond.is_zero(); if should_jump { @@ -271,8 +283,16 @@ pub(crate) fn generate_jumpi( state.traces.push_memory(jumpdest_bit_log); } - state.traces.push_memory(log_in0); - state.traces.push_memory(log_in1); + let diff = row.stack_len - F::TWO; + if let Some(inv) = diff.try_inverse() { + row.general.stack_mut().stack_inv = inv; + row.general.stack_mut().stack_inv_aux = F::ONE; + } else { + row.general.stack_mut().stack_inv = F::ZERO; + row.general.stack_mut().stack_inv_aux = F::ZERO; + } + + state.traces.push_memory(log_cond); state.traces.push_cpu(row); Ok(()) } @@ -281,8 +301,7 @@ pub(crate) fn generate_pc( state: &mut GenerationState, mut row: CpuColumnsView, ) -> Result<(), ProgramError> { - let write = stack_push_log_and_fill(state, &mut row, state.registers.program_counter.into())?; - state.traces.push_memory(write); + push_with_write(state, &mut row, state.registers.program_counter.into())?; state.traces.push_cpu(row); Ok(()) } @@ -299,9 +318,7 @@ pub(crate) fn generate_get_context( state: &mut GenerationState, mut row: CpuColumnsView, ) -> Result<(), ProgramError> { - let ctx = state.registers.context.into(); - let write = stack_push_log_and_fill(state, &mut row, ctx)?; - state.traces.push_memory(write); + push_with_write(state, &mut row, state.registers.context.into())?; state.traces.push_cpu(row); Ok(()) } @@ -310,8 +327,10 @@ pub(crate) fn generate_set_context( state: &mut GenerationState, mut row: CpuColumnsView, ) -> Result<(), ProgramError> { - let [(ctx, log_in)] = stack_pop_with_log_and_fill::<1, _>(state, &mut row)?; + let [(ctx, _)] = stack_pop_with_log_and_fill::<1, _>(state, &mut row)?; + let sp_to_save = state.registers.stack_len.into(); + let old_ctx = state.registers.context; let new_ctx = u256_to_usize(ctx)?; @@ -347,10 +366,31 @@ pub(crate) fn generate_set_context( mem_read_gp_with_log_and_fill(2, new_sp_addr, state, &mut row) }; + // If the new stack isn't empty, read stack_top from memory. + let new_sp = new_sp.as_usize(); + if new_sp > 0 { + // Set up columns to disable the channel if it *is* empty. + let new_sp_field = F::from_canonical_usize(new_sp); + if let Some(inv) = new_sp_field.try_inverse() { + row.general.stack_mut().stack_inv = inv; + row.general.stack_mut().stack_inv_aux = F::ONE; + } else { + row.general.stack_mut().stack_inv = F::ZERO; + row.general.stack_mut().stack_inv_aux = F::ZERO; + } + + let new_top_addr = MemoryAddress::new(new_ctx, Segment::Stack, new_sp - 1); + let (new_top, log_read_new_top) = + mem_read_gp_with_log_and_fill(3, new_top_addr, state, &mut row); + state.registers.stack_top = new_top; + state.traces.push_memory(log_read_new_top); + } else { + row.general.stack_mut().stack_inv = F::ZERO; + row.general.stack_mut().stack_inv_aux = F::ZERO; + } + state.registers.context = new_ctx; - let new_sp = u256_to_usize(new_sp)?; state.registers.stack_len = new_sp; - state.traces.push_memory(log_in); state.traces.push_memory(log_write_old_sp); state.traces.push_memory(log_read_new_sp); state.traces.push_cpu(row); @@ -386,31 +426,76 @@ pub(crate) fn generate_push( .collect_vec(); let val = U256::from_big_endian(&bytes); - let write = stack_push_log_and_fill(state, &mut row, val)?; - - state.traces.push_memory(write); + push_with_write(state, &mut row, val)?; state.traces.push_cpu(row); Ok(()) } +// This instruction is special. The order of the operations are: +// - Write `stack_top` at `stack[stack_len - 1]` +// - Read `val` at `stack[stack_len - 1 - n]` +// - Update `stack_top` with `val` and add 1 to `stack_len` +// Since the write must happen before the read, the normal way of assigning +// GP channels doesn't work and we must handle them manually. pub(crate) fn generate_dup( n: u8, state: &mut GenerationState, mut row: CpuColumnsView, ) -> Result<(), ProgramError> { - let other_addr_lo = state - .registers - .stack_len - .checked_sub(1 + (n as usize)) - .ok_or(ProgramError::StackUnderflow)?; - let other_addr = MemoryAddress::new(state.registers.context, Segment::Stack, other_addr_lo); + // Same logic as in `push_with_write`, but we use the channel GP(0) instead. + if !state.registers.is_kernel && state.registers.stack_len >= MAX_USER_STACK_SIZE { + return Err(ProgramError::StackOverflow); + } + if n as usize >= state.registers.stack_len { + return Err(ProgramError::StackUnderflow); + } + let stack_top = state.registers.stack_top; + let address = MemoryAddress::new( + state.registers.context, + Segment::Stack, + state.registers.stack_len - 1, + ); + let log_push = mem_write_gp_log_and_fill(1, address, state, &mut row, stack_top); + state.traces.push_memory(log_push); - let (val, log_in) = mem_read_gp_with_log_and_fill(0, other_addr, state, &mut row); - let log_out = stack_push_log_and_fill(state, &mut row, val)?; + let other_addr = MemoryAddress::new( + state.registers.context, + Segment::Stack, + state.registers.stack_len - 1 - n as usize, + ); - state.traces.push_memory(log_in); - state.traces.push_memory(log_out); + // If n = 0, we read a value that hasn't been written to memory: the corresponding write + // is buffered in the mem_ops queue, but hasn't been applied yet. + let (val, log_read) = if n == 0 { + let op = MemoryOp::new( + MemoryChannel::GeneralPurpose(2), + state.traces.clock(), + other_addr, + MemoryOpKind::Read, + stack_top, + ); + + let channel = &mut row.mem_channels[2]; + assert_eq!(channel.used, F::ZERO); + channel.used = F::ONE; + channel.is_read = F::ONE; + channel.addr_context = F::from_canonical_usize(other_addr.context); + channel.addr_segment = F::from_canonical_usize(other_addr.segment); + channel.addr_virtual = F::from_canonical_usize(other_addr.virt); + let val_limbs: [u64; 4] = state.registers.stack_top.0; + for (i, limb) in val_limbs.into_iter().enumerate() { + channel.value[2 * i] = F::from_canonical_u32(limb as u32); + channel.value[2 * i + 1] = F::from_canonical_u32((limb >> 32) as u32); + } + + (stack_top, op) + } else { + mem_read_gp_with_log_and_fill(2, other_addr, state, &mut row) + }; + push_no_write(state, &mut row, val, None); + + state.traces.push_memory(log_read); state.traces.push_cpu(row); Ok(()) } @@ -427,15 +512,13 @@ pub(crate) fn generate_swap( .ok_or(ProgramError::StackUnderflow)?; let other_addr = MemoryAddress::new(state.registers.context, Segment::Stack, other_addr_lo); - let [(in0, log_in0)] = stack_pop_with_log_and_fill::<1, _>(state, &mut row)?; + let [(in0, _)] = stack_pop_with_log_and_fill::<1, _>(state, &mut row)?; let (in1, log_in1) = mem_read_gp_with_log_and_fill(1, other_addr, state, &mut row); - let log_out0 = mem_write_gp_log_and_fill(NUM_GP_CHANNELS - 2, other_addr, state, &mut row, in0); - let log_out1 = stack_push_log_and_fill(state, &mut row, in1)?; + let log_out0 = mem_write_gp_log_and_fill(2, other_addr, state, &mut row, in0); + push_no_write(state, &mut row, in1, None); - state.traces.push_memory(log_in0); state.traces.push_memory(log_in1); state.traces.push_memory(log_out0); - state.traces.push_memory(log_out1); state.traces.push_cpu(row); Ok(()) } @@ -444,12 +527,10 @@ pub(crate) fn generate_not( state: &mut GenerationState, mut row: CpuColumnsView, ) -> Result<(), ProgramError> { - let [(x, log_in)] = stack_pop_with_log_and_fill::<1, _>(state, &mut row)?; + let [(x, _)] = stack_pop_with_log_and_fill::<1, _>(state, &mut row)?; let result = !x; - let log_out = stack_push_log_and_fill(state, &mut row, result)?; + push_no_write(state, &mut row, result, Some(NUM_GP_CHANNELS - 1)); - state.traces.push_memory(log_in); - state.traces.push_memory(log_out); state.traces.push_cpu(row); Ok(()) } @@ -458,18 +539,16 @@ pub(crate) fn generate_iszero( state: &mut GenerationState, mut row: CpuColumnsView, ) -> Result<(), ProgramError> { - let [(x, log_in)] = stack_pop_with_log_and_fill::<1, _>(state, &mut row)?; + let [(x, _)] = stack_pop_with_log_and_fill::<1, _>(state, &mut row)?; let is_zero = x.is_zero(); let result = { let t: u64 = is_zero.into(); t.into() }; - let log_out = stack_push_log_and_fill(state, &mut row, result)?; generate_pinv_diff(x, U256::zero(), &mut row); - state.traces.push_memory(log_in); - state.traces.push_memory(log_out); + push_no_write(state, &mut row, result, None); state.traces.push_cpu(row); Ok(()) } @@ -480,12 +559,9 @@ fn append_shift( is_shl: bool, input0: U256, input1: U256, - log_in0: MemoryOp, log_in1: MemoryOp, result: U256, ) -> Result<(), ProgramError> { - let log_out = stack_push_log_and_fill(state, &mut row, result)?; - const LOOKUP_CHANNEL: usize = 2; let lookup_addr = MemoryAddress::new(0, Segment::ShiftTable, input0.low_u32() as usize); if input0.bits() <= 32 { @@ -511,9 +587,8 @@ fn append_shift( let operation = arithmetic::Operation::binary(operator, input0, input1); state.traces.push_arithmetic(operation); - state.traces.push_memory(log_in0); + push_no_write(state, &mut row, result, Some(NUM_GP_CHANNELS - 1)); state.traces.push_memory(log_in1); - state.traces.push_memory(log_out); state.traces.push_cpu(row); Ok(()) } @@ -522,30 +597,28 @@ pub(crate) fn generate_shl( state: &mut GenerationState, mut row: CpuColumnsView, ) -> Result<(), ProgramError> { - let [(input0, log_in0), (input1, log_in1)] = - stack_pop_with_log_and_fill::<2, _>(state, &mut row)?; + let [(input0, _), (input1, log_in1)] = stack_pop_with_log_and_fill::<2, _>(state, &mut row)?; let result = if input0 > U256::from(255u64) { U256::zero() } else { input1 << input0 }; - append_shift(state, row, true, input0, input1, log_in0, log_in1, result) + append_shift(state, row, true, input0, input1, log_in1, result) } pub(crate) fn generate_shr( state: &mut GenerationState, mut row: CpuColumnsView, ) -> Result<(), ProgramError> { - let [(input0, log_in0), (input1, log_in1)] = - stack_pop_with_log_and_fill::<2, _>(state, &mut row)?; + let [(input0, _), (input1, log_in1)] = stack_pop_with_log_and_fill::<2, _>(state, &mut row)?; let result = if input0 > U256::from(255u64) { U256::zero() } else { input1 >> input0 }; - append_shift(state, row, false, input0, input1, log_in0, log_in1, result) + append_shift(state, row, false, input0, input1, log_in1, result) } pub(crate) fn generate_syscall( @@ -574,19 +647,19 @@ pub(crate) fn generate_syscall( handler_jumptable_addr + (opcode as usize) * (BYTES_PER_OFFSET as usize); assert_eq!(BYTES_PER_OFFSET, 3, "Code below assumes 3 bytes per offset"); let (handler_addr0, log_in0) = mem_read_gp_with_log_and_fill( - 0, + 1, MemoryAddress::new(0, Segment::Code, handler_addr_addr), state, &mut row, ); let (handler_addr1, log_in1) = mem_read_gp_with_log_and_fill( - 1, + 2, MemoryAddress::new(0, Segment::Code, handler_addr_addr + 1), state, &mut row, ); let (handler_addr2, log_in2) = mem_read_gp_with_log_and_fill( - 2, + 3, MemoryAddress::new(0, Segment::Code, handler_addr_addr + 2), state, &mut row, @@ -606,14 +679,13 @@ pub(crate) fn generate_syscall( state.registers.is_kernel = true; state.registers.gas_used = 0; - let log_out = stack_push_log_and_fill(state, &mut row, syscall_info)?; + push_with_write(state, &mut row, syscall_info)?; log::debug!("Syscall to {}", KERNEL.offset_name(new_program_counter)); state.traces.push_memory(log_in0); state.traces.push_memory(log_in1); state.traces.push_memory(log_in2); - state.traces.push_memory(log_out); state.traces.push_cpu(row); Ok(()) @@ -623,16 +695,14 @@ pub(crate) fn generate_eq( state: &mut GenerationState, mut row: CpuColumnsView, ) -> Result<(), ProgramError> { - let [(in0, log_in0), (in1, log_in1)] = stack_pop_with_log_and_fill::<2, _>(state, &mut row)?; + let [(in0, _), (in1, log_in1)] = stack_pop_with_log_and_fill::<2, _>(state, &mut row)?; let eq = in0 == in1; let result = U256::from(u64::from(eq)); - let log_out = stack_push_log_and_fill(state, &mut row, result)?; generate_pinv_diff(in0, in1, &mut row); - state.traces.push_memory(log_in0); + push_no_write(state, &mut row, result, None); state.traces.push_memory(log_in1); - state.traces.push_memory(log_out); state.traces.push_cpu(row); Ok(()) } @@ -641,7 +711,7 @@ pub(crate) fn generate_exit_kernel( state: &mut GenerationState, mut row: CpuColumnsView, ) -> Result<(), ProgramError> { - let [(kexit_info, log_in)] = stack_pop_with_log_and_fill::<1, _>(state, &mut row)?; + let [(kexit_info, _)] = stack_pop_with_log_and_fill::<1, _>(state, &mut row)?; let kexit_info_u64 = kexit_info.0[0]; let program_counter = kexit_info_u64 as u32 as usize; let is_kernel_mode_val = (kexit_info_u64 >> 32) as u32; @@ -661,7 +731,6 @@ pub(crate) fn generate_exit_kernel( is_kernel_mode ); - state.traces.push_memory(log_in); state.traces.push_cpu(row); Ok(()) @@ -671,7 +740,7 @@ pub(crate) fn generate_mload_general( state: &mut GenerationState, mut row: CpuColumnsView, ) -> Result<(), ProgramError> { - let [(context, log_in0), (segment, log_in1), (virt, log_in2)] = + let [(context, _), (segment, log_in1), (virt, log_in2)] = stack_pop_with_log_and_fill::<3, _>(state, &mut row)?; let (val, log_read) = mem_read_gp_with_log_and_fill( @@ -680,14 +749,20 @@ pub(crate) fn generate_mload_general( state, &mut row, ); + push_no_write(state, &mut row, val, None); - let log_out = stack_push_log_and_fill(state, &mut row, val)?; + let diff = row.stack_len - F::from_canonical_usize(4); + if let Some(inv) = diff.try_inverse() { + row.general.stack_mut().stack_inv = inv; + row.general.stack_mut().stack_inv_aux = F::ONE; + } else { + row.general.stack_mut().stack_inv = F::ZERO; + row.general.stack_mut().stack_inv_aux = F::ZERO; + } - state.traces.push_memory(log_in0); state.traces.push_memory(log_in1); state.traces.push_memory(log_in2); state.traces.push_memory(log_read); - state.traces.push_memory(log_out); state.traces.push_cpu(row); Ok(()) } @@ -696,7 +771,7 @@ pub(crate) fn generate_mload_32bytes( state: &mut GenerationState, mut row: CpuColumnsView, ) -> Result<(), ProgramError> { - let [(context, log_in0), (segment, log_in1), (base_virt, log_in2), (len, log_in3)] = + let [(context, _), (segment, log_in1), (base_virt, log_in2), (len, log_in3)] = stack_pop_with_log_and_fill::<4, _>(state, &mut row)?; let len = u256_to_usize(len)?; if len > 32 { @@ -722,15 +797,13 @@ pub(crate) fn generate_mload_32bytes( .collect_vec(); let packed_int = U256::from_big_endian(&bytes); - let log_out = stack_push_log_and_fill(state, &mut row, packed_int)?; + push_no_write(state, &mut row, packed_int, Some(4)); byte_packing_log(state, base_address, bytes); - state.traces.push_memory(log_in0); state.traces.push_memory(log_in1); state.traces.push_memory(log_in2); state.traces.push_memory(log_in3); - state.traces.push_memory(log_out); state.traces.push_cpu(row); Ok(()) } @@ -739,7 +812,7 @@ pub(crate) fn generate_mstore_general( state: &mut GenerationState, mut row: CpuColumnsView, ) -> Result<(), ProgramError> { - let [(context, log_in0), (segment, log_in1), (virt, log_in2), (val, log_in3)] = + let [(context, _), (segment, log_in1), (virt, log_in2), (val, log_in3)] = stack_pop_with_log_and_fill::<4, _>(state, &mut row)?; let address = MemoryAddress { @@ -755,12 +828,23 @@ pub(crate) fn generate_mstore_general( }; let log_write = mem_write_gp_log_and_fill(4, address, state, &mut row, val); - state.traces.push_memory(log_in0); + let diff = row.stack_len - F::from_canonical_usize(4); + if let Some(inv) = diff.try_inverse() { + row.general.stack_mut().stack_inv = inv; + row.general.stack_mut().stack_inv_aux = F::ONE; + row.general.stack_mut().stack_inv_aux_2 = F::ONE; + state.registers.is_stack_top_read = true; + } else { + row.general.stack_mut().stack_inv = F::ZERO; + row.general.stack_mut().stack_inv_aux = F::ZERO; + } + state.traces.push_memory(log_in1); state.traces.push_memory(log_in2); state.traces.push_memory(log_in3); state.traces.push_memory(log_write); state.traces.push_cpu(row); + Ok(()) } @@ -768,7 +852,7 @@ pub(crate) fn generate_mstore_32bytes( state: &mut GenerationState, mut row: CpuColumnsView, ) -> Result<(), ProgramError> { - let [(context, log_in0), (segment, log_in1), (base_virt, log_in2), (val, log_in3), (len, log_in4)] = + let [(context, _), (segment, log_in1), (base_virt, log_in2), (val, log_in3), (len, log_in4)] = stack_pop_with_log_and_fill::<5, _>(state, &mut row)?; let len = u256_to_usize(len)?; @@ -776,7 +860,6 @@ pub(crate) fn generate_mstore_32bytes( byte_unpacking_log(state, base_address, val, len); - state.traces.push_memory(log_in0); state.traces.push_memory(log_in1); state.traces.push_memory(log_in2); state.traces.push_memory(log_in3); @@ -805,6 +888,36 @@ pub(crate) fn generate_exception( return Err(ProgramError::InterpreterError); } + if let Some(inv) = row.stack_len.try_inverse() { + row.general.stack_mut().stack_inv = inv; + row.general.stack_mut().stack_inv_aux = F::ONE; + } + + if state.registers.is_stack_top_read { + let channel = &mut row.mem_channels[0]; + channel.used = F::ONE; + channel.is_read = F::ONE; + channel.addr_context = F::from_canonical_usize(state.registers.context); + channel.addr_segment = F::from_canonical_usize(Segment::Stack as usize); + channel.addr_virtual = F::from_canonical_usize(state.registers.stack_len - 1); + + let address = MemoryAddress { + context: state.registers.context, + segment: Segment::Stack as usize, + virt: state.registers.stack_len - 1, + }; + + let mem_op = MemoryOp::new( + GeneralPurpose(0), + state.traces.clock(), + address, + MemoryOpKind::Read, + state.registers.stack_top, + ); + state.traces.push_memory(mem_op); + state.registers.is_stack_top_read = false; + } + row.general.exception_mut().exc_code_bits = [ F::from_bool(exc_code & 1 != 0), F::from_bool(exc_code & 2 != 0), @@ -816,19 +929,19 @@ pub(crate) fn generate_exception( handler_jumptable_addr + (exc_code as usize) * (BYTES_PER_OFFSET as usize); assert_eq!(BYTES_PER_OFFSET, 3, "Code below assumes 3 bytes per offset"); let (handler_addr0, log_in0) = mem_read_gp_with_log_and_fill( - 0, + 1, MemoryAddress::new(0, Segment::Code, handler_addr_addr), state, &mut row, ); let (handler_addr1, log_in1) = mem_read_gp_with_log_and_fill( - 1, + 2, MemoryAddress::new(0, Segment::Code, handler_addr_addr + 1), state, &mut row, ); let (handler_addr2, log_in2) = mem_read_gp_with_log_and_fill( - 2, + 3, MemoryAddress::new(0, Segment::Code, handler_addr_addr + 2), state, &mut row, @@ -847,14 +960,13 @@ pub(crate) fn generate_exception( state.registers.is_kernel = true; state.registers.gas_used = 0; - let log_out = stack_push_log_and_fill(state, &mut row, exc_info)?; + push_with_write(state, &mut row, exc_info)?; log::debug!("Exception to {}", KERNEL.offset_name(new_program_counter)); state.traces.push_memory(log_in0); state.traces.push_memory(log_in1); state.traces.push_memory(log_in2); - state.traces.push_memory(log_out); state.traces.push_cpu(row); Ok(()) diff --git a/evm/src/witness/state.rs b/evm/src/witness/state.rs index 3b37b01e..406ae856 100644 --- a/evm/src/witness/state.rs +++ b/evm/src/witness/state.rs @@ -1,3 +1,5 @@ +use ethereum_types::U256; + use crate::cpu::kernel::aggregator::KERNEL; const KERNEL_CONTEXT: usize = 0; @@ -7,6 +9,9 @@ pub struct RegistersState { pub program_counter: usize, pub is_kernel: bool, pub stack_len: usize, + pub stack_top: U256, + // Indicates if you read the new stack_top from memory to set the channel accordingly. + pub is_stack_top_read: bool, pub context: usize, pub gas_used: u64, } @@ -27,6 +32,8 @@ impl Default for RegistersState { program_counter: KERNEL.global_labels["main"], is_kernel: true, stack_len: 0, + stack_top: U256::zero(), + is_stack_top_read: false, context: 0, gas_used: 0, } diff --git a/evm/src/witness/transition.rs b/evm/src/witness/transition.rs index 2a710f4b..00030110 100644 --- a/evm/src/witness/transition.rs +++ b/evm/src/witness/transition.rs @@ -2,14 +2,20 @@ use anyhow::bail; use log::log_enabled; use plonky2::field::types::Field; +use super::memory::{MemoryOp, MemoryOpKind}; +use super::util::fill_channel_with_value; use crate::cpu::columns::CpuColumnsView; use crate::cpu::kernel::aggregator::KERNEL; +use crate::cpu::stack::{ + EQ_STACK_BEHAVIOR, IS_ZERO_STACK_BEHAVIOR, JUMPI_OP, JUMP_OP, STACK_BEHAVIORS, +}; use crate::cpu::stack_bounds::MAX_USER_STACK_SIZE; use crate::generation::state::GenerationState; use crate::memory::segments::Segment; use crate::witness::errors::ProgramError; use crate::witness::gas::gas_to_charge; use crate::witness::memory::MemoryAddress; +use crate::witness::memory::MemoryChannel::GeneralPurpose; use crate::witness::operation::*; use crate::witness::state::RegistersState; use crate::witness::util::mem_read_code_with_log_and_fill; @@ -175,7 +181,8 @@ fn fill_op_flag(op: Operation, row: &mut CpuColumnsView) { Operation::Jump | Operation::Jumpi => &mut flags.jumps, Operation::Pc => &mut flags.pc, Operation::Jumpdest => &mut flags.jumpdest, - Operation::GetContext | Operation::SetContext => &mut flags.context_op, + Operation::GetContext => &mut flags.get_context, + Operation::SetContext => &mut flags.set_context, Operation::Mload32Bytes => &mut flags.mload_32bytes, Operation::Mstore32Bytes => &mut flags.mstore_32bytes, Operation::ExitKernel => &mut flags.exit_kernel, @@ -183,6 +190,52 @@ fn fill_op_flag(op: Operation, row: &mut CpuColumnsView) { } = F::ONE; } +// Equal to the number of pops if an operation pops without pushing, and `None` otherwise. +fn get_op_special_length(op: Operation) -> Option { + let behavior_opt = match op { + Operation::Push(0) => STACK_BEHAVIORS.push0, + Operation::Push(1..) => STACK_BEHAVIORS.push, + Operation::Dup(_) => STACK_BEHAVIORS.dup, + Operation::Swap(_) => STACK_BEHAVIORS.swap, + Operation::Iszero => IS_ZERO_STACK_BEHAVIOR, + Operation::Not => STACK_BEHAVIORS.not, + Operation::Syscall(_, _, _) => STACK_BEHAVIORS.syscall, + Operation::Eq => EQ_STACK_BEHAVIOR, + Operation::BinaryLogic(_) => STACK_BEHAVIORS.logic_op, + Operation::BinaryArithmetic(arithmetic::BinaryOperator::AddFp254) + | Operation::BinaryArithmetic(arithmetic::BinaryOperator::MulFp254) + | Operation::BinaryArithmetic(arithmetic::BinaryOperator::SubFp254) => { + STACK_BEHAVIORS.fp254_op + } + Operation::BinaryArithmetic(arithmetic::BinaryOperator::Shl) + | Operation::BinaryArithmetic(arithmetic::BinaryOperator::Shr) => STACK_BEHAVIORS.shift, + Operation::BinaryArithmetic(_) => STACK_BEHAVIORS.binary_op, + Operation::TernaryArithmetic(_) => STACK_BEHAVIORS.ternary_op, + Operation::KeccakGeneral => STACK_BEHAVIORS.keccak_general, + Operation::ProverInput => STACK_BEHAVIORS.prover_input, + Operation::Pop => STACK_BEHAVIORS.pop, + Operation::Jump => JUMP_OP, + Operation::Jumpi => JUMPI_OP, + Operation::Pc => STACK_BEHAVIORS.pc, + Operation::Jumpdest => STACK_BEHAVIORS.jumpdest, + Operation::GetContext => STACK_BEHAVIORS.get_context, + Operation::SetContext => None, + Operation::Mload32Bytes => STACK_BEHAVIORS.mload_32bytes, + Operation::Mstore32Bytes => STACK_BEHAVIORS.mstore_32bytes, + Operation::ExitKernel => STACK_BEHAVIORS.exit_kernel, + Operation::MloadGeneral | Operation::MstoreGeneral => STACK_BEHAVIORS.m_op_general, + }; + if let Some(behavior) = behavior_opt { + if behavior.num_pops > 0 && !behavior.pushes { + Some(behavior.num_pops) + } else { + None + } + } else { + None + } +} + fn perform_op( state: &mut GenerationState, op: Operation, @@ -247,6 +300,7 @@ fn base_row(state: &mut GenerationState) -> (CpuColumnsView, u8) F::from_canonical_u32((state.registers.gas_used >> 32) as u32), ]; row.stack_len = F::from_canonical_usize(state.registers.stack_len); + fill_channel_with_value(&mut row, 0, state.registers.stack_top); let opcode = read_code_memory(state, &mut row); (row, opcode) @@ -264,6 +318,31 @@ fn try_perform_instruction(state: &mut GenerationState) -> Result<( fill_op_flag(op, &mut row); + if state.registers.is_stack_top_read { + let channel = &mut row.mem_channels[0]; + channel.used = F::ONE; + channel.is_read = F::ONE; + channel.addr_context = F::from_canonical_usize(state.registers.context); + channel.addr_segment = F::from_canonical_usize(Segment::Stack as usize); + channel.addr_virtual = F::from_canonical_usize(state.registers.stack_len - 1); + + let address = MemoryAddress { + context: state.registers.context, + segment: Segment::Stack as usize, + virt: state.registers.stack_len - 1, + }; + + let mem_op = MemoryOp::new( + GeneralPurpose(0), + state.traces.clock(), + address, + MemoryOpKind::Read, + state.registers.stack_top, + ); + state.traces.push_memory(mem_op); + state.registers.is_stack_top_read = false; + } + if state.registers.is_kernel { row.stack_len_bounds_aux = F::ZERO; } else { @@ -277,6 +356,21 @@ fn try_perform_instruction(state: &mut GenerationState) -> Result<( } } + // Might write in general CPU columns when it shouldn't, but the correct values will + // overwrite these ones during the op generation. + if let Some(special_len) = get_op_special_length(op) { + let special_len = F::from_canonical_usize(special_len); + let diff = row.stack_len - special_len; + if let Some(inv) = diff.try_inverse() { + row.general.stack_mut().stack_inv = inv; + row.general.stack_mut().stack_inv_aux = F::ONE; + state.registers.is_stack_top_read = true; + } + } else if let Some(inv) = row.stack_len.try_inverse() { + row.general.stack_mut().stack_inv = inv; + row.general.stack_mut().stack_inv_aux = F::ONE; + } + perform_op(state, op, row) } diff --git a/evm/src/witness/util.rs b/evm/src/witness/util.rs index dbe4c0ed..24970361 100644 --- a/evm/src/witness/util.rs +++ b/evm/src/witness/util.rs @@ -1,6 +1,7 @@ use ethereum_types::U256; use plonky2::field::types::Field; +use super::memory::DUMMY_MEMOP; use crate::byte_packing::byte_packing_stark::BytePackingOp; use crate::cpu::columns::CpuColumnsView; use crate::cpu::kernel::keccak_util::keccakf_u8s; @@ -36,6 +37,10 @@ pub(crate) fn stack_peek( if i >= state.registers.stack_len { return Err(ProgramError::StackUnderflow); } + if i == 0 { + return Ok(state.registers.stack_top); + } + Ok(state.memory.get(MemoryAddress::new( state.registers.context, Segment::Stack, @@ -53,6 +58,77 @@ pub(crate) fn current_context_peek( state.memory.get(MemoryAddress::new(context, segment, virt)) } +pub(crate) fn fill_channel_with_value(row: &mut CpuColumnsView, n: usize, val: U256) { + let channel = &mut row.mem_channels[n]; + let val_limbs: [u64; 4] = val.0; + for (i, limb) in val_limbs.into_iter().enumerate() { + channel.value[2 * i] = F::from_canonical_u32(limb as u32); + channel.value[2 * i + 1] = F::from_canonical_u32((limb >> 32) as u32); + } +} + +/// Pushes without writing in memory. This happens in opcodes where a push immediately follows a pop. +/// The pushed value may be loaded in a memory channel, without creating a memory operation. +pub(crate) fn push_no_write( + state: &mut GenerationState, + row: &mut CpuColumnsView, + val: U256, + channel_opt: Option, +) { + state.registers.stack_top = val; + state.registers.stack_len += 1; + + if let Some(channel) = channel_opt { + let val_limbs: [u64; 4] = val.0; + + let channel = &mut row.mem_channels[channel]; + assert_eq!(channel.used, F::ZERO); + channel.used = F::ZERO; + channel.is_read = F::ZERO; + channel.addr_context = F::from_canonical_usize(0); + channel.addr_segment = F::from_canonical_usize(0); + channel.addr_virtual = F::from_canonical_usize(0); + for (i, limb) in val_limbs.into_iter().enumerate() { + channel.value[2 * i] = F::from_canonical_u32(limb as u32); + channel.value[2 * i + 1] = F::from_canonical_u32((limb >> 32) as u32); + } + } +} + +/// Pushes and (maybe) writes the previous stack top in memory. This happens in opcodes which only push. +pub(crate) fn push_with_write( + state: &mut GenerationState, + row: &mut CpuColumnsView, + val: U256, +) -> Result<(), ProgramError> { + if !state.registers.is_kernel && state.registers.stack_len >= MAX_USER_STACK_SIZE { + return Err(ProgramError::StackOverflow); + } + + let write = if state.registers.stack_len == 0 { + None + } else { + let address = MemoryAddress::new( + state.registers.context, + Segment::Stack, + state.registers.stack_len - 1, + ); + let res = mem_write_gp_log_and_fill( + NUM_GP_CHANNELS - 1, + address, + state, + row, + state.registers.stack_top, + ); + Some(res) + }; + push_no_write(state, row, val, None); + if let Some(log) = write { + state.traces.push_memory(log); + } + Ok(()) +} + pub(crate) fn mem_read_with_log( channel: MemoryChannel, address: MemoryAddress, @@ -146,6 +222,9 @@ pub(crate) fn mem_write_gp_log_and_fill( op } +// Channel 0 already contains the top of the stack. You only need to read +// from the second popped element. +// If the resulting stack isn't empty, update `stack_top`. pub(crate) fn stack_pop_with_log_and_fill( state: &mut GenerationState, row: &mut CpuColumnsView, @@ -154,39 +233,33 @@ pub(crate) fn stack_pop_with_log_and_fill( return Err(ProgramError::StackUnderflow); } + let new_stack_top = if state.registers.stack_len == N { + None + } else { + Some(stack_peek(state, N)?) + }; + let result = core::array::from_fn(|i| { - let address = MemoryAddress::new( - state.registers.context, - Segment::Stack, - state.registers.stack_len - 1 - i, - ); - mem_read_gp_with_log_and_fill(i, address, state, row) + if i == 0 { + (state.registers.stack_top, DUMMY_MEMOP) + } else { + let address = MemoryAddress::new( + state.registers.context, + Segment::Stack, + state.registers.stack_len - 1 - i, + ); + + mem_read_gp_with_log_and_fill(i, address, state, row) + } }); state.registers.stack_len -= N; - Ok(result) -} - -pub(crate) fn stack_push_log_and_fill( - state: &mut GenerationState, - row: &mut CpuColumnsView, - val: U256, -) -> Result { - if !state.registers.is_kernel && state.registers.stack_len >= MAX_USER_STACK_SIZE { - return Err(ProgramError::StackOverflow); + if let Some(val) = new_stack_top { + state.registers.stack_top = val; } - let address = MemoryAddress::new( - state.registers.context, - Segment::Stack, - state.registers.stack_len, - ); - let res = mem_write_gp_log_and_fill(NUM_GP_CHANNELS - 1, address, state, row, val); - - state.registers.stack_len += 1; - - Ok(res) + Ok(result) } fn xor_into_sponge(