updated all_stark framework to include memory stark (doesn't pass yet)

This commit is contained in:
Nicholas Ward 2022-06-23 13:59:57 -07:00
parent 31be2c8d49
commit 03112f898a
6 changed files with 222 additions and 57 deletions

View File

@ -6,6 +6,7 @@ use crate::cpu::cpu_stark::CpuStark;
use crate::cross_table_lookup::CrossTableLookup;
use crate::keccak::keccak_stark::KeccakStark;
use crate::logic::LogicStark;
use crate::memory::memory_stark::MemoryStark;
use crate::stark::Stark;
#[derive(Clone)]
@ -13,6 +14,7 @@ pub struct AllStark<F: RichField + Extendable<D>, const D: usize> {
pub cpu_stark: CpuStark<F, D>,
pub keccak_stark: KeccakStark<F, D>,
pub logic_stark: LogicStark<F, D>,
pub memory_stark: MemoryStark<F, D>,
pub cross_table_lookups: Vec<CrossTableLookup<F>>,
}
@ -43,11 +45,12 @@ pub enum Table {
Cpu = 0,
Keccak = 1,
Logic = 2,
Memory = 3,
}
impl Table {
pub(crate) fn num_tables() -> usize {
Table::Logic as usize + 1
Table::Memory as usize + 1
}
}
@ -68,11 +71,12 @@ mod tests {
use crate::config::StarkConfig;
use crate::cpu::columns::{KECCAK_INPUT_LIMBS, KECCAK_OUTPUT_LIMBS};
use crate::cpu::cpu_stark::{self as cpu_stark_mod, CpuStark};
use crate::cross_table_lookup::{CrossTableLookup, TableWithColumns};
use crate::keccak::keccak_stark::{
self as keccak_stark_mod, KeccakStark, NUM_INPUTS, NUM_ROUNDS,
};
use crate::logic::{self, LogicStark};
use crate::cross_table_lookup::{Column, CrossTableLookup, TableWithColumns};
use crate::memory::memory_stark::{generate_random_memory_ops, MemoryStark};
use crate::proof::AllProof;
use crate::prover::prove;
use crate::recursive_verifier::{
@ -137,6 +141,7 @@ mod tests {
cpu_stark: &CpuStark<F, D>,
keccak_trace: &[PolynomialValues<F>],
logic_trace: &[PolynomialValues<F>],
memory_trace: &[PolynomialValues<F>],
) -> Vec<PolynomialValues<F>> {
let keccak_input_limbs: Vec<[F; 2 * NUM_INPUTS]> = (0..num_keccak_perms)
.map(|i| {
@ -162,10 +167,6 @@ mod tests {
.unwrap()
})
.collect();
let memory_trace = memory_stark.generate_trace(keccak_inputs);
let column_to_copy: Vec<_> = keccak_trace[keccak_looked_col].values[..].into();
let default = vec![F::ONE; 1];
let mut cpu_trace_rows = vec![];
for i in 0..num_keccak_perms {
@ -203,6 +204,30 @@ mod tests {
cpu_stark.generate(&mut row);
cpu_trace_rows.push(row);
}
for i in 0..num_memory_ops {
let mem_timestamp: usize = memory_trace[memory::registers::TIMESTAMP].values[i]
.to_canonical_u64()
.try_into()
.unwrap();
let clock = mem_timestamp / NUM_MEMORY_OPS;
let op = mem_timestamp % NUM_MEMORY_OPS;
cpu_trace_rows[i][cpu::columns::uses_memop(op)] = F::ONE;
memory_trace[memory::registers::is_memop(op)].values[i] = F::ONE;
cpu_trace_rows[i][cpu::columns::CLOCK] = F::from_canonical_usize(clock);
cpu_trace_rows[i][cpu::columns::memop_is_read(op)] =
memory_trace[memory::registers::IS_READ].values[i];
cpu_trace_rows[i][cpu::columns::memop_addr_context(op)] =
memory_trace[memory::registers::ADDR_CONTEXT].values[i];
cpu_trace_rows[i][cpu::columns::memop_addr_segment(op)] =
memory_trace[memory::registers::ADDR_SEGMENT].values[i];
cpu_trace_rows[i][cpu::columns::memop_addr_virtual(op)] =
memory_trace[memory::registers::ADDR_VIRTUAL].values[i];
for j in 0..8 {
cpu_trace_rows[i][cpu::columns::memop_value(op, j)] =
memory_trace[memory::registers::value_limb(j)].values[i];
}
}
trace_rows_to_poly_values(cpu_trace_rows)
}
@ -220,6 +245,11 @@ mod tests {
};
let num_logic_rows = 62;
let memory_stark = MemoryStark::<F, D> {
f: Default::default(),
};
let num_memory_ops = 1 << 5;
let mut rng = thread_rng();
let num_keccak_perms = 2;
@ -228,11 +258,47 @@ mod tests {
let cpu_trace = make_cpu_trace(
num_keccak_perms,
num_logic_rows,
num_memory_ops,
&cpu_stark,
&keccak_trace,
&logic_trace,
);
let memory_ops = generate_random_memory_ops(num_memory_ops);
let memory_trace = memory_stark.generate_trace(memory_ops);
let mut cpu_keccak_input_output = cpu::columns::KECCAK_INPUT_LIMBS.collect::<Vec<_>>();
cpu_keccak_input_output.extend(cpu::columns::KECCAK_OUTPUT_LIMBS);
let mut keccak_keccak_input_output = (0..2 * NUM_INPUTS)
.map(keccak::registers::reg_input_limb)
.collect::<Vec<_>>();
keccak_keccak_input_output.extend(Column::singles(
(0..2 * NUM_INPUTS).map(keccak::registers::reg_output_limb),
));
let cpu_logic_input_output = {
let mut res = vec![
cpu::columns::IS_AND,
cpu::columns::IS_OR,
cpu::columns::IS_XOR,
];
res.extend(cpu::columns::LOGIC_INPUT0);
res.extend(cpu::columns::LOGIC_INPUT1);
res.extend(cpu::columns::LOGIC_OUTPUT);
res
};
let logic_logic_input_output = {
let mut res = vec![
logic::columns::IS_AND,
logic::columns::IS_OR,
logic::columns::IS_XOR,
];
res.extend(logic::columns::INPUT0_PACKED);
res.extend(logic::columns::INPUT1_PACKED);
res.extend(logic::columns::RESULT);
res
};
let cross_table_lookups = vec![
CrossTableLookup::new(
vec![TableWithColumns::new(
@ -262,13 +328,14 @@ mod tests {
cpu_stark,
keccak_stark,
logic_stark,
memory_stark,
cross_table_lookups,
};
let proof = prove::<F, C, D>(
&all_stark,
config,
vec![cpu_trace, keccak_trace, logic_trace],
vec![cpu_trace, keccak_trace, logic_trace, memory_trace],
vec![vec![]; 3],
&mut TimingTree::default(),
)?;

View File

@ -13,10 +13,20 @@ pub const IS_CPU_CYCLE: usize = IS_BOOTSTRAP_CONTRACT + 1;
/// If CPU cycle: The opcode being decoded, in {0, ..., 255}.
pub const OPCODE: usize = IS_CPU_CYCLE + 1;
<<<<<<< HEAD
/// If CPU cycle: flags for EVM instructions. PUSHn, DUPn, and SWAPn only get one flag each. Invalid
/// opcodes are split between a number of flags for practical reasons. Exactly one of these flags
/// must be 1.
pub const IS_STOP: usize = OPCODE + 1;
=======
pub const KECCAK_DUMMY: usize = OPCODE + 1;
pub const MEMORY_DUMMY: usize = KECCAK_DUMMY + 1;
// If CPU cycle: flags for EVM instructions. PUSHn, DUPn, and SWAPn only get one flag each. Invalid
// opcodes are split between a number of flags for practical reasons. Exactly one of these flags
// must be 1.
pub const IS_STOP: usize = MEMORY_DUMMY + 1;
>>>>>>> ade7e5e0 (updated all_stark framework to include memory stark (doesn't pass yet))
pub const IS_ADD: usize = IS_STOP + 1;
pub const IS_MUL: usize = IS_ADD + 1;
pub const IS_SUB: usize = IS_MUL + 1;

View File

@ -1,9 +1,14 @@
use std::collections::HashMap;
use std::marker::PhantomData;
use itertools::{izip, multiunzip};
use plonky2::field::extension_field::{Extendable, FieldExtension};
use plonky2::field::packed_field::PackedField;
use plonky2::field::polynomial::PolynomialValues;
use plonky2::hash::hash_types::RichField;
use plonky2::timed;
use plonky2::util::timing::TimingTree;
use rand::{thread_rng, Rng};
use crate::constraint_consumer::{ConstraintConsumer, RecursiveConstraintConsumer};
use crate::memory::registers::{
@ -15,7 +20,7 @@ use crate::memory::registers::{
SORTED_MEMORY_TIMESTAMP,
};
use crate::stark::Stark;
use crate::util::permuted_cols;
use crate::util::{permuted_cols, trace_rows_to_poly_values};
use crate::vars::{StarkEvaluationTargets, StarkEvaluationVars};
#[derive(Default)]
@ -42,6 +47,52 @@ pub struct MemoryStark<F, const D: usize> {
pub(crate) f: PhantomData<F>,
}
pub fn generate_random_memory_ops<F: RichField>(num_ops: usize) -> Vec<(F, F, F, [F; 8], F, F)> {
let mut memory_ops = Vec::new();
let mut rng = thread_rng();
let mut current_memory_values: HashMap<(F, F, F), [F; 8]> = HashMap::new();
let mut cur_timestamp = 0;
for i in 0..num_ops {
let is_read = if i == 0 { false } else { rng.gen() };
let is_read_F = F::from_bool(is_read);
let (context, segment, virt, vals) = if is_read {
let written: Vec<_> = current_memory_values.keys().collect();
let &(context, segment, virt) = written[rng.gen_range(0..written.len())];
let &vals = current_memory_values
.get(&(context, segment, virt))
.unwrap();
(context, segment, virt, vals)
} else {
let context = F::from_canonical_usize(rng.gen_range(0..256));
let segment = F::from_canonical_usize(rng.gen_range(0..8));
let virt = F::from_canonical_usize(rng.gen_range(0..20));
let val: [u32; 8] = rng.gen();
let vals: [F; 8] = val
.iter()
.map(|&x| F::from_canonical_u32(x))
.collect::<Vec<_>>()
.try_into()
.unwrap();
current_memory_values.insert((context, segment, virt), vals);
(context, segment, virt, vals)
};
let timestamp = F::from_canonical_usize(cur_timestamp);
cur_timestamp += 1;
memory_ops.push((context, segment, virt, vals, is_read_F, timestamp))
}
memory_ops
}
pub fn sort_memory_ops<F: RichField>(
context: &[F],
segment: &[F],
@ -152,7 +203,7 @@ pub fn generate_range_check_value<F: RichField>(
}
impl<F: RichField + Extendable<D>, const D: usize> MemoryStark<F, D> {
fn generate_trace_rows(
pub(crate) fn generate_trace_rows(
&self,
memory_ops: Vec<(F, F, F, [F; 8], F, F)>,
) -> Vec<[F; NUM_REGISTERS]> {
@ -175,7 +226,7 @@ impl<F: RichField + Extendable<D>, const D: usize> MemoryStark<F, D> {
self.generate_memory(&mut trace_cols);
let mut trace_rows = vec![[F::ZERO; NUM_REGISTERS]];
let mut trace_rows = vec![[F::ZERO; NUM_REGISTERS]; num_ops];
for (i, col) in trace_cols.iter().enumerate() {
for (j, &val) in col.iter().enumerate() {
trace_rows[j][i] = val;
@ -231,7 +282,7 @@ impl<F: RichField + Extendable<D>, const D: usize> MemoryStark<F, D> {
trace_cols[MEMORY_VIRTUAL_FIRST_CHANGE] = virtual_first_change;
trace_cols[MEMORY_RANGE_CHECK] = range_check_value;
trace_cols[MEMORY_COUNTER] = (0..trace_cols.len())
trace_cols[MEMORY_COUNTER] = (0..trace_cols[0].len())
.map(|i| F::from_canonical_usize(i))
.collect();
@ -240,6 +291,29 @@ impl<F: RichField + Extendable<D>, const D: usize> MemoryStark<F, D> {
trace_cols[MEMORY_RANGE_CHECK_PERMUTED] = permuted_inputs;
trace_cols[MEMORY_COUNTER_PERMUTED] = permuted_table;
}
pub fn generate_trace(
&self,
memory_ops: Vec<(F, F, F, [F; 8], F, F)>,
) -> Vec<PolynomialValues<F>> {
let mut timing = TimingTree::new("generate trace", log::Level::Debug);
// Generate the witness, except for permuted columns in the lookup argument.
let trace_rows = timed!(
&mut timing,
"generate trace rows",
self.generate_trace_rows(memory_ops)
);
let trace_polys = timed!(
&mut timing,
"convert to PolynomialValues",
trace_rows_to_poly_values(trace_rows)
);
timing.print();
trace_polys
}
}
impl<F: RichField + Extendable<D>, const D: usize> Stark<F, D> for MemoryStark<F, D> {
@ -287,10 +361,10 @@ impl<F: RichField + Extendable<D>, const D: usize> Stark<F, D> for MemoryStark<F
let not_timestamp_first_change = one - timestamp_first_change;
// First set of ordering constraint: first_change flags are boolean.
yield_constr.constraint(context_first_change * not_context_first_change);
yield_constr.constraint(segment_first_change * not_segment_first_change);
yield_constr.constraint(virtual_first_change * not_virtual_first_change);
yield_constr.constraint(timestamp_first_change * not_timestamp_first_change);
// yield_constr.constraint(context_first_change * not_context_first_change);
// yield_constr.constraint(segment_first_change * not_segment_first_change);
// yield_constr.constraint(virtual_first_change * not_virtual_first_change);
// yield_constr.constraint(timestamp_first_change * not_timestamp_first_change);
// Second set of ordering constraints: no change before the column corresponding to the nonzero first_change flag.
yield_constr.constraint(segment_first_change * (next_addr_context - addr_context));
@ -481,7 +555,7 @@ mod tests {
use plonky2::plonk::config::{GenericConfig, PoseidonGoldilocksConfig};
use rand::{thread_rng, Rng};
use crate::memory::memory_stark::MemoryStark;
use crate::memory::memory_stark::{generate_random_memory_ops, MemoryStark};
use crate::stark_testing::{test_stark_circuit_constraints, test_stark_low_degree};
#[test]
@ -525,47 +599,8 @@ mod tests {
const MAX_SEGMENT: usize = 8;
const MAX_VIRTUAL: usize = 1 << 12;
let mut rng = thread_rng();
let num_ops = 20;
let mut memory_ops = Vec::new();
let current_memory_values: HashMap<(F, F, F), [F; 8]> = HashMap::new();
let mut cur_timestamp = 0;
for i in 0..num_ops {
let is_read = if i == 0 { false } else { rng.gen() };
let is_read_F = F::from_bool(is_read);
let (context, segment, virt, vals) = if is_read {
let written: Vec<_> = current_memory_values.keys().collect();
let &(context, segment, virt) = written[rng.gen_range(0..written.len())];
let &vals = current_memory_values
.get(&(context, segment, virt))
.unwrap();
(context, segment, virt, vals)
} else {
let context = F::from_canonical_usize(rng.gen_range(0..256));
let segment = F::from_canonical_usize(rng.gen_range(0..8));
let virt = F::from_canonical_usize(rng.gen_range(0..20));
let val: [u32; 8] = rng.gen();
let vals: [F; 8] = val
.iter()
.map(|&x| F::from_canonical_u32(x))
.collect::<Vec<_>>()
.try_into()
.unwrap();
current_memory_values.insert((context, segment, virt), vals);
(context, segment, virt, vals)
};
let timestamp = F::from_canonical_usize(cur_timestamp);
cur_timestamp += 1;
memory_ops.push((context, segment, virt, vals, is_read_F, timestamp))
}
let memory_ops = generate_random_memory_ops(num_ops);
let rows = stark.generate_trace_rows(memory_ops);

View File

@ -22,6 +22,7 @@ use crate::cpu::cpu_stark::CpuStark;
use crate::cross_table_lookup::{cross_table_lookup_data, CtlCheckVars, CtlData};
use crate::keccak::keccak_stark::KeccakStark;
use crate::logic::LogicStark;
use crate::memory::memory_stark::MemoryStark;
use crate::permutation::PermutationCheckVars;
use crate::permutation::{
compute_permutation_z_polys, get_n_grand_product_challenge_sets, GrandProductChallengeSet,
@ -49,6 +50,8 @@ where
[(); KeccakStark::<F, D>::PUBLIC_INPUTS]:,
[(); LogicStark::<F, D>::COLUMNS]:,
[(); LogicStark::<F, D>::PUBLIC_INPUTS]:,
[(); MemoryStark::<F, D>::COLUMNS]:,
[(); MemoryStark::<F, D>::PUBLIC_INPUTS]:,
{
let num_starks = Table::num_tables();
debug_assert_eq!(num_starks, trace_poly_values.len());
@ -132,8 +135,21 @@ where
&mut challenger,
timing,
)?;
let memory_proof = prove_single_table(
&all_stark.memory_stark,
config,
&trace_poly_values[Table::Memory as usize],
&trace_commitments[Table::Memory as usize],
&ctl_data_per_table[Table::Memory as usize],
public_inputs[Table::Memory as usize]
.clone()
.try_into()
.unwrap(),
&mut challenger,
timing,
)?;
let stark_proofs = vec![cpu_proof, keccak_proof, logic_proof];
let stark_proofs = vec![cpu_proof, keccak_proof, logic_proof, memory_proof];
debug_assert_eq!(stark_proofs.len(), num_starks);
Ok(AllProof { stark_proofs })

View File

@ -18,6 +18,7 @@ use crate::cpu::cpu_stark::CpuStark;
use crate::cross_table_lookup::{verify_cross_table_lookups_circuit, CtlCheckVarsTarget};
use crate::keccak::keccak_stark::KeccakStark;
use crate::logic::LogicStark;
use crate::memory::memory_stark::MemoryStark;
use crate::permutation::PermutationCheckDataTarget;
use crate::proof::{
AllProof, AllProofChallengesTarget, AllProofTarget, StarkOpeningSetTarget, StarkProof,
@ -44,6 +45,8 @@ pub fn verify_proof_circuit<
[(); KeccakStark::<F, D>::PUBLIC_INPUTS]:,
[(); LogicStark::<F, D>::COLUMNS]:,
[(); LogicStark::<F, D>::PUBLIC_INPUTS]:,
[(); MemoryStark::<F, D>::COLUMNS]:,
[(); MemoryStark::<F, D>::PUBLIC_INPUTS]:,
C::Hasher: AlgebraicHasher<F>,
{
let AllProofChallengesTarget {
@ -57,6 +60,7 @@ pub fn verify_proof_circuit<
cpu_stark,
keccak_stark,
logic_stark,
memory_stark,
cross_table_lookups,
} = all_stark;
@ -91,6 +95,14 @@ pub fn verify_proof_circuit<
&ctl_vars_per_table[Table::Logic as usize],
inner_config,
);
verify_stark_proof_with_challenges_circuit::<F, C, _, D>(
builder,
memory_stark,
&all_proof.stark_proofs[Table::Memory as usize],
&stark_challenges[Table::Memory as usize],
&ctl_vars_per_table[Table::Memory as usize],
inner_config,
);
verify_cross_table_lookups_circuit::<F, C, D>(
builder,
@ -291,6 +303,20 @@ pub fn add_virtual_all_proof<F: RichField + Extendable<D>, const D: usize>(
public_inputs,
}
},
{
let proof = add_virtual_stark_proof(
builder,
all_stark.memory_stark,
config,
degree_bits[Table::Memory as usize],
nums_ctl_zs[Table::Memory as usize],
);
let public_inputs = builder.add_virtual_targets(KeccakStark::<F, D>::PUBLIC_INPUTS);
StarkProofWithPublicInputsTarget {
proof,
public_inputs,
}
},
];
assert_eq!(stark_proofs.len(), Table::num_tables());

View File

@ -13,6 +13,7 @@ use crate::cpu::cpu_stark::CpuStark;
use crate::cross_table_lookup::{verify_cross_table_lookups, CtlCheckVars};
use crate::keccak::keccak_stark::KeccakStark;
use crate::logic::LogicStark;
use crate::memory::memory_stark::MemoryStark;
use crate::permutation::PermutationCheckVars;
use crate::proof::{
AllProof, AllProofChallenges, StarkOpeningSet, StarkProofChallenges, StarkProofWithPublicInputs,
@ -33,6 +34,8 @@ where
[(); KeccakStark::<F, D>::PUBLIC_INPUTS]:,
[(); LogicStark::<F, D>::COLUMNS]:,
[(); LogicStark::<F, D>::PUBLIC_INPUTS]:,
[(); MemoryStark::<F, D>::COLUMNS]:,
[(); MemoryStark::<F, D>::PUBLIC_INPUTS]:,
[(); C::Hasher::HASH_SIZE]:,
{
let AllProofChallenges {
@ -46,6 +49,7 @@ where
cpu_stark,
keccak_stark,
logic_stark,
memory_stark,
cross_table_lookups,
} = all_stark;
@ -70,6 +74,13 @@ where
&ctl_vars_per_table[Table::Keccak as usize],
config,
)?;
verify_stark_proof_with_challenges(
memory_stark,
&all_proof.stark_proofs[Table::Memory as usize],
&stark_challenges[Table::Memory as usize],
&ctl_vars_per_table[Table::Memory as usize],
config,
)?;
verify_stark_proof_with_challenges(
logic_stark,
&all_proof.stark_proofs[Table::Logic as usize],