diff --git a/src/circuit_builder.rs b/src/circuit_builder.rs index 71a3d245..a7e87339 100644 --- a/src/circuit_builder.rs +++ b/src/circuit_builder.rs @@ -1,4 +1,4 @@ -use std::collections::HashSet; +use std::collections::{HashSet, HashMap}; use std::time::Instant; use log::info; @@ -30,6 +30,9 @@ pub struct CircuitBuilder { /// Generators used to generate the witness. generators: Vec>>, + + constants_to_targets: HashMap, + targets_to_constants: HashMap, } impl CircuitBuilder { @@ -40,6 +43,8 @@ impl CircuitBuilder { gate_instances: Vec::new(), virtual_target_index: 0, generators: Vec::new(), + constants_to_targets: HashMap::new(), + targets_to_constants: HashMap::new(), } } @@ -139,14 +144,28 @@ impl CircuitBuilder { /// Returns a routable target with the given constant value. pub fn constant(&mut self, c: F) -> Target { + if let Some(&target) = self.constants_to_targets.get(&c) { + // We already have a wire for this constant. + return target; + } + let gate = self.add_gate(ConstantGate::get(), vec![c]); - Target::Wire(Wire { gate, input: ConstantGate::WIRE_OUTPUT }) + let target = Target::Wire(Wire { gate, input: ConstantGate::WIRE_OUTPUT }); + self.constants_to_targets.insert(c, target); + self.targets_to_constants.insert(target, c); + target } pub fn constants(&mut self, constants: &[F]) -> Vec { constants.iter().map(|&c| self.constant(c)).collect() } + /// If the given target is a constant (i.e. it was created by the `constant(F)` method), returns + /// its constant value. Otherwise, returns `None`. + pub fn target_as_constant(&self, target: Target) -> Option { + self.targets_to_constants.get(&target).cloned() + } + fn blind_and_pad(&mut self) { // TODO: Blind. diff --git a/src/gadgets/arithmetic.rs b/src/gadgets/arithmetic.rs index 7c9076da..ef289c58 100644 --- a/src/gadgets/arithmetic.rs +++ b/src/gadgets/arithmetic.rs @@ -3,6 +3,8 @@ use crate::field::field::Field; use crate::target::Target; use crate::gates::arithmetic::ArithmeticGate; use crate::wire::Wire; +use crate::generator::SimpleGenerator; +use crate::witness::PartialWitness; impl CircuitBuilder { pub fn neg(&mut self, x: Target) -> Target { @@ -10,29 +12,99 @@ impl CircuitBuilder { self.mul(x, neg_one) } - pub fn add(&mut self, x: Target, y: Target) -> Target { - let zero = self.zero(); - let one = self.one(); - if x == zero { - return y; - } - if y == zero { - return x; + /// Computes `const_0 * multiplicand_0 * multiplicand_1 + const_1 * addend`. + pub fn arithmetic( + &mut self, + const_0: F, + multiplicand_0: Target, + multiplicand_1: Target, + const_1: F, + addend: Target, + ) -> Target { + // See if we can determine the result without adding an `ArithmeticGate`. + if let Some(result) = self.arithmetic_special_cases( + const_0, multiplicand_0, multiplicand_1, const_1, addend) { + return result; } - let gate = self.add_gate(ArithmeticGate::new(), vec![F::ONE, F::ONE]); + let gate = self.add_gate(ArithmeticGate::new(), vec![const_0, const_1]); let wire_multiplicand_0 = Wire { gate, input: ArithmeticGate::WIRE_MULTIPLICAND_0 }; let wire_multiplicand_1 = Wire { gate, input: ArithmeticGate::WIRE_MULTIPLICAND_1 }; let wire_addend = Wire { gate, input: ArithmeticGate::WIRE_ADDEND }; let wire_output = Wire { gate, input: ArithmeticGate::WIRE_OUTPUT }; - self.route(x, Target::Wire(wire_multiplicand_0)); - self.route(one, Target::Wire(wire_multiplicand_1)); - self.route(y, Target::Wire(wire_addend)); + self.route(multiplicand_0, Target::Wire(wire_multiplicand_0)); + self.route(multiplicand_1, Target::Wire(wire_multiplicand_1)); + self.route(addend, Target::Wire(wire_addend)); Target::Wire(wire_output) } + /// Checks for special cases where the value of + /// `const_0 * multiplicand_0 * multiplicand_1 + const_1 * addend` + /// can be determined without adding an `ArithmeticGate`. + fn arithmetic_special_cases( + &mut self, + const_0: F, + multiplicand_0: Target, + multiplicand_1: Target, + const_1: F, + addend: Target, + ) -> Option { + let zero = self.zero(); + + let mul_0_const = self.target_as_constant(multiplicand_0); + let mul_1_const = self.target_as_constant(multiplicand_1); + let addend_const = self.target_as_constant(addend); + + let first_term_zero = const_0 == F::ZERO || multiplicand_0 == zero || multiplicand_1 == zero; + let second_term_zero = const_1 == F::ZERO || addend == zero; + + // If both terms are constant, return their (constant) sum. + let first_term_const = if first_term_zero { + Some(F::ZERO) + } else if let (Some(x), Some(y)) = (mul_0_const, mul_1_const) { + Some(const_0 * x * y) + } else { + None + }; + let second_term_const = if second_term_zero { + Some(F::ZERO) + } else { + addend_const.map(|x| const_1 * x) + }; + if let (Some(x), Some(y)) = (first_term_const, second_term_const) { + return Some(self.constant(x + y)); + } + + if first_term_zero { + if const_1.is_one() { + return Some(addend); + } + } + + if second_term_zero { + if let Some(x) = mul_0_const { + if (const_0 * x).is_one() { + return Some(multiplicand_1); + } + } + if let Some(x) = mul_1_const { + if (const_1 * x).is_one() { + return Some(multiplicand_0); + } + } + } + + None + } + + pub fn add(&mut self, x: Target, y: Target) -> Target { + let one = self.one(); + // x + y = 1 * x * 1 + 1 * y + self.arithmetic(F::ONE, x, one, F::ONE, y) + } + pub fn add_many(&mut self, terms: &[Target]) -> Target { let mut sum = self.zero(); for term in terms { @@ -42,22 +114,14 @@ impl CircuitBuilder { } pub fn sub(&mut self, x: Target, y: Target) -> Target { - let zero = self.zero(); - if x == zero { - return y; - } - if y == zero { - return x; - } - - // TODO: Inefficient impl for now. - let neg_y = self.neg(y); - self.add(x, neg_y) + let one = self.one(); + // x - y = 1 * x * 1 + (-1) * y + self.arithmetic(F::ONE, x, one, F::NEG_ONE, y) } pub fn mul(&mut self, x: Target, y: Target) -> Target { - // TODO: Check if one operand is 0 or 1. - todo!() + // x * y = 1 * x * y + 0 * x + self.arithmetic(F::ONE, x, y, F::ZERO, x) } pub fn mul_many(&mut self, terms: &[Target]) -> Target { @@ -68,8 +132,63 @@ impl CircuitBuilder { product } - pub fn div(&mut self, x: Target, y: Target) -> Target { - // TODO: Check if one operand is 0 or 1. - todo!() + /// Computes `q = x / y` by witnessing `q` and requiring that `q * y = x`. This can be unsafe in + /// some cases, as it allows `0 / 0 = `. + pub fn div_unsafe(&mut self, x: Target, y: Target) -> Target { + // Check for special cases where we can determine the result without an `ArithmeticGate`. + let zero = self.zero(); + let one = self.one(); + if x == zero { + return zero; + } + if y == one { + return x; + } + if let (Some(x_const), Some(y_const)) = (self.target_as_constant(x), self.target_as_constant(y)) { + return self.constant(x_const / y_const); + } + + // Add an `ArithmeticGate` to compute `q * y`. + let gate = self.add_gate(ArithmeticGate::new(), vec![F::ONE, F::ZERO]); + + let wire_multiplicand_0 = Wire { gate, input: ArithmeticGate::WIRE_MULTIPLICAND_0 }; + let wire_multiplicand_1 = Wire { gate, input: ArithmeticGate::WIRE_MULTIPLICAND_1 }; + let wire_addend = Wire { gate, input: ArithmeticGate::WIRE_ADDEND }; + let wire_output = Wire { gate, input: ArithmeticGate::WIRE_OUTPUT }; + + let q = Target::Wire(wire_multiplicand_0); + self.add_generator(QuotientGenerator { + numerator: x, + denominator: y, + quotient: q, + }); + + self.route(y, Target::Wire(wire_multiplicand_1)); + + // This can be anything, since the whole second term has a weight of zero. + self.route(zero, Target::Wire(wire_addend)); + + let q_y = Target::Wire(wire_output); + self.assert_equal(q_y, x); + + q + } +} + +struct QuotientGenerator { + numerator: Target, + denominator: Target, + quotient: Target, +} + +impl SimpleGenerator for QuotientGenerator { + fn dependencies(&self) -> Vec { + vec![self.numerator, self.denominator] + } + + fn run_once(&self, witness: &PartialWitness) -> PartialWitness { + let num = witness.get_target(self.numerator); + let den = witness.get_target(self.denominator); + PartialWitness::singleton_target(self.quotient, num / den) } }