basic bignum

This commit is contained in:
Nicholas Ward 2023-02-10 16:31:35 -08:00
parent 0f55956ade
commit fa605d7b22
23 changed files with 799 additions and 23 deletions

View File

@ -22,6 +22,7 @@ keccak-hash = "0.10.0"
log = "0.4.14"
plonky2_maybe_rayon = "0.1.0"
num = "0.4.0"
num-bigint = "0.4.3"
once_cell = "1.13.0"
pest = "2.1.3"
pest_derive = "2.1.0"

View File

@ -11,6 +11,12 @@ pub static KERNEL: Lazy<Kernel> = Lazy::new(combined_kernel);
pub(crate) fn combined_kernel() -> Kernel {
let files = vec![
include_str!("asm/bignum/add.asm"),
include_str!("asm/bignum/ge.asm"),
include_str!("asm/bignum/iszero.asm"),
include_str!("asm/bignum/mul.asm"),
include_str!("asm/bignum/shr.asm"),
include_str!("asm/bignum/util.asm"),
include_str!("asm/core/bootloader.asm"),
include_str!("asm/core/call.asm"),
include_str!("asm/core/create.asm"),
@ -69,6 +75,7 @@ pub(crate) fn combined_kernel() -> Kernel {
include_str!("asm/main.asm"),
include_str!("asm/memory/core.asm"),
include_str!("asm/memory/memcpy.asm"),
include_str!("asm/memory/memset.asm"),
include_str!("asm/memory/metadata.asm"),
include_str!("asm/memory/packing.asm"),
include_str!("asm/memory/syscalls.asm"),

View File

@ -0,0 +1,56 @@
// Arithmetic on little-endian integers represented with 128-bit limbs.
// All integers must be under a given length bound, and are padded with leading zeroes.
// Replaces a with a + b, leaving b unchanged.
global add_bignum:
// stack: len, a_start_loc, b_start_loc, retdest
PUSH 0
// stack: carry=0, i=len, a_start_loc, b_start_loc, retdest
add_loop:
// stack: carry, i, a_cur_loc, b_cur_loc, retdest
DUP4
%mload_kernel_general
// stack: b[cur], carry, i, a_cur_loc, b_cur_loc, retdest
DUP4
%mload_kernel_general
// stack: a[cur], b[cur], carry, i, a_cur_loc, b_cur_loc, retdest
ADD
ADD
// stack: a[cur] + b[cur] + carry, i, a_cur_loc, b_cur_loc, retdest
DUP1
// stack: a[cur] + b[cur] + carry, a[cur] + b[cur] + carry, i, a_cur_loc, b_cur_loc, retdest
%shr_const(128)
// stack: (a[cur] + b[cur] + carry) // 2^128, a[cur] + b[cur] + carry, i, a_cur_loc, b_cur_loc, retdest
SWAP1
// stack: a[cur] + b[cur] + carry, (a[cur] + b[cur] + carry) // 2^128, i, a_cur_loc, b_cur_loc, retdest
%shl_const(128)
%shr_const(128)
// stack: c[cur] = (a[cur] + b[cur] + carry) % 2^128, carry_new = (a[cur] + b[cur] + carry) // 2^128, i, a_cur_loc, b_cur_loc, retdest
DUP4
// stack: a_cur_loc, c[cur], carry_new, i, a_cur_loc, b_cur_loc, retdest
%mstore_kernel_general
// stack: carry_new, i, a_cur_loc, b_cur_loc, retdest
%stack (c, i, a, b) -> (a, b, c, i)
// stack: a_cur_loc, b_cur_loc, carry_new, i, retdest
%increment
// stack: a_cur_loc + 1, b_cur_loc, carry_new, i, retdest
SWAP1
// stack: b_cur_loc, a_cur_loc + 1, carry_new, i, retdest
%increment
// stack: b_cur_loc + 1, a_cur_loc + 1, carry_new, i, retdest
%stack (b, a, c, i) -> (i, c, a, b)
// stack: i, carry_new, a_cur_loc + 1, b_cur_loc + 1, retdest
%decrement
// stack: i - 1, carry_new, a_cur_loc + 1, b_cur_loc + 1, retdest
SWAP1
// stack: carry_new, i - 1, a_cur_loc + 1, b_cur_loc + 1, retdest
DUP2
// stack: i - 1, carry_new, i - 1, a_cur_loc + 1, b_cur_loc + 1, retdest
%jumpi(add_loop)
add_end:
// stack: carry_new, i - 1, a_cur_loc + 1, b_cur_loc + 1, retdest
%stack (c, i, a, b) -> (c)
// stack: carry_new, retdest
SWAP1
// stack: retdest, carry_new
JUMP

View File

@ -0,0 +1,80 @@
// Arithmetic on little-endian integers represented with 128-bit limbs.
// All integers must be under a given length bound, and are padded with leading zeroes.
// Returns a >= b.
global ge_bignum:
// stack: len, a_start_loc, b_start_loc, retdest
SWAP1
// stack: a_start_loc, len, b_start_loc, retdest
DUP2
// stack: len, a_start_loc, len, b_start_loc, retdest
ADD
%decrement
// stack: a_end_loc, len, b_start_loc, retdest
SWAP2
// stack: b_start_loc, len, a_end_loc, retdest
DUP2
// stack: len, b_start_loc, len, a_end_loc, retdest
ADD
%decrement
// stack: b_end_loc, len, a_end_loc, retdest
%stack (b, l, a) -> (l, a, b)
// stack: len, a_end_loc, b_end_loc, retdest
%decrement
ge_loop:
// stack: i, a_i_loc, b_i_loc, retdest
DUP3
DUP3
// stack: a_i_loc, b_i_loc, i, a_i_loc, b_i_loc, retdest
%mload_kernel_general
SWAP1
%mload_kernel_general
SWAP1
// stack: a[i], b[i], i, a_i_loc, b_i_loc, retdest
%stack (vals: 2) -> (vals, vals)
GT
%jumpi(greater)
// stack: a[i], b[i], i, a_i_loc, b_i_loc, retdest
LT
%jumpi(less)
// stack: i, a_i_loc, b_i_loc, retdest
DUP1
ISZERO
%jumpi(equal)
%decrement
// stack: i-1, a_i_loc, b_i_loc, retdest
SWAP1
// stack: a_i_loc, i-1, b_i_loc, retdest
%decrement
// stack: a_i_loc_new, i-1, b_i_loc, retdest
SWAP2
// stack: b_i_loc, i-1, a_i_loc_new, retdest
%decrement
// stack: b_i_loc_new, i-1, a_i_loc_new, retdest
%stack (b, i, a) -> (i, a, b)
// stack: i-1, a_i_loc_new, b_i_loc_new, retdest
%jump(ge_loop)
equal:
// stack: i, a_i_loc, b_i_loc, retdest
%pop3
// stack: retdest
PUSH 3
// stack: 3, retdest
SWAP1
JUMP
greater:
// stack: a[i], b[i], i, a_i_loc, b_i_loc, retdest
%pop5
// stack: retdest
PUSH 1
// stack: 1, retdest
SWAP1
JUMP
less:
// stack: i, a_i_loc, b_i_loc, retdest
%pop3
// stack: retdest
PUSH 0
// stack: 0, retdest
SWAP1
JUMP

View File

@ -0,0 +1,36 @@
// Arithmetic on little-endian integers represented with 128-bit limbs.
// All integers must be under a given length bound, and are padded with leading zeroes.
global iszero_bignum:
// stack: len, start_loc, retdest
DUP2
// stack: start_loc, len, start_loc, retdest
ADD
// stack: end_loc, start_loc, retdest
SWAP1
// stack: cur_loc=start_loc, end_loc, retdest
iszero_loop:
// stack: cur_loc, end_loc, retdest
DUP1
// stack: cur_loc, cur_loc, end_loc, retdest
%mload_kernel_general
// stack: cur_val, cur_loc, end_loc, retdest
%jumpi(neqzero)
// stack: cur_loc, end_loc, retdest
%increment
// stack: cur_loc + 1, end_loc, retdest
%stack (vals: 2) -> (vals, vals)
// stack: cur_loc + 1, end_loc, cur_loc + 1, end_loc, retdest
EQ
%jumpi(eqzero)
%jump(iszero_loop)
neqzero:
// stack: cur_loc, end_loc, retdest
%stack (vals: 2, retdest) -> (retdest, 0)
// stack: retdest, 0
JUMP
eqzero:
// stack: cur_loc, end_loc, retdest
%stack (vals: 2, retdest) -> (retdest, 1)
// stack: retdest, 1
JUMP

View File

@ -0,0 +1,202 @@
// Arithmetic on little-endian integers represented with 128-bit limbs.
// All integers must be under a given length bound, and are padded with leading zeroes.
// Multiplies a bignum by a single-limb value. Resulting limbs may be larger than 128 bits.
// This is a naive multiplication algorithm (BasecaseMultiply from Modern Computer Arithmetic).
mul_bignum_helper:
// stack: len, start_loc, val, retdest
DUP2
// stack: start_loc, len, start_loc, val, retdest
ADD
// stack: end_loc, start_loc, val, retdest
SWAP2
SWAP1
// stack: i=start_loc, val, end_loc, retdest
mul_helper_loop:
// stack: i, val, end_loc, retdest
DUP1
// stack: i, i, val, end_loc, retdest
%mload_kernel_general
// stack: bignum[i], i, val, end_loc, retdest
DUP3
// stack: val, bignum[i], i, val, end_loc, retdest
MUL
// stack: val * bignum[i], i, val, end_loc, retdest
DUP2
// stack: i, val * bignum[i], i, val, end_loc, retdest
%mstore_kernel_general
// stack: i, val, end_loc, retdest
%increment
// stack: i + 1, val, end_loc, retdest
DUP1
// stack: i + 1, i + 1, val, end_loc, retdest
DUP4
// stack: end_loc, i + 1, i + 1, val, end_loc, retdest
GT
%jumpi(mul_helper_loop)
// stack: n = 0, i, val, retdest
%pop3
// stack: retdest
JUMP
// Reduces a bignum with limbs possibly greater than 128 bits to a normalized bignum with length len + 1.
// Used after `mul_bignum_helper` to complete the process of multiplying a bignum by a constant value.
mul_bignum_reduce_helper:
// stack: len, start_loc, retdest
DUP2
// stack: start_loc, len, start_loc, retdest
ADD
// stack: end_loc, start_loc, retdest
SWAP1
// stack: i=start_loc, end_loc, retdest
reduce_loop:
// stack: i, end_loc, retdest
DUP1
// stack: i, i, end_loc, retdest
%mload_kernel_general
// stack: bignum[i], i, end_loc, retdest
DUP1
// stack: bignum[i], bignum[i], i, end_loc, retdest
%shl_const(128)
%shr_const(128)
// stack: bignum[i] % 2^128, bignum[i], i, end_loc, retdest
SWAP1
// stack: bignum[i], bignum[i] % 2^128, i, end_loc, retdest
%shr_const(128)
// stack: bignum[i] // 2^128, bignum[i] % 2^128, i, end_loc, retdest
DUP3
// stack: i, bignum[i] // 2^128, bignum[i] % 2^128, i, end_loc, retdest
%increment
// stack: i+1, bignum[i] // 2^128, bignum[i] % 2^128, i, end_loc, retdest
SWAP1
// stack: bignum[i] // 2^128, i+1, bignum[i] % 2^128, i, end_loc, retdest
DUP2
// stack: i+1, bignum[i] // 2^128, i+1, bignum[i] % 2^128, i, end_loc, retdest
%mload_kernel_general
// stack: bignum[i+1], bignum[i] // 2^128, i+1, bignum[i] % 2^128, i, end_loc, retdest
ADD
// stack: bignum[i+1] + bignum[i] // 2^128, i+1, bignum[i] % 2^128, i, end_loc, retdest
SWAP1
// stack: i+1, bignum[i+1] + bignum[i] // 2^128, bignum[i] % 2^128, i, end_loc, retdest
%mstore_kernel_general
// stack: bignum[i] % 2^128, i, end_loc, retdest
DUP2
// stack: i, bignum[i] % 2^128, i, end_loc, retdest
%mstore_kernel_general
// stack: i, end_loc, retdest
%increment
// stack: i + 1, end_loc, retdest
%stack (vals: 2) -> (vals, vals)
// stack: i + 1, end_loc, i + 1, end_loc, retdest
EQ
%jumpi(reduce_loop)
reduce_end:
// stack: n = 0, i, retdest
%pop2
// stack: retdest
JUMP
// Stores a * b in output_loc, leaving a and b unchanged.
// Both a and b have length len; a * b will have length 2 * len.
// Both output_loc and scratch_space must be initialized as zeroes (2 * len of them in the case
// of output_loc, and len + 1 of them in the case of scratch_space).
global mul_bignum:
// stack: len, a_start_loc, b_start_loc, output_loc, scratch_space, retdest
DUP1
// stack: len, n=len, a_start_loc, bi=b_start_loc, output_cur=output_loc, scratch_space, retdest
mul_loop:
// stack: len, n, a_start_loc, bi, output_cur, scratch_space, retdest
// Copy a from a_start_loc into scratch_space.
DUP1
// stack: len, len, n, a_start_loc, bi, output_cur, scratch_space, retdest
DUP4
// stack: a_start_loc, len, len, n, a_start_loc, bi, output_cur, scratch_space, retdest
DUP8
// stack: scratch_space, a_start_loc, len, len, n, a_start_loc, bi, output_cur, scratch_space, retdest
%memcpy_kernel_general
// stack: len, n, a_start_loc, bi, output_cur, scratch_space, retdest
// Insert a zero into scratch_space[len].
DUP6
// stack: scratch_space, len, n, a_start_loc, bi, output_cur, scratch_space, retdest
DUP2
// stack: len, scratch_space, len, n, a_start_loc, bi, output_cur, scratch_space, retdest
ADD
// stack: scratch_space + len, len, n, a_start_loc, bi, output_cur, scratch_space, retdest
PUSH 0
SWAP1
// stack: scratch_space + len, 0, len, n, a_start_loc, bi, output_cur, scratch_space, retdest
%mstore_kernel_general
// stack: len, n, a_start_loc, bi, output_cur, scratch_space, retdest
// Use scratch_space to multiply a by b[i].
PUSH mul_return_1
// stack: mul_return_1, len, n, a_start_loc, bi, output_cur, scratch_space, retdest
DUP5
// stack: bi, mul_return_1, len, n, a_start_loc, bi, output_cur, scratch_space, retdest
%mload_kernel_general
// stack: b[i], mul_return_1, len, n, a_start_loc, bi, output_cur, scratch_space, retdest
DUP8
// stack: scratch_space, b[i], mul_return_1, len, n, a_start_loc, bi, output_cur, scratch_space, retdest
DUP4
// stack: len, scratch_space, b[i], mul_return_1, len, n, a_start_loc, bi, output_cur, scratch_space, retdest
%jump(mul_bignum_helper)
mul_return_1:
// stack: len, n, a_start_loc, bi, output_cur, scratch_space, retdest
PUSH mul_return_2
// stack: mul_return_2, len, n, a_start_loc, bi, output_cur, scratch_space, retdest
DUP7
// stack: scratch_space, mul_return_2, len, n, a_start_loc, bi, output_cur, scratch_space, retdest
DUP3
// stack: len, scratch_space, mul_return_2, len, n, a_start_loc, bi, output_cur, scratch_space, retdest
%jump(mul_bignum_reduce_helper)
mul_return_2:
// stack: len, n, a_start_loc, bi, output_cur, scratch_space, retdest
// Add the multiplication result into output_cur = output_len[i].
PUSH mul_return_3
// stack: mul_return_3, len, n, a_start_loc, bi, output_cur, scratch_space, retdest
DUP7
// stack: scratch_space, mul_return_3, len, n, a_start_loc, bi, output_cur, scratch_space, retdest
DUP7
// stack: output_cur, scratch_space, mul_return_3, len, n, a_start_loc, bi, output_cur, scratch_space, retdest
DUP4
// stack: len, output_cur, scratch_space, mul_return_3, len, n, a_start_loc, bi, output_cur, scratch_space, retdest
%increment
// stack: len + 1, output_cur, scratch_space, mul_return_3, len, n, a_start_loc, bi, output_cur, scratch_space, retdest
%jump(add_bignum)
mul_return_3:
// stack: carry, len, n, a_start_loc, bi, output_cur, scratch_space, retdest
DUP6
// stack: output_cur, carry, len, n, a_start_loc, bi, output_cur, scratch_space, retdest
DUP3
// stack: len, output_cur, carry, len, n, a_start_loc, bi, output_cur, scratch_space, retdest
ADD
// stack: output_cur + len, carry, len, n, a_start_loc, bi, output_cur, scratch_space, retdest
%increment
// stack: output_cur + len + 1, carry, len, n, a_start_loc, bi, output_cur, scratch_space, retdest
%mstore_kernel_general
// stack: len, n, a_start_loc, bi, output_cur, scratch_space, retdest
// Increment output_cur and b[i], decrement n, and check if we're done.
DUP5
// stack: output_cur, len, n, a_start_loc, bi, output_cur, scratch_space, retdest
%increment
// stack: output_cur+1, len, n, a_start_loc, bi, output_cur, scratch_space, retdest
DUP5
%increment
// stack: bi+1, output_cur+1, len, n, a_start_loc, bi, output_cur, scratch_space, retdest
DUP5
// stack: a_start_loc, bi+1, output_cur+1, len, n, a_start_loc, bi, output_cur, scratch_space, retdest
DUP5
%decrement
// stack: n-1, a_start_loc, bi+1, output_cur+1, len, n, a_start_loc, bi, output_cur, scratch_space, retdest
%stack (new: 4, len, old: 4) -> (len, new)
// stack: len, n-1, a_start_loc, bi+1, output_cur+1, scratch_space, retdest
DUP2
// stack: n-1, len, n-1, a_start_loc, bi+1, output_cur+1, scratch_space, retdest
%jumpi(mul_loop)
// stack: len, n, a_start_loc, bi, output_cur, scratch_space, retdest
%pop6
JUMP

View File

@ -0,0 +1,56 @@
// Arithmetic on little-endian integers represented with 128-bit limbs.
// All integers must be under a given length bound, and are padded with leading zeroes.
// Shifts a given bignum right by one bit (in place).
global shr_bignum:
// stack: len, start_loc, retdest
DUP2
// stack: start_loc, len, start_loc, retdest
ADD
// stack: start_loc + len, start_loc, retdest
%decrement
// stack: end_loc, start_loc, retdest
%stack (e) -> (e, 0)
// stack: i=end_loc, carry=0, start_loc, retdest
shr_loop:
// stack: i, carry, start_loc, retdest
DUP1
// stack: i, i, carry, start_loc, retdest
%mload_kernel_general
// stack: a[i], i, carry, start_loc, retdest
DUP1
// stack: a[i], a[i], i, carry, start_loc, retdest
%shr_const(1)
// stack: a[i] >> 1, a[i], i, carry, start_loc, retdest
SWAP1
// stack: a[i], a[i] >> 1, i, carry, start_loc, retdest
%mod_const(2)
// stack: new_carry = a[i] % 2, a[i] >> 1, i, carry, start_loc, retdest
SWAP3
// stack: carry, a[i] >> 1, i, new_carry, start_loc, retdest
%shl_const(127)
// stack: carry << 127, a[i] >> 1, i, new_carry, start_loc, retdest
OR
// stack: carry << 127 | a[i] >> 1, i, new_carry, start_loc, retdest
DUP2
// stack: i, carry << 127 | a[i] >> 1, i, new_carry, start_loc, retdest
%mstore_kernel_general
// stack: i, new_carry, start_loc, retdest
DUP1
// stack: i, i, new_carry, start_loc, retdest
%decrement
// stack: i-1, i, new_carry, start_loc, retdest
SWAP1
// stack: i, i-1, new_carry, start_loc, retdest
DUP4
// stack: start_loc, i, i-1, new_carry, start_loc, retdest
EQ
// stack: i == start_loc, i-1, new_carry, start_loc, retdest
ISZERO
// stack: i != start_loc, i-1, new_carry, start_loc, retdest
%jumpi(shr_loop)
shr_end:
// stack: i, new_carry, start_loc, retdest
%pop3
// stack: retdest
JUMP

View File

@ -0,0 +1,13 @@
%macro memcpy_kernel_general
// stack: dst, src, len
%stack (dst, src, len) -> (0, @SEGMENT_KERNEL_GENERAL, dst, 0, @SEGMENT_KERNEL_GENERAL, src, len, %%after)
%jump(memcpy)
%%after:
%endmacro
%macro clear_kernel_general
// stack: dst, len
%stack (dst, len) -> (0, @SEGMENT_KERNEL_GENERAL, dst, 0, len, %%after)
%jump(memset)
%%after:
%endmacro

View File

@ -0,0 +1,38 @@
// Sets `count` values to `value` at
// DST = (dst_ctx, dst_segment, dst_addr).
// This tuple definition is used for brevity in the stack comments below.
global memset:
// stack: DST, value, count, retdest
DUP5
// stack: count, DST, value, count, retdest
ISZERO
// stack: count == 0, DST, value, count, retdest
%jumpi(memset_finish)
// stack: DST, value, count, retdest
DUP4
// stack: value, DST, value, count, retdest
DUP4
DUP4
DUP4
// stack: DST, value, DST, value, count, retdest
MSTORE_GENERAL
// stack: DST, value, count, retdest
// Increment dst_addr.
SWAP2
%increment
SWAP2
// Decrement count.
SWAP4
%decrement
SWAP4
// Continue the loop.
%jump(memset)
memset_finish:
// stack: DST, value, count, retdest
%pop5
// stack: retdest
JUMP

View File

@ -50,6 +50,14 @@
%endrep
%endmacro
%macro neq
// stack: x, y
EQ
// stack: x == y
ISZERO
// stack: x != y
%endmacro
%macro and_const(c)
// stack: input, ...
PUSH $c

View File

@ -19,7 +19,11 @@ pub(crate) mod txn_fields;
pub fn evm_constants() -> HashMap<String, U256> {
let mut c = HashMap::new();
let hex_constants = EC_CONSTANTS.iter().chain(HASH_CONSTANTS.iter()).cloned();
let hex_constants = MISC_CONSTANTS
.iter()
.chain(EC_CONSTANTS.iter())
.chain(HASH_CONSTANTS.iter())
.cloned();
for (name, value) in hex_constants {
c.insert(name.into(), U256::from_big_endian(&value));
}
@ -50,6 +54,14 @@ pub fn evm_constants() -> HashMap<String, U256> {
c
}
const MISC_CONSTANTS: [(&str, [u8; 32]); 1] = [
// Base for limbs used in bignum arithmetic.
(
"BIGNUM_LIMB_BASE",
hex!("0000000000000000000000000000000100000000000000000000000000000000"),
),
];
const HASH_CONSTANTS: [(&str, [u8; 32]); 2] = [
// Hash of an empty string: keccak(b'').hex()
(

View File

@ -26,7 +26,7 @@ stack = { ^"%stack" ~ stack_placeholders ~ "->" ~ stack_replacements }
stack_placeholders = { "(" ~ (stack_placeholder ~ ("," ~ stack_placeholder)*)? ~ ")" }
stack_placeholder = { stack_block | identifier }
stack_block = { identifier ~ ":" ~ literal_decimal }
stack_replacements = { "(" ~ stack_replacement ~ ("," ~ stack_replacement)* ~ ")" }
stack_replacements = { "(" ~ (stack_replacement ~ ("," ~ stack_replacement)*)? ~ ")" }
stack_replacement = { literal | identifier | constant | macro_label | variable }
global_label_decl = ${ ^"GLOBAL " ~ identifier ~ ":" }

View File

@ -194,6 +194,12 @@ impl<'a> Interpreter<'a> {
&mut self.generation_state.memory.contexts[0].segments[Segment::TrieData as usize].content
}
pub(crate) fn get_memory_segment(&self, segment: Segment) -> Vec<U256> {
self.generation_state.memory.contexts[0].segments[segment as usize]
.content
.clone()
}
pub(crate) fn get_memory_segment_bytes(&self, segment: Segment) -> Vec<u8> {
self.generation_state.memory.contexts[0].segments[segment as usize]
.content
@ -202,10 +208,22 @@ impl<'a> Interpreter<'a> {
.collect()
}
pub(crate) fn get_kernel_general_memory(&self) -> Vec<U256> {
self.get_memory_segment(Segment::KernelGeneral)
}
pub(crate) fn get_rlp_memory(&self) -> Vec<u8> {
self.get_memory_segment_bytes(Segment::RlpRaw)
}
pub(crate) fn set_memory_segment(&mut self, segment: Segment, memory: Vec<U256>) {
self.generation_state.memory.contexts[0].segments[segment as usize].content = memory;
}
pub(crate) fn set_kernel_general_memory(&mut self, memory: Vec<U256>) {
self.set_memory_segment(Segment::KernelGeneral, memory)
}
pub(crate) fn set_memory_segment_bytes(&mut self, segment: Segment, memory: Vec<u8>) {
self.generation_state.memory.contexts[0].segments[segment as usize].content =
memory.into_iter().map(U256::from).collect();

View File

@ -0,0 +1,222 @@
use anyhow::Result;
use ethereum_types::U256;
use itertools::Itertools;
use num::{BigUint, Signed};
use num_bigint::RandBigInt;
use crate::cpu::kernel::aggregator::KERNEL;
use crate::cpu::kernel::interpreter::Interpreter;
use crate::util::{biguint_to_mem_vec, mem_vec_to_biguint};
fn pack_bignums(biguints: &[BigUint], length: usize) -> Vec<U256> {
biguints
.iter()
.flat_map(|biguint| {
biguint_to_mem_vec(biguint.clone())
.into_iter()
.pad_using(length, |_| U256::zero())
})
.collect()
}
fn gen_bignum(bit_size: usize) -> BigUint {
let mut rng = rand::thread_rng();
rng.gen_bigint(bit_size as u64).abs().to_biguint().unwrap()
}
fn bignum_len(a: &BigUint) -> usize {
(a.bits() as usize) / 128 + 1
}
fn gen_two_bignums_ordered(bit_size: usize) -> (BigUint, BigUint) {
let mut rng = rand::thread_rng();
let (a, b) = {
let a = rng.gen_bigint(bit_size as u64).abs().to_biguint().unwrap();
let b = rng.gen_bigint(bit_size as u64).abs().to_biguint().unwrap();
(a.clone().max(b.clone()), a.min(b))
};
(a, b)
}
fn prepare_bignum(bit_size: usize) -> (BigUint, U256, Vec<U256>) {
let a = gen_bignum(bit_size);
let length: U256 = bignum_len(&a).into();
let a_limbs = biguint_to_mem_vec(a.clone());
(a, length, a_limbs)
}
fn prepare_two_bignums(bit_size: usize) -> (BigUint, BigUint, U256, Vec<U256>) {
let (a, b) = gen_two_bignums_ordered(bit_size);
let length: U256 = bignum_len(&a).into();
let memory = pack_bignums(&[a.clone(), b.clone()], length.try_into().unwrap());
(a, b, length, memory)
}
#[test]
fn test_shr_bignum() -> Result<()> {
let (a, length, memory) = prepare_bignum(1000);
let halved = a >> 1;
let retdest = 0xDEADBEEFu32.into();
let shr_bignum = KERNEL.global_labels["shr_bignum"];
let a_start_loc = 0.into();
let mut initial_stack: Vec<U256> = vec![length, a_start_loc, retdest];
initial_stack.reverse();
let mut interpreter = Interpreter::new_with_kernel(shr_bignum, initial_stack);
interpreter.set_kernel_general_memory(memory);
interpreter.run()?;
let new_memory = interpreter.get_kernel_general_memory();
let new_a = mem_vec_to_biguint(&new_memory[0..length.as_usize()]);
assert_eq!(new_a, halved);
Ok(())
}
#[test]
fn test_iszero_bignum() -> Result<()> {
let (_a, length, memory) = prepare_bignum(1000);
let retdest = 0xDEADBEEFu32.into();
let iszero_bignum = KERNEL.global_labels["iszero_bignum"];
let a_start_loc = 0.into();
// Test with a > 0.
let mut initial_stack: Vec<U256> = vec![length, a_start_loc, retdest];
initial_stack.reverse();
let mut interpreter = Interpreter::new_with_kernel(iszero_bignum, initial_stack);
interpreter.set_kernel_general_memory(memory.clone());
interpreter.run()?;
let result = interpreter.stack()[0];
assert_eq!(result, 0.into());
let memory = vec![0.into(); memory.len()];
// Test with a == 0.
let mut initial_stack: Vec<U256> = vec![length, a_start_loc, retdest];
initial_stack.reverse();
let mut interpreter = Interpreter::new_with_kernel(iszero_bignum, initial_stack);
interpreter.set_kernel_general_memory(memory);
interpreter.run()?;
let result = interpreter.stack()[0];
assert_eq!(result, U256::one());
Ok(())
}
#[test]
fn test_ge_bignum() -> Result<()> {
let (_a, _b, length, memory) = prepare_two_bignums(1000);
let retdest = 0xDEADBEEFu32.into();
let ge_bignum = KERNEL.global_labels["ge_bignum"];
let a_start_loc = 0.into();
let b_start_loc = length;
// Test with a > b.
let mut initial_stack: Vec<U256> = vec![length, a_start_loc, b_start_loc, retdest];
initial_stack.reverse();
let mut interpreter = Interpreter::new_with_kernel(ge_bignum, initial_stack);
interpreter.set_kernel_general_memory(memory.clone());
interpreter.run()?;
let result = interpreter.stack()[0];
assert_eq!(result, U256::one());
// Swap a and b, to test the less-than case.
let mut initial_stack: Vec<U256> = vec![length, b_start_loc, a_start_loc, retdest];
initial_stack.reverse();
let mut interpreter = Interpreter::new_with_kernel(ge_bignum, initial_stack);
interpreter.set_kernel_general_memory(memory);
interpreter.run()?;
let result = interpreter.stack()[0];
assert_eq!(result, 0.into());
Ok(())
}
#[test]
fn test_add_bignum() -> Result<()> {
let (a, b, length, memory) = prepare_two_bignums(1000);
// Determine expected sum.
let sum = a + b;
let expected_sum: Vec<U256> = biguint_to_mem_vec(sum);
let a_start_loc = 0.into();
let b_start_loc = length;
// Prepare stack.
let retdest = 0xDEADBEEFu32.into();
let mut initial_stack: Vec<U256> = vec![length, a_start_loc, b_start_loc, retdest];
initial_stack.reverse();
// Prepare interpreter.
let add_bignum = KERNEL.global_labels["add_bignum"];
let mut interpreter = Interpreter::new_with_kernel(add_bignum, initial_stack);
interpreter.set_kernel_general_memory(memory);
// Run add function.
interpreter.run()?;
// Determine actual sum.
let new_memory = interpreter.get_kernel_general_memory();
let actual_sum: Vec<_> = new_memory[..expected_sum.len()].into();
// Compare.
assert_eq!(actual_sum, expected_sum);
Ok(())
}
#[test]
fn test_mul_bignum() -> Result<()> {
let (a, b, length, memory) = prepare_two_bignums(1000);
// Determine expected product.
let product = a * b;
let expected_product: Vec<U256> = biguint_to_mem_vec(product);
// Output and scratch space locations (initialized as zeroes) follow a and b in memory.
let a_start_loc = 0.into();
let b_start_loc = length;
let output_loc = length * 2;
let scratch_space = length * 4;
// Prepare stack.
let retdest = 0xDEADBEEFu32.into();
let mut initial_stack: Vec<U256> = vec![
length,
a_start_loc,
b_start_loc,
output_loc,
scratch_space,
retdest,
];
initial_stack.reverse();
// Prepare interpreter.
let mul_bignum = KERNEL.global_labels["mul_bignum"];
let mut interpreter = Interpreter::new_with_kernel(mul_bignum, initial_stack);
interpreter.set_kernel_general_memory(memory);
// Run mul function.
interpreter.run()?;
// Determine actual product.
let new_memory = interpreter.get_kernel_general_memory();
let output_location: usize = output_loc.try_into().unwrap();
let actual_product: Vec<_> =
new_memory[output_location..output_location + expected_product.len()].into();
assert_eq!(actual_product, expected_product);
Ok(())
}

View File

@ -1,5 +1,6 @@
mod account_code;
mod balance;
mod bignum;
mod core;
mod ecc;
mod exp;

View File

@ -68,7 +68,7 @@ impl<F: RichField + Extendable<D>, const D: usize> KeccakStark<F, D> {
let pad_rows = self.generate_trace_rows_for_perm([0; NUM_INPUTS]);
while rows.len() < num_rows {
rows.extend(&pad_rows);
rows.extend(pad_rows.clone());
}
rows.drain(num_rows..);
rows

View File

@ -144,3 +144,41 @@ pub(crate) fn biguint_to_u256(x: BigUint) -> U256 {
let bytes = x.to_bytes_le();
U256::from_little_endian(&bytes)
}
pub(crate) fn le_limbs_to_biguint(x: &[u128]) -> BigUint {
BigUint::from_slice(
&x.iter()
.flat_map(|&a| {
[
(a % (1 << 32)) as u32,
((a >> 32) % (1 << 32)) as u32,
((a >> 64) % (1 << 32)) as u32,
((a >> 96) % (1 << 32)) as u32,
]
})
.collect::<Vec<u32>>(),
)
}
pub(crate) fn mem_vec_to_biguint(x: &[U256]) -> BigUint {
le_limbs_to_biguint(&x.iter().map(|&n| n.try_into().unwrap()).collect_vec())
}
pub(crate) fn biguint_to_le_limbs(x: BigUint) -> Vec<u128> {
let mut digits = x.to_u32_digits();
// Pad to a multiple of 8.
digits.resize((digits.len() + 7) / 8 * 8, 0);
digits
.chunks(4)
.map(|c| (c[3] as u128) << 96 | (c[2] as u128) << 64 | (c[1] as u128) << 32 | c[0] as u128)
.collect()
}
pub(crate) fn biguint_to_mem_vec(x: BigUint) -> Vec<U256> {
biguint_to_le_limbs(x)
.into_iter()
.map(|n| n.into())
.collect()
}

View File

@ -78,16 +78,12 @@ impl<F: RichField + Extendable<D>, const D: usize> CircuitBuilder<F, D> {
assert!(
self.config.num_wires >= min_wires,
"To efficiently perform FRI checks with an arity of 2^{}, at least {} wires are needed. Consider reducing arity.",
max_fri_arity_bits,
min_wires
"To efficiently perform FRI checks with an arity of 2^{max_fri_arity_bits}, at least {min_wires} wires are needed. Consider reducing arity."
);
assert!(
self.config.num_routed_wires >= min_routed_wires,
"To efficiently perform FRI checks with an arity of 2^{}, at least {} routed wires are needed. Consider reducing arity.",
max_fri_arity_bits,
min_routed_wires
"To efficiently perform FRI checks with an arity of 2^{max_fri_arity_bits}, at least {min_routed_wires} routed wires are needed. Consider reducing arity."
);
}

View File

@ -38,8 +38,7 @@ impl<F: RichField + Extendable<D>, const D: usize> CircuitBuilder<F, D> {
let num_bits = bits.len();
assert!(
num_bits <= log_floor(F::ORDER, 2),
"{} bits may overflow the field",
num_bits
"{num_bits} bits may overflow the field"
);
if num_bits == 0 {
return self.zero();

View File

@ -361,9 +361,7 @@ impl<F: RichField + Extendable<D>, const D: usize> SimpleGenerator<F>
let access_index = access_index_f.to_canonical_u64() as usize;
debug_assert!(
access_index < vec_size,
"Access index {} is larger than the vector size {}",
access_index,
vec_size
"Access index {access_index} is larger than the vector size {vec_size}"
);
set_local_wire(

View File

@ -136,9 +136,7 @@ impl<F: RichField, H: Hasher<F>> MerkleTree<F, H> {
let log2_leaves_len = log2_strict(leaves.len());
assert!(
cap_height <= log2_leaves_len,
"cap_height={} should be at most log2(leaves.len())={}",
cap_height,
log2_leaves_len
"cap_height={cap_height} should be at most log2(leaves.len())={log2_leaves_len}"
);
let num_digests = 2 * (leaves.len() - (1 << cap_height));

View File

@ -89,8 +89,7 @@ pub(crate) fn generate_partial_witness<
assert_eq!(
remaining_generators, 0,
"{} generators weren't run",
remaining_generators,
"{remaining_generators} generators weren't run",
);
witness

View File

@ -284,8 +284,7 @@ impl<F: Field> WitnessWrite<F> for PartialWitness<F> {
if let Some(old_value) = opt_old_value {
assert_eq!(
value, old_value,
"Target {:?} was set twice with different values: {} != {}",
target, old_value, value
"Target {target:?} was set twice with different values: {old_value} != {value}"
);
}
}
@ -325,8 +324,7 @@ impl<'a, F: Field> PartitionWitness<'a, F> {
if let Some(old_value) = *rep_value {
assert_eq!(
value, old_value,
"Partition containing {:?} was set twice with different values: {} != {}",
target, old_value, value
"Partition containing {target:?} was set twice with different values: {old_value} != {value}"
);
None
} else {