Implement constant-time `div2` on finite and extension fields
This commit is contained in:
parent
8a9cb9287c
commit
aff44f4d8e
|
@ -18,7 +18,7 @@ import
|
||||||
../constantine/arithmetic,
|
../constantine/arithmetic,
|
||||||
../constantine/towers,
|
../constantine/towers,
|
||||||
# Helpers
|
# Helpers
|
||||||
../helpers/[timers, prng, static_for],
|
../helpers/[timers, prng_unsafe, static_for],
|
||||||
# Standard library
|
# Standard library
|
||||||
std/[monotimes, times, strformat, strutils, macros]
|
std/[monotimes, times, strformat, strutils, macros]
|
||||||
|
|
||||||
|
@ -82,42 +82,42 @@ template bench(op: string, T: typedesc, iters: int, body: untyped): untyped =
|
||||||
report(op, fixFieldDisplay(T), start, stop, startClk, stopClk, iters)
|
report(op, fixFieldDisplay(T), start, stop, startClk, stopClk, iters)
|
||||||
|
|
||||||
proc addBench*(T: typedesc, iters: int) =
|
proc addBench*(T: typedesc, iters: int) =
|
||||||
var x = rng.random(T)
|
var x = rng.random_unsafe(T)
|
||||||
let y = rng.random(T)
|
let y = rng.random_unsafe(T)
|
||||||
bench("Addition", T, iters):
|
bench("Addition", T, iters):
|
||||||
x += y
|
x += y
|
||||||
|
|
||||||
proc subBench*(T: typedesc, iters: int) =
|
proc subBench*(T: typedesc, iters: int) =
|
||||||
var x = rng.random(T)
|
var x = rng.random_unsafe(T)
|
||||||
let y = rng.random(T)
|
let y = rng.random_unsafe(T)
|
||||||
preventOptimAway(x)
|
preventOptimAway(x)
|
||||||
bench("Substraction", T, iters):
|
bench("Substraction", T, iters):
|
||||||
x -= y
|
x -= y
|
||||||
|
|
||||||
proc negBench*(T: typedesc, iters: int) =
|
proc negBench*(T: typedesc, iters: int) =
|
||||||
var r: T
|
var r: T
|
||||||
let x = rng.random(T)
|
let x = rng.random_unsafe(T)
|
||||||
bench("Negation", T, iters):
|
bench("Negation", T, iters):
|
||||||
r.neg(x)
|
r.neg(x)
|
||||||
|
|
||||||
proc mulBench*(T: typedesc, iters: int) =
|
proc mulBench*(T: typedesc, iters: int) =
|
||||||
var r: T
|
var r: T
|
||||||
let x = rng.random(T)
|
let x = rng.random_unsafe(T)
|
||||||
let y = rng.random(T)
|
let y = rng.random_unsafe(T)
|
||||||
preventOptimAway(r)
|
preventOptimAway(r)
|
||||||
bench("Multiplication", T, iters):
|
bench("Multiplication", T, iters):
|
||||||
r.prod(x, y)
|
r.prod(x, y)
|
||||||
|
|
||||||
proc sqrBench*(T: typedesc, iters: int) =
|
proc sqrBench*(T: typedesc, iters: int) =
|
||||||
var r: T
|
var r: T
|
||||||
let x = rng.random(T)
|
let x = rng.random_unsafe(T)
|
||||||
preventOptimAway(r)
|
preventOptimAway(r)
|
||||||
bench("Squaring", T, iters):
|
bench("Squaring", T, iters):
|
||||||
r.square(x)
|
r.square(x)
|
||||||
|
|
||||||
proc invBench*(T: typedesc, iters: int) =
|
proc invBench*(T: typedesc, iters: int) =
|
||||||
var r: T
|
var r: T
|
||||||
let x = rng.random(T)
|
let x = rng.random_unsafe(T)
|
||||||
preventOptimAway(r)
|
preventOptimAway(r)
|
||||||
bench("Inversion", T, iters):
|
bench("Inversion", T, iters):
|
||||||
r.inv(x)
|
r.inv(x)
|
||||||
|
|
|
@ -249,6 +249,20 @@ func reduce*[aBits, mBits](r: var BigInt[mBits], a: BigInt[aBits], M: BigInt[mBi
|
||||||
# pass a pointer+length to a fixed session of the BSS.
|
# pass a pointer+length to a fixed session of the BSS.
|
||||||
reduce(r.limbs, a.limbs, aBits, M.limbs, mBits)
|
reduce(r.limbs, a.limbs, aBits, M.limbs, mBits)
|
||||||
|
|
||||||
|
func div2mod*[bits](a: var BigInt[bits], mp1div2: BigInt[bits]) =
|
||||||
|
## Compute a <- a/2 (mod M)
|
||||||
|
## `mp1div2` is the modulus (M+1)/2
|
||||||
|
##
|
||||||
|
## Normally if `a` is odd we add the modulus before dividing by 2
|
||||||
|
## but this may overflow and we might lose a bit before shifting.
|
||||||
|
## Instead we shift first and then add half the modulus rounded up
|
||||||
|
##
|
||||||
|
## Assuming M is odd, `mp1div2` can be precomputed without
|
||||||
|
## overflowing the "Limbs" by dividing by 2 first
|
||||||
|
## and add 1
|
||||||
|
## Otherwise `mp1div2` should be M/2
|
||||||
|
a.limbs.div2mod(mp1div2.limbs)
|
||||||
|
|
||||||
func steinsGCD*[bits](r: var BigInt[bits], a, F, M, mp1div2: BigInt[bits]) =
|
func steinsGCD*[bits](r: var BigInt[bits], a, F, M, mp1div2: BigInt[bits]) =
|
||||||
## Compute F multiplied the modular inverse of ``a`` modulo M
|
## Compute F multiplied the modular inverse of ``a`` modulo M
|
||||||
## r ≡ F . a^-1 (mod M)
|
## r ≡ F . a^-1 (mod M)
|
||||||
|
|
|
@ -179,9 +179,13 @@ func neg*(r: var Fp, a: Fp) =
|
||||||
## Negate modulo p
|
## Negate modulo p
|
||||||
discard r.mres.diff(Fp.C.Mod, a.mres)
|
discard r.mres.diff(Fp.C.Mod, a.mres)
|
||||||
|
|
||||||
|
func div2*(a: var Fp) =
|
||||||
|
## Modular division by 2
|
||||||
|
a.mres.div2mod(Fp.C.getPrimePlus1div2())
|
||||||
|
|
||||||
# ############################################################
|
# ############################################################
|
||||||
#
|
#
|
||||||
# Field arithmetic exponentiation and inversion
|
# Field arithmetic exponentiation
|
||||||
#
|
#
|
||||||
# ############################################################
|
# ############################################################
|
||||||
#
|
#
|
||||||
|
|
|
@ -14,6 +14,34 @@ import
|
||||||
# No exceptions allowed
|
# No exceptions allowed
|
||||||
{.push raises: [].}
|
{.push raises: [].}
|
||||||
|
|
||||||
|
# ############################################################
|
||||||
|
#
|
||||||
|
# Modular division by 2
|
||||||
|
#
|
||||||
|
# ############################################################
|
||||||
|
|
||||||
|
func div2mod*(a: var Limbs, mp1div2: Limbs) {.inline.}=
|
||||||
|
## Modular Division by 2
|
||||||
|
## `a` will be divided in-place
|
||||||
|
## `mp1div2` is the modulus (M+1)/2
|
||||||
|
##
|
||||||
|
## Normally if `a` is odd we add the modulus before dividing by 2
|
||||||
|
## but this may overflow and we might lose a bit before shifting.
|
||||||
|
## Instead we shift first and then add half the modulus rounded up
|
||||||
|
##
|
||||||
|
## Assuming M is odd, `mp1div2` can be precomputed without
|
||||||
|
## overflowing the "Limbs" by dividing by 2 first
|
||||||
|
## and add 1
|
||||||
|
## Otherwise `mp1div2` should be M/2
|
||||||
|
|
||||||
|
# if a.isOdd:
|
||||||
|
# a += M
|
||||||
|
# a = a shr 1
|
||||||
|
let wasOdd = a.isOdd()
|
||||||
|
a.shiftRight(1)
|
||||||
|
let carry = a.cadd(mp1div2, wasOdd)
|
||||||
|
debug: doAssert not carry.bool
|
||||||
|
|
||||||
# ############################################################
|
# ############################################################
|
||||||
#
|
#
|
||||||
# Modular inversion
|
# Modular inversion
|
||||||
|
@ -107,17 +135,8 @@ func steinsGCD*(v: var Limbs, a: Limbs, F, M: Limbs, bits: int, mp1div2: Limbs)
|
||||||
let neg = isOddA and (SecretBool) u.csub(v, isOddA)
|
let neg = isOddA and (SecretBool) u.csub(v, isOddA)
|
||||||
let corrected = u.cadd(M, neg)
|
let corrected = u.cadd(M, neg)
|
||||||
|
|
||||||
let isOddU = u.isOdd()
|
# u = u/2 (mod M)
|
||||||
# if u.isOdd:
|
u.div2mod(mp1div2)
|
||||||
# u += n
|
|
||||||
# u = u shr 1
|
|
||||||
#
|
|
||||||
# Warning ⚠️: u += n will overflow the BigInt
|
|
||||||
# and we might lose a bit on the next shift
|
|
||||||
# Instead we shift first and then add hald the modulus rounded up
|
|
||||||
u.shiftRight(1)
|
|
||||||
let carry = u.cadd(mp1div2, isOddU)
|
|
||||||
debug: doAssert not carry.bool
|
|
||||||
|
|
||||||
debug:
|
debug:
|
||||||
doAssert bool a.isZero()
|
doAssert bool a.isZero()
|
||||||
|
|
|
@ -83,6 +83,11 @@ func isOne*(a: ExtensionField): SecretBool =
|
||||||
# Abelian group
|
# Abelian group
|
||||||
# -------------------------------------------------------------------
|
# -------------------------------------------------------------------
|
||||||
|
|
||||||
|
func neg*(r: var ExtensionField, a: ExtensionField) =
|
||||||
|
## Field out-of-place negation
|
||||||
|
for fR, fA in fields(r, a):
|
||||||
|
fR.neg(fA)
|
||||||
|
|
||||||
func `+=`*(a: var ExtensionField, b: ExtensionField) =
|
func `+=`*(a: var ExtensionField, b: ExtensionField) =
|
||||||
## Addition in the extension field
|
## Addition in the extension field
|
||||||
for fA, fB in fields(a, b):
|
for fA, fB in fields(a, b):
|
||||||
|
@ -103,10 +108,10 @@ func double*(a: var ExtensionField) =
|
||||||
for fA in fields(a):
|
for fA in fields(a):
|
||||||
fA.double()
|
fA.double()
|
||||||
|
|
||||||
func neg*(r: var ExtensionField, a: ExtensionField) =
|
func div2*(a: var ExtensionField) =
|
||||||
## Field out-of-place negation
|
## Field in-place division by 2
|
||||||
for fR, fA in fields(r, a):
|
for fA in fields(a):
|
||||||
fR.neg(fA)
|
fA.div2()
|
||||||
|
|
||||||
func sum*(r: var QuadraticExt, a, b: QuadraticExt) =
|
func sum*(r: var QuadraticExt, a, b: QuadraticExt) =
|
||||||
## Sum ``a`` and ``b`` into ``r``
|
## Sum ``a`` and ``b`` into ``r``
|
||||||
|
|
|
@ -152,6 +152,32 @@ proc main() =
|
||||||
check:
|
check:
|
||||||
computed == expected
|
computed == expected
|
||||||
|
|
||||||
|
suite "Modular division by 2":
|
||||||
|
proc testRandomDiv2(curve: static Curve) =
|
||||||
|
test "Random modular div2 testing on " & $Curve(curve):
|
||||||
|
for _ in 0 ..< Iters:
|
||||||
|
let a = rng.random_unsafe(Fp[curve])
|
||||||
|
var a2 = a
|
||||||
|
a2.double()
|
||||||
|
a2.div2()
|
||||||
|
check: bool(a == a2)
|
||||||
|
a2.div2()
|
||||||
|
a2.double()
|
||||||
|
check: bool(a == a2)
|
||||||
|
|
||||||
|
testRandomDiv2 P224
|
||||||
|
testRandomDiv2 BN254_Nogami
|
||||||
|
testRandomDiv2 BN254_Snarks
|
||||||
|
testRandomDiv2 Curve25519
|
||||||
|
testRandomDiv2 P256
|
||||||
|
testRandomDiv2 Secp256k1
|
||||||
|
testRandomDiv2 BLS12_377
|
||||||
|
testRandomDiv2 BLS12_381
|
||||||
|
testRandomDiv2 BN446
|
||||||
|
testRandomDiv2 FKM12_447
|
||||||
|
testRandomDiv2 BLS12_461
|
||||||
|
testRandomDiv2 BN462
|
||||||
|
|
||||||
suite "Modular inversion over prime fields":
|
suite "Modular inversion over prime fields":
|
||||||
test "Specific tests on Fp[BLS12_381]":
|
test "Specific tests on Fp[BLS12_381]":
|
||||||
block: # No inverse exist for 0 --> should return 0 for projective/jacobian to affine coordinate conversion
|
block: # No inverse exist for 0 --> should return 0 for projective/jacobian to affine coordinate conversion
|
||||||
|
|
|
@ -61,11 +61,12 @@ proc runTowerTests*[N](
|
||||||
test(ExtField(ExtDegree, curve))
|
test(ExtField(ExtDegree, curve))
|
||||||
|
|
||||||
test "Addition, substraction negation are consistent":
|
test "Addition, substraction negation are consistent":
|
||||||
proc test(Field: typedesc) =
|
proc test(Field: typedesc, Iters: static int) =
|
||||||
# Try to exercise all code paths for in-place/out-of-place add/sum/sub/diff/double/neg
|
# Try to exercise all code paths for in-place/out-of-place add/sum/sub/diff/double/neg
|
||||||
# (1 - (-a) - b + (-a) - 2a) + (2a + 2b + (-b)) == 1
|
# (1 - (-a) - b + (-a) - 2a) + (2a + 2b + (-b)) == 1
|
||||||
var accum {.noInit.}, One {.noInit.}, a{.noInit.}, na{.noInit.}, b{.noInit.}, nb{.noInit.}, a2 {.noInit.}, b2 {.noInit.}: Field
|
var accum {.noInit.}, One {.noInit.}, a{.noInit.}, na{.noInit.}, b{.noInit.}, nb{.noInit.}, a2 {.noInit.}, b2 {.noInit.}: Field
|
||||||
|
|
||||||
|
for _ in 0 ..< Iters:
|
||||||
One.setOne()
|
One.setOne()
|
||||||
a = rng.random_unsafe(Field)
|
a = rng.random_unsafe(Field)
|
||||||
a2 = a
|
a2 = a
|
||||||
|
@ -89,7 +90,22 @@ proc runTowerTests*[N](
|
||||||
check: bool accum.isOne()
|
check: bool accum.isOne()
|
||||||
|
|
||||||
staticFor(curve, TestCurves):
|
staticFor(curve, TestCurves):
|
||||||
test(ExtField(ExtDegree, curve))
|
test(ExtField(ExtDegree, curve), Iters)
|
||||||
|
|
||||||
|
test "Division by 2":
|
||||||
|
proc test(Field: typedesc, Iters: static int) =
|
||||||
|
for _ in 0 ..< Iters:
|
||||||
|
let a = rng.random_unsafe(Field)
|
||||||
|
var a2 = a
|
||||||
|
a2.double()
|
||||||
|
a2.div2()
|
||||||
|
check: bool(a == a2)
|
||||||
|
a2.div2()
|
||||||
|
a2.double()
|
||||||
|
check: bool(a == a2)
|
||||||
|
|
||||||
|
staticFor(curve, TestCurves):
|
||||||
|
test(ExtField(ExtDegree, curve), Iters)
|
||||||
|
|
||||||
test "Squaring 1 returns 1":
|
test "Squaring 1 returns 1":
|
||||||
proc test(Field: typedesc) =
|
proc test(Field: typedesc) =
|
||||||
|
|
Loading…
Reference in New Issue