Simplify JUMP/JUMPI constraints and finish witness generation (#846)

* Simplify `JUMP`/`JUMPI` constraints and finish witness generation

* Constrain stack
This commit is contained in:
Jacqueline Nabaglo 2022-12-11 11:08:33 -08:00 committed by GitHub
parent 1732239a00
commit b6bc018cba
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 212 additions and 275 deletions

View File

@ -97,46 +97,10 @@ pub(crate) struct CpuLogicView<T: Copy> {
#[derive(Copy, Clone)]
pub(crate) struct CpuJumpsView<T: Copy> {
/// `input0` is `mem_channel[0].value`. It's the top stack value at entry (for jumps, the
/// address; for `EXIT_KERNEL`, the address and new privilege level).
/// `input1` is `mem_channel[1].value`. For `JUMPI`, it's the second stack value (the
/// predicate). For `JUMP`, 1.
/// Inverse of `input0[1] + ... + input0[7]`, if one exists; otherwise, an arbitrary value.
/// Needed to prove that `input0` is nonzero.
pub(crate) input0_upper_sum_inv: T,
/// 1 if `input0[1..7]` is zero; else 0.
pub(crate) input0_upper_zero: T,
/// 1 if `input0[0]` is the address of a valid jump destination (i.e. `JUMPDEST` that is not
/// part of a `PUSH` immediate); else 0. Note that the kernel is allowed to jump anywhere it
/// wants, so this flag is computed but ignored in kernel mode.
/// NOTE: this flag only considers `input0[0]`, the low 32 bits of the 256-bit register. Even if
/// this flag is 1, `input0` will still be an invalid address if the high 224 bits are not 0.
pub(crate) dst_valid: T, // TODO: populate this (check for JUMPDEST)
/// 1 if either `dst_valid` is 1 or we are in kernel mode; else 0. (Just a logical OR.)
pub(crate) dst_valid_or_kernel: T,
/// 1 if `dst_valid_or_kernel` and `input0_upper_zero` are both 1; else 0. In other words, we
/// are allowed to jump to `input0[0]` because either it's a valid address or we're in kernel
/// mode (`dst_valid_or_kernel`), and also `input0[1..7]` are all 0 so `input0[0]` is in fact
/// the whole address (we're not being asked to jump to an address that would overflow).
pub(crate) input0_jumpable: T,
/// Inverse of `input1[0] + ... + input1[7]`, if one exists; otherwise, an arbitrary value.
/// Needed to prove that `input1` is nonzero.
pub(crate) input1_sum_inv: T,
/// Note that the below flags are mutually exclusive.
/// 1 if the JUMPI falls though (because input1 is 0); else 0.
pub(crate) should_continue: T,
/// 1 if the JUMP/JUMPI does in fact jump to `input0`; else 0. This requires `input0` to be a
/// valid destination (`input0[0]` is a `JUMPDEST` not in an immediate or we are in kernel mode
/// and also `input0[1..7]` is 0) and `input1` to be nonzero.
// A flag.
pub(crate) should_jump: T,
/// 1 if the JUMP/JUMPI faults; else 0. This happens when `input0` is not a valid destination
/// (`input0[0]` is not `JUMPDEST` that is not in an immediate while we are in user mode, or
/// `input0[1..7]` is nonzero) and `input1` is nonzero.
pub(crate) should_trap: T,
// Pseudoinverse of `cond.iter().sum()`. Used to check `should_jump`.
pub(crate) cond_sum_pinv: T,
}
#[derive(Copy, Clone)]

View File

@ -145,7 +145,7 @@ impl<F: RichField + Extendable<D>, const D: usize> Stark<F, D> for CpuStark<F, D
control_flow::eval_packed_generic(local_values, next_values, yield_constr);
decode::eval_packed_generic(local_values, &mut dummy_yield_constr);
dup_swap::eval_packed(local_values, yield_constr);
jumps::eval_packed(local_values, next_values, &mut dummy_yield_constr);
jumps::eval_packed(local_values, next_values, yield_constr);
membus::eval_packed(local_values, yield_constr);
memio::eval_packed(local_values, yield_constr);
modfp254::eval_packed(local_values, yield_constr);
@ -174,7 +174,7 @@ impl<F: RichField + Extendable<D>, const D: usize> Stark<F, D> for CpuStark<F, D
control_flow::eval_ext_circuit(builder, local_values, next_values, yield_constr);
decode::eval_ext_circuit(builder, local_values, &mut dummy_yield_constr);
dup_swap::eval_ext_circuit(builder, local_values, yield_constr);
jumps::eval_ext_circuit(builder, local_values, next_values, &mut dummy_yield_constr);
jumps::eval_ext_circuit(builder, local_values, next_values, yield_constr);
membus::eval_ext_circuit(builder, local_values, yield_constr);
memio::eval_ext_circuit(builder, local_values, yield_constr);
modfp254::eval_ext_circuit(builder, local_values, yield_constr);

View File

@ -1,4 +1,3 @@
use once_cell::sync::Lazy;
use plonky2::field::extension::Extendable;
use plonky2::field::packed::PackedField;
use plonky2::field::types::Field;
@ -7,10 +6,8 @@ use plonky2::iop::ext_target::ExtensionTarget;
use crate::constraint_consumer::{ConstraintConsumer, RecursiveConstraintConsumer};
use crate::cpu::columns::CpuColumnsView;
use crate::cpu::kernel::aggregator::KERNEL;
static INVALID_DST_HANDLER_ADDR: Lazy<usize> =
Lazy::new(|| KERNEL.global_labels["fault_exception"]);
use crate::cpu::membus::NUM_GP_CHANNELS;
use crate::memory::segments::Segment;
pub fn eval_packed_exit_kernel<P: PackedField>(
lv: &CpuColumnsView<P>,
@ -58,99 +55,65 @@ pub fn eval_packed_jump_jumpi<P: PackedField>(
yield_constr: &mut ConstraintConsumer<P>,
) {
let jumps_lv = lv.general.jumps();
let input0 = lv.mem_channels[0].value;
let input1 = lv.mem_channels[1].value;
let dst = lv.mem_channels[0].value;
let cond = lv.mem_channels[1].value;
let filter = lv.op.jump + lv.op.jumpi; // `JUMP` or `JUMPI`
let jumpdest_flag_channel = lv.mem_channels[NUM_GP_CHANNELS - 1];
// If `JUMP`, re-use the `JUMPI` logic, but setting the second input (the predicate) to be 1.
// In other words, we implement `JUMP(addr)` as `JUMPI(addr, cond=1)`.
yield_constr.constraint(lv.op.jump * (input1[0] - P::ONES));
for &limb in &input1[1..] {
// In other words, we implement `JUMP(dst)` as `JUMPI(dst, cond=1)`.
yield_constr.constraint(lv.op.jump * (cond[0] - P::ONES));
for &limb in &cond[1..] {
// Set all limbs (other than the least-significant limb) to 0.
// NB: Technically, they don't have to be 0, as long as the sum
// `input1[0] + ... + input1[7]` cannot overflow.
// `cond[0] + ... + cond[7]` cannot overflow.
yield_constr.constraint(lv.op.jump * limb);
}
// Check `input0_upper_zero`
// `input0_upper_zero` is either 0 or 1.
yield_constr
.constraint(filter * jumps_lv.input0_upper_zero * (jumps_lv.input0_upper_zero - P::ONES));
// The below sum cannot overflow due to the limb size.
let input0_upper_sum: P = input0[1..].iter().copied().sum();
// `input0_upper_zero` = 1 implies `input0_upper_sum` = 0.
yield_constr.constraint(filter * jumps_lv.input0_upper_zero * input0_upper_sum);
// `input0_upper_zero` = 0 implies `input0_upper_sum_inv * input0_upper_sum` = 1, which can only
// happen when `input0_upper_sum` is nonzero.
yield_constr.constraint(
filter
* (jumps_lv.input0_upper_sum_inv * input0_upper_sum + jumps_lv.input0_upper_zero
- P::ONES),
);
// Check `dst_valid_or_kernel` (this is just a logical OR)
yield_constr.constraint(
filter
* (jumps_lv.dst_valid + lv.is_kernel_mode
- jumps_lv.dst_valid * lv.is_kernel_mode
- jumps_lv.dst_valid_or_kernel),
);
// Check `input0_jumpable` (this is just `dst_valid_or_kernel` AND `input0_upper_zero`)
yield_constr.constraint(
filter
* (jumps_lv.dst_valid_or_kernel * jumps_lv.input0_upper_zero
- jumps_lv.input0_jumpable),
);
// Make sure that `should_continue`, `should_jump`, `should_trap` are all binary and exactly one
// is set.
yield_constr
.constraint(filter * jumps_lv.should_continue * (jumps_lv.should_continue - P::ONES));
// Check `should_jump`:
yield_constr.constraint(filter * jumps_lv.should_jump * (jumps_lv.should_jump - P::ONES));
yield_constr.constraint(filter * jumps_lv.should_trap * (jumps_lv.should_trap - P::ONES));
let cond_sum: P = cond.into_iter().sum();
yield_constr.constraint(filter * (jumps_lv.should_jump - P::ONES) * cond_sum);
yield_constr.constraint(filter * (jumps_lv.cond_sum_pinv * cond_sum - jumps_lv.should_jump));
// If we're jumping, then the high 7 limbs of the destination must be 0.
let dst_hi_sum: P = dst[1..].iter().copied().sum();
yield_constr.constraint(filter * jumps_lv.should_jump * dst_hi_sum);
// Check that the destination address holds a `JUMPDEST` instruction. Note that this constraint
// does not need to be conditioned on `should_jump` because no read takes place if we're not
// jumping, so we're free to set the channel to 1.
yield_constr.constraint(filter * (jumpdest_flag_channel.value[0] - P::ONES));
// Make sure that the JUMPDEST flag channel is constrained.
// Only need to read if we're about to jump and we're not in kernel mode.
yield_constr.constraint(
filter * (jumps_lv.should_continue + jumps_lv.should_jump + jumps_lv.should_trap - P::ONES),
);
// Validate `should_continue`
// This sum cannot overflow (due to limb size).
let input1_sum: P = input1.into_iter().sum();
// `should_continue` = 1 implies `input1_sum` = 0.
yield_constr.constraint(filter * jumps_lv.should_continue * input1_sum);
// `should_continue` = 0 implies `input1_sum * input1_sum_inv` = 1, which can only happen if
// input1_sum is nonzero.
yield_constr.constraint(
filter * (input1_sum * jumps_lv.input1_sum_inv + jumps_lv.should_continue - P::ONES),
);
// Validate `should_jump` and `should_trap` by splitting on `input0_jumpable`.
// Note that `should_jump` = 1 and `should_trap` = 1 both imply that `should_continue` = 0, so
// `input1` is nonzero.
yield_constr.constraint(filter * jumps_lv.should_jump * (jumps_lv.input0_jumpable - P::ONES));
yield_constr.constraint(filter * jumps_lv.should_trap * jumps_lv.input0_jumpable);
// Handle trap
// Set program counter and kernel flag
yield_constr
.constraint_transition(filter * jumps_lv.should_trap * (nv.is_kernel_mode - P::ONES));
yield_constr.constraint_transition(
filter
* jumps_lv.should_trap
* (nv.program_counter - P::Scalar::from_canonical_usize(*INVALID_DST_HANDLER_ADDR)),
* (jumpdest_flag_channel.used - jumps_lv.should_jump * (P::ONES - lv.is_kernel_mode)),
);
yield_constr.constraint(filter * (jumpdest_flag_channel.is_read - P::ONES));
yield_constr.constraint(filter * (jumpdest_flag_channel.addr_context - lv.context));
yield_constr.constraint(
filter
* (jumpdest_flag_channel.addr_segment
- P::Scalar::from_canonical_u64(Segment::JumpdestBits as u64)),
);
yield_constr.constraint(filter * (jumpdest_flag_channel.addr_virtual - dst[0]));
// Handle continue and jump
let continue_or_jump = jumps_lv.should_continue + jumps_lv.should_jump;
// Keep kernel mode.
yield_constr
.constraint_transition(filter * continue_or_jump * (nv.is_kernel_mode - lv.is_kernel_mode));
// Set program counter depending on whether we're continuing or jumping.
// Disable unused memory channels
for &channel in &lv.mem_channels[2..NUM_GP_CHANNELS - 1] {
yield_constr.constraint(filter * channel.used);
}
// Channel 1 is unused by the `JUMP` instruction.
yield_constr.constraint(lv.op.jump * lv.mem_channels[1].used);
// Finally, set the next program counter.
let fallthrough_dst = lv.program_counter + P::ONES;
let jump_dest = dst[0];
yield_constr.constraint_transition(
filter * jumps_lv.should_continue * (nv.program_counter - lv.program_counter - P::ONES),
filter * (jumps_lv.should_jump - P::ONES) * (nv.program_counter - fallthrough_dst),
);
yield_constr
.constraint_transition(filter * jumps_lv.should_jump * (nv.program_counter - input0[0]));
.constraint_transition(filter * jumps_lv.should_jump * (nv.program_counter - jump_dest));
}
pub fn eval_ext_circuit_jump_jumpi<F: RichField + Extendable<D>, const D: usize>(
@ -160,178 +123,124 @@ pub fn eval_ext_circuit_jump_jumpi<F: RichField + Extendable<D>, const D: usize>
yield_constr: &mut RecursiveConstraintConsumer<F, D>,
) {
let jumps_lv = lv.general.jumps();
let input0 = lv.mem_channels[0].value;
let input1 = lv.mem_channels[1].value;
let dst = lv.mem_channels[0].value;
let cond = lv.mem_channels[1].value;
let filter = builder.add_extension(lv.op.jump, lv.op.jumpi); // `JUMP` or `JUMPI`
let jumpdest_flag_channel = lv.mem_channels[NUM_GP_CHANNELS - 1];
// If `JUMP`, re-use the `JUMPI` logic, but setting the second input (the predicate) to be 1.
// In other words, we implement `JUMP(addr)` as `JUMPI(addr, cond=1)`.
// In other words, we implement `JUMP(dst)` as `JUMPI(dst, cond=1)`.
{
let constr = builder.mul_sub_extension(lv.op.jump, input1[0], lv.op.jump);
let constr = builder.mul_sub_extension(lv.op.jump, cond[0], lv.op.jump);
yield_constr.constraint(builder, constr);
}
for &limb in &input1[1..] {
for &limb in &cond[1..] {
// Set all limbs (other than the least-significant limb) to 0.
// NB: Technically, they don't have to be 0, as long as the sum
// `input1[0] + ... + input1[7]` cannot overflow.
// `cond[0] + ... + cond[7]` cannot overflow.
let constr = builder.mul_extension(lv.op.jump, limb);
yield_constr.constraint(builder, constr);
}
// Check `input0_upper_zero`
// `input0_upper_zero` is either 0 or 1.
// Check `should_jump`:
{
let constr = builder.mul_sub_extension(
jumps_lv.input0_upper_zero,
jumps_lv.input0_upper_zero,
jumps_lv.input0_upper_zero,
jumps_lv.should_jump,
jumps_lv.should_jump,
jumps_lv.should_jump,
);
let constr = builder.mul_extension(filter, constr);
yield_constr.constraint(builder, constr);
}
let cond_sum = builder.add_many_extension(cond);
{
// The below sum cannot overflow due to the limb size.
let input0_upper_sum = builder.add_many_extension(input0[1..].iter());
// `input0_upper_zero` = 1 implies `input0_upper_sum` = 0.
let constr = builder.mul_extension(jumps_lv.input0_upper_zero, input0_upper_sum);
let constr = builder.mul_sub_extension(cond_sum, jumps_lv.should_jump, cond_sum);
let constr = builder.mul_extension(filter, constr);
yield_constr.constraint(builder, constr);
// `input0_upper_zero` = 0 implies `input0_upper_sum_inv * input0_upper_sum` = 1, which can
// only happen when `input0_upper_sum` is nonzero.
let constr = builder.mul_add_extension(
jumps_lv.input0_upper_sum_inv,
input0_upper_sum,
jumps_lv.input0_upper_zero,
);
let constr = builder.mul_sub_extension(filter, constr, filter);
yield_constr.constraint(builder, constr);
};
// Check `dst_valid_or_kernel` (this is just a logical OR)
}
{
let constr = builder.mul_add_extension(
jumps_lv.dst_valid,
let constr =
builder.mul_sub_extension(jumps_lv.cond_sum_pinv, cond_sum, jumps_lv.should_jump);
let constr = builder.mul_extension(filter, constr);
yield_constr.constraint(builder, constr);
}
// If we're jumping, then the high 7 limbs of the destination must be 0.
let dst_hi_sum = builder.add_many_extension(&dst[1..]);
{
let constr = builder.mul_extension(jumps_lv.should_jump, dst_hi_sum);
let constr = builder.mul_extension(filter, constr);
yield_constr.constraint(builder, constr);
}
// Check that the destination address holds a `JUMPDEST` instruction. Note that this constraint
// does not need to be conditioned on `should_jump` because no read takes place if we're not
// jumping, so we're free to set the channel to 1.
{
let constr = builder.mul_sub_extension(filter, jumpdest_flag_channel.value[0], filter);
yield_constr.constraint(builder, constr);
}
// Make sure that the JUMPDEST flag channel is constrained.
// Only need to read if we're about to jump and we're not in kernel mode.
{
let constr = builder.mul_sub_extension(
jumps_lv.should_jump,
lv.is_kernel_mode,
jumps_lv.dst_valid_or_kernel,
);
let constr = builder.sub_extension(jumps_lv.dst_valid, constr);
let constr = builder.add_extension(lv.is_kernel_mode, constr);
let constr = builder.mul_extension(filter, constr);
yield_constr.constraint(builder, constr);
}
// Check `input0_jumpable` (this is just `dst_valid_or_kernel` AND `input0_upper_zero`)
{
let constr = builder.mul_sub_extension(
jumps_lv.dst_valid_or_kernel,
jumps_lv.input0_upper_zero,
jumps_lv.input0_jumpable,
);
let constr = builder.mul_extension(filter, constr);
yield_constr.constraint(builder, constr);
}
// Make sure that `should_continue`, `should_jump`, `should_trap` are all binary and exactly one
// is set.
for flag in [
jumps_lv.should_continue,
jumps_lv.should_jump,
jumps_lv.should_trap,
] {
let constr = builder.mul_sub_extension(flag, flag, flag);
let constr = builder.mul_extension(filter, constr);
yield_constr.constraint(builder, constr);
}
{
let constr = builder.add_extension(jumps_lv.should_continue, jumps_lv.should_jump);
let constr = builder.add_extension(constr, jumps_lv.should_trap);
let constr = builder.mul_sub_extension(filter, constr, filter);
yield_constr.constraint(builder, constr);
}
// Validate `should_continue`
{
// This sum cannot overflow (due to limb size).
let input1_sum = builder.add_many_extension(input1.into_iter());
// `should_continue` = 1 implies `input1_sum` = 0.
let constr = builder.mul_extension(jumps_lv.should_continue, input1_sum);
let constr = builder.mul_extension(filter, constr);
yield_constr.constraint(builder, constr);
// `should_continue` = 0 implies `input1_sum * input1_sum_inv` = 1, which can only happen if
// input1_sum is nonzero.
let constr = builder.mul_add_extension(
input1_sum,
jumps_lv.input1_sum_inv,
jumps_lv.should_continue,
);
let constr = builder.mul_sub_extension(filter, constr, filter);
yield_constr.constraint(builder, constr);
}
// Validate `should_jump` and `should_trap` by splitting on `input0_jumpable`.
// Note that `should_jump` = 1 and `should_trap` = 1 both imply that `should_continue` = 0, so
// `input1` is nonzero.
{
let constr = builder.mul_sub_extension(
jumps_lv.should_jump,
jumps_lv.input0_jumpable,
jumps_lv.should_jump,
);
let constr = builder.add_extension(jumpdest_flag_channel.used, constr);
let constr = builder.mul_extension(filter, constr);
yield_constr.constraint(builder, constr);
}
{
let constr = builder.mul_extension(jumps_lv.should_trap, jumps_lv.input0_jumpable);
let constr = builder.mul_sub_extension(filter, jumpdest_flag_channel.is_read, filter);
yield_constr.constraint(builder, constr);
}
{
let constr = builder.sub_extension(jumpdest_flag_channel.addr_context, lv.context);
let constr = builder.mul_extension(filter, constr);
yield_constr.constraint(builder, constr);
}
// Handle trap
{
let trap_filter = builder.mul_extension(filter, jumps_lv.should_trap);
// Set kernel flag
let constr = builder.mul_sub_extension(trap_filter, nv.is_kernel_mode, trap_filter);
yield_constr.constraint_transition(builder, constr);
// Set program counter
let constr = builder.arithmetic_extension(
F::ONE,
-F::from_canonical_usize(*INVALID_DST_HANDLER_ADDR),
trap_filter,
nv.program_counter,
trap_filter,
-F::from_canonical_u64(Segment::JumpdestBits as u64),
filter,
jumpdest_flag_channel.addr_segment,
filter,
);
yield_constr.constraint_transition(builder, constr);
yield_constr.constraint(builder, constr);
}
{
let constr = builder.sub_extension(jumpdest_flag_channel.addr_virtual, dst[0]);
let constr = builder.mul_extension(filter, constr);
yield_constr.constraint(builder, constr);
}
// Handle continue and jump
// Disable unused memory channels
for &channel in &lv.mem_channels[2..NUM_GP_CHANNELS - 1] {
let constr = builder.mul_extension(filter, channel.used);
yield_constr.constraint(builder, constr);
}
// Channel 1 is unused by the `JUMP` instruction.
{
// Keep kernel mode.
let continue_or_jump =
builder.add_extension(jumps_lv.should_continue, jumps_lv.should_jump);
let constr = builder.sub_extension(nv.is_kernel_mode, lv.is_kernel_mode);
let constr = builder.mul_extension(continue_or_jump, constr);
let constr = builder.mul_extension(filter, constr);
let constr = builder.mul_extension(lv.op.jump, lv.mem_channels[1].used);
yield_constr.constraint(builder, constr);
}
// Finally, set the next program counter.
let fallthrough_dst = builder.add_const_extension(lv.program_counter, F::ONE);
let jump_dest = dst[0];
{
let constr_a = builder.mul_sub_extension(filter, jumps_lv.should_jump, filter);
let constr_b = builder.sub_extension(nv.program_counter, fallthrough_dst);
let constr = builder.mul_extension(constr_a, constr_b);
yield_constr.constraint_transition(builder, constr);
}
// Set program counter depending on whether we're continuing...
{
let constr = builder.sub_extension(nv.program_counter, lv.program_counter);
let constr =
builder.mul_sub_extension(jumps_lv.should_continue, constr, jumps_lv.should_continue);
let constr = builder.mul_extension(filter, constr);
yield_constr.constraint_transition(builder, constr);
}
// ...or jumping.
{
let constr = builder.sub_extension(nv.program_counter, input0[0]);
let constr = builder.mul_extension(jumps_lv.should_jump, constr);
let constr = builder.mul_extension(filter, constr);
let constr_a = builder.mul_extension(filter, jumps_lv.should_jump);
let constr_b = builder.sub_extension(nv.program_counter, jump_dest);
let constr = builder.mul_extension(constr_a, constr_b);
yield_constr.constraint_transition(builder, constr);
}
}

View File

@ -64,8 +64,16 @@ const STACK_BEHAVIORS: OpsColumnsView<Option<StackBehavior>> = OpsColumnsView {
keccak_general: None, // TODO
prover_input: None, // TODO
pop: None, // TODO
jump: None, // TODO
jumpi: None, // TODO
jump: Some(StackBehavior {
num_pops: 1,
pushes: false,
disable_other_channels: false,
}),
jumpi: Some(StackBehavior {
num_pops: 2,
pushes: false,
disable_other_channels: false,
}),
pc: Some(StackBehavior {
num_pops: 0,
pushes: true,
@ -91,7 +99,11 @@ const STACK_BEHAVIORS: OpsColumnsView<Option<StackBehavior>> = OpsColumnsView {
disable_other_channels: true,
}),
consume_gas: None, // TODO
exit_kernel: None, // TODO
exit_kernel: Some(StackBehavior {
num_pops: 1,
pushes: false,
disable_other_channels: true,
}),
mload_general: Some(StackBehavior {
num_pops: 3,
pushes: true,

View File

@ -144,11 +144,3 @@ pub(crate) fn biguint_to_u256(x: BigUint) -> U256 {
let bytes = x.to_bytes_le();
U256::from_little_endian(&bytes)
}
pub(crate) fn u256_saturating_cast_usize(x: U256) -> usize {
if x > usize::MAX.into() {
usize::MAX
} else {
x.as_usize()
}
}

View File

@ -10,7 +10,6 @@ use crate::cpu::membus::NUM_GP_CHANNELS;
use crate::cpu::simple_logic::eq_iszero::generate_pinv_diff;
use crate::generation::state::GenerationState;
use crate::memory::segments::Segment;
use crate::util::u256_saturating_cast_usize;
use crate::witness::errors::ProgramError;
use crate::witness::memory::MemoryAddress;
use crate::witness::util::{
@ -187,12 +186,37 @@ pub(crate) fn generate_jump<F: Field>(
mut row: CpuColumnsView<F>,
) -> Result<(), ProgramError> {
let [(dst, log_in0)] = stack_pop_with_log_and_fill::<1, _>(state, &mut row)?;
let dst: u32 = dst
.try_into()
.map_err(|_| ProgramError::InvalidJumpDestination)?;
let (jumpdest_bit, jumpdest_bit_log) = mem_read_gp_with_log_and_fill(
NUM_GP_CHANNELS - 1,
MemoryAddress::new(state.registers.context, Segment::JumpdestBits, dst as usize),
state,
&mut row,
);
if state.registers.is_kernel {
// Don't actually do the read, just set the address, etc.
let mut channel = &mut row.mem_channels[NUM_GP_CHANNELS - 1];
channel.used = F::ZERO;
channel.value[0] = F::ONE;
row.mem_channels[1].value[0] = F::ONE;
} else {
if jumpdest_bit != U256::one() {
return Err(ProgramError::InvalidJumpDestination);
}
state.traces.push_memory(jumpdest_bit_log);
}
// Extra fields required by the constraints.
row.general.jumps_mut().should_jump = F::ONE;
row.general.jumps_mut().cond_sum_pinv = F::ONE;
state.traces.push_memory(log_in0);
state.traces.push_cpu(row);
// TODO: First check if it's a valid JUMPDEST
state.registers.program_counter = u256_saturating_cast_usize(dst);
// TODO: Set other cols like input0_upper_sum_inv.
state.registers.program_counter = dst as usize;
Ok(())
}
@ -202,16 +226,52 @@ pub(crate) fn generate_jumpi<F: Field>(
) -> Result<(), ProgramError> {
let [(dst, log_in0), (cond, log_in1)] = stack_pop_with_log_and_fill::<2, _>(state, &mut row)?;
let should_jump = !cond.is_zero();
if should_jump {
row.general.jumps_mut().should_jump = F::ONE;
let cond_sum_u64 = cond
.0
.into_iter()
.map(|limb| ((limb as u32) as u64) + (limb >> 32))
.sum();
let cond_sum = F::from_canonical_u64(cond_sum_u64);
row.general.jumps_mut().cond_sum_pinv = cond_sum.inverse();
let dst: u32 = dst
.try_into()
.map_err(|_| ProgramError::InvalidJumpiDestination)?;
state.registers.program_counter = dst as usize;
} else {
row.general.jumps_mut().should_jump = F::ZERO;
row.general.jumps_mut().cond_sum_pinv = F::ZERO;
state.registers.program_counter += 1;
}
let (jumpdest_bit, jumpdest_bit_log) = mem_read_gp_with_log_and_fill(
NUM_GP_CHANNELS - 1,
MemoryAddress::new(
state.registers.context,
Segment::JumpdestBits,
dst.low_u32() as usize,
),
state,
&mut row,
);
if !should_jump || state.registers.is_kernel {
// Don't actually do the read, just set the address, etc.
let mut channel = &mut row.mem_channels[NUM_GP_CHANNELS - 1];
channel.used = F::ZERO;
channel.value[0] = F::ONE;
} else {
if jumpdest_bit != U256::one() {
return Err(ProgramError::InvalidJumpiDestination);
}
state.traces.push_memory(jumpdest_bit_log);
}
state.traces.push_memory(log_in0);
state.traces.push_memory(log_in1);
state.traces.push_cpu(row);
state.registers.program_counter = if cond.is_zero() {
state.registers.program_counter + 1
} else {
// TODO: First check if it's a valid JUMPDEST
u256_saturating_cast_usize(dst)
};
// TODO: Set other cols like input0_upper_sum_inv.
Ok(())
}