diff --git a/constantine/field_fp.nim b/constantine/field_fp.nim index a38e9d6..a442654 100644 --- a/constantine/field_fp.nim +++ b/constantine/field_fp.nim @@ -45,6 +45,9 @@ template add(a: var Fp, b: Fp, ctl: CTBool[Limb]): CTBool[Limb] = template sub(a: var Fp, b: Fp, ctl: CTBool[Limb]): CTBool[Limb] = sub(a.value, b.value, ctl) +template `[]`(a: Fp, idx: int): Limb = + a.value.limbs[idx] + # ############################################################ # # Field arithmetic primitives @@ -71,6 +74,7 @@ template scaleadd_impl(a: var Fp, c: Limb) = ## ## With a word W = 2^LimbBitSize and a field Fp ## Does a <- a * W + c (mod p) + const len = a.value.limbs.len when Fp.P.bits <= LimbBitSize: # If the prime fits in a single limb @@ -80,4 +84,91 @@ template scaleadd_impl(a: var Fp, c: Limb) = let hi = a[0] shr 1 # 64 - 63 = 1 let lo = a[0] shl LimbBitSize or c # Assumes most-significant bit in c is not set unsafe_div2n1n(q, a[0], hi, lo, Fp.P.limbs[0]) # (hi, lo) mod P - return + + else: + ## Multiple limbs + let hi = a[^1] # Save the high word to detect carries + const R = Fp.P.bits and LimbBitSize # R = bits mod 64 + + when R == 0: # If the number of bits is a multiple of 64 + let a1 = a[^2] # + let a0 = a[^1] # + moveMem(a[1], a[0], (len-1) * Limb.sizeof) # we can just shift words + a[0] = c # and replace the first one by c + const p0 = Fp.P[^1] + else: # Need to deal with partial word shifts at the edge. + let a1 = ((a[^2] shl (LimbBitSize-R)) or (a[^3] shr R)) and HighLimb + let a0 = ((a[^1] shl (LimbBitSize-R)) or (a[^2] shr R)) and HighLimb + moveMem(a[1], a[0], (len-1) * Limb.sizeof) + a[0] = c + const p0 = ((Fp.P[^1] shl (LimbBitSize-R)) or (Fp.P[^2] shr R)) and HighLimb + + # p0 has its high bit set. (a0, a1)/p0 fits in a limb. + # Get a quotient q, at most we will be 2 iterations off + # from the true quotient + + let + a_hi = a0 shr 1 # 64 - 63 = 1 + a_lo = (a0 shl LimbBitSize) or a1 + var q, r: Limb + q = unsafe_div2n1n(q, r, a_hi, a_lo, p0) # Estimate quotient + q = mux( # If n_hi == divisor + a0 == b0, HighLimb, # Quotient == HighLimb (0b0111...1111) + mux( + q == 0, 0, # elif q == 0, true quotient = 0 + q - 1 # else instead of being of by 0, 1 or 2 + ) # we returning q-1 to be off by -1, 0 or 1 + ) + + # Now substract a*2^63 - q*p + var carry = Limb(0) + var over_p = Limb(1) # Track if quotient than the modulus + + for i in static(0 ..< Fp.P.limbs.len): + var qp_lo: Limb + + block: # q*p + qp_hi: Limb + unsafe_extendedPrecMul(qp_hi, qp_lo, q, Fp.P[i]) # q * p + assert qp_lo.isMsbSet.not + assert carry.isMsbSet.not + qp_lo += carry # Add carry from previous limb + let qp_carry = qp_lo.isMsbSet + carry = mux(qp_carry, qp_hi + Limb(1), qp_hi) # New carry + + qp_lo = qp_lo and HighLimb # Normalize to u63 + + block: # a*2^63 - q*p + a[i] -= qp_lo + carry += Limb(a[i].isMsbSet) # Adjust if borrow + a[i] = a[i] and HighLimb # Normalize to u63 + + over_p = mux( + a[i] == Fp.P[i], over_p, + a[i] > Fp.P[i] + ) + + # Fix quotient, the true quotient is either q-1, q or q+1 + # + # if carry < q or carry == q and over_p we must do "a -= p" + # if carry > hi (negative result) we must do "a+= p" + + let neg = carry < hi + let tooBig = not over and (over_p or (carry < hi)) + + add(a, Fp.P, neg) + sub(a, Fp.P, tooBig) + +func scaleadd*(a: var Fp, c: Limb) = + ## Scale-accumulate + ## + ## With a word W = 2^LimbBitSize and a field Fp + ## Does a <- a * W + c (mod p) + scaleadd_impl(a, c) + +func scaleadd*(a: var Fp, c: static Limb) = + ## Scale-accumulate + ## + ## With a word W = 2^LimbBitSize and a field Fp + ## Does a <- a * W + c (mod p) + scaleadd_impl(a, c) diff --git a/constantine/private/word_types_internal.nim b/constantine/private/word_types_internal.nim index 5206d86..15eb1f1 100644 --- a/constantine/private/word_types_internal.nim +++ b/constantine/private/word_types_internal.nim @@ -36,7 +36,7 @@ func asm_x86_64_extMul(hi, lo: var uint64, a, b: uint64) {.inline.}= : // no clobbered registers """ -func unsafe_extendedPrecMul(hi, lo: var Ct[uint64], a, b: Ct[uint64]) {.inline.}= +func unsafe_extendedPrecMul*(hi, lo: var Ct[uint64], a, b: Ct[uint64]) {.inline.}= ## Extended precision multiplication uint64 * uint64 --> uint128 ## ## TODO, at the moment only x86_64 architecture are supported @@ -48,6 +48,13 @@ func unsafe_extendedPrecMul(hi, lo: var Ct[uint64], a, b: Ct[uint64]) {.inline.} # and complicated to make constant-time # See at the bottom. + type T = uint64 + + when not defined(amd64): + {.error: "At the moment only x86_64 architecture is supported".} + else: + asm_x86_64_extMul(T(hi), T(lo), T(a), T(b)) + func asm_x86_64_div2n1n(q, r: var uint64, n_hi, n_lo, d: uint64) {.inline.}= ## Division uint128 by uint64 ## Warning ⚠️ : diff --git a/constantine/word_types.nim b/constantine/word_types.nim index b684f10..43cb2b6 100644 --- a/constantine/word_types.nim +++ b/constantine/word_types.nim @@ -63,7 +63,9 @@ func `or`*[T: Ct](x, y: T): T {.magic: "BitorI".} func `xor`*[T: Ct](x, y: T): T {.magic: "BitxorI".} func `not`*[T: Ct](x: T): T {.magic: "BitnotI".} func `+`*[T: Ct](x, y: T): T {.magic: "AddU".} +func `+=`*[T: Ct](x: var T, y: T): T {.magic: "Inc".} func `-`*[T: Ct](x, y: T): T {.magic: "SubU".} +func `-=`*[T: Ct](x: var T, y: T): T {.magic: "Dec".} func `shr`*[T: Ct](x: T, y: SomeInteger): T {.magic: "ShrI".} func `shl`*[T: Ct](x: T, y: SomeInteger): T {.magic: "ShlI".}