From 300059572b0dd5bb53c0060bf4dc32f8a5a9bf90 Mon Sep 17 00:00:00 2001 From: nicholas-mainardi Date: Fri, 29 Sep 2023 15:57:56 +0200 Subject: [PATCH] Optimize lookup builder (#1258) * Add tests with big LUTs * Optimize lookup builder * Fix comment describing optimization * Cargo fmt * Clone LookupTableGate instead of instantiating * Remove needless enumerate + improving comments --- plonky2/src/gadgets/lookup.rs | 47 +++++++++++--- plonky2/src/lookup_test.rs | 114 ++++++++++++++++++++++++++++++++++ 2 files changed, 151 insertions(+), 10 deletions(-) diff --git a/plonky2/src/gadgets/lookup.rs b/plonky2/src/gadgets/lookup.rs index b71b7808..826f3e29 100644 --- a/plonky2/src/gadgets/lookup.rs +++ b/plonky2/src/gadgets/lookup.rs @@ -89,25 +89,52 @@ impl, const D: usize> CircuitBuilder { let lookups = self.get_lut_lookups(lut_index).to_owned(); - for (looking_in, looking_out) in lookups { - let gate = LookupGate::new_from_table(&self.config, lut.clone()); + let gate = LookupGate::new_from_table(&self.config, lut.clone()); + let num_slots = LookupGate::num_slots(&self.config); + + // Given the number of lookups and the number of slots for each gate, it is possible + // to compute the number of gates that will employ all their slots; such gates can + // can be instantiated with `add_gate` rather than being instantiated slot by slot + + // lookup_iter will iterate over the lookups that can be placed in fully utilized + // gates, splitting them in chunks that can be placed in the same `LookupGate` + let lookup_iter = lookups.chunks_exact(num_slots); + // `last_chunk` will contain the remainder of lookups, which cannot fill all the + // slots of a `LookupGate`; this last chunk will be processed by incrementally + // filling slots, to avoid that the `LookupGenerator` is run on unused slots + let last_chunk = lookup_iter.remainder(); + // handle chunks that can fill all the slots of a `LookupGate` + lookup_iter.for_each(|chunk| { + let row = self.add_gate(gate.clone(), vec![]); + for (i, (looking_in, looking_out)) in chunk.iter().enumerate() { + let gate_in = Target::wire(row, LookupGate::wire_ith_looking_inp(i)); + let gate_out = Target::wire(row, LookupGate::wire_ith_looking_out(i)); + self.connect(gate_in, *looking_in); + self.connect(gate_out, *looking_out); + } + }); + // deal with the last chunk + for (looking_in, looking_out) in last_chunk.iter() { let (gate, i) = - self.find_slot(gate, &[F::from_canonical_usize(lut_index)], &[]); + self.find_slot(gate.clone(), &[F::from_canonical_usize(lut_index)], &[]); let gate_in = Target::wire(gate, LookupGate::wire_ith_looking_inp(i)); let gate_out = Target::wire(gate, LookupGate::wire_ith_looking_out(i)); - self.connect(gate_in, looking_in); - self.connect(gate_out, looking_out); + self.connect(gate_in, *looking_in); + self.connect(gate_out, *looking_out); } // Create LUT gates. Nothing is connected to them. let last_lut_gate = self.num_gates(); let num_lut_entries = LookupTableGate::num_slots(&self.config); let num_lut_rows = (self.get_luts_idx_length(lut_index) - 1) / num_lut_entries + 1; - let num_lut_cells = num_lut_entries * num_lut_rows; - for _ in 0..num_lut_cells { - let gate = - LookupTableGate::new_from_table(&self.config, lut.clone(), last_lut_gate); - self.find_slot(gate, &[], &[]); + let gate = + LookupTableGate::new_from_table(&self.config, lut.clone(), last_lut_gate); + // Also instances of `LookupTableGate` can be placed with the `add_gate` function + // rather than being instantiated slot by slot; note that in this case there is no + // need to separately handle the last chunk of LUT entries that cannot fill all the + // slots of a `LookupTableGate`, as the generator already handles empty slots + for _ in 0..num_lut_rows { + self.add_gate(gate.clone(), vec![]); } let first_lut_gate = self.num_gates() - 1; diff --git a/plonky2/src/lookup_test.rs b/plonky2/src/lookup_test.rs index bca90d59..af85deca 100644 --- a/plonky2/src/lookup_test.rs +++ b/plonky2/src/lookup_test.rs @@ -467,6 +467,120 @@ pub fn test_same_luts() -> anyhow::Result<()> { Ok(()) } +#[test] +fn test_big_lut() -> anyhow::Result<()> { + use crate::field::types::Field; + use crate::iop::witness::{PartialWitness, WitnessWrite}; + use crate::plonk::circuit_builder::CircuitBuilder; + use crate::plonk::circuit_data::CircuitConfig; + use crate::plonk::config::{GenericConfig, PoseidonGoldilocksConfig}; + + const D: usize = 2; + type C = PoseidonGoldilocksConfig; + type F = >::F; + + LOGGER_INITIALIZED.call_once(|| init_logger().unwrap()); + let config = CircuitConfig::standard_recursion_config(); + let mut builder = CircuitBuilder::::new(config); + + const LUT_SIZE: usize = u16::MAX as usize + 1; + let inputs: [u16; LUT_SIZE] = core::array::from_fn(|i| i as u16); + let lut_fn = |inp: u16| inp / 10; + let lut_index = builder.add_lookup_table_from_fn(lut_fn, &inputs); + + let initial_a = builder.add_virtual_target(); + let initial_b = builder.add_virtual_target(); + + let look_val_a = 51; + let look_val_b = 2; + + let output_a = builder.add_lookup_from_index(initial_a, lut_index); + let output_b = builder.add_lookup_from_index(initial_b, lut_index); + + builder.register_public_input(output_a); + builder.register_public_input(output_b); + + let data = builder.build::(); + + let mut pw = PartialWitness::new(); + + pw.set_target(initial_a, F::from_canonical_u16(look_val_a)); + pw.set_target(initial_b, F::from_canonical_u16(look_val_b)); + + let proof = data.prove(pw)?; + assert_eq!( + proof.public_inputs[0], + F::from_canonical_u16(lut_fn(look_val_a)) + ); + assert_eq!( + proof.public_inputs[1], + F::from_canonical_u16(lut_fn(look_val_b)) + ); + + data.verify(proof) +} + +#[test] +fn test_many_lookups_on_big_lut() -> anyhow::Result<()> { + use crate::field::types::Field; + use crate::iop::witness::{PartialWitness, WitnessWrite}; + use crate::plonk::circuit_builder::CircuitBuilder; + use crate::plonk::circuit_data::CircuitConfig; + use crate::plonk::config::{GenericConfig, PoseidonGoldilocksConfig}; + + const D: usize = 2; + type C = PoseidonGoldilocksConfig; + type F = >::F; + + LOGGER_INITIALIZED.call_once(|| init_logger().unwrap()); + let config = CircuitConfig::standard_recursion_config(); + let mut builder = CircuitBuilder::::new(config); + + const LUT_SIZE: usize = u16::MAX as usize + 1; + let inputs: [u16; LUT_SIZE] = core::array::from_fn(|i| i as u16); + let lut_fn = |inp: u16| inp / 10; + let lut_index = builder.add_lookup_table_from_fn(lut_fn, &inputs); + + let inputs = (0..LUT_SIZE) + .map(|_| { + let input_target = builder.add_virtual_target(); + _ = builder.add_lookup_from_index(input_target, lut_index); + input_target + }) + .collect::>(); + + let initial_a = builder.add_virtual_target(); + let initial_b = builder.add_virtual_target(); + + let look_val_a = 51; + let look_val_b = 2; + + let output_a = builder.add_lookup_from_index(initial_a, lut_index); + let output_b = builder.add_lookup_from_index(initial_b, lut_index); + let sum = builder.add(output_a, output_b); + + builder.register_public_input(sum); + + let data = builder.build::(); + + let mut pw = PartialWitness::new(); + + inputs + .into_iter() + .enumerate() + .for_each(|(i, t)| pw.set_target(t, F::from_canonical_usize(i))); + pw.set_target(initial_a, F::from_canonical_u16(look_val_a)); + pw.set_target(initial_b, F::from_canonical_u16(look_val_b)); + + let proof = data.prove(pw)?; + assert_eq!( + proof.public_inputs[0], + F::from_canonical_u16(lut_fn(look_val_a) + lut_fn(look_val_b)) + ); + + data.verify(proof) +} + fn init_logger() -> anyhow::Result<()> { let mut builder = env_logger::Builder::from_default_env(); builder.format_timestamp(None);