Merge pull request #998 from mir-protocol/even-smaller-bignum-modexp-test

even smaller bignum modexp test, and fixes
This commit is contained in:
Nicholas Ward 2023-04-24 15:48:30 -07:00 committed by GitHub
commit 137a9966e0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 1742 additions and 1057 deletions

View File

@ -16,6 +16,7 @@ pub(crate) fn combined_kernel() -> Kernel {
include_str!("asm/bignum/add.asm"),
include_str!("asm/bignum/addmul.asm"),
include_str!("asm/bignum/cmp.asm"),
include_str!("asm/bignum/isone.asm"),
include_str!("asm/bignum/iszero.asm"),
include_str!("asm/bignum/modexp.asm"),
include_str!("asm/bignum/modmul.asm"),

View File

@ -0,0 +1,35 @@
// Arithmetic on little-endian integers represented with 128-bit limbs.
// All integers must be under a given length bound, and are padded with leading zeroes.
global isone_bignum:
// stack: len, start_loc, retdest
DUP1
// stack: len, len, start_loc, retdest
ISZERO
%jumpi(eqzero)
// stack: len, start_loc, retdest
DUP2
// stack: start_loc, len, start_loc, retdest
%mload_kernel_general
// stack: start_val, len, start_loc, retdest
%eq_const(1)
%jumpi(starts_with_one)
// Does not start with one, so not equal to one.
// stack: len, start_loc, retdest
%stack (vals: 2, retdest) -> (retdest, 0)
JUMP
eqzero:
// Is zero, so not equal to one.
// stack: cur_loc, end_loc, retdest
%stack (vals: 2, retdest) -> (retdest, 0)
// stack: retdest, 0
JUMP
starts_with_one:
// Starts with one, so check that the remaining limbs are zero.
// stack: len, start_loc, retdest
%decrement
SWAP1
%increment
SWAP1
// stack: len-1, start_loc+1, retdest
%jump(iszero_bignum)

View File

@ -8,10 +8,55 @@
// All of scratch_2..scratch_5 must have size 2 * length and be initialized with zeroes.
// Also, scratch_2..scratch_5 must be CONSECUTIVE in memory.
global modexp_bignum:
// stack: len, b_loc, e_loc, m_loc, out_loc, s1 (=scratch_1), s2, s3, s4, s5, retdest
DUP1
ISZERO
%jumpi(len_zero)
// Special input cases:
// (1) Modulus is zero (also covers len=0 case).
PUSH modulus_zero_return
// stack: modulus_zero_return, len, b_loc, e_loc, m_loc, out_loc, s1, s2, s3, s4, s5, retdest
DUP5
// stack: m_loc, modulus_zero_return, len, b_loc, e_loc, m_loc, out_loc, s1, s2, s3, s4, s5, retdest
DUP3
// stack: len, m_loc, modulus_zero_return, len, b_loc, e_loc, m_loc, out_loc, s1, s2, s3, s4, s5, retdest
%jump(iszero_bignum)
modulus_zero_return:
// stack: m==0, len, b_loc, e_loc, m_loc, out_loc, s1, s2, s3, s4, s5, retdest
%jumpi(modulus_zero_or_one)
// (2) Modulus is one.
PUSH modulus_one_return
// stack: modulus_one_return, len, b_loc, e_loc, m_loc, out_loc, s1, s2, s3, s4, s5, retdest
DUP5
// stack: m_loc, modulus_one_return, len, b_loc, e_loc, m_loc, out_loc, s1, s2, s3, s4, s5, retdest
DUP3
// stack: len, m_loc, modulus_one_return, len, b_loc, e_loc, m_loc, out_loc, s1, s2, s3, s4, s5, retdest
%jump(isone_bignum)
modulus_one_return:
// stack: m==1, len, b_loc, e_loc, m_loc, out_loc, s1, s2, s3, s4, s5, retdest
%jumpi(modulus_zero_or_one)
// (3) Both b and e are zero.
PUSH b_zero_return
// stack: b_zero_return, len, b_loc, e_loc, m_loc, out_loc, s1, s2, s3, s4, s5, retdest
DUP3
// stack: b_loc, b_zero_return, len, b_loc, e_loc, m_loc, out_loc, s1, s2, s3, s4, s5, retdest
DUP3
// stack: len, b_loc, b_zero_return, len, b_loc, e_loc, m_loc, out_loc, s1, s2, s3, s4, s5, retdest
%jump(iszero_bignum)
b_zero_return:
// stack: b==0, len, b_loc, e_loc, m_loc, out_loc, s1, s2, s3, s4, s5, retdest
PUSH e_zero_return
// stack: e_zero_return, b==0, len, b_loc, e_loc, m_loc, out_loc, s1, s2, s3, s4, s5, retdest
DUP5
// stack: e_loc, e_zero_return, b==0, len, b_loc, e_loc, m_loc, out_loc, s1, s2, s3, s4, s5, retdest
DUP4
// stack: len, e_loc, e_zero_return, b==0, len, b_loc, e_loc, m_loc, out_loc, s1, s2, s3, s4, s5, retdest
%jump(iszero_bignum)
e_zero_return:
// stack: e==0, b==0, len, b_loc, e_loc, m_loc, out_loc, s1, s2, s3, s4, s5, retdest
MUL // logical and
%jumpi(b_and_e_zero)
// End of special cases.
// We store the repeated-squares accumulator x_i in scratch_1, starting with x_0 := b.
DUP1
@ -128,8 +173,18 @@ modexp_iszero_return:
// stack: e != 0, len, b_loc, e_loc, m_loc, out_loc, s1, s2, s3, s4, s5, retdest
%jumpi(modexp_loop)
// end of modexp_loop
len_zero:
modulus_zero_or_one:
// If modulus is zero or one, return 0.
// stack: len, b_loc, e_loc, m_loc, out_loc, s1, s2, s3, s4, s5, retdest
%pop10
// stack: retdest
JUMP
b_and_e_zero:
// If base and exponent are zero (and modulus > 1), return 1.
// stack: len, b_loc, e_loc, m_loc, out_loc, s1, s2, s3, s4, s5, retdest
PUSH 1
DUP6
%mstore_kernel_general
%pop10
// stack: retdest
JUMP

View File

@ -29,6 +29,7 @@ const TEST_DATA_ADDMUL_OUTPUTS: &str = "addmul_outputs";
const TEST_DATA_MUL_OUTPUTS: &str = "mul_outputs";
const TEST_DATA_MODMUL_OUTPUTS: &str = "modmul_outputs";
const TEST_DATA_MODEXP_OUTPUTS: &str = "modexp_outputs";
const TEST_DATA_MODEXP_OUTPUTS_FULL: &str = "modexp_outputs_full";
const BIT_SIZES_TO_TEST: [usize; 15] = [
0, 1, 2, 127, 128, 129, 255, 256, 257, 512, 1000, 1023, 1024, 1025, 31415,
@ -282,7 +283,7 @@ fn test_modexp_bignum(b: BigUint, e: BigUint, m: BigUint, expected_output: BigUi
let scratch_3 = 7 * len; // size 2*len
let scratch_4 = 9 * len; // size 2*len
let scratch_5 = 11 * len; // size 2*len
let (new_memory, _new_stack) = run_test(
let (mut new_memory, _new_stack) = run_test(
"modexp_bignum",
memory,
vec![
@ -298,6 +299,10 @@ fn test_modexp_bignum(b: BigUint, e: BigUint, m: BigUint, expected_output: BigUi
scratch_5.into(),
],
)?;
new_memory.resize(
new_memory.len().max(output_start_loc + output_len),
0.into(),
);
let output = mem_vec_to_biguint(&new_memory[output_start_loc..output_start_loc + output_len]);
assert_eq!(output, expected_output);
@ -533,9 +538,8 @@ fn test_modexp_bignum_all() -> Result<()> {
let mut modexp_outputs_iter = modexp_outputs.into_iter();
for b in &inputs[..9] {
// Include only smaller exponents, to keep tests from becoming too slow.
for e in &inputs[..7] {
// For m, skip the first input, which is zero.
for m in &inputs[1..] {
for e in &inputs[..6] {
for m in &inputs[..9] {
let output = modexp_outputs_iter.next().unwrap();
test_modexp_bignum(b.clone(), e.clone(), m.clone(), output)?;
}
@ -572,13 +576,12 @@ fn test_modexp_bignum_all_full() -> Result<()> {
}
let inputs = test_data_biguint(TEST_DATA_BIGNUM_INPUTS);
let modexp_outputs = test_data_biguint(TEST_DATA_MODEXP_OUTPUTS);
let modexp_outputs = test_data_biguint(TEST_DATA_MODEXP_OUTPUTS_FULL);
let mut modexp_outputs_iter = modexp_outputs.into_iter();
for b in &inputs {
// Include only smaller exponents, to keep tests from becoming too slow.
for e in &inputs[..7] {
// For m, skip the first input, which is zero.
for m in &inputs[1..] {
for m in &inputs {
let output = modexp_outputs_iter.next().unwrap();
test_modexp_bignum(b.clone(), e.clone(), m.clone(), output)?;
}

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff