fixes for class based ssz typing

This commit is contained in:
protolambda 2019-06-20 20:25:22 +02:00
parent 7cdec746b4
commit 4e747fb887
No known key found for this signature in database
GPG Key ID: EC89FDBB2B4C7623
7 changed files with 66 additions and 41 deletions

View File

@ -65,22 +65,6 @@ def get_ssz_type_by_name(name: str) -> Container:
return globals()[name] return globals()[name]
# Monkey patch validator compute committee code
_compute_committee = compute_committee
committee_cache: Dict[Tuple[Hash, Hash, int, int], Tuple[ValidatorIndex, ...]] = {}
def compute_committee(indices: Tuple[ValidatorIndex, ...], # type: ignore
seed: Hash,
index: int,
count: int) -> Tuple[ValidatorIndex, ...]:
param_hash = (hash_tree_root(indices), seed, index, count)
if param_hash not in committee_cache:
committee_cache[param_hash] = _compute_committee(indices, seed, index, count)
return committee_cache[param_hash]
# Monkey patch hash cache # Monkey patch hash cache
_hash = hash _hash = hash
hash_cache: Dict[bytes, Hash] = {} hash_cache: Dict[bytes, Hash] = {}
@ -92,6 +76,22 @@ def hash(x: bytes) -> Hash:
return hash_cache[x] return hash_cache[x]
# Monkey patch validator compute committee code
_compute_committee = compute_committee
committee_cache: Dict[Tuple[Hash, Hash, int, int], Tuple[ValidatorIndex, ...]] = {}
def compute_committee(indices: Tuple[ValidatorIndex, ...], # type: ignore
seed: Hash,
index: int,
count: int) -> Tuple[ValidatorIndex, ...]:
param_hash = (hash(b''.join(index.to_bytes(length=4, byteorder='little') for index in indices)), seed, index, count)
if param_hash not in committee_cache:
committee_cache[param_hash] = _compute_committee(indices, seed, index, count)
return committee_cache[param_hash]
# Access to overwrite spec constants based on configuration # Access to overwrite spec constants based on configuration
def apply_constants_preset(preset: Dict[str, Any]) -> None: def apply_constants_preset(preset: Dict[str, Any]) -> None:
global_vars = globals() global_vars = globals()

View File

@ -651,11 +651,13 @@ def is_slashable_validator(validator: Validator, epoch: Epoch) -> bool:
### `get_active_validator_indices` ### `get_active_validator_indices`
```python ```python
def get_active_validator_indices(state: BeaconState, epoch: Epoch) -> Tuple[ValidatorIndex, ...]: def get_active_validator_indices(state: BeaconState, epoch: Epoch) -> List[ValidatorIndex, VALIDATOR_REGISTRY_SIZE]:
""" """
Get active validator indices at ``epoch``. Get active validator indices at ``epoch``.
""" """
return tuple(ValidatorIndex(i) for i, v in enumerate(state.validators) if is_active_validator(v, epoch)) return List[ValidatorIndex, VALIDATOR_REGISTRY_SIZE](
i for i, v in enumerate(state.validators) if is_active_validator(v, epoch)
)
``` ```
### `increase_balance` ### `increase_balance`
@ -873,7 +875,7 @@ def compute_committee(indices: Tuple[ValidatorIndex, ...],
```python ```python
def get_crosslink_committee(state: BeaconState, epoch: Epoch, shard: Shard) -> Tuple[ValidatorIndex, ...]: def get_crosslink_committee(state: BeaconState, epoch: Epoch, shard: Shard) -> Tuple[ValidatorIndex, ...]:
return compute_committee( return compute_committee(
indices=get_active_validator_indices(state, epoch), indices=tuple(get_active_validator_indices(state, epoch)),
seed=generate_seed(state, epoch), seed=generate_seed(state, epoch),
index=(shard + SHARD_COUNT - get_epoch_start_shard(state, epoch)) % SHARD_COUNT, index=(shard + SHARD_COUNT - get_epoch_start_shard(state, epoch)) % SHARD_COUNT,
count=get_epoch_committee_count(state, epoch), count=get_epoch_committee_count(state, epoch),

View File

@ -11,7 +11,7 @@ def get_valid_early_derived_secret_reveal(spec, state, epoch=None):
epoch = current_epoch + spec.CUSTODY_PERIOD_TO_RANDAO_PADDING epoch = current_epoch + spec.CUSTODY_PERIOD_TO_RANDAO_PADDING
reveal = bls_sign( reveal = bls_sign(
message_hash=spec.hash_tree_root(epoch), message_hash=spec.hash_tree_root(spec.Epoch(epoch)),
privkey=privkeys[revealed_index], privkey=privkeys[revealed_index],
domain=spec.get_domain( domain=spec.get_domain(
state=state, state=state,
@ -20,7 +20,7 @@ def get_valid_early_derived_secret_reveal(spec, state, epoch=None):
), ),
) )
mask = bls_sign( mask = bls_sign(
message_hash=spec.hash_tree_root(epoch), message_hash=spec.hash_tree_root(spec.Epoch(epoch)),
privkey=privkeys[masker_index], privkey=privkeys[masker_index],
domain=spec.get_domain( domain=spec.get_domain(
state=state, state=state,

View File

@ -4,7 +4,7 @@ from .hash_function import hash
ZERO_BYTES32 = b'\x00' * 32 ZERO_BYTES32 = b'\x00' * 32
zerohashes = [ZERO_BYTES32] zerohashes = [ZERO_BYTES32]
for layer in range(1, 32): for layer in range(1, 100):
zerohashes.append(hash(zerohashes[layer - 1] + zerohashes[layer - 1])) zerohashes.append(hash(zerohashes[layer - 1] + zerohashes[layer - 1]))

View File

@ -1,7 +1,7 @@
from ..merkle_minimal import merkleize_chunks from ..merkle_minimal import merkleize_chunks
from ..hash_function import hash from ..hash_function import hash
from .ssz_typing import ( from .ssz_typing import (
SSZValue, SSZType, BasicValue, BasicType, Series, Elements, Bit, Container, List, Bytes, BytesN, uint SSZValue, SSZType, BasicValue, BasicType, Series, Elements, Bit, Container, List, Bytes, BytesN, uint,
) )
# SSZ Serialization # SSZ Serialization
@ -143,5 +143,6 @@ def hash_tree_root(obj: SSZValue):
def signing_root(obj: Container): def signing_root(obj: Container):
# ignore last field # ignore last field
leaves = [hash_tree_root(field) for field in obj[:-1]] fields = [field for field in obj][:-1]
leaves = [hash_tree_root(f) for f in fields]
return merkleize_chunks(chunkify(b''.join(leaves))) return merkleize_chunks(chunkify(b''.join(leaves)))

View File

@ -98,7 +98,7 @@ def coerce_type_maybe(v, typ: SSZType, strict: bool = False):
return typ(v) return typ(v)
elif isinstance(v, (list, tuple)): elif isinstance(v, (list, tuple)):
return typ(*v) return typ(*v)
elif isinstance(v, bytes): elif isinstance(v, (bytes, BytesN, Bytes)):
return typ(v) return typ(v)
elif isinstance(v, GeneratorType): elif isinstance(v, GeneratorType):
return typ(v) return typ(v)
@ -154,7 +154,8 @@ class Container(Series, metaclass=SSZType):
super().__setattr__(name, value) super().__setattr__(name, value)
def __repr__(self): def __repr__(self):
return repr({field: getattr(self, field) for field in self.get_fields().keys()}) return repr({field: (getattr(self, field) if hasattr(self, field) else 'unset')
for field in self.get_fields().keys()})
def __str__(self): def __str__(self):
output = [f'{self.__class__.__name__}'] output = [f'{self.__class__.__name__}']
@ -236,15 +237,24 @@ class ParamsMeta(SSZType):
raise TypeError("provided parameters {} mismatch required parameter count {}".format(params, i)) raise TypeError("provided parameters {} mismatch required parameter count {}".format(params, i))
return res return res
def __instancecheck__(self, obj): def __subclasscheck__(self, subclass):
if obj.__class__.__name__ != self.__name__: # check regular class system if we can, solves a lot of the normal cases.
if super().__subclasscheck__(subclass):
return True
# if they are not normal subclasses, they are of the same class.
# then they should have the same name
if subclass.__name__ != self.__name__:
return False return False
# If they do have the same name, they should also have the same params.
for name, typ in self.__annotations__.items(): for name, typ in self.__annotations__.items():
if hasattr(self, name) and hasattr(obj.__class__, name) \ if hasattr(self, name) and hasattr(subclass, name) \
and getattr(obj.__class__, name) != getattr(self, name): and getattr(subclass, name) != getattr(self, name):
return False return False
return True return True
def __instancecheck__(self, obj):
return self.__subclasscheck__(obj.__class__)
class ElementsType(ParamsMeta): class ElementsType(ParamsMeta):
elem_type: SSZType elem_type: SSZType
@ -305,9 +315,6 @@ class Elements(ParamsBase, metaclass=ElementsType):
def __iter__(self) -> Iterator[SSZValue]: def __iter__(self) -> Iterator[SSZValue]:
return iter(self.items) return iter(self.items)
def __eq__(self, other):
return self.items == other.items
class List(Elements): class List(Elements):
@ -366,9 +373,6 @@ class BytesLike(Elements, metaclass=BytesType):
cls = self.__class__ cls = self.__class__
return f"{cls.__name__}[{cls.length}]: {self.hex()}" return f"{cls.__name__}[{cls.length}]: {self.hex()}"
def hex(self) -> str:
return self.items.hex()
class Bytes(BytesLike): class Bytes(BytesLike):
@ -398,7 +402,7 @@ class BytesN(BytesLike):
# Helpers for common BytesN types. # Helpers for common BytesN types.
Bytes4 = BytesN[4] Bytes4: BytesType = BytesN[4]
Bytes32 = BytesN[32] Bytes32: BytesType = BytesN[32]
Bytes48 = BytesN[48] Bytes48: BytesType = BytesN[48]
Bytes96 = BytesN[96] Bytes96: BytesType = BytesN[96]

View File

@ -1,7 +1,8 @@
from .ssz_typing import ( from .ssz_typing import (
SSZValue, SSZType, BasicValue, BasicType, Series, ElementsType, SSZValue, SSZType, BasicValue, BasicType, Series, ElementsType,
Elements, Bit, Container, List, Vector, Bytes, BytesN, Elements, Bit, Container, List, Vector, Bytes, BytesN,
uint, uint8, uint16, uint32, uint64, uint128, uint256 uint, uint8, uint16, uint32, uint64, uint128, uint256,
Bytes32, Bytes48
) )
@ -193,3 +194,20 @@ def test_list():
assert False assert False
except IndexError: except IndexError:
pass pass
def test_bytesn_subclass():
assert isinstance(BytesN[32](b'\xab' * 32), Bytes32)
assert not isinstance(BytesN[32](b'\xab' * 32), Bytes48)
assert issubclass(BytesN[32](b'\xab' * 32).type(), Bytes32)
assert issubclass(BytesN[32], Bytes32)
class Hash(Bytes32):
pass
assert isinstance(Hash(b'\xab' * 32), Bytes32)
assert not isinstance(Hash(b'\xab' * 32), Bytes48)
assert issubclass(Hash(b'\xab' * 32).type(), Bytes32)
assert issubclass(Hash, Bytes32)
assert not issubclass(Bytes48, Bytes32)