round functions for keccak

This commit is contained in:
Manish Kumar 2024-05-31 14:28:20 +05:30
parent 1db0d91ed1
commit 85505b3104
4 changed files with 158 additions and 98 deletions

View File

@ -3,6 +3,8 @@ use plonky2::field::extension::Extendable;
use plonky2::plonk::circuit_builder::CircuitBuilder;
use plonky2::iop::target::BoolTarget;
//TODO: remove the dead codes later
#[allow(dead_code)]
pub trait CircuitBuilderBoolTarget<F: RichField + Extendable<D>, const D: usize> {
fn and(&mut self, a: BoolTarget, b: BoolTarget) -> BoolTarget;
fn or(&mut self, a: BoolTarget, b: BoolTarget) -> BoolTarget;

View File

@ -6,6 +6,8 @@ use plonky2_u32::gadgets::arithmetic_u32::U32Target;
use super::binary_arithmetic::CircuitBuilderBoolTarget;
use plonky2_u32::gadgets::arithmetic_u32::CircuitBuilderU32;
//TODO: remove the dead codes later
#[allow(dead_code)]
pub trait CircuitBuilderU32M<F: RichField + Extendable<D>, const D: usize> {
fn or_u32(&mut self, a: U32Target, b: U32Target) -> U32Target;
fn and_u32(&mut self, a: U32Target, b: U32Target) -> U32Target;
@ -15,16 +17,14 @@ pub trait CircuitBuilderU32M<F: RichField + Extendable<D>, const D: usize> {
fn from_u32(&mut self, a: U32Target) -> Vec<BoolTarget>;
fn to_u32(&mut self, a: Vec<BoolTarget>) -> U32Target;
// fn constant_u32(&mut self, c: u32) -> U32Target;
// not := 0xFFFFFFFF - x
fn not_u32(&mut self, a: U32Target) -> U32Target;
}
impl<F: RichField + Extendable<D>, const D: usize> CircuitBuilderU32M<F, D>
for CircuitBuilder<F, D>{
// fn constant_u32(&mut self, c: u32) -> U32Target {
// U32Target(self.constant(F::from_canonical_u32(c)))
// }
fn from_u32(&mut self, a: U32Target) -> Vec<BoolTarget> {
let mut res = Vec::new();
@ -88,4 +88,12 @@ impl<F: RichField + Extendable<D>, const D: usize> CircuitBuilderU32M<F, D>
let (lo, hi) = self.mul_u32(a, two_power_n);
self.add_u32(lo, hi).0
}
// not := 0xFFFFFFFF - x
fn not_u32(&mut self, a: U32Target) -> U32Target {
let zero = self.zero_u32();
let ff = self.constant_u32(0xFFFFFFFF);
self.sub_u32(ff, a, zero).0
}
}

View File

@ -8,12 +8,16 @@ use plonky2_u32::gadgets::arithmetic_u32::CircuitBuilderU32;
#[derive(Clone, Copy, Debug)]
pub struct U64Target(pub [U32Target;2]);
//TODO: remove the dead codes later
#[allow(dead_code)]
pub trait CircuitBuilderU64<F: RichField + Extendable<D>, const D: usize> {
fn and_u64(&mut self, a: U64Target, b: U64Target) -> U64Target;
fn xor_u64(&mut self, a: U64Target, b: U64Target) -> U64Target;
fn rotate_left_u64(&mut self, a: U64Target, n: u8) -> U64Target;
fn zero_u64(&mut self) -> U64Target;
fn not_u64(&mut self, a: U64Target) -> U64Target;
}
impl<F: RichField + Extendable<D>, const D: usize> CircuitBuilderU64<F, D>
@ -49,4 +53,8 @@ impl<F: RichField + Extendable<D>, const D: usize> CircuitBuilderU64<F, D>
let zero_u32 = self.zero_u32();
U64Target([zero_u32,zero_u32])
}
fn not_u64(&mut self, a: U64Target) -> U64Target {
U64Target([self.not_u32(a.0[0]), self.not_u32(a.0[1])])
}
}

View File

@ -8,7 +8,7 @@ use rand::Rng;
use plonky2::field::extension::Extendable;
use plonky2::plonk::config::Hasher;
use plonky2::hash::hash_types::RichField;
use plonky2_u32::gadgets::arithmetic_u32::CircuitBuilderU32;
use crate::arithmetic::u64_arithmetic::U64Target;
use crate::arithmetic::u64_arithmetic::CircuitBuilderU64;
@ -36,7 +36,7 @@ pub fn keccak_bench(_size: usize) {
let config = CircuitConfig::standard_recursion_config();
let mut builder = CircuitBuilder::<F, D>::new(config);
let initial = builder.add_virtual_targets(data.len());
let _initial = builder.add_virtual_targets(data.len());
let hash = KeccakHash::<32>::hash_or_noop(&data);
eprintln!("{:?}", hash);
@ -46,103 +46,145 @@ pub fn keccak_bench(_size: usize) {
//----------------------------------------------------------
// const KECCAK_WIDTH: usize = 1600;
// const KECCAK_RATE: usize = 1088;
// const KECCAK_CAPACITY: usize = KECCAK_WIDTH - KECCAK_RATE;
// const KECCAK_LANES: usize = KECCAK_WIDTH / 64;
// const KECCAK_ROUNDS: usize = 24;
const KECCAK_ROUNDS: usize = 24;
// const ROUND_CONSTANTS: [u64; KECCAK_ROUNDS] = [
// 0x0000000000000001, 0x0000000000008082, 0x800000000000808A, 0x8000000080008000,
// 0x000000000000808B, 0x0000000080000001, 0x8000000080008081, 0x8000000000008009,
// 0x000000000000008A, 0x0000000000000088, 0x0000000080008009, 0x000000008000000A,
// 0x000000008000808B, 0x800000000000008B, 0x8000000000008089, 0x8000000000008003,
// 0x8000000000008002, 0x8000000000000080, 0x000000000000800A, 0x800000008000000A,
// 0x8000000080008081, 0x8000000000008080, 0x0000000080000001, 0x8000000080008008,
// ];
// fn initialize_state() -> [u64; KECCAK_LANES] {
// [0; KECCAK_LANES]
// }
// pub struct U64Target([U32Target;2]);
// // copied from sha256 circuit
// // TODO: move to some common place
// pub fn u32_to_bits_target<F: RichField + Extendable<D>, const D: usize, const B: usize>(
// builder: &mut CircuitBuilder<F, D>,
// a: &U32Target,
// ) -> Vec<BoolTarget> {
// let mut res = Vec::new();
// let bit_targets = builder.split_le_base::<B>(a.0, 32);
// for i in (0..32).rev() {
// res.push(BoolTarget::new_unsafe(bit_targets[i]));
// }
// res
// }
// // copied from sha256 circuit
// // TODO: move to some common place
// pub fn bits_to_u32_target<F: RichField + Extendable<D>, const D: usize>(
// builder: &mut CircuitBuilder<F, D>,
// bits_target: Vec<BoolTarget>,
// ) -> U32Target {
// let bit_len = bits_target.len();
// assert_eq!(bit_len, 32);
// U32Target(builder.le_sum(bits_target[0..32].iter().rev()))
// }
// //TODO: not tested
// pub fn xor_u64<F: RichField + Extendable<D>, const D: usize>(
// builder: &mut CircuitBuilder<F, D>,
// x: U64Target,
// y: U64Target,
// ) -> U64Target {
// let xor_x0_y0 = xor_u32(builder, x.0[0], y.0[0]);
// let xor_x1_y1 = xor_u32(builder, x.0[1], y.0[1]);
// U64Target([xor_x0_y0,xor_x1_y1])
// }
// pub fn xor_u32<F: RichField + Extendable<D>, const D: usize>(
// builder: &mut CircuitBuilder<F, D>,
// x: U32Target,
// y: U32Target,
// ) -> U32Target {
// let bits_target_x = u32_to_bits_target::<F, D, 2>(builder, &x);
// let bits_target_y = u32_to_bits_target::<F, D, 2>(builder, &y);
// assert_eq!(bits_target_x.len(), bits_target_y.len());
// let mut xor_result_final = Vec::<BoolTarget>::new();
// for i in 0..bits_target_x.len() {
// let a_plus_b = builder.add(bits_target_x.get(i).unwrap().target, bits_target_y.get(i).unwrap().target);
// let ab = builder.mul(bits_target_x.get(i).unwrap().target, bits_target_y.get(i).unwrap().target);
// let two_ab = builder.mul_const(F::from_canonical_u64(2), ab);
// let xor_result = builder.sub(a_plus_b, two_ab);
// xor_result_final.push(BoolTarget::new_unsafe(xor_result));
// }
// let result = bits_to_u32_target(builder, xor_result_final);
// result
// }
//TODO: remove the dead codes later
#[allow(dead_code)]
const ROUND_CONSTANTS: [u64; KECCAK_ROUNDS] = [
0x0000000000000001, 0x0000000000008082, 0x800000000000808A, 0x8000000080008000,
0x000000000000808B, 0x0000000080000001, 0x8000000080008081, 0x8000000000008009,
0x000000000000008A, 0x0000000000000088, 0x0000000080008009, 0x000000008000000A,
0x000000008000808B, 0x800000000000008B, 0x8000000000008089, 0x8000000000008003,
0x8000000000008002, 0x8000000000000080, 0x000000000000800A, 0x800000008000000A,
0x8000000080008081, 0x8000000000008080, 0x0000000080000001, 0x8000000080008008,
];
//TODO: remove the dead codes later
#[allow(dead_code)]
// Theta
// pub fn theta<F: RichField + Extendable<D>, const D: usize>(
// builder: &mut CircuitBuilder<F, D>,
// state: &mut [U64Target; KECCAK_LANES]
// ) {
pub fn theta<F: RichField + Extendable<D>, const D: usize>(
builder: &mut CircuitBuilder<F, D>,
state: &mut [[U64Target; 5]; 5]
){
let mut c = [builder.zero_u64(); 5];
let mut d = [builder.zero_u64(); 5];
// let mut c = [0u64; 5];
// for x in 0..5 {
// c[x] = state[x] ^ state[x + 5] ^ state[x + 10] ^ state[x + 15] ^ state[x + 20];
// }
// for x in 0..5 {
// let d = c[(x + 4) % 5] ^ c[(x + 1) % 5].rotate_left(1);
// for y in 0..5 {
// state[x + 5 * y] ^= d;
// }
// }
// Compute column parities
for x in 0..5 {
// }
let xor_x0_x1 = builder.xor_u64(state[x][0], state[x][1]);
let xor_x0_x1_x2 = builder.xor_u64(xor_x0_x1, state[x][2]);
let xor_x0_x1_x2_x3 = builder.xor_u64(xor_x0_x1_x2, state[x][3]);
c[x] = builder.xor_u64(xor_x0_x1_x2_x3, state[x][4]);
}
// Compute rotated parities
for x in 0..5 {
let c_left = c[(x + 4) % 5];
let c_right_rot = builder.rotate_left_u64(c[(x + 1) % 5], 1);
d[x] = builder.xor_u64(c_left, c_right_rot);
}
// Modify the state
for x in 0..5 {
for y in 0..5 {
state[x][y] = builder.xor_u64(state[x][y], d[x]);
}
}
}
//TODO: remove the dead codes later
#[allow(dead_code)]
//rho
fn rho<F: RichField + Extendable<D>, const D: usize>(
builder: &mut CircuitBuilder<F, D>,
state: &mut [[U64Target; 5]; 5]
){
const RHO_OFFSETS: [[usize; 5]; 5] = [
[0, 1, 62, 28, 27],
[36, 44, 6, 55, 20],
[3, 10, 43, 25, 39],
[41, 45, 15, 21, 8],
[18, 2, 61, 56, 14],
];
for x in 0..5 {
for y in 0..5 {
let rotation = RHO_OFFSETS[x][y];
state[x][y] = builder.rotate_left_u64(state[x][y], rotation as u8);
}
}
}
//TODO: remove the dead codes later
#[allow(dead_code)]
//pi
fn pi<F: RichField + Extendable<D>, const D: usize>(
builder: &mut CircuitBuilder<F, D>,
state: &mut [[U64Target; 5]; 5]
){
let mut new_state = [[builder.zero_u64(); 5]; 5];
for x in 0..5 {
for y in 0..5 {
new_state[(2 * x + 3 * y) % 5][y] = state[x][y];
}
}
*state = new_state;
}
//TODO: remove the dead codes later
#[allow(dead_code)]
//iota
fn iota<F: RichField + Extendable<D>, const D: usize>(
builder: &mut CircuitBuilder<F, D>,
state: &mut [[U64Target; 5]; 5],
round: usize
){
let lo = builder.constant_u32((ROUND_CONSTANTS[round] & 0xFFFFFFFF) as u32);
let hi = builder.constant_u32(((ROUND_CONSTANTS[round] >> 32)& 0xFFFFFFFF) as u32);
state[0][0] = builder.xor_u64(state[0][0], U64Target([lo,hi])) ;
}
//TODO: remove the dead codes later
#[allow(dead_code)]
fn chi<F: RichField + Extendable<D>, const D: usize>(
builder: &mut CircuitBuilder<F, D>,
state: &mut [[U64Target; 5]; 5]
){
for x in 0..5 {
let mut temp = [builder.zero_u64(); 5];
for y in 0..5 {
temp[y] = state[x][y];
}
for y in 0..5 {
let t1 = builder.not_u64(temp[(y + 1) % 5]);
let t2 = builder.and_u64(t1, temp[(y + 2) % 5]);
state[x][y] = builder.xor_u64(state[x][y], t2);
}
}
}
//TODO: remove the dead codes later
#[allow(dead_code)]
// permutation
fn keccak_permutation<F: RichField + Extendable<D>, const D: usize>(
builder: &mut CircuitBuilder<F, D>,
state: &mut [[U64Target; 5]; 5]
) {
for i in 0..24 {
theta(builder, state);
rho(builder, state);
pi(builder, state);
chi(builder, state);
iota(builder, state, ROUND_CONSTANTS[i] as usize)
}
}