Add a BoolTarget (#179)

It's just a wrapper around `Target`, which signifies that the wrapped `Target` has already been range checked. Should make it easier to audit code that expects bools.
This commit is contained in:
Daniel Lubarov 2021-08-14 08:53:39 -07:00 committed by GitHub
parent 8effaf76e9
commit f3bfd66657
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 114 additions and 48 deletions

View File

@ -185,6 +185,10 @@ pub trait Field:
Self::from_canonical_u64(n as u64)
}
fn from_bool(b: bool) -> Self {
Self::from_canonical_u64(b as u64)
}
fn to_canonical_biguint(&self) -> BigUint;
fn from_canonical_biguint(n: BigUint) -> Self;

View File

@ -5,7 +5,7 @@ use crate::fri::proof::{FriInitialTreeProofTarget, FriProofTarget, FriQueryRound
use crate::fri::FriConfig;
use crate::hash::hash_types::MerkleCapTarget;
use crate::iop::challenger::RecursiveChallenger;
use crate::iop::target::Target;
use crate::iop::target::{BoolTarget, Target};
use crate::plonk::circuit_builder::CircuitBuilder;
use crate::plonk::circuit_data::CommonCircuitData;
use crate::plonk::plonk_common::PlonkPolynomials;
@ -20,7 +20,7 @@ impl<F: Extendable<D>, const D: usize> CircuitBuilder<F, D> {
fn compute_evaluation(
&mut self,
x: Target,
x_index_within_coset_bits: &[Target],
x_index_within_coset_bits: &[BoolTarget],
arity_bits: usize,
evals: &[ExtensionTarget<D>],
beta: ExtensionTarget<D>,
@ -181,7 +181,7 @@ impl<F: Extendable<D>, const D: usize> CircuitBuilder<F, D> {
fn fri_verify_initial_proof(
&mut self,
x_index_bits: &[Target],
x_index_bits: &[BoolTarget],
proof: &FriInitialTreeProofTarget,
initial_merkle_caps: &[MerkleCapTarget],
cap_index: Target,

View File

@ -2,7 +2,7 @@ use std::borrow::Borrow;
use crate::field::extension_field::Extendable;
use crate::gates::exponentiation::ExponentiationGate;
use crate::iop::target::Target;
use crate::iop::target::{BoolTarget, Target};
use crate::plonk::circuit_builder::CircuitBuilder;
impl<F: Extendable<D>, const D: usize> CircuitBuilder<F, D> {
@ -115,21 +115,21 @@ impl<F: Extendable<D>, const D: usize> CircuitBuilder<F, D> {
pub fn exp_from_bits(
&mut self,
base: Target,
exponent_bits: impl IntoIterator<Item = impl Borrow<Target>>,
exponent_bits: impl IntoIterator<Item = impl Borrow<BoolTarget>>,
) -> Target {
let zero = self.zero();
let _false = self._false();
let gate = ExponentiationGate::new(self.config.clone());
let num_power_bits = gate.num_power_bits;
let mut exp_bits_vec: Vec<Target> =
let mut exp_bits_vec: Vec<BoolTarget> =
exponent_bits.into_iter().map(|b| *b.borrow()).collect();
while exp_bits_vec.len() < num_power_bits {
exp_bits_vec.push(zero);
exp_bits_vec.push(_false);
}
let gate_index = self.add_gate(gate.clone(), vec![]);
self.route(base, Target::wire(gate_index, gate.wire_base()));
exp_bits_vec.iter().enumerate().for_each(|(i, bit)| {
self.route(*bit, Target::wire(gate_index, gate.wire_power_bit(i)));
self.route(bit.target, Target::wire(gate_index, gate.wire_power_bit(i)));
});
Target::wire(gate_index, gate.wire_output())
@ -148,8 +148,8 @@ impl<F: Extendable<D>, const D: usize> CircuitBuilder<F, D> {
pub fn exp_u64(&mut self, base: Target, mut exponent: u64) -> Target {
let mut exp_bits = Vec::new();
while exponent != 0 {
let bit = exponent & 1;
let bit_target = self.constant(F::from_canonical_u64(bit));
let bit = (exponent & 1) == 1;
let bit_target = self.constant_bool(bit);
exp_bits.push(bit_target);
exponent >>= 1;
}

View File

@ -2,7 +2,7 @@ use crate::field::extension_field::Extendable;
use crate::field::field_types::Field;
use crate::gates::base_sum::BaseSumGate;
use crate::iop::generator::{GeneratedValues, SimpleGenerator};
use crate::iop::target::Target;
use crate::iop::target::{BoolTarget, Target};
use crate::iop::witness::PartialWitness;
use crate::plonk::circuit_builder::CircuitBuilder;
@ -15,7 +15,7 @@ impl<F: Extendable<D>, const D: usize> CircuitBuilder<F, D> {
}
/// Returns the first `num_low_bits` little-endian bits of `x`.
pub fn low_bits(&mut self, x: Target, num_low_bits: usize, num_bits: usize) -> Vec<Target> {
pub fn low_bits(&mut self, x: Target, num_low_bits: usize, num_bits: usize) -> Vec<BoolTarget> {
let mut res = self.split_le(x, num_bits);
res.truncate(num_low_bits);
res

View File

@ -1,14 +1,25 @@
use crate::field::extension_field::target::ExtensionTarget;
use crate::field::extension_field::Extendable;
use crate::gates::arithmetic::ArithmeticExtensionGate;
use crate::iop::target::Target;
use crate::iop::target::{BoolTarget, Target};
use crate::plonk::circuit_builder::CircuitBuilder;
impl<F: Extendable<D>, const D: usize> CircuitBuilder<F, D> {
/// Selects `x` or `y` based on `b`, which is assumed to be binary, i.e., this returns `if b { x } else { y }`.
/// This expression is gotten as `bx - (by-y)`, which can be computed with a single `ArithmeticExtensionGate`.
/// Note: This does not range-check `b`.
/// Selects `x` or `y` based on `b`, i.e., this returns `if b { x } else { y }`.
pub fn select_ext(
&mut self,
b: BoolTarget,
x: ExtensionTarget<D>,
y: ExtensionTarget<D>,
) -> ExtensionTarget<D> {
let b_ext = self.convert_to_ext(b.target);
self.select_ext_generalized(b_ext, x, y)
}
/// Like `select_ext`, but accepts a condition input which does not necessarily have to be
/// binary. In this case, it computes the arithmetic generalization of `if b { x } else { y }`,
/// i.e. `bx - (by-y)`, which can be computed with a single `ArithmeticExtensionGate`.
pub fn select_ext_generalized(
&mut self,
b: ExtensionTarget<D>,
x: ExtensionTarget<D>,
@ -23,11 +34,10 @@ impl<F: Extendable<D>, const D: usize> CircuitBuilder<F, D> {
}
/// See `select_ext`.
pub fn select(&mut self, b: Target, x: Target, y: Target) -> Target {
let b_ext = self.convert_to_ext(b);
pub fn select(&mut self, b: BoolTarget, x: Target, y: Target) -> Target {
let x_ext = self.convert_to_ext(x);
let y_ext = self.convert_to_ext(y);
self.select_ext(b_ext, x_ext, y_ext).to_target_array()[0]
self.select_ext(b, x_ext, y_ext).to_target_array()[0]
}
}
@ -54,13 +64,11 @@ mod tests {
let (x, y) = (FF::rand(), FF::rand());
let xt = builder.add_virtual_extension_target();
let yt = builder.add_virtual_extension_target();
let truet = builder.add_virtual_extension_target();
let falset = builder.add_virtual_extension_target();
let truet = builder._true();
let falset = builder._false();
pw.set_extension_target(xt, x);
pw.set_extension_target(yt, y);
pw.set_extension_target(truet, FF::ONE);
pw.set_extension_target(falset, FF::ZERO);
let should_be_x = builder.select_ext(truet, xt, yt);
let should_be_y = builder.select_ext(falset, xt, yt);

View File

@ -4,7 +4,7 @@ use crate::field::extension_field::Extendable;
use crate::field::field_types::Field;
use crate::gates::base_sum::BaseSumGate;
use crate::iop::generator::{GeneratedValues, SimpleGenerator};
use crate::iop::target::Target;
use crate::iop::target::{BoolTarget, Target};
use crate::iop::witness::PartialWitness;
use crate::plonk::circuit_builder::CircuitBuilder;
@ -33,7 +33,7 @@ impl<F: Extendable<D>, const D: usize> CircuitBuilder<F, D> {
/// the number with little-endian bit representation given by `bits`.
pub(crate) fn le_sum(
&mut self,
bits: impl ExactSizeIterator<Item = impl Borrow<Target>> + Clone,
bits: impl ExactSizeIterator<Item = impl Borrow<BoolTarget>> + Clone,
) -> Target {
let num_bits = bits.len();
debug_assert!(
@ -45,7 +45,7 @@ impl<F: Extendable<D>, const D: usize> CircuitBuilder<F, D> {
.clone()
.zip(BaseSumGate::<2>::START_LIMBS..BaseSumGate::<2>::START_LIMBS + num_bits)
{
self.route(*limb.borrow(), Target::wire(gate_index, wire));
self.route(limb.borrow().target, Target::wire(gate_index, wire));
}
self.add_generator(BaseSumGenerator::<2> {
@ -60,21 +60,23 @@ impl<F: Extendable<D>, const D: usize> CircuitBuilder<F, D> {
#[derive(Debug)]
struct BaseSumGenerator<const B: usize> {
gate_index: usize,
limbs: Vec<Target>,
limbs: Vec<BoolTarget>,
}
impl<F: Field, const B: usize> SimpleGenerator<F> for BaseSumGenerator<B> {
fn dependencies(&self) -> Vec<Target> {
self.limbs.clone()
self.limbs.iter().map(|b| b.target).collect()
}
fn run_once(&self, witness: &PartialWitness<F>, out_buffer: &mut GeneratedValues<F>) {
let sum = self
.limbs
.iter()
.map(|&t| witness.get_target(t))
.map(|&t| witness.get_bool_target(t))
.rev()
.fold(F::ZERO, |acc, limb| acc * F::from_canonical_usize(B) + limb);
.fold(F::ZERO, |acc, limb| {
acc * F::from_canonical_usize(B) + F::from_bool(limb)
});
out_buffer.set_target(
Target::wire(self.gate_index, BaseSumGate::<B>::WIRE_SUM),
@ -131,8 +133,8 @@ mod tests {
let n = thread_rng().gen_range(0..(1 << 10));
let x = builder.constant(F::from_canonical_usize(n));
let zero = builder.zero();
let one = builder.one();
let zero = builder._false();
let one = builder._true();
let y = builder.le_sum(
(0..10)

View File

@ -2,7 +2,7 @@ use crate::field::extension_field::Extendable;
use crate::field::field_types::Field;
use crate::gates::base_sum::BaseSumGate;
use crate::iop::generator::{GeneratedValues, SimpleGenerator};
use crate::iop::target::Target;
use crate::iop::target::{BoolTarget, Target};
use crate::iop::witness::PartialWitness;
use crate::plonk::circuit_builder::CircuitBuilder;
use crate::util::ceil_div_usize;
@ -11,8 +11,8 @@ impl<F: Extendable<D>, const D: usize> CircuitBuilder<F, D> {
/// Split the given integer into a list of wires, where each one represents a
/// bit of the integer, with little-endian ordering.
/// Verifies that the decomposition is correct by using `k` `BaseSum<2>` gates
/// with `k` such that `k*num_routed_wires>=num_bits`.
pub(crate) fn split_le(&mut self, integer: Target, num_bits: usize) -> Vec<Target> {
/// with `k` such that `k * num_routed_wires >= num_bits`.
pub(crate) fn split_le(&mut self, integer: Target, num_bits: usize) -> Vec<BoolTarget> {
if num_bits == 0 {
return Vec::new();
}
@ -24,10 +24,11 @@ impl<F: Extendable<D>, const D: usize> CircuitBuilder<F, D> {
let mut bits = Vec::with_capacity(num_bits);
for &gate in &gates {
bits.extend(Target::wires_from_range(
gate,
BaseSumGate::<2>::START_LIMBS..BaseSumGate::<2>::START_LIMBS + bits_per_gate,
));
let start_limbs = BaseSumGate::<2>::START_LIMBS;
for limb_input in start_limbs..start_limbs + bits_per_gate {
// `new_unsafe` is safe here because BaseSumGate::<2> forces it to be in `{0, 1}`.
bits.push(BoolTarget::new_unsafe(Target::wire(gate, limb_input)));
}
}
bits.drain(num_bits..);

View File

@ -158,7 +158,7 @@ impl<F: Extendable<D>, const D: usize> Gate<F, D> for ExponentiationGate<F, D> {
// power_bits is in LE order, but we accumulate in BE order.
let cur_bit = power_bits[self.num_power_bits - i - 1];
let mul_by = builder.select_ext(cur_bit, base, one);
let mul_by = builder.select_ext_generalized(cur_bit, base, one);
let intermediate_value_diff =
builder.mul_sub_extension(prev_intermediate_value, mul_by, intermediate_values[i]);
constraints.push(intermediate_value_diff);

View File

@ -10,7 +10,7 @@ use crate::gates::gmimc::GMiMCGate;
use crate::hash::hash_types::{HashOut, HashOutTarget, MerkleCapTarget};
use crate::hash::hashing::{compress, hash_or_noop, GMIMC_ROUNDS};
use crate::hash::merkle_tree::MerkleCap;
use crate::iop::target::Target;
use crate::iop::target::{BoolTarget, Target};
use crate::iop::wire::Wire;
use crate::plonk::circuit_builder::CircuitBuilder;
@ -66,7 +66,7 @@ impl<F: Extendable<D>, const D: usize> CircuitBuilder<F, D> {
pub(crate) fn verify_merkle_proof(
&mut self,
leaf_data: Vec<Target>,
leaf_index_bits: &[Target],
leaf_index_bits: &[BoolTarget],
merkle_cap: &MerkleCapTarget,
proof: &MerkleProofTarget,
) {
@ -83,7 +83,7 @@ impl<F: Extendable<D>, const D: usize> CircuitBuilder<F, D> {
gate,
input: swap_wire,
});
self.generate_copy(bit, swap_wire);
self.generate_copy(bit.target, swap_wire);
let input_wires = (0..12)
.map(|i| {
@ -131,7 +131,7 @@ impl<F: Extendable<D>, const D: usize> CircuitBuilder<F, D> {
pub(crate) fn verify_merkle_proof_with_cap_index(
&mut self,
leaf_data: Vec<Target>,
leaf_index_bits: &[Target],
leaf_index_bits: &[BoolTarget],
cap_index: Target,
merkle_cap: &MerkleCapTarget,
proof: &MerkleProofTarget,
@ -149,7 +149,7 @@ impl<F: Extendable<D>, const D: usize> CircuitBuilder<F, D> {
gate,
input: swap_wire,
});
self.generate_copy(bit, swap_wire);
self.generate_copy(bit.target, swap_wire);
let input_wires = (0..12)
.map(|i| {

View File

@ -31,3 +31,20 @@ impl Target {
range.map(|i| Self::wire(gate, i)).collect()
}
}
/// A `Target` which has already been constrained such that it can only be 0 or 1.
#[derive(Copy, Clone, Debug)]
pub struct BoolTarget {
pub target: Target,
/// This private field is here to force all instantiations to go through `new_unsafe`.
_private: (),
}
impl BoolTarget {
pub fn new_unsafe(target: Target) -> BoolTarget {
BoolTarget {
target,
_private: (),
}
}
}

View File

@ -9,7 +9,7 @@ use crate::gates::gate::GateInstance;
use crate::hash::hash_types::HashOutTarget;
use crate::hash::hash_types::{HashOut, MerkleCapTarget};
use crate::hash::merkle_tree::MerkleCap;
use crate::iop::target::Target;
use crate::iop::target::{BoolTarget, Target};
use crate::iop::wire::Wire;
use crate::plonk::copy_constraint::CopyConstraint;
@ -70,6 +70,15 @@ impl<F: Field> PartialWitness<F> {
.collect()
}
pub fn get_bool_target(&self, target: BoolTarget) -> bool {
let value = self.get_target(target.target).to_canonical_u64();
match value {
0 => false,
1 => true,
_ => panic!("not a bool"),
}
}
pub fn get_hash_target(&self, ht: HashOutTarget) -> HashOut<F> {
HashOut {
elements: self.get_targets(&ht.elements).try_into().unwrap(),
@ -180,6 +189,10 @@ impl<F: Field> PartialWitness<F> {
.for_each(|(&et, &v)| self.set_extension_target(et, v));
}
pub fn set_bool_target(&mut self, target: BoolTarget, value: bool) {
self.set_target(target.target, F::from_bool(value))
}
pub fn set_wire(&mut self, wire: Wire, value: F) {
self.set_target(Target::Wire(wire), value)
}

View File

@ -16,7 +16,7 @@ use crate::gates::public_input::PublicInputGate;
use crate::hash::hash_types::{HashOutTarget, MerkleCapTarget};
use crate::hash::hashing::hash_n_to_hash;
use crate::iop::generator::{CopyGenerator, RandomValueGenerator, WitnessGenerator};
use crate::iop::target::Target;
use crate::iop::target::{BoolTarget, Target};
use crate::iop::wire::Wire;
use crate::plonk::circuit_data::{
CircuitConfig, CircuitData, CommonCircuitData, ProverCircuitData, ProverOnlyCircuitData,
@ -129,6 +129,11 @@ impl<F: Extendable<D>, const D: usize> CircuitBuilder<F, D> {
.collect()
}
// TODO: Unsafe
pub fn add_virtual_bool_target(&mut self) -> BoolTarget {
BoolTarget::new_unsafe(self.add_virtual_target())
}
/// Adds a gate to the circuit, and returns its index.
pub fn add_gate<G: Gate<F, D>>(&mut self, gate_type: G, constants: Vec<F>) -> usize {
self.check_gate_compatibility(&gate_type);
@ -279,6 +284,14 @@ impl<F: Extendable<D>, const D: usize> CircuitBuilder<F, D> {
self.constant(F::NEG_ONE)
}
pub fn _false(&mut self) -> BoolTarget {
BoolTarget::new_unsafe(self.zero())
}
pub fn _true(&mut self) -> BoolTarget {
BoolTarget::new_unsafe(self.one())
}
/// Returns a routable target with the given constant value.
pub fn constant(&mut self, c: F) -> Target {
if let Some(&target) = self.constants_to_targets.get(&c) {
@ -300,6 +313,14 @@ impl<F: Extendable<D>, const D: usize> CircuitBuilder<F, D> {
constants.iter().map(|&c| self.constant(c)).collect()
}
pub fn constant_bool(&mut self, b: bool) -> BoolTarget {
if b {
self._true()
} else {
self._false()
}
}
/// If the given target is a constant (i.e. it was created by the `constant(F)` method), returns
/// its constant value. Otherwise, returns `None`.
pub fn target_as_constant(&self, target: Target) -> Option<F> {