[Research] KZG polynomial commitment - part 1 FFT (#151)
* FFT compiles, now on to debugging ... [skip CI] * Fix FFT and add bench [skip ci] * rename + add KZG resources * rename fft_fr * Implement FFT on elliptic curves =) * FFT G1 bench
This commit is contained in:
parent
94419db783
commit
54887b1777
|
@ -23,13 +23,13 @@ export zoo_inversions
|
|||
{.push raises: [].}
|
||||
{.push inline.}
|
||||
|
||||
func inv_euclid*(r: var Fp, a: Fp) =
|
||||
func inv_euclid*(r: var FF, a: FF) =
|
||||
## Inversion modulo p via
|
||||
## Niels Moller constant-time version of
|
||||
## Stein's GCD derived from extended binary Euclid algorithm
|
||||
r.mres.steinsGCD(a.mres, Fp.getR2modP(), Fp.C.Mod, Fp.getPrimePlus1div2())
|
||||
r.mres.steinsGCD(a.mres, FF.getR2modP(), FF.fieldMod(), FF.getPrimePlus1div2())
|
||||
|
||||
func inv*(r: var Fp, a: Fp) =
|
||||
func inv*(r: var FF, a: FF) =
|
||||
## Inversion modulo p
|
||||
##
|
||||
## The inverse of 0 is 0.
|
||||
|
@ -40,22 +40,19 @@ func inv*(r: var Fp, a: Fp) =
|
|||
# neither for Secp256k1 nor BN curves
|
||||
# Performance is slower than GCD
|
||||
# To be revisited with faster squaring/multiplications
|
||||
when Fp.C.hasInversionAddchain():
|
||||
when FF is Fp and FF.C.hasInversionAddchain():
|
||||
r.inv_addchain(a)
|
||||
else:
|
||||
r.inv_euclid(a)
|
||||
|
||||
func inv*(a: var Fp) =
|
||||
func inv*(a: var FF) =
|
||||
## Inversion modulo p
|
||||
##
|
||||
## The inverse of 0 is 0.
|
||||
## Incidentally this avoids extra check
|
||||
## to convert Jacobian and Projective coordinates
|
||||
## to affine for elliptic curve
|
||||
when Fp.C.hasInversionAddchain():
|
||||
a.inv_addchain(a)
|
||||
else:
|
||||
a.inv_euclid(a)
|
||||
a.inv(a)
|
||||
|
||||
{.pop.} # inline
|
||||
{.pop.} # raises no exceptions
|
||||
|
|
|
@ -0,0 +1,8 @@
|
|||
# Research
|
||||
|
||||
This folder stashes experimentations before they are productionized into the library.
|
||||
|
||||
- `kzg`: KZG Polynomial Commitments\
|
||||
Constant-Size Commitments to Polynomials and Their Applications\
|
||||
Aniket Kate, Gregory M. Zaverucha, Ian Goldberg, 2010\
|
||||
https://www.iacr.org/archive/asiacrypt2010/6477178/6477178.pdf
|
|
@ -0,0 +1,12 @@
|
|||
# KZG Polynomial Commitment research
|
||||
|
||||
Research for Ethereum 2.0 phase 1
|
||||
to implement the Data Availability Sampling protocol
|
||||
|
||||
See
|
||||
|
||||
- https://dankradfeist.de/ethereum/2020/06/16/kate-polynomial-commitments.html
|
||||
- https://github.com/protolambda/go-kate
|
||||
- FK20: https://github.com/khovratovich/Kate/blob/master/Kate_amortized.pdf
|
||||
- https://github.com/ethereum/research/tree/master/polynomial_reconstruction
|
||||
- https://github.com/ethereum/research/tree/master/kzg_data_availability
|
|
@ -0,0 +1,280 @@
|
|||
# Constantine
|
||||
# Copyright (c) 2018-2019 Status Research & Development GmbH
|
||||
# Copyright (c) 2020-Present Mamy André-Ratsimbazafy
|
||||
# Licensed and distributed under either of
|
||||
# * MIT license (license terms in the root directory or at http://opensource.org/licenses/MIT).
|
||||
# * Apache v2 license (license terms in the root directory or at http://www.apache.org/licenses/LICENSE-2.0).
|
||||
# at your option. This file may not be copied, modified, or distributed except according to those terms.
|
||||
|
||||
import
|
||||
../../constantine/config/curves,
|
||||
../../constantine/[arithmetic, primitives],
|
||||
../../constantine/io/io_fields,
|
||||
# Research
|
||||
./strided_views,
|
||||
./fft_lut
|
||||
|
||||
# See: https://github.com/ethereum/research/blob/master/kzg_data_availability/fft.py
|
||||
# Quirks of the Python impl:
|
||||
# - no tests of FFT alone?
|
||||
# - a lot of "if type(x) == tuple else"
|
||||
#
|
||||
# See: https://github.com/protolambda/go-kate/blob/7bb4684/fft_fr.go#L19-L21
|
||||
# The go port uses stride+offset to deal with skip iterator.
|
||||
#
|
||||
# Other readable FFTs includes:
|
||||
# - https://github.com/kwantam/fffft
|
||||
# - https://github.com/ConsenSys/gnark/blob/master/internal/backend/bls381/fft/fft.go
|
||||
# - https://github.com/poanetwork/threshold_crypto/blob/8820c11/src/poly_vals.rs#L332-L370
|
||||
# - https://github.com/zkcrypto/bellman/blob/10c5010/src/domain.rs#L272-L315
|
||||
# - Modern Computer Arithmetic, Brent and Zimmermann, p53 algorithm 2.2
|
||||
# https://members.loria.fr/PZimmermann/mca/mca-cup-0.5.9.pdf
|
||||
|
||||
# ############################################################
|
||||
#
|
||||
# Finite-Field Fast Fourier Transform
|
||||
#
|
||||
# ############################################################
|
||||
#
|
||||
# This is a research, unoptimized implementation of
|
||||
# Finite Field Fast Fourier Transform
|
||||
|
||||
# In research phase we tolerate using
|
||||
# - garbage collected types
|
||||
# - and exceptions for fast prototyping
|
||||
#
|
||||
# In particular, in production all signed integers
|
||||
# must be verified not to overflow
|
||||
# and should not throw (or use unsigned)
|
||||
|
||||
# FFT Context
|
||||
# ----------------------------------------------------------------
|
||||
|
||||
type
|
||||
FFTStatus = enum
|
||||
FFTS_Success
|
||||
FFTS_TooManyValues = "Input length greater than the field 2-adicity (number of roots of unity)"
|
||||
FFTS_SizeNotPowerOfTwo = "Input must be of a power of 2 length"
|
||||
|
||||
FFTDescriptor[F] = object
|
||||
## Metadata for FFT on field F
|
||||
maxWidth: int
|
||||
rootOfUnity: F
|
||||
## The root of unity that generates all roots
|
||||
expandedRootsOfUnity: seq[F]
|
||||
## domain, starting and ending with 1
|
||||
|
||||
func isPowerOf2(n: SomeUnsignedInt): bool =
|
||||
(n and (n - 1)) == 0
|
||||
|
||||
func nextPowerOf2(n: uint64): uint64 =
|
||||
## Returns x if x is a power of 2
|
||||
## or the next biggest power of 2
|
||||
1'u64 shl (log2(n-1) + 1)
|
||||
|
||||
func expandRootOfUnity[F](rootOfUnity: F): seq[F] =
|
||||
## From a generator root of unity
|
||||
## expand to width + 1 values.
|
||||
## (Last value is 1 for the reverse array)
|
||||
# For a field of order q, there are gcd(n, q−1)
|
||||
# nth roots of unity, a.k.a. solutions to xⁿ ≡ 1 (mod q)
|
||||
# but it's likely too long to compute bigint GCD
|
||||
# so embrace heap (re-)allocations.
|
||||
# Figuring out how to do to right size the buffers
|
||||
# in production will be fun.
|
||||
result.setLen(2)
|
||||
result[0].setOne()
|
||||
result[1] = rootOfUnity
|
||||
while not result[^1].isOne().bool:
|
||||
result.setLen(result.len + 1)
|
||||
result[^1].prod(result[^2], rootOfUnity)
|
||||
|
||||
# FFT Algorithm
|
||||
# ----------------------------------------------------------------
|
||||
|
||||
# TODO: research Decimation in Time and Decimation in Frequency
|
||||
# and FFT butterflies
|
||||
|
||||
func simpleFT[F](
|
||||
output: var View[F],
|
||||
vals: View[F],
|
||||
rootsOfUnity: View[F]
|
||||
) =
|
||||
# FFT is a recursive algorithm
|
||||
# This is the base-case using a O(n²) algorithm
|
||||
|
||||
let L = output.len
|
||||
var last {.noInit.}, v {.noInit.}: F
|
||||
|
||||
for i in 0 ..< L:
|
||||
last.prod(vals[0], rootsOfUnity[0])
|
||||
for j in 1 ..< L:
|
||||
v.prod(vals[j], rootsOfUnity[(i*j) mod L])
|
||||
last += v
|
||||
output[i] = last
|
||||
|
||||
func fft_internal[F](
|
||||
output: var View[F],
|
||||
vals: View[F],
|
||||
rootsOfUnity: View[F]
|
||||
) =
|
||||
if output.len <= 4:
|
||||
simpleFT(output, vals, rootsOfUnity)
|
||||
return
|
||||
|
||||
# Recursive Divide-and-Conquer
|
||||
let (evenVals, oddVals) = vals.splitAlternate()
|
||||
var (outLeft, outRight) = output.splitMiddle()
|
||||
let halfROI = rootsOfUnity.skipHalf()
|
||||
|
||||
fft_internal(outLeft, evenVals, halfROI)
|
||||
fft_internal(outRight, oddVals, halfROI)
|
||||
|
||||
let half = outLeft.len
|
||||
var y_times_root{.noinit.}: F
|
||||
|
||||
for i in 0 ..< half:
|
||||
# FFT Butterfly
|
||||
y_times_root .prod(output[i+half], rootsOfUnity[i])
|
||||
output[i+half] .diff(output[i], y_times_root)
|
||||
output[i] += y_times_root
|
||||
|
||||
func fft*[F](
|
||||
desc: FFTDescriptor[F],
|
||||
output: var openarray[F],
|
||||
vals: openarray[F]): FFT_Status =
|
||||
if vals.len > desc.maxWidth:
|
||||
return FFTS_TooManyValues
|
||||
if not vals.len.uint64.isPowerOf2():
|
||||
return FFTS_SizeNotPowerOfTwo
|
||||
|
||||
let rootz = desc.expandedRootsOfUnity
|
||||
.toView()
|
||||
.slice(0, desc.maxWidth-1, desc.maxWidth div vals.len)
|
||||
|
||||
var voutput = output.toView()
|
||||
fft_internal(voutput, vals.toView(), rootz)
|
||||
return FFTS_Success
|
||||
|
||||
func ifft*[F](
|
||||
desc: FFTDescriptor[F],
|
||||
output: var openarray[F],
|
||||
vals: openarray[F]): FFT_Status =
|
||||
## Inverse FFT
|
||||
if vals.len > desc.maxWidth:
|
||||
return FFTS_TooManyValues
|
||||
if not vals.len.uint64.isPowerOf2():
|
||||
return FFTS_SizeNotPowerOfTwo
|
||||
|
||||
let rootz = desc.expandedRootsOfUnity
|
||||
.toView()
|
||||
.reversed()
|
||||
.slice(0, desc.maxWidth-1, desc.maxWidth div vals.len)
|
||||
|
||||
var voutput = output.toView()
|
||||
fft_internal(voutput, vals.toView(), rootz)
|
||||
|
||||
var invLen {.noInit.}: F
|
||||
invLen.fromUint(vals.len.uint64)
|
||||
invLen.inv()
|
||||
|
||||
for i in 0..< output.len:
|
||||
output[i] *= invLen
|
||||
|
||||
return FFTS_Success
|
||||
|
||||
# FFT Descriptor
|
||||
# ----------------------------------------------------------------
|
||||
|
||||
proc init*(T: type FFTDescriptor, maxScale: uint8): T =
|
||||
result.maxWidth = 1 shl maxScale
|
||||
result.rootOfUnity = scaleToRootOfUnity(T.F.C)[maxScale]
|
||||
result.expandedRootsOfUnity =
|
||||
result.rootOfUnity.expandRootOfUnity()
|
||||
# Aren't you tired of reading about unity?
|
||||
|
||||
# ############################################################
|
||||
#
|
||||
# Sanity checks
|
||||
#
|
||||
# ############################################################
|
||||
|
||||
{.experimental: "views".}
|
||||
|
||||
when isMainModule:
|
||||
import
|
||||
std/[times, monotimes, strformat],
|
||||
../../helpers/prng_unsafe
|
||||
|
||||
proc roundtrip() =
|
||||
let fftDesc = FFTDescriptor[Fr[BLS12_381]].init(maxScale = 4)
|
||||
var data = newSeq[Fr[BLS12_381]](fftDesc.maxWidth)
|
||||
for i in 0 ..< fftDesc.maxWidth:
|
||||
data[i].fromUint i.uint64
|
||||
|
||||
var coefs = newSeq[Fr[BLS12_381]](data.len)
|
||||
let fftOk = fft(fftDesc, coefs, data)
|
||||
doAssert fftOk == FFTS_Success
|
||||
# display("coefs", 0, coefs)
|
||||
|
||||
var res = newSeq[Fr[BLS12_381]](data.len)
|
||||
let ifftOk = ifft(fftDesc, res, coefs)
|
||||
doAssert ifftOk == FFTS_Success
|
||||
# display("res", 0, coefs)
|
||||
|
||||
for i in 0 ..< res.len:
|
||||
if bool(res[i] != data[i]):
|
||||
echo "Error: expected ", data[i].toHex(), " but got ", res[i].toHex()
|
||||
quit 1
|
||||
|
||||
echo "FFT round-trip check SUCCESS"
|
||||
|
||||
proc warmup() =
|
||||
# Warmup - make sure cpu is on max perf
|
||||
let start = cpuTime()
|
||||
var foo = 123
|
||||
for i in 0 ..< 300_000_000:
|
||||
foo += i*i mod 456
|
||||
foo = foo mod 789
|
||||
|
||||
# Compiler shouldn't optimize away the results as cpuTime rely on sideeffects
|
||||
let stop = cpuTime()
|
||||
echo &"Warmup: {stop - start:>4.4f} s, result {foo} (displayed to avoid compiler optimizing warmup away)\n"
|
||||
|
||||
|
||||
proc bench() =
|
||||
echo "Starting benchmark ..."
|
||||
const NumIters = 100
|
||||
|
||||
var rng: RngState
|
||||
rng.seed 0x1234
|
||||
# TODO: view types complain about mutable borrow
|
||||
# in `random_unsafe` due to pseudo view type LimbsViewMut
|
||||
# (which was views before Nim properly supported them)
|
||||
|
||||
warmup()
|
||||
|
||||
for scale in 4 ..< 16:
|
||||
# Setup
|
||||
|
||||
let desc = FFTDescriptor[Fr[BLS12_381]].init(uint8 scale)
|
||||
var data = newSeq[Fr[BLS12_381]](desc.maxWidth)
|
||||
for i in 0 ..< desc.maxWidth:
|
||||
# data[i] = rng.random_unsafe(data[i].typeof())
|
||||
data[i].fromUint i.uint64
|
||||
|
||||
var coefsOut = newSeq[Fr[BLS12_381]](data.len)
|
||||
|
||||
# Bench
|
||||
let start = getMonotime()
|
||||
for i in 0 ..< NumIters:
|
||||
let status = desc.fft(coefsOut, data)
|
||||
doAssert status == FFTS_Success
|
||||
let stop = getMonotime()
|
||||
|
||||
let ns = inNanoseconds((stop-start) div NumIters)
|
||||
echo &"FFT scale {scale:>2} {ns:>8} ns/op"
|
||||
|
||||
roundtrip()
|
||||
warmup()
|
||||
bench()
|
|
@ -0,0 +1,302 @@
|
|||
# Constantine
|
||||
# Copyright (c) 2018-2019 Status Research & Development GmbH
|
||||
# Copyright (c) 2020-Present Mamy André-Ratsimbazafy
|
||||
# Licensed and distributed under either of
|
||||
# * MIT license (license terms in the root directory or at http://opensource.org/licenses/MIT).
|
||||
# * Apache v2 license (license terms in the root directory or at http://www.apache.org/licenses/LICENSE-2.0).
|
||||
# at your option. This file may not be copied, modified, or distributed except according to those terms.
|
||||
|
||||
import
|
||||
../../constantine/config/curves,
|
||||
../../constantine/[arithmetic, primitives],
|
||||
../../constantine/elliptic/[
|
||||
ec_endomorphism_accel,
|
||||
ec_shortweierstrass_affine,
|
||||
ec_shortweierstrass_projective,
|
||||
ec_shortweierstrass_jacobian,
|
||||
],
|
||||
../../constantine/io/[io_fields, io_ec],
|
||||
# Research
|
||||
./strided_views,
|
||||
./fft_lut
|
||||
|
||||
# See: https://github.com/ethereum/research/blob/master/kzg_data_availability/fft.py
|
||||
# Quirks of the Python impl:
|
||||
# - no tests of FFT alone?
|
||||
# - a lot of "if type(x) == tuple else"
|
||||
#
|
||||
# See: https://github.com/protolambda/go-kate/blob/7bb4684/fft_fr.go#L19-L21
|
||||
# The go port uses stride+offset to deal with skip iterator.
|
||||
#
|
||||
# Other readable FFTs includes:
|
||||
# - https://github.com/kwantam/fffft
|
||||
# - https://github.com/ConsenSys/gnark/blob/master/internal/backend/bls381/fft/fft.go
|
||||
# - https://github.com/poanetwork/threshold_crypto/blob/8820c11/src/poly_vals.rs#L332-L370
|
||||
# - https://github.com/zkcrypto/bellman/blob/10c5010/src/domain.rs#L272-L315
|
||||
# - Modern Computer Arithmetic, Brent and Zimmermann, p53 algorithm 2.2
|
||||
# https://members.loria.fr/PZimmermann/mca/mca-cup-0.5.9.pdf
|
||||
|
||||
# ############################################################
|
||||
#
|
||||
# Finite-Field Fast Fourier Transform
|
||||
#
|
||||
# ############################################################
|
||||
#
|
||||
# This is a research, unoptimized implementation of
|
||||
# Finite Field Fast Fourier Transform
|
||||
|
||||
# In research phase we tolerate using
|
||||
# - garbage collected types
|
||||
# - and exceptions for fast prototyping
|
||||
#
|
||||
# In particular, in production all signed integers
|
||||
# must be verified not to overflow
|
||||
# and should not throw (or use unsigned)
|
||||
|
||||
# FFT Context
|
||||
# ----------------------------------------------------------------
|
||||
|
||||
type
|
||||
FFTStatus = enum
|
||||
FFTS_Success
|
||||
FFTS_TooManyValues = "Input length greater than the field 2-adicity (number of roots of unity)"
|
||||
FFTS_SizeNotPowerOfTwo = "Input must be of a power of 2 length"
|
||||
|
||||
FFTDescriptor[EC] = object
|
||||
## Metadata for FFT on Elliptic Curve
|
||||
maxWidth: int
|
||||
rootOfUnity: matchingOrderBigInt(EC.F.C)
|
||||
## The root of unity that generates all roots
|
||||
expandedRootsOfUnity: seq[matchingOrderBigInt(EC.F.C)]
|
||||
## domain, starting and ending with 1
|
||||
|
||||
func isPowerOf2(n: SomeUnsignedInt): bool =
|
||||
(n and (n - 1)) == 0
|
||||
|
||||
func nextPowerOf2(n: uint64): uint64 =
|
||||
## Returns x if x is a power of 2
|
||||
## or the next biggest power of 2
|
||||
1'u64 shl (log2(n-1) + 1)
|
||||
|
||||
func expandRootOfUnity[F](rootOfUnity: F): auto {.noInit.} =
|
||||
## From a generator root of unity
|
||||
## expand to width + 1 values.
|
||||
## (Last value is 1 for the reverse array)
|
||||
# For a field of order q, there are gcd(n, q−1)
|
||||
# nth roots of unity, a.k.a. solutions to xⁿ ≡ 1 (mod q)
|
||||
# but it's likely too long to compute bigint GCD
|
||||
# so embrace heap (re-)allocations.
|
||||
# Figuring out how to do to right size the buffers
|
||||
# in production will be fun.
|
||||
var r: seq[matchingOrderBigInt(F.C)]
|
||||
r.setLen(2)
|
||||
r[0].setOne()
|
||||
r[1] = rootOfUnity.toBig()
|
||||
|
||||
var cur = rootOfUnity
|
||||
while not r[^1].isOne().bool:
|
||||
cur *= rootOfUnity
|
||||
r.setLen(r.len + 1)
|
||||
r[^1] = cur.toBig()
|
||||
|
||||
return r
|
||||
|
||||
# FFT Algorithm
|
||||
# ----------------------------------------------------------------
|
||||
|
||||
func simpleFT[EC; bits: static int](
|
||||
output: var View[EC],
|
||||
vals: View[EC],
|
||||
rootsOfUnity: View[BigInt[bits]]
|
||||
) =
|
||||
# FFT is a recursive algorithm
|
||||
# This is the base-case using a O(n²) algorithm
|
||||
|
||||
let L = output.len
|
||||
var last {.noInit.}, v {.noInit.}: EC
|
||||
|
||||
for i in 0 ..< L:
|
||||
last = vals[0]
|
||||
last.scalarMulGLV_m2w2(rootsOfUnity[0])
|
||||
for j in 1 ..< L:
|
||||
v = vals[j]
|
||||
v.scalarMulGLV_m2w2(rootsOfUnity[(i*j) mod L])
|
||||
last += v
|
||||
output[i] = last
|
||||
|
||||
func fft_internal[EC; bits: static int](
|
||||
output: var View[EC],
|
||||
vals: View[EC],
|
||||
rootsOfUnity: View[BigInt[bits]]
|
||||
) =
|
||||
if output.len <= 4:
|
||||
simpleFT(output, vals, rootsOfUnity)
|
||||
return
|
||||
|
||||
# Recursive Divide-and-Conquer
|
||||
let (evenVals, oddVals) = vals.splitAlternate()
|
||||
var (outLeft, outRight) = output.splitMiddle()
|
||||
let halfROI = rootsOfUnity.skipHalf()
|
||||
|
||||
fft_internal(outLeft, evenVals, halfROI)
|
||||
fft_internal(outRight, oddVals, halfROI)
|
||||
|
||||
let half = outLeft.len
|
||||
var y_times_root{.noinit.}: EC
|
||||
|
||||
for i in 0 ..< half:
|
||||
# FFT Butterfly
|
||||
y_times_root = output[i+half]
|
||||
y_times_root .scalarMulGLV_m2w2(rootsOfUnity[i])
|
||||
output[i+half] .diff(output[i], y_times_root)
|
||||
output[i] += y_times_root
|
||||
|
||||
func fft*[EC](
|
||||
desc: FFTDescriptor[EC],
|
||||
output: var openarray[EC],
|
||||
vals: openarray[EC]): FFT_Status =
|
||||
if vals.len > desc.maxWidth:
|
||||
return FFTS_TooManyValues
|
||||
if not vals.len.uint64.isPowerOf2():
|
||||
return FFTS_SizeNotPowerOfTwo
|
||||
|
||||
let rootz = desc.expandedRootsOfUnity
|
||||
.toView()
|
||||
.slice(0, desc.maxWidth-1, desc.maxWidth div vals.len)
|
||||
|
||||
var voutput = output.toView()
|
||||
fft_internal(voutput, vals.toView(), rootz)
|
||||
return FFTS_Success
|
||||
|
||||
func ifft*[EC](
|
||||
desc: FFTDescriptor[EC],
|
||||
output: var openarray[EC],
|
||||
vals: openarray[EC]): FFT_Status =
|
||||
## Inverse FFT
|
||||
if vals.len > desc.maxWidth:
|
||||
return FFTS_TooManyValues
|
||||
if not vals.len.uint64.isPowerOf2():
|
||||
return FFTS_SizeNotPowerOfTwo
|
||||
|
||||
let rootz = desc.expandedRootsOfUnity
|
||||
.toView()
|
||||
.reversed()
|
||||
.slice(0, desc.maxWidth-1, desc.maxWidth div vals.len)
|
||||
|
||||
var voutput = output.toView()
|
||||
fft_internal(voutput, vals.toView(), rootz)
|
||||
|
||||
var invLen {.noInit.}: Fr[EC.F.C]
|
||||
invLen.fromUint(vals.len.uint64)
|
||||
invLen.inv()
|
||||
let inv = invLen.toBig()
|
||||
|
||||
for i in 0..< output.len:
|
||||
output[i].scalarMulGLV_m2w2(inv)
|
||||
|
||||
return FFTS_Success
|
||||
|
||||
# FFT Descriptor
|
||||
# ----------------------------------------------------------------
|
||||
|
||||
proc init*(T: type FFTDescriptor, maxScale: uint8): T =
|
||||
result.maxWidth = 1 shl maxScale
|
||||
|
||||
let root = scaleToRootOfUnity(T.EC.F.C)[maxScale]
|
||||
result.rootOfUnity = root.toBig()
|
||||
result.expandedRootsOfUnity = root.expandRootOfUnity()
|
||||
# Aren't you tired of reading about unity?
|
||||
|
||||
# ############################################################
|
||||
#
|
||||
# Sanity checks
|
||||
#
|
||||
# ############################################################
|
||||
|
||||
{.experimental: "views".}
|
||||
|
||||
when isMainModule:
|
||||
import
|
||||
std/[times, monotimes, strformat],
|
||||
../../helpers/prng_unsafe
|
||||
|
||||
type G1 = ECP_ShortW_Prj[Fp[BLS12_381], NotOnTwist]
|
||||
var Generator1: ECP_ShortW_Aff[Fp[BLS12_381], NotOnTwist]
|
||||
doAssert Generator1.fromHex(
|
||||
"0x17f1d3a73197d7942695638c4fa9ac0fc3688c4f9774b905a14e3a3f171bac586c55e83ff97a1aeffb3af00adb22c6bb",
|
||||
"0x08b3f481e3aaa0f1a09e30ed741d8ae4fcf5e095d5d00af600db18cb2c04b3edd03cc744a2888ae40caa232946c5e7e1"
|
||||
)
|
||||
|
||||
proc roundtrip() =
|
||||
let fftDesc = FFTDescriptor[G1].init(maxScale = 4)
|
||||
var data = newSeq[G1](fftDesc.maxWidth)
|
||||
data[0].projectiveFromAffine(Generator1)
|
||||
for i in 1 ..< fftDesc.maxWidth:
|
||||
data[i].madd(data[i-1], Generator1)
|
||||
|
||||
var coefs = newSeq[G1](data.len)
|
||||
let fftOk = fft(fftDesc, coefs, data)
|
||||
doAssert fftOk == FFTS_Success
|
||||
# display("coefs", 0, coefs)
|
||||
|
||||
var res = newSeq[G1](data.len)
|
||||
let ifftOk = ifft(fftDesc, res, coefs)
|
||||
doAssert ifftOk == FFTS_Success
|
||||
# display("res", 0, coefs)
|
||||
|
||||
for i in 0 ..< res.len:
|
||||
if bool(res[i] != data[i]):
|
||||
echo "Error: expected ", data[i].toHex(), " but got ", res[i].toHex()
|
||||
quit 1
|
||||
|
||||
echo "FFT round-trip check SUCCESS"
|
||||
|
||||
proc warmup() =
|
||||
# Warmup - make sure cpu is on max perf
|
||||
let start = cpuTime()
|
||||
var foo = 123
|
||||
for i in 0 ..< 300_000_000:
|
||||
foo += i*i mod 456
|
||||
foo = foo mod 789
|
||||
|
||||
# Compiler shouldn't optimize away the results as cpuTime rely on sideeffects
|
||||
let stop = cpuTime()
|
||||
echo &"Warmup: {stop - start:>4.4f} s, result {foo} (displayed to avoid compiler optimizing warmup away)\n"
|
||||
|
||||
|
||||
proc bench() =
|
||||
echo "Starting benchmark ..."
|
||||
const NumIters = 3
|
||||
|
||||
var rng: RngState
|
||||
rng.seed 0x1234
|
||||
# TODO: view types complain about mutable borrow
|
||||
# in `random_unsafe` due to pseudo view type LimbsViewMut
|
||||
# (which was views before Nim properly supported them)
|
||||
|
||||
warmup()
|
||||
|
||||
for scale in 4 ..< 16:
|
||||
# Setup
|
||||
|
||||
let desc = FFTDescriptor[G1].init(uint8 scale)
|
||||
var data = newSeq[G1](desc.maxWidth)
|
||||
data[0].projectiveFromAffine(Generator1)
|
||||
for i in 1 ..< desc.maxWidth:
|
||||
data[i].madd(data[i-1], Generator1)
|
||||
|
||||
var coefsOut = newSeq[G1](data.len)
|
||||
|
||||
# Bench
|
||||
let start = getMonotime()
|
||||
for i in 0 ..< NumIters:
|
||||
let status = desc.fft(coefsOut, data)
|
||||
doAssert status == FFTS_Success
|
||||
let stop = getMonotime()
|
||||
|
||||
let ns = inNanoseconds((stop-start) div NumIters)
|
||||
echo &"FFT scale {scale:>2} {ns:>8} ns/op"
|
||||
|
||||
roundtrip()
|
||||
warmup()
|
||||
bench()
|
|
@ -0,0 +1,47 @@
|
|||
# Constantine
|
||||
# Copyright (c) 2018-2019 Status Research & Development GmbH
|
||||
# Copyright (c) 2020-Present Mamy André-Ratsimbazafy
|
||||
# Licensed and distributed under either of
|
||||
# * MIT license (license terms in the root directory or at http://opensource.org/licenses/MIT).
|
||||
# * Apache v2 license (license terms in the root directory or at http://www.apache.org/licenses/LICENSE-2.0).
|
||||
# at your option. This file may not be copied, modified, or distributed except according to those terms.
|
||||
|
||||
import
|
||||
std/macros,
|
||||
../../constantine/config/[curves, common],
|
||||
../../constantine/[arithmetic, primitives],
|
||||
../../constantine/io/io_fields
|
||||
|
||||
# TODO automate this
|
||||
# we can precompute everything in Sage
|
||||
# and auto-generate the file.
|
||||
|
||||
const BLS12_381_Fr_primitive_root = 5
|
||||
|
||||
func buildRootLUT(F: type Fr): array[32, F] =
|
||||
## [pow(PRIMITIVE_ROOT, (MODULUS - 1) // (2**i), MODULUS) for i in range(32)]
|
||||
|
||||
var exponent {.noInit.}: BigInt[F.C.getCurveOrderBitwidth()]
|
||||
exponent = F.C.getCurveOrder()
|
||||
exponent -= One
|
||||
|
||||
# Start by the end
|
||||
var i = result.len - 1
|
||||
exponent.shiftRight(i)
|
||||
result[i].fromUint(BLS12_381_Fr_primitive_root)
|
||||
result[i].powUnsafeExponent(exponent)
|
||||
|
||||
while i > 0:
|
||||
result[i-1].square(result[i])
|
||||
dec i
|
||||
|
||||
# debugEcho "Fr[BLS12_81] - Roots of Unity:"
|
||||
# for i in 0 ..< result.len:
|
||||
# debugEcho " ", i, ": ", result[i].toHex()
|
||||
# debugEcho "Fr[BLS12_81] - Roots of Unity -- FIN\n"
|
||||
|
||||
let BLS12_381_Fr_ScaleToRootOfUnity* = buildRootLUT(Fr[BLS12_381])
|
||||
|
||||
{.experimental: "dynamicBindSym".}
|
||||
macro scaleToRootOfUnity*(C: static Curve): untyped =
|
||||
return bindSym($C & "_Fr_ScaleToRootOfUnity")
|
|
@ -0,0 +1,247 @@
|
|||
# Constantine
|
||||
# Copyright (c) 2018-2019 Status Research & Development GmbH
|
||||
# Copyright (c) 2020-Present Mamy André-Ratsimbazafy
|
||||
# Licensed and distributed under either of
|
||||
# * MIT license (license terms in the root directory or at http://opensource.org/licenses/MIT).
|
||||
# * Apache v2 license (license terms in the root directory or at http://www.apache.org/licenses/LICENSE-2.0).
|
||||
# at your option. This file may not be copied, modified, or distributed except according to those terms.
|
||||
|
||||
# Strided View - Monodimensional Tensors
|
||||
# ----------------------------------------------------------------
|
||||
#
|
||||
# FFT uses recursive divide-and-conquer.
|
||||
# In code this means need strided views
|
||||
# to enable different logical views of the same memory buffer.
|
||||
# Strided views are monodimensional tensors:
|
||||
# See Arraymancer backend:
|
||||
# https://github.com/mratsim/Arraymancer/blob/71cf616/src/arraymancer/laser/tensor/datatypes.nim#L28-L32
|
||||
# Or the minimal tensor implementation challenge:
|
||||
# https://github.com/SimonDanisch/julia-challenge/blob/b8ed3b6/nim/nim_sol_mratsim.nim#L4-L26
|
||||
|
||||
{.experimental: "views".}
|
||||
|
||||
type
|
||||
View*[T] = object
|
||||
## A strided view over an (unowned) data buffer
|
||||
len*: int
|
||||
stride: int
|
||||
offset: int
|
||||
data: lent UncheckedArray[T]
|
||||
|
||||
func `[]`*[T](v: View[T], idx: int): lent T {.inline.} =
|
||||
v.data[v.offset + idx*v.stride]
|
||||
|
||||
func `[]`*[T](v: var View[T], idx: int): var T {.inline.} =
|
||||
# Experimental views indeed ...
|
||||
cast[ptr UncheckedArray[T]](v.data)[v.offset + idx*v.stride]
|
||||
|
||||
func `[]=`*[T](v: var View[T], idx: int, val: T) {.inline.} =
|
||||
# Experimental views indeed ...
|
||||
cast[ptr UncheckedArray[T]](v.data)[v.offset + idx*v.stride] = val
|
||||
|
||||
func toView*[T](oa: openArray[T]): View[T] {.inline.} =
|
||||
result.len = oa.len
|
||||
result.stride = 1
|
||||
result.offset = 0
|
||||
result.data = cast[lent UncheckedArray[T]](oa[0].unsafeAddr)
|
||||
|
||||
iterator items*[T](v: View[T]): lent T =
|
||||
var cur = v.offset
|
||||
for _ in 0 ..< v.len:
|
||||
yield v.data[cur]
|
||||
cur += v.stride
|
||||
|
||||
func `$`*(v: View): string =
|
||||
result = "View["
|
||||
var first = true
|
||||
for elem in v:
|
||||
if not first:
|
||||
result &= ", "
|
||||
else:
|
||||
first = false
|
||||
result &= $elem
|
||||
result &= ']'
|
||||
|
||||
func toHex*(v: View): string =
|
||||
mixin toHex
|
||||
|
||||
result = "View["
|
||||
var first = true
|
||||
for elem in v:
|
||||
if not first:
|
||||
result &= ", "
|
||||
else:
|
||||
first = false
|
||||
result &= elem.toHex()
|
||||
result &= ']'
|
||||
|
||||
# FFT-specific splitting
|
||||
# -------------------------------------------------------------------------------
|
||||
|
||||
func splitAlternate*(t: View): tuple[even, odd: View] {.inline.} =
|
||||
## Split the tensor into 2
|
||||
## partitioning the input every other index
|
||||
## even: indices [0, 2, 4, ...]
|
||||
## odd: indices [ 1, 3, 5, ...]
|
||||
assert (t.len and 1) == 0, "The tensor must contain an even number of elements"
|
||||
|
||||
let half = t.len shr 1
|
||||
let skipHalf = t.stride shl 1
|
||||
|
||||
result.even.len = half
|
||||
result.even.stride = skipHalf
|
||||
result.even.offset = t.offset
|
||||
result.even.data = t.data
|
||||
|
||||
result.odd.len = half
|
||||
result.odd.stride = skipHalf
|
||||
result.odd.offset = t.offset + t.stride
|
||||
result.odd.data = t.data
|
||||
|
||||
func splitMiddle*(t: View): tuple[left, right: View] {.inline.} =
|
||||
## Split the tensor into 2
|
||||
## partitioning into left and right halves.
|
||||
## left: indices [0, 1, 2, 3]
|
||||
## right: indices [4, 5, 6, 7]
|
||||
assert (t.len and 1) == 0, "The tensor must contain an even number of elements"
|
||||
|
||||
let half = t.len shr 1
|
||||
|
||||
result.left.len = half
|
||||
result.left.stride = t.stride
|
||||
result.left.offset = t.offset
|
||||
result.left.data = t.data
|
||||
|
||||
result.right.len = half
|
||||
result.right.stride = t.stride
|
||||
result.right.offset = t.offset + half
|
||||
result.right.data = t.data
|
||||
|
||||
func skipHalf*(t: View): View {.inline.} =
|
||||
## Pick one every other indices
|
||||
## output: [0, 2, 4, ...]
|
||||
assert (t.len and 1) == 0, "The tensor must contain an even number of elements"
|
||||
|
||||
result.len = t.len shr 1
|
||||
result.stride = t.stride shl 1
|
||||
result.offset = t.offset
|
||||
result.data = t.data
|
||||
|
||||
func slice*(v: View, start, stop, step: int): View {.inline.} =
|
||||
## Slice a view
|
||||
## stop is inclusive
|
||||
# General tensor slicing algorithm is
|
||||
# https://github.com/mratsim/Arraymancer/blob/71cf616/src/arraymancer/tensor/private/p_accessors_macros_read.nim#L26-L56
|
||||
#
|
||||
# for i, slice in slices:
|
||||
# # Check if we start from the end
|
||||
# let a = if slice.a_from_end: result.shape[i] - slice.a
|
||||
# else: slice.a
|
||||
#
|
||||
# let b = if slice.b_from_end: result.shape[i] - slice.b
|
||||
# else: slice.b
|
||||
#
|
||||
# # Compute offset:
|
||||
# result.offset += a * result.strides[i]
|
||||
# # Now change shape and strides
|
||||
# result.strides[i] *= slice.step
|
||||
# result.shape[i] = abs((b-a) div slice.step) + 1
|
||||
#
|
||||
# with slices being of size 1, as we have a monodimensional Tensor
|
||||
# and the slice being a..<b with the reverse case: len-1 -> 0
|
||||
#
|
||||
# result is preinitialized with a copy of v (shape, stride, offset, data)
|
||||
result.offset = v.offset + start * v.stride
|
||||
result.stride = v.stride * step
|
||||
result.len = abs((stop-start) div step) + 1
|
||||
result.data = v.data
|
||||
|
||||
func reversed*(v: View): View {.inline.} =
|
||||
# Hopefully the compiler optimizes div by -1
|
||||
v.slice(v.len-1, 0, -1)
|
||||
|
||||
# ############################################################
|
||||
#
|
||||
# Debugging helpers
|
||||
#
|
||||
# ############################################################
|
||||
import strformat, strutils
|
||||
|
||||
func display*[F](name: string, indent: int, oa: openArray[F]) =
|
||||
debugEcho indent(name & ", openarray of " & $F & " of length " & $oa.len, indent)
|
||||
for i in 0 ..< oa.len:
|
||||
debugEcho indent(&" {i:>2}: {oa[i].toHex()}", indent)
|
||||
debugEcho indent(name & " " & $F & " -- FIN\n", indent)
|
||||
|
||||
func display*[F](name: string, indent: int, v: View[F]) =
|
||||
debugEcho indent(name & ", view of " & $F & " of length " & $v.len, indent)
|
||||
for i in 0 ..< v.len:
|
||||
debugEcho indent(&" {i:>2}: {v[i].toHex()}", indent)
|
||||
debugEcho indent(name & " " & $F & " -- FIN\n", indent)
|
||||
|
||||
# ############################################################
|
||||
#
|
||||
# Sanity checks
|
||||
#
|
||||
# ############################################################
|
||||
|
||||
when isMainModule:
|
||||
proc main() =
|
||||
var x = [0, 1, 2, 3, 4, 5, 6, 7]
|
||||
let v = x.toView()
|
||||
|
||||
echo "view: ", v
|
||||
echo "reversed: ", v.reversed()
|
||||
|
||||
block:
|
||||
let (even, odd) = v.splitAlternate()
|
||||
echo "\nSplit Alternate"
|
||||
echo "----------------"
|
||||
echo "even: ", even
|
||||
echo "odd: ", odd
|
||||
|
||||
block:
|
||||
let (ee, eo) = even.splitAlternate()
|
||||
echo ""
|
||||
echo "even-even: ", ee
|
||||
echo "even-odd: ", eo
|
||||
echo "even-even rev: ", ee.reversed()
|
||||
echo "even-odd rev: ", eo.reversed()
|
||||
|
||||
block:
|
||||
let (oe, oo) = odd.splitAlternate()
|
||||
echo ""
|
||||
echo "odd-even: ", oe
|
||||
echo "odd-odd: ", oo
|
||||
echo "odd-even rev: ", oe.reversed()
|
||||
echo "odd-odd rev: ", oo.reversed()
|
||||
|
||||
echo "\nSkip Half"
|
||||
echo "----------------"
|
||||
echo "skipHalf: ", v.skipHalf()
|
||||
echo "skipQuad: ", v.skipHalf().skipHalf()
|
||||
echo "skipQuad rev: ", v.skipHalf().skipHalf().reversed()
|
||||
|
||||
echo "\nSplit middle"
|
||||
echo "----------------"
|
||||
block:
|
||||
let (left, right) = v.splitMiddle()
|
||||
echo "left: ", left
|
||||
echo "right: ", right
|
||||
block:
|
||||
let (ll, lr) = left.splitMiddle()
|
||||
echo ""
|
||||
echo "left-left: ", ll
|
||||
echo "left-right: ", lr
|
||||
echo "left-left rev: ", ll.reversed()
|
||||
echo "left-right rev: ", lr.reversed()
|
||||
|
||||
block:
|
||||
let (rl, rr) = right.splitMiddle()
|
||||
echo ""
|
||||
echo "right-left: ", rl
|
||||
echo "right-right: ", rr
|
||||
echo "right-left rev: ", rl.reversed()
|
||||
echo "right-right rev: ", rr.reversed()
|
||||
|
||||
main()
|
Loading…
Reference in New Issue