diff --git a/tests/variables/test_globals.py b/tests/variables/test_globals.py new file mode 100644 index 0000000..cde8d79 --- /dev/null +++ b/tests/variables/test_globals.py @@ -0,0 +1,56 @@ +import io + +def test_single_key(get_contract, get_last_out): + code = """ +amap: bytes32[bytes32] + + +@public +def set(key: bytes32, value: bytes32): + self.amap[key] = value + + +@public +def get(key: bytes32) -> bytes32: + vdb + return self.amap[key] + """ + + stdin = io.StringIO( + "self.amap['one']\n" + ) + stdout = io.StringIO() + c = get_contract(code, stdin=stdin, stdout=stdout) + c.functions.set(b'one', b'hello!').transact() + res = c.functions.get(b'one').call({'gas': 600000}) + + assert res[:6] == b'hello!' + assert 'hello!' in stdout.getvalue() + + +def test_double_key(get_contract, get_last_out): + code = """ +amap: (bytes32)[bytes32][bytes32] + + +@public +def set(key1: bytes32, key2: bytes32, value: bytes32): + self.amap[key1][key2] = value + + +@public +def get(key1: bytes32, key2: bytes32) -> bytes32: + vdb + return self.amap[key1][key2] + """ + + stdin = io.StringIO( + "self.amap[one][two]\n" + ) + stdout = io.StringIO() + c = get_contract(code, stdin=stdin, stdout=stdout) + c.functions.set(b'one', b'two', b'hello!').transact() + res = c.functions.get(b'one', b'two').call({'gas': 600000}) + + assert res[:6] == b'hello!' + assert 'hello!' in stdout.getvalue() diff --git a/vdb/variables.py b/vdb/variables.py new file mode 100644 index 0000000..fb681a6 --- /dev/null +++ b/vdb/variables.py @@ -0,0 +1,107 @@ +from eth_hash.auto import keccak +from eth_utils import to_hex +from eth_abi import decode_single +from evm.utils.numeric import ( + big_endian_to_int, + int_to_big_endian, +) + +base_types = ( + 'int128', + 'uint256', + 'address', + 'bytes32' +) + + +def print_var(stdout, value, var_typ): + + if isinstance(value, int): + v = int_to_big_endian(value) + else: + v = value + + if isinstance(v, bytes): + if var_typ in ('int128', 'uint256'): + stdout.write(str(decode_single(var_typ, value)) + '\n') + elif var_typ == 'address': + stdout.write(to_hex(v[12:]) + '\n') + elif var_typ.startswith('bytes'): + stdout.write(v.decode() + '\n') + else: + stdout.write(v.decode() + '\n') + + +def parse_local(stdout, local_variables, computation, line): + var_info = local_variables[line] + local_type = var_info['type'] + if local_type in base_types: + start_position = var_info['position'] + value = computation.memory_read(start_position, 32) + print_var(stdout, value, local_type) + else: + stdout.write('Can not read local of type\n') + + +def get_keys(n): + out = [] + name = n + for _ in range(name.count('[')): + open_pos = name.find('[') + close_pos = name.find(']') + key = name[open_pos + 1:close_pos].replace('\'', '').replace('"', '') + name = name[close_pos + 1:] + out.append(key) + return out + + +def get_hash(var_pos, keys, _type): + key_inp = b'' + + key_inp = keccak( + int_to_big_endian(var_pos).rjust(32, b'\0') + + keys[0].encode().ljust(32, b'\0') + ) + for key in keys[1:]: + key_inp = keccak(key_inp + key.encode().ljust(32, b'\0')) + slot = big_endian_to_int(key_inp) + return slot + + +def valid_subscript(name, global_type): + if name.count('[') != name.count(']'): + return False + elif global_type.count('[') != name.count('['): + return False + return True + + +def parse_global(stdout, global_vars, computation, line): + # print global value. + name = line.split('.')[1] + var_name = name[:name.find('[')] if '[' in name else name + + if var_name not in global_vars: + stdout.write('Global named "{}" not found.'.format(var_name) + '\n') + return + + global_type = global_vars[var_name]['type'] + slot = None + + if global_type in base_types: + slot = global_vars[var_name]['position'] + elif global_type.startswith('mapping') and valid_subscript(name, global_type): + keys = get_keys(name) + var_pos = global_vars[var_name]['position'] + slot = get_hash(var_pos, keys, global_type) + + if slot is not None: + value = computation.state.account_db.get_storage( + address=computation.msg.storage_address, + slot=slot, + ) + if global_type.startswith('mapping'): + global_type = global_type[global_type.find('(') + 1: global_type.find('[')] + print_var(stdout, value, global_type) + else: + stdout.write('Can not read global of type "{}".\n'.format(global_type)) diff --git a/vdb/vdb.py b/vdb/vdb.py index cf438b9..5677aea 100644 --- a/vdb/vdb.py +++ b/vdb/vdb.py @@ -1,4 +1,5 @@ import cmd +import readline from eth_utils import to_hex @@ -20,7 +21,6 @@ commands = [ def history(stdout): - import readline for i in range(1, readline.get_current_history_length() + 1): stdout.write("%3d %s" % (i, readline.get_history_item(i)) + '\n') @@ -105,7 +105,22 @@ class VyperDebugCmd(cmd.Cmd): for name, info in variables.items(): self.stdout.write('{}\t\t{}'.format(name, info['type']) + '\n') + def completenames(self, text, *ignored): + line = text.strip() + if 'self.' in line: + return [ + 'self.' + x + for x in self.globals.keys() + if x.startswith(line.split('self.')[1]) + ] + else: + dotext = 'do_' + text + cmds = [a[3:] for a in self.get_names() if a.startswith(dotext)] + _, local_vars = self._get_fn_name_locals() + return cmds + [x for x in local_vars.keys() if x.startswith(line)] + def default(self, line): + line = line.strip() fn_name, local_variables = self._get_fn_name_locals() if line.startswith('self.') and len(line) > 4: @@ -114,7 +129,7 @@ class VyperDebugCmd(cmd.Cmd): ) elif line in local_variables: parse_local( - self.stdout, self.local_vars, self.computation, line + self.stdout, local_variables, self.computation, line ) else: self.stdout.write('*** Unknown syntax: %s\n' % line)