Have add_gate take a generic type instead of GateRef (#125)

* Have add_gate take a generic type instead of GateRef

There are a couple advantages
- Users writing their own gates won't need to know about the `GateRef` wrapper; it's more of an internal thing now.
- Easier access to gate methods requiring `self` -- for example, `split_le_base` can just call `gate_type.limbs()` now.

* Update comment

* Always insert
This commit is contained in:
Daniel Lubarov 2021-07-22 23:48:03 -07:00 committed by GitHub
parent d435720d04
commit bcf524bed0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
20 changed files with 77 additions and 115 deletions

View File

@ -14,7 +14,7 @@ use crate::field::cosets::get_unique_coset_shifts;
use crate::field::extension_field::target::ExtensionTarget;
use crate::field::extension_field::Extendable;
use crate::gates::constant::ConstantGate;
use crate::gates::gate::{GateInstance, GateRef, PrefixedGate};
use crate::gates::gate::{Gate, GateInstance, GateRef, PrefixedGate};
use crate::gates::gate_tree::Tree;
use crate::gates::noop::NoopGate;
use crate::gates::public_input::PublicInputGate;
@ -124,41 +124,35 @@ impl<F: Extendable<D>, const D: usize> CircuitBuilder<F, D> {
.collect()
}
pub fn add_gate_no_constants(&mut self, gate_type: GateRef<F, D>) -> usize {
self.add_gate(gate_type, Vec::new())
}
/// Adds a gate to the circuit, and returns its index.
pub fn add_gate(&mut self, gate_type: GateRef<F, D>, constants: Vec<F>) -> usize {
pub fn add_gate<G: Gate<F, D>>(&mut self, gate_type: G, constants: Vec<F>) -> usize {
self.check_gate_compatibility(&gate_type);
assert_eq!(
gate_type.0.num_constants(),
gate_type.num_constants(),
constants.len(),
"Number of constants doesn't match."
);
// If we haven't seen a gate of this type before, check that it's compatible with our
// circuit configuration, then register it.
if !self.gates.contains(&gate_type) {
self.check_gate_compatibility(&gate_type);
self.gates.insert(gate_type.clone());
}
let index = self.gate_instances.len();
self.add_generators(gate_type.generators(index, &constants));
self.add_generators(gate_type.0.generators(index, &constants));
// Register this gate type if we haven't seen it before.
let gate_ref = GateRef::new(gate_type);
self.gates.insert(gate_ref.clone());
self.gate_instances.push(GateInstance {
gate_type,
gate_ref,
constants,
});
index
}
fn check_gate_compatibility(&self, gate: &GateRef<F, D>) {
fn check_gate_compatibility<G: Gate<F, D>>(&self, gate: &G) {
assert!(
gate.0.num_wires() <= self.config.num_wires,
gate.num_wires() <= self.config.num_wires,
"{:?} requires {} wires, but our GateConfig has only {}",
gate.0.id(),
gate.0.num_wires(),
gate.id(),
gate.num_wires(),
self.config.num_wires
);
}
@ -287,7 +281,7 @@ impl<F: Extendable<D>, const D: usize> CircuitBuilder<F, D> {
return target;
}
let gate = self.add_gate(ConstantGate::get(), vec![c]);
let gate = self.add_gate(ConstantGate, vec![c]);
let target = Target::Wire(Wire {
gate,
input: ConstantGate::WIRE_OUTPUT,
@ -377,7 +371,7 @@ impl<F: Extendable<D>, const D: usize> CircuitBuilder<F, D> {
}
while !self.gate_instances.len().is_power_of_two() {
self.add_gate_no_constants(NoopGate::get());
self.add_gate(NoopGate, vec![]);
}
}
@ -394,7 +388,7 @@ impl<F: Extendable<D>, const D: usize> CircuitBuilder<F, D> {
// For each "regular" blinding factor, we simply add a no-op gate, and insert a random value
// for each wire.
for _ in 0..regular_poly_openings {
let gate = self.add_gate_no_constants(NoopGate::get());
let gate = self.add_gate(NoopGate, vec![]);
for w in 0..num_wires {
self.add_generator(RandomValueGenerator {
target: Target::Wire(Wire { gate, input: w }),
@ -406,8 +400,8 @@ impl<F: Extendable<D>, const D: usize> CircuitBuilder<F, D> {
// enforce a copy constraint between them.
// See https://mirprotocol.org/blog/Adding-zero-knowledge-to-Plonk-Halo
for _ in 0..z_openings {
let gate_1 = self.add_gate_no_constants(NoopGate::get());
let gate_2 = self.add_gate_no_constants(NoopGate::get());
let gate_1 = self.add_gate(NoopGate, vec![]);
let gate_2 = self.add_gate(NoopGate, vec![]);
for w in 0..num_routed_wires {
self.add_generator(RandomValueGenerator {
@ -441,7 +435,7 @@ impl<F: Extendable<D>, const D: usize> CircuitBuilder<F, D> {
.map(|gate| {
let prefix = &gates
.iter()
.find(|g| g.gate.0.id() == gate.gate_type.0.id())
.find(|g| g.gate.0.id() == gate.gate_ref.0.id())
.unwrap()
.prefix;
let mut prefixed_constants = Vec::with_capacity(num_constants);
@ -498,7 +492,7 @@ impl<F: Extendable<D>, const D: usize> CircuitBuilder<F, D> {
// Hash the public inputs, and route them to a `PublicInputGate` which will enforce that
// those hash wires match the claimed public inputs.
let public_inputs_hash = self.hash_n_to_hash(self.public_inputs.clone(), true);
let pi_gate = self.add_gate_no_constants(PublicInputGate::get());
let pi_gate = self.add_gate(PublicInputGate, vec![]);
for (&hash_part, wire) in public_inputs_hash
.elements
.iter()

View File

@ -24,7 +24,7 @@ impl<F: Extendable<D>, const D: usize> CircuitBuilder<F, D> {
second_multiplicand_1: ExtensionTarget<D>,
second_addend: ExtensionTarget<D>,
) -> (ExtensionTarget<D>, ExtensionTarget<D>) {
let gate = self.add_gate(ArithmeticExtensionGate::new(), vec![const_0, const_1]);
let gate = self.add_gate(ArithmeticExtensionGate, vec![const_0, const_1]);
let wire_first_multiplicand_0 = ExtensionTarget::from_range(
gate,

View File

@ -11,8 +11,8 @@ use crate::wire::Wire;
impl<F: Extendable<D>, const D: usize> CircuitBuilder<F, D> {
pub fn permute(&mut self, inputs: [Target; 12]) -> [Target; 12] {
let zero = self.zero();
let gate =
self.add_gate_no_constants(GMiMCGate::<F, D, GMIMC_ROUNDS>::with_automatic_constants());
let gate_type = GMiMCGate::<F, D, GMIMC_ROUNDS>::new_automatic_constants();
let gate = self.add_gate(gate_type, vec![]);
// We don't want to swap any inputs, so set that wire to 0.
let swap_wire = GMiMCGate::<F, D, GMIMC_ROUNDS>::WIRE_SWAP;

View File

@ -15,11 +15,8 @@ impl<F: Extendable<D>, const D: usize> CircuitBuilder<F, D> {
element: ExtensionTarget<D>,
v: Vec<ExtensionTarget<D>>,
) -> Vec<ExtensionTarget<D>> {
let gate = InsertionGate::<F, D> {
vec_size: v.len(),
_phantom: PhantomData,
};
let gate_index = self.add_gate_no_constants(InsertionGate::new(v.len()));
let gate = InsertionGate::new(v.len());
let gate_index = self.add_gate(gate.clone(), vec![]);
v.iter().enumerate().for_each(|(i, &val)| {
self.route_extension(

View File

@ -1,5 +1,3 @@
use std::marker::PhantomData;
use crate::circuit_builder::CircuitBuilder;
use crate::field::extension_field::target::ExtensionTarget;
use crate::field::extension_field::Extendable;
@ -33,12 +31,8 @@ impl<F: Extendable<D>, const D: usize> CircuitBuilder<F, D> {
interpolation_points: &[(Target, ExtensionTarget<D>)],
evaluation_point: ExtensionTarget<D>,
) -> ExtensionTarget<D> {
let gate = InterpolationGate::<F, D> {
num_points: interpolation_points.len(),
_phantom: PhantomData,
};
let gate_index =
self.add_gate_no_constants(InterpolationGate::new(interpolation_points.len()));
let gate = InterpolationGate::new(interpolation_points.len());
let gate_index = self.add_gate(gate.clone(), vec![]);
for (i, &(p, v)) in interpolation_points.iter().enumerate() {
self.route(p, Target::wire(gate_index, gate.wire_point(i)));
self.route_extension(

View File

@ -16,14 +16,12 @@ impl<F: Extendable<D>, const D: usize> CircuitBuilder<F, D> {
x: Target,
num_limbs: usize,
) -> Vec<Target> {
let gate = self.add_gate(BaseSumGate::<B>::new(num_limbs), vec![]);
let gate_type = BaseSumGate::<B>::new(num_limbs);
let gate = self.add_gate(gate_type.clone(), vec![]);
let sum = Target::wire(gate, BaseSumGate::<B>::WIRE_SUM);
self.route(x, sum);
Target::wires_from_range(
gate,
BaseSumGate::<B>::START_LIMBS..BaseSumGate::<B>::START_LIMBS + num_limbs,
)
Target::wires_from_range(gate, gate_type.limbs())
}
/// Asserts that `x`'s big-endian bit representation has at least `leading_zeros` leading zeros.

View File

@ -35,7 +35,7 @@ impl<F: Extendable<D>, const D: usize> CircuitBuilder<F, D> {
let bits_per_gate = self.config.num_routed_wires - BaseSumGate::<2>::START_LIMBS;
let k = ceil_div_usize(num_bits, bits_per_gate);
let gates = (0..k)
.map(|_| self.add_gate_no_constants(BaseSumGate::<2>::new(bits_per_gate)))
.map(|_| self.add_gate(BaseSumGate::<2>::new(bits_per_gate), vec![]))
.collect::<Vec<_>>();
let mut bits = Vec::with_capacity(num_bits);

View File

@ -2,8 +2,9 @@ use std::ops::Range;
use crate::circuit_builder::CircuitBuilder;
use crate::field::extension_field::target::ExtensionTarget;
use crate::field::extension_field::{Extendable, FieldExtension};
use crate::gates::gate::{Gate, GateRef};
use crate::field::extension_field::Extendable;
use crate::field::extension_field::FieldExtension;
use crate::gates::gate::Gate;
use crate::generator::{GeneratedValues, SimpleGenerator, WitnessGenerator};
use crate::target::Target;
use crate::vars::{EvaluationTargets, EvaluationVars, EvaluationVarsBase};
@ -14,10 +15,6 @@ use crate::witness::PartialWitness;
pub struct ArithmeticExtensionGate<const D: usize>;
impl<const D: usize> ArithmeticExtensionGate<D> {
pub fn new<F: Extendable<D>>() -> GateRef<F, D> {
GateRef::new(ArithmeticExtensionGate)
}
pub fn wires_first_multiplicand_0() -> Range<usize> {
0..D
}
@ -259,6 +256,6 @@ mod tests {
#[test]
fn low_degree() {
test_low_degree(ArithmeticExtensionGate::<4>::new::<CrandallField>())
test_low_degree::<CrandallField, _, 4>(ArithmeticExtensionGate)
}
}

View File

@ -13,14 +13,14 @@ use crate::witness::PartialWitness;
/// A gate which can decompose a number into base B little-endian limbs,
/// and compute the limb-reversed (i.e. big-endian) sum.
#[derive(Debug)]
#[derive(Clone, Debug)]
pub struct BaseSumGate<const B: usize> {
num_limbs: usize,
}
impl<const B: usize> BaseSumGate<B> {
pub fn new<F: Extendable<D>, const D: usize>(num_limbs: usize) -> GateRef<F, D> {
GateRef::new(BaseSumGate::<B> { num_limbs })
pub fn new(num_limbs: usize) -> Self {
Self { num_limbs }
}
pub const WIRE_SUM: usize = 0;
@ -186,10 +186,11 @@ impl<F: Field, const B: usize> SimpleGenerator<F> for BaseSplitGenerator<B> {
mod tests {
use crate::field::crandall_field::CrandallField;
use crate::gates::base_sum::BaseSumGate;
use crate::gates::gate::GateRef;
use crate::gates::gate_testing::test_low_degree;
#[test]
fn low_degree() {
test_low_degree(BaseSumGate::<6>::new::<CrandallField, 4>(11))
test_low_degree::<CrandallField, _, 4>(BaseSumGate::<6>::new(11))
}
}

View File

@ -13,10 +13,6 @@ use crate::witness::PartialWitness;
pub struct ConstantGate;
impl ConstantGate {
pub fn get<F: Extendable<D>, const D: usize>() -> GateRef<F, D> {
GateRef::new(ConstantGate)
}
pub const CONST_INPUT: usize = 0;
pub const WIRE_OUTPUT: usize = 0;
@ -106,6 +102,6 @@ mod tests {
#[test]
fn low_degree() {
test_low_degree(ConstantGate::get::<CrandallField, 4>())
test_low_degree::<CrandallField, _, 4>(ConstantGate)
}
}

View File

@ -139,7 +139,7 @@ impl<F: Extendable<D>, const D: usize> Debug for GateRef<F, D> {
/// A gate along with any constants used to configure it.
pub struct GateInstance<F: Extendable<D>, const D: usize> {
pub gate_type: GateRef<F, D>,
pub gate_ref: GateRef<F, D>,
pub constants: Vec<F>,
}

View File

@ -1,6 +1,6 @@
use crate::field::extension_field::Extendable;
use crate::field::field::Field;
use crate::gates::gate::GateRef;
use crate::gates::gate::{Gate, GateRef};
use crate::polynomial::polynomial::{PolynomialCoeffs, PolynomialValues};
use crate::proof::Hash;
use crate::util::{log2_ceil, transpose};
@ -11,8 +11,7 @@ const WITNESS_DEGREE: usize = WITNESS_SIZE - 1;
/// Tests that the constraints imposed by the given gate are low-degree by applying them to random
/// low-degree witness polynomials.
pub(crate) fn test_low_degree<F: Extendable<D>, const D: usize>(gate: GateRef<F, D>) {
let gate = gate.0;
pub(crate) fn test_low_degree<F: Extendable<D>, G: Gate<F, D>, const D: usize>(gate: G) {
let rate_bits = log2_ceil(gate.degree() + 1);
let wire_ldes = random_low_degree_matrix::<F::Extension>(gate.num_wires(), rate_bits);

View File

@ -237,12 +237,12 @@ mod tests {
const D: usize = 4;
let gates = vec![
NoopGate::get::<F, D>(),
ConstantGate::get(),
ArithmeticExtensionGate::new(),
BaseSumGate::<4>::new(4),
GMiMCGate::<F, D, GMIMC_ROUNDS>::with_automatic_constants(),
InterpolationGate::new(4),
GateRef::new(NoopGate),
GateRef::new(ConstantGate),
GateRef::new(ArithmeticExtensionGate),
GateRef::new(BaseSumGate::<4>::new(4)),
GateRef::new(GMiMCGate::<F, D, GMIMC_ROUNDS>::new_automatic_constants()),
GateRef::new(InterpolationGate::new(4)),
];
let len = gates.len();

View File

@ -4,7 +4,7 @@ use crate::circuit_builder::CircuitBuilder;
use crate::field::extension_field::target::ExtensionTarget;
use crate::field::extension_field::Extendable;
use crate::field::field::Field;
use crate::gates::gate::{Gate, GateRef};
use crate::gates::gate::Gate;
use crate::generator::{GeneratedValues, SimpleGenerator, WitnessGenerator};
use crate::gmimc::gmimc_automatic_constants;
use crate::target::Target;
@ -28,14 +28,13 @@ pub struct GMiMCGate<F: Extendable<D>, const D: usize, const R: usize> {
}
impl<F: Extendable<D>, const D: usize, const R: usize> GMiMCGate<F, D, R> {
pub fn with_constants(constants: Arc<[F; R]>) -> GateRef<F, D> {
let gate = GMiMCGate::<F, D, R> { constants };
GateRef::new(gate)
pub fn new(constants: Arc<[F; R]>) -> Self {
Self { constants }
}
pub fn with_automatic_constants() -> GateRef<F, D> {
pub fn new_automatic_constants() -> Self {
let constants = Arc::new(gmimc_automatic_constants::<F, R>());
Self::with_constants(constants)
Self::new(constants)
}
/// The wire index for the `i`th input to the permutation.
@ -335,6 +334,7 @@ mod tests {
use crate::field::crandall_field::CrandallField;
use crate::field::extension_field::quartic::QuarticCrandallField;
use crate::field::field::Field;
use crate::gates::gate::{Gate, GateRef};
use crate::gates::gate_testing::test_low_degree;
use crate::gates::gmimc::{GMiMCGate, W};
use crate::generator::generate_partial_witness;
@ -351,7 +351,7 @@ mod tests {
const R: usize = 101;
let constants = Arc::new([F::TWO; R]);
type Gate = GMiMCGate<F, 4, R>;
let gate = Gate::with_constants(constants.clone());
let gate = Gate::new(constants.clone());
let permutation_inputs = (0..W).map(F::from_canonical_usize).collect::<Vec<_>>();
@ -373,7 +373,7 @@ mod tests {
);
}
let generators = gate.0.generators(0, &[]);
let generators = gate.generators(0, &[]);
generate_partial_witness(&mut witness, &generators);
let expected_outputs: [F; W] =
@ -393,8 +393,7 @@ mod tests {
type F = CrandallField;
const R: usize = 101;
let constants = Arc::new([F::TWO; R]);
type Gate = GMiMCGate<F, 4, R>;
let gate = Gate::with_constants(constants);
let gate = GMiMCGate::<F, 4, R>::new(constants);
test_low_degree(gate)
}
@ -408,7 +407,7 @@ mod tests {
let mut pw = PartialWitness::<F>::new();
let constants = Arc::new([F::TWO; R]);
type Gate = GMiMCGate<F, 4, R>;
let gate = Gate::with_constants(constants);
let gate = Gate::new(constants);
let wires = FF::rand_vec(Gate::end());
let public_inputs_hash = &Hash::rand();
@ -418,7 +417,7 @@ mod tests {
public_inputs_hash,
};
let ev = gate.0.eval_unfiltered(vars);
let ev = gate.eval_unfiltered(vars);
let wires_t = builder.add_virtual_extension_targets(Gate::end());
for i in 0..Gate::end() {
@ -434,7 +433,7 @@ mod tests {
public_inputs_hash: &public_inputs_hash_t,
};
let ev_t = gate.0.eval_unfiltered_recursively(&mut builder, vars_t);
let ev_t = gate.eval_unfiltered_recursively(&mut builder, vars_t);
assert_eq!(ev.len(), ev_t.len());
for (e, e_t) in ev.into_iter().zip(ev_t) {

View File

@ -17,16 +17,15 @@ use crate::witness::PartialWitness;
#[derive(Clone, Debug)]
pub(crate) struct InsertionGate<F: Extendable<D>, const D: usize> {
pub vec_size: usize,
pub _phantom: PhantomData<F>,
_phantom: PhantomData<F>,
}
impl<F: Extendable<D>, const D: usize> InsertionGate<F, D> {
pub fn new(vec_size: usize) -> GateRef<F, D> {
let gate = Self {
pub fn new(vec_size: usize) -> Self {
Self {
vec_size,
_phantom: PhantomData,
};
GateRef::new(gate)
}
}
pub fn wires_insertion_index(&self) -> usize {
@ -350,8 +349,7 @@ mod tests {
#[test]
fn low_degree() {
type F = CrandallField;
test_low_degree(InsertionGate::<F, 4>::new(4));
test_low_degree::<CrandallField, _, 4>(InsertionGate::new(4));
}
#[test]

View File

@ -24,16 +24,15 @@ use crate::witness::PartialWitness;
#[derive(Clone, Debug)]
pub(crate) struct InterpolationGate<F: Extendable<D>, const D: usize> {
pub num_points: usize,
pub _phantom: PhantomData<F>,
_phantom: PhantomData<F>,
}
impl<F: Extendable<D>, const D: usize> InterpolationGate<F, D> {
pub fn new(num_points: usize) -> GateRef<F, D> {
let gate = Self {
pub fn new(num_points: usize) -> Self {
Self {
num_points,
_phantom: PhantomData,
};
GateRef::new(gate)
}
}
fn start_points(&self) -> usize {
@ -321,7 +320,7 @@ mod tests {
#[test]
fn low_degree() {
type F = CrandallField;
test_low_degree(InterpolationGate::<F, 4>::new(4));
test_low_degree::<CrandallField, _, 4>(InterpolationGate::new(4));
}
#[test]

View File

@ -8,12 +8,6 @@ use crate::vars::{EvaluationTargets, EvaluationVars, EvaluationVarsBase};
/// A gate which does nothing.
pub struct NoopGate;
impl NoopGate {
pub fn get<F: Extendable<D>, const D: usize>() -> GateRef<F, D> {
GateRef::new(NoopGate)
}
}
impl<F: Extendable<D>, const D: usize> Gate<F, D> for NoopGate {
fn id(&self) -> String {
"NoopGate".into()
@ -68,6 +62,6 @@ mod tests {
#[test]
fn low_degree() {
test_low_degree(NoopGate::get::<CrandallField, 4>())
test_low_degree::<CrandallField, _, 4>(NoopGate)
}
}

View File

@ -11,10 +11,6 @@ use crate::vars::{EvaluationTargets, EvaluationVars, EvaluationVarsBase};
pub struct PublicInputGate;
impl PublicInputGate {
pub fn get<F: Extendable<D>, const D: usize>() -> GateRef<F, D> {
GateRef::new(PublicInputGate)
}
pub fn wires_public_inputs_hash() -> Range<usize> {
0..4
}
@ -86,6 +82,6 @@ mod tests {
#[test]
fn low_degree() {
test_low_degree(PublicInputGate::get::<CrandallField, 4>())
test_low_degree::<CrandallField, _, 4>(PublicInputGate)
}
}

View File

@ -72,8 +72,8 @@ impl<F: Extendable<D>, const D: usize> CircuitBuilder<F, D> {
let mut state: HashTarget = self.hash_or_noop(leaf_data);
for (&bit, &sibling) in leaf_index_bits.iter().zip(&proof.siblings) {
let gate = self
.add_gate_no_constants(GMiMCGate::<F, D, GMIMC_ROUNDS>::with_automatic_constants());
let gate_type = GMiMCGate::<F, D, GMIMC_ROUNDS>::new_automatic_constants();
let gate = self.add_gate(gate_type, vec![]);
let swap_wire = GMiMCGate::<F, D, GMIMC_ROUNDS>::WIRE_SWAP;
let swap_wire = Target::Wire(Wire {

View File

@ -170,7 +170,7 @@ impl<F: Field> PartialWitness<F> {
"wire {} of gate #{} (`{}`)",
input,
gate,
gate_instances[*gate].gate_type.0.id()
gate_instances[*gate].gate_ref.0.id()
),
Target::VirtualTarget { index } => format!("{}-th virtual target", index),
}