Make uint64 be `class` for type hinting

This commit is contained in:
Hsiao-Wei Wang 2019-06-10 23:16:59 -04:00
parent 9fc197af67
commit 8b64f37d22
No known key found for this signature in database
GPG Key ID: 95B070122902DEA4
2 changed files with 34 additions and 23 deletions

View File

@ -15,7 +15,6 @@ PHASE0_IMPORTS = '''from typing import (
Any, Any,
Dict, Dict,
List, List,
NewType,
Tuple, Tuple,
) )
@ -41,7 +40,6 @@ PHASE1_IMPORTS = '''from typing import (
Any, Any,
Dict, Dict,
List, List,
NewType,
Tuple, Tuple,
) )
@ -65,11 +63,11 @@ from eth2spec.utils.bls import (
from eth2spec.utils.hash_function import hash from eth2spec.utils.hash_function import hash
''' '''
NEW_TYPES = { NEW_TYPES = {
'Slot': 'int', 'Slot': 'uint64',
'Epoch': 'int', 'Epoch': 'uint64',
'Shard': 'int', 'Shard': 'uint64',
'ValidatorIndex': 'int', 'ValidatorIndex': 'uint64',
'Gwei': 'int', 'Gwei': 'uint64',
} }
BYTE_TYPES = [4, 32, 48, 96] BYTE_TYPES = [4, 32, 48, 96]
SUNDRY_FUNCTIONS = ''' SUNDRY_FUNCTIONS = '''
@ -79,7 +77,7 @@ def get_ssz_type_by_name(name: str) -> Container:
# Monkey patch validator compute committee code # Monkey patch validator compute committee code
_compute_committee = compute_committee _compute_committee = compute_committee
committee_cache = {} committee_cache = {} # type: Dict[Tuple[Bytes32, Bytes32, ValidatorIndex, int], List[ValidatorIndex]]
def compute_committee(indices: List[ValidatorIndex], seed: Bytes32, index: int, count: int) -> List[ValidatorIndex]: def compute_committee(indices: List[ValidatorIndex], seed: Bytes32, index: int, count: int) -> List[ValidatorIndex]:
@ -95,10 +93,10 @@ def compute_committee(indices: List[ValidatorIndex], seed: Bytes32, index: int,
# Monkey patch hash cache # Monkey patch hash cache
_hash = hash _hash = hash
hash_cache = {} hash_cache: Dict[bytes, Bytes32] = {}
def hash(x): def hash(x: bytes) -> Bytes32:
if x in hash_cache: if x in hash_cache:
return hash_cache[x] return hash_cache[x]
else: else:
@ -108,7 +106,7 @@ def hash(x):
# Access to overwrite spec constants based on configuration # Access to overwrite spec constants based on configuration
def apply_constants_preset(preset: Dict[str, Any]): def apply_constants_preset(preset: Dict[str, Any]) -> None:
global_vars = globals() global_vars = globals()
for k, v in preset.items(): for k, v in preset.items():
global_vars[k] = v global_vars[k] = v
@ -132,20 +130,28 @@ def objects_to_spec(functions: Dict[str, str],
""" """
Given all the objects that constitute a spec, combine them into a single pyfile. Given all the objects that constitute a spec, combine them into a single pyfile.
""" """
new_type_definitions = \ new_type_definitions = (
'\n'.join(['''%s = NewType('%s', %s)''' % (key, key, value) for key, value in new_types.items()]) '\n\n'.join(
[
f"class {key}({value}):\n"
f" def __init__(self, _x: uint64) -> None:\n"
f" ...\n"
for key, value in new_types.items()
]
)
)
functions_spec = '\n\n'.join(functions.values()) functions_spec = '\n\n'.join(functions.values())
constants_spec = '\n'.join(map(lambda x: '%s = %s' % (x, constants[x]), constants)) constants_spec = '\n'.join(map(lambda x: '%s = %s' % (x, constants[x]), constants))
ssz_objects_instantiation_spec = '\n\n'.join(ssz_objects.values()) ssz_objects_instantiation_spec = '\n\n'.join(ssz_objects.values())
ssz_objects_reinitialization_spec = ( ssz_objects_reinitialization_spec = (
'def init_SSZ_types():\n global_vars = globals()\n\n ' 'def init_SSZ_types() -> None:\n global_vars = globals()\n\n '
+ '\n\n '.join([re.sub(r'(?!\n\n)\n', r'\n ', value[:-1]) for value in ssz_objects.values()]) + '\n\n '.join([re.sub(r'(?!\n\n)\n', r'\n ', value[:-1]) for value in ssz_objects.values()])
+ '\n\n' + '\n\n'
+ '\n'.join(map(lambda x: ' global_vars[\'%s\'] = %s' % (x, x), ssz_objects.keys())) + '\n'.join(map(lambda x: ' global_vars[\'%s\'] = %s' % (x, x), ssz_objects.keys()))
) )
spec = ( spec = (
imports imports
+ '\n' + new_type_definitions + '\n\n' + new_type_definitions
+ '\n\n' + constants_spec + '\n\n' + constants_spec
+ '\n\n\n' + ssz_objects_instantiation_spec + '\n\n\n' + ssz_objects_instantiation_spec
+ '\n\n' + functions_spec + '\n\n' + functions_spec

View File

@ -46,8 +46,13 @@ class uint32(uint):
return super().__new__(cls, value) return super().__new__(cls, value)
# We simply default to uint64. But do give it a name, for readability class uint64(uint):
uint64 = NewType('uint64', int) byte_len = 8
def __new__(cls, value, *args, **kwargs):
if value.bit_length() > 64:
raise ValueError("value out of bounds for uint128")
return super().__new__(cls, value)
class uint128(uint): class uint128(uint):
@ -409,12 +414,12 @@ class Bytes96(BytesN):
# SSZ Defaults # SSZ Defaults
# ----------------------------- # -----------------------------
def get_zero_value(typ): def get_zero_value(typ):
if is_uint_type(typ): if is_list_type(typ):
return 0
elif is_list_type(typ):
return [] return []
elif is_bool_type(typ): elif is_bool_type(typ):
return False return False
elif is_uint_type(typ):
return uint64(0)
elif is_vector_type(typ): elif is_vector_type(typ):
return typ() return typ()
elif is_bytesn_type(typ): elif is_bytesn_type(typ):
@ -432,12 +437,12 @@ def get_zero_value(typ):
def infer_type(obj): def infer_type(obj):
if is_uint_type(obj.__class__): if isinstance(obj, int):
return obj.__class__
elif isinstance(obj, int):
return uint64 return uint64
elif isinstance(obj, list): elif isinstance(obj, list):
return List[infer_type(obj[0])] return List[infer_type(obj[0])]
elif is_uint_type(obj.__class__):
return obj.__class__
elif isinstance(obj, (Vector, Container, bool, BytesN, bytes)): elif isinstance(obj, (Vector, Container, bool, BytesN, bytes)):
return obj.__class__ return obj.__class__
else: else: