From 16227f90b9beb15249a6e5f82895717169533065 Mon Sep 17 00:00:00 2001 From: Linda Guiga Date: Fri, 21 Jul 2023 10:55:09 +0100 Subject: [PATCH] Merge syscall and exceptions constraints. --- evm/src/cpu/cpu_stark.rs | 10 +- evm/src/cpu/mod.rs | 3 +- evm/src/cpu/syscalls.rs | 222 ------------------ .../{exceptions.rs => syscalls_exceptions.rs} | 173 ++++++++++---- 4 files changed, 127 insertions(+), 281 deletions(-) delete mode 100644 evm/src/cpu/syscalls.rs rename evm/src/cpu/{exceptions.rs => syscalls_exceptions.rs} (52%) diff --git a/evm/src/cpu/cpu_stark.rs b/evm/src/cpu/cpu_stark.rs index e4224bad..8a345692 100644 --- a/evm/src/cpu/cpu_stark.rs +++ b/evm/src/cpu/cpu_stark.rs @@ -13,8 +13,8 @@ use crate::constraint_consumer::{ConstraintConsumer, RecursiveConstraintConsumer use crate::cpu::columns::{CpuColumnsView, COL_MAP, NUM_CPU_COLUMNS}; use crate::cpu::membus::NUM_GP_CHANNELS; use crate::cpu::{ - bootstrap_kernel, contextops, control_flow, decode, dup_swap, exceptions, gas, jumps, membus, - memio, modfp254, pc, push0, shift, simple_logic, stack, stack_bounds, syscalls, + bootstrap_kernel, contextops, control_flow, decode, dup_swap, gas, jumps, membus, memio, + modfp254, pc, push0, shift, simple_logic, stack, stack_bounds, syscalls_exceptions, }; use crate::cross_table_lookup::{Column, TableWithColumns}; use crate::memory::segments::Segment; @@ -190,7 +190,6 @@ impl, const D: usize> Stark for CpuStark, const D: usize> Stark for CpuStark, const D: usize> Stark for CpuStark, const D: usize> Stark for CpuStark usize { diff --git a/evm/src/cpu/mod.rs b/evm/src/cpu/mod.rs index 411ddb76..91b04cf4 100644 --- a/evm/src/cpu/mod.rs +++ b/evm/src/cpu/mod.rs @@ -5,7 +5,6 @@ pub(crate) mod control_flow; pub mod cpu_stark; pub(crate) mod decode; mod dup_swap; -mod exceptions; mod gas; mod jumps; pub mod kernel; @@ -18,4 +17,4 @@ mod shift; pub(crate) mod simple_logic; mod stack; pub(crate) mod stack_bounds; -mod syscalls; +mod syscalls_exceptions; diff --git a/evm/src/cpu/syscalls.rs b/evm/src/cpu/syscalls.rs deleted file mode 100644 index 0686bf48..00000000 --- a/evm/src/cpu/syscalls.rs +++ /dev/null @@ -1,222 +0,0 @@ -//! Handle instructions that are implemented in terms of system calls. -//! -//! These are usually the ones that are too complicated to implement in one CPU table row. - -use plonky2::field::extension::Extendable; -use plonky2::field::packed::PackedField; -use plonky2::field::types::Field; -use plonky2::hash::hash_types::RichField; -use plonky2::iop::ext_target::ExtensionTarget; -use static_assertions::const_assert; - -use crate::constraint_consumer::{ConstraintConsumer, RecursiveConstraintConsumer}; -use crate::cpu::columns::CpuColumnsView; -use crate::cpu::kernel::aggregator::KERNEL; -use crate::cpu::membus::NUM_GP_CHANNELS; -use crate::memory::segments::Segment; - -// Copy the constant but make it `usize`. -const BYTES_PER_OFFSET: usize = crate::cpu::kernel::assembler::BYTES_PER_OFFSET as usize; -const_assert!(BYTES_PER_OFFSET < NUM_GP_CHANNELS); // Reserve one channel for stack push - -pub fn eval_packed( - lv: &CpuColumnsView

, - nv: &CpuColumnsView

, - yield_constr: &mut ConstraintConsumer

, -) { - let filter = lv.op.syscall; - - // Look up the handler in memory - let code_segment = P::Scalar::from_canonical_usize(Segment::Code as usize); - let syscall_jumptable_start = - P::Scalar::from_canonical_usize(KERNEL.global_labels["syscall_jumptable"]); - let opcode: P = lv - .opcode_bits - .into_iter() - .enumerate() - .map(|(i, bit)| bit * P::Scalar::from_canonical_u64(1 << i)) - .sum(); - let opcode_handler_addr_start = - syscall_jumptable_start + opcode * P::Scalar::from_canonical_usize(BYTES_PER_OFFSET); - for (i, channel) in lv.mem_channels[0..BYTES_PER_OFFSET].iter().enumerate() { - yield_constr.constraint(filter * (channel.used - P::ONES)); - yield_constr.constraint(filter * (channel.is_read - P::ONES)); - - // Set kernel context and code segment - yield_constr.constraint(filter * channel.addr_context); - yield_constr.constraint(filter * (channel.addr_segment - code_segment)); - - // Set address, using a separate channel for each of the `BYTES_PER_OFFSET` limbs. - let limb_address = opcode_handler_addr_start + P::Scalar::from_canonical_usize(i); - yield_constr.constraint(filter * (channel.addr_virtual - limb_address)); - } - - // Disable unused channels (the last channel is used to push to the stack) - for channel in &lv.mem_channels[BYTES_PER_OFFSET..NUM_GP_CHANNELS - 1] { - yield_constr.constraint(filter * channel.used); - } - - // Set program counter to the handler address - // The addresses are big-endian in memory - let target = lv.mem_channels[0..BYTES_PER_OFFSET] - .iter() - .map(|channel| channel.value[0]) - .fold(P::ZEROS, |cumul, limb| { - cumul * P::Scalar::from_canonical_u64(256) + limb - }); - yield_constr.constraint_transition(filter * (nv.program_counter - target)); - // Set kernel mode - yield_constr.constraint_transition(filter * (nv.is_kernel_mode - P::ONES)); - // Maintain current context - yield_constr.constraint_transition(filter * (nv.context - lv.context)); - // Reset gas counter to zero. - yield_constr.constraint_transition(filter * nv.gas); - - // This memory channel is constrained in `stack.rs`. - let output = lv.mem_channels[NUM_GP_CHANNELS - 1].value; - // Push to stack: current PC + 1 (limb 0), kernel flag (limb 1), gas counter (limbs 6 and 7). - yield_constr.constraint(filter * (output[0] - (lv.program_counter + P::ONES))); - yield_constr.constraint(filter * (output[1] - lv.is_kernel_mode)); - yield_constr.constraint(filter * (output[6] - lv.gas)); - // TODO: Range check `output[6]`. - yield_constr.constraint(filter * output[7]); // High limb of gas is zero. - - // Zero the rest of that register - for &limb in &output[2..6] { - yield_constr.constraint(filter * limb); - } -} - -pub fn eval_ext_circuit, const D: usize>( - builder: &mut plonky2::plonk::circuit_builder::CircuitBuilder, - lv: &CpuColumnsView>, - nv: &CpuColumnsView>, - yield_constr: &mut RecursiveConstraintConsumer, -) { - let filter = lv.op.syscall; - - // Look up the handler in memory - let code_segment = F::from_canonical_usize(Segment::Code as usize); - let syscall_jumptable_start = builder.constant_extension( - F::from_canonical_usize(KERNEL.global_labels["syscall_jumptable"]).into(), - ); - let opcode = lv - .opcode_bits - .into_iter() - .rev() - .fold(builder.zero_extension(), |cumul, bit| { - builder.mul_const_add_extension(F::TWO, cumul, bit) - }); - let opcode_handler_addr_start = builder.mul_const_add_extension( - F::from_canonical_usize(BYTES_PER_OFFSET), - opcode, - syscall_jumptable_start, - ); - for (i, channel) in lv.mem_channels[0..BYTES_PER_OFFSET].iter().enumerate() { - { - let constr = builder.mul_sub_extension(filter, channel.used, filter); - yield_constr.constraint(builder, constr); - } - { - let constr = builder.mul_sub_extension(filter, channel.is_read, filter); - yield_constr.constraint(builder, constr); - } - - // Set kernel context and code segment - { - let constr = builder.mul_extension(filter, channel.addr_context); - yield_constr.constraint(builder, constr); - } - { - let constr = builder.arithmetic_extension( - F::ONE, - -code_segment, - filter, - channel.addr_segment, - filter, - ); - yield_constr.constraint(builder, constr); - } - - // Set address, using a separate channel for each of the `BYTES_PER_OFFSET` limbs. - { - let diff = builder.sub_extension(channel.addr_virtual, opcode_handler_addr_start); - let constr = builder.arithmetic_extension( - F::ONE, - -F::from_canonical_usize(i), - filter, - diff, - filter, - ); - yield_constr.constraint(builder, constr); - } - } - - // Disable unused channels (the last channel is used to push to the stack) - for channel in &lv.mem_channels[BYTES_PER_OFFSET..NUM_GP_CHANNELS - 1] { - let constr = builder.mul_extension(filter, channel.used); - yield_constr.constraint(builder, constr); - } - - // Set program counter to the handler address - // The addresses are big-endian in memory - { - let target = lv.mem_channels[0..BYTES_PER_OFFSET] - .iter() - .map(|channel| channel.value[0]) - .fold(builder.zero_extension(), |cumul, limb| { - builder.mul_const_add_extension(F::from_canonical_u64(256), cumul, limb) - }); - let diff = builder.sub_extension(nv.program_counter, target); - let constr = builder.mul_extension(filter, diff); - yield_constr.constraint_transition(builder, constr); - } - // Set kernel mode - { - let constr = builder.mul_sub_extension(filter, nv.is_kernel_mode, filter); - yield_constr.constraint_transition(builder, constr); - } - // Maintain current context - { - let diff = builder.sub_extension(nv.context, lv.context); - let constr = builder.mul_extension(filter, diff); - yield_constr.constraint_transition(builder, constr); - } - // Reset gas counter to zero. - { - let constr = builder.mul_extension(filter, nv.gas); - yield_constr.constraint_transition(builder, constr); - } - - // This memory channel is constrained in `stack.rs`. - let output = lv.mem_channels[NUM_GP_CHANNELS - 1].value; - // Push to stack: current PC + 1 (limb 0), kernel flag (limb 1), gas counter (limbs 6 and 7). - { - let pc_plus_1 = builder.add_const_extension(lv.program_counter, F::ONE); - let diff = builder.sub_extension(output[0], pc_plus_1); - let constr = builder.mul_extension(filter, diff); - yield_constr.constraint(builder, constr); - } - { - let diff = builder.sub_extension(output[1], lv.is_kernel_mode); - let constr = builder.mul_extension(filter, diff); - yield_constr.constraint(builder, constr); - } - { - let diff = builder.sub_extension(output[6], lv.gas); - let constr = builder.mul_extension(filter, diff); - yield_constr.constraint(builder, constr); - } - // TODO: Range check `output[6]`. - { - // High limb of gas is zero. - let constr = builder.mul_extension(filter, output[7]); - yield_constr.constraint(builder, constr); - } - - // Zero the rest of that register - for &limb in &output[2..6] { - let constr = builder.mul_extension(filter, limb); - yield_constr.constraint(builder, constr); - } -} diff --git a/evm/src/cpu/exceptions.rs b/evm/src/cpu/syscalls_exceptions.rs similarity index 52% rename from evm/src/cpu/exceptions.rs rename to evm/src/cpu/syscalls_exceptions.rs index 485f6888..abc47baf 100644 --- a/evm/src/cpu/exceptions.rs +++ b/evm/src/cpu/syscalls_exceptions.rs @@ -24,12 +24,12 @@ pub fn eval_packed( nv: &CpuColumnsView

, yield_constr: &mut ConstraintConsumer

, ) { - // TODO: There's heaps of overlap between this and syscalls. They could be merged. + let filter_syscall = lv.op.syscall; + let filter_exception = lv.op.exception; + let total_filter = filter_syscall + filter_exception; - let filter = lv.op.exception; - - // Ensure we are not in kernel mode - yield_constr.constraint(filter * lv.is_kernel_mode); + // If exception, ensure we are not in kernel mode + yield_constr.constraint(filter_exception * lv.is_kernel_mode); // Get the exception code as an value in {0, ..., 7}. let exc_code_bits = lv.general.exception().exc_code_bits; @@ -40,31 +40,49 @@ pub fn eval_packed( .sum(); // Ensure that all bits are either 0 or 1. for bit in exc_code_bits { - yield_constr.constraint(filter * bit * (bit - P::ONES)); + yield_constr.constraint(filter_exception * bit * (bit - P::ONES)); } // Look up the handler in memory let code_segment = P::Scalar::from_canonical_usize(Segment::Code as usize); + + let opcode: P = lv + .opcode_bits + .into_iter() + .enumerate() + .map(|(i, bit)| bit * P::Scalar::from_canonical_u64(1 << i)) + .sum(); + + // Syscall handler + let syscall_jumptable_start = + P::Scalar::from_canonical_usize(KERNEL.global_labels["syscall_jumptable"]); + let opcode_handler_addr_start = + syscall_jumptable_start + opcode * P::Scalar::from_canonical_usize(BYTES_PER_OFFSET); + // Exceptions handler let exc_jumptable_start = P::Scalar::from_canonical_usize(KERNEL.global_labels["exception_jumptable"]); let exc_handler_addr_start = exc_jumptable_start + exc_code * P::Scalar::from_canonical_usize(BYTES_PER_OFFSET); + for (i, channel) in lv.mem_channels[0..BYTES_PER_OFFSET].iter().enumerate() { - yield_constr.constraint(filter * (channel.used - P::ONES)); - yield_constr.constraint(filter * (channel.is_read - P::ONES)); + yield_constr.constraint(total_filter * (channel.used - P::ONES)); + yield_constr.constraint(total_filter * (channel.is_read - P::ONES)); // Set kernel context and code segment - yield_constr.constraint(filter * channel.addr_context); - yield_constr.constraint(filter * (channel.addr_segment - code_segment)); + yield_constr.constraint(total_filter * channel.addr_context); + yield_constr.constraint(total_filter * (channel.addr_segment - code_segment)); // Set address, using a separate channel for each of the `BYTES_PER_OFFSET` limbs. - let limb_address = exc_handler_addr_start + P::Scalar::from_canonical_usize(i); - yield_constr.constraint(filter * (channel.addr_virtual - limb_address)); + let limb_address_syscall = opcode_handler_addr_start + P::Scalar::from_canonical_usize(i); + let limb_address_exception = exc_handler_addr_start + P::Scalar::from_canonical_usize(i); + + yield_constr.constraint(filter_syscall * (channel.addr_virtual - limb_address_syscall)); + yield_constr.constraint(filter_exception * (channel.addr_virtual - limb_address_exception)); } // Disable unused channels (the last channel is used to push to the stack) for channel in &lv.mem_channels[BYTES_PER_OFFSET..NUM_GP_CHANNELS - 1] { - yield_constr.constraint(filter * channel.used); + yield_constr.constraint(total_filter * channel.used); } // Set program counter to the handler address @@ -75,25 +93,30 @@ pub fn eval_packed( .fold(P::ZEROS, |cumul, limb| { cumul * P::Scalar::from_canonical_u64(256) + limb }); - yield_constr.constraint_transition(filter * (nv.program_counter - target)); + yield_constr.constraint_transition(total_filter * (nv.program_counter - target)); // Set kernel mode - yield_constr.constraint_transition(filter * (nv.is_kernel_mode - P::ONES)); + yield_constr.constraint_transition(total_filter * (nv.is_kernel_mode - P::ONES)); // Maintain current context - yield_constr.constraint_transition(filter * (nv.context - lv.context)); + yield_constr.constraint_transition(total_filter * (nv.context - lv.context)); // Reset gas counter to zero. - yield_constr.constraint_transition(filter * nv.gas); + yield_constr.constraint_transition(total_filter * nv.gas); // This memory channel is constrained in `stack.rs`. let output = lv.mem_channels[NUM_GP_CHANNELS - 1].value; - // Push to stack: current PC (limb 0), gas counter (limbs 6 and 7). - yield_constr.constraint(filter * (output[0] - lv.program_counter)); - yield_constr.constraint(filter * (output[6] - lv.gas)); + // Push to stack: current PC + 1 (limb 0), kernel flag (limb 1), gas counter (limbs 6 and 7). + yield_constr.constraint(filter_syscall * (output[0] - (lv.program_counter + P::ONES))); + yield_constr.constraint(filter_exception * (output[0] - lv.program_counter)); + // Check the kernel mode, for syscalls only + yield_constr.constraint(filter_syscall * (output[1] - lv.is_kernel_mode)); + yield_constr.constraint(total_filter * (output[6] - lv.gas)); // TODO: Range check `output[6]`. - yield_constr.constraint(filter * output[7]); // High limb of gas is zero. + yield_constr.constraint(total_filter * output[7]); // High limb of gas is zero. // Zero the rest of that register - for &limb in &output[1..6] { - yield_constr.constraint(filter * limb); + // output[1] is 0 for exceptions, but not for syscalls + yield_constr.constraint(filter_exception * output[1]); + for &limb in &output[2..6] { + yield_constr.constraint(total_filter * limb); } } @@ -103,15 +126,14 @@ pub fn eval_ext_circuit, const D: usize>( nv: &CpuColumnsView>, yield_constr: &mut RecursiveConstraintConsumer, ) { - let filter = lv.op.exception; + let filter_syscall = lv.op.syscall; + let filter_exception = lv.op.exception; + let total_filter = builder.add_extension(filter_syscall, filter_exception); - // Ensure we are not in kernel mode - { - let constr = builder.mul_extension(filter, lv.is_kernel_mode); - yield_constr.constraint(builder, constr); - } + // Ensure that, if exception, we are not in kernel mode + let constr = builder.mul_extension(filter_exception, lv.is_kernel_mode); + yield_constr.constraint(builder, constr); - // Get the exception code as an value in {0, ..., 7}. let exc_code_bits = lv.general.exception().exc_code_bits; let exc_code = exc_code_bits @@ -120,15 +142,36 @@ pub fn eval_ext_circuit, const D: usize>( .fold(builder.zero_extension(), |cumul, (i, bit)| { builder.mul_const_add_extension(F::from_canonical_u64(1 << i), bit, cumul) }); + // Ensure that all bits are either 0 or 1. for bit in exc_code_bits { let constr = builder.mul_sub_extension(bit, bit, bit); - let constr = builder.mul_extension(filter, constr); + let constr = builder.mul_extension(filter_exception, constr); yield_constr.constraint(builder, constr); } // Look up the handler in memory let code_segment = F::from_canonical_usize(Segment::Code as usize); + + let opcode = lv + .opcode_bits + .into_iter() + .rev() + .fold(builder.zero_extension(), |cumul, bit| { + builder.mul_const_add_extension(F::TWO, cumul, bit) + }); + + // Syscall handler + let syscall_jumptable_start = builder.constant_extension( + F::from_canonical_usize(KERNEL.global_labels["syscall_jumptable"]).into(), + ); + let opcode_handler_addr_start = builder.mul_const_add_extension( + F::from_canonical_usize(BYTES_PER_OFFSET), + opcode, + syscall_jumptable_start, + ); + + // Exceptions handler let exc_jumptable_start = builder.constant_extension( F::from_canonical_usize(KERNEL.global_labels["exception_jumptable"]).into(), ); @@ -137,41 +180,54 @@ pub fn eval_ext_circuit, const D: usize>( exc_code, exc_jumptable_start, ); + for (i, channel) in lv.mem_channels[0..BYTES_PER_OFFSET].iter().enumerate() { { - let constr = builder.mul_sub_extension(filter, channel.used, filter); + let constr = builder.mul_sub_extension(total_filter, channel.used, total_filter); yield_constr.constraint(builder, constr); } { - let constr = builder.mul_sub_extension(filter, channel.is_read, filter); + let constr = builder.mul_sub_extension(total_filter, channel.is_read, total_filter); yield_constr.constraint(builder, constr); } // Set kernel context and code segment { - let constr = builder.mul_extension(filter, channel.addr_context); + let constr = builder.mul_extension(total_filter, channel.addr_context); yield_constr.constraint(builder, constr); } { let constr = builder.arithmetic_extension( F::ONE, -code_segment, - filter, + total_filter, channel.addr_segment, - filter, + total_filter, ); yield_constr.constraint(builder, constr); } // Set address, using a separate channel for each of the `BYTES_PER_OFFSET` limbs. { - let diff = builder.sub_extension(channel.addr_virtual, exc_handler_addr_start); + let diff_syscall = + builder.sub_extension(channel.addr_virtual, opcode_handler_addr_start); let constr = builder.arithmetic_extension( F::ONE, -F::from_canonical_usize(i), - filter, - diff, - filter, + filter_syscall, + diff_syscall, + filter_syscall, + ); + yield_constr.constraint(builder, constr); + + let diff_exception = + builder.sub_extension(channel.addr_virtual, exc_handler_addr_start); + let constr = builder.arithmetic_extension( + F::ONE, + -F::from_canonical_usize(i), + filter_exception, + diff_exception, + filter_exception, ); yield_constr.constraint(builder, constr); } @@ -179,7 +235,7 @@ pub fn eval_ext_circuit, const D: usize>( // Disable unused channels (the last channel is used to push to the stack) for channel in &lv.mem_channels[BYTES_PER_OFFSET..NUM_GP_CHANNELS - 1] { - let constr = builder.mul_extension(filter, channel.used); + let constr = builder.mul_extension(total_filter, channel.used); yield_constr.constraint(builder, constr); } @@ -193,49 +249,64 @@ pub fn eval_ext_circuit, const D: usize>( builder.mul_const_add_extension(F::from_canonical_u64(256), cumul, limb) }); let diff = builder.sub_extension(nv.program_counter, target); - let constr = builder.mul_extension(filter, diff); + let constr = builder.mul_extension(total_filter, diff); yield_constr.constraint_transition(builder, constr); } // Set kernel mode { - let constr = builder.mul_sub_extension(filter, nv.is_kernel_mode, filter); + let constr = builder.mul_sub_extension(total_filter, nv.is_kernel_mode, total_filter); yield_constr.constraint_transition(builder, constr); } // Maintain current context { let diff = builder.sub_extension(nv.context, lv.context); - let constr = builder.mul_extension(filter, diff); + let constr = builder.mul_extension(total_filter, diff); yield_constr.constraint_transition(builder, constr); } // Reset gas counter to zero. { - let constr = builder.mul_extension(filter, nv.gas); + let constr = builder.mul_extension(total_filter, nv.gas); yield_constr.constraint_transition(builder, constr); } // This memory channel is constrained in `stack.rs`. let output = lv.mem_channels[NUM_GP_CHANNELS - 1].value; - // Push to stack: current PC (limb 0), gas counter (limbs 6 and 7). + // Push to stack (syscall): current PC + 1 (limb 0), kernel flag (limb 1), gas counter (limbs 6 and 7). + { + let pc_plus_1 = builder.add_const_extension(lv.program_counter, F::ONE); + let diff = builder.sub_extension(output[0], pc_plus_1); + let constr = builder.mul_extension(filter_syscall, diff); + yield_constr.constraint(builder, constr); + } + // Push to stack (exception): current PC (limb 0), kernel flag (limb 1), gas counter (limbs 6 and 7). { let diff = builder.sub_extension(output[0], lv.program_counter); - let constr = builder.mul_extension(filter, diff); + let constr = builder.mul_extension(filter_exception, diff); + yield_constr.constraint(builder, constr); + } + // Push to stack(exception): current PC (limb 0), gas counter (limbs 6 and 7). + { + let diff = builder.sub_extension(output[1], lv.is_kernel_mode); + let constr = builder.mul_extension(filter_syscall, diff); yield_constr.constraint(builder, constr); } { let diff = builder.sub_extension(output[6], lv.gas); - let constr = builder.mul_extension(filter, diff); + let constr = builder.mul_extension(total_filter, diff); yield_constr.constraint(builder, constr); } // TODO: Range check `output[6]`. { // High limb of gas is zero. - let constr = builder.mul_extension(filter, output[7]); + let constr = builder.mul_extension(total_filter, output[7]); yield_constr.constraint(builder, constr); } // Zero the rest of that register - for &limb in &output[1..6] { - let constr = builder.mul_extension(filter, limb); + let constr = builder.mul_extension(filter_exception, output[1]); + yield_constr.constraint(builder, constr); + for &limb in &output[2..6] { + let constr = builder.mul_extension(total_filter, limb); yield_constr.constraint(builder, constr); } }