Add extended precision multiplication

This commit is contained in:
mratsim 2018-12-02 18:14:32 +01:00
parent 057ce0cbf9
commit c226987ab0
2 changed files with 106 additions and 13 deletions

View File

@ -50,12 +50,15 @@ type
BigInt*[bits: static int] = object BigInt*[bits: static int] = object
limbs*: array[bits.words_required, Limb] limbs*: array[bits.words_required, Limb]
const highLimb* = (not Ct[uint64](0)) shr 1 const HighLimb* = (not Ct[uint64](0)) shr 1
## This represents 0x7F_FF_FF_FF__FF_FF_FF_FF ## This represents 0x7F_FF_FF_FF__FF_FF_FF_FF
## also 0b0111...1111 ## also 0b0111...1111
## This biggest representable number in our limbs. ## This biggest representable number in our limbs.
## i.e. The most significant bit is never set at the end of each function ## i.e. The most significant bit is never set at the end of each function
template `[]`*(a: Bigint, idx: int): Limb =
a.limbs[idx]
# ############################################################ # ############################################################
# #
# BigInt primitives # BigInt primitives
@ -78,7 +81,7 @@ template addImpl[bits](result: CTBool[Limb], a: var BigInt[bits], b: BigInt[bits
for i in static(0 ..< a.limbs.len): for i in static(0 ..< a.limbs.len):
let new_a = a.limbs[i] + b.limbs[i] + Limb(result) let new_a = a.limbs[i] + b.limbs[i] + Limb(result)
result = new_a.isMsbSet() result = new_a.isMsbSet()
a[i] = ctl.mux(new_a and highLimb, a) a[i] = ctl.mux(new_a and HighLimb, a)
func add*[bits](a: var BigInt[bits], b: BigInt[bits], ctl: CTBool[Limb]): CTBool[Limb] = func add*[bits](a: var BigInt[bits], b: BigInt[bits], ctl: CTBool[Limb]): CTBool[Limb] =
## Constant-time big integer in-place addition ## Constant-time big integer in-place addition
@ -96,7 +99,7 @@ template subImpl[bits](result: CTBool[Limb], a: var BigInt[bits], b: BigInt[bits
for i in static(0 ..< a.limbs.len): for i in static(0 ..< a.limbs.len):
let new_a = a.limbs[i] - b.limbs[i] - Limb(result) let new_a = a.limbs[i] - b.limbs[i] - Limb(result)
result = new_a.isMsbSet() result = new_a.isMsbSet()
a[i] = ctl.mux(new_a and highLimb, a) a[i] = ctl.mux(new_a and HighLimb, a)
func sub*[bits](a: var BigInt[bits], b: BigInt[bits], ctl: CTBool[Limb]): CTBool[Limb] = func sub*[bits](a: var BigInt[bits], b: BigInt[bits], ctl: CTBool[Limb]): CTBool[Limb] =
## Constant-time big integer in-place addition ## Constant-time big integer in-place addition

View File

@ -13,6 +13,41 @@
import ../word_types import ../word_types
func asm_x86_64_extMul(hi, lo: var uint64, a, b: uint64) {.inline.}=
## Extended precision multiplication uint64 * uint64 --> uint128
# TODO !!! - Replace by constant-time, portable, non-assembly version
# -> use uint128? Compiler might add unwanted branches
# MUL r/m64
# Multiply RAX by r/m64
#
# Inputs:
# - RAX
# - r/m
# Outputs:
# - High word in RDX
# - Low word in RAX
asm """
mulq %[operand]
: "=d" (`*hi`), "=a" (`*lo`) // Don't forget to dereference the var hidden pointer
: "a" (`a`), [operand] "rm" (`b`)
: // no clobbered registers
"""
func unsafe_extendedPrecMul(hi, lo: var Ct[uint64], a, b: Ct[uint64]) {.inline.}=
## Extended precision multiplication uint64 * uint64 --> uint128
##
## TODO, at the moment only x86_64 architecture are supported
## as we use assembly.
## Also we assume that the native integer division
## provided by the PU is constant-time
# Note, using C/Nim default `*` is inefficient
# and complicated to make constant-time
# See at the bottom.
func asm_x86_64_div2n1n(q, r: var uint64, n_hi, n_lo, d: uint64) {.inline.}= func asm_x86_64_div2n1n(q, r: var uint64, n_hi, n_lo, d: uint64) {.inline.}=
## Division uint128 by uint64 ## Division uint128 by uint64
## Warning ⚠️ : ## Warning ⚠️ :
@ -20,6 +55,7 @@ func asm_x86_64_div2n1n(q, r: var uint64, n_hi, n_lo, d: uint64) {.inline.}=
## - if n_hi > d result is undefined ## - if n_hi > d result is undefined
# TODO !!! - Replace by constant-time, portable, non-assembly version # TODO !!! - Replace by constant-time, portable, non-assembly version
# -> use uint128? Compiler might add unwanted branches
# DIV r/m64 # DIV r/m64
# Divide RDX:RAX (n_hi:n_lo) by r/m64 # Divide RDX:RAX (n_hi:n_lo) by r/m64
@ -27,7 +63,7 @@ func asm_x86_64_div2n1n(q, r: var uint64, n_hi, n_lo, d: uint64) {.inline.}=
# Inputs # Inputs
# - numerator high word in RDX, # - numerator high word in RDX,
# - numerator low word in RAX, # - numerator low word in RAX,
# - divisor as rm parameter (register or memory at the compiler discretion) # - divisor as r/m parameter (register or memory at the compiler discretion)
# Result # Result
# - Quotient in RAX # - Quotient in RAX
# - Remainder in RDX # - Remainder in RDX
@ -44,6 +80,10 @@ func unsafe_div2n1n*(q, r: var Ct[uint64], n_hi, n_lo, d: Ct[uint64]) {.inline.}
## - if n_hi == d, quotient does not fit in an uint64 ## - if n_hi == d, quotient does not fit in an uint64
## - if n_hi > d result is undefined ## - if n_hi > d result is undefined
## ##
## To avoid issues, n_hi, n_lo, d should be normalized.
## i.e. shifted (== multiplied by the same power of 2)
## so that the most significant bit in d is set.
##
## TODO, at the moment only x86_64 architecture are supported ## TODO, at the moment only x86_64 architecture are supported
## as we use assembly. ## as we use assembly.
## Also we assume that the native integer division ## Also we assume that the native integer division
@ -53,7 +93,7 @@ func unsafe_div2n1n*(q, r: var Ct[uint64], n_hi, n_lo, d: Ct[uint64]) {.inline.}
# and complicated to make constant-time # and complicated to make constant-time
# See at the bottom. # See at the bottom.
# #
# Furthermore compilers try to substitute division # Furthermore compilers may try to substitute division
# with a fast path that may have branches. It might also # with a fast path that may have branches. It might also
# be the same at the hardware level. # be the same at the hardware level.
@ -65,18 +105,68 @@ func unsafe_div2n1n*(q, r: var Ct[uint64], n_hi, n_lo, d: Ct[uint64]) {.inline.}
asm_x86_64_div2n1n(T(q), T(r), T(n_hi), T(n_lo), T(d)) asm_x86_64_div2n1n(T(q), T(r), T(n_hi), T(n_lo), T(d))
when isMainModule: when isMainModule:
block: # Multiplication
var hi, lo: uint64
var q, r: uint64 asm_x86_64_extMul(hi, lo, 1 shl 32, 1 shl 33) # 2^65
doAssert hi == 2
doAssert lo == 0
# (1 shl 64) div 3 block: # Division
let n_hi = 1'u64 var q, r: uint64
let n_lo = 0'u64
let d = 3'u64
asm_x86_64_div2n1n(q, r, n_hi, n_lo, d) # (1 shl 64) div 3
let n_hi = 1'u64
let n_lo = 0'u64
let d = 3'u64
doAssert q == 6148914691236517205'u64 asm_x86_64_div2n1n(q, r, n_hi, n_lo, d)
doAssert r == 1
doAssert q == 6148914691236517205'u64
doAssert r == 1
# ##############################################################
#
# Non-constant-time portable extended precision multiplication
#
# ##############################################################
# implementation from Stint: https://github.com/status-im/nim-stint/blob/edb1ade37309390cc641cee07ab62e5459d9ca44/stint/private/uint_mul.nim#L91-L135
# template extPrecMulImpl(result: var UintImpl[uint64], op: untyped, u, v: uint64) =
# const
# p = 64 div 2
# base: uint64 = 1 shl p
#
# var
# x0, x1, x2, x3: uint64
#
# let
# ul = lo(u)
# uh = hi(u)
# vl = lo(v)
# vh = hi(v)
#
# x0 = ul * vl
# x1 = ul * vh
# x2 = uh * vl
# x3 = uh * vh
#
# x1 += hi(x0) # This can't carry
# x1 += x2 # but this can
# if x1 < x2: # if carry, add it to x3
# x3 += base
#
# op(result.hi, x3 + hi(x1))
# op(result.lo, (x1 shl p) or lo(x0))
#
# func extPrecMul*(result: var UintImpl[uint64], u, v: uint64) =
# ## Extended precision multiplication
# extPrecMulImpl(result, `=`, u, v)
#
# func extPrecAddMul(result: var UintImpl[uint64], u, v: uint64) =
# ## Extended precision fused in-place addition & multiplication
# extPrecMulImpl(result, `+=`, u, v)
# ############################################################ # ############################################################
# #