Add modular inversion + test vs GMP
This commit is contained in:
parent
8cbbd40a0c
commit
68727e5c8d
|
@ -136,6 +136,16 @@ macro genMontyMagics(T: typed): untyped =
|
|||
)
|
||||
)
|
||||
)
|
||||
# const MyCurve_InvModExponent = primeMinus2_BE(MyCurve_Modulus)
|
||||
result.add newConstStmt(
|
||||
ident($curve & "_InvModExponent"), newCall(
|
||||
bindSym"primeMinus2_BE",
|
||||
nnkDotExpr.newTree(
|
||||
bindSym($curve & "_Modulus"),
|
||||
ident"mres"
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
# echo result.toStrLit
|
||||
|
||||
|
@ -153,6 +163,10 @@ macro getMontyOne*(C: static Curve): untyped =
|
|||
## Get one in Montgomery representation (i.e. R mod P)
|
||||
result = bindSym($C & "_MontyOne")
|
||||
|
||||
macro getInvModExponent*(C: static Curve): untyped =
|
||||
## Get modular inversion exponent (Modulus-2 in canonical representation)
|
||||
result = bindSym($C & "_InvModExponent")
|
||||
|
||||
# ############################################################
|
||||
#
|
||||
# Debug info printed at compile-time
|
||||
|
|
|
@ -151,6 +151,15 @@ func fromUint*(
|
|||
# ############################################################
|
||||
import strutils
|
||||
|
||||
template toByte(x: SomeUnsignedInt): byte =
|
||||
## At compile-time, conversion to bytes checks the range
|
||||
## we want to ensure this is done at the register level
|
||||
## at runtime in a single "mov byte" instruction
|
||||
when nimvm:
|
||||
byte(x and 0xFF)
|
||||
else:
|
||||
byte(x)
|
||||
|
||||
template blobFrom(dst: var openArray[byte], src: SomeUnsignedInt, startIdx: int, endian: static Endianness) =
|
||||
## Write an integer into a raw binary blob
|
||||
## Swapping endianness if needed
|
||||
|
@ -159,10 +168,10 @@ template blobFrom(dst: var openArray[byte], src: SomeUnsignedInt, startIdx: int,
|
|||
|
||||
when endian == cpuEndian:
|
||||
for i in 0 ..< sizeof(src):
|
||||
dst[startIdx+i] = byte((src shr (i * 8)))
|
||||
dst[startIdx+i] = toByte((src shr (i * 8)))
|
||||
else:
|
||||
for i in 0 ..< sizeof(src):
|
||||
dst[startIdx+sizeof(src)-1-i] = byte((src shr (i * 8)))
|
||||
dst[startIdx+sizeof(src)-1-i] = toByte((src shr (i * 8)))
|
||||
|
||||
func exportRawUintLE(
|
||||
dst: var openarray[byte],
|
||||
|
@ -203,11 +212,11 @@ func exportRawUintLE(
|
|||
# we can just copy each byte
|
||||
# tail is inclusive
|
||||
for i in 0 ..< tail:
|
||||
dst[dst_idx+i] = byte(lo shr (i*8))
|
||||
dst[dst_idx+i] = toByte(lo shr (i*8))
|
||||
else: # TODO check this
|
||||
# We need to copy from the end
|
||||
for i in 0 ..< tail:
|
||||
dst[dst_idx+i] = byte(lo shr ((tail-i)*8))
|
||||
dst[dst_idx+i] = toByte(lo shr ((tail-i)*8))
|
||||
return
|
||||
|
||||
func exportRawUintBE(
|
||||
|
@ -251,12 +260,13 @@ func exportRawUintBE(
|
|||
# When requesting little-endian on little-endian platform
|
||||
# we can just copy each byte
|
||||
# tail is inclusive
|
||||
debugEcho "tail: "
|
||||
for i in 0 ..< tail:
|
||||
dst[tail-i] = byte(lo shr (i*8))
|
||||
dst[tail-1-i] = toByte(lo shr (i*8))
|
||||
else: # TODO check this
|
||||
# We need to copy from the end
|
||||
for i in 0 ..< tail:
|
||||
dst[tail-i] = byte(lo shr ((tail-i)*8))
|
||||
dst[tail-1-i] = toByte(lo shr ((tail-i)*8))
|
||||
return
|
||||
|
||||
func exportRawUint*(
|
||||
|
|
|
@ -215,3 +215,29 @@ func montyPowUnsafeExponent*[mBits, eBits: static int](
|
|||
scratchPtrs[i] = scratchSpace[i].view()
|
||||
|
||||
montyPowUnsafeExponent(a.view, expBE, M.view, one.view, Word(negInvModWord), scratchPtrs)
|
||||
|
||||
func montyPowUnsafeExponent*[mBits: static int](
|
||||
a: var BigInt[mBits], exponent: openarray[byte],
|
||||
M, one: BigInt[mBits], negInvModWord: static BaseType, windowSize: static int) =
|
||||
## Compute a <- a^exponent (mod M)
|
||||
## ``a`` in the Montgomery domain
|
||||
## ``exponent`` is a BigInt in canonical representation
|
||||
##
|
||||
## Warning ⚠️ :
|
||||
## This is an optimization for public exponent
|
||||
## Otherwise bits of the exponent can be retrieved with:
|
||||
## - memory access analysis
|
||||
## - power analysis
|
||||
## - timing analysis
|
||||
##
|
||||
## This uses fixed window optimization
|
||||
## A window size in the range [1, 5] must be chosen
|
||||
|
||||
const scratchLen = if windowSize == 1: 2
|
||||
else: (1 shl windowSize) + 1
|
||||
var scratchSpace {.noInit.}: array[scratchLen, BigInt[mBits]]
|
||||
var scratchPtrs {.noInit.}: array[scratchLen, BigIntViewMut]
|
||||
for i in 0 ..< scratchLen:
|
||||
scratchPtrs[i] = scratchSpace[i].view()
|
||||
|
||||
montyPowUnsafeExponent(a.view, exponent, M.view, one.view, Word(negInvModWord), scratchPtrs)
|
||||
|
|
|
@ -134,7 +134,11 @@ func pow*(a: var Fq, exponent: BigInt) =
|
|||
## ``a``: a field element to be exponentiated
|
||||
## ``exponent``: a big integer
|
||||
const windowSize = 5 # TODO: find best window size for each curves
|
||||
a.mres.montyPow(exponent, Fq.C.Mod.mres, Fq.C.getMontyOne(), Fq.C.getNegInvModWord(), windowSize)
|
||||
a.mres.montyPow(
|
||||
exponent,
|
||||
Fq.C.Mod.mres, Fq.C.getMontyOne(),
|
||||
Fq.C.getNegInvModWord(), windowSize
|
||||
)
|
||||
|
||||
func powUnsafeExponent*(a: var Fq, exponent: BigInt) =
|
||||
## Exponentiation over Fq
|
||||
|
@ -148,4 +152,20 @@ func powUnsafeExponent*(a: var Fq, exponent: BigInt) =
|
|||
## - power analysis
|
||||
## - timing analysis
|
||||
const windowSize = 5 # TODO: find best window size for each curves
|
||||
a.mres.montyPowUnsafeExponent(exponent, Fq.C.Mod.mres, Fq.C.getMontyOne(), Fq.C.getNegInvModWord(), windowSize)
|
||||
a.mres.montyPowUnsafeExponent(
|
||||
exponent,
|
||||
Fq.C.Mod.mres, Fq.C.getMontyOne(),
|
||||
Fq.C.getNegInvModWord(), windowSize
|
||||
)
|
||||
|
||||
func inv*(a: var Fq) =
|
||||
## Modular inversion
|
||||
## Warning ⚠️ :
|
||||
## - This assumes that `Fq` is a prime field
|
||||
|
||||
const windowSize = 5 # TODO: find best window size for each curves
|
||||
a.mres.montyPowUnsafeExponent(
|
||||
Fq.C.getInvModExponent(),
|
||||
Fq.C.Mod.mres, Fq.C.getMontyOne(),
|
||||
Fq.C.getNegInvModWord(), windowSize
|
||||
)
|
||||
|
|
|
@ -9,7 +9,8 @@
|
|||
import
|
||||
./bigints_checked,
|
||||
../primitives/constant_time,
|
||||
../config/common
|
||||
../config/common,
|
||||
../io/io_bigints
|
||||
|
||||
# Precomputed constants
|
||||
# ############################################################
|
||||
|
@ -187,3 +188,15 @@ func montyOne*(M: BigInt): BigInt =
|
|||
## Returns "1 (mod M)" in the Montgomery domain.
|
||||
## This is equivalent to R (mod M) in the natural domain
|
||||
r_powmod(1, M)
|
||||
|
||||
func primeMinus2_BE*[bits: static int](
|
||||
P: BigInt[bits]
|
||||
): array[(bits+7) div 8, byte] {.noInit.} =
|
||||
## Compute an input prime-2
|
||||
## and return the result as a canonical byte array / octet string
|
||||
## For use to precompute modular inverse exponent.
|
||||
|
||||
var tmp = P
|
||||
discard tmp.sub(BigInt[bits].fromRawUint([byte 2], bigEndian), true)
|
||||
|
||||
result.exportRawUint(tmp, bigEndian)
|
||||
|
|
|
@ -125,4 +125,18 @@ proc main() =
|
|||
check:
|
||||
computed == expected
|
||||
|
||||
suite "Modular inversion over prime fields":
|
||||
test "x^(-1) mod p":
|
||||
var x: Fq[BLS12_381]
|
||||
|
||||
# BN254 field modulus
|
||||
x.fromHex("0x30644e72e131a029b85045b68181585d97816a916871ca8d3c208c16d87cfd47")
|
||||
|
||||
let expected = "0x0636759a0f3034fa47174b2c0334902f11e9915b7bd89c6a2b3082b109abbc9837da17201f6d8286fe6203caa1b9d4c8"
|
||||
x.inv()
|
||||
let computed = x.toHex()
|
||||
|
||||
check:
|
||||
computed == expected
|
||||
|
||||
main()
|
||||
|
|
|
@ -47,7 +47,7 @@ const # https://gmplib.org/manual/Integer-Import-and-Export.html
|
|||
GMP_MostSignificantWordFirst = 1'i32
|
||||
GMP_LeastSignificantWordFirst = -1'i32
|
||||
|
||||
proc main() =
|
||||
proc mainMul() =
|
||||
var gmpRng: gmp_randstate_t
|
||||
gmp_randinit_mt(gmpRng)
|
||||
# The GMP seed varies between run so that
|
||||
|
@ -66,7 +66,7 @@ proc main() =
|
|||
|
||||
randomTests(128, curve):
|
||||
# echo "--------------------------------------------------------------------------------"
|
||||
echo "Testing: random input on ", $curve
|
||||
echo "Testing: random modular multiplication on ", $curve
|
||||
|
||||
const bits = CurveParams[curve][0]
|
||||
|
||||
|
@ -127,4 +127,79 @@ proc main() =
|
|||
" GMP: " & rGMP.toHex() & "\n" &
|
||||
" Constantine: " & rConstantine.toHex()
|
||||
|
||||
main()
|
||||
proc mainInv() =
|
||||
var gmpRng: gmp_randstate_t
|
||||
gmp_randinit_mt(gmpRng)
|
||||
# The GMP seed varies between run so that
|
||||
# test coverage increases as the library gets tested.
|
||||
# This requires to dump the seed in the console or the function inputs
|
||||
# to be able to reproduce a bug
|
||||
let seed = uint32(getTime().toUnix() and (1'i64 shl 32 - 1)) # unixTime mod 2^32
|
||||
echo "GMP seed: ", seed
|
||||
gmp_randseed_ui(gmpRng, seed)
|
||||
|
||||
var a, p, r: mpz_t
|
||||
mpz_init(a)
|
||||
mpz_init(p)
|
||||
mpz_init(r)
|
||||
|
||||
randomTests(128, curve):
|
||||
# echo "--------------------------------------------------------------------------------"
|
||||
echo "Testing: random modular inversion on ", $curve
|
||||
|
||||
const bits = CurveParams[curve][0]
|
||||
|
||||
# Generate random value in the range 0 ..< 2^(bits-1)
|
||||
mpz_urandomb(a, gmpRng, uint bits)
|
||||
# Set modulus to curve modulus
|
||||
let err = mpz_set_str(p, CurveParams[curve][1], 0)
|
||||
doAssert err == 0
|
||||
|
||||
#########################################################
|
||||
# Conversion buffers
|
||||
const len = csize (bits + 7) div 8
|
||||
|
||||
# Note: GMP does not pad right the bigendian numbers if there is extra space
|
||||
var aBuf: array[len, byte]
|
||||
|
||||
var aW: csize # Word written by GMP
|
||||
discard mpz_export(aBuf[0].addr, aW.addr, GMP_MostSignificantWordFirst, 1, GMP_WordNativeEndian, 0, a)
|
||||
|
||||
# Since the modulus is using all bits, it's we can test for exact amount copy
|
||||
doAssert len >= aW, "Expected at most " & $len & " bytes but wrote " & $aW & " for " & toHex(aBuf) & " (little-endian)"
|
||||
|
||||
# Build the bigint - TODO more fields codecs
|
||||
let aTest = Fq[curve].fromBig BigInt[bits].fromRawUint(aBuf[0 ..< aW], bigEndian)
|
||||
|
||||
#########################################################
|
||||
# Modular inversion
|
||||
let exist = mpz_invert(r, a, p)
|
||||
doAssert exist != 0
|
||||
|
||||
var rTest = aTest
|
||||
rTest.inv()
|
||||
|
||||
#########################################################
|
||||
# Check
|
||||
var rGMP: array[len, byte]
|
||||
var rW: csize # Word written by GMP
|
||||
discard mpz_export(rGMP[0].addr, rW.addr, GMP_MostSignificantWordFirst, 1, GMP_WordNativeEndian, 0, r)
|
||||
|
||||
var rConstantine: array[len, byte]
|
||||
exportRawUint(rConstantine, rTest, bigEndian)
|
||||
|
||||
# echo "rGMP: ", rGMP.toHex()
|
||||
# echo "rConstantine: ", rConstantine.toHex()
|
||||
|
||||
doAssert rGMP[0 ..< rW] == rConstantine[^rW..^1], block:
|
||||
# Reexport as bigEndian for debugging
|
||||
discard mpz_export(aBuf[0].addr, aW.addr, GMP_MostSignificantWordFirst, 1, GMP_WordNativeEndian, 0, a)
|
||||
"\nModular Inversion on curve " & $curve & " with operand\n" &
|
||||
" a: 0x" & aBuf.toHex & "\n" &
|
||||
" p: " & CurveParams[curve][1] & "\n" &
|
||||
"failed:" & "\n" &
|
||||
" GMP: " & rGMP.toHex() & "\n" &
|
||||
" Constantine: " & rConstantine.toHex()
|
||||
|
||||
mainMul()
|
||||
mainInv()
|
||||
|
|
Loading…
Reference in New Issue