Merge pull request #925 from mir-protocol/bignum-modexp

Bignum modexp
This commit is contained in:
Nicholas Ward 2023-04-04 13:37:48 -07:00 committed by GitHub
commit d59fa59af8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
16 changed files with 5186 additions and 30 deletions

View File

@ -17,6 +17,8 @@ pub(crate) fn combined_kernel() -> Kernel {
include_str!("asm/bignum/addmul.asm"),
include_str!("asm/bignum/cmp.asm"),
include_str!("asm/bignum/iszero.asm"),
include_str!("asm/bignum/modexp.asm"),
include_str!("asm/bignum/modmul.asm"),
include_str!("asm/bignum/mul.asm"),
include_str!("asm/bignum/shr.asm"),
include_str!("asm/bignum/util.asm"),

View File

@ -6,7 +6,6 @@
global add_bignum:
// stack: len, a_start_loc, b_start_loc, retdest
DUP1
// stack: len, len, a_start_loc, b_start_loc, retdest
ISZERO
%jumpi(len_zero)
// stack: len, a_start_loc, b_start_loc, retdest
@ -57,6 +56,7 @@ add_end:
SWAP1
// stack: retdest, carry_new
JUMP
len_zero:
// stack: len, a_start_loc, b_start_loc, retdest
%pop3

View File

@ -101,6 +101,7 @@ addmul_end:
SWAP1
// stack: retdest, carry_limb_new
JUMP
len_zero:
// stack: len, a_start_loc, b_start_loc, val, retdest
%pop4

View File

@ -0,0 +1,135 @@
// Arithmetic on integers represented with 128-bit limbs.
// These integers are represented in LITTLE-ENDIAN form.
// All integers must be under a given length bound, and are padded with leading zeroes.
// Stores b ^ e % m in output_loc, leaving b, e, and m unchanged.
// b, e, and m must have the same length.
// output_loc must have size length and be initialized with zeroes; scratch_1 must have size length.
// All of scratch_2..scratch_5 must have size 2 * length and be initialized with zeroes.
// Also, scratch_2..scratch_5 must be CONSECUTIVE in memory.
global modexp_bignum:
// stack: len, b_loc, e_loc, m_loc, out_loc, s1 (=scratch_1), s2, s3, s4, s5, retdest
DUP1
ISZERO
%jumpi(len_zero)
// We store the repeated-squares accumulator x_i in scratch_1, starting with x_0 := b.
DUP1
DUP3
DUP8
// stack: s1, b_loc, len, len, b_loc, e_loc, m_loc, out_loc, s1, s2, s3, s4, s5, retdest
%memcpy_kernel_general
// stack: len, b_loc, e_loc, m_loc, out_loc, s1, s2, s3, s4, s5, retdest
// We store the accumulated output value x_i in output_loc, starting with x_0=1.
PUSH 1
DUP6
// stack: out_loc, 1, len, b_loc, e_loc, m_loc, out_loc, s1, s2, s3, s4, s5, retdest
%mstore_kernel_general
modexp_loop:
// stack: len, b_loc, e_loc, m_loc, out_loc, s1, s2, s3, s4, s5, retdest
// y := e % 2
DUP3
// stack: e_loc, len, b_loc, e_loc, m_loc, out_loc, s1, s2, s3, s4, s5, retdest
%mload_kernel_general
// stack: e_first, len, b_loc, e_loc, m_loc, out_loc, s1, s2, s3, s4, s5, retdest
%mod_const(2)
// stack: y = e_first % 2 = e % 2, len, b_loc, e_loc, m_loc, out_loc, s1, s2, s3, s4, s5, retdest
ISZERO
// stack: y == 0, len, b_loc, e_loc, m_loc, out_loc, s1, s2, s3, s4, s5, retdest
%jumpi(modexp_y_0)
// if y == 1, modular-multiply output_loc by scratch_1, using scratch_2..scratch_4 as scratch space, and store in scratch_5.
PUSH modexp_mul_return
DUP10
DUP10
DUP10
DUP14
DUP9
DUP12
DUP12
DUP9
// stack: len, out_loc, s1, m_loc, s5, s2, s3, s4, modexp_mul_return, len, b_loc, e_loc, m_loc, out_loc, s1, s2, s3, s4, s5, retdest
%jump(modmul_bignum)
modexp_mul_return:
// stack: len, b_loc, e_loc, m_loc, out_loc, s1, s2, s3, s4, s5, retdest
// Copy scratch_5 to output_loc.
DUP1
DUP11
DUP7
// stack: out_loc, s5, len, len, b_loc, e_loc, m_loc, out_loc, s1, s2, s3, s4, s5, retdest
%memcpy_kernel_general
// stack: len, b_loc, e_loc, m_loc, out_loc, s1, s2, s3, s4, s5, retdest
// Zero out scratch_2..scratch_5.
DUP1
%mul_const(8)
DUP8
// stack: s2, 8 * len, len, b_loc, e_loc, m_loc, out_loc, s1, s2, s3, s4, s5, retdest
%clear_kernel_general
// stack: len, b_loc, e_loc, m_loc, out_loc, s1, s2, s3, s4, s5, retdest
modexp_y_0:
// if y == 0, do nothing
// Modular-square repeated-squares accumulator x_i (in scratch_1), using scratch_2..scratch_4 as scratch space, and store in scratch_5.
PUSH modexp_square_return
DUP10
DUP10
DUP10
DUP14
DUP9
DUP12
DUP1
DUP9
// stack: len, s1, s1, m_loc, s5, s2, s3, s4, modexp_square_return, len, b_loc, e_loc, m_loc, out_loc, s1, s2, s3, s4, s5, retdest
%jump(modmul_bignum)
// stack: len, b_loc, e_loc, m_loc, out_loc, s1, s2, s3, s4, s5, retdest
modexp_square_return:
// Copy scratch_5 to scratch_1.
DUP1
DUP11
DUP8
// stack: s1, s5, len, len, b_loc, e_loc, m_loc, out_loc, s1, s2, s3, s4, s5, retdest
%memcpy_kernel_general
// stack: len, b_loc, e_loc, m_loc, out_loc, s1, s2, s3, s4, s5, retdest
// Zero out scratch_2..scratch_5.
DUP1
%mul_const(8)
DUP8
// stack: s2, 8 * len, len, b_loc, e_loc, m_loc, out_loc, s1, s2, s3, s4, s5, retdest
%clear_kernel_general
// stack: len, b_loc, e_loc, m_loc, out_loc, s1, s2, s3, s4, s5, retdest
// e //= 2 (with shr_bignum)
PUSH modexp_shr_return
DUP4
DUP3
// stack: len, e_loc, modexp_shr_return, len, b_loc, e_loc, m_loc, out_loc, s1, s2, s3, s4, s5, retdest
%jump(shr_bignum)
modexp_shr_return:
// stack: len, b_loc, e_loc, m_loc, out_loc, s1, s2, s3, s4, s5, retdest
// check if e == 0 (with iszero_bignum)
PUSH modexp_iszero_return
DUP4
DUP3
// stack: len, e_loc, modexp_iszero_return, len, b_loc, e_loc, m_loc, out_loc, s1, s2, s3, s4, s5, retdest
%jump(iszero_bignum)
modexp_iszero_return:
// stack: e == 0, len, b_loc, e_loc, m_loc, out_loc, s1, s2, s3, s4, s5, retdest
ISZERO
// stack: e != 0, len, b_loc, e_loc, m_loc, out_loc, s1, s2, s3, s4, s5, retdest
%jumpi(modexp_loop)
// end of modexp_loop
len_zero:
// stack: len, b_loc, e_loc, m_loc, out_loc, s1, s2, s3, s4, s5, retdest
%pop10
// stack: retdest
JUMP

View File

@ -0,0 +1,167 @@
// Arithmetic on little-endian integers represented with 128-bit limbs.
// All integers must be under a given length bound, and are padded with leading zeroes.
// Stores a * b % m in output_loc, leaving a, b, and m unchanged.
// a, b, and m must have the same length.
// output_loc must have size length; scratch_2 must have size 2*length.
// Both scratch_2 and scratch_3 have size 2*length and be initialized with zeroes.
// The prover provides x := (a * b) % m, which is the output of this function.
// We first check that x < m.
// The prover also provides k := (a * b) / m, stored in scratch space.
// We then check that x + k * m = a * b, by computing both of those using
// bignum arithmetic, storing the results in scratch space.
// We assert equality between those two, limb by limb.
global modmul_bignum:
// stack: len, a_loc, b_loc, m_loc, out_loc, s1 (=scratch_1), s2, s3, retdest
DUP1
ISZERO
%jumpi(len_zero)
// STEP 1:
// The prover provides x := (a * b) % m, which we store in output_loc.
PUSH 0
// stack: i=0, len, a_loc, b_loc, m_loc, out_loc, s1, s2, s3, retdest
modmul_remainder_loop:
// stack: i, len, a_loc, b_loc, m_loc, out_loc, s1, s2, s3, retdest
PROVER_INPUT(bignum_modmul)
// stack: PI, i, len, a_loc, b_loc, m_loc, out_loc, s1, s2, s3, retdest
DUP7
DUP3
ADD
// stack: out_loc[i], PI, i, len, a_loc, b_loc, m_loc, out_loc, s1, s2, s3, retdest
%mstore_kernel_general
// stack: i, len, a_loc, b_loc, m_loc, out_loc, s1, s2, s3, retdest
%increment
DUP2
DUP2
// stack: i+1, len, i+1, len, a_loc, b_loc, m_loc, out_loc, s1, s2, s3, retdest
SUB // functions as NEQ
// stack: i+1!=len, i+1, len, a_loc, b_loc, m_loc, out_loc, s1, s2, s3, retdest
%jumpi(modmul_remainder_loop)
// end of modmul_remainder_loop
// stack: i, len, a_loc, b_loc, m_loc, out_loc, s1, s2, s3, retdest
POP
// stack: len, a_loc, b_loc, m_loc, out_loc, s1, s2, s3, retdest
// STEP 2:
// We check that x < m.
PUSH modmul_return_1
DUP6
DUP6
DUP4
// stack: len, m_loc, out_loc, modmul_return_1, len, a_loc, b_loc, m_loc, out_loc, s1, s2, s3, retdest
// Should return 1 iff the value at m_loc > the value at out_loc; in other words, if x < m.
%jump(cmp_bignum)
modmul_return_1:
// stack: cmp_result, len, a_loc, b_loc, m_loc, out_loc, s1, s2, s3, retdest
PUSH 1
%assert_eq
// STEP 3:
// The prover provides k := (a * b) / m, which we store in scratch_1.
// stack: len, a_loc, b_loc, m_loc, out_loc, s1, s2, s3, retdest
DUP1
// stack: len, len, a_loc, b_loc, m_loc, out_loc, s1, s2, s3, retdest
%mul_const(2)
// stack: 2*len, len, a_loc, b_loc, m_loc, out_loc, s1, s2, s3, retdest
PUSH 0
// stack: i=0, 2*len, len, a_loc, b_loc, m_loc, out_loc, s1, s2, s3, retdest
modmul_quotient_loop:
// stack: i, 2*len, len, a_loc, b_loc, m_loc, out_loc, s1, s2, s3, retdest
PROVER_INPUT(bignum_modmul)
// stack: PI, i, 2*len, len, a_loc, b_loc, m_loc, out_loc, s1, s2, s3, retdest
DUP9
DUP3
ADD
// stack: s1[i], PI, i, 2*len, len, a_loc, b_loc, m_loc, out_loc, s1, s2, s3, retdest
%mstore_kernel_general
// stack: i, 2*len, len, a_loc, b_loc, m_loc, out_loc, s1, s2, s3, retdest
%increment
DUP2
DUP2
// stack: i+1, 2*len, i+1, 2*len, len, a_loc, b_loc, m_loc, out_loc, s1, s2, s3, retdest
SUB // functions as NEQ
// stack: i+1!=2*len, i+1, 2*len, len, a_loc, b_loc, m_loc, out_loc, s1, s2, s3, retdest
%jumpi(modmul_quotient_loop)
// end of modmul_quotient_loop
// stack: i, 2*len, len, a_loc, b_loc, m_loc, out_loc, s1, s2, s3, retdest
%pop2
// stack: len, a_loc, b_loc, m_loc, out_loc, s1, s2, s3, retdest
// STEP 4:
// We calculate x + k * m.
// STEP 4.1:
// Multiply k with m and store k * m in scratch_2.
PUSH modmul_return_2
%stack (return, len, a, b, m, out, s1, s2) -> (len, s1, m, s2, return, len, a, b, out, s2)
// stack: len, s1, m_loc, s2, modmul_return_2, len, a_loc, b_loc, out_loc, s2, s3, retdest
%jump(mul_bignum)
modmul_return_2:
// stack: len, a_loc, b_loc, out_loc, s2, s3, retdest
// STEP 4.2:
// Add x into k * m (in scratch_2).
PUSH modmul_return_3
%stack (return, len, a, b, out, s2) -> (len, s2, out, return, len, a, b, s2)
// stack: len, s2, out_loc, modmul_return_3, len, a_loc, b_loc, s2, s3, retdest
%jump(add_bignum)
modmul_return_3:
// stack: carry, len, a_loc, b_loc, s2, s3, retdest
POP
// stack: len, a_loc, b_loc, s2, s3, retdest
// STEP 5:
// We calculate a * b.
// Multiply a with b and store a * b in scratch_3.
PUSH modmul_return_4
%stack (return, len, a, b, s2, s3) -> (len, a, b, s3, return, len, s2, s3)
// stack: len, a_loc, b_loc, s3, modmul_return_4, len, s2, s3, retdest
%jump(mul_bignum)
modmul_return_4:
// stack: len, s2, s3, retdest
// STEP 6:
// Check that x + k * m = a * b.
// Walk through scratch_2 and scratch_3, checking that they are equal.
// stack: n=len, i=s2, j=s3, retdest
modmul_check_loop:
// stack: n, i, j, retdest
%stack (l, idx: 2) -> (idx, l, idx)
// stack: i, j, n, i, j, retdest
%mload_kernel_general
SWAP1
%mload_kernel_general
SWAP1
// stack: mem[i], mem[j], n, i, j, retdest
%assert_eq
// stack: n, i, j, retdest
%decrement
SWAP1
%increment
SWAP2
%increment
SWAP2
SWAP1
// stack: n-1, i+1, j+1, retdest
DUP1
// stack: n-1, n-1, i+1, j+1, retdest
%jumpi(modmul_check_loop)
// end of modmul_check_loop
// stack: n-1, i+1, j+1, retdest
%pop3
// stack: retdest
JUMP
len_zero:
// stack: len, a_loc, b_loc, m_loc, out_loc, s1, s2, s3, retdest
%pop8
// stack: retdest
JUMP

View File

@ -57,6 +57,7 @@ mul_end:
%pop5
// stack: retdest
JUMP
len_zero:
// stack: len, a_start_loc, b_start_loc, output_loc, retdest
%pop4

View File

@ -60,6 +60,7 @@ shr_end:
%pop3
// stack: retdest
JUMP
len_zero:
// stack: len, start_loc, retdest
%pop2

View File

@ -50,6 +50,18 @@
%endrep
%endmacro
%macro pop9
%rep 9
POP
%endrep
%endmacro
%macro pop10
%rep 10
POP
%endrep
%endmacro
%macro and_const(c)
// stack: input, ...
PUSH $c

View File

@ -376,7 +376,11 @@ impl<'a> Interpreter<'a> {
0xa2 => todo!(), // "LOG2",
0xa3 => todo!(), // "LOG3",
0xa4 => todo!(), // "LOG4",
0xa5 => bail!("Executed PANIC, stack={:?}", self.stack()), // "PANIC",
0xa5 => bail!(
"Executed PANIC, stack={:?}, memory={:?}",
self.stack(),
self.get_kernel_general_memory()
), // "PANIC",
0xf0 => todo!(), // "CREATE",
0xf1 => todo!(), // "CALL",
0xf2 => todo!(), // "CALLCODE",

View File

@ -20,12 +20,15 @@ const MINUS_ONE: U256 = U256::MAX;
const TEST_DATA_BIGNUM_INPUTS: &str = "bignum_inputs";
const TEST_DATA_U128_INPUTS: &str = "u128_inputs";
const TEST_DATA_SHR_OUTPUTS: &str = "shr_outputs";
const TEST_DATA_ISZERO_OUTPUTS: &str = "iszero_outputs";
const TEST_DATA_CMP_OUTPUTS: &str = "cmp_outputs";
const TEST_DATA_ADD_OUTPUTS: &str = "add_outputs";
const TEST_DATA_ADDMUL_OUTPUTS: &str = "addmul_outputs";
const TEST_DATA_MUL_OUTPUTS: &str = "mul_outputs";
const TEST_DATA_MODMUL_OUTPUTS: &str = "modmul_outputs";
const TEST_DATA_MODEXP_OUTPUTS: &str = "modexp_outputs";
const BIT_SIZES_TO_TEST: [usize; 15] = [
0, 1, 2, 127, 128, 129, 255, 256, 257, 512, 1000, 1023, 1024, 1025, 31415,
@ -232,6 +235,76 @@ fn test_mul_bignum(a: BigUint, b: BigUint, expected_output: BigUint) -> Result<(
Ok(())
}
fn test_modmul_bignum(a: BigUint, b: BigUint, m: BigUint, expected_output: BigUint) -> Result<()> {
let len = bignum_len(&a).max(bignum_len(&b)).max(bignum_len(&m));
let output_len = len;
let memory = pad_bignums(&[a, b, m], len);
let a_start_loc = 0;
let b_start_loc = len;
let m_start_loc = 2 * len;
let output_start_loc = 3 * len;
let scratch_1 = 4 * len; // size 2*len
let scratch_2 = 6 * len; // size 2*len
let scratch_3 = 8 * len; // size 2*len
let (new_memory, _new_stack) = run_test(
"modmul_bignum",
memory,
vec![
len.into(),
a_start_loc.into(),
b_start_loc.into(),
m_start_loc.into(),
output_start_loc.into(),
scratch_1.into(),
scratch_2.into(),
scratch_3.into(),
],
)?;
let output = mem_vec_to_biguint(&new_memory[output_start_loc..output_start_loc + output_len]);
assert_eq!(output, expected_output);
Ok(())
}
fn test_modexp_bignum(b: BigUint, e: BigUint, m: BigUint, expected_output: BigUint) -> Result<()> {
let len = bignum_len(&b).max(bignum_len(&e)).max(bignum_len(&m));
let output_len = len;
let memory = pad_bignums(&[b, e, m], len);
let b_start_loc = 0;
let e_start_loc = len;
let m_start_loc = 2 * len;
let output_start_loc = 3 * len;
let scratch_1 = 4 * len;
let scratch_2 = 5 * len; // size 2*len
let scratch_3 = 7 * len; // size 2*len
let scratch_4 = 9 * len; // size 2*len
let scratch_5 = 11 * len; // size 2*len
let (new_memory, _new_stack) = run_test(
"modexp_bignum",
memory,
vec![
len.into(),
b_start_loc.into(),
e_start_loc.into(),
m_start_loc.into(),
output_start_loc.into(),
scratch_1.into(),
scratch_2.into(),
scratch_3.into(),
scratch_4.into(),
scratch_5.into(),
],
)?;
let output = mem_vec_to_biguint(&new_memory[output_start_loc..output_start_loc + output_len]);
assert_eq!(output, expected_output);
Ok(())
}
#[test]
fn test_shr_bignum_all() -> Result<()> {
for bit_size in BIT_SIZES_TO_TEST {
@ -394,3 +467,81 @@ fn test_mul_bignum_all() -> Result<()> {
Ok(())
}
#[test]
fn test_modmul_bignum_all() -> Result<()> {
for bit_size in BIT_SIZES_TO_TEST {
let a = gen_bignum(bit_size);
let b = gen_bignum(bit_size);
let m = gen_bignum(bit_size);
if !m.is_zero() {
let output = &a * &b % &m;
test_modmul_bignum(a, b, m, output)?;
}
let a = max_bignum(bit_size);
let b = max_bignum(bit_size);
let m = max_bignum(bit_size);
if !m.is_zero() {
let output = &a * &b % &m;
test_modmul_bignum(a, b, m, output)?;
}
}
let inputs = test_data_biguint(TEST_DATA_BIGNUM_INPUTS);
let modmul_outputs = test_data_biguint(TEST_DATA_MODMUL_OUTPUTS);
let mut modmul_outputs_iter = modmul_outputs.into_iter();
for a in &inputs {
for b in &inputs {
// For m, skip the first input, which is zero.
for m in &inputs[1..] {
let output = modmul_outputs_iter.next().unwrap();
test_modmul_bignum(a.clone(), b.clone(), m.clone(), output)?;
}
}
}
Ok(())
}
#[test]
fn test_modexp_bignum_all() -> Result<()> {
// Only test smaller values for exponent.
let exp_bit_sizes = vec![2, 100, 127, 128, 129];
for bit_size in &BIT_SIZES_TO_TEST[3..14] {
for exp_bit_size in &exp_bit_sizes {
let b = gen_bignum(*bit_size);
let e = gen_bignum(*exp_bit_size);
let m = gen_bignum(*bit_size);
if !m.is_zero() {
let output = b.clone().modpow(&e, &m);
test_modexp_bignum(b, e, m, output)?;
}
let b = max_bignum(*bit_size);
let e = max_bignum(*exp_bit_size);
let m = max_bignum(*bit_size);
if !m.is_zero() {
let output = b.modpow(&e, &m);
test_modexp_bignum(b, e, m, output)?;
}
}
}
let inputs = test_data_biguint(TEST_DATA_BIGNUM_INPUTS);
let modexp_outputs = test_data_biguint(TEST_DATA_MODEXP_OUTPUTS);
let mut modexp_outputs_iter = modexp_outputs.into_iter();
for b in &inputs {
// Include only smaller exponents, to keep tests from becoming too slow.
for e in &inputs[..7] {
// For m, skip the first input, which is zero.
for m in &inputs[1..] {
let output = modexp_outputs_iter.next().unwrap();
test_modexp_bignum(b.clone(), e.clone(), m.clone(), output)?;
}
}
}
Ok(())
}

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -3,6 +3,7 @@ use std::str::FromStr;
use anyhow::{bail, Error};
use ethereum_types::{BigEndianHash, H256, U256, U512};
use itertools::Itertools;
use plonky2::field::types::Field;
use crate::extension_tower::{FieldExt, Fp12, BLS381, BN254};
@ -11,7 +12,9 @@ use crate::generation::prover_input::EvmField::{
};
use crate::generation::prover_input::FieldOp::{Inverse, Sqrt};
use crate::generation::state::GenerationState;
use crate::memory::segments::Segment;
use crate::memory::segments::Segment::BnPairing;
use crate::util::{biguint_to_mem_vec, mem_vec_to_biguint};
use crate::witness::util::{kernel_peek, stack_peek};
/// Prover input function represented as a scoped function name.
@ -35,6 +38,7 @@ impl<F: Field> GenerationState<F> {
"mpt" => self.run_mpt(),
"rlp" => self.run_rlp(),
"account_code" => self.run_account_code(input_fn),
"bignum_modmul" => self.run_bignum_modmul(),
_ => panic!("Unrecognized prover input function."),
}
}
@ -140,6 +144,69 @@ impl<F: Field> GenerationState<F> {
_ => panic!("Invalid prover input function."),
}
}
// Bignum modular multiplication.
// On the first call, calculates the remainder and quotient of the given inputs.
// These are stored, as limbs, in self.bignum_modmul_result_limbs.
// Subsequent calls return one limb at a time, in order (first remainder and then quotient).
fn run_bignum_modmul(&mut self) -> U256 {
if self.bignum_modmul_result_limbs.is_empty() {
let len = stack_peek(self, 1)
.expect("Stack does not have enough items")
.try_into()
.unwrap();
let a_start_loc = stack_peek(self, 2)
.expect("Stack does not have enough items")
.try_into()
.unwrap();
let b_start_loc = stack_peek(self, 3)
.expect("Stack does not have enough items")
.try_into()
.unwrap();
let m_start_loc = stack_peek(self, 4)
.expect("Stack does not have enough items")
.try_into()
.unwrap();
let (remainder, quotient) =
self.bignum_modmul(len, a_start_loc, b_start_loc, m_start_loc);
self.bignum_modmul_result_limbs = remainder
.iter()
.cloned()
.pad_using(len, |_| 0.into())
.chain(quotient.iter().cloned().pad_using(2 * len, |_| 0.into()))
.collect();
self.bignum_modmul_result_limbs.reverse();
}
self.bignum_modmul_result_limbs.pop().unwrap()
}
fn bignum_modmul(
&mut self,
len: usize,
a_start_loc: usize,
b_start_loc: usize,
m_start_loc: usize,
) -> (Vec<U256>, Vec<U256>) {
let a = &self.memory.contexts[0].segments[Segment::KernelGeneral as usize].content
[a_start_loc..a_start_loc + len];
let b = &self.memory.contexts[0].segments[Segment::KernelGeneral as usize].content
[b_start_loc..b_start_loc + len];
let m = &self.memory.contexts[0].segments[Segment::KernelGeneral as usize].content
[m_start_loc..m_start_loc + len];
let a_biguint = mem_vec_to_biguint(a);
let b_biguint = mem_vec_to_biguint(b);
let m_biguint = mem_vec_to_biguint(m);
let prod = a_biguint * b_biguint;
let quo = &prod / &m_biguint;
let rem = prod - m_biguint * &quo;
(biguint_to_mem_vec(rem), biguint_to_mem_vec(quo))
}
}
enum EvmField {

View File

@ -39,6 +39,11 @@ pub(crate) struct GenerationState<F: Field> {
/// useful to see the actual addresses for debugging. Here we store the mapping for all known
/// addresses.
pub(crate) state_key_to_address: HashMap<H256, Address>,
/// Prover inputs containing the result of a MODMUL operation, in little-endian order (so that
/// inputs are obtained in big-endian order via `pop()`). Contains both the remainder and the
/// quotient, in that order.
pub(crate) bignum_modmul_result_limbs: Vec<U256>,
}
impl<F: Field> GenerationState<F> {
@ -54,6 +59,7 @@ impl<F: Field> GenerationState<F> {
log::debug!("Input contract_code: {:?}", &inputs.contract_code);
let mpt_prover_inputs = all_mpt_prover_inputs_reversed(&inputs.tries);
let rlp_prover_inputs = all_rlp_prover_inputs_reversed(&inputs.signed_txns);
let bignum_modmul_result_limbs = Vec::new();
Self {
inputs,
@ -64,6 +70,7 @@ impl<F: Field> GenerationState<F> {
mpt_prover_inputs,
rlp_prover_inputs,
state_key_to_address: HashMap::new(),
bignum_modmul_result_limbs,
}
}

View File

@ -145,11 +145,11 @@ pub(crate) fn biguint_to_u256(x: BigUint) -> U256 {
U256::from_little_endian(&bytes)
}
#[cfg(test)]
pub(crate) fn le_limbs_to_biguint(x: &[u128]) -> BigUint {
pub(crate) fn mem_vec_to_biguint(x: &[U256]) -> BigUint {
BigUint::from_slice(
&x.iter()
.flat_map(|&a| {
.map(|&n| n.try_into().unwrap())
.flat_map(|a: u128| {
[
(a % (1 << 32)) as u32,
((a >> 32) % (1 << 32)) as u32,
@ -161,28 +161,15 @@ pub(crate) fn le_limbs_to_biguint(x: &[u128]) -> BigUint {
)
}
#[cfg(test)]
pub(crate) fn mem_vec_to_biguint(x: &[U256]) -> BigUint {
le_limbs_to_biguint(&x.iter().map(|&n| n.try_into().unwrap()).collect_vec())
}
#[cfg(test)]
pub(crate) fn biguint_to_le_limbs(x: BigUint) -> Vec<u128> {
let mut digits = x.to_u32_digits();
// Pad to a multiple of 4.
digits.resize((digits.len() + 3) / 4 * 4, 0);
digits
.chunks(4)
.map(|c| (c[3] as u128) << 96 | (c[2] as u128) << 64 | (c[1] as u128) << 32 | c[0] as u128)
.collect()
}
#[cfg(test)]
pub(crate) fn biguint_to_mem_vec(x: BigUint) -> Vec<U256> {
biguint_to_le_limbs(x)
.into_iter()
.map(|n| n.into())
.collect()
let num_limbs = ((x.bits() + 127) / 128) as usize;
let mut digits = x.iter_u64_digits();
let mut mem_vec = Vec::with_capacity(num_limbs);
while let Some(lo) = digits.next() {
let hi = digits.next().unwrap_or(0);
mem_vec.push(U256::from(lo as u128 | (hi as u128) << 64));
}
mem_vec
}

View File

@ -309,10 +309,11 @@ pub(crate) fn transition<F: Field>(state: &mut GenerationState<F>) -> anyhow::Re
if state.registers.is_kernel {
let offset_name = KERNEL.offset_name(state.registers.program_counter);
bail!(
"{:?} in kernel at pc={}, stack={:?}",
"{:?} in kernel at pc={}, stack={:?}, memory={:?}",
e,
offset_name,
state.stack()
state.stack(),
state.memory.contexts[0].segments[Segment::KernelGeneral as usize].content,
);
}
state.rollback(checkpoint);