diff --git a/constantine.nimble b/constantine.nimble index c947216..b8b0678 100644 --- a/constantine.nimble +++ b/constantine.nimble @@ -44,6 +44,7 @@ task test, "Run all tests": test "", "tests/test_io_fields" test "", "tests/test_finite_fields.nim" test "", "tests/test_finite_fields_mulsquare.nim" + test "", "tests/test_finite_fields_sqrt.nim" test "", "tests/test_finite_fields_powinv.nim" test "", "tests/test_finite_fields_vs_gmp.nim" @@ -68,6 +69,7 @@ task test, "Run all tests": test "-d:Constantine32", "tests/test_io_fields" test "-d:Constantine32", "tests/test_finite_fields.nim" test "-d:Constantine32", "tests/test_finite_fields_mulsquare.nim" + test "-d:Constantine32", "tests/test_finite_fields_sqrt.nim" test "-d:Constantine32", "tests/test_finite_fields_powinv.nim" test "-d:Constantine32", "tests/test_finite_fields_vs_gmp.nim" @@ -92,6 +94,7 @@ task test_no_gmp, "Run tests that don't require GMP": test "", "tests/test_io_fields" test "", "tests/test_finite_fields.nim" test "", "tests/test_finite_fields_mulsquare.nim" + test "", "tests/test_finite_fields_sqrt.nim" test "", "tests/test_finite_fields_powinv.nim" # Towers of extension fields @@ -112,6 +115,7 @@ task test_no_gmp, "Run tests that don't require GMP": test "-d:Constantine32", "tests/test_io_fields" test "-d:Constantine32", "tests/test_finite_fields.nim" test "-d:Constantine32", "tests/test_finite_fields_mulsquare.nim" + test "-d:Constantine32", "tests/test_finite_fields_sqrt.nim" test "-d:Constantine32", "tests/test_finite_fields_powinv.nim" # Towers of extension fields diff --git a/constantine/arithmetic/bigints.nim b/constantine/arithmetic/bigints.nim index a7c3719..9a67c5f 100644 --- a/constantine/arithmetic/bigints.nim +++ b/constantine/arithmetic/bigints.nim @@ -143,6 +143,10 @@ func isZero*(a: BigInt): CTBool[Word] = ## Returns true if a big int is equal to zero a.limbs.isZero +func isOne*(a: BigInt): CTBool[Word] = + ## Returns true if a big int is equal to one + a.limbs.isOne + func isOdd*(a: BigInt): CTBool[Word] = ## Returns true if a is odd a.limbs.isOdd diff --git a/constantine/arithmetic/finite_fields.nim b/constantine/arithmetic/finite_fields.nim index 4083639..a12aa2a 100644 --- a/constantine/arithmetic/finite_fields.nim +++ b/constantine/arithmetic/finite_fields.nim @@ -68,6 +68,24 @@ func toBig*(src: Fp): auto {.noInit.} = r.redc(src.mres, Fp.C.Mod, Fp.C.getNegInvModWord(), Fp.C.canUseNoCarryMontyMul()) return r +# Copy +# ------------------------------------------------------------ + +func ccopy*(a: var Fp, b: Fp, ctl: CTBool[Word]) = + ## Constant-time conditional copy + ## If ctl is true: b is copied into a + ## if ctl is false: b is not copied and a is untouched + ## Time and memory accesses are the same whether a copy occurs or not + ccopy(a.mres, b.mres, ctl) + +func cswap*(a, b: var Fp, ctl: CTBool) = + ## Swap ``a`` and ``b`` if ``ctl`` is true + ## + ## Constant-time: + ## Whether ``ctl`` is true or not, the same + ## memory accesses are done (unless the compiler tries to be clever) + cswap(a.mres, b.mres, ctl) + # ############################################################ # # Field arithmetic primitives @@ -92,6 +110,14 @@ func `==`*(a, b: Fp): CTBool[Word] = ## Constant-time equality check a.mres == b.mres +func isZero*(a: Fp): CTBool[Word] = + ## Constant-time check if zero + a.mres.isZero() + +func isOne*(a: Fp): CTBool[Word] = + ## Constant-time check if one + a.mres == Fp.C.getMontyOne() + func setZero*(a: var Fp) = ## Set ``a`` to zero a.mres.setZero() @@ -214,6 +240,65 @@ func powUnsafeExponent*(a: var Fp, exponent: openarray[byte]) = Fp.C.canUseNoCarryMontySquare() ) +# ############################################################ +# +# Field arithmetic square roots +# +# ############################################################ + +func isSquare*[C](a: Fp[C]): CTBool[Word] = + ## Returns true if ``a`` is a square (quadratic residue) in 𝔽p + ## + ## Assumes that the prime modulus ``p`` is public. + # Implementation: we use exponentiation by (p-1)/2 (Euler(s criterion) + # as it can reuse the exponentiation implementation + # Note that we don't care about leaking the bits of p + # as we assume that + var xi {.noInit.} = a # TODO: is noInit necessary? see https://github.com/mratsim/constantine/issues/21 + xi.powUnsafeExponent(C.getPrimeMinus1div2_BE()) + result = xi.isOne() + # 0 is also a square + result = result or xi.isZero() + +func sqrt_p3mod4*[C](a: var Fp[C]) = + ## Compute the square root of ``a`` + ## + ## This requires ``a`` to be a square + ## and the prime field modulus ``p``: p ≡ 3 (mod 4) + ## + ## The result is undefined otherwise + ## + ## The square root, if it exist is multivalued, + ## i.e. both x² == (-x)² + ## This procedure returns a deterministic result + static: doAssert C.Mod.limbs[0].BaseType mod 4 == 3 + a.powUnsafeExponent(C.getPrimePlus1div4_BE()) + +func sqrt_if_square_p3mod4*[C](a: var Fp[C]): CTBool[Word] = + ## If ``a`` is a square, compute the square root of ``a`` + ## if not, ``a`` is unmodified. + ## + ## This assumes that the prime field modulus ``p``: p ≡ 3 (mod 4) + ## + ## The result is undefined otherwise + ## + ## The square root, if it exist is multivalued, + ## i.e. both x² == (-x)² + ## This procedure returns a deterministic result + static: doAssert C.Mod.limbs[0].BaseType mod 4 == 3 + + var a1 {.noInit.} = a + a1.powUnsafeExponent(C.getPrimeMinus3div4_BE()) + + var a1a {.noInit.}: Fp[C] + a1a.prod(a1, a) + + var a0 {.noInit.}: Fp[C] + a0.prod(a1a, a1) + + result = not(a0.mres == C.getMontyPrimeMinus1()) + a.ccopy(a1a, result) + # ############################################################ # # Field arithmetic ergonomic primitives diff --git a/constantine/arithmetic/precomputed.nim b/constantine/arithmetic/precomputed.nim index 64da688..49360bb 100644 --- a/constantine/arithmetic/precomputed.nim +++ b/constantine/arithmetic/precomputed.nim @@ -91,6 +91,19 @@ func dbl(a: var BigInt): bool = result = bool(carry) +func sub(a: var BigInt, w: BaseType): bool = + ## Limbs substraction, sub a number that fits in a word + ## Returns the carry + var borrow, diff: BaseType + subB(borrow, diff, BaseType(a.limbs[0]), w, borrow) + a.limbs[0] = Word(diff) + for i in 1 ..< a.limbs.len: + let ai = BaseType(a.limbs[i]) + subB(borrow, diff, ai, 0, borrow) + a.limbs[i] = Word(diff) + + result = bool(borrow) + func csub(a: var BigInt, b: BigInt, ctl: bool): bool = ## In-place optional substraction ## @@ -254,6 +267,12 @@ func montyOne*(M: BigInt): BigInt = ## This is equivalent to R (mod M) in the natural domain r_powmod(1, M) +func montyPrimeMinus1*(P: BigInt): BigInt = + ## Compute P-1 in the Montgomery domain + ## For use in constant-time sqrt + result = P + discard result.csub(P.montyOne(), true) + func primeMinus2_BE*[bits: static int]( P: BigInt[bits] ): array[(bits+7) div 8, byte] {.noInit.} = @@ -263,13 +282,15 @@ func primeMinus2_BE*[bits: static int]( ## when using inversion by Little Fermat Theorem a^-1 = a^(p-2) mod p var tmp = P - discard tmp.csub(BigInt[bits].fromRawUint([byte 2], bigEndian), true) + discard tmp.sub(2) result.exportRawUint(tmp, bigEndian) func primePlus1div2*(P: BigInt): BigInt = ## Compute (P+1)/2, assumes P is odd ## For use in constant-time modular inversion + ## + ## Warning ⚠️: Result is in the canonical domain (not Montgomery) checkOddModulus(P) # (P+1)/2 = P/2 + 1 if P is odd, @@ -280,3 +301,63 @@ func primePlus1div2*(P: BigInt): BigInt = result.shiftRight(1) let carry = result.add(1) doAssert not carry + +func primeMinus1div2_BE*[bits: static int]( + P: BigInt[bits] + ): array[(bits+7) div 8, byte] {.noInit.} = + ## For an input prime `p`, compute (p-1)/2 + ## and return the result as a canonical byte array / octet string + ## For use to check if a number is a square (quadratic residue) + ## in a field by Euler's criterion + ## + # Output size: + # - (bits + 7) div 8: bits => byte conversion rounded up + # - (bits + 7 - 1): dividing by 2 means 1 bit is unused + # => TODO: reduce the output size (to potentially save a byte and corresponding multiplication/squarings) + + var tmp = P + discard tmp.sub(1) + tmp.shiftRight(1) + + result.exportRawUint(tmp, bigEndian) + +func primeMinus3div4_BE*[bits: static int]( + P: BigInt[bits] + ): array[(bits+7) div 8, byte] {.noInit.} = + ## For an input prime `p`, compute (p-3)/4 + ## and return the result as a canonical byte array / octet string + ## For use to check if a number is a square (quadratic residue) + ## and if so compute the square root in a fused manner + ## + # Output size: + # - (bits + 7) div 8: bits => byte conversion rounded up + # - (bits + 7 - 2): dividing by 4 means 2 bits is unused + # => TODO: reduce the output size (to potentially save a byte and corresponding multiplication/squarings) + + var tmp = P + discard tmp.sub(3) + tmp.shiftRight(2) + + result.exportRawUint(tmp, bigEndian) + +func primePlus1Div4_BE*[bits: static int]( + P: BigInt[bits] + ): array[(bits+7) div 8, byte] {.noInit.} = + ## For an input prime `p`, compute (p+1)/4 + ## and return the result as a canonical byte array / octet string + ## For use to check if a number is a square (quadratic residue) + ## in a field by Euler's criterion + ## + # Output size: + # - (bits + 7) div 8: bits => byte conversion rounded up + # - (bits + 7 - 1): dividing by 4 means 2 bits are unused + # but we also add 1 to an odd number so using an extra bit + # => TODO: reduce the output size (to potentially save a byte and corresponding multiplication/squarings) + checkOddModulus(P) + + # First we do P+1/2 in a way that guarantees no overflow + var tmp = primePlus1div2(P) + # then divide by 2 + tmp.shiftRight(1) + + result.exportRawUint(tmp, bigEndian) diff --git a/constantine/config/curves.nim b/constantine/config/curves.nim index ea607fa..f58b584 100644 --- a/constantine/config/curves.nim +++ b/constantine/config/curves.nim @@ -47,6 +47,18 @@ declareCurves: testingCurve: true bitsize: 7 modulus: "0x65" # 101 in hex + curve Fake103: # 103 ≡ 3 (mod 4) + testingCurve: true + bitsize: 7 + modulus: "0x67" # 103 in hex + curve Fake10007: # 10007 ≡ 3 (mod 4) + testingCurve: true + bitsize: 14 + modulus: "0x2717" # 10007 in hex + curve Fake65519: # 65519 ≡ 3 (mod 4) + testingCurve: true + bitsize: 16 + modulus: "0xFFEF" # 65519 in hex curve Mersenne61: testingCurve: true bitsize: 61 @@ -206,6 +218,13 @@ macro genMontyMagics(T: typed): untyped = bindSym($curve & "_Modulus") ) ) + # const MyCurve_MontyPrimeMinus1 = montyPrimeMinus1(MyCurve_Modulus) + result.add newConstStmt( + ident($curve & "_MontyPrimeMinus1"), newCall( + bindSym"montyPrimeMinus1", + bindSym($curve & "_Modulus") + ) + ) # const MyCurve_InvModExponent = primeMinus2_BE(MyCurve_Modulus) result.add newConstStmt( ident($curve & "_InvModExponent"), newCall( @@ -220,6 +239,27 @@ macro genMontyMagics(T: typed): untyped = bindSym($curve & "_Modulus") ) ) + # const MyCurve_PrimeMinus1div2_BE = primeMinus1div2_BE(MyCurve_Modulus) + result.add newConstStmt( + ident($curve & "_PrimeMinus1div2_BE"), newCall( + bindSym"primeMinus1div2_BE", + bindSym($curve & "_Modulus") + ) + ) + # const MyCurve_PrimeMinus3div4_BE = primeMinus3div4_BE(MyCurve_Modulus) + result.add newConstStmt( + ident($curve & "_PrimeMinus3div4_BE"), newCall( + bindSym"primeMinus3div4_BE", + bindSym($curve & "_Modulus") + ) + ) + # const MyCurve_PrimePlus1div4_BE = primePlus1div4_BE(MyCurve_Modulus) + result.add newConstStmt( + ident($curve & "_PrimePlus1div4_BE"), newCall( + bindSym"primePlus1div4_BE", + bindSym($curve & "_Modulus") + ) + ) # echo result.toStrLit @@ -247,14 +287,31 @@ macro getMontyOne*(C: static Curve): untyped = ## Get one in Montgomery representation (i.e. R mod P) result = bindSym($C & "_MontyOne") +macro getMontyPrimeMinus1*(C: static Curve): untyped = + ## Get (P+1) / 2 for an odd prime + result = bindSym($C & "_MontyPrimeMinus1") + macro getInvModExponent*(C: static Curve): untyped = ## Get modular inversion exponent (Modulus-2 in canonical representation) result = bindSym($C & "_InvModExponent") macro getPrimePlus1div2*(C: static Curve): untyped = ## Get (P+1) / 2 for an odd prime + ## Warning ⚠️: Result in canonical domain (not Montgomery) result = bindSym($C & "_PrimePlus1div2") +macro getPrimeMinus1div2_BE*(C: static Curve): untyped = + ## Get (P-1) / 2 in big-endian serialized format + result = bindSym($C & "_PrimeMinus1div2_BE") + +macro getPrimeMinus3div4_BE*(C: static Curve): untyped = + ## Get (P-3) / 2 in big-endian serialized format + result = bindSym($C & "_PrimeMinus3div4_BE") + +macro getPrimePlus1div4_BE*(C: static Curve): untyped = + ## Get (P+1) / 4 for an odd prime in big-endian serialized format + result = bindSym($C & "_PrimePlus1div4_BE") + # ############################################################ # # Debug info printed at compile-time diff --git a/tests/test_finite_fields_mulsquare.nim b/tests/test_finite_fields_mulsquare.nim index a340730..2a836cb 100644 --- a/tests/test_finite_fields_mulsquare.nim +++ b/tests/test_finite_fields_mulsquare.nim @@ -9,7 +9,7 @@ import std/unittest, std/times, ../constantine/arithmetic, ../constantine/io/[io_bigints, io_fields], - ../constantine/config/curves, + ../constantine/config/[curves, common], # Test utilities ../helpers/prng @@ -22,8 +22,6 @@ echo "test_finite_fields_mulsquare xoshiro512** seed: ", seed static: doAssert defined(testingCurves), "This modules requires the -d:testingCurves compile option" -import ../constantine/config/common - proc sanity(C: static Curve) = test "Squaring 0,1,2 with "& $Curve(C) & " [FastSquaring = " & $C.canUseNoCarryMontySquare & "]": block: # 0² mod diff --git a/tests/test_finite_fields_sqrt.nim b/tests/test_finite_fields_sqrt.nim new file mode 100644 index 0000000..7dbff2c --- /dev/null +++ b/tests/test_finite_fields_sqrt.nim @@ -0,0 +1,121 @@ +# Constantine +# Copyright (c) 2018-2019 Status Research & Development GmbH +# Copyright (c) 2020-Present Mamy André-Ratsimbazafy +# Licensed and distributed under either of +# * MIT license (license terms in the root directory or at http://opensource.org/licenses/MIT). +# * Apache v2 license (license terms in the root directory or at http://www.apache.org/licenses/LICENSE-2.0). +# at your option. This file may not be copied, modified, or distributed except according to those terms. + +import std/unittest, std/times, + ../constantine/[arithmetic, primitives], + ../constantine/io/[io_fields], + ../constantine/config/[curves, common], + # Test utilities + ../helpers/prng, + # Standard library + std/tables + +const Iters = 128 + +var rng: RngState +let seed = uint32(getTime().toUnix() and (1'i64 shl 32 - 1)) # unixTime mod 2^32 +rng.seed(seed) +echo "test_finite_fields_sqrt xoshiro512** seed: ", seed + +static: doAssert defined(testingCurves), "This modules requires the -d:testingCurves compile option" + +proc exhaustiveCheck_p3mod4(C: static Curve, modulus: static int) = + test "Exhaustive square root check for p ≡ 3 (mod 4) on " & $Curve(C): + var squares_to_roots: Table[uint16, set[uint16]] + + # Create all squares + # ------------------------- + for i in 0'u16 ..< modulus: + var a{.noInit.}: Fp[C] + a.fromUint(i) + + a.square() + + var r_bytes: array[8, byte] + r_bytes.exportRawUint(a, cpuEndian) + let r = uint16(cast[uint64](r_bytes)) + + squares_to_roots.mgetOrPut(r, default(set[uint16])).incl(i) + + # From Euler's criterion + # there is exactly (p-1)/2 squares in 𝔽p* (without 0) + # and so (p-1)/2 + 1 in 𝔽p (with 0) + check: squares_to_roots.len == (modulus-1) div 2 + 1 + + # Check squares + # ------------------------- + for i in 0'u16 ..< modulus: + var a{.noInit.}: Fp[C] + a.fromUint(i) + + if i in squares_to_roots: + var a2 = a + check: + bool a.isSquare() + bool a.sqrt_if_square_p3mod4() + + # 2 different code paths have the same result + # (despite 2 square roots existing per square) + a2.sqrt_p3mod4() + check: bool(a == a2) + + var r_bytes: array[8, byte] + r_bytes.exportRawUint(a, cpuEndian) + let r = uint16(cast[uint64](r_bytes)) + + # r is one of the 2 square roots of `i` + check: r in squares_to_roots[i] + + else: + let a2 = a + + check: + bool not a.isSquare() + bool not a.sqrt_if_square_p3mod4() + bool (a == a2) # a shouldn't be modified + +proc randomSqrtCheck_p3mod4(C: static Curve) = + test "Random square root check for p ≡ 3 (mod 4) on " & $Curve(C): + for _ in 0 ..< Iters: + let a = rng.random(Fp[C]) + var na{.noInit.}: Fp[C] + na.neg(a) + + var a2 = a + var na2 = na + a2.square() + na2.square() + check: + bool a2 == na2 + bool a2.isSquare() + + var r, s = a2 + r.sqrt_p3mod4() + let ok = s.sqrt_if_square_p3mod4() + check: + bool ok + bool(r == s) + bool(r == a or r == na) + +proc main() = + suite "Modular square root": + exhaustiveCheck_p3mod4 Fake103, 103 + exhaustiveCheck_p3mod4 Fake10007, 10007 + exhaustiveCheck_p3mod4 Fake65519, 65519 + randomSqrtCheck_p3mod4 Mersenne61 + randomSqrtCheck_p3mod4 Mersenne127 + randomSqrtCheck_p3mod4 BN254 + randomSqrtCheck_p3mod4 P256 + randomSqrtCheck_p3mod4 Secp256k1 + randomSqrtCheck_p3mod4 BLS12_381 + randomSqrtCheck_p3mod4 BN446 + randomSqrtCheck_p3mod4 FKM12_447 + randomSqrtCheck_p3mod4 BLS12_461 + randomSqrtCheck_p3mod4 BN462 + +main() diff --git a/tests/test_finite_fields_sqrt.nim.cfg b/tests/test_finite_fields_sqrt.nim.cfg new file mode 100644 index 0000000..0922c18 --- /dev/null +++ b/tests/test_finite_fields_sqrt.nim.cfg @@ -0,0 +1 @@ +-d:testingCurves