Remerge context flags (#1292)

* Remerge context flags

* Apply comments and revert some unwanted changes
This commit is contained in:
Hamy Ratoanina 2023-10-30 12:56:11 -04:00 committed by GitHub
parent 0258ad4a3d
commit 4b40bc0313
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 199 additions and 106 deletions

View File

@ -22,9 +22,8 @@ pub struct OpsColumnsView<T: Copy> {
pub jumpdest: T,
pub push0: T,
pub push: T,
pub dup_swap: T,
pub get_context: T,
pub set_context: T,
pub dup_swap: T, // Combines DUP and SWAP flags.
pub context_op: T, // Combines GET_CONTEXT and SET_CONTEXT flags.
pub mstore_32bytes: T,
pub mload_32bytes: T,
pub exit_kernel: T,

View File

@ -5,6 +5,7 @@ use plonky2::hash::hash_types::RichField;
use plonky2::iop::ext_target::ExtensionTarget;
use plonky2::plonk::circuit_builder::CircuitBuilder;
use super::membus::NUM_GP_CHANNELS;
use crate::constraint_consumer::{ConstraintConsumer, RecursiveConstraintConsumer};
use crate::cpu::columns::CpuColumnsView;
use crate::cpu::kernel::constants::context_metadata::ContextMetadata;
@ -15,12 +16,25 @@ fn eval_packed_get<P: PackedField>(
nv: &CpuColumnsView<P>,
yield_constr: &mut ConstraintConsumer<P>,
) {
let filter = lv.op.get_context;
// If the opcode is GET_CONTEXT, then lv.opcode_bits[0] = 0.
let filter = lv.op.context_op * (P::ONES - lv.opcode_bits[0]);
let new_stack_top = nv.mem_channels[0].value;
yield_constr.constraint(filter * (new_stack_top[0] - lv.context));
for &limb in &new_stack_top[1..] {
yield_constr.constraint(filter * limb);
}
// Constrain new stack length.
yield_constr.constraint(filter * (nv.stack_len - (lv.stack_len + P::ONES)));
// Unused channels.
for i in 1..NUM_GP_CHANNELS {
if i != 3 {
let channel = lv.mem_channels[i];
yield_constr.constraint(filter * channel.used);
}
}
yield_constr.constraint(filter * nv.mem_channels[0].used);
}
fn eval_ext_circuit_get<F: RichField + Extendable<D>, const D: usize>(
@ -29,7 +43,9 @@ fn eval_ext_circuit_get<F: RichField + Extendable<D>, const D: usize>(
nv: &CpuColumnsView<ExtensionTarget<D>>,
yield_constr: &mut RecursiveConstraintConsumer<F, D>,
) {
let filter = lv.op.get_context;
// If the opcode is GET_CONTEXT, then lv.opcode_bits[0] = 0.
let prod = builder.mul_extension(lv.op.context_op, lv.opcode_bits[0]);
let filter = builder.sub_extension(lv.op.context_op, prod);
let new_stack_top = nv.mem_channels[0].value;
{
let diff = builder.sub_extension(new_stack_top[0], lv.context);
@ -40,6 +56,27 @@ fn eval_ext_circuit_get<F: RichField + Extendable<D>, const D: usize>(
let constr = builder.mul_extension(filter, limb);
yield_constr.constraint(builder, constr);
}
// Constrain new stack length.
{
let new_len = builder.add_const_extension(lv.stack_len, F::ONE);
let diff = builder.sub_extension(nv.stack_len, new_len);
let constr = builder.mul_extension(filter, diff);
yield_constr.constraint(builder, constr);
}
// Unused channels.
for i in 1..NUM_GP_CHANNELS {
if i != 3 {
let channel = lv.mem_channels[i];
let constr = builder.mul_extension(filter, channel.used);
yield_constr.constraint(builder, constr);
}
}
{
let constr = builder.mul_extension(filter, nv.mem_channels[0].used);
yield_constr.constraint(builder, constr);
}
}
fn eval_packed_set<P: PackedField>(
@ -47,7 +84,7 @@ fn eval_packed_set<P: PackedField>(
nv: &CpuColumnsView<P>,
yield_constr: &mut ConstraintConsumer<P>,
) {
let filter = lv.op.set_context;
let filter = lv.op.context_op * lv.opcode_bits[0];
let stack_top = lv.mem_channels[0].value;
let write_old_sp_channel = lv.mem_channels[1];
let read_new_sp_channel = lv.mem_channels[2];
@ -77,34 +114,29 @@ fn eval_packed_set<P: PackedField>(
yield_constr.constraint(filter * (read_new_sp_channel.addr_segment - ctx_metadata_segment));
yield_constr.constraint(filter * (read_new_sp_channel.addr_virtual - stack_size_field));
// The next row's stack top is loaded from memory (if the stack isn't empty).
yield_constr.constraint(filter * nv.mem_channels[0].used);
let read_new_stack_top_channel = lv.mem_channels[3];
let stack_segment = P::Scalar::from_canonical_u64(Segment::Stack as u64);
let new_filter = filter * nv.stack_len;
for (limb_channel, limb_top) in read_new_stack_top_channel
.value
.iter()
.zip(nv.mem_channels[0].value)
{
yield_constr.constraint(new_filter * (*limb_channel - limb_top));
// Constrain stack_inv_aux_2.
let new_top_channel = nv.mem_channels[0];
yield_constr.constraint(
lv.op.context_op
* (lv.general.stack().stack_inv_aux * lv.opcode_bits[0]
- lv.general.stack().stack_inv_aux_2),
);
// The new top is loaded in memory channel 3, if the stack isn't empty (see eval_packed).
yield_constr.constraint(
lv.op.context_op
* lv.general.stack().stack_inv_aux_2
* (lv.mem_channels[3].value[0] - new_top_channel.value[0]),
);
for &limb in &new_top_channel.value[1..] {
yield_constr.constraint(lv.op.context_op * lv.general.stack().stack_inv_aux_2 * limb);
}
yield_constr.constraint(new_filter * (read_new_stack_top_channel.used - P::ONES));
yield_constr.constraint(new_filter * (read_new_stack_top_channel.is_read - P::ONES));
yield_constr.constraint(new_filter * (read_new_stack_top_channel.addr_context - nv.context));
yield_constr.constraint(new_filter * (read_new_stack_top_channel.addr_segment - stack_segment));
yield_constr.constraint(
new_filter * (read_new_stack_top_channel.addr_virtual - (nv.stack_len - P::ONES)),
);
// If the new stack is empty, disable the channel read.
yield_constr.constraint(
filter * (nv.stack_len * lv.general.stack().stack_inv - lv.general.stack().stack_inv_aux),
);
let empty_stack_filter = filter * (lv.general.stack().stack_inv_aux - P::ONES);
yield_constr.constraint(empty_stack_filter * read_new_stack_top_channel.used);
// Unused channels.
for i in 4..NUM_GP_CHANNELS {
let channel = lv.mem_channels[i];
yield_constr.constraint(filter * channel.used);
}
yield_constr.constraint(filter * new_top_channel.used);
}
fn eval_ext_circuit_set<F: RichField + Extendable<D>, const D: usize>(
@ -113,7 +145,7 @@ fn eval_ext_circuit_set<F: RichField + Extendable<D>, const D: usize>(
nv: &CpuColumnsView<ExtensionTarget<D>>,
yield_constr: &mut RecursiveConstraintConsumer<F, D>,
) {
let filter = lv.op.set_context;
let filter = builder.mul_extension(lv.op.context_op, lv.opcode_bits[0]);
let stack_top = lv.mem_channels[0].value;
let write_old_sp_channel = lv.mem_channels[1];
let read_new_sp_channel = lv.mem_channels[2];
@ -197,66 +229,38 @@ fn eval_ext_circuit_set<F: RichField + Extendable<D>, const D: usize>(
yield_constr.constraint(builder, constr);
}
// The next row's stack top is loaded from memory (if the stack isn't empty).
// Constrain stack_inv_aux_2.
let new_top_channel = nv.mem_channels[0];
{
let constr = builder.mul_extension(filter, nv.mem_channels[0].used);
let diff = builder.mul_sub_extension(
lv.general.stack().stack_inv_aux,
lv.opcode_bits[0],
lv.general.stack().stack_inv_aux_2,
);
let constr = builder.mul_extension(lv.op.context_op, diff);
yield_constr.constraint(builder, constr);
}
// The new top is loaded in memory channel 3, if the stack isn't empty (see eval_packed).
{
let diff = builder.sub_extension(lv.mem_channels[3].value[0], new_top_channel.value[0]);
let prod = builder.mul_extension(lv.general.stack().stack_inv_aux_2, diff);
let constr = builder.mul_extension(lv.op.context_op, prod);
yield_constr.constraint(builder, constr);
}
for &limb in &new_top_channel.value[1..] {
let prod = builder.mul_extension(lv.general.stack().stack_inv_aux_2, limb);
let constr = builder.mul_extension(lv.op.context_op, prod);
yield_constr.constraint(builder, constr);
}
let read_new_stack_top_channel = lv.mem_channels[3];
let stack_segment =
builder.constant_extension(F::Extension::from_canonical_u32(Segment::Stack as u32));
let new_filter = builder.mul_extension(filter, nv.stack_len);
for (limb_channel, limb_top) in read_new_stack_top_channel
.value
.iter()
.zip(nv.mem_channels[0].value)
{
let diff = builder.sub_extension(*limb_channel, limb_top);
let constr = builder.mul_extension(new_filter, diff);
// Unused channels.
for i in 4..NUM_GP_CHANNELS {
let channel = lv.mem_channels[i];
let constr = builder.mul_extension(filter, channel.used);
yield_constr.constraint(builder, constr);
}
{
let constr =
builder.mul_sub_extension(new_filter, read_new_stack_top_channel.used, new_filter);
yield_constr.constraint(builder, constr);
}
{
let constr =
builder.mul_sub_extension(new_filter, read_new_stack_top_channel.is_read, new_filter);
yield_constr.constraint(builder, constr);
}
{
let diff = builder.sub_extension(read_new_stack_top_channel.addr_context, nv.context);
let constr = builder.mul_extension(new_filter, diff);
yield_constr.constraint(builder, constr);
}
{
let diff = builder.sub_extension(read_new_stack_top_channel.addr_segment, stack_segment);
let constr = builder.mul_extension(new_filter, diff);
yield_constr.constraint(builder, constr);
}
{
let diff = builder.sub_extension(nv.stack_len, one);
let diff = builder.sub_extension(read_new_stack_top_channel.addr_virtual, diff);
let constr = builder.mul_extension(new_filter, diff);
yield_constr.constraint(builder, constr);
}
// If the new stack is empty, disable the channel read.
{
let diff = builder.mul_extension(nv.stack_len, lv.general.stack().stack_inv);
let diff = builder.sub_extension(diff, lv.general.stack().stack_inv_aux);
let constr = builder.mul_extension(filter, diff);
yield_constr.constraint(builder, constr);
}
{
let empty_stack_filter =
builder.mul_sub_extension(filter, lv.general.stack().stack_inv_aux, filter);
let constr = builder.mul_extension(empty_stack_filter, read_new_stack_top_channel.used);
let constr = builder.mul_extension(filter, new_top_channel.used);
yield_constr.constraint(builder, constr);
}
}
@ -268,6 +272,33 @@ pub fn eval_packed<P: PackedField>(
) {
eval_packed_get(lv, nv, yield_constr);
eval_packed_set(lv, nv, yield_constr);
// Stack constraints.
// Both operations use memory channel 3. The operations are similar enough that
// we can constrain both at the same time.
let filter = lv.op.context_op;
let channel = lv.mem_channels[3];
// For get_context, we check if lv.stack_len is 0. For set_context, we check if nv.stack_len is 0.
// However, for get_context, we can deduce lv.stack_len from nv.stack_len since the operation only pushes.
let stack_len = nv.stack_len - (P::ONES - lv.opcode_bits[0]);
// Constrain stack_inv_aux. It's 0 if the relevant stack is empty, 1 otherwise.
yield_constr.constraint(
filter * (stack_len * lv.general.stack().stack_inv - lv.general.stack().stack_inv_aux),
);
// Enable or disable the channel.
yield_constr.constraint(filter * (lv.general.stack().stack_inv_aux - channel.used));
let new_filter = filter * lv.general.stack().stack_inv_aux;
// It's a write for get_context, a read for set_context.
yield_constr.constraint(new_filter * (channel.is_read - lv.opcode_bits[0]));
// In both cases, next row's context works.
yield_constr.constraint(new_filter * (channel.addr_context - nv.context));
// Same segment for both.
yield_constr.constraint(
new_filter * (channel.addr_segment - P::Scalar::from_canonical_u64(Segment::Stack as u64)),
);
// The address is one less than stack_len.
let addr_virtual = stack_len - P::ONES;
yield_constr.constraint(new_filter * (channel.addr_virtual - addr_virtual));
}
pub fn eval_ext_circuit<F: RichField + Extendable<D>, const D: usize>(
@ -278,4 +309,59 @@ pub fn eval_ext_circuit<F: RichField + Extendable<D>, const D: usize>(
) {
eval_ext_circuit_get(builder, lv, nv, yield_constr);
eval_ext_circuit_set(builder, lv, nv, yield_constr);
// Stack constraints.
// Both operations use memory channel 3. The operations are similar enough that
// we can constrain both at the same time.
let filter = lv.op.context_op;
let channel = lv.mem_channels[3];
// For get_context, we check if lv.stack_len is 0. For set_context, we check if nv.stack_len is 0.
// However, for get_context, we can deduce lv.stack_len from nv.stack_len since the operation only pushes.
let diff = builder.add_const_extension(lv.opcode_bits[0], -F::ONE);
let stack_len = builder.add_extension(nv.stack_len, diff);
// Constrain stack_inv_aux. It's 0 if the relevant stack is empty, 1 otherwise.
{
let diff = builder.mul_sub_extension(
stack_len,
lv.general.stack().stack_inv,
lv.general.stack().stack_inv_aux,
);
let constr = builder.mul_extension(filter, diff);
yield_constr.constraint(builder, constr);
}
// Enable or disable the channel.
{
let diff = builder.sub_extension(lv.general.stack().stack_inv_aux, channel.used);
let constr = builder.mul_extension(filter, diff);
yield_constr.constraint(builder, constr);
}
let new_filter = builder.mul_extension(filter, lv.general.stack().stack_inv_aux);
// It's a write for get_context, a read for set_context.
{
let diff = builder.sub_extension(channel.is_read, lv.opcode_bits[0]);
let constr = builder.mul_extension(new_filter, diff);
yield_constr.constraint(builder, constr);
}
// In both cases, next row's context works.
{
let diff = builder.sub_extension(channel.addr_context, nv.context);
let constr = builder.mul_extension(new_filter, diff);
yield_constr.constraint(builder, constr);
}
// Same segment for both.
{
let diff = builder.add_const_extension(
channel.addr_segment,
-F::from_canonical_u64(Segment::Stack as u64),
);
let constr = builder.mul_extension(new_filter, diff);
yield_constr.constraint(builder, constr);
}
// The address is one less than stack_len.
{
let addr_virtual = builder.add_const_extension(stack_len, -F::ONE);
let diff = builder.sub_extension(channel.addr_virtual, addr_virtual);
let constr = builder.mul_extension(new_filter, diff);
yield_constr.constraint(builder, constr);
}
}

View File

@ -8,7 +8,7 @@ use crate::constraint_consumer::{ConstraintConsumer, RecursiveConstraintConsumer
use crate::cpu::columns::{CpuColumnsView, COL_MAP};
use crate::cpu::kernel::aggregator::KERNEL;
const NATIVE_INSTRUCTIONS: [usize; 17] = [
const NATIVE_INSTRUCTIONS: [usize; 16] = [
COL_MAP.op.binary_op,
COL_MAP.op.ternary_op,
COL_MAP.op.fp254_op,
@ -25,8 +25,7 @@ const NATIVE_INSTRUCTIONS: [usize; 17] = [
COL_MAP.op.push0,
// not PUSH (need to increment by more than 1)
COL_MAP.op.dup_swap,
COL_MAP.op.get_context,
COL_MAP.op.set_context,
COL_MAP.op.context_op,
// not EXIT_KERNEL (performs a jump)
COL_MAP.op.m_op_general,
// not SYSCALL (performs a jump)

View File

@ -23,7 +23,7 @@ use crate::cpu::columns::{CpuColumnsView, COL_MAP};
/// behavior.
/// Note: invalid opcodes are not represented here. _Any_ opcode is permitted to decode to
/// `is_invalid`. The kernel then verifies that the opcode was _actually_ invalid.
const OPCODES: [(u8, usize, bool, usize); 16] = [
const OPCODES: [(u8, usize, bool, usize); 15] = [
// (start index of block, number of top bits to check (log2), kernel-only, flag column)
// ADD, MUL, SUB, DIV, MOD, LT, GT and BYTE flags are handled partly manually here, and partly through the Arithmetic table CTL.
// ADDMOD, MULMOD and SUBMOD flags are handled partly manually here, and partly through the Arithmetic table CTL.
@ -42,8 +42,7 @@ const OPCODES: [(u8, usize, bool, usize); 16] = [
(0x60, 5, false, COL_MAP.op.push), // 0x60-0x7f
(0x80, 5, false, COL_MAP.op.dup_swap), // 0x80-0x9f
(0xee, 0, true, COL_MAP.op.mstore_32bytes),
(0xf6, 0, true, COL_MAP.op.get_context),
(0xf7, 0, true, COL_MAP.op.set_context),
(0xf6, 1, true, COL_MAP.op.context_op), //0xf6-0xf7
(0xf8, 0, true, COL_MAP.op.mload_32bytes),
(0xf9, 0, true, COL_MAP.op.exit_kernel),
// MLOAD_GENERAL and MSTORE_GENERAL flags are handled manually here.

View File

@ -35,8 +35,7 @@ const SIMPLE_OPCODES: OpsColumnsView<Option<u32>> = OpsColumnsView {
push0: G_BASE,
push: G_VERYLOW,
dup_swap: G_VERYLOW,
get_context: KERNEL_ONLY_INSTR,
set_context: KERNEL_ONLY_INSTR,
context_op: KERNEL_ONLY_INSTR,
mstore_32bytes: KERNEL_ONLY_INSTR,
mload_32bytes: KERNEL_ONLY_INSTR,
exit_kernel: None,

View File

@ -264,7 +264,7 @@ fn eval_ext_circuit_store<F: RichField + Extendable<D>, const D: usize>(
let top_read_channel = nv.mem_channels[0];
let is_top_read = builder.mul_extension(lv.general.stack().stack_inv_aux, lv.opcode_bits[0]);
let is_top_read = builder.sub_extension(lv.general.stack().stack_inv_aux, is_top_read);
// Constrain `stack_inv_aux_2`. It contains `stack_inv_aux * opcode_bits[0]`.
// Constrain `stack_inv_aux_2`. It contains `stack_inv_aux * (1 - opcode_bits[0])`.
{
let diff = builder.sub_extension(lv.general.stack().stack_inv_aux_2, is_top_read);
let constr = builder.mul_extension(lv.op.m_op_general, diff);

View File

@ -97,12 +97,7 @@ pub(crate) const STACK_BEHAVIORS: OpsColumnsView<Option<StackBehavior>> = OpsCol
}),
push: None, // TODO
dup_swap: None,
get_context: Some(StackBehavior {
num_pops: 0,
pushes: true,
disable_other_channels: true,
}),
set_context: None, // SET_CONTEXT is special since it involves the old and the new stack.
context_op: None,
mload_32bytes: Some(StackBehavior {
num_pops: 4,
pushes: true,

View File

@ -308,7 +308,22 @@ pub(crate) fn generate_get_context<F: Field>(
state: &mut GenerationState<F>,
mut row: CpuColumnsView<F>,
) -> Result<(), ProgramError> {
push_with_write(state, &mut row, state.registers.context.into())?;
// Same logic as push_with_write, but we have to use channel 3 for stack constraint reasons.
let write = if state.registers.stack_len == 0 {
None
} else {
let address = MemoryAddress::new(
state.registers.context,
Segment::Stack,
state.registers.stack_len - 1,
);
let res = mem_write_gp_log_and_fill(3, address, state, &mut row, state.registers.stack_top);
Some(res)
};
push_no_write(state, state.registers.context.into());
if let Some(log) = write {
state.traces.push_memory(log);
}
state.traces.push_cpu(row);
Ok(())
}
@ -364,9 +379,11 @@ pub(crate) fn generate_set_context<F: Field>(
if let Some(inv) = new_sp_field.try_inverse() {
row.general.stack_mut().stack_inv = inv;
row.general.stack_mut().stack_inv_aux = F::ONE;
row.general.stack_mut().stack_inv_aux_2 = F::ONE;
} else {
row.general.stack_mut().stack_inv = F::ZERO;
row.general.stack_mut().stack_inv_aux = F::ZERO;
row.general.stack_mut().stack_inv_aux_2 = F::ZERO;
}
let new_top_addr = MemoryAddress::new(new_ctx, Segment::Stack, new_sp - 1);
@ -833,6 +850,7 @@ pub(crate) fn generate_mstore_general<F: Field>(
state.traces.push_memory(log_in2);
state.traces.push_memory(log_in3);
state.traces.push_memory(log_write);
state.traces.push_cpu(row);
Ok(())

View File

@ -180,8 +180,7 @@ fn fill_op_flag<F: Field>(op: Operation, row: &mut CpuColumnsView<F>) {
Operation::Jump | Operation::Jumpi => &mut flags.jumps,
Operation::Pc => &mut flags.pc,
Operation::Jumpdest => &mut flags.jumpdest,
Operation::GetContext => &mut flags.get_context,
Operation::SetContext => &mut flags.set_context,
Operation::GetContext | Operation::SetContext => &mut flags.context_op,
Operation::Mload32Bytes => &mut flags.mload_32bytes,
Operation::Mstore32Bytes => &mut flags.mstore_32bytes,
Operation::ExitKernel => &mut flags.exit_kernel,
@ -216,8 +215,7 @@ fn get_op_special_length(op: Operation) -> Option<usize> {
Operation::Jumpi => JUMPI_OP,
Operation::Pc => STACK_BEHAVIORS.pc,
Operation::Jumpdest => STACK_BEHAVIORS.jumpdest,
Operation::GetContext => STACK_BEHAVIORS.get_context,
Operation::SetContext => None,
Operation::GetContext | Operation::SetContext => None,
Operation::Mload32Bytes => STACK_BEHAVIORS.mload_32bytes,
Operation::Mstore32Bytes => STACK_BEHAVIORS.mstore_32bytes,
Operation::ExitKernel => STACK_BEHAVIORS.exit_kernel,