Implement constant-time `div2` on finite and extension fields
This commit is contained in:
parent
8a9cb9287c
commit
aff44f4d8e
|
@ -18,7 +18,7 @@ import
|
|||
../constantine/arithmetic,
|
||||
../constantine/towers,
|
||||
# Helpers
|
||||
../helpers/[timers, prng, static_for],
|
||||
../helpers/[timers, prng_unsafe, static_for],
|
||||
# Standard library
|
||||
std/[monotimes, times, strformat, strutils, macros]
|
||||
|
||||
|
@ -82,42 +82,42 @@ template bench(op: string, T: typedesc, iters: int, body: untyped): untyped =
|
|||
report(op, fixFieldDisplay(T), start, stop, startClk, stopClk, iters)
|
||||
|
||||
proc addBench*(T: typedesc, iters: int) =
|
||||
var x = rng.random(T)
|
||||
let y = rng.random(T)
|
||||
var x = rng.random_unsafe(T)
|
||||
let y = rng.random_unsafe(T)
|
||||
bench("Addition", T, iters):
|
||||
x += y
|
||||
|
||||
proc subBench*(T: typedesc, iters: int) =
|
||||
var x = rng.random(T)
|
||||
let y = rng.random(T)
|
||||
var x = rng.random_unsafe(T)
|
||||
let y = rng.random_unsafe(T)
|
||||
preventOptimAway(x)
|
||||
bench("Substraction", T, iters):
|
||||
x -= y
|
||||
|
||||
proc negBench*(T: typedesc, iters: int) =
|
||||
var r: T
|
||||
let x = rng.random(T)
|
||||
let x = rng.random_unsafe(T)
|
||||
bench("Negation", T, iters):
|
||||
r.neg(x)
|
||||
|
||||
proc mulBench*(T: typedesc, iters: int) =
|
||||
var r: T
|
||||
let x = rng.random(T)
|
||||
let y = rng.random(T)
|
||||
let x = rng.random_unsafe(T)
|
||||
let y = rng.random_unsafe(T)
|
||||
preventOptimAway(r)
|
||||
bench("Multiplication", T, iters):
|
||||
r.prod(x, y)
|
||||
|
||||
proc sqrBench*(T: typedesc, iters: int) =
|
||||
var r: T
|
||||
let x = rng.random(T)
|
||||
let x = rng.random_unsafe(T)
|
||||
preventOptimAway(r)
|
||||
bench("Squaring", T, iters):
|
||||
r.square(x)
|
||||
|
||||
proc invBench*(T: typedesc, iters: int) =
|
||||
var r: T
|
||||
let x = rng.random(T)
|
||||
let x = rng.random_unsafe(T)
|
||||
preventOptimAway(r)
|
||||
bench("Inversion", T, iters):
|
||||
r.inv(x)
|
||||
|
|
|
@ -249,6 +249,20 @@ func reduce*[aBits, mBits](r: var BigInt[mBits], a: BigInt[aBits], M: BigInt[mBi
|
|||
# pass a pointer+length to a fixed session of the BSS.
|
||||
reduce(r.limbs, a.limbs, aBits, M.limbs, mBits)
|
||||
|
||||
func div2mod*[bits](a: var BigInt[bits], mp1div2: BigInt[bits]) =
|
||||
## Compute a <- a/2 (mod M)
|
||||
## `mp1div2` is the modulus (M+1)/2
|
||||
##
|
||||
## Normally if `a` is odd we add the modulus before dividing by 2
|
||||
## but this may overflow and we might lose a bit before shifting.
|
||||
## Instead we shift first and then add half the modulus rounded up
|
||||
##
|
||||
## Assuming M is odd, `mp1div2` can be precomputed without
|
||||
## overflowing the "Limbs" by dividing by 2 first
|
||||
## and add 1
|
||||
## Otherwise `mp1div2` should be M/2
|
||||
a.limbs.div2mod(mp1div2.limbs)
|
||||
|
||||
func steinsGCD*[bits](r: var BigInt[bits], a, F, M, mp1div2: BigInt[bits]) =
|
||||
## Compute F multiplied the modular inverse of ``a`` modulo M
|
||||
## r ≡ F . a^-1 (mod M)
|
||||
|
|
|
@ -179,9 +179,13 @@ func neg*(r: var Fp, a: Fp) =
|
|||
## Negate modulo p
|
||||
discard r.mres.diff(Fp.C.Mod, a.mres)
|
||||
|
||||
func div2*(a: var Fp) =
|
||||
## Modular division by 2
|
||||
a.mres.div2mod(Fp.C.getPrimePlus1div2())
|
||||
|
||||
# ############################################################
|
||||
#
|
||||
# Field arithmetic exponentiation and inversion
|
||||
# Field arithmetic exponentiation
|
||||
#
|
||||
# ############################################################
|
||||
#
|
||||
|
|
|
@ -14,6 +14,34 @@ import
|
|||
# No exceptions allowed
|
||||
{.push raises: [].}
|
||||
|
||||
# ############################################################
|
||||
#
|
||||
# Modular division by 2
|
||||
#
|
||||
# ############################################################
|
||||
|
||||
func div2mod*(a: var Limbs, mp1div2: Limbs) {.inline.}=
|
||||
## Modular Division by 2
|
||||
## `a` will be divided in-place
|
||||
## `mp1div2` is the modulus (M+1)/2
|
||||
##
|
||||
## Normally if `a` is odd we add the modulus before dividing by 2
|
||||
## but this may overflow and we might lose a bit before shifting.
|
||||
## Instead we shift first and then add half the modulus rounded up
|
||||
##
|
||||
## Assuming M is odd, `mp1div2` can be precomputed without
|
||||
## overflowing the "Limbs" by dividing by 2 first
|
||||
## and add 1
|
||||
## Otherwise `mp1div2` should be M/2
|
||||
|
||||
# if a.isOdd:
|
||||
# a += M
|
||||
# a = a shr 1
|
||||
let wasOdd = a.isOdd()
|
||||
a.shiftRight(1)
|
||||
let carry = a.cadd(mp1div2, wasOdd)
|
||||
debug: doAssert not carry.bool
|
||||
|
||||
# ############################################################
|
||||
#
|
||||
# Modular inversion
|
||||
|
@ -107,17 +135,8 @@ func steinsGCD*(v: var Limbs, a: Limbs, F, M: Limbs, bits: int, mp1div2: Limbs)
|
|||
let neg = isOddA and (SecretBool) u.csub(v, isOddA)
|
||||
let corrected = u.cadd(M, neg)
|
||||
|
||||
let isOddU = u.isOdd()
|
||||
# if u.isOdd:
|
||||
# u += n
|
||||
# u = u shr 1
|
||||
#
|
||||
# Warning ⚠️: u += n will overflow the BigInt
|
||||
# and we might lose a bit on the next shift
|
||||
# Instead we shift first and then add hald the modulus rounded up
|
||||
u.shiftRight(1)
|
||||
let carry = u.cadd(mp1div2, isOddU)
|
||||
debug: doAssert not carry.bool
|
||||
# u = u/2 (mod M)
|
||||
u.div2mod(mp1div2)
|
||||
|
||||
debug:
|
||||
doAssert bool a.isZero()
|
||||
|
|
|
@ -83,6 +83,11 @@ func isOne*(a: ExtensionField): SecretBool =
|
|||
# Abelian group
|
||||
# -------------------------------------------------------------------
|
||||
|
||||
func neg*(r: var ExtensionField, a: ExtensionField) =
|
||||
## Field out-of-place negation
|
||||
for fR, fA in fields(r, a):
|
||||
fR.neg(fA)
|
||||
|
||||
func `+=`*(a: var ExtensionField, b: ExtensionField) =
|
||||
## Addition in the extension field
|
||||
for fA, fB in fields(a, b):
|
||||
|
@ -103,10 +108,10 @@ func double*(a: var ExtensionField) =
|
|||
for fA in fields(a):
|
||||
fA.double()
|
||||
|
||||
func neg*(r: var ExtensionField, a: ExtensionField) =
|
||||
## Field out-of-place negation
|
||||
for fR, fA in fields(r, a):
|
||||
fR.neg(fA)
|
||||
func div2*(a: var ExtensionField) =
|
||||
## Field in-place division by 2
|
||||
for fA in fields(a):
|
||||
fA.div2()
|
||||
|
||||
func sum*(r: var QuadraticExt, a, b: QuadraticExt) =
|
||||
## Sum ``a`` and ``b`` into ``r``
|
||||
|
|
|
@ -152,6 +152,32 @@ proc main() =
|
|||
check:
|
||||
computed == expected
|
||||
|
||||
suite "Modular division by 2":
|
||||
proc testRandomDiv2(curve: static Curve) =
|
||||
test "Random modular div2 testing on " & $Curve(curve):
|
||||
for _ in 0 ..< Iters:
|
||||
let a = rng.random_unsafe(Fp[curve])
|
||||
var a2 = a
|
||||
a2.double()
|
||||
a2.div2()
|
||||
check: bool(a == a2)
|
||||
a2.div2()
|
||||
a2.double()
|
||||
check: bool(a == a2)
|
||||
|
||||
testRandomDiv2 P224
|
||||
testRandomDiv2 BN254_Nogami
|
||||
testRandomDiv2 BN254_Snarks
|
||||
testRandomDiv2 Curve25519
|
||||
testRandomDiv2 P256
|
||||
testRandomDiv2 Secp256k1
|
||||
testRandomDiv2 BLS12_377
|
||||
testRandomDiv2 BLS12_381
|
||||
testRandomDiv2 BN446
|
||||
testRandomDiv2 FKM12_447
|
||||
testRandomDiv2 BLS12_461
|
||||
testRandomDiv2 BN462
|
||||
|
||||
suite "Modular inversion over prime fields":
|
||||
test "Specific tests on Fp[BLS12_381]":
|
||||
block: # No inverse exist for 0 --> should return 0 for projective/jacobian to affine coordinate conversion
|
||||
|
|
|
@ -61,11 +61,12 @@ proc runTowerTests*[N](
|
|||
test(ExtField(ExtDegree, curve))
|
||||
|
||||
test "Addition, substraction negation are consistent":
|
||||
proc test(Field: typedesc) =
|
||||
proc test(Field: typedesc, Iters: static int) =
|
||||
# Try to exercise all code paths for in-place/out-of-place add/sum/sub/diff/double/neg
|
||||
# (1 - (-a) - b + (-a) - 2a) + (2a + 2b + (-b)) == 1
|
||||
var accum {.noInit.}, One {.noInit.}, a{.noInit.}, na{.noInit.}, b{.noInit.}, nb{.noInit.}, a2 {.noInit.}, b2 {.noInit.}: Field
|
||||
|
||||
for _ in 0 ..< Iters:
|
||||
One.setOne()
|
||||
a = rng.random_unsafe(Field)
|
||||
a2 = a
|
||||
|
@ -89,7 +90,22 @@ proc runTowerTests*[N](
|
|||
check: bool accum.isOne()
|
||||
|
||||
staticFor(curve, TestCurves):
|
||||
test(ExtField(ExtDegree, curve))
|
||||
test(ExtField(ExtDegree, curve), Iters)
|
||||
|
||||
test "Division by 2":
|
||||
proc test(Field: typedesc, Iters: static int) =
|
||||
for _ in 0 ..< Iters:
|
||||
let a = rng.random_unsafe(Field)
|
||||
var a2 = a
|
||||
a2.double()
|
||||
a2.div2()
|
||||
check: bool(a == a2)
|
||||
a2.div2()
|
||||
a2.double()
|
||||
check: bool(a == a2)
|
||||
|
||||
staticFor(curve, TestCurves):
|
||||
test(ExtField(ExtDegree, curve), Iters)
|
||||
|
||||
test "Squaring 1 returns 1":
|
||||
proc test(Field: typedesc) =
|
||||
|
|
Loading…
Reference in New Issue