From e6b5e3656f6f851587e0906521996f94616d94d4 Mon Sep 17 00:00:00 2001 From: Daniel Lubarov Date: Fri, 7 Oct 2022 12:03:37 -0700 Subject: [PATCH 01/17] Some more uses of %increment, %decrement --- evm/src/cpu/kernel/asm/core/intrinsic_gas.asm | 2 +- evm/src/cpu/kernel/asm/core/util.asm | 2 +- evm/src/cpu/kernel/asm/memory/core.asm | 6 +++--- evm/src/cpu/kernel/asm/memory/memcpy.asm | 6 +++--- evm/src/cpu/kernel/asm/memory/packing.asm | 4 ++-- evm/src/cpu/kernel/asm/mpt/hash.asm | 6 +++--- evm/src/cpu/kernel/asm/mpt/hash_trie_specific.asm | 4 ++-- evm/src/cpu/kernel/asm/mpt/hex_prefix.asm | 8 ++++---- evm/src/cpu/kernel/asm/mpt/load.asm | 4 ++-- evm/src/cpu/kernel/asm/mpt/read.asm | 6 +++--- evm/src/cpu/kernel/asm/mpt/util.asm | 4 ++-- evm/src/cpu/kernel/asm/ripemd/memory.asm | 14 +++++++------- evm/src/cpu/kernel/asm/rlp/decode.asm | 8 ++++---- evm/src/cpu/kernel/asm/rlp/encode.asm | 12 ++++++------ evm/src/cpu/kernel/asm/rlp/read_to_memory.asm | 2 +- 15 files changed, 44 insertions(+), 44 deletions(-) diff --git a/evm/src/cpu/kernel/asm/core/intrinsic_gas.asm b/evm/src/cpu/kernel/asm/core/intrinsic_gas.asm index 931a6a7b..5891807c 100644 --- a/evm/src/cpu/kernel/asm/core/intrinsic_gas.asm +++ b/evm/src/cpu/kernel/asm/core/intrinsic_gas.asm @@ -22,7 +22,7 @@ count_zeros_loop: // stack: zeros', i, retdest SWAP1 // stack: i, zeros', retdest - %add_const(1) + %increment // stack: i', zeros', retdest %jump(count_zeros_loop) diff --git a/evm/src/cpu/kernel/asm/core/util.asm b/evm/src/cpu/kernel/asm/core/util.asm index 4ceaec3b..dfacf1a2 100644 --- a/evm/src/cpu/kernel/asm/core/util.asm +++ b/evm/src/cpu/kernel/asm/core/util.asm @@ -14,7 +14,7 @@ %macro next_context_id // stack: (empty) %mload_global_metadata(@GLOBAL_METADATA_LARGEST_CONTEXT) - %add_const(1) + %increment // stack: new_ctx DUP1 %mstore_global_metadata(@GLOBAL_METADATA_LARGEST_CONTEXT) diff --git a/evm/src/cpu/kernel/asm/memory/core.asm b/evm/src/cpu/kernel/asm/memory/core.asm index c2c19811..f4bcf1f1 100644 --- a/evm/src/cpu/kernel/asm/memory/core.asm +++ b/evm/src/cpu/kernel/asm/memory/core.asm @@ -64,7 +64,7 @@ %shl_const(8) // stack: c_3 << 8, offset DUP2 - %add_const(1) + %increment %mload_kernel($segment) OR // stack: (c_3 << 8) | c_2, offset @@ -91,7 +91,7 @@ %mload_kernel($segment) // stack: c0 , offset DUP2 - %add_const(1) + %increment %mload_kernel($segment) %shl_const(8) OR @@ -208,7 +208,7 @@ // stack: c_2, c_1, c_0, offset DUP4 // stack: offset, c_2, c_1, c_0, offset - %add_const(1) + %increment %mstore_kernel($segment) // stack: c_1, c_0, offset DUP3 diff --git a/evm/src/cpu/kernel/asm/memory/memcpy.asm b/evm/src/cpu/kernel/asm/memory/memcpy.asm index 3feca35d..dd0569e7 100644 --- a/evm/src/cpu/kernel/asm/memory/memcpy.asm +++ b/evm/src/cpu/kernel/asm/memory/memcpy.asm @@ -28,15 +28,15 @@ global memcpy: // Increment dst_addr. SWAP2 - %add_const(1) + %increment SWAP2 // Increment src_addr. SWAP5 - %add_const(1) + %increment SWAP5 // Decrement count. SWAP6 - %sub_const(1) + %decrement SWAP6 // Continue the loop. diff --git a/evm/src/cpu/kernel/asm/memory/packing.asm b/evm/src/cpu/kernel/asm/memory/packing.asm index c8b4c468..f12c7b17 100644 --- a/evm/src/cpu/kernel/asm/memory/packing.asm +++ b/evm/src/cpu/kernel/asm/memory/packing.asm @@ -71,9 +71,9 @@ mstore_unpacking_loop: // stack: i, context, segment, offset, value, len, retdest // Increment offset. - SWAP3 %add_const(1) SWAP3 + SWAP3 %increment SWAP3 // Increment i. - %add_const(1) + %increment %jump(mstore_unpacking_loop) diff --git a/evm/src/cpu/kernel/asm/mpt/hash.asm b/evm/src/cpu/kernel/asm/mpt/hash.asm index abd436fe..ef0158e0 100644 --- a/evm/src/cpu/kernel/asm/mpt/hash.asm +++ b/evm/src/cpu/kernel/asm/mpt/hash.asm @@ -82,7 +82,7 @@ global encode_node: DUP1 %mload_trie_data // stack: node_type, node_ptr, encode_value, retdest // Increment node_ptr, so it points to the node payload instead of its type. - SWAP1 %add_const(1) SWAP1 + SWAP1 %increment SWAP1 // stack: node_type, node_payload_ptr, encode_value, retdest DUP1 %eq_const(@MPT_NODE_EMPTY) %jumpi(encode_node_empty) @@ -214,7 +214,7 @@ encode_node_extension_after_encode_child: PUSH encode_node_extension_after_hex_prefix // retdest PUSH 0 // terminated // stack: terminated, encode_node_extension_after_hex_prefix, result, result_len, node_payload_ptr, retdest - DUP5 %add_const(1) %mload_trie_data // Load the packed_nibbles field, which is at index 1. + DUP5 %increment %mload_trie_data // Load the packed_nibbles field, which is at index 1. // stack: packed_nibbles, terminated, encode_node_extension_after_hex_prefix, result, result_len, node_payload_ptr, retdest DUP6 %mload_trie_data // Load the num_nibbles field, which is at index 0. // stack: num_nibbles, packed_nibbles, terminated, encode_node_extension_after_hex_prefix, result, result_len, node_payload_ptr, retdest @@ -247,7 +247,7 @@ encode_node_leaf: PUSH encode_node_leaf_after_hex_prefix // retdest PUSH 1 // terminated // stack: terminated, encode_node_leaf_after_hex_prefix, node_payload_ptr, encode_value, retdest - DUP3 %add_const(1) %mload_trie_data // Load the packed_nibbles field, which is at index 1. + DUP3 %increment %mload_trie_data // Load the packed_nibbles field, which is at index 1. // stack: packed_nibbles, terminated, encode_node_leaf_after_hex_prefix, node_payload_ptr, encode_value, retdest DUP4 %mload_trie_data // Load the num_nibbles field, which is at index 0. // stack: num_nibbles, packed_nibbles, terminated, encode_node_leaf_after_hex_prefix, node_payload_ptr, encode_value, retdest diff --git a/evm/src/cpu/kernel/asm/mpt/hash_trie_specific.asm b/evm/src/cpu/kernel/asm/mpt/hash_trie_specific.asm index 80763deb..221c0f20 100644 --- a/evm/src/cpu/kernel/asm/mpt/hash_trie_specific.asm +++ b/evm/src/cpu/kernel/asm/mpt/hash_trie_specific.asm @@ -48,7 +48,7 @@ encode_account: DUP2 %mload_trie_data // nonce = value[0] %rlp_scalar_len // stack: nonce_rlp_len, rlp_pos, value_ptr, retdest - DUP3 %add_const(1) %mload_trie_data // balance = value[1] + DUP3 %increment %mload_trie_data // balance = value[1] %rlp_scalar_len // stack: balance_rlp_len, nonce_rlp_len, rlp_pos, value_ptr, retdest PUSH 66 // storage_root and code_hash fields each take 1 + 32 bytes @@ -68,7 +68,7 @@ encode_account: // stack: nonce, rlp_pos_3, value_ptr, retdest SWAP1 %encode_rlp_scalar // stack: rlp_pos_4, value_ptr, retdest - DUP2 %add_const(1) %mload_trie_data // balance = value[1] + DUP2 %increment %mload_trie_data // balance = value[1] // stack: balance, rlp_pos_4, value_ptr, retdest SWAP1 %encode_rlp_scalar // stack: rlp_pos_5, value_ptr, retdest diff --git a/evm/src/cpu/kernel/asm/mpt/hex_prefix.asm b/evm/src/cpu/kernel/asm/mpt/hex_prefix.asm index 72ac18cc..b7a3073b 100644 --- a/evm/src/cpu/kernel/asm/mpt/hex_prefix.asm +++ b/evm/src/cpu/kernel/asm/mpt/hex_prefix.asm @@ -15,7 +15,7 @@ global hex_prefix_rlp: // Compute the length of the hex-prefix string, in bytes: // hp_len = num_nibbles / 2 + 1 = i + 1 - DUP1 %add_const(1) + DUP1 %increment // stack: hp_len, i, rlp_pos, num_nibbles, packed_nibbles, terminated, retdest // Write the RLP header. @@ -35,7 +35,7 @@ rlp_header_medium: %mstore_rlp // rlp_pos += 1 - SWAP2 %add_const(1) SWAP2 + SWAP2 %increment SWAP2 %jump(start_loop) @@ -49,7 +49,7 @@ rlp_header_large: %mstore_rlp DUP1 // value = hp_len - DUP4 %add_const(1) // offset = rlp_pos + 1 + DUP4 %increment // offset = rlp_pos + 1 %mstore_rlp // rlp_pos += 2 @@ -74,7 +74,7 @@ loop: %mstore_rlp // stack: i, hp_len, rlp_pos, num_nibbles, packed_nibbles, terminated, retdest - %sub_const(1) + %decrement SWAP4 %shr_const(8) SWAP4 // packed_nibbles >>= 8 %jump(loop) diff --git a/evm/src/cpu/kernel/asm/mpt/load.asm b/evm/src/cpu/kernel/asm/mpt/load.asm index f072f202..f37e94ba 100644 --- a/evm/src/cpu/kernel/asm/mpt/load.asm +++ b/evm/src/cpu/kernel/asm/mpt/load.asm @@ -111,7 +111,7 @@ load_mpt_extension: // one element, appending our child pointer. Thus our child node will start // at i + 1. So we will set our child pointer to i + 1. %get_trie_data_size - %add_const(1) + %increment %append_to_trie_data // stack: retdest @@ -172,7 +172,7 @@ load_mpt_digest: // stack: leaf_part, leaf_len %append_to_trie_data // stack: leaf_len - %sub_const(1) + %decrement // stack: leaf_len' %jump(%%loop) %%finish: diff --git a/evm/src/cpu/kernel/asm/mpt/read.asm b/evm/src/cpu/kernel/asm/mpt/read.asm index f952f49a..32ab8f7f 100644 --- a/evm/src/cpu/kernel/asm/mpt/read.asm +++ b/evm/src/cpu/kernel/asm/mpt/read.asm @@ -31,7 +31,7 @@ global mpt_read: %mload_trie_data // stack: node_type, node_ptr, num_nibbles, key, retdest // Increment node_ptr, so it points to the node payload instead of its type. - SWAP1 %add_const(1) SWAP1 + SWAP1 %increment SWAP1 // stack: node_type, node_payload_ptr, num_nibbles, key, retdest DUP1 %eq_const(@MPT_NODE_EMPTY) %jumpi(mpt_read_empty) @@ -103,7 +103,7 @@ mpt_read_extension: %mul_const(4) SHR // key_part = key >> (future_nibbles * 4) DUP1 // stack: key_part, key_part, future_nibbles, key, node_payload_ptr, retdest - DUP5 %add_const(1) %mload_trie_data + DUP5 %increment %mload_trie_data // stack: node_key, key_part, key_part, future_nibbles, key, node_payload_ptr, retdest EQ // does the first part of our key match the node's key? %jumpi(mpt_read_extension_found) @@ -131,7 +131,7 @@ mpt_read_leaf: // stack: node_payload_ptr, num_nibbles, key, retdest DUP1 %mload_trie_data // stack: node_num_nibbles, node_payload_ptr, num_nibbles, key, retdest - DUP2 %add_const(1) %mload_trie_data + DUP2 %increment %mload_trie_data // stack: node_key, node_num_nibbles, node_payload_ptr, num_nibbles, key, retdest SWAP3 // stack: num_nibbles, node_num_nibbles, node_payload_ptr, node_key, key, retdest diff --git a/evm/src/cpu/kernel/asm/mpt/util.asm b/evm/src/cpu/kernel/asm/mpt/util.asm index 0e0006d3..19cad943 100644 --- a/evm/src/cpu/kernel/asm/mpt/util.asm +++ b/evm/src/cpu/kernel/asm/mpt/util.asm @@ -28,7 +28,7 @@ %get_trie_data_size // stack: trie_data_size, value DUP1 - %add_const(1) + %increment // stack: trie_data_size', trie_data_size, value %set_trie_data_size // stack: trie_data_size, value @@ -45,7 +45,7 @@ // return (first_nibble, num_nibbles, key) %macro split_first_nibble // stack: num_nibbles, key - %sub_const(1) // num_nibbles -= 1 + %decrement // num_nibbles -= 1 // stack: num_nibbles, key DUP2 // stack: key, num_nibbles, key diff --git a/evm/src/cpu/kernel/asm/ripemd/memory.asm b/evm/src/cpu/kernel/asm/ripemd/memory.asm index 5d0266bd..e3b7cbe6 100644 --- a/evm/src/cpu/kernel/asm/ripemd/memory.asm +++ b/evm/src/cpu/kernel/asm/ripemd/memory.asm @@ -44,7 +44,7 @@ store_input_stack: // stack: offset, byte, rem, length, REM_INP %mstore_kernel_general // stack: rem, length, REM_INP - %sub_const(1) + %decrement DUP1 // stack: rem - 1, rem - 1, length, REM_INP %jumpi(store_input_stack) @@ -66,10 +66,10 @@ store_input: // stack: offset, byte, rem , ADDR , length %mstore_kernel_general // stack: rem , ADDR , length - %sub_const(1) + %decrement // stack: rem-1, ADDR , length SWAP3 - %add_const(1) + %increment SWAP3 // stack: rem-1, ADDR+1, length DUP1 @@ -90,12 +90,12 @@ global buffer_update: // stack: get, set, get , set , times , retdest %mupdate_kernel_general // stack: get , set , times , retdest - %add_const(1) + %increment SWAP1 - %add_const(1) + %increment SWAP1 SWAP2 - %sub_const(1) + %decrement SWAP2 // stack: get+1, set+1, times-1, retdest DUP3 @@ -112,7 +112,7 @@ global buffer_update: // stack: offset = N-i, 0, i %mstore_kernel_general // stack: i - %sub_const(1) + %decrement DUP1 // stack: i-1, i-1 %jumpi($label) diff --git a/evm/src/cpu/kernel/asm/rlp/decode.asm b/evm/src/cpu/kernel/asm/rlp/decode.asm index 5749aee7..182354c4 100644 --- a/evm/src/cpu/kernel/asm/rlp/decode.asm +++ b/evm/src/cpu/kernel/asm/rlp/decode.asm @@ -36,7 +36,7 @@ decode_rlp_string_len_medium: %sub_const(0x80) // stack: len, pos, retdest SWAP1 - %add_const(1) + %increment // stack: pos', len, retdest %stack (pos, len, retdest) -> (retdest, pos, len) JUMP @@ -47,7 +47,7 @@ decode_rlp_string_len_large: %sub_const(0xb7) // stack: len_of_len, pos, retdest SWAP1 - %add_const(1) + %increment // stack: pos', len_of_len, retdest %jump(decode_int_given_len) @@ -92,7 +92,7 @@ global decode_rlp_list_len: %mload_current(@SEGMENT_RLP_RAW) // stack: first_byte, pos, retdest SWAP1 - %add_const(1) // increment pos + %increment // increment pos SWAP1 // stack: first_byte, pos', retdest // If first_byte is >= 0xf8, it's a > 55 byte list, and @@ -157,7 +157,7 @@ decode_int_given_len_loop: // stack: acc', pos, end_pos, retdest // Increment pos. SWAP1 - %add_const(1) + %increment SWAP1 // stack: acc', pos', end_pos, retdest %jump(decode_int_given_len_loop) diff --git a/evm/src/cpu/kernel/asm/rlp/encode.asm b/evm/src/cpu/kernel/asm/rlp/encode.asm index 851ad3cf..dada98b0 100644 --- a/evm/src/cpu/kernel/asm/rlp/encode.asm +++ b/evm/src/cpu/kernel/asm/rlp/encode.asm @@ -14,7 +14,7 @@ global encode_rlp_scalar: // stack: pos, scalar, pos, retdest %mstore_rlp // stack: pos, retdest - %add_const(1) + %increment // stack: pos', retdest SWAP1 JUMP @@ -76,7 +76,7 @@ encode_rlp_fixed: %mstore_rlp // stack: len, pos, string, retdest SWAP1 - %add_const(1) // increment pos + %increment // increment pos // stack: pos, len, string, retdest %stack (pos, len, string) -> (pos, string, len, encode_rlp_fixed_finish) // stack: context, segment, pos, string, len, encode_rlp_fixed_finish, retdest @@ -159,7 +159,7 @@ global encode_rlp_list_prefix: // stack: pos, prefix, pos, retdest %mstore_rlp // stack: pos, retdest - %add_const(1) + %increment SWAP1 JUMP encode_rlp_list_prefix_large: @@ -172,7 +172,7 @@ encode_rlp_list_prefix_large: DUP3 // pos %mstore_rlp // stack: len_of_len, pos, payload_len, retdest - SWAP1 %add_const(1) + SWAP1 %increment // stack: pos', len_of_len, payload_len, retdest %stack (pos, len_of_len, payload_len) -> (pos, payload_len, len_of_len, @@ -231,7 +231,7 @@ prepend_rlp_list_prefix_big: SUB // stack: start_pos, len_of_len, payload_len, end_pos, retdest DUP2 %add_const(0xf7) DUP2 %mstore_rlp // rlp[start_pos] = 0xf7 + len_of_len - DUP1 %add_const(1) // start_len_pos = start_pos + 1 + DUP1 %increment // start_len_pos = start_pos + 1 %stack (start_len_pos, start_pos, len_of_len, payload_len, end_pos, retdest) -> (start_len_pos, payload_len, len_of_len, prepend_rlp_list_prefix_big_done_writing_len, @@ -269,7 +269,7 @@ prepend_rlp_list_prefix_big_done_writing_len: // stack: scalar %num_bytes // stack: scalar_bytes - %add_const(1) // Account for the length prefix. + %increment // Account for the length prefix. // stack: rlp_len %%finish: %endmacro diff --git a/evm/src/cpu/kernel/asm/rlp/read_to_memory.asm b/evm/src/cpu/kernel/asm/rlp/read_to_memory.asm index 189edd1d..5d8cbd17 100644 --- a/evm/src/cpu/kernel/asm/rlp/read_to_memory.asm +++ b/evm/src/cpu/kernel/asm/rlp/read_to_memory.asm @@ -25,7 +25,7 @@ read_rlp_to_memory_loop: // stack: pos, byte, pos, len, retdest %mstore_current(@SEGMENT_RLP_RAW) // stack: pos, len, retdest - %add_const(1) + %increment // stack: pos', len, retdest %jump(read_rlp_to_memory_loop) From 817156cd479612d7c16369e462550c8ecd6e1894 Mon Sep 17 00:00:00 2001 From: Daniel Lubarov Date: Sat, 8 Oct 2022 13:23:00 -0700 Subject: [PATCH 02/17] Begin MPT insert --- evm/src/cpu/kernel/aggregator.rs | 1 + evm/src/cpu/kernel/asm/mpt/delete.asm | 6 ++++ evm/src/cpu/kernel/asm/mpt/read.asm | 2 +- evm/src/cpu/kernel/asm/mpt/write.asm | 50 +++++++++++++++++++++++++-- 4 files changed, 55 insertions(+), 4 deletions(-) create mode 100644 evm/src/cpu/kernel/asm/mpt/delete.asm diff --git a/evm/src/cpu/kernel/aggregator.rs b/evm/src/cpu/kernel/aggregator.rs index 0d94c86f..2dcbd41c 100644 --- a/evm/src/cpu/kernel/aggregator.rs +++ b/evm/src/cpu/kernel/aggregator.rs @@ -39,6 +39,7 @@ pub(crate) fn combined_kernel() -> Kernel { include_str!("asm/memory/metadata.asm"), include_str!("asm/memory/packing.asm"), include_str!("asm/memory/txn_fields.asm"), + include_str!("asm/mpt/delete.asm"), include_str!("asm/mpt/hash.asm"), include_str!("asm/mpt/hash_trie_specific.asm"), include_str!("asm/mpt/hex_prefix.asm"), diff --git a/evm/src/cpu/kernel/asm/mpt/delete.asm b/evm/src/cpu/kernel/asm/mpt/delete.asm new file mode 100644 index 00000000..3e0b8afe --- /dev/null +++ b/evm/src/cpu/kernel/asm/mpt/delete.asm @@ -0,0 +1,6 @@ +// Return a copy of the given node with the given key deleted. +// +// Pre stack: node_ptr, num_nibbles, key, retdest +// Post stack: updated_node_ptr +global mpt_delete: + PANIC // TODO diff --git a/evm/src/cpu/kernel/asm/mpt/read.asm b/evm/src/cpu/kernel/asm/mpt/read.asm index 32ab8f7f..c6f36204 100644 --- a/evm/src/cpu/kernel/asm/mpt/read.asm +++ b/evm/src/cpu/kernel/asm/mpt/read.asm @@ -39,7 +39,7 @@ global mpt_read: DUP1 %eq_const(@MPT_NODE_EXTENSION) %jumpi(mpt_read_extension) DUP1 %eq_const(@MPT_NODE_LEAF) %jumpi(mpt_read_leaf) - // There's still the MPT_NODE_HASH case, but if we hit a digest node, + // There's still the MPT_NODE_HASH case, but if we hit a hash node, // it means the prover failed to provide necessary Merkle data, so panic. PANIC diff --git a/evm/src/cpu/kernel/asm/mpt/write.asm b/evm/src/cpu/kernel/asm/mpt/write.asm index 5b59d016..eab51d3e 100644 --- a/evm/src/cpu/kernel/asm/mpt/write.asm +++ b/evm/src/cpu/kernel/asm/mpt/write.asm @@ -1,3 +1,47 @@ -global mpt_write: - // stack: node_ptr, num_nibbles, key, retdest - // TODO +// TODO: Need a special case for deleting, if value = ''. +// Or canonicalize once, before final hashing, to remove empty leaves etc. + +// Return a copy of the given node, with the given key set to the given value. +// +// Pre stack: node_ptr, num_nibbles, key, value_ptr, retdest +// Post stack: updated_node_ptr +global mpt_insert: + // stack: node_ptr, num_nibbles, key, value_ptr, retdest + DUP1 %mload_trie_data + // stack: node_type, node_ptr, num_nibbles, key, value_ptr, retdest + // Increment node_ptr, so it points to the node payload instead of its type. + SWAP1 %increment SWAP1 + // stack: node_type, node_payload_ptr, num_nibbles, key, value_ptr, retdest + + DUP1 %eq_const(@MPT_NODE_EMPTY) %jumpi(mpt_insert_empty) + DUP1 %eq_const(@MPT_NODE_BRANCH) %jumpi(mpt_insert_branch) + DUP1 %eq_const(@MPT_NODE_EXTENSION) %jumpi(mpt_insert_extension) + DUP1 %eq_const(@MPT_NODE_LEAF) %jumpi(mpt_insert_leaf) + + // There's still the MPT_NODE_HASH case, but if we hit a hash node, + // it means the prover failed to provide necessary Merkle data, so panic. + PANIC + +mpt_insert_empty: + // stack: node_type, node_payload_ptr, num_nibbles, key, value_ptr, retdest + POP + // stack: node_payload_ptr, num_nibbles, key, value_ptr, retdest + PANIC // TODO + +mpt_insert_branch: + // stack: node_type, node_payload_ptr, num_nibbles, key, value_ptr, retdest + POP + // stack: node_payload_ptr, num_nibbles, key, value_ptr, retdest + PANIC // TODO + +mpt_insert_extension: + // stack: node_type, node_payload_ptr, num_nibbles, key, value_ptr, retdest + POP + // stack: node_payload_ptr, num_nibbles, key, value_ptr, retdest + PANIC // TODO + +mpt_insert_leaf: + // stack: node_type, node_payload_ptr, num_nibbles, key, value_ptr, retdest + POP + // stack: node_payload_ptr, num_nibbles, key, value_ptr, retdest + PANIC // TODO From 8ee7265863009917c2e509abd76194eb78a707ca Mon Sep 17 00:00:00 2001 From: Daniel Lubarov Date: Sat, 8 Oct 2022 13:51:52 -0700 Subject: [PATCH 03/17] Tweak MPT value storage --- evm/src/cpu/kernel/asm/mpt/hash.asm | 7 ++++++- evm/src/cpu/kernel/asm/mpt/load.asm | 12 +++++++++++- evm/src/cpu/kernel/asm/mpt/read.asm | 14 +++++++++----- evm/src/cpu/kernel/tests/mpt/load.rs | 1 + evm/src/cpu/kernel/tests/mpt/read.rs | 11 ++++++----- 5 files changed, 33 insertions(+), 12 deletions(-) diff --git a/evm/src/cpu/kernel/asm/mpt/hash.asm b/evm/src/cpu/kernel/asm/mpt/hash.asm index ef0158e0..511e9d18 100644 --- a/evm/src/cpu/kernel/asm/mpt/hash.asm +++ b/evm/src/cpu/kernel/asm/mpt/hash.asm @@ -137,6 +137,8 @@ encode_node_branch: // stack: rlp_pos', node_payload_ptr, encode_value, retdest SWAP1 %add_const(16) + // stack: value_ptr_ptr, rlp_pos', encode_value, retdest + %mload_trie_data // stack: value_len_ptr, rlp_pos', encode_value, retdest DUP1 %mload_trie_data // stack: value_len, value_len_ptr, rlp_pos', encode_value, retdest @@ -257,7 +259,10 @@ encode_node_leaf: encode_node_leaf_after_hex_prefix: // stack: rlp_pos, node_payload_ptr, encode_value, retdest SWAP1 - %add_const(3) // The value starts at index 3, after num_nibbles, packed_nibbles, and value_len. + %add_const(2) // The value pointer starts at index 3, after num_nibbles and packed_nibbles. + // stack: value_ptr_ptr, rlp_pos, encode_value, retdest + %mload_trie_data + %increment // skip over length prefix // stack: value_ptr, rlp_pos, encode_value, retdest %stack (value_ptr, rlp_pos, encode_value, retdest) -> (encode_value, rlp_pos, value_ptr, encode_node_leaf_after_encode_value, retdest) diff --git a/evm/src/cpu/kernel/asm/mpt/load.asm b/evm/src/cpu/kernel/asm/mpt/load.asm index f37e94ba..62909f2d 100644 --- a/evm/src/cpu/kernel/asm/mpt/load.asm +++ b/evm/src/cpu/kernel/asm/mpt/load.asm @@ -76,9 +76,14 @@ load_mpt_branch: %get_trie_data_size // stack: ptr_children, retdest DUP1 %add_const(16) - // stack: ptr_leaf, ptr_children, retdest + // stack: ptr_value, ptr_children, retdest %set_trie_data_size // stack: ptr_children, retdest + // We need to append a pointer to where the value will live. + // %load_leaf_value will append the value just after this pointer; + // we add 1 to account for the pointer itself. + %get_trie_data_size %increment %append_to_trie_data + // stack: ptr_children, retdest %load_leaf_value // Load the 16 children. @@ -128,6 +133,11 @@ load_mpt_leaf: PROVER_INPUT(mpt) // read packed_nibbles %append_to_trie_data // stack: retdest + // We need to append a pointer to where the value will live. + // %load_leaf_value will append the value just after this pointer; + // we add 1 to account for the pointer itself. + %get_trie_data_size %increment %append_to_trie_data + // stack: retdest %load_leaf_value // stack: retdest JUMP diff --git a/evm/src/cpu/kernel/asm/mpt/read.asm b/evm/src/cpu/kernel/asm/mpt/read.asm index c6f36204..dae97336 100644 --- a/evm/src/cpu/kernel/asm/mpt/read.asm +++ b/evm/src/cpu/kernel/asm/mpt/read.asm @@ -1,6 +1,6 @@ -// Given an address, return a pointer to the associated account data, which -// consists of four words (nonce, balance, storage_root, code_hash), in the -// state trie. Returns 0 if the address is not found. +// Given an address, return a pointer to the associated (length-prefixed) +// account data, which consists of four words (nonce, balance, storage_root, +// code_hash), in the state trie. Returns 0 if the address is not found. global mpt_read_state_trie: // stack: addr, retdest // The key is the hash of the address. Since KECCAK_GENERAL takes input from @@ -24,7 +24,7 @@ mpt_read_state_trie_after_mstore: // - the key, as a U256 // - the number of nibbles in the key (should start at 64) // -// This function returns a pointer to the leaf, or 0 if the key is not found. +// This function returns a pointer to the length-prefixed leaf, or 0 if the key is not found. global mpt_read: // stack: node_ptr, num_nibbles, key, retdest DUP1 @@ -75,6 +75,8 @@ mpt_read_branch_end_of_key: %stack (node_payload_ptr, num_nibbles, key, retdest) -> (node_payload_ptr, retdest) // stack: node_payload_ptr, retdest %add_const(16) // skip over the 16 child nodes + // stack: value_ptr_ptr, retdest + %mload_trie_data // stack: value_len_ptr, retdest DUP1 %mload_trie_data // stack: value_len, value_len_ptr, retdest @@ -147,7 +149,9 @@ mpt_read_leaf: JUMP mpt_read_leaf_found: // stack: node_payload_ptr, retdest - %add_const(3) // The value is located after num_nibbles, the key, and the value length. + %add_const(2) // The value pointer is located after num_nibbles and the key. + // stack: value_ptr_ptr, retdest + %mload_trie_data // stack: value_ptr, retdest SWAP1 JUMP diff --git a/evm/src/cpu/kernel/tests/mpt/load.rs b/evm/src/cpu/kernel/tests/mpt/load.rs index 3af39e30..ca4f7071 100644 --- a/evm/src/cpu/kernel/tests/mpt/load.rs +++ b/evm/src/cpu/kernel/tests/mpt/load.rs @@ -48,6 +48,7 @@ fn load_all_mpts() -> Result<()> { type_leaf, 3.into(), // 3 nibbles 0xDEF.into(), // key part + 9.into(), // value pointer 4.into(), // value length account.nonce, account.balance, diff --git a/evm/src/cpu/kernel/tests/mpt/read.rs b/evm/src/cpu/kernel/tests/mpt/read.rs index c45a6b60..06d89ff6 100644 --- a/evm/src/cpu/kernel/tests/mpt/read.rs +++ b/evm/src/cpu/kernel/tests/mpt/read.rs @@ -44,11 +44,12 @@ fn mpt_read() -> Result<()> { assert_eq!(interpreter.stack().len(), 1); let result_ptr = interpreter.stack()[0].as_usize(); - let result = &interpreter.get_trie_data()[result_ptr..][..4]; - assert_eq!(result[0], account.nonce); - assert_eq!(result[1], account.balance); - assert_eq!(result[2], account.storage_root.into_uint()); - assert_eq!(result[3], account.code_hash.into_uint()); + let result = &interpreter.get_trie_data()[result_ptr..][..5]; + assert_eq!(result[0], 4.into()); + assert_eq!(result[1], account.nonce); + assert_eq!(result[2], account.balance); + assert_eq!(result[3], account.storage_root.into_uint()); + assert_eq!(result[4], account.code_hash.into_uint()); Ok(()) } From 443a07000390bc0957321b1aaaa8eb48a7443b00 Mon Sep 17 00:00:00 2001 From: Daniel Lubarov Date: Sat, 8 Oct 2022 13:59:02 -0700 Subject: [PATCH 04/17] Clippy fix --- evm/src/cpu/kernel/opcodes.rs | 2 +- evm/src/cpu/kernel/tests/ripemd.rs | 2 +- field/src/types.rs | 2 +- plonky2/src/gadgets/arithmetic_extension.rs | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/evm/src/cpu/kernel/opcodes.rs b/evm/src/cpu/kernel/opcodes.rs index 2325c53a..c5133050 100644 --- a/evm/src/cpu/kernel/opcodes.rs +++ b/evm/src/cpu/kernel/opcodes.rs @@ -2,7 +2,7 @@ pub(crate) fn get_push_opcode(n: u8) -> u8 { assert!(n > 0); assert!(n <= 32); - 0x60 + (n as u8 - 1) + 0x60 + n - 1 } /// The opcode of a standard instruction (not a `PUSH`). diff --git a/evm/src/cpu/kernel/tests/ripemd.rs b/evm/src/cpu/kernel/tests/ripemd.rs index 6123c336..305548ec 100644 --- a/evm/src/cpu/kernel/tests/ripemd.rs +++ b/evm/src/cpu/kernel/tests/ripemd.rs @@ -46,7 +46,7 @@ fn test_ripemd_reference() -> Result<()> { let kernel = combined_kernel(); let initial_offset = kernel.global_labels["ripemd_stack"]; - let initial_stack: Vec = input.iter().map(|&x| U256::from(x as u32)).rev().collect(); + let initial_stack: Vec = input.iter().map(|&x| U256::from(x)).rev().collect(); let final_stack: Vec = run_with_kernel(&kernel, initial_offset, initial_stack)? .stack() .to_vec(); diff --git a/field/src/types.rs b/field/src/types.rs index b112fde2..545f90c5 100644 --- a/field/src/types.rs +++ b/field/src/types.rs @@ -455,7 +455,7 @@ pub trait PrimeField: Field { let mut x = w * *self; let mut b = x * w; - let mut v = Self::TWO_ADICITY as usize; + let mut v = Self::TWO_ADICITY; while !b.is_one() { let mut k = 0usize; diff --git a/plonky2/src/gadgets/arithmetic_extension.rs b/plonky2/src/gadgets/arithmetic_extension.rs index 23caeac1..23c401b8 100644 --- a/plonky2/src/gadgets/arithmetic_extension.rs +++ b/plonky2/src/gadgets/arithmetic_extension.rs @@ -443,7 +443,7 @@ impl, const D: usize> CircuitBuilder { let mut current = base; let mut product = self.one_extension(); - for j in 0..bits_u64(exponent as u64) { + for j in 0..bits_u64(exponent) { if j != 0 { current = self.square_extension(current); } From 6bb1ad94e8acc9fb144cf8562f4c9b4d0bb74b48 Mon Sep 17 00:00:00 2001 From: Daniel Lubarov Date: Sat, 8 Oct 2022 15:09:07 -0700 Subject: [PATCH 05/17] MPT insert logic, part 1 --- evm/src/cpu/kernel/aggregator.rs | 3 +- .../cpu/kernel/asm/mpt/hash_trie_specific.asm | 2 +- evm/src/cpu/kernel/asm/mpt/insert.asm | 118 +++++++++++ .../kernel/asm/mpt/insert_trie_specific.asm | 14 ++ evm/src/cpu/kernel/asm/mpt/load.asm | 159 +++++++------- evm/src/cpu/kernel/asm/mpt/write.asm | 47 ----- evm/src/cpu/kernel/interpreter.rs | 11 +- evm/src/cpu/kernel/tests/mpt/insert.rs | 173 ++++++++++++++++ evm/src/cpu/kernel/tests/mpt/load.rs | 194 ++++++++++++++++-- evm/src/cpu/kernel/tests/mpt/mod.rs | 30 +++ evm/src/generation/mpt.rs | 18 +- 11 files changed, 620 insertions(+), 149 deletions(-) create mode 100644 evm/src/cpu/kernel/asm/mpt/insert.asm create mode 100644 evm/src/cpu/kernel/asm/mpt/insert_trie_specific.asm delete mode 100644 evm/src/cpu/kernel/asm/mpt/write.asm create mode 100644 evm/src/cpu/kernel/tests/mpt/insert.rs diff --git a/evm/src/cpu/kernel/aggregator.rs b/evm/src/cpu/kernel/aggregator.rs index 2dcbd41c..6fb2231e 100644 --- a/evm/src/cpu/kernel/aggregator.rs +++ b/evm/src/cpu/kernel/aggregator.rs @@ -43,12 +43,13 @@ pub(crate) fn combined_kernel() -> Kernel { include_str!("asm/mpt/hash.asm"), include_str!("asm/mpt/hash_trie_specific.asm"), include_str!("asm/mpt/hex_prefix.asm"), + include_str!("asm/mpt/insert.asm"), + include_str!("asm/mpt/insert_trie_specific.asm"), include_str!("asm/mpt/load.asm"), include_str!("asm/mpt/read.asm"), include_str!("asm/mpt/storage_read.asm"), include_str!("asm/mpt/storage_write.asm"), include_str!("asm/mpt/util.asm"), - include_str!("asm/mpt/write.asm"), include_str!("asm/ripemd/box.asm"), include_str!("asm/ripemd/compression.asm"), include_str!("asm/ripemd/constants.asm"), diff --git a/evm/src/cpu/kernel/asm/mpt/hash_trie_specific.asm b/evm/src/cpu/kernel/asm/mpt/hash_trie_specific.asm index 221c0f20..bf2c46f0 100644 --- a/evm/src/cpu/kernel/asm/mpt/hash_trie_specific.asm +++ b/evm/src/cpu/kernel/asm/mpt/hash_trie_specific.asm @@ -39,7 +39,7 @@ global mpt_hash_receipt_trie: %%after: %endmacro -encode_account: +global encode_account: // stack: rlp_pos, value_ptr, retdest // First, we compute the length of the RLP data we're about to write. // The nonce and balance fields are variable-length, so we need to load them diff --git a/evm/src/cpu/kernel/asm/mpt/insert.asm b/evm/src/cpu/kernel/asm/mpt/insert.asm new file mode 100644 index 00000000..9a5ed7f2 --- /dev/null +++ b/evm/src/cpu/kernel/asm/mpt/insert.asm @@ -0,0 +1,118 @@ +// Return a copy of the given node, with the given key set to the given value. +// +// Pre stack: node_ptr, num_nibbles, key, value_ptr, retdest +// Post stack: updated_node_ptr +global mpt_insert: + // stack: node_ptr, num_nibbles, key, value_ptr, retdest + DUP1 %mload_trie_data + // stack: node_type, node_ptr, num_nibbles, key, value_ptr, retdest + // Increment node_ptr, so it points to the node payload instead of its type. + SWAP1 %increment SWAP1 + // stack: node_type, node_payload_ptr, num_nibbles, key, value_ptr, retdest + + DUP1 %eq_const(@MPT_NODE_EMPTY) %jumpi(mpt_insert_empty) + DUP1 %eq_const(@MPT_NODE_BRANCH) %jumpi(mpt_insert_branch) + DUP1 %eq_const(@MPT_NODE_EXTENSION) %jumpi(mpt_insert_extension) + DUP1 %eq_const(@MPT_NODE_LEAF) %jumpi(mpt_insert_leaf) + + // There's still the MPT_NODE_HASH case, but if we hit a hash node, + // it means the prover failed to provide necessary Merkle data, so panic. + PANIC + +mpt_insert_empty: + // stack: node_type, node_payload_ptr, num_nibbles, key, value_ptr, retdest + %pop2 + // stack: num_nibbles, key, value_ptr, retdest + // We will append a new leaf node to our MPT tape and return a pointer to it. + %get_trie_data_size + // stack: leaf_ptr, num_nibbles, key, value_ptr, retdest + PUSH @MPT_NODE_LEAF %append_to_trie_data + // stack: leaf_ptr, num_nibbles, key, value_ptr, retdest + SWAP1 %append_to_trie_data + // stack: leaf_ptr, key, value_ptr, retdest + SWAP1 %append_to_trie_data + // stack: leaf_ptr, value_ptr, retdest + SWAP1 %append_to_trie_data + // stack: leaf_ptr, retdest + SWAP1 + JUMP + +mpt_insert_branch: + // stack: node_type, node_payload_ptr, num_nibbles, key, value_ptr, retdest + %get_trie_data_size + // stack: updated_branch_ptr, node_type, node_payload_ptr, num_nibbles, key, value_ptr, retdest + SWAP1 + %append_to_trie_data + // stack: updated_branch_ptr, node_payload_ptr, num_nibbles, key, value_ptr, retdest + SWAP1 + // stack: node_payload_ptr, updated_branch_ptr, num_nibbles, key, value_ptr, retdest + + // Copy the original node's data to our updated node. + DUP1 %mload_trie_data %append_to_trie_data // Copy child[0] + DUP1 %add_const(1) %mload_trie_data %append_to_trie_data // ... + DUP1 %add_const(2) %mload_trie_data %append_to_trie_data + DUP1 %add_const(3) %mload_trie_data %append_to_trie_data + DUP1 %add_const(4) %mload_trie_data %append_to_trie_data + DUP1 %add_const(5) %mload_trie_data %append_to_trie_data + DUP1 %add_const(6) %mload_trie_data %append_to_trie_data + DUP1 %add_const(7) %mload_trie_data %append_to_trie_data + DUP1 %add_const(8) %mload_trie_data %append_to_trie_data + DUP1 %add_const(9) %mload_trie_data %append_to_trie_data + DUP1 %add_const(10) %mload_trie_data %append_to_trie_data + DUP1 %add_const(11) %mload_trie_data %append_to_trie_data + DUP1 %add_const(12) %mload_trie_data %append_to_trie_data + DUP1 %add_const(13) %mload_trie_data %append_to_trie_data + DUP1 %add_const(14) %mload_trie_data %append_to_trie_data + DUP1 %add_const(15) %mload_trie_data %append_to_trie_data // Copy child[15] + %add_const(16) %mload_trie_data %append_to_trie_data // Copy value_ptr + + // At this point, we branch based on whether the key terminates with this branch node. + // stack: updated_branch_ptr, num_nibbles, key, value_ptr, retdest + DUP2 %jumpi(mpt_insert_branch_nonterminal) + + // The key terminates here, so the value will be placed right in our (updated) branch node. + // stack: updated_branch_ptr, num_nibbles, key, value_ptr, retdest + SWAP3 + // stack: value_ptr, num_nibbles, key, updated_branch_ptr, retdest + DUP4 %add_const(17) + // stack: updated_branch_value_ptr_ptr, value_ptr, num_nibbles, key, updated_branch_ptr, retdest + %mstore_trie_data + // stack: num_nibbles, key, updated_branch_ptr, retdest + %pop2 + // stack: updated_branch_ptr, retdest + SWAP1 + JUMP + +mpt_insert_branch_nonterminal: + // The key continues, so we split off the first (most significant) nibble, + // and recursively insert into the child associated with that nibble. + // stack: updated_branch_ptr, num_nibbles, key, value_ptr, retdest + %stack (updated_branch_ptr, num_nibbles, key) -> (num_nibbles, key, updated_branch_ptr) + %split_first_nibble + // stack: first_nibble, num_nibbles, key, updated_branch_ptr, value_ptr, retdest + DUP4 %increment ADD + // stack: child_ptr_ptr, num_nibbles, key, updated_branch_ptr, value_ptr, retdest + %stack (child_ptr_ptr, num_nibbles, key, updated_branch_ptr, value_ptr) + -> (child_ptr_ptr, num_nibbles, key, value_ptr, + mpt_insert_branch_nonterminal_after_recursion, + child_ptr_ptr, updated_branch_ptr) + %mload_trie_data // Deref child_ptr_ptr, giving child_ptr + %jump(mpt_insert) +mpt_insert_branch_nonterminal_after_recursion: + // stack: updated_child_ptr, child_ptr_ptr, updated_branch_ptr, retdest + SWAP1 %mstore_trie_data // Store the pointer to the updated child. + // stack: updated_branch_ptr, retdest + SWAP1 + JUMP + +mpt_insert_extension: + // stack: node_type, node_payload_ptr, num_nibbles, key, value_ptr, retdest + POP + // stack: node_payload_ptr, num_nibbles, key, value_ptr, retdest + PANIC // TODO + +mpt_insert_leaf: + // stack: node_type, node_payload_ptr, num_nibbles, key, value_ptr, retdest + POP + // stack: node_payload_ptr, num_nibbles, key, value_ptr, retdest + PANIC // TODO diff --git a/evm/src/cpu/kernel/asm/mpt/insert_trie_specific.asm b/evm/src/cpu/kernel/asm/mpt/insert_trie_specific.asm new file mode 100644 index 00000000..4c03d96c --- /dev/null +++ b/evm/src/cpu/kernel/asm/mpt/insert_trie_specific.asm @@ -0,0 +1,14 @@ +// Insertion logic specific to a particular trie. + +// Mutate the state trie, inserting the given key-value pair. +global mpt_insert_state_trie: + // stack: num_nibbles, key, value_ptr, retdest + %stack (num_nibbles, key, value_ptr) + -> (num_nibbles, key, value_ptr, mpt_insert_state_trie_save) + %mload_global_metadata(@GLOBAL_METADATA_STATE_TRIE_ROOT) + // stack: state_root_ptr, num_nibbles, key, value_ptr, mpt_insert_state_trie_save, retdest + %jump(mpt_insert) +mpt_insert_state_trie_save: + // stack: updated_node_ptr, retdest + %mstore_global_metadata(@GLOBAL_METADATA_STATE_TRIE_ROOT) + JUMP diff --git a/evm/src/cpu/kernel/asm/mpt/load.asm b/evm/src/cpu/kernel/asm/mpt/load.asm index 62909f2d..73f58b95 100644 --- a/evm/src/cpu/kernel/asm/mpt/load.asm +++ b/evm/src/cpu/kernel/asm/mpt/load.asm @@ -9,9 +9,9 @@ global load_all_mpts: PUSH 1 %set_trie_data_size - %load_mpt_and_return_root_ptr %mstore_global_metadata(@GLOBAL_METADATA_STATE_TRIE_ROOT) - %load_mpt_and_return_root_ptr %mstore_global_metadata(@GLOBAL_METADATA_TXN_TRIE_ROOT) - %load_mpt_and_return_root_ptr %mstore_global_metadata(@GLOBAL_METADATA_RECEIPT_TRIE_ROOT) + %load_mpt %mstore_global_metadata(@GLOBAL_METADATA_STATE_TRIE_ROOT) + %load_mpt %mstore_global_metadata(@GLOBAL_METADATA_TXN_TRIE_ROOT) + %load_mpt %mstore_global_metadata(@GLOBAL_METADATA_RECEIPT_TRIE_ROOT) PROVER_INPUT(mpt) // stack: num_storage_tries, retdest @@ -30,7 +30,7 @@ storage_trie_loop: // stack: i, storage_trie_addr, i, num_storage_tries, retdest %mstore_kernel(@SEGMENT_STORAGE_TRIE_ADDRS) // stack: i, num_storage_tries, retdest - %load_mpt_and_return_root_ptr + %load_mpt // stack: root_ptr, i, num_storage_tries, retdest DUP2 // stack: i, root_ptr, i, num_storage_tries, retdest @@ -45,13 +45,11 @@ storage_trie_loop_end: // Load an MPT from prover inputs. // Pre stack: retdest -// Post stack: (empty) +// Post stack: node_ptr load_mpt: // stack: retdest PROVER_INPUT(mpt) // stack: node_type, retdest - DUP1 %append_to_trie_data - // stack: node_type, retdest DUP1 %eq_const(@MPT_NODE_EMPTY) %jumpi(load_mpt_empty) DUP1 %eq_const(@MPT_NODE_BRANCH) %jumpi(load_mpt_branch) @@ -61,94 +59,108 @@ load_mpt: PANIC // Invalid node type load_mpt_empty: - // stack: node_type, retdest - POP - // stack: retdest + // TRIE_DATA[0] = 0, and an empty node has type 0, so we can simply return the null pointer. + %stack (node_type, retdest) -> (retdest, 0) JUMP load_mpt_branch: // stack: node_type, retdest - POP - // stack: retdest + %get_trie_data_size + // stack: node_ptr, node_type, retdest + SWAP1 %append_to_trie_data + // stack: node_ptr, retdest // Save the offset of our 16 child pointers so we can write them later. // Then advance out current trie pointer beyond them, so we can load the // value and have it placed after our child pointers. %get_trie_data_size - // stack: ptr_children, retdest - DUP1 %add_const(16) - // stack: ptr_value, ptr_children, retdest + // stack: children_ptr, node_ptr, retdest + DUP1 %add_const(17) // Skip over 16 children plus the value pointer + // stack: value_ptr, children_ptr, node_ptr, retdest %set_trie_data_size - // stack: ptr_children, retdest - // We need to append a pointer to where the value will live. - // %load_leaf_value will append the value just after this pointer; - // we add 1 to account for the pointer itself. - %get_trie_data_size %increment %append_to_trie_data - // stack: ptr_children, retdest - %load_leaf_value + // stack: children_ptr, node_ptr, retdest + %load_value + // stack: children_ptr, value_ptr, node_ptr, retdest + SWAP1 // Load the 16 children. %rep 16 - %load_mpt_and_return_root_ptr - // stack: child_ptr, ptr_next_child, retdest + %load_mpt + // stack: child_ptr, next_child_ptr_ptr, value_ptr, node_ptr, retdest DUP2 - // stack: ptr_next_child, child_ptr, ptr_next_child, retdest + // stack: next_child_ptr_ptr, child_ptr, next_child_ptr_ptr, value_ptr, node_ptr, retdest %mstore_trie_data - // stack: ptr_next_child, retdest + // stack: next_child_ptr_ptr, value_ptr, node_ptr, retdest %increment - // stack: ptr_next_child, retdest + // stack: next_child_ptr_ptr, value_ptr, node_ptr, retdest %endrep - // stack: ptr_next_child, retdest - POP + // stack: value_ptr_ptr, value_ptr, node_ptr, retdest + %mstore_trie_data + // stack: node_ptr, retdest + SWAP1 JUMP load_mpt_extension: // stack: node_type, retdest - POP - // stack: retdest + %get_trie_data_size + // stack: node_ptr, node_type, retdest + SWAP1 %append_to_trie_data + // stack: node_ptr, retdest PROVER_INPUT(mpt) // read num_nibbles %append_to_trie_data PROVER_INPUT(mpt) // read packed_nibbles %append_to_trie_data - // stack: retdest + // stack: node_ptr, retdest - // Let i be the current trie data size. We still need to expand this node by - // one element, appending our child pointer. Thus our child node will start - // at i + 1. So we will set our child pointer to i + 1. %get_trie_data_size - %increment - %append_to_trie_data - // stack: retdest + // stack: child_ptr_ptr, node_ptr, retdest + // Increment trie_data_size, to leave room for child_ptr_ptr, before we load our child. + DUP1 %increment %set_trie_data_size + // stack: child_ptr_ptr, node_ptr, retdest %load_mpt - // stack: retdest + // stack: child_ptr, child_ptr_ptr, node_ptr, retdest + SWAP1 + %mstore_trie_data + // stack: node_ptr, retdest + SWAP1 JUMP load_mpt_leaf: // stack: node_type, retdest - POP - // stack: retdest + %get_trie_data_size + // stack: node_ptr, node_type, retdest + SWAP1 %append_to_trie_data + // stack: node_ptr, retdest PROVER_INPUT(mpt) // read num_nibbles %append_to_trie_data PROVER_INPUT(mpt) // read packed_nibbles %append_to_trie_data - // stack: retdest - // We need to append a pointer to where the value will live. - // %load_leaf_value will append the value just after this pointer; - // we add 1 to account for the pointer itself. - %get_trie_data_size %increment %append_to_trie_data - // stack: retdest - %load_leaf_value - // stack: retdest + // stack: node_ptr, retdest + // We save value_ptr_ptr = get_trie_data_size, then increment trie_data_size + // to skip over the slot for value_ptr. We will write value_ptr after the + // load_value call. + %get_trie_data_size + // stack: value_ptr_ptr, node_ptr, retdest + DUP1 %increment %set_trie_data_size + // stack: value_ptr_ptr, node_ptr, retdest + %load_value + // stack: value_ptr, value_ptr_ptr, node_ptr, retdest + SWAP1 %mstore_trie_data + // stack: node_ptr, retdest + SWAP1 JUMP load_mpt_digest: // stack: node_type, retdest - POP - // stack: retdest + %get_trie_data_size + // stack: node_ptr, node_type, retdest + SWAP1 %append_to_trie_data + // stack: node_ptr, retdest PROVER_INPUT(mpt) // read digest %append_to_trie_data - // stack: retdest + // stack: node_ptr, retdest + SWAP1 JUMP // Convenience macro to call load_mpt and return where we left off. @@ -158,34 +170,37 @@ load_mpt_digest: %%after: %endmacro -%macro load_mpt_and_return_root_ptr - // stack: (empty) - %get_trie_data_size - // stack: ptr - %load_mpt - // stack: ptr -%endmacro - -// Load a leaf from prover input, and append it to trie data. -%macro load_leaf_value +// Load a leaf from prover input, append it to trie data, and return a pointer to it. +%macro load_value // stack: (empty) PROVER_INPUT(mpt) - // stack: leaf_len + // stack: value_len + DUP1 ISZERO + %jumpi(%%return_null) + // stack: value_len + %get_trie_data_size + SWAP1 + // stack: value_len, value_ptr DUP1 %append_to_trie_data - // stack: leaf_len + // stack: value_len, value_ptr %%loop: DUP1 ISZERO - // stack: leaf_len == 0, leaf_len - %jumpi(%%finish) - // stack: leaf_len + // stack: value_len == 0, value_len, value_ptr + %jumpi(%%finish_loop) + // stack: value_len, value_ptr PROVER_INPUT(mpt) - // stack: leaf_part, leaf_len + // stack: leaf_part, value_len, value_ptr %append_to_trie_data - // stack: leaf_len + // stack: value_len, value_ptr %decrement - // stack: leaf_len' + // stack: value_len', value_ptr %jump(%%loop) -%%finish: +%%finish_loop: + // stack: value_len, value_ptr POP - // stack: (empty) + // stack: value_ptr + %jump(%%end) +%%return_null: + %stack (value_len) -> (0) +%%end: %endmacro diff --git a/evm/src/cpu/kernel/asm/mpt/write.asm b/evm/src/cpu/kernel/asm/mpt/write.asm deleted file mode 100644 index eab51d3e..00000000 --- a/evm/src/cpu/kernel/asm/mpt/write.asm +++ /dev/null @@ -1,47 +0,0 @@ -// TODO: Need a special case for deleting, if value = ''. -// Or canonicalize once, before final hashing, to remove empty leaves etc. - -// Return a copy of the given node, with the given key set to the given value. -// -// Pre stack: node_ptr, num_nibbles, key, value_ptr, retdest -// Post stack: updated_node_ptr -global mpt_insert: - // stack: node_ptr, num_nibbles, key, value_ptr, retdest - DUP1 %mload_trie_data - // stack: node_type, node_ptr, num_nibbles, key, value_ptr, retdest - // Increment node_ptr, so it points to the node payload instead of its type. - SWAP1 %increment SWAP1 - // stack: node_type, node_payload_ptr, num_nibbles, key, value_ptr, retdest - - DUP1 %eq_const(@MPT_NODE_EMPTY) %jumpi(mpt_insert_empty) - DUP1 %eq_const(@MPT_NODE_BRANCH) %jumpi(mpt_insert_branch) - DUP1 %eq_const(@MPT_NODE_EXTENSION) %jumpi(mpt_insert_extension) - DUP1 %eq_const(@MPT_NODE_LEAF) %jumpi(mpt_insert_leaf) - - // There's still the MPT_NODE_HASH case, but if we hit a hash node, - // it means the prover failed to provide necessary Merkle data, so panic. - PANIC - -mpt_insert_empty: - // stack: node_type, node_payload_ptr, num_nibbles, key, value_ptr, retdest - POP - // stack: node_payload_ptr, num_nibbles, key, value_ptr, retdest - PANIC // TODO - -mpt_insert_branch: - // stack: node_type, node_payload_ptr, num_nibbles, key, value_ptr, retdest - POP - // stack: node_payload_ptr, num_nibbles, key, value_ptr, retdest - PANIC // TODO - -mpt_insert_extension: - // stack: node_type, node_payload_ptr, num_nibbles, key, value_ptr, retdest - POP - // stack: node_payload_ptr, num_nibbles, key, value_ptr, retdest - PANIC // TODO - -mpt_insert_leaf: - // stack: node_type, node_payload_ptr, num_nibbles, key, value_ptr, retdest - POP - // stack: node_payload_ptr, num_nibbles, key, value_ptr, retdest - PANIC // TODO diff --git a/evm/src/cpu/kernel/interpreter.rs b/evm/src/cpu/kernel/interpreter.rs index 589ba6b3..2eb9dcb9 100644 --- a/evm/src/cpu/kernel/interpreter.rs +++ b/evm/src/cpu/kernel/interpreter.rs @@ -168,10 +168,19 @@ impl<'a> Interpreter<'a> { self.memory.context_memory[0].segments[Segment::GlobalMetadata as usize].get(field as usize) } + pub(crate) fn set_global_metadata_field(&mut self, field: GlobalMetadata, value: U256) { + self.memory.context_memory[0].segments[Segment::GlobalMetadata as usize] + .set(field as usize, value) + } + pub(crate) fn get_trie_data(&self) -> &[U256] { &self.memory.context_memory[0].segments[Segment::TrieData as usize].content } + pub(crate) fn get_trie_data_mut(&mut self) -> &mut Vec { + &mut self.memory.context_memory[0].segments[Segment::TrieData as usize].content + } + pub(crate) fn get_rlp_memory(&self) -> Vec { self.memory.context_memory[0].segments[Segment::RlpRaw as usize] .content @@ -205,7 +214,7 @@ impl<'a> Interpreter<'a> { self.push(if x { U256::one() } else { U256::zero() }); } - fn pop(&mut self) -> U256 { + pub(crate) fn pop(&mut self) -> U256 { self.stack_mut().pop().expect("Pop on empty stack.") } diff --git a/evm/src/cpu/kernel/tests/mpt/insert.rs b/evm/src/cpu/kernel/tests/mpt/insert.rs new file mode 100644 index 00000000..7aeb4a1a --- /dev/null +++ b/evm/src/cpu/kernel/tests/mpt/insert.rs @@ -0,0 +1,173 @@ +use anyhow::Result; +use eth_trie_utils::partial_trie::{Nibbles, PartialTrie}; +use eth_trie_utils::trie_builder::InsertEntry; +use ethereum_types::{BigEndianHash, H256}; + +use crate::cpu::kernel::aggregator::KERNEL; +use crate::cpu::kernel::constants::global_metadata::GlobalMetadata; +use crate::cpu::kernel::interpreter::Interpreter; +use crate::cpu::kernel::tests::mpt::{extension_to_leaf, test_account_1_rlp, test_account_2_rlp}; +use crate::generation::mpt::{all_mpt_prover_inputs_reversed, AccountRlp}; +use crate::generation::TrieInputs; + +#[test] +fn mpt_insert_empty() -> Result<()> { + let insert = InsertEntry { + nibbles: Nibbles { + count: 3, + packed: 0xABC.into(), + }, + v: test_account_2_rlp(), + }; + test_state_trie(Default::default(), insert) +} + +#[test] +#[ignore] // TODO: Enable when mpt_insert_leaf is done. +fn mpt_insert_leaf_same_key() -> Result<()> { + let key = Nibbles { + count: 3, + packed: 0xABC.into(), + }; + let state_trie = PartialTrie::Leaf { + nibbles: key, + value: test_account_1_rlp(), + }; + let insert = InsertEntry { + nibbles: key, + v: test_account_2_rlp(), + }; + + test_state_trie(state_trie, insert) +} + +#[test] +fn mpt_insert_branch_replacing_empty_child() -> Result<()> { + let children = std::array::from_fn(|_| Box::new(PartialTrie::Empty)); + let state_trie = PartialTrie::Branch { + children, + value: vec![], + }; + + let insert = InsertEntry { + nibbles: Nibbles { + count: 3, + packed: 0xABC.into(), + }, + v: test_account_2_rlp(), + }; + + test_state_trie(state_trie, insert) +} + +#[test] +#[ignore] // TODO: Enable when mpt_insert_extension is done. +fn mpt_insert_extension_to_leaf_same_key() -> Result<()> { + let state_trie = extension_to_leaf(test_account_1_rlp()); + + let insert = InsertEntry { + nibbles: Nibbles { + count: 3, + packed: 0xABCDEF.into(), + }, + v: test_account_2_rlp(), + }; + + test_state_trie(state_trie, insert) +} + +#[test] +#[ignore] // TODO: Enable when mpt_insert_leaf is done. +fn mpt_insert_branch_to_leaf_same_key() -> Result<()> { + let leaf = PartialTrie::Leaf { + nibbles: Nibbles { + count: 3, + packed: 0xBCD.into(), + }, + value: test_account_1_rlp(), + }; + let mut children = std::array::from_fn(|_| Box::new(PartialTrie::Empty)); + children[0xA] = Box::new(leaf); + let state_trie = PartialTrie::Branch { + children, + value: vec![], + }; + + let insert = InsertEntry { + nibbles: Nibbles { + count: 4, + packed: 0xABCD.into(), + }, + v: test_account_2_rlp(), + }; + + test_state_trie(state_trie, insert) +} + +fn test_state_trie(state_trie: PartialTrie, insert: InsertEntry) -> Result<()> { + let trie_inputs = TrieInputs { + state_trie: state_trie.clone(), + transactions_trie: Default::default(), + receipts_trie: Default::default(), + storage_tries: vec![], + }; + let load_all_mpts = KERNEL.global_labels["load_all_mpts"]; + let mpt_insert_state_trie = KERNEL.global_labels["mpt_insert_state_trie"]; + let mpt_hash_state_trie = KERNEL.global_labels["mpt_hash_state_trie"]; + + let initial_stack = vec![0xDEADBEEFu32.into()]; + let mut interpreter = Interpreter::new_with_kernel(load_all_mpts, initial_stack); + interpreter.generation_state.mpt_prover_inputs = all_mpt_prover_inputs_reversed(&trie_inputs); + interpreter.run()?; + assert_eq!(interpreter.stack(), vec![]); + + // Next, execute mpt_insert_state_trie. + interpreter.offset = mpt_insert_state_trie; + let trie_data = interpreter.get_trie_data_mut(); + if trie_data.is_empty() { + // In the assembly we skip over 0, knowing trie_data[0] = 0 by default. + // Since we don't explicitly set it to 0, we need to do so here. + trie_data.push(0.into()); + } + let value_ptr = trie_data.len(); + let account: AccountRlp = rlp::decode(&insert.v).expect("Decoding failed"); + let account_data = account.to_vec(); + trie_data.push(account_data.len().into()); + trie_data.extend(account_data); + let trie_data_len = trie_data.len().into(); + interpreter.set_global_metadata_field(GlobalMetadata::TrieDataSize, trie_data_len); + interpreter.push(0xDEADBEEFu32.into()); + interpreter.push(value_ptr.into()); // value_ptr + interpreter.push(insert.nibbles.packed); // key + interpreter.push(insert.nibbles.count.into()); // num_nibbles + + interpreter.run()?; + assert_eq!(interpreter.stack().len(), 0); + + // Now, execute mpt_hash_state_trie. + interpreter.offset = mpt_hash_state_trie; + interpreter.push(0xDEADBEEFu32.into()); + interpreter.run()?; + + assert_eq!( + interpreter.stack().len(), + 1, + "Expected 1 item on stack, found {:?}", + interpreter.stack() + ); + let hash = H256::from_uint(&interpreter.stack()[0]); + + let expected_state_trie_hash = apply_insert(state_trie, insert).calc_hash(); + assert_eq!(hash, expected_state_trie_hash); + + Ok(()) +} + +fn apply_insert(trie: PartialTrie, insert: InsertEntry) -> PartialTrie { + let mut trie = Box::new(trie); + if let Some(updated_trie) = PartialTrie::insert_into_trie(&mut trie, insert) { + *updated_trie + } else { + *trie + } +} diff --git a/evm/src/cpu/kernel/tests/mpt/load.rs b/evm/src/cpu/kernel/tests/mpt/load.rs index ca4f7071..ccf8353e 100644 --- a/evm/src/cpu/kernel/tests/mpt/load.rs +++ b/evm/src/cpu/kernel/tests/mpt/load.rs @@ -1,26 +1,19 @@ use anyhow::Result; -use ethereum_types::{BigEndianHash, H256, U256}; +use eth_trie_utils::partial_trie::{Nibbles, PartialTrie}; +use ethereum_types::{BigEndianHash, U256}; use crate::cpu::kernel::aggregator::KERNEL; use crate::cpu::kernel::constants::global_metadata::GlobalMetadata; use crate::cpu::kernel::constants::trie_type::PartialTrieType; use crate::cpu::kernel::interpreter::Interpreter; -use crate::cpu::kernel::tests::mpt::extension_to_leaf; -use crate::generation::mpt::{all_mpt_prover_inputs_reversed, AccountRlp}; +use crate::cpu::kernel::tests::mpt::{extension_to_leaf, test_account_1, test_account_1_rlp}; +use crate::generation::mpt::all_mpt_prover_inputs_reversed; use crate::generation::TrieInputs; #[test] -fn load_all_mpts() -> Result<()> { - let account = AccountRlp { - nonce: U256::from(1111), - balance: U256::from(2222), - storage_root: H256::from_uint(&U256::from(3333)), - code_hash: H256::from_uint(&U256::from(4444)), - }; - let account_rlp = rlp::encode(&account); - +fn load_all_mpts_empty() -> Result<()> { let trie_inputs = TrieInputs { - state_trie: extension_to_leaf(account_rlp.to_vec()), + state_trie: Default::default(), transactions_trie: Default::default(), receipts_trie: Default::default(), storage_tries: vec![], @@ -28,13 +21,174 @@ fn load_all_mpts() -> Result<()> { let load_all_mpts = KERNEL.global_labels["load_all_mpts"]; - let initial_stack = vec![0xdeadbeefu32.into()]; + let initial_stack = vec![0xDEADBEEFu32.into()]; + let mut interpreter = Interpreter::new_with_kernel(load_all_mpts, initial_stack); + interpreter.generation_state.mpt_prover_inputs = all_mpt_prover_inputs_reversed(&trie_inputs); + interpreter.run()?; + assert_eq!(interpreter.stack(), vec![]); + + assert_eq!(interpreter.get_trie_data(), vec![]); + + assert_eq!( + interpreter.get_global_metadata_field(GlobalMetadata::StateTrieRoot), + 0.into() + ); + assert_eq!( + interpreter.get_global_metadata_field(GlobalMetadata::TransactionTrieRoot), + 0.into() + ); + assert_eq!( + interpreter.get_global_metadata_field(GlobalMetadata::ReceiptTrieRoot), + 0.into() + ); + + assert_eq!( + interpreter.get_global_metadata_field(GlobalMetadata::NumStorageTries), + trie_inputs.storage_tries.len().into() + ); + + Ok(()) +} + +#[test] +fn load_all_mpts_leaf() -> Result<()> { + let trie_inputs = TrieInputs { + state_trie: PartialTrie::Leaf { + nibbles: Nibbles { + count: 3, + packed: 0xABC.into(), + }, + value: test_account_1_rlp(), + }, + transactions_trie: Default::default(), + receipts_trie: Default::default(), + storage_tries: vec![], + }; + + let load_all_mpts = KERNEL.global_labels["load_all_mpts"]; + + let initial_stack = vec![0xDEADBEEFu32.into()]; + let mut interpreter = Interpreter::new_with_kernel(load_all_mpts, initial_stack); + interpreter.generation_state.mpt_prover_inputs = all_mpt_prover_inputs_reversed(&trie_inputs); + interpreter.run()?; + assert_eq!(interpreter.stack(), vec![]); + + let type_leaf = U256::from(PartialTrieType::Leaf as u32); + assert_eq!( + interpreter.get_trie_data(), + vec![ + 0.into(), + type_leaf, + 3.into(), + 0xABC.into(), + 5.into(), // value ptr + 4.into(), // value length + test_account_1().nonce, + test_account_1().balance, + test_account_1().storage_root.into_uint(), + test_account_1().code_hash.into_uint(), + ] + ); + + assert_eq!( + interpreter.get_global_metadata_field(GlobalMetadata::TransactionTrieRoot), + 0.into() + ); + assert_eq!( + interpreter.get_global_metadata_field(GlobalMetadata::ReceiptTrieRoot), + 0.into() + ); + + assert_eq!( + interpreter.get_global_metadata_field(GlobalMetadata::NumStorageTries), + trie_inputs.storage_tries.len().into() + ); + + Ok(()) +} + +#[test] +fn load_all_mpts_empty_branch() -> Result<()> { + let children = std::array::from_fn(|_| Box::new(PartialTrie::Empty)); + let state_trie = PartialTrie::Branch { + children, + value: vec![], + }; + let trie_inputs = TrieInputs { + state_trie, + transactions_trie: Default::default(), + receipts_trie: Default::default(), + storage_tries: vec![], + }; + + let load_all_mpts = KERNEL.global_labels["load_all_mpts"]; + + let initial_stack = vec![0xDEADBEEFu32.into()]; + let mut interpreter = Interpreter::new_with_kernel(load_all_mpts, initial_stack); + interpreter.generation_state.mpt_prover_inputs = all_mpt_prover_inputs_reversed(&trie_inputs); + interpreter.run()?; + assert_eq!(interpreter.stack(), vec![]); + + let type_branch = U256::from(PartialTrieType::Branch as u32); + assert_eq!( + interpreter.get_trie_data(), + vec![ + 0.into(), // First address is unused, so that 0 can be treated as a null pointer. + type_branch, + 0.into(), // child 0 + 0.into(), // ... + 0.into(), + 0.into(), + 0.into(), + 0.into(), + 0.into(), + 0.into(), + 0.into(), + 0.into(), + 0.into(), + 0.into(), + 0.into(), + 0.into(), + 0.into(), + 0.into(), // child 16 + 0.into(), // value_ptr + ] + ); + + assert_eq!( + interpreter.get_global_metadata_field(GlobalMetadata::TransactionTrieRoot), + 0.into() + ); + assert_eq!( + interpreter.get_global_metadata_field(GlobalMetadata::ReceiptTrieRoot), + 0.into() + ); + + assert_eq!( + interpreter.get_global_metadata_field(GlobalMetadata::NumStorageTries), + trie_inputs.storage_tries.len().into() + ); + + Ok(()) +} + +#[test] +fn load_all_mpts_ext_to_leaf() -> Result<()> { + let trie_inputs = TrieInputs { + state_trie: extension_to_leaf(test_account_1_rlp()), + transactions_trie: Default::default(), + receipts_trie: Default::default(), + storage_tries: vec![], + }; + + let load_all_mpts = KERNEL.global_labels["load_all_mpts"]; + + let initial_stack = vec![0xDEADBEEFu32.into()]; let mut interpreter = Interpreter::new_with_kernel(load_all_mpts, initial_stack); interpreter.generation_state.mpt_prover_inputs = all_mpt_prover_inputs_reversed(&trie_inputs); interpreter.run()?; assert_eq!(interpreter.stack(), vec![]); - let type_empty = U256::from(PartialTrieType::Empty as u32); let type_extension = U256::from(PartialTrieType::Extension as u32); let type_leaf = U256::from(PartialTrieType::Leaf as u32); assert_eq!( @@ -50,12 +204,10 @@ fn load_all_mpts() -> Result<()> { 0xDEF.into(), // key part 9.into(), // value pointer 4.into(), // value length - account.nonce, - account.balance, - account.storage_root.into_uint(), - account.code_hash.into_uint(), - type_empty, // txn trie - type_empty, // receipt trie + test_account_1().nonce, + test_account_1().balance, + test_account_1().storage_root.into_uint(), + test_account_1().code_hash.into_uint(), ] ); diff --git a/evm/src/cpu/kernel/tests/mpt/mod.rs b/evm/src/cpu/kernel/tests/mpt/mod.rs index 55a56653..e3414b38 100644 --- a/evm/src/cpu/kernel/tests/mpt/mod.rs +++ b/evm/src/cpu/kernel/tests/mpt/mod.rs @@ -1,10 +1,40 @@ use eth_trie_utils::partial_trie::{Nibbles, PartialTrie}; +use ethereum_types::{BigEndianHash, H256, U256}; + +use crate::generation::mpt::AccountRlp; mod hash; mod hex_prefix; +mod insert; mod load; mod read; +pub(crate) fn test_account_1() -> AccountRlp { + AccountRlp { + nonce: U256::from(1111), + balance: U256::from(2222), + storage_root: H256::from_uint(&U256::from(3333)), + code_hash: H256::from_uint(&U256::from(4444)), + } +} + +pub(crate) fn test_account_1_rlp() -> Vec { + rlp::encode(&test_account_1()).to_vec() +} + +pub(crate) fn test_account_2() -> AccountRlp { + AccountRlp { + nonce: U256::from(5555), + balance: U256::from(6666), + storage_root: H256::from_uint(&U256::from(7777)), + code_hash: H256::from_uint(&U256::from(8888)), + } +} + +pub(crate) fn test_account_2_rlp() -> Vec { + rlp::encode(&test_account_2()).to_vec() +} + /// A `PartialTrie` where an extension node leads to a leaf node containing an account. pub(crate) fn extension_to_leaf(value: Vec) -> PartialTrie { PartialTrie::Extension { diff --git a/evm/src/generation/mpt.rs b/evm/src/generation/mpt.rs index f6bc630d..e35364c6 100644 --- a/evm/src/generation/mpt.rs +++ b/evm/src/generation/mpt.rs @@ -13,6 +13,17 @@ pub(crate) struct AccountRlp { pub(crate) code_hash: H256, } +impl AccountRlp { + pub(crate) fn to_vec(&self) -> Vec { + vec![ + self.nonce, + self.balance, + self.storage_root.into_uint(), + self.code_hash.into_uint(), + ] + } +} + pub(crate) fn all_mpt_prover_inputs_reversed(trie_inputs: &TrieInputs) -> Vec { let mut inputs = all_mpt_prover_inputs(trie_inputs); inputs.reverse(); @@ -25,12 +36,7 @@ pub(crate) fn all_mpt_prover_inputs(trie_inputs: &TrieInputs) -> Vec { mpt_prover_inputs(&trie_inputs.state_trie, &mut prover_inputs, &|rlp| { let account: AccountRlp = rlp::decode(rlp).expect("Decoding failed"); - vec![ - account.nonce, - account.balance, - account.storage_root.into_uint(), - account.code_hash.into_uint(), - ] + account.to_vec() }); mpt_prover_inputs(&trie_inputs.transactions_trie, &mut prover_inputs, &|rlp| { From 4a055b3a76c650c529bcd6b74161a7bf17318980 Mon Sep 17 00:00:00 2001 From: Daniel Lubarov Date: Sun, 9 Oct 2022 11:32:01 -0700 Subject: [PATCH 06/17] MPT insert logic, part 2 --- evm/src/cpu/kernel/asm/mpt/insert.asm | 35 +++++++++-- evm/src/cpu/kernel/asm/mpt/util.asm | 85 ++++++++++++++++++++++++++ evm/src/cpu/kernel/tests/mpt/insert.rs | 10 ++- 3 files changed, 123 insertions(+), 7 deletions(-) diff --git a/evm/src/cpu/kernel/asm/mpt/insert.asm b/evm/src/cpu/kernel/asm/mpt/insert.asm index 9a5ed7f2..cc748e79 100644 --- a/evm/src/cpu/kernel/asm/mpt/insert.asm +++ b/evm/src/cpu/kernel/asm/mpt/insert.asm @@ -106,13 +106,40 @@ mpt_insert_branch_nonterminal_after_recursion: JUMP mpt_insert_extension: - // stack: node_type, node_payload_ptr, num_nibbles, key, value_ptr, retdest + // stack: node_type, node_payload_ptr, insert_len, insert_key, value_ptr, retdest POP - // stack: node_payload_ptr, num_nibbles, key, value_ptr, retdest + // stack: node_payload_ptr, insert_len, insert_key, value_ptr, retdest PANIC // TODO mpt_insert_leaf: - // stack: node_type, node_payload_ptr, num_nibbles, key, value_ptr, retdest + // stack: node_type, node_payload_ptr, insert_len, insert_key, value_ptr, retdest POP - // stack: node_payload_ptr, num_nibbles, key, value_ptr, retdest + // stack: node_payload_ptr, insert_len, insert_key, value_ptr, retdest + %stack (node_payload_ptr, insert_len, insert_key) -> (insert_len, insert_key, node_payload_ptr) + // stack: insert_len, insert_key, node_payload_ptr, value_ptr, retdest + DUP3 %increment %mload_trie_data + // stack: node_key, insert_len, insert_key, node_payload_ptr, value_ptr, retdest + DUP4 %mload_trie_data + // stack: node_len, node_key, insert_len, insert_key, node_payload_ptr, value_ptr, retdest + // TODO: Maybe skip %split_common_prefix if lengths & keys exactly match. + %split_common_prefix + // stack: common_len, common_key, node_len, node_key, insert_len, insert_key, node_payload_ptr, value_ptr, retdest + DUP3 DUP6 ADD %jumpi(mpt_insert_leaf_not_exact_match) + // If we got here, the node key exactly matches the insert key, so we will + // keep the same leaf node structure and just replace its value. + %stack (common_len, common_key, node_len, node_key, insert_len, insert_key, node_payload_ptr, value_ptr) + -> (common_len, common_key, value_ptr) + // stack: common_len, common_key, value_ptr, retdest + %get_trie_data_size + // stack: updated_leaf_ptr, common_len, common_key, value_ptr, retdest + PUSH @MPT_NODE_LEAF %append_to_trie_data + SWAP1 %append_to_trie_data // append common_len + SWAP1 %append_to_trie_data // append common_key + SWAP1 %append_to_trie_data // append value_ptr + // stack: updated_leaf_ptr, retdest + SWAP1 + JUMP +mpt_insert_leaf_not_exact_match: + // stack: common_len, common_key, node_len, node_key, insert_len, insert_key, node_payload_ptr, value_ptr, retdest + // %get_trie_data_size PANIC // TODO diff --git a/evm/src/cpu/kernel/asm/mpt/util.asm b/evm/src/cpu/kernel/asm/mpt/util.asm index 19cad943..a3ff7f38 100644 --- a/evm/src/cpu/kernel/asm/mpt/util.asm +++ b/evm/src/cpu/kernel/asm/mpt/util.asm @@ -72,3 +72,88 @@ POP // stack: first_nibble, num_nibbles, key %endmacro + +// Split off the common prefix among two key parts. Roughly equivalent to +// def split_common_prefix(len_1, key_1, len_2, key_2): +// bits_1 = len_1 * 4 +// bits_2 = len_2 * 4 +// len_common = 0 +// key_common = 0 +// while True: +// if bits_1 * bits_2 == 0: +// break +// first_nib_1 = (key_1 >> (bits_1 - 4)) & 0xF +// first_nib_2 = (key_2 >> (bits_2 - 4)) & 0xF +// if first_nib_1 != first_nib_2: +// break +// len_common += 1 +// key_common = key_common * 16 + first_nib_1 +// bits_1 -= 4 +// bits_2 -= 4 +// key_1 -= (first_nib_1 << bits_1) +// key_2 -= (first_nib_2 << bits_2) +// len_1 = bits_1 // 4 +// len_2 = bits_2 // 4 +// return (len_common, key_common, len_1, key_1, len_2, key_2) +%macro split_common_prefix + // stack: len_1, key_1, len_2, key_2 + %mul_const(4) + SWAP2 %mul_const(4) SWAP2 + // stack: bits_1, key_1, bits_2, key_2 + PUSH 0 + PUSH 0 + +%%loop: + // stack: len_common, key_common, bits_1, key_1, bits_2, key_2 + + // if bits_1 * bits_2 == 0: break + DUP3 DUP6 MUL ISZERO %jumpi(%%return) + + // first_nib_2 = (key_2 >> (bits_2 - 4)) & 0xF + DUP6 DUP6 %sub_const(4) SHR %and_const(0xF) + // first_nib_1 = (key_1 >> (bits_1 - 4)) & 0xF + DUP5 DUP5 %sub_const(4) SHR %and_const(0xF) + // stack: first_nib_1, first_nib_2, len_common, key_common, bits_1, key_1, bits_2, key_2 + + // if first_nib_1 != first_nib_2: break + DUP2 DUP2 SUB %jumpi(%%return_with_first_nibs) + + // len_common += 1 + SWAP2 %increment SWAP2 + + // key_common = key_common * 16 + first_nib_1 + SWAP3 + %mul_const(16) + DUP4 ADD + SWAP3 + // stack: first_nib_1, first_nib_2, len_common, key_common, bits_1, key_1, bits_2, key_2 + + // bits_1 -= 4 + SWAP4 %sub_const(4) SWAP4 + // bits_2 -= 4 + SWAP6 %sub_const(4) SWAP6 + // stack: first_nib_1, first_nib_2, len_common, key_common, bits_1, key_1, bits_2, key_2 + + // key_1 -= (first_nib_1 << bits_1) + DUP5 SHL + // stack: first_nib_1 << bits_1, first_nib_2, len_common, key_common, bits_1, key_1, bits_2, key_2 + DUP6 SUB + // stack: key_1, first_nib_2, len_common, key_common, bits_1, key_1_old, bits_2, key_2 + SWAP5 POP + // stack: first_nib_2, len_common, key_common, bits_1, key_1, bits_2, key_2 + + // key_2 -= (first_nib_2 << bits_2) + DUP6 SHL + // stack: first_nib_2 << bits_2, len_common, key_common, bits_1, key_1, bits_2, key_2 + DUP7 SUB + // stack: key_2, len_common, key_common, bits_1, key_1, bits_2, key_2_old + SWAP6 POP + // stack: len_common, key_common, bits_1, key_1, bits_2, key_2 + + %jump(%%loop) +%%return_with_first_nibs: + // stack: first_nib_1, first_nib_2, len_common, key_common, bits_1, key_1, bits_2, key_2 + %pop2 +%%return: + // stack: len_common, key_common, len_1, key_1, len_2, key_2 +%endmacro diff --git a/evm/src/cpu/kernel/tests/mpt/insert.rs b/evm/src/cpu/kernel/tests/mpt/insert.rs index 7aeb4a1a..11927d52 100644 --- a/evm/src/cpu/kernel/tests/mpt/insert.rs +++ b/evm/src/cpu/kernel/tests/mpt/insert.rs @@ -23,7 +23,6 @@ fn mpt_insert_empty() -> Result<()> { } #[test] -#[ignore] // TODO: Enable when mpt_insert_leaf is done. fn mpt_insert_leaf_same_key() -> Result<()> { let key = Nibbles { count: 3, @@ -142,7 +141,12 @@ fn test_state_trie(state_trie: PartialTrie, insert: InsertEntry) -> Result<()> { interpreter.push(insert.nibbles.count.into()); // num_nibbles interpreter.run()?; - assert_eq!(interpreter.stack().len(), 0); + assert_eq!( + interpreter.stack().len(), + 0, + "Expected empty stack after insert, found {:?}", + interpreter.stack() + ); // Now, execute mpt_hash_state_trie. interpreter.offset = mpt_hash_state_trie; @@ -152,7 +156,7 @@ fn test_state_trie(state_trie: PartialTrie, insert: InsertEntry) -> Result<()> { assert_eq!( interpreter.stack().len(), 1, - "Expected 1 item on stack, found {:?}", + "Expected 1 item on stack after hashing, found {:?}", interpreter.stack() ); let hash = H256::from_uint(&interpreter.stack()[0]); From 33dba3a23dfcf2fbfc07bd08c6a7db3a97630c69 Mon Sep 17 00:00:00 2001 From: Daniel Lubarov Date: Sun, 9 Oct 2022 20:18:16 -0700 Subject: [PATCH 07/17] Insertion optimization for leaf case --- evm/src/cpu/kernel/asm/mpt/insert.asm | 31 ++++++++++++++++----------- 1 file changed, 19 insertions(+), 12 deletions(-) diff --git a/evm/src/cpu/kernel/asm/mpt/insert.asm b/evm/src/cpu/kernel/asm/mpt/insert.asm index cc748e79..65c18428 100644 --- a/evm/src/cpu/kernel/asm/mpt/insert.asm +++ b/evm/src/cpu/kernel/asm/mpt/insert.asm @@ -121,14 +121,25 @@ mpt_insert_leaf: // stack: node_key, insert_len, insert_key, node_payload_ptr, value_ptr, retdest DUP4 %mload_trie_data // stack: node_len, node_key, insert_len, insert_key, node_payload_ptr, value_ptr, retdest - // TODO: Maybe skip %split_common_prefix if lengths & keys exactly match. + + // If the keys match, i.e. node_len == insert_len && node_key == insert_key, + // then we're simply replacing the leaf node's value. Since this is a common + // case, it's best to detect it early. Calling %split_common_prefix could be + // expensive as leaf keys tend to be long. + DUP1 DUP4 EQ // node_len == insert_len + DUP3 DUP6 EQ // node_key == insert_key + MUL // Cheaper than AND + // stack: key_match, node_len, node_key, insert_len, insert_key, node_payload_ptr, value_ptr, retdest + %jumpi(mpt_insert_leaf_keys_match) + %split_common_prefix - // stack: common_len, common_key, node_len, node_key, insert_len, insert_key, node_payload_ptr, value_ptr, retdest - DUP3 DUP6 ADD %jumpi(mpt_insert_leaf_not_exact_match) - // If we got here, the node key exactly matches the insert key, so we will - // keep the same leaf node structure and just replace its value. - %stack (common_len, common_key, node_len, node_key, insert_len, insert_key, node_payload_ptr, value_ptr) - -> (common_len, common_key, value_ptr) + PANIC // TODO + +mpt_insert_leaf_keys_match: + // The keys match exactly, so we simply create a new leaf node with the new value.xs + // stack: node_len, node_key, insert_len, insert_key, node_payload_ptr, value_ptr, retdest + %stack (node_len, node_key, insert_len, insert_key, node_payload_ptr, value_ptr) + -> (node_len, node_key, value_ptr) // stack: common_len, common_key, value_ptr, retdest %get_trie_data_size // stack: updated_leaf_ptr, common_len, common_key, value_ptr, retdest @@ -136,10 +147,6 @@ mpt_insert_leaf: SWAP1 %append_to_trie_data // append common_len SWAP1 %append_to_trie_data // append common_key SWAP1 %append_to_trie_data // append value_ptr - // stack: updated_leaf_ptr, retdest + // stack: updated_leaf_ptr, retdestx SWAP1 JUMP -mpt_insert_leaf_not_exact_match: - // stack: common_len, common_key, node_len, node_key, insert_len, insert_key, node_payload_ptr, value_ptr, retdest - // %get_trie_data_size - PANIC // TODO From cad0473e1d76dad91aa8c1ae215d928c481b4930 Mon Sep 17 00:00:00 2001 From: Daniel Lubarov Date: Sun, 9 Oct 2022 21:37:46 -0700 Subject: [PATCH 08/17] More MPT insert logic --- evm/src/cpu/kernel/aggregator.rs | 2 + evm/src/cpu/kernel/asm/mpt/insert.asm | 47 +---- .../cpu/kernel/asm/mpt/insert_extension.asm | 5 + evm/src/cpu/kernel/asm/mpt/insert_leaf.asm | 167 ++++++++++++++++++ evm/src/cpu/kernel/asm/mpt/util.asm | 10 +- evm/src/cpu/kernel/asm/util/basic_macros.asm | 6 + evm/src/cpu/kernel/tests/mpt/insert.rs | 24 ++- 7 files changed, 212 insertions(+), 49 deletions(-) create mode 100644 evm/src/cpu/kernel/asm/mpt/insert_extension.asm create mode 100644 evm/src/cpu/kernel/asm/mpt/insert_leaf.asm diff --git a/evm/src/cpu/kernel/aggregator.rs b/evm/src/cpu/kernel/aggregator.rs index 6fb2231e..032338b7 100644 --- a/evm/src/cpu/kernel/aggregator.rs +++ b/evm/src/cpu/kernel/aggregator.rs @@ -44,6 +44,8 @@ pub(crate) fn combined_kernel() -> Kernel { include_str!("asm/mpt/hash_trie_specific.asm"), include_str!("asm/mpt/hex_prefix.asm"), include_str!("asm/mpt/insert.asm"), + include_str!("asm/mpt/insert_extension.asm"), + include_str!("asm/mpt/insert_leaf.asm"), include_str!("asm/mpt/insert_trie_specific.asm"), include_str!("asm/mpt/load.asm"), include_str!("asm/mpt/read.asm"), diff --git a/evm/src/cpu/kernel/asm/mpt/insert.asm b/evm/src/cpu/kernel/asm/mpt/insert.asm index 65c18428..2830d376 100644 --- a/evm/src/cpu/kernel/asm/mpt/insert.asm +++ b/evm/src/cpu/kernel/asm/mpt/insert.asm @@ -98,55 +98,10 @@ mpt_insert_branch_nonterminal: child_ptr_ptr, updated_branch_ptr) %mload_trie_data // Deref child_ptr_ptr, giving child_ptr %jump(mpt_insert) + mpt_insert_branch_nonterminal_after_recursion: // stack: updated_child_ptr, child_ptr_ptr, updated_branch_ptr, retdest SWAP1 %mstore_trie_data // Store the pointer to the updated child. // stack: updated_branch_ptr, retdest SWAP1 JUMP - -mpt_insert_extension: - // stack: node_type, node_payload_ptr, insert_len, insert_key, value_ptr, retdest - POP - // stack: node_payload_ptr, insert_len, insert_key, value_ptr, retdest - PANIC // TODO - -mpt_insert_leaf: - // stack: node_type, node_payload_ptr, insert_len, insert_key, value_ptr, retdest - POP - // stack: node_payload_ptr, insert_len, insert_key, value_ptr, retdest - %stack (node_payload_ptr, insert_len, insert_key) -> (insert_len, insert_key, node_payload_ptr) - // stack: insert_len, insert_key, node_payload_ptr, value_ptr, retdest - DUP3 %increment %mload_trie_data - // stack: node_key, insert_len, insert_key, node_payload_ptr, value_ptr, retdest - DUP4 %mload_trie_data - // stack: node_len, node_key, insert_len, insert_key, node_payload_ptr, value_ptr, retdest - - // If the keys match, i.e. node_len == insert_len && node_key == insert_key, - // then we're simply replacing the leaf node's value. Since this is a common - // case, it's best to detect it early. Calling %split_common_prefix could be - // expensive as leaf keys tend to be long. - DUP1 DUP4 EQ // node_len == insert_len - DUP3 DUP6 EQ // node_key == insert_key - MUL // Cheaper than AND - // stack: key_match, node_len, node_key, insert_len, insert_key, node_payload_ptr, value_ptr, retdest - %jumpi(mpt_insert_leaf_keys_match) - - %split_common_prefix - PANIC // TODO - -mpt_insert_leaf_keys_match: - // The keys match exactly, so we simply create a new leaf node with the new value.xs - // stack: node_len, node_key, insert_len, insert_key, node_payload_ptr, value_ptr, retdest - %stack (node_len, node_key, insert_len, insert_key, node_payload_ptr, value_ptr) - -> (node_len, node_key, value_ptr) - // stack: common_len, common_key, value_ptr, retdest - %get_trie_data_size - // stack: updated_leaf_ptr, common_len, common_key, value_ptr, retdest - PUSH @MPT_NODE_LEAF %append_to_trie_data - SWAP1 %append_to_trie_data // append common_len - SWAP1 %append_to_trie_data // append common_key - SWAP1 %append_to_trie_data // append value_ptr - // stack: updated_leaf_ptr, retdestx - SWAP1 - JUMP diff --git a/evm/src/cpu/kernel/asm/mpt/insert_extension.asm b/evm/src/cpu/kernel/asm/mpt/insert_extension.asm new file mode 100644 index 00000000..36458165 --- /dev/null +++ b/evm/src/cpu/kernel/asm/mpt/insert_extension.asm @@ -0,0 +1,5 @@ +global mpt_insert_extension: + // stack: node_type, node_payload_ptr, insert_len, insert_key, value_ptr, retdest + POP + // stack: node_payload_ptr, insert_len, insert_key, value_ptr, retdest + PANIC // TODO diff --git a/evm/src/cpu/kernel/asm/mpt/insert_leaf.asm b/evm/src/cpu/kernel/asm/mpt/insert_leaf.asm new file mode 100644 index 00000000..b82653f2 --- /dev/null +++ b/evm/src/cpu/kernel/asm/mpt/insert_leaf.asm @@ -0,0 +1,167 @@ +// The high-level logic can be expressed with the following pseudocode: +// +// if node_len == insert_len && node_key == insert_key: +// return Leaf[node_key, insert_value] +// +// common_len, common_key, node_len, node_key, insert_len, insert_key = +// consume_common_prefix(node_len, node_key, insert_len, insert_key) +// +// branch = [MPT_TYPE_BRANCH] + [0] * 17 +// +// if node_len > 0: +// node_key_first, node_len, node_key = split_first_nibble(node_len, node_key) +// branch[node_key_first + 1] = Leaf[node_len, node_key, node_value] +// else: +// branch[17] = node_value +// +// if insert_len > 0: +// insert_key_first, insert_len, insert_key = split_first_nibble(insert_len, insert_key) +// branch[insert_key_first + 1] = Leaf[insert_len, insert_key, insert_value] +// else: +// branch[17] = insert_value +// +// if common_len > 0: +// return Extension[common_key, branch] +// else: +// return branch + +global mpt_insert_leaf: + // stack: node_type, node_payload_ptr, insert_len, insert_key, insert_value_ptr, retdest + POP + // stack: node_payload_ptr, insert_len, insert_key, insert_value_ptr, retdest + %stack (node_payload_ptr, insert_len, insert_key) -> (insert_len, insert_key, node_payload_ptr) + // stack: insert_len, insert_key, node_payload_ptr, insert_value_ptr, retdest + DUP3 %increment %mload_trie_data + // stack: node_key, insert_len, insert_key, node_payload_ptr, insert_value_ptr, retdest + DUP4 %mload_trie_data + // stack: node_len, node_key, insert_len, insert_key, node_payload_ptr, insert_value_ptr, retdest + + // If the keys match, i.e. node_len == insert_len && node_key == insert_key, + // then we're simply replacing the leaf node's value. Since this is a common + // case, it's best to detect it early. Calling %split_common_prefix could be + // expensive as leaf keys tend to be long. + DUP1 DUP4 EQ // node_len == insert_len + DUP3 DUP6 EQ // node_key == insert_key + MUL // Cheaper than AND + // stack: keys_match, node_len, node_key, insert_len, insert_key, node_payload_ptr, insert_value_ptr, retdest + %jumpi(keys_match) + + // Replace node_payload_ptr with node_value, which is node_payload[2]. + // stack: node_len, node_key, insert_len, insert_key, node_payload_ptr, insert_value_ptr, retdest + SWAP4 + %add_const(2) + %mload_trie_data + SWAP4 + // stack: node_len, node_key, insert_len, insert_key, node_value_ptr, insert_value_ptr, retdest + + // Split off any common prefix between the node key and the inserted key. + %split_common_prefix + // stack: common_len, common_key, node_len, node_key, insert_len, insert_key, node_value_ptr, insert_value_ptr, retdest + + // For the remaining cases, we will need a new branch node since the two keys diverge. + // We may also need an extension node above it (if common_len > 0); we will handle that later. + // For now, we allocate the branch node, initially with no children or value. + %get_trie_data_size + PUSH @MPT_NODE_BRANCH %append_to_trie_data + %rep 17 + PUSH 0 %append_to_trie_data + %endrep + // stack: branch_ptr, common_len, common_key, node_len, node_key, insert_len, insert_key, node_value_ptr, insert_value_ptr, retdest + + // Here, we branch based on whether each key continues beyond the common + // prefix, starting with the node key. + DUP4 // node_len + %jumpi(node_key_continues) + + // stack: branch_ptr, common_len, common_key, node_len, node_key, insert_len, insert_key, node_value_ptr, insert_value_ptr, retdest + // branch[17] = node_value_ptr + DUP8 // node_value_ptr + DUP2 // branch_ptr + %add_const(17) + %mstore_trie_data + +finished_processing_node_value: + DUP6 // insert_len + %jumpi(insert_key_continues) + + // stack: branch_ptr, common_len, common_key, node_len, node_key, insert_len, insert_key, node_value_ptr, insert_value_ptr, retdest + // branch[17] = insert_value_ptr + DUP9 // insert_value_ptr + DUP2 // branch_ptr + %add_const(17) + %mstore_trie_data + +finished_processing_insert_value: + // stack: branch_ptr, common_len, common_key, node_len, node_key, insert_len, insert_key, node_value_ptr, insert_value_ptr, retdest + // If common_len > 0, we need to add an extension node. + DUP2 %jumpi(extension_for_common_key) + // Otherwise, we simply return our branch node. + SWAP8 + %pop8 + // stack: branch_ptr, retdest + SWAP1 + JUMP + +extension_for_common_key: + // stack: branch_ptr, common_len, common_key, node_len, node_key, insert_len, insert_key, node_value_ptr, insert_value_ptr, retdest + PANIC // TODO + +node_key_continues: + // stack: branch_ptr, common_len, common_key, node_len, node_key, insert_len, insert_key, node_value_ptr, insert_value_ptr, retdest + // branch[node_key_first + 1] = Leaf[node_len, node_key, node_value] + DUP5 DUP5 + // stack: node_len, node_key, branch_ptr, ... + %split_first_nibble + // stack: node_key_first, node_len, node_key, branch_ptr, ... + %get_trie_data_size + // stack: leaf_ptr, node_key_first, node_len, node_key, branch_ptr, ... + SWAP1 + DUP5 // branch_ptr + %increment // Skip over node type field + ADD // Add node_key_first + %mstore_trie_data + // stack: node_len, node_key, branch_ptr, ... + PUSH @MPT_NODE_LEAF %append_to_trie_data + %append_to_trie_data // Append node_len to our leaf node + %append_to_trie_data // Append node_key to our leaf node + // stack: branch_ptr, common_len, common_key, node_len, node_key, insert_len, insert_key, node_value_ptr, insert_value_ptr, retdest + DUP8 %append_to_trie_data // Append node_value_ptr to our leaf node + %jump(finished_processing_node_value) + +insert_key_continues: + // stack: branch_ptr, common_len, common_key, node_len, node_key, insert_len, insert_key, node_value_ptr, insert_value_ptr, retdest + // branch[insert_key_first + 1] = Leaf[insert_len, insert_key, insert_value] + DUP7 DUP7 + // stack: insert_len, insert_key, branch_ptr, ... + %split_first_nibble + // stack: insert_key_first, insert_len, insert_key, branch_ptr, ... + %get_trie_data_size + // stack: leaf_ptr, insert_key_first, insert_len, insert_key, branch_ptr, ... + SWAP1 + DUP5 // branch_ptr + %increment // Skip over node type field + ADD // Add insert_key_first + %mstore_trie_data + // stack: insert_len, insert_key, branch_ptr, ... + PUSH @MPT_NODE_LEAF %append_to_trie_data + %append_to_trie_data // Append insert_len to our leaf node + %append_to_trie_data // Append insert_key to our leaf node + // stack: branch_ptr, common_len, common_key, node_len, node_key, insert_len, insert_key, node_value_ptr, insert_value_ptr, retdest + DUP9 %append_to_trie_data // Append insert_value_ptr to our leaf node + %jump(finished_processing_insert_value) + +keys_match: + // The keys match exactly, so we simply create a new leaf node with the new value.xs + // stack: node_len, node_key, insert_len, insert_key, node_payload_ptr, insert_value_ptr, retdest + %stack (node_len, node_key, insert_len, insert_key, node_payload_ptr, insert_value_ptr) + -> (node_len, node_key, insert_value_ptr) + // stack: common_len, common_key, insert_value_ptr, retdest + %get_trie_data_size + // stack: updated_leaf_ptr, common_len, common_key, insert_value_ptr, retdest + PUSH @MPT_NODE_LEAF %append_to_trie_data + SWAP1 %append_to_trie_data // Append common_len to our leaf node + SWAP1 %append_to_trie_data // Append common_key to our leaf node + SWAP1 %append_to_trie_data // Append insert_value_ptr to our leaf node + // stack: updated_leaf_ptr, retdestx + SWAP1 + JUMP diff --git a/evm/src/cpu/kernel/asm/mpt/util.asm b/evm/src/cpu/kernel/asm/mpt/util.asm index a3ff7f38..0faa72f4 100644 --- a/evm/src/cpu/kernel/asm/mpt/util.asm +++ b/evm/src/cpu/kernel/asm/mpt/util.asm @@ -73,7 +73,12 @@ // stack: first_nibble, num_nibbles, key %endmacro -// Split off the common prefix among two key parts. Roughly equivalent to +// Split off the common prefix among two key parts. +// +// Pre stack: len_1, key_1, len_2, key_2 +// Post stack: len_common, key_common, len_1, key_1, len_2, key_2 +// +// Roughly equivalent to // def split_common_prefix(len_1, key_1, len_2, key_2): // bits_1 = len_1 * 4 // bits_2 = len_2 * 4 @@ -155,5 +160,8 @@ // stack: first_nib_1, first_nib_2, len_common, key_common, bits_1, key_1, bits_2, key_2 %pop2 %%return: + // stack: len_common, key_common, bits_1, key_1, bits_2, key_2 + SWAP2 %div_const(4) SWAP2 // bits_1 -> len_1 (in nibbles) + SWAP4 %div_const(4) SWAP4 // bits_2 -> len_2 (in nibbles) // stack: len_common, key_common, len_1, key_1, len_2, key_2 %endmacro diff --git a/evm/src/cpu/kernel/asm/util/basic_macros.asm b/evm/src/cpu/kernel/asm/util/basic_macros.asm index bba2a2c1..02a2c807 100644 --- a/evm/src/cpu/kernel/asm/util/basic_macros.asm +++ b/evm/src/cpu/kernel/asm/util/basic_macros.asm @@ -44,6 +44,12 @@ %endrep %endmacro +%macro pop8 + %rep 8 + POP + %endrep +%endmacro + %macro and_const(c) // stack: input, ... PUSH $c diff --git a/evm/src/cpu/kernel/tests/mpt/insert.rs b/evm/src/cpu/kernel/tests/mpt/insert.rs index 11927d52..1ac7974f 100644 --- a/evm/src/cpu/kernel/tests/mpt/insert.rs +++ b/evm/src/cpu/kernel/tests/mpt/insert.rs @@ -40,6 +40,26 @@ fn mpt_insert_leaf_same_key() -> Result<()> { test_state_trie(state_trie, insert) } +#[test] +fn mpt_insert_leaf_nonoverlapping_key() -> Result<()> { + let state_trie = PartialTrie::Leaf { + nibbles: Nibbles { + count: 3, + packed: 0xABC.into(), + }, + value: test_account_1_rlp(), + }; + let insert = InsertEntry { + nibbles: Nibbles { + count: 3, + packed: 0x123.into(), + }, + v: test_account_2_rlp(), + }; + + test_state_trie(state_trie, insert) +} + #[test] fn mpt_insert_branch_replacing_empty_child() -> Result<()> { let children = std::array::from_fn(|_| Box::new(PartialTrie::Empty)); @@ -76,7 +96,6 @@ fn mpt_insert_extension_to_leaf_same_key() -> Result<()> { } #[test] -#[ignore] // TODO: Enable when mpt_insert_leaf is done. fn mpt_insert_branch_to_leaf_same_key() -> Result<()> { let leaf = PartialTrie::Leaf { nibbles: Nibbles { @@ -161,7 +180,8 @@ fn test_state_trie(state_trie: PartialTrie, insert: InsertEntry) -> Result<()> { ); let hash = H256::from_uint(&interpreter.stack()[0]); - let expected_state_trie_hash = apply_insert(state_trie, insert).calc_hash(); + let updated_trie = apply_insert(state_trie, insert); + let expected_state_trie_hash = updated_trie.calc_hash(); assert_eq!(hash, expected_state_trie_hash); Ok(()) From 50002df8e49cd1675ae13d20b126b507d81196d6 Mon Sep 17 00:00:00 2001 From: Daniel Lubarov Date: Mon, 10 Oct 2022 10:42:02 -0700 Subject: [PATCH 09/17] MPT insert into leaf, overlapping keys case --- evm/src/cpu/kernel/asm/mpt/insert_leaf.asm | 16 +++++++++++++-- evm/src/cpu/kernel/tests/mpt/insert.rs | 24 ++++++++++++++++++++-- 2 files changed, 36 insertions(+), 4 deletions(-) diff --git a/evm/src/cpu/kernel/asm/mpt/insert_leaf.asm b/evm/src/cpu/kernel/asm/mpt/insert_leaf.asm index b82653f2..eeb7612a 100644 --- a/evm/src/cpu/kernel/asm/mpt/insert_leaf.asm +++ b/evm/src/cpu/kernel/asm/mpt/insert_leaf.asm @@ -21,7 +21,7 @@ // branch[17] = insert_value // // if common_len > 0: -// return Extension[common_key, branch] +// return Extension[common_len, common_key, branch] // else: // return branch @@ -104,7 +104,19 @@ finished_processing_insert_value: extension_for_common_key: // stack: branch_ptr, common_len, common_key, node_len, node_key, insert_len, insert_key, node_value_ptr, insert_value_ptr, retdest - PANIC // TODO + // return Extension[common_len, common_key, branch] + %get_trie_data_size + // stack: extension_ptr, branch_ptr, common_len, common_key, ... + PUSH @MPT_NODE_EXTENSION %append_to_trie_data + SWAP2 %append_to_trie_data // Append common_len to our node + SWAP2 %append_to_trie_data // Append common_key to our node + SWAP1 %append_to_trie_data // Append branch_ptr to our node + // stack: extension_ptr, node_len, node_key, insert_len, insert_key, node_value_ptr, insert_value_ptr, retdest + SWAP6 + %pop6 + // stack: extension_ptr, retdest + SWAP1 + JUMP node_key_continues: // stack: branch_ptr, common_len, common_key, node_len, node_key, insert_len, insert_key, node_value_ptr, insert_value_ptr, retdest diff --git a/evm/src/cpu/kernel/tests/mpt/insert.rs b/evm/src/cpu/kernel/tests/mpt/insert.rs index 1ac7974f..872ef4af 100644 --- a/evm/src/cpu/kernel/tests/mpt/insert.rs +++ b/evm/src/cpu/kernel/tests/mpt/insert.rs @@ -23,7 +23,7 @@ fn mpt_insert_empty() -> Result<()> { } #[test] -fn mpt_insert_leaf_same_key() -> Result<()> { +fn mpt_insert_leaf_identical_keys() -> Result<()> { let key = Nibbles { count: 3, packed: 0xABC.into(), @@ -41,7 +41,7 @@ fn mpt_insert_leaf_same_key() -> Result<()> { } #[test] -fn mpt_insert_leaf_nonoverlapping_key() -> Result<()> { +fn mpt_insert_leaf_nonoverlapping_keys() -> Result<()> { let state_trie = PartialTrie::Leaf { nibbles: Nibbles { count: 3, @@ -60,6 +60,26 @@ fn mpt_insert_leaf_nonoverlapping_key() -> Result<()> { test_state_trie(state_trie, insert) } +#[test] +fn mpt_insert_leaf_overlapping_keys() -> Result<()> { + let state_trie = PartialTrie::Leaf { + nibbles: Nibbles { + count: 3, + packed: 0xABC.into(), + }, + value: test_account_1_rlp(), + }; + let insert = InsertEntry { + nibbles: Nibbles { + count: 3, + packed: 0xADE.into(), + }, + v: test_account_2_rlp(), + }; + + test_state_trie(state_trie, insert) +} + #[test] fn mpt_insert_branch_replacing_empty_child() -> Result<()> { let children = std::array::from_fn(|_| Box::new(PartialTrie::Empty)); From caf928b11efacf62989e20a9776e913697022991 Mon Sep 17 00:00:00 2001 From: Daniel Lubarov Date: Mon, 10 Oct 2022 11:14:16 -0700 Subject: [PATCH 10/17] MPT logic for inserts into extension nodes --- .../cpu/kernel/asm/mpt/insert_extension.asm | 202 +++++++++++++++++- evm/src/cpu/kernel/asm/mpt/insert_leaf.asm | 94 ++++---- evm/src/cpu/kernel/tests/mpt/insert.rs | 65 +++++- 3 files changed, 310 insertions(+), 51 deletions(-) diff --git a/evm/src/cpu/kernel/asm/mpt/insert_extension.asm b/evm/src/cpu/kernel/asm/mpt/insert_extension.asm index 36458165..3ead805b 100644 --- a/evm/src/cpu/kernel/asm/mpt/insert_extension.asm +++ b/evm/src/cpu/kernel/asm/mpt/insert_extension.asm @@ -1,5 +1,201 @@ +/* +Insert into an extension node. +The high-level logic can be expressed with the following pseudocode: + +common_len, common_key, node_len, node_key, insert_len, insert_key = + split_common_prefix(node_len, node_key, insert_len, insert_key) + +if node_len == 0: + new_node = insert(node_child, insert_len, insert_key, insert_value) +else: + new_node = [MPT_TYPE_BRANCH] + [0] * 17 + + // Process the node's child. + if node_len > 1: + // The node key continues with multiple nibbles left, so we can't place + // node_child directly in the branch, but need an extension for it. + node_key_first, node_len, node_key = split_first_nibble(node_len, node_key) + new_node[node_key_first + 1] = [MPT_TYPE_EXTENSION, node_len, node_key, node_child] + else: + // The remaining node_key is a single nibble, so we can place node_child directly in the branch. + new_node[node_key + 1] = node_child + + // Process the inserted entry. + if insert_len > 0: + // The insert key continues. Add a leaf node for it. + insert_key_first, insert_len, insert_key = split_first_nibble(insert_len, insert_key) + new_node[insert_key_first + 1] = [MPT_TYPE_LEAF, insert_len, insert_key, insert_value] + else: + new_node[17] = insert_value + +if common_len > 0: + return [MPT_TYPE_EXTENSION, common_len, common_key, new_node] +else: + return new_node +*/ + global mpt_insert_extension: - // stack: node_type, node_payload_ptr, insert_len, insert_key, value_ptr, retdest + // stack: node_type, node_payload_ptr, insert_len, insert_key, insert_value_ptr, retdest POP - // stack: node_payload_ptr, insert_len, insert_key, value_ptr, retdest - PANIC // TODO + // stack: node_payload_ptr, insert_len, insert_key, insert_value_ptr, retdest + + // We start by loading the extension node's three fields: node_len, node_key, node_child_ptr + DUP1 %add_const(2) %mload_trie_data + // stack: node_child_ptr, node_payload_ptr, insert_len, insert_key, insert_value_ptr, retdest + %stack (node_child_ptr, node_payload_ptr, insert_len, insert_key) + -> (node_payload_ptr, insert_len, insert_key, node_child_ptr) + // stack: node_payload_ptr, insert_len, insert_key, node_child_ptr, insert_value_ptr, retdest + DUP1 %increment %mload_trie_data + // stack: node_key, node_payload_ptr, insert_len, insert_key, node_child_ptr, insert_value_ptr, retdest + SWAP1 %mload_trie_data + // stack: node_len, node_key, insert_len, insert_key, node_child_ptr, insert_value_ptr, retdest + + // Next, we split off any key prefix which is common to the node's key and the inserted key. + %split_common_prefix + // stack: common_len, common_key, node_len, node_key, insert_len, insert_key, node_child_ptr, insert_value_ptr, retdest + + // Now we branch based on whether the node key continues beyond the common prefix. + DUP3 %jumpi(node_key_continues) + + // The node key does not continue. In this case we recurse. Pseudocode: + // new_node = insert(node_child, insert_len, insert_key, insert_value) + // and then proceed to maybe_add_extension_for_common_key. + // stack: common_len, common_key, node_len, node_key, insert_len, insert_key, node_child_ptr, insert_value_ptr, retdest + PUSH maybe_add_extension_for_common_key + DUP9 // insert_value_ptr + DUP8 // insert_key + DUP8 // insert_len + DUP11 // node_child_ptr + %jump(mpt_insert) + +node_key_continues: + // stack: common_len, common_key, node_len, node_key, insert_len, insert_key, node_child_ptr, insert_value_ptr, retdest + // Allocate new_node, a branch node which is initially empty + // Pseudocode: new_node = [MPT_TYPE_BRANCH] + [0] * 17 + %get_trie_data_size // pointer to the branch node we're about to create + PUSH @MPT_NODE_BRANCH %append_to_trie_data + %rep 17 + PUSH 0 %append_to_trie_data + %endrep + +process_node_child: + // stack: new_node_ptr, common_len, common_key, node_len, node_key, insert_len, insert_key, node_child_ptr, insert_value_ptr, retdest + // We want to check if node_len > 1. We already know node_len > 0 since we're in node_key_continues, + // so it suffices to check 1 - node_len != 0 + DUP4 // node_len + PUSH 1 SUB + %jumpi(node_key_continues_multiple_nibbles) + + // If we got here, node_len = 1. + // Pseudocode: new_node[node_key + 1] = node_child + // stack: new_node_ptr, common_len, common_key, node_len, node_key, insert_len, insert_key, node_child_ptr, insert_value_ptr, retdest + DUP8 // node_child_ptr + DUP2 // new_node_ptr + %increment + DUP7 // node_key + ADD + %mstore_trie_data + // stack: new_node_ptr, common_len, common_key, node_len, node_key, insert_len, insert_key, node_child_ptr, insert_value_ptr, retdest + %jump(process_inserted_entry) + +node_key_continues_multiple_nibbles: + // stack: new_node_ptr, common_len, common_key, node_len, node_key, insert_len, insert_key, node_child_ptr, insert_value_ptr, retdest + // Pseudocode: node_key_first, node_len, node_key = split_first_nibble(node_len, node_key) + // To minimize stack manipulation, we won't actually mutate the node_len, node_key variables in our stack. + // Instead we will duplicate them, and leave the old ones alone; they won't be used. + DUP5 DUP5 + // stack: node_len, node_key, new_node_ptr, ... + %split_first_nibble + // stack: node_key_first, node_len, node_key, new_node_ptr, ... + + // Pseudocode: new_node[node_key_first + 1] = [MPT_TYPE_EXTENSION, node_len, node_key, node_child] + %get_trie_data_size // pointer to the extension node we're about to create + // stack: ext_node_ptr, node_key_first, node_len, node_key, new_node_ptr, ... + PUSH @MPT_NODE_EXTENSION %append_to_trie_data + // stack: ext_node_ptr, node_key_first, node_len, node_key, new_node_ptr, ... + SWAP2 %append_to_trie_data // Append node_len + // stack: node_key_first, ext_node_ptr, node_key, new_node_ptr, ... + SWAP2 %append_to_trie_data // Append node_key + // stack: ext_node_ptr, node_key_first, new_node_ptr, common_len, common_key, node_len, node_key, insert_len, insert_key, node_child_ptr, insert_value_ptr, retdest + DUP10 %append_to_trie_data // Append node_child_ptr + + SWAP1 + // stack: node_key_first, ext_node_ptr, new_node_ptr, ... + DUP3 // new_node_ptr + ADD + %increment + // stack: new_node_ptr + node_key_first + 1, ext_node_ptr, new_node_ptr, ... + %mstore_trie_data + %jump(process_inserted_entry) + +process_inserted_entry: + // stack: new_node_ptr, common_len, common_key, node_len, node_key, insert_len, insert_key, node_child_ptr, insert_value_ptr, retdest + DUP6 // insert_len + %jumpi(insert_key_continues) + + // If we got here, insert_len = 0, so we store the inserted value directly in our new branch node. + // Pseudocode: new_node[17] = insert_value + DUP9 // insert_value_ptr + DUP2 // new_node_ptr + %add_const(17) + %mstore_trie_data + %jump(maybe_add_extension_for_common_key) + +insert_key_continues: + // stack: new_node_ptr, common_len, common_key, node_len, node_key, insert_len, insert_key, node_child_ptr, insert_value_ptr, retdest + // Pseudocode: insert_key_first, insert_len, insert_key = split_first_nibble(insert_len, insert_key) + // To minimize stack manipulation, we won't actually mutate the node_len, node_key variables in our stack. + // Instead we will duplicate them, and leave the old ones alone; they won't be used. + DUP7 DUP7 + // stack: insert_len, insert_key, new_node_ptr, ... + %split_first_nibble + // stack: insert_key_first, insert_len, insert_key, new_node_ptr, ... + + // Pseudocode: new_node[insert_key_first + 1] = [MPT_TYPE_LEAF, insert_len, insert_key, insert_value] + %get_trie_data_size // pointer to the leaf node we're about to create + // stack: leaf_node_ptr, insert_key_first, insert_len, insert_key, new_node_ptr, ... + PUSH @MPT_NODE_LEAF %append_to_trie_data + // stack: leaf_node_ptr, insert_key_first, insert_len, insert_key, new_node_ptr, ... + SWAP2 %append_to_trie_data // Append insert_len + // stack: insert_key_first, leaf_node_ptr, insert_key, new_node_ptr, ... + SWAP2 %append_to_trie_data // Append insert_key + // stack: leaf_node_ptr, insert_key_first, new_node_ptr, common_len, common_key, node_len, node_key, insert_len, insert_key, node_child_ptr, insert_value_ptr, retdest + DUP11 %append_to_trie_data // Append insert_value_ptr + + SWAP1 + // stack: insert_key_first, leaf_node_ptr, new_node_ptr, ... + DUP3 // new_node_ptr + ADD + %increment + // stack: new_node_ptr + insert_key_first + 1, leaf_node_ptr, new_node_ptr, ... + %mstore_trie_data + %jump(maybe_add_extension_for_common_key) + +maybe_add_extension_for_common_key: + // stack: new_node_ptr, common_len, common_key, node_len, node_key, insert_len, insert_key, node_child_ptr, insert_value_ptr, retdest + // If common_len > 0, we need to add an extension node. + DUP2 %jumpi(add_extension_for_common_key) + // Otherwise, we simply return new_node_ptr. + SWAP8 + %pop8 + // stack: new_node_ptr, retdest + SWAP1 + JUMP + +add_extension_for_common_key: + // stack: new_node_ptr, common_len, common_key, node_len, node_key, insert_len, insert_key, node_child_ptr, insert_value_ptr, retdest + // Pseudocode: return [MPT_TYPE_EXTENSION, common_len, common_key, new_node] + %get_trie_data_size // pointer to the extension node we're about to create + // stack: extension_ptr, new_node_ptr, common_len, common_key, ... + PUSH @MPT_NODE_EXTENSION %append_to_trie_data + SWAP2 %append_to_trie_data // Append common_len to our node + // stack: new_node_ptr, extension_ptr, common_key, ... + SWAP2 %append_to_trie_data // Append common_key to our node + // stack: extension_ptr, new_node_ptr, ... + SWAP1 %append_to_trie_data // Append new_node_ptr to our node + // stack: extension_ptr, node_len, node_key, insert_len, insert_key, node_child_ptr, insert_value_ptr, retdest + SWAP6 + %pop6 + // stack: extension_ptr, retdest + SWAP1 + JUMP diff --git a/evm/src/cpu/kernel/asm/mpt/insert_leaf.asm b/evm/src/cpu/kernel/asm/mpt/insert_leaf.asm index eeb7612a..6afe2f14 100644 --- a/evm/src/cpu/kernel/asm/mpt/insert_leaf.asm +++ b/evm/src/cpu/kernel/asm/mpt/insert_leaf.asm @@ -1,29 +1,35 @@ -// The high-level logic can be expressed with the following pseudocode: -// -// if node_len == insert_len && node_key == insert_key: -// return Leaf[node_key, insert_value] -// -// common_len, common_key, node_len, node_key, insert_len, insert_key = -// consume_common_prefix(node_len, node_key, insert_len, insert_key) -// -// branch = [MPT_TYPE_BRANCH] + [0] * 17 -// -// if node_len > 0: -// node_key_first, node_len, node_key = split_first_nibble(node_len, node_key) -// branch[node_key_first + 1] = Leaf[node_len, node_key, node_value] -// else: -// branch[17] = node_value -// -// if insert_len > 0: -// insert_key_first, insert_len, insert_key = split_first_nibble(insert_len, insert_key) -// branch[insert_key_first + 1] = Leaf[insert_len, insert_key, insert_value] -// else: -// branch[17] = insert_value -// -// if common_len > 0: -// return Extension[common_len, common_key, branch] -// else: -// return branch +/* +Insert into a leaf node. +The high-level logic can be expressed with the following pseudocode: + +if node_len == insert_len && node_key == insert_key: + return Leaf[node_key, insert_value] + +common_len, common_key, node_len, node_key, insert_len, insert_key = + split_common_prefix(node_len, node_key, insert_len, insert_key) + +branch = [MPT_TYPE_BRANCH] + [0] * 17 + +// Process the node's entry. +if node_len > 0: + node_key_first, node_len, node_key = split_first_nibble(node_len, node_key) + branch[node_key_first + 1] = [MPT_TYPE_LEAF, node_len, node_key, node_value] +else: + branch[17] = node_value + +// Process the inserted entry. +if insert_len > 0: + insert_key_first, insert_len, insert_key = split_first_nibble(insert_len, insert_key) + branch[insert_key_first + 1] = [MPT_TYPE_LEAF, insert_len, insert_key, insert_value] +else: + branch[17] = insert_value + +// Add an extension node if there is a common prefix. +if common_len > 0: + return [MPT_TYPE_EXTENSION, common_len, common_key, branch] +else: + return branch +*/ global mpt_insert_leaf: // stack: node_type, node_payload_ptr, insert_len, insert_key, insert_value_ptr, retdest @@ -61,15 +67,17 @@ global mpt_insert_leaf: // For the remaining cases, we will need a new branch node since the two keys diverge. // We may also need an extension node above it (if common_len > 0); we will handle that later. // For now, we allocate the branch node, initially with no children or value. - %get_trie_data_size + %get_trie_data_size // pointer to the branch node we're about to create PUSH @MPT_NODE_BRANCH %append_to_trie_data %rep 17 PUSH 0 %append_to_trie_data %endrep // stack: branch_ptr, common_len, common_key, node_len, node_key, insert_len, insert_key, node_value_ptr, insert_value_ptr, retdest - // Here, we branch based on whether each key continues beyond the common + // Now, we branch based on whether each key continues beyond the common // prefix, starting with the node key. + +process_node_entry: DUP4 // node_len %jumpi(node_key_continues) @@ -80,7 +88,7 @@ global mpt_insert_leaf: %add_const(17) %mstore_trie_data -finished_processing_node_value: +process_inserted_entry: DUP6 // insert_len %jumpi(insert_key_continues) @@ -91,25 +99,27 @@ finished_processing_node_value: %add_const(17) %mstore_trie_data -finished_processing_insert_value: +maybe_add_extension_for_common_key: // stack: branch_ptr, common_len, common_key, node_len, node_key, insert_len, insert_key, node_value_ptr, insert_value_ptr, retdest // If common_len > 0, we need to add an extension node. - DUP2 %jumpi(extension_for_common_key) - // Otherwise, we simply return our branch node. + DUP2 %jumpi(add_extension_for_common_key) + // Otherwise, we simply return branch_ptr. SWAP8 %pop8 // stack: branch_ptr, retdest SWAP1 JUMP -extension_for_common_key: +add_extension_for_common_key: // stack: branch_ptr, common_len, common_key, node_len, node_key, insert_len, insert_key, node_value_ptr, insert_value_ptr, retdest - // return Extension[common_len, common_key, branch] - %get_trie_data_size + // Pseudocode: return [MPT_TYPE_EXTENSION, common_len, common_key, branch] + %get_trie_data_size // pointer to the extension node we're about to create // stack: extension_ptr, branch_ptr, common_len, common_key, ... PUSH @MPT_NODE_EXTENSION %append_to_trie_data SWAP2 %append_to_trie_data // Append common_len to our node + // stack: branch_ptr, extension_ptr, common_key, ... SWAP2 %append_to_trie_data // Append common_key to our node + // stack: extension_ptr, branch_ptr, ... SWAP1 %append_to_trie_data // Append branch_ptr to our node // stack: extension_ptr, node_len, node_key, insert_len, insert_key, node_value_ptr, insert_value_ptr, retdest SWAP6 @@ -121,11 +131,13 @@ extension_for_common_key: node_key_continues: // stack: branch_ptr, common_len, common_key, node_len, node_key, insert_len, insert_key, node_value_ptr, insert_value_ptr, retdest // branch[node_key_first + 1] = Leaf[node_len, node_key, node_value] + // To minimize stack manipulation, we won't actually mutate the node_len, node_key variables in our stack. + // Instead we will duplicate them, and leave the old ones alone; they won't be used. DUP5 DUP5 // stack: node_len, node_key, branch_ptr, ... %split_first_nibble // stack: node_key_first, node_len, node_key, branch_ptr, ... - %get_trie_data_size + %get_trie_data_size // pointer to the leaf node we're about to create // stack: leaf_ptr, node_key_first, node_len, node_key, branch_ptr, ... SWAP1 DUP5 // branch_ptr @@ -138,16 +150,18 @@ node_key_continues: %append_to_trie_data // Append node_key to our leaf node // stack: branch_ptr, common_len, common_key, node_len, node_key, insert_len, insert_key, node_value_ptr, insert_value_ptr, retdest DUP8 %append_to_trie_data // Append node_value_ptr to our leaf node - %jump(finished_processing_node_value) + %jump(process_inserted_entry) insert_key_continues: // stack: branch_ptr, common_len, common_key, node_len, node_key, insert_len, insert_key, node_value_ptr, insert_value_ptr, retdest // branch[insert_key_first + 1] = Leaf[insert_len, insert_key, insert_value] + // To minimize stack manipulation, we won't actually mutate the insert_len, insert_key variables in our stack. + // Instead we will duplicate them, and leave the old ones alone; they won't be used. DUP7 DUP7 // stack: insert_len, insert_key, branch_ptr, ... %split_first_nibble // stack: insert_key_first, insert_len, insert_key, branch_ptr, ... - %get_trie_data_size + %get_trie_data_size // pointer to the leaf node we're about to create // stack: leaf_ptr, insert_key_first, insert_len, insert_key, branch_ptr, ... SWAP1 DUP5 // branch_ptr @@ -160,7 +174,7 @@ insert_key_continues: %append_to_trie_data // Append insert_key to our leaf node // stack: branch_ptr, common_len, common_key, node_len, node_key, insert_len, insert_key, node_value_ptr, insert_value_ptr, retdest DUP9 %append_to_trie_data // Append insert_value_ptr to our leaf node - %jump(finished_processing_insert_value) + %jump(maybe_add_extension_for_common_key) keys_match: // The keys match exactly, so we simply create a new leaf node with the new value.xs @@ -168,7 +182,7 @@ keys_match: %stack (node_len, node_key, insert_len, insert_key, node_payload_ptr, insert_value_ptr) -> (node_len, node_key, insert_value_ptr) // stack: common_len, common_key, insert_value_ptr, retdest - %get_trie_data_size + %get_trie_data_size // pointer to the leaf node we're about to create // stack: updated_leaf_ptr, common_len, common_key, insert_value_ptr, retdest PUSH @MPT_NODE_LEAF %append_to_trie_data SWAP1 %append_to_trie_data // Append common_len to our leaf node diff --git a/evm/src/cpu/kernel/tests/mpt/insert.rs b/evm/src/cpu/kernel/tests/mpt/insert.rs index 872ef4af..103bdd2e 100644 --- a/evm/src/cpu/kernel/tests/mpt/insert.rs +++ b/evm/src/cpu/kernel/tests/mpt/insert.rs @@ -6,7 +6,7 @@ use ethereum_types::{BigEndianHash, H256}; use crate::cpu::kernel::aggregator::KERNEL; use crate::cpu::kernel::constants::global_metadata::GlobalMetadata; use crate::cpu::kernel::interpreter::Interpreter; -use crate::cpu::kernel::tests::mpt::{extension_to_leaf, test_account_1_rlp, test_account_2_rlp}; +use crate::cpu::kernel::tests::mpt::{test_account_1_rlp, test_account_2_rlp}; use crate::generation::mpt::{all_mpt_prover_inputs_reversed, AccountRlp}; use crate::generation::TrieInputs; @@ -36,7 +36,6 @@ fn mpt_insert_leaf_identical_keys() -> Result<()> { nibbles: key, v: test_account_2_rlp(), }; - test_state_trie(state_trie, insert) } @@ -56,7 +55,6 @@ fn mpt_insert_leaf_nonoverlapping_keys() -> Result<()> { }, v: test_account_2_rlp(), }; - test_state_trie(state_trie, insert) } @@ -76,7 +74,44 @@ fn mpt_insert_leaf_overlapping_keys() -> Result<()> { }, v: test_account_2_rlp(), }; + test_state_trie(state_trie, insert) +} +#[test] +fn mpt_insert_leaf_insert_key_extends_leaf_key() -> Result<()> { + let state_trie = PartialTrie::Leaf { + nibbles: Nibbles { + count: 3, + packed: 0xABC.into(), + }, + value: test_account_1_rlp(), + }; + let insert = InsertEntry { + nibbles: Nibbles { + count: 5, + packed: 0xABCDE.into(), + }, + v: test_account_2_rlp(), + }; + test_state_trie(state_trie, insert) +} + +#[test] +fn mpt_insert_leaf_leaf_key_extends_insert_key() -> Result<()> { + let state_trie = PartialTrie::Leaf { + nibbles: Nibbles { + count: 5, + packed: 0xABCDE.into(), + }, + value: test_account_1_rlp(), + }; + let insert = InsertEntry { + nibbles: Nibbles { + count: 3, + packed: 0xABC.into(), + }, + v: test_account_2_rlp(), + }; test_state_trie(state_trie, insert) } @@ -100,18 +135,32 @@ fn mpt_insert_branch_replacing_empty_child() -> Result<()> { } #[test] -#[ignore] // TODO: Enable when mpt_insert_extension is done. fn mpt_insert_extension_to_leaf_same_key() -> Result<()> { - let state_trie = extension_to_leaf(test_account_1_rlp()); - - let insert = InsertEntry { + let mut children = std::array::from_fn(|_| Box::new(PartialTrie::Empty)); + children[0xD] = Box::new(PartialTrie::Leaf { + nibbles: Nibbles { + count: 2, + packed: 0xEF.into(), + }, + value: test_account_1_rlp(), + }); + let state_trie = PartialTrie::Extension { nibbles: Nibbles { count: 3, + packed: 0xABC.into(), + }, + child: Box::new(PartialTrie::Branch { + children, + value: test_account_1_rlp(), + }), + }; + let insert = InsertEntry { + nibbles: Nibbles { + count: 6, packed: 0xABCDEF.into(), }, v: test_account_2_rlp(), }; - test_state_trie(state_trie, insert) } From 0d0067554e7dd32f49936e7768699de196caceee Mon Sep 17 00:00:00 2001 From: Hamish Ivey-Law <426294+unzvfu@users.noreply.github.com> Date: Tue, 11 Oct 2022 18:59:02 +1100 Subject: [PATCH 11/17] Refactor and tidy up `mul.rs` (#764) * Refactor and tidy up `mul.rs`. * Jacqui PR comments. --- evm/src/arithmetic/modular.rs | 27 +++--- evm/src/arithmetic/mul.rs | 161 +++++++++++++++------------------- evm/src/arithmetic/utils.rs | 66 ++++++++------ 3 files changed, 124 insertions(+), 130 deletions(-) diff --git a/evm/src/arithmetic/modular.rs b/evm/src/arithmetic/modular.rs index 1fd31bb1..fd2a2e28 100644 --- a/evm/src/arithmetic/modular.rs +++ b/evm/src/arithmetic/modular.rs @@ -18,12 +18,13 @@ //! a(x) = \sum_{i=0}^15 a[i] x^i //! //! (so A = a(β)) and similarly for b(x), m(x) and c(x). Then -//! operation(A,B) = C (mod M) if and only if the polynomial +//! operation(A,B) = C (mod M) if and only if there exists q such that +//! the polynomial //! //! operation(a(x), b(x)) - c(x) - m(x) * q(x) //! -//! is zero when evaluated at x = β, i.e. it is divisible by (x - β). -//! Thus exists a polynomial s such that +//! is zero when evaluated at x = β, i.e. it is divisible by (x - β); +//! equivalently, there exists a polynomial s such that //! //! operation(a(x), b(x)) - c(x) - m(x) * q(x) - (x - β) * s(x) == 0 //! @@ -34,12 +35,12 @@ //! coefficients must be zero. The variable names of the constituent //! polynomials are (writing N for N_LIMBS=16): //! -//! a(x) = \sum_{i=0}^{N-1} input0[i] * β^i -//! b(x) = \sum_{i=0}^{N-1} input1[i] * β^i -//! c(x) = \sum_{i=0}^{N-1} output[i] * β^i -//! m(x) = \sum_{i=0}^{N-1} modulus[i] * β^i -//! q(x) = \sum_{i=0}^{2N-1} quot[i] * β^i -//! s(x) = \sum_i^{2N-2} aux[i] * β^i +//! a(x) = \sum_{i=0}^{N-1} input0[i] * x^i +//! b(x) = \sum_{i=0}^{N-1} input1[i] * x^i +//! c(x) = \sum_{i=0}^{N-1} output[i] * x^i +//! m(x) = \sum_{i=0}^{N-1} modulus[i] * x^i +//! q(x) = \sum_{i=0}^{2N-1} quot[i] * x^i +//! s(x) = \sum_i^{2N-2} aux[i] * x^i //! //! Because A, B, M and C are 256-bit numbers, the degrees of a, b, m //! and c are (at most) N-1 = 15. If m = 1, then Q would be A*B which @@ -211,7 +212,7 @@ fn generate_modular_op( // constr_poly must be zero when evaluated at x = β := // 2^LIMB_BITS, hence it's divisible by (x - β). `aux_limbs` is // the result of removing that root. - let aux_limbs = pol_remove_root_2exp::(constr_poly); + let aux_limbs = pol_remove_root_2exp::(constr_poly); for deg in 0..N_LIMBS { lv[MODULAR_OUTPUT[deg]] = F::from_canonical_i64(output_limbs[deg]); @@ -303,7 +304,8 @@ fn modular_constr_poly( pol_add_assign(&mut constr_poly, &output); // constr_poly = c(x) + q(x) * m(x) + (x - β) * s(x) - let aux = MODULAR_AUX_INPUT.map(|c| lv[c]); + let mut aux = MODULAR_AUX_INPUT.map(|c| lv[c]); + aux[2 * N_LIMBS - 1] = P::ZEROS; // zero out the MOD_IS_ZERO flag let base = P::Scalar::from_canonical_u64(1 << LIMB_BITS); pol_add_assign(&mut constr_poly, &pol_adjoin_root(aux, base)); @@ -397,7 +399,8 @@ fn modular_constr_poly_ext_circuit, const D: usize> let mut constr_poly: [_; 2 * N_LIMBS] = prod[0..2 * N_LIMBS].try_into().unwrap(); pol_add_assign_ext_circuit(builder, &mut constr_poly, &output); - let aux = MODULAR_AUX_INPUT.map(|c| lv[c]); + let mut aux = MODULAR_AUX_INPUT.map(|c| lv[c]); + aux[2 * N_LIMBS - 1] = builder.zero_extension(); let base = builder.constant_extension(F::Extension::from_canonical_u64(1u64 << LIMB_BITS)); let t = pol_adjoin_root_ext_circuit(builder, aux, base); pol_add_assign_ext_circuit(builder, &mut constr_poly, &t); diff --git a/evm/src/arithmetic/mul.rs b/evm/src/arithmetic/mul.rs index 9d6638f1..c98b9af8 100644 --- a/evm/src/arithmetic/mul.rs +++ b/evm/src/arithmetic/mul.rs @@ -3,30 +3,57 @@ //! This crate verifies an EVM MUL instruction, which takes two //! 256-bit inputs A and B, and produces a 256-bit output C satisfying //! -//! C = A*B (mod 2^256). +//! C = A*B (mod 2^256), //! -//! Inputs A and B, and output C, are given as arrays of 16-bit +//! i.e. C is the lower half of the usual long multiplication +//! A*B. Inputs A and B, and output C, are given as arrays of 16-bit //! limbs. For example, if the limbs of A are a[0]...a[15], then //! //! A = \sum_{i=0}^15 a[i] β^i, //! -//! where β = 2^16. To verify that A, B and C satisfy the equation we -//! proceed as follows. Define a(x) = \sum_{i=0}^15 a[i] x^i (so A = a(β)) -//! and similarly for b(x) and c(x). Then A*B = C (mod 2^256) if and only -//! if there exist polynomials q and m such that +//! where β = 2^16 = 2^LIMB_BITS. To verify that A, B and C satisfy +//! the equation we proceed as follows. Define //! -//! a(x)*b(x) - c(x) - m(x)*x^16 - (β - x)*q(x) == 0. +//! a(x) = \sum_{i=0}^15 a[i] x^i +//! +//! (so A = a(β)) and similarly for b(x) and c(x). Then A*B = C (mod +//! 2^256) if and only if there exists q such that the polynomial +//! +//! a(x) * b(x) - c(x) - x^16 * q(x) +//! +//! is zero when evaluated at x = β, i.e. it is divisible by (x - β); +//! equivalently, there exists a polynomial s (representing the +//! carries from the long multiplication) such that +//! +//! a(x) * b(x) - c(x) - x^16 * q(x) - (x - β) * s(x) == 0 +//! +//! As we only need the lower half of the product, we can omit q(x) +//! since it is multiplied by the modulus β^16 = 2^256. Thus we only +//! need to verify +//! +//! a(x) * b(x) - c(x) - (x - β) * s(x) == 0 +//! +//! In the code below, this "constraint polynomial" is constructed in +//! the variable `constr_poly`. It must be identically zero for the +//! multiplication operation to be verified, or, equivalently, each of +//! its coefficients must be zero. The variable names of the +//! constituent polynomials are (writing N for N_LIMBS=16): +//! +//! a(x) = \sum_{i=0}^{N-1} input0[i] * x^i +//! b(x) = \sum_{i=0}^{N-1} input1[i] * x^i +//! c(x) = \sum_{i=0}^{N-1} output[i] * x^i +//! s(x) = \sum_i^{2N-3} aux[i] * x^i //! //! Because A, B and C are 256-bit numbers, the degrees of a, b and c -//! are (at most) 15. Thus deg(a*b) <= 30, so deg(m) <= 14 and deg(q) -//! <= 29. However, the fact that we're verifying the equality modulo -//! 2^256 means that we can ignore terms of degree >= 16, since for -//! them evaluating at β gives a factor of β^16 = 2^256 which is 0. +//! are (at most) 15. Thus deg(a*b) <= 30 and deg(s) <= 29; however, +//! as we're only verifying the lower half of A*B, we only need to +//! know s(x) up to degree 14 (so that (x - β)*s(x) has degree 15). On +//! the other hand, the coefficients of s(x) can be as large as +//! 16*(β-2) or 20 bits. //! -//! Hence, to verify the equality, we don't need m(x) at all, and we -//! only need to know q(x) up to degree 14 (so that (β - x)*q(x) has -//! degree 15). On the other hand, the coefficients of q(x) can be as -//! large as 16*(β-2) or 20 bits. +//! Note that, unlike for the general modular multiplication (see the +//! file `modular.rs`), we don't need to check that output is reduced, +//! since any value of output is less than β^16 and is hence reduced. use plonky2::field::extension::Extendable; use plonky2::field::packed::PackedField; @@ -35,64 +62,42 @@ use plonky2::hash::hash_types::RichField; use plonky2::iop::ext_target::ExtensionTarget; use crate::arithmetic::columns::*; -use crate::arithmetic::utils::{pol_mul_lo, pol_sub_assign}; +use crate::arithmetic::utils::*; use crate::constraint_consumer::{ConstraintConsumer, RecursiveConstraintConsumer}; use crate::range_check_error; pub fn generate(lv: &mut [F; NUM_ARITH_COLUMNS]) { - let input0_limbs = MUL_INPUT_0.map(|c| lv[c].to_canonical_u64()); - let input1_limbs = MUL_INPUT_1.map(|c| lv[c].to_canonical_u64()); + let input0_limbs = MUL_INPUT_0.map(|c| lv[c].to_canonical_u64() as i64); + let input1_limbs = MUL_INPUT_1.map(|c| lv[c].to_canonical_u64() as i64); - const MASK: u64 = (1u64 << LIMB_BITS) - 1u64; + const MASK: i64 = (1i64 << LIMB_BITS) - 1i64; // Input and output have 16-bit limbs - let mut aux_in_limbs = [0u64; N_LIMBS]; - let mut output_limbs = [0u64; N_LIMBS]; + let mut output_limbs = [0i64; N_LIMBS]; // Column-wise pen-and-paper long multiplication on 16-bit limbs. // First calculate the coefficients of a(x)*b(x) (in unreduced_prod), // then do carry propagation to obtain C = c(β) = a(β)*b(β). - let mut cy = 0u64; + let mut cy = 0i64; let mut unreduced_prod = pol_mul_lo(input0_limbs, input1_limbs); for col in 0..N_LIMBS { let t = unreduced_prod[col] + cy; cy = t >> LIMB_BITS; output_limbs[col] = t & MASK; } - // In principle, the last cy could be dropped because this is // multiplication modulo 2^256. However, we need it below for - // aux_in_limbs to handle the fact that unreduced_prod will - // inevitably contain a one digit's worth that is > 2^256. + // aux_limbs to handle the fact that unreduced_prod will + // inevitably contain one digit's worth that is > 2^256. - for (&c, output_limb) in MUL_OUTPUT.iter().zip(output_limbs) { - lv[c] = F::from_canonical_u64(output_limb); - } pol_sub_assign(&mut unreduced_prod, &output_limbs); - // unreduced_prod is the coefficients of the polynomial a(x)*b(x) - c(x). - // This must be zero when evaluated at x = β = 2^LIMB_BITS, hence it's - // divisible by (β - x). If we write unreduced_prod as - // - // a(x)*b(x) - c(x) = \sum_{i=0}^n p_i x^i + terms of degree > n - // = (β - x) \sum_{i=0}^{n-1} q_i x^i + terms of degree > n - // - // then by comparing coefficients it is easy to see that - // - // q_0 = p_0 / β and q_i = (p_i + q_{i-1}) / β - // - // for 0 < i < n-1 (and the divisions are exact). Because we're - // only calculating the result modulo 2^256, we can ignore the - // terms of degree > n = 15. - aux_in_limbs[0] = unreduced_prod[0] >> LIMB_BITS; - for deg in 1..N_LIMBS - 1 { - aux_in_limbs[deg] = (unreduced_prod[deg] + aux_in_limbs[deg - 1]) >> LIMB_BITS; - } - aux_in_limbs[N_LIMBS - 1] = cy; + let mut aux_limbs = pol_remove_root_2exp::(unreduced_prod); + aux_limbs[N_LIMBS - 1] = -cy; for deg in 0..N_LIMBS { - let c = MUL_AUX_INPUT[deg]; - lv[c] = F::from_canonical_u64(aux_in_limbs[deg]); + lv[MUL_OUTPUT[deg]] = F::from_canonical_i64(output_limbs[deg]); + lv[MUL_AUX_INPUT[deg]] = F::from_noncanonical_i64(aux_limbs[deg]); } } @@ -115,29 +120,26 @@ pub fn eval_packed_generic( // must be identically zero for this multiplication to be // verified. // - // These two lines set constr_poly to the polynomial A(x)B(x) - C(x), - // where A, B and C are the polynomials + // These two lines set constr_poly to the polynomial a(x)b(x) - c(x), + // where a, b and c are the polynomials // - // A(x) = \sum_i input0_limbs[i] * 2^LIMB_BITS - // B(x) = \sum_i input1_limbs[i] * 2^LIMB_BITS - // C(x) = \sum_i output_limbs[i] * 2^LIMB_BITS + // a(x) = \sum_i input0_limbs[i] * β^i + // b(x) = \sum_i input1_limbs[i] * β^i + // c(x) = \sum_i output_limbs[i] * β^i // - // This polynomial should equal (2^LIMB_BITS - x) * Q(x) where Q is + // This polynomial should equal where s is // - // Q(x) = \sum_i aux_limbs[i] * 2^LIMB_BITS + // s(x) = \sum_i aux_limbs[i] * β^i // let mut constr_poly = pol_mul_lo(input0_limbs, input1_limbs); pol_sub_assign(&mut constr_poly, &output_limbs); - // This subtracts (2^LIMB_BITS - x) * Q(x) from constr_poly. + // This subtracts (x - β) * s(x) from constr_poly. let base = P::Scalar::from_canonical_u64(1 << LIMB_BITS); - constr_poly[0] -= base * aux_limbs[0]; - for deg in 1..N_LIMBS { - constr_poly[deg] -= (base * aux_limbs[deg]) - aux_limbs[deg - 1]; - } + pol_sub_assign(&mut constr_poly, &pol_adjoin_root(aux_limbs, base)); // At this point constr_poly holds the coefficients of the - // polynomial A(x)B(x) - C(x) - (2^LIMB_BITS - x)*Q(x). The + // polynomial a(x)b(x) - c(x) - (x - β)*s(x). The // multiplication is valid if and only if all of those // coefficients are zero. for &c in &constr_poly { @@ -154,37 +156,14 @@ pub fn eval_ext_circuit, const D: usize>( let input0_limbs = MUL_INPUT_0.map(|c| lv[c]); let input1_limbs = MUL_INPUT_1.map(|c| lv[c]); let output_limbs = MUL_OUTPUT.map(|c| lv[c]); - let aux_in_limbs = MUL_AUX_INPUT.map(|c| lv[c]); + let aux_limbs = MUL_AUX_INPUT.map(|c| lv[c]); - let zero = builder.zero_extension(); - let mut constr_poly = [zero; N_LIMBS]; // pointless init + let mut constr_poly = pol_mul_lo_ext_circuit(builder, input0_limbs, input1_limbs); + pol_sub_assign_ext_circuit(builder, &mut constr_poly, &output_limbs); - // Invariant: i + j = deg - for col in 0..N_LIMBS { - let mut acc = zero; - for i in 0..=col { - let j = col - i; - acc = builder.mul_add_extension(input0_limbs[i], input1_limbs[j], acc); - } - constr_poly[col] = builder.sub_extension(acc, output_limbs[col]); - } - - let base = F::from_canonical_u64(1 << LIMB_BITS); - let one = builder.one_extension(); - // constr_poly[0] = constr_poly[0] - base * aux_in_limbs[0] - constr_poly[0] = - builder.arithmetic_extension(F::ONE, -base, constr_poly[0], one, aux_in_limbs[0]); - for deg in 1..N_LIMBS { - // constr_poly[deg] -= (base*aux_in_limbs[deg] - aux_in_limbs[deg-1]) - let t = builder.arithmetic_extension( - base, - F::NEG_ONE, - aux_in_limbs[deg], - one, - aux_in_limbs[deg - 1], - ); - constr_poly[deg] = builder.sub_extension(constr_poly[deg], t); - } + let base = builder.constant_extension(F::Extension::from_canonical_u64(1 << LIMB_BITS)); + let rhs = pol_adjoin_root_ext_circuit(builder, aux_limbs, base); + pol_sub_assign_ext_circuit(builder, &mut constr_poly, &rhs); for &c in &constr_poly { let filter = builder.mul_extension(is_mul, c); diff --git a/evm/src/arithmetic/utils.rs b/evm/src/arithmetic/utils.rs index b5356a78..ccb8bc0a 100644 --- a/evm/src/arithmetic/utils.rs +++ b/evm/src/arithmetic/utils.rs @@ -225,6 +225,22 @@ where res } +pub(crate) fn pol_mul_lo_ext_circuit, const D: usize>( + builder: &mut CircuitBuilder, + a: [ExtensionTarget; N_LIMBS], + b: [ExtensionTarget; N_LIMBS], +) -> [ExtensionTarget; N_LIMBS] { + let zero = builder.zero_extension(); + let mut res = [zero; N_LIMBS]; + for deg in 0..N_LIMBS { + for i in 0..=deg { + let j = deg - i; + res[deg] = builder.mul_add_extension(a[i], b[j], res[deg]); + } + } + res +} + /// Adjoin M - N zeros to a, returning [a[0], a[1], ..., a[N-1], 0, 0, ..., 0]. pub(crate) fn pol_extend(a: [T; N]) -> [T; M] where @@ -248,11 +264,9 @@ pub(crate) fn pol_extend_ext_circuit, const D: usiz zero_extend } -/// Given polynomial a(x) = \sum_{i=0}^{2N-2} a[i] x^i and an element +/// Given polynomial a(x) = \sum_{i=0}^{N-2} a[i] x^i and an element /// `root`, return b = (x - root) * a(x). -/// -/// NB: Ignores element a[2 * N_LIMBS - 1], treating it as if it's 0. -pub(crate) fn pol_adjoin_root(a: [T; 2 * N_LIMBS], root: U) -> [T; 2 * N_LIMBS] +pub(crate) fn pol_adjoin_root(a: [T; N], root: U) -> [T; N] where T: Add + Copy + Default + Mul + Sub, U: Copy + Mul + Neg, @@ -261,66 +275,64 @@ where // coefficients, res[0] = -root*a[0] and // res[i] = a[i-1] - root * a[i] - let mut res = [T::default(); 2 * N_LIMBS]; + let mut res = [T::default(); N]; res[0] = -root * a[0]; - for deg in 1..(2 * N_LIMBS - 1) { + for deg in 1..N { res[deg] = a[deg - 1] - (root * a[deg]); } - // NB: We assume that a[2 * N_LIMBS - 1] = 0, so the last - // iteration has no "* root" term. - res[2 * N_LIMBS - 1] = a[2 * N_LIMBS - 2]; res } -pub(crate) fn pol_adjoin_root_ext_circuit, const D: usize>( +pub(crate) fn pol_adjoin_root_ext_circuit< + F: RichField + Extendable, + const D: usize, + const N: usize, +>( builder: &mut CircuitBuilder, - a: [ExtensionTarget; 2 * N_LIMBS], + a: [ExtensionTarget; N], root: ExtensionTarget, -) -> [ExtensionTarget; 2 * N_LIMBS] { +) -> [ExtensionTarget; N] { let zero = builder.zero_extension(); - let mut res = [zero; 2 * N_LIMBS]; + let mut res = [zero; N]; // res[deg] = NEG_ONE * root * a[0] + ZERO * zero res[0] = builder.arithmetic_extension(F::NEG_ONE, F::ZERO, root, a[0], zero); - for deg in 1..(2 * N_LIMBS - 1) { + for deg in 1..N { // res[deg] = NEG_ONE * root * a[deg] + ONE * a[deg - 1] res[deg] = builder.arithmetic_extension(F::NEG_ONE, F::ONE, root, a[deg], a[deg - 1]); } - // NB: We assumes that a[2 * N_LIMBS - 1] = 0, so the last - // iteration has no "* root" term. - res[2 * N_LIMBS - 1] = a[2 * N_LIMBS - 2]; res } -/// Given polynomial a(x) = \sum_{i=0}^{2N-1} a[i] x^i and a root of `a` +/// Given polynomial a(x) = \sum_{i=0}^{N-1} a[i] x^i and a root of `a` /// of the form 2^EXP, return q(x) satisfying a(x) = (x - root) * q(x). /// /// NB: We do not verify that a(2^EXP) = 0; if this doesn't hold the /// result is basically junk. /// -/// NB: The result could be returned in 2*N-1 elements, but we return -/// 2*N and set the last element to zero since the calling code -/// happens to require a result zero-extended to 2*N elements. -pub(crate) fn pol_remove_root_2exp(a: [T; 2 * N_LIMBS]) -> [T; 2 * N_LIMBS] +/// NB: The result could be returned in N-1 elements, but we return +/// N and set the last element to zero since the calling code +/// happens to require a result zero-extended to N elements. +pub(crate) fn pol_remove_root_2exp(a: [T; N]) -> [T; N] where T: Copy + Default + Neg + Shr + Sub, { // By assumption β := 2^EXP is a root of `a`, i.e. (x - β) divides // `a`; if we write // - // a(x) = \sum_{i=0}^{2N-1} a[i] x^i - // = (x - β) \sum_{i=0}^{2N-2} q[i] x^i + // a(x) = \sum_{i=0}^{N-1} a[i] x^i + // = (x - β) \sum_{i=0}^{N-2} q[i] x^i // // then by comparing coefficients it is easy to see that // // q[0] = -a[0] / β and q[i] = (q[i-1] - a[i]) / β // - // for 0 < i <= 2N-1 (and the divisions are exact). + // for 0 < i <= N-1 (and the divisions are exact). - let mut q = [T::default(); 2 * N_LIMBS]; + let mut q = [T::default(); N]; q[0] = -(a[0] >> EXP); // NB: Last element of q is deliberately left equal to zero. - for deg in 1..2 * N_LIMBS - 1 { + for deg in 1..N - 1 { q[deg] = (q[deg - 1] - a[deg]) >> EXP; } q From 68a5428500966679b746a096b61442b57e790be0 Mon Sep 17 00:00:00 2001 From: Hamish Ivey-Law <426294+unzvfu@users.noreply.github.com> Date: Wed, 12 Oct 2022 02:39:13 +1100 Subject: [PATCH 12/17] Represent input columns as ranges rather than arrays (#776) * Use std::ops::Range of columns rather than arrays of column indices. * Refactor reading from the local values table. * The inevitable post-push fmt/clippy commit. --- evm/src/arithmetic/add.rs | 40 +++++++++-------- evm/src/arithmetic/columns.rs | 79 ++++++++++++++------------------- evm/src/arithmetic/compare.rs | 47 ++++++++++---------- evm/src/arithmetic/modular.rs | 83 ++++++++++++++--------------------- evm/src/arithmetic/mul.rs | 40 ++++++++--------- evm/src/arithmetic/sub.rs | 39 +++++++++------- evm/src/arithmetic/utils.rs | 45 ++++++++++++++++--- 7 files changed, 192 insertions(+), 181 deletions(-) diff --git a/evm/src/arithmetic/add.rs b/evm/src/arithmetic/add.rs index d2520fb9..1bf798cc 100644 --- a/evm/src/arithmetic/add.rs +++ b/evm/src/arithmetic/add.rs @@ -6,6 +6,7 @@ use plonky2::hash::hash_types::RichField; use plonky2::iop::ext_target::ExtensionTarget; use crate::arithmetic::columns::*; +use crate::arithmetic::utils::read_value_u64_limbs; use crate::constraint_consumer::{ConstraintConsumer, RecursiveConstraintConsumer}; use crate::range_check_error; @@ -94,15 +95,12 @@ where } pub fn generate(lv: &mut [F; NUM_ARITH_COLUMNS]) { - let input0_limbs = ADD_INPUT_0.map(|c| lv[c].to_canonical_u64()); - let input1_limbs = ADD_INPUT_1.map(|c| lv[c].to_canonical_u64()); + let input0 = read_value_u64_limbs(lv, ADD_INPUT_0); + let input1 = read_value_u64_limbs(lv, ADD_INPUT_1); // Input and output have 16-bit limbs - let (output_limbs, _) = u256_add_cc(input0_limbs, input1_limbs); - - for (&c, output_limb) in ADD_OUTPUT.iter().zip(output_limbs) { - lv[c] = F::from_canonical_u64(output_limb); - } + let (output_limbs, _) = u256_add_cc(input0, input1); + lv[ADD_OUTPUT].copy_from_slice(&output_limbs.map(|c| F::from_canonical_u64(c))); } pub fn eval_packed_generic( @@ -114,15 +112,20 @@ pub fn eval_packed_generic( range_check_error!(ADD_OUTPUT, 16); let is_add = lv[IS_ADD]; - let input0_limbs = ADD_INPUT_0.iter().map(|&c| lv[c]); - let input1_limbs = ADD_INPUT_1.iter().map(|&c| lv[c]); - let output_limbs = ADD_OUTPUT.iter().map(|&c| lv[c]); + let input0_limbs = &lv[ADD_INPUT_0]; + let input1_limbs = &lv[ADD_INPUT_1]; + let output_limbs = &lv[ADD_OUTPUT]; // This computed output is not yet reduced; i.e. some limbs may be // more than 16 bits. - let output_computed = input0_limbs.zip(input1_limbs).map(|(a, b)| a + b); + let output_computed = input0_limbs.iter().zip(input1_limbs).map(|(&a, &b)| a + b); - eval_packed_generic_are_equal(yield_constr, is_add, output_computed, output_limbs); + eval_packed_generic_are_equal( + yield_constr, + is_add, + output_computed, + output_limbs.iter().copied(), + ); } #[allow(clippy::needless_collect)] @@ -132,17 +135,18 @@ pub fn eval_ext_circuit, const D: usize>( yield_constr: &mut RecursiveConstraintConsumer, ) { let is_add = lv[IS_ADD]; - let input0_limbs = ADD_INPUT_0.iter().map(|&c| lv[c]); - let input1_limbs = ADD_INPUT_1.iter().map(|&c| lv[c]); - let output_limbs = ADD_OUTPUT.iter().map(|&c| lv[c]); + let input0_limbs = &lv[ADD_INPUT_0]; + let input1_limbs = &lv[ADD_INPUT_1]; + let output_limbs = &lv[ADD_OUTPUT]; // Since `map` is lazy and the closure passed to it borrows // `builder`, we can't then borrow builder again below in the call // to `eval_ext_circuit_are_equal`. The solution is to force // evaluation with `collect`. let output_computed = input0_limbs + .iter() .zip(input1_limbs) - .map(|(a, b)| builder.add_extension(a, b)) + .map(|(&a, &b)| builder.add_extension(a, b)) .collect::>>(); eval_ext_circuit_are_equal( @@ -150,7 +154,7 @@ pub fn eval_ext_circuit, const D: usize>( yield_constr, is_add, output_computed.into_iter(), - output_limbs, + output_limbs.iter().copied(), ); } @@ -203,7 +207,7 @@ mod tests { for _ in 0..N_RND_TESTS { // set inputs to random values - for (&ai, bi) in ADD_INPUT_0.iter().zip(ADD_INPUT_1) { + for (ai, bi) in ADD_INPUT_0.zip(ADD_INPUT_1) { lv[ai] = F::from_canonical_u16(rng.gen()); lv[bi] = F::from_canonical_u16(rng.gen()); } diff --git a/evm/src/arithmetic/columns.rs b/evm/src/arithmetic/columns.rs index ca8ba549..ee73f223 100644 --- a/evm/src/arithmetic/columns.rs +++ b/evm/src/arithmetic/columns.rs @@ -1,5 +1,7 @@ //! Arithmetic unit +use std::ops::Range; + pub const LIMB_BITS: usize = 16; const EVM_REGISTER_BITS: usize = 256; @@ -44,57 +46,42 @@ pub(crate) const ALL_OPERATIONS: [usize; 16] = [ /// used by any arithmetic circuit, depending on which one is active /// this cycle. Can be increased as needed as other operations are /// implemented. -const NUM_SHARED_COLS: usize = 144; // only need 64 for add, sub, and mul +const NUM_SHARED_COLS: usize = 9 * N_LIMBS; // only need 64 for add, sub, and mul -const fn shared_col(i: usize) -> usize { - assert!(i < NUM_SHARED_COLS); - START_SHARED_COLS + i -} +const GENERAL_INPUT_0: Range = START_SHARED_COLS..START_SHARED_COLS + N_LIMBS; +const GENERAL_INPUT_1: Range = GENERAL_INPUT_0.end..GENERAL_INPUT_0.end + N_LIMBS; +const GENERAL_INPUT_2: Range = GENERAL_INPUT_1.end..GENERAL_INPUT_1.end + N_LIMBS; +const GENERAL_INPUT_3: Range = GENERAL_INPUT_2.end..GENERAL_INPUT_2.end + N_LIMBS; +const AUX_INPUT_0: Range = GENERAL_INPUT_3.end..GENERAL_INPUT_3.end + 2 * N_LIMBS; +const AUX_INPUT_1: Range = AUX_INPUT_0.end..AUX_INPUT_0.end + 2 * N_LIMBS; +const AUX_INPUT_2: Range = AUX_INPUT_1.end..AUX_INPUT_1.end + N_LIMBS; -const fn gen_input_cols(start: usize) -> [usize; N] { - let mut cols = [0usize; N]; - let mut i = 0; - while i < N { - cols[i] = shared_col(start + i); - i += 1; - } - cols -} +pub(crate) const ADD_INPUT_0: Range = GENERAL_INPUT_0; +pub(crate) const ADD_INPUT_1: Range = GENERAL_INPUT_1; +pub(crate) const ADD_OUTPUT: Range = GENERAL_INPUT_2; -const GENERAL_INPUT_0: [usize; N_LIMBS] = gen_input_cols::(0); -const GENERAL_INPUT_1: [usize; N_LIMBS] = gen_input_cols::(N_LIMBS); -const GENERAL_INPUT_2: [usize; N_LIMBS] = gen_input_cols::(2 * N_LIMBS); -const GENERAL_INPUT_3: [usize; N_LIMBS] = gen_input_cols::(3 * N_LIMBS); -const AUX_INPUT_0: [usize; 2 * N_LIMBS] = gen_input_cols::<{ 2 * N_LIMBS }>(4 * N_LIMBS); -const AUX_INPUT_1: [usize; 2 * N_LIMBS] = gen_input_cols::<{ 2 * N_LIMBS }>(6 * N_LIMBS); -const AUX_INPUT_2: [usize; N_LIMBS] = gen_input_cols::(8 * N_LIMBS); +pub(crate) const SUB_INPUT_0: Range = GENERAL_INPUT_0; +pub(crate) const SUB_INPUT_1: Range = GENERAL_INPUT_1; +pub(crate) const SUB_OUTPUT: Range = GENERAL_INPUT_2; -pub(crate) const ADD_INPUT_0: [usize; N_LIMBS] = GENERAL_INPUT_0; -pub(crate) const ADD_INPUT_1: [usize; N_LIMBS] = GENERAL_INPUT_1; -pub(crate) const ADD_OUTPUT: [usize; N_LIMBS] = GENERAL_INPUT_2; +pub(crate) const MUL_INPUT_0: Range = GENERAL_INPUT_0; +pub(crate) const MUL_INPUT_1: Range = GENERAL_INPUT_1; +pub(crate) const MUL_OUTPUT: Range = GENERAL_INPUT_2; +pub(crate) const MUL_AUX_INPUT: Range = GENERAL_INPUT_3; -pub(crate) const SUB_INPUT_0: [usize; N_LIMBS] = GENERAL_INPUT_0; -pub(crate) const SUB_INPUT_1: [usize; N_LIMBS] = GENERAL_INPUT_1; -pub(crate) const SUB_OUTPUT: [usize; N_LIMBS] = GENERAL_INPUT_2; +pub(crate) const CMP_INPUT_0: Range = GENERAL_INPUT_0; +pub(crate) const CMP_INPUT_1: Range = GENERAL_INPUT_1; +pub(crate) const CMP_OUTPUT: usize = GENERAL_INPUT_2.start; +pub(crate) const CMP_AUX_INPUT: Range = GENERAL_INPUT_3; -pub(crate) const MUL_INPUT_0: [usize; N_LIMBS] = GENERAL_INPUT_0; -pub(crate) const MUL_INPUT_1: [usize; N_LIMBS] = GENERAL_INPUT_1; -pub(crate) const MUL_OUTPUT: [usize; N_LIMBS] = GENERAL_INPUT_2; -pub(crate) const MUL_AUX_INPUT: [usize; N_LIMBS] = GENERAL_INPUT_3; - -pub(crate) const CMP_INPUT_0: [usize; N_LIMBS] = GENERAL_INPUT_0; -pub(crate) const CMP_INPUT_1: [usize; N_LIMBS] = GENERAL_INPUT_1; -pub(crate) const CMP_OUTPUT: usize = GENERAL_INPUT_2[0]; -pub(crate) const CMP_AUX_INPUT: [usize; N_LIMBS] = GENERAL_INPUT_3; - -pub(crate) const MODULAR_INPUT_0: [usize; N_LIMBS] = GENERAL_INPUT_0; -pub(crate) const MODULAR_INPUT_1: [usize; N_LIMBS] = GENERAL_INPUT_1; -pub(crate) const MODULAR_MODULUS: [usize; N_LIMBS] = GENERAL_INPUT_2; -pub(crate) const MODULAR_OUTPUT: [usize; N_LIMBS] = GENERAL_INPUT_3; -pub(crate) const MODULAR_QUO_INPUT: [usize; 2 * N_LIMBS] = AUX_INPUT_0; -// NB: Last value is not used in AUX, it is used in IS_ZERO -pub(crate) const MODULAR_AUX_INPUT: [usize; 2 * N_LIMBS] = AUX_INPUT_1; -pub(crate) const MODULAR_MOD_IS_ZERO: usize = AUX_INPUT_1[2 * N_LIMBS - 1]; -pub(crate) const MODULAR_OUT_AUX_RED: [usize; N_LIMBS] = AUX_INPUT_2; +pub(crate) const MODULAR_INPUT_0: Range = GENERAL_INPUT_0; +pub(crate) const MODULAR_INPUT_1: Range = GENERAL_INPUT_1; +pub(crate) const MODULAR_MODULUS: Range = GENERAL_INPUT_2; +pub(crate) const MODULAR_OUTPUT: Range = GENERAL_INPUT_3; +pub(crate) const MODULAR_QUO_INPUT: Range = AUX_INPUT_0; +// NB: Last value is not used in AUX, it is used in MOD_IS_ZERO +pub(crate) const MODULAR_AUX_INPUT: Range = AUX_INPUT_1; +pub(crate) const MODULAR_MOD_IS_ZERO: usize = AUX_INPUT_1.end - 1; +pub(crate) const MODULAR_OUT_AUX_RED: Range = AUX_INPUT_2; pub const NUM_ARITH_COLUMNS: usize = START_SHARED_COLS + NUM_SHARED_COLS; diff --git a/evm/src/arithmetic/compare.rs b/evm/src/arithmetic/compare.rs index a6566db5..55dc5764 100644 --- a/evm/src/arithmetic/compare.rs +++ b/evm/src/arithmetic/compare.rs @@ -22,12 +22,13 @@ use plonky2::iop::ext_target::ExtensionTarget; use crate::arithmetic::add::{eval_ext_circuit_are_equal, eval_packed_generic_are_equal}; use crate::arithmetic::columns::*; use crate::arithmetic::sub::u256_sub_br; +use crate::arithmetic::utils::read_value_u64_limbs; use crate::constraint_consumer::{ConstraintConsumer, RecursiveConstraintConsumer}; use crate::range_check_error; pub(crate) fn generate(lv: &mut [F; NUM_ARITH_COLUMNS], op: usize) { - let input0 = CMP_INPUT_0.map(|c| lv[c].to_canonical_u64()); - let input1 = CMP_INPUT_1.map(|c| lv[c].to_canonical_u64()); + let input0 = read_value_u64_limbs(lv, CMP_INPUT_0); + let input1 = read_value_u64_limbs(lv, CMP_INPUT_1); let (diff, br) = match op { // input0 - input1 == diff + br*2^256 @@ -39,9 +40,7 @@ pub(crate) fn generate(lv: &mut [F; NUM_ARITH_COLUMNS], op: usize) _ => panic!("op code not a comparison"), }; - for (&c, diff_limb) in CMP_AUX_INPUT.iter().zip(diff) { - lv[c] = F::from_canonical_u64(diff_limb); - } + lv[CMP_AUX_INPUT].copy_from_slice(&diff.map(|c| F::from_canonical_u64(c))); lv[CMP_OUTPUT] = F::from_canonical_u64(br); } @@ -56,15 +55,17 @@ fn eval_packed_generic_check_is_one_bit( pub(crate) fn eval_packed_generic_lt( yield_constr: &mut ConstraintConsumer

, is_op: P, - input0: [P; N_LIMBS], - input1: [P; N_LIMBS], - aux: [P; N_LIMBS], + input0: &[P], + input1: &[P], + aux: &[P], output: P, ) { + debug_assert!(input0.len() == N_LIMBS && input1.len() == N_LIMBS && aux.len() == N_LIMBS); + // Verify (input0 < input1) == output by providing aux such that // input0 - input1 == aux + output*2^256. - let lhs_limbs = input0.iter().zip(input1).map(|(&a, b)| a - b); - let cy = eval_packed_generic_are_equal(yield_constr, is_op, aux.into_iter(), lhs_limbs); + let lhs_limbs = input0.iter().zip(input1).map(|(&a, &b)| a - b); + let cy = eval_packed_generic_are_equal(yield_constr, is_op, aux.iter().copied(), lhs_limbs); // We don't need to check that cy is 0 or 1, since output has // already been checked to be 0 or 1. yield_constr.constraint(is_op * (cy - output)); @@ -81,9 +82,9 @@ pub fn eval_packed_generic( let is_lt = lv[IS_LT]; let is_gt = lv[IS_GT]; - let input0 = CMP_INPUT_0.map(|c| lv[c]); - let input1 = CMP_INPUT_1.map(|c| lv[c]); - let aux = CMP_AUX_INPUT.map(|c| lv[c]); + let input0 = &lv[CMP_INPUT_0]; + let input1 = &lv[CMP_INPUT_1]; + let aux = &lv[CMP_AUX_INPUT]; let output = lv[CMP_OUTPUT]; let is_cmp = is_lt + is_gt; @@ -109,11 +110,13 @@ pub(crate) fn eval_ext_circuit_lt, const D: usize>( builder: &mut plonky2::plonk::circuit_builder::CircuitBuilder, yield_constr: &mut RecursiveConstraintConsumer, is_op: ExtensionTarget, - input0: [ExtensionTarget; N_LIMBS], - input1: [ExtensionTarget; N_LIMBS], - aux: [ExtensionTarget; N_LIMBS], + input0: &[ExtensionTarget], + input1: &[ExtensionTarget], + aux: &[ExtensionTarget], output: ExtensionTarget, ) { + debug_assert!(input0.len() == N_LIMBS && input1.len() == N_LIMBS && aux.len() == N_LIMBS); + // Since `map` is lazy and the closure passed to it borrows // `builder`, we can't then borrow builder again below in the call // to `eval_ext_circuit_are_equal`. The solution is to force @@ -121,14 +124,14 @@ pub(crate) fn eval_ext_circuit_lt, const D: usize>( let lhs_limbs = input0 .iter() .zip(input1) - .map(|(&a, b)| builder.sub_extension(a, b)) + .map(|(&a, &b)| builder.sub_extension(a, b)) .collect::>>(); let cy = eval_ext_circuit_are_equal( builder, yield_constr, is_op, - aux.into_iter(), + aux.iter().copied(), lhs_limbs.into_iter(), ); let good_output = builder.sub_extension(cy, output); @@ -144,9 +147,9 @@ pub fn eval_ext_circuit, const D: usize>( let is_lt = lv[IS_LT]; let is_gt = lv[IS_GT]; - let input0 = CMP_INPUT_0.map(|c| lv[c]); - let input1 = CMP_INPUT_1.map(|c| lv[c]); - let aux = CMP_AUX_INPUT.map(|c| lv[c]); + let input0 = &lv[CMP_INPUT_0]; + let input1 = &lv[CMP_INPUT_1]; + let aux = &lv[CMP_AUX_INPUT]; let output = lv[CMP_OUTPUT]; let is_cmp = builder.add_extension(is_lt, is_gt); @@ -210,7 +213,7 @@ mod tests { lv[other_op] = F::ZERO; // set inputs to random values - for (&ai, bi) in CMP_INPUT_0.iter().zip(CMP_INPUT_1) { + for (ai, bi) in CMP_INPUT_0.zip(CMP_INPUT_1) { lv[ai] = F::from_canonical_u16(rng.gen()); lv[bi] = F::from_canonical_u16(rng.gen()); } diff --git a/evm/src/arithmetic/modular.rs b/evm/src/arithmetic/modular.rs index fd2a2e28..53051cda 100644 --- a/evm/src/arithmetic/modular.rs +++ b/evm/src/arithmetic/modular.rs @@ -160,9 +160,9 @@ fn generate_modular_op( ) { // Inputs are all range-checked in [0, 2^16), so the "as i64" // conversion is safe. - let input0_limbs = MODULAR_INPUT_0.map(|c| F::to_canonical_u64(&lv[c]) as i64); - let input1_limbs = MODULAR_INPUT_1.map(|c| F::to_canonical_u64(&lv[c]) as i64); - let mut modulus_limbs = MODULAR_MODULUS.map(|c| F::to_canonical_u64(&lv[c]) as i64); + let input0_limbs = read_value_i64_limbs(lv, MODULAR_INPUT_0); + let input1_limbs = read_value_i64_limbs(lv, MODULAR_INPUT_1); + let mut modulus_limbs = read_value_i64_limbs(lv, MODULAR_MODULUS); // The use of BigUints is just to avoid having to implement // modular reduction. @@ -175,12 +175,11 @@ fn generate_modular_op( let mut constr_poly = [0i64; 2 * N_LIMBS]; constr_poly[..2 * N_LIMBS - 1].copy_from_slice(&operation(input0_limbs, input1_limbs)); + let mut mod_is_zero = F::ZERO; if modulus.is_zero() { modulus += 1u32; modulus_limbs[0] += 1i64; - lv[MODULAR_MOD_IS_ZERO] = F::ONE; - } else { - lv[MODULAR_MOD_IS_ZERO] = F::ZERO; + mod_is_zero = F::ONE; } let input = columns_to_biguint(&constr_poly); @@ -214,19 +213,11 @@ fn generate_modular_op( // the result of removing that root. let aux_limbs = pol_remove_root_2exp::(constr_poly); - for deg in 0..N_LIMBS { - lv[MODULAR_OUTPUT[deg]] = F::from_canonical_i64(output_limbs[deg]); - lv[MODULAR_OUT_AUX_RED[deg]] = F::from_canonical_i64(out_aux_red[deg]); - lv[MODULAR_QUO_INPUT[deg]] = F::from_canonical_i64(quot_limbs[deg]); - lv[MODULAR_QUO_INPUT[deg + N_LIMBS]] = F::from_canonical_i64(quot_limbs[deg + N_LIMBS]); - lv[MODULAR_AUX_INPUT[deg]] = F::from_noncanonical_i64(aux_limbs[deg]); - // Don't overwrite MODULAR_MOD_IS_ZERO, which is at the last - // index of MODULAR_AUX_INPUT - if deg < N_LIMBS - 1 { - lv[MODULAR_AUX_INPUT[deg + N_LIMBS]] = - F::from_noncanonical_i64(aux_limbs[deg + N_LIMBS]); - } - } + lv[MODULAR_OUTPUT].copy_from_slice(&output_limbs.map(|c| F::from_canonical_i64(c))); + lv[MODULAR_OUT_AUX_RED].copy_from_slice(&out_aux_red.map(|c| F::from_canonical_i64(c))); + lv[MODULAR_QUO_INPUT].copy_from_slice("_limbs.map(|c| F::from_canonical_i64(c))); + lv[MODULAR_AUX_INPUT].copy_from_slice(&aux_limbs.map(|c| F::from_noncanonical_i64(c))); + lv[MODULAR_MOD_IS_ZERO] = mod_is_zero; } /// Generate the output and auxiliary values for modular operations. @@ -262,7 +253,7 @@ fn modular_constr_poly( range_check_error!(MODULAR_AUX_INPUT, 20, signed); range_check_error!(MODULAR_OUTPUT, 16); - let mut modulus = MODULAR_MODULUS.map(|c| lv[c]); + let mut modulus = read_value::(lv, MODULAR_MODULUS); let mod_is_zero = lv[MODULAR_MOD_IS_ZERO]; // Check that mod_is_zero is zero or one @@ -277,22 +268,22 @@ fn modular_constr_poly( // modulus = 0. modulus[0] += mod_is_zero; - let output = MODULAR_OUTPUT.map(|c| lv[c]); + let output = &lv[MODULAR_OUTPUT]; // Verify that the output is reduced, i.e. output < modulus. - let out_aux_red = MODULAR_OUT_AUX_RED.map(|c| lv[c]); + let out_aux_red = &lv[MODULAR_OUT_AUX_RED]; let is_less_than = P::ONES; eval_packed_generic_lt( yield_constr, filter, output, - modulus, + &modulus, out_aux_red, is_less_than, ); // prod = q(x) * m(x) - let quot = MODULAR_QUO_INPUT.map(|c| lv[c]); + let quot = read_value::<{ 2 * N_LIMBS }, _>(lv, MODULAR_QUO_INPUT); let prod = pol_mul_wide2(quot, modulus); // higher order terms must be zero for &x in prod[2 * N_LIMBS..].iter() { @@ -301,10 +292,10 @@ fn modular_constr_poly( // constr_poly = c(x) + q(x) * m(x) let mut constr_poly: [_; 2 * N_LIMBS] = prod[0..2 * N_LIMBS].try_into().unwrap(); - pol_add_assign(&mut constr_poly, &output); + pol_add_assign(&mut constr_poly, output); // constr_poly = c(x) + q(x) * m(x) + (x - β) * s(x) - let mut aux = MODULAR_AUX_INPUT.map(|c| lv[c]); + let mut aux = read_value::<{ 2 * N_LIMBS }, _>(lv, MODULAR_AUX_INPUT); aux[2 * N_LIMBS - 1] = P::ZEROS; // zero out the MOD_IS_ZERO flag let base = P::Scalar::from_canonical_u64(1 << LIMB_BITS); pol_add_assign(&mut constr_poly, &pol_adjoin_root(aux, base)); @@ -324,8 +315,8 @@ pub(crate) fn eval_packed_generic( // constr_poly has 2*N_LIMBS limbs let constr_poly = modular_constr_poly(lv, yield_constr, filter); - let input0 = MODULAR_INPUT_0.map(|c| lv[c]); - let input1 = MODULAR_INPUT_1.map(|c| lv[c]); + let input0 = read_value(lv, MODULAR_INPUT_0); + let input1 = read_value(lv, MODULAR_INPUT_1); let add_input = pol_add(input0, input1); let mul_input = pol_mul_wide(input0, input1); @@ -362,7 +353,7 @@ fn modular_constr_poly_ext_circuit, const D: usize> yield_constr: &mut RecursiveConstraintConsumer, filter: ExtensionTarget, ) -> [ExtensionTarget; 2 * N_LIMBS] { - let mut modulus = MODULAR_MODULUS.map(|c| lv[c]); + let mut modulus = read_value::(lv, MODULAR_MODULUS); let mod_is_zero = lv[MODULAR_MOD_IS_ZERO]; let t = builder.mul_sub_extension(mod_is_zero, mod_is_zero, mod_is_zero); @@ -376,20 +367,20 @@ fn modular_constr_poly_ext_circuit, const D: usize> modulus[0] = builder.add_extension(modulus[0], mod_is_zero); - let output = MODULAR_OUTPUT.map(|c| lv[c]); - let out_aux_red = MODULAR_OUT_AUX_RED.map(|c| lv[c]); + let output = &lv[MODULAR_OUTPUT]; + let out_aux_red = &lv[MODULAR_OUT_AUX_RED]; let is_less_than = builder.one_extension(); eval_ext_circuit_lt( builder, yield_constr, filter, output, - modulus, + &modulus, out_aux_red, is_less_than, ); - let quot = MODULAR_QUO_INPUT.map(|c| lv[c]); + let quot = read_value::<{ 2 * N_LIMBS }, _>(lv, MODULAR_QUO_INPUT); let prod = pol_mul_wide2_ext_circuit(builder, quot, modulus); for &x in prod[2 * N_LIMBS..].iter() { let t = builder.mul_extension(filter, x); @@ -397,9 +388,9 @@ fn modular_constr_poly_ext_circuit, const D: usize> } let mut constr_poly: [_; 2 * N_LIMBS] = prod[0..2 * N_LIMBS].try_into().unwrap(); - pol_add_assign_ext_circuit(builder, &mut constr_poly, &output); + pol_add_assign_ext_circuit(builder, &mut constr_poly, output); - let mut aux = MODULAR_AUX_INPUT.map(|c| lv[c]); + let mut aux = read_value::<{ 2 * N_LIMBS }, _>(lv, MODULAR_AUX_INPUT); aux[2 * N_LIMBS - 1] = builder.zero_extension(); let base = builder.constant_extension(F::Extension::from_canonical_u64(1u64 << LIMB_BITS)); let t = pol_adjoin_root_ext_circuit(builder, aux, base); @@ -421,8 +412,8 @@ pub(crate) fn eval_ext_circuit, const D: usize>( let constr_poly = modular_constr_poly_ext_circuit(lv, builder, yield_constr, filter); - let input0 = MODULAR_INPUT_0.map(|c| lv[c]); - let input1 = MODULAR_INPUT_1.map(|c| lv[c]); + let input0 = read_value(lv, MODULAR_INPUT_0); + let input1 = read_value(lv, MODULAR_INPUT_1); let add_input = pol_add_ext_circuit(builder, input0, input1); let mul_input = pol_mul_wide_ext_circuit(builder, input0, input1); @@ -498,11 +489,7 @@ mod tests { for i in 0..N_RND_TESTS { // set inputs to random values - for (&ai, &bi, &mi) in izip!( - MODULAR_INPUT_0.iter(), - MODULAR_INPUT_1.iter(), - MODULAR_MODULUS.iter() - ) { + for (ai, bi, mi) in izip!(MODULAR_INPUT_0, MODULAR_INPUT_1, MODULAR_MODULUS) { lv[ai] = F::from_canonical_u16(rng.gen()); lv[bi] = F::from_canonical_u16(rng.gen()); lv[mi] = F::from_canonical_u16(rng.gen()); @@ -514,7 +501,7 @@ mod tests { if i > N_RND_TESTS / 2 { // 1 <= start < N_LIMBS let start = (rng.gen::() % (N_LIMBS - 1)) + 1; - for &mi in &MODULAR_MODULUS[start..N_LIMBS] { + for mi in MODULAR_MODULUS.skip(start) { lv[mi] = F::ZERO; } } @@ -552,11 +539,7 @@ mod tests { for _i in 0..N_RND_TESTS { // set inputs to random values and the modulus to zero; // the output is defined to be zero when modulus is zero. - for (&ai, &bi, &mi) in izip!( - MODULAR_INPUT_0.iter(), - MODULAR_INPUT_1.iter(), - MODULAR_MODULUS.iter() - ) { + for (ai, bi, mi) in izip!(MODULAR_INPUT_0, MODULAR_INPUT_1, MODULAR_MODULUS) { lv[ai] = F::from_canonical_u16(rng.gen()); lv[bi] = F::from_canonical_u16(rng.gen()); lv[mi] = F::ZERO; @@ -565,7 +548,7 @@ mod tests { generate(&mut lv, op_filter); // check that the correct output was generated - assert!(MODULAR_OUTPUT.iter().all(|&oi| lv[oi] == F::ZERO)); + assert!(lv[MODULAR_OUTPUT].iter().all(|&c| c == F::ZERO)); let mut constraint_consumer = ConstraintConsumer::new( vec![GoldilocksField(2), GoldilocksField(3), GoldilocksField(5)], @@ -580,7 +563,7 @@ mod tests { .all(|&acc| acc == F::ZERO)); // Corrupt one output limb by setting it to a non-zero value - let random_oi = MODULAR_OUTPUT[rng.gen::() % N_LIMBS]; + let random_oi = MODULAR_OUTPUT.start + rng.gen::() % N_LIMBS; lv[random_oi] = F::from_canonical_u16(rng.gen_range(1..u16::MAX)); eval_packed_generic(&lv, &mut constraint_consumer); diff --git a/evm/src/arithmetic/mul.rs b/evm/src/arithmetic/mul.rs index c98b9af8..7dda18e2 100644 --- a/evm/src/arithmetic/mul.rs +++ b/evm/src/arithmetic/mul.rs @@ -67,8 +67,8 @@ use crate::constraint_consumer::{ConstraintConsumer, RecursiveConstraintConsumer use crate::range_check_error; pub fn generate(lv: &mut [F; NUM_ARITH_COLUMNS]) { - let input0_limbs = MUL_INPUT_0.map(|c| lv[c].to_canonical_u64() as i64); - let input1_limbs = MUL_INPUT_1.map(|c| lv[c].to_canonical_u64() as i64); + let input0 = read_value_i64_limbs(lv, MUL_INPUT_0); + let input1 = read_value_i64_limbs(lv, MUL_INPUT_1); const MASK: i64 = (1i64 << LIMB_BITS) - 1i64; @@ -79,7 +79,7 @@ pub fn generate(lv: &mut [F; NUM_ARITH_COLUMNS]) { // First calculate the coefficients of a(x)*b(x) (in unreduced_prod), // then do carry propagation to obtain C = c(β) = a(β)*b(β). let mut cy = 0i64; - let mut unreduced_prod = pol_mul_lo(input0_limbs, input1_limbs); + let mut unreduced_prod = pol_mul_lo(input0, input1); for col in 0..N_LIMBS { let t = unreduced_prod[col] + cy; cy = t >> LIMB_BITS; @@ -90,15 +90,13 @@ pub fn generate(lv: &mut [F; NUM_ARITH_COLUMNS]) { // aux_limbs to handle the fact that unreduced_prod will // inevitably contain one digit's worth that is > 2^256. + lv[MUL_OUTPUT].copy_from_slice(&output_limbs.map(|c| F::from_canonical_i64(c))); pol_sub_assign(&mut unreduced_prod, &output_limbs); let mut aux_limbs = pol_remove_root_2exp::(unreduced_prod); aux_limbs[N_LIMBS - 1] = -cy; - for deg in 0..N_LIMBS { - lv[MUL_OUTPUT[deg]] = F::from_canonical_i64(output_limbs[deg]); - lv[MUL_AUX_INPUT[deg]] = F::from_noncanonical_i64(aux_limbs[deg]); - } + lv[MUL_AUX_INPUT].copy_from_slice(&aux_limbs.map(|c| F::from_noncanonical_i64(c))); } pub fn eval_packed_generic( @@ -111,10 +109,10 @@ pub fn eval_packed_generic( range_check_error!(MUL_AUX_INPUT, 20); let is_mul = lv[IS_MUL]; - let input0_limbs = MUL_INPUT_0.map(|c| lv[c]); - let input1_limbs = MUL_INPUT_1.map(|c| lv[c]); - let output_limbs = MUL_OUTPUT.map(|c| lv[c]); - let aux_limbs = MUL_AUX_INPUT.map(|c| lv[c]); + let input0_limbs = read_value::(lv, MUL_INPUT_0); + let input1_limbs = read_value::(lv, MUL_INPUT_1); + let output_limbs = read_value::(lv, MUL_OUTPUT); + let aux_limbs = read_value::(lv, MUL_AUX_INPUT); // Constraint poly holds the coefficients of the polynomial that // must be identically zero for this multiplication to be @@ -123,13 +121,13 @@ pub fn eval_packed_generic( // These two lines set constr_poly to the polynomial a(x)b(x) - c(x), // where a, b and c are the polynomials // - // a(x) = \sum_i input0_limbs[i] * β^i - // b(x) = \sum_i input1_limbs[i] * β^i - // c(x) = \sum_i output_limbs[i] * β^i + // a(x) = \sum_i input0_limbs[i] * x^i + // b(x) = \sum_i input1_limbs[i] * x^i + // c(x) = \sum_i output_limbs[i] * x^i // - // This polynomial should equal where s is + // This polynomial should equal (x - β)*s(x) where s is // - // s(x) = \sum_i aux_limbs[i] * β^i + // s(x) = \sum_i aux_limbs[i] * x^i // let mut constr_poly = pol_mul_lo(input0_limbs, input1_limbs); pol_sub_assign(&mut constr_poly, &output_limbs); @@ -153,10 +151,10 @@ pub fn eval_ext_circuit, const D: usize>( yield_constr: &mut RecursiveConstraintConsumer, ) { let is_mul = lv[IS_MUL]; - let input0_limbs = MUL_INPUT_0.map(|c| lv[c]); - let input1_limbs = MUL_INPUT_1.map(|c| lv[c]); - let output_limbs = MUL_OUTPUT.map(|c| lv[c]); - let aux_limbs = MUL_AUX_INPUT.map(|c| lv[c]); + let input0_limbs = read_value::(lv, MUL_INPUT_0); + let input1_limbs = read_value::(lv, MUL_INPUT_1); + let output_limbs = read_value::(lv, MUL_OUTPUT); + let aux_limbs = read_value::(lv, MUL_AUX_INPUT); let mut constr_poly = pol_mul_lo_ext_circuit(builder, input0_limbs, input1_limbs); pol_sub_assign_ext_circuit(builder, &mut constr_poly, &output_limbs); @@ -220,7 +218,7 @@ mod tests { for _i in 0..N_RND_TESTS { // set inputs to random values - for (&ai, bi) in MUL_INPUT_0.iter().zip(MUL_INPUT_1) { + for (ai, bi) in MUL_INPUT_0.zip(MUL_INPUT_1) { lv[ai] = F::from_canonical_u16(rng.gen()); lv[bi] = F::from_canonical_u16(rng.gen()); } diff --git a/evm/src/arithmetic/sub.rs b/evm/src/arithmetic/sub.rs index 25834406..f8377651 100644 --- a/evm/src/arithmetic/sub.rs +++ b/evm/src/arithmetic/sub.rs @@ -6,6 +6,7 @@ use plonky2::iop::ext_target::ExtensionTarget; use crate::arithmetic::add::{eval_ext_circuit_are_equal, eval_packed_generic_are_equal}; use crate::arithmetic::columns::*; +use crate::arithmetic::utils::read_value_u64_limbs; use crate::constraint_consumer::{ConstraintConsumer, RecursiveConstraintConsumer}; use crate::range_check_error; @@ -28,14 +29,12 @@ pub(crate) fn u256_sub_br(input0: [u64; N_LIMBS], input1: [u64; N_LIMBS]) -> ([u } pub fn generate(lv: &mut [F; NUM_ARITH_COLUMNS]) { - let input0_limbs = SUB_INPUT_0.map(|c| lv[c].to_canonical_u64()); - let input1_limbs = SUB_INPUT_1.map(|c| lv[c].to_canonical_u64()); + let input0 = read_value_u64_limbs(lv, SUB_INPUT_0); + let input1 = read_value_u64_limbs(lv, SUB_INPUT_1); - let (output_limbs, _) = u256_sub_br(input0_limbs, input1_limbs); + let (output_limbs, _) = u256_sub_br(input0, input1); - for (&c, output_limb) in SUB_OUTPUT.iter().zip(output_limbs) { - lv[c] = F::from_canonical_u64(output_limb); - } + lv[SUB_OUTPUT].copy_from_slice(&output_limbs.map(|c| F::from_canonical_u64(c))); } pub fn eval_packed_generic( @@ -47,13 +46,18 @@ pub fn eval_packed_generic( range_check_error!(SUB_OUTPUT, 16); let is_sub = lv[IS_SUB]; - let input0_limbs = SUB_INPUT_0.iter().map(|&c| lv[c]); - let input1_limbs = SUB_INPUT_1.iter().map(|&c| lv[c]); - let output_limbs = SUB_OUTPUT.iter().map(|&c| lv[c]); + let input0_limbs = &lv[SUB_INPUT_0]; + let input1_limbs = &lv[SUB_INPUT_1]; + let output_limbs = &lv[SUB_OUTPUT]; - let output_computed = input0_limbs.zip(input1_limbs).map(|(a, b)| a - b); + let output_computed = input0_limbs.iter().zip(input1_limbs).map(|(&a, &b)| a - b); - eval_packed_generic_are_equal(yield_constr, is_sub, output_limbs, output_computed); + eval_packed_generic_are_equal( + yield_constr, + is_sub, + output_limbs.iter().copied(), + output_computed, + ); } #[allow(clippy::needless_collect)] @@ -63,24 +67,25 @@ pub fn eval_ext_circuit, const D: usize>( yield_constr: &mut RecursiveConstraintConsumer, ) { let is_sub = lv[IS_SUB]; - let input0_limbs = SUB_INPUT_0.iter().map(|&c| lv[c]); - let input1_limbs = SUB_INPUT_1.iter().map(|&c| lv[c]); - let output_limbs = SUB_OUTPUT.iter().map(|&c| lv[c]); + let input0_limbs = &lv[SUB_INPUT_0]; + let input1_limbs = &lv[SUB_INPUT_1]; + let output_limbs = &lv[SUB_OUTPUT]; // Since `map` is lazy and the closure passed to it borrows // `builder`, we can't then borrow builder again below in the call // to `eval_ext_circuit_are_equal`. The solution is to force // evaluation with `collect`. let output_computed = input0_limbs + .iter() .zip(input1_limbs) - .map(|(a, b)| builder.sub_extension(a, b)) + .map(|(&a, &b)| builder.sub_extension(a, b)) .collect::>>(); eval_ext_circuit_are_equal( builder, yield_constr, is_sub, - output_limbs, + output_limbs.iter().copied(), output_computed.into_iter(), ); } @@ -134,7 +139,7 @@ mod tests { for _ in 0..N_RND_TESTS { // set inputs to random values - for (&ai, bi) in SUB_INPUT_0.iter().zip(SUB_INPUT_1) { + for (ai, bi) in SUB_INPUT_0.zip(SUB_INPUT_1) { lv[ai] = F::from_canonical_u16(rng.gen()); lv[bi] = F::from_canonical_u16(rng.gen()); } diff --git a/evm/src/arithmetic/utils.rs b/evm/src/arithmetic/utils.rs index ccb8bc0a..871a9646 100644 --- a/evm/src/arithmetic/utils.rs +++ b/evm/src/arithmetic/utils.rs @@ -1,4 +1,4 @@ -use std::ops::{Add, AddAssign, Mul, Neg, Shr, Sub, SubAssign}; +use std::ops::{Add, AddAssign, Mul, Neg, Range, Shr, Sub, SubAssign}; use log::error; use plonky2::field::extension::Extendable; @@ -6,7 +6,7 @@ use plonky2::hash::hash_types::RichField; use plonky2::iop::ext_target::ExtensionTarget; use plonky2::plonk::circuit_builder::CircuitBuilder; -use crate::arithmetic::columns::N_LIMBS; +use crate::arithmetic::columns::{NUM_ARITH_COLUMNS, N_LIMBS}; /// Emit an error message regarding unchecked range assumptions. /// Assumes the values in `cols` are `[cols[0], cols[0] + 1, ..., @@ -14,7 +14,7 @@ use crate::arithmetic::columns::N_LIMBS; pub(crate) fn _range_check_error( file: &str, line: u32, - cols: &[usize], + cols: Range, signedness: &str, ) { error!( @@ -23,8 +23,8 @@ pub(crate) fn _range_check_error( file, RC_BITS, signedness, - cols[0], - cols[0] + cols.len() - 1 + cols.start, + cols.end - 1, ); } @@ -34,7 +34,7 @@ macro_rules! range_check_error { $crate::arithmetic::utils::_range_check_error::<$rc_bits>( file!(), line!(), - &$cols, + $cols, "unsigned", ); }; @@ -42,7 +42,7 @@ macro_rules! range_check_error { $crate::arithmetic::utils::_range_check_error::<$rc_bits>( file!(), line!(), - &$cols, + $cols, "signed", ); }; @@ -337,3 +337,34 @@ where } q } + +/// Read the range `value_idxs` of values from `lv` into an array of +/// length `N`. Panics if the length of the range is not `N`. +pub(crate) fn read_value( + lv: &[T; NUM_ARITH_COLUMNS], + value_idxs: Range, +) -> [T; N] { + lv[value_idxs].try_into().unwrap() +} + +/// Read the range `value_idxs` of values from `lv` into an array of +/// length `N`, interpreting the values as `u64`s. Panics if the +/// length of the range is not `N`. +pub(crate) fn read_value_u64_limbs( + lv: &[F; NUM_ARITH_COLUMNS], + value_idxs: Range, +) -> [u64; N] { + let limbs: [_; N] = lv[value_idxs].try_into().unwrap(); + limbs.map(|c| F::to_canonical_u64(&c)) +} + +/// Read the range `value_idxs` of values from `lv` into an array of +/// length `N`, interpreting the values as `i64`s. Panics if the +/// length of the range is not `N`. +pub(crate) fn read_value_i64_limbs( + lv: &[F; NUM_ARITH_COLUMNS], + value_idxs: Range, +) -> [i64; N] { + let limbs: [_; N] = lv[value_idxs].try_into().unwrap(); + limbs.map(|c| F::to_canonical_u64(&c) as i64) +} From f4c0337af7b70e6fec25dff7451b8aebcbcdbd9d Mon Sep 17 00:00:00 2001 From: Daniel Lubarov Date: Mon, 10 Oct 2022 23:46:45 -0700 Subject: [PATCH 13/17] Interpreter feature to configure debug offsets --- evm/src/cpu/kernel/interpreter.rs | 24 +++++++++++++++++++++++- 1 file changed, 23 insertions(+), 1 deletion(-) diff --git a/evm/src/cpu/kernel/interpreter.rs b/evm/src/cpu/kernel/interpreter.rs index 2eb9dcb9..a85f3db9 100644 --- a/evm/src/cpu/kernel/interpreter.rs +++ b/evm/src/cpu/kernel/interpreter.rs @@ -75,6 +75,7 @@ pub struct Interpreter<'a> { pub(crate) generation_state: GenerationState, prover_inputs_map: &'a HashMap, pub(crate) halt_offsets: Vec, + pub(crate) debug_offsets: Vec, running: bool, } @@ -128,6 +129,7 @@ impl<'a> Interpreter<'a> { prover_inputs_map: prover_inputs, context: 0, halt_offsets: vec![DEFAULT_HALT_OFFSET], + debug_offsets: vec![], running: false, } } @@ -283,7 +285,7 @@ impl<'a> Interpreter<'a> { 0x55 => todo!(), // "SSTORE", 0x56 => self.run_jump(), // "JUMP", 0x57 => self.run_jumpi(), // "JUMPI", - 0x58 => todo!(), // "GETPC", + 0x58 => self.run_pc(), // "PC", 0x59 => self.run_msize(), // "MSIZE", 0x5a => todo!(), // "GAS", 0x5b => self.run_jumpdest(), // "JUMPDEST", @@ -318,9 +320,24 @@ impl<'a> Interpreter<'a> { 0xff => todo!(), // "SELFDESTRUCT", _ => bail!("Unrecognized opcode {}.", opcode), }; + + if self.debug_offsets.contains(&self.offset) { + println!("At {}, stack={:?}", self.offset_name(), self.stack()); + } + Ok(()) } + /// Get a string representation of the current offset for debugging purposes. + fn offset_name(&self) -> String { + // TODO: Not sure we should use KERNEL? Interpreter is more general in other places. + let label = KERNEL + .global_labels + .iter() + .find_map(|(k, v)| (*v == self.offset).then(|| k.clone())); + label.unwrap_or_else(|| self.offset.to_string()) + } + fn run_stop(&mut self) { self.running = false; } @@ -476,6 +493,7 @@ impl<'a> Interpreter<'a> { let bytes = (offset..offset + size) .map(|i| self.memory.mload_general(context, segment, i).byte(0)) .collect::>(); + println!("Hashing {:?}", &bytes); let hash = keccak(bytes); self.push(U256::from_big_endian(hash.as_bytes())); } @@ -544,6 +562,10 @@ impl<'a> Interpreter<'a> { } } + fn run_pc(&mut self) { + self.push((self.offset - 1).into()); + } + fn run_msize(&mut self) { let num_bytes = self.memory.context_memory[self.context].segments [Segment::MainMemory as usize] From 299aabf860241be96e4eaba4f1d7761289afbfb1 Mon Sep 17 00:00:00 2001 From: Daniel Lubarov Date: Tue, 11 Oct 2022 08:46:40 -0700 Subject: [PATCH 14/17] Fix branch hashing bug --- evm/src/cpu/kernel/asm/memory/core.asm | 13 +++++ evm/src/cpu/kernel/asm/mpt/hash.asm | 56 ++++++++++++------- .../cpu/kernel/constants/global_metadata.rs | 8 ++- evm/src/cpu/kernel/interpreter.rs | 10 +++- evm/src/cpu/kernel/tests/mpt/hash.rs | 6 +- evm/src/cpu/kernel/tests/mpt/insert.rs | 34 ++++++++++- evm/src/memory/segments.rs | 12 +++- 7 files changed, 113 insertions(+), 26 deletions(-) diff --git a/evm/src/cpu/kernel/asm/memory/core.asm b/evm/src/cpu/kernel/asm/memory/core.asm index f4bcf1f1..2b4d2b68 100644 --- a/evm/src/cpu/kernel/asm/memory/core.asm +++ b/evm/src/cpu/kernel/asm/memory/core.asm @@ -55,6 +55,19 @@ // stack: (empty) %endmacro +// Store a single value from the given segment of kernel (context 0) memory. +%macro mstore_kernel(segment, offset) + // stack: value + PUSH $offset + // stack: offset, value + PUSH $segment + // stack: segment, offset, value + PUSH 0 // kernel has context 0 + // stack: context, segment, offset, value + MSTORE_GENERAL + // stack: (empty) +%endmacro + // Load from the kernel a big-endian u32, consisting of 4 bytes (c_3, c_2, c_1, c_0) %macro mload_kernel_u32(segment) // stack: offset diff --git a/evm/src/cpu/kernel/asm/mpt/hash.asm b/evm/src/cpu/kernel/asm/mpt/hash.asm index 511e9d18..8342d650 100644 --- a/evm/src/cpu/kernel/asm/mpt/hash.asm +++ b/evm/src/cpu/kernel/asm/mpt/hash.asm @@ -118,14 +118,22 @@ encode_node_branch: POP // stack: node_payload_ptr, encode_value, retdest + // Get the next unused offset within the encoded child buffers. + // Then immediately increment the next unused offset by 16, so any + // recursive calls will use nonoverlapping offsets. + %mload_global_metadata(@TRIE_ENCODED_CHILD_SIZE) + DUP1 %add_const(16) + %mstore_global_metadata(@TRIE_ENCODED_CHILD_SIZE) + // stack: base_offset, node_payload_ptr, encode_value, retdest + // We will call encode_or_hash_node on each child. For the i'th child, we - // will store the result in SEGMENT_KERNEL_GENERAL[i], and its length in - // SEGMENT_KERNEL_GENERAL_2[i]. + // will store the result in SEGMENT_TRIE_ENCODED_CHILD[base + i], and its length in + // SEGMENT_TRIE_ENCODED_CHILD_LEN[base + i]. %encode_child(0) %encode_child(1) %encode_child(2) %encode_child(3) %encode_child(4) %encode_child(5) %encode_child(6) %encode_child(7) %encode_child(8) %encode_child(9) %encode_child(10) %encode_child(11) %encode_child(12) %encode_child(13) %encode_child(14) %encode_child(15) - // stack: node_payload_ptr, encode_value, retdest + // stack: base_offset, node_payload_ptr, encode_value, retdest // Now, append each child to our RLP tape. PUSH 9 // rlp_pos; we start at 9 to leave room to prepend a list prefix @@ -133,6 +141,11 @@ encode_node_branch: %append_child(4) %append_child(5) %append_child(6) %append_child(7) %append_child(8) %append_child(9) %append_child(10) %append_child(11) %append_child(12) %append_child(13) %append_child(14) %append_child(15) + // stack: rlp_pos', base_offset, node_payload_ptr, encode_value, retdest + + // We no longer need base_offset. + SWAP1 + POP // stack: rlp_pos', node_payload_ptr, encode_value, retdest SWAP1 @@ -165,43 +178,44 @@ encode_node_branch_prepend_prefix: JUMP // Part of the encode_node_branch function. Encodes the i'th child. -// Stores the result in SEGMENT_KERNEL_GENERAL[i], and its length in -// SEGMENT_KERNEL_GENERAL_2[i]. +// Stores the result in SEGMENT_TRIE_ENCODED_CHILD[base + i], and its length in +// SEGMENT_TRIE_ENCODED_CHILD_LEN[base + i]. %macro encode_child(i) - // stack: node_payload_ptr, encode_value, retdest + // stack: base_offset, node_payload_ptr, encode_value, retdest PUSH %%after_encode - DUP3 DUP3 - // stack: node_payload_ptr, encode_value, %%after_encode, node_payload_ptr, encode_value, retdest + DUP4 DUP4 + // stack: node_payload_ptr, encode_value, %%after_encode, base_offset, node_payload_ptr, encode_value, retdest %add_const($i) %mload_trie_data - // stack: child_i_ptr, encode_value, %%after_encode, node_payload_ptr, encode_value, retdest + // stack: child_i_ptr, encode_value, %%after_encode, base_offset, node_payload_ptr, encode_value, retdest %jump(encode_or_hash_node) %%after_encode: - // stack: result, result_len, node_payload_ptr, encode_value, retdest - %mstore_kernel_general($i) - %mstore_kernel_general_2($i) - // stack: node_payload_ptr, encode_value, retdest + // stack: result, result_len, base_offset, node_payload_ptr, encode_value, retdest + DUP3 %add_const($i) %mstore_kernel(@SEGMENT_TRIE_ENCODED_CHILD) + // stack: result_len, base_offset, node_payload_ptr, encode_value, retdest + DUP2 %add_const($i) %mstore_kernel(@SEGMENT_TRIE_ENCODED_CHILD_LEN) + // stack: base_offset, node_payload_ptr, encode_value, retdest %endmacro // Part of the encode_node_branch function. Appends the i'th child's RLP. %macro append_child(i) - // stack: rlp_pos, node_payload_ptr, encode_value, retdest - %mload_kernel_general($i) // load result - %mload_kernel_general_2($i) // load result_len - // stack: result_len, result, rlp_pos, node_payload_ptr, encode_value, retdest + // stack: rlp_pos, base_offset, node_payload_ptr, encode_value, retdest + DUP2 %add_const($i) %mload_kernel(@SEGMENT_TRIE_ENCODED_CHILD) // load result + DUP3 %add_const($i) %mload_kernel(@SEGMENT_TRIE_ENCODED_CHILD_LEN) // load result_len + // stack: result_len, result, rlp_pos, base_offset, node_payload_ptr, encode_value, retdest // If result_len != 32, result is raw RLP, with an appropriate RLP prefix already. DUP1 %sub_const(32) %jumpi(%%unpack) // Otherwise, result is a hash, and we need to add the prefix 0x80 + 32 = 160. - // stack: result_len, result, rlp_pos, node_payload_ptr, encode_value, retdest + // stack: result_len, result, rlp_pos, base_offset, node_payload_ptr, encode_value, retdest PUSH 160 DUP4 // rlp_pos %mstore_rlp SWAP2 %increment SWAP2 // rlp_pos += 1 %%unpack: - %stack (result_len, result, rlp_pos, node_payload_ptr, encode_value, retdest) - -> (rlp_pos, result, result_len, %%after_unpacking, node_payload_ptr, encode_value, retdest) + %stack (result_len, result, rlp_pos, base_offset, node_payload_ptr, encode_value, retdest) + -> (rlp_pos, result, result_len, %%after_unpacking, base_offset, node_payload_ptr, encode_value, retdest) %jump(mstore_unpacking_rlp) %%after_unpacking: - // stack: rlp_pos', node_payload_ptr, encode_value, retdest + // stack: rlp_pos', base_offset, node_payload_ptr, encode_value, retdest %endmacro encode_node_extension: diff --git a/evm/src/cpu/kernel/constants/global_metadata.rs b/evm/src/cpu/kernel/constants/global_metadata.rs index f3f34e7a..295cdfd5 100644 --- a/evm/src/cpu/kernel/constants/global_metadata.rs +++ b/evm/src/cpu/kernel/constants/global_metadata.rs @@ -31,10 +31,14 @@ pub(crate) enum GlobalMetadata { StateTrieRootDigestAfter = 11, TransactionTrieRootDigestAfter = 12, ReceiptTrieRootDigestAfter = 13, + + /// The sizes of the `TrieEncodedChild` and `TrieEncodedChildLen` buffers. In other words, the + /// next available offset in these buffers. + TrieEncodedChildSize = 14, } impl GlobalMetadata { - pub(crate) const COUNT: usize = 14; + pub(crate) const COUNT: usize = 15; pub(crate) fn all() -> [Self; Self::COUNT] { [ @@ -52,6 +56,7 @@ impl GlobalMetadata { Self::StateTrieRootDigestAfter, Self::TransactionTrieRootDigestAfter, Self::ReceiptTrieRootDigestAfter, + Self::TrieEncodedChildSize, ] } @@ -80,6 +85,7 @@ impl GlobalMetadata { GlobalMetadata::ReceiptTrieRootDigestAfter => { "GLOBAL_METADATA_RECEIPT_TRIE_DIGEST_AFTER" } + GlobalMetadata::TrieEncodedChildSize => "TRIE_ENCODED_CHILD_SIZE", } } } diff --git a/evm/src/cpu/kernel/interpreter.rs b/evm/src/cpu/kernel/interpreter.rs index 2eb9dcb9..5f3c7dcc 100644 --- a/evm/src/cpu/kernel/interpreter.rs +++ b/evm/src/cpu/kernel/interpreter.rs @@ -1,3 +1,5 @@ +//! An EVM interpreter for testing and debugging purposes. + use std::collections::HashMap; use anyhow::{anyhow, bail, ensure}; @@ -609,7 +611,13 @@ impl<'a> Interpreter<'a> { let segment = Segment::all()[self.pop().as_usize()]; let offset = self.pop().as_usize(); let value = self.pop(); - assert!(value.bits() <= segment.bit_range()); + assert!( + value.bits() <= segment.bit_range(), + "Value {} exceeds {:?} range of {} bits", + value, + segment, + segment.bit_range() + ); self.memory.mstore_general(context, segment, offset, value); } } diff --git a/evm/src/cpu/kernel/tests/mpt/hash.rs b/evm/src/cpu/kernel/tests/mpt/hash.rs index 6b31a523..dd09f350 100644 --- a/evm/src/cpu/kernel/tests/mpt/hash.rs +++ b/evm/src/cpu/kernel/tests/mpt/hash.rs @@ -90,7 +90,11 @@ fn mpt_hash_branch_to_leaf() -> Result<()> { value: account_rlp.to_vec(), }; let mut children = std::array::from_fn(|_| Box::new(PartialTrie::Empty)); - children[0] = Box::new(leaf); + children[5] = Box::new(PartialTrie::Branch { + children: children.clone(), + value: vec![], + }); + children[3] = Box::new(leaf); let state_trie = PartialTrie::Branch { children, value: vec![], diff --git a/evm/src/cpu/kernel/tests/mpt/insert.rs b/evm/src/cpu/kernel/tests/mpt/insert.rs index 103bdd2e..218e1681 100644 --- a/evm/src/cpu/kernel/tests/mpt/insert.rs +++ b/evm/src/cpu/kernel/tests/mpt/insert.rs @@ -135,7 +135,8 @@ fn mpt_insert_branch_replacing_empty_child() -> Result<()> { } #[test] -fn mpt_insert_extension_to_leaf_same_key() -> Result<()> { +fn mpt_insert_extension_nonoverlapping_keys() -> Result<()> { + // Existing keys are 0xABC, 0xABCDEF; inserted key is 0x12345. let mut children = std::array::from_fn(|_| Box::new(PartialTrie::Empty)); children[0xD] = Box::new(PartialTrie::Leaf { nibbles: Nibbles { @@ -154,6 +155,37 @@ fn mpt_insert_extension_to_leaf_same_key() -> Result<()> { value: test_account_1_rlp(), }), }; + let insert = InsertEntry { + nibbles: Nibbles { + count: 5, + packed: 0x12345.into(), + }, + v: test_account_2_rlp(), + }; + test_state_trie(state_trie, insert) +} + +#[test] +fn mpt_insert_extension_insert_key_extends_node_key() -> Result<()> { + // Existing keys are 0xA, 0xABCD; inserted key is 0xABCDEF. + let mut children = std::array::from_fn(|_| Box::new(PartialTrie::Empty)); + children[0xB] = Box::new(PartialTrie::Leaf { + nibbles: Nibbles { + count: 2, + packed: 0xCD.into(), + }, + value: test_account_1_rlp(), + }); + let state_trie = PartialTrie::Extension { + nibbles: Nibbles { + count: 1, + packed: 0xA.into(), + }, + child: Box::new(PartialTrie::Branch { + children, + value: test_account_1_rlp(), + }), + }; let insert = InsertEntry { nibbles: Nibbles { count: 6, diff --git a/evm/src/memory/segments.rs b/evm/src/memory/segments.rs index 44390a9b..b6254900 100644 --- a/evm/src/memory/segments.rs +++ b/evm/src/memory/segments.rs @@ -38,10 +38,14 @@ pub(crate) enum Segment { /// `StorageTriePointers` with `StorageTrieCheckpointPointers`. /// See also `StateTrieCheckpointPointer`. StorageTrieCheckpointPointers = 15, + /// A buffer used to store the encodings of a branch node's children. + TrieEncodedChild = 16, + /// A buffer used to store the lengths of the encodings of a branch node's children. + TrieEncodedChildLen = 17, } impl Segment { - pub(crate) const COUNT: usize = 16; + pub(crate) const COUNT: usize = 18; pub(crate) fn all() -> [Self; Self::COUNT] { [ @@ -61,6 +65,8 @@ impl Segment { Self::StorageTrieAddresses, Self::StorageTriePointers, Self::StorageTrieCheckpointPointers, + Self::TrieEncodedChild, + Self::TrieEncodedChildLen, ] } @@ -83,6 +89,8 @@ impl Segment { Segment::StorageTrieAddresses => "SEGMENT_STORAGE_TRIE_ADDRS", Segment::StorageTriePointers => "SEGMENT_STORAGE_TRIE_PTRS", Segment::StorageTrieCheckpointPointers => "SEGMENT_STORAGE_TRIE_CHECKPOINT_PTRS", + Segment::TrieEncodedChild => "SEGMENT_TRIE_ENCODED_CHILD", + Segment::TrieEncodedChildLen => "SEGMENT_TRIE_ENCODED_CHILD_LEN", } } @@ -105,6 +113,8 @@ impl Segment { Segment::StorageTrieAddresses => 160, Segment::StorageTriePointers => 32, Segment::StorageTrieCheckpointPointers => 32, + Segment::TrieEncodedChild => 256, + Segment::TrieEncodedChildLen => 6, } } } From cb2e69a2c9d2247828fd20d11aa5449fa62e59fe Mon Sep 17 00:00:00 2001 From: BGluth Date: Tue, 11 Oct 2022 20:15:33 -0600 Subject: [PATCH 15/17] Updated `eth_trie_utils` to `0.2.0` --- evm/Cargo.toml | 2 +- evm/src/cpu/kernel/tests/mpt/hash.rs | 25 ++-- evm/src/cpu/kernel/tests/mpt/insert.rs | 193 ++++++------------------- evm/src/cpu/kernel/tests/mpt/load.rs | 11 +- evm/src/cpu/kernel/tests/mpt/mod.rs | 21 ++- 5 files changed, 76 insertions(+), 176 deletions(-) diff --git a/evm/Cargo.toml b/evm/Cargo.toml index 848dff15..a3dc09e2 100644 --- a/evm/Cargo.toml +++ b/evm/Cargo.toml @@ -7,7 +7,7 @@ edition = "2021" [dependencies] plonky2 = { path = "../plonky2", default-features = false, features = ["rand", "timing"] } plonky2_util = { path = "../util" } -eth_trie_utils = "0.1.0" +eth_trie_utils = "0.2.0" anyhow = "1.0.40" env_logger = "0.9.0" ethereum-types = "0.14.0" diff --git a/evm/src/cpu/kernel/tests/mpt/hash.rs b/evm/src/cpu/kernel/tests/mpt/hash.rs index dd09f350..de519797 100644 --- a/evm/src/cpu/kernel/tests/mpt/hash.rs +++ b/evm/src/cpu/kernel/tests/mpt/hash.rs @@ -1,7 +1,8 @@ use anyhow::Result; -use eth_trie_utils::partial_trie::{Nibbles, PartialTrie}; +use eth_trie_utils::partial_trie::PartialTrie; use ethereum_types::{BigEndianHash, H256, U256}; +use super::nibbles; use crate::cpu::kernel::aggregator::KERNEL; use crate::cpu::kernel::interpreter::Interpreter; use crate::cpu::kernel::tests::mpt::extension_to_leaf; @@ -33,10 +34,7 @@ fn mpt_hash_leaf() -> Result<()> { let account_rlp = rlp::encode(&account); let state_trie = PartialTrie::Leaf { - nibbles: Nibbles { - count: 3, - packed: 0xABC.into(), - }, + nibbles: nibbles(0xABC), value: account_rlp.to_vec(), }; @@ -83,18 +81,17 @@ fn mpt_hash_branch_to_leaf() -> Result<()> { let account_rlp = rlp::encode(&account); let leaf = PartialTrie::Leaf { - nibbles: Nibbles { - count: 3, - packed: 0xABC.into(), - }, + nibbles: nibbles(0xABC), value: account_rlp.to_vec(), - }; - let mut children = std::array::from_fn(|_| Box::new(PartialTrie::Empty)); - children[5] = Box::new(PartialTrie::Branch { + } + .into(); + let mut children = std::array::from_fn(|_| PartialTrie::Empty.into()); + children[5] = PartialTrie::Branch { children: children.clone(), value: vec![], - }); - children[3] = Box::new(leaf); + } + .into(); + children[3] = leaf; let state_trie = PartialTrie::Branch { children, value: vec![], diff --git a/evm/src/cpu/kernel/tests/mpt/insert.rs b/evm/src/cpu/kernel/tests/mpt/insert.rs index 218e1681..469ad1e4 100644 --- a/evm/src/cpu/kernel/tests/mpt/insert.rs +++ b/evm/src/cpu/kernel/tests/mpt/insert.rs @@ -1,8 +1,8 @@ use anyhow::Result; use eth_trie_utils::partial_trie::{Nibbles, PartialTrie}; -use eth_trie_utils::trie_builder::InsertEntry; use ethereum_types::{BigEndianHash, H256}; +use super::nibbles; use crate::cpu::kernel::aggregator::KERNEL; use crate::cpu::kernel::constants::global_metadata::GlobalMetadata; use crate::cpu::kernel::interpreter::Interpreter; @@ -12,218 +12,124 @@ use crate::generation::TrieInputs; #[test] fn mpt_insert_empty() -> Result<()> { - let insert = InsertEntry { - nibbles: Nibbles { - count: 3, - packed: 0xABC.into(), - }, - v: test_account_2_rlp(), - }; - test_state_trie(Default::default(), insert) + test_state_trie(Default::default(), nibbles(0xABC), test_account_2_rlp()) } #[test] fn mpt_insert_leaf_identical_keys() -> Result<()> { - let key = Nibbles { - count: 3, - packed: 0xABC.into(), - }; + let key = nibbles(0xABC); let state_trie = PartialTrie::Leaf { nibbles: key, value: test_account_1_rlp(), }; - let insert = InsertEntry { - nibbles: key, - v: test_account_2_rlp(), - }; - test_state_trie(state_trie, insert) + test_state_trie(state_trie, key, test_account_2_rlp()) } #[test] fn mpt_insert_leaf_nonoverlapping_keys() -> Result<()> { let state_trie = PartialTrie::Leaf { - nibbles: Nibbles { - count: 3, - packed: 0xABC.into(), - }, + nibbles: nibbles(0xABC), value: test_account_1_rlp(), }; - let insert = InsertEntry { - nibbles: Nibbles { - count: 3, - packed: 0x123.into(), - }, - v: test_account_2_rlp(), - }; - test_state_trie(state_trie, insert) + test_state_trie(state_trie, nibbles(0x123), test_account_2_rlp()) } #[test] fn mpt_insert_leaf_overlapping_keys() -> Result<()> { let state_trie = PartialTrie::Leaf { - nibbles: Nibbles { - count: 3, - packed: 0xABC.into(), - }, + nibbles: nibbles(0xABC), value: test_account_1_rlp(), }; - let insert = InsertEntry { - nibbles: Nibbles { - count: 3, - packed: 0xADE.into(), - }, - v: test_account_2_rlp(), - }; - test_state_trie(state_trie, insert) + test_state_trie(state_trie, nibbles(0xADE), test_account_2_rlp()) } #[test] fn mpt_insert_leaf_insert_key_extends_leaf_key() -> Result<()> { let state_trie = PartialTrie::Leaf { - nibbles: Nibbles { - count: 3, - packed: 0xABC.into(), - }, + nibbles: nibbles(0xABC), value: test_account_1_rlp(), }; - let insert = InsertEntry { - nibbles: Nibbles { - count: 5, - packed: 0xABCDE.into(), - }, - v: test_account_2_rlp(), - }; - test_state_trie(state_trie, insert) + test_state_trie(state_trie, nibbles(0xABCDE), test_account_2_rlp()) } #[test] fn mpt_insert_leaf_leaf_key_extends_insert_key() -> Result<()> { let state_trie = PartialTrie::Leaf { - nibbles: Nibbles { - count: 5, - packed: 0xABCDE.into(), - }, + nibbles: nibbles(0xABCDE), value: test_account_1_rlp(), }; - let insert = InsertEntry { - nibbles: Nibbles { - count: 3, - packed: 0xABC.into(), - }, - v: test_account_2_rlp(), - }; - test_state_trie(state_trie, insert) + test_state_trie(state_trie, nibbles(0xABC), test_account_2_rlp()) } #[test] fn mpt_insert_branch_replacing_empty_child() -> Result<()> { - let children = std::array::from_fn(|_| Box::new(PartialTrie::Empty)); + let children = std::array::from_fn(|_| PartialTrie::Empty.into()); let state_trie = PartialTrie::Branch { children, value: vec![], }; - let insert = InsertEntry { - nibbles: Nibbles { - count: 3, - packed: 0xABC.into(), - }, - v: test_account_2_rlp(), - }; - - test_state_trie(state_trie, insert) + test_state_trie(state_trie, nibbles(0xABC), test_account_2_rlp()) } #[test] fn mpt_insert_extension_nonoverlapping_keys() -> Result<()> { // Existing keys are 0xABC, 0xABCDEF; inserted key is 0x12345. - let mut children = std::array::from_fn(|_| Box::new(PartialTrie::Empty)); - children[0xD] = Box::new(PartialTrie::Leaf { - nibbles: Nibbles { - count: 2, - packed: 0xEF.into(), - }, + let mut children = std::array::from_fn(|_| PartialTrie::Empty.into()); + children[0xD] = PartialTrie::Leaf { + nibbles: nibbles(0xEF), value: test_account_1_rlp(), - }); + } + .into(); let state_trie = PartialTrie::Extension { - nibbles: Nibbles { - count: 3, - packed: 0xABC.into(), - }, - child: Box::new(PartialTrie::Branch { + nibbles: nibbles(0xABC), + child: PartialTrie::Branch { children, value: test_account_1_rlp(), - }), + } + .into(), }; - let insert = InsertEntry { - nibbles: Nibbles { - count: 5, - packed: 0x12345.into(), - }, - v: test_account_2_rlp(), - }; - test_state_trie(state_trie, insert) + test_state_trie(state_trie, nibbles(0x12345), test_account_2_rlp()) } #[test] fn mpt_insert_extension_insert_key_extends_node_key() -> Result<()> { // Existing keys are 0xA, 0xABCD; inserted key is 0xABCDEF. - let mut children = std::array::from_fn(|_| Box::new(PartialTrie::Empty)); - children[0xB] = Box::new(PartialTrie::Leaf { - nibbles: Nibbles { - count: 2, - packed: 0xCD.into(), - }, + let mut children = std::array::from_fn(|_| PartialTrie::Empty.into()); + children[0xB] = PartialTrie::Leaf { + nibbles: nibbles(0xCD), value: test_account_1_rlp(), - }); + } + .into(); let state_trie = PartialTrie::Extension { - nibbles: Nibbles { - count: 1, - packed: 0xA.into(), - }, - child: Box::new(PartialTrie::Branch { + nibbles: nibbles(0xA), + child: PartialTrie::Branch { children, value: test_account_1_rlp(), - }), + } + .into(), }; - let insert = InsertEntry { - nibbles: Nibbles { - count: 6, - packed: 0xABCDEF.into(), - }, - v: test_account_2_rlp(), - }; - test_state_trie(state_trie, insert) + test_state_trie(state_trie, nibbles(0xABCDEF), test_account_2_rlp()) } #[test] fn mpt_insert_branch_to_leaf_same_key() -> Result<()> { let leaf = PartialTrie::Leaf { - nibbles: Nibbles { - count: 3, - packed: 0xBCD.into(), - }, + nibbles: nibbles(0xBCD), value: test_account_1_rlp(), - }; - let mut children = std::array::from_fn(|_| Box::new(PartialTrie::Empty)); - children[0xA] = Box::new(leaf); + } + .into(); + let mut children = std::array::from_fn(|_| PartialTrie::Empty.into()); + children[0xA] = leaf; let state_trie = PartialTrie::Branch { children, value: vec![], }; - let insert = InsertEntry { - nibbles: Nibbles { - count: 4, - packed: 0xABCD.into(), - }, - v: test_account_2_rlp(), - }; - - test_state_trie(state_trie, insert) + test_state_trie(state_trie, nibbles(0xABCD), test_account_2_rlp()) } -fn test_state_trie(state_trie: PartialTrie, insert: InsertEntry) -> Result<()> { +fn test_state_trie(state_trie: PartialTrie, k: Nibbles, v: Vec) -> Result<()> { let trie_inputs = TrieInputs { state_trie: state_trie.clone(), transactions_trie: Default::default(), @@ -249,7 +155,7 @@ fn test_state_trie(state_trie: PartialTrie, insert: InsertEntry) -> Result<()> { trie_data.push(0.into()); } let value_ptr = trie_data.len(); - let account: AccountRlp = rlp::decode(&insert.v).expect("Decoding failed"); + let account: AccountRlp = rlp::decode(&v).expect("Decoding failed"); let account_data = account.to_vec(); trie_data.push(account_data.len().into()); trie_data.extend(account_data); @@ -257,8 +163,8 @@ fn test_state_trie(state_trie: PartialTrie, insert: InsertEntry) -> Result<()> { interpreter.set_global_metadata_field(GlobalMetadata::TrieDataSize, trie_data_len); interpreter.push(0xDEADBEEFu32.into()); interpreter.push(value_ptr.into()); // value_ptr - interpreter.push(insert.nibbles.packed); // key - interpreter.push(insert.nibbles.count.into()); // num_nibbles + interpreter.push(k.packed); // key + interpreter.push(k.count.into()); // num_nibbles interpreter.run()?; assert_eq!( @@ -281,18 +187,9 @@ fn test_state_trie(state_trie: PartialTrie, insert: InsertEntry) -> Result<()> { ); let hash = H256::from_uint(&interpreter.stack()[0]); - let updated_trie = apply_insert(state_trie, insert); + let updated_trie = state_trie.insert(k, v); let expected_state_trie_hash = updated_trie.calc_hash(); assert_eq!(hash, expected_state_trie_hash); Ok(()) } - -fn apply_insert(trie: PartialTrie, insert: InsertEntry) -> PartialTrie { - let mut trie = Box::new(trie); - if let Some(updated_trie) = PartialTrie::insert_into_trie(&mut trie, insert) { - *updated_trie - } else { - *trie - } -} diff --git a/evm/src/cpu/kernel/tests/mpt/load.rs b/evm/src/cpu/kernel/tests/mpt/load.rs index ccf8353e..0572458d 100644 --- a/evm/src/cpu/kernel/tests/mpt/load.rs +++ b/evm/src/cpu/kernel/tests/mpt/load.rs @@ -1,12 +1,12 @@ use anyhow::Result; -use eth_trie_utils::partial_trie::{Nibbles, PartialTrie}; +use eth_trie_utils::partial_trie::PartialTrie; use ethereum_types::{BigEndianHash, U256}; -use crate::cpu::kernel::aggregator::KERNEL; use crate::cpu::kernel::constants::global_metadata::GlobalMetadata; use crate::cpu::kernel::constants::trie_type::PartialTrieType; use crate::cpu::kernel::interpreter::Interpreter; use crate::cpu::kernel::tests::mpt::{extension_to_leaf, test_account_1, test_account_1_rlp}; +use crate::cpu::kernel::{aggregator::KERNEL, tests::mpt::nibbles}; use crate::generation::mpt::all_mpt_prover_inputs_reversed; use crate::generation::TrieInputs; @@ -54,10 +54,7 @@ fn load_all_mpts_empty() -> Result<()> { fn load_all_mpts_leaf() -> Result<()> { let trie_inputs = TrieInputs { state_trie: PartialTrie::Leaf { - nibbles: Nibbles { - count: 3, - packed: 0xABC.into(), - }, + nibbles: nibbles(0xABC), value: test_account_1_rlp(), }, transactions_trie: Default::default(), @@ -109,7 +106,7 @@ fn load_all_mpts_leaf() -> Result<()> { #[test] fn load_all_mpts_empty_branch() -> Result<()> { - let children = std::array::from_fn(|_| Box::new(PartialTrie::Empty)); + let children = std::array::from_fn(|_| PartialTrie::Empty.into()); let state_trie = PartialTrie::Branch { children, value: vec![], diff --git a/evm/src/cpu/kernel/tests/mpt/mod.rs b/evm/src/cpu/kernel/tests/mpt/mod.rs index e3414b38..2c7999df 100644 --- a/evm/src/cpu/kernel/tests/mpt/mod.rs +++ b/evm/src/cpu/kernel/tests/mpt/mod.rs @@ -9,6 +9,17 @@ mod insert; mod load; mod read; +/// Helper function to reduce code duplication. +/// Note that this preserves all nibbles (eg. `0x123` is not interpreted as `0x0123`). +pub(crate) fn nibbles>(v: T) -> Nibbles { + let packed = v.into(); + + Nibbles { + count: Nibbles::get_num_nibbles_in_key(&packed), + packed, + } +} + pub(crate) fn test_account_1() -> AccountRlp { AccountRlp { nonce: U256::from(1111), @@ -38,16 +49,14 @@ pub(crate) fn test_account_2_rlp() -> Vec { /// A `PartialTrie` where an extension node leads to a leaf node containing an account. pub(crate) fn extension_to_leaf(value: Vec) -> PartialTrie { PartialTrie::Extension { - nibbles: Nibbles { - count: 3, - packed: 0xABC.into(), - }, - child: Box::new(PartialTrie::Leaf { + nibbles: nibbles(0xABC), + child: PartialTrie::Leaf { nibbles: Nibbles { count: 3, packed: 0xDEF.into(), }, value, - }), + } + .into(), } } From 06475c2b61a0d3b26d3f3ded791456052b94251b Mon Sep 17 00:00:00 2001 From: BGluth Date: Tue, 11 Oct 2022 22:07:32 -0600 Subject: [PATCH 16/17] Bumped patch version --- evm/Cargo.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/evm/Cargo.toml b/evm/Cargo.toml index a3dc09e2..bcbc5cd9 100644 --- a/evm/Cargo.toml +++ b/evm/Cargo.toml @@ -7,7 +7,7 @@ edition = "2021" [dependencies] plonky2 = { path = "../plonky2", default-features = false, features = ["rand", "timing"] } plonky2_util = { path = "../util" } -eth_trie_utils = "0.2.0" +eth_trie_utils = "0.2.1" anyhow = "1.0.40" env_logger = "0.9.0" ethereum-types = "0.14.0" From ec3391f9c4e14055e811697dae957e6c85631639 Mon Sep 17 00:00:00 2001 From: Jacqueline Nabaglo Date: Thu, 13 Oct 2022 14:02:19 -0700 Subject: [PATCH 17/17] Add Fp254 ops to the CPU table (#779) * Add Fp254 ops to the CPU table * Add forgotten file --- evm/src/cpu/columns/ops.rs | 3 ++ evm/src/cpu/control_flow.rs | 5 ++- evm/src/cpu/cpu_stark.rs | 4 ++- evm/src/cpu/decode.rs | 5 ++- evm/src/cpu/kernel/interpreter.rs | 3 ++ evm/src/cpu/kernel/opcodes.rs | 3 ++ evm/src/cpu/mod.rs | 1 + evm/src/cpu/modfp254.rs | 53 +++++++++++++++++++++++++++++++ evm/src/cpu/stack.rs | 3 ++ 9 files changed, 77 insertions(+), 3 deletions(-) create mode 100644 evm/src/cpu/modfp254.rs diff --git a/evm/src/cpu/columns/ops.rs b/evm/src/cpu/columns/ops.rs index e0cb2952..04d4d0f2 100644 --- a/evm/src/cpu/columns/ops.rs +++ b/evm/src/cpu/columns/ops.rs @@ -19,6 +19,9 @@ pub struct OpsColumnsView { pub mulmod: T, pub exp: T, pub signextend: T, + pub addfp254: T, + pub mulfp254: T, + pub subfp254: T, pub lt: T, pub gt: T, pub slt: T, diff --git a/evm/src/cpu/control_flow.rs b/evm/src/cpu/control_flow.rs index 3856726c..c7b7c6bb 100644 --- a/evm/src/cpu/control_flow.rs +++ b/evm/src/cpu/control_flow.rs @@ -9,7 +9,7 @@ use crate::cpu::columns::{CpuColumnsView, COL_MAP}; use crate::cpu::kernel::aggregator::KERNEL; // TODO: This list is incomplete. -const NATIVE_INSTRUCTIONS: [usize; 25] = [ +const NATIVE_INSTRUCTIONS: [usize; 28] = [ COL_MAP.op.add, COL_MAP.op.mul, COL_MAP.op.sub, @@ -20,6 +20,9 @@ const NATIVE_INSTRUCTIONS: [usize; 25] = [ COL_MAP.op.addmod, COL_MAP.op.mulmod, COL_MAP.op.signextend, + COL_MAP.op.addfp254, + COL_MAP.op.mulfp254, + COL_MAP.op.subfp254, COL_MAP.op.lt, COL_MAP.op.gt, COL_MAP.op.slt, diff --git a/evm/src/cpu/cpu_stark.rs b/evm/src/cpu/cpu_stark.rs index b11ff9f5..7b34cc4f 100644 --- a/evm/src/cpu/cpu_stark.rs +++ b/evm/src/cpu/cpu_stark.rs @@ -11,7 +11,7 @@ use plonky2::hash::hash_types::RichField; use crate::constraint_consumer::{ConstraintConsumer, RecursiveConstraintConsumer}; use crate::cpu::columns::{CpuColumnsView, COL_MAP, NUM_CPU_COLUMNS}; use crate::cpu::{ - bootstrap_kernel, control_flow, decode, dup_swap, jumps, membus, simple_logic, stack, + bootstrap_kernel, control_flow, decode, dup_swap, jumps, membus, modfp254, simple_logic, stack, stack_bounds, syscalls, }; use crate::cross_table_lookup::Column; @@ -150,6 +150,7 @@ impl, const D: usize> Stark for CpuStark, const D: usize> Stark for CpuStark Interpreter<'a> { 0x09 => self.run_mulmod(), // "MULMOD", 0x0a => self.run_exp(), // "EXP", 0x0b => todo!(), // "SIGNEXTEND", + 0x0c => todo!(), // "ADDFP254", + 0x0d => todo!(), // "MULFP254", + 0x0e => todo!(), // "SUBFP254", 0x10 => self.run_lt(), // "LT", 0x11 => self.run_gt(), // "GT", 0x12 => todo!(), // "SLT", diff --git a/evm/src/cpu/kernel/opcodes.rs b/evm/src/cpu/kernel/opcodes.rs index c5133050..20601267 100644 --- a/evm/src/cpu/kernel/opcodes.rs +++ b/evm/src/cpu/kernel/opcodes.rs @@ -20,6 +20,9 @@ pub(crate) fn get_opcode(mnemonic: &str) -> u8 { "MULMOD" => 0x09, "EXP" => 0x0a, "SIGNEXTEND" => 0x0b, + "ADDFP254" => 0x0c, + "MULFP254" => 0x0d, + "SUBFP254" => 0x0e, "LT" => 0x10, "GT" => 0x11, "SLT" => 0x12, diff --git a/evm/src/cpu/mod.rs b/evm/src/cpu/mod.rs index bde06585..fda5db80 100644 --- a/evm/src/cpu/mod.rs +++ b/evm/src/cpu/mod.rs @@ -7,6 +7,7 @@ mod dup_swap; mod jumps; pub mod kernel; pub(crate) mod membus; +mod modfp254; mod simple_logic; mod stack; mod stack_bounds; diff --git a/evm/src/cpu/modfp254.rs b/evm/src/cpu/modfp254.rs new file mode 100644 index 00000000..defbf862 --- /dev/null +++ b/evm/src/cpu/modfp254.rs @@ -0,0 +1,53 @@ +use itertools::izip; +use plonky2::field::extension::Extendable; +use plonky2::field::packed::PackedField; +use plonky2::field::types::Field; +use plonky2::hash::hash_types::RichField; +use plonky2::iop::ext_target::ExtensionTarget; + +use crate::constraint_consumer::{ConstraintConsumer, RecursiveConstraintConsumer}; +use crate::cpu::columns::CpuColumnsView; + +// Python: +// >>> P = 21888242871839275222246405745257275088696311157297823662689037894645226208583 +// >>> "[" + ", ".join(hex((P >> n) % 2**32) for n in range(0, 256, 32)) + "]" +const P_LIMBS: [u32; 8] = [ + 0xd87cfd47, 0x3c208c16, 0x6871ca8d, 0x97816a91, 0x8181585d, 0xb85045b6, 0xe131a029, 0x30644e72, +]; + +pub fn eval_packed( + lv: &CpuColumnsView

, + yield_constr: &mut ConstraintConsumer

, +) { + let filter = lv.is_cpu_cycle * (lv.op.addfp254 + lv.op.mulfp254 + lv.op.subfp254); + + // We want to use all the same logic as the usual mod operations, but without needing to read + // the modulus from the stack. We simply constrain `mem_channels[2]` to be our prime (that's + // where the modulus goes in the generalized operations). + let channel_val = lv.mem_channels[2].value; + for (channel_limb, p_limb) in izip!(channel_val, P_LIMBS) { + let p_limb = P::Scalar::from_canonical_u32(p_limb); + yield_constr.constraint(filter * (channel_limb - p_limb)); + } +} + +pub fn eval_ext_circuit, const D: usize>( + builder: &mut plonky2::plonk::circuit_builder::CircuitBuilder, + lv: &CpuColumnsView>, + yield_constr: &mut RecursiveConstraintConsumer, +) { + let filter = { + let flag_sum = builder.add_many_extension([lv.op.addfp254, lv.op.mulfp254, lv.op.subfp254]); + builder.mul_extension(lv.is_cpu_cycle, flag_sum) + }; + + // We want to use all the same logic as the usual mod operations, but without needing to read + // the modulus from the stack. We simply constrain `mem_channels[2]` to be our prime (that's + // where the modulus goes in the generalized operations). + let channel_val = lv.mem_channels[2].value; + for (channel_limb, p_limb) in izip!(channel_val, P_LIMBS) { + let p_limb = F::from_canonical_u32(p_limb); + let constr = builder.arithmetic_extension(F::ONE, -p_limb, filter, channel_limb, filter); + yield_constr.constraint(builder, constr); + } +} diff --git a/evm/src/cpu/stack.rs b/evm/src/cpu/stack.rs index 9bc08091..c72688ed 100644 --- a/evm/src/cpu/stack.rs +++ b/evm/src/cpu/stack.rs @@ -52,6 +52,9 @@ const STACK_BEHAVIORS: OpsColumnsView> = OpsColumnsView { mulmod: BASIC_TERNARY_OP, exp: None, // TODO signextend: BASIC_BINARY_OP, + addfp254: BASIC_BINARY_OP, + mulfp254: BASIC_BINARY_OP, + subfp254: BASIC_BINARY_OP, lt: BASIC_BINARY_OP, gt: BASIC_BINARY_OP, slt: BASIC_BINARY_OP,