research/zksnark/code_to_r1cs.py

239 lines
8.6 KiB
Python
Raw Normal View History

2016-12-11 02:29:08 -05:00
import ast
if 'arg' not in dir(ast):
ast.arg = type(None)
def parse(code):
return ast.parse(code).body
# Takes code of the form
# def foo(arg1, arg2 ...):
# x = arg1 + arg2
# y = ...
# return x + y
# And extracts the inputs and the body, where
# it expects the body to be a sequence of
# variable assignments (variables are immutable;
# can only be set once) and a return statement at the end
def extract_inputs_and_body(code):
o = []
if len(code) != 1 or not isinstance(code[0], ast.FunctionDef):
raise Exception("Expecting function declaration")
# Gather the list of input variables
inputs = []
for arg in code[0].args.args:
if isinstance(arg, ast.arg):
assert isinstance(arg.arg, str)
inputs.append(arg.arg)
elif isinstance(arg, ast.Name):
inputs.append(arg.id)
else:
raise Exception("Invalid arg: %r" % ast.dump(arg))
# Gather the body
body = []
returned = False
for c in code[0].body:
if not isinstance(c, (ast.Assign, ast.Return)):
raise Exception("Expected variable assignment or return")
if returned:
raise Exception("Cannot do stuff after a return statement")
if isinstance(c, ast.Return):
returned = True
body.append(c)
return inputs, body
# Convert a body with potentially complex expressions into
# simple expressions of the form x = y or x = y * z
def flatten_body(body):
o = []
for c in body:
o.extend(flatten_stmt(c))
return o
# Generate a dummy variable
next_symbol = [0]
def mksymbol():
next_symbol[0] += 1
return 'sym_'+str(next_symbol[0])
# "Flatten" a single statement into a list of simple statements.
# First extract the target variable, then flatten the expression
def flatten_stmt(stmt):
# Get target variable
if isinstance(stmt, ast.Assign):
assert len(stmt.targets) == 1 and isinstance(stmt.targets[0], ast.Name)
target = stmt.targets[0].id
elif isinstance(stmt, ast.Return):
target = '~out'
# Get inner content
return flatten_expr(target, stmt.value)
# Main method for flattening an expression
def flatten_expr(target, expr):
# x = y
if isinstance(expr, ast.Name):
return [['set', target, expr.id]]
# x = 5
elif isinstance(expr, ast.Num):
return [['set', target, expr.n]]
# x = y (op) z
# Or, for that matter, x = y (op) 5
elif isinstance(expr, ast.BinOp):
if isinstance(expr.op, ast.Add):
op = '+'
elif isinstance(expr.op, ast.Mult):
op = '*'
elif isinstance(expr.op, ast.Sub):
op = '-'
elif isinstance(expr.op, ast.Div):
op = '/'
# Exponentiation gets compiled to repeat multiplication,
# requires constant exponent
elif isinstance(expr.op, ast.Pow):
assert isinstance(expr.right, ast.Num)
if expr.right.n == 0:
return [['set', target, 1]]
elif expr.right.n == 1:
return flatten_expr(target, expr.left)
else: # This could be made more efficient via square-and-multiply but oh well
if isinstance(expr.left, (ast.Name, ast.Num)):
nxt = base = expr.left.id if isinstance(expr.left, ast.Name) else expr.left.n
o = []
else:
nxt = base = mksymbol()
o = flatten_expr(base, expr.left)
for i in range(1, expr.right.n):
latest = nxt
nxt = target if i == expr.right.n - 1 else mksymbol()
o.append(['*', nxt, latest, base])
return o
else:
raise Exception("Bad operation: " % ast.dump(stmt.op))
# If the subexpression is a variable or a number, then include it directly
if isinstance(expr.left, (ast.Name, ast.Num)):
var1 = expr.left.id if isinstance(expr.left, ast.Name) else expr.left.n
sub1 = []
# If one of the subexpressions is itself a compound expression, recursively
# apply this method to it using an intermediate variable
else:
var1 = mksymbol()
sub1 = flatten_expr(var1, expr.left)
# Same for right subexpression as for left subexpression
if isinstance(expr.right, (ast.Name, ast.Num)):
var2 = expr.right.id if isinstance(expr.right, ast.Name) else expr.right.n
sub2 = []
else:
var2 = mksymbol()
sub2 = flatten_expr(var2, expr.right)
# Last expression represents the assignment; sub1 and sub2 represent the
# processing for the subexpression if any
return sub1 + sub2 + [[op, target, var1, var2]]
else:
raise Exception("Unexpected statement value: %r" % stmt.value)
# Adds a variable or number into one of the vectors; if it's a variable
# then the slot associated with that variable is set to 1, and if it's
# a number then the slot associated with 1 gets set to that number
def insert_var(arr, varz, var, used, reverse=False):
if isinstance(var, str):
if var not in used:
raise Exception("Using a variable before it is set!")
arr[varz.index(var)] += (-1 if reverse else 1)
elif isinstance(var, int):
arr[0] += var * (-1 if reverse else 1)
# Maps input, output and intermediate variables to indices
def get_var_placement(inputs, flatcode):
return ['~one'] + [x for x in inputs] + ['~out'] + [c[1] for c in flatcode if c[1] not in inputs and c[1] != '~out']
# Convert the flattened code generated above into a rank-1 constraint system
def flatcode_to_r1cs(inputs, flatcode):
varz = get_var_placement(inputs, flatcode)
A, B, C = [], [], []
used = {i: True for i in inputs}
for x in flatcode:
a, b, c = [0] * len(varz), [0] * len(varz), [0] * len(varz)
if x[1] in used:
raise Exception("Variable already used: %r" % x[1])
used[x[1]] = True
if x[0] == 'set':
a[varz.index(x[1])] += 1
insert_var(a, varz, x[2], used, reverse=True)
b[0] = 1
elif x[0] == '+' or x[0] == '-':
c[varz.index(x[1])] = 1
insert_var(a, varz, x[2], used)
insert_var(a, varz, x[3], used, reverse=(x[0] == '-'))
b[0] = 1
elif x[0] == '*':
c[varz.index(x[1])] = 1
insert_var(a, varz, x[2], used)
insert_var(b, varz, x[3], used)
elif x[0] == '/':
insert_var(c, varz, x[2], used)
a[varz.index(x[1])] = 1
insert_var(b, varz, x[3], used)
A.append(a)
B.append(b)
C.append(c)
return A, B, C
# Get a variable or number given an existing input vector
def grab_var(varz, assignment, var):
if isinstance(var, str):
return assignment[varz.index(var)]
elif isinstance(var, int):
return var
else:
raise Exception("What kind of expression is this? %r" % var)
# Goes through flattened code and completes the input vector
def assign_variables(inputs, input_vars, flatcode):
varz = get_var_placement(inputs, flatcode)
assignment = [0] * len(varz)
assignment[0] = 1
for i, inp in enumerate(input_vars):
assignment[i + 1] = inp
for x in flatcode:
if x[0] == 'set':
assignment[varz.index(x[1])] = grab_var(varz, assignment, x[2])
elif x[0] == '+':
assignment[varz.index(x[1])] = grab_var(varz, assignment, x[2]) + grab_var(varz, assignment, x[3])
elif x[0] == '-':
assignment[varz.index(x[1])] = grab_var(varz, assignment, x[2]) - grab_var(varz, assignment, x[3])
elif x[0] == '*':
assignment[varz.index(x[1])] = grab_var(varz, assignment, x[2]) * grab_var(varz, assignment, x[3])
elif x[0] == '/':
assignment[varz.index(x[1])] = grab_var(varz, assignment, x[2]) / grab_var(varz, assignment, x[3])
return assignment
def code_to_r1cs_with_inputs(code, input_vars):
inputs, body = extract_inputs_and_body(parse(code))
print 'Inputs'
print inputs
print 'Body'
print body
flatcode = flatten_body(body)
print 'Flatcode'
print flatcode
print 'Input var assignment'
print get_var_placement(inputs, flatcode)
A, B, C = flatcode_to_r1cs(inputs, flatcode)
r = assign_variables(inputs, input_vars, flatcode)
return r, A, B, C
r, A, B, C = code_to_r1cs_with_inputs("""
def qeval(x):
y = x**3
return y + x + 5
""", [3])
print 'r'
print r
print 'A'
for x in A: print x
print 'B'
for x in B: print x
print 'C'
for x in C: print x