Add support for retrieving values from maps.

This commit is contained in:
Jacques Wagener 2018-07-24 10:25:15 +02:00
parent 7a689bf56f
commit 4ecbe602e6
No known key found for this signature in database
GPG Key ID: C294D1025DA0E923
3 changed files with 180 additions and 2 deletions

View File

@ -0,0 +1,56 @@
import io
def test_single_key(get_contract, get_last_out):
code = """
amap: bytes32[bytes32]
@public
def set(key: bytes32, value: bytes32):
self.amap[key] = value
@public
def get(key: bytes32) -> bytes32:
vdb
return self.amap[key]
"""
stdin = io.StringIO(
"self.amap['one']\n"
)
stdout = io.StringIO()
c = get_contract(code, stdin=stdin, stdout=stdout)
c.functions.set(b'one', b'hello!').transact()
res = c.functions.get(b'one').call({'gas': 600000})
assert res[:6] == b'hello!'
assert 'hello!' in stdout.getvalue()
def test_double_key(get_contract, get_last_out):
code = """
amap: (bytes32)[bytes32][bytes32]
@public
def set(key1: bytes32, key2: bytes32, value: bytes32):
self.amap[key1][key2] = value
@public
def get(key1: bytes32, key2: bytes32) -> bytes32:
vdb
return self.amap[key1][key2]
"""
stdin = io.StringIO(
"self.amap[one][two]\n"
)
stdout = io.StringIO()
c = get_contract(code, stdin=stdin, stdout=stdout)
c.functions.set(b'one', b'two', b'hello!').transact()
res = c.functions.get(b'one', b'two').call({'gas': 600000})
assert res[:6] == b'hello!'
assert 'hello!' in stdout.getvalue()

107
vdb/variables.py Normal file
View File

@ -0,0 +1,107 @@
from eth_hash.auto import keccak
from eth_utils import to_hex
from eth_abi import decode_single
from evm.utils.numeric import (
big_endian_to_int,
int_to_big_endian,
)
base_types = (
'int128',
'uint256',
'address',
'bytes32'
)
def print_var(stdout, value, var_typ):
if isinstance(value, int):
v = int_to_big_endian(value)
else:
v = value
if isinstance(v, bytes):
if var_typ in ('int128', 'uint256'):
stdout.write(str(decode_single(var_typ, value)) + '\n')
elif var_typ == 'address':
stdout.write(to_hex(v[12:]) + '\n')
elif var_typ.startswith('bytes'):
stdout.write(v.decode() + '\n')
else:
stdout.write(v.decode() + '\n')
def parse_local(stdout, local_variables, computation, line):
var_info = local_variables[line]
local_type = var_info['type']
if local_type in base_types:
start_position = var_info['position']
value = computation.memory_read(start_position, 32)
print_var(stdout, value, local_type)
else:
stdout.write('Can not read local of type\n')
def get_keys(n):
out = []
name = n
for _ in range(name.count('[')):
open_pos = name.find('[')
close_pos = name.find(']')
key = name[open_pos + 1:close_pos].replace('\'', '').replace('"', '')
name = name[close_pos + 1:]
out.append(key)
return out
def get_hash(var_pos, keys, _type):
key_inp = b''
key_inp = keccak(
int_to_big_endian(var_pos).rjust(32, b'\0') +
keys[0].encode().ljust(32, b'\0')
)
for key in keys[1:]:
key_inp = keccak(key_inp + key.encode().ljust(32, b'\0'))
slot = big_endian_to_int(key_inp)
return slot
def valid_subscript(name, global_type):
if name.count('[') != name.count(']'):
return False
elif global_type.count('[') != name.count('['):
return False
return True
def parse_global(stdout, global_vars, computation, line):
# print global value.
name = line.split('.')[1]
var_name = name[:name.find('[')] if '[' in name else name
if var_name not in global_vars:
stdout.write('Global named "{}" not found.'.format(var_name) + '\n')
return
global_type = global_vars[var_name]['type']
slot = None
if global_type in base_types:
slot = global_vars[var_name]['position']
elif global_type.startswith('mapping') and valid_subscript(name, global_type):
keys = get_keys(name)
var_pos = global_vars[var_name]['position']
slot = get_hash(var_pos, keys, global_type)
if slot is not None:
value = computation.state.account_db.get_storage(
address=computation.msg.storage_address,
slot=slot,
)
if global_type.startswith('mapping'):
global_type = global_type[global_type.find('(') + 1: global_type.find('[')]
print_var(stdout, value, global_type)
else:
stdout.write('Can not read global of type "{}".\n'.format(global_type))

View File

@ -1,4 +1,5 @@
import cmd
import readline
from eth_utils import to_hex
@ -20,7 +21,6 @@ commands = [
def history(stdout):
import readline
for i in range(1, readline.get_current_history_length() + 1):
stdout.write("%3d %s" % (i, readline.get_history_item(i)) + '\n')
@ -105,7 +105,22 @@ class VyperDebugCmd(cmd.Cmd):
for name, info in variables.items():
self.stdout.write('{}\t\t{}'.format(name, info['type']) + '\n')
def completenames(self, text, *ignored):
line = text.strip()
if 'self.' in line:
return [
'self.' + x
for x in self.globals.keys()
if x.startswith(line.split('self.')[1])
]
else:
dotext = 'do_' + text
cmds = [a[3:] for a in self.get_names() if a.startswith(dotext)]
_, local_vars = self._get_fn_name_locals()
return cmds + [x for x in local_vars.keys() if x.startswith(line)]
def default(self, line):
line = line.strip()
fn_name, local_variables = self._get_fn_name_locals()
if line.startswith('self.') and len(line) > 4:
@ -114,7 +129,7 @@ class VyperDebugCmd(cmd.Cmd):
)
elif line in local_variables:
parse_local(
self.stdout, self.local_vars, self.computation, line
self.stdout, local_variables, self.computation, line
)
else:
self.stdout.write('*** Unknown syntax: %s\n' % line)