From f3bfd666578a1de71c4bd74aa7ddf4dd6aa0ffad Mon Sep 17 00:00:00 2001 From: Daniel Lubarov Date: Sat, 14 Aug 2021 08:53:39 -0700 Subject: [PATCH] Add a BoolTarget (#179) It's just a wrapper around `Target`, which signifies that the wrapped `Target` has already been range checked. Should make it easier to audit code that expects bools. --- src/field/field_types.rs | 4 ++++ src/fri/recursive_verifier.rs | 6 +++--- src/gadgets/arithmetic.rs | 16 ++++++++-------- src/gadgets/range_check.rs | 4 ++-- src/gadgets/select.rs | 30 +++++++++++++++++++----------- src/gadgets/split_base.rs | 20 +++++++++++--------- src/gadgets/split_join.rs | 15 ++++++++------- src/gates/exponentiation.rs | 2 +- src/hash/merkle_proofs.rs | 10 +++++----- src/iop/target.rs | 17 +++++++++++++++++ src/iop/witness.rs | 15 ++++++++++++++- src/plonk/circuit_builder.rs | 23 ++++++++++++++++++++++- 12 files changed, 114 insertions(+), 48 deletions(-) diff --git a/src/field/field_types.rs b/src/field/field_types.rs index fd5f8ac1..8d2f8872 100644 --- a/src/field/field_types.rs +++ b/src/field/field_types.rs @@ -185,6 +185,10 @@ pub trait Field: Self::from_canonical_u64(n as u64) } + fn from_bool(b: bool) -> Self { + Self::from_canonical_u64(b as u64) + } + fn to_canonical_biguint(&self) -> BigUint; fn from_canonical_biguint(n: BigUint) -> Self; diff --git a/src/fri/recursive_verifier.rs b/src/fri/recursive_verifier.rs index a8bd5d94..bf07639f 100644 --- a/src/fri/recursive_verifier.rs +++ b/src/fri/recursive_verifier.rs @@ -5,7 +5,7 @@ use crate::fri::proof::{FriInitialTreeProofTarget, FriProofTarget, FriQueryRound use crate::fri::FriConfig; use crate::hash::hash_types::MerkleCapTarget; use crate::iop::challenger::RecursiveChallenger; -use crate::iop::target::Target; +use crate::iop::target::{BoolTarget, Target}; use crate::plonk::circuit_builder::CircuitBuilder; use crate::plonk::circuit_data::CommonCircuitData; use crate::plonk::plonk_common::PlonkPolynomials; @@ -20,7 +20,7 @@ impl, const D: usize> CircuitBuilder { fn compute_evaluation( &mut self, x: Target, - x_index_within_coset_bits: &[Target], + x_index_within_coset_bits: &[BoolTarget], arity_bits: usize, evals: &[ExtensionTarget], beta: ExtensionTarget, @@ -181,7 +181,7 @@ impl, const D: usize> CircuitBuilder { fn fri_verify_initial_proof( &mut self, - x_index_bits: &[Target], + x_index_bits: &[BoolTarget], proof: &FriInitialTreeProofTarget, initial_merkle_caps: &[MerkleCapTarget], cap_index: Target, diff --git a/src/gadgets/arithmetic.rs b/src/gadgets/arithmetic.rs index 27a204bc..2c41482e 100644 --- a/src/gadgets/arithmetic.rs +++ b/src/gadgets/arithmetic.rs @@ -2,7 +2,7 @@ use std::borrow::Borrow; use crate::field::extension_field::Extendable; use crate::gates::exponentiation::ExponentiationGate; -use crate::iop::target::Target; +use crate::iop::target::{BoolTarget, Target}; use crate::plonk::circuit_builder::CircuitBuilder; impl, const D: usize> CircuitBuilder { @@ -115,21 +115,21 @@ impl, const D: usize> CircuitBuilder { pub fn exp_from_bits( &mut self, base: Target, - exponent_bits: impl IntoIterator>, + exponent_bits: impl IntoIterator>, ) -> Target { - let zero = self.zero(); + let _false = self._false(); let gate = ExponentiationGate::new(self.config.clone()); let num_power_bits = gate.num_power_bits; - let mut exp_bits_vec: Vec = + let mut exp_bits_vec: Vec = exponent_bits.into_iter().map(|b| *b.borrow()).collect(); while exp_bits_vec.len() < num_power_bits { - exp_bits_vec.push(zero); + exp_bits_vec.push(_false); } let gate_index = self.add_gate(gate.clone(), vec![]); self.route(base, Target::wire(gate_index, gate.wire_base())); exp_bits_vec.iter().enumerate().for_each(|(i, bit)| { - self.route(*bit, Target::wire(gate_index, gate.wire_power_bit(i))); + self.route(bit.target, Target::wire(gate_index, gate.wire_power_bit(i))); }); Target::wire(gate_index, gate.wire_output()) @@ -148,8 +148,8 @@ impl, const D: usize> CircuitBuilder { pub fn exp_u64(&mut self, base: Target, mut exponent: u64) -> Target { let mut exp_bits = Vec::new(); while exponent != 0 { - let bit = exponent & 1; - let bit_target = self.constant(F::from_canonical_u64(bit)); + let bit = (exponent & 1) == 1; + let bit_target = self.constant_bool(bit); exp_bits.push(bit_target); exponent >>= 1; } diff --git a/src/gadgets/range_check.rs b/src/gadgets/range_check.rs index 53bbf55c..be7eaf2e 100644 --- a/src/gadgets/range_check.rs +++ b/src/gadgets/range_check.rs @@ -2,7 +2,7 @@ use crate::field::extension_field::Extendable; use crate::field::field_types::Field; use crate::gates::base_sum::BaseSumGate; use crate::iop::generator::{GeneratedValues, SimpleGenerator}; -use crate::iop::target::Target; +use crate::iop::target::{BoolTarget, Target}; use crate::iop::witness::PartialWitness; use crate::plonk::circuit_builder::CircuitBuilder; @@ -15,7 +15,7 @@ impl, const D: usize> CircuitBuilder { } /// Returns the first `num_low_bits` little-endian bits of `x`. - pub fn low_bits(&mut self, x: Target, num_low_bits: usize, num_bits: usize) -> Vec { + pub fn low_bits(&mut self, x: Target, num_low_bits: usize, num_bits: usize) -> Vec { let mut res = self.split_le(x, num_bits); res.truncate(num_low_bits); res diff --git a/src/gadgets/select.rs b/src/gadgets/select.rs index 58de09c6..c81dd37b 100644 --- a/src/gadgets/select.rs +++ b/src/gadgets/select.rs @@ -1,14 +1,25 @@ use crate::field::extension_field::target::ExtensionTarget; use crate::field::extension_field::Extendable; use crate::gates::arithmetic::ArithmeticExtensionGate; -use crate::iop::target::Target; +use crate::iop::target::{BoolTarget, Target}; use crate::plonk::circuit_builder::CircuitBuilder; impl, const D: usize> CircuitBuilder { - /// Selects `x` or `y` based on `b`, which is assumed to be binary, i.e., this returns `if b { x } else { y }`. - /// This expression is gotten as `bx - (by-y)`, which can be computed with a single `ArithmeticExtensionGate`. - /// Note: This does not range-check `b`. + /// Selects `x` or `y` based on `b`, i.e., this returns `if b { x } else { y }`. pub fn select_ext( + &mut self, + b: BoolTarget, + x: ExtensionTarget, + y: ExtensionTarget, + ) -> ExtensionTarget { + let b_ext = self.convert_to_ext(b.target); + self.select_ext_generalized(b_ext, x, y) + } + + /// Like `select_ext`, but accepts a condition input which does not necessarily have to be + /// binary. In this case, it computes the arithmetic generalization of `if b { x } else { y }`, + /// i.e. `bx - (by-y)`, which can be computed with a single `ArithmeticExtensionGate`. + pub fn select_ext_generalized( &mut self, b: ExtensionTarget, x: ExtensionTarget, @@ -23,11 +34,10 @@ impl, const D: usize> CircuitBuilder { } /// See `select_ext`. - pub fn select(&mut self, b: Target, x: Target, y: Target) -> Target { - let b_ext = self.convert_to_ext(b); + pub fn select(&mut self, b: BoolTarget, x: Target, y: Target) -> Target { let x_ext = self.convert_to_ext(x); let y_ext = self.convert_to_ext(y); - self.select_ext(b_ext, x_ext, y_ext).to_target_array()[0] + self.select_ext(b, x_ext, y_ext).to_target_array()[0] } } @@ -54,13 +64,11 @@ mod tests { let (x, y) = (FF::rand(), FF::rand()); let xt = builder.add_virtual_extension_target(); let yt = builder.add_virtual_extension_target(); - let truet = builder.add_virtual_extension_target(); - let falset = builder.add_virtual_extension_target(); + let truet = builder._true(); + let falset = builder._false(); pw.set_extension_target(xt, x); pw.set_extension_target(yt, y); - pw.set_extension_target(truet, FF::ONE); - pw.set_extension_target(falset, FF::ZERO); let should_be_x = builder.select_ext(truet, xt, yt); let should_be_y = builder.select_ext(falset, xt, yt); diff --git a/src/gadgets/split_base.rs b/src/gadgets/split_base.rs index 9b481e78..69cfa377 100644 --- a/src/gadgets/split_base.rs +++ b/src/gadgets/split_base.rs @@ -4,7 +4,7 @@ use crate::field::extension_field::Extendable; use crate::field::field_types::Field; use crate::gates::base_sum::BaseSumGate; use crate::iop::generator::{GeneratedValues, SimpleGenerator}; -use crate::iop::target::Target; +use crate::iop::target::{BoolTarget, Target}; use crate::iop::witness::PartialWitness; use crate::plonk::circuit_builder::CircuitBuilder; @@ -33,7 +33,7 @@ impl, const D: usize> CircuitBuilder { /// the number with little-endian bit representation given by `bits`. pub(crate) fn le_sum( &mut self, - bits: impl ExactSizeIterator> + Clone, + bits: impl ExactSizeIterator> + Clone, ) -> Target { let num_bits = bits.len(); debug_assert!( @@ -45,7 +45,7 @@ impl, const D: usize> CircuitBuilder { .clone() .zip(BaseSumGate::<2>::START_LIMBS..BaseSumGate::<2>::START_LIMBS + num_bits) { - self.route(*limb.borrow(), Target::wire(gate_index, wire)); + self.route(limb.borrow().target, Target::wire(gate_index, wire)); } self.add_generator(BaseSumGenerator::<2> { @@ -60,21 +60,23 @@ impl, const D: usize> CircuitBuilder { #[derive(Debug)] struct BaseSumGenerator { gate_index: usize, - limbs: Vec, + limbs: Vec, } impl SimpleGenerator for BaseSumGenerator { fn dependencies(&self) -> Vec { - self.limbs.clone() + self.limbs.iter().map(|b| b.target).collect() } fn run_once(&self, witness: &PartialWitness, out_buffer: &mut GeneratedValues) { let sum = self .limbs .iter() - .map(|&t| witness.get_target(t)) + .map(|&t| witness.get_bool_target(t)) .rev() - .fold(F::ZERO, |acc, limb| acc * F::from_canonical_usize(B) + limb); + .fold(F::ZERO, |acc, limb| { + acc * F::from_canonical_usize(B) + F::from_bool(limb) + }); out_buffer.set_target( Target::wire(self.gate_index, BaseSumGate::::WIRE_SUM), @@ -131,8 +133,8 @@ mod tests { let n = thread_rng().gen_range(0..(1 << 10)); let x = builder.constant(F::from_canonical_usize(n)); - let zero = builder.zero(); - let one = builder.one(); + let zero = builder._false(); + let one = builder._true(); let y = builder.le_sum( (0..10) diff --git a/src/gadgets/split_join.rs b/src/gadgets/split_join.rs index 1fe25dbe..71c171f0 100644 --- a/src/gadgets/split_join.rs +++ b/src/gadgets/split_join.rs @@ -2,7 +2,7 @@ use crate::field::extension_field::Extendable; use crate::field::field_types::Field; use crate::gates::base_sum::BaseSumGate; use crate::iop::generator::{GeneratedValues, SimpleGenerator}; -use crate::iop::target::Target; +use crate::iop::target::{BoolTarget, Target}; use crate::iop::witness::PartialWitness; use crate::plonk::circuit_builder::CircuitBuilder; use crate::util::ceil_div_usize; @@ -11,8 +11,8 @@ impl, const D: usize> CircuitBuilder { /// Split the given integer into a list of wires, where each one represents a /// bit of the integer, with little-endian ordering. /// Verifies that the decomposition is correct by using `k` `BaseSum<2>` gates - /// with `k` such that `k*num_routed_wires>=num_bits`. - pub(crate) fn split_le(&mut self, integer: Target, num_bits: usize) -> Vec { + /// with `k` such that `k * num_routed_wires >= num_bits`. + pub(crate) fn split_le(&mut self, integer: Target, num_bits: usize) -> Vec { if num_bits == 0 { return Vec::new(); } @@ -24,10 +24,11 @@ impl, const D: usize> CircuitBuilder { let mut bits = Vec::with_capacity(num_bits); for &gate in &gates { - bits.extend(Target::wires_from_range( - gate, - BaseSumGate::<2>::START_LIMBS..BaseSumGate::<2>::START_LIMBS + bits_per_gate, - )); + let start_limbs = BaseSumGate::<2>::START_LIMBS; + for limb_input in start_limbs..start_limbs + bits_per_gate { + // `new_unsafe` is safe here because BaseSumGate::<2> forces it to be in `{0, 1}`. + bits.push(BoolTarget::new_unsafe(Target::wire(gate, limb_input))); + } } bits.drain(num_bits..); diff --git a/src/gates/exponentiation.rs b/src/gates/exponentiation.rs index ac3a467a..468f58e2 100644 --- a/src/gates/exponentiation.rs +++ b/src/gates/exponentiation.rs @@ -158,7 +158,7 @@ impl, const D: usize> Gate for ExponentiationGate { // power_bits is in LE order, but we accumulate in BE order. let cur_bit = power_bits[self.num_power_bits - i - 1]; - let mul_by = builder.select_ext(cur_bit, base, one); + let mul_by = builder.select_ext_generalized(cur_bit, base, one); let intermediate_value_diff = builder.mul_sub_extension(prev_intermediate_value, mul_by, intermediate_values[i]); constraints.push(intermediate_value_diff); diff --git a/src/hash/merkle_proofs.rs b/src/hash/merkle_proofs.rs index 45009c7d..1b42e393 100644 --- a/src/hash/merkle_proofs.rs +++ b/src/hash/merkle_proofs.rs @@ -10,7 +10,7 @@ use crate::gates::gmimc::GMiMCGate; use crate::hash::hash_types::{HashOut, HashOutTarget, MerkleCapTarget}; use crate::hash::hashing::{compress, hash_or_noop, GMIMC_ROUNDS}; use crate::hash::merkle_tree::MerkleCap; -use crate::iop::target::Target; +use crate::iop::target::{BoolTarget, Target}; use crate::iop::wire::Wire; use crate::plonk::circuit_builder::CircuitBuilder; @@ -66,7 +66,7 @@ impl, const D: usize> CircuitBuilder { pub(crate) fn verify_merkle_proof( &mut self, leaf_data: Vec, - leaf_index_bits: &[Target], + leaf_index_bits: &[BoolTarget], merkle_cap: &MerkleCapTarget, proof: &MerkleProofTarget, ) { @@ -83,7 +83,7 @@ impl, const D: usize> CircuitBuilder { gate, input: swap_wire, }); - self.generate_copy(bit, swap_wire); + self.generate_copy(bit.target, swap_wire); let input_wires = (0..12) .map(|i| { @@ -131,7 +131,7 @@ impl, const D: usize> CircuitBuilder { pub(crate) fn verify_merkle_proof_with_cap_index( &mut self, leaf_data: Vec, - leaf_index_bits: &[Target], + leaf_index_bits: &[BoolTarget], cap_index: Target, merkle_cap: &MerkleCapTarget, proof: &MerkleProofTarget, @@ -149,7 +149,7 @@ impl, const D: usize> CircuitBuilder { gate, input: swap_wire, }); - self.generate_copy(bit, swap_wire); + self.generate_copy(bit.target, swap_wire); let input_wires = (0..12) .map(|i| { diff --git a/src/iop/target.rs b/src/iop/target.rs index 50bd6bb6..877da5b8 100644 --- a/src/iop/target.rs +++ b/src/iop/target.rs @@ -31,3 +31,20 @@ impl Target { range.map(|i| Self::wire(gate, i)).collect() } } + +/// A `Target` which has already been constrained such that it can only be 0 or 1. +#[derive(Copy, Clone, Debug)] +pub struct BoolTarget { + pub target: Target, + /// This private field is here to force all instantiations to go through `new_unsafe`. + _private: (), +} + +impl BoolTarget { + pub fn new_unsafe(target: Target) -> BoolTarget { + BoolTarget { + target, + _private: (), + } + } +} diff --git a/src/iop/witness.rs b/src/iop/witness.rs index f2c0d453..51b5a182 100644 --- a/src/iop/witness.rs +++ b/src/iop/witness.rs @@ -9,7 +9,7 @@ use crate::gates::gate::GateInstance; use crate::hash::hash_types::HashOutTarget; use crate::hash::hash_types::{HashOut, MerkleCapTarget}; use crate::hash::merkle_tree::MerkleCap; -use crate::iop::target::Target; +use crate::iop::target::{BoolTarget, Target}; use crate::iop::wire::Wire; use crate::plonk::copy_constraint::CopyConstraint; @@ -70,6 +70,15 @@ impl PartialWitness { .collect() } + pub fn get_bool_target(&self, target: BoolTarget) -> bool { + let value = self.get_target(target.target).to_canonical_u64(); + match value { + 0 => false, + 1 => true, + _ => panic!("not a bool"), + } + } + pub fn get_hash_target(&self, ht: HashOutTarget) -> HashOut { HashOut { elements: self.get_targets(&ht.elements).try_into().unwrap(), @@ -180,6 +189,10 @@ impl PartialWitness { .for_each(|(&et, &v)| self.set_extension_target(et, v)); } + pub fn set_bool_target(&mut self, target: BoolTarget, value: bool) { + self.set_target(target.target, F::from_bool(value)) + } + pub fn set_wire(&mut self, wire: Wire, value: F) { self.set_target(Target::Wire(wire), value) } diff --git a/src/plonk/circuit_builder.rs b/src/plonk/circuit_builder.rs index 1db21fa5..aa8998b6 100644 --- a/src/plonk/circuit_builder.rs +++ b/src/plonk/circuit_builder.rs @@ -16,7 +16,7 @@ use crate::gates::public_input::PublicInputGate; use crate::hash::hash_types::{HashOutTarget, MerkleCapTarget}; use crate::hash::hashing::hash_n_to_hash; use crate::iop::generator::{CopyGenerator, RandomValueGenerator, WitnessGenerator}; -use crate::iop::target::Target; +use crate::iop::target::{BoolTarget, Target}; use crate::iop::wire::Wire; use crate::plonk::circuit_data::{ CircuitConfig, CircuitData, CommonCircuitData, ProverCircuitData, ProverOnlyCircuitData, @@ -129,6 +129,11 @@ impl, const D: usize> CircuitBuilder { .collect() } + // TODO: Unsafe + pub fn add_virtual_bool_target(&mut self) -> BoolTarget { + BoolTarget::new_unsafe(self.add_virtual_target()) + } + /// Adds a gate to the circuit, and returns its index. pub fn add_gate>(&mut self, gate_type: G, constants: Vec) -> usize { self.check_gate_compatibility(&gate_type); @@ -279,6 +284,14 @@ impl, const D: usize> CircuitBuilder { self.constant(F::NEG_ONE) } + pub fn _false(&mut self) -> BoolTarget { + BoolTarget::new_unsafe(self.zero()) + } + + pub fn _true(&mut self) -> BoolTarget { + BoolTarget::new_unsafe(self.one()) + } + /// Returns a routable target with the given constant value. pub fn constant(&mut self, c: F) -> Target { if let Some(&target) = self.constants_to_targets.get(&c) { @@ -300,6 +313,14 @@ impl, const D: usize> CircuitBuilder { constants.iter().map(|&c| self.constant(c)).collect() } + pub fn constant_bool(&mut self, b: bool) -> BoolTarget { + if b { + self._true() + } else { + self._false() + } + } + /// If the given target is a constant (i.e. it was created by the `constant(F)` method), returns /// its constant value. Otherwise, returns `None`. pub fn target_as_constant(&self, target: Target) -> Option {