Fix isZeroMask in SignedSecretWord

This commit is contained in:
Mamy Ratsimbazafy 2023-01-29 01:05:54 +01:00
parent 915f89fdd6
commit a385acf2b8
No known key found for this signature in database
GPG Key ID: 6227262F49BE273A
3 changed files with 59 additions and 57 deletions

View File

@ -15,7 +15,7 @@ import
# ############################################################ # ############################################################
# #
# Primitives based on Bézout's identity # Primitives based on Bézout's identity
# #
# ############################################################ # ############################################################
# #
@ -78,7 +78,7 @@ func invMod2powK(M0: BaseType, k: static BaseType): BaseType =
# ############################################################### # ###############################################################
# #
# Modular inversion # Modular inversion
# #
# ############################################################### # ###############################################################
@ -136,17 +136,17 @@ func canonicalize(
M: LimbsUnsaturated M: LimbsUnsaturated
) = ) =
## Compute a = sign*a (mod M) ## Compute a = sign*a (mod M)
## ##
## with a in range (-2*M, M) ## with a in range (-2*M, M)
## result in range [0, M) ## result in range [0, M)
const const
UnsatBitWidth = WordBitWidth - a.Excess UnsatBitWidth = WordBitWidth - a.Excess
Max = SignedSecretWord(MaxWord shr a.Excess) Max = SignedSecretWord(MaxWord shr a.Excess)
# Operate in registers # Operate in registers
var z = a var z = a
# Add M if `z` is negative # Add M if `z` is negative
# -> range (-M, M) # -> range (-M, M)
z.cadd(M, z.isNegMask()) z.cadd(M, z.isNegMask())
@ -170,7 +170,7 @@ func canonicalize(
proc partitionDivsteps(bits, wordBitWidth: int): tuple[totalIters, numChunks, chunkSize, cutoff: int] = proc partitionDivsteps(bits, wordBitWidth: int): tuple[totalIters, numChunks, chunkSize, cutoff: int] =
# Given the field modulus number of bits # Given the field modulus number of bits
# and the effective word size # and the effective word size
# Returns: # Returns:
# - the total number of iterations that guarantees GCD convergence # - the total number of iterations that guarantees GCD convergence
# - the number of chunks of divsteps to compute # - the number of chunks of divsteps to compute
@ -178,7 +178,7 @@ proc partitionDivsteps(bits, wordBitWidth: int): tuple[totalIters, numChunks, ch
# - a cutoff chunk, # - a cutoff chunk,
# before this chunk ID, the number of divsteps is "base number + 1" # before this chunk ID, the number of divsteps is "base number + 1"
# afterward it's "base number" # afterward it's "base number"
let totalIters = let totalIters =
if bits == 256: if bits == 256:
# https://github.com/sipa/safegcd-bounds/tree/master/coq # https://github.com/sipa/safegcd-bounds/tree/master/coq
# For 256-bit inputs, 590 divsteps are sufficient with hddivstep variant (half-delta divstep) # For 256-bit inputs, 590 divsteps are sufficient with hddivstep variant (half-delta divstep)
@ -203,12 +203,12 @@ func batchedDivsteps(
k: static int k: static int
): SignedSecretWord = ): SignedSecretWord =
## Bernstein-Yang half-delta (hdelta) batch of divsteps ## Bernstein-Yang half-delta (hdelta) batch of divsteps
## ##
## Output: ## Output:
## - return hdelta for the next batch of divsteps ## - return hdelta for the next batch of divsteps
## - mutate t, the transition matrix to apply `numIters` divsteps at once ## - mutate t, the transition matrix to apply `numIters` divsteps at once
## t is scaled by 2ᵏ ## t is scaled by 2ᵏ
## ##
## Input: ## Input:
## - f0, bottom limb of f ## - f0, bottom limb of f
## - g0, bottom limb of g ## - g0, bottom limb of g
@ -281,13 +281,13 @@ func matVecMul_shr_k_mod_M[N, E: static int](
invMod2powK: SecretWord invMod2powK: SecretWord
) = ) =
## Compute ## Compute
## ##
## [u v] [d] ## [u v] [d]
## [q r]/2ᵏ.[e] mod M ## [q r]/2ᵏ.[e] mod M
## ##
## d, e will be in range (-2*modulus,modulus) ## d, e will be in range (-2*modulus,modulus)
## and output limbs in (-2ᵏ, 2ᵏ) ## and output limbs in (-2ᵏ, 2ᵏ)
static: doAssert k == WordBitWidth - E static: doAssert k == WordBitWidth - E
const Max = SignedSecretWord(MaxWord shr E) const Max = SignedSecretWord(MaxWord shr E)
@ -303,7 +303,7 @@ func matVecMul_shr_k_mod_M[N, E: static int](
# Double-signed-word carries # Double-signed-word carries
var cd, ce: DSWord var cd, ce: DSWord
# First iteration of [u v] [d] # First iteration of [u v] [d]
# [q r].[e] # [q r].[e]
cd.ssumprodAccNoCarry(u, d[0], v, e[0]) cd.ssumprodAccNoCarry(u, d[0], v, e[0])
ce.ssumprodAccNoCarry(q, d[0], r, e[0]) ce.ssumprodAccNoCarry(q, d[0], r, e[0])
@ -317,7 +317,7 @@ func matVecMul_shr_k_mod_M[N, E: static int](
md.cadd(v, sign_e) md.cadd(v, sign_e)
me.cadd(q, sign_d) me.cadd(q, sign_d)
me.cadd(r, sign_e) me.cadd(r, sign_e)
md = md - (SignedSecretWord(invMod2powK * SecretWord(cd.lo) + SecretWord(md)) and Max) md = md - (SignedSecretWord(invMod2powK * SecretWord(cd.lo) + SecretWord(md)) and Max)
me = me - (SignedSecretWord(invMod2powK * SecretWord(ce.lo) + SecretWord(me)) and Max) me = me - (SignedSecretWord(invMod2powK * SecretWord(ce.lo) + SecretWord(me)) and Max)
@ -338,18 +338,18 @@ func matVecMul_shr_k_mod_M[N, E: static int](
e[i-1] = ce.lo and Max e[i-1] = ce.lo and Max
cd.ashr(k) cd.ashr(k)
ce.ashr(k) ce.ashr(k)
d[N-1] = cd.lo d[N-1] = cd.lo
e[N-1] = ce.lo e[N-1] = ce.lo
func matVecMul_shr_k[N, E: static int]( func matVecMul_shr_k[N, E: static int](
t: TransitionMatrix, t: TransitionMatrix,
f, g: var LimbsUnsaturated[N, E], f, g: var LimbsUnsaturated[N, E],
k: static int k: static int
) = ) =
## Compute ## Compute
## ##
## [u v] [f] ## [u v] [f]
## [q r].[g] / 2ᵏ ## [q r].[g] / 2ᵏ
static: doAssert k == WordBitWidth - E static: doAssert k == WordBitWidth - E
@ -363,8 +363,8 @@ func matVecMul_shr_k[N, E: static int](
# Double-signed-word carries # Double-signed-word carries
var cf, cg: DSWord var cf, cg: DSWord
# First iteration of [u v] [f] # First iteration of [u v] [f]
# [q r].[g] # [q r].[g]
cf.ssumprodAccNoCarry(u, f[0], v, g[0]) cf.ssumprodAccNoCarry(u, f[0], v, g[0])
cg.ssumprodAccNoCarry(q, f[0], r, g[0]) cg.ssumprodAccNoCarry(q, f[0], r, g[0])
@ -383,7 +383,7 @@ func matVecMul_shr_k[N, E: static int](
g[i-1] = cg.lo and Max g[i-1] = cg.lo and Max
cf.ashr(k) cf.ashr(k)
cg.ashr(k) cg.ashr(k)
f[N-1] = cf.lo f[N-1] = cf.lo
g[N-1] = cg.lo g[N-1] = cg.lo
@ -414,11 +414,11 @@ func invmodImpl[N, E](
# Compute transition matrix and next hdelta # Compute transition matrix and next hdelta
hdelta = t.batchedDivsteps(hdelta, f[0], g[0], numIters, k) hdelta = t.batchedDivsteps(hdelta, f[0], g[0], numIters, k)
# Apply the transition matrix # Apply the transition matrix
# [u v] [d] # [u v] [d]
# [q r]/2ᵏ.[e] mod M # [q r]/2ᵏ.[e] mod M
t.matVecMul_shr_k_mod_M(d, e, k, M, invMod2powK) t.matVecMul_shr_k_mod_M(d, e, k, M, invMod2powK)
# [u v] [f] # [u v] [f]
# [q r]/2ᵏ.[g] # [q r]/2ᵏ.[g]
t.matVecMul_shr_k(f, g, k) t.matVecMul_shr_k(f, g, k)
d.canonicalize(signMask = f.isNegMask(), M) d.canonicalize(signMask = f.isNegMask(), M)
@ -453,12 +453,12 @@ func invmod*(
F, M: static Limbs, bits: static int) = F, M: static Limbs, bits: static int) =
## Compute the scaled modular inverse of ``a`` modulo M ## Compute the scaled modular inverse of ``a`` modulo M
## r ≡ F.a⁻¹ (mod M) (compile-time factor and modulus overload) ## r ≡ F.a⁻¹ (mod M) (compile-time factor and modulus overload)
## ##
## with F and M known at compile-time ## with F and M known at compile-time
## ##
## M MUST be odd, M does not need to be prime. ## M MUST be odd, M does not need to be prime.
## ``a`` MUST be less than M. ## ``a`` MUST be less than M.
const Excess = 2 const Excess = 2
const k = WordBitWidth - Excess const k = WordBitWidth - Excess
const NumUnsatWords = (bits + k - 1) div k const NumUnsatWords = (bits + k - 1) div k
@ -506,13 +506,13 @@ func batchedDivstepsSymbol(
): tuple[hdelta, L: SignedSecretWord] = ): tuple[hdelta, L: SignedSecretWord] =
## Bernstein-Yang half-delta (hdelta) batch of divsteps ## Bernstein-Yang half-delta (hdelta) batch of divsteps
## with Legendre symbol tracking ## with Legendre symbol tracking
## ##
## Output: ## Output:
## - return hdelta for the next batch of divsteps ## - return hdelta for the next batch of divsteps
## - Returns the intermediate Legendre symbol ## - Returns the intermediate Legendre symbol
## - mutate t, the transition matrix to apply `numIters` divsteps at once ## - mutate t, the transition matrix to apply `numIters` divsteps at once
## t is scaled by 2ᵏ ## t is scaled by 2ᵏ
## ##
## Input: ## Input:
## - f0, bottom limb of f ## - f0, bottom limb of f
## - g0, bottom limb of g ## - g0, bottom limb of g
@ -618,20 +618,20 @@ func legendreImpl[N, E](
numIters, k) numIters, k)
else: else:
(hdelta, L) = t.batchedDivstepsSymbol(hdelta, f[0], g[0], numIters, k) (hdelta, L) = t.batchedDivstepsSymbol(hdelta, f[0], g[0], numIters, k)
# [u v] [f] # [u v] [f]
# [q r]/2ᵏ.[g] # [q r]/2ᵏ.[g]
t.matVecMul_shr_k(f, g, k) t.matVecMul_shr_k(f, g, k)
accL = (accL + L) and SignedSecretWord(3) accL = (accL + L) and SignedSecretWord(3)
accL = (accL + ((accL.isOdd() xor f.isNeg()))) and SignedSecretWord(3) accL = (accL + ((accL.isOdd() xor f.isNeg()))) and SignedSecretWord(3)
accL = (accL + accL.isOdd()) and SignedSecretWord(3) accL = (accL + accL.isOdd()) and SignedSecretWord(3)
accL = SignedSecretWord(1)-accL accL = SignedSecretWord(1)-accL
accL.csetZero(f.isZeroMask()) accL.csetZero(not f.isZeroMask())
return SecretWord(accL) return SecretWord(accL)
func legendre*(a, M: Limbs, bits: static int): SecretWord = func legendre*(a, M: Limbs, bits: static int): SecretWord =
## Compute the Legendre symbol ## Compute the Legendre symbol
## ##
## (a/p)ₗ ≡ a^((p-1)/2) ≡ 1 (mod p), iff a is a square ## (a/p)ₗ ≡ a^((p-1)/2) ≡ 1 (mod p), iff a is a square
## ≡ -1 (mod p), iff a is quadratic non-residue ## ≡ -1 (mod p), iff a is quadratic non-residue
## ≡ 0 (mod p), iff a is 0 ## ≡ 0 (mod p), iff a is 0
@ -645,24 +645,24 @@ func legendre*(a, M: Limbs, bits: static int): SecretWord =
var a2 {.noInit.}: LimbsUnsaturated[NumUnsatWords, Excess] var a2 {.noInit.}: LimbsUnsaturated[NumUnsatWords, Excess]
a2.fromPackedRepr(a) a2.fromPackedRepr(a)
legendreImpl(a2, m2, k, bits) legendreImpl(a2, m2, k, bits)
func legendre*(a: Limbs, M: static Limbs, bits: static int): SecretWord = func legendre*(a: Limbs, M: static Limbs, bits: static int): SecretWord =
## Compute the Legendre symbol (compile-time modulus overload) ## Compute the Legendre symbol (compile-time modulus overload)
## ##
## (a/p)ₗ ≡ a^((p-1)/2) ≡ 1 (mod p), iff a is a square ## (a/p)ₗ ≡ a^((p-1)/2) ≡ 1 (mod p), iff a is a square
## ≡ -1 (mod p), iff a is quadratic non-residue ## ≡ -1 (mod p), iff a is quadratic non-residue
## ≡ 0 (mod p), iff a is 0 ## ≡ 0 (mod p), iff a is 0
const Excess = 2 const Excess = 2
const k = WordBitWidth - Excess const k = WordBitWidth - Excess
const NumUnsatWords = (bits + k - 1) div k const NumUnsatWords = (bits + k - 1) div k
# Convert values to unsaturated repr # Convert values to unsaturated repr
const m2 = LimbsUnsaturated[NumUnsatWords, Excess].fromPackedRepr(M) const m2 = LimbsUnsaturated[NumUnsatWords, Excess].fromPackedRepr(M)
var a2 {.noInit.}: LimbsUnsaturated[NumUnsatWords, Excess] var a2 {.noInit.}: LimbsUnsaturated[NumUnsatWords, Excess]
a2.fromPackedRepr(a) a2.fromPackedRepr(a)
legendreImpl(a2, m2, k, bits) legendreImpl(a2, m2, k, bits)

View File

@ -17,7 +17,7 @@ type
## This allows efficient handling of carries and signs without intrinsics or assembly. ## This allows efficient handling of carries and signs without intrinsics or assembly.
# #
# Comparison with packed representation: # Comparison with packed representation:
# #
# Packed representation # Packed representation
# - pro: uses less words (important for multiplication which is O(n²) with n the number of words) # - pro: uses less words (important for multiplication which is O(n²) with n the number of words)
# - pro: less "mental overhead" to keep track (clear/shift) excess bits # - pro: less "mental overhead" to keep track (clear/shift) excess bits
@ -86,7 +86,7 @@ template `[]=`*(a: LimbsUnsaturated, idx: int, val: SignedSecretWord) =
func fromPackedRepr*[LU, E, LP: static int]( func fromPackedRepr*[LU, E, LP: static int](
dst: var LimbsUnsaturated[LU, E], dst: var LimbsUnsaturated[LU, E],
src: Limbs[LP]) = src: Limbs[LP]) =
## Converts from an packed representation to an unsaturated representation ## Converts from an packed representation to an unsaturated representation
const UnsatBitWidth = WordBitWidth-E const UnsatBitWidth = WordBitWidth-E
const Max = MaxWord shr E const Max = MaxWord shr E
@ -100,13 +100,13 @@ func fromPackedRepr*[LU, E, LP: static int](
srcIdx, dstIdx = 0 srcIdx, dstIdx = 0
hi, lo = Zero hi, lo = Zero
accLen = 0 accLen = 0
while srcIdx < src.len: while srcIdx < src.len:
# Form a 2-word buffer (hi, lo) # Form a 2-word buffer (hi, lo)
let w = if src_idx < src.len: src[srcIdx] let w = if src_idx < src.len: src[srcIdx]
else: Zero else: Zero
inc srcIdx inc srcIdx
if accLen == 0: if accLen == 0:
lo = w and Max lo = w and Max
hi = w shr UnsatBitWidth hi = w shr UnsatBitWidth
@ -124,7 +124,7 @@ func fromPackedRepr*[LU, E, LP: static int](
accLen -= s accLen -= s
lo = ((lo shr s) or (hi shl (UnsatBitWidth - s))) and Max lo = ((lo shr s) or (hi shl (UnsatBitWidth - s))) and Max
hi = hi shr s hi = hi shr s
if dstIdx < dst.words.len: if dstIdx < dst.words.len:
dst[dstIdx] = SignedSecretWord lo dst[dstIdx] = SignedSecretWord lo
@ -138,7 +138,7 @@ func fromPackedRepr*(T: type LimbsUnsaturated, src: Limbs): T =
func fromUnsatRepr*[LU, E, LP: static int]( func fromUnsatRepr*[LU, E, LP: static int](
dst: var Limbs[LP], dst: var Limbs[LP],
src: LimbsUnsaturated[LU, E]) = src: LimbsUnsaturated[LU, E]) =
## Converts from an packed representation to an unsaturated representation ## Converts from an packed representation to an unsaturated representation
const UnsatBitWidth = WordBitWidth-E const UnsatBitWidth = WordBitWidth-E
static: static:
@ -165,7 +165,7 @@ func fromUnsatRepr*[LU, E, LP: static int](
inc dstIdx inc dstIdx
accLen -= WordBitWidth accLen -= WordBitWidth
acc = nextWord shr (UnsatBitWidth - accLen) acc = nextWord shr (UnsatBitWidth - accLen)
if dst_idx < dst.len: if dst_idx < dst.len:
dst[dst_idx] = acc dst[dst_idx] = acc
@ -280,9 +280,10 @@ func isOdd*(a: SignedSecretWord): SignedSecretWord {.inline.} =
a and SignedSecretWord(1) a and SignedSecretWord(1)
func isZeroMask*(a: SignedSecretWord): SignedSecretWord {.inline.} = func isZeroMask*(a: SignedSecretWord): SignedSecretWord {.inline.} =
## Produce the -1 mask if a is negative ## Produce the -1 mask if a is 0
## and 0 otherwise ## and 0 otherwise
not SignedSecretWord(a.SecretWord().isZero()) # In x86 assembly, we can use "neg" + "sbb"
-SignedSecretWord(a.SecretWord().isZero())
func isNegMask*(a: SignedSecretWord): SignedSecretWord {.inline.} = func isNegMask*(a: SignedSecretWord): SignedSecretWord {.inline.} =
## Produce the -1 mask if a is negative ## Produce the -1 mask if a is negative
@ -295,7 +296,7 @@ func isOddMask*(a: SignedSecretWord): SignedSecretWord {.inline.} =
-(a and SignedSecretWord(1)) -(a and SignedSecretWord(1))
func csetZero*(a: var SignedSecretWord, mask: SignedSecretWord) {.inline.} = func csetZero*(a: var SignedSecretWord, mask: SignedSecretWord) {.inline.} =
## Conditionally set `a` to 0 ## Conditionally set `a` to 0
## mask must be 0 (0x00000...0000) (kept as is) ## mask must be 0 (0x00000...0000) (kept as is)
## or -1 (0xFFFF...FFFF) (zeroed) ## or -1 (0xFFFF...FFFF) (zeroed)
a = a and mask a = a and mask
@ -303,7 +304,7 @@ func csetZero*(a: var SignedSecretWord, mask: SignedSecretWord) {.inline.} =
func cneg*( func cneg*(
a: SignedSecretWord, a: SignedSecretWord,
mask: SignedSecretWord): SignedSecretWord {.inline.} = mask: SignedSecretWord): SignedSecretWord {.inline.} =
## Conditionally negate `a` ## Conditionally negate `a`
## mask must be 0 (0x00000...0000) (no negation) ## mask must be 0 (0x00000...0000) (no negation)
## or -1 (0xFFFF...FFFF) (negation) ## or -1 (0xFFFF...FFFF) (negation)
(a xor mask) - mask (a xor mask) - mask
@ -312,7 +313,7 @@ func cadd*(
a: var SignedSecretWord, a: var SignedSecretWord,
b: SignedSecretWord, b: SignedSecretWord,
mask: SignedSecretWord) {.inline.} = mask: SignedSecretWord) {.inline.} =
## Conditionally add `b` to `a` ## Conditionally add `b` to `a`
## mask must be 0 (0x00000...0000) (no addition) ## mask must be 0 (0x00000...0000) (no addition)
## or -1 (0xFFFF...FFFF) (addition) ## or -1 (0xFFFF...FFFF) (addition)
a = a + (b and mask) a = a + (b and mask)
@ -321,7 +322,7 @@ func csub*(
a: var SignedSecretWord, a: var SignedSecretWord,
b: SignedSecretWord, b: SignedSecretWord,
mask: SignedSecretWord) {.inline.} = mask: SignedSecretWord) {.inline.} =
## Conditionally substract `b` from `a` ## Conditionally substract `b` from `a`
## mask must be 0 (0x00000...0000) (no substraction) ## mask must be 0 (0x00000...0000) (no substraction)
## or -1 (0xFFFF...FFFF) (substraction) ## or -1 (0xFFFF...FFFF) (substraction)
a = a - (b and mask) a = a - (b and mask)
@ -335,7 +336,7 @@ func isZeroMask*(a: LimbsUnsaturated): SignedSecretWord {.inline.} =
var accum = SignedSecretWord(0) var accum = SignedSecretWord(0)
for i in 0 ..< a.words.len: for i in 0 ..< a.words.len:
accum = accum or a.words[i] accum = accum or a.words[i]
return accum.isZeroMask() return accum.isZeroMask()
func isNeg*(a: LimbsUnsaturated): SignedSecretWord {.inline.} = func isNeg*(a: LimbsUnsaturated): SignedSecretWord {.inline.} =
@ -351,10 +352,10 @@ func isNegMask*(a: LimbsUnsaturated): SignedSecretWord {.inline.} =
func cneg*( func cneg*(
a: var LimbsUnsaturated, a: var LimbsUnsaturated,
mask: SignedSecretWord) {.inline.} = mask: SignedSecretWord) {.inline.} =
## Conditionally negate `a` ## Conditionally negate `a`
## mask must be 0 (0x00000...0000) (no negation) ## mask must be 0 (0x00000...0000) (no negation)
## or -1 (0xFFFF...FFFF) (negation) ## or -1 (0xFFFF...FFFF) (negation)
## ##
## Carry propagation is deferred ## Carry propagation is deferred
for i in 0 ..< a.words.len: for i in 0 ..< a.words.len:
a[i] = a[i].cneg(mask) a[i] = a[i].cneg(mask)
@ -363,10 +364,10 @@ func cadd*(
a: var LimbsUnsaturated, a: var LimbsUnsaturated,
b: LimbsUnsaturated, b: LimbsUnsaturated,
mask: SignedSecretWord) {.inline.} = mask: SignedSecretWord) {.inline.} =
## Conditionally add `b` to `a` ## Conditionally add `b` to `a`
## mask must be 0 (0x00000...0000) (no addition) ## mask must be 0 (0x00000...0000) (no addition)
## or -1 (0xFFFF...FFFF) (addition) ## or -1 (0xFFFF...FFFF) (addition)
## ##
## Carry propagation is deferred ## Carry propagation is deferred
for i in 0 ..< a.words.len: for i in 0 ..< a.words.len:
a[i].cadd(b[i], mask) a[i].cadd(b[i], mask)

View File

@ -173,7 +173,7 @@ template `xor`*[T: Ct](x, y: CTBool[T]): CTBool[T] =
template cneg*[T: Ct](x: T, ctl: CTBool[T]): T = template cneg*[T: Ct](x: T, ctl: CTBool[T]): T =
# Conditional negate if ctl is true # Conditional negate if ctl is true
(x xor -T(ctl)) + T(ctl) (x xor -T(ctl)) + T(ctl)
# ############################################################ # ############################################################
# #
@ -215,6 +215,7 @@ template isNonZero*[T: Ct](x: T): CTBool[T] =
isMsbSet(x_NZ or -x_NZ) isMsbSet(x_NZ or -x_NZ)
template isZero*[T: Ct](x: T): CTBool[T] = template isZero*[T: Ct](x: T): CTBool[T] =
# In x86 assembly, we can use "neg" + "adc"
not isNonZero(x) not isNonZero(x)
# ############################################################ # ############################################################