Optimize lookup builder (#1258)

* Add tests with big LUTs

* Optimize lookup builder

* Fix comment describing optimization

* Cargo fmt

* Clone LookupTableGate instead of instantiating

* Remove needless enumerate + improving comments
This commit is contained in:
nicholas-mainardi 2023-09-29 15:57:56 +02:00 committed by GitHub
parent 1ff6d4a283
commit 300059572b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 151 additions and 10 deletions

View File

@ -89,25 +89,52 @@ impl<F: RichField + Extendable<D>, const D: usize> CircuitBuilder<F, D> {
let lookups = self.get_lut_lookups(lut_index).to_owned();
for (looking_in, looking_out) in lookups {
let gate = LookupGate::new_from_table(&self.config, lut.clone());
let gate = LookupGate::new_from_table(&self.config, lut.clone());
let num_slots = LookupGate::num_slots(&self.config);
// Given the number of lookups and the number of slots for each gate, it is possible
// to compute the number of gates that will employ all their slots; such gates can
// can be instantiated with `add_gate` rather than being instantiated slot by slot
// lookup_iter will iterate over the lookups that can be placed in fully utilized
// gates, splitting them in chunks that can be placed in the same `LookupGate`
let lookup_iter = lookups.chunks_exact(num_slots);
// `last_chunk` will contain the remainder of lookups, which cannot fill all the
// slots of a `LookupGate`; this last chunk will be processed by incrementally
// filling slots, to avoid that the `LookupGenerator` is run on unused slots
let last_chunk = lookup_iter.remainder();
// handle chunks that can fill all the slots of a `LookupGate`
lookup_iter.for_each(|chunk| {
let row = self.add_gate(gate.clone(), vec![]);
for (i, (looking_in, looking_out)) in chunk.iter().enumerate() {
let gate_in = Target::wire(row, LookupGate::wire_ith_looking_inp(i));
let gate_out = Target::wire(row, LookupGate::wire_ith_looking_out(i));
self.connect(gate_in, *looking_in);
self.connect(gate_out, *looking_out);
}
});
// deal with the last chunk
for (looking_in, looking_out) in last_chunk.iter() {
let (gate, i) =
self.find_slot(gate, &[F::from_canonical_usize(lut_index)], &[]);
self.find_slot(gate.clone(), &[F::from_canonical_usize(lut_index)], &[]);
let gate_in = Target::wire(gate, LookupGate::wire_ith_looking_inp(i));
let gate_out = Target::wire(gate, LookupGate::wire_ith_looking_out(i));
self.connect(gate_in, looking_in);
self.connect(gate_out, looking_out);
self.connect(gate_in, *looking_in);
self.connect(gate_out, *looking_out);
}
// Create LUT gates. Nothing is connected to them.
let last_lut_gate = self.num_gates();
let num_lut_entries = LookupTableGate::num_slots(&self.config);
let num_lut_rows = (self.get_luts_idx_length(lut_index) - 1) / num_lut_entries + 1;
let num_lut_cells = num_lut_entries * num_lut_rows;
for _ in 0..num_lut_cells {
let gate =
LookupTableGate::new_from_table(&self.config, lut.clone(), last_lut_gate);
self.find_slot(gate, &[], &[]);
let gate =
LookupTableGate::new_from_table(&self.config, lut.clone(), last_lut_gate);
// Also instances of `LookupTableGate` can be placed with the `add_gate` function
// rather than being instantiated slot by slot; note that in this case there is no
// need to separately handle the last chunk of LUT entries that cannot fill all the
// slots of a `LookupTableGate`, as the generator already handles empty slots
for _ in 0..num_lut_rows {
self.add_gate(gate.clone(), vec![]);
}
let first_lut_gate = self.num_gates() - 1;

View File

@ -467,6 +467,120 @@ pub fn test_same_luts() -> anyhow::Result<()> {
Ok(())
}
#[test]
fn test_big_lut() -> anyhow::Result<()> {
use crate::field::types::Field;
use crate::iop::witness::{PartialWitness, WitnessWrite};
use crate::plonk::circuit_builder::CircuitBuilder;
use crate::plonk::circuit_data::CircuitConfig;
use crate::plonk::config::{GenericConfig, PoseidonGoldilocksConfig};
const D: usize = 2;
type C = PoseidonGoldilocksConfig;
type F = <C as GenericConfig<D>>::F;
LOGGER_INITIALIZED.call_once(|| init_logger().unwrap());
let config = CircuitConfig::standard_recursion_config();
let mut builder = CircuitBuilder::<F, D>::new(config);
const LUT_SIZE: usize = u16::MAX as usize + 1;
let inputs: [u16; LUT_SIZE] = core::array::from_fn(|i| i as u16);
let lut_fn = |inp: u16| inp / 10;
let lut_index = builder.add_lookup_table_from_fn(lut_fn, &inputs);
let initial_a = builder.add_virtual_target();
let initial_b = builder.add_virtual_target();
let look_val_a = 51;
let look_val_b = 2;
let output_a = builder.add_lookup_from_index(initial_a, lut_index);
let output_b = builder.add_lookup_from_index(initial_b, lut_index);
builder.register_public_input(output_a);
builder.register_public_input(output_b);
let data = builder.build::<C>();
let mut pw = PartialWitness::new();
pw.set_target(initial_a, F::from_canonical_u16(look_val_a));
pw.set_target(initial_b, F::from_canonical_u16(look_val_b));
let proof = data.prove(pw)?;
assert_eq!(
proof.public_inputs[0],
F::from_canonical_u16(lut_fn(look_val_a))
);
assert_eq!(
proof.public_inputs[1],
F::from_canonical_u16(lut_fn(look_val_b))
);
data.verify(proof)
}
#[test]
fn test_many_lookups_on_big_lut() -> anyhow::Result<()> {
use crate::field::types::Field;
use crate::iop::witness::{PartialWitness, WitnessWrite};
use crate::plonk::circuit_builder::CircuitBuilder;
use crate::plonk::circuit_data::CircuitConfig;
use crate::plonk::config::{GenericConfig, PoseidonGoldilocksConfig};
const D: usize = 2;
type C = PoseidonGoldilocksConfig;
type F = <C as GenericConfig<D>>::F;
LOGGER_INITIALIZED.call_once(|| init_logger().unwrap());
let config = CircuitConfig::standard_recursion_config();
let mut builder = CircuitBuilder::<F, D>::new(config);
const LUT_SIZE: usize = u16::MAX as usize + 1;
let inputs: [u16; LUT_SIZE] = core::array::from_fn(|i| i as u16);
let lut_fn = |inp: u16| inp / 10;
let lut_index = builder.add_lookup_table_from_fn(lut_fn, &inputs);
let inputs = (0..LUT_SIZE)
.map(|_| {
let input_target = builder.add_virtual_target();
_ = builder.add_lookup_from_index(input_target, lut_index);
input_target
})
.collect::<Vec<_>>();
let initial_a = builder.add_virtual_target();
let initial_b = builder.add_virtual_target();
let look_val_a = 51;
let look_val_b = 2;
let output_a = builder.add_lookup_from_index(initial_a, lut_index);
let output_b = builder.add_lookup_from_index(initial_b, lut_index);
let sum = builder.add(output_a, output_b);
builder.register_public_input(sum);
let data = builder.build::<C>();
let mut pw = PartialWitness::new();
inputs
.into_iter()
.enumerate()
.for_each(|(i, t)| pw.set_target(t, F::from_canonical_usize(i)));
pw.set_target(initial_a, F::from_canonical_u16(look_val_a));
pw.set_target(initial_b, F::from_canonical_u16(look_val_b));
let proof = data.prove(pw)?;
assert_eq!(
proof.public_inputs[0],
F::from_canonical_u16(lut_fn(look_val_a) + lut_fn(look_val_b))
);
data.verify(proof)
}
fn init_logger() -> anyhow::Result<()> {
let mut builder = env_logger::Builder::from_default_env();
builder.format_timestamp(None);