diff --git a/casper_sm/.test.py.swp b/casper_sm/.test.py.swp deleted file mode 100644 index 19aed53..0000000 Binary files a/casper_sm/.test.py.swp and /dev/null differ diff --git a/zksnark/code_to_r1cs.py b/zksnark/code_to_r1cs.py new file mode 100644 index 0000000..d85a61d --- /dev/null +++ b/zksnark/code_to_r1cs.py @@ -0,0 +1,238 @@ +import ast +if 'arg' not in dir(ast): + ast.arg = type(None) + +def parse(code): + return ast.parse(code).body + +# Takes code of the form +# def foo(arg1, arg2 ...): +# x = arg1 + arg2 +# y = ... +# return x + y +# And extracts the inputs and the body, where +# it expects the body to be a sequence of +# variable assignments (variables are immutable; +# can only be set once) and a return statement at the end +def extract_inputs_and_body(code): + o = [] + if len(code) != 1 or not isinstance(code[0], ast.FunctionDef): + raise Exception("Expecting function declaration") + # Gather the list of input variables + inputs = [] + for arg in code[0].args.args: + if isinstance(arg, ast.arg): + assert isinstance(arg.arg, str) + inputs.append(arg.arg) + elif isinstance(arg, ast.Name): + inputs.append(arg.id) + else: + raise Exception("Invalid arg: %r" % ast.dump(arg)) + # Gather the body + body = [] + returned = False + for c in code[0].body: + if not isinstance(c, (ast.Assign, ast.Return)): + raise Exception("Expected variable assignment or return") + if returned: + raise Exception("Cannot do stuff after a return statement") + if isinstance(c, ast.Return): + returned = True + body.append(c) + return inputs, body + +# Convert a body with potentially complex expressions into +# simple expressions of the form x = y or x = y * z +def flatten_body(body): + o = [] + for c in body: + o.extend(flatten_stmt(c)) + return o + +# Generate a dummy variable +next_symbol = [0] +def mksymbol(): + next_symbol[0] += 1 + return 'sym_'+str(next_symbol[0]) + +# "Flatten" a single statement into a list of simple statements. +# First extract the target variable, then flatten the expression +def flatten_stmt(stmt): + # Get target variable + if isinstance(stmt, ast.Assign): + assert len(stmt.targets) == 1 and isinstance(stmt.targets[0], ast.Name) + target = stmt.targets[0].id + elif isinstance(stmt, ast.Return): + target = '~out' + # Get inner content + return flatten_expr(target, stmt.value) + +# Main method for flattening an expression +def flatten_expr(target, expr): + # x = y + if isinstance(expr, ast.Name): + return [['set', target, expr.id]] + # x = 5 + elif isinstance(expr, ast.Num): + return [['set', target, expr.n]] + # x = y (op) z + # Or, for that matter, x = y (op) 5 + elif isinstance(expr, ast.BinOp): + if isinstance(expr.op, ast.Add): + op = '+' + elif isinstance(expr.op, ast.Mult): + op = '*' + elif isinstance(expr.op, ast.Sub): + op = '-' + elif isinstance(expr.op, ast.Div): + op = '/' + # Exponentiation gets compiled to repeat multiplication, + # requires constant exponent + elif isinstance(expr.op, ast.Pow): + assert isinstance(expr.right, ast.Num) + if expr.right.n == 0: + return [['set', target, 1]] + elif expr.right.n == 1: + return flatten_expr(target, expr.left) + else: # This could be made more efficient via square-and-multiply but oh well + if isinstance(expr.left, (ast.Name, ast.Num)): + nxt = base = expr.left.id if isinstance(expr.left, ast.Name) else expr.left.n + o = [] + else: + nxt = base = mksymbol() + o = flatten_expr(base, expr.left) + for i in range(1, expr.right.n): + latest = nxt + nxt = target if i == expr.right.n - 1 else mksymbol() + o.append(['*', nxt, latest, base]) + return o + else: + raise Exception("Bad operation: " % ast.dump(stmt.op)) + # If the subexpression is a variable or a number, then include it directly + if isinstance(expr.left, (ast.Name, ast.Num)): + var1 = expr.left.id if isinstance(expr.left, ast.Name) else expr.left.n + sub1 = [] + # If one of the subexpressions is itself a compound expression, recursively + # apply this method to it using an intermediate variable + else: + var1 = mksymbol() + sub1 = flatten_expr(var1, expr.left) + # Same for right subexpression as for left subexpression + if isinstance(expr.right, (ast.Name, ast.Num)): + var2 = expr.right.id if isinstance(expr.right, ast.Name) else expr.right.n + sub2 = [] + else: + var2 = mksymbol() + sub2 = flatten_expr(var2, expr.right) + # Last expression represents the assignment; sub1 and sub2 represent the + # processing for the subexpression if any + return sub1 + sub2 + [[op, target, var1, var2]] + else: + raise Exception("Unexpected statement value: %r" % stmt.value) + +# Adds a variable or number into one of the vectors; if it's a variable +# then the slot associated with that variable is set to 1, and if it's +# a number then the slot associated with 1 gets set to that number +def insert_var(arr, varz, var, used, reverse=False): + if isinstance(var, str): + if var not in used: + raise Exception("Using a variable before it is set!") + arr[varz.index(var)] += (-1 if reverse else 1) + elif isinstance(var, int): + arr[0] += var * (-1 if reverse else 1) + +# Maps input, output and intermediate variables to indices +def get_var_placement(inputs, flatcode): + return ['~one'] + [x for x in inputs] + ['~out'] + [c[1] for c in flatcode if c[1] not in inputs and c[1] != '~out'] + + +# Convert the flattened code generated above into a rank-1 constraint system +def flatcode_to_r1cs(inputs, flatcode): + varz = get_var_placement(inputs, flatcode) + A, B, C = [], [], [] + used = {i: True for i in inputs} + for x in flatcode: + a, b, c = [0] * len(varz), [0] * len(varz), [0] * len(varz) + if x[1] in used: + raise Exception("Variable already used: %r" % x[1]) + used[x[1]] = True + if x[0] == 'set': + a[varz.index(x[1])] += 1 + insert_var(a, varz, x[2], used, reverse=True) + b[0] = 1 + elif x[0] == '+' or x[0] == '-': + c[varz.index(x[1])] = 1 + insert_var(a, varz, x[2], used) + insert_var(a, varz, x[3], used, reverse=(x[0] == '-')) + b[0] = 1 + elif x[0] == '*': + c[varz.index(x[1])] = 1 + insert_var(a, varz, x[2], used) + insert_var(b, varz, x[3], used) + elif x[0] == '/': + insert_var(c, varz, x[2], used) + a[varz.index(x[1])] = 1 + insert_var(b, varz, x[3], used) + A.append(a) + B.append(b) + C.append(c) + return A, B, C + +# Get a variable or number given an existing input vector +def grab_var(varz, assignment, var): + if isinstance(var, str): + return assignment[varz.index(var)] + elif isinstance(var, int): + return var + else: + raise Exception("What kind of expression is this? %r" % var) + +# Goes through flattened code and completes the input vector +def assign_variables(inputs, input_vars, flatcode): + varz = get_var_placement(inputs, flatcode) + assignment = [0] * len(varz) + assignment[0] = 1 + for i, inp in enumerate(input_vars): + assignment[i + 1] = inp + for x in flatcode: + if x[0] == 'set': + assignment[varz.index(x[1])] = grab_var(varz, assignment, x[2]) + elif x[0] == '+': + assignment[varz.index(x[1])] = grab_var(varz, assignment, x[2]) + grab_var(varz, assignment, x[3]) + elif x[0] == '-': + assignment[varz.index(x[1])] = grab_var(varz, assignment, x[2]) - grab_var(varz, assignment, x[3]) + elif x[0] == '*': + assignment[varz.index(x[1])] = grab_var(varz, assignment, x[2]) * grab_var(varz, assignment, x[3]) + elif x[0] == '/': + assignment[varz.index(x[1])] = grab_var(varz, assignment, x[2]) / grab_var(varz, assignment, x[3]) + return assignment + + +def code_to_r1cs_with_inputs(code, input_vars): + inputs, body = extract_inputs_and_body(parse(code)) + print 'Inputs' + print inputs + print 'Body' + print body + flatcode = flatten_body(body) + print 'Flatcode' + print flatcode + print 'Input var assignment' + print get_var_placement(inputs, flatcode) + A, B, C = flatcode_to_r1cs(inputs, flatcode) + r = assign_variables(inputs, input_vars, flatcode) + return r, A, B, C + +r, A, B, C = code_to_r1cs_with_inputs(""" +def qeval(x): + y = x**3 + return y + x + 5 +""", [3]) +print 'r' +print r +print 'A' +for x in A: print x +print 'B' +for x in B: print x +print 'C' +for x in C: print x diff --git a/zksnark/qap_creator.py b/zksnark/qap_creator.py new file mode 100644 index 0000000..4340f70 --- /dev/null +++ b/zksnark/qap_creator.py @@ -0,0 +1,132 @@ +# Polynomials are stored as arrays, where the ith element in +# the array is the ith degree coefficient + +# Multiply two polynomials +def multiply_polys(a, b): + o = [0] * (len(a) + len(b) - 1) + for i in range(len(a)): + for j in range(len(b)): + o[i + j] += a[i] * b[j] + return o + +# Add two polynomials +def add_polys(a, b, subtract=False): + o = [0] * max(len(a), len(b)) + for i in range(len(a)): + o[i] += a[i] + for i in range(len(b)): + o[i] += b[i] * (-1 if subtract else 1) # Reuse the function structure for subtraction + return o + +def subtract_polys(a, b): + return add_polys(a, b, subtract=True) + +# Divide a/b, return quotient and remainder +def div_polys(a, b): + o = [0] * (len(a) - len(b) + 1) + remainder = a + while len(remainder) >= len(b): + leading_fac = remainder[-1] / b[-1] + pos = len(remainder) - len(b) + o[pos] = leading_fac + remainder = subtract_polys(remainder, multiply_polys(b, [0] * pos + [leading_fac]))[:-1] + return o, remainder + +# Evaluate a polynomial at a point +def eval_poly(poly, x): + return sum([poly[i] * x**i for i in range(len(poly))]) + +# Make a polynomial which is zero at {1, 2 ... total_pts}, except +# for `point_loc` where the value is `height` +def mk_singleton(point_loc, height, total_pts): + fac = 1 + for i in range(1, total_pts + 1): + if i != point_loc: + fac *= point_loc - i + o = [height * 1.0 / fac] + for i in range(1, total_pts + 1): + if i != point_loc: + o = multiply_polys(o, [-i, 1]) + return o + +# Assumes vec[0] = p(1), vec[1] = p(2), etc, tries to find p, +# expresses result as [deg 0 coeff, deg 1 coeff...] +def lagrange_interp(vec): + o = [] + for i in range(len(vec)): + o = add_polys(o, mk_singleton(i + 1, vec[i], len(vec))) + for i in range(len(vec)): + assert abs(eval_poly(o, i + 1) - vec[i] < 10**-10), \ + (o, eval_poly(o, i + 1), i+1) + return o + +def transpose(matrix): + return list(map(list, zip(*matrix))) + +# A, B, C = matrices of m vectors of length n, where for each +# 0 <= i < m, we want to satisfy A[i] * B[i] - C[i] = 0 +def r1cs_to_qap(A, B, C): + A, B, C = transpose(A), transpose(B), transpose(C) + new_A = [lagrange_interp(a) for a in A] + new_B = [lagrange_interp(b) for b in B] + new_C = [lagrange_interp(c) for c in C] + Z = [1] + for i in range(1, len(A[0]) + 1): + Z = multiply_polys(Z, [-i, 1]) + return (new_A, new_B, new_C, Z) + +def create_solution_polynomials(r, new_A, new_B, new_C): + Apoly = [] + for rval, a in zip(r, new_A): + Apoly = add_polys(Apoly, multiply_polys([rval], a)) + Bpoly = [] + for rval, b in zip(r, new_B): + Bpoly = add_polys(Bpoly, multiply_polys([rval], b)) + Cpoly = [] + for rval, c in zip(r, new_C): + Cpoly = add_polys(Cpoly, multiply_polys([rval], c)) + o = subtract_polys(multiply_polys(Apoly, Bpoly), Cpoly) + for i in range(1, len(new_A[0]) + 1): + assert abs(eval_poly(o, i)) < 10**-10, (eval_poly(o, i), i) + return Apoly, Bpoly, Cpoly, o + +def create_divisor_polynomial(sol, Z): + quot, rem = div_polys(sol, Z) + for x in rem: + assert abs(x) < 10**-10 + return quot + +r = [1, 3, 35, 9, 27, 30] +A = [[0, 1, 0, 0, 0, 0], + [0, 0, 0, 1, 0, 0], + [0, 1, 0, 0, 1, 0], + [5, 0, 0, 0, 0, 1]] +B = [[0, 1, 0, 0, 0, 0], + [0, 1, 0, 0, 0, 0], + [1, 0, 0, 0, 0, 0], + [1, 0, 0, 0, 0, 0]] +C = [[0, 0, 0, 1, 0, 0], + [0, 0, 0, 0, 1, 0], + [0, 0, 0, 0, 0, 1], + [0, 0, 1, 0, 0, 0]] + +Ap, Bp, Cp, Z = r1cs_to_qap(A, B, C) +print 'Ap' +for x in Ap: print x +print 'Bp' +for x in Bp: print x +print 'Cp' +for x in Cp: print x +print 'Z' +print Z +Apoly, Bpoly, Cpoly, sol = create_solution_polynomials(r, Ap, Bp, Cp) +print 'Apoly' +print Apoly +print 'Bpoly' +print Bpoly +print 'Cpoly' +print Cpoly +print 'Sol' +print sol +print 'Z cofactor' +print create_divisor_polynomial(sol, Z)