diff --git a/nomos-da/kzgrs/src/rs.rs b/nomos-da/kzgrs/src/rs.rs index 507ff727..902f04dd 100644 --- a/nomos-da/kzgrs/src/rs.rs +++ b/nomos-da/kzgrs/src/rs.rs @@ -1,8 +1,15 @@ use ark_bls12_381::Fr; -use ark_ff::{BigInteger, PrimeField}; +use ark_ff::{BigInteger, Field, PrimeField}; use ark_poly::univariate::DensePolynomial; -use ark_poly::{EvaluationDomain, Evaluations, GeneralEvaluationDomain, Polynomial}; +use ark_poly::{ + DenseUVPolynomial, EvaluationDomain, Evaluations, GeneralEvaluationDomain, Polynomial, +}; +use num_traits::Zero; +use std::ops::{Mul, Neg}; +/// Extend a polynomial over some factor `polynomial.len()*factor and return the original points +/// plus the extra ones. +/// `factor` need to be `>1` pub fn encode( polynomial: &DensePolynomial, evaluations: &Evaluations, @@ -18,14 +25,26 @@ pub fn encode( ) } +/// Interpolate points into a polynomial, then evaluate the polynomial in the original evaluations +/// to recover the original data. +/// `domain` need to be the same domain of the original `evaluations` and `polynomial` used for encoding. pub fn decode( original_chunks_len: usize, - points: &[Fr], + points: &[Option], domain: &GeneralEvaluationDomain, ) -> Evaluations { - let evals = Evaluations::::from_vec_and_domain(points.to_vec(), *domain); - let coeffs = evals.interpolate(); - + let (points, roots_of_unity): (Vec, Vec) = points + .iter() + .enumerate() + .flat_map(|(i, e)| { + if let Some(e) = e { + Some((*e, domain.element(i))) + } else { + None + } + }) + .unzip(); + let coeffs = lagrange_interpolate(&points, &roots_of_unity); Evaluations::from_vec_and_domain( (0..original_chunks_len) .map(|i| coeffs.evaluate(&domain.element(i))) @@ -34,6 +53,35 @@ pub fn decode( ) } +/// Interpolate a set of points using lagrange interpolation and roots of unity +/// Warning!! Be aware that the mapping between points and roots of unity is the intended: +/// A polynomial `f(x)` is derived for `w_x` (root) mapping to p_x. `[(w_1, p_1)..(w_n, p_n)]` even +/// if points are missing it is important to keep the mapping integrity. +pub fn lagrange_interpolate(points: &[Fr], roots_of_unity: &[Fr]) -> DensePolynomial { + assert_eq!(points.len(), roots_of_unity.len()); + let mut result = DensePolynomial::from_coefficients_vec(vec![Fr::zero()]); + for i in 0..roots_of_unity.len() { + let mut summand = DensePolynomial::from_coefficients_vec(vec![points[i]]); + for j in 0..points.len() { + if i != j { + let weight_adjustment = + (roots_of_unity[i] - roots_of_unity[j]) + .inverse() + .expect( + "Roots of unity are/should not repeated. If this panics it means we have no coefficients enough in the evaluation domain" + ); + summand = summand.naive_mul(&DensePolynomial::from_coefficients_vec(vec![ + weight_adjustment.mul(roots_of_unity[j]).neg(), + weight_adjustment, + ])) + } + } + result = result + summand; + } + result +} + +/// Reconstruct bytes from the polynomial evaluation points using original chunk size and a set of points pub fn points_to_bytes(points: &[Fr]) -> Vec { fn point_to_buff(p: &Fr) -> impl Iterator { p.into_bigint().to_bytes_le().into_iter().take(CHUNK_SIZE) @@ -49,14 +97,12 @@ pub fn points_to_bytes(points: &[Fr]) -> Vec { mod test { use crate::common::bytes_to_polynomial; use crate::rs::{decode, encode, points_to_bytes}; - use ark_bls12_381::{Bls12_381, Fr}; - use ark_poly::univariate::DensePolynomial; + use ark_bls12_381::Fr; use ark_poly::{EvaluationDomain, GeneralEvaluationDomain}; - use ark_poly_commit::kzg10::{UniversalParams, KZG10}; use once_cell::sync::Lazy; use rand::{thread_rng, Fill}; - const COEFFICIENTS_SIZE: usize = 16; + const COEFFICIENTS_SIZE: usize = 32; static DOMAIN: Lazy> = Lazy::new(|| GeneralEvaluationDomain::new(COEFFICIENTS_SIZE).unwrap()); @@ -68,9 +114,19 @@ mod test { let (evals, poly) = bytes_to_polynomial::<31>(&bytes, *DOMAIN).unwrap(); - let mut encoded = encode(&poly, &evals, 2, &DOMAIN); + let encoded = encode(&poly, &evals, 2, &DOMAIN); + let mut encoded: Vec> = encoded.evals.into_iter().map(Some).collect(); + + let decoded = decode(10, &encoded, &DOMAIN); + let decoded_bytes = points_to_bytes::<31>(&decoded.evals); + assert_eq!(decoded_bytes, bytes); + + // check with missing pieces + + for i in (1..encoded.len()).step_by(2) { + encoded[i] = None; + } - let decoded = decode(10, &encoded.evals, &DOMAIN); let decoded_bytes = points_to_bytes::<31>(&decoded.evals); assert_eq!(decoded_bytes, bytes); }