Combine all syscalls into one flag (#802)

* Combine all syscalls into one flag

* Minor: typo

* Daniel PR comments
This commit is contained in:
Jacqueline Nabaglo 2022-11-07 12:29:28 -08:00 committed by GitHub
parent 98b9f3a462
commit 626c2583de
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
13 changed files with 439 additions and 273 deletions

View File

@ -28,6 +28,7 @@ rlp = "0.5.1"
rlp-derive = "0.1.0"
serde = { version = "1.0.144", features = ["derive"] }
sha2 = "0.10.2"
static_assertions = "1.1.0"
tiny-keccak = "2.0.2"
[dev-dependencies]

View File

@ -7,106 +7,59 @@ use crate::util::{indices_arr, transmute_no_compile_time_size_checks};
#[repr(C)]
#[derive(Eq, PartialEq, Debug)]
pub struct OpsColumnsView<T> {
pub stop: T,
// TODO: combine ADD, MUL, SUB, DIV, MOD, ADDFP254, MULFP254, SUBFP254, LT, and GT into one flag
pub add: T,
pub mul: T,
pub sub: T,
pub div: T,
pub sdiv: T,
pub mod_: T,
pub smod: T,
// TODO: combine ADDMOD, MULMOD into one flag
pub addmod: T,
pub mulmod: T,
pub exp: T,
pub signextend: T,
pub addfp254: T,
pub mulfp254: T,
pub subfp254: T,
pub lt: T,
pub gt: T,
pub slt: T,
pub sgt: T,
pub eq: T, // Note: This column must be 0 when is_cpu_cycle = 0.
pub iszero: T, // Note: This column must be 0 when is_cpu_cycle = 0.
// TODO: combine AND, OR, and XOR into one flag
pub and: T,
pub or: T,
pub xor: T,
pub not: T,
pub byte: T,
// TODO: combine SHL and SHR into one flag
pub shl: T,
pub shr: T,
pub sar: T,
pub keccak256: T,
pub keccak_general: T,
pub address: T,
pub balance: T,
pub origin: T,
pub caller: T,
pub callvalue: T,
pub calldataload: T,
pub calldatasize: T,
pub calldatacopy: T,
pub codesize: T,
pub codecopy: T,
pub gasprice: T,
pub extcodesize: T,
pub extcodecopy: T,
pub returndatasize: T,
pub returndatacopy: T,
pub extcodehash: T,
pub blockhash: T,
pub coinbase: T,
pub timestamp: T,
pub number: T,
pub difficulty: T,
pub gaslimit: T,
pub chainid: T,
pub selfbalance: T,
pub basefee: T,
pub prover_input: T,
pub pop: T,
pub mload: T,
pub mstore: T,
pub mstore8: T,
pub sload: T,
pub sstore: T,
// TODO: combine JUMP and JUMPI into one flag
pub jump: T, // Note: This column must be 0 when is_cpu_cycle = 0.
pub jumpi: T, // Note: This column must be 0 when is_cpu_cycle = 0.
pub pc: T,
pub msize: T,
pub gas: T,
pub jumpdest: T,
// TODO: combine GET_STATE_ROOT and SET_STATE_ROOT into one flag
pub get_state_root: T,
pub set_state_root: T,
// TODO: combine GET_RECEIPT_ROOT and SET_RECEIPT_ROOT into one flag
pub get_receipt_root: T,
pub set_receipt_root: T,
pub push: T,
pub dup: T,
pub swap: T,
pub log0: T,
pub log1: T,
pub log2: T,
pub log3: T,
pub log4: T,
// PANIC does not get a flag; it fails at the decode stage.
pub create: T,
pub call: T,
pub callcode: T,
pub return_: T,
pub delegatecall: T,
pub create2: T,
// TODO: combine GET_CONTEXT and SET_CONTEXT into one flag
pub get_context: T,
pub set_context: T,
pub consume_gas: T,
pub exit_kernel: T,
pub staticcall: T,
// TODO: combine MLOAD_GENERAL and MSTORE_GENERAL into one flag
pub mload_general: T,
pub mstore_general: T,
pub revert: T,
pub selfdestruct: T,
// TODO: this doesn't actually need its own flag. We can just do `1 - sum(all other flags)`.
pub invalid: T,
pub syscall: T,
}
// `u8` is guaranteed to have a `size_of` of 1.

View File

@ -8,36 +8,49 @@ use crate::constraint_consumer::{ConstraintConsumer, RecursiveConstraintConsumer
use crate::cpu::columns::{CpuColumnsView, COL_MAP};
use crate::cpu::kernel::aggregator::KERNEL;
// TODO: This list is incomplete.
const NATIVE_INSTRUCTIONS: [usize; 28] = [
const NATIVE_INSTRUCTIONS: [usize; 37] = [
COL_MAP.op.add,
COL_MAP.op.mul,
COL_MAP.op.sub,
COL_MAP.op.div,
COL_MAP.op.sdiv,
COL_MAP.op.mod_,
COL_MAP.op.smod,
COL_MAP.op.addmod,
COL_MAP.op.mulmod,
COL_MAP.op.signextend,
COL_MAP.op.addfp254,
COL_MAP.op.mulfp254,
COL_MAP.op.subfp254,
COL_MAP.op.lt,
COL_MAP.op.gt,
COL_MAP.op.slt,
COL_MAP.op.sgt,
COL_MAP.op.eq,
COL_MAP.op.iszero,
COL_MAP.op.and,
COL_MAP.op.or,
COL_MAP.op.xor,
COL_MAP.op.not,
COL_MAP.op.byte,
COL_MAP.op.shl,
COL_MAP.op.shr,
COL_MAP.op.sar,
COL_MAP.op.keccak_general,
COL_MAP.op.prover_input,
COL_MAP.op.pop,
// not JUMP (need to jump)
// not JUMPI (possible need to jump)
COL_MAP.op.pc,
COL_MAP.op.gas,
COL_MAP.op.jumpdest,
COL_MAP.op.get_state_root,
COL_MAP.op.set_state_root,
COL_MAP.op.get_receipt_root,
COL_MAP.op.set_receipt_root,
// not PUSH (need to increment by more than 1)
COL_MAP.op.dup,
COL_MAP.op.swap,
COL_MAP.op.get_context,
COL_MAP.op.set_context,
COL_MAP.op.consume_gas,
// not EXIT_KERNEL (performs a jump)
COL_MAP.op.mload_general,
COL_MAP.op.mstore_general,
// not SYSCALL (performs a jump)
];
fn get_halt_pcs<F: Field>() -> (F, F) {

View File

@ -22,27 +22,20 @@ use crate::cpu::columns::{CpuColumnsView, COL_MAP};
/// behavior.
/// Note: invalid opcodes are not represented here. _Any_ opcode is permitted to decode to
/// `is_invalid`. The kernel then verifies that the opcode was _actually_ invalid.
const OPCODES: [(u8, usize, bool, usize); 96] = [
const OPCODES: [(u8, usize, bool, usize); 42] = [
// (start index of block, number of top bits to check (log2), kernel-only, flag column)
(0x00, 0, false, COL_MAP.op.stop),
(0x01, 0, false, COL_MAP.op.add),
(0x02, 0, false, COL_MAP.op.mul),
(0x03, 0, false, COL_MAP.op.sub),
(0x04, 0, false, COL_MAP.op.div),
(0x05, 0, false, COL_MAP.op.sdiv),
(0x06, 0, false, COL_MAP.op.mod_),
(0x07, 0, false, COL_MAP.op.smod),
(0x08, 0, false, COL_MAP.op.addmod),
(0x09, 0, false, COL_MAP.op.mulmod),
(0x0a, 0, false, COL_MAP.op.exp),
(0x0b, 0, false, COL_MAP.op.signextend),
(0x0c, 0, true, COL_MAP.op.addfp254),
(0x0d, 0, true, COL_MAP.op.mulfp254),
(0x0e, 0, true, COL_MAP.op.subfp254),
(0x10, 0, false, COL_MAP.op.lt),
(0x11, 0, false, COL_MAP.op.gt),
(0x12, 0, false, COL_MAP.op.slt),
(0x13, 0, false, COL_MAP.op.sgt),
(0x14, 0, false, COL_MAP.op.eq),
(0x15, 0, false, COL_MAP.op.iszero),
(0x16, 0, false, COL_MAP.op.and),
@ -52,45 +45,12 @@ const OPCODES: [(u8, usize, bool, usize); 96] = [
(0x1a, 0, false, COL_MAP.op.byte),
(0x1b, 0, false, COL_MAP.op.shl),
(0x1c, 0, false, COL_MAP.op.shr),
(0x1d, 0, false, COL_MAP.op.sar),
(0x20, 0, false, COL_MAP.op.keccak256),
(0x21, 0, true, COL_MAP.op.keccak_general),
(0x30, 0, false, COL_MAP.op.address),
(0x31, 0, false, COL_MAP.op.balance),
(0x32, 0, false, COL_MAP.op.origin),
(0x33, 0, false, COL_MAP.op.caller),
(0x34, 0, false, COL_MAP.op.callvalue),
(0x35, 0, false, COL_MAP.op.calldataload),
(0x36, 0, false, COL_MAP.op.calldatasize),
(0x37, 0, false, COL_MAP.op.calldatacopy),
(0x38, 0, false, COL_MAP.op.codesize),
(0x39, 0, false, COL_MAP.op.codecopy),
(0x3a, 0, false, COL_MAP.op.gasprice),
(0x3b, 0, false, COL_MAP.op.extcodesize),
(0x3c, 0, false, COL_MAP.op.extcodecopy),
(0x3d, 0, false, COL_MAP.op.returndatasize),
(0x3e, 0, false, COL_MAP.op.returndatacopy),
(0x3f, 0, false, COL_MAP.op.extcodehash),
(0x40, 0, false, COL_MAP.op.blockhash),
(0x41, 0, false, COL_MAP.op.coinbase),
(0x42, 0, false, COL_MAP.op.timestamp),
(0x43, 0, false, COL_MAP.op.number),
(0x44, 0, false, COL_MAP.op.difficulty),
(0x45, 0, false, COL_MAP.op.gaslimit),
(0x46, 0, false, COL_MAP.op.chainid),
(0x47, 0, false, COL_MAP.op.selfbalance),
(0x48, 0, false, COL_MAP.op.basefee),
(0x49, 0, true, COL_MAP.op.prover_input),
(0x50, 0, false, COL_MAP.op.pop),
(0x51, 0, false, COL_MAP.op.mload),
(0x52, 0, false, COL_MAP.op.mstore),
(0x53, 0, false, COL_MAP.op.mstore8),
(0x54, 0, false, COL_MAP.op.sload),
(0x55, 0, false, COL_MAP.op.sstore),
(0x56, 0, false, COL_MAP.op.jump),
(0x57, 0, false, COL_MAP.op.jumpi),
(0x58, 0, false, COL_MAP.op.pc),
(0x59, 0, false, COL_MAP.op.msize),
(0x5a, 0, false, COL_MAP.op.gas),
(0x5b, 0, false, COL_MAP.op.jumpdest),
(0x5c, 0, true, COL_MAP.op.get_state_root),
@ -100,27 +60,12 @@ const OPCODES: [(u8, usize, bool, usize); 96] = [
(0x60, 5, false, COL_MAP.op.push), // 0x60-0x7f
(0x80, 4, false, COL_MAP.op.dup), // 0x80-0x8f
(0x90, 4, false, COL_MAP.op.swap), // 0x90-0x9f
(0xa0, 0, false, COL_MAP.op.log0),
(0xa1, 0, false, COL_MAP.op.log1),
(0xa2, 0, false, COL_MAP.op.log2),
(0xa3, 0, false, COL_MAP.op.log3),
(0xa4, 0, false, COL_MAP.op.log4),
// Opcode 0xa5 is PANIC when Kernel. Make the proof unverifiable by giving it no flag to decode to.
(0xf0, 0, false, COL_MAP.op.create),
(0xf1, 0, false, COL_MAP.op.call),
(0xf2, 0, false, COL_MAP.op.callcode),
(0xf3, 0, false, COL_MAP.op.return_),
(0xf4, 0, false, COL_MAP.op.delegatecall),
(0xf5, 0, false, COL_MAP.op.create2),
(0xf6, 0, true, COL_MAP.op.get_context),
(0xf7, 0, true, COL_MAP.op.set_context),
(0xf8, 0, true, COL_MAP.op.consume_gas),
(0xf9, 0, true, COL_MAP.op.exit_kernel),
(0xfa, 0, false, COL_MAP.op.staticcall),
(0xfb, 0, true, COL_MAP.op.mload_general),
(0xfc, 0, true, COL_MAP.op.mstore_general),
(0xfd, 0, false, COL_MAP.op.revert),
(0xff, 0, false, COL_MAP.op.selfdestruct),
];
/// Bitfield of invalid opcodes, in little-endian order.
@ -188,19 +133,12 @@ pub fn generate<F: RichField>(lv: &mut CpuColumnsView<F>) {
assert!(kernel <= 1);
let kernel = kernel != 0;
let mut any_flag_set = false;
for (oc, block_length, kernel_only, col) in OPCODES {
let available = !kernel_only || kernel;
let opcode_match = top_bits[8 - block_length] == oc;
let flag = available && opcode_match;
lv[col] = F::from_bool(flag);
if flag && any_flag_set {
panic!("opcode matched multiple flags");
}
any_flag_set = any_flag_set || flag;
}
// is_invalid is a catch-all for opcodes we can't decode.
lv.op.invalid = F::from_bool(!any_flag_set);
}
/// Break up an opcode (which is 8 bits long) into its eight bits.
@ -238,14 +176,12 @@ pub fn eval_packed_generic<P: PackedField>(
let flag = lv[flag_col];
yield_constr.constraint(cycle_filter * flag * (flag - P::ONES));
}
yield_constr.constraint(cycle_filter * lv.op.invalid * (lv.op.invalid - P::ONES));
// Now check that exactly one is 1.
// Now check that they sum to 0 or 1.
let flag_sum: P = OPCODES
.into_iter()
.map(|(_, _, _, flag_col)| lv[flag_col])
.sum::<P>()
+ lv.op.invalid;
yield_constr.constraint(cycle_filter * (P::ONES - flag_sum));
.sum::<P>();
yield_constr.constraint(cycle_filter * flag_sum * (flag_sum - P::ONES));
// Finally, classify all opcodes, together with the kernel flag, into blocks
for (oc, block_length, kernel_only, col) in OPCODES {
@ -308,20 +244,15 @@ pub fn eval_ext_circuit<F: RichField + Extendable<D>, const D: usize>(
let constr = builder.mul_extension(cycle_filter, constr);
yield_constr.constraint(builder, constr);
}
// Now check that they sum to 0 or 1.
{
let constr = builder.mul_sub_extension(lv.op.invalid, lv.op.invalid, lv.op.invalid);
let constr = builder.mul_extension(cycle_filter, constr);
yield_constr.constraint(builder, constr);
}
// Now check that exactly one is 1.
{
let mut constr = builder.one_extension();
let mut flag_sum = builder.zero_extension();
for (_, _, _, flag_col) in OPCODES {
let flag = lv[flag_col];
constr = builder.sub_extension(constr, flag);
flag_sum = builder.add_extension(flag_sum, flag);
}
constr = builder.sub_extension(constr, lv.op.invalid);
constr = builder.mul_extension(cycle_filter, constr);
let constr = builder.mul_sub_extension(flag_sum, flag_sum, flag_sum);
let constr = builder.mul_extension(cycle_filter, constr);
yield_constr.constraint(builder, constr);
}

View File

@ -18,6 +18,8 @@ pub(crate) fn combined_kernel() -> Kernel {
include_str!("asm/core/invalid.asm"),
include_str!("asm/core/nonce.asm"),
include_str!("asm/core/process_txn.asm"),
include_str!("asm/core/syscall.asm"),
include_str!("asm/core/syscall_stubs.asm"),
include_str!("asm/core/terminate.asm"),
include_str!("asm/core/transfer.asm"),
include_str!("asm/core/util.asm"),

View File

@ -0,0 +1,160 @@
global syscall_jumptable:
// 0x00-0x0f
JUMPTABLE sys_stop
JUMPTABLE panic // add is implemented natively
JUMPTABLE panic // mul is implemented natively
JUMPTABLE panic // sub is implemented natively
JUMPTABLE panic // div is implemented natively
JUMPTABLE sys_sdiv
JUMPTABLE panic // mod is implemented natively
JUMPTABLE sys_smod
JUMPTABLE panic // addmod is implemented natively
JUMPTABLE panic // mulmod is implemented natively
JUMPTABLE sys_exp
JUMPTABLE sys_signextend
JUMPTABLE panic // 0x0c is an invalid opcode
JUMPTABLE panic // 0x0d is an invalid opcode
JUMPTABLE panic // 0x0e is an invalid opcode
JUMPTABLE panic // 0x0f is an invalid opcode
// 0x10-0x1f
JUMPTABLE panic // lt is implemented natively
JUMPTABLE panic // gt is implemented natively
JUMPTABLE sys_slt
JUMPTABLE sys_sgt
JUMPTABLE panic // eq is implemented natively
JUMPTABLE panic // iszero is implemented natively
JUMPTABLE panic // and is implemented natively
JUMPTABLE panic // or is implemented natively
JUMPTABLE panic // xor is implemented natively
JUMPTABLE panic // not is implemented natively
JUMPTABLE panic // byte is implemented natively
JUMPTABLE panic // shl is implemented natively
JUMPTABLE panic // shr is implemented natively
JUMPTABLE sys_sar
JUMPTABLE panic // 0x1e is an invalid opcode
JUMPTABLE panic // 0x1f is an invalid opcode
// 0x20-0x2f
JUMPTABLE sys_keccak256
%rep 15
JUMPTABLE panic // 0x21-0x2f are invalid opcodes
%endrep
// 0x30-0x3f
JUMPTABLE sys_address
JUMPTABLE sys_balance
JUMPTABLE sys_origin
JUMPTABLE sys_caller
JUMPTABLE sys_callvalue
JUMPTABLE sys_calldataload
JUMPTABLE sys_calldatasize
JUMPTABLE sys_calldatacopy
JUMPTABLE sys_codesize
JUMPTABLE sys_codecopy
JUMPTABLE sys_gasprice
JUMPTABLE sys_extcodesize
JUMPTABLE sys_extcodecopy
JUMPTABLE sys_returndatasize
JUMPTABLE sys_returndatacopy
JUMPTABLE sys_extcodehash
// 0x40-0x4f
JUMPTABLE sys_blockhash
JUMPTABLE sys_coinbase
JUMPTABLE sys_timestamp
JUMPTABLE sys_number
JUMPTABLE sys_prevrandao
JUMPTABLE sys_gaslimit
JUMPTABLE sys_chainid
JUMPTABLE sys_selfbalance
JUMPTABLE sys_basefee
%rep 7
JUMPTABLE panic // 0x49-0x4f are invalid opcodes
%endrep
// 0x50-0x5f
JUMPTABLE panic // pop is implemented natively
JUMPTABLE sys_mload
JUMPTABLE sys_mstore
JUMPTABLE sys_mstore8
JUMPTABLE sys_sload
JUMPTABLE sys_sstore
JUMPTABLE panic // jump is implemented natively
JUMPTABLE panic // jumpi is implemented natively
JUMPTABLE panic // pc is implemented natively
JUMPTABLE sys_msize
JUMPTABLE panic // gas is implemented natively
JUMPTABLE panic // jumpdest is implemented natively
JUMPTABLE panic // 0x5c is an invalid opcode
JUMPTABLE panic // 0x5d is an invalid opcode
JUMPTABLE panic // 0x5e is an invalid opcode
JUMPTABLE panic // 0x5f is an invalid opcode
// 0x60-0x6f
%rep 16
JUMPTABLE panic // push1-push16 are implemented natively
%endrep
// 0x70-0x7f
%rep 16
JUMPTABLE panic // push17-push32 are implemented natively
%endrep
// 0x80-0x8f
%rep 16
JUMPTABLE panic // dup1-dup16 are implemented natively
%endrep
// 0x90-0x9f
%rep 16
JUMPTABLE panic // swap1-swap16 are implemented natively
%endrep
// 0xa0-0xaf
JUMPTABLE sys_log0
JUMPTABLE sys_log1
JUMPTABLE sys_log2
JUMPTABLE sys_log3
JUMPTABLE sys_log4
%rep 11
JUMPTABLE panic // 0xa5-0xaf are invalid opcodes
%endrep
// 0xb0-0xbf
%rep 16
JUMPTABLE panic // 0xb0-0xbf are invalid opcodes
%endrep
// 0xc0-0xcf
%rep 16
JUMPTABLE panic // 0xc0-0xcf are invalid opcodes
%endrep
// 0xd0-0xdf
%rep 16
JUMPTABLE panic // 0xd0-0xdf are invalid opcodes
%endrep
// 0xe0-0xef
%rep 16
JUMPTABLE panic // 0xe0-0xef are invalid opcodes
%endrep
// 0xf0-0xff
JUMPTABLE sys_create
JUMPTABLE sys_call
JUMPTABLE sys_callcode
JUMPTABLE sys_return
JUMPTABLE sys_delegatecall
JUMPTABLE sys_create2
JUMPTABLE panic // 0xf6 is an invalid opcode
JUMPTABLE panic // 0xf7 is an invalid opcode
JUMPTABLE panic // 0xf8 is an invalid opcode
JUMPTABLE panic // 0xf9 is an invalid opcode
JUMPTABLE sys_staticcall
JUMPTABLE panic // 0xfb is an invalid opcode
JUMPTABLE panic // 0xfc is an invalid opcode
JUMPTABLE sys_revert
JUMPTABLE panic // 0xfe is an invalid opcode
JUMPTABLE sys_selfdestruct

View File

@ -0,0 +1,53 @@
// Labels for unimplemented syscalls to make the kernel assemble.
// Each label should be removed from this file once it is implemented.
global sys_sdiv:
global sys_smod:
global sys_signextend:
global sys_slt:
global sys_sgt:
global sys_sar:
global sys_keccak256:
global sys_address:
global sys_balance:
global sys_origin:
global sys_caller:
global sys_callvalue:
global sys_calldataload:
global sys_calldatasize:
global sys_calldatacopy:
global sys_codesize:
global sys_codecopy:
global sys_gasprice:
global sys_extcodesize:
global sys_extcodecopy:
global sys_returndatasize:
global sys_returndatacopy:
global sys_extcodehash:
global sys_blockhash:
global sys_coinbase:
global sys_timestamp:
global sys_number:
global sys_prevrandao:
global sys_gaslimit:
global sys_chainid:
global sys_selfbalance:
global sys_basefee:
global sys_mload:
global sys_mstore:
global sys_mstore8:
global sys_sload:
global sys_sstore:
global sys_msize:
global sys_log0:
global sys_log1:
global sys_log2:
global sys_log3:
global sys_log4:
global sys_create:
global sys_call:
global sys_callcode:
global sys_delegatecall:
global sys_create2:
global sys_staticcall:
PANIC

View File

@ -300,11 +300,29 @@ fn find_labels(
}
Item::StandardOp(_) => *offset += 1,
Item::Bytes(bytes) => *offset += bytes.len(),
Item::Jumptable(labels) => *offset += labels.len() * (BYTES_PER_OFFSET as usize),
}
}
local_labels
}
fn look_up_label(
label: &String,
local_labels: &HashMap<String, usize>,
global_labels: &HashMap<String, usize>,
) -> Vec<u8> {
let offset = local_labels
.get(label)
.or_else(|| global_labels.get(label))
.unwrap_or_else(|| panic!("No such label: {label}"));
// We want the BYTES_PER_OFFSET least significant bytes in BE order.
// It's easiest to rev the first BYTES_PER_OFFSET bytes of the LE encoding.
(0..BYTES_PER_OFFSET)
.rev()
.map(|i| offset.to_le_bytes()[i as usize])
.collect()
}
fn assemble_file(
body: Vec<Item>,
code: &mut Vec<u8>,
@ -327,18 +345,7 @@ fn assemble_file(
Item::Push(target) => {
let target_bytes: Vec<u8> = match target {
PushTarget::Literal(n) => u256_to_trimmed_be_bytes(&n),
PushTarget::Label(label) => {
let offset = local_labels
.get(&label)
.or_else(|| global_labels.get(&label))
.unwrap_or_else(|| panic!("No such label: {label}"));
// We want the BYTES_PER_OFFSET least significant bytes in BE order.
// It's easiest to rev the first BYTES_PER_OFFSET bytes of the LE encoding.
(0..BYTES_PER_OFFSET)
.rev()
.map(|i| offset.to_le_bytes()[i as usize])
.collect()
}
PushTarget::Label(label) => look_up_label(&label, &local_labels, global_labels),
PushTarget::MacroLabel(v) => panic!("Macro label not in a macro: {v}"),
PushTarget::MacroVar(v) => panic!("Variable not in a macro: {v}"),
PushTarget::Constant(c) => panic!("Constant wasn't inlined: {c}"),
@ -353,6 +360,12 @@ fn assemble_file(
code.push(get_opcode(&opcode));
}
Item::Bytes(bytes) => code.extend(bytes),
Item::Jumptable(labels) => {
for label in labels {
let bytes = look_up_label(&label, &local_labels, global_labels);
code.extend(bytes);
}
}
}
}
}

View File

@ -34,6 +34,8 @@ pub(crate) enum Item {
StandardOp(String),
/// Literal hex data; should contain an even number of hex chars.
Bytes(Vec<u8>),
/// Creates a table of addresses from a list of labels.
Jumptable(Vec<String>),
}
/// The left hand side of a %stack stack-manipulation macro.

View File

@ -15,7 +15,7 @@ literal = { literal_hex | literal_decimal }
variable = ${ "$" ~ identifier }
constant = ${ "@" ~ identifier }
item = { macro_def | macro_call | repeat | stack | global_label_decl | local_label_decl | macro_label_decl | bytes_item | push_instruction | prover_input_instruction | nullary_instruction }
item = { macro_def | macro_call | repeat | stack | global_label_decl | local_label_decl | macro_label_decl | bytes_item | jumptable_item | push_instruction | prover_input_instruction | nullary_instruction }
macro_def = { ^"%macro" ~ identifier ~ paramlist? ~ item* ~ ^"%endmacro" }
macro_call = ${ "%" ~ !((^"macro" | ^"endmacro" | ^"rep" | ^"endrep" | ^"stack") ~ !identifier_char) ~ identifier ~ macro_arglist? }
repeat = { ^"%rep" ~ literal ~ item* ~ ^"%endrep" }
@ -35,6 +35,7 @@ macro_label_decl = ${ "%%" ~ identifier ~ ":" }
macro_label = ${ "%%" ~ identifier }
bytes_item = { ^"BYTES " ~ literal ~ ("," ~ literal)* }
jumptable_item = { ^"JUMPTABLE " ~ identifier ~ ("," ~ identifier)* }
push_instruction = { ^"PUSH " ~ push_target }
push_target = { literal | identifier | macro_label | variable | constant }
prover_input_instruction = { ^"PROVER_INPUT" ~ "(" ~ prover_input_fn ~ ")" }

View File

@ -39,6 +39,9 @@ fn parse_item(item: Pair<Rule>) -> Item {
Item::MacroLabelDeclaration(item.into_inner().next().unwrap().as_str().into())
}
Rule::bytes_item => Item::Bytes(item.into_inner().map(parse_literal_u8).collect()),
Rule::jumptable_item => {
Item::Jumptable(item.into_inner().map(|i| i.as_str().into()).collect())
}
Rule::push_instruction => Item::Push(parse_push_target(item.into_inner().next().unwrap())),
Rule::prover_input_instruction => Item::ProverInput(
item.into_inner()

View File

@ -40,73 +40,33 @@ const BASIC_TERNARY_OP: Option<StackBehavior> = Some(StackBehavior {
// except the first `num_pops` and the last `pushes as usize` channels have their read flag and
// address constrained automatically in this file.
const STACK_BEHAVIORS: OpsColumnsView<Option<StackBehavior>> = OpsColumnsView {
stop: None, // TODO
add: BASIC_BINARY_OP,
mul: BASIC_BINARY_OP,
sub: BASIC_BINARY_OP,
div: BASIC_BINARY_OP,
sdiv: BASIC_BINARY_OP,
mod_: BASIC_BINARY_OP,
smod: BASIC_BINARY_OP,
addmod: BASIC_TERNARY_OP,
mulmod: BASIC_TERNARY_OP,
exp: None, // TODO
signextend: BASIC_BINARY_OP,
addfp254: BASIC_BINARY_OP,
mulfp254: BASIC_BINARY_OP,
subfp254: BASIC_BINARY_OP,
lt: BASIC_BINARY_OP,
gt: BASIC_BINARY_OP,
slt: BASIC_BINARY_OP,
sgt: BASIC_BINARY_OP,
eq: BASIC_BINARY_OP,
iszero: BASIC_UNARY_OP,
and: BASIC_BINARY_OP,
or: BASIC_BINARY_OP,
xor: BASIC_BINARY_OP,
not: BASIC_TERNARY_OP,
not: BASIC_UNARY_OP,
byte: BASIC_BINARY_OP,
shl: BASIC_BINARY_OP,
shr: BASIC_BINARY_OP,
sar: BASIC_BINARY_OP,
keccak256: None, // TODO
keccak_general: None, // TODO
address: None, // TODO
balance: None, // TODO
origin: None, // TODO
caller: None, // TODO
callvalue: None, // TODO
calldataload: None, // TODO
calldatasize: None, // TODO
calldatacopy: None, // TODO
codesize: None, // TODO
codecopy: None, // TODO
gasprice: None, // TODO
extcodesize: None, // TODO
extcodecopy: None, // TODO
returndatasize: None, // TODO
returndatacopy: None, // TODO
extcodehash: None, // TODO
blockhash: None, // TODO
coinbase: None, // TODO
timestamp: None, // TODO
number: None, // TODO
difficulty: None, // TODO
gaslimit: None, // TODO
chainid: None, // TODO
selfbalance: None, // TODO
basefee: None, // TODO
prover_input: None, // TODO
pop: None, // TODO
mload: None, // TODO
mstore: None, // TODO
mstore8: None, // TODO
sload: None, // TODO
sstore: None, // TODO
jump: None, // TODO
jumpi: None, // TODO
pc: None, // TODO
msize: None, // TODO
gas: None, // TODO
jumpdest: None, // TODO
get_state_root: None, // TODO
@ -116,27 +76,17 @@ const STACK_BEHAVIORS: OpsColumnsView<Option<StackBehavior>> = OpsColumnsView {
push: None, // TODO
dup: None,
swap: None,
log0: None, // TODO
log1: None, // TODO
log2: None, // TODO
log3: None, // TODO
log4: None, // TODO
create: None, // TODO
call: None, // TODO
callcode: None, // TODO
return_: None, // TODO
delegatecall: None, // TODO
create2: None, // TODO
get_context: None, // TODO
set_context: None, // TODO
consume_gas: None, // TODO
exit_kernel: None, // TODO
staticcall: None, // TODO
mload_general: None, // TODO
mstore_general: None, // TODO
revert: None, // TODO
selfdestruct: None, // TODO
invalid: None, // TODO
syscall: Some(StackBehavior {
num_pops: 0,
pushes: true,
disable_other_channels: false,
}),
};
fn eval_packed_one<P: PackedField>(

View File

@ -2,62 +2,81 @@
//!
//! These are usually the ones that are too complicated to implement in one CPU table row.
use once_cell::sync::Lazy;
use plonky2::field::extension::Extendable;
use plonky2::field::packed::PackedField;
use plonky2::field::types::Field;
use plonky2::hash::hash_types::RichField;
use plonky2::iop::ext_target::ExtensionTarget;
use static_assertions::const_assert;
use crate::constraint_consumer::{ConstraintConsumer, RecursiveConstraintConsumer};
use crate::cpu::columns::{CpuColumnsView, COL_MAP};
use crate::cpu::columns::CpuColumnsView;
use crate::cpu::kernel::aggregator::KERNEL;
use crate::cpu::membus::NUM_GP_CHANNELS;
use crate::memory::segments::Segment;
const NUM_SYSCALLS: usize = 3;
fn make_syscall_list() -> [(usize, usize); NUM_SYSCALLS] {
let kernel = Lazy::force(&KERNEL);
[
(COL_MAP.op.stop, "sys_stop"),
(COL_MAP.op.exp, "sys_exp"),
(COL_MAP.op.invalid, "handle_invalid"),
]
.map(|(col_index, handler_name)| (col_index, kernel.global_labels[handler_name]))
}
static TRAP_LIST: Lazy<[(usize, usize); NUM_SYSCALLS]> = Lazy::new(make_syscall_list);
// Copy the constant but make it `usize`.
const BYTES_PER_OFFSET: usize = crate::cpu::kernel::assembler::BYTES_PER_OFFSET as usize;
const_assert!(BYTES_PER_OFFSET < NUM_GP_CHANNELS); // Reserve one channel for stack push
pub fn eval_packed<P: PackedField>(
lv: &CpuColumnsView<P>,
nv: &CpuColumnsView<P>,
yield_constr: &mut ConstraintConsumer<P>,
) {
let syscall_list = Lazy::force(&TRAP_LIST);
// 1 if _any_ syscall, else 0.
let should_syscall: P = syscall_list
.iter()
.map(|&(col_index, _)| lv[col_index])
.sum();
let filter = lv.is_cpu_cycle * should_syscall;
let filter = lv.is_cpu_cycle * lv.op.syscall;
// If syscall: set program counter to the handler address
// Note that at most one of the `lv[col_index]`s will be 1 and all others will be 0.
let syscall_dst: P = syscall_list
.iter()
.map(|&(col_index, handler_addr)| {
lv[col_index] * P::Scalar::from_canonical_usize(handler_addr)
})
// Look up the handler in memory
let code_segment = P::Scalar::from_canonical_usize(Segment::Code as usize);
let syscall_jumptable_start =
P::Scalar::from_canonical_usize(KERNEL.global_labels["syscall_jumptable"]);
let opcode: P = lv
.opcode_bits
.into_iter()
.enumerate()
.map(|(i, bit)| bit * P::Scalar::from_canonical_u64(1 << i))
.sum();
yield_constr.constraint_transition(filter * (nv.program_counter - syscall_dst));
// If syscall: set kernel mode
let opcode_handler_addr_start =
syscall_jumptable_start + opcode * P::Scalar::from_canonical_usize(BYTES_PER_OFFSET);
for (i, channel) in lv.mem_channels[0..BYTES_PER_OFFSET].iter().enumerate() {
yield_constr.constraint(filter * (channel.used - P::ONES));
yield_constr.constraint(filter * (channel.is_read - P::ONES));
// Set kernel context and code segment
yield_constr.constraint(filter * channel.addr_context);
yield_constr.constraint(filter * (channel.addr_segment - code_segment));
// Set address, using a separate channel for each of the `BYTES_PER_OFFSET` limbs.
let limb_address = opcode_handler_addr_start + P::Scalar::from_canonical_usize(i);
yield_constr.constraint(filter * (channel.addr_virtual - limb_address));
}
// Disable unused channels (the last channel is used to push to the stack)
for channel in &lv.mem_channels[BYTES_PER_OFFSET..NUM_GP_CHANNELS - 1] {
yield_constr.constraint(filter * channel.used);
}
// Set program counter to the handler address
// The addresses are big-endian in memory
let target = lv.mem_channels[0..BYTES_PER_OFFSET]
.iter()
.map(|channel| channel.value[0])
.fold(P::ZEROS, |cumul, limb| {
cumul * P::Scalar::from_canonical_u64(256) + limb
});
yield_constr.constraint_transition(filter * (nv.program_counter - target));
// Set kernel mode
yield_constr.constraint_transition(filter * (nv.is_kernel_mode - P::ONES));
// Maintain current context
yield_constr.constraint_transition(filter * (nv.context - lv.context));
let output = lv.mem_channels[0].value;
// If syscall: push current PC to stack
// This memory channel is constrained in `stack.rs`.
let output = lv.mem_channels[NUM_GP_CHANNELS - 1].value;
// Push current PC to stack
yield_constr.constraint(filter * (output[0] - lv.program_counter));
// If syscall: push current kernel flag to stack (share register with PC)
// Push current kernel flag to stack (share register with PC)
yield_constr.constraint(filter * (output[1] - lv.is_kernel_mode));
// If syscall: zero the rest of that register
// Zero the rest of that register
for &limb in &output[2..] {
yield_constr.constraint(filter * limb);
}
@ -69,46 +88,111 @@ pub fn eval_ext_circuit<F: RichField + Extendable<D>, const D: usize>(
nv: &CpuColumnsView<ExtensionTarget<D>>,
yield_constr: &mut RecursiveConstraintConsumer<F, D>,
) {
let syscall_list = Lazy::force(&TRAP_LIST);
// 1 if _any_ syscall, else 0.
let should_syscall =
builder.add_many_extension(syscall_list.iter().map(|&(col_index, _)| lv[col_index]));
let filter = builder.mul_extension(lv.is_cpu_cycle, should_syscall);
let filter = builder.mul_extension(lv.is_cpu_cycle, lv.op.syscall);
// If syscall: set program counter to the handler address
// Look up the handler in memory
let code_segment = F::from_canonical_usize(Segment::Code as usize);
let syscall_jumptable_start = builder.constant_extension(
F::from_canonical_usize(KERNEL.global_labels["syscall_jumptable"]).into(),
);
let opcode = lv
.opcode_bits
.into_iter()
.rev()
.fold(builder.zero_extension(), |cumul, bit| {
builder.mul_const_add_extension(F::TWO, cumul, bit)
});
let opcode_handler_addr_start = builder.mul_const_add_extension(
F::from_canonical_usize(BYTES_PER_OFFSET),
opcode,
syscall_jumptable_start,
);
for (i, channel) in lv.mem_channels[0..BYTES_PER_OFFSET].iter().enumerate() {
{
let constr = builder.mul_sub_extension(filter, channel.used, filter);
yield_constr.constraint(builder, constr);
}
{
let constr = builder.mul_sub_extension(filter, channel.is_read, filter);
yield_constr.constraint(builder, constr);
}
// Set kernel context and code segment
{
let constr = builder.mul_extension(filter, channel.addr_context);
yield_constr.constraint(builder, constr);
}
{
let constr = builder.arithmetic_extension(
F::ONE,
-code_segment,
filter,
channel.addr_segment,
filter,
);
yield_constr.constraint(builder, constr);
}
// Set address, using a separate channel for each of the `BYTES_PER_OFFSET` limbs.
{
let diff = builder.sub_extension(channel.addr_virtual, opcode_handler_addr_start);
let constr = builder.arithmetic_extension(
F::ONE,
-F::from_canonical_usize(i),
filter,
diff,
filter,
);
yield_constr.constraint(builder, constr);
}
}
// Disable unused channels (the last channel is used to push to the stack)
for channel in &lv.mem_channels[BYTES_PER_OFFSET..NUM_GP_CHANNELS - 1] {
let constr = builder.mul_extension(filter, channel.used);
yield_constr.constraint(builder, constr);
}
// Set program counter to the handler address
// The addresses are big-endian in memory
{
// Note that at most one of the `lv[col_index]`s will be 1 and all others will be 0.
let syscall_dst = syscall_list.iter().fold(
builder.zero_extension(),
|cumul, &(col_index, handler_addr)| {
let handler_addr = F::from_canonical_usize(handler_addr);
builder.mul_const_add_extension(handler_addr, lv[col_index], cumul)
},
);
let constr = builder.sub_extension(nv.program_counter, syscall_dst);
let constr = builder.mul_extension(filter, constr);
let target = lv.mem_channels[0..BYTES_PER_OFFSET]
.iter()
.map(|channel| channel.value[0])
.fold(builder.zero_extension(), |cumul, limb| {
builder.mul_const_add_extension(F::from_canonical_u64(256), cumul, limb)
});
let diff = builder.sub_extension(nv.program_counter, target);
let constr = builder.mul_extension(filter, diff);
yield_constr.constraint_transition(builder, constr);
}
// If syscall: set kernel mode
// Set kernel mode
{
let constr = builder.mul_sub_extension(filter, nv.is_kernel_mode, filter);
yield_constr.constraint_transition(builder, constr);
}
// Maintain current context
{
let diff = builder.sub_extension(nv.context, lv.context);
let constr = builder.mul_extension(filter, diff);
yield_constr.constraint_transition(builder, constr);
}
let output = lv.mem_channels[0].value;
// If syscall: push current PC to stack
// This memory channel is constrained in `stack.rs`.
let output = lv.mem_channels[NUM_GP_CHANNELS - 1].value;
// Push current PC to stack
{
let constr = builder.sub_extension(output[0], lv.program_counter);
let constr = builder.mul_extension(filter, constr);
let diff = builder.sub_extension(output[0], lv.program_counter);
let constr = builder.mul_extension(filter, diff);
yield_constr.constraint(builder, constr);
}
// If syscall: push current kernel flag to stack (share register with PC)
// Push current kernel flag to stack (share register with PC)
{
let constr = builder.sub_extension(output[1], lv.is_kernel_mode);
let constr = builder.mul_extension(filter, constr);
let diff = builder.sub_extension(output[1], lv.is_kernel_mode);
let constr = builder.mul_extension(filter, diff);
yield_constr.constraint(builder, constr);
}
// If syscall: zero the rest of that register
// Zero the rest of that register
for &limb in &output[2..] {
let constr = builder.mul_extension(filter, limb);
yield_constr.constraint(builder, constr);