mirror of
https://github.com/logos-storage/plonky2.git
synced 2026-02-24 15:53:09 +00:00
fixes, testing, and in-progress debugging
This commit is contained in:
parent
24705e1e39
commit
d59501e6a7
@ -6,7 +6,6 @@
|
||||
global add_bignum:
|
||||
// stack: len, a_start_loc, b_start_loc, retdest
|
||||
DUP1
|
||||
// stack: len, len, a_start_loc, b_start_loc, retdest
|
||||
ISZERO
|
||||
%jumpi(len_zero)
|
||||
// stack: len, a_start_loc, b_start_loc, retdest
|
||||
|
||||
@ -8,6 +8,9 @@
|
||||
// All of scratch_2..scratch_5 must have size 2 * length and be initialized with zeroes.
|
||||
global modexp_bignum:
|
||||
// stack: len, b_loc, e_loc, m_loc, out_loc, s1 (=scratch_1), s2, s3, s4, s5, retdest
|
||||
DUP1
|
||||
ISZERO
|
||||
%jumpi(len_zero)
|
||||
|
||||
// We store the repeated-squares accumulator x_i in scratch_1, starting with x_0 := b.
|
||||
DUP1
|
||||
@ -125,10 +128,11 @@ modexp_iszero_return:
|
||||
%jumpi(modexp_loop)
|
||||
// end of modexp_loop
|
||||
// stack: len, b_loc, e_loc, m_loc, out_loc, s1, s2, s3, s4, s5, retdest
|
||||
%rep 10
|
||||
POP
|
||||
%endrep
|
||||
%pop10
|
||||
// stack: retdest
|
||||
JUMP
|
||||
len_zero:
|
||||
// stack: len, b_loc, e_loc, m_loc, out_loc, s1, s2, s3, s4, s5, retdest
|
||||
%pop10
|
||||
// stack: retdest
|
||||
JUMP
|
||||
|
||||
|
||||
|
||||
@ -6,158 +6,141 @@
|
||||
// Both output_loc and scratch_1 must have size length.
|
||||
// Both scratch_2 and scratch_3 have size 2 * length and be initialized with zeroes.
|
||||
global modmul_bignum:
|
||||
// stack: length, a_start_loc, b_start_loc, m_start_loc, output_loc, scratch_1, scratch_2, scratch_3, retdest
|
||||
// stack: len, a_loc, b_loc, m_loc, out_loc, s1 (=scratch_1), s2, s3, retdest
|
||||
DUP1
|
||||
ISZERO
|
||||
%jumpi(len_zero)
|
||||
|
||||
// The prover provides x := (a * b) % m, which we store in output_loc.
|
||||
|
||||
PUSH 0
|
||||
// stack: i=0, length, a_start_loc, b_start_loc, m_start_loc, output_loc, scratch_1, scratch_2, scratch_3, retdest
|
||||
// stack: i=0, len, a_loc, b_loc, m_loc, out_loc, s1, s2, s3, retdest
|
||||
modmul_remainder_loop:
|
||||
// stack: i, length, a_start_loc, b_start_loc, m_start_loc, output_loc, scratch_1, scratch_2, scratch_3, retdest
|
||||
PROVER_INPUT(bignum_modmul::remainder)
|
||||
// stack: PI, i, length, a_start_loc, b_start_loc, m_start_loc, output_loc, scratch_1, scratch_2, scratch_3, retdest
|
||||
// stack: i, len, a_loc, b_loc, m_loc, out_loc, s1, s2, s3, retdest
|
||||
PROVER_INPUT(bignum_modmul)
|
||||
// stack: PI, i, len, a_loc, b_loc, m_loc, out_loc, s1, s2, s3, retdest
|
||||
DUP7
|
||||
// stack: output_loc, PI, i, length, a_start_loc, b_start_loc, m_start_loc, output_loc, scratch_1, scratch_2, scratch_3, retdest
|
||||
DUP3
|
||||
// stack: i, output_loc, PI, i, length, a_start_loc, b_start_loc, m_start_loc, output_loc, scratch_1, scratch_2, scratch_3, retdest
|
||||
ADD
|
||||
// stack: output_loc[i], PI, i, length, a_start_loc, b_start_loc, m_start_loc, output_loc, scratch_1, scratch_2, scratch_3, retdest
|
||||
// stack: out_loc[i], PI, i, len, a_loc, b_loc, m_loc, out_loc, s1, s2, s3, retdest
|
||||
%mstore_kernel_general
|
||||
// stack: i, length, a_start_loc, b_start_loc, m_start_loc, output_loc, scratch_1, scratch_2, scratch_3, retdest
|
||||
// stack: i, len, a_loc, b_loc, m_loc, out_loc, s1, s2, s3, retdest
|
||||
%increment
|
||||
// stack: i+1, length, a_start_loc, b_start_loc, m_start_loc, output_loc, scratch_1, scratch_2, scratch_3, retdest
|
||||
DUP2
|
||||
DUP2
|
||||
// stack: i+1, length, i+1, length, a_start_loc, b_start_loc, m_start_loc, output_loc, scratch_1, scratch_2, scratch_3, retdest
|
||||
EQ
|
||||
// stack: i+1==length, i+1, length, a_start_loc, b_start_loc, m_start_loc, output_loc, scratch_1, scratch_2, scratch_3, retdest
|
||||
// stack: i+1, len, i+1, len, a_loc, b_loc, m_loc, out_loc, s1, s2, s3, retdest
|
||||
SUB // functions as NEQ
|
||||
// stack: i+1!=length, i+1, length, a_start_loc, b_start_loc, m_start_loc, output_loc, scratch_1, scratch_2, scratch_3, retdest
|
||||
// stack: i+1!=len, i+1, len, a_loc, b_loc, m_loc, out_loc, s1, s2, s3, retdest
|
||||
%jumpi(modmul_remainder_loop)
|
||||
// end of modmul_remainder_loop
|
||||
// stack: i, length, a_start_loc, b_start_loc, m_start_loc, output_loc, scratch_1, scratch_2, scratch_3, retdest
|
||||
// stack: i, len, a_loc, b_loc, m_loc, out_loc, s1, s2, s3, retdest
|
||||
POP
|
||||
|
||||
// The prover provides k := (a * b) / m, which we store in scratch_1.
|
||||
|
||||
// stack: length, a_start_loc, b_start_loc, m_start_loc, output_loc, scratch_1, scratch_2, scratch_3, retdest
|
||||
// stack: len, a_loc, b_loc, m_loc, out_loc, s1, s2, s3, retdest
|
||||
PUSH 0
|
||||
// stack: i=0, length, a_start_loc, b_start_loc, m_start_loc, output_loc, scratch_1, scratch_2, scratch_3, retdest
|
||||
// stack: i=0, len, a_loc, b_loc, m_loc, out_loc, s1, s2, s3, retdest
|
||||
modmul_quotient_loop:
|
||||
// stack: i, length, a_start_loc, b_start_loc, m_start_loc, output_loc, scratch_1, scratch_2, scratch_3, retdest
|
||||
PROVER_INPUT(bignum_modmul::quotient)
|
||||
// stack: PI, i, length, a_start_loc, b_start_loc, m_start_loc, output_loc, scratch_1, scratch_2, scratch_3, retdest
|
||||
// stack: i, len, a_loc, b_loc, m_loc, out_loc, s1, s2, s3, retdest
|
||||
PROVER_INPUT(bignum_modmul)
|
||||
// stack: PI, i, len, a_loc, b_loc, m_loc, out_loc, s1, s2, s3, retdest
|
||||
DUP8
|
||||
// stack: scratch_1, PI, i, length, a_start_loc, b_start_loc, m_start_loc, output_loc, scratch_1, scratch_2, scratch_3, retdest
|
||||
DUP3
|
||||
// stack: i, scratch_1, PI, i, length, a_start_loc, b_start_loc, m_start_loc, output_loc, scratch_1, scratch_2, scratch_3, retdest
|
||||
ADD
|
||||
// stack: scratch_1[i], PI, i, length, a_start_loc, b_start_loc, m_start_loc, output_loc, scratch_1, scratch_2, scratch_3, retdest
|
||||
// stack: s1[i], PI, i, len, a_loc, b_loc, m_loc, out_loc, s1, s2, s3, retdest
|
||||
%mstore_kernel_general
|
||||
// stack: i, length, a_start_loc, b_start_loc, m_start_loc, output_loc, scratch_1, scratch_2, scratch_3, retdest
|
||||
// stack: i, len, a_loc, b_loc, m_loc, out_loc, s1, s2, s3, retdest
|
||||
%increment
|
||||
// stack: i+1, length, a_start_loc, b_start_loc, m_start_loc, output_loc, scratch_1, scratch_2, scratch_3, retdest
|
||||
DUP2
|
||||
DUP2
|
||||
// stack: i+1, length, i+1, length, a_start_loc, b_start_loc, m_start_loc, output_loc, scratch_1, scratch_2, scratch_3, retdest
|
||||
// stack: i+1, len, i+1, len, a_loc, b_loc, m_loc, out_loc, s1, s2, s3, retdest
|
||||
SUB // functions as NEQ
|
||||
// stack: i+1!=length, i+1, length, a_start_loc, b_start_loc, m_start_loc, output_loc, scratch_1, scratch_2, scratch_3, retdest
|
||||
// stack: i+1!=len, i+1, len, a_loc, b_loc, m_loc, out_loc, s1, s2, s3, retdest
|
||||
%jumpi(modmul_quotient_loop)
|
||||
// end of modmul_quotient_loop
|
||||
// stack: i, length, a_start_loc, b_start_loc, m_start_loc, output_loc, scratch_1, scratch_2, scratch_3, retdest
|
||||
// stack: i, len, a_loc, b_loc, m_loc, out_loc, s1, s2, s3, retdest
|
||||
POP
|
||||
// stack: length, a_start_loc, b_start_loc, m_start_loc, output_loc, scratch_1, scratch_2, scratch_3, retdest
|
||||
// stack: len, a_loc, b_loc, m_loc, out_loc, s1, s2, s3, retdest
|
||||
|
||||
// Verification step 1: calculate x + k * m.
|
||||
|
||||
// Store k * m in scratch_2.
|
||||
PUSH modmul_return_1
|
||||
// stack: modmul_return_1, length, a_start_loc, b_start_loc, m_start_loc, output_loc, scratch_1, scratch_2, scratch_3, retdest
|
||||
%stack (return, len, a, b, m, out, s1, s2) -> (len, s1, m, s2, return, len, a, b, out, s2)
|
||||
// stack: length, scratch_1, m_start_loc, scratch_2, modmul_return_1, length, a_start_loc, b_start_loc, output_loc, scratch_2, scratch_3, retdest
|
||||
// stack: len, s1, m_loc, s2, modmul_return_1, len, a_loc, b_loc, out_loc, s2, s3, retdest
|
||||
%jump(mul_bignum)
|
||||
modmul_return_1:
|
||||
// stack: length, a_start_loc, b_start_loc, output_loc, scratch_2, scratch_3, retdest
|
||||
// stack: len, a_loc, b_loc, out_loc, s2, s3, retdest
|
||||
|
||||
// Add x into k * m (in scratch_2).
|
||||
PUSH modmul_return_2
|
||||
// stack: modmul_return_2, length, a_start_loc, b_start_loc, output_loc, scratch_2, scratch_3, retdest
|
||||
%stack (return, len, a, b, out, s2) -> (len, s2, out, return, len, a, b, s2)
|
||||
// stack: length, scratch_2, output_loc, modmul_return_2, length, a_start_loc, b_start_loc, scratch_2, scratch_3, retdest
|
||||
// stack: len, s2, out_loc, modmul_return_2, len, a_loc, b_loc, s2, s3, retdest
|
||||
%jump(add_bignum)
|
||||
modmul_return_2:
|
||||
// stack: carry, length, a_start_loc, b_start_loc, scratch_2, scratch_3, retdest
|
||||
// stack: carry, len, a_loc, b_loc, s2, s3, retdest
|
||||
ISZERO
|
||||
%jumpi(no_carry)
|
||||
|
||||
// stack: length, a_start_loc, b_start_loc, scratch_2, scratch_3, retdest
|
||||
// stack: len, a_loc, b_loc, s2, s3, retdest
|
||||
DUP4
|
||||
// stack: scratch_2, length, a_start_loc, b_start_loc, scratch_2, scratch_3, retdest
|
||||
DUP2
|
||||
// stack: length, scratch_2, length, a_start_loc, b_start_loc, scratch_2, scratch_3, retdest
|
||||
ADD
|
||||
// stack: cur_loc=scratch_2 + length, length, a_start_loc, b_start_loc, scratch_2, scratch_3, retdest
|
||||
// stack: cur_loc=s2 + len, len, a_loc, b_loc, s2, s3, retdest
|
||||
increment_loop:
|
||||
// stack: cur_loc, length, a_start_loc, b_start_loc, scratch_2, scratch_3, retdest
|
||||
// stack: cur_loc, len, a_loc, b_loc, s2, s3, retdest
|
||||
DUP1
|
||||
%mload_kernel_general
|
||||
// stack: val, cur_loc, length, a_start_loc, b_start_loc, scratch_2, scratch_3, retdest
|
||||
// stack: val, cur_loc, len, a_loc, b_loc, s2, s3, retdest
|
||||
%increment
|
||||
// stack: val+1, cur_loc, length, a_start_loc, b_start_loc, scratch_2, scratch_3, retdest
|
||||
DUP1
|
||||
// stack: val+1, val+1, cur_loc, length, a_start_loc, b_start_loc, scratch_2, scratch_3, retdest
|
||||
// stack: val+1, val+1, cur_loc, len, a_loc, b_loc, s2, s3, retdest
|
||||
%eq_const(@BIGNUM_LIMB_BASE)
|
||||
// stack: val+1==limb_base, val+1, cur_loc, length, a_start_loc, b_start_loc, scratch_2, scratch_3, retdest
|
||||
DUP1
|
||||
// stack: val+1==limb_base, val+1==limb_base, val+1, cur_loc, length, a_start_loc, b_start_loc, scratch_2, scratch_3, retdest
|
||||
ISZERO
|
||||
// stack: val+1!=limb_base, val+1==limb_base, val+1, cur_loc, length, a_start_loc, b_start_loc, scratch_2, scratch_3, retdest
|
||||
// stack: val+1!=limb_base, val+1==limb_base, val+1, cur_loc, len, a_loc, b_loc, s2, s3, retdest
|
||||
SWAP1
|
||||
SWAP2
|
||||
// stack: val+1, val+1!=limb_base, val+1==limb_base, cur_loc, length, a_start_loc, b_start_loc, scratch_2, scratch_3, retdest
|
||||
// stack: val+1, val+1!=limb_base, val+1==limb_base, cur_loc, len, a_loc, b_loc, s2, s3, retdest
|
||||
MUL
|
||||
// stack: to_write=(val+1)*(val+1!=limb_base), continue=val+1==limb_base, cur_loc, length, a_start_loc, b_start_loc, scratch_2, scratch_3, retdest
|
||||
// stack: to_write=(val+1)*(val+1!=limb_base), continue=val+1==limb_base, cur_loc, len, a_loc, b_loc, s2, s3, retdest
|
||||
DUP3
|
||||
// stack: cur_loc, to_write, continue, cur_loc, length, a_start_loc, b_start_loc, scratch_2, scratch_3, retdest
|
||||
// stack: cur_loc, to_write, continue, cur_loc, len, a_loc, b_loc, s2, s3, retdest
|
||||
%mstore_kernel_general
|
||||
// stack: continue, cur_loc, length, a_start_loc, b_start_loc, scratch_2, scratch_3, retdest
|
||||
// stack: continue, cur_loc, len, a_loc, b_loc, s2, s3, retdest
|
||||
SWAP1
|
||||
// stack: cur_loc, continue, length, a_start_loc, b_start_loc, scratch_2, scratch_3, retdest
|
||||
%increment
|
||||
// stack: cur_loc + 1, continue, length, a_start_loc, b_start_loc, scratch_2, scratch_3, retdest
|
||||
DUP1
|
||||
// stack: cur_loc + 1, cur_loc + 1, continue, length, a_start_loc, b_start_loc, scratch_2, scratch_3, retdest
|
||||
DUP8
|
||||
// stack: scratch_3, cur_loc + 1, cur_loc + 1, continue, length, a_start_loc, b_start_loc, scratch_2, scratch_3, retdest
|
||||
// stack: s3, cur_loc + 1, cur_loc + 1, continue, len, a_loc, b_loc, s2, s3, retdest
|
||||
EQ
|
||||
// stack: cur_loc + 1 == scratch_3, cur_loc + 1, continue, length, a_start_loc, b_start_loc, scratch_2, scratch_3, retdest
|
||||
ISZERO
|
||||
// stack: cur_loc + 1 != scratch_3, cur_loc + 1, continue, length, a_start_loc, b_start_loc, scratch_2, scratch_3, retdest
|
||||
// stack: cur_loc + 1 != s3, cur_loc + 1, continue, len, a_loc, b_loc, s2, s3, retdest
|
||||
SWAP1
|
||||
SWAP2
|
||||
// stack: continue, cur_loc + 1 != scratch_3, cur_loc + 1, length, a_start_loc, b_start_loc, scratch_2, scratch_3, retdest
|
||||
// stack: continue, cur_loc + 1 != s3, cur_loc + 1, len, a_loc, b_loc, s2, s3, retdest
|
||||
MUL
|
||||
// stack: new_continue=continue*(cur_loc + 1 != scratch_3), cur_loc + 1, length, a_start_loc, b_start_loc, scratch_2, scratch_3, retdest
|
||||
// stack: new_continue=continue*(cur_loc + 1 != s3), cur_loc + 1, len, a_loc, b_loc, s2, s3, retdest
|
||||
%jumpi(increment_loop)
|
||||
// stack: cur_loc + 1, length, a_start_loc, b_start_loc, scratch_2, scratch_3, retdest
|
||||
// stack: cur_loc + 1, len, a_loc, b_loc, s2, s3, retdest
|
||||
POP
|
||||
no_carry:
|
||||
// stack: length, a_start_loc, b_start_loc, scratch_2, scratch_3, retdest
|
||||
// stack: len, a_loc, b_loc, s2, s3, retdest
|
||||
|
||||
// Calculate a * b.
|
||||
|
||||
// Store a * b in scratch_3.
|
||||
PUSH modmul_return_3
|
||||
// stack: modmul_return_3, length, a_start_loc, b_start_loc, scratch_2, scratch_3, retdest
|
||||
%stack (return, len, a, b, s2, s3) -> (len, a, b, s3, return, len, s2, s3)
|
||||
// stack: length, a_start_loc, b_start_loc, scratch_3, modmul_return_3, length, scratch_2, scratch_3, retdest
|
||||
// stack: len, a_loc, b_loc, s3, modmul_return_3, len, s2, s3, retdest
|
||||
%jump(mul_bignum)
|
||||
modmul_return_3:
|
||||
// stack: length, scratch_2, scratch_3, retdest
|
||||
// stack: len, s2, s3, retdest
|
||||
|
||||
// Check that x + k * m = a * b.
|
||||
// Walk through scratch_2 and scratch_3, checking that they are equal.
|
||||
// stack: n=length, i=scratch_2, j=scratch_3, retdest
|
||||
// stack: n=len, i=s2, j=s3, retdest
|
||||
modmul_check_loop:
|
||||
// stack: n, i, j, retdest
|
||||
%stack (l, idx: 2) -> (idx, l, idx)
|
||||
@ -170,23 +153,23 @@ modmul_check_loop:
|
||||
%assert_eq
|
||||
// stack: n, i, j, retdest
|
||||
%decrement
|
||||
// stack: n-1, i, j, retdest
|
||||
SWAP1
|
||||
// stack: i, n-1, j, retdest
|
||||
%increment
|
||||
// stack: i+1, n-1, j, retdest
|
||||
SWAP2
|
||||
// stack: j, n-1, i+1, retdest
|
||||
%increment
|
||||
// stack: j+1, n-1, i+1, retdest
|
||||
SWAP2
|
||||
SWAP1
|
||||
// stack: n-1, i+1, j+1, retdest
|
||||
DUP1
|
||||
// stack: n-1, n-1, i+1, j+1, retdest
|
||||
%jumpi(modmul_check_loop)
|
||||
modmul_check_end:
|
||||
// end of modmul_check_loop
|
||||
// stack: n-1, i+1, j+1, retdest
|
||||
%pop3
|
||||
// stack: retdest
|
||||
JUMP
|
||||
len_zero:
|
||||
// stack: len, a_loc, b_loc, m_loc, out_loc, s1, s2, s3, retdest
|
||||
%pop8
|
||||
// stack: retdest
|
||||
JUMP
|
||||
@ -50,6 +50,18 @@
|
||||
%endrep
|
||||
%endmacro
|
||||
|
||||
%macro pop9
|
||||
%rep 9
|
||||
POP
|
||||
%endrep
|
||||
%endmacro
|
||||
|
||||
%macro pop10
|
||||
%rep 10
|
||||
POP
|
||||
%endrep
|
||||
%endmacro
|
||||
|
||||
%macro and_const(c)
|
||||
// stack: input, ...
|
||||
PUSH $c
|
||||
|
||||
@ -399,9 +399,9 @@ impl<'a> Interpreter<'a> {
|
||||
.debug_offsets
|
||||
.contains(&self.generation_state.registers.program_counter)
|
||||
{
|
||||
println!("At {}, stack={:?}", self.offset_name(), self.stack());
|
||||
// println!("At {}, stack={:?}", self.offset_name(), self.stack());
|
||||
} else if let Some(label) = self.offset_label() {
|
||||
println!("At {label}");
|
||||
// println!("At {label}");
|
||||
}
|
||||
|
||||
Ok(())
|
||||
|
||||
@ -20,12 +20,15 @@ const MINUS_ONE: U256 = U256::MAX;
|
||||
|
||||
const TEST_DATA_BIGNUM_INPUTS: &str = "bignum_inputs";
|
||||
const TEST_DATA_U128_INPUTS: &str = "u128_inputs";
|
||||
|
||||
const TEST_DATA_SHR_OUTPUTS: &str = "shr_outputs";
|
||||
const TEST_DATA_ISZERO_OUTPUTS: &str = "iszero_outputs";
|
||||
const TEST_DATA_CMP_OUTPUTS: &str = "cmp_outputs";
|
||||
const TEST_DATA_ADD_OUTPUTS: &str = "add_outputs";
|
||||
const TEST_DATA_ADDMUL_OUTPUTS: &str = "addmul_outputs";
|
||||
const TEST_DATA_MUL_OUTPUTS: &str = "mul_outputs";
|
||||
const TEST_DATA_MODMUL_OUTPUTS: &str = "modmul_outputs";
|
||||
const TEST_DATA_MODEXP_OUTPUTS: &str = "modexp_outputs";
|
||||
|
||||
const BIT_SIZES_TO_TEST: [usize; 15] = [
|
||||
0, 1, 2, 127, 128, 129, 255, 256, 257, 512, 1000, 1023, 1024, 1025, 31415,
|
||||
@ -232,6 +235,76 @@ fn test_mul_bignum(a: BigUint, b: BigUint, expected_output: BigUint) -> Result<(
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn test_modmul_bignum(a: BigUint, b: BigUint, m: BigUint, expected_output: BigUint) -> Result<()> {
|
||||
let len = bignum_len(&a).max(bignum_len(&b)).max(bignum_len(&m));
|
||||
let output_len = len;
|
||||
let memory = pad_bignums(&[a, b, m], len);
|
||||
|
||||
let a_start_loc = 0;
|
||||
let b_start_loc = len;
|
||||
let m_start_loc = 2 * len;
|
||||
let output_start_loc = 3 * len;
|
||||
let scratch_1 = 4 * len;
|
||||
let scratch_2 = 5 * len; // size 2*len
|
||||
let scratch_3 = 7 * len; // size 2*len
|
||||
let (new_memory, _new_stack) = run_test(
|
||||
"modmul_bignum",
|
||||
memory,
|
||||
vec![
|
||||
len.into(),
|
||||
a_start_loc.into(),
|
||||
b_start_loc.into(),
|
||||
m_start_loc.into(),
|
||||
output_start_loc.into(),
|
||||
scratch_1.into(),
|
||||
scratch_2.into(),
|
||||
scratch_3.into(),
|
||||
],
|
||||
)?;
|
||||
|
||||
let output = mem_vec_to_biguint(&new_memory[output_start_loc..output_start_loc + output_len]);
|
||||
assert_eq!(output, expected_output);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn test_modexp_bignum(b: BigUint, e: BigUint, m: BigUint, expected_output: BigUint) -> Result<()> {
|
||||
let len = bignum_len(&b).max(bignum_len(&e)).max(bignum_len(&m));
|
||||
let output_len = len;
|
||||
let memory = pad_bignums(&[b, e, m], len);
|
||||
|
||||
let b_start_loc = 0;
|
||||
let e_start_loc = len;
|
||||
let m_start_loc = 2 * len;
|
||||
let output_start_loc = 3 * len;
|
||||
let scratch_1 = 4 * len;
|
||||
let scratch_2 = 5 * len; // size 2*len
|
||||
let scratch_3 = 7 * len; // size 2*len
|
||||
let scratch_4 = 9 * len; // size 2*len
|
||||
let scratch_5 = 11 * len; // size 2*len
|
||||
let (new_memory, new_stack) = run_test(
|
||||
"modexp_bignum",
|
||||
memory,
|
||||
vec![
|
||||
len.into(),
|
||||
b_start_loc.into(),
|
||||
e_start_loc.into(),
|
||||
m_start_loc.into(),
|
||||
output_start_loc.into(),
|
||||
scratch_1.into(),
|
||||
scratch_2.into(),
|
||||
scratch_3.into(),
|
||||
scratch_4.into(),
|
||||
scratch_5.into(),
|
||||
],
|
||||
)?;
|
||||
|
||||
let output = mem_vec_to_biguint(&new_memory[output_start_loc..output_start_loc + output_len]);
|
||||
assert_eq!(output, expected_output);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_shr_bignum_all() -> Result<()> {
|
||||
for bit_size in BIT_SIZES_TO_TEST {
|
||||
@ -394,3 +467,83 @@ fn test_mul_bignum_all() -> Result<()> {
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_modmul_bignum_all() -> Result<()> {
|
||||
for bit_size in BIT_SIZES_TO_TEST {
|
||||
let a = gen_bignum(bit_size);
|
||||
let b = gen_bignum(bit_size);
|
||||
let m = gen_bignum(bit_size);
|
||||
if !m.is_zero() {
|
||||
let output = a.clone() * b.clone() % m.clone();
|
||||
test_modmul_bignum(a, b, m, output)?;
|
||||
}
|
||||
|
||||
let a = max_bignum(bit_size);
|
||||
let b = max_bignum(bit_size);
|
||||
let m = max_bignum(bit_size);
|
||||
if !m.is_zero() {
|
||||
let output = a.clone() * b.clone() % m.clone();
|
||||
test_modmul_bignum(a, b, m, output)?;
|
||||
}
|
||||
}
|
||||
|
||||
let inputs = test_data_biguint(TEST_DATA_BIGNUM_INPUTS);
|
||||
let modmul_outputs = test_data_biguint(TEST_DATA_MODMUL_OUTPUTS);
|
||||
let mut modmul_outputs_iter = modmul_outputs.iter();
|
||||
for a in &inputs {
|
||||
for b in &inputs {
|
||||
// For m, skip the first input, which is zero.
|
||||
for m in &inputs[1..] {
|
||||
let output = modmul_outputs_iter.next().unwrap();
|
||||
test_modmul_bignum(a.clone(), b.clone(), m.clone(), output.clone())?;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_modexp_bignum_all() -> Result<()> {
|
||||
// Only test smaller values for exponent.
|
||||
let exp_bit_sizes = vec![2, 100, 127, 128, 129];
|
||||
|
||||
for bit_size in &BIT_SIZES_TO_TEST[3..] {
|
||||
for exp_bit_size in &exp_bit_sizes {
|
||||
let b = gen_bignum(*bit_size);
|
||||
let e = gen_bignum(*exp_bit_size);
|
||||
let m = gen_bignum(*bit_size);
|
||||
if !m.is_zero() {
|
||||
let output = b.clone().modpow(&e, &m);
|
||||
dbg!(b.clone(), e.clone(), m.clone(), output.clone());
|
||||
test_modexp_bignum(b, e, m, output)?;
|
||||
}
|
||||
|
||||
let b = max_bignum(*bit_size);
|
||||
let e = max_bignum(*exp_bit_size);
|
||||
let m = max_bignum(*bit_size);
|
||||
if !m.is_zero() {
|
||||
let output = b.clone().modpow(&e, &m);
|
||||
dbg!(b.clone(), e.clone(), m.clone(), output.clone());
|
||||
test_modexp_bignum(b, e, m, output)?;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// let inputs = test_data_biguint(TEST_DATA_BIGNUM_INPUTS);
|
||||
// let modexp_outputs = test_data_biguint(TEST_DATA_MODEXP_OUTPUTS);
|
||||
// let mut modexp_outputs_iter = modexp_outputs.iter();
|
||||
// for b in &inputs {
|
||||
// // Include only smaller exponents, to keep tests from becoming too slow.
|
||||
// for e in &inputs[..7] {
|
||||
// // For m, skip the first input, which is zero.
|
||||
// for m in &inputs[1..] {
|
||||
// let output = modexp_outputs_iter.next().unwrap();
|
||||
// test_modexp_bignum(b.clone(), e.clone(), m.clone(), output.clone())?;
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
|
||||
Ok(())
|
||||
}
|
||||
1470
evm/src/cpu/kernel/tests/bignum/test_data/modexp_outputs
Normal file
1470
evm/src/cpu/kernel/tests/bignum/test_data/modexp_outputs
Normal file
File diff suppressed because it is too large
Load Diff
3150
evm/src/cpu/kernel/tests/bignum/test_data/modmul_outputs
Normal file
3150
evm/src/cpu/kernel/tests/bignum/test_data/modmul_outputs
Normal file
File diff suppressed because it is too large
Load Diff
@ -37,7 +37,7 @@ impl<F: Field> GenerationState<F> {
|
||||
"mpt" => self.run_mpt(),
|
||||
"rlp" => self.run_rlp(),
|
||||
"account_code" => self.run_account_code(input_fn),
|
||||
"bignum_modmul" => self.run_bignum_modmul(input_fn),
|
||||
"bignum_modmul" => self.run_bignum_modmul(),
|
||||
_ => panic!("Unrecognized prover input function."),
|
||||
}
|
||||
}
|
||||
@ -128,11 +128,12 @@ impl<F: Field> GenerationState<F> {
|
||||
}
|
||||
}
|
||||
|
||||
// Bignum modular multiplication related code.
|
||||
fn run_bignum_modmul(&mut self, input_fn: &ProverInputFn) -> U256 {
|
||||
if self.bignum_modmul_prover_inputs.is_empty() {
|
||||
let function = input_fn.0[1].as_str();
|
||||
|
||||
// Bignum modular multiplication.
|
||||
// On the first call, calculates the remainder and quotient of the given inputs.
|
||||
// These are stored, as limbs, in self.bignum_modmul_result_limbs.
|
||||
// Subsequent calls return one limb at a time, in order (first remainder and then quotient).
|
||||
fn run_bignum_modmul(&mut self) -> U256 {
|
||||
if self.bignum_modmul_result_limbs.is_empty() {
|
||||
let len = stack_peek(self, 1)
|
||||
.expect("Stack does not have enough items")
|
||||
.try_into()
|
||||
@ -150,34 +151,31 @@ impl<F: Field> GenerationState<F> {
|
||||
.try_into()
|
||||
.unwrap();
|
||||
|
||||
let result = match function {
|
||||
"remainder" => {
|
||||
self.bignum_modmul_remainder(len, a_start_loc, b_start_loc, m_start_loc)
|
||||
}
|
||||
"quotient" => {
|
||||
self.bignum_modmul_quotient(len, a_start_loc, b_start_loc, m_start_loc)
|
||||
}
|
||||
_ => panic!("Invalid prover input function."),
|
||||
};
|
||||
let (remainder, quotient) =
|
||||
self.bignum_modmul(len, a_start_loc, b_start_loc, m_start_loc);
|
||||
|
||||
dbg!(remainder.clone(), quotient.clone());
|
||||
|
||||
self.bignum_modmul_prover_inputs = result
|
||||
self.bignum_modmul_result_limbs = remainder
|
||||
.iter()
|
||||
.cloned()
|
||||
.pad_using(len, |_| 0.into())
|
||||
.chain(quotient.iter().cloned().pad_using(len, |_| 0.into()))
|
||||
.collect();
|
||||
self.bignum_modmul_prover_inputs.reverse();
|
||||
dbg!(self.bignum_modmul_result_limbs.clone());
|
||||
self.bignum_modmul_result_limbs.reverse();
|
||||
}
|
||||
|
||||
self.bignum_modmul_prover_inputs.pop().unwrap()
|
||||
self.bignum_modmul_result_limbs.pop().unwrap()
|
||||
}
|
||||
|
||||
fn bignum_modmul_remainder(
|
||||
fn bignum_modmul(
|
||||
&mut self,
|
||||
len: usize,
|
||||
a_start_loc: usize,
|
||||
b_start_loc: usize,
|
||||
m_start_loc: usize,
|
||||
) -> Vec<U256> {
|
||||
) -> (Vec<U256>, Vec<U256>) {
|
||||
let a = &self.memory.contexts[0].segments[Segment::KernelGeneral as usize].content
|
||||
[a_start_loc..a_start_loc + len];
|
||||
let b = &self.memory.contexts[0].segments[Segment::KernelGeneral as usize].content
|
||||
@ -189,30 +187,12 @@ impl<F: Field> GenerationState<F> {
|
||||
let b_biguint = mem_vec_to_biguint(b);
|
||||
let m_biguint = mem_vec_to_biguint(m);
|
||||
|
||||
let result_biguint = (a_biguint * b_biguint) % m_biguint;
|
||||
biguint_to_mem_vec(result_biguint)
|
||||
}
|
||||
dbg!(a_biguint.clone(), b_biguint.clone(), m_biguint.clone());
|
||||
|
||||
fn bignum_modmul_quotient(
|
||||
&mut self,
|
||||
len: usize,
|
||||
a_start_loc: usize,
|
||||
b_start_loc: usize,
|
||||
m_start_loc: usize,
|
||||
) -> Vec<U256> {
|
||||
let a = &self.memory.contexts[0].segments[Segment::KernelGeneral as usize].content
|
||||
[a_start_loc..a_start_loc + len];
|
||||
let b = &self.memory.contexts[0].segments[Segment::KernelGeneral as usize].content
|
||||
[b_start_loc..b_start_loc + len];
|
||||
let m = &self.memory.contexts[0].segments[Segment::KernelGeneral as usize].content
|
||||
[m_start_loc..m_start_loc + len];
|
||||
|
||||
let a_biguint = mem_vec_to_biguint(a);
|
||||
let b_biguint = mem_vec_to_biguint(b);
|
||||
let m_biguint = mem_vec_to_biguint(m);
|
||||
|
||||
let result_biguint = (a_biguint * b_biguint) / m_biguint;
|
||||
biguint_to_mem_vec(result_biguint)
|
||||
let prod = a_biguint * b_biguint;
|
||||
let quo = prod.clone() / m_biguint.clone();
|
||||
let rem = prod - quo.clone() * m_biguint;
|
||||
(biguint_to_mem_vec(rem), biguint_to_mem_vec(quo))
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -40,9 +40,10 @@ pub(crate) struct GenerationState<F: Field> {
|
||||
/// addresses.
|
||||
pub(crate) state_key_to_address: HashMap<H256, Address>,
|
||||
|
||||
/// Prover inputs containing the result of a MODMUL-related operation, in little-endian order (so that
|
||||
/// inputs are obtained in big-endian order via `pop()`).
|
||||
pub(crate) bignum_modmul_prover_inputs: Vec<U256>,
|
||||
/// Prover inputs containing the result of a MODMUL operation, in little-endian order (so that
|
||||
/// inputs are obtained in big-endian order via `pop()`). Contains both the remainder and the
|
||||
/// quotient, in that order.
|
||||
pub(crate) bignum_modmul_result_limbs: Vec<U256>,
|
||||
}
|
||||
|
||||
impl<F: Field> GenerationState<F> {
|
||||
@ -58,7 +59,7 @@ impl<F: Field> GenerationState<F> {
|
||||
log::debug!("Input contract_code: {:?}", &inputs.contract_code);
|
||||
let mpt_prover_inputs = all_mpt_prover_inputs_reversed(&inputs.tries);
|
||||
let rlp_prover_inputs = all_rlp_prover_inputs_reversed(&inputs.signed_txns);
|
||||
let bignum_modmul_prover_inputs = Vec::new();
|
||||
let bignum_modmul_result_limbs = Vec::new();
|
||||
|
||||
Self {
|
||||
inputs,
|
||||
@ -69,7 +70,7 @@ impl<F: Field> GenerationState<F> {
|
||||
mpt_prover_inputs,
|
||||
rlp_prover_inputs,
|
||||
state_key_to_address: HashMap::new(),
|
||||
bignum_modmul_prover_inputs,
|
||||
bignum_modmul_result_limbs,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user