Windowed GLV acceleration - 25% faster signing on G1 (#74)

* Fix 8x bigger than necessary encoding size of miniscalars in scalar mul

* initial windowed GLV-SAC implementation

* Simplify table encoding to match k0 without flipping bits
This commit is contained in:
Mamy Ratsimbazafy 2020-08-25 00:02:30 +02:00 committed by GitHub
parent 00ff599106
commit 6ac974d65e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 285 additions and 13 deletions

View File

@ -59,9 +59,13 @@ proc main() =
separator()
scalarMulGenericBench(ECP_SWei_Proj[Fp[curve]], scratchSpaceSize = 1 shl 4, MulIters)
separator()
scalarMulGenericBench(ECP_SWei_Proj[Fp[curve]], scratchSpaceSize = 1 shl 5, MulIters)
separator()
scalarMulEndo(ECP_SWei_Proj[Fp[curve]], MulIters)
separator()
separator()
scalarMulEndoWindow(ECP_SWei_Proj[Fp[curve]], MulIters)
separator()
separator()
main()
notes()

View File

@ -172,6 +172,22 @@ proc scalarMulEndo*(T: typedesc, iters: int) =
else:
{.error: "Not implemented".}
proc scalarMulEndoWindow*(T: typedesc, iters: int) =
const bits = T.F.C.getCurveOrderBitwidth()
const G1_or_G2 = when T.F is Fp: "G1" else: "G2"
var r {.noInit.}: T
let P = rng.random_unsafe(T) # TODO: clear cofactor
let exponent = rng.random_unsafe(BigInt[bits])
bench("EC ScalarMul Window-2 " & G1_or_G2 & " (endomorphism accelerated)", T, iters):
r = P
when T.F is Fp:
r.scalarMulGLV_m2w2(exponent)
else:
{.error: "Not implemented".}
proc scalarMulUnsafeDoubleAddBench*(T: typedesc, iters: int) =
const bits = T.F.C.getCurveOrderBitwidth()
const G1_or_G2 = when T.F is Fp: "G1" else: "G2"

View File

@ -78,7 +78,7 @@ proc `[]`(recoding: Recoded,
## 0 <= digitIdx < LengthInDigits
## returns digit ∈ {0, 1}
const len = Recoded.LengthInDigits
assert digitIdx < len
# assert digitIdx * BitSize < len
let slot = distinctBase(recoding)[
(len-1 - digitIdx) shr Shift
@ -92,7 +92,7 @@ proc `[]=`(recoding: var Recoded,
## returns digit ∈ {0, 1}
## This is write-once
const len = Recoded.LengthInDigits
assert digitIdx < Recoded.LengthInDigits
# assert digitIdx * BitSize < Recoded.LengthInDigits
let slot = distinctBase(recoding)[
(len-1 - digitIdx) shr Shift
@ -185,15 +185,10 @@ func buildLookupTable[M: static int, F](
# The recoding allows usage of 2^(n-1) table instead of the usual 2^n with NAF
let msb = u.log2() # No undefined, u != 0
lut[u].sum(lut[u.clearBit(msb)], endomorphisms[msb])
# } # highlight bug, ...
func tableIndex(glv: GLV_SAC, bit: int): SecretWord =
## Compose the secret table index from
## the GLV-SAC representation and the "bit" accessed
# TODO:
# We are currently storing 2-bit for 0, 1, -1 in the GLV-SAC representation
# but since columns have all the same sign, determined by k0,
# we only need 0 and 1 dividing storage per 2
staticFor i, 1, GLV_SAC.M:
result = result or SecretWord((glv[i][bit] and 1) shl (i-1))
@ -256,7 +251,7 @@ func scalarMulGLV*[scalBits](
let k0isOdd = miniScalars[0].isOdd()
discard miniScalars[0].cadd(SecretWord(1), not k0isOdd)
var recoded: GLV_SAC[2, L] # zero-init required
var recoded: GLV_SAC[M, L] # zero-init required
recoded.nDimMultiScalarRecoding(miniScalars)
# 6. Proceed to GLV accelerated scalar multiplication
@ -274,6 +269,178 @@ func scalarMulGLV*[scalBits](
P.diff(Q, lut[0]) # Contains Q - P0
P.ccopy(Q, k0isOdd)
# Windowed GLV
# ----------------------------------------------------------------
# Config
# - 2 dimensional decomposition
# - Window of size 2
# -> precomputation 2^((2*2)-1) = 8
# Encoding explanation:
# - Coding is in big endian
# digits are grouped 2-by-2
# - k0 column has the following sign and encoding
# - `paper` -> `impl` is `value`
# with ternary encoding from the paper and 𝟙 denoting -1
# - 0t1𝟙 -> 0b01 is 1
# - 0t11 -> 0b00 is 3
# - 0t𝟙1 -> 0b10 is -1
# - 0t𝟙𝟙 -> 0b11 is -3
# - if k0 == 1 (0t1𝟙 - 0b01) or -1 (0b10 - 0t0𝟙):
# then kn is encoded with
# (signed opposite 2-complement)
# - 0t00 -> 0b00 is 0
# - 0t0𝟙 -> 0b01 is -1
# - 0t10 -> 0b10 is 2
# - 0t1𝟙 -> 0b11 is 1
# if k0 == 3 (0b00) or -3 (0b11):
# then kn is encoded with
# (unsigned integer)
# - 0t00 -> 0b00 is 0
# - 0t01 -> 0b01 is 1
# - 0t10 -> 0b10 is 2
# - 0t11 -> 0b11 is 3
func buildLookupTable_m2w2[F](
P0: ECP_SWei_Proj[F],
P1: ECP_SWei_Proj[F],
lut: var array[8, ECP_SWei_Proj[F]],
) =
## Build a lookup table for GLV with 2-dimensional decomposition
## and window of size 2
# with [k0, k1] the mini-scalars with digits of size 2-bit
#
# 4 = 0b100 - encodes [0b01, 0b00] ≡ P0
lut[4] = P0
# 5 = 0b101 - encodes [0b01, 0b01] ≡ P0 - P1
lut[5].diff(lut[4], P1)
# 7 = 0b111 - encodes [0b01, 0b11] ≡ P0 + P1
lut[7].sum(lut[4], P1)
# 6 = 0b110 - encodes [0b01, 0b10] ≡ P0 + 2P1
lut[6].sum(lut[7], P1)
# 0 = 0b000 - encodes [0b00, 0b00] ≡ 3P0
lut[0].double(lut[4])
lut[0] += lut[4]
# 1 = 0b001 - encodes [0b00, 0b01] ≡ 3P0 + P1
lut[1].sum(lut[0], P1)
# 2 = 0b010 - encodes [0b00, 0b10] ≡ 3P0 + 2P1
lut[2].sum(lut[1], P1)
# 3 = 0b011 - encodes [0b00, 0b11] ≡ 3P0 + 3P1
lut[3].sum(lut[2], P1)
func w2Get(recoding: Recoded,
digitIdx: int): uint8 {.inline.}=
## Window Get for window of size 2
## 0 <= digitIdx < LengthInDigits
## returns digit ∈ {0, 1}
const
wBitSize = 2
wWordMask = sizeof(byte) * 8 div 2 - 1 # - In the word, shift to the offset to read/write
wDigitMask = 1 shl wBitSize - 1 # Digits take 1-bit - Once at location, isolate bits to read/write
const len = Recoded.LengthInDigits
# assert digitIdx * wBitSize < len, "digitIdx: " & $digitIdx & ", window: " & $wBitsize & ", len: " & $len
let slot = distinctBase(recoding)[
(len-1 - 2*digitIdx) shr Shift
]
let recoded = slot shr (wBitSize*(digitIdx and wWordMask)) and wDigitMask
return recoded
func w2TableIndex(glv: GLV_SAC, bit2: int, isNeg: var SecretBool): SecretWord {.inline.} =
## Compose the secret table index from
## the windowed of size 2 GLV-SAC representation and the "bit" accessed
let k0 = glv[0].w2Get(bit2)
let k1 = glv[1].w2Get(bit2)
# assert k0 < 4 and k1 < 4
isNeg = SecretBool(k0 shr 1)
let parity = (k0 shr 1) xor (k0 and 1)
result = SecretWord((parity shl 2) or k1)
func computeRecodedLength(bitWidth, window: int): int =
# Strangely in the paper this doesn't depend
# "m", the GLV decomposition dimension.
# lw = ⌈log2 r/w⌉+1
let lw = ((bitWidth + window - 1) div window + 1)
result = (lw mod window) + lw
func scalarMulGLV_m2w2*[scalBits](
P0: var ECP_SWei_Proj,
scalar: BigInt[scalBits]
) =
## Elliptic Curve Scalar Multiplication
##
## P <- [k] P
##
## This is a scalar multiplication accelerated by an endomorphism
## via the GLV (Gallant-lambert-Vanstone) decomposition.
##
## For 2-dimensional decomposition with window 2
const C = P0.F.C # curve
static: doAssert: scalBits == C.getCurveOrderBitwidth()
# 1. Compute endomorphisms
var P1 = P0
P1.x *= C.getCubicRootOfUnity_mod_p()
# 2. Decompose scalar into mini-scalars
const L = computeRecodedLength(C.getCurveOrderBitwidth(), 2)
var miniScalars {.noInit.}: array[2, BigInt[L]]
when C == BN254_Snarks:
scalar.decomposeScalar_BN254_Snarks_G1(
miniScalars
)
elif C == BLS12_381:
scalar.decomposeScalar_BLS12_381_G1(
miniScalars
)
else:
{.error: "Unsupported curve for GLV acceleration".}
# 3. TODO: handle negative mini-scalars
# Either negate the associated base and the scalar (in the `endomorphisms` array)
# Or use Algorithm 3 from Faz et al which can encode the sign
# in the GLV representation at the low low price of 1 bit
# 4. Precompute lookup table
var lut {.noInit.}: array[8, ECP_SWei_Proj]
buildLookupTable_m2w2(P0, P1, lut)
# TODO: Montgomery simultaneous inversion (or other simultaneous inversion techniques)
# so that we use mixed addition formulas in the main loop
# 5. Recode the miniscalars
# we need the base miniscalar (that encodes the sign)
# to be odd, and this in constant-time to protect the secret least-significant bit.
let k0isOdd = miniScalars[0].isOdd()
discard miniScalars[0].cadd(SecretWord(1), not k0isOdd)
var recoded: GLV_SAC[2, L] # zero-init required
recoded.nDimMultiScalarRecoding(miniScalars)
# 6. Proceed to GLV accelerated scalar multiplication
var Q {.noInit.}: typeof(P0)
var isNeg: SecretBool
Q.secretLookup(lut, recoded.w2TableIndex((L div 2) - 1, isNeg))
for i in countdown((L div 2) - 2, 0):
Q.double()
Q.double()
var tmp {.noInit.}: typeof(Q)
tmp.secretLookup(lut, recoded.w2TableIndex(i, isNeg))
tmp.cneg(isNeg)
Q += tmp
# Now we need to correct if the sign miniscalar was not odd
P0.diff(Q, P0)
P0.ccopy(Q, k0isOdd)
# Sanity checks
# ----------------------------------------------------------------
# See page 7 of
@ -297,7 +464,7 @@ when isMainModule:
of 1: "1"
else:
raise newException(ValueError, "Unexpected encoded value: " & $glvSac[j][i])
) # " # Unbreak VSCode highlighting bug
)
result.add " ]\n"
@ -355,6 +522,7 @@ when isMainModule:
kRecoded.nDimMultiScalarRecoding(k)
echo "Recoded bytesize: ", sizeof(kRecoded)
echo kRecoded.toString()
var lut: array[1 shl (M-1), string]
@ -422,7 +590,6 @@ when isMainModule:
main_decomp()
echo "---------------------------------------------"
# This tests the multiplication against the Table 1
@ -506,3 +673,80 @@ when isMainModule:
echo Q
mainFullMul()
echo "---------------------------------------------"
func buildLookupTable_m2w2(
lut: var array[8, array[2, int]],
) =
## Build a lookup table for GLV with 2-dimensional decomposition
## and window of size 2
# with [k0, k1] the mini-scalars with digits of size 2-bit
#
# 0 = 0b000 - encodes [0b01, 0b00] ≡ P0
lut[0] = [1, 0]
# 1 = 0b001 - encodes [0b01, 0b01] ≡ P0 - P1
lut[1] = [1, -1]
# 3 = 0b011 - encodes [0b01, 0b11] ≡ P0 + P1
lut[3] = [1, 1]
# 2 = 0b010 - encodes [0b01, 0b10] ≡ P0 + 2P1
lut[2] = [1, 2]
# 4 = 0b100 - encodes [0b00, 0b00] ≡ 3P0
lut[4] = [3, 0]
# 5 = 0b101 - encodes [0b00, 0b01] ≡ 3P0 + P1
lut[5] = [3, 1]
# 6 = 0b110 - encodes [0b00, 0b10] ≡ 3P0 + 2P1
lut[6] = [3, 2]
# 7 = 0b111 - encodes [0b00, 0b11] ≡ 3P0 + 3P1
lut[7] = [3, 3]
proc mainFullMulWindowed() =
const M = 2 # GLS-2 decomposition
const miniBitwidth = 8 # Bitwidth of the miniscalars resulting from scalar decomposition
const W = 2 # Window
const L = computeRecodedLength(miniBitwidth, W)
var k: MultiScalar[M, L]
var kRecoded: GLV_SAC[M, L]
k[0].fromUint(11)
k[1].fromUint(14)
kRecoded.nDimMultiScalarRecoding(k)
echo "Recoded bytesize: ", sizeof(kRecoded)
echo kRecoded.toString()
var lut: array[8, array[range[P0..P1], int]]
buildLookupTable_m2w2(lut)
echo lut
# Assumes k[0] is odd to simplify test
# and having to conditional substract at the end
assert bool k[0].isOdd()
var Q: array[Endo, int]
var isNeg: SecretBool
let idx = kRecoded.w2TableIndex((L div 2)-1, isNeg)
for p, coef in lut[int(idx)]:
# Unneeeded by construction
# let sign = if isNeg: -1 else: 1
Q[p] = coef
# Loop
for i in countdown((L div 2)-2, 0):
# Q = 4Q
for val in Q.mitems: val *= 4
echo "4Q: ", Q
# Q = Q + sign_l-1 P[K_l-1]
let idx = kRecoded.w2TableIndex(i, isNeg)
for p, coef in lut[int(idx)]:
let sign = (if bool isNeg: -1 else: 1)
Q[p] += sign * coef
echo "Q + sign_l-1 P[K_l-1]: ", Q
echo Q
mainFullMulWindowed()

View File

@ -105,7 +105,9 @@ func decomposeScalar_BLS12_381_G1*[M, scalBits, L: static int](
## - needs a Lattice type
## - needs to better support negative bigints, (extra bit for sign?)
static: doAssert L == (scalBits + M - 1) div M + 1
# Equal when no window, greater otherwise
static: doAssert L >= (scalBits + M - 1) div M + 1
# 𝛼0 = (0x2d91d232ec7e0b3d7 * s) >> 256
# 𝛼1 = (0x24ccef014a773d2d25398fd0300ff6565 * s) >> 256
const

View File

@ -57,7 +57,7 @@ template checkScalarMulScratchspaceLen(len: int) =
func getWindowLen(bufLen: int): uint =
## Compute the maximum window size that fits in the scratchspace buffer
checkScalarMulScratchspaceLen(bufLen)
result = 4
result = 5
while (1 shl result) + 1 > bufLen:
dec result

View File

@ -44,15 +44,18 @@ proc test(
impl = P
reference = P
endo = P
endoW = P
scratchSpace: array[1 shl 4, EC]
impl.scalarMulGeneric(exponentCanonical, scratchSpace)
reference.unsafe_ECmul_double_add(exponentCanonical)
endo.scalarMulGLV(exponent)
endoW.scalarMulGLV_m2w2(exponent)
doAssert: bool(Q == reference)
doAssert: bool(Q == impl)
doAssert: bool(Q == endo)
doAssert: bool(Q == endoW)
suite "Scalar Multiplication (cofactor cleared): BLS12_381 implementation vs SageMath" & " [" & $WordBitwidth & "-bit mode]":
# Generated via sage sage/testgen_bls12_381.sage

View File

@ -44,15 +44,18 @@ proc test(
impl = P
reference = P
endo = P
endoW = P
scratchSpace: array[1 shl 4, EC]
impl.scalarMulGeneric(exponentCanonical, scratchSpace)
reference.unsafe_ECmul_double_add(exponentCanonical)
endo.scalarMulGLV(exponent)
endoW.scalarMulGLV_m2w2(exponent)
doAssert: bool(Q == reference)
doAssert: bool(Q == impl)
doAssert: bool(Q == endo)
doAssert: bool(Q == endoW)
suite "Scalar Multiplication G1: BN254 implementation vs SageMath" & " [" & $WordBitwidth & "-bit mode]":
# Generated via sage sage/testgen_bn254_snarks.sage