Refactor the sample code and fix #47

This commit is contained in:
Hsiao-Wei Wang 2018-10-07 14:21:37 +08:00
parent d158e11006
commit 12a1bd2473
No known key found for this signature in database
GPG Key ID: 95B070122902DEA4
1 changed files with 114 additions and 47 deletions

View File

@ -317,61 +317,109 @@ def get_active_validator_indices(validators):
Now, a function that shuffles this list: Now, a function that shuffles this list:
```python ```python
def shuffle(lst, seed): def shuffle(values: List[Any],
seed: Hash32) -> List[Any]:
"""
Returns the shuffled ``values`` with seed as entropy.
Mainly for shuffling active validators in-protocol.
"""
values_count = len(values)
# entropy is consumed in 3 byte chunks # entropy is consumed in 3 byte chunks
# rand_max is defined to remove the modulo bias from this entropy source # rand_max is defined to remove the modulo bias from this entropy source
rand_max = 2**24 SAMPLE_RANGE = 2 ** 24
assert len(lst) <= rand_max assert values_count <= SAMPLE_RANGE:
o = [x for x in lst] output = [x for x in values]
source = seed source = seed
i = 0 index = 0
while i < len(lst): while index < values_count:
source = hash(source) # Re-hash the source
for pos in range(0, 30, 3): source = blake(source)
m = int.from_bytes(source[pos:pos+3], 'big') for position in range(0, 30, 3): # gets indices 3 bytes at a time
remaining = len(lst) - i # Select a 3-byte sampled int
sample_from_source = int.from_bytes(source[position:position + 3], 'big')
# `remaining` is the size of remaining indices of this round
remaining = values_count - index
if remaining == 0: if remaining == 0:
break break
rand_max = rand_max - rand_max % remaining
if m < rand_max: # Set a random maximum bound of sample_from_source
replacement_pos = (m % remaining) + i rand_max = SAMPLE_RANGE - SAMPLE_RANGE % remaining
o[i], o[replacement_pos] = o[replacement_pos], o[i]
i += 1 # Select `replacement_position` with the given `sample_from_source` and `remaining`
return o if sample_from_source < rand_max:
# Use random number to get `replacement_position`, where it's not `index`
replacement_position = (sample_from_source % remaining) + index
# Swap the index-th and replacement_position-th elements
(output[index], output[replacement_position]) = (
output[replacement_position],
output[index]
)
index += 1
else:
pass
return output
``` ```
Here's a function that splits a list into `N` pieces: Here's a function that splits a list into `N` pieces:
```python ```python
def split(lst, N): def split(seq: List[Any], pieces: int) -> List[Any]:
return [lst[len(lst)*i//N: len(lst)*(i+1)//N] for i in range(N)] """
Returns the split ``seq`` in ``pieces`` pieces in protocol.
"""
list_length = len(seq)
return [
seq[(list_length * i // pieces): (list_length * (i + 1) // pieces)]
for i in range(pieces)
]
``` ```
Now, our combined helper method: Now, our combined helper method:
```python ```python
def get_new_shuffling(seed, validators, crosslinking_start_shard): def get_new_shuffling(seed: Hash32,
validators: List[ValidatorRecord],
crosslinking_start_shard: int) -> List[List[ShardAndCommittee]]:
active_validators = get_active_validator_indices(validators) active_validators = get_active_validator_indices(validators)
if len(active_validators) >= CYCLE_LENGTH * MIN_COMMITTEE_SIZE: active_validators_size = len(active_validators)
committees_per_slot = min(len(active_validators) // CYCLE_LENGTH // (MIN_COMMITTEE_SIZE * 2) + 1, SHARD_COUNT // CYCLE_LENGTH)
if active_validators_size >= CYCLE_LENGTH * MIN_COMMITTEE_SIZE:
committees_per_slot = min(active_validators_size // CYCLE_LENGTH // (MIN_COMMITTEE_SIZE * 2) + 1, SHARD_COUNT // CYCLE_LENGTH)
slots_per_committee = 1 slots_per_committee = 1
else: else:
committees_per_slot = 1 committees_per_slot = 1
slots_per_committee = 1 slots_per_committee = 1
while len(active_validators) * slots_per_committee < CYCLE_LENGTH * MIN_COMMITTEE_SIZE \ while active_validators_size * slots_per_committee < CYCLE_LENGTH * MIN_COMMITTEE_SIZE \
and slots_per_committee < CYCLE_LENGTH: and slots_per_committee < CYCLE_LENGTH:
slots_per_committee *= 2 slots_per_committee *= 2
o = [] output = []
for i, slot_indices in enumerate(split(shuffle(active_validators, seed), CYCLE_LENGTH)):
# Shuffle with seed
shuffled_active_validator_indices = shuffle(active_validators, seed)
# Split the shuffled list into cycle_length pieces
validators_per_slot = split(shuffled_active_validator_indices, CYCLE_LENGTH)
for slot, slot_indices in enumerate(validators_per_slot):
# Split the shuffled list into committees_per_slot pieces
shard_indices = split(slot_indices, committees_per_slot) shard_indices = split(slot_indices, committees_per_slot)
shard_start = crosslinking_start_shard + \
i * committees_per_slot // slots_per_committee shard_id_start = crosslinking_start_shard + (
o.append([ShardAndCommittee( slot * committees_per_slot // slots_per_committee
shard = (shard_start + j) % SHARD_COUNT, )
committee = indices shards_and_committees_for_shard_indices = [
) for j, indices in enumerate(shard_indices)]) ShardAndCommittee(
return o shard_id = (shard_id_start + j) % SHARD_COUNT,
committee = indices
)
for slot, indices in enumerate(shard_indices)
]
output.append(shards_and_committees_for_shard_indices)
return output
``` ```
Here's a diagram of what's going on: Here's a diagram of what's going on:
@ -381,13 +429,16 @@ Here's a diagram of what's going on:
We also define two functions for retrieving data from the state: We also define two functions for retrieving data from the state:
```python ```python
def get_shards_and_committees_for_slot(crystallized_state, slot): def get_shards_and_committees_for_slot(crystallized_state: CrystallizedState,
earliest_slot_in_array = crystallized_state.last_state_recalculation_slot - CYCLE_LENGTH slot: int) -> List[ShardAndCommittee]:
earliest_slot_in_array = crystallized_state.last_state_recalculation - CYCLE_LENGTH
assert earliest_slot_in_array <= slot < earliest_slot_in_array + CYCLE_LENGTH * 2 assert earliest_slot_in_array <= slot < earliest_slot_in_array + CYCLE_LENGTH * 2
return crystallized_state.shard_and_committee_for_slots[slot - earliest_slot_in_array] return crystallized_state.shard_and_committee_for_slots[slot - earliest_slot_in_array]
def get_block_hash(active_state, curblock, slot): def get_block_hash(active_state:ActiveState,
earliest_slot_in_array = curblock.slot - CYCLE_LENGTH * 2 current_block: BeaconBlock,
slot: int):
earliest_slot_in_array = current_block.slot - CYCLE_LENGTH * 2
assert earliest_slot_in_array <= slot < earliest_slot_in_array + CYCLE_LENGTH * 2 assert earliest_slot_in_array <= slot < earliest_slot_in_array + CYCLE_LENGTH * 2
return active_state.recent_block_hashes[slot - earliest_slot_in_array] return active_state.recent_block_hashes[slot - earliest_slot_in_array]
``` ```
@ -397,7 +448,10 @@ def get_block_hash(active_state, curblock, slot):
We define a function to "add a link" to the validator hash chain, used when a validator is added or removed: We define a function to "add a link" to the validator hash chain, used when a validator is added or removed:
```python ```python
def add_validator_set_change_record(crystallized_state, index, pubkey, flag): def add_validator_set_change_record(crystallized_state: CrystallizedState,
index: int,
pubkey: int,
flag: int) -> None:
crystallized_state.validator_set_delta_hash_chain = \ crystallized_state.validator_set_delta_hash_chain = \
hash(crystallized_state.validator_set_delta_hash_chain + hash(crystallized_state.validator_set_delta_hash_chain +
bytes1(flag) + bytes3(index) + bytes32(pubkey)) bytes1(flag) + bytes3(index) + bytes32(pubkey))
@ -406,7 +460,7 @@ def add_validator_set_change_record(crystallized_state, index, pubkey, flag):
Finally, we abstractly define `int_sqrt(n)` for use in reward/penalty calculations as the largest integer `k` such that `k**2 <= n`. Here is one possible implementation, though clients are free to use their own including standard libraries for [integer square root](https://en.wikipedia.org/wiki/Integer_square_root) if available and meet the specification. Finally, we abstractly define `int_sqrt(n)` for use in reward/penalty calculations as the largest integer `k` such that `k**2 <= n`. Here is one possible implementation, though clients are free to use their own including standard libraries for [integer square root](https://en.wikipedia.org/wiki/Integer_square_root) if available and meet the specification.
```python ```python
def int_sqrt(n): def int_sqrt(n: int):
x = n x = n
y = (x + 1) // 2 y = (x + 1) // 2
while y < x: while y < x:
@ -421,13 +475,18 @@ def int_sqrt(n):
Run the following code: Run the following code:
```python ```python
def on_startup(initial_validator_entries): def on_startup(initial_validator_entries: List[Any]) -> None:
# Induct validators # Induct validators
validators = [] validators = []
for pubkey, proof_of_possession, withdrawal_shard, withdrawal_address, \ for pubkey, proof_of_possession, withdrawal_shard, withdrawal_address, \
randao_commitment in initial_validator_entries: randao_commitment in initial_validator_entries:
add_validator(validators, pubkey, proof_of_possession, add_validator(validators,
withdrawal_shard, withdrawal_address, randao_commitment) pubkey,
proof_of_possession,
withdrawal_shard,
withdrawal_address,
randao_commitment
)
# Setup crystallized state # Setup crystallized state
cs = CrystallizedState() cs = CrystallizedState()
x = get_new_shuffling(bytes([0] * 32), validators, 0) x = get_new_shuffling(bytes([0] * 32), validators, 0)
@ -447,8 +506,12 @@ The `CrystallizedState()` and `ActiveState()` constructors should initialize all
This routine should be run for every validator that is inducted as part of a log created on the PoW chain [TODO: explain where to check for these logs]. These logs should be processed in the order in which they are emitted by the PoW chain. Define `min_empty_validator(validators)` as a function that returns the lowest validator index `i` such that `validators[i].status == WITHDRAWN`, otherwise `None`. This routine should be run for every validator that is inducted as part of a log created on the PoW chain [TODO: explain where to check for these logs]. These logs should be processed in the order in which they are emitted by the PoW chain. Define `min_empty_validator(validators)` as a function that returns the lowest validator index `i` such that `validators[i].status == WITHDRAWN`, otherwise `None`.
```python ```python
def add_validator(validators, pubkey, proof_of_possession, withdrawal_shard, def add_validator(validators: List[ValidatorRecord],
withdrawal_address, randao_commitment): pubkey: int,
proof_of_possession: bytes,
withdrawal_shard: int,
withdrawal_address: Address,
randao_commitment: Hash32) -> int:
# if following assert fails, validator induction failed # if following assert fails, validator induction failed
# move on to next validator registration log # move on to next validator registration log
assert BLSVerify(pub=pubkey, assert BLSVerify(pub=pubkey,
@ -479,8 +542,10 @@ This procedure should be carried out every block.
First, set `recent_block_hashes` to the output of the following, where `parent_hash` is the hash of the immediate previous block (ie. must be equal to `ancestor_hashes[0]`): First, set `recent_block_hashes` to the output of the following, where `parent_hash` is the hash of the immediate previous block (ie. must be equal to `ancestor_hashes[0]`):
```python ```python
def get_new_recent_block_hashes(old_block_hashes, parent_slot, def get_new_recent_block_hashes(old_block_hashes: List[Hash32],
current_slot, parent_hash): parent_slot: int,
current_slot: int,
parent_hash: Hash32) -> List[Hash32]:
d = current_slot - parent_slot d = current_slot - parent_slot
return old_block_hashes[d:] + [parent_hash] * min(d, len(old_block_hashes)) return old_block_hashes[d:] + [parent_hash] * min(d, len(old_block_hashes))
``` ```
@ -488,7 +553,9 @@ def get_new_recent_block_hashes(old_block_hashes, parent_slot,
The output of `get_block_hash` should not change, except that it will no longer throw for `current_slot - 1`, and will now throw for `current_slot - CYCLE_LENGTH * 2 - 1`. Also, check that the block's `ancestor_hashes` array was correctly updated, using the following algorithm: The output of `get_block_hash` should not change, except that it will no longer throw for `current_slot - 1`, and will now throw for `current_slot - CYCLE_LENGTH * 2 - 1`. Also, check that the block's `ancestor_hashes` array was correctly updated, using the following algorithm:
```python ```python
def update_ancestor_hashes(parent_ancestor_hashes, parent_slot_number, parent_hash): def update_ancestor_hashes(parent_ancestor_hashes: List[Hash32],
parent_slot_number: int,
parent_hash: Hash32) -> List[Hash32]:
new_ancestor_hashes = copy.copy(parent_ancestor_hashes) new_ancestor_hashes = copy.copy(parent_ancestor_hashes)
for i in range(32): for i in range(32):
if parent_slot_number % 2**i == 0: if parent_slot_number % 2**i == 0:
@ -598,9 +665,9 @@ A dynasty transition can happen after a state recalculation if all of the follow
Then, run the following algorithm to update the validator set: Then, run the following algorithm to update the validator set:
```python ```python
def change_validators(validators): def change_validators(validators: List[ValidatorRecord]) -> None:
# The active validator set # The active validator set
active_validators = get_active_validator_indices(validators, dynasty) active_validators = get_active_validator_indices(validators)
# The total balance of active validators # The total balance of active validators
total_balance = sum([v.balance for i, v in enumerate(validators) if i in active_validators]) total_balance = sum([v.balance for i, v in enumerate(validators) if i in active_validators])
# The maximum total wei that can deposit+withdraw # The maximum total wei that can deposit+withdraw