Fused initialization and arithmetic finite field primitive to have Fp2 less verbose and more efficient
This commit is contained in:
parent
2aec16d8d8
commit
2aa33ea226
|
@ -168,6 +168,13 @@ func diff*(r: var BigInt, a, b: BigInt): CTBool[Word] =
|
|||
## Returns the borrow
|
||||
diff(r.view, a.view, b.view)
|
||||
|
||||
func double*(r: var BigInt, a: BigInt): CTBool[Word] =
|
||||
## Double `a` into `r`.
|
||||
## `r` is initialized/overwritten
|
||||
##
|
||||
## Returns the carry
|
||||
sum(r.view, a.view, a.view)
|
||||
|
||||
# ############################################################
|
||||
#
|
||||
# Comparisons
|
||||
|
|
|
@ -343,7 +343,7 @@ func sum*(r: BigIntViewMut, a, b: BigIntViewAny): CTBool[Word] =
|
|||
## Returns the carry
|
||||
checkMatchingBitlengths(a, b)
|
||||
|
||||
r.setBitLength(bitSizeof(M))
|
||||
r.setBitLength(bitSizeof(a))
|
||||
|
||||
for i in 0 ..< a.numLimbs():
|
||||
r[i] = a[i] + b[i] + Word(result)
|
||||
|
@ -357,7 +357,7 @@ func diff*(r: BigIntViewMut, a, b: BigIntViewAny): CTBool[Word] =
|
|||
## Returns the borrow
|
||||
checkMatchingBitlengths(a, b)
|
||||
|
||||
r.setBitLength(bitSizeof(M))
|
||||
r.setBitLength(bitSizeof(a))
|
||||
|
||||
for i in 0 ..< a.numLimbs():
|
||||
r[i] = a[i] - b[i] - Word(result)
|
||||
|
|
|
@ -99,13 +99,13 @@ func setOne*(a: var Fp) =
|
|||
a = Fp.C.getMontyOne()
|
||||
|
||||
func `+=`*(a: var Fp, b: Fp) =
|
||||
## Addition modulo p
|
||||
## In-place addition modulo p
|
||||
var overflowed = add(a.mres, b.mres)
|
||||
overflowed = overflowed or not csub(a.mres, Fp.C.Mod.mres, CtFalse) # a >= P
|
||||
discard csub(a.mres, Fp.C.Mod.mres, overflowed)
|
||||
|
||||
func `-=`*(a: var Fp, b: Fp) =
|
||||
## Substraction modulo p
|
||||
## In-place substraction modulo p
|
||||
let underflowed = sub(a.mres, b.mres)
|
||||
discard cadd(a.mres, Fp.C.Mod.mres, underflowed)
|
||||
|
||||
|
@ -115,13 +115,46 @@ func double*(a: var Fp) =
|
|||
overflowed = overflowed or not csub(a.mres, Fp.C.Mod.mres, CtFalse) # a >= P
|
||||
discard csub(a.mres, Fp.C.Mod.mres, overflowed)
|
||||
|
||||
func sum*(r: var Fp, a, b: Fp) =
|
||||
## Sum ``a`` and ``b`` into ``r`` module p
|
||||
## r is initialized/overwritten
|
||||
var overflowed = r.mres.sum(a.mres, b.mres)
|
||||
overflowed = overflowed or not csub(r.mres, Fp.C.Mod.mres, CtFalse) # r >= P
|
||||
discard csub(r.mres, Fp.C.Mod.mres, overflowed)
|
||||
|
||||
func diff*(r: var Fp, a, b: Fp) =
|
||||
## Substract `b` from `a` and store the result into `r`.
|
||||
## `r` is initialized/overwritten
|
||||
var underflowed = r.mres.diff(a.mres, b.mres)
|
||||
discard cadd(r.mres, Fp.C.Mod.mres, underflowed)
|
||||
|
||||
func double*(r: var Fp, a: Fp) =
|
||||
## Double ``a`` into ``r``
|
||||
## `r` is initialized/overwritten
|
||||
var overflowed = r.mres.double(a.mres)
|
||||
overflowed = overflowed or not csub(r.mres, Fp.C.Mod.mres, CtFalse) # r >= P
|
||||
discard csub(r.mres, Fp.C.Mod.mres, overflowed)
|
||||
|
||||
func `+`*(a, b: Fp): Fp {.noInit.} =
|
||||
## Addition modulo p
|
||||
result.sum(a, b)
|
||||
|
||||
func `-`*(a, b: Fp): Fp {.noInit.} =
|
||||
## Substraction modulo p
|
||||
result.diff(a, b)
|
||||
|
||||
func prod*(r: var Fp, a, b: Fp) =
|
||||
## Store the product of ``a`` by ``b`` modulo p into ``r``
|
||||
## ``r`` is initialized / overwritten
|
||||
r.mres.montyMul(a.mres, b.mres, Fp.C.Mod.mres, Fp.C.getNegInvModWord())
|
||||
|
||||
func `*`*(a, b: Fp): Fp {.noInit.} =
|
||||
## Multiplication modulo p
|
||||
##
|
||||
## It is recommended to assign with {.noInit.}
|
||||
## as Fp elements are usually large and this
|
||||
## routine will zero init internally the result.
|
||||
result.mres.montyMul(a.mres, b.mres, Fp.C.Mod.mres, Fp.C.getNegInvModWord())
|
||||
result.prod(a, b)
|
||||
|
||||
func `*=`*(a: var Fp, b: Fp) =
|
||||
## Multiplication modulo p
|
||||
|
@ -132,7 +165,7 @@ func `*=`*(a: var Fp, b: Fp) =
|
|||
## Cost
|
||||
## Stack: 1 * ModulusBitSize
|
||||
var tmp{.noInit.}: Fp
|
||||
tmp.mres.montyMul(a.mres, b.mres, Fp.C.Mod.mres, Fp.C.getNegInvModWord())
|
||||
tmp.prod(a, b)
|
||||
a = tmp
|
||||
|
||||
func square*(a: Fp): Fp {.noInit.} =
|
||||
|
|
|
@ -105,14 +105,9 @@ func square*(a: Fp2): Fp2 {.noInit.} =
|
|||
# Stack: 6 * ModulusBitSize (4x 𝔽p element + 1 named temporaries + 1 multiplication temporary)
|
||||
# as multiplications require a (shared) internal temporary
|
||||
|
||||
var c0mc1 = a.c0
|
||||
c0mc1 -= a.c1 # c0mc1 = c0 - c1 [1 Sub]
|
||||
result.c0 = c0mc1 # result.c0 = c0 - c1
|
||||
|
||||
result.c1 = a.c1
|
||||
result.c1.double() # result.c1 = 2 c1 [1 Dbl, 1 Sub]
|
||||
|
||||
result.c0 += result.c1 # result.c0 = c0 - c1 + 2 c1 [1 Add, 1 Dbl, 1 Sub]
|
||||
result.c0 *= c0mc1 # result.c0 = (c0 + c1)(c0 - c1) = c0² - c1² [1 Mul, 1 Add, 1 Dbl, 1 Sub]
|
||||
|
||||
result.c1 *= a.c0 # result.c1 = 2 c1 c0 [2 Mul, 1 Add, 1 Dbl, 1 Sub]
|
||||
var c0mc1 {.noInit.}: Fp
|
||||
c0mc1.diff(a.c0, a.c1) # c0mc1 = c0 - c1 [1 Sub]
|
||||
result.c1.double(a.c1) # result.c1 = 2 c1 [1 Dbl, 1 Sub]
|
||||
result.c0.sum(c0mc1, result.c1) # result.c0 = c0 - c1 + 2 c1 [1 Add, 1 Dbl, 1 Sub]
|
||||
result.c0 *= c0mc1 # result.c0 = (c0 + c1)(c0 - c1) = c0² - c1² [1 Mul, 1 Add, 1 Dbl, 1 Sub]
|
||||
result.c1 *= a.c0 # result.c1 = 2 c1 c0 [2 Mul, 1 Add, 1 Dbl, 1 Sub]
|
||||
|
|
|
@ -25,6 +25,7 @@ proc main() =
|
|||
y.fromUint(10'u32)
|
||||
z.fromUint(90'u32)
|
||||
|
||||
let u = x + y
|
||||
x += y
|
||||
|
||||
var x_bytes: array[8, byte]
|
||||
|
@ -33,6 +34,7 @@ proc main() =
|
|||
check:
|
||||
# Check equality in the Montgomery domain
|
||||
bool(z == x)
|
||||
bool(z == u)
|
||||
# Check equality when converting back to natural domain
|
||||
90'u64 == cast[uint64](x_bytes)
|
||||
|
||||
|
@ -43,6 +45,7 @@ proc main() =
|
|||
y.fromUint(21'u32)
|
||||
z.fromUint(0'u32)
|
||||
|
||||
let u = x + y
|
||||
x += y
|
||||
|
||||
var x_bytes: array[8, byte]
|
||||
|
@ -51,6 +54,7 @@ proc main() =
|
|||
check:
|
||||
# Check equality in the Montgomery domain
|
||||
bool(z == x)
|
||||
bool(z == u)
|
||||
# Check equality when converting back to natural domain
|
||||
0'u64 == cast[uint64](x_bytes)
|
||||
|
||||
|
@ -61,6 +65,7 @@ proc main() =
|
|||
y.fromUint(22'u32)
|
||||
z.fromUint(1'u32)
|
||||
|
||||
let u = x + y
|
||||
x += y
|
||||
|
||||
var x_bytes: array[8, byte]
|
||||
|
@ -69,6 +74,7 @@ proc main() =
|
|||
check:
|
||||
# Check equality in the Montgomery domain
|
||||
bool(z == x)
|
||||
bool(z == u)
|
||||
# Check equality when converting back to natural domain
|
||||
1'u64 == cast[uint64](x_bytes)
|
||||
|
||||
|
@ -80,6 +86,7 @@ proc main() =
|
|||
y.fromUint(10'u32)
|
||||
z.fromUint(70'u32)
|
||||
|
||||
let u = x - y
|
||||
x -= y
|
||||
|
||||
var x_bytes: array[8, byte]
|
||||
|
@ -88,6 +95,7 @@ proc main() =
|
|||
check:
|
||||
# Check equality in the Montgomery domain
|
||||
bool(z == x)
|
||||
bool(z == u)
|
||||
# Check equality when converting back to natural domain
|
||||
70'u64 == cast[uint64](x_bytes)
|
||||
|
||||
|
@ -98,6 +106,7 @@ proc main() =
|
|||
y.fromUint(80'u32)
|
||||
z.fromUint(0'u32)
|
||||
|
||||
let u = x - y
|
||||
x -= y
|
||||
|
||||
var x_bytes: array[8, byte]
|
||||
|
@ -106,6 +115,7 @@ proc main() =
|
|||
check:
|
||||
# Check equality in the Montgomery domain
|
||||
bool(z == x)
|
||||
bool(z == u)
|
||||
# Check equality when converting back to natural domain
|
||||
0'u64 == cast[uint64](x_bytes)
|
||||
|
||||
|
@ -116,6 +126,7 @@ proc main() =
|
|||
y.fromUint(81'u32)
|
||||
z.fromUint(100'u32)
|
||||
|
||||
let u = x - y
|
||||
x -= y
|
||||
|
||||
var x_bytes: array[8, byte]
|
||||
|
@ -124,6 +135,7 @@ proc main() =
|
|||
check:
|
||||
# Check equality in the Montgomery domain
|
||||
bool(z == x)
|
||||
bool(z == u)
|
||||
# Check equality when converting back to natural domain
|
||||
100'u64 == cast[uint64](x_bytes)
|
||||
|
||||
|
|
Loading…
Reference in New Issue