fixes for class based ssz typing

This commit is contained in:
protolambda 2019-06-20 20:25:22 +02:00
parent 7cdec746b4
commit 4e747fb887
No known key found for this signature in database
GPG Key ID: EC89FDBB2B4C7623
7 changed files with 66 additions and 41 deletions

View File

@ -65,22 +65,6 @@ def get_ssz_type_by_name(name: str) -> Container:
return globals()[name]
# Monkey patch validator compute committee code
_compute_committee = compute_committee
committee_cache: Dict[Tuple[Hash, Hash, int, int], Tuple[ValidatorIndex, ...]] = {}
def compute_committee(indices: Tuple[ValidatorIndex, ...], # type: ignore
seed: Hash,
index: int,
count: int) -> Tuple[ValidatorIndex, ...]:
param_hash = (hash_tree_root(indices), seed, index, count)
if param_hash not in committee_cache:
committee_cache[param_hash] = _compute_committee(indices, seed, index, count)
return committee_cache[param_hash]
# Monkey patch hash cache
_hash = hash
hash_cache: Dict[bytes, Hash] = {}
@ -92,6 +76,22 @@ def hash(x: bytes) -> Hash:
return hash_cache[x]
# Monkey patch validator compute committee code
_compute_committee = compute_committee
committee_cache: Dict[Tuple[Hash, Hash, int, int], Tuple[ValidatorIndex, ...]] = {}
def compute_committee(indices: Tuple[ValidatorIndex, ...], # type: ignore
seed: Hash,
index: int,
count: int) -> Tuple[ValidatorIndex, ...]:
param_hash = (hash(b''.join(index.to_bytes(length=4, byteorder='little') for index in indices)), seed, index, count)
if param_hash not in committee_cache:
committee_cache[param_hash] = _compute_committee(indices, seed, index, count)
return committee_cache[param_hash]
# Access to overwrite spec constants based on configuration
def apply_constants_preset(preset: Dict[str, Any]) -> None:
global_vars = globals()

View File

@ -651,11 +651,13 @@ def is_slashable_validator(validator: Validator, epoch: Epoch) -> bool:
### `get_active_validator_indices`
```python
def get_active_validator_indices(state: BeaconState, epoch: Epoch) -> Tuple[ValidatorIndex, ...]:
def get_active_validator_indices(state: BeaconState, epoch: Epoch) -> List[ValidatorIndex, VALIDATOR_REGISTRY_SIZE]:
"""
Get active validator indices at ``epoch``.
"""
return tuple(ValidatorIndex(i) for i, v in enumerate(state.validators) if is_active_validator(v, epoch))
return List[ValidatorIndex, VALIDATOR_REGISTRY_SIZE](
i for i, v in enumerate(state.validators) if is_active_validator(v, epoch)
)
```
### `increase_balance`
@ -873,7 +875,7 @@ def compute_committee(indices: Tuple[ValidatorIndex, ...],
```python
def get_crosslink_committee(state: BeaconState, epoch: Epoch, shard: Shard) -> Tuple[ValidatorIndex, ...]:
return compute_committee(
indices=get_active_validator_indices(state, epoch),
indices=tuple(get_active_validator_indices(state, epoch)),
seed=generate_seed(state, epoch),
index=(shard + SHARD_COUNT - get_epoch_start_shard(state, epoch)) % SHARD_COUNT,
count=get_epoch_committee_count(state, epoch),

View File

@ -11,7 +11,7 @@ def get_valid_early_derived_secret_reveal(spec, state, epoch=None):
epoch = current_epoch + spec.CUSTODY_PERIOD_TO_RANDAO_PADDING
reveal = bls_sign(
message_hash=spec.hash_tree_root(epoch),
message_hash=spec.hash_tree_root(spec.Epoch(epoch)),
privkey=privkeys[revealed_index],
domain=spec.get_domain(
state=state,
@ -20,7 +20,7 @@ def get_valid_early_derived_secret_reveal(spec, state, epoch=None):
),
)
mask = bls_sign(
message_hash=spec.hash_tree_root(epoch),
message_hash=spec.hash_tree_root(spec.Epoch(epoch)),
privkey=privkeys[masker_index],
domain=spec.get_domain(
state=state,

View File

@ -4,7 +4,7 @@ from .hash_function import hash
ZERO_BYTES32 = b'\x00' * 32
zerohashes = [ZERO_BYTES32]
for layer in range(1, 32):
for layer in range(1, 100):
zerohashes.append(hash(zerohashes[layer - 1] + zerohashes[layer - 1]))

View File

@ -1,7 +1,7 @@
from ..merkle_minimal import merkleize_chunks
from ..hash_function import hash
from .ssz_typing import (
SSZValue, SSZType, BasicValue, BasicType, Series, Elements, Bit, Container, List, Bytes, BytesN, uint
SSZValue, SSZType, BasicValue, BasicType, Series, Elements, Bit, Container, List, Bytes, BytesN, uint,
)
# SSZ Serialization
@ -143,5 +143,6 @@ def hash_tree_root(obj: SSZValue):
def signing_root(obj: Container):
# ignore last field
leaves = [hash_tree_root(field) for field in obj[:-1]]
fields = [field for field in obj][:-1]
leaves = [hash_tree_root(f) for f in fields]
return merkleize_chunks(chunkify(b''.join(leaves)))

View File

@ -98,7 +98,7 @@ def coerce_type_maybe(v, typ: SSZType, strict: bool = False):
return typ(v)
elif isinstance(v, (list, tuple)):
return typ(*v)
elif isinstance(v, bytes):
elif isinstance(v, (bytes, BytesN, Bytes)):
return typ(v)
elif isinstance(v, GeneratorType):
return typ(v)
@ -154,7 +154,8 @@ class Container(Series, metaclass=SSZType):
super().__setattr__(name, value)
def __repr__(self):
return repr({field: getattr(self, field) for field in self.get_fields().keys()})
return repr({field: (getattr(self, field) if hasattr(self, field) else 'unset')
for field in self.get_fields().keys()})
def __str__(self):
output = [f'{self.__class__.__name__}']
@ -236,15 +237,24 @@ class ParamsMeta(SSZType):
raise TypeError("provided parameters {} mismatch required parameter count {}".format(params, i))
return res
def __instancecheck__(self, obj):
if obj.__class__.__name__ != self.__name__:
def __subclasscheck__(self, subclass):
# check regular class system if we can, solves a lot of the normal cases.
if super().__subclasscheck__(subclass):
return True
# if they are not normal subclasses, they are of the same class.
# then they should have the same name
if subclass.__name__ != self.__name__:
return False
# If they do have the same name, they should also have the same params.
for name, typ in self.__annotations__.items():
if hasattr(self, name) and hasattr(obj.__class__, name) \
and getattr(obj.__class__, name) != getattr(self, name):
if hasattr(self, name) and hasattr(subclass, name) \
and getattr(subclass, name) != getattr(self, name):
return False
return True
def __instancecheck__(self, obj):
return self.__subclasscheck__(obj.__class__)
class ElementsType(ParamsMeta):
elem_type: SSZType
@ -305,9 +315,6 @@ class Elements(ParamsBase, metaclass=ElementsType):
def __iter__(self) -> Iterator[SSZValue]:
return iter(self.items)
def __eq__(self, other):
return self.items == other.items
class List(Elements):
@ -366,9 +373,6 @@ class BytesLike(Elements, metaclass=BytesType):
cls = self.__class__
return f"{cls.__name__}[{cls.length}]: {self.hex()}"
def hex(self) -> str:
return self.items.hex()
class Bytes(BytesLike):
@ -398,7 +402,7 @@ class BytesN(BytesLike):
# Helpers for common BytesN types.
Bytes4 = BytesN[4]
Bytes32 = BytesN[32]
Bytes48 = BytesN[48]
Bytes96 = BytesN[96]
Bytes4: BytesType = BytesN[4]
Bytes32: BytesType = BytesN[32]
Bytes48: BytesType = BytesN[48]
Bytes96: BytesType = BytesN[96]

View File

@ -1,7 +1,8 @@
from .ssz_typing import (
SSZValue, SSZType, BasicValue, BasicType, Series, ElementsType,
Elements, Bit, Container, List, Vector, Bytes, BytesN,
uint, uint8, uint16, uint32, uint64, uint128, uint256
uint, uint8, uint16, uint32, uint64, uint128, uint256,
Bytes32, Bytes48
)
@ -193,3 +194,20 @@ def test_list():
assert False
except IndexError:
pass
def test_bytesn_subclass():
assert isinstance(BytesN[32](b'\xab' * 32), Bytes32)
assert not isinstance(BytesN[32](b'\xab' * 32), Bytes48)
assert issubclass(BytesN[32](b'\xab' * 32).type(), Bytes32)
assert issubclass(BytesN[32], Bytes32)
class Hash(Bytes32):
pass
assert isinstance(Hash(b'\xab' * 32), Bytes32)
assert not isinstance(Hash(b'\xab' * 32), Bytes48)
assert issubclass(Hash(b'\xab' * 32).type(), Bytes32)
assert issubclass(Hash, Bytes32)
assert not issubclass(Bytes48, Bytes32)