diff --git a/bin/vyper-run b/bin/vyper-run index 1b7ff94..6eb13ad 100755 --- a/bin/vyper-run +++ b/bin/vyper-run @@ -1,7 +1,7 @@ #!/usr/bin/env python3.6 import argparse import vyper - +from collections import Counter from pprint import pprint from vyper import compiler from vyper.parser import ( @@ -78,6 +78,44 @@ def get_contract(w3, source_code, *args, **kwargs): return contract +def get_func_abi(abi, func_name, args): + # next(filter(lambda func: func["name"] == func_name, abi)) + def guess_type(v): + # is annotated type. + if ':' in v: + return v.split(':')[1] + # otherwise just guess + try: + int(v) + return 'int128' + except ValueError: + return 'bytes' + + func_name_count_map = dict(Counter([a['name'] for a in abi])) + for candidate_func_abi in abi: + if candidate_func_abi["type"] == "function": + # try func name first. + if candidate_func_abi["name"] == func_name and \ + func_name_count_map[candidate_func_abi['name']] == 1: + if len(args) != candidate_func_abi['inputs']: + print('Incorrect arguments for {}'.format(func_name)) + return + else: + return candidate_func_abi + # is overloaded function, use full signature. + else: + full_sig = "{func_name}({type_str})".format( + func_name=func_name, + type_str=','.join([guess_type(x) for x in args]) + ) + method = "{func_name}({type_str})".format( + func_name=candidate_func_abi['name'], + type_str=','.join([x['type'] for x in candidate_func_abi['inputs']]) + ) + if method == full_sig: + return candidate_func_abi + + if __name__ == '__main__': with open(args.input_file) as fh: @@ -118,17 +156,18 @@ if __name__ == '__main__': continue print('\n* Calling {}({})'.format(func_name, ','.join(args))) - func_abi = next(filter(lambda func: func["name"] == func_name, abi)) - if len(args) != len(func_abi['inputs']): - print('Argument mismatch, please provide correct arguments.') + + func_abi = get_func_abi(abi, func_name, args) + if not func_abi: + print('Did not find function in abi.') break cast_args = cast_types(args, func_abi) - res = getattr(contract.functions, func_name)(*cast_args).call({'gas': func_abi['gas'] + 22000}) + res = getattr(contract.functions, func_name)(*cast_args).call({'gas': func_abi.get('gas', 0) + 50000}) source_map = produce_source_map(code) set_evm_opcode_debugger(source_code=code, source_map=source_map) - tx_hash = getattr(contract.functions, func_name)(*cast_args).transact({'gas': func_abi['gas'] + 22000}) + tx_hash = getattr(contract.functions, func_name)(*cast_args).transact({'gas': func_abi.get('gas', 0) + 50000}) set_evm_opcode_pass() print('- Returns:')