Merge branch 'main' of github.com:mir-protocol/plonky2 into non-inv

This commit is contained in:
Dmitry Vagner 2022-12-05 12:16:58 -08:00
commit 8f15402041
80 changed files with 2779 additions and 2139 deletions

View File

@ -24,7 +24,7 @@ jobs:
uses: actions-rs/toolchain@v1
with:
profile: minimal
toolchain: nightly
toolchain: nightly-2022-11-23
override: true
- name: rust-cache
@ -60,7 +60,7 @@ jobs:
uses: actions-rs/toolchain@v1
with:
profile: minimal
toolchain: nightly
toolchain: nightly-2022-11-23
override: true
components: rustfmt, clippy

View File

@ -29,7 +29,7 @@ RUSTFLAGS=-Ctarget-cpu=native cargo run --release --example bench_recursion -- -
## Jemalloc
Plonky2 prefers the [Jemalloc](http://jemalloc.net) memory allocator due to its superior performance. To use it, include `jemallocator = "0.3.2"` in`Cargo.toml`and add the following lines
Plonky2 prefers the [Jemalloc](http://jemalloc.net) memory allocator due to its superior performance. To use it, include `jemallocator = "0.5.0"` in`Cargo.toml`and add the following lines
to your `main.rs`:
```rust

View File

@ -31,6 +31,9 @@ sha2 = "0.10.2"
static_assertions = "1.1.0"
tiny-keccak = "2.0.2"
[target.'cfg(not(target_env = "msvc"))'.dependencies]
jemallocator = "0.5.0"
[dev-dependencies]
criterion = "0.4.0"
hex = "0.4.3"

View File

@ -8,12 +8,12 @@ use crate::config::StarkConfig;
use crate::cpu::cpu_stark;
use crate::cpu::cpu_stark::CpuStark;
use crate::cpu::membus::NUM_GP_CHANNELS;
use crate::cross_table_lookup::{CrossTableLookup, TableWithColumns};
use crate::cross_table_lookup::{Column, CrossTableLookup, TableWithColumns};
use crate::keccak::keccak_stark;
use crate::keccak::keccak_stark::KeccakStark;
use crate::keccak_memory::columns::KECCAK_WIDTH_BYTES;
use crate::keccak_memory::keccak_memory_stark;
use crate::keccak_memory::keccak_memory_stark::KeccakMemoryStark;
use crate::keccak_sponge::columns::KECCAK_RATE_BYTES;
use crate::keccak_sponge::keccak_sponge_stark;
use crate::keccak_sponge::keccak_sponge_stark::{num_logic_ctls, KeccakSpongeStark};
use crate::logic;
use crate::logic::LogicStark;
use crate::memory::memory_stark;
@ -24,7 +24,7 @@ use crate::stark::Stark;
pub struct AllStark<F: RichField + Extendable<D>, const D: usize> {
pub cpu_stark: CpuStark<F, D>,
pub keccak_stark: KeccakStark<F, D>,
pub keccak_memory_stark: KeccakMemoryStark<F, D>,
pub keccak_sponge_stark: KeccakSpongeStark<F, D>,
pub logic_stark: LogicStark<F, D>,
pub memory_stark: MemoryStark<F, D>,
pub cross_table_lookups: Vec<CrossTableLookup<F>>,
@ -35,7 +35,7 @@ impl<F: RichField + Extendable<D>, const D: usize> Default for AllStark<F, D> {
Self {
cpu_stark: CpuStark::default(),
keccak_stark: KeccakStark::default(),
keccak_memory_stark: KeccakMemoryStark::default(),
keccak_sponge_stark: KeccakSpongeStark::default(),
logic_stark: LogicStark::default(),
memory_stark: MemoryStark::default(),
cross_table_lookups: all_cross_table_lookups(),
@ -48,7 +48,7 @@ impl<F: RichField + Extendable<D>, const D: usize> AllStark<F, D> {
[
self.cpu_stark.num_permutation_batches(config),
self.keccak_stark.num_permutation_batches(config),
self.keccak_memory_stark.num_permutation_batches(config),
self.keccak_sponge_stark.num_permutation_batches(config),
self.logic_stark.num_permutation_batches(config),
self.memory_stark.num_permutation_batches(config),
]
@ -58,7 +58,7 @@ impl<F: RichField + Extendable<D>, const D: usize> AllStark<F, D> {
[
self.cpu_stark.permutation_batch_size(),
self.keccak_stark.permutation_batch_size(),
self.keccak_memory_stark.permutation_batch_size(),
self.keccak_sponge_stark.permutation_batch_size(),
self.logic_stark.permutation_batch_size(),
self.memory_stark.permutation_batch_size(),
]
@ -69,66 +69,77 @@ impl<F: RichField + Extendable<D>, const D: usize> AllStark<F, D> {
pub enum Table {
Cpu = 0,
Keccak = 1,
KeccakMemory = 2,
KeccakSponge = 2,
Logic = 3,
Memory = 4,
}
pub(crate) const NUM_TABLES: usize = Table::Memory as usize + 1;
#[allow(unused)] // TODO: Should be used soon.
pub(crate) fn all_cross_table_lookups<F: Field>() -> Vec<CrossTableLookup<F>> {
vec![ctl_keccak(), ctl_logic(), ctl_memory(), ctl_keccak_memory()]
let mut ctls = vec![ctl_keccak(), ctl_logic(), ctl_memory(), ctl_keccak_sponge()];
// TODO: Some CTLs temporarily disabled while we get them working.
disable_ctl(&mut ctls[0]);
disable_ctl(&mut ctls[1]);
disable_ctl(&mut ctls[2]);
disable_ctl(&mut ctls[3]);
ctls
}
fn disable_ctl<F: Field>(ctl: &mut CrossTableLookup<F>) {
for table in &mut ctl.looking_tables {
table.filter_column = Some(Column::zero());
}
ctl.looked_table.filter_column = Some(Column::zero());
}
fn ctl_keccak<F: Field>() -> CrossTableLookup<F> {
let cpu_looking = TableWithColumns::new(
Table::Cpu,
cpu_stark::ctl_data_keccak(),
Some(cpu_stark::ctl_filter_keccak()),
let keccak_sponge_looking = TableWithColumns::new(
Table::KeccakSponge,
keccak_sponge_stark::ctl_looking_keccak(),
Some(keccak_sponge_stark::ctl_looking_keccak_filter()),
);
let keccak_memory_looking = TableWithColumns::new(
Table::KeccakMemory,
keccak_memory_stark::ctl_looking_keccak(),
Some(keccak_memory_stark::ctl_filter()),
let keccak_looked = TableWithColumns::new(
Table::Keccak,
keccak_stark::ctl_data(),
Some(keccak_stark::ctl_filter()),
);
CrossTableLookup::new(
vec![cpu_looking, keccak_memory_looking],
TableWithColumns::new(
Table::Keccak,
keccak_stark::ctl_data(),
Some(keccak_stark::ctl_filter()),
),
None,
)
CrossTableLookup::new(vec![keccak_sponge_looking], keccak_looked, None)
}
fn ctl_keccak_memory<F: Field>() -> CrossTableLookup<F> {
CrossTableLookup::new(
vec![TableWithColumns::new(
Table::Cpu,
cpu_stark::ctl_data_keccak_memory(),
Some(cpu_stark::ctl_filter_keccak_memory()),
)],
TableWithColumns::new(
Table::KeccakMemory,
keccak_memory_stark::ctl_looked_data(),
Some(keccak_memory_stark::ctl_filter()),
),
None,
)
fn ctl_keccak_sponge<F: Field>() -> CrossTableLookup<F> {
let cpu_looking = TableWithColumns::new(
Table::Cpu,
cpu_stark::ctl_data_keccak_sponge(),
Some(cpu_stark::ctl_filter_keccak_sponge()),
);
let keccak_sponge_looked = TableWithColumns::new(
Table::KeccakSponge,
keccak_sponge_stark::ctl_looked_data(),
Some(keccak_sponge_stark::ctl_looked_filter()),
);
CrossTableLookup::new(vec![cpu_looking], keccak_sponge_looked, None)
}
fn ctl_logic<F: Field>() -> CrossTableLookup<F> {
CrossTableLookup::new(
vec![TableWithColumns::new(
Table::Cpu,
cpu_stark::ctl_data_logic(),
Some(cpu_stark::ctl_filter_logic()),
)],
TableWithColumns::new(Table::Logic, logic::ctl_data(), Some(logic::ctl_filter())),
None,
)
let cpu_looking = TableWithColumns::new(
Table::Cpu,
cpu_stark::ctl_data_logic(),
Some(cpu_stark::ctl_filter_logic()),
);
let mut all_lookers = vec![cpu_looking];
for i in 0..num_logic_ctls() {
let keccak_sponge_looking = TableWithColumns::new(
Table::KeccakSponge,
keccak_sponge_stark::ctl_looking_logic(i),
// TODO: Double check, but I think it's the same filter for memory and logic?
Some(keccak_sponge_stark::ctl_looking_memory_filter(i)),
);
all_lookers.push(keccak_sponge_looking);
}
let logic_looked =
TableWithColumns::new(Table::Logic, logic::ctl_data(), Some(logic::ctl_filter()));
CrossTableLookup::new(all_lookers, logic_looked, None)
}
fn ctl_memory<F: Field>() -> CrossTableLookup<F> {
@ -144,662 +155,21 @@ fn ctl_memory<F: Field>() -> CrossTableLookup<F> {
Some(cpu_stark::ctl_filter_gp_memory(channel)),
)
});
let keccak_memory_reads = (0..KECCAK_WIDTH_BYTES).map(|i| {
let keccak_sponge_reads = (0..KECCAK_RATE_BYTES).map(|i| {
TableWithColumns::new(
Table::KeccakMemory,
keccak_memory_stark::ctl_looking_memory(i, true),
Some(keccak_memory_stark::ctl_filter()),
)
});
let keccak_memory_writes = (0..KECCAK_WIDTH_BYTES).map(|i| {
TableWithColumns::new(
Table::KeccakMemory,
keccak_memory_stark::ctl_looking_memory(i, false),
Some(keccak_memory_stark::ctl_filter()),
Table::KeccakSponge,
keccak_sponge_stark::ctl_looking_memory(i),
Some(keccak_sponge_stark::ctl_looking_memory_filter(i)),
)
});
let all_lookers = iter::once(cpu_memory_code_read)
.chain(cpu_memory_gp_ops)
.chain(keccak_memory_reads)
.chain(keccak_memory_writes)
.chain(keccak_sponge_reads)
.collect();
CrossTableLookup::new(
all_lookers,
TableWithColumns::new(
Table::Memory,
memory_stark::ctl_data(),
Some(memory_stark::ctl_filter()),
),
None,
)
}
#[cfg(test)]
mod tests {
use std::borrow::BorrowMut;
use anyhow::Result;
use ethereum_types::U256;
use itertools::Itertools;
use plonky2::field::polynomial::PolynomialValues;
use plonky2::field::types::{Field, PrimeField64};
use plonky2::iop::witness::PartialWitness;
use plonky2::plonk::circuit_builder::CircuitBuilder;
use plonky2::plonk::circuit_data::{CircuitConfig, VerifierCircuitData};
use plonky2::plonk::config::{GenericConfig, PoseidonGoldilocksConfig};
use plonky2::util::timing::TimingTree;
use rand::{thread_rng, Rng};
use crate::all_stark::{AllStark, NUM_TABLES};
use crate::config::StarkConfig;
use crate::cpu::cpu_stark::CpuStark;
use crate::cpu::kernel::aggregator::KERNEL;
use crate::cross_table_lookup::testutils::check_ctls;
use crate::keccak::keccak_stark::{KeccakStark, NUM_INPUTS, NUM_ROUNDS};
use crate::keccak_memory::keccak_memory_stark::KeccakMemoryStark;
use crate::logic::{self, LogicStark, Operation};
use crate::memory::memory_stark::tests::generate_random_memory_ops;
use crate::memory::memory_stark::MemoryStark;
use crate::memory::NUM_CHANNELS;
use crate::proof::{AllProof, PublicValues};
use crate::prover::prove_with_traces;
use crate::recursive_verifier::tests::recursively_verify_all_proof;
use crate::recursive_verifier::{
add_virtual_recursive_all_proof, all_verifier_data_recursive_stark_proof,
set_recursive_all_proof_target, RecursiveAllProof,
};
use crate::stark::Stark;
use crate::util::{limb_from_bits_le, trace_rows_to_poly_values};
use crate::verifier::verify_proof;
use crate::{cpu, keccak, memory};
const D: usize = 2;
type C = PoseidonGoldilocksConfig;
type F = <C as GenericConfig<D>>::F;
fn make_keccak_trace<R: Rng>(
num_keccak_perms: usize,
keccak_stark: &KeccakStark<F, D>,
rng: &mut R,
) -> Vec<PolynomialValues<F>> {
let keccak_inputs = (0..num_keccak_perms)
.map(|_| [0u64; NUM_INPUTS].map(|_| rng.gen()))
.collect_vec();
keccak_stark.generate_trace(keccak_inputs, &mut TimingTree::default())
}
fn make_keccak_memory_trace(
keccak_memory_stark: &KeccakMemoryStark<F, D>,
config: &StarkConfig,
) -> Vec<PolynomialValues<F>> {
keccak_memory_stark.generate_trace(
vec![],
config.fri_config.num_cap_elements(),
&mut TimingTree::default(),
)
}
fn make_logic_trace<R: Rng>(
num_rows: usize,
logic_stark: &LogicStark<F, D>,
rng: &mut R,
) -> Vec<PolynomialValues<F>> {
let all_ops = [logic::Op::And, logic::Op::Or, logic::Op::Xor];
let ops = (0..num_rows)
.map(|_| {
let op = all_ops[rng.gen_range(0..all_ops.len())];
let input0 = U256(rng.gen());
let input1 = U256(rng.gen());
Operation::new(op, input0, input1)
})
.collect();
logic_stark.generate_trace(ops, &mut TimingTree::default())
}
fn make_memory_trace<R: Rng>(
num_memory_ops: usize,
memory_stark: &MemoryStark<F, D>,
rng: &mut R,
) -> (Vec<PolynomialValues<F>>, usize) {
let memory_ops = generate_random_memory_ops(num_memory_ops, rng);
let trace = memory_stark.generate_trace(memory_ops, &mut TimingTree::default());
let num_ops = trace[0].values.len();
(trace, num_ops)
}
fn bits_from_opcode(opcode: u8) -> [F; 8] {
[
F::from_bool(opcode & (1 << 0) != 0),
F::from_bool(opcode & (1 << 1) != 0),
F::from_bool(opcode & (1 << 2) != 0),
F::from_bool(opcode & (1 << 3) != 0),
F::from_bool(opcode & (1 << 4) != 0),
F::from_bool(opcode & (1 << 5) != 0),
F::from_bool(opcode & (1 << 6) != 0),
F::from_bool(opcode & (1 << 7) != 0),
]
}
fn make_cpu_trace(
num_keccak_perms: usize,
num_logic_rows: usize,
num_memory_ops: usize,
cpu_stark: &CpuStark<F, D>,
keccak_trace: &[PolynomialValues<F>],
logic_trace: &[PolynomialValues<F>],
memory_trace: &mut [PolynomialValues<F>],
) -> Vec<PolynomialValues<F>> {
let keccak_input_limbs: Vec<[F; 2 * NUM_INPUTS]> = (0..num_keccak_perms)
.map(|i| {
(0..2 * NUM_INPUTS)
.map(|j| {
keccak::columns::reg_input_limb(j)
.eval_table(keccak_trace, (i + 1) * NUM_ROUNDS - 1)
})
.collect::<Vec<_>>()
.try_into()
.unwrap()
})
.collect();
let keccak_output_limbs: Vec<[F; 2 * NUM_INPUTS]> = (0..num_keccak_perms)
.map(|i| {
(0..2 * NUM_INPUTS)
.map(|j| {
keccak_trace[keccak::columns::reg_output_limb(j)].values
[(i + 1) * NUM_ROUNDS - 1]
})
.collect::<Vec<_>>()
.try_into()
.unwrap()
})
.collect();
let mut cpu_trace_rows: Vec<[F; CpuStark::<F, D>::COLUMNS]> = vec![];
let mut bootstrap_row: cpu::columns::CpuColumnsView<F> =
[F::ZERO; CpuStark::<F, D>::COLUMNS].into();
bootstrap_row.is_bootstrap_kernel = F::ONE;
cpu_trace_rows.push(bootstrap_row.into());
for i in 0..num_keccak_perms {
let mut row: cpu::columns::CpuColumnsView<F> =
[F::ZERO; CpuStark::<F, D>::COLUMNS].into();
row.is_keccak = F::ONE;
let keccak = row.general.keccak_mut();
for j in 0..2 * NUM_INPUTS {
keccak.input_limbs[j] = keccak_input_limbs[i][j];
keccak.output_limbs[j] = keccak_output_limbs[i][j];
}
cpu_stark.generate(row.borrow_mut());
cpu_trace_rows.push(row.into());
}
// Pad to `num_memory_ops` for memory testing.
for _ in cpu_trace_rows.len()..num_memory_ops {
let mut row: cpu::columns::CpuColumnsView<F> =
[F::ZERO; CpuStark::<F, D>::COLUMNS].into();
row.opcode_bits = bits_from_opcode(0x5b);
row.is_cpu_cycle = F::ONE;
row.is_kernel_mode = F::ONE;
row.program_counter = F::from_canonical_usize(KERNEL.global_labels["main"]);
cpu_stark.generate(row.borrow_mut());
cpu_trace_rows.push(row.into());
}
for i in 0..num_memory_ops {
let mem_timestamp: usize = memory_trace[memory::columns::TIMESTAMP].values[i]
.to_canonical_u64()
.try_into()
.unwrap();
let clock = mem_timestamp / NUM_CHANNELS;
let channel = mem_timestamp % NUM_CHANNELS;
let filter = memory_trace[memory::columns::FILTER].values[i];
assert!(filter.is_one() || filter.is_zero());
let is_actual_op = filter.is_one();
if is_actual_op {
let row: &mut cpu::columns::CpuColumnsView<F> = cpu_trace_rows[clock].borrow_mut();
row.clock = F::from_canonical_usize(clock);
dbg!(channel, row.mem_channels.len());
let channel = &mut row.mem_channels[channel];
channel.used = F::ONE;
channel.is_read = memory_trace[memory::columns::IS_READ].values[i];
channel.addr_context = memory_trace[memory::columns::ADDR_CONTEXT].values[i];
channel.addr_segment = memory_trace[memory::columns::ADDR_SEGMENT].values[i];
channel.addr_virtual = memory_trace[memory::columns::ADDR_VIRTUAL].values[i];
for j in 0..8 {
channel.value[j] = memory_trace[memory::columns::value_limb(j)].values[i];
}
}
}
for i in 0..num_logic_rows {
let mut row: cpu::columns::CpuColumnsView<F> =
[F::ZERO; CpuStark::<F, D>::COLUMNS].into();
row.is_cpu_cycle = F::ONE;
row.is_kernel_mode = F::ONE;
// Since these are the first cycle rows, we must start with PC=main then increment.
row.program_counter = F::from_canonical_usize(KERNEL.global_labels["main"] + i);
row.opcode_bits = bits_from_opcode(
if logic_trace[logic::columns::IS_AND].values[i] != F::ZERO {
0x16
} else if logic_trace[logic::columns::IS_OR].values[i] != F::ZERO {
0x17
} else if logic_trace[logic::columns::IS_XOR].values[i] != F::ZERO {
0x18
} else {
panic!()
},
);
let input0_bit_cols = logic::columns::limb_bit_cols_for_input(logic::columns::INPUT0);
for (col_cpu, limb_cols_logic) in
row.mem_channels[0].value.iter_mut().zip(input0_bit_cols)
{
*col_cpu = limb_from_bits_le(limb_cols_logic.map(|col| logic_trace[col].values[i]));
}
let input1_bit_cols = logic::columns::limb_bit_cols_for_input(logic::columns::INPUT1);
for (col_cpu, limb_cols_logic) in
row.mem_channels[1].value.iter_mut().zip(input1_bit_cols)
{
*col_cpu = limb_from_bits_le(limb_cols_logic.map(|col| logic_trace[col].values[i]));
}
for (col_cpu, col_logic) in row.mem_channels[2]
.value
.iter_mut()
.zip(logic::columns::RESULT)
{
*col_cpu = logic_trace[col_logic].values[i];
}
cpu_stark.generate(row.borrow_mut());
cpu_trace_rows.push(row.into());
}
// Trap to kernel
{
let mut row: cpu::columns::CpuColumnsView<F> =
[F::ZERO; CpuStark::<F, D>::COLUMNS].into();
let last_row: cpu::columns::CpuColumnsView<F> =
cpu_trace_rows[cpu_trace_rows.len() - 1].into();
row.is_cpu_cycle = F::ONE;
row.opcode_bits = bits_from_opcode(0x0a); // `EXP` is implemented in software
row.is_kernel_mode = F::ONE;
row.program_counter = last_row.program_counter + F::ONE;
row.mem_channels[0].value = [
row.program_counter,
F::ONE,
F::ZERO,
F::ZERO,
F::ZERO,
F::ZERO,
F::ZERO,
F::ZERO,
];
cpu_stark.generate(row.borrow_mut());
cpu_trace_rows.push(row.into());
}
// `EXIT_KERNEL` (to kernel)
{
let mut row: cpu::columns::CpuColumnsView<F> =
[F::ZERO; CpuStark::<F, D>::COLUMNS].into();
row.is_cpu_cycle = F::ONE;
row.opcode_bits = bits_from_opcode(0xf9);
row.is_kernel_mode = F::ONE;
row.program_counter = F::from_canonical_usize(KERNEL.global_labels["sys_exp"]);
row.mem_channels[0].value = [
F::from_canonical_u16(15682),
F::ONE,
F::ZERO,
F::ZERO,
F::ZERO,
F::ZERO,
F::ZERO,
F::ZERO,
];
cpu_stark.generate(row.borrow_mut());
cpu_trace_rows.push(row.into());
}
// `JUMP` (in kernel mode)
{
let mut row: cpu::columns::CpuColumnsView<F> =
[F::ZERO; CpuStark::<F, D>::COLUMNS].into();
row.is_cpu_cycle = F::ONE;
row.opcode_bits = bits_from_opcode(0x56);
row.is_kernel_mode = F::ONE;
row.program_counter = F::from_canonical_u16(15682);
row.mem_channels[0].value = [
F::from_canonical_u16(15106),
F::ZERO,
F::ZERO,
F::ZERO,
F::ZERO,
F::ZERO,
F::ZERO,
F::ZERO,
];
row.mem_channels[1].value = [
F::ONE,
F::ZERO,
F::ZERO,
F::ZERO,
F::ZERO,
F::ZERO,
F::ZERO,
F::ZERO,
];
row.general.jumps_mut().input0_upper_zero = F::ONE;
row.general.jumps_mut().dst_valid_or_kernel = F::ONE;
row.general.jumps_mut().input0_jumpable = F::ONE;
row.general.jumps_mut().input1_sum_inv = F::ONE;
row.general.jumps_mut().should_jump = F::ONE;
cpu_stark.generate(row.borrow_mut());
cpu_trace_rows.push(row.into());
}
// `EXIT_KERNEL` (to userspace)
{
let mut row: cpu::columns::CpuColumnsView<F> =
[F::ZERO; CpuStark::<F, D>::COLUMNS].into();
row.is_cpu_cycle = F::ONE;
row.opcode_bits = bits_from_opcode(0xf9);
row.is_kernel_mode = F::ONE;
row.program_counter = F::from_canonical_u16(15106);
row.mem_channels[0].value = [
F::from_canonical_u16(63064),
F::ZERO,
F::ZERO,
F::ZERO,
F::ZERO,
F::ZERO,
F::ZERO,
F::ZERO,
];
cpu_stark.generate(row.borrow_mut());
cpu_trace_rows.push(row.into());
}
// `JUMP` (taken)
{
let mut row: cpu::columns::CpuColumnsView<F> =
[F::ZERO; CpuStark::<F, D>::COLUMNS].into();
row.is_cpu_cycle = F::ONE;
row.opcode_bits = bits_from_opcode(0x56);
row.is_kernel_mode = F::ZERO;
row.program_counter = F::from_canonical_u16(63064);
row.mem_channels[0].value = [
F::from_canonical_u16(3754),
F::ZERO,
F::ZERO,
F::ZERO,
F::ZERO,
F::ZERO,
F::ZERO,
F::ZERO,
];
row.mem_channels[1].value = [
F::ONE,
F::ZERO,
F::ZERO,
F::ZERO,
F::ZERO,
F::ZERO,
F::ZERO,
F::ZERO,
];
row.general.jumps_mut().input0_upper_zero = F::ONE;
row.general.jumps_mut().dst_valid = F::ONE;
row.general.jumps_mut().dst_valid_or_kernel = F::ONE;
row.general.jumps_mut().input0_jumpable = F::ONE;
row.general.jumps_mut().input1_sum_inv = F::ONE;
row.general.jumps_mut().should_jump = F::ONE;
cpu_stark.generate(row.borrow_mut());
cpu_trace_rows.push(row.into());
}
// `JUMPI` (taken)
{
let mut row: cpu::columns::CpuColumnsView<F> =
[F::ZERO; CpuStark::<F, D>::COLUMNS].into();
row.is_cpu_cycle = F::ONE;
row.opcode_bits = bits_from_opcode(0x57);
row.is_kernel_mode = F::ZERO;
row.program_counter = F::from_canonical_u16(3754);
row.mem_channels[0].value = [
F::from_canonical_u16(37543),
F::ZERO,
F::ZERO,
F::ZERO,
F::ZERO,
F::ZERO,
F::ZERO,
F::ZERO,
];
row.mem_channels[1].value = [
F::ZERO,
F::ZERO,
F::ZERO,
F::ONE,
F::ZERO,
F::ZERO,
F::ZERO,
F::ZERO,
];
row.general.jumps_mut().input0_upper_zero = F::ONE;
row.general.jumps_mut().dst_valid = F::ONE;
row.general.jumps_mut().dst_valid_or_kernel = F::ONE;
row.general.jumps_mut().input0_jumpable = F::ONE;
row.general.jumps_mut().input1_sum_inv = F::ONE;
row.general.jumps_mut().should_jump = F::ONE;
cpu_stark.generate(row.borrow_mut());
cpu_trace_rows.push(row.into());
}
// `JUMPI` (not taken)
{
let mut row: cpu::columns::CpuColumnsView<F> =
[F::ZERO; CpuStark::<F, D>::COLUMNS].into();
row.is_cpu_cycle = F::ONE;
row.opcode_bits = bits_from_opcode(0x57);
row.is_kernel_mode = F::ZERO;
row.program_counter = F::from_canonical_u16(37543);
row.mem_channels[0].value = [
F::from_canonical_u16(37543),
F::ZERO,
F::ZERO,
F::ZERO,
F::ONE,
F::ZERO,
F::ZERO,
F::ZERO,
];
row.general.jumps_mut().input0_upper_sum_inv = F::ONE;
row.general.jumps_mut().dst_valid = F::ONE;
row.general.jumps_mut().dst_valid_or_kernel = F::ONE;
row.general.jumps_mut().input0_jumpable = F::ZERO;
row.general.jumps_mut().should_continue = F::ONE;
cpu_stark.generate(row.borrow_mut());
cpu_trace_rows.push(row.into());
}
// `JUMP` (trapping)
{
let mut row: cpu::columns::CpuColumnsView<F> =
[F::ZERO; CpuStark::<F, D>::COLUMNS].into();
let last_row: cpu::columns::CpuColumnsView<F> =
cpu_trace_rows[cpu_trace_rows.len() - 1].into();
row.is_cpu_cycle = F::ONE;
row.opcode_bits = bits_from_opcode(0x56);
row.is_kernel_mode = F::ZERO;
row.program_counter = last_row.program_counter + F::ONE;
row.mem_channels[0].value = [
F::from_canonical_u16(37543),
F::ZERO,
F::ZERO,
F::ZERO,
F::ONE,
F::ZERO,
F::ZERO,
F::ZERO,
];
row.mem_channels[1].value = [
F::ONE,
F::ZERO,
F::ZERO,
F::ZERO,
F::ZERO,
F::ZERO,
F::ZERO,
F::ZERO,
];
row.general.jumps_mut().input0_upper_sum_inv = F::ONE;
row.general.jumps_mut().dst_valid = F::ONE;
row.general.jumps_mut().dst_valid_or_kernel = F::ONE;
row.general.jumps_mut().input0_jumpable = F::ZERO;
row.general.jumps_mut().input1_sum_inv = F::ONE;
row.general.jumps_mut().should_trap = F::ONE;
cpu_stark.generate(row.borrow_mut());
cpu_trace_rows.push(row.into());
}
// Pad to a power of two.
for i in 0..cpu_trace_rows.len().next_power_of_two() - cpu_trace_rows.len() {
let mut row: cpu::columns::CpuColumnsView<F> =
[F::ZERO; CpuStark::<F, D>::COLUMNS].into();
row.opcode_bits = bits_from_opcode(0xff);
row.is_cpu_cycle = F::ONE;
row.is_kernel_mode = F::ONE;
row.program_counter =
F::from_canonical_usize(KERNEL.global_labels["fault_exception"] + i);
cpu_stark.generate(row.borrow_mut());
cpu_trace_rows.push(row.into());
}
// Ensure we finish in a halted state.
{
let num_rows = cpu_trace_rows.len();
let halt_label = F::from_canonical_usize(KERNEL.global_labels["halt_pc0"]);
let last_row: &mut cpu::columns::CpuColumnsView<F> =
cpu_trace_rows[num_rows - 1].borrow_mut();
last_row.program_counter = halt_label;
}
trace_rows_to_poly_values(cpu_trace_rows)
}
fn get_proof(config: &StarkConfig) -> Result<(AllStark<F, D>, AllProof<F, C, D>)> {
let all_stark = AllStark::default();
let num_logic_rows = 62;
let num_memory_ops = 1 << 5;
let mut rng = thread_rng();
let num_keccak_perms = 2;
let keccak_trace = make_keccak_trace(num_keccak_perms, &all_stark.keccak_stark, &mut rng);
let keccak_memory_trace = make_keccak_memory_trace(&all_stark.keccak_memory_stark, config);
let logic_trace = make_logic_trace(num_logic_rows, &all_stark.logic_stark, &mut rng);
let mem_trace = make_memory_trace(num_memory_ops, &all_stark.memory_stark, &mut rng);
let mut memory_trace = mem_trace.0;
let num_memory_ops = mem_trace.1;
let cpu_trace = make_cpu_trace(
num_keccak_perms,
num_logic_rows,
num_memory_ops,
&all_stark.cpu_stark,
&keccak_trace,
&logic_trace,
&mut memory_trace,
);
let traces = [
cpu_trace,
keccak_trace,
keccak_memory_trace,
logic_trace,
memory_trace,
];
check_ctls(&traces, &all_stark.cross_table_lookups);
let public_values = PublicValues::default();
let proof = prove_with_traces::<F, C, D>(
&all_stark,
config,
traces,
public_values,
&mut TimingTree::default(),
)?;
Ok((all_stark, proof))
}
#[test]
#[ignore] // Ignoring but not deleting so the test can serve as an API usage example
fn test_all_stark() -> Result<()> {
let config = StarkConfig::standard_fast_config();
let (all_stark, proof) = get_proof(&config)?;
verify_proof(all_stark, proof, &config)
}
#[test]
#[ignore] // Ignoring but not deleting so the test can serve as an API usage example
fn test_all_stark_recursive_verifier() -> Result<()> {
init_logger();
let config = StarkConfig::standard_fast_config();
let (all_stark, proof) = get_proof(&config)?;
verify_proof(all_stark.clone(), proof.clone(), &config)?;
recursive_proof(all_stark, proof, &config)
}
fn recursive_proof(
inner_all_stark: AllStark<F, D>,
inner_proof: AllProof<F, C, D>,
inner_config: &StarkConfig,
) -> Result<()> {
let circuit_config = CircuitConfig::standard_recursion_config();
let recursive_all_proof = recursively_verify_all_proof(
&inner_all_stark,
&inner_proof,
inner_config,
&circuit_config,
)?;
let verifier_data: [VerifierCircuitData<F, C, D>; NUM_TABLES] =
all_verifier_data_recursive_stark_proof(
&inner_all_stark,
inner_proof.degree_bits(inner_config),
inner_config,
&circuit_config,
);
let circuit_config = CircuitConfig::standard_recursion_config();
let mut builder = CircuitBuilder::<F, D>::new(circuit_config);
let mut pw = PartialWitness::new();
let recursive_all_proof_target =
add_virtual_recursive_all_proof(&mut builder, &verifier_data);
set_recursive_all_proof_target(&mut pw, &recursive_all_proof_target, &recursive_all_proof);
RecursiveAllProof::verify_circuit(
&mut builder,
recursive_all_proof_target,
&verifier_data,
inner_all_stark.cross_table_lookups,
inner_config,
);
let data = builder.build::<C>();
let proof = data.prove(pw)?;
data.verify(proof)
}
fn init_logger() {
let _ = env_logger::builder().format_timestamp(None).try_init();
}
let memory_looked = TableWithColumns::new(
Table::Memory,
memory_stark::ctl_data(),
Some(memory_stark::ctl_filter()),
);
CrossTableLookup::new(all_lookers, memory_looked, None)
}

View File

@ -35,6 +35,7 @@ pub(crate) fn eval_packed_generic_are_equal<P, I, J>(
is_op: P,
larger: I,
smaller: J,
is_two_row_op: bool,
) -> P
where
P: PackedField,
@ -47,7 +48,11 @@ where
for (a, b) in larger.zip(smaller) {
// t should be either 0 or 2^LIMB_BITS
let t = cy + a - b;
yield_constr.constraint(is_op * t * (overflow - t));
if is_two_row_op {
yield_constr.constraint_transition(is_op * t * (overflow - t));
} else {
yield_constr.constraint(is_op * t * (overflow - t));
}
// cy <-- 0 or 1
// NB: this is multiplication by a constant, so doesn't
// increase the degree of the constraint.
@ -62,6 +67,7 @@ pub(crate) fn eval_ext_circuit_are_equal<F, const D: usize, I, J>(
is_op: ExtensionTarget<D>,
larger: I,
smaller: J,
is_two_row_op: bool,
) -> ExtensionTarget<D>
where
F: RichField + Extendable<D>,
@ -87,7 +93,11 @@ where
let t2 = builder.mul_extension(t, t1);
let filtered_limb_constraint = builder.mul_extension(is_op, t2);
yield_constr.constraint(builder, filtered_limb_constraint);
if is_two_row_op {
yield_constr.constraint_transition(builder, filtered_limb_constraint);
} else {
yield_constr.constraint(builder, filtered_limb_constraint);
}
cy = builder.mul_const_extension(overflow_inv, t);
}
@ -125,6 +135,7 @@ pub fn eval_packed_generic<P: PackedField>(
is_add,
output_computed,
output_limbs.iter().copied(),
false,
);
}
@ -155,6 +166,7 @@ pub fn eval_ext_circuit<F: RichField + Extendable<D>, const D: usize>(
is_add,
output_computed.into_iter(),
output_limbs.iter().copied(),
false,
);
}

View File

@ -17,7 +17,11 @@ pub struct ArithmeticStark<F, const D: usize> {
}
impl<F: RichField, const D: usize> ArithmeticStark<F, D> {
pub fn generate(&self, local_values: &mut [F; columns::NUM_ARITH_COLUMNS]) {
pub fn generate(
&self,
local_values: &mut [F; columns::NUM_ARITH_COLUMNS],
next_values: &mut [F; columns::NUM_ARITH_COLUMNS],
) {
// Check that at most one operation column is "one" and that the
// rest are "zero".
assert_eq!(
@ -47,17 +51,17 @@ impl<F: RichField, const D: usize> ArithmeticStark<F, D> {
} else if local_values[columns::IS_GT].is_one() {
compare::generate(local_values, columns::IS_GT);
} else if local_values[columns::IS_ADDMOD].is_one() {
modular::generate(local_values, columns::IS_ADDMOD);
modular::generate(local_values, next_values, columns::IS_ADDMOD);
} else if local_values[columns::IS_SUBMOD].is_one() {
modular::generate(local_values, columns::IS_SUBMOD);
modular::generate(local_values, next_values, columns::IS_SUBMOD);
} else if local_values[columns::IS_MULMOD].is_one() {
modular::generate(local_values, columns::IS_MULMOD);
modular::generate(local_values, next_values, columns::IS_MULMOD);
} else if local_values[columns::IS_MOD].is_one() {
modular::generate(local_values, columns::IS_MOD);
modular::generate(local_values, next_values, columns::IS_MOD);
} else if local_values[columns::IS_DIV].is_one() {
modular::generate(local_values, columns::IS_DIV);
modular::generate(local_values, next_values, columns::IS_DIV);
} else {
todo!("the requested operation has not yet been implemented");
panic!("the requested operation should not be handled by the arithmetic table");
}
}
}
@ -74,11 +78,12 @@ impl<F: RichField + Extendable<D>, const D: usize> Stark<F, D> for ArithmeticSta
P: PackedField<Scalar = FE>,
{
let lv = vars.local_values;
let nv = vars.next_values;
add::eval_packed_generic(lv, yield_constr);
sub::eval_packed_generic(lv, yield_constr);
mul::eval_packed_generic(lv, yield_constr);
compare::eval_packed_generic(lv, yield_constr);
modular::eval_packed_generic(lv, yield_constr);
modular::eval_packed_generic(lv, nv, yield_constr);
}
fn eval_ext_circuit(
@ -88,11 +93,12 @@ impl<F: RichField + Extendable<D>, const D: usize> Stark<F, D> for ArithmeticSta
yield_constr: &mut RecursiveConstraintConsumer<F, D>,
) {
let lv = vars.local_values;
let nv = vars.next_values;
add::eval_ext_circuit(builder, lv, yield_constr);
sub::eval_ext_circuit(builder, lv, yield_constr);
mul::eval_ext_circuit(builder, lv, yield_constr);
compare::eval_ext_circuit(builder, lv, yield_constr);
modular::eval_ext_circuit(builder, lv, yield_constr);
modular::eval_ext_circuit(builder, lv, nv, yield_constr);
}
fn constraint_degree(&self) -> usize {

View File

@ -12,7 +12,11 @@ const fn n_limbs() -> usize {
if EVM_REGISTER_BITS % LIMB_BITS != 0 {
panic!("limb size must divide EVM register size");
}
EVM_REGISTER_BITS / LIMB_BITS
let n = EVM_REGISTER_BITS / LIMB_BITS;
if n % 2 == 1 {
panic!("number of limbs must be even");
}
n
}
/// Number of LIMB_BITS limbs that are in on EVM register-sized number.
@ -40,43 +44,66 @@ pub(crate) const ALL_OPERATIONS: [usize; 12] = [
/// Within the Arithmetic Unit, there are shared columns which can be
/// used by any arithmetic circuit, depending on which one is active
/// this cycle. Can be increased as needed as other operations are
/// implemented.
const NUM_SHARED_COLS: usize = 9 * N_LIMBS; // only need 64 for add, sub, and mul
/// this cycle.
///
/// Modular arithmetic takes 9 * N_LIMBS columns which is split across
/// two rows, the first with 5 * N_LIMBS columns and the second with
/// 4 * N_LIMBS columns. (There are hence N_LIMBS "wasted columns" in
/// the second row.)
const NUM_SHARED_COLS: usize = 5 * N_LIMBS;
const GENERAL_INPUT_0: Range<usize> = START_SHARED_COLS..START_SHARED_COLS + N_LIMBS;
const GENERAL_INPUT_1: Range<usize> = GENERAL_INPUT_0.end..GENERAL_INPUT_0.end + N_LIMBS;
const GENERAL_INPUT_2: Range<usize> = GENERAL_INPUT_1.end..GENERAL_INPUT_1.end + N_LIMBS;
const GENERAL_INPUT_3: Range<usize> = GENERAL_INPUT_2.end..GENERAL_INPUT_2.end + N_LIMBS;
const AUX_INPUT_0: Range<usize> = GENERAL_INPUT_3.end..GENERAL_INPUT_3.end + 2 * N_LIMBS;
const AUX_INPUT_1: Range<usize> = AUX_INPUT_0.end..AUX_INPUT_0.end + 2 * N_LIMBS;
const AUX_INPUT_0_LO: Range<usize> = GENERAL_INPUT_3.end..GENERAL_INPUT_3.end + N_LIMBS;
// The auxiliary input columns overlap the general input columns
// because they correspond to the values in the second row for modular
// operations.
const AUX_INPUT_0_HI: Range<usize> = START_SHARED_COLS..START_SHARED_COLS + N_LIMBS;
const AUX_INPUT_1: Range<usize> = AUX_INPUT_0_HI.end..AUX_INPUT_0_HI.end + 2 * N_LIMBS;
// These auxiliary input columns are awkwardly split across two rows,
// with the first half after the general input columns and the second
// half after the auxiliary input columns.
const AUX_INPUT_2: Range<usize> = AUX_INPUT_1.end..AUX_INPUT_1.end + N_LIMBS;
// ADD takes 3 * N_LIMBS = 48 columns
pub(crate) const ADD_INPUT_0: Range<usize> = GENERAL_INPUT_0;
pub(crate) const ADD_INPUT_1: Range<usize> = GENERAL_INPUT_1;
pub(crate) const ADD_OUTPUT: Range<usize> = GENERAL_INPUT_2;
// SUB takes 3 * N_LIMBS = 48 columns
pub(crate) const SUB_INPUT_0: Range<usize> = GENERAL_INPUT_0;
pub(crate) const SUB_INPUT_1: Range<usize> = GENERAL_INPUT_1;
pub(crate) const SUB_OUTPUT: Range<usize> = GENERAL_INPUT_2;
// MUL takes 4 * N_LIMBS = 64 columns
pub(crate) const MUL_INPUT_0: Range<usize> = GENERAL_INPUT_0;
pub(crate) const MUL_INPUT_1: Range<usize> = GENERAL_INPUT_1;
pub(crate) const MUL_OUTPUT: Range<usize> = GENERAL_INPUT_2;
pub(crate) const MUL_AUX_INPUT: Range<usize> = GENERAL_INPUT_3;
// LT and GT take 4 * N_LIMBS = 64 columns
pub(crate) const CMP_INPUT_0: Range<usize> = GENERAL_INPUT_0;
pub(crate) const CMP_INPUT_1: Range<usize> = GENERAL_INPUT_1;
pub(crate) const CMP_OUTPUT: usize = GENERAL_INPUT_2.start;
pub(crate) const CMP_AUX_INPUT: Range<usize> = GENERAL_INPUT_3;
// MULMOD takes 4 * N_LIMBS + 2 * 2*N_LIMBS + N_LIMBS = 144 columns
// but split over two rows of 80 columns and 64 columns.
//
// ADDMOD, SUBMOD, MOD and DIV are currently implemented in terms of
// the general modular code, so they also take 144 columns (also split
// over two rows).
pub(crate) const MODULAR_INPUT_0: Range<usize> = GENERAL_INPUT_0;
pub(crate) const MODULAR_INPUT_1: Range<usize> = GENERAL_INPUT_1;
pub(crate) const MODULAR_MODULUS: Range<usize> = GENERAL_INPUT_2;
pub(crate) const MODULAR_OUTPUT: Range<usize> = GENERAL_INPUT_3;
pub(crate) const MODULAR_QUO_INPUT: Range<usize> = AUX_INPUT_0;
pub(crate) const MODULAR_QUO_INPUT_LO: Range<usize> = AUX_INPUT_0_LO;
// NB: Last value is not used in AUX, it is used in MOD_IS_ZERO
pub(crate) const MODULAR_AUX_INPUT: Range<usize> = AUX_INPUT_1;
pub(crate) const MODULAR_QUO_INPUT_HI: Range<usize> = AUX_INPUT_0_HI;
pub(crate) const MODULAR_AUX_INPUT: Range<usize> = AUX_INPUT_1.start..AUX_INPUT_1.end - 1;
pub(crate) const MODULAR_MOD_IS_ZERO: usize = AUX_INPUT_1.end - 1;
pub(crate) const MODULAR_OUT_AUX_RED: Range<usize> = AUX_INPUT_2;
@ -85,6 +112,6 @@ pub(crate) const DIV_NUMERATOR: Range<usize> = MODULAR_INPUT_0;
#[allow(unused)] // TODO: Will be used when hooking into the CPU
pub(crate) const DIV_DENOMINATOR: Range<usize> = MODULAR_MODULUS;
#[allow(unused)] // TODO: Will be used when hooking into the CPU
pub(crate) const DIV_OUTPUT: Range<usize> = MODULAR_QUO_INPUT.start..MODULAR_QUO_INPUT.start + 16;
pub(crate) const DIV_OUTPUT: Range<usize> = MODULAR_QUO_INPUT_LO;
pub const NUM_ARITH_COLUMNS: usize = START_SHARED_COLS + NUM_SHARED_COLS;

View File

@ -57,16 +57,27 @@ pub(crate) fn eval_packed_generic_lt<P: PackedField>(
input1: &[P],
aux: &[P],
output: P,
is_two_row_op: bool,
) {
debug_assert!(input0.len() == N_LIMBS && input1.len() == N_LIMBS && aux.len() == N_LIMBS);
// Verify (input0 < input1) == output by providing aux such that
// input0 - input1 == aux + output*2^256.
let lhs_limbs = input0.iter().zip(input1).map(|(&a, &b)| a - b);
let cy = eval_packed_generic_are_equal(yield_constr, is_op, aux.iter().copied(), lhs_limbs);
let cy = eval_packed_generic_are_equal(
yield_constr,
is_op,
aux.iter().copied(),
lhs_limbs,
is_two_row_op,
);
// We don't need to check that cy is 0 or 1, since output has
// already been checked to be 0 or 1.
yield_constr.constraint(is_op * (cy - output));
if is_two_row_op {
yield_constr.constraint_transition(is_op * (cy - output));
} else {
yield_constr.constraint(is_op * (cy - output));
}
}
pub fn eval_packed_generic<P: PackedField>(
@ -88,8 +99,8 @@ pub fn eval_packed_generic<P: PackedField>(
let is_cmp = is_lt + is_gt;
eval_packed_generic_check_is_one_bit(yield_constr, is_cmp, output);
eval_packed_generic_lt(yield_constr, is_lt, input0, input1, aux, output);
eval_packed_generic_lt(yield_constr, is_gt, input1, input0, aux, output);
eval_packed_generic_lt(yield_constr, is_lt, input0, input1, aux, output, false);
eval_packed_generic_lt(yield_constr, is_gt, input1, input0, aux, output, false);
}
fn eval_ext_circuit_check_is_one_bit<F: RichField + Extendable<D>, const D: usize>(
@ -112,6 +123,7 @@ pub(crate) fn eval_ext_circuit_lt<F: RichField + Extendable<D>, const D: usize>(
input1: &[ExtensionTarget<D>],
aux: &[ExtensionTarget<D>],
output: ExtensionTarget<D>,
is_two_row_op: bool,
) {
debug_assert!(input0.len() == N_LIMBS && input1.len() == N_LIMBS && aux.len() == N_LIMBS);
@ -131,10 +143,11 @@ pub(crate) fn eval_ext_circuit_lt<F: RichField + Extendable<D>, const D: usize>(
is_op,
aux.iter().copied(),
lhs_limbs.into_iter(),
is_two_row_op,
);
let good_output = builder.sub_extension(cy, output);
let filter = builder.mul_extension(is_op, good_output);
yield_constr.constraint(builder, filter);
yield_constr.constraint_transition(builder, filter);
}
pub fn eval_ext_circuit<F: RichField + Extendable<D>, const D: usize>(
@ -153,8 +166,26 @@ pub fn eval_ext_circuit<F: RichField + Extendable<D>, const D: usize>(
let is_cmp = builder.add_extension(is_lt, is_gt);
eval_ext_circuit_check_is_one_bit(builder, yield_constr, is_cmp, output);
eval_ext_circuit_lt(builder, yield_constr, is_lt, input0, input1, aux, output);
eval_ext_circuit_lt(builder, yield_constr, is_gt, input1, input0, aux, output);
eval_ext_circuit_lt(
builder,
yield_constr,
is_lt,
input0,
input1,
aux,
output,
false,
);
eval_ext_circuit_lt(
builder,
yield_constr,
is_gt,
input1,
input0,
aux,
output,
false,
);
}
#[cfg(test)]

View File

@ -1,3 +1,9 @@
use std::str::FromStr;
use ethereum_types::U256;
use crate::util::{addmod, mulmod, submod};
mod add;
mod compare;
mod modular;
@ -7,3 +13,146 @@ mod utils;
pub mod arithmetic_stark;
pub(crate) mod columns;
#[derive(Clone, Copy, Debug, Eq, PartialEq)]
pub(crate) enum BinaryOperator {
Add,
Mul,
Sub,
Div,
Mod,
Lt,
Gt,
Shl,
Shr,
AddFp254,
MulFp254,
SubFp254,
}
impl BinaryOperator {
pub(crate) fn result(&self, input0: U256, input1: U256) -> U256 {
match self {
BinaryOperator::Add => input0.overflowing_add(input1).0,
BinaryOperator::Mul => input0.overflowing_mul(input1).0,
BinaryOperator::Sub => input0.overflowing_sub(input1).0,
BinaryOperator::Div => {
if input1.is_zero() {
U256::zero()
} else {
input0 / input1
}
}
BinaryOperator::Mod => {
if input1.is_zero() {
U256::zero()
} else {
input0 % input1
}
}
BinaryOperator::Lt => {
if input0 < input1 {
U256::one()
} else {
U256::zero()
}
}
BinaryOperator::Gt => {
if input0 > input1 {
U256::one()
} else {
U256::zero()
}
}
BinaryOperator::Shl => {
if input0 > 255.into() {
U256::zero()
} else {
input1 << input0
}
}
BinaryOperator::Shr => {
if input0 > 255.into() {
U256::zero()
} else {
input1 >> input0
}
}
BinaryOperator::AddFp254 => addmod(input0, input1, bn_base_order()),
BinaryOperator::MulFp254 => mulmod(input0, input1, bn_base_order()),
BinaryOperator::SubFp254 => submod(input0, input1, bn_base_order()),
}
}
}
#[derive(Clone, Copy, Debug, Eq, PartialEq)]
pub(crate) enum TernaryOperator {
AddMod,
MulMod,
}
impl TernaryOperator {
pub(crate) fn result(&self, input0: U256, input1: U256, input2: U256) -> U256 {
match self {
TernaryOperator::AddMod => addmod(input0, input1, input2),
TernaryOperator::MulMod => mulmod(input0, input1, input2),
}
}
}
#[derive(Debug)]
#[allow(unused)] // TODO: Should be used soon.
pub(crate) enum Operation {
BinaryOperation {
operator: BinaryOperator,
input0: U256,
input1: U256,
result: U256,
},
TernaryOperation {
operator: TernaryOperator,
input0: U256,
input1: U256,
input2: U256,
result: U256,
},
}
impl Operation {
pub(crate) fn binary(operator: BinaryOperator, input0: U256, input1: U256) -> Self {
let result = operator.result(input0, input1);
Self::BinaryOperation {
operator,
input0,
input1,
result,
}
}
pub(crate) fn ternary(
operator: TernaryOperator,
input0: U256,
input1: U256,
input2: U256,
) -> Self {
let result = operator.result(input0, input1, input2);
Self::TernaryOperation {
operator,
input0,
input1,
input2,
result,
}
}
pub(crate) fn result(&self) -> U256 {
match self {
Operation::BinaryOperation { result, .. } => *result,
Operation::TernaryOperation { result, .. } => *result,
}
}
}
fn bn_base_order() -> U256 {
U256::from_str("0x30644e72e131a029b85045b68181585d97816a916871ca8d3c208c16d87cfd47").unwrap()
}

View File

@ -86,6 +86,27 @@
//!
//! In the case of DIV, we do something similar, except that we "replace"
//! the modulus with "2^256" to force the quotient to be zero.
//!
//! -*-
//!
//! NB: The implementation uses 9 * N_LIMBS = 144 columns because of
//! the requirements of the general purpose MULMOD; since ADDMOD,
//! SUBMOD, MOD and DIV are currently implemented in terms of the
//! general modular code, they also take 144 columns. Possible
//! improvements:
//!
//! - We could reduce the number of columns to 112 for ADDMOD, SUBMOD,
//! etc. if they were implemented separately, so they don't pay the
//! full cost of the general MULMOD.
//!
//! - All these operations could have alternative forms where the
//! output was not guaranteed to be reduced, which is often sufficient
//! in practice, and which would save a further 16 columns.
//!
//! - If the modulus is known in advance (such as for elliptic curve
//! arithmetic), specialised handling of MULMOD in that case would
//! only require 96 columns, or 80 if the output doesn't need to be
//! reduced.
use num::bigint::Sign;
use num::{BigInt, One, Zero};
@ -171,11 +192,13 @@ fn bigint_to_columns<const N: usize>(num: &BigInt) -> [i64; N] {
/// zero if they are not used.
fn generate_modular_op<F: RichField>(
lv: &mut [F; NUM_ARITH_COLUMNS],
nv: &mut [F; NUM_ARITH_COLUMNS],
filter: usize,
operation: fn([i64; N_LIMBS], [i64; N_LIMBS]) -> [i64; 2 * N_LIMBS - 1],
) {
// Inputs are all range-checked in [0, 2^16), so the "as i64"
// conversion is safe.
let input0_limbs = read_value_i64_limbs(lv, MODULAR_INPUT_0);
let input1_limbs = read_value_i64_limbs(lv, MODULAR_INPUT_1);
let mut modulus_limbs = read_value_i64_limbs(lv, MODULAR_MODULUS);
@ -246,21 +269,38 @@ fn generate_modular_op<F: RichField>(
let aux_limbs = pol_remove_root_2exp::<LIMB_BITS, _, { 2 * N_LIMBS }>(constr_poly);
lv[MODULAR_OUTPUT].copy_from_slice(&output_limbs.map(|c| F::from_canonical_i64(c)));
lv[MODULAR_OUT_AUX_RED].copy_from_slice(&out_aux_red.map(|c| F::from_canonical_i64(c)));
lv[MODULAR_QUO_INPUT].copy_from_slice(&quot_limbs.map(|c| F::from_noncanonical_i64(c)));
lv[MODULAR_AUX_INPUT].copy_from_slice(&aux_limbs.map(|c| F::from_noncanonical_i64(c)));
lv[MODULAR_MOD_IS_ZERO] = mod_is_zero;
// Copy lo and hi halves of quot_limbs into their respective registers
for (i, &lo) in MODULAR_QUO_INPUT_LO.zip(&quot_limbs[..N_LIMBS]) {
lv[i] = F::from_noncanonical_i64(lo);
}
for (i, &hi) in MODULAR_QUO_INPUT_HI.zip(&quot_limbs[N_LIMBS..]) {
nv[i] = F::from_noncanonical_i64(hi);
}
for (i, &c) in MODULAR_AUX_INPUT.zip(&aux_limbs[..2 * N_LIMBS - 1]) {
nv[i] = F::from_noncanonical_i64(c);
}
nv[MODULAR_MOD_IS_ZERO] = mod_is_zero;
nv[MODULAR_OUT_AUX_RED].copy_from_slice(&out_aux_red.map(|c| F::from_canonical_i64(c)));
}
/// Generate the output and auxiliary values for modular operations.
///
/// `filter` must be one of `columns::IS_{ADDMOD,MULMOD,MOD}`.
pub(crate) fn generate<F: RichField>(lv: &mut [F; NUM_ARITH_COLUMNS], filter: usize) {
pub(crate) fn generate<F: RichField>(
lv: &mut [F; NUM_ARITH_COLUMNS],
nv: &mut [F; NUM_ARITH_COLUMNS],
filter: usize,
) {
match filter {
columns::IS_ADDMOD => generate_modular_op(lv, filter, pol_add),
columns::IS_SUBMOD => generate_modular_op(lv, filter, pol_sub),
columns::IS_MULMOD => generate_modular_op(lv, filter, pol_mul_wide),
columns::IS_MOD | columns::IS_DIV => generate_modular_op(lv, filter, |a, _| pol_extend(a)),
columns::IS_ADDMOD => generate_modular_op(lv, nv, filter, pol_add),
columns::IS_SUBMOD => generate_modular_op(lv, nv, filter, pol_sub),
columns::IS_MULMOD => generate_modular_op(lv, nv, filter, pol_mul_wide),
columns::IS_MOD | columns::IS_DIV => {
generate_modular_op(lv, nv, filter, |a, _| pol_extend(a))
}
_ => panic!("generate modular operation called with unknown opcode"),
}
}
@ -275,26 +315,28 @@ pub(crate) fn generate<F: RichField>(lv: &mut [F; NUM_ARITH_COLUMNS], filter: us
/// and check consistency when m = 0, and that c is reduced.
fn modular_constr_poly<P: PackedField>(
lv: &[P; NUM_ARITH_COLUMNS],
nv: &[P; NUM_ARITH_COLUMNS],
yield_constr: &mut ConstraintConsumer<P>,
filter: P,
) -> [P; 2 * N_LIMBS] {
range_check_error!(MODULAR_INPUT_0, 16);
range_check_error!(MODULAR_INPUT_1, 16);
range_check_error!(MODULAR_MODULUS, 16);
range_check_error!(MODULAR_QUO_INPUT, 16);
range_check_error!(MODULAR_QUO_INPUT_LO, 16);
range_check_error!(MODULAR_QUO_INPUT_HI, 16);
range_check_error!(MODULAR_AUX_INPUT, 20, signed);
range_check_error!(MODULAR_OUTPUT, 16);
let mut modulus = read_value::<N_LIMBS, _>(lv, MODULAR_MODULUS);
let mod_is_zero = lv[MODULAR_MOD_IS_ZERO];
let mod_is_zero = nv[MODULAR_MOD_IS_ZERO];
// Check that mod_is_zero is zero or one
yield_constr.constraint(filter * (mod_is_zero * mod_is_zero - mod_is_zero));
yield_constr.constraint_transition(filter * (mod_is_zero * mod_is_zero - mod_is_zero));
// Check that mod_is_zero is zero if modulus is not zero (they
// could both be zero)
let limb_sum = modulus.into_iter().sum::<P>();
yield_constr.constraint(filter * limb_sum * mod_is_zero);
yield_constr.constraint_transition(filter * limb_sum * mod_is_zero);
// See the file documentation for why this suffices to handle
// modulus = 0.
@ -308,8 +350,8 @@ fn modular_constr_poly<P: PackedField>(
output[0] += mod_is_zero * lv[IS_DIV];
// Verify that the output is reduced, i.e. output < modulus.
let out_aux_red = &lv[MODULAR_OUT_AUX_RED];
// this sets is_less_than to 1 unless we get mod_is_zero when
let out_aux_red = &nv[MODULAR_OUT_AUX_RED];
// This sets is_less_than to 1 unless we get mod_is_zero when
// doing a DIV; in that case, we need is_less_than=0, since the
// function checks
//
@ -317,6 +359,8 @@ fn modular_constr_poly<P: PackedField>(
//
// and we were given output = out_aux_red
let is_less_than = P::ONES - mod_is_zero * lv[IS_DIV];
// NB: output and modulus in lv while out_aux_red and is_less_than
// (via mod_is_zero) depend on nv.
eval_packed_generic_lt(
yield_constr,
filter,
@ -324,16 +368,23 @@ fn modular_constr_poly<P: PackedField>(
&modulus,
out_aux_red,
is_less_than,
true,
);
// restore output[0]
output[0] -= mod_is_zero * lv[IS_DIV];
// prod = q(x) * m(x)
let quot = read_value::<{ 2 * N_LIMBS }, _>(lv, MODULAR_QUO_INPUT);
let quot = {
let mut quot = [P::default(); 2 * N_LIMBS];
quot[..N_LIMBS].copy_from_slice(&lv[MODULAR_QUO_INPUT_LO]);
quot[N_LIMBS..].copy_from_slice(&nv[MODULAR_QUO_INPUT_HI]);
quot
};
let prod = pol_mul_wide2(quot, modulus);
// higher order terms must be zero
for &x in prod[2 * N_LIMBS..].iter() {
yield_constr.constraint(filter * x);
yield_constr.constraint_transition(filter * x);
}
// constr_poly = c(x) + q(x) * m(x)
@ -341,8 +392,11 @@ fn modular_constr_poly<P: PackedField>(
pol_add_assign(&mut constr_poly, &output);
// constr_poly = c(x) + q(x) * m(x) + (x - β) * s(x)
let mut aux = read_value::<{ 2 * N_LIMBS }, _>(lv, MODULAR_AUX_INPUT);
aux[2 * N_LIMBS - 1] = P::ZEROS; // zero out the MOD_IS_ZERO flag
let mut aux = [P::ZEROS; 2 * N_LIMBS];
for (i, j) in MODULAR_AUX_INPUT.enumerate() {
aux[i] = nv[j];
}
let base = P::Scalar::from_canonical_u64(1 << LIMB_BITS);
pol_add_assign(&mut constr_poly, &pol_adjoin_root(aux, base));
@ -352,6 +406,7 @@ fn modular_constr_poly<P: PackedField>(
/// Add constraints for modular operations.
pub(crate) fn eval_packed_generic<P: PackedField>(
lv: &[P; NUM_ARITH_COLUMNS],
nv: &[P; NUM_ARITH_COLUMNS],
yield_constr: &mut ConstraintConsumer<P>,
) {
// NB: The CTL code guarantees that filter is 0 or 1, i.e. that
@ -362,8 +417,12 @@ pub(crate) fn eval_packed_generic<P: PackedField>(
+ lv[columns::IS_SUBMOD]
+ lv[columns::IS_DIV];
// Ensure that this operation is not the last row of the table;
// needed because we access the next row of the table in nv.
yield_constr.constraint_last_row(filter);
// constr_poly has 2*N_LIMBS limbs
let constr_poly = modular_constr_poly(lv, yield_constr, filter);
let constr_poly = modular_constr_poly(lv, nv, yield_constr, filter);
let input0 = read_value(lv, MODULAR_INPUT_0);
let input1 = read_value(lv, MODULAR_INPUT_1);
@ -394,35 +453,36 @@ pub(crate) fn eval_packed_generic<P: PackedField>(
// operation is valid if and only if all of those coefficients
// are zero.
for &c in constr_poly_copy.iter() {
yield_constr.constraint(filter * c);
yield_constr.constraint_transition(filter * c);
}
}
}
fn modular_constr_poly_ext_circuit<F: RichField + Extendable<D>, const D: usize>(
lv: &[ExtensionTarget<D>; NUM_ARITH_COLUMNS],
nv: &[ExtensionTarget<D>; NUM_ARITH_COLUMNS],
builder: &mut CircuitBuilder<F, D>,
yield_constr: &mut RecursiveConstraintConsumer<F, D>,
filter: ExtensionTarget<D>,
) -> [ExtensionTarget<D>; 2 * N_LIMBS] {
let mut modulus = read_value::<N_LIMBS, _>(lv, MODULAR_MODULUS);
let mod_is_zero = lv[MODULAR_MOD_IS_ZERO];
let mod_is_zero = nv[MODULAR_MOD_IS_ZERO];
let t = builder.mul_sub_extension(mod_is_zero, mod_is_zero, mod_is_zero);
let t = builder.mul_extension(filter, t);
yield_constr.constraint(builder, t);
yield_constr.constraint_transition(builder, t);
let limb_sum = builder.add_many_extension(modulus);
let t = builder.mul_extension(limb_sum, mod_is_zero);
let t = builder.mul_extension(filter, t);
yield_constr.constraint(builder, t);
yield_constr.constraint_transition(builder, t);
modulus[0] = builder.add_extension(modulus[0], mod_is_zero);
let mut output = read_value::<N_LIMBS, _>(lv, MODULAR_OUTPUT);
output[0] = builder.mul_add_extension(mod_is_zero, lv[IS_DIV], output[0]);
let out_aux_red = &lv[MODULAR_OUT_AUX_RED];
let out_aux_red = &nv[MODULAR_OUT_AUX_RED];
let one = builder.one_extension();
let is_less_than =
builder.arithmetic_extension(F::NEG_ONE, F::ONE, mod_is_zero, lv[IS_DIV], one);
@ -435,22 +495,33 @@ fn modular_constr_poly_ext_circuit<F: RichField + Extendable<D>, const D: usize>
&modulus,
out_aux_red,
is_less_than,
true,
);
output[0] =
builder.arithmetic_extension(F::NEG_ONE, F::ONE, mod_is_zero, lv[IS_DIV], output[0]);
let quot = {
let zero = builder.zero_extension();
let mut quot = [zero; 2 * N_LIMBS];
quot[..N_LIMBS].copy_from_slice(&lv[MODULAR_QUO_INPUT_LO]);
quot[N_LIMBS..].copy_from_slice(&nv[MODULAR_QUO_INPUT_HI]);
quot
};
let quot = read_value::<{ 2 * N_LIMBS }, _>(lv, MODULAR_QUO_INPUT);
let prod = pol_mul_wide2_ext_circuit(builder, quot, modulus);
for &x in prod[2 * N_LIMBS..].iter() {
let t = builder.mul_extension(filter, x);
yield_constr.constraint(builder, t);
yield_constr.constraint_transition(builder, t);
}
let mut constr_poly: [_; 2 * N_LIMBS] = prod[0..2 * N_LIMBS].try_into().unwrap();
pol_add_assign_ext_circuit(builder, &mut constr_poly, &output);
let mut aux = read_value::<{ 2 * N_LIMBS }, _>(lv, MODULAR_AUX_INPUT);
aux[2 * N_LIMBS - 1] = builder.zero_extension();
let zero = builder.zero_extension();
let mut aux = [zero; 2 * N_LIMBS];
for (i, j) in MODULAR_AUX_INPUT.enumerate() {
aux[i] = nv[j];
}
let base = builder.constant_extension(F::Extension::from_canonical_u64(1u64 << LIMB_BITS));
let t = pol_adjoin_root_ext_circuit(builder, aux, base);
pol_add_assign_ext_circuit(builder, &mut constr_poly, &t);
@ -461,6 +532,7 @@ fn modular_constr_poly_ext_circuit<F: RichField + Extendable<D>, const D: usize>
pub(crate) fn eval_ext_circuit<F: RichField + Extendable<D>, const D: usize>(
builder: &mut CircuitBuilder<F, D>,
lv: &[ExtensionTarget<D>; NUM_ARITH_COLUMNS],
nv: &[ExtensionTarget<D>; NUM_ARITH_COLUMNS],
yield_constr: &mut RecursiveConstraintConsumer<F, D>,
) {
let filter = builder.add_many_extension([
@ -471,8 +543,9 @@ pub(crate) fn eval_ext_circuit<F: RichField + Extendable<D>, const D: usize>(
lv[columns::IS_DIV],
]);
let constr_poly = modular_constr_poly_ext_circuit(lv, builder, yield_constr, filter);
yield_constr.constraint_last_row(builder, filter);
let constr_poly = modular_constr_poly_ext_circuit(lv, nv, builder, yield_constr, filter);
let input0 = read_value(lv, MODULAR_INPUT_0);
let input1 = read_value(lv, MODULAR_INPUT_1);
@ -492,7 +565,7 @@ pub(crate) fn eval_ext_circuit<F: RichField + Extendable<D>, const D: usize>(
pol_sub_assign_ext_circuit(builder, &mut constr_poly_copy, input);
for &c in constr_poly_copy.iter() {
let t = builder.mul_extension(filter, c);
yield_constr.constraint(builder, t);
yield_constr.constraint_transition(builder, t);
}
}
}
@ -518,6 +591,7 @@ mod tests {
let mut rng = ChaCha8Rng::seed_from_u64(0x6feb51b7ec230f25);
let mut lv = [F::default(); NUM_ARITH_COLUMNS].map(|_| F::sample(&mut rng));
let nv = [F::default(); NUM_ARITH_COLUMNS].map(|_| F::sample(&mut rng));
// if `IS_ADDMOD == 0`, then the constraints should be met even
// if all values are garbage.
@ -533,7 +607,7 @@ mod tests {
GoldilocksField::ONE,
GoldilocksField::ONE,
);
eval_packed_generic(&lv, &mut constraint_consumer);
eval_packed_generic(&lv, &nv, &mut constraint_consumer);
for &acc in &constraint_consumer.constraint_accs {
assert_eq!(acc, GoldilocksField::ZERO);
}
@ -545,6 +619,7 @@ mod tests {
let mut rng = ChaCha8Rng::seed_from_u64(0x6feb51b7ec230f25);
let mut lv = [F::default(); NUM_ARITH_COLUMNS].map(|_| F::sample(&mut rng));
let mut nv = [F::default(); NUM_ARITH_COLUMNS].map(|_| F::sample(&mut rng));
for op_filter in [IS_ADDMOD, IS_DIV, IS_SUBMOD, IS_MOD, IS_MULMOD] {
// Reset operation columns, then select one
@ -563,9 +638,9 @@ mod tests {
lv[mi] = F::from_canonical_u16(rng.gen());
}
// For the second half of the tests, set the top 16 -
// start digits of the modulus to zero so it is much
// smaller than the inputs.
// For the second half of the tests, set the top
// 16-start digits of the modulus to zero so it is
// much smaller than the inputs.
if i > N_RND_TESTS / 2 {
// 1 <= start < N_LIMBS
let start = (rng.gen::<usize>() % (N_LIMBS - 1)) + 1;
@ -574,15 +649,15 @@ mod tests {
}
}
generate(&mut lv, op_filter);
generate(&mut lv, &mut nv, op_filter);
let mut constraint_consumer = ConstraintConsumer::new(
vec![GoldilocksField(2), GoldilocksField(3), GoldilocksField(5)],
GoldilocksField::ONE,
GoldilocksField::ONE,
GoldilocksField::ONE,
GoldilocksField::ZERO,
GoldilocksField::ZERO,
);
eval_packed_generic(&lv, &mut constraint_consumer);
eval_packed_generic(&lv, &nv, &mut constraint_consumer);
for &acc in &constraint_consumer.constraint_accs {
assert_eq!(acc, GoldilocksField::ZERO);
}
@ -596,6 +671,7 @@ mod tests {
let mut rng = ChaCha8Rng::seed_from_u64(0x6feb51b7ec230f25);
let mut lv = [F::default(); NUM_ARITH_COLUMNS].map(|_| F::sample(&mut rng));
let mut nv = [F::default(); NUM_ARITH_COLUMNS].map(|_| F::sample(&mut rng));
for op_filter in [IS_ADDMOD, IS_SUBMOD, IS_DIV, IS_MOD, IS_MULMOD] {
// Reset operation columns, then select one
@ -609,13 +685,14 @@ mod tests {
for _i in 0..N_RND_TESTS {
// set inputs to random values and the modulus to zero;
// the output is defined to be zero when modulus is zero.
for (ai, bi, mi) in izip!(MODULAR_INPUT_0, MODULAR_INPUT_1, MODULAR_MODULUS) {
lv[ai] = F::from_canonical_u16(rng.gen());
lv[bi] = F::from_canonical_u16(rng.gen());
lv[mi] = F::ZERO;
}
generate(&mut lv, op_filter);
generate(&mut lv, &mut nv, op_filter);
// check that the correct output was generated
if op_filter == IS_DIV {
@ -627,24 +704,25 @@ mod tests {
let mut constraint_consumer = ConstraintConsumer::new(
vec![GoldilocksField(2), GoldilocksField(3), GoldilocksField(5)],
GoldilocksField::ONE,
GoldilocksField::ONE,
GoldilocksField::ONE,
GoldilocksField::ZERO,
GoldilocksField::ZERO,
);
eval_packed_generic(&lv, &mut constraint_consumer);
eval_packed_generic(&lv, &nv, &mut constraint_consumer);
assert!(constraint_consumer
.constraint_accs
.iter()
.all(|&acc| acc == F::ZERO));
// Corrupt one output limb by setting it to a non-zero value
let random_oi = if op_filter == IS_DIV {
DIV_OUTPUT.start + rng.gen::<usize>() % N_LIMBS
if op_filter == IS_DIV {
let random_oi = DIV_OUTPUT.start + rng.gen::<usize>() % N_LIMBS;
lv[random_oi] = F::from_canonical_u16(rng.gen_range(1..u16::MAX));
} else {
MODULAR_OUTPUT.start + rng.gen::<usize>() % N_LIMBS
let random_oi = MODULAR_OUTPUT.start + rng.gen::<usize>() % N_LIMBS;
lv[random_oi] = F::from_canonical_u16(rng.gen_range(1..u16::MAX));
};
lv[random_oi] = F::from_canonical_u16(rng.gen_range(1..u16::MAX));
eval_packed_generic(&lv, &mut constraint_consumer);
eval_packed_generic(&lv, &nv, &mut constraint_consumer);
// Check that at least one of the constraints was non-zero
assert!(constraint_consumer

View File

@ -57,6 +57,7 @@ pub fn eval_packed_generic<P: PackedField>(
is_sub,
output_limbs.iter().copied(),
output_computed,
false,
);
}
@ -87,6 +88,7 @@ pub fn eval_ext_circuit<F: RichField + Extendable<D>, const D: usize>(
is_sub,
output_limbs.iter().copied(),
output_computed.into_iter(),
false,
);
}

View File

@ -13,53 +13,47 @@ use plonky2::plonk::circuit_builder::CircuitBuilder;
use crate::constraint_consumer::{ConstraintConsumer, RecursiveConstraintConsumer};
use crate::cpu::columns::{CpuColumnsView, NUM_CPU_COLUMNS};
use crate::cpu::kernel::aggregator::KERNEL;
use crate::cpu::kernel::keccak_util::keccakf_u32s;
use crate::cpu::membus::NUM_GP_CHANNELS;
use crate::generation::state::GenerationState;
use crate::keccak_sponge::columns::KECCAK_RATE_U32S;
use crate::memory::segments::Segment;
use crate::vars::{StarkEvaluationTargets, StarkEvaluationVars};
/// We can't process more than `NUM_CHANNELS` bytes per row, since that's all the memory bandwidth
/// we have. We also can't process more than 4 bytes (or the number of bytes in a `u32`), since we
/// want them to fit in a single limb of Keccak input.
const BYTES_PER_ROW: usize = 4;
use crate::witness::memory::MemoryAddress;
use crate::witness::util::{keccak_sponge_log, mem_write_gp_log_and_fill};
pub(crate) fn generate_bootstrap_kernel<F: Field>(state: &mut GenerationState<F>) {
let mut sponge_state = [0u32; 50];
let mut sponge_input_pos: usize = 0;
// Iterate through chunks of the code, such that we can write one chunk to memory per row.
for chunk in &KERNEL
.padded_code()
.iter()
.enumerate()
.chunks(BYTES_PER_ROW)
{
state.current_cpu_row.is_bootstrap_kernel = F::ONE;
for chunk in &KERNEL.code.iter().enumerate().chunks(NUM_GP_CHANNELS) {
let mut cpu_row = CpuColumnsView::default();
cpu_row.clock = F::from_canonical_usize(state.traces.clock());
cpu_row.is_bootstrap_kernel = F::ONE;
// Write this chunk to memory, while simultaneously packing its bytes into a u32 word.
let mut packed_bytes: u32 = 0;
for (channel, (addr, &byte)) in chunk.enumerate() {
state.set_mem_cpu_current(channel, Segment::Code, addr, byte.into());
packed_bytes = (packed_bytes << 8) | byte as u32;
let address = MemoryAddress::new(0, Segment::Code, addr);
let write =
mem_write_gp_log_and_fill(channel, address, state, &mut cpu_row, byte.into());
state.traces.push_memory(write);
}
sponge_state[sponge_input_pos] = packed_bytes;
let keccak = state.current_cpu_row.general.keccak_mut();
keccak.input_limbs = sponge_state.map(F::from_canonical_u32);
state.commit_cpu_row();
sponge_input_pos = (sponge_input_pos + 1) % KECCAK_RATE_U32S;
// If we just crossed a multiple of KECCAK_RATE_LIMBS, then we've filled the Keccak input
// buffer, so it's time to absorb.
if sponge_input_pos == 0 {
state.current_cpu_row.is_keccak = F::ONE;
keccakf_u32s(&mut sponge_state);
let keccak = state.current_cpu_row.general.keccak_mut();
keccak.output_limbs = sponge_state.map(F::from_canonical_u32);
}
state.traces.push_cpu(cpu_row);
}
let mut final_cpu_row = CpuColumnsView::default();
final_cpu_row.clock = F::from_canonical_usize(state.traces.clock());
final_cpu_row.is_bootstrap_kernel = F::ONE;
final_cpu_row.is_keccak_sponge = F::ONE;
// The Keccak sponge CTL uses memory value columns for its inputs and outputs.
final_cpu_row.mem_channels[0].value[0] = F::ZERO;
final_cpu_row.mem_channels[1].value[0] = F::from_canonical_usize(Segment::Code as usize);
final_cpu_row.mem_channels[2].value[0] = F::ZERO;
final_cpu_row.mem_channels[3].value[0] = F::from_canonical_usize(state.traces.clock());
final_cpu_row.mem_channels[4].value = KERNEL.code_hash.map(F::from_canonical_u32);
state.traces.push_cpu(final_cpu_row);
keccak_sponge_log(
state,
MemoryAddress::new(0, Segment::Code, 0),
KERNEL.code.clone(),
);
}
pub(crate) fn eval_bootstrap_kernel<F: Field, P: PackedField<Scalar = F>>(
@ -77,19 +71,25 @@ pub(crate) fn eval_bootstrap_kernel<F: Field, P: PackedField<Scalar = F>>(
let delta_is_bootstrap = next_is_bootstrap - local_is_bootstrap;
yield_constr.constraint_transition(delta_is_bootstrap * (delta_is_bootstrap + P::ONES));
// TODO: Constraints to enforce that, if IS_BOOTSTRAP_KERNEL,
// - If CLOCK is a multiple of KECCAK_RATE_LIMBS, activate the Keccak CTL, and ensure the output
// is copied to the next row (besides the first limb which will immediately be overwritten).
// - Otherwise, ensure that the Keccak input is copied to the next row (besides the next limb).
// - The next limb we add to the buffer is also written to memory.
// If this is a bootloading row and the i'th memory channel is used, it must have the right
// address, name context = 0, segment = Code, virt = clock * NUM_GP_CHANNELS + i.
let code_segment = F::from_canonical_usize(Segment::Code as usize);
for (i, channel) in local_values.mem_channels.iter().enumerate() {
let filter = local_is_bootstrap * channel.used;
yield_constr.constraint(filter * channel.addr_context);
yield_constr.constraint(filter * (channel.addr_segment - code_segment));
let expected_virt = local_values.clock * F::from_canonical_usize(NUM_GP_CHANNELS)
+ F::from_canonical_usize(i);
yield_constr.constraint(filter * (channel.addr_virtual - expected_virt));
}
// If IS_BOOTSTRAP_KERNEL changed (from 1 to 0), check that
// - the clock is a multiple of KECCAK_RATE_LIMBS (TODO)
// If this is the final bootstrap row (i.e. delta_is_bootstrap = 1), check that
// - all memory channels are disabled (TODO)
// - the current kernel hash matches a precomputed one
for (&expected, actual) in KERNEL
.code_hash
.iter()
.zip(local_values.general.keccak().output_limbs)
.zip(local_values.mem_channels.last().unwrap().value)
{
let expected = P::from(F::from_canonical_u32(expected));
let diff = expected - actual;
@ -117,19 +117,35 @@ pub(crate) fn eval_bootstrap_kernel_circuit<F: RichField + Extendable<D>, const
builder.mul_add_extension(delta_is_bootstrap, delta_is_bootstrap, delta_is_bootstrap);
yield_constr.constraint_transition(builder, constraint);
// TODO: Constraints to enforce that, if IS_BOOTSTRAP_KERNEL,
// - If CLOCK is a multiple of KECCAK_RATE_LIMBS, activate the Keccak CTL, and ensure the output
// is copied to the next row (besides the first limb which will immediately be overwritten).
// - Otherwise, ensure that the Keccak input is copied to the next row (besides the next limb).
// - The next limb we add to the buffer is also written to memory.
// If this is a bootloading row and the i'th memory channel is used, it must have the right
// address, name context = 0, segment = Code, virt = clock * NUM_GP_CHANNELS + i.
let code_segment =
builder.constant_extension(F::Extension::from_canonical_usize(Segment::Code as usize));
for (i, channel) in local_values.mem_channels.iter().enumerate() {
let filter = builder.mul_extension(local_is_bootstrap, channel.used);
let constraint = builder.mul_extension(filter, channel.addr_context);
yield_constr.constraint(builder, constraint);
// If IS_BOOTSTRAP_KERNEL changed (from 1 to 0), check that
// - the clock is a multiple of KECCAK_RATE_LIMBS (TODO)
let segment_diff = builder.sub_extension(channel.addr_segment, code_segment);
let constraint = builder.mul_extension(filter, segment_diff);
yield_constr.constraint(builder, constraint);
let i_ext = builder.constant_extension(F::Extension::from_canonical_usize(i));
let num_gp_channels_f = F::from_canonical_usize(NUM_GP_CHANNELS);
let expected_virt =
builder.mul_const_add_extension(num_gp_channels_f, local_values.clock, i_ext);
let virt_diff = builder.sub_extension(channel.addr_virtual, expected_virt);
let constraint = builder.mul_extension(filter, virt_diff);
yield_constr.constraint(builder, constraint);
}
// If this is the final bootstrap row (i.e. delta_is_bootstrap = 1), check that
// - all memory channels are disabled (TODO)
// - the current kernel hash matches a precomputed one
for (&expected, actual) in KERNEL
.code_hash
.iter()
.zip(local_values.general.keccak().output_limbs)
.zip(local_values.mem_channels.last().unwrap().value)
{
let expected = builder.constant_extension(F::Extension::from_canonical_u32(expected));
let diff = builder.sub_extension(expected, actual);

View File

@ -4,8 +4,8 @@ use std::mem::{size_of, transmute};
/// General purpose columns, which can have different meanings depending on what CTL or other
/// operation is occurring at this row.
#[derive(Clone, Copy)]
pub(crate) union CpuGeneralColumnsView<T: Copy> {
keccak: CpuKeccakView<T>,
arithmetic: CpuArithmeticView<T>,
logic: CpuLogicView<T>,
jumps: CpuJumpsView<T>,
@ -13,16 +13,6 @@ pub(crate) union CpuGeneralColumnsView<T: Copy> {
}
impl<T: Copy> CpuGeneralColumnsView<T> {
// SAFETY: Each view is a valid interpretation of the underlying array.
pub(crate) fn keccak(&self) -> &CpuKeccakView<T> {
unsafe { &self.keccak }
}
// SAFETY: Each view is a valid interpretation of the underlying array.
pub(crate) fn keccak_mut(&mut self) -> &mut CpuKeccakView<T> {
unsafe { &mut self.keccak }
}
// SAFETY: Each view is a valid interpretation of the underlying array.
pub(crate) fn arithmetic(&self) -> &CpuArithmeticView<T> {
unsafe { &self.arithmetic }
@ -93,12 +83,6 @@ impl<T: Copy> BorrowMut<[T; NUM_SHARED_COLUMNS]> for CpuGeneralColumnsView<T> {
}
}
#[derive(Copy, Clone)]
pub(crate) struct CpuKeccakView<T: Copy> {
pub(crate) input_limbs: [T; 50],
pub(crate) output_limbs: [T; 50],
}
#[derive(Copy, Clone)]
pub(crate) struct CpuArithmeticView<T: Copy> {
// TODO: Add "looking" columns for the arithmetic CTL.

View File

@ -6,6 +6,8 @@ use std::fmt::Debug;
use std::mem::{size_of, transmute};
use std::ops::{Index, IndexMut};
use plonky2::field::types::Field;
use crate::cpu::columns::general::CpuGeneralColumnsView;
use crate::cpu::columns::ops::OpsColumnsView;
use crate::cpu::membus::NUM_GP_CHANNELS;
@ -31,7 +33,7 @@ pub struct MemoryChannelView<T: Copy> {
}
#[repr(C)]
#[derive(Eq, PartialEq, Debug)]
#[derive(Clone, Copy, Eq, PartialEq, Debug)]
pub struct CpuColumnsView<T: Copy> {
/// Filter. 1 if the row is part of bootstrapping the kernel code, 0 otherwise.
pub is_bootstrap_kernel: T,
@ -67,11 +69,8 @@ pub struct CpuColumnsView<T: Copy> {
/// If CPU cycle: the opcode, broken up into bits in little-endian order.
pub opcode_bits: [T; 8],
/// Filter. 1 iff a Keccak lookup is performed on this row.
pub is_keccak: T,
/// Filter. 1 iff a Keccak memory lookup is performed on this row.
pub is_keccak_memory: T,
/// Filter. 1 iff a Keccak sponge lookup is performed on this row.
pub is_keccak_sponge: T,
pub(crate) general: CpuGeneralColumnsView<T>,
@ -82,6 +81,12 @@ pub struct CpuColumnsView<T: Copy> {
// `u8` is guaranteed to have a `size_of` of 1.
pub const NUM_CPU_COLUMNS: usize = size_of::<CpuColumnsView<u8>>();
impl<F: Field> Default for CpuColumnsView<F> {
fn default() -> Self {
Self::from([F::ZERO; NUM_CPU_COLUMNS])
}
}
impl<T: Copy> From<[T; NUM_CPU_COLUMNS]> for CpuColumnsView<T> {
fn from(value: [T; NUM_CPU_COLUMNS]) -> Self {
unsafe { transmute_no_compile_time_size_checks(value) }

View File

@ -5,8 +5,8 @@ use std::ops::{Deref, DerefMut};
use crate::util::{indices_arr, transmute_no_compile_time_size_checks};
#[repr(C)]
#[derive(Eq, PartialEq, Debug)]
pub struct OpsColumnsView<T> {
#[derive(Clone, Copy, Eq, PartialEq, Debug)]
pub struct OpsColumnsView<T: Copy> {
// TODO: combine ADD, MUL, SUB, DIV, MOD, ADDFP254, MULFP254, SUBFP254, LT, and GT into one flag
pub add: T,
pub mul: T,
@ -41,12 +41,6 @@ pub struct OpsColumnsView<T> {
pub pc: T,
pub gas: T,
pub jumpdest: T,
// TODO: combine GET_STATE_ROOT and SET_STATE_ROOT into one flag
pub get_state_root: T,
pub set_state_root: T,
// TODO: combine GET_RECEIPT_ROOT and SET_RECEIPT_ROOT into one flag
pub get_receipt_root: T,
pub set_receipt_root: T,
pub push: T,
pub dup: T,
pub swap: T,
@ -65,38 +59,38 @@ pub struct OpsColumnsView<T> {
// `u8` is guaranteed to have a `size_of` of 1.
pub const NUM_OPS_COLUMNS: usize = size_of::<OpsColumnsView<u8>>();
impl<T> From<[T; NUM_OPS_COLUMNS]> for OpsColumnsView<T> {
impl<T: Copy> From<[T; NUM_OPS_COLUMNS]> for OpsColumnsView<T> {
fn from(value: [T; NUM_OPS_COLUMNS]) -> Self {
unsafe { transmute_no_compile_time_size_checks(value) }
}
}
impl<T> From<OpsColumnsView<T>> for [T; NUM_OPS_COLUMNS] {
impl<T: Copy> From<OpsColumnsView<T>> for [T; NUM_OPS_COLUMNS] {
fn from(value: OpsColumnsView<T>) -> Self {
unsafe { transmute_no_compile_time_size_checks(value) }
}
}
impl<T> Borrow<OpsColumnsView<T>> for [T; NUM_OPS_COLUMNS] {
impl<T: Copy> Borrow<OpsColumnsView<T>> for [T; NUM_OPS_COLUMNS] {
fn borrow(&self) -> &OpsColumnsView<T> {
unsafe { transmute(self) }
}
}
impl<T> BorrowMut<OpsColumnsView<T>> for [T; NUM_OPS_COLUMNS] {
impl<T: Copy> BorrowMut<OpsColumnsView<T>> for [T; NUM_OPS_COLUMNS] {
fn borrow_mut(&mut self) -> &mut OpsColumnsView<T> {
unsafe { transmute(self) }
}
}
impl<T> Deref for OpsColumnsView<T> {
impl<T: Copy> Deref for OpsColumnsView<T> {
type Target = [T; NUM_OPS_COLUMNS];
fn deref(&self) -> &Self::Target {
unsafe { transmute(self) }
}
}
impl<T> DerefMut for OpsColumnsView<T> {
impl<T: Copy> DerefMut for OpsColumnsView<T> {
fn deref_mut(&mut self) -> &mut Self::Target {
unsafe { transmute(self) }
}

View File

@ -8,7 +8,7 @@ use crate::constraint_consumer::{ConstraintConsumer, RecursiveConstraintConsumer
use crate::cpu::columns::{CpuColumnsView, COL_MAP};
use crate::cpu::kernel::aggregator::KERNEL;
const NATIVE_INSTRUCTIONS: [usize; 37] = [
const NATIVE_INSTRUCTIONS: [usize; 33] = [
COL_MAP.op.add,
COL_MAP.op.mul,
COL_MAP.op.sub,
@ -37,10 +37,6 @@ const NATIVE_INSTRUCTIONS: [usize; 37] = [
COL_MAP.op.pc,
COL_MAP.op.gas,
COL_MAP.op.jumpdest,
COL_MAP.op.get_state_root,
COL_MAP.op.set_state_root,
COL_MAP.op.get_receipt_root,
COL_MAP.op.set_receipt_root,
// not PUSH (need to increment by more than 1)
COL_MAP.op.dup,
COL_MAP.op.swap,
@ -53,7 +49,7 @@ const NATIVE_INSTRUCTIONS: [usize; 37] = [
// not SYSCALL (performs a jump)
];
fn get_halt_pcs<F: Field>() -> (F, F) {
pub(crate) fn get_halt_pcs<F: Field>() -> (F, F) {
let halt_pc0 = KERNEL.global_labels["halt_pc0"];
let halt_pc1 = KERNEL.global_labels["halt_pc1"];
@ -63,6 +59,12 @@ fn get_halt_pcs<F: Field>() -> (F, F) {
)
}
pub(crate) fn get_start_pc<F: Field>() -> F {
let start_pc = KERNEL.global_labels["main"];
F::from_canonical_usize(start_pc)
}
pub fn eval_packed_generic<P: PackedField>(
lv: &CpuColumnsView<P>,
nv: &CpuColumnsView<P>,
@ -89,8 +91,7 @@ pub fn eval_packed_generic<P: PackedField>(
// - execution is in kernel mode, and
// - the stack is empty.
let is_last_noncpu_cycle = (lv.is_cpu_cycle - P::ONES) * nv.is_cpu_cycle;
let pc_diff =
nv.program_counter - P::Scalar::from_canonical_usize(KERNEL.global_labels["main"]);
let pc_diff = nv.program_counter - get_start_pc::<P::Scalar>();
yield_constr.constraint_transition(is_last_noncpu_cycle * pc_diff);
yield_constr.constraint_transition(is_last_noncpu_cycle * (nv.is_kernel_mode - P::ONES));
yield_constr.constraint_transition(is_last_noncpu_cycle * nv.stack_len);
@ -142,9 +143,7 @@ pub fn eval_ext_circuit<F: RichField + Extendable<D>, const D: usize>(
builder.mul_sub_extension(lv.is_cpu_cycle, nv.is_cpu_cycle, nv.is_cpu_cycle);
// Start at `main`.
let main = builder.constant_extension(F::Extension::from_canonical_usize(
KERNEL.global_labels["main"],
));
let main = builder.constant_extension(get_start_pc::<F>().into());
let pc_diff = builder.sub_extension(nv.program_counter, main);
let pc_constr = builder.mul_extension(is_last_noncpu_cycle, pc_diff);
yield_constr.constraint_transition(builder, pc_constr);

View File

@ -20,34 +20,28 @@ use crate::memory::{NUM_CHANNELS, VALUE_LIMBS};
use crate::stark::Stark;
use crate::vars::{StarkEvaluationTargets, StarkEvaluationVars};
pub fn ctl_data_keccak<F: Field>() -> Vec<Column<F>> {
let keccak = COL_MAP.general.keccak();
let mut res: Vec<_> = Column::singles(keccak.input_limbs).collect();
res.extend(Column::singles(keccak.output_limbs));
res
}
pub fn ctl_data_keccak_memory<F: Field>() -> Vec<Column<F>> {
pub fn ctl_data_keccak_sponge<F: Field>() -> Vec<Column<F>> {
// When executing KECCAK_GENERAL, the GP memory channels are used as follows:
// GP channel 0: stack[-1] = context
// GP channel 1: stack[-2] = segment
// GP channel 2: stack[-3] = virtual
// GP channel 2: stack[-3] = virt
// GP channel 3: stack[-4] = len
// GP channel 4: pushed = outputs
let context = Column::single(COL_MAP.mem_channels[0].value[0]);
let segment = Column::single(COL_MAP.mem_channels[1].value[0]);
let virt = Column::single(COL_MAP.mem_channels[2].value[0]);
let len = Column::single(COL_MAP.mem_channels[3].value[0]);
let num_channels = F::from_canonical_usize(NUM_CHANNELS);
let clock = Column::linear_combination([(COL_MAP.clock, num_channels)]);
let timestamp = Column::linear_combination([(COL_MAP.clock, num_channels)]);
vec![context, segment, virt, clock]
let mut cols = vec![context, segment, virt, len, timestamp];
cols.extend(COL_MAP.mem_channels[3].value.map(Column::single));
cols
}
pub fn ctl_filter_keccak<F: Field>() -> Column<F> {
Column::single(COL_MAP.is_keccak)
}
pub fn ctl_filter_keccak_memory<F: Field>() -> Column<F> {
Column::single(COL_MAP.is_keccak_memory)
pub fn ctl_filter_keccak_sponge<F: Field>() -> Column<F> {
Column::single(COL_MAP.is_keccak_sponge)
}
pub fn ctl_data_logic<F: Field>() -> Vec<Column<F>> {
@ -122,11 +116,11 @@ pub struct CpuStark<F, const D: usize> {
}
impl<F: RichField, const D: usize> CpuStark<F, D> {
// TODO: Remove?
pub fn generate(&self, local_values: &mut [F; NUM_CPU_COLUMNS]) {
let local_values: &mut CpuColumnsView<_> = local_values.borrow_mut();
decode::generate(local_values);
membus::generate(local_values);
simple_logic::generate(local_values);
stack_bounds::generate(local_values); // Must come after `decode`.
}
}
@ -144,17 +138,19 @@ impl<F: RichField + Extendable<D>, const D: usize> Stark<F, D> for CpuStark<F, D
{
let local_values = vars.local_values.borrow();
let next_values = vars.next_values.borrow();
// TODO: Some failing constraints temporarily disabled by using this dummy consumer.
let mut dummy_yield_constr = ConstraintConsumer::new(vec![], P::ZEROS, P::ZEROS, P::ZEROS);
bootstrap_kernel::eval_bootstrap_kernel(vars, yield_constr);
control_flow::eval_packed_generic(local_values, next_values, yield_constr);
decode::eval_packed_generic(local_values, yield_constr);
dup_swap::eval_packed(local_values, yield_constr);
jumps::eval_packed(local_values, next_values, yield_constr);
jumps::eval_packed(local_values, next_values, &mut dummy_yield_constr);
membus::eval_packed(local_values, yield_constr);
modfp254::eval_packed(local_values, yield_constr);
shift::eval_packed(local_values, yield_constr);
simple_logic::eval_packed(local_values, yield_constr);
stack::eval_packed(local_values, yield_constr);
stack_bounds::eval_packed(local_values, yield_constr);
stack_bounds::eval_packed(local_values, &mut dummy_yield_constr);
syscalls::eval_packed(local_values, next_values, yield_constr);
}
@ -166,17 +162,21 @@ impl<F: RichField + Extendable<D>, const D: usize> Stark<F, D> for CpuStark<F, D
) {
let local_values = vars.local_values.borrow();
let next_values = vars.next_values.borrow();
// TODO: Some failing constraints temporarily disabled by using this dummy consumer.
let zero = builder.zero_extension();
let mut dummy_yield_constr =
RecursiveConstraintConsumer::new(zero, vec![], zero, zero, zero);
bootstrap_kernel::eval_bootstrap_kernel_circuit(builder, vars, yield_constr);
control_flow::eval_ext_circuit(builder, local_values, next_values, yield_constr);
decode::eval_ext_circuit(builder, local_values, yield_constr);
dup_swap::eval_ext_circuit(builder, local_values, yield_constr);
jumps::eval_ext_circuit(builder, local_values, next_values, yield_constr);
jumps::eval_ext_circuit(builder, local_values, next_values, &mut dummy_yield_constr);
membus::eval_ext_circuit(builder, local_values, yield_constr);
modfp254::eval_ext_circuit(builder, local_values, yield_constr);
shift::eval_ext_circuit(builder, local_values, yield_constr);
simple_logic::eval_ext_circuit(builder, local_values, yield_constr);
stack::eval_ext_circuit(builder, local_values, yield_constr);
stack_bounds::eval_ext_circuit(builder, local_values, yield_constr);
stack_bounds::eval_ext_circuit(builder, local_values, &mut dummy_yield_constr);
syscalls::eval_ext_circuit(builder, local_values, next_values, yield_constr);
}

View File

@ -22,7 +22,7 @@ use crate::cpu::columns::{CpuColumnsView, COL_MAP};
/// behavior.
/// Note: invalid opcodes are not represented here. _Any_ opcode is permitted to decode to
/// `is_invalid`. The kernel then verifies that the opcode was _actually_ invalid.
const OPCODES: [(u8, usize, bool, usize); 42] = [
const OPCODES: [(u8, usize, bool, usize); 38] = [
// (start index of block, number of top bits to check (log2), kernel-only, flag column)
(0x01, 0, false, COL_MAP.op.add),
(0x02, 0, false, COL_MAP.op.mul),
@ -53,10 +53,6 @@ const OPCODES: [(u8, usize, bool, usize); 42] = [
(0x58, 0, false, COL_MAP.op.pc),
(0x5a, 0, false, COL_MAP.op.gas),
(0x5b, 0, false, COL_MAP.op.jumpdest),
(0x5c, 0, true, COL_MAP.op.get_state_root),
(0x5d, 0, true, COL_MAP.op.set_state_root),
(0x5e, 0, true, COL_MAP.op.get_receipt_root),
(0x5f, 0, true, COL_MAP.op.set_receipt_root),
(0x60, 5, false, COL_MAP.op.push), // 0x60-0x7f
(0x80, 4, false, COL_MAP.op.dup), // 0x80-0x8f
(0x90, 4, false, COL_MAP.op.swap), // 0x90-0x9f

View File

@ -73,4 +73,4 @@ recursion_return:
jump
global sys_exp:
PANIC
PANIC // TODO: Implement.

View File

@ -1,8 +1,9 @@
global main:
// First, initialise the shift table
%shift_table_init
// Second, load all MPT data from the prover.
PUSH txn_loop
PUSH hash_initial_tries
%jump(load_all_mpts)
hash_initial_tries:

View File

@ -2,19 +2,17 @@ use std::collections::HashMap;
use ethereum_types::U256;
use itertools::izip;
use keccak_hash::keccak;
use log::debug;
use plonky2_util::ceil_div_usize;
use super::ast::PushTarget;
use crate::cpu::kernel::ast::Item::LocalLabelDeclaration;
use crate::cpu::kernel::ast::{File, Item, StackReplacement};
use crate::cpu::kernel::keccak_util::hash_kernel;
use crate::cpu::kernel::opcodes::{get_opcode, get_push_opcode};
use crate::cpu::kernel::optimizer::optimize_asm;
use crate::cpu::kernel::stack::stack_manipulation::expand_stack_manipulation;
use crate::cpu::kernel::utils::u256_to_trimmed_be_bytes;
use crate::generation::prover_input::ProverInputFn;
use crate::keccak_sponge::columns::KECCAK_RATE_BYTES;
/// The number of bytes to push when pushing an offset within the code (i.e. when assembling jumps).
/// Ideally we would automatically use the minimal number of bytes required, but that would be
@ -41,8 +39,10 @@ impl Kernel {
global_labels: HashMap<String, usize>,
prover_inputs: HashMap<usize, ProverInputFn>,
) -> Self {
let code_hash = hash_kernel(&Self::padded_code_helper(&code));
let code_hash_bytes = keccak(&code).0;
let code_hash = std::array::from_fn(|i| {
u32::from_le_bytes(std::array::from_fn(|j| code_hash_bytes[i * 4 + j]))
});
Self {
code,
code_hash,
@ -51,16 +51,16 @@ impl Kernel {
}
}
/// Zero-pads the code such that its length is a multiple of the Keccak rate.
pub(crate) fn padded_code(&self) -> Vec<u8> {
Self::padded_code_helper(&self.code)
/// Get a string representation of the current offset for debugging purposes.
pub(crate) fn offset_name(&self, offset: usize) -> String {
self.offset_label(offset)
.unwrap_or_else(|| offset.to_string())
}
fn padded_code_helper(code: &[u8]) -> Vec<u8> {
let padded_len = ceil_div_usize(code.len(), KECCAK_RATE_BYTES) * KECCAK_RATE_BYTES;
let mut padded_code = code.to_vec();
padded_code.resize(padded_len, 0);
padded_code
pub(crate) fn offset_label(&self, offset: usize) -> Option<String> {
self.global_labels
.iter()
.find_map(|(k, v)| (*v == offset).then(|| k.clone()))
}
}

View File

@ -11,43 +11,21 @@ use crate::cpu::kernel::aggregator::KERNEL;
use crate::cpu::kernel::constants::context_metadata::ContextMetadata;
use crate::cpu::kernel::constants::global_metadata::GlobalMetadata;
use crate::cpu::kernel::constants::txn_fields::NormalizedTxnField;
use crate::generation::memory::{MemoryContextState, MemorySegmentState};
use crate::generation::prover_input::ProverInputFn;
use crate::generation::state::GenerationState;
use crate::generation::GenerationInputs;
use crate::memory::segments::Segment;
use crate::witness::memory::{MemoryContextState, MemorySegmentState, MemoryState};
use crate::witness::util::stack_peek;
type F = GoldilocksField;
/// Halt interpreter execution whenever a jump to this offset is done.
const DEFAULT_HALT_OFFSET: usize = 0xdeadbeef;
#[derive(Clone, Debug)]
pub(crate) struct InterpreterMemory {
pub(crate) context_memory: Vec<MemoryContextState>,
}
impl Default for InterpreterMemory {
fn default() -> Self {
Self {
context_memory: vec![MemoryContextState::default()],
}
}
}
impl InterpreterMemory {
fn with_code_and_stack(code: &[u8], stack: Vec<U256>) -> Self {
let mut mem = Self::default();
for (i, b) in code.iter().copied().enumerate() {
mem.context_memory[0].segments[Segment::Code as usize].set(i, b.into());
}
mem.context_memory[0].segments[Segment::Stack as usize].content = stack;
mem
}
impl MemoryState {
fn mload_general(&self, context: usize, segment: Segment, offset: usize) -> U256 {
let value = self.context_memory[context].segments[segment as usize].get(offset);
let value = self.contexts[context].segments[segment as usize].get(offset);
assert!(
value.bits() <= segment.bit_range(),
"Value read from memory exceeds expected range of {:?} segment",
@ -62,16 +40,14 @@ impl InterpreterMemory {
"Value written to memory exceeds expected range of {:?} segment",
segment
);
self.context_memory[context].segments[segment as usize].set(offset, value)
self.contexts[context].segments[segment as usize].set(offset, value)
}
}
pub struct Interpreter<'a> {
kernel_mode: bool,
jumpdests: Vec<usize>,
pub(crate) offset: usize,
pub(crate) context: usize,
pub(crate) memory: InterpreterMemory,
pub(crate) generation_state: GenerationState<F>,
prover_inputs_map: &'a HashMap<usize, ProverInputFn>,
pub(crate) halt_offsets: Vec<usize>,
@ -119,19 +95,21 @@ impl<'a> Interpreter<'a> {
initial_stack: Vec<U256>,
prover_inputs: &'a HashMap<usize, ProverInputFn>,
) -> Self {
Self {
let mut result = Self {
kernel_mode: true,
jumpdests: find_jumpdests(code),
offset: initial_offset,
memory: InterpreterMemory::with_code_and_stack(code, initial_stack),
generation_state: GenerationState::new(GenerationInputs::default()),
generation_state: GenerationState::new(GenerationInputs::default(), code),
prover_inputs_map: prover_inputs,
context: 0,
halt_offsets: vec![DEFAULT_HALT_OFFSET],
debug_offsets: vec![],
running: false,
opcode_count: [0; 0x100],
}
};
result.generation_state.registers.program_counter = initial_offset;
result.generation_state.registers.stack_len = initial_stack.len();
*result.stack_mut() = initial_stack;
result
}
pub(crate) fn run(&mut self) -> anyhow::Result<()> {
@ -152,48 +130,51 @@ impl<'a> Interpreter<'a> {
}
fn code(&self) -> &MemorySegmentState {
&self.memory.context_memory[self.context].segments[Segment::Code as usize]
&self.generation_state.memory.contexts[self.context].segments[Segment::Code as usize]
}
fn code_slice(&self, n: usize) -> Vec<u8> {
self.code().content[self.offset..self.offset + n]
let pc = self.generation_state.registers.program_counter;
self.code().content[pc..pc + n]
.iter()
.map(|u256| u256.byte(0))
.collect::<Vec<_>>()
}
pub(crate) fn get_txn_field(&self, field: NormalizedTxnField) -> U256 {
self.memory.context_memory[0].segments[Segment::TxnFields as usize].get(field as usize)
self.generation_state.memory.contexts[0].segments[Segment::TxnFields as usize]
.get(field as usize)
}
pub(crate) fn set_txn_field(&mut self, field: NormalizedTxnField, value: U256) {
self.memory.context_memory[0].segments[Segment::TxnFields as usize]
self.generation_state.memory.contexts[0].segments[Segment::TxnFields as usize]
.set(field as usize, value);
}
pub(crate) fn get_txn_data(&self) -> &[U256] {
&self.memory.context_memory[0].segments[Segment::TxnData as usize].content
&self.generation_state.memory.contexts[0].segments[Segment::TxnData as usize].content
}
pub(crate) fn get_global_metadata_field(&self, field: GlobalMetadata) -> U256 {
self.memory.context_memory[0].segments[Segment::GlobalMetadata as usize].get(field as usize)
self.generation_state.memory.contexts[0].segments[Segment::GlobalMetadata as usize]
.get(field as usize)
}
pub(crate) fn set_global_metadata_field(&mut self, field: GlobalMetadata, value: U256) {
self.memory.context_memory[0].segments[Segment::GlobalMetadata as usize]
self.generation_state.memory.contexts[0].segments[Segment::GlobalMetadata as usize]
.set(field as usize, value)
}
pub(crate) fn get_trie_data(&self) -> &[U256] {
&self.memory.context_memory[0].segments[Segment::TrieData as usize].content
&self.generation_state.memory.contexts[0].segments[Segment::TrieData as usize].content
}
pub(crate) fn get_trie_data_mut(&mut self) -> &mut Vec<U256> {
&mut self.memory.context_memory[0].segments[Segment::TrieData as usize].content
&mut self.generation_state.memory.contexts[0].segments[Segment::TrieData as usize].content
}
pub(crate) fn get_rlp_memory(&self) -> Vec<u8> {
self.memory.context_memory[0].segments[Segment::RlpRaw as usize]
self.generation_state.memory.contexts[0].segments[Segment::RlpRaw as usize]
.content
.iter()
.map(|x| x.as_u32() as u8)
@ -201,23 +182,24 @@ impl<'a> Interpreter<'a> {
}
pub(crate) fn set_rlp_memory(&mut self, rlp: Vec<u8>) {
self.memory.context_memory[0].segments[Segment::RlpRaw as usize].content =
self.generation_state.memory.contexts[0].segments[Segment::RlpRaw as usize].content =
rlp.into_iter().map(U256::from).collect();
}
pub(crate) fn set_code(&mut self, context: usize, code: Vec<u8>) {
assert_ne!(context, 0, "Can't modify kernel code.");
while self.memory.context_memory.len() <= context {
self.memory
.context_memory
while self.generation_state.memory.contexts.len() <= context {
self.generation_state
.memory
.contexts
.push(MemoryContextState::default());
}
self.memory.context_memory[context].segments[Segment::Code as usize].content =
self.generation_state.memory.contexts[context].segments[Segment::Code as usize].content =
code.into_iter().map(U256::from).collect();
}
pub(crate) fn get_jumpdest_bits(&self, context: usize) -> Vec<bool> {
self.memory.context_memory[context].segments[Segment::JumpdestBits as usize]
self.generation_state.memory.contexts[context].segments[Segment::JumpdestBits as usize]
.content
.iter()
.map(|x| x.bit(0))
@ -225,19 +207,22 @@ impl<'a> Interpreter<'a> {
}
fn incr(&mut self, n: usize) {
self.offset += n;
self.generation_state.registers.program_counter += n;
}
pub(crate) fn stack(&self) -> &[U256] {
&self.memory.context_memory[self.context].segments[Segment::Stack as usize].content
&self.generation_state.memory.contexts[self.context].segments[Segment::Stack as usize]
.content
}
fn stack_mut(&mut self) -> &mut Vec<U256> {
&mut self.memory.context_memory[self.context].segments[Segment::Stack as usize].content
&mut self.generation_state.memory.contexts[self.context].segments[Segment::Stack as usize]
.content
}
pub(crate) fn push(&mut self, x: U256) {
self.stack_mut().push(x);
self.generation_state.registers.stack_len += 1;
}
fn push_bool(&mut self, x: bool) {
@ -245,11 +230,18 @@ impl<'a> Interpreter<'a> {
}
pub(crate) fn pop(&mut self) -> U256 {
self.stack_mut().pop().expect("Pop on empty stack.")
let result = stack_peek(&self.generation_state, 0);
self.generation_state.registers.stack_len -= 1;
let new_len = self.stack_len();
self.stack_mut().truncate(new_len);
result.expect("Empty stack")
}
fn run_opcode(&mut self) -> anyhow::Result<()> {
let opcode = self.code().get(self.offset).byte(0);
let opcode = self
.code()
.get(self.generation_state.registers.program_counter)
.byte(0);
self.opcode_count[opcode as usize] += 1;
self.incr(1);
match opcode {
@ -321,10 +313,6 @@ impl<'a> Interpreter<'a> {
0x59 => self.run_msize(), // "MSIZE",
0x5a => todo!(), // "GAS",
0x5b => self.run_jumpdest(), // "JUMPDEST",
0x5c => todo!(), // "GET_STATE_ROOT",
0x5d => todo!(), // "SET_STATE_ROOT",
0x5e => todo!(), // "GET_RECEIPT_ROOT",
0x5f => todo!(), // "SET_RECEIPT_ROOT",
x if (0x60..0x80).contains(&x) => self.run_push(x - 0x5f), // "PUSH"
x if (0x80..0x90).contains(&x) => self.run_dup(x - 0x7f), // "DUP"
x if (0x90..0xa0).contains(&x) => self.run_swap(x - 0x8f)?, // "SWAP"
@ -353,7 +341,10 @@ impl<'a> Interpreter<'a> {
_ => bail!("Unrecognized opcode {}.", opcode),
};
if self.debug_offsets.contains(&self.offset) {
if self
.debug_offsets
.contains(&self.generation_state.registers.program_counter)
{
println!("At {}, stack={:?}", self.offset_name(), self.stack());
} else if let Some(label) = self.offset_label() {
println!("At {label}");
@ -362,18 +353,12 @@ impl<'a> Interpreter<'a> {
Ok(())
}
/// Get a string representation of the current offset for debugging purposes.
fn offset_name(&self) -> String {
self.offset_label()
.unwrap_or_else(|| self.offset.to_string())
KERNEL.offset_name(self.generation_state.registers.program_counter)
}
fn offset_label(&self) -> Option<String> {
// TODO: Not sure we should use KERNEL? Interpreter is more general in other places.
KERNEL
.global_labels
.iter()
.find_map(|(k, v)| (*v == self.offset).then(|| k.clone()))
KERNEL.offset_label(self.generation_state.registers.program_counter)
}
fn run_stop(&mut self) {
@ -508,12 +493,10 @@ impl<'a> Interpreter<'a> {
fn run_byte(&mut self) {
let i = self.pop();
let x = self.pop();
let result = if i > 32.into() {
0
let result = if i < 32.into() {
x.byte(31 - i.as_usize())
} else {
let mut bytes = [0; 32];
x.to_big_endian(&mut bytes);
bytes[i.as_usize()]
0
};
self.push(result.into());
}
@ -535,7 +518,8 @@ impl<'a> Interpreter<'a> {
let size = self.pop().as_usize();
let bytes = (offset..offset + size)
.map(|i| {
self.memory
self.generation_state
.memory
.mload_general(self.context, Segment::MainMemory, i)
.byte(0)
})
@ -552,7 +536,12 @@ impl<'a> Interpreter<'a> {
let offset = self.pop().as_usize();
let size = self.pop().as_usize();
let bytes = (offset..offset + size)
.map(|i| self.memory.mload_general(context, segment, i).byte(0))
.map(|i| {
self.generation_state
.memory
.mload_general(context, segment, i)
.byte(0)
})
.collect::<Vec<_>>();
println!("Hashing {:?}", &bytes);
let hash = keccak(bytes);
@ -561,7 +550,8 @@ impl<'a> Interpreter<'a> {
fn run_callvalue(&mut self) {
self.push(
self.memory.context_memory[self.context].segments[Segment::ContextMetadata as usize]
self.generation_state.memory.contexts[self.context].segments
[Segment::ContextMetadata as usize]
.get(ContextMetadata::CallValue as usize),
)
}
@ -571,7 +561,8 @@ impl<'a> Interpreter<'a> {
let value = U256::from_big_endian(
&(0..32)
.map(|i| {
self.memory
self.generation_state
.memory
.mload_general(self.context, Segment::Calldata, offset + i)
.byte(0)
})
@ -582,7 +573,8 @@ impl<'a> Interpreter<'a> {
fn run_calldatasize(&mut self) {
self.push(
self.memory.context_memory[self.context].segments[Segment::ContextMetadata as usize]
self.generation_state.memory.contexts[self.context].segments
[Segment::ContextMetadata as usize]
.get(ContextMetadata::CalldataSize as usize),
)
}
@ -592,10 +584,12 @@ impl<'a> Interpreter<'a> {
let offset = self.pop().as_usize();
let size = self.pop().as_usize();
for i in 0..size {
let calldata_byte =
self.memory
.mload_general(self.context, Segment::Calldata, offset + i);
self.memory.mstore_general(
let calldata_byte = self.generation_state.memory.mload_general(
self.context,
Segment::Calldata,
offset + i,
);
self.generation_state.memory.mstore_general(
self.context,
Segment::MainMemory,
dest_offset + i,
@ -607,10 +601,9 @@ impl<'a> Interpreter<'a> {
fn run_prover_input(&mut self) -> anyhow::Result<()> {
let prover_input_fn = self
.prover_inputs_map
.get(&(self.offset - 1))
.get(&(self.generation_state.registers.program_counter - 1))
.ok_or_else(|| anyhow!("Offset not in prover inputs."))?;
let stack = self.stack().to_vec();
let output = self.generation_state.prover_input(&stack, prover_input_fn);
let output = self.generation_state.prover_input(prover_input_fn);
self.push(output);
Ok(())
}
@ -624,7 +617,8 @@ impl<'a> Interpreter<'a> {
let value = U256::from_big_endian(
&(0..32)
.map(|i| {
self.memory
self.generation_state
.memory
.mload_general(self.context, Segment::MainMemory, offset + i)
.byte(0)
})
@ -639,15 +633,19 @@ impl<'a> Interpreter<'a> {
let mut bytes = [0; 32];
value.to_big_endian(&mut bytes);
for (i, byte) in (0..32).zip(bytes) {
self.memory
.mstore_general(self.context, Segment::MainMemory, offset + i, byte.into());
self.generation_state.memory.mstore_general(
self.context,
Segment::MainMemory,
offset + i,
byte.into(),
);
}
}
fn run_mstore8(&mut self) {
let offset = self.pop().as_usize();
let value = self.pop();
self.memory.mstore_general(
self.generation_state.memory.mstore_general(
self.context,
Segment::MainMemory,
offset,
@ -669,12 +667,13 @@ impl<'a> Interpreter<'a> {
}
fn run_pc(&mut self) {
self.push((self.offset - 1).into());
self.push((self.generation_state.registers.program_counter - 1).into());
}
fn run_msize(&mut self) {
self.push(
self.memory.context_memory[self.context].segments[Segment::ContextMetadata as usize]
self.generation_state.memory.contexts[self.context].segments
[Segment::ContextMetadata as usize]
.get(ContextMetadata::MSize as usize),
)
}
@ -689,7 +688,7 @@ impl<'a> Interpreter<'a> {
panic!("Destination is not a JUMPDEST.");
}
self.offset = offset;
self.generation_state.registers.program_counter = offset;
if self.halt_offsets.contains(&offset) {
self.running = false;
@ -703,11 +702,11 @@ impl<'a> Interpreter<'a> {
}
fn run_dup(&mut self, n: u8) {
self.push(self.stack()[self.stack().len() - n as usize]);
self.push(self.stack()[self.stack_len() - n as usize]);
}
fn run_swap(&mut self, n: u8) -> anyhow::Result<()> {
let len = self.stack().len();
let len = self.stack_len();
ensure!(len > n as usize);
self.stack_mut().swap(len - 1, len - n as usize - 1);
Ok(())
@ -726,7 +725,10 @@ impl<'a> Interpreter<'a> {
let context = self.pop().as_usize();
let segment = Segment::all()[self.pop().as_usize()];
let offset = self.pop().as_usize();
let value = self.memory.mload_general(context, segment, offset);
let value = self
.generation_state
.memory
.mload_general(context, segment, offset);
assert!(value.bits() <= segment.bit_range());
self.push(value);
}
@ -743,7 +745,13 @@ impl<'a> Interpreter<'a> {
segment,
segment.bit_range()
);
self.memory.mstore_general(context, segment, offset, value);
self.generation_state
.memory
.mstore_general(context, segment, offset, value);
}
fn stack_len(&self) -> usize {
self.generation_state.registers.stack_len
}
}
@ -833,10 +841,6 @@ fn get_mnemonic(opcode: u8) -> &'static str {
0x59 => "MSIZE",
0x5a => "GAS",
0x5b => "JUMPDEST",
0x5c => "GET_STATE_ROOT",
0x5d => "SET_STATE_ROOT",
0x5e => "GET_RECEIPT_ROOT",
0x5f => "SET_RECEIPT_ROOT",
0x60 => "PUSH1",
0x61 => "PUSH2",
0x62 => "PUSH3",
@ -969,11 +973,13 @@ mod tests {
let run = run(&code, 0, vec![], &pis)?;
assert_eq!(run.stack(), &[0xff.into(), 0xff00.into()]);
assert_eq!(
run.memory.context_memory[0].segments[Segment::MainMemory as usize].get(0x27),
run.generation_state.memory.contexts[0].segments[Segment::MainMemory as usize]
.get(0x27),
0x42.into()
);
assert_eq!(
run.memory.context_memory[0].segments[Segment::MainMemory as usize].get(0x1f),
run.generation_state.memory.contexts[0].segments[Segment::MainMemory as usize]
.get(0x1f),
0xff.into()
);
Ok(())

View File

@ -1,31 +1,9 @@
use tiny_keccak::keccakf;
use crate::keccak_sponge::columns::{KECCAK_RATE_BYTES, KECCAK_RATE_U32S};
/// A Keccak-f based hash.
///
/// This hash does not use standard Keccak padding, since we don't care about extra zeros at the
/// end of the code. It also uses an overwrite-mode sponge, rather than a standard sponge where
/// inputs are xor'ed in.
pub(crate) fn hash_kernel(code: &[u8]) -> [u32; 8] {
debug_assert_eq!(
code.len() % KECCAK_RATE_BYTES,
0,
"Code should have been padded to a multiple of the Keccak rate."
);
let mut state = [0u32; 50];
for chunk in code.chunks(KECCAK_RATE_BYTES) {
for i in 0..KECCAK_RATE_U32S {
state[i] = u32::from_le_bytes(std::array::from_fn(|j| chunk[i * 4 + j]));
}
keccakf_u32s(&mut state);
}
state[..8].try_into().unwrap()
}
use crate::keccak_sponge::columns::{KECCAK_WIDTH_BYTES, KECCAK_WIDTH_U32S};
/// Like tiny-keccak's `keccakf`, but deals with `u32` limbs instead of `u64` limbs.
pub(crate) fn keccakf_u32s(state_u32s: &mut [u32; 50]) {
pub(crate) fn keccakf_u32s(state_u32s: &mut [u32; KECCAK_WIDTH_U32S]) {
let mut state_u64s: [u64; 25] = std::array::from_fn(|i| {
let lo = state_u32s[i * 2] as u64;
let hi = state_u32s[i * 2 + 1] as u64;
@ -39,6 +17,17 @@ pub(crate) fn keccakf_u32s(state_u32s: &mut [u32; 50]) {
});
}
/// Like tiny-keccak's `keccakf`, but deals with bytes instead of `u64` limbs.
pub(crate) fn keccakf_u8s(state_u8s: &mut [u8; KECCAK_WIDTH_BYTES]) {
let mut state_u64s: [u64; 25] =
std::array::from_fn(|i| u64::from_le_bytes(state_u8s[i * 8..][..8].try_into().unwrap()));
keccakf(&mut state_u64s);
*state_u8s = std::array::from_fn(|i| {
let u64_limb = state_u64s[i / 8];
u64_limb.to_le_bytes()[i % 8]
});
}
#[cfg(test)]
mod tests {
use tiny_keccak::keccakf;

View File

@ -76,10 +76,6 @@ pub(crate) fn get_opcode(mnemonic: &str) -> u8 {
"MSIZE" => 0x59,
"GAS" => 0x5a,
"JUMPDEST" => 0x5b,
"GET_STATE_ROOT" => 0x5c,
"SET_STATE_ROOT" => 0x5d,
"GET_RECEIPT_ROOT" => 0x5e,
"SET_RECEIPT_ROOT" => 0x5f,
"DUP1" => 0x80,
"DUP2" => 0x81,
"DUP3" => 0x82,

View File

@ -42,7 +42,7 @@ fn prepare_interpreter(
let mut state_trie: PartialTrie = Default::default();
let trie_inputs = Default::default();
interpreter.offset = load_all_mpts;
interpreter.generation_state.registers.program_counter = load_all_mpts;
interpreter.push(0xDEADBEEFu32.into());
interpreter.generation_state.mpt_prover_inputs = all_mpt_prover_inputs_reversed(&trie_inputs);
@ -53,7 +53,7 @@ fn prepare_interpreter(
keccak(address.to_fixed_bytes()).as_bytes(),
));
// Next, execute mpt_insert_state_trie.
interpreter.offset = mpt_insert_state_trie;
interpreter.generation_state.registers.program_counter = mpt_insert_state_trie;
let trie_data = interpreter.get_trie_data_mut();
if trie_data.is_empty() {
// In the assembly we skip over 0, knowing trie_data[0] = 0 by default.
@ -83,7 +83,7 @@ fn prepare_interpreter(
);
// Now, execute mpt_hash_state_trie.
interpreter.offset = mpt_hash_state_trie;
interpreter.generation_state.registers.program_counter = mpt_hash_state_trie;
interpreter.push(0xDEADBEEFu32.into());
interpreter.run()?;
@ -115,7 +115,7 @@ fn test_extcodesize() -> Result<()> {
let extcodesize = KERNEL.global_labels["extcodesize"];
// Test `extcodesize`
interpreter.offset = extcodesize;
interpreter.generation_state.registers.program_counter = extcodesize;
interpreter.pop();
assert!(interpreter.stack().is_empty());
interpreter.push(0xDEADBEEFu32.into());
@ -144,10 +144,10 @@ fn test_extcodecopy() -> Result<()> {
// Put random data in main memory and the `KernelAccountCode` segment for realism.
let mut rng = thread_rng();
for i in 0..2000 {
interpreter.memory.context_memory[interpreter.context].segments
interpreter.generation_state.memory.contexts[interpreter.context].segments
[Segment::MainMemory as usize]
.set(i, U256::from(rng.gen::<u8>()));
interpreter.memory.context_memory[interpreter.context].segments
interpreter.generation_state.memory.contexts[interpreter.context].segments
[Segment::KernelAccountCode as usize]
.set(i, U256::from(rng.gen::<u8>()));
}
@ -158,7 +158,7 @@ fn test_extcodecopy() -> Result<()> {
let size = rng.gen_range(0..1500);
// Test `extcodecopy`
interpreter.offset = extcodecopy;
interpreter.generation_state.registers.program_counter = extcodecopy;
interpreter.pop();
assert!(interpreter.stack().is_empty());
interpreter.push(0xDEADBEEFu32.into());
@ -173,7 +173,7 @@ fn test_extcodecopy() -> Result<()> {
assert!(interpreter.stack().is_empty());
// Check that the code was correctly copied to memory.
for i in 0..size {
let memory = interpreter.memory.context_memory[interpreter.context].segments
let memory = interpreter.generation_state.memory.contexts[interpreter.context].segments
[Segment::MainMemory as usize]
.get(dest_offset + i);
assert_eq!(

View File

@ -33,7 +33,7 @@ fn prepare_interpreter(
let mut state_trie: PartialTrie = Default::default();
let trie_inputs = Default::default();
interpreter.offset = load_all_mpts;
interpreter.generation_state.registers.program_counter = load_all_mpts;
interpreter.push(0xDEADBEEFu32.into());
interpreter.generation_state.mpt_prover_inputs = all_mpt_prover_inputs_reversed(&trie_inputs);
@ -44,7 +44,7 @@ fn prepare_interpreter(
keccak(address.to_fixed_bytes()).as_bytes(),
));
// Next, execute mpt_insert_state_trie.
interpreter.offset = mpt_insert_state_trie;
interpreter.generation_state.registers.program_counter = mpt_insert_state_trie;
let trie_data = interpreter.get_trie_data_mut();
if trie_data.is_empty() {
// In the assembly we skip over 0, knowing trie_data[0] = 0 by default.
@ -74,7 +74,7 @@ fn prepare_interpreter(
);
// Now, execute mpt_hash_state_trie.
interpreter.offset = mpt_hash_state_trie;
interpreter.generation_state.registers.program_counter = mpt_hash_state_trie;
interpreter.push(0xDEADBEEFu32.into());
interpreter.run()?;
@ -105,7 +105,7 @@ fn test_balance() -> Result<()> {
prepare_interpreter(&mut interpreter, address, &account)?;
// Test `balance`
interpreter.offset = KERNEL.global_labels["balance"];
interpreter.generation_state.registers.program_counter = KERNEL.global_labels["balance"];
interpreter.pop();
assert!(interpreter.stack().is_empty());
interpreter.push(0xDEADBEEFu32.into());

View File

@ -113,7 +113,7 @@ fn test_state_trie(trie_inputs: TrieInputs) -> Result<()> {
assert_eq!(interpreter.stack(), vec![]);
// Now, execute mpt_hash_state_trie.
interpreter.offset = mpt_hash_state_trie;
interpreter.generation_state.registers.program_counter = mpt_hash_state_trie;
interpreter.push(0xDEADBEEFu32.into());
interpreter.run()?;

View File

@ -164,7 +164,7 @@ fn test_state_trie(mut state_trie: PartialTrie, k: Nibbles, mut account: Account
assert_eq!(interpreter.stack(), vec![]);
// Next, execute mpt_insert_state_trie.
interpreter.offset = mpt_insert_state_trie;
interpreter.generation_state.registers.program_counter = mpt_insert_state_trie;
let trie_data = interpreter.get_trie_data_mut();
if trie_data.is_empty() {
// In the assembly we skip over 0, knowing trie_data[0] = 0 by default.
@ -194,7 +194,7 @@ fn test_state_trie(mut state_trie: PartialTrie, k: Nibbles, mut account: Account
);
// Now, execute mpt_hash_state_trie.
interpreter.offset = mpt_hash_state_trie;
interpreter.generation_state.registers.program_counter = mpt_hash_state_trie;
interpreter.push(0xDEADBEEFu32.into());
interpreter.run()?;

View File

@ -27,7 +27,7 @@ fn mpt_read() -> Result<()> {
assert_eq!(interpreter.stack(), vec![]);
// Now, execute mpt_read on the state trie.
interpreter.offset = mpt_read;
interpreter.generation_state.registers.program_counter = mpt_read;
interpreter.push(0xdeadbeefu32.into());
interpreter.push(0xABCDEFu64.into());
interpreter.push(6.into());

View File

@ -8,7 +8,7 @@ use crate::constraint_consumer::{ConstraintConsumer, RecursiveConstraintConsumer
use crate::cpu::columns::CpuColumnsView;
/// General-purpose memory channels; they can read and write to all contexts/segments/addresses.
pub const NUM_GP_CHANNELS: usize = 4;
pub const NUM_GP_CHANNELS: usize = 5;
pub mod channel_indices {
use std::ops::Range;

View File

@ -1,6 +1,6 @@
pub(crate) mod bootstrap_kernel;
pub(crate) mod columns;
mod control_flow;
pub(crate) mod control_flow;
pub mod cpu_stark;
pub(crate) mod decode;
mod dup_swap;
@ -9,7 +9,7 @@ pub mod kernel;
pub(crate) mod membus;
mod modfp254;
mod shift;
mod simple_logic;
pub(crate) mod simple_logic;
mod stack;
mod stack_bounds;
pub(crate) mod stack_bounds;
mod syscalls;

View File

@ -14,7 +14,7 @@ pub(crate) fn eval_packed<P: PackedField>(
yield_constr: &mut ConstraintConsumer<P>,
) {
let is_shift = lv.op.shl + lv.op.shr;
let displacement = lv.mem_channels[1]; // holds the shift displacement d
let displacement = lv.mem_channels[0]; // holds the shift displacement d
let two_exp = lv.mem_channels[2]; // holds 2^d
// Not needed here; val is the input and we're verifying that output is
@ -65,7 +65,7 @@ pub(crate) fn eval_ext_circuit<F: RichField + Extendable<D>, const D: usize>(
yield_constr: &mut RecursiveConstraintConsumer<F, D>,
) {
let is_shift = builder.add_extension(lv.op.shl, lv.op.shr);
let displacement = lv.mem_channels[1];
let displacement = lv.mem_channels[0];
let two_exp = lv.mem_channels[2];
let shift_table_segment = F::from_canonical_u64(Segment::ShiftTable as u64);

View File

@ -1,34 +1,29 @@
use ethereum_types::U256;
use itertools::izip;
use plonky2::field::extension::Extendable;
use plonky2::field::packed::PackedField;
use plonky2::field::types::Field;
use plonky2::hash::hash_types::RichField;
use plonky2::iop::ext_target::ExtensionTarget;
use crate::constraint_consumer::{ConstraintConsumer, RecursiveConstraintConsumer};
use crate::cpu::columns::CpuColumnsView;
pub fn generate<F: RichField>(lv: &mut CpuColumnsView<F>) {
let input0 = lv.mem_channels[0].value;
let eq_filter = lv.op.eq.to_canonical_u64();
let iszero_filter = lv.op.iszero.to_canonical_u64();
assert!(eq_filter <= 1);
assert!(iszero_filter <= 1);
assert!(eq_filter + iszero_filter <= 1);
if eq_filter + iszero_filter == 0 {
return;
fn limbs(x: U256) -> [u32; 8] {
let mut res = [0; 8];
let x_u64: [u64; 4] = x.0;
for i in 0..4 {
res[2 * i] = x_u64[i] as u32;
res[2 * i + 1] = (x_u64[i] >> 32) as u32;
}
res
}
let input1 = &mut lv.mem_channels[1].value;
if iszero_filter != 0 {
for limb in input1.iter_mut() {
*limb = F::ZERO;
}
}
pub fn generate_pinv_diff<F: Field>(val0: U256, val1: U256, lv: &mut CpuColumnsView<F>) {
let val0_limbs = limbs(val0).map(F::from_canonical_u32);
let val1_limbs = limbs(val1).map(F::from_canonical_u32);
let input1 = lv.mem_channels[1].value;
let num_unequal_limbs = izip!(input0, input1)
let num_unequal_limbs = izip!(val0_limbs, val1_limbs)
.map(|(limb0, limb1)| (limb0 != limb1) as usize)
.sum();
let equal = num_unequal_limbs == 0;
@ -40,7 +35,7 @@ pub fn generate<F: RichField>(lv: &mut CpuColumnsView<F>) {
}
// Form `diff_pinv`.
// Let `diff = input0 - input1`. Consider `x[i] = diff[i]^-1` if `diff[i] != 0` and 0 otherwise.
// Let `diff = val0 - val1`. Consider `x[i] = diff[i]^-1` if `diff[i] != 0` and 0 otherwise.
// Then `diff @ x = num_unequal_limbs`, where `@` denotes the dot product. We set
// `diff_pinv = num_unequal_limbs^-1 * x` if `num_unequal_limbs != 0` and 0 otherwise. We have
// `diff @ diff_pinv = 1 - equal` as desired.
@ -48,7 +43,7 @@ pub fn generate<F: RichField>(lv: &mut CpuColumnsView<F>) {
let num_unequal_limbs_inv = F::from_canonical_usize(num_unequal_limbs)
.try_inverse()
.unwrap_or(F::ZERO);
for (limb_pinv, limb0, limb1) in izip!(logic.diff_pinv.iter_mut(), input0, input1) {
for (limb_pinv, limb0, limb1) in izip!(logic.diff_pinv.iter_mut(), val0_limbs, val1_limbs) {
*limb_pinv = (limb0 - limb1).try_inverse().unwrap_or(F::ZERO) * num_unequal_limbs_inv;
}
}

View File

@ -1,4 +1,4 @@
mod eq_iszero;
pub(crate) mod eq_iszero;
mod not;
use plonky2::field::extension::Extendable;
@ -9,17 +9,6 @@ use plonky2::iop::ext_target::ExtensionTarget;
use crate::constraint_consumer::{ConstraintConsumer, RecursiveConstraintConsumer};
use crate::cpu::columns::CpuColumnsView;
pub fn generate<F: RichField>(lv: &mut CpuColumnsView<F>) {
let cycle_filter = lv.is_cpu_cycle.to_canonical_u64();
if cycle_filter == 0 {
return;
}
assert_eq!(cycle_filter, 1);
not::generate(lv);
eq_iszero::generate(lv);
}
pub fn eval_packed<P: PackedField>(
lv: &CpuColumnsView<P>,
yield_constr: &mut ConstraintConsumer<P>,

View File

@ -6,34 +6,18 @@ use plonky2::iop::ext_target::ExtensionTarget;
use crate::constraint_consumer::{ConstraintConsumer, RecursiveConstraintConsumer};
use crate::cpu::columns::CpuColumnsView;
use crate::cpu::membus::NUM_GP_CHANNELS;
const LIMB_SIZE: usize = 32;
const ALL_1_LIMB: u64 = (1 << LIMB_SIZE) - 1;
pub fn generate<F: RichField>(lv: &mut CpuColumnsView<F>) {
let is_not_filter = lv.op.not.to_canonical_u64();
if is_not_filter == 0 {
return;
}
assert_eq!(is_not_filter, 1);
let input = lv.mem_channels[0].value;
let output = &mut lv.mem_channels[1].value;
for (input, output_ref) in input.into_iter().zip(output.iter_mut()) {
let input = input.to_canonical_u64();
assert_eq!(input >> LIMB_SIZE, 0);
let output = input ^ ALL_1_LIMB;
*output_ref = F::from_canonical_u64(output);
}
}
pub fn eval_packed<P: PackedField>(
lv: &CpuColumnsView<P>,
yield_constr: &mut ConstraintConsumer<P>,
) {
// This is simple: just do output = 0xffffffff - input.
let input = lv.mem_channels[0].value;
let output = lv.mem_channels[1].value;
let output = lv.mem_channels[NUM_GP_CHANNELS - 1].value;
let cycle_filter = lv.is_cpu_cycle;
let is_not_filter = lv.op.not;
let filter = cycle_filter * is_not_filter;
@ -50,7 +34,7 @@ pub fn eval_ext_circuit<F: RichField + Extendable<D>, const D: usize>(
yield_constr: &mut RecursiveConstraintConsumer<F, D>,
) {
let input = lv.mem_channels[0].value;
let output = lv.mem_channels[1].value;
let output = lv.mem_channels[NUM_GP_CHANNELS - 1].value;
let cycle_filter = lv.is_cpu_cycle;
let is_not_filter = lv.op.not;
let filter = builder.mul_extension(cycle_filter, is_not_filter);

View File

@ -61,19 +61,15 @@ const STACK_BEHAVIORS: OpsColumnsView<Option<StackBehavior>> = OpsColumnsView {
byte: BASIC_BINARY_OP,
shl: BASIC_BINARY_OP,
shr: BASIC_BINARY_OP,
keccak_general: None, // TODO
prover_input: None, // TODO
pop: None, // TODO
jump: None, // TODO
jumpi: None, // TODO
pc: None, // TODO
gas: None, // TODO
jumpdest: None, // TODO
get_state_root: None, // TODO
set_state_root: None, // TODO
get_receipt_root: None, // TODO
set_receipt_root: None, // TODO
push: None, // TODO
keccak_general: None, // TODO
prover_input: None, // TODO
pop: None, // TODO
jump: None, // TODO
jumpi: None, // TODO
pc: None, // TODO
gas: None, // TODO
jumpdest: None, // TODO
push: None, // TODO
dup: None,
swap: None,
get_context: None, // TODO

View File

@ -19,7 +19,7 @@ use plonky2::iop::ext_target::ExtensionTarget;
use crate::constraint_consumer::{ConstraintConsumer, RecursiveConstraintConsumer};
use crate::cpu::columns::{CpuColumnsView, COL_MAP};
const MAX_USER_STACK_SIZE: u64 = 1024;
pub const MAX_USER_STACK_SIZE: usize = 1024;
// Below only includes the operations that pop the top of the stack **without reading the value from
// memory**, i.e. `POP`.
@ -45,7 +45,7 @@ pub fn generate<F: Field>(lv: &mut CpuColumnsView<F>) {
let check_overflow: F = INCREMENTING_FLAGS.map(|i| lv[i]).into_iter().sum();
let no_check = F::ONE - (check_underflow + check_overflow);
let disallowed_len = check_overflow * F::from_canonical_u64(MAX_USER_STACK_SIZE) - no_check;
let disallowed_len = check_overflow * F::from_canonical_usize(MAX_USER_STACK_SIZE) - no_check;
let diff = lv.stack_len - disallowed_len;
let user_mode = F::ONE - lv.is_kernel_mode;
@ -84,7 +84,7 @@ pub fn eval_packed<P: PackedField>(
// 0 if `check_underflow`, `MAX_USER_STACK_SIZE` if `check_overflow`, and -1 if `no_check`.
let disallowed_len =
check_overflow * P::Scalar::from_canonical_u64(MAX_USER_STACK_SIZE) - no_check;
check_overflow * P::Scalar::from_canonical_usize(MAX_USER_STACK_SIZE) - no_check;
// This `lhs` must equal some `rhs`. If `rhs` is nonzero, then this shows that `lv.stack_len` is
// not `disallowed_len`.
let lhs = (lv.stack_len - disallowed_len) * lv.stack_len_bounds_aux;
@ -108,7 +108,7 @@ pub fn eval_ext_circuit<F: RichField + Extendable<D>, const D: usize>(
) {
let one = builder.one_extension();
let max_stack_size =
builder.constant_extension(F::from_canonical_u64(MAX_USER_STACK_SIZE).into());
builder.constant_extension(F::from_canonical_usize(MAX_USER_STACK_SIZE).into());
// `check_underflow`, `check_overflow`, and `no_check` are mutually exclusive.
let check_underflow = builder.add_many_extension(DECREMENTING_FLAGS.map(|i| lv[i]));

View File

@ -145,7 +145,7 @@ impl<F: Field> Column<F> {
pub struct TableWithColumns<F: Field> {
table: Table,
columns: Vec<Column<F>>,
filter_column: Option<Column<F>>,
pub(crate) filter_column: Option<Column<F>>,
}
impl<F: Field> TableWithColumns<F> {
@ -160,8 +160,8 @@ impl<F: Field> TableWithColumns<F> {
#[derive(Clone)]
pub struct CrossTableLookup<F: Field> {
looking_tables: Vec<TableWithColumns<F>>,
looked_table: TableWithColumns<F>,
pub(crate) looking_tables: Vec<TableWithColumns<F>>,
pub(crate) looked_table: TableWithColumns<F>,
/// Default value if filters are not used.
default: Option<Vec<F>>,
}
@ -248,6 +248,7 @@ pub fn cross_table_lookup_data<F: RichField, C: GenericConfig<D, F = F>, const D
default,
} in cross_table_lookups
{
log::debug!("Processing CTL for {:?}", looked_table.table);
for &challenge in &challenges.challenges {
let zs_looking = looking_tables.iter().map(|table| {
partial_products(
@ -610,16 +611,15 @@ pub(crate) fn verify_cross_table_lookups<
.product::<F>();
let looked_z = *ctl_zs_openings[looked_table.table as usize].next().unwrap();
let challenge = challenges.challenges[i % config.num_challenges];
let combined_default = default
.as_ref()
.map(|default| challenge.combine(default.iter()))
.unwrap_or(F::ONE);
ensure!(
looking_zs_prod
== looked_z * combined_default.exp_u64(looking_degrees_sum - looked_degree),
"Cross-table lookup verification failed."
);
if let Some(default) = default.as_ref() {
let combined_default = challenge.combine(default.iter());
ensure!(
looking_zs_prod
== looked_z * combined_default.exp_u64(looking_degrees_sum - looked_degree),
"Cross-table lookup verification failed."
);
}
}
}
debug_assert!(ctl_zs_openings.iter_mut().all(|iter| iter.next().is_none()));
@ -694,6 +694,7 @@ pub(crate) mod testutils {
type MultiSet<F> = HashMap<Vec<F>, Vec<(Table, usize)>>;
/// Check that the provided traces and cross-table lookups are consistent.
#[allow(unused)] // TODO: used later?
pub(crate) fn check_ctls<F: Field>(
trace_poly_values: &[Vec<PolynomialValues<F>>],
cross_table_lookups: &[CrossTableLookup<F>],

View File

@ -1,50 +0,0 @@
use ethereum_types::U256;
use crate::memory::memory_stark::MemoryOp;
use crate::memory::segments::Segment;
#[allow(unused)] // TODO: Should be used soon.
#[derive(Debug)]
pub(crate) struct MemoryState {
/// A log of each memory operation, in the order that it occurred.
pub log: Vec<MemoryOp>,
pub contexts: Vec<MemoryContextState>,
}
impl Default for MemoryState {
fn default() -> Self {
Self {
log: vec![],
// We start with an initial context for the kernel.
contexts: vec![MemoryContextState::default()],
}
}
}
#[derive(Clone, Default, Debug)]
pub(crate) struct MemoryContextState {
/// The content of each memory segment.
pub segments: [MemorySegmentState; Segment::COUNT],
}
#[derive(Clone, Default, Debug)]
pub(crate) struct MemorySegmentState {
pub content: Vec<U256>,
}
impl MemorySegmentState {
pub(crate) fn get(&self, virtual_addr: usize) -> U256 {
self.content
.get(virtual_addr)
.copied()
.unwrap_or(U256::zero())
}
pub(crate) fn set(&mut self, virtual_addr: usize, value: U256) {
if virtual_addr >= self.content.len() {
self.content.resize(virtual_addr + 1, U256::zero());
}
self.content[virtual_addr] = value;
}
}

View File

@ -4,24 +4,27 @@ use eth_trie_utils::partial_trie::PartialTrie;
use ethereum_types::{Address, BigEndianHash, H256};
use plonky2::field::extension::Extendable;
use plonky2::field::polynomial::PolynomialValues;
use plonky2::field::types::Field;
use plonky2::hash::hash_types::RichField;
use plonky2::timed;
use plonky2::util::timing::TimingTree;
use serde::{Deserialize, Serialize};
use GlobalMetadata::{
ReceiptTrieRootDigestAfter, ReceiptTrieRootDigestBefore, StateTrieRootDigestAfter,
StateTrieRootDigestBefore, TransactionTrieRootDigestAfter, TransactionTrieRootDigestBefore,
};
use crate::all_stark::{AllStark, NUM_TABLES};
use crate::config::StarkConfig;
use crate::cpu::bootstrap_kernel::generate_bootstrap_kernel;
use crate::cpu::columns::NUM_CPU_COLUMNS;
use crate::cpu::kernel::aggregator::KERNEL;
use crate::cpu::kernel::constants::global_metadata::GlobalMetadata;
use crate::generation::state::GenerationState;
use crate::memory::segments::Segment;
use crate::memory::NUM_CHANNELS;
use crate::proof::{BlockMetadata, PublicValues, TrieRoots};
use crate::util::trace_rows_to_poly_values;
use crate::witness::memory::MemoryAddress;
use crate::witness::transition::transition;
pub(crate) mod memory;
pub(crate) mod mpt;
pub mod mpt;
pub(crate) mod prover_input;
pub(crate) mod rlp;
pub(crate) mod state;
@ -65,79 +68,68 @@ pub(crate) fn generate_traces<F: RichField + Extendable<D>, const D: usize>(
config: &StarkConfig,
timing: &mut TimingTree,
) -> ([Vec<PolynomialValues<F>>; NUM_TABLES], PublicValues) {
let mut state = GenerationState::<F>::new(inputs.clone());
let mut state = GenerationState::<F>::new(inputs.clone(), &KERNEL.code);
generate_bootstrap_kernel::<F>(&mut state);
for txn in &inputs.signed_txns {
generate_txn(&mut state, txn);
}
timed!(timing, "simulate CPU", simulate_cpu(&mut state));
// TODO: Pad to a power of two, ending in the `halt` kernel function.
log::info!(
"Trace lengths (before padding): {:?}",
state.traces.checkpoint()
);
let cpu_rows = state.cpu_rows.len();
let mem_end_timestamp = cpu_rows * NUM_CHANNELS;
let mut read_metadata = |field| {
state.get_mem(
let read_metadata = |field| {
state.memory.get(MemoryAddress::new(
0,
Segment::GlobalMetadata,
field as usize,
mem_end_timestamp,
)
))
};
let trie_roots_before = TrieRoots {
state_root: H256::from_uint(&read_metadata(GlobalMetadata::StateTrieRootDigestBefore)),
transactions_root: H256::from_uint(&read_metadata(
GlobalMetadata::TransactionTrieRootDigestBefore,
)),
receipts_root: H256::from_uint(&read_metadata(GlobalMetadata::ReceiptTrieRootDigestBefore)),
state_root: H256::from_uint(&read_metadata(StateTrieRootDigestBefore)),
transactions_root: H256::from_uint(&read_metadata(TransactionTrieRootDigestBefore)),
receipts_root: H256::from_uint(&read_metadata(ReceiptTrieRootDigestBefore)),
};
let trie_roots_after = TrieRoots {
state_root: H256::from_uint(&read_metadata(GlobalMetadata::StateTrieRootDigestAfter)),
transactions_root: H256::from_uint(&read_metadata(
GlobalMetadata::TransactionTrieRootDigestAfter,
)),
receipts_root: H256::from_uint(&read_metadata(GlobalMetadata::ReceiptTrieRootDigestAfter)),
state_root: H256::from_uint(&read_metadata(StateTrieRootDigestAfter)),
transactions_root: H256::from_uint(&read_metadata(TransactionTrieRootDigestAfter)),
receipts_root: H256::from_uint(&read_metadata(ReceiptTrieRootDigestAfter)),
};
let GenerationState {
cpu_rows,
current_cpu_row,
memory,
keccak_inputs,
keccak_memory_inputs,
logic_ops,
..
} = state;
assert_eq!(current_cpu_row, [F::ZERO; NUM_CPU_COLUMNS].into());
let cpu_trace = trace_rows_to_poly_values(cpu_rows);
let keccak_trace = all_stark.keccak_stark.generate_trace(keccak_inputs, timing);
let keccak_memory_trace = all_stark.keccak_memory_stark.generate_trace(
keccak_memory_inputs,
config.fri_config.num_cap_elements(),
timing,
);
let logic_trace = all_stark.logic_stark.generate_trace(logic_ops, timing);
let memory_trace = all_stark.memory_stark.generate_trace(memory.log, timing);
let traces = [
cpu_trace,
keccak_trace,
keccak_memory_trace,
logic_trace,
memory_trace,
];
let public_values = PublicValues {
trie_roots_before,
trie_roots_after,
block_metadata: inputs.block_metadata,
};
(traces, public_values)
let tables = timed!(
timing,
"convert trace data to tables",
state.traces.into_tables(all_stark, config, timing)
);
(tables, public_values)
}
fn generate_txn<F: Field>(_state: &mut GenerationState<F>, _signed_txn: &[u8]) {
// TODO
fn simulate_cpu<F: RichField + Extendable<D>, const D: usize>(state: &mut GenerationState<F>) {
let halt_pc0 = KERNEL.global_labels["halt_pc0"];
let halt_pc1 = KERNEL.global_labels["halt_pc1"];
let mut already_in_halt_loop = false;
loop {
// If we've reached the kernel's halt routine, and our trace length is a power of 2, stop.
let pc = state.registers.program_counter;
let in_halt_loop = pc == halt_pc0 || pc == halt_pc1;
if in_halt_loop && !already_in_halt_loop {
log::info!("CPU halted after {} cycles", state.traces.clock());
}
already_in_halt_loop |= in_halt_loop;
if already_in_halt_loop && state.traces.clock().is_power_of_two() {
log::info!("CPU trace padded to {} cycles", state.traces.clock());
break;
}
transition(state);
}
}

View File

@ -9,11 +9,11 @@ use crate::cpu::kernel::constants::trie_type::PartialTrieType;
use crate::generation::TrieInputs;
#[derive(RlpEncodable, RlpDecodable, Debug)]
pub(crate) struct AccountRlp {
pub(crate) nonce: U256,
pub(crate) balance: U256,
pub(crate) storage_root: H256,
pub(crate) code_hash: H256,
pub struct AccountRlp {
pub nonce: U256,
pub balance: U256,
pub storage_root: H256,
pub code_hash: H256,
}
pub(crate) fn all_mpt_prover_inputs_reversed(trie_inputs: &TrieInputs) -> Vec<U256> {

View File

@ -8,6 +8,7 @@ use crate::generation::prover_input::EvmField::{
};
use crate::generation::prover_input::FieldOp::{Inverse, Sqrt};
use crate::generation::state::GenerationState;
use crate::witness::util::stack_peek;
/// Prover input function represented as a scoped function name.
/// Example: `PROVER_INPUT(ff::bn254_base::inverse)` is represented as `ProverInputFn([ff, bn254_base, inverse])`.
@ -21,14 +22,13 @@ impl From<Vec<String>> for ProverInputFn {
}
impl<F: Field> GenerationState<F> {
#[allow(unused)] // TODO: Should be used soon.
pub(crate) fn prover_input(&mut self, stack: &[U256], input_fn: &ProverInputFn) -> U256 {
pub(crate) fn prover_input(&mut self, input_fn: &ProverInputFn) -> U256 {
match input_fn.0[0].as_str() {
"end_of_txns" => self.run_end_of_txns(),
"ff" => self.run_ff(stack, input_fn),
"ff" => self.run_ff(input_fn),
"mpt" => self.run_mpt(),
"rlp" => self.run_rlp(),
"account_code" => self.run_account_code(stack, input_fn),
"account_code" => self.run_account_code(input_fn),
_ => panic!("Unrecognized prover input function."),
}
}
@ -44,10 +44,10 @@ impl<F: Field> GenerationState<F> {
}
/// Finite field operations.
fn run_ff(&self, stack: &[U256], input_fn: &ProverInputFn) -> U256 {
fn run_ff(&self, input_fn: &ProverInputFn) -> U256 {
let field = EvmField::from_str(input_fn.0[1].as_str()).unwrap();
let op = FieldOp::from_str(input_fn.0[2].as_str()).unwrap();
let x = *stack.last().expect("Empty stack");
let x = stack_peek(self, 0).expect("Empty stack");
field.op(op, x)
}
@ -66,22 +66,21 @@ impl<F: Field> GenerationState<F> {
}
/// Account code.
fn run_account_code(&mut self, stack: &[U256], input_fn: &ProverInputFn) -> U256 {
fn run_account_code(&mut self, input_fn: &ProverInputFn) -> U256 {
match input_fn.0[1].as_str() {
"length" => {
// Return length of code.
// stack: codehash, ...
let codehash = stack.last().expect("Empty stack");
self.inputs.contract_code[&H256::from_uint(codehash)]
let codehash = stack_peek(self, 0).expect("Empty stack");
self.inputs.contract_code[&H256::from_uint(&codehash)]
.len()
.into()
}
"get" => {
// Return `code[i]`.
// stack: i, code_length, codehash, ...
let stacklen = stack.len();
let i = stack[stacklen - 1].as_usize();
let codehash = stack[stacklen - 3];
let i = stack_peek(self, 0).expect("Unexpected stack").as_usize();
let codehash = stack_peek(self, 2).expect("Unexpected stack");
self.inputs.contract_code[&H256::from_uint(&codehash)][i].into()
}
_ => panic!("Invalid prover input function."),

View File

@ -1,35 +1,26 @@
use std::mem;
use ethereum_types::U256;
use plonky2::field::types::Field;
use tiny_keccak::keccakf;
use crate::cpu::columns::{CpuColumnsView, NUM_CPU_COLUMNS};
use crate::generation::memory::MemoryState;
use crate::generation::mpt::all_mpt_prover_inputs_reversed;
use crate::generation::rlp::all_rlp_prover_inputs_reversed;
use crate::generation::GenerationInputs;
use crate::keccak_memory::keccak_memory_stark::KeccakMemoryOp;
use crate::memory::memory_stark::MemoryOp;
use crate::memory::segments::Segment;
use crate::memory::NUM_CHANNELS;
use crate::util::u256_limbs;
use crate::{keccak, logic};
use crate::witness::memory::MemoryState;
use crate::witness::state::RegistersState;
use crate::witness::traces::{TraceCheckpoint, Traces};
pub(crate) struct GenerationStateCheckpoint {
pub(crate) registers: RegistersState,
pub(crate) traces: TraceCheckpoint,
}
#[derive(Debug)]
pub(crate) struct GenerationState<F: Field> {
#[allow(unused)] // TODO: Should be used soon.
pub(crate) inputs: GenerationInputs,
pub(crate) next_txn_index: usize,
pub(crate) cpu_rows: Vec<[F; NUM_CPU_COLUMNS]>,
pub(crate) current_cpu_row: CpuColumnsView<F>,
pub(crate) current_context: usize,
pub(crate) registers: RegistersState,
pub(crate) memory: MemoryState,
pub(crate) traces: Traces<F>,
pub(crate) keccak_inputs: Vec<[u64; keccak::keccak_stark::NUM_INPUTS]>,
pub(crate) keccak_memory_inputs: Vec<KeccakMemoryOp>,
pub(crate) logic_ops: Vec<logic::Operation>,
pub(crate) next_txn_index: usize,
/// Prover inputs containing MPT data, in reverse order so that the next input can be obtained
/// via `pop()`.
@ -41,212 +32,30 @@ pub(crate) struct GenerationState<F: Field> {
}
impl<F: Field> GenerationState<F> {
pub(crate) fn new(inputs: GenerationInputs) -> Self {
pub(crate) fn new(inputs: GenerationInputs, kernel_code: &[u8]) -> Self {
let mpt_prover_inputs = all_mpt_prover_inputs_reversed(&inputs.tries);
let rlp_prover_inputs = all_rlp_prover_inputs_reversed(&inputs.signed_txns);
Self {
inputs,
registers: Default::default(),
memory: MemoryState::new(kernel_code),
traces: Traces::default(),
next_txn_index: 0,
cpu_rows: vec![],
current_cpu_row: [F::ZERO; NUM_CPU_COLUMNS].into(),
current_context: 0,
memory: MemoryState::default(),
keccak_inputs: vec![],
keccak_memory_inputs: vec![],
logic_ops: vec![],
mpt_prover_inputs,
rlp_prover_inputs,
}
}
/// Compute logical AND, and record the operation to be added in the logic table later.
#[allow(unused)] // TODO: Should be used soon.
pub(crate) fn and(&mut self, input0: U256, input1: U256) -> U256 {
self.logic_op(logic::Op::And, input0, input1)
pub fn checkpoint(&self) -> GenerationStateCheckpoint {
GenerationStateCheckpoint {
registers: self.registers,
traces: self.traces.checkpoint(),
}
}
/// Compute logical OR, and record the operation to be added in the logic table later.
#[allow(unused)] // TODO: Should be used soon.
pub(crate) fn or(&mut self, input0: U256, input1: U256) -> U256 {
self.logic_op(logic::Op::Or, input0, input1)
}
/// Compute logical XOR, and record the operation to be added in the logic table later.
#[allow(unused)] // TODO: Should be used soon.
pub(crate) fn xor(&mut self, input0: U256, input1: U256) -> U256 {
self.logic_op(logic::Op::Xor, input0, input1)
}
/// Compute logical AND, and record the operation to be added in the logic table later.
pub(crate) fn logic_op(&mut self, op: logic::Op, input0: U256, input1: U256) -> U256 {
let operation = logic::Operation::new(op, input0, input1);
let result = operation.result;
self.logic_ops.push(operation);
result
}
/// Like `get_mem_cpu`, but reads from the current context specifically.
#[allow(unused)] // TODO: Should be used soon.
pub(crate) fn get_mem_cpu_current(
&mut self,
channel_index: usize,
segment: Segment,
virt: usize,
) -> U256 {
let context = self.current_context;
self.get_mem_cpu(channel_index, context, segment, virt)
}
/// Simulates the CPU reading some memory through the given channel. Besides logging the memory
/// operation, this also generates the associated registers in the current CPU row.
pub(crate) fn get_mem_cpu(
&mut self,
channel_index: usize,
context: usize,
segment: Segment,
virt: usize,
) -> U256 {
let timestamp = self.cpu_rows.len() * NUM_CHANNELS + channel_index;
let value = self.get_mem(context, segment, virt, timestamp);
let channel = &mut self.current_cpu_row.mem_channels[channel_index];
channel.used = F::ONE;
channel.is_read = F::ONE;
channel.addr_context = F::from_canonical_usize(context);
channel.addr_segment = F::from_canonical_usize(segment as usize);
channel.addr_virtual = F::from_canonical_usize(virt);
channel.value = u256_limbs(value);
value
}
/// Read some memory, and log the operation.
pub(crate) fn get_mem(
&mut self,
context: usize,
segment: Segment,
virt: usize,
timestamp: usize,
) -> U256 {
let value = self.memory.contexts[context].segments[segment as usize].get(virt);
self.memory.log.push(MemoryOp {
filter: true,
timestamp,
is_read: true,
context,
segment,
virt,
value,
});
value
}
/// Write some memory within the current execution context, and log the operation.
pub(crate) fn set_mem_cpu_current(
&mut self,
channel_index: usize,
segment: Segment,
virt: usize,
value: U256,
) {
let context = self.current_context;
self.set_mem_cpu(channel_index, context, segment, virt, value);
}
/// Write some memory, and log the operation.
pub(crate) fn set_mem_cpu(
&mut self,
channel_index: usize,
context: usize,
segment: Segment,
virt: usize,
value: U256,
) {
let timestamp = self.cpu_rows.len() * NUM_CHANNELS + channel_index;
self.set_mem(context, segment, virt, value, timestamp);
let channel = &mut self.current_cpu_row.mem_channels[channel_index];
channel.used = F::ONE;
channel.is_read = F::ZERO; // For clarity; should already be 0.
channel.addr_context = F::from_canonical_usize(context);
channel.addr_segment = F::from_canonical_usize(segment as usize);
channel.addr_virtual = F::from_canonical_usize(virt);
channel.value = u256_limbs(value);
}
/// Write some memory, and log the operation.
pub(crate) fn set_mem(
&mut self,
context: usize,
segment: Segment,
virt: usize,
value: U256,
timestamp: usize,
) {
self.memory.log.push(MemoryOp {
filter: true,
timestamp,
is_read: false,
context,
segment,
virt,
value,
});
self.memory.contexts[context].segments[segment as usize].set(virt, value)
}
/// Evaluate the Keccak-f permutation in-place on some data in memory, and record the operations
/// for the purpose of witness generation.
#[allow(unused)] // TODO: Should be used soon.
pub(crate) fn keccak_memory(
&mut self,
context: usize,
segment: Segment,
virt: usize,
) -> [u64; keccak::keccak_stark::NUM_INPUTS] {
let read_timestamp = self.cpu_rows.len() * NUM_CHANNELS;
let _write_timestamp = read_timestamp + 1;
let input = (0..25)
.map(|i| {
let bytes = [0, 1, 2, 3, 4, 5, 6, 7].map(|j| {
let virt = virt + i * 8 + j;
let byte = self.get_mem(context, segment, virt, read_timestamp);
debug_assert!(byte.bits() <= 8);
byte.as_u32() as u8
});
u64::from_le_bytes(bytes)
})
.collect::<Vec<_>>()
.try_into()
.unwrap();
let output = self.keccak(input);
self.keccak_memory_inputs.push(KeccakMemoryOp {
context,
segment,
virt,
read_timestamp,
input,
output,
});
// TODO: Write output to memory.
output
}
/// Evaluate the Keccak-f permutation, and record the operation for the purpose of witness
/// generation.
pub(crate) fn keccak(
&mut self,
mut input: [u64; keccak::keccak_stark::NUM_INPUTS],
) -> [u64; keccak::keccak_stark::NUM_INPUTS] {
self.keccak_inputs.push(input);
keccakf(&mut input);
input
}
pub(crate) fn commit_cpu_row(&mut self) {
let mut swapped_row = [F::ZERO; NUM_CPU_COLUMNS].into();
mem::swap(&mut self.current_cpu_row, &mut swapped_row);
self.cpu_rows.push(swapped_row.into());
pub fn rollback(&mut self, checkpoint: GenerationStateCheckpoint) {
self.registers = checkpoint.registers;
self.traces.rollback(checkpoint.traces);
}
}

View File

@ -1,7 +1,6 @@
use std::marker::PhantomData;
use itertools::Itertools;
use log::info;
use plonky2::field::extension::{Extendable, FieldExtension};
use plonky2::field::packed::PackedField;
use plonky2::field::polynomial::PolynomialValues;
@ -39,6 +38,7 @@ pub fn ctl_data<F: Field>() -> Vec<Column<F>> {
}
pub fn ctl_filter<F: Field>() -> Column<F> {
// TODO: Also need to filter out padding rows somehow.
Column::single(reg_step(NUM_ROUNDS - 1))
}
@ -50,12 +50,14 @@ pub struct KeccakStark<F, const D: usize> {
impl<F: RichField + Extendable<D>, const D: usize> KeccakStark<F, D> {
/// Generate the rows of the trace. Note that this does not generate the permuted columns used
/// in our lookup arguments, as those are computed after transposing to column-wise form.
pub(crate) fn generate_trace_rows(
fn generate_trace_rows(
&self,
inputs: Vec<[u64; NUM_INPUTS]>,
min_rows: usize,
) -> Vec<[F; NUM_COLUMNS]> {
let num_rows = (inputs.len() * NUM_ROUNDS).next_power_of_two();
info!("{} rows", num_rows);
let num_rows = (inputs.len() * NUM_ROUNDS)
.max(min_rows)
.next_power_of_two();
let mut rows = Vec::with_capacity(num_rows);
for input in inputs.iter() {
rows.extend(self.generate_trace_rows_for_perm(*input));
@ -204,13 +206,14 @@ impl<F: RichField + Extendable<D>, const D: usize> KeccakStark<F, D> {
pub fn generate_trace(
&self,
inputs: Vec<[u64; NUM_INPUTS]>,
min_rows: usize,
timing: &mut TimingTree,
) -> Vec<PolynomialValues<F>> {
// Generate the witness, except for permuted columns in the lookup argument.
let trace_rows = timed!(
timing,
"generate trace rows",
self.generate_trace_rows(inputs)
self.generate_trace_rows(inputs, min_rows)
);
let trace_polys = timed!(
timing,
@ -598,7 +601,7 @@ mod tests {
f: Default::default(),
};
let rows = stark.generate_trace_rows(vec![input.try_into().unwrap()]);
let rows = stark.generate_trace_rows(vec![input.try_into().unwrap()], 8);
let last_row = rows[NUM_ROUNDS - 1];
let output = (0..NUM_INPUTS)
.map(|i| {
@ -637,7 +640,7 @@ mod tests {
let trace_poly_values = timed!(
timing,
"generate trace",
stark.generate_trace(input.try_into().unwrap(), &mut timing)
stark.generate_trace(input.try_into().unwrap(), 8, &mut timing)
);
// TODO: Cloning this isn't great; consider having `from_values` accept a reference,

View File

@ -1,29 +0,0 @@
pub(crate) const KECCAK_WIDTH_BYTES: usize = 200;
/// 1 if this row represents a real operation; 0 if it's a padding row.
pub(crate) const COL_IS_REAL: usize = 0;
// The address at which we will read inputs and write outputs.
pub(crate) const COL_CONTEXT: usize = 1;
pub(crate) const COL_SEGMENT: usize = 2;
pub(crate) const COL_VIRTUAL: usize = 3;
/// The timestamp at which inputs should be read from memory.
/// Outputs will be written at the following timestamp.
pub(crate) const COL_READ_TIMESTAMP: usize = 4;
const START_INPUT_LIMBS: usize = 5;
/// A byte of the input.
pub(crate) fn col_input_byte(i: usize) -> usize {
debug_assert!(i < KECCAK_WIDTH_BYTES);
START_INPUT_LIMBS + i
}
const START_OUTPUT_LIMBS: usize = START_INPUT_LIMBS + KECCAK_WIDTH_BYTES;
/// A byte of the output.
pub(crate) fn col_output_byte(i: usize) -> usize {
debug_assert!(i < KECCAK_WIDTH_BYTES);
START_OUTPUT_LIMBS + i
}
pub const NUM_COLUMNS: usize = START_OUTPUT_LIMBS + KECCAK_WIDTH_BYTES;

View File

@ -1,224 +0,0 @@
use std::marker::PhantomData;
use plonky2::field::extension::{Extendable, FieldExtension};
use plonky2::field::packed::PackedField;
use plonky2::field::polynomial::PolynomialValues;
use plonky2::field::types::Field;
use plonky2::hash::hash_types::RichField;
use plonky2::timed;
use plonky2::util::timing::TimingTree;
use crate::constraint_consumer::{ConstraintConsumer, RecursiveConstraintConsumer};
use crate::cross_table_lookup::Column;
use crate::keccak::keccak_stark::NUM_INPUTS;
use crate::keccak_memory::columns::*;
use crate::memory::segments::Segment;
use crate::stark::Stark;
use crate::util::trace_rows_to_poly_values;
use crate::vars::{StarkEvaluationTargets, StarkEvaluationVars};
pub(crate) fn ctl_looked_data<F: Field>() -> Vec<Column<F>> {
Column::singles([COL_CONTEXT, COL_SEGMENT, COL_VIRTUAL, COL_READ_TIMESTAMP]).collect()
}
pub(crate) fn ctl_looking_keccak<F: Field>() -> Vec<Column<F>> {
let input_cols = (0..50).map(|i| {
Column::le_bytes((0..4).map(|j| {
let byte_index = i * 4 + j;
col_input_byte(byte_index)
}))
});
let output_cols = (0..50).map(|i| {
Column::le_bytes((0..4).map(|j| {
let byte_index = i * 4 + j;
col_output_byte(byte_index)
}))
});
input_cols.chain(output_cols).collect()
}
pub(crate) fn ctl_looking_memory<F: Field>(i: usize, is_read: bool) -> Vec<Column<F>> {
let mut res = vec![Column::constant(F::from_bool(is_read))];
res.extend(Column::singles([COL_CONTEXT, COL_SEGMENT, COL_VIRTUAL]));
res.push(Column::single(col_input_byte(i)));
// Since we're reading or writing a single byte, the higher limbs must be zero.
res.extend((1..8).map(|_| Column::zero()));
// Since COL_READ_TIMESTAMP is the read time, we add 1 if this is a write.
let is_write_f = F::from_bool(!is_read);
res.push(Column::linear_combination_with_constant(
[(COL_READ_TIMESTAMP, F::ONE)],
is_write_f,
));
assert_eq!(
res.len(),
crate::memory::memory_stark::ctl_data::<F>().len()
);
res
}
/// CTL filter used for both directions (looked and looking).
pub(crate) fn ctl_filter<F: Field>() -> Column<F> {
Column::single(COL_IS_REAL)
}
/// Information about a Keccak memory operation needed for witness generation.
#[derive(Debug)]
pub(crate) struct KeccakMemoryOp {
// The address at which we will read inputs and write outputs.
pub(crate) context: usize,
pub(crate) segment: Segment,
pub(crate) virt: usize,
/// The timestamp at which inputs should be read from memory.
/// Outputs will be written at the following timestamp.
pub(crate) read_timestamp: usize,
/// The input that was read at that address.
pub(crate) input: [u64; NUM_INPUTS],
pub(crate) output: [u64; NUM_INPUTS],
}
#[derive(Copy, Clone, Default)]
pub struct KeccakMemoryStark<F, const D: usize> {
pub(crate) f: PhantomData<F>,
}
impl<F: RichField + Extendable<D>, const D: usize> KeccakMemoryStark<F, D> {
#[allow(unused)] // TODO: Should be used soon.
pub(crate) fn generate_trace(
&self,
operations: Vec<KeccakMemoryOp>,
min_rows: usize,
timing: &mut TimingTree,
) -> Vec<PolynomialValues<F>> {
// Generate the witness row-wise.
let trace_rows = timed!(
timing,
"generate trace rows",
self.generate_trace_rows(operations, min_rows)
);
let trace_polys = timed!(
timing,
"convert to PolynomialValues",
trace_rows_to_poly_values(trace_rows)
);
trace_polys
}
fn generate_trace_rows(
&self,
operations: Vec<KeccakMemoryOp>,
min_rows: usize,
) -> Vec<[F; NUM_COLUMNS]> {
let num_rows = operations.len().max(min_rows).next_power_of_two();
let mut rows = Vec::with_capacity(num_rows);
for op in operations {
rows.push(self.generate_row_for_op(op));
}
let padding_row = self.generate_padding_row();
for _ in rows.len()..num_rows {
rows.push(padding_row);
}
rows
}
fn generate_row_for_op(&self, op: KeccakMemoryOp) -> [F; NUM_COLUMNS] {
let mut row = [F::ZERO; NUM_COLUMNS];
row[COL_IS_REAL] = F::ONE;
row[COL_CONTEXT] = F::from_canonical_usize(op.context);
row[COL_SEGMENT] = F::from_canonical_usize(op.segment as usize);
row[COL_VIRTUAL] = F::from_canonical_usize(op.virt);
row[COL_READ_TIMESTAMP] = F::from_canonical_usize(op.read_timestamp);
for i in 0..25 {
let input_u64 = op.input[i];
let output_u64 = op.output[i];
for j in 0..8 {
let byte_index = i * 8 + j;
row[col_input_byte(byte_index)] = F::from_canonical_u8(input_u64.to_le_bytes()[j]);
row[col_output_byte(byte_index)] =
F::from_canonical_u8(output_u64.to_le_bytes()[j]);
}
}
row
}
fn generate_padding_row(&self) -> [F; NUM_COLUMNS] {
// We just need COL_IS_REAL to be zero, which it is by default.
// The other fields will have no effect.
[F::ZERO; NUM_COLUMNS]
}
}
impl<F: RichField + Extendable<D>, const D: usize> Stark<F, D> for KeccakMemoryStark<F, D> {
const COLUMNS: usize = NUM_COLUMNS;
fn eval_packed_generic<FE, P, const D2: usize>(
&self,
vars: StarkEvaluationVars<FE, P, { Self::COLUMNS }>,
yield_constr: &mut ConstraintConsumer<P>,
) where
FE: FieldExtension<D2, BaseField = F>,
P: PackedField<Scalar = FE>,
{
// is_real must be 0 or 1.
let is_real = vars.local_values[COL_IS_REAL];
yield_constr.constraint(is_real * (is_real - P::ONES));
}
fn eval_ext_circuit(
&self,
builder: &mut plonky2::plonk::circuit_builder::CircuitBuilder<F, D>,
vars: StarkEvaluationTargets<D, { Self::COLUMNS }>,
yield_constr: &mut RecursiveConstraintConsumer<F, D>,
) {
// is_real must be 0 or 1.
let is_real = vars.local_values[COL_IS_REAL];
let constraint = builder.mul_sub_extension(is_real, is_real, is_real);
yield_constr.constraint(builder, constraint);
}
fn constraint_degree(&self) -> usize {
2
}
}
#[cfg(test)]
mod tests {
use anyhow::Result;
use plonky2::plonk::config::{GenericConfig, PoseidonGoldilocksConfig};
use crate::keccak_memory::keccak_memory_stark::KeccakMemoryStark;
use crate::stark_testing::{test_stark_circuit_constraints, test_stark_low_degree};
#[test]
fn test_stark_degree() -> Result<()> {
const D: usize = 2;
type C = PoseidonGoldilocksConfig;
type F = <C as GenericConfig<D>>::F;
type S = KeccakMemoryStark<F, D>;
let stark = S {
f: Default::default(),
};
test_stark_low_degree(stark)
}
#[test]
fn test_stark_circuit() -> Result<()> {
const D: usize = 2;
type C = PoseidonGoldilocksConfig;
type F = <C as GenericConfig<D>>::F;
type S = KeccakMemoryStark<F, D>;
let stark = S {
f: Default::default(),
};
test_stark_circuit_constraints::<F, C, S, D>(stark)
}
}

View File

@ -1,2 +0,0 @@
pub mod columns;
pub mod keccak_memory_stark;

View File

@ -21,7 +21,7 @@ pub(crate) struct KeccakSpongeColumnsView<T: Copy> {
/// in the block will be padding bytes; 0 otherwise.
pub is_final_block: T,
// The address at which we will read the input block.
// The base address at which we will read the input block.
pub context: T,
pub segment: T,
pub virt: T,

View File

@ -18,12 +18,11 @@ use crate::constraint_consumer::{ConstraintConsumer, RecursiveConstraintConsumer
use crate::cpu::kernel::keccak_util::keccakf_u32s;
use crate::cross_table_lookup::Column;
use crate::keccak_sponge::columns::*;
use crate::memory::segments::Segment;
use crate::stark::Stark;
use crate::util::trace_rows_to_poly_values;
use crate::vars::{StarkEvaluationTargets, StarkEvaluationVars};
use crate::witness::memory::MemoryAddress;
#[allow(unused)] // TODO: Should be used soon.
pub(crate) fn ctl_looked_data<F: Field>() -> Vec<Column<F>> {
let cols = KECCAK_SPONGE_COL_MAP;
let outputs = Column::singles(&cols.updated_state_u32s[..8]);
@ -31,14 +30,13 @@ pub(crate) fn ctl_looked_data<F: Field>() -> Vec<Column<F>> {
cols.context,
cols.segment,
cols.virt,
cols.timestamp,
cols.len,
cols.timestamp,
])
.chain(outputs)
.collect()
}
#[allow(unused)] // TODO: Should be used soon.
pub(crate) fn ctl_looking_keccak<F: Field>() -> Vec<Column<F>> {
let cols = KECCAK_SPONGE_COL_MAP;
Column::singles(
@ -52,7 +50,6 @@ pub(crate) fn ctl_looking_keccak<F: Field>() -> Vec<Column<F>> {
.collect()
}
#[allow(unused)] // TODO: Should be used soon.
pub(crate) fn ctl_looking_memory<F: Field>(i: usize) -> Vec<Column<F>> {
let cols = KECCAK_SPONGE_COL_MAP;
@ -81,14 +78,18 @@ pub(crate) fn ctl_looking_memory<F: Field>(i: usize) -> Vec<Column<F>> {
res
}
pub(crate) fn num_logic_ctls() -> usize {
const U8S_PER_CTL: usize = 32;
ceil_div_usize(KECCAK_RATE_BYTES, U8S_PER_CTL)
}
/// CTL for performing the `i`th logic CTL. Since we need to do 136 byte XORs, and the logic CTL can
/// XOR 32 bytes per CTL, there are 5 such CTLs.
#[allow(unused)] // TODO: Should be used soon.
pub(crate) fn ctl_looking_logic<F: Field>(i: usize) -> Vec<Column<F>> {
const U32S_PER_CTL: usize = 8;
const U8S_PER_CTL: usize = 32;
debug_assert!(i < ceil_div_usize(KECCAK_RATE_BYTES, U8S_PER_CTL));
debug_assert!(i < num_logic_ctls());
let cols = KECCAK_SPONGE_COL_MAP;
let mut res = vec![
@ -111,7 +112,7 @@ pub(crate) fn ctl_looking_logic<F: Field>(i: usize) -> Vec<Column<F>> {
.chunks(size_of::<u32>())
.map(|chunk| Column::le_bytes(chunk))
.chain(repeat(Column::zero()))
.take(U8S_PER_CTL),
.take(U32S_PER_CTL),
);
// The output contains the XOR'd rate part.
@ -124,14 +125,12 @@ pub(crate) fn ctl_looking_logic<F: Field>(i: usize) -> Vec<Column<F>> {
res
}
#[allow(unused)] // TODO: Should be used soon.
pub(crate) fn ctl_looked_filter<F: Field>() -> Column<F> {
// The CPU table is only interested in our final-block rows, since those contain the final
// sponge output.
Column::single(KECCAK_SPONGE_COL_MAP.is_final_block)
}
#[allow(unused)] // TODO: Should be used soon.
/// CTL filter for reading the `i`th byte of input from memory.
pub(crate) fn ctl_looking_memory_filter<F: Field>(i: usize) -> Column<F> {
// We perform the `i`th read if either
@ -141,26 +140,26 @@ pub(crate) fn ctl_looking_memory_filter<F: Field>(i: usize) -> Column<F> {
Column::sum(once(&cols.is_full_input_block).chain(&cols.is_final_input_len[i..]))
}
pub(crate) fn ctl_looking_keccak_filter<F: Field>() -> Column<F> {
let cols = KECCAK_SPONGE_COL_MAP;
Column::sum([cols.is_full_input_block, cols.is_final_block])
}
/// Information about a Keccak sponge operation needed for witness generation.
#[derive(Debug)]
pub(crate) struct KeccakSpongeOp {
// The address at which inputs are read.
pub(crate) context: usize,
pub(crate) segment: Segment,
pub(crate) virt: usize,
/// The base address at which inputs are read.
pub(crate) base_address: MemoryAddress,
/// The timestamp at which inputs are read.
pub(crate) timestamp: usize,
/// The length of the input, in bytes.
pub(crate) len: usize,
/// The input that was read.
pub(crate) input: Vec<u8>,
}
#[derive(Copy, Clone, Default)]
pub(crate) struct KeccakSpongeStark<F, const D: usize> {
pub struct KeccakSpongeStark<F, const D: usize> {
f: PhantomData<F>,
}
@ -261,7 +260,7 @@ impl<F: RichField + Extendable<D>, const D: usize> KeccakSpongeStark<F, D> {
sponge_state: [u32; KECCAK_WIDTH_U32S],
final_inputs: &[u8],
) -> KeccakSpongeColumnsView<F> {
assert_eq!(already_absorbed_bytes + final_inputs.len(), op.len);
assert_eq!(already_absorbed_bytes + final_inputs.len(), op.input.len());
let mut row = KeccakSpongeColumnsView {
is_final_block: F::ONE,
@ -295,11 +294,11 @@ impl<F: RichField + Extendable<D>, const D: usize> KeccakSpongeStark<F, D> {
already_absorbed_bytes: usize,
mut sponge_state: [u32; KECCAK_WIDTH_U32S],
) {
row.context = F::from_canonical_usize(op.context);
row.segment = F::from_canonical_usize(op.segment as usize);
row.virt = F::from_canonical_usize(op.virt);
row.context = F::from_canonical_usize(op.base_address.context);
row.segment = F::from_canonical_usize(op.base_address.segment);
row.virt = F::from_canonical_usize(op.base_address.virt);
row.timestamp = F::from_canonical_usize(op.timestamp);
row.len = F::from_canonical_usize(op.len);
row.len = F::from_canonical_usize(op.input.len());
row.already_absorbed_bytes = F::from_canonical_usize(already_absorbed_bytes);
row.original_rate_u32s = sponge_state[..KECCAK_RATE_U32S]
@ -410,6 +409,7 @@ mod tests {
use crate::keccak_sponge::keccak_sponge_stark::{KeccakSpongeOp, KeccakSpongeStark};
use crate::memory::segments::Segment;
use crate::stark_testing::{test_stark_circuit_constraints, test_stark_low_degree};
use crate::witness::memory::MemoryAddress;
#[test]
fn test_stark_degree() -> Result<()> {
@ -443,11 +443,12 @@ mod tests {
let expected_output = keccak(&input);
let op = KeccakSpongeOp {
context: 0,
segment: Segment::Code,
virt: 0,
base_address: MemoryAddress {
context: 0,
segment: Segment::Code as usize,
virt: 0,
},
timestamp: 0,
len: input.len(),
input,
};
let stark = S::default();

View File

@ -2,6 +2,7 @@
#![allow(clippy::needless_range_loop)]
#![allow(clippy::too_many_arguments)]
#![allow(clippy::type_complexity)]
#![allow(clippy::field_reassign_with_default)]
#![feature(let_chains)]
#![feature(generic_const_exprs)]
@ -14,7 +15,6 @@ pub mod cross_table_lookup;
pub mod generation;
mod get_challenges;
pub mod keccak;
pub mod keccak_memory;
pub mod keccak_sponge;
pub mod logic;
pub mod lookup;
@ -29,3 +29,12 @@ pub mod util;
pub mod vanishing_poly;
pub mod vars;
pub mod verifier;
pub mod witness;
// Set up Jemalloc
#[cfg(not(target_env = "msvc"))]
use jemallocator::Jemalloc;
#[cfg(not(target_env = "msvc"))]
#[global_allocator]
static GLOBAL: Jemalloc = Jemalloc;

View File

@ -72,13 +72,23 @@ pub struct LogicStark<F, const D: usize> {
pub f: PhantomData<F>,
}
#[derive(Copy, Clone, Debug)]
#[derive(Copy, Clone, Debug, Eq, PartialEq)]
pub(crate) enum Op {
And,
Or,
Xor,
}
impl Op {
pub(crate) fn result(&self, a: U256, b: U256) -> U256 {
match self {
Op::And => a & b,
Op::Or => a | b,
Op::Xor => a ^ b,
}
}
}
#[derive(Debug)]
pub(crate) struct Operation {
operator: Op,
@ -89,11 +99,7 @@ pub(crate) struct Operation {
impl Operation {
pub(crate) fn new(operator: Op, input0: U256, input1: U256) -> Self {
let result = match operator {
Op::And => input0 & input1,
Op::Or => input0 | input1,
Op::Xor => input0 ^ input1,
};
let result = operator.result(input0, input1);
Operation {
operator,
input0,
@ -101,18 +107,44 @@ impl Operation {
result,
}
}
fn into_row<F: Field>(self) -> [F; NUM_COLUMNS] {
let Operation {
operator,
input0,
input1,
result,
} = self;
let mut row = [F::ZERO; NUM_COLUMNS];
row[match operator {
Op::And => columns::IS_AND,
Op::Or => columns::IS_OR,
Op::Xor => columns::IS_XOR,
}] = F::ONE;
for i in 0..256 {
row[columns::INPUT0.start + i] = F::from_bool(input0.bit(i));
row[columns::INPUT1.start + i] = F::from_bool(input1.bit(i));
}
let result_limbs: &[u64] = result.as_ref();
for (i, &limb) in result_limbs.iter().enumerate() {
row[columns::RESULT.start + 2 * i] = F::from_canonical_u32(limb as u32);
row[columns::RESULT.start + 2 * i + 1] = F::from_canonical_u32((limb >> 32) as u32);
}
row
}
}
impl<F: RichField, const D: usize> LogicStark<F, D> {
pub(crate) fn generate_trace(
&self,
operations: Vec<Operation>,
min_rows: usize,
timing: &mut TimingTree,
) -> Vec<PolynomialValues<F>> {
let trace_rows = timed!(
timing,
"generate trace rows",
self.generate_trace_rows(operations)
self.generate_trace_rows(operations, min_rows)
);
let trace_polys = timed!(
timing,
@ -122,46 +154,30 @@ impl<F: RichField, const D: usize> LogicStark<F, D> {
trace_polys
}
fn generate_trace_rows(&self, operations: Vec<Operation>) -> Vec<[F; NUM_COLUMNS]> {
fn generate_trace_rows(
&self,
operations: Vec<Operation>,
min_rows: usize,
) -> Vec<[F; NUM_COLUMNS]> {
let len = operations.len();
let padded_len = len.next_power_of_two();
let padded_len = len.max(min_rows).next_power_of_two();
let mut rows = Vec::with_capacity(padded_len);
for op in operations {
rows.push(Self::generate_row(op));
rows.push(op.into_row());
}
// Pad to a power of two.
for _ in len..padded_len {
rows.push([F::ZERO; columns::NUM_COLUMNS]);
rows.push([F::ZERO; NUM_COLUMNS]);
}
rows
}
fn generate_row(operation: Operation) -> [F; columns::NUM_COLUMNS] {
let mut row = [F::ZERO; columns::NUM_COLUMNS];
match operation.operator {
Op::And => row[columns::IS_AND] = F::ONE,
Op::Or => row[columns::IS_OR] = F::ONE,
Op::Xor => row[columns::IS_XOR] = F::ONE,
}
for (i, col) in columns::INPUT0.enumerate() {
row[col] = F::from_bool(operation.input0.bit(i));
}
for (i, col) in columns::INPUT1.enumerate() {
row[col] = F::from_bool(operation.input1.bit(i));
}
for (i, col) in columns::RESULT.enumerate() {
let bit_range = i * PACKED_LIMB_BITS..(i + 1) * PACKED_LIMB_BITS;
row[col] = limb_from_bits_le(bit_range.map(|j| F::from_bool(operation.result.bit(j))));
}
row
}
}
impl<F: RichField + Extendable<D>, const D: usize> Stark<F, D> for LogicStark<F, D> {
const COLUMNS: usize = columns::NUM_COLUMNS;
const COLUMNS: usize = NUM_COLUMNS;
fn eval_packed_generic<FE, P, const D2: usize>(
&self,

View File

@ -1,6 +1,5 @@
use std::marker::PhantomData;
use ethereum_types::U256;
use itertools::Itertools;
use maybe_rayon::*;
use plonky2::field::extension::{Extendable, FieldExtension};
@ -20,11 +19,12 @@ use crate::memory::columns::{
COUNTER_PERMUTED, FILTER, IS_READ, NUM_COLUMNS, RANGE_CHECK, RANGE_CHECK_PERMUTED,
SEGMENT_FIRST_CHANGE, TIMESTAMP, VIRTUAL_FIRST_CHANGE,
};
use crate::memory::segments::Segment;
use crate::memory::VALUE_LIMBS;
use crate::permutation::PermutationPair;
use crate::stark::Stark;
use crate::vars::{StarkEvaluationTargets, StarkEvaluationVars};
use crate::witness::memory::MemoryOpKind::Read;
use crate::witness::memory::{MemoryAddress, MemoryOp};
pub fn ctl_data<F: Field>() -> Vec<Column<F>> {
let mut res =
@ -43,31 +43,24 @@ pub struct MemoryStark<F, const D: usize> {
pub(crate) f: PhantomData<F>,
}
#[derive(Clone, Debug)]
pub(crate) struct MemoryOp {
/// true if this is an actual memory operation, or false if it's a padding row.
pub filter: bool,
pub timestamp: usize,
pub is_read: bool,
pub context: usize,
pub segment: Segment,
pub virt: usize,
pub value: U256,
}
impl MemoryOp {
/// Generate a row for a given memory operation. Note that this does not generate columns which
/// depend on the next operation, such as `CONTEXT_FIRST_CHANGE`; those are generated later.
/// It also does not generate columns such as `COUNTER`, which are generated later, after the
/// trace has been transposed into column-major form.
fn to_row<F: Field>(&self) -> [F; NUM_COLUMNS] {
fn into_row<F: Field>(self) -> [F; NUM_COLUMNS] {
let mut row = [F::ZERO; NUM_COLUMNS];
row[FILTER] = F::from_bool(self.filter);
row[TIMESTAMP] = F::from_canonical_usize(self.timestamp);
row[IS_READ] = F::from_bool(self.is_read);
row[ADDR_CONTEXT] = F::from_canonical_usize(self.context);
row[ADDR_SEGMENT] = F::from_canonical_usize(self.segment as usize);
row[ADDR_VIRTUAL] = F::from_canonical_usize(self.virt);
row[IS_READ] = F::from_bool(self.kind == Read);
let MemoryAddress {
context,
segment,
virt,
} = self.address;
row[ADDR_CONTEXT] = F::from_canonical_usize(context);
row[ADDR_SEGMENT] = F::from_canonical_usize(segment);
row[ADDR_VIRTUAL] = F::from_canonical_usize(virt);
for j in 0..VALUE_LIMBS {
row[value_limb(j)] = F::from_canonical_u32((self.value >> (j * 32)).low_u32());
}
@ -80,14 +73,14 @@ fn get_max_range_check(memory_ops: &[MemoryOp]) -> usize {
.iter()
.tuple_windows()
.map(|(curr, next)| {
if curr.context != next.context {
next.context - curr.context - 1
} else if curr.segment != next.segment {
next.segment as usize - curr.segment as usize - 1
} else if curr.virt != next.virt {
next.virt - curr.virt - 1
if curr.address.context != next.address.context {
next.address.context - curr.address.context - 1
} else if curr.address.segment != next.address.segment {
next.address.segment - curr.address.segment - 1
} else if curr.address.virt != next.address.virt {
next.address.virt - curr.address.virt - 1
} else {
next.timestamp - curr.timestamp - 1
next.timestamp - curr.timestamp
}
})
.max()
@ -131,7 +124,7 @@ pub fn generate_first_change_flags_and_rc<F: RichField>(trace_rows: &mut [[F; NU
} else if virtual_first_change {
next_virt - virt - F::ONE
} else {
next_timestamp - timestamp - F::ONE
next_timestamp - timestamp
};
}
}
@ -140,13 +133,20 @@ impl<F: RichField + Extendable<D>, const D: usize> MemoryStark<F, D> {
/// Generate most of the trace rows. Excludes a few columns like `COUNTER`, which are generated
/// later, after transposing to column-major form.
fn generate_trace_row_major(&self, mut memory_ops: Vec<MemoryOp>) -> Vec<[F; NUM_COLUMNS]> {
memory_ops.sort_by_key(|op| (op.context, op.segment, op.virt, op.timestamp));
memory_ops.sort_by_key(|op| {
(
op.address.context,
op.address.segment,
op.address.virt,
op.timestamp,
)
});
Self::pad_memory_ops(&mut memory_ops);
let mut trace_rows = memory_ops
.into_par_iter()
.map(|op| op.to_row())
.map(|op| op.into_row())
.collect::<Vec<_>>();
generate_first_change_flags_and_rc(trace_rows.as_mut_slice());
trace_rows
@ -170,7 +170,7 @@ impl<F: RichField + Extendable<D>, const D: usize> MemoryStark<F, D> {
let num_ops_padded = num_ops.max(max_range_check + 1).next_power_of_two();
let to_pad = num_ops_padded - num_ops;
let last_op = memory_ops.last().expect("No memory ops?").clone();
let last_op = *memory_ops.last().expect("No memory ops?");
// We essentially repeat the last operation until our operation list has the desired size,
// with a few changes:
@ -181,7 +181,7 @@ impl<F: RichField + Extendable<D>, const D: usize> MemoryStark<F, D> {
memory_ops.push(MemoryOp {
filter: false,
timestamp: last_op.timestamp + i + 1,
is_read: true,
kind: Read,
..last_op
});
}
@ -283,7 +283,7 @@ impl<F: RichField + Extendable<D>, const D: usize> Stark<F, D> for MemoryStark<F
let computed_range_check = context_first_change * (next_addr_context - addr_context - one)
+ segment_first_change * (next_addr_segment - addr_segment - one)
+ virtual_first_change * (next_addr_virtual - addr_virtual - one)
+ address_unchanged * (next_timestamp - timestamp - one);
+ address_unchanged * (next_timestamp - timestamp);
yield_constr.constraint_transition(range_check - computed_range_check);
// Enumerate purportedly-ordered log.
@ -394,10 +394,7 @@ impl<F: RichField + Extendable<D>, const D: usize> Stark<F, D> for MemoryStark<F
builder.sub_extension(diff, one)
};
let virtual_range_check = builder.mul_extension(virtual_first_change, virtual_diff);
let timestamp_diff = {
let diff = builder.sub_extension(next_timestamp, timestamp);
builder.sub_extension(diff, one)
};
let timestamp_diff = builder.sub_extension(next_timestamp, timestamp);
let timestamp_range_check = builder.mul_extension(address_unchanged, timestamp_diff);
let computed_range_check = {
@ -439,94 +436,12 @@ impl<F: RichField + Extendable<D>, const D: usize> Stark<F, D> for MemoryStark<F
#[cfg(test)]
pub(crate) mod tests {
use std::collections::{HashMap, HashSet};
use anyhow::Result;
use ethereum_types::U256;
use plonky2::plonk::config::{GenericConfig, PoseidonGoldilocksConfig};
use rand::prelude::SliceRandom;
use rand::Rng;
use crate::memory::memory_stark::{MemoryOp, MemoryStark};
use crate::memory::segments::Segment;
use crate::memory::NUM_CHANNELS;
use crate::memory::memory_stark::MemoryStark;
use crate::stark_testing::{test_stark_circuit_constraints, test_stark_low_degree};
pub(crate) fn generate_random_memory_ops<R: Rng>(num_ops: usize, rng: &mut R) -> Vec<MemoryOp> {
let mut memory_ops = Vec::new();
let mut current_memory_values: HashMap<(usize, Segment, usize), U256> = HashMap::new();
let num_cycles = num_ops / 2;
for clock in 0..num_cycles {
let mut used_indices = HashSet::new();
let mut new_writes_this_cycle = HashMap::new();
let mut has_read = false;
for _ in 0..2 {
let mut channel_index = rng.gen_range(0..NUM_CHANNELS);
while used_indices.contains(&channel_index) {
channel_index = rng.gen_range(0..NUM_CHANNELS);
}
used_indices.insert(channel_index);
let is_read = if clock == 0 {
false
} else {
!has_read && rng.gen()
};
has_read = has_read || is_read;
let (context, segment, virt, vals) = if is_read {
let written: Vec<_> = current_memory_values.keys().collect();
let &(mut context, mut segment, mut virt) =
written[rng.gen_range(0..written.len())];
while new_writes_this_cycle.contains_key(&(context, segment, virt)) {
(context, segment, virt) = *written[rng.gen_range(0..written.len())];
}
let &vals = current_memory_values
.get(&(context, segment, virt))
.unwrap();
(context, segment, virt, vals)
} else {
// TODO: with taller memory table or more padding (to enable range-checking bigger diffs),
// test larger address values.
let mut context = rng.gen_range(0..40);
let segments = [Segment::Code, Segment::Stack, Segment::MainMemory];
let mut segment = *segments.choose(rng).unwrap();
let mut virt = rng.gen_range(0..20);
while new_writes_this_cycle.contains_key(&(context, segment, virt)) {
context = rng.gen_range(0..40);
segment = *segments.choose(rng).unwrap();
virt = rng.gen_range(0..20);
}
let val = U256(rng.gen());
new_writes_this_cycle.insert((context, segment, virt), val);
(context, segment, virt, val)
};
let timestamp = clock * NUM_CHANNELS + channel_index;
memory_ops.push(MemoryOp {
filter: true,
timestamp,
is_read,
context,
segment,
virt,
value: vals,
});
}
for (k, v) in new_writes_this_cycle {
current_memory_values.insert(k, v);
}
}
memory_ops
}
#[test]
fn test_stark_degree() -> Result<()> {
const D: usize = 2;

View File

@ -24,7 +24,7 @@ use crate::cpu::cpu_stark::CpuStark;
use crate::cross_table_lookup::{cross_table_lookup_data, CtlCheckVars, CtlData};
use crate::generation::{generate_traces, GenerationInputs};
use crate::keccak::keccak_stark::KeccakStark;
use crate::keccak_memory::keccak_memory_stark::KeccakMemoryStark;
use crate::keccak_sponge::keccak_sponge_stark::KeccakSpongeStark;
use crate::logic::LogicStark;
use crate::memory::memory_stark::MemoryStark;
use crate::permutation::{
@ -49,7 +49,7 @@ where
[(); C::Hasher::HASH_SIZE]:,
[(); CpuStark::<F, D>::COLUMNS]:,
[(); KeccakStark::<F, D>::COLUMNS]:,
[(); KeccakMemoryStark::<F, D>::COLUMNS]:,
[(); KeccakSpongeStark::<F, D>::COLUMNS]:,
[(); LogicStark::<F, D>::COLUMNS]:,
[(); MemoryStark::<F, D>::COLUMNS]:,
{
@ -71,7 +71,7 @@ where
[(); C::Hasher::HASH_SIZE]:,
[(); CpuStark::<F, D>::COLUMNS]:,
[(); KeccakStark::<F, D>::COLUMNS]:,
[(); KeccakMemoryStark::<F, D>::COLUMNS]:,
[(); KeccakSpongeStark::<F, D>::COLUMNS]:,
[(); LogicStark::<F, D>::COLUMNS]:,
[(); MemoryStark::<F, D>::COLUMNS]:,
{
@ -132,12 +132,12 @@ where
&mut challenger,
timing,
)?;
let keccak_memory_proof = prove_single_table(
&all_stark.keccak_memory_stark,
let keccak_sponge_proof = prove_single_table(
&all_stark.keccak_sponge_stark,
config,
&trace_poly_values[Table::KeccakMemory as usize],
&trace_commitments[Table::KeccakMemory as usize],
&ctl_data_per_table[Table::KeccakMemory as usize],
&trace_poly_values[Table::KeccakSponge as usize],
&trace_commitments[Table::KeccakSponge as usize],
&ctl_data_per_table[Table::KeccakSponge as usize],
&mut challenger,
timing,
)?;
@ -163,7 +163,7 @@ where
let stark_proofs = [
cpu_proof,
keccak_proof,
keccak_memory_proof,
keccak_sponge_proof,
logic_proof,
memory_proof,
];

View File

@ -27,7 +27,7 @@ use crate::cross_table_lookup::{
CtlCheckVarsTarget,
};
use crate::keccak::keccak_stark::KeccakStark;
use crate::keccak_memory::keccak_memory_stark::KeccakMemoryStark;
use crate::keccak_sponge::keccak_sponge_stark::KeccakSpongeStark;
use crate::logic::LogicStark;
use crate::memory::memory_stark::MemoryStark;
use crate::permutation::{
@ -231,7 +231,7 @@ impl<F: RichField + Extendable<D>, C: GenericConfig<D, F = F>, const D: usize>
.enumerate()
{
builder.verify_proof::<C>(
recursive_proof,
&recursive_proof,
&verifier_data_target,
&verifier_data[i].common,
);
@ -332,7 +332,7 @@ pub fn all_verifier_data_recursive_stark_proof<
where
[(); CpuStark::<F, D>::COLUMNS]:,
[(); KeccakStark::<F, D>::COLUMNS]:,
[(); KeccakMemoryStark::<F, D>::COLUMNS]:,
[(); KeccakSpongeStark::<F, D>::COLUMNS]:,
[(); LogicStark::<F, D>::COLUMNS]:,
[(); MemoryStark::<F, D>::COLUMNS]:,
[(); C::Hasher::HASH_SIZE]:,
@ -356,9 +356,9 @@ where
circuit_config,
),
verifier_data_recursive_stark_proof(
Table::KeccakMemory,
all_stark.keccak_memory_stark,
degree_bits[Table::KeccakMemory as usize],
Table::KeccakSponge,
all_stark.keccak_sponge_stark,
degree_bits[Table::KeccakSponge as usize],
&all_stark.cross_table_lookups,
inner_config,
circuit_config,
@ -534,10 +534,10 @@ pub fn add_virtual_all_proof<F: RichField + Extendable<D>, const D: usize>(
),
add_virtual_stark_proof(
builder,
&all_stark.keccak_memory_stark,
&all_stark.keccak_sponge_stark,
config,
degree_bits[Table::KeccakMemory as usize],
nums_ctl_zs[Table::KeccakMemory as usize],
degree_bits[Table::KeccakSponge as usize],
nums_ctl_zs[Table::KeccakSponge as usize],
),
add_virtual_stark_proof(
builder,
@ -853,7 +853,7 @@ pub(crate) mod tests {
use crate::cpu::cpu_stark::CpuStark;
use crate::cross_table_lookup::{CrossTableLookup, CtlCheckVarsTarget};
use crate::keccak::keccak_stark::KeccakStark;
use crate::keccak_memory::keccak_memory_stark::KeccakMemoryStark;
use crate::keccak_sponge::keccak_sponge_stark::KeccakSpongeStark;
use crate::logic::LogicStark;
use crate::memory::memory_stark::MemoryStark;
use crate::permutation::{GrandProductChallenge, GrandProductChallengeSet};
@ -866,6 +866,7 @@ pub(crate) mod tests {
/// Recursively verify a Stark proof.
/// Outputs the recursive proof and the associated verifier data.
#[allow(unused)] // TODO: used later?
fn recursively_verify_stark_proof<
F: RichField + Extendable<D>,
C: GenericConfig<D, F = F>,
@ -965,6 +966,7 @@ pub(crate) mod tests {
}
/// Recursively verify every Stark proof in an `AllProof`.
#[allow(unused)] // TODO: used later?
pub fn recursively_verify_all_proof<
F: RichField + Extendable<D>,
C: GenericConfig<D, F = F>,
@ -978,7 +980,7 @@ pub(crate) mod tests {
where
[(); CpuStark::<F, D>::COLUMNS]:,
[(); KeccakStark::<F, D>::COLUMNS]:,
[(); KeccakMemoryStark::<F, D>::COLUMNS]:,
[(); KeccakSpongeStark::<F, D>::COLUMNS]:,
[(); LogicStark::<F, D>::COLUMNS]:,
[(); MemoryStark::<F, D>::COLUMNS]:,
[(); C::Hasher::HASH_SIZE]:,
@ -1013,9 +1015,9 @@ pub(crate) mod tests {
)?
.0,
recursively_verify_stark_proof(
Table::KeccakMemory,
all_stark.keccak_memory_stark,
&all_proof.stark_proofs[Table::KeccakMemory as usize],
Table::KeccakSponge,
all_stark.keccak_sponge_stark,
&all_proof.stark_proofs[Table::KeccakSponge as usize],
&all_stark.cross_table_lookups,
&ctl_challenges,
states[2],

View File

@ -2,6 +2,7 @@ use std::mem::{size_of, transmute_copy, ManuallyDrop};
use ethereum_types::{H160, H256, U256};
use itertools::Itertools;
use num::BigUint;
use plonky2::field::extension::Extendable;
use plonky2::field::packed::PackedField;
use plonky2::field::polynomial::PolynomialValues;
@ -44,6 +45,7 @@ pub fn trace_rows_to_poly_values<F: Field, const COLUMNS: usize>(
.collect()
}
#[allow(unused)] // TODO: Remove?
/// Returns the 32-bit little-endian limbs of a `U256`.
pub(crate) fn u256_limbs<F: Field>(u256: U256) -> [F; 8] {
u256.0
@ -98,3 +100,55 @@ pub(crate) unsafe fn transmute_no_compile_time_size_checks<T, U>(value: T) -> U
// Copy the bit pattern. The original value is no longer safe to use.
transmute_copy(&value)
}
pub(crate) fn addmod(x: U256, y: U256, m: U256) -> U256 {
if m.is_zero() {
return m;
}
let x = u256_to_biguint(x);
let y = u256_to_biguint(y);
let m = u256_to_biguint(m);
biguint_to_u256((x + y) % m)
}
pub(crate) fn mulmod(x: U256, y: U256, m: U256) -> U256 {
if m.is_zero() {
return m;
}
let x = u256_to_biguint(x);
let y = u256_to_biguint(y);
let m = u256_to_biguint(m);
biguint_to_u256(x * y % m)
}
pub(crate) fn submod(x: U256, y: U256, m: U256) -> U256 {
if m.is_zero() {
return m;
}
let mut x = u256_to_biguint(x);
let y = u256_to_biguint(y);
let m = u256_to_biguint(m);
while x < y {
x += &m;
}
biguint_to_u256((x - y) % m)
}
pub(crate) fn u256_to_biguint(x: U256) -> BigUint {
let mut bytes = [0u8; 32];
x.to_little_endian(&mut bytes);
BigUint::from_bytes_le(&bytes)
}
pub(crate) fn biguint_to_u256(x: BigUint) -> U256 {
let bytes = x.to_bytes_le();
U256::from_little_endian(&bytes)
}
pub(crate) fn u256_saturating_cast_usize(x: U256) -> usize {
if x > usize::MAX.into() {
usize::MAX
} else {
x.as_usize()
}
}

View File

@ -1,3 +1,5 @@
use std::any::type_name;
use anyhow::{ensure, Result};
use plonky2::field::extension::{Extendable, FieldExtension};
use plonky2::field::types::Field;
@ -12,7 +14,7 @@ use crate::constraint_consumer::ConstraintConsumer;
use crate::cpu::cpu_stark::CpuStark;
use crate::cross_table_lookup::{verify_cross_table_lookups, CtlCheckVars};
use crate::keccak::keccak_stark::KeccakStark;
use crate::keccak_memory::keccak_memory_stark::KeccakMemoryStark;
use crate::keccak_sponge::keccak_sponge_stark::KeccakSpongeStark;
use crate::logic::LogicStark;
use crate::memory::memory_stark::MemoryStark;
use crate::permutation::PermutationCheckVars;
@ -31,7 +33,7 @@ pub fn verify_proof<F: RichField + Extendable<D>, C: GenericConfig<D, F = F>, co
where
[(); CpuStark::<F, D>::COLUMNS]:,
[(); KeccakStark::<F, D>::COLUMNS]:,
[(); KeccakMemoryStark::<F, D>::COLUMNS]:,
[(); KeccakSpongeStark::<F, D>::COLUMNS]:,
[(); LogicStark::<F, D>::COLUMNS]:,
[(); MemoryStark::<F, D>::COLUMNS]:,
[(); C::Hasher::HASH_SIZE]:,
@ -46,7 +48,7 @@ where
let AllStark {
cpu_stark,
keccak_stark,
keccak_memory_stark,
keccak_sponge_stark,
logic_stark,
memory_stark,
cross_table_lookups,
@ -74,10 +76,10 @@ where
config,
)?;
verify_stark_proof_with_challenges(
keccak_memory_stark,
&all_proof.stark_proofs[Table::KeccakMemory as usize],
&stark_challenges[Table::KeccakMemory as usize],
&ctl_vars_per_table[Table::KeccakMemory as usize],
keccak_sponge_stark,
&all_proof.stark_proofs[Table::KeccakSponge as usize],
&stark_challenges[Table::KeccakSponge as usize],
&ctl_vars_per_table[Table::KeccakSponge as usize],
config,
)?;
verify_stark_proof_with_challenges(
@ -122,6 +124,7 @@ where
[(); S::COLUMNS]:,
[(); C::Hasher::HASH_SIZE]:,
{
log::debug!("Checking proof: {}", type_name::<S>());
validate_proof_shape(&stark, proof, config, ctl_vars.len())?;
let StarkOpeningSet {
local_values,

10
evm/src/witness/errors.rs Normal file
View File

@ -0,0 +1,10 @@
#[allow(dead_code)]
#[derive(Debug)]
pub enum ProgramError {
OutOfGas,
InvalidOpcode,
StackUnderflow,
InvalidJumpDestination,
InvalidJumpiDestination,
StackOverflow,
}

12
evm/src/witness/mem_tx.rs Normal file
View File

@ -0,0 +1,12 @@
use crate::witness::memory::{MemoryOp, MemoryOpKind, MemoryState};
pub fn apply_mem_ops(state: &mut MemoryState, mut ops: Vec<MemoryOp>) {
ops.sort_unstable_by_key(|mem_op| mem_op.timestamp);
for op in ops {
let MemoryOp { address, op, .. } = op;
if let MemoryOpKind::Write(val) = op {
state.set(address, val);
}
}
}

162
evm/src/witness/memory.rs Normal file
View File

@ -0,0 +1,162 @@
use ethereum_types::U256;
use crate::cpu::membus::{NUM_CHANNELS, NUM_GP_CHANNELS};
#[derive(Clone, Copy, Debug)]
pub enum MemoryChannel {
Code,
GeneralPurpose(usize),
}
use MemoryChannel::{Code, GeneralPurpose};
use crate::memory::segments::Segment;
use crate::util::u256_saturating_cast_usize;
impl MemoryChannel {
pub fn index(&self) -> usize {
match *self {
Code => 0,
GeneralPurpose(n) => {
assert!(n < NUM_GP_CHANNELS);
n + 1
}
}
}
}
#[derive(Clone, Copy, Debug, Eq, PartialEq, Hash)]
pub struct MemoryAddress {
pub(crate) context: usize,
pub(crate) segment: usize,
pub(crate) virt: usize,
}
impl MemoryAddress {
pub(crate) fn new(context: usize, segment: Segment, virt: usize) -> Self {
Self {
context,
segment: segment as usize,
virt,
}
}
pub(crate) fn new_u256s(context: U256, segment: U256, virt: U256) -> Self {
Self {
context: u256_saturating_cast_usize(context),
segment: u256_saturating_cast_usize(segment),
virt: u256_saturating_cast_usize(virt),
}
}
pub(crate) fn increment(&mut self) {
self.virt = self.virt.saturating_add(1);
}
}
#[derive(Clone, Copy, Debug, Eq, PartialEq)]
pub enum MemoryOpKind {
Read,
Write,
}
#[derive(Clone, Copy, Debug)]
pub struct MemoryOp {
/// true if this is an actual memory operation, or false if it's a padding row.
pub filter: bool,
pub timestamp: usize,
pub address: MemoryAddress,
pub kind: MemoryOpKind,
pub value: U256,
}
impl MemoryOp {
pub fn new(
channel: MemoryChannel,
clock: usize,
address: MemoryAddress,
kind: MemoryOpKind,
value: U256,
) -> Self {
let timestamp = clock * NUM_CHANNELS + channel.index();
MemoryOp {
filter: true,
timestamp,
address,
kind,
value,
}
}
}
#[derive(Clone, Debug)]
pub struct MemoryState {
pub(crate) contexts: Vec<MemoryContextState>,
}
impl MemoryState {
pub fn new(kernel_code: &[u8]) -> Self {
let code_u256s = kernel_code.iter().map(|&x| x.into()).collect();
let mut result = Self::default();
result.contexts[0].segments[Segment::Code as usize].content = code_u256s;
result
}
pub fn apply_ops(&mut self, ops: &[MemoryOp]) {
for &op in ops {
let MemoryOp {
address,
kind,
value,
..
} = op;
if kind == MemoryOpKind::Write {
self.set(address, value);
}
}
}
pub fn get(&self, address: MemoryAddress) -> U256 {
self.contexts[address.context].segments[address.segment].get(address.virt)
}
pub fn set(&mut self, address: MemoryAddress, val: U256) {
self.contexts[address.context].segments[address.segment].set(address.virt, val);
}
}
impl Default for MemoryState {
fn default() -> Self {
Self {
// We start with an initial context for the kernel.
contexts: vec![MemoryContextState::default()],
}
}
}
#[derive(Clone, Default, Debug)]
pub(crate) struct MemoryContextState {
/// The content of each memory segment.
pub(crate) segments: [MemorySegmentState; Segment::COUNT],
}
#[derive(Clone, Default, Debug)]
pub(crate) struct MemorySegmentState {
pub(crate) content: Vec<U256>,
}
impl MemorySegmentState {
pub(crate) fn get(&self, virtual_addr: usize) -> U256 {
self.content
.get(virtual_addr)
.copied()
.unwrap_or(U256::zero())
}
pub(crate) fn set(&mut self, virtual_addr: usize, value: U256) {
if virtual_addr >= self.content.len() {
self.content.resize(virtual_addr + 1, U256::zero());
}
self.content[virtual_addr] = value;
}
}

7
evm/src/witness/mod.rs Normal file
View File

@ -0,0 +1,7 @@
mod errors;
pub(crate) mod memory;
mod operation;
pub(crate) mod state;
pub(crate) mod traces;
pub mod transition;
pub(crate) mod util;

View File

@ -0,0 +1,507 @@
use ethereum_types::{BigEndianHash, U256};
use itertools::Itertools;
use keccak_hash::keccak;
use plonky2::field::types::Field;
use crate::cpu::columns::CpuColumnsView;
use crate::cpu::kernel::aggregator::KERNEL;
use crate::cpu::membus::NUM_GP_CHANNELS;
use crate::cpu::simple_logic::eq_iszero::generate_pinv_diff;
use crate::generation::state::GenerationState;
use crate::memory::segments::Segment;
use crate::util::u256_saturating_cast_usize;
use crate::witness::errors::ProgramError;
use crate::witness::memory::MemoryAddress;
use crate::witness::util::{
keccak_sponge_log, mem_read_code_with_log_and_fill, mem_read_gp_with_log_and_fill,
mem_write_gp_log_and_fill, stack_pop_with_log_and_fill, stack_push_log_and_fill,
};
use crate::{arithmetic, logic};
#[derive(Clone, Copy, Debug, Eq, PartialEq)]
pub(crate) enum Operation {
Push(u8),
Dup(u8),
Swap(u8),
Iszero,
Not,
Byte,
Syscall(u8),
Eq,
BinaryLogic(logic::Op),
BinaryArithmetic(arithmetic::BinaryOperator),
TernaryArithmetic(arithmetic::TernaryOperator),
KeccakGeneral,
ProverInput,
Pop,
Jump,
Jumpi,
Pc,
Gas,
Jumpdest,
GetContext,
SetContext,
ConsumeGas,
ExitKernel,
MloadGeneral,
MstoreGeneral,
}
pub(crate) fn generate_binary_logic_op<F: Field>(
op: logic::Op,
state: &mut GenerationState<F>,
mut row: CpuColumnsView<F>,
) -> Result<(), ProgramError> {
let [(in0, log_in0), (in1, log_in1)] = stack_pop_with_log_and_fill::<2, _>(state, &mut row)?;
let operation = logic::Operation::new(op, in0, in1);
let log_out = stack_push_log_and_fill(state, &mut row, operation.result)?;
state.traces.push_logic(operation);
state.traces.push_memory(log_in0);
state.traces.push_memory(log_in1);
state.traces.push_memory(log_out);
state.traces.push_cpu(row);
Ok(())
}
pub(crate) fn generate_binary_arithmetic_op<F: Field>(
operator: arithmetic::BinaryOperator,
state: &mut GenerationState<F>,
mut row: CpuColumnsView<F>,
) -> Result<(), ProgramError> {
let [(input0, log_in0), (input1, log_in1)] =
stack_pop_with_log_and_fill::<2, _>(state, &mut row)?;
let operation = arithmetic::Operation::binary(operator, input0, input1);
let log_out = stack_push_log_and_fill(state, &mut row, operation.result())?;
if operator == arithmetic::BinaryOperator::Shl || operator == arithmetic::BinaryOperator::Shr {
const LOOKUP_CHANNEL: usize = 2;
let lookup_addr = MemoryAddress::new(0, Segment::ShiftTable, input0.low_u32() as usize);
if input0.bits() <= 32 {
let (_, read) =
mem_read_gp_with_log_and_fill(LOOKUP_CHANNEL, lookup_addr, state, &mut row);
state.traces.push_memory(read);
} else {
// The shift constraints still expect the address to be set, even though no read will occur.
let mut channel = &mut row.mem_channels[LOOKUP_CHANNEL];
channel.addr_context = F::from_canonical_usize(lookup_addr.context);
channel.addr_segment = F::from_canonical_usize(lookup_addr.segment);
channel.addr_virtual = F::from_canonical_usize(lookup_addr.virt);
}
}
state.traces.push_arithmetic(operation);
state.traces.push_memory(log_in0);
state.traces.push_memory(log_in1);
state.traces.push_memory(log_out);
state.traces.push_cpu(row);
Ok(())
}
pub(crate) fn generate_ternary_arithmetic_op<F: Field>(
operator: arithmetic::TernaryOperator,
state: &mut GenerationState<F>,
mut row: CpuColumnsView<F>,
) -> Result<(), ProgramError> {
let [(input0, log_in0), (input1, log_in1), (input2, log_in2)] =
stack_pop_with_log_and_fill::<3, _>(state, &mut row)?;
let operation = arithmetic::Operation::ternary(operator, input0, input1, input2);
let log_out = stack_push_log_and_fill(state, &mut row, operation.result())?;
state.traces.push_arithmetic(operation);
state.traces.push_memory(log_in0);
state.traces.push_memory(log_in1);
state.traces.push_memory(log_in2);
state.traces.push_memory(log_out);
state.traces.push_cpu(row);
Ok(())
}
pub(crate) fn generate_keccak_general<F: Field>(
state: &mut GenerationState<F>,
mut row: CpuColumnsView<F>,
) -> Result<(), ProgramError> {
row.is_keccak_sponge = F::ONE;
let [(context, log_in0), (segment, log_in1), (base_virt, log_in2), (len, log_in3)] =
stack_pop_with_log_and_fill::<4, _>(state, &mut row)?;
let len = len.as_usize();
let base_address = MemoryAddress::new_u256s(context, segment, base_virt);
let input = (0..len)
.map(|i| {
let address = MemoryAddress {
virt: base_address.virt.saturating_add(i),
..base_address
};
let val = state.memory.get(address);
val.as_u32() as u8
})
.collect_vec();
log::debug!("Hashing {:?}", input);
let hash = keccak(&input);
let log_push = stack_push_log_and_fill(state, &mut row, hash.into_uint())?;
keccak_sponge_log(state, base_address, input);
state.traces.push_memory(log_in0);
state.traces.push_memory(log_in1);
state.traces.push_memory(log_in2);
state.traces.push_memory(log_in3);
state.traces.push_memory(log_push);
state.traces.push_cpu(row);
Ok(())
}
pub(crate) fn generate_prover_input<F: Field>(
state: &mut GenerationState<F>,
mut row: CpuColumnsView<F>,
) -> Result<(), ProgramError> {
let pc = state.registers.program_counter;
let input_fn = &KERNEL.prover_inputs[&pc];
let input = state.prover_input(input_fn);
let write = stack_push_log_and_fill(state, &mut row, input)?;
state.traces.push_memory(write);
state.traces.push_cpu(row);
Ok(())
}
pub(crate) fn generate_pop<F: Field>(
state: &mut GenerationState<F>,
row: CpuColumnsView<F>,
) -> Result<(), ProgramError> {
if state.registers.stack_len == 0 {
return Err(ProgramError::StackUnderflow);
}
state.registers.stack_len -= 1;
state.traces.push_cpu(row);
Ok(())
}
pub(crate) fn generate_jump<F: Field>(
state: &mut GenerationState<F>,
mut row: CpuColumnsView<F>,
) -> Result<(), ProgramError> {
let [(dst, log_in0)] = stack_pop_with_log_and_fill::<1, _>(state, &mut row)?;
state.traces.push_memory(log_in0);
state.traces.push_cpu(row);
state.registers.program_counter = u256_saturating_cast_usize(dst);
// TODO: Set other cols like input0_upper_sum_inv.
Ok(())
}
pub(crate) fn generate_jumpi<F: Field>(
state: &mut GenerationState<F>,
mut row: CpuColumnsView<F>,
) -> Result<(), ProgramError> {
let [(dst, log_in0), (cond, log_in1)] = stack_pop_with_log_and_fill::<2, _>(state, &mut row)?;
state.traces.push_memory(log_in0);
state.traces.push_memory(log_in1);
state.traces.push_cpu(row);
state.registers.program_counter = if cond.is_zero() {
state.registers.program_counter + 1
} else {
u256_saturating_cast_usize(dst)
};
// TODO: Set other cols like input0_upper_sum_inv.
Ok(())
}
pub(crate) fn generate_push<F: Field>(
n: u8,
state: &mut GenerationState<F>,
mut row: CpuColumnsView<F>,
) -> Result<(), ProgramError> {
let context = state.registers.effective_context();
let num_bytes = n as usize + 1;
let initial_offset = state.registers.program_counter + 1;
let offsets = initial_offset..initial_offset + num_bytes;
let mut addrs = offsets.map(|offset| MemoryAddress::new(context, Segment::Code, offset));
// First read val without going through `mem_read_with_log` type methods, so we can pass it
// to stack_push_log_and_fill.
let bytes = (0..num_bytes)
.map(|i| {
state
.memory
.get(MemoryAddress::new(
context,
Segment::Code,
initial_offset + i,
))
.as_u32() as u8
})
.collect_vec();
let val = U256::from_big_endian(&bytes);
let write = stack_push_log_and_fill(state, &mut row, val)?;
// In the first cycle, we read up to NUM_GP_CHANNELS - 1 bytes, leaving the last GP channel
// to push the result.
for (i, addr) in (&mut addrs).take(NUM_GP_CHANNELS - 1).enumerate() {
let (_, read) = mem_read_gp_with_log_and_fill(i, addr, state, &mut row);
state.traces.push_memory(read);
}
state.traces.push_memory(write);
state.traces.push_cpu(row);
// In any subsequent cycles, we read up to 1 + NUM_GP_CHANNELS bytes.
for mut addrs_chunk in &addrs.chunks(1 + NUM_GP_CHANNELS) {
let mut row = CpuColumnsView::default();
row.is_cpu_cycle = F::ONE;
row.op.push = F::ONE;
let first_addr = addrs_chunk.next().unwrap();
let (_, first_read) = mem_read_code_with_log_and_fill(first_addr, state, &mut row);
state.traces.push_memory(first_read);
for (i, addr) in addrs_chunk.enumerate() {
let (_, read) = mem_read_gp_with_log_and_fill(i, addr, state, &mut row);
state.traces.push_memory(read);
}
state.traces.push_cpu(row);
}
Ok(())
}
pub(crate) fn generate_dup<F: Field>(
n: u8,
state: &mut GenerationState<F>,
mut row: CpuColumnsView<F>,
) -> Result<(), ProgramError> {
let other_addr_lo = state
.registers
.stack_len
.checked_sub(1 + (n as usize))
.ok_or(ProgramError::StackUnderflow)?;
let other_addr = MemoryAddress::new(
state.registers.effective_context(),
Segment::Stack,
other_addr_lo,
);
let (val, log_in) = mem_read_gp_with_log_and_fill(0, other_addr, state, &mut row);
let log_out = stack_push_log_and_fill(state, &mut row, val)?;
state.traces.push_memory(log_in);
state.traces.push_memory(log_out);
state.traces.push_cpu(row);
Ok(())
}
pub(crate) fn generate_swap<F: Field>(
n: u8,
state: &mut GenerationState<F>,
mut row: CpuColumnsView<F>,
) -> Result<(), ProgramError> {
let other_addr_lo = state
.registers
.stack_len
.checked_sub(2 + (n as usize))
.ok_or(ProgramError::StackUnderflow)?;
let other_addr = MemoryAddress::new(
state.registers.effective_context(),
Segment::Stack,
other_addr_lo,
);
let [(in0, log_in0)] = stack_pop_with_log_and_fill::<1, _>(state, &mut row)?;
let (in1, log_in1) = mem_read_gp_with_log_and_fill(1, other_addr, state, &mut row);
let log_out0 = mem_write_gp_log_and_fill(NUM_GP_CHANNELS - 2, other_addr, state, &mut row, in0);
let log_out1 = stack_push_log_and_fill(state, &mut row, in1)?;
state.traces.push_memory(log_in0);
state.traces.push_memory(log_in1);
state.traces.push_memory(log_out0);
state.traces.push_memory(log_out1);
state.traces.push_cpu(row);
Ok(())
}
pub(crate) fn generate_not<F: Field>(
state: &mut GenerationState<F>,
mut row: CpuColumnsView<F>,
) -> Result<(), ProgramError> {
let [(x, log_in)] = stack_pop_with_log_and_fill::<1, _>(state, &mut row)?;
let result = !x;
let log_out = stack_push_log_and_fill(state, &mut row, result)?;
state.traces.push_memory(log_in);
state.traces.push_memory(log_out);
state.traces.push_cpu(row);
Ok(())
}
pub(crate) fn generate_byte<F: Field>(
state: &mut GenerationState<F>,
mut row: CpuColumnsView<F>,
) -> Result<(), ProgramError> {
let [(i, log_in0), (x, log_in1)] = stack_pop_with_log_and_fill::<2, _>(state, &mut row)?;
let byte = if i < 32.into() {
// byte(i) is the i'th little-endian byte; we want the i'th big-endian byte.
x.byte(31 - i.as_usize())
} else {
0
};
let log_out = stack_push_log_and_fill(state, &mut row, byte.into())?;
state.traces.push_memory(log_in0);
state.traces.push_memory(log_in1);
state.traces.push_memory(log_out);
state.traces.push_cpu(row);
Ok(())
}
pub(crate) fn generate_iszero<F: Field>(
state: &mut GenerationState<F>,
mut row: CpuColumnsView<F>,
) -> Result<(), ProgramError> {
let [(x, log_in)] = stack_pop_with_log_and_fill::<1, _>(state, &mut row)?;
let is_zero = x.is_zero();
let result = {
let t: u64 = is_zero.into();
t.into()
};
let log_out = stack_push_log_and_fill(state, &mut row, result)?;
generate_pinv_diff(x, U256::zero(), &mut row);
state.traces.push_memory(log_in);
state.traces.push_memory(log_out);
state.traces.push_cpu(row);
Ok(())
}
pub(crate) fn generate_syscall<F: Field>(
opcode: u8,
state: &mut GenerationState<F>,
mut row: CpuColumnsView<F>,
) -> Result<(), ProgramError> {
let handler_jumptable_addr = KERNEL.global_labels["syscall_jumptable"];
let handler_addr_addr = handler_jumptable_addr + (opcode as usize);
let (handler_addr0, log_in0) = mem_read_gp_with_log_and_fill(
0,
MemoryAddress::new(0, Segment::Code, handler_addr_addr),
state,
&mut row,
);
let (handler_addr1, log_in1) = mem_read_gp_with_log_and_fill(
1,
MemoryAddress::new(0, Segment::Code, handler_addr_addr + 1),
state,
&mut row,
);
let (handler_addr2, log_in2) = mem_read_gp_with_log_and_fill(
2,
MemoryAddress::new(0, Segment::Code, handler_addr_addr + 2),
state,
&mut row,
);
let handler_addr = (handler_addr0 << 16) + (handler_addr1 << 8) + handler_addr2;
let new_program_counter = handler_addr.as_usize();
let syscall_info = U256::from(state.registers.program_counter)
+ (U256::from(u64::from(state.registers.is_kernel)) << 32);
let log_out = stack_push_log_and_fill(state, &mut row, syscall_info)?;
state.registers.program_counter = new_program_counter;
state.registers.is_kernel = true;
state.traces.push_memory(log_in0);
state.traces.push_memory(log_in1);
state.traces.push_memory(log_in2);
state.traces.push_memory(log_out);
state.traces.push_cpu(row);
Ok(())
}
pub(crate) fn generate_eq<F: Field>(
state: &mut GenerationState<F>,
mut row: CpuColumnsView<F>,
) -> Result<(), ProgramError> {
let [(in0, log_in0), (in1, log_in1)] = stack_pop_with_log_and_fill::<2, _>(state, &mut row)?;
let eq = in0 == in1;
let result = U256::from(u64::from(eq));
let log_out = stack_push_log_and_fill(state, &mut row, result)?;
generate_pinv_diff(in0, in1, &mut row);
state.traces.push_memory(log_in0);
state.traces.push_memory(log_in1);
state.traces.push_memory(log_out);
state.traces.push_cpu(row);
Ok(())
}
pub(crate) fn generate_exit_kernel<F: Field>(
state: &mut GenerationState<F>,
mut row: CpuColumnsView<F>,
) -> Result<(), ProgramError> {
let [(kexit_info, log_in)] = stack_pop_with_log_and_fill::<1, _>(state, &mut row)?;
let kexit_info_u64: [u64; 4] = kexit_info.0;
let program_counter = kexit_info_u64[0] as usize;
let is_kernel_mode_val = (kexit_info_u64[1] >> 32) as u32;
assert!(is_kernel_mode_val == 0 || is_kernel_mode_val == 1);
let is_kernel_mode = is_kernel_mode_val != 0;
state.registers.program_counter = program_counter;
state.registers.is_kernel = is_kernel_mode;
state.traces.push_memory(log_in);
state.traces.push_cpu(row);
Ok(())
}
pub(crate) fn generate_mload_general<F: Field>(
state: &mut GenerationState<F>,
mut row: CpuColumnsView<F>,
) -> Result<(), ProgramError> {
let [(context, log_in0), (segment, log_in1), (virt, log_in2)] =
stack_pop_with_log_and_fill::<3, _>(state, &mut row)?;
let val = state
.memory
.get(MemoryAddress::new_u256s(context, segment, virt));
let log_out = stack_push_log_and_fill(state, &mut row, val)?;
state.traces.push_memory(log_in0);
state.traces.push_memory(log_in1);
state.traces.push_memory(log_in2);
state.traces.push_memory(log_out);
state.traces.push_cpu(row);
Ok(())
}
pub(crate) fn generate_mstore_general<F: Field>(
state: &mut GenerationState<F>,
mut row: CpuColumnsView<F>,
) -> Result<(), ProgramError> {
let [(context, log_in0), (segment, log_in1), (virt, log_in2), (val, log_in3)] =
stack_pop_with_log_and_fill::<4, _>(state, &mut row)?;
let address = MemoryAddress {
context: context.as_usize(),
segment: segment.as_usize(),
virt: virt.as_usize(),
};
let log_write = mem_write_gp_log_and_fill(4, address, state, &mut row, val);
state.traces.push_memory(log_in0);
state.traces.push_memory(log_in1);
state.traces.push_memory(log_in2);
state.traces.push_memory(log_in3);
state.traces.push_memory(log_write);
state.traces.push_cpu(row);
Ok(())
}

32
evm/src/witness/state.rs Normal file
View File

@ -0,0 +1,32 @@
use crate::cpu::kernel::aggregator::KERNEL;
const KERNEL_CONTEXT: usize = 0;
#[derive(Clone, Copy, Debug, Eq, PartialEq)]
pub struct RegistersState {
pub program_counter: usize,
pub is_kernel: bool,
pub stack_len: usize,
pub context: usize,
}
impl RegistersState {
pub(crate) fn effective_context(&self) -> usize {
if self.is_kernel {
KERNEL_CONTEXT
} else {
self.context
}
}
}
impl Default for RegistersState {
fn default() -> Self {
Self {
program_counter: KERNEL.global_labels["main"],
is_kernel: true,
stack_len: 0,
context: 0,
}
}
}

161
evm/src/witness/traces.rs Normal file
View File

@ -0,0 +1,161 @@
use std::mem::size_of;
use itertools::Itertools;
use plonky2::field::extension::Extendable;
use plonky2::field::polynomial::PolynomialValues;
use plonky2::hash::hash_types::RichField;
use plonky2::util::timing::TimingTree;
use crate::all_stark::{AllStark, NUM_TABLES};
use crate::config::StarkConfig;
use crate::cpu::columns::CpuColumnsView;
use crate::keccak_sponge::columns::KECCAK_WIDTH_BYTES;
use crate::keccak_sponge::keccak_sponge_stark::KeccakSpongeOp;
use crate::util::trace_rows_to_poly_values;
use crate::witness::memory::MemoryOp;
use crate::{arithmetic, keccak, logic};
#[derive(Clone, Copy, Debug)]
pub struct TraceCheckpoint {
pub(self) cpu_len: usize,
pub(self) keccak_len: usize,
pub(self) keccak_sponge_len: usize,
pub(self) logic_len: usize,
pub(self) arithmetic_len: usize,
pub(self) memory_len: usize,
}
#[derive(Debug)]
pub(crate) struct Traces<T: Copy> {
pub(crate) cpu: Vec<CpuColumnsView<T>>,
pub(crate) logic_ops: Vec<logic::Operation>,
pub(crate) arithmetic: Vec<arithmetic::Operation>,
pub(crate) memory_ops: Vec<MemoryOp>,
pub(crate) keccak_inputs: Vec<[u64; keccak::keccak_stark::NUM_INPUTS]>,
pub(crate) keccak_sponge_ops: Vec<KeccakSpongeOp>,
}
impl<T: Copy> Traces<T> {
pub fn new() -> Self {
Traces {
cpu: vec![],
logic_ops: vec![],
arithmetic: vec![],
memory_ops: vec![],
keccak_inputs: vec![],
keccak_sponge_ops: vec![],
}
}
pub fn checkpoint(&self) -> TraceCheckpoint {
TraceCheckpoint {
cpu_len: self.cpu.len(),
keccak_len: self.keccak_inputs.len(),
keccak_sponge_len: self.keccak_sponge_ops.len(),
logic_len: self.logic_ops.len(),
arithmetic_len: self.arithmetic.len(),
memory_len: self.memory_ops.len(),
}
}
pub fn rollback(&mut self, checkpoint: TraceCheckpoint) {
self.cpu.truncate(checkpoint.cpu_len);
self.keccak_inputs.truncate(checkpoint.keccak_len);
self.keccak_sponge_ops
.truncate(checkpoint.keccak_sponge_len);
self.logic_ops.truncate(checkpoint.logic_len);
self.arithmetic.truncate(checkpoint.arithmetic_len);
self.memory_ops.truncate(checkpoint.memory_len);
}
pub fn mem_ops_since(&self, checkpoint: TraceCheckpoint) -> &[MemoryOp] {
&self.memory_ops[checkpoint.memory_len..]
}
pub fn push_cpu(&mut self, val: CpuColumnsView<T>) {
self.cpu.push(val);
}
pub fn push_logic(&mut self, op: logic::Operation) {
self.logic_ops.push(op);
}
pub fn push_arithmetic(&mut self, op: arithmetic::Operation) {
self.arithmetic.push(op);
}
pub fn push_memory(&mut self, op: MemoryOp) {
self.memory_ops.push(op);
}
pub fn push_keccak(&mut self, input: [u64; keccak::keccak_stark::NUM_INPUTS]) {
self.keccak_inputs.push(input);
}
pub fn push_keccak_bytes(&mut self, input: [u8; KECCAK_WIDTH_BYTES]) {
let chunks = input
.chunks(size_of::<u64>())
.map(|chunk| u64::from_le_bytes(chunk.try_into().unwrap()))
.collect_vec()
.try_into()
.unwrap();
self.push_keccak(chunks);
}
pub fn push_keccak_sponge(&mut self, op: KeccakSpongeOp) {
self.keccak_sponge_ops.push(op);
}
pub fn clock(&self) -> usize {
self.cpu.len()
}
pub fn into_tables<const D: usize>(
self,
all_stark: &AllStark<T, D>,
config: &StarkConfig,
timing: &mut TimingTree,
) -> [Vec<PolynomialValues<T>>; NUM_TABLES]
where
T: RichField + Extendable<D>,
{
let cap_elements = config.fri_config.num_cap_elements();
let Traces {
cpu,
logic_ops,
arithmetic: _, // TODO
memory_ops,
keccak_inputs,
keccak_sponge_ops,
} = self;
let cpu_rows = cpu.into_iter().map(|x| x.into()).collect();
let cpu_trace = trace_rows_to_poly_values(cpu_rows);
let keccak_trace =
all_stark
.keccak_stark
.generate_trace(keccak_inputs, cap_elements, timing);
let keccak_sponge_trace =
all_stark
.keccak_sponge_stark
.generate_trace(keccak_sponge_ops, cap_elements, timing);
let logic_trace = all_stark
.logic_stark
.generate_trace(logic_ops, cap_elements, timing);
let memory_trace = all_stark.memory_stark.generate_trace(memory_ops, timing);
[
cpu_trace,
keccak_trace,
keccak_sponge_trace,
logic_trace,
memory_trace,
]
}
}
impl<T: Copy> Default for Traces<T> {
fn default() -> Self {
Self::new()
}
}

View File

@ -0,0 +1,279 @@
use itertools::Itertools;
use plonky2::field::types::Field;
use crate::cpu::columns::CpuColumnsView;
use crate::cpu::kernel::aggregator::KERNEL;
use crate::generation::state::GenerationState;
use crate::memory::segments::Segment;
use crate::witness::errors::ProgramError;
use crate::witness::memory::MemoryAddress;
use crate::witness::operation::*;
use crate::witness::state::RegistersState;
use crate::witness::util::{mem_read_code_with_log_and_fill, stack_peek};
use crate::{arithmetic, logic};
fn read_code_memory<F: Field>(state: &mut GenerationState<F>, row: &mut CpuColumnsView<F>) -> u8 {
let code_context = state.registers.effective_context();
row.code_context = F::from_canonical_usize(code_context);
let address = MemoryAddress::new(code_context, Segment::Code, state.registers.program_counter);
let (opcode, mem_log) = mem_read_code_with_log_and_fill(address, state, row);
state.traces.push_memory(mem_log);
opcode
}
fn decode(registers: RegistersState, opcode: u8) -> Result<Operation, ProgramError> {
match (opcode, registers.is_kernel) {
(0x00, _) => Ok(Operation::Syscall(opcode)),
(0x01, _) => Ok(Operation::BinaryArithmetic(arithmetic::BinaryOperator::Add)),
(0x02, _) => Ok(Operation::BinaryArithmetic(arithmetic::BinaryOperator::Mul)),
(0x03, _) => Ok(Operation::BinaryArithmetic(arithmetic::BinaryOperator::Sub)),
(0x04, _) => Ok(Operation::BinaryArithmetic(arithmetic::BinaryOperator::Div)),
(0x05, _) => Ok(Operation::Syscall(opcode)),
(0x06, _) => Ok(Operation::BinaryArithmetic(arithmetic::BinaryOperator::Mod)),
(0x07, _) => Ok(Operation::Syscall(opcode)),
(0x08, _) => Ok(Operation::TernaryArithmetic(
arithmetic::TernaryOperator::AddMod,
)),
(0x09, _) => Ok(Operation::TernaryArithmetic(
arithmetic::TernaryOperator::MulMod,
)),
(0x0a, _) => Ok(Operation::Syscall(opcode)),
(0x0b, _) => Ok(Operation::Syscall(opcode)),
(0x0c, true) => Ok(Operation::BinaryArithmetic(
arithmetic::BinaryOperator::AddFp254,
)),
(0x0d, true) => Ok(Operation::BinaryArithmetic(
arithmetic::BinaryOperator::MulFp254,
)),
(0x0e, true) => Ok(Operation::BinaryArithmetic(
arithmetic::BinaryOperator::SubFp254,
)),
(0x10, _) => Ok(Operation::BinaryArithmetic(arithmetic::BinaryOperator::Lt)),
(0x11, _) => Ok(Operation::BinaryArithmetic(arithmetic::BinaryOperator::Gt)),
(0x12, _) => Ok(Operation::Syscall(opcode)),
(0x13, _) => Ok(Operation::Syscall(opcode)),
(0x14, _) => Ok(Operation::Eq),
(0x15, _) => Ok(Operation::Iszero),
(0x16, _) => Ok(Operation::BinaryLogic(logic::Op::And)),
(0x17, _) => Ok(Operation::BinaryLogic(logic::Op::Or)),
(0x18, _) => Ok(Operation::BinaryLogic(logic::Op::Xor)),
(0x19, _) => Ok(Operation::Not),
(0x1a, _) => Ok(Operation::Byte),
(0x1b, _) => Ok(Operation::BinaryArithmetic(arithmetic::BinaryOperator::Shl)),
(0x1c, _) => Ok(Operation::BinaryArithmetic(arithmetic::BinaryOperator::Shr)),
(0x1d, _) => Ok(Operation::Syscall(opcode)),
(0x20, _) => Ok(Operation::Syscall(opcode)),
(0x21, true) => Ok(Operation::KeccakGeneral),
(0x30, _) => Ok(Operation::Syscall(opcode)),
(0x31, _) => Ok(Operation::Syscall(opcode)),
(0x32, _) => Ok(Operation::Syscall(opcode)),
(0x33, _) => Ok(Operation::Syscall(opcode)),
(0x34, _) => Ok(Operation::Syscall(opcode)),
(0x35, _) => Ok(Operation::Syscall(opcode)),
(0x36, _) => Ok(Operation::Syscall(opcode)),
(0x37, _) => Ok(Operation::Syscall(opcode)),
(0x38, _) => Ok(Operation::Syscall(opcode)),
(0x39, _) => Ok(Operation::Syscall(opcode)),
(0x3a, _) => Ok(Operation::Syscall(opcode)),
(0x3b, _) => Ok(Operation::Syscall(opcode)),
(0x3c, _) => Ok(Operation::Syscall(opcode)),
(0x3d, _) => Ok(Operation::Syscall(opcode)),
(0x3e, _) => Ok(Operation::Syscall(opcode)),
(0x3f, _) => Ok(Operation::Syscall(opcode)),
(0x40, _) => Ok(Operation::Syscall(opcode)),
(0x41, _) => Ok(Operation::Syscall(opcode)),
(0x42, _) => Ok(Operation::Syscall(opcode)),
(0x43, _) => Ok(Operation::Syscall(opcode)),
(0x44, _) => Ok(Operation::Syscall(opcode)),
(0x45, _) => Ok(Operation::Syscall(opcode)),
(0x46, _) => Ok(Operation::Syscall(opcode)),
(0x47, _) => Ok(Operation::Syscall(opcode)),
(0x48, _) => Ok(Operation::Syscall(opcode)),
(0x49, _) => Ok(Operation::ProverInput),
(0x50, _) => Ok(Operation::Pop),
(0x51, _) => Ok(Operation::Syscall(opcode)),
(0x52, _) => Ok(Operation::Syscall(opcode)),
(0x53, _) => Ok(Operation::Syscall(opcode)),
(0x54, _) => Ok(Operation::Syscall(opcode)),
(0x55, _) => Ok(Operation::Syscall(opcode)),
(0x56, _) => Ok(Operation::Jump),
(0x57, _) => Ok(Operation::Jumpi),
(0x58, _) => Ok(Operation::Pc),
(0x59, _) => Ok(Operation::Syscall(opcode)),
(0x5a, _) => Ok(Operation::Gas),
(0x5b, _) => Ok(Operation::Jumpdest),
(0x60..=0x7f, _) => Ok(Operation::Push(opcode & 0x1f)),
(0x80..=0x8f, _) => Ok(Operation::Dup(opcode & 0xf)),
(0x90..=0x9f, _) => Ok(Operation::Swap(opcode & 0xf)),
(0xa0, _) => Ok(Operation::Syscall(opcode)),
(0xa1, _) => Ok(Operation::Syscall(opcode)),
(0xa2, _) => Ok(Operation::Syscall(opcode)),
(0xa3, _) => Ok(Operation::Syscall(opcode)),
(0xa4, _) => Ok(Operation::Syscall(opcode)),
(0xf0, _) => Ok(Operation::Syscall(opcode)),
(0xf1, _) => Ok(Operation::Syscall(opcode)),
(0xf2, _) => Ok(Operation::Syscall(opcode)),
(0xf3, _) => Ok(Operation::Syscall(opcode)),
(0xf4, _) => Ok(Operation::Syscall(opcode)),
(0xf5, _) => Ok(Operation::Syscall(opcode)),
(0xf6, true) => Ok(Operation::GetContext),
(0xf7, true) => Ok(Operation::SetContext),
(0xf8, true) => Ok(Operation::ConsumeGas),
(0xf9, true) => Ok(Operation::ExitKernel),
(0xfa, _) => Ok(Operation::Syscall(opcode)),
(0xfb, true) => Ok(Operation::MloadGeneral),
(0xfc, true) => Ok(Operation::MstoreGeneral),
(0xfd, _) => Ok(Operation::Syscall(opcode)),
(0xff, _) => Ok(Operation::Syscall(opcode)),
_ => Err(ProgramError::InvalidOpcode),
}
}
fn fill_op_flag<F: Field>(op: Operation, row: &mut CpuColumnsView<F>) {
let flags = &mut row.op;
*match op {
Operation::Push(_) => &mut flags.push,
Operation::Dup(_) => &mut flags.dup,
Operation::Swap(_) => &mut flags.swap,
Operation::Iszero => &mut flags.iszero,
Operation::Not => &mut flags.not,
Operation::Byte => &mut flags.byte,
Operation::Syscall(_) => &mut flags.syscall,
Operation::Eq => &mut flags.eq,
Operation::BinaryLogic(logic::Op::And) => &mut flags.and,
Operation::BinaryLogic(logic::Op::Or) => &mut flags.or,
Operation::BinaryLogic(logic::Op::Xor) => &mut flags.xor,
Operation::BinaryArithmetic(arithmetic::BinaryOperator::Add) => &mut flags.add,
Operation::BinaryArithmetic(arithmetic::BinaryOperator::Mul) => &mut flags.mul,
Operation::BinaryArithmetic(arithmetic::BinaryOperator::Sub) => &mut flags.sub,
Operation::BinaryArithmetic(arithmetic::BinaryOperator::Div) => &mut flags.div,
Operation::BinaryArithmetic(arithmetic::BinaryOperator::Mod) => &mut flags.mod_,
Operation::BinaryArithmetic(arithmetic::BinaryOperator::Lt) => &mut flags.lt,
Operation::BinaryArithmetic(arithmetic::BinaryOperator::Gt) => &mut flags.gt,
Operation::BinaryArithmetic(arithmetic::BinaryOperator::Shl) => &mut flags.shl,
Operation::BinaryArithmetic(arithmetic::BinaryOperator::Shr) => &mut flags.shr,
Operation::BinaryArithmetic(arithmetic::BinaryOperator::AddFp254) => &mut flags.addfp254,
Operation::BinaryArithmetic(arithmetic::BinaryOperator::MulFp254) => &mut flags.mulfp254,
Operation::BinaryArithmetic(arithmetic::BinaryOperator::SubFp254) => &mut flags.subfp254,
Operation::TernaryArithmetic(arithmetic::TernaryOperator::AddMod) => &mut flags.addmod,
Operation::TernaryArithmetic(arithmetic::TernaryOperator::MulMod) => &mut flags.mulmod,
Operation::KeccakGeneral => &mut flags.keccak_general,
Operation::ProverInput => &mut flags.prover_input,
Operation::Pop => &mut flags.pop,
Operation::Jump => &mut flags.jump,
Operation::Jumpi => &mut flags.jumpi,
Operation::Pc => &mut flags.pc,
Operation::Gas => &mut flags.gas,
Operation::Jumpdest => &mut flags.jumpdest,
Operation::GetContext => &mut flags.get_context,
Operation::SetContext => &mut flags.set_context,
Operation::ConsumeGas => &mut flags.consume_gas,
Operation::ExitKernel => &mut flags.exit_kernel,
Operation::MloadGeneral => &mut flags.mload_general,
Operation::MstoreGeneral => &mut flags.mstore_general,
} = F::ONE;
}
fn perform_op<F: Field>(
state: &mut GenerationState<F>,
op: Operation,
row: CpuColumnsView<F>,
) -> Result<(), ProgramError> {
match op {
Operation::Push(n) => generate_push(n, state, row)?,
Operation::Dup(n) => generate_dup(n, state, row)?,
Operation::Swap(n) => generate_swap(n, state, row)?,
Operation::Iszero => generate_iszero(state, row)?,
Operation::Not => generate_not(state, row)?,
Operation::Byte => generate_byte(state, row)?,
Operation::Syscall(opcode) => generate_syscall(opcode, state, row)?,
Operation::Eq => generate_eq(state, row)?,
Operation::BinaryLogic(binary_logic_op) => {
generate_binary_logic_op(binary_logic_op, state, row)?
}
Operation::BinaryArithmetic(op) => generate_binary_arithmetic_op(op, state, row)?,
Operation::TernaryArithmetic(op) => generate_ternary_arithmetic_op(op, state, row)?,
Operation::KeccakGeneral => generate_keccak_general(state, row)?,
Operation::ProverInput => generate_prover_input(state, row)?,
Operation::Pop => generate_pop(state, row)?,
Operation::Jump => generate_jump(state, row)?,
Operation::Jumpi => generate_jumpi(state, row)?,
Operation::Pc => todo!(),
Operation::Gas => todo!(),
Operation::Jumpdest => todo!(),
Operation::GetContext => todo!(),
Operation::SetContext => todo!(),
Operation::ConsumeGas => todo!(),
Operation::ExitKernel => generate_exit_kernel(state, row)?,
Operation::MloadGeneral => generate_mload_general(state, row)?,
Operation::MstoreGeneral => generate_mstore_general(state, row)?,
};
state.registers.program_counter += match op {
Operation::Syscall(_) | Operation::ExitKernel => 0,
Operation::Push(n) => n as usize + 2,
Operation::Jump | Operation::Jumpi => 0,
_ => 1,
};
if let Some(label) = KERNEL.offset_label(state.registers.program_counter) {
if !label.starts_with("halt_pc") {
log::debug!("At {label}");
}
}
Ok(())
}
fn try_perform_instruction<F: Field>(state: &mut GenerationState<F>) -> Result<(), ProgramError> {
let mut row: CpuColumnsView<F> = CpuColumnsView::default();
row.is_cpu_cycle = F::ONE;
row.clock = F::from_canonical_usize(state.traces.clock());
row.context = F::from_canonical_usize(state.registers.context);
row.program_counter = F::from_canonical_usize(state.registers.program_counter);
row.is_kernel_mode = F::from_bool(state.registers.is_kernel);
row.stack_len = F::from_canonical_usize(state.registers.stack_len);
let opcode = read_code_memory(state, &mut row);
let op = decode(state.registers, opcode)?;
let pc = state.registers.program_counter;
log::trace!("\nCycle {}", state.traces.clock());
log::trace!(
"Stack: {:?}",
(0..state.registers.stack_len)
.map(|i| stack_peek(state, i).unwrap())
.collect_vec()
);
log::trace!("Executing {:?} at {}", op, KERNEL.offset_name(pc));
fill_op_flag(op, &mut row);
perform_op(state, op, row)
}
fn handle_error<F: Field>(_state: &mut GenerationState<F>) {
todo!("generation for exception handling is not implemented");
}
pub(crate) fn transition<F: Field>(state: &mut GenerationState<F>) {
let checkpoint = state.checkpoint();
let result = try_perform_instruction(state);
match result {
Ok(()) => {
state
.memory
.apply_ops(state.traces.mem_ops_since(checkpoint.traces));
}
Err(e) => {
if state.registers.is_kernel {
panic!("exception in kernel mode: {:?}", e);
}
state.rollback(checkpoint);
handle_error(state)
}
}
}

250
evm/src/witness/util.rs Normal file
View File

@ -0,0 +1,250 @@
use ethereum_types::U256;
use plonky2::field::types::Field;
use crate::cpu::columns::CpuColumnsView;
use crate::cpu::kernel::keccak_util::keccakf_u8s;
use crate::cpu::membus::{NUM_CHANNELS, NUM_GP_CHANNELS};
use crate::cpu::stack_bounds::MAX_USER_STACK_SIZE;
use crate::generation::state::GenerationState;
use crate::keccak_sponge::columns::{KECCAK_RATE_BYTES, KECCAK_WIDTH_BYTES};
use crate::keccak_sponge::keccak_sponge_stark::KeccakSpongeOp;
use crate::logic;
use crate::memory::segments::Segment;
use crate::witness::errors::ProgramError;
use crate::witness::memory::{MemoryAddress, MemoryChannel, MemoryOp, MemoryOpKind};
fn to_byte_checked(n: U256) -> u8 {
let res = n.byte(0);
assert_eq!(n, res.into());
res
}
fn to_bits_le<F: Field>(n: u8) -> [F; 8] {
let mut res = [F::ZERO; 8];
for (i, bit) in res.iter_mut().enumerate() {
*bit = F::from_bool(n & (1 << i) != 0);
}
res
}
/// Peak at the stack item `i`th from the top. If `i=0` this gives the tip.
pub(crate) fn stack_peek<F: Field>(state: &GenerationState<F>, i: usize) -> Option<U256> {
if i >= state.registers.stack_len {
return None;
}
Some(state.memory.get(MemoryAddress::new(
state.registers.effective_context(),
Segment::Stack,
state.registers.stack_len - 1 - i,
)))
}
pub(crate) fn mem_read_with_log<F: Field>(
channel: MemoryChannel,
address: MemoryAddress,
state: &GenerationState<F>,
) -> (U256, MemoryOp) {
let val = state.memory.get(address);
let op = MemoryOp::new(
channel,
state.traces.clock(),
address,
MemoryOpKind::Read,
val,
);
(val, op)
}
pub(crate) fn mem_write_log<F: Field>(
channel: MemoryChannel,
address: MemoryAddress,
state: &mut GenerationState<F>,
val: U256,
) -> MemoryOp {
MemoryOp::new(
channel,
state.traces.clock(),
address,
MemoryOpKind::Write,
val,
)
}
pub(crate) fn mem_read_code_with_log_and_fill<F: Field>(
address: MemoryAddress,
state: &GenerationState<F>,
row: &mut CpuColumnsView<F>,
) -> (u8, MemoryOp) {
let (val, op) = mem_read_with_log(MemoryChannel::Code, address, state);
let val_u8 = to_byte_checked(val);
row.opcode_bits = to_bits_le(val_u8);
(val_u8, op)
}
pub(crate) fn mem_read_gp_with_log_and_fill<F: Field>(
n: usize,
address: MemoryAddress,
state: &mut GenerationState<F>,
row: &mut CpuColumnsView<F>,
) -> (U256, MemoryOp) {
let (val, op) = mem_read_with_log(MemoryChannel::GeneralPurpose(n), address, state);
let val_limbs: [u64; 4] = val.0;
let channel = &mut row.mem_channels[n];
assert_eq!(channel.used, F::ZERO);
channel.used = F::ONE;
channel.is_read = F::ONE;
channel.addr_context = F::from_canonical_usize(address.context);
channel.addr_segment = F::from_canonical_usize(address.segment);
channel.addr_virtual = F::from_canonical_usize(address.virt);
for (i, limb) in val_limbs.into_iter().enumerate() {
channel.value[2 * i] = F::from_canonical_u32(limb as u32);
channel.value[2 * i + 1] = F::from_canonical_u32((limb >> 32) as u32);
}
(val, op)
}
pub(crate) fn mem_write_gp_log_and_fill<F: Field>(
n: usize,
address: MemoryAddress,
state: &mut GenerationState<F>,
row: &mut CpuColumnsView<F>,
val: U256,
) -> MemoryOp {
let op = mem_write_log(MemoryChannel::GeneralPurpose(n), address, state, val);
let val_limbs: [u64; 4] = val.0;
let channel = &mut row.mem_channels[n];
assert_eq!(channel.used, F::ZERO);
channel.used = F::ONE;
channel.is_read = F::ZERO;
channel.addr_context = F::from_canonical_usize(address.context);
channel.addr_segment = F::from_canonical_usize(address.segment);
channel.addr_virtual = F::from_canonical_usize(address.virt);
for (i, limb) in val_limbs.into_iter().enumerate() {
channel.value[2 * i] = F::from_canonical_u32(limb as u32);
channel.value[2 * i + 1] = F::from_canonical_u32((limb >> 32) as u32);
}
op
}
pub(crate) fn stack_pop_with_log_and_fill<const N: usize, F: Field>(
state: &mut GenerationState<F>,
row: &mut CpuColumnsView<F>,
) -> Result<[(U256, MemoryOp); N], ProgramError> {
if state.registers.stack_len < N {
return Err(ProgramError::StackUnderflow);
}
let result = std::array::from_fn(|i| {
let address = MemoryAddress::new(
state.registers.effective_context(),
Segment::Stack,
state.registers.stack_len - 1 - i,
);
mem_read_gp_with_log_and_fill(i, address, state, row)
});
state.registers.stack_len -= N;
Ok(result)
}
pub(crate) fn stack_push_log_and_fill<F: Field>(
state: &mut GenerationState<F>,
row: &mut CpuColumnsView<F>,
val: U256,
) -> Result<MemoryOp, ProgramError> {
if !state.registers.is_kernel && state.registers.stack_len >= MAX_USER_STACK_SIZE {
return Err(ProgramError::StackOverflow);
}
let address = MemoryAddress::new(
state.registers.effective_context(),
Segment::Stack,
state.registers.stack_len,
);
let res = mem_write_gp_log_and_fill(NUM_GP_CHANNELS - 1, address, state, row, val);
state.registers.stack_len += 1;
Ok(res)
}
fn xor_into_sponge<F: Field>(
state: &mut GenerationState<F>,
sponge_state: &mut [u8; KECCAK_WIDTH_BYTES],
block: &[u8; KECCAK_RATE_BYTES],
) {
for i in (0..KECCAK_RATE_BYTES).step_by(32) {
let range = i..KECCAK_RATE_BYTES.min(i + 32);
let lhs = U256::from_little_endian(&sponge_state[range.clone()]);
let rhs = U256::from_little_endian(&block[range]);
state
.traces
.push_logic(logic::Operation::new(logic::Op::Xor, lhs, rhs));
}
for i in 0..KECCAK_RATE_BYTES {
sponge_state[i] ^= block[i];
}
}
pub(crate) fn keccak_sponge_log<F: Field>(
state: &mut GenerationState<F>,
base_address: MemoryAddress,
input: Vec<u8>,
) {
let clock = state.traces.clock();
let mut address = base_address;
let mut input_blocks = input.chunks_exact(KECCAK_RATE_BYTES);
let mut sponge_state = [0u8; KECCAK_WIDTH_BYTES];
for block in input_blocks.by_ref() {
for &byte in block {
state.traces.push_memory(MemoryOp::new(
MemoryChannel::Code,
clock,
address,
MemoryOpKind::Read,
byte.into(),
));
address.increment();
}
xor_into_sponge(state, &mut sponge_state, block.try_into().unwrap());
state.traces.push_keccak_bytes(sponge_state);
keccakf_u8s(&mut sponge_state);
}
for &byte in input_blocks.remainder() {
state.traces.push_memory(MemoryOp::new(
MemoryChannel::Code,
clock,
address,
MemoryOpKind::Read,
byte.into(),
));
address.increment();
}
let mut final_block = [0u8; KECCAK_RATE_BYTES];
final_block[..input_blocks.remainder().len()].copy_from_slice(input_blocks.remainder());
// pad10*1 rule
if input_blocks.remainder().len() == KECCAK_RATE_BYTES - 1 {
// Both 1s are placed in the same byte.
final_block[input_blocks.remainder().len()] = 0b10000001;
} else {
final_block[input_blocks.remainder().len()] = 1;
final_block[KECCAK_RATE_BYTES - 1] = 0b10000000;
}
xor_into_sponge(state, &mut sponge_state, &final_block);
state.traces.push_keccak_bytes(sponge_state);
state.traces.push_keccak_sponge(KeccakSpongeOp {
base_address,
timestamp: clock * NUM_CHANNELS,
input,
});
}

View File

@ -1,6 +1,7 @@
use std::collections::HashMap;
use eth_trie_utils::partial_trie::{Nibbles, PartialTrie};
use env_logger::{try_init_from_env, Env, DEFAULT_FILTER_ENV};
use eth_trie_utils::partial_trie::PartialTrie;
use plonky2::field::goldilocks_field::GoldilocksField;
use plonky2::plonk::config::PoseidonGoldilocksConfig;
use plonky2::util::timing::TimingTree;
@ -17,20 +18,15 @@ type C = PoseidonGoldilocksConfig;
/// Execute the empty list of transactions, i.e. a no-op.
#[test]
#[ignore] // TODO: Won't work until witness generation logic is finished.
fn test_empty_txn_list() -> anyhow::Result<()> {
init_logger();
let all_stark = AllStark::<F, D>::default();
let config = StarkConfig::standard_fast_config();
let block_metadata = BlockMetadata::default();
let state_trie = PartialTrie::Leaf {
nibbles: Nibbles {
count: 5,
packed: 0xABCDE.into(),
},
value: vec![1, 2, 3],
};
let state_trie = PartialTrie::Empty;
let transactions_trie = PartialTrie::Empty;
let receipts_trie = PartialTrie::Empty;
let storage_tries = vec![];
@ -51,7 +47,10 @@ fn test_empty_txn_list() -> anyhow::Result<()> {
block_metadata,
};
let proof = prove::<F, C, D>(&all_stark, &config, inputs, &mut TimingTree::default())?;
let mut timing = TimingTree::new("prove", log::Level::Debug);
let proof = prove::<F, C, D>(&all_stark, &config, inputs, &mut timing)?;
timing.print();
assert_eq!(
proof.public_values.trie_roots_before.state_root,
state_trie_root
@ -79,3 +78,7 @@ fn test_empty_txn_list() -> anyhow::Result<()> {
verify_proof(all_stark, proof, &config)
}
fn init_logger() {
let _ = try_init_from_env(Env::default().filter_or(DEFAULT_FILTER_ENV, "debug"));
}

View File

@ -46,7 +46,7 @@ structopt = { version = "0.3.26", default-features = false }
tynm = { version = "0.1.6", default-features = false }
[target.'cfg(not(target_env = "msvc"))'.dev-dependencies]
jemallocator = "0.3.2"
jemallocator = "0.5.0"
[[bin]]
name = "generate_constants"

View File

@ -105,17 +105,14 @@ where
{
let (inner_proof, inner_vd, inner_cd) = inner;
let mut builder = CircuitBuilder::<F, D>::new(config.clone());
let mut pw = PartialWitness::new();
let pt = builder.add_virtual_proof_with_pis::<InnerC>(inner_cd);
pw.set_proof_with_pis_target(&pt, inner_proof);
let inner_data = VerifierCircuitTarget {
constants_sigmas_cap: builder.add_virtual_cap(inner_cd.config.fri_config.cap_height),
circuit_digest: builder.add_virtual_hash(),
};
pw.set_verifier_data_target(&inner_data, inner_vd);
builder.verify_proof::<InnerC>(pt, &inner_data, inner_cd);
builder.verify_proof::<InnerC>(&pt, &inner_data, inner_cd);
builder.print_gate_counts(0);
if let Some(min_degree_bits) = min_degree_bits {
@ -131,6 +128,10 @@ where
let data = builder.build::<C>();
let mut pw = PartialWitness::new();
pw.set_proof_with_pis_target(&pt, inner_proof);
pw.set_verifier_data_target(&inner_data, inner_vd);
let mut timing = TimingTree::new("prove", Level::Debug);
let proof = prove(&data.prover_only, &data.common, pw, &mut timing)?;
timing.print();

View File

@ -112,7 +112,7 @@ pub struct FriProof<F: RichField + Extendable<D>, H: Hasher<F>, const D: usize>
pub pow_witness: F,
}
#[derive(Debug)]
#[derive(Clone, Debug)]
pub struct FriProofTarget<const D: usize> {
pub commit_phase_merkle_caps: Vec<MerkleCapTarget>,
pub query_round_proofs: Vec<FriQueryRoundTarget<D>>,

View File

@ -7,7 +7,7 @@ use crate::iop::target::Target;
use crate::plonk::circuit_builder::CircuitBuilder;
use crate::util::reducing::ReducingFactorTarget;
#[derive(Debug)]
#[derive(Clone, Debug)]
pub struct PolynomialCoeffsExtTarget<const D: usize>(pub Vec<ExtensionTarget<D>>);
impl<const D: usize> PolynomialCoeffsExtTarget<D> {

View File

@ -1,5 +1,6 @@
use alloc::vec::Vec;
use anyhow::ensure;
use serde::{Deserialize, Deserializer, Serialize, Serializer};
use crate::field::goldilocks_field::GoldilocksField;
@ -25,6 +26,7 @@ impl<F: Field> HashOut<F> {
elements: [F::ZERO; 4],
};
// TODO: Switch to a TryFrom impl.
pub fn from_vec(elements: Vec<F>) -> Self {
debug_assert!(elements.len() == 4);
Self {
@ -39,6 +41,23 @@ impl<F: Field> HashOut<F> {
}
}
impl<F: Field> From<[F; 4]> for HashOut<F> {
fn from(elements: [F; 4]) -> Self {
Self { elements }
}
}
impl<F: Field> TryFrom<&[F]> for HashOut<F> {
type Error = anyhow::Error;
fn try_from(elements: &[F]) -> Result<Self, Self::Error> {
ensure!(elements.len() == 4);
Ok(Self {
elements: elements.try_into().unwrap(),
})
}
}
impl<F> Sample for HashOut<F>
where
F: Field,
@ -97,6 +116,7 @@ pub struct HashOutTarget {
}
impl HashOutTarget {
// TODO: Switch to a TryFrom impl.
pub fn from_vec(elements: Vec<Target>) -> Self {
debug_assert!(elements.len() == 4);
Self {
@ -111,6 +131,23 @@ impl HashOutTarget {
}
}
impl From<[Target; 4]> for HashOutTarget {
fn from(elements: [Target; 4]) -> Self {
Self { elements }
}
}
impl TryFrom<&[Target]> for HashOutTarget {
type Error = anyhow::Error;
fn try_from(elements: &[Target]) -> Result<Self, Self::Error> {
ensure!(elements.len() == 4);
Ok(Self {
elements: elements.try_into().unwrap(),
})
}
}
#[derive(Clone, Debug)]
pub struct MerkleCapTarget(pub Vec<HashOutTarget>);

View File

@ -135,7 +135,9 @@ impl<F: RichField, H: Hasher<F>> MerkleTree<F, H> {
let log2_leaves_len = log2_strict(leaves.len());
assert!(
cap_height <= log2_leaves_len,
"cap height should be at most log2(leaves.len())"
"cap_height={} should be at most log2(leaves.len())={}",
cap_height,
log2_leaves_len
);
let num_digests = 2 * (leaves.len() - (1 << cap_height));

View File

@ -244,9 +244,15 @@ impl<F: RichField + Extendable<D>, const D: usize> CircuitBuilder<F, D> {
self.register_public_input(t);
t
}
/// Add a virtual verifier data, register it as a public input and set it to `self.verifier_data_public_input`.
/// WARNING: Do not register any public input after calling this! TODO: relax this
pub(crate) fn add_verifier_data_public_input(&mut self) {
pub fn add_verifier_data_public_inputs(&mut self) {
assert!(
self.verifier_data_public_input.is_none(),
"add_verifier_data_public_inputs only needs to be called once"
);
let verifier_data = VerifierCircuitTarget {
constants_sigmas_cap: self.add_virtual_cap(self.config.fri_config.cap_height),
circuit_digest: self.add_virtual_hash(),
@ -886,7 +892,7 @@ impl<F: RichField + Extendable<D>, const D: usize> CircuitBuilder<F, D> {
num_partial_products,
};
if let Some(goal_data) = self.goal_common_data {
assert_eq!(goal_data, common);
assert_eq!(goal_data, common, "The expected circuit data passed to cyclic recursion method did not match the actual circuit");
}
let prover_only = ProverOnlyCircuitData {

View File

@ -40,7 +40,7 @@ pub struct Proof<F: RichField + Extendable<D>, C: GenericConfig<D, F = F>, const
pub opening_proof: FriProof<F, C::Hasher, D>,
}
#[derive(Debug)]
#[derive(Clone, Debug)]
pub struct ProofTarget<const D: usize> {
pub wires_cap: MerkleCapTarget,
pub plonk_zs_partial_products_cap: MerkleCapTarget,
@ -283,7 +283,7 @@ pub(crate) struct FriInferredElements<F: RichField + Extendable<D>, const D: usi
pub Vec<F::Extension>,
);
#[derive(Debug)]
#[derive(Clone, Debug)]
pub struct ProofWithPublicInputsTarget<const D: usize> {
pub proof: ProofTarget<D>,
pub public_inputs: Vec<Target>,

View File

@ -1,7 +1,5 @@
use alloc::vec;
use alloc::vec::Vec;
use anyhow::{ensure, Result};
use itertools::Itertools;
use crate::field::extension::Extendable;
@ -9,66 +7,16 @@ use crate::fri::proof::{
FriInitialTreeProofTarget, FriProofTarget, FriQueryRoundTarget, FriQueryStepTarget,
};
use crate::gadgets::polynomial::PolynomialCoeffsExtTarget;
use crate::gates::noop::NoopGate;
use crate::hash::hash_types::{HashOutTarget, MerkleCapTarget, RichField};
use crate::hash::merkle_proofs::MerkleProofTarget;
use crate::iop::ext_target::ExtensionTarget;
use crate::iop::target::{BoolTarget, Target};
use crate::iop::witness::{PartialWitness, Witness};
use crate::plonk::circuit_builder::CircuitBuilder;
use crate::plonk::circuit_data::{
CommonCircuitData, VerifierCircuitTarget, VerifierOnlyCircuitData,
};
use crate::plonk::circuit_data::{CommonCircuitData, VerifierCircuitTarget};
use crate::plonk::config::{AlgebraicHasher, GenericConfig};
use crate::plonk::proof::{
OpeningSetTarget, ProofTarget, ProofWithPublicInputs, ProofWithPublicInputsTarget,
};
use crate::util::ceil_div_usize;
use crate::plonk::proof::{OpeningSetTarget, ProofTarget, ProofWithPublicInputsTarget};
use crate::with_context;
/// Generate a proof having a given `CommonCircuitData`.
pub(crate) fn dummy_proof<
F: RichField + Extendable<D>,
C: GenericConfig<D, F = F>,
const D: usize,
>(
common_data: &CommonCircuitData<F, D>,
) -> Result<(
ProofWithPublicInputs<F, C, D>,
VerifierOnlyCircuitData<C, D>,
)> {
let config = common_data.config.clone();
let mut pw = PartialWitness::new();
let mut builder = CircuitBuilder::<F, D>::new(config);
ensure!(
!common_data.config.zero_knowledge,
"Degree calculation can be off if zero-knowledge is on."
);
let degree = common_data.degree();
// Number of `NoopGate`s to add to get a circuit of size `degree` in the end.
// Need to account for public input hashing, a `PublicInputGate` and a `ConstantGate`.
let num_noop_gate = degree - ceil_div_usize(common_data.num_public_inputs, 8) - 2;
for _ in 0..num_noop_gate {
builder.add_gate(NoopGate, vec![]);
}
for gate in &common_data.gates {
builder.add_gate_to_gate_set(gate.clone());
}
for _ in 0..common_data.num_public_inputs {
let t = builder.add_virtual_public_input();
pw.set_target(t, F::ZERO);
}
let data = builder.build::<C>();
assert_eq!(&data.common, common_data);
let proof = data.prove(pw)?;
Ok((proof, data.verifier_only))
}
impl<F: RichField + Extendable<D>, const D: usize> CircuitBuilder<F, D> {
/// Verify `proof0` if `condition` else verify `proof1`.
/// `proof0` and `proof1` are assumed to use the same `CommonCircuitData`.
@ -143,7 +91,7 @@ impl<F: RichField + Extendable<D>, const D: usize> CircuitBuilder<F, D> {
),
};
self.verify_proof::<C>(selected_proof, &selected_verifier_data, inner_common_data);
self.verify_proof::<C>(&selected_proof, &selected_verifier_data, inner_common_data);
}
/// Conditionally verify a proof with a new generated dummy proof.
@ -369,6 +317,7 @@ impl<F: RichField + Extendable<D>, const D: usize> CircuitBuilder<F, D> {
#[cfg(test)]
mod tests {
use anyhow::Result;
use hashbrown::HashMap;
use super::*;
use crate::field::types::Sample;
@ -376,6 +325,7 @@ mod tests {
use crate::iop::witness::{PartialWitness, Witness};
use crate::plonk::circuit_data::CircuitConfig;
use crate::plonk::config::{GenericConfig, PoseidonGoldilocksConfig};
use crate::recursion::dummy_circuit::{dummy_circuit, dummy_proof};
#[test]
fn test_conditional_recursive_verifier() -> Result<()> {
@ -400,7 +350,8 @@ mod tests {
data.verify(proof.clone())?;
// Generate dummy proof with the same `CommonCircuitData`.
let (dummy_proof, dummy_data) = dummy_proof(&data.common)?;
let dummy_data = dummy_circuit(&data.common);
let dummy_proof = dummy_proof(&dummy_data, HashMap::new())?;
// Conditionally verify the two proofs.
let mut builder = CircuitBuilder::<F, D>::new(config);
@ -418,7 +369,7 @@ mod tests {
constants_sigmas_cap: builder.add_virtual_cap(data.common.config.fri_config.cap_height),
circuit_digest: builder.add_virtual_hash(),
};
pw.set_verifier_data_target(&dummy_inner_data, &dummy_data);
pw.set_verifier_data_target(&dummy_inner_data, &dummy_data.verifier_only);
let b = builder.constant_bool(F::rand().0 % 2 == 0);
builder.conditionally_verify_proof::<C>(
b,

View File

@ -3,6 +3,7 @@
use alloc::vec;
use anyhow::{ensure, Result};
use hashbrown::HashMap;
use itertools::Itertools;
use crate::field::extension::Extendable;
@ -13,11 +14,11 @@ use crate::iop::target::{BoolTarget, Target};
use crate::iop::witness::{PartialWitness, Witness};
use crate::plonk::circuit_builder::CircuitBuilder;
use crate::plonk::circuit_data::{
CommonCircuitData, VerifierCircuitTarget, VerifierOnlyCircuitData,
CircuitData, CommonCircuitData, VerifierCircuitTarget, VerifierOnlyCircuitData,
};
use crate::plonk::config::{AlgebraicHasher, GenericConfig};
use crate::plonk::proof::{ProofWithPublicInputs, ProofWithPublicInputsTarget};
use crate::recursion::conditional_recursive_verifier::dummy_proof;
use crate::recursion::dummy_circuit::{dummy_circuit, dummy_proof};
pub struct CyclicRecursionData<
'a,
@ -30,12 +31,17 @@ pub struct CyclicRecursionData<
common_data: &'a CommonCircuitData<F, D>,
}
pub struct CyclicRecursionTarget<const D: usize> {
pub proof: ProofWithPublicInputsTarget<D>,
pub verifier_data: VerifierCircuitTarget,
pub dummy_proof: ProofWithPublicInputsTarget<D>,
pub dummy_verifier_data: VerifierCircuitTarget,
pub condition: BoolTarget,
pub struct CyclicRecursionTarget<F, C, const D: usize>
where
F: RichField + Extendable<D>,
C: GenericConfig<D, F = F>,
{
pub(crate) proof: ProofWithPublicInputsTarget<D>,
pub(crate) verifier_data: VerifierCircuitTarget,
pub(crate) dummy_proof: ProofWithPublicInputsTarget<D>,
pub(crate) dummy_verifier_data: VerifierCircuitTarget,
pub(crate) condition: BoolTarget,
pub(crate) dummy_circuit: CircuitData<F, C, D>,
}
impl<C: GenericConfig<D>, const D: usize> VerifierOnlyCircuitData<C, D> {
@ -107,17 +113,16 @@ impl<F: RichField + Extendable<D>, const D: usize> CircuitBuilder<F, D> {
pub fn cyclic_recursion<C: GenericConfig<D, F = F>>(
&mut self,
condition: BoolTarget,
previous_virtual_public_inputs: &[Target],
common_data: &mut CommonCircuitData<F, D>,
) -> Result<CyclicRecursionTarget<D>>
proof_with_pis: &ProofWithPublicInputsTarget<D>,
common_data: &CommonCircuitData<F, D>,
) -> Result<CyclicRecursionTarget<F, C, D>>
where
C::Hasher: AlgebraicHasher<F>,
{
if self.verifier_data_public_input.is_none() {
self.add_verifier_data_public_input();
}
let verifier_data = self.verifier_data_public_input.clone().unwrap();
common_data.num_public_inputs = self.num_public_inputs();
let verifier_data = self
.verifier_data_public_input
.clone()
.expect("Must call add_verifier_data_public_inputs before cyclic recursion");
self.goal_common_data = Some(common_data.clone());
let dummy_verifier_data = VerifierCircuitTarget {
@ -125,10 +130,12 @@ impl<F: RichField + Extendable<D>, const D: usize> CircuitBuilder<F, D> {
circuit_digest: self.add_virtual_hash(),
};
let proof = self.add_virtual_proof_with_pis::<C>(common_data);
let dummy_proof = self.add_virtual_proof_with_pis::<C>(common_data);
let pis = VerifierCircuitTarget::from_slice::<F, C, D>(&proof.public_inputs, common_data)?;
let pis = VerifierCircuitTarget::from_slice::<F, C, D>(
&proof_with_pis.public_inputs,
common_data,
)?;
// Connect previous verifier data to current one. This guarantees that every proof in the cycle uses the same verifier data.
self.connect_hashes(pis.circuit_digest, verifier_data.circuit_digest);
for (h0, h1) in pis
@ -140,17 +147,10 @@ impl<F: RichField + Extendable<D>, const D: usize> CircuitBuilder<F, D> {
self.connect_hashes(*h0, *h1);
}
for (x, y) in previous_virtual_public_inputs
.iter()
.zip(&proof.public_inputs)
{
self.connect(*x, *y);
}
// Verify the real proof if `condition` is set to true, otherwise verify the dummy proof.
self.conditionally_verify_proof::<C>(
condition,
&proof,
proof_with_pis,
&verifier_data,
&dummy_proof,
&dummy_verifier_data,
@ -167,26 +167,29 @@ impl<F: RichField + Extendable<D>, const D: usize> CircuitBuilder<F, D> {
}
Ok(CyclicRecursionTarget {
proof,
verifier_data: verifier_data.clone(),
proof: proof_with_pis.clone(),
verifier_data,
dummy_proof,
dummy_verifier_data,
condition,
dummy_circuit: dummy_circuit(common_data),
})
}
}
/// Set the targets in a `CyclicRecursionTarget` to their corresponding values in a `CyclicRecursionData`.
/// The `public_inputs` parameter let the caller specify certain public inputs (identified by their
/// indices) which should be given specific values. The rest will default to zero.
pub fn set_cyclic_recursion_data_target<
F: RichField + Extendable<D>,
C: GenericConfig<D, F = F>,
const D: usize,
>(
pw: &mut PartialWitness<F>,
cyclic_recursion_data_target: &CyclicRecursionTarget<D>,
cyclic_recursion_data_target: &CyclicRecursionTarget<F, C, D>,
cyclic_recursion_data: &CyclicRecursionData<F, C, D>,
// Public inputs to set in the base case to seed some initial data.
public_inputs: &[F],
mut public_inputs: HashMap<usize, F>,
) -> Result<()>
where
C::Hasher: AlgebraicHasher<F>,
@ -204,36 +207,41 @@ where
cyclic_recursion_data.verifier_data,
);
} else {
let (dummy_proof, dummy_data) = dummy_proof::<F, C, D>(cyclic_recursion_data.common_data)?;
pw.set_bool_target(cyclic_recursion_data_target.condition, false);
let mut proof = dummy_proof.clone();
proof.public_inputs[0..public_inputs.len()].copy_from_slice(public_inputs);
let pis_len = proof.public_inputs.len();
// The circuit checks that the verifier data is the same throughout the cycle, so
// we set the verifier data to the "real" verifier data even though it's unused in the base case.
let num_cap = cyclic_recursion_data
let pis_len = cyclic_recursion_data_target
.dummy_circuit
.common
.num_public_inputs;
let cap_elements = cyclic_recursion_data
.common_data
.config
.fri_config
.num_cap_elements();
let s = pis_len - 4 - 4 * num_cap;
proof.public_inputs[s..s + 4]
.copy_from_slice(&cyclic_recursion_data.verifier_data.circuit_digest.elements);
for i in 0..num_cap {
proof.public_inputs[s + 4 * (1 + i)..s + 4 * (2 + i)].copy_from_slice(
&cyclic_recursion_data.verifier_data.constants_sigmas_cap.0[i].elements,
);
let start_vk_pis = pis_len - 4 - 4 * cap_elements;
// The circuit checks that the verifier data is the same throughout the cycle, so
// we set the verifier data to the "real" verifier data even though it's unused in the base case.
let verifier_data = &cyclic_recursion_data.verifier_data;
public_inputs.extend((start_vk_pis..).zip(verifier_data.circuit_digest.elements));
for i in 0..cap_elements {
let start = start_vk_pis + 4 + 4 * i;
public_inputs.extend((start..).zip(verifier_data.constants_sigmas_cap.0[i].elements));
}
let proof = dummy_proof(&cyclic_recursion_data_target.dummy_circuit, public_inputs)?;
pw.set_proof_with_pis_target(&cyclic_recursion_data_target.proof, &proof);
pw.set_verifier_data_target(
&cyclic_recursion_data_target.verifier_data,
cyclic_recursion_data.verifier_data,
);
pw.set_proof_with_pis_target(&cyclic_recursion_data_target.dummy_proof, &dummy_proof);
let dummy_p = dummy_proof(&cyclic_recursion_data_target.dummy_circuit, HashMap::new())?;
pw.set_proof_with_pis_target(&cyclic_recursion_data_target.dummy_proof, &dummy_p);
pw.set_verifier_data_target(
&cyclic_recursion_data_target.dummy_verifier_data,
&dummy_data,
&cyclic_recursion_data_target.dummy_circuit.verifier_only,
);
}
@ -264,11 +272,12 @@ where
#[cfg(test)]
mod tests {
use anyhow::Result;
use hashbrown::HashMap;
use crate::field::extension::Extendable;
use crate::field::types::{Field, PrimeField64};
use crate::gates::noop::NoopGate;
use crate::hash::hash_types::RichField;
use crate::hash::hash_types::{HashOutTarget, RichField};
use crate::hash::hashing::hash_n_to_hash_no_pad;
use crate::hash::poseidon::{PoseidonHash, PoseidonPermutation};
use crate::iop::witness::PartialWitness;
@ -298,7 +307,7 @@ mod tests {
constants_sigmas_cap: builder.add_virtual_cap(data.common.config.fri_config.cap_height),
circuit_digest: builder.add_virtual_hash(),
};
builder.verify_proof::<C>(proof, &verifier_data, &data.common);
builder.verify_proof::<C>(&proof, &verifier_data, &data.common);
let data = builder.build::<C>();
let config = CircuitConfig::standard_recursion_config();
@ -308,13 +317,19 @@ mod tests {
constants_sigmas_cap: builder.add_virtual_cap(data.common.config.fri_config.cap_height),
circuit_digest: builder.add_virtual_hash(),
};
builder.verify_proof::<C>(proof, &verifier_data, &data.common);
builder.verify_proof::<C>(&proof, &verifier_data, &data.common);
while builder.num_gates() < 1 << 12 {
builder.add_gate(NoopGate, vec![]);
}
builder.build::<C>().common
}
/// Uses cyclic recursion to build a hash chain.
/// The circuit has the following public input structure:
/// - Initial hash (4)
/// - Output for the tip of the hash chain (4)
/// - Chain length, i.e. the number of times the hash has been applied (1)
/// - VK for cyclic recursion (?)
#[test]
fn test_cyclic_recursion() -> Result<()> {
const D: usize = 2;
@ -322,55 +337,62 @@ mod tests {
type F = <C as GenericConfig<D>>::F;
let config = CircuitConfig::standard_recursion_config();
let mut pw = PartialWitness::new();
let mut builder = CircuitBuilder::<F, D>::new(config);
let one = builder.one();
// Circuit that computes a repeated hash.
let initial_hash = builder.add_virtual_hash();
builder.register_public_inputs(&initial_hash.elements);
// Hash from the previous proof.
let old_hash = builder.add_virtual_hash();
// The input hash is either the previous hash or the initial hash depending on whether
// the last proof was a base case.
let input_hash = builder.add_virtual_hash();
let h = builder.hash_n_to_hash_no_pad::<PoseidonHash>(input_hash.elements.to_vec());
builder.register_public_inputs(&h.elements);
// Previous counter.
let old_counter = builder.add_virtual_target();
let new_counter = builder.add_virtual_public_input();
let old_pis = [
initial_hash.elements.as_slice(),
old_hash.elements.as_slice(),
[old_counter].as_slice(),
]
.concat();
let current_hash_in = builder.add_virtual_hash();
let current_hash_out =
builder.hash_n_to_hash_no_pad::<PoseidonHash>(current_hash_in.elements.to_vec());
builder.register_public_inputs(&current_hash_out.elements);
let counter = builder.add_virtual_public_input();
let mut common_data = common_data_for_recursion::<F, C, D>();
builder.add_verifier_data_public_inputs();
common_data.num_public_inputs = builder.num_public_inputs();
let condition = builder.add_virtual_bool_target_safe();
// Add cyclic recursion gadget.
// Unpack inner proof's public inputs.
let inner_proof_with_pis = builder.add_virtual_proof_with_pis::<C>(&common_data);
let inner_pis = &inner_proof_with_pis.public_inputs;
let inner_initial_hash = HashOutTarget::try_from(&inner_pis[0..4]).unwrap();
let inner_latest_hash = HashOutTarget::try_from(&inner_pis[4..8]).unwrap();
let inner_counter = inner_pis[8];
// Connect our initial hash to that of our inner proof. (If there is no inner proof, the
// initial hash will be unconstrained, which is intentional.)
builder.connect_hashes(initial_hash, inner_initial_hash);
// The input hash is the previous hash output if we have an inner proof, or the initial hash
// if this is the base case.
let actual_hash_in = builder.select_hash(condition, inner_latest_hash, initial_hash);
builder.connect_hashes(current_hash_in, actual_hash_in);
// Our chain length will be inner_counter + 1 if we have an inner proof, or 1 if not.
let new_counter = builder.mul_add(condition.target, inner_counter, one);
builder.connect(counter, new_counter);
let cyclic_data_target =
builder.cyclic_recursion::<C>(condition, &old_pis, &mut common_data)?;
let input_hash_bis =
builder.select_hash(cyclic_data_target.condition, old_hash, initial_hash);
builder.connect_hashes(input_hash, input_hash_bis);
// New counter is the previous counter +1 if the previous proof wasn't a base case.
let new_counter_bis = builder.add(old_counter, condition.target);
builder.connect(new_counter, new_counter_bis);
builder.cyclic_recursion::<C>(condition, &inner_proof_with_pis, &common_data)?;
let cyclic_circuit_data = builder.build::<C>();
let mut pw = PartialWitness::new();
let cyclic_recursion_data = CyclicRecursionData {
proof: &None, // Base case: We don't have a proof to put here yet.
verifier_data: &cyclic_circuit_data.verifier_only,
common_data: &cyclic_circuit_data.common,
};
let initial_hash = [F::ZERO, F::ONE, F::TWO, F::from_canonical_usize(3)];
let initial_hash_pis = initial_hash.into_iter().enumerate().collect();
set_cyclic_recursion_data_target(
&mut pw,
&cyclic_data_target,
&cyclic_recursion_data,
&initial_hash,
initial_hash_pis,
)?;
let proof = cyclic_circuit_data.prove(pw)?;
check_cyclic_proof_verifier_data(
@ -391,7 +413,7 @@ mod tests {
&mut pw,
&cyclic_data_target,
&cyclic_recursion_data,
&[],
HashMap::new(),
)?;
let proof = cyclic_circuit_data.prove(pw)?;
check_cyclic_proof_verifier_data(
@ -412,7 +434,7 @@ mod tests {
&mut pw,
&cyclic_data_target,
&cyclic_recursion_data,
&[],
HashMap::new(),
)?;
let proof = cyclic_circuit_data.prove(pw)?;
check_cyclic_proof_verifier_data(
@ -425,17 +447,20 @@ mod tests {
let initial_hash = &proof.public_inputs[..4];
let hash = &proof.public_inputs[4..8];
let counter = proof.public_inputs[8];
let mut h: [F; 4] = initial_hash.try_into().unwrap();
assert_eq!(
hash,
core::iter::repeat_with(|| {
h = hash_n_to_hash_no_pad::<F, PoseidonPermutation>(&h).elements;
h
})
.nth(counter.to_canonical_u64() as usize)
.unwrap()
let expected_hash: [F; 4] = iterate_poseidon(
initial_hash.try_into().unwrap(),
counter.to_canonical_u64() as usize,
);
assert_eq!(hash, expected_hash);
cyclic_circuit_data.verify(proof)
}
fn iterate_poseidon<F: RichField>(initial_state: [F; 4], n: usize) -> [F; 4] {
let mut current = initial_state;
for _ in 0..n {
current = hash_n_to_hash_no_pad::<F, PoseidonPermutation>(&current).elements;
}
current
}
}

View File

@ -0,0 +1,67 @@
use alloc::vec;
use hashbrown::HashMap;
use plonky2_field::extension::Extendable;
use plonky2_util::ceil_div_usize;
use crate::gates::noop::NoopGate;
use crate::hash::hash_types::RichField;
use crate::iop::witness::{PartialWitness, Witness};
use crate::plonk::circuit_builder::CircuitBuilder;
use crate::plonk::circuit_data::{CircuitData, CommonCircuitData};
use crate::plonk::config::GenericConfig;
use crate::plonk::proof::ProofWithPublicInputs;
/// Generate a proof for a dummy circuit. The `public_inputs` parameter let the caller specify
/// certain public inputs (identified by their indices) which should be given specific values.
/// The rest will default to zero.
pub(crate) fn dummy_proof<F, C, const D: usize>(
circuit: &CircuitData<F, C, D>,
nonzero_public_inputs: HashMap<usize, F>,
) -> anyhow::Result<ProofWithPublicInputs<F, C, D>>
where
F: RichField + Extendable<D>,
C: GenericConfig<D, F = F>,
{
let mut pw = PartialWitness::new();
for i in 0..circuit.common.num_public_inputs {
let pi = nonzero_public_inputs.get(&i).copied().unwrap_or_default();
pw.set_target(circuit.prover_only.public_inputs[i], pi);
}
circuit.prove(pw)
}
/// Generate a circuit matching a given `CommonCircuitData`.
pub(crate) fn dummy_circuit<
F: RichField + Extendable<D>,
C: GenericConfig<D, F = F>,
const D: usize,
>(
common_data: &CommonCircuitData<F, D>,
) -> CircuitData<F, C, D> {
let config = common_data.config.clone();
assert!(
!common_data.config.zero_knowledge,
"Degree calculation can be off if zero-knowledge is on."
);
// Number of `NoopGate`s to add to get a circuit of size `degree` in the end.
// Need to account for public input hashing, a `PublicInputGate` and a `ConstantGate`.
let degree = common_data.degree();
let num_noop_gate = degree - ceil_div_usize(common_data.num_public_inputs, 8) - 2;
let mut builder = CircuitBuilder::<F, D>::new(config);
for _ in 0..num_noop_gate {
builder.add_gate(NoopGate, vec![]);
}
for gate in &common_data.gates {
builder.add_gate_to_gate_set(gate.clone());
}
for _ in 0..common_data.num_public_inputs {
builder.add_virtual_public_input();
}
let circuit = builder.build::<C>();
assert_eq!(&circuit.common, common_data);
circuit
}

View File

@ -1,3 +1,4 @@
pub mod conditional_recursive_verifier;
pub mod cyclic_recursion;
pub(crate) mod dummy_circuit;
pub mod recursive_verifier;

View File

@ -16,7 +16,7 @@ impl<F: RichField + Extendable<D>, const D: usize> CircuitBuilder<F, D> {
/// Recursively verifies an inner proof.
pub fn verify_proof<C: GenericConfig<D, F = F>>(
&mut self,
proof_with_pis: ProofWithPublicInputsTarget<D>,
proof_with_pis: &ProofWithPublicInputsTarget<D>,
inner_verifier_data: &VerifierCircuitTarget,
inner_common_data: &CommonCircuitData<F, D>,
) where
@ -36,7 +36,7 @@ impl<F: RichField + Extendable<D>, const D: usize> CircuitBuilder<F, D> {
);
self.verify_proof_with_challenges::<C>(
proof_with_pis.proof,
&proof_with_pis.proof,
public_inputs_hash,
challenges,
inner_verifier_data,
@ -47,7 +47,7 @@ impl<F: RichField + Extendable<D>, const D: usize> CircuitBuilder<F, D> {
/// Recursively verifies an inner proof.
fn verify_proof_with_challenges<C: GenericConfig<D, F = F>>(
&mut self,
proof: ProofTarget<D>,
proof: &ProofTarget<D>,
public_inputs_hash: HashOutTarget,
challenges: ProofChallengesTarget<D>,
inner_verifier_data: &VerifierCircuitTarget,
@ -106,9 +106,9 @@ impl<F: RichField + Extendable<D>, const D: usize> CircuitBuilder<F, D> {
let merkle_caps = &[
inner_verifier_data.constants_sigmas_cap.clone(),
proof.wires_cap,
proof.plonk_zs_partial_products_cap,
proof.quotient_polys_cap,
proof.wires_cap.clone(),
proof.plonk_zs_partial_products_cap.clone(),
proof.quotient_polys_cap.clone(),
];
let fri_instance = inner_common_data.get_fri_instance_target(self, challenges.plonk_zeta);
@ -376,7 +376,7 @@ mod tests {
);
pw.set_hash_target(inner_data.circuit_digest, inner_vd.circuit_digest);
builder.verify_proof::<InnerC>(pt, &inner_data, &inner_cd);
builder.verify_proof::<InnerC>(&pt, &inner_data, &inner_cd);
if print_gate_counts {
builder.print_gate_counts(0);

View File

@ -23,4 +23,4 @@ name = "lookup_permuted_cols"
harness = false
[target.'cfg(not(target_env = "msvc"))'.dev-dependencies]
jemallocator = "0.3.2"
jemallocator = "0.5.0"