From 93d695d33e5953738a04a94d00db00924f527ec0 Mon Sep 17 00:00:00 2001 From: wborgeaud Date: Thu, 2 Dec 2021 15:14:25 +0100 Subject: [PATCH 1/2] Variable number of U32 ops --- src/gadgets/arithmetic_u32.rs | 20 ++---- src/gates/arithmetic_u32.rs | 122 ++++++++++++++++++---------------- src/plonk/circuit_builder.rs | 31 +++------ 3 files changed, 80 insertions(+), 93 deletions(-) diff --git a/src/gadgets/arithmetic_u32.rs b/src/gadgets/arithmetic_u32.rs index ce7aa121..d15df304 100644 --- a/src/gadgets/arithmetic_u32.rs +++ b/src/gadgets/arithmetic_u32.rs @@ -76,34 +76,26 @@ impl, const D: usize> CircuitBuilder { return result; } + let gate = U32ArithmeticGate::::new_from_config(&self.config); let (gate_index, copy) = self.find_u32_arithmetic_gate(); self.connect( - Target::wire( - gate_index, - U32ArithmeticGate::::wire_ith_multiplicand_0(copy), - ), + Target::wire(gate_index, gate.wire_ith_multiplicand_0(copy)), x.0, ); self.connect( - Target::wire( - gate_index, - U32ArithmeticGate::::wire_ith_multiplicand_1(copy), - ), + Target::wire(gate_index, gate.wire_ith_multiplicand_1(copy)), y.0, ); - self.connect( - Target::wire(gate_index, U32ArithmeticGate::::wire_ith_addend(copy)), - z.0, - ); + self.connect(Target::wire(gate_index, gate.wire_ith_addend(copy)), z.0); let output_low = U32Target(Target::wire( gate_index, - U32ArithmeticGate::::wire_ith_output_low_half(copy), + gate.wire_ith_output_low_half(copy), )); let output_high = U32Target(Target::wire( gate_index, - U32ArithmeticGate::::wire_ith_output_high_half(copy), + gate.wire_ith_output_high_half(copy), )); (output_low, output_high) diff --git a/src/gates/arithmetic_u32.rs b/src/gates/arithmetic_u32.rs index a5a63047..e88654df 100644 --- a/src/gates/arithmetic_u32.rs +++ b/src/gates/arithmetic_u32.rs @@ -11,43 +11,49 @@ use crate::iop::target::Target; use crate::iop::wire::Wire; use crate::iop::witness::{PartitionWitness, Witness}; use crate::plonk::circuit_builder::CircuitBuilder; +use crate::plonk::circuit_data::CircuitConfig; use crate::plonk::vars::{EvaluationTargets, EvaluationVars, EvaluationVarsBase}; -/// Number of arithmetic operations performed by an arithmetic gate. -pub const NUM_U32_ARITHMETIC_OPS: usize = 3; - /// A gate to perform a basic mul-add on 32-bit values (we assume they are range-checked beforehand). -#[derive(Clone, Debug)] +#[derive(Copy, Clone, Debug)] pub struct U32ArithmeticGate, const D: usize> { + pub num_ops: usize, _phantom: PhantomData, } impl, const D: usize> U32ArithmeticGate { - pub fn new() -> Self { + pub fn new_from_config(config: &CircuitConfig) -> Self { Self { + num_ops: Self::num_ops(config), _phantom: PhantomData, } } - pub fn wire_ith_multiplicand_0(i: usize) -> usize { - debug_assert!(i < NUM_U32_ARITHMETIC_OPS); + pub(crate) fn num_ops(config: &CircuitConfig) -> usize { + let wires_per_op = 5 + Self::num_limbs(); + let routed_wires_per_op = 5; + (config.num_wires / wires_per_op).min(config.num_routed_wires / routed_wires_per_op) + } + + pub fn wire_ith_multiplicand_0(&self, i: usize) -> usize { + debug_assert!(i < self.num_ops); 5 * i } - pub fn wire_ith_multiplicand_1(i: usize) -> usize { - debug_assert!(i < NUM_U32_ARITHMETIC_OPS); + pub fn wire_ith_multiplicand_1(&self, i: usize) -> usize { + debug_assert!(i < self.num_ops); 5 * i + 1 } - pub fn wire_ith_addend(i: usize) -> usize { - debug_assert!(i < NUM_U32_ARITHMETIC_OPS); + pub fn wire_ith_addend(&self, i: usize) -> usize { + debug_assert!(i < self.num_ops); 5 * i + 2 } - pub fn wire_ith_output_low_half(i: usize) -> usize { - debug_assert!(i < NUM_U32_ARITHMETIC_OPS); + pub fn wire_ith_output_low_half(&self, i: usize) -> usize { + debug_assert!(i < self.num_ops); 5 * i + 3 } - pub fn wire_ith_output_high_half(i: usize) -> usize { - debug_assert!(i < NUM_U32_ARITHMETIC_OPS); + pub fn wire_ith_output_high_half(&self, i: usize) -> usize { + debug_assert!(i < self.num_ops); 5 * i + 4 } @@ -58,10 +64,10 @@ impl, const D: usize> U32ArithmeticGate { 64 / Self::limb_bits() } - pub fn wire_ith_output_jth_limb(i: usize, j: usize) -> usize { - debug_assert!(i < NUM_U32_ARITHMETIC_OPS); + pub fn wire_ith_output_jth_limb(&self, i: usize, j: usize) -> usize { + debug_assert!(i < self.num_ops); debug_assert!(j < Self::num_limbs()); - 5 * NUM_U32_ARITHMETIC_OPS + Self::num_limbs() * i + j + 5 * self.num_ops + Self::num_limbs() * i + j } } @@ -72,15 +78,15 @@ impl, const D: usize> Gate for U32ArithmeticG fn eval_unfiltered(&self, vars: EvaluationVars) -> Vec { let mut constraints = Vec::with_capacity(self.num_constraints()); - for i in 0..NUM_U32_ARITHMETIC_OPS { - let multiplicand_0 = vars.local_wires[Self::wire_ith_multiplicand_0(i)]; - let multiplicand_1 = vars.local_wires[Self::wire_ith_multiplicand_1(i)]; - let addend = vars.local_wires[Self::wire_ith_addend(i)]; + for i in 0..self.num_ops { + let multiplicand_0 = vars.local_wires[self.wire_ith_multiplicand_0(i)]; + let multiplicand_1 = vars.local_wires[self.wire_ith_multiplicand_1(i)]; + let addend = vars.local_wires[self.wire_ith_addend(i)]; let computed_output = multiplicand_0 * multiplicand_1 + addend; - let output_low = vars.local_wires[Self::wire_ith_output_low_half(i)]; - let output_high = vars.local_wires[Self::wire_ith_output_high_half(i)]; + let output_low = vars.local_wires[self.wire_ith_output_low_half(i)]; + let output_high = vars.local_wires[self.wire_ith_output_high_half(i)]; let base = F::Extension::from_canonical_u64(1 << 32u64); let combined_output = output_high * base + output_low; @@ -92,7 +98,7 @@ impl, const D: usize> Gate for U32ArithmeticG let midpoint = Self::num_limbs() / 2; let base = F::Extension::from_canonical_u64(1u64 << Self::limb_bits()); for j in (0..Self::num_limbs()).rev() { - let this_limb = vars.local_wires[Self::wire_ith_output_jth_limb(i, j)]; + let this_limb = vars.local_wires[self.wire_ith_output_jth_limb(i, j)]; let max_limb = 1 << Self::limb_bits(); let product = (0..max_limb) .map(|x| this_limb - F::Extension::from_canonical_usize(x)) @@ -114,15 +120,15 @@ impl, const D: usize> Gate for U32ArithmeticG fn eval_unfiltered_base(&self, vars: EvaluationVarsBase) -> Vec { let mut constraints = Vec::with_capacity(self.num_constraints()); - for i in 0..NUM_U32_ARITHMETIC_OPS { - let multiplicand_0 = vars.local_wires[Self::wire_ith_multiplicand_0(i)]; - let multiplicand_1 = vars.local_wires[Self::wire_ith_multiplicand_1(i)]; - let addend = vars.local_wires[Self::wire_ith_addend(i)]; + for i in 0..self.num_ops { + let multiplicand_0 = vars.local_wires[self.wire_ith_multiplicand_0(i)]; + let multiplicand_1 = vars.local_wires[self.wire_ith_multiplicand_1(i)]; + let addend = vars.local_wires[self.wire_ith_addend(i)]; let computed_output = multiplicand_0 * multiplicand_1 + addend; - let output_low = vars.local_wires[Self::wire_ith_output_low_half(i)]; - let output_high = vars.local_wires[Self::wire_ith_output_high_half(i)]; + let output_low = vars.local_wires[self.wire_ith_output_low_half(i)]; + let output_high = vars.local_wires[self.wire_ith_output_high_half(i)]; let base = F::from_canonical_u64(1 << 32u64); let combined_output = output_high * base + output_low; @@ -134,7 +140,7 @@ impl, const D: usize> Gate for U32ArithmeticG let midpoint = Self::num_limbs() / 2; let base = F::from_canonical_u64(1u64 << Self::limb_bits()); for j in (0..Self::num_limbs()).rev() { - let this_limb = vars.local_wires[Self::wire_ith_output_jth_limb(i, j)]; + let this_limb = vars.local_wires[self.wire_ith_output_jth_limb(i, j)]; let max_limb = 1 << Self::limb_bits(); let product = (0..max_limb) .map(|x| this_limb - F::from_canonical_usize(x)) @@ -161,15 +167,15 @@ impl, const D: usize> Gate for U32ArithmeticG ) -> Vec> { let mut constraints = Vec::with_capacity(self.num_constraints()); - for i in 0..NUM_U32_ARITHMETIC_OPS { - let multiplicand_0 = vars.local_wires[Self::wire_ith_multiplicand_0(i)]; - let multiplicand_1 = vars.local_wires[Self::wire_ith_multiplicand_1(i)]; - let addend = vars.local_wires[Self::wire_ith_addend(i)]; + for i in 0..self.num_ops { + let multiplicand_0 = vars.local_wires[self.wire_ith_multiplicand_0(i)]; + let multiplicand_1 = vars.local_wires[self.wire_ith_multiplicand_1(i)]; + let addend = vars.local_wires[self.wire_ith_addend(i)]; let computed_output = builder.mul_add_extension(multiplicand_0, multiplicand_1, addend); - let output_low = vars.local_wires[Self::wire_ith_output_low_half(i)]; - let output_high = vars.local_wires[Self::wire_ith_output_high_half(i)]; + let output_low = vars.local_wires[self.wire_ith_output_low_half(i)]; + let output_high = vars.local_wires[self.wire_ith_output_high_half(i)]; let base: F::Extension = F::from_canonical_u64(1 << 32u64).into(); let base_target = builder.constant_extension(base); @@ -183,7 +189,7 @@ impl, const D: usize> Gate for U32ArithmeticG let base = builder .constant_extension(F::Extension::from_canonical_u64(1u64 << Self::limb_bits())); for j in (0..Self::num_limbs()).rev() { - let this_limb = vars.local_wires[Self::wire_ith_output_jth_limb(i, j)]; + let this_limb = vars.local_wires[self.wire_ith_output_jth_limb(i, j)]; let max_limb = 1 << Self::limb_bits(); let mut product = builder.one_extension(); @@ -216,10 +222,11 @@ impl, const D: usize> Gate for U32ArithmeticG gate_index: usize, _local_constants: &[F], ) -> Vec>> { - (0..NUM_U32_ARITHMETIC_OPS) + (0..self.num_ops) .map(|i| { let g: Box> = Box::new( U32ArithmeticGenerator { + gate: *self, gate_index, i, _phantom: PhantomData, @@ -232,7 +239,7 @@ impl, const D: usize> Gate for U32ArithmeticG } fn num_wires(&self) -> usize { - NUM_U32_ARITHMETIC_OPS * (5 + Self::num_limbs()) + self.num_ops * (5 + Self::num_limbs()) } fn num_constants(&self) -> usize { @@ -244,12 +251,13 @@ impl, const D: usize> Gate for U32ArithmeticG } fn num_constraints(&self) -> usize { - NUM_U32_ARITHMETIC_OPS * (3 + Self::num_limbs()) + self.num_ops * (3 + Self::num_limbs()) } } #[derive(Clone, Debug)] struct U32ArithmeticGenerator, const D: usize> { + gate: U32ArithmeticGate, gate_index: usize, i: usize, _phantom: PhantomData, @@ -262,9 +270,9 @@ impl, const D: usize> SimpleGenerator let local_target = |input| Target::wire(self.gate_index, input); vec![ - local_target(U32ArithmeticGate::::wire_ith_multiplicand_0(self.i)), - local_target(U32ArithmeticGate::::wire_ith_multiplicand_1(self.i)), - local_target(U32ArithmeticGate::::wire_ith_addend(self.i)), + local_target(self.gate.wire_ith_multiplicand_0(self.i)), + local_target(self.gate.wire_ith_multiplicand_1(self.i)), + local_target(self.gate.wire_ith_addend(self.i)), ] } @@ -276,11 +284,9 @@ impl, const D: usize> SimpleGenerator let get_local_wire = |input| witness.get_wire(local_wire(input)); - let multiplicand_0 = - get_local_wire(U32ArithmeticGate::::wire_ith_multiplicand_0(self.i)); - let multiplicand_1 = - get_local_wire(U32ArithmeticGate::::wire_ith_multiplicand_1(self.i)); - let addend = get_local_wire(U32ArithmeticGate::::wire_ith_addend(self.i)); + let multiplicand_0 = get_local_wire(self.gate.wire_ith_multiplicand_0(self.i)); + let multiplicand_1 = get_local_wire(self.gate.wire_ith_multiplicand_1(self.i)); + let addend = get_local_wire(self.gate.wire_ith_addend(self.i)); let output = multiplicand_0 * multiplicand_1 + addend; let mut output_u64 = output.to_canonical_u64(); @@ -291,10 +297,8 @@ impl, const D: usize> SimpleGenerator let output_high = F::from_canonical_u64(output_high_u64); let output_low = F::from_canonical_u64(output_low_u64); - let output_high_wire = - local_wire(U32ArithmeticGate::::wire_ith_output_high_half(self.i)); - let output_low_wire = - local_wire(U32ArithmeticGate::::wire_ith_output_low_half(self.i)); + let output_high_wire = local_wire(self.gate.wire_ith_output_high_half(self.i)); + let output_low_wire = local_wire(self.gate.wire_ith_output_low_half(self.i)); out_buffer.set_wire(output_high_wire, output_high); out_buffer.set_wire(output_low_wire, output_low); @@ -310,9 +314,7 @@ impl, const D: usize> SimpleGenerator let output_limbs_f = output_limbs_u64.map(F::from_canonical_u64); for (j, output_limb) in output_limbs_f.enumerate() { - let wire = local_wire(U32ArithmeticGate::::wire_ith_output_jth_limb( - self.i, j, - )); + let wire = local_wire(self.gate.wire_ith_output_jth_limb(self.i, j)); out_buffer.set_wire(wire, output_limb); } } @@ -328,7 +330,7 @@ mod tests { use crate::field::extension_field::quartic::QuarticExtension; use crate::field::field_types::Field; use crate::field::goldilocks_field::GoldilocksField; - use crate::gates::arithmetic_u32::{U32ArithmeticGate, NUM_U32_ARITHMETIC_OPS}; + use crate::gates::arithmetic_u32::U32ArithmeticGate; use crate::gates::gate::Gate; use crate::gates::gate_testing::{test_eval_fns, test_low_degree}; use crate::hash::hash_types::HashOut; @@ -337,6 +339,7 @@ mod tests { #[test] fn low_degree() { test_low_degree::(U32ArithmeticGate:: { + num_ops: 3, _phantom: PhantomData, }) } @@ -344,6 +347,7 @@ mod tests { #[test] fn eval_fns() -> Result<()> { test_eval_fns::(U32ArithmeticGate:: { + num_ops: 3, _phantom: PhantomData, }) } @@ -353,6 +357,7 @@ mod tests { type F = GoldilocksField; type FF = QuarticExtension; const D: usize = 4; + const NUM_U32_ARITHMETIC_OPS: usize = 3; fn get_wires( multiplicands_0: Vec, @@ -410,6 +415,7 @@ mod tests { .collect(); let gate = U32ArithmeticGate:: { + num_ops: NUM_U32_ARITHMETIC_OPS, _phantom: PhantomData, }; diff --git a/src/plonk/circuit_builder.rs b/src/plonk/circuit_builder.rs index 3cb62cf3..7799bc38 100644 --- a/src/plonk/circuit_builder.rs +++ b/src/plonk/circuit_builder.rs @@ -16,7 +16,7 @@ use crate::gadgets::arithmetic_extension::ExtensionArithmeticOperation; use crate::gadgets::arithmetic_u32::U32Target; use crate::gates::arithmetic_base::ArithmeticGate; use crate::gates::arithmetic_extension::ArithmeticExtensionGate; -use crate::gates::arithmetic_u32::{U32ArithmeticGate, NUM_U32_ARITHMETIC_OPS}; +use crate::gates::arithmetic_u32::U32ArithmeticGate; use crate::gates::constant::ConstantGate; use crate::gates::gate::{Gate, GateInstance, GateRef, PrefixedGate}; use crate::gates::gate_tree::Tree; @@ -965,14 +965,14 @@ impl, const D: usize> CircuitBuilder { pub(crate) fn find_u32_arithmetic_gate(&mut self) -> (usize, usize) { let (gate_index, copy) = match self.batched_gates.current_u32_arithmetic_gate { None => { - let gate = U32ArithmeticGate::new(); + let gate = U32ArithmeticGate::new_from_config(&self.config); let gate_index = self.add_gate(gate, vec![]); (gate_index, 0) } Some((gate_index, copy)) => (gate_index, copy), }; - if copy == NUM_U32_ARITHMETIC_OPS - 1 { + if copy == U32ArithmeticGate::::num_ops(&self.config) - 1 { self.batched_gates.current_u32_arithmetic_gate = None; } else { self.batched_gates.current_u32_arithmetic_gate = Some((gate_index, copy + 1)); @@ -1111,23 +1111,12 @@ impl, const D: usize> CircuitBuilder { /// Fill the remaining unused U32 arithmetic operations with zeros, so that all /// `U32ArithmeticGenerator`s are run. fn fill_u32_arithmetic_gates(&mut self) { - let zero = self.zero(); - if let Some((gate_index, copy)) = self.batched_gates.current_u32_arithmetic_gate { - for i in copy..NUM_U32_ARITHMETIC_OPS { - let wire_multiplicand_0 = Target::wire( - gate_index, - U32ArithmeticGate::::wire_ith_multiplicand_0(i), - ); - let wire_multiplicand_1 = Target::wire( - gate_index, - U32ArithmeticGate::::wire_ith_multiplicand_1(i), - ); - let wire_addend = - Target::wire(gate_index, U32ArithmeticGate::::wire_ith_addend(i)); - - self.connect(zero, wire_multiplicand_0); - self.connect(zero, wire_multiplicand_1); - self.connect(zero, wire_addend); + let zero = self.zero_u32(); + if let Some((_gate_index, copy)) = self.batched_gates.current_u32_arithmetic_gate { + for _ in copy..U32ArithmeticGate::::num_ops(&self.config) { + let dummy = self.add_virtual_u32_target(); + self.mul_add_u32(dummy, dummy, dummy); + self.connect_u32(dummy, zero); } } } @@ -1137,7 +1126,7 @@ impl, const D: usize> CircuitBuilder { fn fill_u32_subtraction_gates(&mut self) { let zero = self.zero(); if let Some((gate_index, copy)) = self.batched_gates.current_u32_subtraction_gate { - for i in copy..NUM_U32_ARITHMETIC_OPS { + for i in copy..NUM_U32_SUBTRACTION_OPS { let wire_input_x = Target::wire(gate_index, U32SubtractionGate::::wire_ith_input_x(i)); let wire_input_y = From 29ed0673f2a76b912af970e80761626283851921 Mon Sep 17 00:00:00 2001 From: wborgeaud Date: Thu, 2 Dec 2021 15:35:59 +0100 Subject: [PATCH 2/2] Variable number of U32 sub ops --- src/gadgets/arithmetic_u32.rs | 32 ++------- src/gates/subtraction_u32.rs | 121 ++++++++++++++++++---------------- src/plonk/circuit_builder.rs | 27 +++----- 3 files changed, 79 insertions(+), 101 deletions(-) diff --git a/src/gadgets/arithmetic_u32.rs b/src/gadgets/arithmetic_u32.rs index d15df304..3bf6ce58 100644 --- a/src/gadgets/arithmetic_u32.rs +++ b/src/gadgets/arithmetic_u32.rs @@ -136,38 +136,18 @@ impl, const D: usize> CircuitBuilder { y: U32Target, borrow: U32Target, ) -> (U32Target, U32Target) { + let gate = U32SubtractionGate::::new_from_config(&self.config); let (gate_index, copy) = self.find_u32_subtraction_gate(); + self.connect(Target::wire(gate_index, gate.wire_ith_input_x(copy)), x.0); + self.connect(Target::wire(gate_index, gate.wire_ith_input_y(copy)), y.0); self.connect( - Target::wire( - gate_index, - U32SubtractionGate::::wire_ith_input_x(copy), - ), - x.0, - ); - self.connect( - Target::wire( - gate_index, - U32SubtractionGate::::wire_ith_input_y(copy), - ), - y.0, - ); - self.connect( - Target::wire( - gate_index, - U32SubtractionGate::::wire_ith_input_borrow(copy), - ), + Target::wire(gate_index, gate.wire_ith_input_borrow(copy)), borrow.0, ); - let output_result = U32Target(Target::wire( - gate_index, - U32SubtractionGate::::wire_ith_output_result(copy), - )); - let output_borrow = U32Target(Target::wire( - gate_index, - U32SubtractionGate::::wire_ith_output_borrow(copy), - )); + let output_result = U32Target(Target::wire(gate_index, gate.wire_ith_output_result(copy))); + let output_borrow = U32Target(Target::wire(gate_index, gate.wire_ith_output_borrow(copy))); (output_result, output_borrow) } diff --git a/src/gates/subtraction_u32.rs b/src/gates/subtraction_u32.rs index 26f6302e..fc4cd646 100644 --- a/src/gates/subtraction_u32.rs +++ b/src/gates/subtraction_u32.rs @@ -9,44 +9,50 @@ use crate::iop::target::Target; use crate::iop::wire::Wire; use crate::iop::witness::{PartitionWitness, Witness}; use crate::plonk::circuit_builder::CircuitBuilder; +use crate::plonk::circuit_data::CircuitConfig; use crate::plonk::vars::{EvaluationTargets, EvaluationVars, EvaluationVarsBase}; -/// Maximum number of subtractions operations performed by a single gate. -pub const NUM_U32_SUBTRACTION_OPS: usize = 3; - /// A gate to perform a subtraction on 32-bit limbs: given `x`, `y`, and `borrow`, it returns /// the result `x - y - borrow` and, if this underflows, a new `borrow`. Inputs are not range-checked. -#[derive(Clone, Debug)] +#[derive(Copy, Clone, Debug)] pub struct U32SubtractionGate, const D: usize> { + pub num_ops: usize, _phantom: PhantomData, } impl, const D: usize> U32SubtractionGate { - pub fn new() -> Self { + pub fn new_from_config(config: &CircuitConfig) -> Self { Self { + num_ops: Self::num_ops(config), _phantom: PhantomData, } } - pub fn wire_ith_input_x(i: usize) -> usize { - debug_assert!(i < NUM_U32_SUBTRACTION_OPS); + pub(crate) fn num_ops(config: &CircuitConfig) -> usize { + let wires_per_op = 5 + Self::num_limbs(); + let routed_wires_per_op = 5; + (config.num_wires / wires_per_op).min(config.num_routed_wires / routed_wires_per_op) + } + + pub fn wire_ith_input_x(&self, i: usize) -> usize { + debug_assert!(i < self.num_ops); 5 * i } - pub fn wire_ith_input_y(i: usize) -> usize { - debug_assert!(i < NUM_U32_SUBTRACTION_OPS); + pub fn wire_ith_input_y(&self, i: usize) -> usize { + debug_assert!(i < self.num_ops); 5 * i + 1 } - pub fn wire_ith_input_borrow(i: usize) -> usize { - debug_assert!(i < NUM_U32_SUBTRACTION_OPS); + pub fn wire_ith_input_borrow(&self, i: usize) -> usize { + debug_assert!(i < self.num_ops); 5 * i + 2 } - pub fn wire_ith_output_result(i: usize) -> usize { - debug_assert!(i < NUM_U32_SUBTRACTION_OPS); + pub fn wire_ith_output_result(&self, i: usize) -> usize { + debug_assert!(i < self.num_ops); 5 * i + 3 } - pub fn wire_ith_output_borrow(i: usize) -> usize { - debug_assert!(i < NUM_U32_SUBTRACTION_OPS); + pub fn wire_ith_output_borrow(&self, i: usize) -> usize { + debug_assert!(i < self.num_ops); 5 * i + 4 } @@ -58,10 +64,10 @@ impl, const D: usize> U32SubtractionGate { 32 / Self::limb_bits() } - pub fn wire_ith_output_jth_limb(i: usize, j: usize) -> usize { - debug_assert!(i < NUM_U32_SUBTRACTION_OPS); + pub fn wire_ith_output_jth_limb(&self, i: usize, j: usize) -> usize { + debug_assert!(i < self.num_ops); debug_assert!(j < Self::num_limbs()); - 5 * NUM_U32_SUBTRACTION_OPS + Self::num_limbs() * i + j + 5 * self.num_ops + Self::num_limbs() * i + j } } @@ -72,16 +78,16 @@ impl, const D: usize> Gate for U32Subtraction fn eval_unfiltered(&self, vars: EvaluationVars) -> Vec { let mut constraints = Vec::with_capacity(self.num_constraints()); - for i in 0..NUM_U32_SUBTRACTION_OPS { - let input_x = vars.local_wires[Self::wire_ith_input_x(i)]; - let input_y = vars.local_wires[Self::wire_ith_input_y(i)]; - let input_borrow = vars.local_wires[Self::wire_ith_input_borrow(i)]; + for i in 0..self.num_ops { + let input_x = vars.local_wires[self.wire_ith_input_x(i)]; + let input_y = vars.local_wires[self.wire_ith_input_y(i)]; + let input_borrow = vars.local_wires[self.wire_ith_input_borrow(i)]; let result_initial = input_x - input_y - input_borrow; let base = F::Extension::from_canonical_u64(1 << 32u64); - let output_result = vars.local_wires[Self::wire_ith_output_result(i)]; - let output_borrow = vars.local_wires[Self::wire_ith_output_borrow(i)]; + let output_result = vars.local_wires[self.wire_ith_output_result(i)]; + let output_borrow = vars.local_wires[self.wire_ith_output_borrow(i)]; constraints.push(output_result - (result_initial + base * output_borrow)); @@ -89,7 +95,7 @@ impl, const D: usize> Gate for U32Subtraction let mut combined_limbs = F::Extension::ZERO; let limb_base = F::Extension::from_canonical_u64(1u64 << Self::limb_bits()); for j in (0..Self::num_limbs()).rev() { - let this_limb = vars.local_wires[Self::wire_ith_output_jth_limb(i, j)]; + let this_limb = vars.local_wires[self.wire_ith_output_jth_limb(i, j)]; let max_limb = 1 << Self::limb_bits(); let product = (0..max_limb) .map(|x| this_limb - F::Extension::from_canonical_usize(x)) @@ -109,16 +115,16 @@ impl, const D: usize> Gate for U32Subtraction fn eval_unfiltered_base(&self, vars: EvaluationVarsBase) -> Vec { let mut constraints = Vec::with_capacity(self.num_constraints()); - for i in 0..NUM_U32_SUBTRACTION_OPS { - let input_x = vars.local_wires[Self::wire_ith_input_x(i)]; - let input_y = vars.local_wires[Self::wire_ith_input_y(i)]; - let input_borrow = vars.local_wires[Self::wire_ith_input_borrow(i)]; + for i in 0..self.num_ops { + let input_x = vars.local_wires[self.wire_ith_input_x(i)]; + let input_y = vars.local_wires[self.wire_ith_input_y(i)]; + let input_borrow = vars.local_wires[self.wire_ith_input_borrow(i)]; let result_initial = input_x - input_y - input_borrow; let base = F::from_canonical_u64(1 << 32u64); - let output_result = vars.local_wires[Self::wire_ith_output_result(i)]; - let output_borrow = vars.local_wires[Self::wire_ith_output_borrow(i)]; + let output_result = vars.local_wires[self.wire_ith_output_result(i)]; + let output_borrow = vars.local_wires[self.wire_ith_output_borrow(i)]; constraints.push(output_result - (result_initial + base * output_borrow)); @@ -126,7 +132,7 @@ impl, const D: usize> Gate for U32Subtraction let mut combined_limbs = F::ZERO; let limb_base = F::from_canonical_u64(1u64 << Self::limb_bits()); for j in (0..Self::num_limbs()).rev() { - let this_limb = vars.local_wires[Self::wire_ith_output_jth_limb(i, j)]; + let this_limb = vars.local_wires[self.wire_ith_output_jth_limb(i, j)]; let max_limb = 1 << Self::limb_bits(); let product = (0..max_limb) .map(|x| this_limb - F::from_canonical_usize(x)) @@ -150,17 +156,17 @@ impl, const D: usize> Gate for U32Subtraction vars: EvaluationTargets, ) -> Vec> { let mut constraints = Vec::with_capacity(self.num_constraints()); - for i in 0..NUM_U32_SUBTRACTION_OPS { - let input_x = vars.local_wires[Self::wire_ith_input_x(i)]; - let input_y = vars.local_wires[Self::wire_ith_input_y(i)]; - let input_borrow = vars.local_wires[Self::wire_ith_input_borrow(i)]; + for i in 0..self.num_ops { + let input_x = vars.local_wires[self.wire_ith_input_x(i)]; + let input_y = vars.local_wires[self.wire_ith_input_y(i)]; + let input_borrow = vars.local_wires[self.wire_ith_input_borrow(i)]; let diff = builder.sub_extension(input_x, input_y); let result_initial = builder.sub_extension(diff, input_borrow); let base = builder.constant_extension(F::Extension::from_canonical_u64(1 << 32u64)); - let output_result = vars.local_wires[Self::wire_ith_output_result(i)]; - let output_borrow = vars.local_wires[Self::wire_ith_output_borrow(i)]; + let output_result = vars.local_wires[self.wire_ith_output_result(i)]; + let output_borrow = vars.local_wires[self.wire_ith_output_borrow(i)]; let computed_output = builder.mul_add_extension(base, output_borrow, result_initial); constraints.push(builder.sub_extension(output_result, computed_output)); @@ -170,7 +176,7 @@ impl, const D: usize> Gate for U32Subtraction let limb_base = builder .constant_extension(F::Extension::from_canonical_u64(1u64 << Self::limb_bits())); for j in (0..Self::num_limbs()).rev() { - let this_limb = vars.local_wires[Self::wire_ith_output_jth_limb(i, j)]; + let this_limb = vars.local_wires[self.wire_ith_output_jth_limb(i, j)]; let max_limb = 1 << Self::limb_bits(); let mut product = builder.one_extension(); for x in 0..max_limb { @@ -199,10 +205,11 @@ impl, const D: usize> Gate for U32Subtraction gate_index: usize, _local_constants: &[F], ) -> Vec>> { - (0..NUM_U32_SUBTRACTION_OPS) + (0..self.num_ops) .map(|i| { let g: Box> = Box::new( U32SubtractionGenerator { + gate: *self, gate_index, i, _phantom: PhantomData, @@ -215,7 +222,7 @@ impl, const D: usize> Gate for U32Subtraction } fn num_wires(&self) -> usize { - NUM_U32_SUBTRACTION_OPS * (5 + Self::num_limbs()) + self.num_ops * (5 + Self::num_limbs()) } fn num_constants(&self) -> usize { @@ -227,12 +234,13 @@ impl, const D: usize> Gate for U32Subtraction } fn num_constraints(&self) -> usize { - NUM_U32_SUBTRACTION_OPS * (3 + Self::num_limbs()) + self.num_ops * (3 + Self::num_limbs()) } } #[derive(Clone, Debug)] struct U32SubtractionGenerator, const D: usize> { + gate: U32SubtractionGate, gate_index: usize, i: usize, _phantom: PhantomData, @@ -245,9 +253,9 @@ impl, const D: usize> SimpleGenerator let local_target = |input| Target::wire(self.gate_index, input); vec![ - local_target(U32SubtractionGate::::wire_ith_input_x(self.i)), - local_target(U32SubtractionGate::::wire_ith_input_y(self.i)), - local_target(U32SubtractionGate::::wire_ith_input_borrow(self.i)), + local_target(self.gate.wire_ith_input_x(self.i)), + local_target(self.gate.wire_ith_input_y(self.i)), + local_target(self.gate.wire_ith_input_borrow(self.i)), ] } @@ -259,10 +267,9 @@ impl, const D: usize> SimpleGenerator let get_local_wire = |input| witness.get_wire(local_wire(input)); - let input_x = get_local_wire(U32SubtractionGate::::wire_ith_input_x(self.i)); - let input_y = get_local_wire(U32SubtractionGate::::wire_ith_input_y(self.i)); - let input_borrow = - get_local_wire(U32SubtractionGate::::wire_ith_input_borrow(self.i)); + let input_x = get_local_wire(self.gate.wire_ith_input_x(self.i)); + let input_y = get_local_wire(self.gate.wire_ith_input_y(self.i)); + let input_borrow = get_local_wire(self.gate.wire_ith_input_borrow(self.i)); let result_initial = input_x - input_y - input_borrow; let result_initial_u64 = result_initial.to_canonical_u64(); @@ -275,10 +282,8 @@ impl, const D: usize> SimpleGenerator let base = F::from_canonical_u64(1 << 32u64); let output_result = result_initial + base * output_borrow; - let output_result_wire = - local_wire(U32SubtractionGate::::wire_ith_output_result(self.i)); - let output_borrow_wire = - local_wire(U32SubtractionGate::::wire_ith_output_borrow(self.i)); + let output_result_wire = local_wire(self.gate.wire_ith_output_result(self.i)); + let output_borrow_wire = local_wire(self.gate.wire_ith_output_borrow(self.i)); out_buffer.set_wire(output_result_wire, output_result); out_buffer.set_wire(output_borrow_wire, output_borrow); @@ -296,9 +301,7 @@ impl, const D: usize> SimpleGenerator .collect(); for j in 0..num_limbs { - let wire = local_wire(U32SubtractionGate::::wire_ith_output_jth_limb( - self.i, j, - )); + let wire = local_wire(self.gate.wire_ith_output_jth_limb(self.i, j)); out_buffer.set_wire(wire, output_limbs[j]); } } @@ -316,13 +319,14 @@ mod tests { use crate::field::goldilocks_field::GoldilocksField; use crate::gates::gate::Gate; use crate::gates::gate_testing::{test_eval_fns, test_low_degree}; - use crate::gates::subtraction_u32::{U32SubtractionGate, NUM_U32_SUBTRACTION_OPS}; + use crate::gates::subtraction_u32::U32SubtractionGate; use crate::hash::hash_types::HashOut; use crate::plonk::vars::EvaluationVars; #[test] fn low_degree() { test_low_degree::(U32SubtractionGate:: { + num_ops: 3, _phantom: PhantomData, }) } @@ -330,6 +334,7 @@ mod tests { #[test] fn eval_fns() -> Result<()> { test_eval_fns::(U32SubtractionGate:: { + num_ops: 3, _phantom: PhantomData, }) } @@ -339,6 +344,7 @@ mod tests { type F = GoldilocksField; type FF = QuarticExtension; const D: usize = 4; + const NUM_U32_SUBTRACTION_OPS: usize = 3; fn get_wires(inputs_x: Vec, inputs_y: Vec, borrows: Vec) -> Vec { let mut v0 = Vec::new(); @@ -399,6 +405,7 @@ mod tests { .collect(); let gate = U32SubtractionGate:: { + num_ops: NUM_U32_SUBTRACTION_OPS, _phantom: PhantomData, }; diff --git a/src/plonk/circuit_builder.rs b/src/plonk/circuit_builder.rs index 7799bc38..8b8ce1ff 100644 --- a/src/plonk/circuit_builder.rs +++ b/src/plonk/circuit_builder.rs @@ -24,7 +24,7 @@ use crate::gates::multiplication_extension::MulExtensionGate; use crate::gates::noop::NoopGate; use crate::gates::public_input::PublicInputGate; use crate::gates::random_access::RandomAccessGate; -use crate::gates::subtraction_u32::{U32SubtractionGate, NUM_U32_SUBTRACTION_OPS}; +use crate::gates::subtraction_u32::U32SubtractionGate; use crate::gates::switch::SwitchGate; use crate::hash::hash_types::{HashOutTarget, MerkleCapTarget}; use crate::hash::hashing::hash_n_to_hash; @@ -984,14 +984,14 @@ impl, const D: usize> CircuitBuilder { pub(crate) fn find_u32_subtraction_gate(&mut self) -> (usize, usize) { let (gate_index, copy) = match self.batched_gates.current_u32_subtraction_gate { None => { - let gate = U32SubtractionGate::new(); + let gate = U32SubtractionGate::new_from_config(&self.config); let gate_index = self.add_gate(gate, vec![]); (gate_index, 0) } Some((gate_index, copy)) => (gate_index, copy), }; - if copy == NUM_U32_SUBTRACTION_OPS - 1 { + if copy == U32SubtractionGate::::num_ops(&self.config) - 1 { self.batched_gates.current_u32_subtraction_gate = None; } else { self.batched_gates.current_u32_subtraction_gate = Some((gate_index, copy + 1)); @@ -1124,21 +1124,12 @@ impl, const D: usize> CircuitBuilder { /// Fill the remaining unused U32 subtraction operations with zeros, so that all /// `U32SubtractionGenerator`s are run. fn fill_u32_subtraction_gates(&mut self) { - let zero = self.zero(); - if let Some((gate_index, copy)) = self.batched_gates.current_u32_subtraction_gate { - for i in copy..NUM_U32_SUBTRACTION_OPS { - let wire_input_x = - Target::wire(gate_index, U32SubtractionGate::::wire_ith_input_x(i)); - let wire_input_y = - Target::wire(gate_index, U32SubtractionGate::::wire_ith_input_y(i)); - let wire_input_borrow = Target::wire( - gate_index, - U32SubtractionGate::::wire_ith_input_borrow(i), - ); - - self.connect(zero, wire_input_x); - self.connect(zero, wire_input_y); - self.connect(zero, wire_input_borrow); + let zero = self.zero_u32(); + if let Some((_gate_index, copy)) = self.batched_gates.current_u32_subtraction_gate { + for _i in copy..U32SubtractionGate::::num_ops(&self.config) { + let dummy = self.add_virtual_u32_target(); + self.sub_u32(dummy, dummy, dummy); + self.connect_u32(dummy, zero); } } }