diff --git a/evm/src/get_challenges.rs b/evm/src/get_challenges.rs index ac2e7aba..9cb07c31 100644 --- a/evm/src/get_challenges.rs +++ b/evm/src/get_challenges.rs @@ -36,7 +36,7 @@ impl, C: GenericConfig, const D: usize> A AllProofChallenges { stark_challenges: std::array::from_fn(|i| { - challenger.duplexing(); + challenger.compact(); self.stark_proofs[i].get_challenges( &mut challenger, num_permutation_zs[i] > 0, @@ -66,8 +66,7 @@ impl, C: GenericConfig, const D: usize> A let num_permutation_zs = all_stark.nums_permutation_zs(config); let num_permutation_batch_sizes = all_stark.permutation_batch_sizes(); - challenger.duplexing(); - let mut challenger_states = vec![challenger.state()]; + let mut challenger_states = vec![challenger.compact()]; for i in 0..NUM_TABLES { self.stark_proofs[i].get_challenges( &mut challenger, @@ -75,8 +74,7 @@ impl, C: GenericConfig, const D: usize> A num_permutation_batch_sizes[i], config, ); - challenger.duplexing(); - challenger_states.push(challenger.state()); + challenger_states.push(challenger.compact()); } AllChallengerState { diff --git a/evm/src/prover.rs b/evm/src/prover.rs index eba0759f..3b702c56 100644 --- a/evm/src/prover.rs +++ b/evm/src/prover.rs @@ -201,7 +201,7 @@ where "FRI total reduction arity is too large.", ); - challenger.duplexing(); + challenger.compact(); // Permutation arguments. let permutation_challenges = stark.uses_permutation_args().then(|| { diff --git a/evm/src/recursive_verifier.rs b/evm/src/recursive_verifier.rs index 3f079ea4..6a94feb9 100644 --- a/evm/src/recursive_verifier.rs +++ b/evm/src/recursive_verifier.rs @@ -127,8 +127,7 @@ impl, C: GenericConfig, const D: usize> ensure!(ctl_challenges == pi.ctl_challenges); } - challenger.duplexing(); - let state = challenger.state(); + let state = challenger.compact(); ensure!(state == pis[0].challenger_state_before); // Check that the challenger state is consistent between proofs. for i in 1..NUM_TABLES { @@ -209,8 +208,7 @@ impl, C: GenericConfig, const D: usize> } } - challenger.duplexing(builder); - let state = challenger.state(); + let state = challenger.compact(builder); for k in 0..SPONGE_WIDTH { builder.connect(state[k], pis[0].challenger_state_before[k]); } @@ -321,8 +319,7 @@ where num_permutation_batch_size, inner_config, ); - challenger.duplexing(&mut builder); - let challenger_state = challenger.state(); + let challenger_state = challenger.compact(&mut builder); builder.register_public_inputs(&challenger_state); builder.register_public_inputs(&proof_target.openings.ctl_zs_last); @@ -402,8 +399,7 @@ where num_permutation_batch_size, inner_config, ); - challenger.duplexing(&mut builder); - let challenger_state = challenger.state(); + let challenger_state = challenger.compact(&mut builder); builder.register_public_inputs(&challenger_state); builder.register_public_inputs(&proof_target.openings.ctl_zs_last); diff --git a/plonky2/src/iop/challenger.rs b/plonky2/src/iop/challenger.rs index eeb18038..84bbad1c 100644 --- a/plonky2/src/iop/challenger.rs +++ b/plonky2/src/iop/challenger.rs @@ -147,7 +147,11 @@ impl> Challenger { .extend_from_slice(&self.sponge_state[0..SPONGE_RATE]); } - pub fn state(&self) -> [F; SPONGE_WIDTH] { + pub fn compact(&mut self) -> [F; SPONGE_WIDTH] { + if !self.input_buffer.is_empty() { + self.duplexing(); + } + self.output_buffer.clear(); self.sponge_state } } @@ -181,11 +185,10 @@ impl, H: AlgebraicHasher, const D: usize> } pub fn from_state(sponge_state: [Target; SPONGE_WIDTH]) -> Self { - let output_buffer = sponge_state[0..SPONGE_RATE].to_vec(); RecursiveChallenger { sponge_state, input_buffer: vec![], - output_buffer, + output_buffer: vec![], } } @@ -286,29 +289,9 @@ impl, H: AlgebraicHasher, const D: usize> self.input_buffer.clear(); } - pub fn duplexing(&mut self, builder: &mut CircuitBuilder) { - if self.input_buffer.is_empty() { - self.sponge_state = builder.permute::(self.sponge_state); - } else { - for input_chunk in self.input_buffer.chunks(SPONGE_RATE) { - // Overwrite the first r elements with the inputs. This differs from a standard sponge, - // where we would xor or add in the inputs. This is a well-known variant, though, - // sometimes called "overwrite mode". - for (i, &input) in input_chunk.iter().enumerate() { - self.sponge_state[i] = input; - } - - // Apply the permutation. - self.sponge_state = builder.permute::(self.sponge_state); - } - } - - self.output_buffer = self.sponge_state[0..SPONGE_RATE].to_vec(); - - self.input_buffer.clear(); - } - - pub fn state(&self) -> [Target; SPONGE_WIDTH] { + pub fn compact(&mut self, builder: &mut CircuitBuilder) -> [Target; SPONGE_WIDTH] { + self.absorb_buffered_inputs(builder); + self.output_buffer.clear(); self.sponge_state } }