From 85505b3104642b9a3d1a114d6c45673b1bdfc359 Mon Sep 17 00:00:00 2001 From: Manish Kumar Date: Fri, 31 May 2024 14:28:20 +0530 Subject: [PATCH] round functions for keccak --- .../src/arithmetic/binary_arithmetic.rs | 2 + hash/plonky2/src/arithmetic/u32_arithmetic.rs | 18 +- hash/plonky2/src/arithmetic/u64_arithmetic.rs | 8 + hash/plonky2/src/bench/keccak.rs | 228 +++++++++++------- 4 files changed, 158 insertions(+), 98 deletions(-) diff --git a/hash/plonky2/src/arithmetic/binary_arithmetic.rs b/hash/plonky2/src/arithmetic/binary_arithmetic.rs index ffdde19..89c95b0 100644 --- a/hash/plonky2/src/arithmetic/binary_arithmetic.rs +++ b/hash/plonky2/src/arithmetic/binary_arithmetic.rs @@ -3,6 +3,8 @@ use plonky2::field::extension::Extendable; use plonky2::plonk::circuit_builder::CircuitBuilder; use plonky2::iop::target::BoolTarget; +//TODO: remove the dead codes later +#[allow(dead_code)] pub trait CircuitBuilderBoolTarget, const D: usize> { fn and(&mut self, a: BoolTarget, b: BoolTarget) -> BoolTarget; fn or(&mut self, a: BoolTarget, b: BoolTarget) -> BoolTarget; diff --git a/hash/plonky2/src/arithmetic/u32_arithmetic.rs b/hash/plonky2/src/arithmetic/u32_arithmetic.rs index f6c4676..cacbdea 100644 --- a/hash/plonky2/src/arithmetic/u32_arithmetic.rs +++ b/hash/plonky2/src/arithmetic/u32_arithmetic.rs @@ -6,6 +6,8 @@ use plonky2_u32::gadgets::arithmetic_u32::U32Target; use super::binary_arithmetic::CircuitBuilderBoolTarget; use plonky2_u32::gadgets::arithmetic_u32::CircuitBuilderU32; +//TODO: remove the dead codes later +#[allow(dead_code)] pub trait CircuitBuilderU32M, const D: usize> { fn or_u32(&mut self, a: U32Target, b: U32Target) -> U32Target; fn and_u32(&mut self, a: U32Target, b: U32Target) -> U32Target; @@ -15,16 +17,14 @@ pub trait CircuitBuilderU32M, const D: usize> { fn from_u32(&mut self, a: U32Target) -> Vec; fn to_u32(&mut self, a: Vec) -> U32Target; - // fn constant_u32(&mut self, c: u32) -> U32Target; + // not := 0xFFFFFFFF - x + fn not_u32(&mut self, a: U32Target) -> U32Target; + } impl, const D: usize> CircuitBuilderU32M for CircuitBuilder{ - // fn constant_u32(&mut self, c: u32) -> U32Target { - // U32Target(self.constant(F::from_canonical_u32(c))) - // } - fn from_u32(&mut self, a: U32Target) -> Vec { let mut res = Vec::new(); @@ -88,4 +88,12 @@ impl, const D: usize> CircuitBuilderU32M let (lo, hi) = self.mul_u32(a, two_power_n); self.add_u32(lo, hi).0 } + + // not := 0xFFFFFFFF - x + fn not_u32(&mut self, a: U32Target) -> U32Target { + let zero = self.zero_u32(); + let ff = self.constant_u32(0xFFFFFFFF); + self.sub_u32(ff, a, zero).0 + } + } \ No newline at end of file diff --git a/hash/plonky2/src/arithmetic/u64_arithmetic.rs b/hash/plonky2/src/arithmetic/u64_arithmetic.rs index 9d3f0b7..c33d2be 100644 --- a/hash/plonky2/src/arithmetic/u64_arithmetic.rs +++ b/hash/plonky2/src/arithmetic/u64_arithmetic.rs @@ -8,12 +8,16 @@ use plonky2_u32::gadgets::arithmetic_u32::CircuitBuilderU32; #[derive(Clone, Copy, Debug)] pub struct U64Target(pub [U32Target;2]); +//TODO: remove the dead codes later +#[allow(dead_code)] pub trait CircuitBuilderU64, const D: usize> { fn and_u64(&mut self, a: U64Target, b: U64Target) -> U64Target; fn xor_u64(&mut self, a: U64Target, b: U64Target) -> U64Target; fn rotate_left_u64(&mut self, a: U64Target, n: u8) -> U64Target; fn zero_u64(&mut self) -> U64Target; + + fn not_u64(&mut self, a: U64Target) -> U64Target; } impl, const D: usize> CircuitBuilderU64 @@ -49,4 +53,8 @@ impl, const D: usize> CircuitBuilderU64 let zero_u32 = self.zero_u32(); U64Target([zero_u32,zero_u32]) } + + fn not_u64(&mut self, a: U64Target) -> U64Target { + U64Target([self.not_u32(a.0[0]), self.not_u32(a.0[1])]) + } } \ No newline at end of file diff --git a/hash/plonky2/src/bench/keccak.rs b/hash/plonky2/src/bench/keccak.rs index 5822815..4898ba3 100644 --- a/hash/plonky2/src/bench/keccak.rs +++ b/hash/plonky2/src/bench/keccak.rs @@ -8,7 +8,7 @@ use rand::Rng; use plonky2::field::extension::Extendable; use plonky2::plonk::config::Hasher; use plonky2::hash::hash_types::RichField; - +use plonky2_u32::gadgets::arithmetic_u32::CircuitBuilderU32; use crate::arithmetic::u64_arithmetic::U64Target; use crate::arithmetic::u64_arithmetic::CircuitBuilderU64; @@ -36,7 +36,7 @@ pub fn keccak_bench(_size: usize) { let config = CircuitConfig::standard_recursion_config(); let mut builder = CircuitBuilder::::new(config); - let initial = builder.add_virtual_targets(data.len()); + let _initial = builder.add_virtual_targets(data.len()); let hash = KeccakHash::<32>::hash_or_noop(&data); eprintln!("{:?}", hash); @@ -46,103 +46,145 @@ pub fn keccak_bench(_size: usize) { //---------------------------------------------------------- + // const KECCAK_WIDTH: usize = 1600; // const KECCAK_RATE: usize = 1088; // const KECCAK_CAPACITY: usize = KECCAK_WIDTH - KECCAK_RATE; // const KECCAK_LANES: usize = KECCAK_WIDTH / 64; -// const KECCAK_ROUNDS: usize = 24; +const KECCAK_ROUNDS: usize = 24; -// const ROUND_CONSTANTS: [u64; KECCAK_ROUNDS] = [ -// 0x0000000000000001, 0x0000000000008082, 0x800000000000808A, 0x8000000080008000, -// 0x000000000000808B, 0x0000000080000001, 0x8000000080008081, 0x8000000000008009, -// 0x000000000000008A, 0x0000000000000088, 0x0000000080008009, 0x000000008000000A, -// 0x000000008000808B, 0x800000000000008B, 0x8000000000008089, 0x8000000000008003, -// 0x8000000000008002, 0x8000000000000080, 0x000000000000800A, 0x800000008000000A, -// 0x8000000080008081, 0x8000000000008080, 0x0000000080000001, 0x8000000080008008, -// ]; - -// fn initialize_state() -> [u64; KECCAK_LANES] { -// [0; KECCAK_LANES] -// } -// pub struct U64Target([U32Target;2]); - -// // copied from sha256 circuit -// // TODO: move to some common place -// pub fn u32_to_bits_target, const D: usize, const B: usize>( -// builder: &mut CircuitBuilder, -// a: &U32Target, -// ) -> Vec { -// let mut res = Vec::new(); -// let bit_targets = builder.split_le_base::(a.0, 32); -// for i in (0..32).rev() { -// res.push(BoolTarget::new_unsafe(bit_targets[i])); -// } -// res -// } - -// // copied from sha256 circuit -// // TODO: move to some common place -// pub fn bits_to_u32_target, const D: usize>( -// builder: &mut CircuitBuilder, -// bits_target: Vec, -// ) -> U32Target { -// let bit_len = bits_target.len(); -// assert_eq!(bit_len, 32); -// U32Target(builder.le_sum(bits_target[0..32].iter().rev())) -// } - -// //TODO: not tested -// pub fn xor_u64, const D: usize>( -// builder: &mut CircuitBuilder, -// x: U64Target, -// y: U64Target, -// ) -> U64Target { -// let xor_x0_y0 = xor_u32(builder, x.0[0], y.0[0]); -// let xor_x1_y1 = xor_u32(builder, x.0[1], y.0[1]); - -// U64Target([xor_x0_y0,xor_x1_y1]) - -// } - -// pub fn xor_u32, const D: usize>( -// builder: &mut CircuitBuilder, -// x: U32Target, -// y: U32Target, -// ) -> U32Target { - -// let bits_target_x = u32_to_bits_target::(builder, &x); -// let bits_target_y = u32_to_bits_target::(builder, &y); - -// assert_eq!(bits_target_x.len(), bits_target_y.len()); - -// let mut xor_result_final = Vec::::new(); -// for i in 0..bits_target_x.len() { -// let a_plus_b = builder.add(bits_target_x.get(i).unwrap().target, bits_target_y.get(i).unwrap().target); -// let ab = builder.mul(bits_target_x.get(i).unwrap().target, bits_target_y.get(i).unwrap().target); -// let two_ab = builder.mul_const(F::from_canonical_u64(2), ab); -// let xor_result = builder.sub(a_plus_b, two_ab); -// xor_result_final.push(BoolTarget::new_unsafe(xor_result)); -// } -// let result = bits_to_u32_target(builder, xor_result_final); -// result - -// } +//TODO: remove the dead codes later +#[allow(dead_code)] +const ROUND_CONSTANTS: [u64; KECCAK_ROUNDS] = [ + 0x0000000000000001, 0x0000000000008082, 0x800000000000808A, 0x8000000080008000, + 0x000000000000808B, 0x0000000080000001, 0x8000000080008081, 0x8000000000008009, + 0x000000000000008A, 0x0000000000000088, 0x0000000080008009, 0x000000008000000A, + 0x000000008000808B, 0x800000000000008B, 0x8000000000008089, 0x8000000000008003, + 0x8000000000008002, 0x8000000000000080, 0x000000000000800A, 0x800000008000000A, + 0x8000000080008081, 0x8000000000008080, 0x0000000080000001, 0x8000000080008008, +]; +//TODO: remove the dead codes later +#[allow(dead_code)] // Theta -// pub fn theta, const D: usize>( -// builder: &mut CircuitBuilder, -// state: &mut [U64Target; KECCAK_LANES] -// ) { +pub fn theta, const D: usize>( + builder: &mut CircuitBuilder, + state: &mut [[U64Target; 5]; 5] +){ + let mut c = [builder.zero_u64(); 5]; + let mut d = [builder.zero_u64(); 5]; -// let mut c = [0u64; 5]; -// for x in 0..5 { -// c[x] = state[x] ^ state[x + 5] ^ state[x + 10] ^ state[x + 15] ^ state[x + 20]; -// } -// for x in 0..5 { -// let d = c[(x + 4) % 5] ^ c[(x + 1) % 5].rotate_left(1); -// for y in 0..5 { -// state[x + 5 * y] ^= d; -// } -// } + // Compute column parities + for x in 0..5 { -// } \ No newline at end of file + let xor_x0_x1 = builder.xor_u64(state[x][0], state[x][1]); + let xor_x0_x1_x2 = builder.xor_u64(xor_x0_x1, state[x][2]); + let xor_x0_x1_x2_x3 = builder.xor_u64(xor_x0_x1_x2, state[x][3]); + c[x] = builder.xor_u64(xor_x0_x1_x2_x3, state[x][4]); + + } + + // Compute rotated parities + for x in 0..5 { + let c_left = c[(x + 4) % 5]; + let c_right_rot = builder.rotate_left_u64(c[(x + 1) % 5], 1); + d[x] = builder.xor_u64(c_left, c_right_rot); + } + + // Modify the state + for x in 0..5 { + for y in 0..5 { + state[x][y] = builder.xor_u64(state[x][y], d[x]); + } + } +} + +//TODO: remove the dead codes later +#[allow(dead_code)] +//rho +fn rho, const D: usize>( + builder: &mut CircuitBuilder, + state: &mut [[U64Target; 5]; 5] +){ + const RHO_OFFSETS: [[usize; 5]; 5] = [ + [0, 1, 62, 28, 27], + [36, 44, 6, 55, 20], + [3, 10, 43, 25, 39], + [41, 45, 15, 21, 8], + [18, 2, 61, 56, 14], + ]; + + for x in 0..5 { + for y in 0..5 { + let rotation = RHO_OFFSETS[x][y]; + state[x][y] = builder.rotate_left_u64(state[x][y], rotation as u8); + } + } +} + +//TODO: remove the dead codes later +#[allow(dead_code)] +//pi +fn pi, const D: usize>( + builder: &mut CircuitBuilder, + state: &mut [[U64Target; 5]; 5] +){ + let mut new_state = [[builder.zero_u64(); 5]; 5]; + for x in 0..5 { + for y in 0..5 { + new_state[(2 * x + 3 * y) % 5][y] = state[x][y]; + } + } + *state = new_state; +} + +//TODO: remove the dead codes later +#[allow(dead_code)] +//iota +fn iota, const D: usize>( + builder: &mut CircuitBuilder, + state: &mut [[U64Target; 5]; 5], + round: usize +){ + + let lo = builder.constant_u32((ROUND_CONSTANTS[round] & 0xFFFFFFFF) as u32); + let hi = builder.constant_u32(((ROUND_CONSTANTS[round] >> 32)& 0xFFFFFFFF) as u32); + state[0][0] = builder.xor_u64(state[0][0], U64Target([lo,hi])) ; +} + +//TODO: remove the dead codes later +#[allow(dead_code)] +fn chi, const D: usize>( + builder: &mut CircuitBuilder, + state: &mut [[U64Target; 5]; 5] +){ + for x in 0..5 { + let mut temp = [builder.zero_u64(); 5]; + for y in 0..5 { + temp[y] = state[x][y]; + } + + + for y in 0..5 { + let t1 = builder.not_u64(temp[(y + 1) % 5]); + let t2 = builder.and_u64(t1, temp[(y + 2) % 5]); + state[x][y] = builder.xor_u64(state[x][y], t2); + } + } +} + +//TODO: remove the dead codes later +#[allow(dead_code)] +// permutation +fn keccak_permutation, const D: usize>( + builder: &mut CircuitBuilder, + state: &mut [[U64Target; 5]; 5] +) { + for i in 0..24 { + theta(builder, state); + rho(builder, state); + pi(builder, state); + chi(builder, state); + iota(builder, state, ROUND_CONSTANTS[i] as usize) + } +} \ No newline at end of file