Fix curve parser, implement smoke test for finite field

This commit is contained in:
Mamy André-Ratsimbazafy 2020-02-12 23:57:51 +01:00
parent 6226d86726
commit 3eb22f8fc7
No known key found for this signature in database
GPG Key ID: 7B88AD1FE79492E1
9 changed files with 224 additions and 122 deletions

View File

@ -8,7 +8,66 @@
import import
# Internal # Internal
./private/curves_config_parser ./curves_parser, ./common,
../primitives/constant_time,
../math/bigints_checked
# ############################################################
#
# Montgomery Magic Constant precomputation
#
# ############################################################
func montyMagic(M: static BigInt): static Word {.inline.} =
## Returns the Montgomery domain magic constant for the input modulus:
## -1/M[0] mod LimbSize
## M[0] is the least significant limb of M
## M must be odd and greater than 2.
# Test vectors: https://www.researchgate.net/publication/4107322_Montgomery_modular_multiplication_architecture_for_public_key_cryptosystems
# on p354
# Reference C impl: http://www.hackersdelight.org/hdcodetxt/mont64.c.txt
# ######################################################################
# Implementation of modular multiplicative inverse
# Assuming 2 positive integers a and m the modulo
#
# We are looking for z that solves `az ≡ 1 mod m`
#
# References:
# - Knuth, The Art of Computer Programming, Vol2 p342
# - Menezes, Handbook of Applied Cryptography (HAC), p610
# http://cacr.uwaterloo.ca/hac/about/chap14.pdf
# Starting from the extended GCD formula (Bezout identity),
# `ax + by = gcd(x,y)` with input x,y and outputs a, b, gcd
# We assume a and m are coprimes, i.e. gcd is 1, otherwise no inverse
# `ax + my = 1` <=> `ax + my ≡ 1 mod m` <=> `ax ≡ 1 mod m`
# For Montgomery magic number, we are in a special case
# where a = M and m = 2^LimbSize.
# For a and m to be coprimes, a must be odd.
# `m` (2^LimbSize) being a power of 2 greatly simplifies computation:
# - https://crypto.stackexchange.com/questions/47493/how-to-determine-the-multiplicative-inverse-modulo-64-or-other-power-of-two
# - http://groups.google.com/groups?selm=1994Apr6.093116.27805%40mnemosyne.cs.du.edu
# - https://mumble.net/~campbell/2015/01/21/inverse-mod-power-of-two
# - https://eprint.iacr.org/2017/411
# We have the following relation
# ax ≡ 1 (mod 2^k) <=> ax(2 - ax) ≡ 1 (mod 2^(2k))
#
# To get -1/M0 mod LimbSize
# we can either negate the resulting x of `ax(2 - ax) ≡ 1 (mod 2^(2k))`
# or do ax(2 + ax) ≡ 1 (mod 2^(2k))
const
M0 = M.limbs[0]
k = log2(WordBitSize)
result = M0 # Start from an inverse of M0 modulo 2, M0 is odd and it's own inverse
for _ in 0 ..< k:
result *= 2 + M * result # x' = x(2 + ax) (`+` to avoid negating at the end)
# ############################################################ # ############################################################
# #
@ -48,5 +107,5 @@ else:
# Fake curve for testing field arithmetic # Fake curve for testing field arithmetic
declareCurves: declareCurves:
curve Fake101: curve Fake101:
bitsize: 101 bitsize: 7
modulus: "0x65" # 101 in hex modulus: "0x65" # 101 in hex

View File

@ -10,7 +10,7 @@ import
# Standard library # Standard library
macros, macros,
# Internal # Internal
../io, ../bigints, ../montgomery_magic ../io/io_bigints, ../math/bigints_checked
# Macro to parse declarative curves configuration. # Macro to parse declarative curves configuration.
@ -60,6 +60,8 @@ macro declareCurves*(curves: untyped): untyped =
var curveModStmts = newStmtList() var curveModStmts = newStmtList()
var curveModWhenStmt = nnkWhenStmt.newTree() var curveModWhenStmt = nnkWhenStmt.newTree()
let Fp = ident"Fp"
for curveDesc in curves: for curveDesc in curves:
curveDesc.expectKind(nnkCommand) curveDesc.expectKind(nnkCommand)
doAssert curveDesc[0].eqIdent"curve" doAssert curveDesc[0].eqIdent"curve"
@ -86,17 +88,20 @@ macro declareCurves*(curves: untyped): untyped =
curve, bitSize curve, bitSize
) )
# const BN254_Modulus = fromHex(BigInt[254], "0x30644e72e131a029b85045b68181585d97816a916871ca8d3c208c16d87cfd47") # const BN254_Modulus = Fp[BN254](value: fromHex(BigInt[254], "0x30644e72e131a029b85045b68181585d97816a916871ca8d3c208c16d87cfd47"))
let modulusID = ident($curve & "_Modulus") let modulusID = ident($curve & "_Modulus")
curveModStmts.add newConstStmt( curveModStmts.add newConstStmt(
modulusID, modulusID,
newCall( nnkObjConstr.newTree(
bindSym"fromHex", nnkBracketExpr.newTree(Fp, curve),
nnkBracketExpr.newTree( nnkExprColonExpr.newTree(
bindSym"BigInt", ident"value",
bitSize newCall(
), bindSym"fromHex",
modulus nnkBracketExpr.newTree(bindSym"BigInt", bitSize),
modulus
)
)
) )
) )
@ -109,12 +114,14 @@ macro declareCurves*(curves: untyped): untyped =
), ),
modulusID modulusID
) )
# end for ---------------------------------------------------
result = newStmtList() result = newStmtList()
# type Curve = enum # type Curve = enum
let Curve = ident"Curve"
result.add newEnum( result.add newEnum(
name = ident"Curve", name = Curve,
fields = Curves, fields = Curves,
public = true, public = true,
pure = false pure = false
@ -122,10 +129,45 @@ macro declareCurves*(curves: untyped): untyped =
# const CurveBitSize*: array[Curve, int] = ... # const CurveBitSize*: array[Curve, int] = ...
let cbs = ident("CurveBitSize") let cbs = ident("CurveBitSize")
result.add quote do: result.add newConstStmt(
const `cbs`*: array[Curve, int] = `CurveBitSize` cbs, CurveBitSize
)
result.add curveModStmts # Need template indirection in the type section to avoid Nim sigmatch bug
# template matchingBigInt(C: static Curve): untyped =
# BigInt[CurveBitSize[C]]
let C = ident"C"
let matchingBigInt = ident"matchingBigInt"
result.add newProc(
name = matchingBigInt,
params = [ident"untyped", newIdentDefs(C, nnkStaticTy.newTree(Curve))],
body = nnkBracketExpr.newTree(bindSym"BigInt", nnkBracketExpr.newTree(cbs, C)),
procType = nnkTemplateDef
)
# type
# `Fp`*[C: static Curve] = object
# ## All operations on a field are modulo P
# ## P being the prime modulus of the Curve C
# value*: matchingBigInt(C)
result.add nnkTypeSection.newTree(
nnkTypeDef.newTree(
nnkPostfix.newTree(ident"*", Fp),
nnkGenericParams.newTree(newIdentDefs(
C, nnkStaticTy.newTree(Curve), newEmptyNode()
)),
nnkObjectTy.newTree(
newEmptyNode(),
newEmptyNode(),
nnkRecList.newTree(
newIdentDefs(
nnkPostfix.newTree(ident"*", ident"value"),
newCall(matchingBigInt, C)
)
)
)
)
)
# Add 'else: {.error: "Unreachable".}' to the when statements # Add 'else: {.error: "Unreachable".}' to the when statements
curveModWhenStmt.add nnkElse.newTree( curveModWhenStmt.add nnkElse.newTree(
@ -137,6 +179,8 @@ macro declareCurves*(curves: untyped): untyped =
) )
) )
result.add curveModStmts
# proc Mod(curve: static Curve): auto # proc Mod(curve: static Curve): auto
result.add newProc( result.add newProc(
name = nnkPostfix.newTree(ident"*", ident"Mod"), name = nnkPostfix.newTree(ident"*", ident"Mod"),
@ -148,8 +192,7 @@ macro declareCurves*(curves: untyped): untyped =
) )
], ],
body = curveModWhenStmt, body = curveModWhenStmt,
procType = nnkFuncDef, procType = nnkFuncDef
pragmas = nnkPragma.newTree(ident"compileTime")
) )
# proc MontyMagic(curve: static Curve): static Word # proc MontyMagic(curve: static Curve): static Word
@ -163,11 +206,11 @@ macro declareCurves*(curves: untyped): untyped =
) )
], ],
body = newCall( body = newCall(
bindSym"montyMagic", ident"montyMagic",
newCall(ident"Mod", ident"curve") newCall(ident"Mod", ident"curve")
), ),
procType = nnkFuncDef, procType = nnkFuncDef,
pragmas = nnkPragma.newTree(ident"compileTime") pragmas = nnkPragma.newTree(ident"compileTime")
) )
# echo result.toStrLit # echo result.toStrLit()

View File

@ -0,0 +1,23 @@
# Constantine
# Copyright (c) 2018-2019 Status Research & Development GmbH
# Copyright (c) 2020-Present Mamy André-Ratsimbazafy
# Licensed and distributed under either of
# * MIT license (license terms in the root directory or at http://opensource.org/licenses/MIT).
# * Apache v2 license (license terms in the root directory or at http://www.apache.org/licenses/LICENSE-2.0).
# at your option. This file may not be copied, modified, or distributed except according to those terms.
import
./io_bigints,
../math/finite_fields
# ############################################################
#
# Parsing from canonical inputs to internal representation
#
# ############################################################
func fromUint*(dst: var Fp,
src: SomeUnsignedInt) =
## Parse a regular unsigned integer
## and store it into a BigInt of size `bits`
dst.value.fromRawUint(cast[array[sizeof(src), byte]](src), cpuEndian)

View File

@ -81,7 +81,7 @@ func setInternalBitLength*(a: var BigInt) {.inline.} =
## from the announced BigInt bitsize ## from the announced BigInt bitsize
## and set the bitLength field of that BigInt ## and set the bitLength field of that BigInt
## to that computed value. ## to that computed value.
a.bitLength = static(a.bits + a.bits div WordBitSize) a.bitLength = uint32 static(a.bits + a.bits div WordBitSize)
func isZero*(a: BigInt): CTBool[Word] = func isZero*(a: BigInt): CTBool[Word] =
## Returns true if a big int is equal to zero ## Returns true if a big int is equal to zero

View File

@ -15,13 +15,22 @@
# We assume that p is prime known at compile-time # We assume that p is prime known at compile-time
# We assume that p is not even (requirement for Montgomery form) # We assume that p is not even (requirement for Montgomery form)
import ./primitives, ./bigints, ./curves_config import
../primitives/constant_time,
../config/[common, curves],
./bigints_checked
type # type
Fp*[C: static Curve] = object # Fp*[C: static Curve] = object
## P is the prime modulus of the Curve C # ## P is the prime modulus of the Curve C
## All operations on a field are modulo P # ## All operations on a field are modulo P
value: BigInt[CurveBitSize[C]] # value*: BigInt[CurveBitSize[C]]
export Fp # defined in ../config/curves to avoid recursive module dependencies
debug:
func `==`*(a, b: Fp): CTBool[Word] =
## Returns true if 2 big ints are equal
a.value == b.value
# ############################################################ # ############################################################
# #
@ -30,14 +39,23 @@ type
# ############################################################ # ############################################################
template add(a: var Fp, b: Fp, ctl: CTBool[Word]): CTBool[Word] = template add(a: var Fp, b: Fp, ctl: CTBool[Word]): CTBool[Word] =
## Constant-time big integer in-place optional addition
## The addition is only performed if ctl is "true"
## The result carry is always computed.
##
## a and b MAY be the same buffer
## a and b MUST have the same announced bitlength (i.e. `bits` static parameters)
add(a.value, b.value, ctl) add(a.value, b.value, ctl)
template sub(a: var Fp, b: Fp, ctl: CTBool[Word]): CTBool[Word] = template sub(a: var Fp, b: Fp, ctl: CTBool[Word]): CTBool[Word] =
## Constant-time big integer in-place optional substraction
## The substraction is only performed if ctl is "true"
## The result carry is always computed.
##
## a and b MAY be the same buffer
## a and b MUST have the same announced bitlength (i.e. `bits` static parameters)
sub(a.value, b.value, ctl) sub(a.value, b.value, ctl)
template `[]`(a: Fp, idx: int): Word =
a.value.limbs[idx]
# ############################################################ # ############################################################
# #
# Field arithmetic primitives # Field arithmetic primitives
@ -47,17 +65,13 @@ template `[]`(a: Fp, idx: int): Word =
# No exceptions allowed # No exceptions allowed
{.push raises: [].} {.push raises: [].}
func `+`*(a, b: Fp): Fp {.noInit.}= func `+=`*(a: var Fp, b: Fp) =
## Addition over Fp ## Addition over Fp
var ctl = add(a, b, CtTrue)
ctl = ctl or not sub(a, Fp.C.Mod, CtFalse)
discard sub(a, Fp.C.Mod, ctl)
# Non-CT implementation from Stint func `-=`*(a: var Fp, b: Fp) =
# ## Substraction over Fp
# let b_from_p = p - b # Don't do a + b directly to avoid overflows let ctl = sub(a, b, CtTrue)
# if a >= b_from_p: discard add(a, Fp.C.Mod, ctl)
# return a - b_from_p
# return m - b_from_p + a
result = a
var ctl = add(result, b, CtTrue)
ctl = ctl or not sub(result, Fp.C.Mod, CtFalse)
sub(result, Fp.C.Mod, ctl)

View File

@ -18,57 +18,6 @@ import
# No exceptions allowed # No exceptions allowed
{.push raises: [].} {.push raises: [].}
func montyMagic*(M: static BigInt): static Word {.inline.} =
## Returns the Montgomery domain magic constant for the input modulus:
## -1/M[0] mod LimbSize
## M[0] is the least significant limb of M
## M must be odd and greater than 2.
# Test vectors: https://www.researchgate.net/publication/4107322_Montgomery_modular_multiplication_architecture_for_public_key_cryptosystems
# on p354
# Reference C impl: http://www.hackersdelight.org/hdcodetxt/mont64.c.txt
# ######################################################################
# Implementation of modular multiplicative inverse
# Assuming 2 positive integers a and m the modulo
#
# We are looking for z that solves `az ≡ 1 mod m`
#
# References:
# - Knuth, The Art of Computer Programming, Vol2 p342
# - Menezes, Handbook of Applied Cryptography (HAC), p610
# http://cacr.uwaterloo.ca/hac/about/chap14.pdf
# Starting from the extended GCD formula (Bezout identity),
# `ax + by = gcd(x,y)` with input x,y and outputs a, b, gcd
# We assume a and m are coprimes, i.e. gcd is 1, otherwise no inverse
# `ax + my = 1` <=> `ax + my ≡ 1 mod m` <=> `ax ≡ 1 mod m`
# For Montgomery magic number, we are in a special case
# where a = M and m = 2^LimbSize.
# For a and m to be coprimes, a must be odd.
# `m` (2^LimbSize) being a power of 2 greatly simplifies computation:
# - https://crypto.stackexchange.com/questions/47493/how-to-determine-the-multiplicative-inverse-modulo-64-or-other-power-of-two
# - http://groups.google.com/groups?selm=1994Apr6.093116.27805%40mnemosyne.cs.du.edu
# - https://mumble.net/~campbell/2015/01/21/inverse-mod-power-of-two
# - https://eprint.iacr.org/2017/411
# We have the following relation
# ax ≡ 1 (mod 2^k) <=> ax(2 - ax) ≡ 1 (mod 2^(2k))
#
# To get -1/M0 mod LimbSize
# we can either negate the resulting x of `ax(2 - ax) ≡ 1 (mod 2^(2k))`
# or do ax(2 + ax) ≡ 1 (mod 2^(2k))
const
M0 = M.limbs[0]
k = log2(WordBitSize)
result = M0 # Start from an inverse of M0 modulo 2, M0 is odd and it's own inverse
for _ in static(0 ..< k):
result *= 2 + M * result # x' = x(2 + ax) (`+` to avoid negating at the end)
# ############################################################ # ############################################################
# #
# Montgomery domain primitives # Montgomery domain primitives

View File

@ -29,32 +29,6 @@ type
# return and/or accept CTBool, we don't want them # return and/or accept CTBool, we don't want them
# to require unnecessarily 8 bytes instead of 4 bytes # to require unnecessarily 8 bytes instead of 4 bytes
# ############################################################
#
# Bit hacks
#
# ############################################################
template isMsbSet*[T: Ct](x: T): CTBool[T] =
## Returns the most significant bit of an integer
const msb_pos = T.sizeof * 8 - 1
(CTBool[T])(x shr msb_pos)
func log2*(x: uint32): uint32 =
## Find the log base 2 of a 32-bit or less integer.
## using De Bruijn multiplication
## Works at compile-time, guaranteed constant-time.
# https://graphics.stanford.edu/%7Eseander/bithacks.html#IntegerLogDeBruijn
const lookup: array[32, uint8] = [0'u8, 9, 1, 10, 13, 21, 2, 29, 11, 14, 16, 18,
22, 25, 3, 30, 8, 12, 20, 28, 15, 17, 24, 7, 19, 27, 23, 6, 26, 5, 4, 31]
var v = x
v = v or v shr 1 # first round down to one less than a power of 2
v = v or v shr 2
v = v or v shr 4
v = v or v shr 8
v = v or v shr 16
lookup[(v * 0x07C4ACDD'u32) shr 27]
# ############################################################ # ############################################################
# #
# Pragmas # Pragmas
@ -166,6 +140,32 @@ template `-`*[T: Ct](x: T): T =
{.emit:[neg, " = -", x, ";"].} {.emit:[neg, " = -", x, ";"].}
neg neg
# ############################################################
#
# Bit hacks
#
# ############################################################
template isMsbSet*[T: Ct](x: T): CTBool[T] =
## Returns the most significant bit of an integer
const msb_pos = T.sizeof * 8 - 1
(CTBool[T])(x shr msb_pos)
func log2*(x: uint32): uint32 =
## Find the log base 2 of a 32-bit or less integer.
## using De Bruijn multiplication
## Works at compile-time, guaranteed constant-time.
# https://graphics.stanford.edu/%7Eseander/bithacks.html#IntegerLogDeBruijn
const lookup: array[32, uint8] = [0'u8, 9, 1, 10, 13, 21, 2, 29, 11, 14, 16, 18,
22, 25, 3, 30, 8, 12, 20, 28, 15, 17, 24, 7, 19, 27, 23, 6, 26, 5, 4, 31]
var v = x
v = v or v shr 1 # first round down to one less than a power of 2
v = v or v shr 2
v = v or v shr 4
v = v or v shr 8
v = v or v shr 16
lookup[(v * 0x07C4ACDD'u32) shr 27]
# ############################################################ # ############################################################
# #
# Hardened Boolean primitives # Hardened Boolean primitives
@ -258,7 +258,7 @@ template isNonZero*[T: Ct](x: T): CTBool[T] =
isMsbSet(x_NZ or -x_NZ) isMsbSet(x_NZ or -x_NZ)
template isZero*[T: Ct](x: T): CTBool[T] = template isZero*[T: Ct](x: T): CTBool[T] =
not x.isNonZero not isNonZero(x)
# ############################################################ # ############################################################
# #

View File

@ -7,11 +7,23 @@
# at your option. This file may not be copied, modified, or distributed except according to those terms. # at your option. This file may not be copied, modified, or distributed except according to those terms.
import unittest, random, import unittest, random,
../constantine/math/[io, primitives, finite_fields] ../constantine/math/finite_fields,
../constantine/io/io_fields,
../constantine/config/curves
static: doAssert defined(testingCurves), "This modules requires the -d:testingCurves compile option"
proc main() = proc main() =
suite "Basic arithmetic over finite fields": suite "Basic arithmetic over finite fields":
test "Addition mod 101": test "Addition mod 101":
block: block:
var x: Fp[Fake101] var x, y, z: Fp[Fake101]
x.fromUint()
x.fromUint(80'u32)
y.fromUint(10'u32)
z.fromUint(90'u32)
x += y
check: bool(z == x)
main()

View File

@ -0,0 +1,2 @@
-d:testingCurves
-d:debugConstantine