Square roots (#22)
* Add modular square root for p ≡ 3 (mod 4) * Exhaustive tests for sqrt with p ≡ 3 (mod 4) * fix typo
This commit is contained in:
parent
a6e4517be2
commit
42109d4f1c
|
@ -44,6 +44,7 @@ task test, "Run all tests":
|
|||
test "", "tests/test_io_fields"
|
||||
test "", "tests/test_finite_fields.nim"
|
||||
test "", "tests/test_finite_fields_mulsquare.nim"
|
||||
test "", "tests/test_finite_fields_sqrt.nim"
|
||||
test "", "tests/test_finite_fields_powinv.nim"
|
||||
|
||||
test "", "tests/test_finite_fields_vs_gmp.nim"
|
||||
|
@ -68,6 +69,7 @@ task test, "Run all tests":
|
|||
test "-d:Constantine32", "tests/test_io_fields"
|
||||
test "-d:Constantine32", "tests/test_finite_fields.nim"
|
||||
test "-d:Constantine32", "tests/test_finite_fields_mulsquare.nim"
|
||||
test "-d:Constantine32", "tests/test_finite_fields_sqrt.nim"
|
||||
test "-d:Constantine32", "tests/test_finite_fields_powinv.nim"
|
||||
|
||||
test "-d:Constantine32", "tests/test_finite_fields_vs_gmp.nim"
|
||||
|
@ -92,6 +94,7 @@ task test_no_gmp, "Run tests that don't require GMP":
|
|||
test "", "tests/test_io_fields"
|
||||
test "", "tests/test_finite_fields.nim"
|
||||
test "", "tests/test_finite_fields_mulsquare.nim"
|
||||
test "", "tests/test_finite_fields_sqrt.nim"
|
||||
test "", "tests/test_finite_fields_powinv.nim"
|
||||
|
||||
# Towers of extension fields
|
||||
|
@ -112,6 +115,7 @@ task test_no_gmp, "Run tests that don't require GMP":
|
|||
test "-d:Constantine32", "tests/test_io_fields"
|
||||
test "-d:Constantine32", "tests/test_finite_fields.nim"
|
||||
test "-d:Constantine32", "tests/test_finite_fields_mulsquare.nim"
|
||||
test "-d:Constantine32", "tests/test_finite_fields_sqrt.nim"
|
||||
test "-d:Constantine32", "tests/test_finite_fields_powinv.nim"
|
||||
|
||||
# Towers of extension fields
|
||||
|
|
|
@ -143,6 +143,10 @@ func isZero*(a: BigInt): CTBool[Word] =
|
|||
## Returns true if a big int is equal to zero
|
||||
a.limbs.isZero
|
||||
|
||||
func isOne*(a: BigInt): CTBool[Word] =
|
||||
## Returns true if a big int is equal to one
|
||||
a.limbs.isOne
|
||||
|
||||
func isOdd*(a: BigInt): CTBool[Word] =
|
||||
## Returns true if a is odd
|
||||
a.limbs.isOdd
|
||||
|
|
|
@ -68,6 +68,24 @@ func toBig*(src: Fp): auto {.noInit.} =
|
|||
r.redc(src.mres, Fp.C.Mod, Fp.C.getNegInvModWord(), Fp.C.canUseNoCarryMontyMul())
|
||||
return r
|
||||
|
||||
# Copy
|
||||
# ------------------------------------------------------------
|
||||
|
||||
func ccopy*(a: var Fp, b: Fp, ctl: CTBool[Word]) =
|
||||
## Constant-time conditional copy
|
||||
## If ctl is true: b is copied into a
|
||||
## if ctl is false: b is not copied and a is untouched
|
||||
## Time and memory accesses are the same whether a copy occurs or not
|
||||
ccopy(a.mres, b.mres, ctl)
|
||||
|
||||
func cswap*(a, b: var Fp, ctl: CTBool) =
|
||||
## Swap ``a`` and ``b`` if ``ctl`` is true
|
||||
##
|
||||
## Constant-time:
|
||||
## Whether ``ctl`` is true or not, the same
|
||||
## memory accesses are done (unless the compiler tries to be clever)
|
||||
cswap(a.mres, b.mres, ctl)
|
||||
|
||||
# ############################################################
|
||||
#
|
||||
# Field arithmetic primitives
|
||||
|
@ -92,6 +110,14 @@ func `==`*(a, b: Fp): CTBool[Word] =
|
|||
## Constant-time equality check
|
||||
a.mres == b.mres
|
||||
|
||||
func isZero*(a: Fp): CTBool[Word] =
|
||||
## Constant-time check if zero
|
||||
a.mres.isZero()
|
||||
|
||||
func isOne*(a: Fp): CTBool[Word] =
|
||||
## Constant-time check if one
|
||||
a.mres == Fp.C.getMontyOne()
|
||||
|
||||
func setZero*(a: var Fp) =
|
||||
## Set ``a`` to zero
|
||||
a.mres.setZero()
|
||||
|
@ -214,6 +240,65 @@ func powUnsafeExponent*(a: var Fp, exponent: openarray[byte]) =
|
|||
Fp.C.canUseNoCarryMontySquare()
|
||||
)
|
||||
|
||||
# ############################################################
|
||||
#
|
||||
# Field arithmetic square roots
|
||||
#
|
||||
# ############################################################
|
||||
|
||||
func isSquare*[C](a: Fp[C]): CTBool[Word] =
|
||||
## Returns true if ``a`` is a square (quadratic residue) in 𝔽p
|
||||
##
|
||||
## Assumes that the prime modulus ``p`` is public.
|
||||
# Implementation: we use exponentiation by (p-1)/2 (Euler(s criterion)
|
||||
# as it can reuse the exponentiation implementation
|
||||
# Note that we don't care about leaking the bits of p
|
||||
# as we assume that
|
||||
var xi {.noInit.} = a # TODO: is noInit necessary? see https://github.com/mratsim/constantine/issues/21
|
||||
xi.powUnsafeExponent(C.getPrimeMinus1div2_BE())
|
||||
result = xi.isOne()
|
||||
# 0 is also a square
|
||||
result = result or xi.isZero()
|
||||
|
||||
func sqrt_p3mod4*[C](a: var Fp[C]) =
|
||||
## Compute the square root of ``a``
|
||||
##
|
||||
## This requires ``a`` to be a square
|
||||
## and the prime field modulus ``p``: p ≡ 3 (mod 4)
|
||||
##
|
||||
## The result is undefined otherwise
|
||||
##
|
||||
## The square root, if it exist is multivalued,
|
||||
## i.e. both x² == (-x)²
|
||||
## This procedure returns a deterministic result
|
||||
static: doAssert C.Mod.limbs[0].BaseType mod 4 == 3
|
||||
a.powUnsafeExponent(C.getPrimePlus1div4_BE())
|
||||
|
||||
func sqrt_if_square_p3mod4*[C](a: var Fp[C]): CTBool[Word] =
|
||||
## If ``a`` is a square, compute the square root of ``a``
|
||||
## if not, ``a`` is unmodified.
|
||||
##
|
||||
## This assumes that the prime field modulus ``p``: p ≡ 3 (mod 4)
|
||||
##
|
||||
## The result is undefined otherwise
|
||||
##
|
||||
## The square root, if it exist is multivalued,
|
||||
## i.e. both x² == (-x)²
|
||||
## This procedure returns a deterministic result
|
||||
static: doAssert C.Mod.limbs[0].BaseType mod 4 == 3
|
||||
|
||||
var a1 {.noInit.} = a
|
||||
a1.powUnsafeExponent(C.getPrimeMinus3div4_BE())
|
||||
|
||||
var a1a {.noInit.}: Fp[C]
|
||||
a1a.prod(a1, a)
|
||||
|
||||
var a0 {.noInit.}: Fp[C]
|
||||
a0.prod(a1a, a1)
|
||||
|
||||
result = not(a0.mres == C.getMontyPrimeMinus1())
|
||||
a.ccopy(a1a, result)
|
||||
|
||||
# ############################################################
|
||||
#
|
||||
# Field arithmetic ergonomic primitives
|
||||
|
|
|
@ -91,6 +91,19 @@ func dbl(a: var BigInt): bool =
|
|||
|
||||
result = bool(carry)
|
||||
|
||||
func sub(a: var BigInt, w: BaseType): bool =
|
||||
## Limbs substraction, sub a number that fits in a word
|
||||
## Returns the carry
|
||||
var borrow, diff: BaseType
|
||||
subB(borrow, diff, BaseType(a.limbs[0]), w, borrow)
|
||||
a.limbs[0] = Word(diff)
|
||||
for i in 1 ..< a.limbs.len:
|
||||
let ai = BaseType(a.limbs[i])
|
||||
subB(borrow, diff, ai, 0, borrow)
|
||||
a.limbs[i] = Word(diff)
|
||||
|
||||
result = bool(borrow)
|
||||
|
||||
func csub(a: var BigInt, b: BigInt, ctl: bool): bool =
|
||||
## In-place optional substraction
|
||||
##
|
||||
|
@ -254,6 +267,12 @@ func montyOne*(M: BigInt): BigInt =
|
|||
## This is equivalent to R (mod M) in the natural domain
|
||||
r_powmod(1, M)
|
||||
|
||||
func montyPrimeMinus1*(P: BigInt): BigInt =
|
||||
## Compute P-1 in the Montgomery domain
|
||||
## For use in constant-time sqrt
|
||||
result = P
|
||||
discard result.csub(P.montyOne(), true)
|
||||
|
||||
func primeMinus2_BE*[bits: static int](
|
||||
P: BigInt[bits]
|
||||
): array[(bits+7) div 8, byte] {.noInit.} =
|
||||
|
@ -263,13 +282,15 @@ func primeMinus2_BE*[bits: static int](
|
|||
## when using inversion by Little Fermat Theorem a^-1 = a^(p-2) mod p
|
||||
|
||||
var tmp = P
|
||||
discard tmp.csub(BigInt[bits].fromRawUint([byte 2], bigEndian), true)
|
||||
discard tmp.sub(2)
|
||||
|
||||
result.exportRawUint(tmp, bigEndian)
|
||||
|
||||
func primePlus1div2*(P: BigInt): BigInt =
|
||||
## Compute (P+1)/2, assumes P is odd
|
||||
## For use in constant-time modular inversion
|
||||
##
|
||||
## Warning ⚠️: Result is in the canonical domain (not Montgomery)
|
||||
checkOddModulus(P)
|
||||
|
||||
# (P+1)/2 = P/2 + 1 if P is odd,
|
||||
|
@ -280,3 +301,63 @@ func primePlus1div2*(P: BigInt): BigInt =
|
|||
result.shiftRight(1)
|
||||
let carry = result.add(1)
|
||||
doAssert not carry
|
||||
|
||||
func primeMinus1div2_BE*[bits: static int](
|
||||
P: BigInt[bits]
|
||||
): array[(bits+7) div 8, byte] {.noInit.} =
|
||||
## For an input prime `p`, compute (p-1)/2
|
||||
## and return the result as a canonical byte array / octet string
|
||||
## For use to check if a number is a square (quadratic residue)
|
||||
## in a field by Euler's criterion
|
||||
##
|
||||
# Output size:
|
||||
# - (bits + 7) div 8: bits => byte conversion rounded up
|
||||
# - (bits + 7 - 1): dividing by 2 means 1 bit is unused
|
||||
# => TODO: reduce the output size (to potentially save a byte and corresponding multiplication/squarings)
|
||||
|
||||
var tmp = P
|
||||
discard tmp.sub(1)
|
||||
tmp.shiftRight(1)
|
||||
|
||||
result.exportRawUint(tmp, bigEndian)
|
||||
|
||||
func primeMinus3div4_BE*[bits: static int](
|
||||
P: BigInt[bits]
|
||||
): array[(bits+7) div 8, byte] {.noInit.} =
|
||||
## For an input prime `p`, compute (p-3)/4
|
||||
## and return the result as a canonical byte array / octet string
|
||||
## For use to check if a number is a square (quadratic residue)
|
||||
## and if so compute the square root in a fused manner
|
||||
##
|
||||
# Output size:
|
||||
# - (bits + 7) div 8: bits => byte conversion rounded up
|
||||
# - (bits + 7 - 2): dividing by 4 means 2 bits is unused
|
||||
# => TODO: reduce the output size (to potentially save a byte and corresponding multiplication/squarings)
|
||||
|
||||
var tmp = P
|
||||
discard tmp.sub(3)
|
||||
tmp.shiftRight(2)
|
||||
|
||||
result.exportRawUint(tmp, bigEndian)
|
||||
|
||||
func primePlus1Div4_BE*[bits: static int](
|
||||
P: BigInt[bits]
|
||||
): array[(bits+7) div 8, byte] {.noInit.} =
|
||||
## For an input prime `p`, compute (p+1)/4
|
||||
## and return the result as a canonical byte array / octet string
|
||||
## For use to check if a number is a square (quadratic residue)
|
||||
## in a field by Euler's criterion
|
||||
##
|
||||
# Output size:
|
||||
# - (bits + 7) div 8: bits => byte conversion rounded up
|
||||
# - (bits + 7 - 1): dividing by 4 means 2 bits are unused
|
||||
# but we also add 1 to an odd number so using an extra bit
|
||||
# => TODO: reduce the output size (to potentially save a byte and corresponding multiplication/squarings)
|
||||
checkOddModulus(P)
|
||||
|
||||
# First we do P+1/2 in a way that guarantees no overflow
|
||||
var tmp = primePlus1div2(P)
|
||||
# then divide by 2
|
||||
tmp.shiftRight(1)
|
||||
|
||||
result.exportRawUint(tmp, bigEndian)
|
||||
|
|
|
@ -47,6 +47,18 @@ declareCurves:
|
|||
testingCurve: true
|
||||
bitsize: 7
|
||||
modulus: "0x65" # 101 in hex
|
||||
curve Fake103: # 103 ≡ 3 (mod 4)
|
||||
testingCurve: true
|
||||
bitsize: 7
|
||||
modulus: "0x67" # 103 in hex
|
||||
curve Fake10007: # 10007 ≡ 3 (mod 4)
|
||||
testingCurve: true
|
||||
bitsize: 14
|
||||
modulus: "0x2717" # 10007 in hex
|
||||
curve Fake65519: # 65519 ≡ 3 (mod 4)
|
||||
testingCurve: true
|
||||
bitsize: 16
|
||||
modulus: "0xFFEF" # 65519 in hex
|
||||
curve Mersenne61:
|
||||
testingCurve: true
|
||||
bitsize: 61
|
||||
|
@ -206,6 +218,13 @@ macro genMontyMagics(T: typed): untyped =
|
|||
bindSym($curve & "_Modulus")
|
||||
)
|
||||
)
|
||||
# const MyCurve_MontyPrimeMinus1 = montyPrimeMinus1(MyCurve_Modulus)
|
||||
result.add newConstStmt(
|
||||
ident($curve & "_MontyPrimeMinus1"), newCall(
|
||||
bindSym"montyPrimeMinus1",
|
||||
bindSym($curve & "_Modulus")
|
||||
)
|
||||
)
|
||||
# const MyCurve_InvModExponent = primeMinus2_BE(MyCurve_Modulus)
|
||||
result.add newConstStmt(
|
||||
ident($curve & "_InvModExponent"), newCall(
|
||||
|
@ -220,6 +239,27 @@ macro genMontyMagics(T: typed): untyped =
|
|||
bindSym($curve & "_Modulus")
|
||||
)
|
||||
)
|
||||
# const MyCurve_PrimeMinus1div2_BE = primeMinus1div2_BE(MyCurve_Modulus)
|
||||
result.add newConstStmt(
|
||||
ident($curve & "_PrimeMinus1div2_BE"), newCall(
|
||||
bindSym"primeMinus1div2_BE",
|
||||
bindSym($curve & "_Modulus")
|
||||
)
|
||||
)
|
||||
# const MyCurve_PrimeMinus3div4_BE = primeMinus3div4_BE(MyCurve_Modulus)
|
||||
result.add newConstStmt(
|
||||
ident($curve & "_PrimeMinus3div4_BE"), newCall(
|
||||
bindSym"primeMinus3div4_BE",
|
||||
bindSym($curve & "_Modulus")
|
||||
)
|
||||
)
|
||||
# const MyCurve_PrimePlus1div4_BE = primePlus1div4_BE(MyCurve_Modulus)
|
||||
result.add newConstStmt(
|
||||
ident($curve & "_PrimePlus1div4_BE"), newCall(
|
||||
bindSym"primePlus1div4_BE",
|
||||
bindSym($curve & "_Modulus")
|
||||
)
|
||||
)
|
||||
|
||||
# echo result.toStrLit
|
||||
|
||||
|
@ -247,14 +287,31 @@ macro getMontyOne*(C: static Curve): untyped =
|
|||
## Get one in Montgomery representation (i.e. R mod P)
|
||||
result = bindSym($C & "_MontyOne")
|
||||
|
||||
macro getMontyPrimeMinus1*(C: static Curve): untyped =
|
||||
## Get (P+1) / 2 for an odd prime
|
||||
result = bindSym($C & "_MontyPrimeMinus1")
|
||||
|
||||
macro getInvModExponent*(C: static Curve): untyped =
|
||||
## Get modular inversion exponent (Modulus-2 in canonical representation)
|
||||
result = bindSym($C & "_InvModExponent")
|
||||
|
||||
macro getPrimePlus1div2*(C: static Curve): untyped =
|
||||
## Get (P+1) / 2 for an odd prime
|
||||
## Warning ⚠️: Result in canonical domain (not Montgomery)
|
||||
result = bindSym($C & "_PrimePlus1div2")
|
||||
|
||||
macro getPrimeMinus1div2_BE*(C: static Curve): untyped =
|
||||
## Get (P-1) / 2 in big-endian serialized format
|
||||
result = bindSym($C & "_PrimeMinus1div2_BE")
|
||||
|
||||
macro getPrimeMinus3div4_BE*(C: static Curve): untyped =
|
||||
## Get (P-3) / 2 in big-endian serialized format
|
||||
result = bindSym($C & "_PrimeMinus3div4_BE")
|
||||
|
||||
macro getPrimePlus1div4_BE*(C: static Curve): untyped =
|
||||
## Get (P+1) / 4 for an odd prime in big-endian serialized format
|
||||
result = bindSym($C & "_PrimePlus1div4_BE")
|
||||
|
||||
# ############################################################
|
||||
#
|
||||
# Debug info printed at compile-time
|
||||
|
|
|
@ -9,7 +9,7 @@
|
|||
import std/unittest, std/times,
|
||||
../constantine/arithmetic,
|
||||
../constantine/io/[io_bigints, io_fields],
|
||||
../constantine/config/curves,
|
||||
../constantine/config/[curves, common],
|
||||
# Test utilities
|
||||
../helpers/prng
|
||||
|
||||
|
@ -22,8 +22,6 @@ echo "test_finite_fields_mulsquare xoshiro512** seed: ", seed
|
|||
|
||||
static: doAssert defined(testingCurves), "This modules requires the -d:testingCurves compile option"
|
||||
|
||||
import ../constantine/config/common
|
||||
|
||||
proc sanity(C: static Curve) =
|
||||
test "Squaring 0,1,2 with "& $Curve(C) & " [FastSquaring = " & $C.canUseNoCarryMontySquare & "]":
|
||||
block: # 0² mod
|
||||
|
|
|
@ -0,0 +1,121 @@
|
|||
# Constantine
|
||||
# Copyright (c) 2018-2019 Status Research & Development GmbH
|
||||
# Copyright (c) 2020-Present Mamy André-Ratsimbazafy
|
||||
# Licensed and distributed under either of
|
||||
# * MIT license (license terms in the root directory or at http://opensource.org/licenses/MIT).
|
||||
# * Apache v2 license (license terms in the root directory or at http://www.apache.org/licenses/LICENSE-2.0).
|
||||
# at your option. This file may not be copied, modified, or distributed except according to those terms.
|
||||
|
||||
import std/unittest, std/times,
|
||||
../constantine/[arithmetic, primitives],
|
||||
../constantine/io/[io_fields],
|
||||
../constantine/config/[curves, common],
|
||||
# Test utilities
|
||||
../helpers/prng,
|
||||
# Standard library
|
||||
std/tables
|
||||
|
||||
const Iters = 128
|
||||
|
||||
var rng: RngState
|
||||
let seed = uint32(getTime().toUnix() and (1'i64 shl 32 - 1)) # unixTime mod 2^32
|
||||
rng.seed(seed)
|
||||
echo "test_finite_fields_sqrt xoshiro512** seed: ", seed
|
||||
|
||||
static: doAssert defined(testingCurves), "This modules requires the -d:testingCurves compile option"
|
||||
|
||||
proc exhaustiveCheck_p3mod4(C: static Curve, modulus: static int) =
|
||||
test "Exhaustive square root check for p ≡ 3 (mod 4) on " & $Curve(C):
|
||||
var squares_to_roots: Table[uint16, set[uint16]]
|
||||
|
||||
# Create all squares
|
||||
# -------------------------
|
||||
for i in 0'u16 ..< modulus:
|
||||
var a{.noInit.}: Fp[C]
|
||||
a.fromUint(i)
|
||||
|
||||
a.square()
|
||||
|
||||
var r_bytes: array[8, byte]
|
||||
r_bytes.exportRawUint(a, cpuEndian)
|
||||
let r = uint16(cast[uint64](r_bytes))
|
||||
|
||||
squares_to_roots.mgetOrPut(r, default(set[uint16])).incl(i)
|
||||
|
||||
# From Euler's criterion
|
||||
# there is exactly (p-1)/2 squares in 𝔽p* (without 0)
|
||||
# and so (p-1)/2 + 1 in 𝔽p (with 0)
|
||||
check: squares_to_roots.len == (modulus-1) div 2 + 1
|
||||
|
||||
# Check squares
|
||||
# -------------------------
|
||||
for i in 0'u16 ..< modulus:
|
||||
var a{.noInit.}: Fp[C]
|
||||
a.fromUint(i)
|
||||
|
||||
if i in squares_to_roots:
|
||||
var a2 = a
|
||||
check:
|
||||
bool a.isSquare()
|
||||
bool a.sqrt_if_square_p3mod4()
|
||||
|
||||
# 2 different code paths have the same result
|
||||
# (despite 2 square roots existing per square)
|
||||
a2.sqrt_p3mod4()
|
||||
check: bool(a == a2)
|
||||
|
||||
var r_bytes: array[8, byte]
|
||||
r_bytes.exportRawUint(a, cpuEndian)
|
||||
let r = uint16(cast[uint64](r_bytes))
|
||||
|
||||
# r is one of the 2 square roots of `i`
|
||||
check: r in squares_to_roots[i]
|
||||
|
||||
else:
|
||||
let a2 = a
|
||||
|
||||
check:
|
||||
bool not a.isSquare()
|
||||
bool not a.sqrt_if_square_p3mod4()
|
||||
bool (a == a2) # a shouldn't be modified
|
||||
|
||||
proc randomSqrtCheck_p3mod4(C: static Curve) =
|
||||
test "Random square root check for p ≡ 3 (mod 4) on " & $Curve(C):
|
||||
for _ in 0 ..< Iters:
|
||||
let a = rng.random(Fp[C])
|
||||
var na{.noInit.}: Fp[C]
|
||||
na.neg(a)
|
||||
|
||||
var a2 = a
|
||||
var na2 = na
|
||||
a2.square()
|
||||
na2.square()
|
||||
check:
|
||||
bool a2 == na2
|
||||
bool a2.isSquare()
|
||||
|
||||
var r, s = a2
|
||||
r.sqrt_p3mod4()
|
||||
let ok = s.sqrt_if_square_p3mod4()
|
||||
check:
|
||||
bool ok
|
||||
bool(r == s)
|
||||
bool(r == a or r == na)
|
||||
|
||||
proc main() =
|
||||
suite "Modular square root":
|
||||
exhaustiveCheck_p3mod4 Fake103, 103
|
||||
exhaustiveCheck_p3mod4 Fake10007, 10007
|
||||
exhaustiveCheck_p3mod4 Fake65519, 65519
|
||||
randomSqrtCheck_p3mod4 Mersenne61
|
||||
randomSqrtCheck_p3mod4 Mersenne127
|
||||
randomSqrtCheck_p3mod4 BN254
|
||||
randomSqrtCheck_p3mod4 P256
|
||||
randomSqrtCheck_p3mod4 Secp256k1
|
||||
randomSqrtCheck_p3mod4 BLS12_381
|
||||
randomSqrtCheck_p3mod4 BN446
|
||||
randomSqrtCheck_p3mod4 FKM12_447
|
||||
randomSqrtCheck_p3mod4 BLS12_461
|
||||
randomSqrtCheck_p3mod4 BN462
|
||||
|
||||
main()
|
|
@ -0,0 +1 @@
|
|||
-d:testingCurves
|
Loading…
Reference in New Issue