From 9b54ee43db862e373cd71d5b3a9232945e65e09c Mon Sep 17 00:00:00 2001 From: Dmitry Vagner Date: Wed, 19 Apr 2023 13:12:47 -0700 Subject: [PATCH] refactor --- evm/src/bn254_pairing.rs | 4 +- .../bn254/curve_arithmetic/final_exponent.asm | 8 +-- .../bn254/curve_arithmetic/miller_loop.asm | 17 ++--- .../curve/bn254/curve_arithmetic/pairing.asm | 20 ++++++ evm/src/cpu/kernel/tests/bn254.rs | 65 +++++++------------ 5 files changed, 52 insertions(+), 62 deletions(-) diff --git a/evm/src/bn254_pairing.rs b/evm/src/bn254_pairing.rs index 7277c2a8..6d2347a9 100644 --- a/evm/src/bn254_pairing.rs +++ b/evm/src/bn254_pairing.rs @@ -41,7 +41,7 @@ pub struct TwistedCurve { // The tate pairing takes a point each from the curve and its twist and outputs an Fp12 element pub fn tate(p: Curve, q: TwistedCurve) -> Fp12 { let miller_output = miller_loop(p, q); - invariant_exponent(miller_output) + final_exponent(miller_output) } /// Standard code for miller loop, can be found on page 99 at this url: @@ -120,7 +120,7 @@ pub fn gen_fp12_sparse(rng: &mut R) -> Fp12 { /// (p^4 - p^2 + 1)/N = p^3 + (a2)p^2 - (a1)p - a0 /// where 0 < a0, a1, a2 < p. Then the final power is given by /// y = y_3 * (y^a2)_2 * (y^-a1)_1 * (y^-a0) -pub fn invariant_exponent(f: Fp12) -> Fp12 { +pub fn final_exponent(f: Fp12) -> Fp12 { let mut y = f.frob(6) / f; y = y.frob(2) * y; let (y_a2, y_a1, y_a0) = get_custom_powers(y); diff --git a/evm/src/cpu/kernel/asm/curve/bn254/curve_arithmetic/final_exponent.asm b/evm/src/cpu/kernel/asm/curve/bn254/curve_arithmetic/final_exponent.asm index 2fcd5d2b..85b1d639 100644 --- a/evm/src/cpu/kernel/asm/curve/bn254/curve_arithmetic/final_exponent.asm +++ b/evm/src/cpu/kernel/asm/curve/bn254/curve_arithmetic/final_exponent.asm @@ -2,18 +2,18 @@ /// (p^12 - 1)/N = (p^6 - 1) * (p^2 + 1) * (p^4 - p^2 + 1)/N /// and thus we can exponentiate by each factor sequentially. /// -/// def bn254_invariant_exponent(y: Fp12): +/// def bn254_final_exponent(y: Fp12): /// y = first_exp(y) /// y = second_exp(y) /// return final_exp(y) -global bn254_invariant_exponent: +global bn254_final_exponent: /// first, exponentiate by (p^6 - 1) via /// def first_exp(y): /// return y.frob(6) / y - // stack: out, retdest {out: y} - %stack (out) -> (out, 0, first_exp, out) + // stack: k, inp, out, retdest {out: y} + %stack (k, inp, out) -> (out, 0, first_exp, out) // stack: out, 0, first_exp, out, retdest {out: y} %jump(inv_fp254_12) first_exp: diff --git a/evm/src/cpu/kernel/asm/curve/bn254/curve_arithmetic/miller_loop.asm b/evm/src/cpu/kernel/asm/curve/bn254/curve_arithmetic/miller_loop.asm index f09684bd..0067e0ec 100644 --- a/evm/src/cpu/kernel/asm/curve/bn254/curve_arithmetic/miller_loop.asm +++ b/evm/src/cpu/kernel/asm/curve/bn254/curve_arithmetic/miller_loop.asm @@ -1,12 +1,3 @@ -/// def tate(P: Curve, Q: TwistedCurve) -> Fp12: -/// out = miller_loop(P, Q) -/// return bn254_invariant_exponent(P, Q) -global bn254_tate: - // stack: inp, out, retdest - %stack (inp, out) -> (inp, out, bn254_invariant_exponent, out) - // stack: inp, out, bn254_invariant_exponent, out, retdest - %jump(bn254_miller) - /// def miller(P, Q): /// miller_init() /// miller_loop() @@ -35,13 +26,13 @@ global bn254_tate: /// mul_tangent() global bn254_miller: - // stack: ptr, out, retdest + // stack: ptr, out, retdest %stack (ptr, out) -> (out, 1, ptr, out) - // stack: out, 1, ptr, out, retdest + // stack: out, 1, ptr, out, retdest %mstore_kernel_bn254_pairing - // stack: ptr, out, retdest + // stack: ptr, out, retdest %load_fp254_6 - // stack: P, Q, out, retdest + // stack: P, Q, out, retdest %stack (P: 2) -> (0, 53, P, P) // stack: 0, 53, O, P, Q, out, retdest // the head 0 lets miller_loop start with POP diff --git a/evm/src/cpu/kernel/asm/curve/bn254/curve_arithmetic/pairing.asm b/evm/src/cpu/kernel/asm/curve/bn254/curve_arithmetic/pairing.asm index e69de29b..57008494 100644 --- a/evm/src/cpu/kernel/asm/curve/bn254/curve_arithmetic/pairing.asm +++ b/evm/src/cpu/kernel/asm/curve/bn254/curve_arithmetic/pairing.asm @@ -0,0 +1,20 @@ +/// def tate(pairs: List((Curve, TwistedCurve))) -> Fp12: +/// out = 1 +/// for P, Q in pairs: +/// out *= miller_loop(P, Q) +/// return bn254_final_exponent(out) + +global bn254_tate: + // stack: k, inp, out, retdest + DUP1 + ISZERO + // stack: end?, k, inp, out, retdest + %jumpi(bn254_final_exponent) + // stack: k, inp, out, retdest + + + + + %stack (inp, out) -> (inp, out, bn254_final_exponent, out) + // stack: inp, out, bn254_final_exponent, out, retdest + %jump(bn254_miller) diff --git a/evm/src/cpu/kernel/tests/bn254.rs b/evm/src/cpu/kernel/tests/bn254.rs index 8e71ffd6..48455b15 100644 --- a/evm/src/cpu/kernel/tests/bn254.rs +++ b/evm/src/cpu/kernel/tests/bn254.rs @@ -3,7 +3,7 @@ use ethereum_types::U256; use rand::Rng; use crate::bn254_pairing::{ - gen_fp12_sparse, invariant_exponent, miller_loop, tate, Curve, TwistedCurve, + final_exponent, gen_fp12_sparse, miller_loop, tate, Curve, TwistedCurve, }; use crate::cpu::kernel::interpreter::{ run_interpreter_with_memory, Interpreter, InterpreterMemoryInitialization, @@ -20,22 +20,21 @@ fn extract_stack(interpreter: Interpreter<'static>) -> Vec { .collect::>() } -fn setup_mul_fp6_test( - f: Fp6, - g: Fp6, - label: &str, -) -> InterpreterMemoryInitialization { +fn run_mul_fp6_test(f: Fp6, g: Fp6, label: &str) -> Vec { let mut stack = f.on_stack(); if label == "mul_fp254_6" { stack.extend(g.on_stack()); } stack.push(U256::from(0xdeadbeefu32)); - InterpreterMemoryInitialization { + + let setup = InterpreterMemoryInitialization { label: label.to_string(), stack, segment: BnPairing, memory: vec![], - } + }; + let interpreter = run_interpreter_with_memory(setup).unwrap(); + extract_stack(interpreter) } #[test] @@ -44,14 +43,8 @@ fn test_mul_fp6() -> Result<()> { let f: Fp6 = rng.gen::>(); let g: Fp6 = rng.gen::>(); - let setup_normal: InterpreterMemoryInitialization = setup_mul_fp6_test(f, g, "mul_fp254_6"); - let setup_square: InterpreterMemoryInitialization = setup_mul_fp6_test(f, f, "square_fp254_6"); - - let intrptr_normal: Interpreter = run_interpreter_with_memory(setup_normal).unwrap(); - let intrptr_square: Interpreter = run_interpreter_with_memory(setup_square).unwrap(); - - let out_normal: Vec = extract_stack(intrptr_normal); - let out_square: Vec = extract_stack(intrptr_square); + let out_normal: Vec = run_mul_fp6_test(f, g, "mul_fp254_6"); + let out_square: Vec = run_mul_fp6_test(f, f, "square_fp254_6"); let exp_normal: Vec = (f * g).on_stack(); let exp_square: Vec = (f * f).on_stack(); @@ -62,14 +55,10 @@ fn test_mul_fp6() -> Result<()> { Ok(()) } -fn setup_mul_fp12_test( - out: usize, - f: Fp12, - g: Fp12, - label: &str, -) -> InterpreterMemoryInitialization { +fn run_mul_fp12_test(f: Fp12, g: Fp12, label: &str) -> Vec { let in0: usize = 200; let in1: usize = 212; + let out: usize = 224; let mut stack = vec![ U256::from(in0), @@ -80,37 +69,27 @@ fn setup_mul_fp12_test( if label == "square_fp254_12" { stack.remove(0); } - InterpreterMemoryInitialization { + + let setup = InterpreterMemoryInitialization { label: label.to_string(), stack, segment: BnPairing, memory: vec![(in0, f.on_stack()), (in1, g.on_stack())], - } + }; + let interpreter = run_interpreter_with_memory(setup).unwrap(); + interpreter.extract_kernel_memory(BnPairing, out..out + 12) } #[test] fn test_mul_fp12() -> Result<()> { - let out: usize = 224; - let mut rng = rand::thread_rng(); let f: Fp12 = rng.gen::>(); let g: Fp12 = rng.gen::>(); let h: Fp12 = gen_fp12_sparse(&mut rng); - let setup_normal: InterpreterMemoryInitialization = - setup_mul_fp12_test(out, f, g, "mul_fp254_12"); - let setup_sparse: InterpreterMemoryInitialization = - setup_mul_fp12_test(out, f, h, "mul_fp254_12_sparse"); - let setup_square: InterpreterMemoryInitialization = - setup_mul_fp12_test(out, f, f, "square_fp254_12"); - - let intrptr_normal: Interpreter = run_interpreter_with_memory(setup_normal).unwrap(); - let intrptr_sparse: Interpreter = run_interpreter_with_memory(setup_sparse).unwrap(); - let intrptr_square: Interpreter = run_interpreter_with_memory(setup_square).unwrap(); - - let out_normal: Vec = intrptr_normal.extract_kernel_memory(BnPairing, out..out + 12); - let out_sparse: Vec = intrptr_sparse.extract_kernel_memory(BnPairing, out..out + 12); - let out_square: Vec = intrptr_square.extract_kernel_memory(BnPairing, out..out + 12); + let out_normal: Vec = run_mul_fp12_test(f, g, "mul_fp254_12"); + let out_sparse: Vec = run_mul_fp12_test(f, h, "mul_fp254_12_sparse"); + let out_square: Vec = run_mul_fp12_test(f, f, "square_fp254_12"); let exp_normal: Vec = (f * g).on_stack(); let exp_sparse: Vec = (f * h).on_stack(); @@ -193,13 +172,13 @@ fn test_inv_fp12() -> Result<()> { } #[test] -fn test_invariant_exponent() -> Result<()> { +fn test_final_exponent() -> Result<()> { let ptr: usize = 200; let mut rng = rand::thread_rng(); let f: Fp12 = rng.gen::>(); let setup = InterpreterMemoryInitialization { - label: "bn254_invariant_exponent".to_string(), + label: "bn254_final_exponent".to_string(), stack: vec![U256::from(ptr), U256::from(0xdeadbeefu32)], segment: BnPairing, memory: vec![(ptr, f.on_stack())], @@ -207,7 +186,7 @@ fn test_invariant_exponent() -> Result<()> { let interpreter: Interpreter = run_interpreter_with_memory(setup).unwrap(); let output: Vec = interpreter.extract_kernel_memory(BnPairing, ptr..ptr + 12); - let expected: Vec = invariant_exponent(f).on_stack(); + let expected: Vec = final_exponent(f).on_stack(); assert_eq!(output, expected);