From a44bf9ffd8cbd044fa4d3b58803c4964a9dd75b9 Mon Sep 17 00:00:00 2001 From: wborgeaud Date: Fri, 20 Aug 2021 09:50:07 +0200 Subject: [PATCH] Added witness trait --- src/gadgets/arithmetic_extension.rs | 2 +- src/gadgets/select.rs | 2 +- src/gates/gate_testing.rs | 2 +- src/hash/merkle_proofs.rs | 2 +- src/iop/generator.rs | 6 +- src/iop/witness.rs | 335 ++++++++++++---------------- src/plonk/prover.rs | 6 +- src/plonk/recursive_verifier.rs | 2 +- src/util/marking.rs | 2 +- 9 files changed, 151 insertions(+), 208 deletions(-) diff --git a/src/gadgets/arithmetic_extension.rs b/src/gadgets/arithmetic_extension.rs index f7b8eee5..8ed0e934 100644 --- a/src/gadgets/arithmetic_extension.rs +++ b/src/gadgets/arithmetic_extension.rs @@ -504,7 +504,7 @@ mod tests { use crate::field::extension_field::algebra::ExtensionAlgebra; use crate::field::extension_field::quartic::QuarticCrandallField; use crate::field::field_types::Field; - use crate::iop::witness::PartialWitness; + use crate::iop::witness::{PartialWitness, Witness}; use crate::plonk::circuit_builder::CircuitBuilder; use crate::plonk::circuit_data::CircuitConfig; use crate::plonk::verifier::verify; diff --git a/src/gadgets/select.rs b/src/gadgets/select.rs index 3be3455c..3f0e001d 100644 --- a/src/gadgets/select.rs +++ b/src/gadgets/select.rs @@ -43,7 +43,7 @@ mod tests { use crate::field::crandall_field::CrandallField; use crate::field::extension_field::quartic::QuarticCrandallField; use crate::field::field_types::Field; - use crate::iop::witness::PartialWitness; + use crate::iop::witness::{PartialWitness, Witness}; use crate::plonk::circuit_builder::CircuitBuilder; use crate::plonk::circuit_data::CircuitConfig; use crate::plonk::verifier::verify; diff --git a/src/gates/gate_testing.rs b/src/gates/gate_testing.rs index 4d336454..0e773aa0 100644 --- a/src/gates/gate_testing.rs +++ b/src/gates/gate_testing.rs @@ -4,7 +4,7 @@ use crate::field::extension_field::{Extendable, FieldExtension}; use crate::field::field_types::Field; use crate::gates::gate::Gate; use crate::hash::hash_types::HashOut; -use crate::iop::witness::PartialWitness; +use crate::iop::witness::{PartialWitness, Witness}; use crate::plonk::circuit_builder::CircuitBuilder; use crate::plonk::circuit_data::CircuitConfig; use crate::plonk::vars::{EvaluationTargets, EvaluationVars, EvaluationVarsBase}; diff --git a/src/hash/merkle_proofs.rs b/src/hash/merkle_proofs.rs index 7a7b358d..5f333f4e 100644 --- a/src/hash/merkle_proofs.rs +++ b/src/hash/merkle_proofs.rs @@ -221,7 +221,7 @@ mod tests { use super::*; use crate::field::crandall_field::CrandallField; use crate::hash::merkle_tree::MerkleTree; - use crate::iop::witness::PartialWitness; + use crate::iop::witness::{PartialWitness, Witness}; use crate::plonk::circuit_builder::CircuitBuilder; use crate::plonk::circuit_data::CircuitConfig; use crate::plonk::verifier::verify; diff --git a/src/iop/generator.rs b/src/iop/generator.rs index f86d9f30..62fe4d26 100644 --- a/src/iop/generator.rs +++ b/src/iop/generator.rs @@ -9,7 +9,7 @@ use crate::hash::hash_types::{HashOut, HashOutTarget, MerkleCapTarget}; use crate::hash::merkle_tree::MerkleCap; use crate::iop::target::{BoolTarget, Target}; use crate::iop::wire::Wire; -use crate::iop::witness::{PartialWitness, Witness}; +use crate::iop::witness::{MatrixWitness, PartialWitness, Witness}; use crate::plonk::permutation_argument::ForestNode; use crate::timed; use crate::util::timing::TimingTree; @@ -159,7 +159,7 @@ impl Yo { } } - pub fn full_witness(self, degree: usize, num_wires: usize) -> Witness { + pub fn full_witness(self, degree: usize, num_wires: usize) -> MatrixWitness { let mut wire_values = vec![vec![F::ZERO; degree]; num_wires]; // assert!(self.wire_values.len() <= degree); for i in 0..degree { @@ -168,7 +168,7 @@ impl Yo { wire_values[j][i] = self.0[self.0[self.1(t)].parent].value.unwrap_or(F::ZERO); } } - Witness { wire_values } + MatrixWitness { wire_values } } } diff --git a/src/iop/witness.rs b/src/iop/witness.rs index cfcf5950..5d1441ff 100644 --- a/src/iop/witness.rs +++ b/src/iop/witness.rs @@ -13,12 +13,143 @@ use crate::iop::target::{BoolTarget, Target}; use crate::iop::wire::Wire; use crate::plonk::copy_constraint::CopyConstraint; +pub trait Witness { + fn try_get_target(&self, target: Target) -> Option; + + fn set_target(&mut self, target: Target, value: F); + + fn get_target(&self, target: Target) -> F { + self.try_get_target(target).unwrap() + } + + fn get_targets(&self, targets: &[Target]) -> Vec { + targets.iter().map(|&t| self.get_target(t)).collect() + } + + fn get_extension_target(&self, et: ExtensionTarget) -> F::Extension + where + F: Extendable, + { + F::Extension::from_basefield_array( + self.get_targets(&et.to_target_array()).try_into().unwrap(), + ) + } + + fn get_extension_targets(&self, ets: &[ExtensionTarget]) -> Vec + where + F: Extendable, + { + ets.iter() + .map(|&et| self.get_extension_target(et)) + .collect() + } + + fn get_bool_target(&self, target: BoolTarget) -> bool { + let value = self.get_target(target.target).to_canonical_u64(); + match value { + 0 => false, + 1 => true, + _ => panic!("not a bool"), + } + } + + fn get_hash_target(&self, ht: HashOutTarget) -> HashOut { + HashOut { + elements: self.get_targets(&ht.elements).try_into().unwrap(), + } + } + + fn get_wire(&self, wire: Wire) -> F { + self.get_target(Target::Wire(wire)) + } + + fn try_get_wire(&self, wire: Wire) -> Option { + self.try_get_target(Target::Wire(wire)) + } + + fn contains(&self, target: Target) -> bool { + self.try_get_target(target).is_some() + } + + fn contains_all(&self, targets: &[Target]) -> bool { + targets.iter().all(|&t| self.contains(t)) + } + + fn set_hash_target(&mut self, ht: HashOutTarget, value: HashOut) { + ht.elements + .iter() + .zip(value.elements) + .for_each(|(&t, x)| self.set_target(t, x)); + } + + fn set_cap_target(&mut self, ct: &MerkleCapTarget, value: &MerkleCap) { + for (ht, h) in ct.0.iter().zip(&value.0) { + self.set_hash_target(*ht, *h); + } + } + + fn set_extension_target(&mut self, et: ExtensionTarget, value: F::Extension) + where + F: Extendable, + { + let limbs = value.to_basefield_array(); + (0..D).for_each(|i| { + self.set_target(et.0[i], limbs[i]); + }); + } + + fn set_extension_targets( + &mut self, + ets: &[ExtensionTarget], + values: &[F::Extension], + ) where + F: Extendable, + { + debug_assert_eq!(ets.len(), values.len()); + ets.iter() + .zip(values) + .for_each(|(&et, &v)| self.set_extension_target(et, v)); + } + + fn set_bool_target(&mut self, target: BoolTarget, value: bool) { + self.set_target(target.target, F::from_bool(value)) + } + + fn set_wire(&mut self, wire: Wire, value: F) { + self.set_target(Target::Wire(wire), value) + } + + fn set_wires(&mut self, wires: W, values: &[F]) + where + W: IntoIterator, + { + // If we used itertools, we could use zip_eq for extra safety. + for (wire, &value) in wires.into_iter().zip(values) { + self.set_wire(wire, value); + } + } + + fn set_ext_wires(&mut self, wires: W, value: F::Extension) + where + F: Extendable, + W: IntoIterator, + { + self.set_wires(wires, &value.to_basefield_array()); + } + + fn extend>(&mut self, pairs: I) { + for (t, v) in pairs { + self.set_target(t, v); + } + } +} + #[derive(Clone, Debug)] -pub struct Witness { +pub struct MatrixWitness { pub(crate) wire_values: Vec>, } -impl Witness { +impl MatrixWitness { pub fn get_wire(&self, gate: usize, input: usize) -> F { self.wire_values[input][gate] } @@ -39,86 +170,17 @@ impl PartialWitness { set_targets: vec![], } } +} - pub fn get_target(&self, target: Target) -> F { +impl Witness for PartialWitness { + fn try_get_target(&self, target: Target) -> Option { match target { - Target::Wire(Wire { gate, input }) => self.wire_values[gate][input].unwrap(), - Target::VirtualTarget { index } => self.virtual_target_values[index].unwrap(), + Target::Wire(Wire { gate, input }) => *self.wire_values.get(gate)?.get(input)?, + Target::VirtualTarget { index } => *self.virtual_target_values.get(index)?, } } - pub fn get_targets(&self, targets: &[Target]) -> Vec { - targets.iter().map(|&t| self.get_target(t)).collect() - } - - pub fn get_extension_target(&self, et: ExtensionTarget) -> F::Extension - where - F: Extendable, - { - F::Extension::from_basefield_array( - self.get_targets(&et.to_target_array()).try_into().unwrap(), - ) - } - - pub fn get_extension_targets( - &self, - ets: &[ExtensionTarget], - ) -> Vec - where - F: Extendable, - { - ets.iter() - .map(|&et| self.get_extension_target(et)) - .collect() - } - - pub fn get_bool_target(&self, target: BoolTarget) -> bool { - let value = self.get_target(target.target).to_canonical_u64(); - match value { - 0 => false, - 1 => true, - _ => panic!("not a bool"), - } - } - - pub fn get_hash_target(&self, ht: HashOutTarget) -> HashOut { - HashOut { - elements: self.get_targets(&ht.elements).try_into().unwrap(), - } - } - - pub fn try_get_target(&self, target: Target) -> Option { - match target { - Target::Wire(Wire { gate, input }) => self.wire_values[gate][input], - Target::VirtualTarget { index } => self.virtual_target_values[index], - } - } - - pub fn get_wire(&self, wire: Wire) -> F { - self.get_target(Target::Wire(wire)) - } - - pub fn try_get_wire(&self, wire: Wire) -> Option { - self.try_get_target(Target::Wire(wire)) - } - - pub fn contains(&self, target: Target) -> bool { - match target { - Target::Wire(Wire { gate, input }) => { - self.wire_values.len() > gate && self.wire_values[gate][input].is_some() - } - Target::VirtualTarget { index } => { - self.virtual_target_values.len() > index - && self.virtual_target_values[index].is_some() - } - } - } - - pub fn contains_all(&self, targets: &[Target]) -> bool { - targets.iter().all(|&t| self.contains(t)) - } - - pub fn set_target(&mut self, target: Target, value: F) { + fn set_target(&mut self, target: Target, value: F) { match target { Target::Wire(Wire { gate, input }) => { if gate >= self.wire_values.len() { @@ -152,123 +214,4 @@ impl PartialWitness { } self.set_targets.push((target, value)); } - - pub fn set_hash_target(&mut self, ht: HashOutTarget, value: HashOut) { - ht.elements - .iter() - .zip(value.elements) - .for_each(|(&t, x)| self.set_target(t, x)); - } - - pub fn set_cap_target(&mut self, ct: &MerkleCapTarget, value: &MerkleCap) { - for (ht, h) in ct.0.iter().zip(&value.0) { - self.set_hash_target(*ht, *h); - } - } - - pub fn set_extension_target( - &mut self, - et: ExtensionTarget, - value: F::Extension, - ) where - F: Extendable, - { - let limbs = value.to_basefield_array(); - (0..D).for_each(|i| { - self.set_target(et.0[i], limbs[i]); - }); - } - - pub fn set_extension_targets( - &mut self, - ets: &[ExtensionTarget], - values: &[F::Extension], - ) where - F: Extendable, - { - debug_assert_eq!(ets.len(), values.len()); - ets.iter() - .zip(values) - .for_each(|(&et, &v)| self.set_extension_target(et, v)); - } - - pub fn set_bool_target(&mut self, target: BoolTarget, value: bool) { - self.set_target(target.target, F::from_bool(value)) - } - - pub fn set_wire(&mut self, wire: Wire, value: F) { - self.set_target(Target::Wire(wire), value) - } - - pub fn set_wires(&mut self, wires: W, values: &[F]) - where - W: IntoIterator, - { - // If we used itertools, we could use zip_eq for extra safety. - for (wire, &value) in wires.into_iter().zip(values) { - self.set_wire(wire, value); - } - } - - pub fn set_ext_wires(&mut self, wires: W, value: F::Extension) - where - F: Extendable, - W: IntoIterator, - { - self.set_wires(wires, &value.to_basefield_array()); - } - - pub fn extend>(&mut self, pairs: I) { - for (t, v) in pairs { - self.set_target(t, v); - } - } - - pub fn full_witness(self, degree: usize, num_wires: usize) -> Witness { - let mut wire_values = vec![vec![F::ZERO; degree]; num_wires]; - assert!(self.wire_values.len() <= degree); - for i in 0..self.wire_values.len() { - for j in 0..num_wires { - wire_values[j][i] = self.wire_values[i][j].unwrap_or(F::ZERO); - } - } - Witness { wire_values } - } - - /// Checks that the copy constraints are satisfied in the witness. - pub fn check_copy_constraints( - &self, - copy_constraints: &[CopyConstraint], - gate_instances: &[GateInstance], - ) -> Result<()> - where - F: Extendable, - { - for CopyConstraint { pair: (a, b), name } in copy_constraints { - let va = self.try_get_target(*a).unwrap_or(F::ZERO); - let vb = self.try_get_target(*b).unwrap_or(F::ZERO); - let desc = |t: &Target| -> String { - match t { - Target::Wire(Wire { gate, input }) => format!( - "wire {} of gate #{} (`{}`)", - input, - gate, - gate_instances[*gate].gate_ref.0.id() - ), - Target::VirtualTarget { index } => format!("{}-th virtual target", index), - } - }; - ensure!( - va == vb, - "Copy constraint '{}' between {} and {} is not satisfied. \ - Got values of {} and {} respectively.", - name, - desc(a), - desc(b), - va, - vb - ); - } - Ok(()) - } } diff --git a/src/plonk/prover.rs b/src/plonk/prover.rs index 558ca290..49b93190 100644 --- a/src/plonk/prover.rs +++ b/src/plonk/prover.rs @@ -10,7 +10,7 @@ use crate::iop::challenger::Challenger; use crate::iop::generator::{generate_partial_witness, Yo}; use crate::iop::target::Target; use crate::iop::wire::Wire; -use crate::iop::witness::{PartialWitness, Witness}; +use crate::iop::witness::{MatrixWitness, PartialWitness, Witness}; use crate::plonk::circuit_data::{CommonCircuitData, ProverOnlyCircuitData}; use crate::plonk::permutation_argument::ForestNode; use crate::plonk::plonk_common::PlonkPolynomials; @@ -268,7 +268,7 @@ pub(crate) fn prove, const D: usize>( /// Compute the partial products used in the `Z` polynomials. fn all_wires_permutation_partial_products, const D: usize>( - witness: &Witness, + witness: &MatrixWitness, betas: &[F], gammas: &[F], prover_data: &ProverOnlyCircuitData, @@ -291,7 +291,7 @@ fn all_wires_permutation_partial_products, const D: usize>( /// Returns the polynomials interpolating `partial_products(f / g)` /// where `f, g` are the products in the definition of `Z`: `Z(g^i) = f / g`. fn wires_permutation_partial_products, const D: usize>( - witness: &Witness, + witness: &MatrixWitness, beta: F, gamma: F, prover_data: &ProverOnlyCircuitData, diff --git a/src/plonk/recursive_verifier.rs b/src/plonk/recursive_verifier.rs index 5f45af7e..97ebd84c 100644 --- a/src/plonk/recursive_verifier.rs +++ b/src/plonk/recursive_verifier.rs @@ -137,7 +137,7 @@ mod tests { use crate::fri::FriConfig; use crate::gadgets::polynomial::PolynomialCoeffsExtTarget; use crate::hash::merkle_proofs::MerkleProofTarget; - use crate::iop::witness::PartialWitness; + use crate::iop::witness::{PartialWitness, Witness}; use crate::plonk::proof::{OpeningSetTarget, Proof, ProofTarget, ProofWithPublicInputs}; use crate::plonk::verifier::verify; use crate::util::log2_strict; diff --git a/src/util/marking.rs b/src/util/marking.rs index e98f7cbc..6e0ad993 100644 --- a/src/util/marking.rs +++ b/src/util/marking.rs @@ -2,7 +2,7 @@ use crate::field::extension_field::target::ExtensionTarget; use crate::field::extension_field::Extendable; use crate::hash::hash_types::HashOutTarget; use crate::iop::target::Target; -use crate::iop::witness::PartialWitness; +use crate::iop::witness::{PartialWitness, Witness}; /// Enum representing all types of targets, so that they can be marked. #[derive(Clone)]