From 4e747fb8879540012655e534e9d8fd214e47be87 Mon Sep 17 00:00:00 2001 From: protolambda Date: Thu, 20 Jun 2019 20:25:22 +0200 Subject: [PATCH] fixes for class based ssz typing --- scripts/build_spec.py | 32 ++++++++--------- specs/core/0_beacon-chain.md | 8 +++-- .../pyspec/eth2spec/test/helpers/custody.py | 4 +-- .../pyspec/eth2spec/utils/merkle_minimal.py | 2 +- .../pyspec/eth2spec/utils/ssz/ssz_impl.py | 5 +-- .../pyspec/eth2spec/utils/ssz/ssz_typing.py | 36 ++++++++++--------- .../eth2spec/utils/ssz/test_ssz_typing.py | 20 ++++++++++- 7 files changed, 66 insertions(+), 41 deletions(-) diff --git a/scripts/build_spec.py b/scripts/build_spec.py index 612edbd00..d33ba6642 100644 --- a/scripts/build_spec.py +++ b/scripts/build_spec.py @@ -65,22 +65,6 @@ def get_ssz_type_by_name(name: str) -> Container: return globals()[name] -# Monkey patch validator compute committee code -_compute_committee = compute_committee -committee_cache: Dict[Tuple[Hash, Hash, int, int], Tuple[ValidatorIndex, ...]] = {} - - -def compute_committee(indices: Tuple[ValidatorIndex, ...], # type: ignore - seed: Hash, - index: int, - count: int) -> Tuple[ValidatorIndex, ...]: - param_hash = (hash_tree_root(indices), seed, index, count) - - if param_hash not in committee_cache: - committee_cache[param_hash] = _compute_committee(indices, seed, index, count) - return committee_cache[param_hash] - - # Monkey patch hash cache _hash = hash hash_cache: Dict[bytes, Hash] = {} @@ -92,6 +76,22 @@ def hash(x: bytes) -> Hash: return hash_cache[x] +# Monkey patch validator compute committee code +_compute_committee = compute_committee +committee_cache: Dict[Tuple[Hash, Hash, int, int], Tuple[ValidatorIndex, ...]] = {} + + +def compute_committee(indices: Tuple[ValidatorIndex, ...], # type: ignore + seed: Hash, + index: int, + count: int) -> Tuple[ValidatorIndex, ...]: + param_hash = (hash(b''.join(index.to_bytes(length=4, byteorder='little') for index in indices)), seed, index, count) + + if param_hash not in committee_cache: + committee_cache[param_hash] = _compute_committee(indices, seed, index, count) + return committee_cache[param_hash] + + # Access to overwrite spec constants based on configuration def apply_constants_preset(preset: Dict[str, Any]) -> None: global_vars = globals() diff --git a/specs/core/0_beacon-chain.md b/specs/core/0_beacon-chain.md index 7cbb9b67b..9a02c16e4 100644 --- a/specs/core/0_beacon-chain.md +++ b/specs/core/0_beacon-chain.md @@ -651,11 +651,13 @@ def is_slashable_validator(validator: Validator, epoch: Epoch) -> bool: ### `get_active_validator_indices` ```python -def get_active_validator_indices(state: BeaconState, epoch: Epoch) -> Tuple[ValidatorIndex, ...]: +def get_active_validator_indices(state: BeaconState, epoch: Epoch) -> List[ValidatorIndex, VALIDATOR_REGISTRY_SIZE]: """ Get active validator indices at ``epoch``. """ - return tuple(ValidatorIndex(i) for i, v in enumerate(state.validators) if is_active_validator(v, epoch)) + return List[ValidatorIndex, VALIDATOR_REGISTRY_SIZE]( + i for i, v in enumerate(state.validators) if is_active_validator(v, epoch) + ) ``` ### `increase_balance` @@ -873,7 +875,7 @@ def compute_committee(indices: Tuple[ValidatorIndex, ...], ```python def get_crosslink_committee(state: BeaconState, epoch: Epoch, shard: Shard) -> Tuple[ValidatorIndex, ...]: return compute_committee( - indices=get_active_validator_indices(state, epoch), + indices=tuple(get_active_validator_indices(state, epoch)), seed=generate_seed(state, epoch), index=(shard + SHARD_COUNT - get_epoch_start_shard(state, epoch)) % SHARD_COUNT, count=get_epoch_committee_count(state, epoch), diff --git a/test_libs/pyspec/eth2spec/test/helpers/custody.py b/test_libs/pyspec/eth2spec/test/helpers/custody.py index 67df12fcd..b49a6be1f 100644 --- a/test_libs/pyspec/eth2spec/test/helpers/custody.py +++ b/test_libs/pyspec/eth2spec/test/helpers/custody.py @@ -11,7 +11,7 @@ def get_valid_early_derived_secret_reveal(spec, state, epoch=None): epoch = current_epoch + spec.CUSTODY_PERIOD_TO_RANDAO_PADDING reveal = bls_sign( - message_hash=spec.hash_tree_root(epoch), + message_hash=spec.hash_tree_root(spec.Epoch(epoch)), privkey=privkeys[revealed_index], domain=spec.get_domain( state=state, @@ -20,7 +20,7 @@ def get_valid_early_derived_secret_reveal(spec, state, epoch=None): ), ) mask = bls_sign( - message_hash=spec.hash_tree_root(epoch), + message_hash=spec.hash_tree_root(spec.Epoch(epoch)), privkey=privkeys[masker_index], domain=spec.get_domain( state=state, diff --git a/test_libs/pyspec/eth2spec/utils/merkle_minimal.py b/test_libs/pyspec/eth2spec/utils/merkle_minimal.py index 21583ee92..038b555cf 100644 --- a/test_libs/pyspec/eth2spec/utils/merkle_minimal.py +++ b/test_libs/pyspec/eth2spec/utils/merkle_minimal.py @@ -4,7 +4,7 @@ from .hash_function import hash ZERO_BYTES32 = b'\x00' * 32 zerohashes = [ZERO_BYTES32] -for layer in range(1, 32): +for layer in range(1, 100): zerohashes.append(hash(zerohashes[layer - 1] + zerohashes[layer - 1])) diff --git a/test_libs/pyspec/eth2spec/utils/ssz/ssz_impl.py b/test_libs/pyspec/eth2spec/utils/ssz/ssz_impl.py index a9c36649b..4b64c9162 100644 --- a/test_libs/pyspec/eth2spec/utils/ssz/ssz_impl.py +++ b/test_libs/pyspec/eth2spec/utils/ssz/ssz_impl.py @@ -1,7 +1,7 @@ from ..merkle_minimal import merkleize_chunks from ..hash_function import hash from .ssz_typing import ( - SSZValue, SSZType, BasicValue, BasicType, Series, Elements, Bit, Container, List, Bytes, BytesN, uint + SSZValue, SSZType, BasicValue, BasicType, Series, Elements, Bit, Container, List, Bytes, BytesN, uint, ) # SSZ Serialization @@ -143,5 +143,6 @@ def hash_tree_root(obj: SSZValue): def signing_root(obj: Container): # ignore last field - leaves = [hash_tree_root(field) for field in obj[:-1]] + fields = [field for field in obj][:-1] + leaves = [hash_tree_root(f) for f in fields] return merkleize_chunks(chunkify(b''.join(leaves))) diff --git a/test_libs/pyspec/eth2spec/utils/ssz/ssz_typing.py b/test_libs/pyspec/eth2spec/utils/ssz/ssz_typing.py index 51a790853..381dadf9e 100644 --- a/test_libs/pyspec/eth2spec/utils/ssz/ssz_typing.py +++ b/test_libs/pyspec/eth2spec/utils/ssz/ssz_typing.py @@ -98,7 +98,7 @@ def coerce_type_maybe(v, typ: SSZType, strict: bool = False): return typ(v) elif isinstance(v, (list, tuple)): return typ(*v) - elif isinstance(v, bytes): + elif isinstance(v, (bytes, BytesN, Bytes)): return typ(v) elif isinstance(v, GeneratorType): return typ(v) @@ -154,7 +154,8 @@ class Container(Series, metaclass=SSZType): super().__setattr__(name, value) def __repr__(self): - return repr({field: getattr(self, field) for field in self.get_fields().keys()}) + return repr({field: (getattr(self, field) if hasattr(self, field) else 'unset') + for field in self.get_fields().keys()}) def __str__(self): output = [f'{self.__class__.__name__}'] @@ -236,15 +237,24 @@ class ParamsMeta(SSZType): raise TypeError("provided parameters {} mismatch required parameter count {}".format(params, i)) return res - def __instancecheck__(self, obj): - if obj.__class__.__name__ != self.__name__: + def __subclasscheck__(self, subclass): + # check regular class system if we can, solves a lot of the normal cases. + if super().__subclasscheck__(subclass): + return True + # if they are not normal subclasses, they are of the same class. + # then they should have the same name + if subclass.__name__ != self.__name__: return False + # If they do have the same name, they should also have the same params. for name, typ in self.__annotations__.items(): - if hasattr(self, name) and hasattr(obj.__class__, name) \ - and getattr(obj.__class__, name) != getattr(self, name): + if hasattr(self, name) and hasattr(subclass, name) \ + and getattr(subclass, name) != getattr(self, name): return False return True + def __instancecheck__(self, obj): + return self.__subclasscheck__(obj.__class__) + class ElementsType(ParamsMeta): elem_type: SSZType @@ -305,9 +315,6 @@ class Elements(ParamsBase, metaclass=ElementsType): def __iter__(self) -> Iterator[SSZValue]: return iter(self.items) - def __eq__(self, other): - return self.items == other.items - class List(Elements): @@ -366,9 +373,6 @@ class BytesLike(Elements, metaclass=BytesType): cls = self.__class__ return f"{cls.__name__}[{cls.length}]: {self.hex()}" - def hex(self) -> str: - return self.items.hex() - class Bytes(BytesLike): @@ -398,7 +402,7 @@ class BytesN(BytesLike): # Helpers for common BytesN types. -Bytes4 = BytesN[4] -Bytes32 = BytesN[32] -Bytes48 = BytesN[48] -Bytes96 = BytesN[96] +Bytes4: BytesType = BytesN[4] +Bytes32: BytesType = BytesN[32] +Bytes48: BytesType = BytesN[48] +Bytes96: BytesType = BytesN[96] diff --git a/test_libs/pyspec/eth2spec/utils/ssz/test_ssz_typing.py b/test_libs/pyspec/eth2spec/utils/ssz/test_ssz_typing.py index 6bb56f4e5..daa923aa7 100644 --- a/test_libs/pyspec/eth2spec/utils/ssz/test_ssz_typing.py +++ b/test_libs/pyspec/eth2spec/utils/ssz/test_ssz_typing.py @@ -1,7 +1,8 @@ from .ssz_typing import ( SSZValue, SSZType, BasicValue, BasicType, Series, ElementsType, Elements, Bit, Container, List, Vector, Bytes, BytesN, - uint, uint8, uint16, uint32, uint64, uint128, uint256 + uint, uint8, uint16, uint32, uint64, uint128, uint256, + Bytes32, Bytes48 ) @@ -193,3 +194,20 @@ def test_list(): assert False except IndexError: pass + + +def test_bytesn_subclass(): + assert isinstance(BytesN[32](b'\xab' * 32), Bytes32) + assert not isinstance(BytesN[32](b'\xab' * 32), Bytes48) + assert issubclass(BytesN[32](b'\xab' * 32).type(), Bytes32) + assert issubclass(BytesN[32], Bytes32) + + class Hash(Bytes32): + pass + + assert isinstance(Hash(b'\xab' * 32), Bytes32) + assert not isinstance(Hash(b'\xab' * 32), Bytes48) + assert issubclass(Hash(b'\xab' * 32).type(), Bytes32) + assert issubclass(Hash, Bytes32) + + assert not issubclass(Bytes48, Bytes32)