Merge pull request #1180 from ethereum/list-rework

pyspec-SSZ: lists-rework (enable static generalized indices) + fully python class based now.
This commit is contained in:
Danny Ryan 2019-06-25 07:38:50 -06:00 committed by GitHub
commit df2a9e1b54
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
24 changed files with 1314 additions and 871 deletions

View File

@ -77,6 +77,10 @@ MIN_EPOCHS_TO_INACTIVITY_PENALTY: 4
EPOCHS_PER_HISTORICAL_VECTOR: 65536 EPOCHS_PER_HISTORICAL_VECTOR: 65536
# 2**13 (= 8,192) epochs ~36 days # 2**13 (= 8,192) epochs ~36 days
EPOCHS_PER_SLASHED_BALANCES_VECTOR: 8192 EPOCHS_PER_SLASHED_BALANCES_VECTOR: 8192
# 2**24 (= 16,777,216) historical roots, ~26,131 years
HISTORICAL_ROOTS_LIMIT: 16777216
# 2**40 (= 1,099,511,627,776) validator spots
VALIDATOR_REGISTRY_LIMIT: 1099511627776
# Reward and penalty quotients # Reward and penalty quotients

View File

@ -78,6 +78,10 @@ EARLY_DERIVED_SECRET_PENALTY_MAX_FUTURE_EPOCHS: 4096
EPOCHS_PER_HISTORICAL_VECTOR: 64 EPOCHS_PER_HISTORICAL_VECTOR: 64
# [customized] smaller state # [customized] smaller state
EPOCHS_PER_SLASHED_BALANCES_VECTOR: 64 EPOCHS_PER_SLASHED_BALANCES_VECTOR: 64
# 2**24 (= 16,777,216) historical roots
HISTORICAL_ROOTS_LIMIT: 16777216
# 2**40 (= 1,099,511,627,776) validator spots
VALIDATOR_REGISTRY_LIMIT: 1099511627776
# Reward and penalty quotients # Reward and penalty quotients

View File

@ -12,12 +12,7 @@ from typing import (
PHASE0_IMPORTS = '''from typing import ( PHASE0_IMPORTS = '''from typing import (
Any, Any, Callable, Dict, Set, Sequence, Tuple,
Callable,
Dict,
List,
Set,
Tuple,
) )
from dataclasses import ( from dataclasses import (
@ -30,8 +25,7 @@ from eth2spec.utils.ssz.ssz_impl import (
signing_root, signing_root,
) )
from eth2spec.utils.ssz.ssz_typing import ( from eth2spec.utils.ssz.ssz_typing import (
# unused: uint8, uint16, uint32, uint128, uint256, Bit, Bool, Container, List, Vector, Bytes, uint64,
uint64, Container, Vector,
Bytes4, Bytes32, Bytes48, Bytes96, Bytes4, Bytes32, Bytes48, Bytes96,
) )
from eth2spec.utils.bls import ( from eth2spec.utils.bls import (
@ -39,18 +33,11 @@ from eth2spec.utils.bls import (
bls_verify, bls_verify,
bls_verify_multiple, bls_verify_multiple,
) )
# Note: 'int' type defaults to being interpreted as a uint64 by SSZ implementation.
from eth2spec.utils.hash_function import hash from eth2spec.utils.hash_function import hash
''' '''
PHASE1_IMPORTS = '''from typing import ( PHASE1_IMPORTS = '''from typing import (
Any, Any, Callable, Dict, Optional, Set, Sequence, MutableSequence, Tuple,
Callable,
Dict,
List,
Optional,
Set,
Tuple,
) )
from dataclasses import ( from dataclasses import (
@ -65,8 +52,7 @@ from eth2spec.utils.ssz.ssz_impl import (
is_empty, is_empty,
) )
from eth2spec.utils.ssz.ssz_typing import ( from eth2spec.utils.ssz.ssz_typing import (
# unused: uint8, uint16, uint32, uint128, uint256, Bit, Bool, Container, List, Vector, Bytes, uint64,
uint64, Container, Vector,
Bytes4, Bytes32, Bytes48, Bytes96, Bytes4, Bytes32, Bytes48, Bytes96,
) )
from eth2spec.utils.bls import ( from eth2spec.utils.bls import (
@ -77,28 +63,7 @@ from eth2spec.utils.bls import (
from eth2spec.utils.hash_function import hash from eth2spec.utils.hash_function import hash
''' '''
BYTE_TYPES = [4, 32, 48, 96]
SUNDRY_FUNCTIONS = ''' SUNDRY_FUNCTIONS = '''
def get_ssz_type_by_name(name: str) -> Container:
return globals()[name]
# Monkey patch validator compute committee code
_compute_committee = compute_committee
committee_cache: Dict[Tuple[Hash, Hash, int, int], List[ValidatorIndex]] = {}
def compute_committee(indices: List[ValidatorIndex], # type: ignore
seed: Hash,
index: int,
count: int) -> List[ValidatorIndex]:
param_hash = (hash_tree_root(indices), seed, index, count)
if param_hash not in committee_cache:
committee_cache[param_hash] = _compute_committee(indices, seed, index, count)
return committee_cache[param_hash]
# Monkey patch hash cache # Monkey patch hash cache
_hash = hash _hash = hash
hash_cache: Dict[bytes, Hash] = {} hash_cache: Dict[bytes, Hash] = {}
@ -110,6 +75,22 @@ def hash(x: bytes) -> Hash:
return hash_cache[x] return hash_cache[x]
# Monkey patch validator compute committee code
_compute_committee = compute_committee
committee_cache: Dict[Tuple[Hash, Hash, int, int], Sequence[ValidatorIndex]] = {}
def compute_committee(indices: Sequence[ValidatorIndex], # type: ignore
seed: Hash,
index: int,
count: int) -> Sequence[ValidatorIndex]:
param_hash = (hash(b''.join(index.to_bytes(length=4, byteorder='little') for index in indices)), seed, index, count)
if param_hash not in committee_cache:
committee_cache[param_hash] = _compute_committee(indices, seed, index, count)
return committee_cache[param_hash]
# Access to overwrite spec constants based on configuration # Access to overwrite spec constants based on configuration
def apply_constants_preset(preset: Dict[str, Any]) -> None: def apply_constants_preset(preset: Dict[str, Any]) -> None:
global_vars = globals() global_vars = globals()
@ -124,6 +105,18 @@ def apply_constants_preset(preset: Dict[str, Any]) -> None:
''' '''
def strip_comments(raw: str) -> str:
comment_line_regex = re.compile('^\s+# ')
lines = raw.split('\n')
out = []
for line in lines:
if not comment_line_regex.match(line):
if ' #' in line:
line = line[:line.index(' #')]
out.append(line)
return '\n'.join(out)
def objects_to_spec(functions: Dict[str, str], def objects_to_spec(functions: Dict[str, str],
custom_types: Dict[str, str], custom_types: Dict[str, str],
constants: Dict[str, str], constants: Dict[str, str],
@ -151,7 +144,8 @@ def objects_to_spec(functions: Dict[str, str],
ssz_objects_instantiation_spec = '\n\n'.join(ssz_objects.values()) ssz_objects_instantiation_spec = '\n\n'.join(ssz_objects.values())
ssz_objects_reinitialization_spec = ( ssz_objects_reinitialization_spec = (
'def init_SSZ_types() -> None:\n global_vars = globals()\n\n ' 'def init_SSZ_types() -> None:\n global_vars = globals()\n\n '
+ '\n\n '.join([re.sub(r'(?!\n\n)\n', r'\n ', value[:-1]) for value in ssz_objects.values()]) + '\n\n '.join([strip_comments(re.sub(r'(?!\n\n)\n', r'\n ', value[:-1]))
for value in ssz_objects.values()])
+ '\n\n' + '\n\n'
+ '\n'.join(map(lambda x: ' global_vars[\'%s\'] = %s' % (x, x), ssz_objects.keys())) + '\n'.join(map(lambda x: ' global_vars[\'%s\'] = %s' % (x, x), ssz_objects.keys()))
) )
@ -183,17 +177,32 @@ def combine_constants(old_constants: Dict[str, str], new_constants: Dict[str, st
return old_constants return old_constants
ignored_dependencies = [
'Bit', 'Bool', 'Vector', 'List', 'Container', 'Hash', 'BLSPubkey', 'BLSSignature', 'Bytes', 'BytesN'
'Bytes4', 'Bytes32', 'Bytes48', 'Bytes96',
'uint8', 'uint16', 'uint32', 'uint64', 'uint128', 'uint256',
'bytes' # to be removed after updating spec doc
]
def dependency_order_ssz_objects(objects: Dict[str, str], custom_types: Dict[str, str]) -> None: def dependency_order_ssz_objects(objects: Dict[str, str], custom_types: Dict[str, str]) -> None:
""" """
Determines which SSZ Object is depenedent on which other and orders them appropriately Determines which SSZ Object is depenedent on which other and orders them appropriately
""" """
items = list(objects.items()) items = list(objects.items())
for key, value in items: for key, value in items:
dependencies = re.findall(r'(: [A-Z][\w[]*)', value) dependencies = []
dependencies = map(lambda x: re.sub(r'\W|Vector|List|Container|Hash|BLSPubkey|BLSSignature|uint\d+|Bytes\d+|bytes', '', x), dependencies) for line in value.split('\n'):
if not re.match(r'\s+\w+: .+', line):
continue # skip whitespace etc.
line = line[line.index(':') + 1:] # strip of field name
if '#' in line:
line = line[:line.index('#')] # strip of comment
dependencies.extend(re.findall(r'(\w+)', line)) # catch all legible words, potential dependencies
dependencies = filter(lambda x: '_' not in x and x.upper() != x, dependencies) # filter out constants
dependencies = filter(lambda x: x not in ignored_dependencies, dependencies)
dependencies = filter(lambda x: x not in custom_types, dependencies)
for dep in dependencies: for dep in dependencies:
if dep in custom_types or len(dep) == 0:
continue
key_list = list(objects.keys()) key_list = list(objects.keys())
for item in [dep, key] + key_list[key_list.index(dep)+1:]: for item in [dep, key] + key_list[key_list.index(dep)+1:]:
objects[item] = objects.pop(item) objects[item] = objects.pop(item)

View File

@ -234,6 +234,8 @@ The following values are (non-configurable) constants used throughout the specif
| - | - | :-: | :-: | | - | - | :-: | :-: |
| `EPOCHS_PER_HISTORICAL_VECTOR` | `2**16` (= 65,536) | epochs | ~0.8 years | | `EPOCHS_PER_HISTORICAL_VECTOR` | `2**16` (= 65,536) | epochs | ~0.8 years |
| `EPOCHS_PER_SLASHED_BALANCES_VECTOR` | `2**13` (= 8,192) | epochs | ~36 days | | `EPOCHS_PER_SLASHED_BALANCES_VECTOR` | `2**13` (= 8,192) | epochs | ~36 days |
| `HISTORICAL_ROOTS_LIMIT` | `2**24` (= 16,777,216) | historical roots | ~26,131 years |
| `VALIDATOR_REGISTRY_LIMIT` | `2**40` (= 1,099,511,627,776) | validator spots | |
### Rewards and penalties ### Rewards and penalties
@ -295,7 +297,7 @@ class Validator(Container):
pubkey: BLSPubkey pubkey: BLSPubkey
withdrawal_credentials: Hash # Commitment to pubkey for withdrawals and transfers withdrawal_credentials: Hash # Commitment to pubkey for withdrawals and transfers
effective_balance: Gwei # Balance at stake effective_balance: Gwei # Balance at stake
slashed: bool slashed: Bool
# Status epochs # Status epochs
activation_eligibility_epoch: Epoch # When criteria for activation were met activation_eligibility_epoch: Epoch # When criteria for activation were met
activation_epoch: Epoch activation_epoch: Epoch
@ -335,15 +337,15 @@ class AttestationData(Container):
```python ```python
class AttestationDataAndCustodyBit(Container): class AttestationDataAndCustodyBit(Container):
data: AttestationData data: AttestationData
custody_bit: bool # Challengeable bit for the custody of crosslink data custody_bit: Bit # Challengeable bit (SSZ-bool, 1 byte) for the custody of crosslink data
``` ```
#### `IndexedAttestation` #### `IndexedAttestation`
```python ```python
class IndexedAttestation(Container): class IndexedAttestation(Container):
custody_bit_0_indices: List[ValidatorIndex] # Indices with custody bit equal to 0 custody_bit_0_indices: List[ValidatorIndex, MAX_INDICES_PER_ATTESTATION] # Indices with custody bit equal to 0
custody_bit_1_indices: List[ValidatorIndex] # Indices with custody bit equal to 1 custody_bit_1_indices: List[ValidatorIndex, MAX_INDICES_PER_ATTESTATION] # Indices with custody bit equal to 1
data: AttestationData data: AttestationData
signature: BLSSignature signature: BLSSignature
``` ```
@ -352,7 +354,7 @@ class IndexedAttestation(Container):
```python ```python
class PendingAttestation(Container): class PendingAttestation(Container):
aggregation_bitfield: bytes # Bit set for every attesting participant within a committee aggregation_bitfield: Bytes[MAX_INDICES_PER_ATTESTATION // 8]
data: AttestationData data: AttestationData
inclusion_delay: Slot inclusion_delay: Slot
proposer_index: ValidatorIndex proposer_index: ValidatorIndex
@ -419,9 +421,9 @@ class AttesterSlashing(Container):
```python ```python
class Attestation(Container): class Attestation(Container):
aggregation_bitfield: bytes aggregation_bitfield: Bytes[MAX_INDICES_PER_ATTESTATION // 8]
data: AttestationData data: AttestationData
custody_bitfield: bytes custody_bitfield: Bytes[MAX_INDICES_PER_ATTESTATION // 8]
signature: BLSSignature signature: BLSSignature
``` ```
@ -465,12 +467,12 @@ class BeaconBlockBody(Container):
eth1_data: Eth1Data # Eth1 data vote eth1_data: Eth1Data # Eth1 data vote
graffiti: Bytes32 # Arbitrary data graffiti: Bytes32 # Arbitrary data
# Operations # Operations
proposer_slashings: List[ProposerSlashing] proposer_slashings: List[ProposerSlashing, MAX_PROPOSER_SLASHINGS]
attester_slashings: List[AttesterSlashing] attester_slashings: List[AttesterSlashing, MAX_ATTESTER_SLASHINGS]
attestations: List[Attestation] attestations: List[Attestation, MAX_ATTESTATIONS]
deposits: List[Deposit] deposits: List[Deposit, MAX_DEPOSITS]
voluntary_exits: List[VoluntaryExit] voluntary_exits: List[VoluntaryExit, MAX_VOLUNTARY_EXITS]
transfers: List[Transfer] transfers: List[Transfer, MAX_TRANSFERS]
``` ```
#### `BeaconBlock` #### `BeaconBlock`
@ -498,14 +500,14 @@ class BeaconState(Container):
latest_block_header: BeaconBlockHeader latest_block_header: BeaconBlockHeader
block_roots: Vector[Hash, SLOTS_PER_HISTORICAL_ROOT] block_roots: Vector[Hash, SLOTS_PER_HISTORICAL_ROOT]
state_roots: Vector[Hash, SLOTS_PER_HISTORICAL_ROOT] state_roots: Vector[Hash, SLOTS_PER_HISTORICAL_ROOT]
historical_roots: List[Hash] historical_roots: List[Hash, HISTORICAL_ROOTS_LIMIT]
# Eth1 # Eth1
eth1_data: Eth1Data eth1_data: Eth1Data
eth1_data_votes: List[Eth1Data] eth1_data_votes: List[Eth1Data, SLOTS_PER_ETH1_VOTING_PERIOD]
eth1_deposit_index: uint64 eth1_deposit_index: uint64
# Registry # Registry
validators: List[Validator] validators: List[Validator, VALIDATOR_REGISTRY_LIMIT]
balances: List[Gwei] balances: List[Gwei, VALIDATOR_REGISTRY_LIMIT]
# Shuffling # Shuffling
start_shard: Shard start_shard: Shard
randao_mixes: Vector[Hash, EPOCHS_PER_HISTORICAL_VECTOR] randao_mixes: Vector[Hash, EPOCHS_PER_HISTORICAL_VECTOR]
@ -513,8 +515,8 @@ class BeaconState(Container):
# Slashings # Slashings
slashed_balances: Vector[Gwei, EPOCHS_PER_SLASHED_BALANCES_VECTOR] # Sums of slashed effective balances slashed_balances: Vector[Gwei, EPOCHS_PER_SLASHED_BALANCES_VECTOR] # Sums of slashed effective balances
# Attestations # Attestations
previous_epoch_attestations: List[PendingAttestation] previous_epoch_attestations: List[PendingAttestation, MAX_ATTESTATIONS * SLOTS_PER_EPOCH]
current_epoch_attestations: List[PendingAttestation] current_epoch_attestations: List[PendingAttestation, MAX_ATTESTATIONS * SLOTS_PER_EPOCH]
# Crosslinks # Crosslinks
previous_crosslinks: Vector[Crosslink, SHARD_COUNT] # Previous epoch snapshot previous_crosslinks: Vector[Crosslink, SHARD_COUNT] # Previous epoch snapshot
current_crosslinks: Vector[Crosslink, SHARD_COUNT] current_crosslinks: Vector[Crosslink, SHARD_COUNT]
@ -623,13 +625,13 @@ def is_slashable_validator(validator: Validator, epoch: Epoch) -> bool:
""" """
Check if ``validator`` is slashable. Check if ``validator`` is slashable.
""" """
return validator.slashed is False and (validator.activation_epoch <= epoch < validator.withdrawable_epoch) return (not validator.slashed) and (validator.activation_epoch <= epoch < validator.withdrawable_epoch)
``` ```
### `get_active_validator_indices` ### `get_active_validator_indices`
```python ```python
def get_active_validator_indices(state: BeaconState, epoch: Epoch) -> List[ValidatorIndex]: def get_active_validator_indices(state: BeaconState, epoch: Epoch) -> Sequence[ValidatorIndex]:
""" """
Get active validator indices at ``epoch``. Get active validator indices at ``epoch``.
""" """
@ -795,7 +797,7 @@ def get_beacon_proposer_index(state: BeaconState) -> ValidatorIndex:
### `verify_merkle_branch` ### `verify_merkle_branch`
```python ```python
def verify_merkle_branch(leaf: Hash, proof: List[Hash], depth: int, index: int, root: Hash) -> bool: def verify_merkle_branch(leaf: Hash, proof: Sequence[Hash], depth: int, index: int, root: Hash) -> bool:
""" """
Verify that the given ``leaf`` is on the merkle branch ``proof`` Verify that the given ``leaf`` is on the merkle branch ``proof``
starting with the given ``root``. starting with the given ``root``.
@ -839,7 +841,8 @@ def get_shuffled_index(index: ValidatorIndex, index_count: int, seed: Hash) -> V
### `compute_committee` ### `compute_committee`
```python ```python
def compute_committee(indices: List[ValidatorIndex], seed: Hash, index: int, count: int) -> List[ValidatorIndex]: def compute_committee(indices: Sequence[ValidatorIndex],
seed: Hash, index: int, count: int) -> Sequence[ValidatorIndex]:
start = (len(indices) * index) // count start = (len(indices) * index) // count
end = (len(indices) * (index + 1)) // count end = (len(indices) * (index + 1)) // count
return [indices[get_shuffled_index(ValidatorIndex(i), len(indices), seed)] for i in range(start, end)] return [indices[get_shuffled_index(ValidatorIndex(i), len(indices), seed)] for i in range(start, end)]
@ -848,7 +851,7 @@ def compute_committee(indices: List[ValidatorIndex], seed: Hash, index: int, cou
### `get_crosslink_committee` ### `get_crosslink_committee`
```python ```python
def get_crosslink_committee(state: BeaconState, epoch: Epoch, shard: Shard) -> List[ValidatorIndex]: def get_crosslink_committee(state: BeaconState, epoch: Epoch, shard: Shard) -> Sequence[ValidatorIndex]:
return compute_committee( return compute_committee(
indices=get_active_validator_indices(state, epoch), indices=get_active_validator_indices(state, epoch),
seed=generate_seed(state, epoch), seed=generate_seed(state, epoch),
@ -862,7 +865,7 @@ def get_crosslink_committee(state: BeaconState, epoch: Epoch, shard: Shard) -> L
```python ```python
def get_attesting_indices(state: BeaconState, def get_attesting_indices(state: BeaconState,
attestation_data: AttestationData, attestation_data: AttestationData,
bitfield: bytes) -> List[ValidatorIndex]: bitfield: bytes) -> Sequence[ValidatorIndex]:
""" """
Return the sorted attesting indices corresponding to ``attestation_data`` and ``bitfield``. Return the sorted attesting indices corresponding to ``attestation_data`` and ``bitfield``.
""" """
@ -888,7 +891,7 @@ def bytes_to_int(data: bytes) -> int:
### `get_total_balance` ### `get_total_balance`
```python ```python
def get_total_balance(state: BeaconState, indices: List[ValidatorIndex]) -> Gwei: def get_total_balance(state: BeaconState, indices: Set[ValidatorIndex]) -> Gwei:
""" """
Return the combined effective balance of the ``indices``. (1 Gwei minimum to avoid divisions by zero.) Return the combined effective balance of the ``indices``. (1 Gwei minimum to avoid divisions by zero.)
""" """
@ -1114,7 +1117,7 @@ def slash_validator(state: BeaconState,
### Genesis trigger ### Genesis trigger
Before genesis has been triggered and whenever the deposit contract emits a `Deposit` log, call the function `is_genesis_trigger(deposits: List[Deposit], timestamp: uint64) -> bool` where: Before genesis has been triggered and whenever the deposit contract emits a `Deposit` log, call the function `is_genesis_trigger(deposits: Sequence[Deposit], timestamp: uint64) -> bool` where:
* `deposits` is the list of all deposits, ordered chronologically, up to and including the deposit triggering the latest `Deposit` log * `deposits` is the list of all deposits, ordered chronologically, up to and including the deposit triggering the latest `Deposit` log
* `timestamp` is the Unix timestamp in the Ethereum 1.0 block that emitted the latest `Deposit` log * `timestamp` is the Unix timestamp in the Ethereum 1.0 block that emitted the latest `Deposit` log
@ -1131,7 +1134,7 @@ When `is_genesis_trigger(deposits, timestamp) is True` for the first time, let:
*Note*: The function `is_genesis_trigger` has yet to be agreed upon by the community, and can be updated as necessary. We define the following testing placeholder: *Note*: The function `is_genesis_trigger` has yet to be agreed upon by the community, and can be updated as necessary. We define the following testing placeholder:
```python ```python
def is_genesis_trigger(deposits: List[Deposit], timestamp: uint64) -> bool: def is_genesis_trigger(deposits: Sequence[Deposit], timestamp: uint64) -> bool:
# Process deposits # Process deposits
state = BeaconState() state = BeaconState()
for deposit in deposits: for deposit in deposits:
@ -1153,10 +1156,10 @@ def is_genesis_trigger(deposits: List[Deposit], timestamp: uint64) -> bool:
Let `genesis_state = get_genesis_beacon_state(genesis_deposits, genesis_time, genesis_eth1_data)`. Let `genesis_state = get_genesis_beacon_state(genesis_deposits, genesis_time, genesis_eth1_data)`.
```python ```python
def get_genesis_beacon_state(deposits: List[Deposit], genesis_time: int, genesis_eth1_data: Eth1Data) -> BeaconState: def get_genesis_beacon_state(deposits: Sequence[Deposit], genesis_time: int, eth1_data: Eth1Data) -> BeaconState:
state = BeaconState( state = BeaconState(
genesis_time=genesis_time, genesis_time=genesis_time,
eth1_data=genesis_eth1_data, eth1_data=eth1_data,
latest_block_header=BeaconBlockHeader(body_root=hash_tree_root(BeaconBlockBody())), latest_block_header=BeaconBlockHeader(body_root=hash_tree_root(BeaconBlockBody())),
) )
@ -1170,8 +1173,10 @@ def get_genesis_beacon_state(deposits: List[Deposit], genesis_time: int, genesis
validator.activation_eligibility_epoch = GENESIS_EPOCH validator.activation_eligibility_epoch = GENESIS_EPOCH
validator.activation_epoch = GENESIS_EPOCH validator.activation_epoch = GENESIS_EPOCH
# Populate active_index_roots # Populate active_index_roots
genesis_active_index_root = hash_tree_root(get_active_validator_indices(state, GENESIS_EPOCH)) genesis_active_index_root = hash_tree_root(
List[ValidatorIndex, VALIDATOR_REGISTRY_LIMIT](get_active_validator_indices(state, GENESIS_EPOCH))
)
for index in range(EPOCHS_PER_HISTORICAL_VECTOR): for index in range(EPOCHS_PER_HISTORICAL_VECTOR):
state.active_index_roots[index] = genesis_active_index_root state.active_index_roots[index] = genesis_active_index_root
@ -1246,17 +1251,17 @@ def process_epoch(state: BeaconState) -> None:
```python ```python
def get_total_active_balance(state: BeaconState) -> Gwei: def get_total_active_balance(state: BeaconState) -> Gwei:
return get_total_balance(state, get_active_validator_indices(state, get_current_epoch(state))) return get_total_balance(state, set(get_active_validator_indices(state, get_current_epoch(state))))
``` ```
```python ```python
def get_matching_source_attestations(state: BeaconState, epoch: Epoch) -> List[PendingAttestation]: def get_matching_source_attestations(state: BeaconState, epoch: Epoch) -> Sequence[PendingAttestation]:
assert epoch in (get_current_epoch(state), get_previous_epoch(state)) assert epoch in (get_current_epoch(state), get_previous_epoch(state))
return state.current_epoch_attestations if epoch == get_current_epoch(state) else state.previous_epoch_attestations return state.current_epoch_attestations if epoch == get_current_epoch(state) else state.previous_epoch_attestations
``` ```
```python ```python
def get_matching_target_attestations(state: BeaconState, epoch: Epoch) -> List[PendingAttestation]: def get_matching_target_attestations(state: BeaconState, epoch: Epoch) -> Sequence[PendingAttestation]:
return [ return [
a for a in get_matching_source_attestations(state, epoch) a for a in get_matching_source_attestations(state, epoch)
if a.data.target_root == get_block_root(state, epoch) if a.data.target_root == get_block_root(state, epoch)
@ -1264,7 +1269,7 @@ def get_matching_target_attestations(state: BeaconState, epoch: Epoch) -> List[P
``` ```
```python ```python
def get_matching_head_attestations(state: BeaconState, epoch: Epoch) -> List[PendingAttestation]: def get_matching_head_attestations(state: BeaconState, epoch: Epoch) -> Sequence[PendingAttestation]:
return [ return [
a for a in get_matching_source_attestations(state, epoch) a for a in get_matching_source_attestations(state, epoch)
if a.data.beacon_block_root == get_block_root_at_slot(state, get_attestation_data_slot(state, a.data)) if a.data.beacon_block_root == get_block_root_at_slot(state, get_attestation_data_slot(state, a.data))
@ -1273,22 +1278,22 @@ def get_matching_head_attestations(state: BeaconState, epoch: Epoch) -> List[Pen
```python ```python
def get_unslashed_attesting_indices(state: BeaconState, def get_unslashed_attesting_indices(state: BeaconState,
attestations: List[PendingAttestation]) -> List[ValidatorIndex]: attestations: Sequence[PendingAttestation]) -> Set[ValidatorIndex]:
output = set() # type: Set[ValidatorIndex] output = set() # type: Set[ValidatorIndex]
for a in attestations: for a in attestations:
output = output.union(get_attesting_indices(state, a.data, a.aggregation_bitfield)) output = output.union(get_attesting_indices(state, a.data, a.aggregation_bitfield))
return sorted(filter(lambda index: not state.validators[index].slashed, list(output))) return set(filter(lambda index: not state.validators[index].slashed, list(output)))
``` ```
```python ```python
def get_attesting_balance(state: BeaconState, attestations: List[PendingAttestation]) -> Gwei: def get_attesting_balance(state: BeaconState, attestations: Sequence[PendingAttestation]) -> Gwei:
return get_total_balance(state, get_unslashed_attesting_indices(state, attestations)) return get_total_balance(state, get_unslashed_attesting_indices(state, attestations))
``` ```
```python ```python
def get_winning_crosslink_and_attesting_indices(state: BeaconState, def get_winning_crosslink_and_attesting_indices(state: BeaconState,
epoch: Epoch, epoch: Epoch,
shard: Shard) -> Tuple[Crosslink, List[ValidatorIndex]]: shard: Shard) -> Tuple[Crosslink, Set[ValidatorIndex]]:
attestations = [a for a in get_matching_source_attestations(state, epoch) if a.data.crosslink.shard == shard] attestations = [a for a in get_matching_source_attestations(state, epoch) if a.data.crosslink.shard == shard]
crosslinks = list(filter( crosslinks = list(filter(
lambda c: hash_tree_root(state.current_crosslinks[shard]) in (c.parent_root, hash_tree_root(c)), lambda c: hash_tree_root(state.current_crosslinks[shard]) in (c.parent_root, hash_tree_root(c)),
@ -1361,7 +1366,7 @@ def process_crosslinks(state: BeaconState) -> None:
for epoch in (get_previous_epoch(state), get_current_epoch(state)): for epoch in (get_previous_epoch(state), get_current_epoch(state)):
for offset in range(get_epoch_committee_count(state, epoch)): for offset in range(get_epoch_committee_count(state, epoch)):
shard = Shard((get_epoch_start_shard(state, epoch) + offset) % SHARD_COUNT) shard = Shard((get_epoch_start_shard(state, epoch) + offset) % SHARD_COUNT)
crosslink_committee = get_crosslink_committee(state, epoch, shard) crosslink_committee = set(get_crosslink_committee(state, epoch, shard))
winning_crosslink, attesting_indices = get_winning_crosslink_and_attesting_indices(state, epoch, shard) winning_crosslink, attesting_indices = get_winning_crosslink_and_attesting_indices(state, epoch, shard)
if 3 * get_total_balance(state, attesting_indices) >= 2 * get_total_balance(state, crosslink_committee): if 3 * get_total_balance(state, attesting_indices) >= 2 * get_total_balance(state, crosslink_committee):
state.current_crosslinks[shard] = winning_crosslink state.current_crosslinks[shard] = winning_crosslink
@ -1377,7 +1382,7 @@ def get_base_reward(state: BeaconState, index: ValidatorIndex) -> Gwei:
``` ```
```python ```python
def get_attestation_deltas(state: BeaconState) -> Tuple[List[Gwei], List[Gwei]]: def get_attestation_deltas(state: BeaconState) -> Tuple[Sequence[Gwei], Sequence[Gwei]]:
previous_epoch = get_previous_epoch(state) previous_epoch = get_previous_epoch(state)
total_balance = get_total_active_balance(state) total_balance = get_total_active_balance(state)
rewards = [Gwei(0) for _ in range(len(state.validators))] rewards = [Gwei(0) for _ in range(len(state.validators))]
@ -1428,13 +1433,13 @@ def get_attestation_deltas(state: BeaconState) -> Tuple[List[Gwei], List[Gwei]]:
``` ```
```python ```python
def get_crosslink_deltas(state: BeaconState) -> Tuple[List[Gwei], List[Gwei]]: def get_crosslink_deltas(state: BeaconState) -> Tuple[Sequence[Gwei], Sequence[Gwei]]:
rewards = [Gwei(0) for index in range(len(state.validators))] rewards = [Gwei(0) for _ in range(len(state.validators))]
penalties = [Gwei(0) for index in range(len(state.validators))] penalties = [Gwei(0) for _ in range(len(state.validators))]
epoch = get_previous_epoch(state) epoch = get_previous_epoch(state)
for offset in range(get_epoch_committee_count(state, epoch)): for offset in range(get_epoch_committee_count(state, epoch)):
shard = Shard((get_epoch_start_shard(state, epoch) + offset) % SHARD_COUNT) shard = Shard((get_epoch_start_shard(state, epoch) + offset) % SHARD_COUNT)
crosslink_committee = get_crosslink_committee(state, epoch, shard) crosslink_committee = set(get_crosslink_committee(state, epoch, shard))
winning_crosslink, attesting_indices = get_winning_crosslink_and_attesting_indices(state, epoch, shard) winning_crosslink, attesting_indices = get_winning_crosslink_and_attesting_indices(state, epoch, shard)
attesting_balance = get_total_balance(state, attesting_indices) attesting_balance = get_total_balance(state, attesting_indices)
committee_balance = get_total_balance(state, crosslink_committee) committee_balance = get_total_balance(state, crosslink_committee)
@ -1528,7 +1533,9 @@ def process_final_updates(state: BeaconState) -> None:
# Set active index root # Set active index root
index_root_position = (next_epoch + ACTIVATION_EXIT_DELAY) % EPOCHS_PER_HISTORICAL_VECTOR index_root_position = (next_epoch + ACTIVATION_EXIT_DELAY) % EPOCHS_PER_HISTORICAL_VECTOR
state.active_index_roots[index_root_position] = hash_tree_root( state.active_index_roots[index_root_position] = hash_tree_root(
get_active_validator_indices(state, Epoch(next_epoch + ACTIVATION_EXIT_DELAY)) List[ValidatorIndex, VALIDATOR_REGISTRY_LIMIT](
get_active_validator_indices(state, Epoch(next_epoch + ACTIVATION_EXIT_DELAY))
)
) )
# Set total slashed balances # Set total slashed balances
state.slashed_balances[next_epoch % EPOCHS_PER_SLASHED_BALANCES_VECTOR] = ( state.slashed_balances[next_epoch % EPOCHS_PER_SLASHED_BALANCES_VECTOR] = (
@ -1610,16 +1617,16 @@ def process_operations(state: BeaconState, body: BeaconBlockBody) -> None:
assert len(body.deposits) == min(MAX_DEPOSITS, state.eth1_data.deposit_count - state.eth1_deposit_index) assert len(body.deposits) == min(MAX_DEPOSITS, state.eth1_data.deposit_count - state.eth1_deposit_index)
# Verify that there are no duplicate transfers # Verify that there are no duplicate transfers
assert len(body.transfers) == len(set(body.transfers)) assert len(body.transfers) == len(set(body.transfers))
all_operations = [
(body.proposer_slashings, MAX_PROPOSER_SLASHINGS, process_proposer_slashing), all_operations = (
(body.attester_slashings, MAX_ATTESTER_SLASHINGS, process_attester_slashing), (body.proposer_slashings, process_proposer_slashing),
(body.attestations, MAX_ATTESTATIONS, process_attestation), (body.attester_slashings, process_attester_slashing),
(body.deposits, MAX_DEPOSITS, process_deposit), (body.attestations, process_attestation),
(body.voluntary_exits, MAX_VOLUNTARY_EXITS, process_voluntary_exit), (body.deposits, process_deposit),
(body.transfers, MAX_TRANSFERS, process_transfer), (body.voluntary_exits, process_voluntary_exit),
] # type: List[Tuple[List[Container], int, Callable]] (body.transfers, process_transfer),
for operations, max_operations, function in all_operations: ) # type: Sequence[Tuple[List, Callable]]
assert len(operations) <= max_operations for operations, function in all_operations:
for operation in operations: for operation in operations:
function(state, operation) function(state, operation)
``` ```

View File

@ -113,6 +113,13 @@ This document details the beacon chain additions and changes in Phase 1 of Ether
| - | - | | - | - |
| `DOMAIN_CUSTODY_BIT_CHALLENGE` | `6` | | `DOMAIN_CUSTODY_BIT_CHALLENGE` | `6` |
### TODO PLACEHOLDER
| Name | Value |
| - | - |
| `PLACEHOLDER` | `2**32` |
## Data structures ## Data structures
### Custody objects ### Custody objects
@ -134,7 +141,7 @@ class CustodyBitChallenge(Container):
attestation: Attestation attestation: Attestation
challenger_index: ValidatorIndex challenger_index: ValidatorIndex
responder_key: BLSSignature responder_key: BLSSignature
chunk_bits: bytes chunk_bits: Bytes[PLACEHOLDER]
signature: BLSSignature signature: BLSSignature
``` ```
@ -171,9 +178,9 @@ class CustodyBitChallengeRecord(Container):
class CustodyResponse(Container): class CustodyResponse(Container):
challenge_index: uint64 challenge_index: uint64
chunk_index: uint64 chunk_index: uint64
chunk: Vector[bytes, BYTES_PER_CUSTODY_CHUNK] chunk: Vector[Bytes[PLACEHOLDER], BYTES_PER_CUSTODY_CHUNK]
data_branch: List[Bytes32] data_branch: List[Bytes32, PLACEHOLDER]
chunk_bits_branch: List[Bytes32] chunk_bits_branch: List[Bytes32, PLACEHOLDER]
chunk_bits_leaf: Bytes32 chunk_bits_leaf: Bytes32
``` ```
@ -226,24 +233,25 @@ class Validator(Container):
```python ```python
class BeaconState(Container): class BeaconState(Container):
custody_chunk_challenge_records: List[CustodyChunkChallengeRecord] custody_chunk_challenge_records: List[CustodyChunkChallengeRecord, PLACEHOLDER]
custody_bit_challenge_records: List[CustodyBitChallengeRecord] custody_bit_challenge_records: List[CustodyBitChallengeRecord, PLACEHOLDER]
custody_challenge_index: uint64 custody_challenge_index: uint64
# Future derived secrets already exposed; contains the indices of the exposed validator # Future derived secrets already exposed; contains the indices of the exposed validator
# at RANDAO reveal period % EARLY_DERIVED_SECRET_PENALTY_MAX_FUTURE_EPOCHS # at RANDAO reveal period % EARLY_DERIVED_SECRET_PENALTY_MAX_FUTURE_EPOCHS
exposed_derived_secrets: Vector[List[ValidatorIndex], EARLY_DERIVED_SECRET_PENALTY_MAX_FUTURE_EPOCHS] exposed_derived_secrets: Vector[List[ValidatorIndex, PLACEHOLDER],
EARLY_DERIVED_SECRET_PENALTY_MAX_FUTURE_EPOCHS]
``` ```
#### `BeaconBlockBody` #### `BeaconBlockBody`
```python ```python
class BeaconBlockBody(Container): class BeaconBlockBody(Container):
custody_chunk_challenges: List[CustodyChunkChallenge] custody_chunk_challenges: List[CustodyChunkChallenge, PLACEHOLDER]
custody_bit_challenges: List[CustodyBitChallenge] custody_bit_challenges: List[CustodyBitChallenge, PLACEHOLDER]
custody_responses: List[CustodyResponse] custody_responses: List[CustodyResponse, PLACEHOLDER]
custody_key_reveals: List[CustodyKeyReveal] custody_key_reveals: List[CustodyKeyReveal, PLACEHOLDER]
early_derived_secret_reveals: List[EarlyDerivedSecretReveal] early_derived_secret_reveals: List[EarlyDerivedSecretReveal, PLACEHOLDER]
``` ```
## Helpers ## Helpers
@ -310,7 +318,7 @@ def get_validators_custody_reveal_period(state: BeaconState,
### `replace_empty_or_append` ### `replace_empty_or_append`
```python ```python
def replace_empty_or_append(list: List[Any], new_element: Any) -> int: def replace_empty_or_append(list: MutableSequence[Any], new_element: Any) -> int:
for i in range(len(list)): for i in range(len(list)):
if is_empty(list[i]): if is_empty(list[i]):
list[i] = new_element list[i] = new_element
@ -394,12 +402,11 @@ def process_early_derived_secret_reveal(state: BeaconState,
""" """
revealed_validator = state.validators[reveal.revealed_index] revealed_validator = state.validators[reveal.revealed_index]
masker = state.validators[reveal.masker_index]
derived_secret_location = reveal.epoch % EARLY_DERIVED_SECRET_PENALTY_MAX_FUTURE_EPOCHS derived_secret_location = reveal.epoch % EARLY_DERIVED_SECRET_PENALTY_MAX_FUTURE_EPOCHS
assert reveal.epoch >= get_current_epoch(state) + RANDAO_PENALTY_EPOCHS assert reveal.epoch >= get_current_epoch(state) + RANDAO_PENALTY_EPOCHS
assert reveal.epoch < get_current_epoch(state) + EARLY_DERIVED_SECRET_PENALTY_MAX_FUTURE_EPOCHS assert reveal.epoch < get_current_epoch(state) + EARLY_DERIVED_SECRET_PENALTY_MAX_FUTURE_EPOCHS
assert revealed_validator.slashed is False assert not revealed_validator.slashed
assert reveal.revealed_index not in state.exposed_derived_secrets[derived_secret_location] assert reveal.revealed_index not in state.exposed_derived_secrets[derived_secret_location]
# Verify signature correctness # Verify signature correctness

View File

@ -69,13 +69,19 @@ This document describes the shard data layer and the shard fork choice rule in P
| `DOMAIN_SHARD_PROPOSER` | `128` | | `DOMAIN_SHARD_PROPOSER` | `128` |
| `DOMAIN_SHARD_ATTESTER` | `129` | | `DOMAIN_SHARD_ATTESTER` | `129` |
### TODO PLACEHOLDER
| Name | Value |
| - | - |
| `PLACEHOLDER` | `2**32` |
## Data structures ## Data structures
### `ShardBlockBody` ### `ShardBlockBody`
```python ```python
class ShardBlockBody(Container): class ShardBlockBody(Container):
data: Vector[bytes, BYTES_PER_SHARD_BLOCK_BODY] data: Vector[Bytes[PLACEHOLDER], BYTES_PER_SHARD_BLOCK_BODY]
``` ```
### `ShardAttestation` ### `ShardAttestation`
@ -86,7 +92,7 @@ class ShardAttestation(Container):
slot: Slot slot: Slot
shard: Shard shard: Shard
shard_block_root: Bytes32 shard_block_root: Bytes32
aggregation_bitfield: bytes aggregation_bitfield: Bytes[PLACEHOLDER]
aggregate_signature: BLSSignature aggregate_signature: BLSSignature
``` ```
@ -100,7 +106,7 @@ class ShardBlock(Container):
parent_root: Bytes32 parent_root: Bytes32
data: ShardBlockBody data: ShardBlockBody
state_root: Bytes32 state_root: Bytes32
attestations: List[ShardAttestation] attestations: List[ShardAttestation, PLACEHOLDER]
signature: BLSSignature signature: BLSSignature
``` ```
@ -114,7 +120,7 @@ class ShardBlockHeader(Container):
parent_root: Bytes32 parent_root: Bytes32
body_root: Bytes32 body_root: Bytes32
state_root: Bytes32 state_root: Bytes32
attestations: List[ShardAttestation] attestations: List[ShardAttestation, PLACEHOLDER]
signature: BLSSignature signature: BLSSignature
``` ```
@ -127,7 +133,7 @@ def get_period_committee(state: BeaconState,
epoch: Epoch, epoch: Epoch,
shard: Shard, shard: Shard,
index: int, index: int,
count: int) -> List[ValidatorIndex]: count: int) -> Sequence[ValidatorIndex]:
""" """
Return committee for a period. Used to construct persistent committees. Return committee for a period. Used to construct persistent committees.
""" """
@ -153,7 +159,7 @@ def get_switchover_epoch(state: BeaconState, epoch: Epoch, index: ValidatorIndex
```python ```python
def get_persistent_committee(state: BeaconState, def get_persistent_committee(state: BeaconState,
shard: Shard, shard: Shard,
slot: Slot) -> List[ValidatorIndex]: slot: Slot) -> Sequence[ValidatorIndex]:
""" """
Return the persistent committee for the given ``shard`` at the given ``slot``. Return the persistent committee for the given ``shard`` at the given ``slot``.
""" """
@ -187,7 +193,7 @@ def get_shard_proposer_index(state: BeaconState,
shard: Shard, shard: Shard,
slot: Slot) -> Optional[ValidatorIndex]: slot: Slot) -> Optional[ValidatorIndex]:
# Randomly shift persistent committee # Randomly shift persistent committee
persistent_committee = get_persistent_committee(state, shard, slot) persistent_committee = list(get_persistent_committee(state, shard, slot))
seed = hash(state.current_shuffling_seed + int_to_bytes(shard, length=8) + int_to_bytes(slot, length=8)) seed = hash(state.current_shuffling_seed + int_to_bytes(shard, length=8) + int_to_bytes(slot, length=8))
random_index = bytes_to_int(seed[0:8]) % len(persistent_committee) random_index = bytes_to_int(seed[0:8]) % len(persistent_committee)
persistent_committee = persistent_committee[random_index:] + persistent_committee[:random_index] persistent_committee = persistent_committee[random_index:] + persistent_committee[:random_index]
@ -242,13 +248,13 @@ def verify_shard_attestation_signature(state: BeaconState,
### `compute_crosslink_data_root` ### `compute_crosslink_data_root`
```python ```python
def compute_crosslink_data_root(blocks: List[ShardBlock]) -> Bytes32: def compute_crosslink_data_root(blocks: Sequence[ShardBlock]) -> Bytes32:
def is_power_of_two(value: int) -> bool: def is_power_of_two(value: int) -> bool:
return (value > 0) and (value & (value - 1) == 0) return (value > 0) and (value & (value - 1) == 0)
def pad_to_power_of_2(values: List[bytes]) -> List[bytes]: def pad_to_power_of_2(values: MutableSequence[bytes]) -> Sequence[bytes]:
while not is_power_of_two(len(values)): while not is_power_of_two(len(values)):
values += [b'\x00' * BYTES_PER_SHARD_BLOCK_BODY] values.append(b'\x00' * BYTES_PER_SHARD_BLOCK_BODY)
return values return values
def hash_tree_root_of_bytes(data: bytes) -> bytes: def hash_tree_root_of_bytes(data: bytes) -> bytes:
@ -258,6 +264,8 @@ def compute_crosslink_data_root(blocks: List[ShardBlock]) -> Bytes32:
return data + b'\x00' * (length - len(data)) return data + b'\x00' * (length - len(data))
return hash( return hash(
# TODO untested code.
# Need to either pass a typed list to hash-tree-root, or merkleize_chunks(values, pad_to=2**x)
hash_tree_root(pad_to_power_of_2([ hash_tree_root(pad_to_power_of_2([
hash_tree_root_of_bytes( hash_tree_root_of_bytes(
zpad(serialize(get_shard_header(block)), BYTES_PER_SHARD_BLOCK_BODY) zpad(serialize(get_shard_header(block)), BYTES_PER_SHARD_BLOCK_BODY)
@ -281,9 +289,9 @@ Let:
* `candidate` be a candidate `ShardBlock` for which validity is to be determined by running `is_valid_shard_block` * `candidate` be a candidate `ShardBlock` for which validity is to be determined by running `is_valid_shard_block`
```python ```python
def is_valid_shard_block(beacon_blocks: List[BeaconBlock], def is_valid_shard_block(beacon_blocks: Sequence[BeaconBlock],
beacon_state: BeaconState, beacon_state: BeaconState,
valid_shard_blocks: List[ShardBlock], valid_shard_blocks: Sequence[ShardBlock],
candidate: ShardBlock) -> bool: candidate: ShardBlock) -> bool:
# Check if block is already determined valid # Check if block is already determined valid
for _, block in enumerate(valid_shard_blocks): for _, block in enumerate(valid_shard_blocks):
@ -330,7 +338,7 @@ def is_valid_shard_block(beacon_blocks: List[BeaconBlock],
assert proposer_index is not None assert proposer_index is not None
assert bls_verify( assert bls_verify(
pubkey=beacon_state.validators[proposer_index].pubkey, pubkey=beacon_state.validators[proposer_index].pubkey,
message_hash=signing_root(block), message_hash=signing_root(candidate),
signature=candidate.signature, signature=candidate.signature,
domain=get_domain(beacon_state, DOMAIN_SHARD_PROPOSER, slot_to_epoch(candidate.slot)), domain=get_domain(beacon_state, DOMAIN_SHARD_PROPOSER, slot_to_epoch(candidate.slot)),
) )
@ -347,7 +355,7 @@ Let:
* `candidate` be a candidate `ShardAttestation` for which validity is to be determined by running `is_valid_shard_attestation` * `candidate` be a candidate `ShardAttestation` for which validity is to be determined by running `is_valid_shard_attestation`
```python ```python
def is_valid_shard_attestation(valid_shard_blocks: List[ShardBlock], def is_valid_shard_attestation(valid_shard_blocks: Sequence[ShardBlock],
beacon_state: BeaconState, beacon_state: BeaconState,
candidate: ShardAttestation) -> bool: candidate: ShardAttestation) -> bool:
# Check shard block # Check shard block
@ -372,17 +380,17 @@ Let:
* `shard` be a valid `Shard` * `shard` be a valid `Shard`
* `shard_blocks` be the `ShardBlock` list such that `shard_blocks[slot]` is the canonical `ShardBlock` for shard `shard` at slot `slot` * `shard_blocks` be the `ShardBlock` list such that `shard_blocks[slot]` is the canonical `ShardBlock` for shard `shard` at slot `slot`
* `beacon_state` be the canonical `BeaconState` * `beacon_state` be the canonical `BeaconState`
* `valid_attestations` be the list of valid `Attestation`, recursively defined * `valid_attestations` be the set of valid `Attestation` objects, recursively defined
* `candidate` be a candidate `Attestation` which is valid under Phase 0 rules, and for which validity is to be determined under Phase 1 rules by running `is_valid_beacon_attestation` * `candidate` be a candidate `Attestation` which is valid under Phase 0 rules, and for which validity is to be determined under Phase 1 rules by running `is_valid_beacon_attestation`
```python ```python
def is_valid_beacon_attestation(shard: Shard, def is_valid_beacon_attestation(shard: Shard,
shard_blocks: List[ShardBlock], shard_blocks: Sequence[ShardBlock],
beacon_state: BeaconState, beacon_state: BeaconState,
valid_attestations: List[Attestation], valid_attestations: Set[Attestation],
candidate: Attestation) -> bool: candidate: Attestation) -> bool:
# Check if attestation is already determined valid # Check if attestation is already determined valid
for _, attestation in enumerate(valid_attestations): for attestation in valid_attestations:
if candidate == attestation: if candidate == attestation:
return True return True

View File

@ -21,8 +21,8 @@ MAX_LIST_LENGTH = 10
@to_dict @to_dict
def create_test_case_contents(value, typ): def create_test_case_contents(value):
yield "value", encode.encode(value, typ) yield "value", encode.encode(value)
yield "serialized", '0x' + serialize(value).hex() yield "serialized", '0x' + serialize(value).hex()
yield "root", '0x' + hash_tree_root(value).hex() yield "root", '0x' + hash_tree_root(value).hex()
if hasattr(value, "signature"): if hasattr(value, "signature"):
@ -32,7 +32,7 @@ def create_test_case_contents(value, typ):
@to_dict @to_dict
def create_test_case(rng: Random, name: str, typ, mode: random_value.RandomizationMode, chaos: bool): def create_test_case(rng: Random, name: str, typ, mode: random_value.RandomizationMode, chaos: bool):
value = random_value.get_random_ssz_object(rng, typ, MAX_BYTES_LENGTH, MAX_LIST_LENGTH, mode, chaos) value = random_value.get_random_ssz_object(rng, typ, MAX_BYTES_LENGTH, MAX_LIST_LENGTH, mode, chaos)
yield name, create_test_case_contents(value, typ) yield name, create_test_case_contents(value)
def get_spec_ssz_types(): def get_spec_ssz_types():

View File

@ -1,39 +1,29 @@
from typing import Any
from eth2spec.utils.ssz.ssz_impl import hash_tree_root from eth2spec.utils.ssz.ssz_impl import hash_tree_root
from eth2spec.utils.ssz.ssz_typing import ( from eth2spec.utils.ssz.ssz_typing import (
is_uint_type, is_bool_type, is_list_type, SSZType, SSZValue, uint, Container, Bytes, List, Bool,
is_vector_type, is_bytes_type, is_bytesn_type, is_container_type,
read_vector_elem_type, read_list_elem_type,
Vector, BytesN Vector, BytesN
) )
def decode(data, typ): def decode(data: Any, typ: SSZType) -> SSZValue:
if is_uint_type(typ): if issubclass(typ, (uint, Bool)):
return data return typ(data)
elif is_bool_type(typ): elif issubclass(typ, (List, Vector)):
assert data in (True, False) return typ(decode(element, typ.elem_type) for element in data)
return data elif issubclass(typ, (Bytes, BytesN)):
elif is_list_type(typ): return typ(bytes.fromhex(data[2:]))
elem_typ = read_list_elem_type(typ) elif issubclass(typ, Container):
return [decode(element, elem_typ) for element in data]
elif is_vector_type(typ):
elem_typ = read_vector_elem_type(typ)
return Vector(decode(element, elem_typ) for element in data)
elif is_bytes_type(typ):
return bytes.fromhex(data[2:])
elif is_bytesn_type(typ):
return BytesN(bytes.fromhex(data[2:]))
elif is_container_type(typ):
temp = {} temp = {}
for field, subtype in typ.get_fields(): for field_name, field_type in typ.get_fields().items():
temp[field] = decode(data[field], subtype) temp[field_name] = decode(data[field_name], field_type)
if field + "_hash_tree_root" in data: if field_name + "_hash_tree_root" in data:
assert(data[field + "_hash_tree_root"][2:] == assert (data[field_name + "_hash_tree_root"][2:] ==
hash_tree_root(temp[field], subtype).hex()) hash_tree_root(temp[field_name]).hex())
ret = typ(**temp) ret = typ(**temp)
if "hash_tree_root" in data: if "hash_tree_root" in data:
assert(data["hash_tree_root"][2:] == assert (data["hash_tree_root"][2:] ==
hash_tree_root(ret, typ).hex()) hash_tree_root(ret).hex())
return ret return ret
else: else:
raise Exception(f"Type not recognized: data={data}, typ={typ}") raise Exception(f"Type not recognized: data={data}, typ={typ}")

View File

@ -1,36 +1,29 @@
from eth2spec.utils.ssz.ssz_impl import hash_tree_root from eth2spec.utils.ssz.ssz_impl import hash_tree_root
from eth2spec.utils.ssz.ssz_typing import ( from eth2spec.utils.ssz.ssz_typing import (
is_uint_type, is_bool_type, is_list_type, is_vector_type, is_container_type, SSZValue, uint, Container, Bool
read_elem_type,
uint
) )
def encode(value, typ, include_hash_tree_roots=False): def encode(value: SSZValue, include_hash_tree_roots=False):
if is_uint_type(typ): if isinstance(value, uint):
if hasattr(typ, '__supertype__'):
typ = typ.__supertype__
# Larger uints are boxed and the class declares their byte length # Larger uints are boxed and the class declares their byte length
if issubclass(typ, uint) and typ.byte_len > 8: if value.type().byte_len > 8:
return str(value) return str(int(value))
return value return int(value)
elif is_bool_type(typ): elif isinstance(value, Bool):
assert value in (True, False) return value == 1
return value elif isinstance(value, list): # normal python lists, ssz-List, Vector
elif is_list_type(typ) or is_vector_type(typ): return [encode(element, include_hash_tree_roots) for element in value]
elem_typ = read_elem_type(typ) elif isinstance(value, bytes): # both bytes and BytesN
return [encode(element, elem_typ, include_hash_tree_roots) for element in value]
elif isinstance(typ, type) and issubclass(typ, bytes): # both bytes and BytesN
return '0x' + value.hex() return '0x' + value.hex()
elif is_container_type(typ): elif isinstance(value, Container):
ret = {} ret = {}
for field, subtype in typ.get_fields(): for field_value, field_name in zip(value, value.get_fields().keys()):
field_value = getattr(value, field) ret[field_name] = encode(field_value, include_hash_tree_roots)
ret[field] = encode(field_value, subtype, include_hash_tree_roots)
if include_hash_tree_roots: if include_hash_tree_roots:
ret[field + "_hash_tree_root"] = '0x' + hash_tree_root(field_value, subtype).hex() ret[field_name + "_hash_tree_root"] = '0x' + hash_tree_root(field_value).hex()
if include_hash_tree_roots: if include_hash_tree_roots:
ret["hash_tree_root"] = '0x' + hash_tree_root(value, typ).hex() ret["hash_tree_root"] = '0x' + hash_tree_root(value).hex()
return ret return ret
else: else:
raise Exception(f"Type not recognized: value={value}, typ={typ}") raise Exception(f"Type not recognized: value={value}, typ={value.type()}")

View File

@ -1,18 +1,13 @@
from random import Random from random import Random
from typing import Any
from enum import Enum from enum import Enum
from eth2spec.utils.ssz.ssz_impl import is_basic_type
from eth2spec.utils.ssz.ssz_typing import ( from eth2spec.utils.ssz.ssz_typing import (
is_uint_type, is_bool_type, is_list_type, SSZType, SSZValue, BasicValue, BasicType, uint, Container, Bytes, List, Bool,
is_vector_type, is_bytes_type, is_bytesn_type, is_container_type, Vector, BytesN
read_vector_elem_type, read_list_elem_type,
uint_byte_size
) )
# in bytes # in bytes
UINT_SIZES = (1, 2, 4, 8, 16, 32) UINT_BYTE_SIZES = (1, 2, 4, 8, 16, 32)
random_mode_names = ("random", "zero", "max", "nil", "one", "lengthy") random_mode_names = ("random", "zero", "max", "nil", "one", "lengthy")
@ -39,11 +34,11 @@ class RandomizationMode(Enum):
def get_random_ssz_object(rng: Random, def get_random_ssz_object(rng: Random,
typ: Any, typ: SSZType,
max_bytes_length: int, max_bytes_length: int,
max_list_length: int, max_list_length: int,
mode: RandomizationMode, mode: RandomizationMode,
chaos: bool) -> Any: chaos: bool) -> SSZValue:
""" """
Create an object for a given type, filled with random data. Create an object for a given type, filled with random data.
:param rng: The random number generator to use. :param rng: The random number generator to use.
@ -56,33 +51,31 @@ def get_random_ssz_object(rng: Random,
""" """
if chaos: if chaos:
mode = rng.choice(list(RandomizationMode)) mode = rng.choice(list(RandomizationMode))
if is_bytes_type(typ): if issubclass(typ, Bytes):
# Bytes array # Bytes array
if mode == RandomizationMode.mode_nil_count: if mode == RandomizationMode.mode_nil_count:
return b'' return typ(b'')
elif mode == RandomizationMode.mode_max_count: elif mode == RandomizationMode.mode_max_count:
return get_random_bytes_list(rng, max_bytes_length) return typ(get_random_bytes_list(rng, max_bytes_length))
elif mode == RandomizationMode.mode_one_count: elif mode == RandomizationMode.mode_one_count:
return get_random_bytes_list(rng, 1) return typ(get_random_bytes_list(rng, 1))
elif mode == RandomizationMode.mode_zero: elif mode == RandomizationMode.mode_zero:
return b'\x00' return typ(b'\x00')
elif mode == RandomizationMode.mode_max: elif mode == RandomizationMode.mode_max:
return b'\xff' return typ(b'\xff')
else: else:
return get_random_bytes_list(rng, rng.randint(0, max_bytes_length)) return typ(get_random_bytes_list(rng, rng.randint(0, max_bytes_length)))
elif is_bytesn_type(typ): elif issubclass(typ, BytesN):
# BytesN
length = typ.length
# Sanity, don't generate absurdly big random values # Sanity, don't generate absurdly big random values
# If a client is aiming to performance-test, they should create a benchmark suite. # If a client is aiming to performance-test, they should create a benchmark suite.
assert length <= max_bytes_length assert typ.length <= max_bytes_length
if mode == RandomizationMode.mode_zero: if mode == RandomizationMode.mode_zero:
return b'\x00' * length return typ(b'\x00' * typ.length)
elif mode == RandomizationMode.mode_max: elif mode == RandomizationMode.mode_max:
return b'\xff' * length return typ(b'\xff' * typ.length)
else: else:
return get_random_bytes_list(rng, length) return typ(get_random_bytes_list(rng, typ.length))
elif is_basic_type(typ): elif issubclass(typ, BasicValue):
# Basic types # Basic types
if mode == RandomizationMode.mode_zero: if mode == RandomizationMode.mode_zero:
return get_min_basic_value(typ) return get_min_basic_value(typ)
@ -90,32 +83,31 @@ def get_random_ssz_object(rng: Random,
return get_max_basic_value(typ) return get_max_basic_value(typ)
else: else:
return get_random_basic_value(rng, typ) return get_random_basic_value(rng, typ)
elif is_vector_type(typ): elif issubclass(typ, Vector):
# Vector return typ(
elem_typ = read_vector_elem_type(typ) get_random_ssz_object(rng, typ.elem_type, max_bytes_length, max_list_length, mode, chaos)
return [
get_random_ssz_object(rng, elem_typ, max_bytes_length, max_list_length, mode, chaos)
for _ in range(typ.length) for _ in range(typ.length)
] )
elif is_list_type(typ): elif issubclass(typ, List):
# List length = rng.randint(0, min(typ.length, max_list_length))
elem_typ = read_list_elem_type(typ)
length = rng.randint(0, max_list_length)
if mode == RandomizationMode.mode_one_count: if mode == RandomizationMode.mode_one_count:
length = 1 length = 1
elif mode == RandomizationMode.mode_max_count: elif mode == RandomizationMode.mode_max_count:
length = max_list_length length = max_list_length
return [ if typ.length < length: # SSZ imposes a hard limit on lists, we can't put in more than that
get_random_ssz_object(rng, elem_typ, max_bytes_length, max_list_length, mode, chaos) length = typ.length
return typ(
get_random_ssz_object(rng, typ.elem_type, max_bytes_length, max_list_length, mode, chaos)
for _ in range(length) for _ in range(length)
] )
elif is_container_type(typ): elif issubclass(typ, Container):
# Container # Container
return typ(**{ return typ(**{
field: field_name:
get_random_ssz_object(rng, subtype, max_bytes_length, max_list_length, mode, chaos) get_random_ssz_object(rng, field_type, max_bytes_length, max_list_length, mode, chaos)
for field, subtype in typ.get_fields() for field_name, field_type in typ.get_fields().items()
}) })
else: else:
raise Exception(f"Type not recognized: typ={typ}") raise Exception(f"Type not recognized: typ={typ}")
@ -125,34 +117,31 @@ def get_random_bytes_list(rng: Random, length: int) -> bytes:
return bytes(rng.getrandbits(8) for _ in range(length)) return bytes(rng.getrandbits(8) for _ in range(length))
def get_random_basic_value(rng: Random, typ) -> Any: def get_random_basic_value(rng: Random, typ: BasicType) -> BasicValue:
if is_bool_type(typ): if issubclass(typ, Bool):
return rng.choice((True, False)) return typ(rng.choice((True, False)))
elif is_uint_type(typ): elif issubclass(typ, uint):
size = uint_byte_size(typ) assert typ.byte_len in UINT_BYTE_SIZES
assert size in UINT_SIZES return typ(rng.randint(0, 256 ** typ.byte_len - 1))
return rng.randint(0, 256**size - 1)
else: else:
raise ValueError(f"Not a basic type: typ={typ}") raise ValueError(f"Not a basic type: typ={typ}")
def get_min_basic_value(typ) -> Any: def get_min_basic_value(typ: BasicType) -> BasicValue:
if is_bool_type(typ): if issubclass(typ, Bool):
return False return typ(False)
elif is_uint_type(typ): elif issubclass(typ, uint):
size = uint_byte_size(typ) assert typ.byte_len in UINT_BYTE_SIZES
assert size in UINT_SIZES return typ(0)
return 0
else: else:
raise ValueError(f"Not a basic type: typ={typ}") raise ValueError(f"Not a basic type: typ={typ}")
def get_max_basic_value(typ) -> Any: def get_max_basic_value(typ: BasicType) -> BasicValue:
if is_bool_type(typ): if issubclass(typ, Bool):
return True return typ(True)
elif is_uint_type(typ): elif issubclass(typ, uint):
size = uint_byte_size(typ) assert typ.byte_len in UINT_BYTE_SIZES
assert size in UINT_SIZES return typ(256 ** typ.byte_len - 1)
return 256**size - 1
else: else:
raise ValueError(f"Not a basic type: typ={typ}") raise ValueError(f"Not a basic type: typ={typ}")

View File

@ -8,32 +8,31 @@ def translate_typ(typ) -> ssz.BaseSedes:
:param typ: The spec type, a class. :param typ: The spec type, a class.
:return: The Py-SSZ equivalent. :return: The Py-SSZ equivalent.
""" """
if spec_ssz.is_container_type(typ): if issubclass(typ, spec_ssz.Container):
return ssz.Container( return ssz.Container(
[translate_typ(field_typ) for (field_name, field_typ) in typ.get_fields()]) [translate_typ(field_typ) for field_name, field_typ in typ.get_fields().items()])
elif spec_ssz.is_bytesn_type(typ): elif issubclass(typ, spec_ssz.BytesN):
return ssz.ByteVector(typ.length) return ssz.ByteVector(typ.length)
elif spec_ssz.is_bytes_type(typ): elif issubclass(typ, spec_ssz.Bytes):
return ssz.ByteList() return ssz.ByteList()
elif spec_ssz.is_vector_type(typ): elif issubclass(typ, spec_ssz.Vector):
return ssz.Vector(translate_typ(spec_ssz.read_vector_elem_type(typ)), typ.length) return ssz.Vector(translate_typ(typ.elem_type), typ.length)
elif spec_ssz.is_list_type(typ): elif issubclass(typ, spec_ssz.List):
return ssz.List(translate_typ(spec_ssz.read_list_elem_type(typ))) return ssz.List(translate_typ(typ.elem_type))
elif spec_ssz.is_bool_type(typ): elif issubclass(typ, spec_ssz.Bool):
return ssz.boolean return ssz.boolean
elif spec_ssz.is_uint_type(typ): elif issubclass(typ, spec_ssz.uint):
size = spec_ssz.uint_byte_size(typ) if typ.byte_len == 1:
if size == 1:
return ssz.uint8 return ssz.uint8
elif size == 2: elif typ.byte_len == 2:
return ssz.uint16 return ssz.uint16
elif size == 4: elif typ.byte_len == 4:
return ssz.uint32 return ssz.uint32
elif size == 8: elif typ.byte_len == 8:
return ssz.uint64 return ssz.uint64
elif size == 16: elif typ.byte_len == 16:
return ssz.uint128 return ssz.uint128
elif size == 32: elif typ.byte_len == 32:
return ssz.uint256 return ssz.uint256
else: else:
raise TypeError("invalid uint size") raise TypeError("invalid uint size")
@ -48,37 +47,33 @@ def translate_value(value, typ):
:param typ: The type from the spec to translate into :param typ: The type from the spec to translate into
:return: the translated value :return: the translated value
""" """
if spec_ssz.is_uint_type(typ): if issubclass(typ, spec_ssz.uint):
size = spec_ssz.uint_byte_size(typ) if typ.byte_len == 1:
if size == 1:
return spec_ssz.uint8(value) return spec_ssz.uint8(value)
elif size == 2: elif typ.byte_len == 2:
return spec_ssz.uint16(value) return spec_ssz.uint16(value)
elif size == 4: elif typ.byte_len == 4:
return spec_ssz.uint32(value) return spec_ssz.uint32(value)
elif size == 8: elif typ.byte_len == 8:
# uint64 is default (TODO this is changing soon) return spec_ssz.uint64(value)
return value elif typ.byte_len == 16:
elif size == 16:
return spec_ssz.uint128(value) return spec_ssz.uint128(value)
elif size == 32: elif typ.byte_len == 32:
return spec_ssz.uint256(value) return spec_ssz.uint256(value)
else: else:
raise TypeError("invalid uint size") raise TypeError("invalid uint size")
elif spec_ssz.is_list_type(typ): elif issubclass(typ, spec_ssz.List):
elem_typ = spec_ssz.read_elem_type(typ) return [translate_value(elem, typ.elem_type) for elem in value]
return [translate_value(elem, elem_typ) for elem in value] elif issubclass(typ, spec_ssz.Bool):
elif spec_ssz.is_bool_type(typ):
return value return value
elif spec_ssz.is_vector_type(typ): elif issubclass(typ, spec_ssz.Vector):
elem_typ = spec_ssz.read_elem_type(typ) return typ(*(translate_value(elem, typ.elem_type) for elem in value))
return typ(*(translate_value(elem, elem_typ) for elem in value)) elif issubclass(typ, spec_ssz.BytesN):
elif spec_ssz.is_bytesn_type(typ):
return typ(value) return typ(value)
elif spec_ssz.is_bytes_type(typ): elif issubclass(typ, spec_ssz.Bytes):
return value return value
elif spec_ssz.is_container_type(typ): if issubclass(typ, spec_ssz.Container):
return typ(**{f_name: translate_value(f_val, f_typ) for (f_name, f_val, f_typ) return typ(**{f_name: translate_value(f_val, f_typ) for (f_val, (f_name, f_typ))
in zip(typ.get_field_names(), value, typ.get_field_types())}) in zip(value, typ.get_fields().items())})
else: else:
raise TypeError("Type not supported: {}".format(typ)) raise TypeError("Type not supported: {}".format(typ))

View File

@ -11,7 +11,7 @@ def get_valid_early_derived_secret_reveal(spec, state, epoch=None):
epoch = current_epoch + spec.CUSTODY_PERIOD_TO_RANDAO_PADDING epoch = current_epoch + spec.CUSTODY_PERIOD_TO_RANDAO_PADDING
reveal = bls_sign( reveal = bls_sign(
message_hash=spec.hash_tree_root(epoch), message_hash=spec.hash_tree_root(spec.Epoch(epoch)),
privkey=privkeys[revealed_index], privkey=privkeys[revealed_index],
domain=spec.get_domain( domain=spec.get_domain(
state=state, state=state,
@ -20,14 +20,14 @@ def get_valid_early_derived_secret_reveal(spec, state, epoch=None):
), ),
) )
mask = bls_sign( mask = bls_sign(
message_hash=spec.hash_tree_root(epoch), message_hash=spec.hash_tree_root(spec.Epoch(epoch)),
privkey=privkeys[masker_index], privkey=privkeys[masker_index],
domain=spec.get_domain( domain=spec.get_domain(
state=state, state=state,
domain_type=spec.DOMAIN_RANDAO, domain_type=spec.DOMAIN_RANDAO,
message_epoch=epoch, message_epoch=epoch,
), ),
) )[:32] # TODO(Carl): mask is 32 bytes, and signature is 96? Correct to slice the first 32 out?
return spec.EarlyDerivedSecretReveal( return spec.EarlyDerivedSecretReveal(
revealed_index=revealed_index, revealed_index=revealed_index,

View File

@ -1,5 +1,6 @@
from eth2spec.test.helpers.keys import pubkeys from eth2spec.test.helpers.keys import pubkeys
from eth2spec.utils.ssz.ssz_impl import hash_tree_root from eth2spec.utils.ssz.ssz_impl import hash_tree_root
from eth2spec.utils.ssz.ssz_typing import List
def build_mock_validator(spec, i: int, balance: int): def build_mock_validator(spec, i: int, balance: int):
@ -40,7 +41,8 @@ def create_genesis_state(spec, num_validators):
validator.activation_eligibility_epoch = spec.GENESIS_EPOCH validator.activation_eligibility_epoch = spec.GENESIS_EPOCH
validator.activation_epoch = spec.GENESIS_EPOCH validator.activation_epoch = spec.GENESIS_EPOCH
genesis_active_index_root = hash_tree_root(spec.get_active_validator_indices(state, spec.GENESIS_EPOCH)) genesis_active_index_root = hash_tree_root(List[spec.ValidatorIndex, spec.VALIDATOR_REGISTRY_LIMIT](
spec.get_active_validator_indices(state, spec.GENESIS_EPOCH)))
for index in range(spec.EPOCHS_PER_HISTORICAL_VECTOR): for index in range(spec.EPOCHS_PER_HISTORICAL_VECTOR):
state.active_index_roots[index] = genesis_active_index_root state.active_index_roots[index] = genesis_active_index_root

View File

@ -114,21 +114,6 @@ def test_invalid_withdrawal_credentials_top_up(spec, state):
yield from run_deposit_processing(spec, state, deposit, validator_index, valid=True, effective=True) yield from run_deposit_processing(spec, state, deposit, validator_index, valid=True, effective=True)
@with_all_phases
@spec_state_test
def test_wrong_index(spec, state):
validator_index = len(state.validators)
amount = spec.MAX_EFFECTIVE_BALANCE
deposit = prepare_state_and_deposit(spec, state, validator_index, amount)
# mess up eth1_deposit_index
deposit.index = state.eth1_deposit_index + 1
sign_deposit_data(spec, state, deposit.data, privkeys[validator_index])
yield from run_deposit_processing(spec, state, deposit, validator_index, valid=False)
@with_all_phases @with_all_phases
@spec_state_test @spec_state_test
def test_wrong_deposit_for_deposit_count(spec, state): def test_wrong_deposit_for_deposit_count(spec, state):
@ -172,9 +157,6 @@ def test_wrong_deposit_for_deposit_count(spec, state):
yield from run_deposit_processing(spec, state, deposit_2, index_2, valid=False) yield from run_deposit_processing(spec, state, deposit_2, index_2, valid=False)
# TODO: test invalid signature
@with_all_phases @with_all_phases
@spec_state_test @spec_state_test
def test_bad_merkle_proof(spec, state): def test_bad_merkle_proof(spec, state):
@ -183,7 +165,7 @@ def test_bad_merkle_proof(spec, state):
deposit = prepare_state_and_deposit(spec, state, validator_index, amount) deposit = prepare_state_and_deposit(spec, state, validator_index, amount)
# mess up merkle branch # mess up merkle branch
deposit.proof[-1] = spec.ZERO_HASH deposit.proof[5] = spec.ZERO_HASH
sign_deposit_data(spec, state, deposit.data, privkeys[validator_index]) sign_deposit_data(spec, state, deposit.data, privkeys[validator_index])

View File

@ -1,5 +1,4 @@
from copy import deepcopy from copy import deepcopy
from typing import List
from eth2spec.utils.ssz.ssz_impl import signing_root from eth2spec.utils.ssz.ssz_impl import signing_root
from eth2spec.utils.bls import bls_sign from eth2spec.utils.bls import bls_sign
@ -28,7 +27,7 @@ def test_empty_block_transition(spec, state):
state_transition_and_sign_block(spec, state, block) state_transition_and_sign_block(spec, state, block)
yield 'blocks', [block], List[spec.BeaconBlock] yield 'blocks', [block]
yield 'post', state yield 'post', state
assert len(state.eth1_data_votes) == pre_eth1_votes + 1 assert len(state.eth1_data_votes) == pre_eth1_votes + 1
@ -48,7 +47,7 @@ def test_skipped_slots(spec, state):
state_transition_and_sign_block(spec, state, block) state_transition_and_sign_block(spec, state, block)
yield 'blocks', [block], List[spec.BeaconBlock] yield 'blocks', [block]
yield 'post', state yield 'post', state
assert state.slot == block.slot assert state.slot == block.slot
@ -69,7 +68,7 @@ def test_empty_epoch_transition(spec, state):
state_transition_and_sign_block(spec, state, block) state_transition_and_sign_block(spec, state, block)
yield 'blocks', [block], List[spec.BeaconBlock] yield 'blocks', [block]
yield 'post', state yield 'post', state
assert state.slot == block.slot assert state.slot == block.slot
@ -90,7 +89,7 @@ def test_empty_epoch_transition(spec, state):
# state_transition_and_sign_block(spec, state, block) # state_transition_and_sign_block(spec, state, block)
# yield 'blocks', [block], List[spec.BeaconBlock] # yield 'blocks', [block]
# yield 'post', state # yield 'post', state
# assert state.slot == block.slot # assert state.slot == block.slot
@ -120,7 +119,7 @@ def test_proposer_slashing(spec, state):
state_transition_and_sign_block(spec, state, block) state_transition_and_sign_block(spec, state, block)
yield 'blocks', [block], List[spec.BeaconBlock] yield 'blocks', [block]
yield 'post', state yield 'post', state
# check if slashed # check if slashed
@ -155,7 +154,7 @@ def test_attester_slashing(spec, state):
state_transition_and_sign_block(spec, state, block) state_transition_and_sign_block(spec, state, block)
yield 'blocks', [block], List[spec.BeaconBlock] yield 'blocks', [block]
yield 'post', state yield 'post', state
slashed_validator = state.validators[validator_index] slashed_validator = state.validators[validator_index]
@ -193,7 +192,7 @@ def test_deposit_in_block(spec, state):
state_transition_and_sign_block(spec, state, block) state_transition_and_sign_block(spec, state, block)
yield 'blocks', [block], List[spec.BeaconBlock] yield 'blocks', [block]
yield 'post', state yield 'post', state
assert len(state.validators) == initial_registry_len + 1 assert len(state.validators) == initial_registry_len + 1
@ -221,7 +220,7 @@ def test_deposit_top_up(spec, state):
state_transition_and_sign_block(spec, state, block) state_transition_and_sign_block(spec, state, block)
yield 'blocks', [block], List[spec.BeaconBlock] yield 'blocks', [block]
yield 'post', state yield 'post', state
assert len(state.validators) == initial_registry_len assert len(state.validators) == initial_registry_len
@ -256,7 +255,7 @@ def test_attestation(spec, state):
sign_block(spec, state, epoch_block) sign_block(spec, state, epoch_block)
state_transition_and_sign_block(spec, state, epoch_block) state_transition_and_sign_block(spec, state, epoch_block)
yield 'blocks', [attestation_block, epoch_block], List[spec.BeaconBlock] yield 'blocks', [attestation_block, epoch_block]
yield 'post', state yield 'post', state
assert len(state.current_epoch_attestations) == 0 assert len(state.current_epoch_attestations) == 0
@ -303,7 +302,7 @@ def test_voluntary_exit(spec, state):
sign_block(spec, state, exit_block) sign_block(spec, state, exit_block)
state_transition_and_sign_block(spec, state, exit_block) state_transition_and_sign_block(spec, state, exit_block)
yield 'blocks', [initiate_exit_block, exit_block], List[spec.BeaconBlock] yield 'blocks', [initiate_exit_block, exit_block]
yield 'post', state yield 'post', state
assert state.validators[validator_index].exit_epoch < spec.FAR_FUTURE_EPOCH assert state.validators[validator_index].exit_epoch < spec.FAR_FUTURE_EPOCH
@ -334,7 +333,7 @@ def test_voluntary_exit(spec, state):
# state_transition_and_sign_block(spec, state, block) # state_transition_and_sign_block(spec, state, block)
# yield 'blocks', [block], List[spec.BeaconBlock] # yield 'blocks', [block]
# yield 'post', state # yield 'post', state
# sender_balance = get_balance(state, sender_index) # sender_balance = get_balance(state, sender_index)
@ -362,7 +361,7 @@ def test_balance_driven_status_transitions(spec, state):
sign_block(spec, state, block) sign_block(spec, state, block)
state_transition_and_sign_block(spec, state, block) state_transition_and_sign_block(spec, state, block)
yield 'blocks', [block], List[spec.BeaconBlock] yield 'blocks', [block]
yield 'post', state yield 'post', state
assert state.validators[validator_index].exit_epoch < spec.FAR_FUTURE_EPOCH assert state.validators[validator_index].exit_epoch < spec.FAR_FUTURE_EPOCH
@ -379,7 +378,7 @@ def test_historical_batch(spec, state):
block = build_empty_block_for_next_slot(spec, state, signed=True) block = build_empty_block_for_next_slot(spec, state, signed=True)
state_transition_and_sign_block(spec, state, block) state_transition_and_sign_block(spec, state, block)
yield 'blocks', [block], List[spec.BeaconBlock] yield 'blocks', [block]
yield 'post', state yield 'post', state
assert state.slot == block.slot assert state.slot == block.slot
@ -408,7 +407,7 @@ def test_historical_batch(spec, state):
# state_transition_and_sign_block(spec, state, block) # state_transition_and_sign_block(spec, state, block)
# yield 'blocks', [block], List[spec.BeaconBlock] # yield 'blocks', [block]
# yield 'post', state # yield 'post', state
# assert state.slot % spec.SLOTS_PER_ETH1_VOTING_PERIOD == 0 # assert state.slot % spec.SLOTS_PER_ETH1_VOTING_PERIOD == 0

View File

@ -1,5 +1,4 @@
from copy import deepcopy from copy import deepcopy
from typing import List
from eth2spec.test.context import spec_state_test, never_bls, with_all_phases from eth2spec.test.context import spec_state_test, never_bls, with_all_phases
from eth2spec.test.helpers.state import next_epoch, state_transition_and_sign_block from eth2spec.test.helpers.state import next_epoch, state_transition_and_sign_block
@ -39,11 +38,13 @@ def next_epoch_with_attestations(spec,
state, state,
fill_cur_epoch, fill_cur_epoch,
fill_prev_epoch): fill_prev_epoch):
assert state.slot % spec.SLOTS_PER_EPOCH == 0
post_state = deepcopy(state) post_state = deepcopy(state)
blocks = [] blocks = []
for _ in range(spec.SLOTS_PER_EPOCH): for _ in range(spec.SLOTS_PER_EPOCH):
block = build_empty_block_for_next_slot(spec, post_state) block = build_empty_block_for_next_slot(spec, post_state)
if fill_cur_epoch: if fill_cur_epoch and post_state.slot >= spec.MIN_ATTESTATION_INCLUSION_DELAY:
slot_to_attest = post_state.slot - spec.MIN_ATTESTATION_INCLUSION_DELAY + 1 slot_to_attest = post_state.slot - spec.MIN_ATTESTATION_INCLUSION_DELAY + 1
if slot_to_attest >= spec.get_epoch_start_slot(spec.get_current_epoch(post_state)): if slot_to_attest >= spec.get_epoch_start_slot(spec.get_current_epoch(post_state)):
cur_attestation = get_valid_attestation(spec, post_state, slot_to_attest) cur_attestation = get_valid_attestation(spec, post_state, slot_to_attest)
@ -63,11 +64,13 @@ def next_epoch_with_attestations(spec,
@with_all_phases @with_all_phases
@never_bls @never_bls
@spec_state_test @spec_state_test
def test_finality_rule_4(spec, state): def test_finality_no_updates_at_genesis(spec, state):
assert spec.get_current_epoch(state) == spec.GENESIS_EPOCH
yield 'pre', state yield 'pre', state
blocks = [] blocks = []
for epoch in range(4): for epoch in range(2):
prev_state, new_blocks, state = next_epoch_with_attestations(spec, state, True, False) prev_state, new_blocks, state = next_epoch_with_attestations(spec, state, True, False)
blocks += new_blocks blocks += new_blocks
@ -77,15 +80,37 @@ def test_finality_rule_4(spec, state):
# justification/finalization skipped at GENESIS_EPOCH + 1 # justification/finalization skipped at GENESIS_EPOCH + 1
elif epoch == 1: elif epoch == 1:
check_finality(spec, state, prev_state, False, False, False) check_finality(spec, state, prev_state, False, False, False)
elif epoch == 2:
yield 'blocks', blocks
yield 'post', state
@with_all_phases
@never_bls
@spec_state_test
def test_finality_rule_4(spec, state):
# get past first two epochs that finality does not run on
next_epoch(spec, state)
apply_empty_block(spec, state)
next_epoch(spec, state)
apply_empty_block(spec, state)
yield 'pre', state
blocks = []
for epoch in range(2):
prev_state, new_blocks, state = next_epoch_with_attestations(spec, state, True, False)
blocks += new_blocks
if epoch == 0:
check_finality(spec, state, prev_state, True, False, False) check_finality(spec, state, prev_state, True, False, False)
elif epoch >= 3: elif epoch == 1:
# rule 4 of finality # rule 4 of finality
check_finality(spec, state, prev_state, True, True, True) check_finality(spec, state, prev_state, True, True, True)
assert state.finalized_epoch == prev_state.current_justified_epoch assert state.finalized_epoch == prev_state.current_justified_epoch
assert state.finalized_root == prev_state.current_justified_root assert state.finalized_root == prev_state.current_justified_root
yield 'blocks', blocks, List[spec.BeaconBlock] yield 'blocks', blocks
yield 'post', state yield 'post', state
@ -116,7 +141,7 @@ def test_finality_rule_1(spec, state):
assert state.finalized_epoch == prev_state.previous_justified_epoch assert state.finalized_epoch == prev_state.previous_justified_epoch
assert state.finalized_root == prev_state.previous_justified_root assert state.finalized_root == prev_state.previous_justified_root
yield 'blocks', blocks, List[spec.BeaconBlock] yield 'blocks', blocks
yield 'post', state yield 'post', state
@ -149,7 +174,7 @@ def test_finality_rule_2(spec, state):
blocks += new_blocks blocks += new_blocks
yield 'blocks', blocks, List[spec.BeaconBlock] yield 'blocks', blocks
yield 'post', state yield 'post', state
@ -199,5 +224,5 @@ def test_finality_rule_3(spec, state):
assert state.finalized_epoch == prev_state.current_justified_epoch assert state.finalized_epoch == prev_state.current_justified_epoch
assert state.finalized_root == prev_state.current_justified_root assert state.finalized_root == prev_state.current_justified_root
yield 'blocks', blocks, List[spec.BeaconBlock] yield 'blocks', blocks
yield 'post', state yield 'post', state

View File

@ -4,7 +4,7 @@ from .hash_function import hash
ZERO_BYTES32 = b'\x00' * 32 ZERO_BYTES32 = b'\x00' * 32
zerohashes = [ZERO_BYTES32] zerohashes = [ZERO_BYTES32]
for layer in range(1, 32): for layer in range(1, 100):
zerohashes.append(hash(zerohashes[layer - 1] + zerohashes[layer - 1])) zerohashes.append(hash(zerohashes[layer - 1] + zerohashes[layer - 1]))
@ -44,11 +44,35 @@ def next_power_of_two(v: int) -> int:
return 1 << (v - 1).bit_length() return 1 << (v - 1).bit_length()
def merkleize_chunks(chunks): def merkleize_chunks(chunks, pad_to: int = 1):
tree = chunks[::] count = len(chunks)
margin = next_power_of_two(len(chunks)) - len(chunks) depth = max(count - 1, 0).bit_length()
tree.extend([ZERO_BYTES32] * margin) max_depth = max(depth, (pad_to - 1).bit_length())
tree = [ZERO_BYTES32] * len(tree) + tree tmp = [None for _ in range(max_depth + 1)]
for i in range(len(tree) // 2 - 1, 0, -1):
tree[i] = hash(tree[i * 2] + tree[i * 2 + 1]) def merge(h, i):
return tree[1] j = 0
while True:
if i & (1 << j) == 0:
if i == count and j < depth:
h = hash(h + zerohashes[j]) # keep going if we are complementing the void to the next power of 2
else:
break
else:
h = hash(tmp[j] + h)
j += 1
tmp[j] = h
# merge in leaf by leaf.
for i in range(count):
merge(chunks[i], i)
# complement with 0 if empty, or if not the right power of 2
if 1 << depth != count:
merge(zerohashes[0], count)
# the next power of two may be smaller than the ultimate virtual size, complement with zero-hashes at each depth.
for j in range(depth, max_depth):
tmp[j + 1] = hash(tmp[j] + zerohashes[j])
return tmp[max_depth]

View File

@ -1,11 +1,7 @@
from ..merkle_minimal import merkleize_chunks, hash from ..merkle_minimal import merkleize_chunks
from eth2spec.utils.ssz.ssz_typing import ( from ..hash_function import hash
is_uint_type, is_bool_type, is_container_type, from .ssz_typing import (
is_list_kind, is_vector_kind, SSZValue, SSZType, BasicValue, BasicType, Series, Elements, Bool, Container, List, Bytes, uint,
read_vector_elem_type, read_elem_type,
uint_byte_size,
infer_input_type,
get_zero_value,
) )
# SSZ Serialization # SSZ Serialization
@ -14,68 +10,47 @@ from eth2spec.utils.ssz.ssz_typing import (
BYTES_PER_LENGTH_OFFSET = 4 BYTES_PER_LENGTH_OFFSET = 4
def is_basic_type(typ): def serialize_basic(value: SSZValue):
return is_uint_type(typ) or is_bool_type(typ) if isinstance(value, uint):
return value.to_bytes(value.type().byte_len, 'little')
elif isinstance(value, Bool):
def serialize_basic(value, typ):
if is_uint_type(typ):
return value.to_bytes(uint_byte_size(typ), 'little')
elif is_bool_type(typ):
if value: if value:
return b'\x01' return b'\x01'
else: else:
return b'\x00' return b'\x00'
else: else:
raise Exception("Type not supported: {}".format(typ)) raise Exception(f"Type not supported: {type(value)}")
def deserialize_basic(value, typ): def deserialize_basic(value, typ: BasicType):
if is_uint_type(typ): if issubclass(typ, uint):
return typ(int.from_bytes(value, 'little')) return typ(int.from_bytes(value, 'little'))
elif is_bool_type(typ): elif issubclass(typ, Bool):
assert value in (b'\x00', b'\x01') assert value in (b'\x00', b'\x01')
return True if value == b'\x01' else False return typ(value == b'\x01')
else: else:
raise Exception("Type not supported: {}".format(typ)) raise Exception(f"Type not supported: {typ}")
def is_fixed_size(typ): def is_empty(obj: SSZValue):
if is_basic_type(typ): return type(obj).default() == obj
return True
elif is_list_kind(typ):
return False def serialize(obj: SSZValue):
elif is_vector_kind(typ): if isinstance(obj, BasicValue):
return is_fixed_size(read_vector_elem_type(typ)) return serialize_basic(obj)
elif is_container_type(typ): elif isinstance(obj, Series):
return all(is_fixed_size(t) for t in typ.get_field_types()) return encode_series(obj)
else: else:
raise Exception("Type not supported: {}".format(typ)) raise Exception(f"Type not supported: {type(obj)}")
def is_empty(obj): def encode_series(values: Series):
return get_zero_value(type(obj)) == obj if isinstance(values, bytes): # Bytes and BytesN are already like serialized output
@infer_input_type
def serialize(obj, typ=None):
if is_basic_type(typ):
return serialize_basic(obj, typ)
elif is_list_kind(typ) or is_vector_kind(typ):
return encode_series(obj, [read_elem_type(typ)] * len(obj))
elif is_container_type(typ):
return encode_series(obj.get_field_values(), typ.get_field_types())
else:
raise Exception("Type not supported: {}".format(typ))
def encode_series(values, types):
# bytes and bytesN are already in the right format.
if isinstance(values, bytes):
return values return values
# Recursively serialize # Recursively serialize
parts = [(is_fixed_size(types[i]), serialize(values[i], typ=types[i])) for i in range(len(values))] parts = [(v.type().is_fixed_size(), serialize(v)) for v in values]
# Compute and check lengths # Compute and check lengths
fixed_lengths = [len(serialized) if constant_size else BYTES_PER_LENGTH_OFFSET fixed_lengths = [len(serialized) if constant_size else BYTES_PER_LENGTH_OFFSET
@ -107,10 +82,10 @@ def encode_series(values, types):
# ----------------------------- # -----------------------------
def pack(values, subtype): def pack(values: Series):
if isinstance(values, bytes): if isinstance(values, bytes): # Bytes and BytesN are already packed
return values return values
return b''.join([serialize_basic(value, subtype) for value in values]) return b''.join([serialize_basic(value) for value in values])
def chunkify(bytez): def chunkify(bytez):
@ -123,41 +98,50 @@ def mix_in_length(root, length):
return hash(root + length.to_bytes(32, 'little')) return hash(root + length.to_bytes(32, 'little'))
def is_bottom_layer_kind(typ): def is_bottom_layer_kind(typ: SSZType):
return ( return (
is_basic_type(typ) or isinstance(typ, BasicType) or
(is_list_kind(typ) or is_vector_kind(typ)) and is_basic_type(read_elem_type(typ)) (issubclass(typ, Elements) and isinstance(typ.elem_type, BasicType))
) )
@infer_input_type def item_length(typ: SSZType) -> int:
def get_typed_values(obj, typ=None): if issubclass(typ, BasicValue):
if is_container_type(typ): return typ.byte_len
return obj.get_typed_values()
elif is_list_kind(typ) or is_vector_kind(typ):
elem_type = read_elem_type(typ)
return list(zip(obj, [elem_type] * len(obj)))
else: else:
raise Exception("Invalid type") return 32
@infer_input_type def chunk_count(typ: SSZType) -> int:
def hash_tree_root(obj, typ=None): if isinstance(typ, BasicType):
if is_bottom_layer_kind(typ): return 1
data = serialize_basic(obj, typ) if is_basic_type(typ) else pack(obj, read_elem_type(typ)) elif issubclass(typ, Elements):
leaves = chunkify(data) return (typ.length * item_length(typ.elem_type) + 31) // 32
elif issubclass(typ, Container):
return len(typ.get_fields())
else: else:
fields = get_typed_values(obj, typ=typ) raise Exception(f"Type not supported: {typ}")
leaves = [hash_tree_root(field_value, typ=field_typ) for field_value, field_typ in fields]
if is_list_kind(typ):
return mix_in_length(merkleize_chunks(leaves), len(obj)) def hash_tree_root(obj: SSZValue):
if isinstance(obj, Series):
if is_bottom_layer_kind(obj.type()):
leaves = chunkify(pack(obj))
else:
leaves = [hash_tree_root(value) for value in obj]
elif isinstance(obj, BasicValue):
leaves = chunkify(serialize_basic(obj))
else:
raise Exception(f"Type not supported: {type(obj)}")
if isinstance(obj, (List, Bytes)):
return mix_in_length(merkleize_chunks(leaves, pad_to=chunk_count(obj.type())), len(obj))
else: else:
return merkleize_chunks(leaves) return merkleize_chunks(leaves)
@infer_input_type def signing_root(obj: Container):
def signing_root(obj, typ):
assert is_container_type(typ)
# ignore last field # ignore last field
leaves = [hash_tree_root(field_value, typ=field_typ) for field_value, field_typ in obj.get_typed_values()[:-1]] fields = [field for field in obj][:-1]
leaves = [hash_tree_root(f) for f in fields]
return merkleize_chunks(chunkify(b''.join(leaves))) return merkleize_chunks(chunkify(b''.join(leaves)))

View File

@ -1,149 +1,183 @@
from typing import Dict, Iterator
import copy import copy
from types import GeneratorType from types import GeneratorType
from typing import (
List,
Iterable,
TypeVar,
Type,
NewType,
Union,
)
from typing_inspect import get_origin
# SSZ integers
# -----------------------------
class uint(int): class DefaultingTypeMeta(type):
def default(cls):
raise Exception("Not implemented")
class SSZType(DefaultingTypeMeta):
def is_fixed_size(cls):
raise Exception("Not implemented")
class SSZValue(object, metaclass=SSZType):
def type(self):
return self.__class__
class BasicType(SSZType):
byte_len = 0 byte_len = 0
def is_fixed_size(cls):
return True
class BasicValue(int, SSZValue, metaclass=BasicType):
pass
class Bool(BasicValue): # can't subclass bool.
byte_len = 1
def __new__(cls, value, *args, **kwargs):
if value < 0 or value > 1:
raise ValueError(f"value {value} out of bounds for bit")
return super().__new__(cls, value)
@classmethod
def default(cls):
return cls(0)
def __bool__(self):
return self > 0
# Alias for Bool
class Bit(Bool):
pass
class uint(BasicValue, metaclass=BasicType):
def __new__(cls, value, *args, **kwargs): def __new__(cls, value, *args, **kwargs):
if value < 0: if value < 0:
raise ValueError("unsigned types must not be negative") raise ValueError("unsigned types must not be negative")
if cls.byte_len and value.bit_length() > (cls.byte_len << 3):
raise ValueError("value out of bounds for uint{}".format(cls.byte_len * 8))
return super().__new__(cls, value) return super().__new__(cls, value)
def __add__(self, other):
return self.__class__(super().__add__(coerce_type_maybe(other, self.__class__, strict=True)))
def __sub__(self, other):
return self.__class__(super().__sub__(coerce_type_maybe(other, self.__class__, strict=True)))
@classmethod
def default(cls):
return cls(0)
class uint8(uint): class uint8(uint):
byte_len = 1 byte_len = 1
def __new__(cls, value, *args, **kwargs):
if value.bit_length() > 8:
raise ValueError("value out of bounds for uint8")
return super().__new__(cls, value)
# Alias for uint8 # Alias for uint8
byte = NewType('byte', uint8) class byte(uint8):
pass
class uint16(uint): class uint16(uint):
byte_len = 2 byte_len = 2
def __new__(cls, value, *args, **kwargs):
if value.bit_length() > 16:
raise ValueError("value out of bounds for uint16")
return super().__new__(cls, value)
class uint32(uint): class uint32(uint):
byte_len = 4 byte_len = 4
def __new__(cls, value, *args, **kwargs):
if value.bit_length() > 32:
raise ValueError("value out of bounds for uint16")
return super().__new__(cls, value)
class uint64(uint): class uint64(uint):
byte_len = 8 byte_len = 8
def __new__(cls, value, *args, **kwargs):
if value.bit_length() > 64:
raise ValueError("value out of bounds for uint64")
return super().__new__(cls, value)
class uint128(uint): class uint128(uint):
byte_len = 16 byte_len = 16
def __new__(cls, value, *args, **kwargs):
if value.bit_length() > 128:
raise ValueError("value out of bounds for uint128")
return super().__new__(cls, value)
class uint256(uint): class uint256(uint):
byte_len = 32 byte_len = 32
def __new__(cls, value, *args, **kwargs):
if value.bit_length() > 256: def coerce_type_maybe(v, typ: SSZType, strict: bool = False):
raise ValueError("value out of bounds for uint256") v_typ = type(v)
return super().__new__(cls, value) # shortcut if it's already the type we are looking for
if v_typ == typ:
return v
elif isinstance(v, int):
if isinstance(v, uint): # do not coerce from one uintX to another uintY
if issubclass(typ, uint) and v.type().byte_len == typ.byte_len:
return typ(v)
# revert to default behavior below if-else. (ValueError/bare)
else:
return typ(v)
elif isinstance(v, (list, tuple)):
return typ(*v)
elif isinstance(v, (bytes, BytesN, Bytes)):
return typ(v)
elif isinstance(v, GeneratorType):
return typ(v)
# just return as-is, Value-checkers will take care of it not being coerced, if we are not strict.
if strict and not isinstance(v, typ):
raise ValueError("Type coercion of {} to {} failed".format(v, typ))
return v
def is_uint_type(typ): class Series(SSZValue):
# All integers are uint in the scope of the spec here.
# Since we default to uint64. Bounds can be checked elsewhere.
# However, some are wrapped in a NewType
if hasattr(typ, '__supertype__'):
# get the type that the NewType is wrapping
typ = typ.__supertype__
return isinstance(typ, type) and issubclass(typ, int) and not issubclass(typ, bool) def __iter__(self) -> Iterator[SSZValue]:
raise Exception("Not implemented")
def uint_byte_size(typ):
if hasattr(typ, '__supertype__'):
typ = typ.__supertype__
if isinstance(typ, type):
if issubclass(typ, uint):
return typ.byte_len
elif issubclass(typ, int):
# Default to uint64
return 8
else:
raise TypeError("Type %s is not an uint (or int-default uint64) type" % typ)
# SSZ Container base class
# -----------------------------
# Note: importing ssz functionality locally, to avoid import loop # Note: importing ssz functionality locally, to avoid import loop
class Container(object): class Container(Series, metaclass=SSZType):
def __init__(self, **kwargs): def __init__(self, **kwargs):
cls = self.__class__ cls = self.__class__
for f, t in cls.get_fields(): for f, t in cls.get_fields().items():
if f not in kwargs: if f not in kwargs:
setattr(self, f, get_zero_value(t)) setattr(self, f, t.default())
else: else:
setattr(self, f, kwargs[f]) value = coerce_type_maybe(kwargs[f], t)
if not isinstance(value, t):
raise ValueError(f"Bad input for class {self.__class__}:"
f" field: {f} type: {t} value: {value} value type: {type(value)}")
setattr(self, f, value)
def serialize(self): def serialize(self):
from .ssz_impl import serialize from .ssz_impl import serialize
return serialize(self, self.__class__) return serialize(self)
def hash_tree_root(self): def hash_tree_root(self):
from .ssz_impl import hash_tree_root from .ssz_impl import hash_tree_root
return hash_tree_root(self, self.__class__) return hash_tree_root(self)
def signing_root(self): def signing_root(self):
from .ssz_impl import signing_root from .ssz_impl import signing_root
return signing_root(self, self.__class__) return signing_root(self)
def get_field_values(self): def __setattr__(self, name, value):
cls = self.__class__ if name not in self.__class__.__annotations__:
return [getattr(self, field) for field in cls.get_field_names()] raise AttributeError("Cannot change non-existing SSZ-container attribute")
field_typ = self.__class__.__annotations__[name]
value = coerce_type_maybe(value, field_typ)
if not isinstance(value, field_typ):
raise ValueError(f"Cannot set field of {self.__class__}:"
f" field: {name} type: {field_typ} value: {value} value type: {type(value)}")
super().__setattr__(name, value)
def __repr__(self): def __repr__(self):
return repr({field: getattr(self, field) for field in self.get_field_names()}) return repr({field: (getattr(self, field) if hasattr(self, field) else 'unset')
for field in self.get_fields().keys()})
def __str__(self): def __str__(self):
output = [] output = [f'{self.__class__.__name__}']
for field in self.get_field_names(): for field in self.get_fields().keys():
output.append(f'{field}: {getattr(self, field)}') output.append(f' {field}: {getattr(self, field)}')
return "\n".join(output) return "\n".join(output)
def __eq__(self, other): def __eq__(self, other):
@ -156,404 +190,261 @@ class Container(object):
return copy.deepcopy(self) return copy.deepcopy(self)
@classmethod @classmethod
def get_fields_dict(cls): def get_fields(cls) -> Dict[str, SSZType]:
if not hasattr(cls, '__annotations__'): # no container fields
return {}
return dict(cls.__annotations__) return dict(cls.__annotations__)
@classmethod @classmethod
def get_fields(cls): def default(cls):
return list(dict(cls.__annotations__).items()) return cls(**{f: t.default() for f, t in cls.get_fields().items()})
def get_typed_values(self):
return list(zip(self.get_field_values(), self.get_field_types()))
@classmethod @classmethod
def get_field_names(cls): def is_fixed_size(cls):
return list(cls.__annotations__.keys()) return all(t.is_fixed_size() for t in cls.get_fields().values())
@classmethod def __iter__(self) -> Iterator[SSZValue]:
def get_field_types(cls): return iter([getattr(self, field) for field in self.get_fields().keys()])
# values of annotations are the types corresponding to the fields, not instance values.
return list(cls.__annotations__.values())
# SSZ vector class ParamsBase(Series):
# ----------------------------- _has_params = False
def __new__(cls, *args, **kwargs):
if not cls._has_params:
raise Exception("cannot init bare type without params")
return super().__new__(cls, **kwargs)
def _is_vector_instance_of(a, b): class ParamsMeta(SSZType):
# Other must not be a BytesN
if issubclass(b, bytes):
return False
elif not hasattr(b, 'elem_type') or not hasattr(b, 'length'):
# Vector (b) is not an instance of Vector[X, Y] (a)
return False
elif not hasattr(a, 'elem_type') or not hasattr(a, 'length'):
# Vector[X, Y] (b) is an instance of Vector (a)
return True
else:
# Vector[X, Y] (a) is an instance of Vector[X, Y] (b)
return a.elem_type == b.elem_type and a.length == b.length
def _is_equal_vector_type(a, b):
# Other must not be a BytesN
if issubclass(b, bytes):
return False
elif not hasattr(a, 'elem_type') or not hasattr(a, 'length'):
if not hasattr(b, 'elem_type') or not hasattr(b, 'length'):
# Vector == Vector
return True
else:
# Vector != Vector[X, Y]
return False
elif not hasattr(b, 'elem_type') or not hasattr(b, 'length'):
# Vector[X, Y] != Vector
return False
else:
# Vector[X, Y] == Vector[X, Y]
return a.elem_type == b.elem_type and a.length == b.length
class VectorMeta(type):
def __new__(cls, class_name, parents, attrs): def __new__(cls, class_name, parents, attrs):
out = type.__new__(cls, class_name, parents, attrs) out = type.__new__(cls, class_name, parents, attrs)
if 'elem_type' in attrs and 'length' in attrs: if hasattr(out, "_has_params") and getattr(out, "_has_params"):
setattr(out, 'elem_type', attrs['elem_type']) for k, v in attrs.items():
setattr(out, 'length', attrs['length']) setattr(out, k, v)
return out return out
def __getitem__(self, params): def __getitem__(self, params):
if not isinstance(params, tuple) or len(params) != 2: o = self.__class__(self.__name__, (self,), self.attr_from_params(params))
raise Exception("Vector must be instantiated with two args: elem type and length")
o = self.__class__(self.__name__, (Vector,), {'elem_type': params[0], 'length': params[1]})
o._name = 'Vector'
return o return o
def __subclasscheck__(self, sub): def __str__(self):
return _is_vector_instance_of(self, sub) return f"{self.__name__}~{self.__class__.__name__}"
def __instancecheck__(self, other):
return _is_vector_instance_of(self, other.__class__)
def __eq__(self, other):
return _is_equal_vector_type(self, other)
def __ne__(self, other):
return not _is_equal_vector_type(self, other)
def __hash__(self):
return hash(self.__class__)
class Vector(metaclass=VectorMeta):
def __init__(self, *args: Iterable):
cls = self.__class__
if not hasattr(cls, 'elem_type'):
raise TypeError("Type Vector without elem_type data cannot be instantiated")
elif not hasattr(cls, 'length'):
raise TypeError("Type Vector without length data cannot be instantiated")
if len(args) != cls.length:
if len(args) == 0:
args = [get_zero_value(cls.elem_type) for _ in range(cls.length)]
else:
raise TypeError("Typed vector with length %d cannot hold %d items" % (cls.length, len(args)))
self.items = list(args)
# cannot check non-type objects, or parametrized types
if isinstance(cls.elem_type, type) and not hasattr(cls.elem_type, '__args__'):
for i, item in enumerate(self.items):
if not issubclass(cls.elem_type, type(item)):
raise TypeError("Typed vector cannot hold differently typed value"
" at index %d. Got type: %s, expected type: %s" % (i, type(item), cls.elem_type))
def serialize(self):
from .ssz_impl import serialize
return serialize(self, self.__class__)
def hash_tree_root(self):
from .ssz_impl import hash_tree_root
return hash_tree_root(self, self.__class__)
def __repr__(self): def __repr__(self):
return repr({'length': self.__class__.length, 'items': self.items}) return self, self.__class__
def __getitem__(self, key): def attr_from_params(self, p):
return self.items[key] # single key params are valid too. Wrap them in a tuple.
params = p if isinstance(p, tuple) else (p,)
def __setitem__(self, key, value): res = {'_has_params': True}
self.items[key] = value i = 0
for (name, typ) in self.__annotations__.items():
def __iter__(self): if hasattr(self.__class__, name):
return iter(self.items) res[name] = getattr(self.__class__, name)
def __len__(self):
return len(self.items)
def __eq__(self, other):
return self.hash_tree_root() == other.hash_tree_root()
# SSZ BytesN
# -----------------------------
def _is_bytes_n_instance_of(a, b):
# Other has to be a Bytes derivative class to be a BytesN
if not issubclass(b, bytes):
return False
elif not hasattr(b, 'length'):
# BytesN (b) is not an instance of BytesN[X] (a)
return False
elif not hasattr(a, 'length'):
# BytesN[X] (b) is an instance of BytesN (a)
return True
else:
# BytesN[X] (a) is an instance of BytesN[X] (b)
return a.length == b.length
def _is_equal_bytes_n_type(a, b):
# Other has to be a Bytes derivative class to be a BytesN
if not issubclass(b, bytes):
return False
elif not hasattr(a, 'length'):
if not hasattr(b, 'length'):
# BytesN == BytesN
return True
else:
# BytesN != BytesN[X]
return False
elif not hasattr(b, 'length'):
# BytesN[X] != BytesN
return False
else:
# BytesN[X] == BytesN[X]
return a.length == b.length
class BytesNMeta(type):
def __new__(cls, class_name, parents, attrs):
out = type.__new__(cls, class_name, parents, attrs)
if 'length' in attrs:
setattr(out, 'length', attrs['length'])
out._name = 'BytesN'
out.elem_type = byte
return out
def __getitem__(self, n):
return self.__class__(self.__name__, (BytesN,), {'length': n})
def __subclasscheck__(self, sub):
return _is_bytes_n_instance_of(self, sub)
def __instancecheck__(self, other):
return _is_bytes_n_instance_of(self, other.__class__)
def __eq__(self, other):
if other == ():
return False
return _is_equal_bytes_n_type(self, other)
def __ne__(self, other):
return not _is_equal_bytes_n_type(self, other)
def __hash__(self):
return hash(self.__class__)
def parse_bytes(val):
if val is None:
return None
elif isinstance(val, str):
# TODO: import from eth-utils instead, and do: hexstr_if_str(to_bytes, val)
return None
elif isinstance(val, bytes):
return val
elif isinstance(val, int):
return bytes([val])
elif isinstance(val, (list, GeneratorType)):
return bytes(val)
else:
return None
class BytesN(bytes, metaclass=BytesNMeta):
def __new__(cls, *args):
if not hasattr(cls, 'length'):
return
bytesval = None
if len(args) == 1:
val: Union[bytes, int, str] = args[0]
bytesval = parse_bytes(val)
elif len(args) > 1:
# TODO: each int is 1 byte, check size, create bytesval
bytesval = bytes(args)
if bytesval is None:
if cls.length == 0:
bytesval = b''
else: else:
bytesval = b'\x00' * cls.length if i >= len(params):
if len(bytesval) != cls.length: i += 1
raise TypeError("BytesN[%d] cannot be initialized with value of %d bytes" % (cls.length, len(bytesval))) continue
return super().__new__(cls, bytesval) param = params[i]
if not isinstance(param, typ):
raise TypeError(
"cannot create parametrized class with param {} as {} of type {}".format(param, name, typ))
res[name] = param
i += 1
if len(params) != i:
raise TypeError("provided parameters {} mismatch required parameter count {}".format(params, i))
return res
def serialize(self): def __subclasscheck__(self, subclass):
from .ssz_impl import serialize # check regular class system if we can, solves a lot of the normal cases.
return serialize(self, self.__class__) if super().__subclasscheck__(subclass):
return True
# if they are not normal subclasses, they are of the same class.
# then they should have the same name
if subclass.__name__ != self.__name__:
return False
# If they do have the same name, they should also have the same params.
for name, typ in self.__annotations__.items():
if hasattr(self, name) and hasattr(subclass, name) \
and getattr(subclass, name) != getattr(self, name):
return False
return True
def hash_tree_root(self): def __instancecheck__(self, obj):
from .ssz_impl import hash_tree_root return self.__subclasscheck__(obj.__class__)
return hash_tree_root(self, self.__class__)
class Bytes4(BytesN): class ElementsType(ParamsMeta):
length = 4 elem_type: SSZType
length: int
class Bytes32(BytesN): class Elements(ParamsBase, metaclass=ElementsType):
length = 32 pass
class Bytes48(BytesN): class BaseList(list, Elements):
length = 48
def __init__(self, *args):
items = self.extract_args(*args)
if not self.value_check(items):
raise ValueError(f"Bad input for class {self.__class__}: {items}")
super().__init__(items)
@classmethod
def value_check(cls, value):
return all(isinstance(v, cls.elem_type) for v in value) and len(value) <= cls.length
@classmethod
def extract_args(cls, *args):
x = list(args)
if len(x) == 1 and isinstance(x[0], (GeneratorType, list, tuple)):
x = list(x[0])
x = [coerce_type_maybe(v, cls.elem_type) for v in x]
return x
def __str__(self):
cls = self.__class__
return f"{cls.__name__}[{cls.elem_type.__name__}, {cls.length}]({', '.join(str(v) for v in self)})"
def __getitem__(self, k) -> SSZValue:
if isinstance(k, int): # check if we are just doing a lookup, and not slicing
if k < 0:
raise IndexError(f"cannot get item in type {self.__class__} at negative index {k}")
if k > len(self):
raise IndexError(f"cannot get item in type {self.__class__}"
f" at out of bounds index {k}")
return super().__getitem__(k)
def __setitem__(self, k, v):
if k < 0:
raise IndexError(f"cannot set item in type {self.__class__} at negative index {k} (to {v})")
if k > len(self):
raise IndexError(f"cannot set item in type {self.__class__}"
f" at out of bounds index {k} (to {v}, bound: {len(self)})")
super().__setitem__(k, coerce_type_maybe(v, self.__class__.elem_type, strict=True))
def append(self, v):
super().append(coerce_type_maybe(v, self.__class__.elem_type, strict=True))
def __iter__(self) -> Iterator[SSZValue]:
return super().__iter__()
def last(self):
# be explict about getting the last item, for the non-python readers, and negative-index safety
return self[len(self) - 1]
class Bytes96(BytesN): class List(BaseList):
length = 96
@classmethod
def default(cls):
return cls()
# SSZ Defaults @classmethod
# ----------------------------- def is_fixed_size(cls):
def get_zero_value(typ):
if is_uint_type(typ):
return uint64(0)
elif is_list_type(typ):
return []
elif is_bool_type(typ):
return False return False
elif is_vector_type(typ):
return typ()
elif is_bytesn_type(typ): class Vector(BaseList):
return typ()
elif is_bytes_type(typ): @classmethod
def value_check(cls, value):
# check length limit strictly
return len(value) == cls.length and super().value_check(value)
@classmethod
def default(cls):
return cls(cls.elem_type.default() for _ in range(cls.length))
@classmethod
def is_fixed_size(cls):
return cls.elem_type.is_fixed_size()
def append(self, v):
# Deep-copy and other utils like to change the internals during work.
# Only complain if we had the right size.
if len(self) == self.__class__.length:
raise Exception("cannot modify vector length")
else:
super().append(v)
def pop(self, *args):
raise Exception("cannot modify vector length")
class BytesType(ElementsType):
elem_type: SSZType = byte
length: int
class BaseBytes(bytes, Elements, metaclass=BytesType):
def __new__(cls, *args) -> "BaseBytes":
extracted_val = cls.extract_args(*args)
if not cls.value_check(extracted_val):
raise ValueError(f"Bad input for class {cls}: {extracted_val}")
return super().__new__(cls, extracted_val)
@classmethod
def extract_args(cls, *args):
x = args
if len(x) == 1 and isinstance(x[0], (GeneratorType, bytes)):
x = x[0]
if isinstance(x, bytes): # Includes BytesLike
return x
else:
return bytes(x) # E.g. GeneratorType put into bytes.
@classmethod
def value_check(cls, value):
# check type and virtual length limit
return isinstance(value, bytes) and len(value) <= cls.length
def __str__(self):
cls = self.__class__
return f"{cls.__name__}[{cls.length}]: {self.hex()}"
class Bytes(BaseBytes):
@classmethod
def default(cls):
return b'' return b''
elif is_container_type(typ):
return typ(**{f: get_zero_value(t) for f, t in typ.get_fields()}) @classmethod
else: def is_fixed_size(cls):
raise Exception("Type not supported: {}".format(typ)) return False
# Type helpers class BytesN(BaseBytes):
# -----------------------------
@classmethod
def extract_args(cls, *args):
if len(args) == 0:
return cls.default()
else:
return super().extract_args(*args)
@classmethod
def default(cls):
return b'\x00' * cls.length
@classmethod
def value_check(cls, value):
# check length limit strictly
return len(value) == cls.length and super().value_check(value)
@classmethod
def is_fixed_size(cls):
return True
def infer_type(obj): # Helpers for common BytesN types.
if is_uint_type(obj.__class__): Bytes4: BytesType = BytesN[4]
return obj.__class__ Bytes32: BytesType = BytesN[32]
elif isinstance(obj, int): Bytes48: BytesType = BytesN[48]
return uint64 Bytes96: BytesType = BytesN[96]
elif isinstance(obj, list):
return List[infer_type(obj[0])]
elif isinstance(obj, (Vector, Container, bool, BytesN, bytes)):
return obj.__class__
else:
raise Exception("Unknown type for {}".format(obj))
def infer_input_type(fn):
"""
Decorator to run infer_type on the obj if typ argument is None
"""
def infer_helper(obj, typ=None, **kwargs):
if typ is None:
typ = infer_type(obj)
return fn(obj, typ=typ, **kwargs)
return infer_helper
def is_bool_type(typ):
"""
Check if the given type is a bool.
"""
if hasattr(typ, '__supertype__'):
typ = typ.__supertype__
return isinstance(typ, type) and issubclass(typ, bool)
def is_list_type(typ):
"""
Check if the given type is a list.
"""
return get_origin(typ) is List or get_origin(typ) is list
def is_bytes_type(typ):
"""
Check if the given type is a ``bytes``.
"""
# Do not accept subclasses of bytes here, to avoid confusion with BytesN
return typ == bytes
def is_bytesn_type(typ):
"""
Check if the given type is a BytesN.
"""
return isinstance(typ, type) and issubclass(typ, BytesN)
def is_list_kind(typ):
"""
Check if the given type is a kind of list. Can be bytes.
"""
return is_list_type(typ) or is_bytes_type(typ)
def is_vector_type(typ):
"""
Check if the given type is a vector.
"""
return isinstance(typ, type) and issubclass(typ, Vector)
def is_vector_kind(typ):
"""
Check if the given type is a kind of vector. Can be BytesN.
"""
return is_vector_type(typ) or is_bytesn_type(typ)
def is_container_type(typ):
"""
Check if the given type is a container.
"""
return isinstance(typ, type) and issubclass(typ, Container)
T = TypeVar('T')
L = TypeVar('L')
def read_list_elem_type(list_typ: Type[List[T]]) -> T:
if list_typ.__args__ is None or len(list_typ.__args__) != 1:
raise TypeError("Supplied list-type is invalid, no element type found.")
return list_typ.__args__[0]
def read_vector_elem_type(vector_typ: Type[Vector[T, L]]) -> T:
return vector_typ.elem_type
def read_elem_type(typ):
if typ == bytes or (isinstance(typ, type) and issubclass(typ, bytes)): # bytes or bytesN
return byte
elif is_list_type(typ):
return read_list_elem_type(typ)
elif is_vector_type(typ):
return read_vector_elem_type(typ)
else:
raise TypeError("Unexpected type: {}".format(typ))

View File

@ -0,0 +1,241 @@
from typing import Iterable
from .ssz_impl import serialize, hash_tree_root
from .ssz_typing import (
Bit, Bool, Container, List, Vector, Bytes, BytesN,
uint8, uint16, uint32, uint64, uint256, byte
)
from ..hash_function import hash as bytes_hash
import pytest
class EmptyTestStruct(Container):
pass
class SingleFieldTestStruct(Container):
A: byte
class SmallTestStruct(Container):
A: uint16
B: uint16
class FixedTestStruct(Container):
A: uint8
B: uint64
C: uint32
class VarTestStruct(Container):
A: uint16
B: List[uint16, 1024]
C: uint8
class ComplexTestStruct(Container):
A: uint16
B: List[uint16, 128]
C: uint8
D: Bytes[256]
E: VarTestStruct
F: Vector[FixedTestStruct, 4]
G: Vector[VarTestStruct, 2]
sig_test_data = [0 for i in range(96)]
for k, v in {0: 1, 32: 2, 64: 3, 95: 0xff}.items():
sig_test_data[k] = v
def chunk(hex: str) -> str:
return (hex + ("00" * 32))[:64] # just pad on the right, to 32 bytes (64 hex chars)
def h(a: str, b: str) -> str:
return bytes_hash(bytes.fromhex(a) + bytes.fromhex(b)).hex()
# zero hashes, as strings, for
zero_hashes = [chunk("")]
for layer in range(1, 32):
zero_hashes.append(h(zero_hashes[layer - 1], zero_hashes[layer - 1]))
def merge(a: str, branch: Iterable[str]) -> str:
"""
Merge (out on left, branch on right) leaf a with branch items, branch is from bottom to top.
"""
out = a
for b in branch:
out = h(out, b)
return out
test_data = [
("bit F", Bit(False), "00", chunk("00")),
("bit T", Bit(True), "01", chunk("01")),
("bool F", Bool(False), "00", chunk("00")),
("bool T", Bool(True), "01", chunk("01")),
("uint8 00", uint8(0x00), "00", chunk("00")),
("uint8 01", uint8(0x01), "01", chunk("01")),
("uint8 ab", uint8(0xab), "ab", chunk("ab")),
("byte 00", byte(0x00), "00", chunk("00")),
("byte 01", byte(0x01), "01", chunk("01")),
("byte ab", byte(0xab), "ab", chunk("ab")),
("uint16 0000", uint16(0x0000), "0000", chunk("0000")),
("uint16 abcd", uint16(0xabcd), "cdab", chunk("cdab")),
("uint32 00000000", uint32(0x00000000), "00000000", chunk("00000000")),
("uint32 01234567", uint32(0x01234567), "67452301", chunk("67452301")),
("small (4567, 0123)", SmallTestStruct(A=0x4567, B=0x0123), "67452301", h(chunk("6745"), chunk("2301"))),
("small [4567, 0123]::2", Vector[uint16, 2](uint16(0x4567), uint16(0x0123)), "67452301", chunk("67452301")),
("uint32 01234567", uint32(0x01234567), "67452301", chunk("67452301")),
("uint64 0000000000000000", uint64(0x00000000), "0000000000000000", chunk("0000000000000000")),
("uint64 0123456789abcdef", uint64(0x0123456789abcdef), "efcdab8967452301", chunk("efcdab8967452301")),
("sig", BytesN[96](*sig_test_data),
"0100000000000000000000000000000000000000000000000000000000000000"
"0200000000000000000000000000000000000000000000000000000000000000"
"03000000000000000000000000000000000000000000000000000000000000ff",
h(h(chunk("01"), chunk("02")),
h("03000000000000000000000000000000000000000000000000000000000000ff", chunk("")))),
("emptyTestStruct", EmptyTestStruct(), "", chunk("")),
("singleFieldTestStruct", SingleFieldTestStruct(A=0xab), "ab", chunk("ab")),
("uint16 list", List[uint16, 32](uint16(0xaabb), uint16(0xc0ad), uint16(0xeeff)), "bbaaadc0ffee",
h(h(chunk("bbaaadc0ffee"), chunk("")), chunk("03000000")) # max length: 32 * 2 = 64 bytes = 2 chunks
),
("uint32 list", List[uint32, 128](uint32(0xaabb), uint32(0xc0ad), uint32(0xeeff)), "bbaa0000adc00000ffee0000",
# max length: 128 * 4 = 512 bytes = 16 chunks
h(merge(chunk("bbaa0000adc00000ffee0000"), zero_hashes[0:4]), chunk("03000000"))
),
("uint256 list", List[uint256, 32](uint256(0xaabb), uint256(0xc0ad), uint256(0xeeff)),
"bbaa000000000000000000000000000000000000000000000000000000000000"
"adc0000000000000000000000000000000000000000000000000000000000000"
"ffee000000000000000000000000000000000000000000000000000000000000",
h(merge(h(h(chunk("bbaa"), chunk("adc0")), h(chunk("ffee"), chunk(""))), zero_hashes[2:5]), chunk("03000000"))
),
("uint256 list long", List[uint256, 128](i for i in range(1, 20)),
"".join([i.to_bytes(length=32, byteorder='little').hex() for i in range(1, 20)]),
h(merge(
h(
h(
h(
h(h(chunk("01"), chunk("02")), h(chunk("03"), chunk("04"))),
h(h(chunk("05"), chunk("06")), h(chunk("07"), chunk("08"))),
),
h(
h(h(chunk("09"), chunk("0a")), h(chunk("0b"), chunk("0c"))),
h(h(chunk("0d"), chunk("0e")), h(chunk("0f"), chunk("10"))),
)
),
h(
h(
h(h(chunk("11"), chunk("12")), h(chunk("13"), chunk(""))),
zero_hashes[2]
),
zero_hashes[3]
)
),
zero_hashes[5:7]), chunk("13000000")) # 128 chunks = 7 deep
),
("fixedTestStruct", FixedTestStruct(A=0xab, B=0xaabbccdd00112233, C=0x12345678), "ab33221100ddccbbaa78563412",
h(h(chunk("ab"), chunk("33221100ddccbbaa")), h(chunk("78563412"), chunk("")))),
("varTestStruct nil", VarTestStruct(A=0xabcd, C=0xff), "cdab07000000ff",
h(h(chunk("cdab"), h(zero_hashes[6], chunk("00000000"))), h(chunk("ff"), chunk("")))),
("varTestStruct empty", VarTestStruct(A=0xabcd, B=List[uint16, 1024](), C=0xff), "cdab07000000ff",
h(h(chunk("cdab"), h(zero_hashes[6], chunk("00000000"))), h(chunk("ff"), chunk("")))), # log2(1024*2/32)= 6 deep
("varTestStruct some", VarTestStruct(A=0xabcd, B=List[uint16, 1024](1, 2, 3), C=0xff),
"cdab07000000ff010002000300",
h(
h(
chunk("cdab"),
h(
merge(
chunk("010002000300"),
zero_hashes[0:6]
),
chunk("03000000") # length mix in
)
),
h(chunk("ff"), chunk(""))
)),
("complexTestStruct",
ComplexTestStruct(
A=0xaabb,
B=List[uint16, 128](0x1122, 0x3344),
C=0xff,
D=Bytes[256](b"foobar"),
E=VarTestStruct(A=0xabcd, B=List[uint16, 1024](1, 2, 3), C=0xff),
F=Vector[FixedTestStruct, 4](
FixedTestStruct(A=0xcc, B=0x4242424242424242, C=0x13371337),
FixedTestStruct(A=0xdd, B=0x3333333333333333, C=0xabcdabcd),
FixedTestStruct(A=0xee, B=0x4444444444444444, C=0x00112233),
FixedTestStruct(A=0xff, B=0x5555555555555555, C=0x44556677)),
G=Vector[VarTestStruct, 2](
VarTestStruct(A=0xdead, B=List[uint16, 1024](1, 2, 3), C=0x11),
VarTestStruct(A=0xbeef, B=List[uint16, 1024](4, 5, 6), C=0x22)),
),
"bbaa"
"47000000" # offset of B, []uint16
"ff"
"4b000000" # offset of foobar
"51000000" # offset of E
"cc424242424242424237133713"
"dd3333333333333333cdabcdab"
"ee444444444444444433221100"
"ff555555555555555577665544"
"5e000000" # pointer to G
"22114433" # contents of B
"666f6f626172" # foobar
"cdab07000000ff010002000300" # contents of E
"08000000" "15000000" # [start G]: local offsets of [2]varTestStruct
"adde0700000011010002000300"
"efbe0700000022040005000600",
h(
h(
h( # A and B
chunk("bbaa"),
h(merge(chunk("22114433"), zero_hashes[0:3]), chunk("02000000")) # 2*128/32 = 8 chunks
),
h( # C and D
chunk("ff"),
h(merge(chunk("666f6f626172"), zero_hashes[0:3]), chunk("06000000")) # 256/32 = 8 chunks
)
),
h(
h( # E and F
h(h(chunk("cdab"), h(merge(chunk("010002000300"), zero_hashes[0:6]), chunk("03000000"))),
h(chunk("ff"), chunk(""))),
h(
h(
h(h(chunk("cc"), chunk("4242424242424242")), h(chunk("37133713"), chunk(""))),
h(h(chunk("dd"), chunk("3333333333333333")), h(chunk("cdabcdab"), chunk(""))),
),
h(
h(h(chunk("ee"), chunk("4444444444444444")), h(chunk("33221100"), chunk(""))),
h(h(chunk("ff"), chunk("5555555555555555")), h(chunk("77665544"), chunk(""))),
),
)
),
h( # G and padding
h(
h(h(chunk("adde"), h(merge(chunk("010002000300"), zero_hashes[0:6]), chunk("03000000"))),
h(chunk("11"), chunk(""))),
h(h(chunk("efbe"), h(merge(chunk("040005000600"), zero_hashes[0:6]), chunk("03000000"))),
h(chunk("22"), chunk(""))),
),
chunk("")
)
)
))
]
@pytest.mark.parametrize("name, value, serialized, _", test_data)
def test_serialize(name, value, serialized, _):
assert serialize(value) == bytes.fromhex(serialized)
@pytest.mark.parametrize("name, value, _, root", test_data)
def test_hash_tree_root(name, value, _, root):
assert hash_tree_root(value) == bytes.fromhex(root)

View File

@ -0,0 +1,233 @@
from .ssz_typing import (
SSZValue, SSZType, BasicValue, BasicType, Series, ElementsType,
Elements, Bit, Bool, Container, List, Vector, Bytes, BytesN,
byte, uint, uint8, uint16, uint32, uint64, uint128, uint256,
Bytes32, Bytes48
)
def expect_value_error(fn, msg):
try:
fn()
raise AssertionError(msg)
except ValueError:
pass
def test_subclasses():
for u in [uint, uint8, uint16, uint32, uint64, uint128, uint256]:
assert issubclass(u, uint)
assert issubclass(u, int)
assert issubclass(u, BasicValue)
assert issubclass(u, SSZValue)
assert isinstance(u, SSZType)
assert isinstance(u, BasicType)
assert issubclass(Bool, BasicValue)
assert isinstance(Bool, BasicType)
for c in [Container, List, Vector, Bytes, BytesN]:
assert issubclass(c, Series)
assert issubclass(c, SSZValue)
assert isinstance(c, SSZType)
assert not issubclass(c, BasicValue)
assert not isinstance(c, BasicType)
for c in [List, Vector, Bytes, BytesN]:
assert issubclass(c, Elements)
assert isinstance(c, ElementsType)
def test_basic_instances():
for u in [uint, uint8, byte, uint16, uint32, uint64, uint128, uint256]:
v = u(123)
assert isinstance(v, uint)
assert isinstance(v, int)
assert isinstance(v, BasicValue)
assert isinstance(v, SSZValue)
assert isinstance(Bool(True), BasicValue)
assert isinstance(Bool(False), BasicValue)
assert isinstance(Bit(True), Bool)
assert isinstance(Bit(False), Bool)
def test_basic_value_bounds():
max = {
Bool: 2 ** 1,
Bit: 2 ** 1,
uint8: 2 ** (8 * 1),
byte: 2 ** (8 * 1),
uint16: 2 ** (8 * 2),
uint32: 2 ** (8 * 4),
uint64: 2 ** (8 * 8),
uint128: 2 ** (8 * 16),
uint256: 2 ** (8 * 32),
}
for k, v in max.items():
# this should work
assert k(v - 1) == v - 1
# but we do not allow overflows
expect_value_error(lambda: k(v), "no overflows allowed")
for k, _ in max.items():
# this should work
assert k(0) == 0
# but we do not allow underflows
expect_value_error(lambda: k(-1), "no underflows allowed")
def test_container():
class Foo(Container):
a: uint8
b: uint32
empty = Foo()
assert empty.a == uint8(0)
assert empty.b == uint32(0)
assert issubclass(Foo, Container)
assert issubclass(Foo, SSZValue)
assert issubclass(Foo, Series)
assert Foo.is_fixed_size()
x = Foo(a=uint8(123), b=uint32(45))
assert x.a == 123
assert x.b == 45
assert isinstance(x.a, uint8)
assert isinstance(x.b, uint32)
assert x.type().is_fixed_size()
class Bar(Container):
a: uint8
b: List[uint8, 1024]
assert not Bar.is_fixed_size()
y = Bar(a=123, b=List[uint8, 1024](uint8(1), uint8(2)))
assert y.a == 123
assert isinstance(y.a, uint8)
assert len(y.b) == 2
assert isinstance(y.a, uint8)
assert isinstance(y.b, List[uint8, 1024])
assert not y.type().is_fixed_size()
assert y.b[0] == 1
v: List = y.b
assert v.type().elem_type == uint8
assert v.type().length == 1024
y.a = 42
try:
y.a = 256 # out of bounds
assert False
except ValueError:
pass
try:
y.a = uint16(255) # within bounds, wrong type
assert False
except ValueError:
pass
try:
y.not_here = 5
assert False
except AttributeError:
pass
def test_list():
typ = List[uint64, 128]
assert issubclass(typ, List)
assert issubclass(typ, SSZValue)
assert issubclass(typ, Series)
assert issubclass(typ, Elements)
assert isinstance(typ, ElementsType)
assert not typ.is_fixed_size()
assert len(typ()) == 0 # empty
assert len(typ(uint64(0))) == 1 # single arg
assert len(typ(uint64(i) for i in range(10))) == 10 # generator
assert len(typ(uint64(0), uint64(1), uint64(2))) == 3 # args
assert isinstance(typ(1, 2, 3, 4, 5)[4], uint64) # coercion
assert isinstance(typ(i for i in range(10))[9], uint64) # coercion in generator
v = typ(uint64(0))
v[0] = uint64(123)
assert v[0] == 123
assert isinstance(v[0], uint64)
assert isinstance(v, List)
assert isinstance(v, List[uint64, 128])
assert isinstance(v, typ)
assert isinstance(v, SSZValue)
assert isinstance(v, Series)
assert issubclass(v.type(), Elements)
assert isinstance(v.type(), ElementsType)
assert len(typ([i for i in range(10)])) == 10 # cast py list to SSZ list
foo = List[uint32, 128](0 for i in range(128))
foo[0] = 123
foo[1] = 654
foo[127] = 222
assert sum(foo) == 999
try:
foo[3] = 2 ** 32 # out of bounds
except ValueError:
pass
try:
foo[3] = uint64(2 ** 32 - 1) # within bounds, wrong type
assert False
except ValueError:
pass
try:
foo[128] = 100
assert False
except IndexError:
pass
try:
foo[-1] = 100 # valid in normal python lists
assert False
except IndexError:
pass
try:
foo[128] = 100 # out of bounds
assert False
except IndexError:
pass
def test_bytesn_subclass():
assert isinstance(BytesN[32](b'\xab' * 32), Bytes32)
assert not isinstance(BytesN[32](b'\xab' * 32), Bytes48)
assert issubclass(BytesN[32](b'\xab' * 32).type(), Bytes32)
assert issubclass(BytesN[32], Bytes32)
class Hash(Bytes32):
pass
assert isinstance(Hash(b'\xab' * 32), Bytes32)
assert not isinstance(Hash(b'\xab' * 32), Bytes48)
assert issubclass(Hash(b'\xab' * 32).type(), Bytes32)
assert issubclass(Hash, Bytes32)
assert not issubclass(Bytes48, Bytes32)
assert len(Bytes32() + Bytes48()) == 80
def test_uint_math():
assert uint8(0) + uint8(uint32(16)) == uint8(16) # allow explict casting to make invalid addition valid
expect_value_error(lambda: uint8(0) - uint8(1), "no underflows allowed")
expect_value_error(lambda: uint8(1) + uint8(255), "no overflows allowed")
expect_value_error(lambda: uint8(0) + 256, "no overflows allowed")
expect_value_error(lambda: uint8(42) + uint32(123), "no mixed types")
expect_value_error(lambda: uint32(42) + uint8(123), "no mixed types")
assert type(uint32(1234) + 56) == uint32

View File

@ -0,0 +1,58 @@
import pytest
from .merkle_minimal import zerohashes, merkleize_chunks
from .hash_function import hash
def h(a: bytes, b: bytes) -> bytes:
return hash(a + b)
def e(v: int) -> bytes:
return v.to_bytes(length=32, byteorder='little')
def z(i: int) -> bytes:
return zerohashes[i]
cases = [
(0, 0, 1, z(0)),
(0, 1, 1, e(0)),
(1, 0, 2, h(z(0), z(0))),
(1, 1, 2, h(e(0), z(0))),
(1, 2, 2, h(e(0), e(1))),
(2, 0, 4, h(h(z(0), z(0)), z(1))),
(2, 1, 4, h(h(e(0), z(0)), z(1))),
(2, 2, 4, h(h(e(0), e(1)), z(1))),
(2, 3, 4, h(h(e(0), e(1)), h(e(2), z(0)))),
(2, 4, 4, h(h(e(0), e(1)), h(e(2), e(3)))),
(3, 0, 8, h(h(h(z(0), z(0)), z(1)), z(2))),
(3, 1, 8, h(h(h(e(0), z(0)), z(1)), z(2))),
(3, 2, 8, h(h(h(e(0), e(1)), z(1)), z(2))),
(3, 3, 8, h(h(h(e(0), e(1)), h(e(2), z(0))), z(2))),
(3, 4, 8, h(h(h(e(0), e(1)), h(e(2), e(3))), z(2))),
(3, 5, 8, h(h(h(e(0), e(1)), h(e(2), e(3))), h(h(e(4), z(0)), z(1)))),
(3, 6, 8, h(h(h(e(0), e(1)), h(e(2), e(3))), h(h(e(4), e(5)), h(z(0), z(0))))),
(3, 7, 8, h(h(h(e(0), e(1)), h(e(2), e(3))), h(h(e(4), e(5)), h(e(6), z(0))))),
(3, 8, 8, h(h(h(e(0), e(1)), h(e(2), e(3))), h(h(e(4), e(5)), h(e(6), e(7))))),
(4, 0, 16, h(h(h(h(z(0), z(0)), z(1)), z(2)), z(3))),
(4, 1, 16, h(h(h(h(e(0), z(0)), z(1)), z(2)), z(3))),
(4, 2, 16, h(h(h(h(e(0), e(1)), z(1)), z(2)), z(3))),
(4, 3, 16, h(h(h(h(e(0), e(1)), h(e(2), z(0))), z(2)), z(3))),
(4, 4, 16, h(h(h(h(e(0), e(1)), h(e(2), e(3))), z(2)), z(3))),
(4, 5, 16, h(h(h(h(e(0), e(1)), h(e(2), e(3))), h(h(e(4), z(0)), z(1))), z(3))),
(4, 6, 16, h(h(h(h(e(0), e(1)), h(e(2), e(3))), h(h(e(4), e(5)), h(z(0), z(0)))), z(3))),
(4, 7, 16, h(h(h(h(e(0), e(1)), h(e(2), e(3))), h(h(e(4), e(5)), h(e(6), z(0)))), z(3))),
(4, 8, 16, h(h(h(h(e(0), e(1)), h(e(2), e(3))), h(h(e(4), e(5)), h(e(6), e(7)))), z(3))),
(4, 9, 16,
h(h(h(h(e(0), e(1)), h(e(2), e(3))), h(h(e(4), e(5)), h(e(6), e(7)))), h(h(h(e(8), z(0)), z(1)), z(2)))),
]
@pytest.mark.parametrize(
'depth,count,pow2,value',
cases,
)
def test_merkleize_chunks(depth, count, pow2, value):
chunks = [e(i) for i in range(count)]
assert merkleize_chunks(chunks, pad_to=pow2) == value

View File

@ -2,6 +2,5 @@ eth-utils>=1.3.0,<2
eth-typing>=2.1.0,<3.0.0 eth-typing>=2.1.0,<3.0.0
pycryptodome==3.7.3 pycryptodome==3.7.3
py_ecc>=1.6.0 py_ecc>=1.6.0
typing_inspect==0.4.0
dataclasses==0.6 dataclasses==0.6
ssz==0.1.0a10 ssz==0.1.0a10

View File

@ -9,7 +9,6 @@ setup(
"eth-typing>=2.1.0,<3.0.0", "eth-typing>=2.1.0,<3.0.0",
"pycryptodome==3.7.3", "pycryptodome==3.7.3",
"py_ecc>=1.6.0", "py_ecc>=1.6.0",
"typing_inspect==0.4.0",
"ssz==0.1.0a10", "ssz==0.1.0a10",
"dataclasses==0.6", "dataclasses==0.6",
] ]