From 1889fc4eeb772af2008b8eacf380b2cf11cd158e Mon Sep 17 00:00:00 2001 From: Mamy Ratsimbazafy Date: Sun, 12 Apr 2020 16:09:38 +0200 Subject: [PATCH] Improve bn curve family support (#23) * Allow tagging BarretoNaehrig family * Refactor the constant generation and fix XDeclaredButNotUsed * BN field inversion via addition chain (but slower than generic :/ so deactivated) --- .../arithmetic/finite_fields_inversion.nim | 31 ++++- constantine/arithmetic/limbs_modular.nim | 6 +- constantine/arithmetic/precomputed.nim | 52 ++++++++ constantine/config/curves.nim | 117 ++++++++++++++---- constantine/config/curves_parser.nim | 81 ++++++++++-- tests/test_finite_fields_powinv.nim | 80 ++++++++++-- tests/test_finite_fields_sqrt.nim | 6 +- 7 files changed, 319 insertions(+), 54 deletions(-) diff --git a/constantine/arithmetic/finite_fields_inversion.nim b/constantine/arithmetic/finite_fields_inversion.nim index 4f657a6..47bb180 100644 --- a/constantine/arithmetic/finite_fields_inversion.nim +++ b/constantine/arithmetic/finite_fields_inversion.nim @@ -120,6 +120,30 @@ func invmod_addchain(r: var Fp[Secp256k1], a: Fp[Secp256k1]) = # Note: it only works for u positive, in particular BN254 doesn't work :/ # Is there a way to only use a^-u or even powers? +func invmod_addchain_bn[C](r: var Fp[C], a: Fp[C]) = + ## Inversion on BN prime fields with positive base parameter `u` + ## via Little Fermat theorem and leveraging the prime low Hamming weight + ## + ## Requires a `bn` curve with a positive parameter `u` + # TODO: debug for input "0x0d2007d8aaface1b8501bfbe792974166e8f9ad6106e5b563604f0aea9ab06f6" + # see test suite + static: doAssert C.canUseFast_BN_Inversion() + + var v0 {.noInit.}, v1 {.noInit.}: Fp[C] + + v0 = a + v0.powUnsafeExponent(C.getBN_param_6u_minus_1_BE()) # v0 <- a^(6u-1) + v1.prod(v0, a) # v1 <- a^(6u) + v1.powUnsafeExponent(C.getBN_param_u_BE()) # v1 <- a^(6u²) + r.square(v1) # r <- a^(12u²) + v1.square(r) # v1 <- a^(24u²) + v0 *= v1 # v0 <- a^(24u²) a^(6u-1) + v1 *= r # v1 <- a^(24u²) a^(12u²) = a^(36u²) + v1.powUnsafeExponent(C.getBN_param_u_BE()) # v1 <- a^(36u³) + r.prod(v0, v1) # r <- a^(36u³) a^(24u²) a^(6u-1) + v1.powUnsafeExponent(C.getBN_param_u_BE()) # v1 <- a^(36u⁴) + r *= v1 # r <- a^(36u⁴) a^(36u³) a^(24u²) a^(6u-1) = a^(p-2) = a^(-1) + # ############################################################ # # Dispatch @@ -128,7 +152,8 @@ func invmod_addchain(r: var Fp[Secp256k1], a: Fp[Secp256k1]) = func inv*(r: var Fp, a: Fp) = ## Inversion modulo p - # For now we don't activate the addition chain. - # Performance is equal to GCD and it does not pass test on 𝔽p2 - # We need faster squaring/multiplications + # For now we don't activate the addition chains + # neither for Secp256k1 nor BN curves + # Performance is slower than GCD + # To be revisited with faster squaring/multiplications r.mres.steinsGCD(a.mres, Fp.C.getR2modP(), Fp.C.Mod, Fp.C.getPrimePlus1div2()) diff --git a/constantine/arithmetic/limbs_modular.nim b/constantine/arithmetic/limbs_modular.nim index 12bc7c2..2e3966f 100644 --- a/constantine/arithmetic/limbs_modular.nim +++ b/constantine/arithmetic/limbs_modular.nim @@ -276,7 +276,7 @@ func shlAddMod_estimate(a: LimbsViewMut, aLen: int, block: # q*p # q * p + carry (doubleword) carry from previous limb - muladd1(carry, qp_lo, q, M[i], Word carry) + muladd1(carry, qp_lo, q, M[i], carry) block: # a*2^64 - q*p var borrow: Borrow @@ -293,8 +293,8 @@ func shlAddMod_estimate(a: LimbsViewMut, aLen: int, # if carry < q or carry == q and over_p we must do "a -= p" # if carry > hi (negative result) we must do "a += p" - result.neg = Word(carry) > hi - result.tooBig = not(result.neg) and (over_p or (Word(carry) < hi)) + result.neg = carry > hi + result.tooBig = not(result.neg) and (over_p or (carry < hi)) func shlAddMod(a: LimbsViewMut, aLen: int, c: Word, M: LimbsViewConst, mBits: int) = diff --git a/constantine/arithmetic/precomputed.nim b/constantine/arithmetic/precomputed.nim index 49360bb..f333950 100644 --- a/constantine/arithmetic/precomputed.nim +++ b/constantine/arithmetic/precomputed.nim @@ -104,6 +104,22 @@ func sub(a: var BigInt, w: BaseType): bool = result = bool(borrow) +func cadd(a: var BigInt, b: BigInt, ctl: bool): bool = + ## In-place optional addition + ## + ## It is NOT constant-time and is intended + ## only for compile-time precomputation + ## of non-secret data. + var carry, sum: BaseType + for i in 0 ..< a.limbs.len: + let ai = BaseType(a.limbs[i]) + let bi = BaseType(b.limbs[i]) + addC(carry, sum, ai, bi, carry) + if ctl: + a.limbs[i] = Word(sum) + + result = bool(carry) + func csub(a: var BigInt, b: BigInt, ctl: bool): bool = ## In-place optional substraction ## @@ -361,3 +377,39 @@ func primePlus1Div4_BE*[bits: static int]( tmp.shiftRight(1) result.exportRawUint(tmp, bigEndian) + +func toCanonicalIntRepr*[bits: static int]( + a: BigInt[bits] + ): array[(bits+7) div 8, byte] {.noInit.} = + ## Export a bigint to its canonical BigEndian representation + ## (octet-string) + result.exportRawUint(a, bigEndian) + +func bn_6u_minus_1_BE*[bits: static int]( + u: BigInt[bits] + ): array[(bits+7+3) div 8, byte] {.noInit.} = + ## For a BN curve + ## Precompute 6u-1 (for Little Fermat inversion) + ## and store it in canonical integer representation + # TODO: optimize output size + # each extra 0-bit is an extra useless squaring for a public exponent + # For example, for BN254-Snarks, u = 0x44E992B44A6909F1 (63-bit) + # and 6u+1 is 65-bit (not 66 as inferred) + + # Zero-extend "u" + var u_ext: BigInt[bits+3] + + for i in 0 ..< u.limbs.len: + u_ext.limbs[i] = u.limbs[i] + + # Addition chain to u -> 6u + discard u_ext.dbl() # u_ext = 2u + let u_ext2 = u_ext # u_ext2 = 2u + discard u_ext.dbl() # u_ext = 4u + discard u_ext.cadd(u_ext2, true) # u_ext = 6u + + # Sustract 1 + discard u_ext.sub(1) + + # Export + result.exportRawUint(u_ext, bigEndian) diff --git a/constantine/config/curves.nim b/constantine/config/curves.nim index e43820e..28261a8 100644 --- a/constantine/config/curves.nim +++ b/constantine/config/curves.nim @@ -13,8 +13,6 @@ import ./curves_parser, ./common, ../arithmetic/[precomputed, bigints] -{.push used.} - # ############################################################ # # Configuration of finite fields @@ -39,6 +37,22 @@ import # - type Curve* = enum # - proc Mod*(curve: static Curve): auto # which returns the field modulus of the curve +# - proc Family*(curve: static Curve): CurveFamily +# which returns the curve family +# - proc get_BN_param_u_BE*(curve: static Curve): array[N, byte] +# which returns the "u" parameter of a BN curve +# as a big-endian canonical integer representation +# if it's a BN curve and u is positive +# - proc get_BN_param_6u_minus1_BE*(curve: static Curve): array[N, byte] +# which returns the "6u-1" parameter of a BN curve +# as a big-endian canonical integer representation +# if it's a BN curve and u is positive. +# This is used for optimized field inversion for BN curves + +type + CurveFamily* = enum + NoFamily + BarretoNaehrig # BN curve declareCurves: # ----------------------------------------------------------------------------- @@ -74,11 +88,15 @@ declareCurves: curve BN254_Nogami: # Integer Variable χ–Based Ate Pairing, 2008, Nogami et al bitsize: 254 modulus: "0x2523648240000001ba344d80000000086121000000000013a700000000000013" + family: BarretoNaehrig # Equation: Y^2 = X^3 + 2 # u: -(2^62 + 2^55 + 1) curve BN254_Snarks: # Zero-Knowledge proofs curve (SNARKS, STARKS, Ethereum) bitsize: 254 modulus: "0x30644e72e131a029b85045b68181585d97816a916871ca8d3c208c16d87cfd47" + family: BarretoNaehrig + bn_u_bitwidth: 63 + bn_u: "0x44E992B44A6909F1" # Equation: Y^2 = X^3 + 3 # u: 4965661367192848881 curve Curve25519: # Bernstein curve @@ -96,6 +114,8 @@ declareCurves: # https://github.com/ethereum/EIPs/blob/41dea9615/EIPS/eip-2539.md bitsize: 377 modulus: "0x01ae3a4617c510eac63b05c06ca1493b1a22d9f300f5138f1ef3622fba094800170b5d44300000008508c00000000001" + # u: 3 * 2^46 * (7 * 13 * 499) + 1 + # u: 0x8508c00000000001 curve BLS12_381: bitsize: 381 modulus: "0x1a0111ea397fe69a4b1ba7b6434bacd764774b84f38512bf6730d2a0f6b0f6241eabfffeb153ffffb9feffffffffaaab" @@ -104,6 +124,7 @@ declareCurves: curve BN446: bitsize: 446 modulus: "0x2400000000000000002400000002d00000000d800000021c0000001800000000870000000b0400000057c00000015c000000132000000067" + family: BarretoNaehrig # u = 2^110 + 2^36 + 1 curve FKM12_447: # Fotiadis-Konstantinou-Martindale bitsize: 447 @@ -144,6 +165,7 @@ declareCurves: # https://hal.archives-ouvertes.fr/hal-01534101/file/main.pdf bitsize: 462 modulus: "0x240480360120023ffffffffff6ff0cf6b7d9bfca0000000000d812908f41c8020ffffffffff6ff66fc6ff687f640000000002401b00840138013" + family: BarretoNaehrig # u = 2^114 + 2^101 - 2^14 - 1 # ############################################################ @@ -165,29 +187,34 @@ func getCurveBitSize*(C: static Curve): static int = template matchingBigInt*(C: static Curve): untyped = BigInt[CurveBitSize[C]] +func family*(C: static Curve): CurveFamily = + result = static(CurveFamilies[C]) + # ############################################################ # -# Autogeneration of precomputed Montgomery constants in ROM +# Autogeneration of precomputed constants in ROM # # ############################################################ -macro genMontyMagics(T: typed): untyped = +macro genConstants(): untyped = ## Store ## - the Montgomery magic constant "R^2 mod N" in ROM ## For each curve under the private symbol "MyCurve_R2modP" ## - the Montgomery magic constant -1/P mod 2^WordBitSize ## For each curve under the private symbol "MyCurve_NegInvModWord - T.getImpl.expectKind(nnkTypeDef) - T.getImpl[2].expectKind(nnkEnumTy) - + ## - ... result = newStmtList() - let E = T.getImpl[2] - for i in 1 ..< E.len: - let curve = E[i] + template used(name: string): NimNode = + nnkPragmaExpr.newTree( + ident(name), + nnkPragma.newTree(ident"used") + ) + + for curve in Curve.low .. Curve.high: # const MyCurve_CanUseNoCarryMontyMul = useNoCarryMontyMul(MyCurve_Modulus) result.add newConstStmt( - ident($curve & "_CanUseNoCarryMontyMul"), newCall( + used($curve & "_CanUseNoCarryMontyMul"), newCall( bindSym"useNoCarryMontyMul", bindSym($curve & "_Modulus") ) @@ -195,7 +222,7 @@ macro genMontyMagics(T: typed): untyped = # const MyCurve_CanUseNoCarryMontySquare = useNoCarryMontySquare(MyCurve_Modulus) result.add newConstStmt( - ident($curve & "_CanUseNoCarryMontySquare"), newCall( + used($curve & "_CanUseNoCarryMontySquare"), newCall( bindSym"useNoCarryMontySquare", bindSym($curve & "_Modulus") ) @@ -203,7 +230,7 @@ macro genMontyMagics(T: typed): untyped = # const MyCurve_R2modP = r2mod(MyCurve_Modulus) result.add newConstStmt( - ident($curve & "_R2modP"), newCall( + used($curve & "_R2modP"), newCall( bindSym"r2mod", bindSym($curve & "_Modulus") ) @@ -211,64 +238,87 @@ macro genMontyMagics(T: typed): untyped = # const MyCurve_NegInvModWord = negInvModWord(MyCurve_Modulus) result.add newConstStmt( - ident($curve & "_NegInvModWord"), newCall( + used($curve & "_NegInvModWord"), newCall( bindSym"negInvModWord", bindSym($curve & "_Modulus") ) ) # const MyCurve_montyOne = montyOne(MyCurve_Modulus) result.add newConstStmt( - ident($curve & "_MontyOne"), newCall( + used($curve & "_MontyOne"), newCall( bindSym"montyOne", bindSym($curve & "_Modulus") ) ) # const MyCurve_MontyPrimeMinus1 = montyPrimeMinus1(MyCurve_Modulus) result.add newConstStmt( - ident($curve & "_MontyPrimeMinus1"), newCall( + used($curve & "_MontyPrimeMinus1"), newCall( bindSym"montyPrimeMinus1", bindSym($curve & "_Modulus") ) ) # const MyCurve_InvModExponent = primeMinus2_BE(MyCurve_Modulus) result.add newConstStmt( - ident($curve & "_InvModExponent"), newCall( + used($curve & "_InvModExponent"), newCall( bindSym"primeMinus2_BE", bindSym($curve & "_Modulus") ) ) # const MyCurve_PrimePlus1div2 = primePlus1div2(MyCurve_Modulus) result.add newConstStmt( - ident($curve & "_PrimePlus1div2"), newCall( + used($curve & "_PrimePlus1div2"), newCall( bindSym"primePlus1div2", bindSym($curve & "_Modulus") ) ) # const MyCurve_PrimeMinus1div2_BE = primeMinus1div2_BE(MyCurve_Modulus) result.add newConstStmt( - ident($curve & "_PrimeMinus1div2_BE"), newCall( + used($curve & "_PrimeMinus1div2_BE"), newCall( bindSym"primeMinus1div2_BE", bindSym($curve & "_Modulus") ) ) # const MyCurve_PrimeMinus3div4_BE = primeMinus3div4_BE(MyCurve_Modulus) result.add newConstStmt( - ident($curve & "_PrimeMinus3div4_BE"), newCall( + used($curve & "_PrimeMinus3div4_BE"), newCall( bindSym"primeMinus3div4_BE", bindSym($curve & "_Modulus") ) ) # const MyCurve_PrimePlus1div4_BE = primePlus1div4_BE(MyCurve_Modulus) result.add newConstStmt( - ident($curve & "_PrimePlus1div4_BE"), newCall( + used($curve & "_PrimePlus1div4_BE"), newCall( bindSym"primePlus1div4_BE", bindSym($curve & "_Modulus") ) ) - # echo result.toStrLit + if CurveFamilies[curve] == BarretoNaehrig: + # when declared(MyCurve_BN_param_u): + # const MyCurve_BN_u_BE = toCanonicalIntRepr(MyCurve_BN_param_u) + # const MyCurve_BN_6u_minus_1_BE = bn_6u_minus_1_BE(MyCurve_BN_param_u) + var bnStmts = newStmtList() + bnStmts.add newConstStmt( + used($curve & "_BN_u_BE"), newCall( + bindSym"toCanonicalIntRepr", + ident($curve & "_BN_param_u") + ) + ) + bnStmts.add newConstStmt( + used($curve & "_BN_6u_minus_1_BE"), newCall( + bindSym"bn_6u_minus_1_BE", + ident($curve & "_BN_param_u") + ) + ) -genMontyMagics(Curve) + result.add nnkWhenStmt.newTree( + nnkElifBranch.newTree( + newCall(ident"declared", ident($curve & "_BN_param_u")), + bnStmts + ) + ) + +genConstants() macro canUseNoCarryMontyMul*(C: static Curve): untyped = ## Returns true if the Modulus is compatible with a fast @@ -317,13 +367,30 @@ macro getPrimePlus1div4_BE*(C: static Curve): untyped = ## Get (P+1) / 4 for an odd prime in big-endian serialized format result = bindSym($C & "_PrimePlus1div4_BE") +# Family specific +# ------------------------------------------------------- +macro canUseFast_BN_Inversion*(C: static Curve): untyped = + ## A BN curve can use the fast BN inversion if the parameter "u" is positive + if CurveFamilies[C] != BarretoNaehrig: + return newLit false + return bindSym($C & "_BN_can_use_fast_inversion") + +macro getBN_param_u_BE*(C: static Curve): untyped = + ## Get the ``u`` parameter of a BN curve in canonical big-endian representation + result = bindSym($C & "_BN_u_BE") + +macro getBN_param_6u_minus_1_BE*(C: static Curve): untyped = + ## Get the ``6u-1`` from the ``u`` parameter + ## of a BN curve in canonical big-endian representation + result = bindSym($C & "_BN_6u_minus_1_BE") + # ############################################################ # # Debug info printed at compile-time # # ############################################################ -macro debugConsts(): untyped = +macro debugConsts(): untyped {.used.} = let curves = bindSym("Curve") let E = curves.getImpl[2] @@ -343,5 +410,5 @@ macro debugConsts(): untyped = result.add quote do: echo "----------------------------------------------------------------------------" -# debug: +# debug: # displayed with -d:debugConstantine # debugConsts() diff --git a/constantine/config/curves_parser.nim b/constantine/config/curves_parser.nim index ba2564a..75e020f 100644 --- a/constantine/config/curves_parser.nim +++ b/constantine/config/curves_parser.nim @@ -56,10 +56,14 @@ macro declareCurves*(curves: untyped): untyped = # StrLit "0x30644e72e131a029b85045b68181585d97816a916871ca8d3c208c16d87cfd47" var Curves: seq[NimNode] - var CurveBitSize = nnKBracket.newTree() + var MapCurveBitWidth = nnkBracket.newTree() + var MapCurveFamily = nnkBracket.newTree() var curveModStmts = newStmtList() + var curveExtraStmts = newStmtList() for curveDesc in curves: + # Checks + # ----------------------------------------------- curveDesc.expectKind(nnkCommand) doAssert curveDesc[0].eqIdent"curve" curveDesc[1].expectKind(nnkIdent) # Curve name @@ -67,32 +71,37 @@ macro declareCurves*(curves: untyped): untyped = curveDesc[2][0].expectKind(nnkCall) curveDesc[2][1].expectKind(nnkCall) + # Mandatory fields + # ----------------------------------------------- let curve = curveDesc[1] + let curveParams = curveDesc[2] var offset = 0 var testCurve = false - if curveDesc[2][0][0].eqident"testingCurve": + if curveParams[0][0].eqident"testingCurve": offset = 1 - testCurve = curveDesc[2][0][1].boolVal + testCurve = curveParams[0][1].boolVal - let sizeSection = curveDesc[2][offset] + let sizeSection = curveParams[offset] doAssert sizeSection[0].eqIdent"bitsize" sizeSection[1].expectKind(nnkStmtList) let bitSize = sizeSection[1][0] - let modSection = curveDesc[2][offset+1] + let modSection = curveParams[offset+1] doAssert modSection[0].eqIdent"modulus" modSection[1].expectKind(nnkStmtList) let modulus = modSection[1][0] + # Construct the constants + # ----------------------------------------------- if not testCurve or defined(testingCurves): Curves.add curve # "BN254: 254" for array construction - CurveBitSize.add nnkExprColonExpr.newTree( + MapCurveBitWidth.add nnkExprColonExpr.newTree( curve, bitSize ) - # const BN254_Modulus = fromHex(BigInt[254], "0x30644e72e131a029b85045b68181585d97816a916871ca8d3c208c16d87cfd47") + # const BN254_Snarks_Modulus = fromHex(BigInt[254], "0x30644e72e131a029b85045b68181585d97816a916871ca8d3c208c16d87cfd47") let modulusID = ident($curve & "_Modulus") curveModStmts.add newConstStmt( modulusID, @@ -103,6 +112,56 @@ macro declareCurves*(curves: untyped): untyped = ) ) + # Family specific + # ----------------------------------------------- + if offset + 2 < curveParams.len: + let familySection = curveParams[offset+2] + doAssert familySection[0].eqIdent"family" + familySection[1].expectKind(nnkStmtList) + let family = familySection[1][0] + + MapCurveFamily.add nnkExprColonExpr.newTree( + curve, family + ) + + # BN curves + # ----------------------------------------------- + if family.eqIdent"BarretoNaehrig": + if offset + 5 == curveParams.len: + if curveParams[offset+3][0].eqIdent"bn_u_bitwidth" and + curveParams[offset+4][0].eqIdent"bn_u": + + let bn_u_bitwidth = curveParams[offset+3][1][0] + let bn_u = curveParams[offset+4][1][0] + + # const BN254_Snarks_BN_can_use_fast_inversion = ... + curveExtraStmts.add newConstStmt( + ident($curve & "_BN_can_use_fast_inversion"), + if ($bn_u)[0] == '-': newLit false # negative ``u`` can use the specialized fast inversion + else: newLit true + ) + + # const BN254_Snarks_BN_param_u = fromHex(BigInt[63], "0x44E992B44A6909F1") + curveExtraStmts.add newConstStmt( + ident($curve & "_BN_param_u"), + newCall( + bindSym"fromHex", + nnkBracketExpr.newTree(bindSym"BigInt", bn_u_bitwidth), + bn_u + ) + ) + else: + # const BN254_Snarks_BN_can_use_fast_inversion = ... + curveExtraStmts.add newConstStmt( + ident($curve & "_BN_can_use_fast_inversion"), + newLit false + ) + + else: + MapCurveFamily.add nnkExprColonExpr.newTree( + curve, ident"NoFamily" + ) + # end for --------------------------------------------------- result = newStmtList() @@ -117,11 +176,15 @@ macro declareCurves*(curves: untyped): untyped = ) # const CurveBitSize: array[Curve, int] = ... - let cbs = ident("CurveBitSize") result.add newConstStmt( - cbs, CurveBitSize + ident("CurveBitSize"), MapCurveBitWidth + ) + # const CurveFamily: array[Curve, CurveFamily] = ... + result.add newConstStmt( + ident("CurveFamilies"), MapCurveFamily ) result.add curveModStmts + result.add curveExtraStmts # echo result.toStrLit() diff --git a/tests/test_finite_fields_powinv.nim b/tests/test_finite_fields_powinv.nim index f0bc51d..7e966b5 100644 --- a/tests/test_finite_fields_powinv.nim +++ b/tests/test_finite_fields_powinv.nim @@ -6,13 +6,23 @@ # * Apache v2 license (license terms in the root directory or at http://www.apache.org/licenses/LICENSE-2.0). # at your option. This file may not be copied, modified, or distributed except according to those terms. -import unittest, - ../constantine/arithmetic, +import ../constantine/arithmetic, ../constantine/io/[io_bigints, io_fields], - ../constantine/config/curves + ../constantine/config/curves, + # Test utilities + ../helpers/prng, + # Standard library + std/unittest, std/times static: doAssert defined(testingCurves), "This modules requires the -d:testingCurves compile option" +const Iters = 512 + +var rng: RngState +let seed = uint32(getTime().toUnix() and (1'i64 shl 32 - 1)) # unixTime mod 2^32 +rng.seed(seed) +echo "test_finite_fields_powinv xoshiro512** seed: ", seed + proc main() = suite "Modular exponentiation over finite fields": test "n² mod 101": @@ -143,17 +153,65 @@ proc main() = computed == expected suite "Modular inversion over prime fields": - test "x^(-1) mod p": - var r, x: Fp[BLS12_381] + test "Specific test on Fp[BLS12_381]": + var r, x: Fp[BLS12_381] - # BN254 field modulus - x.fromHex("0x30644e72e131a029b85045b68181585d97816a916871ca8d3c208c16d87cfd47") + # BN254 field modulus + x.fromHex("0x30644e72e131a029b85045b68181585d97816a916871ca8d3c208c16d87cfd47") - let expected = "0x0636759a0f3034fa47174b2c0334902f11e9915b7bd89c6a2b3082b109abbc9837da17201f6d8286fe6203caa1b9d4c8" + let expected = "0x0636759a0f3034fa47174b2c0334902f11e9915b7bd89c6a2b3082b109abbc9837da17201f6d8286fe6203caa1b9d4c8" + r.inv(x) + let computed = r.toHex() + + check: + computed == expected + + test "Specific tests on Fp[BN254_Snarks]": + block: + var r, x: Fp[BN254_Snarks] + x.setOne() r.inv(x) - let computed = r.toHex() + check: bool r.isOne() - check: - computed == expected + block: + var r, x, expected: Fp[BN254_Snarks] + x.fromHex"0x076ef96647587df443d86a7ac8aa12f3f52d5d775287a6f5e47764a59d378309" + expected.fromHex"2d2ef0cd23dd8ec9e9b47c130942ecd7d7fda5e2dd5af19114bc34565ee355b8" + + r.inv(x) + check: bool(r == expected) + + block: + var r, x, expected: Fp[BN254_Snarks] + x.fromHex"0x0d2007d8aaface1b8501bfbe792974166e8f9ad6106e5b563604f0aea9ab06f6" + expected.fromHex"1b632d8aa572c4356debe80f772228dee49c203f34066a998fba5194b98e56c3" + + r.inv(x) + check: bool(r == expected) + + proc testRandomInv(curve: static Curve) = + test "Random inversion testing on " & $Curve(curve): + var aInv, r: Fp[curve] + + for _ in 0 ..< Iters: + let a = rng.random(Fp[curve]) + aInv.inv(a) + r.prod(a, aInv) + check: bool r.isOne() + r.prod(aInv, a) + check: bool r.isOne() + + testRandomInv P224 + testRandomInv BN254_Nogami + testRandomInv BN254_Snarks + testRandomInv Curve25519 + testRandomInv P256 + testRandomInv Secp256k1 + testRandomInv BLS12_377 + testRandomInv BLS12_381 + testRandomInv BN446 + testRandomInv FKM12_447 + testRandomInv BLS12_461 + testRandomInv BN462 main() diff --git a/tests/test_finite_fields_sqrt.nim b/tests/test_finite_fields_sqrt.nim index 0d29595..86af23b 100644 --- a/tests/test_finite_fields_sqrt.nim +++ b/tests/test_finite_fields_sqrt.nim @@ -6,14 +6,14 @@ # * Apache v2 license (license terms in the root directory or at http://www.apache.org/licenses/LICENSE-2.0). # at your option. This file may not be copied, modified, or distributed except according to those terms. -import std/unittest, std/times, - ../constantine/[arithmetic, primitives], +import ../constantine/[arithmetic, primitives], ../constantine/io/[io_fields], ../constantine/config/[curves, common], # Test utilities ../helpers/prng, # Standard library - std/tables + std/tables, + std/unittest, std/times const Iters = 128