skip last mds mul

This commit is contained in:
kilic 2020-05-01 15:34:36 +03:00
parent d639d64aaf
commit c1130234ef
2 changed files with 81 additions and 50 deletions

View File

@ -3,7 +3,6 @@ use sapling_crypto::bellman::pairing::ff::{Field, PrimeField, PrimeFieldRepr};
use sapling_crypto::bellman::pairing::Engine; use sapling_crypto::bellman::pairing::Engine;
use sapling_crypto::bellman::{Circuit, ConstraintSystem, LinearCombination, SynthesisError}; use sapling_crypto::bellman::{Circuit, ConstraintSystem, LinearCombination, SynthesisError};
use sapling_crypto::circuit::{boolean, ecc, num, Assignment}; use sapling_crypto::circuit::{boolean, ecc, num, Assignment};
// use sapling_crypto::jubjub::JubjubEngine;
#[derive(Clone)] #[derive(Clone)]
struct Element<E> struct Element<E>
@ -35,6 +34,7 @@ where
elements: Vec<Element<E>>, elements: Vec<Element<E>>,
} }
#[derive(Clone)]
pub struct PoseidonCircuit<E> pub struct PoseidonCircuit<E>
where where
E: Engine, E: Engine,
@ -136,6 +136,10 @@ where
} }
} }
pub fn is_last_round(&self) -> bool {
self.number == self.params.total_rounds() - 1
}
pub fn in_transition(&self) -> bool { pub fn in_transition(&self) -> bool {
let a1 = self.params.full_round_half_len(); let a1 = self.params.full_round_half_len();
let a2 = a1 + self.params.partial_round_len(); let a2 = a1 + self.params.partial_round_len();
@ -180,8 +184,7 @@ impl<E> State<E>
where where
E: Engine, E: Engine,
{ {
pub fn new(input: Vec<num::AllocatedNum<E>>) -> Self { pub fn new(elements: Vec<Element<E>>) -> Self {
let elements = input.iter().map(|el| Element::new_from_alloc(el.clone())).collect();
Self { elements } Self { elements }
} }
@ -198,6 +201,7 @@ where
fn sbox<CS: ConstraintSystem<E>>(&mut self, mut cs: CS, ctx: &mut RoundCtx<E>) -> Result<(), SynthesisError> { fn sbox<CS: ConstraintSystem<E>>(&mut self, mut cs: CS, ctx: &mut RoundCtx<E>) -> Result<(), SynthesisError> {
assert_eq!(ctx.width(), self.elements.len()); assert_eq!(ctx.width(), self.elements.len());
for i in 0..if ctx.is_full_round() { ctx.width() } else { 1 } { for i in 0..if ctx.is_full_round() { ctx.width() } else { 1 } {
let round_constant = ctx.round_constant(i); let round_constant = ctx.round_constant(i);
let si = { let si = {
@ -244,6 +248,10 @@ where
ctx: &mut RoundCtx<E>, ctx: &mut RoundCtx<E>,
) -> Result<(), SynthesisError> { ) -> Result<(), SynthesisError> {
assert_eq!(ctx.width(), self.elements.len()); assert_eq!(ctx.width(), self.elements.len());
if !ctx.is_last_round() {
// skip mds multiplication in last round
let mut new_state: Vec<num::Num<E>> = Vec::new(); let mut new_state: Vec<num::Num<E>> = Vec::new();
let w = ctx.width(); let w = ctx.width();
@ -280,6 +288,10 @@ where
for (s0, s1) in self.elements.iter_mut().zip(new_state) { for (s0, s1) in self.elements.iter_mut().zip(new_state) {
s0.update_with_num(s1); s0.update_with_num(s1);
} }
} else {
// terminates hades
ctx.round_end();
}
Ok(()) Ok(())
} }
} }
@ -288,16 +300,25 @@ impl<E> PoseidonCircuit<E>
where where
E: Engine, E: Engine,
{ {
fn new(params: PoseidonParams<E>) -> Self { pub fn new(params: PoseidonParams<E>) -> Self {
Self { params: params } Self { params: params }
} }
pub fn width(&self) -> usize {
self.params.width()
}
pub fn alloc<CS: ConstraintSystem<E>>( pub fn alloc<CS: ConstraintSystem<E>>(
&mut self, &self,
mut cs: CS, mut cs: CS,
input: Vec<num::AllocatedNum<E>>, input: Vec<num::AllocatedNum<E>>,
) -> Result<num::AllocatedNum<E>, SynthesisError> { ) -> Result<num::AllocatedNum<E>, SynthesisError> {
let mut state = State::new(input); assert!(input.len() < self.params.width());
let mut elements: Vec<Element<E>> = input.iter().map(|el| Element::new_from_alloc(el.clone())).collect();
elements.resize(self.width(), Element::new_from_num(num::Num::zero()));
let mut state = State::new(elements);
let mut ctx = RoundCtx::new(&self.params); let mut ctx = RoundCtx::new(&self.params);
loop { loop {
match ctx.round_type() { match ctx.round_type() {
@ -324,7 +345,7 @@ fn test_poseidon_circuit() {
let mut cs = TestConstraintSystem::<Bn256>::new(); let mut cs = TestConstraintSystem::<Bn256>::new();
let params = PoseidonParams::default(); let params = PoseidonParams::default();
let inputs: Vec<Fr> = ["1", "2", "0"].iter().map(|e| Fr::from_str(e).unwrap()).collect(); let inputs: Vec<Fr> = ["0", "0"].iter().map(|e| Fr::from_str(e).unwrap()).collect();
let allocated_inputs = inputs let allocated_inputs = inputs
.clone() .clone()
.into_iter() .into_iter()
@ -336,16 +357,15 @@ fn test_poseidon_circuit() {
.collect(); .collect();
let mut circuit = PoseidonCircuit::<Bn256>::new(params.clone()); let mut circuit = PoseidonCircuit::<Bn256>::new(params.clone());
let res = circuit.alloc(cs.namespace(|| "hash alloc"), allocated_inputs).unwrap(); let res_allocated = circuit.alloc(cs.namespace(|| "hash alloc"), allocated_inputs).unwrap();
let val = res.get_value().unwrap(); let result = res_allocated.get_value().unwrap();
let mut poseidon = PoseidonHasher::new(params.clone()); let mut poseidon = PoseidonHasher::new(params.clone());
let expected = poseidon.hash(inputs); let expected = poseidon.hash(inputs);
assert_eq!(val, expected); assert_eq!(result, expected);
assert!(cs.is_satisfied()); assert!(cs.is_satisfied());
println!( println!(
"number of constraints for\nt {}, rf {}, rp {}\n{}", "number of constraints for (t: {}, rf: {}, rp: {}), {}",
params.width(), params.width(),
params.full_round_half_len() * 2, params.full_round_half_len() * 2,
params.partial_round_len(), params.partial_round_len(),

View File

@ -12,6 +12,13 @@ pub struct PoseidonParams<E: Engine> {
mds_matrix: Vec<E::Fr>, mds_matrix: Vec<E::Fr>,
} }
#[derive(Clone)]
pub struct Poseidon<E: Engine> {
state: Vec<E::Fr>,
round: usize,
params: PoseidonParams<E>,
}
impl<E: Engine> PoseidonParams<E> { impl<E: Engine> PoseidonParams<E> {
pub fn new(rf: usize, rp: usize, t: usize, round_constants: Vec<E::Fr>, mds_matrix: Vec<E::Fr>) -> PoseidonParams<E> { pub fn new(rf: usize, rp: usize, t: usize, round_constants: Vec<E::Fr>, mds_matrix: Vec<E::Fr>) -> PoseidonParams<E> {
assert_eq!((rf + rp) * t, round_constants.len()); assert_eq!((rf + rp) * t, round_constants.len());
@ -27,9 +34,8 @@ impl<E: Engine> PoseidonParams<E> {
pub fn default() -> PoseidonParams<E> { pub fn default() -> PoseidonParams<E> {
let (t, rf, rp) = (3usize, 8usize, 55usize); let (t, rf, rp) = (3usize, 8usize, 55usize);
let seed = b"".to_vec(); let seed = b"".to_vec();
let person_mds_matrix = b"drlnhdsm";
let round_constants = PoseidonParams::<E>::generate_constants(b"drlnhdsc", seed.clone(), (rf + rp) * t); let round_constants = PoseidonParams::<E>::generate_constants(b"drlnhdsc", seed.clone(), (rf + rp) * t);
let mds_matrix = PoseidonParams::<E>::generate_mds_matrix(person_mds_matrix, seed.clone(), t); let mds_matrix = PoseidonParams::<E>::generate_mds_matrix(b"drlnhdsm", seed.clone(), t);
PoseidonParams::new(rf, rp, t, round_constants, mds_matrix) PoseidonParams::new(rf, rp, t, round_constants, mds_matrix)
} }
@ -103,12 +109,6 @@ impl<E: Engine> PoseidonParams<E> {
} }
} }
pub struct Poseidon<E: Engine> {
state: Vec<E::Fr>,
round: usize,
params: PoseidonParams<E>,
}
impl<E: Engine> Poseidon<E> { impl<E: Engine> Poseidon<E> {
pub fn new_with_params( pub fn new_with_params(
rf: usize, rf: usize,
@ -169,7 +169,11 @@ impl<E: Engine> Poseidon<E> {
} else if round >= a1 && round < a2 { } else if round >= a1 && round < a2 {
self.partial_round(round); self.partial_round(round);
} else if round >= a2 && round < a3 { } else if round >= a2 && round < a3 {
if round == a3 - 1 {
self.full_round_last();
} else {
self.full_round(round); self.full_round(round);
}
} else { } else {
panic!("should not be here") panic!("should not be here")
} }
@ -181,6 +185,12 @@ impl<E: Engine> Poseidon<E> {
self.mul_mds_matrix(); self.mul_mds_matrix();
} }
fn full_round_last(&mut self) {
let last_round = self.params.total_rounds() - 1;
self.add_round_constants(last_round);
self.apply_quintic_sbox(true);
}
fn partial_round(&mut self, round: usize) { fn partial_round(&mut self, round: usize) {
self.add_round_constants(round); self.add_round_constants(round);
self.apply_quintic_sbox(false); self.apply_quintic_sbox(false);
@ -230,5 +240,6 @@ fn test_poseidon_hash() {
let r1: Fr = hasher.hash(input1.to_vec()); let r1: Fr = hasher.hash(input1.to_vec());
let input2: Vec<Fr> = ["0", "0"].iter().map(|e| Fr::from_str(e).unwrap()).collect(); let input2: Vec<Fr> = ["0", "0"].iter().map(|e| Fr::from_str(e).unwrap()).collect();
let r2: Fr = hasher.hash(input2.to_vec()); let r2: Fr = hasher.hash(input2.to_vec());
println!("{}", r2);
assert_eq!(r1, r2, "just to see if internal state resets"); assert_eq!(r1, r2, "just to see if internal state resets");
} }