Add a prove_with_outputs method

Which returns information about the post-state after execution. This is useful for debugging purposes.
This commit is contained in:
Daniel Lubarov 2023-03-16 13:32:34 -07:00
parent eebdd02972
commit c8d591f6da
8 changed files with 223 additions and 53 deletions

View File

@ -8,6 +8,8 @@ global get_create_address:
// TODO: Replace with actual implementation.
%pop2
PUSH 123
// stack: address, retdest
%observe_new_address
SWAP1
JUMP
@ -21,5 +23,22 @@ global get_create2_address:
// TODO: Replace with actual implementation.
%pop3
PUSH 123
// stack: address, retdest
%observe_new_address
SWAP1
JUMP
// This should be called whenever a new address is created. This is only for debugging. It does
// nothing, but just provides a single hook where code can react to newly created addresses.
global observe_new_address:
// stack: address, retdest
SWAP1
// stack: retdest, address
JUMP
// Convenience macro to call observe_new_address and return where we left off.
%macro observe_new_address
%stack (address) -> (address, %%after)
%jump(observe_new_address)
%%after:
%endmacro

View File

@ -18,25 +18,24 @@ use crate::config::StarkConfig;
use crate::cpu::bootstrap_kernel::generate_bootstrap_kernel;
use crate::cpu::kernel::aggregator::KERNEL;
use crate::cpu::kernel::constants::global_metadata::GlobalMetadata;
use crate::cpu::kernel::constants::global_metadata::GlobalMetadata::StateTrieRoot;
use crate::generation::mpt::AccountRlp;
use crate::generation::outputs::{get_outputs, GenerationOutputs};
use crate::generation::state::GenerationState;
use crate::generation::trie_extractor::read_state_trie_value;
use crate::memory::segments::Segment;
use crate::proof::{BlockMetadata, PublicValues, TrieRoots};
use crate::witness::memory::{MemoryAddress, MemoryChannel};
use crate::witness::transition::transition;
pub mod mpt;
pub mod outputs;
pub(crate) mod prover_input;
pub(crate) mod rlp;
pub(crate) mod state;
mod trie_extractor;
use crate::generation::trie_extractor::read_trie;
use crate::witness::util::mem_write_log;
#[derive(Clone, Debug, Deserialize, Serialize, Default)]
/// Inputs needed for trace generation.
#[derive(Clone, Debug, Deserialize, Serialize, Default)]
pub struct GenerationInputs {
pub signed_txns: Vec<Vec<u8>>,
@ -104,7 +103,11 @@ pub(crate) fn generate_traces<F: RichField + Extendable<D>, const D: usize>(
inputs: GenerationInputs,
config: &StarkConfig,
timing: &mut TimingTree,
) -> anyhow::Result<([Vec<PolynomialValues<F>>; NUM_TABLES], PublicValues)> {
) -> anyhow::Result<(
[Vec<PolynomialValues<F>>; NUM_TABLES],
PublicValues,
GenerationOutputs,
)> {
let mut state = GenerationState::<F>::new(inputs.clone(), &KERNEL.code);
apply_metadata_memops(&mut state, &inputs.block_metadata);
@ -118,23 +121,9 @@ pub(crate) fn generate_traces<F: RichField + Extendable<D>, const D: usize>(
state.traces.checkpoint()
);
let read_metadata = |field| {
state.memory.get(MemoryAddress::new(
0,
Segment::GlobalMetadata,
field as usize,
))
};
log::debug!(
"Updated state trie:\n{:#?}",
read_trie::<F, AccountRlp, D>(
&state.memory,
read_metadata(StateTrieRoot).as_usize(),
read_state_trie_value
)
);
let outputs = get_outputs(&mut state);
let read_metadata = |field| state.memory.read_global_metadata(field);
let trie_roots_before = TrieRoots {
state_root: H256::from_uint(&read_metadata(StateTrieRootDigestBefore)),
transactions_root: H256::from_uint(&read_metadata(TransactionTrieRootDigestBefore)),
@ -157,7 +146,7 @@ pub(crate) fn generate_traces<F: RichField + Extendable<D>, const D: usize>(
"convert trace data to tables",
state.traces.into_tables(all_stark, config, timing)
);
Ok((tables, public_values))
Ok((tables, public_values, outputs))
}
fn simulate_cpu<F: RichField + Extendable<D>, const D: usize>(

View File

@ -0,0 +1,99 @@
use std::collections::HashMap;
use ethereum_types::{Address, BigEndianHash, H256, U256};
use plonky2::field::types::Field;
use crate::cpu::kernel::constants::global_metadata::GlobalMetadata::StateTrieRoot;
use crate::generation::state::GenerationState;
use crate::generation::trie_extractor::{
read_state_trie_value, read_storage_trie_value, read_trie, AccountTrieRecord,
};
/// The post-state after trace generation; intended for debugging.
#[derive(Clone, Debug)]
pub struct GenerationOutputs {
pub accounts: HashMap<AddressOrStateKey, AccountOutput>,
}
#[derive(Clone, Eq, PartialEq, Hash, Debug)]
pub enum AddressOrStateKey {
Address(Address),
StateKey(H256),
}
#[derive(Clone, Debug)]
pub struct AccountOutput {
pub balance: U256,
pub nonce: u64,
pub code: Vec<u8>,
pub storage: HashMap<U256, U256>,
}
pub(crate) fn get_outputs<F: Field>(state: &mut GenerationState<F>) -> GenerationOutputs {
let account_map = read_trie::<AccountTrieRecord>(
&state.memory,
state.memory.read_global_metadata(StateTrieRoot).as_usize(),
read_state_trie_value,
);
let accounts = account_map
.into_iter()
.map(|(state_key_nibbles, account)| {
assert_eq!(
state_key_nibbles.count, 64,
"Each state key should have 64 nibbles = 256 bits"
);
let state_key_h256 = H256::from_uint(&state_key_nibbles.packed);
let addr_or_state_key =
if let Some(address) = state.state_key_to_address.get(&state_key_h256) {
AddressOrStateKey::Address(*address)
} else {
AddressOrStateKey::StateKey(state_key_h256)
};
let account_output = account_trie_record_to_output(state, account);
(addr_or_state_key, account_output)
})
.collect();
GenerationOutputs { accounts }
}
fn account_trie_record_to_output<F: Field>(
state: &GenerationState<F>,
account: AccountTrieRecord,
) -> AccountOutput {
let storage = get_storage(state, account.storage_ptr);
// TODO: This won't work if the account was created during the txn.
// Need to track changes to code, similar to how we track addresses
// with observe_new_address.
let code = state
.inputs
.contract_code
.get(&account.code_hash)
.expect("Code not found")
.clone();
AccountOutput {
balance: account.balance,
nonce: account.nonce,
storage,
code,
}
}
/// Get an account's storage trie, given a pointer to its root.
fn get_storage<F: Field>(state: &GenerationState<F>, storage_ptr: usize) -> HashMap<U256, U256> {
read_trie::<U256>(&state.memory, storage_ptr, read_storage_trie_value)
.into_iter()
.map(|(storage_key_nibbles, value)| {
assert_eq!(
storage_key_nibbles.count, 64,
"Each storage key should have 64 nibbles = 256 bits"
);
(storage_key_nibbles.packed, value)
})
.collect()
}

View File

@ -1,12 +1,17 @@
use ethereum_types::U256;
use std::collections::HashMap;
use ethereum_types::{Address, H160, H256, U256};
use keccak_hash::keccak;
use plonky2::field::types::Field;
use crate::cpu::kernel::aggregator::KERNEL;
use crate::generation::mpt::all_mpt_prover_inputs_reversed;
use crate::generation::rlp::all_rlp_prover_inputs_reversed;
use crate::generation::GenerationInputs;
use crate::witness::memory::MemoryState;
use crate::witness::state::RegistersState;
use crate::witness::traces::{TraceCheckpoint, Traces};
use crate::witness::util::stack_peek;
pub(crate) struct GenerationStateCheckpoint {
pub(crate) registers: RegistersState,
@ -29,6 +34,11 @@ pub(crate) struct GenerationState<F: Field> {
/// Prover inputs containing RLP data, in reverse order so that the next input can be obtained
/// via `pop()`.
pub(crate) rlp_prover_inputs: Vec<U256>,
/// The state trie only stores state keys, which are hashes of addresses, but sometimes it is
/// useful to see the actual addresses for debugging. Here we store the mapping for all known
/// addresses.
pub(crate) state_key_to_address: HashMap<H256, Address>,
}
impl<F: Field> GenerationState<F> {
@ -53,9 +63,29 @@ impl<F: Field> GenerationState<F> {
next_txn_index: 0,
mpt_prover_inputs,
rlp_prover_inputs,
state_key_to_address: HashMap::new(),
}
}
/// Updates `program_counter`, and potentially adds some extra handling if we're jumping to a
/// special location.
pub fn jump_to(&mut self, dst: usize) {
self.registers.program_counter = dst;
if dst == KERNEL.global_labels["observe_new_address"] {
let address = stack_peek(self, 0).expect("Empty stack");
let mut address_bytes = [0; 20];
address.to_big_endian(&mut address_bytes);
self.observe_address(H160(address_bytes));
}
}
/// Observe the given address, so that we will be able to recognize the associated state key.
/// This is just for debugging purposes.
pub fn observe_address(&mut self, address: Address) {
let state_key = keccak(address.0);
self.state_key_to_address.insert(state_key, address);
}
pub fn checkpoint(&self) -> GenerationStateCheckpoint {
GenerationStateCheckpoint {
registers: self.registers,

View File

@ -1,50 +1,56 @@
//! Code for extracting trie data after witness generation. This is intended only for debugging.
use std::collections::HashMap;
use eth_trie_utils::partial_trie::Nibbles;
use ethereum_types::{BigEndianHash, H256, U256};
use plonky2::field::extension::Extendable;
use plonky2::hash::hash_types::RichField;
use crate::cpu::kernel::constants::trie_type::PartialTrieType;
use crate::generation::mpt::AccountRlp;
use crate::memory::segments::Segment;
use crate::witness::memory::{MemoryAddress, MemoryState};
pub(crate) fn read_state_trie_value(slice: &[U256]) -> AccountRlp {
AccountRlp {
nonce: slice[0],
/// Account data as it's stored in the state trie, with a pointer to the storage trie.
pub(crate) struct AccountTrieRecord {
pub(crate) nonce: u64,
pub(crate) balance: U256,
pub(crate) storage_ptr: usize,
pub(crate) code_hash: H256,
}
pub(crate) fn read_state_trie_value(slice: &[U256]) -> AccountTrieRecord {
AccountTrieRecord {
nonce: slice[0].as_u64(),
balance: slice[1],
storage_root: H256::from_uint(&slice[2]),
storage_ptr: slice[2].as_usize(),
code_hash: H256::from_uint(&slice[3]),
}
}
pub(crate) fn read_trie<F, V, const D: usize>(
pub(crate) fn read_storage_trie_value(slice: &[U256]) -> U256 {
slice[0]
}
pub(crate) fn read_trie<V>(
memory: &MemoryState,
ptr: usize,
read_value: fn(&[U256]) -> V,
) -> HashMap<Nibbles, V>
where
F: RichField + Extendable<D>,
{
) -> HashMap<Nibbles, V> {
let mut res = HashMap::new();
let empty_nibbles = Nibbles {
count: 0,
packed: U256::zero(),
};
read_trie_helper::<F, V, D>(memory, ptr, read_value, empty_nibbles, &mut res);
read_trie_helper::<V>(memory, ptr, read_value, empty_nibbles, &mut res);
res
}
pub(crate) fn read_trie_helper<F, V, const D: usize>(
pub(crate) fn read_trie_helper<V>(
memory: &MemoryState,
ptr: usize,
read_value: fn(&[U256]) -> V,
prefix: Nibbles,
res: &mut HashMap<Nibbles, V>,
) where
F: RichField + Extendable<D>,
{
) {
let load = |offset| memory.get(MemoryAddress::new(0, Segment::TrieData, offset));
let load_slice_from = |init_offset| {
&memory.contexts[0].segments[Segment::TrieData as usize].content[init_offset..]
@ -58,13 +64,7 @@ pub(crate) fn read_trie_helper<F, V, const D: usize>(
let ptr_payload = ptr + 1;
for i in 0u8..16 {
let child_ptr = load(ptr_payload + i as usize).as_usize();
read_trie_helper::<F, V, D>(
memory,
child_ptr,
read_value,
prefix.merge_nibble(i),
res,
);
read_trie_helper::<V>(memory, child_ptr, read_value, prefix.merge_nibble(i), res);
}
let value_ptr = load(ptr_payload + 16).as_usize();
if value_ptr != 0 {
@ -76,7 +76,7 @@ pub(crate) fn read_trie_helper<F, V, const D: usize>(
let packed = load(ptr + 2);
let nibbles = Nibbles { count, packed };
let child_ptr = load(ptr + 3).as_usize();
read_trie_helper::<F, V, D>(
read_trie_helper::<V>(
memory,
child_ptr,
read_value,

View File

@ -25,6 +25,7 @@ use crate::constraint_consumer::ConstraintConsumer;
use crate::cpu::cpu_stark::CpuStark;
use crate::cpu::kernel::aggregator::KERNEL;
use crate::cross_table_lookup::{cross_table_lookup_data, CtlCheckVars, CtlData};
use crate::generation::outputs::GenerationOutputs;
use crate::generation::{generate_traces, GenerationInputs};
use crate::keccak::keccak_stark::KeccakStark;
use crate::keccak_sponge::keccak_sponge_stark::KeccakSpongeStark;
@ -46,6 +47,28 @@ pub fn prove<F, C, const D: usize>(
inputs: GenerationInputs,
timing: &mut TimingTree,
) -> Result<AllProof<F, C, D>>
where
F: RichField + Extendable<D>,
C: GenericConfig<D, F = F>,
[(); C::Hasher::HASH_SIZE]:,
[(); CpuStark::<F, D>::COLUMNS]:,
[(); KeccakStark::<F, D>::COLUMNS]:,
[(); KeccakSpongeStark::<F, D>::COLUMNS]:,
[(); LogicStark::<F, D>::COLUMNS]:,
[(); MemoryStark::<F, D>::COLUMNS]:,
{
let (proof, _outputs) = prove_with_outputs(all_stark, config, inputs, timing)?;
Ok(proof)
}
/// Generate traces, then create all STARK proofs. Returns information about the post-state,
/// intended for debugging, in addition to the proof.
pub fn prove_with_outputs<F, C, const D: usize>(
all_stark: &AllStark<F, D>,
config: &StarkConfig,
inputs: GenerationInputs,
timing: &mut TimingTree,
) -> Result<(AllProof<F, C, D>, GenerationOutputs)>
where
F: RichField + Extendable<D>,
C: GenericConfig<D, F = F>,
@ -57,12 +80,13 @@ where
[(); MemoryStark::<F, D>::COLUMNS]:,
{
timed!(timing, "build kernel", Lazy::force(&KERNEL));
let (traces, public_values) = timed!(
let (traces, public_values, outputs) = timed!(
timing,
"generate all traces",
generate_traces(all_stark, inputs, config, timing)?
);
prove_with_traces(all_stark, config, traces, public_values, timing)
let proof = prove_with_traces(all_stark, config, traces, public_values, timing)?;
Ok((proof, outputs))
}
/// Compute all STARK proofs.

View File

@ -10,6 +10,7 @@ pub enum MemoryChannel {
use MemoryChannel::{Code, GeneralPurpose};
use crate::cpu::kernel::constants::global_metadata::GlobalMetadata;
use crate::memory::segments::Segment;
impl MemoryChannel {
@ -173,6 +174,14 @@ impl MemoryState {
);
self.contexts[address.context].segments[address.segment].set(address.virt, val);
}
pub(crate) fn read_global_metadata(&self, field: GlobalMetadata) -> U256 {
self.get(MemoryAddress::new(
0,
Segment::GlobalMetadata,
field as usize,
))
}
}
impl Default for MemoryState {

View File

@ -200,7 +200,7 @@ pub(crate) fn generate_jump<F: Field>(
state.traces.push_memory(log_in0);
state.traces.push_cpu(row);
state.registers.program_counter = dst as usize;
state.jump_to(dst as usize);
Ok(())
}
@ -224,7 +224,7 @@ pub(crate) fn generate_jumpi<F: Field>(
let dst: u32 = dst
.try_into()
.map_err(|_| ProgramError::InvalidJumpiDestination)?;
state.registers.program_counter = dst as usize;
state.jump_to(dst as usize);
} else {
row.general.jumps_mut().should_jump = F::ZERO;
row.general.jumps_mut().cond_sum_pinv = F::ZERO;