mirror of
https://github.com/logos-storage/plonky2.git
synced 2026-01-04 23:03:08 +00:00
Simplify JUMP/JUMPI constraints and finish witness generation (#846)
* Simplify `JUMP`/`JUMPI` constraints and finish witness generation * Constrain stack
This commit is contained in:
parent
1732239a00
commit
b6bc018cba
@ -97,46 +97,10 @@ pub(crate) struct CpuLogicView<T: Copy> {
|
||||
|
||||
#[derive(Copy, Clone)]
|
||||
pub(crate) struct CpuJumpsView<T: Copy> {
|
||||
/// `input0` is `mem_channel[0].value`. It's the top stack value at entry (for jumps, the
|
||||
/// address; for `EXIT_KERNEL`, the address and new privilege level).
|
||||
/// `input1` is `mem_channel[1].value`. For `JUMPI`, it's the second stack value (the
|
||||
/// predicate). For `JUMP`, 1.
|
||||
|
||||
/// Inverse of `input0[1] + ... + input0[7]`, if one exists; otherwise, an arbitrary value.
|
||||
/// Needed to prove that `input0` is nonzero.
|
||||
pub(crate) input0_upper_sum_inv: T,
|
||||
/// 1 if `input0[1..7]` is zero; else 0.
|
||||
pub(crate) input0_upper_zero: T,
|
||||
|
||||
/// 1 if `input0[0]` is the address of a valid jump destination (i.e. `JUMPDEST` that is not
|
||||
/// part of a `PUSH` immediate); else 0. Note that the kernel is allowed to jump anywhere it
|
||||
/// wants, so this flag is computed but ignored in kernel mode.
|
||||
/// NOTE: this flag only considers `input0[0]`, the low 32 bits of the 256-bit register. Even if
|
||||
/// this flag is 1, `input0` will still be an invalid address if the high 224 bits are not 0.
|
||||
pub(crate) dst_valid: T, // TODO: populate this (check for JUMPDEST)
|
||||
/// 1 if either `dst_valid` is 1 or we are in kernel mode; else 0. (Just a logical OR.)
|
||||
pub(crate) dst_valid_or_kernel: T,
|
||||
/// 1 if `dst_valid_or_kernel` and `input0_upper_zero` are both 1; else 0. In other words, we
|
||||
/// are allowed to jump to `input0[0]` because either it's a valid address or we're in kernel
|
||||
/// mode (`dst_valid_or_kernel`), and also `input0[1..7]` are all 0 so `input0[0]` is in fact
|
||||
/// the whole address (we're not being asked to jump to an address that would overflow).
|
||||
pub(crate) input0_jumpable: T,
|
||||
|
||||
/// Inverse of `input1[0] + ... + input1[7]`, if one exists; otherwise, an arbitrary value.
|
||||
/// Needed to prove that `input1` is nonzero.
|
||||
pub(crate) input1_sum_inv: T,
|
||||
|
||||
/// Note that the below flags are mutually exclusive.
|
||||
/// 1 if the JUMPI falls though (because input1 is 0); else 0.
|
||||
pub(crate) should_continue: T,
|
||||
/// 1 if the JUMP/JUMPI does in fact jump to `input0`; else 0. This requires `input0` to be a
|
||||
/// valid destination (`input0[0]` is a `JUMPDEST` not in an immediate or we are in kernel mode
|
||||
/// and also `input0[1..7]` is 0) and `input1` to be nonzero.
|
||||
// A flag.
|
||||
pub(crate) should_jump: T,
|
||||
/// 1 if the JUMP/JUMPI faults; else 0. This happens when `input0` is not a valid destination
|
||||
/// (`input0[0]` is not `JUMPDEST` that is not in an immediate while we are in user mode, or
|
||||
/// `input0[1..7]` is nonzero) and `input1` is nonzero.
|
||||
pub(crate) should_trap: T,
|
||||
// Pseudoinverse of `cond.iter().sum()`. Used to check `should_jump`.
|
||||
pub(crate) cond_sum_pinv: T,
|
||||
}
|
||||
|
||||
#[derive(Copy, Clone)]
|
||||
|
||||
@ -145,7 +145,7 @@ impl<F: RichField + Extendable<D>, const D: usize> Stark<F, D> for CpuStark<F, D
|
||||
control_flow::eval_packed_generic(local_values, next_values, yield_constr);
|
||||
decode::eval_packed_generic(local_values, &mut dummy_yield_constr);
|
||||
dup_swap::eval_packed(local_values, yield_constr);
|
||||
jumps::eval_packed(local_values, next_values, &mut dummy_yield_constr);
|
||||
jumps::eval_packed(local_values, next_values, yield_constr);
|
||||
membus::eval_packed(local_values, yield_constr);
|
||||
memio::eval_packed(local_values, yield_constr);
|
||||
modfp254::eval_packed(local_values, yield_constr);
|
||||
@ -174,7 +174,7 @@ impl<F: RichField + Extendable<D>, const D: usize> Stark<F, D> for CpuStark<F, D
|
||||
control_flow::eval_ext_circuit(builder, local_values, next_values, yield_constr);
|
||||
decode::eval_ext_circuit(builder, local_values, &mut dummy_yield_constr);
|
||||
dup_swap::eval_ext_circuit(builder, local_values, yield_constr);
|
||||
jumps::eval_ext_circuit(builder, local_values, next_values, &mut dummy_yield_constr);
|
||||
jumps::eval_ext_circuit(builder, local_values, next_values, yield_constr);
|
||||
membus::eval_ext_circuit(builder, local_values, yield_constr);
|
||||
memio::eval_ext_circuit(builder, local_values, yield_constr);
|
||||
modfp254::eval_ext_circuit(builder, local_values, yield_constr);
|
||||
|
||||
@ -1,4 +1,3 @@
|
||||
use once_cell::sync::Lazy;
|
||||
use plonky2::field::extension::Extendable;
|
||||
use plonky2::field::packed::PackedField;
|
||||
use plonky2::field::types::Field;
|
||||
@ -7,10 +6,8 @@ use plonky2::iop::ext_target::ExtensionTarget;
|
||||
|
||||
use crate::constraint_consumer::{ConstraintConsumer, RecursiveConstraintConsumer};
|
||||
use crate::cpu::columns::CpuColumnsView;
|
||||
use crate::cpu::kernel::aggregator::KERNEL;
|
||||
|
||||
static INVALID_DST_HANDLER_ADDR: Lazy<usize> =
|
||||
Lazy::new(|| KERNEL.global_labels["fault_exception"]);
|
||||
use crate::cpu::membus::NUM_GP_CHANNELS;
|
||||
use crate::memory::segments::Segment;
|
||||
|
||||
pub fn eval_packed_exit_kernel<P: PackedField>(
|
||||
lv: &CpuColumnsView<P>,
|
||||
@ -58,99 +55,65 @@ pub fn eval_packed_jump_jumpi<P: PackedField>(
|
||||
yield_constr: &mut ConstraintConsumer<P>,
|
||||
) {
|
||||
let jumps_lv = lv.general.jumps();
|
||||
let input0 = lv.mem_channels[0].value;
|
||||
let input1 = lv.mem_channels[1].value;
|
||||
let dst = lv.mem_channels[0].value;
|
||||
let cond = lv.mem_channels[1].value;
|
||||
let filter = lv.op.jump + lv.op.jumpi; // `JUMP` or `JUMPI`
|
||||
let jumpdest_flag_channel = lv.mem_channels[NUM_GP_CHANNELS - 1];
|
||||
|
||||
// If `JUMP`, re-use the `JUMPI` logic, but setting the second input (the predicate) to be 1.
|
||||
// In other words, we implement `JUMP(addr)` as `JUMPI(addr, cond=1)`.
|
||||
yield_constr.constraint(lv.op.jump * (input1[0] - P::ONES));
|
||||
for &limb in &input1[1..] {
|
||||
// In other words, we implement `JUMP(dst)` as `JUMPI(dst, cond=1)`.
|
||||
yield_constr.constraint(lv.op.jump * (cond[0] - P::ONES));
|
||||
for &limb in &cond[1..] {
|
||||
// Set all limbs (other than the least-significant limb) to 0.
|
||||
// NB: Technically, they don't have to be 0, as long as the sum
|
||||
// `input1[0] + ... + input1[7]` cannot overflow.
|
||||
// `cond[0] + ... + cond[7]` cannot overflow.
|
||||
yield_constr.constraint(lv.op.jump * limb);
|
||||
}
|
||||
|
||||
// Check `input0_upper_zero`
|
||||
// `input0_upper_zero` is either 0 or 1.
|
||||
yield_constr
|
||||
.constraint(filter * jumps_lv.input0_upper_zero * (jumps_lv.input0_upper_zero - P::ONES));
|
||||
// The below sum cannot overflow due to the limb size.
|
||||
let input0_upper_sum: P = input0[1..].iter().copied().sum();
|
||||
// `input0_upper_zero` = 1 implies `input0_upper_sum` = 0.
|
||||
yield_constr.constraint(filter * jumps_lv.input0_upper_zero * input0_upper_sum);
|
||||
// `input0_upper_zero` = 0 implies `input0_upper_sum_inv * input0_upper_sum` = 1, which can only
|
||||
// happen when `input0_upper_sum` is nonzero.
|
||||
yield_constr.constraint(
|
||||
filter
|
||||
* (jumps_lv.input0_upper_sum_inv * input0_upper_sum + jumps_lv.input0_upper_zero
|
||||
- P::ONES),
|
||||
);
|
||||
|
||||
// Check `dst_valid_or_kernel` (this is just a logical OR)
|
||||
yield_constr.constraint(
|
||||
filter
|
||||
* (jumps_lv.dst_valid + lv.is_kernel_mode
|
||||
- jumps_lv.dst_valid * lv.is_kernel_mode
|
||||
- jumps_lv.dst_valid_or_kernel),
|
||||
);
|
||||
|
||||
// Check `input0_jumpable` (this is just `dst_valid_or_kernel` AND `input0_upper_zero`)
|
||||
yield_constr.constraint(
|
||||
filter
|
||||
* (jumps_lv.dst_valid_or_kernel * jumps_lv.input0_upper_zero
|
||||
- jumps_lv.input0_jumpable),
|
||||
);
|
||||
|
||||
// Make sure that `should_continue`, `should_jump`, `should_trap` are all binary and exactly one
|
||||
// is set.
|
||||
yield_constr
|
||||
.constraint(filter * jumps_lv.should_continue * (jumps_lv.should_continue - P::ONES));
|
||||
// Check `should_jump`:
|
||||
yield_constr.constraint(filter * jumps_lv.should_jump * (jumps_lv.should_jump - P::ONES));
|
||||
yield_constr.constraint(filter * jumps_lv.should_trap * (jumps_lv.should_trap - P::ONES));
|
||||
let cond_sum: P = cond.into_iter().sum();
|
||||
yield_constr.constraint(filter * (jumps_lv.should_jump - P::ONES) * cond_sum);
|
||||
yield_constr.constraint(filter * (jumps_lv.cond_sum_pinv * cond_sum - jumps_lv.should_jump));
|
||||
|
||||
// If we're jumping, then the high 7 limbs of the destination must be 0.
|
||||
let dst_hi_sum: P = dst[1..].iter().copied().sum();
|
||||
yield_constr.constraint(filter * jumps_lv.should_jump * dst_hi_sum);
|
||||
// Check that the destination address holds a `JUMPDEST` instruction. Note that this constraint
|
||||
// does not need to be conditioned on `should_jump` because no read takes place if we're not
|
||||
// jumping, so we're free to set the channel to 1.
|
||||
yield_constr.constraint(filter * (jumpdest_flag_channel.value[0] - P::ONES));
|
||||
|
||||
// Make sure that the JUMPDEST flag channel is constrained.
|
||||
// Only need to read if we're about to jump and we're not in kernel mode.
|
||||
yield_constr.constraint(
|
||||
filter * (jumps_lv.should_continue + jumps_lv.should_jump + jumps_lv.should_trap - P::ONES),
|
||||
);
|
||||
|
||||
// Validate `should_continue`
|
||||
// This sum cannot overflow (due to limb size).
|
||||
let input1_sum: P = input1.into_iter().sum();
|
||||
// `should_continue` = 1 implies `input1_sum` = 0.
|
||||
yield_constr.constraint(filter * jumps_lv.should_continue * input1_sum);
|
||||
// `should_continue` = 0 implies `input1_sum * input1_sum_inv` = 1, which can only happen if
|
||||
// input1_sum is nonzero.
|
||||
yield_constr.constraint(
|
||||
filter * (input1_sum * jumps_lv.input1_sum_inv + jumps_lv.should_continue - P::ONES),
|
||||
);
|
||||
|
||||
// Validate `should_jump` and `should_trap` by splitting on `input0_jumpable`.
|
||||
// Note that `should_jump` = 1 and `should_trap` = 1 both imply that `should_continue` = 0, so
|
||||
// `input1` is nonzero.
|
||||
yield_constr.constraint(filter * jumps_lv.should_jump * (jumps_lv.input0_jumpable - P::ONES));
|
||||
yield_constr.constraint(filter * jumps_lv.should_trap * jumps_lv.input0_jumpable);
|
||||
|
||||
// Handle trap
|
||||
// Set program counter and kernel flag
|
||||
yield_constr
|
||||
.constraint_transition(filter * jumps_lv.should_trap * (nv.is_kernel_mode - P::ONES));
|
||||
yield_constr.constraint_transition(
|
||||
filter
|
||||
* jumps_lv.should_trap
|
||||
* (nv.program_counter - P::Scalar::from_canonical_usize(*INVALID_DST_HANDLER_ADDR)),
|
||||
* (jumpdest_flag_channel.used - jumps_lv.should_jump * (P::ONES - lv.is_kernel_mode)),
|
||||
);
|
||||
yield_constr.constraint(filter * (jumpdest_flag_channel.is_read - P::ONES));
|
||||
yield_constr.constraint(filter * (jumpdest_flag_channel.addr_context - lv.context));
|
||||
yield_constr.constraint(
|
||||
filter
|
||||
* (jumpdest_flag_channel.addr_segment
|
||||
- P::Scalar::from_canonical_u64(Segment::JumpdestBits as u64)),
|
||||
);
|
||||
yield_constr.constraint(filter * (jumpdest_flag_channel.addr_virtual - dst[0]));
|
||||
|
||||
// Handle continue and jump
|
||||
let continue_or_jump = jumps_lv.should_continue + jumps_lv.should_jump;
|
||||
// Keep kernel mode.
|
||||
yield_constr
|
||||
.constraint_transition(filter * continue_or_jump * (nv.is_kernel_mode - lv.is_kernel_mode));
|
||||
// Set program counter depending on whether we're continuing or jumping.
|
||||
// Disable unused memory channels
|
||||
for &channel in &lv.mem_channels[2..NUM_GP_CHANNELS - 1] {
|
||||
yield_constr.constraint(filter * channel.used);
|
||||
}
|
||||
// Channel 1 is unused by the `JUMP` instruction.
|
||||
yield_constr.constraint(lv.op.jump * lv.mem_channels[1].used);
|
||||
|
||||
// Finally, set the next program counter.
|
||||
let fallthrough_dst = lv.program_counter + P::ONES;
|
||||
let jump_dest = dst[0];
|
||||
yield_constr.constraint_transition(
|
||||
filter * jumps_lv.should_continue * (nv.program_counter - lv.program_counter - P::ONES),
|
||||
filter * (jumps_lv.should_jump - P::ONES) * (nv.program_counter - fallthrough_dst),
|
||||
);
|
||||
yield_constr
|
||||
.constraint_transition(filter * jumps_lv.should_jump * (nv.program_counter - input0[0]));
|
||||
.constraint_transition(filter * jumps_lv.should_jump * (nv.program_counter - jump_dest));
|
||||
}
|
||||
|
||||
pub fn eval_ext_circuit_jump_jumpi<F: RichField + Extendable<D>, const D: usize>(
|
||||
@ -160,178 +123,124 @@ pub fn eval_ext_circuit_jump_jumpi<F: RichField + Extendable<D>, const D: usize>
|
||||
yield_constr: &mut RecursiveConstraintConsumer<F, D>,
|
||||
) {
|
||||
let jumps_lv = lv.general.jumps();
|
||||
let input0 = lv.mem_channels[0].value;
|
||||
let input1 = lv.mem_channels[1].value;
|
||||
let dst = lv.mem_channels[0].value;
|
||||
let cond = lv.mem_channels[1].value;
|
||||
let filter = builder.add_extension(lv.op.jump, lv.op.jumpi); // `JUMP` or `JUMPI`
|
||||
let jumpdest_flag_channel = lv.mem_channels[NUM_GP_CHANNELS - 1];
|
||||
|
||||
// If `JUMP`, re-use the `JUMPI` logic, but setting the second input (the predicate) to be 1.
|
||||
// In other words, we implement `JUMP(addr)` as `JUMPI(addr, cond=1)`.
|
||||
// In other words, we implement `JUMP(dst)` as `JUMPI(dst, cond=1)`.
|
||||
{
|
||||
let constr = builder.mul_sub_extension(lv.op.jump, input1[0], lv.op.jump);
|
||||
let constr = builder.mul_sub_extension(lv.op.jump, cond[0], lv.op.jump);
|
||||
yield_constr.constraint(builder, constr);
|
||||
}
|
||||
for &limb in &input1[1..] {
|
||||
for &limb in &cond[1..] {
|
||||
// Set all limbs (other than the least-significant limb) to 0.
|
||||
// NB: Technically, they don't have to be 0, as long as the sum
|
||||
// `input1[0] + ... + input1[7]` cannot overflow.
|
||||
// `cond[0] + ... + cond[7]` cannot overflow.
|
||||
let constr = builder.mul_extension(lv.op.jump, limb);
|
||||
yield_constr.constraint(builder, constr);
|
||||
}
|
||||
|
||||
// Check `input0_upper_zero`
|
||||
// `input0_upper_zero` is either 0 or 1.
|
||||
// Check `should_jump`:
|
||||
{
|
||||
let constr = builder.mul_sub_extension(
|
||||
jumps_lv.input0_upper_zero,
|
||||
jumps_lv.input0_upper_zero,
|
||||
jumps_lv.input0_upper_zero,
|
||||
jumps_lv.should_jump,
|
||||
jumps_lv.should_jump,
|
||||
jumps_lv.should_jump,
|
||||
);
|
||||
let constr = builder.mul_extension(filter, constr);
|
||||
yield_constr.constraint(builder, constr);
|
||||
}
|
||||
let cond_sum = builder.add_many_extension(cond);
|
||||
{
|
||||
// The below sum cannot overflow due to the limb size.
|
||||
let input0_upper_sum = builder.add_many_extension(input0[1..].iter());
|
||||
|
||||
// `input0_upper_zero` = 1 implies `input0_upper_sum` = 0.
|
||||
let constr = builder.mul_extension(jumps_lv.input0_upper_zero, input0_upper_sum);
|
||||
let constr = builder.mul_sub_extension(cond_sum, jumps_lv.should_jump, cond_sum);
|
||||
let constr = builder.mul_extension(filter, constr);
|
||||
yield_constr.constraint(builder, constr);
|
||||
|
||||
// `input0_upper_zero` = 0 implies `input0_upper_sum_inv * input0_upper_sum` = 1, which can
|
||||
// only happen when `input0_upper_sum` is nonzero.
|
||||
let constr = builder.mul_add_extension(
|
||||
jumps_lv.input0_upper_sum_inv,
|
||||
input0_upper_sum,
|
||||
jumps_lv.input0_upper_zero,
|
||||
);
|
||||
let constr = builder.mul_sub_extension(filter, constr, filter);
|
||||
yield_constr.constraint(builder, constr);
|
||||
};
|
||||
|
||||
// Check `dst_valid_or_kernel` (this is just a logical OR)
|
||||
}
|
||||
{
|
||||
let constr = builder.mul_add_extension(
|
||||
jumps_lv.dst_valid,
|
||||
let constr =
|
||||
builder.mul_sub_extension(jumps_lv.cond_sum_pinv, cond_sum, jumps_lv.should_jump);
|
||||
let constr = builder.mul_extension(filter, constr);
|
||||
yield_constr.constraint(builder, constr);
|
||||
}
|
||||
|
||||
// If we're jumping, then the high 7 limbs of the destination must be 0.
|
||||
let dst_hi_sum = builder.add_many_extension(&dst[1..]);
|
||||
{
|
||||
let constr = builder.mul_extension(jumps_lv.should_jump, dst_hi_sum);
|
||||
let constr = builder.mul_extension(filter, constr);
|
||||
yield_constr.constraint(builder, constr);
|
||||
}
|
||||
// Check that the destination address holds a `JUMPDEST` instruction. Note that this constraint
|
||||
// does not need to be conditioned on `should_jump` because no read takes place if we're not
|
||||
// jumping, so we're free to set the channel to 1.
|
||||
{
|
||||
let constr = builder.mul_sub_extension(filter, jumpdest_flag_channel.value[0], filter);
|
||||
yield_constr.constraint(builder, constr);
|
||||
}
|
||||
|
||||
// Make sure that the JUMPDEST flag channel is constrained.
|
||||
// Only need to read if we're about to jump and we're not in kernel mode.
|
||||
{
|
||||
let constr = builder.mul_sub_extension(
|
||||
jumps_lv.should_jump,
|
||||
lv.is_kernel_mode,
|
||||
jumps_lv.dst_valid_or_kernel,
|
||||
);
|
||||
let constr = builder.sub_extension(jumps_lv.dst_valid, constr);
|
||||
let constr = builder.add_extension(lv.is_kernel_mode, constr);
|
||||
let constr = builder.mul_extension(filter, constr);
|
||||
yield_constr.constraint(builder, constr);
|
||||
}
|
||||
|
||||
// Check `input0_jumpable` (this is just `dst_valid_or_kernel` AND `input0_upper_zero`)
|
||||
{
|
||||
let constr = builder.mul_sub_extension(
|
||||
jumps_lv.dst_valid_or_kernel,
|
||||
jumps_lv.input0_upper_zero,
|
||||
jumps_lv.input0_jumpable,
|
||||
);
|
||||
let constr = builder.mul_extension(filter, constr);
|
||||
yield_constr.constraint(builder, constr);
|
||||
}
|
||||
|
||||
// Make sure that `should_continue`, `should_jump`, `should_trap` are all binary and exactly one
|
||||
// is set.
|
||||
for flag in [
|
||||
jumps_lv.should_continue,
|
||||
jumps_lv.should_jump,
|
||||
jumps_lv.should_trap,
|
||||
] {
|
||||
let constr = builder.mul_sub_extension(flag, flag, flag);
|
||||
let constr = builder.mul_extension(filter, constr);
|
||||
yield_constr.constraint(builder, constr);
|
||||
}
|
||||
{
|
||||
let constr = builder.add_extension(jumps_lv.should_continue, jumps_lv.should_jump);
|
||||
let constr = builder.add_extension(constr, jumps_lv.should_trap);
|
||||
let constr = builder.mul_sub_extension(filter, constr, filter);
|
||||
yield_constr.constraint(builder, constr);
|
||||
}
|
||||
|
||||
// Validate `should_continue`
|
||||
{
|
||||
// This sum cannot overflow (due to limb size).
|
||||
let input1_sum = builder.add_many_extension(input1.into_iter());
|
||||
|
||||
// `should_continue` = 1 implies `input1_sum` = 0.
|
||||
let constr = builder.mul_extension(jumps_lv.should_continue, input1_sum);
|
||||
let constr = builder.mul_extension(filter, constr);
|
||||
yield_constr.constraint(builder, constr);
|
||||
|
||||
// `should_continue` = 0 implies `input1_sum * input1_sum_inv` = 1, which can only happen if
|
||||
// input1_sum is nonzero.
|
||||
let constr = builder.mul_add_extension(
|
||||
input1_sum,
|
||||
jumps_lv.input1_sum_inv,
|
||||
jumps_lv.should_continue,
|
||||
);
|
||||
let constr = builder.mul_sub_extension(filter, constr, filter);
|
||||
yield_constr.constraint(builder, constr);
|
||||
}
|
||||
|
||||
// Validate `should_jump` and `should_trap` by splitting on `input0_jumpable`.
|
||||
// Note that `should_jump` = 1 and `should_trap` = 1 both imply that `should_continue` = 0, so
|
||||
// `input1` is nonzero.
|
||||
{
|
||||
let constr = builder.mul_sub_extension(
|
||||
jumps_lv.should_jump,
|
||||
jumps_lv.input0_jumpable,
|
||||
jumps_lv.should_jump,
|
||||
);
|
||||
let constr = builder.add_extension(jumpdest_flag_channel.used, constr);
|
||||
let constr = builder.mul_extension(filter, constr);
|
||||
yield_constr.constraint(builder, constr);
|
||||
}
|
||||
{
|
||||
let constr = builder.mul_extension(jumps_lv.should_trap, jumps_lv.input0_jumpable);
|
||||
let constr = builder.mul_sub_extension(filter, jumpdest_flag_channel.is_read, filter);
|
||||
yield_constr.constraint(builder, constr);
|
||||
}
|
||||
{
|
||||
let constr = builder.sub_extension(jumpdest_flag_channel.addr_context, lv.context);
|
||||
let constr = builder.mul_extension(filter, constr);
|
||||
yield_constr.constraint(builder, constr);
|
||||
}
|
||||
|
||||
// Handle trap
|
||||
{
|
||||
let trap_filter = builder.mul_extension(filter, jumps_lv.should_trap);
|
||||
|
||||
// Set kernel flag
|
||||
let constr = builder.mul_sub_extension(trap_filter, nv.is_kernel_mode, trap_filter);
|
||||
yield_constr.constraint_transition(builder, constr);
|
||||
|
||||
// Set program counter
|
||||
let constr = builder.arithmetic_extension(
|
||||
F::ONE,
|
||||
-F::from_canonical_usize(*INVALID_DST_HANDLER_ADDR),
|
||||
trap_filter,
|
||||
nv.program_counter,
|
||||
trap_filter,
|
||||
-F::from_canonical_u64(Segment::JumpdestBits as u64),
|
||||
filter,
|
||||
jumpdest_flag_channel.addr_segment,
|
||||
filter,
|
||||
);
|
||||
yield_constr.constraint_transition(builder, constr);
|
||||
yield_constr.constraint(builder, constr);
|
||||
}
|
||||
{
|
||||
let constr = builder.sub_extension(jumpdest_flag_channel.addr_virtual, dst[0]);
|
||||
let constr = builder.mul_extension(filter, constr);
|
||||
yield_constr.constraint(builder, constr);
|
||||
}
|
||||
|
||||
// Handle continue and jump
|
||||
// Disable unused memory channels
|
||||
for &channel in &lv.mem_channels[2..NUM_GP_CHANNELS - 1] {
|
||||
let constr = builder.mul_extension(filter, channel.used);
|
||||
yield_constr.constraint(builder, constr);
|
||||
}
|
||||
// Channel 1 is unused by the `JUMP` instruction.
|
||||
{
|
||||
// Keep kernel mode.
|
||||
let continue_or_jump =
|
||||
builder.add_extension(jumps_lv.should_continue, jumps_lv.should_jump);
|
||||
let constr = builder.sub_extension(nv.is_kernel_mode, lv.is_kernel_mode);
|
||||
let constr = builder.mul_extension(continue_or_jump, constr);
|
||||
let constr = builder.mul_extension(filter, constr);
|
||||
let constr = builder.mul_extension(lv.op.jump, lv.mem_channels[1].used);
|
||||
yield_constr.constraint(builder, constr);
|
||||
}
|
||||
|
||||
// Finally, set the next program counter.
|
||||
let fallthrough_dst = builder.add_const_extension(lv.program_counter, F::ONE);
|
||||
let jump_dest = dst[0];
|
||||
{
|
||||
let constr_a = builder.mul_sub_extension(filter, jumps_lv.should_jump, filter);
|
||||
let constr_b = builder.sub_extension(nv.program_counter, fallthrough_dst);
|
||||
let constr = builder.mul_extension(constr_a, constr_b);
|
||||
yield_constr.constraint_transition(builder, constr);
|
||||
}
|
||||
// Set program counter depending on whether we're continuing...
|
||||
{
|
||||
let constr = builder.sub_extension(nv.program_counter, lv.program_counter);
|
||||
let constr =
|
||||
builder.mul_sub_extension(jumps_lv.should_continue, constr, jumps_lv.should_continue);
|
||||
let constr = builder.mul_extension(filter, constr);
|
||||
yield_constr.constraint_transition(builder, constr);
|
||||
}
|
||||
// ...or jumping.
|
||||
{
|
||||
let constr = builder.sub_extension(nv.program_counter, input0[0]);
|
||||
let constr = builder.mul_extension(jumps_lv.should_jump, constr);
|
||||
let constr = builder.mul_extension(filter, constr);
|
||||
let constr_a = builder.mul_extension(filter, jumps_lv.should_jump);
|
||||
let constr_b = builder.sub_extension(nv.program_counter, jump_dest);
|
||||
let constr = builder.mul_extension(constr_a, constr_b);
|
||||
yield_constr.constraint_transition(builder, constr);
|
||||
}
|
||||
}
|
||||
|
||||
@ -64,8 +64,16 @@ const STACK_BEHAVIORS: OpsColumnsView<Option<StackBehavior>> = OpsColumnsView {
|
||||
keccak_general: None, // TODO
|
||||
prover_input: None, // TODO
|
||||
pop: None, // TODO
|
||||
jump: None, // TODO
|
||||
jumpi: None, // TODO
|
||||
jump: Some(StackBehavior {
|
||||
num_pops: 1,
|
||||
pushes: false,
|
||||
disable_other_channels: false,
|
||||
}),
|
||||
jumpi: Some(StackBehavior {
|
||||
num_pops: 2,
|
||||
pushes: false,
|
||||
disable_other_channels: false,
|
||||
}),
|
||||
pc: Some(StackBehavior {
|
||||
num_pops: 0,
|
||||
pushes: true,
|
||||
@ -91,7 +99,11 @@ const STACK_BEHAVIORS: OpsColumnsView<Option<StackBehavior>> = OpsColumnsView {
|
||||
disable_other_channels: true,
|
||||
}),
|
||||
consume_gas: None, // TODO
|
||||
exit_kernel: None, // TODO
|
||||
exit_kernel: Some(StackBehavior {
|
||||
num_pops: 1,
|
||||
pushes: false,
|
||||
disable_other_channels: true,
|
||||
}),
|
||||
mload_general: Some(StackBehavior {
|
||||
num_pops: 3,
|
||||
pushes: true,
|
||||
|
||||
@ -144,11 +144,3 @@ pub(crate) fn biguint_to_u256(x: BigUint) -> U256 {
|
||||
let bytes = x.to_bytes_le();
|
||||
U256::from_little_endian(&bytes)
|
||||
}
|
||||
|
||||
pub(crate) fn u256_saturating_cast_usize(x: U256) -> usize {
|
||||
if x > usize::MAX.into() {
|
||||
usize::MAX
|
||||
} else {
|
||||
x.as_usize()
|
||||
}
|
||||
}
|
||||
|
||||
@ -10,7 +10,6 @@ use crate::cpu::membus::NUM_GP_CHANNELS;
|
||||
use crate::cpu::simple_logic::eq_iszero::generate_pinv_diff;
|
||||
use crate::generation::state::GenerationState;
|
||||
use crate::memory::segments::Segment;
|
||||
use crate::util::u256_saturating_cast_usize;
|
||||
use crate::witness::errors::ProgramError;
|
||||
use crate::witness::memory::MemoryAddress;
|
||||
use crate::witness::util::{
|
||||
@ -187,12 +186,37 @@ pub(crate) fn generate_jump<F: Field>(
|
||||
mut row: CpuColumnsView<F>,
|
||||
) -> Result<(), ProgramError> {
|
||||
let [(dst, log_in0)] = stack_pop_with_log_and_fill::<1, _>(state, &mut row)?;
|
||||
let dst: u32 = dst
|
||||
.try_into()
|
||||
.map_err(|_| ProgramError::InvalidJumpDestination)?;
|
||||
|
||||
let (jumpdest_bit, jumpdest_bit_log) = mem_read_gp_with_log_and_fill(
|
||||
NUM_GP_CHANNELS - 1,
|
||||
MemoryAddress::new(state.registers.context, Segment::JumpdestBits, dst as usize),
|
||||
state,
|
||||
&mut row,
|
||||
);
|
||||
if state.registers.is_kernel {
|
||||
// Don't actually do the read, just set the address, etc.
|
||||
let mut channel = &mut row.mem_channels[NUM_GP_CHANNELS - 1];
|
||||
channel.used = F::ZERO;
|
||||
channel.value[0] = F::ONE;
|
||||
|
||||
row.mem_channels[1].value[0] = F::ONE;
|
||||
} else {
|
||||
if jumpdest_bit != U256::one() {
|
||||
return Err(ProgramError::InvalidJumpDestination);
|
||||
}
|
||||
state.traces.push_memory(jumpdest_bit_log);
|
||||
}
|
||||
|
||||
// Extra fields required by the constraints.
|
||||
row.general.jumps_mut().should_jump = F::ONE;
|
||||
row.general.jumps_mut().cond_sum_pinv = F::ONE;
|
||||
|
||||
state.traces.push_memory(log_in0);
|
||||
state.traces.push_cpu(row);
|
||||
// TODO: First check if it's a valid JUMPDEST
|
||||
state.registers.program_counter = u256_saturating_cast_usize(dst);
|
||||
// TODO: Set other cols like input0_upper_sum_inv.
|
||||
state.registers.program_counter = dst as usize;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@ -202,16 +226,52 @@ pub(crate) fn generate_jumpi<F: Field>(
|
||||
) -> Result<(), ProgramError> {
|
||||
let [(dst, log_in0), (cond, log_in1)] = stack_pop_with_log_and_fill::<2, _>(state, &mut row)?;
|
||||
|
||||
let should_jump = !cond.is_zero();
|
||||
if should_jump {
|
||||
row.general.jumps_mut().should_jump = F::ONE;
|
||||
let cond_sum_u64 = cond
|
||||
.0
|
||||
.into_iter()
|
||||
.map(|limb| ((limb as u32) as u64) + (limb >> 32))
|
||||
.sum();
|
||||
let cond_sum = F::from_canonical_u64(cond_sum_u64);
|
||||
row.general.jumps_mut().cond_sum_pinv = cond_sum.inverse();
|
||||
|
||||
let dst: u32 = dst
|
||||
.try_into()
|
||||
.map_err(|_| ProgramError::InvalidJumpiDestination)?;
|
||||
state.registers.program_counter = dst as usize;
|
||||
} else {
|
||||
row.general.jumps_mut().should_jump = F::ZERO;
|
||||
row.general.jumps_mut().cond_sum_pinv = F::ZERO;
|
||||
state.registers.program_counter += 1;
|
||||
}
|
||||
|
||||
let (jumpdest_bit, jumpdest_bit_log) = mem_read_gp_with_log_and_fill(
|
||||
NUM_GP_CHANNELS - 1,
|
||||
MemoryAddress::new(
|
||||
state.registers.context,
|
||||
Segment::JumpdestBits,
|
||||
dst.low_u32() as usize,
|
||||
),
|
||||
state,
|
||||
&mut row,
|
||||
);
|
||||
if !should_jump || state.registers.is_kernel {
|
||||
// Don't actually do the read, just set the address, etc.
|
||||
let mut channel = &mut row.mem_channels[NUM_GP_CHANNELS - 1];
|
||||
channel.used = F::ZERO;
|
||||
channel.value[0] = F::ONE;
|
||||
} else {
|
||||
if jumpdest_bit != U256::one() {
|
||||
return Err(ProgramError::InvalidJumpiDestination);
|
||||
}
|
||||
state.traces.push_memory(jumpdest_bit_log);
|
||||
}
|
||||
|
||||
state.traces.push_memory(log_in0);
|
||||
state.traces.push_memory(log_in1);
|
||||
state.traces.push_cpu(row);
|
||||
state.registers.program_counter = if cond.is_zero() {
|
||||
state.registers.program_counter + 1
|
||||
} else {
|
||||
// TODO: First check if it's a valid JUMPDEST
|
||||
u256_saturating_cast_usize(dst)
|
||||
};
|
||||
// TODO: Set other cols like input0_upper_sum_inv.
|
||||
Ok(())
|
||||
}
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user