fixes for class based ssz typing
This commit is contained in:
parent
7cdec746b4
commit
4e747fb887
|
@ -65,22 +65,6 @@ def get_ssz_type_by_name(name: str) -> Container:
|
|||
return globals()[name]
|
||||
|
||||
|
||||
# Monkey patch validator compute committee code
|
||||
_compute_committee = compute_committee
|
||||
committee_cache: Dict[Tuple[Hash, Hash, int, int], Tuple[ValidatorIndex, ...]] = {}
|
||||
|
||||
|
||||
def compute_committee(indices: Tuple[ValidatorIndex, ...], # type: ignore
|
||||
seed: Hash,
|
||||
index: int,
|
||||
count: int) -> Tuple[ValidatorIndex, ...]:
|
||||
param_hash = (hash_tree_root(indices), seed, index, count)
|
||||
|
||||
if param_hash not in committee_cache:
|
||||
committee_cache[param_hash] = _compute_committee(indices, seed, index, count)
|
||||
return committee_cache[param_hash]
|
||||
|
||||
|
||||
# Monkey patch hash cache
|
||||
_hash = hash
|
||||
hash_cache: Dict[bytes, Hash] = {}
|
||||
|
@ -92,6 +76,22 @@ def hash(x: bytes) -> Hash:
|
|||
return hash_cache[x]
|
||||
|
||||
|
||||
# Monkey patch validator compute committee code
|
||||
_compute_committee = compute_committee
|
||||
committee_cache: Dict[Tuple[Hash, Hash, int, int], Tuple[ValidatorIndex, ...]] = {}
|
||||
|
||||
|
||||
def compute_committee(indices: Tuple[ValidatorIndex, ...], # type: ignore
|
||||
seed: Hash,
|
||||
index: int,
|
||||
count: int) -> Tuple[ValidatorIndex, ...]:
|
||||
param_hash = (hash(b''.join(index.to_bytes(length=4, byteorder='little') for index in indices)), seed, index, count)
|
||||
|
||||
if param_hash not in committee_cache:
|
||||
committee_cache[param_hash] = _compute_committee(indices, seed, index, count)
|
||||
return committee_cache[param_hash]
|
||||
|
||||
|
||||
# Access to overwrite spec constants based on configuration
|
||||
def apply_constants_preset(preset: Dict[str, Any]) -> None:
|
||||
global_vars = globals()
|
||||
|
|
|
@ -651,11 +651,13 @@ def is_slashable_validator(validator: Validator, epoch: Epoch) -> bool:
|
|||
### `get_active_validator_indices`
|
||||
|
||||
```python
|
||||
def get_active_validator_indices(state: BeaconState, epoch: Epoch) -> Tuple[ValidatorIndex, ...]:
|
||||
def get_active_validator_indices(state: BeaconState, epoch: Epoch) -> List[ValidatorIndex, VALIDATOR_REGISTRY_SIZE]:
|
||||
"""
|
||||
Get active validator indices at ``epoch``.
|
||||
"""
|
||||
return tuple(ValidatorIndex(i) for i, v in enumerate(state.validators) if is_active_validator(v, epoch))
|
||||
return List[ValidatorIndex, VALIDATOR_REGISTRY_SIZE](
|
||||
i for i, v in enumerate(state.validators) if is_active_validator(v, epoch)
|
||||
)
|
||||
```
|
||||
|
||||
### `increase_balance`
|
||||
|
@ -873,7 +875,7 @@ def compute_committee(indices: Tuple[ValidatorIndex, ...],
|
|||
```python
|
||||
def get_crosslink_committee(state: BeaconState, epoch: Epoch, shard: Shard) -> Tuple[ValidatorIndex, ...]:
|
||||
return compute_committee(
|
||||
indices=get_active_validator_indices(state, epoch),
|
||||
indices=tuple(get_active_validator_indices(state, epoch)),
|
||||
seed=generate_seed(state, epoch),
|
||||
index=(shard + SHARD_COUNT - get_epoch_start_shard(state, epoch)) % SHARD_COUNT,
|
||||
count=get_epoch_committee_count(state, epoch),
|
||||
|
|
|
@ -11,7 +11,7 @@ def get_valid_early_derived_secret_reveal(spec, state, epoch=None):
|
|||
epoch = current_epoch + spec.CUSTODY_PERIOD_TO_RANDAO_PADDING
|
||||
|
||||
reveal = bls_sign(
|
||||
message_hash=spec.hash_tree_root(epoch),
|
||||
message_hash=spec.hash_tree_root(spec.Epoch(epoch)),
|
||||
privkey=privkeys[revealed_index],
|
||||
domain=spec.get_domain(
|
||||
state=state,
|
||||
|
@ -20,7 +20,7 @@ def get_valid_early_derived_secret_reveal(spec, state, epoch=None):
|
|||
),
|
||||
)
|
||||
mask = bls_sign(
|
||||
message_hash=spec.hash_tree_root(epoch),
|
||||
message_hash=spec.hash_tree_root(spec.Epoch(epoch)),
|
||||
privkey=privkeys[masker_index],
|
||||
domain=spec.get_domain(
|
||||
state=state,
|
||||
|
|
|
@ -4,7 +4,7 @@ from .hash_function import hash
|
|||
ZERO_BYTES32 = b'\x00' * 32
|
||||
|
||||
zerohashes = [ZERO_BYTES32]
|
||||
for layer in range(1, 32):
|
||||
for layer in range(1, 100):
|
||||
zerohashes.append(hash(zerohashes[layer - 1] + zerohashes[layer - 1]))
|
||||
|
||||
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
from ..merkle_minimal import merkleize_chunks
|
||||
from ..hash_function import hash
|
||||
from .ssz_typing import (
|
||||
SSZValue, SSZType, BasicValue, BasicType, Series, Elements, Bit, Container, List, Bytes, BytesN, uint
|
||||
SSZValue, SSZType, BasicValue, BasicType, Series, Elements, Bit, Container, List, Bytes, BytesN, uint,
|
||||
)
|
||||
|
||||
# SSZ Serialization
|
||||
|
@ -143,5 +143,6 @@ def hash_tree_root(obj: SSZValue):
|
|||
|
||||
def signing_root(obj: Container):
|
||||
# ignore last field
|
||||
leaves = [hash_tree_root(field) for field in obj[:-1]]
|
||||
fields = [field for field in obj][:-1]
|
||||
leaves = [hash_tree_root(f) for f in fields]
|
||||
return merkleize_chunks(chunkify(b''.join(leaves)))
|
||||
|
|
|
@ -98,7 +98,7 @@ def coerce_type_maybe(v, typ: SSZType, strict: bool = False):
|
|||
return typ(v)
|
||||
elif isinstance(v, (list, tuple)):
|
||||
return typ(*v)
|
||||
elif isinstance(v, bytes):
|
||||
elif isinstance(v, (bytes, BytesN, Bytes)):
|
||||
return typ(v)
|
||||
elif isinstance(v, GeneratorType):
|
||||
return typ(v)
|
||||
|
@ -154,7 +154,8 @@ class Container(Series, metaclass=SSZType):
|
|||
super().__setattr__(name, value)
|
||||
|
||||
def __repr__(self):
|
||||
return repr({field: getattr(self, field) for field in self.get_fields().keys()})
|
||||
return repr({field: (getattr(self, field) if hasattr(self, field) else 'unset')
|
||||
for field in self.get_fields().keys()})
|
||||
|
||||
def __str__(self):
|
||||
output = [f'{self.__class__.__name__}']
|
||||
|
@ -236,15 +237,24 @@ class ParamsMeta(SSZType):
|
|||
raise TypeError("provided parameters {} mismatch required parameter count {}".format(params, i))
|
||||
return res
|
||||
|
||||
def __instancecheck__(self, obj):
|
||||
if obj.__class__.__name__ != self.__name__:
|
||||
def __subclasscheck__(self, subclass):
|
||||
# check regular class system if we can, solves a lot of the normal cases.
|
||||
if super().__subclasscheck__(subclass):
|
||||
return True
|
||||
# if they are not normal subclasses, they are of the same class.
|
||||
# then they should have the same name
|
||||
if subclass.__name__ != self.__name__:
|
||||
return False
|
||||
# If they do have the same name, they should also have the same params.
|
||||
for name, typ in self.__annotations__.items():
|
||||
if hasattr(self, name) and hasattr(obj.__class__, name) \
|
||||
and getattr(obj.__class__, name) != getattr(self, name):
|
||||
if hasattr(self, name) and hasattr(subclass, name) \
|
||||
and getattr(subclass, name) != getattr(self, name):
|
||||
return False
|
||||
return True
|
||||
|
||||
def __instancecheck__(self, obj):
|
||||
return self.__subclasscheck__(obj.__class__)
|
||||
|
||||
|
||||
class ElementsType(ParamsMeta):
|
||||
elem_type: SSZType
|
||||
|
@ -305,9 +315,6 @@ class Elements(ParamsBase, metaclass=ElementsType):
|
|||
def __iter__(self) -> Iterator[SSZValue]:
|
||||
return iter(self.items)
|
||||
|
||||
def __eq__(self, other):
|
||||
return self.items == other.items
|
||||
|
||||
|
||||
class List(Elements):
|
||||
|
||||
|
@ -366,9 +373,6 @@ class BytesLike(Elements, metaclass=BytesType):
|
|||
cls = self.__class__
|
||||
return f"{cls.__name__}[{cls.length}]: {self.hex()}"
|
||||
|
||||
def hex(self) -> str:
|
||||
return self.items.hex()
|
||||
|
||||
|
||||
class Bytes(BytesLike):
|
||||
|
||||
|
@ -398,7 +402,7 @@ class BytesN(BytesLike):
|
|||
|
||||
|
||||
# Helpers for common BytesN types.
|
||||
Bytes4 = BytesN[4]
|
||||
Bytes32 = BytesN[32]
|
||||
Bytes48 = BytesN[48]
|
||||
Bytes96 = BytesN[96]
|
||||
Bytes4: BytesType = BytesN[4]
|
||||
Bytes32: BytesType = BytesN[32]
|
||||
Bytes48: BytesType = BytesN[48]
|
||||
Bytes96: BytesType = BytesN[96]
|
||||
|
|
|
@ -1,7 +1,8 @@
|
|||
from .ssz_typing import (
|
||||
SSZValue, SSZType, BasicValue, BasicType, Series, ElementsType,
|
||||
Elements, Bit, Container, List, Vector, Bytes, BytesN,
|
||||
uint, uint8, uint16, uint32, uint64, uint128, uint256
|
||||
uint, uint8, uint16, uint32, uint64, uint128, uint256,
|
||||
Bytes32, Bytes48
|
||||
)
|
||||
|
||||
|
||||
|
@ -193,3 +194,20 @@ def test_list():
|
|||
assert False
|
||||
except IndexError:
|
||||
pass
|
||||
|
||||
|
||||
def test_bytesn_subclass():
|
||||
assert isinstance(BytesN[32](b'\xab' * 32), Bytes32)
|
||||
assert not isinstance(BytesN[32](b'\xab' * 32), Bytes48)
|
||||
assert issubclass(BytesN[32](b'\xab' * 32).type(), Bytes32)
|
||||
assert issubclass(BytesN[32], Bytes32)
|
||||
|
||||
class Hash(Bytes32):
|
||||
pass
|
||||
|
||||
assert isinstance(Hash(b'\xab' * 32), Bytes32)
|
||||
assert not isinstance(Hash(b'\xab' * 32), Bytes48)
|
||||
assert issubclass(Hash(b'\xab' * 32).type(), Bytes32)
|
||||
assert issubclass(Hash, Bytes32)
|
||||
|
||||
assert not issubclass(Bytes48, Bytes32)
|
||||
|
|
Loading…
Reference in New Issue