diff --git a/evm/src/cpu/columns/general.rs b/evm/src/cpu/columns/general.rs index 67fe4256..6727eb07 100644 --- a/evm/src/cpu/columns/general.rs +++ b/evm/src/cpu/columns/general.rs @@ -97,46 +97,10 @@ pub(crate) struct CpuLogicView { #[derive(Copy, Clone)] pub(crate) struct CpuJumpsView { - /// `input0` is `mem_channel[0].value`. It's the top stack value at entry (for jumps, the - /// address; for `EXIT_KERNEL`, the address and new privilege level). - /// `input1` is `mem_channel[1].value`. For `JUMPI`, it's the second stack value (the - /// predicate). For `JUMP`, 1. - - /// Inverse of `input0[1] + ... + input0[7]`, if one exists; otherwise, an arbitrary value. - /// Needed to prove that `input0` is nonzero. - pub(crate) input0_upper_sum_inv: T, - /// 1 if `input0[1..7]` is zero; else 0. - pub(crate) input0_upper_zero: T, - - /// 1 if `input0[0]` is the address of a valid jump destination (i.e. `JUMPDEST` that is not - /// part of a `PUSH` immediate); else 0. Note that the kernel is allowed to jump anywhere it - /// wants, so this flag is computed but ignored in kernel mode. - /// NOTE: this flag only considers `input0[0]`, the low 32 bits of the 256-bit register. Even if - /// this flag is 1, `input0` will still be an invalid address if the high 224 bits are not 0. - pub(crate) dst_valid: T, // TODO: populate this (check for JUMPDEST) - /// 1 if either `dst_valid` is 1 or we are in kernel mode; else 0. (Just a logical OR.) - pub(crate) dst_valid_or_kernel: T, - /// 1 if `dst_valid_or_kernel` and `input0_upper_zero` are both 1; else 0. In other words, we - /// are allowed to jump to `input0[0]` because either it's a valid address or we're in kernel - /// mode (`dst_valid_or_kernel`), and also `input0[1..7]` are all 0 so `input0[0]` is in fact - /// the whole address (we're not being asked to jump to an address that would overflow). - pub(crate) input0_jumpable: T, - - /// Inverse of `input1[0] + ... + input1[7]`, if one exists; otherwise, an arbitrary value. - /// Needed to prove that `input1` is nonzero. - pub(crate) input1_sum_inv: T, - - /// Note that the below flags are mutually exclusive. - /// 1 if the JUMPI falls though (because input1 is 0); else 0. - pub(crate) should_continue: T, - /// 1 if the JUMP/JUMPI does in fact jump to `input0`; else 0. This requires `input0` to be a - /// valid destination (`input0[0]` is a `JUMPDEST` not in an immediate or we are in kernel mode - /// and also `input0[1..7]` is 0) and `input1` to be nonzero. + // A flag. pub(crate) should_jump: T, - /// 1 if the JUMP/JUMPI faults; else 0. This happens when `input0` is not a valid destination - /// (`input0[0]` is not `JUMPDEST` that is not in an immediate while we are in user mode, or - /// `input0[1..7]` is nonzero) and `input1` is nonzero. - pub(crate) should_trap: T, + // Pseudoinverse of `cond.iter().sum()`. Used to check `should_jump`. + pub(crate) cond_sum_pinv: T, } #[derive(Copy, Clone)] diff --git a/evm/src/cpu/cpu_stark.rs b/evm/src/cpu/cpu_stark.rs index cd154e54..1af8428d 100644 --- a/evm/src/cpu/cpu_stark.rs +++ b/evm/src/cpu/cpu_stark.rs @@ -145,7 +145,7 @@ impl, const D: usize> Stark for CpuStark, const D: usize> Stark for CpuStark = - Lazy::new(|| KERNEL.global_labels["fault_exception"]); +use crate::cpu::membus::NUM_GP_CHANNELS; +use crate::memory::segments::Segment; pub fn eval_packed_exit_kernel( lv: &CpuColumnsView

, @@ -58,99 +55,65 @@ pub fn eval_packed_jump_jumpi( yield_constr: &mut ConstraintConsumer

, ) { let jumps_lv = lv.general.jumps(); - let input0 = lv.mem_channels[0].value; - let input1 = lv.mem_channels[1].value; + let dst = lv.mem_channels[0].value; + let cond = lv.mem_channels[1].value; let filter = lv.op.jump + lv.op.jumpi; // `JUMP` or `JUMPI` + let jumpdest_flag_channel = lv.mem_channels[NUM_GP_CHANNELS - 1]; // If `JUMP`, re-use the `JUMPI` logic, but setting the second input (the predicate) to be 1. - // In other words, we implement `JUMP(addr)` as `JUMPI(addr, cond=1)`. - yield_constr.constraint(lv.op.jump * (input1[0] - P::ONES)); - for &limb in &input1[1..] { + // In other words, we implement `JUMP(dst)` as `JUMPI(dst, cond=1)`. + yield_constr.constraint(lv.op.jump * (cond[0] - P::ONES)); + for &limb in &cond[1..] { // Set all limbs (other than the least-significant limb) to 0. // NB: Technically, they don't have to be 0, as long as the sum - // `input1[0] + ... + input1[7]` cannot overflow. + // `cond[0] + ... + cond[7]` cannot overflow. yield_constr.constraint(lv.op.jump * limb); } - // Check `input0_upper_zero` - // `input0_upper_zero` is either 0 or 1. - yield_constr - .constraint(filter * jumps_lv.input0_upper_zero * (jumps_lv.input0_upper_zero - P::ONES)); - // The below sum cannot overflow due to the limb size. - let input0_upper_sum: P = input0[1..].iter().copied().sum(); - // `input0_upper_zero` = 1 implies `input0_upper_sum` = 0. - yield_constr.constraint(filter * jumps_lv.input0_upper_zero * input0_upper_sum); - // `input0_upper_zero` = 0 implies `input0_upper_sum_inv * input0_upper_sum` = 1, which can only - // happen when `input0_upper_sum` is nonzero. - yield_constr.constraint( - filter - * (jumps_lv.input0_upper_sum_inv * input0_upper_sum + jumps_lv.input0_upper_zero - - P::ONES), - ); - - // Check `dst_valid_or_kernel` (this is just a logical OR) - yield_constr.constraint( - filter - * (jumps_lv.dst_valid + lv.is_kernel_mode - - jumps_lv.dst_valid * lv.is_kernel_mode - - jumps_lv.dst_valid_or_kernel), - ); - - // Check `input0_jumpable` (this is just `dst_valid_or_kernel` AND `input0_upper_zero`) - yield_constr.constraint( - filter - * (jumps_lv.dst_valid_or_kernel * jumps_lv.input0_upper_zero - - jumps_lv.input0_jumpable), - ); - - // Make sure that `should_continue`, `should_jump`, `should_trap` are all binary and exactly one - // is set. - yield_constr - .constraint(filter * jumps_lv.should_continue * (jumps_lv.should_continue - P::ONES)); + // Check `should_jump`: yield_constr.constraint(filter * jumps_lv.should_jump * (jumps_lv.should_jump - P::ONES)); - yield_constr.constraint(filter * jumps_lv.should_trap * (jumps_lv.should_trap - P::ONES)); + let cond_sum: P = cond.into_iter().sum(); + yield_constr.constraint(filter * (jumps_lv.should_jump - P::ONES) * cond_sum); + yield_constr.constraint(filter * (jumps_lv.cond_sum_pinv * cond_sum - jumps_lv.should_jump)); + + // If we're jumping, then the high 7 limbs of the destination must be 0. + let dst_hi_sum: P = dst[1..].iter().copied().sum(); + yield_constr.constraint(filter * jumps_lv.should_jump * dst_hi_sum); + // Check that the destination address holds a `JUMPDEST` instruction. Note that this constraint + // does not need to be conditioned on `should_jump` because no read takes place if we're not + // jumping, so we're free to set the channel to 1. + yield_constr.constraint(filter * (jumpdest_flag_channel.value[0] - P::ONES)); + + // Make sure that the JUMPDEST flag channel is constrained. + // Only need to read if we're about to jump and we're not in kernel mode. yield_constr.constraint( - filter * (jumps_lv.should_continue + jumps_lv.should_jump + jumps_lv.should_trap - P::ONES), - ); - - // Validate `should_continue` - // This sum cannot overflow (due to limb size). - let input1_sum: P = input1.into_iter().sum(); - // `should_continue` = 1 implies `input1_sum` = 0. - yield_constr.constraint(filter * jumps_lv.should_continue * input1_sum); - // `should_continue` = 0 implies `input1_sum * input1_sum_inv` = 1, which can only happen if - // input1_sum is nonzero. - yield_constr.constraint( - filter * (input1_sum * jumps_lv.input1_sum_inv + jumps_lv.should_continue - P::ONES), - ); - - // Validate `should_jump` and `should_trap` by splitting on `input0_jumpable`. - // Note that `should_jump` = 1 and `should_trap` = 1 both imply that `should_continue` = 0, so - // `input1` is nonzero. - yield_constr.constraint(filter * jumps_lv.should_jump * (jumps_lv.input0_jumpable - P::ONES)); - yield_constr.constraint(filter * jumps_lv.should_trap * jumps_lv.input0_jumpable); - - // Handle trap - // Set program counter and kernel flag - yield_constr - .constraint_transition(filter * jumps_lv.should_trap * (nv.is_kernel_mode - P::ONES)); - yield_constr.constraint_transition( filter - * jumps_lv.should_trap - * (nv.program_counter - P::Scalar::from_canonical_usize(*INVALID_DST_HANDLER_ADDR)), + * (jumpdest_flag_channel.used - jumps_lv.should_jump * (P::ONES - lv.is_kernel_mode)), ); + yield_constr.constraint(filter * (jumpdest_flag_channel.is_read - P::ONES)); + yield_constr.constraint(filter * (jumpdest_flag_channel.addr_context - lv.context)); + yield_constr.constraint( + filter + * (jumpdest_flag_channel.addr_segment + - P::Scalar::from_canonical_u64(Segment::JumpdestBits as u64)), + ); + yield_constr.constraint(filter * (jumpdest_flag_channel.addr_virtual - dst[0])); - // Handle continue and jump - let continue_or_jump = jumps_lv.should_continue + jumps_lv.should_jump; - // Keep kernel mode. - yield_constr - .constraint_transition(filter * continue_or_jump * (nv.is_kernel_mode - lv.is_kernel_mode)); - // Set program counter depending on whether we're continuing or jumping. + // Disable unused memory channels + for &channel in &lv.mem_channels[2..NUM_GP_CHANNELS - 1] { + yield_constr.constraint(filter * channel.used); + } + // Channel 1 is unused by the `JUMP` instruction. + yield_constr.constraint(lv.op.jump * lv.mem_channels[1].used); + + // Finally, set the next program counter. + let fallthrough_dst = lv.program_counter + P::ONES; + let jump_dest = dst[0]; yield_constr.constraint_transition( - filter * jumps_lv.should_continue * (nv.program_counter - lv.program_counter - P::ONES), + filter * (jumps_lv.should_jump - P::ONES) * (nv.program_counter - fallthrough_dst), ); yield_constr - .constraint_transition(filter * jumps_lv.should_jump * (nv.program_counter - input0[0])); + .constraint_transition(filter * jumps_lv.should_jump * (nv.program_counter - jump_dest)); } pub fn eval_ext_circuit_jump_jumpi, const D: usize>( @@ -160,178 +123,124 @@ pub fn eval_ext_circuit_jump_jumpi, const D: usize> yield_constr: &mut RecursiveConstraintConsumer, ) { let jumps_lv = lv.general.jumps(); - let input0 = lv.mem_channels[0].value; - let input1 = lv.mem_channels[1].value; + let dst = lv.mem_channels[0].value; + let cond = lv.mem_channels[1].value; let filter = builder.add_extension(lv.op.jump, lv.op.jumpi); // `JUMP` or `JUMPI` + let jumpdest_flag_channel = lv.mem_channels[NUM_GP_CHANNELS - 1]; // If `JUMP`, re-use the `JUMPI` logic, but setting the second input (the predicate) to be 1. - // In other words, we implement `JUMP(addr)` as `JUMPI(addr, cond=1)`. + // In other words, we implement `JUMP(dst)` as `JUMPI(dst, cond=1)`. { - let constr = builder.mul_sub_extension(lv.op.jump, input1[0], lv.op.jump); + let constr = builder.mul_sub_extension(lv.op.jump, cond[0], lv.op.jump); yield_constr.constraint(builder, constr); } - for &limb in &input1[1..] { + for &limb in &cond[1..] { // Set all limbs (other than the least-significant limb) to 0. // NB: Technically, they don't have to be 0, as long as the sum - // `input1[0] + ... + input1[7]` cannot overflow. + // `cond[0] + ... + cond[7]` cannot overflow. let constr = builder.mul_extension(lv.op.jump, limb); yield_constr.constraint(builder, constr); } - // Check `input0_upper_zero` - // `input0_upper_zero` is either 0 or 1. + // Check `should_jump`: { let constr = builder.mul_sub_extension( - jumps_lv.input0_upper_zero, - jumps_lv.input0_upper_zero, - jumps_lv.input0_upper_zero, + jumps_lv.should_jump, + jumps_lv.should_jump, + jumps_lv.should_jump, ); let constr = builder.mul_extension(filter, constr); yield_constr.constraint(builder, constr); } + let cond_sum = builder.add_many_extension(cond); { - // The below sum cannot overflow due to the limb size. - let input0_upper_sum = builder.add_many_extension(input0[1..].iter()); - - // `input0_upper_zero` = 1 implies `input0_upper_sum` = 0. - let constr = builder.mul_extension(jumps_lv.input0_upper_zero, input0_upper_sum); + let constr = builder.mul_sub_extension(cond_sum, jumps_lv.should_jump, cond_sum); let constr = builder.mul_extension(filter, constr); yield_constr.constraint(builder, constr); - - // `input0_upper_zero` = 0 implies `input0_upper_sum_inv * input0_upper_sum` = 1, which can - // only happen when `input0_upper_sum` is nonzero. - let constr = builder.mul_add_extension( - jumps_lv.input0_upper_sum_inv, - input0_upper_sum, - jumps_lv.input0_upper_zero, - ); - let constr = builder.mul_sub_extension(filter, constr, filter); - yield_constr.constraint(builder, constr); - }; - - // Check `dst_valid_or_kernel` (this is just a logical OR) + } { - let constr = builder.mul_add_extension( - jumps_lv.dst_valid, + let constr = + builder.mul_sub_extension(jumps_lv.cond_sum_pinv, cond_sum, jumps_lv.should_jump); + let constr = builder.mul_extension(filter, constr); + yield_constr.constraint(builder, constr); + } + + // If we're jumping, then the high 7 limbs of the destination must be 0. + let dst_hi_sum = builder.add_many_extension(&dst[1..]); + { + let constr = builder.mul_extension(jumps_lv.should_jump, dst_hi_sum); + let constr = builder.mul_extension(filter, constr); + yield_constr.constraint(builder, constr); + } + // Check that the destination address holds a `JUMPDEST` instruction. Note that this constraint + // does not need to be conditioned on `should_jump` because no read takes place if we're not + // jumping, so we're free to set the channel to 1. + { + let constr = builder.mul_sub_extension(filter, jumpdest_flag_channel.value[0], filter); + yield_constr.constraint(builder, constr); + } + + // Make sure that the JUMPDEST flag channel is constrained. + // Only need to read if we're about to jump and we're not in kernel mode. + { + let constr = builder.mul_sub_extension( + jumps_lv.should_jump, lv.is_kernel_mode, - jumps_lv.dst_valid_or_kernel, - ); - let constr = builder.sub_extension(jumps_lv.dst_valid, constr); - let constr = builder.add_extension(lv.is_kernel_mode, constr); - let constr = builder.mul_extension(filter, constr); - yield_constr.constraint(builder, constr); - } - - // Check `input0_jumpable` (this is just `dst_valid_or_kernel` AND `input0_upper_zero`) - { - let constr = builder.mul_sub_extension( - jumps_lv.dst_valid_or_kernel, - jumps_lv.input0_upper_zero, - jumps_lv.input0_jumpable, - ); - let constr = builder.mul_extension(filter, constr); - yield_constr.constraint(builder, constr); - } - - // Make sure that `should_continue`, `should_jump`, `should_trap` are all binary and exactly one - // is set. - for flag in [ - jumps_lv.should_continue, - jumps_lv.should_jump, - jumps_lv.should_trap, - ] { - let constr = builder.mul_sub_extension(flag, flag, flag); - let constr = builder.mul_extension(filter, constr); - yield_constr.constraint(builder, constr); - } - { - let constr = builder.add_extension(jumps_lv.should_continue, jumps_lv.should_jump); - let constr = builder.add_extension(constr, jumps_lv.should_trap); - let constr = builder.mul_sub_extension(filter, constr, filter); - yield_constr.constraint(builder, constr); - } - - // Validate `should_continue` - { - // This sum cannot overflow (due to limb size). - let input1_sum = builder.add_many_extension(input1.into_iter()); - - // `should_continue` = 1 implies `input1_sum` = 0. - let constr = builder.mul_extension(jumps_lv.should_continue, input1_sum); - let constr = builder.mul_extension(filter, constr); - yield_constr.constraint(builder, constr); - - // `should_continue` = 0 implies `input1_sum * input1_sum_inv` = 1, which can only happen if - // input1_sum is nonzero. - let constr = builder.mul_add_extension( - input1_sum, - jumps_lv.input1_sum_inv, - jumps_lv.should_continue, - ); - let constr = builder.mul_sub_extension(filter, constr, filter); - yield_constr.constraint(builder, constr); - } - - // Validate `should_jump` and `should_trap` by splitting on `input0_jumpable`. - // Note that `should_jump` = 1 and `should_trap` = 1 both imply that `should_continue` = 0, so - // `input1` is nonzero. - { - let constr = builder.mul_sub_extension( - jumps_lv.should_jump, - jumps_lv.input0_jumpable, jumps_lv.should_jump, ); + let constr = builder.add_extension(jumpdest_flag_channel.used, constr); let constr = builder.mul_extension(filter, constr); yield_constr.constraint(builder, constr); } { - let constr = builder.mul_extension(jumps_lv.should_trap, jumps_lv.input0_jumpable); + let constr = builder.mul_sub_extension(filter, jumpdest_flag_channel.is_read, filter); + yield_constr.constraint(builder, constr); + } + { + let constr = builder.sub_extension(jumpdest_flag_channel.addr_context, lv.context); let constr = builder.mul_extension(filter, constr); yield_constr.constraint(builder, constr); } - - // Handle trap { - let trap_filter = builder.mul_extension(filter, jumps_lv.should_trap); - - // Set kernel flag - let constr = builder.mul_sub_extension(trap_filter, nv.is_kernel_mode, trap_filter); - yield_constr.constraint_transition(builder, constr); - - // Set program counter let constr = builder.arithmetic_extension( F::ONE, - -F::from_canonical_usize(*INVALID_DST_HANDLER_ADDR), - trap_filter, - nv.program_counter, - trap_filter, + -F::from_canonical_u64(Segment::JumpdestBits as u64), + filter, + jumpdest_flag_channel.addr_segment, + filter, ); - yield_constr.constraint_transition(builder, constr); + yield_constr.constraint(builder, constr); + } + { + let constr = builder.sub_extension(jumpdest_flag_channel.addr_virtual, dst[0]); + let constr = builder.mul_extension(filter, constr); + yield_constr.constraint(builder, constr); } - // Handle continue and jump + // Disable unused memory channels + for &channel in &lv.mem_channels[2..NUM_GP_CHANNELS - 1] { + let constr = builder.mul_extension(filter, channel.used); + yield_constr.constraint(builder, constr); + } + // Channel 1 is unused by the `JUMP` instruction. { - // Keep kernel mode. - let continue_or_jump = - builder.add_extension(jumps_lv.should_continue, jumps_lv.should_jump); - let constr = builder.sub_extension(nv.is_kernel_mode, lv.is_kernel_mode); - let constr = builder.mul_extension(continue_or_jump, constr); - let constr = builder.mul_extension(filter, constr); + let constr = builder.mul_extension(lv.op.jump, lv.mem_channels[1].used); + yield_constr.constraint(builder, constr); + } + + // Finally, set the next program counter. + let fallthrough_dst = builder.add_const_extension(lv.program_counter, F::ONE); + let jump_dest = dst[0]; + { + let constr_a = builder.mul_sub_extension(filter, jumps_lv.should_jump, filter); + let constr_b = builder.sub_extension(nv.program_counter, fallthrough_dst); + let constr = builder.mul_extension(constr_a, constr_b); yield_constr.constraint_transition(builder, constr); } - // Set program counter depending on whether we're continuing... { - let constr = builder.sub_extension(nv.program_counter, lv.program_counter); - let constr = - builder.mul_sub_extension(jumps_lv.should_continue, constr, jumps_lv.should_continue); - let constr = builder.mul_extension(filter, constr); - yield_constr.constraint_transition(builder, constr); - } - // ...or jumping. - { - let constr = builder.sub_extension(nv.program_counter, input0[0]); - let constr = builder.mul_extension(jumps_lv.should_jump, constr); - let constr = builder.mul_extension(filter, constr); + let constr_a = builder.mul_extension(filter, jumps_lv.should_jump); + let constr_b = builder.sub_extension(nv.program_counter, jump_dest); + let constr = builder.mul_extension(constr_a, constr_b); yield_constr.constraint_transition(builder, constr); } } diff --git a/evm/src/cpu/stack.rs b/evm/src/cpu/stack.rs index 42181e41..b666d46c 100644 --- a/evm/src/cpu/stack.rs +++ b/evm/src/cpu/stack.rs @@ -64,8 +64,16 @@ const STACK_BEHAVIORS: OpsColumnsView> = OpsColumnsView { keccak_general: None, // TODO prover_input: None, // TODO pop: None, // TODO - jump: None, // TODO - jumpi: None, // TODO + jump: Some(StackBehavior { + num_pops: 1, + pushes: false, + disable_other_channels: false, + }), + jumpi: Some(StackBehavior { + num_pops: 2, + pushes: false, + disable_other_channels: false, + }), pc: Some(StackBehavior { num_pops: 0, pushes: true, @@ -91,7 +99,11 @@ const STACK_BEHAVIORS: OpsColumnsView> = OpsColumnsView { disable_other_channels: true, }), consume_gas: None, // TODO - exit_kernel: None, // TODO + exit_kernel: Some(StackBehavior { + num_pops: 1, + pushes: false, + disable_other_channels: true, + }), mload_general: Some(StackBehavior { num_pops: 3, pushes: true, diff --git a/evm/src/util.rs b/evm/src/util.rs index fb3f1f13..dcb0a8ef 100644 --- a/evm/src/util.rs +++ b/evm/src/util.rs @@ -144,11 +144,3 @@ pub(crate) fn biguint_to_u256(x: BigUint) -> U256 { let bytes = x.to_bytes_le(); U256::from_little_endian(&bytes) } - -pub(crate) fn u256_saturating_cast_usize(x: U256) -> usize { - if x > usize::MAX.into() { - usize::MAX - } else { - x.as_usize() - } -} diff --git a/evm/src/witness/operation.rs b/evm/src/witness/operation.rs index fa9d0fac..b026de01 100644 --- a/evm/src/witness/operation.rs +++ b/evm/src/witness/operation.rs @@ -10,7 +10,6 @@ use crate::cpu::membus::NUM_GP_CHANNELS; use crate::cpu::simple_logic::eq_iszero::generate_pinv_diff; use crate::generation::state::GenerationState; use crate::memory::segments::Segment; -use crate::util::u256_saturating_cast_usize; use crate::witness::errors::ProgramError; use crate::witness::memory::MemoryAddress; use crate::witness::util::{ @@ -187,12 +186,37 @@ pub(crate) fn generate_jump( mut row: CpuColumnsView, ) -> Result<(), ProgramError> { let [(dst, log_in0)] = stack_pop_with_log_and_fill::<1, _>(state, &mut row)?; + let dst: u32 = dst + .try_into() + .map_err(|_| ProgramError::InvalidJumpDestination)?; + + let (jumpdest_bit, jumpdest_bit_log) = mem_read_gp_with_log_and_fill( + NUM_GP_CHANNELS - 1, + MemoryAddress::new(state.registers.context, Segment::JumpdestBits, dst as usize), + state, + &mut row, + ); + if state.registers.is_kernel { + // Don't actually do the read, just set the address, etc. + let mut channel = &mut row.mem_channels[NUM_GP_CHANNELS - 1]; + channel.used = F::ZERO; + channel.value[0] = F::ONE; + + row.mem_channels[1].value[0] = F::ONE; + } else { + if jumpdest_bit != U256::one() { + return Err(ProgramError::InvalidJumpDestination); + } + state.traces.push_memory(jumpdest_bit_log); + } + + // Extra fields required by the constraints. + row.general.jumps_mut().should_jump = F::ONE; + row.general.jumps_mut().cond_sum_pinv = F::ONE; state.traces.push_memory(log_in0); state.traces.push_cpu(row); - // TODO: First check if it's a valid JUMPDEST - state.registers.program_counter = u256_saturating_cast_usize(dst); - // TODO: Set other cols like input0_upper_sum_inv. + state.registers.program_counter = dst as usize; Ok(()) } @@ -202,16 +226,52 @@ pub(crate) fn generate_jumpi( ) -> Result<(), ProgramError> { let [(dst, log_in0), (cond, log_in1)] = stack_pop_with_log_and_fill::<2, _>(state, &mut row)?; + let should_jump = !cond.is_zero(); + if should_jump { + row.general.jumps_mut().should_jump = F::ONE; + let cond_sum_u64 = cond + .0 + .into_iter() + .map(|limb| ((limb as u32) as u64) + (limb >> 32)) + .sum(); + let cond_sum = F::from_canonical_u64(cond_sum_u64); + row.general.jumps_mut().cond_sum_pinv = cond_sum.inverse(); + + let dst: u32 = dst + .try_into() + .map_err(|_| ProgramError::InvalidJumpiDestination)?; + state.registers.program_counter = dst as usize; + } else { + row.general.jumps_mut().should_jump = F::ZERO; + row.general.jumps_mut().cond_sum_pinv = F::ZERO; + state.registers.program_counter += 1; + } + + let (jumpdest_bit, jumpdest_bit_log) = mem_read_gp_with_log_and_fill( + NUM_GP_CHANNELS - 1, + MemoryAddress::new( + state.registers.context, + Segment::JumpdestBits, + dst.low_u32() as usize, + ), + state, + &mut row, + ); + if !should_jump || state.registers.is_kernel { + // Don't actually do the read, just set the address, etc. + let mut channel = &mut row.mem_channels[NUM_GP_CHANNELS - 1]; + channel.used = F::ZERO; + channel.value[0] = F::ONE; + } else { + if jumpdest_bit != U256::one() { + return Err(ProgramError::InvalidJumpiDestination); + } + state.traces.push_memory(jumpdest_bit_log); + } + state.traces.push_memory(log_in0); state.traces.push_memory(log_in1); state.traces.push_cpu(row); - state.registers.program_counter = if cond.is_zero() { - state.registers.program_counter + 1 - } else { - // TODO: First check if it's a valid JUMPDEST - u256_saturating_cast_usize(dst) - }; - // TODO: Set other cols like input0_upper_sum_inv. Ok(()) }