Added shuffling algos

This commit is contained in:
Vitalik Buterin 2018-12-18 17:33:28 -05:00
parent ba690c0307
commit 3aa84a76dd
No known key found for this signature in database
GPG Key ID: A99D082A6179F987
4 changed files with 173 additions and 0 deletions

View File

@ -0,0 +1,44 @@
from hashlib import blake2s
def hash(x): return blake2s(x).digest()[:32]
def numhash(x, i, seed, modulus):
assert 0 <= i < 4
return (int.from_bytes(hash(x.to_bytes(32, 'big') + seed), 'big') // modulus**i) % modulus
def numhash_all(x, seed, modulus):
h = int.from_bytes(hash(x.to_bytes(32, 'big') + seed), 'big')
return [(h // modulus ** i) % modulus for i in range(4)]
def next_perfect_square(n):
if int(n ** 0.5) ** 2 == n:
return n
return (int(n ** 0.5) + 1) ** 2
def multi_feistel(modulus, xs, seed, precompute=False):
h = int(next_perfect_square(modulus) ** 0.5)
numhashes = [numhash_all(i, seed, modulus) for i in range(h)] if precompute else None
o = []
for x in xs:
while 1:
L, R = x//h, x%h
for i in range(4):
if precompute:
new_R = (L + numhashes[R][i]) % h
else:
new_R = (L + numhash(R, i, seed, modulus)) % h
L = R
R = new_R
x = L * h + R
if x < modulus:
o.append(x)
break
return o
def feistel_shuffle(values, seed):
return [values[i] for i in multi_feistel(len(values), list(range(len(values))), seed, True)]
def feistel_shuffle_partial(values, seed, count):
return [values[i] for i in multi_feistel(len(values), list(range(count)), seed, False)]

View File

@ -0,0 +1,52 @@
from hashlib import blake2s
def hash(x): return blake2s(x).digest()[:32]
def fisher_yates_shuffle(values, seed):
"""
Returns the shuffled ``values`` with ``seed`` as entropy.
"""
values_count = len(values)
# Entropy is consumed from the seed in 3-byte (24 bit) chunks.
rand_bytes = 3
# The highest possible result of the RNG.
rand_max = 2 ** (rand_bytes * 8) - 1
# The range of the RNG places an upper-bound on the size of the list that
# may be shuffled. It is a logic error to supply an oversized list.
assert values_count < rand_max
output = [x for x in values]
source = seed
index = 0
while index < values_count - 1:
# Re-hash the `source` to obtain a new pattern of bytes.
source = hash(source)
# Iterate through the `source` bytes in 3-byte chunks.
for position in range(0, 32 - (32 % rand_bytes), rand_bytes):
# Determine the number of indices remaining in `values` and exit
# once the last index is reached.
remaining = values_count - index
if remaining == 1:
break
# Read 3-bytes of `source` as a 24-bit big-endian integer.
sample_from_source = int.from_bytes(source[position:position + rand_bytes], 'big')
# Sample values greater than or equal to `sample_max` will cause
# modulo bias when mapped into the `remaining` range.
sample_max = rand_max - rand_max % remaining
# Perform a swap if the consumed entropy will not cause modulo bias.
if sample_from_source < sample_max:
# Select a replacement index for the current index.
replacement_position = (sample_from_source % remaining) + index
# Swap the current index with the replacement index.
output[index], output[replacement_position] = output[replacement_position], output[index]
index += 1
else:
# The sample causes modulo bias. A new sample should be read.
pass
return output

View File

@ -0,0 +1,42 @@
from hashlib import blake2s
def hash(x): return blake2s(x).digest()[:32]
def is_prime(x):
return [i for i in range(2, int(x**0.5)+1) if x%i == 0] == []
def values_at_position(n, positions, seed, precompute=False):
# We do the shuffling mod p, the lowest prime >= n, but if we actually shuffle into
# the "forbidden" [n...p-1] slice we just reshuffle until we get out of that slice
p = n
while not is_prime(p):
p += 1
# x -> x**power is a permutation mod p
power = 3
while (p-1) % power == 0 or not is_prime(power):
power += 2
values = positions[::]
power_of = [pow(i, power, p) for i in range(p)] if precompute else None
indices = list(range(len(values)))
for round in range(40):
a = int.from_bytes(seed[(round % 8)*4: (round % 8)*4 + 4], 'big')
if precompute:
values = [(power_of[v] + a) % p for v in values]
else:
values = [(pow(v, power, p) + a) % p for v in values]
for i in [i for i in indices if values[i] >= n]:
while values[i] >= n:
if precompute:
values[i] = (power_of[values[i]] + a) % p
else:
values[i] = (pow(values[i], power, p) + a) % p
# Update the seed if needed
if round % 8 == 0:
seed = hash(seed)
return values
def prime_shuffle(values, seed):
return [values[i] for i in values_at_position(len(values), list(range(len(values))), seed, True)]
def prime_shuffle_partial(values, seed, count):
return [values[i] for i in values_at_position(len(values), list(range(count)), seed, False)]

35
shuffling/test_shuffle.py Normal file
View File

@ -0,0 +1,35 @@
from prime_shuffle import prime_shuffle, prime_shuffle_partial
from feistel_shuffle import feistel_shuffle, feistel_shuffle_partial
from fisher_yates_shuffle import fisher_yates_shuffle
import time
count = 100000
subcount = 500
print("Testing prime shuffle")
a = time.time()
o = prime_shuffle(range(count), b'doge'*8)
print(o[:10])
t2 = time.time()
o2 = prime_shuffle_partial(range(count), b'doge' * 8, subcount)
print(o2[:10])
print("Total runtime: ", t2 - a)
print("Runtime to compute committee: ", time.time() - t2)
print("\n")
print("Testing feistel shuffle")
a = time.time()
o = feistel_shuffle(range(count), b'doge'*8)
print(o[:10])
t2 = time.time()
o2 = feistel_shuffle_partial(range(count), b'doge' * 8, subcount)
print(o2[:10])
print("Total runtime: ", t2 - a)
print("Runtime to compute committee: ", time.time() - t2)
print("\n")
print("Testing Fisher-Yates shuffle")
a = time.time()
o = fisher_yates_shuffle(range(count), b'doge'*8)
print(o[:10])
print("Total runtime: ", time.time() - a)