Implement scale accumulate
This commit is contained in:
parent
c226987ab0
commit
22d8bc218d
|
@ -45,6 +45,9 @@ template add(a: var Fp, b: Fp, ctl: CTBool[Limb]): CTBool[Limb] =
|
|||
template sub(a: var Fp, b: Fp, ctl: CTBool[Limb]): CTBool[Limb] =
|
||||
sub(a.value, b.value, ctl)
|
||||
|
||||
template `[]`(a: Fp, idx: int): Limb =
|
||||
a.value.limbs[idx]
|
||||
|
||||
# ############################################################
|
||||
#
|
||||
# Field arithmetic primitives
|
||||
|
@ -71,6 +74,7 @@ template scaleadd_impl(a: var Fp, c: Limb) =
|
|||
##
|
||||
## With a word W = 2^LimbBitSize and a field Fp
|
||||
## Does a <- a * W + c (mod p)
|
||||
const len = a.value.limbs.len
|
||||
|
||||
when Fp.P.bits <= LimbBitSize:
|
||||
# If the prime fits in a single limb
|
||||
|
@ -80,4 +84,91 @@ template scaleadd_impl(a: var Fp, c: Limb) =
|
|||
let hi = a[0] shr 1 # 64 - 63 = 1
|
||||
let lo = a[0] shl LimbBitSize or c # Assumes most-significant bit in c is not set
|
||||
unsafe_div2n1n(q, a[0], hi, lo, Fp.P.limbs[0]) # (hi, lo) mod P
|
||||
return
|
||||
|
||||
else:
|
||||
## Multiple limbs
|
||||
let hi = a[^1] # Save the high word to detect carries
|
||||
const R = Fp.P.bits and LimbBitSize # R = bits mod 64
|
||||
|
||||
when R == 0: # If the number of bits is a multiple of 64
|
||||
let a1 = a[^2] #
|
||||
let a0 = a[^1] #
|
||||
moveMem(a[1], a[0], (len-1) * Limb.sizeof) # we can just shift words
|
||||
a[0] = c # and replace the first one by c
|
||||
const p0 = Fp.P[^1]
|
||||
else: # Need to deal with partial word shifts at the edge.
|
||||
let a1 = ((a[^2] shl (LimbBitSize-R)) or (a[^3] shr R)) and HighLimb
|
||||
let a0 = ((a[^1] shl (LimbBitSize-R)) or (a[^2] shr R)) and HighLimb
|
||||
moveMem(a[1], a[0], (len-1) * Limb.sizeof)
|
||||
a[0] = c
|
||||
const p0 = ((Fp.P[^1] shl (LimbBitSize-R)) or (Fp.P[^2] shr R)) and HighLimb
|
||||
|
||||
# p0 has its high bit set. (a0, a1)/p0 fits in a limb.
|
||||
# Get a quotient q, at most we will be 2 iterations off
|
||||
# from the true quotient
|
||||
|
||||
let
|
||||
a_hi = a0 shr 1 # 64 - 63 = 1
|
||||
a_lo = (a0 shl LimbBitSize) or a1
|
||||
var q, r: Limb
|
||||
q = unsafe_div2n1n(q, r, a_hi, a_lo, p0) # Estimate quotient
|
||||
q = mux( # If n_hi == divisor
|
||||
a0 == b0, HighLimb, # Quotient == HighLimb (0b0111...1111)
|
||||
mux(
|
||||
q == 0, 0, # elif q == 0, true quotient = 0
|
||||
q - 1 # else instead of being of by 0, 1 or 2
|
||||
) # we returning q-1 to be off by -1, 0 or 1
|
||||
)
|
||||
|
||||
# Now substract a*2^63 - q*p
|
||||
var carry = Limb(0)
|
||||
var over_p = Limb(1) # Track if quotient than the modulus
|
||||
|
||||
for i in static(0 ..< Fp.P.limbs.len):
|
||||
var qp_lo: Limb
|
||||
|
||||
block: # q*p
|
||||
qp_hi: Limb
|
||||
unsafe_extendedPrecMul(qp_hi, qp_lo, q, Fp.P[i]) # q * p
|
||||
assert qp_lo.isMsbSet.not
|
||||
assert carry.isMsbSet.not
|
||||
qp_lo += carry # Add carry from previous limb
|
||||
let qp_carry = qp_lo.isMsbSet
|
||||
carry = mux(qp_carry, qp_hi + Limb(1), qp_hi) # New carry
|
||||
|
||||
qp_lo = qp_lo and HighLimb # Normalize to u63
|
||||
|
||||
block: # a*2^63 - q*p
|
||||
a[i] -= qp_lo
|
||||
carry += Limb(a[i].isMsbSet) # Adjust if borrow
|
||||
a[i] = a[i] and HighLimb # Normalize to u63
|
||||
|
||||
over_p = mux(
|
||||
a[i] == Fp.P[i], over_p,
|
||||
a[i] > Fp.P[i]
|
||||
)
|
||||
|
||||
# Fix quotient, the true quotient is either q-1, q or q+1
|
||||
#
|
||||
# if carry < q or carry == q and over_p we must do "a -= p"
|
||||
# if carry > hi (negative result) we must do "a+= p"
|
||||
|
||||
let neg = carry < hi
|
||||
let tooBig = not over and (over_p or (carry < hi))
|
||||
|
||||
add(a, Fp.P, neg)
|
||||
sub(a, Fp.P, tooBig)
|
||||
|
||||
func scaleadd*(a: var Fp, c: Limb) =
|
||||
## Scale-accumulate
|
||||
##
|
||||
## With a word W = 2^LimbBitSize and a field Fp
|
||||
## Does a <- a * W + c (mod p)
|
||||
scaleadd_impl(a, c)
|
||||
|
||||
func scaleadd*(a: var Fp, c: static Limb) =
|
||||
## Scale-accumulate
|
||||
##
|
||||
## With a word W = 2^LimbBitSize and a field Fp
|
||||
## Does a <- a * W + c (mod p)
|
||||
scaleadd_impl(a, c)
|
||||
|
|
|
@ -36,7 +36,7 @@ func asm_x86_64_extMul(hi, lo: var uint64, a, b: uint64) {.inline.}=
|
|||
: // no clobbered registers
|
||||
"""
|
||||
|
||||
func unsafe_extendedPrecMul(hi, lo: var Ct[uint64], a, b: Ct[uint64]) {.inline.}=
|
||||
func unsafe_extendedPrecMul*(hi, lo: var Ct[uint64], a, b: Ct[uint64]) {.inline.}=
|
||||
## Extended precision multiplication uint64 * uint64 --> uint128
|
||||
##
|
||||
## TODO, at the moment only x86_64 architecture are supported
|
||||
|
@ -48,6 +48,13 @@ func unsafe_extendedPrecMul(hi, lo: var Ct[uint64], a, b: Ct[uint64]) {.inline.}
|
|||
# and complicated to make constant-time
|
||||
# See at the bottom.
|
||||
|
||||
type T = uint64
|
||||
|
||||
when not defined(amd64):
|
||||
{.error: "At the moment only x86_64 architecture is supported".}
|
||||
else:
|
||||
asm_x86_64_extMul(T(hi), T(lo), T(a), T(b))
|
||||
|
||||
func asm_x86_64_div2n1n(q, r: var uint64, n_hi, n_lo, d: uint64) {.inline.}=
|
||||
## Division uint128 by uint64
|
||||
## Warning ⚠️ :
|
||||
|
|
|
@ -63,7 +63,9 @@ func `or`*[T: Ct](x, y: T): T {.magic: "BitorI".}
|
|||
func `xor`*[T: Ct](x, y: T): T {.magic: "BitxorI".}
|
||||
func `not`*[T: Ct](x: T): T {.magic: "BitnotI".}
|
||||
func `+`*[T: Ct](x, y: T): T {.magic: "AddU".}
|
||||
func `+=`*[T: Ct](x: var T, y: T): T {.magic: "Inc".}
|
||||
func `-`*[T: Ct](x, y: T): T {.magic: "SubU".}
|
||||
func `-=`*[T: Ct](x: var T, y: T): T {.magic: "Dec".}
|
||||
func `shr`*[T: Ct](x: T, y: SomeInteger): T {.magic: "ShrI".}
|
||||
func `shl`*[T: Ct](x: T, y: SomeInteger): T {.magic: "ShlI".}
|
||||
|
||||
|
|
Loading…
Reference in New Issue