diff --git a/src/circuit/poseidon.rs b/src/circuit/poseidon.rs index 94f6411..76de7d1 100644 --- a/src/circuit/poseidon.rs +++ b/src/circuit/poseidon.rs @@ -3,7 +3,6 @@ use sapling_crypto::bellman::pairing::ff::{Field, PrimeField, PrimeFieldRepr}; use sapling_crypto::bellman::pairing::Engine; use sapling_crypto::bellman::{Circuit, ConstraintSystem, LinearCombination, SynthesisError}; use sapling_crypto::circuit::{boolean, ecc, num, Assignment}; -// use sapling_crypto::jubjub::JubjubEngine; #[derive(Clone)] struct Element @@ -35,6 +34,7 @@ where elements: Vec>, } +#[derive(Clone)] pub struct PoseidonCircuit where E: Engine, @@ -136,6 +136,10 @@ where } } + pub fn is_last_round(&self) -> bool { + self.number == self.params.total_rounds() - 1 + } + pub fn in_transition(&self) -> bool { let a1 = self.params.full_round_half_len(); let a2 = a1 + self.params.partial_round_len(); @@ -180,8 +184,7 @@ impl State where E: Engine, { - pub fn new(input: Vec>) -> Self { - let elements = input.iter().map(|el| Element::new_from_alloc(el.clone())).collect(); + pub fn new(elements: Vec>) -> Self { Self { elements } } @@ -198,6 +201,7 @@ where fn sbox>(&mut self, mut cs: CS, ctx: &mut RoundCtx) -> Result<(), SynthesisError> { assert_eq!(ctx.width(), self.elements.len()); + for i in 0..if ctx.is_full_round() { ctx.width() } else { 1 } { let round_constant = ctx.round_constant(i); let si = { @@ -244,41 +248,49 @@ where ctx: &mut RoundCtx, ) -> Result<(), SynthesisError> { assert_eq!(ctx.width(), self.elements.len()); - let mut new_state: Vec> = Vec::new(); - let w = ctx.width(); - for i in 0..w { - let row = ctx.mds_matrix_row(i); - let mut acc = num::Num::::zero(); - for j in 0..w { - let mut r = self.elements[j].num(); - r.scale(row[j]); - acc.add_assign(&r); + if !ctx.is_last_round() { + // skip mds multiplication in last round + + let mut new_state: Vec> = Vec::new(); + let w = ctx.width(); + + for i in 0..w { + let row = ctx.mds_matrix_row(i); + let mut acc = num::Num::::zero(); + for j in 0..w { + let mut r = self.elements[j].num(); + r.scale(row[j]); + acc.add_assign(&r); + } + new_state.push(acc); } - new_state.push(acc); - } - // round ends here - let is_full_round = ctx.is_full_round(); - let in_transition = ctx.in_transition(); - ctx.round_end(); + // round ends here + let is_full_round = ctx.is_full_round(); + let in_transition = ctx.in_transition(); + ctx.round_end(); - // add round constants just after mds if - // first full round has just ended - // or in partial rounds expect the last one. - if in_transition == is_full_round { - // add round constants for elements in {1, t} - let round_constants = ctx.round_constants(); - for i in 1..w { - let mut constant_as_num = num::Num::::zero(); - constant_as_num = - constant_as_num.add_bool_with_coeff(CS::one(), &boolean::Boolean::Constant(true), round_constants[i]); - new_state[i].add_assign(&constant_as_num); + // add round constants just after mds if + // first full round has just ended + // or in partial rounds expect the last one. + if in_transition == is_full_round { + // add round constants for elements in {1, t} + let round_constants = ctx.round_constants(); + for i in 1..w { + let mut constant_as_num = num::Num::::zero(); + constant_as_num = + constant_as_num.add_bool_with_coeff(CS::one(), &boolean::Boolean::Constant(true), round_constants[i]); + new_state[i].add_assign(&constant_as_num); + } } - } - for (s0, s1) in self.elements.iter_mut().zip(new_state) { - s0.update_with_num(s1); + for (s0, s1) in self.elements.iter_mut().zip(new_state) { + s0.update_with_num(s1); + } + } else { + // terminates hades + ctx.round_end(); } Ok(()) } @@ -288,16 +300,25 @@ impl PoseidonCircuit where E: Engine, { - fn new(params: PoseidonParams) -> Self { + pub fn new(params: PoseidonParams) -> Self { Self { params: params } } + pub fn width(&self) -> usize { + self.params.width() + } + pub fn alloc>( - &mut self, + &self, mut cs: CS, input: Vec>, ) -> Result, SynthesisError> { - let mut state = State::new(input); + assert!(input.len() < self.params.width()); + + let mut elements: Vec> = input.iter().map(|el| Element::new_from_alloc(el.clone())).collect(); + elements.resize(self.width(), Element::new_from_num(num::Num::zero())); + + let mut state = State::new(elements); let mut ctx = RoundCtx::new(&self.params); loop { match ctx.round_type() { @@ -324,7 +345,7 @@ fn test_poseidon_circuit() { let mut cs = TestConstraintSystem::::new(); let params = PoseidonParams::default(); - let inputs: Vec = ["1", "2", "0"].iter().map(|e| Fr::from_str(e).unwrap()).collect(); + let inputs: Vec = ["0", "0"].iter().map(|e| Fr::from_str(e).unwrap()).collect(); let allocated_inputs = inputs .clone() .into_iter() @@ -336,16 +357,15 @@ fn test_poseidon_circuit() { .collect(); let mut circuit = PoseidonCircuit::::new(params.clone()); - let res = circuit.alloc(cs.namespace(|| "hash alloc"), allocated_inputs).unwrap(); - let val = res.get_value().unwrap(); - + let res_allocated = circuit.alloc(cs.namespace(|| "hash alloc"), allocated_inputs).unwrap(); + let result = res_allocated.get_value().unwrap(); let mut poseidon = PoseidonHasher::new(params.clone()); let expected = poseidon.hash(inputs); - assert_eq!(val, expected); + assert_eq!(result, expected); assert!(cs.is_satisfied()); println!( - "number of constraints for\nt {}, rf {}, rp {}\n{}", + "number of constraints for (t: {}, rf: {}, rp: {}), {}", params.width(), params.full_round_half_len() * 2, params.partial_round_len(), diff --git a/src/poseidon.rs b/src/poseidon.rs index 41a3b82..71cd035 100644 --- a/src/poseidon.rs +++ b/src/poseidon.rs @@ -12,6 +12,13 @@ pub struct PoseidonParams { mds_matrix: Vec, } +#[derive(Clone)] +pub struct Poseidon { + state: Vec, + round: usize, + params: PoseidonParams, +} + impl PoseidonParams { pub fn new(rf: usize, rp: usize, t: usize, round_constants: Vec, mds_matrix: Vec) -> PoseidonParams { assert_eq!((rf + rp) * t, round_constants.len()); @@ -27,9 +34,8 @@ impl PoseidonParams { pub fn default() -> PoseidonParams { let (t, rf, rp) = (3usize, 8usize, 55usize); let seed = b"".to_vec(); - let person_mds_matrix = b"drlnhdsm"; let round_constants = PoseidonParams::::generate_constants(b"drlnhdsc", seed.clone(), (rf + rp) * t); - let mds_matrix = PoseidonParams::::generate_mds_matrix(person_mds_matrix, seed.clone(), t); + let mds_matrix = PoseidonParams::::generate_mds_matrix(b"drlnhdsm", seed.clone(), t); PoseidonParams::new(rf, rp, t, round_constants, mds_matrix) } @@ -103,12 +109,6 @@ impl PoseidonParams { } } -pub struct Poseidon { - state: Vec, - round: usize, - params: PoseidonParams, -} - impl Poseidon { pub fn new_with_params( rf: usize, @@ -169,7 +169,11 @@ impl Poseidon { } else if round >= a1 && round < a2 { self.partial_round(round); } else if round >= a2 && round < a3 { - self.full_round(round); + if round == a3 - 1 { + self.full_round_last(); + } else { + self.full_round(round); + } } else { panic!("should not be here") } @@ -181,6 +185,12 @@ impl Poseidon { self.mul_mds_matrix(); } + fn full_round_last(&mut self) { + let last_round = self.params.total_rounds() - 1; + self.add_round_constants(last_round); + self.apply_quintic_sbox(true); + } + fn partial_round(&mut self, round: usize) { self.add_round_constants(round); self.apply_quintic_sbox(false); @@ -230,5 +240,6 @@ fn test_poseidon_hash() { let r1: Fr = hasher.hash(input1.to_vec()); let input2: Vec = ["0", "0"].iter().map(|e| Fr::from_str(e).unwrap()).collect(); let r2: Fr = hasher.hash(input2.to_vec()); + println!("{}", r2); assert_eq!(r1, r2, "just to see if internal state resets"); }