diff --git a/.cirrus.yml b/.cirrus.yml index ffbd820..727a73f 100644 --- a/.cirrus.yml +++ b/.cirrus.yml @@ -322,3 +322,10 @@ task: test_script: - ./ci/cirrus.sh << : *CAT_LOGS + +task: + name: "sage prover" + << : *LINUX_CONTAINER + test_script: + - cd sage + - sage prove_group_implementations.sage diff --git a/ci/linux-debian.Dockerfile b/ci/linux-debian.Dockerfile index fdba12a..5cccbb5 100644 --- a/ci/linux-debian.Dockerfile +++ b/ci/linux-debian.Dockerfile @@ -19,7 +19,8 @@ RUN apt-get install --no-install-recommends --no-upgrade -y \ gcc-arm-linux-gnueabihf libc6-dev-armhf-cross libc6-dbg:armhf \ gcc-aarch64-linux-gnu libc6-dev-arm64-cross libc6-dbg:arm64 \ gcc-powerpc64le-linux-gnu libc6-dev-ppc64el-cross libc6-dbg:ppc64el \ - wine gcc-mingw-w64-x86-64 + wine gcc-mingw-w64-x86-64 \ + sagemath # Run a dummy command in wine to make it set up configuration RUN wine64-stable xcopy || true diff --git a/sage/group_prover.sage b/sage/group_prover.sage index b200bfe..9305c21 100644 --- a/sage/group_prover.sage +++ b/sage/group_prover.sage @@ -164,6 +164,9 @@ class constraints: def negate(self): return constraints(zero=self.nonzero, nonzero=self.zero) + def map(self, fun): + return constraints(zero={fun(k): v for k, v in self.zero.items()}, nonzero={fun(k): v for k, v in self.nonzero.items()}) + def __add__(self, other): zero = self.zero.copy() zero.update(other.zero) @@ -177,6 +180,30 @@ class constraints: def __repr__(self): return "%s" % self +def normalize_factor(p): + """Normalizes the sign of primitive polynomials (as returned by factor()) + + This function ensures that the polynomial has a positive leading coefficient. + + This is necessary because recent sage versions (starting with v9.3 or v9.4, + we don't know) are inconsistent about the placement of the minus sign in + polynomial factorizations: + ``` + sage: R. = PolynomialRing(QQ,8,order='invlex') + sage: R((-2 * (bx - ax)) ^ 1).factor() + (-2) * (bx - ax) + sage: R((-2 * (bx - ax)) ^ 2).factor() + (4) * (-bx + ax)^2 + sage: R((-2 * (bx - ax)) ^ 3).factor() + (8) * (-bx + ax)^3 + ``` + """ + # Assert p is not 0 and that its non-zero coeffients are coprime. + # (We could just work with the primitive part p/p.content() but we want to be + # aware if factor() does not return a primitive part in future sage versions.) + assert p.content() == 1 + # Ensure that the first non-zero coefficient is positive. + return p if p.lc() > 0 else -p def conflicts(R, con): """Check whether any of the passed non-zero assumptions is implied by the zero assumptions""" @@ -204,10 +231,10 @@ def get_nonzero_set(R, assume): nonzero = set() for nz in map(numerator, assume.nonzero): for (f,n) in nz.factor(): - nonzero.add(f) + nonzero.add(normalize_factor(f)) rnz = zero.reduce(nz) for (f,n) in rnz.factor(): - nonzero.add(f) + nonzero.add(normalize_factor(f)) return nonzero @@ -222,27 +249,27 @@ def prove_nonzero(R, exprs, assume): return (False, [exprs[expr]]) allexprs = reduce(lambda a,b: numerator(a)*numerator(b), exprs, 1) for (f, n) in allexprs.factor(): - if f not in nonzero: + if normalize_factor(f) not in nonzero: ok = False if ok: return (True, None) ok = True - for (f, n) in zero.reduce(numerator(allexprs)).factor(): - if f not in nonzero: + for (f, n) in zero.reduce(allexprs).factor(): + if normalize_factor(f) not in nonzero: ok = False if ok: return (True, None) ok = True for expr in exprs: for (f,n) in numerator(expr).factor(): - if f not in nonzero: + if normalize_factor(f) not in nonzero: ok = False if ok: return (True, None) ok = True for expr in exprs: for (f,n) in zero.reduce(numerator(expr)).factor(): - if f not in nonzero: + if normalize_factor(f) not in nonzero: expl.add(exprs[expr]) if expl: return (False, list(expl)) @@ -254,7 +281,7 @@ def prove_zero(R, exprs, assume): """Check whether all of the passed expressions are provably zero, given assumptions""" r, e = prove_nonzero(R, dict(map(lambda x: (fastfrac(R, x.bot, 1), exprs[x]), exprs)), assume) if not r: - return (False, map(lambda x: "Possibly zero denominator: %s" % x, e)) + return (False, list(map(lambda x: "Possibly zero denominator: %s" % x, e))) zero = R.ideal(list(map(numerator, assume.zero))) nonzero = prod(x for x in assume.nonzero) expl = [] @@ -279,8 +306,8 @@ def describe_extra(R, assume, assumeExtra): if base not in zero: add = [] for (f, n) in numerator(base).factor(): - if f not in nonzero: - add += ["%s" % f] + if normalize_factor(f) not in nonzero: + add += ["%s" % normalize_factor(f)] if add: ret.add((" * ".join(add)) + " = 0 [%s]" % assumeExtra.zero[base]) # Iterate over the extra nonzero expressions @@ -288,8 +315,8 @@ def describe_extra(R, assume, assumeExtra): nzr = zeroextra.reduce(numerator(nz)) if nzr not in zeroextra: for (f,n) in nzr.factor(): - if zeroextra.reduce(f) not in nonzero: - ret.add("%s != 0" % zeroextra.reduce(f)) + if normalize_factor(zeroextra.reduce(f)) not in nonzero: + ret.add("%s != 0" % normalize_factor(zeroextra.reduce(f))) return ", ".join(x for x in ret) @@ -299,22 +326,21 @@ def check_symbolic(R, assumeLaw, assumeAssert, assumeBranch, require): if conflicts(R, assume): # This formula does not apply - return None + return (True, None) describe = describe_extra(R, assumeLaw + assumeBranch, assumeAssert) + if describe != "": + describe = " (assuming " + describe + ")" ok, msg = prove_zero(R, require.zero, assume) if not ok: - return "FAIL, %s fails (assuming %s)" % (str(msg), describe) + return (False, "FAIL, %s fails%s" % (str(msg), describe)) res, expl = prove_nonzero(R, require.nonzero, assume) if not res: - return "FAIL, %s fails (assuming %s)" % (str(expl), describe) + return (False, "FAIL, %s fails%s" % (str(expl), describe)) - if describe != "": - return "OK (assuming %s)" % describe - else: - return "OK" + return (True, "OK%s" % describe) def concrete_verify(c): diff --git a/sage/prove_group_implementations.sage b/sage/prove_group_implementations.sage index a97e732..49af9b1 100644 --- a/sage/prove_group_implementations.sage +++ b/sage/prove_group_implementations.sage @@ -292,15 +292,18 @@ def formula_secp256k1_gej_add_ge_old(branch, a, b): return (constraints(zero={b.Z - 1 : 'b.z=1', b.Infinity : 'b_finite'}), constraints(zero=zero, nonzero=nonzero), jacobianpoint(rx, ry, rz)) if __name__ == "__main__": - check_symbolic_jacobian_weierstrass("secp256k1_gej_add_var", 0, 7, 5, formula_secp256k1_gej_add_var) - check_symbolic_jacobian_weierstrass("secp256k1_gej_add_ge_var", 0, 7, 5, formula_secp256k1_gej_add_ge_var) - check_symbolic_jacobian_weierstrass("secp256k1_gej_add_zinv_var", 0, 7, 5, formula_secp256k1_gej_add_zinv_var) - check_symbolic_jacobian_weierstrass("secp256k1_gej_add_ge", 0, 7, 16, formula_secp256k1_gej_add_ge) - check_symbolic_jacobian_weierstrass("secp256k1_gej_add_ge_old [should fail]", 0, 7, 4, formula_secp256k1_gej_add_ge_old) + success = True + success = success & check_symbolic_jacobian_weierstrass("secp256k1_gej_add_var", 0, 7, 5, formula_secp256k1_gej_add_var) + success = success & check_symbolic_jacobian_weierstrass("secp256k1_gej_add_ge_var", 0, 7, 5, formula_secp256k1_gej_add_ge_var) + success = success & check_symbolic_jacobian_weierstrass("secp256k1_gej_add_zinv_var", 0, 7, 5, formula_secp256k1_gej_add_zinv_var) + success = success & check_symbolic_jacobian_weierstrass("secp256k1_gej_add_ge", 0, 7, 16, formula_secp256k1_gej_add_ge) + success = success & (not check_symbolic_jacobian_weierstrass("secp256k1_gej_add_ge_old [should fail]", 0, 7, 4, formula_secp256k1_gej_add_ge_old)) if len(sys.argv) >= 2 and sys.argv[1] == "--exhaustive": - check_exhaustive_jacobian_weierstrass("secp256k1_gej_add_var", 0, 7, 5, formula_secp256k1_gej_add_var, 43) - check_exhaustive_jacobian_weierstrass("secp256k1_gej_add_ge_var", 0, 7, 5, formula_secp256k1_gej_add_ge_var, 43) - check_exhaustive_jacobian_weierstrass("secp256k1_gej_add_zinv_var", 0, 7, 5, formula_secp256k1_gej_add_zinv_var, 43) - check_exhaustive_jacobian_weierstrass("secp256k1_gej_add_ge", 0, 7, 16, formula_secp256k1_gej_add_ge, 43) - check_exhaustive_jacobian_weierstrass("secp256k1_gej_add_ge_old [should fail]", 0, 7, 4, formula_secp256k1_gej_add_ge_old, 43) + success = success & check_exhaustive_jacobian_weierstrass("secp256k1_gej_add_var", 0, 7, 5, formula_secp256k1_gej_add_var, 43) + success = success & check_exhaustive_jacobian_weierstrass("secp256k1_gej_add_ge_var", 0, 7, 5, formula_secp256k1_gej_add_ge_var, 43) + success = success & check_exhaustive_jacobian_weierstrass("secp256k1_gej_add_zinv_var", 0, 7, 5, formula_secp256k1_gej_add_zinv_var, 43) + success = success & check_exhaustive_jacobian_weierstrass("secp256k1_gej_add_ge", 0, 7, 16, formula_secp256k1_gej_add_ge, 43) + success = success & (not check_exhaustive_jacobian_weierstrass("secp256k1_gej_add_ge_old [should fail]", 0, 7, 4, formula_secp256k1_gej_add_ge_old, 43)) + + sys.exit(int(not success)) diff --git a/sage/weierstrass_prover.sage b/sage/weierstrass_prover.sage index b770c6d..be9cfd4 100644 --- a/sage/weierstrass_prover.sage +++ b/sage/weierstrass_prover.sage @@ -184,6 +184,7 @@ def check_exhaustive_jacobian_weierstrass(name, A, B, branches, formula, p): if r: points.append(point) + ret = True for za in range(1, p): for zb in range(1, p): for pa in points: @@ -211,8 +212,11 @@ def check_exhaustive_jacobian_weierstrass(name, A, B, branches, formula, p): match = True r, e = concrete_verify(require) if not r: + ret = False print(" failure in branch %i for (%s,%s,%s,%s) + (%s,%s,%s,%s) = (%s,%s,%s,%s): %s" % (branch, pA.X, pA.Y, pA.Z, pA.Infinity, pB.X, pB.Y, pB.Z, pB.Infinity, pC.X, pC.Y, pC.Z, pC.Infinity, e)) + print() + return ret def check_symbolic_function(R, assumeAssert, assumeBranch, f, A, B, pa, pb, pA, pB, pC): @@ -244,15 +248,21 @@ def check_symbolic_jacobian_weierstrass(name, A, B, branches, formula): print("Formula " + name + ":") count = 0 + ret = True for branch in range(branches): assumeFormula, assumeBranch, pC = formula(branch, pA, pB) + assumeBranch = assumeBranch.map(lift) + assumeFormula = assumeFormula.map(lift) pC.X = lift(pC.X) pC.Y = lift(pC.Y) pC.Z = lift(pC.Z) pC.Infinity = lift(pC.Infinity) for key in laws_jacobian_weierstrass: - res[key].append((check_symbolic_function(R, assumeFormula, assumeBranch, laws_jacobian_weierstrass[key], A, B, pa, pb, pA, pB, pC), branch)) + success, msg = check_symbolic_function(R, assumeFormula, assumeBranch, laws_jacobian_weierstrass[key], A, B, pa, pb, pA, pB, pC) + if not success: + ret = False + res[key].append((msg, branch)) for key in res: print(" %s:" % key) @@ -262,3 +272,4 @@ def check_symbolic_jacobian_weierstrass(name, A, B, branches, formula): print(" branch %i: %s" % (x[1], x[0])) print() + return ret