From 4cef5aaa84934ad9ef8da2af7159ff73333c3d05 Mon Sep 17 00:00:00 2001 From: Nicholas Ward Date: Tue, 7 Mar 2023 15:15:20 -0800 Subject: [PATCH] modmul and modexp --- evm/src/cpu/kernel/aggregator.rs | 2 + evm/src/cpu/kernel/asm/bignum/modexp.asm | 169 +++++++++++++++++++ evm/src/cpu/kernel/asm/bignum/modmul.asm | 202 +++++++++++++++++++++++ evm/src/generation/prover_input.rs | 100 ++++++++++- evm/src/generation/state.rs | 6 + 5 files changed, 478 insertions(+), 1 deletion(-) create mode 100644 evm/src/cpu/kernel/asm/bignum/modexp.asm create mode 100644 evm/src/cpu/kernel/asm/bignum/modmul.asm diff --git a/evm/src/cpu/kernel/aggregator.rs b/evm/src/cpu/kernel/aggregator.rs index 89f11467..4a0855b4 100644 --- a/evm/src/cpu/kernel/aggregator.rs +++ b/evm/src/cpu/kernel/aggregator.rs @@ -17,6 +17,8 @@ pub(crate) fn combined_kernel() -> Kernel { include_str!("asm/bignum/addmul.asm"), include_str!("asm/bignum/cmp.asm"), include_str!("asm/bignum/iszero.asm"), + include_str!("asm/bignum/modexp.asm"), + include_str!("asm/bignum/modmul.asm"), include_str!("asm/bignum/mul.asm"), include_str!("asm/bignum/shr.asm"), include_str!("asm/bignum/util.asm"), diff --git a/evm/src/cpu/kernel/asm/bignum/modexp.asm b/evm/src/cpu/kernel/asm/bignum/modexp.asm new file mode 100644 index 00000000..9707c67e --- /dev/null +++ b/evm/src/cpu/kernel/asm/bignum/modexp.asm @@ -0,0 +1,169 @@ +// Arithmetic on little-endian integers represented with 128-bit limbs. +// All integers must be under a given length bound, and are padded with leading zeroes. + +// Stores b ^ e % m in output_loc, leaving b, e, and m unchanged. +// b, e, and m must have the same length. +// output_loc must have size length and be initialized with zeroes; scratch_1 must have size length. +// All of scratch_2..scratch_6 must have size 2 * length and be initialized with zeroes. +global modexp_bignum: + // stack: length, b_start_loc, e_start_loc, m_start_loc, output_loc, scratch_1, scratch_2, scratch_3, scratch_4, scratch_5, scratch_6, retdest + + // We store the repeated-squares accumulator x_i in scratch_1, starting with x_0 := b. + DUP1 + // stack: length, length, b_start_loc, e_start_loc, m_start_loc, output_loc, scratch_1, scratch_2, scratch_3, scratch_4, scratch_5, scratch_6, retdest + DUP3 + // stack: b_start_loc, length, length, b_start_loc, e_start_loc, m_start_loc, output_loc, scratch_1, scratch_2, scratch_3, scratch_4, scratch_5, scratch_6, retdest + DUP8 + // stack: scratch_1, b_start_loc, length, length, b_start_loc, e_start_loc, m_start_loc, output_loc, scratch_1, scratch_2, scratch_3, scratch_4, scratch_5, scratch_6, retdest + %memcpy_kernel_general + // stack: length, b_start_loc, e_start_loc, m_start_loc, output_loc, scratch_1, scratch_2, scratch_3, scratch_4, scratch_5, scratch_6, retdest + + // We store the accumulated output value x_i in output_loc, starting with x_0=1. + PUSH 1 + // stack: 1, length, b_start_loc, e_start_loc, m_start_loc, output_loc, scratch_1, scratch_2, scratch_3, scratch_4, scratch_5, scratch_6, retdest + DUP6 + // stack: output_loc, 1, length, b_start_loc, e_start_loc, m_start_loc, output_loc, scratch_1, scratch_2, scratch_3, scratch_4, scratch_5, scratch_6, retdest + %mstore_kernel_general + +modexp_loop: + // stack: length, b_start_loc, e_start_loc, m_start_loc, output_loc, scratch_1, scratch_2, scratch_3, scratch_4, scratch_5, scratch_6, retdest + + // y := e % 2 + DUP3 + // stack: e_start_loc, length, b_start_loc, e_start_loc, m_start_loc, output_loc, scratch_1, scratch_2, scratch_3, scratch_4, scratch_5, scratch_6, retdest + %mload_kernel_general + // stack: e_first, length, b_start_loc, e_start_loc, m_start_loc, output_loc, scratch_1, scratch_2, scratch_3, scratch_4, scratch_5, scratch_6, retdest + %mod_const(2) + // stack: y = e_first % 2 = e % 2, length, b_start_loc, e_start_loc, m_start_loc, output_loc, scratch_1, scratch_2, scratch_3, scratch_4, scratch_5, scratch_6, retdest + ISZERO + // stack: y == 0, length, b_start_loc, e_start_loc, m_start_loc, output_loc, scratch_1, scratch_2, scratch_3, scratch_4, scratch_5, scratch_6, retdest + %jumpi(modexp_y_0) + + // if y == 1, modular-multiply output_loc by scratch_1, using scratch_2..scratch_5 as scratch space, and store in scratch_6. + PUSH modexp_mul_return + // stack: modexp_mul_return, length, b_start_loc, e_start_loc, m_start_loc, output_loc, scratch_1, scratch_2, scratch_3, scratch_4, scratch_5, scratch_6, retdest + DUP11 + // stack: scratch_5, modexp_mul_return, length, b_start_loc, e_start_loc, m_start_loc, output_loc, scratch_1, scratch_2, scratch_3, scratch_4, scratch_5, scratch_6, retdest + DUP11 + // stack: scratch_4, scratch_5, modexp_mul_return, length, b_start_loc, e_start_loc, m_start_loc, output_loc, scratch_1, scratch_2, scratch_3, scratch_4, scratch_5, scratch_6, retdest + DUP11 + // stack: scratch_3, scratch_4, scratch_5, modexp_mul_return, length, b_start_loc, e_start_loc, m_start_loc, output_loc, scratch_1, scratch_2, scratch_3, scratch_4, scratch_5, scratch_6, retdest + DUP11 + // stack: scratch_2, scratch_3, scratch_4, scratch_5, modexp_mul_return, length, b_start_loc, e_start_loc, m_start_loc, output_loc, scratch_1, scratch_2, scratch_3, scratch_4, scratch_5, scratch_6, retdest + DUP16 + // stack: scratch_6, scratch_2, scratch_3, scratch_4, scratch_5, modexp_mul_return, length, b_start_loc, e_start_loc, m_start_loc, output_loc, scratch_1, scratch_2, scratch_3, scratch_4, scratch_5, scratch_6, retdest + DUP10 + // stack: m_start_loc, scratch_6, scratch_2, scratch_3, scratch_4, scratch_5, modexp_mul_return, length, b_start_loc, e_start_loc, m_start_loc, output_loc, scratch_1, scratch_2, scratch_3, scratch_4, scratch_5, scratch_6, retdest + DUP13 + // stack: scratch_1, m_start_loc, scratch_6, scratch_2, scratch_3, scratch_4, scratch_5, modexp_mul_return, length, b_start_loc, e_start_loc, m_start_loc, output_loc, scratch_1, scratch_2, scratch_3, scratch_4, scratch_5, scratch_6, retdest + DUP13 + // stack: output_loc, scratch_1, m_start_loc, scratch_6, scratch_2, scratch_3, scratch_4, scratch_5, modexp_mul_return, length, b_start_loc, e_start_loc, m_start_loc, output_loc, scratch_1, scratch_2, scratch_3, scratch_4, scratch_5, scratch_6, retdest + DUP10 + // stack: length, output_loc, scratch_1, m_start_loc, scratch_6, scratch_2, scratch_3, scratch_4, scratch_5, modexp_mul_return, length, b_start_loc, e_start_loc, m_start_loc, output_loc, scratch_1, scratch_2, scratch_3, scratch_4, scratch_5, scratch_6, retdest + %jump(modmul_bignum) +modexp_mul_return: + // stack: length, b_start_loc, e_start_loc, m_start_loc, output_loc, scratch_1, scratch_2, scratch_3, scratch_4, scratch_5, scratch_6, retdest + + // Copy scratch_6 to output_loc. + DUP1 + // stack: length, length, b_start_loc, e_start_loc, m_start_loc, output_loc, scratch_1, scratch_2, scratch_3, scratch_4, scratch_5, scratch_6, retdest + DUP12 + // stack: scratch_6, length, length, b_start_loc, e_start_loc, m_start_loc, output_loc, scratch_1, scratch_2, scratch_3, scratch_4, scratch_5, scratch_6, retdest + DUP7 + // stack: output_loc, scratch_6, length, length, b_start_loc, e_start_loc, m_start_loc, output_loc, scratch_1, scratch_2, scratch_3, scratch_4, scratch_5, scratch_6, retdest + %memcpy_kernel_general + // stack: length, b_start_loc, e_start_loc, m_start_loc, output_loc, scratch_1, scratch_2, scratch_3, scratch_4, scratch_5, scratch_6, retdest + + // Zero out scratch_2..scratch_6. + DUP1 + // stack: length, length, b_start_loc, e_start_loc, m_start_loc, output_loc, scratch_1, scratch_2, scratch_3, scratch_4, scratch_5, scratch_6, retdest + %mul_const(10) + // stack: 10 * length, length, b_start_loc, e_start_loc, m_start_loc, output_loc, scratch_1, scratch_2, scratch_3, scratch_4, scratch_5, scratch_6, retdest + DUP8 + // stack: scratch_2, 10 * length, length, b_start_loc, e_start_loc, m_start_loc, output_loc, scratch_1, scratch_2, scratch_3, scratch_4, scratch_5, scratch_6, retdest + %clear_kernel_general + // stack: length, b_start_loc, e_start_loc, m_start_loc, output_loc, scratch_1, scratch_2, scratch_3, scratch_4, scratch_5, scratch_6, retdest + +modexp_y_0: + // if y == 0, do nothing + + // Modular-square repeated-squares accumulator x_i (in scratch_1), using scratch_2..scratch_5 as scratch space, and store in scratch_6. + PUSH modexp_square_return + // stack: modexp_square_return, length, b_start_loc, e_start_loc, m_start_loc, output_loc, scratch_1, scratch_2, scratch_3, scratch_4, scratch_5, scratch_6, retdest + DUP11 + // stack: scratch_5, modexp_square_return, length, b_start_loc, e_start_loc, m_start_loc, output_loc, scratch_1, scratch_2, scratch_3, scratch_4, scratch_5, scratch_6, retdest + DUP11 + // stack: scratch_4, scratch_5, modexp_square_return, length, b_start_loc, e_start_loc, m_start_loc, output_loc, scratch_1, scratch_2, scratch_3, scratch_4, scratch_5, scratch_6, retdest + DUP11 + // stack: scratch_3, scratch_4, scratch_5, modexp_square_return, length, b_start_loc, e_start_loc, m_start_loc, output_loc, scratch_1, scratch_2, scratch_3, scratch_4, scratch_5, scratch_6, retdest + DUP11 + // stack: scratch_2, scratch_3, scratch_4, scratch_5, modexp_square_return, length, b_start_loc, e_start_loc, m_start_loc, output_loc, scratch_1, scratch_2, scratch_3, scratch_4, scratch_5, scratch_6, retdest + DUP16 + // stack: scratch_6, scratch_2, scratch_3, scratch_4, scratch_5, modexp_square_return, length, b_start_loc, e_start_loc, m_start_loc, output_loc, scratch_1, scratch_2, scratch_3, scratch_4, scratch_5, scratch_6, retdest + DUP10 + // stack: m_start_loc, scratch_6, scratch_2, scratch_3, scratch_4, scratch_5, modexp_square_return, length, b_start_loc, e_start_loc, m_start_loc, output_loc, scratch_1, scratch_2, scratch_3, scratch_4, scratch_5, scratch_6, retdest + DUP13 + // stack: scratch_1, m_start_loc, scratch_6, scratch_2, scratch_3, scratch_4, scratch_5, modexp_square_return, length, b_start_loc, e_start_loc, m_start_loc, output_loc, scratch_1, scratch_2, scratch_3, scratch_4, scratch_5, scratch_6, retdest + DUP1 + // stack: scratch_1, scratch_1, m_start_loc, scratch_6, scratch_2, scratch_3, scratch_4, scratch_5, modexp_square_return, length, b_start_loc, e_start_loc, m_start_loc, output_loc, scratch_1, scratch_2, scratch_3, scratch_4, scratch_5, scratch_6, retdest + DUP10 + // stack: length, scratch_1, scratch_1, m_start_loc, scratch_6, scratch_2, scratch_3, scratch_4, scratch_5, modexp_square_return, length, b_start_loc, e_start_loc, m_start_loc, output_loc, scratch_1, scratch_2, scratch_3, scratch_4, scratch_5, scratch_6, retdest + %jump(modmul_bignum) + // stack: length, b_start_loc, e_start_loc, m_start_loc, output_loc, scratch_1, scratch_2, scratch_3, scratch_4, scratch_5, scratch_6, retdest + +modexp_square_return: + // Copy scratch_6 to scratch_1. + DUP1 + // stack: length, length, b_start_loc, e_start_loc, m_start_loc, output_loc, scratch_1, scratch_2, scratch_3, scratch_4, scratch_5, scratch_6, retdest + DUP12 + // stack: scratch_6, length, length, b_start_loc, e_start_loc, m_start_loc, output_loc, scratch_1, scratch_2, scratch_3, scratch_4, scratch_5, scratch_6, retdest + DUP8 + // stack: scratch_1, scratch_6, length, length, b_start_loc, e_start_loc, m_start_loc, output_loc, scratch_1, scratch_2, scratch_3, scratch_4, scratch_5, scratch_6, retdest + %memcpy_kernel_general + // stack: length, b_start_loc, e_start_loc, m_start_loc, output_loc, scratch_1, scratch_2, scratch_3, scratch_4, scratch_5, scratch_6, retdest + + // Zero out scratch_2..scratch_6. + DUP1 + // stack: length, length, b_start_loc, e_start_loc, m_start_loc, output_loc, scratch_1, scratch_2, scratch_3, scratch_4, scratch_5, scratch_6, retdest + %mul_const(10) + // stack: 10 * length, length, b_start_loc, e_start_loc, m_start_loc, output_loc, scratch_1, scratch_2, scratch_3, scratch_4, scratch_5, scratch_6, retdest + DUP8 + // stack: scratch_2, 10 * length, length, b_start_loc, e_start_loc, m_start_loc, output_loc, scratch_1, scratch_2, scratch_3, scratch_4, scratch_5, scratch_6, retdest + %clear_kernel_general + // stack: length, b_start_loc, e_start_loc, m_start_loc, output_loc, scratch_1, scratch_2, scratch_3, scratch_4, scratch_5, scratch_6, retdest + + // e //= 2 (with shr_bignum) + + PUSH modexp_shr_return + // stack: modexp_shr_return, length, b_start_loc, e_start_loc, m_start_loc, output_loc, scratch_1, scratch_2, scratch_3, scratch_4, scratch_5, scratch_6, retdest + DUP4 + // stack: e_start_loc, modexp_shr_return, length, b_start_loc, e_start_loc, m_start_loc, output_loc, scratch_1, scratch_2, scratch_3, scratch_4, scratch_5, scratch_6, retdest + DUP3 + // stack: length, e_start_loc, modexp_shr_return, length, b_start_loc, e_start_loc, m_start_loc, output_loc, scratch_1, scratch_2, scratch_3, scratch_4, scratch_5, scratch_6, retdest + %jump(shr_bignum) +modexp_shr_return: + // stack: length, b_start_loc, e_start_loc, m_start_loc, output_loc, scratch_1, scratch_2, scratch_3, scratch_4, scratch_5, scratch_6, retdest + + // check if e == 0 (with iszero_bignum) + + PUSH modexp_iszero_return + // stack: modexp_return_6, length, b_start_loc, e_start_loc, m_start_loc, output_loc, scratch_1, scratch_2, scratch_3, scratch_4, scratch_5, scratch_6, retdest + DUP4 + // stack: e_start_loc, modexp_return_6, length, b_start_loc, e_start_loc, m_start_loc, output_loc, scratch_1, scratch_2, scratch_3, scratch_4, scratch_5, scratch_6, retdest + DUP3 + // stack: length, e_start_loc, modexp_return_6, length, b_start_loc, e_start_loc, m_start_loc, output_loc, scratch_1, scratch_2, scratch_3, scratch_4, scratch_5, scratch_6, retdest + %jump(iszero_bignum) +modexp_iszero_return: + // stack: e == 0, length, b_start_loc, e_start_loc, m_start_loc, output_loc, scratch_1, scratch_2, scratch_3, scratch_4, scratch_5, scratch_6, retdest + ISZERO + // stack: e != 0, length, b_start_loc, e_start_loc, m_start_loc, output_loc, scratch_1, scratch_2, scratch_3, scratch_4, scratch_5, scratch_6, retdest + %jumpi(modexp_loop) +modexp_end: + // stack: length, b_start_loc, e_start_loc, m_start_loc, output_loc, scratch_1, scratch_2, scratch_3, scratch_4, scratch_5, scratch_6, retdest + %rep 11 + POP + %endrep + // stack: retdest + JUMP + + diff --git a/evm/src/cpu/kernel/asm/bignum/modmul.asm b/evm/src/cpu/kernel/asm/bignum/modmul.asm new file mode 100644 index 00000000..92ee63f8 --- /dev/null +++ b/evm/src/cpu/kernel/asm/bignum/modmul.asm @@ -0,0 +1,202 @@ +// Arithmetic on little-endian integers represented with 128-bit limbs. +// All integers must be under a given length bound, and are padded with leading zeroes. + +// Stores a * b % m in output_loc, leaving a, b, and m unchanged. +// a, b, and m must have the same length. +// Both output_loc and scratch_1 must have size length. +// All of scratch_2, scratch_3, and scratch_4 must have size 2 * length and be initialized with zeroes. +global modmul_bignum: + // stack: length, a_start_loc, b_start_loc, m_start_loc, output_loc, scratch_1, scratch_2, scratch_3, scratch_4, retdest + + // The prover provides x := (a * b) % m, which we store in output_loc. + + PUSH 0 + // stack: i=0, length, a_start_loc, b_start_loc, m_start_loc, output_loc, scratch_1, scratch_2, scratch_3, scratch_4, retdest +modmul_remainder_loop: + // stack: i, length, a_start_loc, b_start_loc, m_start_loc, output_loc, scratch_1, scratch_2, scratch_3, scratch_4, retdest + PROVER_INPUT(bignum_modmul::remainder) + // stack: PI, i, length, a_start_loc, b_start_loc, m_start_loc, output_loc, scratch_1, scratch_2, scratch_3, scratch_4, retdest + DUP7 + // stack: output_loc, PI, i, length, a_start_loc, b_start_loc, m_start_loc, output_loc, scratch_1, scratch_2, scratch_3, scratch_4, retdest + DUP3 + // stack: i, output_loc, PI, i, length, a_start_loc, b_start_loc, m_start_loc, output_loc, scratch_1, scratch_2, scratch_3, scratch_4, retdest + ADD + // stack: output_loc[i], PI, i, length, a_start_loc, b_start_loc, m_start_loc, output_loc, scratch_1, scratch_2, scratch_3, scratch_4, retdest + %mstore_kernel_general + // stack: i, length, a_start_loc, b_start_loc, m_start_loc, output_loc, scratch_1, scratch_2, scratch_3, scratch_4, retdest + %increment + // stack: i+1, length, a_start_loc, b_start_loc, m_start_loc, output_loc, scratch_1, scratch_2, scratch_3, scratch_4, retdest + DUP2 + DUP2 + // stack: i+1, length, i+1, length, a_start_loc, b_start_loc, m_start_loc, output_loc, scratch_1, scratch_2, scratch_3, scratch_4, retdest + EQ + // stack: i+1==length, i+1, length, a_start_loc, b_start_loc, m_start_loc, output_loc, scratch_1, scratch_2, scratch_3, scratch_4, retdest + ISZERO + // stack: i+1!=length, i+1, length, a_start_loc, b_start_loc, m_start_loc, output_loc, scratch_1, scratch_2, scratch_3, scratch_4, retdest + %jumpi(modmul_remainder_loop) +modmul_remainder_end: + // stack: i, length, a_start_loc, b_start_loc, m_start_loc, output_loc, scratch_1, scratch_2, scratch_3, scratch_4, retdest + POP + + // The prover provides k := (a * b) / m, which we store in scratch_1. + + // stack: length, a_start_loc, b_start_loc, m_start_loc, output_loc, scratch_1, scratch_2, scratch_3, scratch_4, retdest + PUSH 0 + // stack: i=0, length, a_start_loc, b_start_loc, m_start_loc, output_loc, scratch_1, scratch_2, scratch_3, scratch_4, retdest +modmul_quotient_loop: + // stack: i, length, a_start_loc, b_start_loc, m_start_loc, output_loc, scratch_1, scratch_2, scratch_3, scratch_4, retdest + PROVER_INPUT(bignum_modmul::quotient) + // stack: PI, i, length, a_start_loc, b_start_loc, m_start_loc, output_loc, scratch_1, scratch_2, scratch_3, scratch_4, retdest + DUP8 + // stack: scratch_1, PI, i, length, a_start_loc, b_start_loc, m_start_loc, output_loc, scratch_1, scratch_2, scratch_3, scratch_4, retdest + DUP3 + // stack: i, scratch_1, PI, i, length, a_start_loc, b_start_loc, m_start_loc, output_loc, scratch_1, scratch_2, scratch_3, scratch_4, retdest + ADD + // stack: scratch_1[i], PI, i, length, a_start_loc, b_start_loc, m_start_loc, output_loc, scratch_1, scratch_2, scratch_3, scratch_4, retdest + %mstore_kernel_general + // stack: i, length, a_start_loc, b_start_loc, m_start_loc, output_loc, scratch_1, scratch_2, scratch_3, scratch_4, retdest + %increment + // stack: i+1, length, a_start_loc, b_start_loc, m_start_loc, output_loc, scratch_1, scratch_2, scratch_3, scratch_4, retdest + DUP2 + DUP2 + // stack: i+1, length, i+1, length, a_start_loc, b_start_loc, m_start_loc, output_loc, scratch_1, scratch_2, scratch_3, scratch_4, retdest + %neq + // stack: i+1!=length, i+1, length, a_start_loc, b_start_loc, m_start_loc, output_loc, scratch_1, scratch_2, scratch_3, scratch_4, retdest + %jumpi(modmul_quotient_loop) +modmul_quotient_end: + // stack: i, length, a_start_loc, b_start_loc, m_start_loc, output_loc, scratch_1, scratch_2, scratch_3, scratch_4, retdest + POP + // stack: length, a_start_loc, b_start_loc, m_start_loc, output_loc, scratch_1, scratch_2, scratch_3, scratch_4, retdest + + // Verification step 1: calculate x + k * m. + + // Store k * m in scratch_2, using scratch_3 as scratch space. + PUSH modmul_return_1 + // stack: modmul_return_1, length, a_start_loc, b_start_loc, m_start_loc, output_loc, scratch_1, scratch_2, scratch_3, scratch_4, retdest + %stack (return, len, a, b, m, out, s1, s2, s3) -> (len, s1, m, s2, s3, return, len, a, b, out, s2, s3) + // stack: length, scratch_1, m_start_loc, scratch_2, scratch_3, modmul_return_1, length, a_start_loc, b_start_loc, output_loc, scratch_2, scratch_3, scratch_4, retdest + %jump(mul_bignum) +modmul_return_1: + // stack: length, a_start_loc, b_start_loc, output_loc, scratch_2, scratch_3, scratch_4, retdest + + // Add x into k * m (in scratch_2). + PUSH modmul_return_2 + // stack: modmul_return_2, length, a_start_loc, b_start_loc, output_loc, scratch_2, scratch_3, scratch_4, retdest + %stack (return, len, a, b, out, s2) -> (len, s2, out, return, len, a, b, s2) + // stack: length, scratch_2, output_loc, modmul_return_2, length, a_start_loc, b_start_loc, scratch_2, scratch_3, scratch_4, retdest + %jump(add_bignum) +modmul_return_2: + // stack: carry, length, a_start_loc, b_start_loc, scratch_2, scratch_3, scratch_4, retdest + ISZERO + %jumpi(no_carry) + + // input is correct, x + k * m will equal a * b, which has length at most 2 * length). + + // stack: length, a_start_loc, b_start_loc, scratch_2, scratch_3, scratch_4, retdest + DUP4 + // stack: scratch_2, length, a_start_loc, b_start_loc, scratch_2, scratch_3, scratch_4, retdest + DUP2 + // stack: length, scratch_2, length, a_start_loc, b_start_loc, scratch_2, scratch_3, scratch_4, retdest + ADD + // stack: cur_loc=scratch_2 + length, length, a_start_loc, b_start_loc, scratch_2, scratch_3, scratch_4, retdest +increment_loop: + // stack: cur_loc, length, a_start_loc, b_start_loc, scratch_2, scratch_3, scratch_4, retdest + DUP1 + %mload_kernel_general + // stack: val, cur_loc, length, a_start_loc, b_start_loc, scratch_2, scratch_3, scratch_4, retdest + %increment + // stack: val+1, cur_loc, length, a_start_loc, b_start_loc, scratch_2, scratch_3, scratch_4, retdest + DUP1 + // stack: val+1, val+1, cur_loc, length, a_start_loc, b_start_loc, scratch_2, scratch_3, scratch_4, retdest + %eq_const(@BIGNUM_LIMB_BASE) + // stack: val+1==limb_base, val+1, cur_loc, length, a_start_loc, b_start_loc, scratch_2, scratch_3, scratch_4, retdest + DUP1 + // stack: val+1==limb_base, val+1==limb_base, val+1, cur_loc, length, a_start_loc, b_start_loc, scratch_2, scratch_3, scratch_4, retdest + ISZERO + // stack: val+1!=limb_base, val+1==limb_base, val+1, cur_loc, length, a_start_loc, b_start_loc, scratch_2, scratch_3, scratch_4, retdest + SWAP1 + SWAP2 + // stack: val+1, val+1!=limb_base, val+1==limb_base, cur_loc, length, a_start_loc, b_start_loc, scratch_2, scratch_3, scratch_4, retdest + MUL + // stack: to_write=(val+1)*(val+1!=limb_base), continue=val+1==limb_base, cur_loc, length, a_start_loc, b_start_loc, scratch_2, scratch_3, scratch_4, retdest + DUP3 + // stack: cur_loc, to_write, continue, cur_loc, length, a_start_loc, b_start_loc, scratch_2, scratch_3, scratch_4, retdest + %mstore_kernel_general + // stack: continue, cur_loc, length, a_start_loc, b_start_loc, scratch_2, scratch_3, scratch_4, retdest + SWAP1 + // stack: cur_loc, continue, length, a_start_loc, b_start_loc, scratch_2, scratch_3, scratch_4, retdest + %increment + // stack: cur_loc + 1, continue, length, a_start_loc, b_start_loc, scratch_2, scratch_3, scratch_4, retdest + DUP1 + // stack: cur_loc + 1, cur_loc + 1, continue, length, a_start_loc, b_start_loc, scratch_2, scratch_3, scratch_4, retdest + DUP8 + // stack: scratch_3, cur_loc + 1, cur_loc + 1, continue, length, a_start_loc, b_start_loc, scratch_2, scratch_3, scratch_4, retdest + EQ + // stack: cur_loc + 1 == scratch_3, cur_loc + 1, continue, length, a_start_loc, b_start_loc, scratch_2, scratch_3, scratch_4, retdest + ISZERO + // stack: cur_loc + 1 != scratch_3, cur_loc + 1, continue, length, a_start_loc, b_start_loc, scratch_2, scratch_3, scratch_4, retdest + SWAP1 + SWAP2 + // stack: continue, cur_loc + 1 != scratch_3, cur_loc + 1, length, a_start_loc, b_start_loc, scratch_2, scratch_3, scratch_4, retdest + MUL + // stack: new_continue=continue*(cur_loc + 1 != scratch_3), cur_loc + 1, length, a_start_loc, b_start_loc, scratch_2, scratch_3, scratch_4, retdest + %jumpi(increment_loop) + // stack: cur_loc + 1, length, a_start_loc, b_start_loc, scratch_2, scratch_3, scratch_4, retdest + POP +no_carry: + // stack: length, a_start_loc, b_start_loc, scratch_2, scratch_3, scratch_4, retdest + + // Calculate a * b. + + // Store zeroes in scratch_3. + DUP1 + // stack: length, length, a_start_loc, b_start_loc, scratch_2, scratch_3, scratch_4, retdest + DUP6 + // stack: scratch_3, length, length, a_start_loc, b_start_loc, scratch_2, scratch_3, scratch_4, retdest + %clear_kernel_general + // stack: length, a_start_loc, b_start_loc, scratch_2, scratch_3, scratch_4, retdest + + // Store a * b in scratch_3, using scratch_4 as scratch space. + PUSH modmul_return_3 + // stack: modmul_return_3, length, a_start_loc, b_start_loc, scratch_2, scratch_3, scratch_4, retdest + %stack (return, len, a, b, s2, s3, s4) -> (len, a, b, s3, s4, return, len, s2, s3) + // stack: length, a_start_loc, b_start_loc, scratch_3, scratch_4, modmul_return_3, length, scratch_2, scratch_3, retdest + %jump(mul_bignum) +modmul_return_3: + // stack: length, scratch_2, scratch_3, retdest + + // Check that x + k * m = a * b. + // Walk through scratch_2 and scratch_3, checking that they are equal. + // stack: n=length, i=scratch_2, j=scratch_3, retdest +modmul_check_loop: + // stack: n, i, j, retdest + %stack (l, idx: 2) -> (idx, l, idx) + // stack: i, j, n, i, j, retdest + %mload_kernel_general + SWAP1 + %mload_kernel_general + SWAP1 + // stack: mem[i], mem[j], n, i, j, retdest + %assert_eq + // stack: n, i, j, retdest + %decrement + // stack: n-1, i, j, retdest + SWAP1 + // stack: i, n-1, j, retdest + %increment + // stack: i+1, n-1, j, retdest + SWAP2 + // stack: j, n-1, i+1, retdest + %increment + // stack: j+1, n-1, i+1, retdest + SWAP2 + SWAP1 + // stack: n-1, i+1, j+1, retdest + DUP1 + // stack: n-1, n-1, i+1, j+1, retdest + %jumpi(modmul_check_loop) +modmul_check_end: + // stack: n-1, i+1, j+1, retdest + %pop3 + // stack: retdest + JUMP diff --git a/evm/src/generation/prover_input.rs b/evm/src/generation/prover_input.rs index 1313de83..d5506ad4 100644 --- a/evm/src/generation/prover_input.rs +++ b/evm/src/generation/prover_input.rs @@ -3,6 +3,7 @@ use std::str::FromStr; use anyhow::{bail, Error}; use ethereum_types::{BigEndianHash, H256, U256}; +use itertools::Itertools; use plonky2::field::types::Field; use crate::bn254_arithmetic::Fp12; @@ -11,7 +12,8 @@ use crate::generation::prover_input::EvmField::{ }; use crate::generation::prover_input::FieldOp::{Inverse, Sqrt}; use crate::generation::state::GenerationState; -use crate::memory::segments::Segment::BnPairing; +use crate::memory::segments::{Segment, Segment::BnPairing}; +use crate::util::{biguint_to_mem_vec, mem_vec_to_biguint}; use crate::witness::util::{kernel_peek, stack_peek}; /// Prover input function represented as a scoped function name. @@ -34,6 +36,7 @@ impl GenerationState { "mpt" => self.run_mpt(), "rlp" => self.run_rlp(), "account_code" => self.run_account_code(input_fn), + "bignum_modmul" => self.run_bignum_modmul(input_fn), _ => panic!("Unrecognized prover input function."), } } @@ -123,6 +126,101 @@ impl GenerationState { _ => panic!("Invalid prover input function."), } } + + // Bignum modular multiplication related code. + fn run_bignum_modmul(&mut self, input_fn: &ProverInputFn) -> U256 { + if self.bignum_modmul_prover_inputs.is_empty() { + let function = input_fn.0[1].as_str(); + + let len = stack_peek(self, 1) + .expect("Stack does not have enough items") + .try_into() + .unwrap(); + let a_start_loc = stack_peek(self, 2) + .expect("Stack does not have enough items") + .try_into() + .unwrap(); + let b_start_loc = stack_peek(self, 3) + .expect("Stack does not have enough items") + .try_into() + .unwrap(); + let m_start_loc = stack_peek(self, 4) + .expect("Stack does not have enough items") + .try_into() + .unwrap(); + + let result = match function { + "remainder" => { + self.bignum_modmul_remainder(len, a_start_loc, b_start_loc, m_start_loc) + } + "quotient" => { + self.bignum_modmul_quotient(len, a_start_loc, b_start_loc, m_start_loc) + } + _ => panic!("Invalid prover input function."), + }; + + self.bignum_modmul_prover_inputs = result + .iter() + .cloned() + .pad_using(len, |_| 0.into()) + .collect(); + self.bignum_modmul_prover_inputs.reverse(); + } + + self.bignum_modmul_prover_inputs.pop().unwrap() + } + + fn bignum_modmul_remainder( + &mut self, + len: usize, + a_start_loc: usize, + b_start_loc: usize, + m_start_loc: usize, + ) -> Vec { + let a = &self.memory.contexts[0].segments[Segment::KernelGeneral as usize].content + [a_start_loc..a_start_loc + len]; + let b = &self.memory.contexts[0].segments[Segment::KernelGeneral as usize].content + [b_start_loc..b_start_loc + len]; + let m = &self.memory.contexts[0].segments[Segment::KernelGeneral as usize].content + [m_start_loc..m_start_loc + len]; + + let a_biguint = mem_vec_to_biguint(a); + let b_biguint = mem_vec_to_biguint(b); + let m_biguint = mem_vec_to_biguint(m); + + let result_biguint = (a_biguint * b_biguint) % m_biguint; + dbg!("remainder"); + dbg!(result_biguint.clone()); + biguint_to_mem_vec(result_biguint) + } + + fn bignum_modmul_quotient( + &mut self, + len: usize, + a_start_loc: usize, + b_start_loc: usize, + m_start_loc: usize, + ) -> Vec { + let a = &self.memory.contexts[0].segments[Segment::KernelGeneral as usize].content + [a_start_loc..a_start_loc + len]; + let b = &self.memory.contexts[0].segments[Segment::KernelGeneral as usize].content + [b_start_loc..b_start_loc + len]; + let m = &self.memory.contexts[0].segments[Segment::KernelGeneral as usize].content + [m_start_loc..m_start_loc + len]; + + let a_biguint = mem_vec_to_biguint(a); + let b_biguint = mem_vec_to_biguint(b); + let m_biguint = mem_vec_to_biguint(m); + + dbg!(a_biguint.clone()); + dbg!(b_biguint.clone()); + dbg!(m_biguint.clone()); + + let result_biguint = (a_biguint * b_biguint) / m_biguint; + dbg!("quotient"); + dbg!(result_biguint.clone()); + biguint_to_mem_vec(result_biguint) + } } enum EvmField { diff --git a/evm/src/generation/state.rs b/evm/src/generation/state.rs index 9399e4b6..f2d2a287 100644 --- a/evm/src/generation/state.rs +++ b/evm/src/generation/state.rs @@ -39,6 +39,10 @@ pub(crate) struct GenerationState { /// useful to see the actual addresses for debugging. Here we store the mapping for all known /// addresses. pub(crate) state_key_to_address: HashMap, + + /// Prover inputs containing the result of a MODMUL-related operation, in reverse order so that the next + /// input can be obtained via `pop()`. + pub(crate) bignum_modmul_prover_inputs: Vec, } impl GenerationState { @@ -54,6 +58,7 @@ impl GenerationState { log::debug!("Input contract_code: {:?}", &inputs.contract_code); let mpt_prover_inputs = all_mpt_prover_inputs_reversed(&inputs.tries); let rlp_prover_inputs = all_rlp_prover_inputs_reversed(&inputs.signed_txns); + let bignum_modmul_prover_inputs = Vec::new(); Self { inputs, @@ -64,6 +69,7 @@ impl GenerationState { mpt_prover_inputs, rlp_prover_inputs, state_key_to_address: HashMap::new(), + bignum_modmul_prover_inputs, } }