Add challenger state

This commit is contained in:
wborgeaud 2022-09-22 11:01:27 +02:00
parent 00c439513a
commit 6e6c2daf29
7 changed files with 186 additions and 50 deletions

View File

@ -5,7 +5,7 @@ use plonky2::iop::challenger::{Challenger, RecursiveChallenger};
use plonky2::plonk::circuit_builder::CircuitBuilder;
use plonky2::plonk::config::{AlgebraicHasher, GenericConfig};
use crate::all_stark::AllStark;
use crate::all_stark::{AllStark, NUM_TABLES};
use crate::config::StarkConfig;
use crate::permutation::{
get_grand_product_challenge_set, get_grand_product_challenge_set_target,
@ -46,6 +46,43 @@ impl<F: RichField + Extendable<D>, C: GenericConfig<D, F = F>, const D: usize> A
ctl_challenges,
}
}
pub(crate) fn get_challenger_states(
&self,
all_stark: &AllStark<F, D>,
config: &StarkConfig,
) -> AllChallengerState<F, D> {
let mut challenger = Challenger::<F, C::Hasher>::new();
for proof in &self.stark_proofs {
challenger.observe_cap(&proof.trace_cap);
}
// TODO: Observe public values.
let ctl_challenges =
get_grand_product_challenge_set(&mut challenger, config.num_challenges);
let num_permutation_zs = all_stark.nums_permutation_zs(config);
let num_permutation_batch_sizes = all_stark.permutation_batch_sizes();
challenger.duplexing();
let mut challenger_states = vec![challenger.state()];
for i in 0..NUM_TABLES {
self.stark_proofs[i].get_challenges(
&mut challenger,
num_permutation_zs[i] > 0,
num_permutation_batch_sizes[i],
config,
);
challenger.duplexing();
challenger_states.push(challenger.state());
}
AllChallengerState {
states: challenger_states.try_into().unwrap(),
ctl_challenges,
}
}
}
impl<const D: usize> AllProofTarget<D> {

View File

@ -8,6 +8,7 @@ use plonky2::fri::structure::{
FriOpeningBatch, FriOpeningBatchTarget, FriOpenings, FriOpeningsTarget,
};
use plonky2::hash::hash_types::{MerkleCapTarget, RichField};
use plonky2::hash::hashing::SPONGE_WIDTH;
use plonky2::hash::merkle_tree::MerkleCap;
use plonky2::iop::ext_target::ExtensionTarget;
use plonky2::iop::target::Target;
@ -38,6 +39,11 @@ pub(crate) struct AllProofChallenges<F: RichField + Extendable<D>, const D: usiz
pub ctl_challenges: GrandProductChallengeSet<F>,
}
pub(crate) struct AllChallengerState<F: RichField + Extendable<D>, const D: usize> {
pub states: [[F; SPONGE_WIDTH]; NUM_TABLES + 1],
pub ctl_challenges: GrandProductChallengeSet<F>,
}
pub struct AllProofTarget<const D: usize> {
pub stark_proofs: [StarkProofTarget<D>; NUM_TABLES],
pub public_values: PublicValuesTarget,

View File

@ -201,6 +201,8 @@ where
"FRI total reduction arity is too large.",
);
challenger.duplexing();
// Permutation arguments.
let permutation_challenges = stark.uses_permutation_args().then(|| {
get_n_grand_product_challenge_sets(

View File

@ -4,6 +4,8 @@ use plonky2::field::extension::Extendable;
use plonky2::field::types::Field;
use plonky2::fri::witness_util::set_fri_proof_target;
use plonky2::hash::hash_types::RichField;
use plonky2::hash::hashing::SPONGE_WIDTH;
use plonky2::iop::challenger::RecursiveChallenger;
use plonky2::iop::ext_target::ExtensionTarget;
use plonky2::iop::target::Target;
use plonky2::iop::witness::{PartialWitness, Witness};
@ -19,15 +21,19 @@ use crate::all_stark::{AllStark, Table, NUM_TABLES};
use crate::config::StarkConfig;
use crate::constraint_consumer::RecursiveConstraintConsumer;
use crate::cpu::cpu_stark::CpuStark;
use crate::cross_table_lookup::{verify_cross_table_lookups_circuit, CtlCheckVarsTarget};
use crate::cross_table_lookup::{
verify_cross_table_lookups_circuit, CrossTableLookup, CtlCheckVarsTarget,
};
use crate::keccak::keccak_stark::KeccakStark;
use crate::keccak_memory::keccak_memory_stark::KeccakMemoryStark;
use crate::logic::LogicStark;
use crate::memory::memory_stark::MemoryStark;
use crate::permutation::PermutationCheckDataTarget;
use crate::permutation::{
GrandProductChallenge, GrandProductChallengeSet, PermutationCheckDataTarget,
};
use crate::proof::{
AllProof, AllProofChallengesTarget, AllProofTarget, BlockMetadata, BlockMetadataTarget,
PublicValues, PublicValuesTarget, StarkOpeningSetTarget, StarkProof,
AllChallengerState, AllProof, AllProofChallengesTarget, AllProofTarget, BlockMetadata,
BlockMetadataTarget, PublicValues, PublicValuesTarget, StarkOpeningSetTarget, StarkProof,
StarkProofChallengesTarget, StarkProofTarget, TrieRoots, TrieRootsTarget,
};
use crate::stark::Stark;
@ -93,8 +99,10 @@ fn recursively_verify_stark_proof<
>(
table: Table,
stark: S,
all_stark: &AllStark<F, D>,
all_proof: &AllProof<F, C, D>,
proof: &StarkProof<F, C, D>,
cross_table_lookups: &[CrossTableLookup<F>],
ctl_challenges: &GrandProductChallengeSet<F>,
challenger_state_before: [F; SPONGE_WIDTH],
inner_config: &StarkConfig,
circuit_config: &CircuitConfig,
) -> Result<(ProofWithPublicInputs<F, C, D>, VerifierCircuitData<F, C, D>)>
@ -106,36 +114,63 @@ where
let mut builder = CircuitBuilder::<F, D>::new(circuit_config.clone());
let mut pw = PartialWitness::new();
let nums_ctl_zs = all_proof.nums_ctl_zs();
let degree_bits = all_proof.degree_bits(inner_config);
let num_permutation_zs = stark.num_permutation_batches(inner_config);
let all_proof_target = add_virtual_all_proof(
let num_permutation_batch_size = stark.permutation_batch_size();
let proof_target = add_virtual_stark_proof(
&mut builder,
all_stark,
&stark,
inner_config,
&degree_bits,
&nums_ctl_zs,
proof.recover_degree_bits(inner_config),
proof.num_ctl_zs(),
);
set_all_proof_target(&mut pw, &all_proof_target, all_proof, builder.zero());
set_stark_proof_target(&mut pw, &proof_target, proof, builder.zero());
let AllProofChallengesTarget {
stark_challenges,
ctl_challenges,
} = all_proof_target.get_challenges::<F, C>(&mut builder, all_stark, inner_config);
let ctl_challenges_target = GrandProductChallengeSet {
challenges: (0..inner_config.num_challenges)
.map(|_| GrandProductChallenge {
beta: builder.add_virtual_public_input(),
gamma: builder.add_virtual_public_input(),
})
.collect(),
};
for i in 0..inner_config.num_challenges {
pw.set_target(
ctl_challenges_target.challenges[i].beta,
ctl_challenges.challenges[i].beta,
);
pw.set_target(
ctl_challenges_target.challenges[i].gamma,
ctl_challenges.challenges[i].gamma,
);
}
let ctl_vars = CtlCheckVarsTarget::from_proof(
table,
&all_proof_target.stark_proofs[table as usize],
&all_stark.cross_table_lookups,
&ctl_challenges,
&proof_target,
cross_table_lookups,
&ctl_challenges_target,
num_permutation_zs,
);
let challenger_state = std::array::from_fn(|_| builder.add_virtual_public_input());
pw.set_target_arr(challenger_state, challenger_state_before);
let mut challenger = RecursiveChallenger::<F, C::Hasher, D>::from_state(challenger_state);
let challenges = proof_target.get_challenges::<F, C>(
&mut builder,
&mut challenger,
num_permutation_zs > 0,
num_permutation_batch_size,
inner_config,
);
challenger.duplexing(&mut builder);
let challenger_state = challenger.state();
builder.register_public_inputs(&challenger_state);
verify_stark_proof_with_challenges_circuit::<F, C, _, D>(
&mut builder,
stark,
&all_proof_target.stark_proofs[table as usize],
&stark_challenges[table as usize],
&stark,
&proof_target,
&challenges,
&ctl_vars,
inner_config,
);
@ -164,45 +199,59 @@ where
[(); C::Hasher::HASH_SIZE]:,
C::Hasher: AlgebraicHasher<F>,
{
let AllChallengerState {
states,
ctl_challenges,
} = all_proof.get_challenger_states(all_stark, inner_config);
Ok(RecursiveAllProof {
recursive_proofs: [
recursively_verify_stark_proof(
Table::Cpu,
all_stark.cpu_stark,
all_stark,
all_proof,
&all_proof.stark_proofs[Table::Cpu as usize],
&all_stark.cross_table_lookups,
&ctl_challenges,
states[0],
inner_config,
&circuit_config,
)?,
recursively_verify_stark_proof(
Table::Keccak,
all_stark.keccak_stark,
all_stark,
all_proof,
&all_proof.stark_proofs[Table::Keccak as usize],
&all_stark.cross_table_lookups,
&ctl_challenges,
states[1],
inner_config,
&circuit_config,
)?,
recursively_verify_stark_proof(
Table::KeccakMemory,
all_stark.keccak_memory_stark,
all_stark,
all_proof,
&all_proof.stark_proofs[Table::KeccakMemory as usize],
&all_stark.cross_table_lookups,
&ctl_challenges,
states[2],
inner_config,
&circuit_config,
)?,
recursively_verify_stark_proof(
Table::Logic,
all_stark.logic_stark,
all_stark,
all_proof,
&all_proof.stark_proofs[Table::Logic as usize],
&all_stark.cross_table_lookups,
&ctl_challenges,
states[3],
inner_config,
&circuit_config,
)?,
recursively_verify_stark_proof(
Table::Memory,
all_stark.memory_stark,
all_stark,
all_proof,
&all_proof.stark_proofs[Table::Memory as usize],
&all_stark.cross_table_lookups,
&ctl_challenges,
states[4],
inner_config,
&circuit_config,
)?,
@ -255,7 +304,7 @@ pub fn verify_proof_circuit<
"verify CPU proof",
verify_stark_proof_with_challenges_circuit::<F, C, _, D>(
builder,
cpu_stark,
&cpu_stark,
&all_proof.stark_proofs[Table::Cpu as usize],
&stark_challenges[Table::Cpu as usize],
&ctl_vars_per_table[Table::Cpu as usize],
@ -267,7 +316,7 @@ pub fn verify_proof_circuit<
"verify Keccak proof",
verify_stark_proof_with_challenges_circuit::<F, C, _, D>(
builder,
keccak_stark,
&keccak_stark,
&all_proof.stark_proofs[Table::Keccak as usize],
&stark_challenges[Table::Keccak as usize],
&ctl_vars_per_table[Table::Keccak as usize],
@ -279,7 +328,7 @@ pub fn verify_proof_circuit<
"verify Keccak memory proof",
verify_stark_proof_with_challenges_circuit::<F, C, _, D>(
builder,
keccak_memory_stark,
&keccak_memory_stark,
&all_proof.stark_proofs[Table::KeccakMemory as usize],
&stark_challenges[Table::KeccakMemory as usize],
&ctl_vars_per_table[Table::KeccakMemory as usize],
@ -291,7 +340,7 @@ pub fn verify_proof_circuit<
"verify logic proof",
verify_stark_proof_with_challenges_circuit::<F, C, _, D>(
builder,
logic_stark,
&logic_stark,
&all_proof.stark_proofs[Table::Logic as usize],
&stark_challenges[Table::Logic as usize],
&ctl_vars_per_table[Table::Logic as usize],
@ -303,7 +352,7 @@ pub fn verify_proof_circuit<
"verify memory proof",
verify_stark_proof_with_challenges_circuit::<F, C, _, D>(
builder,
memory_stark,
&memory_stark,
&all_proof.stark_proofs[Table::Memory as usize],
&stark_challenges[Table::Memory as usize],
&ctl_vars_per_table[Table::Memory as usize],
@ -332,7 +381,7 @@ fn verify_stark_proof_with_challenges_circuit<
const D: usize,
>(
builder: &mut CircuitBuilder<F, D>,
stark: S,
stark: &S,
proof: &StarkProofTarget<D>,
challenges: &StarkProofChallengesTarget<D>,
ctl_vars: &[CtlCheckVarsTarget<F, D>],
@ -388,7 +437,7 @@ fn verify_stark_proof_with_challenges_circuit<
"evaluate vanishing polynomial",
eval_vanishing_poly_circuit::<F, C, S, D>(
builder,
&stark,
stark,
inner_config,
vars,
permutation_data,
@ -462,35 +511,35 @@ pub fn add_virtual_all_proof<F: RichField + Extendable<D>, const D: usize>(
let stark_proofs = [
add_virtual_stark_proof(
builder,
all_stark.cpu_stark,
&all_stark.cpu_stark,
config,
degree_bits[Table::Cpu as usize],
nums_ctl_zs[Table::Cpu as usize],
),
add_virtual_stark_proof(
builder,
all_stark.keccak_stark,
&all_stark.keccak_stark,
config,
degree_bits[Table::Keccak as usize],
nums_ctl_zs[Table::Keccak as usize],
),
add_virtual_stark_proof(
builder,
all_stark.keccak_memory_stark,
&all_stark.keccak_memory_stark,
config,
degree_bits[Table::KeccakMemory as usize],
nums_ctl_zs[Table::KeccakMemory as usize],
),
add_virtual_stark_proof(
builder,
all_stark.logic_stark,
&all_stark.logic_stark,
config,
degree_bits[Table::Logic as usize],
nums_ctl_zs[Table::Logic as usize],
),
add_virtual_stark_proof(
builder,
all_stark.memory_stark,
&all_stark.memory_stark,
config,
degree_bits[Table::Memory as usize],
nums_ctl_zs[Table::Memory as usize],
@ -553,7 +602,7 @@ pub fn add_virtual_block_metadata<F: RichField + Extendable<D>, const D: usize>(
pub fn add_virtual_stark_proof<F: RichField + Extendable<D>, S: Stark<F, D>, const D: usize>(
builder: &mut CircuitBuilder<F, D>,
stark: S,
stark: &S,
config: &StarkConfig,
degree_bits: usize,
num_ctl_zs: usize,
@ -580,7 +629,7 @@ pub fn add_virtual_stark_proof<F: RichField + Extendable<D>, S: Stark<F, D>, con
fn add_stark_opening_set<F: RichField + Extendable<D>, S: Stark<F, D>, const D: usize>(
builder: &mut CircuitBuilder<F, D>,
stark: S,
stark: &S,
num_ctl_zs: usize,
config: &StarkConfig,
) -> StarkOpeningSetTarget<D> {

View File

@ -129,7 +129,7 @@ impl<F: RichField, H: Hasher<F>> Challenger<F, H> {
/// Absorb any buffered inputs. After calling this, the input buffer will be empty, and the
/// output buffer will be full.
fn duplexing(&mut self) {
pub fn duplexing(&mut self) {
assert!(self.input_buffer.len() <= SPONGE_RATE);
// Overwrite the first r elements with the inputs. This differs from a standard sponge,
@ -146,6 +146,10 @@ impl<F: RichField, H: Hasher<F>> Challenger<F, H> {
self.output_buffer
.extend_from_slice(&self.sponge_state[0..SPONGE_RATE]);
}
pub fn state(&self) -> [F; SPONGE_WIDTH] {
self.sponge_state
}
}
impl<F: RichField, H: AlgebraicHasher<F>> Default for Challenger<F, H> {
@ -176,6 +180,15 @@ impl<F: RichField + Extendable<D>, H: AlgebraicHasher<F>, const D: usize>
}
}
pub fn from_state(sponge_state: [Target; SPONGE_WIDTH]) -> Self {
let output_buffer = sponge_state[0..SPONGE_RATE].to_vec();
RecursiveChallenger {
sponge_state,
input_buffer: vec![],
output_buffer,
}
}
pub(crate) fn observe_element(&mut self, target: Target) {
// Any buffered outputs are now invalid, since they wouldn't reflect this input.
self.output_buffer.clear();
@ -272,6 +285,28 @@ impl<F: RichField + Extendable<D>, H: AlgebraicHasher<F>, const D: usize>
self.input_buffer.clear();
}
pub fn duplexing(&mut self, builder: &mut CircuitBuilder<F, D>) {
for input_chunk in self.input_buffer.chunks(SPONGE_RATE) {
// Overwrite the first r elements with the inputs. This differs from a standard sponge,
// where we would xor or add in the inputs. This is a well-known variant, though,
// sometimes called "overwrite mode".
for (i, &input) in input_chunk.iter().enumerate() {
self.sponge_state[i] = input;
}
// Apply the permutation.
self.sponge_state = builder.permute::<H>(self.sponge_state);
}
self.output_buffer = self.sponge_state[0..SPONGE_RATE].to_vec();
self.input_buffer.clear();
}
pub fn state(&self) -> [Target; SPONGE_WIDTH] {
self.sponge_state
}
}
#[cfg(test)]

View File

@ -208,6 +208,13 @@ impl<F: RichField + Extendable<D>, const D: usize> CircuitBuilder<F, D> {
b
}
/// Add a virtual target and register it as a public input.
pub fn add_virtual_public_input(&mut self) -> Target {
let t = self.add_virtual_target();
self.register_public_input(t);
t
}
/// Adds a gate to the circuit, and returns its index.
pub fn add_gate<G: Gate<F, D>>(&mut self, gate_type: G, mut constants: Vec<F>) -> usize {
self.check_gate_compatibility(&gate_type);

View File

@ -282,7 +282,7 @@ impl Buffer {
arity: usize,
compressed: bool,
) -> Result<FriQueryStep<F, C::Hasher, D>> {
let evals = self.read_field_ext_vec::<F, D>(arity - if compressed { 1 } else { 0 })?;
let evals = self.read_field_ext_vec::<F, D>(arity - usize::from(compressed))?;
let merkle_proof = self.read_merkle_proof()?;
Ok(FriQueryStep {
evals,