diff --git a/evm/src/arithmetic/mod.rs b/evm/src/arithmetic/mod.rs index d214fa3b..60c8a2f8 100644 --- a/evm/src/arithmetic/mod.rs +++ b/evm/src/arithmetic/mod.rs @@ -23,8 +23,6 @@ pub(crate) enum BinaryOperator { Mod, Lt, Gt, - Shl, - Shr, AddFp254, MulFp254, SubFp254, @@ -64,20 +62,6 @@ impl BinaryOperator { U256::zero() } } - BinaryOperator::Shl => { - if input0 > 255.into() { - U256::zero() - } else { - input1 << input0 - } - } - BinaryOperator::Shr => { - if input0 > 255.into() { - U256::zero() - } else { - input1 >> input0 - } - } BinaryOperator::AddFp254 => addmod(input0, input1, bn_base_order()), BinaryOperator::MulFp254 => mulmod(input0, input1, bn_base_order()), BinaryOperator::SubFp254 => submod(input0, input1, bn_base_order()), diff --git a/evm/src/witness/operation.rs b/evm/src/witness/operation.rs index b8cb2473..54d8ce3f 100644 --- a/evm/src/witness/operation.rs +++ b/evm/src/witness/operation.rs @@ -12,7 +12,7 @@ use crate::cpu::simple_logic::eq_iszero::generate_pinv_diff; use crate::generation::state::GenerationState; use crate::memory::segments::Segment; use crate::witness::errors::ProgramError; -use crate::witness::memory::MemoryAddress; +use crate::witness::memory::{MemoryAddress, MemoryOp}; use crate::witness::util::{ keccak_sponge_log, mem_read_code_with_log_and_fill, mem_read_gp_with_log_and_fill, mem_write_gp_log_and_fill, stack_pop_with_log_and_fill, stack_push_log_and_fill, @@ -24,6 +24,8 @@ pub(crate) enum Operation { Iszero, Not, Byte, + Shl, + Shr, Syscall(u8), Eq, BinaryLogic(logic::Op), @@ -73,25 +75,8 @@ pub(crate) fn generate_binary_arithmetic_op( let [(input0, log_in0), (input1, log_in1)] = stack_pop_with_log_and_fill::<2, _>(state, &mut row)?; let operation = arithmetic::Operation::binary(operator, input0, input1); - let log_out = stack_push_log_and_fill(state, &mut row, operation.result())?; - if operator == arithmetic::BinaryOperator::Shl || operator == arithmetic::BinaryOperator::Shr { - const LOOKUP_CHANNEL: usize = 2; - let lookup_addr = MemoryAddress::new(0, Segment::ShiftTable, input0.low_u32() as usize); - if input0.bits() <= 32 { - let (_, read) = - mem_read_gp_with_log_and_fill(LOOKUP_CHANNEL, lookup_addr, state, &mut row); - state.traces.push_memory(read); - } else { - // The shift constraints still expect the address to be set, even though no read will occur. - let mut channel = &mut row.mem_channels[LOOKUP_CHANNEL]; - channel.addr_context = F::from_canonical_usize(lookup_addr.context); - channel.addr_segment = F::from_canonical_usize(lookup_addr.segment); - channel.addr_virtual = F::from_canonical_usize(lookup_addr.virt); - } - } - state.traces.push_arithmetic(operation); state.traces.push_memory(log_in0); state.traces.push_memory(log_in1); @@ -490,6 +475,56 @@ pub(crate) fn generate_iszero( Ok(()) } +fn append_shift( + state: &mut GenerationState, + mut row: CpuColumnsView, + input0: U256, + log_in0: MemoryOp, + log_in1: MemoryOp, + result: U256, +) -> Result<(), ProgramError> { + let log_out = stack_push_log_and_fill(state, &mut row, result)?; + + const LOOKUP_CHANNEL: usize = 2; + let lookup_addr = MemoryAddress::new(0, Segment::ShiftTable, input0.low_u32() as usize); + if input0.bits() <= 32 { + let (_, read) = mem_read_gp_with_log_and_fill(LOOKUP_CHANNEL, lookup_addr, state, &mut row); + state.traces.push_memory(read); + } else { + // The shift constraints still expect the address to be set, even though no read will occur. + let mut channel = &mut row.mem_channels[LOOKUP_CHANNEL]; + channel.addr_context = F::from_canonical_usize(lookup_addr.context); + channel.addr_segment = F::from_canonical_usize(lookup_addr.segment); + channel.addr_virtual = F::from_canonical_usize(lookup_addr.virt); + } + + state.traces.push_memory(log_in0); + state.traces.push_memory(log_in1); + state.traces.push_memory(log_out); + state.traces.push_cpu(row); + Ok(()) +} + +pub(crate) fn generate_shl( + state: &mut GenerationState, + mut row: CpuColumnsView, +) -> Result<(), ProgramError> { + let [(input0, log_in0), (input1, log_in1)] = + stack_pop_with_log_and_fill::<2, _>(state, &mut row)?; + let result = input1 << input0; + append_shift(state, row, input0, log_in0, log_in1, result) +} + +pub(crate) fn generate_shr( + state: &mut GenerationState, + mut row: CpuColumnsView, +) -> Result<(), ProgramError> { + let [(input0, log_in0), (input1, log_in1)] = + stack_pop_with_log_and_fill::<2, _>(state, &mut row)?; + let result = input1 >> input0; + append_shift(state, row, input0, log_in0, log_in1, result) +} + pub(crate) fn generate_syscall( opcode: u8, state: &mut GenerationState, diff --git a/evm/src/witness/transition.rs b/evm/src/witness/transition.rs index 6f007b22..b8e46d78 100644 --- a/evm/src/witness/transition.rs +++ b/evm/src/witness/transition.rs @@ -62,8 +62,8 @@ fn decode(registers: RegistersState, opcode: u8) -> Result Ok(Operation::BinaryLogic(logic::Op::Xor)), (0x19, _) => Ok(Operation::Not), (0x1a, _) => Ok(Operation::Byte), - (0x1b, _) => Ok(Operation::BinaryArithmetic(arithmetic::BinaryOperator::Shl)), - (0x1c, _) => Ok(Operation::BinaryArithmetic(arithmetic::BinaryOperator::Shr)), + (0x1b, _) => Ok(Operation::Shl), + (0x1c, _) => Ok(Operation::Shr), (0x1d, _) => Ok(Operation::Syscall(opcode)), (0x20, _) => Ok(Operation::Syscall(opcode)), (0x21, true) => Ok(Operation::KeccakGeneral), @@ -160,8 +160,8 @@ fn fill_op_flag(op: Operation, row: &mut CpuColumnsView) { Operation::BinaryArithmetic(arithmetic::BinaryOperator::Mod) => &mut flags.mod_, Operation::BinaryArithmetic(arithmetic::BinaryOperator::Lt) => &mut flags.lt, Operation::BinaryArithmetic(arithmetic::BinaryOperator::Gt) => &mut flags.gt, - Operation::BinaryArithmetic(arithmetic::BinaryOperator::Shl) => &mut flags.shl, - Operation::BinaryArithmetic(arithmetic::BinaryOperator::Shr) => &mut flags.shr, + Operation::Shl => &mut flags.shl, + Operation::Shr => &mut flags.shr, Operation::BinaryArithmetic(arithmetic::BinaryOperator::AddFp254) => &mut flags.addfp254, Operation::BinaryArithmetic(arithmetic::BinaryOperator::MulFp254) => &mut flags.mulfp254, Operation::BinaryArithmetic(arithmetic::BinaryOperator::SubFp254) => &mut flags.subfp254, @@ -196,6 +196,8 @@ fn perform_op( Operation::Iszero => generate_iszero(state, row)?, Operation::Not => generate_not(state, row)?, Operation::Byte => generate_byte(state, row)?, + Operation::Shl => generate_shl(state, row)?, + Operation::Shr => generate_shr(state, row)?, Operation::Syscall(opcode) => generate_syscall(opcode, state, row)?, Operation::Eq => generate_eq(state, row)?, Operation::BinaryLogic(binary_logic_op) => {