diff --git a/scripts/build_spec.py b/scripts/build_spec.py index a18107528..cdcbc2456 100644 --- a/scripts/build_spec.py +++ b/scripts/build_spec.py @@ -5,6 +5,83 @@ from argparse import ArgumentParser from typing import Tuple, List +IMPORTS = '''from typing import ( + Any, + Dict, + List, + NewType, + Tuple, +) + +from eth2spec.utils.minimal_ssz import ( + SSZType, + hash_tree_root, + signing_root, +) + +from eth2spec.utils.bls_stub import ( + bls_aggregate_pubkeys, + bls_verify, + bls_verify_multiple, +) + +from eth2spec.utils.hash_function import hash +''' +NEW_TYPE_DEFINITIONS = ''' +Slot = NewType('Slot', int) # uint64 +Epoch = NewType('Epoch', int) # uint64 +Shard = NewType('Shard', int) # uint64 +ValidatorIndex = NewType('ValidatorIndex', int) # uint64 +Gwei = NewType('Gwei', int) # uint64 +Bytes32 = NewType('Bytes32', bytes) # bytes32 +BLSPubkey = NewType('BLSPubkey', bytes) # bytes48 +BLSSignature = NewType('BLSSignature', bytes) # bytes96 +Store = None +''' +SUNDRY_FUNCTIONS = ''' +# Monkey patch validator compute committee code +_compute_committee = compute_committee +committee_cache = {} + + +def compute_committee(indices: List[ValidatorIndex], seed: Bytes32, index: int, count: int) -> List[ValidatorIndex]: + param_hash = (hash_tree_root(indices), seed, index, count) + + if param_hash in committee_cache: + return committee_cache[param_hash] + else: + ret = _compute_committee(indices, seed, index, count) + committee_cache[param_hash] = ret + return ret + + +# Monkey patch hash cache +_hash = hash +hash_cache = {} + + +def hash(x): + if x in hash_cache: + return hash_cache[x] + else: + ret = _hash(x) + hash_cache[x] = ret + return ret + + +# Access to overwrite spec constants based on configuration +def apply_constants_preset(preset: Dict[str, Any]): + global_vars = globals() + for k, v in preset.items(): + global_vars[k] = v + + # Deal with derived constants + global_vars['GENESIS_EPOCH'] = slot_to_epoch(GENESIS_SLOT) + + # Initialize SSZ types again, to account for changed lengths + init_SSZ_types() +''' + def split_and_label(regex_pattern: str, text: str) -> List[str]: ''' Splits a string based on regex, but down not remove the matched text. @@ -72,111 +149,71 @@ def merger(oldfile:str, newfile:str) -> str: return ''.join(elem for elem in map(lambda x: x[1], old_objects)) +def objects_to_spec(functions, constants, ssz_objects): + functions_spec = '\n\n'.join(functions.values()) + constants_spec = '\n'.join(map(lambda x: '%s = %s' % (x, constants[x]),constants)) + ssz_objects_instantiation_spec = '\n'.join(map(lambda x: '%s = SSZType(%s)' % (x, ssz_objects[x][:-1]), ssz_objects)) + ssz_objects_reinitialization_spec = '\n'.join( + map(lambda x: ' global_vars[%s] = SSZType(%s })' % (x, re.sub('( ){4}', ' '*8, ssz_objects[x][:-2])), ssz_objects)) + ssz_objects_reinitialization_spec = ( + 'def init_SSZ_types():\n global_vars = globals()\n' + + ssz_objects_reinitialization_spec + ) + return ( + IMPORTS + + '\n' + NEW_TYPE_DEFINITIONS + + '\n' + constants_spec + + '\n' + ssz_objects_instantiation_spec + + '\n\n\n' + functions_spec + + '\n' + SUNDRY_FUNCTIONS + + '\n\n' + ssz_objects_reinitialization_spec + + '\n' + ) + +def combine_functions(old_funcitons, new_functions): + for key, value in new_functions.items(): + old_funcitons[key] = value + # TODO: Add insert functionality + return old_funcitons + + +def combine_constants(old_constants, new_constants): + for key, value in new_constants.items(): + old_constants[key] = value + return old_constants + +def combine_ssz_objects(old_objects, new_objects): + remove_encasing = lambda x: x[1:-1] + old_objects = map(remove_encasing, old_objects) + new_objects = map(remove_encasing, new_objects) + for key, value in new_objects.items(): + old_objects[key] += value + reapply_encasing = lambda x: '{%s}' %x + return map(reapply_encasing, old_objects) + + def build_phase0_spec(sourcefile, outfile=None): - code_lines = [] - code_lines.append(""" -from typing import ( - Any, - Dict, - List, - NewType, - Tuple, -) - -from eth2spec.utils.minimal_ssz import ( - SSZType, - hash_tree_root, - signing_root, -) - -from eth2spec.utils.bls_stub import ( - bls_aggregate_pubkeys, - bls_verify, - bls_verify_multiple, -) - -from eth2spec.utils.hash_function import hash - - -# stub, will get overwritten by real var -SLOTS_PER_EPOCH = 64 - -Slot = NewType('Slot', int) # uint64 -Epoch = NewType('Epoch', int) # uint64 -Shard = NewType('Shard', int) # uint64 -ValidatorIndex = NewType('ValidatorIndex', int) # uint64 -Gwei = NewType('Gwei', int) # uint64 -Bytes32 = NewType('Bytes32', bytes) # bytes32 -BLSPubkey = NewType('BLSPubkey', bytes) # bytes48 -BLSSignature = NewType('BLSSignature', bytes) # bytes96 -Store = None -""") - - code_lines += function_puller.get_spec(sourcefile) - - code_lines.append(""" -# Monkey patch validator compute committee code -_compute_committee = compute_committee -committee_cache = {} - - -def compute_committee(indices: List[ValidatorIndex], seed: Bytes32, index: int, count: int) -> List[ValidatorIndex]: - param_hash = (hash_tree_root(indices), seed, index, count) - - if param_hash in committee_cache: - return committee_cache[param_hash] - else: - ret = _compute_committee(indices, seed, index, count) - committee_cache[param_hash] = ret - return ret - - -# Monkey patch hash cache -_hash = hash -hash_cache = {} - - -def hash(x): - if x in hash_cache: - return hash_cache[x] - else: - ret = _hash(x) - hash_cache[x] = ret - return ret - - -# Access to overwrite spec constants based on configuration -def apply_constants_preset(preset: Dict[str, Any]): - global_vars = globals() - for k, v in preset.items(): - global_vars[k] = v - - # Deal with derived constants - global_vars['GENESIS_EPOCH'] = slot_to_epoch(GENESIS_SLOT) - - # Initialize SSZ types again, to account for changed lengths - init_SSZ_types() - -""") - + functions, constants, ssz_objects = function_puller.get_spec(sourcefile) + spec = objects_to_spec(functions, constants, ssz_objects) if outfile is not None: with open(outfile, 'w') as out: - out.write("\n".join(code_lines)) + out.write(spec) else: - return "\n".join(code_lines) + return spec def build_phase1_spec(phase0_sourcefile, phase1_sourcefile, outfile=None): - phase0_code = build_phase0_spec(phase0_sourcefile) - phase1_code = build_phase0_spec(phase1_sourcefile) - phase0_code, phase1_code = inserter(phase0_code, phase1_code) - phase1_code = merger(phase0_code, phase1_code) - + phase0_functions, phase0_constants, phase0_ssz_objects = function_puller.get_spec(phase0_sourcefile) + phase1_functions, phase1_constants, phase1_ssz_objects = function_puller.get_spec(phase1_sourcefile) + functions = combine_functions(phase0_functions, phase1_functions) + constants = combine_constants(phase0_constants, phase1_constants) + ssz_objects = combine_functions(phase0_ssz_objects, phase1_ssz_objects) + spec = objects_to_spec(functions, constants, ssz_objects) if outfile is not None: with open(outfile, 'w') as out: - out.write(phase1_code) + out.write(spec) else: - return phase1_code + return spec if __name__ == '__main__': diff --git a/scripts/function_puller.py b/scripts/function_puller.py index 97bc62821..720f07502 100644 --- a/scripts/function_puller.py +++ b/scripts/function_puller.py @@ -1,13 +1,21 @@ import sys +import re from typing import List +from collections import defaultdict -def get_spec(file_name: str) -> List[str]: +FUNCTION_REGEX = r'^def [\w_]*' + + +def get_spec(file_name: str): code_lines = [] - pulling_from = None - current_name = None - current_typedef = None - type_defs = [] + pulling_from = None # line number of start of latest object + current_name = None # most recent section title + functions = defaultdict(str) + constants = {} + ssz_objects = defaultdict(str) + function_matcher = re.compile(FUNCTION_REGEX) + # type_defs = [] for linenum, line in enumerate(open(file_name).readlines()): line = line.rstrip() if pulling_from is None and len(line) > 0 and line[0] == '#' and line[-1] == '`': @@ -16,29 +24,21 @@ def get_spec(file_name: str) -> List[str]: assert pulling_from is None pulling_from = linenum + 1 elif line[:3] == '```': - if pulling_from is None: - pulling_from = linenum - else: - if current_typedef is not None: - assert code_lines[-1] == '}' - code_lines[-1] = '})' - current_typedef[-1] = '})' - type_defs.append((current_name, current_typedef)) - pulling_from = None - current_typedef = None + pulling_from = None else: - if pulling_from == linenum and line == '{': - code_lines.append('%s = SSZType({' % current_name) - current_typedef = ['global_vars["%s"] = SSZType({' % current_name] - elif pulling_from is not None: - # Add some whitespace between functions - if line[:3] == 'def': - code_lines.append('') - code_lines.append('') - code_lines.append(line) - # Remember type def lines - if current_typedef is not None: - current_typedef.append(line) + # # Handle SSZObjects + # if pulling_from == linenum and line == '{': + # code_lines.append('%s = SSZType({' % current_name) + # Handle function definitions + if pulling_from is not None: + match = function_matcher.match(line) + if match is not None: + current_name = match.group(0) + if function_matcher.match(current_name) is None: + ssz_objects[current_name] += line + '\n' + else: + functions[current_name] += line + '\n' + # Handle constant table entries elif pulling_from is None and len(line) > 0 and line[0] == '|': row = line[1:].split('|') if len(row) >= 2: @@ -53,18 +53,5 @@ def get_spec(file_name: str) -> List[str]: if c not in 'ABCDEFGHIJKLMNOPQRSTUVWXYZ_0123456789': eligible = False if eligible: - code_lines.append(row[0] + ' = ' + (row[1].replace('**TBD**', '0x1234567890123456789012345678901234567890'))) - # Build type-def re-initialization - code_lines.append('\n') - code_lines.append('def init_SSZ_types():') - code_lines.append(' global_vars = globals()') - for ssz_type_name, ssz_type in type_defs: - code_lines.append('') - for type_line in ssz_type: - if len(type_line) > 0: - code_lines.append(' ' + type_line) - code_lines.append('\n') - code_lines.append('def get_ssz_type_by_name(name: str) -> SSZType:') - code_lines.append(' return globals()[name]') - code_lines.append('') - return code_lines + constants[row[0]] = row[1].replace('**TBD**', '0x1234567890123456789012345678901234567890') + return functions, constants, ssz_objects