Add primitive for window-based modular exponentiation
This commit is contained in:
parent
285b6aad1a
commit
d7d20c50b6
|
@ -22,3 +22,4 @@ task test, "Run all tests":
|
||||||
test "", "tests/test_bigints_multimod.nim"
|
test "", "tests/test_bigints_multimod.nim"
|
||||||
test "", "tests/test_bigints_vs_gmp.nim"
|
test "", "tests/test_bigints_vs_gmp.nim"
|
||||||
test "", "tests/test_finite_fields.nim"
|
test "", "tests/test_finite_fields.nim"
|
||||||
|
test "", "tests/test_finite_fields_vs_gmp.nim"
|
||||||
|
|
|
@ -54,6 +54,7 @@ import
|
||||||
../primitives/extended_precision,
|
../primitives/extended_precision,
|
||||||
../config/common
|
../config/common
|
||||||
from sugar import distinctBase
|
from sugar import distinctBase
|
||||||
|
from bitops import countSetBits # only used on modulus and public values
|
||||||
|
|
||||||
# ############################################################
|
# ############################################################
|
||||||
#
|
#
|
||||||
|
@ -116,6 +117,16 @@ type
|
||||||
## Mutable view into a BigInt
|
## Mutable view into a BigInt
|
||||||
BigIntViewAny* = BigIntViewConst or BigIntViewMut
|
BigIntViewAny* = BigIntViewConst or BigIntViewMut
|
||||||
|
|
||||||
|
BigIntLeakedConst* = distinct BigIntViewConst
|
||||||
|
## BigInt which information will be leaked
|
||||||
|
## besides the announced bit length
|
||||||
|
## This is only suitable for values
|
||||||
|
## that are publicly known
|
||||||
|
|
||||||
|
SensitiveInt* = distinct int
|
||||||
|
## Integer that contains sensitive information
|
||||||
|
## and will not be manipulated in a constant-time manner
|
||||||
|
|
||||||
# No exceptions allowed
|
# No exceptions allowed
|
||||||
{.push raises: [].}
|
{.push raises: [].}
|
||||||
|
|
||||||
|
@ -187,6 +198,12 @@ template checkOddModulus(m: BigIntViewConst) =
|
||||||
debug:
|
debug:
|
||||||
assert bool(BaseType(m[0]) and 1), "Internal Error: the modulus must be odd to use the Montgomery representation."
|
assert bool(BaseType(m[0]) and 1), "Internal Error: the modulus must be odd to use the Montgomery representation."
|
||||||
|
|
||||||
|
template checkWordShift(k: int) =
|
||||||
|
## Checks that the shift is less than the word bit size
|
||||||
|
debug:
|
||||||
|
assert k <= WordBitSize, "Internal Error: the shift must be less than the word bit size"
|
||||||
|
|
||||||
|
|
||||||
debug:
|
debug:
|
||||||
func `$`*(a: BigIntViewAny): string =
|
func `$`*(a: BigIntViewAny): string =
|
||||||
let len = a.numLimbs()
|
let len = a.numLimbs()
|
||||||
|
@ -205,6 +222,9 @@ debug:
|
||||||
#
|
#
|
||||||
# ############################################################
|
# ############################################################
|
||||||
|
|
||||||
|
template mask*(w: Word): Word =
|
||||||
|
w and MaxWord
|
||||||
|
|
||||||
func isZero*(a: BigIntViewAny): CTBool[Word] =
|
func isZero*(a: BigIntViewAny): CTBool[Word] =
|
||||||
## Returns true if a big int is equal to zero
|
## Returns true if a big int is equal to zero
|
||||||
var accum: Word
|
var accum: Word
|
||||||
|
@ -217,7 +237,6 @@ func setZero*(a: BigIntViewMut) =
|
||||||
## It's bit size is unchanged
|
## It's bit size is unchanged
|
||||||
zeroMem(a[0].unsafeAddr, a.numLimbs() * sizeof(Word))
|
zeroMem(a[0].unsafeAddr, a.numLimbs() * sizeof(Word))
|
||||||
|
|
||||||
|
|
||||||
func cmov*(a: BigIntViewMut, b: BigIntViewAny, ctl: CTBool[Word]) =
|
func cmov*(a: BigIntViewMut, b: BigIntViewAny, ctl: CTBool[Word]) =
|
||||||
## Constant-time conditional copy
|
## Constant-time conditional copy
|
||||||
## If ctl is true: b is copied into a
|
## If ctl is true: b is copied into a
|
||||||
|
@ -243,7 +262,7 @@ func add*(a: BigIntViewMut, b: BigIntViewAny, ctl: CTBool[Word]): CTBool[Word] =
|
||||||
for i in 0 ..< a.numLimbs():
|
for i in 0 ..< a.numLimbs():
|
||||||
let new_a = a[i] + b[i] + Word(result)
|
let new_a = a[i] + b[i] + Word(result)
|
||||||
result = new_a.isMsbSet()
|
result = new_a.isMsbSet()
|
||||||
a[i] = ctl.mux(new_a and MaxWord, a[i])
|
a[i] = ctl.mux(new_a.mask(), a[i])
|
||||||
|
|
||||||
func sub*(a: BigIntViewMut, b: BigIntViewAny, ctl: CTBool[Word]): CTBool[Word] =
|
func sub*(a: BigIntViewMut, b: BigIntViewAny, ctl: CTBool[Word]): CTBool[Word] =
|
||||||
## Constant-time big integer in-place optional substraction
|
## Constant-time big integer in-place optional substraction
|
||||||
|
@ -257,7 +276,37 @@ func sub*(a: BigIntViewMut, b: BigIntViewAny, ctl: CTBool[Word]): CTBool[Word] =
|
||||||
for i in 0 ..< a.numLimbs():
|
for i in 0 ..< a.numLimbs():
|
||||||
let new_a = a[i] - b[i] - Word(result)
|
let new_a = a[i] - b[i] - Word(result)
|
||||||
result = new_a.isMsbSet()
|
result = new_a.isMsbSet()
|
||||||
a[i] = ctl.mux(new_a and MaxWord, a[i])
|
a[i] = ctl.mux(new_a.mask(), a[i])
|
||||||
|
|
||||||
|
func dec*(a: BigIntViewMut, w: Word): CTBool[Word] =
|
||||||
|
## Decrement a big int by a small word
|
||||||
|
# returns the result carry
|
||||||
|
|
||||||
|
a[0] -= w
|
||||||
|
result = a[0].isMsbSet()
|
||||||
|
a[0] = a[0].mask()
|
||||||
|
for i in 1 ..< a.numLimbs():
|
||||||
|
a[i] -= Word(result)
|
||||||
|
result = a[i].isMsbSet()
|
||||||
|
a[i] = a[i].mask()
|
||||||
|
|
||||||
|
func shiftRight*(a: BigIntViewMut, k: int) =
|
||||||
|
## Shift right by k.
|
||||||
|
##
|
||||||
|
## k MUST be less than the base word size (2^31 or 2^63)
|
||||||
|
# We don't reuse shr for this in-place operation
|
||||||
|
# Do we need to return the shifted out part?
|
||||||
|
#
|
||||||
|
# Note: for speed, loading a[i] and a[i+1]
|
||||||
|
# instead of a[i-1] and a[i]
|
||||||
|
# is probably easier to parallelize for the compiler
|
||||||
|
# (antidependence WAR vs loop-carried dependence RAW)
|
||||||
|
checkWordShift(k)
|
||||||
|
|
||||||
|
let len = a.numLimbs()
|
||||||
|
for i in 0 ..< len-1:
|
||||||
|
a[i] = (a[i] shr k) or mask(a[i+1] shl (WordBitSize - k))
|
||||||
|
a[len-1] = a[len-1] shr k
|
||||||
|
|
||||||
# ############################################################
|
# ############################################################
|
||||||
#
|
#
|
||||||
|
@ -301,11 +350,11 @@ func shlAddMod(a: BigIntViewMut, c: Word, M: BigIntViewConst) =
|
||||||
a1 = a[^1]
|
a1 = a[^1]
|
||||||
m0 = M[^1]
|
m0 = M[^1]
|
||||||
else: # Else: need to deal with partial word shifts at the edge.
|
else: # Else: need to deal with partial word shifts at the edge.
|
||||||
a0 = ((a[^1] shl (WordBitSize-R)) or (a[^2] shr R)) and MaxWord
|
a0 = mask((a[^1] shl (WordBitSize-R)) or (a[^2] shr R))
|
||||||
moveMem(a[1].addr, a[0].addr, (aLen-1) * Word.sizeof)
|
moveMem(a[1].addr, a[0].addr, (aLen-1) * Word.sizeof)
|
||||||
a[0] = c
|
a[0] = c
|
||||||
a1 = ((a[^1] shl (WordBitSize-R)) or (a[^2] shr R)) and MaxWord
|
a1 = mask((a[^1] shl (WordBitSize-R)) or (a[^2] shr R))
|
||||||
m0 = ((M[^1] shl (WordBitSize-R)) or (M[^2] shr R)) and MaxWord
|
m0 = mask((M[^1] shl (WordBitSize-R)) or (M[^2] shr R))
|
||||||
|
|
||||||
# m0 has its high bit set. (a0, a1)/p0 fits in a limb.
|
# m0 has its high bit set. (a0, a1)/p0 fits in a limb.
|
||||||
# Get a quotient q, at most we will be 2 iterations off
|
# Get a quotient q, at most we will be 2 iterations off
|
||||||
|
@ -335,12 +384,12 @@ func shlAddMod(a: BigIntViewMut, c: Word, M: BigIntViewConst) =
|
||||||
# q * p + carry (doubleword) carry from previous limb
|
# q * p + carry (doubleword) carry from previous limb
|
||||||
let qp = unsafeExtPrecMul(q, M[i]) + carry.DoubleWord
|
let qp = unsafeExtPrecMul(q, M[i]) + carry.DoubleWord
|
||||||
carry = Word(qp shr WordBitSize) # New carry: high digit besides LSB
|
carry = Word(qp shr WordBitSize) # New carry: high digit besides LSB
|
||||||
qp_lo = qp.Word and MaxWord # Normalize to u63
|
qp_lo = qp.Word.mask() # Normalize to u63
|
||||||
|
|
||||||
block: # a*2^63 - q*p
|
block: # a*2^63 - q*p
|
||||||
a[i] -= qp_lo
|
a[i] -= qp_lo
|
||||||
carry += Word(a[i].isMsbSet) # Adjust if borrow
|
carry += Word(a[i].isMsbSet) # Adjust if borrow
|
||||||
a[i] = a[i] and MaxWord # Normalize to u63
|
a[i] = a[i].mask() # Normalize to u63
|
||||||
|
|
||||||
over_p = mux(
|
over_p = mux(
|
||||||
a[i] == M[i], over_p,
|
a[i] == M[i], over_p,
|
||||||
|
@ -404,7 +453,7 @@ func reduce*(r: BigIntViewMut, a: BigIntViewAny, M: BigIntViewConst) =
|
||||||
# ############################################################
|
# ############################################################
|
||||||
|
|
||||||
template wordMul(a, b: Word): Word =
|
template wordMul(a, b: Word): Word =
|
||||||
(a * b) and MaxWord
|
mask(a * b)
|
||||||
|
|
||||||
func montyMul*(
|
func montyMul*(
|
||||||
r: BigIntViewMut, a, b: distinct BigIntViewAny,
|
r: BigIntViewMut, a, b: distinct BigIntViewAny,
|
||||||
|
@ -448,10 +497,10 @@ func montyMul*(
|
||||||
unsafeExtPrecMul(zi, M[j]) + DoubleWord(carry)
|
unsafeExtPrecMul(zi, M[j]) + DoubleWord(carry)
|
||||||
carry = Word(z shr WordBitSize)
|
carry = Word(z shr WordBitSize)
|
||||||
if j != 0:
|
if j != 0:
|
||||||
r[j-1] = Word(z) and MaxWord
|
r[j-1] = Word(z).mask()
|
||||||
|
|
||||||
r_hi += carry
|
r_hi += carry
|
||||||
r[^1] = r_hi and MaxWord
|
r[^1] = r_hi.mask()
|
||||||
r_hi = r_hi shr WordBitSize
|
r_hi = r_hi shr WordBitSize
|
||||||
|
|
||||||
# If the extra word is not zero or if r-M does not borrow (i.e. r > M)
|
# If the extra word is not zero or if r-M does not borrow (i.e. r > M)
|
||||||
|
@ -509,3 +558,40 @@ func montyResidue*(
|
||||||
checkMatchingBitlengths(a, N)
|
checkMatchingBitlengths(a, N)
|
||||||
|
|
||||||
montyMul(r, a, r2ModN, N, negInvModWord)
|
montyMul(r, a, r2ModN, N, negInvModWord)
|
||||||
|
|
||||||
|
# ############################################################
|
||||||
|
#
|
||||||
|
# Sensitive Primitives
|
||||||
|
#
|
||||||
|
# ############################################################
|
||||||
|
# Warning: Primitives that expose bits of information
|
||||||
|
# due to non-constant time.
|
||||||
|
# Proper usage is enforced by compiler.
|
||||||
|
# Only apply explicitly to public data like the field modulus
|
||||||
|
|
||||||
|
template bitSizeof(v: BigIntLeakedConst): uint32 =
|
||||||
|
distinctBase(type v)(v).bitLength
|
||||||
|
|
||||||
|
template numLimbs*(v: BigIntLeakedConst): int =
|
||||||
|
## Compute the number of limbs from
|
||||||
|
## the **internal** bitlength
|
||||||
|
(bitSizeof(v).int + WordPhysBitSize - 1) shr divShiftor
|
||||||
|
|
||||||
|
template `[]`*(v: BigIntLeakedConst, limbIdx: int): SensitiveInt =
|
||||||
|
SensitiveInt distinctBase(type v)(v).limbs[limbIdx]
|
||||||
|
|
||||||
|
func popcount(a: BigIntLeakedConst): SensitiveInt =
|
||||||
|
## Count the number of bits set in an integer
|
||||||
|
## also called popcount or Hamming Weight
|
||||||
|
## ⚠️⚠️⚠️: This is only intended for use on public data
|
||||||
|
var accum: int
|
||||||
|
for i in 0 ..< a.numLimbs:
|
||||||
|
accum += countSetBits(a[i].BaseType)
|
||||||
|
return SensitiveInt accum
|
||||||
|
|
||||||
|
func lastBits(a: BigIntLeakedConst, k: int): SensitiveInt =
|
||||||
|
## Returns the last bits of an integer
|
||||||
|
## k MUST be less than the base word size (2^31 or 2^63)
|
||||||
|
checkWordShift(k)
|
||||||
|
let mask = BaseType((1 shl k) - 1)
|
||||||
|
return SensitiveInt(a[0].BaseType and mask)
|
||||||
|
|
|
@ -261,7 +261,7 @@ proc main() =
|
||||||
# Check equality when converting back to natural domain
|
# Check equality when converting back to natural domain
|
||||||
new_x == 1'u64 shl 61 - 2
|
new_x == 1'u64 shl 61 - 2
|
||||||
|
|
||||||
test "Multiplication mod 101":
|
test "Multiplication mod 2^61 - 1":
|
||||||
block:
|
block:
|
||||||
var x, y, z: Fq[Mersenne61]
|
var x, y, z: Fq[Mersenne61]
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue