fixes, testing, and in-progress debugging

This commit is contained in:
Nicholas Ward 2023-03-21 16:03:54 -07:00
parent 24705e1e39
commit d59501e6a7
10 changed files with 4878 additions and 126 deletions

View File

@ -6,7 +6,6 @@
global add_bignum:
// stack: len, a_start_loc, b_start_loc, retdest
DUP1
// stack: len, len, a_start_loc, b_start_loc, retdest
ISZERO
%jumpi(len_zero)
// stack: len, a_start_loc, b_start_loc, retdest

View File

@ -8,6 +8,9 @@
// All of scratch_2..scratch_5 must have size 2 * length and be initialized with zeroes.
global modexp_bignum:
// stack: len, b_loc, e_loc, m_loc, out_loc, s1 (=scratch_1), s2, s3, s4, s5, retdest
DUP1
ISZERO
%jumpi(len_zero)
// We store the repeated-squares accumulator x_i in scratch_1, starting with x_0 := b.
DUP1
@ -125,10 +128,11 @@ modexp_iszero_return:
%jumpi(modexp_loop)
// end of modexp_loop
// stack: len, b_loc, e_loc, m_loc, out_loc, s1, s2, s3, s4, s5, retdest
%rep 10
POP
%endrep
%pop10
// stack: retdest
JUMP
len_zero:
// stack: len, b_loc, e_loc, m_loc, out_loc, s1, s2, s3, s4, s5, retdest
%pop10
// stack: retdest
JUMP

View File

@ -6,158 +6,141 @@
// Both output_loc and scratch_1 must have size length.
// Both scratch_2 and scratch_3 have size 2 * length and be initialized with zeroes.
global modmul_bignum:
// stack: length, a_start_loc, b_start_loc, m_start_loc, output_loc, scratch_1, scratch_2, scratch_3, retdest
// stack: len, a_loc, b_loc, m_loc, out_loc, s1 (=scratch_1), s2, s3, retdest
DUP1
ISZERO
%jumpi(len_zero)
// The prover provides x := (a * b) % m, which we store in output_loc.
PUSH 0
// stack: i=0, length, a_start_loc, b_start_loc, m_start_loc, output_loc, scratch_1, scratch_2, scratch_3, retdest
// stack: i=0, len, a_loc, b_loc, m_loc, out_loc, s1, s2, s3, retdest
modmul_remainder_loop:
// stack: i, length, a_start_loc, b_start_loc, m_start_loc, output_loc, scratch_1, scratch_2, scratch_3, retdest
PROVER_INPUT(bignum_modmul::remainder)
// stack: PI, i, length, a_start_loc, b_start_loc, m_start_loc, output_loc, scratch_1, scratch_2, scratch_3, retdest
// stack: i, len, a_loc, b_loc, m_loc, out_loc, s1, s2, s3, retdest
PROVER_INPUT(bignum_modmul)
// stack: PI, i, len, a_loc, b_loc, m_loc, out_loc, s1, s2, s3, retdest
DUP7
// stack: output_loc, PI, i, length, a_start_loc, b_start_loc, m_start_loc, output_loc, scratch_1, scratch_2, scratch_3, retdest
DUP3
// stack: i, output_loc, PI, i, length, a_start_loc, b_start_loc, m_start_loc, output_loc, scratch_1, scratch_2, scratch_3, retdest
ADD
// stack: output_loc[i], PI, i, length, a_start_loc, b_start_loc, m_start_loc, output_loc, scratch_1, scratch_2, scratch_3, retdest
// stack: out_loc[i], PI, i, len, a_loc, b_loc, m_loc, out_loc, s1, s2, s3, retdest
%mstore_kernel_general
// stack: i, length, a_start_loc, b_start_loc, m_start_loc, output_loc, scratch_1, scratch_2, scratch_3, retdest
// stack: i, len, a_loc, b_loc, m_loc, out_loc, s1, s2, s3, retdest
%increment
// stack: i+1, length, a_start_loc, b_start_loc, m_start_loc, output_loc, scratch_1, scratch_2, scratch_3, retdest
DUP2
DUP2
// stack: i+1, length, i+1, length, a_start_loc, b_start_loc, m_start_loc, output_loc, scratch_1, scratch_2, scratch_3, retdest
EQ
// stack: i+1==length, i+1, length, a_start_loc, b_start_loc, m_start_loc, output_loc, scratch_1, scratch_2, scratch_3, retdest
// stack: i+1, len, i+1, len, a_loc, b_loc, m_loc, out_loc, s1, s2, s3, retdest
SUB // functions as NEQ
// stack: i+1!=length, i+1, length, a_start_loc, b_start_loc, m_start_loc, output_loc, scratch_1, scratch_2, scratch_3, retdest
// stack: i+1!=len, i+1, len, a_loc, b_loc, m_loc, out_loc, s1, s2, s3, retdest
%jumpi(modmul_remainder_loop)
// end of modmul_remainder_loop
// stack: i, length, a_start_loc, b_start_loc, m_start_loc, output_loc, scratch_1, scratch_2, scratch_3, retdest
// stack: i, len, a_loc, b_loc, m_loc, out_loc, s1, s2, s3, retdest
POP
// The prover provides k := (a * b) / m, which we store in scratch_1.
// stack: length, a_start_loc, b_start_loc, m_start_loc, output_loc, scratch_1, scratch_2, scratch_3, retdest
// stack: len, a_loc, b_loc, m_loc, out_loc, s1, s2, s3, retdest
PUSH 0
// stack: i=0, length, a_start_loc, b_start_loc, m_start_loc, output_loc, scratch_1, scratch_2, scratch_3, retdest
// stack: i=0, len, a_loc, b_loc, m_loc, out_loc, s1, s2, s3, retdest
modmul_quotient_loop:
// stack: i, length, a_start_loc, b_start_loc, m_start_loc, output_loc, scratch_1, scratch_2, scratch_3, retdest
PROVER_INPUT(bignum_modmul::quotient)
// stack: PI, i, length, a_start_loc, b_start_loc, m_start_loc, output_loc, scratch_1, scratch_2, scratch_3, retdest
// stack: i, len, a_loc, b_loc, m_loc, out_loc, s1, s2, s3, retdest
PROVER_INPUT(bignum_modmul)
// stack: PI, i, len, a_loc, b_loc, m_loc, out_loc, s1, s2, s3, retdest
DUP8
// stack: scratch_1, PI, i, length, a_start_loc, b_start_loc, m_start_loc, output_loc, scratch_1, scratch_2, scratch_3, retdest
DUP3
// stack: i, scratch_1, PI, i, length, a_start_loc, b_start_loc, m_start_loc, output_loc, scratch_1, scratch_2, scratch_3, retdest
ADD
// stack: scratch_1[i], PI, i, length, a_start_loc, b_start_loc, m_start_loc, output_loc, scratch_1, scratch_2, scratch_3, retdest
// stack: s1[i], PI, i, len, a_loc, b_loc, m_loc, out_loc, s1, s2, s3, retdest
%mstore_kernel_general
// stack: i, length, a_start_loc, b_start_loc, m_start_loc, output_loc, scratch_1, scratch_2, scratch_3, retdest
// stack: i, len, a_loc, b_loc, m_loc, out_loc, s1, s2, s3, retdest
%increment
// stack: i+1, length, a_start_loc, b_start_loc, m_start_loc, output_loc, scratch_1, scratch_2, scratch_3, retdest
DUP2
DUP2
// stack: i+1, length, i+1, length, a_start_loc, b_start_loc, m_start_loc, output_loc, scratch_1, scratch_2, scratch_3, retdest
// stack: i+1, len, i+1, len, a_loc, b_loc, m_loc, out_loc, s1, s2, s3, retdest
SUB // functions as NEQ
// stack: i+1!=length, i+1, length, a_start_loc, b_start_loc, m_start_loc, output_loc, scratch_1, scratch_2, scratch_3, retdest
// stack: i+1!=len, i+1, len, a_loc, b_loc, m_loc, out_loc, s1, s2, s3, retdest
%jumpi(modmul_quotient_loop)
// end of modmul_quotient_loop
// stack: i, length, a_start_loc, b_start_loc, m_start_loc, output_loc, scratch_1, scratch_2, scratch_3, retdest
// stack: i, len, a_loc, b_loc, m_loc, out_loc, s1, s2, s3, retdest
POP
// stack: length, a_start_loc, b_start_loc, m_start_loc, output_loc, scratch_1, scratch_2, scratch_3, retdest
// stack: len, a_loc, b_loc, m_loc, out_loc, s1, s2, s3, retdest
// Verification step 1: calculate x + k * m.
// Store k * m in scratch_2.
PUSH modmul_return_1
// stack: modmul_return_1, length, a_start_loc, b_start_loc, m_start_loc, output_loc, scratch_1, scratch_2, scratch_3, retdest
%stack (return, len, a, b, m, out, s1, s2) -> (len, s1, m, s2, return, len, a, b, out, s2)
// stack: length, scratch_1, m_start_loc, scratch_2, modmul_return_1, length, a_start_loc, b_start_loc, output_loc, scratch_2, scratch_3, retdest
// stack: len, s1, m_loc, s2, modmul_return_1, len, a_loc, b_loc, out_loc, s2, s3, retdest
%jump(mul_bignum)
modmul_return_1:
// stack: length, a_start_loc, b_start_loc, output_loc, scratch_2, scratch_3, retdest
// stack: len, a_loc, b_loc, out_loc, s2, s3, retdest
// Add x into k * m (in scratch_2).
PUSH modmul_return_2
// stack: modmul_return_2, length, a_start_loc, b_start_loc, output_loc, scratch_2, scratch_3, retdest
%stack (return, len, a, b, out, s2) -> (len, s2, out, return, len, a, b, s2)
// stack: length, scratch_2, output_loc, modmul_return_2, length, a_start_loc, b_start_loc, scratch_2, scratch_3, retdest
// stack: len, s2, out_loc, modmul_return_2, len, a_loc, b_loc, s2, s3, retdest
%jump(add_bignum)
modmul_return_2:
// stack: carry, length, a_start_loc, b_start_loc, scratch_2, scratch_3, retdest
// stack: carry, len, a_loc, b_loc, s2, s3, retdest
ISZERO
%jumpi(no_carry)
// stack: length, a_start_loc, b_start_loc, scratch_2, scratch_3, retdest
// stack: len, a_loc, b_loc, s2, s3, retdest
DUP4
// stack: scratch_2, length, a_start_loc, b_start_loc, scratch_2, scratch_3, retdest
DUP2
// stack: length, scratch_2, length, a_start_loc, b_start_loc, scratch_2, scratch_3, retdest
ADD
// stack: cur_loc=scratch_2 + length, length, a_start_loc, b_start_loc, scratch_2, scratch_3, retdest
// stack: cur_loc=s2 + len, len, a_loc, b_loc, s2, s3, retdest
increment_loop:
// stack: cur_loc, length, a_start_loc, b_start_loc, scratch_2, scratch_3, retdest
// stack: cur_loc, len, a_loc, b_loc, s2, s3, retdest
DUP1
%mload_kernel_general
// stack: val, cur_loc, length, a_start_loc, b_start_loc, scratch_2, scratch_3, retdest
// stack: val, cur_loc, len, a_loc, b_loc, s2, s3, retdest
%increment
// stack: val+1, cur_loc, length, a_start_loc, b_start_loc, scratch_2, scratch_3, retdest
DUP1
// stack: val+1, val+1, cur_loc, length, a_start_loc, b_start_loc, scratch_2, scratch_3, retdest
// stack: val+1, val+1, cur_loc, len, a_loc, b_loc, s2, s3, retdest
%eq_const(@BIGNUM_LIMB_BASE)
// stack: val+1==limb_base, val+1, cur_loc, length, a_start_loc, b_start_loc, scratch_2, scratch_3, retdest
DUP1
// stack: val+1==limb_base, val+1==limb_base, val+1, cur_loc, length, a_start_loc, b_start_loc, scratch_2, scratch_3, retdest
ISZERO
// stack: val+1!=limb_base, val+1==limb_base, val+1, cur_loc, length, a_start_loc, b_start_loc, scratch_2, scratch_3, retdest
// stack: val+1!=limb_base, val+1==limb_base, val+1, cur_loc, len, a_loc, b_loc, s2, s3, retdest
SWAP1
SWAP2
// stack: val+1, val+1!=limb_base, val+1==limb_base, cur_loc, length, a_start_loc, b_start_loc, scratch_2, scratch_3, retdest
// stack: val+1, val+1!=limb_base, val+1==limb_base, cur_loc, len, a_loc, b_loc, s2, s3, retdest
MUL
// stack: to_write=(val+1)*(val+1!=limb_base), continue=val+1==limb_base, cur_loc, length, a_start_loc, b_start_loc, scratch_2, scratch_3, retdest
// stack: to_write=(val+1)*(val+1!=limb_base), continue=val+1==limb_base, cur_loc, len, a_loc, b_loc, s2, s3, retdest
DUP3
// stack: cur_loc, to_write, continue, cur_loc, length, a_start_loc, b_start_loc, scratch_2, scratch_3, retdest
// stack: cur_loc, to_write, continue, cur_loc, len, a_loc, b_loc, s2, s3, retdest
%mstore_kernel_general
// stack: continue, cur_loc, length, a_start_loc, b_start_loc, scratch_2, scratch_3, retdest
// stack: continue, cur_loc, len, a_loc, b_loc, s2, s3, retdest
SWAP1
// stack: cur_loc, continue, length, a_start_loc, b_start_loc, scratch_2, scratch_3, retdest
%increment
// stack: cur_loc + 1, continue, length, a_start_loc, b_start_loc, scratch_2, scratch_3, retdest
DUP1
// stack: cur_loc + 1, cur_loc + 1, continue, length, a_start_loc, b_start_loc, scratch_2, scratch_3, retdest
DUP8
// stack: scratch_3, cur_loc + 1, cur_loc + 1, continue, length, a_start_loc, b_start_loc, scratch_2, scratch_3, retdest
// stack: s3, cur_loc + 1, cur_loc + 1, continue, len, a_loc, b_loc, s2, s3, retdest
EQ
// stack: cur_loc + 1 == scratch_3, cur_loc + 1, continue, length, a_start_loc, b_start_loc, scratch_2, scratch_3, retdest
ISZERO
// stack: cur_loc + 1 != scratch_3, cur_loc + 1, continue, length, a_start_loc, b_start_loc, scratch_2, scratch_3, retdest
// stack: cur_loc + 1 != s3, cur_loc + 1, continue, len, a_loc, b_loc, s2, s3, retdest
SWAP1
SWAP2
// stack: continue, cur_loc + 1 != scratch_3, cur_loc + 1, length, a_start_loc, b_start_loc, scratch_2, scratch_3, retdest
// stack: continue, cur_loc + 1 != s3, cur_loc + 1, len, a_loc, b_loc, s2, s3, retdest
MUL
// stack: new_continue=continue*(cur_loc + 1 != scratch_3), cur_loc + 1, length, a_start_loc, b_start_loc, scratch_2, scratch_3, retdest
// stack: new_continue=continue*(cur_loc + 1 != s3), cur_loc + 1, len, a_loc, b_loc, s2, s3, retdest
%jumpi(increment_loop)
// stack: cur_loc + 1, length, a_start_loc, b_start_loc, scratch_2, scratch_3, retdest
// stack: cur_loc + 1, len, a_loc, b_loc, s2, s3, retdest
POP
no_carry:
// stack: length, a_start_loc, b_start_loc, scratch_2, scratch_3, retdest
// stack: len, a_loc, b_loc, s2, s3, retdest
// Calculate a * b.
// Store a * b in scratch_3.
PUSH modmul_return_3
// stack: modmul_return_3, length, a_start_loc, b_start_loc, scratch_2, scratch_3, retdest
%stack (return, len, a, b, s2, s3) -> (len, a, b, s3, return, len, s2, s3)
// stack: length, a_start_loc, b_start_loc, scratch_3, modmul_return_3, length, scratch_2, scratch_3, retdest
// stack: len, a_loc, b_loc, s3, modmul_return_3, len, s2, s3, retdest
%jump(mul_bignum)
modmul_return_3:
// stack: length, scratch_2, scratch_3, retdest
// stack: len, s2, s3, retdest
// Check that x + k * m = a * b.
// Walk through scratch_2 and scratch_3, checking that they are equal.
// stack: n=length, i=scratch_2, j=scratch_3, retdest
// stack: n=len, i=s2, j=s3, retdest
modmul_check_loop:
// stack: n, i, j, retdest
%stack (l, idx: 2) -> (idx, l, idx)
@ -170,23 +153,23 @@ modmul_check_loop:
%assert_eq
// stack: n, i, j, retdest
%decrement
// stack: n-1, i, j, retdest
SWAP1
// stack: i, n-1, j, retdest
%increment
// stack: i+1, n-1, j, retdest
SWAP2
// stack: j, n-1, i+1, retdest
%increment
// stack: j+1, n-1, i+1, retdest
SWAP2
SWAP1
// stack: n-1, i+1, j+1, retdest
DUP1
// stack: n-1, n-1, i+1, j+1, retdest
%jumpi(modmul_check_loop)
modmul_check_end:
// end of modmul_check_loop
// stack: n-1, i+1, j+1, retdest
%pop3
// stack: retdest
JUMP
len_zero:
// stack: len, a_loc, b_loc, m_loc, out_loc, s1, s2, s3, retdest
%pop8
// stack: retdest
JUMP

View File

@ -50,6 +50,18 @@
%endrep
%endmacro
%macro pop9
%rep 9
POP
%endrep
%endmacro
%macro pop10
%rep 10
POP
%endrep
%endmacro
%macro and_const(c)
// stack: input, ...
PUSH $c

View File

@ -399,9 +399,9 @@ impl<'a> Interpreter<'a> {
.debug_offsets
.contains(&self.generation_state.registers.program_counter)
{
println!("At {}, stack={:?}", self.offset_name(), self.stack());
// println!("At {}, stack={:?}", self.offset_name(), self.stack());
} else if let Some(label) = self.offset_label() {
println!("At {label}");
// println!("At {label}");
}
Ok(())

View File

@ -20,12 +20,15 @@ const MINUS_ONE: U256 = U256::MAX;
const TEST_DATA_BIGNUM_INPUTS: &str = "bignum_inputs";
const TEST_DATA_U128_INPUTS: &str = "u128_inputs";
const TEST_DATA_SHR_OUTPUTS: &str = "shr_outputs";
const TEST_DATA_ISZERO_OUTPUTS: &str = "iszero_outputs";
const TEST_DATA_CMP_OUTPUTS: &str = "cmp_outputs";
const TEST_DATA_ADD_OUTPUTS: &str = "add_outputs";
const TEST_DATA_ADDMUL_OUTPUTS: &str = "addmul_outputs";
const TEST_DATA_MUL_OUTPUTS: &str = "mul_outputs";
const TEST_DATA_MODMUL_OUTPUTS: &str = "modmul_outputs";
const TEST_DATA_MODEXP_OUTPUTS: &str = "modexp_outputs";
const BIT_SIZES_TO_TEST: [usize; 15] = [
0, 1, 2, 127, 128, 129, 255, 256, 257, 512, 1000, 1023, 1024, 1025, 31415,
@ -232,6 +235,76 @@ fn test_mul_bignum(a: BigUint, b: BigUint, expected_output: BigUint) -> Result<(
Ok(())
}
fn test_modmul_bignum(a: BigUint, b: BigUint, m: BigUint, expected_output: BigUint) -> Result<()> {
let len = bignum_len(&a).max(bignum_len(&b)).max(bignum_len(&m));
let output_len = len;
let memory = pad_bignums(&[a, b, m], len);
let a_start_loc = 0;
let b_start_loc = len;
let m_start_loc = 2 * len;
let output_start_loc = 3 * len;
let scratch_1 = 4 * len;
let scratch_2 = 5 * len; // size 2*len
let scratch_3 = 7 * len; // size 2*len
let (new_memory, _new_stack) = run_test(
"modmul_bignum",
memory,
vec![
len.into(),
a_start_loc.into(),
b_start_loc.into(),
m_start_loc.into(),
output_start_loc.into(),
scratch_1.into(),
scratch_2.into(),
scratch_3.into(),
],
)?;
let output = mem_vec_to_biguint(&new_memory[output_start_loc..output_start_loc + output_len]);
assert_eq!(output, expected_output);
Ok(())
}
fn test_modexp_bignum(b: BigUint, e: BigUint, m: BigUint, expected_output: BigUint) -> Result<()> {
let len = bignum_len(&b).max(bignum_len(&e)).max(bignum_len(&m));
let output_len = len;
let memory = pad_bignums(&[b, e, m], len);
let b_start_loc = 0;
let e_start_loc = len;
let m_start_loc = 2 * len;
let output_start_loc = 3 * len;
let scratch_1 = 4 * len;
let scratch_2 = 5 * len; // size 2*len
let scratch_3 = 7 * len; // size 2*len
let scratch_4 = 9 * len; // size 2*len
let scratch_5 = 11 * len; // size 2*len
let (new_memory, new_stack) = run_test(
"modexp_bignum",
memory,
vec![
len.into(),
b_start_loc.into(),
e_start_loc.into(),
m_start_loc.into(),
output_start_loc.into(),
scratch_1.into(),
scratch_2.into(),
scratch_3.into(),
scratch_4.into(),
scratch_5.into(),
],
)?;
let output = mem_vec_to_biguint(&new_memory[output_start_loc..output_start_loc + output_len]);
assert_eq!(output, expected_output);
Ok(())
}
#[test]
fn test_shr_bignum_all() -> Result<()> {
for bit_size in BIT_SIZES_TO_TEST {
@ -394,3 +467,83 @@ fn test_mul_bignum_all() -> Result<()> {
Ok(())
}
#[test]
fn test_modmul_bignum_all() -> Result<()> {
for bit_size in BIT_SIZES_TO_TEST {
let a = gen_bignum(bit_size);
let b = gen_bignum(bit_size);
let m = gen_bignum(bit_size);
if !m.is_zero() {
let output = a.clone() * b.clone() % m.clone();
test_modmul_bignum(a, b, m, output)?;
}
let a = max_bignum(bit_size);
let b = max_bignum(bit_size);
let m = max_bignum(bit_size);
if !m.is_zero() {
let output = a.clone() * b.clone() % m.clone();
test_modmul_bignum(a, b, m, output)?;
}
}
let inputs = test_data_biguint(TEST_DATA_BIGNUM_INPUTS);
let modmul_outputs = test_data_biguint(TEST_DATA_MODMUL_OUTPUTS);
let mut modmul_outputs_iter = modmul_outputs.iter();
for a in &inputs {
for b in &inputs {
// For m, skip the first input, which is zero.
for m in &inputs[1..] {
let output = modmul_outputs_iter.next().unwrap();
test_modmul_bignum(a.clone(), b.clone(), m.clone(), output.clone())?;
}
}
}
Ok(())
}
#[test]
fn test_modexp_bignum_all() -> Result<()> {
// Only test smaller values for exponent.
let exp_bit_sizes = vec![2, 100, 127, 128, 129];
for bit_size in &BIT_SIZES_TO_TEST[3..] {
for exp_bit_size in &exp_bit_sizes {
let b = gen_bignum(*bit_size);
let e = gen_bignum(*exp_bit_size);
let m = gen_bignum(*bit_size);
if !m.is_zero() {
let output = b.clone().modpow(&e, &m);
dbg!(b.clone(), e.clone(), m.clone(), output.clone());
test_modexp_bignum(b, e, m, output)?;
}
let b = max_bignum(*bit_size);
let e = max_bignum(*exp_bit_size);
let m = max_bignum(*bit_size);
if !m.is_zero() {
let output = b.clone().modpow(&e, &m);
dbg!(b.clone(), e.clone(), m.clone(), output.clone());
test_modexp_bignum(b, e, m, output)?;
}
}
}
// let inputs = test_data_biguint(TEST_DATA_BIGNUM_INPUTS);
// let modexp_outputs = test_data_biguint(TEST_DATA_MODEXP_OUTPUTS);
// let mut modexp_outputs_iter = modexp_outputs.iter();
// for b in &inputs {
// // Include only smaller exponents, to keep tests from becoming too slow.
// for e in &inputs[..7] {
// // For m, skip the first input, which is zero.
// for m in &inputs[1..] {
// let output = modexp_outputs_iter.next().unwrap();
// test_modexp_bignum(b.clone(), e.clone(), m.clone(), output.clone())?;
// }
// }
// }
Ok(())
}

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -37,7 +37,7 @@ impl<F: Field> GenerationState<F> {
"mpt" => self.run_mpt(),
"rlp" => self.run_rlp(),
"account_code" => self.run_account_code(input_fn),
"bignum_modmul" => self.run_bignum_modmul(input_fn),
"bignum_modmul" => self.run_bignum_modmul(),
_ => panic!("Unrecognized prover input function."),
}
}
@ -128,11 +128,12 @@ impl<F: Field> GenerationState<F> {
}
}
// Bignum modular multiplication related code.
fn run_bignum_modmul(&mut self, input_fn: &ProverInputFn) -> U256 {
if self.bignum_modmul_prover_inputs.is_empty() {
let function = input_fn.0[1].as_str();
// Bignum modular multiplication.
// On the first call, calculates the remainder and quotient of the given inputs.
// These are stored, as limbs, in self.bignum_modmul_result_limbs.
// Subsequent calls return one limb at a time, in order (first remainder and then quotient).
fn run_bignum_modmul(&mut self) -> U256 {
if self.bignum_modmul_result_limbs.is_empty() {
let len = stack_peek(self, 1)
.expect("Stack does not have enough items")
.try_into()
@ -150,34 +151,31 @@ impl<F: Field> GenerationState<F> {
.try_into()
.unwrap();
let result = match function {
"remainder" => {
self.bignum_modmul_remainder(len, a_start_loc, b_start_loc, m_start_loc)
}
"quotient" => {
self.bignum_modmul_quotient(len, a_start_loc, b_start_loc, m_start_loc)
}
_ => panic!("Invalid prover input function."),
};
let (remainder, quotient) =
self.bignum_modmul(len, a_start_loc, b_start_loc, m_start_loc);
dbg!(remainder.clone(), quotient.clone());
self.bignum_modmul_prover_inputs = result
self.bignum_modmul_result_limbs = remainder
.iter()
.cloned()
.pad_using(len, |_| 0.into())
.chain(quotient.iter().cloned().pad_using(len, |_| 0.into()))
.collect();
self.bignum_modmul_prover_inputs.reverse();
dbg!(self.bignum_modmul_result_limbs.clone());
self.bignum_modmul_result_limbs.reverse();
}
self.bignum_modmul_prover_inputs.pop().unwrap()
self.bignum_modmul_result_limbs.pop().unwrap()
}
fn bignum_modmul_remainder(
fn bignum_modmul(
&mut self,
len: usize,
a_start_loc: usize,
b_start_loc: usize,
m_start_loc: usize,
) -> Vec<U256> {
) -> (Vec<U256>, Vec<U256>) {
let a = &self.memory.contexts[0].segments[Segment::KernelGeneral as usize].content
[a_start_loc..a_start_loc + len];
let b = &self.memory.contexts[0].segments[Segment::KernelGeneral as usize].content
@ -189,30 +187,12 @@ impl<F: Field> GenerationState<F> {
let b_biguint = mem_vec_to_biguint(b);
let m_biguint = mem_vec_to_biguint(m);
let result_biguint = (a_biguint * b_biguint) % m_biguint;
biguint_to_mem_vec(result_biguint)
}
dbg!(a_biguint.clone(), b_biguint.clone(), m_biguint.clone());
fn bignum_modmul_quotient(
&mut self,
len: usize,
a_start_loc: usize,
b_start_loc: usize,
m_start_loc: usize,
) -> Vec<U256> {
let a = &self.memory.contexts[0].segments[Segment::KernelGeneral as usize].content
[a_start_loc..a_start_loc + len];
let b = &self.memory.contexts[0].segments[Segment::KernelGeneral as usize].content
[b_start_loc..b_start_loc + len];
let m = &self.memory.contexts[0].segments[Segment::KernelGeneral as usize].content
[m_start_loc..m_start_loc + len];
let a_biguint = mem_vec_to_biguint(a);
let b_biguint = mem_vec_to_biguint(b);
let m_biguint = mem_vec_to_biguint(m);
let result_biguint = (a_biguint * b_biguint) / m_biguint;
biguint_to_mem_vec(result_biguint)
let prod = a_biguint * b_biguint;
let quo = prod.clone() / m_biguint.clone();
let rem = prod - quo.clone() * m_biguint;
(biguint_to_mem_vec(rem), biguint_to_mem_vec(quo))
}
}

View File

@ -40,9 +40,10 @@ pub(crate) struct GenerationState<F: Field> {
/// addresses.
pub(crate) state_key_to_address: HashMap<H256, Address>,
/// Prover inputs containing the result of a MODMUL-related operation, in little-endian order (so that
/// inputs are obtained in big-endian order via `pop()`).
pub(crate) bignum_modmul_prover_inputs: Vec<U256>,
/// Prover inputs containing the result of a MODMUL operation, in little-endian order (so that
/// inputs are obtained in big-endian order via `pop()`). Contains both the remainder and the
/// quotient, in that order.
pub(crate) bignum_modmul_result_limbs: Vec<U256>,
}
impl<F: Field> GenerationState<F> {
@ -58,7 +59,7 @@ impl<F: Field> GenerationState<F> {
log::debug!("Input contract_code: {:?}", &inputs.contract_code);
let mpt_prover_inputs = all_mpt_prover_inputs_reversed(&inputs.tries);
let rlp_prover_inputs = all_rlp_prover_inputs_reversed(&inputs.signed_txns);
let bignum_modmul_prover_inputs = Vec::new();
let bignum_modmul_result_limbs = Vec::new();
Self {
inputs,
@ -69,7 +70,7 @@ impl<F: Field> GenerationState<F> {
mpt_prover_inputs,
rlp_prover_inputs,
state_key_to_address: HashMap::new(),
bignum_modmul_prover_inputs,
bignum_modmul_result_limbs,
}
}