From eb94c3d1bc4be812f9c5a66772476d6a3b9b0aeb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mamy=20Andr=C3=A9-Ratsimbazafy?= Date: Sat, 15 Feb 2020 02:59:08 +0100 Subject: [PATCH] Add Montgomery Modular Multiplication --- constantine/math/bigints_checked.nim | 7 ++++ constantine/math/bigints_raw.nim | 60 +++++++++++++++++++++++++++- constantine/math/finite_fields.nim | 10 +++++ tests/test_finite_fields.nim | 47 +++++++++++++++++++++- 4 files changed, 121 insertions(+), 3 deletions(-) diff --git a/constantine/math/bigints_checked.nim b/constantine/math/bigints_checked.nim index 4185683..17ffacb 100644 --- a/constantine/math/bigints_checked.nim +++ b/constantine/math/bigints_checked.nim @@ -131,3 +131,10 @@ func unsafeRedc*[mBits](nres: var BigInt[mBits], N: BigInt[mBits], montyMagic: s ## Caller must take care of properly switching between ## the natural and montgomery domain. redc(nres.view, N.view, Word(montyMagic)) + +func montyMul*[mBits](r: var BigInt[mBits], a, b, M: BigInt[mBits], montyMagic: static BaseType) = + ## Compute r <- a*b (mod M) in the Montgomery domain + ## + ## This resets r to zero before processing. Use {.noInit.} + ## to avoid duplicating with Nim zero-init policy + montyMul(r.view, a.view, b.view, M.view, Word(montyMagic)) diff --git a/constantine/math/bigints_raw.nim b/constantine/math/bigints_raw.nim index 31f6802..2001031 100644 --- a/constantine/math/bigints_raw.nim +++ b/constantine/math/bigints_raw.nim @@ -179,7 +179,7 @@ template checkValidModulus(m: BigIntViewConst) = ## This is only checked ## with "-d:debugConstantine" and when assertions are on. debug: - assert not m[^1].isZero.bool, "Internal Error: the modulus must use all declared bits" + assert not isZero(m[^1]).bool, "Internal Error: the modulus must use all declared bits" template checkOddModulus(m: BigIntViewConst) = ## CHeck that the modulus is odd @@ -212,6 +212,11 @@ func isZero*(a: BigIntViewAny): CTBool[Word] = accum = accum or a[i] result = accum.isZero() +func setZero*(a: BigIntViewMut) = + ## Set a BigInt to 0 + ## It's bit size is unchanged + zeroMem(a[0].unsafeAddr, a.numLimbs() * sizeof(Word)) + # The arithmetic primitives all accept a control input that indicates # if it is a placebo operation. It stills performs the # same memory accesses to be side-channel attack resistant. @@ -382,6 +387,12 @@ func reduce*(r: BigIntViewMut, a: BigIntViewAny, M: BigIntViewConst) = for i in countdown(aOffset, 0): r.shlAddMod(a[i], M) +# ############################################################ +# +# Montgomery Arithmetic +# +# ############################################################ + func montgomeryResidue*(a: BigIntViewMut, N: BigIntViewConst) = ## Transform a bigint ``a`` from it's natural representation (mod N) ## to a the Montgomery n-residue representation @@ -401,6 +412,9 @@ func montgomeryResidue*(a: BigIntViewMut, N: BigIntViewConst) = for i in countdown(nLen, 1): a.shlAddMod(Zero, N) +template wordMul(a, b: Word): Word = + (a * b) and MaxWord + func redc*(a: BigIntViewMut, N: BigIntViewConst, montyMagic: Word) = ## Transform a bigint ``a`` from it's Montgomery N-residue representation (mod N) ## to the regular natural representation (mod N) @@ -424,9 +438,10 @@ func redc*(a: BigIntViewMut, N: BigIntViewConst, montyMagic: Word) = let nLen = N.numLimbs() for i in 0 ..< nLen: - let z0 = Word(BaseType(a[0]) * BaseType(montyMagic)) and MaxWord + let z0 = wordMul(a[0], montyMagic) var carry = DoubleWord(0) + for j in 0 ..< nLen: let z = DoubleWord(a[i]) + unsafeExtPrecMul(z0, N[i]) + carry carry = z shr WordBitSize @@ -434,3 +449,44 @@ func redc*(a: BigIntViewMut, N: BigIntViewConst, montyMagic: Word) = a[j] = Word(z) and MaxWord a[^1] = Word(carry) + +func montyMul*( + r: BigIntViewMut, a, b: distinct BigIntViewAny, + M: BigIntViewConst, montyMagic: Word) = + ## Compute r <- a*b (mod M) in the Montgomery domain + ## `montyMagic` = -1/M (mod Word). Our words are 2^31 or 2^63 + ## + ## This resets r to zero before processing. Use {.noInit.} + ## to avoid duplicating with Nim zero-init policy + # i.e. c'R <- a'R b'R * R^-1 (mod M) in the natural domain + # as in the Montgomery domain all numbers are scaled by R + + checkValidModulus(M) + checkOddModulus(M) + checkMatchingBitlengths(r, M) + checkMatchingBitlengths(a, M) + checkMatchingBitlengths(b, M) + + let nLen = M.numLimbs() + setZero(r) + + var r_hi = Zero # represents the high word that is used in intermediate computation before reduction mod M + for i in 0 ..< nLen: + + let zi = (r[0] + wordMul(a[i], b[0])).wordMul(montyMagic) + var carry = Zero + + for j in 0 ..< nLen: + let z = DoubleWord(r[j]) + unsafeExtPrecMul(a[i], b[j]) + + unsafeExtPrecMul(zi, M[j]) + DoubleWord(carry) + carry = Word(z shr WordBitSize) + if j != 0: + r[j-1] = Word(z) and MaxWord + + r_hi += carry + r[^1] = r_hi and MaxWord + r_hi = r_hi shr WordBitSize + + # If the extra word is not zero or if r-M does not borrow (i.e. r > M) + # Then substract M + discard r.sub(M, r_hi.isNonZero() or not r.sub(M, CtFalse)) diff --git a/constantine/math/finite_fields.nim b/constantine/math/finite_fields.nim index db2a7a5..b8c0e59 100644 --- a/constantine/math/finite_fields.nim +++ b/constantine/math/finite_fields.nim @@ -41,6 +41,7 @@ debug: # No exceptions allowed {.push raises: [].} +{.push inline.} # ############################################################ # @@ -98,3 +99,12 @@ func `-=`*(a: var Fq, b: Fq) = ## Substraction over Fq let ctl = sub(a, b, CtTrue) discard add(a, Fq.C.Mod, ctl) + +func `*`*(a, b: Fq): Fq {.noInit.} = + ## Multiplication over Fq + ## + ## It is recommended to assign with {.noInit.} + ## as Fq elements are usually large and this + ## routine will zero init internally the result. + result.nres.setInternalBitLength() + result.nres.montyMul(a.nres, b.nres, Fq.C.Mod.nres, montyMagic(Fq.C)) diff --git a/tests/test_finite_fields.nim b/tests/test_finite_fields.nim index c54d342..00ae70c 100644 --- a/tests/test_finite_fields.nim +++ b/tests/test_finite_fields.nim @@ -44,7 +44,15 @@ proc main() = z.fromUint(0'u32) x += y - check: bool(z == x) + + var x_bytes: array[8, byte] + x_bytes.serializeRawUint(x, cpuEndian) + + check: + # Check equality in the Montgomery domain + bool(z == x) + # Check equality when converting back to natural domain + 0'u64 == cast[uint64](x_bytes) block: var x, y, z: Fq[Fake101] @@ -119,4 +127,41 @@ proc main() = # Check equality when converting back to natural domain 100'u64 == cast[uint64](x_bytes) + test "Multiplication mod 101": + block: + var x, y, z: Fq[Fake101] + + x.fromUint(10'u32) + y.fromUint(10'u32) + z.fromUint(100'u32) + + let r = x * y + + var r_bytes: array[8, byte] + r_bytes.serializeRawUint(r, cpuEndian) + + check: + # Check equality in the Montgomery domain + bool(z == r) + # Check equality when converting back to natural domain + 100'u64 == cast[uint64](r_bytes) + + block: + var x, y, z: Fq[Fake101] + + x.fromUint(10'u32) + y.fromUint(11'u32) + z.fromUint(9'u32) + + let r = x * y + + var r_bytes: array[8, byte] + r_bytes.serializeRawUint(r, cpuEndian) + + check: + # Check equality in the Montgomery domain + bool(z == r) + # Check equality when converting back to natural domain + 9'u64 == cast[uint64](r_bytes) + main()