improved prover input and test api

This commit is contained in:
Dmitry Vagner 2023-01-21 13:19:07 +07:00
parent e06a2f2d46
commit d2aa937a2f
7 changed files with 115 additions and 133 deletions

View File

@ -2,7 +2,6 @@ use std::mem::transmute;
use std::ops::{Add, Div, Mul, Neg, Sub};
use ethereum_types::U256;
use itertools::Itertools;
use rand::{thread_rng, Rng};
pub const BN_BASE: U256 = U256([
@ -139,17 +138,13 @@ impl Mul for Fp2 {
}
}
/// The inverse of a + bi is given by (a - bi)/(a^2 + b^2) since
/// (a + bi)(a - bi)/(a^2 + b^2) = (a^2 + b^2)/(a^2 + b^2) = 1
/// The inverse of z is given by z'/||z|| since ||z|| = zz'
impl Div for Fp2 {
type Output = Self;
fn div(self, rhs: Self) -> Self::Output {
let norm = rhs.re * rhs.re + rhs.im * rhs.im;
let inv = Fp2 {
re: rhs.re / norm,
im: -rhs.im / norm,
};
let inv = mul_fp_fp2(norm, conj_fp2(rhs));
self * inv
}
}
@ -833,36 +828,9 @@ const FROB_Z: [Fp2; 12] = [
},
];
pub fn fp12_to_array(f: Fp12) -> [U256; 12] {
unsafe { transmute(f) }
}
pub fn fp12_to_vec(f: Fp12) -> Vec<U256> {
fp12_to_array(f).into_iter().collect()
}
pub fn vec_to_fp12(xs: Vec<U256>) -> Fp12 {
xs.into_iter()
.tuples::<(U256, U256)>()
.map(|(v1, v2)| Fp2 {
re: Fp { val: v1 },
im: Fp { val: v2 },
})
.tuples()
.map(|(a1, a2, a3, a4, a5, a6)| Fp12 {
z0: Fp6 {
t0: a1,
t1: a2,
t2: a3,
},
z1: Fp6 {
t0: a4,
t1: a5,
t2: a6,
},
})
.next()
.unwrap()
let f: [U256; 12] = unsafe { transmute(f) };
f.into_iter().collect()
}
fn gen_fp() -> Fp {

View File

@ -34,7 +34,7 @@ pub(crate) fn combined_kernel() -> Kernel {
include_str!("asm/curve/bn254/field_arithmetic/fp12_mul.asm"),
include_str!("asm/curve/bn254/field_arithmetic/frobenius.asm"),
include_str!("asm/curve/bn254/field_arithmetic/power.asm"),
include_str!("asm/curve/bn254/field_arithmetic/utils.asm"),
include_str!("asm/curve/bn254/field_arithmetic/util.asm"),
include_str!("asm/curve/common.asm"),
include_str!("asm/curve/secp256k1/curve_mul.asm"),
include_str!("asm/curve/secp256k1/curve_add.asm"),

View File

@ -23,6 +23,18 @@
global inv_fp12:
// stack: ptr, inv, retdest
%prover_inv_fp12
// stack: f^-1, ptr, inv, retdest
DUP14
// stack: inv, f^-1, ptr, inv, retdest
%store_fp12
// stack: ptr, inv, retdest
%stack (ptr, inv) -> (ptr, inv, 50, check_inv)
// stack: ptr, inv, 50, check_inv, retdest
%jump(mul_fp12)
global inv_fp12_old:
// stack: ptr, inv, retdest
DUP1 %load_fp12
// stack: f, ptr, inv, retdest
@ -39,9 +51,12 @@ global inv_fp12:
%stack (check_inv, mem, ptr, inv) -> (ptr, inv, mem, check_inv)
// stack: ptr, inv, 50, check_inv, retdest
%jump(mul_fp12)
global check_inv:
// stack: retdest
PUSH 50 %load_fp12
PUSH 50
%load_fp12
// stack: unit?, retdest
%assert_eq_unit_fp12
// stack: retdest

View File

@ -13,15 +13,13 @@ struct InterpreterSetup {
offset: String,
stack: Vec<U256>,
memory: Vec<(usize, Vec<U256>)>,
output: Range<usize>,
}
fn get_interpreter_output(setup: InterpreterSetup) -> Result<Vec<U256>> {
fn run_setup_interpreter(setup: InterpreterSetup) -> Result<Interpreter<'static>> {
let label = KERNEL.global_labels[&setup.offset];
let mut stack = setup.stack;
stack.reverse();
let mut interpreter = Interpreter::new_with_kernel(label, stack);
for (pointer, data) in setup.memory {
for (i, term) in data.iter().enumerate() {
interpreter.generation_state.memory.set(
@ -30,54 +28,64 @@ fn get_interpreter_output(setup: InterpreterSetup) -> Result<Vec<U256>> {
)
}
}
interpreter.run()?;
let kernel = &interpreter.generation_state.memory.contexts[interpreter.context].segments
[Segment::KernelGeneral as usize]
.content;
let mut output: Vec<U256> = vec![];
for i in setup.output {
output.push(kernel[i]);
}
Ok(output)
Ok(interpreter)
}
fn setup_mul_test(f: Fp12, g: Fp12, label: &str) -> InterpreterSetup {
let in0: usize = 64;
let in1: usize = 76;
let out: usize = 88;
let stack = vec![
U256::from(in0),
U256::from(in1),
U256::from(out),
U256::from(0xdeadbeefu32),
];
let memory = vec![(in0, fp12_to_vec(f)), (in1, fp12_to_vec(g))];
fn extract_kernel_output(range: Range<usize>, interpreter: Interpreter<'static>) -> Vec<U256> {
let mut output: Vec<U256> = vec![];
for i in range {
let term = interpreter.generation_state.memory.get(MemoryAddress::new(
0,
Segment::KernelGeneral,
i,
));
output.push(term);
}
output
}
fn setup_mul_test(
in0: usize,
in1: usize,
out: usize,
f: Fp12,
g: Fp12,
label: &str,
) -> InterpreterSetup {
InterpreterSetup {
offset: label.to_string(),
stack,
memory,
output: out..out + 12,
stack: vec![
U256::from(in0),
U256::from(in1),
U256::from(out),
U256::from(0xdeadbeefu32),
],
memory: vec![(in0, fp12_to_vec(f)), (in1, fp12_to_vec(g))],
}
}
#[test]
fn test_mul_fp12() -> Result<()> {
let in0: usize = 64;
let in1: usize = 76;
let out: usize = 88;
let f: Fp12 = gen_fp12();
let g: Fp12 = gen_fp12();
let h: Fp12 = gen_fp12_sparse();
let setup_normal: InterpreterSetup = setup_mul_test(f, g, "mul_fp12");
let setup_sparse: InterpreterSetup = setup_mul_test(f, h, "mul_fp12_sparse");
let setup_square: InterpreterSetup = setup_mul_test(f, f, "square_fp12_test");
let setup_normal: InterpreterSetup = setup_mul_test(in0, in1, out, f, g, "mul_fp12");
let setup_sparse: InterpreterSetup = setup_mul_test(in0, in1, out, f, h, "mul_fp12_sparse");
let setup_square: InterpreterSetup = setup_mul_test(in0, in1, out, f, f, "square_fp12_test");
let out_normal: Vec<U256> = get_interpreter_output(setup_normal).unwrap();
let out_sparse: Vec<U256> = get_interpreter_output(setup_sparse).unwrap();
let out_square: Vec<U256> = get_interpreter_output(setup_square).unwrap();
let intrptr_normal: Interpreter = run_setup_interpreter(setup_normal).unwrap();
let intrptr_sparse: Interpreter = run_setup_interpreter(setup_sparse).unwrap();
let intrptr_square: Interpreter = run_setup_interpreter(setup_square).unwrap();
let out_normal: Vec<U256> = extract_kernel_output(out..out + 12, intrptr_normal);
let out_sparse: Vec<U256> = extract_kernel_output(out..out + 12, intrptr_sparse);
let out_square: Vec<U256> = extract_kernel_output(out..out + 12, intrptr_square);
let exp_normal: Vec<U256> = fp12_to_vec(f * g);
let exp_sparse: Vec<U256> = fp12_to_vec(f * h);
@ -90,32 +98,33 @@ fn test_mul_fp12() -> Result<()> {
Ok(())
}
fn setup_frob_test(f: Fp12, label: &str) -> InterpreterSetup {
let ptr: usize = 100;
let stack = vec![U256::from(ptr)];
let memory = vec![(ptr, fp12_to_vec(f))];
fn setup_frob_test(ptr: usize, f: Fp12, label: &str) -> InterpreterSetup {
InterpreterSetup {
offset: label.to_string(),
stack,
memory,
output: ptr..ptr + 12,
stack: vec![U256::from(ptr)],
memory: vec![(ptr, fp12_to_vec(f))],
}
}
#[test]
fn test_frob_fp12() -> Result<()> {
let ptr: usize = 100;
let f: Fp12 = gen_fp12();
let setup_frob_1 = setup_frob_test(f, "test_frob_fp12_1");
let setup_frob_2 = setup_frob_test(f, "test_frob_fp12_2");
let setup_frob_3 = setup_frob_test(f, "test_frob_fp12_3");
let setup_frob_6 = setup_frob_test(f, "test_frob_fp12_6");
let setup_frob_1 = setup_frob_test(ptr, f, "test_frob_fp12_1");
let setup_frob_2 = setup_frob_test(ptr, f, "test_frob_fp12_2");
let setup_frob_3 = setup_frob_test(ptr, f, "test_frob_fp12_3");
let setup_frob_6 = setup_frob_test(ptr, f, "test_frob_fp12_6");
let out_frob_1: Vec<U256> = get_interpreter_output(setup_frob_1).unwrap();
let out_frob_2: Vec<U256> = get_interpreter_output(setup_frob_2).unwrap();
let out_frob_3: Vec<U256> = get_interpreter_output(setup_frob_3).unwrap();
let out_frob_6: Vec<U256> = get_interpreter_output(setup_frob_6).unwrap();
let intrptr_frob_1: Interpreter = run_setup_interpreter(setup_frob_1).unwrap();
let intrptr_frob_2: Interpreter = run_setup_interpreter(setup_frob_2).unwrap();
let intrptr_frob_3: Interpreter = run_setup_interpreter(setup_frob_3).unwrap();
let intrptr_frob_6: Interpreter = run_setup_interpreter(setup_frob_6).unwrap();
let out_frob_1: Vec<U256> = extract_kernel_output(ptr..ptr + 12, intrptr_frob_1);
let out_frob_2: Vec<U256> = extract_kernel_output(ptr..ptr + 12, intrptr_frob_2);
let out_frob_3: Vec<U256> = extract_kernel_output(ptr..ptr + 12, intrptr_frob_3);
let out_frob_6: Vec<U256> = extract_kernel_output(ptr..ptr + 12, intrptr_frob_6);
let exp_frob_1: Vec<U256> = fp12_to_vec(frob_fp12(1, f));
let exp_frob_2: Vec<U256> = fp12_to_vec(frob_fp12(2, f));
@ -130,26 +139,19 @@ fn test_frob_fp12() -> Result<()> {
Ok(())
}
fn setup_inv_test(f: Fp12) -> InterpreterSetup {
let ptr: usize = 100;
let inv: usize = 112;
let stack = vec![U256::from(ptr), U256::from(inv), U256::from(0xdeadbeefu32)];
let memory = vec![(ptr, fp12_to_vec(f))];
InterpreterSetup {
offset: "inv_fp12".to_string(),
stack,
memory,
output: inv..inv + 12,
}
}
#[test]
fn test_inv_fp12() -> Result<()> {
let ptr: usize = 100;
let inv: usize = 112;
let f: Fp12 = gen_fp12();
let setup = setup_inv_test(f);
let output: Vec<U256> = get_interpreter_output(setup).unwrap();
let setup = InterpreterSetup {
offset: "inv_fp12".to_string(),
stack: vec![U256::from(ptr), U256::from(inv), U256::from(0xdeadbeefu32)],
memory: vec![(ptr, fp12_to_vec(f))],
};
let interpreter: Interpreter = run_setup_interpreter(setup).unwrap();
let output: Vec<U256> = extract_kernel_output(inv..inv + 12, interpreter);
let expected: Vec<U256> = fp12_to_vec(inv_fp12(f));
assert_eq!(output, expected);
@ -173,7 +175,7 @@ fn test_inv_fp12() -> Result<()> {
// out,
// ]);
// let output: Vec<U256> = get_interpreter_output("test_pow", stack);
// let output: Vec<U256> = run_setup_interpreter("test_pow", stack);
// let expected: Vec<U256> = fp12_to_vec(power(f));
// assert_eq!(output, expected);
@ -206,7 +208,7 @@ fn test_inv_fp12() -> Result<()> {
// let q: TwistedCurve = twisted_curve_generator();
// let stack = make_tate_stack(p, q);
// let output = get_interpreter_output("test_miller", stack);
// let output = run_setup_interpreter("test_miller", stack);
// let expected = fp12_to_vec(miller_loop(p, q));
// assert_eq!(output, expected);
@ -220,7 +222,7 @@ fn test_inv_fp12() -> Result<()> {
// let q: TwistedCurve = twisted_curve_generator();
// let stack = make_tate_stack(p, q);
// let output = get_interpreter_output("test_tate", stack);
// let output = run_setup_interpreter("test_tate", stack);
// let expected = fp12_to_vec(tate(p, q));
// assert_eq!(output, expected);

View File

@ -1,16 +1,17 @@
use std::mem::transmute;
use std::str::FromStr;
use anyhow::{bail, Error};
use ethereum_types::{BigEndianHash, H256, U256};
use plonky2::field::types::Field;
use crate::bn254_arithmetic::{fp12_to_array, inv_fp12, vec_to_fp12};
use crate::bn254_arithmetic::{inv_fp12, Fp12};
use crate::generation::prover_input::EvmField::{
Bn254Base, Bn254Scalar, Secp256k1Base, Secp256k1Scalar,
};
use crate::generation::prover_input::FieldOp::{Inverse, Sqrt};
use crate::generation::state::GenerationState;
use crate::witness::util::{stack_peek, stack_peeks};
use crate::witness::util::{kernel_general_peek, stack_peek};
/// Prover input function represented as a scoped function name.
/// Example: `PROVER_INPUT(ff::bn254_base::inverse)` is represented as `ProverInputFn([ff, bn254_base, inverse])`.
@ -57,10 +58,7 @@ impl<F: Field> GenerationState<F> {
/// Finite field extension operations.
fn run_ffe(&self, input_fn: &ProverInputFn) -> U256 {
let field = EvmField::from_str(input_fn.0[1].as_str()).unwrap();
let component = input_fn.0[2].as_str();
let xs = stack_peeks(self).expect("Empty stack");
// TODO: This sucks... come back later
let n = match component {
let n = match input_fn.0[2].as_str() {
"component_0" => 0,
"component_1" => 1,
"component_2" => 2,
@ -75,7 +73,12 @@ impl<F: Field> GenerationState<F> {
"component_11" => 11,
_ => panic!("out of bounds"),
};
field.inverse_fp12(n, xs)
let ptr = stack_peek(self, 11 - n).expect("Empty stack").as_usize();
let mut f: [U256; 12] = [U256::zero(); 12];
for i in 0..12 {
f[i] = kernel_general_peek(self, ptr + i);
}
field.inverse_fp12(n, f)
}
/// MPT data.
@ -196,11 +199,10 @@ impl EvmField {
modexp(x, q, n)
}
fn inverse_fp12(&self, n: usize, xs: Vec<U256>) -> U256 {
let offset = 12 - n;
let vec: Vec<U256> = xs[offset..].to_vec();
let f = fp12_to_array(inv_fp12(vec_to_fp12(vec)));
f[n]
fn inverse_fp12(&self, n: usize, f: [U256; 12]) -> U256 {
let f: Fp12 = unsafe { transmute(f) };
let f_inv: [U256; 12] = unsafe { transmute(inv_fp12(f)) };
f_inv[n]
}
}

View File

@ -27,7 +27,7 @@ fn to_bits_le<F: Field>(n: u8) -> [F; 8] {
res
}
/// Peak at the stack item `i`th from the top. If `i=0` this gives the tip.
/// Peek at the stack item `i`th from the top. If `i=0` this gives the tip.
pub(crate) fn stack_peek<F: Field>(state: &GenerationState<F>, i: usize) -> Option<U256> {
if i >= state.registers.stack_len {
return None;
@ -39,18 +39,13 @@ pub(crate) fn stack_peek<F: Field>(state: &GenerationState<F>, i: usize) -> Opti
)))
}
/// Peek at the entire stack.
pub(crate) fn stack_peeks<F: Field>(state: &GenerationState<F>) -> Option<Vec<U256>> {
let n = state.registers.stack_len;
let mut stack: Vec<U256> = vec![];
for i in 0..n {
stack.extend(vec![state.memory.get(MemoryAddress::new(
state.registers.code_context(),
Segment::Stack,
n - 1 - i,
))])
}
Some(stack)
/// Peek at the kernel general item at address `i`
pub(crate) fn kernel_general_peek<F: Field>(state: &GenerationState<F>, i: usize) -> U256 {
state.memory.get(MemoryAddress::new(
state.registers.context,
Segment::KernelGeneral,
i,
))
}
pub(crate) fn mem_read_with_log<F: Field>(