Infer types where possible, e.g. uint64+uint64=uint64

This commit is contained in:
protolambda 2020-06-26 15:41:47 +02:00
parent b239f6108c
commit 531184f42b
No known key found for this signature in database
GPG Key ID: EC89FDBB2B4C7623
1 changed files with 17 additions and 20 deletions

View File

@ -561,9 +561,9 @@ def integer_squareroot(n: uint64) -> uint64:
x = n x = n
y = (x + 1) // 2 y = (x + 1) // 2
while y < x: while y < x:
x = uint64(y) x = y
y = (x + n // x) // 2 y = (x + n // x) // 2
return uint64(x) return x
``` ```
#### `xor` #### `xor`
@ -732,15 +732,11 @@ def compute_shuffled_index(index: uint64, index_count: uint64, seed: Bytes32) ->
# Swap or not (https://link.springer.com/content/pdf/10.1007%2F978-3-642-32009-5_1.pdf) # Swap or not (https://link.springer.com/content/pdf/10.1007%2F978-3-642-32009-5_1.pdf)
# See the 'generalized domain' algorithm on page 3 # See the 'generalized domain' algorithm on page 3
for current_round in map(uint64, range(SHUFFLE_ROUND_COUNT)): for current_round in range(SHUFFLE_ROUND_COUNT):
pivot = bytes_to_int(hash(seed + int_to_bytes(current_round, length=1))[0:8]) % index_count pivot = bytes_to_int(hash(seed + int_to_bytes(current_round, length=1))[0:8]) % index_count
flip = uint64((pivot + index_count - index) % index_count) flip = (pivot + index_count - index) % index_count
position = max(index, flip) position = max(index, flip)
source = hash( source = hash(seed + int_to_bytes(current_round, length=1) + int_to_bytes(position // 256, length=4))
seed
+ int_to_bytes(current_round, length=1)
+ int_to_bytes(uint64(position // 256), length=4)
)
byte = source[(position % 256) // 8] byte = source[(position % 256) // 8]
bit = (byte >> (position % 8)) % 2 bit = (byte >> (position % 8)) % 2
index = flip if bit else index index = flip if bit else index
@ -757,10 +753,11 @@ def compute_proposer_index(state: BeaconState, indices: Sequence[ValidatorIndex]
""" """
assert len(indices) > 0 assert len(indices) > 0
MAX_RANDOM_BYTE = 2**8 - 1 MAX_RANDOM_BYTE = 2**8 - 1
i = 0 i = uint64(0)
total = uint64(len(indices))
while True: while True:
candidate_index = indices[compute_shuffled_index(uint64(i % len(indices)), uint64(len(indices)), seed)] candidate_index = indices[compute_shuffled_index(i % total, total, seed)]
random_byte = hash(seed + int_to_bytes(uint64(i // 32), length=8))[i % 32] random_byte = hash(seed + int_to_bytes(i // 32, length=8))[i % 32]
effective_balance = state.validators[candidate_index].effective_balance effective_balance = state.validators[candidate_index].effective_balance
if effective_balance * MAX_RANDOM_BYTE >= MAX_EFFECTIVE_BALANCE * random_byte: if effective_balance * MAX_RANDOM_BYTE >= MAX_EFFECTIVE_BALANCE * random_byte:
return candidate_index return candidate_index
@ -938,7 +935,7 @@ def get_validator_churn_limit(state: BeaconState) -> uint64:
Return the validator churn limit for the current epoch. Return the validator churn limit for the current epoch.
""" """
active_validator_indices = get_active_validator_indices(state, get_current_epoch(state)) active_validator_indices = get_active_validator_indices(state, get_current_epoch(state))
return max(MIN_PER_EPOCH_CHURN_LIMIT, uint64(len(active_validator_indices) // CHURN_LIMIT_QUOTIENT)) return max(MIN_PER_EPOCH_CHURN_LIMIT, uint64(len(active_validator_indices)) // CHURN_LIMIT_QUOTIENT)
``` ```
#### `get_seed` #### `get_seed`
@ -961,7 +958,7 @@ def get_committee_count_per_slot(state: BeaconState, epoch: Epoch) -> uint64:
""" """
return max(uint64(1), min( return max(uint64(1), min(
MAX_COMMITTEES_PER_SLOT, MAX_COMMITTEES_PER_SLOT,
uint64(len(get_active_validator_indices(state, epoch)) // SLOTS_PER_EPOCH // TARGET_COMMITTEE_SIZE), uint64(len(get_active_validator_indices(state, epoch))) // SLOTS_PER_EPOCH // TARGET_COMMITTEE_SIZE,
)) ))
``` ```
@ -977,8 +974,8 @@ def get_beacon_committee(state: BeaconState, slot: Slot, index: CommitteeIndex)
return compute_committee( return compute_committee(
indices=get_active_validator_indices(state, epoch), indices=get_active_validator_indices(state, epoch),
seed=get_seed(state, epoch, DOMAIN_BEACON_ATTESTER), seed=get_seed(state, epoch, DOMAIN_BEACON_ATTESTER),
index=uint64((slot % SLOTS_PER_EPOCH) * committees_per_slot + index), index=(slot % SLOTS_PER_EPOCH) * committees_per_slot + index,
count=uint64(committees_per_slot * SLOTS_PER_EPOCH), count=committees_per_slot * SLOTS_PER_EPOCH,
) )
``` ```
@ -1501,12 +1498,12 @@ def get_attestation_deltas(state: BeaconState) -> Tuple[Sequence[Gwei], Sequence
_, inactivity_penalties = get_inactivity_penalty_deltas(state) _, inactivity_penalties = get_inactivity_penalty_deltas(state)
rewards = [ rewards = [
Gwei(source_rewards[i] + target_rewards[i] + head_rewards[i] + inclusion_delay_rewards[i]) source_rewards[i] + target_rewards[i] + head_rewards[i] + inclusion_delay_rewards[i]
for i in range(len(state.validators)) for i in range(len(state.validators))
] ]
penalties = [ penalties = [
Gwei(source_penalties[i] + target_penalties[i] + head_penalties[i] + inactivity_penalties[i]) source_penalties[i] + target_penalties[i] + head_penalties[i] + inactivity_penalties[i]
for i in range(len(state.validators)) for i in range(len(state.validators))
] ]
@ -1773,13 +1770,13 @@ def process_deposit(state: BeaconState, deposit: Deposit) -> None:
assert is_valid_merkle_branch( assert is_valid_merkle_branch(
leaf=hash_tree_root(deposit.data), leaf=hash_tree_root(deposit.data),
branch=deposit.proof, branch=deposit.proof,
depth=uint64(DEPOSIT_CONTRACT_TREE_DEPTH + 1), # Add 1 for the List length mix-in depth=DEPOSIT_CONTRACT_TREE_DEPTH + 1, # Add 1 for the List length mix-in
index=state.eth1_deposit_index, index=state.eth1_deposit_index,
root=state.eth1_data.deposit_root, root=state.eth1_data.deposit_root,
) )
# Deposits must be processed in order # Deposits must be processed in order
state.eth1_deposit_index = uint64(state.eth1_deposit_index + 1) state.eth1_deposit_index += 1
pubkey = deposit.data.pubkey pubkey = deposit.data.pubkey
amount = deposit.data.amount amount = deposit.data.amount