mirror of
https://github.com/logos-storage/plonky2.git
synced 2026-01-04 23:03:08 +00:00
improved prover input and test api
This commit is contained in:
parent
e06a2f2d46
commit
d2aa937a2f
@ -2,7 +2,6 @@ use std::mem::transmute;
|
||||
use std::ops::{Add, Div, Mul, Neg, Sub};
|
||||
|
||||
use ethereum_types::U256;
|
||||
use itertools::Itertools;
|
||||
use rand::{thread_rng, Rng};
|
||||
|
||||
pub const BN_BASE: U256 = U256([
|
||||
@ -139,17 +138,13 @@ impl Mul for Fp2 {
|
||||
}
|
||||
}
|
||||
|
||||
/// The inverse of a + bi is given by (a - bi)/(a^2 + b^2) since
|
||||
/// (a + bi)(a - bi)/(a^2 + b^2) = (a^2 + b^2)/(a^2 + b^2) = 1
|
||||
/// The inverse of z is given by z'/||z|| since ||z|| = zz'
|
||||
impl Div for Fp2 {
|
||||
type Output = Self;
|
||||
|
||||
fn div(self, rhs: Self) -> Self::Output {
|
||||
let norm = rhs.re * rhs.re + rhs.im * rhs.im;
|
||||
let inv = Fp2 {
|
||||
re: rhs.re / norm,
|
||||
im: -rhs.im / norm,
|
||||
};
|
||||
let inv = mul_fp_fp2(norm, conj_fp2(rhs));
|
||||
self * inv
|
||||
}
|
||||
}
|
||||
@ -833,36 +828,9 @@ const FROB_Z: [Fp2; 12] = [
|
||||
},
|
||||
];
|
||||
|
||||
pub fn fp12_to_array(f: Fp12) -> [U256; 12] {
|
||||
unsafe { transmute(f) }
|
||||
}
|
||||
|
||||
pub fn fp12_to_vec(f: Fp12) -> Vec<U256> {
|
||||
fp12_to_array(f).into_iter().collect()
|
||||
}
|
||||
|
||||
pub fn vec_to_fp12(xs: Vec<U256>) -> Fp12 {
|
||||
xs.into_iter()
|
||||
.tuples::<(U256, U256)>()
|
||||
.map(|(v1, v2)| Fp2 {
|
||||
re: Fp { val: v1 },
|
||||
im: Fp { val: v2 },
|
||||
})
|
||||
.tuples()
|
||||
.map(|(a1, a2, a3, a4, a5, a6)| Fp12 {
|
||||
z0: Fp6 {
|
||||
t0: a1,
|
||||
t1: a2,
|
||||
t2: a3,
|
||||
},
|
||||
z1: Fp6 {
|
||||
t0: a4,
|
||||
t1: a5,
|
||||
t2: a6,
|
||||
},
|
||||
})
|
||||
.next()
|
||||
.unwrap()
|
||||
let f: [U256; 12] = unsafe { transmute(f) };
|
||||
f.into_iter().collect()
|
||||
}
|
||||
|
||||
fn gen_fp() -> Fp {
|
||||
|
||||
@ -34,7 +34,7 @@ pub(crate) fn combined_kernel() -> Kernel {
|
||||
include_str!("asm/curve/bn254/field_arithmetic/fp12_mul.asm"),
|
||||
include_str!("asm/curve/bn254/field_arithmetic/frobenius.asm"),
|
||||
include_str!("asm/curve/bn254/field_arithmetic/power.asm"),
|
||||
include_str!("asm/curve/bn254/field_arithmetic/utils.asm"),
|
||||
include_str!("asm/curve/bn254/field_arithmetic/util.asm"),
|
||||
include_str!("asm/curve/common.asm"),
|
||||
include_str!("asm/curve/secp256k1/curve_mul.asm"),
|
||||
include_str!("asm/curve/secp256k1/curve_add.asm"),
|
||||
|
||||
@ -23,6 +23,18 @@
|
||||
|
||||
|
||||
global inv_fp12:
|
||||
// stack: ptr, inv, retdest
|
||||
%prover_inv_fp12
|
||||
// stack: f^-1, ptr, inv, retdest
|
||||
DUP14
|
||||
// stack: inv, f^-1, ptr, inv, retdest
|
||||
%store_fp12
|
||||
// stack: ptr, inv, retdest
|
||||
%stack (ptr, inv) -> (ptr, inv, 50, check_inv)
|
||||
// stack: ptr, inv, 50, check_inv, retdest
|
||||
%jump(mul_fp12)
|
||||
|
||||
global inv_fp12_old:
|
||||
// stack: ptr, inv, retdest
|
||||
DUP1 %load_fp12
|
||||
// stack: f, ptr, inv, retdest
|
||||
@ -39,9 +51,12 @@ global inv_fp12:
|
||||
%stack (check_inv, mem, ptr, inv) -> (ptr, inv, mem, check_inv)
|
||||
// stack: ptr, inv, 50, check_inv, retdest
|
||||
%jump(mul_fp12)
|
||||
|
||||
|
||||
global check_inv:
|
||||
// stack: retdest
|
||||
PUSH 50 %load_fp12
|
||||
PUSH 50
|
||||
%load_fp12
|
||||
// stack: unit?, retdest
|
||||
%assert_eq_unit_fp12
|
||||
// stack: retdest
|
||||
|
||||
@ -13,15 +13,13 @@ struct InterpreterSetup {
|
||||
offset: String,
|
||||
stack: Vec<U256>,
|
||||
memory: Vec<(usize, Vec<U256>)>,
|
||||
output: Range<usize>,
|
||||
}
|
||||
|
||||
fn get_interpreter_output(setup: InterpreterSetup) -> Result<Vec<U256>> {
|
||||
fn run_setup_interpreter(setup: InterpreterSetup) -> Result<Interpreter<'static>> {
|
||||
let label = KERNEL.global_labels[&setup.offset];
|
||||
let mut stack = setup.stack;
|
||||
stack.reverse();
|
||||
let mut interpreter = Interpreter::new_with_kernel(label, stack);
|
||||
|
||||
for (pointer, data) in setup.memory {
|
||||
for (i, term) in data.iter().enumerate() {
|
||||
interpreter.generation_state.memory.set(
|
||||
@ -30,54 +28,64 @@ fn get_interpreter_output(setup: InterpreterSetup) -> Result<Vec<U256>> {
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
interpreter.run()?;
|
||||
|
||||
let kernel = &interpreter.generation_state.memory.contexts[interpreter.context].segments
|
||||
[Segment::KernelGeneral as usize]
|
||||
.content;
|
||||
|
||||
let mut output: Vec<U256> = vec![];
|
||||
for i in setup.output {
|
||||
output.push(kernel[i]);
|
||||
}
|
||||
Ok(output)
|
||||
Ok(interpreter)
|
||||
}
|
||||
|
||||
fn setup_mul_test(f: Fp12, g: Fp12, label: &str) -> InterpreterSetup {
|
||||
let in0: usize = 64;
|
||||
let in1: usize = 76;
|
||||
let out: usize = 88;
|
||||
|
||||
let stack = vec![
|
||||
U256::from(in0),
|
||||
U256::from(in1),
|
||||
U256::from(out),
|
||||
U256::from(0xdeadbeefu32),
|
||||
];
|
||||
let memory = vec![(in0, fp12_to_vec(f)), (in1, fp12_to_vec(g))];
|
||||
fn extract_kernel_output(range: Range<usize>, interpreter: Interpreter<'static>) -> Vec<U256> {
|
||||
let mut output: Vec<U256> = vec![];
|
||||
for i in range {
|
||||
let term = interpreter.generation_state.memory.get(MemoryAddress::new(
|
||||
0,
|
||||
Segment::KernelGeneral,
|
||||
i,
|
||||
));
|
||||
output.push(term);
|
||||
}
|
||||
output
|
||||
}
|
||||
|
||||
fn setup_mul_test(
|
||||
in0: usize,
|
||||
in1: usize,
|
||||
out: usize,
|
||||
f: Fp12,
|
||||
g: Fp12,
|
||||
label: &str,
|
||||
) -> InterpreterSetup {
|
||||
InterpreterSetup {
|
||||
offset: label.to_string(),
|
||||
stack,
|
||||
memory,
|
||||
output: out..out + 12,
|
||||
stack: vec![
|
||||
U256::from(in0),
|
||||
U256::from(in1),
|
||||
U256::from(out),
|
||||
U256::from(0xdeadbeefu32),
|
||||
],
|
||||
memory: vec![(in0, fp12_to_vec(f)), (in1, fp12_to_vec(g))],
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_mul_fp12() -> Result<()> {
|
||||
let in0: usize = 64;
|
||||
let in1: usize = 76;
|
||||
let out: usize = 88;
|
||||
|
||||
let f: Fp12 = gen_fp12();
|
||||
let g: Fp12 = gen_fp12();
|
||||
let h: Fp12 = gen_fp12_sparse();
|
||||
|
||||
let setup_normal: InterpreterSetup = setup_mul_test(f, g, "mul_fp12");
|
||||
let setup_sparse: InterpreterSetup = setup_mul_test(f, h, "mul_fp12_sparse");
|
||||
let setup_square: InterpreterSetup = setup_mul_test(f, f, "square_fp12_test");
|
||||
let setup_normal: InterpreterSetup = setup_mul_test(in0, in1, out, f, g, "mul_fp12");
|
||||
let setup_sparse: InterpreterSetup = setup_mul_test(in0, in1, out, f, h, "mul_fp12_sparse");
|
||||
let setup_square: InterpreterSetup = setup_mul_test(in0, in1, out, f, f, "square_fp12_test");
|
||||
|
||||
let out_normal: Vec<U256> = get_interpreter_output(setup_normal).unwrap();
|
||||
let out_sparse: Vec<U256> = get_interpreter_output(setup_sparse).unwrap();
|
||||
let out_square: Vec<U256> = get_interpreter_output(setup_square).unwrap();
|
||||
let intrptr_normal: Interpreter = run_setup_interpreter(setup_normal).unwrap();
|
||||
let intrptr_sparse: Interpreter = run_setup_interpreter(setup_sparse).unwrap();
|
||||
let intrptr_square: Interpreter = run_setup_interpreter(setup_square).unwrap();
|
||||
|
||||
let out_normal: Vec<U256> = extract_kernel_output(out..out + 12, intrptr_normal);
|
||||
let out_sparse: Vec<U256> = extract_kernel_output(out..out + 12, intrptr_sparse);
|
||||
let out_square: Vec<U256> = extract_kernel_output(out..out + 12, intrptr_square);
|
||||
|
||||
let exp_normal: Vec<U256> = fp12_to_vec(f * g);
|
||||
let exp_sparse: Vec<U256> = fp12_to_vec(f * h);
|
||||
@ -90,32 +98,33 @@ fn test_mul_fp12() -> Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn setup_frob_test(f: Fp12, label: &str) -> InterpreterSetup {
|
||||
let ptr: usize = 100;
|
||||
let stack = vec![U256::from(ptr)];
|
||||
let memory = vec![(ptr, fp12_to_vec(f))];
|
||||
|
||||
fn setup_frob_test(ptr: usize, f: Fp12, label: &str) -> InterpreterSetup {
|
||||
InterpreterSetup {
|
||||
offset: label.to_string(),
|
||||
stack,
|
||||
memory,
|
||||
output: ptr..ptr + 12,
|
||||
stack: vec![U256::from(ptr)],
|
||||
memory: vec![(ptr, fp12_to_vec(f))],
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_frob_fp12() -> Result<()> {
|
||||
let ptr: usize = 100;
|
||||
let f: Fp12 = gen_fp12();
|
||||
|
||||
let setup_frob_1 = setup_frob_test(f, "test_frob_fp12_1");
|
||||
let setup_frob_2 = setup_frob_test(f, "test_frob_fp12_2");
|
||||
let setup_frob_3 = setup_frob_test(f, "test_frob_fp12_3");
|
||||
let setup_frob_6 = setup_frob_test(f, "test_frob_fp12_6");
|
||||
let setup_frob_1 = setup_frob_test(ptr, f, "test_frob_fp12_1");
|
||||
let setup_frob_2 = setup_frob_test(ptr, f, "test_frob_fp12_2");
|
||||
let setup_frob_3 = setup_frob_test(ptr, f, "test_frob_fp12_3");
|
||||
let setup_frob_6 = setup_frob_test(ptr, f, "test_frob_fp12_6");
|
||||
|
||||
let out_frob_1: Vec<U256> = get_interpreter_output(setup_frob_1).unwrap();
|
||||
let out_frob_2: Vec<U256> = get_interpreter_output(setup_frob_2).unwrap();
|
||||
let out_frob_3: Vec<U256> = get_interpreter_output(setup_frob_3).unwrap();
|
||||
let out_frob_6: Vec<U256> = get_interpreter_output(setup_frob_6).unwrap();
|
||||
let intrptr_frob_1: Interpreter = run_setup_interpreter(setup_frob_1).unwrap();
|
||||
let intrptr_frob_2: Interpreter = run_setup_interpreter(setup_frob_2).unwrap();
|
||||
let intrptr_frob_3: Interpreter = run_setup_interpreter(setup_frob_3).unwrap();
|
||||
let intrptr_frob_6: Interpreter = run_setup_interpreter(setup_frob_6).unwrap();
|
||||
|
||||
let out_frob_1: Vec<U256> = extract_kernel_output(ptr..ptr + 12, intrptr_frob_1);
|
||||
let out_frob_2: Vec<U256> = extract_kernel_output(ptr..ptr + 12, intrptr_frob_2);
|
||||
let out_frob_3: Vec<U256> = extract_kernel_output(ptr..ptr + 12, intrptr_frob_3);
|
||||
let out_frob_6: Vec<U256> = extract_kernel_output(ptr..ptr + 12, intrptr_frob_6);
|
||||
|
||||
let exp_frob_1: Vec<U256> = fp12_to_vec(frob_fp12(1, f));
|
||||
let exp_frob_2: Vec<U256> = fp12_to_vec(frob_fp12(2, f));
|
||||
@ -130,26 +139,19 @@ fn test_frob_fp12() -> Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn setup_inv_test(f: Fp12) -> InterpreterSetup {
|
||||
let ptr: usize = 100;
|
||||
let inv: usize = 112;
|
||||
let stack = vec![U256::from(ptr), U256::from(inv), U256::from(0xdeadbeefu32)];
|
||||
let memory = vec![(ptr, fp12_to_vec(f))];
|
||||
|
||||
InterpreterSetup {
|
||||
offset: "inv_fp12".to_string(),
|
||||
stack,
|
||||
memory,
|
||||
output: inv..inv + 12,
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_inv_fp12() -> Result<()> {
|
||||
let ptr: usize = 100;
|
||||
let inv: usize = 112;
|
||||
let f: Fp12 = gen_fp12();
|
||||
|
||||
let setup = setup_inv_test(f);
|
||||
let output: Vec<U256> = get_interpreter_output(setup).unwrap();
|
||||
let setup = InterpreterSetup {
|
||||
offset: "inv_fp12".to_string(),
|
||||
stack: vec![U256::from(ptr), U256::from(inv), U256::from(0xdeadbeefu32)],
|
||||
memory: vec![(ptr, fp12_to_vec(f))],
|
||||
};
|
||||
let interpreter: Interpreter = run_setup_interpreter(setup).unwrap();
|
||||
let output: Vec<U256> = extract_kernel_output(inv..inv + 12, interpreter);
|
||||
let expected: Vec<U256> = fp12_to_vec(inv_fp12(f));
|
||||
|
||||
assert_eq!(output, expected);
|
||||
@ -173,7 +175,7 @@ fn test_inv_fp12() -> Result<()> {
|
||||
// out,
|
||||
// ]);
|
||||
|
||||
// let output: Vec<U256> = get_interpreter_output("test_pow", stack);
|
||||
// let output: Vec<U256> = run_setup_interpreter("test_pow", stack);
|
||||
// let expected: Vec<U256> = fp12_to_vec(power(f));
|
||||
|
||||
// assert_eq!(output, expected);
|
||||
@ -206,7 +208,7 @@ fn test_inv_fp12() -> Result<()> {
|
||||
// let q: TwistedCurve = twisted_curve_generator();
|
||||
|
||||
// let stack = make_tate_stack(p, q);
|
||||
// let output = get_interpreter_output("test_miller", stack);
|
||||
// let output = run_setup_interpreter("test_miller", stack);
|
||||
// let expected = fp12_to_vec(miller_loop(p, q));
|
||||
|
||||
// assert_eq!(output, expected);
|
||||
@ -220,7 +222,7 @@ fn test_inv_fp12() -> Result<()> {
|
||||
// let q: TwistedCurve = twisted_curve_generator();
|
||||
|
||||
// let stack = make_tate_stack(p, q);
|
||||
// let output = get_interpreter_output("test_tate", stack);
|
||||
// let output = run_setup_interpreter("test_tate", stack);
|
||||
// let expected = fp12_to_vec(tate(p, q));
|
||||
|
||||
// assert_eq!(output, expected);
|
||||
|
||||
@ -1,16 +1,17 @@
|
||||
use std::mem::transmute;
|
||||
use std::str::FromStr;
|
||||
|
||||
use anyhow::{bail, Error};
|
||||
use ethereum_types::{BigEndianHash, H256, U256};
|
||||
use plonky2::field::types::Field;
|
||||
|
||||
use crate::bn254_arithmetic::{fp12_to_array, inv_fp12, vec_to_fp12};
|
||||
use crate::bn254_arithmetic::{inv_fp12, Fp12};
|
||||
use crate::generation::prover_input::EvmField::{
|
||||
Bn254Base, Bn254Scalar, Secp256k1Base, Secp256k1Scalar,
|
||||
};
|
||||
use crate::generation::prover_input::FieldOp::{Inverse, Sqrt};
|
||||
use crate::generation::state::GenerationState;
|
||||
use crate::witness::util::{stack_peek, stack_peeks};
|
||||
use crate::witness::util::{kernel_general_peek, stack_peek};
|
||||
|
||||
/// Prover input function represented as a scoped function name.
|
||||
/// Example: `PROVER_INPUT(ff::bn254_base::inverse)` is represented as `ProverInputFn([ff, bn254_base, inverse])`.
|
||||
@ -57,10 +58,7 @@ impl<F: Field> GenerationState<F> {
|
||||
/// Finite field extension operations.
|
||||
fn run_ffe(&self, input_fn: &ProverInputFn) -> U256 {
|
||||
let field = EvmField::from_str(input_fn.0[1].as_str()).unwrap();
|
||||
let component = input_fn.0[2].as_str();
|
||||
let xs = stack_peeks(self).expect("Empty stack");
|
||||
// TODO: This sucks... come back later
|
||||
let n = match component {
|
||||
let n = match input_fn.0[2].as_str() {
|
||||
"component_0" => 0,
|
||||
"component_1" => 1,
|
||||
"component_2" => 2,
|
||||
@ -75,7 +73,12 @@ impl<F: Field> GenerationState<F> {
|
||||
"component_11" => 11,
|
||||
_ => panic!("out of bounds"),
|
||||
};
|
||||
field.inverse_fp12(n, xs)
|
||||
let ptr = stack_peek(self, 11 - n).expect("Empty stack").as_usize();
|
||||
let mut f: [U256; 12] = [U256::zero(); 12];
|
||||
for i in 0..12 {
|
||||
f[i] = kernel_general_peek(self, ptr + i);
|
||||
}
|
||||
field.inverse_fp12(n, f)
|
||||
}
|
||||
|
||||
/// MPT data.
|
||||
@ -196,11 +199,10 @@ impl EvmField {
|
||||
modexp(x, q, n)
|
||||
}
|
||||
|
||||
fn inverse_fp12(&self, n: usize, xs: Vec<U256>) -> U256 {
|
||||
let offset = 12 - n;
|
||||
let vec: Vec<U256> = xs[offset..].to_vec();
|
||||
let f = fp12_to_array(inv_fp12(vec_to_fp12(vec)));
|
||||
f[n]
|
||||
fn inverse_fp12(&self, n: usize, f: [U256; 12]) -> U256 {
|
||||
let f: Fp12 = unsafe { transmute(f) };
|
||||
let f_inv: [U256; 12] = unsafe { transmute(inv_fp12(f)) };
|
||||
f_inv[n]
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -27,7 +27,7 @@ fn to_bits_le<F: Field>(n: u8) -> [F; 8] {
|
||||
res
|
||||
}
|
||||
|
||||
/// Peak at the stack item `i`th from the top. If `i=0` this gives the tip.
|
||||
/// Peek at the stack item `i`th from the top. If `i=0` this gives the tip.
|
||||
pub(crate) fn stack_peek<F: Field>(state: &GenerationState<F>, i: usize) -> Option<U256> {
|
||||
if i >= state.registers.stack_len {
|
||||
return None;
|
||||
@ -39,18 +39,13 @@ pub(crate) fn stack_peek<F: Field>(state: &GenerationState<F>, i: usize) -> Opti
|
||||
)))
|
||||
}
|
||||
|
||||
/// Peek at the entire stack.
|
||||
pub(crate) fn stack_peeks<F: Field>(state: &GenerationState<F>) -> Option<Vec<U256>> {
|
||||
let n = state.registers.stack_len;
|
||||
let mut stack: Vec<U256> = vec![];
|
||||
for i in 0..n {
|
||||
stack.extend(vec![state.memory.get(MemoryAddress::new(
|
||||
state.registers.code_context(),
|
||||
Segment::Stack,
|
||||
n - 1 - i,
|
||||
))])
|
||||
}
|
||||
Some(stack)
|
||||
/// Peek at the kernel general item at address `i`
|
||||
pub(crate) fn kernel_general_peek<F: Field>(state: &GenerationState<F>, i: usize) -> U256 {
|
||||
state.memory.get(MemoryAddress::new(
|
||||
state.registers.context,
|
||||
Segment::KernelGeneral,
|
||||
i,
|
||||
))
|
||||
}
|
||||
|
||||
pub(crate) fn mem_read_with_log<F: Field>(
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user