Refactor to support PROVER_INPUT

This commit is contained in:
Daniel Lubarov 2022-12-01 11:15:51 -08:00
parent b6326c56b2
commit 027dfc14b6
8 changed files with 334 additions and 373 deletions

View File

@ -14,11 +14,11 @@ use crate::constraint_consumer::{ConstraintConsumer, RecursiveConstraintConsumer
use crate::cpu::columns::{CpuColumnsView, NUM_CPU_COLUMNS};
use crate::cpu::kernel::aggregator::KERNEL;
use crate::cpu::kernel::keccak_util::keccakf_u32s;
use crate::generation::state::GenerationState;
use crate::keccak_sponge::columns::KECCAK_RATE_U32S;
use crate::memory::segments::Segment;
use crate::vars::{StarkEvaluationTargets, StarkEvaluationVars};
use crate::witness::memory::MemoryAddress;
use crate::witness::traces::Traces;
use crate::witness::util::mem_write_gp_log_and_fill;
/// We can't process more than `NUM_CHANNELS` bytes per row, since that's all the memory bandwidth
@ -26,7 +26,7 @@ use crate::witness::util::mem_write_gp_log_and_fill;
/// want them to fit in a single limb of Keccak input.
const BYTES_PER_ROW: usize = 4;
pub(crate) fn generate_bootstrap_kernel<F: Field>(traces: &mut Traces<F>) {
pub(crate) fn generate_bootstrap_kernel<F: Field>(state: &mut GenerationState<F>) {
let mut sponge_state = [0u32; 50];
let mut sponge_input_pos: usize = 0;
@ -47,11 +47,11 @@ pub(crate) fn generate_bootstrap_kernel<F: Field>(traces: &mut Traces<F>) {
let write = mem_write_gp_log_and_fill(
channel,
address,
traces,
state,
&mut current_cpu_row,
byte.into(),
);
traces.push_memory(write);
state.traces.push_memory(write);
packed_bytes = (packed_bytes << 8) | byte as u32;
}
@ -70,7 +70,7 @@ pub(crate) fn generate_bootstrap_kernel<F: Field>(traces: &mut Traces<F>) {
keccak.output_limbs = sponge_state.map(F::from_canonical_u32);
}
traces.push_cpu(current_cpu_row);
state.traces.push_cpu(current_cpu_row);
}
}

View File

@ -13,11 +13,10 @@ use crate::config::StarkConfig;
use crate::cpu::bootstrap_kernel::generate_bootstrap_kernel;
use crate::cpu::kernel::aggregator::KERNEL;
use crate::cpu::kernel::constants::global_metadata::GlobalMetadata;
use crate::generation::state::GenerationState;
use crate::memory::segments::Segment;
use crate::proof::{BlockMetadata, PublicValues, TrieRoots};
use crate::witness::memory::{MemoryAddress, MemoryState};
use crate::witness::state::RegistersState;
use crate::witness::traces::Traces;
use crate::witness::memory::MemoryAddress;
use crate::witness::transition::transition;
pub(crate) mod memory;
@ -65,29 +64,26 @@ pub(crate) fn generate_traces<F: RichField + Extendable<D>, const D: usize>(
config: &StarkConfig,
timing: &mut TimingTree,
) -> ([Vec<PolynomialValues<F>>; NUM_TABLES], PublicValues) {
// let mut state = GenerationState::<F>::new(inputs.clone());
let mut state = GenerationState::<F>::new(inputs.clone(), &KERNEL.code);
let mut memory_state = MemoryState::new(&KERNEL.code);
let mut traces = Traces::<F>::default();
generate_bootstrap_kernel::<F>(&mut traces);
generate_bootstrap_kernel::<F>(&mut state);
let mut registers_state = RegistersState::default();
let halt_pc0 = KERNEL.global_labels["halt_pc0"];
let halt_pc1 = KERNEL.global_labels["halt_pc1"];
loop {
// If we've reached the kernel's halt routine, and our trace length is a power of 2, stop.
let pc = registers_state.program_counter as usize;
let pc = state.registers.program_counter;
let in_halt_loop = pc == halt_pc0 || pc == halt_pc1;
if in_halt_loop && traces.cpu.len().is_power_of_two() {
if in_halt_loop && state.traces.clock().is_power_of_two() {
break;
}
registers_state = transition(registers_state, &mut memory_state, &mut traces);
transition(&mut state);
}
let read_metadata = |field| {
memory_state.get(MemoryAddress::new(
state.memory.get(MemoryAddress::new(
0,
Segment::GlobalMetadata,
field as usize,
@ -115,5 +111,8 @@ pub(crate) fn generate_traces<F: RichField + Extendable<D>, const D: usize>(
block_metadata: inputs.block_metadata,
};
(traces.to_tables(all_stark, config, timing), public_values)
(
state.traces.to_tables(all_stark, config, timing),
public_values,
)
}

View File

@ -8,6 +8,7 @@ use crate::generation::prover_input::EvmField::{
};
use crate::generation::prover_input::FieldOp::{Inverse, Sqrt};
use crate::generation::state::GenerationState;
use crate::witness::util::stack_peek;
/// Prover input function represented as a scoped function name.
/// Example: `PROVER_INPUT(ff::bn254_base::inverse)` is represented as `ProverInputFn([ff, bn254_base, inverse])`.
@ -22,13 +23,13 @@ impl From<Vec<String>> for ProverInputFn {
impl<F: Field> GenerationState<F> {
#[allow(unused)] // TODO: Should be used soon.
pub(crate) fn prover_input(&mut self, stack: &[U256], input_fn: &ProverInputFn) -> U256 {
pub(crate) fn prover_input(&mut self, input_fn: &ProverInputFn) -> U256 {
match input_fn.0[0].as_str() {
"end_of_txns" => self.run_end_of_txns(),
"ff" => self.run_ff(stack, input_fn),
"ff" => self.run_ff(input_fn),
"mpt" => self.run_mpt(),
"rlp" => self.run_rlp(),
"account_code" => self.run_account_code(stack, input_fn),
"account_code" => self.run_account_code(input_fn),
_ => panic!("Unrecognized prover input function."),
}
}
@ -44,10 +45,10 @@ impl<F: Field> GenerationState<F> {
}
/// Finite field operations.
fn run_ff(&self, stack: &[U256], input_fn: &ProverInputFn) -> U256 {
fn run_ff(&self, input_fn: &ProverInputFn) -> U256 {
let field = EvmField::from_str(input_fn.0[1].as_str()).unwrap();
let op = FieldOp::from_str(input_fn.0[2].as_str()).unwrap();
let x = *stack.last().expect("Empty stack");
let x = stack_peek(self, 0).expect("Empty stack");
field.op(op, x)
}
@ -66,22 +67,21 @@ impl<F: Field> GenerationState<F> {
}
/// Account code.
fn run_account_code(&mut self, stack: &[U256], input_fn: &ProverInputFn) -> U256 {
fn run_account_code(&mut self, input_fn: &ProverInputFn) -> U256 {
match input_fn.0[1].as_str() {
"length" => {
// Return length of code.
// stack: codehash, ...
let codehash = stack.last().expect("Empty stack");
self.inputs.contract_code[&H256::from_uint(codehash)]
let codehash = stack_peek(self, 0).expect("Empty stack");
self.inputs.contract_code[&H256::from_uint(&codehash)]
.len()
.into()
}
"get" => {
// Return `code[i]`.
// stack: i, code_length, codehash, ...
let stacklen = stack.len();
let i = stack[stacklen - 1].as_usize();
let codehash = stack[stacklen - 3];
let i = stack_peek(self, 0).expect("Unexpected stack").as_usize();
let codehash = stack_peek(self, 2).expect("Unexpected stack");
self.inputs.contract_code[&H256::from_uint(&codehash)][i].into()
}
_ => panic!("Invalid prover input function."),

View File

@ -6,7 +6,12 @@ use crate::generation::rlp::all_rlp_prover_inputs_reversed;
use crate::generation::GenerationInputs;
use crate::witness::memory::MemoryState;
use crate::witness::state::RegistersState;
use crate::witness::traces::Traces;
use crate::witness::traces::{TraceCheckpoint, Traces};
pub(crate) struct GenerationStateCheckpoint {
pub(crate) registers: RegistersState,
pub(crate) traces: TraceCheckpoint,
}
#[derive(Debug)]
pub(crate) struct GenerationState<F: Field> {
@ -27,14 +32,14 @@ pub(crate) struct GenerationState<F: Field> {
}
impl<F: Field> GenerationState<F> {
pub(crate) fn new(inputs: GenerationInputs) -> Self {
pub(crate) fn new(inputs: GenerationInputs, kernel_code: &[u8]) -> Self {
let mpt_prover_inputs = all_mpt_prover_inputs_reversed(&inputs.tries);
let rlp_prover_inputs = all_rlp_prover_inputs_reversed(&inputs.signed_txns);
Self {
inputs,
registers: Default::default(),
memory: MemoryState::default(),
memory: MemoryState::new(kernel_code),
traces: Traces::default(),
next_txn_index: 0,
mpt_prover_inputs,
@ -42,6 +47,18 @@ impl<F: Field> GenerationState<F> {
}
}
pub fn checkpoint(&self) -> GenerationStateCheckpoint {
GenerationStateCheckpoint {
registers: self.registers,
traces: self.traces.checkpoint(),
}
}
pub fn rollback(&mut self, checkpoint: GenerationStateCheckpoint) {
self.registers = checkpoint.registers;
self.traces.rollback(checkpoint.traces);
}
// /// Evaluate the Keccak-f permutation in-place on some data in memory, and record the operations
// /// for the purpose of witness generation.
// #[allow(unused)] // TODO: Should be used soon.

View File

@ -6,12 +6,11 @@ use crate::cpu::columns::CpuColumnsView;
use crate::cpu::kernel::aggregator::KERNEL;
use crate::cpu::membus::NUM_GP_CHANNELS;
use crate::cpu::simple_logic::eq_iszero::generate_pinv_diff;
use crate::generation::state::GenerationState;
use crate::memory::segments::Segment;
use crate::util::u256_saturating_cast_usize;
use crate::witness::errors::ProgramError;
use crate::witness::memory::{MemoryAddress, MemoryState};
use crate::witness::state::RegistersState;
use crate::witness::traces::Traces;
use crate::witness::memory::MemoryAddress;
use crate::witness::util::{
mem_read_code_with_log_and_fill, mem_read_gp_with_log_and_fill, mem_write_gp_log_and_fill,
stack_pop_with_log_and_fill, stack_push_log_and_fill,
@ -49,133 +48,118 @@ pub(crate) enum Operation {
pub(crate) fn generate_binary_logic_op<F: Field>(
op: logic::Op,
mut registers_state: RegistersState,
memory_state: &MemoryState,
traces: &mut Traces<F>,
state: &mut GenerationState<F>,
mut row: CpuColumnsView<F>,
) -> Result<RegistersState, ProgramError> {
let [(in0, log_in0), (in1, log_in1)] =
stack_pop_with_log_and_fill::<2, _>(&mut registers_state, memory_state, traces, &mut row)?;
) -> Result<(), ProgramError> {
let [(in0, log_in0), (in1, log_in1)] = stack_pop_with_log_and_fill::<2, _>(state, &mut row)?;
let operation = logic::Operation::new(op, in0, in1);
let log_out =
stack_push_log_and_fill(&mut registers_state, traces, &mut row, operation.result)?;
let log_out = stack_push_log_and_fill(state, &mut row, operation.result)?;
traces.push_logic(operation);
traces.push_memory(log_in0);
traces.push_memory(log_in1);
traces.push_memory(log_out);
traces.push_cpu(row);
Ok(registers_state)
state.traces.push_logic(operation);
state.traces.push_memory(log_in0);
state.traces.push_memory(log_in1);
state.traces.push_memory(log_out);
state.traces.push_cpu(row);
Ok(())
}
pub(crate) fn generate_binary_arithmetic_op<F: Field>(
operator: arithmetic::BinaryOperator,
mut registers_state: RegistersState,
memory_state: &MemoryState,
traces: &mut Traces<F>,
state: &mut GenerationState<F>,
mut row: CpuColumnsView<F>,
) -> Result<RegistersState, ProgramError> {
) -> Result<(), ProgramError> {
let [(input0, log_in0), (input1, log_in1)] =
stack_pop_with_log_and_fill::<2, _>(&mut registers_state, memory_state, traces, &mut row)?;
stack_pop_with_log_and_fill::<2, _>(state, &mut row)?;
let operation = arithmetic::Operation::binary(operator, input0, input1);
let log_out =
stack_push_log_and_fill(&mut registers_state, traces, &mut row, operation.result())?;
let log_out = stack_push_log_and_fill(state, &mut row, operation.result())?;
traces.push_arithmetic(operation);
traces.push_memory(log_in0);
traces.push_memory(log_in1);
traces.push_memory(log_out);
traces.push_cpu(row);
Ok(registers_state)
state.traces.push_arithmetic(operation);
state.traces.push_memory(log_in0);
state.traces.push_memory(log_in1);
state.traces.push_memory(log_out);
state.traces.push_cpu(row);
Ok(())
}
pub(crate) fn generate_ternary_arithmetic_op<F: Field>(
operator: arithmetic::TernaryOperator,
mut registers_state: RegistersState,
memory_state: &MemoryState,
traces: &mut Traces<F>,
state: &mut GenerationState<F>,
mut row: CpuColumnsView<F>,
) -> Result<RegistersState, ProgramError> {
) -> Result<(), ProgramError> {
let [(input0, log_in0), (input1, log_in1), (input2, log_in2)] =
stack_pop_with_log_and_fill::<3, _>(&mut registers_state, memory_state, traces, &mut row)?;
stack_pop_with_log_and_fill::<3, _>(state, &mut row)?;
let operation = arithmetic::Operation::ternary(operator, input0, input1, input2);
let log_out =
stack_push_log_and_fill(&mut registers_state, traces, &mut row, operation.result())?;
let log_out = stack_push_log_and_fill(state, &mut row, operation.result())?;
traces.push_arithmetic(operation);
traces.push_memory(log_in0);
traces.push_memory(log_in1);
traces.push_memory(log_in2);
traces.push_memory(log_out);
traces.push_cpu(row);
Ok(registers_state)
state.traces.push_arithmetic(operation);
state.traces.push_memory(log_in0);
state.traces.push_memory(log_in1);
state.traces.push_memory(log_in2);
state.traces.push_memory(log_out);
state.traces.push_cpu(row);
Ok(())
}
pub(crate) fn generate_prover_input<F: Field>(
mut registers_state: RegistersState,
memory_state: &MemoryState,
traces: &mut Traces<F>,
state: &mut GenerationState<F>,
mut row: CpuColumnsView<F>,
) -> Result<RegistersState, ProgramError> {
todo!()
) -> Result<(), ProgramError> {
let pc = state.registers.program_counter;
let input_fn = &KERNEL.prover_inputs[&pc];
let input = state.prover_input(input_fn);
let write = stack_push_log_and_fill(state, &mut row, input)?;
state.traces.push_memory(write);
state.traces.push_cpu(row);
Ok(())
}
pub(crate) fn generate_pop<F: Field>(
mut registers_state: RegistersState,
memory_state: &MemoryState,
traces: &mut Traces<F>,
state: &mut GenerationState<F>,
mut row: CpuColumnsView<F>,
) -> Result<RegistersState, ProgramError> {
) -> Result<(), ProgramError> {
todo!()
}
pub(crate) fn generate_jump<F: Field>(
mut registers_state: RegistersState,
memory_state: &MemoryState,
traces: &mut Traces<F>,
state: &mut GenerationState<F>,
mut row: CpuColumnsView<F>,
) -> Result<RegistersState, ProgramError> {
let [(dst, log_in0)] =
stack_pop_with_log_and_fill::<1, _>(&mut registers_state, memory_state, traces, &mut row)?;
) -> Result<(), ProgramError> {
let [(dst, log_in0)] = stack_pop_with_log_and_fill::<1, _>(state, &mut row)?;
traces.push_memory(log_in0);
traces.push_cpu(row);
registers_state.program_counter = u256_saturating_cast_usize(dst);
state.traces.push_memory(log_in0);
state.traces.push_cpu(row);
state.registers.program_counter = u256_saturating_cast_usize(dst);
// TODO: Set other cols like input0_upper_sum_inv.
Ok(registers_state)
Ok(())
}
pub(crate) fn generate_jumpi<F: Field>(
mut registers_state: RegistersState,
memory_state: &MemoryState,
traces: &mut Traces<F>,
state: &mut GenerationState<F>,
mut row: CpuColumnsView<F>,
) -> Result<RegistersState, ProgramError> {
let [(dst, log_in0), (cond, log_in1)] =
stack_pop_with_log_and_fill::<2, _>(&mut registers_state, memory_state, traces, &mut row)?;
) -> Result<(), ProgramError> {
let [(dst, log_in0), (cond, log_in1)] = stack_pop_with_log_and_fill::<2, _>(state, &mut row)?;
traces.push_memory(log_in0);
traces.push_memory(log_in1);
traces.push_cpu(row);
registers_state.program_counter = if cond.is_zero() {
registers_state.program_counter + 1
state.traces.push_memory(log_in0);
state.traces.push_memory(log_in1);
state.traces.push_cpu(row);
state.registers.program_counter = if cond.is_zero() {
state.registers.program_counter + 1
} else {
u256_saturating_cast_usize(dst)
};
// TODO: Set other cols like input0_upper_sum_inv.
Ok(registers_state)
Ok(())
}
pub(crate) fn generate_push<F: Field>(
n: u8,
mut registers_state: RegistersState,
memory_state: &MemoryState,
traces: &mut Traces<F>,
state: &mut GenerationState<F>,
mut row: CpuColumnsView<F>,
) -> Result<RegistersState, ProgramError> {
let context = registers_state.effective_context();
) -> Result<(), ProgramError> {
let context = state.registers.effective_context();
let num_bytes = n as usize + 1;
let initial_offset = registers_state.program_counter + 1;
let initial_offset = state.registers.program_counter + 1;
let offsets = initial_offset..initial_offset + num_bytes;
let mut addrs = offsets.map(|offset| MemoryAddress::new(context, Segment::Code, offset));
@ -183,7 +167,8 @@ pub(crate) fn generate_push<F: Field>(
// to stack_push_log_and_fill.
let bytes = (0..num_bytes)
.map(|i| {
memory_state
state
.memory
.get(MemoryAddress::new(
context,
Segment::Code,
@ -194,16 +179,16 @@ pub(crate) fn generate_push<F: Field>(
.collect_vec();
let val = U256::from_big_endian(&bytes);
let write = stack_push_log_and_fill(&mut registers_state, traces, &mut row, val)?;
let write = stack_push_log_and_fill(state, &mut row, val)?;
// In the first cycle, we read up to NUM_GP_CHANNELS - 1 bytes, leaving the last GP channel
// to push the result.
for (i, addr) in (&mut addrs).take(NUM_GP_CHANNELS - 1).enumerate() {
let (_, read) = mem_read_gp_with_log_and_fill(i, addr, memory_state, traces, &mut row);
traces.push_memory(read);
let (_, read) = mem_read_gp_with_log_and_fill(i, addr, state, &mut row);
state.traces.push_memory(read);
}
traces.push_memory(write);
traces.push_cpu(row);
state.traces.push_memory(write);
state.traces.push_cpu(row);
// In any subsequent cycles, we read up to 1 + NUM_GP_CHANNELS bytes.
for mut addrs_chunk in &addrs.chunks(1 + NUM_GP_CHANNELS) {
@ -211,258 +196,238 @@ pub(crate) fn generate_push<F: Field>(
// TODO: Set other row fields, like push=1?
let first_addr = addrs_chunk.next().unwrap();
let (_, first_read) =
mem_read_code_with_log_and_fill(first_addr, memory_state, traces, &mut row);
traces.push_memory(first_read);
let (_, first_read) = mem_read_code_with_log_and_fill(first_addr, state, &mut row);
state.traces.push_memory(first_read);
for (i, addr) in addrs_chunk.enumerate() {
let (_, read) = mem_read_gp_with_log_and_fill(i, addr, memory_state, traces, &mut row);
traces.push_memory(read);
let (_, read) = mem_read_gp_with_log_and_fill(i, addr, state, &mut row);
state.traces.push_memory(read);
}
traces.push_cpu(row);
state.traces.push_cpu(row);
}
Ok(registers_state)
Ok(())
}
pub(crate) fn generate_dup<F: Field>(
n: u8,
mut registers_state: RegistersState,
memory_state: &MemoryState,
traces: &mut Traces<F>,
state: &mut GenerationState<F>,
mut row: CpuColumnsView<F>,
) -> Result<RegistersState, ProgramError> {
let other_addr_lo = registers_state
) -> Result<(), ProgramError> {
let other_addr_lo = state
.registers
.stack_len
.checked_sub(1 + (n as usize))
.ok_or(ProgramError::StackUnderflow)?;
let other_addr = MemoryAddress::new(registers_state.context, Segment::Stack, other_addr_lo);
let other_addr = MemoryAddress::new(
state.registers.effective_context(),
Segment::Stack,
other_addr_lo,
);
let (val, log_in) =
mem_read_gp_with_log_and_fill(0, other_addr, memory_state, traces, &mut row);
let log_out = stack_push_log_and_fill(&mut registers_state, traces, &mut row, val)?;
let (val, log_in) = mem_read_gp_with_log_and_fill(0, other_addr, state, &mut row);
let log_out = stack_push_log_and_fill(state, &mut row, val)?;
traces.push_memory(log_in);
traces.push_memory(log_out);
traces.push_cpu(row);
Ok(registers_state)
state.traces.push_memory(log_in);
state.traces.push_memory(log_out);
state.traces.push_cpu(row);
Ok(())
}
pub(crate) fn generate_swap<F: Field>(
n: u8,
mut registers_state: RegistersState,
memory_state: &MemoryState,
traces: &mut Traces<F>,
state: &mut GenerationState<F>,
mut row: CpuColumnsView<F>,
) -> Result<RegistersState, ProgramError> {
let other_addr_lo = registers_state
) -> Result<(), ProgramError> {
let other_addr_lo = state
.registers
.stack_len
.checked_sub(2 + (n as usize))
.ok_or(ProgramError::StackUnderflow)?;
let other_addr = MemoryAddress::new(registers_state.context, Segment::Stack, other_addr_lo);
let other_addr = MemoryAddress::new(
state.registers.effective_context(),
Segment::Stack,
other_addr_lo,
);
let [(in0, log_in0)] =
stack_pop_with_log_and_fill::<1, _>(&mut registers_state, memory_state, traces, &mut row)?;
let (in1, log_in1) =
mem_read_gp_with_log_and_fill(1, other_addr, memory_state, traces, &mut row);
let log_out0 =
mem_write_gp_log_and_fill(NUM_GP_CHANNELS - 2, other_addr, traces, &mut row, in0);
let log_out1 = stack_push_log_and_fill(&mut registers_state, traces, &mut row, in1)?;
let [(in0, log_in0)] = stack_pop_with_log_and_fill::<1, _>(state, &mut row)?;
let (in1, log_in1) = mem_read_gp_with_log_and_fill(1, other_addr, state, &mut row);
let log_out0 = mem_write_gp_log_and_fill(NUM_GP_CHANNELS - 2, other_addr, state, &mut row, in0);
let log_out1 = stack_push_log_and_fill(state, &mut row, in1)?;
traces.push_memory(log_in0);
traces.push_memory(log_in1);
traces.push_memory(log_out0);
traces.push_memory(log_out1);
traces.push_cpu(row);
Ok(registers_state)
state.traces.push_memory(log_in0);
state.traces.push_memory(log_in1);
state.traces.push_memory(log_out0);
state.traces.push_memory(log_out1);
state.traces.push_cpu(row);
Ok(())
}
pub(crate) fn generate_not<F: Field>(
mut registers_state: RegistersState,
memory_state: &MemoryState,
traces: &mut Traces<F>,
state: &mut GenerationState<F>,
mut row: CpuColumnsView<F>,
) -> Result<RegistersState, ProgramError> {
let [(x, log_in)] =
stack_pop_with_log_and_fill::<1, _>(&mut registers_state, memory_state, traces, &mut row)?;
) -> Result<(), ProgramError> {
let [(x, log_in)] = stack_pop_with_log_and_fill::<1, _>(state, &mut row)?;
let result = !x;
let log_out = stack_push_log_and_fill(&mut registers_state, traces, &mut row, result)?;
let log_out = stack_push_log_and_fill(state, &mut row, result)?;
traces.push_memory(log_in);
traces.push_memory(log_out);
traces.push_cpu(row);
Ok(registers_state)
state.traces.push_memory(log_in);
state.traces.push_memory(log_out);
state.traces.push_cpu(row);
Ok(())
}
pub(crate) fn generate_iszero<F: Field>(
mut registers_state: RegistersState,
memory_state: &MemoryState,
traces: &mut Traces<F>,
state: &mut GenerationState<F>,
mut row: CpuColumnsView<F>,
) -> Result<RegistersState, ProgramError> {
let [(x, log_in)] =
stack_pop_with_log_and_fill::<1, _>(&mut registers_state, memory_state, traces, &mut row)?;
) -> Result<(), ProgramError> {
let [(x, log_in)] = stack_pop_with_log_and_fill::<1, _>(state, &mut row)?;
let is_zero = x.is_zero();
let result = {
let t: u64 = is_zero.into();
t.into()
};
let log_out = stack_push_log_and_fill(&mut registers_state, traces, &mut row, result)?;
let log_out = stack_push_log_and_fill(state, &mut row, result)?;
generate_pinv_diff(x, U256::zero(), &mut row);
traces.push_memory(log_in);
traces.push_memory(log_out);
traces.push_cpu(row);
Ok(registers_state)
state.traces.push_memory(log_in);
state.traces.push_memory(log_out);
state.traces.push_cpu(row);
Ok(())
}
pub(crate) fn generate_syscall<F: Field>(
opcode: u8,
mut registers_state: RegistersState,
memory_state: &MemoryState,
traces: &mut Traces<F>,
state: &mut GenerationState<F>,
mut row: CpuColumnsView<F>,
) -> Result<RegistersState, ProgramError> {
) -> Result<(), ProgramError> {
let handler_jumptable_addr = KERNEL.global_labels["syscall_jumptable"] as usize;
let handler_addr_addr = handler_jumptable_addr + (opcode as usize);
let (handler_addr0, log_in0) = mem_read_gp_with_log_and_fill(
0,
MemoryAddress::new(0, Segment::Code, handler_addr_addr),
memory_state,
traces,
state,
&mut row,
);
let (handler_addr1, log_in1) = mem_read_gp_with_log_and_fill(
1,
MemoryAddress::new(0, Segment::Code, handler_addr_addr + 1),
memory_state,
traces,
state,
&mut row,
);
let (handler_addr2, log_in2) = mem_read_gp_with_log_and_fill(
2,
MemoryAddress::new(0, Segment::Code, handler_addr_addr + 2),
memory_state,
traces,
state,
&mut row,
);
let handler_addr = (handler_addr0 << 16) + (handler_addr1 << 8) + handler_addr2;
let new_program_counter = handler_addr.as_usize();
let syscall_info = U256::from(registers_state.program_counter)
+ (U256::from(u64::from(registers_state.is_kernel)) << 32);
let log_out = stack_push_log_and_fill(&mut registers_state, traces, &mut row, syscall_info)?;
let syscall_info = U256::from(state.registers.program_counter)
+ (U256::from(u64::from(state.registers.is_kernel)) << 32);
let log_out = stack_push_log_and_fill(state, &mut row, syscall_info)?;
registers_state.program_counter = new_program_counter;
registers_state.is_kernel = true;
state.registers.program_counter = new_program_counter;
state.registers.is_kernel = true;
traces.push_memory(log_in0);
traces.push_memory(log_in1);
traces.push_memory(log_in2);
traces.push_memory(log_out);
traces.push_cpu(row);
state.traces.push_memory(log_in0);
state.traces.push_memory(log_in1);
state.traces.push_memory(log_in2);
state.traces.push_memory(log_out);
state.traces.push_cpu(row);
Ok(registers_state)
Ok(())
}
pub(crate) fn generate_eq<F: Field>(
mut registers_state: RegistersState,
memory_state: &MemoryState,
traces: &mut Traces<F>,
state: &mut GenerationState<F>,
mut row: CpuColumnsView<F>,
) -> Result<RegistersState, ProgramError> {
let [(in0, log_in0), (in1, log_in1)] =
stack_pop_with_log_and_fill::<2, _>(&mut registers_state, memory_state, traces, &mut row)?;
) -> Result<(), ProgramError> {
let [(in0, log_in0), (in1, log_in1)] = stack_pop_with_log_and_fill::<2, _>(state, &mut row)?;
let eq = in0 == in1;
let result = U256::from(u64::from(eq));
let log_out = stack_push_log_and_fill(&mut registers_state, traces, &mut row, result)?;
let log_out = stack_push_log_and_fill(state, &mut row, result)?;
generate_pinv_diff(in0, in1, &mut row);
traces.push_memory(log_in0);
traces.push_memory(log_in1);
traces.push_memory(log_out);
traces.push_cpu(row);
Ok(registers_state)
state.traces.push_memory(log_in0);
state.traces.push_memory(log_in1);
state.traces.push_memory(log_out);
state.traces.push_cpu(row);
Ok(())
}
pub(crate) fn generate_exit_kernel<F: Field>(
mut registers_state: RegistersState,
memory_state: &MemoryState,
traces: &mut Traces<F>,
state: &mut GenerationState<F>,
mut row: CpuColumnsView<F>,
) -> Result<RegistersState, ProgramError> {
let [(kexit_info, log_in)] =
stack_pop_with_log_and_fill::<1, _>(&mut registers_state, memory_state, traces, &mut row)?;
) -> Result<(), ProgramError> {
let [(kexit_info, log_in)] = stack_pop_with_log_and_fill::<1, _>(state, &mut row)?;
let kexit_info_u64: [u64; 4] = kexit_info.0;
let program_counter = kexit_info_u64[0] as usize;
let is_kernel_mode_val = (kexit_info_u64[1] >> 32) as u32;
assert!(is_kernel_mode_val == 0 || is_kernel_mode_val == 1);
let is_kernel_mode = is_kernel_mode_val != 0;
registers_state.program_counter = program_counter;
registers_state.is_kernel = is_kernel_mode;
state.registers.program_counter = program_counter;
state.registers.is_kernel = is_kernel_mode;
traces.push_memory(log_in);
traces.push_cpu(row);
state.traces.push_memory(log_in);
state.traces.push_cpu(row);
Ok(registers_state)
Ok(())
}
pub(crate) fn generate_mload_general<F: Field>(
mut registers_state: RegistersState,
memory_state: &MemoryState,
traces: &mut Traces<F>,
state: &mut GenerationState<F>,
mut row: CpuColumnsView<F>,
) -> Result<RegistersState, ProgramError> {
) -> Result<(), ProgramError> {
let [(context, log_in0), (segment, log_in1), (virt, log_in2)] =
stack_pop_with_log_and_fill::<3, _>(&mut registers_state, memory_state, traces, &mut row)?;
stack_pop_with_log_and_fill::<3, _>(state, &mut row)?;
// If virt won't fit in a usize, don't try to convert it, just return 0.
let val = if virt > usize::MAX.into() {
U256::zero()
} else {
memory_state.get(MemoryAddress {
state.memory.get(MemoryAddress {
context: context.as_usize(),
segment: segment.as_usize(),
virt: virt.as_usize(),
})
};
let log_out = stack_push_log_and_fill(&mut registers_state, traces, &mut row, val)?;
let log_out = stack_push_log_and_fill(state, &mut row, val)?;
traces.push_memory(log_in0);
traces.push_memory(log_in1);
traces.push_memory(log_in2);
traces.push_memory(log_out);
traces.push_cpu(row);
Ok(registers_state)
state.traces.push_memory(log_in0);
state.traces.push_memory(log_in1);
state.traces.push_memory(log_in2);
state.traces.push_memory(log_out);
state.traces.push_cpu(row);
Ok(())
}
pub(crate) fn generate_mstore_general<F: Field>(
mut registers_state: RegistersState,
memory_state: &MemoryState,
traces: &mut Traces<F>,
state: &mut GenerationState<F>,
mut row: CpuColumnsView<F>,
) -> Result<RegistersState, ProgramError> {
) -> Result<(), ProgramError> {
let [(context, log_in0), (segment, log_in1), (virt, log_in2), (val, log_in3)] =
stack_pop_with_log_and_fill::<4, _>(&mut registers_state, memory_state, traces, &mut row)?;
stack_pop_with_log_and_fill::<4, _>(state, &mut row)?;
let address = MemoryAddress {
context: context.as_usize(),
segment: segment.as_usize(),
virt: virt.as_usize(),
};
let log_write = mem_write_gp_log_and_fill(4, address, traces, &mut row, val);
let log_write = mem_write_gp_log_and_fill(4, address, state, &mut row, val);
traces.push_memory(log_in0);
traces.push_memory(log_in1);
traces.push_memory(log_in2);
traces.push_memory(log_in3);
traces.push_memory(log_write);
traces.push_cpu(row);
Ok(registers_state)
state.traces.push_memory(log_in0);
state.traces.push_memory(log_in1);
state.traces.push_memory(log_in2);
state.traces.push_memory(log_in3);
state.traces.push_memory(log_write);
state.traces.push_cpu(row);
Ok(())
}

View File

@ -1,5 +1,7 @@
use crate::cpu::kernel::aggregator::KERNEL;
const KERNEL_CONTEXT: usize = 0;
#[derive(Clone, Copy, Debug, Eq, PartialEq)]
pub struct RegistersState {
pub program_counter: usize,
@ -11,7 +13,7 @@ pub struct RegistersState {
impl RegistersState {
pub(crate) fn effective_context(&self) -> usize {
if self.is_kernel {
0
KERNEL_CONTEXT
} else {
self.context
}

View File

@ -1,40 +1,29 @@
use plonky2::field::types::Field;
use crate::cpu::columns::CpuColumnsView;
use crate::generation::state::GenerationState;
use crate::memory::segments::Segment;
use crate::witness::errors::ProgramError;
use crate::witness::memory::{MemoryAddress, MemoryState};
use crate::witness::memory::MemoryAddress;
use crate::witness::operation::*;
use crate::witness::state::RegistersState;
use crate::witness::traces::Traces;
use crate::witness::util::mem_read_code_with_log_and_fill;
use crate::{arithmetic, logic};
const KERNEL_CONTEXT: usize = 0;
fn read_code_memory<F: Field>(
registers_state: RegistersState,
memory_state: &MemoryState,
traces: &mut Traces<F>,
row: &mut CpuColumnsView<F>,
) -> u8 {
let code_context = if registers_state.is_kernel {
KERNEL_CONTEXT
} else {
registers_state.context
};
fn read_code_memory<F: Field>(state: &mut GenerationState<F>, row: &mut CpuColumnsView<F>) -> u8 {
let code_context = state.registers.effective_context();
row.code_context = F::from_canonical_usize(code_context);
let address = MemoryAddress::new(code_context, Segment::Code, registers_state.program_counter);
let (opcode, mem_log) = mem_read_code_with_log_and_fill(address, memory_state, traces, row);
let address = MemoryAddress::new(code_context, Segment::Code, state.registers.program_counter);
let (opcode, mem_log) = mem_read_code_with_log_and_fill(address, state, row);
traces.push_memory(mem_log);
state.traces.push_memory(mem_log);
opcode
}
fn decode(registers_state: RegistersState, opcode: u8) -> Result<Operation, ProgramError> {
match (opcode, registers_state.is_kernel) {
fn decode(registers: RegistersState, opcode: u8) -> Result<Operation, ProgramError> {
match (opcode, registers.is_kernel) {
(0x00, _) => Ok(Operation::Syscall(opcode)),
(0x01, _) => Ok(Operation::BinaryArithmetic(arithmetic::BinaryOperator::Add)),
(0x02, _) => Ok(Operation::BinaryArithmetic(arithmetic::BinaryOperator::Mul)),
@ -187,105 +176,82 @@ fn fill_op_flag<F: Field>(op: Operation, row: &mut CpuColumnsView<F>) {
}
fn perform_op<F: Field>(
state: &mut GenerationState<F>,
op: Operation,
registers_state: RegistersState,
memory_state: &MemoryState,
traces: &mut Traces<F>,
row: CpuColumnsView<F>,
) -> Result<RegistersState, ProgramError> {
let mut new_registers_state = match op {
Operation::Push(n) => generate_push(n, registers_state, memory_state, traces, row)?,
Operation::Dup(n) => generate_dup(n, registers_state, memory_state, traces, row)?,
Operation::Swap(n) => generate_swap(n, registers_state, memory_state, traces, row)?,
Operation::Iszero => generate_iszero(registers_state, memory_state, traces, row)?,
Operation::Not => generate_not(registers_state, memory_state, traces, row)?,
) -> Result<(), ProgramError> {
match op {
Operation::Push(n) => generate_push(n, state, row)?,
Operation::Dup(n) => generate_dup(n, state, row)?,
Operation::Swap(n) => generate_swap(n, state, row)?,
Operation::Iszero => generate_iszero(state, row)?,
Operation::Not => generate_not(state, row)?,
Operation::Byte => todo!(),
Operation::Syscall(opcode) => {
generate_syscall(opcode, registers_state, memory_state, traces, row)?
}
Operation::Eq => generate_eq(registers_state, memory_state, traces, row)?,
Operation::Syscall(opcode) => generate_syscall(opcode, state, row)?,
Operation::Eq => generate_eq(state, row)?,
Operation::BinaryLogic(binary_logic_op) => {
generate_binary_logic_op(binary_logic_op, registers_state, memory_state, traces, row)?
}
Operation::BinaryArithmetic(op) => {
generate_binary_arithmetic_op(op, registers_state, memory_state, traces, row)?
}
Operation::TernaryArithmetic(op) => {
generate_ternary_arithmetic_op(op, registers_state, memory_state, traces, row)?
generate_binary_logic_op(binary_logic_op, state, row)?
}
Operation::BinaryArithmetic(op) => generate_binary_arithmetic_op(op, state, row)?,
Operation::TernaryArithmetic(op) => generate_ternary_arithmetic_op(op, state, row)?,
Operation::KeccakGeneral => todo!(),
Operation::ProverInput => {
generate_prover_input(registers_state, memory_state, traces, row)?
}
Operation::Pop => generate_pop(registers_state, memory_state, traces, row)?,
Operation::Jump => generate_jump(registers_state, memory_state, traces, row)?,
Operation::Jumpi => generate_jumpi(registers_state, memory_state, traces, row)?,
Operation::ProverInput => generate_prover_input(state, row)?,
Operation::Pop => generate_pop(state, row)?,
Operation::Jump => generate_jump(state, row)?,
Operation::Jumpi => generate_jumpi(state, row)?,
Operation::Pc => todo!(),
Operation::Gas => todo!(),
Operation::Jumpdest => todo!(),
Operation::GetContext => todo!(),
Operation::SetContext => todo!(),
Operation::ConsumeGas => todo!(),
Operation::ExitKernel => generate_exit_kernel(registers_state, memory_state, traces, row)?,
Operation::MloadGeneral => {
generate_mload_general(registers_state, memory_state, traces, row)?
}
Operation::MstoreGeneral => {
generate_mstore_general(registers_state, memory_state, traces, row)?
}
Operation::ExitKernel => generate_exit_kernel(state, row)?,
Operation::MloadGeneral => generate_mload_general(state, row)?,
Operation::MstoreGeneral => generate_mstore_general(state, row)?,
};
new_registers_state.program_counter += match op {
state.registers.program_counter += match op {
Operation::Syscall(_) | Operation::ExitKernel => 0,
Operation::Push(n) => n as usize + 2,
Operation::Jump | Operation::Jumpi => 0,
_ => 1,
};
Ok(new_registers_state)
Ok(())
}
fn try_perform_instruction<F: Field>(
registers_state: RegistersState,
memory_state: &MemoryState,
traces: &mut Traces<F>,
) -> Result<RegistersState, ProgramError> {
fn try_perform_instruction<F: Field>(state: &mut GenerationState<F>) -> Result<(), ProgramError> {
let mut row: CpuColumnsView<F> = CpuColumnsView::default();
row.is_cpu_cycle = F::ONE;
let opcode = read_code_memory(registers_state, memory_state, traces, &mut row);
let op = decode(registers_state, opcode)?;
log::trace!("Executing {:?} at {}", op, registers_state.program_counter);
let opcode = read_code_memory(state, &mut row);
let op = decode(state.registers, opcode)?;
log::trace!("Executing {:?} at {}", op, state.registers.program_counter);
fill_op_flag(op, &mut row);
perform_op(op, registers_state, memory_state, traces, row)
perform_op(state, op, row)
}
fn handle_error<F: Field>(
_registers_state: RegistersState,
_memory_state: &MemoryState,
_traces: &mut Traces<F>,
) -> RegistersState {
todo!("constraints for exception handling are not implemented");
fn handle_error<F: Field>(_state: &mut GenerationState<F>) {
todo!("generation for exception handling is not implemented");
}
pub(crate) fn transition<F: Field>(
registers_state: RegistersState,
memory_state: &mut MemoryState,
traces: &mut Traces<F>,
) -> RegistersState {
let checkpoint = traces.checkpoint();
let result = try_perform_instruction(registers_state, memory_state, traces);
memory_state.apply_ops(traces.mem_ops_since(checkpoint));
pub(crate) fn transition<F: Field>(state: &mut GenerationState<F>) {
let checkpoint = state.checkpoint();
let result = try_perform_instruction(state);
match result {
Ok(new_registers_state) => new_registers_state,
Ok(()) => {
state
.memory
.apply_ops(state.traces.mem_ops_since(checkpoint.traces));
}
Err(_) => {
traces.rollback(checkpoint);
if registers_state.is_kernel {
state.rollback(checkpoint);
if state.registers.is_kernel {
panic!("exception in kernel mode");
}
handle_error(registers_state, memory_state, traces)
handle_error(state)
}
}
}

View File

@ -4,11 +4,10 @@ use plonky2::field::types::Field;
use crate::cpu::columns::CpuColumnsView;
use crate::cpu::membus::NUM_GP_CHANNELS;
use crate::cpu::stack_bounds::MAX_USER_STACK_SIZE;
use crate::generation::state::GenerationState;
use crate::memory::segments::Segment;
use crate::witness::errors::ProgramError;
use crate::witness::memory::{MemoryAddress, MemoryChannel, MemoryOp, MemoryOpKind, MemoryState};
use crate::witness::state::RegistersState;
use crate::witness::traces::Traces;
use crate::witness::memory::{MemoryAddress, MemoryChannel, MemoryOp, MemoryOpKind};
fn to_byte_checked(n: U256) -> u8 {
let res = n.byte(0);
@ -24,33 +23,55 @@ fn to_bits_le<F: Field>(n: u8) -> [F; 8] {
res
}
pub(crate) fn mem_read_with_log<T: Copy>(
/// Peak at the stack item `i`th from the top. If `i=0` this gives the tip.
pub(crate) fn stack_peek<F: Field>(state: &GenerationState<F>, i: usize) -> Option<U256> {
if i >= state.registers.stack_len {
return None;
}
Some(state.memory.get(MemoryAddress::new(
state.registers.effective_context(),
Segment::Stack,
state.registers.stack_len - 1 - i,
)))
}
pub(crate) fn mem_read_with_log<F: Field>(
channel: MemoryChannel,
address: MemoryAddress,
memory_state: &MemoryState,
traces: &Traces<T>,
state: &GenerationState<F>,
) -> (U256, MemoryOp) {
let val = memory_state.get(address);
let op = MemoryOp::new(channel, traces.clock(), address, MemoryOpKind::Read, val);
let val = state.memory.get(address);
let op = MemoryOp::new(
channel,
state.traces.clock(),
address,
MemoryOpKind::Read,
val,
);
(val, op)
}
pub(crate) fn mem_write_log<T: Copy>(
pub(crate) fn mem_write_log<F: Field>(
channel: MemoryChannel,
address: MemoryAddress,
traces: &Traces<T>,
state: &mut GenerationState<F>,
val: U256,
) -> MemoryOp {
MemoryOp::new(channel, traces.clock(), address, MemoryOpKind::Write, val)
MemoryOp::new(
channel,
state.traces.clock(),
address,
MemoryOpKind::Write,
val,
)
}
pub(crate) fn mem_read_code_with_log_and_fill<F: Field>(
address: MemoryAddress,
memory_state: &MemoryState,
traces: &Traces<F>,
state: &GenerationState<F>,
row: &mut CpuColumnsView<F>,
) -> (u8, MemoryOp) {
let (val, op) = mem_read_with_log(MemoryChannel::Code, address, memory_state, traces);
let (val, op) = mem_read_with_log(MemoryChannel::Code, address, state);
let val_u8 = to_byte_checked(val);
row.opcode_bits = to_bits_le(val_u8);
@ -61,16 +82,10 @@ pub(crate) fn mem_read_code_with_log_and_fill<F: Field>(
pub(crate) fn mem_read_gp_with_log_and_fill<F: Field>(
n: usize,
address: MemoryAddress,
memory_state: &MemoryState,
traces: &Traces<F>,
state: &mut GenerationState<F>,
row: &mut CpuColumnsView<F>,
) -> (U256, MemoryOp) {
let (val, op) = mem_read_with_log(
MemoryChannel::GeneralPurpose(n),
address,
memory_state,
traces,
);
let (val, op) = mem_read_with_log(MemoryChannel::GeneralPurpose(n), address, state);
let val_limbs: [u64; 4] = val.0;
let channel = &mut row.mem_channels[n];
@ -91,11 +106,11 @@ pub(crate) fn mem_read_gp_with_log_and_fill<F: Field>(
pub(crate) fn mem_write_gp_log_and_fill<F: Field>(
n: usize,
address: MemoryAddress,
traces: &Traces<F>,
state: &mut GenerationState<F>,
row: &mut CpuColumnsView<F>,
val: U256,
) -> MemoryOp {
let op = mem_write_log(MemoryChannel::GeneralPurpose(n), address, traces, val);
let op = mem_write_log(MemoryChannel::GeneralPurpose(n), address, state, val);
let val_limbs: [u64; 4] = val.0;
let channel = &mut row.mem_channels[n];
@ -114,12 +129,10 @@ pub(crate) fn mem_write_gp_log_and_fill<F: Field>(
}
pub(crate) fn stack_pop_with_log_and_fill<const N: usize, F: Field>(
registers_state: &mut RegistersState,
memory_state: &MemoryState,
traces: &Traces<F>,
state: &mut GenerationState<F>,
row: &mut CpuColumnsView<F>,
) -> Result<[(U256, MemoryOp); N], ProgramError> {
if (registers_state.stack_len as usize) < N {
if state.registers.stack_len < N {
return Err(ProgramError::StackUnderflow);
}
@ -127,39 +140,38 @@ pub(crate) fn stack_pop_with_log_and_fill<const N: usize, F: Field>(
let mut i = 0usize;
[(); N].map(|_| {
let address = MemoryAddress::new(
registers_state.context,
state.registers.effective_context(),
Segment::Stack,
registers_state.stack_len - 1 - i,
state.registers.stack_len - 1 - i,
);
let res = mem_read_gp_with_log_and_fill(i, address, memory_state, traces, row);
let res = mem_read_gp_with_log_and_fill(i, address, state, row);
i += 1;
res
})
};
registers_state.stack_len -= N;
state.registers.stack_len -= N;
Ok(result)
}
pub(crate) fn stack_push_log_and_fill<F: Field>(
registers_state: &mut RegistersState,
traces: &Traces<F>,
state: &mut GenerationState<F>,
row: &mut CpuColumnsView<F>,
val: U256,
) -> Result<MemoryOp, ProgramError> {
if !registers_state.is_kernel && registers_state.stack_len >= MAX_USER_STACK_SIZE {
if !state.registers.is_kernel && state.registers.stack_len >= MAX_USER_STACK_SIZE {
return Err(ProgramError::StackOverflow);
}
let address = MemoryAddress::new(
registers_state.context,
state.registers.effective_context(),
Segment::Stack,
registers_state.stack_len,
state.registers.stack_len,
);
let res = mem_write_gp_log_and_fill(NUM_GP_CHANNELS - 1, address, traces, row, val);
let res = mem_write_gp_log_and_fill(NUM_GP_CHANNELS - 1, address, state, row, val);
registers_state.stack_len += 1;
state.registers.stack_len += 1;
Ok(res)
}