diff --git a/benchmarks/bench_fields_template.nim b/benchmarks/bench_fields_template.nim index da602c0..f85e597 100644 --- a/benchmarks/bench_fields_template.nim +++ b/benchmarks/bench_fields_template.nim @@ -18,7 +18,7 @@ import ../constantine/arithmetic, ../constantine/towers, # Helpers - ../helpers/[timers, prng, static_for], + ../helpers/[timers, prng_unsafe, static_for], # Standard library std/[monotimes, times, strformat, strutils, macros] @@ -82,42 +82,42 @@ template bench(op: string, T: typedesc, iters: int, body: untyped): untyped = report(op, fixFieldDisplay(T), start, stop, startClk, stopClk, iters) proc addBench*(T: typedesc, iters: int) = - var x = rng.random(T) - let y = rng.random(T) + var x = rng.random_unsafe(T) + let y = rng.random_unsafe(T) bench("Addition", T, iters): x += y proc subBench*(T: typedesc, iters: int) = - var x = rng.random(T) - let y = rng.random(T) + var x = rng.random_unsafe(T) + let y = rng.random_unsafe(T) preventOptimAway(x) bench("Substraction", T, iters): x -= y proc negBench*(T: typedesc, iters: int) = var r: T - let x = rng.random(T) + let x = rng.random_unsafe(T) bench("Negation", T, iters): r.neg(x) proc mulBench*(T: typedesc, iters: int) = var r: T - let x = rng.random(T) - let y = rng.random(T) + let x = rng.random_unsafe(T) + let y = rng.random_unsafe(T) preventOptimAway(r) bench("Multiplication", T, iters): r.prod(x, y) proc sqrBench*(T: typedesc, iters: int) = var r: T - let x = rng.random(T) + let x = rng.random_unsafe(T) preventOptimAway(r) bench("Squaring", T, iters): r.square(x) proc invBench*(T: typedesc, iters: int) = var r: T - let x = rng.random(T) + let x = rng.random_unsafe(T) preventOptimAway(r) bench("Inversion", T, iters): r.inv(x) diff --git a/constantine/arithmetic/bigints.nim b/constantine/arithmetic/bigints.nim index 258e23d..26280d0 100644 --- a/constantine/arithmetic/bigints.nim +++ b/constantine/arithmetic/bigints.nim @@ -249,6 +249,20 @@ func reduce*[aBits, mBits](r: var BigInt[mBits], a: BigInt[aBits], M: BigInt[mBi # pass a pointer+length to a fixed session of the BSS. reduce(r.limbs, a.limbs, aBits, M.limbs, mBits) +func div2mod*[bits](a: var BigInt[bits], mp1div2: BigInt[bits]) = + ## Compute a <- a/2 (mod M) + ## `mp1div2` is the modulus (M+1)/2 + ## + ## Normally if `a` is odd we add the modulus before dividing by 2 + ## but this may overflow and we might lose a bit before shifting. + ## Instead we shift first and then add half the modulus rounded up + ## + ## Assuming M is odd, `mp1div2` can be precomputed without + ## overflowing the "Limbs" by dividing by 2 first + ## and add 1 + ## Otherwise `mp1div2` should be M/2 + a.limbs.div2mod(mp1div2.limbs) + func steinsGCD*[bits](r: var BigInt[bits], a, F, M, mp1div2: BigInt[bits]) = ## Compute F multiplied the modular inverse of ``a`` modulo M ## r ≡ F . a^-1 (mod M) diff --git a/constantine/arithmetic/finite_fields.nim b/constantine/arithmetic/finite_fields.nim index 88a51e1..eb888ff 100644 --- a/constantine/arithmetic/finite_fields.nim +++ b/constantine/arithmetic/finite_fields.nim @@ -179,9 +179,13 @@ func neg*(r: var Fp, a: Fp) = ## Negate modulo p discard r.mres.diff(Fp.C.Mod, a.mres) +func div2*(a: var Fp) = + ## Modular division by 2 + a.mres.div2mod(Fp.C.getPrimePlus1div2()) + # ############################################################ # -# Field arithmetic exponentiation and inversion +# Field arithmetic exponentiation # # ############################################################ # diff --git a/constantine/arithmetic/limbs_modular.nim b/constantine/arithmetic/limbs_modular.nim index 12de1e3..894f0c0 100644 --- a/constantine/arithmetic/limbs_modular.nim +++ b/constantine/arithmetic/limbs_modular.nim @@ -14,6 +14,34 @@ import # No exceptions allowed {.push raises: [].} +# ############################################################ +# +# Modular division by 2 +# +# ############################################################ + +func div2mod*(a: var Limbs, mp1div2: Limbs) {.inline.}= + ## Modular Division by 2 + ## `a` will be divided in-place + ## `mp1div2` is the modulus (M+1)/2 + ## + ## Normally if `a` is odd we add the modulus before dividing by 2 + ## but this may overflow and we might lose a bit before shifting. + ## Instead we shift first and then add half the modulus rounded up + ## + ## Assuming M is odd, `mp1div2` can be precomputed without + ## overflowing the "Limbs" by dividing by 2 first + ## and add 1 + ## Otherwise `mp1div2` should be M/2 + + # if a.isOdd: + # a += M + # a = a shr 1 + let wasOdd = a.isOdd() + a.shiftRight(1) + let carry = a.cadd(mp1div2, wasOdd) + debug: doAssert not carry.bool + # ############################################################ # # Modular inversion @@ -107,17 +135,8 @@ func steinsGCD*(v: var Limbs, a: Limbs, F, M: Limbs, bits: int, mp1div2: Limbs) let neg = isOddA and (SecretBool) u.csub(v, isOddA) let corrected = u.cadd(M, neg) - let isOddU = u.isOdd() - # if u.isOdd: - # u += n - # u = u shr 1 - # - # Warning ⚠️: u += n will overflow the BigInt - # and we might lose a bit on the next shift - # Instead we shift first and then add hald the modulus rounded up - u.shiftRight(1) - let carry = u.cadd(mp1div2, isOddU) - debug: doAssert not carry.bool + # u = u/2 (mod M) + u.div2mod(mp1div2) debug: doAssert bool a.isZero() diff --git a/constantine/tower_field_extensions/tower_common.nim b/constantine/tower_field_extensions/tower_common.nim index dcc0cd7..7fa9906 100644 --- a/constantine/tower_field_extensions/tower_common.nim +++ b/constantine/tower_field_extensions/tower_common.nim @@ -83,6 +83,11 @@ func isOne*(a: ExtensionField): SecretBool = # Abelian group # ------------------------------------------------------------------- +func neg*(r: var ExtensionField, a: ExtensionField) = + ## Field out-of-place negation + for fR, fA in fields(r, a): + fR.neg(fA) + func `+=`*(a: var ExtensionField, b: ExtensionField) = ## Addition in the extension field for fA, fB in fields(a, b): @@ -103,10 +108,10 @@ func double*(a: var ExtensionField) = for fA in fields(a): fA.double() -func neg*(r: var ExtensionField, a: ExtensionField) = - ## Field out-of-place negation - for fR, fA in fields(r, a): - fR.neg(fA) +func div2*(a: var ExtensionField) = + ## Field in-place division by 2 + for fA in fields(a): + fA.div2() func sum*(r: var QuadraticExt, a, b: QuadraticExt) = ## Sum ``a`` and ``b`` into ``r`` diff --git a/tests/test_finite_fields_powinv.nim b/tests/test_finite_fields_powinv.nim index e40d399..bb59f24 100644 --- a/tests/test_finite_fields_powinv.nim +++ b/tests/test_finite_fields_powinv.nim @@ -152,6 +152,32 @@ proc main() = check: computed == expected + suite "Modular division by 2": + proc testRandomDiv2(curve: static Curve) = + test "Random modular div2 testing on " & $Curve(curve): + for _ in 0 ..< Iters: + let a = rng.random_unsafe(Fp[curve]) + var a2 = a + a2.double() + a2.div2() + check: bool(a == a2) + a2.div2() + a2.double() + check: bool(a == a2) + + testRandomDiv2 P224 + testRandomDiv2 BN254_Nogami + testRandomDiv2 BN254_Snarks + testRandomDiv2 Curve25519 + testRandomDiv2 P256 + testRandomDiv2 Secp256k1 + testRandomDiv2 BLS12_377 + testRandomDiv2 BLS12_381 + testRandomDiv2 BN446 + testRandomDiv2 FKM12_447 + testRandomDiv2 BLS12_461 + testRandomDiv2 BN462 + suite "Modular inversion over prime fields": test "Specific tests on Fp[BLS12_381]": block: # No inverse exist for 0 --> should return 0 for projective/jacobian to affine coordinate conversion diff --git a/tests/test_fp_tower_template.nim b/tests/test_fp_tower_template.nim index aa39e3f..e21172b 100644 --- a/tests/test_fp_tower_template.nim +++ b/tests/test_fp_tower_template.nim @@ -61,35 +61,51 @@ proc runTowerTests*[N]( test(ExtField(ExtDegree, curve)) test "Addition, substraction negation are consistent": - proc test(Field: typedesc) = + proc test(Field: typedesc, Iters: static int) = # Try to exercise all code paths for in-place/out-of-place add/sum/sub/diff/double/neg # (1 - (-a) - b + (-a) - 2a) + (2a + 2b + (-b)) == 1 var accum {.noInit.}, One {.noInit.}, a{.noInit.}, na{.noInit.}, b{.noInit.}, nb{.noInit.}, a2 {.noInit.}, b2 {.noInit.}: Field - One.setOne() - a = rng.random_unsafe(Field) - a2 = a - a2.double() - na.neg(a) + for _ in 0 ..< Iters: + One.setOne() + a = rng.random_unsafe(Field) + a2 = a + a2.double() + na.neg(a) - b = rng.random_unsafe(Field) - b2.double(b) - nb.neg(b) + b = rng.random_unsafe(Field) + b2.double(b) + nb.neg(b) - accum.diff(One, na) - accum -= b - accum += na - accum -= a2 + accum.diff(One, na) + accum -= b + accum += na + accum -= a2 - var t{.noInit.}: Field - t.sum(a2, b2) - t += nb + var t{.noInit.}: Field + t.sum(a2, b2) + t += nb - accum += t - check: bool accum.isOne() + accum += t + check: bool accum.isOne() staticFor(curve, TestCurves): - test(ExtField(ExtDegree, curve)) + test(ExtField(ExtDegree, curve), Iters) + + test "Division by 2": + proc test(Field: typedesc, Iters: static int) = + for _ in 0 ..< Iters: + let a = rng.random_unsafe(Field) + var a2 = a + a2.double() + a2.div2() + check: bool(a == a2) + a2.div2() + a2.double() + check: bool(a == a2) + + staticFor(curve, TestCurves): + test(ExtField(ExtDegree, curve), Iters) test "Squaring 1 returns 1": proc test(Field: typedesc) =