Witness generation work

This commit is contained in:
Jacqueline Nabaglo 2022-11-15 09:26:54 -08:00
parent 626c2583de
commit 205bd58f98
13 changed files with 514 additions and 28 deletions

View File

@ -53,7 +53,7 @@ const NATIVE_INSTRUCTIONS: [usize; 37] = [
// not SYSCALL (performs a jump)
];
fn get_halt_pcs<F: Field>() -> (F, F) {
pub(crate) fn get_halt_pcs<F: Field>() -> (F, F) {
let halt_pc0 = KERNEL.global_labels["halt_pc0"];
let halt_pc1 = KERNEL.global_labels["halt_pc1"];
@ -63,6 +63,12 @@ fn get_halt_pcs<F: Field>() -> (F, F) {
)
}
pub(crate) fn get_start_pc<F: Field>() -> F {
let start_pc = KERNEL.global_labels["main"];
F::from_canonical_usize(start_pc)
}
pub fn eval_packed_generic<P: PackedField>(
lv: &CpuColumnsView<P>,
nv: &CpuColumnsView<P>,
@ -89,8 +95,7 @@ pub fn eval_packed_generic<P: PackedField>(
// - execution is in kernel mode, and
// - the stack is empty.
let is_last_noncpu_cycle = (lv.is_cpu_cycle - P::ONES) * nv.is_cpu_cycle;
let pc_diff =
nv.program_counter - P::Scalar::from_canonical_usize(KERNEL.global_labels["main"]);
let pc_diff = nv.program_counter - get_start_pc::<P::Scalar>();
yield_constr.constraint_transition(is_last_noncpu_cycle * pc_diff);
yield_constr.constraint_transition(is_last_noncpu_cycle * (nv.is_kernel_mode - P::ONES));
yield_constr.constraint_transition(is_last_noncpu_cycle * nv.stack_len);
@ -142,9 +147,7 @@ pub fn eval_ext_circuit<F: RichField + Extendable<D>, const D: usize>(
builder.mul_sub_extension(lv.is_cpu_cycle, nv.is_cpu_cycle, nv.is_cpu_cycle);
// Start at `main`.
let main = builder.constant_extension(F::Extension::from_canonical_usize(
KERNEL.global_labels["main"],
));
let main = builder.constant_extension(get_start_pc::<F>().into());
let pc_diff = builder.sub_extension(nv.program_counter, main);
let pc_constr = builder.mul_extension(is_last_noncpu_cycle, pc_diff);
yield_constr.constraint_transition(builder, pc_constr);

View File

@ -1,6 +1,6 @@
pub(crate) mod bootstrap_kernel;
pub(crate) mod columns;
mod control_flow;
pub(crate) mod control_flow;
pub mod cpu_stark;
pub(crate) mod decode;
mod dup_swap;

View File

@ -7,7 +7,17 @@ use plonky2::iop::ext_target::ExtensionTarget;
use crate::constraint_consumer::{ConstraintConsumer, RecursiveConstraintConsumer};
use crate::cpu::columns::CpuColumnsView;
pub fn generate<F: RichField>(lv: &mut CpuColumnsView<F>) {
fn limbs(x: U256) -> [u32; 8] {
let mut res = [0; 8];
let x_u64: &[u64; 4] = x.as_ref();
for i in 0..4 {
res[2 * i] = x_u64[i] as u32;
res[2 * i + 1] = (x_u64[i] >> 32) as u32;
}
res
}
pub fn generate_pinv_diff<F: RichField>(val0: U256, val1: U256, lv: &mut CpuColumnsView<F>) {
let input0 = lv.mem_channels[0].value;
let eq_filter = lv.op.eq.to_canonical_u64();

View File

@ -10,30 +10,13 @@ use crate::cpu::columns::CpuColumnsView;
const LIMB_SIZE: usize = 32;
const ALL_1_LIMB: u64 = (1 << LIMB_SIZE) - 1;
pub fn generate<F: RichField>(lv: &mut CpuColumnsView<F>) {
let is_not_filter = lv.op.not.to_canonical_u64();
if is_not_filter == 0 {
return;
}
assert_eq!(is_not_filter, 1);
let input = lv.mem_channels[0].value;
let output = &mut lv.mem_channels[1].value;
for (input, output_ref) in input.into_iter().zip(output.iter_mut()) {
let input = input.to_canonical_u64();
assert_eq!(input >> LIMB_SIZE, 0);
let output = input ^ ALL_1_LIMB;
*output_ref = F::from_canonical_u64(output);
}
}
pub fn eval_packed<P: PackedField>(
lv: &CpuColumnsView<P>,
yield_constr: &mut ConstraintConsumer<P>,
) {
// This is simple: just do output = 0xffffffff - input.
let input = lv.mem_channels[0].value;
let output = lv.mem_channels[1].value;
let output = lv.mem_channels[NUM_GP_CHANNELS - 1].value;
let cycle_filter = lv.is_cpu_cycle;
let is_not_filter = lv.op.not;
let filter = cycle_filter * is_not_filter;
@ -50,7 +33,7 @@ pub fn eval_ext_circuit<F: RichField + Extendable<D>, const D: usize>(
yield_constr: &mut RecursiveConstraintConsumer<F, D>,
) {
let input = lv.mem_channels[0].value;
let output = lv.mem_channels[1].value;
let output = lv.mem_channels[NUM_GP_CHANNELS - 1].value;
let cycle_filter = lv.is_cpu_cycle;
let is_not_filter = lv.op.not;
let filter = builder.mul_extension(cycle_filter, is_not_filter);

View File

@ -19,7 +19,7 @@ use plonky2::iop::ext_target::ExtensionTarget;
use crate::constraint_consumer::{ConstraintConsumer, RecursiveConstraintConsumer};
use crate::cpu::columns::{CpuColumnsView, COL_MAP};
const MAX_USER_STACK_SIZE: u64 = 1024;
pub const MAX_USER_STACK_SIZE: u64 = 1024;
// Below only includes the operations that pop the top of the stack **without reading the value from
// memory**, i.e. `POP`.

View File

@ -0,0 +1,8 @@
enum ProgramError {
OutOfGas,
InvalidOpcode,
StackUnderflow,
InvalidJumpDestination,
InvalidJumpiDestination,
StackOverflow,
}

12
evm/src/witness/mem_tx.rs Normal file
View File

@ -0,0 +1,12 @@
use crate::witness::memory::{MemoryOp, MemoryOpKind, MemoryState};
pub fn apply_mem_ops(state: &mut MemoryState, mut ops: Vec<MemoryOp>) {
ops.sort_unstable_by_key(|mem_op| mem_op.timestamp);
for op in ops {
let MemoryOp { address, op, .. } = op;
if let MemoryOpKind::Write(val) = op {
state.set(address, val);
}
}
}

83
evm/src/witness/memory.rs Normal file
View File

@ -0,0 +1,83 @@
use std::collections::HashMap;
use ethereum_types::U256;
use crate::cpu::membus::{NUM_CHANNELS, NUM_GP_CHANNELS};
pub enum MemoryChannel {
Code,
GeneralPurpose(usize),
}
use MemoryChannel::{Code, GeneralPurpose};
impl MemoryChannel {
pub fn index(&self) -> usize {
match *self {
Code => 0,
GeneralPurpose(n) => {
assert!(n < NUM_GP_CHANNELS);
n + 1
}
}
}
}
pub type MemoryAddress = (u32, u32, u32);
pub enum MemoryOpKind {
Read,
Write(U256),
}
pub struct MemoryOp {
pub timestamp: u64,
pub address: MemoryAddress,
pub op: MemoryOpKind,
}
impl MemoryOp {
pub fn new(
channel: MemoryChannel,
clock: usize,
address: MemoryAddress,
op: MemoryOpKind,
) -> Self {
let timestamp = (clock * NUM_CHANNELS + channel.index()) as u64;
MemoryOp { timestamp, address, op }
}
}
#[derive(Clone)]
pub struct MemoryState {
contents: HashMap<MemoryAddress, U256>,
}
impl MemoryState {
pub fn new(kernel_code: &[u8]) -> Self {
let mut contents = HashMap::new();
for (i, &byte) in kernel_code.iter().enumerate() {
if byte != 0 {
let address = (0, 0, i as u32);
let val = byte.into();
contents.insert(address, val);
}
}
Self { contents }
}
pub fn get(&self, address: MemoryAddress) -> U256 {
self.contents.get(&address).copied().unwrap_or_else(U256::zero)
}
pub fn set(&mut self, address: MemoryAddress, val: U256) {
if val.is_zero() {
self.contents.remove(&address);
} else {
self.contents.insert(address, val);
}
}
}

4
evm/src/witness/mod.rs Normal file
View File

@ -0,0 +1,4 @@
mod mem_tx;
mod memory;
mod state;
mod traces;

View File

@ -0,0 +1,156 @@
use crate::cpu::kernel::aggregator::KERNEL;
enum Operation {
Dup(u8),
Swap(u8),
Iszero,
Not,
Jump(JumpOp),
Syscall(u8),
Eq,
ExitKernel,
BinaryLogic(BinaryLogicOp),
NotImplemented,
}
enum JumpOp {
Jump,
Jumpi,
}
enum BinaryLogicOp {
And,
Or,
Xor,
}
impl BinaryLogicOp {
fn result(&self, a: U256, b: U256) -> U256 {
match self {
BinaryLogicOp::And => a & b,
BinaryLogicOp::Or => a | b,
BinaryLogicOp::Xor => a ^ b,
}
}
}
fn make_logic_row<F>(op: BinaryLogicOp, in0: U256, in1: U256, result: U256) -> [F; logic::columns::NUM_COLUMNS] {
let mut row = [F::ZERO; logic::columns::NUM_COLUMNS];
row[match op {
BinaryLogicOp::And => logic::columns::IS_AND,
BinaryLogicOp::Or => logic::columns::IS_OR,
BinaryLogicOp::Xor => logic::columns::IS_XOR,
}] = F::ONE;
for i in 0..256 {
row[logic::columns::INPUT0[i]] = F::from_bool(in0.bit(i));
row[logic::columns::INPUT1[i]] = F::from_bool(in1.bit(i));
}
let result_limbs: &[u64] = result.as_ref();
for (i, &limb) in result_limbs.iter().enumerate() {
row[logic::columns::RESULT[2 * i]] = F::from_canonical_u32(limb as u32);
row[logic::columns::RESULT[2 * i + 1]] = F::from_canonical_u32((limb >> 32) as u32);
}
row
}
fn generate_binary_logic_op<F>(op: BinaryLogicOp, state: &mut State, row: &mut CpuRow, traces: &mut Traces<T>) -> Result<(), ProgramError> {
let ([in0, in1], logs_in) = state.pop_stack_with_log::<2>()?;
let result = op.result(in0, in1);
let log_out = state.push_stack_with_log(result)?;
traces.logic.append(make_logic_row(op, in0, in1, result));
traces.memory.extend(logs_in);
traces.memory.append(log_out);
}
fn generate_dup<F>(n: u8, state: &mut State, row: &mut CpuRow, traces: &mut Traces<T>) -> Result<(), ProgramError> {
let other_addr_lo = state.stack_len.sub_checked(1 + (n as usize)).ok_or(ProgramError::StackUnderflow)?;
let other_addr = (state.context, Segment::Stack as u32, other_addr_lo);
let (val, log_in) = state.mem_read_with_log(MemoryChannel::GeneralPurpose(0), other_addr);
let log_out = state.push_stack_with_log(val)?;
traces.memory.extend([log_in, log_out]);
}
fn generate_swap<F>(n: u8, state: &mut State, row: &mut CpuRow, traces: &mut Traces<T>) -> Result<(), ProgramError> {
let other_addr_lo = state.stack_len.sub_checked(2 + (n as usize)).ok_or(ProgramError::StackUnderflow)?;
let other_addr = (state.context, Segment::Stack as u32, other_addr_lo);
let ([in0], [log_in0]) = state.pop_stack_with_log::<1>()?;
let (in1, log_in1) = state.mem_read_with_log(MemoryChannel::GeneralPurpose(1), other_addr);
let log_out0 = state.mem_write_with_log(MemoryChannel::GeneralPurpose(NUM_GP_CHANNELS - 2), other_addr, in0);
let log_out1 = state.push_stack_with_log(in1)?;
traces.memory.extend([log_in0, log_in1, log_out0, log_out1]);
}
fn generate_not<F>(state: &mut State, row: &mut CpuRow, traces: &mut Traces<T>) -> Result<(), ProgramError> {
let ([x], [log_in]) = state.pop_stack_with_log::<1>()?;
let result = !x;
let log_out = state.push_stack_with_log(result)?;
traces.memory.append(log_in);
traces.memory.append(log_out);
}
fn generate_iszero<F>(state: &mut State, row: &mut CpuRow, traces: &mut Traces<T>) -> Result<(), ProgramError> {
let ([x], [log_in]) = state.pop_stack_with_log::<1>()?;
let is_zero = state.is_zero();
let result = is_zero.into::<u64>().into::<U256>();
let log_out = state.push_stack_with_log(result)?;
generate_pinv_diff(x, U256::zero(), row);
traces.memory.append(log_in);
traces.memory.append(log_out);
}
fn generate_jump<F>(op: JumpOp, state: &mut State, row: &mut CpuRow, traces: &mut Traces<T>) -> Result<(), ProgramError> {
todo!();
}
fn generate_syscall<F>(opcode: u8, state: &mut State, row: &mut CpuRow, traces: &mut Traces<T>) -> Result<(), ProgramError> {
let handler_jumptable_addr = KERNEL.global_labels["syscall_jumptable"] as u32;
let handler_addr_addr = handler_jumptable_addr + (opcode as u32);
let (handler_addr0, in_log0) = state.mem_read_with_log(MemoryChannel::GeneralPurpose(0), (0, Segment::Code as u32, handler_addr_addr));
let (handler_addr1, in_log1) = state.mem_read_with_log(MemoryChannel::GeneralPurpose(1), (0, Segment::Code as u32, handler_addr_addr + 1));
let (handler_addr2, in_log2) = state.mem_read_with_log(MemoryChannel::GeneralPurpose(2), (0, Segment::Code as u32, handler_addr_addr + 2));
let handler_addr = (handler_addr0 << 16) + (handler_addr1 << 8) + handler_addr2;
let new_program_counter = handler_addr.as_u32();
let syscall_info = state.program_counter.into::<U256>() + (state.is_kernel_mode.into::<u64>.into::<U256> << 32);
let log_out = state.push_stack_with_log(syscall_info)?;
state.program_counter = new_program_counter;
state.is_kernel = true;
}
fn generate_eq<F>(state: &mut State, row: &mut CpuRow, traces: &mut Traces<T>) -> Result<(), ProgramError> {
let ([x0, x1], logs_in) = state.pop_stack_with_log::<1>()?;
let equal = x0 == x1;
let result = equal.into::<u64>().into::<U256>();
let log_out = state.push_stack_with_log(result)?;
generate_pinv_diff(x0, x1, row);
traces.memory.extend(logs_in);
traces.memory.append(log_out);
}
fn generate_exit_kernel<F>(state: &mut State, row: &mut CpuRow, traces: &mut Traces<T>) -> Result<(), ProgramError> {
let ([kexit_info], [log_in]) = state.pop_stack_with_log::<1>()?;
let kexit_info_u64: &[u64; 4] = kexit_info.as_ref();
let program_counter = kexit_info_u64[0] as u32;
let is_kernel_mode_val = (kexit_info_u64[1] >> 32) as u32
assert!(is_kernel_mode_val == 0 || is_kernel_mode_val == 1);
let is_kernel_mode = is_kernel_mode_val != 0;
state.program_counter = program_counter;
state.is_kernel = is_kernel_mode;
}

108
evm/src/witness/state.rs Normal file
View File

@ -0,0 +1,108 @@
use ethereum_types::U256;
use crate::cpu::kernel::aggregator::KERNEL;
use crate::witness::errors::ProgramError;
use crate::witness::memory::{MemoryAddress, MemoryChannel, MemoryOp, MemoryOpKind, MemoryState};
pub const KERNEL_CONTEXT: u32 = 0;
pub const MAX_USER_STACK_SIZE: u32 = crate::cpu::stack_bounds::MAX_USER_STACK_SIZE as u32;
#[derive(Clone)]
pub struct State {
pub clock: usize,
pub program_counter: u32,
pub is_kernel: bool,
pub stack_len: u32,
pub context: u32,
pub memory: MemoryState,
}
impl State {
pub fn initial(clock: usize, memory: MemoryState) -> Self {
Self {
clock,
program_counter: KERNEL.global_labels["main"] as u32,
is_kernel: true,
stack_len: 0,
context: KERNEL_CONTEXT,
memory: memory,
}
}
pub fn is_terminal(&self) -> bool {
self.is_kernel && [
KERNEL.global_labels["halt_pc0"] as u32,
KERNEL.global_labels["halt_pc1"] as u32,
].contains(&self.program_counter)
}
pub fn mem_write_log(
&self,
channel: MemoryChannel,
address: MemoryAddress,
val: U256,
) -> MemoryOp {
MemoryOp::new(channel, self.clock, address, MemoryOpKind::Write(val))
}
pub fn mem_write_with_log(
&mut self,
channel: MemoryChannel,
address: MemoryAddress,
val: U256,
) -> MemoryOp {
(self.memory.set(address, val), self.mem_write_log(channel, address, val))
}
pub fn mem_read_log(
&self,
channel: MemoryChannel,
address: MemoryAddress,
) -> MemoryOp {
MemoryOp::new(channel, self.clock, address, MemoryOpKind::Read)
}
pub fn mem_read_with_log(
&self,
channel: MemoryChannel,
address: MemoryAddress,
) -> (U256, MemoryOp) {
(self.memory.get(address), self.mem_read_log(channel, address))
}
pub fn pop_stack_with_log<const N: usize>(&mut self) -> Result<[(U256, MemoryOp); N], ProgramError> {
if stack_len < N {
return Err(ProgramError::StackUnderflow);
}
let mut result = [U256::default(); N];
for i in 0..N {
let channel = GeneralPurpose(i);
let address = (self.context, Segment::Stack as u32, self.stack_len - 1 - i);
result[i] = self.mem_read_with_log(channel, address);
}
self.stack_len -= N;
Ok(result)
}
pub fn push_stack_with_log(&mut self, val: U256) -> Result<MemoryOp, ProgramError> {
if !self.is_kernel_mode {
assert!(self.stack_len <= MAX_USER_STACK_SIZE);
if self.stack_len == MAX_USER_STACK_SIZE {
return Err(ProgramError::StackOverflow);
}
}
let channel = GeneralPurpose(NUM_GP_CHANNELS - 1);
let address = (self.context, Segment::Stack as u32, self.stack_len);
let result = self.mem_write_with_log(channel, address, val);
self.stack_len += 1;
Ok(result)
}
}

32
evm/src/witness/traces.rs Normal file
View File

@ -0,0 +1,32 @@
use crate::arithmetic::columns::NUM_ARITH_COLUMNS;
use crate::cpu::columns::CpuColumnsView;
use crate::logic;
use crate::witness::memory::MemoryOp;
type LogicRow<T> = [T; logic::columns::NUM_COLUMNS];
type ArithmeticRow<T> = [T; NUM_ARITH_COLUMNS];
struct Traces<T: Copy> {
pub cpu: Vec<CpuColumnsView<T>>,
pub logic: Vec<LogicRow<T>>,
pub arithmetic: Vec<ArithmeticRow<T>>,
pub memory: Vec<MemoryOp>,
}
impl<T: Copy> Traces<T> {
pub fn new() -> Self {
Traces {
cpu: vec![],
logic: vec![],
arithmetic: vec![],
memory: vec![],
}
}
pub fn append(&mut self, other: &mut Self) {
self.cpu.append(&mut other.cpu);
self.logic.append(&mut other.logic);
self.arithmetic.append(&mut other.arithmetic);
self.memory.append(&mut other.memory);
}
}

View File

@ -0,0 +1,87 @@
fn to_byte_checked(n: U256) -> u8 {
let res = n.byte(0);
assert_eq!(n, res.into());
res
}
fn to_bits<F: Field>(n: u8) -> [F; 8] {
let mut res = [F::ZERO; 8];
for (i, bit) in res.iter_mut().enumerate() {
*bit = F::from_bool(n & (1 << i) != 0);
}
res
}
fn decode(state: &State, row: &mut CpuRow) -> (Operation, MemoryLog) {
let code_context = if state.is_kernel {
KERNEL_CONTEXT
} else {
state.context
};
row.code_context = F::from_canonical_u32(code_context);
let address = (context, Segment::Code as u32, state.program_counter);
let mem_contents, mem_log = state.mem_read_with_log(address);
let opcode = to_byte_checked(mem_contents);
row.opcode_bits = to_bits(address);
let operation = match opcode {
0x01 => Operation::NotImplemented,
0x02 => Operation::NotImplemented,
0x03 => Operation::NotImplemented,
0x04 => Operation::NotImplemented,
0x06 => Operation::NotImplemented,
0x08 => Operation::NotImplemented,
0x09 => Operation::NotImplemented,
0x0c => Operation::NotImplemented,
0x0d => Operation::NotImplemented,
0x0e => Operation::NotImplemented,
0x10 => Operation::NotImplemented,
0x11 => Operation::NotImplemented,
0x14 => Operation::Eq,
0x15 => Operation::Iszero,
0x16 => Operation::BinaryLogic(BinaryLogicOp::And),
0x17 => Operation::BinaryLogic(BinaryLogicOp::Or),
0x18 => Operation::BinaryLogic(BinaryLogicOp::Xor),
0x19 => Operation::Not,
0x1a => Operation::Byte,
0x1b => Operation::NotImplemented,
0x1c => Operation::NotImplemented,
0x21 => Operation::NotImplemented,
0x49 => Operation::NotImplemented,
0x50 => Operation::NotImplemented,
0x56 => Operation::Jump,
0x57 => Operation::Jumpi,
0x58 => Operation::NotImplemented,
0x5a => Operation::NotImplemented,
0x5b => Operation::NotImplemented,
0x5c => Operation::NotImplemented,
0x5d => Operation::NotImplemented,
0x5e => Operation::NotImplemented,
0x5f => Operation::NotImplemented,
0x60..0x7f => Operation::NotImplemented,
0x80..0x8f => Operation::Dup(opcode & 0xf),
0x90..0x9f => Operation::Swap(opcode & 0xf),
0xf6 => Operation::NotImplemented,
0xf7 => Operation::NotImplemented,
0xf8 => Operation::NotImplemented,
0xf9 => Operation::ExitKernel,
0xfb => Operation::NotImplemented,
0xfc => Operation::NotImplemented,
_ => Operation::Syscall,
}
}
fn op_result() {
match op {
}
}
pub fn transition<F: Field>(state: State) -> (State, Traces<T>) {
let mut current_row: CpuColumnsView<F> = [F::ZERO; NUM_CPU_COLUMNS].into();
current_row.is_cpu_cycle = F::ONE;
let (op, code_mem_log) = decode(&state, &mut current_row);
}