Add primitive for window-based modular exponentiation

This commit is contained in:
Mamy André-Ratsimbazafy 2020-02-17 00:13:42 +01:00
parent 285b6aad1a
commit d7d20c50b6
No known key found for this signature in database
GPG Key ID: 7B88AD1FE79492E1
3 changed files with 99 additions and 12 deletions

View File

@ -22,3 +22,4 @@ task test, "Run all tests":
test "", "tests/test_bigints_multimod.nim" test "", "tests/test_bigints_multimod.nim"
test "", "tests/test_bigints_vs_gmp.nim" test "", "tests/test_bigints_vs_gmp.nim"
test "", "tests/test_finite_fields.nim" test "", "tests/test_finite_fields.nim"
test "", "tests/test_finite_fields_vs_gmp.nim"

View File

@ -54,6 +54,7 @@ import
../primitives/extended_precision, ../primitives/extended_precision,
../config/common ../config/common
from sugar import distinctBase from sugar import distinctBase
from bitops import countSetBits # only used on modulus and public values
# ############################################################ # ############################################################
# #
@ -116,6 +117,16 @@ type
## Mutable view into a BigInt ## Mutable view into a BigInt
BigIntViewAny* = BigIntViewConst or BigIntViewMut BigIntViewAny* = BigIntViewConst or BigIntViewMut
BigIntLeakedConst* = distinct BigIntViewConst
## BigInt which information will be leaked
## besides the announced bit length
## This is only suitable for values
## that are publicly known
SensitiveInt* = distinct int
## Integer that contains sensitive information
## and will not be manipulated in a constant-time manner
# No exceptions allowed # No exceptions allowed
{.push raises: [].} {.push raises: [].}
@ -187,6 +198,12 @@ template checkOddModulus(m: BigIntViewConst) =
debug: debug:
assert bool(BaseType(m[0]) and 1), "Internal Error: the modulus must be odd to use the Montgomery representation." assert bool(BaseType(m[0]) and 1), "Internal Error: the modulus must be odd to use the Montgomery representation."
template checkWordShift(k: int) =
## Checks that the shift is less than the word bit size
debug:
assert k <= WordBitSize, "Internal Error: the shift must be less than the word bit size"
debug: debug:
func `$`*(a: BigIntViewAny): string = func `$`*(a: BigIntViewAny): string =
let len = a.numLimbs() let len = a.numLimbs()
@ -205,6 +222,9 @@ debug:
# #
# ############################################################ # ############################################################
template mask*(w: Word): Word =
w and MaxWord
func isZero*(a: BigIntViewAny): CTBool[Word] = func isZero*(a: BigIntViewAny): CTBool[Word] =
## Returns true if a big int is equal to zero ## Returns true if a big int is equal to zero
var accum: Word var accum: Word
@ -217,7 +237,6 @@ func setZero*(a: BigIntViewMut) =
## It's bit size is unchanged ## It's bit size is unchanged
zeroMem(a[0].unsafeAddr, a.numLimbs() * sizeof(Word)) zeroMem(a[0].unsafeAddr, a.numLimbs() * sizeof(Word))
func cmov*(a: BigIntViewMut, b: BigIntViewAny, ctl: CTBool[Word]) = func cmov*(a: BigIntViewMut, b: BigIntViewAny, ctl: CTBool[Word]) =
## Constant-time conditional copy ## Constant-time conditional copy
## If ctl is true: b is copied into a ## If ctl is true: b is copied into a
@ -243,7 +262,7 @@ func add*(a: BigIntViewMut, b: BigIntViewAny, ctl: CTBool[Word]): CTBool[Word] =
for i in 0 ..< a.numLimbs(): for i in 0 ..< a.numLimbs():
let new_a = a[i] + b[i] + Word(result) let new_a = a[i] + b[i] + Word(result)
result = new_a.isMsbSet() result = new_a.isMsbSet()
a[i] = ctl.mux(new_a and MaxWord, a[i]) a[i] = ctl.mux(new_a.mask(), a[i])
func sub*(a: BigIntViewMut, b: BigIntViewAny, ctl: CTBool[Word]): CTBool[Word] = func sub*(a: BigIntViewMut, b: BigIntViewAny, ctl: CTBool[Word]): CTBool[Word] =
## Constant-time big integer in-place optional substraction ## Constant-time big integer in-place optional substraction
@ -257,7 +276,37 @@ func sub*(a: BigIntViewMut, b: BigIntViewAny, ctl: CTBool[Word]): CTBool[Word] =
for i in 0 ..< a.numLimbs(): for i in 0 ..< a.numLimbs():
let new_a = a[i] - b[i] - Word(result) let new_a = a[i] - b[i] - Word(result)
result = new_a.isMsbSet() result = new_a.isMsbSet()
a[i] = ctl.mux(new_a and MaxWord, a[i]) a[i] = ctl.mux(new_a.mask(), a[i])
func dec*(a: BigIntViewMut, w: Word): CTBool[Word] =
## Decrement a big int by a small word
# returns the result carry
a[0] -= w
result = a[0].isMsbSet()
a[0] = a[0].mask()
for i in 1 ..< a.numLimbs():
a[i] -= Word(result)
result = a[i].isMsbSet()
a[i] = a[i].mask()
func shiftRight*(a: BigIntViewMut, k: int) =
## Shift right by k.
##
## k MUST be less than the base word size (2^31 or 2^63)
# We don't reuse shr for this in-place operation
# Do we need to return the shifted out part?
#
# Note: for speed, loading a[i] and a[i+1]
# instead of a[i-1] and a[i]
# is probably easier to parallelize for the compiler
# (antidependence WAR vs loop-carried dependence RAW)
checkWordShift(k)
let len = a.numLimbs()
for i in 0 ..< len-1:
a[i] = (a[i] shr k) or mask(a[i+1] shl (WordBitSize - k))
a[len-1] = a[len-1] shr k
# ############################################################ # ############################################################
# #
@ -301,11 +350,11 @@ func shlAddMod(a: BigIntViewMut, c: Word, M: BigIntViewConst) =
a1 = a[^1] a1 = a[^1]
m0 = M[^1] m0 = M[^1]
else: # Else: need to deal with partial word shifts at the edge. else: # Else: need to deal with partial word shifts at the edge.
a0 = ((a[^1] shl (WordBitSize-R)) or (a[^2] shr R)) and MaxWord a0 = mask((a[^1] shl (WordBitSize-R)) or (a[^2] shr R))
moveMem(a[1].addr, a[0].addr, (aLen-1) * Word.sizeof) moveMem(a[1].addr, a[0].addr, (aLen-1) * Word.sizeof)
a[0] = c a[0] = c
a1 = ((a[^1] shl (WordBitSize-R)) or (a[^2] shr R)) and MaxWord a1 = mask((a[^1] shl (WordBitSize-R)) or (a[^2] shr R))
m0 = ((M[^1] shl (WordBitSize-R)) or (M[^2] shr R)) and MaxWord m0 = mask((M[^1] shl (WordBitSize-R)) or (M[^2] shr R))
# m0 has its high bit set. (a0, a1)/p0 fits in a limb. # m0 has its high bit set. (a0, a1)/p0 fits in a limb.
# Get a quotient q, at most we will be 2 iterations off # Get a quotient q, at most we will be 2 iterations off
@ -335,12 +384,12 @@ func shlAddMod(a: BigIntViewMut, c: Word, M: BigIntViewConst) =
# q * p + carry (doubleword) carry from previous limb # q * p + carry (doubleword) carry from previous limb
let qp = unsafeExtPrecMul(q, M[i]) + carry.DoubleWord let qp = unsafeExtPrecMul(q, M[i]) + carry.DoubleWord
carry = Word(qp shr WordBitSize) # New carry: high digit besides LSB carry = Word(qp shr WordBitSize) # New carry: high digit besides LSB
qp_lo = qp.Word and MaxWord # Normalize to u63 qp_lo = qp.Word.mask() # Normalize to u63
block: # a*2^63 - q*p block: # a*2^63 - q*p
a[i] -= qp_lo a[i] -= qp_lo
carry += Word(a[i].isMsbSet) # Adjust if borrow carry += Word(a[i].isMsbSet) # Adjust if borrow
a[i] = a[i] and MaxWord # Normalize to u63 a[i] = a[i].mask() # Normalize to u63
over_p = mux( over_p = mux(
a[i] == M[i], over_p, a[i] == M[i], over_p,
@ -404,7 +453,7 @@ func reduce*(r: BigIntViewMut, a: BigIntViewAny, M: BigIntViewConst) =
# ############################################################ # ############################################################
template wordMul(a, b: Word): Word = template wordMul(a, b: Word): Word =
(a * b) and MaxWord mask(a * b)
func montyMul*( func montyMul*(
r: BigIntViewMut, a, b: distinct BigIntViewAny, r: BigIntViewMut, a, b: distinct BigIntViewAny,
@ -448,10 +497,10 @@ func montyMul*(
unsafeExtPrecMul(zi, M[j]) + DoubleWord(carry) unsafeExtPrecMul(zi, M[j]) + DoubleWord(carry)
carry = Word(z shr WordBitSize) carry = Word(z shr WordBitSize)
if j != 0: if j != 0:
r[j-1] = Word(z) and MaxWord r[j-1] = Word(z).mask()
r_hi += carry r_hi += carry
r[^1] = r_hi and MaxWord r[^1] = r_hi.mask()
r_hi = r_hi shr WordBitSize r_hi = r_hi shr WordBitSize
# If the extra word is not zero or if r-M does not borrow (i.e. r > M) # If the extra word is not zero or if r-M does not borrow (i.e. r > M)
@ -509,3 +558,40 @@ func montyResidue*(
checkMatchingBitlengths(a, N) checkMatchingBitlengths(a, N)
montyMul(r, a, r2ModN, N, negInvModWord) montyMul(r, a, r2ModN, N, negInvModWord)
# ############################################################
#
# Sensitive Primitives
#
# ############################################################
# Warning: Primitives that expose bits of information
# due to non-constant time.
# Proper usage is enforced by compiler.
# Only apply explicitly to public data like the field modulus
template bitSizeof(v: BigIntLeakedConst): uint32 =
distinctBase(type v)(v).bitLength
template numLimbs*(v: BigIntLeakedConst): int =
## Compute the number of limbs from
## the **internal** bitlength
(bitSizeof(v).int + WordPhysBitSize - 1) shr divShiftor
template `[]`*(v: BigIntLeakedConst, limbIdx: int): SensitiveInt =
SensitiveInt distinctBase(type v)(v).limbs[limbIdx]
func popcount(a: BigIntLeakedConst): SensitiveInt =
## Count the number of bits set in an integer
## also called popcount or Hamming Weight
## ⚠️⚠️⚠️: This is only intended for use on public data
var accum: int
for i in 0 ..< a.numLimbs:
accum += countSetBits(a[i].BaseType)
return SensitiveInt accum
func lastBits(a: BigIntLeakedConst, k: int): SensitiveInt =
## Returns the last bits of an integer
## k MUST be less than the base word size (2^31 or 2^63)
checkWordShift(k)
let mask = BaseType((1 shl k) - 1)
return SensitiveInt(a[0].BaseType and mask)

View File

@ -261,7 +261,7 @@ proc main() =
# Check equality when converting back to natural domain # Check equality when converting back to natural domain
new_x == 1'u64 shl 61 - 2 new_x == 1'u64 shl 61 - 2
test "Multiplication mod 101": test "Multiplication mod 2^61 - 1":
block: block:
var x, y, z: Fq[Mersenne61] var x, y, z: Fq[Mersenne61]