281 lines
8.3 KiB
Nim
281 lines
8.3 KiB
Nim
# Constantine
|
||
# Copyright (c) 2018-2019 Status Research & Development GmbH
|
||
# Copyright (c) 2020-Present Mamy André-Ratsimbazafy
|
||
# Licensed and distributed under either of
|
||
# * MIT license (license terms in the root directory or at http://opensource.org/licenses/MIT).
|
||
# * Apache v2 license (license terms in the root directory or at http://www.apache.org/licenses/LICENSE-2.0).
|
||
# at your option. This file may not be copied, modified, or distributed except according to those terms.
|
||
|
||
import
|
||
../../constantine/config/curves,
|
||
../../constantine/[arithmetic, primitives],
|
||
../../constantine/io/io_fields,
|
||
# Research
|
||
./strided_views,
|
||
./fft_lut
|
||
|
||
# See: https://github.com/ethereum/research/blob/master/kzg_data_availability/fft.py
|
||
# Quirks of the Python impl:
|
||
# - no tests of FFT alone?
|
||
# - a lot of "if type(x) == tuple else"
|
||
#
|
||
# See: https://github.com/protolambda/go-kate/blob/7bb4684/fft_fr.go#L19-L21
|
||
# The go port uses stride+offset to deal with skip iterator.
|
||
#
|
||
# Other readable FFTs includes:
|
||
# - https://github.com/kwantam/fffft
|
||
# - https://github.com/ConsenSys/gnark/blob/master/internal/backend/bls381/fft/fft.go
|
||
# - https://github.com/poanetwork/threshold_crypto/blob/8820c11/src/poly_vals.rs#L332-L370
|
||
# - https://github.com/zkcrypto/bellman/blob/10c5010/src/domain.rs#L272-L315
|
||
# - Modern Computer Arithmetic, Brent and Zimmermann, p53 algorithm 2.2
|
||
# https://members.loria.fr/PZimmermann/mca/mca-cup-0.5.9.pdf
|
||
|
||
# ############################################################
|
||
#
|
||
# Finite-Field Fast Fourier Transform
|
||
#
|
||
# ############################################################
|
||
#
|
||
# This is a research, unoptimized implementation of
|
||
# Finite Field Fast Fourier Transform
|
||
|
||
# In research phase we tolerate using
|
||
# - garbage collected types
|
||
# - and exceptions for fast prototyping
|
||
#
|
||
# In particular, in production all signed integers
|
||
# must be verified not to overflow
|
||
# and should not throw (or use unsigned)
|
||
|
||
# FFT Context
|
||
# ----------------------------------------------------------------
|
||
|
||
type
|
||
FFTStatus = enum
|
||
FFTS_Success
|
||
FFTS_TooManyValues = "Input length greater than the field 2-adicity (number of roots of unity)"
|
||
FFTS_SizeNotPowerOfTwo = "Input must be of a power of 2 length"
|
||
|
||
FFTDescriptor*[F] = object
|
||
## Metadata for FFT on field F
|
||
maxWidth: int
|
||
rootOfUnity: F
|
||
## The root of unity that generates all roots
|
||
expandedRootsOfUnity: seq[F]
|
||
## domain, starting and ending with 1
|
||
|
||
func isPowerOf2(n: SomeUnsignedInt): bool =
|
||
(n and (n - 1)) == 0
|
||
|
||
func nextPowerOf2(n: uint64): uint64 =
|
||
## Returns x if x is a power of 2
|
||
## or the next biggest power of 2
|
||
1'u64 shl (log2(n-1) + 1)
|
||
|
||
func expandRootOfUnity[F](rootOfUnity: F): seq[F] =
|
||
## From a generator root of unity
|
||
## expand to width + 1 values.
|
||
## (Last value is 1 for the reverse array)
|
||
# For a field of order q, there are gcd(n, q−1)
|
||
# nth roots of unity, a.k.a. solutions to xⁿ ≡ 1 (mod q)
|
||
# but it's likely too long to compute bigint GCD
|
||
# so embrace heap (re-)allocations.
|
||
# Figuring out how to do to right size the buffers
|
||
# in production will be fun.
|
||
result.setLen(2)
|
||
result[0].setOne()
|
||
result[1] = rootOfUnity
|
||
while not result[^1].isOne().bool:
|
||
result.setLen(result.len + 1)
|
||
result[^1].prod(result[^2], rootOfUnity)
|
||
|
||
# FFT Algorithm
|
||
# ----------------------------------------------------------------
|
||
|
||
# TODO: research Decimation in Time and Decimation in Frequency
|
||
# and FFT butterflies
|
||
|
||
func simpleFT[F](
|
||
output: var View[F],
|
||
vals: View[F],
|
||
rootsOfUnity: View[F]
|
||
) =
|
||
# FFT is a recursive algorithm
|
||
# This is the base-case using a O(n²) algorithm
|
||
|
||
let L = output.len
|
||
var last {.noInit.}, v {.noInit.}: F
|
||
|
||
for i in 0 ..< L:
|
||
last.prod(vals[0], rootsOfUnity[0])
|
||
for j in 1 ..< L:
|
||
v.prod(vals[j], rootsOfUnity[(i*j) mod L])
|
||
last += v
|
||
output[i] = last
|
||
|
||
func fft_internal[F](
|
||
output: var View[F],
|
||
vals: View[F],
|
||
rootsOfUnity: View[F]
|
||
) =
|
||
if output.len <= 4:
|
||
simpleFT(output, vals, rootsOfUnity)
|
||
return
|
||
|
||
# Recursive Divide-and-Conquer
|
||
let (evenVals, oddVals) = vals.splitAlternate()
|
||
var (outLeft, outRight) = output.splitMiddle()
|
||
let halfROI = rootsOfUnity.skipHalf()
|
||
|
||
fft_internal(outLeft, evenVals, halfROI)
|
||
fft_internal(outRight, oddVals, halfROI)
|
||
|
||
let half = outLeft.len
|
||
var y_times_root{.noinit.}: F
|
||
|
||
for i in 0 ..< half:
|
||
# FFT Butterfly
|
||
y_times_root .prod(output[i+half], rootsOfUnity[i])
|
||
output[i+half] .diff(output[i], y_times_root)
|
||
output[i] += y_times_root
|
||
|
||
func fft*[F](
|
||
desc: FFTDescriptor[F],
|
||
output: var openarray[F],
|
||
vals: openarray[F]): FFT_Status =
|
||
if vals.len > desc.maxWidth:
|
||
return FFTS_TooManyValues
|
||
if not vals.len.uint64.isPowerOf2():
|
||
return FFTS_SizeNotPowerOfTwo
|
||
|
||
let rootz = desc.expandedRootsOfUnity
|
||
.toView()
|
||
.slice(0, desc.maxWidth-1, desc.maxWidth div vals.len)
|
||
|
||
var voutput = output.toView()
|
||
fft_internal(voutput, vals.toView(), rootz)
|
||
return FFTS_Success
|
||
|
||
func ifft*[F](
|
||
desc: FFTDescriptor[F],
|
||
output: var openarray[F],
|
||
vals: openarray[F]): FFT_Status =
|
||
## Inverse FFT
|
||
if vals.len > desc.maxWidth:
|
||
return FFTS_TooManyValues
|
||
if not vals.len.uint64.isPowerOf2():
|
||
return FFTS_SizeNotPowerOfTwo
|
||
|
||
let rootz = desc.expandedRootsOfUnity
|
||
.toView()
|
||
.reversed()
|
||
.slice(0, desc.maxWidth-1, desc.maxWidth div vals.len)
|
||
|
||
var voutput = output.toView()
|
||
fft_internal(voutput, vals.toView(), rootz)
|
||
|
||
var invLen {.noInit.}: F
|
||
invLen.fromUint(vals.len.uint64)
|
||
invLen.inv()
|
||
|
||
for i in 0..< output.len:
|
||
output[i] *= invLen
|
||
|
||
return FFTS_Success
|
||
|
||
# FFT Descriptor
|
||
# ----------------------------------------------------------------
|
||
|
||
proc init*(T: type FFTDescriptor, maxScale: uint8): T =
|
||
result.maxWidth = 1 shl maxScale
|
||
result.rootOfUnity = scaleToRootOfUnity(T.F.C)[maxScale]
|
||
result.expandedRootsOfUnity =
|
||
result.rootOfUnity.expandRootOfUnity()
|
||
# Aren't you tired of reading about unity?
|
||
|
||
# ############################################################
|
||
#
|
||
# Sanity checks
|
||
#
|
||
# ############################################################
|
||
|
||
{.experimental: "views".}
|
||
|
||
when isMainModule:
|
||
import
|
||
std/[times, monotimes, strformat],
|
||
../../helpers/prng_unsafe
|
||
|
||
proc roundtrip() =
|
||
let fftDesc = FFTDescriptor[Fr[BLS12_381]].init(maxScale = 4)
|
||
var data = newSeq[Fr[BLS12_381]](fftDesc.maxWidth)
|
||
for i in 0 ..< fftDesc.maxWidth:
|
||
data[i].fromUint i.uint64
|
||
|
||
var coefs = newSeq[Fr[BLS12_381]](data.len)
|
||
let fftOk = fft(fftDesc, coefs, data)
|
||
doAssert fftOk == FFTS_Success
|
||
# display("coefs", 0, coefs)
|
||
|
||
var res = newSeq[Fr[BLS12_381]](data.len)
|
||
let ifftOk = ifft(fftDesc, res, coefs)
|
||
doAssert ifftOk == FFTS_Success
|
||
# display("res", 0, coefs)
|
||
|
||
for i in 0 ..< res.len:
|
||
if bool(res[i] != data[i]):
|
||
echo "Error: expected ", data[i].toHex(), " but got ", res[i].toHex()
|
||
quit 1
|
||
|
||
echo "FFT round-trip check SUCCESS"
|
||
|
||
proc warmup() =
|
||
# Warmup - make sure cpu is on max perf
|
||
let start = cpuTime()
|
||
var foo = 123
|
||
for i in 0 ..< 300_000_000:
|
||
foo += i*i mod 456
|
||
foo = foo mod 789
|
||
|
||
# Compiler shouldn't optimize away the results as cpuTime rely on sideeffects
|
||
let stop = cpuTime()
|
||
echo &"Warmup: {stop - start:>4.4f} s, result {foo} (displayed to avoid compiler optimizing warmup away)\n"
|
||
|
||
|
||
proc bench() =
|
||
echo "Starting benchmark ..."
|
||
const NumIters = 100
|
||
|
||
var rng: RngState
|
||
rng.seed 0x1234
|
||
# TODO: view types complain about mutable borrow
|
||
# in `random_unsafe` due to pseudo view type LimbsViewMut
|
||
# (which was views before Nim properly supported them)
|
||
|
||
warmup()
|
||
|
||
for scale in 4 ..< 16:
|
||
# Setup
|
||
|
||
let desc = FFTDescriptor[Fr[BLS12_381]].init(uint8 scale)
|
||
var data = newSeq[Fr[BLS12_381]](desc.maxWidth)
|
||
for i in 0 ..< desc.maxWidth:
|
||
# data[i] = rng.random_unsafe(data[i].typeof())
|
||
data[i].fromUint i.uint64
|
||
|
||
var coefsOut = newSeq[Fr[BLS12_381]](data.len)
|
||
|
||
# Bench
|
||
let start = getMonotime()
|
||
for i in 0 ..< NumIters:
|
||
let status = desc.fft(coefsOut, data)
|
||
doAssert status == FFTS_Success
|
||
let stop = getMonotime()
|
||
|
||
let ns = inNanoseconds((stop-start) div NumIters)
|
||
echo &"FFT scale {scale:>2} {ns:>8} ns/op"
|
||
|
||
roundtrip()
|
||
warmup()
|
||
bench()
|