From d2aa937a2ff42e1e437ad44207d7745d14448c1c Mon Sep 17 00:00:00 2001 From: Dmitry Vagner Date: Sat, 21 Jan 2023 13:19:07 +0700 Subject: [PATCH] improved prover input and test api --- evm/src/bn254_arithmetic.rs | 40 +---- evm/src/cpu/kernel/aggregator.rs | 2 +- .../curve/bn254/field_arithmetic/inverse.asm | 17 ++- .../field_arithmetic/{utils.asm => util.asm} | 0 evm/src/cpu/kernel/tests/bn254.rs | 142 +++++++++--------- evm/src/generation/prover_input.rs | 26 ++-- evm/src/witness/util.rs | 21 +-- 7 files changed, 115 insertions(+), 133 deletions(-) rename evm/src/cpu/kernel/asm/curve/bn254/field_arithmetic/{utils.asm => util.asm} (100%) diff --git a/evm/src/bn254_arithmetic.rs b/evm/src/bn254_arithmetic.rs index 22bf92c4..c5c7961d 100644 --- a/evm/src/bn254_arithmetic.rs +++ b/evm/src/bn254_arithmetic.rs @@ -2,7 +2,6 @@ use std::mem::transmute; use std::ops::{Add, Div, Mul, Neg, Sub}; use ethereum_types::U256; -use itertools::Itertools; use rand::{thread_rng, Rng}; pub const BN_BASE: U256 = U256([ @@ -139,17 +138,13 @@ impl Mul for Fp2 { } } -/// The inverse of a + bi is given by (a - bi)/(a^2 + b^2) since -/// (a + bi)(a - bi)/(a^2 + b^2) = (a^2 + b^2)/(a^2 + b^2) = 1 +/// The inverse of z is given by z'/||z|| since ||z|| = zz' impl Div for Fp2 { type Output = Self; fn div(self, rhs: Self) -> Self::Output { let norm = rhs.re * rhs.re + rhs.im * rhs.im; - let inv = Fp2 { - re: rhs.re / norm, - im: -rhs.im / norm, - }; + let inv = mul_fp_fp2(norm, conj_fp2(rhs)); self * inv } } @@ -833,36 +828,9 @@ const FROB_Z: [Fp2; 12] = [ }, ]; -pub fn fp12_to_array(f: Fp12) -> [U256; 12] { - unsafe { transmute(f) } -} - pub fn fp12_to_vec(f: Fp12) -> Vec { - fp12_to_array(f).into_iter().collect() -} - -pub fn vec_to_fp12(xs: Vec) -> Fp12 { - xs.into_iter() - .tuples::<(U256, U256)>() - .map(|(v1, v2)| Fp2 { - re: Fp { val: v1 }, - im: Fp { val: v2 }, - }) - .tuples() - .map(|(a1, a2, a3, a4, a5, a6)| Fp12 { - z0: Fp6 { - t0: a1, - t1: a2, - t2: a3, - }, - z1: Fp6 { - t0: a4, - t1: a5, - t2: a6, - }, - }) - .next() - .unwrap() + let f: [U256; 12] = unsafe { transmute(f) }; + f.into_iter().collect() } fn gen_fp() -> Fp { diff --git a/evm/src/cpu/kernel/aggregator.rs b/evm/src/cpu/kernel/aggregator.rs index d924eeb4..c74baa65 100644 --- a/evm/src/cpu/kernel/aggregator.rs +++ b/evm/src/cpu/kernel/aggregator.rs @@ -34,7 +34,7 @@ pub(crate) fn combined_kernel() -> Kernel { include_str!("asm/curve/bn254/field_arithmetic/fp12_mul.asm"), include_str!("asm/curve/bn254/field_arithmetic/frobenius.asm"), include_str!("asm/curve/bn254/field_arithmetic/power.asm"), - include_str!("asm/curve/bn254/field_arithmetic/utils.asm"), + include_str!("asm/curve/bn254/field_arithmetic/util.asm"), include_str!("asm/curve/common.asm"), include_str!("asm/curve/secp256k1/curve_mul.asm"), include_str!("asm/curve/secp256k1/curve_add.asm"), diff --git a/evm/src/cpu/kernel/asm/curve/bn254/field_arithmetic/inverse.asm b/evm/src/cpu/kernel/asm/curve/bn254/field_arithmetic/inverse.asm index 408d3cc9..8f42e047 100644 --- a/evm/src/cpu/kernel/asm/curve/bn254/field_arithmetic/inverse.asm +++ b/evm/src/cpu/kernel/asm/curve/bn254/field_arithmetic/inverse.asm @@ -23,6 +23,18 @@ global inv_fp12: + // stack: ptr, inv, retdest + %prover_inv_fp12 + // stack: f^-1, ptr, inv, retdest + DUP14 + // stack: inv, f^-1, ptr, inv, retdest + %store_fp12 + // stack: ptr, inv, retdest + %stack (ptr, inv) -> (ptr, inv, 50, check_inv) + // stack: ptr, inv, 50, check_inv, retdest + %jump(mul_fp12) + +global inv_fp12_old: // stack: ptr, inv, retdest DUP1 %load_fp12 // stack: f, ptr, inv, retdest @@ -39,9 +51,12 @@ global inv_fp12: %stack (check_inv, mem, ptr, inv) -> (ptr, inv, mem, check_inv) // stack: ptr, inv, 50, check_inv, retdest %jump(mul_fp12) + + global check_inv: // stack: retdest - PUSH 50 %load_fp12 + PUSH 50 + %load_fp12 // stack: unit?, retdest %assert_eq_unit_fp12 // stack: retdest diff --git a/evm/src/cpu/kernel/asm/curve/bn254/field_arithmetic/utils.asm b/evm/src/cpu/kernel/asm/curve/bn254/field_arithmetic/util.asm similarity index 100% rename from evm/src/cpu/kernel/asm/curve/bn254/field_arithmetic/utils.asm rename to evm/src/cpu/kernel/asm/curve/bn254/field_arithmetic/util.asm diff --git a/evm/src/cpu/kernel/tests/bn254.rs b/evm/src/cpu/kernel/tests/bn254.rs index 58f26bcc..23f9d531 100644 --- a/evm/src/cpu/kernel/tests/bn254.rs +++ b/evm/src/cpu/kernel/tests/bn254.rs @@ -13,15 +13,13 @@ struct InterpreterSetup { offset: String, stack: Vec, memory: Vec<(usize, Vec)>, - output: Range, } -fn get_interpreter_output(setup: InterpreterSetup) -> Result> { +fn run_setup_interpreter(setup: InterpreterSetup) -> Result> { let label = KERNEL.global_labels[&setup.offset]; let mut stack = setup.stack; stack.reverse(); let mut interpreter = Interpreter::new_with_kernel(label, stack); - for (pointer, data) in setup.memory { for (i, term) in data.iter().enumerate() { interpreter.generation_state.memory.set( @@ -30,54 +28,64 @@ fn get_interpreter_output(setup: InterpreterSetup) -> Result> { ) } } - interpreter.run()?; - - let kernel = &interpreter.generation_state.memory.contexts[interpreter.context].segments - [Segment::KernelGeneral as usize] - .content; - - let mut output: Vec = vec![]; - for i in setup.output { - output.push(kernel[i]); - } - Ok(output) + Ok(interpreter) } -fn setup_mul_test(f: Fp12, g: Fp12, label: &str) -> InterpreterSetup { - let in0: usize = 64; - let in1: usize = 76; - let out: usize = 88; - - let stack = vec![ - U256::from(in0), - U256::from(in1), - U256::from(out), - U256::from(0xdeadbeefu32), - ]; - let memory = vec![(in0, fp12_to_vec(f)), (in1, fp12_to_vec(g))]; +fn extract_kernel_output(range: Range, interpreter: Interpreter<'static>) -> Vec { + let mut output: Vec = vec![]; + for i in range { + let term = interpreter.generation_state.memory.get(MemoryAddress::new( + 0, + Segment::KernelGeneral, + i, + )); + output.push(term); + } + output +} +fn setup_mul_test( + in0: usize, + in1: usize, + out: usize, + f: Fp12, + g: Fp12, + label: &str, +) -> InterpreterSetup { InterpreterSetup { offset: label.to_string(), - stack, - memory, - output: out..out + 12, + stack: vec![ + U256::from(in0), + U256::from(in1), + U256::from(out), + U256::from(0xdeadbeefu32), + ], + memory: vec![(in0, fp12_to_vec(f)), (in1, fp12_to_vec(g))], } } #[test] fn test_mul_fp12() -> Result<()> { + let in0: usize = 64; + let in1: usize = 76; + let out: usize = 88; + let f: Fp12 = gen_fp12(); let g: Fp12 = gen_fp12(); let h: Fp12 = gen_fp12_sparse(); - let setup_normal: InterpreterSetup = setup_mul_test(f, g, "mul_fp12"); - let setup_sparse: InterpreterSetup = setup_mul_test(f, h, "mul_fp12_sparse"); - let setup_square: InterpreterSetup = setup_mul_test(f, f, "square_fp12_test"); + let setup_normal: InterpreterSetup = setup_mul_test(in0, in1, out, f, g, "mul_fp12"); + let setup_sparse: InterpreterSetup = setup_mul_test(in0, in1, out, f, h, "mul_fp12_sparse"); + let setup_square: InterpreterSetup = setup_mul_test(in0, in1, out, f, f, "square_fp12_test"); - let out_normal: Vec = get_interpreter_output(setup_normal).unwrap(); - let out_sparse: Vec = get_interpreter_output(setup_sparse).unwrap(); - let out_square: Vec = get_interpreter_output(setup_square).unwrap(); + let intrptr_normal: Interpreter = run_setup_interpreter(setup_normal).unwrap(); + let intrptr_sparse: Interpreter = run_setup_interpreter(setup_sparse).unwrap(); + let intrptr_square: Interpreter = run_setup_interpreter(setup_square).unwrap(); + + let out_normal: Vec = extract_kernel_output(out..out + 12, intrptr_normal); + let out_sparse: Vec = extract_kernel_output(out..out + 12, intrptr_sparse); + let out_square: Vec = extract_kernel_output(out..out + 12, intrptr_square); let exp_normal: Vec = fp12_to_vec(f * g); let exp_sparse: Vec = fp12_to_vec(f * h); @@ -90,32 +98,33 @@ fn test_mul_fp12() -> Result<()> { Ok(()) } -fn setup_frob_test(f: Fp12, label: &str) -> InterpreterSetup { - let ptr: usize = 100; - let stack = vec![U256::from(ptr)]; - let memory = vec![(ptr, fp12_to_vec(f))]; - +fn setup_frob_test(ptr: usize, f: Fp12, label: &str) -> InterpreterSetup { InterpreterSetup { offset: label.to_string(), - stack, - memory, - output: ptr..ptr + 12, + stack: vec![U256::from(ptr)], + memory: vec![(ptr, fp12_to_vec(f))], } } #[test] fn test_frob_fp12() -> Result<()> { + let ptr: usize = 100; let f: Fp12 = gen_fp12(); - let setup_frob_1 = setup_frob_test(f, "test_frob_fp12_1"); - let setup_frob_2 = setup_frob_test(f, "test_frob_fp12_2"); - let setup_frob_3 = setup_frob_test(f, "test_frob_fp12_3"); - let setup_frob_6 = setup_frob_test(f, "test_frob_fp12_6"); + let setup_frob_1 = setup_frob_test(ptr, f, "test_frob_fp12_1"); + let setup_frob_2 = setup_frob_test(ptr, f, "test_frob_fp12_2"); + let setup_frob_3 = setup_frob_test(ptr, f, "test_frob_fp12_3"); + let setup_frob_6 = setup_frob_test(ptr, f, "test_frob_fp12_6"); - let out_frob_1: Vec = get_interpreter_output(setup_frob_1).unwrap(); - let out_frob_2: Vec = get_interpreter_output(setup_frob_2).unwrap(); - let out_frob_3: Vec = get_interpreter_output(setup_frob_3).unwrap(); - let out_frob_6: Vec = get_interpreter_output(setup_frob_6).unwrap(); + let intrptr_frob_1: Interpreter = run_setup_interpreter(setup_frob_1).unwrap(); + let intrptr_frob_2: Interpreter = run_setup_interpreter(setup_frob_2).unwrap(); + let intrptr_frob_3: Interpreter = run_setup_interpreter(setup_frob_3).unwrap(); + let intrptr_frob_6: Interpreter = run_setup_interpreter(setup_frob_6).unwrap(); + + let out_frob_1: Vec = extract_kernel_output(ptr..ptr + 12, intrptr_frob_1); + let out_frob_2: Vec = extract_kernel_output(ptr..ptr + 12, intrptr_frob_2); + let out_frob_3: Vec = extract_kernel_output(ptr..ptr + 12, intrptr_frob_3); + let out_frob_6: Vec = extract_kernel_output(ptr..ptr + 12, intrptr_frob_6); let exp_frob_1: Vec = fp12_to_vec(frob_fp12(1, f)); let exp_frob_2: Vec = fp12_to_vec(frob_fp12(2, f)); @@ -130,26 +139,19 @@ fn test_frob_fp12() -> Result<()> { Ok(()) } -fn setup_inv_test(f: Fp12) -> InterpreterSetup { - let ptr: usize = 100; - let inv: usize = 112; - let stack = vec![U256::from(ptr), U256::from(inv), U256::from(0xdeadbeefu32)]; - let memory = vec![(ptr, fp12_to_vec(f))]; - - InterpreterSetup { - offset: "inv_fp12".to_string(), - stack, - memory, - output: inv..inv + 12, - } -} - #[test] fn test_inv_fp12() -> Result<()> { + let ptr: usize = 100; + let inv: usize = 112; let f: Fp12 = gen_fp12(); - let setup = setup_inv_test(f); - let output: Vec = get_interpreter_output(setup).unwrap(); + let setup = InterpreterSetup { + offset: "inv_fp12".to_string(), + stack: vec![U256::from(ptr), U256::from(inv), U256::from(0xdeadbeefu32)], + memory: vec![(ptr, fp12_to_vec(f))], + }; + let interpreter: Interpreter = run_setup_interpreter(setup).unwrap(); + let output: Vec = extract_kernel_output(inv..inv + 12, interpreter); let expected: Vec = fp12_to_vec(inv_fp12(f)); assert_eq!(output, expected); @@ -173,7 +175,7 @@ fn test_inv_fp12() -> Result<()> { // out, // ]); -// let output: Vec = get_interpreter_output("test_pow", stack); +// let output: Vec = run_setup_interpreter("test_pow", stack); // let expected: Vec = fp12_to_vec(power(f)); // assert_eq!(output, expected); @@ -206,7 +208,7 @@ fn test_inv_fp12() -> Result<()> { // let q: TwistedCurve = twisted_curve_generator(); // let stack = make_tate_stack(p, q); -// let output = get_interpreter_output("test_miller", stack); +// let output = run_setup_interpreter("test_miller", stack); // let expected = fp12_to_vec(miller_loop(p, q)); // assert_eq!(output, expected); @@ -220,7 +222,7 @@ fn test_inv_fp12() -> Result<()> { // let q: TwistedCurve = twisted_curve_generator(); // let stack = make_tate_stack(p, q); -// let output = get_interpreter_output("test_tate", stack); +// let output = run_setup_interpreter("test_tate", stack); // let expected = fp12_to_vec(tate(p, q)); // assert_eq!(output, expected); diff --git a/evm/src/generation/prover_input.rs b/evm/src/generation/prover_input.rs index 9f305e41..4dff42c7 100644 --- a/evm/src/generation/prover_input.rs +++ b/evm/src/generation/prover_input.rs @@ -1,16 +1,17 @@ +use std::mem::transmute; use std::str::FromStr; use anyhow::{bail, Error}; use ethereum_types::{BigEndianHash, H256, U256}; use plonky2::field::types::Field; -use crate::bn254_arithmetic::{fp12_to_array, inv_fp12, vec_to_fp12}; +use crate::bn254_arithmetic::{inv_fp12, Fp12}; use crate::generation::prover_input::EvmField::{ Bn254Base, Bn254Scalar, Secp256k1Base, Secp256k1Scalar, }; use crate::generation::prover_input::FieldOp::{Inverse, Sqrt}; use crate::generation::state::GenerationState; -use crate::witness::util::{stack_peek, stack_peeks}; +use crate::witness::util::{kernel_general_peek, stack_peek}; /// Prover input function represented as a scoped function name. /// Example: `PROVER_INPUT(ff::bn254_base::inverse)` is represented as `ProverInputFn([ff, bn254_base, inverse])`. @@ -57,10 +58,7 @@ impl GenerationState { /// Finite field extension operations. fn run_ffe(&self, input_fn: &ProverInputFn) -> U256 { let field = EvmField::from_str(input_fn.0[1].as_str()).unwrap(); - let component = input_fn.0[2].as_str(); - let xs = stack_peeks(self).expect("Empty stack"); - // TODO: This sucks... come back later - let n = match component { + let n = match input_fn.0[2].as_str() { "component_0" => 0, "component_1" => 1, "component_2" => 2, @@ -75,7 +73,12 @@ impl GenerationState { "component_11" => 11, _ => panic!("out of bounds"), }; - field.inverse_fp12(n, xs) + let ptr = stack_peek(self, 11 - n).expect("Empty stack").as_usize(); + let mut f: [U256; 12] = [U256::zero(); 12]; + for i in 0..12 { + f[i] = kernel_general_peek(self, ptr + i); + } + field.inverse_fp12(n, f) } /// MPT data. @@ -196,11 +199,10 @@ impl EvmField { modexp(x, q, n) } - fn inverse_fp12(&self, n: usize, xs: Vec) -> U256 { - let offset = 12 - n; - let vec: Vec = xs[offset..].to_vec(); - let f = fp12_to_array(inv_fp12(vec_to_fp12(vec))); - f[n] + fn inverse_fp12(&self, n: usize, f: [U256; 12]) -> U256 { + let f: Fp12 = unsafe { transmute(f) }; + let f_inv: [U256; 12] = unsafe { transmute(inv_fp12(f)) }; + f_inv[n] } } diff --git a/evm/src/witness/util.rs b/evm/src/witness/util.rs index 9aa0cb03..d47365f0 100644 --- a/evm/src/witness/util.rs +++ b/evm/src/witness/util.rs @@ -27,7 +27,7 @@ fn to_bits_le(n: u8) -> [F; 8] { res } -/// Peak at the stack item `i`th from the top. If `i=0` this gives the tip. +/// Peek at the stack item `i`th from the top. If `i=0` this gives the tip. pub(crate) fn stack_peek(state: &GenerationState, i: usize) -> Option { if i >= state.registers.stack_len { return None; @@ -39,18 +39,13 @@ pub(crate) fn stack_peek(state: &GenerationState, i: usize) -> Opti ))) } -/// Peek at the entire stack. -pub(crate) fn stack_peeks(state: &GenerationState) -> Option> { - let n = state.registers.stack_len; - let mut stack: Vec = vec![]; - for i in 0..n { - stack.extend(vec![state.memory.get(MemoryAddress::new( - state.registers.code_context(), - Segment::Stack, - n - 1 - i, - ))]) - } - Some(stack) +/// Peek at the kernel general item at address `i` +pub(crate) fn kernel_general_peek(state: &GenerationState, i: usize) -> U256 { + state.memory.get(MemoryAddress::new( + state.registers.context, + Segment::KernelGeneral, + i, + )) } pub(crate) fn mem_read_with_log(