mirror of
https://github.com/codex-storage/constantine.git
synced 2025-01-27 02:54:56 +00:00
Add extended precision multiplication
This commit is contained in:
parent
057ce0cbf9
commit
c226987ab0
@ -50,12 +50,15 @@ type
|
||||
BigInt*[bits: static int] = object
|
||||
limbs*: array[bits.words_required, Limb]
|
||||
|
||||
const highLimb* = (not Ct[uint64](0)) shr 1
|
||||
const HighLimb* = (not Ct[uint64](0)) shr 1
|
||||
## This represents 0x7F_FF_FF_FF__FF_FF_FF_FF
|
||||
## also 0b0111...1111
|
||||
## This biggest representable number in our limbs.
|
||||
## i.e. The most significant bit is never set at the end of each function
|
||||
|
||||
template `[]`*(a: Bigint, idx: int): Limb =
|
||||
a.limbs[idx]
|
||||
|
||||
# ############################################################
|
||||
#
|
||||
# BigInt primitives
|
||||
@ -78,7 +81,7 @@ template addImpl[bits](result: CTBool[Limb], a: var BigInt[bits], b: BigInt[bits
|
||||
for i in static(0 ..< a.limbs.len):
|
||||
let new_a = a.limbs[i] + b.limbs[i] + Limb(result)
|
||||
result = new_a.isMsbSet()
|
||||
a[i] = ctl.mux(new_a and highLimb, a)
|
||||
a[i] = ctl.mux(new_a and HighLimb, a)
|
||||
|
||||
func add*[bits](a: var BigInt[bits], b: BigInt[bits], ctl: CTBool[Limb]): CTBool[Limb] =
|
||||
## Constant-time big integer in-place addition
|
||||
@ -96,7 +99,7 @@ template subImpl[bits](result: CTBool[Limb], a: var BigInt[bits], b: BigInt[bits
|
||||
for i in static(0 ..< a.limbs.len):
|
||||
let new_a = a.limbs[i] - b.limbs[i] - Limb(result)
|
||||
result = new_a.isMsbSet()
|
||||
a[i] = ctl.mux(new_a and highLimb, a)
|
||||
a[i] = ctl.mux(new_a and HighLimb, a)
|
||||
|
||||
func sub*[bits](a: var BigInt[bits], b: BigInt[bits], ctl: CTBool[Limb]): CTBool[Limb] =
|
||||
## Constant-time big integer in-place addition
|
||||
|
@ -13,6 +13,41 @@
|
||||
|
||||
import ../word_types
|
||||
|
||||
func asm_x86_64_extMul(hi, lo: var uint64, a, b: uint64) {.inline.}=
|
||||
## Extended precision multiplication uint64 * uint64 --> uint128
|
||||
|
||||
# TODO !!! - Replace by constant-time, portable, non-assembly version
|
||||
# -> use uint128? Compiler might add unwanted branches
|
||||
|
||||
# MUL r/m64
|
||||
# Multiply RAX by r/m64
|
||||
#
|
||||
# Inputs:
|
||||
# - RAX
|
||||
# - r/m
|
||||
# Outputs:
|
||||
# - High word in RDX
|
||||
# - Low word in RAX
|
||||
|
||||
asm """
|
||||
mulq %[operand]
|
||||
: "=d" (`*hi`), "=a" (`*lo`) // Don't forget to dereference the var hidden pointer
|
||||
: "a" (`a`), [operand] "rm" (`b`)
|
||||
: // no clobbered registers
|
||||
"""
|
||||
|
||||
func unsafe_extendedPrecMul(hi, lo: var Ct[uint64], a, b: Ct[uint64]) {.inline.}=
|
||||
## Extended precision multiplication uint64 * uint64 --> uint128
|
||||
##
|
||||
## TODO, at the moment only x86_64 architecture are supported
|
||||
## as we use assembly.
|
||||
## Also we assume that the native integer division
|
||||
## provided by the PU is constant-time
|
||||
|
||||
# Note, using C/Nim default `*` is inefficient
|
||||
# and complicated to make constant-time
|
||||
# See at the bottom.
|
||||
|
||||
func asm_x86_64_div2n1n(q, r: var uint64, n_hi, n_lo, d: uint64) {.inline.}=
|
||||
## Division uint128 by uint64
|
||||
## Warning ⚠️ :
|
||||
@ -20,6 +55,7 @@ func asm_x86_64_div2n1n(q, r: var uint64, n_hi, n_lo, d: uint64) {.inline.}=
|
||||
## - if n_hi > d result is undefined
|
||||
|
||||
# TODO !!! - Replace by constant-time, portable, non-assembly version
|
||||
# -> use uint128? Compiler might add unwanted branches
|
||||
|
||||
# DIV r/m64
|
||||
# Divide RDX:RAX (n_hi:n_lo) by r/m64
|
||||
@ -27,7 +63,7 @@ func asm_x86_64_div2n1n(q, r: var uint64, n_hi, n_lo, d: uint64) {.inline.}=
|
||||
# Inputs
|
||||
# - numerator high word in RDX,
|
||||
# - numerator low word in RAX,
|
||||
# - divisor as rm parameter (register or memory at the compiler discretion)
|
||||
# - divisor as r/m parameter (register or memory at the compiler discretion)
|
||||
# Result
|
||||
# - Quotient in RAX
|
||||
# - Remainder in RDX
|
||||
@ -44,6 +80,10 @@ func unsafe_div2n1n*(q, r: var Ct[uint64], n_hi, n_lo, d: Ct[uint64]) {.inline.}
|
||||
## - if n_hi == d, quotient does not fit in an uint64
|
||||
## - if n_hi > d result is undefined
|
||||
##
|
||||
## To avoid issues, n_hi, n_lo, d should be normalized.
|
||||
## i.e. shifted (== multiplied by the same power of 2)
|
||||
## so that the most significant bit in d is set.
|
||||
##
|
||||
## TODO, at the moment only x86_64 architecture are supported
|
||||
## as we use assembly.
|
||||
## Also we assume that the native integer division
|
||||
@ -53,7 +93,7 @@ func unsafe_div2n1n*(q, r: var Ct[uint64], n_hi, n_lo, d: Ct[uint64]) {.inline.}
|
||||
# and complicated to make constant-time
|
||||
# See at the bottom.
|
||||
#
|
||||
# Furthermore compilers try to substitute division
|
||||
# Furthermore compilers may try to substitute division
|
||||
# with a fast path that may have branches. It might also
|
||||
# be the same at the hardware level.
|
||||
|
||||
@ -65,18 +105,68 @@ func unsafe_div2n1n*(q, r: var Ct[uint64], n_hi, n_lo, d: Ct[uint64]) {.inline.}
|
||||
asm_x86_64_div2n1n(T(q), T(r), T(n_hi), T(n_lo), T(d))
|
||||
|
||||
when isMainModule:
|
||||
block: # Multiplication
|
||||
var hi, lo: uint64
|
||||
|
||||
var q, r: uint64
|
||||
asm_x86_64_extMul(hi, lo, 1 shl 32, 1 shl 33) # 2^65
|
||||
doAssert hi == 2
|
||||
doAssert lo == 0
|
||||
|
||||
# (1 shl 64) div 3
|
||||
let n_hi = 1'u64
|
||||
let n_lo = 0'u64
|
||||
let d = 3'u64
|
||||
block: # Division
|
||||
var q, r: uint64
|
||||
|
||||
asm_x86_64_div2n1n(q, r, n_hi, n_lo, d)
|
||||
# (1 shl 64) div 3
|
||||
let n_hi = 1'u64
|
||||
let n_lo = 0'u64
|
||||
let d = 3'u64
|
||||
|
||||
doAssert q == 6148914691236517205'u64
|
||||
doAssert r == 1
|
||||
asm_x86_64_div2n1n(q, r, n_hi, n_lo, d)
|
||||
|
||||
doAssert q == 6148914691236517205'u64
|
||||
doAssert r == 1
|
||||
|
||||
# ##############################################################
|
||||
#
|
||||
# Non-constant-time portable extended precision multiplication
|
||||
#
|
||||
# ##############################################################
|
||||
|
||||
# implementation from Stint: https://github.com/status-im/nim-stint/blob/edb1ade37309390cc641cee07ab62e5459d9ca44/stint/private/uint_mul.nim#L91-L135
|
||||
|
||||
# template extPrecMulImpl(result: var UintImpl[uint64], op: untyped, u, v: uint64) =
|
||||
# const
|
||||
# p = 64 div 2
|
||||
# base: uint64 = 1 shl p
|
||||
#
|
||||
# var
|
||||
# x0, x1, x2, x3: uint64
|
||||
#
|
||||
# let
|
||||
# ul = lo(u)
|
||||
# uh = hi(u)
|
||||
# vl = lo(v)
|
||||
# vh = hi(v)
|
||||
#
|
||||
# x0 = ul * vl
|
||||
# x1 = ul * vh
|
||||
# x2 = uh * vl
|
||||
# x3 = uh * vh
|
||||
#
|
||||
# x1 += hi(x0) # This can't carry
|
||||
# x1 += x2 # but this can
|
||||
# if x1 < x2: # if carry, add it to x3
|
||||
# x3 += base
|
||||
#
|
||||
# op(result.hi, x3 + hi(x1))
|
||||
# op(result.lo, (x1 shl p) or lo(x0))
|
||||
#
|
||||
# func extPrecMul*(result: var UintImpl[uint64], u, v: uint64) =
|
||||
# ## Extended precision multiplication
|
||||
# extPrecMulImpl(result, `=`, u, v)
|
||||
#
|
||||
# func extPrecAddMul(result: var UintImpl[uint64], u, v: uint64) =
|
||||
# ## Extended precision fused in-place addition & multiplication
|
||||
# extPrecMulImpl(result, `+=`, u, v)
|
||||
|
||||
# ############################################################
|
||||
#
|
||||
|
Loading…
x
Reference in New Issue
Block a user