[Research] KZG polynomial commitment - part 1 FFT (#151)

* FFT compiles, now on to debugging ... [skip CI]

* Fix FFT and add bench [skip ci]

* rename + add KZG resources

* rename fft_fr

* Implement FFT on elliptic curves =)

* FFT G1 bench
This commit is contained in:
Mamy Ratsimbazafy 2021-02-06 22:11:17 +01:00 committed by GitHub
parent 94419db783
commit 54887b1777
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 902 additions and 9 deletions

View File

@ -23,13 +23,13 @@ export zoo_inversions
{.push raises: [].} {.push raises: [].}
{.push inline.} {.push inline.}
func inv_euclid*(r: var Fp, a: Fp) = func inv_euclid*(r: var FF, a: FF) =
## Inversion modulo p via ## Inversion modulo p via
## Niels Moller constant-time version of ## Niels Moller constant-time version of
## Stein's GCD derived from extended binary Euclid algorithm ## Stein's GCD derived from extended binary Euclid algorithm
r.mres.steinsGCD(a.mres, Fp.getR2modP(), Fp.C.Mod, Fp.getPrimePlus1div2()) r.mres.steinsGCD(a.mres, FF.getR2modP(), FF.fieldMod(), FF.getPrimePlus1div2())
func inv*(r: var Fp, a: Fp) = func inv*(r: var FF, a: FF) =
## Inversion modulo p ## Inversion modulo p
## ##
## The inverse of 0 is 0. ## The inverse of 0 is 0.
@ -40,22 +40,19 @@ func inv*(r: var Fp, a: Fp) =
# neither for Secp256k1 nor BN curves # neither for Secp256k1 nor BN curves
# Performance is slower than GCD # Performance is slower than GCD
# To be revisited with faster squaring/multiplications # To be revisited with faster squaring/multiplications
when Fp.C.hasInversionAddchain(): when FF is Fp and FF.C.hasInversionAddchain():
r.inv_addchain(a) r.inv_addchain(a)
else: else:
r.inv_euclid(a) r.inv_euclid(a)
func inv*(a: var Fp) = func inv*(a: var FF) =
## Inversion modulo p ## Inversion modulo p
## ##
## The inverse of 0 is 0. ## The inverse of 0 is 0.
## Incidentally this avoids extra check ## Incidentally this avoids extra check
## to convert Jacobian and Projective coordinates ## to convert Jacobian and Projective coordinates
## to affine for elliptic curve ## to affine for elliptic curve
when Fp.C.hasInversionAddchain(): a.inv(a)
a.inv_addchain(a)
else:
a.inv_euclid(a)
{.pop.} # inline {.pop.} # inline
{.pop.} # raises no exceptions {.pop.} # raises no exceptions

8
research/README.md Normal file
View File

@ -0,0 +1,8 @@
# Research
This folder stashes experimentations before they are productionized into the library.
- `kzg`: KZG Polynomial Commitments\
Constant-Size Commitments to Polynomials and Their Applications\
Aniket Kate, Gregory M. Zaverucha, Ian Goldberg, 2010\
https://www.iacr.org/archive/asiacrypt2010/6477178/6477178.pdf

View File

@ -0,0 +1,12 @@
# KZG Polynomial Commitment research
Research for Ethereum 2.0 phase 1
to implement the Data Availability Sampling protocol
See
- https://dankradfeist.de/ethereum/2020/06/16/kate-polynomial-commitments.html
- https://github.com/protolambda/go-kate
- FK20: https://github.com/khovratovich/Kate/blob/master/Kate_amortized.pdf
- https://github.com/ethereum/research/tree/master/polynomial_reconstruction
- https://github.com/ethereum/research/tree/master/kzg_data_availability

View File

@ -0,0 +1,280 @@
# Constantine
# Copyright (c) 2018-2019 Status Research & Development GmbH
# Copyright (c) 2020-Present Mamy André-Ratsimbazafy
# Licensed and distributed under either of
# * MIT license (license terms in the root directory or at http://opensource.org/licenses/MIT).
# * Apache v2 license (license terms in the root directory or at http://www.apache.org/licenses/LICENSE-2.0).
# at your option. This file may not be copied, modified, or distributed except according to those terms.
import
../../constantine/config/curves,
../../constantine/[arithmetic, primitives],
../../constantine/io/io_fields,
# Research
./strided_views,
./fft_lut
# See: https://github.com/ethereum/research/blob/master/kzg_data_availability/fft.py
# Quirks of the Python impl:
# - no tests of FFT alone?
# - a lot of "if type(x) == tuple else"
#
# See: https://github.com/protolambda/go-kate/blob/7bb4684/fft_fr.go#L19-L21
# The go port uses stride+offset to deal with skip iterator.
#
# Other readable FFTs includes:
# - https://github.com/kwantam/fffft
# - https://github.com/ConsenSys/gnark/blob/master/internal/backend/bls381/fft/fft.go
# - https://github.com/poanetwork/threshold_crypto/blob/8820c11/src/poly_vals.rs#L332-L370
# - https://github.com/zkcrypto/bellman/blob/10c5010/src/domain.rs#L272-L315
# - Modern Computer Arithmetic, Brent and Zimmermann, p53 algorithm 2.2
# https://members.loria.fr/PZimmermann/mca/mca-cup-0.5.9.pdf
# ############################################################
#
# Finite-Field Fast Fourier Transform
#
# ############################################################
#
# This is a research, unoptimized implementation of
# Finite Field Fast Fourier Transform
# In research phase we tolerate using
# - garbage collected types
# - and exceptions for fast prototyping
#
# In particular, in production all signed integers
# must be verified not to overflow
# and should not throw (or use unsigned)
# FFT Context
# ----------------------------------------------------------------
type
FFTStatus = enum
FFTS_Success
FFTS_TooManyValues = "Input length greater than the field 2-adicity (number of roots of unity)"
FFTS_SizeNotPowerOfTwo = "Input must be of a power of 2 length"
FFTDescriptor[F] = object
## Metadata for FFT on field F
maxWidth: int
rootOfUnity: F
## The root of unity that generates all roots
expandedRootsOfUnity: seq[F]
## domain, starting and ending with 1
func isPowerOf2(n: SomeUnsignedInt): bool =
(n and (n - 1)) == 0
func nextPowerOf2(n: uint64): uint64 =
## Returns x if x is a power of 2
## or the next biggest power of 2
1'u64 shl (log2(n-1) + 1)
func expandRootOfUnity[F](rootOfUnity: F): seq[F] =
## From a generator root of unity
## expand to width + 1 values.
## (Last value is 1 for the reverse array)
# For a field of order q, there are gcd(n, q1)
# nth roots of unity, a.k.a. solutions to xⁿ ≡ 1 (mod q)
# but it's likely too long to compute bigint GCD
# so embrace heap (re-)allocations.
# Figuring out how to do to right size the buffers
# in production will be fun.
result.setLen(2)
result[0].setOne()
result[1] = rootOfUnity
while not result[^1].isOne().bool:
result.setLen(result.len + 1)
result[^1].prod(result[^2], rootOfUnity)
# FFT Algorithm
# ----------------------------------------------------------------
# TODO: research Decimation in Time and Decimation in Frequency
# and FFT butterflies
func simpleFT[F](
output: var View[F],
vals: View[F],
rootsOfUnity: View[F]
) =
# FFT is a recursive algorithm
# This is the base-case using a O(n²) algorithm
let L = output.len
var last {.noInit.}, v {.noInit.}: F
for i in 0 ..< L:
last.prod(vals[0], rootsOfUnity[0])
for j in 1 ..< L:
v.prod(vals[j], rootsOfUnity[(i*j) mod L])
last += v
output[i] = last
func fft_internal[F](
output: var View[F],
vals: View[F],
rootsOfUnity: View[F]
) =
if output.len <= 4:
simpleFT(output, vals, rootsOfUnity)
return
# Recursive Divide-and-Conquer
let (evenVals, oddVals) = vals.splitAlternate()
var (outLeft, outRight) = output.splitMiddle()
let halfROI = rootsOfUnity.skipHalf()
fft_internal(outLeft, evenVals, halfROI)
fft_internal(outRight, oddVals, halfROI)
let half = outLeft.len
var y_times_root{.noinit.}: F
for i in 0 ..< half:
# FFT Butterfly
y_times_root .prod(output[i+half], rootsOfUnity[i])
output[i+half] .diff(output[i], y_times_root)
output[i] += y_times_root
func fft*[F](
desc: FFTDescriptor[F],
output: var openarray[F],
vals: openarray[F]): FFT_Status =
if vals.len > desc.maxWidth:
return FFTS_TooManyValues
if not vals.len.uint64.isPowerOf2():
return FFTS_SizeNotPowerOfTwo
let rootz = desc.expandedRootsOfUnity
.toView()
.slice(0, desc.maxWidth-1, desc.maxWidth div vals.len)
var voutput = output.toView()
fft_internal(voutput, vals.toView(), rootz)
return FFTS_Success
func ifft*[F](
desc: FFTDescriptor[F],
output: var openarray[F],
vals: openarray[F]): FFT_Status =
## Inverse FFT
if vals.len > desc.maxWidth:
return FFTS_TooManyValues
if not vals.len.uint64.isPowerOf2():
return FFTS_SizeNotPowerOfTwo
let rootz = desc.expandedRootsOfUnity
.toView()
.reversed()
.slice(0, desc.maxWidth-1, desc.maxWidth div vals.len)
var voutput = output.toView()
fft_internal(voutput, vals.toView(), rootz)
var invLen {.noInit.}: F
invLen.fromUint(vals.len.uint64)
invLen.inv()
for i in 0..< output.len:
output[i] *= invLen
return FFTS_Success
# FFT Descriptor
# ----------------------------------------------------------------
proc init*(T: type FFTDescriptor, maxScale: uint8): T =
result.maxWidth = 1 shl maxScale
result.rootOfUnity = scaleToRootOfUnity(T.F.C)[maxScale]
result.expandedRootsOfUnity =
result.rootOfUnity.expandRootOfUnity()
# Aren't you tired of reading about unity?
# ############################################################
#
# Sanity checks
#
# ############################################################
{.experimental: "views".}
when isMainModule:
import
std/[times, monotimes, strformat],
../../helpers/prng_unsafe
proc roundtrip() =
let fftDesc = FFTDescriptor[Fr[BLS12_381]].init(maxScale = 4)
var data = newSeq[Fr[BLS12_381]](fftDesc.maxWidth)
for i in 0 ..< fftDesc.maxWidth:
data[i].fromUint i.uint64
var coefs = newSeq[Fr[BLS12_381]](data.len)
let fftOk = fft(fftDesc, coefs, data)
doAssert fftOk == FFTS_Success
# display("coefs", 0, coefs)
var res = newSeq[Fr[BLS12_381]](data.len)
let ifftOk = ifft(fftDesc, res, coefs)
doAssert ifftOk == FFTS_Success
# display("res", 0, coefs)
for i in 0 ..< res.len:
if bool(res[i] != data[i]):
echo "Error: expected ", data[i].toHex(), " but got ", res[i].toHex()
quit 1
echo "FFT round-trip check SUCCESS"
proc warmup() =
# Warmup - make sure cpu is on max perf
let start = cpuTime()
var foo = 123
for i in 0 ..< 300_000_000:
foo += i*i mod 456
foo = foo mod 789
# Compiler shouldn't optimize away the results as cpuTime rely on sideeffects
let stop = cpuTime()
echo &"Warmup: {stop - start:>4.4f} s, result {foo} (displayed to avoid compiler optimizing warmup away)\n"
proc bench() =
echo "Starting benchmark ..."
const NumIters = 100
var rng: RngState
rng.seed 0x1234
# TODO: view types complain about mutable borrow
# in `random_unsafe` due to pseudo view type LimbsViewMut
# (which was views before Nim properly supported them)
warmup()
for scale in 4 ..< 16:
# Setup
let desc = FFTDescriptor[Fr[BLS12_381]].init(uint8 scale)
var data = newSeq[Fr[BLS12_381]](desc.maxWidth)
for i in 0 ..< desc.maxWidth:
# data[i] = rng.random_unsafe(data[i].typeof())
data[i].fromUint i.uint64
var coefsOut = newSeq[Fr[BLS12_381]](data.len)
# Bench
let start = getMonotime()
for i in 0 ..< NumIters:
let status = desc.fft(coefsOut, data)
doAssert status == FFTS_Success
let stop = getMonotime()
let ns = inNanoseconds((stop-start) div NumIters)
echo &"FFT scale {scale:>2} {ns:>8} ns/op"
roundtrip()
warmup()
bench()

View File

@ -0,0 +1,302 @@
# Constantine
# Copyright (c) 2018-2019 Status Research & Development GmbH
# Copyright (c) 2020-Present Mamy André-Ratsimbazafy
# Licensed and distributed under either of
# * MIT license (license terms in the root directory or at http://opensource.org/licenses/MIT).
# * Apache v2 license (license terms in the root directory or at http://www.apache.org/licenses/LICENSE-2.0).
# at your option. This file may not be copied, modified, or distributed except according to those terms.
import
../../constantine/config/curves,
../../constantine/[arithmetic, primitives],
../../constantine/elliptic/[
ec_endomorphism_accel,
ec_shortweierstrass_affine,
ec_shortweierstrass_projective,
ec_shortweierstrass_jacobian,
],
../../constantine/io/[io_fields, io_ec],
# Research
./strided_views,
./fft_lut
# See: https://github.com/ethereum/research/blob/master/kzg_data_availability/fft.py
# Quirks of the Python impl:
# - no tests of FFT alone?
# - a lot of "if type(x) == tuple else"
#
# See: https://github.com/protolambda/go-kate/blob/7bb4684/fft_fr.go#L19-L21
# The go port uses stride+offset to deal with skip iterator.
#
# Other readable FFTs includes:
# - https://github.com/kwantam/fffft
# - https://github.com/ConsenSys/gnark/blob/master/internal/backend/bls381/fft/fft.go
# - https://github.com/poanetwork/threshold_crypto/blob/8820c11/src/poly_vals.rs#L332-L370
# - https://github.com/zkcrypto/bellman/blob/10c5010/src/domain.rs#L272-L315
# - Modern Computer Arithmetic, Brent and Zimmermann, p53 algorithm 2.2
# https://members.loria.fr/PZimmermann/mca/mca-cup-0.5.9.pdf
# ############################################################
#
# Finite-Field Fast Fourier Transform
#
# ############################################################
#
# This is a research, unoptimized implementation of
# Finite Field Fast Fourier Transform
# In research phase we tolerate using
# - garbage collected types
# - and exceptions for fast prototyping
#
# In particular, in production all signed integers
# must be verified not to overflow
# and should not throw (or use unsigned)
# FFT Context
# ----------------------------------------------------------------
type
FFTStatus = enum
FFTS_Success
FFTS_TooManyValues = "Input length greater than the field 2-adicity (number of roots of unity)"
FFTS_SizeNotPowerOfTwo = "Input must be of a power of 2 length"
FFTDescriptor[EC] = object
## Metadata for FFT on Elliptic Curve
maxWidth: int
rootOfUnity: matchingOrderBigInt(EC.F.C)
## The root of unity that generates all roots
expandedRootsOfUnity: seq[matchingOrderBigInt(EC.F.C)]
## domain, starting and ending with 1
func isPowerOf2(n: SomeUnsignedInt): bool =
(n and (n - 1)) == 0
func nextPowerOf2(n: uint64): uint64 =
## Returns x if x is a power of 2
## or the next biggest power of 2
1'u64 shl (log2(n-1) + 1)
func expandRootOfUnity[F](rootOfUnity: F): auto {.noInit.} =
## From a generator root of unity
## expand to width + 1 values.
## (Last value is 1 for the reverse array)
# For a field of order q, there are gcd(n, q1)
# nth roots of unity, a.k.a. solutions to xⁿ ≡ 1 (mod q)
# but it's likely too long to compute bigint GCD
# so embrace heap (re-)allocations.
# Figuring out how to do to right size the buffers
# in production will be fun.
var r: seq[matchingOrderBigInt(F.C)]
r.setLen(2)
r[0].setOne()
r[1] = rootOfUnity.toBig()
var cur = rootOfUnity
while not r[^1].isOne().bool:
cur *= rootOfUnity
r.setLen(r.len + 1)
r[^1] = cur.toBig()
return r
# FFT Algorithm
# ----------------------------------------------------------------
func simpleFT[EC; bits: static int](
output: var View[EC],
vals: View[EC],
rootsOfUnity: View[BigInt[bits]]
) =
# FFT is a recursive algorithm
# This is the base-case using a O(n²) algorithm
let L = output.len
var last {.noInit.}, v {.noInit.}: EC
for i in 0 ..< L:
last = vals[0]
last.scalarMulGLV_m2w2(rootsOfUnity[0])
for j in 1 ..< L:
v = vals[j]
v.scalarMulGLV_m2w2(rootsOfUnity[(i*j) mod L])
last += v
output[i] = last
func fft_internal[EC; bits: static int](
output: var View[EC],
vals: View[EC],
rootsOfUnity: View[BigInt[bits]]
) =
if output.len <= 4:
simpleFT(output, vals, rootsOfUnity)
return
# Recursive Divide-and-Conquer
let (evenVals, oddVals) = vals.splitAlternate()
var (outLeft, outRight) = output.splitMiddle()
let halfROI = rootsOfUnity.skipHalf()
fft_internal(outLeft, evenVals, halfROI)
fft_internal(outRight, oddVals, halfROI)
let half = outLeft.len
var y_times_root{.noinit.}: EC
for i in 0 ..< half:
# FFT Butterfly
y_times_root = output[i+half]
y_times_root .scalarMulGLV_m2w2(rootsOfUnity[i])
output[i+half] .diff(output[i], y_times_root)
output[i] += y_times_root
func fft*[EC](
desc: FFTDescriptor[EC],
output: var openarray[EC],
vals: openarray[EC]): FFT_Status =
if vals.len > desc.maxWidth:
return FFTS_TooManyValues
if not vals.len.uint64.isPowerOf2():
return FFTS_SizeNotPowerOfTwo
let rootz = desc.expandedRootsOfUnity
.toView()
.slice(0, desc.maxWidth-1, desc.maxWidth div vals.len)
var voutput = output.toView()
fft_internal(voutput, vals.toView(), rootz)
return FFTS_Success
func ifft*[EC](
desc: FFTDescriptor[EC],
output: var openarray[EC],
vals: openarray[EC]): FFT_Status =
## Inverse FFT
if vals.len > desc.maxWidth:
return FFTS_TooManyValues
if not vals.len.uint64.isPowerOf2():
return FFTS_SizeNotPowerOfTwo
let rootz = desc.expandedRootsOfUnity
.toView()
.reversed()
.slice(0, desc.maxWidth-1, desc.maxWidth div vals.len)
var voutput = output.toView()
fft_internal(voutput, vals.toView(), rootz)
var invLen {.noInit.}: Fr[EC.F.C]
invLen.fromUint(vals.len.uint64)
invLen.inv()
let inv = invLen.toBig()
for i in 0..< output.len:
output[i].scalarMulGLV_m2w2(inv)
return FFTS_Success
# FFT Descriptor
# ----------------------------------------------------------------
proc init*(T: type FFTDescriptor, maxScale: uint8): T =
result.maxWidth = 1 shl maxScale
let root = scaleToRootOfUnity(T.EC.F.C)[maxScale]
result.rootOfUnity = root.toBig()
result.expandedRootsOfUnity = root.expandRootOfUnity()
# Aren't you tired of reading about unity?
# ############################################################
#
# Sanity checks
#
# ############################################################
{.experimental: "views".}
when isMainModule:
import
std/[times, monotimes, strformat],
../../helpers/prng_unsafe
type G1 = ECP_ShortW_Prj[Fp[BLS12_381], NotOnTwist]
var Generator1: ECP_ShortW_Aff[Fp[BLS12_381], NotOnTwist]
doAssert Generator1.fromHex(
"0x17f1d3a73197d7942695638c4fa9ac0fc3688c4f9774b905a14e3a3f171bac586c55e83ff97a1aeffb3af00adb22c6bb",
"0x08b3f481e3aaa0f1a09e30ed741d8ae4fcf5e095d5d00af600db18cb2c04b3edd03cc744a2888ae40caa232946c5e7e1"
)
proc roundtrip() =
let fftDesc = FFTDescriptor[G1].init(maxScale = 4)
var data = newSeq[G1](fftDesc.maxWidth)
data[0].projectiveFromAffine(Generator1)
for i in 1 ..< fftDesc.maxWidth:
data[i].madd(data[i-1], Generator1)
var coefs = newSeq[G1](data.len)
let fftOk = fft(fftDesc, coefs, data)
doAssert fftOk == FFTS_Success
# display("coefs", 0, coefs)
var res = newSeq[G1](data.len)
let ifftOk = ifft(fftDesc, res, coefs)
doAssert ifftOk == FFTS_Success
# display("res", 0, coefs)
for i in 0 ..< res.len:
if bool(res[i] != data[i]):
echo "Error: expected ", data[i].toHex(), " but got ", res[i].toHex()
quit 1
echo "FFT round-trip check SUCCESS"
proc warmup() =
# Warmup - make sure cpu is on max perf
let start = cpuTime()
var foo = 123
for i in 0 ..< 300_000_000:
foo += i*i mod 456
foo = foo mod 789
# Compiler shouldn't optimize away the results as cpuTime rely on sideeffects
let stop = cpuTime()
echo &"Warmup: {stop - start:>4.4f} s, result {foo} (displayed to avoid compiler optimizing warmup away)\n"
proc bench() =
echo "Starting benchmark ..."
const NumIters = 3
var rng: RngState
rng.seed 0x1234
# TODO: view types complain about mutable borrow
# in `random_unsafe` due to pseudo view type LimbsViewMut
# (which was views before Nim properly supported them)
warmup()
for scale in 4 ..< 16:
# Setup
let desc = FFTDescriptor[G1].init(uint8 scale)
var data = newSeq[G1](desc.maxWidth)
data[0].projectiveFromAffine(Generator1)
for i in 1 ..< desc.maxWidth:
data[i].madd(data[i-1], Generator1)
var coefsOut = newSeq[G1](data.len)
# Bench
let start = getMonotime()
for i in 0 ..< NumIters:
let status = desc.fft(coefsOut, data)
doAssert status == FFTS_Success
let stop = getMonotime()
let ns = inNanoseconds((stop-start) div NumIters)
echo &"FFT scale {scale:>2} {ns:>8} ns/op"
roundtrip()
warmup()
bench()

View File

@ -0,0 +1,47 @@
# Constantine
# Copyright (c) 2018-2019 Status Research & Development GmbH
# Copyright (c) 2020-Present Mamy André-Ratsimbazafy
# Licensed and distributed under either of
# * MIT license (license terms in the root directory or at http://opensource.org/licenses/MIT).
# * Apache v2 license (license terms in the root directory or at http://www.apache.org/licenses/LICENSE-2.0).
# at your option. This file may not be copied, modified, or distributed except according to those terms.
import
std/macros,
../../constantine/config/[curves, common],
../../constantine/[arithmetic, primitives],
../../constantine/io/io_fields
# TODO automate this
# we can precompute everything in Sage
# and auto-generate the file.
const BLS12_381_Fr_primitive_root = 5
func buildRootLUT(F: type Fr): array[32, F] =
## [pow(PRIMITIVE_ROOT, (MODULUS - 1) // (2**i), MODULUS) for i in range(32)]
var exponent {.noInit.}: BigInt[F.C.getCurveOrderBitwidth()]
exponent = F.C.getCurveOrder()
exponent -= One
# Start by the end
var i = result.len - 1
exponent.shiftRight(i)
result[i].fromUint(BLS12_381_Fr_primitive_root)
result[i].powUnsafeExponent(exponent)
while i > 0:
result[i-1].square(result[i])
dec i
# debugEcho "Fr[BLS12_81] - Roots of Unity:"
# for i in 0 ..< result.len:
# debugEcho " ", i, ": ", result[i].toHex()
# debugEcho "Fr[BLS12_81] - Roots of Unity -- FIN\n"
let BLS12_381_Fr_ScaleToRootOfUnity* = buildRootLUT(Fr[BLS12_381])
{.experimental: "dynamicBindSym".}
macro scaleToRootOfUnity*(C: static Curve): untyped =
return bindSym($C & "_Fr_ScaleToRootOfUnity")

View File

@ -0,0 +1,247 @@
# Constantine
# Copyright (c) 2018-2019 Status Research & Development GmbH
# Copyright (c) 2020-Present Mamy André-Ratsimbazafy
# Licensed and distributed under either of
# * MIT license (license terms in the root directory or at http://opensource.org/licenses/MIT).
# * Apache v2 license (license terms in the root directory or at http://www.apache.org/licenses/LICENSE-2.0).
# at your option. This file may not be copied, modified, or distributed except according to those terms.
# Strided View - Monodimensional Tensors
# ----------------------------------------------------------------
#
# FFT uses recursive divide-and-conquer.
# In code this means need strided views
# to enable different logical views of the same memory buffer.
# Strided views are monodimensional tensors:
# See Arraymancer backend:
# https://github.com/mratsim/Arraymancer/blob/71cf616/src/arraymancer/laser/tensor/datatypes.nim#L28-L32
# Or the minimal tensor implementation challenge:
# https://github.com/SimonDanisch/julia-challenge/blob/b8ed3b6/nim/nim_sol_mratsim.nim#L4-L26
{.experimental: "views".}
type
View*[T] = object
## A strided view over an (unowned) data buffer
len*: int
stride: int
offset: int
data: lent UncheckedArray[T]
func `[]`*[T](v: View[T], idx: int): lent T {.inline.} =
v.data[v.offset + idx*v.stride]
func `[]`*[T](v: var View[T], idx: int): var T {.inline.} =
# Experimental views indeed ...
cast[ptr UncheckedArray[T]](v.data)[v.offset + idx*v.stride]
func `[]=`*[T](v: var View[T], idx: int, val: T) {.inline.} =
# Experimental views indeed ...
cast[ptr UncheckedArray[T]](v.data)[v.offset + idx*v.stride] = val
func toView*[T](oa: openArray[T]): View[T] {.inline.} =
result.len = oa.len
result.stride = 1
result.offset = 0
result.data = cast[lent UncheckedArray[T]](oa[0].unsafeAddr)
iterator items*[T](v: View[T]): lent T =
var cur = v.offset
for _ in 0 ..< v.len:
yield v.data[cur]
cur += v.stride
func `$`*(v: View): string =
result = "View["
var first = true
for elem in v:
if not first:
result &= ", "
else:
first = false
result &= $elem
result &= ']'
func toHex*(v: View): string =
mixin toHex
result = "View["
var first = true
for elem in v:
if not first:
result &= ", "
else:
first = false
result &= elem.toHex()
result &= ']'
# FFT-specific splitting
# -------------------------------------------------------------------------------
func splitAlternate*(t: View): tuple[even, odd: View] {.inline.} =
## Split the tensor into 2
## partitioning the input every other index
## even: indices [0, 2, 4, ...]
## odd: indices [ 1, 3, 5, ...]
assert (t.len and 1) == 0, "The tensor must contain an even number of elements"
let half = t.len shr 1
let skipHalf = t.stride shl 1
result.even.len = half
result.even.stride = skipHalf
result.even.offset = t.offset
result.even.data = t.data
result.odd.len = half
result.odd.stride = skipHalf
result.odd.offset = t.offset + t.stride
result.odd.data = t.data
func splitMiddle*(t: View): tuple[left, right: View] {.inline.} =
## Split the tensor into 2
## partitioning into left and right halves.
## left: indices [0, 1, 2, 3]
## right: indices [4, 5, 6, 7]
assert (t.len and 1) == 0, "The tensor must contain an even number of elements"
let half = t.len shr 1
result.left.len = half
result.left.stride = t.stride
result.left.offset = t.offset
result.left.data = t.data
result.right.len = half
result.right.stride = t.stride
result.right.offset = t.offset + half
result.right.data = t.data
func skipHalf*(t: View): View {.inline.} =
## Pick one every other indices
## output: [0, 2, 4, ...]
assert (t.len and 1) == 0, "The tensor must contain an even number of elements"
result.len = t.len shr 1
result.stride = t.stride shl 1
result.offset = t.offset
result.data = t.data
func slice*(v: View, start, stop, step: int): View {.inline.} =
## Slice a view
## stop is inclusive
# General tensor slicing algorithm is
# https://github.com/mratsim/Arraymancer/blob/71cf616/src/arraymancer/tensor/private/p_accessors_macros_read.nim#L26-L56
#
# for i, slice in slices:
# # Check if we start from the end
# let a = if slice.a_from_end: result.shape[i] - slice.a
# else: slice.a
#
# let b = if slice.b_from_end: result.shape[i] - slice.b
# else: slice.b
#
# # Compute offset:
# result.offset += a * result.strides[i]
# # Now change shape and strides
# result.strides[i] *= slice.step
# result.shape[i] = abs((b-a) div slice.step) + 1
#
# with slices being of size 1, as we have a monodimensional Tensor
# and the slice being a..<b with the reverse case: len-1 -> 0
#
# result is preinitialized with a copy of v (shape, stride, offset, data)
result.offset = v.offset + start * v.stride
result.stride = v.stride * step
result.len = abs((stop-start) div step) + 1
result.data = v.data
func reversed*(v: View): View {.inline.} =
# Hopefully the compiler optimizes div by -1
v.slice(v.len-1, 0, -1)
# ############################################################
#
# Debugging helpers
#
# ############################################################
import strformat, strutils
func display*[F](name: string, indent: int, oa: openArray[F]) =
debugEcho indent(name & ", openarray of " & $F & " of length " & $oa.len, indent)
for i in 0 ..< oa.len:
debugEcho indent(&" {i:>2}: {oa[i].toHex()}", indent)
debugEcho indent(name & " " & $F & " -- FIN\n", indent)
func display*[F](name: string, indent: int, v: View[F]) =
debugEcho indent(name & ", view of " & $F & " of length " & $v.len, indent)
for i in 0 ..< v.len:
debugEcho indent(&" {i:>2}: {v[i].toHex()}", indent)
debugEcho indent(name & " " & $F & " -- FIN\n", indent)
# ############################################################
#
# Sanity checks
#
# ############################################################
when isMainModule:
proc main() =
var x = [0, 1, 2, 3, 4, 5, 6, 7]
let v = x.toView()
echo "view: ", v
echo "reversed: ", v.reversed()
block:
let (even, odd) = v.splitAlternate()
echo "\nSplit Alternate"
echo "----------------"
echo "even: ", even
echo "odd: ", odd
block:
let (ee, eo) = even.splitAlternate()
echo ""
echo "even-even: ", ee
echo "even-odd: ", eo
echo "even-even rev: ", ee.reversed()
echo "even-odd rev: ", eo.reversed()
block:
let (oe, oo) = odd.splitAlternate()
echo ""
echo "odd-even: ", oe
echo "odd-odd: ", oo
echo "odd-even rev: ", oe.reversed()
echo "odd-odd rev: ", oo.reversed()
echo "\nSkip Half"
echo "----------------"
echo "skipHalf: ", v.skipHalf()
echo "skipQuad: ", v.skipHalf().skipHalf()
echo "skipQuad rev: ", v.skipHalf().skipHalf().reversed()
echo "\nSplit middle"
echo "----------------"
block:
let (left, right) = v.splitMiddle()
echo "left: ", left
echo "right: ", right
block:
let (ll, lr) = left.splitMiddle()
echo ""
echo "left-left: ", ll
echo "left-right: ", lr
echo "left-left rev: ", ll.reversed()
echo "left-right rev: ", lr.reversed()
block:
let (rl, rr) = right.splitMiddle()
echo ""
echo "right-left: ", rl
echo "right-right: ", rr
echo "right-left rev: ", rl.reversed()
echo "right-right rev: ", rr.reversed()
main()