From 626c2583deef6a8c9cc89c7c0ee4df1dfb79926c Mon Sep 17 00:00:00 2001 From: Jacqueline Nabaglo Date: Mon, 7 Nov 2022 12:29:28 -0800 Subject: [PATCH] Combine all syscalls into one flag (#802) * Combine all syscalls into one flag * Minor: typo * Daniel PR comments --- evm/Cargo.toml | 1 + evm/src/cpu/columns/ops.rs | 67 +----- evm/src/cpu/control_flow.rs | 31 ++- evm/src/cpu/decode.rs | 87 +------- evm/src/cpu/kernel/aggregator.rs | 2 + evm/src/cpu/kernel/asm/core/syscall.asm | 160 ++++++++++++++ evm/src/cpu/kernel/asm/core/syscall_stubs.asm | 53 +++++ evm/src/cpu/kernel/assembler.rs | 37 ++-- evm/src/cpu/kernel/ast.rs | 2 + evm/src/cpu/kernel/evm_asm.pest | 3 +- evm/src/cpu/kernel/parser.rs | 3 + evm/src/cpu/stack.rs | 62 +----- evm/src/cpu/syscalls.rs | 204 ++++++++++++------ 13 files changed, 439 insertions(+), 273 deletions(-) create mode 100644 evm/src/cpu/kernel/asm/core/syscall.asm create mode 100644 evm/src/cpu/kernel/asm/core/syscall_stubs.asm diff --git a/evm/Cargo.toml b/evm/Cargo.toml index 650874b0..5b48f524 100644 --- a/evm/Cargo.toml +++ b/evm/Cargo.toml @@ -28,6 +28,7 @@ rlp = "0.5.1" rlp-derive = "0.1.0" serde = { version = "1.0.144", features = ["derive"] } sha2 = "0.10.2" +static_assertions = "1.1.0" tiny-keccak = "2.0.2" [dev-dependencies] diff --git a/evm/src/cpu/columns/ops.rs b/evm/src/cpu/columns/ops.rs index 04d4d0f2..c265be44 100644 --- a/evm/src/cpu/columns/ops.rs +++ b/evm/src/cpu/columns/ops.rs @@ -7,106 +7,59 @@ use crate::util::{indices_arr, transmute_no_compile_time_size_checks}; #[repr(C)] #[derive(Eq, PartialEq, Debug)] pub struct OpsColumnsView { - pub stop: T, + // TODO: combine ADD, MUL, SUB, DIV, MOD, ADDFP254, MULFP254, SUBFP254, LT, and GT into one flag pub add: T, pub mul: T, pub sub: T, pub div: T, - pub sdiv: T, pub mod_: T, - pub smod: T, + // TODO: combine ADDMOD, MULMOD into one flag pub addmod: T, pub mulmod: T, - pub exp: T, - pub signextend: T, pub addfp254: T, pub mulfp254: T, pub subfp254: T, pub lt: T, pub gt: T, - pub slt: T, - pub sgt: T, pub eq: T, // Note: This column must be 0 when is_cpu_cycle = 0. pub iszero: T, // Note: This column must be 0 when is_cpu_cycle = 0. + // TODO: combine AND, OR, and XOR into one flag pub and: T, pub or: T, pub xor: T, pub not: T, pub byte: T, + // TODO: combine SHL and SHR into one flag pub shl: T, pub shr: T, - pub sar: T, - pub keccak256: T, pub keccak_general: T, - pub address: T, - pub balance: T, - pub origin: T, - pub caller: T, - pub callvalue: T, - pub calldataload: T, - pub calldatasize: T, - pub calldatacopy: T, - pub codesize: T, - pub codecopy: T, - pub gasprice: T, - pub extcodesize: T, - pub extcodecopy: T, - pub returndatasize: T, - pub returndatacopy: T, - pub extcodehash: T, - pub blockhash: T, - pub coinbase: T, - pub timestamp: T, - pub number: T, - pub difficulty: T, - pub gaslimit: T, - pub chainid: T, - pub selfbalance: T, - pub basefee: T, pub prover_input: T, pub pop: T, - pub mload: T, - pub mstore: T, - pub mstore8: T, - pub sload: T, - pub sstore: T, + // TODO: combine JUMP and JUMPI into one flag pub jump: T, // Note: This column must be 0 when is_cpu_cycle = 0. pub jumpi: T, // Note: This column must be 0 when is_cpu_cycle = 0. pub pc: T, - pub msize: T, pub gas: T, pub jumpdest: T, + // TODO: combine GET_STATE_ROOT and SET_STATE_ROOT into one flag pub get_state_root: T, pub set_state_root: T, + // TODO: combine GET_RECEIPT_ROOT and SET_RECEIPT_ROOT into one flag pub get_receipt_root: T, pub set_receipt_root: T, pub push: T, pub dup: T, pub swap: T, - pub log0: T, - pub log1: T, - pub log2: T, - pub log3: T, - pub log4: T, - // PANIC does not get a flag; it fails at the decode stage. - pub create: T, - pub call: T, - pub callcode: T, - pub return_: T, - pub delegatecall: T, - pub create2: T, + // TODO: combine GET_CONTEXT and SET_CONTEXT into one flag pub get_context: T, pub set_context: T, pub consume_gas: T, pub exit_kernel: T, - pub staticcall: T, + // TODO: combine MLOAD_GENERAL and MSTORE_GENERAL into one flag pub mload_general: T, pub mstore_general: T, - pub revert: T, - pub selfdestruct: T, - // TODO: this doesn't actually need its own flag. We can just do `1 - sum(all other flags)`. - pub invalid: T, + pub syscall: T, } // `u8` is guaranteed to have a `size_of` of 1. diff --git a/evm/src/cpu/control_flow.rs b/evm/src/cpu/control_flow.rs index c7b7c6bb..ba0bbd3b 100644 --- a/evm/src/cpu/control_flow.rs +++ b/evm/src/cpu/control_flow.rs @@ -8,36 +8,49 @@ use crate::constraint_consumer::{ConstraintConsumer, RecursiveConstraintConsumer use crate::cpu::columns::{CpuColumnsView, COL_MAP}; use crate::cpu::kernel::aggregator::KERNEL; -// TODO: This list is incomplete. -const NATIVE_INSTRUCTIONS: [usize; 28] = [ +const NATIVE_INSTRUCTIONS: [usize; 37] = [ COL_MAP.op.add, COL_MAP.op.mul, COL_MAP.op.sub, COL_MAP.op.div, - COL_MAP.op.sdiv, COL_MAP.op.mod_, - COL_MAP.op.smod, COL_MAP.op.addmod, COL_MAP.op.mulmod, - COL_MAP.op.signextend, COL_MAP.op.addfp254, COL_MAP.op.mulfp254, COL_MAP.op.subfp254, COL_MAP.op.lt, COL_MAP.op.gt, - COL_MAP.op.slt, - COL_MAP.op.sgt, COL_MAP.op.eq, COL_MAP.op.iszero, COL_MAP.op.and, COL_MAP.op.or, COL_MAP.op.xor, COL_MAP.op.not, - COL_MAP.op.byte, COL_MAP.op.shl, COL_MAP.op.shr, - COL_MAP.op.sar, + COL_MAP.op.keccak_general, + COL_MAP.op.prover_input, COL_MAP.op.pop, + // not JUMP (need to jump) + // not JUMPI (possible need to jump) + COL_MAP.op.pc, + COL_MAP.op.gas, + COL_MAP.op.jumpdest, + COL_MAP.op.get_state_root, + COL_MAP.op.set_state_root, + COL_MAP.op.get_receipt_root, + COL_MAP.op.set_receipt_root, + // not PUSH (need to increment by more than 1) + COL_MAP.op.dup, + COL_MAP.op.swap, + COL_MAP.op.get_context, + COL_MAP.op.set_context, + COL_MAP.op.consume_gas, + // not EXIT_KERNEL (performs a jump) + COL_MAP.op.mload_general, + COL_MAP.op.mstore_general, + // not SYSCALL (performs a jump) ]; fn get_halt_pcs() -> (F, F) { diff --git a/evm/src/cpu/decode.rs b/evm/src/cpu/decode.rs index 8d7cf3f6..feb672d0 100644 --- a/evm/src/cpu/decode.rs +++ b/evm/src/cpu/decode.rs @@ -22,27 +22,20 @@ use crate::cpu::columns::{CpuColumnsView, COL_MAP}; /// behavior. /// Note: invalid opcodes are not represented here. _Any_ opcode is permitted to decode to /// `is_invalid`. The kernel then verifies that the opcode was _actually_ invalid. -const OPCODES: [(u8, usize, bool, usize); 96] = [ +const OPCODES: [(u8, usize, bool, usize); 42] = [ // (start index of block, number of top bits to check (log2), kernel-only, flag column) - (0x00, 0, false, COL_MAP.op.stop), (0x01, 0, false, COL_MAP.op.add), (0x02, 0, false, COL_MAP.op.mul), (0x03, 0, false, COL_MAP.op.sub), (0x04, 0, false, COL_MAP.op.div), - (0x05, 0, false, COL_MAP.op.sdiv), (0x06, 0, false, COL_MAP.op.mod_), - (0x07, 0, false, COL_MAP.op.smod), (0x08, 0, false, COL_MAP.op.addmod), (0x09, 0, false, COL_MAP.op.mulmod), - (0x0a, 0, false, COL_MAP.op.exp), - (0x0b, 0, false, COL_MAP.op.signextend), (0x0c, 0, true, COL_MAP.op.addfp254), (0x0d, 0, true, COL_MAP.op.mulfp254), (0x0e, 0, true, COL_MAP.op.subfp254), (0x10, 0, false, COL_MAP.op.lt), (0x11, 0, false, COL_MAP.op.gt), - (0x12, 0, false, COL_MAP.op.slt), - (0x13, 0, false, COL_MAP.op.sgt), (0x14, 0, false, COL_MAP.op.eq), (0x15, 0, false, COL_MAP.op.iszero), (0x16, 0, false, COL_MAP.op.and), @@ -52,45 +45,12 @@ const OPCODES: [(u8, usize, bool, usize); 96] = [ (0x1a, 0, false, COL_MAP.op.byte), (0x1b, 0, false, COL_MAP.op.shl), (0x1c, 0, false, COL_MAP.op.shr), - (0x1d, 0, false, COL_MAP.op.sar), - (0x20, 0, false, COL_MAP.op.keccak256), (0x21, 0, true, COL_MAP.op.keccak_general), - (0x30, 0, false, COL_MAP.op.address), - (0x31, 0, false, COL_MAP.op.balance), - (0x32, 0, false, COL_MAP.op.origin), - (0x33, 0, false, COL_MAP.op.caller), - (0x34, 0, false, COL_MAP.op.callvalue), - (0x35, 0, false, COL_MAP.op.calldataload), - (0x36, 0, false, COL_MAP.op.calldatasize), - (0x37, 0, false, COL_MAP.op.calldatacopy), - (0x38, 0, false, COL_MAP.op.codesize), - (0x39, 0, false, COL_MAP.op.codecopy), - (0x3a, 0, false, COL_MAP.op.gasprice), - (0x3b, 0, false, COL_MAP.op.extcodesize), - (0x3c, 0, false, COL_MAP.op.extcodecopy), - (0x3d, 0, false, COL_MAP.op.returndatasize), - (0x3e, 0, false, COL_MAP.op.returndatacopy), - (0x3f, 0, false, COL_MAP.op.extcodehash), - (0x40, 0, false, COL_MAP.op.blockhash), - (0x41, 0, false, COL_MAP.op.coinbase), - (0x42, 0, false, COL_MAP.op.timestamp), - (0x43, 0, false, COL_MAP.op.number), - (0x44, 0, false, COL_MAP.op.difficulty), - (0x45, 0, false, COL_MAP.op.gaslimit), - (0x46, 0, false, COL_MAP.op.chainid), - (0x47, 0, false, COL_MAP.op.selfbalance), - (0x48, 0, false, COL_MAP.op.basefee), (0x49, 0, true, COL_MAP.op.prover_input), (0x50, 0, false, COL_MAP.op.pop), - (0x51, 0, false, COL_MAP.op.mload), - (0x52, 0, false, COL_MAP.op.mstore), - (0x53, 0, false, COL_MAP.op.mstore8), - (0x54, 0, false, COL_MAP.op.sload), - (0x55, 0, false, COL_MAP.op.sstore), (0x56, 0, false, COL_MAP.op.jump), (0x57, 0, false, COL_MAP.op.jumpi), (0x58, 0, false, COL_MAP.op.pc), - (0x59, 0, false, COL_MAP.op.msize), (0x5a, 0, false, COL_MAP.op.gas), (0x5b, 0, false, COL_MAP.op.jumpdest), (0x5c, 0, true, COL_MAP.op.get_state_root), @@ -100,27 +60,12 @@ const OPCODES: [(u8, usize, bool, usize); 96] = [ (0x60, 5, false, COL_MAP.op.push), // 0x60-0x7f (0x80, 4, false, COL_MAP.op.dup), // 0x80-0x8f (0x90, 4, false, COL_MAP.op.swap), // 0x90-0x9f - (0xa0, 0, false, COL_MAP.op.log0), - (0xa1, 0, false, COL_MAP.op.log1), - (0xa2, 0, false, COL_MAP.op.log2), - (0xa3, 0, false, COL_MAP.op.log3), - (0xa4, 0, false, COL_MAP.op.log4), - // Opcode 0xa5 is PANIC when Kernel. Make the proof unverifiable by giving it no flag to decode to. - (0xf0, 0, false, COL_MAP.op.create), - (0xf1, 0, false, COL_MAP.op.call), - (0xf2, 0, false, COL_MAP.op.callcode), - (0xf3, 0, false, COL_MAP.op.return_), - (0xf4, 0, false, COL_MAP.op.delegatecall), - (0xf5, 0, false, COL_MAP.op.create2), (0xf6, 0, true, COL_MAP.op.get_context), (0xf7, 0, true, COL_MAP.op.set_context), (0xf8, 0, true, COL_MAP.op.consume_gas), (0xf9, 0, true, COL_MAP.op.exit_kernel), - (0xfa, 0, false, COL_MAP.op.staticcall), (0xfb, 0, true, COL_MAP.op.mload_general), (0xfc, 0, true, COL_MAP.op.mstore_general), - (0xfd, 0, false, COL_MAP.op.revert), - (0xff, 0, false, COL_MAP.op.selfdestruct), ]; /// Bitfield of invalid opcodes, in little-endian order. @@ -188,19 +133,12 @@ pub fn generate(lv: &mut CpuColumnsView) { assert!(kernel <= 1); let kernel = kernel != 0; - let mut any_flag_set = false; for (oc, block_length, kernel_only, col) in OPCODES { let available = !kernel_only || kernel; let opcode_match = top_bits[8 - block_length] == oc; let flag = available && opcode_match; lv[col] = F::from_bool(flag); - if flag && any_flag_set { - panic!("opcode matched multiple flags"); - } - any_flag_set = any_flag_set || flag; } - // is_invalid is a catch-all for opcodes we can't decode. - lv.op.invalid = F::from_bool(!any_flag_set); } /// Break up an opcode (which is 8 bits long) into its eight bits. @@ -238,14 +176,12 @@ pub fn eval_packed_generic( let flag = lv[flag_col]; yield_constr.constraint(cycle_filter * flag * (flag - P::ONES)); } - yield_constr.constraint(cycle_filter * lv.op.invalid * (lv.op.invalid - P::ONES)); - // Now check that exactly one is 1. + // Now check that they sum to 0 or 1. let flag_sum: P = OPCODES .into_iter() .map(|(_, _, _, flag_col)| lv[flag_col]) - .sum::

() - + lv.op.invalid; - yield_constr.constraint(cycle_filter * (P::ONES - flag_sum)); + .sum::

(); + yield_constr.constraint(cycle_filter * flag_sum * (flag_sum - P::ONES)); // Finally, classify all opcodes, together with the kernel flag, into blocks for (oc, block_length, kernel_only, col) in OPCODES { @@ -308,20 +244,15 @@ pub fn eval_ext_circuit, const D: usize>( let constr = builder.mul_extension(cycle_filter, constr); yield_constr.constraint(builder, constr); } + // Now check that they sum to 0 or 1. { - let constr = builder.mul_sub_extension(lv.op.invalid, lv.op.invalid, lv.op.invalid); - let constr = builder.mul_extension(cycle_filter, constr); - yield_constr.constraint(builder, constr); - } - // Now check that exactly one is 1. - { - let mut constr = builder.one_extension(); + let mut flag_sum = builder.zero_extension(); for (_, _, _, flag_col) in OPCODES { let flag = lv[flag_col]; - constr = builder.sub_extension(constr, flag); + flag_sum = builder.add_extension(flag_sum, flag); } - constr = builder.sub_extension(constr, lv.op.invalid); - constr = builder.mul_extension(cycle_filter, constr); + let constr = builder.mul_sub_extension(flag_sum, flag_sum, flag_sum); + let constr = builder.mul_extension(cycle_filter, constr); yield_constr.constraint(builder, constr); } diff --git a/evm/src/cpu/kernel/aggregator.rs b/evm/src/cpu/kernel/aggregator.rs index 49398288..f44dbcb2 100644 --- a/evm/src/cpu/kernel/aggregator.rs +++ b/evm/src/cpu/kernel/aggregator.rs @@ -18,6 +18,8 @@ pub(crate) fn combined_kernel() -> Kernel { include_str!("asm/core/invalid.asm"), include_str!("asm/core/nonce.asm"), include_str!("asm/core/process_txn.asm"), + include_str!("asm/core/syscall.asm"), + include_str!("asm/core/syscall_stubs.asm"), include_str!("asm/core/terminate.asm"), include_str!("asm/core/transfer.asm"), include_str!("asm/core/util.asm"), diff --git a/evm/src/cpu/kernel/asm/core/syscall.asm b/evm/src/cpu/kernel/asm/core/syscall.asm new file mode 100644 index 00000000..2de50993 --- /dev/null +++ b/evm/src/cpu/kernel/asm/core/syscall.asm @@ -0,0 +1,160 @@ +global syscall_jumptable: + // 0x00-0x0f + JUMPTABLE sys_stop + JUMPTABLE panic // add is implemented natively + JUMPTABLE panic // mul is implemented natively + JUMPTABLE panic // sub is implemented natively + JUMPTABLE panic // div is implemented natively + JUMPTABLE sys_sdiv + JUMPTABLE panic // mod is implemented natively + JUMPTABLE sys_smod + JUMPTABLE panic // addmod is implemented natively + JUMPTABLE panic // mulmod is implemented natively + JUMPTABLE sys_exp + JUMPTABLE sys_signextend + JUMPTABLE panic // 0x0c is an invalid opcode + JUMPTABLE panic // 0x0d is an invalid opcode + JUMPTABLE panic // 0x0e is an invalid opcode + JUMPTABLE panic // 0x0f is an invalid opcode + + // 0x10-0x1f + JUMPTABLE panic // lt is implemented natively + JUMPTABLE panic // gt is implemented natively + JUMPTABLE sys_slt + JUMPTABLE sys_sgt + JUMPTABLE panic // eq is implemented natively + JUMPTABLE panic // iszero is implemented natively + JUMPTABLE panic // and is implemented natively + JUMPTABLE panic // or is implemented natively + JUMPTABLE panic // xor is implemented natively + JUMPTABLE panic // not is implemented natively + JUMPTABLE panic // byte is implemented natively + JUMPTABLE panic // shl is implemented natively + JUMPTABLE panic // shr is implemented natively + JUMPTABLE sys_sar + JUMPTABLE panic // 0x1e is an invalid opcode + JUMPTABLE panic // 0x1f is an invalid opcode + + // 0x20-0x2f + JUMPTABLE sys_keccak256 + %rep 15 + JUMPTABLE panic // 0x21-0x2f are invalid opcodes + %endrep + + // 0x30-0x3f + JUMPTABLE sys_address + JUMPTABLE sys_balance + JUMPTABLE sys_origin + JUMPTABLE sys_caller + JUMPTABLE sys_callvalue + JUMPTABLE sys_calldataload + JUMPTABLE sys_calldatasize + JUMPTABLE sys_calldatacopy + JUMPTABLE sys_codesize + JUMPTABLE sys_codecopy + JUMPTABLE sys_gasprice + JUMPTABLE sys_extcodesize + JUMPTABLE sys_extcodecopy + JUMPTABLE sys_returndatasize + JUMPTABLE sys_returndatacopy + JUMPTABLE sys_extcodehash + + // 0x40-0x4f + JUMPTABLE sys_blockhash + JUMPTABLE sys_coinbase + JUMPTABLE sys_timestamp + JUMPTABLE sys_number + JUMPTABLE sys_prevrandao + JUMPTABLE sys_gaslimit + JUMPTABLE sys_chainid + JUMPTABLE sys_selfbalance + JUMPTABLE sys_basefee + %rep 7 + JUMPTABLE panic // 0x49-0x4f are invalid opcodes + %endrep + + // 0x50-0x5f + JUMPTABLE panic // pop is implemented natively + JUMPTABLE sys_mload + JUMPTABLE sys_mstore + JUMPTABLE sys_mstore8 + JUMPTABLE sys_sload + JUMPTABLE sys_sstore + JUMPTABLE panic // jump is implemented natively + JUMPTABLE panic // jumpi is implemented natively + JUMPTABLE panic // pc is implemented natively + JUMPTABLE sys_msize + JUMPTABLE panic // gas is implemented natively + JUMPTABLE panic // jumpdest is implemented natively + JUMPTABLE panic // 0x5c is an invalid opcode + JUMPTABLE panic // 0x5d is an invalid opcode + JUMPTABLE panic // 0x5e is an invalid opcode + JUMPTABLE panic // 0x5f is an invalid opcode + + // 0x60-0x6f + %rep 16 + JUMPTABLE panic // push1-push16 are implemented natively + %endrep + + // 0x70-0x7f + %rep 16 + JUMPTABLE panic // push17-push32 are implemented natively + %endrep + + // 0x80-0x8f + %rep 16 + JUMPTABLE panic // dup1-dup16 are implemented natively + %endrep + + // 0x90-0x9f + %rep 16 + JUMPTABLE panic // swap1-swap16 are implemented natively + %endrep + + // 0xa0-0xaf + JUMPTABLE sys_log0 + JUMPTABLE sys_log1 + JUMPTABLE sys_log2 + JUMPTABLE sys_log3 + JUMPTABLE sys_log4 + %rep 11 + JUMPTABLE panic // 0xa5-0xaf are invalid opcodes + %endrep + + // 0xb0-0xbf + %rep 16 + JUMPTABLE panic // 0xb0-0xbf are invalid opcodes + %endrep + + // 0xc0-0xcf + %rep 16 + JUMPTABLE panic // 0xc0-0xcf are invalid opcodes + %endrep + + // 0xd0-0xdf + %rep 16 + JUMPTABLE panic // 0xd0-0xdf are invalid opcodes + %endrep + + // 0xe0-0xef + %rep 16 + JUMPTABLE panic // 0xe0-0xef are invalid opcodes + %endrep + + // 0xf0-0xff + JUMPTABLE sys_create + JUMPTABLE sys_call + JUMPTABLE sys_callcode + JUMPTABLE sys_return + JUMPTABLE sys_delegatecall + JUMPTABLE sys_create2 + JUMPTABLE panic // 0xf6 is an invalid opcode + JUMPTABLE panic // 0xf7 is an invalid opcode + JUMPTABLE panic // 0xf8 is an invalid opcode + JUMPTABLE panic // 0xf9 is an invalid opcode + JUMPTABLE sys_staticcall + JUMPTABLE panic // 0xfb is an invalid opcode + JUMPTABLE panic // 0xfc is an invalid opcode + JUMPTABLE sys_revert + JUMPTABLE panic // 0xfe is an invalid opcode + JUMPTABLE sys_selfdestruct diff --git a/evm/src/cpu/kernel/asm/core/syscall_stubs.asm b/evm/src/cpu/kernel/asm/core/syscall_stubs.asm new file mode 100644 index 00000000..d39d8145 --- /dev/null +++ b/evm/src/cpu/kernel/asm/core/syscall_stubs.asm @@ -0,0 +1,53 @@ +// Labels for unimplemented syscalls to make the kernel assemble. +// Each label should be removed from this file once it is implemented. + +global sys_sdiv: +global sys_smod: +global sys_signextend: +global sys_slt: +global sys_sgt: +global sys_sar: +global sys_keccak256: +global sys_address: +global sys_balance: +global sys_origin: +global sys_caller: +global sys_callvalue: +global sys_calldataload: +global sys_calldatasize: +global sys_calldatacopy: +global sys_codesize: +global sys_codecopy: +global sys_gasprice: +global sys_extcodesize: +global sys_extcodecopy: +global sys_returndatasize: +global sys_returndatacopy: +global sys_extcodehash: +global sys_blockhash: +global sys_coinbase: +global sys_timestamp: +global sys_number: +global sys_prevrandao: +global sys_gaslimit: +global sys_chainid: +global sys_selfbalance: +global sys_basefee: +global sys_mload: +global sys_mstore: +global sys_mstore8: +global sys_sload: +global sys_sstore: +global sys_msize: +global sys_log0: +global sys_log1: +global sys_log2: +global sys_log3: +global sys_log4: +global sys_create: +global sys_call: +global sys_callcode: +global sys_delegatecall: +global sys_create2: +global sys_staticcall: + PANIC diff --git a/evm/src/cpu/kernel/assembler.rs b/evm/src/cpu/kernel/assembler.rs index aad2dd53..5f9584ba 100644 --- a/evm/src/cpu/kernel/assembler.rs +++ b/evm/src/cpu/kernel/assembler.rs @@ -300,11 +300,29 @@ fn find_labels( } Item::StandardOp(_) => *offset += 1, Item::Bytes(bytes) => *offset += bytes.len(), + Item::Jumptable(labels) => *offset += labels.len() * (BYTES_PER_OFFSET as usize), } } local_labels } +fn look_up_label( + label: &String, + local_labels: &HashMap, + global_labels: &HashMap, +) -> Vec { + let offset = local_labels + .get(label) + .or_else(|| global_labels.get(label)) + .unwrap_or_else(|| panic!("No such label: {label}")); + // We want the BYTES_PER_OFFSET least significant bytes in BE order. + // It's easiest to rev the first BYTES_PER_OFFSET bytes of the LE encoding. + (0..BYTES_PER_OFFSET) + .rev() + .map(|i| offset.to_le_bytes()[i as usize]) + .collect() +} + fn assemble_file( body: Vec, code: &mut Vec, @@ -327,18 +345,7 @@ fn assemble_file( Item::Push(target) => { let target_bytes: Vec = match target { PushTarget::Literal(n) => u256_to_trimmed_be_bytes(&n), - PushTarget::Label(label) => { - let offset = local_labels - .get(&label) - .or_else(|| global_labels.get(&label)) - .unwrap_or_else(|| panic!("No such label: {label}")); - // We want the BYTES_PER_OFFSET least significant bytes in BE order. - // It's easiest to rev the first BYTES_PER_OFFSET bytes of the LE encoding. - (0..BYTES_PER_OFFSET) - .rev() - .map(|i| offset.to_le_bytes()[i as usize]) - .collect() - } + PushTarget::Label(label) => look_up_label(&label, &local_labels, global_labels), PushTarget::MacroLabel(v) => panic!("Macro label not in a macro: {v}"), PushTarget::MacroVar(v) => panic!("Variable not in a macro: {v}"), PushTarget::Constant(c) => panic!("Constant wasn't inlined: {c}"), @@ -353,6 +360,12 @@ fn assemble_file( code.push(get_opcode(&opcode)); } Item::Bytes(bytes) => code.extend(bytes), + Item::Jumptable(labels) => { + for label in labels { + let bytes = look_up_label(&label, &local_labels, global_labels); + code.extend(bytes); + } + } } } } diff --git a/evm/src/cpu/kernel/ast.rs b/evm/src/cpu/kernel/ast.rs index 6180b1c8..ed4f6dbb 100644 --- a/evm/src/cpu/kernel/ast.rs +++ b/evm/src/cpu/kernel/ast.rs @@ -34,6 +34,8 @@ pub(crate) enum Item { StandardOp(String), /// Literal hex data; should contain an even number of hex chars. Bytes(Vec), + /// Creates a table of addresses from a list of labels. + Jumptable(Vec), } /// The left hand side of a %stack stack-manipulation macro. diff --git a/evm/src/cpu/kernel/evm_asm.pest b/evm/src/cpu/kernel/evm_asm.pest index 9b8721f4..e7337430 100644 --- a/evm/src/cpu/kernel/evm_asm.pest +++ b/evm/src/cpu/kernel/evm_asm.pest @@ -15,7 +15,7 @@ literal = { literal_hex | literal_decimal } variable = ${ "$" ~ identifier } constant = ${ "@" ~ identifier } -item = { macro_def | macro_call | repeat | stack | global_label_decl | local_label_decl | macro_label_decl | bytes_item | push_instruction | prover_input_instruction | nullary_instruction } +item = { macro_def | macro_call | repeat | stack | global_label_decl | local_label_decl | macro_label_decl | bytes_item | jumptable_item | push_instruction | prover_input_instruction | nullary_instruction } macro_def = { ^"%macro" ~ identifier ~ paramlist? ~ item* ~ ^"%endmacro" } macro_call = ${ "%" ~ !((^"macro" | ^"endmacro" | ^"rep" | ^"endrep" | ^"stack") ~ !identifier_char) ~ identifier ~ macro_arglist? } repeat = { ^"%rep" ~ literal ~ item* ~ ^"%endrep" } @@ -35,6 +35,7 @@ macro_label_decl = ${ "%%" ~ identifier ~ ":" } macro_label = ${ "%%" ~ identifier } bytes_item = { ^"BYTES " ~ literal ~ ("," ~ literal)* } +jumptable_item = { ^"JUMPTABLE " ~ identifier ~ ("," ~ identifier)* } push_instruction = { ^"PUSH " ~ push_target } push_target = { literal | identifier | macro_label | variable | constant } prover_input_instruction = { ^"PROVER_INPUT" ~ "(" ~ prover_input_fn ~ ")" } diff --git a/evm/src/cpu/kernel/parser.rs b/evm/src/cpu/kernel/parser.rs index b7a8124b..49181b71 100644 --- a/evm/src/cpu/kernel/parser.rs +++ b/evm/src/cpu/kernel/parser.rs @@ -39,6 +39,9 @@ fn parse_item(item: Pair) -> Item { Item::MacroLabelDeclaration(item.into_inner().next().unwrap().as_str().into()) } Rule::bytes_item => Item::Bytes(item.into_inner().map(parse_literal_u8).collect()), + Rule::jumptable_item => { + Item::Jumptable(item.into_inner().map(|i| i.as_str().into()).collect()) + } Rule::push_instruction => Item::Push(parse_push_target(item.into_inner().next().unwrap())), Rule::prover_input_instruction => Item::ProverInput( item.into_inner() diff --git a/evm/src/cpu/stack.rs b/evm/src/cpu/stack.rs index c72688ed..ea235578 100644 --- a/evm/src/cpu/stack.rs +++ b/evm/src/cpu/stack.rs @@ -40,73 +40,33 @@ const BASIC_TERNARY_OP: Option = Some(StackBehavior { // except the first `num_pops` and the last `pushes as usize` channels have their read flag and // address constrained automatically in this file. const STACK_BEHAVIORS: OpsColumnsView> = OpsColumnsView { - stop: None, // TODO add: BASIC_BINARY_OP, mul: BASIC_BINARY_OP, sub: BASIC_BINARY_OP, div: BASIC_BINARY_OP, - sdiv: BASIC_BINARY_OP, mod_: BASIC_BINARY_OP, - smod: BASIC_BINARY_OP, addmod: BASIC_TERNARY_OP, mulmod: BASIC_TERNARY_OP, - exp: None, // TODO - signextend: BASIC_BINARY_OP, addfp254: BASIC_BINARY_OP, mulfp254: BASIC_BINARY_OP, subfp254: BASIC_BINARY_OP, lt: BASIC_BINARY_OP, gt: BASIC_BINARY_OP, - slt: BASIC_BINARY_OP, - sgt: BASIC_BINARY_OP, eq: BASIC_BINARY_OP, iszero: BASIC_UNARY_OP, and: BASIC_BINARY_OP, or: BASIC_BINARY_OP, xor: BASIC_BINARY_OP, - not: BASIC_TERNARY_OP, + not: BASIC_UNARY_OP, byte: BASIC_BINARY_OP, shl: BASIC_BINARY_OP, shr: BASIC_BINARY_OP, - sar: BASIC_BINARY_OP, - keccak256: None, // TODO keccak_general: None, // TODO - address: None, // TODO - balance: None, // TODO - origin: None, // TODO - caller: None, // TODO - callvalue: None, // TODO - calldataload: None, // TODO - calldatasize: None, // TODO - calldatacopy: None, // TODO - codesize: None, // TODO - codecopy: None, // TODO - gasprice: None, // TODO - extcodesize: None, // TODO - extcodecopy: None, // TODO - returndatasize: None, // TODO - returndatacopy: None, // TODO - extcodehash: None, // TODO - blockhash: None, // TODO - coinbase: None, // TODO - timestamp: None, // TODO - number: None, // TODO - difficulty: None, // TODO - gaslimit: None, // TODO - chainid: None, // TODO - selfbalance: None, // TODO - basefee: None, // TODO prover_input: None, // TODO pop: None, // TODO - mload: None, // TODO - mstore: None, // TODO - mstore8: None, // TODO - sload: None, // TODO - sstore: None, // TODO jump: None, // TODO jumpi: None, // TODO pc: None, // TODO - msize: None, // TODO gas: None, // TODO jumpdest: None, // TODO get_state_root: None, // TODO @@ -116,27 +76,17 @@ const STACK_BEHAVIORS: OpsColumnsView> = OpsColumnsView { push: None, // TODO dup: None, swap: None, - log0: None, // TODO - log1: None, // TODO - log2: None, // TODO - log3: None, // TODO - log4: None, // TODO - create: None, // TODO - call: None, // TODO - callcode: None, // TODO - return_: None, // TODO - delegatecall: None, // TODO - create2: None, // TODO get_context: None, // TODO set_context: None, // TODO consume_gas: None, // TODO exit_kernel: None, // TODO - staticcall: None, // TODO mload_general: None, // TODO mstore_general: None, // TODO - revert: None, // TODO - selfdestruct: None, // TODO - invalid: None, // TODO + syscall: Some(StackBehavior { + num_pops: 0, + pushes: true, + disable_other_channels: false, + }), }; fn eval_packed_one( diff --git a/evm/src/cpu/syscalls.rs b/evm/src/cpu/syscalls.rs index 0ac31ef6..4033620e 100644 --- a/evm/src/cpu/syscalls.rs +++ b/evm/src/cpu/syscalls.rs @@ -2,62 +2,81 @@ //! //! These are usually the ones that are too complicated to implement in one CPU table row. -use once_cell::sync::Lazy; use plonky2::field::extension::Extendable; use plonky2::field::packed::PackedField; use plonky2::field::types::Field; use plonky2::hash::hash_types::RichField; use plonky2::iop::ext_target::ExtensionTarget; +use static_assertions::const_assert; use crate::constraint_consumer::{ConstraintConsumer, RecursiveConstraintConsumer}; -use crate::cpu::columns::{CpuColumnsView, COL_MAP}; +use crate::cpu::columns::CpuColumnsView; use crate::cpu::kernel::aggregator::KERNEL; +use crate::cpu::membus::NUM_GP_CHANNELS; +use crate::memory::segments::Segment; -const NUM_SYSCALLS: usize = 3; - -fn make_syscall_list() -> [(usize, usize); NUM_SYSCALLS] { - let kernel = Lazy::force(&KERNEL); - [ - (COL_MAP.op.stop, "sys_stop"), - (COL_MAP.op.exp, "sys_exp"), - (COL_MAP.op.invalid, "handle_invalid"), - ] - .map(|(col_index, handler_name)| (col_index, kernel.global_labels[handler_name])) -} - -static TRAP_LIST: Lazy<[(usize, usize); NUM_SYSCALLS]> = Lazy::new(make_syscall_list); +// Copy the constant but make it `usize`. +const BYTES_PER_OFFSET: usize = crate::cpu::kernel::assembler::BYTES_PER_OFFSET as usize; +const_assert!(BYTES_PER_OFFSET < NUM_GP_CHANNELS); // Reserve one channel for stack push pub fn eval_packed( lv: &CpuColumnsView

, nv: &CpuColumnsView

, yield_constr: &mut ConstraintConsumer

, ) { - let syscall_list = Lazy::force(&TRAP_LIST); - // 1 if _any_ syscall, else 0. - let should_syscall: P = syscall_list - .iter() - .map(|&(col_index, _)| lv[col_index]) - .sum(); - let filter = lv.is_cpu_cycle * should_syscall; + let filter = lv.is_cpu_cycle * lv.op.syscall; - // If syscall: set program counter to the handler address - // Note that at most one of the `lv[col_index]`s will be 1 and all others will be 0. - let syscall_dst: P = syscall_list - .iter() - .map(|&(col_index, handler_addr)| { - lv[col_index] * P::Scalar::from_canonical_usize(handler_addr) - }) + // Look up the handler in memory + let code_segment = P::Scalar::from_canonical_usize(Segment::Code as usize); + let syscall_jumptable_start = + P::Scalar::from_canonical_usize(KERNEL.global_labels["syscall_jumptable"]); + let opcode: P = lv + .opcode_bits + .into_iter() + .enumerate() + .map(|(i, bit)| bit * P::Scalar::from_canonical_u64(1 << i)) .sum(); - yield_constr.constraint_transition(filter * (nv.program_counter - syscall_dst)); - // If syscall: set kernel mode + let opcode_handler_addr_start = + syscall_jumptable_start + opcode * P::Scalar::from_canonical_usize(BYTES_PER_OFFSET); + for (i, channel) in lv.mem_channels[0..BYTES_PER_OFFSET].iter().enumerate() { + yield_constr.constraint(filter * (channel.used - P::ONES)); + yield_constr.constraint(filter * (channel.is_read - P::ONES)); + + // Set kernel context and code segment + yield_constr.constraint(filter * channel.addr_context); + yield_constr.constraint(filter * (channel.addr_segment - code_segment)); + + // Set address, using a separate channel for each of the `BYTES_PER_OFFSET` limbs. + let limb_address = opcode_handler_addr_start + P::Scalar::from_canonical_usize(i); + yield_constr.constraint(filter * (channel.addr_virtual - limb_address)); + } + + // Disable unused channels (the last channel is used to push to the stack) + for channel in &lv.mem_channels[BYTES_PER_OFFSET..NUM_GP_CHANNELS - 1] { + yield_constr.constraint(filter * channel.used); + } + + // Set program counter to the handler address + // The addresses are big-endian in memory + let target = lv.mem_channels[0..BYTES_PER_OFFSET] + .iter() + .map(|channel| channel.value[0]) + .fold(P::ZEROS, |cumul, limb| { + cumul * P::Scalar::from_canonical_u64(256) + limb + }); + yield_constr.constraint_transition(filter * (nv.program_counter - target)); + // Set kernel mode yield_constr.constraint_transition(filter * (nv.is_kernel_mode - P::ONES)); + // Maintain current context + yield_constr.constraint_transition(filter * (nv.context - lv.context)); - let output = lv.mem_channels[0].value; - // If syscall: push current PC to stack + // This memory channel is constrained in `stack.rs`. + let output = lv.mem_channels[NUM_GP_CHANNELS - 1].value; + // Push current PC to stack yield_constr.constraint(filter * (output[0] - lv.program_counter)); - // If syscall: push current kernel flag to stack (share register with PC) + // Push current kernel flag to stack (share register with PC) yield_constr.constraint(filter * (output[1] - lv.is_kernel_mode)); - // If syscall: zero the rest of that register + // Zero the rest of that register for &limb in &output[2..] { yield_constr.constraint(filter * limb); } @@ -69,46 +88,111 @@ pub fn eval_ext_circuit, const D: usize>( nv: &CpuColumnsView>, yield_constr: &mut RecursiveConstraintConsumer, ) { - let syscall_list = Lazy::force(&TRAP_LIST); - // 1 if _any_ syscall, else 0. - let should_syscall = - builder.add_many_extension(syscall_list.iter().map(|&(col_index, _)| lv[col_index])); - let filter = builder.mul_extension(lv.is_cpu_cycle, should_syscall); + let filter = builder.mul_extension(lv.is_cpu_cycle, lv.op.syscall); - // If syscall: set program counter to the handler address + // Look up the handler in memory + let code_segment = F::from_canonical_usize(Segment::Code as usize); + let syscall_jumptable_start = builder.constant_extension( + F::from_canonical_usize(KERNEL.global_labels["syscall_jumptable"]).into(), + ); + let opcode = lv + .opcode_bits + .into_iter() + .rev() + .fold(builder.zero_extension(), |cumul, bit| { + builder.mul_const_add_extension(F::TWO, cumul, bit) + }); + let opcode_handler_addr_start = builder.mul_const_add_extension( + F::from_canonical_usize(BYTES_PER_OFFSET), + opcode, + syscall_jumptable_start, + ); + for (i, channel) in lv.mem_channels[0..BYTES_PER_OFFSET].iter().enumerate() { + { + let constr = builder.mul_sub_extension(filter, channel.used, filter); + yield_constr.constraint(builder, constr); + } + { + let constr = builder.mul_sub_extension(filter, channel.is_read, filter); + yield_constr.constraint(builder, constr); + } + + // Set kernel context and code segment + { + let constr = builder.mul_extension(filter, channel.addr_context); + yield_constr.constraint(builder, constr); + } + { + let constr = builder.arithmetic_extension( + F::ONE, + -code_segment, + filter, + channel.addr_segment, + filter, + ); + yield_constr.constraint(builder, constr); + } + + // Set address, using a separate channel for each of the `BYTES_PER_OFFSET` limbs. + { + let diff = builder.sub_extension(channel.addr_virtual, opcode_handler_addr_start); + let constr = builder.arithmetic_extension( + F::ONE, + -F::from_canonical_usize(i), + filter, + diff, + filter, + ); + yield_constr.constraint(builder, constr); + } + } + + // Disable unused channels (the last channel is used to push to the stack) + for channel in &lv.mem_channels[BYTES_PER_OFFSET..NUM_GP_CHANNELS - 1] { + let constr = builder.mul_extension(filter, channel.used); + yield_constr.constraint(builder, constr); + } + + // Set program counter to the handler address + // The addresses are big-endian in memory { - // Note that at most one of the `lv[col_index]`s will be 1 and all others will be 0. - let syscall_dst = syscall_list.iter().fold( - builder.zero_extension(), - |cumul, &(col_index, handler_addr)| { - let handler_addr = F::from_canonical_usize(handler_addr); - builder.mul_const_add_extension(handler_addr, lv[col_index], cumul) - }, - ); - let constr = builder.sub_extension(nv.program_counter, syscall_dst); - let constr = builder.mul_extension(filter, constr); + let target = lv.mem_channels[0..BYTES_PER_OFFSET] + .iter() + .map(|channel| channel.value[0]) + .fold(builder.zero_extension(), |cumul, limb| { + builder.mul_const_add_extension(F::from_canonical_u64(256), cumul, limb) + }); + let diff = builder.sub_extension(nv.program_counter, target); + let constr = builder.mul_extension(filter, diff); yield_constr.constraint_transition(builder, constr); } - // If syscall: set kernel mode + // Set kernel mode { let constr = builder.mul_sub_extension(filter, nv.is_kernel_mode, filter); yield_constr.constraint_transition(builder, constr); } + // Maintain current context + { + let diff = builder.sub_extension(nv.context, lv.context); + let constr = builder.mul_extension(filter, diff); + yield_constr.constraint_transition(builder, constr); + } - let output = lv.mem_channels[0].value; - // If syscall: push current PC to stack + // This memory channel is constrained in `stack.rs`. + let output = lv.mem_channels[NUM_GP_CHANNELS - 1].value; + // Push current PC to stack { - let constr = builder.sub_extension(output[0], lv.program_counter); - let constr = builder.mul_extension(filter, constr); + let diff = builder.sub_extension(output[0], lv.program_counter); + let constr = builder.mul_extension(filter, diff); yield_constr.constraint(builder, constr); } - // If syscall: push current kernel flag to stack (share register with PC) + // Push current kernel flag to stack (share register with PC) { - let constr = builder.sub_extension(output[1], lv.is_kernel_mode); - let constr = builder.mul_extension(filter, constr); + let diff = builder.sub_extension(output[1], lv.is_kernel_mode); + let constr = builder.mul_extension(filter, diff); yield_constr.constraint(builder, constr); } - // If syscall: zero the rest of that register + // Zero the rest of that register for &limb in &output[2..] { let constr = builder.mul_extension(filter, limb); yield_constr.constraint(builder, constr);