plonky2/evm/src/cpu/jumps.rs
Robin Salen 24aa9668f2
Revert "Make gas fit in 2 limbs (#1261)" (#1361)
* Revert "Make gas fit in 2 limbs (#1261)"

This reverts commit 0f19cd0dbc25f9f1aa8fc325ae4dd1b95ca933b3.

* Comment
2023-11-17 10:01:26 -05:00

384 lines
16 KiB
Rust

use plonky2::field::extension::Extendable;
use plonky2::field::packed::PackedField;
use plonky2::field::types::Field;
use plonky2::hash::hash_types::RichField;
use plonky2::iop::ext_target::ExtensionTarget;
use crate::constraint_consumer::{ConstraintConsumer, RecursiveConstraintConsumer};
use crate::cpu::columns::CpuColumnsView;
use crate::cpu::membus::NUM_GP_CHANNELS;
use crate::memory::segments::Segment;
/// Evaluates constraints for EXIT_KERNEL.
pub(crate) fn eval_packed_exit_kernel<P: PackedField>(
lv: &CpuColumnsView<P>,
nv: &CpuColumnsView<P>,
yield_constr: &mut ConstraintConsumer<P>,
) {
let input = lv.mem_channels[0].value;
let filter = lv.op.exit_kernel;
// If we are executing `EXIT_KERNEL` then we simply restore the program counter, kernel mode
// flag, and gas counter. The middle 4 (32-bit) limbs are ignored (this is not part of the spec,
// but we trust the kernel to set them to zero).
yield_constr.constraint_transition(filter * (input[0] - nv.program_counter));
yield_constr.constraint_transition(filter * (input[1] - nv.is_kernel_mode));
yield_constr.constraint_transition(filter * (input[6] - nv.gas));
// High limb of gas must be 0 for convenient detection of overflow.
yield_constr.constraint(filter * input[7]);
}
/// Circuit version of `eval_packed_exit_kernel`.
/// Evaluates constraints for EXIT_KERNEL.
pub(crate) fn eval_ext_circuit_exit_kernel<F: RichField + Extendable<D>, const D: usize>(
builder: &mut plonky2::plonk::circuit_builder::CircuitBuilder<F, D>,
lv: &CpuColumnsView<ExtensionTarget<D>>,
nv: &CpuColumnsView<ExtensionTarget<D>>,
yield_constr: &mut RecursiveConstraintConsumer<F, D>,
) {
let input = lv.mem_channels[0].value;
let filter = lv.op.exit_kernel;
// If we are executing `EXIT_KERNEL` then we simply restore the program counter and kernel mode
// flag. The top 6 (32-bit) limbs are ignored (this is not part of the spec, but we trust the
// kernel to set them to zero).
let pc_constr = builder.sub_extension(input[0], nv.program_counter);
let pc_constr = builder.mul_extension(filter, pc_constr);
yield_constr.constraint_transition(builder, pc_constr);
let kernel_constr = builder.sub_extension(input[1], nv.is_kernel_mode);
let kernel_constr = builder.mul_extension(filter, kernel_constr);
yield_constr.constraint_transition(builder, kernel_constr);
{
let diff = builder.sub_extension(input[6], nv.gas);
let constr = builder.mul_extension(filter, diff);
yield_constr.constraint_transition(builder, constr);
}
{
// High limb of gas must be 0 for convenient detection of overflow.
let constr = builder.mul_extension(filter, input[7]);
yield_constr.constraint(builder, constr);
}
}
/// Evaluates constraints jump operations: JUMP and JUMPI.
pub(crate) fn eval_packed_jump_jumpi<P: PackedField>(
lv: &CpuColumnsView<P>,
nv: &CpuColumnsView<P>,
yield_constr: &mut ConstraintConsumer<P>,
) {
let jumps_lv = lv.general.jumps();
let dst = lv.mem_channels[0].value;
let cond = lv.mem_channels[1].value;
let filter = lv.op.jumps; // `JUMP` or `JUMPI`
let jumpdest_flag_channel = lv.mem_channels[NUM_GP_CHANNELS - 1];
let is_jump = filter * (P::ONES - lv.opcode_bits[0]);
let is_jumpi = filter * lv.opcode_bits[0];
// Stack constraints.
// If (JUMP and stack_len != 1) or (JUMPI and stack_len != 2)...
let len_diff = lv.stack_len - P::ONES - lv.opcode_bits[0];
let new_filter = len_diff * filter;
// Read an extra element.
let channel = nv.mem_channels[0];
yield_constr.constraint_transition(new_filter * (channel.used - P::ONES));
yield_constr.constraint_transition(new_filter * (channel.is_read - P::ONES));
yield_constr.constraint_transition(new_filter * (channel.addr_context - nv.context));
yield_constr.constraint_transition(
new_filter * (channel.addr_segment - P::Scalar::from_canonical_u64(Segment::Stack as u64)),
);
let addr_virtual = nv.stack_len - P::ONES;
yield_constr.constraint_transition(new_filter * (channel.addr_virtual - addr_virtual));
// Constrain `stack_inv_aux`.
yield_constr.constraint(
filter * (len_diff * lv.general.stack().stack_inv - lv.general.stack().stack_inv_aux),
);
// Disable channel if stack_len == N.
let empty_stack_filter = filter * (lv.general.stack().stack_inv_aux - P::ONES);
yield_constr.constraint_transition(empty_stack_filter * channel.used);
// If `JUMP`, re-use the `JUMPI` logic, but setting the second input (the predicate) to be 1.
// In other words, we implement `JUMP(dst)` as `JUMPI(dst, cond=1)`.
yield_constr.constraint(is_jump * (cond[0] - P::ONES));
for &limb in &cond[1..] {
// Set all limbs (other than the least-significant limb) to 0.
// NB: Technically, they don't have to be 0, as long as the sum
// `cond[0] + ... + cond[7]` cannot overflow.
yield_constr.constraint(is_jump * limb);
}
// Check `should_jump`:
yield_constr.constraint(filter * jumps_lv.should_jump * (jumps_lv.should_jump - P::ONES));
let cond_sum: P = cond.into_iter().sum();
yield_constr.constraint(filter * (jumps_lv.should_jump - P::ONES) * cond_sum);
yield_constr.constraint(filter * (jumps_lv.cond_sum_pinv * cond_sum - jumps_lv.should_jump));
// If we're jumping, then the high 7 limbs of the destination must be 0.
let dst_hi_sum: P = dst[1..].iter().copied().sum();
yield_constr.constraint(filter * jumps_lv.should_jump * dst_hi_sum);
// Check that the destination address holds a `JUMPDEST` instruction. Note that this constraint
// does not need to be conditioned on `should_jump` because no read takes place if we're not
// jumping, so we're free to set the channel to 1.
yield_constr.constraint(filter * (jumpdest_flag_channel.value[0] - P::ONES));
// Make sure that the JUMPDEST flag channel is constrained.
// Only need to read if we're about to jump and we're not in kernel mode.
yield_constr.constraint(
filter
* (jumpdest_flag_channel.used - jumps_lv.should_jump * (P::ONES - lv.is_kernel_mode)),
);
yield_constr.constraint(filter * (jumpdest_flag_channel.is_read - P::ONES));
yield_constr.constraint(filter * (jumpdest_flag_channel.addr_context - lv.context));
yield_constr.constraint(
filter
* (jumpdest_flag_channel.addr_segment
- P::Scalar::from_canonical_u64(Segment::JumpdestBits as u64)),
);
yield_constr.constraint(filter * (jumpdest_flag_channel.addr_virtual - dst[0]));
// Disable unused memory channels
for &channel in &lv.mem_channels[2..NUM_GP_CHANNELS - 1] {
yield_constr.constraint(filter * channel.used);
}
// Channel 1 is unused by the `JUMP` instruction.
yield_constr.constraint(is_jump * lv.mem_channels[1].used);
// Update stack length.
yield_constr.constraint_transition(is_jump * (nv.stack_len - lv.stack_len + P::ONES));
yield_constr.constraint_transition(
is_jumpi * (nv.stack_len - lv.stack_len + P::Scalar::from_canonical_u64(2)),
);
// Finally, set the next program counter.
let fallthrough_dst = lv.program_counter + P::ONES;
let jump_dest = dst[0];
yield_constr.constraint_transition(
filter * (jumps_lv.should_jump - P::ONES) * (nv.program_counter - fallthrough_dst),
);
yield_constr
.constraint_transition(filter * jumps_lv.should_jump * (nv.program_counter - jump_dest));
}
/// Circuit version of `eval_packed_jumpi_jumpi`.
/// Evaluates constraints jump operations: JUMP and JUMPI.
pub(crate) fn eval_ext_circuit_jump_jumpi<F: RichField + Extendable<D>, const D: usize>(
builder: &mut plonky2::plonk::circuit_builder::CircuitBuilder<F, D>,
lv: &CpuColumnsView<ExtensionTarget<D>>,
nv: &CpuColumnsView<ExtensionTarget<D>>,
yield_constr: &mut RecursiveConstraintConsumer<F, D>,
) {
let jumps_lv = lv.general.jumps();
let dst = lv.mem_channels[0].value;
let cond = lv.mem_channels[1].value;
let filter = lv.op.jumps; // `JUMP` or `JUMPI`
let jumpdest_flag_channel = lv.mem_channels[NUM_GP_CHANNELS - 1];
let one_extension = builder.one_extension();
let is_jump = builder.sub_extension(one_extension, lv.opcode_bits[0]);
let is_jump = builder.mul_extension(filter, is_jump);
let is_jumpi = builder.mul_extension(filter, lv.opcode_bits[0]);
// Stack constraints.
// If (JUMP and stack_len != 1) or (JUMPI and stack_len != 2)...
let len_diff = builder.sub_extension(lv.stack_len, one_extension);
let len_diff = builder.sub_extension(len_diff, lv.opcode_bits[0]);
let new_filter = builder.mul_extension(len_diff, filter);
// Read an extra element.
let channel = nv.mem_channels[0];
{
let constr = builder.mul_sub_extension(new_filter, channel.used, new_filter);
yield_constr.constraint_transition(builder, constr);
}
{
let constr = builder.mul_sub_extension(new_filter, channel.is_read, new_filter);
yield_constr.constraint_transition(builder, constr);
}
{
let diff = builder.sub_extension(channel.addr_context, nv.context);
let constr = builder.mul_extension(new_filter, diff);
yield_constr.constraint_transition(builder, constr);
}
{
let constr = builder.arithmetic_extension(
F::ONE,
-F::from_canonical_u64(Segment::Stack as u64),
new_filter,
channel.addr_segment,
new_filter,
);
yield_constr.constraint_transition(builder, constr);
}
{
let diff = builder.sub_extension(channel.addr_virtual, nv.stack_len);
let constr = builder.arithmetic_extension(F::ONE, F::ONE, new_filter, diff, new_filter);
yield_constr.constraint_transition(builder, constr);
}
// Constrain `stack_inv_aux`.
{
let prod = builder.mul_extension(len_diff, lv.general.stack().stack_inv);
let diff = builder.sub_extension(prod, lv.general.stack().stack_inv_aux);
let constr = builder.mul_extension(filter, diff);
yield_constr.constraint(builder, constr);
}
// Disable channel if stack_len == N.
{
let empty_stack_filter =
builder.mul_sub_extension(filter, lv.general.stack().stack_inv_aux, filter);
let constr = builder.mul_extension(empty_stack_filter, channel.used);
yield_constr.constraint_transition(builder, constr);
}
// If `JUMP`, re-use the `JUMPI` logic, but setting the second input (the predicate) to be 1.
// In other words, we implement `JUMP(dst)` as `JUMPI(dst, cond=1)`.
{
let constr = builder.mul_sub_extension(is_jump, cond[0], is_jump);
yield_constr.constraint(builder, constr);
}
for &limb in &cond[1..] {
// Set all limbs (other than the least-significant limb) to 0.
// NB: Technically, they don't have to be 0, as long as the sum
// `cond[0] + ... + cond[7]` cannot overflow.
let constr = builder.mul_extension(is_jump, limb);
yield_constr.constraint(builder, constr);
}
// Check `should_jump`:
{
let constr = builder.mul_sub_extension(
jumps_lv.should_jump,
jumps_lv.should_jump,
jumps_lv.should_jump,
);
let constr = builder.mul_extension(filter, constr);
yield_constr.constraint(builder, constr);
}
let cond_sum = builder.add_many_extension(cond);
{
let constr = builder.mul_sub_extension(cond_sum, jumps_lv.should_jump, cond_sum);
let constr = builder.mul_extension(filter, constr);
yield_constr.constraint(builder, constr);
}
{
let constr =
builder.mul_sub_extension(jumps_lv.cond_sum_pinv, cond_sum, jumps_lv.should_jump);
let constr = builder.mul_extension(filter, constr);
yield_constr.constraint(builder, constr);
}
// If we're jumping, then the high 7 limbs of the destination must be 0.
let dst_hi_sum = builder.add_many_extension(&dst[1..]);
{
let constr = builder.mul_extension(jumps_lv.should_jump, dst_hi_sum);
let constr = builder.mul_extension(filter, constr);
yield_constr.constraint(builder, constr);
}
// Check that the destination address holds a `JUMPDEST` instruction. Note that this constraint
// does not need to be conditioned on `should_jump` because no read takes place if we're not
// jumping, so we're free to set the channel to 1.
{
let constr = builder.mul_sub_extension(filter, jumpdest_flag_channel.value[0], filter);
yield_constr.constraint(builder, constr);
}
// Make sure that the JUMPDEST flag channel is constrained.
// Only need to read if we're about to jump and we're not in kernel mode.
{
let constr = builder.mul_sub_extension(
jumps_lv.should_jump,
lv.is_kernel_mode,
jumps_lv.should_jump,
);
let constr = builder.add_extension(jumpdest_flag_channel.used, constr);
let constr = builder.mul_extension(filter, constr);
yield_constr.constraint(builder, constr);
}
{
let constr = builder.mul_sub_extension(filter, jumpdest_flag_channel.is_read, filter);
yield_constr.constraint(builder, constr);
}
{
let constr = builder.sub_extension(jumpdest_flag_channel.addr_context, lv.context);
let constr = builder.mul_extension(filter, constr);
yield_constr.constraint(builder, constr);
}
{
let constr = builder.arithmetic_extension(
F::ONE,
-F::from_canonical_u64(Segment::JumpdestBits as u64),
filter,
jumpdest_flag_channel.addr_segment,
filter,
);
yield_constr.constraint(builder, constr);
}
{
let constr = builder.sub_extension(jumpdest_flag_channel.addr_virtual, dst[0]);
let constr = builder.mul_extension(filter, constr);
yield_constr.constraint(builder, constr);
}
// Disable unused memory channels
for &channel in &lv.mem_channels[2..NUM_GP_CHANNELS - 1] {
let constr = builder.mul_extension(filter, channel.used);
yield_constr.constraint(builder, constr);
}
// Channel 1 is unused by the `JUMP` instruction.
{
let constr = builder.mul_extension(is_jump, lv.mem_channels[1].used);
yield_constr.constraint(builder, constr);
}
// Update stack length.
{
let diff = builder.sub_extension(nv.stack_len, lv.stack_len);
let constr = builder.mul_add_extension(is_jump, diff, is_jump);
yield_constr.constraint_transition(builder, constr);
}
{
let diff = builder.sub_extension(nv.stack_len, lv.stack_len);
let diff = builder.add_const_extension(diff, F::TWO);
let constr = builder.mul_extension(is_jumpi, diff);
yield_constr.constraint_transition(builder, constr);
}
// Finally, set the next program counter.
let fallthrough_dst = builder.add_const_extension(lv.program_counter, F::ONE);
let jump_dest = dst[0];
{
let constr_a = builder.mul_sub_extension(filter, jumps_lv.should_jump, filter);
let constr_b = builder.sub_extension(nv.program_counter, fallthrough_dst);
let constr = builder.mul_extension(constr_a, constr_b);
yield_constr.constraint_transition(builder, constr);
}
{
let constr_a = builder.mul_extension(filter, jumps_lv.should_jump);
let constr_b = builder.sub_extension(nv.program_counter, jump_dest);
let constr = builder.mul_extension(constr_a, constr_b);
yield_constr.constraint_transition(builder, constr);
}
}
/// Evaluates constraints for EXIT_KERNEL, JUMP and JUMPI.
pub(crate) fn eval_packed<P: PackedField>(
lv: &CpuColumnsView<P>,
nv: &CpuColumnsView<P>,
yield_constr: &mut ConstraintConsumer<P>,
) {
eval_packed_exit_kernel(lv, nv, yield_constr);
eval_packed_jump_jumpi(lv, nv, yield_constr);
}
/// Circuit version of `eval_packed`.
/// Evaluates constraints for EXIT_KERNEL, JUMP and JUMPI.
pub(crate) fn eval_ext_circuit<F: RichField + Extendable<D>, const D: usize>(
builder: &mut plonky2::plonk::circuit_builder::CircuitBuilder<F, D>,
lv: &CpuColumnsView<ExtensionTarget<D>>,
nv: &CpuColumnsView<ExtensionTarget<D>>,
yield_constr: &mut RecursiveConstraintConsumer<F, D>,
) {
eval_ext_circuit_exit_kernel(builder, lv, nv, yield_constr);
eval_ext_circuit_jump_jumpi(builder, lv, nv, yield_constr);
}