From 7d11d0f8a1f83b5e885caee2bd53fef25064bdad Mon Sep 17 00:00:00 2001 From: wborgeaud Date: Fri, 6 Aug 2021 14:58:39 +0200 Subject: [PATCH] Change PartialWitness to use `Vec`s --- src/bin/bench_recursion.rs | 2 +- src/gadgets/arithmetic_extension.rs | 4 +- src/gadgets/insert.rs | 2 +- src/gadgets/interpolation.rs | 4 +- src/gadgets/random_access.rs | 2 +- src/gadgets/select.rs | 2 +- src/gadgets/split_base.rs | 4 +- src/gates/gate_testing.rs | 2 +- src/gates/gmimc.rs | 2 +- src/hash/merkle_proofs.rs | 2 +- src/iop/challenger.rs | 2 +- src/iop/witness.rs | 71 +++++++++++++++++++---------- src/plonk/recursive_verifier.rs | 10 ++-- src/util/reducing.rs | 4 +- 14 files changed, 69 insertions(+), 44 deletions(-) diff --git a/src/bin/bench_recursion.rs b/src/bin/bench_recursion.rs index 59cdf4a6..59ee2646 100644 --- a/src/bin/bench_recursion.rs +++ b/src/bin/bench_recursion.rs @@ -49,7 +49,7 @@ fn bench_prove, const D: usize>() -> Result<()> { builder.add_extension(zero_ext, zero_ext); let circuit = builder.build(); - let inputs = PartialWitness::new(); + let inputs = PartialWitness::new(0, 0, 0); let proof_with_pis = circuit.prove(inputs)?; let proof_bytes = serde_cbor::to_vec(&proof_with_pis).unwrap(); info!("Proof length: {} bytes", proof_bytes.len()); diff --git a/src/gadgets/arithmetic_extension.rs b/src/gadgets/arithmetic_extension.rs index d76dab9d..ad7519ea 100644 --- a/src/gadgets/arithmetic_extension.rs +++ b/src/gadgets/arithmetic_extension.rs @@ -604,7 +604,7 @@ mod tests { let config = CircuitConfig::large_config(); let mut builder = CircuitBuilder::::new(config); - let mut pw = PartialWitness::new(); + let mut pw = PartialWitness::new(0, 0, 0); let vs = FF::rand_vec(3); let ts = builder.add_virtual_extension_targets(3); @@ -654,7 +654,7 @@ mod tests { builder.assert_equal_extension(zt, comp_zt_unsafe); let data = builder.build(); - let proof = data.prove(PartialWitness::new())?; + let proof = data.prove(PartialWitness::new(0, 0, 0))?; verify(proof, &data.verifier_only, &data.common) } diff --git a/src/gadgets/insert.rs b/src/gadgets/insert.rs index 72fe032b..37a1d302 100644 --- a/src/gadgets/insert.rs +++ b/src/gadgets/insert.rs @@ -83,7 +83,7 @@ mod tests { } let data = builder.build(); - let proof = data.prove(PartialWitness::new())?; + let proof = data.prove(PartialWitness::new(0, 0, 0))?; verify(proof, &data.verifier_only, &data.common) } diff --git a/src/gadgets/interpolation.rs b/src/gadgets/interpolation.rs index cc547d88..12b694d9 100644 --- a/src/gadgets/interpolation.rs +++ b/src/gadgets/interpolation.rs @@ -99,7 +99,7 @@ mod tests { builder.assert_equal_extension(eval, true_eval_target); let data = builder.build(); - let proof = data.prove(PartialWitness::new())?; + let proof = data.prove(PartialWitness::new(0, 0, 0))?; verify(proof, &data.verifier_only, &data.common) } @@ -133,7 +133,7 @@ mod tests { builder.assert_equal_extension(eval, true_eval_target); let data = builder.build(); - let proof = data.prove(PartialWitness::new())?; + let proof = data.prove(PartialWitness::new(0, 0, 0))?; verify(proof, &data.verifier_only, &data.common) } diff --git a/src/gadgets/random_access.rs b/src/gadgets/random_access.rs index a435b99f..493a1ecc 100644 --- a/src/gadgets/random_access.rs +++ b/src/gadgets/random_access.rs @@ -61,7 +61,7 @@ mod tests { } let data = builder.build(); - let proof = data.prove(PartialWitness::new())?; + let proof = data.prove(PartialWitness::new(0, 0, 0))?; verify(proof, &data.verifier_only, &data.common) } diff --git a/src/gadgets/select.rs b/src/gadgets/select.rs index f1f651dc..ed78eee5 100644 --- a/src/gadgets/select.rs +++ b/src/gadgets/select.rs @@ -49,7 +49,7 @@ mod tests { type FF = QuarticCrandallField; let config = CircuitConfig::large_config(); let mut builder = CircuitBuilder::::new(config); - let mut pw = PartialWitness::new(); + let mut pw = PartialWitness::new(0, 0, 0); let (x, y) = (FF::rand(), FF::rand()); let xt = builder.add_virtual_extension_target(); diff --git a/src/gadgets/split_base.rs b/src/gadgets/split_base.rs index 0e135c05..c579213d 100644 --- a/src/gadgets/split_base.rs +++ b/src/gadgets/split_base.rs @@ -115,7 +115,7 @@ mod tests { builder.assert_leading_zeros(xt, 64 - 9); let data = builder.build(); - let proof = data.prove(PartialWitness::new())?; + let proof = data.prove(PartialWitness::new(0, 0, 0))?; verify(proof, &data.verifier_only, &data.common) } @@ -147,7 +147,7 @@ mod tests { let data = builder.build(); - let proof = data.prove(PartialWitness::new())?; + let proof = data.prove(PartialWitness::new(0, 0, 0))?; verify(proof, &data.verifier_only, &data.common) } diff --git a/src/gates/gate_testing.rs b/src/gates/gate_testing.rs index f7a17d81..6c27e76d 100644 --- a/src/gates/gate_testing.rs +++ b/src/gates/gate_testing.rs @@ -125,7 +125,7 @@ pub(crate) fn test_eval_fns, G: Gate, const D: usize>( let config = CircuitConfig::large_config(); let mut builder = CircuitBuilder::::new(config); - let mut pw = PartialWitness::new(); + let mut pw = PartialWitness::new(0, 0, 0); let wires_t = builder.add_virtual_extension_targets(wires.len()); let constants_t = builder.add_virtual_extension_targets(constants.len()); diff --git a/src/gates/gmimc.rs b/src/gates/gmimc.rs index 73fcc190..a6c74e91 100644 --- a/src/gates/gmimc.rs +++ b/src/gates/gmimc.rs @@ -345,7 +345,7 @@ mod tests { let permutation_inputs = (0..W).map(F::from_canonical_usize).collect::>(); - let mut witness = PartialWitness::new(); + let mut witness = PartialWitness::new(0, 0, 0); witness.set_wire( Wire { gate: 0, diff --git a/src/hash/merkle_proofs.rs b/src/hash/merkle_proofs.rs index 9af1cb58..2b5fae13 100644 --- a/src/hash/merkle_proofs.rs +++ b/src/hash/merkle_proofs.rs @@ -155,7 +155,7 @@ mod tests { type F = CrandallField; let config = CircuitConfig::large_config(); let mut builder = CircuitBuilder::::new(config); - let mut pw = PartialWitness::new(); + let mut pw = PartialWitness::new(0, 0, 0); let log_n = 8; let n = 1 << log_n; diff --git a/src/iop/challenger.rs b/src/iop/challenger.rs index 886fb09e..cfffc79f 100644 --- a/src/iop/challenger.rs +++ b/src/iop/challenger.rs @@ -409,7 +409,7 @@ mod tests { ); } let circuit = builder.build(); - let mut witness = PartialWitness::new(); + let mut witness = PartialWitness::new(0, 0, 0); generate_partial_witness( &mut witness, &circuit.prover_only.generators, diff --git a/src/iop/witness.rs b/src/iop/witness.rs index 5b97f546..1eabad4b 100644 --- a/src/iop/witness.rs +++ b/src/iop/witness.rs @@ -26,22 +26,23 @@ impl Witness { #[derive(Clone, Debug)] pub struct PartialWitness { - pub(crate) target_values: HashMap, + pub(crate) wire_values: Vec>>, + pub(crate) virtual_target_values: Vec>, } impl PartialWitness { - pub fn new() -> Self { + pub fn new(degree: usize, num_wires: usize, max_virtual_target: usize) -> Self { PartialWitness { - target_values: HashMap::new(), + wire_values: vec![vec![None; num_wires]; degree], + virtual_target_values: vec![None; max_virtual_target], } } - pub fn is_empty(&self) -> bool { - self.target_values.is_empty() - } - pub fn get_target(&self, target: Target) -> F { - self.target_values[&target] + match target { + Target::Wire(Wire { gate, input }) => self.wire_values[gate][input].unwrap(), + Target::VirtualTarget { index } => self.virtual_target_values[index].unwrap(), + } } pub fn get_targets(&self, targets: &[Target]) -> Vec { @@ -76,7 +77,10 @@ impl PartialWitness { } pub fn try_get_target(&self, target: Target) -> Option { - self.target_values.get(&target).cloned() + match target { + Target::Wire(Wire { gate, input }) => self.wire_values[gate][input], + Target::VirtualTarget { index } => self.virtual_target_values[index], + } } pub fn get_wire(&self, wire: Wire) -> F { @@ -88,7 +92,10 @@ impl PartialWitness { } pub fn contains(&self, target: Target) -> bool { - self.target_values.contains_key(&target) + match target { + Target::Wire(Wire { gate, input }) => self.wire_values[gate][input].is_some(), + Target::VirtualTarget { index } => self.virtual_target_values[index].is_some(), + } } pub fn contains_all(&self, targets: &[Target]) -> bool { @@ -96,13 +103,29 @@ impl PartialWitness { } pub fn set_target(&mut self, target: Target, value: F) { - let opt_old_value = self.target_values.insert(target, value); - if let Some(old_value) = opt_old_value { - assert_eq!( - old_value, value, - "Target was set twice with different values: {:?}", - target - ); + match target { + Target::Wire(Wire { gate, input }) => { + if let Some(old_value) = self.wire_values[gate][input] { + assert_eq!( + old_value, value, + "Target was set twice with different values: {:?}", + target + ); + } else { + self.wire_values[gate][input] = Some(value); + } + } + Target::VirtualTarget { index } => { + if let Some(old_value) = self.virtual_target_values[index] { + assert_eq!( + old_value, value, + "Target was set twice with different values: {:?}", + target + ); + } else { + self.virtual_target_values[index] = Some(value); + } + } } } @@ -162,16 +185,18 @@ impl PartialWitness { } pub fn extend>(&mut self, pairs: I) { - self.target_values.extend(pairs); + for (t, v) in pairs { + self.set_target(t, v); + } } pub fn full_witness(self, degree: usize, num_wires: usize) -> Witness { let mut wire_values = vec![vec![F::ZERO; degree]; num_wires]; - self.target_values.into_iter().for_each(|(t, v)| { - if let Target::Wire(Wire { gate, input }) = t { - wire_values[input][gate] = v; + for i in 0..degree { + for j in 0..num_wires { + wire_values[j][i] = self.wire_values[i][j].unwrap_or(F::ZERO); } - }); + } Witness { wire_values } } @@ -215,6 +240,6 @@ impl PartialWitness { impl Default for PartialWitness { fn default() -> Self { - Self::new() + Self::new(0, 0, 0) } } diff --git a/src/plonk/recursive_verifier.rs b/src/plonk/recursive_verifier.rs index 773a695e..e21eec98 100644 --- a/src/plonk/recursive_verifier.rs +++ b/src/plonk/recursive_verifier.rs @@ -387,7 +387,7 @@ mod tests { } let data = builder.build(); ( - data.prove(PartialWitness::new())?, + data.prove(PartialWitness::new(0, 0, 0))?, data.verifier_only, data.common, ) @@ -395,7 +395,7 @@ mod tests { verify(proof_with_pis.clone(), &vd, &cd)?; let mut builder = CircuitBuilder::::new(config.clone()); - let mut pw = PartialWitness::new(); + let mut pw = PartialWitness::new(0, 0, 0); let pt = proof_to_proof_target(&proof_with_pis, &mut builder); set_proof_target(&proof_with_pis, &pt, &mut pw); @@ -442,7 +442,7 @@ mod tests { } let data = builder.build(); ( - data.prove(PartialWitness::new())?, + data.prove(PartialWitness::new(1 << 14, config.num_wires, 1000))?, data.verifier_only, data.common, ) @@ -450,7 +450,7 @@ mod tests { verify(proof_with_pis.clone(), &vd, &cd)?; let mut builder = CircuitBuilder::::new(config.clone()); - let mut pw = PartialWitness::new(); + let mut pw = PartialWitness::new(1 << 14, config.num_wires, 100000); let pt = proof_to_proof_target(&proof_with_pis, &mut builder); set_proof_target(&proof_with_pis, &pt, &mut pw); @@ -468,7 +468,7 @@ mod tests { verify(proof_with_pis.clone(), &vd, &cd)?; let mut builder = CircuitBuilder::::new(config.clone()); - let mut pw = PartialWitness::new(); + let mut pw = PartialWitness::new(1 << 14, config.num_wires, 100000); let pt = proof_to_proof_target(&proof_with_pis, &mut builder); set_proof_target(&proof_with_pis, &pt, &mut pw); diff --git a/src/util/reducing.rs b/src/util/reducing.rs index be95f67a..001783fd 100644 --- a/src/util/reducing.rs +++ b/src/util/reducing.rs @@ -264,7 +264,7 @@ mod tests { builder.assert_equal_extension(manual_reduce, circuit_reduce); let data = builder.build(); - let proof = data.prove(PartialWitness::new())?; + let proof = data.prove(PartialWitness::new(0, 0, 0))?; verify(proof, &data.verifier_only, &data.common) } @@ -294,7 +294,7 @@ mod tests { builder.assert_equal_extension(manual_reduce, circuit_reduce); let data = builder.build(); - let proof = data.prove(PartialWitness::new())?; + let proof = data.prove(PartialWitness::new(0, 0, 0))?; verify(proof, &data.verifier_only, &data.common) }