diff --git a/Cargo.toml b/Cargo.toml index c1d9faf1..71970b41 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -25,6 +25,7 @@ unroll = "0.1.5" anyhow = "1.0.40" serde = { version = "1.0", features = ["derive"] } serde_cbor = "0.11.1" +static_assertions = "1.1.0" [dev-dependencies] criterion = "0.3.5" diff --git a/src/hash/arch/x86_64/poseidon_goldilocks_avx2_bmi2.rs b/src/hash/arch/x86_64/poseidon_goldilocks_avx2_bmi2.rs index 1f0978f0..8503d5b2 100644 --- a/src/hash/arch/x86_64/poseidon_goldilocks_avx2_bmi2.rs +++ b/src/hash/arch/x86_64/poseidon_goldilocks_avx2_bmi2.rs @@ -2,13 +2,19 @@ use core::arch::x86_64::*; use std::convert::TryInto; use std::mem::size_of; +use static_assertions::const_assert; + use crate::field::field_types::Field; use crate::field::goldilocks_field::GoldilocksField; -use crate::hash::poseidon::{ALL_ROUND_CONSTANTS, HALF_N_FULL_ROUNDS, N_PARTIAL_ROUNDS, N_ROUNDS}; +use crate::hash::poseidon::{ + Poseidon, ALL_ROUND_CONSTANTS, HALF_N_FULL_ROUNDS, N_PARTIAL_ROUNDS, N_ROUNDS, +}; // WARNING: This code contains tricks that work for the current MDS matrix and round constants, but // are not guaranteed to work if those are changed. +// * Constant definitions * + const WIDTH: usize = 12; // These tranformed round constants are used where the constant layer is fused with the preceeding @@ -31,6 +37,68 @@ const FUSED_ROUND_CONSTANTS: [u64; WIDTH * N_ROUNDS] = make_fused_round_constant // indices: [0, 11, ..., 1]. static TOP_ROW_EXPS: [usize; 12] = [0, 10, 16, 3, 12, 8, 1, 5, 3, 0, 1, 0]; +// * Compile-time checks * + +/// The MDS matrix multiplication ASM is specific to the MDS matrix below. We want this file to +/// fail to compile if it has been changed. +#[allow(dead_code)] +const fn check_mds_matrix() -> bool { + // Can't == two arrays in a const_assert! (: + let mut i = 0; + let wanted_matrix_exps = [0, 0, 1, 0, 3, 5, 1, 8, 12, 3, 16, 10]; + while i < WIDTH { + if >::MDS_MATRIX_EXPS[i] != wanted_matrix_exps[i] { + return false; + } + i += 1; + } + true +} +const_assert!(check_mds_matrix()); + +/// The maximum amount by which the MDS matrix will multiply the input. +/// i.e. max(MDS(state)) <= mds_matrix_inf_norm() * max(state). +const fn mds_matrix_inf_norm() -> u64 { + let mut cumul = 0; + let mut i = 0; + while i < WIDTH { + cumul += 1 << >::MDS_MATRIX_EXPS[i]; + i += 1; + } + cumul +} + +/// Ensure that adding round constants to the low result of the MDS multiplication can never +/// overflow. +#[allow(dead_code)] +const fn check_round_const_bounds_mds() -> bool { + let max_mds_res = mds_matrix_inf_norm() * (u32::MAX as u64); + let mut i = WIDTH; // First const layer is handled specially. + while i < WIDTH * N_ROUNDS { + if ALL_ROUND_CONSTANTS[i].overflowing_add(max_mds_res).1 { + return false; + } + i += 1; + } + true +} +const_assert!(check_round_const_bounds_mds()); + +/// Ensure that the first WIDTH round constants are in canonical form for the vpcmpgtd trick. +#[allow(dead_code)] +const fn check_round_const_bounds_init() -> bool { + let max_permitted_round_const = 0xffffffff00000000; + let mut i = 0; // First const layer is handled specially. + while i < WIDTH { + if ALL_ROUND_CONSTANTS[i] > max_permitted_round_const { + return false; + } + i += 1; + } + true +} +const_assert!(check_round_const_bounds_init()); + // Preliminary notes: // 1. AVX does not support addition with carry but 128-bit (2-word) addition can be easily // emulated. The method recognizes that for a + b overflowed iff (a + b) < a: