Fork Update (#3)

* Use static `KERNEL` in tests

* Print opcode count

* Update criterion

* Combine all syscalls into one flag (#802)

* Combine all syscalls into one flag

* Minor: typo

* Daniel PR comments

* Check that `le_sum` won't overflow

* security notes

* Test reverse_index_bits

Thanks to Least Authority for this

* clippy

* EVM shift left/right operations (#801)

* First parts of shift implementation.

* Disable range check errors.

* Tidy up ASM.

* Update comments; fix some .sum() expressions.

* First full draft of shift left/right.

* Missed a +1.

* Clippy.

* Address Jacqui's comments.

* Add comment.

* Fix missing filter.

* Address second round of comments from Jacqui.

* Remove signed operation placeholders from arithmetic table. (#812)

Co-authored-by: wborgeaud <williamborgeaud@gmail.com>
Co-authored-by: Daniel Lubarov <daniel@lubarov.com>
Co-authored-by: Jacqueline Nabaglo <jakub@mirprotocol.org>
Co-authored-by: Hamish Ivey-Law <426294+unzvfu@users.noreply.github.com>
This commit is contained in:
Brandon H. Gomes 2022-11-15 01:51:29 -05:00 committed by GitHub
parent ea7fbed33a
commit 14c2a6dd1f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
36 changed files with 942 additions and 474 deletions

View File

@ -47,7 +47,11 @@ Jemalloc is known to cause crashes when a binary compiled for x86 is run on an A
As this is a monorepo, see the individual crates within for license information.
## Disclaimer
## Security
This code has not yet been audited, and should not be used in any production systems.
While Plonky2 is configurable, its defaults generally target 100 bits of security. The default FRI configuration targets 100 bits of *conjectured* security based on the conjecture in [ethSTARK](https://eprint.iacr.org/2021/582).
Plonky2's default hash function is Poseidon, configured with 8 full rounds, 22 partial rounds, a width of 12 field elements (each ~64 bits), and an S-box of `x^7`. [BBLP22](https://tosc.iacr.org/index.php/ToSC/article/view/9850) suggests that this configuration may have around 95 bits of security, falling a bit short of our 100 bit target.

View File

@ -28,10 +28,11 @@ rlp = "0.5.1"
rlp-derive = "0.1.0"
serde = { version = "1.0.144", features = ["derive"] }
sha2 = "0.10.2"
static_assertions = "1.1.0"
tiny-keccak = "2.0.2"
[dev-dependencies]
criterion = "0.3.5"
criterion = "0.4.0"
hex = "0.4.3"
[features]

View File

@ -22,25 +22,20 @@ pub const IS_ADD: usize = 0;
pub const IS_MUL: usize = IS_ADD + 1;
pub const IS_SUB: usize = IS_MUL + 1;
pub const IS_DIV: usize = IS_SUB + 1;
pub const IS_SDIV: usize = IS_DIV + 1;
pub const IS_MOD: usize = IS_SDIV + 1;
pub const IS_SMOD: usize = IS_MOD + 1;
pub const IS_ADDMOD: usize = IS_SMOD + 1;
pub const IS_MOD: usize = IS_DIV + 1;
pub const IS_ADDMOD: usize = IS_MOD + 1;
pub const IS_SUBMOD: usize = IS_ADDMOD + 1;
pub const IS_MULMOD: usize = IS_SUBMOD + 1;
pub const IS_LT: usize = IS_MULMOD + 1;
pub const IS_GT: usize = IS_LT + 1;
pub const IS_SLT: usize = IS_GT + 1;
pub const IS_SGT: usize = IS_SLT + 1;
pub const IS_SHL: usize = IS_SGT + 1;
pub const IS_SHL: usize = IS_GT + 1;
pub const IS_SHR: usize = IS_SHL + 1;
pub const IS_SAR: usize = IS_SHR + 1;
const START_SHARED_COLS: usize = IS_SAR + 1;
const START_SHARED_COLS: usize = IS_SHR + 1;
pub(crate) const ALL_OPERATIONS: [usize; 17] = [
IS_ADD, IS_MUL, IS_SUB, IS_DIV, IS_SDIV, IS_MOD, IS_SMOD, IS_ADDMOD, IS_SUBMOD, IS_MULMOD,
IS_LT, IS_GT, IS_SLT, IS_SGT, IS_SHL, IS_SHR, IS_SAR,
pub(crate) const ALL_OPERATIONS: [usize; 12] = [
IS_ADD, IS_MUL, IS_SUB, IS_DIV, IS_MOD, IS_ADDMOD, IS_SUBMOD, IS_MULMOD, IS_LT, IS_GT, IS_SHL,
IS_SHR,
];
/// Within the Arithmetic Unit, there are shared columns which can be

View File

@ -35,8 +35,6 @@ pub(crate) fn generate<F: RichField>(lv: &mut [F; NUM_ARITH_COLUMNS], op: usize)
IS_LT => u256_sub_br(input0, input1),
// input1 - input0 == diff + br*2^256
IS_GT => u256_sub_br(input1, input0),
IS_SLT => todo!(),
IS_SGT => todo!(),
_ => panic!("op code not a comparison"),
};

View File

@ -1,6 +1,5 @@
use std::ops::{Add, AddAssign, Mul, Neg, Range, Shr, Sub, SubAssign};
use log::error;
use plonky2::field::extension::Extendable;
use plonky2::hash::hash_types::RichField;
use plonky2::iop::ext_target::ExtensionTarget;
@ -11,21 +10,24 @@ use crate::arithmetic::columns::{NUM_ARITH_COLUMNS, N_LIMBS};
/// Emit an error message regarding unchecked range assumptions.
/// Assumes the values in `cols` are `[cols[0], cols[0] + 1, ...,
/// cols[0] + cols.len() - 1]`.
///
/// TODO: Hamish to delete this when he has implemented and integrated
/// range checks.
pub(crate) fn _range_check_error<const RC_BITS: u32>(
file: &str,
line: u32,
cols: Range<usize>,
signedness: &str,
_file: &str,
_line: u32,
_cols: Range<usize>,
_signedness: &str,
) {
error!(
"{}:{}: arithmetic unit skipped {}-bit {} range-checks on columns {}--{}: not yet implemented",
line,
file,
RC_BITS,
signedness,
cols.start,
cols.end - 1,
);
// error!(
// "{}:{}: arithmetic unit skipped {}-bit {} range-checks on columns {}--{}: not yet implemented",
// line,
// file,
// RC_BITS,
// signedness,
// cols.start,
// cols.end - 1,
// );
}
#[macro_export]

View File

@ -9,6 +9,7 @@ pub(crate) union CpuGeneralColumnsView<T: Copy> {
arithmetic: CpuArithmeticView<T>,
logic: CpuLogicView<T>,
jumps: CpuJumpsView<T>,
shift: CpuShiftView<T>,
}
impl<T: Copy> CpuGeneralColumnsView<T> {
@ -51,6 +52,16 @@ impl<T: Copy> CpuGeneralColumnsView<T> {
pub(crate) fn jumps_mut(&mut self) -> &mut CpuJumpsView<T> {
unsafe { &mut self.jumps }
}
// SAFETY: Each view is a valid interpretation of the underlying array.
pub(crate) fn shift(&self) -> &CpuShiftView<T> {
unsafe { &self.shift }
}
// SAFETY: Each view is a valid interpretation of the underlying array.
pub(crate) fn shift_mut(&mut self) -> &mut CpuShiftView<T> {
unsafe { &mut self.shift }
}
}
impl<T: Copy + PartialEq> PartialEq<Self> for CpuGeneralColumnsView<T> {
@ -144,5 +155,12 @@ pub(crate) struct CpuJumpsView<T: Copy> {
pub(crate) should_trap: T,
}
#[derive(Copy, Clone)]
pub(crate) struct CpuShiftView<T: Copy> {
// For a shift amount of displacement: [T], this is the inverse of
// sum(displacement[1..]) or zero if the sum is zero.
pub(crate) high_limb_sum_inv: T,
}
// `u8` is guaranteed to have a `size_of` of 1.
pub const NUM_SHARED_COLUMNS: usize = size_of::<CpuGeneralColumnsView<u8>>();

View File

@ -7,106 +7,59 @@ use crate::util::{indices_arr, transmute_no_compile_time_size_checks};
#[repr(C)]
#[derive(Eq, PartialEq, Debug)]
pub struct OpsColumnsView<T> {
pub stop: T,
// TODO: combine ADD, MUL, SUB, DIV, MOD, ADDFP254, MULFP254, SUBFP254, LT, and GT into one flag
pub add: T,
pub mul: T,
pub sub: T,
pub div: T,
pub sdiv: T,
pub mod_: T,
pub smod: T,
// TODO: combine ADDMOD, MULMOD into one flag
pub addmod: T,
pub mulmod: T,
pub exp: T,
pub signextend: T,
pub addfp254: T,
pub mulfp254: T,
pub subfp254: T,
pub lt: T,
pub gt: T,
pub slt: T,
pub sgt: T,
pub eq: T, // Note: This column must be 0 when is_cpu_cycle = 0.
pub iszero: T, // Note: This column must be 0 when is_cpu_cycle = 0.
// TODO: combine AND, OR, and XOR into one flag
pub and: T,
pub or: T,
pub xor: T,
pub not: T,
pub byte: T,
// TODO: combine SHL and SHR into one flag
pub shl: T,
pub shr: T,
pub sar: T,
pub keccak256: T,
pub keccak_general: T,
pub address: T,
pub balance: T,
pub origin: T,
pub caller: T,
pub callvalue: T,
pub calldataload: T,
pub calldatasize: T,
pub calldatacopy: T,
pub codesize: T,
pub codecopy: T,
pub gasprice: T,
pub extcodesize: T,
pub extcodecopy: T,
pub returndatasize: T,
pub returndatacopy: T,
pub extcodehash: T,
pub blockhash: T,
pub coinbase: T,
pub timestamp: T,
pub number: T,
pub difficulty: T,
pub gaslimit: T,
pub chainid: T,
pub selfbalance: T,
pub basefee: T,
pub prover_input: T,
pub pop: T,
pub mload: T,
pub mstore: T,
pub mstore8: T,
pub sload: T,
pub sstore: T,
// TODO: combine JUMP and JUMPI into one flag
pub jump: T, // Note: This column must be 0 when is_cpu_cycle = 0.
pub jumpi: T, // Note: This column must be 0 when is_cpu_cycle = 0.
pub pc: T,
pub msize: T,
pub gas: T,
pub jumpdest: T,
// TODO: combine GET_STATE_ROOT and SET_STATE_ROOT into one flag
pub get_state_root: T,
pub set_state_root: T,
// TODO: combine GET_RECEIPT_ROOT and SET_RECEIPT_ROOT into one flag
pub get_receipt_root: T,
pub set_receipt_root: T,
pub push: T,
pub dup: T,
pub swap: T,
pub log0: T,
pub log1: T,
pub log2: T,
pub log3: T,
pub log4: T,
// PANIC does not get a flag; it fails at the decode stage.
pub create: T,
pub call: T,
pub callcode: T,
pub return_: T,
pub delegatecall: T,
pub create2: T,
// TODO: combine GET_CONTEXT and SET_CONTEXT into one flag
pub get_context: T,
pub set_context: T,
pub consume_gas: T,
pub exit_kernel: T,
pub staticcall: T,
// TODO: combine MLOAD_GENERAL and MSTORE_GENERAL into one flag
pub mload_general: T,
pub mstore_general: T,
pub revert: T,
pub selfdestruct: T,
// TODO: this doesn't actually need its own flag. We can just do `1 - sum(all other flags)`.
pub invalid: T,
pub syscall: T,
}
// `u8` is guaranteed to have a `size_of` of 1.

View File

@ -8,36 +8,49 @@ use crate::constraint_consumer::{ConstraintConsumer, RecursiveConstraintConsumer
use crate::cpu::columns::{CpuColumnsView, COL_MAP};
use crate::cpu::kernel::aggregator::KERNEL;
// TODO: This list is incomplete.
const NATIVE_INSTRUCTIONS: [usize; 28] = [
const NATIVE_INSTRUCTIONS: [usize; 37] = [
COL_MAP.op.add,
COL_MAP.op.mul,
COL_MAP.op.sub,
COL_MAP.op.div,
COL_MAP.op.sdiv,
COL_MAP.op.mod_,
COL_MAP.op.smod,
COL_MAP.op.addmod,
COL_MAP.op.mulmod,
COL_MAP.op.signextend,
COL_MAP.op.addfp254,
COL_MAP.op.mulfp254,
COL_MAP.op.subfp254,
COL_MAP.op.lt,
COL_MAP.op.gt,
COL_MAP.op.slt,
COL_MAP.op.sgt,
COL_MAP.op.eq,
COL_MAP.op.iszero,
COL_MAP.op.and,
COL_MAP.op.or,
COL_MAP.op.xor,
COL_MAP.op.not,
COL_MAP.op.byte,
COL_MAP.op.shl,
COL_MAP.op.shr,
COL_MAP.op.sar,
COL_MAP.op.keccak_general,
COL_MAP.op.prover_input,
COL_MAP.op.pop,
// not JUMP (need to jump)
// not JUMPI (possible need to jump)
COL_MAP.op.pc,
COL_MAP.op.gas,
COL_MAP.op.jumpdest,
COL_MAP.op.get_state_root,
COL_MAP.op.set_state_root,
COL_MAP.op.get_receipt_root,
COL_MAP.op.set_receipt_root,
// not PUSH (need to increment by more than 1)
COL_MAP.op.dup,
COL_MAP.op.swap,
COL_MAP.op.get_context,
COL_MAP.op.set_context,
COL_MAP.op.consume_gas,
// not EXIT_KERNEL (performs a jump)
COL_MAP.op.mload_general,
COL_MAP.op.mstore_general,
// not SYSCALL (performs a jump)
];
fn get_halt_pcs<F: Field>() -> (F, F) {

View File

@ -11,8 +11,8 @@ use plonky2::hash::hash_types::RichField;
use crate::constraint_consumer::{ConstraintConsumer, RecursiveConstraintConsumer};
use crate::cpu::columns::{CpuColumnsView, COL_MAP, NUM_CPU_COLUMNS};
use crate::cpu::{
bootstrap_kernel, control_flow, decode, dup_swap, jumps, membus, modfp254, simple_logic, stack,
stack_bounds, syscalls,
bootstrap_kernel, control_flow, decode, dup_swap, jumps, membus, modfp254, shift, simple_logic,
stack, stack_bounds, syscalls,
};
use crate::cross_table_lookup::Column;
use crate::memory::segments::Segment;
@ -151,6 +151,7 @@ impl<F: RichField + Extendable<D>, const D: usize> Stark<F, D> for CpuStark<F, D
jumps::eval_packed(local_values, next_values, yield_constr);
membus::eval_packed(local_values, yield_constr);
modfp254::eval_packed(local_values, yield_constr);
shift::eval_packed(local_values, yield_constr);
simple_logic::eval_packed(local_values, yield_constr);
stack::eval_packed(local_values, yield_constr);
stack_bounds::eval_packed(local_values, yield_constr);
@ -172,6 +173,7 @@ impl<F: RichField + Extendable<D>, const D: usize> Stark<F, D> for CpuStark<F, D
jumps::eval_ext_circuit(builder, local_values, next_values, yield_constr);
membus::eval_ext_circuit(builder, local_values, yield_constr);
modfp254::eval_ext_circuit(builder, local_values, yield_constr);
shift::eval_ext_circuit(builder, local_values, yield_constr);
simple_logic::eval_ext_circuit(builder, local_values, yield_constr);
stack::eval_ext_circuit(builder, local_values, yield_constr);
stack_bounds::eval_ext_circuit(builder, local_values, yield_constr);

View File

@ -22,27 +22,20 @@ use crate::cpu::columns::{CpuColumnsView, COL_MAP};
/// behavior.
/// Note: invalid opcodes are not represented here. _Any_ opcode is permitted to decode to
/// `is_invalid`. The kernel then verifies that the opcode was _actually_ invalid.
const OPCODES: [(u8, usize, bool, usize); 96] = [
const OPCODES: [(u8, usize, bool, usize); 42] = [
// (start index of block, number of top bits to check (log2), kernel-only, flag column)
(0x00, 0, false, COL_MAP.op.stop),
(0x01, 0, false, COL_MAP.op.add),
(0x02, 0, false, COL_MAP.op.mul),
(0x03, 0, false, COL_MAP.op.sub),
(0x04, 0, false, COL_MAP.op.div),
(0x05, 0, false, COL_MAP.op.sdiv),
(0x06, 0, false, COL_MAP.op.mod_),
(0x07, 0, false, COL_MAP.op.smod),
(0x08, 0, false, COL_MAP.op.addmod),
(0x09, 0, false, COL_MAP.op.mulmod),
(0x0a, 0, false, COL_MAP.op.exp),
(0x0b, 0, false, COL_MAP.op.signextend),
(0x0c, 0, true, COL_MAP.op.addfp254),
(0x0d, 0, true, COL_MAP.op.mulfp254),
(0x0e, 0, true, COL_MAP.op.subfp254),
(0x10, 0, false, COL_MAP.op.lt),
(0x11, 0, false, COL_MAP.op.gt),
(0x12, 0, false, COL_MAP.op.slt),
(0x13, 0, false, COL_MAP.op.sgt),
(0x14, 0, false, COL_MAP.op.eq),
(0x15, 0, false, COL_MAP.op.iszero),
(0x16, 0, false, COL_MAP.op.and),
@ -52,45 +45,12 @@ const OPCODES: [(u8, usize, bool, usize); 96] = [
(0x1a, 0, false, COL_MAP.op.byte),
(0x1b, 0, false, COL_MAP.op.shl),
(0x1c, 0, false, COL_MAP.op.shr),
(0x1d, 0, false, COL_MAP.op.sar),
(0x20, 0, false, COL_MAP.op.keccak256),
(0x21, 0, true, COL_MAP.op.keccak_general),
(0x30, 0, false, COL_MAP.op.address),
(0x31, 0, false, COL_MAP.op.balance),
(0x32, 0, false, COL_MAP.op.origin),
(0x33, 0, false, COL_MAP.op.caller),
(0x34, 0, false, COL_MAP.op.callvalue),
(0x35, 0, false, COL_MAP.op.calldataload),
(0x36, 0, false, COL_MAP.op.calldatasize),
(0x37, 0, false, COL_MAP.op.calldatacopy),
(0x38, 0, false, COL_MAP.op.codesize),
(0x39, 0, false, COL_MAP.op.codecopy),
(0x3a, 0, false, COL_MAP.op.gasprice),
(0x3b, 0, false, COL_MAP.op.extcodesize),
(0x3c, 0, false, COL_MAP.op.extcodecopy),
(0x3d, 0, false, COL_MAP.op.returndatasize),
(0x3e, 0, false, COL_MAP.op.returndatacopy),
(0x3f, 0, false, COL_MAP.op.extcodehash),
(0x40, 0, false, COL_MAP.op.blockhash),
(0x41, 0, false, COL_MAP.op.coinbase),
(0x42, 0, false, COL_MAP.op.timestamp),
(0x43, 0, false, COL_MAP.op.number),
(0x44, 0, false, COL_MAP.op.difficulty),
(0x45, 0, false, COL_MAP.op.gaslimit),
(0x46, 0, false, COL_MAP.op.chainid),
(0x47, 0, false, COL_MAP.op.selfbalance),
(0x48, 0, false, COL_MAP.op.basefee),
(0x49, 0, true, COL_MAP.op.prover_input),
(0x50, 0, false, COL_MAP.op.pop),
(0x51, 0, false, COL_MAP.op.mload),
(0x52, 0, false, COL_MAP.op.mstore),
(0x53, 0, false, COL_MAP.op.mstore8),
(0x54, 0, false, COL_MAP.op.sload),
(0x55, 0, false, COL_MAP.op.sstore),
(0x56, 0, false, COL_MAP.op.jump),
(0x57, 0, false, COL_MAP.op.jumpi),
(0x58, 0, false, COL_MAP.op.pc),
(0x59, 0, false, COL_MAP.op.msize),
(0x5a, 0, false, COL_MAP.op.gas),
(0x5b, 0, false, COL_MAP.op.jumpdest),
(0x5c, 0, true, COL_MAP.op.get_state_root),
@ -100,27 +60,12 @@ const OPCODES: [(u8, usize, bool, usize); 96] = [
(0x60, 5, false, COL_MAP.op.push), // 0x60-0x7f
(0x80, 4, false, COL_MAP.op.dup), // 0x80-0x8f
(0x90, 4, false, COL_MAP.op.swap), // 0x90-0x9f
(0xa0, 0, false, COL_MAP.op.log0),
(0xa1, 0, false, COL_MAP.op.log1),
(0xa2, 0, false, COL_MAP.op.log2),
(0xa3, 0, false, COL_MAP.op.log3),
(0xa4, 0, false, COL_MAP.op.log4),
// Opcode 0xa5 is PANIC when Kernel. Make the proof unverifiable by giving it no flag to decode to.
(0xf0, 0, false, COL_MAP.op.create),
(0xf1, 0, false, COL_MAP.op.call),
(0xf2, 0, false, COL_MAP.op.callcode),
(0xf3, 0, false, COL_MAP.op.return_),
(0xf4, 0, false, COL_MAP.op.delegatecall),
(0xf5, 0, false, COL_MAP.op.create2),
(0xf6, 0, true, COL_MAP.op.get_context),
(0xf7, 0, true, COL_MAP.op.set_context),
(0xf8, 0, true, COL_MAP.op.consume_gas),
(0xf9, 0, true, COL_MAP.op.exit_kernel),
(0xfa, 0, false, COL_MAP.op.staticcall),
(0xfb, 0, true, COL_MAP.op.mload_general),
(0xfc, 0, true, COL_MAP.op.mstore_general),
(0xfd, 0, false, COL_MAP.op.revert),
(0xff, 0, false, COL_MAP.op.selfdestruct),
];
/// Bitfield of invalid opcodes, in little-endian order.
@ -188,19 +133,12 @@ pub fn generate<F: RichField>(lv: &mut CpuColumnsView<F>) {
assert!(kernel <= 1);
let kernel = kernel != 0;
let mut any_flag_set = false;
for (oc, block_length, kernel_only, col) in OPCODES {
let available = !kernel_only || kernel;
let opcode_match = top_bits[8 - block_length] == oc;
let flag = available && opcode_match;
lv[col] = F::from_bool(flag);
if flag && any_flag_set {
panic!("opcode matched multiple flags");
}
any_flag_set = any_flag_set || flag;
}
// is_invalid is a catch-all for opcodes we can't decode.
lv.op.invalid = F::from_bool(!any_flag_set);
}
/// Break up an opcode (which is 8 bits long) into its eight bits.
@ -238,14 +176,12 @@ pub fn eval_packed_generic<P: PackedField>(
let flag = lv[flag_col];
yield_constr.constraint(cycle_filter * flag * (flag - P::ONES));
}
yield_constr.constraint(cycle_filter * lv.op.invalid * (lv.op.invalid - P::ONES));
// Now check that exactly one is 1.
// Now check that they sum to 0 or 1.
let flag_sum: P = OPCODES
.into_iter()
.map(|(_, _, _, flag_col)| lv[flag_col])
.sum::<P>()
+ lv.op.invalid;
yield_constr.constraint(cycle_filter * (P::ONES - flag_sum));
.sum::<P>();
yield_constr.constraint(cycle_filter * flag_sum * (flag_sum - P::ONES));
// Finally, classify all opcodes, together with the kernel flag, into blocks
for (oc, block_length, kernel_only, col) in OPCODES {
@ -308,20 +244,15 @@ pub fn eval_ext_circuit<F: RichField + Extendable<D>, const D: usize>(
let constr = builder.mul_extension(cycle_filter, constr);
yield_constr.constraint(builder, constr);
}
// Now check that they sum to 0 or 1.
{
let constr = builder.mul_sub_extension(lv.op.invalid, lv.op.invalid, lv.op.invalid);
let constr = builder.mul_extension(cycle_filter, constr);
yield_constr.constraint(builder, constr);
}
// Now check that exactly one is 1.
{
let mut constr = builder.one_extension();
let mut flag_sum = builder.zero_extension();
for (_, _, _, flag_col) in OPCODES {
let flag = lv[flag_col];
constr = builder.sub_extension(constr, flag);
flag_sum = builder.add_extension(flag_sum, flag);
}
constr = builder.sub_extension(constr, lv.op.invalid);
constr = builder.mul_extension(cycle_filter, constr);
let constr = builder.mul_sub_extension(flag_sum, flag_sum, flag_sum);
let constr = builder.mul_extension(cycle_filter, constr);
yield_constr.constraint(builder, constr);
}

View File

@ -18,6 +18,8 @@ pub(crate) fn combined_kernel() -> Kernel {
include_str!("asm/core/invalid.asm"),
include_str!("asm/core/nonce.asm"),
include_str!("asm/core/process_txn.asm"),
include_str!("asm/core/syscall.asm"),
include_str!("asm/core/syscall_stubs.asm"),
include_str!("asm/core/terminate.asm"),
include_str!("asm/core/transfer.asm"),
include_str!("asm/core/util.asm"),
@ -75,6 +77,7 @@ pub(crate) fn combined_kernel() -> Kernel {
include_str!("asm/sha2/store_pad.asm"),
include_str!("asm/sha2/temp_words.asm"),
include_str!("asm/sha2/write_length.asm"),
include_str!("asm/shift.asm"),
include_str!("asm/transactions/router.asm"),
include_str!("asm/transactions/type_0.asm"),
include_str!("asm/transactions/type_1.asm"),

View File

@ -0,0 +1,160 @@
global syscall_jumptable:
// 0x00-0x0f
JUMPTABLE sys_stop
JUMPTABLE panic // add is implemented natively
JUMPTABLE panic // mul is implemented natively
JUMPTABLE panic // sub is implemented natively
JUMPTABLE panic // div is implemented natively
JUMPTABLE sys_sdiv
JUMPTABLE panic // mod is implemented natively
JUMPTABLE sys_smod
JUMPTABLE panic // addmod is implemented natively
JUMPTABLE panic // mulmod is implemented natively
JUMPTABLE sys_exp
JUMPTABLE sys_signextend
JUMPTABLE panic // 0x0c is an invalid opcode
JUMPTABLE panic // 0x0d is an invalid opcode
JUMPTABLE panic // 0x0e is an invalid opcode
JUMPTABLE panic // 0x0f is an invalid opcode
// 0x10-0x1f
JUMPTABLE panic // lt is implemented natively
JUMPTABLE panic // gt is implemented natively
JUMPTABLE sys_slt
JUMPTABLE sys_sgt
JUMPTABLE panic // eq is implemented natively
JUMPTABLE panic // iszero is implemented natively
JUMPTABLE panic // and is implemented natively
JUMPTABLE panic // or is implemented natively
JUMPTABLE panic // xor is implemented natively
JUMPTABLE panic // not is implemented natively
JUMPTABLE panic // byte is implemented natively
JUMPTABLE panic // shl is implemented natively
JUMPTABLE panic // shr is implemented natively
JUMPTABLE sys_sar
JUMPTABLE panic // 0x1e is an invalid opcode
JUMPTABLE panic // 0x1f is an invalid opcode
// 0x20-0x2f
JUMPTABLE sys_keccak256
%rep 15
JUMPTABLE panic // 0x21-0x2f are invalid opcodes
%endrep
// 0x30-0x3f
JUMPTABLE sys_address
JUMPTABLE sys_balance
JUMPTABLE sys_origin
JUMPTABLE sys_caller
JUMPTABLE sys_callvalue
JUMPTABLE sys_calldataload
JUMPTABLE sys_calldatasize
JUMPTABLE sys_calldatacopy
JUMPTABLE sys_codesize
JUMPTABLE sys_codecopy
JUMPTABLE sys_gasprice
JUMPTABLE sys_extcodesize
JUMPTABLE sys_extcodecopy
JUMPTABLE sys_returndatasize
JUMPTABLE sys_returndatacopy
JUMPTABLE sys_extcodehash
// 0x40-0x4f
JUMPTABLE sys_blockhash
JUMPTABLE sys_coinbase
JUMPTABLE sys_timestamp
JUMPTABLE sys_number
JUMPTABLE sys_prevrandao
JUMPTABLE sys_gaslimit
JUMPTABLE sys_chainid
JUMPTABLE sys_selfbalance
JUMPTABLE sys_basefee
%rep 7
JUMPTABLE panic // 0x49-0x4f are invalid opcodes
%endrep
// 0x50-0x5f
JUMPTABLE panic // pop is implemented natively
JUMPTABLE sys_mload
JUMPTABLE sys_mstore
JUMPTABLE sys_mstore8
JUMPTABLE sys_sload
JUMPTABLE sys_sstore
JUMPTABLE panic // jump is implemented natively
JUMPTABLE panic // jumpi is implemented natively
JUMPTABLE panic // pc is implemented natively
JUMPTABLE sys_msize
JUMPTABLE panic // gas is implemented natively
JUMPTABLE panic // jumpdest is implemented natively
JUMPTABLE panic // 0x5c is an invalid opcode
JUMPTABLE panic // 0x5d is an invalid opcode
JUMPTABLE panic // 0x5e is an invalid opcode
JUMPTABLE panic // 0x5f is an invalid opcode
// 0x60-0x6f
%rep 16
JUMPTABLE panic // push1-push16 are implemented natively
%endrep
// 0x70-0x7f
%rep 16
JUMPTABLE panic // push17-push32 are implemented natively
%endrep
// 0x80-0x8f
%rep 16
JUMPTABLE panic // dup1-dup16 are implemented natively
%endrep
// 0x90-0x9f
%rep 16
JUMPTABLE panic // swap1-swap16 are implemented natively
%endrep
// 0xa0-0xaf
JUMPTABLE sys_log0
JUMPTABLE sys_log1
JUMPTABLE sys_log2
JUMPTABLE sys_log3
JUMPTABLE sys_log4
%rep 11
JUMPTABLE panic // 0xa5-0xaf are invalid opcodes
%endrep
// 0xb0-0xbf
%rep 16
JUMPTABLE panic // 0xb0-0xbf are invalid opcodes
%endrep
// 0xc0-0xcf
%rep 16
JUMPTABLE panic // 0xc0-0xcf are invalid opcodes
%endrep
// 0xd0-0xdf
%rep 16
JUMPTABLE panic // 0xd0-0xdf are invalid opcodes
%endrep
// 0xe0-0xef
%rep 16
JUMPTABLE panic // 0xe0-0xef are invalid opcodes
%endrep
// 0xf0-0xff
JUMPTABLE sys_create
JUMPTABLE sys_call
JUMPTABLE sys_callcode
JUMPTABLE sys_return
JUMPTABLE sys_delegatecall
JUMPTABLE sys_create2
JUMPTABLE panic // 0xf6 is an invalid opcode
JUMPTABLE panic // 0xf7 is an invalid opcode
JUMPTABLE panic // 0xf8 is an invalid opcode
JUMPTABLE panic // 0xf9 is an invalid opcode
JUMPTABLE sys_staticcall
JUMPTABLE panic // 0xfb is an invalid opcode
JUMPTABLE panic // 0xfc is an invalid opcode
JUMPTABLE sys_revert
JUMPTABLE panic // 0xfe is an invalid opcode
JUMPTABLE sys_selfdestruct

View File

@ -0,0 +1,53 @@
// Labels for unimplemented syscalls to make the kernel assemble.
// Each label should be removed from this file once it is implemented.
global sys_sdiv:
global sys_smod:
global sys_signextend:
global sys_slt:
global sys_sgt:
global sys_sar:
global sys_keccak256:
global sys_address:
global sys_balance:
global sys_origin:
global sys_caller:
global sys_callvalue:
global sys_calldataload:
global sys_calldatasize:
global sys_calldatacopy:
global sys_codesize:
global sys_codecopy:
global sys_gasprice:
global sys_extcodesize:
global sys_extcodecopy:
global sys_returndatasize:
global sys_returndatacopy:
global sys_extcodehash:
global sys_blockhash:
global sys_coinbase:
global sys_timestamp:
global sys_number:
global sys_prevrandao:
global sys_gaslimit:
global sys_chainid:
global sys_selfbalance:
global sys_basefee:
global sys_mload:
global sys_mstore:
global sys_mstore8:
global sys_sload:
global sys_sstore:
global sys_msize:
global sys_log0:
global sys_log1:
global sys_log2:
global sys_log3:
global sys_log4:
global sys_create:
global sys_call:
global sys_callcode:
global sys_delegatecall:
global sys_create2:
global sys_staticcall:
PANIC

View File

@ -1,5 +1,7 @@
global main:
// First, load all MPT data from the prover.
// First, initialise the shift table
%shift_table_init
// Second, load all MPT data from the prover.
PUSH txn_loop
%jump(load_all_mpts)

View File

@ -0,0 +1,25 @@
/// Initialise the lookup table of binary powers for doing left/right shifts
///
/// Specifically, set SHIFT_TABLE_SEGMENT[i] = 2^i for i = 0..255.
%macro shift_table_init
push 1 // 2^0
push 0 // initial offset is zero
push @SEGMENT_SHIFT_TABLE // segment
dup2 // kernel context is 0
%rep 255
// stack: context, segment, ost_i, 2^i
dup4
dup1
add
// stack: 2^(i+1), context, segment, ost_i, 2^i
dup4
%increment
// stack: ost_(i+1), 2^(i+1), context, segment, ost_i, 2^i
dup4
dup4
// stack: context, segment, ost_(i+1), 2^(i+1), context, segment, ost_i, 2^i
%endrep
%rep 256
mstore_general
%endrep
%endmacro

View File

@ -300,11 +300,29 @@ fn find_labels(
}
Item::StandardOp(_) => *offset += 1,
Item::Bytes(bytes) => *offset += bytes.len(),
Item::Jumptable(labels) => *offset += labels.len() * (BYTES_PER_OFFSET as usize),
}
}
local_labels
}
fn look_up_label(
label: &String,
local_labels: &HashMap<String, usize>,
global_labels: &HashMap<String, usize>,
) -> Vec<u8> {
let offset = local_labels
.get(label)
.or_else(|| global_labels.get(label))
.unwrap_or_else(|| panic!("No such label: {label}"));
// We want the BYTES_PER_OFFSET least significant bytes in BE order.
// It's easiest to rev the first BYTES_PER_OFFSET bytes of the LE encoding.
(0..BYTES_PER_OFFSET)
.rev()
.map(|i| offset.to_le_bytes()[i as usize])
.collect()
}
fn assemble_file(
body: Vec<Item>,
code: &mut Vec<u8>,
@ -327,18 +345,7 @@ fn assemble_file(
Item::Push(target) => {
let target_bytes: Vec<u8> = match target {
PushTarget::Literal(n) => u256_to_trimmed_be_bytes(&n),
PushTarget::Label(label) => {
let offset = local_labels
.get(&label)
.or_else(|| global_labels.get(&label))
.unwrap_or_else(|| panic!("No such label: {label}"));
// We want the BYTES_PER_OFFSET least significant bytes in BE order.
// It's easiest to rev the first BYTES_PER_OFFSET bytes of the LE encoding.
(0..BYTES_PER_OFFSET)
.rev()
.map(|i| offset.to_le_bytes()[i as usize])
.collect()
}
PushTarget::Label(label) => look_up_label(&label, &local_labels, global_labels),
PushTarget::MacroLabel(v) => panic!("Macro label not in a macro: {v}"),
PushTarget::MacroVar(v) => panic!("Variable not in a macro: {v}"),
PushTarget::Constant(c) => panic!("Constant wasn't inlined: {c}"),
@ -353,6 +360,12 @@ fn assemble_file(
code.push(get_opcode(&opcode));
}
Item::Bytes(bytes) => code.extend(bytes),
Item::Jumptable(labels) => {
for label in labels {
let bytes = look_up_label(&label, &local_labels, global_labels);
code.extend(bytes);
}
}
}
}
}

View File

@ -34,6 +34,8 @@ pub(crate) enum Item {
StandardOp(String),
/// Literal hex data; should contain an even number of hex chars.
Bytes(Vec<u8>),
/// Creates a table of addresses from a list of labels.
Jumptable(Vec<String>),
}
/// The left hand side of a %stack stack-manipulation macro.

View File

@ -15,7 +15,7 @@ literal = { literal_hex | literal_decimal }
variable = ${ "$" ~ identifier }
constant = ${ "@" ~ identifier }
item = { macro_def | macro_call | repeat | stack | global_label_decl | local_label_decl | macro_label_decl | bytes_item | push_instruction | prover_input_instruction | nullary_instruction }
item = { macro_def | macro_call | repeat | stack | global_label_decl | local_label_decl | macro_label_decl | bytes_item | jumptable_item | push_instruction | prover_input_instruction | nullary_instruction }
macro_def = { ^"%macro" ~ identifier ~ paramlist? ~ item* ~ ^"%endmacro" }
macro_call = ${ "%" ~ !((^"macro" | ^"endmacro" | ^"rep" | ^"endrep" | ^"stack") ~ !identifier_char) ~ identifier ~ macro_arglist? }
repeat = { ^"%rep" ~ literal ~ item* ~ ^"%endrep" }
@ -35,6 +35,7 @@ macro_label_decl = ${ "%%" ~ identifier ~ ":" }
macro_label = ${ "%%" ~ identifier }
bytes_item = { ^"BYTES " ~ literal ~ ("," ~ literal)* }
jumptable_item = { ^"JUMPTABLE " ~ identifier ~ ("," ~ identifier)* }
push_instruction = { ^"PUSH " ~ push_target }
push_target = { literal | identifier | macro_label | variable | constant }
prover_input_instruction = { ^"PROVER_INPUT" ~ "(" ~ prover_input_fn ~ ")" }

View File

@ -8,7 +8,6 @@ use keccak_hash::keccak;
use plonky2::field::goldilocks_field::GoldilocksField;
use crate::cpu::kernel::aggregator::KERNEL;
use crate::cpu::kernel::assembler::Kernel;
use crate::cpu::kernel::constants::context_metadata::ContextMetadata;
use crate::cpu::kernel::constants::global_metadata::GlobalMetadata;
use crate::cpu::kernel::constants::txn_fields::NormalizedTxnField;
@ -78,19 +77,18 @@ pub struct Interpreter<'a> {
pub(crate) halt_offsets: Vec<usize>,
pub(crate) debug_offsets: Vec<usize>,
running: bool,
opcode_count: [usize; 0x100],
}
pub fn run_with_kernel(
// TODO: Remove param and just use KERNEL.
kernel: &Kernel,
pub fn run_interpreter(
initial_offset: usize,
initial_stack: Vec<U256>,
) -> anyhow::Result<Interpreter> {
) -> anyhow::Result<Interpreter<'static>> {
run(
&kernel.code,
&KERNEL.code,
initial_offset,
initial_stack,
&kernel.prover_inputs,
&KERNEL.prover_inputs,
)
}
@ -132,6 +130,7 @@ impl<'a> Interpreter<'a> {
halt_offsets: vec![DEFAULT_HALT_OFFSET],
debug_offsets: vec![],
running: false,
opcode_count: [0; 0x100],
}
}
@ -140,6 +139,12 @@ impl<'a> Interpreter<'a> {
while self.running {
self.run_opcode()?;
}
println!("Opcode count:");
for i in 0..0x100 {
if self.opcode_count[i] > 0 {
println!("{}: {}", get_mnemonic(i as u8), self.opcode_count[i])
}
}
Ok(())
}
@ -223,6 +228,7 @@ impl<'a> Interpreter<'a> {
fn run_opcode(&mut self) -> anyhow::Result<()> {
let opcode = self.code().get(self.offset).byte(0);
self.opcode_count[opcode as usize] += 1;
self.incr(1);
match opcode {
0x00 => self.run_stop(), // "STOP",
@ -690,6 +696,170 @@ fn find_jumpdests(code: &[u8]) -> Vec<usize> {
res
}
fn get_mnemonic(opcode: u8) -> &'static str {
match opcode {
0x00 => "STOP",
0x01 => "ADD",
0x02 => "MUL",
0x03 => "SUB",
0x04 => "DIV",
0x05 => "SDIV",
0x06 => "MOD",
0x07 => "SMOD",
0x08 => "ADDMOD",
0x09 => "MULMOD",
0x0a => "EXP",
0x0b => "SIGNEXTEND",
0x0c => "ADDFP254",
0x0d => "MULFP254",
0x0e => "SUBFP254",
0x10 => "LT",
0x11 => "GT",
0x12 => "SLT",
0x13 => "SGT",
0x14 => "EQ",
0x15 => "ISZERO",
0x16 => "AND",
0x17 => "OR",
0x18 => "XOR",
0x19 => "NOT",
0x1a => "BYTE",
0x1b => "SHL",
0x1c => "SHR",
0x1d => "SAR",
0x20 => "KECCAK256",
0x21 => "KECCAK_GENERAL",
0x30 => "ADDRESS",
0x31 => "BALANCE",
0x32 => "ORIGIN",
0x33 => "CALLER",
0x34 => "CALLVALUE",
0x35 => "CALLDATALOAD",
0x36 => "CALLDATASIZE",
0x37 => "CALLDATACOPY",
0x38 => "CODESIZE",
0x39 => "CODECOPY",
0x3a => "GASPRICE",
0x3b => "EXTCODESIZE",
0x3c => "EXTCODECOPY",
0x3d => "RETURNDATASIZE",
0x3e => "RETURNDATACOPY",
0x3f => "EXTCODEHASH",
0x40 => "BLOCKHASH",
0x41 => "COINBASE",
0x42 => "TIMESTAMP",
0x43 => "NUMBER",
0x44 => "DIFFICULTY",
0x45 => "GASLIMIT",
0x46 => "CHAINID",
0x48 => "BASEFEE",
0x49 => "PROVER_INPUT",
0x50 => "POP",
0x51 => "MLOAD",
0x52 => "MSTORE",
0x53 => "MSTORE8",
0x54 => "SLOAD",
0x55 => "SSTORE",
0x56 => "JUMP",
0x57 => "JUMPI",
0x58 => "GETPC",
0x59 => "MSIZE",
0x5a => "GAS",
0x5b => "JUMPDEST",
0x5c => "GET_STATE_ROOT",
0x5d => "SET_STATE_ROOT",
0x5e => "GET_RECEIPT_ROOT",
0x5f => "SET_RECEIPT_ROOT",
0x60 => "PUSH1",
0x61 => "PUSH2",
0x62 => "PUSH3",
0x63 => "PUSH4",
0x64 => "PUSH5",
0x65 => "PUSH6",
0x66 => "PUSH7",
0x67 => "PUSH8",
0x68 => "PUSH9",
0x69 => "PUSH10",
0x6a => "PUSH11",
0x6b => "PUSH12",
0x6c => "PUSH13",
0x6d => "PUSH14",
0x6e => "PUSH15",
0x6f => "PUSH16",
0x70 => "PUSH17",
0x71 => "PUSH18",
0x72 => "PUSH19",
0x73 => "PUSH20",
0x74 => "PUSH21",
0x75 => "PUSH22",
0x76 => "PUSH23",
0x77 => "PUSH24",
0x78 => "PUSH25",
0x79 => "PUSH26",
0x7a => "PUSH27",
0x7b => "PUSH28",
0x7c => "PUSH29",
0x7d => "PUSH30",
0x7e => "PUSH31",
0x7f => "PUSH32",
0x80 => "DUP1",
0x81 => "DUP2",
0x82 => "DUP3",
0x83 => "DUP4",
0x84 => "DUP5",
0x85 => "DUP6",
0x86 => "DUP7",
0x87 => "DUP8",
0x88 => "DUP9",
0x89 => "DUP10",
0x8a => "DUP11",
0x8b => "DUP12",
0x8c => "DUP13",
0x8d => "DUP14",
0x8e => "DUP15",
0x8f => "DUP16",
0x90 => "SWAP1",
0x91 => "SWAP2",
0x92 => "SWAP3",
0x93 => "SWAP4",
0x94 => "SWAP5",
0x95 => "SWAP6",
0x96 => "SWAP7",
0x97 => "SWAP8",
0x98 => "SWAP9",
0x99 => "SWAP10",
0x9a => "SWAP11",
0x9b => "SWAP12",
0x9c => "SWAP13",
0x9d => "SWAP14",
0x9e => "SWAP15",
0x9f => "SWAP16",
0xa0 => "LOG0",
0xa1 => "LOG1",
0xa2 => "LOG2",
0xa3 => "LOG3",
0xa4 => "LOG4",
0xa5 => "PANIC",
0xf0 => "CREATE",
0xf1 => "CALL",
0xf2 => "CALLCODE",
0xf3 => "RETURN",
0xf4 => "DELEGATECALL",
0xf5 => "CREATE2",
0xf6 => "GET_CONTEXT",
0xf7 => "SET_CONTEXT",
0xf8 => "CONSUME_GAS",
0xf9 => "EXIT_KERNEL",
0xfa => "STATICCALL",
0xfb => "MLOAD_GENERAL",
0xfc => "MSTORE_GENERAL",
0xfd => "REVERT",
0xfe => "INVALID",
0xff => "SELFDESTRUCT",
_ => panic!("Unrecognized opcode {opcode}"),
}
}
#[cfg(test)]
mod tests {
use std::collections::HashMap;

View File

@ -39,6 +39,9 @@ fn parse_item(item: Pair<Rule>) -> Item {
Item::MacroLabelDeclaration(item.into_inner().next().unwrap().as_str().into())
}
Rule::bytes_item => Item::Bytes(item.into_inner().map(parse_literal_u8).collect()),
Rule::jumptable_item => {
Item::Jumptable(item.into_inner().map(|i| i.as_str().into()).collect())
}
Rule::push_instruction => Item::Push(parse_push_target(item.into_inner().next().unwrap())),
Rule::prover_input_instruction => Item::ProverInput(
item.into_inner()

View File

@ -3,17 +3,16 @@ mod bn {
use anyhow::Result;
use ethereum_types::U256;
use crate::cpu::kernel::aggregator::combined_kernel;
use crate::cpu::kernel::interpreter::run_with_kernel;
use crate::cpu::kernel::aggregator::KERNEL;
use crate::cpu::kernel::interpreter::run_interpreter;
use crate::cpu::kernel::tests::u256ify;
#[test]
fn test_ec_ops() -> Result<()> {
// Make sure we can parse and assemble the entire kernel.
let kernel = combined_kernel();
let ec_add = kernel.global_labels["ec_add"];
let ec_double = kernel.global_labels["ec_double"];
let ec_mul = kernel.global_labels["ec_mul"];
let ec_add = KERNEL.global_labels["ec_add"];
let ec_double = KERNEL.global_labels["ec_double"];
let ec_mul = KERNEL.global_labels["ec_mul"];
let identity = ("0x0", "0x0");
let invalid = ("0x0", "0x3"); // Not on curve
let point0 = (
@ -43,110 +42,76 @@ mod bn {
// Standard addition #1
let initial_stack = u256ify(["0xdeadbeef", point0.1, point0.0, point1.1, point1.0])?;
let stack = run_with_kernel(&kernel, ec_add, initial_stack)?
.stack()
.to_vec();
let stack = run_interpreter(ec_add, initial_stack)?.stack().to_vec();
assert_eq!(stack, u256ify([point2.1, point2.0])?);
// Standard addition #2
let initial_stack = u256ify(["0xdeadbeef", point1.1, point1.0, point0.1, point0.0])?;
let stack = run_with_kernel(&kernel, ec_add, initial_stack)?
.stack()
.to_vec();
let stack = run_interpreter(ec_add, initial_stack)?.stack().to_vec();
assert_eq!(stack, u256ify([point2.1, point2.0])?);
// Standard doubling #1
let initial_stack = u256ify(["0xdeadbeef", point0.1, point0.0, point0.1, point0.0])?;
let stack = run_with_kernel(&kernel, ec_add, initial_stack)?
.stack()
.to_vec();
let stack = run_interpreter(ec_add, initial_stack)?.stack().to_vec();
assert_eq!(stack, u256ify([point3.1, point3.0])?);
// Standard doubling #2
let initial_stack = u256ify(["0xdeadbeef", point0.1, point0.0])?;
let stack = run_with_kernel(&kernel, ec_double, initial_stack)?
.stack()
.to_vec();
let stack = run_interpreter(ec_double, initial_stack)?.stack().to_vec();
assert_eq!(stack, u256ify([point3.1, point3.0])?);
// Standard doubling #3
let initial_stack = u256ify(["0xdeadbeef", "0x2", point0.1, point0.0])?;
let stack = run_with_kernel(&kernel, ec_mul, initial_stack)?
.stack()
.to_vec();
let stack = run_interpreter(ec_mul, initial_stack)?.stack().to_vec();
assert_eq!(stack, u256ify([point3.1, point3.0])?);
// Addition with identity #1
let initial_stack = u256ify(["0xdeadbeef", identity.1, identity.0, point1.1, point1.0])?;
let stack = run_with_kernel(&kernel, ec_add, initial_stack)?
.stack()
.to_vec();
let stack = run_interpreter(ec_add, initial_stack)?.stack().to_vec();
assert_eq!(stack, u256ify([point1.1, point1.0])?);
// Addition with identity #2
let initial_stack = u256ify(["0xdeadbeef", point1.1, point1.0, identity.1, identity.0])?;
let stack = run_with_kernel(&kernel, ec_add, initial_stack)?
.stack()
.to_vec();
let stack = run_interpreter(ec_add, initial_stack)?.stack().to_vec();
assert_eq!(stack, u256ify([point1.1, point1.0])?);
// Addition with identity #3
let initial_stack =
u256ify(["0xdeadbeef", identity.1, identity.0, identity.1, identity.0])?;
let stack = run_with_kernel(&kernel, ec_add, initial_stack)?
.stack()
.to_vec();
let stack = run_interpreter(ec_add, initial_stack)?.stack().to_vec();
assert_eq!(stack, u256ify([identity.1, identity.0])?);
// Addition with invalid point(s) #1
let initial_stack = u256ify(["0xdeadbeef", point0.1, point0.0, invalid.1, invalid.0])?;
let stack = run_with_kernel(&kernel, ec_add, initial_stack)?
.stack()
.to_vec();
let stack = run_interpreter(ec_add, initial_stack)?.stack().to_vec();
assert_eq!(stack, vec![U256::MAX, U256::MAX]);
// Addition with invalid point(s) #2
let initial_stack = u256ify(["0xdeadbeef", invalid.1, invalid.0, point0.1, point0.0])?;
let stack = run_with_kernel(&kernel, ec_add, initial_stack)?
.stack()
.to_vec();
let stack = run_interpreter(ec_add, initial_stack)?.stack().to_vec();
assert_eq!(stack, vec![U256::MAX, U256::MAX]);
// Addition with invalid point(s) #3
let initial_stack = u256ify(["0xdeadbeef", invalid.1, invalid.0, identity.1, identity.0])?;
let stack = run_with_kernel(&kernel, ec_add, initial_stack)?
.stack()
.to_vec();
let stack = run_interpreter(ec_add, initial_stack)?.stack().to_vec();
assert_eq!(stack, vec![U256::MAX, U256::MAX]);
// Addition with invalid point(s) #4
let initial_stack = u256ify(["0xdeadbeef", invalid.1, invalid.0, invalid.1, invalid.0])?;
let stack = run_with_kernel(&kernel, ec_add, initial_stack)?
.stack()
.to_vec();
let stack = run_interpreter(ec_add, initial_stack)?.stack().to_vec();
assert_eq!(stack, vec![U256::MAX, U256::MAX]);
// Scalar multiplication #1
let initial_stack = u256ify(["0xdeadbeef", s, point0.1, point0.0])?;
let stack = run_with_kernel(&kernel, ec_mul, initial_stack)?
.stack()
.to_vec();
let stack = run_interpreter(ec_mul, initial_stack)?.stack().to_vec();
assert_eq!(stack, u256ify([point4.1, point4.0])?);
// Scalar multiplication #2
let initial_stack = u256ify(["0xdeadbeef", "0x0", point0.1, point0.0])?;
let stack = run_with_kernel(&kernel, ec_mul, initial_stack)?
.stack()
.to_vec();
let stack = run_interpreter(ec_mul, initial_stack)?.stack().to_vec();
assert_eq!(stack, u256ify([identity.1, identity.0])?);
// Scalar multiplication #3
let initial_stack = u256ify(["0xdeadbeef", "0x1", point0.1, point0.0])?;
let stack = run_with_kernel(&kernel, ec_mul, initial_stack)?
.stack()
.to_vec();
let stack = run_interpreter(ec_mul, initial_stack)?.stack().to_vec();
assert_eq!(stack, u256ify([point0.1, point0.0])?);
// Scalar multiplication #4
let initial_stack = u256ify(["0xdeadbeef", s, identity.1, identity.0])?;
let stack = run_with_kernel(&kernel, ec_mul, initial_stack)?
.stack()
.to_vec();
let stack = run_interpreter(ec_mul, initial_stack)?.stack().to_vec();
assert_eq!(stack, u256ify([identity.1, identity.0])?);
// Scalar multiplication #5
let initial_stack = u256ify(["0xdeadbeef", s, invalid.1, invalid.0])?;
let stack = run_with_kernel(&kernel, ec_mul, initial_stack)?
.stack()
.to_vec();
let stack = run_interpreter(ec_mul, initial_stack)?.stack().to_vec();
assert_eq!(stack, vec![U256::MAX, U256::MAX]);
// Multiple calls
@ -160,9 +125,7 @@ mod bn {
point0.1,
point0.0,
])?;
let stack = run_with_kernel(&kernel, ec_add, initial_stack)?
.stack()
.to_vec();
let stack = run_interpreter(ec_add, initial_stack)?.stack().to_vec();
assert_eq!(stack, u256ify([point4.1, point4.0])?);
Ok(())
@ -174,7 +137,7 @@ mod secp {
use anyhow::Result;
use crate::cpu::kernel::aggregator::combined_kernel;
use crate::cpu::kernel::interpreter::{run, run_with_kernel};
use crate::cpu::kernel::interpreter::{run, run_interpreter};
use crate::cpu::kernel::tests::u256ify;
#[test]
@ -212,9 +175,7 @@ mod secp {
// Standard addition #1
let initial_stack = u256ify(["0xdeadbeef", point0.1, point0.0, point1.1, point1.0])?;
let stack = run_with_kernel(&kernel, ec_add, initial_stack)?
.stack()
.to_vec();
let stack = run_interpreter(ec_add, initial_stack)?.stack().to_vec();
assert_eq!(stack, u256ify([point2.1, point2.0])?);
// Standard addition #2
let initial_stack = u256ify(["0xdeadbeef", point1.1, point1.0, point0.1, point0.0])?;
@ -225,66 +186,46 @@ mod secp {
// Standard doubling #1
let initial_stack = u256ify(["0xdeadbeef", point0.1, point0.0, point0.1, point0.0])?;
let stack = run_with_kernel(&kernel, ec_add, initial_stack)?
.stack()
.to_vec();
let stack = run_interpreter(ec_add, initial_stack)?.stack().to_vec();
assert_eq!(stack, u256ify([point3.1, point3.0])?);
// Standard doubling #2
let initial_stack = u256ify(["0xdeadbeef", point0.1, point0.0])?;
let stack = run_with_kernel(&kernel, ec_double, initial_stack)?
.stack()
.to_vec();
let stack = run_interpreter(ec_double, initial_stack)?.stack().to_vec();
assert_eq!(stack, u256ify([point3.1, point3.0])?);
// Standard doubling #3
let initial_stack = u256ify(["0xdeadbeef", "0x2", point0.1, point0.0])?;
let stack = run_with_kernel(&kernel, ec_mul, initial_stack)?
.stack()
.to_vec();
let stack = run_interpreter(ec_mul, initial_stack)?.stack().to_vec();
assert_eq!(stack, u256ify([point3.1, point3.0])?);
// Addition with identity #1
let initial_stack = u256ify(["0xdeadbeef", identity.1, identity.0, point1.1, point1.0])?;
let stack = run_with_kernel(&kernel, ec_add, initial_stack)?
.stack()
.to_vec();
let stack = run_interpreter(ec_add, initial_stack)?.stack().to_vec();
assert_eq!(stack, u256ify([point1.1, point1.0])?);
// Addition with identity #2
let initial_stack = u256ify(["0xdeadbeef", point1.1, point1.0, identity.1, identity.0])?;
let stack = run_with_kernel(&kernel, ec_add, initial_stack)?
.stack()
.to_vec();
let stack = run_interpreter(ec_add, initial_stack)?.stack().to_vec();
assert_eq!(stack, u256ify([point1.1, point1.0])?);
// Addition with identity #3
let initial_stack =
u256ify(["0xdeadbeef", identity.1, identity.0, identity.1, identity.0])?;
let stack = run_with_kernel(&kernel, ec_add, initial_stack)?
.stack()
.to_vec();
let stack = run_interpreter(ec_add, initial_stack)?.stack().to_vec();
assert_eq!(stack, u256ify([identity.1, identity.0])?);
// Scalar multiplication #1
let initial_stack = u256ify(["0xdeadbeef", s, point0.1, point0.0])?;
let stack = run_with_kernel(&kernel, ec_mul, initial_stack)?
.stack()
.to_vec();
let stack = run_interpreter(ec_mul, initial_stack)?.stack().to_vec();
assert_eq!(stack, u256ify([point4.1, point4.0])?);
// Scalar multiplication #2
let initial_stack = u256ify(["0xdeadbeef", "0x0", point0.1, point0.0])?;
let stack = run_with_kernel(&kernel, ec_mul, initial_stack)?
.stack()
.to_vec();
let stack = run_interpreter(ec_mul, initial_stack)?.stack().to_vec();
assert_eq!(stack, u256ify([identity.1, identity.0])?);
// Scalar multiplication #3
let initial_stack = u256ify(["0xdeadbeef", "0x1", point0.1, point0.0])?;
let stack = run_with_kernel(&kernel, ec_mul, initial_stack)?
.stack()
.to_vec();
let stack = run_interpreter(ec_mul, initial_stack)?.stack().to_vec();
assert_eq!(stack, u256ify([point0.1, point0.0])?);
// Scalar multiplication #4
let initial_stack = u256ify(["0xdeadbeef", s, identity.1, identity.0])?;
let stack = run_with_kernel(&kernel, ec_mul, initial_stack)?
.stack()
.to_vec();
let stack = run_interpreter(ec_mul, initial_stack)?.stack().to_vec();
assert_eq!(stack, u256ify([identity.1, identity.0])?);
// Multiple calls
@ -298,9 +239,7 @@ mod secp {
point0.1,
point0.0,
])?;
let stack = run_with_kernel(&kernel, ec_add, initial_stack)?
.stack()
.to_vec();
let stack = run_interpreter(ec_add, initial_stack)?.stack().to_vec();
assert_eq!(stack, u256ify([point4.1, point4.0])?);
Ok(())

View File

@ -3,35 +3,23 @@ use std::str::FromStr;
use anyhow::Result;
use ethereum_types::U256;
use crate::cpu::kernel::aggregator::combined_kernel;
use crate::cpu::kernel::assembler::Kernel;
use crate::cpu::kernel::interpreter::run_with_kernel;
use crate::cpu::kernel::aggregator::KERNEL;
use crate::cpu::kernel::interpreter::run_interpreter;
use crate::cpu::kernel::tests::u256ify;
fn test_valid_ecrecover(
hash: &str,
v: &str,
r: &str,
s: &str,
expected: &str,
kernel: &Kernel,
) -> Result<()> {
let ecrecover = kernel.global_labels["ecrecover"];
fn test_valid_ecrecover(hash: &str, v: &str, r: &str, s: &str, expected: &str) -> Result<()> {
let ecrecover = KERNEL.global_labels["ecrecover"];
let initial_stack = u256ify(["0xdeadbeef", s, r, v, hash])?;
let stack = run_with_kernel(kernel, ecrecover, initial_stack)?
.stack()
.to_vec();
let stack = run_interpreter(ecrecover, initial_stack)?.stack().to_vec();
assert_eq!(stack[0], U256::from_str(expected).unwrap());
Ok(())
}
fn test_invalid_ecrecover(hash: &str, v: &str, r: &str, s: &str, kernel: &Kernel) -> Result<()> {
let ecrecover = kernel.global_labels["ecrecover"];
fn test_invalid_ecrecover(hash: &str, v: &str, r: &str, s: &str) -> Result<()> {
let ecrecover = KERNEL.global_labels["ecrecover"];
let initial_stack = u256ify(["0xdeadbeef", s, r, v, hash])?;
let stack = run_with_kernel(kernel, ecrecover, initial_stack)?
.stack()
.to_vec();
let stack = run_interpreter(ecrecover, initial_stack)?.stack().to_vec();
assert_eq!(stack, vec![U256::MAX]);
Ok(())
@ -39,15 +27,12 @@ fn test_invalid_ecrecover(hash: &str, v: &str, r: &str, s: &str, kernel: &Kernel
#[test]
fn test_ecrecover() -> Result<()> {
let kernel = combined_kernel();
test_valid_ecrecover(
"0x55f77e8909b1f1c9531c4a309bb2d40388e9ed4b87830c8f90363c6b36255fb9",
"0x1b",
"0xd667c5a20fa899b253924099e10ae92998626718585b8171eb98de468bbebc",
"0x58351f48ce34bf134ee611fb5bf255a5733f0029561d345a7d46bfa344b60ac0",
"0x67f3c0Da351384838d7F7641AB0fCAcF853E1844",
&kernel,
)?;
test_valid_ecrecover(
"0x55f77e8909b1f1c9531c4a309bb2d40388e9ed4b87830c8f90363c6b36255fb9",
@ -55,7 +40,6 @@ fn test_ecrecover() -> Result<()> {
"0xd667c5a20fa899b253924099e10ae92998626718585b8171eb98de468bbebc",
"0x58351f48ce34bf134ee611fb5bf255a5733f0029561d345a7d46bfa344b60ac0",
"0xaA58436DeABb64982a386B2De1A8015AA28fCCc0",
&kernel,
)?;
test_valid_ecrecover(
"0x0",
@ -63,7 +47,6 @@ fn test_ecrecover() -> Result<()> {
"0x1",
"0xfffffffffffffffffffffffffffffffebaaedce6af48a03bbfd25e8cd0364140",
"0x3344c6f6eeCA588be132142DB0a32C71ABFAAe7B",
&kernel,
)?;
test_invalid_ecrecover(
@ -71,28 +54,24 @@ fn test_ecrecover() -> Result<()> {
"0x42", // v not in {27,28}
"0x1",
"0xfffffffffffffffffffffffffffffffebaaedce6af48a03bbfd25e8cd0364140",
&kernel,
)?;
test_invalid_ecrecover(
"0x0",
"0x42",
"0xd667c5a20fa899b253924099e10ae92998626718585b8171eb98de468bbebc",
"0x0", // s=0
&kernel,
)?;
test_invalid_ecrecover(
"0x0",
"0x42",
"0x0", // r=0
"0xd667c5a20fa899b253924099e10ae92998626718585b8171eb98de468bbebc",
&kernel,
)?;
test_invalid_ecrecover(
"0x0",
"0x1c",
"0x3a18b21408d275dde53c0ea86f9c1982eca60193db0ce15008fa408d43024847", // r^3 + 7 isn't a square
"0x5db9745f44089305b2f2c980276e7025a594828d878e6e36dd2abd34ca6b9e3d",
&kernel,
)?;
Ok(())

View File

@ -2,50 +2,43 @@ use anyhow::Result;
use ethereum_types::U256;
use rand::{thread_rng, Rng};
use crate::cpu::kernel::aggregator::combined_kernel;
use crate::cpu::kernel::interpreter::{run, run_with_kernel};
use crate::cpu::kernel::aggregator::KERNEL;
use crate::cpu::kernel::interpreter::{run, run_interpreter};
#[test]
fn test_exp() -> Result<()> {
// Make sure we can parse and assemble the entire kernel.
let kernel = combined_kernel();
let exp = kernel.global_labels["exp"];
let exp = KERNEL.global_labels["exp"];
let mut rng = thread_rng();
let a = U256([0; 4].map(|_| rng.gen()));
let b = U256([0; 4].map(|_| rng.gen()));
// Random input
let initial_stack = vec![0xDEADBEEFu32.into(), b, a];
let stack_with_kernel = run_with_kernel(&kernel, exp, initial_stack)?
.stack()
.to_vec();
let stack_with_kernel = run_interpreter(exp, initial_stack)?.stack().to_vec();
let initial_stack = vec![b, a];
let code = [0xa, 0x63, 0xde, 0xad, 0xbe, 0xef, 0x56]; // EXP, PUSH4 deadbeef, JUMP
let stack_with_opcode = run(&code, 0, initial_stack, &kernel.prover_inputs)?
let stack_with_opcode = run(&code, 0, initial_stack, &KERNEL.prover_inputs)?
.stack()
.to_vec();
assert_eq!(stack_with_kernel, stack_with_opcode);
// 0 base
let initial_stack = vec![0xDEADBEEFu32.into(), b, U256::zero()];
let stack_with_kernel = run_with_kernel(&kernel, exp, initial_stack)?
.stack()
.to_vec();
let stack_with_kernel = run_interpreter(exp, initial_stack)?.stack().to_vec();
let initial_stack = vec![b, U256::zero()];
let code = [0xa, 0x63, 0xde, 0xad, 0xbe, 0xef, 0x56]; // EXP, PUSH4 deadbeef, JUMP
let stack_with_opcode = run(&code, 0, initial_stack, &kernel.prover_inputs)?
let stack_with_opcode = run(&code, 0, initial_stack, &KERNEL.prover_inputs)?
.stack()
.to_vec();
assert_eq!(stack_with_kernel, stack_with_opcode);
// 0 exponent
let initial_stack = vec![0xDEADBEEFu32.into(), U256::zero(), a];
let stack_with_kernel = run_with_kernel(&kernel, exp, initial_stack)?
.stack()
.to_vec();
let stack_with_kernel = run_interpreter(exp, initial_stack)?.stack().to_vec();
let initial_stack = vec![U256::zero(), a];
let code = [0xa, 0x63, 0xde, 0xad, 0xbe, 0xef, 0x56]; // EXP, PUSH4 deadbeef, JUMP
let stack_with_opcode = run(&code, 0, initial_stack, &kernel.prover_inputs)?
let stack_with_opcode = run(&code, 0, initial_stack, &KERNEL.prover_inputs)?
.stack()
.to_vec();
assert_eq!(stack_with_kernel, stack_with_opcode);

View File

@ -2,8 +2,8 @@ use anyhow::Result;
use ethereum_types::U256;
use rand::{thread_rng, Rng};
use crate::cpu::kernel::aggregator::combined_kernel;
use crate::cpu::kernel::interpreter::run_with_kernel;
use crate::cpu::kernel::aggregator::KERNEL;
use crate::cpu::kernel::interpreter::run_interpreter;
// TODO: 107 is hardcoded as a dummy prime for testing
// should be changed to the proper implementation prime
@ -137,10 +137,9 @@ fn test_fp6() -> Result<()> {
let mut input: Vec<u32> = [c, d].into_iter().flatten().flatten().collect();
input.push(0xdeadbeef);
let kernel = combined_kernel();
let initial_offset = kernel.global_labels["mul_fp6"];
let initial_offset = KERNEL.global_labels["mul_fp6"];
let initial_stack: Vec<U256> = as_stack(input);
let final_stack: Vec<U256> = run_with_kernel(&kernel, initial_offset, initial_stack)?
let final_stack: Vec<U256> = run_interpreter(initial_offset, initial_stack)?
.stack()
.to_vec();

View File

@ -6,8 +6,8 @@ use rand::{thread_rng, Rng};
use ripemd::{Digest, Ripemd160};
use sha2::Sha256;
use crate::cpu::kernel::aggregator::combined_kernel;
use crate::cpu::kernel::interpreter::run_with_kernel;
use crate::cpu::kernel::aggregator::KERNEL;
use crate::cpu::kernel::interpreter::run_interpreter;
/// Standard Sha2 implementation.
fn sha2(input: Vec<u8>) -> U256 {
@ -62,12 +62,11 @@ fn test_hash(hash_fn_label: &str, standard_implementation: &dyn Fn(Vec<u8>) -> U
let initial_stack_custom = make_input_stack(message_custom);
// Make the kernel.
let kernel = combined_kernel();
let kernel_function = kernel.global_labels[hash_fn_label];
let kernel_function = KERNEL.global_labels[hash_fn_label];
// Run the kernel code.
let result_random = run_with_kernel(&kernel, kernel_function, initial_stack_random)?;
let result_custom = run_with_kernel(&kernel, kernel_function, initial_stack_custom)?;
let result_random = run_interpreter(kernel_function, initial_stack_random)?;
let result_custom = run_interpreter(kernel_function, initial_stack_custom)?;
// Extract the final output.
let actual_random = result_random.stack()[0];

View File

@ -2,8 +2,8 @@ use anyhow::Result;
use ethereum_types::U256;
use itertools::Itertools;
use crate::cpu::kernel::aggregator::combined_kernel;
use crate::cpu::kernel::interpreter::run_with_kernel;
use crate::cpu::kernel::aggregator::KERNEL;
use crate::cpu::kernel::interpreter::run_interpreter;
fn make_input(word: &str) -> Vec<u32> {
let mut input: Vec<u32> = vec![word.len().try_into().unwrap()];
@ -44,10 +44,9 @@ fn test_ripemd_reference() -> Result<()> {
let input: Vec<u32> = make_input(x);
let expected = U256::from(y);
let kernel = combined_kernel();
let initial_offset = kernel.global_labels["ripemd_stack"];
let initial_offset = KERNEL.global_labels["ripemd_stack"];
let initial_stack: Vec<U256> = input.iter().map(|&x| U256::from(x)).rev().collect();
let final_stack: Vec<U256> = run_with_kernel(&kernel, initial_offset, initial_stack)?
let final_stack: Vec<U256> = run_interpreter(initial_offset, initial_stack)?
.stack()
.to_vec();
let actual = final_stack[0];

View File

@ -8,6 +8,7 @@ mod jumps;
pub mod kernel;
pub(crate) mod membus;
mod modfp254;
mod shift;
mod simple_logic;
mod stack;
mod stack_bounds;

108
evm/src/cpu/shift.rs Normal file
View File

@ -0,0 +1,108 @@
use plonky2::field::extension::Extendable;
use plonky2::field::packed::PackedField;
use plonky2::field::types::Field;
use plonky2::hash::hash_types::RichField;
use plonky2::iop::ext_target::ExtensionTarget;
use crate::constraint_consumer::{ConstraintConsumer, RecursiveConstraintConsumer};
use crate::cpu::columns::CpuColumnsView;
use crate::cpu::membus::NUM_GP_CHANNELS;
use crate::memory::segments::Segment;
pub(crate) fn eval_packed<P: PackedField>(
lv: &CpuColumnsView<P>,
yield_constr: &mut ConstraintConsumer<P>,
) {
let is_shift = lv.op.shl + lv.op.shr;
let displacement = lv.mem_channels[1]; // holds the shift displacement d
let two_exp = lv.mem_channels[2]; // holds 2^d
// Not needed here; val is the input and we're verifying that output is
// val * 2^d (mod 2^256)
//let val = lv.mem_channels[0];
//let output = lv.mem_channels[NUM_GP_CHANNELS - 1];
let shift_table_segment = P::Scalar::from_canonical_u64(Segment::ShiftTable as u64);
// Only lookup the shifting factor when displacement is < 2^32.
// two_exp.used is true (1) if the high limbs of the displacement are
// zero and false (0) otherwise.
let high_limbs_are_zero = two_exp.used;
yield_constr.constraint(is_shift * (two_exp.is_read - P::ONES));
let high_limbs_sum: P = displacement.value[1..].iter().copied().sum();
let high_limbs_sum_inv = lv.general.shift().high_limb_sum_inv;
// Verify that high_limbs_are_zero = 0 implies high_limbs_sum != 0 and
// high_limbs_are_zero = 1 implies high_limbs_sum = 0.
let t = high_limbs_sum * high_limbs_sum_inv - (P::ONES - high_limbs_are_zero);
yield_constr.constraint(is_shift * t);
yield_constr.constraint(is_shift * high_limbs_sum * high_limbs_are_zero);
// When the shift displacement is < 2^32, constrain the two_exp
// mem_channel to be the entry corresponding to `displacement` in
// the shift table lookup (will be zero if displacement >= 256).
yield_constr.constraint(is_shift * two_exp.addr_context); // read from kernel memory
yield_constr.constraint(is_shift * (two_exp.addr_segment - shift_table_segment));
yield_constr.constraint(is_shift * (two_exp.addr_virtual - displacement.value[0]));
// Other channels must be unused
for chan in &lv.mem_channels[3..NUM_GP_CHANNELS - 1] {
yield_constr.constraint(is_shift * chan.used); // channel is not used
}
// Cross-table lookup must connect the memory channels here to MUL
// (in the case of left shift) or DIV (in the case of right shift)
// in the arithmetic table. Specifically, the mapping is
//
// 0 -> 0 (value to be shifted is the same)
// 2 -> 1 (two_exp becomes the multiplicand (resp. divisor))
// last -> last (output is the same)
}
pub(crate) fn eval_ext_circuit<F: RichField + Extendable<D>, const D: usize>(
builder: &mut plonky2::plonk::circuit_builder::CircuitBuilder<F, D>,
lv: &CpuColumnsView<ExtensionTarget<D>>,
yield_constr: &mut RecursiveConstraintConsumer<F, D>,
) {
let is_shift = builder.add_extension(lv.op.shl, lv.op.shr);
let displacement = lv.mem_channels[1];
let two_exp = lv.mem_channels[2];
let shift_table_segment = F::from_canonical_u64(Segment::ShiftTable as u64);
let high_limbs_are_zero = two_exp.used;
let one = builder.one_extension();
let t = builder.sub_extension(two_exp.is_read, one);
let t = builder.mul_extension(is_shift, t);
yield_constr.constraint(builder, t);
let high_limbs_sum = builder.add_many_extension(&displacement.value[1..]);
let high_limbs_sum_inv = lv.general.shift().high_limb_sum_inv;
let t = builder.one_extension();
let t = builder.sub_extension(t, high_limbs_are_zero);
let t = builder.mul_sub_extension(high_limbs_sum, high_limbs_sum_inv, t);
let t = builder.mul_extension(is_shift, t);
yield_constr.constraint(builder, t);
let t = builder.mul_many_extension([is_shift, high_limbs_sum, high_limbs_are_zero]);
yield_constr.constraint(builder, t);
let t = builder.mul_extension(is_shift, two_exp.addr_context);
yield_constr.constraint(builder, t);
let t = builder.arithmetic_extension(
F::ONE,
-shift_table_segment,
is_shift,
two_exp.addr_segment,
is_shift,
);
yield_constr.constraint(builder, t);
let t = builder.sub_extension(two_exp.addr_virtual, displacement.value[0]);
let t = builder.mul_extension(is_shift, t);
yield_constr.constraint(builder, t);
for chan in &lv.mem_channels[3..NUM_GP_CHANNELS - 1] {
let t = builder.mul_extension(is_shift, chan.used);
yield_constr.constraint(builder, t);
}
}

View File

@ -40,73 +40,33 @@ const BASIC_TERNARY_OP: Option<StackBehavior> = Some(StackBehavior {
// except the first `num_pops` and the last `pushes as usize` channels have their read flag and
// address constrained automatically in this file.
const STACK_BEHAVIORS: OpsColumnsView<Option<StackBehavior>> = OpsColumnsView {
stop: None, // TODO
add: BASIC_BINARY_OP,
mul: BASIC_BINARY_OP,
sub: BASIC_BINARY_OP,
div: BASIC_BINARY_OP,
sdiv: BASIC_BINARY_OP,
mod_: BASIC_BINARY_OP,
smod: BASIC_BINARY_OP,
addmod: BASIC_TERNARY_OP,
mulmod: BASIC_TERNARY_OP,
exp: None, // TODO
signextend: BASIC_BINARY_OP,
addfp254: BASIC_BINARY_OP,
mulfp254: BASIC_BINARY_OP,
subfp254: BASIC_BINARY_OP,
lt: BASIC_BINARY_OP,
gt: BASIC_BINARY_OP,
slt: BASIC_BINARY_OP,
sgt: BASIC_BINARY_OP,
eq: BASIC_BINARY_OP,
iszero: BASIC_UNARY_OP,
and: BASIC_BINARY_OP,
or: BASIC_BINARY_OP,
xor: BASIC_BINARY_OP,
not: BASIC_TERNARY_OP,
not: BASIC_UNARY_OP,
byte: BASIC_BINARY_OP,
shl: BASIC_BINARY_OP,
shr: BASIC_BINARY_OP,
sar: BASIC_BINARY_OP,
keccak256: None, // TODO
keccak_general: None, // TODO
address: None, // TODO
balance: None, // TODO
origin: None, // TODO
caller: None, // TODO
callvalue: None, // TODO
calldataload: None, // TODO
calldatasize: None, // TODO
calldatacopy: None, // TODO
codesize: None, // TODO
codecopy: None, // TODO
gasprice: None, // TODO
extcodesize: None, // TODO
extcodecopy: None, // TODO
returndatasize: None, // TODO
returndatacopy: None, // TODO
extcodehash: None, // TODO
blockhash: None, // TODO
coinbase: None, // TODO
timestamp: None, // TODO
number: None, // TODO
difficulty: None, // TODO
gaslimit: None, // TODO
chainid: None, // TODO
selfbalance: None, // TODO
basefee: None, // TODO
prover_input: None, // TODO
pop: None, // TODO
mload: None, // TODO
mstore: None, // TODO
mstore8: None, // TODO
sload: None, // TODO
sstore: None, // TODO
jump: None, // TODO
jumpi: None, // TODO
pc: None, // TODO
msize: None, // TODO
gas: None, // TODO
jumpdest: None, // TODO
get_state_root: None, // TODO
@ -116,27 +76,17 @@ const STACK_BEHAVIORS: OpsColumnsView<Option<StackBehavior>> = OpsColumnsView {
push: None, // TODO
dup: None,
swap: None,
log0: None, // TODO
log1: None, // TODO
log2: None, // TODO
log3: None, // TODO
log4: None, // TODO
create: None, // TODO
call: None, // TODO
callcode: None, // TODO
return_: None, // TODO
delegatecall: None, // TODO
create2: None, // TODO
get_context: None, // TODO
set_context: None, // TODO
consume_gas: None, // TODO
exit_kernel: None, // TODO
staticcall: None, // TODO
mload_general: None, // TODO
mstore_general: None, // TODO
revert: None, // TODO
selfdestruct: None, // TODO
invalid: None, // TODO
syscall: Some(StackBehavior {
num_pops: 0,
pushes: true,
disable_other_channels: false,
}),
};
fn eval_packed_one<P: PackedField>(

View File

@ -2,62 +2,81 @@
//!
//! These are usually the ones that are too complicated to implement in one CPU table row.
use once_cell::sync::Lazy;
use plonky2::field::extension::Extendable;
use plonky2::field::packed::PackedField;
use plonky2::field::types::Field;
use plonky2::hash::hash_types::RichField;
use plonky2::iop::ext_target::ExtensionTarget;
use static_assertions::const_assert;
use crate::constraint_consumer::{ConstraintConsumer, RecursiveConstraintConsumer};
use crate::cpu::columns::{CpuColumnsView, COL_MAP};
use crate::cpu::columns::CpuColumnsView;
use crate::cpu::kernel::aggregator::KERNEL;
use crate::cpu::membus::NUM_GP_CHANNELS;
use crate::memory::segments::Segment;
const NUM_SYSCALLS: usize = 3;
fn make_syscall_list() -> [(usize, usize); NUM_SYSCALLS] {
let kernel = Lazy::force(&KERNEL);
[
(COL_MAP.op.stop, "sys_stop"),
(COL_MAP.op.exp, "sys_exp"),
(COL_MAP.op.invalid, "handle_invalid"),
]
.map(|(col_index, handler_name)| (col_index, kernel.global_labels[handler_name]))
}
static TRAP_LIST: Lazy<[(usize, usize); NUM_SYSCALLS]> = Lazy::new(make_syscall_list);
// Copy the constant but make it `usize`.
const BYTES_PER_OFFSET: usize = crate::cpu::kernel::assembler::BYTES_PER_OFFSET as usize;
const_assert!(BYTES_PER_OFFSET < NUM_GP_CHANNELS); // Reserve one channel for stack push
pub fn eval_packed<P: PackedField>(
lv: &CpuColumnsView<P>,
nv: &CpuColumnsView<P>,
yield_constr: &mut ConstraintConsumer<P>,
) {
let syscall_list = Lazy::force(&TRAP_LIST);
// 1 if _any_ syscall, else 0.
let should_syscall: P = syscall_list
.iter()
.map(|&(col_index, _)| lv[col_index])
.sum();
let filter = lv.is_cpu_cycle * should_syscall;
let filter = lv.is_cpu_cycle * lv.op.syscall;
// If syscall: set program counter to the handler address
// Note that at most one of the `lv[col_index]`s will be 1 and all others will be 0.
let syscall_dst: P = syscall_list
.iter()
.map(|&(col_index, handler_addr)| {
lv[col_index] * P::Scalar::from_canonical_usize(handler_addr)
})
// Look up the handler in memory
let code_segment = P::Scalar::from_canonical_usize(Segment::Code as usize);
let syscall_jumptable_start =
P::Scalar::from_canonical_usize(KERNEL.global_labels["syscall_jumptable"]);
let opcode: P = lv
.opcode_bits
.into_iter()
.enumerate()
.map(|(i, bit)| bit * P::Scalar::from_canonical_u64(1 << i))
.sum();
yield_constr.constraint_transition(filter * (nv.program_counter - syscall_dst));
// If syscall: set kernel mode
let opcode_handler_addr_start =
syscall_jumptable_start + opcode * P::Scalar::from_canonical_usize(BYTES_PER_OFFSET);
for (i, channel) in lv.mem_channels[0..BYTES_PER_OFFSET].iter().enumerate() {
yield_constr.constraint(filter * (channel.used - P::ONES));
yield_constr.constraint(filter * (channel.is_read - P::ONES));
// Set kernel context and code segment
yield_constr.constraint(filter * channel.addr_context);
yield_constr.constraint(filter * (channel.addr_segment - code_segment));
// Set address, using a separate channel for each of the `BYTES_PER_OFFSET` limbs.
let limb_address = opcode_handler_addr_start + P::Scalar::from_canonical_usize(i);
yield_constr.constraint(filter * (channel.addr_virtual - limb_address));
}
// Disable unused channels (the last channel is used to push to the stack)
for channel in &lv.mem_channels[BYTES_PER_OFFSET..NUM_GP_CHANNELS - 1] {
yield_constr.constraint(filter * channel.used);
}
// Set program counter to the handler address
// The addresses are big-endian in memory
let target = lv.mem_channels[0..BYTES_PER_OFFSET]
.iter()
.map(|channel| channel.value[0])
.fold(P::ZEROS, |cumul, limb| {
cumul * P::Scalar::from_canonical_u64(256) + limb
});
yield_constr.constraint_transition(filter * (nv.program_counter - target));
// Set kernel mode
yield_constr.constraint_transition(filter * (nv.is_kernel_mode - P::ONES));
// Maintain current context
yield_constr.constraint_transition(filter * (nv.context - lv.context));
let output = lv.mem_channels[0].value;
// If syscall: push current PC to stack
// This memory channel is constrained in `stack.rs`.
let output = lv.mem_channels[NUM_GP_CHANNELS - 1].value;
// Push current PC to stack
yield_constr.constraint(filter * (output[0] - lv.program_counter));
// If syscall: push current kernel flag to stack (share register with PC)
// Push current kernel flag to stack (share register with PC)
yield_constr.constraint(filter * (output[1] - lv.is_kernel_mode));
// If syscall: zero the rest of that register
// Zero the rest of that register
for &limb in &output[2..] {
yield_constr.constraint(filter * limb);
}
@ -69,46 +88,111 @@ pub fn eval_ext_circuit<F: RichField + Extendable<D>, const D: usize>(
nv: &CpuColumnsView<ExtensionTarget<D>>,
yield_constr: &mut RecursiveConstraintConsumer<F, D>,
) {
let syscall_list = Lazy::force(&TRAP_LIST);
// 1 if _any_ syscall, else 0.
let should_syscall =
builder.add_many_extension(syscall_list.iter().map(|&(col_index, _)| lv[col_index]));
let filter = builder.mul_extension(lv.is_cpu_cycle, should_syscall);
let filter = builder.mul_extension(lv.is_cpu_cycle, lv.op.syscall);
// If syscall: set program counter to the handler address
{
// Note that at most one of the `lv[col_index]`s will be 1 and all others will be 0.
let syscall_dst = syscall_list.iter().fold(
builder.zero_extension(),
|cumul, &(col_index, handler_addr)| {
let handler_addr = F::from_canonical_usize(handler_addr);
builder.mul_const_add_extension(handler_addr, lv[col_index], cumul)
},
// Look up the handler in memory
let code_segment = F::from_canonical_usize(Segment::Code as usize);
let syscall_jumptable_start = builder.constant_extension(
F::from_canonical_usize(KERNEL.global_labels["syscall_jumptable"]).into(),
);
let constr = builder.sub_extension(nv.program_counter, syscall_dst);
let constr = builder.mul_extension(filter, constr);
let opcode = lv
.opcode_bits
.into_iter()
.rev()
.fold(builder.zero_extension(), |cumul, bit| {
builder.mul_const_add_extension(F::TWO, cumul, bit)
});
let opcode_handler_addr_start = builder.mul_const_add_extension(
F::from_canonical_usize(BYTES_PER_OFFSET),
opcode,
syscall_jumptable_start,
);
for (i, channel) in lv.mem_channels[0..BYTES_PER_OFFSET].iter().enumerate() {
{
let constr = builder.mul_sub_extension(filter, channel.used, filter);
yield_constr.constraint(builder, constr);
}
{
let constr = builder.mul_sub_extension(filter, channel.is_read, filter);
yield_constr.constraint(builder, constr);
}
// Set kernel context and code segment
{
let constr = builder.mul_extension(filter, channel.addr_context);
yield_constr.constraint(builder, constr);
}
{
let constr = builder.arithmetic_extension(
F::ONE,
-code_segment,
filter,
channel.addr_segment,
filter,
);
yield_constr.constraint(builder, constr);
}
// Set address, using a separate channel for each of the `BYTES_PER_OFFSET` limbs.
{
let diff = builder.sub_extension(channel.addr_virtual, opcode_handler_addr_start);
let constr = builder.arithmetic_extension(
F::ONE,
-F::from_canonical_usize(i),
filter,
diff,
filter,
);
yield_constr.constraint(builder, constr);
}
}
// Disable unused channels (the last channel is used to push to the stack)
for channel in &lv.mem_channels[BYTES_PER_OFFSET..NUM_GP_CHANNELS - 1] {
let constr = builder.mul_extension(filter, channel.used);
yield_constr.constraint(builder, constr);
}
// Set program counter to the handler address
// The addresses are big-endian in memory
{
let target = lv.mem_channels[0..BYTES_PER_OFFSET]
.iter()
.map(|channel| channel.value[0])
.fold(builder.zero_extension(), |cumul, limb| {
builder.mul_const_add_extension(F::from_canonical_u64(256), cumul, limb)
});
let diff = builder.sub_extension(nv.program_counter, target);
let constr = builder.mul_extension(filter, diff);
yield_constr.constraint_transition(builder, constr);
}
// If syscall: set kernel mode
// Set kernel mode
{
let constr = builder.mul_sub_extension(filter, nv.is_kernel_mode, filter);
yield_constr.constraint_transition(builder, constr);
}
// Maintain current context
{
let diff = builder.sub_extension(nv.context, lv.context);
let constr = builder.mul_extension(filter, diff);
yield_constr.constraint_transition(builder, constr);
}
let output = lv.mem_channels[0].value;
// If syscall: push current PC to stack
// This memory channel is constrained in `stack.rs`.
let output = lv.mem_channels[NUM_GP_CHANNELS - 1].value;
// Push current PC to stack
{
let constr = builder.sub_extension(output[0], lv.program_counter);
let constr = builder.mul_extension(filter, constr);
let diff = builder.sub_extension(output[0], lv.program_counter);
let constr = builder.mul_extension(filter, diff);
yield_constr.constraint(builder, constr);
}
// If syscall: push current kernel flag to stack (share register with PC)
// Push current kernel flag to stack (share register with PC)
{
let constr = builder.sub_extension(output[1], lv.is_kernel_mode);
let constr = builder.mul_extension(filter, constr);
let diff = builder.sub_extension(output[1], lv.is_kernel_mode);
let constr = builder.mul_extension(filter, diff);
yield_constr.constraint(builder, constr);
}
// If syscall: zero the rest of that register
// Zero the rest of that register
for &limb in &output[2..] {
let constr = builder.mul_extension(filter, limb);
yield_constr.constraint(builder, constr);

View File

@ -35,10 +35,13 @@ pub(crate) enum Segment {
TrieEncodedChild = 14,
/// A buffer used to store the lengths of the encodings of a branch node's children.
TrieEncodedChildLen = 15,
/// A table of values 2^i for i=0..255 for use with shift
/// instructions; initialised by `kernel/asm/shift.asm::init_shift_table()`.
ShiftTable = 16,
}
impl Segment {
pub(crate) const COUNT: usize = 16;
pub(crate) const COUNT: usize = 17;
pub(crate) fn all() -> [Self; Self::COUNT] {
[
@ -58,6 +61,7 @@ impl Segment {
Self::TrieData,
Self::TrieEncodedChild,
Self::TrieEncodedChildLen,
Self::ShiftTable,
]
}
@ -80,6 +84,7 @@ impl Segment {
Segment::TrieData => "SEGMENT_TRIE_DATA",
Segment::TrieEncodedChild => "SEGMENT_TRIE_ENCODED_CHILD",
Segment::TrieEncodedChildLen => "SEGMENT_TRIE_ENCODED_CHILD_LEN",
Segment::ShiftTable => "SEGMENT_SHIFT_TABLE",
}
}
@ -102,6 +107,7 @@ impl Segment {
Segment::TrieData => 256,
Segment::TrieEncodedChild => 256,
Segment::TrieEncodedChildLen => 6,
Segment::ShiftTable => 256,
}
}
}

View File

@ -37,7 +37,7 @@ static_assertions = "1.1.0"
[dev-dependencies]
rand = "0.8.4"
rand_chacha = "0.3.1"
criterion = "0.3.5"
criterion = "0.4.0"
env_logger = "0.9.0"
tynm = "0.1.6"
structopt = "0.3.26"

View File

@ -3,6 +3,7 @@ use std::borrow::Borrow;
use itertools::Itertools;
use plonky2_field::extension::Extendable;
use plonky2_field::types::Field;
use plonky2_util::log_floor;
use crate::gates::base_sum::BaseSumGate;
use crate::hash::hash_types::RichField;
@ -33,6 +34,11 @@ impl<F: RichField + Extendable<D>, const D: usize> CircuitBuilder<F, D> {
pub(crate) fn le_sum(&mut self, bits: impl Iterator<Item = impl Borrow<BoolTarget>>) -> Target {
let bits = bits.map(|b| *b.borrow()).collect_vec();
let num_bits = bits.len();
assert!(
num_bits <= log_floor(F::ORDER, 2),
"{} bits may overflow the field",
num_bits
);
if num_bits == 0 {
return self.zero();
}

View File

@ -16,7 +16,7 @@ rand = "0.8.4"
rand_chacha = "0.3.1"
[dev-dependencies]
criterion = "0.3.5"
criterion = "0.4.0"
[[bench]]
name = "lookup_permuted_cols"

View File

@ -5,3 +5,4 @@ version = "0.1.0"
edition = "2021"
[dependencies]
rand = { version = "0.8.5", default-features = false, features = ["getrandom"] }

View File

@ -274,8 +274,50 @@ pub fn branch_hint() {
#[cfg(test)]
mod tests {
use rand::rngs::OsRng;
use rand::Rng;
use crate::{log2_ceil, log2_strict};
#[test]
fn test_reverse_index_bits() {
let lengths = [32, 128, 1 << 16];
let mut rng = OsRng;
for _ in 0..32 {
for length in lengths {
let mut rand_list: Vec<u32> = Vec::with_capacity(length);
rand_list.resize_with(length, || rng.gen());
let out = super::reverse_index_bits(&rand_list);
let expect = reverse_index_bits_naive(&rand_list);
for (out, expect) in out.iter().zip(&expect) {
assert_eq!(out, expect);
}
}
}
}
#[test]
fn test_reverse_index_bits_in_place() {
let lengths = [32, 128, 1 << 16];
let mut rng = OsRng;
for _ in 0..32 {
for length in lengths {
let mut rand_list: Vec<u32> = Vec::with_capacity(length);
rand_list.resize_with(length, || rng.gen());
let expect = reverse_index_bits_naive(&rand_list);
super::reverse_index_bits_in_place(&mut rand_list);
for (got, expect) in rand_list.iter().zip(&expect) {
assert_eq!(got, expect);
}
}
}
}
#[test]
fn test_log2_strict() {
assert_eq!(log2_strict(1), 0);
@ -326,4 +368,17 @@ mod tests {
assert_eq!(log2_ceil(usize::MAX - 1), usize::BITS as usize);
assert_eq!(log2_ceil(usize::MAX), usize::BITS as usize);
}
fn reverse_index_bits_naive<T: Copy>(arr: &[T]) -> Vec<T> {
let n = arr.len();
let n_power = log2_strict(n);
let mut out = vec![None; n];
for (i, v) in arr.iter().enumerate() {
let dst = i.reverse_bits() >> (64 - n_power);
out[dst] = Some(*v);
}
out.into_iter().map(|x| x.unwrap()).collect()
}
}