From 658ede2191e7da57a37a9c12f5c21cc4bc4c7020 Mon Sep 17 00:00:00 2001
From: Hsiao-Wei Wang <hsiaowei.eth@gmail.com>
Date: Fri, 9 Apr 2021 20:34:51 +0800
Subject: [PATCH] Refactor pyspec builder with `SpecAdjustment` classes

---
 setup.py | 339 +++++++++++++++++++++++++++++++++----------------------
 1 file changed, 206 insertions(+), 133 deletions(-)

diff --git a/setup.py b/setup.py
index 514f75c50..65b04fcf2 100644
--- a/setup.py
+++ b/setup.py
@@ -6,15 +6,35 @@ from distutils.util import convert_path
 import os
 import re
 from typing import Dict, NamedTuple, List
+from abc import ABC, abstractmethod
+
 
 FUNCTION_REGEX = r'^def [\w_]*'
 
-
 # Definitions in context.py
 PHASE0 = 'phase0'
 ALTAIR = 'altair'
 MERGE = 'merge'
 
+CONFIG_LOADER = '''
+apply_constants_config(globals())
+'''
+
+# The helper functions that are used when defining constants
+CONSTANT_DEP_SUNDRY_CONSTANTS_FUNCTIONS = '''
+def ceillog2(x: int) -> uint64:
+    if x < 1:
+        raise ValueError(f"ceillog2 accepts only positive values, x={x}")
+    return uint64((x - 1).bit_length())
+
+
+def floorlog2(x: int) -> uint64:
+    if x < 1:
+        raise ValueError(f"floorlog2 accepts only positive values, x={x}")
+    return uint64(x.bit_length() - 1)
+'''
+
+
 class SpecObject(NamedTuple):
     functions: Dict[str, str]
     custom_types: Dict[str, str]
@@ -111,11 +131,55 @@ def get_spec(file_name: str) -> SpecObject:
     )
 
 
-CONFIG_LOADER = '''
-apply_constants_config(globals())
-'''
+class SpecAdjustment(ABC):
+    @classmethod
+    @abstractmethod
+    def imports_and_predefinitions(cls) -> str:
+        """
+        Importing functions and defining special types/constants for building pyspec.
+        """
+        raise NotImplementedError()
 
-PHASE0_IMPORTS = '''from eth2spec.config.config_util import apply_constants_config
+    @classmethod
+    @abstractmethod
+    def sundry_functions(cls) -> str:
+        """
+        The functions that are (1) defined abstractly in specs or (2) adjusted for getting better performance.
+        """
+        raise NotImplementedError()
+
+    @classmethod
+    @abstractmethod
+    def hardcoded_ssz_dep_constants(cls) -> Dict[str, str]:
+        """
+        The constants that are required for SSZ objects.
+        """
+        raise NotImplementedError()
+
+    @classmethod
+    @abstractmethod
+    def hardcoded_custom_type_dep_constants(cls) -> Dict[str, str]:
+        """
+        The constants that are required for custom types.
+        """
+        raise NotImplementedError()
+
+    @classmethod
+    @abstractmethod
+    def invariant_checks(cls) -> str:
+        """
+        The invariant checks
+        """
+        raise NotImplementedError()
+
+
+#
+# Phase0SpecAdjustment
+#
+class Phase0SpecAdjustment(SpecAdjustment):
+    @classmethod
+    def imports_and_predefinitions(cls) -> str:
+        return '''from eth2spec.config.config_util import apply_constants_config
 from typing import (
     Any, Callable, Dict, Set, Sequence, Tuple, Optional, TypeVar
 )
@@ -141,82 +205,9 @@ SSZObject = TypeVar('SSZObject', bound=View)
 CONFIG_NAME = 'mainnet'
 '''
 
-ALTAIR_IMPORTS = '''from eth2spec.phase0 import spec as phase0
-from eth2spec.config.config_util import apply_constants_config
-from typing import (
-    Any, Dict, Set, Sequence, NewType, Tuple, TypeVar, Callable, Optional, Union
-)
-
-from dataclasses import (
-    dataclass,
-    field,
-)
-
-from lru import LRU
-
-from eth2spec.utils.ssz.ssz_impl import hash_tree_root, copy, uint_to_bytes
-from eth2spec.utils.ssz.ssz_typing import (
-    View, boolean, Container, List, Vector, uint8, uint32, uint64,
-    Bytes1, Bytes4, Bytes32, Bytes48, Bytes96, Bitlist, Bitvector,
-    Path,
-)
-from eth2spec.utils import bls
-
-from eth2spec.utils.hash_function import hash
-
-# Whenever altair is loaded, make sure we have the latest phase0
-from importlib import reload
-reload(phase0)
-
-
-SSZVariableName = str
-GeneralizedIndex = NewType('GeneralizedIndex', int)
-SSZObject = TypeVar('SSZObject', bound=View)
-
-CONFIG_NAME = 'mainnet'
-'''
-
-MERGE_IMPORTS = '''from eth2spec.phase0 import spec as phase0
-from eth2spec.config.config_util import apply_constants_config
-from typing import (
-    Any, Callable, Dict, Set, Sequence, Tuple, Optional, TypeVar
-)
-
-from dataclasses import (
-    dataclass,
-    field,
-)
-
-from lru import LRU
-
-from eth2spec.utils.ssz.ssz_impl import hash_tree_root, copy, uint_to_bytes
-from eth2spec.utils.ssz.ssz_typing import (
-    View, boolean, Container, List, Vector, uint8, uint32, uint64, uint256,
-    Bytes1, Bytes4, Bytes20, Bytes32, Bytes48, Bytes96, Bitlist,
-    ByteList, ByteVector
-)
-from eth2spec.utils import bls
-
-from eth2spec.utils.hash_function import hash
-
-SSZObject = TypeVar('SSZObject', bound=View)
-
-CONFIG_NAME = 'mainnet'
-'''
-
-SUNDRY_CONSTANTS_FUNCTIONS = '''
-def ceillog2(x: int) -> uint64:
-    if x < 1:
-        raise ValueError(f"ceillog2 accepts only positive values, x={x}")
-    return uint64((x - 1).bit_length())
-
-
-def floorlog2(x: int) -> uint64:
-    if x < 1:
-        raise ValueError(f"floorlog2 accepts only positive values, x={x}")
-    return uint64(x.bit_length() - 1)
-'''
-PHASE0_SUNDRY_FUNCTIONS = '''
+    @classmethod
+    def sundry_functions(cls) -> str:
+        return '''
 def get_eth1_data(block: Eth1Block) -> Eth1Data:
     """
     A stub function return mocking Eth1Data.
@@ -287,9 +278,62 @@ get_attesting_indices = cache_this(
     ),
     _get_attesting_indices, lru_size=SLOTS_PER_EPOCH * MAX_COMMITTEES_PER_SLOT * 3)'''
 
+    @classmethod
+    def hardcoded_ssz_dep_constants(cls) -> Dict[str, str]:
+        return {}
 
-ALTAIR_SUNDRY_FUNCTIONS = '''
+    @classmethod
+    def hardcoded_custom_type_dep_constants(cls) -> Dict[str, str]:
+        return {}
 
+    @classmethod
+    def invariant_checks(cls) -> str:
+        return ''
+
+
+#
+# AltairSpecAdjustment
+#
+class AltairSpecAdjustment(Phase0SpecAdjustment):
+    @classmethod
+    def imports_and_predefinitions(cls) -> str:
+        return '''from eth2spec.phase0 import spec as phase0
+from eth2spec.config.config_util import apply_constants_config
+from typing import (
+    Any, Dict, Set, Sequence, NewType, Tuple, TypeVar, Callable, Optional, Union
+)
+
+from dataclasses import (
+    dataclass,
+    field,
+)
+
+from lru import LRU
+
+from eth2spec.utils.ssz.ssz_impl import hash_tree_root, copy, uint_to_bytes
+from eth2spec.utils.ssz.ssz_typing import (
+    View, boolean, Container, List, Vector, uint8, uint32, uint64,
+    Bytes1, Bytes4, Bytes32, Bytes48, Bytes96, Bitlist, Bitvector,
+    Path,
+)
+from eth2spec.utils import bls
+
+from eth2spec.utils.hash_function import hash
+
+# Whenever altair is loaded, make sure we have the latest phase0
+from importlib import reload
+reload(phase0)
+
+
+SSZVariableName = str
+GeneralizedIndex = NewType('GeneralizedIndex', int)
+SSZObject = TypeVar('SSZObject', bound=View)
+
+CONFIG_NAME = 'mainnet'
+'''
+    @classmethod
+    def sundry_functions(cls) -> str:
+        return super().sundry_functions() + '\n\n' + '''
 def get_generalized_index(ssz_class: Any, *path: Sequence[Union[int, SSZVariableName]]) -> GeneralizedIndex:
     ssz_path = Path(ssz_class)
     for item in path:
@@ -297,7 +341,59 @@ def get_generalized_index(ssz_class: Any, *path: Sequence[Union[int, SSZVariable
     return GeneralizedIndex(ssz_path.gindex())'''
 
 
-MERGE_SUNDRY_FUNCTIONS = """
+    @classmethod
+    def hardcoded_ssz_dep_constants(cls) -> Dict[str, str]:
+        constants = {
+            'FINALIZED_ROOT_INDEX': 'GeneralizedIndex(105)',
+            'NEXT_SYNC_COMMITTEE_INDEX': 'GeneralizedIndex(55)',
+        }
+        return {**super().hardcoded_ssz_dep_constants(), **constants}
+
+    @classmethod
+    def invariant_checks(cls) -> str:
+        return '''
+assert (
+    TIMELY_HEAD_WEIGHT + TIMELY_SOURCE_WEIGHT + TIMELY_TARGET_WEIGHT + SYNC_REWARD_WEIGHT + PROPOSER_WEIGHT
+) == WEIGHT_DENOMINATOR'''
+
+
+#
+# MergeSpecAdjustment
+#
+class MergeSpecAdjustment(Phase0SpecAdjustment):
+    @classmethod
+    def imports_and_predefinitions(cls):
+        return '''from eth2spec.phase0 import spec as phase0
+from eth2spec.config.config_util import apply_constants_config
+from typing import (
+    Any, Callable, Dict, Set, Sequence, Tuple, Optional, TypeVar
+)
+
+from dataclasses import (
+    dataclass,
+    field,
+)
+
+from lru import LRU
+
+from eth2spec.utils.ssz.ssz_impl import hash_tree_root, copy, uint_to_bytes
+from eth2spec.utils.ssz.ssz_typing import (
+    View, boolean, Container, List, Vector, uint8, uint32, uint64, uint256,
+    Bytes1, Bytes4, Bytes20, Bytes32, Bytes48, Bytes96, Bitlist,
+    ByteList, ByteVector
+)
+from eth2spec.utils import bls
+
+from eth2spec.utils.hash_function import hash
+
+SSZObject = TypeVar('SSZObject', bound=View)
+
+CONFIG_NAME = 'mainnet'
+'''
+
+    @classmethod
+    def sundry_functions(cls) -> str:
+        return super().sundry_functions() + '\n\n' + """
 ExecutionState = Any
 
 
@@ -321,22 +417,18 @@ def produce_execution_payload(parent_hash: Bytes32) -> ExecutionPayload:
     pass"""
 
 
-# The constants that depend on SSZ objects
-# Will verify the value at the end of the spec
-ALTAIR_HARDCODED_SSZ_DEP_CONSTANTS = {
-    'FINALIZED_ROOT_INDEX': 'GeneralizedIndex(105)',
-    'NEXT_SYNC_COMMITTEE_INDEX': 'GeneralizedIndex(55)',
-}
+    @classmethod
+    def hardcoded_custom_type_dep_constants(cls) -> str:
+        constants = {
+            'MAX_BYTES_PER_OPAQUE_TRANSACTION': 'uint64(2**20)',
+        }
+        return {**super().hardcoded_custom_type_dep_constants(), **constants}
 
 
-ALTAIR_INVAIANT_CHECKS = '''
-assert (
-    TIMELY_HEAD_WEIGHT + TIMELY_SOURCE_WEIGHT + TIMELY_TARGET_WEIGHT + SYNC_REWARD_WEIGHT + PROPOSER_WEIGHT
-) == WEIGHT_DENOMINATOR'''
-
-
-MERGE_HARDCODED_CUSTOM_TYPE_DEP_CONSTANTS = {
-    'MAX_BYTES_PER_OPAQUE_TRANSACTION': 'uint64(2**20)',
+spec_adjustments = {
+    PHASE0: Phase0SpecAdjustment,
+    ALTAIR: AltairSpecAdjustment,
+    MERGE: MergeSpecAdjustment,
 }
 
 
@@ -352,7 +444,7 @@ def is_merge(fork):
     return fork == MERGE
 
 
-def objects_to_spec(spec_object: SpecObject, imports: str, fork: str, ordered_class_objects: Dict[str, str]) -> str:
+def objects_to_spec(spec_object: SpecObject, adjustment: SpecAdjustment, fork: str, ordered_class_objects: Dict[str, str]) -> str:
     """
     Given all the objects that constitute a spec, combine them into a single pyfile.
     """
@@ -382,41 +474,29 @@ def objects_to_spec(spec_object: SpecObject, imports: str, fork: str, ordered_cl
             spec_object.constants[k] += "  # noqa: E501"
     constants_spec = '\n'.join(map(lambda x: '%s = %s' % (x, spec_object.constants[x]), spec_object.constants))
     ordered_class_objects_spec = '\n\n'.join(ordered_class_objects.values())
-
-    if is_altair(fork):
-        altair_ssz_dep_constants = '\n'.join(map(lambda x: '%s = %s' % (x, ALTAIR_HARDCODED_SSZ_DEP_CONSTANTS[x]), ALTAIR_HARDCODED_SSZ_DEP_CONSTANTS))
-
-    if is_merge(fork):
-        merge_custom_type_dep_constants = '\n'.join(map(lambda x: '%s = %s' % (x, MERGE_HARDCODED_CUSTOM_TYPE_DEP_CONSTANTS[x]), MERGE_HARDCODED_CUSTOM_TYPE_DEP_CONSTANTS))
-
-    
+    ssz_dep_constants = '\n'.join(map(lambda x: '%s = %s' % (x, adjustment.hardcoded_ssz_dep_constants()[x]), adjustment.hardcoded_ssz_dep_constants()))
+    ssz_dep_constants_verification = '\n'.join(map(lambda x: 'assert %s == %s' % (x, spec_object.ssz_dep_constants[x]), adjustment.hardcoded_ssz_dep_constants()))
+    custom_type_dep_constants = '\n'.join(map(lambda x: '%s = %s' % (x, adjustment.hardcoded_custom_type_dep_constants()[x]), adjustment.hardcoded_custom_type_dep_constants()))
     spec = (
-            imports
+            adjustment.imports_and_predefinitions()
             + '\n\n' + f"fork = \'{fork}\'\n"
             # The constants that some SSZ containers require. Need to be defined before `new_type_definitions`
-            + ('\n\n' + merge_custom_type_dep_constants  + '\n' if is_merge(fork) else '')
+            + ('\n\n' + custom_type_dep_constants + '\n' if custom_type_dep_constants != '' else '')
             + '\n\n' + new_type_definitions
-            + '\n' + SUNDRY_CONSTANTS_FUNCTIONS
+            + '\n' + CONSTANT_DEP_SUNDRY_CONSTANTS_FUNCTIONS
             # The constants that some SSZ containers require. Need to be defined before `constants_spec`
-            + ('\n\n' + altair_ssz_dep_constants if is_altair(fork) else '')
+            + ('\n\n' + ssz_dep_constants if ssz_dep_constants != '' else '')
             + '\n\n' + constants_spec
             + '\n\n' + CONFIG_LOADER
             + '\n\n' + ordered_class_objects_spec
             + '\n\n' + functions_spec
-            # Functions to make pyspec work
-            + '\n' + PHASE0_SUNDRY_FUNCTIONS
-            + ('\n' + ALTAIR_SUNDRY_FUNCTIONS if is_altair(fork) else '')
-            + ('\n' + MERGE_SUNDRY_FUNCTIONS if is_merge(fork) else '')
+            + '\n' + adjustment.sundry_functions()
+            # Since some constants are hardcoded in setup.py, the following assertions verify that the hardcoded constants are
+            # as same as the spec definition.
+            + ('\n\n\n' + ssz_dep_constants_verification if ssz_dep_constants_verification != '' else '')
+            + ('\n' + adjustment.invariant_checks() if adjustment.invariant_checks() != '' else '')
+            + '\n'
     )
-
-    # Since some constants are hardcoded in setup.py, the following assertions verify that the hardcoded constants are
-    # as same as the spec definition.
-    if is_altair(fork):
-        altair_ssz_dep_constants_verification = '\n'.join(map(lambda x: 'assert %s == %s' % (x, spec_object.ssz_dep_constants[x]), ALTAIR_HARDCODED_SSZ_DEP_CONSTANTS))
-        spec += '\n\n\n' + altair_ssz_dep_constants_verification
-        spec += '\n' + ALTAIR_INVAIANT_CHECKS
-
-    spec += '\n'
     return spec
 
 
@@ -496,13 +576,6 @@ def combine_spec_objects(spec0: SpecObject, spec1: SpecObject) -> SpecObject:
     )
 
 
-fork_imports = {
-    'phase0': PHASE0_IMPORTS,
-    'altair': ALTAIR_IMPORTS,
-    'merge': MERGE_IMPORTS,
-}
-
-
 def build_spec(fork: str, source_files: List[str]) -> str:
     all_specs = [get_spec(spec) for spec in source_files]
 
@@ -513,7 +586,7 @@ def build_spec(fork: str, source_files: List[str]) -> str:
     class_objects = {**spec_object.ssz_objects, **spec_object.dataclasses}
     dependency_order_class_objects(class_objects, spec_object.custom_types)
 
-    return objects_to_spec(spec_object, fork_imports[fork], fork, class_objects)
+    return objects_to_spec(spec_object, spec_adjustments[fork], fork, class_objects)
 
 
 class PySpecCommand(Command):
@@ -611,7 +684,7 @@ class BuildPyCommand(build_py):
         self.run_command('pyspec')
 
     def run(self):
-        for spec_fork in fork_imports:
+        for spec_fork in spec_adjustments:
             self.run_pyspec_cmd(spec_fork=spec_fork)
 
         super(BuildPyCommand, self).run()
@@ -639,7 +712,7 @@ class PyspecDevCommand(Command):
 
     def run(self):
         print("running build_py command")
-        for spec_fork in fork_imports:
+        for spec_fork in spec_adjustments:
             self.run_pyspec_cmd(spec_fork=spec_fork)
 
 commands = {