diff --git a/evm/src/arithmetic/utils.rs b/evm/src/arithmetic/utils.rs index 74999ab4..ec989c94 100644 --- a/evm/src/arithmetic/utils.rs +++ b/evm/src/arithmetic/utils.rs @@ -1,6 +1,5 @@ use std::ops::{Add, AddAssign, Mul, Neg, Range, Shr, Sub, SubAssign}; -use log::error; use plonky2::field::extension::Extendable; use plonky2::hash::hash_types::RichField; use plonky2::iop::ext_target::ExtensionTarget; @@ -11,21 +10,24 @@ use crate::arithmetic::columns::{NUM_ARITH_COLUMNS, N_LIMBS}; /// Emit an error message regarding unchecked range assumptions. /// Assumes the values in `cols` are `[cols[0], cols[0] + 1, ..., /// cols[0] + cols.len() - 1]`. +/// +/// TODO: Hamish to delete this when he has implemented and integrated +/// range checks. pub(crate) fn _range_check_error( - file: &str, - line: u32, - cols: Range, - signedness: &str, + _file: &str, + _line: u32, + _cols: Range, + _signedness: &str, ) { - error!( - "{}:{}: arithmetic unit skipped {}-bit {} range-checks on columns {}--{}: not yet implemented", - line, - file, - RC_BITS, - signedness, - cols.start, - cols.end - 1, - ); + // error!( + // "{}:{}: arithmetic unit skipped {}-bit {} range-checks on columns {}--{}: not yet implemented", + // line, + // file, + // RC_BITS, + // signedness, + // cols.start, + // cols.end - 1, + // ); } #[macro_export] diff --git a/evm/src/cpu/columns/general.rs b/evm/src/cpu/columns/general.rs index 134788dc..5a2c9426 100644 --- a/evm/src/cpu/columns/general.rs +++ b/evm/src/cpu/columns/general.rs @@ -9,6 +9,7 @@ pub(crate) union CpuGeneralColumnsView { arithmetic: CpuArithmeticView, logic: CpuLogicView, jumps: CpuJumpsView, + shift: CpuShiftView, } impl CpuGeneralColumnsView { @@ -51,6 +52,16 @@ impl CpuGeneralColumnsView { pub(crate) fn jumps_mut(&mut self) -> &mut CpuJumpsView { unsafe { &mut self.jumps } } + + // SAFETY: Each view is a valid interpretation of the underlying array. + pub(crate) fn shift(&self) -> &CpuShiftView { + unsafe { &self.shift } + } + + // SAFETY: Each view is a valid interpretation of the underlying array. + pub(crate) fn shift_mut(&mut self) -> &mut CpuShiftView { + unsafe { &mut self.shift } + } } impl PartialEq for CpuGeneralColumnsView { @@ -144,5 +155,12 @@ pub(crate) struct CpuJumpsView { pub(crate) should_trap: T, } +#[derive(Copy, Clone)] +pub(crate) struct CpuShiftView { + // For a shift amount of displacement: [T], this is the inverse of + // sum(displacement[1..]) or zero if the sum is zero. + pub(crate) high_limb_sum_inv: T, +} + // `u8` is guaranteed to have a `size_of` of 1. pub const NUM_SHARED_COLUMNS: usize = size_of::>(); diff --git a/evm/src/cpu/cpu_stark.rs b/evm/src/cpu/cpu_stark.rs index 7b34cc4f..4cc38823 100644 --- a/evm/src/cpu/cpu_stark.rs +++ b/evm/src/cpu/cpu_stark.rs @@ -11,8 +11,8 @@ use plonky2::hash::hash_types::RichField; use crate::constraint_consumer::{ConstraintConsumer, RecursiveConstraintConsumer}; use crate::cpu::columns::{CpuColumnsView, COL_MAP, NUM_CPU_COLUMNS}; use crate::cpu::{ - bootstrap_kernel, control_flow, decode, dup_swap, jumps, membus, modfp254, simple_logic, stack, - stack_bounds, syscalls, + bootstrap_kernel, control_flow, decode, dup_swap, jumps, membus, modfp254, shift, simple_logic, + stack, stack_bounds, syscalls, }; use crate::cross_table_lookup::Column; use crate::memory::segments::Segment; @@ -151,6 +151,7 @@ impl, const D: usize> Stark for CpuStark, const D: usize> Stark for CpuStark Kernel { include_str!("asm/sha2/store_pad.asm"), include_str!("asm/sha2/temp_words.asm"), include_str!("asm/sha2/write_length.asm"), + include_str!("asm/shift.asm"), include_str!("asm/transactions/router.asm"), include_str!("asm/transactions/type_0.asm"), include_str!("asm/transactions/type_1.asm"), diff --git a/evm/src/cpu/kernel/asm/main.asm b/evm/src/cpu/kernel/asm/main.asm index e8c8e3e4..41cb8079 100644 --- a/evm/src/cpu/kernel/asm/main.asm +++ b/evm/src/cpu/kernel/asm/main.asm @@ -1,5 +1,7 @@ global main: - // First, load all MPT data from the prover. + // First, initialise the shift table + %shift_table_init + // Second, load all MPT data from the prover. PUSH txn_loop %jump(load_all_mpts) diff --git a/evm/src/cpu/kernel/asm/shift.asm b/evm/src/cpu/kernel/asm/shift.asm new file mode 100644 index 00000000..ce481ea2 --- /dev/null +++ b/evm/src/cpu/kernel/asm/shift.asm @@ -0,0 +1,25 @@ +/// Initialise the lookup table of binary powers for doing left/right shifts +/// +/// Specifically, set SHIFT_TABLE_SEGMENT[i] = 2^i for i = 0..255. +%macro shift_table_init + push 1 // 2^0 + push 0 // initial offset is zero + push @SEGMENT_SHIFT_TABLE // segment + dup2 // kernel context is 0 + %rep 255 + // stack: context, segment, ost_i, 2^i + dup4 + dup1 + add + // stack: 2^(i+1), context, segment, ost_i, 2^i + dup4 + %increment + // stack: ost_(i+1), 2^(i+1), context, segment, ost_i, 2^i + dup4 + dup4 + // stack: context, segment, ost_(i+1), 2^(i+1), context, segment, ost_i, 2^i + %endrep + %rep 256 + mstore_general + %endrep +%endmacro diff --git a/evm/src/cpu/mod.rs b/evm/src/cpu/mod.rs index fda5db80..ece07c1c 100644 --- a/evm/src/cpu/mod.rs +++ b/evm/src/cpu/mod.rs @@ -8,6 +8,7 @@ mod jumps; pub mod kernel; pub(crate) mod membus; mod modfp254; +mod shift; mod simple_logic; mod stack; mod stack_bounds; diff --git a/evm/src/cpu/shift.rs b/evm/src/cpu/shift.rs new file mode 100644 index 00000000..d383b6b2 --- /dev/null +++ b/evm/src/cpu/shift.rs @@ -0,0 +1,108 @@ +use plonky2::field::extension::Extendable; +use plonky2::field::packed::PackedField; +use plonky2::field::types::Field; +use plonky2::hash::hash_types::RichField; +use plonky2::iop::ext_target::ExtensionTarget; + +use crate::constraint_consumer::{ConstraintConsumer, RecursiveConstraintConsumer}; +use crate::cpu::columns::CpuColumnsView; +use crate::cpu::membus::NUM_GP_CHANNELS; +use crate::memory::segments::Segment; + +pub(crate) fn eval_packed( + lv: &CpuColumnsView

, + yield_constr: &mut ConstraintConsumer

, +) { + let is_shift = lv.op.shl + lv.op.shr; + let displacement = lv.mem_channels[1]; // holds the shift displacement d + let two_exp = lv.mem_channels[2]; // holds 2^d + + // Not needed here; val is the input and we're verifying that output is + // val * 2^d (mod 2^256) + //let val = lv.mem_channels[0]; + //let output = lv.mem_channels[NUM_GP_CHANNELS - 1]; + + let shift_table_segment = P::Scalar::from_canonical_u64(Segment::ShiftTable as u64); + + // Only lookup the shifting factor when displacement is < 2^32. + // two_exp.used is true (1) if the high limbs of the displacement are + // zero and false (0) otherwise. + let high_limbs_are_zero = two_exp.used; + yield_constr.constraint(is_shift * (two_exp.is_read - P::ONES)); + + let high_limbs_sum: P = displacement.value[1..].iter().copied().sum(); + let high_limbs_sum_inv = lv.general.shift().high_limb_sum_inv; + // Verify that high_limbs_are_zero = 0 implies high_limbs_sum != 0 and + // high_limbs_are_zero = 1 implies high_limbs_sum = 0. + let t = high_limbs_sum * high_limbs_sum_inv - (P::ONES - high_limbs_are_zero); + yield_constr.constraint(is_shift * t); + yield_constr.constraint(is_shift * high_limbs_sum * high_limbs_are_zero); + + // When the shift displacement is < 2^32, constrain the two_exp + // mem_channel to be the entry corresponding to `displacement` in + // the shift table lookup (will be zero if displacement >= 256). + yield_constr.constraint(is_shift * two_exp.addr_context); // read from kernel memory + yield_constr.constraint(is_shift * (two_exp.addr_segment - shift_table_segment)); + yield_constr.constraint(is_shift * (two_exp.addr_virtual - displacement.value[0])); + + // Other channels must be unused + for chan in &lv.mem_channels[3..NUM_GP_CHANNELS - 1] { + yield_constr.constraint(is_shift * chan.used); // channel is not used + } + + // Cross-table lookup must connect the memory channels here to MUL + // (in the case of left shift) or DIV (in the case of right shift) + // in the arithmetic table. Specifically, the mapping is + // + // 0 -> 0 (value to be shifted is the same) + // 2 -> 1 (two_exp becomes the multiplicand (resp. divisor)) + // last -> last (output is the same) +} + +pub(crate) fn eval_ext_circuit, const D: usize>( + builder: &mut plonky2::plonk::circuit_builder::CircuitBuilder, + lv: &CpuColumnsView>, + yield_constr: &mut RecursiveConstraintConsumer, +) { + let is_shift = builder.add_extension(lv.op.shl, lv.op.shr); + let displacement = lv.mem_channels[1]; + let two_exp = lv.mem_channels[2]; + + let shift_table_segment = F::from_canonical_u64(Segment::ShiftTable as u64); + + let high_limbs_are_zero = two_exp.used; + let one = builder.one_extension(); + let t = builder.sub_extension(two_exp.is_read, one); + let t = builder.mul_extension(is_shift, t); + yield_constr.constraint(builder, t); + + let high_limbs_sum = builder.add_many_extension(&displacement.value[1..]); + let high_limbs_sum_inv = lv.general.shift().high_limb_sum_inv; + let t = builder.one_extension(); + let t = builder.sub_extension(t, high_limbs_are_zero); + let t = builder.mul_sub_extension(high_limbs_sum, high_limbs_sum_inv, t); + let t = builder.mul_extension(is_shift, t); + yield_constr.constraint(builder, t); + + let t = builder.mul_many_extension([is_shift, high_limbs_sum, high_limbs_are_zero]); + yield_constr.constraint(builder, t); + + let t = builder.mul_extension(is_shift, two_exp.addr_context); + yield_constr.constraint(builder, t); + let t = builder.arithmetic_extension( + F::ONE, + -shift_table_segment, + is_shift, + two_exp.addr_segment, + is_shift, + ); + yield_constr.constraint(builder, t); + let t = builder.sub_extension(two_exp.addr_virtual, displacement.value[0]); + let t = builder.mul_extension(is_shift, t); + yield_constr.constraint(builder, t); + + for chan in &lv.mem_channels[3..NUM_GP_CHANNELS - 1] { + let t = builder.mul_extension(is_shift, chan.used); + yield_constr.constraint(builder, t); + } +} diff --git a/evm/src/memory/segments.rs b/evm/src/memory/segments.rs index f8d536e9..7a28cb96 100644 --- a/evm/src/memory/segments.rs +++ b/evm/src/memory/segments.rs @@ -35,10 +35,13 @@ pub(crate) enum Segment { TrieEncodedChild = 14, /// A buffer used to store the lengths of the encodings of a branch node's children. TrieEncodedChildLen = 15, + /// A table of values 2^i for i=0..255 for use with shift + /// instructions; initialised by `kernel/asm/shift.asm::init_shift_table()`. + ShiftTable = 16, } impl Segment { - pub(crate) const COUNT: usize = 16; + pub(crate) const COUNT: usize = 17; pub(crate) fn all() -> [Self; Self::COUNT] { [ @@ -58,6 +61,7 @@ impl Segment { Self::TrieData, Self::TrieEncodedChild, Self::TrieEncodedChildLen, + Self::ShiftTable, ] } @@ -80,6 +84,7 @@ impl Segment { Segment::TrieData => "SEGMENT_TRIE_DATA", Segment::TrieEncodedChild => "SEGMENT_TRIE_ENCODED_CHILD", Segment::TrieEncodedChildLen => "SEGMENT_TRIE_ENCODED_CHILD_LEN", + Segment::ShiftTable => "SEGMENT_SHIFT_TABLE", } } @@ -102,6 +107,7 @@ impl Segment { Segment::TrieData => 256, Segment::TrieEncodedChild => 256, Segment::TrieEncodedChildLen => 6, + Segment::ShiftTable => 256, } } }