mirror of
https://github.com/status-im/research.git
synced 2025-01-26 06:49:42 +00:00
Cut down to 2 non-trivial FFTs
This commit is contained in:
parent
cc15f8e70c
commit
5f762aee81
@ -10,13 +10,13 @@ from poly_utils import PrimeField
|
||||
# We use maxdeg+1 instead of maxdeg because it's more mathematically
|
||||
# convenient in this case.
|
||||
|
||||
def prove_low_degree(values, root_of_unity, maxdeg_plus_1, modulus):
|
||||
def prove_low_degree(values, root_of_unity, maxdeg_plus_1, modulus, exclude_multiples_of=0):
|
||||
f = PrimeField(modulus)
|
||||
print('Proving %d values are degree <= %d' % (len(values), maxdeg_plus_1))
|
||||
|
||||
# If the degree we are checking for is less than or equal to 32,
|
||||
# use the polynomial directly as a proof
|
||||
if maxdeg_plus_1 <= 32:
|
||||
if maxdeg_plus_1 <= 16:
|
||||
print('Produced FRI proof')
|
||||
return [[x.to_bytes(32, 'big') for x in values]]
|
||||
|
||||
@ -30,6 +30,7 @@ def prove_low_degree(values, root_of_unity, maxdeg_plus_1, modulus):
|
||||
|
||||
# Select a pseudo-random x coordinate
|
||||
special_x = int.from_bytes(m[1], 'big') % modulus
|
||||
special_x = root_of_unity + 5
|
||||
|
||||
# Calculate the "column" at that x coordinate
|
||||
# (see https://vitalik.ca/general/2017/11/22/starks_part_2.html)
|
||||
@ -45,7 +46,7 @@ def prove_low_degree(values, root_of_unity, maxdeg_plus_1, modulus):
|
||||
m2 = merkelize(column)
|
||||
|
||||
# Pseudo-randomly select y indices to sample
|
||||
ys = get_pseudorandom_indices(m2[1], len(column), 40)
|
||||
ys = get_pseudorandom_indices(m2[1], len(column), 40, exclude_multiples_of=exclude_multiples_of)
|
||||
|
||||
# Compute the Merkle branches for the values in the polynomial and the column
|
||||
branches = []
|
||||
@ -63,10 +64,10 @@ def prove_low_degree(values, root_of_unity, maxdeg_plus_1, modulus):
|
||||
|
||||
# Recurse...
|
||||
return [o] + prove_low_degree(column, f.exp(root_of_unity, 4),
|
||||
maxdeg_plus_1 // 4, modulus)
|
||||
maxdeg_plus_1 // 4, modulus, exclude_multiples_of=exclude_multiples_of)
|
||||
|
||||
# Verify an FRI proof
|
||||
def verify_low_degree_proof(merkle_root, root_of_unity, proof, maxdeg_plus_1, modulus):
|
||||
def verify_low_degree_proof(merkle_root, root_of_unity, proof, maxdeg_plus_1, modulus, exclude_multiples_of=0):
|
||||
f = PrimeField(modulus)
|
||||
|
||||
# Calculate which root of unity we're working with
|
||||
@ -89,10 +90,11 @@ def verify_low_degree_proof(merkle_root, root_of_unity, proof, maxdeg_plus_1, mo
|
||||
|
||||
# Calculate the pseudo-random x coordinate
|
||||
special_x = int.from_bytes(merkle_root, 'big') % modulus
|
||||
special_x = root_of_unity + 5
|
||||
|
||||
# Calculate the pseudo-randomly sampled y indices
|
||||
ys = get_pseudorandom_indices(root2, roudeg // 4, 40)
|
||||
|
||||
ys = get_pseudorandom_indices(root2, roudeg // 4, 40,
|
||||
exclude_multiples_of=exclude_multiples_of)
|
||||
|
||||
# Verify for each selected y coordinate that the four points from the
|
||||
# polynomial and the one point from the column that are on that y
|
||||
@ -122,16 +124,23 @@ def verify_low_degree_proof(merkle_root, root_of_unity, proof, maxdeg_plus_1, mo
|
||||
# Verify the direct components of the proof
|
||||
data = [int.from_bytes(x, 'big') for x in proof[-1]]
|
||||
print('Verifying degree <= %d' % maxdeg_plus_1)
|
||||
assert maxdeg_plus_1 <= 32
|
||||
assert maxdeg_plus_1 <= 16
|
||||
|
||||
# Check the Merkle root matches up
|
||||
mtree = merkelize(data)
|
||||
assert mtree[1] == merkle_root
|
||||
|
||||
# Check the degree of the data
|
||||
poly = fft(data, modulus, root_of_unity, inv=True)
|
||||
for i in range(maxdeg_plus_1, len(poly)):
|
||||
assert poly[i] == 0
|
||||
powers = get_power_cycle(root_of_unity, modulus)
|
||||
if exclude_multiples_of:
|
||||
pts = [x for x in range(len(data)) if x % exclude_multiples_of]
|
||||
else:
|
||||
pts = range(len(data))
|
||||
|
||||
poly = f.lagrange_interp([powers[x] for x in pts[:maxdeg_plus_1]],
|
||||
[data[x] for x in pts[:maxdeg_plus_1]])
|
||||
for x in pts[maxdeg_plus_1:]:
|
||||
assert f.eval_poly_at(poly, powers[x]) == data[x]
|
||||
|
||||
print('FRI proof verified')
|
||||
return True
|
||||
|
@ -10,7 +10,8 @@ modulus = 2**256 - 2**32 * 351 + 1
|
||||
f = PrimeField(modulus)
|
||||
nonresidue = 7
|
||||
|
||||
spot_check_security_factor = 240
|
||||
spot_check_security_factor = 80
|
||||
extension_factor = 8
|
||||
|
||||
# Compute a MIMC permutation for 2**logsteps steps
|
||||
def mimc(inp, logsteps, round_constants):
|
||||
@ -25,9 +26,8 @@ def mimc(inp, logsteps, round_constants):
|
||||
def mk_mimc_proof(inp, logsteps, round_constants):
|
||||
start_time = time.time()
|
||||
assert logsteps <= 29
|
||||
logprecision = logsteps + 3
|
||||
steps = 2**logsteps
|
||||
precision = 2**logprecision
|
||||
precision = steps * extension_factor
|
||||
|
||||
# Root of unity such that x^precision=1
|
||||
root_of_unity = f.exp(7, (modulus-1)//precision)
|
||||
@ -37,9 +37,9 @@ def mk_mimc_proof(inp, logsteps, round_constants):
|
||||
subroot = f.exp(root_of_unity, skips)
|
||||
|
||||
# Powers of the root of unity, our computational trace will be
|
||||
# along the sequence of roots of unity
|
||||
xs = get_power_cycle(subroot, modulus)
|
||||
last_step_position = xs[steps-1]
|
||||
# along the sequence of sub-roots
|
||||
xs = get_power_cycle(root_of_unity, modulus)
|
||||
last_step_position = xs[(steps-1)*extension_factor]
|
||||
|
||||
# Generate the computational trace
|
||||
values = [inp]
|
||||
@ -61,38 +61,29 @@ def mk_mimc_proof(inp, logsteps, round_constants):
|
||||
|
||||
# Create the composed polynomial such that
|
||||
# C(P(x), P(rx), K(x)) = P(rx) - P(x)**3 - K(x)
|
||||
term1 = f.multiply_base(values_polynomial, subroot)
|
||||
term2 = fft([f.exp(x, 3) for x in p_evaluations], modulus, root_of_unity, inv=True)[:len(values_polynomial) * 3 - 2]
|
||||
c_of_values = f.sub_polys(f.sub_polys(term1, term2), constants_polynomial)
|
||||
c_of_p_evaluations = [(p_evaluations[(i+8)%precision] - f.exp(p_evaluations[i], 3) -
|
||||
constants_mini_extension[i % len(constants_mini_extension)]) % modulus for i in range(precision)]
|
||||
c_of_p_evaluations = [(p_evaluations[(i+extension_factor)%precision] -
|
||||
f.exp(p_evaluations[i], 3) -
|
||||
constants_mini_extension[i % len(constants_mini_extension)])
|
||||
% modulus for i in range(precision)]
|
||||
print('Computed C(P, K) polynomial')
|
||||
|
||||
# Compute D(x) = C(P(x), P(rx), K(x)) / Z(x)
|
||||
# Z(x) = (x^steps - 1) / (x - x_atlast_step)
|
||||
d = f.divide_by_xnm1(f.mul_polys(c_of_values,
|
||||
[-last_step_position, 1]),
|
||||
steps)
|
||||
# Consistency check
|
||||
# assert (f.eval_poly_at(d, 90833) *
|
||||
# (f.exp(90833, steps) - 1) *
|
||||
# f.inv(f.eval_poly_at([-last_step_position, 1], 90833)) -
|
||||
# f.eval_poly_at(c_of_values, 90833)) % modulus == 0
|
||||
z_num_evaluations = [xs[(i * steps) % precision] - 1 for i in range(precision)]
|
||||
z_num_inv = f.multi_inv(z_num_evaluations)
|
||||
z_den_evaluations = [xs[i] - last_step_position for i in range(precision)]
|
||||
d_evaluations = [cp * zd * zni % modulus for cp, zd, zni in zip(c_of_p_evaluations, z_den_evaluations, z_num_inv)]
|
||||
print('Computed D polynomial')
|
||||
|
||||
# Compute interpolant of ((1, input), (x_atlast_step, output))
|
||||
interpolant = f.lagrange_interp_2([1, last_step_position], [inp, output])
|
||||
quotient = f.mul_polys([-1, 1], [-last_step_position, 1])
|
||||
b = f.div_polys(f.sub_polys(values_polynomial, interpolant), quotient)
|
||||
# Consistency check
|
||||
# assert f.eval_poly_at(f.add_polys(f.mul_polys(b, quotient), interpolant), 7045) == \
|
||||
# f.eval_poly_at(values_polynomial, 7045)
|
||||
print('Computed B polynomial')
|
||||
i_evaluations = [f.eval_poly_at(interpolant, x) for x in xs]
|
||||
|
||||
# Evaluate B and D across the entire subgroup
|
||||
d_evaluations = fft(d, modulus, root_of_unity)
|
||||
b_evaluations = fft(b, modulus, root_of_unity)
|
||||
print('Evaluated low-degree extension of B and D')
|
||||
quotient = f.mul_polys([-1, 1], [-last_step_position, 1])
|
||||
inv_q_evaluations = f.multi_inv([f.eval_poly_at(quotient, x) for x in xs])
|
||||
|
||||
b_evaluations = [((p - i) * invq) % modulus for p, i, invq in zip(p_evaluations, i_evaluations, inv_q_evaluations)]
|
||||
print('Computed B polynomial')
|
||||
|
||||
# Compute their Merkle roots
|
||||
p_mtree = merkelize(p_evaluations)
|
||||
@ -119,16 +110,19 @@ def mk_mimc_proof(inp, logsteps, round_constants):
|
||||
p_evaluations[i] * k1 + p_evaluations[i] * k2 * powers[i] +
|
||||
b_evaluations[i] * k3 + b_evaluations[i] * powers[i] * k4) % modulus
|
||||
for i in range(precision)]
|
||||
|
||||
l_mtree = merkelize(l_evaluations)
|
||||
print('Computed random linear combination')
|
||||
|
||||
# Do some spot checks of the Merkle tree at pseudo-random coordinates
|
||||
# Do some spot checks of the Merkle tree at pseudo-random coordinates, excluding
|
||||
# multiples of `extension_factor`
|
||||
branches = []
|
||||
samples = spot_check_security_factor // (logprecision - logsteps)
|
||||
positions = get_pseudorandom_indices(l_mtree[1], precision - skips, samples)
|
||||
samples = spot_check_security_factor
|
||||
positions = get_pseudorandom_indices(l_mtree[1], precision, samples,
|
||||
exclude_multiples_of=extension_factor)
|
||||
for pos in positions:
|
||||
branches.append(mk_branch(p_mtree, pos))
|
||||
branches.append(mk_branch(p_mtree, pos + skips))
|
||||
branches.append(mk_branch(p_mtree, (pos + skips) % precision))
|
||||
branches.append(mk_branch(d_mtree, pos))
|
||||
branches.append(mk_branch(b_mtree, pos))
|
||||
branches.append(mk_branch(l_mtree, pos))
|
||||
@ -141,7 +135,7 @@ def mk_mimc_proof(inp, logsteps, round_constants):
|
||||
b_mtree[1],
|
||||
l_mtree[1],
|
||||
branches,
|
||||
prove_low_degree(l_evaluations, root_of_unity, steps * 2, modulus)]
|
||||
prove_low_degree(l_evaluations, root_of_unity, steps * 2, modulus, exclude_multiples_of=extension_factor)]
|
||||
print("STARK computed in %.4f sec" % (time.time() - start_time))
|
||||
return o
|
||||
|
||||
@ -150,9 +144,8 @@ def verify_mimc_proof(inp, logsteps, round_constants, output, proof):
|
||||
p_root, d_root, b_root, l_root, branches, fri_proof = proof
|
||||
start_time = time.time()
|
||||
|
||||
logprecision = logsteps + 3
|
||||
steps = 2**logsteps
|
||||
precision = 2**logprecision
|
||||
precision = steps * extension_factor
|
||||
|
||||
# Get (steps)th root of unity
|
||||
root_of_unity = f.exp(7, (modulus-1)//precision)
|
||||
@ -160,24 +153,25 @@ def verify_mimc_proof(inp, logsteps, round_constants, output, proof):
|
||||
|
||||
# Gets the polynomial representing the round constants
|
||||
skips2 = steps // len(round_constants)
|
||||
constants_mini_polynomial = fft(round_constants, modulus, f.exp(root_of_unity, 8 * skips2), inv=True)
|
||||
constants_mini_polynomial = fft(round_constants, modulus, f.exp(root_of_unity, extension_factor * skips2), inv=True)
|
||||
|
||||
# Verifies the low-degree proofs
|
||||
assert verify_low_degree_proof(l_root, root_of_unity, fri_proof, steps * 2, modulus)
|
||||
assert verify_low_degree_proof(l_root, root_of_unity, fri_proof, steps * 2, modulus, exclude_multiples_of=extension_factor)
|
||||
|
||||
# Performs the spot checks
|
||||
k1 = int.from_bytes(blake(p_root + d_root + b_root + b'\x01'), 'big')
|
||||
k2 = int.from_bytes(blake(p_root + d_root + b_root + b'\x02'), 'big')
|
||||
k3 = int.from_bytes(blake(p_root + d_root + b_root + b'\x03'), 'big')
|
||||
k4 = int.from_bytes(blake(p_root + d_root + b_root + b'\x04'), 'big')
|
||||
samples = spot_check_security_factor // (logprecision - logsteps)
|
||||
positions = get_pseudorandom_indices(l_root, precision - skips, samples)
|
||||
samples = spot_check_security_factor
|
||||
positions = get_pseudorandom_indices(l_root, precision, samples,
|
||||
exclude_multiples_of=extension_factor)
|
||||
last_step_position = f.exp(root_of_unity, (steps - 1) * skips)
|
||||
for i, pos in enumerate(positions):
|
||||
x = f.exp(root_of_unity, pos)
|
||||
x_to_the_steps = f.exp(x, steps)
|
||||
p_of_x = verify_branch(p_root, pos, branches[i*5])
|
||||
p_of_rx = verify_branch(p_root, pos+skips, branches[i*5 + 1])
|
||||
p_of_rx = verify_branch(p_root, (pos+skips)%precision, branches[i*5 + 1])
|
||||
d_of_x = verify_branch(d_root, pos, branches[i*5 + 2])
|
||||
b_of_x = verify_branch(b_root, pos, branches[i*5 + 3])
|
||||
l_of_x = verify_branch(l_root, pos, branches[i*5 + 4])
|
||||
@ -200,6 +194,6 @@ def verify_mimc_proof(inp, logsteps, round_constants, output, proof):
|
||||
k1 * p_of_x - k2 * p_of_x * x_to_the_steps -
|
||||
k3 * b_of_x - k4 * b_of_x * x_to_the_steps) % modulus == 0
|
||||
|
||||
print('Verified %d consistency checks' % (spot_check_security_factor // (logprecision - logsteps)))
|
||||
print('Verified %d consistency checks' % spot_check_security_factor)
|
||||
print('Verified STARK in %.4f sec' % (time.time() - start_time))
|
||||
return True
|
||||
|
@ -20,7 +20,7 @@ class PrimeField():
|
||||
# Modular inverse using the extended Euclidean algorithm
|
||||
def inv(self, a):
|
||||
if a == 0:
|
||||
raise Exception("Cannot invert 0")
|
||||
return 0
|
||||
lm, hm = 1, 0
|
||||
low, high = a % self.modulus, self.modulus
|
||||
while low > 1:
|
||||
@ -32,12 +32,12 @@ class PrimeField():
|
||||
def multi_inv(self, values):
|
||||
partials = [1]
|
||||
for i in range(len(values)):
|
||||
partials.append(self.mul(partials[-1], values[i]))
|
||||
partials.append(self.mul(partials[-1], values[i] or 1))
|
||||
inv = self.inv(partials[-1])
|
||||
outputs = [0] * len(values)
|
||||
for i in range(len(values), 0, -1):
|
||||
outputs[i-1] = self.mul(partials[i-1], inv)
|
||||
inv = self.mul(inv, values[i-1])
|
||||
outputs[i-1] = self.mul(partials[i-1], inv) if values[i-1] else 0
|
||||
inv = self.mul(inv, values[i-1] or 1)
|
||||
return outputs
|
||||
|
||||
def div(self, x, y):
|
||||
@ -80,11 +80,12 @@ class PrimeField():
|
||||
nums = [self.div_polys(root, [-x, 1]) for x in xs]
|
||||
# Generate denominators by evaluating numerator polys at each x
|
||||
denoms = [self.eval_poly_at(nums[i], xs[i]) for i in range(len(xs))]
|
||||
invdenoms = self.multi_inv(denoms)
|
||||
# Generate output polynomial, which is the sum of the per-value numerator
|
||||
# polynomials rescaled to have the right y values
|
||||
b = [0 for y in ys]
|
||||
for i in range(len(xs)):
|
||||
yslice = self.div(ys[i], denoms[i])
|
||||
yslice = self.mul(ys[i], invdenoms[i])
|
||||
for j in range(len(ys)):
|
||||
if nums[i][j] and ys[i]:
|
||||
b[j] += nums[i][j] * yslice
|
||||
|
@ -34,10 +34,12 @@ def test_fri():
|
||||
|
||||
def test_stark():
|
||||
INPUT = 3
|
||||
LOGSTEPS = 13
|
||||
import sys
|
||||
LOGSTEPS = int(sys.argv[1]) if len(sys.argv) > 1 else 13
|
||||
# Full STARK test
|
||||
import random
|
||||
constants = [random.randrange(modulus) for i in range(64)]
|
||||
#constants = [random.randrange(modulus) for i in range(64)]
|
||||
constants = [(i**7) ^ 42 for i in range(64)]
|
||||
proof = mk_mimc_proof(INPUT, LOGSTEPS, constants)
|
||||
p_root, d_root, b_root, l_root, branches, fri_proof = proof
|
||||
L1 = bin_length(compress_branches(branches))
|
||||
|
@ -9,9 +9,14 @@ def get_power_cycle(r, modulus):
|
||||
return o[:-1]
|
||||
|
||||
# Extract pseudorandom indices from entropy
|
||||
def get_pseudorandom_indices(seed, modulus, count):
|
||||
def get_pseudorandom_indices(seed, modulus, count, exclude_multiples_of=0):
|
||||
assert modulus < 2**24
|
||||
data = seed
|
||||
while len(data) < 4 * count:
|
||||
data += blake(data[-32:])
|
||||
if exclude_multiples_of == 0:
|
||||
return [int.from_bytes(data[i: i+4], 'big') % modulus for i in range(0, count * 4, 4)]
|
||||
else:
|
||||
real_modulus = modulus * (exclude_multiples_of - 1) // exclude_multiples_of
|
||||
o = [int.from_bytes(data[i: i+4], 'big') % real_modulus for i in range(0, count * 4, 4)]
|
||||
return [x+1+x//(exclude_multiples_of-1) for x in o]
|
||||
|
Loading…
x
Reference in New Issue
Block a user