Fused initialization and arithmetic finite field primitive to have Fp2 less verbose and more efficient

This commit is contained in:
Mamy André-Ratsimbazafy 2020-02-25 11:00:27 +01:00
parent 2aec16d8d8
commit 2aa33ea226
No known key found for this signature in database
GPG Key ID: 7B88AD1FE79492E1
5 changed files with 64 additions and 17 deletions

View File

@ -168,6 +168,13 @@ func diff*(r: var BigInt, a, b: BigInt): CTBool[Word] =
## Returns the borrow
diff(r.view, a.view, b.view)
func double*(r: var BigInt, a: BigInt): CTBool[Word] =
## Double `a` into `r`.
## `r` is initialized/overwritten
##
## Returns the carry
sum(r.view, a.view, a.view)
# ############################################################
#
# Comparisons

View File

@ -343,7 +343,7 @@ func sum*(r: BigIntViewMut, a, b: BigIntViewAny): CTBool[Word] =
## Returns the carry
checkMatchingBitlengths(a, b)
r.setBitLength(bitSizeof(M))
r.setBitLength(bitSizeof(a))
for i in 0 ..< a.numLimbs():
r[i] = a[i] + b[i] + Word(result)
@ -357,7 +357,7 @@ func diff*(r: BigIntViewMut, a, b: BigIntViewAny): CTBool[Word] =
## Returns the borrow
checkMatchingBitlengths(a, b)
r.setBitLength(bitSizeof(M))
r.setBitLength(bitSizeof(a))
for i in 0 ..< a.numLimbs():
r[i] = a[i] - b[i] - Word(result)

View File

@ -99,13 +99,13 @@ func setOne*(a: var Fp) =
a = Fp.C.getMontyOne()
func `+=`*(a: var Fp, b: Fp) =
## Addition modulo p
## In-place addition modulo p
var overflowed = add(a.mres, b.mres)
overflowed = overflowed or not csub(a.mres, Fp.C.Mod.mres, CtFalse) # a >= P
discard csub(a.mres, Fp.C.Mod.mres, overflowed)
func `-=`*(a: var Fp, b: Fp) =
## Substraction modulo p
## In-place substraction modulo p
let underflowed = sub(a.mres, b.mres)
discard cadd(a.mres, Fp.C.Mod.mres, underflowed)
@ -115,13 +115,46 @@ func double*(a: var Fp) =
overflowed = overflowed or not csub(a.mres, Fp.C.Mod.mres, CtFalse) # a >= P
discard csub(a.mres, Fp.C.Mod.mres, overflowed)
func sum*(r: var Fp, a, b: Fp) =
## Sum ``a`` and ``b`` into ``r`` module p
## r is initialized/overwritten
var overflowed = r.mres.sum(a.mres, b.mres)
overflowed = overflowed or not csub(r.mres, Fp.C.Mod.mres, CtFalse) # r >= P
discard csub(r.mres, Fp.C.Mod.mres, overflowed)
func diff*(r: var Fp, a, b: Fp) =
## Substract `b` from `a` and store the result into `r`.
## `r` is initialized/overwritten
var underflowed = r.mres.diff(a.mres, b.mres)
discard cadd(r.mres, Fp.C.Mod.mres, underflowed)
func double*(r: var Fp, a: Fp) =
## Double ``a`` into ``r``
## `r` is initialized/overwritten
var overflowed = r.mres.double(a.mres)
overflowed = overflowed or not csub(r.mres, Fp.C.Mod.mres, CtFalse) # r >= P
discard csub(r.mres, Fp.C.Mod.mres, overflowed)
func `+`*(a, b: Fp): Fp {.noInit.} =
## Addition modulo p
result.sum(a, b)
func `-`*(a, b: Fp): Fp {.noInit.} =
## Substraction modulo p
result.diff(a, b)
func prod*(r: var Fp, a, b: Fp) =
## Store the product of ``a`` by ``b`` modulo p into ``r``
## ``r`` is initialized / overwritten
r.mres.montyMul(a.mres, b.mres, Fp.C.Mod.mres, Fp.C.getNegInvModWord())
func `*`*(a, b: Fp): Fp {.noInit.} =
## Multiplication modulo p
##
## It is recommended to assign with {.noInit.}
## as Fp elements are usually large and this
## routine will zero init internally the result.
result.mres.montyMul(a.mres, b.mres, Fp.C.Mod.mres, Fp.C.getNegInvModWord())
result.prod(a, b)
func `*=`*(a: var Fp, b: Fp) =
## Multiplication modulo p
@ -132,7 +165,7 @@ func `*=`*(a: var Fp, b: Fp) =
## Cost
## Stack: 1 * ModulusBitSize
var tmp{.noInit.}: Fp
tmp.mres.montyMul(a.mres, b.mres, Fp.C.Mod.mres, Fp.C.getNegInvModWord())
tmp.prod(a, b)
a = tmp
func square*(a: Fp): Fp {.noInit.} =

View File

@ -105,14 +105,9 @@ func square*(a: Fp2): Fp2 {.noInit.} =
# Stack: 6 * ModulusBitSize (4x 𝔽p element + 1 named temporaries + 1 multiplication temporary)
# as multiplications require a (shared) internal temporary
var c0mc1 = a.c0
c0mc1 -= a.c1 # c0mc1 = c0 - c1 [1 Sub]
result.c0 = c0mc1 # result.c0 = c0 - c1
result.c1 = a.c1
result.c1.double() # result.c1 = 2 c1 [1 Dbl, 1 Sub]
result.c0 += result.c1 # result.c0 = c0 - c1 + 2 c1 [1 Add, 1 Dbl, 1 Sub]
result.c0 *= c0mc1 # result.c0 = (c0 + c1)(c0 - c1) = c0² - c1² [1 Mul, 1 Add, 1 Dbl, 1 Sub]
result.c1 *= a.c0 # result.c1 = 2 c1 c0 [2 Mul, 1 Add, 1 Dbl, 1 Sub]
var c0mc1 {.noInit.}: Fp
c0mc1.diff(a.c0, a.c1) # c0mc1 = c0 - c1 [1 Sub]
result.c1.double(a.c1) # result.c1 = 2 c1 [1 Dbl, 1 Sub]
result.c0.sum(c0mc1, result.c1) # result.c0 = c0 - c1 + 2 c1 [1 Add, 1 Dbl, 1 Sub]
result.c0 *= c0mc1 # result.c0 = (c0 + c1)(c0 - c1) = c0² - c1² [1 Mul, 1 Add, 1 Dbl, 1 Sub]
result.c1 *= a.c0 # result.c1 = 2 c1 c0 [2 Mul, 1 Add, 1 Dbl, 1 Sub]

View File

@ -25,6 +25,7 @@ proc main() =
y.fromUint(10'u32)
z.fromUint(90'u32)
let u = x + y
x += y
var x_bytes: array[8, byte]
@ -33,6 +34,7 @@ proc main() =
check:
# Check equality in the Montgomery domain
bool(z == x)
bool(z == u)
# Check equality when converting back to natural domain
90'u64 == cast[uint64](x_bytes)
@ -43,6 +45,7 @@ proc main() =
y.fromUint(21'u32)
z.fromUint(0'u32)
let u = x + y
x += y
var x_bytes: array[8, byte]
@ -51,6 +54,7 @@ proc main() =
check:
# Check equality in the Montgomery domain
bool(z == x)
bool(z == u)
# Check equality when converting back to natural domain
0'u64 == cast[uint64](x_bytes)
@ -61,6 +65,7 @@ proc main() =
y.fromUint(22'u32)
z.fromUint(1'u32)
let u = x + y
x += y
var x_bytes: array[8, byte]
@ -69,6 +74,7 @@ proc main() =
check:
# Check equality in the Montgomery domain
bool(z == x)
bool(z == u)
# Check equality when converting back to natural domain
1'u64 == cast[uint64](x_bytes)
@ -80,6 +86,7 @@ proc main() =
y.fromUint(10'u32)
z.fromUint(70'u32)
let u = x - y
x -= y
var x_bytes: array[8, byte]
@ -88,6 +95,7 @@ proc main() =
check:
# Check equality in the Montgomery domain
bool(z == x)
bool(z == u)
# Check equality when converting back to natural domain
70'u64 == cast[uint64](x_bytes)
@ -98,6 +106,7 @@ proc main() =
y.fromUint(80'u32)
z.fromUint(0'u32)
let u = x - y
x -= y
var x_bytes: array[8, byte]
@ -106,6 +115,7 @@ proc main() =
check:
# Check equality in the Montgomery domain
bool(z == x)
bool(z == u)
# Check equality when converting back to natural domain
0'u64 == cast[uint64](x_bytes)
@ -116,6 +126,7 @@ proc main() =
y.fromUint(81'u32)
z.fromUint(100'u32)
let u = x - y
x -= y
var x_bytes: array[8, byte]
@ -124,6 +135,7 @@ proc main() =
check:
# Check equality in the Montgomery domain
bool(z == x)
bool(z == u)
# Check equality when converting back to natural domain
100'u64 == cast[uint64](x_bytes)