Square roots (#22)

* Add modular square root for p ≡ 3 (mod 4)

* Exhaustive tests for sqrt with p ≡ 3 (mod 4)

* fix typo
This commit is contained in:
Mamy Ratsimbazafy 2020-04-11 23:53:21 +02:00 committed by GitHub
parent a6e4517be2
commit 42109d4f1c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 355 additions and 4 deletions

View File

@ -44,6 +44,7 @@ task test, "Run all tests":
test "", "tests/test_io_fields" test "", "tests/test_io_fields"
test "", "tests/test_finite_fields.nim" test "", "tests/test_finite_fields.nim"
test "", "tests/test_finite_fields_mulsquare.nim" test "", "tests/test_finite_fields_mulsquare.nim"
test "", "tests/test_finite_fields_sqrt.nim"
test "", "tests/test_finite_fields_powinv.nim" test "", "tests/test_finite_fields_powinv.nim"
test "", "tests/test_finite_fields_vs_gmp.nim" test "", "tests/test_finite_fields_vs_gmp.nim"
@ -68,6 +69,7 @@ task test, "Run all tests":
test "-d:Constantine32", "tests/test_io_fields" test "-d:Constantine32", "tests/test_io_fields"
test "-d:Constantine32", "tests/test_finite_fields.nim" test "-d:Constantine32", "tests/test_finite_fields.nim"
test "-d:Constantine32", "tests/test_finite_fields_mulsquare.nim" test "-d:Constantine32", "tests/test_finite_fields_mulsquare.nim"
test "-d:Constantine32", "tests/test_finite_fields_sqrt.nim"
test "-d:Constantine32", "tests/test_finite_fields_powinv.nim" test "-d:Constantine32", "tests/test_finite_fields_powinv.nim"
test "-d:Constantine32", "tests/test_finite_fields_vs_gmp.nim" test "-d:Constantine32", "tests/test_finite_fields_vs_gmp.nim"
@ -92,6 +94,7 @@ task test_no_gmp, "Run tests that don't require GMP":
test "", "tests/test_io_fields" test "", "tests/test_io_fields"
test "", "tests/test_finite_fields.nim" test "", "tests/test_finite_fields.nim"
test "", "tests/test_finite_fields_mulsquare.nim" test "", "tests/test_finite_fields_mulsquare.nim"
test "", "tests/test_finite_fields_sqrt.nim"
test "", "tests/test_finite_fields_powinv.nim" test "", "tests/test_finite_fields_powinv.nim"
# Towers of extension fields # Towers of extension fields
@ -112,6 +115,7 @@ task test_no_gmp, "Run tests that don't require GMP":
test "-d:Constantine32", "tests/test_io_fields" test "-d:Constantine32", "tests/test_io_fields"
test "-d:Constantine32", "tests/test_finite_fields.nim" test "-d:Constantine32", "tests/test_finite_fields.nim"
test "-d:Constantine32", "tests/test_finite_fields_mulsquare.nim" test "-d:Constantine32", "tests/test_finite_fields_mulsquare.nim"
test "-d:Constantine32", "tests/test_finite_fields_sqrt.nim"
test "-d:Constantine32", "tests/test_finite_fields_powinv.nim" test "-d:Constantine32", "tests/test_finite_fields_powinv.nim"
# Towers of extension fields # Towers of extension fields

View File

@ -143,6 +143,10 @@ func isZero*(a: BigInt): CTBool[Word] =
## Returns true if a big int is equal to zero ## Returns true if a big int is equal to zero
a.limbs.isZero a.limbs.isZero
func isOne*(a: BigInt): CTBool[Word] =
## Returns true if a big int is equal to one
a.limbs.isOne
func isOdd*(a: BigInt): CTBool[Word] = func isOdd*(a: BigInt): CTBool[Word] =
## Returns true if a is odd ## Returns true if a is odd
a.limbs.isOdd a.limbs.isOdd

View File

@ -68,6 +68,24 @@ func toBig*(src: Fp): auto {.noInit.} =
r.redc(src.mres, Fp.C.Mod, Fp.C.getNegInvModWord(), Fp.C.canUseNoCarryMontyMul()) r.redc(src.mres, Fp.C.Mod, Fp.C.getNegInvModWord(), Fp.C.canUseNoCarryMontyMul())
return r return r
# Copy
# ------------------------------------------------------------
func ccopy*(a: var Fp, b: Fp, ctl: CTBool[Word]) =
## Constant-time conditional copy
## If ctl is true: b is copied into a
## if ctl is false: b is not copied and a is untouched
## Time and memory accesses are the same whether a copy occurs or not
ccopy(a.mres, b.mres, ctl)
func cswap*(a, b: var Fp, ctl: CTBool) =
## Swap ``a`` and ``b`` if ``ctl`` is true
##
## Constant-time:
## Whether ``ctl`` is true or not, the same
## memory accesses are done (unless the compiler tries to be clever)
cswap(a.mres, b.mres, ctl)
# ############################################################ # ############################################################
# #
# Field arithmetic primitives # Field arithmetic primitives
@ -92,6 +110,14 @@ func `==`*(a, b: Fp): CTBool[Word] =
## Constant-time equality check ## Constant-time equality check
a.mres == b.mres a.mres == b.mres
func isZero*(a: Fp): CTBool[Word] =
## Constant-time check if zero
a.mres.isZero()
func isOne*(a: Fp): CTBool[Word] =
## Constant-time check if one
a.mres == Fp.C.getMontyOne()
func setZero*(a: var Fp) = func setZero*(a: var Fp) =
## Set ``a`` to zero ## Set ``a`` to zero
a.mres.setZero() a.mres.setZero()
@ -214,6 +240,65 @@ func powUnsafeExponent*(a: var Fp, exponent: openarray[byte]) =
Fp.C.canUseNoCarryMontySquare() Fp.C.canUseNoCarryMontySquare()
) )
# ############################################################
#
# Field arithmetic square roots
#
# ############################################################
func isSquare*[C](a: Fp[C]): CTBool[Word] =
## Returns true if ``a`` is a square (quadratic residue) in 𝔽p
##
## Assumes that the prime modulus ``p`` is public.
# Implementation: we use exponentiation by (p-1)/2 (Euler(s criterion)
# as it can reuse the exponentiation implementation
# Note that we don't care about leaking the bits of p
# as we assume that
var xi {.noInit.} = a # TODO: is noInit necessary? see https://github.com/mratsim/constantine/issues/21
xi.powUnsafeExponent(C.getPrimeMinus1div2_BE())
result = xi.isOne()
# 0 is also a square
result = result or xi.isZero()
func sqrt_p3mod4*[C](a: var Fp[C]) =
## Compute the square root of ``a``
##
## This requires ``a`` to be a square
## and the prime field modulus ``p``: p ≡ 3 (mod 4)
##
## The result is undefined otherwise
##
## The square root, if it exist is multivalued,
## i.e. both x² == (-x)²
## This procedure returns a deterministic result
static: doAssert C.Mod.limbs[0].BaseType mod 4 == 3
a.powUnsafeExponent(C.getPrimePlus1div4_BE())
func sqrt_if_square_p3mod4*[C](a: var Fp[C]): CTBool[Word] =
## If ``a`` is a square, compute the square root of ``a``
## if not, ``a`` is unmodified.
##
## This assumes that the prime field modulus ``p``: p ≡ 3 (mod 4)
##
## The result is undefined otherwise
##
## The square root, if it exist is multivalued,
## i.e. both x² == (-x)²
## This procedure returns a deterministic result
static: doAssert C.Mod.limbs[0].BaseType mod 4 == 3
var a1 {.noInit.} = a
a1.powUnsafeExponent(C.getPrimeMinus3div4_BE())
var a1a {.noInit.}: Fp[C]
a1a.prod(a1, a)
var a0 {.noInit.}: Fp[C]
a0.prod(a1a, a1)
result = not(a0.mres == C.getMontyPrimeMinus1())
a.ccopy(a1a, result)
# ############################################################ # ############################################################
# #
# Field arithmetic ergonomic primitives # Field arithmetic ergonomic primitives

View File

@ -91,6 +91,19 @@ func dbl(a: var BigInt): bool =
result = bool(carry) result = bool(carry)
func sub(a: var BigInt, w: BaseType): bool =
## Limbs substraction, sub a number that fits in a word
## Returns the carry
var borrow, diff: BaseType
subB(borrow, diff, BaseType(a.limbs[0]), w, borrow)
a.limbs[0] = Word(diff)
for i in 1 ..< a.limbs.len:
let ai = BaseType(a.limbs[i])
subB(borrow, diff, ai, 0, borrow)
a.limbs[i] = Word(diff)
result = bool(borrow)
func csub(a: var BigInt, b: BigInt, ctl: bool): bool = func csub(a: var BigInt, b: BigInt, ctl: bool): bool =
## In-place optional substraction ## In-place optional substraction
## ##
@ -254,6 +267,12 @@ func montyOne*(M: BigInt): BigInt =
## This is equivalent to R (mod M) in the natural domain ## This is equivalent to R (mod M) in the natural domain
r_powmod(1, M) r_powmod(1, M)
func montyPrimeMinus1*(P: BigInt): BigInt =
## Compute P-1 in the Montgomery domain
## For use in constant-time sqrt
result = P
discard result.csub(P.montyOne(), true)
func primeMinus2_BE*[bits: static int]( func primeMinus2_BE*[bits: static int](
P: BigInt[bits] P: BigInt[bits]
): array[(bits+7) div 8, byte] {.noInit.} = ): array[(bits+7) div 8, byte] {.noInit.} =
@ -263,13 +282,15 @@ func primeMinus2_BE*[bits: static int](
## when using inversion by Little Fermat Theorem a^-1 = a^(p-2) mod p ## when using inversion by Little Fermat Theorem a^-1 = a^(p-2) mod p
var tmp = P var tmp = P
discard tmp.csub(BigInt[bits].fromRawUint([byte 2], bigEndian), true) discard tmp.sub(2)
result.exportRawUint(tmp, bigEndian) result.exportRawUint(tmp, bigEndian)
func primePlus1div2*(P: BigInt): BigInt = func primePlus1div2*(P: BigInt): BigInt =
## Compute (P+1)/2, assumes P is odd ## Compute (P+1)/2, assumes P is odd
## For use in constant-time modular inversion ## For use in constant-time modular inversion
##
## Warning ⚠️: Result is in the canonical domain (not Montgomery)
checkOddModulus(P) checkOddModulus(P)
# (P+1)/2 = P/2 + 1 if P is odd, # (P+1)/2 = P/2 + 1 if P is odd,
@ -280,3 +301,63 @@ func primePlus1div2*(P: BigInt): BigInt =
result.shiftRight(1) result.shiftRight(1)
let carry = result.add(1) let carry = result.add(1)
doAssert not carry doAssert not carry
func primeMinus1div2_BE*[bits: static int](
P: BigInt[bits]
): array[(bits+7) div 8, byte] {.noInit.} =
## For an input prime `p`, compute (p-1)/2
## and return the result as a canonical byte array / octet string
## For use to check if a number is a square (quadratic residue)
## in a field by Euler's criterion
##
# Output size:
# - (bits + 7) div 8: bits => byte conversion rounded up
# - (bits + 7 - 1): dividing by 2 means 1 bit is unused
# => TODO: reduce the output size (to potentially save a byte and corresponding multiplication/squarings)
var tmp = P
discard tmp.sub(1)
tmp.shiftRight(1)
result.exportRawUint(tmp, bigEndian)
func primeMinus3div4_BE*[bits: static int](
P: BigInt[bits]
): array[(bits+7) div 8, byte] {.noInit.} =
## For an input prime `p`, compute (p-3)/4
## and return the result as a canonical byte array / octet string
## For use to check if a number is a square (quadratic residue)
## and if so compute the square root in a fused manner
##
# Output size:
# - (bits + 7) div 8: bits => byte conversion rounded up
# - (bits + 7 - 2): dividing by 4 means 2 bits is unused
# => TODO: reduce the output size (to potentially save a byte and corresponding multiplication/squarings)
var tmp = P
discard tmp.sub(3)
tmp.shiftRight(2)
result.exportRawUint(tmp, bigEndian)
func primePlus1Div4_BE*[bits: static int](
P: BigInt[bits]
): array[(bits+7) div 8, byte] {.noInit.} =
## For an input prime `p`, compute (p+1)/4
## and return the result as a canonical byte array / octet string
## For use to check if a number is a square (quadratic residue)
## in a field by Euler's criterion
##
# Output size:
# - (bits + 7) div 8: bits => byte conversion rounded up
# - (bits + 7 - 1): dividing by 4 means 2 bits are unused
# but we also add 1 to an odd number so using an extra bit
# => TODO: reduce the output size (to potentially save a byte and corresponding multiplication/squarings)
checkOddModulus(P)
# First we do P+1/2 in a way that guarantees no overflow
var tmp = primePlus1div2(P)
# then divide by 2
tmp.shiftRight(1)
result.exportRawUint(tmp, bigEndian)

View File

@ -47,6 +47,18 @@ declareCurves:
testingCurve: true testingCurve: true
bitsize: 7 bitsize: 7
modulus: "0x65" # 101 in hex modulus: "0x65" # 101 in hex
curve Fake103: # 103 ≡ 3 (mod 4)
testingCurve: true
bitsize: 7
modulus: "0x67" # 103 in hex
curve Fake10007: # 10007 ≡ 3 (mod 4)
testingCurve: true
bitsize: 14
modulus: "0x2717" # 10007 in hex
curve Fake65519: # 65519 ≡ 3 (mod 4)
testingCurve: true
bitsize: 16
modulus: "0xFFEF" # 65519 in hex
curve Mersenne61: curve Mersenne61:
testingCurve: true testingCurve: true
bitsize: 61 bitsize: 61
@ -206,6 +218,13 @@ macro genMontyMagics(T: typed): untyped =
bindSym($curve & "_Modulus") bindSym($curve & "_Modulus")
) )
) )
# const MyCurve_MontyPrimeMinus1 = montyPrimeMinus1(MyCurve_Modulus)
result.add newConstStmt(
ident($curve & "_MontyPrimeMinus1"), newCall(
bindSym"montyPrimeMinus1",
bindSym($curve & "_Modulus")
)
)
# const MyCurve_InvModExponent = primeMinus2_BE(MyCurve_Modulus) # const MyCurve_InvModExponent = primeMinus2_BE(MyCurve_Modulus)
result.add newConstStmt( result.add newConstStmt(
ident($curve & "_InvModExponent"), newCall( ident($curve & "_InvModExponent"), newCall(
@ -220,6 +239,27 @@ macro genMontyMagics(T: typed): untyped =
bindSym($curve & "_Modulus") bindSym($curve & "_Modulus")
) )
) )
# const MyCurve_PrimeMinus1div2_BE = primeMinus1div2_BE(MyCurve_Modulus)
result.add newConstStmt(
ident($curve & "_PrimeMinus1div2_BE"), newCall(
bindSym"primeMinus1div2_BE",
bindSym($curve & "_Modulus")
)
)
# const MyCurve_PrimeMinus3div4_BE = primeMinus3div4_BE(MyCurve_Modulus)
result.add newConstStmt(
ident($curve & "_PrimeMinus3div4_BE"), newCall(
bindSym"primeMinus3div4_BE",
bindSym($curve & "_Modulus")
)
)
# const MyCurve_PrimePlus1div4_BE = primePlus1div4_BE(MyCurve_Modulus)
result.add newConstStmt(
ident($curve & "_PrimePlus1div4_BE"), newCall(
bindSym"primePlus1div4_BE",
bindSym($curve & "_Modulus")
)
)
# echo result.toStrLit # echo result.toStrLit
@ -247,14 +287,31 @@ macro getMontyOne*(C: static Curve): untyped =
## Get one in Montgomery representation (i.e. R mod P) ## Get one in Montgomery representation (i.e. R mod P)
result = bindSym($C & "_MontyOne") result = bindSym($C & "_MontyOne")
macro getMontyPrimeMinus1*(C: static Curve): untyped =
## Get (P+1) / 2 for an odd prime
result = bindSym($C & "_MontyPrimeMinus1")
macro getInvModExponent*(C: static Curve): untyped = macro getInvModExponent*(C: static Curve): untyped =
## Get modular inversion exponent (Modulus-2 in canonical representation) ## Get modular inversion exponent (Modulus-2 in canonical representation)
result = bindSym($C & "_InvModExponent") result = bindSym($C & "_InvModExponent")
macro getPrimePlus1div2*(C: static Curve): untyped = macro getPrimePlus1div2*(C: static Curve): untyped =
## Get (P+1) / 2 for an odd prime ## Get (P+1) / 2 for an odd prime
## Warning ⚠️: Result in canonical domain (not Montgomery)
result = bindSym($C & "_PrimePlus1div2") result = bindSym($C & "_PrimePlus1div2")
macro getPrimeMinus1div2_BE*(C: static Curve): untyped =
## Get (P-1) / 2 in big-endian serialized format
result = bindSym($C & "_PrimeMinus1div2_BE")
macro getPrimeMinus3div4_BE*(C: static Curve): untyped =
## Get (P-3) / 2 in big-endian serialized format
result = bindSym($C & "_PrimeMinus3div4_BE")
macro getPrimePlus1div4_BE*(C: static Curve): untyped =
## Get (P+1) / 4 for an odd prime in big-endian serialized format
result = bindSym($C & "_PrimePlus1div4_BE")
# ############################################################ # ############################################################
# #
# Debug info printed at compile-time # Debug info printed at compile-time

View File

@ -9,7 +9,7 @@
import std/unittest, std/times, import std/unittest, std/times,
../constantine/arithmetic, ../constantine/arithmetic,
../constantine/io/[io_bigints, io_fields], ../constantine/io/[io_bigints, io_fields],
../constantine/config/curves, ../constantine/config/[curves, common],
# Test utilities # Test utilities
../helpers/prng ../helpers/prng
@ -22,8 +22,6 @@ echo "test_finite_fields_mulsquare xoshiro512** seed: ", seed
static: doAssert defined(testingCurves), "This modules requires the -d:testingCurves compile option" static: doAssert defined(testingCurves), "This modules requires the -d:testingCurves compile option"
import ../constantine/config/common
proc sanity(C: static Curve) = proc sanity(C: static Curve) =
test "Squaring 0,1,2 with "& $Curve(C) & " [FastSquaring = " & $C.canUseNoCarryMontySquare & "]": test "Squaring 0,1,2 with "& $Curve(C) & " [FastSquaring = " & $C.canUseNoCarryMontySquare & "]":
block: # 0² mod block: # 0² mod

View File

@ -0,0 +1,121 @@
# Constantine
# Copyright (c) 2018-2019 Status Research & Development GmbH
# Copyright (c) 2020-Present Mamy André-Ratsimbazafy
# Licensed and distributed under either of
# * MIT license (license terms in the root directory or at http://opensource.org/licenses/MIT).
# * Apache v2 license (license terms in the root directory or at http://www.apache.org/licenses/LICENSE-2.0).
# at your option. This file may not be copied, modified, or distributed except according to those terms.
import std/unittest, std/times,
../constantine/[arithmetic, primitives],
../constantine/io/[io_fields],
../constantine/config/[curves, common],
# Test utilities
../helpers/prng,
# Standard library
std/tables
const Iters = 128
var rng: RngState
let seed = uint32(getTime().toUnix() and (1'i64 shl 32 - 1)) # unixTime mod 2^32
rng.seed(seed)
echo "test_finite_fields_sqrt xoshiro512** seed: ", seed
static: doAssert defined(testingCurves), "This modules requires the -d:testingCurves compile option"
proc exhaustiveCheck_p3mod4(C: static Curve, modulus: static int) =
test "Exhaustive square root check for p ≡ 3 (mod 4) on " & $Curve(C):
var squares_to_roots: Table[uint16, set[uint16]]
# Create all squares
# -------------------------
for i in 0'u16 ..< modulus:
var a{.noInit.}: Fp[C]
a.fromUint(i)
a.square()
var r_bytes: array[8, byte]
r_bytes.exportRawUint(a, cpuEndian)
let r = uint16(cast[uint64](r_bytes))
squares_to_roots.mgetOrPut(r, default(set[uint16])).incl(i)
# From Euler's criterion
# there is exactly (p-1)/2 squares in 𝔽p* (without 0)
# and so (p-1)/2 + 1 in 𝔽p (with 0)
check: squares_to_roots.len == (modulus-1) div 2 + 1
# Check squares
# -------------------------
for i in 0'u16 ..< modulus:
var a{.noInit.}: Fp[C]
a.fromUint(i)
if i in squares_to_roots:
var a2 = a
check:
bool a.isSquare()
bool a.sqrt_if_square_p3mod4()
# 2 different code paths have the same result
# (despite 2 square roots existing per square)
a2.sqrt_p3mod4()
check: bool(a == a2)
var r_bytes: array[8, byte]
r_bytes.exportRawUint(a, cpuEndian)
let r = uint16(cast[uint64](r_bytes))
# r is one of the 2 square roots of `i`
check: r in squares_to_roots[i]
else:
let a2 = a
check:
bool not a.isSquare()
bool not a.sqrt_if_square_p3mod4()
bool (a == a2) # a shouldn't be modified
proc randomSqrtCheck_p3mod4(C: static Curve) =
test "Random square root check for p ≡ 3 (mod 4) on " & $Curve(C):
for _ in 0 ..< Iters:
let a = rng.random(Fp[C])
var na{.noInit.}: Fp[C]
na.neg(a)
var a2 = a
var na2 = na
a2.square()
na2.square()
check:
bool a2 == na2
bool a2.isSquare()
var r, s = a2
r.sqrt_p3mod4()
let ok = s.sqrt_if_square_p3mod4()
check:
bool ok
bool(r == s)
bool(r == a or r == na)
proc main() =
suite "Modular square root":
exhaustiveCheck_p3mod4 Fake103, 103
exhaustiveCheck_p3mod4 Fake10007, 10007
exhaustiveCheck_p3mod4 Fake65519, 65519
randomSqrtCheck_p3mod4 Mersenne61
randomSqrtCheck_p3mod4 Mersenne127
randomSqrtCheck_p3mod4 BN254
randomSqrtCheck_p3mod4 P256
randomSqrtCheck_p3mod4 Secp256k1
randomSqrtCheck_p3mod4 BLS12_381
randomSqrtCheck_p3mod4 BN446
randomSqrtCheck_p3mod4 FKM12_447
randomSqrtCheck_p3mod4 BLS12_461
randomSqrtCheck_p3mod4 BN462
main()

View File

@ -0,0 +1 @@
-d:testingCurves