EVM shift left/right operations (#801)

* First parts of shift implementation.

* Disable range check errors.

* Tidy up ASM.

* Update comments; fix some .sum() expressions.

* First full draft of shift left/right.

* Missed a +1.

* Clippy.

* Address Jacqui's comments.

* Add comment.

* Fix missing filter.

* Address second round of comments from Jacqui.
This commit is contained in:
Hamish Ivey-Law 2022-11-09 10:47:15 +11:00 committed by GitHub
parent 7126231b52
commit 1c87fbb712
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 183 additions and 18 deletions

View File

@ -1,6 +1,5 @@
use std::ops::{Add, AddAssign, Mul, Neg, Range, Shr, Sub, SubAssign};
use log::error;
use plonky2::field::extension::Extendable;
use plonky2::hash::hash_types::RichField;
use plonky2::iop::ext_target::ExtensionTarget;
@ -11,21 +10,24 @@ use crate::arithmetic::columns::{NUM_ARITH_COLUMNS, N_LIMBS};
/// Emit an error message regarding unchecked range assumptions.
/// Assumes the values in `cols` are `[cols[0], cols[0] + 1, ...,
/// cols[0] + cols.len() - 1]`.
///
/// TODO: Hamish to delete this when he has implemented and integrated
/// range checks.
pub(crate) fn _range_check_error<const RC_BITS: u32>(
file: &str,
line: u32,
cols: Range<usize>,
signedness: &str,
_file: &str,
_line: u32,
_cols: Range<usize>,
_signedness: &str,
) {
error!(
"{}:{}: arithmetic unit skipped {}-bit {} range-checks on columns {}--{}: not yet implemented",
line,
file,
RC_BITS,
signedness,
cols.start,
cols.end - 1,
);
// error!(
// "{}:{}: arithmetic unit skipped {}-bit {} range-checks on columns {}--{}: not yet implemented",
// line,
// file,
// RC_BITS,
// signedness,
// cols.start,
// cols.end - 1,
// );
}
#[macro_export]

View File

@ -9,6 +9,7 @@ pub(crate) union CpuGeneralColumnsView<T: Copy> {
arithmetic: CpuArithmeticView<T>,
logic: CpuLogicView<T>,
jumps: CpuJumpsView<T>,
shift: CpuShiftView<T>,
}
impl<T: Copy> CpuGeneralColumnsView<T> {
@ -51,6 +52,16 @@ impl<T: Copy> CpuGeneralColumnsView<T> {
pub(crate) fn jumps_mut(&mut self) -> &mut CpuJumpsView<T> {
unsafe { &mut self.jumps }
}
// SAFETY: Each view is a valid interpretation of the underlying array.
pub(crate) fn shift(&self) -> &CpuShiftView<T> {
unsafe { &self.shift }
}
// SAFETY: Each view is a valid interpretation of the underlying array.
pub(crate) fn shift_mut(&mut self) -> &mut CpuShiftView<T> {
unsafe { &mut self.shift }
}
}
impl<T: Copy + PartialEq> PartialEq<Self> for CpuGeneralColumnsView<T> {
@ -144,5 +155,12 @@ pub(crate) struct CpuJumpsView<T: Copy> {
pub(crate) should_trap: T,
}
#[derive(Copy, Clone)]
pub(crate) struct CpuShiftView<T: Copy> {
// For a shift amount of displacement: [T], this is the inverse of
// sum(displacement[1..]) or zero if the sum is zero.
pub(crate) high_limb_sum_inv: T,
}
// `u8` is guaranteed to have a `size_of` of 1.
pub const NUM_SHARED_COLUMNS: usize = size_of::<CpuGeneralColumnsView<u8>>();

View File

@ -11,8 +11,8 @@ use plonky2::hash::hash_types::RichField;
use crate::constraint_consumer::{ConstraintConsumer, RecursiveConstraintConsumer};
use crate::cpu::columns::{CpuColumnsView, COL_MAP, NUM_CPU_COLUMNS};
use crate::cpu::{
bootstrap_kernel, control_flow, decode, dup_swap, jumps, membus, modfp254, simple_logic, stack,
stack_bounds, syscalls,
bootstrap_kernel, control_flow, decode, dup_swap, jumps, membus, modfp254, shift, simple_logic,
stack, stack_bounds, syscalls,
};
use crate::cross_table_lookup::Column;
use crate::memory::segments::Segment;
@ -151,6 +151,7 @@ impl<F: RichField + Extendable<D>, const D: usize> Stark<F, D> for CpuStark<F, D
jumps::eval_packed(local_values, next_values, yield_constr);
membus::eval_packed(local_values, yield_constr);
modfp254::eval_packed(local_values, yield_constr);
shift::eval_packed(local_values, yield_constr);
simple_logic::eval_packed(local_values, yield_constr);
stack::eval_packed(local_values, yield_constr);
stack_bounds::eval_packed(local_values, yield_constr);
@ -172,6 +173,7 @@ impl<F: RichField + Extendable<D>, const D: usize> Stark<F, D> for CpuStark<F, D
jumps::eval_ext_circuit(builder, local_values, next_values, yield_constr);
membus::eval_ext_circuit(builder, local_values, yield_constr);
modfp254::eval_ext_circuit(builder, local_values, yield_constr);
shift::eval_ext_circuit(builder, local_values, yield_constr);
simple_logic::eval_ext_circuit(builder, local_values, yield_constr);
stack::eval_ext_circuit(builder, local_values, yield_constr);
stack_bounds::eval_ext_circuit(builder, local_values, yield_constr);

View File

@ -77,6 +77,7 @@ pub(crate) fn combined_kernel() -> Kernel {
include_str!("asm/sha2/store_pad.asm"),
include_str!("asm/sha2/temp_words.asm"),
include_str!("asm/sha2/write_length.asm"),
include_str!("asm/shift.asm"),
include_str!("asm/transactions/router.asm"),
include_str!("asm/transactions/type_0.asm"),
include_str!("asm/transactions/type_1.asm"),

View File

@ -1,5 +1,7 @@
global main:
// First, load all MPT data from the prover.
// First, initialise the shift table
%shift_table_init
// Second, load all MPT data from the prover.
PUSH txn_loop
%jump(load_all_mpts)

View File

@ -0,0 +1,25 @@
/// Initialise the lookup table of binary powers for doing left/right shifts
///
/// Specifically, set SHIFT_TABLE_SEGMENT[i] = 2^i for i = 0..255.
%macro shift_table_init
push 1 // 2^0
push 0 // initial offset is zero
push @SEGMENT_SHIFT_TABLE // segment
dup2 // kernel context is 0
%rep 255
// stack: context, segment, ost_i, 2^i
dup4
dup1
add
// stack: 2^(i+1), context, segment, ost_i, 2^i
dup4
%increment
// stack: ost_(i+1), 2^(i+1), context, segment, ost_i, 2^i
dup4
dup4
// stack: context, segment, ost_(i+1), 2^(i+1), context, segment, ost_i, 2^i
%endrep
%rep 256
mstore_general
%endrep
%endmacro

View File

@ -8,6 +8,7 @@ mod jumps;
pub mod kernel;
pub(crate) mod membus;
mod modfp254;
mod shift;
mod simple_logic;
mod stack;
mod stack_bounds;

108
evm/src/cpu/shift.rs Normal file
View File

@ -0,0 +1,108 @@
use plonky2::field::extension::Extendable;
use plonky2::field::packed::PackedField;
use plonky2::field::types::Field;
use plonky2::hash::hash_types::RichField;
use plonky2::iop::ext_target::ExtensionTarget;
use crate::constraint_consumer::{ConstraintConsumer, RecursiveConstraintConsumer};
use crate::cpu::columns::CpuColumnsView;
use crate::cpu::membus::NUM_GP_CHANNELS;
use crate::memory::segments::Segment;
pub(crate) fn eval_packed<P: PackedField>(
lv: &CpuColumnsView<P>,
yield_constr: &mut ConstraintConsumer<P>,
) {
let is_shift = lv.op.shl + lv.op.shr;
let displacement = lv.mem_channels[1]; // holds the shift displacement d
let two_exp = lv.mem_channels[2]; // holds 2^d
// Not needed here; val is the input and we're verifying that output is
// val * 2^d (mod 2^256)
//let val = lv.mem_channels[0];
//let output = lv.mem_channels[NUM_GP_CHANNELS - 1];
let shift_table_segment = P::Scalar::from_canonical_u64(Segment::ShiftTable as u64);
// Only lookup the shifting factor when displacement is < 2^32.
// two_exp.used is true (1) if the high limbs of the displacement are
// zero and false (0) otherwise.
let high_limbs_are_zero = two_exp.used;
yield_constr.constraint(is_shift * (two_exp.is_read - P::ONES));
let high_limbs_sum: P = displacement.value[1..].iter().copied().sum();
let high_limbs_sum_inv = lv.general.shift().high_limb_sum_inv;
// Verify that high_limbs_are_zero = 0 implies high_limbs_sum != 0 and
// high_limbs_are_zero = 1 implies high_limbs_sum = 0.
let t = high_limbs_sum * high_limbs_sum_inv - (P::ONES - high_limbs_are_zero);
yield_constr.constraint(is_shift * t);
yield_constr.constraint(is_shift * high_limbs_sum * high_limbs_are_zero);
// When the shift displacement is < 2^32, constrain the two_exp
// mem_channel to be the entry corresponding to `displacement` in
// the shift table lookup (will be zero if displacement >= 256).
yield_constr.constraint(is_shift * two_exp.addr_context); // read from kernel memory
yield_constr.constraint(is_shift * (two_exp.addr_segment - shift_table_segment));
yield_constr.constraint(is_shift * (two_exp.addr_virtual - displacement.value[0]));
// Other channels must be unused
for chan in &lv.mem_channels[3..NUM_GP_CHANNELS - 1] {
yield_constr.constraint(is_shift * chan.used); // channel is not used
}
// Cross-table lookup must connect the memory channels here to MUL
// (in the case of left shift) or DIV (in the case of right shift)
// in the arithmetic table. Specifically, the mapping is
//
// 0 -> 0 (value to be shifted is the same)
// 2 -> 1 (two_exp becomes the multiplicand (resp. divisor))
// last -> last (output is the same)
}
pub(crate) fn eval_ext_circuit<F: RichField + Extendable<D>, const D: usize>(
builder: &mut plonky2::plonk::circuit_builder::CircuitBuilder<F, D>,
lv: &CpuColumnsView<ExtensionTarget<D>>,
yield_constr: &mut RecursiveConstraintConsumer<F, D>,
) {
let is_shift = builder.add_extension(lv.op.shl, lv.op.shr);
let displacement = lv.mem_channels[1];
let two_exp = lv.mem_channels[2];
let shift_table_segment = F::from_canonical_u64(Segment::ShiftTable as u64);
let high_limbs_are_zero = two_exp.used;
let one = builder.one_extension();
let t = builder.sub_extension(two_exp.is_read, one);
let t = builder.mul_extension(is_shift, t);
yield_constr.constraint(builder, t);
let high_limbs_sum = builder.add_many_extension(&displacement.value[1..]);
let high_limbs_sum_inv = lv.general.shift().high_limb_sum_inv;
let t = builder.one_extension();
let t = builder.sub_extension(t, high_limbs_are_zero);
let t = builder.mul_sub_extension(high_limbs_sum, high_limbs_sum_inv, t);
let t = builder.mul_extension(is_shift, t);
yield_constr.constraint(builder, t);
let t = builder.mul_many_extension([is_shift, high_limbs_sum, high_limbs_are_zero]);
yield_constr.constraint(builder, t);
let t = builder.mul_extension(is_shift, two_exp.addr_context);
yield_constr.constraint(builder, t);
let t = builder.arithmetic_extension(
F::ONE,
-shift_table_segment,
is_shift,
two_exp.addr_segment,
is_shift,
);
yield_constr.constraint(builder, t);
let t = builder.sub_extension(two_exp.addr_virtual, displacement.value[0]);
let t = builder.mul_extension(is_shift, t);
yield_constr.constraint(builder, t);
for chan in &lv.mem_channels[3..NUM_GP_CHANNELS - 1] {
let t = builder.mul_extension(is_shift, chan.used);
yield_constr.constraint(builder, t);
}
}

View File

@ -35,10 +35,13 @@ pub(crate) enum Segment {
TrieEncodedChild = 14,
/// A buffer used to store the lengths of the encodings of a branch node's children.
TrieEncodedChildLen = 15,
/// A table of values 2^i for i=0..255 for use with shift
/// instructions; initialised by `kernel/asm/shift.asm::init_shift_table()`.
ShiftTable = 16,
}
impl Segment {
pub(crate) const COUNT: usize = 16;
pub(crate) const COUNT: usize = 17;
pub(crate) fn all() -> [Self; Self::COUNT] {
[
@ -58,6 +61,7 @@ impl Segment {
Self::TrieData,
Self::TrieEncodedChild,
Self::TrieEncodedChildLen,
Self::ShiftTable,
]
}
@ -80,6 +84,7 @@ impl Segment {
Segment::TrieData => "SEGMENT_TRIE_DATA",
Segment::TrieEncodedChild => "SEGMENT_TRIE_ENCODED_CHILD",
Segment::TrieEncodedChildLen => "SEGMENT_TRIE_ENCODED_CHILD_LEN",
Segment::ShiftTable => "SEGMENT_SHIFT_TABLE",
}
}
@ -102,6 +107,7 @@ impl Segment {
Segment::TrieData => 256,
Segment::TrieEncodedChild => 256,
Segment::TrieEncodedChildLen => 6,
Segment::ShiftTable => 256,
}
}
}