mirror of https://github.com/status-im/rln.git
skip last mds mul
This commit is contained in:
parent
d639d64aaf
commit
c1130234ef
|
@ -3,7 +3,6 @@ use sapling_crypto::bellman::pairing::ff::{Field, PrimeField, PrimeFieldRepr};
|
||||||
use sapling_crypto::bellman::pairing::Engine;
|
use sapling_crypto::bellman::pairing::Engine;
|
||||||
use sapling_crypto::bellman::{Circuit, ConstraintSystem, LinearCombination, SynthesisError};
|
use sapling_crypto::bellman::{Circuit, ConstraintSystem, LinearCombination, SynthesisError};
|
||||||
use sapling_crypto::circuit::{boolean, ecc, num, Assignment};
|
use sapling_crypto::circuit::{boolean, ecc, num, Assignment};
|
||||||
// use sapling_crypto::jubjub::JubjubEngine;
|
|
||||||
|
|
||||||
#[derive(Clone)]
|
#[derive(Clone)]
|
||||||
struct Element<E>
|
struct Element<E>
|
||||||
|
@ -35,6 +34,7 @@ where
|
||||||
elements: Vec<Element<E>>,
|
elements: Vec<Element<E>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Clone)]
|
||||||
pub struct PoseidonCircuit<E>
|
pub struct PoseidonCircuit<E>
|
||||||
where
|
where
|
||||||
E: Engine,
|
E: Engine,
|
||||||
|
@ -136,6 +136,10 @@ where
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn is_last_round(&self) -> bool {
|
||||||
|
self.number == self.params.total_rounds() - 1
|
||||||
|
}
|
||||||
|
|
||||||
pub fn in_transition(&self) -> bool {
|
pub fn in_transition(&self) -> bool {
|
||||||
let a1 = self.params.full_round_half_len();
|
let a1 = self.params.full_round_half_len();
|
||||||
let a2 = a1 + self.params.partial_round_len();
|
let a2 = a1 + self.params.partial_round_len();
|
||||||
|
@ -180,8 +184,7 @@ impl<E> State<E>
|
||||||
where
|
where
|
||||||
E: Engine,
|
E: Engine,
|
||||||
{
|
{
|
||||||
pub fn new(input: Vec<num::AllocatedNum<E>>) -> Self {
|
pub fn new(elements: Vec<Element<E>>) -> Self {
|
||||||
let elements = input.iter().map(|el| Element::new_from_alloc(el.clone())).collect();
|
|
||||||
Self { elements }
|
Self { elements }
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -198,6 +201,7 @@ where
|
||||||
|
|
||||||
fn sbox<CS: ConstraintSystem<E>>(&mut self, mut cs: CS, ctx: &mut RoundCtx<E>) -> Result<(), SynthesisError> {
|
fn sbox<CS: ConstraintSystem<E>>(&mut self, mut cs: CS, ctx: &mut RoundCtx<E>) -> Result<(), SynthesisError> {
|
||||||
assert_eq!(ctx.width(), self.elements.len());
|
assert_eq!(ctx.width(), self.elements.len());
|
||||||
|
|
||||||
for i in 0..if ctx.is_full_round() { ctx.width() } else { 1 } {
|
for i in 0..if ctx.is_full_round() { ctx.width() } else { 1 } {
|
||||||
let round_constant = ctx.round_constant(i);
|
let round_constant = ctx.round_constant(i);
|
||||||
let si = {
|
let si = {
|
||||||
|
@ -244,41 +248,49 @@ where
|
||||||
ctx: &mut RoundCtx<E>,
|
ctx: &mut RoundCtx<E>,
|
||||||
) -> Result<(), SynthesisError> {
|
) -> Result<(), SynthesisError> {
|
||||||
assert_eq!(ctx.width(), self.elements.len());
|
assert_eq!(ctx.width(), self.elements.len());
|
||||||
let mut new_state: Vec<num::Num<E>> = Vec::new();
|
|
||||||
let w = ctx.width();
|
|
||||||
|
|
||||||
for i in 0..w {
|
if !ctx.is_last_round() {
|
||||||
let row = ctx.mds_matrix_row(i);
|
// skip mds multiplication in last round
|
||||||
let mut acc = num::Num::<E>::zero();
|
|
||||||
for j in 0..w {
|
let mut new_state: Vec<num::Num<E>> = Vec::new();
|
||||||
let mut r = self.elements[j].num();
|
let w = ctx.width();
|
||||||
r.scale(row[j]);
|
|
||||||
acc.add_assign(&r);
|
for i in 0..w {
|
||||||
|
let row = ctx.mds_matrix_row(i);
|
||||||
|
let mut acc = num::Num::<E>::zero();
|
||||||
|
for j in 0..w {
|
||||||
|
let mut r = self.elements[j].num();
|
||||||
|
r.scale(row[j]);
|
||||||
|
acc.add_assign(&r);
|
||||||
|
}
|
||||||
|
new_state.push(acc);
|
||||||
}
|
}
|
||||||
new_state.push(acc);
|
|
||||||
}
|
|
||||||
|
|
||||||
// round ends here
|
// round ends here
|
||||||
let is_full_round = ctx.is_full_round();
|
let is_full_round = ctx.is_full_round();
|
||||||
let in_transition = ctx.in_transition();
|
let in_transition = ctx.in_transition();
|
||||||
ctx.round_end();
|
ctx.round_end();
|
||||||
|
|
||||||
// add round constants just after mds if
|
// add round constants just after mds if
|
||||||
// first full round has just ended
|
// first full round has just ended
|
||||||
// or in partial rounds expect the last one.
|
// or in partial rounds expect the last one.
|
||||||
if in_transition == is_full_round {
|
if in_transition == is_full_round {
|
||||||
// add round constants for elements in {1, t}
|
// add round constants for elements in {1, t}
|
||||||
let round_constants = ctx.round_constants();
|
let round_constants = ctx.round_constants();
|
||||||
for i in 1..w {
|
for i in 1..w {
|
||||||
let mut constant_as_num = num::Num::<E>::zero();
|
let mut constant_as_num = num::Num::<E>::zero();
|
||||||
constant_as_num =
|
constant_as_num =
|
||||||
constant_as_num.add_bool_with_coeff(CS::one(), &boolean::Boolean::Constant(true), round_constants[i]);
|
constant_as_num.add_bool_with_coeff(CS::one(), &boolean::Boolean::Constant(true), round_constants[i]);
|
||||||
new_state[i].add_assign(&constant_as_num);
|
new_state[i].add_assign(&constant_as_num);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
for (s0, s1) in self.elements.iter_mut().zip(new_state) {
|
for (s0, s1) in self.elements.iter_mut().zip(new_state) {
|
||||||
s0.update_with_num(s1);
|
s0.update_with_num(s1);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// terminates hades
|
||||||
|
ctx.round_end();
|
||||||
}
|
}
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
@ -288,16 +300,25 @@ impl<E> PoseidonCircuit<E>
|
||||||
where
|
where
|
||||||
E: Engine,
|
E: Engine,
|
||||||
{
|
{
|
||||||
fn new(params: PoseidonParams<E>) -> Self {
|
pub fn new(params: PoseidonParams<E>) -> Self {
|
||||||
Self { params: params }
|
Self { params: params }
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn width(&self) -> usize {
|
||||||
|
self.params.width()
|
||||||
|
}
|
||||||
|
|
||||||
pub fn alloc<CS: ConstraintSystem<E>>(
|
pub fn alloc<CS: ConstraintSystem<E>>(
|
||||||
&mut self,
|
&self,
|
||||||
mut cs: CS,
|
mut cs: CS,
|
||||||
input: Vec<num::AllocatedNum<E>>,
|
input: Vec<num::AllocatedNum<E>>,
|
||||||
) -> Result<num::AllocatedNum<E>, SynthesisError> {
|
) -> Result<num::AllocatedNum<E>, SynthesisError> {
|
||||||
let mut state = State::new(input);
|
assert!(input.len() < self.params.width());
|
||||||
|
|
||||||
|
let mut elements: Vec<Element<E>> = input.iter().map(|el| Element::new_from_alloc(el.clone())).collect();
|
||||||
|
elements.resize(self.width(), Element::new_from_num(num::Num::zero()));
|
||||||
|
|
||||||
|
let mut state = State::new(elements);
|
||||||
let mut ctx = RoundCtx::new(&self.params);
|
let mut ctx = RoundCtx::new(&self.params);
|
||||||
loop {
|
loop {
|
||||||
match ctx.round_type() {
|
match ctx.round_type() {
|
||||||
|
@ -324,7 +345,7 @@ fn test_poseidon_circuit() {
|
||||||
let mut cs = TestConstraintSystem::<Bn256>::new();
|
let mut cs = TestConstraintSystem::<Bn256>::new();
|
||||||
let params = PoseidonParams::default();
|
let params = PoseidonParams::default();
|
||||||
|
|
||||||
let inputs: Vec<Fr> = ["1", "2", "0"].iter().map(|e| Fr::from_str(e).unwrap()).collect();
|
let inputs: Vec<Fr> = ["0", "0"].iter().map(|e| Fr::from_str(e).unwrap()).collect();
|
||||||
let allocated_inputs = inputs
|
let allocated_inputs = inputs
|
||||||
.clone()
|
.clone()
|
||||||
.into_iter()
|
.into_iter()
|
||||||
|
@ -336,16 +357,15 @@ fn test_poseidon_circuit() {
|
||||||
.collect();
|
.collect();
|
||||||
|
|
||||||
let mut circuit = PoseidonCircuit::<Bn256>::new(params.clone());
|
let mut circuit = PoseidonCircuit::<Bn256>::new(params.clone());
|
||||||
let res = circuit.alloc(cs.namespace(|| "hash alloc"), allocated_inputs).unwrap();
|
let res_allocated = circuit.alloc(cs.namespace(|| "hash alloc"), allocated_inputs).unwrap();
|
||||||
let val = res.get_value().unwrap();
|
let result = res_allocated.get_value().unwrap();
|
||||||
|
|
||||||
let mut poseidon = PoseidonHasher::new(params.clone());
|
let mut poseidon = PoseidonHasher::new(params.clone());
|
||||||
let expected = poseidon.hash(inputs);
|
let expected = poseidon.hash(inputs);
|
||||||
|
|
||||||
assert_eq!(val, expected);
|
assert_eq!(result, expected);
|
||||||
assert!(cs.is_satisfied());
|
assert!(cs.is_satisfied());
|
||||||
println!(
|
println!(
|
||||||
"number of constraints for\nt {}, rf {}, rp {}\n{}",
|
"number of constraints for (t: {}, rf: {}, rp: {}), {}",
|
||||||
params.width(),
|
params.width(),
|
||||||
params.full_round_half_len() * 2,
|
params.full_round_half_len() * 2,
|
||||||
params.partial_round_len(),
|
params.partial_round_len(),
|
||||||
|
|
|
@ -12,6 +12,13 @@ pub struct PoseidonParams<E: Engine> {
|
||||||
mds_matrix: Vec<E::Fr>,
|
mds_matrix: Vec<E::Fr>,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Clone)]
|
||||||
|
pub struct Poseidon<E: Engine> {
|
||||||
|
state: Vec<E::Fr>,
|
||||||
|
round: usize,
|
||||||
|
params: PoseidonParams<E>,
|
||||||
|
}
|
||||||
|
|
||||||
impl<E: Engine> PoseidonParams<E> {
|
impl<E: Engine> PoseidonParams<E> {
|
||||||
pub fn new(rf: usize, rp: usize, t: usize, round_constants: Vec<E::Fr>, mds_matrix: Vec<E::Fr>) -> PoseidonParams<E> {
|
pub fn new(rf: usize, rp: usize, t: usize, round_constants: Vec<E::Fr>, mds_matrix: Vec<E::Fr>) -> PoseidonParams<E> {
|
||||||
assert_eq!((rf + rp) * t, round_constants.len());
|
assert_eq!((rf + rp) * t, round_constants.len());
|
||||||
|
@ -27,9 +34,8 @@ impl<E: Engine> PoseidonParams<E> {
|
||||||
pub fn default() -> PoseidonParams<E> {
|
pub fn default() -> PoseidonParams<E> {
|
||||||
let (t, rf, rp) = (3usize, 8usize, 55usize);
|
let (t, rf, rp) = (3usize, 8usize, 55usize);
|
||||||
let seed = b"".to_vec();
|
let seed = b"".to_vec();
|
||||||
let person_mds_matrix = b"drlnhdsm";
|
|
||||||
let round_constants = PoseidonParams::<E>::generate_constants(b"drlnhdsc", seed.clone(), (rf + rp) * t);
|
let round_constants = PoseidonParams::<E>::generate_constants(b"drlnhdsc", seed.clone(), (rf + rp) * t);
|
||||||
let mds_matrix = PoseidonParams::<E>::generate_mds_matrix(person_mds_matrix, seed.clone(), t);
|
let mds_matrix = PoseidonParams::<E>::generate_mds_matrix(b"drlnhdsm", seed.clone(), t);
|
||||||
PoseidonParams::new(rf, rp, t, round_constants, mds_matrix)
|
PoseidonParams::new(rf, rp, t, round_constants, mds_matrix)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -103,12 +109,6 @@ impl<E: Engine> PoseidonParams<E> {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub struct Poseidon<E: Engine> {
|
|
||||||
state: Vec<E::Fr>,
|
|
||||||
round: usize,
|
|
||||||
params: PoseidonParams<E>,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<E: Engine> Poseidon<E> {
|
impl<E: Engine> Poseidon<E> {
|
||||||
pub fn new_with_params(
|
pub fn new_with_params(
|
||||||
rf: usize,
|
rf: usize,
|
||||||
|
@ -169,7 +169,11 @@ impl<E: Engine> Poseidon<E> {
|
||||||
} else if round >= a1 && round < a2 {
|
} else if round >= a1 && round < a2 {
|
||||||
self.partial_round(round);
|
self.partial_round(round);
|
||||||
} else if round >= a2 && round < a3 {
|
} else if round >= a2 && round < a3 {
|
||||||
self.full_round(round);
|
if round == a3 - 1 {
|
||||||
|
self.full_round_last();
|
||||||
|
} else {
|
||||||
|
self.full_round(round);
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
panic!("should not be here")
|
panic!("should not be here")
|
||||||
}
|
}
|
||||||
|
@ -181,6 +185,12 @@ impl<E: Engine> Poseidon<E> {
|
||||||
self.mul_mds_matrix();
|
self.mul_mds_matrix();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn full_round_last(&mut self) {
|
||||||
|
let last_round = self.params.total_rounds() - 1;
|
||||||
|
self.add_round_constants(last_round);
|
||||||
|
self.apply_quintic_sbox(true);
|
||||||
|
}
|
||||||
|
|
||||||
fn partial_round(&mut self, round: usize) {
|
fn partial_round(&mut self, round: usize) {
|
||||||
self.add_round_constants(round);
|
self.add_round_constants(round);
|
||||||
self.apply_quintic_sbox(false);
|
self.apply_quintic_sbox(false);
|
||||||
|
@ -230,5 +240,6 @@ fn test_poseidon_hash() {
|
||||||
let r1: Fr = hasher.hash(input1.to_vec());
|
let r1: Fr = hasher.hash(input1.to_vec());
|
||||||
let input2: Vec<Fr> = ["0", "0"].iter().map(|e| Fr::from_str(e).unwrap()).collect();
|
let input2: Vec<Fr> = ["0", "0"].iter().map(|e| Fr::from_str(e).unwrap()).collect();
|
||||||
let r2: Fr = hasher.hash(input2.to_vec());
|
let r2: Fr = hasher.hash(input2.to_vec());
|
||||||
|
println!("{}", r2);
|
||||||
assert_eq!(r1, r2, "just to see if internal state resets");
|
assert_eq!(r1, r2, "just to see if internal state resets");
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue