Add support for retrieving values from maps.
This commit is contained in:
parent
7a689bf56f
commit
4ecbe602e6
|
@ -0,0 +1,56 @@
|
|||
import io
|
||||
|
||||
def test_single_key(get_contract, get_last_out):
|
||||
code = """
|
||||
amap: bytes32[bytes32]
|
||||
|
||||
|
||||
@public
|
||||
def set(key: bytes32, value: bytes32):
|
||||
self.amap[key] = value
|
||||
|
||||
|
||||
@public
|
||||
def get(key: bytes32) -> bytes32:
|
||||
vdb
|
||||
return self.amap[key]
|
||||
"""
|
||||
|
||||
stdin = io.StringIO(
|
||||
"self.amap['one']\n"
|
||||
)
|
||||
stdout = io.StringIO()
|
||||
c = get_contract(code, stdin=stdin, stdout=stdout)
|
||||
c.functions.set(b'one', b'hello!').transact()
|
||||
res = c.functions.get(b'one').call({'gas': 600000})
|
||||
|
||||
assert res[:6] == b'hello!'
|
||||
assert 'hello!' in stdout.getvalue()
|
||||
|
||||
|
||||
def test_double_key(get_contract, get_last_out):
|
||||
code = """
|
||||
amap: (bytes32)[bytes32][bytes32]
|
||||
|
||||
|
||||
@public
|
||||
def set(key1: bytes32, key2: bytes32, value: bytes32):
|
||||
self.amap[key1][key2] = value
|
||||
|
||||
|
||||
@public
|
||||
def get(key1: bytes32, key2: bytes32) -> bytes32:
|
||||
vdb
|
||||
return self.amap[key1][key2]
|
||||
"""
|
||||
|
||||
stdin = io.StringIO(
|
||||
"self.amap[one][two]\n"
|
||||
)
|
||||
stdout = io.StringIO()
|
||||
c = get_contract(code, stdin=stdin, stdout=stdout)
|
||||
c.functions.set(b'one', b'two', b'hello!').transact()
|
||||
res = c.functions.get(b'one', b'two').call({'gas': 600000})
|
||||
|
||||
assert res[:6] == b'hello!'
|
||||
assert 'hello!' in stdout.getvalue()
|
|
@ -0,0 +1,107 @@
|
|||
from eth_hash.auto import keccak
|
||||
from eth_utils import to_hex
|
||||
from eth_abi import decode_single
|
||||
from evm.utils.numeric import (
|
||||
big_endian_to_int,
|
||||
int_to_big_endian,
|
||||
)
|
||||
|
||||
base_types = (
|
||||
'int128',
|
||||
'uint256',
|
||||
'address',
|
||||
'bytes32'
|
||||
)
|
||||
|
||||
|
||||
def print_var(stdout, value, var_typ):
|
||||
|
||||
if isinstance(value, int):
|
||||
v = int_to_big_endian(value)
|
||||
else:
|
||||
v = value
|
||||
|
||||
if isinstance(v, bytes):
|
||||
if var_typ in ('int128', 'uint256'):
|
||||
stdout.write(str(decode_single(var_typ, value)) + '\n')
|
||||
elif var_typ == 'address':
|
||||
stdout.write(to_hex(v[12:]) + '\n')
|
||||
elif var_typ.startswith('bytes'):
|
||||
stdout.write(v.decode() + '\n')
|
||||
else:
|
||||
stdout.write(v.decode() + '\n')
|
||||
|
||||
|
||||
def parse_local(stdout, local_variables, computation, line):
|
||||
var_info = local_variables[line]
|
||||
local_type = var_info['type']
|
||||
if local_type in base_types:
|
||||
start_position = var_info['position']
|
||||
value = computation.memory_read(start_position, 32)
|
||||
print_var(stdout, value, local_type)
|
||||
else:
|
||||
stdout.write('Can not read local of type\n')
|
||||
|
||||
|
||||
def get_keys(n):
|
||||
out = []
|
||||
name = n
|
||||
for _ in range(name.count('[')):
|
||||
open_pos = name.find('[')
|
||||
close_pos = name.find(']')
|
||||
key = name[open_pos + 1:close_pos].replace('\'', '').replace('"', '')
|
||||
name = name[close_pos + 1:]
|
||||
out.append(key)
|
||||
return out
|
||||
|
||||
|
||||
def get_hash(var_pos, keys, _type):
|
||||
key_inp = b''
|
||||
|
||||
key_inp = keccak(
|
||||
int_to_big_endian(var_pos).rjust(32, b'\0') +
|
||||
keys[0].encode().ljust(32, b'\0')
|
||||
)
|
||||
for key in keys[1:]:
|
||||
key_inp = keccak(key_inp + key.encode().ljust(32, b'\0'))
|
||||
slot = big_endian_to_int(key_inp)
|
||||
return slot
|
||||
|
||||
|
||||
def valid_subscript(name, global_type):
|
||||
if name.count('[') != name.count(']'):
|
||||
return False
|
||||
elif global_type.count('[') != name.count('['):
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def parse_global(stdout, global_vars, computation, line):
|
||||
# print global value.
|
||||
name = line.split('.')[1]
|
||||
var_name = name[:name.find('[')] if '[' in name else name
|
||||
|
||||
if var_name not in global_vars:
|
||||
stdout.write('Global named "{}" not found.'.format(var_name) + '\n')
|
||||
return
|
||||
|
||||
global_type = global_vars[var_name]['type']
|
||||
slot = None
|
||||
|
||||
if global_type in base_types:
|
||||
slot = global_vars[var_name]['position']
|
||||
elif global_type.startswith('mapping') and valid_subscript(name, global_type):
|
||||
keys = get_keys(name)
|
||||
var_pos = global_vars[var_name]['position']
|
||||
slot = get_hash(var_pos, keys, global_type)
|
||||
|
||||
if slot is not None:
|
||||
value = computation.state.account_db.get_storage(
|
||||
address=computation.msg.storage_address,
|
||||
slot=slot,
|
||||
)
|
||||
if global_type.startswith('mapping'):
|
||||
global_type = global_type[global_type.find('(') + 1: global_type.find('[')]
|
||||
print_var(stdout, value, global_type)
|
||||
else:
|
||||
stdout.write('Can not read global of type "{}".\n'.format(global_type))
|
19
vdb/vdb.py
19
vdb/vdb.py
|
@ -1,4 +1,5 @@
|
|||
import cmd
|
||||
import readline
|
||||
|
||||
from eth_utils import to_hex
|
||||
|
||||
|
@ -20,7 +21,6 @@ commands = [
|
|||
|
||||
|
||||
def history(stdout):
|
||||
import readline
|
||||
for i in range(1, readline.get_current_history_length() + 1):
|
||||
stdout.write("%3d %s" % (i, readline.get_history_item(i)) + '\n')
|
||||
|
||||
|
@ -105,7 +105,22 @@ class VyperDebugCmd(cmd.Cmd):
|
|||
for name, info in variables.items():
|
||||
self.stdout.write('{}\t\t{}'.format(name, info['type']) + '\n')
|
||||
|
||||
def completenames(self, text, *ignored):
|
||||
line = text.strip()
|
||||
if 'self.' in line:
|
||||
return [
|
||||
'self.' + x
|
||||
for x in self.globals.keys()
|
||||
if x.startswith(line.split('self.')[1])
|
||||
]
|
||||
else:
|
||||
dotext = 'do_' + text
|
||||
cmds = [a[3:] for a in self.get_names() if a.startswith(dotext)]
|
||||
_, local_vars = self._get_fn_name_locals()
|
||||
return cmds + [x for x in local_vars.keys() if x.startswith(line)]
|
||||
|
||||
def default(self, line):
|
||||
line = line.strip()
|
||||
fn_name, local_variables = self._get_fn_name_locals()
|
||||
|
||||
if line.startswith('self.') and len(line) > 4:
|
||||
|
@ -114,7 +129,7 @@ class VyperDebugCmd(cmd.Cmd):
|
|||
)
|
||||
elif line in local_variables:
|
||||
parse_local(
|
||||
self.stdout, self.local_vars, self.computation, line
|
||||
self.stdout, local_variables, self.computation, line
|
||||
)
|
||||
else:
|
||||
self.stdout.write('*** Unknown syntax: %s\n' % line)
|
||||
|
|
Loading…
Reference in New Issue