mirror of
https://github.com/logos-storage/plonky2.git
synced 2026-01-07 00:03:10 +00:00
Working ReducingFactorTarget
This commit is contained in:
parent
beadce72fc
commit
8a119f035d
@ -7,6 +7,7 @@ use crate::field::field::Field;
|
||||
use crate::gates::mul_extension::ArithmeticExtensionGate;
|
||||
use crate::generator::SimpleGenerator;
|
||||
use crate::target::Target;
|
||||
use crate::util::bits_u64;
|
||||
use crate::wire::Wire;
|
||||
use crate::witness::PartialWitness;
|
||||
|
||||
@ -22,6 +23,11 @@ impl<F: Extendable<D>, const D: usize> CircuitBuilder<F, D> {
|
||||
self.mul(x, x)
|
||||
}
|
||||
|
||||
/// Computes `x^2`.
|
||||
pub fn square_extension(&mut self, x: ExtensionTarget<D>) -> ExtensionTarget<D> {
|
||||
self.mul_extension(x, x)
|
||||
}
|
||||
|
||||
/// Computes `x^3`.
|
||||
pub fn cube(&mut self, x: Target) -> Target {
|
||||
self.mul_many(&[x, x, x])
|
||||
@ -161,21 +167,58 @@ impl<F: Extendable<D>, const D: usize> CircuitBuilder<F, D> {
|
||||
}
|
||||
|
||||
// TODO: Optimize this, maybe with a new gate.
|
||||
// TODO: Test
|
||||
/// Exponentiate `base` to the power of `exponent`, where `exponent < 2^num_bits`.
|
||||
pub fn exp(&mut self, base: Target, exponent: Target, num_bits: usize) -> Target {
|
||||
let mut current = base;
|
||||
let one = self.one();
|
||||
let mut product = one;
|
||||
let one_ext = self.one_extension();
|
||||
let mut product = self.one();
|
||||
let exponent_bits = self.split_le(exponent, num_bits);
|
||||
|
||||
for bit in exponent_bits.into_iter() {
|
||||
product = self.mul_many(&[bit, current, product]);
|
||||
let current_ext = self.convert_to_ext(current);
|
||||
let multiplicand = self.select(bit, current_ext, one_ext);
|
||||
product = self.mul(product, multiplicand.0[0]);
|
||||
current = self.mul(current, current);
|
||||
}
|
||||
|
||||
product
|
||||
}
|
||||
|
||||
/// Exponentiate `base` to the power of a known `exponent`.
|
||||
// TODO: Test
|
||||
pub fn exp_u64(&mut self, base: Target, exponent: u64) -> Target {
|
||||
let mut current = base;
|
||||
let mut product = self.one();
|
||||
|
||||
for j in 0..bits_u64(exponent as u64) {
|
||||
if (exponent >> j & 1) != 0 {
|
||||
product = self.mul(product, current);
|
||||
}
|
||||
current = self.square(current);
|
||||
}
|
||||
product
|
||||
}
|
||||
|
||||
/// Exponentiate `base` to the power of a known `exponent`.
|
||||
// TODO: Test
|
||||
pub fn exp_u64_extension(
|
||||
&mut self,
|
||||
base: ExtensionTarget<D>,
|
||||
exponent: u64,
|
||||
) -> ExtensionTarget<D> {
|
||||
let mut current = base;
|
||||
let mut product = self.one_extension();
|
||||
|
||||
for j in 0..bits_u64(exponent as u64) {
|
||||
if (exponent >> j & 1) != 0 {
|
||||
product = self.mul_extension(product, current);
|
||||
}
|
||||
current = self.square_extension(current);
|
||||
}
|
||||
product
|
||||
}
|
||||
|
||||
/// Computes `q = x / y` by witnessing `q` and requiring that `q * y = x`. This can be unsafe in
|
||||
/// some cases, as it allows `0 / 0 = <anything>`.
|
||||
pub fn div_unsafe(&mut self, x: Target, y: Target) -> Target {
|
||||
|
||||
@ -80,112 +80,252 @@ impl<F: Field> ReducingFactor<F> {
|
||||
}
|
||||
}
|
||||
|
||||
// #[derive(Debug, Copy, Clone)]
|
||||
// pub struct ReducingFactorTarget<const D: usize> {
|
||||
// base: ExtensionTarget<D>,
|
||||
// count: u64,
|
||||
// }
|
||||
//
|
||||
// impl<F: Extendable<D>, const D: usize> ReducingFactorTarget<D> {
|
||||
// pub fn new(base: ExtensionTarget<D>) -> Self {
|
||||
// Self { base, count: 0 }
|
||||
// }
|
||||
//
|
||||
// fn mul(
|
||||
// &mut self,
|
||||
// x: ExtensionTarget<D>,
|
||||
// builder: &mut CircuitBuilder<F, D>,
|
||||
// ) -> ExtensionTarget<D> {
|
||||
// self.count += 1;
|
||||
// builder.mul_extension(self.base, x)
|
||||
// }
|
||||
//
|
||||
// pub fn reduce(
|
||||
// &mut self,
|
||||
// iter: &[ExtensionTarget<D>], // Could probably work with a `DoubleEndedIterator` too.
|
||||
// builder: &mut CircuitBuilder<F, D>,
|
||||
// ) -> ExtensionTarget<D> {
|
||||
// let l = iter.len();
|
||||
// let padded_iter = if l % 2 == 0 {
|
||||
// iter.to_vec()
|
||||
// } else {
|
||||
// [iter, &[builder.zero_extension()]].concat()
|
||||
// };
|
||||
// let half_length = padded_iter.len() / 2;
|
||||
// let gates = (0..half_length)
|
||||
// .map(|_| builder.add_gate(ArithmeticExtensionGate::new(), vec![F::ONE, F::ONE]))
|
||||
// .collect::<Vec<_>>();
|
||||
//
|
||||
// struct ParallelReductionGenerator<'a, const D: usize> {
|
||||
// base: ExtensionTarget<D>,
|
||||
// padded_iter: &'a [ExtensionTarget<D>],
|
||||
// gates: &'a [usize],
|
||||
// half_length: usize,
|
||||
// }
|
||||
//
|
||||
// impl<'a, F: Extendable<D>, const D: usize> SimpleGenerator<F>
|
||||
// for ParallelReductionGenerator<'a, D>
|
||||
// {
|
||||
// fn dependencies(&self) -> Vec<Target> {
|
||||
// self.padded_iter
|
||||
// .iter()
|
||||
// .flat_map(|ext| ext.to_target_array())
|
||||
// .chain(self.base.to_target_array())
|
||||
// .collect()
|
||||
// }
|
||||
//
|
||||
// fn run_once(&self, witness: &PartialWitness<F>) -> PartialWitness<F> {
|
||||
// let mut pw = PartialWitness::new();
|
||||
// let base = witness.get_extension_target(self.base);
|
||||
// let vs = self
|
||||
// .padded_iter
|
||||
// .iter()
|
||||
// .map(|&ext| witness.get_extension_target(ext))
|
||||
// .collect::<Vec<_>>();
|
||||
// let first_half = &vs[..self.half_length];
|
||||
// let intermediate_acc = base.reduce(first_half);
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
//
|
||||
// pub fn reduce_parallel(
|
||||
// &mut self,
|
||||
// iter0: impl DoubleEndedIterator<Item = impl Borrow<ExtensionTarget<D>>>,
|
||||
// iter1: impl DoubleEndedIterator<Item = impl Borrow<ExtensionTarget<D>>>,
|
||||
// builder: &mut CircuitBuilder<F, D>,
|
||||
// ) -> (ExtensionTarget<D>, ExtensionTarget<D>) {
|
||||
// iter.rev().fold(builder.zero_extension(), |acc, x| {
|
||||
// builder.arithmetic_extension(F::ONE, F::ONE, self.base, acc, x)
|
||||
// })
|
||||
// }
|
||||
//
|
||||
// pub fn shift(
|
||||
// &mut self,
|
||||
// x: ExtensionTarget<D>,
|
||||
// builder: &mut CircuitBuilder<F, D>,
|
||||
// ) -> ExtensionTarget<D> {
|
||||
// let tmp = self.base.exp(self.count) * x;
|
||||
// self.count = 0;
|
||||
// tmp
|
||||
// }
|
||||
//
|
||||
// pub fn shift_poly(
|
||||
// &mut self,
|
||||
// p: &mut PolynomialCoeffs<ExtensionTarget<D>>,
|
||||
// builder: &mut CircuitBuilder<F, D>,
|
||||
// ) {
|
||||
// *p *= self.base.exp(self.count);
|
||||
// self.count = 0;
|
||||
// }
|
||||
//
|
||||
// pub fn reset(&mut self) {
|
||||
// self.count = 0;
|
||||
// }
|
||||
//
|
||||
// pub fn repeated_frobenius(&self, count: usize, builder: &mut CircuitBuilder<F, D>) -> Self {
|
||||
// Self {
|
||||
// base: self.base.repeated_frobenius(count),
|
||||
// count: self.count,
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
#[derive(Debug, Copy, Clone)]
|
||||
pub struct ReducingFactorTarget<const D: usize> {
|
||||
base: ExtensionTarget<D>,
|
||||
count: u64,
|
||||
}
|
||||
|
||||
impl<const D: usize> ReducingFactorTarget<D> {
|
||||
pub fn new(base: ExtensionTarget<D>) -> Self {
|
||||
Self { base, count: 0 }
|
||||
}
|
||||
|
||||
pub fn reduce<F>(
|
||||
&mut self,
|
||||
iter: &[ExtensionTarget<D>], // Could probably work with a `DoubleEndedIterator` too.
|
||||
builder: &mut CircuitBuilder<F, D>,
|
||||
) -> ExtensionTarget<D>
|
||||
where
|
||||
F: Extendable<D>,
|
||||
{
|
||||
let l = iter.len();
|
||||
self.count += l as u64;
|
||||
let padded_iter = if l % 2 == 0 {
|
||||
iter.to_vec()
|
||||
} else {
|
||||
[iter, &[builder.zero_extension()]].concat()
|
||||
};
|
||||
let half_length = padded_iter.len() / 2;
|
||||
let gates = (0..half_length)
|
||||
.map(|_| builder.add_gate(ArithmeticExtensionGate::new(), vec![F::ONE, F::ONE]))
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
builder.add_generator(ParallelReductionGenerator {
|
||||
base: self.base,
|
||||
padded_iter: padded_iter.clone(),
|
||||
gates: gates.clone(),
|
||||
half_length,
|
||||
});
|
||||
|
||||
for i in 0..half_length {
|
||||
builder.route_extension(
|
||||
ExtensionTarget::from_range(
|
||||
gates[i],
|
||||
ArithmeticExtensionGate::<D>::wires_addend_0(),
|
||||
),
|
||||
padded_iter[2 * half_length - i - 1],
|
||||
);
|
||||
}
|
||||
for i in 0..half_length {
|
||||
builder.route_extension(
|
||||
ExtensionTarget::from_range(
|
||||
gates[i],
|
||||
ArithmeticExtensionGate::<D>::wires_addend_1(),
|
||||
),
|
||||
padded_iter[half_length - i - 1],
|
||||
);
|
||||
}
|
||||
for gate_pair in gates[..half_length].windows(2) {
|
||||
builder.assert_equal_extension(
|
||||
ExtensionTarget::from_range(
|
||||
gate_pair[0],
|
||||
ArithmeticExtensionGate::<D>::wires_output_0(),
|
||||
),
|
||||
ExtensionTarget::from_range(
|
||||
gate_pair[1],
|
||||
ArithmeticExtensionGate::<D>::wires_multiplicand_0(),
|
||||
),
|
||||
);
|
||||
}
|
||||
for gate_pair in gates[half_length..].windows(2) {
|
||||
builder.assert_equal_extension(
|
||||
ExtensionTarget::from_range(
|
||||
gate_pair[0],
|
||||
ArithmeticExtensionGate::<D>::wires_output_1(),
|
||||
),
|
||||
ExtensionTarget::from_range(
|
||||
gate_pair[1],
|
||||
ArithmeticExtensionGate::<D>::wires_multiplicand_1(),
|
||||
),
|
||||
);
|
||||
}
|
||||
builder.assert_equal_extension(
|
||||
ExtensionTarget::from_range(
|
||||
gates[half_length - 1],
|
||||
ArithmeticExtensionGate::<D>::wires_output_0(),
|
||||
),
|
||||
ExtensionTarget::from_range(
|
||||
gates[0],
|
||||
ArithmeticExtensionGate::<D>::wires_multiplicand_1(),
|
||||
),
|
||||
);
|
||||
|
||||
ExtensionTarget::from_range(
|
||||
gates[half_length - 1],
|
||||
ArithmeticExtensionGate::<D>::wires_output_1(),
|
||||
)
|
||||
}
|
||||
|
||||
pub fn shift<F>(
|
||||
&mut self,
|
||||
x: ExtensionTarget<D>,
|
||||
builder: &mut CircuitBuilder<F, D>,
|
||||
) -> ExtensionTarget<D>
|
||||
where
|
||||
F: Extendable<D>,
|
||||
{
|
||||
let exp = builder.exp_u64_extension(self.base, self.count);
|
||||
let tmp = builder.mul_extension(exp, x);
|
||||
self.count = 0;
|
||||
tmp
|
||||
}
|
||||
|
||||
pub fn reset(&mut self) {
|
||||
self.count = 0;
|
||||
}
|
||||
|
||||
pub fn repeated_frobenius<F>(&self, count: usize, builder: &mut CircuitBuilder<F, D>) -> Self
|
||||
where
|
||||
F: Extendable<D>,
|
||||
{
|
||||
Self {
|
||||
base: self.base.repeated_frobenius(count, builder),
|
||||
count: self.count,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
struct ParallelReductionGenerator<const D: usize> {
|
||||
base: ExtensionTarget<D>,
|
||||
padded_iter: Vec<ExtensionTarget<D>>,
|
||||
gates: Vec<usize>,
|
||||
half_length: usize,
|
||||
}
|
||||
|
||||
impl<F: Extendable<D>, const D: usize> SimpleGenerator<F> for ParallelReductionGenerator<D> {
|
||||
fn dependencies(&self) -> Vec<Target> {
|
||||
self.padded_iter
|
||||
.iter()
|
||||
.flat_map(|ext| ext.to_target_array())
|
||||
.chain(self.base.to_target_array())
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn run_once(&self, witness: &PartialWitness<F>) -> PartialWitness<F> {
|
||||
let mut pw = PartialWitness::new();
|
||||
let base = witness.get_extension_target(self.base);
|
||||
let vs = self
|
||||
.padded_iter
|
||||
.iter()
|
||||
.map(|&ext| witness.get_extension_target(ext))
|
||||
.collect::<Vec<_>>();
|
||||
let intermediate_accs = vs
|
||||
.iter()
|
||||
.rev()
|
||||
.scan(F::Extension::ZERO, |acc, &x| {
|
||||
let tmp = *acc;
|
||||
*acc = *acc * base + x;
|
||||
Some(tmp)
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
for i in 0..self.half_length {
|
||||
pw.set_extension_target(
|
||||
ExtensionTarget::from_range(
|
||||
self.gates[i],
|
||||
ArithmeticExtensionGate::<D>::wires_fixed_multiplicand(),
|
||||
),
|
||||
base,
|
||||
);
|
||||
pw.set_extension_target(
|
||||
ExtensionTarget::from_range(
|
||||
self.gates[i],
|
||||
ArithmeticExtensionGate::<D>::wires_multiplicand_0(),
|
||||
),
|
||||
intermediate_accs[i],
|
||||
);
|
||||
pw.set_extension_target(
|
||||
ExtensionTarget::from_range(
|
||||
self.gates[i],
|
||||
ArithmeticExtensionGate::<D>::wires_addend_0(),
|
||||
),
|
||||
vs[2 * self.half_length - i - 1],
|
||||
);
|
||||
pw.set_extension_target(
|
||||
ExtensionTarget::from_range(
|
||||
self.gates[i],
|
||||
ArithmeticExtensionGate::<D>::wires_multiplicand_1(),
|
||||
),
|
||||
intermediate_accs[self.half_length + i],
|
||||
);
|
||||
pw.set_extension_target(
|
||||
ExtensionTarget::from_range(
|
||||
self.gates[i],
|
||||
ArithmeticExtensionGate::<D>::wires_addend_1(),
|
||||
),
|
||||
vs[self.half_length - i - 1],
|
||||
);
|
||||
}
|
||||
|
||||
pw
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::circuit_data::CircuitConfig;
|
||||
use crate::field::crandall_field::CrandallField;
|
||||
use crate::field::extension_field::quartic::QuarticCrandallField;
|
||||
|
||||
fn test_reduce_gadget(n: usize) {
|
||||
type F = CrandallField;
|
||||
type FF = QuarticCrandallField;
|
||||
const D: usize = 4;
|
||||
|
||||
let config = CircuitConfig::large_config();
|
||||
|
||||
let mut builder = CircuitBuilder::<F, D>::new(config);
|
||||
|
||||
let alpha = FF::rand();
|
||||
let alpha = FF::ONE;
|
||||
let vs = (0..n)
|
||||
.map(|i| FF::from_canonical_usize(i))
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let manual_reduce = ReducingFactor::new(alpha).reduce(vs.iter());
|
||||
let manual_reduce = builder.constant_extension(manual_reduce);
|
||||
|
||||
let mut alpha_t = ReducingFactorTarget::new(builder.constant_extension(alpha));
|
||||
let vs_t = vs
|
||||
.iter()
|
||||
.map(|&v| builder.constant_extension(v))
|
||||
.collect::<Vec<_>>();
|
||||
let circuit_reduce = alpha_t.reduce(&vs_t, &mut builder);
|
||||
|
||||
builder.assert_equal_extension(manual_reduce, circuit_reduce);
|
||||
|
||||
let data = builder.build();
|
||||
let proof = data.prove(PartialWitness::new());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_reduce_gadget_even() {
|
||||
test_reduce_gadget(10);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_reduce_gadget_odd() {
|
||||
test_reduce_gadget(11);
|
||||
}
|
||||
}
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user