From 4b40bc0313cf8bbf7c92a7e70adcc0641f1cb7c4 Mon Sep 17 00:00:00 2001 From: Hamy Ratoanina Date: Mon, 30 Oct 2023 12:56:11 -0400 Subject: [PATCH] Remerge context flags (#1292) * Remerge context flags * Apply comments and revert some unwanted changes --- evm/src/cpu/columns/ops.rs | 5 +- evm/src/cpu/contextops.rs | 252 +++++++++++++++++++++++----------- evm/src/cpu/control_flow.rs | 5 +- evm/src/cpu/decode.rs | 5 +- evm/src/cpu/gas.rs | 3 +- evm/src/cpu/memio.rs | 2 +- evm/src/cpu/stack.rs | 7 +- evm/src/witness/operation.rs | 20 ++- evm/src/witness/transition.rs | 6 +- 9 files changed, 199 insertions(+), 106 deletions(-) diff --git a/evm/src/cpu/columns/ops.rs b/evm/src/cpu/columns/ops.rs index 270b0ab8..b51c1a9e 100644 --- a/evm/src/cpu/columns/ops.rs +++ b/evm/src/cpu/columns/ops.rs @@ -22,9 +22,8 @@ pub struct OpsColumnsView { pub jumpdest: T, pub push0: T, pub push: T, - pub dup_swap: T, - pub get_context: T, - pub set_context: T, + pub dup_swap: T, // Combines DUP and SWAP flags. + pub context_op: T, // Combines GET_CONTEXT and SET_CONTEXT flags. pub mstore_32bytes: T, pub mload_32bytes: T, pub exit_kernel: T, diff --git a/evm/src/cpu/contextops.rs b/evm/src/cpu/contextops.rs index 1683c30e..edc07e7a 100644 --- a/evm/src/cpu/contextops.rs +++ b/evm/src/cpu/contextops.rs @@ -5,6 +5,7 @@ use plonky2::hash::hash_types::RichField; use plonky2::iop::ext_target::ExtensionTarget; use plonky2::plonk::circuit_builder::CircuitBuilder; +use super::membus::NUM_GP_CHANNELS; use crate::constraint_consumer::{ConstraintConsumer, RecursiveConstraintConsumer}; use crate::cpu::columns::CpuColumnsView; use crate::cpu::kernel::constants::context_metadata::ContextMetadata; @@ -15,12 +16,25 @@ fn eval_packed_get( nv: &CpuColumnsView

, yield_constr: &mut ConstraintConsumer

, ) { - let filter = lv.op.get_context; + // If the opcode is GET_CONTEXT, then lv.opcode_bits[0] = 0. + let filter = lv.op.context_op * (P::ONES - lv.opcode_bits[0]); let new_stack_top = nv.mem_channels[0].value; yield_constr.constraint(filter * (new_stack_top[0] - lv.context)); for &limb in &new_stack_top[1..] { yield_constr.constraint(filter * limb); } + + // Constrain new stack length. + yield_constr.constraint(filter * (nv.stack_len - (lv.stack_len + P::ONES))); + + // Unused channels. + for i in 1..NUM_GP_CHANNELS { + if i != 3 { + let channel = lv.mem_channels[i]; + yield_constr.constraint(filter * channel.used); + } + } + yield_constr.constraint(filter * nv.mem_channels[0].used); } fn eval_ext_circuit_get, const D: usize>( @@ -29,7 +43,9 @@ fn eval_ext_circuit_get, const D: usize>( nv: &CpuColumnsView>, yield_constr: &mut RecursiveConstraintConsumer, ) { - let filter = lv.op.get_context; + // If the opcode is GET_CONTEXT, then lv.opcode_bits[0] = 0. + let prod = builder.mul_extension(lv.op.context_op, lv.opcode_bits[0]); + let filter = builder.sub_extension(lv.op.context_op, prod); let new_stack_top = nv.mem_channels[0].value; { let diff = builder.sub_extension(new_stack_top[0], lv.context); @@ -40,6 +56,27 @@ fn eval_ext_circuit_get, const D: usize>( let constr = builder.mul_extension(filter, limb); yield_constr.constraint(builder, constr); } + + // Constrain new stack length. + { + let new_len = builder.add_const_extension(lv.stack_len, F::ONE); + let diff = builder.sub_extension(nv.stack_len, new_len); + let constr = builder.mul_extension(filter, diff); + yield_constr.constraint(builder, constr); + } + + // Unused channels. + for i in 1..NUM_GP_CHANNELS { + if i != 3 { + let channel = lv.mem_channels[i]; + let constr = builder.mul_extension(filter, channel.used); + yield_constr.constraint(builder, constr); + } + } + { + let constr = builder.mul_extension(filter, nv.mem_channels[0].used); + yield_constr.constraint(builder, constr); + } } fn eval_packed_set( @@ -47,7 +84,7 @@ fn eval_packed_set( nv: &CpuColumnsView

, yield_constr: &mut ConstraintConsumer

, ) { - let filter = lv.op.set_context; + let filter = lv.op.context_op * lv.opcode_bits[0]; let stack_top = lv.mem_channels[0].value; let write_old_sp_channel = lv.mem_channels[1]; let read_new_sp_channel = lv.mem_channels[2]; @@ -77,34 +114,29 @@ fn eval_packed_set( yield_constr.constraint(filter * (read_new_sp_channel.addr_segment - ctx_metadata_segment)); yield_constr.constraint(filter * (read_new_sp_channel.addr_virtual - stack_size_field)); - // The next row's stack top is loaded from memory (if the stack isn't empty). - yield_constr.constraint(filter * nv.mem_channels[0].used); - - let read_new_stack_top_channel = lv.mem_channels[3]; - let stack_segment = P::Scalar::from_canonical_u64(Segment::Stack as u64); - let new_filter = filter * nv.stack_len; - - for (limb_channel, limb_top) in read_new_stack_top_channel - .value - .iter() - .zip(nv.mem_channels[0].value) - { - yield_constr.constraint(new_filter * (*limb_channel - limb_top)); + // Constrain stack_inv_aux_2. + let new_top_channel = nv.mem_channels[0]; + yield_constr.constraint( + lv.op.context_op + * (lv.general.stack().stack_inv_aux * lv.opcode_bits[0] + - lv.general.stack().stack_inv_aux_2), + ); + // The new top is loaded in memory channel 3, if the stack isn't empty (see eval_packed). + yield_constr.constraint( + lv.op.context_op + * lv.general.stack().stack_inv_aux_2 + * (lv.mem_channels[3].value[0] - new_top_channel.value[0]), + ); + for &limb in &new_top_channel.value[1..] { + yield_constr.constraint(lv.op.context_op * lv.general.stack().stack_inv_aux_2 * limb); } - yield_constr.constraint(new_filter * (read_new_stack_top_channel.used - P::ONES)); - yield_constr.constraint(new_filter * (read_new_stack_top_channel.is_read - P::ONES)); - yield_constr.constraint(new_filter * (read_new_stack_top_channel.addr_context - nv.context)); - yield_constr.constraint(new_filter * (read_new_stack_top_channel.addr_segment - stack_segment)); - yield_constr.constraint( - new_filter * (read_new_stack_top_channel.addr_virtual - (nv.stack_len - P::ONES)), - ); - // If the new stack is empty, disable the channel read. - yield_constr.constraint( - filter * (nv.stack_len * lv.general.stack().stack_inv - lv.general.stack().stack_inv_aux), - ); - let empty_stack_filter = filter * (lv.general.stack().stack_inv_aux - P::ONES); - yield_constr.constraint(empty_stack_filter * read_new_stack_top_channel.used); + // Unused channels. + for i in 4..NUM_GP_CHANNELS { + let channel = lv.mem_channels[i]; + yield_constr.constraint(filter * channel.used); + } + yield_constr.constraint(filter * new_top_channel.used); } fn eval_ext_circuit_set, const D: usize>( @@ -113,7 +145,7 @@ fn eval_ext_circuit_set, const D: usize>( nv: &CpuColumnsView>, yield_constr: &mut RecursiveConstraintConsumer, ) { - let filter = lv.op.set_context; + let filter = builder.mul_extension(lv.op.context_op, lv.opcode_bits[0]); let stack_top = lv.mem_channels[0].value; let write_old_sp_channel = lv.mem_channels[1]; let read_new_sp_channel = lv.mem_channels[2]; @@ -197,66 +229,38 @@ fn eval_ext_circuit_set, const D: usize>( yield_constr.constraint(builder, constr); } - // The next row's stack top is loaded from memory (if the stack isn't empty). + // Constrain stack_inv_aux_2. + let new_top_channel = nv.mem_channels[0]; { - let constr = builder.mul_extension(filter, nv.mem_channels[0].used); + let diff = builder.mul_sub_extension( + lv.general.stack().stack_inv_aux, + lv.opcode_bits[0], + lv.general.stack().stack_inv_aux_2, + ); + let constr = builder.mul_extension(lv.op.context_op, diff); + yield_constr.constraint(builder, constr); + } + // The new top is loaded in memory channel 3, if the stack isn't empty (see eval_packed). + { + let diff = builder.sub_extension(lv.mem_channels[3].value[0], new_top_channel.value[0]); + let prod = builder.mul_extension(lv.general.stack().stack_inv_aux_2, diff); + let constr = builder.mul_extension(lv.op.context_op, prod); + yield_constr.constraint(builder, constr); + } + for &limb in &new_top_channel.value[1..] { + let prod = builder.mul_extension(lv.general.stack().stack_inv_aux_2, limb); + let constr = builder.mul_extension(lv.op.context_op, prod); yield_constr.constraint(builder, constr); } - let read_new_stack_top_channel = lv.mem_channels[3]; - let stack_segment = - builder.constant_extension(F::Extension::from_canonical_u32(Segment::Stack as u32)); - - let new_filter = builder.mul_extension(filter, nv.stack_len); - - for (limb_channel, limb_top) in read_new_stack_top_channel - .value - .iter() - .zip(nv.mem_channels[0].value) - { - let diff = builder.sub_extension(*limb_channel, limb_top); - let constr = builder.mul_extension(new_filter, diff); + // Unused channels. + for i in 4..NUM_GP_CHANNELS { + let channel = lv.mem_channels[i]; + let constr = builder.mul_extension(filter, channel.used); yield_constr.constraint(builder, constr); } { - let constr = - builder.mul_sub_extension(new_filter, read_new_stack_top_channel.used, new_filter); - yield_constr.constraint(builder, constr); - } - { - let constr = - builder.mul_sub_extension(new_filter, read_new_stack_top_channel.is_read, new_filter); - yield_constr.constraint(builder, constr); - } - { - let diff = builder.sub_extension(read_new_stack_top_channel.addr_context, nv.context); - let constr = builder.mul_extension(new_filter, diff); - yield_constr.constraint(builder, constr); - } - { - let diff = builder.sub_extension(read_new_stack_top_channel.addr_segment, stack_segment); - let constr = builder.mul_extension(new_filter, diff); - yield_constr.constraint(builder, constr); - } - { - let diff = builder.sub_extension(nv.stack_len, one); - let diff = builder.sub_extension(read_new_stack_top_channel.addr_virtual, diff); - let constr = builder.mul_extension(new_filter, diff); - yield_constr.constraint(builder, constr); - } - - // If the new stack is empty, disable the channel read. - { - let diff = builder.mul_extension(nv.stack_len, lv.general.stack().stack_inv); - let diff = builder.sub_extension(diff, lv.general.stack().stack_inv_aux); - let constr = builder.mul_extension(filter, diff); - yield_constr.constraint(builder, constr); - } - - { - let empty_stack_filter = - builder.mul_sub_extension(filter, lv.general.stack().stack_inv_aux, filter); - let constr = builder.mul_extension(empty_stack_filter, read_new_stack_top_channel.used); + let constr = builder.mul_extension(filter, new_top_channel.used); yield_constr.constraint(builder, constr); } } @@ -268,6 +272,33 @@ pub fn eval_packed( ) { eval_packed_get(lv, nv, yield_constr); eval_packed_set(lv, nv, yield_constr); + + // Stack constraints. + // Both operations use memory channel 3. The operations are similar enough that + // we can constrain both at the same time. + let filter = lv.op.context_op; + let channel = lv.mem_channels[3]; + // For get_context, we check if lv.stack_len is 0. For set_context, we check if nv.stack_len is 0. + // However, for get_context, we can deduce lv.stack_len from nv.stack_len since the operation only pushes. + let stack_len = nv.stack_len - (P::ONES - lv.opcode_bits[0]); + // Constrain stack_inv_aux. It's 0 if the relevant stack is empty, 1 otherwise. + yield_constr.constraint( + filter * (stack_len * lv.general.stack().stack_inv - lv.general.stack().stack_inv_aux), + ); + // Enable or disable the channel. + yield_constr.constraint(filter * (lv.general.stack().stack_inv_aux - channel.used)); + let new_filter = filter * lv.general.stack().stack_inv_aux; + // It's a write for get_context, a read for set_context. + yield_constr.constraint(new_filter * (channel.is_read - lv.opcode_bits[0])); + // In both cases, next row's context works. + yield_constr.constraint(new_filter * (channel.addr_context - nv.context)); + // Same segment for both. + yield_constr.constraint( + new_filter * (channel.addr_segment - P::Scalar::from_canonical_u64(Segment::Stack as u64)), + ); + // The address is one less than stack_len. + let addr_virtual = stack_len - P::ONES; + yield_constr.constraint(new_filter * (channel.addr_virtual - addr_virtual)); } pub fn eval_ext_circuit, const D: usize>( @@ -278,4 +309,59 @@ pub fn eval_ext_circuit, const D: usize>( ) { eval_ext_circuit_get(builder, lv, nv, yield_constr); eval_ext_circuit_set(builder, lv, nv, yield_constr); + + // Stack constraints. + // Both operations use memory channel 3. The operations are similar enough that + // we can constrain both at the same time. + let filter = lv.op.context_op; + let channel = lv.mem_channels[3]; + // For get_context, we check if lv.stack_len is 0. For set_context, we check if nv.stack_len is 0. + // However, for get_context, we can deduce lv.stack_len from nv.stack_len since the operation only pushes. + let diff = builder.add_const_extension(lv.opcode_bits[0], -F::ONE); + let stack_len = builder.add_extension(nv.stack_len, diff); + // Constrain stack_inv_aux. It's 0 if the relevant stack is empty, 1 otherwise. + { + let diff = builder.mul_sub_extension( + stack_len, + lv.general.stack().stack_inv, + lv.general.stack().stack_inv_aux, + ); + let constr = builder.mul_extension(filter, diff); + yield_constr.constraint(builder, constr); + } + // Enable or disable the channel. + { + let diff = builder.sub_extension(lv.general.stack().stack_inv_aux, channel.used); + let constr = builder.mul_extension(filter, diff); + yield_constr.constraint(builder, constr); + } + let new_filter = builder.mul_extension(filter, lv.general.stack().stack_inv_aux); + // It's a write for get_context, a read for set_context. + { + let diff = builder.sub_extension(channel.is_read, lv.opcode_bits[0]); + let constr = builder.mul_extension(new_filter, diff); + yield_constr.constraint(builder, constr); + } + // In both cases, next row's context works. + { + let diff = builder.sub_extension(channel.addr_context, nv.context); + let constr = builder.mul_extension(new_filter, diff); + yield_constr.constraint(builder, constr); + } + // Same segment for both. + { + let diff = builder.add_const_extension( + channel.addr_segment, + -F::from_canonical_u64(Segment::Stack as u64), + ); + let constr = builder.mul_extension(new_filter, diff); + yield_constr.constraint(builder, constr); + } + // The address is one less than stack_len. + { + let addr_virtual = builder.add_const_extension(stack_len, -F::ONE); + let diff = builder.sub_extension(channel.addr_virtual, addr_virtual); + let constr = builder.mul_extension(new_filter, diff); + yield_constr.constraint(builder, constr); + } } diff --git a/evm/src/cpu/control_flow.rs b/evm/src/cpu/control_flow.rs index 2f496b51..6bc0562f 100644 --- a/evm/src/cpu/control_flow.rs +++ b/evm/src/cpu/control_flow.rs @@ -8,7 +8,7 @@ use crate::constraint_consumer::{ConstraintConsumer, RecursiveConstraintConsumer use crate::cpu::columns::{CpuColumnsView, COL_MAP}; use crate::cpu::kernel::aggregator::KERNEL; -const NATIVE_INSTRUCTIONS: [usize; 17] = [ +const NATIVE_INSTRUCTIONS: [usize; 16] = [ COL_MAP.op.binary_op, COL_MAP.op.ternary_op, COL_MAP.op.fp254_op, @@ -25,8 +25,7 @@ const NATIVE_INSTRUCTIONS: [usize; 17] = [ COL_MAP.op.push0, // not PUSH (need to increment by more than 1) COL_MAP.op.dup_swap, - COL_MAP.op.get_context, - COL_MAP.op.set_context, + COL_MAP.op.context_op, // not EXIT_KERNEL (performs a jump) COL_MAP.op.m_op_general, // not SYSCALL (performs a jump) diff --git a/evm/src/cpu/decode.rs b/evm/src/cpu/decode.rs index a4756684..83f05c14 100644 --- a/evm/src/cpu/decode.rs +++ b/evm/src/cpu/decode.rs @@ -23,7 +23,7 @@ use crate::cpu::columns::{CpuColumnsView, COL_MAP}; /// behavior. /// Note: invalid opcodes are not represented here. _Any_ opcode is permitted to decode to /// `is_invalid`. The kernel then verifies that the opcode was _actually_ invalid. -const OPCODES: [(u8, usize, bool, usize); 16] = [ +const OPCODES: [(u8, usize, bool, usize); 15] = [ // (start index of block, number of top bits to check (log2), kernel-only, flag column) // ADD, MUL, SUB, DIV, MOD, LT, GT and BYTE flags are handled partly manually here, and partly through the Arithmetic table CTL. // ADDMOD, MULMOD and SUBMOD flags are handled partly manually here, and partly through the Arithmetic table CTL. @@ -42,8 +42,7 @@ const OPCODES: [(u8, usize, bool, usize); 16] = [ (0x60, 5, false, COL_MAP.op.push), // 0x60-0x7f (0x80, 5, false, COL_MAP.op.dup_swap), // 0x80-0x9f (0xee, 0, true, COL_MAP.op.mstore_32bytes), - (0xf6, 0, true, COL_MAP.op.get_context), - (0xf7, 0, true, COL_MAP.op.set_context), + (0xf6, 1, true, COL_MAP.op.context_op), //0xf6-0xf7 (0xf8, 0, true, COL_MAP.op.mload_32bytes), (0xf9, 0, true, COL_MAP.op.exit_kernel), // MLOAD_GENERAL and MSTORE_GENERAL flags are handled manually here. diff --git a/evm/src/cpu/gas.rs b/evm/src/cpu/gas.rs index 1a908d6d..c59c32d5 100644 --- a/evm/src/cpu/gas.rs +++ b/evm/src/cpu/gas.rs @@ -35,8 +35,7 @@ const SIMPLE_OPCODES: OpsColumnsView> = OpsColumnsView { push0: G_BASE, push: G_VERYLOW, dup_swap: G_VERYLOW, - get_context: KERNEL_ONLY_INSTR, - set_context: KERNEL_ONLY_INSTR, + context_op: KERNEL_ONLY_INSTR, mstore_32bytes: KERNEL_ONLY_INSTR, mload_32bytes: KERNEL_ONLY_INSTR, exit_kernel: None, diff --git a/evm/src/cpu/memio.rs b/evm/src/cpu/memio.rs index f70f3fdb..c0daa8f1 100644 --- a/evm/src/cpu/memio.rs +++ b/evm/src/cpu/memio.rs @@ -264,7 +264,7 @@ fn eval_ext_circuit_store, const D: usize>( let top_read_channel = nv.mem_channels[0]; let is_top_read = builder.mul_extension(lv.general.stack().stack_inv_aux, lv.opcode_bits[0]); let is_top_read = builder.sub_extension(lv.general.stack().stack_inv_aux, is_top_read); - // Constrain `stack_inv_aux_2`. It contains `stack_inv_aux * opcode_bits[0]`. + // Constrain `stack_inv_aux_2`. It contains `stack_inv_aux * (1 - opcode_bits[0])`. { let diff = builder.sub_extension(lv.general.stack().stack_inv_aux_2, is_top_read); let constr = builder.mul_extension(lv.op.m_op_general, diff); diff --git a/evm/src/cpu/stack.rs b/evm/src/cpu/stack.rs index bc3d381f..fa8415e4 100644 --- a/evm/src/cpu/stack.rs +++ b/evm/src/cpu/stack.rs @@ -97,12 +97,7 @@ pub(crate) const STACK_BEHAVIORS: OpsColumnsView> = OpsCol }), push: None, // TODO dup_swap: None, - get_context: Some(StackBehavior { - num_pops: 0, - pushes: true, - disable_other_channels: true, - }), - set_context: None, // SET_CONTEXT is special since it involves the old and the new stack. + context_op: None, mload_32bytes: Some(StackBehavior { num_pops: 4, pushes: true, diff --git a/evm/src/witness/operation.rs b/evm/src/witness/operation.rs index cc7911a9..66ba8460 100644 --- a/evm/src/witness/operation.rs +++ b/evm/src/witness/operation.rs @@ -308,7 +308,22 @@ pub(crate) fn generate_get_context( state: &mut GenerationState, mut row: CpuColumnsView, ) -> Result<(), ProgramError> { - push_with_write(state, &mut row, state.registers.context.into())?; + // Same logic as push_with_write, but we have to use channel 3 for stack constraint reasons. + let write = if state.registers.stack_len == 0 { + None + } else { + let address = MemoryAddress::new( + state.registers.context, + Segment::Stack, + state.registers.stack_len - 1, + ); + let res = mem_write_gp_log_and_fill(3, address, state, &mut row, state.registers.stack_top); + Some(res) + }; + push_no_write(state, state.registers.context.into()); + if let Some(log) = write { + state.traces.push_memory(log); + } state.traces.push_cpu(row); Ok(()) } @@ -364,9 +379,11 @@ pub(crate) fn generate_set_context( if let Some(inv) = new_sp_field.try_inverse() { row.general.stack_mut().stack_inv = inv; row.general.stack_mut().stack_inv_aux = F::ONE; + row.general.stack_mut().stack_inv_aux_2 = F::ONE; } else { row.general.stack_mut().stack_inv = F::ZERO; row.general.stack_mut().stack_inv_aux = F::ZERO; + row.general.stack_mut().stack_inv_aux_2 = F::ZERO; } let new_top_addr = MemoryAddress::new(new_ctx, Segment::Stack, new_sp - 1); @@ -833,6 +850,7 @@ pub(crate) fn generate_mstore_general( state.traces.push_memory(log_in2); state.traces.push_memory(log_in3); state.traces.push_memory(log_write); + state.traces.push_cpu(row); Ok(()) diff --git a/evm/src/witness/transition.rs b/evm/src/witness/transition.rs index 312b8591..1f9752d4 100644 --- a/evm/src/witness/transition.rs +++ b/evm/src/witness/transition.rs @@ -180,8 +180,7 @@ fn fill_op_flag(op: Operation, row: &mut CpuColumnsView) { Operation::Jump | Operation::Jumpi => &mut flags.jumps, Operation::Pc => &mut flags.pc, Operation::Jumpdest => &mut flags.jumpdest, - Operation::GetContext => &mut flags.get_context, - Operation::SetContext => &mut flags.set_context, + Operation::GetContext | Operation::SetContext => &mut flags.context_op, Operation::Mload32Bytes => &mut flags.mload_32bytes, Operation::Mstore32Bytes => &mut flags.mstore_32bytes, Operation::ExitKernel => &mut flags.exit_kernel, @@ -216,8 +215,7 @@ fn get_op_special_length(op: Operation) -> Option { Operation::Jumpi => JUMPI_OP, Operation::Pc => STACK_BEHAVIORS.pc, Operation::Jumpdest => STACK_BEHAVIORS.jumpdest, - Operation::GetContext => STACK_BEHAVIORS.get_context, - Operation::SetContext => None, + Operation::GetContext | Operation::SetContext => None, Operation::Mload32Bytes => STACK_BEHAVIORS.mload_32bytes, Operation::Mstore32Bytes => STACK_BEHAVIORS.mstore_32bytes, Operation::ExitKernel => STACK_BEHAVIORS.exit_kernel,