Cut down to 2 non-trivial FFTs

This commit is contained in:
Vitalik Buterin 2018-07-11 11:46:21 -04:00
parent cc15f8e70c
commit 5f762aee81
5 changed files with 73 additions and 62 deletions

View File

@ -10,13 +10,13 @@ from poly_utils import PrimeField
# We use maxdeg+1 instead of maxdeg because it's more mathematically # We use maxdeg+1 instead of maxdeg because it's more mathematically
# convenient in this case. # convenient in this case.
def prove_low_degree(values, root_of_unity, maxdeg_plus_1, modulus): def prove_low_degree(values, root_of_unity, maxdeg_plus_1, modulus, exclude_multiples_of=0):
f = PrimeField(modulus) f = PrimeField(modulus)
print('Proving %d values are degree <= %d' % (len(values), maxdeg_plus_1)) print('Proving %d values are degree <= %d' % (len(values), maxdeg_plus_1))
# If the degree we are checking for is less than or equal to 32, # If the degree we are checking for is less than or equal to 32,
# use the polynomial directly as a proof # use the polynomial directly as a proof
if maxdeg_plus_1 <= 32: if maxdeg_plus_1 <= 16:
print('Produced FRI proof') print('Produced FRI proof')
return [[x.to_bytes(32, 'big') for x in values]] return [[x.to_bytes(32, 'big') for x in values]]
@ -30,6 +30,7 @@ def prove_low_degree(values, root_of_unity, maxdeg_plus_1, modulus):
# Select a pseudo-random x coordinate # Select a pseudo-random x coordinate
special_x = int.from_bytes(m[1], 'big') % modulus special_x = int.from_bytes(m[1], 'big') % modulus
special_x = root_of_unity + 5
# Calculate the "column" at that x coordinate # Calculate the "column" at that x coordinate
# (see https://vitalik.ca/general/2017/11/22/starks_part_2.html) # (see https://vitalik.ca/general/2017/11/22/starks_part_2.html)
@ -45,7 +46,7 @@ def prove_low_degree(values, root_of_unity, maxdeg_plus_1, modulus):
m2 = merkelize(column) m2 = merkelize(column)
# Pseudo-randomly select y indices to sample # Pseudo-randomly select y indices to sample
ys = get_pseudorandom_indices(m2[1], len(column), 40) ys = get_pseudorandom_indices(m2[1], len(column), 40, exclude_multiples_of=exclude_multiples_of)
# Compute the Merkle branches for the values in the polynomial and the column # Compute the Merkle branches for the values in the polynomial and the column
branches = [] branches = []
@ -63,10 +64,10 @@ def prove_low_degree(values, root_of_unity, maxdeg_plus_1, modulus):
# Recurse... # Recurse...
return [o] + prove_low_degree(column, f.exp(root_of_unity, 4), return [o] + prove_low_degree(column, f.exp(root_of_unity, 4),
maxdeg_plus_1 // 4, modulus) maxdeg_plus_1 // 4, modulus, exclude_multiples_of=exclude_multiples_of)
# Verify an FRI proof # Verify an FRI proof
def verify_low_degree_proof(merkle_root, root_of_unity, proof, maxdeg_plus_1, modulus): def verify_low_degree_proof(merkle_root, root_of_unity, proof, maxdeg_plus_1, modulus, exclude_multiples_of=0):
f = PrimeField(modulus) f = PrimeField(modulus)
# Calculate which root of unity we're working with # Calculate which root of unity we're working with
@ -89,10 +90,11 @@ def verify_low_degree_proof(merkle_root, root_of_unity, proof, maxdeg_plus_1, mo
# Calculate the pseudo-random x coordinate # Calculate the pseudo-random x coordinate
special_x = int.from_bytes(merkle_root, 'big') % modulus special_x = int.from_bytes(merkle_root, 'big') % modulus
special_x = root_of_unity + 5
# Calculate the pseudo-randomly sampled y indices # Calculate the pseudo-randomly sampled y indices
ys = get_pseudorandom_indices(root2, roudeg // 4, 40) ys = get_pseudorandom_indices(root2, roudeg // 4, 40,
exclude_multiples_of=exclude_multiples_of)
# Verify for each selected y coordinate that the four points from the # Verify for each selected y coordinate that the four points from the
# polynomial and the one point from the column that are on that y # polynomial and the one point from the column that are on that y
@ -122,16 +124,23 @@ def verify_low_degree_proof(merkle_root, root_of_unity, proof, maxdeg_plus_1, mo
# Verify the direct components of the proof # Verify the direct components of the proof
data = [int.from_bytes(x, 'big') for x in proof[-1]] data = [int.from_bytes(x, 'big') for x in proof[-1]]
print('Verifying degree <= %d' % maxdeg_plus_1) print('Verifying degree <= %d' % maxdeg_plus_1)
assert maxdeg_plus_1 <= 32 assert maxdeg_plus_1 <= 16
# Check the Merkle root matches up # Check the Merkle root matches up
mtree = merkelize(data) mtree = merkelize(data)
assert mtree[1] == merkle_root assert mtree[1] == merkle_root
# Check the degree of the data # Check the degree of the data
poly = fft(data, modulus, root_of_unity, inv=True) powers = get_power_cycle(root_of_unity, modulus)
for i in range(maxdeg_plus_1, len(poly)): if exclude_multiples_of:
assert poly[i] == 0 pts = [x for x in range(len(data)) if x % exclude_multiples_of]
else:
pts = range(len(data))
poly = f.lagrange_interp([powers[x] for x in pts[:maxdeg_plus_1]],
[data[x] for x in pts[:maxdeg_plus_1]])
for x in pts[maxdeg_plus_1:]:
assert f.eval_poly_at(poly, powers[x]) == data[x]
print('FRI proof verified') print('FRI proof verified')
return True return True

View File

@ -10,7 +10,8 @@ modulus = 2**256 - 2**32 * 351 + 1
f = PrimeField(modulus) f = PrimeField(modulus)
nonresidue = 7 nonresidue = 7
spot_check_security_factor = 240 spot_check_security_factor = 80
extension_factor = 8
# Compute a MIMC permutation for 2**logsteps steps # Compute a MIMC permutation for 2**logsteps steps
def mimc(inp, logsteps, round_constants): def mimc(inp, logsteps, round_constants):
@ -25,9 +26,8 @@ def mimc(inp, logsteps, round_constants):
def mk_mimc_proof(inp, logsteps, round_constants): def mk_mimc_proof(inp, logsteps, round_constants):
start_time = time.time() start_time = time.time()
assert logsteps <= 29 assert logsteps <= 29
logprecision = logsteps + 3
steps = 2**logsteps steps = 2**logsteps
precision = 2**logprecision precision = steps * extension_factor
# Root of unity such that x^precision=1 # Root of unity such that x^precision=1
root_of_unity = f.exp(7, (modulus-1)//precision) root_of_unity = f.exp(7, (modulus-1)//precision)
@ -37,9 +37,9 @@ def mk_mimc_proof(inp, logsteps, round_constants):
subroot = f.exp(root_of_unity, skips) subroot = f.exp(root_of_unity, skips)
# Powers of the root of unity, our computational trace will be # Powers of the root of unity, our computational trace will be
# along the sequence of roots of unity # along the sequence of sub-roots
xs = get_power_cycle(subroot, modulus) xs = get_power_cycle(root_of_unity, modulus)
last_step_position = xs[steps-1] last_step_position = xs[(steps-1)*extension_factor]
# Generate the computational trace # Generate the computational trace
values = [inp] values = [inp]
@ -61,38 +61,29 @@ def mk_mimc_proof(inp, logsteps, round_constants):
# Create the composed polynomial such that # Create the composed polynomial such that
# C(P(x), P(rx), K(x)) = P(rx) - P(x)**3 - K(x) # C(P(x), P(rx), K(x)) = P(rx) - P(x)**3 - K(x)
term1 = f.multiply_base(values_polynomial, subroot) c_of_p_evaluations = [(p_evaluations[(i+extension_factor)%precision] -
term2 = fft([f.exp(x, 3) for x in p_evaluations], modulus, root_of_unity, inv=True)[:len(values_polynomial) * 3 - 2] f.exp(p_evaluations[i], 3) -
c_of_values = f.sub_polys(f.sub_polys(term1, term2), constants_polynomial) constants_mini_extension[i % len(constants_mini_extension)])
c_of_p_evaluations = [(p_evaluations[(i+8)%precision] - f.exp(p_evaluations[i], 3) - % modulus for i in range(precision)]
constants_mini_extension[i % len(constants_mini_extension)]) % modulus for i in range(precision)]
print('Computed C(P, K) polynomial') print('Computed C(P, K) polynomial')
# Compute D(x) = C(P(x), P(rx), K(x)) / Z(x) # Compute D(x) = C(P(x), P(rx), K(x)) / Z(x)
# Z(x) = (x^steps - 1) / (x - x_atlast_step) # Z(x) = (x^steps - 1) / (x - x_atlast_step)
d = f.divide_by_xnm1(f.mul_polys(c_of_values, z_num_evaluations = [xs[(i * steps) % precision] - 1 for i in range(precision)]
[-last_step_position, 1]), z_num_inv = f.multi_inv(z_num_evaluations)
steps) z_den_evaluations = [xs[i] - last_step_position for i in range(precision)]
# Consistency check d_evaluations = [cp * zd * zni % modulus for cp, zd, zni in zip(c_of_p_evaluations, z_den_evaluations, z_num_inv)]
# assert (f.eval_poly_at(d, 90833) *
# (f.exp(90833, steps) - 1) *
# f.inv(f.eval_poly_at([-last_step_position, 1], 90833)) -
# f.eval_poly_at(c_of_values, 90833)) % modulus == 0
print('Computed D polynomial') print('Computed D polynomial')
# Compute interpolant of ((1, input), (x_atlast_step, output)) # Compute interpolant of ((1, input), (x_atlast_step, output))
interpolant = f.lagrange_interp_2([1, last_step_position], [inp, output]) interpolant = f.lagrange_interp_2([1, last_step_position], [inp, output])
quotient = f.mul_polys([-1, 1], [-last_step_position, 1]) i_evaluations = [f.eval_poly_at(interpolant, x) for x in xs]
b = f.div_polys(f.sub_polys(values_polynomial, interpolant), quotient)
# Consistency check
# assert f.eval_poly_at(f.add_polys(f.mul_polys(b, quotient), interpolant), 7045) == \
# f.eval_poly_at(values_polynomial, 7045)
print('Computed B polynomial')
# Evaluate B and D across the entire subgroup quotient = f.mul_polys([-1, 1], [-last_step_position, 1])
d_evaluations = fft(d, modulus, root_of_unity) inv_q_evaluations = f.multi_inv([f.eval_poly_at(quotient, x) for x in xs])
b_evaluations = fft(b, modulus, root_of_unity)
print('Evaluated low-degree extension of B and D') b_evaluations = [((p - i) * invq) % modulus for p, i, invq in zip(p_evaluations, i_evaluations, inv_q_evaluations)]
print('Computed B polynomial')
# Compute their Merkle roots # Compute their Merkle roots
p_mtree = merkelize(p_evaluations) p_mtree = merkelize(p_evaluations)
@ -119,16 +110,19 @@ def mk_mimc_proof(inp, logsteps, round_constants):
p_evaluations[i] * k1 + p_evaluations[i] * k2 * powers[i] + p_evaluations[i] * k1 + p_evaluations[i] * k2 * powers[i] +
b_evaluations[i] * k3 + b_evaluations[i] * powers[i] * k4) % modulus b_evaluations[i] * k3 + b_evaluations[i] * powers[i] * k4) % modulus
for i in range(precision)] for i in range(precision)]
l_mtree = merkelize(l_evaluations) l_mtree = merkelize(l_evaluations)
print('Computed random linear combination') print('Computed random linear combination')
# Do some spot checks of the Merkle tree at pseudo-random coordinates # Do some spot checks of the Merkle tree at pseudo-random coordinates, excluding
# multiples of `extension_factor`
branches = [] branches = []
samples = spot_check_security_factor // (logprecision - logsteps) samples = spot_check_security_factor
positions = get_pseudorandom_indices(l_mtree[1], precision - skips, samples) positions = get_pseudorandom_indices(l_mtree[1], precision, samples,
exclude_multiples_of=extension_factor)
for pos in positions: for pos in positions:
branches.append(mk_branch(p_mtree, pos)) branches.append(mk_branch(p_mtree, pos))
branches.append(mk_branch(p_mtree, pos + skips)) branches.append(mk_branch(p_mtree, (pos + skips) % precision))
branches.append(mk_branch(d_mtree, pos)) branches.append(mk_branch(d_mtree, pos))
branches.append(mk_branch(b_mtree, pos)) branches.append(mk_branch(b_mtree, pos))
branches.append(mk_branch(l_mtree, pos)) branches.append(mk_branch(l_mtree, pos))
@ -141,7 +135,7 @@ def mk_mimc_proof(inp, logsteps, round_constants):
b_mtree[1], b_mtree[1],
l_mtree[1], l_mtree[1],
branches, branches,
prove_low_degree(l_evaluations, root_of_unity, steps * 2, modulus)] prove_low_degree(l_evaluations, root_of_unity, steps * 2, modulus, exclude_multiples_of=extension_factor)]
print("STARK computed in %.4f sec" % (time.time() - start_time)) print("STARK computed in %.4f sec" % (time.time() - start_time))
return o return o
@ -150,9 +144,8 @@ def verify_mimc_proof(inp, logsteps, round_constants, output, proof):
p_root, d_root, b_root, l_root, branches, fri_proof = proof p_root, d_root, b_root, l_root, branches, fri_proof = proof
start_time = time.time() start_time = time.time()
logprecision = logsteps + 3
steps = 2**logsteps steps = 2**logsteps
precision = 2**logprecision precision = steps * extension_factor
# Get (steps)th root of unity # Get (steps)th root of unity
root_of_unity = f.exp(7, (modulus-1)//precision) root_of_unity = f.exp(7, (modulus-1)//precision)
@ -160,24 +153,25 @@ def verify_mimc_proof(inp, logsteps, round_constants, output, proof):
# Gets the polynomial representing the round constants # Gets the polynomial representing the round constants
skips2 = steps // len(round_constants) skips2 = steps // len(round_constants)
constants_mini_polynomial = fft(round_constants, modulus, f.exp(root_of_unity, 8 * skips2), inv=True) constants_mini_polynomial = fft(round_constants, modulus, f.exp(root_of_unity, extension_factor * skips2), inv=True)
# Verifies the low-degree proofs # Verifies the low-degree proofs
assert verify_low_degree_proof(l_root, root_of_unity, fri_proof, steps * 2, modulus) assert verify_low_degree_proof(l_root, root_of_unity, fri_proof, steps * 2, modulus, exclude_multiples_of=extension_factor)
# Performs the spot checks # Performs the spot checks
k1 = int.from_bytes(blake(p_root + d_root + b_root + b'\x01'), 'big') k1 = int.from_bytes(blake(p_root + d_root + b_root + b'\x01'), 'big')
k2 = int.from_bytes(blake(p_root + d_root + b_root + b'\x02'), 'big') k2 = int.from_bytes(blake(p_root + d_root + b_root + b'\x02'), 'big')
k3 = int.from_bytes(blake(p_root + d_root + b_root + b'\x03'), 'big') k3 = int.from_bytes(blake(p_root + d_root + b_root + b'\x03'), 'big')
k4 = int.from_bytes(blake(p_root + d_root + b_root + b'\x04'), 'big') k4 = int.from_bytes(blake(p_root + d_root + b_root + b'\x04'), 'big')
samples = spot_check_security_factor // (logprecision - logsteps) samples = spot_check_security_factor
positions = get_pseudorandom_indices(l_root, precision - skips, samples) positions = get_pseudorandom_indices(l_root, precision, samples,
exclude_multiples_of=extension_factor)
last_step_position = f.exp(root_of_unity, (steps - 1) * skips) last_step_position = f.exp(root_of_unity, (steps - 1) * skips)
for i, pos in enumerate(positions): for i, pos in enumerate(positions):
x = f.exp(root_of_unity, pos) x = f.exp(root_of_unity, pos)
x_to_the_steps = f.exp(x, steps) x_to_the_steps = f.exp(x, steps)
p_of_x = verify_branch(p_root, pos, branches[i*5]) p_of_x = verify_branch(p_root, pos, branches[i*5])
p_of_rx = verify_branch(p_root, pos+skips, branches[i*5 + 1]) p_of_rx = verify_branch(p_root, (pos+skips)%precision, branches[i*5 + 1])
d_of_x = verify_branch(d_root, pos, branches[i*5 + 2]) d_of_x = verify_branch(d_root, pos, branches[i*5 + 2])
b_of_x = verify_branch(b_root, pos, branches[i*5 + 3]) b_of_x = verify_branch(b_root, pos, branches[i*5 + 3])
l_of_x = verify_branch(l_root, pos, branches[i*5 + 4]) l_of_x = verify_branch(l_root, pos, branches[i*5 + 4])
@ -200,6 +194,6 @@ def verify_mimc_proof(inp, logsteps, round_constants, output, proof):
k1 * p_of_x - k2 * p_of_x * x_to_the_steps - k1 * p_of_x - k2 * p_of_x * x_to_the_steps -
k3 * b_of_x - k4 * b_of_x * x_to_the_steps) % modulus == 0 k3 * b_of_x - k4 * b_of_x * x_to_the_steps) % modulus == 0
print('Verified %d consistency checks' % (spot_check_security_factor // (logprecision - logsteps))) print('Verified %d consistency checks' % spot_check_security_factor)
print('Verified STARK in %.4f sec' % (time.time() - start_time)) print('Verified STARK in %.4f sec' % (time.time() - start_time))
return True return True

View File

@ -20,7 +20,7 @@ class PrimeField():
# Modular inverse using the extended Euclidean algorithm # Modular inverse using the extended Euclidean algorithm
def inv(self, a): def inv(self, a):
if a == 0: if a == 0:
raise Exception("Cannot invert 0") return 0
lm, hm = 1, 0 lm, hm = 1, 0
low, high = a % self.modulus, self.modulus low, high = a % self.modulus, self.modulus
while low > 1: while low > 1:
@ -32,12 +32,12 @@ class PrimeField():
def multi_inv(self, values): def multi_inv(self, values):
partials = [1] partials = [1]
for i in range(len(values)): for i in range(len(values)):
partials.append(self.mul(partials[-1], values[i])) partials.append(self.mul(partials[-1], values[i] or 1))
inv = self.inv(partials[-1]) inv = self.inv(partials[-1])
outputs = [0] * len(values) outputs = [0] * len(values)
for i in range(len(values), 0, -1): for i in range(len(values), 0, -1):
outputs[i-1] = self.mul(partials[i-1], inv) outputs[i-1] = self.mul(partials[i-1], inv) if values[i-1] else 0
inv = self.mul(inv, values[i-1]) inv = self.mul(inv, values[i-1] or 1)
return outputs return outputs
def div(self, x, y): def div(self, x, y):
@ -80,11 +80,12 @@ class PrimeField():
nums = [self.div_polys(root, [-x, 1]) for x in xs] nums = [self.div_polys(root, [-x, 1]) for x in xs]
# Generate denominators by evaluating numerator polys at each x # Generate denominators by evaluating numerator polys at each x
denoms = [self.eval_poly_at(nums[i], xs[i]) for i in range(len(xs))] denoms = [self.eval_poly_at(nums[i], xs[i]) for i in range(len(xs))]
invdenoms = self.multi_inv(denoms)
# Generate output polynomial, which is the sum of the per-value numerator # Generate output polynomial, which is the sum of the per-value numerator
# polynomials rescaled to have the right y values # polynomials rescaled to have the right y values
b = [0 for y in ys] b = [0 for y in ys]
for i in range(len(xs)): for i in range(len(xs)):
yslice = self.div(ys[i], denoms[i]) yslice = self.mul(ys[i], invdenoms[i])
for j in range(len(ys)): for j in range(len(ys)):
if nums[i][j] and ys[i]: if nums[i][j] and ys[i]:
b[j] += nums[i][j] * yslice b[j] += nums[i][j] * yslice

View File

@ -34,10 +34,12 @@ def test_fri():
def test_stark(): def test_stark():
INPUT = 3 INPUT = 3
LOGSTEPS = 13 import sys
LOGSTEPS = int(sys.argv[1]) if len(sys.argv) > 1 else 13
# Full STARK test # Full STARK test
import random import random
constants = [random.randrange(modulus) for i in range(64)] #constants = [random.randrange(modulus) for i in range(64)]
constants = [(i**7) ^ 42 for i in range(64)]
proof = mk_mimc_proof(INPUT, LOGSTEPS, constants) proof = mk_mimc_proof(INPUT, LOGSTEPS, constants)
p_root, d_root, b_root, l_root, branches, fri_proof = proof p_root, d_root, b_root, l_root, branches, fri_proof = proof
L1 = bin_length(compress_branches(branches)) L1 = bin_length(compress_branches(branches))

View File

@ -9,9 +9,14 @@ def get_power_cycle(r, modulus):
return o[:-1] return o[:-1]
# Extract pseudorandom indices from entropy # Extract pseudorandom indices from entropy
def get_pseudorandom_indices(seed, modulus, count): def get_pseudorandom_indices(seed, modulus, count, exclude_multiples_of=0):
assert modulus < 2**24 assert modulus < 2**24
data = seed data = seed
while len(data) < 4 * count: while len(data) < 4 * count:
data += blake(data[-32:]) data += blake(data[-32:])
return [int.from_bytes(data[i: i+4], 'big') % modulus for i in range(0, count * 4, 4)] if exclude_multiples_of == 0:
return [int.from_bytes(data[i: i+4], 'big') % modulus for i in range(0, count * 4, 4)]
else:
real_modulus = modulus * (exclude_multiples_of - 1) // exclude_multiples_of
o = [int.from_bytes(data[i: i+4], 'big') % real_modulus for i in range(0, count * 4, 4)]
return [x+1+x//(exclude_multiples_of-1) for x in o]