mirror of
https://github.com/status-im/eth2.0-specs.git
synced 2025-01-12 19:54:34 +00:00
fixes for class based ssz typing
This commit is contained in:
parent
7cdec746b4
commit
4e747fb887
@ -65,22 +65,6 @@ def get_ssz_type_by_name(name: str) -> Container:
|
|||||||
return globals()[name]
|
return globals()[name]
|
||||||
|
|
||||||
|
|
||||||
# Monkey patch validator compute committee code
|
|
||||||
_compute_committee = compute_committee
|
|
||||||
committee_cache: Dict[Tuple[Hash, Hash, int, int], Tuple[ValidatorIndex, ...]] = {}
|
|
||||||
|
|
||||||
|
|
||||||
def compute_committee(indices: Tuple[ValidatorIndex, ...], # type: ignore
|
|
||||||
seed: Hash,
|
|
||||||
index: int,
|
|
||||||
count: int) -> Tuple[ValidatorIndex, ...]:
|
|
||||||
param_hash = (hash_tree_root(indices), seed, index, count)
|
|
||||||
|
|
||||||
if param_hash not in committee_cache:
|
|
||||||
committee_cache[param_hash] = _compute_committee(indices, seed, index, count)
|
|
||||||
return committee_cache[param_hash]
|
|
||||||
|
|
||||||
|
|
||||||
# Monkey patch hash cache
|
# Monkey patch hash cache
|
||||||
_hash = hash
|
_hash = hash
|
||||||
hash_cache: Dict[bytes, Hash] = {}
|
hash_cache: Dict[bytes, Hash] = {}
|
||||||
@ -92,6 +76,22 @@ def hash(x: bytes) -> Hash:
|
|||||||
return hash_cache[x]
|
return hash_cache[x]
|
||||||
|
|
||||||
|
|
||||||
|
# Monkey patch validator compute committee code
|
||||||
|
_compute_committee = compute_committee
|
||||||
|
committee_cache: Dict[Tuple[Hash, Hash, int, int], Tuple[ValidatorIndex, ...]] = {}
|
||||||
|
|
||||||
|
|
||||||
|
def compute_committee(indices: Tuple[ValidatorIndex, ...], # type: ignore
|
||||||
|
seed: Hash,
|
||||||
|
index: int,
|
||||||
|
count: int) -> Tuple[ValidatorIndex, ...]:
|
||||||
|
param_hash = (hash(b''.join(index.to_bytes(length=4, byteorder='little') for index in indices)), seed, index, count)
|
||||||
|
|
||||||
|
if param_hash not in committee_cache:
|
||||||
|
committee_cache[param_hash] = _compute_committee(indices, seed, index, count)
|
||||||
|
return committee_cache[param_hash]
|
||||||
|
|
||||||
|
|
||||||
# Access to overwrite spec constants based on configuration
|
# Access to overwrite spec constants based on configuration
|
||||||
def apply_constants_preset(preset: Dict[str, Any]) -> None:
|
def apply_constants_preset(preset: Dict[str, Any]) -> None:
|
||||||
global_vars = globals()
|
global_vars = globals()
|
||||||
|
@ -651,11 +651,13 @@ def is_slashable_validator(validator: Validator, epoch: Epoch) -> bool:
|
|||||||
### `get_active_validator_indices`
|
### `get_active_validator_indices`
|
||||||
|
|
||||||
```python
|
```python
|
||||||
def get_active_validator_indices(state: BeaconState, epoch: Epoch) -> Tuple[ValidatorIndex, ...]:
|
def get_active_validator_indices(state: BeaconState, epoch: Epoch) -> List[ValidatorIndex, VALIDATOR_REGISTRY_SIZE]:
|
||||||
"""
|
"""
|
||||||
Get active validator indices at ``epoch``.
|
Get active validator indices at ``epoch``.
|
||||||
"""
|
"""
|
||||||
return tuple(ValidatorIndex(i) for i, v in enumerate(state.validators) if is_active_validator(v, epoch))
|
return List[ValidatorIndex, VALIDATOR_REGISTRY_SIZE](
|
||||||
|
i for i, v in enumerate(state.validators) if is_active_validator(v, epoch)
|
||||||
|
)
|
||||||
```
|
```
|
||||||
|
|
||||||
### `increase_balance`
|
### `increase_balance`
|
||||||
@ -873,7 +875,7 @@ def compute_committee(indices: Tuple[ValidatorIndex, ...],
|
|||||||
```python
|
```python
|
||||||
def get_crosslink_committee(state: BeaconState, epoch: Epoch, shard: Shard) -> Tuple[ValidatorIndex, ...]:
|
def get_crosslink_committee(state: BeaconState, epoch: Epoch, shard: Shard) -> Tuple[ValidatorIndex, ...]:
|
||||||
return compute_committee(
|
return compute_committee(
|
||||||
indices=get_active_validator_indices(state, epoch),
|
indices=tuple(get_active_validator_indices(state, epoch)),
|
||||||
seed=generate_seed(state, epoch),
|
seed=generate_seed(state, epoch),
|
||||||
index=(shard + SHARD_COUNT - get_epoch_start_shard(state, epoch)) % SHARD_COUNT,
|
index=(shard + SHARD_COUNT - get_epoch_start_shard(state, epoch)) % SHARD_COUNT,
|
||||||
count=get_epoch_committee_count(state, epoch),
|
count=get_epoch_committee_count(state, epoch),
|
||||||
|
@ -11,7 +11,7 @@ def get_valid_early_derived_secret_reveal(spec, state, epoch=None):
|
|||||||
epoch = current_epoch + spec.CUSTODY_PERIOD_TO_RANDAO_PADDING
|
epoch = current_epoch + spec.CUSTODY_PERIOD_TO_RANDAO_PADDING
|
||||||
|
|
||||||
reveal = bls_sign(
|
reveal = bls_sign(
|
||||||
message_hash=spec.hash_tree_root(epoch),
|
message_hash=spec.hash_tree_root(spec.Epoch(epoch)),
|
||||||
privkey=privkeys[revealed_index],
|
privkey=privkeys[revealed_index],
|
||||||
domain=spec.get_domain(
|
domain=spec.get_domain(
|
||||||
state=state,
|
state=state,
|
||||||
@ -20,7 +20,7 @@ def get_valid_early_derived_secret_reveal(spec, state, epoch=None):
|
|||||||
),
|
),
|
||||||
)
|
)
|
||||||
mask = bls_sign(
|
mask = bls_sign(
|
||||||
message_hash=spec.hash_tree_root(epoch),
|
message_hash=spec.hash_tree_root(spec.Epoch(epoch)),
|
||||||
privkey=privkeys[masker_index],
|
privkey=privkeys[masker_index],
|
||||||
domain=spec.get_domain(
|
domain=spec.get_domain(
|
||||||
state=state,
|
state=state,
|
||||||
|
@ -4,7 +4,7 @@ from .hash_function import hash
|
|||||||
ZERO_BYTES32 = b'\x00' * 32
|
ZERO_BYTES32 = b'\x00' * 32
|
||||||
|
|
||||||
zerohashes = [ZERO_BYTES32]
|
zerohashes = [ZERO_BYTES32]
|
||||||
for layer in range(1, 32):
|
for layer in range(1, 100):
|
||||||
zerohashes.append(hash(zerohashes[layer - 1] + zerohashes[layer - 1]))
|
zerohashes.append(hash(zerohashes[layer - 1] + zerohashes[layer - 1]))
|
||||||
|
|
||||||
|
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
from ..merkle_minimal import merkleize_chunks
|
from ..merkle_minimal import merkleize_chunks
|
||||||
from ..hash_function import hash
|
from ..hash_function import hash
|
||||||
from .ssz_typing import (
|
from .ssz_typing import (
|
||||||
SSZValue, SSZType, BasicValue, BasicType, Series, Elements, Bit, Container, List, Bytes, BytesN, uint
|
SSZValue, SSZType, BasicValue, BasicType, Series, Elements, Bit, Container, List, Bytes, BytesN, uint,
|
||||||
)
|
)
|
||||||
|
|
||||||
# SSZ Serialization
|
# SSZ Serialization
|
||||||
@ -143,5 +143,6 @@ def hash_tree_root(obj: SSZValue):
|
|||||||
|
|
||||||
def signing_root(obj: Container):
|
def signing_root(obj: Container):
|
||||||
# ignore last field
|
# ignore last field
|
||||||
leaves = [hash_tree_root(field) for field in obj[:-1]]
|
fields = [field for field in obj][:-1]
|
||||||
|
leaves = [hash_tree_root(f) for f in fields]
|
||||||
return merkleize_chunks(chunkify(b''.join(leaves)))
|
return merkleize_chunks(chunkify(b''.join(leaves)))
|
||||||
|
@ -98,7 +98,7 @@ def coerce_type_maybe(v, typ: SSZType, strict: bool = False):
|
|||||||
return typ(v)
|
return typ(v)
|
||||||
elif isinstance(v, (list, tuple)):
|
elif isinstance(v, (list, tuple)):
|
||||||
return typ(*v)
|
return typ(*v)
|
||||||
elif isinstance(v, bytes):
|
elif isinstance(v, (bytes, BytesN, Bytes)):
|
||||||
return typ(v)
|
return typ(v)
|
||||||
elif isinstance(v, GeneratorType):
|
elif isinstance(v, GeneratorType):
|
||||||
return typ(v)
|
return typ(v)
|
||||||
@ -154,7 +154,8 @@ class Container(Series, metaclass=SSZType):
|
|||||||
super().__setattr__(name, value)
|
super().__setattr__(name, value)
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
return repr({field: getattr(self, field) for field in self.get_fields().keys()})
|
return repr({field: (getattr(self, field) if hasattr(self, field) else 'unset')
|
||||||
|
for field in self.get_fields().keys()})
|
||||||
|
|
||||||
def __str__(self):
|
def __str__(self):
|
||||||
output = [f'{self.__class__.__name__}']
|
output = [f'{self.__class__.__name__}']
|
||||||
@ -236,15 +237,24 @@ class ParamsMeta(SSZType):
|
|||||||
raise TypeError("provided parameters {} mismatch required parameter count {}".format(params, i))
|
raise TypeError("provided parameters {} mismatch required parameter count {}".format(params, i))
|
||||||
return res
|
return res
|
||||||
|
|
||||||
def __instancecheck__(self, obj):
|
def __subclasscheck__(self, subclass):
|
||||||
if obj.__class__.__name__ != self.__name__:
|
# check regular class system if we can, solves a lot of the normal cases.
|
||||||
|
if super().__subclasscheck__(subclass):
|
||||||
|
return True
|
||||||
|
# if they are not normal subclasses, they are of the same class.
|
||||||
|
# then they should have the same name
|
||||||
|
if subclass.__name__ != self.__name__:
|
||||||
return False
|
return False
|
||||||
|
# If they do have the same name, they should also have the same params.
|
||||||
for name, typ in self.__annotations__.items():
|
for name, typ in self.__annotations__.items():
|
||||||
if hasattr(self, name) and hasattr(obj.__class__, name) \
|
if hasattr(self, name) and hasattr(subclass, name) \
|
||||||
and getattr(obj.__class__, name) != getattr(self, name):
|
and getattr(subclass, name) != getattr(self, name):
|
||||||
return False
|
return False
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
def __instancecheck__(self, obj):
|
||||||
|
return self.__subclasscheck__(obj.__class__)
|
||||||
|
|
||||||
|
|
||||||
class ElementsType(ParamsMeta):
|
class ElementsType(ParamsMeta):
|
||||||
elem_type: SSZType
|
elem_type: SSZType
|
||||||
@ -305,9 +315,6 @@ class Elements(ParamsBase, metaclass=ElementsType):
|
|||||||
def __iter__(self) -> Iterator[SSZValue]:
|
def __iter__(self) -> Iterator[SSZValue]:
|
||||||
return iter(self.items)
|
return iter(self.items)
|
||||||
|
|
||||||
def __eq__(self, other):
|
|
||||||
return self.items == other.items
|
|
||||||
|
|
||||||
|
|
||||||
class List(Elements):
|
class List(Elements):
|
||||||
|
|
||||||
@ -366,9 +373,6 @@ class BytesLike(Elements, metaclass=BytesType):
|
|||||||
cls = self.__class__
|
cls = self.__class__
|
||||||
return f"{cls.__name__}[{cls.length}]: {self.hex()}"
|
return f"{cls.__name__}[{cls.length}]: {self.hex()}"
|
||||||
|
|
||||||
def hex(self) -> str:
|
|
||||||
return self.items.hex()
|
|
||||||
|
|
||||||
|
|
||||||
class Bytes(BytesLike):
|
class Bytes(BytesLike):
|
||||||
|
|
||||||
@ -398,7 +402,7 @@ class BytesN(BytesLike):
|
|||||||
|
|
||||||
|
|
||||||
# Helpers for common BytesN types.
|
# Helpers for common BytesN types.
|
||||||
Bytes4 = BytesN[4]
|
Bytes4: BytesType = BytesN[4]
|
||||||
Bytes32 = BytesN[32]
|
Bytes32: BytesType = BytesN[32]
|
||||||
Bytes48 = BytesN[48]
|
Bytes48: BytesType = BytesN[48]
|
||||||
Bytes96 = BytesN[96]
|
Bytes96: BytesType = BytesN[96]
|
||||||
|
@ -1,7 +1,8 @@
|
|||||||
from .ssz_typing import (
|
from .ssz_typing import (
|
||||||
SSZValue, SSZType, BasicValue, BasicType, Series, ElementsType,
|
SSZValue, SSZType, BasicValue, BasicType, Series, ElementsType,
|
||||||
Elements, Bit, Container, List, Vector, Bytes, BytesN,
|
Elements, Bit, Container, List, Vector, Bytes, BytesN,
|
||||||
uint, uint8, uint16, uint32, uint64, uint128, uint256
|
uint, uint8, uint16, uint32, uint64, uint128, uint256,
|
||||||
|
Bytes32, Bytes48
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -193,3 +194,20 @@ def test_list():
|
|||||||
assert False
|
assert False
|
||||||
except IndexError:
|
except IndexError:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def test_bytesn_subclass():
|
||||||
|
assert isinstance(BytesN[32](b'\xab' * 32), Bytes32)
|
||||||
|
assert not isinstance(BytesN[32](b'\xab' * 32), Bytes48)
|
||||||
|
assert issubclass(BytesN[32](b'\xab' * 32).type(), Bytes32)
|
||||||
|
assert issubclass(BytesN[32], Bytes32)
|
||||||
|
|
||||||
|
class Hash(Bytes32):
|
||||||
|
pass
|
||||||
|
|
||||||
|
assert isinstance(Hash(b'\xab' * 32), Bytes32)
|
||||||
|
assert not isinstance(Hash(b'\xab' * 32), Bytes48)
|
||||||
|
assert issubclass(Hash(b'\xab' * 32).type(), Bytes32)
|
||||||
|
assert issubclass(Hash, Bytes32)
|
||||||
|
|
||||||
|
assert not issubclass(Bytes48, Bytes32)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user