Improve abi function lookup.

This commit is contained in:
Jacques Wagener 2018-08-06 13:20:17 +02:00
parent 36f8f770d1
commit 1cd39c504b
1 changed files with 45 additions and 6 deletions

View File

@ -1,7 +1,7 @@
#!/usr/bin/env python3.6
import argparse
import vyper
from collections import Counter
from pprint import pprint
from vyper import compiler
from vyper.parser import (
@ -78,6 +78,44 @@ def get_contract(w3, source_code, *args, **kwargs):
return contract
def get_func_abi(abi, func_name, args):
# next(filter(lambda func: func["name"] == func_name, abi))
def guess_type(v):
# is annotated type.
if ':' in v:
return v.split(':')[1]
# otherwise just guess
try:
int(v)
return 'int128'
except ValueError:
return 'bytes'
func_name_count_map = dict(Counter([a['name'] for a in abi]))
for candidate_func_abi in abi:
if candidate_func_abi["type"] == "function":
# try func name first.
if candidate_func_abi["name"] == func_name and \
func_name_count_map[candidate_func_abi['name']] == 1:
if len(args) != candidate_func_abi['inputs']:
print('Incorrect arguments for {}'.format(func_name))
return
else:
return candidate_func_abi
# is overloaded function, use full signature.
else:
full_sig = "{func_name}({type_str})".format(
func_name=func_name,
type_str=','.join([guess_type(x) for x in args])
)
method = "{func_name}({type_str})".format(
func_name=candidate_func_abi['name'],
type_str=','.join([x['type'] for x in candidate_func_abi['inputs']])
)
if method == full_sig:
return candidate_func_abi
if __name__ == '__main__':
with open(args.input_file) as fh:
@ -118,17 +156,18 @@ if __name__ == '__main__':
continue
print('\n* Calling {}({})'.format(func_name, ','.join(args)))
func_abi = next(filter(lambda func: func["name"] == func_name, abi))
if len(args) != len(func_abi['inputs']):
print('Argument mismatch, please provide correct arguments.')
func_abi = get_func_abi(abi, func_name, args)
if not func_abi:
print('Did not find function in abi.')
break
cast_args = cast_types(args, func_abi)
res = getattr(contract.functions, func_name)(*cast_args).call({'gas': func_abi['gas'] + 22000})
res = getattr(contract.functions, func_name)(*cast_args).call({'gas': func_abi.get('gas', 0) + 50000})
source_map = produce_source_map(code)
set_evm_opcode_debugger(source_code=code, source_map=source_map)
tx_hash = getattr(contract.functions, func_name)(*cast_args).transact({'gas': func_abi['gas'] + 22000})
tx_hash = getattr(contract.functions, func_name)(*cast_args).transact({'gas': func_abi.get('gas', 0) + 50000})
set_evm_opcode_pass()
print('- Returns:')