From c226987ab03f057102f3791a4a7f86c7b36b881d Mon Sep 17 00:00:00 2001 From: mratsim Date: Sun, 2 Dec 2018 18:14:32 +0100 Subject: [PATCH] Add extended precision multiplication --- constantine/bigints.nim | 9 +- constantine/private/word_types_internal.nim | 110 ++++++++++++++++++-- 2 files changed, 106 insertions(+), 13 deletions(-) diff --git a/constantine/bigints.nim b/constantine/bigints.nim index c0d5c73..d585bb0 100644 --- a/constantine/bigints.nim +++ b/constantine/bigints.nim @@ -50,12 +50,15 @@ type BigInt*[bits: static int] = object limbs*: array[bits.words_required, Limb] -const highLimb* = (not Ct[uint64](0)) shr 1 +const HighLimb* = (not Ct[uint64](0)) shr 1 ## This represents 0x7F_FF_FF_FF__FF_FF_FF_FF ## also 0b0111...1111 ## This biggest representable number in our limbs. ## i.e. The most significant bit is never set at the end of each function +template `[]`*(a: Bigint, idx: int): Limb = + a.limbs[idx] + # ############################################################ # # BigInt primitives @@ -78,7 +81,7 @@ template addImpl[bits](result: CTBool[Limb], a: var BigInt[bits], b: BigInt[bits for i in static(0 ..< a.limbs.len): let new_a = a.limbs[i] + b.limbs[i] + Limb(result) result = new_a.isMsbSet() - a[i] = ctl.mux(new_a and highLimb, a) + a[i] = ctl.mux(new_a and HighLimb, a) func add*[bits](a: var BigInt[bits], b: BigInt[bits], ctl: CTBool[Limb]): CTBool[Limb] = ## Constant-time big integer in-place addition @@ -96,7 +99,7 @@ template subImpl[bits](result: CTBool[Limb], a: var BigInt[bits], b: BigInt[bits for i in static(0 ..< a.limbs.len): let new_a = a.limbs[i] - b.limbs[i] - Limb(result) result = new_a.isMsbSet() - a[i] = ctl.mux(new_a and highLimb, a) + a[i] = ctl.mux(new_a and HighLimb, a) func sub*[bits](a: var BigInt[bits], b: BigInt[bits], ctl: CTBool[Limb]): CTBool[Limb] = ## Constant-time big integer in-place addition diff --git a/constantine/private/word_types_internal.nim b/constantine/private/word_types_internal.nim index 2a0f63e..5206d86 100644 --- a/constantine/private/word_types_internal.nim +++ b/constantine/private/word_types_internal.nim @@ -13,6 +13,41 @@ import ../word_types +func asm_x86_64_extMul(hi, lo: var uint64, a, b: uint64) {.inline.}= + ## Extended precision multiplication uint64 * uint64 --> uint128 + + # TODO !!! - Replace by constant-time, portable, non-assembly version + # -> use uint128? Compiler might add unwanted branches + + # MUL r/m64 + # Multiply RAX by r/m64 + # + # Inputs: + # - RAX + # - r/m + # Outputs: + # - High word in RDX + # - Low word in RAX + + asm """ + mulq %[operand] + : "=d" (`*hi`), "=a" (`*lo`) // Don't forget to dereference the var hidden pointer + : "a" (`a`), [operand] "rm" (`b`) + : // no clobbered registers + """ + +func unsafe_extendedPrecMul(hi, lo: var Ct[uint64], a, b: Ct[uint64]) {.inline.}= + ## Extended precision multiplication uint64 * uint64 --> uint128 + ## + ## TODO, at the moment only x86_64 architecture are supported + ## as we use assembly. + ## Also we assume that the native integer division + ## provided by the PU is constant-time + + # Note, using C/Nim default `*` is inefficient + # and complicated to make constant-time + # See at the bottom. + func asm_x86_64_div2n1n(q, r: var uint64, n_hi, n_lo, d: uint64) {.inline.}= ## Division uint128 by uint64 ## Warning ⚠️ : @@ -20,6 +55,7 @@ func asm_x86_64_div2n1n(q, r: var uint64, n_hi, n_lo, d: uint64) {.inline.}= ## - if n_hi > d result is undefined # TODO !!! - Replace by constant-time, portable, non-assembly version + # -> use uint128? Compiler might add unwanted branches # DIV r/m64 # Divide RDX:RAX (n_hi:n_lo) by r/m64 @@ -27,7 +63,7 @@ func asm_x86_64_div2n1n(q, r: var uint64, n_hi, n_lo, d: uint64) {.inline.}= # Inputs # - numerator high word in RDX, # - numerator low word in RAX, - # - divisor as rm parameter (register or memory at the compiler discretion) + # - divisor as r/m parameter (register or memory at the compiler discretion) # Result # - Quotient in RAX # - Remainder in RDX @@ -44,6 +80,10 @@ func unsafe_div2n1n*(q, r: var Ct[uint64], n_hi, n_lo, d: Ct[uint64]) {.inline.} ## - if n_hi == d, quotient does not fit in an uint64 ## - if n_hi > d result is undefined ## + ## To avoid issues, n_hi, n_lo, d should be normalized. + ## i.e. shifted (== multiplied by the same power of 2) + ## so that the most significant bit in d is set. + ## ## TODO, at the moment only x86_64 architecture are supported ## as we use assembly. ## Also we assume that the native integer division @@ -53,7 +93,7 @@ func unsafe_div2n1n*(q, r: var Ct[uint64], n_hi, n_lo, d: Ct[uint64]) {.inline.} # and complicated to make constant-time # See at the bottom. # - # Furthermore compilers try to substitute division + # Furthermore compilers may try to substitute division # with a fast path that may have branches. It might also # be the same at the hardware level. @@ -65,18 +105,68 @@ func unsafe_div2n1n*(q, r: var Ct[uint64], n_hi, n_lo, d: Ct[uint64]) {.inline.} asm_x86_64_div2n1n(T(q), T(r), T(n_hi), T(n_lo), T(d)) when isMainModule: + block: # Multiplication + var hi, lo: uint64 - var q, r: uint64 + asm_x86_64_extMul(hi, lo, 1 shl 32, 1 shl 33) # 2^65 + doAssert hi == 2 + doAssert lo == 0 - # (1 shl 64) div 3 - let n_hi = 1'u64 - let n_lo = 0'u64 - let d = 3'u64 + block: # Division + var q, r: uint64 - asm_x86_64_div2n1n(q, r, n_hi, n_lo, d) + # (1 shl 64) div 3 + let n_hi = 1'u64 + let n_lo = 0'u64 + let d = 3'u64 - doAssert q == 6148914691236517205'u64 - doAssert r == 1 + asm_x86_64_div2n1n(q, r, n_hi, n_lo, d) + + doAssert q == 6148914691236517205'u64 + doAssert r == 1 + +# ############################################################## +# +# Non-constant-time portable extended precision multiplication +# +# ############################################################## + +# implementation from Stint: https://github.com/status-im/nim-stint/blob/edb1ade37309390cc641cee07ab62e5459d9ca44/stint/private/uint_mul.nim#L91-L135 + +# template extPrecMulImpl(result: var UintImpl[uint64], op: untyped, u, v: uint64) = +# const +# p = 64 div 2 +# base: uint64 = 1 shl p +# +# var +# x0, x1, x2, x3: uint64 +# +# let +# ul = lo(u) +# uh = hi(u) +# vl = lo(v) +# vh = hi(v) +# +# x0 = ul * vl +# x1 = ul * vh +# x2 = uh * vl +# x3 = uh * vh +# +# x1 += hi(x0) # This can't carry +# x1 += x2 # but this can +# if x1 < x2: # if carry, add it to x3 +# x3 += base +# +# op(result.hi, x3 + hi(x1)) +# op(result.lo, (x1 shl p) or lo(x0)) +# +# func extPrecMul*(result: var UintImpl[uint64], u, v: uint64) = +# ## Extended precision multiplication +# extPrecMulImpl(result, `=`, u, v) +# +# func extPrecAddMul(result: var UintImpl[uint64], u, v: uint64) = +# ## Extended precision fused in-place addition & multiplication +# extPrecMulImpl(result, `+=`, u, v) # ############################################################ #