diff --git a/constantine/arithmetic/bigints_checked.nim b/constantine/arithmetic/bigints_checked.nim index 1f5339c..2525c18 100644 --- a/constantine/arithmetic/bigints_checked.nim +++ b/constantine/arithmetic/bigints_checked.nim @@ -168,6 +168,13 @@ func diff*(r: var BigInt, a, b: BigInt): CTBool[Word] = ## Returns the borrow diff(r.view, a.view, b.view) +func double*(r: var BigInt, a: BigInt): CTBool[Word] = + ## Double `a` into `r`. + ## `r` is initialized/overwritten + ## + ## Returns the carry + sum(r.view, a.view, a.view) + # ############################################################ # # Comparisons diff --git a/constantine/arithmetic/bigints_raw.nim b/constantine/arithmetic/bigints_raw.nim index 159081d..d9eda88 100644 --- a/constantine/arithmetic/bigints_raw.nim +++ b/constantine/arithmetic/bigints_raw.nim @@ -343,7 +343,7 @@ func sum*(r: BigIntViewMut, a, b: BigIntViewAny): CTBool[Word] = ## Returns the carry checkMatchingBitlengths(a, b) - r.setBitLength(bitSizeof(M)) + r.setBitLength(bitSizeof(a)) for i in 0 ..< a.numLimbs(): r[i] = a[i] + b[i] + Word(result) @@ -357,7 +357,7 @@ func diff*(r: BigIntViewMut, a, b: BigIntViewAny): CTBool[Word] = ## Returns the borrow checkMatchingBitlengths(a, b) - r.setBitLength(bitSizeof(M)) + r.setBitLength(bitSizeof(a)) for i in 0 ..< a.numLimbs(): r[i] = a[i] - b[i] - Word(result) diff --git a/constantine/arithmetic/finite_fields.nim b/constantine/arithmetic/finite_fields.nim index 233ebf9..29585e6 100644 --- a/constantine/arithmetic/finite_fields.nim +++ b/constantine/arithmetic/finite_fields.nim @@ -99,13 +99,13 @@ func setOne*(a: var Fp) = a = Fp.C.getMontyOne() func `+=`*(a: var Fp, b: Fp) = - ## Addition modulo p + ## In-place addition modulo p var overflowed = add(a.mres, b.mres) overflowed = overflowed or not csub(a.mres, Fp.C.Mod.mres, CtFalse) # a >= P discard csub(a.mres, Fp.C.Mod.mres, overflowed) func `-=`*(a: var Fp, b: Fp) = - ## Substraction modulo p + ## In-place substraction modulo p let underflowed = sub(a.mres, b.mres) discard cadd(a.mres, Fp.C.Mod.mres, underflowed) @@ -115,13 +115,46 @@ func double*(a: var Fp) = overflowed = overflowed or not csub(a.mres, Fp.C.Mod.mres, CtFalse) # a >= P discard csub(a.mres, Fp.C.Mod.mres, overflowed) +func sum*(r: var Fp, a, b: Fp) = + ## Sum ``a`` and ``b`` into ``r`` module p + ## r is initialized/overwritten + var overflowed = r.mres.sum(a.mres, b.mres) + overflowed = overflowed or not csub(r.mres, Fp.C.Mod.mres, CtFalse) # r >= P + discard csub(r.mres, Fp.C.Mod.mres, overflowed) + +func diff*(r: var Fp, a, b: Fp) = + ## Substract `b` from `a` and store the result into `r`. + ## `r` is initialized/overwritten + var underflowed = r.mres.diff(a.mres, b.mres) + discard cadd(r.mres, Fp.C.Mod.mres, underflowed) + +func double*(r: var Fp, a: Fp) = + ## Double ``a`` into ``r`` + ## `r` is initialized/overwritten + var overflowed = r.mres.double(a.mres) + overflowed = overflowed or not csub(r.mres, Fp.C.Mod.mres, CtFalse) # r >= P + discard csub(r.mres, Fp.C.Mod.mres, overflowed) + +func `+`*(a, b: Fp): Fp {.noInit.} = + ## Addition modulo p + result.sum(a, b) + +func `-`*(a, b: Fp): Fp {.noInit.} = + ## Substraction modulo p + result.diff(a, b) + +func prod*(r: var Fp, a, b: Fp) = + ## Store the product of ``a`` by ``b`` modulo p into ``r`` + ## ``r`` is initialized / overwritten + r.mres.montyMul(a.mres, b.mres, Fp.C.Mod.mres, Fp.C.getNegInvModWord()) + func `*`*(a, b: Fp): Fp {.noInit.} = ## Multiplication modulo p ## ## It is recommended to assign with {.noInit.} ## as Fp elements are usually large and this ## routine will zero init internally the result. - result.mres.montyMul(a.mres, b.mres, Fp.C.Mod.mres, Fp.C.getNegInvModWord()) + result.prod(a, b) func `*=`*(a: var Fp, b: Fp) = ## Multiplication modulo p @@ -132,7 +165,7 @@ func `*=`*(a: var Fp, b: Fp) = ## Cost ## Stack: 1 * ModulusBitSize var tmp{.noInit.}: Fp - tmp.mres.montyMul(a.mres, b.mres, Fp.C.Mod.mres, Fp.C.getNegInvModWord()) + tmp.prod(a, b) a = tmp func square*(a: Fp): Fp {.noInit.} = diff --git a/constantine/tower_field_extensions/fp2_complex.nim b/constantine/tower_field_extensions/fp2_complex.nim index 89f7549..223aca9 100644 --- a/constantine/tower_field_extensions/fp2_complex.nim +++ b/constantine/tower_field_extensions/fp2_complex.nim @@ -105,14 +105,9 @@ func square*(a: Fp2): Fp2 {.noInit.} = # Stack: 6 * ModulusBitSize (4x 𝔽p element + 1 named temporaries + 1 multiplication temporary) # as multiplications require a (shared) internal temporary - var c0mc1 = a.c0 - c0mc1 -= a.c1 # c0mc1 = c0 - c1 [1 Sub] - result.c0 = c0mc1 # result.c0 = c0 - c1 - - result.c1 = a.c1 - result.c1.double() # result.c1 = 2 c1 [1 Dbl, 1 Sub] - - result.c0 += result.c1 # result.c0 = c0 - c1 + 2 c1 [1 Add, 1 Dbl, 1 Sub] - result.c0 *= c0mc1 # result.c0 = (c0 + c1)(c0 - c1) = c0² - c1² [1 Mul, 1 Add, 1 Dbl, 1 Sub] - - result.c1 *= a.c0 # result.c1 = 2 c1 c0 [2 Mul, 1 Add, 1 Dbl, 1 Sub] + var c0mc1 {.noInit.}: Fp + c0mc1.diff(a.c0, a.c1) # c0mc1 = c0 - c1 [1 Sub] + result.c1.double(a.c1) # result.c1 = 2 c1 [1 Dbl, 1 Sub] + result.c0.sum(c0mc1, result.c1) # result.c0 = c0 - c1 + 2 c1 [1 Add, 1 Dbl, 1 Sub] + result.c0 *= c0mc1 # result.c0 = (c0 + c1)(c0 - c1) = c0² - c1² [1 Mul, 1 Add, 1 Dbl, 1 Sub] + result.c1 *= a.c0 # result.c1 = 2 c1 c0 [2 Mul, 1 Add, 1 Dbl, 1 Sub] diff --git a/tests/test_finite_fields.nim b/tests/test_finite_fields.nim index 78ab08d..1088e2b 100644 --- a/tests/test_finite_fields.nim +++ b/tests/test_finite_fields.nim @@ -25,6 +25,7 @@ proc main() = y.fromUint(10'u32) z.fromUint(90'u32) + let u = x + y x += y var x_bytes: array[8, byte] @@ -33,6 +34,7 @@ proc main() = check: # Check equality in the Montgomery domain bool(z == x) + bool(z == u) # Check equality when converting back to natural domain 90'u64 == cast[uint64](x_bytes) @@ -43,6 +45,7 @@ proc main() = y.fromUint(21'u32) z.fromUint(0'u32) + let u = x + y x += y var x_bytes: array[8, byte] @@ -51,6 +54,7 @@ proc main() = check: # Check equality in the Montgomery domain bool(z == x) + bool(z == u) # Check equality when converting back to natural domain 0'u64 == cast[uint64](x_bytes) @@ -61,6 +65,7 @@ proc main() = y.fromUint(22'u32) z.fromUint(1'u32) + let u = x + y x += y var x_bytes: array[8, byte] @@ -69,6 +74,7 @@ proc main() = check: # Check equality in the Montgomery domain bool(z == x) + bool(z == u) # Check equality when converting back to natural domain 1'u64 == cast[uint64](x_bytes) @@ -80,6 +86,7 @@ proc main() = y.fromUint(10'u32) z.fromUint(70'u32) + let u = x - y x -= y var x_bytes: array[8, byte] @@ -88,6 +95,7 @@ proc main() = check: # Check equality in the Montgomery domain bool(z == x) + bool(z == u) # Check equality when converting back to natural domain 70'u64 == cast[uint64](x_bytes) @@ -98,6 +106,7 @@ proc main() = y.fromUint(80'u32) z.fromUint(0'u32) + let u = x - y x -= y var x_bytes: array[8, byte] @@ -106,6 +115,7 @@ proc main() = check: # Check equality in the Montgomery domain bool(z == x) + bool(z == u) # Check equality when converting back to natural domain 0'u64 == cast[uint64](x_bytes) @@ -116,6 +126,7 @@ proc main() = y.fromUint(81'u32) z.fromUint(100'u32) + let u = x - y x -= y var x_bytes: array[8, byte] @@ -124,6 +135,7 @@ proc main() = check: # Check equality in the Montgomery domain bool(z == x) + bool(z == u) # Check equality when converting back to natural domain 100'u64 == cast[uint64](x_bytes)