diff --git a/constantine/pairing/mul_fp12_by_lines.nim b/constantine/pairing/mul_fp12_by_lines.nim index 065db1a..1fb6ba5 100644 --- a/constantine/pairing/mul_fp12_by_lines.nim +++ b/constantine/pairing/mul_fp12_by_lines.nim @@ -131,30 +131,61 @@ func mul_sparse_by_line_xyz000*[C: static Curve]( # = a0 b0 + a2 b0 - v0 + v1 # = a2 b0 + v1 - var b0 {.noInit.}, v0{.noInit.}, v1{.noInit.}, t{.noInit.}: Fp4[C] + when false: + var b0 {.noInit.}, v0{.noInit.}, v1{.noInit.}, t{.noInit.}: Fp4[C] - b0.c0 = l.x - b0.c1 = l.y + b0.c0 = l.x + b0.c1 = l.y - v0.prod(f.c0, b0) - v1.mul_sparse_by_y0(f.c1, l.z) + v0.prod(f.c0, b0) + v1.mul_sparse_by_x0(f.c1, l.z) - # r1 = (a0 + a1) * (b0 + b1) - v0 - v1 - f.c1 += f.c0 # r1 = a0 + a1 - t = b0 - t.c0 += l.z # t = b0 + b1 - f.c1 *= t # r2 = (a0 + a1)(b0 + b1) - f.c1 -= v0 - f.c1 -= v1 # r2 = (a0 + a1)(b0 + b1) - v0 - v1 + # r1 = (a0 + a1) * (b0 + b1) - v0 - v1 + f.c1 += f.c0 # r1 = a0 + a1 + t = b0 + t.c0 += l.z # t = b0 + b1 + f.c1 *= t # r2 = (a0 + a1)(b0 + b1) + f.c1 -= v0 + f.c1 -= v1 # r2 = (a0 + a1)(b0 + b1) - v0 - v1 - # r0 = ξ a2 b1 + v0 - f.c0.mul_sparse_by_y0(f.c2, l.z) - f.c0 *= SexticNonResidue - f.c0 += v0 + # r0 = ξ a2 b1 + v0 + f.c0.mul_sparse_by_x0(f.c2, l.z) + f.c0 *= SexticNonResidue + f.c0 += v0 - # r2 = a2 b0 + v1 - f.c2 *= b0 - f.c2 += v1 + # r2 = a2 b0 + v1 + f.c2 *= b0 + f.c2 += v1 + + else: # Lazy reduction + var V0{.noInit.}, V1{.noInit.}, f2x{.noInit.}: doublePrec(Fp4[C]) + var t{.noInit.}: Fp2[C] + + V0.prod2x_disjoint(f.c0, l.x, l.y) + V1.mul2x_sparse_by_x0(f.c1, l.z) + + # r1 = (a0 + a1) * (b0 + b1) - v0 - v1 + when false: # TODO: what's the condition? + f.c1.sumUnr(f.c1, f.c0) + t.sumUnr(l.x, l.z) # b0 is (x, y) + else: + f.c1.sum(f.c1, f.c0) + t.sum(l.x, l.z) # b0 is (x, y) + f2x.prod2x_disjoint(f.c1, t, l.y) # b1 is (z, 0) + f2x.diff2xMod(f2x, V0) + f2x.diff2xMod(f2x, V1) + f.c1.redc2x(f2x) + + # r0 = ξ a2 b1 + v0 + f2x.mul2x_sparse_by_x0(f.c2, l.z) + f2x.prod2x(f2x, SexticNonResidue) + f2x.sum2xMod(f2x, V0) + f.c0.redc2x(f2x) + + # r2 = a2 b0 + v1 + f2x.prod2x_disjoint(f.c2, l.x, l.y) + f2x.sum2xMod(f2x, V1) + f.c2.redc2x(f2x) func mul_sparse_by_line_xy000z*[C: static Curve]( f: var Fp12[C], l: Line[Fp2[C]]) = @@ -182,31 +213,63 @@ func mul_sparse_by_line_xy000z*[C: static Curve]( # r2 = (a0 + a2) * (b0 + b2) - v0 - v2 + v1 # = (a0 + a2) * (b0 + b2) - v0 - v2 - var b0 {.noInit.}, v0{.noInit.}, v2{.noInit.}, t{.noInit.}: Fp4[C] + when false: + var b0 {.noInit.}, v0{.noInit.}, v2{.noInit.}, t{.noInit.}: Fp4[C] - b0.c0 = l.x - b0.c1 = l.y + b0.c0 = l.x + b0.c1 = l.y - v0.prod(f.c0, b0) - v2.mul_sparse_by_0y(f.c2, l.z) + v0.prod(f.c0, b0) + v2.mul_sparse_by_0y(f.c2, l.z) - # r2 = (a0 + a2) * (b0 + b2) - v0 - v2 - f.c2 += f.c0 # r2 = a0 + a2 - t = b0 - t.c1 += l.z # t = b0 + b2 - f.c2 *= t # r2 = (a0 + a2)(b0 + b2) - f.c2 -= v0 - f.c2 -= v2 # r2 = (a0 + a2)(b0 + b2) - v0 - v2 + # r2 = (a0 + a2) * (b0 + b2) - v0 - v2 + f.c2 += f.c0 # r2 = a0 + a2 + t = b0 + t.c1 += l.z # t = b0 + b2 + f.c2 *= t # r2 = (a0 + a2)(b0 + b2) + f.c2 -= v0 + f.c2 -= v2 # r2 = (a0 + a2)(b0 + b2) - v0 - v2 - # r0 = ξ a1 b2 + v0 - f.c0.mul_sparse_by_0y(f.c1, l.z) - f.c0 *= SexticNonResidue - f.c0 += v0 + # r0 = ξ a1 b2 + v0 + f.c0.mul_sparse_by_0y(f.c1, l.z) + f.c0 *= SexticNonResidue + f.c0 += v0 - # r1 = a1 b0 + ξ v2 - f.c1 *= b0 - v2 *= SexticNonResidue - f.c1 += v2 + # r1 = a1 b0 + ξ v2 + f.c1 *= b0 + v2 *= SexticNonResidue + f.c1 += v2 + + else: # Lazy reduction + var V0{.noInit.}, V2{.noInit.}, f2x{.noInit.}: doublePrec(Fp4[C]) + var t{.noInit.}: Fp2[C] + + V0.prod2x_disjoint(f.c0, l.x, l.y) + V2.mul2x_sparse_by_0y(f.c2, l.z) + + # r2 = (a0 + a2) * (b0 + b2) - v0 - v2 + when false: # TODO: what's the condition + f.c2.sumUnr(f.c2, f.c0) + t.sumUnr(l.y, l.z) # b0 is (x, y) + else: + f.c2.sum(f.c2, f.c0) + t.sum(l.y, l.z) # b0 is (x, y) + f2x.prod2x_disjoint(f.c2, l.x, t) # b2 is (0, z) + f2x.diff2xMod(f2x, V0) + f2x.diff2xMod(f2x, V2) + f.c2.redc2x(f2x) + + # r0 = ξ a1 b2 + v0 + f2x.mul2x_sparse_by_0y(f.c1, l.z) + f2x.prod2x(f2x, SexticNonResidue) + f2x.sum2xMod(f2x, V0) + f.c0.redc2x(f2x) + + # r1 = a1 b0 + ξ v2 + f2x.prod2x_disjoint(f.c1, l.x, l.y) + V2.prod2x(V2, SexticNonResidue) + f2x.sum2xMod(f2x, V2) + f.c1.redc2x(f2x) func mul*[C](f: var Fp12[C], line: Line[Fp2[C]]) {.inline.} = when C.getSexticTwist() == D_Twist: diff --git a/constantine/pairing/mul_fp6_by_lines.nim b/constantine/pairing/mul_fp6_by_lines.nim index 5d012a8..d8d593a 100644 --- a/constantine/pairing/mul_fp6_by_lines.nim +++ b/constantine/pairing/mul_fp6_by_lines.nim @@ -26,10 +26,10 @@ import func mul_sparse_by_line_xyz000*[C: static Curve]( f: var Fp6[C], l: Line[Fp[C]]) = - ## Sparse multiplication of an 𝔽p12 element - ## by a sparse 𝔽p12 element coming from an D-Twist line function. + ## Sparse multiplication of an 𝔽p6 element + ## by a sparse 𝔽p6 element coming from an D-Twist line function. ## The sparse element is represented by a packed Line type - ## with coordinates (x,y,z) matching 𝔽p12 coordinates xyz000 + ## with coordinates (x,y,z) matching 𝔽p6 coordinates xyz000 static: doAssert C.getSexticTwist() == D_Twist diff --git a/constantine/tower_field_extensions/extension_fields.nim b/constantine/tower_field_extensions/extension_fields.nim index 1d7dafc..6ddc08c 100644 --- a/constantine/tower_field_extensions/extension_fields.nim +++ b/constantine/tower_field_extensions/extension_fields.nim @@ -239,17 +239,17 @@ func prod*(r: var ExtensionField, a: ExtensionField, b: static int) = # ############################################################ type - QuadraticExt2x[F] = object + QuadraticExt2x*[F] = object ## Quadratic Extension field for lazy reduced fields - coords: array[2, F] + coords*: array[2, F] - CubicExt2x[F] = object + CubicExt2x*[F] = object ## Cubic Extension field for lazy reduced fields - coords: array[3, F] + coords*: array[3, F] - ExtensionField2x[F] = QuadraticExt2x[F] or CubicExt2x[F] + ExtensionField2x*[F] = QuadraticExt2x[F] or CubicExt2x[F] -template doublePrec(T: type ExtensionField): type = +template doublePrec*(T: type ExtensionField): type = # For now naive unrolling, recursive template don't match # and I don't want to deal with types in macros when T is QuadraticExt: @@ -258,40 +258,43 @@ template doublePrec(T: type ExtensionField): type = elif T.F is Fp: # Fp2Dbl QuadraticExt2x[doublePrec(T.F)] elif T is CubicExt: - when T.F is QuadraticExt: # Fp6Dbl - CubicExt2x[QuadraticExt2x[doublePrec(T.F.F)]] + when T.F is QuadraticExt: # + when T.F.F is QuadraticExt: # Fp12 + CubicExt2x[QuadraticExt2x[QuadraticExt2x[doublePrec(T.F.F.F)]]] + elif T.F.F is Fp: # Fp6 + CubicExt2x[QuadraticExt2x[doublePrec(T.F.F)]] -func has1extraBit(F: type Fp): bool = +func has1extraBit*(F: type Fp): bool = ## We construct extensions only on Fp (and not Fr) getSpareBits(F) >= 1 -func has2extraBits(F: type Fp): bool = +func has2extraBits*(F: type Fp): bool = ## We construct extensions only on Fp (and not Fr) getSpareBits(F) >= 2 -func has1extraBit(E: type ExtensionField): bool = +func has1extraBit*(E: type ExtensionField): bool = ## We construct extensions only on Fp (and not Fr) getSpareBits(Fp[E.F.C]) >= 1 -func has2extraBits(E: type ExtensionField): bool = +func has2extraBits*(E: type ExtensionField): bool = ## We construct extensions only on Fp (and not Fr) getSpareBits(Fp[E.F.C]) >= 2 template C(E: type ExtensionField2x): Curve = E.F.C -template c0(a: ExtensionField2x): auto = +template c0*(a: ExtensionField2x): auto = a.coords[0] -template c1(a: ExtensionField2x): auto = +template c1*(a: ExtensionField2x): auto = a.coords[1] -template c2(a: CubicExt2x): auto = +template c2*(a: CubicExt2x): auto = a.coords[2] -template `c0=`(a: var ExtensionField2x, v: auto) = +template `c0=`*(a: var ExtensionField2x, v: auto) = a.coords[0] = v -template `c1=`(a: var ExtensionField2x, v: auto) = +template `c1=`*(a: var ExtensionField2x, v: auto) = a.coords[1] = v -template `c2=`(a: var CubicExt2x, v: auto) = +template `c2=`*(a: var CubicExt2x, v: auto) = a.coords[2] = v # Initialization @@ -305,32 +308,32 @@ func setZero*(a: var ExtensionField2x) = # Abelian group # ------------------------------------------------------------------- -func sumUnr(r: var ExtensionField, a, b: ExtensionField) = +func sumUnr*(r: var ExtensionField, a, b: ExtensionField) = ## Sum ``a`` and ``b`` into ``r`` staticFor i, 0, a.coords.len: r.coords[i].sumUnr(a.coords[i], b.coords[i]) -func diff2xUnr(r: var ExtensionField2x, a, b: ExtensionField2x) = +func diff2xUnr*(r: var ExtensionField2x, a, b: ExtensionField2x) = ## Double-precision substraction without reduction staticFor i, 0, a.coords.len: r.coords[i].diff2xUnr(a.coords[i], b.coords[i]) -func diff2xMod(r: var ExtensionField2x, a, b: ExtensionField2x) = +func diff2xMod*(r: var ExtensionField2x, a, b: ExtensionField2x) = ## Double-precision modular substraction staticFor i, 0, a.coords.len: r.coords[i].diff2xMod(a.coords[i], b.coords[i]) -func sum2xUnr(r: var ExtensionField2x, a, b: ExtensionField2x) = +func sum2xUnr*(r: var ExtensionField2x, a, b: ExtensionField2x) = ## Double-precision addition without reduction staticFor i, 0, a.coords.len: r.coords[i].sum2xUnr(a.coords[i], b.coords[i]) -func sum2xMod(r: var ExtensionField2x, a, b: ExtensionField2x) = +func sum2xMod*(r: var ExtensionField2x, a, b: ExtensionField2x) = ## Double-precision modular addition staticFor i, 0, a.coords.len: r.coords[i].sum2xMod(a.coords[i], b.coords[i]) -func neg2xMod(r: var ExtensionField2x, a: ExtensionField2x) = +func neg2xMod*(r: var ExtensionField2x, a: ExtensionField2x) = ## Double-precision modular negation staticFor i, 0, a.coords.len: r.coords[i].neg2xMod(a.coords[i], b.coords[i]) @@ -338,7 +341,7 @@ func neg2xMod(r: var ExtensionField2x, a: ExtensionField2x) = # Reductions # ------------------------------------------------------------------- -func redc2x(r: var ExtensionField, a: ExtensionField2x) = +func redc2x*(r: var ExtensionField, a: ExtensionField2x) = ## Reduction staticFor i, 0, a.coords.len: r.coords[i].redc2x(a.coords[i]) @@ -346,7 +349,7 @@ func redc2x(r: var ExtensionField, a: ExtensionField2x) = # Multiplication by a small integer known at compile-time # ------------------------------------------------------------------- -func prod2x(r: var ExtensionField2x, a: ExtensionField2x, b: static int) = +func prod2x*(r: var ExtensionField2x, a: ExtensionField2x, b: static int) = ## Multiplication by a small integer known at compile-time for i in 0 ..< a.coords.len: r.coords[i].prod2x(a.coords[i], b) @@ -360,20 +363,20 @@ func prod2x(r: var FpDbl, a: FpDbl, _: type NonResidue){.inline.} = static: doAssert FpDbl.C.getNonResidueFp() != -1, "𝔽p2 should be specialized for complex extension" r.prod2x(a, FpDbl.C.getNonResidueFp()) -func prod2x[C: static Curve]( - r {.noalias.}: var QuadraticExt2x[FpDbl[C]], - a {.noalias.}: QuadraticExt2x[FpDbl[C]], +func prod2x*[C: static Curve]( + r: var QuadraticExt2x[FpDbl[C]], + a: QuadraticExt2x[FpDbl[C]], _: type NonResidue) {.inline.} = ## Multiplication by non-residue - ## ! no aliasing! const complex = C.getNonResidueFp() == -1 const U = C.getNonResidueFp2()[0] const V = C.getNonResidueFp2()[1] const Beta {.used.} = C.getNonResidueFp() when complex and U == 1 and V == 1: - r.c0.diff2xMod(a.c0, a.c1) - r.c1.sum2xMod(a.c0, a.c1) + let a1 = a.c1 + r.c1.sum2xMod(a.c0, a1) + r.c0.diff2xMod(a.c0, a1) else: # Case: # - BN254_Snarks, QNR_Fp: -1, SNR_Fp2: 9+1𝑖 (𝑖 = √-1) @@ -383,10 +386,10 @@ func prod2x[C: static Curve]( # mul_sparse_by_0v # r0 = β a1 v # r1 = a0 v - # r and a don't alias, we use `r` as a temp location - r.c1.prod2x(a.c1, V) - r.c0.prod2x(r.c1, NonResidue) + var t {.noInit.}: FpDbl[C] + t.prod2x(a.c1, V) r.c1.prod2x(a.c0, V) + r.c0.prod2x(t, NonResidue) else: # ξ = u + v x # and x² = β @@ -395,15 +398,43 @@ func prod2x[C: static Curve]( # => u c0 + β v c1 + (v c0 + u c1) x var t {.noInit.}: FpDbl[C] - r.c0.prod2x(a.c0, U) + t.prod2x(a.c0, U) when V == 1 and Beta == -1: # Case BN254_Snarks - r.c0.diff2xMod(r.c0, a.c1) # r0 = u c0 + β v c1 + t.diff2xMod(t, a.c1) # r0 = u c0 + β v c1 else: {.error: "Unimplemented".} - r.c1.prod2x(a.c0, V) - t.prod2x(a.c1, U) - r.c1.sum2xMod(r.c1, t) # r1 = v c0 + u c1 + + r.c1.prod2x(a.c1, U) + when V == 1: # r1 = v c0 + u c1 + r.c1.sum2xMod(r.c1, a.c0) + # aliasing: a.c0 is unused + `=`(r.c0, t) # "r.c0 = t", is refused by the compiler. + else: + {.error: "Unimplemented".} + +func prod2x*( + r: var QuadraticExt2x, + a: QuadraticExt2x, + _: type NonResidue) {.inline.} = + ## Multiplication by non-residue + static: doAssert not(r.c0 is FpDbl), "Wrong dispatch, there is a specific non-residue multiplication for the base extension." + let t = a.c0 + r.c0.prod2x(a.c1, NonResidue) + `=`(r.c1, t) # "r.c1 = t", is refused by the compiler. + +func prod2x*( + r: var CubicExt2x, + a: CubicExt2x, + _: type NonResidue) {.inline.} = + ## Multiplication by non-residue + ## For all curves γ = v with v the factor for cubic extension coordinate + ## and v³ = ξ + ## (c0 + c1 v + c2 v²) v => ξ c2 + c0 v + c1 v² + let t = a.c2 + r.c1 = a.c0 + r.c2 = a.c1 + r.c0.prod2x(t, NonResidue) # ############################################################ # # @@ -414,8 +445,8 @@ func prod2x[C: static Curve]( # Forward declarations # ---------------------------------------------------------------------- -func prod2x(r: var QuadraticExt2x, a, b: QuadraticExt) -func square2x(r: var QuadraticExt2x, a: QuadraticExt) +func prod2x*(r: var QuadraticExt2x, a, b: QuadraticExt) +func square2x*(r: var QuadraticExt2x, a: QuadraticExt) # Commutative ring implementation for complex quadratic extension fields # ---------------------------------------------------------------------- @@ -455,10 +486,9 @@ func square2x_complex(r: var QuadraticExt2x, a: QuadraticExt) = var t0 {.noInit.}, t1 {.noInit.}: typeof(a.c0) - # Require 2 extra bits - when QuadraticExt.has2extraBits(): + when QuadraticExt.has1extraBit(): t0.sumUnr(a.c1, a.c1) - t1.sum(a.c0, a.c1) + t1.sumUnr(a.c0, a.c1) else: t0.double(a.c1) t1.sum(a.c0, a.c1) @@ -481,7 +511,7 @@ func square2x_complex(r: var QuadraticExt2x, a: QuadraticExt) = # - cyclotomic square in Fp2 -> Fp6 -> Fp12 towering # needs Fp4 as special case -func prod2x_disjoint[Fdbl, F]( +func prod2x_disjoint*[Fdbl, F]( r: var QuadraticExt2x[FDbl], a: QuadraticExt[F], b0, b1: F) = @@ -502,17 +532,13 @@ func prod2x_disjoint[Fdbl, F]( t1.sum(b0, b1) r.c1.prod2x(t0, t1) # r1 = (a0 + a1)(b0 + b1) - when F.has1extraBit(): - r.c1.diff2xMod(r.c1, V0) - r.c1.diff2xMod(r.c1, V1) - else: - r.c1.diff2xMod(r.c1, V0) # r1 = (a0 + a1)(b0 + b1) - a0b0 - r.c1.diff2xMod(r.c1, V1) # r1 = (a0 + a1)(b0 + b1) - a0b0 - a1b1 + r.c1.diff2xMod(r.c1, V0) # r1 = (a0 + a1)(b0 + b1) - a0b0 + r.c1.diff2xMod(r.c1, V1) # r1 = (a0 + a1)(b0 + b1) - a0b0 - a1b1 r.c0.prod2x(V1, NonResidue) # r0 = β a1 b1 r.c0.sum2xMod(r.c0, V0) # r0 = a0 b0 + β a1 b1 -func square2x_disjoint[Fdbl, F]( +func square2x_disjoint*[Fdbl, F]( r: var QuadraticExt2x[FDbl], a0, a1: F) = ## Return (a0, a1)² in r @@ -535,17 +561,108 @@ func square2x_disjoint[Fdbl, F]( r.c1.diff2xMod(r.c1, V0) r.c1.diff2xMod(r.c1, V1) +# Sparse multiplication +# ------------------------------------------------------------------- + +func mul2x_sparse_by_x0*[Fdbl, F]( + r: var QuadraticExt2x[Fdbl], a: QuadraticExt[F], + sparseB: auto) = + ## Multiply `a` by `b` with sparse coordinates (x, 0) + ## On a generic quadratic extension field + # Algorithm (with β the non-residue in the base field) + # + # r0 = a0 b0 + β a1 b1 + # r1 = (a0 + a1) (b0 + b1) - a0 b0 - a1 b1 (Karatsuba) + # + # with b1 = 0, hence + # + # r0 = a0 b0 + # r1 = (a0 + a1) b0 - a0 b0 = a1 b0 + static: doAssert Fdbl is doublePrec(F) + + when typeof(sparseB) is typeof(a): + template b(): untyped = sparseB.c0 + elif typeof(sparseB) is typeof(a.c0): + template b(): untyped = sparseB + else: + {.error: "sparseB type is " & $typeof(sparseB) & + " which does not match with either a (" & $typeof(a) & + ") or a.c0 (" & $typeof(a.c0) & ")".} + + r.c0.prod2x(a.c0, b) + r.c1.prod2x(a.c1, b) + +func mul2x_sparse_by_0y*[Fdbl, F]( + r: var QuadraticExt2x[Fdbl], a: QuadraticExt[F], + sparseB: auto) = + ## Multiply `a` by `b` with sparse coordinates (0, y) + ## On a generic quadratic extension field + # Algorithm (with β the non-residue in the base field) + # + # r0 = a0 b0 + β a1 b1 + # r1 = (a0 + a1) (b0 + b1) - a0 b0 - a1 b1 (Karatsuba) + # + # with b0 = 0, hence + # + # r0 = β a1 b1 + # r1 = (a0 + a1) b1 - a1 b1 = a0 b1 + static: doAssert Fdbl is doublePrec(F) + + when typeof(sparseB) is typeof(a): + template b(): untyped = sparseB.c1 + elif typeof(sparseB) is typeof(a.c0): + template b(): untyped = sparseB + else: + {.error: "sparseB type is " & $typeof(sparseB) & + " which does not match with either a (" & $typeof(a) & + ") or a.c0 (" & $typeof(a.c0) & ")".} + + r.c0.prod2x(a.c1, b) + r.c0.prod2x(r.c0, NonResidue) + r.c1.prod2x(a.c0, b) + +# Inversion +# ------------------------------------------------------------------- + +func inv2xImpl(r: var QuadraticExt, a: QuadraticExt) = + ## Compute the multiplicative inverse of ``a`` + ## + ## The inverse of 0 is 0. + ## + ## Inversion routine is using lazy reduction + mixin fromComplexExtension + + # [2 Sqr, 1 Add] + var V0 {.noInit.}, V1 {.noInit.}: doublePrec(typeof(r.c0)) + var t {.noInit.}: typeof(r.c0) + V0.square2x(a.c0) + V1.square2x(a.c1) + when r.fromComplexExtension(): + V0.sum2xUnr(V0, V1) + else: + V1.prod2x(V1, NonResidue) + V0.diff2xMod(V0, V1) # v0 = a0² - β a1² (the norm / squared magnitude of a) + + # [1 Inv, 2 Sqr, 1 Add] + t.redc2x(V0) + t.inv() # v1 = 1 / (a0² - β a1²) + + # [1 Inv, 2 Mul, 2 Sqr, 1 Add, 1 Neg] + r.c0.prod(a.c0, t) # r0 = a0 / (a0² - β a1²) + t.neg() # v0 = -1 / (a0² - β a1²) + r.c1.prod(a.c1, t) # r1 = -a1 / (a0² - β a1²) + # Dispatch # ---------------------------------------------------------------------- -func prod2x(r: var QuadraticExt2x, a, b: QuadraticExt) = +func prod2x*(r: var QuadraticExt2x, a, b: QuadraticExt) = mixin fromComplexExtension when a.fromComplexExtension(): r.prod2x_complex(a, b) else: r.prod2x_disjoint(a, b.c0, b.c1) -func square2x(r: var QuadraticExt2x, a: QuadraticExt) = +func square2x*(r: var QuadraticExt2x, a: QuadraticExt) = mixin fromComplexExtension when a.fromComplexExtension(): r.square2x_complex(a) @@ -558,6 +675,98 @@ func square2x(r: var QuadraticExt2x, a: QuadraticExt) = # # # ############################################################ +# Commutative ring implementation for Cubic Extension Fields +# ------------------------------------------------------------------- + +func square2x_Chung_Hasan_SQR2(r: var CubicExt2x, a: CubicExt) = + ## Returns r = a² + var m01{.noInit.}, m12{.noInit.}: typeof(r.c0) # double-width + var t{.noInit.}: typeof(a.c0) # single width + + m01.prod2x(a.c0, a.c1) + m01.sum2xMod(m01, m01) # 2a₀a₁ + m12.prod2x(a.c1, a.c2) + m12.sum2xMod(m12, m12) # 2a₁a₂ + r.c0.square2x(a.c2) # borrow r₀ = a₂² for a moment + + # r₂ = (a₀ - a₁ + a₂)² + t.sum(a.c2, a.c0) + t -= a.c1 + r.c2.square2x(t) + + # r₂ = (a₀ - a₁ + a₂)² + 2a₀a₁ + 2a₁a₂ - a₂² + r.c2.sum2xMod(r.c2, m01) + r.c2.sum2xMod(r.c2, m12) + r.c2.diff2xMod(r.c2, r.c0) + + # r₁ = 2a₀a₁ + β a₂² + r.c1.prod2x(r.c0, NonResidue) + r.c1.sum2xMod(r.c1, m01) + + # r₂ = (a₀ - a₁ + a₂)² + 2a₀a₁ + 2a₁a₂ - a₀² - a₂² + r.c0.square2x(a.c0) + r.c2.diff2xMod(r.c2, r.c0) + + # r₀ = a₀² + β 2a₁a₂ + m12.prod2x(m12, NonResidue) + r.c0.sum2xMod(r.c0, m12) + +func prod2xImpl(r: var CubicExt2x, a, b: CubicExt) = + var V0 {.noInit.}, V1 {.noInit.}, V2 {.noinit.}: typeof(r.c0) + var t0 {.noInit.}, t1 {.noInit.}: typeof(a.c0) + + V0.prod2x(a.c0, b.c0) + V1.prod2x(a.c1, b.c1) + V2.prod2x(a.c2, b.c2) + + # r₀ = β ((a₁ + a₂)(b₁ + b₂) - v₁ - v₂) + v₀ + when false: # CubicExt.has1extraBit(): + t0.sumUnr(a.c1, a.c2) + t1.sumUnr(b.c1, b.c2) + else: + t0.sum(a.c1, a.c2) + t1.sum(b.c1, b.c2) + r.c0.prod2x(t0, t1) # r cannot alias a or b since it's double precision + r.c0.diff2xMod(r.c0, V1) + r.c0.diff2xMod(r.c0, V2) + r.c0.prod2x(r.c0, NonResidue) + r.c0.sum2xMod(r.c0, V0) + + # r₁ = (a₀ + a₁) * (b₀ + b₁) - v₀ - v₁ + β v₂ + when false: # CubicExt.has1extraBit(): + t0.sumUnr(a.c0, a.c1) + t1.sumUnr(b.c0, b.c1) + else: + t0.sum(a.c0, a.c1) + t1.sum(b.c0, b.c1) + r.c1.prod2x(t0, t1) + r.c1.diff2xMod(r.c1, V0) + r.c1.diff2xMod(r.c1, V1) + r.c2.prod2x(V2, NonResidue) # r₂ is unused and cannot alias + r.c1.sum2xMod(r.c1, r.c2) + + # r₂ = (a₀ + a₂) * (b₀ + b₂) - v₀ - v₂ + v₁ + when false: # CubicExt.has1extraBit(): + t0.sumUnr(a.c0, a.c2) + t1.sumUnr(b.c0, b.c2) + else: + t0.sum(a.c0, a.c2) + t1.sum(b.c0, b.c2) + r.c2.prod2x(t0, t1) + r.c2.diff2xMod(r.c2, V0) + r.c2.diff2xMod(r.c2, V2) + r.c2.sum2xMod(r.c2, V1) + +# Dispatch +# ---------------------------------------------------------------------- + +func square2x*(r: var CubicExt2x, a: CubicExt) {.inline.} = + ## Returns r = a² + square2x_Chung_Hasan_SQR2(r, a) + +func prod2x*(r: var CubicExt2x, a, b: CubicExt) {.inline.} = + ## Returns r = ab + prod2xImpl(r, a, b) # ############################################################ # # @@ -800,7 +1009,10 @@ func prod_generic(r: var QuadraticExt, a, b: QuadraticExt) = v1 *= NonResidue r.c0.sum(v0, v1) -func mul_sparse_generic_by_x0(r: var QuadraticExt, a, sparseB: QuadraticExt) = +# Sparse multiplication +# ------------------------------------------------------------------- + +func mul_sparse_generic_by_x0(r: var QuadraticExt, a: QuadraticExt, sparseB: auto) = ## Multiply `a` by `b` with sparse coordinates (x, 0) ## On a generic quadratic extension field # Algorithm (with β the non-residue in the base field) @@ -812,10 +1024,17 @@ func mul_sparse_generic_by_x0(r: var QuadraticExt, a, sparseB: QuadraticExt) = # # r0 = a0 b0 # r1 = (a0 + a1) b0 - a0 b0 = a1 b0 - template b(): untyped = sparseB + when typeof(sparseB) is typeof(a): + template b(): untyped = sparseB.c0 + elif typeof(sparseB) is typeof(a.c0): + template b(): untyped = sparseB + else: + {.error: "sparseB type is " & $typeof(sparseB) & + " which does not match with either a (" & $typeof(a) & + ") or a.c0 (" & $typeof(a.c0) & ")".} - r.c0.prod(a.c0, b.c0) - r.c1.prod(a.c1, b.c0) + r.c0.prod(a.c0, b) + r.c1.prod(a.c1, b) func mul_sparse_generic_by_0y( r: var QuadraticExt, a: QuadraticExt, @@ -870,6 +1089,9 @@ func mul_sparse_generic_by_0y( # aliasing: a unneeded now r.c0.prod(t, NonResidue) +# Inversion +# ------------------------------------------------------------------- + func invImpl(r: var QuadraticExt, a: QuadraticExt) = ## Compute the multiplicative inverse of ``a`` ## @@ -959,7 +1181,10 @@ func inv*(r: var QuadraticExt, a: QuadraticExt) = ## Incidentally this avoids extra check ## to convert Jacobian and Projective coordinates ## to affine for elliptic curve - r.invImpl(a) + when true: + r.invImpl(a) + else: # Lazy reduction, doesn't seem to gain speed. + r.inv2xImpl(a) func inv*(a: var QuadraticExt) = ## Compute the multiplicative inverse of ``a`` @@ -968,7 +1193,7 @@ func inv*(a: var QuadraticExt) = ## Incidentally this avoids extra check ## to convert Jacobian and Projective coordinates ## to affine for elliptic curve - a.invImpl(a) + a.inv(a) func `*=`*(a: var QuadraticExt, b: QuadraticExt) = ## In-place multiplication @@ -1147,6 +1372,12 @@ func prodImpl(r: var CubicExt, a, b: CubicExt) = # Finish r₀ r.c0.sum(t0, v0) +# Sparse multiplication +# ------------------------------------------------------------------- + +# Inversion +# ------------------------------------------------------------------- + func invImpl(r: var CubicExt, a: CubicExt) = ## Compute the multiplicative inverse of ``a`` ## @@ -1208,7 +1439,15 @@ func invImpl(r: var CubicExt, a: CubicExt) = func square*(r: var CubicExt, a: CubicExt) = ## Returns r = a² - square_Chung_Hasan_SQR3(r, a) + when CubicExt.F.C == BW6_761 or # Too large + CubicExt.F.C == BN254_Snarks: # 50 cycles slower on Fp2->Fp4->Fp12 towering + square_Chung_Hasan_SQR3(r, a) + else: + var d {.noInit.}: doublePrec(typeof(a)) + d.square2x_Chung_Hasan_SQR2(a) + r.c0.redc2x(d.c0) + r.c1.redc2x(d.c1) + r.c2.redc2x(d.c2) func square*(a: var CubicExt) = ## In-place squaring @@ -1216,11 +1455,25 @@ func square*(a: var CubicExt) = func prod*(r: var CubicExt, a, b: CubicExt) = ## In-place multiplication - r.prodImpl(a, b) + when CubicExt.F.C == BW6_761: # Too large + r.prodImpl(a, b) + else: + var d {.noInit.}: doublePrec(typeof(r)) + d.prod2x(a, b) + r.c0.redc2x(d.c0) + r.c1.redc2x(d.c1) + r.c2.redc2x(d.c2) func `*=`*(a: var CubicExt, b: CubicExt) = ## In-place multiplication - a.prodImpl(a, b) + when CubicExt.F.C == BW6_761: # Too large + a.prodImpl(a, b) + else: + var d {.noInit.}: doublePrec(typeof(a)) + d.prod2x(a, b) + a.c0.redc2x(d.c0) + a.c1.redc2x(d.c1) + a.c2.redc2x(d.c2) func inv*(r: var CubicExt, a: CubicExt) = ## Compute the multiplicative inverse of ``a`` diff --git a/constantine/tower_field_extensions/tower_instantiation.nim b/constantine/tower_field_extensions/tower_instantiation.nim index bb337d1..a638630 100644 --- a/constantine/tower_field_extensions/tower_instantiation.nim +++ b/constantine/tower_field_extensions/tower_instantiation.nim @@ -226,7 +226,7 @@ func `*=`*(a: var Fp4, _: type NonResidue) {.inline.} = a.prod(a, NonResidue) func prod*(r: var Fp6, a: Fp6, _: type NonResidue) {.inline.} = - ## Multiply an element of 𝔽p4 by the non-residue + ## Multiply an element of 𝔽p6 by the non-residue ## chosen to construct the next extension or the twist: ## - if quadratic non-residue: 𝔽p12 ## - if cubic non-residue: 𝔽p18 @@ -243,7 +243,7 @@ func prod*(r: var Fp6, a: Fp6, _: type NonResidue) {.inline.} = r.c0.prod(t, NonResidue) func `*=`*(a: var Fp6, _: type NonResidue) {.inline.} = - ## Multiply an element of 𝔽p4 by the non-residue + ## Multiply an element of 𝔽p6 by the non-residue ## chosen to construct the next extension or the twist: ## - if quadratic non-residue: 𝔽p12 ## - if cubic non-residue: 𝔽p18 @@ -272,12 +272,6 @@ func `*=`*(a: var Fp2, b: Fp) = a.c0 *= b a.c1 *= b -func mul_sparse_by_y0*[C: static Curve](r: var Fp4[C], a: Fp4[C], b: Fp2[C]) = - ## Sparse multiplication of an Fp4 element - ## with coordinates (a₀, a₁) by (b₀, 0) - r.c0.prod(a.c0, b) - r.c1.prod(a.c1, b) - func mul_sparse_by_0y0*[C: static Curve](r: var Fp6[C], a: Fp6[C], b: Fp2[C]) = ## Sparse multiplication of an Fp6 element ## with coordinates (a₀, a₁, a₂) by (0, b₁, 0) diff --git a/tests/t_fp12_exponentiation.nim b/tests/t_fp12_exponentiation.nim index 3953287..b2a847b 100644 --- a/tests/t_fp12_exponentiation.nim +++ b/tests/t_fp12_exponentiation.nim @@ -68,7 +68,7 @@ proc test_sameBaseProduct(C: static Curve, gen: RandomGen) = xapb.powUnsafeExponent(apb, window = 3) xa *= xb - check: bool(xa == xapb) + doAssert: bool(xa == xapb) proc test_powpow(C: static Curve, gen: RandomGen) = ## (xᴬ)ᴮ = xᴬᴮ - power of power @@ -86,7 +86,7 @@ proc test_powpow(C: static Curve, gen: RandomGen) = x.powUnsafeExponent(b, window = 3) y.powUnsafeExponent(ab, window = 3) - check: bool(x == y) + doAssert: bool(x == y) proc test_powprod(C: static Curve, gen: RandomGen) = ## (xy)ᴬ = xᴬyᴬ - power of product @@ -105,7 +105,7 @@ proc test_powprod(C: static Curve, gen: RandomGen) = x *= y - check: bool(x == xy) + doAssert: bool(x == xy) proc test_pow0(C: static Curve, gen: RandomGen) = ## x⁰ = 1 @@ -113,7 +113,7 @@ proc test_pow0(C: static Curve, gen: RandomGen) = var a: BigInt[128] # 0-init x.powUnsafeExponent(a, window=3) - check: bool x.isOne() + doAssert: bool x.isOne() proc test_0pow0(C: static Curve, gen: RandomGen) = ## 0⁰ = 1 @@ -121,7 +121,7 @@ proc test_0pow0(C: static Curve, gen: RandomGen) = var a: BigInt[128] # 0-init x.powUnsafeExponent(a, window=3) - check: bool x.isOne() + doAssert: bool x.isOne() proc test_powinv(C: static Curve, gen: RandomGen) = ## xᴬ / xᴮ = xᴬ⁻ᴮ - quotient of power @@ -150,7 +150,7 @@ proc test_powinv(C: static Curve, gen: RandomGen) = discard amb.diff(a, b) xamb.powUnsafeExponent(amb, window = 3) - check: bool(xa == xamb) + doAssert: bool(xa == xamb) proc test_invpow(C: static Curve, gen: RandomGen) = ## (x / y)ᴬ = xᴬ / yᴬ - power of quotient @@ -173,7 +173,7 @@ proc test_invpow(C: static Curve, gen: RandomGen) = xqya *= invy xqya.powUnsafeExponent(a, window = 3) - check: bool(xa == xqya) + doAssert: bool(xa == xqya) suite "Exponentiation in 𝔽p12" & " [" & $WordBitwidth & "-bit mode]": staticFor(curve, TestCurves): diff --git a/tests/t_fp_tower_template.nim b/tests/t_fp_tower_template.nim index 1e8541c..3750d29 100644 --- a/tests/t_fp_tower_template.nim +++ b/tests/t_fp_tower_template.nim @@ -256,16 +256,22 @@ proc runTowerTests*[N]( staticFor(curve, TestCurves): test(ExtField(ExtDegree, curve)): r.prod(x, Z) - check: bool(r == Z) + doAssert bool(r == Z), + "\nExpected zero but got (" & $ExtField(ExtDegree, curve) & "): " & x.toHex() test(ExtField(ExtDegree, curve)): r.prod(Z, x) - check: bool(r == Z) + doAssert bool(r == Z), + "\nExpected zero but got (" & $ExtField(ExtDegree, curve) & "): " & x.toHex() test(ExtField(ExtDegree, curve)): r.prod(x, O) - check: bool(r == x) + doAssert bool(r == x), + "\n(" & $ExtField(ExtDegree, curve) & "): Expected one: " & O.toHex() & "\n" & + "got: " & x.toHex() test(ExtField(ExtDegree, curve)): r.prod(O, x) - check: bool(r == x) + doAssert bool(r == x), + "\n(" & $ExtField(ExtDegree, curve) & "): Expected one: " & O.toHex() & "\n" & + "got: " & x.toHex() test "Multiplication and Squaring are consistent": proc test(Field: typedesc, Iters: static int, gen: static RandomGen) = @@ -276,7 +282,9 @@ proc runTowerTests*[N]( rMul.prod(a, a) rSqr.square(a) - doAssert bool(rMul == rSqr), "Failure with a (" & $Field & "): " & a.toHex() + doAssert bool(rMul == rSqr), "Failure with a (" & $Field & "): " & a.toHex() & "\n" & + "Mul: " & rMul.toHex() & "\n" & + "Sqr: " & rSqr.toHex() & "\n" staticFor(curve, TestCurves): test(ExtField(ExtDegree, curve), Iters, gen = Uniform) @@ -295,7 +303,9 @@ proc runTowerTests*[N]( rSqr.square(a) rNegSqr.square(na) - doAssert bool(rSqr == rNegSqr), "Failure with a (" & $Field & "): " & a.toHex() + doAssert bool(rSqr == rNegSqr), "Failure with a (" & $Field & "): " & a.toHex() & "\n" & + "Sqr: " & rSqr.toHex() & "\n" & + "SqrNeg: " & rNegSqr.toHex() & "\n" staticFor(curve, TestCurves): test(ExtField(ExtDegree, curve), Iters, gen = Uniform)