Added QAP generator explainer code
This commit is contained in:
parent
4d496f48ca
commit
eec84d1c0c
Binary file not shown.
|
@ -0,0 +1,238 @@
|
|||
import ast
|
||||
if 'arg' not in dir(ast):
|
||||
ast.arg = type(None)
|
||||
|
||||
def parse(code):
|
||||
return ast.parse(code).body
|
||||
|
||||
# Takes code of the form
|
||||
# def foo(arg1, arg2 ...):
|
||||
# x = arg1 + arg2
|
||||
# y = ...
|
||||
# return x + y
|
||||
# And extracts the inputs and the body, where
|
||||
# it expects the body to be a sequence of
|
||||
# variable assignments (variables are immutable;
|
||||
# can only be set once) and a return statement at the end
|
||||
def extract_inputs_and_body(code):
|
||||
o = []
|
||||
if len(code) != 1 or not isinstance(code[0], ast.FunctionDef):
|
||||
raise Exception("Expecting function declaration")
|
||||
# Gather the list of input variables
|
||||
inputs = []
|
||||
for arg in code[0].args.args:
|
||||
if isinstance(arg, ast.arg):
|
||||
assert isinstance(arg.arg, str)
|
||||
inputs.append(arg.arg)
|
||||
elif isinstance(arg, ast.Name):
|
||||
inputs.append(arg.id)
|
||||
else:
|
||||
raise Exception("Invalid arg: %r" % ast.dump(arg))
|
||||
# Gather the body
|
||||
body = []
|
||||
returned = False
|
||||
for c in code[0].body:
|
||||
if not isinstance(c, (ast.Assign, ast.Return)):
|
||||
raise Exception("Expected variable assignment or return")
|
||||
if returned:
|
||||
raise Exception("Cannot do stuff after a return statement")
|
||||
if isinstance(c, ast.Return):
|
||||
returned = True
|
||||
body.append(c)
|
||||
return inputs, body
|
||||
|
||||
# Convert a body with potentially complex expressions into
|
||||
# simple expressions of the form x = y or x = y * z
|
||||
def flatten_body(body):
|
||||
o = []
|
||||
for c in body:
|
||||
o.extend(flatten_stmt(c))
|
||||
return o
|
||||
|
||||
# Generate a dummy variable
|
||||
next_symbol = [0]
|
||||
def mksymbol():
|
||||
next_symbol[0] += 1
|
||||
return 'sym_'+str(next_symbol[0])
|
||||
|
||||
# "Flatten" a single statement into a list of simple statements.
|
||||
# First extract the target variable, then flatten the expression
|
||||
def flatten_stmt(stmt):
|
||||
# Get target variable
|
||||
if isinstance(stmt, ast.Assign):
|
||||
assert len(stmt.targets) == 1 and isinstance(stmt.targets[0], ast.Name)
|
||||
target = stmt.targets[0].id
|
||||
elif isinstance(stmt, ast.Return):
|
||||
target = '~out'
|
||||
# Get inner content
|
||||
return flatten_expr(target, stmt.value)
|
||||
|
||||
# Main method for flattening an expression
|
||||
def flatten_expr(target, expr):
|
||||
# x = y
|
||||
if isinstance(expr, ast.Name):
|
||||
return [['set', target, expr.id]]
|
||||
# x = 5
|
||||
elif isinstance(expr, ast.Num):
|
||||
return [['set', target, expr.n]]
|
||||
# x = y (op) z
|
||||
# Or, for that matter, x = y (op) 5
|
||||
elif isinstance(expr, ast.BinOp):
|
||||
if isinstance(expr.op, ast.Add):
|
||||
op = '+'
|
||||
elif isinstance(expr.op, ast.Mult):
|
||||
op = '*'
|
||||
elif isinstance(expr.op, ast.Sub):
|
||||
op = '-'
|
||||
elif isinstance(expr.op, ast.Div):
|
||||
op = '/'
|
||||
# Exponentiation gets compiled to repeat multiplication,
|
||||
# requires constant exponent
|
||||
elif isinstance(expr.op, ast.Pow):
|
||||
assert isinstance(expr.right, ast.Num)
|
||||
if expr.right.n == 0:
|
||||
return [['set', target, 1]]
|
||||
elif expr.right.n == 1:
|
||||
return flatten_expr(target, expr.left)
|
||||
else: # This could be made more efficient via square-and-multiply but oh well
|
||||
if isinstance(expr.left, (ast.Name, ast.Num)):
|
||||
nxt = base = expr.left.id if isinstance(expr.left, ast.Name) else expr.left.n
|
||||
o = []
|
||||
else:
|
||||
nxt = base = mksymbol()
|
||||
o = flatten_expr(base, expr.left)
|
||||
for i in range(1, expr.right.n):
|
||||
latest = nxt
|
||||
nxt = target if i == expr.right.n - 1 else mksymbol()
|
||||
o.append(['*', nxt, latest, base])
|
||||
return o
|
||||
else:
|
||||
raise Exception("Bad operation: " % ast.dump(stmt.op))
|
||||
# If the subexpression is a variable or a number, then include it directly
|
||||
if isinstance(expr.left, (ast.Name, ast.Num)):
|
||||
var1 = expr.left.id if isinstance(expr.left, ast.Name) else expr.left.n
|
||||
sub1 = []
|
||||
# If one of the subexpressions is itself a compound expression, recursively
|
||||
# apply this method to it using an intermediate variable
|
||||
else:
|
||||
var1 = mksymbol()
|
||||
sub1 = flatten_expr(var1, expr.left)
|
||||
# Same for right subexpression as for left subexpression
|
||||
if isinstance(expr.right, (ast.Name, ast.Num)):
|
||||
var2 = expr.right.id if isinstance(expr.right, ast.Name) else expr.right.n
|
||||
sub2 = []
|
||||
else:
|
||||
var2 = mksymbol()
|
||||
sub2 = flatten_expr(var2, expr.right)
|
||||
# Last expression represents the assignment; sub1 and sub2 represent the
|
||||
# processing for the subexpression if any
|
||||
return sub1 + sub2 + [[op, target, var1, var2]]
|
||||
else:
|
||||
raise Exception("Unexpected statement value: %r" % stmt.value)
|
||||
|
||||
# Adds a variable or number into one of the vectors; if it's a variable
|
||||
# then the slot associated with that variable is set to 1, and if it's
|
||||
# a number then the slot associated with 1 gets set to that number
|
||||
def insert_var(arr, varz, var, used, reverse=False):
|
||||
if isinstance(var, str):
|
||||
if var not in used:
|
||||
raise Exception("Using a variable before it is set!")
|
||||
arr[varz.index(var)] += (-1 if reverse else 1)
|
||||
elif isinstance(var, int):
|
||||
arr[0] += var * (-1 if reverse else 1)
|
||||
|
||||
# Maps input, output and intermediate variables to indices
|
||||
def get_var_placement(inputs, flatcode):
|
||||
return ['~one'] + [x for x in inputs] + ['~out'] + [c[1] for c in flatcode if c[1] not in inputs and c[1] != '~out']
|
||||
|
||||
|
||||
# Convert the flattened code generated above into a rank-1 constraint system
|
||||
def flatcode_to_r1cs(inputs, flatcode):
|
||||
varz = get_var_placement(inputs, flatcode)
|
||||
A, B, C = [], [], []
|
||||
used = {i: True for i in inputs}
|
||||
for x in flatcode:
|
||||
a, b, c = [0] * len(varz), [0] * len(varz), [0] * len(varz)
|
||||
if x[1] in used:
|
||||
raise Exception("Variable already used: %r" % x[1])
|
||||
used[x[1]] = True
|
||||
if x[0] == 'set':
|
||||
a[varz.index(x[1])] += 1
|
||||
insert_var(a, varz, x[2], used, reverse=True)
|
||||
b[0] = 1
|
||||
elif x[0] == '+' or x[0] == '-':
|
||||
c[varz.index(x[1])] = 1
|
||||
insert_var(a, varz, x[2], used)
|
||||
insert_var(a, varz, x[3], used, reverse=(x[0] == '-'))
|
||||
b[0] = 1
|
||||
elif x[0] == '*':
|
||||
c[varz.index(x[1])] = 1
|
||||
insert_var(a, varz, x[2], used)
|
||||
insert_var(b, varz, x[3], used)
|
||||
elif x[0] == '/':
|
||||
insert_var(c, varz, x[2], used)
|
||||
a[varz.index(x[1])] = 1
|
||||
insert_var(b, varz, x[3], used)
|
||||
A.append(a)
|
||||
B.append(b)
|
||||
C.append(c)
|
||||
return A, B, C
|
||||
|
||||
# Get a variable or number given an existing input vector
|
||||
def grab_var(varz, assignment, var):
|
||||
if isinstance(var, str):
|
||||
return assignment[varz.index(var)]
|
||||
elif isinstance(var, int):
|
||||
return var
|
||||
else:
|
||||
raise Exception("What kind of expression is this? %r" % var)
|
||||
|
||||
# Goes through flattened code and completes the input vector
|
||||
def assign_variables(inputs, input_vars, flatcode):
|
||||
varz = get_var_placement(inputs, flatcode)
|
||||
assignment = [0] * len(varz)
|
||||
assignment[0] = 1
|
||||
for i, inp in enumerate(input_vars):
|
||||
assignment[i + 1] = inp
|
||||
for x in flatcode:
|
||||
if x[0] == 'set':
|
||||
assignment[varz.index(x[1])] = grab_var(varz, assignment, x[2])
|
||||
elif x[0] == '+':
|
||||
assignment[varz.index(x[1])] = grab_var(varz, assignment, x[2]) + grab_var(varz, assignment, x[3])
|
||||
elif x[0] == '-':
|
||||
assignment[varz.index(x[1])] = grab_var(varz, assignment, x[2]) - grab_var(varz, assignment, x[3])
|
||||
elif x[0] == '*':
|
||||
assignment[varz.index(x[1])] = grab_var(varz, assignment, x[2]) * grab_var(varz, assignment, x[3])
|
||||
elif x[0] == '/':
|
||||
assignment[varz.index(x[1])] = grab_var(varz, assignment, x[2]) / grab_var(varz, assignment, x[3])
|
||||
return assignment
|
||||
|
||||
|
||||
def code_to_r1cs_with_inputs(code, input_vars):
|
||||
inputs, body = extract_inputs_and_body(parse(code))
|
||||
print 'Inputs'
|
||||
print inputs
|
||||
print 'Body'
|
||||
print body
|
||||
flatcode = flatten_body(body)
|
||||
print 'Flatcode'
|
||||
print flatcode
|
||||
print 'Input var assignment'
|
||||
print get_var_placement(inputs, flatcode)
|
||||
A, B, C = flatcode_to_r1cs(inputs, flatcode)
|
||||
r = assign_variables(inputs, input_vars, flatcode)
|
||||
return r, A, B, C
|
||||
|
||||
r, A, B, C = code_to_r1cs_with_inputs("""
|
||||
def qeval(x):
|
||||
y = x**3
|
||||
return y + x + 5
|
||||
""", [3])
|
||||
print 'r'
|
||||
print r
|
||||
print 'A'
|
||||
for x in A: print x
|
||||
print 'B'
|
||||
for x in B: print x
|
||||
print 'C'
|
||||
for x in C: print x
|
|
@ -0,0 +1,132 @@
|
|||
# Polynomials are stored as arrays, where the ith element in
|
||||
# the array is the ith degree coefficient
|
||||
|
||||
# Multiply two polynomials
|
||||
def multiply_polys(a, b):
|
||||
o = [0] * (len(a) + len(b) - 1)
|
||||
for i in range(len(a)):
|
||||
for j in range(len(b)):
|
||||
o[i + j] += a[i] * b[j]
|
||||
return o
|
||||
|
||||
# Add two polynomials
|
||||
def add_polys(a, b, subtract=False):
|
||||
o = [0] * max(len(a), len(b))
|
||||
for i in range(len(a)):
|
||||
o[i] += a[i]
|
||||
for i in range(len(b)):
|
||||
o[i] += b[i] * (-1 if subtract else 1) # Reuse the function structure for subtraction
|
||||
return o
|
||||
|
||||
def subtract_polys(a, b):
|
||||
return add_polys(a, b, subtract=True)
|
||||
|
||||
# Divide a/b, return quotient and remainder
|
||||
def div_polys(a, b):
|
||||
o = [0] * (len(a) - len(b) + 1)
|
||||
remainder = a
|
||||
while len(remainder) >= len(b):
|
||||
leading_fac = remainder[-1] / b[-1]
|
||||
pos = len(remainder) - len(b)
|
||||
o[pos] = leading_fac
|
||||
remainder = subtract_polys(remainder, multiply_polys(b, [0] * pos + [leading_fac]))[:-1]
|
||||
return o, remainder
|
||||
|
||||
# Evaluate a polynomial at a point
|
||||
def eval_poly(poly, x):
|
||||
return sum([poly[i] * x**i for i in range(len(poly))])
|
||||
|
||||
# Make a polynomial which is zero at {1, 2 ... total_pts}, except
|
||||
# for `point_loc` where the value is `height`
|
||||
def mk_singleton(point_loc, height, total_pts):
|
||||
fac = 1
|
||||
for i in range(1, total_pts + 1):
|
||||
if i != point_loc:
|
||||
fac *= point_loc - i
|
||||
o = [height * 1.0 / fac]
|
||||
for i in range(1, total_pts + 1):
|
||||
if i != point_loc:
|
||||
o = multiply_polys(o, [-i, 1])
|
||||
return o
|
||||
|
||||
# Assumes vec[0] = p(1), vec[1] = p(2), etc, tries to find p,
|
||||
# expresses result as [deg 0 coeff, deg 1 coeff...]
|
||||
def lagrange_interp(vec):
|
||||
o = []
|
||||
for i in range(len(vec)):
|
||||
o = add_polys(o, mk_singleton(i + 1, vec[i], len(vec)))
|
||||
for i in range(len(vec)):
|
||||
assert abs(eval_poly(o, i + 1) - vec[i] < 10**-10), \
|
||||
(o, eval_poly(o, i + 1), i+1)
|
||||
return o
|
||||
|
||||
def transpose(matrix):
|
||||
return list(map(list, zip(*matrix)))
|
||||
|
||||
# A, B, C = matrices of m vectors of length n, where for each
|
||||
# 0 <= i < m, we want to satisfy A[i] * B[i] - C[i] = 0
|
||||
def r1cs_to_qap(A, B, C):
|
||||
A, B, C = transpose(A), transpose(B), transpose(C)
|
||||
new_A = [lagrange_interp(a) for a in A]
|
||||
new_B = [lagrange_interp(b) for b in B]
|
||||
new_C = [lagrange_interp(c) for c in C]
|
||||
Z = [1]
|
||||
for i in range(1, len(A[0]) + 1):
|
||||
Z = multiply_polys(Z, [-i, 1])
|
||||
return (new_A, new_B, new_C, Z)
|
||||
|
||||
def create_solution_polynomials(r, new_A, new_B, new_C):
|
||||
Apoly = []
|
||||
for rval, a in zip(r, new_A):
|
||||
Apoly = add_polys(Apoly, multiply_polys([rval], a))
|
||||
Bpoly = []
|
||||
for rval, b in zip(r, new_B):
|
||||
Bpoly = add_polys(Bpoly, multiply_polys([rval], b))
|
||||
Cpoly = []
|
||||
for rval, c in zip(r, new_C):
|
||||
Cpoly = add_polys(Cpoly, multiply_polys([rval], c))
|
||||
o = subtract_polys(multiply_polys(Apoly, Bpoly), Cpoly)
|
||||
for i in range(1, len(new_A[0]) + 1):
|
||||
assert abs(eval_poly(o, i)) < 10**-10, (eval_poly(o, i), i)
|
||||
return Apoly, Bpoly, Cpoly, o
|
||||
|
||||
def create_divisor_polynomial(sol, Z):
|
||||
quot, rem = div_polys(sol, Z)
|
||||
for x in rem:
|
||||
assert abs(x) < 10**-10
|
||||
return quot
|
||||
|
||||
r = [1, 3, 35, 9, 27, 30]
|
||||
A = [[0, 1, 0, 0, 0, 0],
|
||||
[0, 0, 0, 1, 0, 0],
|
||||
[0, 1, 0, 0, 1, 0],
|
||||
[5, 0, 0, 0, 0, 1]]
|
||||
B = [[0, 1, 0, 0, 0, 0],
|
||||
[0, 1, 0, 0, 0, 0],
|
||||
[1, 0, 0, 0, 0, 0],
|
||||
[1, 0, 0, 0, 0, 0]]
|
||||
C = [[0, 0, 0, 1, 0, 0],
|
||||
[0, 0, 0, 0, 1, 0],
|
||||
[0, 0, 0, 0, 0, 1],
|
||||
[0, 0, 1, 0, 0, 0]]
|
||||
|
||||
Ap, Bp, Cp, Z = r1cs_to_qap(A, B, C)
|
||||
print 'Ap'
|
||||
for x in Ap: print x
|
||||
print 'Bp'
|
||||
for x in Bp: print x
|
||||
print 'Cp'
|
||||
for x in Cp: print x
|
||||
print 'Z'
|
||||
print Z
|
||||
Apoly, Bpoly, Cpoly, sol = create_solution_polynomials(r, Ap, Bp, Cp)
|
||||
print 'Apoly'
|
||||
print Apoly
|
||||
print 'Bpoly'
|
||||
print Bpoly
|
||||
print 'Cpoly'
|
||||
print Cpoly
|
||||
print 'Sol'
|
||||
print sol
|
||||
print 'Z cofactor'
|
||||
print create_divisor_polynomial(sol, Z)
|
Loading…
Reference in New Issue