diff --git a/evm/src/cpu/kernel/asm/mpt/insert.asm b/evm/src/cpu/kernel/asm/mpt/insert.asm index 9a5ed7f2..cc748e79 100644 --- a/evm/src/cpu/kernel/asm/mpt/insert.asm +++ b/evm/src/cpu/kernel/asm/mpt/insert.asm @@ -106,13 +106,40 @@ mpt_insert_branch_nonterminal_after_recursion: JUMP mpt_insert_extension: - // stack: node_type, node_payload_ptr, num_nibbles, key, value_ptr, retdest + // stack: node_type, node_payload_ptr, insert_len, insert_key, value_ptr, retdest POP - // stack: node_payload_ptr, num_nibbles, key, value_ptr, retdest + // stack: node_payload_ptr, insert_len, insert_key, value_ptr, retdest PANIC // TODO mpt_insert_leaf: - // stack: node_type, node_payload_ptr, num_nibbles, key, value_ptr, retdest + // stack: node_type, node_payload_ptr, insert_len, insert_key, value_ptr, retdest POP - // stack: node_payload_ptr, num_nibbles, key, value_ptr, retdest + // stack: node_payload_ptr, insert_len, insert_key, value_ptr, retdest + %stack (node_payload_ptr, insert_len, insert_key) -> (insert_len, insert_key, node_payload_ptr) + // stack: insert_len, insert_key, node_payload_ptr, value_ptr, retdest + DUP3 %increment %mload_trie_data + // stack: node_key, insert_len, insert_key, node_payload_ptr, value_ptr, retdest + DUP4 %mload_trie_data + // stack: node_len, node_key, insert_len, insert_key, node_payload_ptr, value_ptr, retdest + // TODO: Maybe skip %split_common_prefix if lengths & keys exactly match. + %split_common_prefix + // stack: common_len, common_key, node_len, node_key, insert_len, insert_key, node_payload_ptr, value_ptr, retdest + DUP3 DUP6 ADD %jumpi(mpt_insert_leaf_not_exact_match) + // If we got here, the node key exactly matches the insert key, so we will + // keep the same leaf node structure and just replace its value. + %stack (common_len, common_key, node_len, node_key, insert_len, insert_key, node_payload_ptr, value_ptr) + -> (common_len, common_key, value_ptr) + // stack: common_len, common_key, value_ptr, retdest + %get_trie_data_size + // stack: updated_leaf_ptr, common_len, common_key, value_ptr, retdest + PUSH @MPT_NODE_LEAF %append_to_trie_data + SWAP1 %append_to_trie_data // append common_len + SWAP1 %append_to_trie_data // append common_key + SWAP1 %append_to_trie_data // append value_ptr + // stack: updated_leaf_ptr, retdest + SWAP1 + JUMP +mpt_insert_leaf_not_exact_match: + // stack: common_len, common_key, node_len, node_key, insert_len, insert_key, node_payload_ptr, value_ptr, retdest + // %get_trie_data_size PANIC // TODO diff --git a/evm/src/cpu/kernel/asm/mpt/util.asm b/evm/src/cpu/kernel/asm/mpt/util.asm index 19cad943..a3ff7f38 100644 --- a/evm/src/cpu/kernel/asm/mpt/util.asm +++ b/evm/src/cpu/kernel/asm/mpt/util.asm @@ -72,3 +72,88 @@ POP // stack: first_nibble, num_nibbles, key %endmacro + +// Split off the common prefix among two key parts. Roughly equivalent to +// def split_common_prefix(len_1, key_1, len_2, key_2): +// bits_1 = len_1 * 4 +// bits_2 = len_2 * 4 +// len_common = 0 +// key_common = 0 +// while True: +// if bits_1 * bits_2 == 0: +// break +// first_nib_1 = (key_1 >> (bits_1 - 4)) & 0xF +// first_nib_2 = (key_2 >> (bits_2 - 4)) & 0xF +// if first_nib_1 != first_nib_2: +// break +// len_common += 1 +// key_common = key_common * 16 + first_nib_1 +// bits_1 -= 4 +// bits_2 -= 4 +// key_1 -= (first_nib_1 << bits_1) +// key_2 -= (first_nib_2 << bits_2) +// len_1 = bits_1 // 4 +// len_2 = bits_2 // 4 +// return (len_common, key_common, len_1, key_1, len_2, key_2) +%macro split_common_prefix + // stack: len_1, key_1, len_2, key_2 + %mul_const(4) + SWAP2 %mul_const(4) SWAP2 + // stack: bits_1, key_1, bits_2, key_2 + PUSH 0 + PUSH 0 + +%%loop: + // stack: len_common, key_common, bits_1, key_1, bits_2, key_2 + + // if bits_1 * bits_2 == 0: break + DUP3 DUP6 MUL ISZERO %jumpi(%%return) + + // first_nib_2 = (key_2 >> (bits_2 - 4)) & 0xF + DUP6 DUP6 %sub_const(4) SHR %and_const(0xF) + // first_nib_1 = (key_1 >> (bits_1 - 4)) & 0xF + DUP5 DUP5 %sub_const(4) SHR %and_const(0xF) + // stack: first_nib_1, first_nib_2, len_common, key_common, bits_1, key_1, bits_2, key_2 + + // if first_nib_1 != first_nib_2: break + DUP2 DUP2 SUB %jumpi(%%return_with_first_nibs) + + // len_common += 1 + SWAP2 %increment SWAP2 + + // key_common = key_common * 16 + first_nib_1 + SWAP3 + %mul_const(16) + DUP4 ADD + SWAP3 + // stack: first_nib_1, first_nib_2, len_common, key_common, bits_1, key_1, bits_2, key_2 + + // bits_1 -= 4 + SWAP4 %sub_const(4) SWAP4 + // bits_2 -= 4 + SWAP6 %sub_const(4) SWAP6 + // stack: first_nib_1, first_nib_2, len_common, key_common, bits_1, key_1, bits_2, key_2 + + // key_1 -= (first_nib_1 << bits_1) + DUP5 SHL + // stack: first_nib_1 << bits_1, first_nib_2, len_common, key_common, bits_1, key_1, bits_2, key_2 + DUP6 SUB + // stack: key_1, first_nib_2, len_common, key_common, bits_1, key_1_old, bits_2, key_2 + SWAP5 POP + // stack: first_nib_2, len_common, key_common, bits_1, key_1, bits_2, key_2 + + // key_2 -= (first_nib_2 << bits_2) + DUP6 SHL + // stack: first_nib_2 << bits_2, len_common, key_common, bits_1, key_1, bits_2, key_2 + DUP7 SUB + // stack: key_2, len_common, key_common, bits_1, key_1, bits_2, key_2_old + SWAP6 POP + // stack: len_common, key_common, bits_1, key_1, bits_2, key_2 + + %jump(%%loop) +%%return_with_first_nibs: + // stack: first_nib_1, first_nib_2, len_common, key_common, bits_1, key_1, bits_2, key_2 + %pop2 +%%return: + // stack: len_common, key_common, len_1, key_1, len_2, key_2 +%endmacro diff --git a/evm/src/cpu/kernel/tests/mpt/insert.rs b/evm/src/cpu/kernel/tests/mpt/insert.rs index 7aeb4a1a..11927d52 100644 --- a/evm/src/cpu/kernel/tests/mpt/insert.rs +++ b/evm/src/cpu/kernel/tests/mpt/insert.rs @@ -23,7 +23,6 @@ fn mpt_insert_empty() -> Result<()> { } #[test] -#[ignore] // TODO: Enable when mpt_insert_leaf is done. fn mpt_insert_leaf_same_key() -> Result<()> { let key = Nibbles { count: 3, @@ -142,7 +141,12 @@ fn test_state_trie(state_trie: PartialTrie, insert: InsertEntry) -> Result<()> { interpreter.push(insert.nibbles.count.into()); // num_nibbles interpreter.run()?; - assert_eq!(interpreter.stack().len(), 0); + assert_eq!( + interpreter.stack().len(), + 0, + "Expected empty stack after insert, found {:?}", + interpreter.stack() + ); // Now, execute mpt_hash_state_trie. interpreter.offset = mpt_hash_state_trie; @@ -152,7 +156,7 @@ fn test_state_trie(state_trie: PartialTrie, insert: InsertEntry) -> Result<()> { assert_eq!( interpreter.stack().len(), 1, - "Expected 1 item on stack, found {:?}", + "Expected 1 item on stack after hashing, found {:?}", interpreter.stack() ); let hash = H256::from_uint(&interpreter.stack()[0]);