Merge pull request #1184 from topos-protocol/combine_jump_flags

Combine jump flags
This commit is contained in:
Hamy Ratoanina 2023-08-15 01:35:56 +02:00 committed by GitHub
commit 830fdf5374
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 65 additions and 27 deletions

View File

@ -33,8 +33,7 @@ pub struct OpsColumnsView<T: Copy> {
pub prover_input: T, pub prover_input: T,
pub pop: T, pub pop: T,
// TODO: combine JUMP and JUMPI into one flag // TODO: combine JUMP and JUMPI into one flag
pub jump: T, // Note: This column must be 0 when is_cpu_cycle = 0. pub jumps: T, // Note: This column must be 0 when is_cpu_cycle = 0.
pub jumpi: T, // Note: This column must be 0 when is_cpu_cycle = 0.
pub pc: T, pub pc: T,
pub jumpdest: T, pub jumpdest: T,
pub push0: T, pub push0: T,

View File

@ -22,7 +22,7 @@ use crate::cpu::columns::{CpuColumnsView, COL_MAP};
/// behavior. /// behavior.
/// Note: invalid opcodes are not represented here. _Any_ opcode is permitted to decode to /// Note: invalid opcodes are not represented here. _Any_ opcode is permitted to decode to
/// `is_invalid`. The kernel then verifies that the opcode was _actually_ invalid. /// `is_invalid`. The kernel then verifies that the opcode was _actually_ invalid.
const OPCODES: [(u8, usize, bool, usize); 33] = [ const OPCODES: [(u8, usize, bool, usize); 32] = [
// (start index of block, number of top bits to check (log2), kernel-only, flag column) // (start index of block, number of top bits to check (log2), kernel-only, flag column)
(0x01, 0, false, COL_MAP.op.add), (0x01, 0, false, COL_MAP.op.add),
(0x02, 0, false, COL_MAP.op.mul), (0x02, 0, false, COL_MAP.op.mul),
@ -45,8 +45,7 @@ const OPCODES: [(u8, usize, bool, usize); 33] = [
(0x21, 0, true, COL_MAP.op.keccak_general), (0x21, 0, true, COL_MAP.op.keccak_general),
(0x49, 0, true, COL_MAP.op.prover_input), (0x49, 0, true, COL_MAP.op.prover_input),
(0x50, 0, false, COL_MAP.op.pop), (0x50, 0, false, COL_MAP.op.pop),
(0x56, 0, false, COL_MAP.op.jump), (0x56, 1, false, COL_MAP.op.jumps), // 0x56-0x57
(0x57, 0, false, COL_MAP.op.jumpi),
(0x58, 0, false, COL_MAP.op.pc), (0x58, 0, false, COL_MAP.op.pc),
(0x5b, 0, false, COL_MAP.op.jumpdest), (0x5b, 0, false, COL_MAP.op.jumpdest),
(0x5f, 0, false, COL_MAP.op.push0), (0x5f, 0, false, COL_MAP.op.push0),

View File

@ -40,8 +40,7 @@ const SIMPLE_OPCODES: OpsColumnsView<Option<u32>> = OpsColumnsView {
keccak_general: KERNEL_ONLY_INSTR, keccak_general: KERNEL_ONLY_INSTR,
prover_input: KERNEL_ONLY_INSTR, prover_input: KERNEL_ONLY_INSTR,
pop: G_BASE, pop: G_BASE,
jump: G_MID, jumps: None, // Combined flag handled separately.
jumpi: G_HIGH,
pc: G_BASE, pc: G_BASE,
jumpdest: G_JUMPDEST, jumpdest: G_JUMPDEST,
push0: G_BASE, push0: G_BASE,
@ -93,6 +92,12 @@ fn eval_packed_accumulate<P: PackedField>(
.constraint_transition(lv.is_cpu_cycle * op_flag * (nv.gas - lv.gas - cost)); .constraint_transition(lv.is_cpu_cycle * op_flag * (nv.gas - lv.gas - cost));
} }
} }
// For jumps.
let jump_gas_cost = P::Scalar::from_canonical_u32(G_MID.unwrap())
+ lv.opcode_bits[0] * P::Scalar::from_canonical_u32(G_HIGH.unwrap() - G_MID.unwrap());
yield_constr
.constraint_transition(lv.is_cpu_cycle * lv.op.jumps * (nv.gas - lv.gas - jump_gas_cost));
} }
fn eval_packed_init<P: PackedField>( fn eval_packed_init<P: PackedField>(
@ -168,6 +173,20 @@ fn eval_ext_circuit_accumulate<F: RichField + Extendable<D>, const D: usize>(
yield_constr.constraint_transition(builder, constr); yield_constr.constraint_transition(builder, constr);
} }
} }
// For jumps.
let filter = builder.mul_extension(lv.is_cpu_cycle, lv.op.jumps);
let jump_gas_cost = builder.mul_const_extension(
F::from_canonical_u32(G_HIGH.unwrap() - G_MID.unwrap()),
lv.opcode_bits[0],
);
let jump_gas_cost =
builder.add_const_extension(jump_gas_cost, F::from_canonical_u32(G_MID.unwrap()));
let nv_lv_diff = builder.sub_extension(nv.gas, lv.gas);
let gas_diff = builder.sub_extension(nv_lv_diff, jump_gas_cost);
let constr = builder.mul_extension(filter, gas_diff);
yield_constr.constraint_transition(builder, constr);
} }
fn eval_ext_circuit_init<F: RichField + Extendable<D>, const D: usize>( fn eval_ext_circuit_init<F: RichField + Extendable<D>, const D: usize>(

View File

@ -7,6 +7,7 @@ use plonky2::iop::ext_target::ExtensionTarget;
use crate::constraint_consumer::{ConstraintConsumer, RecursiveConstraintConsumer}; use crate::constraint_consumer::{ConstraintConsumer, RecursiveConstraintConsumer};
use crate::cpu::columns::CpuColumnsView; use crate::cpu::columns::CpuColumnsView;
use crate::cpu::membus::NUM_GP_CHANNELS; use crate::cpu::membus::NUM_GP_CHANNELS;
use crate::cpu::stack;
use crate::memory::segments::Segment; use crate::memory::segments::Segment;
pub fn eval_packed_exit_kernel<P: PackedField>( pub fn eval_packed_exit_kernel<P: PackedField>(
@ -68,17 +69,23 @@ pub fn eval_packed_jump_jumpi<P: PackedField>(
let jumps_lv = lv.general.jumps(); let jumps_lv = lv.general.jumps();
let dst = lv.mem_channels[0].value; let dst = lv.mem_channels[0].value;
let cond = lv.mem_channels[1].value; let cond = lv.mem_channels[1].value;
let filter = lv.op.jump + lv.op.jumpi; // `JUMP` or `JUMPI` let filter = lv.op.jumps; // `JUMP` or `JUMPI`
let jumpdest_flag_channel = lv.mem_channels[NUM_GP_CHANNELS - 1]; let jumpdest_flag_channel = lv.mem_channels[NUM_GP_CHANNELS - 1];
let is_jump = filter * (P::ONES - lv.opcode_bits[0]);
let is_jumpi = filter * lv.opcode_bits[0];
// Stack constraints.
stack::eval_packed_one(lv, is_jump, stack::JUMP_OP.unwrap(), yield_constr);
stack::eval_packed_one(lv, is_jumpi, stack::JUMPI_OP.unwrap(), yield_constr);
// If `JUMP`, re-use the `JUMPI` logic, but setting the second input (the predicate) to be 1. // If `JUMP`, re-use the `JUMPI` logic, but setting the second input (the predicate) to be 1.
// In other words, we implement `JUMP(dst)` as `JUMPI(dst, cond=1)`. // In other words, we implement `JUMP(dst)` as `JUMPI(dst, cond=1)`.
yield_constr.constraint(lv.op.jump * (cond[0] - P::ONES)); yield_constr.constraint(is_jump * (cond[0] - P::ONES));
for &limb in &cond[1..] { for &limb in &cond[1..] {
// Set all limbs (other than the least-significant limb) to 0. // Set all limbs (other than the least-significant limb) to 0.
// NB: Technically, they don't have to be 0, as long as the sum // NB: Technically, they don't have to be 0, as long as the sum
// `cond[0] + ... + cond[7]` cannot overflow. // `cond[0] + ... + cond[7]` cannot overflow.
yield_constr.constraint(lv.op.jump * limb); yield_constr.constraint(is_jump * limb);
} }
// Check `should_jump`: // Check `should_jump`:
@ -115,7 +122,7 @@ pub fn eval_packed_jump_jumpi<P: PackedField>(
yield_constr.constraint(filter * channel.used); yield_constr.constraint(filter * channel.used);
} }
// Channel 1 is unused by the `JUMP` instruction. // Channel 1 is unused by the `JUMP` instruction.
yield_constr.constraint(lv.op.jump * lv.mem_channels[1].used); yield_constr.constraint(is_jump * lv.mem_channels[1].used);
// Finally, set the next program counter. // Finally, set the next program counter.
let fallthrough_dst = lv.program_counter + P::ONES; let fallthrough_dst = lv.program_counter + P::ONES;
@ -136,20 +143,34 @@ pub fn eval_ext_circuit_jump_jumpi<F: RichField + Extendable<D>, const D: usize>
let jumps_lv = lv.general.jumps(); let jumps_lv = lv.general.jumps();
let dst = lv.mem_channels[0].value; let dst = lv.mem_channels[0].value;
let cond = lv.mem_channels[1].value; let cond = lv.mem_channels[1].value;
let filter = builder.add_extension(lv.op.jump, lv.op.jumpi); // `JUMP` or `JUMPI` let filter = lv.op.jumps; // `JUMP` or `JUMPI`
let jumpdest_flag_channel = lv.mem_channels[NUM_GP_CHANNELS - 1]; let jumpdest_flag_channel = lv.mem_channels[NUM_GP_CHANNELS - 1];
let one_extension = builder.one_extension();
let is_jump = builder.sub_extension(one_extension, lv.opcode_bits[0]);
let is_jump = builder.mul_extension(filter, is_jump);
let is_jumpi = builder.mul_extension(filter, lv.opcode_bits[0]);
// Stack constraints.
stack::eval_ext_circuit_one(builder, lv, is_jump, stack::JUMP_OP.unwrap(), yield_constr);
stack::eval_ext_circuit_one(
builder,
lv,
is_jumpi,
stack::JUMPI_OP.unwrap(),
yield_constr,
);
// If `JUMP`, re-use the `JUMPI` logic, but setting the second input (the predicate) to be 1. // If `JUMP`, re-use the `JUMPI` logic, but setting the second input (the predicate) to be 1.
// In other words, we implement `JUMP(dst)` as `JUMPI(dst, cond=1)`. // In other words, we implement `JUMP(dst)` as `JUMPI(dst, cond=1)`.
{ {
let constr = builder.mul_sub_extension(lv.op.jump, cond[0], lv.op.jump); let constr = builder.mul_sub_extension(is_jump, cond[0], is_jump);
yield_constr.constraint(builder, constr); yield_constr.constraint(builder, constr);
} }
for &limb in &cond[1..] { for &limb in &cond[1..] {
// Set all limbs (other than the least-significant limb) to 0. // Set all limbs (other than the least-significant limb) to 0.
// NB: Technically, they don't have to be 0, as long as the sum // NB: Technically, they don't have to be 0, as long as the sum
// `cond[0] + ... + cond[7]` cannot overflow. // `cond[0] + ... + cond[7]` cannot overflow.
let constr = builder.mul_extension(lv.op.jump, limb); let constr = builder.mul_extension(is_jump, limb);
yield_constr.constraint(builder, constr); yield_constr.constraint(builder, constr);
} }
@ -235,7 +256,7 @@ pub fn eval_ext_circuit_jump_jumpi<F: RichField + Extendable<D>, const D: usize>
} }
// Channel 1 is unused by the `JUMP` instruction. // Channel 1 is unused by the `JUMP` instruction.
{ {
let constr = builder.mul_extension(lv.op.jump, lv.mem_channels[1].used); let constr = builder.mul_extension(is_jump, lv.mem_channels[1].used);
yield_constr.constraint(builder, constr); yield_constr.constraint(builder, constr);
} }

View File

@ -33,6 +33,16 @@ const BASIC_TERNARY_OP: Option<StackBehavior> = Some(StackBehavior {
pushes: true, pushes: true,
disable_other_channels: true, disable_other_channels: true,
}); });
pub(crate) const JUMP_OP: Option<StackBehavior> = Some(StackBehavior {
num_pops: 1,
pushes: false,
disable_other_channels: false,
});
pub(crate) const JUMPI_OP: Option<StackBehavior> = Some(StackBehavior {
num_pops: 2,
pushes: false,
disable_other_channels: false,
});
// AUDITORS: If the value below is `None`, then the operation must be manually checked to ensure // AUDITORS: If the value below is `None`, then the operation must be manually checked to ensure
// that every general-purpose memory channel is either disabled or has its read flag and address // that every general-purpose memory channel is either disabled or has its read flag and address
@ -78,16 +88,7 @@ const STACK_BEHAVIORS: OpsColumnsView<Option<StackBehavior>> = OpsColumnsView {
pushes: false, pushes: false,
disable_other_channels: true, disable_other_channels: true,
}), }),
jump: Some(StackBehavior { jumps: None, // Depends on whether it's a JUMP or a JUMPI.
num_pops: 1,
pushes: false,
disable_other_channels: false,
}),
jumpi: Some(StackBehavior {
num_pops: 2,
pushes: false,
disable_other_channels: false,
}),
pc: Some(StackBehavior { pc: Some(StackBehavior {
num_pops: 0, num_pops: 0,
pushes: true, pushes: true,

View File

@ -179,8 +179,7 @@ fn fill_op_flag<F: Field>(op: Operation, row: &mut CpuColumnsView<F>) {
Operation::KeccakGeneral => &mut flags.keccak_general, Operation::KeccakGeneral => &mut flags.keccak_general,
Operation::ProverInput => &mut flags.prover_input, Operation::ProverInput => &mut flags.prover_input,
Operation::Pop => &mut flags.pop, Operation::Pop => &mut flags.pop,
Operation::Jump => &mut flags.jump, Operation::Jump | Operation::Jumpi => &mut flags.jumps,
Operation::Jumpi => &mut flags.jumpi,
Operation::Pc => &mut flags.pc, Operation::Pc => &mut flags.pc,
Operation::Jumpdest => &mut flags.jumpdest, Operation::Jumpdest => &mut flags.jumpdest,
Operation::GetContext => &mut flags.get_context, Operation::GetContext => &mut flags.get_context,