Add extended precision multiplication
This commit is contained in:
parent
057ce0cbf9
commit
c226987ab0
|
@ -50,12 +50,15 @@ type
|
||||||
BigInt*[bits: static int] = object
|
BigInt*[bits: static int] = object
|
||||||
limbs*: array[bits.words_required, Limb]
|
limbs*: array[bits.words_required, Limb]
|
||||||
|
|
||||||
const highLimb* = (not Ct[uint64](0)) shr 1
|
const HighLimb* = (not Ct[uint64](0)) shr 1
|
||||||
## This represents 0x7F_FF_FF_FF__FF_FF_FF_FF
|
## This represents 0x7F_FF_FF_FF__FF_FF_FF_FF
|
||||||
## also 0b0111...1111
|
## also 0b0111...1111
|
||||||
## This biggest representable number in our limbs.
|
## This biggest representable number in our limbs.
|
||||||
## i.e. The most significant bit is never set at the end of each function
|
## i.e. The most significant bit is never set at the end of each function
|
||||||
|
|
||||||
|
template `[]`*(a: Bigint, idx: int): Limb =
|
||||||
|
a.limbs[idx]
|
||||||
|
|
||||||
# ############################################################
|
# ############################################################
|
||||||
#
|
#
|
||||||
# BigInt primitives
|
# BigInt primitives
|
||||||
|
@ -78,7 +81,7 @@ template addImpl[bits](result: CTBool[Limb], a: var BigInt[bits], b: BigInt[bits
|
||||||
for i in static(0 ..< a.limbs.len):
|
for i in static(0 ..< a.limbs.len):
|
||||||
let new_a = a.limbs[i] + b.limbs[i] + Limb(result)
|
let new_a = a.limbs[i] + b.limbs[i] + Limb(result)
|
||||||
result = new_a.isMsbSet()
|
result = new_a.isMsbSet()
|
||||||
a[i] = ctl.mux(new_a and highLimb, a)
|
a[i] = ctl.mux(new_a and HighLimb, a)
|
||||||
|
|
||||||
func add*[bits](a: var BigInt[bits], b: BigInt[bits], ctl: CTBool[Limb]): CTBool[Limb] =
|
func add*[bits](a: var BigInt[bits], b: BigInt[bits], ctl: CTBool[Limb]): CTBool[Limb] =
|
||||||
## Constant-time big integer in-place addition
|
## Constant-time big integer in-place addition
|
||||||
|
@ -96,7 +99,7 @@ template subImpl[bits](result: CTBool[Limb], a: var BigInt[bits], b: BigInt[bits
|
||||||
for i in static(0 ..< a.limbs.len):
|
for i in static(0 ..< a.limbs.len):
|
||||||
let new_a = a.limbs[i] - b.limbs[i] - Limb(result)
|
let new_a = a.limbs[i] - b.limbs[i] - Limb(result)
|
||||||
result = new_a.isMsbSet()
|
result = new_a.isMsbSet()
|
||||||
a[i] = ctl.mux(new_a and highLimb, a)
|
a[i] = ctl.mux(new_a and HighLimb, a)
|
||||||
|
|
||||||
func sub*[bits](a: var BigInt[bits], b: BigInt[bits], ctl: CTBool[Limb]): CTBool[Limb] =
|
func sub*[bits](a: var BigInt[bits], b: BigInt[bits], ctl: CTBool[Limb]): CTBool[Limb] =
|
||||||
## Constant-time big integer in-place addition
|
## Constant-time big integer in-place addition
|
||||||
|
|
|
@ -13,6 +13,41 @@
|
||||||
|
|
||||||
import ../word_types
|
import ../word_types
|
||||||
|
|
||||||
|
func asm_x86_64_extMul(hi, lo: var uint64, a, b: uint64) {.inline.}=
|
||||||
|
## Extended precision multiplication uint64 * uint64 --> uint128
|
||||||
|
|
||||||
|
# TODO !!! - Replace by constant-time, portable, non-assembly version
|
||||||
|
# -> use uint128? Compiler might add unwanted branches
|
||||||
|
|
||||||
|
# MUL r/m64
|
||||||
|
# Multiply RAX by r/m64
|
||||||
|
#
|
||||||
|
# Inputs:
|
||||||
|
# - RAX
|
||||||
|
# - r/m
|
||||||
|
# Outputs:
|
||||||
|
# - High word in RDX
|
||||||
|
# - Low word in RAX
|
||||||
|
|
||||||
|
asm """
|
||||||
|
mulq %[operand]
|
||||||
|
: "=d" (`*hi`), "=a" (`*lo`) // Don't forget to dereference the var hidden pointer
|
||||||
|
: "a" (`a`), [operand] "rm" (`b`)
|
||||||
|
: // no clobbered registers
|
||||||
|
"""
|
||||||
|
|
||||||
|
func unsafe_extendedPrecMul(hi, lo: var Ct[uint64], a, b: Ct[uint64]) {.inline.}=
|
||||||
|
## Extended precision multiplication uint64 * uint64 --> uint128
|
||||||
|
##
|
||||||
|
## TODO, at the moment only x86_64 architecture are supported
|
||||||
|
## as we use assembly.
|
||||||
|
## Also we assume that the native integer division
|
||||||
|
## provided by the PU is constant-time
|
||||||
|
|
||||||
|
# Note, using C/Nim default `*` is inefficient
|
||||||
|
# and complicated to make constant-time
|
||||||
|
# See at the bottom.
|
||||||
|
|
||||||
func asm_x86_64_div2n1n(q, r: var uint64, n_hi, n_lo, d: uint64) {.inline.}=
|
func asm_x86_64_div2n1n(q, r: var uint64, n_hi, n_lo, d: uint64) {.inline.}=
|
||||||
## Division uint128 by uint64
|
## Division uint128 by uint64
|
||||||
## Warning ⚠️ :
|
## Warning ⚠️ :
|
||||||
|
@ -20,6 +55,7 @@ func asm_x86_64_div2n1n(q, r: var uint64, n_hi, n_lo, d: uint64) {.inline.}=
|
||||||
## - if n_hi > d result is undefined
|
## - if n_hi > d result is undefined
|
||||||
|
|
||||||
# TODO !!! - Replace by constant-time, portable, non-assembly version
|
# TODO !!! - Replace by constant-time, portable, non-assembly version
|
||||||
|
# -> use uint128? Compiler might add unwanted branches
|
||||||
|
|
||||||
# DIV r/m64
|
# DIV r/m64
|
||||||
# Divide RDX:RAX (n_hi:n_lo) by r/m64
|
# Divide RDX:RAX (n_hi:n_lo) by r/m64
|
||||||
|
@ -27,7 +63,7 @@ func asm_x86_64_div2n1n(q, r: var uint64, n_hi, n_lo, d: uint64) {.inline.}=
|
||||||
# Inputs
|
# Inputs
|
||||||
# - numerator high word in RDX,
|
# - numerator high word in RDX,
|
||||||
# - numerator low word in RAX,
|
# - numerator low word in RAX,
|
||||||
# - divisor as rm parameter (register or memory at the compiler discretion)
|
# - divisor as r/m parameter (register or memory at the compiler discretion)
|
||||||
# Result
|
# Result
|
||||||
# - Quotient in RAX
|
# - Quotient in RAX
|
||||||
# - Remainder in RDX
|
# - Remainder in RDX
|
||||||
|
@ -44,6 +80,10 @@ func unsafe_div2n1n*(q, r: var Ct[uint64], n_hi, n_lo, d: Ct[uint64]) {.inline.}
|
||||||
## - if n_hi == d, quotient does not fit in an uint64
|
## - if n_hi == d, quotient does not fit in an uint64
|
||||||
## - if n_hi > d result is undefined
|
## - if n_hi > d result is undefined
|
||||||
##
|
##
|
||||||
|
## To avoid issues, n_hi, n_lo, d should be normalized.
|
||||||
|
## i.e. shifted (== multiplied by the same power of 2)
|
||||||
|
## so that the most significant bit in d is set.
|
||||||
|
##
|
||||||
## TODO, at the moment only x86_64 architecture are supported
|
## TODO, at the moment only x86_64 architecture are supported
|
||||||
## as we use assembly.
|
## as we use assembly.
|
||||||
## Also we assume that the native integer division
|
## Also we assume that the native integer division
|
||||||
|
@ -53,7 +93,7 @@ func unsafe_div2n1n*(q, r: var Ct[uint64], n_hi, n_lo, d: Ct[uint64]) {.inline.}
|
||||||
# and complicated to make constant-time
|
# and complicated to make constant-time
|
||||||
# See at the bottom.
|
# See at the bottom.
|
||||||
#
|
#
|
||||||
# Furthermore compilers try to substitute division
|
# Furthermore compilers may try to substitute division
|
||||||
# with a fast path that may have branches. It might also
|
# with a fast path that may have branches. It might also
|
||||||
# be the same at the hardware level.
|
# be the same at the hardware level.
|
||||||
|
|
||||||
|
@ -65,18 +105,68 @@ func unsafe_div2n1n*(q, r: var Ct[uint64], n_hi, n_lo, d: Ct[uint64]) {.inline.}
|
||||||
asm_x86_64_div2n1n(T(q), T(r), T(n_hi), T(n_lo), T(d))
|
asm_x86_64_div2n1n(T(q), T(r), T(n_hi), T(n_lo), T(d))
|
||||||
|
|
||||||
when isMainModule:
|
when isMainModule:
|
||||||
|
block: # Multiplication
|
||||||
|
var hi, lo: uint64
|
||||||
|
|
||||||
var q, r: uint64
|
asm_x86_64_extMul(hi, lo, 1 shl 32, 1 shl 33) # 2^65
|
||||||
|
doAssert hi == 2
|
||||||
|
doAssert lo == 0
|
||||||
|
|
||||||
# (1 shl 64) div 3
|
block: # Division
|
||||||
let n_hi = 1'u64
|
var q, r: uint64
|
||||||
let n_lo = 0'u64
|
|
||||||
let d = 3'u64
|
|
||||||
|
|
||||||
asm_x86_64_div2n1n(q, r, n_hi, n_lo, d)
|
# (1 shl 64) div 3
|
||||||
|
let n_hi = 1'u64
|
||||||
|
let n_lo = 0'u64
|
||||||
|
let d = 3'u64
|
||||||
|
|
||||||
doAssert q == 6148914691236517205'u64
|
asm_x86_64_div2n1n(q, r, n_hi, n_lo, d)
|
||||||
doAssert r == 1
|
|
||||||
|
doAssert q == 6148914691236517205'u64
|
||||||
|
doAssert r == 1
|
||||||
|
|
||||||
|
# ##############################################################
|
||||||
|
#
|
||||||
|
# Non-constant-time portable extended precision multiplication
|
||||||
|
#
|
||||||
|
# ##############################################################
|
||||||
|
|
||||||
|
# implementation from Stint: https://github.com/status-im/nim-stint/blob/edb1ade37309390cc641cee07ab62e5459d9ca44/stint/private/uint_mul.nim#L91-L135
|
||||||
|
|
||||||
|
# template extPrecMulImpl(result: var UintImpl[uint64], op: untyped, u, v: uint64) =
|
||||||
|
# const
|
||||||
|
# p = 64 div 2
|
||||||
|
# base: uint64 = 1 shl p
|
||||||
|
#
|
||||||
|
# var
|
||||||
|
# x0, x1, x2, x3: uint64
|
||||||
|
#
|
||||||
|
# let
|
||||||
|
# ul = lo(u)
|
||||||
|
# uh = hi(u)
|
||||||
|
# vl = lo(v)
|
||||||
|
# vh = hi(v)
|
||||||
|
#
|
||||||
|
# x0 = ul * vl
|
||||||
|
# x1 = ul * vh
|
||||||
|
# x2 = uh * vl
|
||||||
|
# x3 = uh * vh
|
||||||
|
#
|
||||||
|
# x1 += hi(x0) # This can't carry
|
||||||
|
# x1 += x2 # but this can
|
||||||
|
# if x1 < x2: # if carry, add it to x3
|
||||||
|
# x3 += base
|
||||||
|
#
|
||||||
|
# op(result.hi, x3 + hi(x1))
|
||||||
|
# op(result.lo, (x1 shl p) or lo(x0))
|
||||||
|
#
|
||||||
|
# func extPrecMul*(result: var UintImpl[uint64], u, v: uint64) =
|
||||||
|
# ## Extended precision multiplication
|
||||||
|
# extPrecMulImpl(result, `=`, u, v)
|
||||||
|
#
|
||||||
|
# func extPrecAddMul(result: var UintImpl[uint64], u, v: uint64) =
|
||||||
|
# ## Extended precision fused in-place addition & multiplication
|
||||||
|
# extPrecMulImpl(result, `+=`, u, v)
|
||||||
|
|
||||||
# ############################################################
|
# ############################################################
|
||||||
#
|
#
|
||||||
|
|
Loading…
Reference in New Issue