skip last mds mul

This commit is contained in:
kilic 2020-05-01 15:34:36 +03:00
parent d639d64aaf
commit c1130234ef
2 changed files with 81 additions and 50 deletions

View File

@ -3,7 +3,6 @@ use sapling_crypto::bellman::pairing::ff::{Field, PrimeField, PrimeFieldRepr};
use sapling_crypto::bellman::pairing::Engine;
use sapling_crypto::bellman::{Circuit, ConstraintSystem, LinearCombination, SynthesisError};
use sapling_crypto::circuit::{boolean, ecc, num, Assignment};
// use sapling_crypto::jubjub::JubjubEngine;
#[derive(Clone)]
struct Element<E>
@ -35,6 +34,7 @@ where
elements: Vec<Element<E>>,
}
#[derive(Clone)]
pub struct PoseidonCircuit<E>
where
E: Engine,
@ -136,6 +136,10 @@ where
}
}
pub fn is_last_round(&self) -> bool {
self.number == self.params.total_rounds() - 1
}
pub fn in_transition(&self) -> bool {
let a1 = self.params.full_round_half_len();
let a2 = a1 + self.params.partial_round_len();
@ -180,8 +184,7 @@ impl<E> State<E>
where
E: Engine,
{
pub fn new(input: Vec<num::AllocatedNum<E>>) -> Self {
let elements = input.iter().map(|el| Element::new_from_alloc(el.clone())).collect();
pub fn new(elements: Vec<Element<E>>) -> Self {
Self { elements }
}
@ -198,6 +201,7 @@ where
fn sbox<CS: ConstraintSystem<E>>(&mut self, mut cs: CS, ctx: &mut RoundCtx<E>) -> Result<(), SynthesisError> {
assert_eq!(ctx.width(), self.elements.len());
for i in 0..if ctx.is_full_round() { ctx.width() } else { 1 } {
let round_constant = ctx.round_constant(i);
let si = {
@ -244,41 +248,49 @@ where
ctx: &mut RoundCtx<E>,
) -> Result<(), SynthesisError> {
assert_eq!(ctx.width(), self.elements.len());
let mut new_state: Vec<num::Num<E>> = Vec::new();
let w = ctx.width();
for i in 0..w {
let row = ctx.mds_matrix_row(i);
let mut acc = num::Num::<E>::zero();
for j in 0..w {
let mut r = self.elements[j].num();
r.scale(row[j]);
acc.add_assign(&r);
if !ctx.is_last_round() {
// skip mds multiplication in last round
let mut new_state: Vec<num::Num<E>> = Vec::new();
let w = ctx.width();
for i in 0..w {
let row = ctx.mds_matrix_row(i);
let mut acc = num::Num::<E>::zero();
for j in 0..w {
let mut r = self.elements[j].num();
r.scale(row[j]);
acc.add_assign(&r);
}
new_state.push(acc);
}
new_state.push(acc);
}
// round ends here
let is_full_round = ctx.is_full_round();
let in_transition = ctx.in_transition();
ctx.round_end();
// round ends here
let is_full_round = ctx.is_full_round();
let in_transition = ctx.in_transition();
ctx.round_end();
// add round constants just after mds if
// first full round has just ended
// or in partial rounds expect the last one.
if in_transition == is_full_round {
// add round constants for elements in {1, t}
let round_constants = ctx.round_constants();
for i in 1..w {
let mut constant_as_num = num::Num::<E>::zero();
constant_as_num =
constant_as_num.add_bool_with_coeff(CS::one(), &boolean::Boolean::Constant(true), round_constants[i]);
new_state[i].add_assign(&constant_as_num);
// add round constants just after mds if
// first full round has just ended
// or in partial rounds expect the last one.
if in_transition == is_full_round {
// add round constants for elements in {1, t}
let round_constants = ctx.round_constants();
for i in 1..w {
let mut constant_as_num = num::Num::<E>::zero();
constant_as_num =
constant_as_num.add_bool_with_coeff(CS::one(), &boolean::Boolean::Constant(true), round_constants[i]);
new_state[i].add_assign(&constant_as_num);
}
}
}
for (s0, s1) in self.elements.iter_mut().zip(new_state) {
s0.update_with_num(s1);
for (s0, s1) in self.elements.iter_mut().zip(new_state) {
s0.update_with_num(s1);
}
} else {
// terminates hades
ctx.round_end();
}
Ok(())
}
@ -288,16 +300,25 @@ impl<E> PoseidonCircuit<E>
where
E: Engine,
{
fn new(params: PoseidonParams<E>) -> Self {
pub fn new(params: PoseidonParams<E>) -> Self {
Self { params: params }
}
pub fn width(&self) -> usize {
self.params.width()
}
pub fn alloc<CS: ConstraintSystem<E>>(
&mut self,
&self,
mut cs: CS,
input: Vec<num::AllocatedNum<E>>,
) -> Result<num::AllocatedNum<E>, SynthesisError> {
let mut state = State::new(input);
assert!(input.len() < self.params.width());
let mut elements: Vec<Element<E>> = input.iter().map(|el| Element::new_from_alloc(el.clone())).collect();
elements.resize(self.width(), Element::new_from_num(num::Num::zero()));
let mut state = State::new(elements);
let mut ctx = RoundCtx::new(&self.params);
loop {
match ctx.round_type() {
@ -324,7 +345,7 @@ fn test_poseidon_circuit() {
let mut cs = TestConstraintSystem::<Bn256>::new();
let params = PoseidonParams::default();
let inputs: Vec<Fr> = ["1", "2", "0"].iter().map(|e| Fr::from_str(e).unwrap()).collect();
let inputs: Vec<Fr> = ["0", "0"].iter().map(|e| Fr::from_str(e).unwrap()).collect();
let allocated_inputs = inputs
.clone()
.into_iter()
@ -336,16 +357,15 @@ fn test_poseidon_circuit() {
.collect();
let mut circuit = PoseidonCircuit::<Bn256>::new(params.clone());
let res = circuit.alloc(cs.namespace(|| "hash alloc"), allocated_inputs).unwrap();
let val = res.get_value().unwrap();
let res_allocated = circuit.alloc(cs.namespace(|| "hash alloc"), allocated_inputs).unwrap();
let result = res_allocated.get_value().unwrap();
let mut poseidon = PoseidonHasher::new(params.clone());
let expected = poseidon.hash(inputs);
assert_eq!(val, expected);
assert_eq!(result, expected);
assert!(cs.is_satisfied());
println!(
"number of constraints for\nt {}, rf {}, rp {}\n{}",
"number of constraints for (t: {}, rf: {}, rp: {}), {}",
params.width(),
params.full_round_half_len() * 2,
params.partial_round_len(),

View File

@ -12,6 +12,13 @@ pub struct PoseidonParams<E: Engine> {
mds_matrix: Vec<E::Fr>,
}
#[derive(Clone)]
pub struct Poseidon<E: Engine> {
state: Vec<E::Fr>,
round: usize,
params: PoseidonParams<E>,
}
impl<E: Engine> PoseidonParams<E> {
pub fn new(rf: usize, rp: usize, t: usize, round_constants: Vec<E::Fr>, mds_matrix: Vec<E::Fr>) -> PoseidonParams<E> {
assert_eq!((rf + rp) * t, round_constants.len());
@ -27,9 +34,8 @@ impl<E: Engine> PoseidonParams<E> {
pub fn default() -> PoseidonParams<E> {
let (t, rf, rp) = (3usize, 8usize, 55usize);
let seed = b"".to_vec();
let person_mds_matrix = b"drlnhdsm";
let round_constants = PoseidonParams::<E>::generate_constants(b"drlnhdsc", seed.clone(), (rf + rp) * t);
let mds_matrix = PoseidonParams::<E>::generate_mds_matrix(person_mds_matrix, seed.clone(), t);
let mds_matrix = PoseidonParams::<E>::generate_mds_matrix(b"drlnhdsm", seed.clone(), t);
PoseidonParams::new(rf, rp, t, round_constants, mds_matrix)
}
@ -103,12 +109,6 @@ impl<E: Engine> PoseidonParams<E> {
}
}
pub struct Poseidon<E: Engine> {
state: Vec<E::Fr>,
round: usize,
params: PoseidonParams<E>,
}
impl<E: Engine> Poseidon<E> {
pub fn new_with_params(
rf: usize,
@ -169,7 +169,11 @@ impl<E: Engine> Poseidon<E> {
} else if round >= a1 && round < a2 {
self.partial_round(round);
} else if round >= a2 && round < a3 {
self.full_round(round);
if round == a3 - 1 {
self.full_round_last();
} else {
self.full_round(round);
}
} else {
panic!("should not be here")
}
@ -181,6 +185,12 @@ impl<E: Engine> Poseidon<E> {
self.mul_mds_matrix();
}
fn full_round_last(&mut self) {
let last_round = self.params.total_rounds() - 1;
self.add_round_constants(last_round);
self.apply_quintic_sbox(true);
}
fn partial_round(&mut self, round: usize) {
self.add_round_constants(round);
self.apply_quintic_sbox(false);
@ -230,5 +240,6 @@ fn test_poseidon_hash() {
let r1: Fr = hasher.hash(input1.to_vec());
let input2: Vec<Fr> = ["0", "0"].iter().map(|e| Fr::from_str(e).unwrap()).collect();
let r2: Fr = hasher.hash(input2.to_vec());
println!("{}", r2);
assert_eq!(r1, r2, "just to see if internal state resets");
}