diff --git a/src/bin/bench_field_mul_interleaved.rs b/src/bin/bench_field_mul_interleaved.rs index d04415f1..737105ad 100644 --- a/src/bin/bench_field_mul_interleaved.rs +++ b/src/bin/bench_field_mul_interleaved.rs @@ -14,8 +14,8 @@ const EXPONENT: usize = 1000000000; fn main() { let mut bases = [F::ZERO; WIDTH]; - for i in 0..WIDTH { - bases[i] = F::rand(); + for base_i in bases.iter_mut() { + *base_i = F::rand(); } let mut state = [F::ONE; WIDTH]; diff --git a/src/bin/bench_gmimc.rs b/src/bin/bench_gmimc.rs index f234285d..2f81aac0 100644 --- a/src/bin/bench_gmimc.rs +++ b/src/bin/bench_gmimc.rs @@ -14,8 +14,8 @@ const PROVER_POLYS: usize = 113 + 3 + 4; fn main() { const THREADS: usize = 12; const LDE_BITS: i32 = 3; - const W: usize = 13; - const HASHES_PER_POLY: usize = 1 << (13 + LDE_BITS); + const W: usize = 12; + const HASHES_PER_POLY: usize = 1 << (13 + LDE_BITS) / 6; let threads = (0..THREADS) .map(|_i| { diff --git a/src/bin/bench_ldes.rs b/src/bin/bench_ldes.rs index ecdcd4fb..5620bd80 100644 --- a/src/bin/bench_ldes.rs +++ b/src/bin/bench_ldes.rs @@ -3,9 +3,8 @@ use std::time::Instant; use rayon::prelude::*; use plonky2::field::crandall_field::CrandallField; -use plonky2::field::fft; use plonky2::field::field::Field; -use plonky2::polynomial::polynomial::{PolynomialCoeffs, PolynomialValues}; +use plonky2::polynomial::polynomial::PolynomialValues; type F = CrandallField; @@ -18,9 +17,9 @@ fn main() { let start = Instant::now(); (0usize..PROVER_POLYS).into_par_iter().for_each(|i| { - let mut values = vec![CrandallField::ZERO; DEGREE]; + let mut values = vec![F::ZERO; DEGREE]; for j in 0usize..DEGREE { - values[j] = CrandallField((i * j) as u64); + values[j] = F::from_canonical_u64((i * j) as u64); } let poly_values = PolynomialValues::new(values); let start = Instant::now(); diff --git a/src/bin/bench_recursion.rs b/src/bin/bench_recursion.rs index 81409642..566fb056 100644 --- a/src/bin/bench_recursion.rs +++ b/src/bin/bench_recursion.rs @@ -1,24 +1,14 @@ -use std::thread; -use std::time::Instant; - use env_logger::Env; -use rayon::prelude::*; use plonky2::circuit_builder::CircuitBuilder; use plonky2::circuit_data::CircuitConfig; use plonky2::field::crandall_field::CrandallField; -use plonky2::field::fft; use plonky2::field::field::Field; use plonky2::gates::constant::ConstantGate; use plonky2::gates::gmimc::GMiMCGate; -use plonky2::gmimc::gmimc_permute_array; -use plonky2::hash::{GMIMC_CONSTANTS, GMIMC_ROUNDS}; -use plonky2::polynomial::polynomial::PolynomialCoeffs; +use plonky2::hash::GMIMC_ROUNDS; use plonky2::witness::PartialWitness; -// 113 wire polys, 3 Z polys, 4 parts of quotient poly. -const PROVER_POLYS: usize = 113 + 3 + 4; - fn main() { // Set the default log filter. This can be overridden using the `RUST_LOG` environment variable, // e.g. `RUST_LOG=debug`. diff --git a/src/bin/bench_rescue.rs b/src/bin/bench_rescue.rs new file mode 100644 index 00000000..96334689 --- /dev/null +++ b/src/bin/bench_rescue.rs @@ -0,0 +1,46 @@ +use std::thread; +use std::time::Instant; + +use plonky2::field::crandall_field::CrandallField; +use plonky2::field::field::Field; +use plonky2::rescue::rescue; + +type F = CrandallField; + +// 113 wire polys, 3 Z polys, 4 parts of quotient poly. +const PROVER_POLYS: usize = 113 + 3 + 4; + +fn main() { + const THREADS: usize = 12; + const LDE_BITS: i32 = 3; + const W: usize = 12; + const HASHES_PER_POLY: usize = (1 << (13 + LDE_BITS)) / 6; + + let threads = (0..THREADS) + .map(|_i| { + thread::spawn(move || { + let mut x = [F::ZERO; W]; + for i in 0..W { + x[i] = F::from_canonical_u64((i as u64) * 123456 + 789); + } + + let hashes_per_thread = HASHES_PER_POLY * PROVER_POLYS / THREADS; + let start = Instant::now(); + for _ in 0..hashes_per_thread { + x = rescue(x); + } + let duration = start.elapsed(); + println!("took {:?}", duration); + println!( + "avg {:?}us", + duration.as_secs_f64() * 1e6 / (hashes_per_thread as f64) + ); + println!("result {:?}", x); + }) + }) + .collect::>(); + + for t in threads { + t.join().expect("oops"); + } +} diff --git a/src/bin/field_search.rs b/src/bin/field_search.rs index e0fdb7e6..6c433339 100644 --- a/src/bin/field_search.rs +++ b/src/bin/field_search.rs @@ -53,5 +53,5 @@ fn is_prime(n: u64) -> bool { d += 2; } - return true; + true } diff --git a/src/circuit_builder.rs b/src/circuit_builder.rs index a7e87339..d21279d2 100644 --- a/src/circuit_builder.rs +++ b/src/circuit_builder.rs @@ -1,16 +1,20 @@ -use std::collections::{HashSet, HashMap}; +use std::collections::{HashMap, HashSet}; use std::time::Instant; use log::info; -use crate::circuit_data::{CircuitConfig, CircuitData, CommonCircuitData, ProverCircuitData, ProverOnlyCircuitData, VerifierCircuitData, VerifierOnlyCircuitData}; +use crate::circuit_data::{ + CircuitConfig, CircuitData, CommonCircuitData, ProverCircuitData, ProverOnlyCircuitData, + VerifierCircuitData, VerifierOnlyCircuitData, +}; +use crate::field::cosets::get_unique_coset_shifts; use crate::field::field::Field; use crate::gates::constant::ConstantGate; use crate::gates::gate::{GateInstance, GateRef}; use crate::gates::noop::NoopGate; use crate::generator::{CopyGenerator, WitnessGenerator}; -use crate::hash::merkle_root_bit_rev_order; -use crate::field::cosets::get_unique_coset_shifts; +use crate::hash::hash_n_to_hash; +use crate::merkle_tree::MerkleTree; use crate::polynomial::polynomial::PolynomialValues; use crate::target::Target; use crate::util::{log2_strict, transpose, transpose_poly_values}; @@ -81,17 +85,23 @@ impl CircuitBuilder { let index = self.gate_instances.len(); - // TODO: Not passing next constants for now. Not sure if it's really useful... - self.add_generators(gate_type.0.generators(index, &constants, &[])); + self.add_generators(gate_type.0.generators(index, &constants)); - self.gate_instances.push(GateInstance { gate_type, constants }); + self.gate_instances.push(GateInstance { + gate_type, + constants, + }); index } fn check_gate_compatibility(&self, gate: &GateRef) { - assert!(gate.0.num_wires() <= self.config.num_wires, - "{:?} requires {} wires, but our GateConfig has only {}", - gate.0.id(), gate.0.num_wires(), self.config.num_wires); + assert!( + gate.0.num_wires() <= self.config.num_wires, + "{:?} requires {} wires, but our GateConfig has only {}", + gate.0.id(), + gate.0.num_wires(), + self.config.num_wires + ); } /// Shorthand for `generate_copy` and `assert_equal`. @@ -109,8 +119,14 @@ impl CircuitBuilder { /// Uses Plonk's permutation argument to require that two elements be equal. /// Both elements must be routable, otherwise this method will panic. pub fn assert_equal(&mut self, x: Target, y: Target) { - assert!(x.is_routable(self.config), "Tried to route a wire that isn't routable"); - assert!(y.is_routable(self.config), "Tried to route a wire that isn't routable"); + assert!( + x.is_routable(self.config), + "Tried to route a wire that isn't routable" + ); + assert!( + y.is_routable(self.config), + "Tried to route a wire that isn't routable" + ); // TODO: Add to copy_constraints. } @@ -150,7 +166,10 @@ impl CircuitBuilder { } let gate = self.add_gate(ConstantGate::get(), vec![c]); - let target = Target::Wire(Wire { gate, input: ConstantGate::WIRE_OUTPUT }); + let target = Target::Wire(Wire { + gate, + input: ConstantGate::WIRE_OUTPUT, + }); self.constants_to_targets.insert(c, target); self.targets_to_constants.insert(target, c); target @@ -175,11 +194,15 @@ impl CircuitBuilder { } fn constant_polys(&self) -> Vec> { - let num_constants = self.gate_instances.iter() + let num_constants = self + .gate_instances + .iter() .map(|gate_inst| gate_inst.constants.len()) .max() .unwrap(); - let constants_per_gate = self.gate_instances.iter() + let constants_per_gate = self + .gate_instances + .iter() .map(|gate_inst| { let mut padded_constants = gate_inst.constants.clone(); for _ in padded_constants.len()..num_constants { @@ -196,13 +219,17 @@ impl CircuitBuilder { } fn sigma_vecs(&self) -> Vec> { - vec![PolynomialValues::zero(self.gate_instances.len()); self.config.num_routed_wires] // TODO + vec![PolynomialValues::zero(self.gate_instances.len()); self.config.num_routed_wires] + // TODO } /// Builds a "full circuit", with both prover and verifier data. pub fn build(mut self) -> CircuitData { let start = Instant::now(); - info!("degree before blinding & padding: {}", self.gate_instances.len()); + info!( + "degree before blinding & padding: {}", + self.gate_instances.len() + ); self.blind_and_pad(); let degree = self.gate_instances.len(); info!("degree after blinding & padding: {}", degree); @@ -210,23 +237,34 @@ impl CircuitBuilder { let constant_vecs = self.constant_polys(); let constant_ldes = PolynomialValues::lde_multiple(constant_vecs, self.config.rate_bits); let constant_ldes_t = transpose_poly_values(constant_ldes); - let constants_root = merkle_root_bit_rev_order(constant_ldes_t.clone()); + let constants_tree = MerkleTree::new(constant_ldes_t, true); let sigma_vecs = self.sigma_vecs(); let sigma_ldes = PolynomialValues::lde_multiple(sigma_vecs, self.config.rate_bits); let sigma_ldes_t = transpose_poly_values(sigma_ldes); - let sigmas_root = merkle_root_bit_rev_order(sigma_ldes_t.clone()); + let sigmas_tree = MerkleTree::new(sigma_ldes_t, true); + + let constants_root = constants_tree.root; + let sigmas_root = sigmas_tree.root; + let verifier_only = VerifierOnlyCircuitData { + constants_root, + sigmas_root, + }; let generators = self.generators; - let prover_only = ProverOnlyCircuitData { generators, constant_ldes_t, sigma_ldes_t }; - let verifier_only = VerifierOnlyCircuitData {}; + let prover_only = ProverOnlyCircuitData { + generators, + constants_tree, + sigmas_tree, + }; // The HashSet of gates will have a non-deterministic order. When converting to a Vec, we // sort by ID to make the ordering deterministic. let mut gates = self.gates.iter().cloned().collect::>(); gates.sort_unstable_by_key(|gate| gate.0.id()); - let num_gate_constraints = gates.iter() + let num_gate_constraints = gates + .iter() .map(|gate| gate.0.num_constraints()) .max() .expect("No gates?"); @@ -234,14 +272,17 @@ impl CircuitBuilder { let degree_bits = log2_strict(degree); let k_is = get_unique_coset_shifts(degree, self.config.num_routed_wires); + // TODO: This should also include an encoding of gate constraints. + let circuit_digest_parts = [constants_root.elements, sigmas_root.elements]; + let circuit_digest = hash_n_to_hash(circuit_digest_parts.concat(), false); + let common = CommonCircuitData { config: self.config, degree_bits, gates, num_gate_constraints, - constants_root, - sigmas_root, k_is, + circuit_digest, }; info!("Building circuit took {}s", start.elapsed().as_secs_f32()); @@ -255,14 +296,28 @@ impl CircuitBuilder { /// Builds a "prover circuit", with data needed to generate proofs but not verify them. pub fn build_prover(self) -> ProverCircuitData { // TODO: Can skip parts of this. - let CircuitData { prover_only, common, .. } = self.build(); - ProverCircuitData { prover_only, common } + let CircuitData { + prover_only, + common, + .. + } = self.build(); + ProverCircuitData { + prover_only, + common, + } } /// Builds a "verifier circuit", with data needed to verify proofs but not generate them. pub fn build_verifier(self) -> VerifierCircuitData { // TODO: Can skip parts of this. - let CircuitData { verifier_only, common, .. } = self.build(); - VerifierCircuitData { verifier_only, common } + let CircuitData { + verifier_only, + common, + .. + } = self.build(); + VerifierCircuitData { + verifier_only, + common, + } } } diff --git a/src/circuit_data.rs b/src/circuit_data.rs index ef0a74c6..f75e7275 100644 --- a/src/circuit_data.rs +++ b/src/circuit_data.rs @@ -1,6 +1,7 @@ use crate::field::field::Field; use crate::gates::gate::GateRef; use crate::generator::WitnessGenerator; +use crate::merkle_tree::MerkleTree; use crate::proof::{Hash, HashTarget, Proof}; use crate::prover::prove; use crate::verifier::verify; @@ -37,7 +38,7 @@ impl CircuitConfig { /// Circuit data required by the prover or the verifier. pub struct CircuitData { pub(crate) prover_only: ProverOnlyCircuitData, - pub(crate) verifier_only: VerifierOnlyCircuitData, + pub(crate) verifier_only: VerifierOnlyCircuitData, pub(crate) common: CommonCircuitData, } @@ -71,7 +72,7 @@ impl ProverCircuitData { /// Circuit data required by the prover. pub struct VerifierCircuitData { - pub(crate) verifier_only: VerifierOnlyCircuitData, + pub(crate) verifier_only: VerifierOnlyCircuitData, pub(crate) common: CommonCircuitData, } @@ -84,13 +85,20 @@ impl VerifierCircuitData { /// Circuit data required by the prover, but not the verifier. pub(crate) struct ProverOnlyCircuitData { pub generators: Vec>>, - pub constant_ldes_t: Vec>, - /// Transpose of LDEs of sigma polynomials (in the context of Plonk's permutation argument). - pub sigma_ldes_t: Vec>, + /// Merkle tree containing LDEs of each constant polynomial. + pub constants_tree: MerkleTree, + /// Merkle tree containing LDEs of each sigma polynomial. + pub sigmas_tree: MerkleTree, } /// Circuit data required by the verifier, but not the prover. -pub(crate) struct VerifierOnlyCircuitData {} +pub(crate) struct VerifierOnlyCircuitData { + /// A commitment to each constant polynomial. + pub(crate) constants_root: Hash, + + /// A commitment to each permutation polynomial. + pub(crate) sigmas_root: Hash, +} /// Circuit data required by both the prover and the verifier. pub(crate) struct CommonCircuitData { @@ -104,14 +112,12 @@ pub(crate) struct CommonCircuitData { /// The largest number of constraints imposed by any gate. pub(crate) num_gate_constraints: usize, - /// A commitment to each constant polynomial. - pub(crate) constants_root: Hash, - - /// A commitment to each permutation polynomial. - pub(crate) sigmas_root: Hash, - /// The `{k_i}` valued used in `S_ID_i` in Plonk's permutation argument. pub(crate) k_is: Vec, + + /// A digest of the "circuit" (i.e. the instance, minus public inputs), which can be used to + /// seed Fiat-Shamir. + pub(crate) circuit_digest: Hash, } impl CommonCircuitData { diff --git a/src/field/crandall_field.rs b/src/field/crandall_field.rs index ae5dd89e..657193d9 100644 --- a/src/field/crandall_field.rs +++ b/src/field/crandall_field.rs @@ -6,6 +6,7 @@ use num::Integer; use crate::field::field::Field; use std::hash::{Hash, Hasher}; +use std::iter::{Product, Sum}; /// EPSILON = 9 * 2**28 - 1 const EPSILON: u64 = 2415919103; @@ -69,6 +70,7 @@ impl Field for CrandallField { *self * *self * *self } + #[allow(clippy::many_single_char_names)] // The names are from the paper. fn try_inverse(&self) -> Option { if self.is_zero() { return None; @@ -144,6 +146,85 @@ impl Field for CrandallField { fn from_canonical_u64(n: u64) -> Self { Self(n) } + + fn cube_root(&self) -> Self { + let x0 = *self; + let x1 = x0.square(); + let x2 = x1.square(); + let x3 = x2 * x0; + let x4 = x3.square(); + let x5 = x4.square(); + // let x6 = x4.square(); + let x7 = x5.square(); + let x8 = x7.square(); + let x9 = x8.square(); + let x10 = x9.square(); + let x11 = x10 * x5; + let x12 = x11.square(); + let x13 = x12.square(); + let x14 = x13.square(); + // let x15 = x13.square(); + let x16 = x14.square(); + let x17 = x16.square(); + let x18 = x17.square(); + let x19 = x18.square(); + let x20 = x19.square(); + let x21 = x20 * x11; + let x22 = x21.square(); + let x23 = x22.square(); + let x24 = x23.square(); + let x25 = x24.square(); + let x26 = x25.square(); + let x27 = x26.square(); + let x28 = x27.square(); + let x29 = x28.square(); + let x30 = x29.square(); + let x31 = x30.square(); + let x32 = x31.square(); + let x33 = x32 * x14; + let x34 = x33 * x3; + let x35 = x34.square(); + let x36 = x35 * x34; + let x37 = x36 * x5; + let x38 = x37 * x34; + let x39 = x38 * x37; + let x40 = x39.square(); + let x41 = x40.square(); + let x42 = x41 * x38; + let x43 = x42.square(); + let x44 = x43.square(); + let x45 = x44.square(); + let x46 = x45.square(); + let x47 = x46.square(); + let x48 = x47.square(); + let x49 = x48.square(); + let x50 = x49.square(); + let x51 = x50.square(); + let x52 = x51.square(); + let x53 = x52.square(); + let x54 = x53.square(); + let x55 = x54.square(); + let x56 = x55.square(); + let x57 = x56.square(); + let x58 = x57.square(); + let x59 = x58.square(); + let x60 = x59.square(); + let x61 = x60.square(); + let x62 = x61.square(); + let x63 = x62.square(); + let x64 = x63.square(); + let x65 = x64.square(); + let x66 = x65.square(); + let x67 = x66.square(); + let x68 = x67.square(); + let x69 = x68.square(); + let x70 = x69.square(); + let x71 = x70.square(); + let x72 = x71.square(); + let x73 = x72.square(); + let x74 = x73 * x39; + x74 + } } impl Neg for CrandallField { @@ -164,6 +245,7 @@ impl Add for CrandallField { type Output = Self; #[inline] + #[allow(clippy::suspicious_arithmetic_impl)] fn add(self, rhs: Self) -> Self { let (sum, over) = self.0.overflowing_add(rhs.0); Self(sum.overflowing_sub((over as u64) * Self::ORDER).0) @@ -176,10 +258,17 @@ impl AddAssign for CrandallField { } } +impl Sum for CrandallField { + fn sum>(iter: I) -> Self { + iter.fold(Self::ZERO, |acc, x| acc + x) + } +} + impl Sub for CrandallField { type Output = Self; #[inline] + #[allow(clippy::suspicious_arithmetic_impl)] fn sub(self, rhs: Self) -> Self { let (diff, under) = self.0.overflowing_sub(rhs.0); Self(diff.overflowing_add((under as u64) * Self::ORDER).0) @@ -209,9 +298,16 @@ impl MulAssign for CrandallField { } } +impl Product for CrandallField { + fn product>(iter: I) -> Self { + iter.fold(Self::ONE, |acc, x| acc * x) + } +} + impl Div for CrandallField { type Output = Self; + #[allow(clippy::suspicious_arithmetic_impl)] fn div(self, rhs: Self) -> Self::Output { self * rhs.inverse() } diff --git a/src/field/fft.rs b/src/field/fft.rs index fe3d2117..8bcde967 100644 --- a/src/field/fft.rs +++ b/src/field/fft.rs @@ -41,10 +41,10 @@ pub fn fft(poly: PolynomialCoeffs) -> PolynomialValues { } pub(crate) fn fft_precompute(degree: usize) -> FftPrecomputation { - let degree_pow = log2_ceil(degree); + let degree_log = log2_ceil(degree); let mut subgroups_rev = Vec::new(); - for i in 0..=degree_pow { + for i in 0..=degree_log { let g_i = F::primitive_root_of_unity(i); let subgroup = F::cyclic_subgroup_known_order(g_i, 1 << i); let subgroup_rev = reverse_index_bits(subgroup); @@ -66,8 +66,8 @@ pub(crate) fn ifft_with_precomputation_power_of_2( fft_with_precomputation_power_of_2(PolynomialCoeffs { coeffs: values }, precomputation); // We reverse all values except the first, and divide each by n. - result[0] = result[0] * n_inv; - result[n / 2] = result[n / 2] * n_inv; + result[0] *= n_inv; + result[n / 2] *= n_inv; for i in 1..(n / 2) { let j = n - i; let result_i = result[j] * n_inv; @@ -89,14 +89,14 @@ pub(crate) fn fft_with_precomputation_power_of_2( ); let half_degree = poly.len() >> 1; - let degree_pow = poly.log_len(); + let degree_log = poly.log_len(); // In the base layer, we're just evaluating "degree 0 polynomials", i.e. the coefficients // themselves. let PolynomialCoeffs { coeffs } = poly; let mut evaluations = reverse_index_bits(coeffs); - for i in 1..=degree_pow { + for i in 1..=degree_log { // In layer i, we're evaluating a series of polynomials, each at 2^i points. In practice // we evaluate a pair of points together, so we have 2^(i - 1) pairs. let points_per_poly = 1 << i; @@ -169,7 +169,7 @@ mod tests { for i in 0..degree { coefficients.push(F::from_canonical_usize(i * 1337 % 100)); } - let coefficients = PolynomialCoeffs::pad(coefficients); + let coefficients = PolynomialCoeffs::new_padded(coefficients); let points = fft(coefficients.clone()); assert_eq!(points, evaluate_naive(&coefficients)); @@ -198,9 +198,9 @@ mod tests { coefficients: &PolynomialCoeffs, ) -> PolynomialValues { let degree = coefficients.len(); - let degree_pow = log2_strict(degree); + let degree_log = log2_strict(degree); - let g = F::primitive_root_of_unity(degree_pow); + let g = F::primitive_root_of_unity(degree_log); let powers_of_g = F::cyclic_subgroup_known_order(g, degree); let values = powers_of_g diff --git a/src/field/field.rs b/src/field/field.rs index cfadfbe6..cafc536a 100644 --- a/src/field/field.rs +++ b/src/field/field.rs @@ -1,10 +1,14 @@ -use crate::util::bits_u64; -use rand::rngs::OsRng; -use rand::Rng; use std::fmt::{Debug, Display}; use std::hash::Hash; +use std::iter::{Product, Sum}; use std::ops::{Add, AddAssign, Div, DivAssign, Mul, MulAssign, Neg, Sub, SubAssign}; +use num::Integer; +use rand::Rng; +use rand::rngs::OsRng; + +use crate::util::bits_u64; + /// A finite field with prime order less than 2^64. pub trait Field: 'static @@ -14,10 +18,12 @@ pub trait Field: + Neg + Add + AddAssign + + Sum + Sub + SubAssign + Mul + MulAssign + + Product + Div + DivAssign + Debug @@ -42,6 +48,10 @@ pub trait Field: *self == Self::ZERO } + fn is_nonzero(&self) -> bool { + *self != Self::ZERO + } + fn is_one(&self) -> bool { *self == Self::ONE } @@ -90,13 +100,13 @@ pub trait Field: x_inv } - fn primitive_root_of_unity(n_power: usize) -> Self { - assert!(n_power <= Self::TWO_ADICITY); - let base = Self::POWER_OF_TWO_GENERATOR; - // TODO: Just repeated squaring should be a bit faster, to avoid conditionals. - base.exp(Self::from_canonical_u64( - 1u64 << (Self::TWO_ADICITY - n_power), - )) + fn primitive_root_of_unity(n_log: usize) -> Self { + assert!(n_log <= Self::TWO_ADICITY); + let mut base = Self::POWER_OF_TWO_GENERATOR; + for _ in n_log..Self::TWO_ADICITY { + base = base.square(); + } + base } /// Computes a multiplicative subgroup whose order is known in advance. @@ -105,11 +115,26 @@ pub trait Field: let mut current = Self::ONE; for _i in 0..order { subgroup.push(current); - current = current * generator; + current *= generator; } subgroup } + fn cyclic_subgroup_unknown_order(generator: Self) -> Vec { + let mut subgroup = Vec::new(); + for power in generator.powers() { + if power.is_one() && !subgroup.is_empty() { + break; + } + subgroup.push(power); + } + subgroup + } + + fn generator_order(generator: Self) -> usize { + generator.powers().skip(1).position(|y| y.is_one()).unwrap() + 1 + } + /// Computes a coset of a multiplicative subgroup whose order is known in advance. fn cyclic_subgroup_coset_known_order(generator: Self, shift: Self, order: usize) -> Vec { let subgroup = Self::cyclic_subgroup_known_order(generator, order); @@ -120,6 +145,10 @@ pub trait Field: fn from_canonical_u64(n: u64) -> Self; + fn from_canonical_u32(n: u32) -> Self { + Self::from_canonical_u64(n as u64) + } + fn from_canonical_usize(n: usize) -> Self { Self::from_canonical_u64(n as u64) } @@ -134,17 +163,68 @@ pub trait Field: for j in 0..power.bits() { if (power.to_canonical_u64() >> j & 1) != 0 { - product = product * current; + product *= current; } current = current.square(); } product } + fn exp_u32(&self, power: u32) -> Self { + self.exp(Self::from_canonical_u32(power)) + } + fn exp_usize(&self, power: usize) -> Self { self.exp(Self::from_canonical_usize(power)) } + /// Returns whether `x^power` is a permutation of this field. + fn is_monomial_permutation(power: Self) -> bool { + if power.is_zero() { + return false; + } + if power.is_one() { + return true; + } + (Self::ORDER - 1).gcd(&power.to_canonical_u64()) == 1 + } + + fn kth_root(&self, k: Self) -> Self { + let p = Self::ORDER; + let p_minus_1 = p - 1; + debug_assert!( + Self::is_monomial_permutation(k), + "Not a permutation of this field" + ); + let k = k.to_canonical_u64(); + + // By Fermat's little theorem, x^p = x and x^(p - 1) = 1, so x^(p + n(p - 1)) = x for any n. + // Our assumption that the k'th root operation is a permutation implies gcd(p - 1, k) = 1, + // so there exists some n such that p + n(p - 1) is a multiple of k. Once we find such an n, + // we can rewrite the above as + // x^((p + n(p - 1))/k)^k = x, + // implying that x^((p + n(p - 1))/k) is a k'th root of x. + for n in 0..k { + let numerator = p as u128 + n as u128 * p_minus_1 as u128; + if numerator % k as u128 == 0 { + let power = (numerator / k as u128) as u64 % p_minus_1; + return self.exp(Self::from_canonical_u64(power)); + } + } + panic!( + "x^{} and x^(1/{}) are not permutations of this field, or we have a bug!", + k, k + ); + } + + fn kth_root_u32(&self, k: u32) -> Self { + self.kth_root(Self::from_canonical_u32(k)) + } + + fn cube_root(&self) -> Self { + self.kth_root_u32(3) + } + fn powers(&self) -> Powers { Powers { base: *self, diff --git a/src/field/field_testing.rs b/src/field/field_testing.rs index d070cb8d..bcb190e8 100644 --- a/src/field/field_testing.rs +++ b/src/field/field_testing.rs @@ -225,35 +225,86 @@ macro_rules! test_arithmetic { ) } - // #[test] - // #[ignore] - // fn arithmetic_division() { - // // This test takes ages to finish so is #[ignore]d by default. - // // TODO: Re-enable and reimplement when - // // https://github.com/rust-num/num-bigint/issues/60 is finally resolved. - // let modulus = <$field>::ORDER; - // crate::field::field_testing::run_binaryop_test_cases( - // modulus, - // WORD_BITS, - // // Need to help the compiler infer the type of y here - // |x: $field, y: $field| { - // // TODO: Work out how to check that div() panics - // // appropriately when given a zero divisor. - // if !y.is_zero() { - // <$field>::div(x, y) - // } else { - // <$field>::ZERO - // } - // }, - // |x, y| { - // // yinv = y^-1 (mod modulus) - // let exp = modulus - 2u64; - // let yinv = y.modpow(exp, modulus); - // // returns 0 if y was 0 - // x * yinv % modulus - // }, - // ) - // } + #[test] + fn inversion() { + let zero = <$field>::ZERO; + let one = <$field>::ONE; + let order = <$field>::ORDER; + + assert_eq!(zero.try_inverse(), None); + + for &x in &[1, 2, 3, order - 3, order - 2, order - 1] { + let x = <$field>::from_canonical_u64(x); + let inv = x.inverse(); + assert_eq!(x * inv, one); + } + } + + #[test] + fn batch_inversion() { + let xs = (1..=3) + .map(|i| <$field>::from_canonical_u64(i)) + .collect::>(); + let invs = <$field>::batch_multiplicative_inverse(&xs); + for (x, inv) in xs.into_iter().zip(invs) { + assert_eq!(x * inv, <$field>::ONE); + } + } + + #[test] + fn primitive_root_order() { + for n_power in 0..8 { + let root = <$field>::primitive_root_of_unity(n_power); + let order = <$field>::generator_order(root); + assert_eq!(order, 1 << n_power, "2^{}'th primitive root", n_power); + } + } + + #[test] + fn negation() { + let zero = <$field>::ZERO; + let order = <$field>::ORDER; + + for &i in &[0, 1, 2, order - 2, order - 1] { + let i_f = <$field>::from_canonical_u64(i); + assert_eq!(i_f + -i_f, zero); + } + } + + #[test] + fn bits() { + assert_eq!(<$field>::ZERO.bits(), 0); + assert_eq!(<$field>::ONE.bits(), 1); + assert_eq!(<$field>::TWO.bits(), 2); + assert_eq!(<$field>::from_canonical_u64(3).bits(), 2); + assert_eq!(<$field>::from_canonical_u64(4).bits(), 3); + assert_eq!(<$field>::from_canonical_u64(5).bits(), 3); + } + + #[test] + fn exponentiation() { + type F = $field; + + assert_eq!(F::ZERO.exp_u32(0), ::ONE); + assert_eq!(F::ONE.exp_u32(0), ::ONE); + assert_eq!(F::TWO.exp_u32(0), ::ONE); + + assert_eq!(F::ZERO.exp_u32(1), ::ZERO); + assert_eq!(F::ONE.exp_u32(1), ::ONE); + assert_eq!(F::TWO.exp_u32(1), ::TWO); + + assert_eq!(F::ZERO.kth_root_u32(1), ::ZERO); + assert_eq!(F::ONE.kth_root_u32(1), ::ONE); + assert_eq!(F::TWO.kth_root_u32(1), ::TWO); + + for power in 1..10 { + let power = F::from_canonical_u32(power); + if F::is_monomial_permutation(power) { + let x = F::rand(); + assert_eq!(x.exp(power).kth_root(power), x); + } + } + } } }; } diff --git a/src/field/lagrange.rs b/src/field/lagrange.rs new file mode 100644 index 00000000..06204b60 --- /dev/null +++ b/src/field/lagrange.rs @@ -0,0 +1,132 @@ +use crate::field::fft::ifft; +use crate::field::field::Field; +use crate::polynomial::polynomial::{PolynomialCoeffs, PolynomialValues}; +use crate::util::log2_ceil; + +/// Computes the unique degree < n interpolant of an arbitrary list of n (point, value) pairs. +/// +/// Note that the implementation assumes that `F` is two-adic, in particular that +/// `2^{F::TWO_ADICITY} >= points.len()`. This leads to a simple FFT-based implementation. +pub(crate) fn interpolant(points: &[(F, F)]) -> PolynomialCoeffs { + let n = points.len(); + let n_log = log2_ceil(n); + let n_padded = 1 << n_log; + + let g = F::primitive_root_of_unity(n_log); + let subgroup = F::cyclic_subgroup_known_order(g, n_padded); + let barycentric_weights = barycentric_weights(points); + let subgroup_evals = subgroup + .into_iter() + .map(|x| interpolate(points, x, &barycentric_weights)) + .collect(); + + let mut coeffs = ifft(PolynomialValues { + values: subgroup_evals, + }); + coeffs.trim(); + coeffs +} + +/// Interpolate the polynomial defined by an arbitrary set of (point, value) pairs at the given +/// point `x`. +fn interpolate(points: &[(F, F)], x: F, barycentric_weights: &[F]) -> F { + // If x is in the list of points, the Lagrange formula would divide by zero. + for &(x_i, y_i) in points { + if x_i == x { + return y_i; + } + } + + let l_x: F = points.iter().map(|&(x_i, y_i)| x - x_i).product(); + + let sum = (0..points.len()) + .map(|i| { + let x_i = points[i].0; + let y_i = points[i].1; + let w_i = barycentric_weights[i]; + w_i / (x - x_i) * y_i + }) + .sum(); + + l_x * sum +} + +fn barycentric_weights(points: &[(F, F)]) -> Vec { + let n = points.len(); + (0..n) + .map(|i| { + (0..n) + .filter(|&j| j != i) + .map(|j| points[i].0 - points[j].0) + .product::() + .inverse() + }) + .collect() +} + +#[cfg(test)] +mod tests { + use crate::field::crandall_field::CrandallField; + use crate::field::field::Field; + use crate::field::lagrange::interpolant; + use crate::polynomial::polynomial::PolynomialCoeffs; + + #[test] + fn interpolant_random() { + type F = CrandallField; + + for deg in 0..10 { + let domain = (0..deg).map(|_| F::rand()).collect::>(); + let coeffs = (0..deg).map(|_| F::rand()).collect(); + let coeffs = PolynomialCoeffs { coeffs }; + + let points = eval_naive(&coeffs, &domain); + assert_eq!(interpolant(&points), coeffs); + } + } + + #[test] + fn interpolant_random_roots_of_unity() { + type F = CrandallField; + + for deg_log in 0..4 { + let deg = 1 << deg_log; + let g = F::primitive_root_of_unity(deg_log); + let domain = F::cyclic_subgroup_known_order(g, deg); + let coeffs = (0..deg).map(|_| F::rand()).collect(); + let coeffs = PolynomialCoeffs { coeffs }; + + let points = eval_naive(&coeffs, &domain); + assert_eq!(interpolant(&points), coeffs); + } + } + + #[test] + fn interpolant_random_overspecified() { + type F = CrandallField; + + for deg in 0..10 { + let points = deg + 5; + let domain = (0..points).map(|_| F::rand()).collect::>(); + let coeffs = (0..deg).map(|_| F::rand()).collect(); + let coeffs = PolynomialCoeffs { coeffs }; + + let points = eval_naive(&coeffs, &domain); + assert_eq!(interpolant(&points), coeffs); + } + } + + fn eval_naive(coeffs: &PolynomialCoeffs, domain: &[F]) -> Vec<(F, F)> { + domain + .iter() + .map(|&x| { + let eval = x + .powers() + .zip(&coeffs.coeffs) + .map(|(x_power, &coeff)| coeff * x_power) + .sum(); + (x, eval) + }) + .collect() + } +} diff --git a/src/field/mod.rs b/src/field/mod.rs index 9f58ef08..6c2828a6 100644 --- a/src/field/mod.rs +++ b/src/field/mod.rs @@ -2,6 +2,7 @@ pub(crate) mod cosets; pub mod crandall_field; pub mod fft; pub mod field; +pub(crate) mod lagrange; #[cfg(test)] mod field_testing; diff --git a/src/fri.rs b/src/fri.rs index 9813d032..7047fb06 100644 --- a/src/fri.rs +++ b/src/fri.rs @@ -8,8 +8,6 @@ use crate::polynomial::polynomial::{PolynomialCoeffs, PolynomialValues}; use crate::proof::{FriEvaluations, FriMerkleProofs, FriProof, FriQueryRound, Hash}; use crate::util::log2_strict; use anyhow::{ensure, Result}; -use std::intrinsics::rotate_left; -use std::iter::FromIterator; /// Somewhat arbitrary. Smaller values will increase delta, but with diminishing returns, /// while increasing L, potentially requiring more challenge points. @@ -127,13 +125,12 @@ fn fri_proof_of_work(current_hash: Hash, config: &FriConfig) -> F { (0u64..) .find(|&i| { hash_n_to_1( - Vec::from_iter( - current_hash - .elements - .iter() - .copied() - .chain(Some(F::from_canonical_u64(i))), - ), + current_hash + .elements + .iter() + .copied() + .chain(Some(F::from_canonical_u64(i))) + .collect(), false, ) .to_canonical_u64() @@ -150,14 +147,13 @@ fn fri_verify_proof_of_work( config: &FriConfig, ) -> Result<()> { let hash = hash_n_to_1( - Vec::from_iter( - challenger - .get_hash() - .elements - .iter() - .copied() - .chain(Some(proof.pow_witness)), - ), + challenger + .get_hash() + .elements + .iter() + .copied() + .chain(Some(proof.pow_witness)) + .collect(), false, ); ensure!( diff --git a/src/gadgets/arithmetic.rs b/src/gadgets/arithmetic.rs index 8fa727bb..8fa5a226 100644 --- a/src/gadgets/arithmetic.rs +++ b/src/gadgets/arithmetic.rs @@ -1,9 +1,9 @@ use crate::circuit_builder::CircuitBuilder; use crate::field::field::Field; use crate::gates::arithmetic::ArithmeticGate; +use crate::generator::SimpleGenerator; use crate::target::Target; use crate::wire::Wire; -use crate::generator::SimpleGenerator; use crate::witness::PartialWitness; impl CircuitBuilder { @@ -22,8 +22,9 @@ impl CircuitBuilder { addend: Target, ) -> Target { // See if we can determine the result without adding an `ArithmeticGate`. - if let Some(result) = self.arithmetic_special_cases( - const_0, multiplicand_0, multiplicand_1, const_1, addend) { + if let Some(result) = + self.arithmetic_special_cases(const_0, multiplicand_0, multiplicand_1, const_1, addend) + { return result; } @@ -69,7 +70,8 @@ impl CircuitBuilder { let mul_1_const = self.target_as_constant(multiplicand_1); let addend_const = self.target_as_constant(addend); - let first_term_zero = const_0 == F::ZERO || multiplicand_0 == zero || multiplicand_1 == zero; + let first_term_zero = + const_0 == F::ZERO || multiplicand_0 == zero || multiplicand_1 == zero; let second_term_zero = const_1 == F::ZERO || addend == zero; // If both terms are constant, return their (constant) sum. @@ -89,10 +91,8 @@ impl CircuitBuilder { return Some(self.constant(x + y)); } - if first_term_zero { - if const_1.is_one() { - return Some(addend); - } + if first_term_zero && const_1.is_one() { + return Some(addend); } if second_term_zero { @@ -156,17 +156,31 @@ impl CircuitBuilder { if y == one { return x; } - if let (Some(x_const), Some(y_const)) = (self.target_as_constant(x), self.target_as_constant(y)) { + if let (Some(x_const), Some(y_const)) = + (self.target_as_constant(x), self.target_as_constant(y)) + { return self.constant(x_const / y_const); } // Add an `ArithmeticGate` to compute `q * y`. let gate = self.add_gate(ArithmeticGate::new(), vec![F::ONE, F::ZERO]); - let wire_multiplicand_0 = Wire { gate, input: ArithmeticGate::WIRE_MULTIPLICAND_0 }; - let wire_multiplicand_1 = Wire { gate, input: ArithmeticGate::WIRE_MULTIPLICAND_1 }; - let wire_addend = Wire { gate, input: ArithmeticGate::WIRE_ADDEND }; - let wire_output = Wire { gate, input: ArithmeticGate::WIRE_OUTPUT }; + let wire_multiplicand_0 = Wire { + gate, + input: ArithmeticGate::WIRE_MULTIPLICAND_0, + }; + let wire_multiplicand_1 = Wire { + gate, + input: ArithmeticGate::WIRE_MULTIPLICAND_1, + }; + let wire_addend = Wire { + gate, + input: ArithmeticGate::WIRE_ADDEND, + }; + let wire_output = Wire { + gate, + input: ArithmeticGate::WIRE_OUTPUT, + }; let q = Target::Wire(wire_multiplicand_0); self.add_generator(QuotientGenerator { diff --git a/src/gates/arithmetic.rs b/src/gates/arithmetic.rs index 85cc8e08..4a2f63b7 100644 --- a/src/gates/arithmetic.rs +++ b/src/gates/arithmetic.rs @@ -64,7 +64,6 @@ impl Gate for ArithmeticGate { &self, gate_index: usize, local_constants: &[F], - _next_constants: &[F], ) -> Vec>> { let gen = ArithmeticGenerator { gate_index, diff --git a/src/gates/constant.rs b/src/gates/constant.rs index cf2b9e04..8482a6de 100644 --- a/src/gates/constant.rs +++ b/src/gates/constant.rs @@ -45,7 +45,6 @@ impl Gate for ConstantGate { &self, gate_index: usize, local_constants: &[F], - _next_constants: &[F], ) -> Vec>> { let gen = ConstantGenerator { gate_index, diff --git a/src/gates/gate.rs b/src/gates/gate.rs index 58bd9dd9..a340a1bd 100644 --- a/src/gates/gate.rs +++ b/src/gates/gate.rs @@ -37,7 +37,6 @@ pub trait Gate: 'static + Send + Sync { &self, gate_index: usize, local_constants: &[F], - next_constants: &[F], ) -> Vec>>; /// The number of wires used by this gate. diff --git a/src/gates/gmimc.rs b/src/gates/gmimc.rs index 706b4bd1..82539388 100644 --- a/src/gates/gmimc.rs +++ b/src/gates/gmimc.rs @@ -126,7 +126,6 @@ impl Gate for GMiMCGate { &self, gate_index: usize, _local_constants: &[F], - _next_constants: &[F], ) -> Vec>> { let gen = GMiMCGenerator { gate_index, @@ -304,7 +303,7 @@ mod tests { ); } - let generators = gate.0.generators(0, &[], &[]); + let generators = gate.0.generators(0, &[]); generate_partial_witness(&mut witness, &generators); let expected_outputs: [F; W] = diff --git a/src/gates/gmimc_eval.rs b/src/gates/gmimc_eval.rs index 427559ab..e2eb4cb6 100644 --- a/src/gates/gmimc_eval.rs +++ b/src/gates/gmimc_eval.rs @@ -38,7 +38,6 @@ impl Gate for GMiMCEvalGate { &self, gate_index: usize, local_constants: &[F], - _next_constants: &[F], ) -> Vec>> { let gen = GMiMCEvalGenerator:: { gate_index, diff --git a/src/gates/noop.rs b/src/gates/noop.rs index fa261873..edd4e5dd 100644 --- a/src/gates/noop.rs +++ b/src/gates/noop.rs @@ -35,7 +35,6 @@ impl Gate for NoopGate { &self, _gate_index: usize, _local_constants: &[F], - _next_constants: &[F], ) -> Vec>> { Vec::new() } diff --git a/src/gmimc.rs b/src/gmimc.rs index 3b7d1e2c..9a65d49d 100644 --- a/src/gmimc.rs +++ b/src/gmimc.rs @@ -35,7 +35,7 @@ pub fn gmimc_compress( F::ZERO, F::ZERO, ]; - let state_1 = gmimc_permute::(state_0, constants.clone()); + let state_1 = gmimc_permute::(state_0, constants); [state_1[0], state_1[1], state_1[2], state_1[3]] } @@ -96,7 +96,7 @@ pub fn gmimc_permute_naive( let f = (xs[active] + constants[r]).cube(); for i in 0..W { if i != active { - xs[i] = xs[i] + f; + xs[i] += f; } } } diff --git a/src/hash.rs b/src/hash.rs index d87d3e28..a51d1f08 100644 --- a/src/hash.rs +++ b/src/hash.rs @@ -1,13 +1,10 @@ //! Concrete instantiation of a hash function. -use rayon::prelude::*; - use crate::circuit_builder::CircuitBuilder; use crate::field::field::Field; use crate::gmimc::gmimc_permute_array; use crate::proof::{Hash, HashTarget}; use crate::target::Target; -use crate::util::reverse_index_bits_in_place; pub(crate) const SPONGE_RATE: usize = 8; pub(crate) const SPONGE_CAPACITY: usize = 4; @@ -173,9 +170,7 @@ impl CircuitBuilder { // Overwrite the first r elements with the inputs. This differs from a standard sponge, // where we would xor or add in the inputs. This is a well-known variant, though, // sometimes called "overwrite mode". - for i in 0..input_chunk.len() { - state[i] = input_chunk[i]; - } + state[..input_chunk.len()].copy_from_slice(input_chunk); state = self.permute(state); } @@ -247,36 +242,3 @@ pub fn hash_n_to_hash(inputs: Vec, pad: bool) -> Hash { pub fn hash_n_to_1(inputs: Vec, pad: bool) -> F { hash_n_to_m(inputs, 1, pad)[0] } - -/// Like `merkle_root`, but first reorders each vector so that `new[i] = old[i.reverse_bits()]`. -pub(crate) fn merkle_root_bit_rev_order(mut vecs: Vec>) -> Hash { - reverse_index_bits_in_place(&mut vecs); - merkle_root(vecs) -} - -/// Given `n` vectors, each of length `l`, constructs a Merkle tree with `l` leaves, where each leaf -/// is a hash obtained by hashing a "leaf set" consisting of `n` elements. If `n <= 4`, this hashing -/// is skipped, as there is no need to compress leaf data. -pub(crate) fn merkle_root(vecs: Vec>) -> Hash { - let elems_per_leaf = vecs[0].len(); - let leaves_per_chunk = (ELEMS_PER_CHUNK / elems_per_leaf).next_power_of_two(); - let subtree_roots: Vec> = vecs - .par_chunks(leaves_per_chunk) - .map(|chunk| merkle_root_inner(chunk.to_vec()).elements.to_vec()) - .collect(); - merkle_root_inner(subtree_roots) -} - -pub(crate) fn merkle_root_inner(vecs: Vec>) -> Hash { - let mut hashes = vecs - .into_iter() - .map(|leaf_set| hash_or_noop(leaf_set)) - .collect::>(); - while hashes.len() > 1 { - hashes = hashes - .chunks(2) - .map(|pair| compress(pair[0], pair[1])) - .collect(); - } - hashes[0] -} diff --git a/src/merkle_tree.rs b/src/merkle_tree.rs index 5b695e79..d0e65058 100644 --- a/src/merkle_tree.rs +++ b/src/merkle_tree.rs @@ -1,3 +1,5 @@ +use rayon::prelude::*; + use crate::field::field::Field; use crate::hash::{compress, hash_or_noop}; use crate::merkle_proofs::MerkleProof; @@ -26,7 +28,7 @@ impl MerkleTree { reverse_index_bits_in_place(&mut leaves); } let mut layers = vec![leaves - .iter() + .par_iter() .map(|l| hash_or_noop(l.clone())) .collect::>()]; while let Some(l) = layers.last() { @@ -34,7 +36,7 @@ impl MerkleTree { break; } let next_layer = l - .chunks(2) + .par_chunks(2) .map(|chunk| compress(chunk[0], chunk[1])) .collect::>(); layers.push(next_layer); @@ -116,11 +118,13 @@ impl MerkleTree { #[cfg(test)] mod tests { - use super::*; + use anyhow::Result; + use crate::field::crandall_field::CrandallField; use crate::merkle_proofs::{verify_merkle_proof, verify_merkle_proof_subtree}; use crate::polynomial::division::divide_by_z_h; - use anyhow::Result; + + use super::*; fn random_data(n: usize, k: usize) -> Vec> { (0..n) diff --git a/src/plonk_challenger.rs b/src/plonk_challenger.rs index a98a6fc3..a8f5e605 100644 --- a/src/plonk_challenger.rs +++ b/src/plonk_challenger.rs @@ -107,6 +107,12 @@ impl Challenger { } } +impl Default for Challenger { + fn default() -> Self { + Self::new() + } +} + /// A recursive version of `Challenger`. pub(crate) struct RecursiveChallenger { sponge_state: [Target; SPONGE_WIDTH], diff --git a/src/polynomial/division.rs b/src/polynomial/division.rs index 41d872d0..0b5055ef 100644 --- a/src/polynomial/division.rs +++ b/src/polynomial/division.rs @@ -9,15 +9,15 @@ use crate::util::log2_strict; pub(crate) fn divide_by_z_h(mut a: PolynomialCoeffs, n: usize) -> PolynomialCoeffs { // TODO: Is this special case needed? if a.coeffs.iter().all(|p| *p == F::ZERO) { - return a.clone(); + return a; } let g = F::MULTIPLICATIVE_GROUP_GENERATOR; let mut g_pow = F::ONE; // Multiply the i-th coefficient of `a` by `g^i`. Then `new_a(w^j) = old_a(g.w^j)`. a.coeffs.iter_mut().for_each(|x| { - *x = (*x) * g_pow; - g_pow = g * g_pow; + *x *= g_pow; + g_pow *= g; }); let root = F::primitive_root_of_unity(log2_strict(a.len())); @@ -43,7 +43,7 @@ pub(crate) fn divide_by_z_h(mut a: PolynomialCoeffs, n: usize) -> P .iter_mut() .zip(denominators_inv.iter()) .for_each(|(x, &d)| { - *x = (*x) * d; + *x *= d; }); // `p` is the interpolating polynomial of `a_eval` on `{w^i}`. let mut p = ifft(a_eval); @@ -52,16 +52,46 @@ pub(crate) fn divide_by_z_h(mut a: PolynomialCoeffs, n: usize) -> P let g_inv = g.inverse(); let mut g_inv_pow = F::ONE; p.coeffs.iter_mut().for_each(|x| { - *x = (*x) * g_inv_pow; - g_inv_pow = g_inv_pow * g_inv; + *x *= g_inv_pow; + g_inv_pow *= g_inv; }); p } #[cfg(test)] mod tests { + use crate::field::crandall_field::CrandallField; + use crate::field::field::Field; + use crate::polynomial::division::divide_by_z_h; + use crate::polynomial::polynomial::PolynomialCoeffs; + + #[test] + fn zero_div_z_h() { + type F = CrandallField; + let zero = PolynomialCoeffs::::zero(16); + let quotient = divide_by_z_h(zero.clone(), 4); + assert_eq!(quotient, zero); + } + #[test] fn division_by_z_h() { - // TODO + type F = CrandallField; + let zero = F::ZERO; + let one = F::ONE; + let two = F::TWO; + let three = F::from_canonical_u64(3); + let four = F::from_canonical_u64(4); + let five = F::from_canonical_u64(5); + let six = F::from_canonical_u64(6); + + // a(x) = Z_4(x) q(x), where + // a(x) = 3 x^7 + 4 x^6 + 5 x^5 + 6 x^4 - 3 x^3 - 4 x^2 - 5 x - 6 + // Z_4(x) = x^4 - 1 + // q(x) = 3 x^3 + 4 x^2 + 5 x + 6 + let a = PolynomialCoeffs::new(vec![-six, -five, -four, -three, six, five, four, three]); + let q = PolynomialCoeffs::new(vec![six, five, four, three, zero, zero, zero, zero]); + + let computed_q = divide_by_z_h(a, 4); + assert_eq!(computed_q, q); } } diff --git a/src/polynomial/polynomial.rs b/src/polynomial/polynomial.rs index 11949494..7034a8b9 100644 --- a/src/polynomial/polynomial.rs +++ b/src/polynomial/polynomial.rs @@ -2,7 +2,7 @@ use crate::field::fft::{fft, ifft}; use crate::field::field::Field; use crate::util::log2_strict; -/// A polynomial in point-value form. The number of values must be a power of two. +/// A polynomial in point-value form. /// /// The points are implicitly `g^i`, where `g` generates the subgroup whose size equals the number /// of points. @@ -13,7 +13,6 @@ pub struct PolynomialValues { impl PolynomialValues { pub fn new(values: Vec) -> Self { - assert!(values.len().is_power_of_two()); PolynomialValues { values } } @@ -31,12 +30,12 @@ impl PolynomialValues { } pub fn lde(self, rate_bits: usize) -> Self { - let mut coeffs = ifft(self).lde(rate_bits); + let coeffs = ifft(self).lde(rate_bits); fft(coeffs) } } -/// A polynomial in coefficient form. The number of coefficients must be a power of two. +/// A polynomial in coefficient form. #[derive(Clone, Debug, Eq, PartialEq)] pub struct PolynomialCoeffs { pub(crate) coeffs: Vec, @@ -44,11 +43,11 @@ pub struct PolynomialCoeffs { impl PolynomialCoeffs { pub fn new(coeffs: Vec) -> Self { - assert!(coeffs.len().is_power_of_two()); PolynomialCoeffs { coeffs } } - pub(crate) fn pad(mut coeffs: Vec) -> Self { + /// Create a new polynomial with its coefficient list padded to the next power of two. + pub(crate) fn new_padded(mut coeffs: Vec) -> Self { while !coeffs.len().is_power_of_two() { coeffs.push(F::ZERO); } @@ -70,7 +69,6 @@ impl PolynomialCoeffs { } pub(crate) fn chunks(&self, chunk_size: usize) -> Vec { - assert!(chunk_size.is_power_of_two()); self.coeffs .chunks(chunk_size) .map(|chunk| PolynomialCoeffs::new(chunk.to_vec())) @@ -88,7 +86,7 @@ impl PolynomialCoeffs { polys.into_iter().map(|p| p.lde(rate_bits)).collect() } - pub(crate) fn lde(mut self, rate_bits: usize) -> Self { + pub(crate) fn lde(self, rate_bits: usize) -> Self { let original_size = self.len(); let lde_size = original_size << rate_bits; let Self { mut coeffs } = self; @@ -97,4 +95,17 @@ impl PolynomialCoeffs { } Self { coeffs } } + + /// Removes leading zero coefficients. + pub fn trim(&mut self) { + self.coeffs.drain(self.degree_plus_one()..); + } + + /// Degree of the polynomial + 1. + fn degree_plus_one(&self) -> usize { + (0usize..self.len()) + .rev() + .find(|&i| self.coeffs[i].is_nonzero()) + .map_or(0, |i| i + 1) + } } diff --git a/src/prover.rs b/src/prover.rs index 32735e17..4801b79d 100644 --- a/src/prover.rs +++ b/src/prover.rs @@ -7,7 +7,7 @@ use crate::circuit_data::{CommonCircuitData, ProverOnlyCircuitData}; use crate::field::fft::{fft, ifft}; use crate::field::field::Field; use crate::generator::generate_partial_witness; -use crate::hash::merkle_root_bit_rev_order; +use crate::merkle_tree::MerkleTree; use crate::plonk_challenger::Challenger; use crate::plonk_common::{eval_l_1, evaluate_gate_constraints, reduce_with_powers_multi}; use crate::polynomial::division::divide_by_z_h; @@ -61,14 +61,18 @@ pub(crate) fn prove( // TODO: Could avoid cloning if it's significant? let start_wires_root = Instant::now(); - let wires_root = merkle_root_bit_rev_order(wire_ldes_t.clone()); + let wires_tree = MerkleTree::new(wire_ldes_t, true); info!( "{:.3}s to Merklize wire LDEs", start_wires_root.elapsed().as_secs_f32() ); let mut challenger = Challenger::new(); - challenger.observe_hash(&wires_root); + // Observe the instance. + // TODO: Need to include public inputs as well. + challenger.observe_hash(&common_data.circuit_digest); + + challenger.observe_hash(&wires_tree.root); let betas = challenger.get_n_challenges(num_checks); let gammas = challenger.get_n_challenges(num_checks); @@ -82,28 +86,24 @@ pub(crate) fn prove( ); let start_plonk_z_root = Instant::now(); - let plonk_zs_root = merkle_root_bit_rev_order(plonk_z_ldes_t.clone()); + let plonk_zs_tree = MerkleTree::new(plonk_z_ldes_t, true); info!( "{:.3}s to Merklize Z's", start_plonk_z_root.elapsed().as_secs_f32() ); - challenger.observe_hash(&plonk_zs_root); + challenger.observe_hash(&plonk_zs_tree.root); let alphas = challenger.get_n_challenges(num_checks); - // TODO - let beta = betas[0]; - let gamma = gammas[0]; - let start_vanishing_polys = Instant::now(); let vanishing_polys = compute_vanishing_polys( common_data, prover_data, - wire_ldes_t, - plonk_z_ldes_t, - beta, - gamma, + &wires_tree, + &plonk_zs_tree, + &betas, + &gammas, &alphas, ); info!( @@ -125,8 +125,8 @@ pub(crate) fn prove( quotient_poly_coeff_ldes.into_par_iter().map(fft).collect(); all_quotient_poly_chunk_ldes.extend(quotient_poly_chunk_ldes); } - let quotient_polys_root = - merkle_root_bit_rev_order(transpose_poly_values(all_quotient_poly_chunk_ldes)); + let quotient_polys_tree = + MerkleTree::new(transpose_poly_values(all_quotient_poly_chunk_ldes), true); info!( "{:.3}s to compute quotient polys and their LDEs", quotient_polys_start.elapsed().as_secs_f32() @@ -142,9 +142,9 @@ pub(crate) fn prove( ); Proof { - wires_root, - plonk_zs_root, - quotient_polys_root, + wires_root: wires_tree.root, + plonk_zs_root: plonk_zs_tree.root, + quotient_polys_root: quotient_polys_tree.root, openings, fri_proofs, } @@ -164,10 +164,10 @@ fn compute_z(common_data: &CommonCircuitData, i: usize) -> Polynomi fn compute_vanishing_polys( common_data: &CommonCircuitData, prover_data: &ProverOnlyCircuitData, - wire_ldes_t: Vec>, - plonk_z_lde_t: Vec>, - beta: F, - gamma: F, + wires_tree: &MerkleTree, + plonk_zs_tree: &MerkleTree, + betas: &[F], + gammas: &[F], alphas: &[F], ) -> Vec> { let lde_size = common_data.lde_size(); @@ -180,22 +180,18 @@ fn compute_vanishing_polys( .enumerate() .map(|(i, x)| { let i_next = (i + 1) % lde_size; - let local_wires = &wire_ldes_t[i]; - let next_wires = &wire_ldes_t[i_next]; - let local_constants = &prover_data.constant_ldes_t[i]; - let next_constants = &prover_data.constant_ldes_t[i_next]; - let local_plonk_zs = &plonk_z_lde_t[i]; - let next_plonk_zs = &plonk_z_lde_t[i_next]; - let s_sigmas = &prover_data.sigma_ldes_t[i]; + let local_wires = &wires_tree.leaves[i]; + let local_constants = &prover_data.constants_tree.leaves[i]; + let local_plonk_zs = &plonk_zs_tree.leaves[i]; + let next_plonk_zs = &plonk_zs_tree.leaves[i_next]; + let s_sigmas = &prover_data.sigmas_tree.leaves[i]; debug_assert_eq!(local_wires.len(), common_data.config.num_wires); debug_assert_eq!(local_plonk_zs.len(), num_checks); let vars = EvaluationVars { local_constants, - next_constants, local_wires, - next_wires, }; compute_vanishing_poly_entry( common_data, @@ -204,8 +200,8 @@ fn compute_vanishing_polys( local_plonk_zs, next_plonk_zs, s_sigmas, - beta, - gamma, + betas, + gammas, alphas, ) }) @@ -227,8 +223,8 @@ fn compute_vanishing_poly_entry( local_plonk_zs: &[F], next_plonk_zs: &[F], s_sigmas: &[F], - beta: F, - gamma: F, + betas: &[F], + gammas: &[F], alphas: &[F], ) -> Vec { let constraint_terms = @@ -251,8 +247,8 @@ fn compute_vanishing_poly_entry( let k_i = common_data.k_is[j]; let s_id = k_i * x; let s_sigma = s_sigmas[j]; - f_prime *= wire_value + beta * s_id + gamma; - g_prime *= wire_value + beta * s_sigma + gamma; + f_prime *= wire_value + betas[i] * s_id + gammas[i]; + g_prime *= wire_value + betas[i] * s_sigma + gammas[i]; } vanishing_v_shift_terms.push(f_prime * z_x - g_prime * z_gz); } diff --git a/src/recursive_verifier.rs b/src/recursive_verifier.rs index 54d35bf0..b850e587 100644 --- a/src/recursive_verifier.rs +++ b/src/recursive_verifier.rs @@ -17,4 +17,6 @@ pub fn add_recursive_verifier( ) { assert!(builder.config.num_wires >= MIN_WIRES); assert!(builder.config.num_wires >= MIN_ROUTED_WIRES); + + todo!() } diff --git a/src/rescue.rs b/src/rescue.rs index 88785a1e..f088fdb7 100644 --- a/src/rescue.rs +++ b/src/rescue.rs @@ -1,8 +1,10 @@ +//! Implements Rescue Prime. + use unroll::unroll_for_loops; use crate::field::field::Field; -const ROUNDS: usize = 10; +const ROUNDS: usize = 8; const W: usize = 12; @@ -177,7 +179,7 @@ const MDS: [[u64; W]; W] = [ ], ]; -const RESCUE_CONSTANTS: [[u64; W]; 20] = [ +const RESCUE_CONSTANTS: [[u64; W]; 16] = [ [ 12050887499329086906, 1748247961703512657, @@ -402,66 +404,10 @@ const RESCUE_CONSTANTS: [[u64; W]; 20] = [ 16465224002344550280, 10282380383506806095, ], - [ - 12608209810104211593, - 11808578423511814760, - 16177950852717156460, - 9394439296563712221, - 12586575762376685187, - 17703393198607870393, - 9811861465513647715, - 14126450959506560131, - 12713673607080398908, - 18301828072718562389, - 11180556590297273821, - 4451415492203885059, - ], - [ - 10465807219916311101, - 1213997644391575261, - 17672155373280862521, - 1491206970207330736, - 10977478805896263804, - 13260961975618373124, - 16060889403827043708, - 3223573072465920682, - 17624203443801796697, - 10247205738678800822, - 11100653267668698651, - 14328592975764892571, - ], - [ - 6984072551318461094, - 3416562710010527326, - 12847783919251969270, - 12223185134739244472, - 12073170519625198198, - 6221124633828606855, - 17596623990006806590, - 1153871693574764968, - 2548851681903410721, - 9823373270182377847, - 16708030507924899244, - 9619306826188519218, - ], - [ - 5842685042453818473, - 12400879353954910914, - 647112787845575111, - 4893664959929687347, - 3759391664155971284, - 15871181179823725763, - 3629377713951158273, - 3439101502554162312, - 8325686353010019444, - 10630488935940555500, - 3478529754946055748, - 12681233130980545828, - ], ]; -fn rescue(mut xs: [F; W]) -> [F; W] { - for r in 0..10 { +pub fn rescue(mut xs: [F; W]) -> [F; W] { + for r in 0..8 { xs = sbox_layer_a(xs); xs = mds_layer(xs); xs = constant_layer(xs, &RESCUE_CONSTANTS[r * 2]); @@ -470,61 +416,27 @@ fn rescue(mut xs: [F; W]) -> [F; W] { xs = mds_layer(xs); xs = constant_layer(xs, &RESCUE_CONSTANTS[r * 2 + 1]); } - - // for i in 0..W { - // xs[i] = xs[i].to_canonical(); - // } - xs } -// #[inline(always)] #[unroll_for_loops] fn sbox_layer_a(x: [F; W]) -> [F; W] { let mut result = [F::ZERO; W]; for i in 0..W { - result[i] = sbox_a(x[i]); + result[i] = x[i].cube(); } result } -// #[inline(always)] #[unroll_for_loops] fn sbox_layer_b(x: [F; W]) -> [F; W] { let mut result = [F::ZERO; W]; for i in 0..W { - result[i] = sbox_b(x[i]); + result[i] = x[i].cube_root(); } result } -// #[inline(always)] -#[unroll_for_loops] -fn sbox_a(x: F) -> F { - // x^{-5}, via Fermat's little theorem - const EXP: u64 = 7378697628517453005; - - let mut product = F::ONE; - let mut current = x; - - for i in 0..64 { - if ((EXP >> i) & 1) != 0 { - product = product * current; - } - current = current.square(); - } - product -} - -#[inline(always)] -fn sbox_b(x: F) -> F { - // x^5 - let x2 = x.square(); - let x3 = x2 * x; - x2 * x3 -} - -// #[inline(always)] #[unroll_for_loops] fn mds_layer(x: [F; W]) -> [F; W] { let mut result = [F::ZERO; W]; @@ -536,7 +448,6 @@ fn mds_layer(x: [F; W]) -> [F; W] { result } -#[inline(always)] #[unroll_for_loops] fn constant_layer(xs: [F; W], con: &[u64; W]) -> [F; W] { let mut result = [F::ZERO; W]; diff --git a/src/vars.rs b/src/vars.rs index 532ddbfe..f2744e8f 100644 --- a/src/vars.rs +++ b/src/vars.rs @@ -4,15 +4,11 @@ use crate::target::Target; #[derive(Copy, Clone)] pub struct EvaluationVars<'a, F: Field> { pub(crate) local_constants: &'a [F], - pub(crate) next_constants: &'a [F], pub(crate) local_wires: &'a [F], - pub(crate) next_wires: &'a [F], } #[derive(Copy, Clone)] pub struct EvaluationTargets<'a> { pub(crate) local_constants: &'a [Target], - pub(crate) next_constants: &'a [Target], pub(crate) local_wires: &'a [Target], - pub(crate) next_wires: &'a [Target], } diff --git a/src/verifier.rs b/src/verifier.rs index 64ae2aed..c0afc07f 100644 --- a/src/verifier.rs +++ b/src/verifier.rs @@ -2,7 +2,7 @@ use crate::circuit_data::{CommonCircuitData, VerifierOnlyCircuitData}; use crate::field::field::Field; pub(crate) fn verify( - verifier_data: &VerifierOnlyCircuitData, + verifier_data: &VerifierOnlyCircuitData, common_data: &CommonCircuitData, ) { todo!() diff --git a/src/witness.rs b/src/witness.rs index 4c5e89a8..42b7150c 100644 --- a/src/witness.rs +++ b/src/witness.rs @@ -79,3 +79,9 @@ impl PartialWitness { } } } + +impl Default for PartialWitness { + fn default() -> Self { + Self::new() + } +}