Merge pull request #1180 from ethereum/list-rework

pyspec-SSZ: lists-rework (enable static generalized indices) + fully python class based now.
This commit is contained in:
Danny Ryan 2019-06-25 07:38:50 -06:00 committed by GitHub
commit df2a9e1b54
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
24 changed files with 1314 additions and 871 deletions

View File

@ -77,6 +77,10 @@ MIN_EPOCHS_TO_INACTIVITY_PENALTY: 4
EPOCHS_PER_HISTORICAL_VECTOR: 65536
# 2**13 (= 8,192) epochs ~36 days
EPOCHS_PER_SLASHED_BALANCES_VECTOR: 8192
# 2**24 (= 16,777,216) historical roots, ~26,131 years
HISTORICAL_ROOTS_LIMIT: 16777216
# 2**40 (= 1,099,511,627,776) validator spots
VALIDATOR_REGISTRY_LIMIT: 1099511627776
# Reward and penalty quotients

View File

@ -78,6 +78,10 @@ EARLY_DERIVED_SECRET_PENALTY_MAX_FUTURE_EPOCHS: 4096
EPOCHS_PER_HISTORICAL_VECTOR: 64
# [customized] smaller state
EPOCHS_PER_SLASHED_BALANCES_VECTOR: 64
# 2**24 (= 16,777,216) historical roots
HISTORICAL_ROOTS_LIMIT: 16777216
# 2**40 (= 1,099,511,627,776) validator spots
VALIDATOR_REGISTRY_LIMIT: 1099511627776
# Reward and penalty quotients

View File

@ -12,12 +12,7 @@ from typing import (
PHASE0_IMPORTS = '''from typing import (
Any,
Callable,
Dict,
List,
Set,
Tuple,
Any, Callable, Dict, Set, Sequence, Tuple,
)
from dataclasses import (
@ -30,8 +25,7 @@ from eth2spec.utils.ssz.ssz_impl import (
signing_root,
)
from eth2spec.utils.ssz.ssz_typing import (
# unused: uint8, uint16, uint32, uint128, uint256,
uint64, Container, Vector,
Bit, Bool, Container, List, Vector, Bytes, uint64,
Bytes4, Bytes32, Bytes48, Bytes96,
)
from eth2spec.utils.bls import (
@ -39,18 +33,11 @@ from eth2spec.utils.bls import (
bls_verify,
bls_verify_multiple,
)
# Note: 'int' type defaults to being interpreted as a uint64 by SSZ implementation.
from eth2spec.utils.hash_function import hash
'''
PHASE1_IMPORTS = '''from typing import (
Any,
Callable,
Dict,
List,
Optional,
Set,
Tuple,
Any, Callable, Dict, Optional, Set, Sequence, MutableSequence, Tuple,
)
from dataclasses import (
@ -65,8 +52,7 @@ from eth2spec.utils.ssz.ssz_impl import (
is_empty,
)
from eth2spec.utils.ssz.ssz_typing import (
# unused: uint8, uint16, uint32, uint128, uint256,
uint64, Container, Vector,
Bit, Bool, Container, List, Vector, Bytes, uint64,
Bytes4, Bytes32, Bytes48, Bytes96,
)
from eth2spec.utils.bls import (
@ -77,28 +63,7 @@ from eth2spec.utils.bls import (
from eth2spec.utils.hash_function import hash
'''
BYTE_TYPES = [4, 32, 48, 96]
SUNDRY_FUNCTIONS = '''
def get_ssz_type_by_name(name: str) -> Container:
return globals()[name]
# Monkey patch validator compute committee code
_compute_committee = compute_committee
committee_cache: Dict[Tuple[Hash, Hash, int, int], List[ValidatorIndex]] = {}
def compute_committee(indices: List[ValidatorIndex], # type: ignore
seed: Hash,
index: int,
count: int) -> List[ValidatorIndex]:
param_hash = (hash_tree_root(indices), seed, index, count)
if param_hash not in committee_cache:
committee_cache[param_hash] = _compute_committee(indices, seed, index, count)
return committee_cache[param_hash]
# Monkey patch hash cache
_hash = hash
hash_cache: Dict[bytes, Hash] = {}
@ -110,6 +75,22 @@ def hash(x: bytes) -> Hash:
return hash_cache[x]
# Monkey patch validator compute committee code
_compute_committee = compute_committee
committee_cache: Dict[Tuple[Hash, Hash, int, int], Sequence[ValidatorIndex]] = {}
def compute_committee(indices: Sequence[ValidatorIndex], # type: ignore
seed: Hash,
index: int,
count: int) -> Sequence[ValidatorIndex]:
param_hash = (hash(b''.join(index.to_bytes(length=4, byteorder='little') for index in indices)), seed, index, count)
if param_hash not in committee_cache:
committee_cache[param_hash] = _compute_committee(indices, seed, index, count)
return committee_cache[param_hash]
# Access to overwrite spec constants based on configuration
def apply_constants_preset(preset: Dict[str, Any]) -> None:
global_vars = globals()
@ -124,6 +105,18 @@ def apply_constants_preset(preset: Dict[str, Any]) -> None:
'''
def strip_comments(raw: str) -> str:
comment_line_regex = re.compile('^\s+# ')
lines = raw.split('\n')
out = []
for line in lines:
if not comment_line_regex.match(line):
if ' #' in line:
line = line[:line.index(' #')]
out.append(line)
return '\n'.join(out)
def objects_to_spec(functions: Dict[str, str],
custom_types: Dict[str, str],
constants: Dict[str, str],
@ -151,7 +144,8 @@ def objects_to_spec(functions: Dict[str, str],
ssz_objects_instantiation_spec = '\n\n'.join(ssz_objects.values())
ssz_objects_reinitialization_spec = (
'def init_SSZ_types() -> None:\n global_vars = globals()\n\n '
+ '\n\n '.join([re.sub(r'(?!\n\n)\n', r'\n ', value[:-1]) for value in ssz_objects.values()])
+ '\n\n '.join([strip_comments(re.sub(r'(?!\n\n)\n', r'\n ', value[:-1]))
for value in ssz_objects.values()])
+ '\n\n'
+ '\n'.join(map(lambda x: ' global_vars[\'%s\'] = %s' % (x, x), ssz_objects.keys()))
)
@ -183,17 +177,32 @@ def combine_constants(old_constants: Dict[str, str], new_constants: Dict[str, st
return old_constants
ignored_dependencies = [
'Bit', 'Bool', 'Vector', 'List', 'Container', 'Hash', 'BLSPubkey', 'BLSSignature', 'Bytes', 'BytesN'
'Bytes4', 'Bytes32', 'Bytes48', 'Bytes96',
'uint8', 'uint16', 'uint32', 'uint64', 'uint128', 'uint256',
'bytes' # to be removed after updating spec doc
]
def dependency_order_ssz_objects(objects: Dict[str, str], custom_types: Dict[str, str]) -> None:
"""
Determines which SSZ Object is depenedent on which other and orders them appropriately
"""
items = list(objects.items())
for key, value in items:
dependencies = re.findall(r'(: [A-Z][\w[]*)', value)
dependencies = map(lambda x: re.sub(r'\W|Vector|List|Container|Hash|BLSPubkey|BLSSignature|uint\d+|Bytes\d+|bytes', '', x), dependencies)
dependencies = []
for line in value.split('\n'):
if not re.match(r'\s+\w+: .+', line):
continue # skip whitespace etc.
line = line[line.index(':') + 1:] # strip of field name
if '#' in line:
line = line[:line.index('#')] # strip of comment
dependencies.extend(re.findall(r'(\w+)', line)) # catch all legible words, potential dependencies
dependencies = filter(lambda x: '_' not in x and x.upper() != x, dependencies) # filter out constants
dependencies = filter(lambda x: x not in ignored_dependencies, dependencies)
dependencies = filter(lambda x: x not in custom_types, dependencies)
for dep in dependencies:
if dep in custom_types or len(dep) == 0:
continue
key_list = list(objects.keys())
for item in [dep, key] + key_list[key_list.index(dep)+1:]:
objects[item] = objects.pop(item)

View File

@ -234,6 +234,8 @@ The following values are (non-configurable) constants used throughout the specif
| - | - | :-: | :-: |
| `EPOCHS_PER_HISTORICAL_VECTOR` | `2**16` (= 65,536) | epochs | ~0.8 years |
| `EPOCHS_PER_SLASHED_BALANCES_VECTOR` | `2**13` (= 8,192) | epochs | ~36 days |
| `HISTORICAL_ROOTS_LIMIT` | `2**24` (= 16,777,216) | historical roots | ~26,131 years |
| `VALIDATOR_REGISTRY_LIMIT` | `2**40` (= 1,099,511,627,776) | validator spots | |
### Rewards and penalties
@ -295,7 +297,7 @@ class Validator(Container):
pubkey: BLSPubkey
withdrawal_credentials: Hash # Commitment to pubkey for withdrawals and transfers
effective_balance: Gwei # Balance at stake
slashed: bool
slashed: Bool
# Status epochs
activation_eligibility_epoch: Epoch # When criteria for activation were met
activation_epoch: Epoch
@ -335,15 +337,15 @@ class AttestationData(Container):
```python
class AttestationDataAndCustodyBit(Container):
data: AttestationData
custody_bit: bool # Challengeable bit for the custody of crosslink data
custody_bit: Bit # Challengeable bit (SSZ-bool, 1 byte) for the custody of crosslink data
```
#### `IndexedAttestation`
```python
class IndexedAttestation(Container):
custody_bit_0_indices: List[ValidatorIndex] # Indices with custody bit equal to 0
custody_bit_1_indices: List[ValidatorIndex] # Indices with custody bit equal to 1
custody_bit_0_indices: List[ValidatorIndex, MAX_INDICES_PER_ATTESTATION] # Indices with custody bit equal to 0
custody_bit_1_indices: List[ValidatorIndex, MAX_INDICES_PER_ATTESTATION] # Indices with custody bit equal to 1
data: AttestationData
signature: BLSSignature
```
@ -352,7 +354,7 @@ class IndexedAttestation(Container):
```python
class PendingAttestation(Container):
aggregation_bitfield: bytes # Bit set for every attesting participant within a committee
aggregation_bitfield: Bytes[MAX_INDICES_PER_ATTESTATION // 8]
data: AttestationData
inclusion_delay: Slot
proposer_index: ValidatorIndex
@ -419,9 +421,9 @@ class AttesterSlashing(Container):
```python
class Attestation(Container):
aggregation_bitfield: bytes
aggregation_bitfield: Bytes[MAX_INDICES_PER_ATTESTATION // 8]
data: AttestationData
custody_bitfield: bytes
custody_bitfield: Bytes[MAX_INDICES_PER_ATTESTATION // 8]
signature: BLSSignature
```
@ -465,12 +467,12 @@ class BeaconBlockBody(Container):
eth1_data: Eth1Data # Eth1 data vote
graffiti: Bytes32 # Arbitrary data
# Operations
proposer_slashings: List[ProposerSlashing]
attester_slashings: List[AttesterSlashing]
attestations: List[Attestation]
deposits: List[Deposit]
voluntary_exits: List[VoluntaryExit]
transfers: List[Transfer]
proposer_slashings: List[ProposerSlashing, MAX_PROPOSER_SLASHINGS]
attester_slashings: List[AttesterSlashing, MAX_ATTESTER_SLASHINGS]
attestations: List[Attestation, MAX_ATTESTATIONS]
deposits: List[Deposit, MAX_DEPOSITS]
voluntary_exits: List[VoluntaryExit, MAX_VOLUNTARY_EXITS]
transfers: List[Transfer, MAX_TRANSFERS]
```
#### `BeaconBlock`
@ -498,14 +500,14 @@ class BeaconState(Container):
latest_block_header: BeaconBlockHeader
block_roots: Vector[Hash, SLOTS_PER_HISTORICAL_ROOT]
state_roots: Vector[Hash, SLOTS_PER_HISTORICAL_ROOT]
historical_roots: List[Hash]
historical_roots: List[Hash, HISTORICAL_ROOTS_LIMIT]
# Eth1
eth1_data: Eth1Data
eth1_data_votes: List[Eth1Data]
eth1_data_votes: List[Eth1Data, SLOTS_PER_ETH1_VOTING_PERIOD]
eth1_deposit_index: uint64
# Registry
validators: List[Validator]
balances: List[Gwei]
validators: List[Validator, VALIDATOR_REGISTRY_LIMIT]
balances: List[Gwei, VALIDATOR_REGISTRY_LIMIT]
# Shuffling
start_shard: Shard
randao_mixes: Vector[Hash, EPOCHS_PER_HISTORICAL_VECTOR]
@ -513,8 +515,8 @@ class BeaconState(Container):
# Slashings
slashed_balances: Vector[Gwei, EPOCHS_PER_SLASHED_BALANCES_VECTOR] # Sums of slashed effective balances
# Attestations
previous_epoch_attestations: List[PendingAttestation]
current_epoch_attestations: List[PendingAttestation]
previous_epoch_attestations: List[PendingAttestation, MAX_ATTESTATIONS * SLOTS_PER_EPOCH]
current_epoch_attestations: List[PendingAttestation, MAX_ATTESTATIONS * SLOTS_PER_EPOCH]
# Crosslinks
previous_crosslinks: Vector[Crosslink, SHARD_COUNT] # Previous epoch snapshot
current_crosslinks: Vector[Crosslink, SHARD_COUNT]
@ -623,13 +625,13 @@ def is_slashable_validator(validator: Validator, epoch: Epoch) -> bool:
"""
Check if ``validator`` is slashable.
"""
return validator.slashed is False and (validator.activation_epoch <= epoch < validator.withdrawable_epoch)
return (not validator.slashed) and (validator.activation_epoch <= epoch < validator.withdrawable_epoch)
```
### `get_active_validator_indices`
```python
def get_active_validator_indices(state: BeaconState, epoch: Epoch) -> List[ValidatorIndex]:
def get_active_validator_indices(state: BeaconState, epoch: Epoch) -> Sequence[ValidatorIndex]:
"""
Get active validator indices at ``epoch``.
"""
@ -795,7 +797,7 @@ def get_beacon_proposer_index(state: BeaconState) -> ValidatorIndex:
### `verify_merkle_branch`
```python
def verify_merkle_branch(leaf: Hash, proof: List[Hash], depth: int, index: int, root: Hash) -> bool:
def verify_merkle_branch(leaf: Hash, proof: Sequence[Hash], depth: int, index: int, root: Hash) -> bool:
"""
Verify that the given ``leaf`` is on the merkle branch ``proof``
starting with the given ``root``.
@ -839,7 +841,8 @@ def get_shuffled_index(index: ValidatorIndex, index_count: int, seed: Hash) -> V
### `compute_committee`
```python
def compute_committee(indices: List[ValidatorIndex], seed: Hash, index: int, count: int) -> List[ValidatorIndex]:
def compute_committee(indices: Sequence[ValidatorIndex],
seed: Hash, index: int, count: int) -> Sequence[ValidatorIndex]:
start = (len(indices) * index) // count
end = (len(indices) * (index + 1)) // count
return [indices[get_shuffled_index(ValidatorIndex(i), len(indices), seed)] for i in range(start, end)]
@ -848,7 +851,7 @@ def compute_committee(indices: List[ValidatorIndex], seed: Hash, index: int, cou
### `get_crosslink_committee`
```python
def get_crosslink_committee(state: BeaconState, epoch: Epoch, shard: Shard) -> List[ValidatorIndex]:
def get_crosslink_committee(state: BeaconState, epoch: Epoch, shard: Shard) -> Sequence[ValidatorIndex]:
return compute_committee(
indices=get_active_validator_indices(state, epoch),
seed=generate_seed(state, epoch),
@ -862,7 +865,7 @@ def get_crosslink_committee(state: BeaconState, epoch: Epoch, shard: Shard) -> L
```python
def get_attesting_indices(state: BeaconState,
attestation_data: AttestationData,
bitfield: bytes) -> List[ValidatorIndex]:
bitfield: bytes) -> Sequence[ValidatorIndex]:
"""
Return the sorted attesting indices corresponding to ``attestation_data`` and ``bitfield``.
"""
@ -888,7 +891,7 @@ def bytes_to_int(data: bytes) -> int:
### `get_total_balance`
```python
def get_total_balance(state: BeaconState, indices: List[ValidatorIndex]) -> Gwei:
def get_total_balance(state: BeaconState, indices: Set[ValidatorIndex]) -> Gwei:
"""
Return the combined effective balance of the ``indices``. (1 Gwei minimum to avoid divisions by zero.)
"""
@ -1114,7 +1117,7 @@ def slash_validator(state: BeaconState,
### Genesis trigger
Before genesis has been triggered and whenever the deposit contract emits a `Deposit` log, call the function `is_genesis_trigger(deposits: List[Deposit], timestamp: uint64) -> bool` where:
Before genesis has been triggered and whenever the deposit contract emits a `Deposit` log, call the function `is_genesis_trigger(deposits: Sequence[Deposit], timestamp: uint64) -> bool` where:
* `deposits` is the list of all deposits, ordered chronologically, up to and including the deposit triggering the latest `Deposit` log
* `timestamp` is the Unix timestamp in the Ethereum 1.0 block that emitted the latest `Deposit` log
@ -1131,7 +1134,7 @@ When `is_genesis_trigger(deposits, timestamp) is True` for the first time, let:
*Note*: The function `is_genesis_trigger` has yet to be agreed upon by the community, and can be updated as necessary. We define the following testing placeholder:
```python
def is_genesis_trigger(deposits: List[Deposit], timestamp: uint64) -> bool:
def is_genesis_trigger(deposits: Sequence[Deposit], timestamp: uint64) -> bool:
# Process deposits
state = BeaconState()
for deposit in deposits:
@ -1153,10 +1156,10 @@ def is_genesis_trigger(deposits: List[Deposit], timestamp: uint64) -> bool:
Let `genesis_state = get_genesis_beacon_state(genesis_deposits, genesis_time, genesis_eth1_data)`.
```python
def get_genesis_beacon_state(deposits: List[Deposit], genesis_time: int, genesis_eth1_data: Eth1Data) -> BeaconState:
def get_genesis_beacon_state(deposits: Sequence[Deposit], genesis_time: int, eth1_data: Eth1Data) -> BeaconState:
state = BeaconState(
genesis_time=genesis_time,
eth1_data=genesis_eth1_data,
eth1_data=eth1_data,
latest_block_header=BeaconBlockHeader(body_root=hash_tree_root(BeaconBlockBody())),
)
@ -1170,8 +1173,10 @@ def get_genesis_beacon_state(deposits: List[Deposit], genesis_time: int, genesis
validator.activation_eligibility_epoch = GENESIS_EPOCH
validator.activation_epoch = GENESIS_EPOCH
# Populate active_index_roots
genesis_active_index_root = hash_tree_root(get_active_validator_indices(state, GENESIS_EPOCH))
# Populate active_index_roots
genesis_active_index_root = hash_tree_root(
List[ValidatorIndex, VALIDATOR_REGISTRY_LIMIT](get_active_validator_indices(state, GENESIS_EPOCH))
)
for index in range(EPOCHS_PER_HISTORICAL_VECTOR):
state.active_index_roots[index] = genesis_active_index_root
@ -1246,17 +1251,17 @@ def process_epoch(state: BeaconState) -> None:
```python
def get_total_active_balance(state: BeaconState) -> Gwei:
return get_total_balance(state, get_active_validator_indices(state, get_current_epoch(state)))
return get_total_balance(state, set(get_active_validator_indices(state, get_current_epoch(state))))
```
```python
def get_matching_source_attestations(state: BeaconState, epoch: Epoch) -> List[PendingAttestation]:
def get_matching_source_attestations(state: BeaconState, epoch: Epoch) -> Sequence[PendingAttestation]:
assert epoch in (get_current_epoch(state), get_previous_epoch(state))
return state.current_epoch_attestations if epoch == get_current_epoch(state) else state.previous_epoch_attestations
```
```python
def get_matching_target_attestations(state: BeaconState, epoch: Epoch) -> List[PendingAttestation]:
def get_matching_target_attestations(state: BeaconState, epoch: Epoch) -> Sequence[PendingAttestation]:
return [
a for a in get_matching_source_attestations(state, epoch)
if a.data.target_root == get_block_root(state, epoch)
@ -1264,7 +1269,7 @@ def get_matching_target_attestations(state: BeaconState, epoch: Epoch) -> List[P
```
```python
def get_matching_head_attestations(state: BeaconState, epoch: Epoch) -> List[PendingAttestation]:
def get_matching_head_attestations(state: BeaconState, epoch: Epoch) -> Sequence[PendingAttestation]:
return [
a for a in get_matching_source_attestations(state, epoch)
if a.data.beacon_block_root == get_block_root_at_slot(state, get_attestation_data_slot(state, a.data))
@ -1273,22 +1278,22 @@ def get_matching_head_attestations(state: BeaconState, epoch: Epoch) -> List[Pen
```python
def get_unslashed_attesting_indices(state: BeaconState,
attestations: List[PendingAttestation]) -> List[ValidatorIndex]:
attestations: Sequence[PendingAttestation]) -> Set[ValidatorIndex]:
output = set() # type: Set[ValidatorIndex]
for a in attestations:
output = output.union(get_attesting_indices(state, a.data, a.aggregation_bitfield))
return sorted(filter(lambda index: not state.validators[index].slashed, list(output)))
return set(filter(lambda index: not state.validators[index].slashed, list(output)))
```
```python
def get_attesting_balance(state: BeaconState, attestations: List[PendingAttestation]) -> Gwei:
def get_attesting_balance(state: BeaconState, attestations: Sequence[PendingAttestation]) -> Gwei:
return get_total_balance(state, get_unslashed_attesting_indices(state, attestations))
```
```python
def get_winning_crosslink_and_attesting_indices(state: BeaconState,
epoch: Epoch,
shard: Shard) -> Tuple[Crosslink, List[ValidatorIndex]]:
shard: Shard) -> Tuple[Crosslink, Set[ValidatorIndex]]:
attestations = [a for a in get_matching_source_attestations(state, epoch) if a.data.crosslink.shard == shard]
crosslinks = list(filter(
lambda c: hash_tree_root(state.current_crosslinks[shard]) in (c.parent_root, hash_tree_root(c)),
@ -1361,7 +1366,7 @@ def process_crosslinks(state: BeaconState) -> None:
for epoch in (get_previous_epoch(state), get_current_epoch(state)):
for offset in range(get_epoch_committee_count(state, epoch)):
shard = Shard((get_epoch_start_shard(state, epoch) + offset) % SHARD_COUNT)
crosslink_committee = get_crosslink_committee(state, epoch, shard)
crosslink_committee = set(get_crosslink_committee(state, epoch, shard))
winning_crosslink, attesting_indices = get_winning_crosslink_and_attesting_indices(state, epoch, shard)
if 3 * get_total_balance(state, attesting_indices) >= 2 * get_total_balance(state, crosslink_committee):
state.current_crosslinks[shard] = winning_crosslink
@ -1377,7 +1382,7 @@ def get_base_reward(state: BeaconState, index: ValidatorIndex) -> Gwei:
```
```python
def get_attestation_deltas(state: BeaconState) -> Tuple[List[Gwei], List[Gwei]]:
def get_attestation_deltas(state: BeaconState) -> Tuple[Sequence[Gwei], Sequence[Gwei]]:
previous_epoch = get_previous_epoch(state)
total_balance = get_total_active_balance(state)
rewards = [Gwei(0) for _ in range(len(state.validators))]
@ -1428,13 +1433,13 @@ def get_attestation_deltas(state: BeaconState) -> Tuple[List[Gwei], List[Gwei]]:
```
```python
def get_crosslink_deltas(state: BeaconState) -> Tuple[List[Gwei], List[Gwei]]:
rewards = [Gwei(0) for index in range(len(state.validators))]
penalties = [Gwei(0) for index in range(len(state.validators))]
def get_crosslink_deltas(state: BeaconState) -> Tuple[Sequence[Gwei], Sequence[Gwei]]:
rewards = [Gwei(0) for _ in range(len(state.validators))]
penalties = [Gwei(0) for _ in range(len(state.validators))]
epoch = get_previous_epoch(state)
for offset in range(get_epoch_committee_count(state, epoch)):
shard = Shard((get_epoch_start_shard(state, epoch) + offset) % SHARD_COUNT)
crosslink_committee = get_crosslink_committee(state, epoch, shard)
crosslink_committee = set(get_crosslink_committee(state, epoch, shard))
winning_crosslink, attesting_indices = get_winning_crosslink_and_attesting_indices(state, epoch, shard)
attesting_balance = get_total_balance(state, attesting_indices)
committee_balance = get_total_balance(state, crosslink_committee)
@ -1528,7 +1533,9 @@ def process_final_updates(state: BeaconState) -> None:
# Set active index root
index_root_position = (next_epoch + ACTIVATION_EXIT_DELAY) % EPOCHS_PER_HISTORICAL_VECTOR
state.active_index_roots[index_root_position] = hash_tree_root(
get_active_validator_indices(state, Epoch(next_epoch + ACTIVATION_EXIT_DELAY))
List[ValidatorIndex, VALIDATOR_REGISTRY_LIMIT](
get_active_validator_indices(state, Epoch(next_epoch + ACTIVATION_EXIT_DELAY))
)
)
# Set total slashed balances
state.slashed_balances[next_epoch % EPOCHS_PER_SLASHED_BALANCES_VECTOR] = (
@ -1610,16 +1617,16 @@ def process_operations(state: BeaconState, body: BeaconBlockBody) -> None:
assert len(body.deposits) == min(MAX_DEPOSITS, state.eth1_data.deposit_count - state.eth1_deposit_index)
# Verify that there are no duplicate transfers
assert len(body.transfers) == len(set(body.transfers))
all_operations = [
(body.proposer_slashings, MAX_PROPOSER_SLASHINGS, process_proposer_slashing),
(body.attester_slashings, MAX_ATTESTER_SLASHINGS, process_attester_slashing),
(body.attestations, MAX_ATTESTATIONS, process_attestation),
(body.deposits, MAX_DEPOSITS, process_deposit),
(body.voluntary_exits, MAX_VOLUNTARY_EXITS, process_voluntary_exit),
(body.transfers, MAX_TRANSFERS, process_transfer),
] # type: List[Tuple[List[Container], int, Callable]]
for operations, max_operations, function in all_operations:
assert len(operations) <= max_operations
all_operations = (
(body.proposer_slashings, process_proposer_slashing),
(body.attester_slashings, process_attester_slashing),
(body.attestations, process_attestation),
(body.deposits, process_deposit),
(body.voluntary_exits, process_voluntary_exit),
(body.transfers, process_transfer),
) # type: Sequence[Tuple[List, Callable]]
for operations, function in all_operations:
for operation in operations:
function(state, operation)
```

View File

@ -113,6 +113,13 @@ This document details the beacon chain additions and changes in Phase 1 of Ether
| - | - |
| `DOMAIN_CUSTODY_BIT_CHALLENGE` | `6` |
### TODO PLACEHOLDER
| Name | Value |
| - | - |
| `PLACEHOLDER` | `2**32` |
## Data structures
### Custody objects
@ -134,7 +141,7 @@ class CustodyBitChallenge(Container):
attestation: Attestation
challenger_index: ValidatorIndex
responder_key: BLSSignature
chunk_bits: bytes
chunk_bits: Bytes[PLACEHOLDER]
signature: BLSSignature
```
@ -171,9 +178,9 @@ class CustodyBitChallengeRecord(Container):
class CustodyResponse(Container):
challenge_index: uint64
chunk_index: uint64
chunk: Vector[bytes, BYTES_PER_CUSTODY_CHUNK]
data_branch: List[Bytes32]
chunk_bits_branch: List[Bytes32]
chunk: Vector[Bytes[PLACEHOLDER], BYTES_PER_CUSTODY_CHUNK]
data_branch: List[Bytes32, PLACEHOLDER]
chunk_bits_branch: List[Bytes32, PLACEHOLDER]
chunk_bits_leaf: Bytes32
```
@ -226,24 +233,25 @@ class Validator(Container):
```python
class BeaconState(Container):
custody_chunk_challenge_records: List[CustodyChunkChallengeRecord]
custody_bit_challenge_records: List[CustodyBitChallengeRecord]
custody_chunk_challenge_records: List[CustodyChunkChallengeRecord, PLACEHOLDER]
custody_bit_challenge_records: List[CustodyBitChallengeRecord, PLACEHOLDER]
custody_challenge_index: uint64
# Future derived secrets already exposed; contains the indices of the exposed validator
# at RANDAO reveal period % EARLY_DERIVED_SECRET_PENALTY_MAX_FUTURE_EPOCHS
exposed_derived_secrets: Vector[List[ValidatorIndex], EARLY_DERIVED_SECRET_PENALTY_MAX_FUTURE_EPOCHS]
exposed_derived_secrets: Vector[List[ValidatorIndex, PLACEHOLDER],
EARLY_DERIVED_SECRET_PENALTY_MAX_FUTURE_EPOCHS]
```
#### `BeaconBlockBody`
```python
class BeaconBlockBody(Container):
custody_chunk_challenges: List[CustodyChunkChallenge]
custody_bit_challenges: List[CustodyBitChallenge]
custody_responses: List[CustodyResponse]
custody_key_reveals: List[CustodyKeyReveal]
early_derived_secret_reveals: List[EarlyDerivedSecretReveal]
custody_chunk_challenges: List[CustodyChunkChallenge, PLACEHOLDER]
custody_bit_challenges: List[CustodyBitChallenge, PLACEHOLDER]
custody_responses: List[CustodyResponse, PLACEHOLDER]
custody_key_reveals: List[CustodyKeyReveal, PLACEHOLDER]
early_derived_secret_reveals: List[EarlyDerivedSecretReveal, PLACEHOLDER]
```
## Helpers
@ -310,7 +318,7 @@ def get_validators_custody_reveal_period(state: BeaconState,
### `replace_empty_or_append`
```python
def replace_empty_or_append(list: List[Any], new_element: Any) -> int:
def replace_empty_or_append(list: MutableSequence[Any], new_element: Any) -> int:
for i in range(len(list)):
if is_empty(list[i]):
list[i] = new_element
@ -394,12 +402,11 @@ def process_early_derived_secret_reveal(state: BeaconState,
"""
revealed_validator = state.validators[reveal.revealed_index]
masker = state.validators[reveal.masker_index]
derived_secret_location = reveal.epoch % EARLY_DERIVED_SECRET_PENALTY_MAX_FUTURE_EPOCHS
assert reveal.epoch >= get_current_epoch(state) + RANDAO_PENALTY_EPOCHS
assert reveal.epoch < get_current_epoch(state) + EARLY_DERIVED_SECRET_PENALTY_MAX_FUTURE_EPOCHS
assert revealed_validator.slashed is False
assert not revealed_validator.slashed
assert reveal.revealed_index not in state.exposed_derived_secrets[derived_secret_location]
# Verify signature correctness

View File

@ -69,13 +69,19 @@ This document describes the shard data layer and the shard fork choice rule in P
| `DOMAIN_SHARD_PROPOSER` | `128` |
| `DOMAIN_SHARD_ATTESTER` | `129` |
### TODO PLACEHOLDER
| Name | Value |
| - | - |
| `PLACEHOLDER` | `2**32` |
## Data structures
### `ShardBlockBody`
```python
class ShardBlockBody(Container):
data: Vector[bytes, BYTES_PER_SHARD_BLOCK_BODY]
data: Vector[Bytes[PLACEHOLDER], BYTES_PER_SHARD_BLOCK_BODY]
```
### `ShardAttestation`
@ -86,7 +92,7 @@ class ShardAttestation(Container):
slot: Slot
shard: Shard
shard_block_root: Bytes32
aggregation_bitfield: bytes
aggregation_bitfield: Bytes[PLACEHOLDER]
aggregate_signature: BLSSignature
```
@ -100,7 +106,7 @@ class ShardBlock(Container):
parent_root: Bytes32
data: ShardBlockBody
state_root: Bytes32
attestations: List[ShardAttestation]
attestations: List[ShardAttestation, PLACEHOLDER]
signature: BLSSignature
```
@ -114,7 +120,7 @@ class ShardBlockHeader(Container):
parent_root: Bytes32
body_root: Bytes32
state_root: Bytes32
attestations: List[ShardAttestation]
attestations: List[ShardAttestation, PLACEHOLDER]
signature: BLSSignature
```
@ -127,7 +133,7 @@ def get_period_committee(state: BeaconState,
epoch: Epoch,
shard: Shard,
index: int,
count: int) -> List[ValidatorIndex]:
count: int) -> Sequence[ValidatorIndex]:
"""
Return committee for a period. Used to construct persistent committees.
"""
@ -153,7 +159,7 @@ def get_switchover_epoch(state: BeaconState, epoch: Epoch, index: ValidatorIndex
```python
def get_persistent_committee(state: BeaconState,
shard: Shard,
slot: Slot) -> List[ValidatorIndex]:
slot: Slot) -> Sequence[ValidatorIndex]:
"""
Return the persistent committee for the given ``shard`` at the given ``slot``.
"""
@ -187,7 +193,7 @@ def get_shard_proposer_index(state: BeaconState,
shard: Shard,
slot: Slot) -> Optional[ValidatorIndex]:
# Randomly shift persistent committee
persistent_committee = get_persistent_committee(state, shard, slot)
persistent_committee = list(get_persistent_committee(state, shard, slot))
seed = hash(state.current_shuffling_seed + int_to_bytes(shard, length=8) + int_to_bytes(slot, length=8))
random_index = bytes_to_int(seed[0:8]) % len(persistent_committee)
persistent_committee = persistent_committee[random_index:] + persistent_committee[:random_index]
@ -242,13 +248,13 @@ def verify_shard_attestation_signature(state: BeaconState,
### `compute_crosslink_data_root`
```python
def compute_crosslink_data_root(blocks: List[ShardBlock]) -> Bytes32:
def compute_crosslink_data_root(blocks: Sequence[ShardBlock]) -> Bytes32:
def is_power_of_two(value: int) -> bool:
return (value > 0) and (value & (value - 1) == 0)
def pad_to_power_of_2(values: List[bytes]) -> List[bytes]:
def pad_to_power_of_2(values: MutableSequence[bytes]) -> Sequence[bytes]:
while not is_power_of_two(len(values)):
values += [b'\x00' * BYTES_PER_SHARD_BLOCK_BODY]
values.append(b'\x00' * BYTES_PER_SHARD_BLOCK_BODY)
return values
def hash_tree_root_of_bytes(data: bytes) -> bytes:
@ -258,6 +264,8 @@ def compute_crosslink_data_root(blocks: List[ShardBlock]) -> Bytes32:
return data + b'\x00' * (length - len(data))
return hash(
# TODO untested code.
# Need to either pass a typed list to hash-tree-root, or merkleize_chunks(values, pad_to=2**x)
hash_tree_root(pad_to_power_of_2([
hash_tree_root_of_bytes(
zpad(serialize(get_shard_header(block)), BYTES_PER_SHARD_BLOCK_BODY)
@ -281,9 +289,9 @@ Let:
* `candidate` be a candidate `ShardBlock` for which validity is to be determined by running `is_valid_shard_block`
```python
def is_valid_shard_block(beacon_blocks: List[BeaconBlock],
def is_valid_shard_block(beacon_blocks: Sequence[BeaconBlock],
beacon_state: BeaconState,
valid_shard_blocks: List[ShardBlock],
valid_shard_blocks: Sequence[ShardBlock],
candidate: ShardBlock) -> bool:
# Check if block is already determined valid
for _, block in enumerate(valid_shard_blocks):
@ -330,7 +338,7 @@ def is_valid_shard_block(beacon_blocks: List[BeaconBlock],
assert proposer_index is not None
assert bls_verify(
pubkey=beacon_state.validators[proposer_index].pubkey,
message_hash=signing_root(block),
message_hash=signing_root(candidate),
signature=candidate.signature,
domain=get_domain(beacon_state, DOMAIN_SHARD_PROPOSER, slot_to_epoch(candidate.slot)),
)
@ -347,7 +355,7 @@ Let:
* `candidate` be a candidate `ShardAttestation` for which validity is to be determined by running `is_valid_shard_attestation`
```python
def is_valid_shard_attestation(valid_shard_blocks: List[ShardBlock],
def is_valid_shard_attestation(valid_shard_blocks: Sequence[ShardBlock],
beacon_state: BeaconState,
candidate: ShardAttestation) -> bool:
# Check shard block
@ -372,17 +380,17 @@ Let:
* `shard` be a valid `Shard`
* `shard_blocks` be the `ShardBlock` list such that `shard_blocks[slot]` is the canonical `ShardBlock` for shard `shard` at slot `slot`
* `beacon_state` be the canonical `BeaconState`
* `valid_attestations` be the list of valid `Attestation`, recursively defined
* `valid_attestations` be the set of valid `Attestation` objects, recursively defined
* `candidate` be a candidate `Attestation` which is valid under Phase 0 rules, and for which validity is to be determined under Phase 1 rules by running `is_valid_beacon_attestation`
```python
def is_valid_beacon_attestation(shard: Shard,
shard_blocks: List[ShardBlock],
shard_blocks: Sequence[ShardBlock],
beacon_state: BeaconState,
valid_attestations: List[Attestation],
valid_attestations: Set[Attestation],
candidate: Attestation) -> bool:
# Check if attestation is already determined valid
for _, attestation in enumerate(valid_attestations):
for attestation in valid_attestations:
if candidate == attestation:
return True

View File

@ -21,8 +21,8 @@ MAX_LIST_LENGTH = 10
@to_dict
def create_test_case_contents(value, typ):
yield "value", encode.encode(value, typ)
def create_test_case_contents(value):
yield "value", encode.encode(value)
yield "serialized", '0x' + serialize(value).hex()
yield "root", '0x' + hash_tree_root(value).hex()
if hasattr(value, "signature"):
@ -32,7 +32,7 @@ def create_test_case_contents(value, typ):
@to_dict
def create_test_case(rng: Random, name: str, typ, mode: random_value.RandomizationMode, chaos: bool):
value = random_value.get_random_ssz_object(rng, typ, MAX_BYTES_LENGTH, MAX_LIST_LENGTH, mode, chaos)
yield name, create_test_case_contents(value, typ)
yield name, create_test_case_contents(value)
def get_spec_ssz_types():

View File

@ -1,39 +1,29 @@
from typing import Any
from eth2spec.utils.ssz.ssz_impl import hash_tree_root
from eth2spec.utils.ssz.ssz_typing import (
is_uint_type, is_bool_type, is_list_type,
is_vector_type, is_bytes_type, is_bytesn_type, is_container_type,
read_vector_elem_type, read_list_elem_type,
SSZType, SSZValue, uint, Container, Bytes, List, Bool,
Vector, BytesN
)
def decode(data, typ):
if is_uint_type(typ):
return data
elif is_bool_type(typ):
assert data in (True, False)
return data
elif is_list_type(typ):
elem_typ = read_list_elem_type(typ)
return [decode(element, elem_typ) for element in data]
elif is_vector_type(typ):
elem_typ = read_vector_elem_type(typ)
return Vector(decode(element, elem_typ) for element in data)
elif is_bytes_type(typ):
return bytes.fromhex(data[2:])
elif is_bytesn_type(typ):
return BytesN(bytes.fromhex(data[2:]))
elif is_container_type(typ):
def decode(data: Any, typ: SSZType) -> SSZValue:
if issubclass(typ, (uint, Bool)):
return typ(data)
elif issubclass(typ, (List, Vector)):
return typ(decode(element, typ.elem_type) for element in data)
elif issubclass(typ, (Bytes, BytesN)):
return typ(bytes.fromhex(data[2:]))
elif issubclass(typ, Container):
temp = {}
for field, subtype in typ.get_fields():
temp[field] = decode(data[field], subtype)
if field + "_hash_tree_root" in data:
assert(data[field + "_hash_tree_root"][2:] ==
hash_tree_root(temp[field], subtype).hex())
for field_name, field_type in typ.get_fields().items():
temp[field_name] = decode(data[field_name], field_type)
if field_name + "_hash_tree_root" in data:
assert (data[field_name + "_hash_tree_root"][2:] ==
hash_tree_root(temp[field_name]).hex())
ret = typ(**temp)
if "hash_tree_root" in data:
assert(data["hash_tree_root"][2:] ==
hash_tree_root(ret, typ).hex())
assert (data["hash_tree_root"][2:] ==
hash_tree_root(ret).hex())
return ret
else:
raise Exception(f"Type not recognized: data={data}, typ={typ}")

View File

@ -1,36 +1,29 @@
from eth2spec.utils.ssz.ssz_impl import hash_tree_root
from eth2spec.utils.ssz.ssz_typing import (
is_uint_type, is_bool_type, is_list_type, is_vector_type, is_container_type,
read_elem_type,
uint
SSZValue, uint, Container, Bool
)
def encode(value, typ, include_hash_tree_roots=False):
if is_uint_type(typ):
if hasattr(typ, '__supertype__'):
typ = typ.__supertype__
def encode(value: SSZValue, include_hash_tree_roots=False):
if isinstance(value, uint):
# Larger uints are boxed and the class declares their byte length
if issubclass(typ, uint) and typ.byte_len > 8:
return str(value)
return value
elif is_bool_type(typ):
assert value in (True, False)
return value
elif is_list_type(typ) or is_vector_type(typ):
elem_typ = read_elem_type(typ)
return [encode(element, elem_typ, include_hash_tree_roots) for element in value]
elif isinstance(typ, type) and issubclass(typ, bytes): # both bytes and BytesN
if value.type().byte_len > 8:
return str(int(value))
return int(value)
elif isinstance(value, Bool):
return value == 1
elif isinstance(value, list): # normal python lists, ssz-List, Vector
return [encode(element, include_hash_tree_roots) for element in value]
elif isinstance(value, bytes): # both bytes and BytesN
return '0x' + value.hex()
elif is_container_type(typ):
elif isinstance(value, Container):
ret = {}
for field, subtype in typ.get_fields():
field_value = getattr(value, field)
ret[field] = encode(field_value, subtype, include_hash_tree_roots)
for field_value, field_name in zip(value, value.get_fields().keys()):
ret[field_name] = encode(field_value, include_hash_tree_roots)
if include_hash_tree_roots:
ret[field + "_hash_tree_root"] = '0x' + hash_tree_root(field_value, subtype).hex()
ret[field_name + "_hash_tree_root"] = '0x' + hash_tree_root(field_value).hex()
if include_hash_tree_roots:
ret["hash_tree_root"] = '0x' + hash_tree_root(value, typ).hex()
ret["hash_tree_root"] = '0x' + hash_tree_root(value).hex()
return ret
else:
raise Exception(f"Type not recognized: value={value}, typ={typ}")
raise Exception(f"Type not recognized: value={value}, typ={value.type()}")

View File

@ -1,18 +1,13 @@
from random import Random
from typing import Any
from enum import Enum
from eth2spec.utils.ssz.ssz_impl import is_basic_type
from eth2spec.utils.ssz.ssz_typing import (
is_uint_type, is_bool_type, is_list_type,
is_vector_type, is_bytes_type, is_bytesn_type, is_container_type,
read_vector_elem_type, read_list_elem_type,
uint_byte_size
SSZType, SSZValue, BasicValue, BasicType, uint, Container, Bytes, List, Bool,
Vector, BytesN
)
# in bytes
UINT_SIZES = (1, 2, 4, 8, 16, 32)
UINT_BYTE_SIZES = (1, 2, 4, 8, 16, 32)
random_mode_names = ("random", "zero", "max", "nil", "one", "lengthy")
@ -39,11 +34,11 @@ class RandomizationMode(Enum):
def get_random_ssz_object(rng: Random,
typ: Any,
typ: SSZType,
max_bytes_length: int,
max_list_length: int,
mode: RandomizationMode,
chaos: bool) -> Any:
chaos: bool) -> SSZValue:
"""
Create an object for a given type, filled with random data.
:param rng: The random number generator to use.
@ -56,33 +51,31 @@ def get_random_ssz_object(rng: Random,
"""
if chaos:
mode = rng.choice(list(RandomizationMode))
if is_bytes_type(typ):
if issubclass(typ, Bytes):
# Bytes array
if mode == RandomizationMode.mode_nil_count:
return b''
return typ(b'')
elif mode == RandomizationMode.mode_max_count:
return get_random_bytes_list(rng, max_bytes_length)
return typ(get_random_bytes_list(rng, max_bytes_length))
elif mode == RandomizationMode.mode_one_count:
return get_random_bytes_list(rng, 1)
return typ(get_random_bytes_list(rng, 1))
elif mode == RandomizationMode.mode_zero:
return b'\x00'
return typ(b'\x00')
elif mode == RandomizationMode.mode_max:
return b'\xff'
return typ(b'\xff')
else:
return get_random_bytes_list(rng, rng.randint(0, max_bytes_length))
elif is_bytesn_type(typ):
# BytesN
length = typ.length
return typ(get_random_bytes_list(rng, rng.randint(0, max_bytes_length)))
elif issubclass(typ, BytesN):
# Sanity, don't generate absurdly big random values
# If a client is aiming to performance-test, they should create a benchmark suite.
assert length <= max_bytes_length
assert typ.length <= max_bytes_length
if mode == RandomizationMode.mode_zero:
return b'\x00' * length
return typ(b'\x00' * typ.length)
elif mode == RandomizationMode.mode_max:
return b'\xff' * length
return typ(b'\xff' * typ.length)
else:
return get_random_bytes_list(rng, length)
elif is_basic_type(typ):
return typ(get_random_bytes_list(rng, typ.length))
elif issubclass(typ, BasicValue):
# Basic types
if mode == RandomizationMode.mode_zero:
return get_min_basic_value(typ)
@ -90,32 +83,31 @@ def get_random_ssz_object(rng: Random,
return get_max_basic_value(typ)
else:
return get_random_basic_value(rng, typ)
elif is_vector_type(typ):
# Vector
elem_typ = read_vector_elem_type(typ)
return [
get_random_ssz_object(rng, elem_typ, max_bytes_length, max_list_length, mode, chaos)
elif issubclass(typ, Vector):
return typ(
get_random_ssz_object(rng, typ.elem_type, max_bytes_length, max_list_length, mode, chaos)
for _ in range(typ.length)
]
elif is_list_type(typ):
# List
elem_typ = read_list_elem_type(typ)
length = rng.randint(0, max_list_length)
)
elif issubclass(typ, List):
length = rng.randint(0, min(typ.length, max_list_length))
if mode == RandomizationMode.mode_one_count:
length = 1
elif mode == RandomizationMode.mode_max_count:
length = max_list_length
return [
get_random_ssz_object(rng, elem_typ, max_bytes_length, max_list_length, mode, chaos)
if typ.length < length: # SSZ imposes a hard limit on lists, we can't put in more than that
length = typ.length
return typ(
get_random_ssz_object(rng, typ.elem_type, max_bytes_length, max_list_length, mode, chaos)
for _ in range(length)
]
elif is_container_type(typ):
)
elif issubclass(typ, Container):
# Container
return typ(**{
field:
get_random_ssz_object(rng, subtype, max_bytes_length, max_list_length, mode, chaos)
for field, subtype in typ.get_fields()
field_name:
get_random_ssz_object(rng, field_type, max_bytes_length, max_list_length, mode, chaos)
for field_name, field_type in typ.get_fields().items()
})
else:
raise Exception(f"Type not recognized: typ={typ}")
@ -125,34 +117,31 @@ def get_random_bytes_list(rng: Random, length: int) -> bytes:
return bytes(rng.getrandbits(8) for _ in range(length))
def get_random_basic_value(rng: Random, typ) -> Any:
if is_bool_type(typ):
return rng.choice((True, False))
elif is_uint_type(typ):
size = uint_byte_size(typ)
assert size in UINT_SIZES
return rng.randint(0, 256**size - 1)
def get_random_basic_value(rng: Random, typ: BasicType) -> BasicValue:
if issubclass(typ, Bool):
return typ(rng.choice((True, False)))
elif issubclass(typ, uint):
assert typ.byte_len in UINT_BYTE_SIZES
return typ(rng.randint(0, 256 ** typ.byte_len - 1))
else:
raise ValueError(f"Not a basic type: typ={typ}")
def get_min_basic_value(typ) -> Any:
if is_bool_type(typ):
return False
elif is_uint_type(typ):
size = uint_byte_size(typ)
assert size in UINT_SIZES
return 0
def get_min_basic_value(typ: BasicType) -> BasicValue:
if issubclass(typ, Bool):
return typ(False)
elif issubclass(typ, uint):
assert typ.byte_len in UINT_BYTE_SIZES
return typ(0)
else:
raise ValueError(f"Not a basic type: typ={typ}")
def get_max_basic_value(typ) -> Any:
if is_bool_type(typ):
return True
elif is_uint_type(typ):
size = uint_byte_size(typ)
assert size in UINT_SIZES
return 256**size - 1
def get_max_basic_value(typ: BasicType) -> BasicValue:
if issubclass(typ, Bool):
return typ(True)
elif issubclass(typ, uint):
assert typ.byte_len in UINT_BYTE_SIZES
return typ(256 ** typ.byte_len - 1)
else:
raise ValueError(f"Not a basic type: typ={typ}")

View File

@ -8,32 +8,31 @@ def translate_typ(typ) -> ssz.BaseSedes:
:param typ: The spec type, a class.
:return: The Py-SSZ equivalent.
"""
if spec_ssz.is_container_type(typ):
if issubclass(typ, spec_ssz.Container):
return ssz.Container(
[translate_typ(field_typ) for (field_name, field_typ) in typ.get_fields()])
elif spec_ssz.is_bytesn_type(typ):
[translate_typ(field_typ) for field_name, field_typ in typ.get_fields().items()])
elif issubclass(typ, spec_ssz.BytesN):
return ssz.ByteVector(typ.length)
elif spec_ssz.is_bytes_type(typ):
elif issubclass(typ, spec_ssz.Bytes):
return ssz.ByteList()
elif spec_ssz.is_vector_type(typ):
return ssz.Vector(translate_typ(spec_ssz.read_vector_elem_type(typ)), typ.length)
elif spec_ssz.is_list_type(typ):
return ssz.List(translate_typ(spec_ssz.read_list_elem_type(typ)))
elif spec_ssz.is_bool_type(typ):
elif issubclass(typ, spec_ssz.Vector):
return ssz.Vector(translate_typ(typ.elem_type), typ.length)
elif issubclass(typ, spec_ssz.List):
return ssz.List(translate_typ(typ.elem_type))
elif issubclass(typ, spec_ssz.Bool):
return ssz.boolean
elif spec_ssz.is_uint_type(typ):
size = spec_ssz.uint_byte_size(typ)
if size == 1:
elif issubclass(typ, spec_ssz.uint):
if typ.byte_len == 1:
return ssz.uint8
elif size == 2:
elif typ.byte_len == 2:
return ssz.uint16
elif size == 4:
elif typ.byte_len == 4:
return ssz.uint32
elif size == 8:
elif typ.byte_len == 8:
return ssz.uint64
elif size == 16:
elif typ.byte_len == 16:
return ssz.uint128
elif size == 32:
elif typ.byte_len == 32:
return ssz.uint256
else:
raise TypeError("invalid uint size")
@ -48,37 +47,33 @@ def translate_value(value, typ):
:param typ: The type from the spec to translate into
:return: the translated value
"""
if spec_ssz.is_uint_type(typ):
size = spec_ssz.uint_byte_size(typ)
if size == 1:
if issubclass(typ, spec_ssz.uint):
if typ.byte_len == 1:
return spec_ssz.uint8(value)
elif size == 2:
elif typ.byte_len == 2:
return spec_ssz.uint16(value)
elif size == 4:
elif typ.byte_len == 4:
return spec_ssz.uint32(value)
elif size == 8:
# uint64 is default (TODO this is changing soon)
return value
elif size == 16:
elif typ.byte_len == 8:
return spec_ssz.uint64(value)
elif typ.byte_len == 16:
return spec_ssz.uint128(value)
elif size == 32:
elif typ.byte_len == 32:
return spec_ssz.uint256(value)
else:
raise TypeError("invalid uint size")
elif spec_ssz.is_list_type(typ):
elem_typ = spec_ssz.read_elem_type(typ)
return [translate_value(elem, elem_typ) for elem in value]
elif spec_ssz.is_bool_type(typ):
elif issubclass(typ, spec_ssz.List):
return [translate_value(elem, typ.elem_type) for elem in value]
elif issubclass(typ, spec_ssz.Bool):
return value
elif spec_ssz.is_vector_type(typ):
elem_typ = spec_ssz.read_elem_type(typ)
return typ(*(translate_value(elem, elem_typ) for elem in value))
elif spec_ssz.is_bytesn_type(typ):
elif issubclass(typ, spec_ssz.Vector):
return typ(*(translate_value(elem, typ.elem_type) for elem in value))
elif issubclass(typ, spec_ssz.BytesN):
return typ(value)
elif spec_ssz.is_bytes_type(typ):
elif issubclass(typ, spec_ssz.Bytes):
return value
elif spec_ssz.is_container_type(typ):
return typ(**{f_name: translate_value(f_val, f_typ) for (f_name, f_val, f_typ)
in zip(typ.get_field_names(), value, typ.get_field_types())})
if issubclass(typ, spec_ssz.Container):
return typ(**{f_name: translate_value(f_val, f_typ) for (f_val, (f_name, f_typ))
in zip(value, typ.get_fields().items())})
else:
raise TypeError("Type not supported: {}".format(typ))

View File

@ -11,7 +11,7 @@ def get_valid_early_derived_secret_reveal(spec, state, epoch=None):
epoch = current_epoch + spec.CUSTODY_PERIOD_TO_RANDAO_PADDING
reveal = bls_sign(
message_hash=spec.hash_tree_root(epoch),
message_hash=spec.hash_tree_root(spec.Epoch(epoch)),
privkey=privkeys[revealed_index],
domain=spec.get_domain(
state=state,
@ -20,14 +20,14 @@ def get_valid_early_derived_secret_reveal(spec, state, epoch=None):
),
)
mask = bls_sign(
message_hash=spec.hash_tree_root(epoch),
message_hash=spec.hash_tree_root(spec.Epoch(epoch)),
privkey=privkeys[masker_index],
domain=spec.get_domain(
state=state,
domain_type=spec.DOMAIN_RANDAO,
message_epoch=epoch,
),
)
)[:32] # TODO(Carl): mask is 32 bytes, and signature is 96? Correct to slice the first 32 out?
return spec.EarlyDerivedSecretReveal(
revealed_index=revealed_index,

View File

@ -1,5 +1,6 @@
from eth2spec.test.helpers.keys import pubkeys
from eth2spec.utils.ssz.ssz_impl import hash_tree_root
from eth2spec.utils.ssz.ssz_typing import List
def build_mock_validator(spec, i: int, balance: int):
@ -40,7 +41,8 @@ def create_genesis_state(spec, num_validators):
validator.activation_eligibility_epoch = spec.GENESIS_EPOCH
validator.activation_epoch = spec.GENESIS_EPOCH
genesis_active_index_root = hash_tree_root(spec.get_active_validator_indices(state, spec.GENESIS_EPOCH))
genesis_active_index_root = hash_tree_root(List[spec.ValidatorIndex, spec.VALIDATOR_REGISTRY_LIMIT](
spec.get_active_validator_indices(state, spec.GENESIS_EPOCH)))
for index in range(spec.EPOCHS_PER_HISTORICAL_VECTOR):
state.active_index_roots[index] = genesis_active_index_root

View File

@ -114,21 +114,6 @@ def test_invalid_withdrawal_credentials_top_up(spec, state):
yield from run_deposit_processing(spec, state, deposit, validator_index, valid=True, effective=True)
@with_all_phases
@spec_state_test
def test_wrong_index(spec, state):
validator_index = len(state.validators)
amount = spec.MAX_EFFECTIVE_BALANCE
deposit = prepare_state_and_deposit(spec, state, validator_index, amount)
# mess up eth1_deposit_index
deposit.index = state.eth1_deposit_index + 1
sign_deposit_data(spec, state, deposit.data, privkeys[validator_index])
yield from run_deposit_processing(spec, state, deposit, validator_index, valid=False)
@with_all_phases
@spec_state_test
def test_wrong_deposit_for_deposit_count(spec, state):
@ -172,9 +157,6 @@ def test_wrong_deposit_for_deposit_count(spec, state):
yield from run_deposit_processing(spec, state, deposit_2, index_2, valid=False)
# TODO: test invalid signature
@with_all_phases
@spec_state_test
def test_bad_merkle_proof(spec, state):
@ -183,7 +165,7 @@ def test_bad_merkle_proof(spec, state):
deposit = prepare_state_and_deposit(spec, state, validator_index, amount)
# mess up merkle branch
deposit.proof[-1] = spec.ZERO_HASH
deposit.proof[5] = spec.ZERO_HASH
sign_deposit_data(spec, state, deposit.data, privkeys[validator_index])

View File

@ -1,5 +1,4 @@
from copy import deepcopy
from typing import List
from eth2spec.utils.ssz.ssz_impl import signing_root
from eth2spec.utils.bls import bls_sign
@ -28,7 +27,7 @@ def test_empty_block_transition(spec, state):
state_transition_and_sign_block(spec, state, block)
yield 'blocks', [block], List[spec.BeaconBlock]
yield 'blocks', [block]
yield 'post', state
assert len(state.eth1_data_votes) == pre_eth1_votes + 1
@ -48,7 +47,7 @@ def test_skipped_slots(spec, state):
state_transition_and_sign_block(spec, state, block)
yield 'blocks', [block], List[spec.BeaconBlock]
yield 'blocks', [block]
yield 'post', state
assert state.slot == block.slot
@ -69,7 +68,7 @@ def test_empty_epoch_transition(spec, state):
state_transition_and_sign_block(spec, state, block)
yield 'blocks', [block], List[spec.BeaconBlock]
yield 'blocks', [block]
yield 'post', state
assert state.slot == block.slot
@ -90,7 +89,7 @@ def test_empty_epoch_transition(spec, state):
# state_transition_and_sign_block(spec, state, block)
# yield 'blocks', [block], List[spec.BeaconBlock]
# yield 'blocks', [block]
# yield 'post', state
# assert state.slot == block.slot
@ -120,7 +119,7 @@ def test_proposer_slashing(spec, state):
state_transition_and_sign_block(spec, state, block)
yield 'blocks', [block], List[spec.BeaconBlock]
yield 'blocks', [block]
yield 'post', state
# check if slashed
@ -155,7 +154,7 @@ def test_attester_slashing(spec, state):
state_transition_and_sign_block(spec, state, block)
yield 'blocks', [block], List[spec.BeaconBlock]
yield 'blocks', [block]
yield 'post', state
slashed_validator = state.validators[validator_index]
@ -193,7 +192,7 @@ def test_deposit_in_block(spec, state):
state_transition_and_sign_block(spec, state, block)
yield 'blocks', [block], List[spec.BeaconBlock]
yield 'blocks', [block]
yield 'post', state
assert len(state.validators) == initial_registry_len + 1
@ -221,7 +220,7 @@ def test_deposit_top_up(spec, state):
state_transition_and_sign_block(spec, state, block)
yield 'blocks', [block], List[spec.BeaconBlock]
yield 'blocks', [block]
yield 'post', state
assert len(state.validators) == initial_registry_len
@ -256,7 +255,7 @@ def test_attestation(spec, state):
sign_block(spec, state, epoch_block)
state_transition_and_sign_block(spec, state, epoch_block)
yield 'blocks', [attestation_block, epoch_block], List[spec.BeaconBlock]
yield 'blocks', [attestation_block, epoch_block]
yield 'post', state
assert len(state.current_epoch_attestations) == 0
@ -303,7 +302,7 @@ def test_voluntary_exit(spec, state):
sign_block(spec, state, exit_block)
state_transition_and_sign_block(spec, state, exit_block)
yield 'blocks', [initiate_exit_block, exit_block], List[spec.BeaconBlock]
yield 'blocks', [initiate_exit_block, exit_block]
yield 'post', state
assert state.validators[validator_index].exit_epoch < spec.FAR_FUTURE_EPOCH
@ -334,7 +333,7 @@ def test_voluntary_exit(spec, state):
# state_transition_and_sign_block(spec, state, block)
# yield 'blocks', [block], List[spec.BeaconBlock]
# yield 'blocks', [block]
# yield 'post', state
# sender_balance = get_balance(state, sender_index)
@ -362,7 +361,7 @@ def test_balance_driven_status_transitions(spec, state):
sign_block(spec, state, block)
state_transition_and_sign_block(spec, state, block)
yield 'blocks', [block], List[spec.BeaconBlock]
yield 'blocks', [block]
yield 'post', state
assert state.validators[validator_index].exit_epoch < spec.FAR_FUTURE_EPOCH
@ -379,7 +378,7 @@ def test_historical_batch(spec, state):
block = build_empty_block_for_next_slot(spec, state, signed=True)
state_transition_and_sign_block(spec, state, block)
yield 'blocks', [block], List[spec.BeaconBlock]
yield 'blocks', [block]
yield 'post', state
assert state.slot == block.slot
@ -408,7 +407,7 @@ def test_historical_batch(spec, state):
# state_transition_and_sign_block(spec, state, block)
# yield 'blocks', [block], List[spec.BeaconBlock]
# yield 'blocks', [block]
# yield 'post', state
# assert state.slot % spec.SLOTS_PER_ETH1_VOTING_PERIOD == 0

View File

@ -1,5 +1,4 @@
from copy import deepcopy
from typing import List
from eth2spec.test.context import spec_state_test, never_bls, with_all_phases
from eth2spec.test.helpers.state import next_epoch, state_transition_and_sign_block
@ -39,11 +38,13 @@ def next_epoch_with_attestations(spec,
state,
fill_cur_epoch,
fill_prev_epoch):
assert state.slot % spec.SLOTS_PER_EPOCH == 0
post_state = deepcopy(state)
blocks = []
for _ in range(spec.SLOTS_PER_EPOCH):
block = build_empty_block_for_next_slot(spec, post_state)
if fill_cur_epoch:
if fill_cur_epoch and post_state.slot >= spec.MIN_ATTESTATION_INCLUSION_DELAY:
slot_to_attest = post_state.slot - spec.MIN_ATTESTATION_INCLUSION_DELAY + 1
if slot_to_attest >= spec.get_epoch_start_slot(spec.get_current_epoch(post_state)):
cur_attestation = get_valid_attestation(spec, post_state, slot_to_attest)
@ -63,11 +64,13 @@ def next_epoch_with_attestations(spec,
@with_all_phases
@never_bls
@spec_state_test
def test_finality_rule_4(spec, state):
def test_finality_no_updates_at_genesis(spec, state):
assert spec.get_current_epoch(state) == spec.GENESIS_EPOCH
yield 'pre', state
blocks = []
for epoch in range(4):
for epoch in range(2):
prev_state, new_blocks, state = next_epoch_with_attestations(spec, state, True, False)
blocks += new_blocks
@ -77,15 +80,37 @@ def test_finality_rule_4(spec, state):
# justification/finalization skipped at GENESIS_EPOCH + 1
elif epoch == 1:
check_finality(spec, state, prev_state, False, False, False)
elif epoch == 2:
yield 'blocks', blocks
yield 'post', state
@with_all_phases
@never_bls
@spec_state_test
def test_finality_rule_4(spec, state):
# get past first two epochs that finality does not run on
next_epoch(spec, state)
apply_empty_block(spec, state)
next_epoch(spec, state)
apply_empty_block(spec, state)
yield 'pre', state
blocks = []
for epoch in range(2):
prev_state, new_blocks, state = next_epoch_with_attestations(spec, state, True, False)
blocks += new_blocks
if epoch == 0:
check_finality(spec, state, prev_state, True, False, False)
elif epoch >= 3:
elif epoch == 1:
# rule 4 of finality
check_finality(spec, state, prev_state, True, True, True)
assert state.finalized_epoch == prev_state.current_justified_epoch
assert state.finalized_root == prev_state.current_justified_root
yield 'blocks', blocks, List[spec.BeaconBlock]
yield 'blocks', blocks
yield 'post', state
@ -116,7 +141,7 @@ def test_finality_rule_1(spec, state):
assert state.finalized_epoch == prev_state.previous_justified_epoch
assert state.finalized_root == prev_state.previous_justified_root
yield 'blocks', blocks, List[spec.BeaconBlock]
yield 'blocks', blocks
yield 'post', state
@ -149,7 +174,7 @@ def test_finality_rule_2(spec, state):
blocks += new_blocks
yield 'blocks', blocks, List[spec.BeaconBlock]
yield 'blocks', blocks
yield 'post', state
@ -199,5 +224,5 @@ def test_finality_rule_3(spec, state):
assert state.finalized_epoch == prev_state.current_justified_epoch
assert state.finalized_root == prev_state.current_justified_root
yield 'blocks', blocks, List[spec.BeaconBlock]
yield 'blocks', blocks
yield 'post', state

View File

@ -4,7 +4,7 @@ from .hash_function import hash
ZERO_BYTES32 = b'\x00' * 32
zerohashes = [ZERO_BYTES32]
for layer in range(1, 32):
for layer in range(1, 100):
zerohashes.append(hash(zerohashes[layer - 1] + zerohashes[layer - 1]))
@ -44,11 +44,35 @@ def next_power_of_two(v: int) -> int:
return 1 << (v - 1).bit_length()
def merkleize_chunks(chunks):
tree = chunks[::]
margin = next_power_of_two(len(chunks)) - len(chunks)
tree.extend([ZERO_BYTES32] * margin)
tree = [ZERO_BYTES32] * len(tree) + tree
for i in range(len(tree) // 2 - 1, 0, -1):
tree[i] = hash(tree[i * 2] + tree[i * 2 + 1])
return tree[1]
def merkleize_chunks(chunks, pad_to: int = 1):
count = len(chunks)
depth = max(count - 1, 0).bit_length()
max_depth = max(depth, (pad_to - 1).bit_length())
tmp = [None for _ in range(max_depth + 1)]
def merge(h, i):
j = 0
while True:
if i & (1 << j) == 0:
if i == count and j < depth:
h = hash(h + zerohashes[j]) # keep going if we are complementing the void to the next power of 2
else:
break
else:
h = hash(tmp[j] + h)
j += 1
tmp[j] = h
# merge in leaf by leaf.
for i in range(count):
merge(chunks[i], i)
# complement with 0 if empty, or if not the right power of 2
if 1 << depth != count:
merge(zerohashes[0], count)
# the next power of two may be smaller than the ultimate virtual size, complement with zero-hashes at each depth.
for j in range(depth, max_depth):
tmp[j + 1] = hash(tmp[j] + zerohashes[j])
return tmp[max_depth]

View File

@ -1,11 +1,7 @@
from ..merkle_minimal import merkleize_chunks, hash
from eth2spec.utils.ssz.ssz_typing import (
is_uint_type, is_bool_type, is_container_type,
is_list_kind, is_vector_kind,
read_vector_elem_type, read_elem_type,
uint_byte_size,
infer_input_type,
get_zero_value,
from ..merkle_minimal import merkleize_chunks
from ..hash_function import hash
from .ssz_typing import (
SSZValue, SSZType, BasicValue, BasicType, Series, Elements, Bool, Container, List, Bytes, uint,
)
# SSZ Serialization
@ -14,68 +10,47 @@ from eth2spec.utils.ssz.ssz_typing import (
BYTES_PER_LENGTH_OFFSET = 4
def is_basic_type(typ):
return is_uint_type(typ) or is_bool_type(typ)
def serialize_basic(value, typ):
if is_uint_type(typ):
return value.to_bytes(uint_byte_size(typ), 'little')
elif is_bool_type(typ):
def serialize_basic(value: SSZValue):
if isinstance(value, uint):
return value.to_bytes(value.type().byte_len, 'little')
elif isinstance(value, Bool):
if value:
return b'\x01'
else:
return b'\x00'
else:
raise Exception("Type not supported: {}".format(typ))
raise Exception(f"Type not supported: {type(value)}")
def deserialize_basic(value, typ):
if is_uint_type(typ):
def deserialize_basic(value, typ: BasicType):
if issubclass(typ, uint):
return typ(int.from_bytes(value, 'little'))
elif is_bool_type(typ):
elif issubclass(typ, Bool):
assert value in (b'\x00', b'\x01')
return True if value == b'\x01' else False
return typ(value == b'\x01')
else:
raise Exception("Type not supported: {}".format(typ))
raise Exception(f"Type not supported: {typ}")
def is_fixed_size(typ):
if is_basic_type(typ):
return True
elif is_list_kind(typ):
return False
elif is_vector_kind(typ):
return is_fixed_size(read_vector_elem_type(typ))
elif is_container_type(typ):
return all(is_fixed_size(t) for t in typ.get_field_types())
def is_empty(obj: SSZValue):
return type(obj).default() == obj
def serialize(obj: SSZValue):
if isinstance(obj, BasicValue):
return serialize_basic(obj)
elif isinstance(obj, Series):
return encode_series(obj)
else:
raise Exception("Type not supported: {}".format(typ))
raise Exception(f"Type not supported: {type(obj)}")
def is_empty(obj):
return get_zero_value(type(obj)) == obj
@infer_input_type
def serialize(obj, typ=None):
if is_basic_type(typ):
return serialize_basic(obj, typ)
elif is_list_kind(typ) or is_vector_kind(typ):
return encode_series(obj, [read_elem_type(typ)] * len(obj))
elif is_container_type(typ):
return encode_series(obj.get_field_values(), typ.get_field_types())
else:
raise Exception("Type not supported: {}".format(typ))
def encode_series(values, types):
# bytes and bytesN are already in the right format.
if isinstance(values, bytes):
def encode_series(values: Series):
if isinstance(values, bytes): # Bytes and BytesN are already like serialized output
return values
# Recursively serialize
parts = [(is_fixed_size(types[i]), serialize(values[i], typ=types[i])) for i in range(len(values))]
parts = [(v.type().is_fixed_size(), serialize(v)) for v in values]
# Compute and check lengths
fixed_lengths = [len(serialized) if constant_size else BYTES_PER_LENGTH_OFFSET
@ -107,10 +82,10 @@ def encode_series(values, types):
# -----------------------------
def pack(values, subtype):
if isinstance(values, bytes):
def pack(values: Series):
if isinstance(values, bytes): # Bytes and BytesN are already packed
return values
return b''.join([serialize_basic(value, subtype) for value in values])
return b''.join([serialize_basic(value) for value in values])
def chunkify(bytez):
@ -123,41 +98,50 @@ def mix_in_length(root, length):
return hash(root + length.to_bytes(32, 'little'))
def is_bottom_layer_kind(typ):
def is_bottom_layer_kind(typ: SSZType):
return (
is_basic_type(typ) or
(is_list_kind(typ) or is_vector_kind(typ)) and is_basic_type(read_elem_type(typ))
isinstance(typ, BasicType) or
(issubclass(typ, Elements) and isinstance(typ.elem_type, BasicType))
)
@infer_input_type
def get_typed_values(obj, typ=None):
if is_container_type(typ):
return obj.get_typed_values()
elif is_list_kind(typ) or is_vector_kind(typ):
elem_type = read_elem_type(typ)
return list(zip(obj, [elem_type] * len(obj)))
def item_length(typ: SSZType) -> int:
if issubclass(typ, BasicValue):
return typ.byte_len
else:
raise Exception("Invalid type")
return 32
@infer_input_type
def hash_tree_root(obj, typ=None):
if is_bottom_layer_kind(typ):
data = serialize_basic(obj, typ) if is_basic_type(typ) else pack(obj, read_elem_type(typ))
leaves = chunkify(data)
def chunk_count(typ: SSZType) -> int:
if isinstance(typ, BasicType):
return 1
elif issubclass(typ, Elements):
return (typ.length * item_length(typ.elem_type) + 31) // 32
elif issubclass(typ, Container):
return len(typ.get_fields())
else:
fields = get_typed_values(obj, typ=typ)
leaves = [hash_tree_root(field_value, typ=field_typ) for field_value, field_typ in fields]
if is_list_kind(typ):
return mix_in_length(merkleize_chunks(leaves), len(obj))
raise Exception(f"Type not supported: {typ}")
def hash_tree_root(obj: SSZValue):
if isinstance(obj, Series):
if is_bottom_layer_kind(obj.type()):
leaves = chunkify(pack(obj))
else:
leaves = [hash_tree_root(value) for value in obj]
elif isinstance(obj, BasicValue):
leaves = chunkify(serialize_basic(obj))
else:
raise Exception(f"Type not supported: {type(obj)}")
if isinstance(obj, (List, Bytes)):
return mix_in_length(merkleize_chunks(leaves, pad_to=chunk_count(obj.type())), len(obj))
else:
return merkleize_chunks(leaves)
@infer_input_type
def signing_root(obj, typ):
assert is_container_type(typ)
def signing_root(obj: Container):
# ignore last field
leaves = [hash_tree_root(field_value, typ=field_typ) for field_value, field_typ in obj.get_typed_values()[:-1]]
fields = [field for field in obj][:-1]
leaves = [hash_tree_root(f) for f in fields]
return merkleize_chunks(chunkify(b''.join(leaves)))

View File

@ -1,149 +1,183 @@
from typing import Dict, Iterator
import copy
from types import GeneratorType
from typing import (
List,
Iterable,
TypeVar,
Type,
NewType,
Union,
)
from typing_inspect import get_origin
# SSZ integers
# -----------------------------
class uint(int):
class DefaultingTypeMeta(type):
def default(cls):
raise Exception("Not implemented")
class SSZType(DefaultingTypeMeta):
def is_fixed_size(cls):
raise Exception("Not implemented")
class SSZValue(object, metaclass=SSZType):
def type(self):
return self.__class__
class BasicType(SSZType):
byte_len = 0
def is_fixed_size(cls):
return True
class BasicValue(int, SSZValue, metaclass=BasicType):
pass
class Bool(BasicValue): # can't subclass bool.
byte_len = 1
def __new__(cls, value, *args, **kwargs):
if value < 0 or value > 1:
raise ValueError(f"value {value} out of bounds for bit")
return super().__new__(cls, value)
@classmethod
def default(cls):
return cls(0)
def __bool__(self):
return self > 0
# Alias for Bool
class Bit(Bool):
pass
class uint(BasicValue, metaclass=BasicType):
def __new__(cls, value, *args, **kwargs):
if value < 0:
raise ValueError("unsigned types must not be negative")
if cls.byte_len and value.bit_length() > (cls.byte_len << 3):
raise ValueError("value out of bounds for uint{}".format(cls.byte_len * 8))
return super().__new__(cls, value)
def __add__(self, other):
return self.__class__(super().__add__(coerce_type_maybe(other, self.__class__, strict=True)))
def __sub__(self, other):
return self.__class__(super().__sub__(coerce_type_maybe(other, self.__class__, strict=True)))
@classmethod
def default(cls):
return cls(0)
class uint8(uint):
byte_len = 1
def __new__(cls, value, *args, **kwargs):
if value.bit_length() > 8:
raise ValueError("value out of bounds for uint8")
return super().__new__(cls, value)
# Alias for uint8
byte = NewType('byte', uint8)
class byte(uint8):
pass
class uint16(uint):
byte_len = 2
def __new__(cls, value, *args, **kwargs):
if value.bit_length() > 16:
raise ValueError("value out of bounds for uint16")
return super().__new__(cls, value)
class uint32(uint):
byte_len = 4
def __new__(cls, value, *args, **kwargs):
if value.bit_length() > 32:
raise ValueError("value out of bounds for uint16")
return super().__new__(cls, value)
class uint64(uint):
byte_len = 8
def __new__(cls, value, *args, **kwargs):
if value.bit_length() > 64:
raise ValueError("value out of bounds for uint64")
return super().__new__(cls, value)
class uint128(uint):
byte_len = 16
def __new__(cls, value, *args, **kwargs):
if value.bit_length() > 128:
raise ValueError("value out of bounds for uint128")
return super().__new__(cls, value)
class uint256(uint):
byte_len = 32
def __new__(cls, value, *args, **kwargs):
if value.bit_length() > 256:
raise ValueError("value out of bounds for uint256")
return super().__new__(cls, value)
def coerce_type_maybe(v, typ: SSZType, strict: bool = False):
v_typ = type(v)
# shortcut if it's already the type we are looking for
if v_typ == typ:
return v
elif isinstance(v, int):
if isinstance(v, uint): # do not coerce from one uintX to another uintY
if issubclass(typ, uint) and v.type().byte_len == typ.byte_len:
return typ(v)
# revert to default behavior below if-else. (ValueError/bare)
else:
return typ(v)
elif isinstance(v, (list, tuple)):
return typ(*v)
elif isinstance(v, (bytes, BytesN, Bytes)):
return typ(v)
elif isinstance(v, GeneratorType):
return typ(v)
# just return as-is, Value-checkers will take care of it not being coerced, if we are not strict.
if strict and not isinstance(v, typ):
raise ValueError("Type coercion of {} to {} failed".format(v, typ))
return v
def is_uint_type(typ):
# All integers are uint in the scope of the spec here.
# Since we default to uint64. Bounds can be checked elsewhere.
# However, some are wrapped in a NewType
if hasattr(typ, '__supertype__'):
# get the type that the NewType is wrapping
typ = typ.__supertype__
class Series(SSZValue):
return isinstance(typ, type) and issubclass(typ, int) and not issubclass(typ, bool)
def __iter__(self) -> Iterator[SSZValue]:
raise Exception("Not implemented")
def uint_byte_size(typ):
if hasattr(typ, '__supertype__'):
typ = typ.__supertype__
if isinstance(typ, type):
if issubclass(typ, uint):
return typ.byte_len
elif issubclass(typ, int):
# Default to uint64
return 8
else:
raise TypeError("Type %s is not an uint (or int-default uint64) type" % typ)
# SSZ Container base class
# -----------------------------
# Note: importing ssz functionality locally, to avoid import loop
class Container(object):
class Container(Series, metaclass=SSZType):
def __init__(self, **kwargs):
cls = self.__class__
for f, t in cls.get_fields():
for f, t in cls.get_fields().items():
if f not in kwargs:
setattr(self, f, get_zero_value(t))
setattr(self, f, t.default())
else:
setattr(self, f, kwargs[f])
value = coerce_type_maybe(kwargs[f], t)
if not isinstance(value, t):
raise ValueError(f"Bad input for class {self.__class__}:"
f" field: {f} type: {t} value: {value} value type: {type(value)}")
setattr(self, f, value)
def serialize(self):
from .ssz_impl import serialize
return serialize(self, self.__class__)
return serialize(self)
def hash_tree_root(self):
from .ssz_impl import hash_tree_root
return hash_tree_root(self, self.__class__)
return hash_tree_root(self)
def signing_root(self):
from .ssz_impl import signing_root
return signing_root(self, self.__class__)
return signing_root(self)
def get_field_values(self):
cls = self.__class__
return [getattr(self, field) for field in cls.get_field_names()]
def __setattr__(self, name, value):
if name not in self.__class__.__annotations__:
raise AttributeError("Cannot change non-existing SSZ-container attribute")
field_typ = self.__class__.__annotations__[name]
value = coerce_type_maybe(value, field_typ)
if not isinstance(value, field_typ):
raise ValueError(f"Cannot set field of {self.__class__}:"
f" field: {name} type: {field_typ} value: {value} value type: {type(value)}")
super().__setattr__(name, value)
def __repr__(self):
return repr({field: getattr(self, field) for field in self.get_field_names()})
return repr({field: (getattr(self, field) if hasattr(self, field) else 'unset')
for field in self.get_fields().keys()})
def __str__(self):
output = []
for field in self.get_field_names():
output.append(f'{field}: {getattr(self, field)}')
output = [f'{self.__class__.__name__}']
for field in self.get_fields().keys():
output.append(f' {field}: {getattr(self, field)}')
return "\n".join(output)
def __eq__(self, other):
@ -156,404 +190,261 @@ class Container(object):
return copy.deepcopy(self)
@classmethod
def get_fields_dict(cls):
def get_fields(cls) -> Dict[str, SSZType]:
if not hasattr(cls, '__annotations__'): # no container fields
return {}
return dict(cls.__annotations__)
@classmethod
def get_fields(cls):
return list(dict(cls.__annotations__).items())
def get_typed_values(self):
return list(zip(self.get_field_values(), self.get_field_types()))
def default(cls):
return cls(**{f: t.default() for f, t in cls.get_fields().items()})
@classmethod
def get_field_names(cls):
return list(cls.__annotations__.keys())
def is_fixed_size(cls):
return all(t.is_fixed_size() for t in cls.get_fields().values())
@classmethod
def get_field_types(cls):
# values of annotations are the types corresponding to the fields, not instance values.
return list(cls.__annotations__.values())
def __iter__(self) -> Iterator[SSZValue]:
return iter([getattr(self, field) for field in self.get_fields().keys()])
# SSZ vector
# -----------------------------
class ParamsBase(Series):
_has_params = False
def __new__(cls, *args, **kwargs):
if not cls._has_params:
raise Exception("cannot init bare type without params")
return super().__new__(cls, **kwargs)
def _is_vector_instance_of(a, b):
# Other must not be a BytesN
if issubclass(b, bytes):
return False
elif not hasattr(b, 'elem_type') or not hasattr(b, 'length'):
# Vector (b) is not an instance of Vector[X, Y] (a)
return False
elif not hasattr(a, 'elem_type') or not hasattr(a, 'length'):
# Vector[X, Y] (b) is an instance of Vector (a)
return True
else:
# Vector[X, Y] (a) is an instance of Vector[X, Y] (b)
return a.elem_type == b.elem_type and a.length == b.length
class ParamsMeta(SSZType):
def _is_equal_vector_type(a, b):
# Other must not be a BytesN
if issubclass(b, bytes):
return False
elif not hasattr(a, 'elem_type') or not hasattr(a, 'length'):
if not hasattr(b, 'elem_type') or not hasattr(b, 'length'):
# Vector == Vector
return True
else:
# Vector != Vector[X, Y]
return False
elif not hasattr(b, 'elem_type') or not hasattr(b, 'length'):
# Vector[X, Y] != Vector
return False
else:
# Vector[X, Y] == Vector[X, Y]
return a.elem_type == b.elem_type and a.length == b.length
class VectorMeta(type):
def __new__(cls, class_name, parents, attrs):
out = type.__new__(cls, class_name, parents, attrs)
if 'elem_type' in attrs and 'length' in attrs:
setattr(out, 'elem_type', attrs['elem_type'])
setattr(out, 'length', attrs['length'])
if hasattr(out, "_has_params") and getattr(out, "_has_params"):
for k, v in attrs.items():
setattr(out, k, v)
return out
def __getitem__(self, params):
if not isinstance(params, tuple) or len(params) != 2:
raise Exception("Vector must be instantiated with two args: elem type and length")
o = self.__class__(self.__name__, (Vector,), {'elem_type': params[0], 'length': params[1]})
o._name = 'Vector'
o = self.__class__(self.__name__, (self,), self.attr_from_params(params))
return o
def __subclasscheck__(self, sub):
return _is_vector_instance_of(self, sub)
def __instancecheck__(self, other):
return _is_vector_instance_of(self, other.__class__)
def __eq__(self, other):
return _is_equal_vector_type(self, other)
def __ne__(self, other):
return not _is_equal_vector_type(self, other)
def __hash__(self):
return hash(self.__class__)
class Vector(metaclass=VectorMeta):
def __init__(self, *args: Iterable):
cls = self.__class__
if not hasattr(cls, 'elem_type'):
raise TypeError("Type Vector without elem_type data cannot be instantiated")
elif not hasattr(cls, 'length'):
raise TypeError("Type Vector without length data cannot be instantiated")
if len(args) != cls.length:
if len(args) == 0:
args = [get_zero_value(cls.elem_type) for _ in range(cls.length)]
else:
raise TypeError("Typed vector with length %d cannot hold %d items" % (cls.length, len(args)))
self.items = list(args)
# cannot check non-type objects, or parametrized types
if isinstance(cls.elem_type, type) and not hasattr(cls.elem_type, '__args__'):
for i, item in enumerate(self.items):
if not issubclass(cls.elem_type, type(item)):
raise TypeError("Typed vector cannot hold differently typed value"
" at index %d. Got type: %s, expected type: %s" % (i, type(item), cls.elem_type))
def serialize(self):
from .ssz_impl import serialize
return serialize(self, self.__class__)
def hash_tree_root(self):
from .ssz_impl import hash_tree_root
return hash_tree_root(self, self.__class__)
def __str__(self):
return f"{self.__name__}~{self.__class__.__name__}"
def __repr__(self):
return repr({'length': self.__class__.length, 'items': self.items})
return self, self.__class__
def __getitem__(self, key):
return self.items[key]
def __setitem__(self, key, value):
self.items[key] = value
def __iter__(self):
return iter(self.items)
def __len__(self):
return len(self.items)
def __eq__(self, other):
return self.hash_tree_root() == other.hash_tree_root()
# SSZ BytesN
# -----------------------------
def _is_bytes_n_instance_of(a, b):
# Other has to be a Bytes derivative class to be a BytesN
if not issubclass(b, bytes):
return False
elif not hasattr(b, 'length'):
# BytesN (b) is not an instance of BytesN[X] (a)
return False
elif not hasattr(a, 'length'):
# BytesN[X] (b) is an instance of BytesN (a)
return True
else:
# BytesN[X] (a) is an instance of BytesN[X] (b)
return a.length == b.length
def _is_equal_bytes_n_type(a, b):
# Other has to be a Bytes derivative class to be a BytesN
if not issubclass(b, bytes):
return False
elif not hasattr(a, 'length'):
if not hasattr(b, 'length'):
# BytesN == BytesN
return True
else:
# BytesN != BytesN[X]
return False
elif not hasattr(b, 'length'):
# BytesN[X] != BytesN
return False
else:
# BytesN[X] == BytesN[X]
return a.length == b.length
class BytesNMeta(type):
def __new__(cls, class_name, parents, attrs):
out = type.__new__(cls, class_name, parents, attrs)
if 'length' in attrs:
setattr(out, 'length', attrs['length'])
out._name = 'BytesN'
out.elem_type = byte
return out
def __getitem__(self, n):
return self.__class__(self.__name__, (BytesN,), {'length': n})
def __subclasscheck__(self, sub):
return _is_bytes_n_instance_of(self, sub)
def __instancecheck__(self, other):
return _is_bytes_n_instance_of(self, other.__class__)
def __eq__(self, other):
if other == ():
return False
return _is_equal_bytes_n_type(self, other)
def __ne__(self, other):
return not _is_equal_bytes_n_type(self, other)
def __hash__(self):
return hash(self.__class__)
def parse_bytes(val):
if val is None:
return None
elif isinstance(val, str):
# TODO: import from eth-utils instead, and do: hexstr_if_str(to_bytes, val)
return None
elif isinstance(val, bytes):
return val
elif isinstance(val, int):
return bytes([val])
elif isinstance(val, (list, GeneratorType)):
return bytes(val)
else:
return None
class BytesN(bytes, metaclass=BytesNMeta):
def __new__(cls, *args):
if not hasattr(cls, 'length'):
return
bytesval = None
if len(args) == 1:
val: Union[bytes, int, str] = args[0]
bytesval = parse_bytes(val)
elif len(args) > 1:
# TODO: each int is 1 byte, check size, create bytesval
bytesval = bytes(args)
if bytesval is None:
if cls.length == 0:
bytesval = b''
def attr_from_params(self, p):
# single key params are valid too. Wrap them in a tuple.
params = p if isinstance(p, tuple) else (p,)
res = {'_has_params': True}
i = 0
for (name, typ) in self.__annotations__.items():
if hasattr(self.__class__, name):
res[name] = getattr(self.__class__, name)
else:
bytesval = b'\x00' * cls.length
if len(bytesval) != cls.length:
raise TypeError("BytesN[%d] cannot be initialized with value of %d bytes" % (cls.length, len(bytesval)))
return super().__new__(cls, bytesval)
if i >= len(params):
i += 1
continue
param = params[i]
if not isinstance(param, typ):
raise TypeError(
"cannot create parametrized class with param {} as {} of type {}".format(param, name, typ))
res[name] = param
i += 1
if len(params) != i:
raise TypeError("provided parameters {} mismatch required parameter count {}".format(params, i))
return res
def serialize(self):
from .ssz_impl import serialize
return serialize(self, self.__class__)
def __subclasscheck__(self, subclass):
# check regular class system if we can, solves a lot of the normal cases.
if super().__subclasscheck__(subclass):
return True
# if they are not normal subclasses, they are of the same class.
# then they should have the same name
if subclass.__name__ != self.__name__:
return False
# If they do have the same name, they should also have the same params.
for name, typ in self.__annotations__.items():
if hasattr(self, name) and hasattr(subclass, name) \
and getattr(subclass, name) != getattr(self, name):
return False
return True
def hash_tree_root(self):
from .ssz_impl import hash_tree_root
return hash_tree_root(self, self.__class__)
def __instancecheck__(self, obj):
return self.__subclasscheck__(obj.__class__)
class Bytes4(BytesN):
length = 4
class ElementsType(ParamsMeta):
elem_type: SSZType
length: int
class Bytes32(BytesN):
length = 32
class Elements(ParamsBase, metaclass=ElementsType):
pass
class Bytes48(BytesN):
length = 48
class BaseList(list, Elements):
def __init__(self, *args):
items = self.extract_args(*args)
if not self.value_check(items):
raise ValueError(f"Bad input for class {self.__class__}: {items}")
super().__init__(items)
@classmethod
def value_check(cls, value):
return all(isinstance(v, cls.elem_type) for v in value) and len(value) <= cls.length
@classmethod
def extract_args(cls, *args):
x = list(args)
if len(x) == 1 and isinstance(x[0], (GeneratorType, list, tuple)):
x = list(x[0])
x = [coerce_type_maybe(v, cls.elem_type) for v in x]
return x
def __str__(self):
cls = self.__class__
return f"{cls.__name__}[{cls.elem_type.__name__}, {cls.length}]({', '.join(str(v) for v in self)})"
def __getitem__(self, k) -> SSZValue:
if isinstance(k, int): # check if we are just doing a lookup, and not slicing
if k < 0:
raise IndexError(f"cannot get item in type {self.__class__} at negative index {k}")
if k > len(self):
raise IndexError(f"cannot get item in type {self.__class__}"
f" at out of bounds index {k}")
return super().__getitem__(k)
def __setitem__(self, k, v):
if k < 0:
raise IndexError(f"cannot set item in type {self.__class__} at negative index {k} (to {v})")
if k > len(self):
raise IndexError(f"cannot set item in type {self.__class__}"
f" at out of bounds index {k} (to {v}, bound: {len(self)})")
super().__setitem__(k, coerce_type_maybe(v, self.__class__.elem_type, strict=True))
def append(self, v):
super().append(coerce_type_maybe(v, self.__class__.elem_type, strict=True))
def __iter__(self) -> Iterator[SSZValue]:
return super().__iter__()
def last(self):
# be explict about getting the last item, for the non-python readers, and negative-index safety
return self[len(self) - 1]
class Bytes96(BytesN):
length = 96
class List(BaseList):
@classmethod
def default(cls):
return cls()
# SSZ Defaults
# -----------------------------
def get_zero_value(typ):
if is_uint_type(typ):
return uint64(0)
elif is_list_type(typ):
return []
elif is_bool_type(typ):
@classmethod
def is_fixed_size(cls):
return False
elif is_vector_type(typ):
return typ()
elif is_bytesn_type(typ):
return typ()
elif is_bytes_type(typ):
class Vector(BaseList):
@classmethod
def value_check(cls, value):
# check length limit strictly
return len(value) == cls.length and super().value_check(value)
@classmethod
def default(cls):
return cls(cls.elem_type.default() for _ in range(cls.length))
@classmethod
def is_fixed_size(cls):
return cls.elem_type.is_fixed_size()
def append(self, v):
# Deep-copy and other utils like to change the internals during work.
# Only complain if we had the right size.
if len(self) == self.__class__.length:
raise Exception("cannot modify vector length")
else:
super().append(v)
def pop(self, *args):
raise Exception("cannot modify vector length")
class BytesType(ElementsType):
elem_type: SSZType = byte
length: int
class BaseBytes(bytes, Elements, metaclass=BytesType):
def __new__(cls, *args) -> "BaseBytes":
extracted_val = cls.extract_args(*args)
if not cls.value_check(extracted_val):
raise ValueError(f"Bad input for class {cls}: {extracted_val}")
return super().__new__(cls, extracted_val)
@classmethod
def extract_args(cls, *args):
x = args
if len(x) == 1 and isinstance(x[0], (GeneratorType, bytes)):
x = x[0]
if isinstance(x, bytes): # Includes BytesLike
return x
else:
return bytes(x) # E.g. GeneratorType put into bytes.
@classmethod
def value_check(cls, value):
# check type and virtual length limit
return isinstance(value, bytes) and len(value) <= cls.length
def __str__(self):
cls = self.__class__
return f"{cls.__name__}[{cls.length}]: {self.hex()}"
class Bytes(BaseBytes):
@classmethod
def default(cls):
return b''
elif is_container_type(typ):
return typ(**{f: get_zero_value(t) for f, t in typ.get_fields()})
else:
raise Exception("Type not supported: {}".format(typ))
@classmethod
def is_fixed_size(cls):
return False
# Type helpers
# -----------------------------
class BytesN(BaseBytes):
@classmethod
def extract_args(cls, *args):
if len(args) == 0:
return cls.default()
else:
return super().extract_args(*args)
@classmethod
def default(cls):
return b'\x00' * cls.length
@classmethod
def value_check(cls, value):
# check length limit strictly
return len(value) == cls.length and super().value_check(value)
@classmethod
def is_fixed_size(cls):
return True
def infer_type(obj):
if is_uint_type(obj.__class__):
return obj.__class__
elif isinstance(obj, int):
return uint64
elif isinstance(obj, list):
return List[infer_type(obj[0])]
elif isinstance(obj, (Vector, Container, bool, BytesN, bytes)):
return obj.__class__
else:
raise Exception("Unknown type for {}".format(obj))
def infer_input_type(fn):
"""
Decorator to run infer_type on the obj if typ argument is None
"""
def infer_helper(obj, typ=None, **kwargs):
if typ is None:
typ = infer_type(obj)
return fn(obj, typ=typ, **kwargs)
return infer_helper
def is_bool_type(typ):
"""
Check if the given type is a bool.
"""
if hasattr(typ, '__supertype__'):
typ = typ.__supertype__
return isinstance(typ, type) and issubclass(typ, bool)
def is_list_type(typ):
"""
Check if the given type is a list.
"""
return get_origin(typ) is List or get_origin(typ) is list
def is_bytes_type(typ):
"""
Check if the given type is a ``bytes``.
"""
# Do not accept subclasses of bytes here, to avoid confusion with BytesN
return typ == bytes
def is_bytesn_type(typ):
"""
Check if the given type is a BytesN.
"""
return isinstance(typ, type) and issubclass(typ, BytesN)
def is_list_kind(typ):
"""
Check if the given type is a kind of list. Can be bytes.
"""
return is_list_type(typ) or is_bytes_type(typ)
def is_vector_type(typ):
"""
Check if the given type is a vector.
"""
return isinstance(typ, type) and issubclass(typ, Vector)
def is_vector_kind(typ):
"""
Check if the given type is a kind of vector. Can be BytesN.
"""
return is_vector_type(typ) or is_bytesn_type(typ)
def is_container_type(typ):
"""
Check if the given type is a container.
"""
return isinstance(typ, type) and issubclass(typ, Container)
T = TypeVar('T')
L = TypeVar('L')
def read_list_elem_type(list_typ: Type[List[T]]) -> T:
if list_typ.__args__ is None or len(list_typ.__args__) != 1:
raise TypeError("Supplied list-type is invalid, no element type found.")
return list_typ.__args__[0]
def read_vector_elem_type(vector_typ: Type[Vector[T, L]]) -> T:
return vector_typ.elem_type
def read_elem_type(typ):
if typ == bytes or (isinstance(typ, type) and issubclass(typ, bytes)): # bytes or bytesN
return byte
elif is_list_type(typ):
return read_list_elem_type(typ)
elif is_vector_type(typ):
return read_vector_elem_type(typ)
else:
raise TypeError("Unexpected type: {}".format(typ))
# Helpers for common BytesN types.
Bytes4: BytesType = BytesN[4]
Bytes32: BytesType = BytesN[32]
Bytes48: BytesType = BytesN[48]
Bytes96: BytesType = BytesN[96]

View File

@ -0,0 +1,241 @@
from typing import Iterable
from .ssz_impl import serialize, hash_tree_root
from .ssz_typing import (
Bit, Bool, Container, List, Vector, Bytes, BytesN,
uint8, uint16, uint32, uint64, uint256, byte
)
from ..hash_function import hash as bytes_hash
import pytest
class EmptyTestStruct(Container):
pass
class SingleFieldTestStruct(Container):
A: byte
class SmallTestStruct(Container):
A: uint16
B: uint16
class FixedTestStruct(Container):
A: uint8
B: uint64
C: uint32
class VarTestStruct(Container):
A: uint16
B: List[uint16, 1024]
C: uint8
class ComplexTestStruct(Container):
A: uint16
B: List[uint16, 128]
C: uint8
D: Bytes[256]
E: VarTestStruct
F: Vector[FixedTestStruct, 4]
G: Vector[VarTestStruct, 2]
sig_test_data = [0 for i in range(96)]
for k, v in {0: 1, 32: 2, 64: 3, 95: 0xff}.items():
sig_test_data[k] = v
def chunk(hex: str) -> str:
return (hex + ("00" * 32))[:64] # just pad on the right, to 32 bytes (64 hex chars)
def h(a: str, b: str) -> str:
return bytes_hash(bytes.fromhex(a) + bytes.fromhex(b)).hex()
# zero hashes, as strings, for
zero_hashes = [chunk("")]
for layer in range(1, 32):
zero_hashes.append(h(zero_hashes[layer - 1], zero_hashes[layer - 1]))
def merge(a: str, branch: Iterable[str]) -> str:
"""
Merge (out on left, branch on right) leaf a with branch items, branch is from bottom to top.
"""
out = a
for b in branch:
out = h(out, b)
return out
test_data = [
("bit F", Bit(False), "00", chunk("00")),
("bit T", Bit(True), "01", chunk("01")),
("bool F", Bool(False), "00", chunk("00")),
("bool T", Bool(True), "01", chunk("01")),
("uint8 00", uint8(0x00), "00", chunk("00")),
("uint8 01", uint8(0x01), "01", chunk("01")),
("uint8 ab", uint8(0xab), "ab", chunk("ab")),
("byte 00", byte(0x00), "00", chunk("00")),
("byte 01", byte(0x01), "01", chunk("01")),
("byte ab", byte(0xab), "ab", chunk("ab")),
("uint16 0000", uint16(0x0000), "0000", chunk("0000")),
("uint16 abcd", uint16(0xabcd), "cdab", chunk("cdab")),
("uint32 00000000", uint32(0x00000000), "00000000", chunk("00000000")),
("uint32 01234567", uint32(0x01234567), "67452301", chunk("67452301")),
("small (4567, 0123)", SmallTestStruct(A=0x4567, B=0x0123), "67452301", h(chunk("6745"), chunk("2301"))),
("small [4567, 0123]::2", Vector[uint16, 2](uint16(0x4567), uint16(0x0123)), "67452301", chunk("67452301")),
("uint32 01234567", uint32(0x01234567), "67452301", chunk("67452301")),
("uint64 0000000000000000", uint64(0x00000000), "0000000000000000", chunk("0000000000000000")),
("uint64 0123456789abcdef", uint64(0x0123456789abcdef), "efcdab8967452301", chunk("efcdab8967452301")),
("sig", BytesN[96](*sig_test_data),
"0100000000000000000000000000000000000000000000000000000000000000"
"0200000000000000000000000000000000000000000000000000000000000000"
"03000000000000000000000000000000000000000000000000000000000000ff",
h(h(chunk("01"), chunk("02")),
h("03000000000000000000000000000000000000000000000000000000000000ff", chunk("")))),
("emptyTestStruct", EmptyTestStruct(), "", chunk("")),
("singleFieldTestStruct", SingleFieldTestStruct(A=0xab), "ab", chunk("ab")),
("uint16 list", List[uint16, 32](uint16(0xaabb), uint16(0xc0ad), uint16(0xeeff)), "bbaaadc0ffee",
h(h(chunk("bbaaadc0ffee"), chunk("")), chunk("03000000")) # max length: 32 * 2 = 64 bytes = 2 chunks
),
("uint32 list", List[uint32, 128](uint32(0xaabb), uint32(0xc0ad), uint32(0xeeff)), "bbaa0000adc00000ffee0000",
# max length: 128 * 4 = 512 bytes = 16 chunks
h(merge(chunk("bbaa0000adc00000ffee0000"), zero_hashes[0:4]), chunk("03000000"))
),
("uint256 list", List[uint256, 32](uint256(0xaabb), uint256(0xc0ad), uint256(0xeeff)),
"bbaa000000000000000000000000000000000000000000000000000000000000"
"adc0000000000000000000000000000000000000000000000000000000000000"
"ffee000000000000000000000000000000000000000000000000000000000000",
h(merge(h(h(chunk("bbaa"), chunk("adc0")), h(chunk("ffee"), chunk(""))), zero_hashes[2:5]), chunk("03000000"))
),
("uint256 list long", List[uint256, 128](i for i in range(1, 20)),
"".join([i.to_bytes(length=32, byteorder='little').hex() for i in range(1, 20)]),
h(merge(
h(
h(
h(
h(h(chunk("01"), chunk("02")), h(chunk("03"), chunk("04"))),
h(h(chunk("05"), chunk("06")), h(chunk("07"), chunk("08"))),
),
h(
h(h(chunk("09"), chunk("0a")), h(chunk("0b"), chunk("0c"))),
h(h(chunk("0d"), chunk("0e")), h(chunk("0f"), chunk("10"))),
)
),
h(
h(
h(h(chunk("11"), chunk("12")), h(chunk("13"), chunk(""))),
zero_hashes[2]
),
zero_hashes[3]
)
),
zero_hashes[5:7]), chunk("13000000")) # 128 chunks = 7 deep
),
("fixedTestStruct", FixedTestStruct(A=0xab, B=0xaabbccdd00112233, C=0x12345678), "ab33221100ddccbbaa78563412",
h(h(chunk("ab"), chunk("33221100ddccbbaa")), h(chunk("78563412"), chunk("")))),
("varTestStruct nil", VarTestStruct(A=0xabcd, C=0xff), "cdab07000000ff",
h(h(chunk("cdab"), h(zero_hashes[6], chunk("00000000"))), h(chunk("ff"), chunk("")))),
("varTestStruct empty", VarTestStruct(A=0xabcd, B=List[uint16, 1024](), C=0xff), "cdab07000000ff",
h(h(chunk("cdab"), h(zero_hashes[6], chunk("00000000"))), h(chunk("ff"), chunk("")))), # log2(1024*2/32)= 6 deep
("varTestStruct some", VarTestStruct(A=0xabcd, B=List[uint16, 1024](1, 2, 3), C=0xff),
"cdab07000000ff010002000300",
h(
h(
chunk("cdab"),
h(
merge(
chunk("010002000300"),
zero_hashes[0:6]
),
chunk("03000000") # length mix in
)
),
h(chunk("ff"), chunk(""))
)),
("complexTestStruct",
ComplexTestStruct(
A=0xaabb,
B=List[uint16, 128](0x1122, 0x3344),
C=0xff,
D=Bytes[256](b"foobar"),
E=VarTestStruct(A=0xabcd, B=List[uint16, 1024](1, 2, 3), C=0xff),
F=Vector[FixedTestStruct, 4](
FixedTestStruct(A=0xcc, B=0x4242424242424242, C=0x13371337),
FixedTestStruct(A=0xdd, B=0x3333333333333333, C=0xabcdabcd),
FixedTestStruct(A=0xee, B=0x4444444444444444, C=0x00112233),
FixedTestStruct(A=0xff, B=0x5555555555555555, C=0x44556677)),
G=Vector[VarTestStruct, 2](
VarTestStruct(A=0xdead, B=List[uint16, 1024](1, 2, 3), C=0x11),
VarTestStruct(A=0xbeef, B=List[uint16, 1024](4, 5, 6), C=0x22)),
),
"bbaa"
"47000000" # offset of B, []uint16
"ff"
"4b000000" # offset of foobar
"51000000" # offset of E
"cc424242424242424237133713"
"dd3333333333333333cdabcdab"
"ee444444444444444433221100"
"ff555555555555555577665544"
"5e000000" # pointer to G
"22114433" # contents of B
"666f6f626172" # foobar
"cdab07000000ff010002000300" # contents of E
"08000000" "15000000" # [start G]: local offsets of [2]varTestStruct
"adde0700000011010002000300"
"efbe0700000022040005000600",
h(
h(
h( # A and B
chunk("bbaa"),
h(merge(chunk("22114433"), zero_hashes[0:3]), chunk("02000000")) # 2*128/32 = 8 chunks
),
h( # C and D
chunk("ff"),
h(merge(chunk("666f6f626172"), zero_hashes[0:3]), chunk("06000000")) # 256/32 = 8 chunks
)
),
h(
h( # E and F
h(h(chunk("cdab"), h(merge(chunk("010002000300"), zero_hashes[0:6]), chunk("03000000"))),
h(chunk("ff"), chunk(""))),
h(
h(
h(h(chunk("cc"), chunk("4242424242424242")), h(chunk("37133713"), chunk(""))),
h(h(chunk("dd"), chunk("3333333333333333")), h(chunk("cdabcdab"), chunk(""))),
),
h(
h(h(chunk("ee"), chunk("4444444444444444")), h(chunk("33221100"), chunk(""))),
h(h(chunk("ff"), chunk("5555555555555555")), h(chunk("77665544"), chunk(""))),
),
)
),
h( # G and padding
h(
h(h(chunk("adde"), h(merge(chunk("010002000300"), zero_hashes[0:6]), chunk("03000000"))),
h(chunk("11"), chunk(""))),
h(h(chunk("efbe"), h(merge(chunk("040005000600"), zero_hashes[0:6]), chunk("03000000"))),
h(chunk("22"), chunk(""))),
),
chunk("")
)
)
))
]
@pytest.mark.parametrize("name, value, serialized, _", test_data)
def test_serialize(name, value, serialized, _):
assert serialize(value) == bytes.fromhex(serialized)
@pytest.mark.parametrize("name, value, _, root", test_data)
def test_hash_tree_root(name, value, _, root):
assert hash_tree_root(value) == bytes.fromhex(root)

View File

@ -0,0 +1,233 @@
from .ssz_typing import (
SSZValue, SSZType, BasicValue, BasicType, Series, ElementsType,
Elements, Bit, Bool, Container, List, Vector, Bytes, BytesN,
byte, uint, uint8, uint16, uint32, uint64, uint128, uint256,
Bytes32, Bytes48
)
def expect_value_error(fn, msg):
try:
fn()
raise AssertionError(msg)
except ValueError:
pass
def test_subclasses():
for u in [uint, uint8, uint16, uint32, uint64, uint128, uint256]:
assert issubclass(u, uint)
assert issubclass(u, int)
assert issubclass(u, BasicValue)
assert issubclass(u, SSZValue)
assert isinstance(u, SSZType)
assert isinstance(u, BasicType)
assert issubclass(Bool, BasicValue)
assert isinstance(Bool, BasicType)
for c in [Container, List, Vector, Bytes, BytesN]:
assert issubclass(c, Series)
assert issubclass(c, SSZValue)
assert isinstance(c, SSZType)
assert not issubclass(c, BasicValue)
assert not isinstance(c, BasicType)
for c in [List, Vector, Bytes, BytesN]:
assert issubclass(c, Elements)
assert isinstance(c, ElementsType)
def test_basic_instances():
for u in [uint, uint8, byte, uint16, uint32, uint64, uint128, uint256]:
v = u(123)
assert isinstance(v, uint)
assert isinstance(v, int)
assert isinstance(v, BasicValue)
assert isinstance(v, SSZValue)
assert isinstance(Bool(True), BasicValue)
assert isinstance(Bool(False), BasicValue)
assert isinstance(Bit(True), Bool)
assert isinstance(Bit(False), Bool)
def test_basic_value_bounds():
max = {
Bool: 2 ** 1,
Bit: 2 ** 1,
uint8: 2 ** (8 * 1),
byte: 2 ** (8 * 1),
uint16: 2 ** (8 * 2),
uint32: 2 ** (8 * 4),
uint64: 2 ** (8 * 8),
uint128: 2 ** (8 * 16),
uint256: 2 ** (8 * 32),
}
for k, v in max.items():
# this should work
assert k(v - 1) == v - 1
# but we do not allow overflows
expect_value_error(lambda: k(v), "no overflows allowed")
for k, _ in max.items():
# this should work
assert k(0) == 0
# but we do not allow underflows
expect_value_error(lambda: k(-1), "no underflows allowed")
def test_container():
class Foo(Container):
a: uint8
b: uint32
empty = Foo()
assert empty.a == uint8(0)
assert empty.b == uint32(0)
assert issubclass(Foo, Container)
assert issubclass(Foo, SSZValue)
assert issubclass(Foo, Series)
assert Foo.is_fixed_size()
x = Foo(a=uint8(123), b=uint32(45))
assert x.a == 123
assert x.b == 45
assert isinstance(x.a, uint8)
assert isinstance(x.b, uint32)
assert x.type().is_fixed_size()
class Bar(Container):
a: uint8
b: List[uint8, 1024]
assert not Bar.is_fixed_size()
y = Bar(a=123, b=List[uint8, 1024](uint8(1), uint8(2)))
assert y.a == 123
assert isinstance(y.a, uint8)
assert len(y.b) == 2
assert isinstance(y.a, uint8)
assert isinstance(y.b, List[uint8, 1024])
assert not y.type().is_fixed_size()
assert y.b[0] == 1
v: List = y.b
assert v.type().elem_type == uint8
assert v.type().length == 1024
y.a = 42
try:
y.a = 256 # out of bounds
assert False
except ValueError:
pass
try:
y.a = uint16(255) # within bounds, wrong type
assert False
except ValueError:
pass
try:
y.not_here = 5
assert False
except AttributeError:
pass
def test_list():
typ = List[uint64, 128]
assert issubclass(typ, List)
assert issubclass(typ, SSZValue)
assert issubclass(typ, Series)
assert issubclass(typ, Elements)
assert isinstance(typ, ElementsType)
assert not typ.is_fixed_size()
assert len(typ()) == 0 # empty
assert len(typ(uint64(0))) == 1 # single arg
assert len(typ(uint64(i) for i in range(10))) == 10 # generator
assert len(typ(uint64(0), uint64(1), uint64(2))) == 3 # args
assert isinstance(typ(1, 2, 3, 4, 5)[4], uint64) # coercion
assert isinstance(typ(i for i in range(10))[9], uint64) # coercion in generator
v = typ(uint64(0))
v[0] = uint64(123)
assert v[0] == 123
assert isinstance(v[0], uint64)
assert isinstance(v, List)
assert isinstance(v, List[uint64, 128])
assert isinstance(v, typ)
assert isinstance(v, SSZValue)
assert isinstance(v, Series)
assert issubclass(v.type(), Elements)
assert isinstance(v.type(), ElementsType)
assert len(typ([i for i in range(10)])) == 10 # cast py list to SSZ list
foo = List[uint32, 128](0 for i in range(128))
foo[0] = 123
foo[1] = 654
foo[127] = 222
assert sum(foo) == 999
try:
foo[3] = 2 ** 32 # out of bounds
except ValueError:
pass
try:
foo[3] = uint64(2 ** 32 - 1) # within bounds, wrong type
assert False
except ValueError:
pass
try:
foo[128] = 100
assert False
except IndexError:
pass
try:
foo[-1] = 100 # valid in normal python lists
assert False
except IndexError:
pass
try:
foo[128] = 100 # out of bounds
assert False
except IndexError:
pass
def test_bytesn_subclass():
assert isinstance(BytesN[32](b'\xab' * 32), Bytes32)
assert not isinstance(BytesN[32](b'\xab' * 32), Bytes48)
assert issubclass(BytesN[32](b'\xab' * 32).type(), Bytes32)
assert issubclass(BytesN[32], Bytes32)
class Hash(Bytes32):
pass
assert isinstance(Hash(b'\xab' * 32), Bytes32)
assert not isinstance(Hash(b'\xab' * 32), Bytes48)
assert issubclass(Hash(b'\xab' * 32).type(), Bytes32)
assert issubclass(Hash, Bytes32)
assert not issubclass(Bytes48, Bytes32)
assert len(Bytes32() + Bytes48()) == 80
def test_uint_math():
assert uint8(0) + uint8(uint32(16)) == uint8(16) # allow explict casting to make invalid addition valid
expect_value_error(lambda: uint8(0) - uint8(1), "no underflows allowed")
expect_value_error(lambda: uint8(1) + uint8(255), "no overflows allowed")
expect_value_error(lambda: uint8(0) + 256, "no overflows allowed")
expect_value_error(lambda: uint8(42) + uint32(123), "no mixed types")
expect_value_error(lambda: uint32(42) + uint8(123), "no mixed types")
assert type(uint32(1234) + 56) == uint32

View File

@ -0,0 +1,58 @@
import pytest
from .merkle_minimal import zerohashes, merkleize_chunks
from .hash_function import hash
def h(a: bytes, b: bytes) -> bytes:
return hash(a + b)
def e(v: int) -> bytes:
return v.to_bytes(length=32, byteorder='little')
def z(i: int) -> bytes:
return zerohashes[i]
cases = [
(0, 0, 1, z(0)),
(0, 1, 1, e(0)),
(1, 0, 2, h(z(0), z(0))),
(1, 1, 2, h(e(0), z(0))),
(1, 2, 2, h(e(0), e(1))),
(2, 0, 4, h(h(z(0), z(0)), z(1))),
(2, 1, 4, h(h(e(0), z(0)), z(1))),
(2, 2, 4, h(h(e(0), e(1)), z(1))),
(2, 3, 4, h(h(e(0), e(1)), h(e(2), z(0)))),
(2, 4, 4, h(h(e(0), e(1)), h(e(2), e(3)))),
(3, 0, 8, h(h(h(z(0), z(0)), z(1)), z(2))),
(3, 1, 8, h(h(h(e(0), z(0)), z(1)), z(2))),
(3, 2, 8, h(h(h(e(0), e(1)), z(1)), z(2))),
(3, 3, 8, h(h(h(e(0), e(1)), h(e(2), z(0))), z(2))),
(3, 4, 8, h(h(h(e(0), e(1)), h(e(2), e(3))), z(2))),
(3, 5, 8, h(h(h(e(0), e(1)), h(e(2), e(3))), h(h(e(4), z(0)), z(1)))),
(3, 6, 8, h(h(h(e(0), e(1)), h(e(2), e(3))), h(h(e(4), e(5)), h(z(0), z(0))))),
(3, 7, 8, h(h(h(e(0), e(1)), h(e(2), e(3))), h(h(e(4), e(5)), h(e(6), z(0))))),
(3, 8, 8, h(h(h(e(0), e(1)), h(e(2), e(3))), h(h(e(4), e(5)), h(e(6), e(7))))),
(4, 0, 16, h(h(h(h(z(0), z(0)), z(1)), z(2)), z(3))),
(4, 1, 16, h(h(h(h(e(0), z(0)), z(1)), z(2)), z(3))),
(4, 2, 16, h(h(h(h(e(0), e(1)), z(1)), z(2)), z(3))),
(4, 3, 16, h(h(h(h(e(0), e(1)), h(e(2), z(0))), z(2)), z(3))),
(4, 4, 16, h(h(h(h(e(0), e(1)), h(e(2), e(3))), z(2)), z(3))),
(4, 5, 16, h(h(h(h(e(0), e(1)), h(e(2), e(3))), h(h(e(4), z(0)), z(1))), z(3))),
(4, 6, 16, h(h(h(h(e(0), e(1)), h(e(2), e(3))), h(h(e(4), e(5)), h(z(0), z(0)))), z(3))),
(4, 7, 16, h(h(h(h(e(0), e(1)), h(e(2), e(3))), h(h(e(4), e(5)), h(e(6), z(0)))), z(3))),
(4, 8, 16, h(h(h(h(e(0), e(1)), h(e(2), e(3))), h(h(e(4), e(5)), h(e(6), e(7)))), z(3))),
(4, 9, 16,
h(h(h(h(e(0), e(1)), h(e(2), e(3))), h(h(e(4), e(5)), h(e(6), e(7)))), h(h(h(e(8), z(0)), z(1)), z(2)))),
]
@pytest.mark.parametrize(
'depth,count,pow2,value',
cases,
)
def test_merkleize_chunks(depth, count, pow2, value):
chunks = [e(i) for i in range(count)]
assert merkleize_chunks(chunks, pad_to=pow2) == value

View File

@ -2,6 +2,5 @@ eth-utils>=1.3.0,<2
eth-typing>=2.1.0,<3.0.0
pycryptodome==3.7.3
py_ecc>=1.6.0
typing_inspect==0.4.0
dataclasses==0.6
ssz==0.1.0a10

View File

@ -9,7 +9,6 @@ setup(
"eth-typing>=2.1.0,<3.0.0",
"pycryptodome==3.7.3",
"py_ecc>=1.6.0",
"typing_inspect==0.4.0",
"ssz==0.1.0a10",
"dataclasses==0.6",
]