make poseidon hasher working

This commit is contained in:
kilic 2020-04-25 21:53:24 +03:00
parent 9352948baf
commit 9f3762c16a
1 changed files with 117 additions and 100 deletions

View File

@ -3,34 +3,7 @@ use blake2::{Blake2s, Digest};
use sapling_crypto::bellman::pairing::ff::{Field, PrimeField, PrimeFieldRepr};
use sapling_crypto::bellman::pairing::Engine;
#[test]
fn test_poseidon() {
use sapling_crypto::bellman::pairing::bn256;
use sapling_crypto::bellman::pairing::bn256::{Bn256, Fr};
let (t, rf, rp) = (3usize, 8usize, 57usize);
let mut hasher = Blake2s::new();
hasher.input(b"rln poseidion t3rf4rp57");
let seed = hasher.result().to_vec();
let person_full_round_constant = b"rlnhds01";
let full_round_constants = PoseidonParams::<Bn256>::generate_constants(person_full_round_constant, seed.clone(), rf);
let person_partial_round_constant = b"rlnhds02";
let partial_round_constants =
PoseidonParams::<Bn256>::generate_constants(person_partial_round_constant, seed.clone(), rp);
let person_mds_matrix = b"rlnhds03";
let mds_matrix = PoseidonParams::<Bn256>::generate_mds_matrix(person_mds_matrix, seed.clone(), t);
let mut constants: Vec<Fr> = Vec::new();
constants.extend_from_slice(&full_round_constants[0..rf / 2]);
constants.extend_from_slice(&partial_round_constants);
constants.extend_from_slice(&full_round_constants[(rf / 2)..rf]);
let mut hasher = Poseidon::<Bn256>::new_with_params(rf, rp, t, constants, mds_matrix);
let input = [Fr::zero()];
let r1: Fr = hasher.hash(&input);
let r2: Fr = hasher.hash(&input);
println!("{}", r1);
assert_eq!(r1, r2, "just to see if internal state resets");
}
struct PoseidonParams<E: Engine> {
pub struct PoseidonParams<E: Engine> {
rf: usize,
rp: usize,
t: usize,
@ -40,7 +13,7 @@ struct PoseidonParams<E: Engine> {
impl<E: Engine> PoseidonParams<E> {
pub fn new(rf: usize, rp: usize, t: usize, round_constants: Vec<E::Fr>, mds_matrix: Vec<E::Fr>) -> PoseidonParams<E> {
assert_eq!(rf + rp, round_constants.len());
assert_eq!((rf + rp) * t, round_constants.len());
PoseidonParams {
rf,
rp,
@ -49,43 +22,51 @@ impl<E: Engine> PoseidonParams<E> {
mds_matrix,
}
}
pub fn t(&self) -> usize {
return self.t;
}
pub fn width(&self) -> usize {
return self.t;
}
pub fn partial_round_len(&self) -> usize {
return self.rp;
}
pub fn full_round_len(&self) -> usize {
return self.rf + self.rp;
}
pub fn full_round_half_len(&self) -> usize {
return self.rf / 2;
}
pub fn total_rounds(&self) -> usize {
return self.rf + self.rp;
}
pub fn round_constant(&self, round: usize, block: usize) -> E::Fr {
let w = self.t();
return self.round_constants[round * w + block];
}
pub fn mds_matrix_row(&self, i: usize) -> Vec<E::Fr> {
let t = self.t();
let mut row: Vec<E::Fr> = Vec::with_capacity(t);
for j in i * t..(i + 1) * t {
row.push(self.mds_matrix[j]);
}
row
}
pub fn generate_mds_matrix(persona: &[u8; 8], seed: Vec<u8>, t: usize) -> Vec<E::Fr> {
let v: Vec<E::Fr> = PoseidonParams::<E>::generate_constants(persona, seed, t * 2);
let mut matrix: Vec<E::Fr> = Vec::with_capacity(t * t);
let mut xs: Vec<E::Fr> = Vec::with_capacity(t);
let mut ys: Vec<E::Fr> = Vec::with_capacity(t);
let mut source = seed.clone();
loop {
let mut hasher = Blake2s::new();
hasher.input(persona);
hasher.input(source);
source = hasher.result().to_vec();
let mut candidate_repr = <E::Fr as PrimeField>::Repr::default();
candidate_repr.read_le(&source[..]).unwrap();
if let Ok(candidate) = E::Fr::from_repr(candidate_repr) {
xs.push(candidate);
if xs.len() == t {
break;
}
}
}
loop {
let mut hasher = Blake2s::new();
hasher.input(persona);
hasher.input(source);
source = hasher.result().to_vec();
let mut candidate_repr = <E::Fr as PrimeField>::Repr::default();
candidate_repr.read_le(&source[..]).unwrap();
if let Ok(candidate) = E::Fr::from_repr(candidate_repr) {
ys.push(candidate);
if ys.len() == t {
break;
}
}
}
for i in 0..t {
for j in 0..t {
let mut tmp = xs[i];
tmp.add_assign(&ys[j]);
let mut tmp = v[i];
tmp.add_assign(&v[t + j]);
let entry = tmp.inverse().unwrap();
matrix.insert((i * t) + j, entry);
}
@ -93,7 +74,8 @@ impl<E: Engine> PoseidonParams<E> {
matrix
}
fn generate_constants(persona: &[u8; 8], seed: Vec<u8>, len: usize) -> Vec<E::Fr> {
pub fn generate_constants(persona: &[u8; 8], seed: Vec<u8>, len: usize) -> Vec<E::Fr> {
use hex;
let mut constants: Vec<E::Fr> = Vec::new();
let mut source = seed.clone();
loop {
@ -114,7 +96,7 @@ impl<E: Engine> PoseidonParams<E> {
}
}
struct Poseidon<E: Engine> {
pub struct Poseidon<E: Engine> {
state: Vec<E::Fr>,
round: usize,
params: PoseidonParams<E>,
@ -139,13 +121,6 @@ impl<E: Engine> Poseidon<E> {
}
}
pub fn hash(&mut self, inputs: &[E::Fr]) -> E::Fr {
self.new_state(inputs);
while self.round() {}
self.round = 0;
self.result()
}
fn new_state(&mut self, inputs: &[E::Fr]) {
let t = self.t();
assert!(inputs.len() < t);
@ -153,6 +128,10 @@ impl<E: Engine> Poseidon<E> {
self.state.resize(t, E::Fr::zero());
}
fn clear(&mut self) {
self.round = 0;
}
fn t(&self) -> usize {
self.params.t
}
@ -161,65 +140,75 @@ impl<E: Engine> Poseidon<E> {
self.state[0]
}
fn round(&mut self) -> bool {
let a1 = self.params.rf / 2;
let a2 = self.params.rf / 2 + self.params.rp;
let a3 = self.params.rf + self.params.rp;
pub fn hash(&mut self, inputs: &[E::Fr]) -> E::Fr {
self.new_state(inputs);
loop {
self.round(self.round);
self.round += 1;
if self.round == self.params.full_round_len() {
break;
}
}
let r = self.result();
self.clear();
r
}
if self.round < a1 {
self.full_round();
false
} else if self.round >= a1 && self.round < a2 {
self.partial_round();
false
} else if self.round >= a2 && self.round < a3 {
self.full_round();
false
fn round(&mut self, round: usize) {
let a1 = self.params.full_round_half_len();
let a2 = a1 + self.params.partial_round_len();
let a3 = self.params.total_rounds();
if round < a1 {
self.full_round(round);
} else if round >= a1 && round < a2 {
self.partial_round(round);
} else if round >= a2 && round < a3 {
self.full_round(round);
} else {
true
panic!("should not be here")
}
}
fn full_round(&mut self) {
self.add_round_constants();
self.apply_quintic_sbox();
fn full_round(&mut self, round: usize) {
self.add_round_constants(round);
self.apply_quintic_sbox(true);
self.mul_mds_matrix();
}
fn full_round_no_mds(&mut self) {
self.add_round_constants();
self.apply_quintic_sbox();
fn full_round_no_mds(&mut self, round: usize) {
self.add_round_constants(round);
self.apply_quintic_sbox(true);
}
fn partial_round(&mut self) {
self.add_round_constants();
self.apply_quintic_sbox();
fn partial_round(&mut self, round: usize) {
self.add_round_constants(round);
self.apply_quintic_sbox(false);
self.mul_mds_matrix();
}
fn add_round_constants(&mut self) {
fn add_round_constants(&mut self, round: usize) {
let w = self.params.t;
// use zip
for (j, b) in self.state.iter_mut().enumerate() {
let c = self.params.round_constants[self.round * w + j];
let c = self.params.round_constants[round * w + j];
b.add_assign(&c);
}
}
fn apply_quintic_sbox(&mut self) {
fn apply_quintic_sbox(&mut self, full: bool) {
for s in self.state.iter_mut() {
let mut b = s.clone();
b.square();
b.square();
s.mul_assign(&b);
if !full {
break;
}
}
}
fn mul_mds_matrix(&mut self) {
let w = self.params.t;
let mut new_state = vec![E::Fr::zero(); w];
for (i, ns) in new_state.iter_mut().enumerate() {
// slice and zip
for (j, s) in self.state.iter().enumerate() {
let mut tmp = s.clone();
tmp.mul_assign(&self.params.mds_matrix[i * w + j]);
@ -229,3 +218,31 @@ impl<E: Engine> Poseidon<E> {
self.state = new_state;
}
}
#[test]
fn test_poseidon() {
use sapling_crypto::bellman::pairing::bn256;
use sapling_crypto::bellman::pairing::bn256::{Bn256, Fr};
let (t, rf, rp) = (3usize, 8usize, 57usize);
let mut hasher = Blake2s::new();
hasher.input(b"rln poseidion t3rf4rp57");
let seed = hasher.result().to_vec();
let person_full_round_constant = b"rlnhds_c";
let person_mds_matrix = b"rlnhds_m";
let round_constants =
PoseidonParams::<Bn256>::generate_constants(person_full_round_constant, seed.clone(), (rf + rp) * t);
let mds_matrix = PoseidonParams::<Bn256>::generate_mds_matrix(person_mds_matrix, seed.clone(), t);
let mut hasher = Poseidon::<Bn256>::new_with_params(rf, rp, t, round_constants, mds_matrix);
// let input1 = [Fr::from_str("1").unwrap(), Fr::from_str("2").unwrap()];
let input1 = [Fr::zero()];
let r1: Fr = hasher.hash(&input1);
let input2 = [Fr::zero(), Fr::zero()];
let r2: Fr = hasher.hash(&input2);
assert_eq!(r1, r2, "just to see if internal state resets");
}