Implement constant-time `div2` on finite and extension fields

This commit is contained in:
Mamy André-Ratsimbazafy 2020-04-15 02:12:45 +02:00
parent 8a9cb9287c
commit aff44f4d8e
No known key found for this signature in database
GPG Key ID: 7B88AD1FE79492E1
7 changed files with 129 additions and 45 deletions

View File

@ -18,7 +18,7 @@ import
../constantine/arithmetic, ../constantine/arithmetic,
../constantine/towers, ../constantine/towers,
# Helpers # Helpers
../helpers/[timers, prng, static_for], ../helpers/[timers, prng_unsafe, static_for],
# Standard library # Standard library
std/[monotimes, times, strformat, strutils, macros] std/[monotimes, times, strformat, strutils, macros]
@ -82,42 +82,42 @@ template bench(op: string, T: typedesc, iters: int, body: untyped): untyped =
report(op, fixFieldDisplay(T), start, stop, startClk, stopClk, iters) report(op, fixFieldDisplay(T), start, stop, startClk, stopClk, iters)
proc addBench*(T: typedesc, iters: int) = proc addBench*(T: typedesc, iters: int) =
var x = rng.random(T) var x = rng.random_unsafe(T)
let y = rng.random(T) let y = rng.random_unsafe(T)
bench("Addition", T, iters): bench("Addition", T, iters):
x += y x += y
proc subBench*(T: typedesc, iters: int) = proc subBench*(T: typedesc, iters: int) =
var x = rng.random(T) var x = rng.random_unsafe(T)
let y = rng.random(T) let y = rng.random_unsafe(T)
preventOptimAway(x) preventOptimAway(x)
bench("Substraction", T, iters): bench("Substraction", T, iters):
x -= y x -= y
proc negBench*(T: typedesc, iters: int) = proc negBench*(T: typedesc, iters: int) =
var r: T var r: T
let x = rng.random(T) let x = rng.random_unsafe(T)
bench("Negation", T, iters): bench("Negation", T, iters):
r.neg(x) r.neg(x)
proc mulBench*(T: typedesc, iters: int) = proc mulBench*(T: typedesc, iters: int) =
var r: T var r: T
let x = rng.random(T) let x = rng.random_unsafe(T)
let y = rng.random(T) let y = rng.random_unsafe(T)
preventOptimAway(r) preventOptimAway(r)
bench("Multiplication", T, iters): bench("Multiplication", T, iters):
r.prod(x, y) r.prod(x, y)
proc sqrBench*(T: typedesc, iters: int) = proc sqrBench*(T: typedesc, iters: int) =
var r: T var r: T
let x = rng.random(T) let x = rng.random_unsafe(T)
preventOptimAway(r) preventOptimAway(r)
bench("Squaring", T, iters): bench("Squaring", T, iters):
r.square(x) r.square(x)
proc invBench*(T: typedesc, iters: int) = proc invBench*(T: typedesc, iters: int) =
var r: T var r: T
let x = rng.random(T) let x = rng.random_unsafe(T)
preventOptimAway(r) preventOptimAway(r)
bench("Inversion", T, iters): bench("Inversion", T, iters):
r.inv(x) r.inv(x)

View File

@ -249,6 +249,20 @@ func reduce*[aBits, mBits](r: var BigInt[mBits], a: BigInt[aBits], M: BigInt[mBi
# pass a pointer+length to a fixed session of the BSS. # pass a pointer+length to a fixed session of the BSS.
reduce(r.limbs, a.limbs, aBits, M.limbs, mBits) reduce(r.limbs, a.limbs, aBits, M.limbs, mBits)
func div2mod*[bits](a: var BigInt[bits], mp1div2: BigInt[bits]) =
## Compute a <- a/2 (mod M)
## `mp1div2` is the modulus (M+1)/2
##
## Normally if `a` is odd we add the modulus before dividing by 2
## but this may overflow and we might lose a bit before shifting.
## Instead we shift first and then add half the modulus rounded up
##
## Assuming M is odd, `mp1div2` can be precomputed without
## overflowing the "Limbs" by dividing by 2 first
## and add 1
## Otherwise `mp1div2` should be M/2
a.limbs.div2mod(mp1div2.limbs)
func steinsGCD*[bits](r: var BigInt[bits], a, F, M, mp1div2: BigInt[bits]) = func steinsGCD*[bits](r: var BigInt[bits], a, F, M, mp1div2: BigInt[bits]) =
## Compute F multiplied the modular inverse of ``a`` modulo M ## Compute F multiplied the modular inverse of ``a`` modulo M
## r ≡ F . a^-1 (mod M) ## r ≡ F . a^-1 (mod M)

View File

@ -179,9 +179,13 @@ func neg*(r: var Fp, a: Fp) =
## Negate modulo p ## Negate modulo p
discard r.mres.diff(Fp.C.Mod, a.mres) discard r.mres.diff(Fp.C.Mod, a.mres)
func div2*(a: var Fp) =
## Modular division by 2
a.mres.div2mod(Fp.C.getPrimePlus1div2())
# ############################################################ # ############################################################
# #
# Field arithmetic exponentiation and inversion # Field arithmetic exponentiation
# #
# ############################################################ # ############################################################
# #

View File

@ -14,6 +14,34 @@ import
# No exceptions allowed # No exceptions allowed
{.push raises: [].} {.push raises: [].}
# ############################################################
#
# Modular division by 2
#
# ############################################################
func div2mod*(a: var Limbs, mp1div2: Limbs) {.inline.}=
## Modular Division by 2
## `a` will be divided in-place
## `mp1div2` is the modulus (M+1)/2
##
## Normally if `a` is odd we add the modulus before dividing by 2
## but this may overflow and we might lose a bit before shifting.
## Instead we shift first and then add half the modulus rounded up
##
## Assuming M is odd, `mp1div2` can be precomputed without
## overflowing the "Limbs" by dividing by 2 first
## and add 1
## Otherwise `mp1div2` should be M/2
# if a.isOdd:
# a += M
# a = a shr 1
let wasOdd = a.isOdd()
a.shiftRight(1)
let carry = a.cadd(mp1div2, wasOdd)
debug: doAssert not carry.bool
# ############################################################ # ############################################################
# #
# Modular inversion # Modular inversion
@ -107,17 +135,8 @@ func steinsGCD*(v: var Limbs, a: Limbs, F, M: Limbs, bits: int, mp1div2: Limbs)
let neg = isOddA and (SecretBool) u.csub(v, isOddA) let neg = isOddA and (SecretBool) u.csub(v, isOddA)
let corrected = u.cadd(M, neg) let corrected = u.cadd(M, neg)
let isOddU = u.isOdd() # u = u/2 (mod M)
# if u.isOdd: u.div2mod(mp1div2)
# u += n
# u = u shr 1
#
# Warning ⚠️: u += n will overflow the BigInt
# and we might lose a bit on the next shift
# Instead we shift first and then add hald the modulus rounded up
u.shiftRight(1)
let carry = u.cadd(mp1div2, isOddU)
debug: doAssert not carry.bool
debug: debug:
doAssert bool a.isZero() doAssert bool a.isZero()

View File

@ -83,6 +83,11 @@ func isOne*(a: ExtensionField): SecretBool =
# Abelian group # Abelian group
# ------------------------------------------------------------------- # -------------------------------------------------------------------
func neg*(r: var ExtensionField, a: ExtensionField) =
## Field out-of-place negation
for fR, fA in fields(r, a):
fR.neg(fA)
func `+=`*(a: var ExtensionField, b: ExtensionField) = func `+=`*(a: var ExtensionField, b: ExtensionField) =
## Addition in the extension field ## Addition in the extension field
for fA, fB in fields(a, b): for fA, fB in fields(a, b):
@ -103,10 +108,10 @@ func double*(a: var ExtensionField) =
for fA in fields(a): for fA in fields(a):
fA.double() fA.double()
func neg*(r: var ExtensionField, a: ExtensionField) = func div2*(a: var ExtensionField) =
## Field out-of-place negation ## Field in-place division by 2
for fR, fA in fields(r, a): for fA in fields(a):
fR.neg(fA) fA.div2()
func sum*(r: var QuadraticExt, a, b: QuadraticExt) = func sum*(r: var QuadraticExt, a, b: QuadraticExt) =
## Sum ``a`` and ``b`` into ``r`` ## Sum ``a`` and ``b`` into ``r``

View File

@ -152,6 +152,32 @@ proc main() =
check: check:
computed == expected computed == expected
suite "Modular division by 2":
proc testRandomDiv2(curve: static Curve) =
test "Random modular div2 testing on " & $Curve(curve):
for _ in 0 ..< Iters:
let a = rng.random_unsafe(Fp[curve])
var a2 = a
a2.double()
a2.div2()
check: bool(a == a2)
a2.div2()
a2.double()
check: bool(a == a2)
testRandomDiv2 P224
testRandomDiv2 BN254_Nogami
testRandomDiv2 BN254_Snarks
testRandomDiv2 Curve25519
testRandomDiv2 P256
testRandomDiv2 Secp256k1
testRandomDiv2 BLS12_377
testRandomDiv2 BLS12_381
testRandomDiv2 BN446
testRandomDiv2 FKM12_447
testRandomDiv2 BLS12_461
testRandomDiv2 BN462
suite "Modular inversion over prime fields": suite "Modular inversion over prime fields":
test "Specific tests on Fp[BLS12_381]": test "Specific tests on Fp[BLS12_381]":
block: # No inverse exist for 0 --> should return 0 for projective/jacobian to affine coordinate conversion block: # No inverse exist for 0 --> should return 0 for projective/jacobian to affine coordinate conversion

View File

@ -61,11 +61,12 @@ proc runTowerTests*[N](
test(ExtField(ExtDegree, curve)) test(ExtField(ExtDegree, curve))
test "Addition, substraction negation are consistent": test "Addition, substraction negation are consistent":
proc test(Field: typedesc) = proc test(Field: typedesc, Iters: static int) =
# Try to exercise all code paths for in-place/out-of-place add/sum/sub/diff/double/neg # Try to exercise all code paths for in-place/out-of-place add/sum/sub/diff/double/neg
# (1 - (-a) - b + (-a) - 2a) + (2a + 2b + (-b)) == 1 # (1 - (-a) - b + (-a) - 2a) + (2a + 2b + (-b)) == 1
var accum {.noInit.}, One {.noInit.}, a{.noInit.}, na{.noInit.}, b{.noInit.}, nb{.noInit.}, a2 {.noInit.}, b2 {.noInit.}: Field var accum {.noInit.}, One {.noInit.}, a{.noInit.}, na{.noInit.}, b{.noInit.}, nb{.noInit.}, a2 {.noInit.}, b2 {.noInit.}: Field
for _ in 0 ..< Iters:
One.setOne() One.setOne()
a = rng.random_unsafe(Field) a = rng.random_unsafe(Field)
a2 = a a2 = a
@ -89,7 +90,22 @@ proc runTowerTests*[N](
check: bool accum.isOne() check: bool accum.isOne()
staticFor(curve, TestCurves): staticFor(curve, TestCurves):
test(ExtField(ExtDegree, curve)) test(ExtField(ExtDegree, curve), Iters)
test "Division by 2":
proc test(Field: typedesc, Iters: static int) =
for _ in 0 ..< Iters:
let a = rng.random_unsafe(Field)
var a2 = a
a2.double()
a2.div2()
check: bool(a == a2)
a2.div2()
a2.double()
check: bool(a == a2)
staticFor(curve, TestCurves):
test(ExtField(ExtDegree, curve), Iters)
test "Squaring 1 returns 1": test "Squaring 1 returns 1":
proc test(Field: typedesc) = proc test(Field: typedesc) =