From 69afed9297231d0ff7f0cdb74c2d4fae957b4ee9 Mon Sep 17 00:00:00 2001 From: Dmitry Vagner Date: Tue, 7 Feb 2023 14:54:07 -0800 Subject: [PATCH] refactor --- evm/src/bn254_pairing.rs | 4 +- evm/src/cpu/kernel/aggregator.rs | 3 +- ...final_power.asm => invariant_exponent.asm} | 35 +- .../bn254/curve_arithmetic/miller_loop.asm | 283 ---------------- .../bn254/curve_arithmetic/tate_pairing.asm | 315 ++++++++++++++++-- evm/src/cpu/kernel/tests/bn254.rs | 115 +++---- 6 files changed, 378 insertions(+), 377 deletions(-) rename evm/src/cpu/kernel/asm/curve/bn254/curve_arithmetic/{final_power.asm => invariant_exponent.asm} (90%) delete mode 100644 evm/src/cpu/kernel/asm/curve/bn254/curve_arithmetic/miller_loop.asm diff --git a/evm/src/bn254_pairing.rs b/evm/src/bn254_pairing.rs index 71f9575f..bf5db74a 100644 --- a/evm/src/bn254_pairing.rs +++ b/evm/src/bn254_pairing.rs @@ -41,7 +41,7 @@ pub struct TwistedCurve { // The tate pairing takes a point each from the curve and its twist and outputs an Fp12 element pub fn tate(p: Curve, q: TwistedCurve) -> Fp12 { let miller_output = miller_loop(p, q); - invariance_inducing_power(miller_output) + invariant_exponent(miller_output) } /// Standard code for miller loop, can be found on page 99 at this url: @@ -116,7 +116,7 @@ pub fn gen_fp12_sparse(rng: &mut R) -> Fp12 { /// (p^4 - p^2 + 1)/N = p^3 + (a2)p^2 - (a1)p - a0 /// where 0 < a0, a1, a2 < p. Then the final power is given by /// y = y_3 * (y^a2)_2 * (y^-a1)_1 * (y^-a0) -pub fn invariance_inducing_power(f: Fp12) -> Fp12 { +pub fn invariant_exponent(f: Fp12) -> Fp12 { let mut y = f.frob(6) / f; y = y.frob(2) * y; let (y_a2, y_a1, y_a0) = get_custom_powers(y); diff --git a/evm/src/cpu/kernel/aggregator.rs b/evm/src/cpu/kernel/aggregator.rs index 3c998449..7fbb9f08 100644 --- a/evm/src/cpu/kernel/aggregator.rs +++ b/evm/src/cpu/kernel/aggregator.rs @@ -27,8 +27,7 @@ pub(crate) fn combined_kernel() -> Kernel { include_str!("asm/curve/bn254/curve_arithmetic/constants.asm"), include_str!("asm/curve/bn254/curve_arithmetic/curve_add.asm"), include_str!("asm/curve/bn254/curve_arithmetic/curve_mul.asm"), - include_str!("asm/curve/bn254/curve_arithmetic/final_power.asm"), - include_str!("asm/curve/bn254/curve_arithmetic/miller_loop.asm"), + include_str!("asm/curve/bn254/curve_arithmetic/invariant_exponent.asm"), include_str!("asm/curve/bn254/curve_arithmetic/tate_pairing.asm"), include_str!("asm/curve/bn254/field_arithmetic/inverse.asm"), include_str!("asm/curve/bn254/field_arithmetic/degree_6_mul.asm"), diff --git a/evm/src/cpu/kernel/asm/curve/bn254/curve_arithmetic/final_power.asm b/evm/src/cpu/kernel/asm/curve/bn254/curve_arithmetic/invariant_exponent.asm similarity index 90% rename from evm/src/cpu/kernel/asm/curve/bn254/curve_arithmetic/final_power.asm rename to evm/src/cpu/kernel/asm/curve/bn254/curve_arithmetic/invariant_exponent.asm index 7f22587a..3176dbf5 100644 --- a/evm/src/cpu/kernel/asm/curve/bn254/curve_arithmetic/final_power.asm +++ b/evm/src/cpu/kernel/asm/curve/bn254/curve_arithmetic/invariant_exponent.asm @@ -22,7 +22,40 @@ /// y1 = y1.frob(1) /// return y * y2 * y1 * y0 -global bn254_final_exp: +/// def bn254_invariant_exponent(y: Fp12): +/// y = first_exp(y) +/// y = second_exp(y) +/// return final_exponentiation(y) + +global bn254_invariant_exponent: + +/// map t to t^(p^6 - 1) via +/// def first_exp(t): +/// return t.frob(6) / t + // stack: out, retdest {out: y} + %stack (out) -> (out, 100, first_exp, out) + // stack: out, 100, first_exp, out, retdest {out: y} + %jump(inv_fp254_12) +first_exp: + // stack: out, retdest {out: y , 100: y^-1} + %frob_fp254_12_6 + // stack: out, retdest {out: y_6, 100: y^-1} + %stack (out) -> (out, 100, out, second_exp, out) + // stack: out, 100, out, second_exp, out, retdest {out: y_6, 100: y^-1} + %jump(mul_fp254_12) + +/// map t to t^(p^2 + 1) via +/// def second_exp(t): +/// return t.frob(2) * t +second_exp: + // stack: out, retdest {out: y} + %stack (out) -> (out, 100, out, out, final_exp, out) + // stack: out, 100, out, out, final_exp, out, retdest {out: y} + %frob_fp254_12_2_ + // stack: 100, out, out, final_exp, out, retdest {out: y, 100: y_2} + %jump(mul_fp254_12) + +final_exp: // stack: val, retdest %stack (val) -> (val, 300, val) // stack: val, 300, val, retdest diff --git a/evm/src/cpu/kernel/asm/curve/bn254/curve_arithmetic/miller_loop.asm b/evm/src/cpu/kernel/asm/curve/bn254/curve_arithmetic/miller_loop.asm deleted file mode 100644 index 63387cb4..00000000 --- a/evm/src/cpu/kernel/asm/curve/bn254/curve_arithmetic/miller_loop.asm +++ /dev/null @@ -1,283 +0,0 @@ -/// def miller(P, Q): -/// miller_init() -/// miller_loop() -/// -/// def miller_init(): -/// out = 1 -/// O = P -/// times = 61 -/// -/// def miller_loop(): -/// while times: -/// 0xnm = load(miller_data) -/// while 0xnm > 0x20: -/// miller_one() -/// while 0xnm: -/// miller_zero() -/// times -= 1 -/// -/// def miller_one(): -/// 0xnm -= 0x20 -/// mul_tangent() -/// mul_cord() -/// -/// def miller_zero(): -/// 0xnm -= 1 -/// mul_tangent() - -global bn254_miller: - // stack: ptr, out, retdest - %stack (ptr, out) -> (out, 1, ptr, out) - // stack: out, 1, ptr, out, retdest - %mstore_kernel_general - // stack: ptr, out, retdest - %load_fp254_6 - // stack: P, Q, out, retdest - %stack (P: 2) -> (0, 53, P, P) - // stack: 0, 53, O, P, Q, out, retdest - // the head 0 lets miller_loop start with POP -miller_loop: - POP - // stack: times , O, P, Q, out, retdest - DUP1 - ISZERO - // stack: break?, times , O, P, Q, out, retdest - %jumpi(miller_return) - // stack: times , O, P, Q, out, retdest - %sub_const(1) - // stack: times-1, O, P, Q, out, retdest - DUP1 - // stack: times-1, times-1, O, P, Q, out, retdest - %mload_kernel_code(miller_data) - // stack: 0xnm, times-1, O, P, Q, out, retdest - %jump(miller_one) -miller_return: - // stack: times, O, P, Q, out, retdest - %stack (times, O: 2, P: 2, Q: 4, out, retdest) -> (retdest) - // stack: retdest - JUMP - -miller_one: - // stack: 0xnm, times, O, P, Q, out, retdest - DUP1 - %lt_const(0x20) - // stack: skip?, 0xnm, times, O, P, Q, out, retdest - %jumpi(miller_zero) - // stack: 0xnm, times, O, P, Q, out, retdest - %sub_const(0x20) - // stack: 0x{n-1}m, times, O, P, Q, out, retdest - PUSH mul_cord - // stack: mul_cord, 0x{n-1}m, times, O, P, Q, out, retdest - %jump(mul_tangent) - -miller_zero: - // stack: m , times, O, P, Q, out, retdest - DUP1 - ISZERO - // stack: skip?, m , times, O, P, Q, out, retdest - %jumpi(miller_loop) - // stack: m , times, O, P, Q, out, retdest - %sub_const(1) - // stack: m-1, times, O, P, Q, out, retdest - PUSH miller_zero - // stack: miller_zero, m-1, times, O, P, Q, out, retdest - %jump(mul_tangent) - - -/// def mul_tangent() -/// out = square_fp254_12(out) -/// line = tangent(O, Q) -/// out = mul_fp254_12_sparse(out, line) -/// O += O - -mul_tangent: - // stack: retdest, 0xnm, times, O, P, Q, out - PUSH mul_tangent_2 - DUP13 - PUSH mul_tangent_1 - // stack: mul_tangent_1, out, mul_tangent_2, retdest, 0xnm, times, O, P, Q, out - %stack (mul_tangent_1, out) -> (out, out, mul_tangent_1, out) - // stack: out, out, mul_tangent_1, out, mul_tangent_2, retdest, 0xnm, times, O, P, Q, out - %jump(square_fp254_12) -mul_tangent_1: - // stack: out, mul_tangent_2, retdest, 0xnm, times, O, P, Q, out - DUP13 - DUP13 - DUP13 - DUP13 - // stack: Q, out, mul_tangent_2, retdest, 0xnm, times, O, P, Q, out - DUP11 - DUP11 - // stack: O, Q, out, mul_tangent_2, retdest, 0xnm, times, O, P, Q, out - %tangent - // stack: out, mul_tangent_2, retdest, 0xnm, times, O, P, Q, out {100: line} - %stack (out) -> (out, 100, out) - // stack: out, 100, out, mul_tangent_2, retdest, 0xnm, times, O, P, Q, out {100: line} - %jump(mul_fp254_12_sparse) -mul_tangent_2: - // stack: retdest, 0xnm, times, O, P, Q, out {100: line} - PUSH after_double - // stack: after_double, retdest, 0xnm, times, O, P, Q, out {100: line} - DUP6 - DUP6 - // stack: O, after_double, retdest, 0xnm, times, O, P, Q, out {100: line} - %jump(ec_double) -after_double: - // stack: 2*O, retdest, 0xnm, times, O, P, Q, out {100: line} - SWAP5 - POP - SWAP5 - POP - // stack: retdest, 0xnm, times, 2*O, P, Q, out {100: line} - JUMP - -/// def mul_cord() -/// line = cord(P, O, Q) -/// out = mul_fp254_12_sparse(out, line) -/// O += P - -mul_cord: - // stack: 0xnm, times, O, P, Q, out - PUSH mul_cord_1 - // stack: mul_cord_1, 0xnm, times, O, P, Q, out - DUP11 - DUP11 - DUP11 - DUP11 - // stack: Q, mul_cord_1, 0xnm, times, O, P, Q, out - DUP9 - DUP9 - // stack: O, Q, mul_cord_1, 0xnm, times, O, P, Q, out - DUP13 - DUP13 - // stack: P, O, Q, mul_cord_1, 0xnm, times, O, P, Q, out - %cord - // stack: mul_cord_1, 0xnm, times, O, P, Q, out {100: line} - DUP12 - // stack: out, mul_cord_1, 0xnm, times, O, P, Q, out {100: line} - %stack (out) -> (out, 100, out) - // stack: out, 100, out, mul_cord_1, 0xnm, times, O, P, Q, out {100: line} - %jump(mul_fp254_12_sparse) -mul_cord_1: - // stack: 0xnm, times, O , P, Q, out - PUSH after_add - // stack: after_add, 0xnm, times, O , P, Q, out - DUP7 - DUP7 - DUP7 - DUP7 - // stack: O , P, after_add, 0xnm, times, O , P, Q, out - %jump(ec_add_valid_points) -after_add: - // stack: O + P, 0xnm, times, O , P, Q, out - SWAP4 - POP - SWAP4 - POP - // stack: 0xnm, times, O+P, P, Q, out - %jump(miller_one) - - -/// def tangent(px, py, qx, qy): -/// return sparse_store( -/// py**2 - 9, -/// (-3px**2) * qx, -/// (2py) * qy, -/// ) - -%macro tangent - // stack: px, py, qx, qx_, qy, qy_ - %stack (px, py) -> (py, py , 9, px, py) - // stack: py, py , 9, px, py, qx, qx_, qy, qy_ - MULFP254 - // stack: py^2 , 9, px, py, qx, qx_, qy, qy_ - SUBFP254 - // stack: py^2 - 9, px, py, qx, qx_, qy, qy_ - %mstore_kernel_general(100) - // stack: px, py, qx, qx_, qy, qy_ - DUP1 - MULFP254 - // stack: px^2, py, qx, qx_, qy, qy_ - PUSH 3 - MULFP254 - // stack: 3*px^2, py, qx, qx_, qy, qy_ - PUSH 0 - SUBFP254 - // stack: -3*px^2, py, qx, qx_, qy, qy_ - SWAP2 - // stack: qx, py, -3px^2, qx_, qy, qy_ - DUP3 - MULFP254 - // stack: (-3*px^2)qx, py, -3px^2, qx_, qy, qy_ - %mstore_kernel_general(102) - // stack: py, -3px^2, qx_, qy, qy_ - PUSH 2 - MULFP254 - // stack: 2py, -3px^2, qx_, qy, qy_ - SWAP3 - // stack: qy, -3px^2, qx_, 2py, qy_ - DUP4 - MULFP254 - // stack: (2py)qy, -3px^2, qx_, 2py, qy_ - %mstore_kernel_general(108) - // stack: -3px^2, qx_, 2py, qy_ - MULFP254 - // stack: (-3px^2)*qx_, 2py, qy_ - %mstore_kernel_general(103) - // stack: 2py, qy_ - MULFP254 - // stack: (2py)*qy_ - %mstore_kernel_general(109) -%endmacro - -/// def cord(p1x, p1y, p2x, p2y, qx, qy): -/// return sparse_store( -/// p1y*p2x - p2y*p1x, -/// (p2y - p1y) * qx, -/// (p1x - p2x) * qy, -/// ) - -%macro cord - // stack: p1x , p1y, p2x , p2y, qx, qx_, qy, qy_ - DUP1 - DUP5 - MULFP254 - // stack: p2y*p1x, p1x , p1y, p2x , p2y, qx, qx_, qy, qy_ - DUP3 - DUP5 - MULFP254 - // stack: p1y*p2x , p2y*p1x, p1x , p1y, p2x , p2y, qx, qx_, qy, qy_ - SUBFP254 - // stack: p1y*p2x - p2y*p1x, p1x , p1y, p2x , p2y, qx, qx_, qy, qy_ - %mstore_kernel_general(100) - // stack: p1x , p1y, p2x , p2y, qx, qx_, qy, qy_ - SWAP3 - // stack: p2y , p1y, p2x , p1x, qx, qx_, qy, qy_ - SUBFP254 - // stack: p2y - p1y, p2x , p1x, qx, qx_, qy, qy_ - SWAP2 - // stack: p1x , p2x, p2y - p1y, qx, qx_, qy, qy_ - SUBFP254 - // stack: p1x - p2x, p2y - p1y, qx, qx_, qy, qy_ - SWAP4 - // stack: qy, p2y - p1y, qx, qx_, p1x - p2x, qy_ - DUP5 - MULFP254 - // stack: (p1x - p2x)qy, p2y - p1y, qx, qx_, p1x - p2x, qy_ - %mstore_kernel_general(108) - // stack: p2y - p1y, qx, qx_, p1x - p2x, qy_ - SWAP1 - // stack: qx, p2y - p1y, qx_, p1x - p2x, qy_ - DUP2 - MULFP254 - // stack: (p2y - p1y)qx, p2y - p1y, qx_, p1x - p2x, qy_ - %mstore_kernel_general(102) - // stack: p2y - p1y, qx_, p1x - p2x, qy_ - MULFP254 - // stack: (p2y - p1y)qx_, p1x - p2x, qy_ - %mstore_kernel_general(103) - // stack: p1x - p2x, qy_ - MULFP254 - // stack: (p1x - p2x)*qy_ - %mstore_kernel_general(109) -%endmacro diff --git a/evm/src/cpu/kernel/asm/curve/bn254/curve_arithmetic/tate_pairing.asm b/evm/src/cpu/kernel/asm/curve/bn254/curve_arithmetic/tate_pairing.asm index cb3fe066..356f002a 100644 --- a/evm/src/cpu/kernel/asm/curve/bn254/curve_arithmetic/tate_pairing.asm +++ b/evm/src/cpu/kernel/asm/curve/bn254/curve_arithmetic/tate_pairing.asm @@ -1,41 +1,292 @@ +/// def miller(P, Q): +/// miller_init() +/// miller_loop() +/// +/// def miller_init(): +/// out = 1 +/// O = P +/// times = 61 +/// +/// def miller_loop(): +/// while times: +/// 0xnm = load(miller_data) +/// while 0xnm > 0x20: +/// miller_one() +/// while 0xnm: +/// miller_zero() +/// times -= 1 +/// +/// def miller_one(): +/// 0xnm -= 0x20 +/// mul_tangent() +/// mul_cord() +/// +/// def miller_zero(): +/// 0xnm -= 1 +/// mul_tangent() + /// def tate(P: Curve, Q: TwistedCurve) -> Fp12: /// out = miller_loop(P, Q) -/// return make_invariant(P, Q) +/// return bn254_invariant_exponent(P, Q) global bn254_tate: // stack: inp, out, retdest - %stack (inp, out) -> (inp, out, make_invariant, out) - // stack: inp, out, make_invariant, out, retdest + %stack (inp, out) -> (inp, out, bn254_invariant_exponent, out) + // stack: inp, out, bn254_invariant_exponent, out, retdest %jump(bn254_miller) +global bn254_miller: + // stack: ptr, out, retdest + %stack (ptr, out) -> (out, 1, ptr, out) + // stack: out, 1, ptr, out, retdest + %mstore_kernel_general + // stack: ptr, out, retdest + %load_fp254_6 + // stack: P, Q, out, retdest + %stack (P: 2) -> (0, 53, P, P) + // stack: 0, 53, O, P, Q, out, retdest + // the head 0 lets miller_loop start with POP +miller_loop: + POP + // stack: times , O, P, Q, out, retdest + DUP1 + ISZERO + // stack: break?, times , O, P, Q, out, retdest + %jumpi(miller_return) + // stack: times , O, P, Q, out, retdest + %sub_const(1) + // stack: times-1, O, P, Q, out, retdest + DUP1 + // stack: times-1, times-1, O, P, Q, out, retdest + %mload_kernel_code(miller_data) + // stack: 0xnm, times-1, O, P, Q, out, retdest + %jump(miller_one) +miller_return: + // stack: times, O, P, Q, out, retdest + %stack (times, O: 2, P: 2, Q: 4, out, retdest) -> (retdest) + // stack: retdest + JUMP -/// def make_invariant(y: Fp12): -/// y = first_exp(y) -/// y = second_exp(y) -/// return final_exponentiation(y) -make_invariant: +miller_one: + // stack: 0xnm, times, O, P, Q, out, retdest + DUP1 + %lt_const(0x20) + // stack: skip?, 0xnm, times, O, P, Q, out, retdest + %jumpi(miller_zero) + // stack: 0xnm, times, O, P, Q, out, retdest + %sub_const(0x20) + // stack: 0x{n-1}m, times, O, P, Q, out, retdest + PUSH mul_cord + // stack: mul_cord, 0x{n-1}m, times, O, P, Q, out, retdest + %jump(mul_tangent) -/// map t to t^(p^6 - 1) via -/// def first_exp(t): -/// return t.frob(6) / t - // stack: out, retdest {out: y} - %stack (out) -> (out, 100, first_exp, out) - // stack: out, 100, first_exp, out, retdest {out: y} - %jump(inv_fp254_12) -first_exp: - // stack: out, retdest {out: y , 100: y^-1} - %frob_fp254_12_6 - // stack: out, retdest {out: y_6, 100: y^-1} - %stack (out) -> (out, 100, out, second_exp, out) - // stack: out, 100, out, second_exp, out, retdest {out: y_6, 100: y^-1} - %jump(mul_fp254_12) +miller_zero: + // stack: m , times, O, P, Q, out, retdest + DUP1 + ISZERO + // stack: skip?, m , times, O, P, Q, out, retdest + %jumpi(miller_loop) + // stack: m , times, O, P, Q, out, retdest + %sub_const(1) + // stack: m-1, times, O, P, Q, out, retdest + PUSH miller_zero + // stack: miller_zero, m-1, times, O, P, Q, out, retdest + %jump(mul_tangent) -/// map t to t^(p^2 + 1) via -/// def second_exp(t): -/// return t.frob(2) * t -second_exp: - // stack: out, retdest {out: y} - %stack (out) -> (out, 100, out, out, bn254_final_exp, out) - // stack: out, 100, out, out, bn254_final_exp, out, retdest {out: y} - %frob_fp254_12_2_ - // stack: 100, out, out, bn254_final_exp, out, retdest {out: y, 100: y_2} - %jump(mul_fp254_12) + +/// def mul_tangent() +/// out = square_fp254_12(out) +/// line = tangent(O, Q) +/// out = mul_fp254_12_sparse(out, line) +/// O += O + +mul_tangent: + // stack: retdest, 0xnm, times, O, P, Q, out + PUSH mul_tangent_2 + DUP13 + PUSH mul_tangent_1 + // stack: mul_tangent_1, out, mul_tangent_2, retdest, 0xnm, times, O, P, Q, out + %stack (mul_tangent_1, out) -> (out, out, mul_tangent_1, out) + // stack: out, out, mul_tangent_1, out, mul_tangent_2, retdest, 0xnm, times, O, P, Q, out + %jump(square_fp254_12) +mul_tangent_1: + // stack: out, mul_tangent_2, retdest, 0xnm, times, O, P, Q, out + DUP13 + DUP13 + DUP13 + DUP13 + // stack: Q, out, mul_tangent_2, retdest, 0xnm, times, O, P, Q, out + DUP11 + DUP11 + // stack: O, Q, out, mul_tangent_2, retdest, 0xnm, times, O, P, Q, out + %tangent + // stack: out, mul_tangent_2, retdest, 0xnm, times, O, P, Q, out {100: line} + %stack (out) -> (out, 100, out) + // stack: out, 100, out, mul_tangent_2, retdest, 0xnm, times, O, P, Q, out {100: line} + %jump(mul_fp254_12_sparse) +mul_tangent_2: + // stack: retdest, 0xnm, times, O, P, Q, out {100: line} + PUSH after_double + // stack: after_double, retdest, 0xnm, times, O, P, Q, out {100: line} + DUP6 + DUP6 + // stack: O, after_double, retdest, 0xnm, times, O, P, Q, out {100: line} + %jump(ec_double) +after_double: + // stack: 2*O, retdest, 0xnm, times, O, P, Q, out {100: line} + SWAP5 + POP + SWAP5 + POP + // stack: retdest, 0xnm, times, 2*O, P, Q, out {100: line} + JUMP + +/// def mul_cord() +/// line = cord(P, O, Q) +/// out = mul_fp254_12_sparse(out, line) +/// O += P + +mul_cord: + // stack: 0xnm, times, O, P, Q, out + PUSH mul_cord_1 + // stack: mul_cord_1, 0xnm, times, O, P, Q, out + DUP11 + DUP11 + DUP11 + DUP11 + // stack: Q, mul_cord_1, 0xnm, times, O, P, Q, out + DUP9 + DUP9 + // stack: O, Q, mul_cord_1, 0xnm, times, O, P, Q, out + DUP13 + DUP13 + // stack: P, O, Q, mul_cord_1, 0xnm, times, O, P, Q, out + %cord + // stack: mul_cord_1, 0xnm, times, O, P, Q, out {100: line} + DUP12 + // stack: out, mul_cord_1, 0xnm, times, O, P, Q, out {100: line} + %stack (out) -> (out, 100, out) + // stack: out, 100, out, mul_cord_1, 0xnm, times, O, P, Q, out {100: line} + %jump(mul_fp254_12_sparse) +mul_cord_1: + // stack: 0xnm, times, O , P, Q, out + PUSH after_add + // stack: after_add, 0xnm, times, O , P, Q, out + DUP7 + DUP7 + DUP7 + DUP7 + // stack: O , P, after_add, 0xnm, times, O , P, Q, out + %jump(ec_add_valid_points) +after_add: + // stack: O + P, 0xnm, times, O , P, Q, out + SWAP4 + POP + SWAP4 + POP + // stack: 0xnm, times, O+P, P, Q, out + %jump(miller_one) + + +/// def tangent(px, py, qx, qy): +/// return sparse_store( +/// py**2 - 9, +/// (-3px**2) * qx, +/// (2py) * qy, +/// ) + +%macro tangent + // stack: px, py, qx, qx_, qy, qy_ + %stack (px, py) -> (py, py , 9, px, py) + // stack: py, py , 9, px, py, qx, qx_, qy, qy_ + MULFP254 + // stack: py^2 , 9, px, py, qx, qx_, qy, qy_ + SUBFP254 + // stack: py^2 - 9, px, py, qx, qx_, qy, qy_ + %mstore_kernel_general(100) + // stack: px, py, qx, qx_, qy, qy_ + DUP1 + MULFP254 + // stack: px^2, py, qx, qx_, qy, qy_ + PUSH 3 + MULFP254 + // stack: 3*px^2, py, qx, qx_, qy, qy_ + PUSH 0 + SUBFP254 + // stack: -3*px^2, py, qx, qx_, qy, qy_ + SWAP2 + // stack: qx, py, -3px^2, qx_, qy, qy_ + DUP3 + MULFP254 + // stack: (-3*px^2)qx, py, -3px^2, qx_, qy, qy_ + %mstore_kernel_general(102) + // stack: py, -3px^2, qx_, qy, qy_ + PUSH 2 + MULFP254 + // stack: 2py, -3px^2, qx_, qy, qy_ + SWAP3 + // stack: qy, -3px^2, qx_, 2py, qy_ + DUP4 + MULFP254 + // stack: (2py)qy, -3px^2, qx_, 2py, qy_ + %mstore_kernel_general(108) + // stack: -3px^2, qx_, 2py, qy_ + MULFP254 + // stack: (-3px^2)*qx_, 2py, qy_ + %mstore_kernel_general(103) + // stack: 2py, qy_ + MULFP254 + // stack: (2py)*qy_ + %mstore_kernel_general(109) +%endmacro + +/// def cord(p1x, p1y, p2x, p2y, qx, qy): +/// return sparse_store( +/// p1y*p2x - p2y*p1x, +/// (p2y - p1y) * qx, +/// (p1x - p2x) * qy, +/// ) + +%macro cord + // stack: p1x , p1y, p2x , p2y, qx, qx_, qy, qy_ + DUP1 + DUP5 + MULFP254 + // stack: p2y*p1x, p1x , p1y, p2x , p2y, qx, qx_, qy, qy_ + DUP3 + DUP5 + MULFP254 + // stack: p1y*p2x , p2y*p1x, p1x , p1y, p2x , p2y, qx, qx_, qy, qy_ + SUBFP254 + // stack: p1y*p2x - p2y*p1x, p1x , p1y, p2x , p2y, qx, qx_, qy, qy_ + %mstore_kernel_general(100) + // stack: p1x , p1y, p2x , p2y, qx, qx_, qy, qy_ + SWAP3 + // stack: p2y , p1y, p2x , p1x, qx, qx_, qy, qy_ + SUBFP254 + // stack: p2y - p1y, p2x , p1x, qx, qx_, qy, qy_ + SWAP2 + // stack: p1x , p2x, p2y - p1y, qx, qx_, qy, qy_ + SUBFP254 + // stack: p1x - p2x, p2y - p1y, qx, qx_, qy, qy_ + SWAP4 + // stack: qy, p2y - p1y, qx, qx_, p1x - p2x, qy_ + DUP5 + MULFP254 + // stack: (p1x - p2x)qy, p2y - p1y, qx, qx_, p1x - p2x, qy_ + %mstore_kernel_general(108) + // stack: p2y - p1y, qx, qx_, p1x - p2x, qy_ + SWAP1 + // stack: qx, p2y - p1y, qx_, p1x - p2x, qy_ + DUP2 + MULFP254 + // stack: (p2y - p1y)qx, p2y - p1y, qx_, p1x - p2x, qy_ + %mstore_kernel_general(102) + // stack: p2y - p1y, qx_, p1x - p2x, qy_ + MULFP254 + // stack: (p2y - p1y)qx_, p1x - p2x, qy_ + %mstore_kernel_general(103) + // stack: p1x - p2x, qy_ + MULFP254 + // stack: (p1x - p2x)*qy_ + %mstore_kernel_general(109) +%endmacro diff --git a/evm/src/cpu/kernel/tests/bn254.rs b/evm/src/cpu/kernel/tests/bn254.rs index daed596d..24a84906 100644 --- a/evm/src/cpu/kernel/tests/bn254.rs +++ b/evm/src/cpu/kernel/tests/bn254.rs @@ -5,7 +5,9 @@ use ethereum_types::U256; use rand::Rng; use crate::bn254_arithmetic::{Fp, Fp12, Fp2}; -use crate::bn254_pairing::{gen_fp12_sparse, miller_loop, tate, Curve, TwistedCurve}; +use crate::bn254_pairing::{ + gen_fp12_sparse, invariant_exponent, miller_loop, tate, Curve, TwistedCurve, +}; use crate::cpu::kernel::aggregator::KERNEL; use crate::cpu::kernel::interpreter::Interpreter; use crate::memory::segments::Segment; @@ -17,24 +19,26 @@ struct InterpreterSetup { memory: Vec<(usize, Vec)>, } -fn run_setup_interpreter(setup: InterpreterSetup) -> Result> { - let label = KERNEL.global_labels[&setup.label]; - let mut stack = setup.stack; - stack.reverse(); - let mut interpreter = Interpreter::new_with_kernel(label, stack); - for (pointer, data) in setup.memory { - for (i, term) in data.iter().enumerate() { - interpreter.generation_state.memory.set( - MemoryAddress::new(0, Segment::KernelGeneral, pointer + i), - *term, - ) +impl InterpreterSetup { + fn run(self) -> Result> { + let label = KERNEL.global_labels[&self.label]; + let mut stack = self.stack; + stack.reverse(); + let mut interpreter = Interpreter::new_with_kernel(label, stack); + for (pointer, data) in self.memory { + for (i, term) in data.iter().enumerate() { + interpreter.generation_state.memory.set( + MemoryAddress::new(0, Segment::KernelGeneral, pointer + i), + *term, + ) + } } + interpreter.run()?; + Ok(interpreter) } - interpreter.run()?; - Ok(interpreter) } -fn extract_kernel_output(range: Range, interpreter: Interpreter<'static>) -> Vec { +fn extract_kernel_memory(range: Range, interpreter: Interpreter<'static>) -> Vec { let mut output: Vec = vec![]; for i in range { let term = interpreter.generation_state.memory.get(MemoryAddress::new( @@ -63,7 +67,7 @@ fn setup_mul_test(out: usize, f: Fp12, g: Fp12, label: &str) -> InterpreterSetup } #[test] -fn test_mul_fp254_12() -> Result<()> { +fn test_mul_fp12() -> Result<()> { let out: usize = 88; let mut rng = rand::thread_rng(); @@ -75,13 +79,13 @@ fn test_mul_fp254_12() -> Result<()> { let setup_sparse: InterpreterSetup = setup_mul_test(out, f, h, "mul_fp254_12_sparse"); let setup_square: InterpreterSetup = setup_mul_test(out, f, f, "square_fp254_12_test"); - let intrptr_normal: Interpreter = run_setup_interpreter(setup_normal).unwrap(); - let intrptr_sparse: Interpreter = run_setup_interpreter(setup_sparse).unwrap(); - let intrptr_square: Interpreter = run_setup_interpreter(setup_square).unwrap(); + let intrptr_normal: Interpreter = setup_normal.run().unwrap(); + let intrptr_sparse: Interpreter = setup_sparse.run().unwrap(); + let intrptr_square: Interpreter = setup_square.run().unwrap(); - let out_normal: Vec = extract_kernel_output(out..out + 12, intrptr_normal); - let out_sparse: Vec = extract_kernel_output(out..out + 12, intrptr_sparse); - let out_square: Vec = extract_kernel_output(out..out + 12, intrptr_square); + let out_normal: Vec = extract_kernel_memory(out..out + 12, intrptr_normal); + let out_sparse: Vec = extract_kernel_memory(out..out + 12, intrptr_sparse); + let out_square: Vec = extract_kernel_memory(out..out + 12, intrptr_square); let exp_normal: Vec = (f * g).on_stack(); let exp_sparse: Vec = (f * h).on_stack(); @@ -103,7 +107,7 @@ fn setup_frob_test(ptr: usize, f: Fp12, label: &str) -> InterpreterSetup { } #[test] -fn test_frob_fp254_12() -> Result<()> { +fn test_frob_fp12() -> Result<()> { let ptr: usize = 100; let mut rng = rand::thread_rng(); @@ -114,15 +118,15 @@ fn test_frob_fp254_12() -> Result<()> { let setup_frob_3 = setup_frob_test(ptr, f, "test_frob_fp254_12_3"); let setup_frob_6 = setup_frob_test(ptr, f, "test_frob_fp254_12_6"); - let intrptr_frob_1: Interpreter = run_setup_interpreter(setup_frob_1).unwrap(); - let intrptr_frob_2: Interpreter = run_setup_interpreter(setup_frob_2).unwrap(); - let intrptr_frob_3: Interpreter = run_setup_interpreter(setup_frob_3).unwrap(); - let intrptr_frob_6: Interpreter = run_setup_interpreter(setup_frob_6).unwrap(); + let intrptr_frob_1: Interpreter = setup_frob_1.run().unwrap(); + let intrptr_frob_2: Interpreter = setup_frob_2.run().unwrap(); + let intrptr_frob_3: Interpreter = setup_frob_3.run().unwrap(); + let intrptr_frob_6: Interpreter = setup_frob_6.run().unwrap(); - let out_frob_1: Vec = extract_kernel_output(ptr..ptr + 12, intrptr_frob_1); - let out_frob_2: Vec = extract_kernel_output(ptr..ptr + 12, intrptr_frob_2); - let out_frob_3: Vec = extract_kernel_output(ptr..ptr + 12, intrptr_frob_3); - let out_frob_6: Vec = extract_kernel_output(ptr..ptr + 12, intrptr_frob_6); + let out_frob_1: Vec = extract_kernel_memory(ptr..ptr + 12, intrptr_frob_1); + let out_frob_2: Vec = extract_kernel_memory(ptr..ptr + 12, intrptr_frob_2); + let out_frob_3: Vec = extract_kernel_memory(ptr..ptr + 12, intrptr_frob_3); + let out_frob_6: Vec = extract_kernel_memory(ptr..ptr + 12, intrptr_frob_6); let exp_frob_1: Vec = f.frob(1).on_stack(); let exp_frob_2: Vec = f.frob(2).on_stack(); @@ -138,10 +142,9 @@ fn test_frob_fp254_12() -> Result<()> { } #[test] -fn test_inv_fp254_12() -> Result<()> { +fn test_inv_fp12() -> Result<()> { let ptr: usize = 100; let inv: usize = 112; - let mut rng = rand::thread_rng(); let f: Fp12 = rng.gen::(); @@ -150,8 +153,8 @@ fn test_inv_fp254_12() -> Result<()> { stack: vec![U256::from(ptr), U256::from(inv), U256::from(0xdeadbeefu32)], memory: vec![(ptr, f.on_stack())], }; - let interpreter: Interpreter = run_setup_interpreter(setup).unwrap(); - let output: Vec = extract_kernel_output(inv..inv + 12, interpreter); + let interpreter: Interpreter = setup.run().unwrap(); + let output: Vec = extract_kernel_memory(inv..inv + 12, interpreter); let expected: Vec = f.inv().on_stack(); assert_eq!(output, expected); @@ -159,29 +162,27 @@ fn test_inv_fp254_12() -> Result<()> { Ok(()) } -// #[test] -// fn test_invariance_inducing_power() -> Result<()> { -// let ptr = U256::from(300); -// let out = U256::from(400); +#[test] +fn test_invariant_exponent() -> Result<()> { + let ptr: usize = 400; -// let f: Fp12 = gen_fp12(); + let mut rng = rand::thread_rng(); + let f: Fp12 = rng.gen::(); -// let mut stack = vec![ptr]; -// stack.extend(fp12_on_stack(f)); -// stack.extend(vec![ -// ptr, -// out, -// get_address_from_label("return_fp12_on_stack"), -// out, -// ]); + let setup = InterpreterSetup { + label: "bn254_invariant_exponent".to_string(), + stack: vec![U256::from(ptr), U256::from(0xdeadbeefu32)], + memory: vec![(ptr, f.on_stack())], + }; -// let output: Vec = run_setup_interpreter("test_pow", stack); -// let expected: Vec = fp12_on_stack(invariance_inducing_power(f)); + let interpreter: Interpreter = setup.run().unwrap(); + let output: Vec = extract_kernel_memory(ptr..ptr + 12, interpreter); + let expected: Vec = invariant_exponent(f).on_stack(); -// assert_eq!(output, expected); + assert_eq!(output, expected); -// Ok(()) -// } + Ok(()) +} // The curve is cyclic with generator (1, 2) pub const CURVE_GENERATOR: Curve = { @@ -253,8 +254,8 @@ fn test_miller() -> Result<()> { stack: vec![U256::from(ptr), U256::from(out), U256::from(0xdeadbeefu32)], memory: vec![(ptr, inputs)], }; - let interpreter = run_setup_interpreter(setup).unwrap(); - let output: Vec = extract_kernel_output(out..out + 12, interpreter); + let interpreter = setup.run().unwrap(); + let output: Vec = extract_kernel_memory(out..out + 12, interpreter); let expected = miller_loop(CURVE_GENERATOR, TWISTED_GENERATOR).on_stack(); assert_eq!(output, expected); @@ -280,8 +281,8 @@ fn test_tate() -> Result<()> { stack: vec![U256::from(ptr), U256::from(out), U256::from(0xdeadbeefu32)], memory: vec![(ptr, inputs)], }; - let interpreter = run_setup_interpreter(setup).unwrap(); - let output: Vec = extract_kernel_output(out..out + 12, interpreter); + let interpreter = setup.run().unwrap(); + let output: Vec = extract_kernel_memory(out..out + 12, interpreter); let expected = tate(CURVE_GENERATOR, TWISTED_GENERATOR).on_stack(); assert_eq!(output, expected);