diff --git a/src/circuit_builder.rs b/src/circuit_builder.rs index 0539a551..a7e87339 100644 --- a/src/circuit_builder.rs +++ b/src/circuit_builder.rs @@ -1,19 +1,16 @@ -use std::collections::HashSet; +use std::collections::{HashSet, HashMap}; use std::time::Instant; use log::info; -use crate::circuit_data::{ - CircuitConfig, CircuitData, CommonCircuitData, ProverCircuitData, ProverOnlyCircuitData, - VerifierCircuitData, VerifierOnlyCircuitData, -}; -use crate::field::cosets::get_unique_coset_shifts; +use crate::circuit_data::{CircuitConfig, CircuitData, CommonCircuitData, ProverCircuitData, ProverOnlyCircuitData, VerifierCircuitData, VerifierOnlyCircuitData}; use crate::field::field::Field; use crate::gates::constant::ConstantGate; use crate::gates::gate::{GateInstance, GateRef}; use crate::gates::noop::NoopGate; use crate::generator::{CopyGenerator, WitnessGenerator}; use crate::hash::merkle_root_bit_rev_order; +use crate::field::cosets::get_unique_coset_shifts; use crate::polynomial::polynomial::PolynomialValues; use crate::target::Target; use crate::util::{log2_strict, transpose, transpose_poly_values}; @@ -33,6 +30,9 @@ pub struct CircuitBuilder { /// Generators used to generate the witness. generators: Vec>>, + + constants_to_targets: HashMap, + targets_to_constants: HashMap, } impl CircuitBuilder { @@ -43,6 +43,8 @@ impl CircuitBuilder { gate_instances: Vec::new(), virtual_target_index: 0, generators: Vec::new(), + constants_to_targets: HashMap::new(), + targets_to_constants: HashMap::new(), } } @@ -82,21 +84,14 @@ impl CircuitBuilder { // TODO: Not passing next constants for now. Not sure if it's really useful... self.add_generators(gate_type.0.generators(index, &constants, &[])); - self.gate_instances.push(GateInstance { - gate_type, - constants, - }); + self.gate_instances.push(GateInstance { gate_type, constants }); index } fn check_gate_compatibility(&self, gate: &GateRef) { - assert!( - gate.0.num_wires() <= self.config.num_wires, - "{:?} requires {} wires, but our GateConfig has only {}", - gate.0.id(), - gate.0.num_wires(), - self.config.num_wires - ); + assert!(gate.0.num_wires() <= self.config.num_wires, + "{:?} requires {} wires, but our GateConfig has only {}", + gate.0.id(), gate.0.num_wires(), self.config.num_wires); } /// Shorthand for `generate_copy` and `assert_equal`. @@ -114,14 +109,8 @@ impl CircuitBuilder { /// Uses Plonk's permutation argument to require that two elements be equal. /// Both elements must be routable, otherwise this method will panic. pub fn assert_equal(&mut self, x: Target, y: Target) { - assert!( - x.is_routable(self.config), - "Tried to route a wire that isn't routable" - ); - assert!( - y.is_routable(self.config), - "Tried to route a wire that isn't routable" - ); + assert!(x.is_routable(self.config), "Tried to route a wire that isn't routable"); + assert!(y.is_routable(self.config), "Tried to route a wire that isn't routable"); // TODO: Add to copy_constraints. } @@ -155,17 +144,28 @@ impl CircuitBuilder { /// Returns a routable target with the given constant value. pub fn constant(&mut self, c: F) -> Target { + if let Some(&target) = self.constants_to_targets.get(&c) { + // We already have a wire for this constant. + return target; + } + let gate = self.add_gate(ConstantGate::get(), vec![c]); - Target::Wire(Wire { - gate, - input: ConstantGate::WIRE_OUTPUT, - }) + let target = Target::Wire(Wire { gate, input: ConstantGate::WIRE_OUTPUT }); + self.constants_to_targets.insert(c, target); + self.targets_to_constants.insert(target, c); + target } pub fn constants(&mut self, constants: &[F]) -> Vec { constants.iter().map(|&c| self.constant(c)).collect() } + /// If the given target is a constant (i.e. it was created by the `constant(F)` method), returns + /// its constant value. Otherwise, returns `None`. + pub fn target_as_constant(&self, target: Target) -> Option { + self.targets_to_constants.get(&target).cloned() + } + fn blind_and_pad(&mut self) { // TODO: Blind. @@ -175,15 +175,11 @@ impl CircuitBuilder { } fn constant_polys(&self) -> Vec> { - let num_constants = self - .gate_instances - .iter() + let num_constants = self.gate_instances.iter() .map(|gate_inst| gate_inst.constants.len()) .max() .unwrap(); - let constants_per_gate = self - .gate_instances - .iter() + let constants_per_gate = self.gate_instances.iter() .map(|gate_inst| { let mut padded_constants = gate_inst.constants.clone(); for _ in padded_constants.len()..num_constants { @@ -200,17 +196,13 @@ impl CircuitBuilder { } fn sigma_vecs(&self) -> Vec> { - vec![PolynomialValues::zero(self.gate_instances.len()); self.config.num_routed_wires] - // TODO + vec![PolynomialValues::zero(self.gate_instances.len()); self.config.num_routed_wires] // TODO } /// Builds a "full circuit", with both prover and verifier data. pub fn build(mut self) -> CircuitData { let start = Instant::now(); - info!( - "degree before blinding & padding: {}", - self.gate_instances.len() - ); + info!("degree before blinding & padding: {}", self.gate_instances.len()); self.blind_and_pad(); let degree = self.gate_instances.len(); info!("degree after blinding & padding: {}", degree); @@ -226,11 +218,7 @@ impl CircuitBuilder { let sigmas_root = merkle_root_bit_rev_order(sigma_ldes_t.clone()); let generators = self.generators; - let prover_only = ProverOnlyCircuitData { - generators, - constant_ldes_t, - sigma_ldes_t, - }; + let prover_only = ProverOnlyCircuitData { generators, constant_ldes_t, sigma_ldes_t }; let verifier_only = VerifierOnlyCircuitData {}; // The HashSet of gates will have a non-deterministic order. When converting to a Vec, we @@ -238,8 +226,7 @@ impl CircuitBuilder { let mut gates = self.gates.iter().cloned().collect::>(); gates.sort_unstable_by_key(|gate| gate.0.id()); - let num_gate_constraints = gates - .iter() + let num_gate_constraints = gates.iter() .map(|gate| gate.0.num_constraints()) .max() .expect("No gates?"); @@ -268,28 +255,14 @@ impl CircuitBuilder { /// Builds a "prover circuit", with data needed to generate proofs but not verify them. pub fn build_prover(self) -> ProverCircuitData { // TODO: Can skip parts of this. - let CircuitData { - prover_only, - common, - .. - } = self.build(); - ProverCircuitData { - prover_only, - common, - } + let CircuitData { prover_only, common, .. } = self.build(); + ProverCircuitData { prover_only, common } } /// Builds a "verifier circuit", with data needed to verify proofs but not generate them. pub fn build_verifier(self) -> VerifierCircuitData { // TODO: Can skip parts of this. - let CircuitData { - verifier_only, - common, - .. - } = self.build(); - VerifierCircuitData { - verifier_only, - common, - } + let CircuitData { verifier_only, common, .. } = self.build(); + VerifierCircuitData { verifier_only, common } } } diff --git a/src/gadgets/arithmetic.rs b/src/gadgets/arithmetic.rs index 2a8700e7..8fa727bb 100644 --- a/src/gadgets/arithmetic.rs +++ b/src/gadgets/arithmetic.rs @@ -3,6 +3,8 @@ use crate::field::field::Field; use crate::gates::arithmetic::ArithmeticGate; use crate::target::Target; use crate::wire::Wire; +use crate::generator::SimpleGenerator; +use crate::witness::PartialWitness; impl CircuitBuilder { pub fn neg(&mut self, x: Target) -> Target { @@ -10,17 +12,22 @@ impl CircuitBuilder { self.mul(x, neg_one) } - pub fn add(&mut self, x: Target, y: Target) -> Target { - let zero = self.zero(); - let one = self.one(); - if x == zero { - return y; - } - if y == zero { - return x; + /// Computes `const_0 * multiplicand_0 * multiplicand_1 + const_1 * addend`. + pub fn arithmetic( + &mut self, + const_0: F, + multiplicand_0: Target, + multiplicand_1: Target, + const_1: F, + addend: Target, + ) -> Target { + // See if we can determine the result without adding an `ArithmeticGate`. + if let Some(result) = self.arithmetic_special_cases( + const_0, multiplicand_0, multiplicand_1, const_1, addend) { + return result; } - let gate = self.add_gate(ArithmeticGate::new(), vec![F::ONE, F::ONE]); + let gate = self.add_gate(ArithmeticGate::new(), vec![const_0, const_1]); let wire_multiplicand_0 = Wire { gate, @@ -39,12 +46,77 @@ impl CircuitBuilder { input: ArithmeticGate::WIRE_OUTPUT, }; - self.route(x, Target::Wire(wire_multiplicand_0)); - self.route(one, Target::Wire(wire_multiplicand_1)); - self.route(y, Target::Wire(wire_addend)); + self.route(multiplicand_0, Target::Wire(wire_multiplicand_0)); + self.route(multiplicand_1, Target::Wire(wire_multiplicand_1)); + self.route(addend, Target::Wire(wire_addend)); Target::Wire(wire_output) } + /// Checks for special cases where the value of + /// `const_0 * multiplicand_0 * multiplicand_1 + const_1 * addend` + /// can be determined without adding an `ArithmeticGate`. + fn arithmetic_special_cases( + &mut self, + const_0: F, + multiplicand_0: Target, + multiplicand_1: Target, + const_1: F, + addend: Target, + ) -> Option { + let zero = self.zero(); + + let mul_0_const = self.target_as_constant(multiplicand_0); + let mul_1_const = self.target_as_constant(multiplicand_1); + let addend_const = self.target_as_constant(addend); + + let first_term_zero = const_0 == F::ZERO || multiplicand_0 == zero || multiplicand_1 == zero; + let second_term_zero = const_1 == F::ZERO || addend == zero; + + // If both terms are constant, return their (constant) sum. + let first_term_const = if first_term_zero { + Some(F::ZERO) + } else if let (Some(x), Some(y)) = (mul_0_const, mul_1_const) { + Some(const_0 * x * y) + } else { + None + }; + let second_term_const = if second_term_zero { + Some(F::ZERO) + } else { + addend_const.map(|x| const_1 * x) + }; + if let (Some(x), Some(y)) = (first_term_const, second_term_const) { + return Some(self.constant(x + y)); + } + + if first_term_zero { + if const_1.is_one() { + return Some(addend); + } + } + + if second_term_zero { + if let Some(x) = mul_0_const { + if (const_0 * x).is_one() { + return Some(multiplicand_1); + } + } + if let Some(x) = mul_1_const { + if (const_1 * x).is_one() { + return Some(multiplicand_0); + } + } + } + + None + } + + pub fn add(&mut self, x: Target, y: Target) -> Target { + let one = self.one(); + // x + y = 1 * x * 1 + 1 * y + self.arithmetic(F::ONE, x, one, F::ONE, y) + } + pub fn add_many(&mut self, terms: &[Target]) -> Target { let mut sum = self.zero(); for term in terms { @@ -54,22 +126,14 @@ impl CircuitBuilder { } pub fn sub(&mut self, x: Target, y: Target) -> Target { - let zero = self.zero(); - if x == zero { - return y; - } - if y == zero { - return x; - } - - // TODO: Inefficient impl for now. - let neg_y = self.neg(y); - self.add(x, neg_y) + let one = self.one(); + // x - y = 1 * x * 1 + (-1) * y + self.arithmetic(F::ONE, x, one, F::NEG_ONE, y) } pub fn mul(&mut self, x: Target, y: Target) -> Target { - // TODO: Check if one operand is 0 or 1. - todo!() + // x * y = 1 * x * y + 0 * x + self.arithmetic(F::ONE, x, y, F::ZERO, x) } pub fn mul_many(&mut self, terms: &[Target]) -> Target { @@ -80,8 +144,63 @@ impl CircuitBuilder { product } - pub fn div(&mut self, x: Target, y: Target) -> Target { - // TODO: Check if one operand is 0 or 1. - todo!() + /// Computes `q = x / y` by witnessing `q` and requiring that `q * y = x`. This can be unsafe in + /// some cases, as it allows `0 / 0 = `. + pub fn div_unsafe(&mut self, x: Target, y: Target) -> Target { + // Check for special cases where we can determine the result without an `ArithmeticGate`. + let zero = self.zero(); + let one = self.one(); + if x == zero { + return zero; + } + if y == one { + return x; + } + if let (Some(x_const), Some(y_const)) = (self.target_as_constant(x), self.target_as_constant(y)) { + return self.constant(x_const / y_const); + } + + // Add an `ArithmeticGate` to compute `q * y`. + let gate = self.add_gate(ArithmeticGate::new(), vec![F::ONE, F::ZERO]); + + let wire_multiplicand_0 = Wire { gate, input: ArithmeticGate::WIRE_MULTIPLICAND_0 }; + let wire_multiplicand_1 = Wire { gate, input: ArithmeticGate::WIRE_MULTIPLICAND_1 }; + let wire_addend = Wire { gate, input: ArithmeticGate::WIRE_ADDEND }; + let wire_output = Wire { gate, input: ArithmeticGate::WIRE_OUTPUT }; + + let q = Target::Wire(wire_multiplicand_0); + self.add_generator(QuotientGenerator { + numerator: x, + denominator: y, + quotient: q, + }); + + self.route(y, Target::Wire(wire_multiplicand_1)); + + // This can be anything, since the whole second term has a weight of zero. + self.route(zero, Target::Wire(wire_addend)); + + let q_y = Target::Wire(wire_output); + self.assert_equal(q_y, x); + + q + } +} + +struct QuotientGenerator { + numerator: Target, + denominator: Target, + quotient: Target, +} + +impl SimpleGenerator for QuotientGenerator { + fn dependencies(&self) -> Vec { + vec![self.numerator, self.denominator] + } + + fn run_once(&self, witness: &PartialWitness) -> PartialWitness { + let num = witness.get_target(self.numerator); + let den = witness.get_target(self.denominator); + PartialWitness::singleton_target(self.quotient, num / den) } }