mirror of
https://github.com/logos-storage/plonky2.git
synced 2026-01-08 16:53:07 +00:00
msm (outside circuit)
This commit is contained in:
parent
051b79db2c
commit
d6630869e1
263
src/curve/curve_msm.rs
Normal file
263
src/curve/curve_msm.rs
Normal file
@ -0,0 +1,263 @@
|
||||
use itertools::Itertools;
|
||||
use rayon::prelude::*;
|
||||
|
||||
use crate::curve::curve_summation::affine_multisummation_best;
|
||||
use crate::curve::curve_types::{AffinePoint, Curve, ProjectivePoint};
|
||||
use crate::field::field_types::Field;
|
||||
|
||||
/// In Yao's method, we compute an affine summation for each digit. In a parallel setting, it would
|
||||
/// be easiest to assign individual summations to threads, but this would be sub-optimal because
|
||||
/// multi-summations can be more efficient than repeating individual summations (see
|
||||
/// `affine_multisummation_best`). Thus we divide digits into large chunks, and assign chunks of
|
||||
/// digits to threads. Note that there is a delicate balance here, as large chunks can result in
|
||||
/// uneven distributions of work among threads.
|
||||
const DIGITS_PER_CHUNK: usize = 80;
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct MsmPrecomputation<C: Curve> {
|
||||
/// For each generator (in the order they were passed to `msm_precompute`), contains a vector
|
||||
/// of powers, i.e. [(2^w)^i] for i < DIGITS.
|
||||
// TODO: Use compressed coordinates here.
|
||||
powers_per_generator: Vec<Vec<AffinePoint<C>>>,
|
||||
|
||||
/// The window size.
|
||||
w: usize,
|
||||
}
|
||||
|
||||
pub fn msm_precompute<C: Curve>(
|
||||
generators: &[ProjectivePoint<C>],
|
||||
w: usize,
|
||||
) -> MsmPrecomputation<C> {
|
||||
MsmPrecomputation {
|
||||
powers_per_generator: generators
|
||||
.into_par_iter()
|
||||
.map(|&g| precompute_single_generator(g, w))
|
||||
.collect(),
|
||||
w,
|
||||
}
|
||||
}
|
||||
|
||||
fn precompute_single_generator<C: Curve>(g: ProjectivePoint<C>, w: usize) -> Vec<AffinePoint<C>> {
|
||||
let digits = (C::ScalarField::BITS + w - 1) / w;
|
||||
let mut powers: Vec<ProjectivePoint<C>> = Vec::with_capacity(digits);
|
||||
powers.push(g);
|
||||
for i in 1..digits {
|
||||
let mut power_i_proj = powers[i - 1];
|
||||
for _j in 0..w {
|
||||
power_i_proj = power_i_proj.double();
|
||||
}
|
||||
powers.push(power_i_proj);
|
||||
}
|
||||
ProjectivePoint::batch_to_affine(&powers)
|
||||
}
|
||||
|
||||
pub fn msm_parallel<C: Curve>(
|
||||
scalars: &[C::ScalarField],
|
||||
generators: &[ProjectivePoint<C>],
|
||||
w: usize,
|
||||
) -> ProjectivePoint<C> {
|
||||
let precomputation = msm_precompute(generators, w);
|
||||
msm_execute_parallel(&precomputation, scalars)
|
||||
}
|
||||
|
||||
pub fn msm_execute<C: Curve>(
|
||||
precomputation: &MsmPrecomputation<C>,
|
||||
scalars: &[C::ScalarField],
|
||||
) -> ProjectivePoint<C> {
|
||||
assert_eq!(precomputation.powers_per_generator.len(), scalars.len());
|
||||
let w = precomputation.w;
|
||||
let digits = (C::ScalarField::BITS + w - 1) / w;
|
||||
let base = 1 << w;
|
||||
|
||||
// This is a variant of Yao's method, adapted to the multi-scalar setting. Because we use
|
||||
// extremely large windows, the repeated scans in Yao's method could be more expensive than the
|
||||
// actual group operations. To avoid this, we store a multimap from each possible digit to the
|
||||
// positions in which that digit occurs in the scalars. These positions have the form (i, j),
|
||||
// where i is the index of the generator and j is an index into the digits of the scalar
|
||||
// associated with that generator.
|
||||
let mut digit_occurrences: Vec<Vec<(usize, usize)>> = Vec::with_capacity(digits);
|
||||
for _i in 0..base {
|
||||
digit_occurrences.push(Vec::new());
|
||||
}
|
||||
for (i, scalar) in scalars.iter().enumerate() {
|
||||
let digits = to_digits::<C>(scalar, w);
|
||||
for (j, &digit) in digits.iter().enumerate() {
|
||||
digit_occurrences[digit].push((i, j));
|
||||
}
|
||||
}
|
||||
|
||||
let mut y = ProjectivePoint::ZERO;
|
||||
let mut u = ProjectivePoint::ZERO;
|
||||
|
||||
for digit in (1..base).rev() {
|
||||
for &(i, j) in &digit_occurrences[digit] {
|
||||
u = u + precomputation.powers_per_generator[i][j];
|
||||
}
|
||||
y = y + u;
|
||||
}
|
||||
|
||||
y
|
||||
}
|
||||
|
||||
pub fn msm_execute_parallel<C: Curve>(
|
||||
precomputation: &MsmPrecomputation<C>,
|
||||
scalars: &[C::ScalarField],
|
||||
) -> ProjectivePoint<C> {
|
||||
assert_eq!(precomputation.powers_per_generator.len(), scalars.len());
|
||||
let w = precomputation.w;
|
||||
let digits = (C::ScalarField::BITS + w - 1) / w;
|
||||
let base = 1 << w;
|
||||
|
||||
// This is a variant of Yao's method, adapted to the multi-scalar setting. Because we use
|
||||
// extremely large windows, the repeated scans in Yao's method could be more expensive than the
|
||||
// actual group operations. To avoid this, we store a multimap from each possible digit to the
|
||||
// positions in which that digit occurs in the scalars. These positions have the form (i, j),
|
||||
// where i is the index of the generator and j is an index into the digits of the scalar
|
||||
// associated with that generator.
|
||||
let mut digit_occurrences: Vec<Vec<(usize, usize)>> = Vec::with_capacity(digits);
|
||||
for _i in 0..base {
|
||||
digit_occurrences.push(Vec::new());
|
||||
}
|
||||
for (i, scalar) in scalars.iter().enumerate() {
|
||||
let digits = to_digits::<C>(scalar, w);
|
||||
for (j, &digit) in digits.iter().enumerate() {
|
||||
digit_occurrences[digit].push((i, j));
|
||||
}
|
||||
}
|
||||
|
||||
// For each digit, we add up the powers associated with all occurrences that digit.
|
||||
let digits: Vec<usize> = (0..base).collect();
|
||||
let digit_acc: Vec<ProjectivePoint<C>> = digits
|
||||
.par_chunks(DIGITS_PER_CHUNK)
|
||||
.flat_map(|chunk| {
|
||||
let summations: Vec<Vec<AffinePoint<C>>> = chunk
|
||||
.iter()
|
||||
.map(|&digit| {
|
||||
digit_occurrences[digit]
|
||||
.iter()
|
||||
.map(|&(i, j)| precomputation.powers_per_generator[i][j])
|
||||
.collect()
|
||||
})
|
||||
.collect();
|
||||
affine_multisummation_best(summations)
|
||||
})
|
||||
.collect();
|
||||
// println!("Computing the per-digit summations (in parallel) took {}s", start.elapsed().as_secs_f64());
|
||||
|
||||
let mut y = ProjectivePoint::ZERO;
|
||||
let mut u = ProjectivePoint::ZERO;
|
||||
for digit in (1..base).rev() {
|
||||
u = u + digit_acc[digit];
|
||||
y = y + u;
|
||||
}
|
||||
// println!("Final summation (sequential) {}s", start.elapsed().as_secs_f64());
|
||||
y
|
||||
}
|
||||
|
||||
pub(crate) fn to_digits<C: Curve>(x: &C::ScalarField, w: usize) -> Vec<usize> {
|
||||
let scalar_bits = C::ScalarField::BITS;
|
||||
let num_digits = (scalar_bits + w - 1) / w;
|
||||
|
||||
// Convert x to a bool array.
|
||||
let x_canonical: Vec<_> = x
|
||||
.to_biguint()
|
||||
.to_u64_digits()
|
||||
.iter()
|
||||
.cloned()
|
||||
.pad_using(scalar_bits / 64, |_| 0)
|
||||
.collect();
|
||||
let mut x_bits = Vec::with_capacity(scalar_bits);
|
||||
for i in 0..scalar_bits {
|
||||
x_bits.push((x_canonical[i / 64] >> (i as u64 % 64) & 1) != 0);
|
||||
}
|
||||
|
||||
let mut digits = Vec::with_capacity(num_digits);
|
||||
for i in 0..num_digits {
|
||||
let mut digit = 0;
|
||||
for j in ((i * w)..((i + 1) * w).min(scalar_bits)).rev() {
|
||||
digit <<= 1;
|
||||
digit |= x_bits[j] as usize;
|
||||
}
|
||||
digits.push(digit);
|
||||
}
|
||||
digits
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use num::BigUint;
|
||||
|
||||
use crate::curve::curve_msm::{msm_execute, msm_precompute, to_digits};
|
||||
use crate::curve::curve_types::Curve;
|
||||
use crate::curve::secp256k1::Secp256K1;
|
||||
use crate::field::field_types::Field;
|
||||
use crate::field::secp256k1_scalar::Secp256K1Scalar;
|
||||
|
||||
#[test]
|
||||
fn test_to_digits() {
|
||||
let x_canonical = [
|
||||
0b10101010101010101010101010101010,
|
||||
0b10101010101010101010101010101010,
|
||||
0b11001100110011001100110011001100,
|
||||
0b11001100110011001100110011001100,
|
||||
0b11110000111100001111000011110000,
|
||||
0b11110000111100001111000011110000,
|
||||
0b00001111111111111111111111111111,
|
||||
0b11111111111111111111111111111111,
|
||||
];
|
||||
let x = Secp256K1Scalar::from_biguint(BigUint::from_slice(&x_canonical));
|
||||
assert_eq!(x.to_biguint().to_u32_digits(), x_canonical);
|
||||
assert_eq!(
|
||||
to_digits::<Secp256K1>(&x, 17),
|
||||
vec![
|
||||
0b01010101010101010,
|
||||
0b10101010101010101,
|
||||
0b01010101010101010,
|
||||
0b11001010101010101,
|
||||
0b01100110011001100,
|
||||
0b00110011001100110,
|
||||
0b10011001100110011,
|
||||
0b11110000110011001,
|
||||
0b01111000011110000,
|
||||
0b00111100001111000,
|
||||
0b00011110000111100,
|
||||
0b11111111111111110,
|
||||
0b01111111111111111,
|
||||
0b11111111111111000,
|
||||
0b11111111111111111,
|
||||
0b1,
|
||||
]
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_msm() {
|
||||
let w = 5;
|
||||
|
||||
let generator_1 = Secp256K1::GENERATOR_PROJECTIVE;
|
||||
let generator_2 = generator_1 + generator_1;
|
||||
let generator_3 = generator_1 + generator_2;
|
||||
|
||||
let scalar_1 = Secp256K1Scalar::from_biguint(BigUint::from_slice(&[
|
||||
11111111, 22222222, 33333333, 44444444,
|
||||
]));
|
||||
let scalar_2 = Secp256K1Scalar::from_biguint(BigUint::from_slice(&[
|
||||
22222222, 22222222, 33333333, 44444444,
|
||||
]));
|
||||
let scalar_3 = Secp256K1Scalar::from_biguint(BigUint::from_slice(&[
|
||||
33333333, 22222222, 33333333, 44444444,
|
||||
]));
|
||||
|
||||
let generators = vec![generator_1, generator_2, generator_3];
|
||||
let scalars = vec![scalar_1, scalar_2, scalar_3];
|
||||
|
||||
let precomputation = msm_precompute(&generators, w);
|
||||
let result_msm = msm_execute(&precomputation, &scalars);
|
||||
|
||||
let result_naive = Secp256K1::convert(scalar_1) * generator_1
|
||||
+ Secp256K1::convert(scalar_2) * generator_2
|
||||
+ Secp256K1::convert(scalar_3) * generator_3;
|
||||
|
||||
assert_eq!(result_msm, result_naive);
|
||||
}
|
||||
}
|
||||
@ -1,4 +1,5 @@
|
||||
pub mod curve_adds;
|
||||
pub mod curve_msm;
|
||||
pub mod curve_multiplication;
|
||||
pub mod curve_summation;
|
||||
pub mod curve_types;
|
||||
|
||||
@ -88,7 +88,7 @@ impl<F: RichField + Extendable<D>, const D: usize> CircuitBuilder<F, D> {
|
||||
AffinePointTarget { x: x3, y: y3 }
|
||||
}
|
||||
|
||||
pub fn curve_add_two_affine<C: Curve>(
|
||||
pub fn curve_add<C: Curve>(
|
||||
&mut self,
|
||||
p1: &AffinePointTarget<C>,
|
||||
p2: &AffinePointTarget<C>,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user