Merge pull request #835 from mir-protocol/gen_fixes

Misc witness generation fixes
This commit is contained in:
Daniel Lubarov 2022-12-03 22:58:59 -08:00 committed by GitHub
commit 7d0ba54e40
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
13 changed files with 142 additions and 55 deletions

View File

@ -80,7 +80,7 @@ pub(crate) fn all_cross_table_lookups<F: Field>() -> Vec<CrossTableLookup<F>> {
let mut ctls = vec![ctl_keccak(), ctl_logic(), ctl_memory(), ctl_keccak_sponge()];
// TODO: Some CTLs temporarily disabled while we get them working.
disable_ctl(&mut ctls[0]);
disable_ctl(&mut ctls[1]); // Enable once we populate logic log in keccak_sponge_log.
disable_ctl(&mut ctls[1]);
disable_ctl(&mut ctls[2]);
disable_ctl(&mut ctls[3]);
ctls

View File

@ -33,11 +33,23 @@ pub(crate) enum BinaryOperator {
impl BinaryOperator {
pub(crate) fn result(&self, input0: U256, input1: U256) -> U256 {
match self {
BinaryOperator::Add => input0 + input1,
BinaryOperator::Mul => input0 * input1,
BinaryOperator::Sub => input0 - input1,
BinaryOperator::Div => input0 / input1,
BinaryOperator::Mod => input0 % input1,
BinaryOperator::Add => input0.overflowing_add(input1).0,
BinaryOperator::Mul => input0.overflowing_mul(input1).0,
BinaryOperator::Sub => input0.overflowing_sub(input1).0,
BinaryOperator::Div => {
if input1.is_zero() {
U256::zero()
} else {
input0 / input1
}
}
BinaryOperator::Mod => {
if input1.is_zero() {
U256::zero()
} else {
input0 % input1
}
}
BinaryOperator::Lt => {
if input0 < input1 {
U256::one()
@ -52,8 +64,20 @@ impl BinaryOperator {
U256::zero()
}
}
BinaryOperator::Shl => input0 << input1,
BinaryOperator::Shr => input0 >> input1,
BinaryOperator::Shl => {
if input0 > 255.into() {
U256::zero()
} else {
input1 << input0
}
}
BinaryOperator::Shr => {
if input0 > 255.into() {
U256::zero()
} else {
input1 >> input0
}
}
BinaryOperator::AddFp254 => addmod(input0, input1, bn_base_order()),
BinaryOperator::MulFp254 => mulmod(input0, input1, bn_base_order()),
BinaryOperator::SubFp254 => submod(input0, input1, bn_base_order()),

View File

@ -490,12 +490,10 @@ impl<'a> Interpreter<'a> {
fn run_byte(&mut self) {
let i = self.pop();
let x = self.pop();
let result = if i > 32.into() {
0
let result = if i < 32.into() {
x.byte(31 - i.as_usize())
} else {
let mut bytes = [0; 32];
x.to_big_endian(&mut bytes);
bytes[i.as_usize()]
0
};
self.push(result.into());
}

View File

@ -14,7 +14,7 @@ pub(crate) fn eval_packed<P: PackedField>(
yield_constr: &mut ConstraintConsumer<P>,
) {
let is_shift = lv.op.shl + lv.op.shr;
let displacement = lv.mem_channels[1]; // holds the shift displacement d
let displacement = lv.mem_channels[0]; // holds the shift displacement d
let two_exp = lv.mem_channels[2]; // holds 2^d
// Not needed here; val is the input and we're verifying that output is
@ -65,7 +65,7 @@ pub(crate) fn eval_ext_circuit<F: RichField + Extendable<D>, const D: usize>(
yield_constr: &mut RecursiveConstraintConsumer<F, D>,
) {
let is_shift = builder.add_extension(lv.op.shl, lv.op.shr);
let displacement = lv.mem_channels[1];
let displacement = lv.mem_channels[0];
let two_exp = lv.mem_channels[2];
let shift_table_segment = F::from_canonical_u64(Segment::ShiftTable as u64);

View File

@ -24,7 +24,7 @@ use crate::proof::{BlockMetadata, PublicValues, TrieRoots};
use crate::witness::memory::MemoryAddress;
use crate::witness::transition::transition;
pub(crate) mod mpt;
pub mod mpt;
pub(crate) mod prover_input;
pub(crate) mod rlp;
pub(crate) mod state;
@ -74,6 +74,11 @@ pub(crate) fn generate_traces<F: RichField + Extendable<D>, const D: usize>(
timed!(timing, "simulate CPU", simulate_cpu(&mut state));
log::info!(
"Trace lengths (before padding): {:?}",
state.traces.checkpoint()
);
let read_metadata = |field| {
state.memory.get(MemoryAddress::new(
0,

View File

@ -9,11 +9,11 @@ use crate::cpu::kernel::constants::trie_type::PartialTrieType;
use crate::generation::TrieInputs;
#[derive(RlpEncodable, RlpDecodable, Debug)]
pub(crate) struct AccountRlp {
pub(crate) nonce: U256,
pub(crate) balance: U256,
pub(crate) storage_root: H256,
pub(crate) code_hash: H256,
pub struct AccountRlp {
pub nonce: U256,
pub balance: U256,
pub storage_root: H256,
pub code_hash: H256,
}
pub(crate) fn all_mpt_prover_inputs_reversed(trie_inputs: &TrieInputs) -> Vec<U256> {

View File

@ -80,7 +80,7 @@ fn get_max_range_check(memory_ops: &[MemoryOp]) -> usize {
} else if curr.address.virt != next.address.virt {
next.address.virt - curr.address.virt - 1
} else {
next.timestamp - curr.timestamp - 1
next.timestamp - curr.timestamp
}
})
.max()
@ -124,7 +124,7 @@ pub fn generate_first_change_flags_and_rc<F: RichField>(trace_rows: &mut [[F; NU
} else if virtual_first_change {
next_virt - virt - F::ONE
} else {
next_timestamp - timestamp - F::ONE
next_timestamp - timestamp
};
}
}
@ -283,7 +283,7 @@ impl<F: RichField + Extendable<D>, const D: usize> Stark<F, D> for MemoryStark<F
let computed_range_check = context_first_change * (next_addr_context - addr_context - one)
+ segment_first_change * (next_addr_segment - addr_segment - one)
+ virtual_first_change * (next_addr_virtual - addr_virtual - one)
+ address_unchanged * (next_timestamp - timestamp - one);
+ address_unchanged * (next_timestamp - timestamp);
yield_constr.constraint_transition(range_check - computed_range_check);
// Enumerate purportedly-ordered log.
@ -394,10 +394,7 @@ impl<F: RichField + Extendable<D>, const D: usize> Stark<F, D> for MemoryStark<F
builder.sub_extension(diff, one)
};
let virtual_range_check = builder.mul_extension(virtual_first_change, virtual_diff);
let timestamp_diff = {
let diff = builder.sub_extension(next_timestamp, timestamp);
builder.sub_extension(diff, one)
};
let timestamp_diff = builder.sub_extension(next_timestamp, timestamp);
let timestamp_range_check = builder.mul_extension(address_unchanged, timestamp_diff);
let computed_range_check = {

View File

@ -48,6 +48,10 @@ impl MemoryAddress {
virt: u256_saturating_cast_usize(virt),
}
}
pub(crate) fn increment(&mut self) {
self.virt = self.virt.saturating_add(1);
}
}
#[derive(Clone, Copy, Debug, Eq, PartialEq)]

View File

@ -72,8 +72,25 @@ pub(crate) fn generate_binary_arithmetic_op<F: Field>(
let [(input0, log_in0), (input1, log_in1)] =
stack_pop_with_log_and_fill::<2, _>(state, &mut row)?;
let operation = arithmetic::Operation::binary(operator, input0, input1);
let log_out = stack_push_log_and_fill(state, &mut row, operation.result())?;
if operator == arithmetic::BinaryOperator::Shl || operator == arithmetic::BinaryOperator::Shr {
const LOOKUP_CHANNEL: usize = 2;
let lookup_addr = MemoryAddress::new(0, Segment::ShiftTable, input0.low_u32() as usize);
if input0.bits() <= 32 {
let (_, read) =
mem_read_gp_with_log_and_fill(LOOKUP_CHANNEL, lookup_addr, state, &mut row);
state.traces.push_memory(read);
} else {
// The shift constraints still expect the address to be set, even though no read will occur.
let mut channel = &mut row.mem_channels[LOOKUP_CHANNEL];
channel.addr_context = F::from_canonical_usize(lookup_addr.context);
channel.addr_segment = F::from_canonical_usize(lookup_addr.segment);
channel.addr_virtual = F::from_canonical_usize(lookup_addr.virt);
}
}
state.traces.push_arithmetic(operation);
state.traces.push_memory(log_in0);
state.traces.push_memory(log_in1);
@ -121,6 +138,7 @@ pub(crate) fn generate_keccak_general<F: Field>(
val.as_u32() as u8
})
.collect_vec();
log::debug!("Hashing {:?}", input);
let hash = keccak(&input);
let log_push = stack_push_log_and_fill(state, &mut row, hash.into_uint())?;

View File

@ -18,6 +18,8 @@ use crate::{arithmetic, keccak, logic};
#[derive(Clone, Copy, Debug)]
pub struct TraceCheckpoint {
pub(self) cpu_len: usize,
pub(self) keccak_len: usize,
pub(self) keccak_sponge_len: usize,
pub(self) logic_len: usize,
pub(self) arithmetic_len: usize,
pub(self) memory_len: usize,
@ -48,19 +50,22 @@ impl<T: Copy> Traces<T> {
pub fn checkpoint(&self) -> TraceCheckpoint {
TraceCheckpoint {
cpu_len: self.cpu.len(),
keccak_len: self.keccak_inputs.len(),
keccak_sponge_len: self.keccak_sponge_ops.len(),
logic_len: self.logic_ops.len(),
arithmetic_len: self.arithmetic.len(),
memory_len: self.memory_ops.len(),
// TODO others
}
}
pub fn rollback(&mut self, checkpoint: TraceCheckpoint) {
self.cpu.truncate(checkpoint.cpu_len);
self.keccak_inputs.truncate(checkpoint.keccak_len);
self.keccak_sponge_ops
.truncate(checkpoint.keccak_sponge_len);
self.logic_ops.truncate(checkpoint.logic_len);
self.arithmetic.truncate(checkpoint.arithmetic_len);
self.memory_ops.truncate(checkpoint.memory_len);
// TODO others
}
pub fn mem_ops_since(&self, checkpoint: TraceCheckpoint) -> &[MemoryOp] {

View File

@ -219,6 +219,12 @@ fn perform_op<F: Field>(
_ => 1,
};
if let Some(label) = KERNEL.offset_label(state.registers.program_counter) {
if !label.starts_with("halt_pc") {
log::debug!("At {label}");
}
}
Ok(())
}

View File

@ -3,11 +3,12 @@ use plonky2::field::types::Field;
use crate::cpu::columns::CpuColumnsView;
use crate::cpu::kernel::keccak_util::keccakf_u8s;
use crate::cpu::membus::NUM_GP_CHANNELS;
use crate::cpu::membus::{NUM_CHANNELS, NUM_GP_CHANNELS};
use crate::cpu::stack_bounds::MAX_USER_STACK_SIZE;
use crate::generation::state::GenerationState;
use crate::keccak_sponge::columns::{KECCAK_RATE_BYTES, KECCAK_WIDTH_BYTES};
use crate::keccak_sponge::keccak_sponge_stark::KeccakSpongeOp;
use crate::logic;
use crate::memory::segments::Segment;
use crate::witness::errors::ProgramError;
use crate::witness::memory::{MemoryAddress, MemoryChannel, MemoryOp, MemoryOpKind};
@ -174,39 +175,76 @@ pub(crate) fn stack_push_log_and_fill<F: Field>(
Ok(res)
}
fn xor_into_sponge<F: Field>(
state: &mut GenerationState<F>,
sponge_state: &mut [u8; KECCAK_WIDTH_BYTES],
block: &[u8; KECCAK_RATE_BYTES],
) {
for i in (0..KECCAK_RATE_BYTES).step_by(32) {
let range = i..KECCAK_RATE_BYTES.min(i + 32);
let lhs = U256::from_little_endian(&sponge_state[range.clone()]);
let rhs = U256::from_little_endian(&block[range]);
state
.traces
.push_logic(logic::Operation::new(logic::Op::Xor, lhs, rhs));
}
for i in 0..KECCAK_RATE_BYTES {
sponge_state[i] ^= block[i];
}
}
pub(crate) fn keccak_sponge_log<F: Field>(
state: &mut GenerationState<F>,
base_address: MemoryAddress,
input: Vec<u8>,
) {
let clock = state.traces.clock();
let mut address = base_address;
let mut input_blocks = input.chunks_exact(KECCAK_RATE_BYTES);
let mut sponge_state = [0u8; KECCAK_WIDTH_BYTES];
for block in input_blocks.by_ref() {
sponge_state[..KECCAK_RATE_BYTES].copy_from_slice(block);
for &byte in block {
state.traces.push_memory(MemoryOp::new(
MemoryChannel::Code,
clock,
address,
MemoryOpKind::Read,
byte.into(),
));
address.increment();
}
xor_into_sponge(state, &mut sponge_state, block.try_into().unwrap());
state.traces.push_keccak_bytes(sponge_state);
// TODO: Also push logic rows for XORs.
// TODO: Also push memory read rows.
keccakf_u8s(&mut sponge_state);
}
let final_inputs = input_blocks.remainder();
sponge_state[..final_inputs.len()].copy_from_slice(final_inputs);
// pad10*1 rule
sponge_state[final_inputs.len()..KECCAK_RATE_BYTES].fill(0);
if final_inputs.len() == KECCAK_RATE_BYTES - 1 {
// Both 1s are placed in the same byte.
sponge_state[final_inputs.len()] = 0b10000001;
} else {
sponge_state[final_inputs.len()] = 1;
sponge_state[KECCAK_RATE_BYTES - 1] = 0b10000000;
for &byte in input_blocks.remainder() {
state.traces.push_memory(MemoryOp::new(
MemoryChannel::Code,
clock,
address,
MemoryOpKind::Read,
byte.into(),
));
address.increment();
}
let mut final_block = [0u8; KECCAK_RATE_BYTES];
final_block[..input_blocks.remainder().len()].copy_from_slice(input_blocks.remainder());
// pad10*1 rule
if input_blocks.remainder().len() == KECCAK_RATE_BYTES - 1 {
// Both 1s are placed in the same byte.
final_block[input_blocks.remainder().len()] = 0b10000001;
} else {
final_block[input_blocks.remainder().len()] = 1;
final_block[KECCAK_RATE_BYTES - 1] = 0b10000000;
}
xor_into_sponge(state, &mut sponge_state, &final_block);
state.traces.push_keccak_bytes(sponge_state);
// TODO: Also push logic rows for XORs.
// TODO: Also push memory read rows.
state.traces.push_keccak_sponge(KeccakSpongeOp {
base_address,
timestamp: state.traces.clock(),
timestamp: clock * NUM_CHANNELS,
input,
});
}

View File

@ -26,14 +26,6 @@ fn test_empty_txn_list() -> anyhow::Result<()> {
let block_metadata = BlockMetadata::default();
// TODO: This trie isn't working yet.
// let state_trie = PartialTrie::Leaf {
// nibbles: Nibbles {
// count: 5,
// packed: 0xABCDE.into(),
// },
// value: vec![1, 2, 3],
// };
let state_trie = PartialTrie::Empty;
let transactions_trie = PartialTrie::Empty;
let receipts_trie = PartialTrie::Empty;