diff --git a/evm/src/cpu/kernel/asm/curve/bls381/util.asm b/evm/src/cpu/kernel/asm/curve/bls381/util.asm index a4e846cb..13943be7 100644 --- a/evm/src/cpu/kernel/asm/curve/bls381/util.asm +++ b/evm/src/cpu/kernel/asm/curve/bls381/util.asm @@ -12,20 +12,6 @@ // stack: z0, z1 %endmacro -%macro mul_fp381 - // stack: x0, x1, y0, y1 - PROVER_INPUT(sf::bls381_base::mul_hi) - // stack: z1, x0, x1, y0, y1 - SWAP4 - // stack: y1, x0, x1, y0, z1 - PROVER_INPUT(sf::bls381_base::mul_lo) - // stack: z0, y1, x0, x1, y0, z1 - SWAP4 - // stack: y0, y1, x0, x1, z0, z1 - %pop4 - // stack: z0, z1 -%endmacro - %macro sub_fp381 // stack: x0, x1, y0, y1 PROVER_INPUT(sf::bls381_base::sub_hi) @@ -40,14 +26,76 @@ // stack: z0, z1 %endmacro -global test_add_fp381: +%macro mul_fp381 + // stack: x0, x1, y0, y1 + PROVER_INPUT(sf::bls381_base::mul_hi) + // stack: z1, x0, x1, y0, y1 + SWAP4 + // stack: y1, x0, x1, y0, z1 + PROVER_INPUT(sf::bls381_base::mul_lo) + // stack: z0, y1, x0, x1, y0, z1 + SWAP4 + // stack: y0, y1, x0, x1, z0, z1 + %pop4 + // stack: z0, z1 +%endmacro + +%macro add_fp381_2 + // stack: x_re, x_im, y_re, y_im + %stack (x_re: 2, x_im: 2, y_re: 2, y_im: 2) -> (y_im, x_im, y_re, x_re) + // stack: y_im, x_im, y_re, x_re %add_fp381 - %jump(0xdeadbeef) + // stack: z_im, y_re, x_re + %stack (z_im: 2, y_re: 2, x_re: 2) -> (x_re, y_re, z_im) + // stack: x_re, y_re, z_im + %add_fp381 + // stack: z_re, z_im +%endmacro -global test_mul_fp381: - %mul_fp381 - %jump(0xdeadbeef) - -global test_sub_fp381: +%macro sub_fp381_2 + // stack: x_re, x_im, y_re, y_im + %stack (x_re: 2, x_im: 2, y_re: 2, y_im: 2) -> (x_im, y_im, y_re, x_re) + // stack: x_im, y_im, y_re, x_re %sub_fp381 - %jump(0xdeadbeef) + // stack: z_im, y_re, x_re + %stack (z_im: 2, y_re: 2, x_re: 2) -> (x_re, y_re, z_im) + // stack: x_re, y_re, z_im + %sub_fp381 + // stack: z_re, z_im +%endmacro + +// note that {x,y}_{re,im} all take up two stack terms +global mul_fp381_2: + // stack: x_re, x_im, y_re, y_im, jumpdest + DUP4 + DUP4 + // stack: x_im, x_re, x_im, y_re, y_im, jumpdest + DUP8 + DUP8 + // stack: y_re, x_im, x_re, x_im, y_re, y_im, jumpdest + DUP12 + DUP12 + // stack: y_im, y_re, x_im, x_re, x_im, y_re, y_im, jumpdest + DUP8 + DUP8 + // stack: x_re , y_im, y_re, x_im, x_re, x_im, y_re, y_im, jumpdest + %mul_fp381 + // stack: x_re * y_im, y_re, x_im, x_re, x_im, y_re, y_im, jumpdest + %stack (v: 2, y_re: 2, x_im: 2) -> (x_im, y_re, v) + // stack: x_im , y_re, x_re*y_im, x_re, x_im, y_re, y_im, jumpdest + %mul_fp381 + // stack: x_im * y_re, x_re*y_im, x_re, x_im, y_re, y_im, jumpdest + %add_fp381 + // stack: z_im, x_re, x_im, y_re, y_im, jumpdest + %stack (z_im: 2, x_re: 2, x_im: 2, y_re: 2, y_im: 2) -> (x_im, y_im, y_re, x_re, z_im) + // stack: x_im , y_im, y_re, x_re, z_im, jumpdest + %mul_fp381 + // stack: x_im * y_im, y_re, x_re, z_im, jumpdest + %stack (v: 2, y_re: 2, x_re: 2) -> (x_re, y_re, v) + // stack: x_re , y_re, x_im*y_im, z_im, jumpdest + %mul_fp381 + // stack: x_re * y_re, x_im*y_im, z_im, jumpdest + %sub_fp381 + // stack: z_re, z_im, jumpdest + %stack (z_re: 2, z_im: 2, jumpdest) -> (jumpdest, z_re, z_im) + JUMP diff --git a/evm/src/cpu/kernel/tests/bls381.rs b/evm/src/cpu/kernel/tests/bls381.rs index afd22a14..aeba6fbd 100644 --- a/evm/src/cpu/kernel/tests/bls381.rs +++ b/evm/src/cpu/kernel/tests/bls381.rs @@ -1,42 +1,32 @@ use anyhow::Result; -use ethereum_types::U512; +use ethereum_types::U256; use rand::Rng; use crate::cpu::kernel::interpreter::{ run_interpreter_with_memory, InterpreterMemoryInitialization, }; -use crate::extension_tower::{Stack, BLS381}; +use crate::extension_tower::{Fp2, Stack, BLS381}; use crate::memory::segments::Segment::KernelGeneral; -fn run_and_return_bls(label: String, x: BLS381, y: BLS381) -> BLS381 { - let mut stack = x.on_stack(); - stack.extend(y.on_stack()); +#[test] +fn test_bls_fp2_mul() -> Result<()> { + let mut rng = rand::thread_rng(); + let x: Fp2 = rng.gen::>(); + let y: Fp2 = rng.gen::>(); + + let mut stack = x.to_stack().to_vec(); + stack.extend(y.to_stack().to_vec()); + stack.push(U256::from(0xdeadbeefu32)); let setup = InterpreterMemoryInitialization { - label, + label: "mul_fp381_2".to_string(), stack, segment: KernelGeneral, memory: vec![], }; let interpreter = run_interpreter_with_memory(setup).unwrap(); - let output = interpreter.stack(); - BLS381 { - val: U512::from(output[1]) + (U512::from(output[0]) << 256), - } -} - -#[test] -fn test_bls_ops() -> Result<()> { - let mut rng = rand::thread_rng(); - let x: BLS381 = rng.gen::(); - let y: BLS381 = rng.gen::(); - - let output_add = run_and_return_bls("test_add_fp381".to_string(), x, y); - let output_mul = run_and_return_bls("test_mul_fp381".to_string(), x, y); - let output_sub = run_and_return_bls("test_sub_fp381".to_string(), x, y); - - assert_eq!(output_add, x + y); - assert_eq!(output_mul, x * y); - assert_eq!(output_sub, x - y); + let stack: Vec = interpreter.stack().iter().rev().cloned().collect(); + let output = Fp2::::from_stack(&stack); + assert_eq!(output, x * y); Ok(()) } diff --git a/evm/src/cpu/kernel/tests/bn254.rs b/evm/src/cpu/kernel/tests/bn254.rs index 548f9789..5ed60e7a 100644 --- a/evm/src/cpu/kernel/tests/bn254.rs +++ b/evm/src/cpu/kernel/tests/bn254.rs @@ -11,22 +11,12 @@ use crate::curve_pairings::{ use crate::extension_tower::{FieldExt, Fp12, Fp2, Fp6, Stack, BN254}; use crate::memory::segments::Segment::BnPairing; -fn extract_stack(interpreter: Interpreter<'static>) -> Vec { - interpreter - .stack() - .iter() - .rev() - .cloned() - .collect::>() -} - -fn run_bn_mul_fp6(f: Fp6, g: Fp6, label: &str) -> Vec { - let mut stack = f.on_stack(); +fn run_bn_mul_fp6(f: Fp6, g: Fp6, label: &str) -> Fp6 { + let mut stack = f.to_stack(); if label == "mul_fp254_6" { - stack.extend(g.on_stack()); + stack.extend(g.to_stack().to_vec()); } stack.push(U256::from(0xdeadbeefu32)); - let setup = InterpreterMemoryInitialization { label: label.to_string(), stack, @@ -34,7 +24,8 @@ fn run_bn_mul_fp6(f: Fp6, g: Fp6, label: &str) -> Vec { memory: vec![], }; let interpreter = run_interpreter_with_memory(setup).unwrap(); - extract_stack(interpreter) + let output: Vec = interpreter.stack().iter().rev().cloned().collect(); + Fp6::::from_stack(&output) } #[test] @@ -43,19 +34,16 @@ fn test_bn_mul_fp6() -> Result<()> { let f: Fp6 = rng.gen::>(); let g: Fp6 = rng.gen::>(); - let out_normal: Vec = run_bn_mul_fp6(f, g, "mul_fp254_6"); - let out_square: Vec = run_bn_mul_fp6(f, f, "square_fp254_6"); + let output_normal: Fp6 = run_bn_mul_fp6(f, g, "mul_fp254_6"); + let output_square: Fp6 = run_bn_mul_fp6(f, f, "square_fp254_6"); - let exp_normal: Vec = (f * g).on_stack(); - let exp_square: Vec = (f * f).on_stack(); - - assert_eq!(out_normal, exp_normal); - assert_eq!(out_square, exp_square); + assert_eq!(output_normal, f * g); + assert_eq!(output_square, f * f); Ok(()) } -fn run_bn_mul_fp12(f: Fp12, g: Fp12, label: &str) -> Vec { +fn run_bn_mul_fp12(f: Fp12, g: Fp12, label: &str) -> Fp12 { let in0: usize = 100; let in1: usize = 112; let out: usize = 124; @@ -69,15 +57,15 @@ fn run_bn_mul_fp12(f: Fp12, g: Fp12, label: &str) -> Vec { if label == "square_fp254_12" { stack.remove(0); } - let setup = InterpreterMemoryInitialization { label: label.to_string(), stack, segment: BnPairing, - memory: vec![(in0, f.on_stack()), (in1, g.on_stack())], + memory: vec![(in0, f.to_stack().to_vec()), (in1, g.to_stack().to_vec())], }; let interpreter = run_interpreter_with_memory(setup).unwrap(); - interpreter.extract_kernel_memory(BnPairing, out..out + 12) + let output = interpreter.extract_kernel_memory(BnPairing, out..out + 12); + Fp12::::from_stack(&output) } #[test] @@ -87,30 +75,27 @@ fn test_bn_mul_fp12() -> Result<()> { let g: Fp12 = rng.gen::>(); let h: Fp12 = gen_bn_fp12_sparse(&mut rng); - let out_normal: Vec = run_bn_mul_fp12(f, g, "mul_fp254_12"); - let out_sparse: Vec = run_bn_mul_fp12(f, h, "mul_fp254_12_sparse"); - let out_square: Vec = run_bn_mul_fp12(f, f, "square_fp254_12"); + let output_normal = run_bn_mul_fp12(f, g, "mul_fp254_12"); + let output_sparse = run_bn_mul_fp12(f, h, "mul_fp254_12_sparse"); + let output_square = run_bn_mul_fp12(f, f, "square_fp254_12"); - let exp_normal: Vec = (f * g).on_stack(); - let exp_sparse: Vec = (f * h).on_stack(); - let exp_square: Vec = (f * f).on_stack(); - - assert_eq!(out_normal, exp_normal); - assert_eq!(out_sparse, exp_sparse); - assert_eq!(out_square, exp_square); + assert_eq!(output_normal, f * g); + assert_eq!(output_sparse, f * h); + assert_eq!(output_square, f * f); Ok(()) } -fn run_bn_frob_fp6(f: Fp6, n: usize) -> Vec { +fn run_bn_frob_fp6(n: usize, f: Fp6) -> Fp6 { let setup = InterpreterMemoryInitialization { label: format!("test_frob_fp254_6_{}", n), - stack: f.on_stack(), + stack: f.to_stack().to_vec(), segment: BnPairing, memory: vec![], }; - let interpreter = run_interpreter_with_memory(setup).unwrap(); - extract_stack(interpreter) + let interpreter: Interpreter = run_interpreter_with_memory(setup).unwrap(); + let output: Vec = interpreter.stack().iter().rev().cloned().collect(); + Fp6::::from_stack(&output) } #[test] @@ -118,34 +103,33 @@ fn test_bn_frob_fp6() -> Result<()> { let mut rng = rand::thread_rng(); let f: Fp6 = rng.gen::>(); for n in 1..4 { - let output: Vec = run_bn_frob_fp6(f, n); - let expected: Vec = f.frob(n).on_stack(); - assert_eq!(output, expected); + let output = run_bn_frob_fp6(n, f); + assert_eq!(output, f.frob(n)); } Ok(()) } -fn run_bn_frob_fp12(f: Fp12, n: usize) -> Vec { +fn run_bn_frob_fp12(f: Fp12, n: usize) -> Fp12 { let ptr: usize = 100; let setup = InterpreterMemoryInitialization { label: format!("test_frob_fp254_12_{}", n), stack: vec![U256::from(ptr)], segment: BnPairing, - memory: vec![(ptr, f.on_stack())], + memory: vec![(ptr, f.to_stack().to_vec())], }; - let interpreter = run_interpreter_with_memory(setup).unwrap(); - interpreter.extract_kernel_memory(BnPairing, ptr..ptr + 12) + let interpeter: Interpreter = run_interpreter_with_memory(setup).unwrap(); + let output: Vec = interpeter.extract_kernel_memory(BnPairing, ptr..ptr + 12); + Fp12::::from_stack(&output) } #[test] -fn test_bn_frob_fp12() -> Result<()> { +fn test_frob_fp12() -> Result<()> { let mut rng = rand::thread_rng(); let f: Fp12 = rng.gen::>(); for n in [1, 2, 3, 6] { let output = run_bn_frob_fp12(f, n); - let expected: Vec = f.frob(n).on_stack(); - assert_eq!(output, expected); + assert_eq!(output, f.frob(n)); } Ok(()) } @@ -161,13 +145,13 @@ fn test_bn_inv_fp12() -> Result<()> { label: "inv_fp254_12".to_string(), stack: vec![U256::from(ptr), U256::from(inv), U256::from(0xdeadbeefu32)], segment: BnPairing, - memory: vec![(ptr, f.on_stack())], + memory: vec![(ptr, f.to_stack().to_vec())], }; let interpreter: Interpreter = run_interpreter_with_memory(setup).unwrap(); let output: Vec = interpreter.extract_kernel_memory(BnPairing, inv..inv + 12); - let expected: Vec = f.inv().on_stack(); + let output = Fp12::::from_stack(&output); - assert_eq!(output, expected); + assert_eq!(output, f.inv()); Ok(()) } @@ -188,12 +172,12 @@ fn test_bn_final_exponent() -> Result<()> { U256::from(0xdeadbeefu32), ], segment: BnPairing, - memory: vec![(ptr, f.on_stack())], + memory: vec![(ptr, f.to_stack().to_vec())], }; let interpreter: Interpreter = run_interpreter_with_memory(setup).unwrap(); let output: Vec = interpreter.extract_kernel_memory(BnPairing, ptr..ptr + 12); - let expected: Vec = bn_final_exponent(f).on_stack(); + let expected: Vec = bn_final_exponent(f).to_stack(); assert_eq!(output, expected); @@ -209,8 +193,8 @@ fn test_bn_miller() -> Result<()> { let p: Curve = rng.gen::>(); let q: Curve> = rng.gen::>>(); - let mut input = p.on_stack(); - input.extend(q.on_stack()); + let mut input = p.to_stack(); + input.extend(q.to_stack()); let setup = InterpreterMemoryInitialization { label: "bn254_miller".to_string(), @@ -220,7 +204,7 @@ fn test_bn_miller() -> Result<()> { }; let interpreter = run_interpreter_with_memory(setup).unwrap(); let output: Vec = interpreter.extract_kernel_memory(BnPairing, out..out + 12); - let expected = bn_miller_loop(p, q).on_stack(); + let expected = bn_miller_loop(p, q).to_stack(); assert_eq!(output, expected); @@ -243,13 +227,13 @@ fn test_bn_pairing() -> Result<()> { let p: Curve = Curve::::int(m); let q: Curve> = Curve::>::int(n); - input.extend(p.on_stack()); - input.extend(q.on_stack()); + input.extend(p.to_stack()); + input.extend(q.to_stack()); } let p: Curve = Curve::::int(acc); let q: Curve> = Curve::>::GENERATOR; - input.extend(p.on_stack()); - input.extend(q.on_stack()); + input.extend(p.to_stack()); + input.extend(q.to_stack()); let setup = InterpreterMemoryInitialization { label: "bn254_pairing".to_string(), diff --git a/evm/src/curve_pairings.rs b/evm/src/curve_pairings.rs index 708e7fb2..d789051a 100644 --- a/evm/src/curve_pairings.rs +++ b/evm/src/curve_pairings.rs @@ -25,12 +25,21 @@ impl Curve { } } -impl Curve { - pub fn on_stack(self) -> Vec { - let mut stack = self.x.on_stack(); - stack.extend(self.y.on_stack()); +impl Stack for Curve { + const SIZE: usize = 2 * T::SIZE; + + fn to_stack(&self) -> Vec { + let mut stack = self.x.to_stack(); + stack.extend(self.y.to_stack()); stack } + + fn from_stack(stack: &[U256]) -> Self { + Curve { + x: T::from_stack(&stack[0..T::SIZE]), + y: T::from_stack(&stack[T::SIZE..2 * T::SIZE]), + } + } } impl Curve diff --git a/evm/src/extension_tower.rs b/evm/src/extension_tower.rs index 0e654c88..845d99aa 100644 --- a/evm/src/extension_tower.rs +++ b/evm/src/extension_tower.rs @@ -1223,43 +1223,79 @@ where } pub trait Stack { - fn on_stack(self) -> Vec; + const SIZE: usize; + + fn to_stack(&self) -> Vec; + + fn from_stack(stack: &[U256]) -> Self; } impl Stack for BN254 { - fn on_stack(self) -> Vec { + const SIZE: usize = 1; + + fn to_stack(&self) -> Vec { vec![self.val] } + + fn from_stack(stack: &[U256]) -> BN254 { + BN254 { val: stack[0] } + } } impl Stack for BLS381 { - fn on_stack(self) -> Vec { + const SIZE: usize = 2; + + fn to_stack(&self) -> Vec { vec![self.lo(), self.hi()] } + + fn from_stack(stack: &[U256]) -> BLS381 { + let mut val = [0u64; 8]; + val[..4].copy_from_slice(&stack[0].0); + val[4..].copy_from_slice(&stack[1].0); + BLS381 { val: U512(val) } + } } -impl Stack for Fp2 -where - T: FieldExt + Stack, -{ - fn on_stack(self) -> Vec { - let mut stack = self.re.on_stack(); - stack.extend(self.im.on_stack()); +impl Stack for Fp2 { + const SIZE: usize = 2 * T::SIZE; + + fn to_stack(&self) -> Vec { + let mut stack = self.re.to_stack(); + stack.extend(self.im.to_stack()); stack } + + fn from_stack(stack: &[U256]) -> Fp2 { + let field_size = T::SIZE; + let re = T::from_stack(&stack[0..field_size]); + let im = T::from_stack(&stack[field_size..2 * field_size]); + Fp2 { re, im } + } } impl Stack for Fp6 where T: FieldExt, - Fp2: Adj + Stack, + Fp2: Adj, + Fp2: Stack, { - fn on_stack(self) -> Vec { - let mut stack = self.t0.on_stack(); - stack.extend(self.t1.on_stack()); - stack.extend(self.t2.on_stack()); + const SIZE: usize = 3 * Fp2::::SIZE; + + fn to_stack(&self) -> Vec { + let mut stack = self.t0.to_stack(); + stack.extend(self.t1.to_stack()); + stack.extend(self.t2.to_stack()); stack } + + fn from_stack(stack: &[U256]) -> Self { + let field_size = Fp2::::SIZE; + let t0 = Fp2::::from_stack(&stack[0..field_size]); + let t1 = Fp2::::from_stack(&stack[field_size..2 * field_size]); + let t2 = Fp2::::from_stack(&stack[2 * field_size..3 * field_size]); + Fp6 { t0, t1, t2 } + } } impl Stack for Fp12 @@ -1268,9 +1304,18 @@ where Fp2: Adj, Fp6: Stack, { - fn on_stack(self) -> Vec { - let mut stack = self.z0.on_stack(); - stack.extend(self.z1.on_stack()); + const SIZE: usize = 2 * Fp6::::SIZE; + + fn to_stack(&self) -> Vec { + let mut stack = self.z0.to_stack(); + stack.extend(self.z1.to_stack()); stack } + + fn from_stack(stack: &[U256]) -> Self { + let field_size = Fp6::::SIZE; + let z0 = Fp6::::from_stack(&stack[0..field_size]); + let z1 = Fp6::::from_stack(&stack[field_size..2 * field_size]); + Fp12 { z0, z1 } + } }