constantine/research/kzg_poly_commit/fft_g1.nim

303 lines
9.0 KiB
Nim
Raw Normal View History

# Constantine
# Copyright (c) 2018-2019 Status Research & Development GmbH
# Copyright (c) 2020-Present Mamy André-Ratsimbazafy
# Licensed and distributed under either of
# * MIT license (license terms in the root directory or at http://opensource.org/licenses/MIT).
# * Apache v2 license (license terms in the root directory or at http://www.apache.org/licenses/LICENSE-2.0).
# at your option. This file may not be copied, modified, or distributed except according to those terms.
import
../../constantine/config/curves,
../../constantine/[arithmetic, primitives],
../../constantine/elliptic/[
ec_scalar_mul,
ec_shortweierstrass_affine,
ec_shortweierstrass_projective,
ec_shortweierstrass_jacobian,
],
../../constantine/io/[io_fields, io_ec],
# Research
./strided_views,
./fft_lut
# See: https://github.com/ethereum/research/blob/master/kzg_data_availability/fft.py
# Quirks of the Python impl:
# - no tests of FFT alone?
# - a lot of "if type(x) == tuple else"
#
# See: https://github.com/protolambda/go-kate/blob/7bb4684/fft_fr.go#L19-L21
# The go port uses stride+offset to deal with skip iterator.
#
# Other readable FFTs includes:
# - https://github.com/kwantam/fffft
# - https://github.com/ConsenSys/gnark/blob/master/internal/backend/bls381/fft/fft.go
# - https://github.com/poanetwork/threshold_crypto/blob/8820c11/src/poly_vals.rs#L332-L370
# - https://github.com/zkcrypto/bellman/blob/10c5010/src/domain.rs#L272-L315
# - Modern Computer Arithmetic, Brent and Zimmermann, p53 algorithm 2.2
# https://members.loria.fr/PZimmermann/mca/mca-cup-0.5.9.pdf
# ############################################################
#
# Finite-Field Fast Fourier Transform
#
# ############################################################
#
# This is a research, unoptimized implementation of
# Finite Field Fast Fourier Transform
# In research phase we tolerate using
# - garbage collected types
# - and exceptions for fast prototyping
#
# In particular, in production all signed integers
# must be verified not to overflow
# and should not throw (or use unsigned)
# FFT Context
# ----------------------------------------------------------------
type
FFTStatus = enum
FFTS_Success
FFTS_TooManyValues = "Input length greater than the field 2-adicity (number of roots of unity)"
FFTS_SizeNotPowerOfTwo = "Input must be of a power of 2 length"
FFTDescriptor*[EC] = object
## Metadata for FFT on Elliptic Curve
maxWidth: int
rootOfUnity: matchingOrderBigInt(EC.F.C)
## The root of unity that generates all roots
expandedRootsOfUnity: seq[matchingOrderBigInt(EC.F.C)]
## domain, starting and ending with 1
func isPowerOf2(n: SomeUnsignedInt): bool =
(n and (n - 1)) == 0
func nextPowerOf2(n: uint64): uint64 =
## Returns x if x is a power of 2
## or the next biggest power of 2
1'u64 shl (log2(n-1) + 1)
func expandRootOfUnity[F](rootOfUnity: F): auto {.noInit.} =
## From a generator root of unity
## expand to width + 1 values.
## (Last value is 1 for the reverse array)
# For a field of order q, there are gcd(n, q1)
# nth roots of unity, a.k.a. solutions to xⁿ ≡ 1 (mod q)
# but it's likely too long to compute bigint GCD
# so embrace heap (re-)allocations.
# Figuring out how to do to right size the buffers
# in production will be fun.
var r: seq[matchingOrderBigInt(F.C)]
r.setLen(2)
r[0].setOne()
r[1] = rootOfUnity.toBig()
var cur = rootOfUnity
while not r[^1].isOne().bool:
cur *= rootOfUnity
r.setLen(r.len + 1)
r[^1] = cur.toBig()
return r
# FFT Algorithm
# ----------------------------------------------------------------
func simpleFT[EC; bits: static int](
output: var View[EC],
vals: View[EC],
rootsOfUnity: View[BigInt[bits]]
) =
# FFT is a recursive algorithm
# This is the base-case using a O(n²) algorithm
let L = output.len
var last {.noInit.}, v {.noInit.}: EC
for i in 0 ..< L:
last = vals[0]
last.scalarMul(rootsOfUnity[0])
for j in 1 ..< L:
v = vals[j]
v.scalarMul(rootsOfUnity[(i*j) mod L])
last += v
output[i] = last
func fft_internal[EC; bits: static int](
output: var View[EC],
vals: View[EC],
rootsOfUnity: View[BigInt[bits]]
) =
if output.len <= 4:
simpleFT(output, vals, rootsOfUnity)
return
# Recursive Divide-and-Conquer
let (evenVals, oddVals) = vals.splitAlternate()
var (outLeft, outRight) = output.splitMiddle()
let halfROI = rootsOfUnity.skipHalf()
fft_internal(outLeft, evenVals, halfROI)
fft_internal(outRight, oddVals, halfROI)
let half = outLeft.len
var y_times_root{.noinit.}: EC
for i in 0 ..< half:
# FFT Butterfly
y_times_root = output[i+half]
y_times_root .scalarMul(rootsOfUnity[i])
output[i+half] .diff(output[i], y_times_root)
output[i] += y_times_root
func fft*[EC](
desc: FFTDescriptor[EC],
output: var openarray[EC],
vals: openarray[EC]): FFT_Status =
if vals.len > desc.maxWidth:
return FFTS_TooManyValues
if not vals.len.uint64.isPowerOf2():
return FFTS_SizeNotPowerOfTwo
let rootz = desc.expandedRootsOfUnity
.toView()
.slice(0, desc.maxWidth-1, desc.maxWidth div vals.len)
var voutput = output.toView()
fft_internal(voutput, vals.toView(), rootz)
return FFTS_Success
func ifft*[EC](
desc: FFTDescriptor[EC],
output: var openarray[EC],
vals: openarray[EC]): FFT_Status =
## Inverse FFT
if vals.len > desc.maxWidth:
return FFTS_TooManyValues
if not vals.len.uint64.isPowerOf2():
return FFTS_SizeNotPowerOfTwo
let rootz = desc.expandedRootsOfUnity
.toView()
.reversed()
.slice(0, desc.maxWidth-1, desc.maxWidth div vals.len)
var voutput = output.toView()
fft_internal(voutput, vals.toView(), rootz)
var invLen {.noInit.}: Fr[EC.F.C]
invLen.fromUint(vals.len.uint64)
invLen.inv()
let inv = invLen.toBig()
for i in 0..< output.len:
output[i].scalarMul(inv)
return FFTS_Success
# FFT Descriptor
# ----------------------------------------------------------------
proc init*(T: type FFTDescriptor, maxScale: uint8): T =
result.maxWidth = 1 shl maxScale
let root = scaleToRootOfUnity(T.EC.F.C)[maxScale]
result.rootOfUnity = root.toBig()
result.expandedRootsOfUnity = root.expandRootOfUnity()
# Aren't you tired of reading about unity?
# ############################################################
#
# Sanity checks
#
# ############################################################
{.experimental: "views".}
when isMainModule:
import
std/[times, monotimes, strformat],
../../helpers/prng_unsafe
type G1 = ECP_ShortW_Prj[Fp[BLS12_381], NotOnTwist]
var Generator1: ECP_ShortW_Aff[Fp[BLS12_381], NotOnTwist]
doAssert Generator1.fromHex(
"0x17f1d3a73197d7942695638c4fa9ac0fc3688c4f9774b905a14e3a3f171bac586c55e83ff97a1aeffb3af00adb22c6bb",
"0x08b3f481e3aaa0f1a09e30ed741d8ae4fcf5e095d5d00af600db18cb2c04b3edd03cc744a2888ae40caa232946c5e7e1"
)
proc roundtrip() =
let fftDesc = FFTDescriptor[G1].init(maxScale = 4)
var data = newSeq[G1](fftDesc.maxWidth)
data[0].projectiveFromAffine(Generator1)
for i in 1 ..< fftDesc.maxWidth:
data[i].madd(data[i-1], Generator1)
var coefs = newSeq[G1](data.len)
let fftOk = fft(fftDesc, coefs, data)
doAssert fftOk == FFTS_Success
# display("coefs", 0, coefs)
var res = newSeq[G1](data.len)
let ifftOk = ifft(fftDesc, res, coefs)
doAssert ifftOk == FFTS_Success
# display("res", 0, coefs)
for i in 0 ..< res.len:
if bool(res[i] != data[i]):
echo "Error: expected ", data[i].toHex(), " but got ", res[i].toHex()
quit 1
echo "FFT round-trip check SUCCESS"
proc warmup() =
# Warmup - make sure cpu is on max perf
let start = cpuTime()
var foo = 123
for i in 0 ..< 300_000_000:
foo += i*i mod 456
foo = foo mod 789
# Compiler shouldn't optimize away the results as cpuTime rely on sideeffects
let stop = cpuTime()
echo &"Warmup: {stop - start:>4.4f} s, result {foo} (displayed to avoid compiler optimizing warmup away)\n"
proc bench() =
echo "Starting benchmark ..."
const NumIters = 3
var rng: RngState
rng.seed 0x1234
# TODO: view types complain about mutable borrow
# in `random_unsafe` due to pseudo view type LimbsViewMut
# (which was views before Nim properly supported them)
warmup()
for scale in 4 ..< 10:
# Setup
let desc = FFTDescriptor[G1].init(uint8 scale)
var data = newSeq[G1](desc.maxWidth)
data[0].projectiveFromAffine(Generator1)
for i in 1 ..< desc.maxWidth:
data[i].madd(data[i-1], Generator1)
var coefsOut = newSeq[G1](data.len)
# Bench
let start = getMonotime()
for i in 0 ..< NumIters:
let status = desc.fft(coefsOut, data)
doAssert status == FFTS_Success
let stop = getMonotime()
let ns = inNanoseconds((stop-start) div NumIters)
echo &"FFT scale {scale:>2} {ns:>8} ns/op"
roundtrip()
warmup()
bench()