constantine/research/kzg/fft_fr.nim

274 lines
8.2 KiB
Nim
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# Constantine
# Copyright (c) 2018-2019 Status Research & Development GmbH
# Copyright (c) 2020-Present Mamy André-Ratsimbazafy
# Licensed and distributed under either of
# * MIT license (license terms in the root directory or at http://opensource.org/licenses/MIT).
# * Apache v2 license (license terms in the root directory or at http://www.apache.org/licenses/LICENSE-2.0).
# at your option. This file may not be copied, modified, or distributed except according to those terms.
import
../../constantine/platforms/primitives,
../../constantine/math/config/curves,
../../constantine/math/arithmetic,
../../constantine/math/io/io_fields,
# Research
./strided_views,
./fft_lut
# See: https://github.com/ethereum/research/blob/master/kzg_data_availability/fft.py
# Quirks of the Python impl:
# - no tests of FFT alone?
# - a lot of "if type(x) == tuple else"
#
# See: https://github.com/protolambda/go-kate/blob/7bb4684/fft_fr.go#L19-L21
# The go port uses stride+offset to deal with skip iterator.
#
# Other readable FFTs includes:
# - https://github.com/kwantam/fffft
# - https://github.com/ConsenSys/gnark/blob/master/internal/math/bls381/fft/fft.go
# - https://github.com/poanetwork/threshold_crypto/blob/8820c11/src/poly_vals.rs#L332-L370
# - https://github.com/zkcrypto/bellman/blob/10c5010/src/domain.rs#L272-L315
# - Modern Computer Arithmetic, Brent and Zimmermann, p53 algorithm 2.2
# https://members.loria.fr/PZimmermann/mca/mca-cup-0.5.9.pdf
# ############################################################
#
# Finite-Field Fast Fourier Transform
#
# ############################################################
#
# This is a research, unoptimized implementation of
# Finite Field Fast Fourier Transform
# In research phase we tolerate using
# - garbage collected types
# - and exceptions for fast prototyping
#
# In particular, in production all signed integers
# must be verified not to overflow
# and should not throw (or use unsigned)
# FFT Context
# ----------------------------------------------------------------
type
FFTStatus = enum
FFTS_Success
FFTS_TooManyValues = "Input length greater than the field 2-adicity (number of roots of unity)"
FFTS_SizeNotPowerOfTwo = "Input must be of a power of 2 length"
FFTDescriptor*[F] = object
## Metadata for FFT on field F
maxWidth: int
rootOfUnity: F
## The root of unity that generates all roots
expandedRootsOfUnity: seq[F]
## domain, starting and ending with 1
func expandRootOfUnity[F](rootOfUnity: F): seq[F] =
## From a generator root of unity
## expand to width + 1 values.
## (Last value is 1 for the reverse array)
# For a field of order q, there are gcd(n, q1)
# nth roots of unity, a.k.a. solutions to xⁿ ≡ 1 (mod q)
# but it's likely too long to compute bigint GCD
# so embrace heap (re-)allocations.
# Figuring out how to do to right size the buffers
# in production will be fun.
result.setLen(2)
result[0].setOne()
result[1] = rootOfUnity
while not result[result.len-1].isOne().bool:
result.setLen(result.len + 1)
result[result.len-1].prod(result[result.len-2], rootOfUnity)
# FFT Algorithm
# ----------------------------------------------------------------
# TODO: research Decimation in Time and Decimation in Frequency
# and FFT butterflies
func simpleFT[F](
output: var View[F],
vals: View[F],
rootsOfUnity: View[F]
) =
# FFT is a recursive algorithm
# This is the base-case using a O(n²) algorithm
let L = output.len
var last {.noInit.}, v {.noInit.}: F
for i in 0 ..< L:
last.prod(vals[0], rootsOfUnity[0])
for j in 1 ..< L:
v.prod(vals[j], rootsOfUnity[(i*j) mod L])
last += v
output[i] = last
func fft_internal[F](
output: var View[F],
vals: View[F],
rootsOfUnity: View[F]
) =
if output.len <= 4:
simpleFT(output, vals, rootsOfUnity)
return
# Recursive Divide-and-Conquer
let (evenVals, oddVals) = vals.splitAlternate()
var (outLeft, outRight) = output.splitMiddle()
let halfROI = rootsOfUnity.skipHalf()
fft_internal(outLeft, evenVals, halfROI)
fft_internal(outRight, oddVals, halfROI)
let half = outLeft.len
var y_times_root{.noinit.}: F
for i in 0 ..< half:
# FFT Butterfly
y_times_root .prod(output[i+half], rootsOfUnity[i])
output[i+half] .diff(output[i], y_times_root)
output[i] += y_times_root
func fft*[F](
desc: FFTDescriptor[F],
output: var openarray[F],
vals: openarray[F]): FFT_Status =
if vals.len > desc.maxWidth:
return FFTS_TooManyValues
if not vals.len.uint64.isPowerOf2_vartime():
return FFTS_SizeNotPowerOfTwo
let rootz = desc.expandedRootsOfUnity
.toView()
.slice(0, desc.maxWidth-1, desc.maxWidth div vals.len)
var voutput = output.toView()
fft_internal(voutput, vals.toView(), rootz)
return FFTS_Success
func ifft*[F](
desc: FFTDescriptor[F],
output: var openarray[F],
vals: openarray[F]): FFT_Status =
## Inverse FFT
if vals.len > desc.maxWidth:
return FFTS_TooManyValues
if not vals.len.uint64.isPowerOf2_vartime():
return FFTS_SizeNotPowerOfTwo
let rootz = desc.expandedRootsOfUnity
.toView()
.reversed()
.slice(0, desc.maxWidth-1, desc.maxWidth div vals.len)
var voutput = output.toView()
fft_internal(voutput, vals.toView(), rootz)
var invLen {.noInit.}: F
invLen.fromUint(vals.len.uint64)
invLen.inv()
for i in 0..< output.len:
output[i] *= invLen
return FFTS_Success
# FFT Descriptor
# ----------------------------------------------------------------
proc init*(T: type FFTDescriptor, maxScale: uint8): T =
result.maxWidth = 1 shl maxScale
result.rootOfUnity = scaleToRootOfUnity(T.F.C)[maxScale]
result.expandedRootsOfUnity =
result.rootOfUnity.expandRootOfUnity()
# Aren't you tired of reading about unity?
# ############################################################
#
# Sanity checks
#
# ############################################################
{.experimental: "views".}
when isMainModule:
import
std/[times, monotimes, strformat],
../../helpers/prng_unsafe
proc roundtrip() =
let fftDesc = FFTDescriptor[Fr[BLS12_381]].init(maxScale = 4)
var data = newSeq[Fr[BLS12_381]](fftDesc.maxWidth)
for i in 0 ..< fftDesc.maxWidth:
data[i].fromUint i.uint64
var coefs = newSeq[Fr[BLS12_381]](data.len)
let fftOk = fft(fftDesc, coefs, data)
doAssert fftOk == FFTS_Success
# display("coefs", 0, coefs)
var res = newSeq[Fr[BLS12_381]](data.len)
let ifftOk = ifft(fftDesc, res, coefs)
doAssert ifftOk == FFTS_Success
# display("res", 0, coefs)
for i in 0 ..< res.len:
if bool(res[i] != data[i]):
echo "Error: expected ", data[i].toHex(), " but got ", res[i].toHex()
quit 1
echo "FFT round-trip check SUCCESS"
proc warmup() =
# Warmup - make sure cpu is on max perf
let start = cpuTime()
var foo = 123
for i in 0 ..< 300_000_000:
foo += i*i mod 456
foo = foo mod 789
# Compiler shouldn't optimize away the results as cpuTime rely on sideeffects
let stop = cpuTime()
echo &"Warmup: {stop - start:>4.4f} s, result {foo} (displayed to avoid compiler optimizing warmup away)\n"
proc bench() =
echo "Starting benchmark ..."
const NumIters = 100
var rng: RngState
rng.seed 0x1234
# TODO: view types complain about mutable borrow
# in `random_unsafe` due to pseudo view type LimbsViewMut
# (which was views before Nim properly supported them)
warmup()
for scale in 4 ..< 16:
# Setup
let desc = FFTDescriptor[Fr[BLS12_381]].init(uint8 scale)
var data = newSeq[Fr[BLS12_381]](desc.maxWidth)
for i in 0 ..< desc.maxWidth:
# data[i] = rng.random_unsafe(data[i].typeof())
data[i].fromUint i.uint64
var coefsOut = newSeq[Fr[BLS12_381]](data.len)
# Bench
let start = getMonotime()
for i in 0 ..< NumIters:
let status = desc.fft(coefsOut, data)
doAssert status == FFTS_Success
let stop = getMonotime()
let ns = inNanoseconds((stop-start) div NumIters)
echo &"FFT scale {scale:>2} {ns:>8} ns/op"
roundtrip()
warmup()
bench()